├── fissix ├── __init__.py ├── pgen2 │ ├── __init__.py │ ├── token.py │ ├── literals.py │ ├── grammar.py │ ├── driver.py │ ├── parse.py │ ├── conv.py │ └── pgen.py ├── PatternGrammar.txt ├── pygram.py ├── Grammar.txt ├── fixer_base.py ├── btm_matcher.py ├── patcomp.py ├── btm_utils.py └── fixer_util.py ├── requirements.txt ├── tools ├── __init__.py ├── find_match_node.py ├── find_pattern.py └── click.py ├── paddle_upgrade_tool ├── tests │ ├── __init__.py │ └── test_refactor.py ├── __init__.py ├── __main__.py ├── processors.py ├── filters.py ├── fixers.py ├── common.py ├── main.py └── transformers.py ├── .gitignore ├── cases └── test.py ├── bowler ├── __init__.py ├── types.py ├── helpers.py ├── type_inference.py ├── imr.py ├── README.md └── tool.py ├── setup.py ├── .travis.yml ├── README.md └── LICENSE /fissix/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /paddle_upgrade_tool/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /paddle_upgrade_tool/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.0.25" 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.log 2 | *.pyc 3 | *.swp 4 | __pycache__ 5 | __pycache__/* 6 | cases 7 | cases/* 8 | build/ 9 | dist/ 10 | paddle_upgrade_tool.egg-info/ 11 | -------------------------------------------------------------------------------- /paddle_upgrade_tool/__main__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from .main import main 4 | 5 | if __name__ == "__main__": 6 | sys.argv[0] = "paddle_upgrade_tool" 7 | main() 8 | -------------------------------------------------------------------------------- /fissix/pgen2/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2004-2005 Elemental Security, Inc. All Rights Reserved. 2 | # Licensed to PSF under a Contributor Agreement. 3 | 4 | """The pgen2 package.""" 5 | -------------------------------------------------------------------------------- /paddle_upgrade_tool/processors.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'demo_post_processor', 3 | ] 4 | 5 | def demo_post_processor(filename, hunks): 6 | print('filename from processor:', filename) 7 | print('hunks from processor:', hunks) 8 | return True 9 | -------------------------------------------------------------------------------- /cases/test.py: -------------------------------------------------------------------------------- 1 | """ 2 | this is docstring 3 | """ 4 | from path.to import old_api 5 | 6 | # this is comment 1 7 | a = paddle.old_api(1, 2) 8 | # this is comment 2 9 | b = paddle.to.old_api(1, 2) 10 | b = paddle.to.old_api(args=1, 2) 11 | 12 | c = paddle.to.old_api_alias1(1, 2) 13 | d = paddle.to1.to2.old_api_alias2(1, 2) 14 | 15 | class CClass: 16 | pass 17 | -------------------------------------------------------------------------------- /paddle_upgrade_tool/filters.py: -------------------------------------------------------------------------------- 1 | from bowler.types import LN, Capture, Filename 2 | 3 | __all__ = [ 4 | 'print_match', 5 | ] 6 | 7 | def print_match(node: LN, capture: Capture, filename: Filename) -> bool: 8 | print('filename:', filename) 9 | print('code:\n"""{}"""\n'.format(str(node))) 10 | print('capture:\n"""{}"""\n'.format(capture)) 11 | print('-' * 10) 12 | return True 13 | -------------------------------------------------------------------------------- /paddle_upgrade_tool/fixers.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from fissix import fixer_base 4 | from fissix.refactor import RefactoringTool 5 | 6 | __all__ = [ 7 | 'FixerDemo', 8 | ] 9 | 10 | class FixerDemo(fixer_base.BaseFix): 11 | BM_compatible = True 12 | # match all function call 13 | PATTERN = """power< any* >""" 14 | 15 | def __init__(self): 16 | _logger = logging.getLogger("RefactoringTool") 17 | super(FixerDemo, self).__init__(RefactoringTool._default_options, _logger) 18 | 19 | def transform(self, node, results): 20 | print('code passed to transform:', node) 21 | print('results passed to transform:', results) 22 | return node 23 | 24 | 25 | -------------------------------------------------------------------------------- /bowler/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """Safe code refactoring for modern Python projects.""" 9 | 10 | __author__ = "John Reese, Facebook" 11 | __version__ = "0.8.0" 12 | 13 | from .imr import FunctionArgument, FunctionSpec 14 | from .query import Query 15 | from .tool import BowlerTool 16 | from .types import ( 17 | ARG_ELEMS, 18 | ARG_END, 19 | ARG_LISTS, 20 | DROP, 21 | LN, 22 | STARS, 23 | START, 24 | SYMBOL, 25 | TOKEN, 26 | BowlerException, 27 | Callback, 28 | Capture, 29 | Filename, 30 | Filter, 31 | Fixers, 32 | Hunk, 33 | IMRError, 34 | Processor, 35 | Stringish, 36 | ) 37 | -------------------------------------------------------------------------------- /fissix/PatternGrammar.txt: -------------------------------------------------------------------------------- 1 | # Copyright 2006 Google, Inc. All Rights Reserved. 2 | # Licensed to PSF under a Contributor Agreement. 3 | 4 | # A grammar to describe tree matching patterns. 5 | # Not shown here: 6 | # - 'TOKEN' stands for any token (leaf node) 7 | # - 'any' stands for any node (leaf or interior) 8 | # With 'any' we can still specify the sub-structure. 9 | 10 | # The start symbol is 'Matcher'. 11 | 12 | Matcher: Alternatives ENDMARKER 13 | 14 | Alternatives: Alternative ('|' Alternative)* 15 | 16 | Alternative: (Unit | NegatedUnit)+ 17 | 18 | Unit: [NAME '='] ( STRING [Repeater] 19 | | NAME [Details] [Repeater] 20 | | '(' Alternatives ')' [Repeater] 21 | | '[' Alternatives ']' 22 | ) 23 | 24 | NegatedUnit: 'not' (STRING | NAME [Details] | '(' Alternatives ')') 25 | 26 | Repeater: '*' | '+' | '{' NUMBER [',' NUMBER] '}' 27 | 28 | Details: '<' Alternatives '>' 29 | -------------------------------------------------------------------------------- /fissix/pygram.py: -------------------------------------------------------------------------------- 1 | # Copyright 2006 Google, Inc. All Rights Reserved. 2 | # Licensed to PSF under a Contributor Agreement. 3 | 4 | """Export the Python grammar and symbols.""" 5 | 6 | # Python imports 7 | import os 8 | 9 | # Local imports 10 | from .pgen2 import token 11 | from .pgen2 import driver 12 | from . import pytree 13 | 14 | # The grammar file 15 | _GRAMMAR_FILE = os.path.join(os.path.dirname(__file__), "Grammar.txt") 16 | _PATTERN_GRAMMAR_FILE = os.path.join(os.path.dirname(__file__), "PatternGrammar.txt") 17 | 18 | 19 | class Symbols(object): 20 | def __init__(self, grammar): 21 | """Initializer. 22 | 23 | Creates an attribute for each grammar symbol (nonterminal), 24 | whose value is the symbol's type (an int >= 256). 25 | """ 26 | for name, symbol in grammar.symbol2number.items(): 27 | setattr(self, name, symbol) 28 | 29 | 30 | python_grammar = driver.load_packaged_grammar("fissix", _GRAMMAR_FILE) 31 | 32 | python_symbols = Symbols(python_grammar) 33 | 34 | python_grammar_no_print_statement = python_grammar.copy() 35 | del python_grammar_no_print_statement.keywords["print"] 36 | 37 | python_grammar_no_print_and_exec_statement = python_grammar_no_print_statement.copy() 38 | del python_grammar_no_print_and_exec_statement.keywords["exec"] 39 | 40 | pattern_grammar = driver.load_packaged_grammar("fissix", _PATTERN_GRAMMAR_FILE) 41 | pattern_symbols = Symbols(pattern_grammar) 42 | -------------------------------------------------------------------------------- /fissix/pgen2/token.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | 3 | """Token constants (from "token.h").""" 4 | 5 | # Taken from Python (r53757) and modified to include some tokens 6 | # originally monkeypatched in by pgen2.tokenize 7 | 8 | # --start constants-- 9 | ENDMARKER = 0 10 | NAME = 1 11 | NUMBER = 2 12 | STRING = 3 13 | NEWLINE = 4 14 | INDENT = 5 15 | DEDENT = 6 16 | LPAR = 7 17 | RPAR = 8 18 | LSQB = 9 19 | RSQB = 10 20 | COLON = 11 21 | COMMA = 12 22 | SEMI = 13 23 | PLUS = 14 24 | MINUS = 15 25 | STAR = 16 26 | SLASH = 17 27 | VBAR = 18 28 | AMPER = 19 29 | LESS = 20 30 | GREATER = 21 31 | EQUAL = 22 32 | DOT = 23 33 | PERCENT = 24 34 | BACKQUOTE = 25 35 | LBRACE = 26 36 | RBRACE = 27 37 | EQEQUAL = 28 38 | NOTEQUAL = 29 39 | LESSEQUAL = 30 40 | GREATEREQUAL = 31 41 | TILDE = 32 42 | CIRCUMFLEX = 33 43 | LEFTSHIFT = 34 44 | RIGHTSHIFT = 35 45 | DOUBLESTAR = 36 46 | PLUSEQUAL = 37 47 | MINEQUAL = 38 48 | STAREQUAL = 39 49 | SLASHEQUAL = 40 50 | PERCENTEQUAL = 41 51 | AMPEREQUAL = 42 52 | VBAREQUAL = 43 53 | CIRCUMFLEXEQUAL = 44 54 | LEFTSHIFTEQUAL = 45 55 | RIGHTSHIFTEQUAL = 46 56 | DOUBLESTAREQUAL = 47 57 | DOUBLESLASH = 48 58 | DOUBLESLASHEQUAL = 49 59 | AT = 50 60 | ATEQUAL = 51 61 | OP = 52 62 | COMMENT = 53 63 | NL = 54 64 | RARROW = 55 65 | AWAIT = 56 66 | ASYNC = 57 67 | ERRORTOKEN = 58 68 | COLONEQUAL = 59 69 | N_TOKENS = 60 70 | NT_OFFSET = 256 71 | # --end constants-- 72 | 73 | tok_name = {} 74 | for _name, _value in list(globals().items()): 75 | if type(_value) is type(0): 76 | tok_name[_value] = _name 77 | 78 | 79 | def ISTERMINAL(x): 80 | return x < NT_OFFSET 81 | 82 | 83 | def ISNONTERMINAL(x): 84 | return x >= NT_OFFSET 85 | 86 | 87 | def ISEOF(x): 88 | return x == ENDMARKER 89 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import paddle_upgrade_tool 2 | from setuptools import setup, find_packages 3 | 4 | with open('requirements.txt') as f: 5 | REQUIREMENTS = f.read().splitlines() 6 | 7 | with open("README.md", "r")as f: 8 | LONG_DESCRIPTION = f.read() 9 | 10 | setup( 11 | name='paddle_upgrade_tool', 12 | version=paddle_upgrade_tool.__version__, 13 | install_requires=REQUIREMENTS, 14 | author='T8T9, PaddlePaddle', 15 | author_email='taoshibo@baidu.com', 16 | keywords=('paddle_upgrade_tool', 'paddle', 'paddlepaddle'), 17 | url='https://github.com/PaddlePaddle/paddle_upgrade_tool', 18 | packages = find_packages(), 19 | package_data={'fissix': ['*.txt']}, 20 | test_suite="paddle_upgrade_tool.tests", 21 | description='Upgrade python project from paddle-1.x to paddle-2.0', 22 | long_description=LONG_DESCRIPTION, 23 | long_description_content_type="text/markdown", 24 | license="Apache License 2.0", 25 | python_requires=">=3.5.4", 26 | setup_requires=['wheel'], 27 | scripts=[], 28 | entry_points={ 29 | 'console_scripts': [ 30 | 'paddle_upgrade_tool=paddle_upgrade_tool.main:main', 31 | 'find_pattern=tools.find_pattern:main', 32 | 'find_match_node=tools.find_match_node:main', 33 | ], 34 | }, 35 | build_dir="build", 36 | zip_safe=False, 37 | classifiers=( 38 | "License :: OSI Approved", 39 | "Programming Language :: Python :: 3", 40 | "Programming Language :: Python :: 3.5", 41 | "Programming Language :: Python :: 3.6", 42 | "Programming Language :: Python :: 3.7", 43 | "Programming Language :: Python :: 3.8", 44 | "Operating System :: OS Independent", 45 | ), 46 | ) 47 | -------------------------------------------------------------------------------- /fissix/pgen2/literals.py: -------------------------------------------------------------------------------- 1 | # Copyright 2004-2005 Elemental Security, Inc. All Rights Reserved. 2 | # Licensed to PSF under a Contributor Agreement. 3 | 4 | """Safely evaluate Python string literals without using eval().""" 5 | 6 | import re 7 | 8 | simple_escapes = { 9 | "a": "\a", 10 | "b": "\b", 11 | "f": "\f", 12 | "n": "\n", 13 | "r": "\r", 14 | "t": "\t", 15 | "v": "\v", 16 | "'": "'", 17 | '"': '"', 18 | "\\": "\\", 19 | } 20 | 21 | 22 | def escape(m): 23 | all, tail = m.group(0, 1) 24 | assert all.startswith("\\") 25 | esc = simple_escapes.get(tail) 26 | if esc is not None: 27 | return esc 28 | if tail.startswith("x"): 29 | hexes = tail[1:] 30 | if len(hexes) < 2: 31 | raise ValueError("invalid hex string escape ('\\%s')" % tail) 32 | try: 33 | i = int(hexes, 16) 34 | except ValueError: 35 | raise ValueError("invalid hex string escape ('\\%s')" % tail) from None 36 | else: 37 | try: 38 | i = int(tail, 8) 39 | except ValueError: 40 | raise ValueError("invalid octal string escape ('\\%s')" % tail) from None 41 | return chr(i) 42 | 43 | 44 | def evalString(s): 45 | assert s.startswith("'") or s.startswith('"'), repr(s[:1]) 46 | q = s[0] 47 | if s[:3] == q * 3: 48 | q = q * 3 49 | assert s.endswith(q), repr(s[-len(q) :]) 50 | assert len(s) >= 2 * len(q) 51 | s = s[len(q) : -len(q)] 52 | return re.sub(r"\\(\'|\"|\\|[abfnrtv]|x.{0,2}|[0-7]{1,3})", escape, s) 53 | 54 | 55 | def test(): 56 | for i in range(256): 57 | c = chr(i) 58 | s = repr(c) 59 | e = evalString(s) 60 | if e != c: 61 | print(i, c, s, e) 62 | 63 | 64 | if __name__ == "__main__": 65 | test() 66 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | branches: 2 | only: 3 | - master 4 | language: python 5 | os: 6 | - linux 7 | python: 8 | - "3.5" 9 | - "3.6" 10 | - "3.7" 11 | - "3.8" 12 | 13 | jobs: 14 | include: 15 | - name: "python: 3.5 on macOS" 16 | os: osx 17 | osx_image: xcode11 18 | language: shell 19 | env: PYTHON=35 20 | before_script: 21 | - python3 --version 22 | script: 23 | - python3 -m unittest discover paddle_upgrade_tool/tests/ 24 | - python3 setup.py sdist bdist_wheel 25 | - pip3 install dist/paddle_upgrade_tool*.whl 26 | - cd .. 27 | - python3 -m unittest discover paddle_upgrade_tool.tests.test_refactor 28 | after_success: "" 29 | - name: "python: 3.5 on Windows" 30 | os: windows 31 | language: shell 32 | before_install: 33 | - choco install python --version 3.5.4 34 | - python -m pip install --upgrade pip 35 | - pip3 install wheel 36 | script: 37 | - python -m unittest discover paddle_upgrade_tool/tests/ 38 | - python setup.py sdist bdist_wheel 39 | - pip3 install dist/paddle_upgrade_tool*.whl 40 | - cd .. 41 | - python -m unittest discover paddle_upgrade_tool.tests.test_refactor 42 | after_success: "" 43 | env: PATH=/c/Python35:/c/Python35/Scripts:$PATH 44 | 45 | before_install: 46 | - echo "this is before_install" 47 | 48 | before_script: 49 | - echo "this is before_script" 50 | - python --version 51 | 52 | script: 53 | - echo "this is script" 54 | - | 55 | ( 56 | python -m unittest discover paddle_upgrade_tool/tests/ 57 | python setup.py sdist bdist_wheel 58 | pip3 install dist/paddle_upgrade_tool*.whl 59 | cd .. 60 | python -m unittest discover paddle_upgrade_tool.tests.test_refactor 61 | paddle_upgrade_tool -h 62 | pip3 install pytest-cov coveralls 63 | cd paddle_upgrade_tool 64 | pytest --cov=paddle_upgrade_tool 65 | ) 66 | 67 | after_script: 68 | - echo "this is after_script" 69 | 70 | after_success: 71 | - echo "this is after_success" 72 | - coveralls 73 | -------------------------------------------------------------------------------- /bowler/types.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | from typing import Any, Callable, Dict, List, NewType, Optional, Type, Union 9 | 10 | from fissix.fixer_base import BaseFix 11 | from fissix.pgen2 import token 12 | from fissix.pygram import python_symbols 13 | from fissix.pytree import Leaf, Node 14 | 15 | 16 | class Passthrough: 17 | def __init__(self, target) -> None: 18 | self._target = target 19 | 20 | def __getattr__(self, name) -> Any: 21 | return getattr(self._target, name) 22 | 23 | 24 | TOKEN = Passthrough(token) 25 | SYMBOL = Passthrough(python_symbols) 26 | 27 | SENTINEL = object() 28 | START = object() 29 | DROP = object() 30 | 31 | STARS = {TOKEN.STAR, TOKEN.DOUBLESTAR} 32 | ARG_END = {TOKEN.RPAR, TOKEN.COMMA} 33 | ARG_LISTS = {SYMBOL.typedargslist, SYMBOL.arglist} # function def, function call 34 | ARG_ELEMS = { 35 | TOKEN.NAME, # single argument 36 | SYMBOL.tname, # type annotated 37 | SYMBOL.argument, # keyword argument 38 | SYMBOL.star_expr, # vararg 39 | } | STARS 40 | 41 | LN = Union[Leaf, Node] 42 | Stringish = Union[str, object] 43 | Filename = NewType("Filename", str) 44 | FilenameMatcher = Callable[[Filename], bool] 45 | Capture = Dict[str, Any] 46 | Callback = Callable[[Node, Capture, Filename], Any] 47 | Filter = Callable[[Node, Capture, Filename], bool] 48 | Fixers = List[Type[BaseFix]] 49 | Hunk = List[str] 50 | Processor = Callable[[Filename, Hunk], bool] 51 | 52 | 53 | class Transform: 54 | def __init__(self, selector="", filters=None, callbacks=None, fixer=None, kwargs=None): 55 | self.selector = selector 56 | self.kwargs = kwargs if kwargs is not None else {} 57 | self.filters = filters if filters is not None else [] 58 | self.callbacks = callbacks if callbacks is not None else [] 59 | self.fixer = fixer 60 | 61 | class BowlerException(Exception): 62 | def __init__(self, message = "", *, filename = "", hunks = None): 63 | super().__init__(message) 64 | self.filename = filename 65 | self.hunks = hunks 66 | 67 | 68 | class BowlerQuit(BowlerException): 69 | pass 70 | 71 | 72 | class IMRError(BowlerException): 73 | pass 74 | 75 | 76 | class RetryFile(BowlerException): 77 | pass 78 | 79 | 80 | class BadTransform(BowlerException): 81 | pass 82 | -------------------------------------------------------------------------------- /tools/find_match_node.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import division 3 | from __future__ import absolute_import 4 | 5 | import sys 6 | import argparse 7 | from six import StringIO 8 | from lib2to3.patcomp import PatternCompiler 9 | from lib2to3 import pytree 10 | from lib2to3.pgen2 import driver 11 | from lib2to3.pygram import python_symbols, python_grammar 12 | 13 | def main(): 14 | parser = argparse.ArgumentParser() 15 | g1 = parser.add_mutually_exclusive_group(required=True) 16 | g1.add_argument("-pf", "--pattern-file", dest="pattern_file", type=str, help='Read pattern from the specified file') 17 | g1.add_argument("-ps", "--pattern-string", dest="pattern_string", type=str, help='A pattern string') 18 | g2 = parser.add_mutually_exclusive_group(required=True) 19 | g2.add_argument("-sf", "--source-file", dest="source_file", type=str, help="Read code snippet from the specified file") 20 | g2.add_argument("-ss", "--source-string", dest="source_string", type=str, help="A code snippet string") 21 | parser.add_argument("--print-results", dest="print_results", action='store_true', default=False, help="Print match results") 22 | parser.add_argument("--print-lineno", dest="print_lineno", action='store_true', default=False, help="Print match code with line number") 23 | # Parse command line arguments 24 | args = parser.parse_args() 25 | 26 | # parse source snippet to CST tree 27 | driver_ = driver.Driver(python_grammar, convert=pytree.convert) 28 | if args.source_file: 29 | tree = driver_.parse_file(args.source_file) 30 | else: 31 | tree = driver_.parse_stream(StringIO(args.source_string + "\n")) 32 | # compile pattern 33 | if args.pattern_file: 34 | with open(args.pattern_file, 'r') as f: 35 | pattern = f.read() 36 | else: 37 | pattern = args.pattern_string 38 | PC = PatternCompiler() 39 | pattern, pattern_tree = PC.compile_pattern(pattern, with_tree=True) 40 | for node in tree.post_order(): 41 | results = {'node':node} 42 | if pattern.match(node, results): 43 | match_node = results['node'] 44 | src_lines = str(match_node).splitlines() 45 | if args.print_lineno: 46 | # calculate lineno_list according to the right most leaf node. 47 | # because some node includes prefix, which is not a node, and we can't get it's lineno. 48 | right_most_leaf = match_node 49 | while not isinstance(right_most_leaf, pytree.Leaf): 50 | right_most_leaf = right_most_leaf.children[-1] 51 | last_lineno = right_most_leaf.get_lineno() 52 | lineno_list = list(range(last_lineno - len(src_lines) + 1, last_lineno + 1)) 53 | src_lines = [str(lineno) + ' ' + line for lineno, line in zip(lineno_list, src_lines)] 54 | for line in src_lines: 55 | print(line) 56 | if args.print_results: 57 | print(results) 58 | print('-' * 20) 59 | 60 | if __name__ == "__main__": 61 | sys.exit(main()) 62 | -------------------------------------------------------------------------------- /paddle_upgrade_tool/common.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import logging 3 | import threading 4 | from multiprocessing import Manager 5 | 6 | 7 | __all__ = [ 8 | 'logger', 9 | 'log_to_file', 10 | ] 11 | 12 | class ColorFormatter(logging.Formatter): 13 | """Logging Formatter to add colors and count warning / errors""" 14 | def __init__(self, fmt=None, datefmt=None, style='%'): 15 | super().__init__(fmt, datefmt, style) 16 | self.fmt = fmt 17 | 18 | light_gray='\033[0;37m' 19 | dark_gray='\033[1;30m' 20 | yellow = "\033[0;33m" 21 | red = "\033[0;31m" 22 | reset = "\033[0m" 23 | 24 | self.FORMATS = { 25 | logging.DEBUG: dark_gray + fmt + reset, 26 | logging.INFO: light_gray + fmt + reset, 27 | logging.WARNING: yellow + fmt + reset, 28 | logging.ERROR: red + fmt + reset, 29 | } 30 | 31 | def format(self, record): 32 | log_fmt = None 33 | # if not windows, add color info 34 | if sys.platform.lower() != 'win32': 35 | log_fmt = self.FORMATS.get(record.levelno) 36 | if log_fmt is None: 37 | log_fmt = self.fmt 38 | formatter = logging.Formatter(log_fmt) 39 | return formatter.format(record) 40 | 41 | def log_to_file(log_filepath="report.log"): 42 | log_filepath = log_filepath or "report.log" 43 | file_handler = logging.FileHandler(log_filepath) 44 | log_format = logging.Formatter('%(asctime)s - %(filename)s:%(lineno)d - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S') 45 | file_handler.setFormatter(log_format) 46 | logger.addHandler(file_handler) 47 | 48 | headless_file_handler = logging.FileHandler(log_filepath) 49 | headless_log_format = logging.Formatter('%(message)s', datefmt='%Y-%m-%d %H:%M:%S') 50 | headless_file_handler.setFormatter(headless_log_format) 51 | headless_logger.addHandler(headless_file_handler) 52 | 53 | def _build_default_logger(): 54 | logger = logging.getLogger('paddle_upgrade_tool') 55 | logger.setLevel("INFO") 56 | 57 | console_handler = logging.StreamHandler(stream=sys.stdout) # default stream is sys.stderr 58 | log_format = ColorFormatter('%(asctime)s - %(filename)s:%(lineno)d - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S') 59 | console_handler.setFormatter(log_format) 60 | logger.addHandler(console_handler) 61 | 62 | return logger 63 | 64 | def _build_headless_logger(): 65 | logger = logging.getLogger('statistic') 66 | logger.setLevel("INFO") 67 | 68 | console_handler = logging.StreamHandler(stream=sys.stdout) # default stream is sys.stderr 69 | log_format = ColorFormatter('%(message)s', datefmt='%Y-%m-%d %H:%M:%S') 70 | console_handler.setFormatter(log_format) 71 | logger.addHandler(console_handler) 72 | 73 | return logger 74 | 75 | logger = _build_default_logger() 76 | headless_logger = _build_headless_logger() 77 | 78 | # record refactor log. 79 | # example: 80 | # statistic = { 81 | # '/path/to/file1.py':{ 82 | # 'info': ['rename "paddle.api1" to "paddle.api2"', 'rename "paddle.api3" to "paddle.api4"'], 83 | # 'warn': ['delete "paddle.api5"'], 84 | # 'error': ['parse "paddle.api6" error'], 85 | # }, 86 | # } 87 | manager = Manager() 88 | statistic = manager.dict() 89 | statistic_lock = threading.Lock() 90 | -------------------------------------------------------------------------------- /tools/find_pattern.py: -------------------------------------------------------------------------------- 1 | """Script that makes determining PATTERN for a new fix much easier. 2 | 3 | Figuring out exactly what PATTERN I want for a given fixer class is 4 | getting tedious. This script will step through each possible subtree 5 | for a given string, allowing you to select which one you want. It will 6 | then try to figure out an appropriate pattern to match that tree. This 7 | pattern will require some editing (it will be overly restrictive) but 8 | should provide a solid base to work with and handle the tricky parts. 9 | 10 | Usage: 11 | 12 | python find_pattern.py "g.throw(E, V, T)" 13 | 14 | This will step through each subtree in the parse. To reject a 15 | candidate subtree, hit enter; to accept a candidate, hit "y" and 16 | enter. The pattern will be spit out to stdout. 17 | 18 | For example, the above will yield a succession of possible snippets, 19 | skipping all leaf-only trees. I accept 20 | 21 | 'g.throw(E, V, T)' 22 | 23 | This causes find_pattern to spit out 24 | 25 | power< 'g' trailer< '.' 'throw' > 26 | trailer< '(' arglist< 'E' ',' 'V' ',' 'T' > ')' > > 27 | 28 | 29 | Some minor tweaks later, I'm left with 30 | 31 | power< any trailer< '.' 'throw' > 32 | trailer< '(' args=arglist< exc=any ',' val=any [',' tb=any] > ')' > > 33 | 34 | which is exactly what I was after. 35 | 36 | Larger snippets can be placed in a file (as opposed to a command-line 37 | arg) and processed with the -f option. 38 | """ 39 | 40 | from __future__ import print_function 41 | from __future__ import division 42 | from __future__ import absolute_import 43 | 44 | __author__ = "Collin Winter " 45 | 46 | # Python imports 47 | import optparse 48 | import sys 49 | import six 50 | from six import StringIO 51 | from six.moves import input 52 | 53 | # Local imports 54 | from lib2to3 import pytree 55 | from lib2to3.pgen2 import driver 56 | from lib2to3.pygram import python_symbols, python_grammar 57 | 58 | driver = driver.Driver(python_grammar, convert=pytree.convert) 59 | 60 | def main(): 61 | args = sys.argv 62 | parser = optparse.OptionParser(usage="find_pattern.py [options] [string]") 63 | parser.add_option("-f", "--file", action="store", 64 | help="Read a code snippet from the specified file") 65 | 66 | # Parse command line arguments 67 | options, args = parser.parse_args(args) 68 | if options.file: 69 | tree = driver.parse_file(options.file) 70 | elif len(args) > 1: 71 | tree = driver.parse_stream(StringIO(args[1] + "\n")) 72 | else: 73 | print("You must specify an input file or an input string", file=sys.stderr) 74 | return 1 75 | 76 | examine_tree(tree) 77 | return 0 78 | 79 | def examine_tree(tree): 80 | for node in tree.post_order(): 81 | if isinstance(node, pytree.Leaf): 82 | continue 83 | print((repr(str(node)))) 84 | verdict = input() 85 | if verdict.strip(): 86 | print((find_pattern(node))) 87 | return 88 | 89 | def find_pattern(node): 90 | if isinstance(node, pytree.Leaf): 91 | return repr(node.value) 92 | 93 | return find_symbol(node.type) + \ 94 | "< " + " ".join(find_pattern(n) for n in node.children) + " >" 95 | 96 | def find_symbol(sym): 97 | for n, v in list(python_symbols.__dict__.items()): 98 | if v == sym: 99 | return n 100 | 101 | if __name__ == "__main__": 102 | sys.exit(main()) 103 | -------------------------------------------------------------------------------- /tools/click.py: -------------------------------------------------------------------------------- 1 | # ref: https://github.com/pallets/click/ 2 | import sys 3 | 4 | _ansi_colors = { 5 | "black": 30, 6 | "red": 31, 7 | "green": 32, 8 | "yellow": 33, 9 | "blue": 34, 10 | "magenta": 35, 11 | "cyan": 36, 12 | "white": 37, 13 | "reset": 39, 14 | "bright_black": 90, 15 | "bright_red": 91, 16 | "bright_green": 92, 17 | "bright_yellow": 93, 18 | "bright_blue": 94, 19 | "bright_magenta": 95, 20 | "bright_cyan": 96, 21 | "bright_white": 97, 22 | } 23 | 24 | _ansi_reset_all = "\033[0m" 25 | 26 | def _interpret_color(color, offset=0): 27 | if isinstance(color, int): 28 | return "{};5;{}".format(38 + offset, color) 29 | 30 | if isinstance(color, (tuple, list)): 31 | r, g, b = color 32 | return "{};2;{};{};{}".format(38 + offset, r, g, b) 33 | 34 | return str(_ansi_colors[color] + offset) 35 | 36 | def style( 37 | text, 38 | fg=None, 39 | bg=None, 40 | bold=None, 41 | dim=None, 42 | underline=None, 43 | blink=None, 44 | reverse=None, 45 | reset=True, 46 | ): 47 | if not isinstance(text, str): 48 | text = str(text) 49 | 50 | if sys.platform.lower() == 'win32': 51 | return text 52 | 53 | bits = [] 54 | 55 | if fg: 56 | try: 57 | bits.append("\033[0;{}m".format(_interpret_color(fg))) 58 | except KeyError: 59 | raise TypeError("Unknown color {}".format(fg)) 60 | 61 | if bg: 62 | try: 63 | bits.append("\033[0;{}m".format(_interpret_color(bg, 10))) 64 | except KeyError: 65 | raise TypeError("Unknown color {}".format(fg)) 66 | 67 | if bold is not None: 68 | bits.append("\033[{}m".format(1 if bold else 22)) 69 | if dim is not None: 70 | bits.append("\033[{}m".format(2 if dim else 22)) 71 | if underline is not None: 72 | bits.append("\033[{}m".format(4 if underline else 24)) 73 | if blink is not None: 74 | bits.append("\033[{}m".format(5 if blink else 25)) 75 | if reverse is not None: 76 | bits.append("\033[{}m".format(7 if reverse else 27)) 77 | bits.append(text) 78 | if reset: 79 | bits.append(_ansi_reset_all) 80 | return "".join(bits) 81 | 82 | def secho(message=None, nl=True, err=False, **styles): 83 | if message is not None: 84 | message = style(message, **styles) 85 | 86 | return echo(message, nl=nl, err=err) 87 | 88 | def echo(message=None, nl=True, err=False): 89 | if err: 90 | out = sys.stderr 91 | else: 92 | out = sys.stdout 93 | 94 | # Convert non bytes/text into the native string type. 95 | if message is not None and not isinstance(message, str): 96 | message = str(message) 97 | 98 | if nl: 99 | message = message or "" 100 | if isinstance(message, str): 101 | message += "\n" 102 | else: 103 | message += b"\n" 104 | 105 | if message: 106 | out.write(message) 107 | out.flush() 108 | 109 | class Abort(RuntimeError): 110 | """An internal signalling exception that signals Click to abort.""" 111 | 112 | def confirm(text, default=False, abort=False, prompt_suffix=": ", err=False): 113 | prompt = "{} [{}]{}".format(text, "Y/n", prompt_suffix) 114 | while 1: 115 | try: 116 | # Write the prompt separately so that we get nice 117 | # coloring through colorama on Windows 118 | echo(prompt, nl=False, err=err) 119 | value = input("").lower().strip() 120 | except (KeyboardInterrupt, EOFError): 121 | raise Abort() 122 | if value in ("y", "yes"): 123 | rv = True 124 | elif value in ("n", "no"): 125 | rv = False 126 | elif value == "": 127 | rv = default 128 | else: 129 | echo("Error: invalid input", err=err) 130 | continue 131 | break 132 | if abort and not rv: 133 | raise Abort() 134 | return rv 135 | -------------------------------------------------------------------------------- /paddle_upgrade_tool/main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import division 3 | from __future__ import absolute_import 4 | 5 | import os 6 | import sys 7 | import argparse 8 | from tools import click 9 | 10 | from bowler import Query 11 | 12 | from paddle_upgrade_tool.common import * 13 | from paddle_upgrade_tool import refactor, filters, utils 14 | from paddle_upgrade_tool.refactor import * 15 | from paddle_upgrade_tool.spec import change_spec 16 | from paddle_upgrade_tool.utils import backup_inpath, print_statistic 17 | 18 | def should_convert(inpath): 19 | """ 20 | check if convert should be run. 21 | """ 22 | # check if inpath exists, and python files in inpath are valid. 23 | if not utils.valid_path(inpath): 24 | return False 25 | return True 26 | 27 | def main(): 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument("--log-level", dest="log_level", type=str, choices=["DEBUG", "INFO", "WARNING", "ERROR"], help="set log level, default is INFO") 30 | parser.add_argument("--no-log-file", dest="no_log_file", action='store_true', default=False, help="don't log to file") 31 | parser.add_argument("--log-filepath", dest="log_filepath", type=str, help='set log file path, default is "report.log"') 32 | parser.add_argument("-i", "--inpath", required=True, type=str, help='the file or directory path you want to upgrade.') 33 | parser.add_argument("-b", "--backup", type=str, nargs='?', default=None, const=None, help='backup directory, default is the "~/.paddle_upgrade_tool/".') 34 | parser.add_argument("-w", "--write", action='store_true', default=False, help='modify files in-place.') 35 | parser.add_argument("--no-confirm", dest="no_confirm", action='store_true', default=False, help='write files in-place without confirm, ignored without --write.') 36 | parser.add_argument("-p", "--parallel", type=int, default=None, help='specify the maximum number of concurrent processes to use when refactoring, ignored with --no-confirm.') 37 | parser.add_argument("-r", "--refactor", action='append', choices=refactor.__all__, help='this is a debug option. Specify refactor you want to run. If none, all refactors will be run.') 38 | parser.add_argument("--print-match", action='store_true', default=False, help='this is a debug option. Print matched code and node for each file.') 39 | 40 | args = parser.parse_args() 41 | if args.refactor: 42 | args.refactor = set(args.refactor) 43 | if args.backup is None: 44 | home = os.path.expanduser('~') 45 | args.backup = os.path.join(home, '.paddle_upgrade_tool') 46 | else: 47 | args.backup = os.path.expanduser(args.backup) 48 | 49 | if args.log_level: 50 | logger.setLevel(args.log_level) 51 | if not args.no_log_file: 52 | log_to_file(args.log_filepath) 53 | if not should_convert(args.inpath): 54 | logger.error("convert abort!") 55 | sys.exit(1) 56 | 57 | # refactor code via "Query" step by step. 58 | q = Query(args.inpath) 59 | for fn in refactor.__all__: 60 | refactor_func = getattr(refactor, fn) 61 | if args.refactor and fn not in args.refactor: 62 | continue 63 | assert callable(refactor_func), "{} is not callable.".format(fn) 64 | logger.debug("run refactor: {}".format(fn)) 65 | if args.print_match: 66 | refactor_func(q, change_spec).filter(filters.print_match) 67 | else: 68 | refactor_func(q, change_spec) 69 | 70 | if args.write: 71 | # backup args.inpath 72 | backup = backup_inpath(args.inpath, args.backup) 73 | # print diff to stdout, and modify file in place. 74 | if utils.is_windows(): 75 | q.execute(write=True, silent=False, need_confirm=not args.no_confirm, backup=backup, in_process=True) 76 | else: 77 | q.execute(write=True, silent=False, need_confirm=not args.no_confirm, parallel=args.parallel, backup=backup) 78 | else: 79 | # print diff to stdout 80 | if utils.is_windows(): 81 | q.execute(write=False, silent=False, in_process=True) 82 | else: 83 | q.execute(write=False, silent=False, parallel=args.parallel) 84 | click.secho('Refactor finished without touching source files, add "--write" to modify source files in-place if everything is ok.', fg="red", bold=True) 85 | 86 | if not(sys.version_info.major == 3 and sys.version_info.minor == 5): 87 | print_statistic(levels=['warning']) 88 | 89 | if __name__ == "__main__": 90 | sys.exit(main()) 91 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Upgrade your python model from paddle-1.x to paddle-2. 2 | 3 | [![Build Status](https://travis-ci.org/PaddlePaddle/paddle_upgrade_tool.svg?branch=master)](https://travis-ci.org/PaddlePaddle/paddle_upgrade_tool) 4 | [![Coverage Status](https://coveralls.io/repos/github/PaddlePaddle/paddle_upgrade_tool/badge.svg?branch=master&kill_cache=1)](https://coveralls.io/github/PaddlePaddle/paddle_upgrade_tool?branch=master) 5 | [![Version](https://img.shields.io/pypi/v/paddle_upgrade_tool)](https://pypi.org/project/paddle_upgrade_tool) 6 | [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) 7 | 8 | ### Attention 9 | `paddle_upgrade_tool` aims to convert python files from paddle-1.x to paddle-2 one by one, it won't handle indirect imports. e.g. 10 | 11 | ```python 12 | # filename "a.py" 13 | import paddle.fluid as fluid 14 | pass 15 | 16 | # filename "b.py" 17 | from a import * 18 | class MyLayer(fluid.layers.Layer): 19 | pass 20 | ``` 21 | 22 | `fluid.layers.Layer` in "b.py" won't get converted. 23 | **So you have to make sure you have import all used paddle modules, classes, objects directly for every python file before running `paddle_upgrade_tool`.** 24 | 25 | ### Install 26 | paddle_upgrade_tool support Linux, Mac OS, Windows([Git Bash](https://gitforwindows.org/) is recommended), **but it requires Python 3.5.4 or higher to run**. Multi-Process is supported for Linux and Mac OS, Single-Process is support for Windows, this will lead to performance difference. 27 | 28 | 1. install with pip 29 | 30 | ```bash 31 | pip install -U paddle_upgrade_tool 32 | paddle_upgrade_tool --help # show help 33 | paddle_upgrade_tool --inpath /path/to/model.py # upgrade your model from paddle-1.x to paddle-2.0 34 | ``` 35 | 36 | **ATTENTION**: If your device contains multiple versions of python, you may need to run the following commands instead: 37 | ```bash 38 | python3 -m pip install -U paddle_upgrade_tool 39 | python3 -m paddle_upgrade_tool -h 40 | python3 -m paddle_upgrade_tool --inpath /path/to/model.py 41 | ``` 42 | 43 | 2. install from source 44 | 45 | ```bash 46 | git clone https://github.com/T8T9/paddle_upgrade_tool.git 47 | cd paddle_upgrade_tool 48 | python setup.py sdist bdist_wheel 49 | pip install -U ./dist/paddle_upgrade_tool-*.whl 50 | paddle_upgrade_tool --help # show help 51 | paddle_upgrade_tool --inpath /path/to/model.py # upgrade your model from paddle-1.x to paddle-2.0 52 | ``` 53 | 54 | ### Develop 55 | If you are a develop, and you want to test your code quickly, you can run the following command in project directory: 56 | 57 | ```bash 58 | python -m paddle_upgrade_tool --inpath /path/to/model.py 59 | 60 | #or 61 | 62 | python paddle_upgrade_tool/main.py --inpath /path/to/model.py 63 | ``` 64 | 65 | Moreover, if you want to run a specific refactor, you can use the following command: 66 | 67 | ```bash 68 | python -m paddle_upgrade_tool --inpath /path/to/model.py --refactor 69 | ``` 70 | 71 | use `python -m paddle_upgrade_tool -h` to see full list of all refactors. 72 | 73 | if you want to run all unittest, use command: 74 | 75 | ```bash 76 | python -m unittest discover paddle_upgrade_tool/tests/ 77 | # or 78 | python setup.py test 79 | ``` 80 | or use command: 81 | 82 | ```bash 83 | python -m unittest paddle_upgrade_tool/tests/test_refactor.py 84 | ``` 85 | to run specific test file. 86 | 87 | ### Change Spec 88 | `change_spec` is a python dict defined in spec.py, it defines the rules to refactor your code. 89 | 90 | ```python 91 | change_spec = { 92 | "path.to.old_api": { 93 | "alias": [ 94 | "path.to.old_api_alias1", 95 | "path.to1.to2.old_api_alias2", 96 | ], 97 | "update_to": "path.to.new_api", 98 | "warning": "this api is deprecated.", 99 | "args_list": ["arg1", "arg2"], 100 | "args_change": [ 101 | ["arg2", "arg2_rename"], 102 | ["arg3", ""], 103 | ["", "new_arg", "default_value"], 104 | ], 105 | "args_warning": {"arg1":"warning message"}, 106 | "args_transformer": "_default_transformer", 107 | }, 108 | } 109 | ``` 110 | 111 | - `alias`: a list of alias of main alias `path.to.old_api`, all alias will be replaced with main alias. 112 | - `update_to`: `path.to.old_api` will be replaced with this new api if specified. 113 | - `warning`: print specified warning message when `path.to.old_api` is found. This field will be ignored if `update_to` is specified. 114 | - `args_list`: is argument list of `path.to.old_api`. 115 | - `args_change`: a list of list. It contains following format: 116 | - `["arg", "new_arg"]`: rename a argument, e.g. `func(arg=value)` -> `func(new_arg=value)` 117 | - `["arg", ""]`: remove a argument, e.g. `func(arg=value)` -> `func()` 118 | - `["", "new_arg", "default_value"]`: add a new argument, e.g. `func(arg=value)` -> `func(arg=value, new_arg=default_value)` 119 | - `args_warning`: print specified warning message for specified argument after apply `args_change`. 120 | - `args_transformer`: execute customized transformer on an [AST node](https://github.com/python/cpython/blob/75c80b0bda89debf312f075716b8c467d411f90e/Lib/lib2to3/pytree.py#L207), it will be called after applying `args_change` to do further refactor. 121 | 122 | ### Other Tools 123 | 1. find pattern of specific code snippet, usage: 124 | 125 | ```bash 126 | find_pattern 'import paddle' 127 | ``` 128 | `find_pattern` command will traverse all nodes in AST, if you see code snippet you want, type in 'y' to get pattern. 129 | 130 | 2. find match node in specific code for specific pattern, usage: 131 | 132 | ```bash 133 | find_match_node -ss 'import paddle' -ps 'any' 134 | ``` 135 | 136 | you can also specify "--print-results" option to got representation of matched node, specify "--print-lineno" to got line number of matched code. 137 | 138 | 139 | ### Acknowledgements 140 | - [Bowler](https://github.com/facebookincubator/Bowler/): Safe code refactoring for modern Python projects. 141 | - [lib2to3](https://github.com/python/cpython/tree/master/Lib/lib2to3): A built-in python library to refactor python code. 142 | - [fissix](https://github.com/jreese/fissix/): A backport of latest lib2to3, with enhancements. 143 | -------------------------------------------------------------------------------- /fissix/pgen2/grammar.py: -------------------------------------------------------------------------------- 1 | # Copyright 2004-2005 Elemental Security, Inc. All Rights Reserved. 2 | # Licensed to PSF under a Contributor Agreement. 3 | 4 | """This module defines the data structures used to represent a grammar. 5 | 6 | These are a bit arcane because they are derived from the data 7 | structures used by Python's 'pgen' parser generator. 8 | 9 | There's also a table here mapping operators to their names in the 10 | token module; the Python tokenize module reports all operators as the 11 | fallback token code OP, but the parser needs the actual token code. 12 | 13 | """ 14 | 15 | # Python imports 16 | import pickle 17 | 18 | # Local imports 19 | from . import token 20 | 21 | 22 | class Grammar(object): 23 | """Pgen parsing tables conversion class. 24 | 25 | Once initialized, this class supplies the grammar tables for the 26 | parsing engine implemented by parse.py. The parsing engine 27 | accesses the instance variables directly. The class here does not 28 | provide initialization of the tables; several subclasses exist to 29 | do this (see the conv and pgen modules). 30 | 31 | The load() method reads the tables from a pickle file, which is 32 | much faster than the other ways offered by subclasses. The pickle 33 | file is written by calling dump() (after loading the grammar 34 | tables using a subclass). The report() method prints a readable 35 | representation of the tables to stdout, for debugging. 36 | 37 | The instance variables are as follows: 38 | 39 | symbol2number -- a dict mapping symbol names to numbers. Symbol 40 | numbers are always 256 or higher, to distinguish 41 | them from token numbers, which are between 0 and 42 | 255 (inclusive). 43 | 44 | number2symbol -- a dict mapping numbers to symbol names; 45 | these two are each other's inverse. 46 | 47 | states -- a list of DFAs, where each DFA is a list of 48 | states, each state is a list of arcs, and each 49 | arc is a (i, j) pair where i is a label and j is 50 | a state number. The DFA number is the index into 51 | this list. (This name is slightly confusing.) 52 | Final states are represented by a special arc of 53 | the form (0, j) where j is its own state number. 54 | 55 | dfas -- a dict mapping symbol numbers to (DFA, first) 56 | pairs, where DFA is an item from the states list 57 | above, and first is a set of tokens that can 58 | begin this grammar rule (represented by a dict 59 | whose values are always 1). 60 | 61 | labels -- a list of (x, y) pairs where x is either a token 62 | number or a symbol number, and y is either None 63 | or a string; the strings are keywords. The label 64 | number is the index in this list; label numbers 65 | are used to mark state transitions (arcs) in the 66 | DFAs. 67 | 68 | start -- the number of the grammar's start symbol. 69 | 70 | keywords -- a dict mapping keyword strings to arc labels. 71 | 72 | tokens -- a dict mapping token numbers to arc labels. 73 | 74 | """ 75 | 76 | def __init__(self): 77 | self.symbol2number = {} 78 | self.number2symbol = {} 79 | self.states = [] 80 | self.dfas = {} 81 | self.labels = [(0, "EMPTY")] 82 | self.keywords = {} 83 | self.tokens = {} 84 | self.symbol2label = {} 85 | self.start = 256 86 | 87 | def dump(self, filename): 88 | """Dump the grammar tables to a pickle file.""" 89 | with open(filename, "wb") as f: 90 | pickle.dump(self.__dict__, f, pickle.HIGHEST_PROTOCOL) 91 | 92 | def load(self, filename): 93 | """Load the grammar tables from a pickle file.""" 94 | with open(filename, "rb") as f: 95 | d = pickle.load(f) 96 | self.__dict__.update(d) 97 | 98 | def loads(self, pkl): 99 | """Load the grammar tables from a pickle bytes object.""" 100 | self.__dict__.update(pickle.loads(pkl)) 101 | 102 | def copy(self): 103 | """ 104 | Copy the grammar. 105 | """ 106 | new = self.__class__() 107 | for dict_attr in ( 108 | "symbol2number", 109 | "number2symbol", 110 | "dfas", 111 | "keywords", 112 | "tokens", 113 | "symbol2label", 114 | ): 115 | setattr(new, dict_attr, getattr(self, dict_attr).copy()) 116 | new.labels = self.labels[:] 117 | new.states = self.states[:] 118 | new.start = self.start 119 | return new 120 | 121 | def report(self): 122 | """Dump the grammar tables to standard output, for debugging.""" 123 | from pprint import pprint 124 | 125 | print("s2n") 126 | pprint(self.symbol2number) 127 | print("n2s") 128 | pprint(self.number2symbol) 129 | print("states") 130 | pprint(self.states) 131 | print("dfas") 132 | pprint(self.dfas) 133 | print("labels") 134 | pprint(self.labels) 135 | print("start", self.start) 136 | 137 | 138 | # Map from operator to number (since tokenize doesn't do this) 139 | 140 | opmap_raw = """ 141 | ( LPAR 142 | ) RPAR 143 | [ LSQB 144 | ] RSQB 145 | : COLON 146 | , COMMA 147 | ; SEMI 148 | + PLUS 149 | - MINUS 150 | * STAR 151 | / SLASH 152 | | VBAR 153 | & AMPER 154 | < LESS 155 | > GREATER 156 | = EQUAL 157 | . DOT 158 | % PERCENT 159 | ` BACKQUOTE 160 | { LBRACE 161 | } RBRACE 162 | @ AT 163 | @= ATEQUAL 164 | == EQEQUAL 165 | != NOTEQUAL 166 | <> NOTEQUAL 167 | <= LESSEQUAL 168 | >= GREATEREQUAL 169 | ~ TILDE 170 | ^ CIRCUMFLEX 171 | << LEFTSHIFT 172 | >> RIGHTSHIFT 173 | ** DOUBLESTAR 174 | += PLUSEQUAL 175 | -= MINEQUAL 176 | *= STAREQUAL 177 | /= SLASHEQUAL 178 | %= PERCENTEQUAL 179 | &= AMPEREQUAL 180 | |= VBAREQUAL 181 | ^= CIRCUMFLEXEQUAL 182 | <<= LEFTSHIFTEQUAL 183 | >>= RIGHTSHIFTEQUAL 184 | **= DOUBLESTAREQUAL 185 | // DOUBLESLASH 186 | //= DOUBLESLASHEQUAL 187 | -> RARROW 188 | := COLONEQUAL 189 | """ 190 | 191 | opmap = {} 192 | for line in opmap_raw.splitlines(): 193 | if line: 194 | op, name = line.split() 195 | opmap[op] = getattr(token, name) 196 | -------------------------------------------------------------------------------- /fissix/pgen2/driver.py: -------------------------------------------------------------------------------- 1 | # Copyright 2004-2005 Elemental Security, Inc. All Rights Reserved. 2 | # Licensed to PSF under a Contributor Agreement. 3 | 4 | # Modifications: 5 | # Copyright 2006 Google, Inc. All Rights Reserved. 6 | # Licensed to PSF under a Contributor Agreement. 7 | 8 | """Parser driver. 9 | 10 | This provides a high-level interface to parse a file into a syntax tree. 11 | 12 | """ 13 | 14 | __author__ = "Guido van Rossum " 15 | 16 | __all__ = ["Driver", "load_grammar"] 17 | 18 | # Python imports 19 | import io 20 | import os 21 | import logging 22 | import pkgutil 23 | import sys 24 | 25 | # Pgen imports 26 | from . import grammar, parse, token, tokenize, pgen 27 | 28 | 29 | class Driver(object): 30 | def __init__(self, grammar, convert=None, logger=None): 31 | self.grammar = grammar 32 | if logger is None: 33 | logger = logging.getLogger() 34 | self.logger = logger 35 | self.convert = convert 36 | 37 | def parse_tokens(self, tokens, debug=False): 38 | """Parse a series of tokens and return the syntax tree.""" 39 | # XXX Move the prefix computation into a wrapper around tokenize. 40 | p = parse.Parser(self.grammar, self.convert) 41 | p.setup() 42 | lineno = 1 43 | column = 0 44 | type = value = start = end = line_text = None 45 | prefix = "" 46 | for quintuple in tokens: 47 | type, value, start, end, line_text = quintuple 48 | if start != (lineno, column): 49 | assert (lineno, column) <= start, ((lineno, column), start) 50 | s_lineno, s_column = start 51 | if lineno < s_lineno: 52 | prefix += "\n" * (s_lineno - lineno) 53 | lineno = s_lineno 54 | column = 0 55 | if column < s_column: 56 | prefix += line_text[column:s_column] 57 | column = s_column 58 | if type in (tokenize.COMMENT, tokenize.NL): 59 | prefix += value 60 | lineno, column = end 61 | if value.endswith("\n"): 62 | lineno += 1 63 | column = 0 64 | continue 65 | if type == token.OP: 66 | type = grammar.opmap[value] 67 | if debug: 68 | self.logger.debug( 69 | "%s %r (prefix=%r)", token.tok_name[type], value, prefix 70 | ) 71 | if p.addtoken(type, value, (prefix, start)): 72 | if debug: 73 | self.logger.debug("Stop.") 74 | break 75 | prefix = "" 76 | lineno, column = end 77 | if value.endswith("\n"): 78 | lineno += 1 79 | column = 0 80 | else: 81 | # We never broke out -- EOF is too soon (how can this happen???) 82 | raise parse.ParseError("incomplete input", type, value, (prefix, start)) 83 | return p.rootnode 84 | 85 | def parse_stream_raw(self, stream, debug=False): 86 | """Parse a stream and return the syntax tree.""" 87 | tokens = tokenize.generate_tokens(stream.readline) 88 | return self.parse_tokens(tokens, debug) 89 | 90 | def parse_stream(self, stream, debug=False): 91 | """Parse a stream and return the syntax tree.""" 92 | return self.parse_stream_raw(stream, debug) 93 | 94 | def parse_file(self, filename, encoding=None, debug=False): 95 | """Parse a file and return the syntax tree.""" 96 | with io.open(filename, "r", encoding=encoding) as stream: 97 | return self.parse_stream(stream, debug) 98 | 99 | def parse_string(self, text, debug=False): 100 | """Parse a string and return the syntax tree.""" 101 | tokens = tokenize.generate_tokens(io.StringIO(text).readline) 102 | return self.parse_tokens(tokens, debug) 103 | 104 | 105 | def _generate_pickle_name(gt): 106 | head, tail = os.path.splitext(gt) 107 | if tail == ".txt": 108 | tail = "" 109 | return head + tail + ".".join(map(str, sys.version_info)) + ".pickle" 110 | 111 | 112 | def load_grammar(gt="Grammar.txt", gp=None, save=True, force=False, logger=None): 113 | """Load the grammar (maybe from a pickle).""" 114 | if logger is None: 115 | logger = logging.getLogger() 116 | gp = _generate_pickle_name(gt) if gp is None else gp 117 | if force or not _newer(gp, gt): 118 | logger.info("Generating grammar tables from %s", gt) 119 | g = pgen.generate_grammar(gt) 120 | if save: 121 | logger.info("Writing grammar tables to %s", gp) 122 | try: 123 | g.dump(gp) 124 | except OSError as e: 125 | logger.info("Writing failed: %s", e) 126 | else: 127 | g = grammar.Grammar() 128 | g.load(gp) 129 | return g 130 | 131 | 132 | def _newer(a, b): 133 | """Inquire whether file a was written since file b.""" 134 | if not os.path.exists(a): 135 | return False 136 | if not os.path.exists(b): 137 | return True 138 | return os.path.getmtime(a) >= os.path.getmtime(b) 139 | 140 | 141 | def load_packaged_grammar(package, grammar_source): 142 | """Normally, loads a pickled grammar by doing 143 | pkgutil.get_data(package, pickled_grammar) 144 | where *pickled_grammar* is computed from *grammar_source* by adding the 145 | Python version and using a ``.pickle`` extension. 146 | 147 | However, if *grammar_source* is an extant file, load_grammar(grammar_source) 148 | is called instead. This facilitates using a packaged grammar file when needed 149 | but preserves load_grammar's automatic regeneration behavior when possible. 150 | 151 | """ 152 | if os.path.isfile(grammar_source): 153 | return load_grammar(grammar_source, save=False, force=True) 154 | pickled_name = _generate_pickle_name(os.path.basename(grammar_source)) 155 | data = pkgutil.get_data(package, pickled_name) 156 | g = grammar.Grammar() 157 | g.loads(data) 158 | return g 159 | 160 | 161 | def main(*args): 162 | """Main program, when run as a script: produce grammar pickle files. 163 | 164 | Calls load_grammar for each argument, a path to a grammar text file. 165 | """ 166 | if not args: 167 | args = sys.argv[1:] 168 | logging.basicConfig(level=logging.INFO, stream=sys.stdout, format="%(message)s") 169 | for gt in args: 170 | load_grammar(gt, save=False, force=True) 171 | return True 172 | 173 | 174 | if __name__ == "__main__": 175 | sys.exit(int(not main())) 176 | -------------------------------------------------------------------------------- /bowler/helpers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import logging 9 | from typing import List, Optional, Sequence, Union 10 | 11 | from tools import click 12 | from fissix.pgen2.token import tok_name 13 | from fissix.pytree import Leaf, Node, type_repr 14 | 15 | from .types import LN, SYMBOL, TOKEN, Capture, Filename, FilenameMatcher 16 | 17 | log = logging.getLogger(__name__) 18 | 19 | INDENT_STR = ". " 20 | 21 | 22 | def print_selector_pattern( 23 | node, results = None, filename = None 24 | ): 25 | key = "" 26 | if results: 27 | for k, v in results.items(): 28 | if node == v: 29 | key = k + "=" 30 | elif isinstance(v, list) and node in v: # v is a list? 31 | key = k + "=" 32 | 33 | if isinstance(node, Leaf): 34 | click.echo("{}{} ".format(key, repr(node.value)), nl=False) 35 | else: 36 | click.echo("{}{} ".format(key, type_repr(node.type)), nl=False) 37 | if node.children: 38 | click.echo("< ", nl=False) 39 | for child in node.children: 40 | print_selector_pattern(child, results, filename) 41 | click.echo("> ", nl=False) 42 | 43 | 44 | def print_tree( 45 | node, 46 | results = None, 47 | filename = None, 48 | indent = 0, 49 | recurse = -1, 50 | ): 51 | filename = filename or Filename("") 52 | tab = INDENT_STR * indent 53 | if filename and indent == 0: 54 | click.secho(filename, fg="red", bold=True) 55 | 56 | if isinstance(node, Leaf): 57 | click.echo( 58 | click.style(tab, fg="black", bold=True) 59 | + click.style( 60 | "[{}] {} {}".format(tok_name[node.type], repr(node.prefix), repr(node.value)), 61 | fg="yellow", 62 | ) 63 | ) 64 | else: 65 | click.echo( 66 | click.style(tab, fg="black", bold=True) 67 | + click.style("[{}] {}".format(type_repr(node.type), repr(node.prefix)), fg="blue") 68 | ) 69 | 70 | if node.children: 71 | if recurse: 72 | for child in node.children: 73 | # N.b. do not pass results here since we print them once 74 | # at the end. 75 | print_tree(child, indent=indent + 1, recurse=recurse - 1) 76 | else: 77 | click.echo(INDENT_STR * (indent + 1) + "...") 78 | 79 | if results is None: 80 | return 81 | 82 | for key in results: 83 | if key == "node": 84 | continue 85 | 86 | value = results[key] 87 | if isinstance(value, (Leaf, Node)): 88 | click.secho("results[{}] =".format(repr(key)), fg="red") 89 | print_tree(value, indent=1, recurse=1) 90 | else: 91 | # TODO: Improve display of multi-match here, see 92 | # test_print_tree_captures test. 93 | click.secho("results[{}] = {}".format(repr(key), value), fg="red") 94 | 95 | 96 | def dotted_parts(name): 97 | pre, dot, post = name.partition(".") 98 | if post: 99 | post_parts = dotted_parts(post) 100 | else: 101 | post_parts = [] 102 | result = [] 103 | if pre: 104 | result.append(pre) 105 | if pre and dot: 106 | result.append(dot) 107 | if post_parts: 108 | result.extend(post_parts) 109 | return result 110 | 111 | 112 | def quoted_parts(name): 113 | return ["'{}'".format(part) for part in dotted_parts(name)] 114 | 115 | 116 | def power_parts(name): 117 | parts = quoted_parts(name) 118 | index = 0 119 | while index < len(parts): 120 | if parts[index] == "'.'": 121 | parts.insert(index, "trailer<") 122 | parts.insert(index + 3, ">") 123 | index += 1 124 | index += 1 125 | return parts 126 | 127 | 128 | def is_method(node): 129 | return ( 130 | node.type == SYMBOL.funcdef 131 | and node.parent is not None 132 | and node.parent.type == SYMBOL.suite 133 | and node.parent.parent is not None 134 | and node.parent.parent.type == SYMBOL.classdef 135 | ) 136 | 137 | 138 | def is_call_to(node, func_name): 139 | """Returns whether the node represents a call to the named function.""" 140 | return ( 141 | node.type == SYMBOL.power 142 | and node.children[0].type == TOKEN.NAME 143 | and node.children[0].value == func_name 144 | ) 145 | 146 | 147 | def find_first(node, target, recursive = False): 148 | queue = [node] 149 | queue.extend(node.children) 150 | while queue: 151 | child = queue.pop(0) 152 | if child.type == target: 153 | return child 154 | if recursive: 155 | queue = child.children + queue 156 | return None 157 | 158 | 159 | def find_previous(node, target, recursive = False): 160 | while node.prev_sibling is not None: 161 | node = node.prev_sibling 162 | result = find_last(node, target, recursive) 163 | if result: 164 | return result 165 | return None 166 | 167 | 168 | def find_next(node, target, recursive = False): 169 | while node.next_sibling is not None: 170 | node = node.next_sibling 171 | result = find_first(node, target, recursive) 172 | if result: 173 | return result 174 | return None 175 | 176 | 177 | def find_last(node, target, recursive = False): 178 | queue = [] 179 | queue.extend(reversed(node.children)) 180 | while queue: 181 | child = queue.pop(0) 182 | if recursive: 183 | result = find_last(child, target, recursive) 184 | if result: 185 | return result 186 | if child.type == target: 187 | return child 188 | return None 189 | 190 | 191 | def get_class(node): 192 | while node.parent is not None: 193 | if node.type == SYMBOL.classdef: 194 | return node 195 | node = node.parent 196 | raise ValueError("classdef node not found") 197 | 198 | 199 | class Once: 200 | """Simple object that evaluates to True once, and then always False.""" 201 | 202 | def __init__(self): 203 | self.done = False 204 | 205 | def __bool__(self): 206 | if self.done: 207 | return False 208 | else: 209 | self.done = True 210 | return True 211 | 212 | 213 | def filename_endswith(ext): 214 | if isinstance(ext, str): 215 | ext = [ext] 216 | 217 | def inner(filename): 218 | return any(filename.endswith(e) for e in ext) 219 | 220 | return inner 221 | -------------------------------------------------------------------------------- /paddle_upgrade_tool/transformers.py: -------------------------------------------------------------------------------- 1 | from bowler.types import LN, Capture, Filename, SYMBOL, TOKEN 2 | from fissix.fixer_util import Name, Call, Number, KeywordArg, Comma, Newline 3 | from fissix.pytree import Leaf, Node, type_repr 4 | from fissix.pygram import python_grammar, python_symbols 5 | from fissix.pgen2 import token 6 | from fissix import patcomp 7 | 8 | from paddle_upgrade_tool import utils 9 | from paddle_upgrade_tool.utils import log_debug, log_info, log_warning, log_error 10 | 11 | 12 | def default_transformer(node: LN, capture: Capture, filename: Filename): 13 | fp = capture.get("function_parameters") 14 | if fp and fp.children[1].type == SYMBOL.arglist: 15 | arg_node = KeywordArg(Name("trans_arg"), Number("1")) 16 | fp.children[1].append_child(Comma()) 17 | fp.children[1].append_child(arg_node) 18 | 19 | 20 | def act_transformer(filename, trailer_node, removed_value): 21 | """ 22 | add act to forward function, after delete act arg from api 23 | """ 24 | if removed_value == "None": 25 | return 26 | # parent must be a power node 27 | power_node = trailer_node.parent 28 | if not isinstance(power_node, Node) and power_node.type != python_symbols.power: 29 | return 30 | # parent of parent must be an expression 31 | expr_node = power_node.parent 32 | if not isinstance(expr_node, Node) and expr_node.type != python_symbols.expr_stmt: 33 | return 34 | assign_idx = -1 # "=" index 35 | for idx in range(len(expr_node.children)): 36 | if expr_node.children[idx].type == token.EQUAL: 37 | assign_idx = idx 38 | break 39 | if assign_idx == -1: 40 | return 41 | layer_name = utils.node2code(expr_node.children[0:assign_idx]) 42 | # Layer Class 43 | if 'self.' in layer_name: 44 | _forward_act_transformer(filename, expr_node, layer_name, removed_value) 45 | # invoke activation function directly 46 | else: 47 | _function_act_transformer(filename, expr_node, removed_value) 48 | 49 | _pattern_funcdef_forward = "funcdef< 'def' 'forward' parameters< '(' ( 'self' | typedargslist< 'self' any* > ) any* ')' > any* >" 50 | _pattern_funcdef_forward = patcomp.compile_pattern(_pattern_funcdef_forward) 51 | _pattern_expr_stmt = "simple_stmt< expr_stmt< left=(any*) '=' right=(any*) > any* >" 52 | _pattern_expr_stmt = patcomp.compile_pattern(_pattern_expr_stmt) 53 | 54 | def _forward_act_transformer(filename, expr_node, layer_name, removed_value): 55 | # find funcdef node 56 | funcdef_node = None 57 | node = expr_node 58 | while node is not None: 59 | if node.type == python_symbols.funcdef: 60 | funcdef_node = node 61 | break 62 | node = node.parent 63 | if funcdef_node is None: 64 | return 65 | # find def forward function node 66 | forward_node = None 67 | classdef_node = None 68 | node = funcdef_node 69 | while node is not None: 70 | if node.type == python_symbols.classdef: 71 | classdef_node = node 72 | break 73 | node = node.parent 74 | if classdef_node is None: 75 | return 76 | for node in classdef_node.pre_order(): 77 | results = {'node': node} 78 | if _pattern_funcdef_forward.match(node, results) and results["node"] is node: 79 | forward_node = node 80 | break 81 | if forward_node is None: 82 | return 83 | for node in forward_node.post_order(): 84 | results = {'node': node} 85 | if _pattern_expr_stmt.match(node, results) and results["node"] is node: 86 | right=utils.node2code(results['right']).strip() 87 | if not utils.startswith(right, layer_name): 88 | continue 89 | left = utils.node2code(results['left']).strip() 90 | # if removed_value type is str 91 | if '"' in removed_value or "'" in removed_value: 92 | act = removed_value 93 | act = act.strip('"') 94 | act = act.strip("'") 95 | act = act.strip() 96 | # create statement like "x = paddle.nn.functional.act(x)" 97 | code = left + ' = ' + 'paddle.nn.functional.' + act + '(' + left + ')' 98 | _create_simple_stmt_node_and_insert_behind(code, node) 99 | # removed_value is a variable 100 | else: 101 | # add "self._act = act" after expr_node to make it visible to other methods 102 | act_var_name = "self._" + removed_value 103 | code = act_var_name + " = " + removed_value 104 | _create_simple_stmt_node_and_insert_behind(code, expr_node.parent) 105 | # create statement like "x = getattr(paddle.nn.functional, act)(x) if act else x" 106 | code = left + " = getattr(paddle.nn.functional, " + act_var_name + ")(" + left + ") if " + act_var_name + " else " + left 107 | _create_simple_stmt_node_and_insert_behind(code, node) 108 | log_warning(filename, expr_node.get_lineno(), 'variable "{}" may not be visible here.'.format(removed_value)) 109 | 110 | 111 | def _function_act_transformer(filename, expr_node, removed_value): 112 | simple_stmt_node = expr_node.parent 113 | results = {'node': simple_stmt_node} 114 | if _pattern_expr_stmt.match(simple_stmt_node, results) and results["node"] is simple_stmt_node: 115 | left = utils.node2code(results['left']).strip() 116 | # if removed_value type is str 117 | if '"' in removed_value or "'" in removed_value: 118 | act = removed_value 119 | act = act.strip('"') 120 | act = act.strip("'") 121 | act = act.strip() 122 | # create statement like "x = paddle.nn.functional.act(x)" 123 | code = left + ' = ' + 'paddle.nn.functional.' + act + '(' + left + ')' 124 | _create_simple_stmt_node_and_insert_behind(code, simple_stmt_node) 125 | # removed_value is a variable 126 | else: 127 | # create statement like "x = getattr(paddle.nn.functional, act)(x) if act else x" 128 | code = left + " = getattr(paddle.nn.functional, " + removed_value + ")(" + left + ") if " + removed_value + " else " + left 129 | _create_simple_stmt_node_and_insert_behind(code, simple_stmt_node) 130 | 131 | 132 | def _create_simple_stmt_node_and_insert_behind(code, node): 133 | if node is None or node.type != python_symbols.simple_stmt: 134 | return 135 | simple_stmt_node = Node(python_symbols.simple_stmt, [utils.newline_node(node)]) 136 | _node = utils.code_repr(code).children[0].children[0] 137 | _node.parent = None 138 | simple_stmt_node.insert_child(0, _node) 139 | simple_stmt_node.prefix = utils.get_indent(node) 140 | utils.insert_node_behind(node, simple_stmt_node) 141 | -------------------------------------------------------------------------------- /fissix/Grammar.txt: -------------------------------------------------------------------------------- 1 | # Grammar for 2to3. This grammar supports Python 2.x and 3.x. 2 | 3 | # NOTE WELL: You should also follow all the steps listed at 4 | # https://devguide.python.org/grammar/ 5 | 6 | # Start symbols for the grammar: 7 | # file_input is a module or sequence of commands read from an input file; 8 | # single_input is a single interactive statement; 9 | # eval_input is the input for the eval() and input() functions. 10 | # NB: compound_stmt in single_input is followed by extra NEWLINE! 11 | file_input: (NEWLINE | stmt)* ENDMARKER 12 | single_input: NEWLINE | simple_stmt | compound_stmt NEWLINE 13 | eval_input: testlist NEWLINE* ENDMARKER 14 | 15 | decorator: '@' dotted_name [ '(' [arglist] ')' ] NEWLINE 16 | decorators: decorator+ 17 | decorated: decorators (classdef | funcdef | async_funcdef) 18 | async_funcdef: ASYNC funcdef 19 | funcdef: 'def' NAME parameters ['->' test] ':' suite 20 | parameters: '(' [typedargslist] ')' 21 | typedargslist: ((tfpdef ['=' test] ',')* 22 | ('*' [tname] (',' tname ['=' test])* [',' ['**' tname [',']]] | '**' tname [',']) 23 | | tfpdef ['=' test] (',' tfpdef ['=' test])* [',']) 24 | tname: NAME [':' test] 25 | tfpdef: tname | '(' tfplist ')' 26 | tfplist: tfpdef (',' tfpdef)* [','] 27 | varargslist: ((vfpdef ['=' test] ',')* 28 | ('*' [vname] (',' vname ['=' test])* [',' ['**' vname [',']]] | '**' vname [',']) 29 | | vfpdef ['=' test] (',' vfpdef ['=' test])* [',']) 30 | vname: NAME 31 | vfpdef: vname | '(' vfplist ')' 32 | vfplist: vfpdef (',' vfpdef)* [','] 33 | 34 | stmt: simple_stmt | compound_stmt 35 | simple_stmt: small_stmt (';' small_stmt)* [';'] NEWLINE 36 | small_stmt: (expr_stmt | print_stmt | del_stmt | pass_stmt | flow_stmt | 37 | import_stmt | global_stmt | exec_stmt | assert_stmt) 38 | expr_stmt: testlist_star_expr (annassign | augassign (yield_expr|testlist) | 39 | ('=' (yield_expr|testlist_star_expr))*) 40 | annassign: ':' test ['=' test] 41 | testlist_star_expr: (test|star_expr) (',' (test|star_expr))* [','] 42 | augassign: ('+=' | '-=' | '*=' | '@=' | '/=' | '%=' | '&=' | '|=' | '^=' | 43 | '<<=' | '>>=' | '**=' | '//=') 44 | # For normal and annotated assignments, additional restrictions enforced by the interpreter 45 | print_stmt: 'print' ( [ test (',' test)* [','] ] | 46 | '>>' test [ (',' test)+ [','] ] ) 47 | del_stmt: 'del' exprlist 48 | pass_stmt: 'pass' 49 | flow_stmt: break_stmt | continue_stmt | return_stmt | raise_stmt | yield_stmt 50 | break_stmt: 'break' 51 | continue_stmt: 'continue' 52 | return_stmt: 'return' [testlist_star_expr] 53 | yield_stmt: yield_expr 54 | raise_stmt: 'raise' [test ['from' test | ',' test [',' test]]] 55 | import_stmt: import_name | import_from 56 | import_name: 'import' dotted_as_names 57 | import_from: ('from' ('.'* dotted_name | '.'+) 58 | 'import' ('*' | '(' import_as_names ')' | import_as_names)) 59 | import_as_name: NAME ['as' NAME] 60 | dotted_as_name: dotted_name ['as' NAME] 61 | import_as_names: import_as_name (',' import_as_name)* [','] 62 | dotted_as_names: dotted_as_name (',' dotted_as_name)* 63 | dotted_name: NAME ('.' NAME)* 64 | global_stmt: ('global' | 'nonlocal') NAME (',' NAME)* 65 | exec_stmt: 'exec' expr ['in' test [',' test]] 66 | assert_stmt: 'assert' test [',' test] 67 | 68 | compound_stmt: if_stmt | while_stmt | for_stmt | try_stmt | with_stmt | funcdef | classdef | decorated | async_stmt 69 | async_stmt: ASYNC (funcdef | with_stmt | for_stmt) 70 | if_stmt: 'if' namedexpr_test ':' suite ('elif' namedexpr_test ':' suite)* ['else' ':' suite] 71 | while_stmt: 'while' namedexpr_test ':' suite ['else' ':' suite] 72 | for_stmt: 'for' exprlist 'in' testlist ':' suite ['else' ':' suite] 73 | try_stmt: ('try' ':' suite 74 | ((except_clause ':' suite)+ 75 | ['else' ':' suite] 76 | ['finally' ':' suite] | 77 | 'finally' ':' suite)) 78 | with_stmt: 'with' with_item (',' with_item)* ':' suite 79 | with_item: test ['as' expr] 80 | with_var: 'as' expr 81 | # NB compile.c makes sure that the default except clause is last 82 | except_clause: 'except' [test [(',' | 'as') test]] 83 | suite: simple_stmt | NEWLINE INDENT stmt+ DEDENT 84 | 85 | # Backward compatibility cruft to support: 86 | # [ x for x in lambda: True, lambda: False if x() ] 87 | # even while also allowing: 88 | # lambda x: 5 if x else 2 89 | # (But not a mix of the two) 90 | testlist_safe: old_test [(',' old_test)+ [',']] 91 | old_test: or_test | old_lambdef 92 | old_lambdef: 'lambda' [varargslist] ':' old_test 93 | 94 | namedexpr_test: test [':=' test] 95 | test: or_test ['if' or_test 'else' test] | lambdef 96 | or_test: and_test ('or' and_test)* 97 | and_test: not_test ('and' not_test)* 98 | not_test: 'not' not_test | comparison 99 | comparison: expr (comp_op expr)* 100 | comp_op: '<'|'>'|'=='|'>='|'<='|'<>'|'!='|'in'|'not' 'in'|'is'|'is' 'not' 101 | star_expr: '*' expr 102 | expr: xor_expr ('|' xor_expr)* 103 | xor_expr: and_expr ('^' and_expr)* 104 | and_expr: shift_expr ('&' shift_expr)* 105 | shift_expr: arith_expr (('<<'|'>>') arith_expr)* 106 | arith_expr: term (('+'|'-') term)* 107 | term: factor (('*'|'@'|'/'|'%'|'//') factor)* 108 | factor: ('+'|'-'|'~') factor | power 109 | power: [AWAIT] atom trailer* ['**' factor] 110 | atom: ('(' [yield_expr|testlist_gexp] ')' | 111 | '[' [listmaker] ']' | 112 | '{' [dictsetmaker] '}' | 113 | '`' testlist1 '`' | 114 | NAME | NUMBER | STRING+ | '.' '.' '.') 115 | listmaker: (namedexpr_test|star_expr) ( comp_for | (',' (namedexpr_test|star_expr))* [','] ) 116 | testlist_gexp: (namedexpr_test|star_expr) ( comp_for | (',' (namedexpr_test|star_expr))* [','] ) 117 | lambdef: 'lambda' [varargslist] ':' test 118 | trailer: '(' [arglist] ')' | '[' subscriptlist ']' | '.' NAME 119 | subscriptlist: subscript (',' subscript)* [','] 120 | subscript: test | [test] ':' [test] [sliceop] 121 | sliceop: ':' [test] 122 | exprlist: (expr|star_expr) (',' (expr|star_expr))* [','] 123 | testlist: test (',' test)* [','] 124 | dictsetmaker: ( ((test ':' test | '**' expr) 125 | (comp_for | (',' (test ':' test | '**' expr))* [','])) | 126 | ((test | star_expr) 127 | (comp_for | (',' (test | star_expr))* [','])) ) 128 | 129 | classdef: 'class' NAME ['(' [arglist] ')'] ':' suite 130 | 131 | arglist: argument (',' argument)* [','] 132 | 133 | # "test '=' test" is really "keyword '=' test", but we have no such token. 134 | # These need to be in a single rule to avoid grammar that is ambiguous 135 | # to our LL(1) parser. Even though 'test' includes '*expr' in star_expr, 136 | # we explicitly match '*' here, too, to give it proper precedence. 137 | # Illegal combinations and orderings are blocked in ast.c: 138 | # multiple (test comp_for) arguments are blocked; keyword unpackings 139 | # that precede iterable unpackings are blocked; etc. 140 | argument: ( test [comp_for] | 141 | test ':=' test | 142 | test '=' test | 143 | '**' test | 144 | '*' test ) 145 | 146 | comp_iter: comp_for | comp_if 147 | comp_for: [ASYNC] 'for' exprlist 'in' testlist_safe [comp_iter] 148 | comp_if: 'if' old_test [comp_iter] 149 | 150 | testlist1: test (',' test)* 151 | 152 | # not used in grammar, but may appear in "node" passed from Parser to Compiler 153 | encoding_decl: NAME 154 | 155 | yield_expr: 'yield' [yield_arg] 156 | yield_arg: 'from' test | testlist_star_expr 157 | -------------------------------------------------------------------------------- /bowler/type_inference.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """bowler.type_inference 9 | 10 | Given an expression, find its result type. 11 | 12 | For sufficiently obvious expressions, we can find this using only local 13 | knowledge (numeric literals, and functions/names which have a standard 14 | meaning). Some obvious examples: 15 | 16 | `1.0` -> InferredType.FLOAT 17 | `2L` -> InferredType.INT 18 | `1/2` -> Depends on use_py2_division 19 | 20 | Even in cases where we don't know the full inputs, we can make some reasonable 21 | assumptions. For example when passing `use_py2_division=False, 22 | type_for_unknown=InferredType.INT_OR_FLOAT`: 23 | 24 | `x+1.0` -> InferredType.FLOAT 25 | `x/y` -> InferredType.FLOAT 26 | `len(x) + 1` -> InferredType.INT 27 | `float(z)` -> InferredType.FLOAT 28 | 29 | This is intended to be useful for either refactoring or flagging for humans 30 | syntax like `range(float(...))` or `"%d" % (float(...),)`. 31 | """ 32 | 33 | import enum 34 | from typing import Dict, Sequence, Union 35 | 36 | from fissix import pygram, pytree 37 | from fissix.pgen2 import token 38 | from fissix.pgen2.driver import Driver 39 | 40 | from .helpers import is_call_to 41 | from .types import LN, SYMBOL, TOKEN 42 | 43 | __all__ = ["InferredType", "numeric_expr_type"] 44 | 45 | 46 | class InferredType(enum.IntEnum): 47 | # The order of these is important, such that an expression using py3 48 | # semantics most operators take max(OP_MIN_TYPE[op], max(children)) as the 49 | # result. 50 | UNSET = 0 51 | BOOL = 1 52 | INT = 2 53 | # This represents UNKNOWN but assumed to be restricted to normal numeric 54 | # values. It can still be promoted to COMPLEX or FLOAT, but if is the 55 | # final result should be treated as INT (or better). 56 | INT_OR_FLOAT = 3 57 | FLOAT = 4 58 | COMPLEX = 5 59 | UNKNOWN = 6 60 | 61 | 62 | # Note: SLASH and DOUBLESLASH are specialcased. 63 | OP_MIN_TYPE: Dict = { 64 | TOKEN.PLUS: InferredType.INT, 65 | TOKEN.MINUS: InferredType.INT, 66 | TOKEN.STAR: InferredType.INT, 67 | TOKEN.PERCENT: InferredType.INT, 68 | TOKEN.SLASH: InferredType.INT, 69 | TOKEN.DOUBLESLASH: InferredType.INT, 70 | TOKEN.TILDE: InferredType.INT, # bitwise not 71 | TOKEN.DOUBLESTAR: InferredType.INT, 72 | TOKEN.LEFTSHIFT: InferredType.INT, 73 | TOKEN.RIGHTSHIFT: InferredType.INT, 74 | TOKEN.VBAR: InferredType.BOOL, 75 | TOKEN.CIRCUMFLEX: InferredType.BOOL, 76 | TOKEN.AMPER: InferredType.BOOL, 77 | TOKEN.LESS: InferredType.BOOL, 78 | } 79 | 80 | 81 | def numeric_expr_type( 82 | node, 83 | use_py2_division=False, 84 | type_for_unknown = InferredType.UNKNOWN, 85 | ): 86 | """Infer the type of an expression from its literals. 87 | 88 | We broaden the definition of "literal" a bit to also include calls to 89 | certain functions like int() and float() where the return type does not 90 | change based on the arguments. 91 | 92 | Args: 93 | node: A Node or leaf. 94 | use_py2_division: Whether to use magical python 2 style division. 95 | type_for_unknown: An InferredType to customize how you wan unknown 96 | handled. Use `INT_OR_FLOAT` if you trust your input to only work 97 | on numbers, but `UNKNOWN` if you want objects to be an option. 98 | 99 | Returns: InferredType 100 | """ 101 | if node.type == TOKEN.NUMBER: 102 | # It's important that we do not use eval here; some literals like `2L` 103 | # may be invalid in the current interpreter. 104 | if "j" in node.value: 105 | return InferredType.COMPLEX 106 | elif "." in node.value or "e" in node.value: 107 | return InferredType.FLOAT 108 | return InferredType.INT 109 | elif node.type == TOKEN.NAME and node.value in ("True", "False"): 110 | return InferredType.BOOL 111 | # TODO let the caller provide other known return types, or even a 112 | # collection of locals and their types. 113 | elif is_call_to(node, "bool"): 114 | return InferredType.BOOL 115 | elif is_call_to(node, "int") or is_call_to(node, "len"): 116 | return InferredType.INT 117 | elif is_call_to(node, "float"): 118 | return InferredType.FLOAT 119 | 120 | elif node.type in (SYMBOL.comparison, SYMBOL.not_test): 121 | return InferredType.BOOL 122 | elif node.type == SYMBOL.factor: 123 | # unary ~ + -, always [op, number] 124 | return max( 125 | OP_MIN_TYPE[node.children[0].type], 126 | numeric_expr_type(node.children[1], use_py2_division, type_for_unknown), 127 | ) 128 | elif node.type == SYMBOL.shift_expr: 129 | # << only valid on int 130 | return InferredType.INT 131 | 132 | elif node.type == SYMBOL.power: 133 | # a**b, but also f(...) 134 | if node.children[1].type != TOKEN.DOUBLESTAR: 135 | # probably f(...) 136 | return type_for_unknown 137 | 138 | return max( 139 | max(OP_MIN_TYPE[c.type] for c in node.children[1::2]), 140 | max( 141 | numeric_expr_type(c, use_py2_division, type_for_unknown) 142 | for c in node.children[::2] 143 | ), 144 | ) 145 | elif node.type in ( 146 | SYMBOL.arith_expr, 147 | SYMBOL.xor_expr, 148 | SYMBOL.and_expr, 149 | SYMBOL.expr, 150 | ): 151 | return max( 152 | max(OP_MIN_TYPE[c.type] for c in node.children[1::2]), 153 | max( 154 | numeric_expr_type(c, use_py2_division, type_for_unknown) 155 | for c in node.children[::2] 156 | ), 157 | ) 158 | elif node.type == SYMBOL.term: 159 | # */% 160 | # This is where things get interesting, as we handle use_py2_division. 161 | t = InferredType.UNSET 162 | last_op = None 163 | for i in range(len(node.children)): 164 | if i % 2 == 0: 165 | new = numeric_expr_type( 166 | node.children[i], use_py2_division, type_for_unknown 167 | ) 168 | if last_op == TOKEN.DOUBLESLASH: 169 | t = InferredType.INT 170 | elif last_op == TOKEN.SLASH: 171 | if use_py2_division: 172 | if t == InferredType.INT and new == InferredType.INT: 173 | t = InferredType.INT 174 | else: 175 | t = max(t, max(OP_MIN_TYPE[last_op], new)) 176 | else: 177 | t = max(t, InferredType.FLOAT) 178 | else: 179 | if last_op: 180 | t = max(t, OP_MIN_TYPE[last_op]) 181 | t = max(t, new) 182 | else: 183 | last_op = node.children[i].type 184 | return t 185 | elif node.type in (SYMBOL.or_test, SYMBOL.and_test): 186 | return max( 187 | numeric_expr_type(c, use_py2_division, type_for_unknown) 188 | for c in node.children[::2] 189 | ) 190 | 191 | return type_for_unknown 192 | -------------------------------------------------------------------------------- /fissix/fixer_base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2006 Google, Inc. All Rights Reserved. 2 | # Licensed to PSF under a Contributor Agreement. 3 | 4 | """Base class for fixers (optional, but recommended).""" 5 | 6 | # Python imports 7 | import itertools 8 | 9 | # Local imports 10 | from .patcomp import PatternCompiler 11 | from . import pygram 12 | from .fixer_util import does_tree_import 13 | 14 | 15 | class BaseFix(object): 16 | 17 | """Optional base class for fixers. 18 | 19 | The subclass name must be FixFooBar where FooBar is the result of 20 | removing underscores and capitalizing the words of the fix name. 21 | For example, the class name for a fixer named 'has_key' should be 22 | FixHasKey. 23 | """ 24 | 25 | PATTERN = None # Most subclasses should override with a string literal 26 | pattern = None # Compiled pattern, set by compile_pattern() 27 | pattern_tree = None # Tree representation of the pattern 28 | options = None # Options object passed to initializer 29 | filename = None # The filename (set by set_filename) 30 | numbers = itertools.count(1) # For new_name() 31 | used_names = set() # A set of all used NAMEs 32 | order = "post" # Does the fixer prefer pre- or post-order traversal 33 | explicit = False # Is this ignored by refactor.py -f all? 34 | run_order = 5 # Fixers will be sorted by run order before execution 35 | # Lower numbers will be run first. 36 | _accept_type = None # [Advanced and not public] This tells RefactoringTool 37 | # which node type to accept when there's not a pattern. 38 | 39 | keep_line_order = False # For the bottom matcher: match with the 40 | # original line order 41 | BM_compatible = False # Compatibility with the bottom matching 42 | # module; every fixer should set this 43 | # manually 44 | 45 | # Shortcut for access to Python grammar symbols 46 | syms = pygram.python_symbols 47 | 48 | def __init__(self, options, log): 49 | """Initializer. Subclass may override. 50 | 51 | Args: 52 | options: a dict containing the options passed to RefactoringTool 53 | that could be used to customize the fixer through the command line. 54 | log: a list to append warnings and other messages to. 55 | """ 56 | self.options = options 57 | self.log = log 58 | self.compile_pattern() 59 | 60 | def compile_pattern(self): 61 | """Compiles self.PATTERN into self.pattern. 62 | 63 | Subclass may override if it doesn't want to use 64 | self.{pattern,PATTERN} in .match(). 65 | """ 66 | if self.PATTERN is not None: 67 | PC = PatternCompiler() 68 | self.pattern, self.pattern_tree = PC.compile_pattern( 69 | self.PATTERN, with_tree=True 70 | ) 71 | 72 | def set_filename(self, filename): 73 | """Set the filename. 74 | 75 | The main refactoring tool should call this. 76 | """ 77 | self.filename = filename 78 | 79 | def match(self, node): 80 | """Returns match for a given parse tree node. 81 | 82 | Should return a true or false object (not necessarily a bool). 83 | It may return a non-empty dict of matching sub-nodes as 84 | returned by a matching pattern. 85 | 86 | Subclass may override. 87 | """ 88 | results = {"node": node} 89 | return self.pattern.match(node, results) and results 90 | 91 | def transform(self, node, results): 92 | """Returns the transformation for a given parse tree node. 93 | 94 | Args: 95 | node: the root of the parse tree that matched the fixer. 96 | results: a dict mapping symbolic names to part of the match. 97 | 98 | Returns: 99 | None, or a node that is a modified copy of the 100 | argument node. The node argument may also be modified in-place to 101 | effect the same change. 102 | 103 | Subclass *must* override. 104 | """ 105 | raise NotImplementedError() 106 | 107 | def new_name(self, template="xxx_todo_changeme"): 108 | """Return a string suitable for use as an identifier 109 | 110 | The new name is guaranteed not to conflict with other identifiers. 111 | """ 112 | name = template 113 | while name in self.used_names: 114 | name = template + str(next(self.numbers)) 115 | self.used_names.add(name) 116 | return name 117 | 118 | def log_message(self, message): 119 | if self.first_log: 120 | self.first_log = False 121 | self.log.append("### In file %s ###" % self.filename) 122 | self.log.append(message) 123 | 124 | def cannot_convert(self, node, reason=None): 125 | """Warn the user that a given chunk of code is not valid Python 3, 126 | but that it cannot be converted automatically. 127 | 128 | First argument is the top-level node for the code in question. 129 | Optional second argument is why it can't be converted. 130 | """ 131 | lineno = node.get_lineno() 132 | for_output = node.clone() 133 | for_output.prefix = "" 134 | msg = "Line %d: could not convert: %s" 135 | self.log_message(msg % (lineno, for_output)) 136 | if reason: 137 | self.log_message(reason) 138 | 139 | def warning(self, node, reason): 140 | """Used for warning the user about possible uncertainty in the 141 | translation. 142 | 143 | First argument is the top-level node for the code in question. 144 | Optional second argument is why it can't be converted. 145 | """ 146 | lineno = node.get_lineno() 147 | self.log_message("Line %d: %s" % (lineno, reason)) 148 | 149 | def start_tree(self, tree, filename): 150 | """Some fixers need to maintain tree-wide state. 151 | This method is called once, at the start of tree fix-up. 152 | 153 | tree - the root node of the tree to be processed. 154 | filename - the name of the file the tree came from. 155 | """ 156 | self.used_names = tree.used_names 157 | self.set_filename(filename) 158 | self.numbers = itertools.count(1) 159 | self.first_log = True 160 | 161 | def finish_tree(self, tree, filename): 162 | """Some fixers need to maintain tree-wide state. 163 | This method is called once, at the conclusion of tree fix-up. 164 | 165 | tree - the root node of the tree to be processed. 166 | filename - the name of the file the tree came from. 167 | """ 168 | pass 169 | 170 | 171 | class ConditionalFix(BaseFix): 172 | """ Base class for fixers which not execute if an import is found. """ 173 | 174 | # This is the name of the import which, if found, will cause the test to be skipped 175 | skip_on = None 176 | 177 | def start_tree(self, *args): 178 | super(ConditionalFix, self).start_tree(*args) 179 | self._should_skip = None 180 | 181 | def should_skip(self, node): 182 | if self._should_skip is not None: 183 | return self._should_skip 184 | pkg = self.skip_on.split(".") 185 | name = pkg[-1] 186 | pkg = ".".join(pkg[:-1]) 187 | self._should_skip = does_tree_import(pkg, name, node) 188 | return self._should_skip 189 | -------------------------------------------------------------------------------- /fissix/btm_matcher.py: -------------------------------------------------------------------------------- 1 | """A bottom-up tree matching algorithm implementation meant to speed 2 | up 2to3's matching process. After the tree patterns are reduced to 3 | their rarest linear path, a linear Aho-Corasick automaton is 4 | created. The linear automaton traverses the linear paths from the 5 | leaves to the root of the AST and returns a set of nodes for further 6 | matching. This reduces significantly the number of candidate nodes.""" 7 | 8 | __author__ = "George Boutsioukis " 9 | 10 | import logging 11 | import itertools 12 | from collections import defaultdict 13 | 14 | from . import pytree 15 | from .btm_utils import reduce_tree 16 | 17 | 18 | class BMNode(object): 19 | """Class for a node of the Aho-Corasick automaton used in matching""" 20 | 21 | count = itertools.count() 22 | 23 | def __init__(self): 24 | self.transition_table = {} 25 | self.fixers = [] 26 | self.id = next(BMNode.count) 27 | self.content = "" 28 | 29 | 30 | class BottomMatcher(object): 31 | """The main matcher class. After instantiating the patterns should 32 | be added using the add_fixer method""" 33 | 34 | def __init__(self): 35 | self.match = set() 36 | self.root = BMNode() 37 | self.nodes = [self.root] 38 | self.fixers = [] 39 | self.logger = logging.getLogger("RefactoringTool") 40 | 41 | def add_fixer(self, fixer): 42 | """Reduces a fixer's pattern tree to a linear path and adds it 43 | to the matcher(a common Aho-Corasick automaton). The fixer is 44 | appended on the matching states and called when they are 45 | reached""" 46 | self.fixers.append(fixer) 47 | tree = reduce_tree(fixer.pattern_tree) 48 | linear = tree.get_linear_subpattern() 49 | match_nodes = self.add(linear, start=self.root) 50 | for match_node in match_nodes: 51 | match_node.fixers.append(fixer) 52 | 53 | def add(self, pattern, start): 54 | "Recursively adds a linear pattern to the AC automaton" 55 | # print("adding pattern", pattern, "to", start) 56 | if not pattern: 57 | # print("empty pattern") 58 | return [start] 59 | if isinstance(pattern[0], tuple): 60 | # alternatives 61 | # print("alternatives") 62 | match_nodes = [] 63 | for alternative in pattern[0]: 64 | # add all alternatives, and add the rest of the pattern 65 | # to each end node 66 | end_nodes = self.add(alternative, start=start) 67 | for end in end_nodes: 68 | match_nodes.extend(self.add(pattern[1:], end)) 69 | return match_nodes 70 | else: 71 | # single token 72 | # not last 73 | if pattern[0] not in start.transition_table: 74 | # transition did not exist, create new 75 | next_node = BMNode() 76 | start.transition_table[pattern[0]] = next_node 77 | else: 78 | # transition exists already, follow 79 | next_node = start.transition_table[pattern[0]] 80 | 81 | if pattern[1:]: 82 | end_nodes = self.add(pattern[1:], start=next_node) 83 | else: 84 | end_nodes = [next_node] 85 | return end_nodes 86 | 87 | def run(self, leaves): 88 | """The main interface with the bottom matcher. The tree is 89 | traversed from the bottom using the constructed 90 | automaton. Nodes are only checked once as the tree is 91 | retraversed. When the automaton fails, we give it one more 92 | shot(in case the above tree matches as a whole with the 93 | rejected leaf), then we break for the next leaf. There is the 94 | special case of multiple arguments(see code comments) where we 95 | recheck the nodes 96 | 97 | Args: 98 | The leaves of the AST tree to be matched 99 | 100 | Returns: 101 | A dictionary of node matches with fixers as the keys 102 | """ 103 | current_ac_node = self.root 104 | results = defaultdict(list) 105 | for leaf in leaves: 106 | current_ast_node = leaf 107 | while current_ast_node: 108 | current_ast_node.was_checked = True 109 | for child in current_ast_node.children: 110 | # multiple statements, recheck 111 | if isinstance(child, pytree.Leaf) and child.value == ";": 112 | current_ast_node.was_checked = False 113 | break 114 | if current_ast_node.type == 1: 115 | # name 116 | node_token = current_ast_node.value 117 | else: 118 | node_token = current_ast_node.type 119 | 120 | if node_token in current_ac_node.transition_table: 121 | # token matches 122 | current_ac_node = current_ac_node.transition_table[node_token] 123 | for fixer in current_ac_node.fixers: 124 | results[fixer].append(current_ast_node) 125 | else: 126 | # matching failed, reset automaton 127 | current_ac_node = self.root 128 | if ( 129 | current_ast_node.parent is not None 130 | and current_ast_node.parent.was_checked 131 | ): 132 | # the rest of the tree upwards has been checked, next leaf 133 | break 134 | 135 | # recheck the rejected node once from the root 136 | if node_token in current_ac_node.transition_table: 137 | # token matches 138 | current_ac_node = current_ac_node.transition_table[node_token] 139 | for fixer in current_ac_node.fixers: 140 | results[fixer].append(current_ast_node) 141 | 142 | current_ast_node = current_ast_node.parent 143 | return results 144 | 145 | def print_ac(self): 146 | "Prints a graphviz diagram of the BM automaton(for debugging)" 147 | print("digraph g{") 148 | 149 | def print_node(node): 150 | for subnode_key in node.transition_table.keys(): 151 | subnode = node.transition_table[subnode_key] 152 | print( 153 | "%d -> %d [label=%s] //%s" 154 | % (node.id, subnode.id, type_repr(subnode_key), str(subnode.fixers)) 155 | ) 156 | if subnode_key == 1: 157 | print(subnode.content) 158 | print_node(subnode) 159 | 160 | print_node(self.root) 161 | print("}") 162 | 163 | 164 | # taken from pytree.py for debugging; only used by print_ac 165 | _type_reprs = {} 166 | 167 | 168 | def type_repr(type_num): 169 | global _type_reprs 170 | if not _type_reprs: 171 | from .pygram import python_symbols 172 | 173 | # printing tokens is possible but not as useful 174 | # from .pgen2 import token // token.__dict__.items(): 175 | for name, val in python_symbols.__dict__.items(): 176 | if type(val) == int: 177 | _type_reprs[val] = name 178 | return _type_reprs.setdefault(type_num, type_num) 179 | -------------------------------------------------------------------------------- /fissix/patcomp.py: -------------------------------------------------------------------------------- 1 | # Copyright 2006 Google, Inc. All Rights Reserved. 2 | # Licensed to PSF under a Contributor Agreement. 3 | 4 | """Pattern compiler. 5 | 6 | The grammar is taken from PatternGrammar.txt. 7 | 8 | The compiler compiles a pattern to a pytree.*Pattern instance. 9 | """ 10 | 11 | __author__ = "Guido van Rossum " 12 | 13 | # Python imports 14 | import io 15 | 16 | # Fairly local imports 17 | from .pgen2 import driver, literals, token, tokenize, parse, grammar 18 | 19 | # Really local imports 20 | from . import pytree 21 | from . import pygram 22 | 23 | 24 | class PatternSyntaxError(Exception): 25 | pass 26 | 27 | 28 | def tokenize_wrapper(input): 29 | """Tokenizes a string suppressing significant whitespace.""" 30 | skip = {token.NEWLINE, token.INDENT, token.DEDENT} 31 | tokens = tokenize.generate_tokens(io.StringIO(input).readline) 32 | for quintuple in tokens: 33 | type, value, start, end, line_text = quintuple 34 | if type not in skip: 35 | yield quintuple 36 | 37 | 38 | class PatternCompiler(object): 39 | def __init__(self, grammar_file=None): 40 | """Initializer. 41 | 42 | Takes an optional alternative filename for the pattern grammar. 43 | """ 44 | if grammar_file is None: 45 | self.grammar = pygram.pattern_grammar 46 | self.syms = pygram.pattern_symbols 47 | else: 48 | self.grammar = driver.load_grammar(grammar_file, save=False, force=True) 49 | self.syms = pygram.Symbols(self.grammar) 50 | self.pygrammar = pygram.python_grammar 51 | self.pysyms = pygram.python_symbols 52 | self.driver = driver.Driver(self.grammar, convert=pattern_convert) 53 | 54 | def compile_pattern(self, input, debug=False, with_tree=False): 55 | """Compiles a pattern string to a nested pytree.*Pattern object.""" 56 | tokens = tokenize_wrapper(input) 57 | try: 58 | root = self.driver.parse_tokens(tokens, debug=debug) 59 | except parse.ParseError as e: 60 | raise PatternSyntaxError(str(e)) from None 61 | if with_tree: 62 | return self.compile_node(root), root 63 | else: 64 | return self.compile_node(root) 65 | 66 | def compile_node(self, node): 67 | """Compiles a node, recursively. 68 | 69 | This is one big switch on the node type. 70 | """ 71 | # XXX Optimize certain Wildcard-containing-Wildcard patterns 72 | # that can be merged 73 | if node.type == self.syms.Matcher: 74 | node = node.children[0] # Avoid unneeded recursion 75 | 76 | if node.type == self.syms.Alternatives: 77 | # Skip the odd children since they are just '|' tokens 78 | alts = [self.compile_node(ch) for ch in node.children[::2]] 79 | if len(alts) == 1: 80 | return alts[0] 81 | p = pytree.WildcardPattern([[a] for a in alts], min=1, max=1) 82 | return p.optimize() 83 | 84 | if node.type == self.syms.Alternative: 85 | units = [self.compile_node(ch) for ch in node.children] 86 | if len(units) == 1: 87 | return units[0] 88 | p = pytree.WildcardPattern([units], min=1, max=1) 89 | return p.optimize() 90 | 91 | if node.type == self.syms.NegatedUnit: 92 | pattern = self.compile_basic(node.children[1:]) 93 | p = pytree.NegatedPattern(pattern) 94 | return p.optimize() 95 | 96 | assert node.type == self.syms.Unit 97 | 98 | name = None 99 | nodes = node.children 100 | if len(nodes) >= 3 and nodes[1].type == token.EQUAL: 101 | name = nodes[0].value 102 | nodes = nodes[2:] 103 | repeat = None 104 | if len(nodes) >= 2 and nodes[-1].type == self.syms.Repeater: 105 | repeat = nodes[-1] 106 | nodes = nodes[:-1] 107 | 108 | # Now we've reduced it to: STRING | NAME [Details] | (...) | [...] 109 | pattern = self.compile_basic(nodes, repeat) 110 | 111 | if repeat is not None: 112 | assert repeat.type == self.syms.Repeater 113 | children = repeat.children 114 | child = children[0] 115 | if child.type == token.STAR: 116 | min = 0 117 | max = pytree.HUGE 118 | elif child.type == token.PLUS: 119 | min = 1 120 | max = pytree.HUGE 121 | elif child.type == token.LBRACE: 122 | assert children[-1].type == token.RBRACE 123 | assert len(children) in (3, 5) 124 | min = max = self.get_int(children[1]) 125 | if len(children) == 5: 126 | max = self.get_int(children[3]) 127 | else: 128 | assert False 129 | if min != 1 or max != 1: 130 | pattern = pattern.optimize() 131 | pattern = pytree.WildcardPattern([[pattern]], min=min, max=max) 132 | 133 | if name is not None: 134 | pattern.name = name 135 | return pattern.optimize() 136 | 137 | def compile_basic(self, nodes, repeat=None): 138 | # Compile STRING | NAME [Details] | (...) | [...] 139 | assert len(nodes) >= 1 140 | node = nodes[0] 141 | if node.type == token.STRING: 142 | value = str(literals.evalString(node.value)) 143 | return pytree.LeafPattern(_type_of_literal(value), value) 144 | elif node.type == token.NAME: 145 | value = node.value 146 | if value.isupper(): 147 | if value not in TOKEN_MAP: 148 | raise PatternSyntaxError("Invalid token: %r" % value) 149 | if nodes[1:]: 150 | raise PatternSyntaxError("Can't have details for token") 151 | return pytree.LeafPattern(TOKEN_MAP[value]) 152 | else: 153 | if value == "any": 154 | type = None 155 | elif not value.startswith("_"): 156 | type = getattr(self.pysyms, value, None) 157 | if type is None: 158 | raise PatternSyntaxError("Invalid symbol: %r" % value) 159 | if nodes[1:]: # Details present 160 | content = [self.compile_node(nodes[1].children[1])] 161 | else: 162 | content = None 163 | return pytree.NodePattern(type, content) 164 | elif node.value == "(": 165 | return self.compile_node(nodes[1]) 166 | elif node.value == "[": 167 | assert repeat is None 168 | subpattern = self.compile_node(nodes[1]) 169 | return pytree.WildcardPattern([[subpattern]], min=0, max=1) 170 | assert False, node 171 | 172 | def get_int(self, node): 173 | assert node.type == token.NUMBER 174 | return int(node.value) 175 | 176 | 177 | # Map named tokens to the type value for a LeafPattern 178 | TOKEN_MAP = { 179 | "NAME": token.NAME, 180 | "STRING": token.STRING, 181 | "NUMBER": token.NUMBER, 182 | "TOKEN": None, 183 | } 184 | 185 | 186 | def _type_of_literal(value): 187 | if value[0].isalpha(): 188 | return token.NAME 189 | elif value in grammar.opmap: 190 | return grammar.opmap[value] 191 | else: 192 | return None 193 | 194 | 195 | def pattern_convert(grammar, raw_node_info): 196 | """Converts raw node information to a Node or Leaf instance.""" 197 | type, value, context, children = raw_node_info 198 | if children or type in grammar.number2symbol: 199 | return pytree.Node(type, children, context=context) 200 | else: 201 | return pytree.Leaf(type, value, context=context) 202 | 203 | 204 | def compile_pattern(pattern): 205 | return PatternCompiler().compile_pattern(pattern) 206 | -------------------------------------------------------------------------------- /bowler/imr.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import logging 9 | from typing import Any, List, Optional 10 | 11 | from fissix.fixer_util import LParen, Name 12 | 13 | from .helpers import find_last 14 | from .types import ( 15 | ARG_END, 16 | ARG_LISTS, 17 | LN, 18 | STARS, 19 | SYMBOL, 20 | TOKEN, 21 | Capture, 22 | IMRError, 23 | Leaf, 24 | Node, 25 | ) 26 | 27 | log = logging.getLogger(__name__) 28 | 29 | 30 | class FunctionArgument: 31 | def __init__(self, name="", value=None, annotation="", star=None, prefix=None): 32 | self.name = name 33 | self.value = value 34 | self.annotation = annotation 35 | self.star = star 36 | self.prefix = prefix 37 | 38 | @classmethod 39 | def build(cls, leaf, is_def, **kwargs): 40 | while leaf is not None and leaf.type not in ARG_END: 41 | if leaf.type in (SYMBOL.star_expr, SYMBOL.argument): 42 | return cls.build(leaf.children[0], is_def, prefix=leaf.prefix) 43 | 44 | elif leaf.type in STARS: 45 | kwargs["star"] = leaf.clone() 46 | 47 | elif leaf.type == SYMBOL.tname: 48 | kwargs["name"] = leaf.children[0].value 49 | kwargs["annotation"] = leaf.children[-1].value 50 | 51 | elif leaf.type == TOKEN.EQUAL: 52 | pass 53 | 54 | elif leaf.type == TOKEN.NAME: 55 | if (is_def and "name" not in kwargs) or ( 56 | leaf.next_sibling and leaf.next_sibling.type == TOKEN.EQUAL 57 | ): 58 | kwargs["name"] = leaf.value 59 | else: 60 | kwargs["value"] = leaf.clone() 61 | 62 | else: 63 | # assume everything else is a complex value 64 | kwargs["value"] = leaf.clone() 65 | 66 | kwargs.setdefault("prefix", leaf.prefix) 67 | leaf = leaf.next_sibling 68 | 69 | return FunctionArgument(**kwargs) 70 | 71 | @classmethod 72 | def build_list( 73 | cls, arguments, is_def 74 | ): 75 | result = [] 76 | 77 | # empty function 78 | if not arguments: 79 | return result 80 | 81 | # only care about what's on the inside 82 | if arguments[0].type in ARG_LISTS: 83 | leaf = arguments[0].children[0] 84 | else: 85 | leaf = arguments[0] 86 | 87 | while leaf is not None: 88 | arg = cls.build(leaf, is_def) 89 | log.debug("{} -> {}".format(leaf, arg)) 90 | result.append(arg) 91 | 92 | # consume leafs for this argument 93 | while leaf is not None and leaf.type not in ARG_END: 94 | log.debug("consuming {}".format(leaf)) 95 | leaf = leaf.next_sibling 96 | 97 | # assume we stopped on a comma or parenthesis 98 | if leaf: 99 | log.debug("separator {}".format(leaf)) 100 | leaf = leaf.next_sibling 101 | 102 | return result 103 | 104 | def explode(self, is_def, prefix = ""): 105 | result = [] 106 | prefix = self.prefix if self.prefix else prefix 107 | if is_def: 108 | if self.star: 109 | self.star.prefix = prefix 110 | result.append(self.star) 111 | prefix = "" 112 | 113 | if self.annotation: 114 | result.append( 115 | Node( 116 | SYMBOL.tname, 117 | [ 118 | Name(self.name, prefix=prefix), 119 | Leaf(TOKEN.COLON, ":", prefix=""), 120 | Name(self.annotation, prefix=" "), 121 | ], 122 | prefix=prefix, 123 | ) 124 | ) 125 | else: 126 | result.append(Name(self.name, prefix=prefix)) 127 | 128 | if self.value: 129 | space = " " if self.annotation else "" 130 | result.append(Leaf(TOKEN.EQUAL, "=", prefix=space)) 131 | result.append(self.value) 132 | 133 | else: 134 | if self.star: 135 | if self.star.type == TOKEN.STAR: 136 | node = Node(SYMBOL.star_expr, [self.star], prefix=prefix) 137 | elif self.star.type == TOKEN.DOUBLESTAR: 138 | node = Node(SYMBOL.argument, [self.star], prefix=prefix) 139 | 140 | if self.value: 141 | self.value.prefix = "" 142 | node.append_child(self.value) 143 | 144 | result.append(node) 145 | return result 146 | 147 | if self.name: 148 | self.value.prefix = "" 149 | result.append( 150 | Node( 151 | SYMBOL.argument, 152 | [ 153 | Name(self.name, prefix=prefix), 154 | Leaf(TOKEN.EQUAL, value="=", prefix=""), 155 | self.value, 156 | ], 157 | prefix=prefix, 158 | ) 159 | ) 160 | else: 161 | self.value.prefix = prefix 162 | result.append(self.value) 163 | 164 | return result 165 | 166 | @classmethod 167 | def explode_list( 168 | cls, arguments, is_def 169 | ): 170 | nodes = [] 171 | prefix = "" 172 | index = 0 173 | for argument in arguments: 174 | if index: 175 | nodes.append(Leaf(TOKEN.COMMA, ",", prefix="")) 176 | prefix = " " 177 | 178 | result = argument.explode(is_def, prefix=prefix) 179 | log.debug("{} -> {}".format(argument, result)) 180 | nodes.extend(result) 181 | index += 1 182 | 183 | if not nodes: 184 | return None 185 | 186 | if len(nodes) == 1: 187 | return nodes[0] 188 | 189 | elif is_def: 190 | return Node(SYMBOL.typedargslist, nodes, prefix=nodes[0].prefix) 191 | 192 | else: 193 | return Node(SYMBOL.arglist, nodes, prefix=nodes[0].prefix) 194 | 195 | 196 | class FunctionSpec: 197 | def __init__(self, name, argument, is_def, capture, node): 198 | self.name = name 199 | self.arguments = arguments 200 | self.is_def = is_def 201 | self.capture = capture 202 | self.node = node 203 | 204 | @classmethod 205 | def build(cls, node, capture): 206 | try: 207 | name = capture["function_name"] 208 | is_def = "function_def" in capture 209 | args = capture["function_arguments"] 210 | except KeyError as e: 211 | raise IMRError("function spec invalid") from e 212 | 213 | arguments = FunctionArgument.build_list(args, is_def) 214 | 215 | return FunctionSpec(name.value, arguments, is_def, capture, node) 216 | 217 | def explode(self): 218 | arguments = FunctionArgument.explode_list(self.arguments, self.is_def) 219 | 220 | rparen = find_last(self.capture["function_parameters"], TOKEN.RPAR) 221 | rprefix = rparen.prefix if rparen else "" 222 | 223 | if self.is_def: 224 | parameters = Node( 225 | SYMBOL.parameters, 226 | [LParen(), Leaf(TOKEN.RPAR, ")", prefix=rprefix)], 227 | prefix="", 228 | ) 229 | else: 230 | parameters = Node( 231 | SYMBOL.trailer, 232 | [LParen(), Leaf(TOKEN.RPAR, ")", prefix=rprefix)], 233 | prefix="", 234 | ) 235 | 236 | if arguments: 237 | parameters.insert_child(1, arguments) 238 | 239 | self.capture["function_parameters"].replace(parameters) 240 | 241 | return self.node 242 | -------------------------------------------------------------------------------- /bowler/README.md: -------------------------------------------------------------------------------- 1 | Bowler 2 | ====== 3 | 4 | 5 | Overview 6 | -------- 7 | 8 | Bowler is a refactoring tool for manipulating Python at the syntax tree level. It 9 | enables safe, large scale code modification while guaranteeing that the resulting code 10 | compiles and runs. It provides both a simple command line interface and a fluent API in 11 | Python for generating complex code modifications in code. 12 | 13 | query = ( 14 | Query([]) 15 | # rename class Foo to Bar 16 | .select_class("Foo") 17 | .rename("Bar") 18 | # change method buzz(x) to buzzard(x: int) 19 | .select_method("buzz") 20 | .rename("buzzard") 21 | .modify_argument("x", type_annotation="int") 22 | ) 23 | 24 | query.diff() # generate unified diff on stdout 25 | query.write() # write changes directly to files 26 | 27 | Bowler uses the concrete syntax tree (CST) as generated by the [lib2to3][] module. 28 | 29 | 30 | CLI Reference 31 | ------------- 32 | 33 | Using Bowler at the command line follows the pattern below: 34 | 35 | $ bowler [--debug] [--help] [ ...] 36 | 37 | Bowler supports the following commands: 38 | 39 | do [ ...] 40 | Compile and run the given query, or open an IPython shell if none given. 41 | Common API elements will already be available in the global namespace. 42 | 43 | dump [ ...] 44 | Dump the CST from the given paths to stdout. 45 | 46 | rename_function [-i | --interactive] [ ...] 47 | Rename a function and its calls. 48 | 49 | 50 | Query Reference 51 | --------------- 52 | 53 | Queries use a fluent API to build a series of transforms over a given set of paths. 54 | Each transform consists of a selector, any number of filters, and one or more 55 | modifiers. Queries will only be compiled and executed once an appropriate action 56 | is triggered – like `diff()` or `write()`. 57 | 58 | Constructing queries should follow this basic pattern: 59 | 60 | 1. Create the query object, and specify all paths that should be considered 61 | 2. Specify a selector to define broad search criteria 62 | 3. Optionally specify one or more filters to refine the scope of modification 63 | 4. Specify one or more modifiers 64 | 5. Repeat from step 2 to include more transforms in the query 65 | 6. Execute the query with a terminal action, such as `diff()` or `write()`. 66 | 67 | Queries are started by creating a `Query` instance, and passing a list of paths that 68 | should be considered for modification: 69 | 70 | query = Query([path, ...]) 71 | 72 | All methods on a `Query` object will return the same `Query` object back, enabling 73 | "fluent" usage of the API – chaining one method call after the other: 74 | 75 | query = Query(...).selector(...)[.filter(...)].modifier(...) 76 | 77 | 78 | ### Selectors 79 | 80 | Selectors are query methods that generate a search pattern for the custom [lib2to3][] 81 | syntax. There are a number of prewritten selectors, but Bowler supports arbitrary 82 | selectors as well. 83 | 84 | Bowler supports the following methods for choosing selectors: 85 | 86 | .select_root() 87 | Selects the root of the syntax tree for each file. 88 | 89 | .select_module(name) 90 | Selects all module imports and references with the given name. 91 | 92 | .select_class(name) 93 | Selects all class definitions for – or subclasses of – the given name, as well 94 | as any calls or references to that name. 95 | 96 | .select_subclass(name) 97 | Selects all class definitions that subclass the given name, as well as any calls 98 | or references to that name. 99 | 100 | .select_attribute(name) 101 | Selects all class or object attributes, including assignments and references. 102 | 103 | .select_method(name) 104 | Selects all class method definitions with the given name, as well as any method 105 | calls or references with that name. 106 | 107 | .select_function(name) 108 | Selects all bare function definitions with the given name, as well as any calls 109 | or references with that name. 110 | 111 | .select_var(name) 112 | Select all references to that name, regardless of context. 113 | 114 | .select_pattern(pattern) 115 | Select nodes based on the arbitrary [lib2to3][] pattern given. 116 | 117 | 118 | ### Filters 119 | 120 | Filters are functions that limit the scope of modifiers. They are functions with the 121 | signature of `filter(node, capture, filename) -> bool`, and return `True` if the current 122 | node should be eligible for modification, or `False` to skip the node. 123 | 124 | - `node` refers to the base CST node matched by the active selector 125 | - `capture` is a dictionary, mapping named captures to their associated CST leaf or node 126 | - `filename` is the current filename being modified 127 | 128 | Bowler supports the following methods for adding filters: 129 | 130 | .is_call() 131 | Filters all nodes that aren't function or method calls. 132 | 133 | .is_def() 134 | Filters all nodes that aren't function or method definitions. 135 | 136 | .in_class(name, [include_subclasses = True]) 137 | Filters all nodes that aren't part of the either given class definition or 138 | a subclass of the given class. 139 | 140 | .is_filename([include = ], [exclude = ]) 141 | Filters all nodes belonging to files that don't match the given include/exclude 142 | regular expressions. 143 | 144 | .add_filter(function | str) 145 | Use an arbitrary function to filter nodes. If given a string, compile that 146 | and `eval()` it at each node to determine if the filter passed. 147 | 148 | 149 | ### Modifiers 150 | 151 | Modifiers take each matched node – that passed all active filters – and optionally 152 | applies some number of modifications to the CST of that node. They are functions with 153 | the signature of `filter(node, capture, filename)`, with no expected return value. 154 | 155 | - `node` refers to the base CST node matched by the active selector 156 | - `capture` is a dictionary, mapping named captures to their associated CST leaf or node 157 | - `filename` is the current filename being modified 158 | 159 | Bowler supports the following methods for adding modifiers: 160 | 161 | .rename(new_name) 162 | Rename all `*_name` captures to the given new name. 163 | 164 | .encapsulate([internal_name]) 165 | Encapsulate a class attribute into an `@property` decorated getter and setter. 166 | Requires the `select_attribute()` selector. 167 | 168 | .add_argument(name, value, [positional], [after], [type_annotation]) 169 | Add a new argument to a method or function, with a default value and optional 170 | position or type annotation. Also updates all callers with the new argument. 171 | 172 | .modify_argument(name, [new_name], [type_annotation], [default_value]) 173 | Modify an existing argument to a method or function, optionally renaming it, 174 | adding/changing the type annotation, or adding/changing the default value. 175 | Also updates all callers with new names. 176 | 177 | .remove_argument(name) 178 | Remove an existing argument from a method or function, as well as from callers. 179 | 180 | .add_modifier(function | str) 181 | Add an arbitrary modifier function. If given a string, compile that and 182 | `exec()` it at each matched node to perform modifications. 183 | 184 | 185 | ### Actions 186 | 187 | After building one or more transforms, those transforms are applied through the use of 188 | terminal actions. These include generating diffs, writing modifications to disk, and 189 | dumping matched nodes to stdout. 190 | 191 | Bowler supports the following terminal actions: 192 | 193 | .diff([interactive=False]) 194 | Generate a unified diff and echo it to stdout. Colors will be used when stdout 195 | is a terminal. Interactive mode will prompt the user after each hunk, with 196 | actions similar to those found in `git add -p`. 197 | 198 | .idiff() 199 | Shortcut for `.diff(interactive=True)`. 200 | 201 | .write() 202 | Write all changes back to disk, overwriting existing files. 203 | 204 | .dump() 205 | For each node matched, for each transform, print the CST representation of the 206 | node along with all captured nodes. 207 | 208 | .execute([write=False], [interactive=False]) 209 | Longer form of `.diff()` or `.write()`. 210 | 211 | 212 | [lib2to3]: https://docs.python.org/3.6/library/2to3.html#module-lib2to3 213 | -------------------------------------------------------------------------------- /fissix/pgen2/parse.py: -------------------------------------------------------------------------------- 1 | # Copyright 2004-2005 Elemental Security, Inc. All Rights Reserved. 2 | # Licensed to PSF under a Contributor Agreement. 3 | 4 | """Parser engine for the grammar tables generated by pgen. 5 | 6 | The grammar table must be loaded first. 7 | 8 | See Parser/parser.c in the Python distribution for additional info on 9 | how this parsing engine works. 10 | 11 | """ 12 | 13 | # Local imports 14 | from . import token 15 | 16 | 17 | class ParseError(Exception): 18 | """Exception to signal the parser is stuck.""" 19 | 20 | def __init__(self, msg, type, value, context): 21 | Exception.__init__( 22 | self, "%s: type=%r, value=%r, context=%r" % (msg, type, value, context) 23 | ) 24 | self.msg = msg 25 | self.type = type 26 | self.value = value 27 | self.context = context 28 | 29 | def __reduce__(self): 30 | return type(self), (self.msg, self.type, self.value, self.context) 31 | 32 | 33 | class Parser(object): 34 | """Parser engine. 35 | 36 | The proper usage sequence is: 37 | 38 | p = Parser(grammar, [converter]) # create instance 39 | p.setup([start]) # prepare for parsing 40 | : 41 | if p.addtoken(...): # parse a token; may raise ParseError 42 | break 43 | root = p.rootnode # root of abstract syntax tree 44 | 45 | A Parser instance may be reused by calling setup() repeatedly. 46 | 47 | A Parser instance contains state pertaining to the current token 48 | sequence, and should not be used concurrently by different threads 49 | to parse separate token sequences. 50 | 51 | See driver.py for how to get input tokens by tokenizing a file or 52 | string. 53 | 54 | Parsing is complete when addtoken() returns True; the root of the 55 | abstract syntax tree can then be retrieved from the rootnode 56 | instance variable. When a syntax error occurs, addtoken() raises 57 | the ParseError exception. There is no error recovery; the parser 58 | cannot be used after a syntax error was reported (but it can be 59 | reinitialized by calling setup()). 60 | 61 | """ 62 | 63 | def __init__(self, grammar, convert=None): 64 | """Constructor. 65 | 66 | The grammar argument is a grammar.Grammar instance; see the 67 | grammar module for more information. 68 | 69 | The parser is not ready yet for parsing; you must call the 70 | setup() method to get it started. 71 | 72 | The optional convert argument is a function mapping concrete 73 | syntax tree nodes to abstract syntax tree nodes. If not 74 | given, no conversion is done and the syntax tree produced is 75 | the concrete syntax tree. If given, it must be a function of 76 | two arguments, the first being the grammar (a grammar.Grammar 77 | instance), and the second being the concrete syntax tree node 78 | to be converted. The syntax tree is converted from the bottom 79 | up. 80 | 81 | A concrete syntax tree node is a (type, value, context, nodes) 82 | tuple, where type is the node type (a token or symbol number), 83 | value is None for symbols and a string for tokens, context is 84 | None or an opaque value used for error reporting (typically a 85 | (lineno, offset) pair), and nodes is a list of children for 86 | symbols, and None for tokens. 87 | 88 | An abstract syntax tree node may be anything; this is entirely 89 | up to the converter function. 90 | 91 | """ 92 | self.grammar = grammar 93 | self.convert = convert or (lambda grammar, node: node) 94 | 95 | def setup(self, start=None): 96 | """Prepare for parsing. 97 | 98 | This *must* be called before starting to parse. 99 | 100 | The optional argument is an alternative start symbol; it 101 | defaults to the grammar's start symbol. 102 | 103 | You can use a Parser instance to parse any number of programs; 104 | each time you call setup() the parser is reset to an initial 105 | state determined by the (implicit or explicit) start symbol. 106 | 107 | """ 108 | if start is None: 109 | start = self.grammar.start 110 | # Each stack entry is a tuple: (dfa, state, node). 111 | # A node is a tuple: (type, value, context, children), 112 | # where children is a list of nodes or None, and context may be None. 113 | newnode = (start, None, None, []) 114 | stackentry = (self.grammar.dfas[start], 0, newnode) 115 | self.stack = [stackentry] 116 | self.rootnode = None 117 | self.used_names = set() # Aliased to self.rootnode.used_names in pop() 118 | 119 | def addtoken(self, type, value, context): 120 | """Add a token; return True iff this is the end of the program.""" 121 | # Map from token to label 122 | ilabel = self.classify(type, value, context) 123 | # Loop until the token is shifted; may raise exceptions 124 | while True: 125 | dfa, state, node = self.stack[-1] 126 | states, first = dfa 127 | arcs = states[state] 128 | # Look for a state with this label 129 | for i, newstate in arcs: 130 | t, v = self.grammar.labels[i] 131 | if ilabel == i: 132 | # Look it up in the list of labels 133 | assert t < 256 134 | # Shift a token; we're done with it 135 | self.shift(type, value, newstate, context) 136 | # Pop while we are in an accept-only state 137 | state = newstate 138 | while states[state] == [(0, state)]: 139 | self.pop() 140 | if not self.stack: 141 | # Done parsing! 142 | return True 143 | dfa, state, node = self.stack[-1] 144 | states, first = dfa 145 | # Done with this token 146 | return False 147 | elif t >= 256: 148 | # See if it's a symbol and if we're in its first set 149 | itsdfa = self.grammar.dfas[t] 150 | itsstates, itsfirst = itsdfa 151 | if ilabel in itsfirst: 152 | # Push a symbol 153 | self.push(t, self.grammar.dfas[t], newstate, context) 154 | break # To continue the outer while loop 155 | else: 156 | if (0, state) in arcs: 157 | # An accepting state, pop it and try something else 158 | self.pop() 159 | if not self.stack: 160 | # Done parsing, but another token is input 161 | raise ParseError("too much input", type, value, context) 162 | else: 163 | # No success finding a transition 164 | raise ParseError("bad input", type, value, context) 165 | 166 | def classify(self, type, value, context): 167 | """Turn a token into a label. (Internal)""" 168 | if type == token.NAME: 169 | # Keep a listing of all used names 170 | self.used_names.add(value) 171 | # Check for reserved words 172 | ilabel = self.grammar.keywords.get(value) 173 | if ilabel is not None: 174 | return ilabel 175 | ilabel = self.grammar.tokens.get(type) 176 | if ilabel is None: 177 | raise ParseError("bad token", type, value, context) 178 | return ilabel 179 | 180 | def shift(self, type, value, newstate, context): 181 | """Shift a token. (Internal)""" 182 | dfa, state, node = self.stack[-1] 183 | newnode = (type, value, context, None) 184 | newnode = self.convert(self.grammar, newnode) 185 | if newnode is not None: 186 | node[-1].append(newnode) 187 | self.stack[-1] = (dfa, newstate, node) 188 | 189 | def push(self, type, newdfa, newstate, context): 190 | """Push a nonterminal. (Internal)""" 191 | dfa, state, node = self.stack[-1] 192 | newnode = (type, None, context, []) 193 | self.stack[-1] = (dfa, newstate, node) 194 | self.stack.append((newdfa, 0, newnode)) 195 | 196 | def pop(self): 197 | """Pop a nonterminal. (Internal)""" 198 | popdfa, popstate, popnode = self.stack.pop() 199 | newnode = self.convert(self.grammar, popnode) 200 | if newnode is not None: 201 | if self.stack: 202 | dfa, state, node = self.stack[-1] 203 | node[-1].append(newnode) 204 | else: 205 | self.rootnode = newnode 206 | self.rootnode.used_names = self.used_names 207 | -------------------------------------------------------------------------------- /fissix/pgen2/conv.py: -------------------------------------------------------------------------------- 1 | # Copyright 2004-2005 Elemental Security, Inc. All Rights Reserved. 2 | # Licensed to PSF under a Contributor Agreement. 3 | 4 | """Convert graminit.[ch] spit out by pgen to Python code. 5 | 6 | Pgen is the Python parser generator. It is useful to quickly create a 7 | parser from a grammar file in Python's grammar notation. But I don't 8 | want my parsers to be written in C (yet), so I'm translating the 9 | parsing tables to Python data structures and writing a Python parse 10 | engine. 11 | 12 | Note that the token numbers are constants determined by the standard 13 | Python tokenizer. The standard token module defines these numbers and 14 | their names (the names are not used much). The token numbers are 15 | hardcoded into the Python tokenizer and into pgen. A Python 16 | implementation of the Python tokenizer is also available, in the 17 | standard tokenize module. 18 | 19 | On the other hand, symbol numbers (representing the grammar's 20 | non-terminals) are assigned by pgen based on the actual grammar 21 | input. 22 | 23 | Note: this module is pretty much obsolete; the pgen module generates 24 | equivalent grammar tables directly from the Grammar.txt input file 25 | without having to invoke the Python pgen C program. 26 | 27 | """ 28 | 29 | # Python imports 30 | import re 31 | 32 | # Local imports 33 | from pgen2 import grammar, token 34 | 35 | 36 | class Converter(grammar.Grammar): 37 | """Grammar subclass that reads classic pgen output files. 38 | 39 | The run() method reads the tables as produced by the pgen parser 40 | generator, typically contained in two C files, graminit.h and 41 | graminit.c. The other methods are for internal use only. 42 | 43 | See the base class for more documentation. 44 | 45 | """ 46 | 47 | def run(self, graminit_h, graminit_c): 48 | """Load the grammar tables from the text files written by pgen.""" 49 | self.parse_graminit_h(graminit_h) 50 | self.parse_graminit_c(graminit_c) 51 | self.finish_off() 52 | 53 | def parse_graminit_h(self, filename): 54 | """Parse the .h file written by pgen. (Internal) 55 | 56 | This file is a sequence of #define statements defining the 57 | nonterminals of the grammar as numbers. We build two tables 58 | mapping the numbers to names and back. 59 | 60 | """ 61 | try: 62 | f = open(filename) 63 | except OSError as err: 64 | print("Can't open %s: %s" % (filename, err)) 65 | return False 66 | self.symbol2number = {} 67 | self.number2symbol = {} 68 | lineno = 0 69 | for line in f: 70 | lineno += 1 71 | mo = re.match(r"^#define\s+(\w+)\s+(\d+)$", line) 72 | if not mo and line.strip(): 73 | print("%s(%s): can't parse %s" % (filename, lineno, line.strip())) 74 | else: 75 | symbol, number = mo.groups() 76 | number = int(number) 77 | assert symbol not in self.symbol2number 78 | assert number not in self.number2symbol 79 | self.symbol2number[symbol] = number 80 | self.number2symbol[number] = symbol 81 | return True 82 | 83 | def parse_graminit_c(self, filename): 84 | """Parse the .c file written by pgen. (Internal) 85 | 86 | The file looks as follows. The first two lines are always this: 87 | 88 | #include "pgenheaders.h" 89 | #include "grammar.h" 90 | 91 | After that come four blocks: 92 | 93 | 1) one or more state definitions 94 | 2) a table defining dfas 95 | 3) a table defining labels 96 | 4) a struct defining the grammar 97 | 98 | A state definition has the following form: 99 | - one or more arc arrays, each of the form: 100 | static arc arcs__[] = { 101 | {, }, 102 | ... 103 | }; 104 | - followed by a state array, of the form: 105 | static state states_[] = { 106 | {, arcs__}, 107 | ... 108 | }; 109 | 110 | """ 111 | try: 112 | f = open(filename) 113 | except OSError as err: 114 | print("Can't open %s: %s" % (filename, err)) 115 | return False 116 | # The code below essentially uses f's iterator-ness! 117 | lineno = 0 118 | 119 | # Expect the two #include lines 120 | lineno, line = lineno + 1, next(f) 121 | assert line == '#include "pgenheaders.h"\n', (lineno, line) 122 | lineno, line = lineno + 1, next(f) 123 | assert line == '#include "grammar.h"\n', (lineno, line) 124 | 125 | # Parse the state definitions 126 | lineno, line = lineno + 1, next(f) 127 | allarcs = {} 128 | states = [] 129 | while line.startswith("static arc "): 130 | while line.startswith("static arc "): 131 | mo = re.match(r"static arc arcs_(\d+)_(\d+)\[(\d+)\] = {$", line) 132 | assert mo, (lineno, line) 133 | n, m, k = list(map(int, mo.groups())) 134 | arcs = [] 135 | for _ in range(k): 136 | lineno, line = lineno + 1, next(f) 137 | mo = re.match(r"\s+{(\d+), (\d+)},$", line) 138 | assert mo, (lineno, line) 139 | i, j = list(map(int, mo.groups())) 140 | arcs.append((i, j)) 141 | lineno, line = lineno + 1, next(f) 142 | assert line == "};\n", (lineno, line) 143 | allarcs[(n, m)] = arcs 144 | lineno, line = lineno + 1, next(f) 145 | mo = re.match(r"static state states_(\d+)\[(\d+)\] = {$", line) 146 | assert mo, (lineno, line) 147 | s, t = list(map(int, mo.groups())) 148 | assert s == len(states), (lineno, line) 149 | state = [] 150 | for _ in range(t): 151 | lineno, line = lineno + 1, next(f) 152 | mo = re.match(r"\s+{(\d+), arcs_(\d+)_(\d+)},$", line) 153 | assert mo, (lineno, line) 154 | k, n, m = list(map(int, mo.groups())) 155 | arcs = allarcs[n, m] 156 | assert k == len(arcs), (lineno, line) 157 | state.append(arcs) 158 | states.append(state) 159 | lineno, line = lineno + 1, next(f) 160 | assert line == "};\n", (lineno, line) 161 | lineno, line = lineno + 1, next(f) 162 | self.states = states 163 | 164 | # Parse the dfas 165 | dfas = {} 166 | mo = re.match(r"static dfa dfas\[(\d+)\] = {$", line) 167 | assert mo, (lineno, line) 168 | ndfas = int(mo.group(1)) 169 | for i in range(ndfas): 170 | lineno, line = lineno + 1, next(f) 171 | mo = re.match(r'\s+{(\d+), "(\w+)", (\d+), (\d+), states_(\d+),$', line) 172 | assert mo, (lineno, line) 173 | symbol = mo.group(2) 174 | number, x, y, z = list(map(int, mo.group(1, 3, 4, 5))) 175 | assert self.symbol2number[symbol] == number, (lineno, line) 176 | assert self.number2symbol[number] == symbol, (lineno, line) 177 | assert x == 0, (lineno, line) 178 | state = states[z] 179 | assert y == len(state), (lineno, line) 180 | lineno, line = lineno + 1, next(f) 181 | mo = re.match(r'\s+("(?:\\\d\d\d)*")},$', line) 182 | assert mo, (lineno, line) 183 | first = {} 184 | rawbitset = eval(mo.group(1)) 185 | for i, c in enumerate(rawbitset): 186 | byte = ord(c) 187 | for j in range(8): 188 | if byte & (1 << j): 189 | first[i * 8 + j] = 1 190 | dfas[number] = (state, first) 191 | lineno, line = lineno + 1, next(f) 192 | assert line == "};\n", (lineno, line) 193 | self.dfas = dfas 194 | 195 | # Parse the labels 196 | labels = [] 197 | lineno, line = lineno + 1, next(f) 198 | mo = re.match(r"static label labels\[(\d+)\] = {$", line) 199 | assert mo, (lineno, line) 200 | nlabels = int(mo.group(1)) 201 | for i in range(nlabels): 202 | lineno, line = lineno + 1, next(f) 203 | mo = re.match(r'\s+{(\d+), (0|"\w+")},$', line) 204 | assert mo, (lineno, line) 205 | x, y = mo.groups() 206 | x = int(x) 207 | if y == "0": 208 | y = None 209 | else: 210 | y = eval(y) 211 | labels.append((x, y)) 212 | lineno, line = lineno + 1, next(f) 213 | assert line == "};\n", (lineno, line) 214 | self.labels = labels 215 | 216 | # Parse the grammar struct 217 | lineno, line = lineno + 1, next(f) 218 | assert line == "grammar _PyParser_Grammar = {\n", (lineno, line) 219 | lineno, line = lineno + 1, next(f) 220 | mo = re.match(r"\s+(\d+),$", line) 221 | assert mo, (lineno, line) 222 | ndfas = int(mo.group(1)) 223 | assert ndfas == len(self.dfas) 224 | lineno, line = lineno + 1, next(f) 225 | assert line == "\tdfas,\n", (lineno, line) 226 | lineno, line = lineno + 1, next(f) 227 | mo = re.match(r"\s+{(\d+), labels},$", line) 228 | assert mo, (lineno, line) 229 | nlabels = int(mo.group(1)) 230 | assert nlabels == len(self.labels), (lineno, line) 231 | lineno, line = lineno + 1, next(f) 232 | mo = re.match(r"\s+(\d+)$", line) 233 | assert mo, (lineno, line) 234 | start = int(mo.group(1)) 235 | assert start in self.number2symbol, (lineno, line) 236 | self.start = start 237 | lineno, line = lineno + 1, next(f) 238 | assert line == "};\n", (lineno, line) 239 | try: 240 | lineno, line = lineno + 1, next(f) 241 | except StopIteration: 242 | pass 243 | else: 244 | assert 0, (lineno, line) 245 | 246 | def finish_off(self): 247 | """Create additional useful structures. (Internal).""" 248 | self.keywords = {} # map from keyword strings to arc labels 249 | self.tokens = {} # map from numeric token values to arc labels 250 | for ilabel, (type, value) in enumerate(self.labels): 251 | if type == token.NAME and value is not None: 252 | self.keywords[value] = ilabel 253 | elif value is None: 254 | self.tokens[type] = ilabel 255 | -------------------------------------------------------------------------------- /fissix/btm_utils.py: -------------------------------------------------------------------------------- 1 | "Utility functions used by the btm_matcher module" 2 | 3 | from . import pytree 4 | from .pgen2 import grammar, token 5 | from .pygram import pattern_symbols, python_symbols 6 | 7 | syms = pattern_symbols 8 | pysyms = python_symbols 9 | tokens = grammar.opmap 10 | token_labels = token 11 | 12 | TYPE_ANY = -1 13 | TYPE_ALTERNATIVES = -2 14 | TYPE_GROUP = -3 15 | 16 | 17 | class MinNode(object): 18 | """This class serves as an intermediate representation of the 19 | pattern tree during the conversion to sets of leaf-to-root 20 | subpatterns""" 21 | 22 | def __init__(self, type=None, name=None): 23 | self.type = type 24 | self.name = name 25 | self.children = [] 26 | self.leaf = False 27 | self.parent = None 28 | self.alternatives = [] 29 | self.group = [] 30 | 31 | def __repr__(self): 32 | return str(self.type) + " " + str(self.name) 33 | 34 | def leaf_to_root(self): 35 | """Internal method. Returns a characteristic path of the 36 | pattern tree. This method must be run for all leaves until the 37 | linear subpatterns are merged into a single""" 38 | node = self 39 | subp = [] 40 | while node: 41 | if node.type == TYPE_ALTERNATIVES: 42 | node.alternatives.append(subp) 43 | if len(node.alternatives) == len(node.children): 44 | # last alternative 45 | subp = [tuple(node.alternatives)] 46 | node.alternatives = [] 47 | node = node.parent 48 | continue 49 | else: 50 | node = node.parent 51 | subp = None 52 | break 53 | 54 | if node.type == TYPE_GROUP: 55 | node.group.append(subp) 56 | # probably should check the number of leaves 57 | if len(node.group) == len(node.children): 58 | subp = get_characteristic_subpattern(node.group) 59 | node.group = [] 60 | node = node.parent 61 | continue 62 | else: 63 | node = node.parent 64 | subp = None 65 | break 66 | 67 | if node.type == token_labels.NAME and node.name: 68 | # in case of type=name, use the name instead 69 | subp.append(node.name) 70 | else: 71 | subp.append(node.type) 72 | 73 | node = node.parent 74 | return subp 75 | 76 | def get_linear_subpattern(self): 77 | """Drives the leaf_to_root method. The reason that 78 | leaf_to_root must be run multiple times is because we need to 79 | reject 'group' matches; for example the alternative form 80 | (a | b c) creates a group [b c] that needs to be matched. Since 81 | matching multiple linear patterns overcomes the automaton's 82 | capabilities, leaf_to_root merges each group into a single 83 | choice based on 'characteristic'ity, 84 | 85 | i.e. (a|b c) -> (a|b) if b more characteristic than c 86 | 87 | Returns: The most 'characteristic'(as defined by 88 | get_characteristic_subpattern) path for the compiled pattern 89 | tree. 90 | """ 91 | 92 | for l in self.leaves(): 93 | subp = l.leaf_to_root() 94 | if subp: 95 | return subp 96 | 97 | def leaves(self): 98 | "Generator that returns the leaves of the tree" 99 | for child in self.children: 100 | yield from child.leaves() 101 | if not self.children: 102 | yield self 103 | 104 | 105 | def reduce_tree(node, parent=None): 106 | """ 107 | Internal function. Reduces a compiled pattern tree to an 108 | intermediate representation suitable for feeding the 109 | automaton. This also trims off any optional pattern elements(like 110 | [a], a*). 111 | """ 112 | 113 | new_node = None 114 | # switch on the node type 115 | if node.type == syms.Matcher: 116 | # skip 117 | node = node.children[0] 118 | 119 | if node.type == syms.Alternatives: 120 | # 2 cases 121 | if len(node.children) <= 2: 122 | # just a single 'Alternative', skip this node 123 | new_node = reduce_tree(node.children[0], parent) 124 | else: 125 | # real alternatives 126 | new_node = MinNode(type=TYPE_ALTERNATIVES) 127 | # skip odd children('|' tokens) 128 | for child in node.children: 129 | if node.children.index(child) % 2: 130 | continue 131 | reduced = reduce_tree(child, new_node) 132 | if reduced is not None: 133 | new_node.children.append(reduced) 134 | elif node.type == syms.Alternative: 135 | if len(node.children) > 1: 136 | 137 | new_node = MinNode(type=TYPE_GROUP) 138 | for child in node.children: 139 | reduced = reduce_tree(child, new_node) 140 | if reduced: 141 | new_node.children.append(reduced) 142 | if not new_node.children: 143 | # delete the group if all of the children were reduced to None 144 | new_node = None 145 | 146 | else: 147 | new_node = reduce_tree(node.children[0], parent) 148 | 149 | elif node.type == syms.Unit: 150 | if isinstance(node.children[0], pytree.Leaf) and node.children[0].value == "(": 151 | # skip parentheses 152 | return reduce_tree(node.children[1], parent) 153 | if ( 154 | isinstance(node.children[0], pytree.Leaf) and node.children[0].value == "[" 155 | ) or ( 156 | len(node.children) > 1 157 | and hasattr(node.children[1], "value") 158 | and node.children[1].value == "[" 159 | ): 160 | # skip whole unit if its optional 161 | return None 162 | 163 | leaf = True 164 | details_node = None 165 | alternatives_node = None 166 | has_repeater = False 167 | repeater_node = None 168 | has_variable_name = False 169 | 170 | for child in node.children: 171 | if child.type == syms.Details: 172 | leaf = False 173 | details_node = child 174 | elif child.type == syms.Repeater: 175 | has_repeater = True 176 | repeater_node = child 177 | elif child.type == syms.Alternatives: 178 | alternatives_node = child 179 | if hasattr(child, "value") and child.value == "=": # variable name 180 | has_variable_name = True 181 | 182 | # skip variable name 183 | if has_variable_name: 184 | # skip variable name, '=' 185 | name_leaf = node.children[2] 186 | if hasattr(name_leaf, "value") and name_leaf.value == "(": 187 | # skip parenthesis 188 | name_leaf = node.children[3] 189 | else: 190 | name_leaf = node.children[0] 191 | 192 | # set node type 193 | if name_leaf.type == token_labels.NAME: 194 | # (python) non-name or wildcard 195 | if name_leaf.value == "any": 196 | new_node = MinNode(type=TYPE_ANY) 197 | else: 198 | if hasattr(token_labels, name_leaf.value): 199 | new_node = MinNode(type=getattr(token_labels, name_leaf.value)) 200 | else: 201 | new_node = MinNode(type=getattr(pysyms, name_leaf.value)) 202 | 203 | elif name_leaf.type == token_labels.STRING: 204 | # (python) name or character; remove the apostrophes from 205 | # the string value 206 | name = name_leaf.value.strip("'") 207 | if name in tokens: 208 | new_node = MinNode(type=tokens[name]) 209 | else: 210 | new_node = MinNode(type=token_labels.NAME, name=name) 211 | elif name_leaf.type == syms.Alternatives: 212 | new_node = reduce_tree(alternatives_node, parent) 213 | 214 | # handle repeaters 215 | if has_repeater: 216 | if repeater_node.children[0].value == "*": 217 | # reduce to None 218 | new_node = None 219 | elif repeater_node.children[0].value == "+": 220 | # reduce to a single occurrence i.e. do nothing 221 | pass 222 | else: 223 | # TODO: handle {min, max} repeaters 224 | raise NotImplementedError 225 | pass 226 | 227 | # add children 228 | if details_node and new_node is not None: 229 | for child in details_node.children[1:-1]: 230 | # skip '<', '>' markers 231 | reduced = reduce_tree(child, new_node) 232 | if reduced is not None: 233 | new_node.children.append(reduced) 234 | if new_node: 235 | new_node.parent = parent 236 | return new_node 237 | 238 | 239 | def get_characteristic_subpattern(subpatterns): 240 | """Picks the most characteristic from a list of linear patterns 241 | Current order used is: 242 | names > common_names > common_chars 243 | """ 244 | if not isinstance(subpatterns, list): 245 | return subpatterns 246 | if len(subpatterns) == 1: 247 | return subpatterns[0] 248 | 249 | # first pick out the ones containing variable names 250 | subpatterns_with_names = [] 251 | subpatterns_with_common_names = [] 252 | common_names = ["in", "for", "if", "not", "None"] 253 | subpatterns_with_common_chars = [] 254 | common_chars = "[]().,:" 255 | for subpattern in subpatterns: 256 | if any(rec_test(subpattern, lambda x: type(x) is str)): 257 | if any( 258 | rec_test(subpattern, lambda x: isinstance(x, str) and x in common_chars) 259 | ): 260 | subpatterns_with_common_chars.append(subpattern) 261 | elif any( 262 | rec_test(subpattern, lambda x: isinstance(x, str) and x in common_names) 263 | ): 264 | subpatterns_with_common_names.append(subpattern) 265 | 266 | else: 267 | subpatterns_with_names.append(subpattern) 268 | 269 | if subpatterns_with_names: 270 | subpatterns = subpatterns_with_names 271 | elif subpatterns_with_common_names: 272 | subpatterns = subpatterns_with_common_names 273 | elif subpatterns_with_common_chars: 274 | subpatterns = subpatterns_with_common_chars 275 | # of the remaining subpatterns pick out the longest one 276 | return max(subpatterns, key=len) 277 | 278 | 279 | def rec_test(sequence, test_func): 280 | """Tests test_func on all items of sequence and items of included 281 | sub-iterables""" 282 | for x in sequence: 283 | if isinstance(x, (list, tuple)): 284 | yield from rec_test(x, test_func) 285 | else: 286 | yield test_func(x) 287 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved 2 | 3 | Apache License 4 | Version 2.0, January 2004 5 | http://www.apache.org/licenses/ 6 | 7 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 8 | 9 | 1. Definitions. 10 | 11 | "License" shall mean the terms and conditions for use, reproduction, 12 | and distribution as defined by Sections 1 through 9 of this document. 13 | 14 | "Licensor" shall mean the copyright owner or entity authorized by 15 | the copyright owner that is granting the License. 16 | 17 | "Legal Entity" shall mean the union of the acting entity and all 18 | other entities that control, are controlled by, or are under common 19 | control with that entity. For the purposes of this definition, 20 | "control" means (i) the power, direct or indirect, to cause the 21 | direction or management of such entity, whether by contract or 22 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 23 | outstanding shares, or (iii) beneficial ownership of such entity. 24 | 25 | "You" (or "Your") shall mean an individual or Legal Entity 26 | exercising permissions granted by this License. 27 | 28 | "Source" form shall mean the preferred form for making modifications, 29 | including but not limited to software source code, documentation 30 | source, and configuration files. 31 | 32 | "Object" form shall mean any form resulting from mechanical 33 | transformation or translation of a Source form, including but 34 | not limited to compiled object code, generated documentation, 35 | and conversions to other media types. 36 | 37 | "Work" shall mean the work of authorship, whether in Source or 38 | Object form, made available under the License, as indicated by a 39 | copyright notice that is included in or attached to the work 40 | (an example is provided in the Appendix below). 41 | 42 | "Derivative Works" shall mean any work, whether in Source or Object 43 | form, that is based on (or derived from) the Work and for which the 44 | editorial revisions, annotations, elaborations, or other modifications 45 | represent, as a whole, an original work of authorship. For the purposes 46 | of this License, Derivative Works shall not include works that remain 47 | separable from, or merely link (or bind by name) to the interfaces of, 48 | the Work and Derivative Works thereof. 49 | 50 | "Contribution" shall mean any work of authorship, including 51 | the original version of the Work and any modifications or additions 52 | to that Work or Derivative Works thereof, that is intentionally 53 | submitted to Licensor for inclusion in the Work by the copyright owner 54 | or by an individual or Legal Entity authorized to submit on behalf of 55 | the copyright owner. For the purposes of this definition, "submitted" 56 | means any form of electronic, verbal, or written communication sent 57 | to the Licensor or its representatives, including but not limited to 58 | communication on electronic mailing lists, source code control systems, 59 | and issue tracking systems that are managed by, or on behalf of, the 60 | Licensor for the purpose of discussing and improving the Work, but 61 | excluding communication that is conspicuously marked or otherwise 62 | designated in writing by the copyright owner as "Not a Contribution." 63 | 64 | "Contributor" shall mean Licensor and any individual or Legal Entity 65 | on behalf of whom a Contribution has been received by Licensor and 66 | subsequently incorporated within the Work. 67 | 68 | 2. Grant of Copyright License. Subject to the terms and conditions of 69 | this License, each Contributor hereby grants to You a perpetual, 70 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 71 | copyright license to reproduce, prepare Derivative Works of, 72 | publicly display, publicly perform, sublicense, and distribute the 73 | Work and such Derivative Works in Source or Object form. 74 | 75 | 3. Grant of Patent License. Subject to the terms and conditions of 76 | this License, each Contributor hereby grants to You a perpetual, 77 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 78 | (except as stated in this section) patent license to make, have made, 79 | use, offer to sell, sell, import, and otherwise transfer the Work, 80 | where such license applies only to those patent claims licensable 81 | by such Contributor that are necessarily infringed by their 82 | Contribution(s) alone or by combination of their Contribution(s) 83 | with the Work to which such Contribution(s) was submitted. If You 84 | institute patent litigation against any entity (including a 85 | cross-claim or counterclaim in a lawsuit) alleging that the Work 86 | or a Contribution incorporated within the Work constitutes direct 87 | or contributory patent infringement, then any patent licenses 88 | granted to You under this License for that Work shall terminate 89 | as of the date such litigation is filed. 90 | 91 | 4. Redistribution. You may reproduce and distribute copies of the 92 | Work or Derivative Works thereof in any medium, with or without 93 | modifications, and in Source or Object form, provided that You 94 | meet the following conditions: 95 | 96 | (a) You must give any other recipients of the Work or 97 | Derivative Works a copy of this License; and 98 | 99 | (b) You must cause any modified files to carry prominent notices 100 | stating that You changed the files; and 101 | 102 | (c) You must retain, in the Source form of any Derivative Works 103 | that You distribute, all copyright, patent, trademark, and 104 | attribution notices from the Source form of the Work, 105 | excluding those notices that do not pertain to any part of 106 | the Derivative Works; and 107 | 108 | (d) If the Work includes a "NOTICE" text file as part of its 109 | distribution, then any Derivative Works that You distribute must 110 | include a readable copy of the attribution notices contained 111 | within such NOTICE file, excluding those notices that do not 112 | pertain to any part of the Derivative Works, in at least one 113 | of the following places: within a NOTICE text file distributed 114 | as part of the Derivative Works; within the Source form or 115 | documentation, if provided along with the Derivative Works; or, 116 | within a display generated by the Derivative Works, if and 117 | wherever such third-party notices normally appear. The contents 118 | of the NOTICE file are for informational purposes only and 119 | do not modify the License. You may add Your own attribution 120 | notices within Derivative Works that You distribute, alongside 121 | or as an addendum to the NOTICE text from the Work, provided 122 | that such additional attribution notices cannot be construed 123 | as modifying the License. 124 | 125 | You may add Your own copyright statement to Your modifications and 126 | may provide additional or different license terms and conditions 127 | for use, reproduction, or distribution of Your modifications, or 128 | for any such Derivative Works as a whole, provided Your use, 129 | reproduction, and distribution of the Work otherwise complies with 130 | the conditions stated in this License. 131 | 132 | 5. Submission of Contributions. Unless You explicitly state otherwise, 133 | any Contribution intentionally submitted for inclusion in the Work 134 | by You to the Licensor shall be under the terms and conditions of 135 | this License, without any additional terms or conditions. 136 | Notwithstanding the above, nothing herein shall supersede or modify 137 | the terms of any separate license agreement you may have executed 138 | with Licensor regarding such Contributions. 139 | 140 | 6. Trademarks. This License does not grant permission to use the trade 141 | names, trademarks, service marks, or product names of the Licensor, 142 | except as required for reasonable and customary use in describing the 143 | origin of the Work and reproducing the content of the NOTICE file. 144 | 145 | 7. Disclaimer of Warranty. Unless required by applicable law or 146 | agreed to in writing, Licensor provides the Work (and each 147 | Contributor provides its Contributions) on an "AS IS" BASIS, 148 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 149 | implied, including, without limitation, any warranties or conditions 150 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 151 | PARTICULAR PURPOSE. You are solely responsible for determining the 152 | appropriateness of using or redistributing the Work and assume any 153 | risks associated with Your exercise of permissions under this License. 154 | 155 | 8. Limitation of Liability. In no event and under no legal theory, 156 | whether in tort (including negligence), contract, or otherwise, 157 | unless required by applicable law (such as deliberate and grossly 158 | negligent acts) or agreed to in writing, shall any Contributor be 159 | liable to You for damages, including any direct, indirect, special, 160 | incidental, or consequential damages of any character arising as a 161 | result of this License or out of the use or inability to use the 162 | Work (including but not limited to damages for loss of goodwill, 163 | work stoppage, computer failure or malfunction, or any and all 164 | other commercial damages or losses), even if such Contributor 165 | has been advised of the possibility of such damages. 166 | 167 | 9. Accepting Warranty or Additional Liability. While redistributing 168 | the Work or Derivative Works thereof, You may choose to offer, 169 | and charge a fee for, acceptance of support, warranty, indemnity, 170 | or other liability obligations and/or rights consistent with this 171 | License. However, in accepting such obligations, You may act only 172 | on Your own behalf and on Your sole responsibility, not on behalf 173 | of any other Contributor, and only if You agree to indemnify, 174 | defend, and hold each Contributor harmless for any liability 175 | incurred by, or claims asserted against, such Contributor by reason 176 | of your accepting any such warranty or additional liability. 177 | 178 | END OF TERMS AND CONDITIONS 179 | 180 | APPENDIX: How to apply the Apache License to your work. 181 | 182 | To apply the Apache License to your work, attach the following 183 | boilerplate notice, with the fields enclosed by brackets "[]" 184 | replaced with your own identifying information. (Don't include 185 | the brackets!) The text should be enclosed in the appropriate 186 | comment syntax for the file format. We also recommend that a 187 | file or class name and description of purpose be included on the 188 | same "printed page" as the copyright notice for easier 189 | identification within third-party archives. 190 | 191 | Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. 192 | 193 | Licensed under the Apache License, Version 2.0 (the "License"); 194 | you may not use this file except in compliance with the License. 195 | You may obtain a copy of the License at 196 | 197 | http://www.apache.org/licenses/LICENSE-2.0 198 | 199 | Unless required by applicable law or agreed to in writing, software 200 | distributed under the License is distributed on an "AS IS" BASIS, 201 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 202 | See the License for the specific language governing permissions and 203 | limitations under the License. 204 | 205 | -------------------------------------------------------------------------------- /bowler/tool.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import difflib 9 | import logging 10 | import multiprocessing 11 | import os 12 | import time 13 | from queue import Empty 14 | from typing import Any, Callable, Iterator, List, Optional, Sequence, Tuple 15 | 16 | from tools import click 17 | from fissix.pgen2.parse import ParseError 18 | from fissix.refactor import RefactoringTool 19 | 20 | from .helpers import filename_endswith 21 | from .types import ( 22 | BadTransform, 23 | BowlerException, 24 | BowlerQuit, 25 | Filename, 26 | FilenameMatcher, 27 | Fixers, 28 | Hunk, 29 | Node, 30 | Processor, 31 | RetryFile, 32 | ) 33 | 34 | PROMPT_HELP = { 35 | "y": "apply this hunk", 36 | "n": "skip this hunk", 37 | "a": "apply this hunk and all remaining hunks for this file", 38 | "d": "skip this hunk and all remaining hunks for this file", 39 | "q": "quit; do not apply this hunk or any remaining hunks", 40 | "?": "show help", 41 | } 42 | 43 | log = logging.getLogger(__name__) 44 | 45 | 46 | def diff_texts(a, b, filename): 47 | lines_a = a.splitlines() 48 | lines_b = b.splitlines() 49 | return difflib.unified_diff(lines_a, lines_b, filename, filename, lineterm="") 50 | 51 | 52 | def prompt_user(question, options, default = ""): 53 | options = options.lower() 54 | default = default.lower() 55 | assert len(default) < 2 and default in options 56 | 57 | if "?" not in options: 58 | options += "?" 59 | 60 | prompt_options = ",".join(o.upper() if o == default else o for o in options) 61 | prompt = "{} [{}]? ".format(question, prompt_options) 62 | result = "" 63 | 64 | while True: 65 | result = input(prompt).strip().lower() 66 | if result == "?": 67 | for option in PROMPT_HELP: 68 | click.secho("{} - {}".format(option, PROMPT_HELP[option]), fg="red", bold=True) 69 | 70 | elif len(result) == 1 and result in options: 71 | return result 72 | 73 | elif result: 74 | click.echo('invalid response "{}"'.format(result)) 75 | 76 | elif default: 77 | return default 78 | 79 | 80 | class BowlerTool(RefactoringTool): 81 | NUM_PROCESSES = os.cpu_count() or 1 82 | IN_PROCESS = False # set when run DEBUG mode from command line 83 | 84 | def __init__( 85 | self, 86 | fixers, 87 | *args, 88 | need_confirm = False, 89 | parallel = None, 90 | write = False, 91 | silent = False, 92 | in_process = False, 93 | hunk_processor = None, 94 | filename_matcher = None, 95 | **kwargs): 96 | self.backup = kwargs.pop('backup', None) 97 | self.print_hint = kwargs.pop('print_hint', True) 98 | options = kwargs.pop("options", {}) 99 | options["print_function"] = True 100 | super().__init__(fixers, *args, options=options, **kwargs) 101 | self.need_confirm = need_confirm 102 | self.parallel = parallel 103 | self.queue_count = 0 104 | self.queue = multiprocessing.JoinableQueue() # type: ignore 105 | # if need_confirm, refactor files in one process one by one to avoid log disorder. 106 | if self.need_confirm: 107 | self.results = multiprocessing.Queue(maxsize=1) # type: ignore 108 | self.NUM_PROCESSES = 1 109 | self.semaphore_confirm = multiprocessing.Semaphore(1) 110 | self.parallel = None 111 | else: 112 | if self.parallel is not None: 113 | self.NUM_PROCESSES = max(1, min(self.parallel, 100)) 114 | self.results = multiprocessing.Queue() # type: ignore 115 | self.semaphore = multiprocessing.Semaphore(self.NUM_PROCESSES) 116 | self.write = write 117 | self.silent = silent 118 | # pick the most restrictive of flags 119 | self.in_process = in_process or self.IN_PROCESS 120 | self.exceptions = [] 121 | if hunk_processor is not None: 122 | self.hunk_processor = hunk_processor 123 | else: 124 | self.hunk_processor = lambda f, h: True 125 | self.filename_matcher = filename_matcher or filename_endswith(".py") 126 | 127 | def log_error(self, msg, *args, **kwds): 128 | self.logger.error(msg, *args, **kwds) 129 | 130 | def get_fixers(self): 131 | fixers = [f(self.options, self.fixer_log) for f in self.fixers] 132 | pre = [f for f in fixers if f.order == "pre"] 133 | post = [f for f in fixers if f.order == "post"] 134 | return pre, post 135 | 136 | def processed_file( 137 | self, new_text, filename, old_text = "", *args, **kwargs 138 | ): 139 | self.files.append(filename) 140 | hunks = [] 141 | if old_text != new_text: 142 | a, b, *lines = list(diff_texts(old_text, new_text, filename)) 143 | 144 | hunk = [] 145 | for line in lines: 146 | if line.startswith("@@"): 147 | if hunk: 148 | hunks.append([a, b, *hunk]) 149 | hunk = [] 150 | hunk.append(line) 151 | 152 | if hunk: 153 | hunks.append([a, b, *hunk]) 154 | 155 | try: 156 | new_tree = self.driver.parse_string(new_text) 157 | if new_tree is None: 158 | raise AssertionError("Re-parsed CST is None") 159 | except Exception as e: 160 | raise BadTransform( 161 | "Transforms generated invalid CST for {}".format(filename), 162 | filename=filename, 163 | hunks=hunks, 164 | ) from e 165 | 166 | return hunks 167 | 168 | def refactor_file(self, filename, *a, **k): 169 | try: 170 | hunks = [] 171 | input, encoding = self._read_python_source(filename) 172 | if input is None: 173 | # Reading the file failed. 174 | return hunks 175 | except (OSError, UnicodeDecodeError) as e: 176 | log.error("Skipping {}: failed to read because {}".format(filename, e)) 177 | return hunks 178 | 179 | try: 180 | if not input.endswith("\n"): 181 | input += "\n" 182 | tree = self.refactor_string(input, filename) 183 | if tree: 184 | hunks = self.processed_file(str(tree), filename, input) 185 | except ParseError as e: 186 | log.exception("Skipping {filename}: failed to parse ({e})") 187 | 188 | return hunks, str(tree).encode(encoding) 189 | 190 | def refactor_dir(self, dir_name, *a, **k): 191 | """Descends down a directory and refactor every Python file found. 192 | 193 | Python files are those for which `self.filename_matcher(filename)` 194 | returns true, to allow for custom extensions. 195 | 196 | Files and subdirectories starting with '.' are skipped. 197 | """ 198 | for dirpath, dirnames, filenames in os.walk(dir_name): 199 | self.log_debug("Descending into %s", dirpath) 200 | dirnames.sort() 201 | filenames.sort() 202 | for name in filenames: 203 | fullname = os.path.join(dirpath, name) 204 | if not name.startswith(".") and self.filename_matcher( 205 | Filename(fullname) 206 | ): 207 | self.queue_work(Filename(fullname)) 208 | # Modify dirnames in-place to remove subdirs with leading dots 209 | dirnames[:] = [dn for dn in dirnames if not dn.startswith(".")] 210 | 211 | def refactor_queue(self): 212 | self.semaphore.acquire() 213 | while True: 214 | filename = self.queue.get() 215 | 216 | if filename is None: 217 | break 218 | 219 | try: 220 | if self.need_confirm: 221 | self.semaphore_confirm.acquire() 222 | hunks, new_text = self.refactor_file(filename) 223 | self.results.put((filename, hunks, None, new_text)) 224 | 225 | except RetryFile: 226 | self.log_debug("Retrying {} later...".format(filename)) 227 | self.queue.put(filename) 228 | except BowlerException as e: 229 | log.exception("Bowler exception during transform of {}: {}".format(filename, e)) 230 | self.results.put((filename, e.hunks, e, None)) 231 | except Exception as e: 232 | log.exception("Skipping {}: failed to transform because {}".format(filename, e)) 233 | self.results.put((filename, [], e, None)) 234 | 235 | finally: 236 | self.queue.task_done() 237 | self.semaphore.release() 238 | 239 | def queue_work(self, filename): 240 | self.queue.put(filename) 241 | self.queue_count += 1 242 | 243 | def refactor(self, items, *a, **k): 244 | """Refactor a list of files and directories.""" 245 | 246 | for dir_or_file in sorted(items): 247 | if os.path.isdir(dir_or_file): 248 | self.refactor_dir(dir_or_file) 249 | else: 250 | self.queue_work(Filename(dir_or_file)) 251 | 252 | children = [] 253 | if self.in_process: 254 | self.queue.put(None) 255 | self.refactor_queue() 256 | else: 257 | child_count = max(1, min(self.NUM_PROCESSES, self.queue_count)) 258 | self.log_debug("starting {} processes".format(child_count)) 259 | for i in range(child_count): 260 | child = multiprocessing.Process(target=self.refactor_queue) 261 | child.start() 262 | children.append(child) 263 | self.queue.put(None) 264 | 265 | results_count = 0 266 | 267 | while True: 268 | try: 269 | filename, hunks, exc, new_text = self.results.get_nowait() 270 | results_count += 1 271 | 272 | if exc: 273 | self.log_error("{}: {}".format(type(exc).__name__, exc)) 274 | if exc.__cause__: 275 | self.log_error( 276 | " {}: {}".format(type(exc.__cause__).__name__, exc.__cause__) 277 | ) 278 | if isinstance(exc, BowlerException) and exc.hunks: 279 | diff = "\n".join("\n".join(hunk) for hunk in exc.hunks) 280 | self.log_error("Generated transform:\n{}".format(diff)) 281 | self.exceptions.append(exc) 282 | else: 283 | self.log_debug("results: got {} hunks for {}".format(len(hunks), filename)) 284 | self.print_hunks(filename, hunks) 285 | if hunks and self.write: 286 | if self.need_confirm: 287 | if click.confirm(click.style('"{}" will be modified in-place, and it has been backed up to "{}". Do you want to continue?'.format(filename, self.backup), fg='red', bold=True)): 288 | self.write_result(filename, new_text) 289 | if self.print_hint: 290 | click.secho('"{}" refactor done! Recover your files from "{}" if anything is wrong.'.format(filename, self.backup)) 291 | else: 292 | if self.print_hint: 293 | click.secho('"{}" refactor cancelled!'.format(filename), fg='red', bold=True) 294 | else: 295 | self.write_result(filename, new_text) 296 | if self.print_hint: 297 | click.secho('"{}" refactor done! Recover your files from "{}" if anything is wrong.'.format(filename, self.backup)) 298 | if self.need_confirm: 299 | self.semaphore_confirm.release() 300 | 301 | except Empty: 302 | if self.queue.empty() and results_count == self.queue_count: 303 | break 304 | 305 | elif not self.in_process and not any( 306 | child.is_alive() for child in children 307 | ): 308 | self.log_debug("child processes stopped without consuming work") 309 | break 310 | 311 | else: 312 | time.sleep(0.05) 313 | 314 | except BowlerQuit: 315 | for child in children: 316 | child.terminate() 317 | break 318 | 319 | self.log_debug("all children stopped and all diff hunks processed") 320 | 321 | def print_hunks(self, filename, hunks): 322 | auto_yes = False 323 | result = "" 324 | # print same filename header only once. 325 | hunks_header = set() 326 | for hunk in hunks: 327 | header = "{} {}".format(hunk[0], hunk[1]) 328 | if self.hunk_processor(filename, hunk) is False: 329 | continue 330 | if not self.silent: 331 | # print header, e.g. 332 | # --- ./model.py 333 | # +++ ./model.py 334 | if header not in hunks_header: 335 | for line in hunk[:2]: 336 | if line.startswith("---"): 337 | click.secho(line, fg="red", bold=True) 338 | elif line.startswith("+++"): 339 | click.secho(line, fg="green", bold=True) 340 | hunks_header.add(header) 341 | 342 | # print diff content 343 | for line in hunk[2:]: 344 | if line.startswith("-"): 345 | click.secho(line, fg="red") 346 | elif line.startswith("+"): 347 | click.secho(line, fg="green") 348 | else: 349 | click.echo(line) 350 | 351 | def write_result(self, filename, new_text): 352 | if isinstance(new_text, bytes): 353 | with open(filename, 'wb') as f: 354 | f.write(new_text) 355 | 356 | def run(self, paths): 357 | if not self.errors: 358 | self.refactor(paths) 359 | self.summarize() 360 | 361 | return int(bool(self.errors or self.exceptions)) 362 | -------------------------------------------------------------------------------- /fissix/pgen2/pgen.py: -------------------------------------------------------------------------------- 1 | # Copyright 2004-2005 Elemental Security, Inc. All Rights Reserved. 2 | # Licensed to PSF under a Contributor Agreement. 3 | 4 | # Pgen imports 5 | from . import grammar, token, tokenize 6 | 7 | 8 | class PgenGrammar(grammar.Grammar): 9 | pass 10 | 11 | 12 | class ParserGenerator(object): 13 | def __init__(self, filename, stream=None): 14 | close_stream = None 15 | if stream is None: 16 | stream = open(filename) 17 | close_stream = stream.close 18 | self.filename = filename 19 | self.stream = stream 20 | self.generator = tokenize.generate_tokens(stream.readline) 21 | self.gettoken() # Initialize lookahead 22 | self.dfas, self.startsymbol = self.parse() 23 | if close_stream is not None: 24 | close_stream() 25 | self.first = {} # map from symbol name to set of tokens 26 | self.addfirstsets() 27 | 28 | def make_grammar(self): 29 | c = PgenGrammar() 30 | names = list(self.dfas.keys()) 31 | names.sort() 32 | names.remove(self.startsymbol) 33 | names.insert(0, self.startsymbol) 34 | for name in names: 35 | i = 256 + len(c.symbol2number) 36 | c.symbol2number[name] = i 37 | c.number2symbol[i] = name 38 | for name in names: 39 | dfa = self.dfas[name] 40 | states = [] 41 | for state in dfa: 42 | arcs = [] 43 | for label, next in sorted(state.arcs.items()): 44 | arcs.append((self.make_label(c, label), dfa.index(next))) 45 | if state.isfinal: 46 | arcs.append((0, dfa.index(state))) 47 | states.append(arcs) 48 | c.states.append(states) 49 | c.dfas[c.symbol2number[name]] = (states, self.make_first(c, name)) 50 | c.start = c.symbol2number[self.startsymbol] 51 | return c 52 | 53 | def make_first(self, c, name): 54 | rawfirst = self.first[name] 55 | first = {} 56 | for label in sorted(rawfirst): 57 | ilabel = self.make_label(c, label) 58 | ##assert ilabel not in first # XXX failed on <> ... != 59 | first[ilabel] = 1 60 | return first 61 | 62 | def make_label(self, c, label): 63 | # XXX Maybe this should be a method on a subclass of converter? 64 | ilabel = len(c.labels) 65 | if label[0].isalpha(): 66 | # Either a symbol name or a named token 67 | if label in c.symbol2number: 68 | # A symbol name (a non-terminal) 69 | if label in c.symbol2label: 70 | return c.symbol2label[label] 71 | else: 72 | c.labels.append((c.symbol2number[label], None)) 73 | c.symbol2label[label] = ilabel 74 | return ilabel 75 | else: 76 | # A named token (NAME, NUMBER, STRING) 77 | itoken = getattr(token, label, None) 78 | assert isinstance(itoken, int), label 79 | assert itoken in token.tok_name, label 80 | if itoken in c.tokens: 81 | return c.tokens[itoken] 82 | else: 83 | c.labels.append((itoken, None)) 84 | c.tokens[itoken] = ilabel 85 | return ilabel 86 | else: 87 | # Either a keyword or an operator 88 | assert label[0] in ('"', "'"), label 89 | value = eval(label) 90 | if value[0].isalpha(): 91 | # A keyword 92 | if value in c.keywords: 93 | return c.keywords[value] 94 | else: 95 | c.labels.append((token.NAME, value)) 96 | c.keywords[value] = ilabel 97 | return ilabel 98 | else: 99 | # An operator (any non-numeric token) 100 | itoken = grammar.opmap[value] # Fails if unknown token 101 | if itoken in c.tokens: 102 | return c.tokens[itoken] 103 | else: 104 | c.labels.append((itoken, None)) 105 | c.tokens[itoken] = ilabel 106 | return ilabel 107 | 108 | def addfirstsets(self): 109 | names = list(self.dfas.keys()) 110 | names.sort() 111 | for name in names: 112 | if name not in self.first: 113 | self.calcfirst(name) 114 | # print name, self.first[name].keys() 115 | 116 | def calcfirst(self, name): 117 | dfa = self.dfas[name] 118 | self.first[name] = None # dummy to detect left recursion 119 | state = dfa[0] 120 | totalset = {} 121 | overlapcheck = {} 122 | for label, next in state.arcs.items(): 123 | if label in self.dfas: 124 | if label in self.first: 125 | fset = self.first[label] 126 | if fset is None: 127 | raise ValueError("recursion for rule %r" % name) 128 | else: 129 | self.calcfirst(label) 130 | fset = self.first[label] 131 | totalset.update(fset) 132 | overlapcheck[label] = fset 133 | else: 134 | totalset[label] = 1 135 | overlapcheck[label] = {label: 1} 136 | inverse = {} 137 | for label, itsfirst in overlapcheck.items(): 138 | for symbol in itsfirst: 139 | if symbol in inverse: 140 | raise ValueError( 141 | "rule %s is ambiguous; %s is in the" 142 | " first sets of %s as well as %s" 143 | % (name, symbol, label, inverse[symbol]) 144 | ) 145 | inverse[symbol] = label 146 | self.first[name] = totalset 147 | 148 | def parse(self): 149 | dfas = {} 150 | startsymbol = None 151 | # MSTART: (NEWLINE | RULE)* ENDMARKER 152 | while self.type != token.ENDMARKER: 153 | while self.type == token.NEWLINE: 154 | self.gettoken() 155 | # RULE: NAME ':' RHS NEWLINE 156 | name = self.expect(token.NAME) 157 | self.expect(token.OP, ":") 158 | a, z = self.parse_rhs() 159 | self.expect(token.NEWLINE) 160 | # self.dump_nfa(name, a, z) 161 | dfa = self.make_dfa(a, z) 162 | # self.dump_dfa(name, dfa) 163 | oldlen = len(dfa) 164 | self.simplify_dfa(dfa) 165 | newlen = len(dfa) 166 | dfas[name] = dfa 167 | # print name, oldlen, newlen 168 | if startsymbol is None: 169 | startsymbol = name 170 | return dfas, startsymbol 171 | 172 | def make_dfa(self, start, finish): 173 | # To turn an NFA into a DFA, we define the states of the DFA 174 | # to correspond to *sets* of states of the NFA. Then do some 175 | # state reduction. Let's represent sets as dicts with 1 for 176 | # values. 177 | assert isinstance(start, NFAState) 178 | assert isinstance(finish, NFAState) 179 | 180 | def closure(state): 181 | base = {} 182 | addclosure(state, base) 183 | return base 184 | 185 | def addclosure(state, base): 186 | assert isinstance(state, NFAState) 187 | if state in base: 188 | return 189 | base[state] = 1 190 | for label, next in state.arcs: 191 | if label is None: 192 | addclosure(next, base) 193 | 194 | states = [DFAState(closure(start), finish)] 195 | for state in states: # NB states grows while we're iterating 196 | arcs = {} 197 | for nfastate in state.nfaset: 198 | for label, next in nfastate.arcs: 199 | if label is not None: 200 | addclosure(next, arcs.setdefault(label, {})) 201 | for label, nfaset in sorted(arcs.items()): 202 | for st in states: 203 | if st.nfaset == nfaset: 204 | break 205 | else: 206 | st = DFAState(nfaset, finish) 207 | states.append(st) 208 | state.addarc(st, label) 209 | return states # List of DFAState instances; first one is start 210 | 211 | def dump_nfa(self, name, start, finish): 212 | print("Dump of NFA for", name) 213 | todo = [start] 214 | for i, state in enumerate(todo): 215 | print(" State", i, state is finish and "(final)" or "") 216 | for label, next in state.arcs: 217 | if next in todo: 218 | j = todo.index(next) 219 | else: 220 | j = len(todo) 221 | todo.append(next) 222 | if label is None: 223 | print(" -> %d" % j) 224 | else: 225 | print(" %s -> %d" % (label, j)) 226 | 227 | def dump_dfa(self, name, dfa): 228 | print("Dump of DFA for", name) 229 | for i, state in enumerate(dfa): 230 | print(" State", i, state.isfinal and "(final)" or "") 231 | for label, next in sorted(state.arcs.items()): 232 | print(" %s -> %d" % (label, dfa.index(next))) 233 | 234 | def simplify_dfa(self, dfa): 235 | # This is not theoretically optimal, but works well enough. 236 | # Algorithm: repeatedly look for two states that have the same 237 | # set of arcs (same labels pointing to the same nodes) and 238 | # unify them, until things stop changing. 239 | 240 | # dfa is a list of DFAState instances 241 | changes = True 242 | while changes: 243 | changes = False 244 | for i, state_i in enumerate(dfa): 245 | for j in range(i + 1, len(dfa)): 246 | state_j = dfa[j] 247 | if state_i == state_j: 248 | # print " unify", i, j 249 | del dfa[j] 250 | for state in dfa: 251 | state.unifystate(state_j, state_i) 252 | changes = True 253 | break 254 | 255 | def parse_rhs(self): 256 | # RHS: ALT ('|' ALT)* 257 | a, z = self.parse_alt() 258 | if self.value != "|": 259 | return a, z 260 | else: 261 | aa = NFAState() 262 | zz = NFAState() 263 | aa.addarc(a) 264 | z.addarc(zz) 265 | while self.value == "|": 266 | self.gettoken() 267 | a, z = self.parse_alt() 268 | aa.addarc(a) 269 | z.addarc(zz) 270 | return aa, zz 271 | 272 | def parse_alt(self): 273 | # ALT: ITEM+ 274 | a, b = self.parse_item() 275 | while self.value in ("(", "[") or self.type in (token.NAME, token.STRING): 276 | c, d = self.parse_item() 277 | b.addarc(c) 278 | b = d 279 | return a, b 280 | 281 | def parse_item(self): 282 | # ITEM: '[' RHS ']' | ATOM ['+' | '*'] 283 | if self.value == "[": 284 | self.gettoken() 285 | a, z = self.parse_rhs() 286 | self.expect(token.OP, "]") 287 | a.addarc(z) 288 | return a, z 289 | else: 290 | a, z = self.parse_atom() 291 | value = self.value 292 | if value not in ("+", "*"): 293 | return a, z 294 | self.gettoken() 295 | z.addarc(a) 296 | if value == "+": 297 | return a, z 298 | else: 299 | return a, a 300 | 301 | def parse_atom(self): 302 | # ATOM: '(' RHS ')' | NAME | STRING 303 | if self.value == "(": 304 | self.gettoken() 305 | a, z = self.parse_rhs() 306 | self.expect(token.OP, ")") 307 | return a, z 308 | elif self.type in (token.NAME, token.STRING): 309 | a = NFAState() 310 | z = NFAState() 311 | a.addarc(z, self.value) 312 | self.gettoken() 313 | return a, z 314 | else: 315 | self.raise_error( 316 | "expected (...) or NAME or STRING, got %s/%s", self.type, self.value 317 | ) 318 | 319 | def expect(self, type, value=None): 320 | if self.type != type or (value is not None and self.value != value): 321 | self.raise_error( 322 | "expected %s/%s, got %s/%s", type, value, self.type, self.value 323 | ) 324 | value = self.value 325 | self.gettoken() 326 | return value 327 | 328 | def gettoken(self): 329 | tup = next(self.generator) 330 | while tup[0] in (tokenize.COMMENT, tokenize.NL): 331 | tup = next(self.generator) 332 | self.type, self.value, self.begin, self.end, self.line = tup 333 | # print token.tok_name[self.type], repr(self.value) 334 | 335 | def raise_error(self, msg, *args): 336 | if args: 337 | try: 338 | msg = msg % args 339 | except: 340 | msg = " ".join([msg] + list(map(str, args))) 341 | raise SyntaxError(msg, (self.filename, self.end[0], self.end[1], self.line)) 342 | 343 | 344 | class NFAState(object): 345 | def __init__(self): 346 | self.arcs = [] # list of (label, NFAState) pairs 347 | 348 | def addarc(self, next, label=None): 349 | assert label is None or isinstance(label, str) 350 | assert isinstance(next, NFAState) 351 | self.arcs.append((label, next)) 352 | 353 | 354 | class DFAState(object): 355 | def __init__(self, nfaset, final): 356 | assert isinstance(nfaset, dict) 357 | assert isinstance(next(iter(nfaset)), NFAState) 358 | assert isinstance(final, NFAState) 359 | self.nfaset = nfaset 360 | self.isfinal = final in nfaset 361 | self.arcs = {} # map from label to DFAState 362 | 363 | def addarc(self, next, label): 364 | assert isinstance(label, str) 365 | assert label not in self.arcs 366 | assert isinstance(next, DFAState) 367 | self.arcs[label] = next 368 | 369 | def unifystate(self, old, new): 370 | for label, next in self.arcs.items(): 371 | if next is old: 372 | self.arcs[label] = new 373 | 374 | def __eq__(self, other): 375 | # Equality test -- ignore the nfaset instance variable 376 | assert isinstance(other, DFAState) 377 | if self.isfinal != other.isfinal: 378 | return False 379 | # Can't just return self.arcs == other.arcs, because that 380 | # would invoke this method recursively, with cycles... 381 | if len(self.arcs) != len(other.arcs): 382 | return False 383 | for label, next in self.arcs.items(): 384 | if next is not other.arcs.get(label): 385 | return False 386 | return True 387 | 388 | __hash__ = None # For Py3 compatibility. 389 | 390 | 391 | def generate_grammar(filename="Grammar.txt"): 392 | p = ParserGenerator(filename) 393 | return p.make_grammar() 394 | -------------------------------------------------------------------------------- /fissix/fixer_util.py: -------------------------------------------------------------------------------- 1 | """Utility functions, node construction macros, etc.""" 2 | # Author: Collin Winter 3 | 4 | # Local imports 5 | from .pgen2 import token 6 | from .pytree import Leaf, Node 7 | from .pygram import python_symbols as syms 8 | from . import patcomp 9 | 10 | 11 | ########################################################### 12 | ### Common node-construction "macros" 13 | ########################################################### 14 | 15 | 16 | def KeywordArg(keyword, value): 17 | return Node(syms.argument, [keyword, Leaf(token.EQUAL, "="), value]) 18 | 19 | 20 | def LParen(): 21 | return Leaf(token.LPAR, "(") 22 | 23 | 24 | def RParen(): 25 | return Leaf(token.RPAR, ")") 26 | 27 | 28 | def Assign(target, source): 29 | """Build an assignment statement""" 30 | if not isinstance(target, list): 31 | target = [target] 32 | if not isinstance(source, list): 33 | source.prefix = " " 34 | source = [source] 35 | 36 | return Node(syms.atom, target + [Leaf(token.EQUAL, "=", prefix=" ")] + source) 37 | 38 | 39 | def Name(name, prefix=None): 40 | """Return a NAME leaf""" 41 | return Leaf(token.NAME, name, prefix=prefix) 42 | 43 | 44 | def Attr(obj, attr): 45 | """A node tuple for obj.attr""" 46 | return [obj, Node(syms.trailer, [Dot(), attr])] 47 | 48 | 49 | def Comma(): 50 | """A comma leaf""" 51 | return Leaf(token.COMMA, ",") 52 | 53 | 54 | def Dot(): 55 | """A period (.) leaf""" 56 | return Leaf(token.DOT, ".") 57 | 58 | 59 | def ArgList(args, lparen=LParen(), rparen=RParen()): 60 | """A parenthesised argument list, used by Call()""" 61 | node = Node(syms.trailer, [lparen.clone(), rparen.clone()]) 62 | if args: 63 | node.insert_child(1, Node(syms.arglist, args)) 64 | return node 65 | 66 | 67 | def Call(func_name, args=None, prefix=None): 68 | """A function call""" 69 | node = Node(syms.power, [func_name, ArgList(args)]) 70 | if prefix is not None: 71 | node.prefix = prefix 72 | return node 73 | 74 | 75 | def Newline(value="\n"): 76 | """A newline literal""" 77 | return Leaf(token.NEWLINE, value) 78 | 79 | 80 | def BlankLine(): 81 | """A blank line""" 82 | return Leaf(token.NEWLINE, "") 83 | 84 | 85 | def Number(n, prefix=None): 86 | return Leaf(token.NUMBER, n, prefix=prefix) 87 | 88 | 89 | def Subscript(index_node): 90 | """A numeric or string subscript""" 91 | return Node( 92 | syms.trailer, [Leaf(token.LBRACE, "["), index_node, Leaf(token.RBRACE, "]")] 93 | ) 94 | 95 | 96 | def String(string, prefix=None): 97 | """A string leaf""" 98 | return Leaf(token.STRING, string, prefix=prefix) 99 | 100 | 101 | def ListComp(xp, fp, it, test=None): 102 | """A list comprehension of the form [xp for fp in it if test]. 103 | 104 | If test is None, the "if test" part is omitted. 105 | """ 106 | xp.prefix = "" 107 | fp.prefix = " " 108 | it.prefix = " " 109 | for_leaf = Leaf(token.NAME, "for") 110 | for_leaf.prefix = " " 111 | in_leaf = Leaf(token.NAME, "in") 112 | in_leaf.prefix = " " 113 | inner_args = [for_leaf, fp, in_leaf, it] 114 | if test: 115 | test.prefix = " " 116 | if_leaf = Leaf(token.NAME, "if") 117 | if_leaf.prefix = " " 118 | inner_args.append(Node(syms.comp_if, [if_leaf, test])) 119 | inner = Node(syms.listmaker, [xp, Node(syms.comp_for, inner_args)]) 120 | return Node(syms.atom, [Leaf(token.LBRACE, "["), inner, Leaf(token.RBRACE, "]")]) 121 | 122 | 123 | def FromImport(package_name, name_leafs): 124 | """ Return an import statement in the form: 125 | from package import name_leafs""" 126 | # XXX: May not handle dotted imports properly (eg, package_name='foo.bar') 127 | # assert package_name == '.' or '.' not in package_name, "FromImport has "\ 128 | # "not been tested with dotted package names -- use at your own "\ 129 | # "peril!" 130 | 131 | for leaf in name_leafs: 132 | # Pull the leaves out of their old tree 133 | leaf.remove() 134 | 135 | children = [ 136 | Leaf(token.NAME, "from"), 137 | Leaf(token.NAME, package_name, prefix=" "), 138 | Leaf(token.NAME, "import", prefix=" "), 139 | Node(syms.import_as_names, name_leafs), 140 | ] 141 | imp = Node(syms.import_from, children) 142 | return imp 143 | 144 | 145 | def ImportAndCall(node, results, names): 146 | """Returns an import statement and calls a method 147 | of the module: 148 | 149 | import module 150 | module.name()""" 151 | obj = results["obj"].clone() 152 | if obj.type == syms.arglist: 153 | newarglist = obj.clone() 154 | else: 155 | newarglist = Node(syms.arglist, [obj.clone()]) 156 | after = results["after"] 157 | if after: 158 | after = [n.clone() for n in after] 159 | new = Node( 160 | syms.power, 161 | Attr(Name(names[0]), Name(names[1])) 162 | + [ 163 | Node( 164 | syms.trailer, 165 | [results["lpar"].clone(), newarglist, results["rpar"].clone()], 166 | ) 167 | ] 168 | + after, 169 | ) 170 | new.prefix = node.prefix 171 | return new 172 | 173 | 174 | ########################################################### 175 | ### Determine whether a node represents a given literal 176 | ########################################################### 177 | 178 | 179 | def is_tuple(node): 180 | """Does the node represent a tuple literal?""" 181 | if isinstance(node, Node) and node.children == [LParen(), RParen()]: 182 | return True 183 | return ( 184 | isinstance(node, Node) 185 | and len(node.children) == 3 186 | and isinstance(node.children[0], Leaf) 187 | and isinstance(node.children[1], Node) 188 | and isinstance(node.children[2], Leaf) 189 | and node.children[0].value == "(" 190 | and node.children[2].value == ")" 191 | ) 192 | 193 | 194 | def is_list(node): 195 | """Does the node represent a list literal?""" 196 | return ( 197 | isinstance(node, Node) 198 | and len(node.children) > 1 199 | and isinstance(node.children[0], Leaf) 200 | and isinstance(node.children[-1], Leaf) 201 | and node.children[0].value == "[" 202 | and node.children[-1].value == "]" 203 | ) 204 | 205 | 206 | ########################################################### 207 | ### Misc 208 | ########################################################### 209 | 210 | 211 | def parenthesize(node): 212 | return Node(syms.atom, [LParen(), node, RParen()]) 213 | 214 | 215 | consuming_calls = { 216 | "sorted", 217 | "list", 218 | "set", 219 | "any", 220 | "all", 221 | "tuple", 222 | "sum", 223 | "min", 224 | "max", 225 | "enumerate", 226 | } 227 | 228 | 229 | def attr_chain(obj, attr): 230 | """Follow an attribute chain. 231 | 232 | If you have a chain of objects where a.foo -> b, b.foo-> c, etc, 233 | use this to iterate over all objects in the chain. Iteration is 234 | terminated by getattr(x, attr) is None. 235 | 236 | Args: 237 | obj: the starting object 238 | attr: the name of the chaining attribute 239 | 240 | Yields: 241 | Each successive object in the chain. 242 | """ 243 | next = getattr(obj, attr) 244 | while next: 245 | yield next 246 | next = getattr(next, attr) 247 | 248 | 249 | p0 = """for_stmt< 'for' any 'in' node=any ':' any* > 250 | | comp_for< 'for' any 'in' node=any any* > 251 | """ 252 | p1 = """ 253 | power< 254 | ( 'iter' | 'list' | 'tuple' | 'sorted' | 'set' | 'sum' | 'dict' | 255 | 'any' | 'all' | 'enumerate' | (any* trailer< '.' 'join' >) ) 256 | trailer< '(' node=any ')' > 257 | any* 258 | > 259 | """ 260 | p2 = """ 261 | power< 262 | ( 'sorted' | 'enumerate' ) 263 | trailer< '(' arglist ')' > 264 | any* 265 | > 266 | """ 267 | pats_built = False 268 | 269 | 270 | def in_special_context(node): 271 | """ Returns true if node is in an environment where all that is required 272 | of it is being iterable (ie, it doesn't matter if it returns a list 273 | or an iterator). 274 | See test_map_nochange in test_fixers.py for some examples and tests. 275 | """ 276 | global p0, p1, p2, pats_built 277 | if not pats_built: 278 | p0 = patcomp.compile_pattern(p0) 279 | p1 = patcomp.compile_pattern(p1) 280 | p2 = patcomp.compile_pattern(p2) 281 | pats_built = True 282 | patterns = [p0, p1, p2] 283 | for pattern, parent in zip(patterns, attr_chain(node, "parent")): 284 | results = {} 285 | if pattern.match(parent, results) and results["node"] is node: 286 | return True 287 | return False 288 | 289 | 290 | def is_probably_builtin(node): 291 | """ 292 | Check that something isn't an attribute or function name etc. 293 | """ 294 | prev = node.prev_sibling 295 | if prev is not None and prev.type == token.DOT: 296 | # Attribute lookup. 297 | return False 298 | parent = node.parent 299 | if parent.type in (syms.funcdef, syms.classdef): 300 | return False 301 | if parent.type == syms.expr_stmt and parent.children[0] is node: 302 | # Assignment. 303 | return False 304 | if parent.type == syms.parameters or ( 305 | parent.type == syms.typedargslist 306 | and ( 307 | (prev is not None and prev.type == token.COMMA) 308 | or parent.children[0] is node 309 | ) 310 | ): 311 | # The name of an argument. 312 | return False 313 | return True 314 | 315 | 316 | def find_indentation(node): 317 | """Find the indentation of *node*.""" 318 | while node is not None: 319 | if node.type == syms.suite and len(node.children) > 2: 320 | indent = node.children[1] 321 | if indent.type == token.INDENT: 322 | return indent.value 323 | node = node.parent 324 | return "" 325 | 326 | 327 | ########################################################### 328 | ### The following functions are to find bindings in a suite 329 | ########################################################### 330 | 331 | 332 | def make_suite(node): 333 | if node.type == syms.suite: 334 | return node 335 | node = node.clone() 336 | parent, node.parent = node.parent, None 337 | suite = Node(syms.suite, [node]) 338 | suite.parent = parent 339 | return suite 340 | 341 | 342 | def find_root(node): 343 | """Find the top level namespace.""" 344 | # Scamper up to the top level namespace 345 | while node.type != syms.file_input: 346 | node = node.parent 347 | if not node: 348 | raise ValueError("root found before file_input node was found.") 349 | return node 350 | 351 | 352 | def does_tree_import(package, name, node): 353 | """ Returns true if name is imported from package at the 354 | top level of the tree which node belongs to. 355 | To cover the case of an import like 'import foo', use 356 | None for the package and 'foo' for the name. """ 357 | binding = find_binding(name, find_root(node), package) 358 | return bool(binding) 359 | 360 | 361 | def is_import(node): 362 | """Returns true if the node is an import statement.""" 363 | return node.type in (syms.import_name, syms.import_from) 364 | 365 | 366 | def touch_import(package, name, node, force=False): 367 | """ Works like `does_tree_import` but adds an import statement 368 | if it was not imported. """ 369 | 370 | def is_import_stmt(node): 371 | return ( 372 | node.type == syms.simple_stmt 373 | and node.children 374 | and is_import(node.children[0]) 375 | ) 376 | 377 | root = find_root(node) 378 | 379 | if not force and does_tree_import(package, name, root): 380 | return 381 | 382 | # figure out where to insert the new import. First try to find 383 | # the first import and then skip to the last one. 384 | insert_pos = offset = 0 385 | for idx, node in enumerate(root.children): 386 | if not is_import_stmt(node): 387 | continue 388 | for offset, node2 in enumerate(root.children[idx:]): 389 | if not is_import_stmt(node2): 390 | break 391 | insert_pos = idx + offset 392 | break 393 | 394 | # if there are no imports where we can insert, find the docstring. 395 | # if that also fails, we stick to the beginning of the file 396 | if insert_pos == 0: 397 | for idx, node in enumerate(root.children): 398 | if ( 399 | node.type == syms.simple_stmt 400 | and node.children 401 | and node.children[0].type == token.STRING 402 | ): 403 | insert_pos = idx + 1 404 | break 405 | 406 | if package is None: 407 | import_ = Node( 408 | syms.import_name, 409 | [Leaf(token.NAME, "import"), Leaf(token.NAME, name, prefix=" ")], 410 | ) 411 | else: 412 | import_ = FromImport(package, [Leaf(token.NAME, name, prefix=" ")]) 413 | 414 | children = [import_, Newline()] 415 | root.insert_child(insert_pos, Node(syms.simple_stmt, children)) 416 | 417 | 418 | _def_syms = {syms.classdef, syms.funcdef} 419 | 420 | 421 | def find_binding(name, node, package=None): 422 | """ Returns the node which binds variable name, otherwise None. 423 | If optional argument package is supplied, only imports will 424 | be returned. 425 | See test cases for examples.""" 426 | for child in node.children: 427 | ret = None 428 | if child.type == syms.for_stmt: 429 | if _find(name, child.children[1]): 430 | return child 431 | n = find_binding(name, make_suite(child.children[-1]), package) 432 | if n: 433 | ret = n 434 | elif child.type in (syms.if_stmt, syms.while_stmt): 435 | n = find_binding(name, make_suite(child.children[-1]), package) 436 | if n: 437 | ret = n 438 | elif child.type == syms.try_stmt: 439 | n = find_binding(name, make_suite(child.children[2]), package) 440 | if n: 441 | ret = n 442 | else: 443 | for i, kid in enumerate(child.children[3:]): 444 | if kid.type == token.COLON and kid.value == ":": 445 | # i+3 is the colon, i+4 is the suite 446 | n = find_binding( 447 | name, make_suite(child.children[i + 4]), package 448 | ) 449 | if n: 450 | ret = n 451 | elif child.type in _def_syms and child.children[1].value == name: 452 | ret = child 453 | elif _is_import_binding(child, name, package): 454 | ret = child 455 | elif child.type == syms.simple_stmt: 456 | ret = find_binding(name, child, package) 457 | elif child.type == syms.expr_stmt: 458 | if _find(name, child.children[0]): 459 | ret = child 460 | 461 | if ret: 462 | if not package: 463 | return ret 464 | if is_import(ret): 465 | return ret 466 | return None 467 | 468 | 469 | _block_syms = {syms.funcdef, syms.classdef, syms.trailer} 470 | 471 | 472 | def _find(name, node): 473 | nodes = [node] 474 | while nodes: 475 | node = nodes.pop() 476 | if node.type > 256 and node.type not in _block_syms: 477 | nodes.extend(node.children) 478 | elif node.type == token.NAME and node.value == name: 479 | return node 480 | return None 481 | 482 | 483 | def _is_import_binding(node, name, package=None): 484 | """ Will return node if node will import name, or node 485 | will import * from package. None is returned otherwise. 486 | See test cases for examples. """ 487 | 488 | if node.type == syms.import_name and not package: 489 | imp = node.children[1] 490 | if imp.type == syms.dotted_as_names: 491 | for child in imp.children: 492 | if child.type == syms.dotted_as_name: 493 | if child.children[2].value == name: 494 | return node 495 | elif child.type == token.NAME and child.value == name: 496 | return node 497 | elif imp.type == syms.dotted_as_name: 498 | last = imp.children[-1] 499 | if last.type == token.NAME and last.value == name: 500 | return node 501 | elif imp.type == token.NAME and imp.value == name: 502 | return node 503 | elif node.type == syms.import_from: 504 | # str(...) is used to make life easier here, because 505 | # from a.b import parses to ['import', ['a', '.', 'b'], ...] 506 | if package and str(node.children[1]).strip() != package: 507 | return None 508 | n = node.children[3] 509 | if package and _find("as", n): 510 | # See test_from_import_as for explanation 511 | return None 512 | elif n.type == syms.import_as_names and _find(name, n): 513 | return node 514 | elif n.type == syms.import_as_name: 515 | child = n.children[2] 516 | if child.type == token.NAME and child.value == name: 517 | return node 518 | elif n.type == token.NAME and n.value == name: 519 | return node 520 | elif package and n.type == token.STAR: 521 | return node 522 | return None 523 | -------------------------------------------------------------------------------- /paddle_upgrade_tool/tests/test_refactor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import unittest 4 | import textwrap 5 | from tempfile import NamedTemporaryFile 6 | 7 | from bowler import Query 8 | from paddle_upgrade_tool.refactor import * 9 | from paddle_upgrade_tool import utils 10 | 11 | def _refactor_helper(refactor_func, input_src, change_spec): 12 | try: 13 | ntf = NamedTemporaryFile(suffix='.py', delete=False) 14 | ntf.write(input_src.encode('utf-8')) 15 | ntf.close() 16 | q = Query(ntf.name) 17 | if utils.is_windows(): 18 | refactor_func(q, change_spec).execute(write=True, silent=True, need_confirm=False, print_hint=False, in_process=True) 19 | else: 20 | refactor_func(q, change_spec).execute(write=True, silent=True, need_confirm=False, print_hint=False) 21 | with open(ntf.name, 'r') as f: 22 | output_src = f.read() 23 | return output_src 24 | finally: 25 | os.remove(ntf.name) 26 | 27 | 28 | class TestRefactorImport(unittest.TestCase): 29 | def _run(self, change_spec, input_src, expected_src): 30 | input_src = textwrap.dedent(input_src).strip() + '\n' 31 | expected_src = textwrap.dedent(expected_src).strip() + '\n' 32 | output_src = _refactor_helper(refactor_import, input_src, change_spec) 33 | self.assertEqual(output_src, expected_src) 34 | 35 | def test_refactor_import(self): 36 | input_src = ''' 37 | import paddle 38 | ''' 39 | expected_src = ''' 40 | import paddle 41 | ''' 42 | self._run({}, input_src, expected_src) 43 | #-------------- 44 | input_src = ''' 45 | import paddle.fluid as fluid 46 | ''' 47 | expected_src = ''' 48 | import paddle 49 | ''' 50 | self._run({}, input_src, expected_src) 51 | #-------------- 52 | input_src = ''' 53 | import paddle 54 | import paddle.fluid as fluid 55 | ''' 56 | expected_src = ''' 57 | import paddle 58 | ''' 59 | self._run({}, input_src, expected_src) 60 | #-------------- 61 | input_src = ''' 62 | import paddle 63 | import paddle.fluid as fluid 64 | ''' 65 | expected_src = ''' 66 | import paddle 67 | ''' 68 | self._run({}, input_src, expected_src) 69 | #-------------- 70 | input_src = ''' 71 | import paddle 72 | import paddle.fluid as fluid 73 | fluid.api() 74 | 75 | def func(): 76 | fluid.api() 77 | ''' 78 | expected_src = ''' 79 | import paddle 80 | paddle.fluid.api() 81 | 82 | def func(): 83 | paddle.fluid.api() 84 | ''' 85 | self._run({}, input_src, expected_src) 86 | #-------------- 87 | input_src = ''' 88 | import paddle.fluid as fluid 89 | fluid.api() 90 | 91 | def func(): 92 | fluid.api() 93 | ''' 94 | expected_src = ''' 95 | import paddle 96 | paddle.fluid.api() 97 | 98 | def func(): 99 | paddle.fluid.api() 100 | ''' 101 | self._run({}, input_src, expected_src) 102 | #-------------- 103 | input_src = ''' 104 | from paddle.fluid.layers import Layer 105 | 106 | class CustomLayer(Layer): 107 | pass 108 | print(Layer.__name__) 109 | print(type(Layer)) 110 | ''' 111 | expected_src = ''' 112 | import paddle 113 | 114 | class CustomLayer(paddle.fluid.layers.Layer): 115 | pass 116 | print(paddle.fluid.layers.Layer.__name__) 117 | print(type(paddle.fluid.layers.Layer)) 118 | ''' 119 | self._run({}, input_src, expected_src) 120 | #-------------- 121 | input_src = ''' 122 | import paddle 123 | import paddle.fluid as fluid 124 | fluid.api() 125 | func(fluid=1) 126 | ''' 127 | expected_src = ''' 128 | import paddle 129 | paddle.fluid.api() 130 | func(fluid=1) 131 | ''' 132 | self._run({}, input_src, expected_src) 133 | 134 | class TestNormApiAlias(unittest.TestCase): 135 | change_spec = { 136 | "paddle.fluid.Layer": { 137 | "alias": [ 138 | "paddle.fluid.layers.Layer", 139 | "paddle.fluid.layers1.layers2.Layer", 140 | ] 141 | } 142 | } 143 | 144 | def _run(self, change_spec, input_src, expected_src): 145 | input_src = textwrap.dedent(input_src).strip() + '\n' 146 | expected_src = textwrap.dedent(expected_src).strip() + '\n' 147 | output_src = _refactor_helper(norm_api_alias, input_src, change_spec) 148 | self.assertEqual(output_src, expected_src) 149 | 150 | def test_norm_api_alias(self): 151 | input_src = ''' 152 | import paddle 153 | 154 | layer = paddle.fluid.Layer() 155 | layer = paddle.fluid.layers.Layer() 156 | layer = paddle.fluid.layers.Layer_With_Underscore() 157 | layer = paddle.fluid.layers1.layers2.Layer() 158 | ''' 159 | expected_src = ''' 160 | import paddle 161 | 162 | layer = paddle.fluid.Layer() 163 | layer = paddle.fluid.Layer() 164 | layer = paddle.fluid.layers.Layer_With_Underscore() 165 | layer = paddle.fluid.Layer() 166 | ''' 167 | self._run(self.change_spec, input_src, expected_src) 168 | 169 | 170 | class TestApiRename(unittest.TestCase): 171 | change_spec = { 172 | "paddle.fluid.Layer": { 173 | "update_to": "paddle.Layer", 174 | }, 175 | } 176 | 177 | def _run(self, change_spec, input_src, expected_src): 178 | input_src = textwrap.dedent(input_src).strip() + '\n' 179 | expected_src = textwrap.dedent(expected_src).strip() + '\n' 180 | output_src = _refactor_helper(api_rename, input_src, change_spec) 181 | self.assertEqual(output_src, expected_src) 182 | 183 | def test_rename(self): 184 | input_src = ''' 185 | import paddle 186 | 187 | layer = paddle.fluid.Layer() 188 | layer = paddle.fluid.Layer_With_Underscore() 189 | ''' 190 | expected_src = ''' 191 | import paddle 192 | 193 | layer = paddle.Layer() 194 | layer = paddle.fluid.Layer_With_Underscore() 195 | ''' 196 | self._run(self.change_spec, input_src, expected_src) 197 | 198 | class TestArgsToKwargs(unittest.TestCase): 199 | change_spec = { 200 | "paddle.add": { 201 | "args_list": ["x", "y"], 202 | }, 203 | } 204 | def _run(self, change_spec, input_src, expected_src): 205 | input_src = textwrap.dedent(input_src).strip() + '\n' 206 | expected_src = textwrap.dedent(expected_src).strip() + '\n' 207 | output_src = _refactor_helper(args_to_kwargs, input_src, change_spec) 208 | self.assertEqual(output_src, expected_src) 209 | 210 | def test_args_to_kwargs(self): 211 | input_src = ''' 212 | paddle.add(1,2) 213 | paddle.add(1, 2) 214 | paddle.add(1, y=2) 215 | paddle.add(1) 216 | paddle.add(z=1) 217 | paddle.add(paddle.to.api, paddle.to.api()) 218 | ''' 219 | expected_src = ''' 220 | paddle.add(x=1,y=2) 221 | paddle.add(x=1, y=2) 222 | paddle.add(x=1, y=2) 223 | paddle.add(x=1) 224 | paddle.add(z=1) 225 | paddle.add(x=paddle.to.api, y=paddle.to.api()) 226 | ''' 227 | self._run(self.change_spec, input_src, expected_src) 228 | 229 | class TestRefactorKwargs(unittest.TestCase): 230 | change_spec = { 231 | "paddle.add": { 232 | "args_change": [ 233 | [ "x", "x_new" ], 234 | [ "out", "" ], 235 | [ "", "name", "test" ], 236 | ], 237 | "args_warning": { 238 | "x_new": "x_new is deleted in paddle.add" 239 | } 240 | }, 241 | "paddle.reassign": { 242 | "args_change": [ 243 | [ "", "x", "2" ], 244 | ], 245 | }, 246 | } 247 | 248 | def _run(self, change_spec, input_src, expected_src): 249 | input_src = textwrap.dedent(input_src).strip() + '\n' 250 | expected_src = textwrap.dedent(expected_src).strip() + '\n' 251 | output_src = _refactor_helper(refactor_kwargs, input_src, change_spec) 252 | self.assertEqual(output_src, expected_src) 253 | 254 | def test_refactor_kwargs(self): 255 | input_src = ''' 256 | paddle.add(x=1, out=2) 257 | paddle.add(1) 258 | paddle.add() 259 | paddle.add(a=1, b=2, c=3) 260 | ''' 261 | expected_src = ''' 262 | paddle.add(x_new=1, name=test) 263 | paddle.add(1, name=test) 264 | paddle.add(name=test) 265 | paddle.add(a=1, b=2, c=3, name=test) 266 | ''' 267 | self._run(self.change_spec, input_src, expected_src) 268 | 269 | input_src = ''' 270 | # comment line1 271 | # comment line2 272 | # comment line3 273 | paddle.add(x=1, out=2) 274 | ''' 275 | expected_src = ''' 276 | # comment line1 277 | # comment line2 278 | # comment line3 279 | paddle.add(x_new=1, name=test) 280 | ''' 281 | self._run(self.change_spec, input_src, expected_src) 282 | 283 | input_src = ''' 284 | paddle.reassign(x=1) 285 | ''' 286 | expected_src = ''' 287 | paddle.reassign(x=2) 288 | ''' 289 | self._run(self.change_spec, input_src, expected_src) 290 | 291 | class TestWithRefactor(unittest.TestCase): 292 | change_spec = {} 293 | 294 | def _run(self, change_spec, input_src, expected_src): 295 | input_src = textwrap.dedent(input_src).strip() + '\n' 296 | expected_src = textwrap.dedent(expected_src).strip() + '\n' 297 | output_src = _refactor_helper(refactor_with, input_src, change_spec) 298 | self.assertEqual(output_src, expected_src) 299 | 300 | def test_rename(self): 301 | input_src = ''' 302 | import paddle 303 | 304 | with paddle.fluid.dygraph.guard(place): 305 | pass 306 | 307 | with paddle.fluid.dygraph.guard(): 308 | pass 309 | ''' 310 | expected_src = ''' 311 | import paddle 312 | 313 | paddle.disable_static(place) 314 | pass 315 | paddle.enable_static() 316 | 317 | paddle.disable_static() 318 | pass 319 | paddle.enable_static() 320 | ''' 321 | self._run(self.change_spec, input_src, expected_src) 322 | 323 | input_src = ''' 324 | import paddle 325 | 326 | with fluid.dygraph.guard(place): 327 | pass 328 | pass 329 | 330 | with fluid.dygraph.guard(): 331 | pass 332 | pass 333 | ''' 334 | expected_src = ''' 335 | import paddle 336 | 337 | paddle.disable_static(place) 338 | pass 339 | pass 340 | paddle.enable_static() 341 | 342 | paddle.disable_static() 343 | pass 344 | pass 345 | paddle.enable_static() 346 | ''' 347 | self._run(self.change_spec, input_src, expected_src) 348 | 349 | input_src = ''' 350 | import paddle 351 | 352 | with dygraph.guard(place): 353 | pass 354 | pass 355 | pass 356 | 357 | with dygraph.guard(): 358 | pass 359 | pass 360 | pass 361 | ''' 362 | expected_src = ''' 363 | import paddle 364 | 365 | paddle.disable_static(place) 366 | pass 367 | pass 368 | pass 369 | paddle.enable_static() 370 | 371 | paddle.disable_static() 372 | pass 373 | pass 374 | pass 375 | paddle.enable_static() 376 | ''' 377 | self._run(self.change_spec, input_src, expected_src) 378 | 379 | input_src = ''' 380 | import paddle 381 | 382 | # comment line1 383 | with dygraph.guard(place): 384 | pass 385 | pass 386 | 387 | # comment line2 388 | # comment line3 389 | with dygraph.guard(): 390 | pass 391 | pass 392 | ''' 393 | expected_src = ''' 394 | import paddle 395 | 396 | # comment line1 397 | paddle.disable_static(place) 398 | pass 399 | pass 400 | paddle.enable_static() 401 | 402 | # comment line2 403 | # comment line3 404 | paddle.disable_static() 405 | pass 406 | pass 407 | paddle.enable_static() 408 | ''' 409 | self._run(self.change_spec, input_src, expected_src) 410 | 411 | input_src = ''' 412 | import paddle 413 | 414 | if True is True: 415 | pass 416 | if place is None: 417 | pass 418 | with paddle.fluid.dygraph.guard(): 419 | pass 420 | pass 421 | pass 422 | else: 423 | pass 424 | ''' 425 | expected_src = ''' 426 | import paddle 427 | 428 | if True is True: 429 | pass 430 | if place is None: 431 | pass 432 | paddle.disable_static() 433 | pass 434 | pass 435 | paddle.enable_static() 436 | pass 437 | else: 438 | pass 439 | ''' 440 | self._run(self.change_spec, input_src, expected_src) 441 | 442 | input_src = ''' 443 | import paddle 444 | if True is True: 445 | with fluid.dygraph.guard(): 446 | if True is True: 447 | pass 448 | 449 | pass 450 | pass 451 | ''' 452 | expected_src = ''' 453 | import paddle 454 | if True is True: 455 | paddle.disable_static() 456 | if True is True: 457 | pass 458 | paddle.enable_static() 459 | 460 | pass 461 | pass 462 | ''' 463 | self._run(self.change_spec, input_src, expected_src) 464 | 465 | 466 | class TestActTransformer(unittest.TestCase): 467 | maxDiff = None 468 | change_spec = { 469 | "paddle.Conv2D": { 470 | "args_change": [ 471 | [ "act", "" ], 472 | ], 473 | }, 474 | "paddle.elementwise_add": { 475 | "args_change": [ 476 | [ "act", "" ], 477 | ], 478 | }, 479 | } 480 | 481 | def _run(self, change_spec, input_src, expected_src): 482 | input_src = textwrap.dedent(input_src).strip() + '\n' 483 | expected_src = textwrap.dedent(expected_src).strip() + '\n' 484 | output_src = _refactor_helper(refactor_kwargs, input_src, change_spec) 485 | self.assertEqual(output_src, expected_src) 486 | 487 | def test_act_transformer(self): 488 | input_src = ''' 489 | import paddle 490 | 491 | visible_act = "relu" 492 | 493 | class SimpleImgConvPool(): 494 | def __init__(self, act=None): 495 | self._conv2d_1 = paddle.Conv2D(act=act) 496 | self._conv2d_2 = paddle.Conv2D(act="relu") 497 | self._conv2d_3 = paddle.Conv2D(act=None) 498 | 499 | def forward(self, x): 500 | x = self._conv2d_1(x) 501 | x = self._conv2d_2(x) 502 | x = self._conv2d_3(x) 503 | x = paddle.elementwise_add(x, act="softmax") 504 | x = paddle.elementwise_add(x, act=visible_act) 505 | return x 506 | ''' 507 | expected_src = ''' 508 | import paddle 509 | 510 | visible_act = "relu" 511 | 512 | class SimpleImgConvPool(): 513 | def __init__(self, act=None): 514 | self._conv2d_1 = paddle.Conv2D() 515 | self._act = act 516 | self._conv2d_2 = paddle.Conv2D() 517 | self._conv2d_3 = paddle.Conv2D() 518 | 519 | def forward(self, x): 520 | x = self._conv2d_1(x) 521 | x = getattr(paddle.nn.functional, self._act)(x) if self._act else x 522 | x = self._conv2d_2(x) 523 | x = paddle.nn.functional.relu(x) 524 | x = self._conv2d_3(x) 525 | x = paddle.elementwise_add(x) 526 | x = paddle.nn.functional.softmax(x) 527 | x = paddle.elementwise_add(x) 528 | x = getattr(paddle.nn.functional, visible_act)(x) if visible_act else x 529 | return x 530 | ''' 531 | self._run(self.change_spec, input_src, expected_src) 532 | 533 | input_src = ''' 534 | import paddle 535 | 536 | class SimpleImgConvPool(): 537 | def __init__(self): 538 | self._conv2d = paddle.Conv2D(act="relu") 539 | 540 | @decorator 541 | def forward(self, x): 542 | x = self._conv2d(x) 543 | return x 544 | ''' 545 | expected_src = ''' 546 | import paddle 547 | 548 | class SimpleImgConvPool(): 549 | def __init__(self): 550 | self._conv2d = paddle.Conv2D() 551 | 552 | @decorator 553 | def forward(self, x): 554 | x = self._conv2d(x) 555 | x = paddle.nn.functional.relu(x) 556 | return x 557 | ''' 558 | self._run(self.change_spec, input_src, expected_src) 559 | 560 | input_src = ''' 561 | import paddle 562 | 563 | global_x = None 564 | 565 | class SimpleImgConvPool(): 566 | def __init__(self): 567 | self._conv2d = paddle.Conv2D(act="softmax") 568 | 569 | @decorator 570 | def forward(self): 571 | x = self._conv2d(global_x) 572 | return x 573 | ''' 574 | expected_src = ''' 575 | import paddle 576 | 577 | global_x = None 578 | 579 | class SimpleImgConvPool(): 580 | def __init__(self): 581 | self._conv2d = paddle.Conv2D() 582 | 583 | @decorator 584 | def forward(self): 585 | x = self._conv2d(global_x) 586 | x = paddle.nn.functional.softmax(x) 587 | return x 588 | ''' 589 | self._run(self.change_spec, input_src, expected_src) 590 | 591 | 592 | if __name__ == '__main__': 593 | unittest.main() 594 | --------------------------------------------------------------------------------