├── .gitignore ├── .travis.yml ├── README.md ├── ast_tools ├── __init__.py ├── common.py ├── cst_utils │ ├── __init__.py │ ├── deep_node.py │ └── insert_statements.py ├── macros.py ├── metadata │ ├── __init__.py │ ├── always_returns_provider.py │ └── condition_provider.py ├── passes │ ├── __init__.py │ ├── base.py │ ├── bool_to_bit.py │ ├── cse.py │ ├── debug.py │ ├── if_inline.py │ ├── if_to_phi.py │ ├── loop_unroll.py │ ├── remove_asserts.py │ ├── ssa.py │ └── util.py ├── pattern.py ├── stack.py ├── transformers │ ├── __init__.py │ ├── if_inliner.py │ ├── loop_unroller.py │ ├── node_replacer.py │ ├── node_tracker.py │ ├── normalizers.py │ ├── renamer.py │ └── symbol_replacer.py ├── utils.py └── visitors │ ├── __init__.py │ ├── collect_names.py │ ├── collect_targets.py │ ├── node_finder.py │ └── used_names.py ├── docs └── developer.md ├── setup.py ├── tests ├── test_apply_passes.py ├── test_assert_remover.py ├── test_bool_to_bit.py ├── test_common.py ├── test_cse.py ├── test_if_to_phi.py ├── test_immutable_ast.py ├── test_inline.py ├── test_normalizers.py ├── test_passes.py ├── test_pattern.py ├── test_ssa.py ├── test_stack.py ├── test_unroll.py └── test_visitors.py └── util └── generate_ast ├── __init__.py ├── _base.px ├── _functions.px ├── _meta.px └── generate.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | dist 3 | ast_tools.egg-info 4 | .ast_tools 5 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | python: 3 | - '3.8' 4 | - '3.9' 5 | install: 6 | - pip install python-coveralls 7 | - pip install pytest-cov 8 | - pip install -e . 9 | - echo $TRAVIS_PYTHON_VERSION 10 | script: py.test --cov=ast_tools tests/ 11 | after_success: coveralls 12 | before_deploy: 13 | # Hack to set python-tag in bdist_wheel since travis deploy doesn't accept 14 | # bdist_wheel arguments, see 15 | # https://stackoverflow.com/questions/52609945/how-do-i-add-a-python-tag-to-the-bdist-wheel-command-using-setuptools 16 | - if [[ $TRAVIS_PYTHON_VERSION == 3.7 ]]; then echo -e "[bdist_wheel]\npython-tag=py37" > setup.cfg; fi 17 | - if [[ $TRAVIS_PYTHON_VERSION == 3.8 ]]; then echo -e "[bdist_wheel]\npython-tag=py38" > setup.cfg; fi 18 | 19 | deploy: 20 | skip_cleanup: true 21 | provider: pypi 22 | user: leonardt 23 | password: 24 | secure: fIKje89v2qyBZsjc+XbXegWjuxekMMz0iO/+pf112lLoBYoc0DER3CSH5ojXqHExviSX+d8TAQvrdDuX5Z13TWEX8VKMcbBjx0Z6YDDJ+Qnb45Pbc8rWhcuTscvT3e7ygPzFp/+9ZnA+OPHjRL8eDzVL0sfDZdkdytgs39ACVuKZSBtQs4ITYegeqxfy+l4MyAbFiLebhmOb9Y/6cIArPSaJLrEYQLfzJF03971yq0g4qukMRha/+OLEg/Hr3EkWfq0OVh5PC2dJ1c/nL0wYllR5Aflqh0+mXP6jWBjdvXQeuOEwIfhGp3uUzLJHbmnI53x9VsU7v1czI6Xgwmil/B9kpuudpGZ6742SU1dmwI6MFZ8Wj1EjmgZcMa21Cw2z+5mTdlwoylnl5KsYI5z7gptdD+qFulyUj9F/BwkQ5Wr6c1nrrLIsAgTs0YLbXXBZxrLKceJokkNX7Z+7CZv/dabms8LqmRW/OY52sDxA3Sx+DqbQjzZbv2DmM9rl4R4mg6JnE1xEFFW4YPqzNHkEpqgEUIi1XZBvZZJ+7rW/LZn35Mm9KGbyWEovCt4K1E2t9H2EfIXm3ech8RPE1OtI/rwGAsyc7RXoHvwq3x8LlcElN9I9RryE/ksQWjDwU/wKWM5A8/g0utCcLIM2EC0dmtnAIPvhsdrNtSSmeTJQKAM= 25 | distributions: bdist_wheel 26 | on: 27 | branch: master 28 | tags: true 29 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Build Status](https://travis-ci.com/leonardt/ast_tools.svg?branch=master)](https://travis-ci.com/leonardt/ast_tools) 2 | [![Coverage Status](https://coveralls.io/repos/github/leonardt/ast_tools/badge.svg?branch=master)](https://coveralls.io/github/leonardt/ast_tools?branch=master) 3 | 4 | Toolbox for working with the Python AST 5 | 6 | ``` 7 | pip install ast_tools 8 | ``` 9 | 10 | # Useful References 11 | * [Green Tree Snakes - the missing Python AST docs](https://greentreesnakes.readthedocs.io/) 12 | 13 | 14 | # Passes 15 | ast_tools provides a number of passes for rewriting function and classes (could 16 | also work at the module level however no such pass exists). Passes are 17 | applied with the `apply_passes` decorator: 18 | 19 | ```python 20 | @apply_passes([pass1(), pass2()]) 21 | def foo(...): ... 22 | ``` 23 | Each pass takes as arguments an AST, an environment, and metadata and 24 | returns (possibly) modified versions of each. 25 | `apply_passes` begins a chain of rewrites by first looking 26 | up the ast of the decorated object and gather attempts to gather locals 27 | and globals from the call site to build the environment. 28 | 29 | After all rewrites have run `apply_passes` serializes and 30 | execute the rewritten ast. 31 | 32 | ## Know Issues 33 | ### Collecting the AST 34 | `apply_passes` relies on `inspect.getsource` to get the 35 | source of the decorated definition (which is then parsed to get the initial ast). 36 | However, `inspect.getsource` has many limitations. 37 | 38 | ### Collecting the Environment 39 | `apply_passes` does its best to infer the environment 40 | however there is no way to do this in a fully correct way. Users are 41 | encouraged to pass environment explicitly: 42 | ```python 43 | @apply_passes(..., env=SymbolTable(locals(), globals())) 44 | def foo(...): ... 45 | ``` 46 | 47 | ### Wrapping the apply_passes decorator 48 | The `apply_passes` decorator must not be wrapped. 49 | 50 | As decorators are a part of the AST of the object they are applied to 51 | they must be removed from the rewritten AST before it is executed. If they 52 | are not removed rewrites will recurse infinitely as 53 | 54 | ```python 55 | @apply_passes([...]) 56 | def foo(...): ... 57 | ``` 58 | 59 | would become 60 | 61 | ```python 62 | exec('''\ 63 | @apply_passes([...]) 64 | def rewritten_foo(...): ... 65 | ''') 66 | ``` 67 | Note: this would invoke `apply_passes([...])` on `rewritten_foo` 68 | 69 | To avoid this the `apply_passes` decorator filters itself from the decorator list. If, however, 70 | the decorator is wrapped inside another decorator, this will fail. 71 | 72 | ### Inner decorators are called multiple times 73 | 74 | Decorators that are applied before a rewrite group will be called multiple times. 75 | See https://github.com/leonardt/ast_tools/issues/46 for detailed explanation. 76 | To avoid this users are encouraged to make rewrites the inner most decorators 77 | when possible. 78 | 79 | # Macros 80 | ## Loop Unrolling 81 | Unroll loops using the pattern 82 | ```python 83 | for in ast_tools.macros.unroll(): 84 | ... 85 | ``` 86 | 87 | `` should be an iterable object that produces integers (e.g. `range(8)`) 88 | that can be evaluated at definition time (can refer to variables in the scope 89 | of the function definition) 90 | 91 | For example, 92 | ```python 93 | from ast_tools.passes import apply_passes, loop_unroll 94 | 95 | @apply_passes([loop_unroll()]) 96 | def foo(): 97 | for i in ast_tools.macros.unroll(range(8)): 98 | print(i) 99 | ``` 100 | is rewritten into 101 | ```python 102 | def foo(): 103 | print(0) 104 | print(1) 105 | print(2) 106 | print(3) 107 | print(4) 108 | print(5) 109 | print(6) 110 | print(7) 111 | ``` 112 | 113 | You can also use a list of `int`s, here's an example that also uses a reference 114 | to a variable defined in the outer scope: 115 | ```python 116 | from ast_tools.passes import apply_passes, loop_unroll 117 | 118 | j = [1, 2, 3] 119 | @apply_passes([loop_unroll()]) 120 | def foo(): 121 | for i in ast_tools.macros.unroll(j): 122 | print(i) 123 | ``` 124 | becomes 125 | ```python 126 | def foo(): 127 | print(1) 128 | print(2) 129 | print(3) 130 | ``` 131 | 132 | ## Inlining If Statements 133 | This macro allows you to evaluate `if` statements at function definition time, 134 | so the resulting rewritten function will have the `if` statements marked 135 | "inlined" removed from the final code and replaced with the chosen branch based 136 | on evaluating the condition in the definition's enclosing scope. `if` 137 | statements are marked by using the form `if inline(...):` where `inline` is 138 | imported from the `ast_tools.macros` package. `if` statements not matching 139 | this pattern will be ignored by the rewrite logic. 140 | 141 | Here's an example 142 | ```python 143 | from ast_tools.macros import inline 144 | from ast_tools.passes import apply_passes, if_inline 145 | 146 | y = True 147 | 148 | @apply_passes([if_inline()]) 149 | def foo(x): 150 | if inline(y): 151 | return x + 1 152 | else: 153 | return x - 1 154 | 155 | 156 | import inspect 157 | assert inspect.getsource(foo) == f"""\ 158 | def foo(x): 159 | return x + 1 160 | """ 161 | ``` 162 | 163 | # Developing 164 | Interested in extending the library? Check out these [developer 165 | docs](https://github.com/leonardt/ast_tools/blob/master/docs/developer.md) 166 | -------------------------------------------------------------------------------- /ast_tools/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | ast_tools top level package 3 | """ 4 | from .common import * 5 | from . import immutable_ast 6 | from . import passes 7 | from . import stack 8 | from . import transformers 9 | from . import visitors 10 | -------------------------------------------------------------------------------- /ast_tools/common.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import ast 3 | import datetime 4 | import functools 5 | import inspect 6 | import itertools 7 | import logging 8 | import os 9 | import textwrap 10 | import types 11 | import typing as tp 12 | import weakref 13 | 14 | import astor 15 | 16 | import libcst as cst 17 | 18 | from ast_tools import stack 19 | from ast_tools.stack import SymbolTable 20 | from ast_tools.visitors import used_names 21 | from ast_tools.cst_utils import to_module 22 | 23 | __ALL__ = ['exec_in_file', 'exec_def_in_file', 'exec_str_in_file', 'get_ast', 'get_cst', 'gen_free_name'] 24 | 25 | CSTDefStmt = tp.Union[ 26 | cst.ClassDef, 27 | cst.FunctionDef, 28 | ] 29 | 30 | ASTDefStmt = tp.Union[ 31 | ast.AsyncFunctionDef, 32 | ast.ClassDef, 33 | ast.FunctionDef, 34 | ] 35 | 36 | DefStmt = tp.Union[ASTDefStmt, CSTDefStmt] 37 | 38 | def exec_def_in_file( 39 | tree: DefStmt, 40 | st: SymbolTable, 41 | path: tp.Optional[str] = None, 42 | file_name: tp.Optional[str] = None, 43 | serialized_tree: tp.Optional[DefStmt] = None, 44 | ) -> tp.Any: 45 | """ 46 | execs a definition in a file and returns the definiton 47 | 48 | For explanation of serialized_tree see 49 | https://github.com/leonardt/ast_tools/issues/46 50 | """ 51 | tree_name = _get_name(tree) 52 | if file_name is None: 53 | file_name = f'{tree_name}_{datetime.datetime.now().isoformat()}.py' 54 | 55 | return exec_in_file(tree, st, path, file_name, serialized_tree)[tree_name] 56 | 57 | def _get_name(tree: DefStmt) -> str: 58 | if isinstance(tree, ast.AST): 59 | return tree.name 60 | else: 61 | return tree.name.value 62 | 63 | def to_source( 64 | tree: DefStmt 65 | ) -> str: 66 | if isinstance(tree, ast.AST): 67 | return astor.to_source(tree) 68 | else: 69 | return to_module(tree).code 70 | 71 | def exec_in_file( 72 | tree: DefStmt, 73 | st: SymbolTable, 74 | path: tp.Optional[str] = None, 75 | file_name: tp.Optional[str] = None, 76 | serialized_tree: tp.Optional[DefStmt] = None, 77 | ) -> tp.MutableMapping[str, tp.Any]: 78 | 79 | """ 80 | execs an ast as a module and returns the modified enviroment 81 | 82 | For explanation of serialized_tree see 83 | https://github.com/leonardt/ast_tools/issues/46 84 | """ 85 | 86 | source = to_source(tree) 87 | if serialized_tree is None: 88 | serialized_source = source 89 | else: 90 | serialized_source = to_source(serialized_tree) 91 | return exec_str_in_file(source, st, path, file_name, serialized_source) 92 | 93 | 94 | def exec_str_in_file( 95 | source: str, 96 | st: SymbolTable, 97 | path: tp.Optional[str] = None, 98 | file_name: tp.Optional[str] = None, 99 | serialized_source: tp.Optional[str] = None, 100 | ) -> tp.MutableMapping[str, tp.Any]: 101 | """ 102 | execs a string as a module and returns the modified enviroment 103 | 104 | For explanation of serialized_source see 105 | https://github.com/leonardt/ast_tools/issues/46 106 | """ 107 | 108 | if path is None: 109 | path = '.ast_tools' 110 | 111 | if file_name is None: 112 | file_name = f'ast_tools_exec_{datetime.datetime.now().isoformat()}.py' 113 | 114 | if serialized_source is None: 115 | serialized_source = source 116 | 117 | file_name = os.path.join(path, file_name) 118 | os.makedirs(path, exist_ok=True) 119 | with open(file_name, 'w') as fp: 120 | fp.write(serialized_source) 121 | 122 | try: 123 | code = compile(source, filename=file_name, mode='exec') 124 | except Exception as e: 125 | logging.exception("Error compiling source") 126 | raise e from None 127 | 128 | st_dict = dict(st) 129 | try: 130 | exec(code, st_dict) 131 | return st_dict 132 | except Exception as e: 133 | logging.exception("Error executing code") 134 | raise e from None 135 | 136 | 137 | _AST_CACHE: tp.MutableMapping[tp.Any, ast.AST] = weakref.WeakKeyDictionary() 138 | def get_ast(obj) -> ast.AST: 139 | """ 140 | Given an object, get the corresponding AST 141 | """ 142 | try: 143 | return _AST_CACHE[obj] 144 | except KeyError: 145 | pass 146 | 147 | src = textwrap.dedent(inspect.getsource(obj)) 148 | 149 | if isinstance(obj, types.ModuleType): 150 | tree = ast.parse(src) 151 | else: 152 | tree = ast.parse(src).body[0] 153 | 154 | return _AST_CACHE.setdefault(obj, tree) 155 | 156 | 157 | 158 | _CST_CACHE: tp.MutableMapping[tp.Any, cst.CSTNode] = weakref.WeakKeyDictionary() 159 | def get_cst(obj) -> cst.CSTNode: 160 | """ 161 | Given an object, get the corresponding CST 162 | """ 163 | try: 164 | return _CST_CACHE[obj] 165 | except KeyError: 166 | pass 167 | 168 | src = textwrap.dedent(inspect.getsource(obj)) 169 | 170 | if isinstance(obj, types.ModuleType): 171 | tree = cst.parse_module(src) 172 | else: 173 | tree = cst.parse_statement(src) 174 | 175 | return _CST_CACHE.setdefault(obj, tree) 176 | 177 | 178 | def is_free_name(tree: cst.CSTNode, env: SymbolTable, name: str): 179 | names = used_names(tree) 180 | return name not in names and name not in env 181 | 182 | 183 | def is_free_prefix(tree: cst.CSTNode, env: SymbolTable, prefix: str): 184 | names = used_names(tree) 185 | return not any( 186 | name.startswith(prefix) 187 | for name in itertools.chain(names, env.keys())) 188 | 189 | 190 | def gen_free_name( 191 | tree: cst.CSTNode, 192 | env: SymbolTable, 193 | prefix: tp.Optional[str] = None) -> str: 194 | names = used_names(tree) | env.keys() 195 | if prefix is not None and prefix not in names: 196 | return prefix 197 | elif prefix is None: 198 | prefix = '_auto_name_' 199 | 200 | f_str = prefix+'{}' 201 | c = 0 202 | name = f_str.format(c) 203 | while name in names: 204 | c += 1 205 | name = f_str.format(c) 206 | 207 | return name 208 | 209 | 210 | def gen_free_prefix( 211 | tree: cst.CSTNode, 212 | env: SymbolTable, 213 | preprefix: tp.Optional[str] = None) -> str: 214 | def check_prefix(prefix: str, used_names: tp.AbstractSet[str]) -> bool: 215 | return not any(name.startswith(prefix) for name in used_names) 216 | 217 | names = used_names(tree) | env.keys() 218 | 219 | if preprefix is not None and check_prefix(preprefix, names): 220 | return preprefix 221 | elif preprefix is None: 222 | preprefix = '_auto_prefix_' 223 | 224 | f_str = preprefix+'{}' 225 | c = 0 226 | prefix = f_str.format(c) 227 | while not check_prefix(prefix, names): 228 | c += 1 229 | prefix = f_str.format(c) 230 | 231 | return prefix 232 | -------------------------------------------------------------------------------- /ast_tools/cst_utils/__init__.py: -------------------------------------------------------------------------------- 1 | import typing as tp 2 | 3 | import libcst as cst 4 | 5 | from .insert_statements import InsertStatementsVisitor 6 | from .deep_node import DeepNode 7 | 8 | 9 | _T = tp.Union[ 10 | cst.BaseSuite, 11 | cst.BaseExpression, 12 | cst.BaseStatement, 13 | cst.BaseSmallStatement, 14 | cst.Module, 15 | ] 16 | 17 | def to_module(node: _T) -> cst.Module: 18 | if isinstance(node, cst.SimpleStatementSuite): 19 | return cst.Module(body=node.body) 20 | elif isinstance(node, cst.IndentedBlock): 21 | return cst.Module(body=node.body) 22 | 23 | if isinstance(node, cst.BaseExpression): 24 | node = cst.Expr(value=node) 25 | 26 | if isinstance(node, (cst.BaseStatement, cst.BaseSmallStatement)): 27 | node = cst.Module(body=(node,)) 28 | 29 | if isinstance(node, cst.Module): 30 | return node 31 | 32 | raise TypeError(f'{node} :: {type(node)} cannot be cast to Module') 33 | 34 | def to_stmt(node: cst.BaseSmallStatement) -> cst.SimpleStatementLine: 35 | return cst.SimpleStatementLine(body=[node]) 36 | 37 | def make_assign( 38 | lhs: cst.BaseAssignTargetExpression, 39 | rhs: cst.BaseExpression, 40 | ) -> cst.Assign: 41 | return cst.Assign( 42 | targets=[cst.AssignTarget(lhs),], 43 | value=rhs, 44 | ) 45 | -------------------------------------------------------------------------------- /ast_tools/cst_utils/deep_node.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | import dataclasses 3 | import functools as ft 4 | import itertools as it 5 | import typing as tp 6 | 7 | import libcst as cst 8 | 9 | 10 | class FieldRemover(cst.CSTTransformer): 11 | @abstractmethod 12 | def skip_field(self, field: dataclasses.Field) -> bool: pass 13 | 14 | def on_leave(self, 15 | original_node: cst.CSTNode, 16 | updated_node: cst.CSTNode 17 | ) -> tp.Union[cst.CSTNode, cst.RemovalSentinel]: 18 | saved_fields = {} 19 | for field in dataclasses.fields(updated_node): 20 | if self.skip_field(field): 21 | continue 22 | 23 | n = field.name 24 | saved_fields[n] = getattr(updated_node, n) 25 | 26 | final_node = type(updated_node)(**saved_fields) 27 | return super().on_leave(original_node, final_node) 28 | 29 | 30 | _WHITE_SPACE_TYPES: tp.Set[tp.Type[cst.CSTNode]] = frozenset(( 31 | cst.Comment, 32 | cst.EmptyLine, 33 | cst.Newline, 34 | cst.ParenthesizedWhitespace, 35 | cst.SimpleWhitespace, 36 | cst.TrailingWhitespace, 37 | cst.BaseParenthesizableWhitespace, 38 | )) 39 | 40 | 41 | _WHITE_SPACE_SEQUENCE_TYPES: tp.Set[tp.Type] = frozenset( 42 | tp.Sequence[t] for t in _WHITE_SPACE_TYPES 43 | ) 44 | 45 | 46 | class WhiteSpaceNormalizer(FieldRemover): 47 | def skip_field(self, field: dataclasses.Field) -> bool: 48 | t = field.type 49 | return t in _WHITE_SPACE_TYPES or t in _WHITE_SPACE_SEQUENCE_TYPES 50 | 51 | 52 | _PAREN_NAMES = ('lpar', 'rpar') 53 | 54 | 55 | class StripParens(FieldRemover): 56 | def skip_field(self, field: dataclasses.Field) -> bool: 57 | return field.name in _PAREN_NAMES 58 | 59 | def _normalize(node: cst.CSTNode): 60 | node = node.visit(StripParens()) 61 | node = node.visit(WhiteSpaceNormalizer()) 62 | node.validate_types_deep() 63 | return node 64 | 65 | 66 | @ft.lru_cache(maxsize=2048) 67 | def _deep_hash(node: cst.CSTNode): 68 | h = hash(type(node)) 69 | try: 70 | h += hash(node.evaluated_value) 71 | except: 72 | pass 73 | 74 | for i,c in enumerate(node.children): 75 | h += (1+i)*_deep_hash(c) 76 | 77 | return h 78 | 79 | 80 | class DeepNode(tp.Generic[cst.CSTNodeT], tp.Hashable): 81 | # Note because of: 82 | # https://github.com/Instagram/LibCST/issues/341 83 | # the normalized node may not be equivelent to original node 84 | # as parens are removed from the normalized node 85 | original_node: cst.CSTNode 86 | normal_node: cst.CSTNode 87 | _hash: int 88 | 89 | def __init__(self, node: cst.CSTNode): 90 | self.original_node = node 91 | self.normal_node = norm = _normalize(node) 92 | self._hash = _deep_hash(norm) 93 | 94 | def __eq__(self, other: 'DeepNode') -> bool: 95 | if isinstance(other, DeepNode): 96 | return self.normal_node.deep_equals(other.normal_node) 97 | else: 98 | return NotImplemented 99 | 100 | 101 | def __ne__(self, other: 'DeepNode') -> bool: 102 | if isinstance(other, DeepNode): 103 | return not self.normal_node.deep_equals(other.normal_node) 104 | else: 105 | return NotImplemented 106 | 107 | def __hash__(self) -> int: 108 | return self._hash 109 | -------------------------------------------------------------------------------- /ast_tools/cst_utils/insert_statements.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | # 6 | # pyre-strict 7 | from dataclasses import dataclass 8 | from typing import List, Optional, Sequence, Set, Union 9 | 10 | import libcst as cst 11 | from libcst.codemod._context import CodemodContext 12 | from libcst.codemod._visitor import ContextAwareTransformer 13 | 14 | 15 | @dataclass 16 | class StatementContext: 17 | # List of statements to insert before the current statement 18 | before_stmts: List[cst.BaseStatement] 19 | 20 | # List of statements to insert after the current statement 21 | after_stmts: List[cst.BaseStatement] 22 | 23 | keep_comments: bool 24 | 25 | 26 | @dataclass 27 | class BlockContext: 28 | # Ordered list of all statements accumulated into the new block body 29 | new_body: List[cst.BaseStatement] 30 | 31 | # Set of all inserted statements (through insert_* methods) 32 | added_stmts: Set[cst.BaseStatement] 33 | 34 | hanging_lines: List[cst.EmptyLine] 35 | 36 | 37 | @dataclass 38 | class InsertStatementsVisitorContext: 39 | """ 40 | Context for the InsertStatementsVisitor about which statements 41 | have been requested to insert before/after the current one. 42 | """ 43 | 44 | # Stack of contexts for statements 45 | ctx_stmt: List[StatementContext] 46 | 47 | # Stack of contexts for blocks 48 | ctx_block: List[BlockContext] 49 | 50 | 51 | class RemoveEmptyBlocks(cst.CSTTransformer): 52 | def leave_SimpleStatementLine(self, original_node, updated_node): 53 | final_node = super().leave_SimpleStatementLine(original_node, 54 | updated_node) 55 | if len(final_node.body) == 0: 56 | return cst.RemoveFromParent() 57 | return final_node 58 | 59 | def leave_If(self, original_node, updated_node): 60 | final_node = super().leave_If(original_node, updated_node) 61 | if not isinstance(final_node, cst.RemovalSentinel): 62 | if (final_node.orelse is not None 63 | and len(final_node.orelse.body.body) == 0): 64 | final_node = final_node.with_changes(orelse=None) 65 | 66 | if len(final_node.body.body) == 0: 67 | return cst.RemoveFromParent() 68 | 69 | return final_node 70 | 71 | def leave_For(self, original_node, updated_node): 72 | final_node = super().leave_For(original_node, updated_node) 73 | if not isinstance(final_node, cst.RemovalSentinel): 74 | if len(final_node.body.body) == 0: 75 | return cst.RemoveFromParent() 76 | return final_node 77 | 78 | 79 | class InsertStatementsVisitor(ContextAwareTransformer, RemoveEmptyBlocks): 80 | """ 81 | Allows transformers to insert multiple statements before and after the currently-visited statement. 82 | 83 | This class is a mixin for :class:`~libcst.codemod.ContextAwareTransformer`. Subclasses gain the methods :meth:`~libcst.codemod.visitors.InsertStatementsVisitor.insert_statements_before_current` and :meth:`~libcst.codemod.visitors.InsertStatementsVisitor.insert_statements_after_current`. For example, you can create a pass that inserts print statements before each use of a variable:: 84 | 85 | from libcst.metadata.visitors import InsertStatementsVisitor 86 | from libcst.metadata import ExpressionContextProvider, ExpressionContext 87 | class InsertPrintVisitor(InsertStatementsVisitor): 88 | METADATA_DEPENDENCIES = (ExpressionContextProvider,) 89 | 90 | def __init__(self, context: CodemodContext, name: str) -> None: 91 | super().__init__(context) 92 | self.name = name 93 | 94 | def visit_Name(self, node: cst.Name) -> None: 95 | if ( 96 | node.value == self.name 97 | and self.get_metadata(ExpressionContextProvider, node) 98 | == ExpressionContext.LOAD 99 | ): 100 | self.insert_statements_before_current( 101 | [cst.parse_statement(f"print({self.name})")] 102 | ) 103 | 104 | After initializing this visitor with ``name = "y"``, it will take this code:: 105 | 106 | y = 1 107 | x = y 108 | 109 | And transform it into this code:: 110 | 111 | y = 1 112 | print(y) 113 | x = y 114 | 115 | You **must** call ``super()`` methods if you override any visit or leave method for: Module, IndentedBlock, SimpleStatementLine, If, Try, FunctionDef, ClassDef, With, For, While. 116 | """ 117 | 118 | CONTEXT_KEY = "InsertStatementsVisitor" 119 | 120 | def __init__(self, context: CodemodContext) -> None: 121 | super().__init__(context) 122 | self.context.scratch[InsertStatementsVisitor. 123 | CONTEXT_KEY] = InsertStatementsVisitorContext([], 124 | []) 125 | 126 | def _context(self) -> InsertStatementsVisitorContext: 127 | return self.context.scratch[InsertStatementsVisitor.CONTEXT_KEY] 128 | 129 | def insert_statements_before_current(self, stmts: List[cst.BaseStatement] 130 | ) -> None: 131 | """ 132 | Inserts a list of statements before the currently visited statement. 133 | 134 | While traversing the AST, the InsertStatementVisitor collects a stack of visited statements. For example, in the snippet:: 135 | 136 | if y: 137 | x = 1 138 | 139 | When visiting ``y`` in ``visit_Name``, the current statement is the ``if``. When visiting ``1`` in ``visit_Integer``, the current statement is the ``x = 1`` assignment. Calling ``insert_statments_before_current`` will add a list of statements to be inserted before ``current``, which is handled in :meth:`~libcst.codemod.visitors.InsertStatementsVisitor.leave_Module` and :meth:`~libcst.codemod.visitors.InsertStatementsVisitor.leave_IndentedBlock`. 140 | """ 141 | 142 | ctx = self._context() 143 | assert ( 144 | len(ctx.ctx_block) > 0 145 | ), "InsertStatementVisitor is inserting a statement before having entered a statement" 146 | ctx.ctx_stmt[-1].before_stmts.extend(stmts) 147 | 148 | def insert_statements_after_current(self, 149 | stmts: List[cst.BaseStatement]) -> None: 150 | """ 151 | Inserts a list of statements after the currently visited statement. 152 | 153 | See :meth:`~libcst.codemod.visitors.InsertStatementVisitor.insert_statements_before_current` for details. 154 | """ 155 | 156 | ctx = self._context() 157 | assert ( 158 | len(ctx.ctx_block) > 0 159 | ), "InsertStatementVisitor is inserting a statement before having entered a statement" 160 | ctx.ctx_stmt[-1].after_stmts.extend(stmts) 161 | 162 | def dont_keep_comments(self) -> None: 163 | ctx = self._context() 164 | ctx.ctx_stmt[-1].keep_comments = False 165 | 166 | def reattach_comments(self, original_node, new_node): 167 | if isinstance(new_node, list): 168 | if len(new_node) == 0: 169 | return new_node 170 | first_node = new_node[0] 171 | return [ 172 | first_node.with_changes( 173 | leading_lines=list(original_node.leading_lines) + 174 | list(first_node.leading_lines)) 175 | ] + new_node[1:] 176 | else: 177 | return new_node.with_changes( 178 | leading_lines=list(original_node.leading_lines) + 179 | list(new_node.leading_lines)) 180 | 181 | def _visit_block(self) -> None: 182 | ctx = self._context() 183 | ctx.ctx_block.append( 184 | BlockContext(new_body=[], added_stmts=set(), hanging_lines=[])) 185 | 186 | def visit_IndentedBlock(self, node: cst.IndentedBlock) -> Optional[bool]: 187 | self._visit_block() 188 | return super().visit_IndentedBlock(node) 189 | 190 | def _leave_block(self, updated_body: Sequence[cst.BaseStatement] 191 | ) -> List[cst.BaseStatement]: 192 | ctx = self._context() 193 | ctx_block = ctx.ctx_block.pop() 194 | return [ 195 | stmt for stmt in ctx_block.new_body 196 | if stmt in updated_body or stmt in ctx_block.added_stmts 197 | ] 198 | 199 | def leave_IndentedBlock(self, original_node: cst.IndentedBlock, 200 | updated_node: cst.IndentedBlock) -> cst.BaseSuite: 201 | final_node = super().leave_IndentedBlock(original_node, updated_node) 202 | if isinstance(final_node, cst.IndentedBlock): 203 | new_body = self._leave_block(final_node.body) 204 | return final_node.with_changes(body=new_body) 205 | else: 206 | self._context().ctx_block.pop() 207 | return final_node 208 | 209 | return final_node 210 | 211 | def visit_Module(self, node: cst.Module) -> Optional[bool]: 212 | self._visit_block() 213 | return super().visit_Module(node) 214 | 215 | def leave_Module(self, original_node: cst.Module, 216 | updated_node: cst.Module) -> cst.Module: 217 | final_node = super().leave_Module(original_node, updated_node) 218 | new_body = self._leave_block(final_node.body) 219 | return final_node.with_changes(body=new_body) 220 | 221 | def _visit_stmt(self, node: cst.BaseStatement) -> None: 222 | ctx = self._context() 223 | ctx.ctx_stmt.append( 224 | StatementContext(before_stmts=[], 225 | after_stmts=[], 226 | keep_comments=True)) 227 | 228 | def _add_hanging_lines(self, node: cst.BaseStatement, 229 | ctx: BlockContext) -> cst.BaseStatement: 230 | if len(ctx.hanging_lines) > 0 and hasattr(node, 'leading_lines'): 231 | new_node = node.with_changes(leading_lines=ctx.hanging_lines + 232 | list(node.leading_lines)) 233 | ctx.hanging_lines.clear() 234 | return new_node 235 | return node 236 | 237 | def _leave_stmt( 238 | self, 239 | original_node: cst.BaseStatement, 240 | final_node: Union[cst.BaseStatement, cst.RemovalSentinel], 241 | ) -> None: 242 | ctx = self._context() 243 | ctx_stmt = ctx.ctx_stmt.pop() 244 | if not ctx.ctx_block: 245 | return final_node 246 | 247 | ctx_block = ctx.ctx_block[-1] 248 | 249 | should_insert = not isinstance(final_node, cst.RemovalSentinel) 250 | if hasattr(original_node, 'leading_lines') and ctx_stmt.keep_comments: 251 | ctx_block.hanging_lines.extend(original_node.leading_lines) 252 | 253 | if should_insert: 254 | final_node = final_node.with_changes(leading_lines=[]) 255 | 256 | if len(ctx_stmt.before_stmts) > 0: 257 | before_stmts = [ 258 | self._add_hanging_lines(ctx_stmt.before_stmts[0], ctx_block) 259 | ] + ctx_stmt.before_stmts[1:] 260 | ctx_block.new_body.extend(before_stmts) 261 | ctx_block.added_stmts.update(set(before_stmts)) 262 | 263 | if should_insert: 264 | final_node = self._add_hanging_lines(final_node, ctx_block) 265 | ctx_block.new_body.append(final_node) 266 | 267 | if len(ctx_stmt.after_stmts) > 0: 268 | ctx_block.new_body.extend(ctx_stmt.after_stmts) 269 | ctx_block.added_stmts.update(set(ctx_stmt.after_stmts)) 270 | 271 | return final_node 272 | 273 | def visit_SimpleStatementLine(self, node: cst.SimpleStatementLine 274 | ) -> Optional[bool]: 275 | self._visit_stmt(node) 276 | return super().visit_SimpleStatementLine(node) 277 | 278 | def leave_SimpleStatementLine( 279 | self, 280 | original_node: cst.SimpleStatementLine, 281 | updated_node: cst.SimpleStatementLine, 282 | ) -> Union[cst.BaseStatement, cst.RemovalSentinel]: 283 | final_node = super().leave_SimpleStatementLine(original_node, 284 | updated_node) 285 | return self._leave_stmt(original_node, final_node) 286 | 287 | def visit_If(self, node: cst.If) -> Optional[bool]: 288 | self._visit_stmt(node) 289 | return super().visit_If(node) 290 | 291 | def leave_If( 292 | self, 293 | original_node: cst.If, 294 | updated_node: cst.If, 295 | ) -> Union[cst.BaseStatement, cst.RemovalSentinel]: 296 | final_node = super().leave_If(original_node, updated_node) 297 | return self._leave_stmt(original_node, final_node) 298 | 299 | def visit_Try(self, node: cst.Try) -> Optional[bool]: 300 | self._visit_stmt(node) 301 | return super().visit_Try(node) 302 | 303 | def leave_Try( 304 | self, 305 | original_node: cst.Try, 306 | updated_node: cst.Try, 307 | ) -> Union[cst.BaseStatement, cst.RemovalSentinel]: 308 | final_node = super().leave_Try(original_node, updated_node) 309 | return self._leave_stmt(original_node, final_node) 310 | 311 | def visit_FunctionDef(self, node: cst.FunctionDef) -> Optional[bool]: 312 | self._visit_stmt(node) 313 | return super().visit_FunctionDef(node) 314 | 315 | def leave_FunctionDef( 316 | self, 317 | original_node: cst.FunctionDef, 318 | updated_node: cst.FunctionDef, 319 | ) -> Union[cst.BaseStatement, cst.RemovalSentinel]: 320 | final_node = super().leave_FunctionDef(original_node, updated_node) 321 | return self._leave_stmt(original_node, final_node) 322 | 323 | def visit_ClassDef(self, node: cst.ClassDef) -> Optional[bool]: 324 | self._visit_stmt(node) 325 | return super().visit_ClassDef(node) 326 | 327 | def leave_ClassDef( 328 | self, 329 | original_node: cst.ClassDef, 330 | updated_node: cst.ClassDef, 331 | ) -> Union[cst.BaseStatement, cst.RemovalSentinel]: 332 | final_node = super().leave_ClassDef(original_node, updated_node) 333 | return self._leave_stmt(original_node, final_node) 334 | 335 | def visit_With(self, node: cst.With) -> Optional[bool]: 336 | self._visit_stmt(node) 337 | return super().visit_With(node) 338 | 339 | def leave_With( 340 | self, 341 | original_node: cst.With, 342 | updated_node: cst.With, 343 | ) -> Union[cst.BaseStatement, cst.RemovalSentinel]: 344 | final_node = super().leave_With(original_node, updated_node) 345 | return self._leave_stmt(original_node, final_node) 346 | 347 | def visit_For(self, node: cst.For) -> Optional[bool]: 348 | self._visit_stmt(node) 349 | return super().visit_For(node) 350 | 351 | def leave_For( 352 | self, 353 | original_node: cst.For, 354 | updated_node: cst.For, 355 | ) -> Union[cst.BaseStatement, cst.RemovalSentinel]: 356 | final_node = super().leave_For(original_node, updated_node) 357 | return self._leave_stmt(original_node, final_node) 358 | 359 | def visit_While(self, node: cst.While) -> Optional[bool]: 360 | self._visit_stmt(node) 361 | return super().visit_While(node) 362 | 363 | def leave_While( 364 | self, 365 | original_node: cst.While, 366 | updated_node: cst.While, 367 | ) -> Union[cst.BaseStatement, cst.RemovalSentinel]: 368 | final_node = super().leave_While(original_node, updated_node) 369 | return self._leave_stmt(original_node, final_node) 370 | -------------------------------------------------------------------------------- /ast_tools/macros.py: -------------------------------------------------------------------------------- 1 | class unroll: 2 | def __init__(self, _iter): 3 | self._iter = _iter 4 | 5 | def __iter__(self): 6 | return iter(self._iter) 7 | 8 | class inline: 9 | def __init__(self, cond): 10 | self._cond = cond 11 | 12 | def __bool__(self): 13 | return self._cond 14 | 15 | -------------------------------------------------------------------------------- /ast_tools/metadata/__init__.py: -------------------------------------------------------------------------------- 1 | from .always_returns_provider import AlwaysReturnsProvider 2 | from .condition_provider import ConditionProvider, IncrementalConditionProvider 3 | -------------------------------------------------------------------------------- /ast_tools/metadata/always_returns_provider.py: -------------------------------------------------------------------------------- 1 | import typing as tp 2 | 3 | import libcst as cst 4 | 5 | class AlwaysReturnsProvider(cst.BatchableMetadataProvider[bool]): 6 | def _visit_simple_block(self, 7 | node: tp.Union[cst.SimpleStatementLine, cst.SimpleStatementSuite] 8 | ) -> tp.Optional[bool]: 9 | for child in node.body: 10 | if isinstance(child, cst.Return): 11 | self.set_metadata(node, True) 12 | return False 13 | self.set_metadata(node, False) 14 | return False 15 | 16 | def visit_SimpleStatementLine(self, 17 | node: cst.SimpleStatementLine, 18 | ) -> tp.Optional[bool]: 19 | return self._visit_simple_block(node) 20 | 21 | def visit_SimpleStatementSuite(self, 22 | node: cst.SimpleStatementLine, 23 | ) -> tp.Optional[bool]: 24 | return self._visit_simple_block(node) 25 | 26 | def leave_IndentedBlock(self, node: cst.IndentedBlock) -> None: 27 | for child in node.body: 28 | if self.get_metadata(type(self), child, False): 29 | self.set_metadata(node, True) 30 | return 31 | self.set_metadata(node, False) 32 | 33 | def leave_If(self, node: cst.If) -> None: 34 | if node.orelse is None: 35 | self.set_metadata(node, False) 36 | else: 37 | self.set_metadata(node, 38 | self.get_metadata(type(self), node.body, False) 39 | and self.get_metadata(type(self), node.orelse, False)) 40 | 41 | def leave_Else(self, node: cst.Else) -> None: 42 | self.set_metadata(node, self.get_metadata(type(self), node.body, False)) 43 | 44 | -------------------------------------------------------------------------------- /ast_tools/metadata/condition_provider.py: -------------------------------------------------------------------------------- 1 | import typing as tp 2 | 3 | import libcst as cst 4 | 5 | from . import AlwaysReturnsProvider 6 | 7 | 8 | class ConditionProvider(cst.VisitorMetadataProvider): 9 | ''' 10 | Marks each node with the conditions under which they will be executed 11 | ''' 12 | cond_stack: tp.List[cst.BaseExpression] 13 | 14 | def __init__(self, simplify: bool = False): 15 | self.cond_stack = [] 16 | 17 | def on_leave(self, node: cst.CSTNode) -> None: 18 | self.set_metadata(node, tuple(self.cond_stack)) 19 | return super().on_leave(node) 20 | 21 | def visit_If_body(self, node: cst.If) -> None: 22 | self.cond_stack.append(node.test) 23 | 24 | def leave_If_body(self, node: cst.If) -> None: 25 | self.cond_stack.pop() 26 | 27 | def visit_If_orelse(self, node: cst.If) -> None: 28 | self.cond_stack.append(cst.UnaryOperation(cst.Not(), node.test)) 29 | 30 | def leave_If_orelse(self, node: cst.If) -> None: 31 | self.cond_stack.pop() 32 | 33 | def visit_IfExp_body(self, node: cst.IfExp) -> None: 34 | self.cond_stack.append(node.test) 35 | 36 | def leave_IfExp_body(self, node: cst.If) -> None: 37 | self.cond_stack.pop() 38 | 39 | def visit_IfExp_orelse(self, node: cst.IfExp) -> None: 40 | self.cond_stack.append(cst.UnaryOperation(cst.Not(), node.test)) 41 | 42 | def leave_IfExp_orelse(self, node: cst.IfExp) -> None: 43 | self.cond_stack.pop() 44 | 45 | 46 | class IncrementalConditionProvider(ConditionProvider): 47 | ''' 48 | Condition provider which implicitly negates previous conditions if 49 | they are not explicitly listed. Used in SSA to generate a "minimal" 50 | ite structures. 51 | 52 | Consider: 53 | ``` 54 | if x: 55 | return 0 56 | else: 57 | return 1 58 | return 2 59 | ``` 60 | using the normal ConditonProvider ssa would generate the following: 61 | ``` 62 | return 0 if x else (1 if not x else 2) 63 | ``` 64 | However, do the structure of the program we can see that this can be 65 | simplified to: 66 | ``` 67 | return 0 if x else 1 68 | ``` 69 | ''' 70 | 71 | METADATA_DEPENDENCIES = (AlwaysReturnsProvider,) 72 | 73 | 74 | def visit_If_orelse(self, node: cst.If) -> None: 75 | if not self.get_metadata(AlwaysReturnsProvider, node.body): 76 | super().visit_If_orelse(node) 77 | 78 | def leave_If_orelse(self, node: cst.If) -> None: 79 | if not self.get_metadata(AlwaysReturnsProvider, node.body): 80 | super().leave_If_orelse(node) 81 | -------------------------------------------------------------------------------- /ast_tools/passes/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import * # This MUST be first 2 | 3 | from .bool_to_bit import bool_to_bit 4 | from .debug import debug 5 | from .if_inline import if_inline 6 | from .if_to_phi import if_to_phi 7 | from .loop_unroll import loop_unroll 8 | from .remove_asserts import remove_asserts 9 | from .ssa import ssa 10 | from .util import apply_passes, apply_ast_passes, apply_cst_passes 11 | -------------------------------------------------------------------------------- /ast_tools/passes/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | import typing as tp 3 | 4 | import libcst as cst 5 | 6 | from ast_tools.stack import SymbolTable 7 | 8 | __ALL__ = ['Pass', 'PASS_ARGS_T'] 9 | 10 | PASS_ARGS_T = tp.Tuple[cst.CSTNode, SymbolTable, tp.MutableMapping] 11 | 12 | 13 | class Pass(metaclass=ABCMeta): 14 | """ 15 | Abstract base class for passes 16 | Mostly a convience to unpack arguments 17 | """ 18 | 19 | def __call__(self, args: PASS_ARGS_T) -> PASS_ARGS_T: 20 | return self.rewrite(*args) 21 | 22 | @abstractmethod 23 | def rewrite(self, 24 | tree: cst.CSTNode, 25 | env: SymbolTable, 26 | metadata: tp.MutableMapping, 27 | ) -> PASS_ARGS_T: 28 | return tree, env, metadata 29 | -------------------------------------------------------------------------------- /ast_tools/passes/bool_to_bit.py: -------------------------------------------------------------------------------- 1 | import typing as tp 2 | 3 | import libcst as cst 4 | 5 | from . import Pass 6 | from . import PASS_ARGS_T 7 | 8 | from ast_tools.stack import SymbolTable 9 | 10 | __ALL__ = ['bool_to_bit'] 11 | 12 | class BoolOpTransformer(cst.CSTTransformer): 13 | match: tp.Union[ 14 | tp.Sequence[tp.Type[cst.BaseBooleanOp]], 15 | tp.Type[cst.BaseBooleanOp], 16 | ] 17 | 18 | replace: tp.Type[cst.BaseBinaryOp] 19 | 20 | def leave_BooleanOperation( 21 | self, 22 | original_node: cst.BooleanOperation, 23 | updated_node: cst.BooleanOperation) -> cst.BinaryOperation: 24 | if isinstance(updated_node.operator, self.match): 25 | return cst.BinaryOperation( 26 | left=updated_node.left, 27 | operator=self.replace(), 28 | right=updated_node.right, 29 | lpar=updated_node.lpar, 30 | rpar=updated_node.rpar 31 | ) 32 | else: 33 | return updated_node 34 | 35 | 36 | class AndTransformer(BoolOpTransformer): 37 | match = cst.And 38 | replace = cst.BitAnd 39 | 40 | 41 | class OrTransformer(BoolOpTransformer): 42 | match = cst.Or 43 | replace = cst.BitOr 44 | 45 | 46 | class NotTransformer(cst.CSTTransformer): 47 | def leave_Not( 48 | self, 49 | original_node: cst.Not, 50 | updated_node: cst.Not) -> cst.BitInvert: 51 | return cst.BitInvert() 52 | 53 | 54 | class bool_to_bit(Pass): 55 | ''' 56 | Pass to replace bool operators (and, or, not) 57 | with bit operators (&, |, ~) 58 | ''' 59 | def __init__(self, 60 | replace_and: bool = True, 61 | replace_or: bool = True, 62 | replace_not: bool = True, 63 | ): 64 | self.replace_and = replace_and 65 | self.replace_or = replace_or 66 | self.replace_not = replace_not 67 | 68 | def rewrite(self, 69 | tree: cst.CSTNode, 70 | env: SymbolTable, 71 | metadata: tp.MutableMapping) -> PASS_ARGS_T: 72 | if self.replace_and: 73 | visitor = AndTransformer() 74 | tree = tree.visit(visitor) 75 | 76 | if self.replace_or: 77 | visitor = OrTransformer() 78 | tree = tree.visit(visitor) 79 | 80 | if self.replace_not: 81 | visitor = NotTransformer() 82 | tree = tree.visit(visitor) 83 | 84 | return tree, env, metadata 85 | -------------------------------------------------------------------------------- /ast_tools/passes/cse.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import astor 3 | from copy import deepcopy 4 | from collections import Counter 5 | import typing as tp 6 | import warnings 7 | 8 | from . import Pass 9 | from . import PASS_ARGS_T 10 | 11 | from ast_tools.common import gen_free_prefix, gen_free_name, is_free_name 12 | from ast_tools.immutable_ast import immutable, mutable 13 | from ast_tools.stack import SymbolTable 14 | from ast_tools.transformers.node_replacer import NodeReplacer 15 | from ast_tools.visitors.node_finder import NodeFinder 16 | 17 | __all__ = ['cse'] 18 | 19 | def _is_leaf_expr(node: ast.expr): 20 | assert isinstance(nod, ast.expr) 21 | 22 | return isinstance(node,( 23 | ast.Attribute, 24 | ast.Constant, 25 | ast.Name, 26 | ast.NameConstant, 27 | ast.Num, 28 | ast.Subscript, 29 | )) 30 | 31 | class ExprKeyGetter: 32 | @staticmethod 33 | def _get_key(node: ast.AST): 34 | if isinstance(node, ast.expr): 35 | # Need the immutable value so its comparable 36 | return immutable(node) 37 | else: 38 | return None 39 | 40 | 41 | class ExprFinder(ExprKeyGetter, NodeFinder): pass 42 | 43 | 44 | class ExprReplacer(ExprKeyGetter, NodeReplacer): pass 45 | 46 | 47 | class ExprCounter(ast.NodeVisitor): 48 | def __init__(self, count_calls: bool): 49 | self.cses = Counter() 50 | self.count_calls = count_calls 51 | 52 | def visit_UnaryOp(self, node: ast.UnaryOp): 53 | self.cses[immutable(node)] += 1 54 | self.generic_visit(node) 55 | 56 | def visit_BinOp(self, node: ast.BinOp): 57 | self.cses[immutable(node)] += 1 58 | self.generic_visit(node) 59 | 60 | def visit_BoolOp(self, node: ast.BoolOp): 61 | self.cses[immutable(node)] += 1 62 | self.generic_visit(node) 63 | 64 | def visit_Compare(self, node: ast.Compare): 65 | self.cses[immutable(node)] += 1 66 | self.generic_visit(node) 67 | 68 | def visit_IfExpr(self, node: ast.IfExp): 69 | self.cses[immutable(node)] += 1 70 | self.generic_visit(node) 71 | 72 | def vist_Call(self, node: ast.Call): 73 | if self.count_calls: 74 | self.cses[immutable(node)] += 1 75 | 76 | return self.generic_visit(node) 77 | 78 | 79 | class ExprSaver(ast.NodeTransformer): 80 | ''' 81 | Saves an expression in a variable then replaces 82 | future occurrences of that expression with the variable 83 | ''' 84 | # this could probably be more effecient by handling multiple exprs 85 | # at a time but this is simple 86 | 87 | def __init__(self, 88 | cse, 89 | cse_name): 90 | 91 | self.cse = cse 92 | self.cse_name = cse_name 93 | self.recorded = False 94 | self.root = None 95 | self.replacer = ExprReplacer({}) 96 | self.replacer.add_replacement(cse, ast.Name(cse_name, ast.Load())) 97 | 98 | def visit(self, node: ast.AST) -> ast.AST: 99 | # basically want to be able to visit a top level def 100 | # but don't want to generally recurse into them 101 | # also want to change behavior after recording the cse 102 | if self.root is None: 103 | self.root = node 104 | return super().generic_visit(node) 105 | elif self.recorded: 106 | return self.replacer.visit(node) 107 | else: 108 | return super().visit(node) 109 | 110 | 111 | def visit_Assign(self, node): 112 | assert not self.recorded 113 | finder = ExprFinder(self.cse) 114 | finder.visit(node) 115 | if finder.target is not None: 116 | self.recorded = True 117 | # save the expr into a variable 118 | save = ast.Assign( 119 | targets=[ast.Name(self.cse_name, ast.Store())], 120 | value=deepcopy(self.cse)) 121 | 122 | # eliminate the node from the expression 123 | stmt = self.replacer.visit(node) 124 | return [save, stmt] 125 | else: 126 | return super().generic_visit(node) 127 | 128 | # don't support control flow (assumes ssa) 129 | def visit_If(self, node: ast.If): 130 | raise SyntaxError(f"Cannot handle node {node}") 131 | 132 | def visit_For(self, node: ast.For): 133 | raise SyntaxError(f"Cannot handle node {node}") 134 | 135 | def visit_AsyncFor(self, node: ast.AsyncFor): 136 | raise SyntaxError(f"Cannot handle node {node}") 137 | 138 | def visit_While(self, node: ast.While): 139 | raise SyntaxError(f"Cannot handle node {node}") 140 | 141 | def visit_With(self, node: ast.With): 142 | raise SyntaxError(f"Cannot handle node {node}") 143 | 144 | def visit_AsyncWith(self, node: ast.AsyncWith): 145 | raise SyntaxError(f"Cannot handle node {node}") 146 | 147 | def visit_Try(self, node: ast.Try): 148 | raise SyntaxError(f"Cannot handle node {node}") 149 | 150 | # don't recurs into defs 151 | def visit_ClassDef(self, node: ast.ClassDef): 152 | return node 153 | 154 | def visit_FunctionDef(self, node: ast.FunctionDef): 155 | return node 156 | 157 | def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef): 158 | return node 159 | 160 | 161 | class cse(Pass): 162 | ''' 163 | Performs common subexpression elimination 164 | 165 | cse_prefix controls the name of variable eliminated expressions are saved in 166 | This does not have any semantic effect. 167 | 168 | elim_calls controls whether calls repeated calls should be eliminated 169 | 170 | min_freq the minimum freq of an expression should have to be eliminated 171 | 172 | Must be run after ssa. 173 | 174 | Post bool_to_bit will likely eliminate more as: 175 | `a and b and c` 176 | is a single `BoolOp` but 177 | `a & b & c` 178 | is: 179 | `(a & b) & c` 180 | 181 | this means `a and b` is not a subexpression of `a and b and c` 182 | but `a & b` is a subexpression of `a & b & c` 183 | ''' 184 | def __init__(self, 185 | cse_prefix: str = '__common_expr', 186 | elim_calls: bool = False, 187 | min_freq: int = 2, 188 | ): 189 | if min_freq < 2: 190 | raise ValueError('min_freq must be >= 2') 191 | self.cse_prefix = cse_prefix 192 | self.elim_calls = elim_calls 193 | self.min_freq = min_freq 194 | 195 | 196 | def rewrite(self, 197 | tree: ast.AST, 198 | env: SymbolTable, 199 | metadata: tp.MutableMapping) -> PASS_ARGS_T: 200 | 201 | prefix = gen_free_prefix(tree, env, self.cse_prefix) 202 | c = 0 203 | while True: 204 | # Count all the expressions in the tree 205 | counter = ExprCounter(self.elim_calls) 206 | counter.visit(tree) 207 | 208 | # If there are no expression in the tree 209 | if not counter.cses: 210 | break 211 | 212 | # get the most common expression 213 | expr, freq = counter.cses.most_common()[0] 214 | if freq < self.min_freq: 215 | break 216 | 217 | expr = mutable(expr) 218 | 219 | # Find the first occurrence of the expression 220 | # and save it to a variable then replace 221 | # future occurrences of that expression with 222 | # references to that variable 223 | saver = ExprSaver(expr, prefix + repr(c)) 224 | c += 1 225 | tree = saver.visit(tree) 226 | 227 | 228 | return tree, env, metadata 229 | -------------------------------------------------------------------------------- /ast_tools/passes/debug.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import code 3 | import typing as tp 4 | import warnings 5 | 6 | import astor 7 | 8 | import libcst as cst 9 | 10 | from . import Pass 11 | from . import PASS_ARGS_T 12 | from ast_tools import to_module 13 | from ast_tools.stack import SymbolTable 14 | 15 | __ALL__ = ['debug'] 16 | 17 | class debug(Pass): 18 | def __init__(self, 19 | dump_ast: bool = False, 20 | dump_src: bool = False, 21 | dump_env: bool = False, 22 | file: tp.Optional[str] = None, 23 | append: tp.Optional[bool] = None, 24 | dump_source_filename: bool = False, 25 | dump_source_lines: bool = False, 26 | interactive: bool = False, 27 | ) -> PASS_ARGS_T: 28 | self.dump_ast = dump_ast 29 | self.dump_src = dump_src 30 | self.dump_env = dump_env 31 | if append is not None and file is None: 32 | warnings.warn('Option append has no effect when file is None', stacklevel=2) 33 | self.file = file 34 | self.append = append 35 | self.dump_source_filename = dump_source_filename 36 | self.dump_source_lines = dump_source_lines 37 | self.interactive = interactive 38 | 39 | def rewrite(self, 40 | tree: cst.CSTNode, 41 | env: SymbolTable, 42 | metadata: tp.MutableMapping) -> PASS_ARGS_T: 43 | 44 | def _do_dumps(dumps, dump_writer): 45 | for dump in dumps: 46 | dump_writer(f'BEGIN {dump[0]}\n') 47 | dump_writer(dump[1].strip()) 48 | dump_writer(f'\nEND {dump[0]}\n\n') 49 | 50 | dumps = [] 51 | if self.dump_ast: 52 | dumps.append(('AST', repr(tree))) 53 | if self.dump_src: 54 | dumps.append(('SRC', to_module(tree).code)) 55 | if self.dump_env: 56 | dumps.append(('ENV', repr(env))) 57 | if self.dump_source_filename: 58 | if "source_filename" not in metadata: 59 | raise Exception("Cannot dump source filename without " 60 | "apply_passes(..., debug=True)") 61 | dumps.append(('SOURCE_FILENAME', metadata["source_filename"])) 62 | if self.dump_source_lines: 63 | if "source_lines" not in metadata: 64 | raise Exception("Cannot dump source lines without " 65 | "apply_passes(..., debug=True)") 66 | lines, start_line_number = metadata["source_lines"] 67 | dump_str = "".join(f"{start_line_number + i}:{line}" for i, line in 68 | enumerate(lines)) 69 | dumps.append(('SOURCE_LINES', dump_str)) 70 | 71 | if self.file is not None: 72 | if self.append: 73 | mode = 'wa' 74 | else: 75 | mode = 'w' 76 | with open(self.dump_file, mode) as fp: 77 | _do_dumps(dumps, fp.write) 78 | else: 79 | def _print(*args, **kwargs): print(*args, end='', **kwargs) 80 | _do_dumps(dumps, _print) 81 | 82 | if self.interactive: 83 | # Launch a repl loop 84 | code.interact( 85 | banner=('Warning: modifications to tree, env, and metadata ' 86 | 'will have side effects'), 87 | local=dict(tree=tree, env=env, metadata=metadata), 88 | ) 89 | 90 | return tree, env, metadata 91 | -------------------------------------------------------------------------------- /ast_tools/passes/if_inline.py: -------------------------------------------------------------------------------- 1 | import typing as tp 2 | 3 | import libcst as cst 4 | 5 | from ast_tools.stack import SymbolTable 6 | from . import Pass, PASS_ARGS_T 7 | from ast_tools.transformers.if_inliner import inline_ifs 8 | from ast_tools.transformers.normalizers import ElifToElse 9 | 10 | 11 | class if_inline(Pass): 12 | def rewrite(self, 13 | tree: cst.CSTNode, 14 | env: SymbolTable, 15 | metadata: tp.MutableMapping) -> PASS_ARGS_T: 16 | return inline_ifs(tree, env), env, metadata 17 | -------------------------------------------------------------------------------- /ast_tools/passes/if_to_phi.py: -------------------------------------------------------------------------------- 1 | import typing as tp 2 | import warnings 3 | 4 | import libcst as cst 5 | 6 | from . import Pass 7 | from . import PASS_ARGS_T 8 | 9 | from ast_tools.common import gen_free_name 10 | from ast_tools.stack import SymbolTable 11 | 12 | __ALL__ = ['if_to_phi'] 13 | 14 | 15 | class IfExpTransformer(cst.CSTTransformer): 16 | def __init__(self, phi_name: str): 17 | self.phi_name = phi_name 18 | 19 | def leave_IfExp( 20 | self, 21 | original_node: cst.IfExp, 22 | updated_node: cst.IfExp, 23 | ): 24 | return cst.Call( 25 | func=cst.Name(value=self.phi_name), 26 | args=[ 27 | cst.Arg(value=v) for v in ( 28 | updated_node.test, 29 | updated_node.body, 30 | updated_node.orelse 31 | ) 32 | ], 33 | ) 34 | 35 | class if_to_phi(Pass): 36 | ''' 37 | Pass to convert IfExp to call to phi functions 38 | phi should have signature: 39 | phi :: Condition -> T -> F -> Union[T, F] 40 | where: 41 | Condition is usually bool 42 | T is the True branch 43 | F is the False branch 44 | ''' 45 | 46 | def __init__(self, 47 | phi: tp.Union[tp.Callable, str], 48 | phi_name_prefix: tp.Optional[str] = None): 49 | self.phi = phi 50 | if isinstance(phi, str) and phi_name_prefix is not None: 51 | warnings.warn('phi_name_prefix has no effect ' 52 | 'if phi is a str', stacklevel=2) 53 | elif phi_name_prefix is None: 54 | phi_name_prefix = '__phi' 55 | 56 | self.phi_name_prefix = phi_name_prefix 57 | 58 | def rewrite(self, 59 | tree: cst.CSTNode, 60 | env: SymbolTable, 61 | metadata: tp.MutableMapping) -> PASS_ARGS_T: 62 | 63 | if not isinstance(self.phi, str): 64 | phi_name = gen_free_name(tree, env, self.phi_name_prefix) 65 | env.locals[phi_name] = self.phi 66 | else: 67 | phi_name = self.phi 68 | 69 | visitor = IfExpTransformer(phi_name) 70 | tree = tree.visit(visitor) 71 | return tree, env, metadata 72 | -------------------------------------------------------------------------------- /ast_tools/passes/loop_unroll.py: -------------------------------------------------------------------------------- 1 | import typing as tp 2 | 3 | import libcst as cst 4 | 5 | from ast_tools.stack import SymbolTable 6 | from . import Pass, PASS_ARGS_T 7 | from ast_tools.transformers.loop_unroller import unroll_for_loops 8 | 9 | 10 | class loop_unroll(Pass): 11 | def rewrite(self, 12 | tree: cst.CSTNode, 13 | env: SymbolTable, 14 | metadata: tp.MutableMapping) -> PASS_ARGS_T: 15 | return unroll_for_loops(tree, env), env, metadata 16 | -------------------------------------------------------------------------------- /ast_tools/passes/remove_asserts.py: -------------------------------------------------------------------------------- 1 | import typing as tp 2 | 3 | import libcst as cst 4 | 5 | from . import Pass, PASS_ARGS_T 6 | from ast_tools.stack import SymbolTable 7 | from ast_tools.transformers.node_replacer import NodeReplacer 8 | 9 | class AssertRemover(NodeReplacer): 10 | def __init__(self): 11 | # replace asserts with pass 12 | super().__init__({cst.Assert: cst.Pass()}) 13 | 14 | def _get_key(self, node): return type(node) 15 | 16 | 17 | class remove_asserts(Pass): 18 | def rewrite(self, 19 | tree: cst.CSTNode, 20 | env: SymbolTable, 21 | metadata: tp.MutableMapping) -> PASS_ARGS_T: 22 | 23 | visitor = AssertRemover() 24 | tree = tree.visit(visitor) 25 | return tree, env, metadata 26 | 27 | -------------------------------------------------------------------------------- /ast_tools/passes/ssa.py: -------------------------------------------------------------------------------- 1 | from collections import ChainMap, Counter 2 | import builtins 3 | import types 4 | import functools as ft 5 | import typing as tp 6 | 7 | import libcst as cst 8 | from libcst.metadata import ExpressionContext, ExpressionContextProvider, PositionProvider 9 | from libcst import matchers as m 10 | 11 | from ast_tools.common import gen_free_prefix 12 | from ast_tools.cst_utils import DeepNode 13 | from ast_tools.cst_utils import to_module, make_assign, to_stmt 14 | from ast_tools.metadata import AlwaysReturnsProvider, IncrementalConditionProvider 15 | from ast_tools.stack import SymbolTable 16 | from ast_tools.transformers.node_tracker import NodeTrackingTransformer, with_tracking 17 | from ast_tools.transformers.node_replacer import NodeReplacer 18 | from ast_tools.transformers.normalizers import ElifToElse 19 | from ast_tools.utils import BiMap 20 | from . import Pass, PASS_ARGS_T 21 | 22 | __ALL__ = ['ssa'] 23 | 24 | #([gaurds], expr]) 25 | _GAURDED_EXPR = tp.Tuple[tp.Sequence[cst.BaseExpression], cst.BaseExpression] 26 | 27 | def _simplify_gaurds( 28 | gaurded_seq: tp.Sequence[_GAURDED_EXPR], 29 | ) -> tp.Sequence[_GAURDED_EXPR]: 30 | ''' 31 | Pretty simplistic simplifyication 32 | which looks for: 33 | [ 34 | ..., 35 | ([p, q], expri), 36 | ([p, not q], exprj), 37 | ... 38 | ] 39 | and simplifies it to: 40 | [ 41 | ..., 42 | ([p, q], expri), 43 | ([p], exprj), 44 | ... 45 | ] 46 | also truncates the list after empty gaurds: 47 | [ 48 | ..., 49 | ([], expr), 50 | ... 51 | ] 52 | becomes: 53 | [ 54 | ..., 55 | ([], expr) 56 | ] 57 | ''' 58 | # wrap the implementation so we can assert that 59 | # we reach a fixed point after a single invocation 60 | def impl(gaurded_seq): 61 | new_seq = [gaurded_seq[0]] 62 | for gaurd, expr in gaurded_seq[1:]: 63 | last_gaurd = new_seq[-1][0] 64 | # Truncate 65 | if not last_gaurd: 66 | break 67 | elif not gaurd: 68 | new_seq.append((gaurd, expr)) 69 | break 70 | 71 | pred = gaurd[-1] 72 | if (isinstance(pred, cst.UnaryOperation) 73 | and isinstance(pred.operator, cst.Not) 74 | and pred.expression == last_gaurd[-1] 75 | and gaurd[:-1] == last_gaurd[:-1]): 76 | # simplify 77 | new_seq.append((gaurd[:-1], expr)) 78 | else: 79 | new_seq.append((gaurd, expr)) 80 | return new_seq 81 | 82 | new_seq = impl(gaurded_seq) 83 | assert new_seq == impl(new_seq) 84 | return new_seq 85 | 86 | 87 | class IncompleteGaurdError(Exception): pass 88 | 89 | 90 | def _fold_conditions( 91 | gaurded_seq: tp.Sequence[_GAURDED_EXPR], 92 | strict: bool, 93 | ) -> cst.BaseExpression: 94 | def and_builder( 95 | left: cst.BaseExpression, 96 | right: cst.BaseExpression 97 | ) -> cst.BooleanOperation: 98 | return cst.BooleanOperation( 99 | left=left, 100 | operator=cst.And(), 101 | right=right, 102 | ) 103 | 104 | if not gaurded_seq and strict: 105 | raise IncompleteGaurdError() 106 | 107 | gaurd, expr = gaurded_seq[0] 108 | if not gaurd or (not strict and len(gaurded_seq) == 1): 109 | return expr 110 | else: 111 | if len(gaurd) == 1: 112 | test = gaurd[0] 113 | else: 114 | test = ft.reduce(and_builder, gaurd) 115 | 116 | conditional = cst.IfExp( 117 | test=test, 118 | body=expr, 119 | orelse=_fold_conditions(gaurded_seq[1:], strict) 120 | ) 121 | return conditional 122 | 123 | 124 | class NameTests(NodeTrackingTransformer): 125 | ''' 126 | Rewrites if statements so that their tests are a single name 127 | This gaurentees that the conditions are ssa 128 | e.g: 129 | if x == 0: 130 | ... 131 | else: 132 | ... 133 | becomes: 134 | cond = x == 0 135 | if cond: 136 | ... 137 | else: 138 | ... 139 | ''' 140 | format: str 141 | added_names: tp.MutableSet[str] 142 | 143 | def __init__(self, prefix): 144 | super().__init__() 145 | self.format = prefix+'_{}' 146 | self.added_names = set() 147 | 148 | def leave_If(self, 149 | original_node: cst.If, 150 | updated_node: cst.If, 151 | ) -> cst.If: 152 | c_name = cst.Name(value=self.format.format(len(self.added_names))) 153 | self.added_names.add(c_name.value) 154 | assign = to_stmt(make_assign(c_name, updated_node.test)) 155 | final_node = updated_node.with_changes(test=c_name) 156 | return cst.FlattenSentinel([assign, final_node]) 157 | 158 | 159 | class SingleReturn(NodeTrackingTransformer): 160 | METADATA_DEPENDENCIES = (IncrementalConditionProvider, AlwaysReturnsProvider) 161 | 162 | attr_format: tp.Optional[str] 163 | attr_states: tp.MutableMapping[str, tp.MutableSequence[_GAURDED_EXPR]] 164 | strict: bool 165 | debug: bool 166 | env: tp.Mapping[str, tp.Any] 167 | names_to_attr: tp.Mapping[str, cst.Attribute] 168 | return_format: tp.Optional[str] 169 | returns: tp.MutableSequence[_GAURDED_EXPR] 170 | scope: tp.Optional[cst.FunctionDef] 171 | tail: tp.MutableSequence[cst.BaseStatement] 172 | added_names: tp.MutableSet[str] 173 | returning_blocks: tp.MutableSet[cst.BaseSuite] 174 | 175 | def __init__(self, 176 | env: tp.Mapping[str, tp.Any], 177 | names_to_attr: tp.Mapping[str, cst.Attribute], 178 | strict: bool = True, 179 | ): 180 | 181 | super().__init__() 182 | self.attr_format = None 183 | self.attr_states = {} 184 | self.strict = strict 185 | self.env = env 186 | self.names_to_attr = names_to_attr 187 | self.returns = [] 188 | self.return_format = None 189 | self.scope = None 190 | self.tail = [] 191 | self.added_names = set() 192 | self.returning_blocks = set() 193 | 194 | def visit_FunctionDef(self, 195 | node: cst.FunctionDef) -> tp.Optional[bool]: 196 | # prevent recursion into inner functions 197 | super().visit_FunctionDef(node) 198 | if self.scope is None: 199 | self.scope = node 200 | prefix = gen_free_prefix(node, self.env, '__') 201 | self.attr_format = prefix + '_final_{}_{}_{}' 202 | self.return_format = prefix + '_return_{}' 203 | 204 | return True 205 | return False 206 | 207 | def leave_FunctionDef(self, 208 | original_node: cst.FunctionDef, 209 | updated_node: cst.FunctionDef 210 | ) -> cst.FunctionDef: 211 | final_node = updated_node 212 | if original_node is self.scope: 213 | suite = updated_node.body 214 | tail = self.tail 215 | for name, attr in self.names_to_attr.items(): 216 | state = self.attr_states.get(name, []) 217 | # default writeback initial value 218 | state.append(([], cst.Name(name))) 219 | attr_val = _fold_conditions(_simplify_gaurds(state), self.strict) 220 | write = to_stmt(make_assign(attr, attr_val)) 221 | tail.append(write) 222 | 223 | if self.returns: 224 | strict = self.strict 225 | 226 | try: 227 | return_val = _fold_conditions(_simplify_gaurds(self.returns), strict) 228 | except IncompleteGaurdError: 229 | raise SyntaxError('Cannot prove function always returns') from None 230 | return_stmt = cst.SimpleStatementLine([cst.Return(value=return_val)]) 231 | tail.append(return_stmt) 232 | 233 | return final_node 234 | 235 | def visit_ClassDef(self, 236 | node: cst.ClassDef) -> tp.Optional[bool]: 237 | return False 238 | 239 | def leave_Return(self, 240 | original_node: cst.Return, 241 | updated_node: cst.Return 242 | ) -> cst.RemovalSentinel: 243 | assert self.return_format is not None 244 | assert self.attr_format is not None 245 | 246 | assignments = [] 247 | cond = self.get_metadata(IncrementalConditionProvider, original_node) 248 | 249 | for name, attr in self.names_to_attr.items(): 250 | assert isinstance(attr.value, cst.Name) 251 | state = self.attr_states.setdefault(name, []) 252 | attr_name = cst.Name( 253 | value=self.attr_format.format( 254 | attr.value.value, 255 | attr.attr.value, 256 | len(state) 257 | ) 258 | ) 259 | self.added_names.add(attr_name.value) 260 | state.append((cond, attr_name)) 261 | assignments.append(make_assign(attr_name, cst.Name(name))) 262 | 263 | r_name = cst.Name(value=self.return_format.format(len(self.returns))) 264 | self.added_names.add(r_name.value) 265 | self.returns.append((cond, r_name)) 266 | 267 | if updated_node.value is None: 268 | r_val = cst.Name(value='None') 269 | else: 270 | r_val = updated_node.value 271 | 272 | assignments.append(make_assign(r_name, r_val)) 273 | 274 | return cst.FlattenSentinel(assignments) 275 | 276 | def leave_SimpleStatementSuite(self, 277 | original_node: cst.SimpleStatementSuite, 278 | updated_node: cst.SimpleStatementSuite, 279 | ) -> cst.SimpleStatementSuite: 280 | final_node = super().leave_SimpleStatementSuite(original_node, updated_node) 281 | if self.get_metadata(AlwaysReturnsProvider, original_node): 282 | self.returning_blocks.add(final_node) 283 | return final_node 284 | 285 | def leave_IndentedBlock(self, 286 | original_node: cst.IndentedBlock, 287 | updated_node: cst.IndentedBlock, 288 | ) -> cst.IndentedBlock: 289 | final_node = super().leave_IndentedBlock(original_node, updated_node) 290 | if self.get_metadata(AlwaysReturnsProvider, original_node): 291 | self.returning_blocks.add(final_node) 292 | return final_node 293 | 294 | 295 | class WrittenAttrs(cst.CSTVisitor): 296 | METADATA_DEPENDENCIES = (ExpressionContextProvider,) 297 | 298 | written_attrs: tp.MutableSet[cst.Attribute] 299 | 300 | def __init__(self): 301 | super().__init__() 302 | self.written_attrs = set() 303 | 304 | 305 | def visit_Attribute(self, 306 | node: cst.Attribute) -> tp.Optional[bool]: 307 | ctx = self.get_metadata(ExpressionContextProvider, node) 308 | if ctx is ExpressionContext.STORE: 309 | self.written_attrs.add(node) 310 | 311 | 312 | class AttrReplacer(NodeReplacer): 313 | def _get_key(self, node): 314 | if isinstance(node, cst.Attribute): 315 | return DeepNode(node) 316 | else: 317 | return None 318 | 319 | 320 | def _wrap(tree: cst.CSTNode) -> cst.MetadataWrapper: 321 | return cst.MetadataWrapper(tree, unsafe_skip_copy=True) 322 | 323 | 324 | class SSATransformer(NodeTrackingTransformer): 325 | env: tp.Mapping[str, tp.Any] 326 | ctxs: tp.Mapping[cst.Name, ExpressionContext] 327 | scope: tp.Optional[cst.FunctionDef] 328 | name_table: tp.ChainMap[str, str] 329 | name_idx: Counter 330 | name_formats: tp.MutableMapping[str, str] 331 | name_assignments: tp.MutableMapping[str, tp.Union[cst.Assign, cst.Param]] 332 | original_names: tp.MutableMapping[str, str] 333 | final_names: tp.AbstractSet[str] 334 | returning_blocks: tp.AbstractSet[cst.BaseSuite] 335 | _skip: bool 336 | _assigned_names: tp.MutableSequence[str] 337 | 338 | 339 | def __init__(self, 340 | env: tp.Mapping[str, tp.Any], 341 | ctxs: tp.Mapping[cst.Name, ExpressionContext], 342 | final_names: tp.AbstractSet[str], 343 | returning_blocks: tp.AbstractSet[cst.BaseSuite], 344 | strict: bool = True, 345 | ): 346 | super().__init__() 347 | _builtins = env.get('__builtins__', builtins) 348 | if isinstance(_builtins, types.ModuleType): 349 | _builtins = builtins.__dict__ 350 | self.env = ChainMap(env, _builtins) 351 | self.ctxs = ctxs 352 | self.scope = None 353 | self.name_assignments = ChainMap() 354 | self.name_idx = Counter() 355 | self.name_table = ChainMap({k: k for k in self.env}) 356 | self.name_formats = {} 357 | self.original_names = {} 358 | self.final_names = final_names 359 | self.strict = strict 360 | self.returning_blocks = returning_blocks 361 | self._skip = 0 362 | self._assigned_names = [] 363 | 364 | 365 | def _make_name(self, name): 366 | if name not in self.name_formats: 367 | prefix = gen_free_prefix(self.scope, self.env, f'{name}_') 368 | self.name_formats[name] = prefix + '{}' 369 | 370 | ssa_name = self.name_formats[name].format(self.name_idx[name]) 371 | self.name_idx[name] += 1 372 | self.name_table[name] = ssa_name 373 | self.original_names[ssa_name] = name 374 | return ssa_name 375 | 376 | def visit_FunctionDef(self, 377 | node: cst.FunctionDef) -> tp.Optional[bool]: 378 | # prevent recursion into inner functions 379 | # and control recursion 380 | if self.scope is None: 381 | self.scope = node 382 | return False 383 | 384 | def leave_FunctionDef(self, 385 | original_node: cst.FunctionDef, 386 | updated_node: cst.FunctionDef) -> cst.FunctionDef: 387 | final_node = updated_node 388 | if original_node is self.scope: 389 | # Don't want to ssa params but do want them in the name table 390 | for param in updated_node.params.params: 391 | name = param.name.value 392 | self.name_table[name] = name 393 | self.name_assignments[name] = param 394 | 395 | # Need to visit params to get them to be rebuilt and therfore 396 | # tracked to build the symbol table 397 | update_params = updated_node.params.visit(self) 398 | assert not self._skip 399 | assert not self._assigned_names, self._assigned_names 400 | new_body = updated_node.body.visit(self) 401 | final_node = updated_node.with_changes(body=new_body, params=update_params) 402 | assert not self._skip 403 | assert not self._assigned_names, self._assigned_names 404 | return final_node 405 | 406 | def visit_If(self, node: cst.If) -> tp.Optional[bool]: 407 | super().visit_If(node) 408 | # Control recursion order 409 | return False 410 | 411 | def leave_If(self, 412 | original_node: cst.If, 413 | updated_node: cst.If, 414 | ) -> tp.Union[cst.If, cst.RemovalSentinel]: 415 | t_returns = original_node.body in self.returning_blocks 416 | if original_node.orelse is not None: 417 | f_returns = original_node.orelse.body in self.returning_blocks 418 | else: 419 | f_returns = False 420 | 421 | new_test = updated_node.test.visit(self) 422 | nt = self.name_table 423 | suite = [] 424 | self.name_table = t_nt = nt.new_child() 425 | new_body = updated_node.body.visit(self) 426 | 427 | suite.extend(new_body.body) 428 | 429 | self.name_table = f_nt = nt.new_child() 430 | orelse = updated_node.orelse 431 | if orelse is not None: 432 | assert isinstance(orelse, cst.Else) 433 | new_orelse = orelse.visit(self) 434 | suite.extend(new_orelse.body.body) 435 | else: 436 | assert not f_returns 437 | 438 | self.name_table = nt 439 | 440 | t_nt = t_nt.maps[0] 441 | f_nt = f_nt.maps[0] 442 | 443 | 444 | def _mux_name(name, t_name, f_name): 445 | new_name = self._make_name(name) 446 | assign = make_assign( 447 | cst.Name(new_name), 448 | cst.IfExp( 449 | test=new_test, 450 | body=cst.Name(t_name), 451 | orelse=cst.Name(f_name), 452 | ), 453 | ) 454 | self.name_assignments[new_name] = assign 455 | 456 | stmt = to_stmt(assign) 457 | 458 | assert isinstance(original_node, cst.If) 459 | assert isinstance(self.name_assignments[t_name], (cst.Assign, cst.Param)) 460 | assert isinstance(self.name_assignments[f_name], (cst.Assign, cst.Param)) 461 | self.track_with_children(( 462 | self.name_assignments[t_name], 463 | self.name_assignments[f_name], 464 | original_node, 465 | ), stmt) 466 | assert assign in self.node_tracking_table.i 467 | return stmt 468 | 469 | if t_returns and f_returns: 470 | # No need to mux any names they can't fall through anyway 471 | pass 472 | elif t_returns and not f_returns: 473 | # fall through from orelse 474 | nt.update(f_nt) 475 | elif f_returns and not t_returns: 476 | # fall through from body 477 | nt.update(t_nt) 478 | else: 479 | # Mux names 480 | for name in sorted(t_nt.keys() | f_nt.keys()): 481 | if name in t_nt and name in f_nt: 482 | # mux between true and false 483 | suite.append(_mux_name(name, t_nt[name], f_nt[name])) 484 | elif name in t_nt and name in nt: 485 | # mux between true and old value 486 | suite.append(_mux_name(name, t_nt[name], nt[name])) 487 | elif name in f_nt and name in nt: 488 | # mux between false and old value 489 | suite.append(_mux_name(name, nt[name], f_nt[name])) 490 | elif name in t_nt and not self.strict: 491 | # Assume name will fall through 492 | nt[name] = t_nt[name] 493 | elif name in f_nt and not self.strict: 494 | # Assume name will fall through 495 | nt[name] = f_nt[name] 496 | 497 | 498 | return cst.FlattenSentinel(suite) 499 | 500 | def visit_Assign(self, node: cst.Assign) -> tp.Optional[bool]: 501 | # Control recursion order 502 | return False 503 | 504 | def leave_Assign(self, 505 | original_node: cst.Assign, 506 | updated_node: cst.Assign) -> cst.Assign: 507 | new_value = updated_node.value.visit(self) 508 | assert not self._assigned_names, (to_module(original_node).code, self._assigned_names) 509 | new_targets = [t.visit(self) for t in updated_node.targets] 510 | final_node = updated_node.with_changes(value=new_value, targets=new_targets) 511 | for name in self._assigned_names: 512 | self.name_assignments[name] = original_node 513 | self._assigned_names = [] 514 | return final_node 515 | 516 | def visit_Attribute(self, node: cst.Attribute) -> tp.Optional[bool]: 517 | return False 518 | 519 | def leave_Attribute(self, 520 | original_node: cst.Attribute, 521 | updated_node: cst.Attribute) -> cst.Attribute: 522 | new_value = updated_node.value.visit(self) 523 | final_node = updated_node.with_changes(value=new_value) 524 | return final_node 525 | 526 | def visit_Arg_keyword(self, node: cst.Arg): 527 | self._skip += 1 528 | 529 | def leave_Arg_keyword(self, node: cst.Arg): 530 | self._skip -= 1 531 | 532 | def visit_Parameters(self, node: cst.Parameters) -> tp.Optional[bool]: 533 | self._skip += 1 534 | return True 535 | 536 | def leave_Parameters(self, 537 | original_node: cst.Parameters, 538 | updated_node: cst.Parameters) -> cst.Parameters: 539 | self._skip -= 1 540 | return updated_node 541 | 542 | def leave_Name(self, 543 | original_node: cst.Name, 544 | updated_node: cst.Name) -> cst.Name: 545 | if self._skip: 546 | return updated_node 547 | 548 | name = updated_node.value 549 | # name is already ssa 550 | if name in self.final_names: 551 | return updated_node 552 | 553 | ctx = self.ctxs[original_node] 554 | if ctx is ExpressionContext.LOAD: 555 | # Names in Load context should not be added to the name table 556 | # as it makes them seem like they have been modified. 557 | try: 558 | return cst.Name(self.name_table[name]) 559 | except KeyError: 560 | if self.strict: 561 | raise SyntaxError(f'Cannot prove name `{name}` is defined') 562 | else: 563 | return cst.Name(name) 564 | elif ctx is ExpressionContext.STORE: 565 | new_name = self._make_name(name) 566 | self._assigned_names.append(new_name) 567 | return cst.Name(new_name) 568 | else: 569 | return updated_node 570 | 571 | class GenerateSymbolTable(cst.CSTVisitor): 572 | node_tracking_table: BiMap[cst.CSTNode, cst.CSTNode] 573 | 574 | def __init__(self, node_tracking_table, original_names, pos_info, start_ln, end_ln): 575 | self.node_tracking_table = node_tracking_table 576 | self.original_names = original_names 577 | self.pos_info = pos_info 578 | self.start_ln = start_ln 579 | self.end_ln = end_ln 580 | self.symbol_table = { 581 | i: {} for i in range(start_ln, end_ln+1) 582 | } 583 | self.scope = None 584 | 585 | 586 | def _set_name(self, name, new_name, origins): 587 | ln = self.start_ln 588 | for origin in origins: 589 | pos = self.pos_info[origin] 590 | 591 | if isinstance(origin, (cst.BaseExpression, cst.BaseSmallStatement, cst.Param)): 592 | ln = max(ln, pos.end.line) 593 | else: 594 | assert isinstance(origin, cst.BaseCompoundStatement) 595 | ln = max(ln, pos.end.line + 1) 596 | 597 | for i in range(ln, self.end_ln+1): 598 | self.symbol_table[i][name] = new_name 599 | 600 | 601 | def visit_FunctionDef(self, 602 | node: cst.FunctionDef) -> tp.Optional[bool]: 603 | if self.scope is None: 604 | self.scope = node 605 | return True 606 | return False 607 | 608 | def visit_Param(self, node: cst.Param) -> tp.Optional[bool]: 609 | name = node.name.value 610 | origins = self.node_tracking_table.i[node] 611 | self._set_name(name, name, origins) 612 | 613 | def visit_Assign(self, node: cst.Assign) -> tp.Optional[bool]: 614 | for t in node.targets: 615 | t = t.target 616 | if m.matches(t, m.Name()): 617 | ssa_name = t.value 618 | if ssa_name in self.original_names: 619 | ln = self.start_ln 620 | name = self.original_names[ssa_name] 621 | # HACK attrs not currently tracked properly 622 | try: 623 | origins = self.node_tracking_table.i[t] 624 | except KeyError: 625 | continue 626 | self._set_name(name, ssa_name, origins) 627 | 628 | class ssa(Pass): 629 | def __init__(self, strict: bool = True): 630 | self.strict = strict 631 | 632 | def rewrite(self, 633 | original_tree: cst.FunctionDef, 634 | env: SymbolTable, 635 | metadata: tp.MutableMapping) -> PASS_ARGS_T: 636 | if not isinstance(original_tree, cst.FunctionDef): 637 | raise TypeError('ssa must be run on a FunctionDef') 638 | 639 | 640 | # resolve position information necessary for generating symbol table 641 | wrapper = _wrap(to_module(original_tree)) 642 | pos_info = wrapper.resolve(PositionProvider) 643 | 644 | # convert `elif cond:` to `else: if cond:` 645 | # (simplifies ssa logic) 646 | transformer = with_tracking(ElifToElse)() 647 | tree = original_tree.visit(transformer) 648 | 649 | # original node -> generated nodes 650 | node_tracking_table = transformer.node_tracking_table 651 | # node_tracking_table.i 652 | # generated node -> original nodes 653 | 654 | wrapper = _wrap(to_module(tree)) 655 | writter_attr_visitor = WrittenAttrs() 656 | wrapper.visit(writter_attr_visitor) 657 | 658 | replacer = with_tracking(AttrReplacer)() 659 | attr_format = gen_free_prefix(tree, env, '_attr') + '_{}_{}' 660 | init_reads = [] 661 | names_to_attr = {} 662 | seen = set() 663 | 664 | for written_attr in writter_attr_visitor.written_attrs: 665 | d_attr = DeepNode(written_attr) 666 | if d_attr in seen: 667 | continue 668 | if not isinstance(written_attr.value, cst.Name): 669 | raise NotImplementedError('writing non name nodes is not supported') 670 | 671 | seen.add(d_attr) 672 | 673 | attr_name = attr_format.format( 674 | written_attr.value.value, 675 | written_attr.attr.value, 676 | ) 677 | 678 | # using normal node instead of original node 679 | # is safe as parenthesis don't matter: 680 | # (name).attr == (name.attr) == name.attr 681 | norm = d_attr.normal_node 682 | names_to_attr[attr_name] = norm 683 | name = cst.Name(attr_name) 684 | replacer.add_replacement(written_attr, name) 685 | read = to_stmt(make_assign(name, norm)) 686 | init_reads.append(read) 687 | 688 | # Replace references to attr with the name generated above 689 | tree = tree.visit(replacer) 690 | 691 | 692 | node_tracking_table = replacer.trace_origins(node_tracking_table) 693 | 694 | # Rewrite conditions to be ssa 695 | cond_prefix = gen_free_prefix(tree, env, '_cond') 696 | wrapper = _wrap(tree) 697 | name_tests = NameTests(cond_prefix) 698 | tree = wrapper.visit(name_tests) 699 | 700 | node_tracking_table = name_tests.trace_origins(node_tracking_table) 701 | 702 | 703 | # Transform to single return format 704 | wrapper = _wrap(tree) 705 | single_return = SingleReturn(env, names_to_attr, self.strict) 706 | tree = wrapper.visit(single_return) 707 | 708 | node_tracking_table = single_return.trace_origins(node_tracking_table) 709 | 710 | # insert the initial reads / final writes / return 711 | body = tree.body 712 | body = body.with_changes(body=(*init_reads, *body.body, *single_return.tail)) 713 | tree = tree.with_changes(body=body) 714 | 715 | # perform ssa 716 | wrapper = _wrap(to_module(tree)) 717 | ctxs = wrapper.resolve(ExpressionContextProvider) 718 | # These names were constructed in such a way that they are 719 | # guaranteed to be ssa and shouldn't be touched by the 720 | # transformer 721 | final_names = single_return.added_names | name_tests.added_names 722 | ssa_transformer = SSATransformer( 723 | env, 724 | ctxs, 725 | final_names, 726 | single_return.returning_blocks, 727 | strict=self.strict) 728 | tree = tree.visit(ssa_transformer) 729 | 730 | node_tracking_table = ssa_transformer.trace_origins(node_tracking_table) 731 | 732 | tree.validate_types_deep() 733 | # generate symbol table 734 | start_ln = pos_info[original_tree].start.line 735 | end_ln = pos_info[original_tree].end.line 736 | visitor = GenerateSymbolTable( 737 | node_tracking_table, 738 | ssa_transformer.original_names, 739 | pos_info, 740 | start_ln, 741 | end_ln, 742 | ) 743 | 744 | tree.visit(visitor) 745 | metadata.setdefault('SYMBOL-TABLE', list()).append((type(self), visitor.symbol_table)) 746 | return tree, env, metadata 747 | -------------------------------------------------------------------------------- /ast_tools/passes/util.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | import ast 3 | import copy 4 | import inspect 5 | import typing as tp 6 | import warnings 7 | 8 | import libcst as cst 9 | 10 | from . import Pass 11 | from . import PASS_ARGS_T 12 | 13 | from ast_tools.stack import get_symbol_table, SymbolTable 14 | from ast_tools.common import get_ast, get_cst, exec_def_in_file, exec_str_in_file 15 | 16 | __ALL__ = ['begin_rewrite', 'end_rewrite', 'apply_ast_passes'] 17 | 18 | 19 | def _issubclass(t, s) -> bool: 20 | try: 21 | return issubclass(t, s) 22 | except TypeError: 23 | return False 24 | 25 | def _isinstance(t, s) -> bool: 26 | try: 27 | return isinstance(t, s) 28 | except TypeError: 29 | return False 30 | 31 | def _is_subclass_or_instance(t, s) -> bool: 32 | return _isinstance(t, s) or _issubclass(t, s) 33 | 34 | 35 | class _DecoratorStripper(metaclass=ABCMeta): 36 | @staticmethod 37 | @abstractmethod 38 | def get_decorators(tree): pass 39 | 40 | 41 | @staticmethod 42 | @abstractmethod 43 | def set_decorators(tree, decorators): pass 44 | 45 | 46 | @classmethod 47 | @abstractmethod 48 | def lookup(cls, node, env): pass 49 | 50 | 51 | @classmethod 52 | def strip(cls, tree, env, start_sentinel, end_sentinel): 53 | decorators = [] 54 | in_group = False 55 | first_group = True 56 | # filter passes from the decorator list 57 | for node in reversed(cls.get_decorators(tree)): 58 | if not first_group: 59 | decorators.append(node) 60 | continue 61 | 62 | deco = cls.lookup(node, env) 63 | if in_group: 64 | if _is_subclass_or_instance(deco, end_sentinel): 65 | in_group = False 66 | first_group = False 67 | elif start_sentinel is None or _is_subclass_or_instance(deco, start_sentinel): 68 | if start_sentinel is end_sentinel: 69 | # Just remove current decorator 70 | first_group = False 71 | else: 72 | in_group = True 73 | else: 74 | decorators.append(node) 75 | 76 | tree = cls.set_decorators(tree, reversed(decorators)) 77 | return tree 78 | 79 | 80 | class _ASTStripper(_DecoratorStripper): 81 | @staticmethod 82 | def get_decorators(tree): 83 | return tree.decorator_list 84 | 85 | 86 | @staticmethod 87 | def set_decorators(tree, decorators): 88 | tree = copy.deepcopy(tree) 89 | tree.decorator_list = decorators 90 | return tree 91 | 92 | 93 | @classmethod 94 | def lookup(cls, node, env): 95 | if isinstance(node, ast.Call): 96 | return cls.lookup(node.func, env) 97 | elif isinstance(node, ast.Attribute): 98 | return getattr(cls.lookup(node.value, env), node.attr) 99 | else: 100 | assert isinstance(node, ast.Name) 101 | return env[node.id] 102 | 103 | 104 | class _CSTStripper(_DecoratorStripper): 105 | @staticmethod 106 | def get_decorators(tree): 107 | return tree.decorators 108 | 109 | 110 | @staticmethod 111 | def set_decorators(tree, decorators): 112 | return tree.with_changes(decorators=decorators) 113 | 114 | 115 | @classmethod 116 | def lookup(cls, node: cst.CSTNode, env): 117 | if isinstance(node, cst.Decorator): 118 | return cls.lookup(node.decorator, env) 119 | elif isinstance(node, cst.Call): 120 | return cls.lookup(node.func, env) 121 | elif isinstance(node, cst.Attribute): 122 | return getattr(cls.lookup(node.value, env), node.attr.value) 123 | else: 124 | assert isinstance(node, cst.Name) 125 | return env[node.value] 126 | 127 | class begin_rewrite: 128 | """ 129 | begins a chain of passes 130 | """ 131 | def __init__(self, 132 | debug: bool = False, 133 | env: tp.Optional[SymbolTable] = None): 134 | warnings.warn( 135 | "begin_rewrite / end_rewrite are deprcated please use apply_ast_passes instead", 136 | DeprecationWarning) 137 | 138 | if env is None: 139 | env = get_symbol_table([self.__init__]) 140 | 141 | self.env = env 142 | self.debug = debug 143 | 144 | def __call__(self, fn) -> PASS_ARGS_T: 145 | tree = get_ast(fn) 146 | metadata = {} 147 | if self.debug: 148 | metadata["source_filename"] = inspect.getsourcefile(fn) 149 | metadata["source_lines"] = inspect.getsourcelines(fn) 150 | return tree, self.env, metadata 151 | 152 | 153 | class end_rewrite(Pass): 154 | """ 155 | ends a chain of passes 156 | """ 157 | def __init__(self, **kwargs): 158 | self.kwargs = kwargs 159 | 160 | def rewrite(self, 161 | tree: ast.AST, 162 | env: SymbolTable, 163 | metadata: tp.MutableMapping) -> tp.Union[tp.Callable, type]: 164 | # tree to exec 165 | etree = _ASTStripper.strip(tree, env, begin_rewrite, None) 166 | etree = ast.fix_missing_locations(etree) 167 | # tree to serialize 168 | stree = _ASTStripper.strip(tree, env, begin_rewrite, end_rewrite) 169 | stree = ast.fix_missing_locations(stree) 170 | return exec_def_in_file(etree, env, serialized_tree=stree, **self.kwargs) 171 | 172 | 173 | class _apply_passes(metaclass=ABCMeta): 174 | ''' 175 | Applies a sequence of passes to a function or class 176 | ''' 177 | passes: tp.Sequence[Pass] 178 | env: SymbolTable 179 | debug: bool 180 | path: tp.Optional[str] 181 | file_name: tp.Optional[str] 182 | metadata_attr: tp.Optional[str] 183 | 184 | 185 | def __init__(self, 186 | passes: tp.Sequence[Pass], 187 | debug: bool = False, 188 | env: tp.Optional[SymbolTable] = None, 189 | path: tp.Optional[str] = None, 190 | file_name: tp.Optional[str] = None, 191 | metadata_attr: tp.Optional[str] = None, 192 | ): 193 | if env is None: 194 | env = get_symbol_table([self.__init__]) 195 | self.passes = passes 196 | self.env = env 197 | self.debug = debug 198 | self.path = path 199 | self.file_name = file_name 200 | self.metadata_attr = metadata_attr 201 | 202 | 203 | @staticmethod 204 | @abstractmethod 205 | def parse(tree): pass 206 | 207 | 208 | @staticmethod 209 | @abstractmethod 210 | def strip_decorators(tree): pass 211 | 212 | 213 | @abstractmethod 214 | def exec(self, etree, stree, env): pass 215 | 216 | def prologue(self, tree, env, metadata): 217 | """ 218 | Invoked before `do_passes`, redefine this method to add code that 219 | operates on the initial `tree`, `env`, and `metadata` 220 | """ 221 | return tree, env, metadata 222 | 223 | def do_passes(self, tree, env, metadata): 224 | args = (tree, env, metadata) 225 | for p in self.passes: 226 | args = p(args) 227 | return args 228 | 229 | def epilogue(self, tree, env, metadata): 230 | """ 231 | Invoked after `do_passes`, redefine this method to add code that 232 | operates on the final `tree`, `env`, and `metadata` 233 | """ 234 | return tree, env, metadata 235 | 236 | def __call__(self, fn): 237 | tree = self.parse(fn) 238 | self.i_tree = tree 239 | 240 | metadata = {} 241 | if self.debug: 242 | metadata["source_filename"] = inspect.getsourcefile(fn) 243 | metadata["source_lines"] = inspect.getsourcelines(fn) 244 | 245 | tree, env, metadata = self.prologue(tree, self.env, 246 | metadata) 247 | tree, env, metadata = self.do_passes(tree, env, metadata) 248 | tree, env, metadata = self.epilogue(tree, env, metadata) 249 | 250 | self.f_tree = tree 251 | self.metadata = metadata 252 | 253 | etree = self.strip_decorators(tree, env, type(self), None) 254 | stree = self.strip_decorators(tree, env, type(self), type(self)) 255 | fn = self.exec(etree, stree, env, metadata) 256 | 257 | if self.metadata_attr is not None: 258 | setattr(fn, self.metadata_attr, metadata) 259 | 260 | return fn 261 | 262 | class apply_ast_passes(_apply_passes): 263 | parse = staticmethod(get_ast) 264 | strip_decorators = staticmethod(_ASTStripper.strip) 265 | 266 | def exec(self, etree: ast.AST, stree: ast.AST, env: SymbolTable, 267 | metadata: tp.MutableMapping): 268 | etree = ast.fix_missing_locations(etree) 269 | stree = ast.fix_missing_locations(stree) 270 | return exec_def_in_file(etree, env, self.path, self.file_name, stree) 271 | 272 | 273 | class apply_cst_passes(_apply_passes): 274 | parse = staticmethod(get_cst) 275 | strip_decorators = staticmethod(_CSTStripper.strip) 276 | 277 | def exec(self, 278 | etree: tp.Union[cst.ClassDef, cst.FunctionDef], 279 | stree: tp.Union[cst.ClassDef, cst.FunctionDef], 280 | env: SymbolTable, 281 | metadata: tp.MutableMapping): 282 | return exec_def_in_file(etree, env, self.path, self.file_name, stree) 283 | 284 | 285 | apply_passes = apply_cst_passes 286 | -------------------------------------------------------------------------------- /ast_tools/pattern.py: -------------------------------------------------------------------------------- 1 | import re 2 | import ast 3 | import itertools 4 | 5 | 6 | class NodePattern: 7 | """An individual node match, e.g. {x:Name}""" 8 | def __init__(self, s): 9 | self._parse(s) 10 | 11 | def _parse(self, s): 12 | """ 13 | Extracts the name and type of the desired AST node. 14 | If no type is provided, then self.type is None. 15 | 16 | Type names are evaluated in a context with the ast module imported. 17 | """ 18 | parts = tuple(s[1:-1].split(':')) 19 | self.name = parts[0] 20 | if len(parts) == 1: 21 | self.type = None 22 | else: 23 | assert (len(parts) == 2) 24 | self.type = getattr(ast, parts[1]) 25 | 26 | 27 | class ASTPattern: 28 | """ 29 | A pattern for an AST subtree. 30 | 31 | Patterns are normal Python programs but with pieces replaced by NodePatterns. 32 | For example, to match a copy statement (x = y) looks like: 33 | 34 | {lhs:Name} = {rhs:Name} 35 | """ 36 | def __init__(self, s): 37 | self.template = self._parse(s) 38 | 39 | def _parse(self, s): 40 | """ 41 | Replace each node pattern with a unique variable name. 42 | """ 43 | hole_id = 0 44 | self.var_map = {} 45 | 46 | def replace_incr(exp): 47 | nonlocal hole_id 48 | hole_id += 1 49 | name = '__hole{}'.format(hole_id) 50 | self.var_map[name] = NodePattern(exp.group(0)) 51 | return name 52 | 53 | return ast.parse(re.sub(r'{[^}]*}', replace_incr, s)).body[0] 54 | 55 | def _match(self, pattern_node, actual_node): 56 | # If pattern_node is a lone variable (Expr of Name, or plain Name) 57 | # then extract the variable name 58 | if isinstance(pattern_node, ast.Name): 59 | pattern_var = pattern_node.id 60 | elif isinstance(pattern_node, ast.Expr) and \ 61 | isinstance(pattern_node.value, ast.Name): 62 | pattern_var = pattern_node.value.id 63 | else: 64 | pattern_var = None 65 | 66 | # Check if pattern variable name corresponds to a hole in var_map 67 | if pattern_var is not None and pattern_var in self.var_map: 68 | node_pattern = self.var_map[pattern_var] 69 | 70 | # If the pattern variable type matches the actual node AST type, 71 | # then save the match and return True 72 | if node_pattern.type is None or \ 73 | isinstance(actual_node, node_pattern.type): 74 | self._matches[node_pattern.name] = actual_node 75 | return True 76 | return False 77 | 78 | # Structural AST equality, adapted from 79 | # https://stackoverflow.com/questions/3312989/elegant-way-to-test-python-asts-for-equality-not-reference-or-object-identity 80 | if type(pattern_node) is not type(actual_node): 81 | return False 82 | if isinstance(pattern_node, ast.AST): 83 | for k, v in vars(pattern_node).items(): 84 | if k in ('lineno', 'col_offset', 'end_lineno', 'end_col_offset', 'ctx', '_pp'): 85 | continue 86 | if not self._match(v, getattr(actual_node, k)): 87 | return False 88 | return True 89 | elif isinstance(pattern_node, list): 90 | return all(itertools.starmap(self._match, zip(pattern_node, actual_node))) 91 | return pattern_node == actual_node 92 | 93 | def match(self, node): 94 | self._matches = {} 95 | if self._match(self.template, node): 96 | return self._matches.copy() 97 | return None 98 | 99 | 100 | def ast_match(pattern, node): 101 | return ASTPattern(pattern).match(node) 102 | -------------------------------------------------------------------------------- /ast_tools/stack.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Functions and classes the inspect or modify the stack 3 | ''' 4 | 5 | import copy 6 | import inspect 7 | import typing as tp 8 | import types 9 | import functools 10 | import itertools 11 | 12 | from collections import ChainMap 13 | from contextlib import contextmanager 14 | import logging 15 | 16 | _SKIP_FRAME_DEBUG_NAME = '_AST_TOOLS_STACK_DEBUG_SKIPPED_FRAME_' 17 | _SKIP_FRAME_DEBUG_VALUE = 0xdeadbeaf 18 | _SKIP_FRAME_DEBUG_STMT = f'{_SKIP_FRAME_DEBUG_NAME} = {_SKIP_FRAME_DEBUG_VALUE}' 19 | _SKIP_FRAME_DEBUG_FAIL = False 20 | 21 | class SymbolTable(tp.Mapping[str, tp.Any]): 22 | locals: tp.MutableMapping[str, tp.Any] 23 | globals: tp.Dict[str, tp.Any] 24 | 25 | def __init__(self, 26 | locals: tp.MutableMapping[str, tp.Any], 27 | globals: tp.Dict[str, tp.Any]): 28 | self.locals = locals 29 | self.globals = globals 30 | 31 | def __getitem__(self, key): 32 | try: 33 | return self.locals[key] 34 | except KeyError: 35 | pass 36 | return self.globals[key] 37 | 38 | def __iter__(self): 39 | # the implementation of chain map does things this way 40 | yield from set().union(self.locals, self.globals) 41 | 42 | def __len__(self): 43 | return len(set().union(self.locals, self.globals)) 44 | 45 | 46 | def get_symbol_table( 47 | decorators: tp.Optional[tp.Sequence[inspect.FrameInfo]] = None, 48 | copy_locals: bool = False 49 | ) -> SymbolTable: 50 | exec(_SKIP_FRAME_DEBUG_STMT) 51 | locals = ChainMap() 52 | globals = ChainMap() 53 | 54 | if decorators is None: 55 | decorators = set() 56 | else: 57 | decorators = {f.__code__ for f in decorators} 58 | decorators.add(get_symbol_table.__code__) 59 | 60 | 61 | stack = inspect.stack() 62 | for i in range(len(stack) - 1, 0, -1): 63 | frame = stack[i] 64 | if frame.frame.f_code in decorators: 65 | continue 66 | debug_check = frame.frame.f_locals.get(_SKIP_FRAME_DEBUG_NAME, None) 67 | if debug_check == _SKIP_FRAME_DEBUG_VALUE: 68 | if _SKIP_FRAME_DEBUG_FAIL: 69 | raise RuntimeError(f'{frame.function} @ {frame.filename}:{frame.lineno} might be leaking names') 70 | else: 71 | logging.debug(f'{frame.function} @ {frame.filename}:{frame.lineno} might be leaking names') 72 | f_locals = stack[i].frame.f_locals 73 | if copy_locals: 74 | f_locals = copy.copy(f_locals) 75 | locals = locals.new_child(f_locals) 76 | globals = globals.new_child(stack[i].frame.f_globals) 77 | return SymbolTable(locals=locals, globals=dict(globals)) 78 | 79 | def inspect_symbol_table( 80 | fn: tp.Callable, # tp.Callable[[SymbolTable, ...], tp.Any], 81 | *, 82 | decorators: tp.Optional[tp.Sequence[inspect.FrameInfo]] = None, 83 | ) -> tp.Callable: 84 | exec(_SKIP_FRAME_DEBUG_STMT) 85 | if decorators is None: 86 | decorators = () 87 | 88 | @functools.wraps(fn) 89 | def wrapped_0(*args, **kwargs): 90 | exec(_SKIP_FRAME_DEBUG_STMT) 91 | st = get_symbol_table(list(itertools.chain(decorators, [wrapped_0]))) 92 | return fn(st, *args, **kwargs) 93 | return wrapped_0 94 | 95 | 96 | # mostly equivelent to magma.ast_utils.inspect_enclosing_env 97 | def inspect_enclosing_env( 98 | fn: tp.Callable, # tp.Callable[[tp.Dict[str, tp.Any], ...], tp.Any], 99 | *, 100 | decorators: tp.Optional[tp.Sequence[inspect.FrameInfo]] = None, 101 | st: tp.Optional[SymbolTable] = None) -> tp.Callable: 102 | exec(_SKIP_FRAME_DEBUG_STMT) 103 | if decorators is None: 104 | decorators = () 105 | 106 | @functools.wraps(fn) 107 | def wrapped_0(*args, **kwargs): 108 | exec(_SKIP_FRAME_DEBUG_STMT) 109 | 110 | _st = get_symbol_table(list(itertools.chain(decorators, [wrapped_0]))) 111 | if st is not None: 112 | _st.locals.update(st) 113 | 114 | env = dict(_st.globals) 115 | env.update(_st.locals) 116 | return fn(env, *args, **kwargs) 117 | 118 | return wrapped_0 119 | 120 | -------------------------------------------------------------------------------- /ast_tools/transformers/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | NodeTransformers of general utility 3 | """ 4 | from .if_inliner import Inliner 5 | from .loop_unroller import Unroller 6 | from .node_replacer import NodeReplacer 7 | from .normalizers import ElifToElse, NormalizeBlocks, NormalizeLines 8 | from .renamer import Renamer 9 | from .symbol_replacer import SymbolReplacer 10 | -------------------------------------------------------------------------------- /ast_tools/transformers/if_inliner.py: -------------------------------------------------------------------------------- 1 | import typing as tp 2 | 3 | import libcst as cst 4 | 5 | from .symbol_replacer import replace_symbols 6 | from ..macros import inline 7 | from ast_tools.cst_utils import to_module 8 | 9 | class Inliner(cst.CSTTransformer): 10 | def __init__(self, env: tp.Mapping[str, tp.Any]): 11 | super().__init__() 12 | self.env = env 13 | 14 | def visit_If(self, node): 15 | # Control recursion order. 16 | # Need to avoid putting an a flatten sentinel in orelse 17 | # can happen if the elif is being inlined 18 | return False 19 | 20 | def leave_If( 21 | self, 22 | original_node: cst.If, 23 | updated_node: cst.If 24 | ) -> tp.Union[cst.If, cst.RemovalSentinel, cst.FlattenSentinel[cst.BaseStatement]]: 25 | try: 26 | cond_obj = eval(to_module(updated_node.test).code, {}, self.env) 27 | is_constant = True 28 | except Exception as e: 29 | is_constant = False 30 | if is_constant and isinstance(cond_obj, inline): 31 | if cond_obj: 32 | new_body = updated_node.body.visit(self) 33 | updated_node = cst.FlattenSentinel(new_body.body) 34 | else: 35 | orelse = updated_node.orelse 36 | if orelse is None: 37 | updated_node = cst.RemoveFromParent() 38 | elif isinstance(orelse, cst.If): 39 | updated_node = updated_node.orelse.visit(self) 40 | else: 41 | assert isinstance(orelse, cst.Else) 42 | new_body = updated_node.orelse.body.visit(self) 43 | updated_node = cst.FlattenSentinel(new_body.body) 44 | else: 45 | new_body = updated_node.body.visit(self) 46 | if updated_node.orelse: 47 | new_orelse = updated_node.orelse.visit(self) 48 | else: 49 | new_orelse = None 50 | 51 | updated_node = updated_node.with_changes(body=new_body, orelse=new_orelse) 52 | 53 | return super().leave_If(original_node, updated_node) 54 | 55 | 56 | def inline_ifs(tree: cst.CSTNode, env: tp.Mapping[str, tp.Any]) -> cst.CSTNode: 57 | return tree.visit(Inliner(env)) 58 | -------------------------------------------------------------------------------- /ast_tools/transformers/loop_unroller.py: -------------------------------------------------------------------------------- 1 | import typing as tp 2 | 3 | import libcst as cst 4 | 5 | from .symbol_replacer import replace_symbols 6 | from ..macros import unroll 7 | from ast_tools.cst_utils import to_module 8 | 9 | class Unroller(cst.CSTTransformer): 10 | def __init__(self, env: tp.Mapping[str, tp.Any]): 11 | super().__init__() 12 | self.env = env 13 | 14 | def leave_For( 15 | self, 16 | original_node: cst.For, 17 | updated_node: cst.For) -> tp.Union[cst.For, cst.FlattenSentinel[cst.BaseStatement]]: 18 | 19 | try: 20 | iter_obj = eval(to_module(updated_node.iter).code, {}, self.env) 21 | is_constant = True 22 | except Exception as e: 23 | is_constant = False 24 | if is_constant and isinstance(iter_obj, unroll): 25 | body = [] 26 | if not isinstance(updated_node.target, cst.Name): 27 | raise NotImplementedError('Unrolling with non-name target') 28 | 29 | for i in iter_obj: 30 | if isinstance(i, int): 31 | symbol_table = {updated_node.target.value: cst.Integer(value=repr(i))} 32 | for child in updated_node.body.body: 33 | body.append( 34 | replace_symbols(child, symbol_table) 35 | ) 36 | else: 37 | raise NotImplementedError('Unrolling non-int iterator') 38 | 39 | updated_node = cst.FlattenSentinel(body) 40 | return super().leave_For(original_node, updated_node) 41 | 42 | 43 | def unroll_for_loops(tree: cst.CSTNode, env: tp.Mapping[str, tp.Any]) -> cst.CSTNode: 44 | return tree.visit(Unroller(env)) 45 | -------------------------------------------------------------------------------- /ast_tools/transformers/node_replacer.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import typing as tp 3 | 4 | import libcst as cst 5 | 6 | _NT = tp.MutableMapping[tp.Any, cst.CSTNode] 7 | class NodeReplacer(cst.CSTTransformer, metaclass=abc.ABCMeta): 8 | node_table: _NT 9 | 10 | def __init__(self, node_table: tp.Optional[_NT] = None): 11 | if node_table is None: 12 | node_table = {} 13 | self.node_table = node_table 14 | 15 | def on_leave(self, 16 | original_node: cst.CSTNode, 17 | updated_node: cst.CSTNode, 18 | ) -> tp.Union[cst.CSTNode, cst.RemovalSentinel]: 19 | key = self._get_key(original_node) 20 | if key is None or key not in self.node_table: 21 | return super().on_leave(original_node, updated_node) 22 | else: 23 | return self.node_table[key] 24 | 25 | def add_replacement(self, node: cst.CSTNode, replacement: cst.CSTNode): 26 | key = self._get_key(node) 27 | if key is None: 28 | raise TypeError(f'Unsupported node {node}') 29 | self.node_table[key] = replacement 30 | 31 | @abc.abstractmethod 32 | def _get_key(self, node: cst.CSTNode) -> tp.Hashable: pass 33 | -------------------------------------------------------------------------------- /ast_tools/transformers/node_tracker.py: -------------------------------------------------------------------------------- 1 | import functools as ft 2 | import typing as tp 3 | import types 4 | from collections import defaultdict 5 | 6 | import libcst as cst 7 | import libcst.matchers as m 8 | from libcst import CSTNode, CSTNodeT, RemovalSentinel, FlattenSentinel 9 | 10 | 11 | from ast_tools.utils import BiMap 12 | 13 | class _NodeTrackerMixin: 14 | node_tracking_table: BiMap[CSTNode, CSTNode] 15 | 16 | def __init__(self, *args, **kwargs) -> None: 17 | super().__init__(*args, **kwargs) 18 | self.node_tracking_table = BiMap() 19 | self.new = set() 20 | 21 | def on_leave(self, 22 | original_node: CSTNodeT, 23 | updated_node: tp.Union[CSTNodeT, RemovalSentinel, FlattenSentinel[CSTNodeT]] 24 | ) -> tp.Union[CSTNodeT, RemovalSentinel, FlattenSentinel[CSTNodeT]]: 25 | final_node = super().on_leave(original_node, updated_node) 26 | self.track_with_children(original_node, final_node) 27 | return final_node 28 | 29 | 30 | def _track(self, 31 | original_node: CSTNode, 32 | updated_node: cst.CSTNode) -> None: 33 | if original_node in self.node_tracking_table.i: 34 | # original_node has a origin, track back 35 | for o_node in self.node_tracking_table.i[original_node]: 36 | self._track(o_node, updated_node) 37 | return 38 | 39 | if updated_node in self.node_tracking_table: 40 | # updated_node is an origin, skip it 41 | for u_node in self.node_tracking_table[updated_node]: 42 | self._track(original_node, u_node) 43 | return 44 | assert updated_node not in self.node_tracking_table, (original_node, updated_node) 45 | assert original_node not in self.node_tracking_table.i, (original_node, updated_node) 46 | self.node_tracking_table[original_node] = updated_node 47 | 48 | def _track_with_children(self, 49 | original_nodes: tp.Iterable[CSTNode], 50 | updated_nodes: tp.Iterable[CSTNode]) -> None: 51 | 52 | for o_node in original_nodes: 53 | for u_node in updated_nodes: 54 | if u_node not in self.node_tracking_table.i or u_node in self.new: 55 | # u_node has not been explained or has multiple origins 56 | self.new.add(u_node) 57 | self._track(o_node, u_node) 58 | self._track_with_children(original_nodes, u_node.children) 59 | 60 | def track(self, 61 | original_node: tp.Union[CSTNode, tp.Iterable[CSTNode]], 62 | updated_node: tp.Union[CSTNode, RemovalSentinel, tp.Iterable[CSTNode]]) -> None: 63 | 64 | if isinstance(updated_node, CSTNode): 65 | updated_node = updated_node, 66 | 67 | if isinstance(updated_node, RemovalSentinel) or not updated_node: 68 | return 69 | 70 | if isinstance(original_node, CSTNode): 71 | original_node = original_node, 72 | 73 | for o_node in original_nodes: 74 | for u_node in updated_nodes: 75 | self._track(o_node, u_node) 76 | 77 | 78 | def track_with_children(self, 79 | original_node: tp.Union[CSTNode, tp.Iterable[CSTNode]], 80 | updated_node: tp.Union[CSTNode, RemovalSentinel, tp.Iterable[CSTNode]]) -> None: 81 | 82 | if isinstance(updated_node, CSTNode): 83 | updated_node = updated_node, 84 | 85 | if isinstance(updated_node, RemovalSentinel) or not updated_node: 86 | return 87 | 88 | if isinstance(original_node, CSTNode): 89 | original_node = original_node, 90 | 91 | self._track_with_children(original_node, updated_node) 92 | self.new = set() 93 | 94 | def trace_origins(self, prev_table: BiMap[CSTNode, CSTNode]) -> BiMap[CSTNode, CSTNode]: 95 | new_table = BiMap() 96 | for update, origins in self.node_tracking_table.i.items(): 97 | for o in origins: 98 | for oo in prev_table.i.get(o, []): 99 | new_table[oo] = update 100 | 101 | for origin, updates in prev_table.items(): 102 | for u in updates: 103 | for uu in self.node_tracking_table.get(u, [u]): 104 | new_table[origin] = uu 105 | 106 | 107 | return new_table 108 | 109 | 110 | class NodeTrackingTransformer( 111 | _NodeTrackerMixin, 112 | cst.CSTTransformer): pass 113 | 114 | 115 | class NodeTrackingMatcherTransformer( 116 | _NodeTrackerMixin, 117 | m.MatcherDecoratableTransformer): pass 118 | 119 | 120 | def with_tracking(transformer: tp.Type[cst.CSTTransformer]) -> tp.Type[cst.CSTTransformer]: 121 | """ Helper function than adds tracking to a transformer type """ 122 | return type(transformer.__name__, (_NodeTrackerMixin, transformer), {}) 123 | -------------------------------------------------------------------------------- /ast_tools/transformers/normalizers.py: -------------------------------------------------------------------------------- 1 | import typing as tp 2 | 3 | import libcst as cst 4 | import libcst.matchers as m 5 | 6 | from ast_tools.cst_utils import to_stmt 7 | 8 | class ElifToElse(m.MatcherDecoratableTransformer): 9 | @m.leave(m.If(orelse=m.If())) 10 | def _(self, 11 | original_node: cst.If, 12 | updated_node: cst.If, 13 | ) -> cst.If: 14 | orelse = cst.Else( 15 | body=cst.IndentedBlock( 16 | body=[updated_node.orelse] 17 | ) 18 | ) 19 | updated_node = updated_node.with_changes(orelse=orelse) 20 | return updated_node 21 | 22 | 23 | class NormalizeBlocks(cst.CSTTransformer): 24 | def leave_SimpleStatementSuite(self, 25 | original_node: cst.SimpleStatementSuite, 26 | updated_node: cst.SimpleStatementSuite, 27 | ) -> cst.IndentedBlock: 28 | body = tuple(to_stmt(stmt) for stmt in updated_node.body) 29 | return cst.IndentedBlock( 30 | body=body 31 | ).visit(self) 32 | 33 | 34 | class NormalizeLines(NormalizeBlocks): 35 | def _normalize_body(self, 36 | updated_node: tp.Union[cst.IndentedBlock, cst.Module]): 37 | body = [] 38 | for node in updated_node.body: 39 | if isinstance(node, cst.SimpleStatementLine): 40 | body.extend(map(to_stmt, node.body)) 41 | else: 42 | body.append(node) 43 | return updated_node.with_changes(body=body) 44 | 45 | def leave_IndentedBlock(self, 46 | original_node: cst.IndentedBlock, 47 | updated_node: cst.IndentedBlock, 48 | ) -> cst.IndentedBlock: 49 | return self._normalize_body(updated_node) 50 | 51 | def leave_Module(self, 52 | original_node: cst.Module, 53 | updated_node: cst.Module 54 | ) -> cst.Module: 55 | return self._normalize_body(updated_node) 56 | 57 | def leave_SimpleStatementSuite(self, 58 | original_node: cst.SimpleStatementSuite, 59 | updated_node: cst.SimpleStatementSuite, 60 | ) -> tp.Union[cst.SimpleStatementSuite, cst.IndentedBlock]: 61 | # Only transform to IndentedBlock if node contains more than 1 statement 62 | if len(updated_node.body) > 1: 63 | return super().leave_SimpleStatementSuite( 64 | original_node, updated_node) 65 | return updated_node 66 | 67 | -------------------------------------------------------------------------------- /ast_tools/transformers/renamer.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import typing as tp 3 | 4 | 5 | class Renamer(ast.NodeTransformer): 6 | def __init__(self, name_map: tp.Mapping[str, str]): 7 | self.name_map = name_map 8 | 9 | def visit_Name(self, node): 10 | name = node.id 11 | new_name = self.name_map.setdefault(name, name) 12 | return ast.copy_location( 13 | ast.Name( 14 | id=new_name, 15 | ctx=node.ctx, 16 | ), 17 | node, 18 | ) 19 | -------------------------------------------------------------------------------- /ast_tools/transformers/symbol_replacer.py: -------------------------------------------------------------------------------- 1 | import libcst as cst 2 | 3 | from .node_replacer import NodeReplacer 4 | 5 | class SymbolReplacer(NodeReplacer): 6 | def _get_key(self, node: cst.CSTNode): 7 | if isinstance(node, cst.Name): 8 | return node.value 9 | else: 10 | return None 11 | 12 | def replace_symbols(tree, symbol_table): 13 | return tree.visit(SymbolReplacer(symbol_table)) 14 | -------------------------------------------------------------------------------- /ast_tools/utils.py: -------------------------------------------------------------------------------- 1 | import functools as ft 2 | import typing as tp 3 | import sys 4 | 5 | # PEP 585 6 | if sys.version_info < (3, 9): 7 | from typing import MutableMapping, Mapping, MutableSet, AbstractSet as Set, Iterator, Iterable 8 | from typing import MappingView, ItemsView, KeysView, ValuesView 9 | else: 10 | from collections.abc import MutableMapping, Mapping, MutableSet, Set, Iterator, Iterable 11 | from collections.abc import MappingView, ItemsView, KeysView, ValuesView 12 | 13 | 14 | T = tp.TypeVar('T') 15 | S = tp.TypeVar('S') 16 | 17 | # Because the setitem signature does not match mutable mapping 18 | # we inherit from mapping. We lose MutableMapping mixin methods 19 | # for correct typing but we don't use them anyway 20 | class BiMap(Mapping[T, Set[S]]): 21 | _d: MutableMapping[T, MutableSet[S]] 22 | _r: MutableMapping[S, MutableSet[T]] 23 | 24 | def __init__(self, d: tp.Optional['BiMap[T, S]'] = None) -> None: 25 | self._d = {} 26 | self._r = {} 27 | if d is not None: 28 | for k, v in d.items(): 29 | for vv in v: 30 | self[k] = vv 31 | 32 | def __getitem__(self, idx: T) -> Set[S]: 33 | return frozenset(self._d[idx]) 34 | 35 | def __setitem__(self, idx: T, val: S) -> None: 36 | self._d.setdefault(idx, set()).add(val) 37 | self._r.setdefault(val, set()).add(idx) 38 | 39 | def __delitem__(self, idx: T) -> None: 40 | for val in self._d[idx]: 41 | self._r[val].remove(idx) 42 | if not self._r[val]: 43 | del self._r[val] 44 | del self._d[idx] 45 | 46 | def __iter__(self) -> Iterator[T]: 47 | return iter(self._d) 48 | 49 | def __len__(self) -> int: 50 | return len(self._d) 51 | 52 | def __eq__(self, other) -> bool: 53 | if isinstance(other, type(self)): 54 | if self._d == other._d: 55 | assert self._r == other._r 56 | return True 57 | else: 58 | assert self._r != other._r 59 | return False 60 | else: 61 | return NotImplemented 62 | 63 | def __ne__(self, other) -> bool: 64 | if isinstance(other, type(self)): 65 | return not self == other 66 | else: 67 | return NotImplemented 68 | 69 | @property 70 | def i(self) -> 'BiMap[S, T]': 71 | i: BiMap[S, T] = BiMap() 72 | i._d = self._r 73 | i._r = self._d 74 | return i 75 | 76 | def __repr__(self) -> str: 77 | kv = map(': '.join, (map(repr, items) for items in self.items())) 78 | return f'{type(self).__name__}(' + ', '.join(kv) + ')' 79 | 80 | 81 | def _attest(self: BiMap[T, S]) -> None: 82 | for dk, dvals in self._d.items(): 83 | for dv in dvals: 84 | assert dk in self._r[dv] 85 | 86 | for rk, rvals in self._r.items(): 87 | for rv in rvals: 88 | assert rk in self._d[rv] 89 | 90 | F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any]) 91 | 92 | def _with_attestation(f: F) -> F: 93 | @ft.wraps(f) 94 | def wrapper(self: BiMap[T, S], *args, **kwargs): 95 | _attest(self) 96 | r_val = f(self, *args, **kwargs) 97 | _attest(self) 98 | return r_val 99 | return tp.cast(F, wrapper) 100 | 101 | class _BiMapDebug(BiMap[T, S]): 102 | def __init__(self, d: tp.Optional[BiMap[T, S]] = None) -> None: 103 | super().__init__(d) 104 | _attest(self) 105 | 106 | @property 107 | def i(self) -> '_BiMapDebug[S, T]': 108 | _attest(self) 109 | i: _BiMapDebug[S, T] = _BiMapDebug() 110 | i._d = self._r 111 | i._r = self._d 112 | _attest(i) 113 | return i 114 | 115 | __getitem__ = _with_attestation(BiMap.__getitem__) 116 | __setitem__ = _with_attestation(BiMap.__setitem__) 117 | __delitem__ = _with_attestation(BiMap.__delitem__) 118 | __iter__ = _with_attestation(BiMap.__iter__) 119 | __len__ = _with_attestation(BiMap.__len__) 120 | __eq__ = _with_attestation(BiMap.__eq__) 121 | __ne__ = _with_attestation(BiMap.__ne__) 122 | -------------------------------------------------------------------------------- /ast_tools/visitors/__init__.py: -------------------------------------------------------------------------------- 1 | from .collect_names import * 2 | from .collect_targets import * 3 | from .used_names import * 4 | -------------------------------------------------------------------------------- /ast_tools/visitors/collect_names.py: -------------------------------------------------------------------------------- 1 | """ 2 | Defines a visitor that collects all names contained in an AST 3 | """ 4 | 5 | import collections.abc 6 | import typing as tp 7 | 8 | import libcst as cst 9 | from libcst.metadata import ExpressionContext, ExpressionContextProvider 10 | 11 | _OptContext = tp.Optional[ExpressionContext] 12 | 13 | class NameCollector(cst.CSTVisitor): 14 | """ 15 | Collect all instances of `Name` in a CST. 16 | """ 17 | METADATA_DEPENDENCIES = (ExpressionContextProvider,) 18 | 19 | def __init__(self, ctx: tp.Union[tp.Sequence[_OptContext], _OptContext] = ()): 20 | """ 21 | Set `ctx` to `STORE` or `LOAD` or `DEL` to filter names. 22 | `ctx=None` will filter for names which are not names in the AST 23 | (e.g. attrs). 24 | """ 25 | super().__init__() 26 | self.names = set() 27 | if ctx == (): 28 | ctx = frozenset(( 29 | ExpressionContext.LOAD, 30 | ExpressionContext.STORE, 31 | ExpressionContext.DEL,)) 32 | elif isinstance(ctx, collections.abc.Iterable): 33 | ctx = frozenset(ctx) 34 | else: 35 | ctx = frozenset((ctx,)) 36 | self.ctx = ctx 37 | 38 | def visit_Name(self, node: cst.Name): 39 | ctx = self.get_metadata(ExpressionContextProvider, node) 40 | if ctx in self.ctx: 41 | self.names.add(node.value) 42 | 43 | 44 | def collect_names(tree, ctx=()): 45 | """ 46 | Convenience wrapper for NameCollector 47 | """ 48 | visitor = NameCollector(ctx) 49 | wrapper = cst.MetadataWrapper(tree, unsafe_skip_copy=True) 50 | wrapper.visit(visitor) 51 | return visitor.names 52 | -------------------------------------------------------------------------------- /ast_tools/visitors/collect_targets.py: -------------------------------------------------------------------------------- 1 | """ 2 | Defines a visitor that collects all assignment targets contained in an AST 3 | """ 4 | import functools 5 | 6 | import libcst as cst 7 | 8 | def _filt(t): 9 | def wrapped(obj): 10 | return isinstance(obj, t) 11 | return wrapped 12 | 13 | class TargetCollector(cst.CSTVisitor): 14 | def __init__(self, target_filter=None): 15 | if target_filter is None: 16 | target_filter = cst.CSTNode 17 | self.target_filter = _filt(target_filter) 18 | self.targets = [] 19 | 20 | def visit_Assign(self, node: cst.Assign): 21 | self.targets.extend(filter(self.target_filter, (n.target for n in node.targets))) 22 | 23 | 24 | def collect_targets(tree, target_filter=None): 25 | visitor = TargetCollector(target_filter) 26 | tree.visit(visitor) 27 | return visitor.targets 28 | 29 | 30 | -------------------------------------------------------------------------------- /ast_tools/visitors/node_finder.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import ast 3 | from copy import deepcopy 4 | 5 | class NodeFinder(ast.NodeVisitor, metaclass=abc.ABCMeta): 6 | def __init__(self, node): 7 | key = self._get_key(node) 8 | if key is None: 9 | raise TypeError(f'Unsupported node {node}') 10 | 11 | self.key = key 12 | self.target = None 13 | 14 | def visit(self, node): 15 | if self.target is not None: 16 | return 17 | 18 | key = self._get_key(node) 19 | if key is None or key != self.key: 20 | return super().visit(node) 21 | else: 22 | self.target = node 23 | 24 | @abc.abstractmethod 25 | def _get_key(self, node): pass 26 | -------------------------------------------------------------------------------- /ast_tools/visitors/used_names.py: -------------------------------------------------------------------------------- 1 | from functools import lru_cache 2 | import typing as tp 3 | 4 | import libcst as cst 5 | from libcst.metadata import ScopeProvider 6 | 7 | from ast_tools.cst_utils import to_module 8 | 9 | class UsedNames(cst.CSTVisitor): 10 | METADATA_DEPENDENCIES = (ScopeProvider,) 11 | 12 | def __init__(self): 13 | super().__init__() 14 | self.names: tp.MutableSet[str] = set() 15 | self.scope: tp.Optional[cst.metadata.Scope] = None 16 | 17 | def on_visit(self, node: cst.CSTNode): 18 | if self.scope is None: 19 | self.scope = self.get_metadata(ScopeProvider, node) 20 | 21 | return super().on_visit(node) 22 | 23 | 24 | def visit_Name(self, node: cst.Name): 25 | if node in self.scope.assignments: 26 | self.names.add(node.value) 27 | 28 | @lru_cache() 29 | def used_names(tree: cst.CSTNode): 30 | tree = to_module(tree) 31 | visitor = UsedNames() 32 | wrapper = cst.MetadataWrapper(tree, unsafe_skip_copy=True) 33 | wrapper.visit(visitor) 34 | return visitor.names 35 | -------------------------------------------------------------------------------- /docs/developer.md: -------------------------------------------------------------------------------- 1 | # Adding passes 2 | The `ast_tools` library is designed around the `apply_passes` decorator that is 3 | used as follows: 4 | ```python 5 | @apply_passes([pass1(), pass2()]) 6 | def foo(...): ... 7 | ``` 8 | 9 | The simplest way to extend the library is to define a new pass that works with 10 | this decorator. This is achieved by subclassing the `Pass` abstract class, see 11 | the definition 12 | [here](https://github.com/leonardt/ast_tools/blob/master/ast_tools/passes/base.py#L13). 13 | 14 | The essential method is the `rewrite` method: 15 | ```python 16 | def rewrite(self, 17 | tree: cst.CSTNode, 18 | env: SymbolTable, 19 | metadata: tp.MutableMapping, 20 | ) -> PASS_ARGS_T: 21 | return tree, env, metadata 22 | ``` 23 | 24 | There are three arguments: 25 | 1. `tree` -- the CST (from [libcst](https://libcst.readthedocs.io/en/latest/)) 26 | of the function being rewritten. 27 | 2. `env` -- the environment of the function being rewritten (the `SymbolTable` 28 | class is defined 29 | [here](https://github.com/leonardt/ast_tools/blob/master/ast_tools/stack.py#L21) 30 | ). 31 | 3. `metadata` -- a mapping containing arbitrary information either provided by 32 | the user or previous passes 33 | 34 | Each of these three arguments must be returned in the same order from the `rewrite` method. 35 | 36 | The ordering of passes is defined by the arguments to `apply_passes` and each 37 | pass will be called with the latest `tree`, `env`, and `metadata`. Passes can 38 | share information using the `metadata` mapping, and can update the `env` that 39 | will be used to execute the final `tree`. 40 | 41 | ## Pass Examples 42 | 43 | The [loop unrolling 44 | pass](https://github.com/leonardt/ast_tools/blob/master/ast_tools/passes/loop_unroll.py) 45 | is quite simple, with most of its logic dispatching to the [loop unroller 46 | transformer](https://github.com/leonardt/ast_tools/blob/master/ast_tools/transformers/loop_unroller.py). 47 | For more information on how to write a `Transformer`, see the [libcst 48 | documentation](https://libcst.readthedocs.io/en/latest/tutorial.html#Build-Visitor-or-Transformer). 49 | Notice that the transformer relies on the `env` table to evaluate the unroll 50 | arguments. 51 | 52 | The 53 | [if_inliner](https://github.com/leonardt/ast_tools/blob/master/ast_tools/transformers/if_inliner.py) 54 | transformer is another good place to start, which similarly relies on the `env` 55 | to evaluate `if` statements at "macro" time. 56 | 57 | The 58 | [if_to_phi](https://github.com/leonardt/ast_tools/blob/master/ast_tools/passes/if_to_phi.py) 59 | pass provides an example of how `env` might be modified by a pass. 60 | 61 | The 62 | [debug](https://github.com/leonardt/ast_tools/blob/master/ast_tools/passes/debug.py) 63 | pass provides an example of how metadata might be used. 64 | 65 | After reviewing these examples, you're ready to look at the full suite of 66 | standard 67 | [passes](https://github.com/leonardt/ast_tools/tree/master/ast_tools/passes) 68 | and 69 | [transformers](https://github.com/leonardt/ast_tools/tree/master/ast_tools/transformers). 70 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | ''' 2 | setup script for package 3 | ''' 4 | 5 | from setuptools import setup 6 | from setuptools.command.build_py import build_py 7 | from setuptools.command.develop import develop 8 | from os import path 9 | import util.generate_ast.generate as generate 10 | 11 | PACKAGE_NAME = 'ast_tools' 12 | 13 | with open('README.md', "r") as fh: 14 | LONG_DESCRIPTION = fh.read() 15 | 16 | class Install(build_py): 17 | def run(self, *args, **kwargs): 18 | self.generated_outputs = [] 19 | if not self.dry_run: 20 | src = generate.generate_immutable_ast() 21 | output_dir = path.join(self.build_lib, PACKAGE_NAME) 22 | self.mkpath(output_dir) 23 | output_file = path.join(output_dir, 'immutable_ast.py') 24 | self.announce(f'generating {output_file}', 2) 25 | with open(output_file, 'w') as f: 26 | f.write(src) 27 | self.generated_outputs.append(output_file) 28 | super().run(*args, **kwargs) 29 | 30 | def get_outputs(self, *args, **kwargs): 31 | outputs = super().get_outputs(*args, **kwargs) 32 | outputs.extend(self.generated_outputs) 33 | return outputs 34 | 35 | 36 | class Develop(develop): 37 | def run(self, *args, **kwargs): 38 | if not self.dry_run: 39 | src = generate.generate_immutable_ast() 40 | output_file = path.join(PACKAGE_NAME, 'immutable_ast.py') 41 | self.announce(f'generating {output_file}', 2) 42 | with open(output_file, 'w') as f: 43 | f.write(src) 44 | super().run(*args, **kwargs) 45 | 46 | setup( 47 | cmdclass={ 48 | 'build_py': Install, 49 | 'develop': Develop, 50 | }, 51 | name='ast_tools', 52 | url='https://github.com/leonardt/ast_tools', 53 | author='Leonard Truong', 54 | author_email='lenny@cs.stanford.edu', 55 | version='0.1.8', 56 | description='Toolbox for working with the Python AST', 57 | scripts=[], 58 | packages=[ 59 | f'{PACKAGE_NAME}', 60 | f'{PACKAGE_NAME}.cst_utils', 61 | f'{PACKAGE_NAME}.metadata', 62 | f'{PACKAGE_NAME}.passes', 63 | f'{PACKAGE_NAME}.transformers', 64 | f'{PACKAGE_NAME}.visitors', 65 | ], 66 | install_requires=[ 67 | 'astor', 68 | 'libcst', 69 | ], 70 | long_description=LONG_DESCRIPTION, 71 | long_description_content_type='text/markdown' 72 | ) 73 | -------------------------------------------------------------------------------- /tests/test_apply_passes.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | 3 | from ast_tools.passes import apply_passes, if_inline 4 | from ast_tools.macros import inline 5 | from ast_tools.common import to_source 6 | 7 | 8 | def test_apply_with_prologue(): 9 | class foo(apply_passes): 10 | def prologue(self, tree, env, metadata): 11 | env.locals['x'] = True 12 | return tree, env, metadata 13 | 14 | x = False 15 | 16 | @foo([if_inline()]) 17 | def bar(): 18 | if inline(x): 19 | return 0 20 | else: 21 | return 1 22 | 23 | assert inspect.getsource(bar) == '''\ 24 | def bar(): 25 | return 0 26 | ''' 27 | 28 | 29 | def test_apply_with_epilogue(): 30 | class foo(apply_passes): 31 | def epilogue(self, tree, env, metadata): 32 | expected = """\ 33 | @foo([if_inline()]) 34 | def bar(): 35 | return 1 36 | """ 37 | assert to_source(tree) == expected 38 | return tree, env, metadata 39 | 40 | x = False 41 | 42 | @foo([if_inline()]) 43 | def bar(): 44 | if inline(x): 45 | return 0 46 | else: 47 | return 1 48 | -------------------------------------------------------------------------------- /tests/test_assert_remover.py: -------------------------------------------------------------------------------- 1 | from ast_tools.passes import apply_passes, remove_asserts 2 | import inspect 3 | 4 | def test_remove_asserts(): 5 | @apply_passes([remove_asserts()]) 6 | def foo(): 7 | if True: 8 | assert False 9 | for i in range(10): 10 | assert i == 0 11 | assert name_error 12 | 13 | foo() 14 | 15 | assert inspect.getsource(foo) == f'''\ 16 | def foo(): 17 | if True: 18 | pass 19 | for i in range(10): 20 | pass 21 | pass 22 | ''' 23 | -------------------------------------------------------------------------------- /tests/test_bool_to_bit.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import inspect 3 | 4 | import pytest 5 | 6 | 7 | from ast_tools.passes import apply_passes, bool_to_bit 8 | 9 | def test_and(): 10 | @apply_passes([bool_to_bit()]) 11 | def and_f(x, y): 12 | return x and y 13 | 14 | assert inspect.getsource(and_f) == '''\ 15 | def and_f(x, y): 16 | return x & y 17 | ''' 18 | 19 | def test_or(): 20 | @apply_passes([bool_to_bit()]) 21 | def or_f(x, y): 22 | return x or y 23 | 24 | assert inspect.getsource(or_f) == '''\ 25 | def or_f(x, y): 26 | return x | y 27 | ''' 28 | 29 | def test_not(): 30 | @apply_passes([bool_to_bit()]) 31 | def not_f(x): 32 | return not x 33 | 34 | assert inspect.getsource(not_f) == '''\ 35 | def not_f(x): 36 | return ~x 37 | ''' 38 | 39 | def test_xor(): 40 | @apply_passes([bool_to_bit()]) 41 | def xor(x, y): 42 | return x and not y or not x and y 43 | 44 | assert inspect.getsource(xor) == '''\ 45 | def xor(x, y): 46 | return x & ~y | ~x & y 47 | ''' 48 | 49 | -------------------------------------------------------------------------------- /tests/test_common.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import astor 3 | 4 | import libcst as cst 5 | 6 | from ast_tools.common import get_ast, get_cst, gen_free_name, gen_free_prefix, to_source 7 | from ast_tools.stack import SymbolTable 8 | from ast_tools.passes import apply_passes 9 | 10 | 11 | def test_get_ast(): 12 | def f(): pass 13 | f_str = 'def f(): pass' 14 | ast_str_0 = astor.dump_tree(get_ast(f)) 15 | ast_str_1 = astor.dump_tree(ast.parse(f_str).body[0]) 16 | assert ast_str_0 == ast_str_1 17 | 18 | def test_get_cst(): 19 | def f(): pass 20 | f_str = 'def f(): pass' 21 | ast_str_0 = to_source(get_cst(f)) 22 | ast_str_1 = to_source(cst.parse_module(f_str).body[0]) 23 | assert ast_str_0 == ast_str_1 24 | 25 | def test_gen_free_name(): 26 | src = ''' 27 | class P: 28 | P5 = 1 29 | def __init__(self): self.y = 0 30 | def P0(): 31 | return P.P5 32 | P1 = P0() 33 | ''' 34 | tree = cst.parse_module(src) 35 | env = SymbolTable({}, {}) 36 | 37 | free_name = gen_free_name(tree, env) 38 | assert free_name == '_auto_name_0' 39 | 40 | free_name = gen_free_name(tree, env, prefix='P') 41 | assert free_name == 'P2' 42 | env = SymbolTable({'P3': 'foo'}, {}) 43 | free_name = gen_free_name(tree, env, prefix='P') 44 | assert free_name == 'P2' 45 | env = SymbolTable({'P3': 'foo'}, {'P2' : 'bar'}) 46 | free_name = gen_free_name(tree, env, prefix='P') 47 | assert free_name == 'P4' 48 | 49 | def test_gen_free_prefix(): 50 | src = ''' 51 | class P: 52 | P5 = 1 53 | def __init__(self): self.y = 0 54 | def P0(): 55 | return P.P5 56 | P1 = P0() 57 | ''' 58 | tree = cst.parse_module(src) 59 | env = SymbolTable({}, {}) 60 | 61 | free_prefix = gen_free_prefix(tree, env) 62 | assert free_prefix == '_auto_prefix_0' 63 | 64 | free_prefix = gen_free_prefix(tree, env, 'P') 65 | assert free_prefix == 'P2' 66 | 67 | def test_exec_in_file(): 68 | x = 3 69 | def foo(): 70 | return x 71 | assert foo() == 3 72 | 73 | @apply_passes(passes=()) 74 | def foo(): 75 | return x 76 | 77 | assert foo() == 3 78 | -------------------------------------------------------------------------------- /tests/test_cse.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import inspect 3 | 4 | import pytest 5 | 6 | from ast_tools.passes import apply_ast_passes, cse, ssa, debug 7 | 8 | @pytest.mark.skip() 9 | def test_basic(): 10 | @apply_ast_passes([ssa(), cse()]) 11 | def foo(a, b, c): 12 | x = a + b 13 | y = a + b + c 14 | z = a + b - c 15 | return x + y + z 16 | 17 | assert inspect.getsource(foo) == '''\ 18 | def foo(a, b, c): 19 | __common_expr0 = a + b 20 | x0 = __common_expr0 21 | y0 = __common_expr0 + c 22 | z0 = __common_expr0 - c 23 | __return_value0 = x0 + y0 + z0 24 | return __return_value0 25 | ''' 26 | 27 | -------------------------------------------------------------------------------- /tests/test_if_to_phi.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import inspect 3 | 4 | import pytest 5 | 6 | from ast_tools.passes import apply_passes, if_to_phi 7 | 8 | 9 | def mux(select, t, f): 10 | return t if select else f 11 | 12 | 13 | @pytest.mark.parametrize('phi_args, expected_name', [ 14 | ([mux], '__phi'), # test passing function 15 | (['mux'], 'mux'), # test passing name 16 | ([mux, 'foo'], 'foo'), # test passing function and free name 17 | ([mux, 'mux'], 'mux0'), # test passing function and used name 18 | ]) 19 | def test_basic(phi_args, expected_name): 20 | def basic(s): 21 | return 0 if s else 1 22 | 23 | phi_basic = apply_passes([if_to_phi(*phi_args)])(basic) 24 | 25 | for s in (True, False): 26 | assert basic(s) == phi_basic(s) 27 | 28 | assert inspect.getsource(phi_basic) == f'''\ 29 | def basic(s): 30 | return {expected_name}(s, 0, 1) 31 | ''' 32 | 33 | 34 | @pytest.mark.parametrize('phi_args, expected_name', [ 35 | ([mux], '__phi'), # test passing function 36 | (['mux'], 'mux'), # test passing name 37 | ([mux, 'foo'], 'foo'), # test passing function and free name 38 | ([mux, 'mux'], 'mux0'), # test passing function and used name 39 | ]) 40 | def test_nested(phi_args, expected_name): 41 | def nested(s, t): 42 | return 0 if s else 1 if t else 2 43 | 44 | phi_nested = apply_passes([if_to_phi(*phi_args)])(nested) 45 | 46 | for s in (True, False): 47 | for t in (True, False): 48 | assert nested(s, t) == phi_nested(s, t) 49 | 50 | assert inspect.getsource(phi_nested) == f'''\ 51 | def nested(s, t): 52 | return {expected_name}(s, 0, {expected_name}(t, 1, 2)) 53 | ''' 54 | -------------------------------------------------------------------------------- /tests/test_immutable_ast.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import ast 3 | import sys 4 | 5 | import inspect 6 | from ast_tools import immutable_ast 7 | from ast_tools.immutable_ast import ImmutableMeta 8 | 9 | 10 | trees = [] 11 | 12 | # inspect is about the largest module I know 13 | # hopefully it has a diverse ast 14 | for mod in (immutable_ast, inspect, ast, pytest): 15 | with open(mod.__file__, 'r') as f: 16 | text = f.read() 17 | tree = ast.parse(text) 18 | trees.append(tree) 19 | 20 | 21 | @pytest.mark.parametrize("tree", trees) 22 | def test_mutable_to_immutable(tree): 23 | def _test(tree, itree): 24 | if isinstance(tree, ast.AST): 25 | assert isinstance(itree, immutable_ast.AST) 26 | assert isinstance(tree, type(itree)) 27 | assert tree._fields == itree._fields 28 | assert ImmutableMeta._mutable_to_immutable[type(tree)] is type(itree) 29 | for field, value in ast.iter_fields(tree): 30 | _test(value, getattr(itree, field)) 31 | elif isinstance(tree, list): 32 | assert isinstance(itree, tuple) 33 | assert len(tree) == len(itree) 34 | for c, ic in zip(tree, itree): 35 | _test(c, ic) 36 | else: 37 | assert tree == itree 38 | 39 | 40 | itree = immutable_ast.immutable(tree) 41 | _test(tree, itree) 42 | 43 | @pytest.mark.parametrize("tree", trees) 44 | def test_immutable_to_mutable(tree): 45 | def _test(tree, mtree): 46 | assert type(tree) is type(mtree) 47 | if isinstance(tree, ast.AST): 48 | for field, value in ast.iter_fields(tree): 49 | _test(value, getattr(mtree, field)) 50 | elif isinstance(tree, list): 51 | assert len(tree) == len(mtree) 52 | for c, mc in zip(tree, mtree): 53 | _test(c, mc) 54 | else: 55 | assert tree == mtree 56 | 57 | itree = immutable_ast.immutable(tree) 58 | mtree = immutable_ast.mutable(itree) 59 | _test(tree, mtree) 60 | 61 | 62 | @pytest.mark.parametrize("tree", trees) 63 | def test_eq(tree): 64 | itree = immutable_ast.immutable(tree) 65 | jtree = immutable_ast.immutable(tree) 66 | assert itree == jtree 67 | assert hash(itree) == hash(jtree) 68 | 69 | def test_mutate(): 70 | node = immutable_ast.Name(id='foo', ctx=immutable_ast.Load()) 71 | # can add metadata to a node 72 | node.random = 0 73 | del node.random 74 | 75 | # but cant change its fields 76 | for field in node._fields: 77 | with pytest.raises(AttributeError): 78 | setattr(node, field, 'bar') 79 | 80 | with pytest.raises(AttributeError): 81 | delattr(node, field) 82 | 83 | 84 | def test_construct_from_mutable(): 85 | asts = [ast.Name(id='foo', ctx=ast.Store())] 86 | node = ( 87 | immutable_ast.Module(asts) 88 | if sys.version_info < (3, 8) 89 | else immutable_ast.Module(asts, type_ignores=None) 90 | ) 91 | 92 | assert isinstance(node.body, tuple) 93 | assert type(node.body[0]) is immutable_ast.Name 94 | assert type(node.body[0].ctx) is immutable_ast.Store 95 | -------------------------------------------------------------------------------- /tests/test_inline.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | 3 | import pytest 4 | 5 | from ast_tools.macros import inline 6 | from ast_tools.passes import apply_passes, if_inline 7 | 8 | 9 | @pytest.mark.parametrize('cond', [False, True]) 10 | def test_basic(cond): 11 | def basic(): 12 | if inline(cond): 13 | return 0 14 | else: 15 | return 1 16 | 17 | inlined = apply_passes([if_inline()])(basic) 18 | inlined_src = inspect.getsource(inlined) 19 | assert inlined_src == f'''\ 20 | def basic(): 21 | return {0 if cond else 1} 22 | ''' 23 | assert basic() == inlined() 24 | 25 | @pytest.mark.parametrize('cond_0', [False, True]) 26 | @pytest.mark.parametrize('cond_1', [False, True]) 27 | def test_nested(cond_0, cond_1): 28 | def nested(): 29 | if inline(cond_0): 30 | if inline(cond_1): 31 | return 3 32 | else: 33 | return 2 34 | else: 35 | if inline(cond_1): 36 | return 1 37 | else: 38 | return 0 39 | 40 | inlined = apply_passes([if_inline()])(nested) 41 | assert inspect.getsource(inlined) == f'''\ 42 | def nested(): 43 | return {nested()} 44 | ''' 45 | assert nested() == inlined() 46 | 47 | @pytest.mark.parametrize('cond_0', [False, True]) 48 | @pytest.mark.parametrize('cond_1', [False, True]) 49 | def test_inner_inline(cond_0, cond_1): 50 | def nested(cond): 51 | if cond: 52 | if inline(cond_1): 53 | return 3 54 | else: 55 | return 2 56 | else: 57 | if inline(cond_1): 58 | return 1 59 | else: 60 | return 0 61 | 62 | inlined = apply_passes([if_inline()])(nested) 63 | assert inspect.getsource(inlined) == f'''\ 64 | def nested(cond): 65 | if cond: 66 | return {3 if cond_1 else 2} 67 | else: 68 | return {1 if cond_1 else 0} 69 | ''' 70 | assert nested(cond_0) == inlined(cond_0) 71 | 72 | @pytest.mark.parametrize('cond_0', [False, True]) 73 | @pytest.mark.parametrize('cond_1', [False, True]) 74 | def test_outer_inline(cond_0, cond_1): 75 | def nested(cond): 76 | if inline(cond_0): 77 | if cond: 78 | return 3 79 | else: 80 | return 2 81 | else: 82 | if cond: 83 | return 1 84 | else: 85 | return 0 86 | 87 | inlined = apply_passes([if_inline()])(nested) 88 | assert inspect.getsource(inlined) == f'''\ 89 | def nested(cond): 90 | if cond: 91 | return {3 if cond_0 else 1} 92 | else: 93 | return {2 if cond_0 else 0} 94 | ''' 95 | assert nested(cond_1) == inlined(cond_1) 96 | 97 | 98 | @pytest.mark.parametrize('cond', [False, True]) 99 | def test_if_no_else(cond): 100 | def if_no_else(): 101 | x = 0 102 | if inline(cond): 103 | x = 1 104 | return x 105 | 106 | if cond: 107 | gold_src = '''\ 108 | def if_no_else(): 109 | x = 0 110 | x = 1 111 | return x 112 | ''' 113 | else: 114 | gold_src = '''\ 115 | def if_no_else(): 116 | x = 0 117 | return x 118 | ''' 119 | inlined = apply_passes([if_inline()])(if_no_else) 120 | assert inspect.getsource(inlined) == gold_src 121 | assert if_no_else() == inlined() 122 | 123 | 124 | @pytest.mark.parametrize('cond_0', [False, True]) 125 | @pytest.mark.parametrize('cond_1', [False, True]) 126 | def test_if_elif(cond_0, cond_1): 127 | def if_elif(): 128 | x = 0 129 | if inline(cond_0): 130 | x = 1 131 | elif inline(cond_1): 132 | x = 2 133 | return x 134 | 135 | gold_lines = ['def if_elif():', '{tab}x = 0'] 136 | if cond_0: 137 | gold_lines.append('{tab}x = 1') 138 | elif cond_1: 139 | gold_lines.append('{tab}x = 2') 140 | 141 | gold_lines.append('{tab}return x\n') 142 | 143 | gold_src = '\n'.join(gold_lines).format(tab=' ') 144 | inlined = apply_passes([if_inline()])(if_elif) 145 | assert inspect.getsource(inlined) == gold_src 146 | assert if_elif() == inlined() 147 | 148 | 149 | def test_readme_example(): 150 | y = True 151 | @apply_passes([if_inline()]) 152 | def foo(x): 153 | if inline(y): 154 | return x + 1 155 | else: 156 | return x - 1 157 | assert inspect.getsource(foo) == f"""\ 158 | def foo(x): 159 | return x + 1 160 | """ 161 | -------------------------------------------------------------------------------- /tests/test_normalizers.py: -------------------------------------------------------------------------------- 1 | import libcst as cst 2 | 3 | import pytest 4 | 5 | from ast_tools.cst_utils import to_module 6 | from ast_tools.transformers.normalizers import ElifToElse 7 | from ast_tools.transformers.normalizers import NormalizeBlocks 8 | from ast_tools.transformers.normalizers import NormalizeLines 9 | 10 | def test_elif_to_else(): 11 | src = '''\ 12 | if x: 13 | foo() 14 | elif y: 15 | bar() 16 | else: 17 | foo_bar() 18 | ''' 19 | 20 | gold_src = '''\ 21 | if x: 22 | foo() 23 | else: 24 | if y: 25 | bar() 26 | else: 27 | foo_bar() 28 | ''' 29 | 30 | tree = cst.parse_statement(src) 31 | norm = tree.visit(ElifToElse()) 32 | gold = cst.parse_statement(gold_src) 33 | assert norm.deep_equals(gold) 34 | 35 | 36 | def test_normalize_blocks_if(): 37 | src = '''\ 38 | if x: foo() 39 | elif y: bar() 40 | else: foo_bar() 41 | ''' 42 | 43 | gold_src = '''\ 44 | if x: 45 | foo() 46 | elif y: 47 | bar() 48 | else: 49 | foo_bar() 50 | ''' 51 | 52 | tree = cst.parse_statement(src) 53 | norm = tree.visit(NormalizeBlocks()) 54 | gold = cst.parse_statement(gold_src) 55 | assert norm.deep_equals(gold) 56 | 57 | def test_normalize_blocks_def(): 58 | src = '''\ 59 | def f(): return 0 60 | ''' 61 | 62 | gold_src = '''\ 63 | def f(): 64 | return 0 65 | ''' 66 | 67 | tree = cst.parse_statement(src) 68 | norm = tree.visit(NormalizeBlocks()) 69 | gold = cst.parse_statement(gold_src) 70 | assert norm.deep_equals(gold) 71 | 72 | def test_normalize_lines_module(): 73 | src = '''\ 74 | x = 1;y = 2 75 | ''' 76 | 77 | gold_src = '''\ 78 | x = 1; 79 | y = 2 80 | ''' 81 | 82 | tree = cst.parse_module(src) 83 | norm = tree.visit(NormalizeLines()) 84 | gold = cst.parse_module(gold_src) 85 | assert norm.deep_equals(gold) 86 | 87 | def test_normalize_lines_block(): 88 | 89 | src = '''\ 90 | if x: 91 | y = 0;z = 1 92 | ''' 93 | 94 | gold_src = '''\ 95 | if x: 96 | y = 0; 97 | z = 1 98 | ''' 99 | 100 | tree = cst.parse_statement(src) 101 | norm = tree.visit(NormalizeLines()) 102 | gold = cst.parse_statement(gold_src) 103 | assert norm.deep_equals(gold) 104 | 105 | 106 | def test_normalize_lines_partial(): 107 | src = '''\ 108 | if x: y = 0;z = 0 109 | else: z = 1 110 | ''' 111 | 112 | gold_src = '''\ 113 | if x: 114 | y = 0; 115 | z = 0 116 | else: z = 1 117 | ''' 118 | 119 | tree = cst.parse_statement(src) 120 | norm = tree.visit(NormalizeLines()) 121 | gold = cst.parse_statement(gold_src) 122 | norm_code = to_module(norm).code 123 | gold_code = to_module(gold).code 124 | assert norm.deep_equals(gold) 125 | 126 | -------------------------------------------------------------------------------- /tests/test_passes.py: -------------------------------------------------------------------------------- 1 | import os 2 | import inspect 3 | 4 | import pytest 5 | 6 | import ast_tools 7 | from ast_tools.passes import debug, apply_ast_passes, apply_cst_passes 8 | from ast_tools.passes.util import begin_rewrite, end_rewrite 9 | from ast_tools.stack import SymbolTable 10 | 11 | def attr_setter(attr): 12 | def wrapper(fn): 13 | assert not hasattr(fn, attr) 14 | setattr(fn, attr, True) 15 | return fn 16 | return wrapper 17 | 18 | wrapper1 = attr_setter('a') 19 | wrapper2 = attr_setter('b') 20 | wrapper3 = attr_setter('c') 21 | 22 | def test_begin_end(): 23 | with pytest.warns(DeprecationWarning): 24 | @wrapper1 25 | @end_rewrite(file_name='test_begin_end.py') 26 | @begin_rewrite() 27 | @wrapper2 28 | def foo(): 29 | pass 30 | 31 | assert foo.a 32 | assert foo.b 33 | assert inspect.getsource(foo) == '''\ 34 | @wrapper1 35 | @wrapper2 36 | def foo(): 37 | pass 38 | ''' 39 | 40 | @pytest.mark.parametrize('deco', [apply_ast_passes, apply_cst_passes]) 41 | def test_apply(deco): 42 | @wrapper1 43 | @deco([], file_name='test_apply.py') 44 | @wrapper2 45 | def foo(): 46 | pass 47 | 48 | assert foo.a 49 | assert foo.b 50 | assert inspect.getsource(foo) == '''\ 51 | @wrapper1 52 | @wrapper2 53 | def foo(): 54 | pass 55 | ''' 56 | 57 | 58 | def test_apply_mixed(): 59 | @wrapper1 60 | @apply_ast_passes([], file_name='test_apply_mixed_ast.py') 61 | @wrapper2 62 | @apply_cst_passes([], file_name='test_apply_mixed_cst.py') 63 | @wrapper3 64 | def foo(): 65 | pass 66 | 67 | assert foo.a 68 | assert foo.b 69 | assert foo.c 70 | assert inspect.getsource(foo) == '''\ 71 | @wrapper1 72 | @wrapper2 73 | @wrapper3 74 | def foo(): 75 | pass 76 | ''' 77 | 78 | def test_debug(capsys): 79 | l0 = inspect.currentframe().f_lineno + 1 80 | @apply_ast_passes( 81 | [ast_tools.passes.debug(dump_source_filename=True, dump_source_lines=True)], 82 | debug=True, 83 | file_name='test_debug.py', 84 | ) 85 | def foo(): 86 | print("bar") 87 | out = capsys.readouterr().out 88 | gold = f"""\ 89 | BEGIN SOURCE_FILENAME 90 | {os.path.abspath(__file__)} 91 | END SOURCE_FILENAME 92 | 93 | BEGIN SOURCE_LINES 94 | {l0+0}: @apply_ast_passes( 95 | {l0+1}: [ast_tools.passes.debug(dump_source_filename=True, dump_source_lines=True)], 96 | {l0+2}: debug=True, 97 | {l0+3}: file_name='test_debug.py', 98 | {l0+4}: ) 99 | {l0+5}: def foo(): 100 | {l0+6}: print("bar") 101 | END SOURCE_LINES 102 | 103 | """ 104 | assert out == gold 105 | 106 | 107 | def test_debug_error(): 108 | with pytest.raises( 109 | Exception, 110 | match=r"Cannot dump source filename without .*" 111 | ): 112 | @apply_cst_passes([debug(dump_source_filename=True)]) 113 | def foo(): 114 | print("bar") 115 | 116 | 117 | with pytest.raises( 118 | Exception, 119 | match=r"Cannot dump source lines without .*" 120 | ): 121 | @apply_cst_passes([debug(dump_source_lines=True)]) 122 | def foo(): 123 | print("bar") 124 | 125 | 126 | def test_custom_env(): 127 | @apply_cst_passes( 128 | [], 129 | env=SymbolTable({'x': 1}, globals=globals()), 130 | file_name='test_custom_env.py', 131 | ) 132 | def f(): 133 | return x 134 | 135 | assert f() == 1 136 | -------------------------------------------------------------------------------- /tests/test_pattern.py: -------------------------------------------------------------------------------- 1 | import ast 2 | 3 | from ast_tools.pattern import ast_match 4 | from ast_tools.common import get_ast 5 | 6 | 7 | def parse_match(pattern, stmt): 8 | return ast_match(pattern, ast.parse(stmt).body[0]) 9 | 10 | 11 | def test_pattern_assign_copy(): 12 | stmt1 = "x = y" 13 | stmt2 = "x = 1" 14 | 15 | pattern = "{lhs:Name} = {rhs:Name}" 16 | match1 = parse_match(pattern, stmt1) 17 | assert match1 is not None 18 | assert match1['lhs'].id == 'x' 19 | assert match1['rhs'].id == 'y' 20 | 21 | match2 = parse_match(pattern, stmt2) 22 | assert match2 is None 23 | 24 | 25 | def test_pattern_assign_number(): 26 | stmt1 = "x = y" 27 | stmt2 = "x = 1" 28 | 29 | pattern = "{lhs:Name} = {rhs:Num}" 30 | match1 = parse_match(pattern, stmt1) 31 | assert match1 is None 32 | 33 | match2 = parse_match(pattern, stmt2) 34 | assert match2 is not None 35 | assert match2['lhs'].id == 'x' 36 | assert match2['rhs'].n == 1 37 | 38 | 39 | def test_pattern_if(): 40 | stmt1 = """ 41 | if x: 42 | y = 1 43 | else: 44 | z = 1 45 | """ 46 | 47 | stmt2 = """ 48 | if x: 49 | print(y) 50 | else: 51 | z = 1 52 | """ 53 | 54 | pattern = """ 55 | if {cond:Name}: 56 | {then_:Assign} 57 | else: 58 | {else_} 59 | """ 60 | 61 | match1 = parse_match(pattern, stmt1) 62 | assert match1 is not None 63 | 64 | match2 = parse_match(pattern, stmt2) 65 | assert match2 is None 66 | -------------------------------------------------------------------------------- /tests/test_ssa.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import inspect 3 | import random 4 | 5 | import libcst as cst 6 | 7 | import pytest 8 | 9 | from ast_tools.common import exec_def_in_file 10 | from ast_tools.passes.ssa import ssa 11 | from ast_tools.passes import apply_passes, debug 12 | from ast_tools.stack import SymbolTable 13 | 14 | 15 | 16 | NTEST = 16 17 | 18 | basic_template = '''\ 19 | def basic(x): 20 | if x: 21 | {} 0 22 | else: 23 | {} 2 24 | {} 25 | ''' 26 | 27 | template_options = ['r =', 'return'] 28 | 29 | def _do_ssa(func, strict, **kwargs): 30 | for dec in ( 31 | begin_rewrite(), 32 | debug(**kwargs), 33 | ssa(strict), 34 | debug(**kwargs), 35 | end_rewrite()): 36 | func = dec(func) 37 | return func 38 | 39 | 40 | @pytest.mark.parametrize('strict', [True, False]) 41 | @pytest.mark.parametrize('a', template_options) 42 | @pytest.mark.parametrize('b', template_options) 43 | def test_basic_if(strict, a, b): 44 | if a == b == 'return': 45 | final = '' 46 | else: 47 | final = 'return r' 48 | 49 | src = basic_template.format(a, b, final) 50 | tree = cst.parse_statement(src) 51 | env = SymbolTable({}, {}) 52 | basic = exec_def_in_file(tree, env) 53 | ssa_basic = apply_passes([ssa(strict)])(basic) 54 | 55 | for x in (False, True): 56 | assert basic(x) == ssa_basic(x) 57 | 58 | 59 | nested_template = '''\ 60 | def nested(x, y): 61 | if x: 62 | if y: 63 | {} 0 64 | else: 65 | {} 1 66 | else: 67 | if y: 68 | {} 2 69 | else: 70 | {} 3 71 | {} 72 | ''' 73 | 74 | @pytest.mark.parametrize('strict', [True, False]) 75 | @pytest.mark.parametrize('a', template_options) 76 | @pytest.mark.parametrize('b', template_options) 77 | @pytest.mark.parametrize('c', template_options) 78 | @pytest.mark.parametrize('d', template_options) 79 | def test_nested(strict, a, b, c, d): 80 | if a == b == c == d == 'return': 81 | final = '' 82 | else: 83 | final = 'return r' 84 | 85 | src = nested_template.format(a, b, c, d, final) 86 | tree = cst.parse_statement(src) 87 | env = SymbolTable({}, {}) 88 | nested = exec_def_in_file(tree, env) 89 | ssa_nested = apply_passes([ssa(strict)])(nested) 90 | 91 | for x in (False, True): 92 | for y in (False, True): 93 | assert nested(x, y) == ssa_nested(x, y) 94 | 95 | imbalanced_template = '''\ 96 | def imbalanced(x, y): 97 | {} -1 98 | if x: 99 | {} -2 100 | if y: 101 | {} 0 102 | else: 103 | {} 1 104 | return r 105 | ''' 106 | 107 | init_template_options = ['r = ', '0'] 108 | 109 | @pytest.mark.parametrize('strict', [True, False]) 110 | @pytest.mark.parametrize('a', init_template_options) 111 | @pytest.mark.parametrize('b', init_template_options) 112 | @pytest.mark.parametrize('c', template_options) 113 | @pytest.mark.parametrize('d', template_options) 114 | def test_imbalanced(strict, a, b, c, d): 115 | src = imbalanced_template.format(a, b, c, d) 116 | tree = cst.parse_statement(src) 117 | env = SymbolTable({}, {}) 118 | imbalanced = exec_def_in_file(tree, env) 119 | can_name_error = False 120 | for x in (False, True): 121 | for y in (False, True): 122 | try: 123 | imbalanced(x, y) 124 | except NameError: 125 | can_name_error = True 126 | break 127 | 128 | if can_name_error and strict: 129 | with pytest.raises(SyntaxError): 130 | ssa_imbalanced = apply_passes([ssa(strict)])(imbalanced) 131 | else: 132 | ssa_imbalanced = apply_passes([ssa(strict)])(imbalanced) 133 | for x in (False, True): 134 | for y in (False, True): 135 | try: 136 | assert imbalanced(x, y) == ssa_imbalanced(x, y) 137 | except NameError: 138 | assert can_name_error 139 | 140 | 141 | 142 | def test_reassign_arg(): 143 | def bar(x): 144 | return x 145 | 146 | @apply_passes([ssa()], metadata_attr='metadata') 147 | def foo(a, b): 148 | if b: 149 | a = len(a) 150 | return a 151 | assert inspect.getsource(foo) == '''\ 152 | def foo(a, b): 153 | _cond_0 = b 154 | a_0 = len(a) 155 | a_1 = a_0 if _cond_0 else a 156 | __0_return_0 = a_1 157 | return __0_return_0 158 | ''' 159 | symbol_tables = foo.metadata['SYMBOL-TABLE'] 160 | assert len(symbol_tables) == 1 161 | assert symbol_tables[0][0] == ssa 162 | symbol_table = symbol_tables[0][1] 163 | assert symbol_table == { 164 | 2: { 165 | 'a': 'a', 166 | 'b': 'b', 167 | }, 168 | 3: { 169 | 'a': 'a', 170 | 'b': 'b', 171 | }, 172 | 4: { 173 | 'a': 'a_0', 174 | 'b': 'b', 175 | }, 176 | 5: { 177 | 'a': 'a_1', 178 | 'b': 'b', 179 | }, 180 | } 181 | 182 | 183 | def test_double_nested_function_call(): 184 | def bar(x): 185 | return x 186 | 187 | def baz(x): 188 | return x + 1 189 | 190 | @apply_passes([ssa()], metadata_attr='metadata') # 1 191 | def foo(a, b, c): # 2 192 | if b: # 3 193 | a = bar(a) # 4 194 | else: # 5 195 | a = bar(a) # 6 196 | if c: # 7 197 | b = bar(b) # 8 198 | else: # 9 199 | b = bar(b) # 10 200 | return a, b # 11 201 | assert inspect.getsource(foo) == '''\ 202 | def foo(a, b, c): # 2 203 | _cond_0 = b 204 | a_0 = bar(a) # 4 205 | a_1 = bar(a) # 6 206 | a_2 = a_0 if _cond_0 else a_1 207 | _cond_1 = c 208 | b_0 = bar(b) # 8 209 | b_1 = bar(b) # 10 210 | b_2 = b_0 if _cond_1 else b_1 211 | __0_return_0 = a_2, b_2 # 11 212 | return __0_return_0 213 | ''' 214 | 215 | symbol_tables = foo.metadata['SYMBOL-TABLE'] 216 | assert len(symbol_tables) == 1 217 | assert symbol_tables[0][0] == ssa 218 | symbol_table = symbol_tables[0][1] 219 | gold_table = {i: { 220 | 'a': 'a' if i < 4 else 'a_0' if i < 6 else 'a_1' if i < 7 else 'a_2', 221 | 'b': 'b' if i < 8 else 'b_0' if i < 10 else 'b_1' if i < 11 else 'b_2', 222 | 'c': 'c', 223 | } for i in range(2, 12)} 224 | assert symbol_table == gold_table 225 | 226 | class Thing: 227 | def __init__(self, x=None): 228 | self.x = x 229 | 230 | def __eq__(self, other): 231 | if isinstance(other, type(self)): 232 | return self.x == other.x 233 | else: 234 | return self.x == other 235 | 236 | def __ne__(self, other): 237 | return not (self == other) 238 | 239 | def __repr__(self): 240 | return f'Thing({self.x})' 241 | 242 | @pytest.mark.parametrize('strict', [True, False]) 243 | def test_attrs_basic(strict): 244 | def f1(t, cond): 245 | old = t.x 246 | if cond: 247 | t.x = 1 248 | else: 249 | t.x = 0 250 | return old 251 | 252 | f2 = apply_passes([ssa(strict)])(f1) 253 | 254 | t1 = Thing() 255 | t2 = Thing() 256 | 257 | assert t1 == t2 == None 258 | f1(t1, True) 259 | assert t1 != t2 260 | f2(t2, True) 261 | assert t1 == t2 == 1 262 | f1(t1, False) 263 | assert t1 != t2 264 | f2(t2, False) 265 | assert t1 == t2 == 0 266 | 267 | 268 | @pytest.mark.parametrize('strict', [True, False]) 269 | def test_attrs_returns(strict): 270 | def f1(t, cond1, cond2): 271 | if cond1: 272 | t.x = 1 273 | if cond2: 274 | return 0 275 | else: 276 | t.x = 0 277 | if cond2: 278 | return 1 279 | return -1 280 | 281 | f2 = apply_passes([ssa(strict)])(f1) 282 | 283 | t1 = Thing() 284 | t2 = Thing() 285 | assert t1 == t2 286 | 287 | for _ in range(NTEST): 288 | c1 = random.randint(0, 1) 289 | c2 = random.randint(0, 1) 290 | o1 = f1(t1, c1, c2) 291 | o2 = f2(t2, c1, c2) 292 | assert o1 == o2 293 | assert t1 == t2 294 | 295 | 296 | @pytest.mark.parametrize('strict', [True, False]) 297 | def test_attrs_class(strict): 298 | class Counter1: 299 | def __init__(self, init, max): 300 | self.cnt = init 301 | self.max = max 302 | 303 | def __call__(self, en): 304 | if en and self.cnt < self.max - 1: 305 | self.cnt = self.cnt + 1 306 | elif en: 307 | self.cnt = 0 308 | 309 | class Counter2: 310 | __init__ = Counter1.__init__ 311 | __call__ = apply_passes([ssa(strict)])(Counter1.__call__) 312 | 313 | c1 = Counter1(3, 5) 314 | c2 = Counter2(3, 5) 315 | 316 | assert c1.cnt == c2.cnt 317 | 318 | for _ in range(NTEST): 319 | e = random.randint(0, 1) 320 | o1 = c1(e) 321 | o2 = c2(e) 322 | assert o1 == o2 323 | assert c1.cnt == c2.cnt 324 | 325 | @pytest.mark.parametrize('strict', [True, False]) 326 | def test_attrs_class_methods(strict): 327 | class Counter1: 328 | def __init__(self, init, max): 329 | self.cnt = init 330 | self.max = max 331 | 332 | def __call__(self, en): 333 | if en and self.cnt < self.max - 1: 334 | self.cnt = self.cnt + self.get_step(self.cnt) 335 | elif en: 336 | self.cnt = 0 337 | 338 | def get_step(self, cnt): 339 | return (cnt % 2) + 1 340 | 341 | class Counter2: 342 | __init__ = Counter1.__init__ 343 | __call__ = apply_passes([ssa(strict)])(Counter1.__call__) 344 | get_step = Counter1.get_step 345 | 346 | c1 = Counter1(3, 5) 347 | c2 = Counter2(3, 5) 348 | 349 | assert c1.cnt == c2.cnt 350 | 351 | for _ in range(NTEST): 352 | e = random.randint(0, 1) 353 | o1 = c1(e) 354 | o2 = c2(e) 355 | assert o1 == o2 356 | assert c1.cnt == c2.cnt 357 | 358 | 359 | def test_nstrict(): 360 | # This function would confuse strict ssa in so many ways 361 | def f1(cond): 362 | if cond: 363 | if cond: 364 | return 0 365 | elif not cond: 366 | z = 1 367 | 368 | if not cond: 369 | x = z 370 | return x 371 | 372 | f2 = apply_passes([ssa(False)])(f1) 373 | assert inspect.getsource(f2) == '''\ 374 | def f1(cond): 375 | _cond_2 = cond 376 | _cond_0 = cond 377 | __0_return_0 = 0 378 | _cond_1 = not cond 379 | z_0 = 1 380 | _cond_3 = not cond 381 | x_0 = z_0 382 | __0_return_1 = x_0 383 | return __0_return_0 if _cond_2 and _cond_0 else __0_return_1 384 | ''' 385 | for cond in [True, False]: 386 | assert f1(cond) == f2(cond) 387 | 388 | 389 | def test_attr(): 390 | bar = namedtuple('bar', ['x', 'y']) 391 | 392 | def f1(x, y): 393 | z = bar(1, 0) 394 | if x: 395 | a = z 396 | else: 397 | a = y 398 | a.x = 3 399 | return a 400 | 401 | f2 = apply_passes([ssa(False)])(f1) 402 | assert inspect.getsource(f2) == '''\ 403 | def f1(x, y): 404 | _attr_a_x_0 = a.x 405 | z_0 = bar(1, 0) 406 | _cond_0 = x 407 | a_0 = z_0 408 | a_1 = y 409 | a_2 = a_0 if _cond_0 else a_1 410 | _attr_a_x_1 = 3 411 | __0_final_a_x_0 = _attr_a_x_1; __0_return_0 = a_2 412 | a_2.x = __0_final_a_x_0 413 | return __0_return_0 414 | ''' 415 | 416 | 417 | def test_call(): 418 | def f1(x): 419 | x = 2 420 | return g(x=x) 421 | 422 | f2 = apply_passes([ssa(False)])(f1) 423 | assert inspect.getsource(f2) == '''\ 424 | def f1(x): 425 | x_0 = 2 426 | __0_return_0 = g(x=x_0) 427 | return __0_return_0 428 | ''' 429 | 430 | 431 | def ident(x): return x 432 | 433 | sig_template = '''\ 434 | def f(x{}, y{}) -> ({}, {}): 435 | return x, y 436 | ''' 437 | 438 | 439 | template_options = ['', 'int', 'ident(int)', 'ident(x=int)'] 440 | 441 | @pytest.mark.parametrize('strict', [True, False]) 442 | @pytest.mark.parametrize('x', template_options) 443 | @pytest.mark.parametrize('y', template_options) 444 | def test_call_in_annotations(strict, x, y): 445 | r_x = x if x else 'int' 446 | r_y = y if y else 'int' 447 | x = f': {x}' if x else x 448 | y = f': {y}' if y else y 449 | src = sig_template.format(x, y, r_x, r_y) 450 | tree = cst.parse_statement(src) 451 | env = SymbolTable(locals(), globals()) 452 | f1 = exec_def_in_file(tree, env) 453 | f2 = apply_passes([ssa(strict)])(f1) 454 | 455 | 456 | @pytest.mark.parametrize('strict', [True, False]) 457 | def test_issue_79(strict): 458 | class Wrapper: 459 | def __init__(self, val): 460 | self.val = val 461 | def apply(self, f): 462 | return f(self.val) 463 | 464 | def f1(x): 465 | return x.apply(lambda x: x+1) 466 | 467 | f2 = apply_passes([ssa(strict)])(f1) 468 | 469 | for _ in range(8): 470 | x = Wrapper(random.randint(0, 1<<10)) 471 | assert f1(x) == f2(x) 472 | -------------------------------------------------------------------------------- /tests/test_stack.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from ast_tools import stack 3 | 4 | MAGIC = 'foo' 5 | 6 | def test_get_symbol_table(): 7 | MAGIC = 'bar' 8 | st = stack.get_symbol_table() 9 | assert st.globals['MAGIC'] == 'foo' 10 | assert st.locals['MAGIC'] == 'bar' 11 | 12 | def test_inspect_symbol_table(): 13 | MAGIC = 'bar' 14 | 15 | @stack.inspect_symbol_table 16 | def test(st): 17 | assert st.globals['MAGIC'] == 'foo' 18 | assert st.locals['MAGIC'] == 'bar' 19 | 20 | test() 21 | 22 | @stack.inspect_symbol_table 23 | def test(st): 24 | assert st.locals[stack._SKIP_FRAME_DEBUG_NAME] == 0xdeadbeaf 25 | 26 | stack._SKIP_FRAME_DEBUG_FAIL = True 27 | exec(stack._SKIP_FRAME_DEBUG_STMT) 28 | 29 | with pytest.raises(RuntimeError): 30 | test() 31 | 32 | stack._SKIP_FRAME_DEBUG_FAIL = False 33 | test() 34 | 35 | 36 | def test_inspect_enclosing_env(): 37 | MAGIC = 'bar' 38 | 39 | @stack.inspect_enclosing_env 40 | def test(env): 41 | assert env['MAGIC'] == 'bar' 42 | 43 | test() 44 | 45 | @stack.inspect_enclosing_env 46 | def test(env): 47 | assert env[stack._SKIP_FRAME_DEBUG_NAME] == 0xdeadbeaf 48 | 49 | stack._SKIP_FRAME_DEBUG_FAIL = True 50 | exec(stack._SKIP_FRAME_DEBUG_STMT) 51 | 52 | with pytest.raises(RuntimeError): 53 | test() 54 | 55 | stack._SKIP_FRAME_DEBUG_FAIL = False 56 | test() 57 | 58 | def test_custom_env(): 59 | MAGIC1 = 'foo' 60 | def test(env): 61 | assert env['MAGIC1'] == 'foo' 62 | assert env['MAGIC2'] == 'bar' 63 | 64 | st = stack.SymbolTable(locals={},globals={'MAGIC2':'bar'}) 65 | test = stack.inspect_enclosing_env(test, st=st) 66 | test() 67 | 68 | def test_get_symbol_table_copy_frames(): 69 | non_copy_sts = [] 70 | copy_sts = [] 71 | for i in range(5): 72 | non_copy_sts.append(stack.get_symbol_table()) 73 | copy_sts.append(stack.get_symbol_table(copy_locals=True)) 74 | for j in range(5): 75 | assert non_copy_sts[j].locals["i"] == 4 76 | assert copy_sts[j].locals["i"] == j 77 | -------------------------------------------------------------------------------- /tests/test_unroll.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | 3 | import pytest 4 | 5 | import libcst as cst 6 | 7 | import ast_tools 8 | from ast_tools.transformers.loop_unroller import unroll_for_loops 9 | from ast_tools.passes import apply_passes, loop_unroll 10 | from ast_tools.cst_utils import to_module 11 | 12 | 13 | 14 | def test_basic_unroll(): 15 | src = """\ 16 | def foo(): 17 | for i in ast_tools.macros.unroll(range(8)): 18 | print(i) 19 | """ 20 | unrolled_src = """\ 21 | def foo(): 22 | print(0) 23 | print(1) 24 | print(2) 25 | print(3) 26 | print(4) 27 | print(5) 28 | print(6) 29 | print(7) 30 | """ 31 | tree = cst.parse_module(src) 32 | unrolled_tree = unroll_for_loops(tree, globals()) 33 | assert to_module(unrolled_tree).code == unrolled_src 34 | 35 | 36 | def test_basic_inside_if(): 37 | src = """\ 38 | def foo(x): 39 | if x: 40 | for i in ast_tools.macros.unroll(range(8)): 41 | print(i) 42 | return x + 1 if x % 2 else x 43 | else: 44 | print(x) 45 | for j in ast_tools.macros.unroll(range(2)): 46 | print(j - 1) 47 | """ 48 | unrolled_src = """\ 49 | def foo(x): 50 | if x: 51 | print(0) 52 | print(1) 53 | print(2) 54 | print(3) 55 | print(4) 56 | print(5) 57 | print(6) 58 | print(7) 59 | return x + 1 if x % 2 else x 60 | else: 61 | print(x) 62 | print(0 - 1) 63 | print(1 - 1) 64 | """ 65 | tree = cst.parse_module(src) 66 | unrolled_tree = unroll_for_loops(tree, globals()) 67 | assert to_module(unrolled_tree).code == unrolled_src 68 | 69 | 70 | def test_basic_inside_while(): 71 | src = """\ 72 | def foo(x): 73 | while True: 74 | for i in ast_tools.macros.unroll(range(8)): 75 | print(i) 76 | """ 77 | unrolled_src = """\ 78 | def foo(x): 79 | while True: 80 | print(0) 81 | print(1) 82 | print(2) 83 | print(3) 84 | print(4) 85 | print(5) 86 | print(6) 87 | print(7) 88 | """ 89 | tree = cst.parse_module(src) 90 | unrolled_tree = unroll_for_loops(tree, globals()) 91 | assert to_module(unrolled_tree).code == unrolled_src 92 | 93 | 94 | def test_basic_env(): 95 | src = """\ 96 | def foo(x): 97 | for i in ast_tools.macros.unroll(range(j)): 98 | print(i) 99 | """ 100 | unrolled_src = """\ 101 | def foo(x): 102 | print(0) 103 | print(1) 104 | """ 105 | env = dict(globals(), **{"j": 2}) 106 | tree = cst.parse_module(src) 107 | unrolled_tree = unroll_for_loops(tree, env) 108 | assert to_module(unrolled_tree).code == unrolled_src 109 | 110 | 111 | def test_pass_basic(): 112 | @apply_passes([loop_unroll()]) 113 | def foo(): 114 | for i in ast_tools.macros.unroll(range(8)): 115 | print(i) 116 | assert inspect.getsource(foo) == """\ 117 | def foo(): 118 | print(0) 119 | print(1) 120 | print(2) 121 | print(3) 122 | print(4) 123 | print(5) 124 | print(6) 125 | print(7) 126 | """ 127 | 128 | 129 | def test_pass_env(): 130 | j = 3 131 | @apply_passes([loop_unroll()]) 132 | def foo(): 133 | for i in ast_tools.macros.unroll(range(j)): 134 | print(i) 135 | assert inspect.getsource(foo) == """\ 136 | def foo(): 137 | print(0) 138 | print(1) 139 | print(2) 140 | """ 141 | 142 | 143 | def test_pass_nested(): 144 | @apply_passes([loop_unroll()]) 145 | def foo(): 146 | for i in ast_tools.macros.unroll(range(2)): 147 | for j in ast_tools.macros.unroll(range(3)): 148 | print(i + j) 149 | assert inspect.getsource(foo) == """\ 150 | def foo(): 151 | print(0 + 0) 152 | print(0 + 1) 153 | print(0 + 2) 154 | print(1 + 0) 155 | print(1 + 1) 156 | print(1 + 2) 157 | """ 158 | 159 | 160 | def test_pass_no_unroll(): 161 | j = 3 162 | @apply_passes([loop_unroll()]) 163 | def foo(): 164 | for i in range(j): 165 | print(i) 166 | assert inspect.getsource(foo) == """\ 167 | def foo(): 168 | for i in range(j): 169 | print(i) 170 | """ 171 | 172 | 173 | def test_pass_no_unroll_nested(): 174 | j = 3 175 | @apply_passes([loop_unroll()]) 176 | def foo(): 177 | for i in range(j): 178 | for k in ast_tools.macros.unroll(range(3)): 179 | print(i * k) 180 | assert inspect.getsource(foo) == """\ 181 | def foo(): 182 | for i in range(j): 183 | print(i * 0) 184 | print(i * 1) 185 | print(i * 2) 186 | """ 187 | 188 | 189 | def test_pass_no_rewrite_range(): 190 | j = 3 191 | 192 | def foo(): 193 | count = 0 194 | for i in range(j): 195 | for k in ast_tools.macros.unroll(range(3)): 196 | count += 1 197 | return count 198 | assert foo() == 3 * 3 199 | 200 | 201 | def test_bad_iter(): 202 | with pytest.raises(Exception): 203 | @apply_passes([loop_unroll()]) 204 | def foo(): 205 | count = 0 206 | for k in ast_tools.macros.unroll([object(), object()]): 207 | count += 1 208 | return count 209 | 210 | 211 | def test_list_of_ints(): 212 | j = [1, 2, 3] 213 | @apply_passes([loop_unroll()]) 214 | def foo(): 215 | for i in ast_tools.macros.unroll(j): 216 | print(i) 217 | assert inspect.getsource(foo) == """\ 218 | def foo(): 219 | print(1) 220 | print(2) 221 | print(3) 222 | """ 223 | -------------------------------------------------------------------------------- /tests/test_visitors.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test visitors 3 | """ 4 | import libcst as cst 5 | 6 | import pytest 7 | 8 | from ast_tools.visitors import collect_names 9 | from ast_tools.visitors import collect_targets 10 | from ast_tools.visitors import used_names 11 | 12 | 13 | def test_collect_targets(): 14 | tree = cst.parse_module(''' 15 | x = [0, 1] 16 | x[0] = 1 17 | x.attr = 2 18 | ''') 19 | x = cst.Name(value='x') 20 | x0 = cst.Subscript( 21 | value=x, 22 | slice=[ 23 | cst.SubscriptElement( 24 | slice=cst.Index(value=cst.Integer('0')) 25 | ) 26 | ], 27 | ) 28 | xa = cst.Attribute( 29 | value=x, 30 | attr=cst.Name('attr'), 31 | ) 32 | 33 | golds = x,x0,xa 34 | 35 | targets = collect_targets(tree) 36 | assert all(t.deep_equals(g) for t,g in zip(targets, golds)) 37 | 38 | 39 | def test_used_names(): 40 | tree = cst.parse_module(''' 41 | x = 1 42 | def foo(): 43 | def g(): pass 44 | 45 | class A: 46 | def __init__(self): pass 47 | 48 | class B: pass 49 | 50 | x.f = 7 51 | 52 | async def h(): pass 53 | ''') 54 | assert used_names(tree) == {'x', 'foo', 'A', 'h'} 55 | assert used_names(tree.body[1].body) == {'g'} 56 | 57 | # Currently broken requires new release of LibCSt 58 | def test_collect_names(): 59 | """ 60 | Test collecting names from a simple function including the `ctx` feature 61 | """ 62 | s = ''' 63 | def foo(bar, baz): # pylint: disable=blacklisted-name 64 | buzz = bar + baz 65 | name_error 66 | del bar 67 | return buzz 68 | ''' 69 | 70 | foo_ast = cst.parse_module(s) 71 | assert collect_names(foo_ast, ctx=cst.metadata.ExpressionContext.STORE) == {"foo", "bar", "baz", "buzz"} 72 | assert collect_names(foo_ast, ctx=cst.metadata.ExpressionContext.LOAD) == {"bar", "baz", "buzz", "name_error"} 73 | assert collect_names(foo_ast, ctx=cst.metadata.ExpressionContext.DEL) == {"bar"} 74 | assert collect_names(foo_ast) == {"foo", "bar", "baz", "buzz", "name_error"} 75 | 76 | -------------------------------------------------------------------------------- /util/generate_ast/__init__.py: -------------------------------------------------------------------------------- 1 | if __name__ == '__main__': 2 | import generate 3 | print(generate.generate_immutable_ast()) 4 | 5 | -------------------------------------------------------------------------------- /util/generate_ast/_base.px: -------------------------------------------------------------------------------- 1 | class AST(mutable=ast.AST, metaclass=ImmutableMeta): 2 | def __setattr__(self, attr, value): 3 | if attr in self._fields and hasattr(self, attr): 4 | raise AttributeError('Cannot modify ImmutableAST fields') 5 | elif isinstance(value, (list, ast.AST)): 6 | value = immutable(value) 7 | 8 | self.__dict__[attr] = value 9 | 10 | def __delattr__(self, attr): 11 | if attr in self._fields: 12 | raise AttributeError('Cannot modify ImmutableAST fields') 13 | del self.__dict__[attr] 14 | 15 | def __hash__(self): 16 | try: 17 | return self._hash_ 18 | except AttributeError: 19 | pass 20 | 21 | h = hash(type(self)) 22 | for _, n in iter_fields(self): 23 | if isinstance(n, AST): 24 | h += hash(n) 25 | elif isinstance(n, tp.Sequence): 26 | for c in n: 27 | h += hash(c) 28 | else: 29 | h += hash(n) 30 | self._hash_ = h 31 | return h 32 | 33 | def __eq__(self, other): 34 | if not isinstance(other, type(self)): 35 | return NotImplemented 36 | elif type(self) == type(other): 37 | for f in self._fields: 38 | if getattr(self, f) != getattr(other, f): 39 | return False 40 | return True 41 | else: 42 | return False 43 | 44 | def __ne__(self, other): 45 | return not (self == other) 46 | -------------------------------------------------------------------------------- /util/generate_ast/_functions.px: -------------------------------------------------------------------------------- 1 | __ALL__ += ['immutable', 'mutable', 'parse', 'dump', 2 | 'iter_fields', 'iter_child_nodes', 'walk', 3 | 'NodeVisitor', 'NodeTransformer'] 4 | 5 | 6 | def _cast_tree(seq_t, n_seq_t, type_look_up, tree): 7 | args = seq_t, n_seq_t, type_look_up 8 | 9 | if isinstance(tree, seq_t): 10 | return n_seq_t(_cast_tree(*args, c) for c in tree) 11 | 12 | try: 13 | T = type_look_up[type(tree)] 14 | except KeyError: 15 | return tree 16 | 17 | kwargs = {} 18 | for field, c in iter_fields(tree): 19 | kwargs[field] = _cast_tree(*args, c) 20 | 21 | return T(**kwargs) 22 | 23 | 24 | def immutable(tree: ast.AST) -> 'AST': 25 | '''Converts a mutable ast to an immutable one''' 26 | return _cast_tree(list, tuple, ImmutableMeta._mutable_to_immutable, tree) 27 | 28 | def mutable(tree: 'AST') -> ast.AST: 29 | '''Converts an immutable ast to a mutable one''' 30 | return _cast_tree(tuple, list, ImmutableMeta._immutable_to_mutable, tree) 31 | 32 | def parse(source, filename='', mode='exec') -> 'AST': 33 | tree = ast.parse(source, filename, mode) 34 | return immutable(tree) 35 | 36 | def dump(node, annotate_fields=True, include_attributes=False) -> str: 37 | tree = mutable(node) 38 | return ast.dump(tree) 39 | 40 | 41 | # duck typing ftw 42 | iter_fields = ast.iter_fields 43 | 44 | # The following is more or less copied verbatim from 45 | # CPython/Lib/ast.py. Changes are: 46 | # s/list/tuple/ 47 | # 48 | # The CPython license is very permissive so I am pretty sure this is cool. 49 | # If it is not Guido please forgive me. 50 | def iter_child_nodes(node): 51 | for name, field in iter_fields(node): 52 | if isinstance(field, AST): 53 | yield field 54 | elif isinstance(field, tuple): 55 | for item in field: 56 | if isinstance(item, AST): 57 | yield item 58 | 59 | # Same note as above 60 | def walk(node): 61 | from collections import deque 62 | todo = deque([node]) 63 | while todo: 64 | node = todo.popleft() 65 | todo.extend(iter_child_nodes(node)) 66 | yield node 67 | 68 | 69 | # Same note as above 70 | class NodeVisitor: 71 | def visit(self, node): 72 | method = 'visit_' + node.__class__.__name__ 73 | visitor = getattr(self, method, self.generic_visit) 74 | return visitor(node) 75 | 76 | def generic_visit(self, node): 77 | for field, value in iter_fields(node): 78 | if isinstance(value, tuple): 79 | for item in value: 80 | if isinstance(item, AST): 81 | self.visit(item) 82 | elif isinstance(value, AST): 83 | self.visit(value) 84 | 85 | 86 | # Same note as above 87 | class NodeTransformer(NodeVisitor): 88 | ''' 89 | Mostly equivalent to ast.NodeTransformer, except returns new nodes 90 | instead of mutating them in place 91 | ''' 92 | 93 | def generic_visit(self, node): 94 | kwargs = {} 95 | for field, old_value in iter_fields(node): 96 | if isinstance(old_value, tuple): 97 | new_value = [] 98 | for value in old_value: 99 | if isinstance(value, AST): 100 | value = self.visit(value) 101 | if value is None: 102 | continue 103 | elif not isinstance(value, AST): 104 | new_value.extend(value) 105 | continue 106 | new_value.append(value) 107 | new_value = tuple(new_value) 108 | elif isinstance(type(old_value), ImmutableMeta): 109 | new_value = self.visit(old_value) 110 | else: 111 | new_value = old_value 112 | kwargs[field] = new_value 113 | return type(node)(**kwargs) 114 | -------------------------------------------------------------------------------- /util/generate_ast/_meta.px: -------------------------------------------------------------------------------- 1 | __ALL__ += ['ImmutableMeta'] 2 | 3 | class ImmutableMeta(type): 4 | _immutable_to_mutable = dict() 5 | _mutable_to_immutable = dict() 6 | def __new__(mcs, name, bases, namespace, mutable, **kwargs): 7 | cls = super().__new__(mcs, name, bases, namespace, **kwargs) 8 | ImmutableMeta._immutable_to_mutable[cls] = mutable 9 | ImmutableMeta._mutable_to_immutable[mutable] = cls 10 | 11 | return cls 12 | 13 | def __instancecheck__(cls, instance): 14 | return super().__instancecheck__(instance)\ 15 | or isinstance(instance, ImmutableMeta._immutable_to_mutable[cls]) 16 | 17 | def __subclasscheck__(cls, type_): 18 | return super().__subclasscheck__(type_)\ 19 | or issubclass(type_, ImmutableMeta._immutable_to_mutable[cls]) 20 | -------------------------------------------------------------------------------- /util/generate_ast/generate.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import datetime 3 | import inspect 4 | from os import path 5 | import sys 6 | 7 | _BASE_PATH = path.dirname(__file__) 8 | def _make_path(f): 9 | return path.abspath(path.join(_BASE_PATH, f)) 10 | 11 | META_FILE = _make_path('_meta.px') 12 | FUNCTIONS_FILE = _make_path('_functions.px') 13 | AST_BASE_FILE = _make_path('_base.px') 14 | TAB = ' '*4 15 | 16 | def generate_class(name, bases, fields): 17 | bases=', '.join(bases) + (', ' if bases else '') 18 | sig = (', ' if fields else '') + ', '.join(fields) 19 | body = [f'self.{arg} = {arg}' for arg in fields] 20 | if not body: 21 | body.append('pass') 22 | 23 | body = f'\n{TAB}{TAB}'.join(body) 24 | 25 | class_ = f'''\ 26 | class {name}({bases}mutable=ast.{name}): 27 | {TAB}_fields={fields} 28 | {TAB}def __init__(self{sig}): 29 | {TAB}{TAB}{body} 30 | ''' 31 | 32 | return class_ 33 | 34 | def generate_classes(class_tree, ALL): 35 | cls_to_args = {ast.AST : ('AST', (), ())} 36 | 37 | def pop_args_from_tree(tree): 38 | for item in tree: 39 | if isinstance(item, list): 40 | r = pop_args_from_tree(item) 41 | if r is not None: 42 | return r 43 | elif item[0] not in cls_to_args: 44 | cls = item[0] 45 | bases = tuple(cls_to_args[base][0] for base in item[1] if base is not object) 46 | cls_to_args[cls] = r = cls.__name__, bases, cls._fields 47 | return r 48 | 49 | classes_ = [] 50 | 51 | args = pop_args_from_tree(class_tree[1]) 52 | while args is not None: 53 | class_ = generate_class(*args) 54 | classes_.append(class_) 55 | ALL.append(args[0]) 56 | args = pop_args_from_tree(class_tree[1]) 57 | 58 | return '\n'.join(classes_) 59 | 60 | 61 | def generate_immutable_ast(): 62 | def _issubclass(t, types): 63 | try: 64 | return issubclass(t, types) 65 | except TypeError: 66 | pass 67 | return False 68 | 69 | classes = [] 70 | for name in dir(ast): 71 | obj = getattr(ast, name) 72 | if _issubclass(obj, ast.AST): 73 | classes.append(obj) 74 | 75 | class_tree = inspect.getclasstree(classes) 76 | # assert the class tree is a tree and not a dag 77 | assert class_tree == inspect.getclasstree(classes, unique=True) 78 | # assert the class tree has a root 79 | assert len(class_tree) == 2, class_tree[0] 80 | # assert the root is object 81 | assert class_tree[0][0] is object, class_tree[0][0] 82 | # assert the root has only 1 child 83 | assert len(class_tree[1]) == 2, class_tree[1] 84 | # assert that the child is ast.AST 85 | assert class_tree[1][0][0] is ast.AST, class_tree[1][0] 86 | 87 | nl = '\n' 88 | head_comment = f'''\ 89 | # file generated by {__file__} on {datetime.datetime.now()} 90 | # for python {sys.version.split(nl)[0].strip()}''' 91 | 92 | version_check = f'''\ 93 | if sys.version_info[:2] != {sys.version_info[:2]}: 94 | {TAB}warnings.warn(f"{{__file__}} generated for {sys.version_info[:2]}" 95 | {TAB} f"does not match system version {{sys.version_info[:2]}}")''' 96 | 97 | 98 | 99 | ALL = ['AST'] 100 | 101 | with open(FUNCTIONS_FILE, 'r') as f: 102 | functions = f.read() 103 | 104 | with open(META_FILE, 'r') as f: 105 | meta = f.read() 106 | 107 | with open(AST_BASE_FILE, 'r') as f: 108 | ast_base = f.read() 109 | 110 | classes = generate_classes(class_tree, ALL) 111 | 112 | immutable_ast = f'''\ 113 | {head_comment} 114 | 115 | import ast 116 | import sys 117 | import typing as tp 118 | import warnings 119 | 120 | {version_check} 121 | 122 | __ALL__ = {ALL} 123 | 124 | {functions} 125 | 126 | {meta} 127 | 128 | {ast_base} 129 | 130 | {classes} 131 | ''' 132 | return immutable_ast 133 | --------------------------------------------------------------------------------