├── .github ├── dependabot.yml └── workflows │ └── check.yml ├── .gitignore ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── pyproject.toml ├── pyrefact ├── __init__.py ├── __main__.py ├── abstractions.py ├── constants.py ├── core.py ├── fixes.py ├── formatting.py ├── logs.py ├── main.py ├── object_oriented.py ├── parsing.py ├── pattern_matching.py ├── performance.py ├── performance_numpy.py ├── performance_pandas.py ├── processing.py ├── style.py ├── symbolic_math.py └── tracing.py └── tests ├── integration ├── integration_test_cases.py ├── test_format_code.py ├── test_format_code_safe.py ├── test_format_file.py ├── test_imports.py ├── test_tracing.py └── tracing_test_files │ ├── a.py │ ├── b.py │ ├── c.py │ ├── d.py │ └── e.py ├── main.py ├── main_profile.py ├── numpy.sh ├── testing_infra.py └── unit ├── test_abstractions.py ├── test_add_missing_imports.py ├── test_align_variable-names_with_convention.py ├── test_breakout_common_code_in_ifs.py ├── test_breakout_starred_args.py ├── test_deinterpolate_logging_args.py ├── test_delete_commented_code.py ├── test_delete_pointless_statements.py ├── test_delete_unreachable_code.py ├── test_delete_unused_functions_and_classes.py ├── test_early_continue.py ├── test_early_return.py ├── test_fix_duplicate_imports.py ├── test_fix_if_assign.py ├── test_fix_if_return.py ├── test_fix_import_spacing.py ├── test_fix_raise_missing_from.py ├── test_fix_reimported_names.py ├── test_fix_starred_imports.py ├── test_fix_unconventional_class_definitions.py ├── test_has_side_effect.py ├── test_hash_node.py ├── test_ignore_comments.py ├── test_implicit_defaultdict.py ├── test_implicit_dict_keys_values_items.py ├── test_implicit_dot.py ├── test_implicit_matmul.py ├── test_inline_math_comprehensions.py ├── test_invalid_escape_sequence.py ├── test_is_blocking.py ├── test_literal_value.py ├── test_match_template.py ├── test_merge_chained_comps.py ├── test_merge_nested_comprehensions.py ├── test_missing_context_manager.py ├── test_move_before_loop.py ├── test_move_imports_to_toplevel.py ├── test_optimize_contains_types.py ├── test_overused_constant.py ├── test_pattern_matching.py ├── test_pattern_zeroormore_zeroorone_zeroormany.py ├── test_redundant_elses.py ├── test_redundant_enumerate.py ├── test_remove_dead_ifs.py ├── test_remove_duplicate_dict_keys.py ├── test_remove_duplicate_functions.py ├── test_remove_duplicate_set_elts.py ├── test_remove_redundant_boolop_values.py ├── test_remove_redundant_chain_casts.py ├── test_remove_redundant_chained_calls.py ├── test_remove_redundant_comprehension_casts.py ├── test_remove_redundant_comprehensions.py ├── test_remove_redundant_iter.py ├── test_remove_unused_imports.py ├── test_replace_collection_add_update_with_collection_literal.py ├── test_replace_dict_assign_with_dict_literal.py ├── test_replace_dict_update_with_dict_literal.py ├── test_replace_dictcomp_assign_with_dict_literal.py ├── test_replace_dictcomp_update_with_dict_literal.py ├── test_replace_filter_lambda_with_comp.py ├── test_replace_for_loops_with_dict_comp.py ├── test_replace_for_loops_with_set_list_comp.py ├── test_replace_functions_with_literals.py ├── test_replace_iterrows_index.py ├── test_replace_iterrows_itertuples.py ├── test_replace_listcomp_append_with_plus.py ├── test_replace_loc_at_iloc_iat.py ├── test_replace_map_lambda_with_comp.py ├── test_replace_negated_numeric_comparison.py ├── test_replace_nested_loops_with_set_list_comp.py ├── test_replace_redundant_starred.py ├── test_replace_setcomp_add_with_union.py ├── test_replace_subscript_looping.py ├── test_replace_with_filter.py ├── test_simplify_assign_immediate_return.py ├── test_simplify_boolean_expressions.py ├── test_simplify_boolean_expressions_symmath.py ├── test_simplify_collection_unpacks.py ├── test_simplify_constrained_range.py ├── test_simplify_dict_unpacks.py ├── test_simplify_if_control_flow.py ├── test_simplify_math_iterators.py ├── test_simplify_matrix_operations.py ├── test_simplify_redundant_lambda.py ├── test_simplify_transposes.py ├── test_singleton_comparison.py ├── test_sort_imports.py ├── test_sorted_heapq.py ├── test_swap_if_else.py ├── test_trace_origin.py ├── test_undefine_unused_variables.py ├── test_unravel_classes.py └── test_unused_zip_args.py /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # Dependabot generates periodic pull requests with version updates for dependencies. 2 | 3 | version: 2 4 | updates: 5 | - package-ecosystem: "pip" # See documentation for possible values 6 | directory: "/" # Location of package manifests 7 | schedule: 8 | interval: "weekly" 9 | 10 | - package-ecosystem: "github-actions" 11 | # Workflow files stored in the 12 | # default location of `.github/workflows` 13 | directory: "/" 14 | schedule: 15 | interval: "weekly" 16 | -------------------------------------------------------------------------------- /.github/workflows/check.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | ruff-lint: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - uses: actions/checkout@v4 10 | - name: Set up Python 3.13 11 | uses: actions/setup-python@v5 12 | with: 13 | python-version: 3.13 14 | - name: Install dependencies 15 | run: | 16 | python -m pip install --upgrade pip 17 | pip install -e .[dev] 18 | - name: Run ruff lint 19 | run: | 20 | ruff check ./pyrefact 21 | 22 | check: 23 | needs: [ruff-lint] 24 | runs-on: ubuntu-latest 25 | strategy: 26 | matrix: 27 | python-version: ["3.9", "3.10", "3.11", "3.12", "pypy3.10"] 28 | steps: 29 | - uses: actions/checkout@v4 30 | - name: Set up Python ${{ matrix.python-version }} 31 | uses: actions/setup-python@v5 32 | with: 33 | python-version: ${{ matrix.python-version }} 34 | - name: Install dependencies 35 | run: | 36 | python -m pip install --upgrade pip 37 | pip install . 38 | - name: Running unit tests 39 | run: | 40 | python ./tests/main.py 41 | - name: Running on self with --safe, -s, --preserve . and variants thereof 42 | run: | 43 | pyrefact . --safe --n_cores 1 44 | python -m pyrefact . -s 45 | python -m pyrefact . --preserve . --n_cores 10 46 | pyrefact . -p . 47 | pyrefact . --safe --preserve . 48 | python -m pyrefact . -sp . 49 | cat ./pyrefact/main.py | python -m pyrefact --from-stdin 50 | - name: Install as editable 51 | run: | 52 | pip install -e . 53 | - name: Rerun unit tests 54 | run: | 55 | python ./tests/main.py 56 | 57 | check-macos: 58 | needs: [ruff-lint] 59 | runs-on: macos-latest 60 | strategy: 61 | matrix: 62 | python-version: ["3.12"] 63 | steps: 64 | - uses: actions/checkout@v4 65 | - name: Set up Python ${{ matrix.python-version }} 66 | uses: actions/setup-python@v5 67 | with: 68 | python-version: ${{ matrix.python-version }} 69 | - name: Install dependencies 70 | run: | 71 | python -m pip install --upgrade pip 72 | pip install . 73 | - name: Running unit tests 74 | run: python ./tests/main.py 75 | - name: Run on self 76 | run: pyrefact . --safe --preserve . 77 | - name: Rerun unit tests 78 | run: python ./tests/main.py 79 | 80 | check-windows: 81 | needs: [ruff-lint] 82 | runs-on: macos-latest 83 | strategy: 84 | matrix: 85 | python-version: ["3.12"] 86 | steps: 87 | - uses: actions/checkout@v4 88 | - name: Set up Python ${{ matrix.python-version }} 89 | uses: actions/setup-python@v5 90 | with: 91 | python-version: ${{ matrix.python-version }} 92 | - name: Install dependencies 93 | run: | 94 | python -m pip install --upgrade pip 95 | pip install . 96 | - name: Running unit tests 97 | run: python ./tests/main.py 98 | - name: Run on self 99 | run: pyrefact . --safe --preserve . 100 | - name: Rerun unit tests 101 | run: python ./tests/main.py 102 | 103 | check-slow: 104 | needs: [ruff-lint] 105 | runs-on: ubuntu-latest 106 | strategy: 107 | matrix: 108 | python-version: ["3.12"] 109 | steps: 110 | - uses: actions/checkout@v4 111 | - name: Set up Python ${{ matrix.python-version }} 112 | uses: actions/setup-python@v5 113 | with: 114 | python-version: ${{ matrix.python-version }} 115 | - name: Install dependencies 116 | run: | 117 | python -m pip install --upgrade pip 118 | pip install . 119 | - name: Formatting numpy repo 120 | run: | 121 | ./tests/numpy.sh 122 | 123 | deploy: 124 | environment: deploy 125 | needs: [ruff-lint, check, check-macos, check-windows] 126 | if: github.ref == 'refs/heads/main' 127 | runs-on: ubuntu-latest 128 | steps: 129 | - uses: actions/checkout@v4 130 | - uses: technote-space/get-diff-action@v6 131 | with: 132 | FILES: | 133 | pyproject.toml 134 | - name: Set up Python 135 | uses: actions/setup-python@v5 136 | with: 137 | python-version: '3.12' 138 | - name: Install dependencies 139 | run: | 140 | python -m pip install --upgrade pip 141 | pip install build 142 | - name: Build package 143 | run: python -m build 144 | - name: Publish package 145 | if: env.GIT_DIFF && env.MATCHED_FILES 146 | uses: pypa/gh-action-pypi-publish@8a08d616893759ef8e1aa1f2785787c0b97e20d6 147 | with: 148 | user: __token__ 149 | password: ${{ secrets.PYPI_DEPLOY_TOKEN_MAINONLY }} 150 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.egg-info 3 | dist 4 | .vscode 5 | build 6 | *.pstats 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 OlleLindgren 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "pyrefact" 7 | version = "100" 8 | description = "Automated Python refactoring" 9 | readme = "README.md" 10 | requires-python = ">=3.9" 11 | license = {file = "LICENSE"} 12 | authors = [ 13 | {email = "olle.ln@outlook.com"}, 14 | {name = "Olle Lindgren"} 15 | ] 16 | classifiers = ["Programming Language :: Python :: 3"] 17 | dependencies = [ 18 | "black>=23.1.0", 19 | "compactify>=2", 20 | "rmspace>=7", 21 | "sympy>=1.11.0", 22 | 'tomli>=2.0.0; python_version<"3.11"', 23 | ] 24 | 25 | [project.optional-dependencies] 26 | dev = ["ruff==0.8.1"] 27 | 28 | [project.urls] 29 | repository="https://github.com/OlleLindgren/pyrefact" 30 | 31 | [project.scripts] 32 | pyrefact = "pyrefact.__main__:main" 33 | pyrefind = "pyrefact.pattern_matching:pyrefind_main" 34 | pyreplace = "pyrefact.pattern_matching:pyreplace_main" 35 | 36 | [tool.setuptools.packages.find] 37 | include = ["pyrefact*"] 38 | 39 | [tool.black] 40 | skip_magic_trailing_comma = true 41 | line_length = 100 42 | 43 | [tool.pyrefact] 44 | line_length = 100 45 | 46 | [tool.ruff.lint.per-file-ignores] 47 | "tests/unit/test_literal_value.py" = ["F403", "E711", "E712"] 48 | "tests/unit/test_trace_origin.py" = ["F403"] 49 | "tests/integration/tracing_test_files/a.py" = ["F401"] 50 | "tests/integration/tracing_test_files/b.py" = ["F403"] 51 | "tests/integration/tracing_test_files/c.py" = ["F401"] 52 | "tests/integration/tracing_test_files/d.py" = ["F401", "E402"] 53 | "tests/integration/tracing_test_files/e.py" = ["F821"] 54 | -------------------------------------------------------------------------------- /pyrefact/__init__.py: -------------------------------------------------------------------------------- 1 | from .main import * # noqa: F403 2 | from .pattern_matching import * # noqa: F403 3 | from .processing import * # noqa: F403 4 | -------------------------------------------------------------------------------- /pyrefact/__main__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import sys 3 | 4 | from pyrefact.main import main 5 | 6 | if __name__ == "__main__": 7 | sys.exit(main(sys.argv[1:])) 8 | -------------------------------------------------------------------------------- /pyrefact/formatting.py: -------------------------------------------------------------------------------- 1 | """Code related to formatting""" 2 | 3 | import textwrap 4 | 5 | import black 6 | import compactify 7 | from pyrefact import logs as logger 8 | 9 | 10 | def format_with_black(source: str, *, line_length: int = 100) -> str: 11 | """Format code with black. 12 | 13 | Args: 14 | source (str): Python source code 15 | 16 | Returns: 17 | str: Formatted source code. 18 | """ 19 | original_source = source 20 | indent = indentation_level(source) 21 | if indent > 0: 22 | source = textwrap.dedent(source) 23 | 24 | try: 25 | source = black.format_str( 26 | source, mode=black.Mode(line_length=max(60, line_length - indent)) 27 | ) 28 | except (SyntaxError, black.parsing.InvalidInput): 29 | logger.error("Black raised InvalidInput on code:\n{}", source) 30 | return original_source 31 | 32 | if indent > 0: 33 | source = textwrap.indent(source, " " * indent) 34 | 35 | return source 36 | 37 | 38 | def collapse_trailing_parentheses(source: str) -> str: 39 | """Collapse trailing ])} together. 40 | 41 | Args: 42 | source (str): _description_ 43 | 44 | Returns: 45 | str: _description_ 46 | """ 47 | return compactify.format_code(source) 48 | 49 | 50 | def _inspect_indentsize(line: str) -> int: 51 | """Return the indent size, in spaces, at the start of a line of text. 52 | 53 | This function is the same as the undocumented inspect.indentsize() function in the stdlib. 54 | For stability, we copy the code here rather than depending on the undocumented stdlib function. 55 | """ 56 | expline = line.expandtabs() 57 | return len(expline) - len(expline.lstrip()) 58 | 59 | 60 | def indentation_level(source: str) -> int: 61 | """Return the indentation level of source code.""" 62 | return min( 63 | (_inspect_indentsize(line) for line in source.splitlines() if line.strip()), default=0 64 | ) 65 | -------------------------------------------------------------------------------- /pyrefact/logs.py: -------------------------------------------------------------------------------- 1 | """Logging""" 2 | 3 | import functools 4 | import logging 5 | 6 | 7 | class _Message: 8 | def __init__(self, /, fmt, *args, **kwargs): 9 | self.fmt = fmt 10 | self.args = args 11 | self.kwargs = kwargs 12 | 13 | def __str__(self) -> str: 14 | return self.fmt.format(*self.args, **self.kwargs) 15 | 16 | 17 | @functools.lru_cache(maxsize=1) 18 | def _get_logger() -> logging.Logger: 19 | # Why is this so complicated 20 | 21 | logger = logging.getLogger("pyrefact") 22 | handler = logging.StreamHandler() 23 | formatter = logging.Formatter("%(message)s") 24 | handler.setFormatter(formatter) 25 | logger.addHandler(handler) 26 | logger.setLevel(logging.INFO) 27 | 28 | return logger 29 | 30 | 31 | def info(fmt, /, *args, **kwargs): 32 | return _get_logger().info(_Message(fmt, *args, **kwargs)) 33 | 34 | 35 | def debug(fmt, /, *args, **kwargs): 36 | return _get_logger().debug(_Message(fmt, *args, **kwargs)) 37 | 38 | 39 | def error(fmt, /, *args, **kwargs): 40 | return _get_logger().error(_Message(fmt, *args, **kwargs)) 41 | 42 | 43 | def set_level(level: int) -> None: 44 | _get_logger().setLevel(level) 45 | -------------------------------------------------------------------------------- /pyrefact/style.py: -------------------------------------------------------------------------------- 1 | """Code relating to coding style""" 2 | 3 | import re 4 | from typing import Sequence 5 | 6 | from pyrefact import parsing 7 | 8 | 9 | def _list_words(name: str) -> Sequence[str]: 10 | return [ 11 | match.group() 12 | for match in re.finditer(r"([A-Z]{2,}(?![a-z])|[A-Z]?[a-z]*)\d*", name) 13 | if match.end() > match.start() 14 | ] 15 | 16 | 17 | def _make_snakecase(name: str, *, uppercase: bool = False) -> str: 18 | return "_".join(word.upper() if uppercase else word.lower() for word in _list_words(name)) 19 | 20 | 21 | def _make_camelcase(name: str) -> str: 22 | return "".join(word[0].upper() + word[1:].lower() for word in _list_words(name)) 23 | 24 | 25 | def rename_class(name: str, *, private: bool) -> str: 26 | name = re.sub("_{1,}", "_", name) 27 | if len(name) == 0: 28 | raise ValueError("Cannot rename empty name") 29 | 30 | name = _make_camelcase(name) 31 | 32 | if private and not parsing.is_private(name): 33 | return f"_{name}" 34 | if not private and parsing.is_private(name): 35 | return name[1:] 36 | 37 | return name 38 | 39 | 40 | def rename_variable(variable: str, *, static: bool, private: bool) -> str: 41 | if variable == "_": 42 | return variable 43 | 44 | if variable.startswith("__") and variable.endswith("__"): 45 | return variable 46 | 47 | renamed_variable = _make_snakecase(variable, uppercase=static) 48 | 49 | if private and not parsing.is_private(renamed_variable): 50 | renamed_variable = f"_{renamed_variable}" 51 | if not private and parsing.is_private(renamed_variable): 52 | renamed_variable = renamed_variable.lstrip("_") 53 | 54 | if renamed_variable: 55 | return renamed_variable 56 | 57 | raise RuntimeError(f"Unable to find a replacement name for {variable}") 58 | -------------------------------------------------------------------------------- /tests/integration/test_format_code.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import sys 4 | from pathlib import Path 5 | 6 | import pyrefact 7 | 8 | sys.path.append(str(Path(__file__).parents[1])) 9 | import testing_infra 10 | from integration_test_cases import INTEGRATION_TEST_CASES 11 | 12 | 13 | def main() -> int: 14 | for source, expected_abstraction in INTEGRATION_TEST_CASES: 15 | processed_content = pyrefact.format_code(source) 16 | if not testing_infra.check_fixes_equal(processed_content, expected_abstraction): 17 | return 1 18 | 19 | return 0 20 | 21 | 22 | if __name__ == "__main__": 23 | sys.exit(main()) 24 | -------------------------------------------------------------------------------- /tests/integration/test_format_code_safe.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import sys 4 | from pathlib import Path 5 | 6 | import pyrefact 7 | 8 | sys.path.append(str(Path(__file__).parents[1])) 9 | import testing_infra 10 | 11 | 12 | def main() -> int: 13 | test_cases = (( 14 | """ 15 | def q() -> None: 16 | print(1) 17 | Spam.weeee() 18 | 19 | class Spam: 20 | @staticmethod 21 | def weeee(): 22 | print(3) 23 | 24 | "Very important string statement" 25 | 26 | class Foo: 27 | def __init__(self): 28 | self.x = 2 29 | 30 | @staticmethod 31 | def some_static_function(x, y) -> int: 32 | return 2 + x + y 33 | 34 | @staticmethod 35 | def some_other_static_function(): 36 | print(3) 37 | """, 38 | """ 39 | def q() -> None: 40 | print(1) 41 | Spam.weeee() 42 | 43 | class Spam: 44 | @staticmethod 45 | def weeee(): 46 | print(3) 47 | 48 | class Foo: 49 | def __init__(self): 50 | self.x = 2 51 | 52 | @staticmethod 53 | def some_static_function(x, y) -> int: 54 | return 2 + x + y 55 | 56 | @staticmethod 57 | def some_other_static_function(): 58 | print(3) 59 | """, 60 | ), 61 | ( 62 | """ 63 | def asdf(): 64 | x = None 65 | if 2 in {1, 2, 3}: 66 | print(3) 67 | class Foo: 68 | @staticmethod 69 | def asdf(): 70 | x = None 71 | if 2 in {1, 2, 3}: 72 | y = x is not None 73 | z = y or not y 74 | print(3) 75 | """, 76 | """ 77 | def asdf(): 78 | print(3) 79 | class Foo: 80 | @staticmethod 81 | def asdf(): 82 | print(3) 83 | """, 84 | ), 85 | ( 86 | """ 87 | class TestSomeStuff(unittest.TestCase): 88 | def test_important_stuff(self): 89 | assert 1 == 3 90 | @classmethod 91 | def test_important_stuff2(cls): 92 | assert 1 == 3 93 | def test_nonsense(self): 94 | self.assertEqual(1, 3) 95 | """, 96 | """ 97 | import unittest 98 | class TestSomeStuff(unittest.TestCase): 99 | @staticmethod 100 | def test_important_stuff(): 101 | assert False 102 | @staticmethod 103 | def test_important_stuff2(): 104 | assert False 105 | def test_nonsense(self): 106 | self.assertEqual(1, 3) 107 | """, 108 | ), 109 | ( 110 | ''' 111 | def foo() -> int: 112 | """This seems useless, but pyrefact shouldn't remove it with --safe""" 113 | return 10 114 | ''', 115 | ''' 116 | def foo() -> int: 117 | """This seems useless, but pyrefact shouldn't remove it with --safe""" 118 | return 10 119 | ''', 120 | ),) 121 | 122 | for source, expected_abstraction in test_cases: 123 | processed_content = pyrefact.format_code(source, safe=True) 124 | if not testing_infra.check_fixes_equal(processed_content, expected_abstraction): 125 | return 1 126 | 127 | return 0 128 | 129 | 130 | if __name__ == "__main__": 131 | sys.exit(main()) 132 | -------------------------------------------------------------------------------- /tests/integration/test_format_file.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import sys 3 | import tempfile 4 | from pathlib import Path 5 | 6 | import pyrefact 7 | 8 | sys.path.append(str(Path(__file__).parents[1])) 9 | import testing_infra 10 | from integration_test_cases import INTEGRATION_TEST_CASES 11 | 12 | 13 | def main() -> int: 14 | for source, expected_abstraction in INTEGRATION_TEST_CASES: 15 | with tempfile.NamedTemporaryFile() as temp: 16 | temp = temp.name 17 | with open(temp, "w", encoding="utf-8") as stream: 18 | stream.write(source) 19 | 20 | pyrefact.format_file(temp) 21 | 22 | with open(temp, "r", encoding="utf-8") as stream: 23 | processed_content = stream.read() 24 | 25 | if not testing_infra.check_fixes_equal(processed_content, expected_abstraction): 26 | return 1 27 | 28 | return 0 29 | 30 | 31 | if __name__ == "__main__": 32 | sys.exit(main()) 33 | -------------------------------------------------------------------------------- /tests/integration/test_imports.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import unittest 3 | 4 | 5 | class TestImports(unittest.TestCase): 6 | def test_main_imports(self): 7 | import pyrefact 8 | 9 | # pyrefact.main 10 | assert callable(pyrefact.main) 11 | assert callable(pyrefact.format_code) 12 | assert callable(pyrefact.format_file) 13 | assert callable(pyrefact.format_files) 14 | 15 | # pyrefact.pattern_matching 16 | assert callable(pyrefact.compile) 17 | assert callable(pyrefact.findall) 18 | assert callable(pyrefact.finditer) 19 | assert callable(pyrefact.search) 20 | assert callable(pyrefact.match) 21 | assert callable(pyrefact.fullmatch) 22 | assert callable(pyrefact.sub) 23 | assert callable(pyrefact.subn) 24 | 25 | def runTest(self): 26 | self.test_main_imports() 27 | 28 | 29 | def main() -> int: 30 | test_result = TestImports().run() 31 | 32 | if not test_result.wasSuccessful(): 33 | print("FAILED") 34 | return 1 35 | 36 | print("PASSED") 37 | return 0 38 | 39 | 40 | if __name__ == "__main__": 41 | unittest.main() 42 | -------------------------------------------------------------------------------- /tests/integration/tracing_test_files/a.py: -------------------------------------------------------------------------------- 1 | # pyrefact: skip_file 2 | 3 | from c import z 4 | from b import x as k 5 | from d import sys 6 | from e import hh 7 | -------------------------------------------------------------------------------- /tests/integration/tracing_test_files/b.py: -------------------------------------------------------------------------------- 1 | # pyrefact: skip_file 2 | 3 | from c import * 4 | from e import * 5 | -------------------------------------------------------------------------------- /tests/integration/tracing_test_files/c.py: -------------------------------------------------------------------------------- 1 | # pyrefact: skip_file 2 | 3 | from d import x, y as z 4 | -------------------------------------------------------------------------------- /tests/integration/tracing_test_files/d.py: -------------------------------------------------------------------------------- 1 | # pyrefact: skip_file 2 | 3 | x = 1 4 | y = 100 5 | 6 | import sys 7 | 8 | import e 9 | -------------------------------------------------------------------------------- /tests/integration/tracing_test_files/e.py: -------------------------------------------------------------------------------- 1 | # pyrefact: skip_file 2 | 3 | ww = 10 4 | 5 | hh = aabb # yes, this is intentional 6 | -------------------------------------------------------------------------------- /tests/main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Maing script for running all tests.""" 3 | import argparse 4 | import itertools 5 | import logging 6 | import sys 7 | import traceback 8 | from pathlib import Path 9 | from typing import Sequence 10 | 11 | sys.path.append(str(Path(__file__).parent)) 12 | sys.path.append(str(Path(__file__).parent / "unit")) 13 | sys.path.append(str(Path(__file__).parent / "integration")) 14 | 15 | import testing_infra 16 | 17 | from pyrefact import logs as logger 18 | 19 | logger.set_level(logging.DEBUG) 20 | 21 | 22 | def _parse_args(args: Sequence[str]) -> argparse.Namespace: 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument("--verbose", action="store_true") 25 | return parser.parse_args(args) 26 | 27 | 28 | def main(args: Sequence[str]) -> int: 29 | """Run all scripts in the pyrefact/tests folder. 30 | 31 | Returns: 32 | int: 0 if successful, otherwise 1. 33 | """ 34 | args = _parse_args(args) 35 | return_codes = {} 36 | unit_tests = testing_infra.iter_unit_tests() 37 | integration_tests = testing_infra.iter_integration_tests() 38 | 39 | for filename in itertools.chain(unit_tests, integration_tests): 40 | module = __import__(filename.stem) 41 | relpath = str(filename.absolute().relative_to(Path.cwd())) 42 | try: 43 | return_codes[relpath] = module.main() 44 | except Exception as error: 45 | if args.verbose: 46 | return_codes[relpath] = "".join( 47 | traceback.format_exception(type(error), error, error.__traceback__) 48 | ) 49 | else: 50 | return_codes[relpath] = error 51 | 52 | if not set(return_codes.values()) - {0}: 53 | print("PASSED") 54 | return 0 55 | 56 | print("Some tests failed") 57 | print(f"{'Test path':<50} Return code") 58 | for test, return_code in return_codes.items(): 59 | if return_code != 0: 60 | print(f"./{test:<50} {return_code}") 61 | print("FAILED") 62 | return 1 63 | 64 | 65 | if __name__ == "__main__": 66 | sys.exit(main(sys.argv[1:])) 67 | -------------------------------------------------------------------------------- /tests/main_profile.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Maing script for running all tests.""" 3 | 4 | import cProfile 5 | import pstats 6 | import subprocess 7 | import sys 8 | import tempfile 9 | from pathlib import Path 10 | 11 | import pyrefact 12 | 13 | 14 | def main() -> int: 15 | """Run all scripts in the pyrefact/tests folder. 16 | 17 | Returns: 18 | int: Always returns 0. 19 | """ 20 | out_filename = Path.cwd() / "pyrefact_profiling.pstats" 21 | with tempfile.TemporaryDirectory() as tmpdir: 22 | subprocess.check_call( 23 | ["git", "clone", "--depth=1", "https://github.com/numpy/numpy.git", tmpdir] 24 | ) 25 | 26 | with cProfile.Profile() as profile: 27 | try: 28 | pyrefact.main([tmpdir]) 29 | finally: 30 | with open(out_filename, "w") as stream: 31 | stats = pstats.Stats(profile, stream=stream).sort_stats( 32 | pstats.SortKey.CUMULATIVE 33 | ) 34 | stats.dump_stats(out_filename) 35 | print(f"Saved profiling to {out_filename}") 36 | return 0 37 | 38 | 39 | if __name__ == "__main__": 40 | sys.exit(main()) 41 | -------------------------------------------------------------------------------- /tests/numpy.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Clone the numpy git repo and try to rewrite all python files in it 3 | # TODO fix the commit hash but keep shallow commit somehow 4 | 5 | mkdir -p ~/.cache/pyrefact/tests 6 | cd ~/.cache/pyrefact/tests 7 | rm -rf numpy 8 | git clone --depth=1 https://github.com/numpy/numpy.git 9 | cd numpy 10 | pyrefact -sv $(git ls-files | grep 'numpy\/core\/\w*\.py$') 11 | -------------------------------------------------------------------------------- /tests/testing_infra.py: -------------------------------------------------------------------------------- 1 | """General convenience functions used in tests.""" 2 | import itertools 3 | import re 4 | import sys 5 | from pathlib import Path 6 | from typing import Iterable 7 | 8 | from pyrefact import formatting 9 | 10 | 11 | def _remove_multi_whitespace(source: str) -> str: 12 | source = re.sub(r"(? str: 21 | proc_lines = processed_content.splitlines() or [""] 22 | exp_lines = expected_content.splitlines() or [""] 23 | length = max(max(map(len, proc_lines)), max(map(len, exp_lines))) 24 | diff_view = [ 25 | f"{p.ljust(length, ' ')} {'=' if p==e else '!'} {e}\n" 26 | for p, e in itertools.zip_longest(proc_lines, exp_lines, fillvalue="") 27 | ] 28 | return "".join(diff_view) 29 | 30 | 31 | def check_fixes_equal( 32 | processed_content: str, 33 | expected_abstraction: str, 34 | clear_paranthesises=False, 35 | clear_whitespace=True, 36 | ) -> int: 37 | if clear_whitespace: 38 | processed_content = _remove_multi_whitespace(processed_content) 39 | expected_abstraction = _remove_multi_whitespace(expected_abstraction) 40 | 41 | if tuple(sys.version_info) < (3, 9): 42 | processed_content = formatting.format_with_black(processed_content) 43 | processed_content = formatting.collapse_trailing_parentheses(processed_content) 44 | expected_abstraction = formatting.format_with_black(expected_abstraction) 45 | expected_abstraction = formatting.collapse_trailing_parentheses(expected_abstraction) 46 | 47 | diff_view = _create_diff_view(processed_content, expected_abstraction) 48 | 49 | if tuple(sys.version_info) < (3, 9): 50 | processed_content = re.sub(r"[()]", "", processed_content) 51 | expected_abstraction = re.sub(r"[()]", "", expected_abstraction) 52 | 53 | if clear_paranthesises: 54 | processed_content = re.sub(r"[\(\)]", "", processed_content) 55 | expected_abstraction = re.sub(r"[\(\)]", "", expected_abstraction) 56 | 57 | if processed_content != expected_abstraction: 58 | print(diff_view) 59 | return False 60 | 61 | return True 62 | 63 | 64 | def iter_unit_tests() -> Iterable[Path]: 65 | """Iterate over all unit test files""" 66 | return sorted((Path(__file__).parent / "unit").rglob("test_*.py")) 67 | 68 | 69 | def iter_integration_tests() -> Iterable[Path]: 70 | return sorted((Path(__file__).parent / "integration").rglob("test_*.py")) 71 | 72 | 73 | def ignore_on_version(major: int, minor: int): 74 | if (major, minor) == sys.version_info[:2]: 75 | return lambda before, after: ("", "") 76 | 77 | return lambda before, after: (before, after) 78 | -------------------------------------------------------------------------------- /tests/unit/test_add_missing_imports.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import sys 4 | from pathlib import Path 5 | 6 | from pyrefact import fixes 7 | 8 | sys.path.append(str(Path(__file__).parents[1])) 9 | import testing_infra 10 | 11 | 12 | def main() -> int: 13 | test_cases = (( 14 | """ 15 | sys.path.append(os.getcwd()) 16 | """, 17 | """ 18 | import os 19 | import sys 20 | 21 | sys.path.append(os.getcwd()) 22 | """, 23 | ), 24 | ( 25 | """ 26 | functools.reduce(lambda x: x+y, [1, 2, 3]) 27 | """, 28 | """ 29 | import functools 30 | 31 | functools.reduce(lambda x: x+y, [1, 2, 3]) 32 | """, 33 | ), 34 | ( 35 | """ 36 | import scipy.stats 37 | 38 | a, b = 1.25, 0.5 39 | mean, var, skew, kurt = scipy.stats.norminvgauss.stats(a, b, moments='mvsk') 40 | """, 41 | """ 42 | import scipy.stats 43 | 44 | a, b = 1.25, 0.5 45 | mean, var, skew, kurt = scipy.stats.norminvgauss.stats(a, b, moments='mvsk') 46 | """, 47 | ), 48 | ( 49 | """ 50 | print(wierdo_library.strange_function()) 51 | """, 52 | """ 53 | print(wierdo_library.strange_function()) 54 | """, 55 | ), 56 | ( 57 | """ 58 | x = np.array() 59 | z = pd.DataFrame() 60 | """, 61 | """ 62 | import numpy as np 63 | import pandas as pd 64 | 65 | 66 | x = np.array() 67 | z = pd.DataFrame() 68 | """, 69 | ), 70 | ( 71 | """ 72 | w = numpy.zeros(10, dtype=numpy.float32) 73 | """, 74 | """ 75 | import numpy 76 | 77 | w = numpy.zeros(10, dtype=numpy.float32) 78 | """, 79 | ), 80 | ( 81 | """ 82 | from typing import ( 83 | Sequence, 84 | Tuple, 85 | 86 | TypeVar, 87 | ) 88 | y = Iterable 89 | """, 90 | """ 91 | from typing import Iterable 92 | from typing import ( 93 | Sequence, 94 | Tuple, 95 | TypeVar, 96 | ) 97 | y = Iterable 98 | """, 99 | ),) 100 | 101 | for source, expected_abstraction in test_cases: 102 | processed_content = fixes.add_missing_imports(source) 103 | processed_content = fixes.sort_imports(processed_content) # or the order will be random 104 | if not testing_infra.check_fixes_equal(processed_content, expected_abstraction): 105 | return 1 106 | 107 | return 0 108 | 109 | 110 | if __name__ == "__main__": 111 | sys.exit(main()) 112 | -------------------------------------------------------------------------------- /tests/unit/test_breakout_starred_args.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import sys 4 | from pathlib import Path 5 | 6 | from pyrefact import fixes 7 | 8 | sys.path.append(str(Path(__file__).parents[1])) 9 | import testing_infra 10 | 11 | 12 | def main() -> int: 13 | test_cases = (( 14 | """ 15 | x = foo(a, b, *(c, d), e, *{f}, *[k, v, h]) 16 | """, 17 | """ 18 | x = foo(a, b, c, d, e, f, k, v, h) 19 | """, 20 | ), 21 | ( 22 | """ 23 | x = foo(*(1, 2)) 24 | """, 25 | """ 26 | x = foo(1, 2) 27 | """, 28 | ), 29 | ( 30 | """ 31 | x = foo(*[1, 2]) 32 | """, 33 | """ 34 | x = foo(1, 2) 35 | """, 36 | ), 37 | ( # Set of > 1 length should not be unpacked 38 | """ 39 | x = foo(*{1, 2}) 40 | """, 41 | """ 42 | x = foo(*{1, 2}) 43 | """, 44 | ), 45 | ( 46 | """ 47 | x = foo(*(1,)) 48 | """, 49 | """ 50 | x = foo(1) 51 | """, 52 | ), 53 | ( 54 | """ 55 | x = foo(*[1]) 56 | """, 57 | """ 58 | x = foo(1) 59 | """, 60 | ), 61 | ( 62 | """ 63 | x = foo(*{1}) 64 | """, 65 | """ 66 | x = foo(1) 67 | """, 68 | ),) 69 | 70 | for source, expected_abstraction in test_cases: 71 | processed_content = fixes.breakout_starred_args(source) 72 | if not testing_infra.check_fixes_equal( 73 | processed_content, expected_abstraction, clear_paranthesises=True 74 | ): 75 | return 1 76 | 77 | return 0 78 | 79 | 80 | if __name__ == "__main__": 81 | sys.exit(main()) 82 | -------------------------------------------------------------------------------- /tests/unit/test_deinterpolate_logging_args.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import sys 4 | from pathlib import Path 5 | 6 | from pyrefact import fixes 7 | 8 | sys.path.append(str(Path(__file__).parents[1])) 9 | import testing_infra 10 | 11 | 12 | def main() -> int: 13 | test_cases = (( 14 | """ 15 | logging.info("interesting information: {value}".format(value=1337)) 16 | """, 17 | """ 18 | logging.info("interesting information: {value}", value=1337) 19 | """, 20 | ), 21 | ( 22 | """ 23 | logging.critical("interesting information: {}, {}, {}".format(13, 14, 15)) 24 | """, 25 | """ 26 | logging.critical("interesting information: {}, {}, {}", 13, 14, 15) 27 | """, 28 | ), 29 | ( 30 | """ 31 | logging.log(logging.INFO, 'interesting information: {}, {}, {}'.format(13, 14, 15)) 32 | """, 33 | """ 34 | logging.log(logging.INFO, 'interesting information: {}, {}, {}', 13, 14, 15) 35 | """, 36 | ), 37 | ( 38 | """ 39 | logging.log(25, f'interesting information: {value}') 40 | """, 41 | """ 42 | logging.log(25, 'interesting information: {}', value) 43 | """, 44 | ), 45 | ( 46 | """ 47 | logging.warning(f"interesting information: {value}") 48 | """, 49 | """ 50 | logging.warning('interesting information: {}', value) 51 | """, 52 | ), 53 | ( 54 | """ 55 | logging.error(f'interesting information: {value}') 56 | """, 57 | """ 58 | logging.error('interesting information: {}', value) 59 | """, 60 | ), 61 | ( # Too complex if additional args 62 | """ 63 | logging.error(f"interesting information: {value}", more_args) 64 | """, 65 | """ 66 | logging.error(f"interesting information: {value}", more_args) 67 | """, 68 | ), 69 | ( 70 | """ 71 | logging.info(f"interesting information: {value}", foo=more_args) 72 | """, 73 | """ 74 | logging.info(f"interesting information: {value}", foo=more_args) 75 | """, 76 | ), 77 | ( # Too complex if additional args. 78 | """ 79 | logging.log(10, f"interesting information: {value}", more_args) 80 | """, 81 | """ 82 | logging.log(10, f"interesting information: {value}", more_args) 83 | """, 84 | ), 85 | ( 86 | """ 87 | logging.log(10, f'interesting information: {value}', foo=more_args) 88 | """, 89 | """ 90 | logging.log(10, f'interesting information: {value}', foo=more_args) 91 | """, 92 | ), 93 | ( 94 | """ 95 | logger.info("interesting information: {value}".format(value=1337)) 96 | """, 97 | """ 98 | logger.info("interesting information: {value}", value=1337) 99 | """, 100 | ), 101 | ( 102 | """ 103 | logger.critical('interesting information: {}, {}, {}'.format(13, 14, 15)) 104 | """, 105 | """ 106 | logger.critical('interesting information: {}, {}, {}', 13, 14, 15) 107 | """, 108 | ), 109 | ( 110 | """ 111 | log.log(logging.INFO, "interesting information: {}, {}, {}".format(13, 14, 15)) 112 | """, 113 | """ 114 | log.log(logging.INFO, "interesting information: {}, {}, {}", 13, 14, 15) 115 | """, 116 | ),) 117 | 118 | for source, expected_abstraction in test_cases: 119 | processed_content = fixes.deinterpolate_logging_args(source) 120 | if not testing_infra.check_fixes_equal( 121 | processed_content, expected_abstraction, clear_paranthesises=True 122 | ): 123 | return 1 124 | 125 | return 0 126 | 127 | 128 | if __name__ == "__main__": 129 | sys.exit(main()) 130 | -------------------------------------------------------------------------------- /tests/unit/test_delete_pointless_statements.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import sys 4 | 5 | from pyrefact.fixes import delete_pointless_statements 6 | 7 | CODE = """12 8 | 9 | def pointless_function() -> None: 10 | 555 11 | 12 | 13 | 1234 14 | {1: sum((2, 3, 6, 0)), "asdf": 13-12} 15 | 16 | """ 17 | 18 | 19 | EXPECTED = """ 20 | 21 | def pointless_function() -> None: 22 | pass 23 | 24 | """ 25 | 26 | 27 | SHEBANG = """#!/usr/bin/env python3 28 | 29 | 30 | """ 31 | 32 | 33 | REGULAR_MODULE_DOCSTRING = '''"""A normal module docstring""" 34 | 35 | def f() -> int: 36 | return 0 37 | 38 | import sys 39 | if __name__ == "__main__": 40 | sys.exit(f()) 41 | ''' 42 | 43 | 44 | def main() -> int: 45 | got = delete_pointless_statements(CODE) 46 | assert got.strip() == EXPECTED.strip(), "\n".join( 47 | ("Wrong result: (got, expected)", got, "\n***\n", EXPECTED) 48 | ) 49 | assert SHEBANG == delete_pointless_statements(SHEBANG) 50 | assert REGULAR_MODULE_DOCSTRING == delete_pointless_statements(REGULAR_MODULE_DOCSTRING) 51 | 52 | return 0 53 | 54 | 55 | if __name__ == "__main__": 56 | sys.exit(main()) 57 | -------------------------------------------------------------------------------- /tests/unit/test_delete_unreachable_code.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import sys 4 | from pathlib import Path 5 | 6 | from pyrefact import fixes 7 | 8 | sys.path.append(str(Path(__file__).parents[1])) 9 | import testing_infra 10 | 11 | 12 | def main() -> int: 13 | test_cases = (( 14 | """ 15 | if False: 16 | print(2) 17 | if True: 18 | print(3) 19 | """, 20 | """ 21 | if True: 22 | print(3) 23 | """, 24 | ), 25 | ( 26 | """ 27 | if (1, 2, 3): 28 | print(2) 29 | if (): 30 | print(3) 31 | """, 32 | """ 33 | if (1, 2, 3): 34 | print(2) 35 | """, 36 | ), 37 | ( 38 | """ 39 | for i in range(10): 40 | print(2) 41 | continue 42 | import os 43 | print(os.getcwd()) 44 | """, 45 | """ 46 | for i in range(10): 47 | print(2) 48 | continue 49 | """, 50 | ), 51 | ( 52 | """ 53 | for i in range(10): 54 | print(2) 55 | if [1]: 56 | break 57 | import os 58 | print(os.getcwd()) 59 | """, 60 | """ 61 | for i in range(10): 62 | print(2) 63 | if [1]: 64 | break 65 | """, 66 | ), 67 | ( 68 | """ 69 | def foo(): 70 | import random 71 | return random.random() > 0.5 72 | print(3) 73 | for i in range(10): 74 | print(2) 75 | if foo(): 76 | break 77 | else: 78 | continue 79 | import os 80 | print(os.getcwd()) 81 | """, 82 | """ 83 | def foo(): 84 | import random 85 | return random.random() > 0.5 86 | for i in range(10): 87 | print(2) 88 | if foo(): 89 | break 90 | else: 91 | continue 92 | """, 93 | ), 94 | ( 95 | """ 96 | while 0: 97 | print(0) 98 | while 3: 99 | print(3) 100 | """, 101 | """ 102 | while 3: 103 | print(3) 104 | """, 105 | ),) 106 | 107 | for source, expected_abstraction in test_cases: 108 | processed_content = fixes.delete_unreachable_code(source) 109 | 110 | if not testing_infra.check_fixes_equal(processed_content, expected_abstraction): 111 | return 1 112 | 113 | return 0 114 | 115 | 116 | if __name__ == "__main__": 117 | sys.exit(main()) 118 | -------------------------------------------------------------------------------- /tests/unit/test_delete_unused_functions_and_classes.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import sys 4 | from pathlib import Path 5 | 6 | from pyrefact import fixes 7 | 8 | sys.path.append(str(Path(__file__).parents[1])) 9 | import testing_infra 10 | 11 | 12 | def main() -> int: 13 | test_cases = (( 14 | """ 15 | 16 | def function() -> int: 17 | return 0 18 | 19 | def _private_function() -> bool: 20 | return 11 21 | 22 | def _used_function() -> str: 23 | '''Docstring mentioning _private_function()''' 24 | return "this function is used" 25 | 26 | def _user_of_used_function() -> str: 27 | return _used_function() 28 | 29 | class Foo: 30 | def bar(self) -> bool: 31 | return False 32 | 33 | @property 34 | def spammy(self) -> bool: 35 | return True 36 | 37 | if __name__ == "__main__": 38 | Foo().spammy 39 | _user_of_used_function() 40 | 41 | """, 42 | """ 43 | def _used_function() -> str: 44 | '''Docstring mentioning _private_function()''' 45 | return "this function is used" 46 | 47 | def _user_of_used_function() -> str: 48 | return _used_function() 49 | 50 | class Foo: 51 | @property 52 | def spammy(self) -> bool: 53 | return True 54 | 55 | if __name__ == "__main__": 56 | Foo().spammy 57 | _user_of_used_function() 58 | 59 | """, 60 | ), 61 | ) 62 | 63 | for source, expected_abstraction in test_cases: 64 | processed_content = fixes.delete_unused_functions_and_classes(source) 65 | if not testing_infra.check_fixes_equal(processed_content, expected_abstraction): 66 | return 1 67 | 68 | return 0 69 | 70 | 71 | if __name__ == "__main__": 72 | sys.exit(main()) 73 | -------------------------------------------------------------------------------- /tests/unit/test_early_continue.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import sys 4 | from pathlib import Path 5 | 6 | from pyrefact import fixes 7 | 8 | sys.path.append(str(Path(__file__).parents[1])) 9 | import testing_infra 10 | 11 | 12 | def main() -> int: 13 | test_cases = (( 14 | """ 15 | for x in range(100): 16 | if x > 10: 17 | y = 13 18 | else: 19 | x += 1 20 | x *= 12 21 | print(x > 30) 22 | y = 100 - sum(x, 2, 3) 23 | print(x) 24 | """, 25 | """ 26 | for x in range(100): 27 | if x > 10: 28 | y = 13 29 | continue 30 | else: 31 | x += 1 32 | x *= 12 33 | print(x > 30) 34 | y = 100 - sum(x, 2, 3) 35 | print(x) 36 | """, 37 | ), 38 | ( 39 | """ 40 | for i in range(100): 41 | if i % 3 == 2: 42 | print(i ** i) 43 | print(i ** 3) 44 | import os 45 | import sys 46 | print(os.getcwd()) 47 | print(sys is os) 48 | """, 49 | """ 50 | for i in range(100): 51 | if i % 3 != 2: 52 | continue 53 | else: 54 | print(i ** i) 55 | print(i ** 3) 56 | import os 57 | import sys 58 | print(os.getcwd()) 59 | print(sys is os) 60 | """, 61 | ), 62 | ( 63 | """ 64 | for i in range(100): 65 | if i % 3 == 2: 66 | print(i ** i) 67 | print(i ** 3) 68 | print(i ** 4) 69 | """, 70 | """ 71 | for i in range(100): 72 | if i % 3 == 2: 73 | print(i ** i) 74 | print(i ** 3) 75 | print(i ** 4) 76 | """, 77 | ), 78 | ( 79 | """ 80 | for i in range(100): 81 | if i % 3 == 2: 82 | if i % 6 == 1: 83 | print(i ** i) 84 | print(i ** 3) 85 | print(i ** 4) 86 | """, 87 | """ 88 | for i in range(100): 89 | if i % 3 == 2: 90 | if i % 6 == 1: 91 | print(i ** i) 92 | print(i ** 3) 93 | print(i ** 4) 94 | """, 95 | ), 96 | ( 97 | """ 98 | for i in range(100): 99 | if i % 3 == 2: 100 | print(i ** i) 101 | if i % 6 == 1: 102 | print(i ** 3) 103 | print(i ** 4) 104 | """, 105 | """ 106 | for i in range(100): 107 | if i % 3 == 2: 108 | print(i ** i) 109 | if i % 6 == 1: 110 | print(i ** 3) 111 | print(i ** 4) 112 | """, 113 | ), 114 | ( 115 | """ 116 | for i in range(100): 117 | if i % 3 == 2: 118 | print(i ** i) 119 | print(i ** (i - 1)) 120 | for i in range(10000): 121 | if i >= 1337 and i is not 99: 122 | print(i - i) 123 | if i % 6 == 1: 124 | print(i ** 3) 125 | print(i ** 4) 126 | """, 127 | """ 128 | for i in range(100): 129 | if i % 3 != 2: 130 | continue 131 | else: 132 | print(i ** i) 133 | print(i ** (i - 1)) 134 | for i in range(10000): 135 | if i >= 1337 and i is not 99: 136 | print(i - i) 137 | if i % 6 == 1: 138 | print(i ** 3) 139 | print(i ** 4) 140 | """, 141 | ),) 142 | 143 | for source, expected_abstraction in test_cases: 144 | processed_content = fixes.early_continue(source) 145 | if not testing_infra.check_fixes_equal(processed_content, expected_abstraction): 146 | return 1 147 | 148 | return 0 149 | 150 | 151 | if __name__ == "__main__": 152 | sys.exit(main()) 153 | -------------------------------------------------------------------------------- /tests/unit/test_early_return.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import sys 4 | from pathlib import Path 5 | 6 | from pyrefact import fixes 7 | 8 | sys.path.append(str(Path(__file__).parents[1])) 9 | import testing_infra 10 | 11 | 12 | def main() -> int: 13 | test_cases = (( 14 | """ 15 | def f(x) -> int: 16 | if x > 10: 17 | x += 1 18 | x *= 12 19 | print(x > 30) 20 | y = 100 - sum(x, 2, 3) 21 | else: 22 | y = 13 23 | return y 24 | """, 25 | """ 26 | def f(x) -> int: 27 | if x > 10: 28 | x += 1 29 | x *= 12 30 | print(x > 30) 31 | return 100 - sum(x, 2, 3) 32 | else: 33 | return 13 34 | """, 35 | ), 36 | ( 37 | """ 38 | def f(x) -> int: 39 | if x > 10: 40 | x += 1 41 | x *= 12 42 | print(x > 30) 43 | y = 100 - sum(x, 2, 3) 44 | else: 45 | y = 13 46 | print(3) 47 | return y 48 | """, 49 | """ 50 | def f(x) -> int: 51 | if x > 10: 52 | x += 1 53 | x *= 12 54 | print(x > 30) 55 | y = 100 - sum(x, 2, 3) 56 | else: 57 | y = 13 58 | print(3) 59 | return y 60 | """, 61 | ),) 62 | 63 | for source, expected_abstraction in test_cases: 64 | processed_content = fixes.early_return(source) 65 | if not testing_infra.check_fixes_equal(processed_content, expected_abstraction): 66 | return 1 67 | 68 | return 0 69 | 70 | 71 | if __name__ == "__main__": 72 | sys.exit(main()) 73 | -------------------------------------------------------------------------------- /tests/unit/test_fix_duplicate_imports.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import sys 4 | from pathlib import Path 5 | 6 | from pyrefact import fixes 7 | 8 | sys.path.append(str(Path(__file__).parents[1])) 9 | import testing_infra 10 | 11 | 12 | def main() -> int: 13 | test_cases = (( 14 | """ 15 | from logging import info 16 | from logging import warning 17 | from logging import error, info, log 18 | from logging import ( 19 | warning, 20 | critical , 21 | error, error, error as error) 22 | from logging import critical 23 | """, 24 | """ 25 | from logging import critical, error, info, log, warning 26 | """, 27 | ), 28 | ( 29 | """ 30 | from logging import info 31 | from numpy import ndarray 32 | from logging import warning 33 | from numpy import array 34 | from logging import error as info, warning as error 35 | """, 36 | """ 37 | from logging import error as info, info, warning, warning as error 38 | from numpy import array, ndarray 39 | """, 40 | ), 41 | ( 42 | """ 43 | import logging 44 | import logging 45 | import logging, numpy, pandas as pd, os as sys, os as os 46 | import pandas as pd, os, os, os 47 | import os 48 | """, 49 | """ 50 | import logging 51 | import numpy 52 | import os 53 | import os as sys 54 | import pandas as pd 55 | """, 56 | ), 57 | ( 58 | """ 59 | if foo(): 60 | from spam import eggs 61 | else: 62 | from spam import spam 63 | """, 64 | """ 65 | if foo(): 66 | from spam import eggs 67 | else: 68 | from spam import spam 69 | """, 70 | ), 71 | ( 72 | """ 73 | from spam import eggs 74 | print(10) 75 | from spam import spam 76 | """, 77 | """ 78 | from spam import eggs 79 | print(10) 80 | from spam import spam 81 | """, 82 | ), 83 | ( 84 | """ 85 | from spam import eggs 86 | from spam import spam 87 | """, 88 | """ 89 | from spam import eggs, spam 90 | """, 91 | ), 92 | ( 93 | """ 94 | import foo as foo 95 | """, 96 | """ 97 | import foo 98 | """, 99 | ), 100 | ( 101 | """ 102 | import foo as food 103 | """, 104 | """ 105 | import foo as food 106 | """, 107 | ), 108 | ( 109 | """ 110 | import foo.bar as bar, spam.eggs as eggs 111 | """, 112 | """ 113 | from foo import bar 114 | from spam import eggs 115 | """, 116 | ), 117 | ( 118 | """ 119 | import foo.bar as bars 120 | """, 121 | """ 122 | import foo.bar as bars 123 | """, 124 | ), 125 | ( 126 | """ 127 | import a.b.c.d.e.f.g.h.i as i 128 | """, 129 | """ 130 | from a.b.c.d.e.f.g.h import i 131 | """, 132 | ),) 133 | 134 | for source, expected_abstraction in test_cases: 135 | processed_content = fixes.fix_duplicate_imports(source) 136 | if not testing_infra.check_fixes_equal( 137 | processed_content, expected_abstraction, clear_paranthesises=True 138 | ): 139 | return 1 140 | 141 | return 0 142 | 143 | 144 | if __name__ == "__main__": 145 | sys.exit(main()) 146 | -------------------------------------------------------------------------------- /tests/unit/test_fix_if_assign.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import sys 4 | from pathlib import Path 5 | 6 | from pyrefact import fixes 7 | 8 | sys.path.append(str(Path(__file__).parents[1])) 9 | import testing_infra 10 | 11 | 12 | def main() -> int: 13 | test_cases = (( 14 | """ 15 | if x == 100: 16 | y = False 17 | else: 18 | y = True 19 | """, 20 | """ 21 | y = not x == 100 22 | """, 23 | ), 24 | ( 25 | """ 26 | if x == 100 and z >= foo(): 27 | k = False 28 | else: 29 | k = True 30 | """, 31 | """ 32 | k = not (x == 100 and z >= foo()) 33 | """, 34 | ), 35 | ( 36 | """ 37 | if x == 100: 38 | k = False 39 | else: 40 | k = 100 41 | """, 42 | """ 43 | if x == 100: 44 | k = False 45 | else: 46 | k = 100 47 | """, 48 | ),) 49 | 50 | for source, expected_abstraction in test_cases: 51 | processed_content = fixes.fix_if_assign(source) 52 | if not testing_infra.check_fixes_equal(processed_content, expected_abstraction): 53 | return 1 54 | 55 | return 0 56 | 57 | 58 | if __name__ == "__main__": 59 | sys.exit(main()) 60 | -------------------------------------------------------------------------------- /tests/unit/test_fix_if_return.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import sys 4 | from pathlib import Path 5 | 6 | from pyrefact import fixes 7 | 8 | sys.path.append(str(Path(__file__).parents[1])) 9 | import testing_infra 10 | 11 | 12 | def main() -> int: 13 | test_cases = (( 14 | """ 15 | def f(x: int) -> int: 16 | if x == 100: 17 | return False 18 | return True 19 | """, 20 | """ 21 | def f(x: int) -> int: 22 | return not x == 100 23 | """, 24 | ), 25 | ( 26 | """ 27 | def f(x: int) -> int: 28 | if 2 ** x >= -1 and y ** 3 == 100 and not foo(False): 29 | return False 30 | 31 | return True 32 | """, 33 | """ 34 | def f(x: int) -> int: 35 | return not (2 ** x >= -1 and y ** 3 == 100 and (not foo(False))) 36 | """, 37 | ), 38 | ( 39 | """ 40 | def f(x: int) -> int: 41 | foo() 42 | if x == 100: 43 | return True 44 | return False 45 | """, 46 | """ 47 | def f(x: int) -> int: 48 | foo() 49 | return x == 100 50 | """, 51 | ),) 52 | 53 | for source, expected_abstraction in test_cases: 54 | processed_content = fixes.fix_if_return(source) 55 | if not testing_infra.check_fixes_equal(processed_content, expected_abstraction): 56 | return 1 57 | 58 | return 0 59 | 60 | 61 | if __name__ == "__main__": 62 | sys.exit(main()) 63 | -------------------------------------------------------------------------------- /tests/unit/test_fix_import_spacing.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | 4 | import sys 5 | from pathlib import Path 6 | 7 | from pyrefact import fixes 8 | 9 | sys.path.append(str(Path(__file__).parents[1])) 10 | import testing_infra 11 | 12 | 13 | def main() -> int: 14 | test_cases = (( 15 | """ 16 | import os 17 | import re 18 | """, 19 | """ 20 | import os 21 | import re 22 | """, 23 | ), 24 | ( 25 | """ 26 | import os 27 | 28 | 29 | import re 30 | """, 31 | """ 32 | import os 33 | import re 34 | """, 35 | ), 36 | ( 37 | """ 38 | def f(): 39 | import os 40 | 41 | 42 | import re 43 | """, 44 | """ 45 | def f(): 46 | import os 47 | import re 48 | """, 49 | ), 50 | ( 51 | """ 52 | import os; import sys 53 | 54 | 55 | import re 56 | """, 57 | """ 58 | import os; import sys 59 | import re 60 | """, 61 | ), 62 | ( 63 | """ 64 | import os 65 | 66 | from re import findall 67 | """, 68 | """ 69 | import os 70 | from re import findall 71 | """, 72 | ), 73 | ( 74 | """ 75 | import os 76 | 77 | 78 | import re 79 | from pathlib import ( 80 | Path, 81 | PurePath, PosixPath, 82 | ) 83 | import numpy 84 | import pandas as pd 85 | """, 86 | """ 87 | import os 88 | import re 89 | from pathlib import ( 90 | Path, 91 | PurePath, PosixPath, 92 | ) 93 | 94 | import numpy 95 | import pandas as pd 96 | """, 97 | ), 98 | ( 99 | """ 100 | import os 101 | 102 | # Interesting comment 103 | import re 104 | import numpy 105 | import pandas as pd 106 | def foo(): 107 | import os as re 108 | 109 | import re as os 110 | 111 | print(100) 112 | import pandas 113 | foo() 114 | """, 115 | """ 116 | import os 117 | 118 | # Interesting comment 119 | import re 120 | 121 | import numpy 122 | import pandas as pd 123 | 124 | 125 | def foo(): 126 | import os as re 127 | import re as os 128 | 129 | print(100) 130 | 131 | import pandas 132 | 133 | foo() 134 | """, 135 | ),) 136 | 137 | for source, expected_abstraction in test_cases: 138 | processed_content = fixes.fix_import_spacing(source) 139 | 140 | if not testing_infra.check_fixes_equal( 141 | processed_content, expected_abstraction, clear_whitespace=False 142 | ): 143 | return 1 144 | 145 | return 0 146 | 147 | 148 | if __name__ == "__main__": 149 | sys.exit(main()) 150 | -------------------------------------------------------------------------------- /tests/unit/test_fix_raise_missing_from.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import sys 4 | from pathlib import Path 5 | 6 | from pyrefact import fixes 7 | 8 | sys.path.append(str(Path(__file__).parents[1])) 9 | import testing_infra 10 | 11 | 12 | def main() -> int: 13 | test_cases = (( 14 | """ 15 | try: 16 | sketchy_function() 17 | except ValueError: 18 | raise RuntimeError() 19 | """, 20 | """ 21 | try: 22 | sketchy_function() 23 | except ValueError as error: 24 | raise RuntimeError() from error 25 | """, 26 | ), 27 | ( 28 | """ 29 | try: 30 | sketchy_function() 31 | except ValueError as foo: 32 | raise RuntimeError() from foo 33 | """, 34 | """ 35 | try: 36 | sketchy_function() 37 | except ValueError as foo: 38 | raise RuntimeError() from foo 39 | """, 40 | ), 41 | ( 42 | """ 43 | try: 44 | sketchy_function() 45 | except ValueError: 46 | pass 47 | """, 48 | """ 49 | try: 50 | sketchy_function() 51 | except ValueError: 52 | pass 53 | """, 54 | ),) 55 | 56 | for source, expected_abstraction in test_cases: 57 | processed_content = fixes.fix_raise_missing_from(source) 58 | if not testing_infra.check_fixes_equal(processed_content, expected_abstraction): 59 | return 1 60 | 61 | return 0 62 | 63 | 64 | if __name__ == "__main__": 65 | sys.exit(main()) 66 | -------------------------------------------------------------------------------- /tests/unit/test_fix_reimported_names.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import sys 4 | from pathlib import Path 5 | 6 | from pyrefact import tracing 7 | 8 | sys.path.append(str(Path(__file__).parents[1])) 9 | import testing_infra 10 | 11 | 12 | def main() -> int: 13 | test_cases = (( 14 | """ 15 | import os 16 | import sys 17 | """, 18 | """ 19 | import os 20 | import sys 21 | """, 22 | ), 23 | ( 24 | """ 25 | from c import z 26 | """, 27 | """ 28 | from d import y as z 29 | """, 30 | ), 31 | ( 32 | """ 33 | from d import sys 34 | """, 35 | """ 36 | import sys 37 | """, 38 | ), 39 | ( 40 | """ 41 | from c import z 42 | from b import x as k 43 | from d import sys 44 | """, 45 | """ 46 | from d import x as k 47 | from d import y as z 48 | import sys 49 | """, 50 | ), 51 | ( 52 | """ 53 | from c import z 54 | from b import x as k 55 | from d import sys 56 | from e import hh 57 | """, 58 | """ 59 | from d import x as k 60 | from d import y as z 61 | import sys 62 | from e import hh 63 | """, 64 | ), 65 | ( # This fix doesn't touch starred imports. They're fixed by fix_starred_imports 66 | """ 67 | from c import * 68 | from e import * 69 | """, 70 | """ 71 | from c import * 72 | from e import * 73 | """, 74 | ),) 75 | 76 | sys.path.append(str(Path(__file__).parents[1] / "integration" / "tracing_test_files")) 77 | for source, expected_abstraction in test_cases: 78 | processed_content = tracing.fix_reimported_names(source) 79 | if not testing_infra.check_fixes_equal(processed_content, expected_abstraction): 80 | return 1 81 | 82 | return 0 83 | 84 | 85 | if __name__ == "__main__": 86 | sys.exit(main()) 87 | -------------------------------------------------------------------------------- /tests/unit/test_fix_starred_imports.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import sys 4 | from pathlib import Path 5 | 6 | from pyrefact import tracing 7 | 8 | sys.path.append(str(Path(__file__).parents[1])) 9 | import testing_infra 10 | 11 | 12 | def main() -> int: 13 | test_cases = (( 14 | """ 15 | from os import * 16 | 17 | print(getcwd()) 18 | """, 19 | """ 20 | from os import getcwd 21 | 22 | print(getcwd()) 23 | """, 24 | ), 25 | ( 26 | """ 27 | from os import * 28 | from pathlib import * 29 | from sys import * 30 | 31 | print(Path(getcwd())) 32 | """, 33 | """ 34 | from os import getcwd 35 | from pathlib import Path 36 | 37 | print(Path(getcwd())) 38 | """, 39 | ), 40 | ( 41 | """ 42 | import time 43 | 44 | print(time.time()) 45 | """, 46 | """ 47 | import time 48 | 49 | print(time.time()) 50 | """, 51 | ), 52 | ( 53 | """ 54 | from time import * 55 | from os import * 56 | from pathlib import * 57 | from datetime import * 58 | 59 | print(f''' 60 | Working directory: {getcwd()} 61 | Current time: {datetime.now().isoformat()} 62 | Created: {datetime.datetime.utcfromtimestamp(stat(__file__).st_ctime)} 63 | Last modified: {datetime.datetime.utcfromtimestamp(stat(__file__).st_mtime)} 64 | ''') 65 | """, 66 | """ 67 | from os import getcwd, stat 68 | from datetime import datetime 69 | 70 | print(f''' 71 | Working directory: {getcwd()} 72 | Current time: {datetime.now().isoformat()} 73 | Created: {datetime.datetime.utcfromtimestamp(stat(__file__).st_ctime)} 74 | Last modified: {datetime.datetime.utcfromtimestamp(stat(__file__).st_mtime)} 75 | ''') 76 | """, 77 | ), 78 | ( # Thise are reimported from other files, but fix_reimported_names fixes that instead. 79 | """ 80 | from c import * 81 | from e import * 82 | 83 | print(x) 84 | print(z) 85 | print(ww) 86 | """, 87 | """ 88 | from c import x, z 89 | from e import ww 90 | 91 | print(x) 92 | print(z) 93 | print(ww) 94 | """, 95 | ),) 96 | 97 | sys.path.append(str(Path(__file__).parents[1] / "integration" / "tracing_test_files")) 98 | for source, expected_abstraction in test_cases: 99 | processed_content = tracing.fix_starred_imports(source) 100 | if not testing_infra.check_fixes_equal(processed_content, expected_abstraction): 101 | return 1 102 | 103 | return 0 104 | 105 | 106 | if __name__ == "__main__": 107 | sys.exit(main()) 108 | -------------------------------------------------------------------------------- /tests/unit/test_fix_unconventional_class_definitions.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import sys 4 | from pathlib import Path 5 | 6 | from pyrefact import object_oriented 7 | 8 | sys.path.append(str(Path(__file__).parents[1])) 9 | import testing_infra 10 | 11 | 12 | def main() -> int: 13 | test_cases = (( 14 | """ 15 | class Foo: 16 | pass 17 | 18 | Foo.x = 1 19 | """, 20 | """ 21 | class Foo: 22 | pass 23 | x = 1 24 | """, 25 | ), 26 | ( 27 | """ 28 | class Foo: 29 | pass 30 | 31 | Foo.x = (1, 2, 3) 32 | """, 33 | """ 34 | class Foo: 35 | pass 36 | x = (1, 2, 3) 37 | """, 38 | ), 39 | ( 40 | """ 41 | class Foo: 42 | pass 43 | 44 | Bar.x = 1 45 | """, 46 | """ 47 | class Foo: 48 | pass 49 | 50 | Bar.x = 1 51 | """, 52 | ), 53 | ( 54 | """ 55 | class Foo(object): 56 | pass 57 | 58 | Foo.x = 1 59 | """, 60 | """ 61 | class Foo(object): 62 | pass 63 | x = 1 64 | """, 65 | ), 66 | ( 67 | """ 68 | @a 69 | @bunch 70 | @of 71 | @decorators 72 | class Foo(object, list, set, tuple, []): 73 | pass 74 | 75 | Foo.x = 1 76 | """, 77 | """ 78 | @a 79 | @bunch 80 | @of 81 | @decorators 82 | class Foo(object, list, set, tuple, []): 83 | pass 84 | x = 1 85 | """, 86 | ), 87 | ( 88 | """ 89 | class Foo: 90 | pass 91 | 92 | Foo.x = 1 93 | Foo.y = z 94 | Foo.z = func() 95 | """, 96 | """ 97 | class Foo: 98 | pass 99 | x = 1 100 | y = z 101 | z = func() 102 | """, 103 | ),) 104 | 105 | for source, expected_abstraction in test_cases: 106 | processed_content = object_oriented.fix_unconventional_class_definitions(source) 107 | if not testing_infra.check_fixes_equal(processed_content, expected_abstraction): 108 | return 1 109 | 110 | return 0 111 | 112 | 113 | if __name__ == "__main__": 114 | sys.exit(main()) 115 | -------------------------------------------------------------------------------- /tests/unit/test_has_side_effect.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import ast 4 | import sys 5 | 6 | from pyrefact import constants, core 7 | 8 | 9 | def main() -> int: 10 | """Test core.has_side_effect 11 | 12 | Returns: 13 | int: 1 if the function behaves incorrectly, otherwise 0 14 | """ 15 | whitelist = constants.SAFE_CALLABLES 16 | for source in ( 17 | "{}", 18 | "()", 19 | "[]", 20 | "1", 21 | "-1", 22 | "2-1", 23 | "2.-1", 24 | "False", 25 | "pass", 26 | "print", 27 | "exit", 28 | "x", 29 | "x+1", 30 | "x**10", 31 | "x > 2", 32 | "x < 2 < x", 33 | "{1, 2, 3}", 34 | "[1, 2, 3]", 35 | "(1, 2, 3)", 36 | "[1]*2", 37 | "[] + []", 38 | "range(2)", 39 | "list(range(2, 6))", 40 | "[x for x in range(3)]", 41 | "(x for x in range(3))", 42 | "{x for x in range(3)}", 43 | "{x: x-1 for x in range(3)}", 44 | "{**{1: 2}}", 45 | "lambda: 2", 46 | "{1: 2}[1]", 47 | "{1: sum}[1]((3, 3, 3))", 48 | '{1: sum((2, 3, 6, 0)), "asdf": 13-12}', 49 | "_=2", 50 | "(_:=3)", 51 | "_+=11", 52 | "_: int = 'q'", 53 | """for _ in range(10): 54 | 1 55 | """, 56 | "f'''y={1}'''", 57 | "b'bytes_string'", 58 | "r'''raw_string_literal\n'''", 59 | 'f"i={i:.3f}"', 60 | "f'{x=}'", 61 | "'foo = {}'.format(foo)", 62 | ): 63 | node = core.parse(source).body[0] 64 | if not core.has_side_effect(node, whitelist): 65 | continue 66 | 67 | print("Ast has side effect, but should not:") 68 | print(source) 69 | print("Ast structure:") 70 | print(ast.dump(node, indent=2)) 71 | return 1 72 | 73 | for source in ( 74 | "x=100", 75 | "requests.post(*args, **kwargs)", 76 | "x.y=z", 77 | "print()", 78 | "exit()", 79 | """def f() -> None: 80 | return 1 81 | """, 82 | "g=2", 83 | "(h:=3)", 84 | """for i in range(10): 85 | 1 86 | """, 87 | "mysterious_function()", 88 | "flat_dict[value] = something", 89 | "nested_dict[value][item] = something", 90 | "deep_nested_dict[a][b][c][d][e][f][g] = something", 91 | "f'''y={1 + foo()}'''", 92 | 'f"i={i - i ** (1 - f(i)):.3f}"', 93 | "f'{(x := 10)=}'", 94 | "'foo() = {}'.format(foo())", 95 | "x.append(10)", 96 | ): 97 | node = core.parse(source).body[0] 98 | if core.has_side_effect(node, whitelist): 99 | continue 100 | 101 | print("Ast has no side effect, but should:") 102 | print(source) 103 | print("Ast structure:") 104 | print(ast.dump(node, indent=2)) 105 | return 1 106 | 107 | return 0 108 | 109 | 110 | if __name__ == "__main__": 111 | sys.exit(main()) 112 | -------------------------------------------------------------------------------- /tests/unit/test_hash_node.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import ast 4 | import sys 5 | 6 | from pyrefact import abstractions, constants, core 7 | 8 | 9 | def _error_message(left: ast.AST, right: ast.AST, *, positive: bool) -> str: 10 | if positive: 11 | msg = "Hashes are different, but should be the same for" 12 | else: 13 | msg = "Hashes are the same, but should be different for" 14 | return f"""{msg}: 15 | 16 | {ast.dump(left, indent=2)} 17 | 18 | and 19 | 20 | {ast.dump(right, indent=2)} 21 | """ 22 | 23 | 24 | def main() -> None: 25 | preserved_names = constants.BUILTIN_FUNCTIONS 26 | positives = ( 27 | ("lambda x: x", "lambda y: y"), 28 | ("lambda x: (x**3 - x**2) // 11", "lambda aaaa: (aaaa**3 - aaaa**2) // 11"), 29 | ( 30 | """ 31 | def q(a, b, c): 32 | if a and b and c: 33 | return q(a) 34 | elif c and d > 0: 35 | return -a 36 | return a*b +c*a 37 | """, 38 | """ 39 | def qz(aaa, bbb, ccc): 40 | if aaa and bbb and ccc: 41 | return qz(aaa) 42 | elif ccc and ddd > 0: 43 | return -aaa 44 | return aaa*bbb +ccc*aaa 45 | """, 46 | ),) 47 | negatives = (("lambda x: list(x)", "lambda x: set(x)"),) 48 | 49 | for left_expression, right_expression in positives: 50 | left_node = core.parse(left_expression).body[0] 51 | right_node = core.parse(right_expression).body[0] 52 | left_hash = abstractions.hash_node(left_node, preserved_names) 53 | right_hash = abstractions.hash_node(right_node, preserved_names) 54 | assert left_hash == right_hash, _error_message(left_node, right_node, positive=True) 55 | 56 | for left_expression, right_expression in negatives: 57 | left_node = core.parse(left_expression).body[0] 58 | right_node = core.parse(right_expression).body[0] 59 | left_hash = abstractions.hash_node(left_node, preserved_names) 60 | right_hash = abstractions.hash_node(right_node, preserved_names) 61 | assert left_hash != right_hash, _error_message(left_node, right_node, positive=False) 62 | 63 | return 0 64 | 65 | 66 | if __name__ == "__main__": 67 | sys.exit(main()) 68 | -------------------------------------------------------------------------------- /tests/unit/test_ignore_comments.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import sys 4 | from pathlib import Path 5 | 6 | from pyrefact import format_code 7 | 8 | sys.path.append(str(Path(__file__).parents[1])) 9 | import testing_infra 10 | 11 | 12 | def main() -> int: 13 | test_cases = ( 14 | ( # No ignore comment 15 | """ 16 | print(0) 17 | if False: 18 | print(1) 19 | """, 20 | """ 21 | print(0) 22 | """, 23 | ), 24 | ( # Valid ignore comment 25 | """ 26 | print(0) 27 | if False: # pyrefact: ignore 28 | print(1) 29 | """, 30 | """ 31 | print(0) 32 | if False: # pyrefact: ignore 33 | print(1) 34 | """, 35 | ), 36 | ( # Valid ignore comment 37 | """ 38 | print(0) 39 | if False: # pyrefact: skip_file 40 | print(1) 41 | """, 42 | """ 43 | print(0) 44 | if False: # pyrefact: skip_file 45 | print(1) 46 | """, 47 | ), 48 | ( # Unrelated ignore comment 49 | """ 50 | print(0) 51 | if False: # type: ignore 52 | print(1) 53 | """, 54 | """ 55 | print(0) 56 | """, 57 | ), 58 | ( # Invalid ignore comment 59 | """ 60 | print(0) 61 | if False: # pyrefact: asdfdsas 62 | print(1) 63 | """, 64 | """ 65 | print(0) 66 | """, 67 | ), 68 | ( # Valid ignore comment with extra spaces 69 | """ 70 | print(0) 71 | if False:# pyrefact :ignore 72 | print(1) 73 | """, 74 | """ 75 | print(0) 76 | if False:# pyrefact :ignore 77 | print(1) 78 | """, 79 | ), 80 | ( # Valid ignore comment with extra spaces 81 | """ 82 | print(0) 83 | if False:# pyrefact :skip_file 84 | print(1) 85 | """, 86 | """ 87 | print(0) 88 | if False:# pyrefact :skip_file 89 | print(1) 90 | """, 91 | ),) 92 | 93 | for source, expected_abstraction in test_cases: 94 | processed_content = format_code(source) 95 | if not testing_infra.check_fixes_equal(processed_content, expected_abstraction): 96 | return 1 97 | 98 | return 0 99 | 100 | 101 | if __name__ == "__main__": 102 | sys.exit(main()) 103 | -------------------------------------------------------------------------------- /tests/unit/test_implicit_dict_keys_values_items.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import sys 4 | from pathlib import Path 5 | 6 | from pyrefact import fixes 7 | 8 | sys.path.append(str(Path(__file__).parents[1])) 9 | import testing_infra 10 | 11 | 12 | def main() -> int: 13 | test_cases = ( 14 | ( # Items to keys 15 | """ 16 | (x for x, _ in d.items()) 17 | """, 18 | """ 19 | (x for x in d.keys()) 20 | """, 21 | ), 22 | ( 23 | """ 24 | {x: 100 - x for x, _ in d.items()} 25 | """, 26 | """ 27 | {x: 100 - x for x in d.keys()} 28 | """, 29 | ), 30 | ( # Items to values 31 | """ 32 | [x for _, x in d.items() if foo if bar if x if baz] 33 | """, 34 | """ 35 | [x for x in d.values() if foo if bar if x if baz] 36 | """, 37 | ), 38 | ( 39 | """ 40 | {x for _, x in d.items() if foo if bar if x if baz} 41 | """, 42 | """ 43 | {x for x in d.values() if foo if bar if x if baz} 44 | """, 45 | ), 46 | ( 47 | """ 48 | for x, _ in d.items(): 49 | print(x) 50 | """, 51 | """ 52 | for x in d.keys(): 53 | print(x) 54 | """, 55 | ), 56 | ( 57 | """ 58 | for _, x in d.items(): 59 | print(x) 60 | """, 61 | """ 62 | for x in d.values(): 63 | print(x) 64 | """, 65 | ), 66 | ( # Implicit items 67 | """ 68 | {(x, d[x]) for x in d.keys()} 69 | """, 70 | """ 71 | {(x, d_x) for x, d_x in d.items()} 72 | """, 73 | ), 74 | ( 75 | """ 76 | for x in d.keys(): 77 | print(x) 78 | print(d[x]) 79 | """, 80 | """ 81 | for x, d_x in d.items(): 82 | print(x) 83 | print(d_x) 84 | """, 85 | ), 86 | ( # Implicit values 87 | """ 88 | [d[x] for x in d.keys()] 89 | """, 90 | """ 91 | [d_x for x, d_x in d.items()] 92 | """, 93 | ), 94 | ( 95 | """ 96 | for x in d.keys(): 97 | print(d[x]) 98 | """, 99 | """ 100 | for x, d_x in d.items(): 101 | print(d_x) 102 | """, 103 | ),) 104 | 105 | for source, expected_abstraction in test_cases: 106 | processed_content = fixes.implicit_dict_keys_values_items(source) 107 | if not testing_infra.check_fixes_equal( 108 | processed_content, expected_abstraction, clear_paranthesises=True 109 | ): 110 | return 1 111 | 112 | return 0 113 | 114 | 115 | if __name__ == "__main__": 116 | sys.exit(main()) 117 | -------------------------------------------------------------------------------- /tests/unit/test_implicit_dot.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | 4 | import sys 5 | from pathlib import Path 6 | 7 | from pyrefact import fixes, performance_numpy 8 | 9 | sys.path.append(str(Path(__file__).parents[1])) 10 | import testing_infra 11 | 12 | 13 | def main() -> int: 14 | test_cases = ( 15 | ( 16 | """ 17 | import numpy as np 18 | 19 | a = np.random.random(n) 20 | b = np.random.random(n) 21 | 22 | c = sum([a_ * b_ for a_, b_ in zip(a, b)]) 23 | d = np.sum(a_ * b_ for a_, b_ in zip(a, b)) 24 | print(c, d) 25 | """, 26 | """ 27 | import numpy as np 28 | 29 | a = np.random.random(n) 30 | b = np.random.random(n) 31 | 32 | c = np.dot(a, b) 33 | d = np.dot(a, b) 34 | print(c, d) 35 | """, 36 | ), 37 | ( 38 | """ 39 | n = 10 40 | def _mysterious_function(a: np.array, b: np.array): 41 | return sum([a_ * b_ for a_, b_ in zip(a, b)]) 42 | 43 | a = np.random.random(n) 44 | b = np.random.random(n) 45 | 46 | c = _mysterious_function(a, b) 47 | print(c, np.dot(a, b)) 48 | """, 49 | """ 50 | n = 10 51 | def _mysterious_function(a: np.array, b: np.array): 52 | return np.dot(a, b) 53 | 54 | a = np.random.random(n) 55 | b = np.random.random(n) 56 | 57 | c = _mysterious_function(a, b) 58 | print(c, np.dot(a, b)) 59 | """, 60 | ), 61 | ( 62 | """ 63 | import numpy as np 64 | 65 | i, j, k = 10, 11, 12 66 | 67 | a = np.random.random((i, j)) 68 | b = np.random.random((j, k)) 69 | 70 | u = np.array( 71 | [ 72 | [ 73 | np.sum( 74 | a__ * b__ 75 | for a__, b__ in zip(a_, b_) 76 | ) 77 | for a_ in a 78 | ] 79 | for b_ in b.T 80 | ] 81 | ).T 82 | 83 | print(u) 84 | """, 85 | """ 86 | import numpy as np 87 | 88 | i, j, k = 10, 11, 12 89 | 90 | a = np.random.random((i, j)) 91 | b = np.random.random((j, k)) 92 | 93 | u = np.array( 94 | [ 95 | [ 96 | np.dot(a_, b_) 97 | for a_ in a 98 | ] 99 | for b_ in b.T 100 | ] 101 | ).T 102 | 103 | print(u) 104 | """, 105 | ), 106 | ) 107 | 108 | for source, expected_abstraction in test_cases: 109 | 110 | processed_content = performance_numpy.replace_implicit_dot(source) 111 | processed_content = fixes.simplify_transposes(processed_content) 112 | 113 | if not testing_infra.check_fixes_equal(processed_content, expected_abstraction): 114 | return 1 115 | 116 | return 0 117 | 118 | 119 | if __name__ == "__main__": 120 | sys.exit(main()) 121 | -------------------------------------------------------------------------------- /tests/unit/test_implicit_matmul.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | 4 | import sys 5 | from pathlib import Path 6 | 7 | from pyrefact import fixes, performance_numpy 8 | 9 | sys.path.append(str(Path(__file__).parents[1])) 10 | import testing_infra 11 | 12 | 13 | def main() -> int: 14 | test_cases = ( 15 | ( 16 | """ 17 | import numpy as np 18 | 19 | i, j, k = 10, 11, 12 20 | 21 | a = np.random.random((i, j)) 22 | b = np.random.random((j, k)) 23 | 24 | u = np.array([[np.dot(a_, b_) for a_ in a] for b_ in b.T]).T 25 | v = np.array([[np.dot(b_, a_) for b_ in b.T] for a_ in a]) 26 | 27 | print(np.sum((u - np.matmul(a, b)).ravel())) 28 | print(np.sum((v - np.matmul(a, b)).ravel())) 29 | """, 30 | """ 31 | import numpy as np 32 | 33 | i, j, k = 10, 11, 12 34 | 35 | a = np.random.random((i, j)) 36 | b = np.random.random((j, k)) 37 | 38 | u = np.matmul(a, b) 39 | v = np.matmul(a, b) 40 | 41 | print(np.sum((u - np.matmul(a, b)).ravel())) 42 | print(np.sum((v - np.matmul(a, b)).ravel())) 43 | """, 44 | ), 45 | ( 46 | """ 47 | import numpy as np 48 | 49 | i, j, k = 10, 11, 12 50 | 51 | a = np.random.random((i, j)) 52 | b = np.random.random((j, k)) 53 | c = np.random.random((k, j)) 54 | d = np.random.random((j, i)) 55 | 56 | u = np.array([[np.dot(b[:, i], a[j, :]) for i in range(b.shape[1])] for j in range(a.shape[0])]) 57 | v = np.array([[np.dot(c[i, :], a[j, :]) for i in range(c.shape[0])] for j in range(a.shape[0])]) 58 | w = np.array([[np.dot(b[:, i], d[:, j]) for i in range(b.shape[1])] for j in range(d.shape[1])]) 59 | z = np.array([[np.dot(a[i, :], b[:, j]) for i in range(a.shape[0])] for j in range(b.shape[1])]) 60 | 61 | print(np.sum((u - np.matmul(a, b)).ravel())) 62 | print(np.sum((v - np.matmul(a, c.T)).ravel())) 63 | print(np.sum((w - np.matmul(b.T, d).T).ravel())) 64 | print(np.sum((z - np.matmul(b.T, a.T)).ravel())) 65 | """, 66 | """ 67 | import numpy as np 68 | 69 | i, j, k = 10, 11, 12 70 | 71 | a = np.random.random((i, j)) 72 | b = np.random.random((j, k)) 73 | c = np.random.random((k, j)) 74 | d = np.random.random((j, i)) 75 | 76 | u = np.matmul(a, b) 77 | v = np.matmul(c, a.T).T 78 | w = np.matmul(b.T, d).T 79 | z = np.matmul(a, b).T 80 | 81 | print(np.sum((u - np.matmul(a, b)).ravel())) 82 | print(np.sum((v - np.matmul(a, c.T)).ravel())) 83 | print(np.sum((w - np.matmul(b.T, d).T).ravel())) 84 | print(np.sum((z - np.matmul(b.T, a.T)).ravel())) 85 | """, 86 | ), 87 | ( 88 | """ 89 | for i in range(len(left)): 90 | for j in range(len(right[0])): 91 | result[i][j] = np.dot(left[i] * right.T[j]) 92 | """, 93 | """ 94 | result = np.matmul(left, right) 95 | """, 96 | ), 97 | ( 98 | """ 99 | for i in range(len(left)): 100 | for j in range(len(right[0])): 101 | for k in range(len(right)): 102 | result[i][j] += left[i][k] * right[k][j] 103 | """, 104 | """ 105 | result = np.matmul(left, right) 106 | """, 107 | ), 108 | ( 109 | """ 110 | result = [ 111 | [ 112 | np.dot(left[i] * right.T[j]) 113 | for j in range(len(right[0])) 114 | ] 115 | for i in range(len(left)) 116 | ] 117 | """, 118 | """ 119 | result = np.matmul(left, right) 120 | """, 121 | ), 122 | ( 123 | """ 124 | result = [ 125 | [ 126 | sum( 127 | left[i][k] * right[k][j] 128 | for k in range(len(right)) 129 | ) 130 | for j in range(len(right[0])) 131 | ] 132 | for i in range(len(left)) 133 | ] 134 | """, 135 | """ 136 | result = np.matmul(left, right) 137 | """, 138 | ), 139 | ) 140 | 141 | for source, expected_abstraction in test_cases: 142 | 143 | processed_content = performance_numpy.replace_implicit_matmul(source) 144 | processed_content = fixes.simplify_transposes(processed_content) 145 | 146 | if not testing_infra.check_fixes_equal(processed_content, expected_abstraction): 147 | return 1 148 | 149 | return 0 150 | 151 | 152 | if __name__ == "__main__": 153 | sys.exit(main()) 154 | -------------------------------------------------------------------------------- /tests/unit/test_inline_math_comprehensions.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | 4 | import sys 5 | from pathlib import Path 6 | 7 | from pyrefact import fixes 8 | 9 | sys.path.append(str(Path(__file__).parents[1])) 10 | import testing_infra 11 | 12 | 13 | def main() -> int: 14 | test_cases = ( # the original assignment becomes dead code and is removed by later steps 15 | ( 16 | """ 17 | z = {a for a in range(10)} 18 | x = sum(z) 19 | """, 20 | """ 21 | z = {a for a in range(10)} 22 | x = sum({a for a in range(10)}) 23 | """, 24 | ), 25 | ( 26 | """ 27 | w = [a ** 2 for a in range(10)] 28 | y = sum(w) 29 | """, 30 | """ 31 | w = [a ** 2 for a in range(10)] 32 | y = sum([a ** 2 for a in range(10)]) 33 | """, 34 | ), 35 | ( 36 | """ 37 | k = True 38 | w = [a ** 2 for a in range(11) if k] 39 | k = False 40 | y = sum(w) 41 | """, 42 | """ 43 | k = True 44 | w = [a ** 2 for a in range(11) if k] 45 | k = False 46 | y = sum(w) 47 | """, 48 | ), 49 | ( 50 | """ 51 | for i in range(10): 52 | w = [a ** 2 for a in range(10)] 53 | y = sum(w) 54 | """, 55 | """ 56 | for i in range(10): 57 | w = [a ** 2 for a in range(10)] 58 | y = sum([a ** 2 for a in range(10)]) 59 | """, 60 | ), 61 | ( 62 | """ 63 | w = [] 64 | for i in range(10): 65 | y = sum(w) 66 | w = [a ** 2 for a in range(i)] 67 | """, 68 | """ 69 | w = [] 70 | for i in range(10): 71 | y = sum(w) 72 | w = [a ** 2 for a in range(i)] 73 | """, 74 | ),) 75 | 76 | for source, expected_abstraction in test_cases: 77 | processed_content = fixes.inline_math_comprehensions(source) 78 | 79 | if not testing_infra.check_fixes_equal(processed_content, expected_abstraction): 80 | return 1 81 | 82 | return 0 83 | 84 | 85 | if __name__ == "__main__": 86 | sys.exit(main()) 87 | -------------------------------------------------------------------------------- /tests/unit/test_invalid_escape_sequence.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import sys 4 | from pathlib import Path 5 | 6 | from pyrefact import fixes 7 | 8 | sys.path.append(str(Path(__file__).parents[1])) 9 | import testing_infra 10 | 11 | 12 | def main() -> int: 13 | test_cases = (( 14 | r""" 15 | import re 16 | print(re.findall("\d+", "1234x23")) 17 | """, 18 | r""" 19 | import re 20 | print(re.findall(r"\d+", "1234x23")) 21 | """, 22 | ), 23 | ( 24 | r""" 25 | import re 26 | print(re.findall("\+", "1234+23")) 27 | """, 28 | r""" 29 | import re 30 | print(re.findall(r"\+", "1234+23")) 31 | """, 32 | ), 33 | ( # Watch out with f strings 34 | r""" 35 | import re 36 | print(re.findall(f"\d{'+'}", "1234x23")) 37 | """, 38 | r""" 39 | import re 40 | print(re.findall(f"\d{'+'}", "1234x23")) 41 | """, 42 | ),) 43 | 44 | for source, expected_abstraction in test_cases: 45 | processed_content = fixes.invalid_escape_sequence(source) 46 | if not testing_infra.check_fixes_equal(processed_content, expected_abstraction): 47 | return 1 48 | 49 | return 0 50 | 51 | 52 | if __name__ == "__main__": 53 | sys.exit(main()) 54 | -------------------------------------------------------------------------------- /tests/unit/test_is_blocking.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import ast 4 | import sys 5 | 6 | from pyrefact import core 7 | 8 | 9 | def main() -> int: 10 | """Test core.is_blocking 11 | 12 | Returns: 13 | int: 1 if the function behaves incorrectly, otherwise 0 14 | """ 15 | for source in ( 16 | """ 17 | for i in range(10): 18 | continue 19 | """, 20 | """ 21 | for i in range(10): 22 | break 23 | """, 24 | """ 25 | for i in range(10): 26 | print(1) 27 | print(2) 28 | with x as y: 29 | continue 30 | """, 31 | """ 32 | for i in x: 33 | while True: 34 | break 35 | """, 36 | """ 37 | for i in x: 38 | while True: 39 | break 40 | continue 41 | """, 42 | """ 43 | while True: 44 | break 45 | raise Exception() 46 | """, 47 | """ 48 | while statement: 49 | if x: 50 | statement = False 51 | if random.random(): 52 | statement = False 53 | """, 54 | """ 55 | while False: 56 | print(2) 57 | raise RuntimeError() 58 | """, 59 | """ 60 | if 66: 61 | pass 62 | """, 63 | """ 64 | if 66: 65 | print(6) 66 | else: 67 | return 22 68 | """, 69 | """ 70 | if None: 71 | break 72 | else: 73 | f = 2 + x() 74 | """, 75 | """ 76 | for i in []: 77 | raise ValueError() 78 | """, 79 | """ 80 | for i in something: 81 | raise ValueError() 82 | """, 83 | ): 84 | node = core.parse(source).body[0] 85 | if not core.is_blocking(node): 86 | continue 87 | 88 | print("Ast is blocking, but should not be:") 89 | print(source) 90 | print("Ast structure:") 91 | print(ast.dump(node, indent=2)) 92 | return 1 93 | 94 | for source in ( 95 | """ 96 | for i in [1, 2, 3]: 97 | raise ValueError() 98 | """, 99 | """ 100 | for i in (1, 2, 3): 101 | print(1) 102 | print(2) 103 | with x as y: 104 | raise RuntimeError() 105 | """, 106 | """ 107 | for i in [None, False]: 108 | while True: 109 | break 110 | assert False 111 | """, 112 | """ 113 | if x: 114 | raise RuntimeError() 115 | elif y: 116 | if z: 117 | for a in range(10): 118 | continue 119 | return 1 120 | break 121 | else: 122 | print(2) 123 | return 99 124 | """, 125 | """ 126 | while True: 127 | while True: 128 | while True: 129 | print(3) 130 | """, 131 | """ 132 | while True: 133 | while True: 134 | while True: 135 | raise Exception() 136 | """, 137 | """ 138 | while 1: 139 | while False: 140 | pass 141 | """, 142 | """ 143 | if 66: 144 | return 0 145 | """, 146 | """ 147 | if 66: 148 | return 0 149 | else: 150 | print(8) 151 | """, 152 | """ 153 | if None: 154 | pass 155 | else: 156 | return 0 157 | """, 158 | ): 159 | node = core.parse(source).body[0] 160 | if core.is_blocking(node): 161 | continue 162 | 163 | print("Ast is not blocking, but should be:") 164 | print(source) 165 | print("Ast structure:") 166 | print(ast.dump(node, indent=2)) 167 | return 1 168 | 169 | return 0 170 | 171 | 172 | if __name__ == "__main__": 173 | sys.exit(main()) 174 | -------------------------------------------------------------------------------- /tests/unit/test_literal_value.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import unittest 3 | 4 | from pyrefact import core 5 | 6 | 7 | def parse_one_node(source: str) -> ast.AST: 8 | module = ast.parse(source) 9 | assert len(module.body) == 1 10 | assert isinstance(module.body[0], ast.Expr) 11 | return module.body[0].value 12 | 13 | 14 | class TestLiteralValue(unittest.TestCase): 15 | def test_basic_datatypes(self): 16 | node = parse_one_node("1") 17 | assert core.literal_value(node) == 1 18 | 19 | node = parse_one_node("1.5") 20 | assert core.literal_value(node) == 1.5 21 | 22 | node = parse_one_node("False") 23 | assert core.literal_value(node) == False 24 | 25 | node = parse_one_node("()") 26 | assert core.literal_value(node) == () 27 | 28 | node = parse_one_node("[]") 29 | assert core.literal_value(node) == [] 30 | 31 | node = parse_one_node("{}") 32 | assert core.literal_value(node) == {} 33 | 34 | node = parse_one_node("'abcd'") 35 | assert core.literal_value(node) == "abcd" 36 | 37 | node = parse_one_node("None") 38 | assert core.literal_value(node) == None 39 | 40 | def test_nested_data(self): 41 | node = parse_one_node("[1, 2, 3]") 42 | assert core.literal_value(node) == [1, 2, 3] 43 | 44 | node = parse_one_node("{'a': 1, 'b': 2}") 45 | assert core.literal_value(node) == {"a": 1, "b": 2} 46 | 47 | node = parse_one_node("{'a': [1, 2, 3], 'b': {'c': 4}}") 48 | assert core.literal_value(node) == {"a": [1, 2, 3], "b": {"c": 4}} 49 | 50 | def test_comparisons(self): 51 | node = parse_one_node("1 < 2") 52 | assert core.literal_value(node) == True 53 | 54 | node = parse_one_node("1 > 2") 55 | assert core.literal_value(node) == False 56 | 57 | node = parse_one_node("1 == 2") 58 | assert core.literal_value(node) == False 59 | 60 | node = parse_one_node("1 != 2") 61 | assert core.literal_value(node) == True 62 | 63 | node = parse_one_node("1 <= 2") 64 | assert core.literal_value(node) == True 65 | 66 | node = parse_one_node("1 >= 2") 67 | assert core.literal_value(node) == False 68 | 69 | node = parse_one_node("1 < 2 < 3 < 4 == 4 < 4.5") 70 | assert core.literal_value(node) == True 71 | 72 | node = parse_one_node("not True") 73 | assert core.literal_value(node) == False 74 | 75 | node = parse_one_node("[] or ()") 76 | assert core.literal_value(node) == () 77 | 78 | node = parse_one_node("[] and ()") 79 | assert core.literal_value(node) == [] 80 | 81 | node = parse_one_node("not []") 82 | assert core.literal_value(node) == True 83 | 84 | node = parse_one_node("1 and 2") 85 | assert core.literal_value(node) == 2 86 | 87 | node = parse_one_node("1 or 2") 88 | assert core.literal_value(node) == 1 89 | 90 | def test_arithmetic(self): 91 | node = parse_one_node("1 + 2") 92 | assert core.literal_value(node) == 3 93 | 94 | node = parse_one_node("1 - 2") 95 | assert core.literal_value(node) == -1 96 | 97 | node = parse_one_node("[1] + [2]") 98 | assert core.literal_value(node) == [1, 2] 99 | 100 | def test_literal_calls(self): 101 | node = parse_one_node("''.join(['a', 'b'])") 102 | assert core.literal_value(node) == "ab" 103 | 104 | node = parse_one_node("len('abc')") 105 | assert core.literal_value(node) == 3 106 | 107 | node = parse_one_node("len([1, 2, 3])") 108 | assert core.literal_value(node) == 3 109 | 110 | def runTest(self): 111 | self.test_basic_datatypes() 112 | self.test_nested_data() 113 | self.test_comparisons() 114 | self.test_arithmetic() 115 | self.test_literal_calls() 116 | 117 | 118 | def main() -> int: 119 | # For use with ./tests/main.py, which looks for these main functions. 120 | # unittest.main() will do sys.exit() or something, it quits the whole 121 | # program after done and prevents further tests from running. 122 | 123 | test_result = TestLiteralValue().run() 124 | if not test_result.wasSuccessful(): 125 | test_result.printErrors() 126 | return 1 127 | 128 | return 0 129 | 130 | 131 | if __name__ == "__main__": 132 | unittest.main() 133 | -------------------------------------------------------------------------------- /tests/unit/test_merge_chained_comps.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import sys 4 | from pathlib import Path 5 | 6 | from pyrefact import fixes 7 | 8 | sys.path.append(str(Path(__file__).parents[1])) 9 | import testing_infra 10 | 11 | 12 | def main() -> int: 13 | test_cases = (( 14 | """ 15 | x = (y for y in (y for y in (3, 4, 5))) 16 | """, 17 | """ 18 | x = (y for y in (3, 4, 5)) 19 | """, 20 | ), 21 | ( 22 | """ 23 | x = {y for y in {y for y in (3, 4, 5)}} 24 | """, 25 | """ 26 | x = {y for y in (3, 4, 5)} 27 | """, 28 | ), 29 | ( 30 | """ 31 | x = [y for y in [y for y in (3, 4, 5)]] 32 | """, 33 | """ 34 | x = [y for y in (3, 4, 5)] 35 | """, 36 | ), 37 | ( # Not same comprehension type 38 | """ 39 | x = {y for y in (y for y in (3, 4, 5))} 40 | """, 41 | """ 42 | x = {y for y in (y for y in (3, 4, 5))} 43 | """, 44 | ), 45 | ( 46 | """ 47 | x = (y for y in (y for y in (3, 4, 5) if y > 3)) 48 | """, 49 | """ 50 | x = (y for y in (3, 4, 5) if y > 3) 51 | """, 52 | ), 53 | ( 54 | """ 55 | x = (y for y in (y for y in (3, 4, 5)) if y > 3) 56 | """, 57 | """ 58 | x = (y for y in (3, 4, 5) if y > 3) 59 | """, 60 | ), 61 | ( 62 | """ 63 | x = (y for y in (y for y in (3, 4, 5) if y > 3) if y < 5) 64 | """, 65 | """ 66 | x = (y for y in (3, 4, 5) if y > 3 if y < 5) 67 | """, 68 | ), 69 | ( 70 | """ 71 | x = (y ** 2 for y in (y for y in (3, 4, 5))) 72 | """, 73 | """ 74 | x = (y ** 2 for y in (3, 4, 5)) 75 | """, 76 | ), 77 | ( # Inner elt is not same as inner target 78 | """ 79 | x = (y ** 2 for y in (y ** 2 for y in (3, 4, 5))) 80 | """, 81 | """ 82 | x = (y ** 2 for y in (y ** 2 for y in (3, 4, 5))) 83 | """, 84 | ), 85 | ( 86 | """ 87 | x = (y ** z for y, z in ((y, z) for y, z in zip((3, 4, 5), [3, 4, 5]))) 88 | """, 89 | """ 90 | x = (y ** z for y, z in zip((3, 4, 5), [3, 4, 5])) 91 | """, 92 | ),) 93 | 94 | for source, expected_abstraction in test_cases: 95 | processed_content = fixes.merge_chained_comps(source) 96 | if not testing_infra.check_fixes_equal( 97 | processed_content, expected_abstraction, clear_paranthesises=True 98 | ): 99 | return 1 100 | 101 | return 0 102 | 103 | 104 | if __name__ == "__main__": 105 | sys.exit(main()) 106 | -------------------------------------------------------------------------------- /tests/unit/test_merge_nested_comprehensions.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | 4 | import sys 5 | from pathlib import Path 6 | 7 | from pyrefact import fixes 8 | 9 | sys.path.append(str(Path(__file__).parents[1])) 10 | import testing_infra 11 | 12 | 13 | def main() -> int: 14 | test_cases = ( 15 | ( # GeneratorExp in SetComp 16 | """ 17 | {x for x in (y for y in range(10))} 18 | """, 19 | """ 20 | {x for x in range(10)} 21 | """, 22 | ), 23 | ( # GeneratorExp in ListComp 24 | """ 25 | [x for x in (y for y in range(10))] 26 | """, 27 | """ 28 | [x for x in range(10)] 29 | """, 30 | ), 31 | ( # GeneratorExp in DictComp 32 | """ 33 | {x: 1 for x in (y for y in range(10))} 34 | """, 35 | """ 36 | {x: 1 for x in range(10)} 37 | """, 38 | ), 39 | ( # GeneratorExp in ListComp 40 | """ 41 | {x for x in [y for y in range(10)]} 42 | """, 43 | """ 44 | {x for x in range(10)} 45 | """, 46 | ), 47 | ( # GeneratorExp in GeneratorExp 48 | """ 49 | (x for x in (y for y in range(10))) 50 | """, 51 | """ 52 | (x for x in range(10)) 53 | """, 54 | ), 55 | ( # SetComp in SetComp 56 | """ 57 | {x for x in {y for y in range(10)}} 58 | """, 59 | """ 60 | {x for x in range(10)} 61 | """, 62 | ), 63 | ( # SetComp in SetComp, wrong name 64 | """ 65 | {x for x in {h for y in range(10)}} 66 | """, 67 | """ 68 | {x for x in {h for y in range(10)}} 69 | """, 70 | ), 71 | ( # SetComp in SetComp, non-trivial target 72 | """ 73 | {x for x in {y + 1 for y in range(10)}} 74 | """, 75 | """ 76 | {x for x in {y + 1 for y in range(10)}} 77 | """, 78 | ), 79 | ( # SetComp in SetComp, non-trivial iter 80 | """ 81 | {x for x in {y + 1 for y, z in range(10)}} 82 | """, 83 | """ 84 | {x for x in {y + 1 for y, z in range(10)}} 85 | """, 86 | ), 87 | ( # SetComp in ListComp 88 | """ 89 | [x for x in {y for y in range(10)}] 90 | """, 91 | """ 92 | [x for x in {y for y in range(10)}] 93 | """, 94 | ), 95 | ( # SetComp in GeneratorExp 96 | """ 97 | (x for x in {y for y in range(10)}) 98 | """, 99 | """ 100 | (x for x in {y for y in range(10)}) 101 | """, 102 | ), 103 | ( # SetComp in DictComp 104 | """ 105 | {x: 99 for x in {y for y in range(10)}} 106 | """, 107 | """ 108 | {x: 99 for x in range(10)} 109 | """, 110 | ), 111 | ( # DictComp in DictComp 112 | """ 113 | {x: 99 for x in {y: 3131 for y in range(10)}} 114 | """, 115 | """ 116 | {x: 99 for x in range(10)} 117 | """, 118 | ), 119 | ( # DictComp in DictComp 120 | """ 121 | {x: 99 for x in {3131: y for y in range(10)}} 122 | """, 123 | """ 124 | {x: 99 for x in {3131: y for y in range(10)}} 125 | """, 126 | ), 127 | ( # DictComp in DictComp 128 | """ 129 | {x: 99 for x in {z: y for y in range(10)}} 130 | """, 131 | """ 132 | {x: 99 for x in {z: y for y in range(10)}} 133 | """, 134 | ),) 135 | 136 | for source, expected_abstraction in test_cases: 137 | processed_content = fixes.merge_nested_comprehensions(source) 138 | 139 | if not testing_infra.check_fixes_equal(processed_content, expected_abstraction): 140 | return 1 141 | 142 | return 0 143 | 144 | 145 | if __name__ == "__main__": 146 | sys.exit(main()) 147 | -------------------------------------------------------------------------------- /tests/unit/test_missing_context_manager.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | 4 | import sys 5 | from pathlib import Path 6 | 7 | from pyrefact import fixes 8 | 9 | sys.path.append(str(Path(__file__).parents[1])) 10 | import testing_infra 11 | 12 | 13 | def main() -> int: 14 | test_cases = (( 15 | """ 16 | import requests 17 | 18 | session = requests.Session() 19 | session.get("https://www.google.com") 20 | """, 21 | """ 22 | import requests 23 | 24 | with requests.Session() as session: 25 | session.get("https://www.google.com") 26 | """, 27 | ), 28 | ( 29 | """ 30 | x = open("path/to/file.py") 31 | print(x.read()) 32 | 33 | x.close() 34 | """, 35 | """ 36 | with open("path/to/file.py") as x: 37 | print(x.read()) 38 | """, 39 | ), 40 | ( 41 | """ 42 | with open("path/to/file.py") as x: 43 | print(x.read()) 44 | """, 45 | """ 46 | with open("path/to/file.py") as x: 47 | print(x.read()) 48 | """, 49 | ), 50 | ( 51 | """ 52 | @app.route('/capacities', methods=['GET']) 53 | @cross_origin() 54 | def get_capacity(): 55 | connection = psycopg2.connect(DATABASE_URI) 56 | cursor = connection.cursor() 57 | cursor.execute("SELECT id, name FROM capacity") 58 | capacities = [{"id": row[0], "name": row[1]} for row in cursor.fetchall()] 59 | 60 | return jsonify(capacities) 61 | """, 62 | """ 63 | @app.route('/capacities', methods=['GET']) 64 | @cross_origin() 65 | def get_capacity(): 66 | with psycopg2.connect(DATABASE_URI) as connection: 67 | with connection.cursor() as cursor: 68 | cursor.execute("SELECT id, name FROM capacity") 69 | capacities = [{"id": row[0], "name": row[1]} for row in cursor.fetchall()] 70 | 71 | return jsonify(capacities) 72 | """, 73 | ),) 74 | 75 | for source, expected_abstraction in test_cases: 76 | processed_content = fixes.missing_context_manager(source) 77 | 78 | if not testing_infra.check_fixes_equal(processed_content, expected_abstraction): 79 | return 1 80 | 81 | return 0 82 | 83 | 84 | if __name__ == "__main__": 85 | sys.exit(main()) 86 | -------------------------------------------------------------------------------- /tests/unit/test_move_imports_to_toplevel.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import sys 4 | from pathlib import Path 5 | 6 | from pyrefact import fixes 7 | 8 | sys.path.append(str(Path(__file__).parents[1])) 9 | import testing_infra 10 | 11 | 12 | def main() -> int: 13 | test_cases = (( 14 | """ 15 | #!/usr/bin/env python3 16 | '''docstring''' 17 | import time 18 | import os 19 | import sys 20 | 21 | sys.path.append(os.getcwd()) 22 | from somewhere import something 23 | 24 | def function_call(): 25 | from somewhere import something_else 26 | return something_else() 27 | 28 | def call2(): 29 | from somewhere_else import qwerty 30 | return qwerty() 31 | 32 | def call3(): 33 | import math 34 | print(math.sum([3])) 35 | """, 36 | """ 37 | #!/usr/bin/env python3 38 | '''docstring''' 39 | import time 40 | import os 41 | import sys 42 | import math 43 | 44 | sys.path.append(os.getcwd()) 45 | from somewhere import something_else 46 | from somewhere import something 47 | 48 | def function_call(): 49 | return something_else() 50 | 51 | def call2(): 52 | from somewhere_else import qwerty 53 | return qwerty() 54 | 55 | def call3(): 56 | print(math.sum([3])) 57 | """, 58 | ), 59 | ( 60 | """ 61 | import os 62 | from numpy import ( 63 | integer, ndarray, dtype as _dtype, asarray, frombuffer 64 | ) 65 | from numpy.core.multiarray import _flagdict, flagsobj 66 | def foo(): 67 | from numpy import intp as c_intp 68 | print(199) 69 | def woo(): 70 | from numpy import intp as c_intp 71 | print(199) 72 | def hoo(): 73 | from numpy import intp as c_intp 74 | print(199) 75 | """, 76 | """ 77 | from numpy import intp as c_intp 78 | import os 79 | from numpy import ( 80 | integer, ndarray, dtype as _dtype, asarray, frombuffer 81 | ) 82 | from numpy.core.multiarray import _flagdict, flagsobj 83 | def foo(): 84 | print(199) 85 | def woo(): 86 | print(199) 87 | def hoo(): 88 | print(199) 89 | """, 90 | ), 91 | ( 92 | """ 93 | import sys 94 | if sys.version_info >= (3, 11): 95 | from tomllib import load # pyrefact: ignore 96 | else: 97 | from tomli import load # pyrefact: ignore 98 | """, 99 | """ 100 | import sys 101 | if sys.version_info >= (3, 11): 102 | from tomllib import load # pyrefact: ignore 103 | else: 104 | from tomli import load # pyrefact: ignore 105 | """, 106 | ), 107 | ( 108 | """ 109 | import sys 110 | if sys.version_info >= (3, 11): 111 | import re 112 | from tomllib import load # pyrefact: ignore 113 | else: 114 | from tomli import load # pyrefact: ignore 115 | """, 116 | """ 117 | import re 118 | import sys 119 | if sys.version_info >= (3, 11): 120 | from tomllib import load # pyrefact: ignore 121 | else: 122 | from tomli import load # pyrefact: ignore 123 | """, 124 | ),) 125 | 126 | for source, expected_abstraction in test_cases: 127 | processed_content = fixes.move_imports_to_toplevel(source) 128 | if not testing_infra.check_fixes_equal(processed_content, expected_abstraction): 129 | return 1 130 | 131 | return 0 132 | 133 | 134 | if __name__ == "__main__": 135 | sys.exit(main()) 136 | -------------------------------------------------------------------------------- /tests/unit/test_optimize_contains_types.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | 4 | import sys 5 | from pathlib import Path 6 | 7 | from pyrefact import performance 8 | 9 | sys.path.append(str(Path(__file__).parents[1])) 10 | import testing_infra 11 | 12 | 13 | def main() -> int: 14 | test_cases = (( 15 | """ 16 | 1 in (1, 2, 3) 17 | x in {1, 2, ()} 18 | x in [1, 2, []] 19 | w in [1, 2, {}] 20 | w in {foo, bar, "asdf", coo} 21 | w in (foo, bar, "asdf", coo) 22 | w in {x for x in range(10)} 23 | w in [x for x in range(10)] 24 | w in (x for x in range(10)) 25 | w in {x for x in [1, 3, "", 909, ()]} 26 | w in [x for x in [1, 3, "", 909, ()]] 27 | w in (x for x in [1, 3, "", 909, ()]) 28 | x in sorted([1, 2, 3]) 29 | """, 30 | """ 31 | 1 in {1, 2, 3} 32 | x in {1, 2, ()} 33 | x in (1, 2, []) 34 | w in (1, 2, {}) 35 | w in {foo, bar, "asdf", coo} 36 | w in (foo, bar, "asdf", coo) 37 | w in (x for x in range(10)) 38 | w in (x for x in range(10)) 39 | w in (x for x in range(10)) 40 | w in (x for x in [1, 3, "", 909, ()]) 41 | w in (x for x in [1, 3, "", 909, ()]) 42 | w in (x for x in [1, 3, "", 909, ()]) 43 | x in {1, 2, 3} 44 | """, 45 | ),) 46 | 47 | for source, expected_abstraction in test_cases: 48 | processed_content = performance.optimize_contains_types(source) 49 | 50 | if not testing_infra.check_fixes_equal(processed_content, expected_abstraction): 51 | return 1 52 | 53 | return 0 54 | 55 | 56 | if __name__ == "__main__": 57 | sys.exit(main()) 58 | -------------------------------------------------------------------------------- /tests/unit/test_pattern_zeroormore_zeroorone_zeroormany.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from pyrefact.core import match_template, ZeroOrOne, ZeroOrMany, OneOrMany 4 | 5 | 6 | class TestZeroOrOne(unittest.TestCase): 7 | def test(self): 8 | assert match_template([], [ZeroOrOne(object)]) 9 | assert match_template([object], [ZeroOrOne(object)]) 10 | assert not match_template([object, "asdf"], [ZeroOrOne(object)]) 11 | 12 | assert match_template([], [ZeroOrOne(str)]) 13 | assert match_template([""], [ZeroOrOne(str)]) 14 | assert not match_template(["", "asdf"], [ZeroOrOne(str)]) 15 | 16 | def runTest(self): 17 | self.test() 18 | 19 | 20 | class TestZeroOrMany(unittest.TestCase): 21 | def test(self): 22 | assert match_template([], [ZeroOrMany(object)]) 23 | assert match_template([object], [ZeroOrMany(object)]) 24 | assert match_template([object, object], [ZeroOrMany(object)]) 25 | 26 | assert match_template([], [ZeroOrMany(object)]) 27 | assert match_template([""], [ZeroOrMany(str)]) 28 | assert match_template(["a", "b"], [ZeroOrMany(object)]) 29 | 30 | def runTest(self): 31 | self.test() 32 | 33 | 34 | class TestOneOrMany(unittest.TestCase): 35 | def test(self): 36 | assert not match_template([], [OneOrMany(object)]) 37 | assert match_template([object], [OneOrMany(object)]) 38 | assert match_template([object, object], [OneOrMany(object)]) 39 | 40 | assert not match_template([], [OneOrMany(object)]) 41 | assert match_template([object], [OneOrMany(object)]) 42 | assert match_template([1, 31], [OneOrMany(int)]) 43 | 44 | def runTest(self): 45 | self.test() 46 | 47 | 48 | class TestCombination(unittest.TestCase): 49 | def test(self): 50 | assert match_template([], [ZeroOrMany(object), ZeroOrOne(object)]) 51 | assert match_template([object], [ZeroOrMany(object), ZeroOrOne(object)]) 52 | assert match_template([object, object], [ZeroOrMany(object), ZeroOrOne(object)]) 53 | assert match_template(["qwerty", "asdf", object], [ZeroOrMany(str), ZeroOrOne(object)]) 54 | 55 | assert not match_template([], [ZeroOrMany(object), object]) 56 | assert match_template([object], [ZeroOrMany(object), object]) 57 | assert not match_template([object], [OneOrMany(object), object]) 58 | assert match_template([object], [ZeroOrOne(object), object]) 59 | assert not match_template([object], [ZeroOrOne(object), str]) 60 | assert match_template([22], [ZeroOrOne(str), int]) 61 | assert not match_template([], [ZeroOrMany(object), ZeroOrOne(object), OneOrMany(object)]) 62 | 63 | 64 | def main() -> int: 65 | # For use with ./tests/main.py, which looks for these main functions. 66 | # unittest.main() will do sys.exit() or something, it quits the whole 67 | # program after done and prevents further tests from running. 68 | test_result = TestZeroOrOne().run() 69 | if not test_result.wasSuccessful(): 70 | test_result.printErrors() 71 | return 1 72 | 73 | test_result = TestZeroOrMany().run() 74 | if not test_result.wasSuccessful(): 75 | test_result.printErrors() 76 | return 1 77 | 78 | test_result = TestOneOrMany().run() 79 | if not test_result.wasSuccessful(): 80 | test_result.printErrors() 81 | return 1 82 | 83 | return 0 84 | 85 | 86 | if __name__ == "__main__": 87 | unittest.main() 88 | -------------------------------------------------------------------------------- /tests/unit/test_redundant_elses.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | 4 | import sys 5 | from pathlib import Path 6 | 7 | from pyrefact import fixes 8 | 9 | sys.path.append(str(Path(__file__).parents[1])) 10 | import testing_infra 11 | 12 | 13 | def main() -> int: 14 | test_cases = (( 15 | """ 16 | def f(x: int) -> int: 17 | if x == 2: 18 | return 10 19 | else: 20 | return 11 - x 21 | """, 22 | """ 23 | def f(x: int) -> int: 24 | if x == 2: 25 | return 10 26 | 27 | return 11 - x 28 | """, 29 | ), 30 | ( 31 | """ 32 | def f(x: int) -> int: 33 | if x == 2: 34 | return 10 35 | elif x == 12: 36 | return x**x - 3 37 | else: 38 | return 11 - x 39 | """, 40 | """ 41 | def f(x: int) -> int: 42 | if x == 2: 43 | return 10 44 | if x == 12: 45 | return x**x - 3 46 | 47 | return 11 - x 48 | """, 49 | ), 50 | ( 51 | """ 52 | def f(x: int) -> int: 53 | if x < 0: 54 | if x > -100: 55 | return 10 56 | else: 57 | return 101 58 | elif x >= 12: 59 | if x ** 2 >= 99: 60 | return x**x - 3 61 | elif x ** 3 >= 99: 62 | return x**2 63 | else: 64 | return 0 65 | else: 66 | return 11 - x 67 | """, 68 | """ 69 | def f(x: int) -> int: 70 | if x < 0: 71 | if x > -100: 72 | return 10 73 | 74 | return 101 75 | if x >= 12: 76 | if x ** 2 >= 99: 77 | return x**x - 3 78 | if x ** 3 >= 99: 79 | return x**2 80 | 81 | return 0 82 | 83 | return 11 - x 84 | """, 85 | ), 86 | ( 87 | """ 88 | for i in range(10): 89 | if i == 3: 90 | continue 91 | else: 92 | print(2) 93 | """, 94 | """ 95 | for i in range(10): 96 | if i == 3: 97 | continue 98 | 99 | print(2) 100 | """, 101 | ), 102 | ( 103 | """ 104 | for i in range(10): 105 | if i == 3: 106 | while True: 107 | print(1) 108 | time.sleep(3) 109 | else: 110 | print(2) 111 | """, 112 | """ 113 | for i in range(10): 114 | if i == 3: 115 | while True: 116 | print(1) 117 | time.sleep(3) 118 | 119 | print(2) 120 | """, 121 | ), 122 | ( 123 | """ 124 | for i in range(10): 125 | if i == 3: 126 | while True: 127 | print(1) 128 | time.sleep(3) 129 | break 130 | else: 131 | print(2) 132 | """, 133 | """ 134 | for i in range(10): 135 | if i == 3: 136 | while True: 137 | print(1) 138 | time.sleep(3) 139 | break 140 | else: 141 | print(2) 142 | """, 143 | ), 144 | ( 145 | """ 146 | def foo() -> bool: 147 | if x == 1: 148 | return False 149 | elif x == 2: 150 | return True 151 | else: 152 | if z: 153 | return False 154 | else: 155 | return True 156 | """, 157 | """ 158 | def foo() -> bool: 159 | if x == 1: 160 | return False 161 | if x == 2: 162 | return True 163 | if z: 164 | return False 165 | return True 166 | """, 167 | ), # First pass 168 | ( 169 | """ 170 | def foo() -> bool: 171 | if x == 1: 172 | return False 173 | if x == 2: 174 | return True 175 | if z: 176 | return False 177 | else: 178 | return True 179 | """, 180 | """ 181 | def foo() -> bool: 182 | if x == 1: 183 | return False 184 | if x == 2: 185 | return True 186 | if z: 187 | return False 188 | 189 | return True 190 | """, 191 | ), # Second pass 192 | ) 193 | 194 | for source, expected_abstraction in test_cases: 195 | processed_content = fixes.remove_redundant_else(source) 196 | 197 | if not testing_infra.check_fixes_equal(processed_content, expected_abstraction): 198 | return 1 199 | 200 | return 0 201 | 202 | 203 | if __name__ == "__main__": 204 | sys.exit(main()) 205 | -------------------------------------------------------------------------------- /tests/unit/test_redundant_enumerate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import sys 4 | from pathlib import Path 5 | 6 | from pyrefact import fixes 7 | 8 | sys.path.append(str(Path(__file__).parents[1])) 9 | import testing_infra 10 | 11 | 12 | def main() -> int: 13 | test_cases = (( 14 | """ 15 | (x for _, x in enumerate(y)) 16 | """, 17 | """ 18 | (x for x in y) 19 | """, 20 | ), 21 | ( 22 | """ 23 | (x for i, x in enumerate(y)) 24 | """, 25 | """ 26 | (x for i, x in enumerate(y)) 27 | """, 28 | ), 29 | ( 30 | """ 31 | for _, x in enumerate(y): 32 | print(100 * x) 33 | """, 34 | """ 35 | for x in y: 36 | print(100 * x) 37 | """, 38 | ), 39 | ( 40 | """ 41 | for i, x in enumerate(y): 42 | print(100 * x) 43 | """, 44 | """ 45 | for i, x in enumerate(y): 46 | print(100 * x) 47 | """, 48 | ),) 49 | 50 | for source, expected_abstraction in test_cases: 51 | processed_content = fixes.redundant_enumerate(source) 52 | if not testing_infra.check_fixes_equal( 53 | processed_content, expected_abstraction, clear_paranthesises=True 54 | ): 55 | return 1 56 | 57 | return 0 58 | 59 | 60 | if __name__ == "__main__": 61 | sys.exit(main()) 62 | -------------------------------------------------------------------------------- /tests/unit/test_remove_dead_ifs.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import sys 4 | from pathlib import Path 5 | 6 | from pyrefact import fixes 7 | 8 | sys.path.append(str(Path(__file__).parents[1])) 9 | import testing_infra 10 | 11 | 12 | def main() -> int: 13 | test_cases = (( 14 | """ 15 | if False: 16 | print(3) 17 | if True: 18 | print(2) 19 | """, 20 | """ 21 | print(2) 22 | """, 23 | ), 24 | ( 25 | """ 26 | if (): 27 | print(3) 28 | if []: 29 | print(2) 30 | """, 31 | """ 32 | """, 33 | ), 34 | ( 35 | """ 36 | x = 0 37 | if x == 3: 38 | print(x) 39 | if []: 40 | print(2) 41 | else: 42 | print(x + x) 43 | if [1]: 44 | print(222222) 45 | else: 46 | print(x ** x) 47 | """, 48 | """ 49 | x = 0 50 | if x == 3: 51 | print(x) 52 | 53 | print(x + x) 54 | print(222222) 55 | """, 56 | ), 57 | ( 58 | """ 59 | import sys 60 | while False: 61 | sys.exit(0) 62 | while sys.executable == "/usr/bin/python": 63 | print(7) 64 | while True: 65 | sys.exit(2) 66 | """, 67 | """ 68 | import sys 69 | while sys.executable == "/usr/bin/python": 70 | print(7) 71 | while True: 72 | sys.exit(2) 73 | """, 74 | ), 75 | ( 76 | """ 77 | x = 13 78 | a = x if x > 3 else 0 79 | b = x if True else 0 80 | c = x if False else 2 81 | d = 13 if () else {2: 3} 82 | e = 14 if list((1, 2, 3)) else 13 83 | print(3 if 2 > 0 else 2) 84 | print(14 if False else 2) 85 | """, 86 | """ 87 | x = 13 88 | a = x if x > 3 else 0 89 | b = x 90 | c = 2 91 | d = {2: 3} 92 | e = 14 93 | print(3) 94 | print(2) 95 | """, 96 | ), 97 | ( 98 | """ 99 | y = (i for i in range(11) if not ()) 100 | """, 101 | """ 102 | y = (i for i in range(11)) 103 | """, 104 | ), 105 | ( 106 | """ 107 | y = (i for i in range(11) if 7 and not () if foo() if bar() if baz() and wombat() ** 3) 108 | """, 109 | """ 110 | y = (i for i in range(11) if foo() if bar() if baz() and wombat() ** 3) 111 | """, 112 | ), 113 | ( 114 | """ 115 | y = (i for i in range(11) if ()) 116 | """, 117 | """ 118 | y = () 119 | """, 120 | ), 121 | ( 122 | """ 123 | y = [i for i in range(11) if ()] 124 | """, 125 | """ 126 | y = [] 127 | """, 128 | ), 129 | ( 130 | """ 131 | y = {i for i in range(11) if ()} 132 | """, 133 | """ 134 | y = set() 135 | """, 136 | ), 137 | ( 138 | """ 139 | y = {i: i for i in range(11) if ()} 140 | """, 141 | """ 142 | y = {} 143 | """, 144 | ),) 145 | 146 | for source, expected_abstraction in test_cases: 147 | processed_content = fixes.remove_dead_ifs(source) 148 | if not testing_infra.check_fixes_equal(processed_content, expected_abstraction): 149 | return 1 150 | 151 | return 0 152 | 153 | 154 | if __name__ == "__main__": 155 | sys.exit(main()) 156 | -------------------------------------------------------------------------------- /tests/unit/test_remove_duplicate_dict_keys.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | 4 | import sys 5 | from pathlib import Path 6 | 7 | from pyrefact import fixes 8 | 9 | sys.path.append(str(Path(__file__).parents[1])) 10 | import testing_infra 11 | 12 | 13 | def main() -> int: 14 | test_cases = (( 15 | """ 16 | {1: 2, 99: 101, "s": 4, 1: 22, sum(range(11)): 9999, sum(range(11)): 9999} 17 | """, 18 | """ 19 | {99: 101, "s": 4, 1: 22, sum(range(11)): 9999, sum(range(11)): 9999} 20 | """, 21 | ),) 22 | 23 | for source, expected_abstraction in test_cases: 24 | processed_content = fixes.remove_duplicate_dict_keys(source) 25 | 26 | if not testing_infra.check_fixes_equal(processed_content, expected_abstraction): 27 | return 1 28 | 29 | return 0 30 | 31 | 32 | if __name__ == "__main__": 33 | sys.exit(main()) 34 | -------------------------------------------------------------------------------- /tests/unit/test_remove_duplicate_functions.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import sys 4 | from pathlib import Path 5 | 6 | from pyrefact import fixes 7 | 8 | sys.path.append(str(Path(__file__).parents[1])) 9 | import testing_infra 10 | 11 | 12 | def main() -> int: 13 | test_cases = (( 14 | """ 15 | def f(a, b, c): 16 | return 1, 2, 3 17 | def g(a, b, c): 18 | return 1, 2, 3 19 | y = f(1, 2, 3) 20 | h = g(1, 2, 3) 21 | """, 22 | """ 23 | def f(a, b, c): 24 | return 1, 2, 3 25 | y = f(1, 2, 3) 26 | h = f(1, 2, 3) 27 | """, 28 | ), 29 | ( 30 | """ 31 | def f(a, b, c): 32 | w = a ** (b - c) 33 | return 1 + w // 2 34 | def g(c, b, k): 35 | w = c ** (b - k) 36 | return 1 + w // 2 37 | y = f(1, 2, 3) 38 | h = g(1, 2, 3) 39 | """, 40 | """ 41 | def f(a, b, c): 42 | w = a ** (b - c) 43 | return 1 + w // 2 44 | y = f(1, 2, 3) 45 | h = f(1, 2, 3) 46 | """, 47 | ), 48 | ( 49 | """ 50 | def f(a, b, c): 51 | w = a ** (b - c) 52 | return 1 + w // 2 53 | def g(c, b, k): 54 | w = c ** (b - k) 55 | return 1 - w // 2 56 | y = f(1, 2, 3) 57 | h = g(1, 2, 3) 58 | """, 59 | """ 60 | def f(a, b, c): 61 | w = a ** (b - c) 62 | return 1 + w // 2 63 | def g(c, b, k): 64 | w = c ** (b - k) 65 | return 1 - w // 2 66 | y = f(1, 2, 3) 67 | h = g(1, 2, 3) 68 | """, 69 | ), 70 | ( 71 | """ 72 | def f(a, b, c): 73 | w = a ** (b - c) 74 | return 1 + w // 2 75 | def g(a, b, c): 76 | w = a ** (c - b) 77 | return 1 + w // 2 78 | y = f(1, 2, 3) 79 | h = g(1, 2, 3) 80 | """, 81 | """ 82 | def f(a, b, c): 83 | w = a ** (b - c) 84 | return 1 + w // 2 85 | def g(a, b, c): 86 | w = a ** (c - b) 87 | return 1 + w // 2 88 | y = f(1, 2, 3) 89 | h = g(1, 2, 3) 90 | """, 91 | ),) 92 | 93 | for source, expected_abstraction in test_cases: 94 | processed_content = fixes.remove_duplicate_functions(source, set()) 95 | 96 | if not testing_infra.check_fixes_equal(processed_content, expected_abstraction): 97 | return 1 98 | 99 | return 0 100 | 101 | 102 | if __name__ == "__main__": 103 | sys.exit(main()) 104 | -------------------------------------------------------------------------------- /tests/unit/test_remove_duplicate_set_elts.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | 4 | import sys 5 | from pathlib import Path 6 | 7 | from pyrefact import fixes 8 | 9 | sys.path.append(str(Path(__file__).parents[1])) 10 | import testing_infra 11 | 12 | 13 | def main() -> int: 14 | test_cases = (( 15 | """ 16 | {1, 99, "s", 1, sum(range(11)), sum(range(11))} 17 | """, 18 | """ 19 | {1, 99, "s", sum(range(11)), sum(range(11))} 20 | """, 21 | ),) 22 | 23 | for source, expected_abstraction in test_cases: 24 | processed_content = fixes.remove_duplicate_set_elts(source) 25 | 26 | if not testing_infra.check_fixes_equal(processed_content, expected_abstraction): 27 | return 1 28 | 29 | return 0 30 | 31 | 32 | if __name__ == "__main__": 33 | sys.exit(main()) 34 | -------------------------------------------------------------------------------- /tests/unit/test_remove_redundant_boolop_values.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import sys 4 | from pathlib import Path 5 | 6 | from pyrefact import fixes 7 | 8 | sys.path.append(str(Path(__file__).parents[1])) 9 | import testing_infra 10 | 11 | 12 | def main() -> int: 13 | test_cases = (( 14 | """ 15 | 1 or 2 or 3 16 | 0 or 4 or 5 17 | """, 18 | """ 19 | 1 20 | 4 21 | """, 22 | ), 23 | ( 24 | """ 25 | 1 and 2 and 3 26 | 0 and 4 and 5 27 | """, 28 | """ 29 | 3 30 | 0 31 | """, 32 | ), 33 | ( 34 | """ 35 | print(None or os.getcwd() or False) 36 | """, 37 | """ 38 | print(os.getcwd() or False) 39 | """, 40 | ), 41 | ( 42 | """ 43 | print(None or os.getcwd() or False or sys.path) 44 | """, 45 | """ 46 | print(os.getcwd() or sys.path) 47 | """, 48 | ), 49 | ( 50 | """ 51 | print(None and os.getcwd() and False) 52 | """, 53 | """ 54 | print(None) 55 | """, 56 | ), 57 | ( 58 | """ 59 | print(None and os.getcwd() and False and sys.path) 60 | """, 61 | """ 62 | print(None) 63 | """, 64 | ), 65 | ( 66 | """ 67 | print(os.getcwd() and False) 68 | """, 69 | """ 70 | print(os.getcwd() and False) 71 | """, 72 | ), 73 | ( 74 | """ 75 | print(os.getcwd() and sys.path) 76 | """, 77 | """ 78 | print(os.getcwd() and sys.path) 79 | """, 80 | ),) 81 | 82 | for source, expected_abstraction in test_cases: 83 | processed_content = fixes.remove_redundant_boolop_values(source) 84 | if not testing_infra.check_fixes_equal( 85 | processed_content, expected_abstraction, clear_paranthesises=True 86 | ): 87 | return 1 88 | 89 | return 0 90 | 91 | 92 | if __name__ == "__main__": 93 | sys.exit(main()) 94 | -------------------------------------------------------------------------------- /tests/unit/test_remove_redundant_chain_casts.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import sys 4 | from pathlib import Path 5 | 6 | from pyrefact import fixes 7 | 8 | sys.path.append(str(Path(__file__).parents[1])) 9 | import testing_infra 10 | 11 | 12 | def main() -> int: 13 | test_cases = (( 14 | """ 15 | set(itertools.chain(range(10), range(11))) 16 | """, 17 | """ 18 | {*range(10), *range(11)} 19 | """, 20 | ), 21 | ( 22 | """ 23 | set(itertools.chain()) 24 | """, 25 | """ 26 | set() 27 | """, 28 | ), 29 | ( 30 | """ 31 | tuple(itertools.chain(range(10), range(11))) 32 | """, 33 | """ 34 | (*range(10), *range(11)) 35 | """, 36 | ), 37 | ( 38 | """ 39 | tuple(itertools.chain()) 40 | """, 41 | """ 42 | () 43 | """, 44 | ), 45 | ( 46 | """ 47 | list(itertools.chain()) 48 | """, 49 | """ 50 | [] 51 | """, 52 | ), 53 | ( 54 | """ 55 | iter(itertools.chain()) 56 | """, 57 | """ 58 | iter(()) 59 | """, 60 | ), 61 | ( 62 | """ 63 | list(itertools.chain(range(10), range(11))) 64 | """, 65 | """ 66 | [*range(10), *range(11)] 67 | """, 68 | ), 69 | ( 70 | """ 71 | iter(itertools.chain(range(10), range(11))) 72 | """, 73 | """ 74 | itertools.chain(range(10), range(11)) 75 | """, 76 | ),) 77 | 78 | for source, expected_abstraction in test_cases: 79 | processed_content = fixes.remove_redundant_chain_casts(source) 80 | if not testing_infra.check_fixes_equal( 81 | processed_content, expected_abstraction, clear_paranthesises=True 82 | ): 83 | return 1 84 | 85 | return 0 86 | 87 | 88 | if __name__ == "__main__": 89 | sys.exit(main()) 90 | -------------------------------------------------------------------------------- /tests/unit/test_remove_redundant_chained_calls.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import sys 4 | from pathlib import Path 5 | 6 | from pyrefact import performance 7 | 8 | sys.path.append(str(Path(__file__).parents[1])) 9 | import testing_infra 10 | 11 | 12 | def main() -> int: 13 | test_cases = (( 14 | """ 15 | sorted(reversed(v)) 16 | sorted(sorted(v)) 17 | sorted(iter(v)) 18 | sorted(tuple(v)) 19 | sorted(list(v)) 20 | list(iter(v)) 21 | list(tuple(v)) 22 | list(list(v)) 23 | list(sorted(v)) 24 | sorted(sorted(v)) 25 | set(set(v)) 26 | set(reversed(v)) 27 | set(sorted(v)) 28 | set(iter(v)) 29 | set(tuple(v)) 30 | set(list(v)) 31 | iter(iter(v)) 32 | iter(tuple(v)) 33 | iter(list(v)) 34 | reversed(tuple(v)) 35 | reversed(list(v)) 36 | tuple(iter(v)) 37 | tuple(tuple(v)) 38 | tuple(list(v)) 39 | sum(reversed(v)) 40 | sum(sorted(v)) 41 | sum(iter(v)) 42 | sum(tuple(v)) 43 | sum(list(v)) 44 | sorted(foo(list(foo(iter((foo(v))))))) 45 | reversed(sorted(foo)) 46 | reversed(sorted(asdf, reverse=True)) 47 | reversed(sorted(k, reverse=False)) 48 | reversed(sorted(k, reverse=foo() == 313)) 49 | """, 50 | """ 51 | sorted(v) 52 | sorted(v) 53 | sorted(v) 54 | sorted(v) 55 | sorted(v) 56 | list(v) 57 | list(v) 58 | list(v) 59 | sorted(v) 60 | sorted(v) 61 | set(v) 62 | set(v) 63 | set(v) 64 | set(v) 65 | set(v) 66 | set(v) 67 | iter(v) 68 | iter(v) 69 | iter(v) 70 | reversed(v) 71 | reversed(v) 72 | tuple(v) 73 | tuple(v) 74 | tuple(v) 75 | sum(v) 76 | sum(v) 77 | sum(v) 78 | sum(v) 79 | sum(v) 80 | sorted(foo(list(foo(iter((foo(v))))))) 81 | sorted(foo, reverse=True) 82 | sorted(asdf) 83 | sorted(k, reverse=True) 84 | sorted(k, reverse=not foo() == 313) 85 | """, 86 | ),) 87 | 88 | for source, expected_abstraction in test_cases: 89 | processed_content = performance.remove_redundant_chained_calls(source) 90 | 91 | if not testing_infra.check_fixes_equal(processed_content, expected_abstraction): 92 | return 1 93 | 94 | return 0 95 | 96 | 97 | if __name__ == "__main__": 98 | sys.exit(main()) 99 | -------------------------------------------------------------------------------- /tests/unit/test_remove_redundant_comprehension_casts.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import sys 4 | from pathlib import Path 5 | 6 | from pyrefact import fixes 7 | 8 | sys.path.append(str(Path(__file__).parents[1])) 9 | import testing_infra 10 | 11 | 12 | def main() -> int: 13 | test_cases = (( 14 | """ 15 | set({x for x in range(10)}) 16 | """, 17 | """ 18 | {x for x in range(10)} 19 | """, 20 | ), 21 | ( 22 | """ 23 | iter(x for x in range(10)) 24 | """, 25 | """ 26 | (x for x in range(10)) 27 | """, 28 | ), 29 | ( 30 | """ 31 | list((x for x in range(10))) 32 | """, 33 | """ 34 | [x for x in range(10)] 35 | """, 36 | ), 37 | ( 38 | """ 39 | list([x for y in range(10) for x in range(12 + y)]) 40 | """, 41 | """ 42 | [x for y in range(10) for x in range(12 + y)] 43 | """, 44 | ), 45 | ( 46 | """ 47 | list({x: 100 for x in range(10)}) 48 | """, 49 | """ 50 | list({x for x in range(10)}) 51 | """, 52 | ), 53 | ( 54 | """ 55 | set({x: 100 for x in range(10)}) 56 | """, 57 | """ 58 | {x for x in range(10)} 59 | """, 60 | ), 61 | ( 62 | """ 63 | dict({x: 100 for x in range(10)}) 64 | """, 65 | """ 66 | {x: 100 for x in range(10)} 67 | """, 68 | ), 69 | ( 70 | """ 71 | iter({x: 100 for x in range(10)}) 72 | """, 73 | """ 74 | iter({x for x in range(10)}) 75 | """, 76 | ), 77 | ( 78 | """ 79 | list({x for x in range(10)}) 80 | """, 81 | """ 82 | list({x for x in range(10)}) 83 | """, 84 | ),) 85 | 86 | for source, expected_abstraction in test_cases: 87 | processed_content = fixes.remove_redundant_comprehension_casts(source) 88 | if not testing_infra.check_fixes_equal( 89 | processed_content, expected_abstraction, clear_paranthesises=True 90 | ): 91 | return 1 92 | 93 | return 0 94 | 95 | 96 | if __name__ == "__main__": 97 | sys.exit(main()) 98 | -------------------------------------------------------------------------------- /tests/unit/test_remove_redundant_comprehensions.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | 4 | import sys 5 | from pathlib import Path 6 | 7 | from pyrefact import fixes 8 | 9 | sys.path.append(str(Path(__file__).parents[1])) 10 | import testing_infra 11 | 12 | 13 | def main() -> int: 14 | test_cases = (( 15 | """ 16 | a = {x: y for x, y in zip(range(4), range(1, 5))} 17 | b = [w for w in (1, 2, 3, 99)] 18 | c = {v for v in [1, 2, 3]} 19 | d = (u for u in (1, 2, 3, 5)) 20 | aa = (1 for u in (1, 2, 3, 5)) 21 | ww = {x: y for x, y in zip((1, 2, 3), range(3)) if x > y > 1} 22 | ww = {x: y for y, x in zip((1, 2, 3), range(3))} 23 | """, 24 | """ 25 | a = dict(zip(range(4), range(1, 5))) 26 | b = list((1, 2, 3, 99)) 27 | c = set([1, 2, 3]) 28 | d = iter((1, 2, 3, 5)) 29 | aa = (1 for u in (1, 2, 3, 5)) 30 | ww = {x: y for x, y in zip((1, 2, 3), range(3)) if x > y > 1} 31 | ww = {x: y for y, x in zip((1, 2, 3), range(3))} 32 | """, 33 | ),) 34 | 35 | for source, expected_abstraction in test_cases: 36 | processed_content = fixes.remove_redundant_comprehensions(source) 37 | 38 | if not testing_infra.check_fixes_equal(processed_content, expected_abstraction): 39 | return 1 40 | 41 | return 0 42 | 43 | 44 | if __name__ == "__main__": 45 | sys.exit(main()) 46 | -------------------------------------------------------------------------------- /tests/unit/test_remove_redundant_iter.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import sys 4 | from pathlib import Path 5 | 6 | from pyrefact import performance 7 | 8 | sys.path.append(str(Path(__file__).parents[1])) 9 | import testing_infra 10 | 11 | 12 | def main() -> int: 13 | test_cases = (( 14 | """ 15 | for key in (1, 2, 3): 16 | print(key) 17 | """, 18 | """ 19 | for key in (1, 2, 3): 20 | print(key) 21 | """, 22 | ), 23 | ( 24 | """ 25 | for key in list((1, 2, 3)): 26 | print(key) 27 | """, 28 | """ 29 | for key in (1, 2, 3): 30 | print(key) 31 | """, 32 | ), 33 | ( 34 | """ 35 | for q in tuple((1, 2, 3)): 36 | print(q) 37 | """, 38 | """ 39 | for q in (1, 2, 3): 40 | print(q) 41 | """, 42 | ), 43 | ( 44 | """ 45 | values = (1, 2, 3) 46 | for q in list(values): 47 | print(q) 48 | """, 49 | """ 50 | values = (1, 2, 3) 51 | for q in values: 52 | print(q) 53 | """, 54 | ), 55 | ( 56 | """ 57 | values = (1, 2, 3) 58 | for q in sorted(values): 59 | print(q) 60 | """, 61 | """ 62 | values = (1, 2, 3) 63 | for q in sorted(values): 64 | print(q) 65 | """, 66 | ), 67 | ( 68 | """ 69 | values = range(50) 70 | w = [x for x in list(values)] 71 | print(w) 72 | """, 73 | """ 74 | values = range(50) 75 | w = [x for x in values] 76 | print(w) 77 | """, 78 | ), 79 | ( 80 | """ 81 | values = range(50) 82 | w = [x for x in iter(values)] 83 | print(w) 84 | """, 85 | """ 86 | values = range(50) 87 | w = [x for x in values] 88 | print(w) 89 | """, 90 | ), 91 | ( 92 | """ 93 | values = range(50) 94 | w = {x for x in list(values)} 95 | print(w) 96 | """, 97 | """ 98 | values = range(50) 99 | w = {x for x in values} 100 | print(w) 101 | """, 102 | ), 103 | ( 104 | """ 105 | values = range(50) 106 | w = (x for x in list(values)) 107 | print(w) 108 | """, 109 | """ 110 | values = range(50) 111 | w = (x for x in values) 112 | print(w) 113 | """, 114 | ),) 115 | 116 | for source, expected_abstraction in test_cases: 117 | processed_content = performance.remove_redundant_iter(source) 118 | if not testing_infra.check_fixes_equal(processed_content, expected_abstraction): 119 | return 1 120 | 121 | return 0 122 | 123 | 124 | if __name__ == "__main__": 125 | sys.exit(main()) 126 | -------------------------------------------------------------------------------- /tests/unit/test_remove_unused_imports.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import sys 4 | from pathlib import Path 5 | 6 | from pyrefact import fixes 7 | 8 | sys.path.append(str(Path(__file__).parents[1])) 9 | import testing_infra 10 | 11 | 12 | def main() -> int: 13 | test_cases = (( 14 | """ 15 | import numpy 16 | """, 17 | """ 18 | """, 19 | ), 20 | ( 21 | """ 22 | import numpy as np, pandas as pd 23 | print(pd) 24 | """, 25 | """ 26 | import pandas as pd 27 | print(pd) 28 | """, 29 | ), 30 | ( 31 | """ 32 | from . import a, c 33 | c(2) 34 | """, 35 | """ 36 | from . import c 37 | c(2) 38 | """, 39 | ), 40 | ( 41 | """ 42 | from ... import a, c 43 | c(2) 44 | """, 45 | """ 46 | from ... import c 47 | c(2) 48 | """, 49 | ), 50 | ( 51 | """ 52 | from ....af.qwerty import a, b, c as d, q, w as f 53 | print(a, b, d) 54 | """, 55 | """ 56 | from ....af.qwerty import a, b, c as d 57 | print(a, b, d) 58 | """, 59 | ),) 60 | 61 | for source, expected_abstraction in test_cases: 62 | processed_content = fixes.remove_unused_imports(source) 63 | if not testing_infra.check_fixes_equal( 64 | processed_content, expected_abstraction, clear_paranthesises=True 65 | ): 66 | return 1 67 | 68 | return 0 69 | 70 | 71 | if __name__ == "__main__": 72 | sys.exit(main()) 73 | -------------------------------------------------------------------------------- /tests/unit/test_replace_collection_add_update_with_collection_literal.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import sys 4 | from pathlib import Path 5 | 6 | from pyrefact import fixes 7 | 8 | sys.path.append(str(Path(__file__).parents[1])) 9 | import testing_infra 10 | 11 | 12 | def main() -> int: 13 | test_cases = (( 14 | """ 15 | x = {1, 2, 3} 16 | x.add(7) 17 | x.add(191) 18 | """, 19 | """ 20 | x = {1, 2, 3, 7, 191} 21 | """, 22 | ), 23 | ( 24 | """ 25 | x = [1, 2, 3] 26 | x.append(7) 27 | x.append(191) 28 | """, 29 | """ 30 | x = [1, 2, 3, 7, 191] 31 | """, 32 | ), 33 | ( # simplify_collection_unpacks is assumed to run after this 34 | """ 35 | x = {1, 2, 3} 36 | x.update((7, 22)) 37 | x.update((191, 191)) 38 | """, 39 | """ 40 | x = {1, 2, 3, 7, 22, 191, 191} 41 | """, 42 | ), 43 | ( 44 | """ 45 | x = [1, 2, 3] 46 | x.extend((7, 22)) 47 | x.extend(foo) 48 | """, 49 | """ 50 | x = [1, 2, 3, 7, 22, *foo] 51 | """, 52 | ), 53 | ( 54 | """ 55 | f = [1, 2, 3] 56 | f.extend(foo, bar) 57 | f.extend((7, 22)) 58 | """, 59 | """ 60 | f = [1, 2, 3, *foo, *bar, 7, 22] 61 | """, 62 | ), 63 | ( 64 | """ 65 | f = [1, 2, 3] 66 | x.extend(foo, bar) 67 | x.extend((7, 22)) 68 | """, 69 | """ 70 | f = [1, 2, 3] 71 | x.extend(foo, bar) 72 | x.extend((7, 22)) 73 | """, 74 | ), 75 | ( 76 | """ 77 | f = set() 78 | f.update(foo, bar) 79 | f.update((7, 22)) 80 | """, 81 | """ 82 | f = {*foo, *bar, 7, 22} 83 | """, 84 | ),) 85 | 86 | for source, expected_abstraction in test_cases: 87 | processed_content = fixes.replace_collection_add_update_with_collection_literal(source) 88 | if not testing_infra.check_fixes_equal( 89 | processed_content, expected_abstraction, clear_paranthesises=True 90 | ): 91 | return 1 92 | 93 | return 0 94 | 95 | 96 | if __name__ == "__main__": 97 | sys.exit(main()) 98 | -------------------------------------------------------------------------------- /tests/unit/test_replace_dict_assign_with_dict_literal.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import sys 4 | from pathlib import Path 5 | 6 | from pyrefact import fixes 7 | 8 | sys.path.append(str(Path(__file__).parents[1])) 9 | import testing_infra 10 | 11 | 12 | def main() -> int: 13 | test_cases = (( 14 | """ 15 | x = {} 16 | x[10] = 100 17 | """, 18 | """ 19 | x = {10: 100} 20 | """, 21 | ), 22 | ( 23 | """ 24 | x = {} 25 | x[10] = 100 26 | x[101] = 220 27 | x[103] = 223 28 | """, 29 | """ 30 | x = {10: 100, 101: 220, 103: 223} 31 | """, 32 | ), 33 | ( 34 | """ 35 | x = {5: 13, **{102: 101, 103: 909}, 19: 14} 36 | x[10] = 100 37 | x[101] = 220 38 | x[103] = 223 39 | """, 40 | """ 41 | x = {5: 13, **{102: 101, 103: 909}, 19: 14, 10: 100, 101: 220, 103: 223} 42 | """, 43 | ),) 44 | 45 | for source, expected_abstraction in test_cases: 46 | processed_content = fixes.replace_dict_assign_with_dict_literal(source) 47 | if not testing_infra.check_fixes_equal( 48 | processed_content, expected_abstraction, clear_paranthesises=True 49 | ): 50 | return 1 51 | 52 | return 0 53 | 54 | 55 | if __name__ == "__main__": 56 | sys.exit(main()) 57 | -------------------------------------------------------------------------------- /tests/unit/test_replace_dict_update_with_dict_literal.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import sys 4 | from pathlib import Path 5 | 6 | from pyrefact import fixes 7 | 8 | sys.path.append(str(Path(__file__).parents[1])) 9 | import testing_infra 10 | 11 | 12 | def main() -> int: 13 | test_cases = (( 14 | """ 15 | x = {} 16 | x.update({10: 100}) 17 | """, 18 | """ 19 | x = {**{10: 100}} 20 | """, 21 | ), 22 | ( 23 | """ 24 | x = {} 25 | x.update({10: 100}) 26 | x.update({101: 220}) 27 | x.update({103: 223}) 28 | """, 29 | """ 30 | x = {**{10: 100}, **{101: 220}, **{103: 223}} 31 | """, 32 | ), 33 | ( 34 | """ 35 | x = {5: 13, **{102: 101, 103: 909}, 19: 14} 36 | x.update({10: 100, 101: 220, 103: 223}) 37 | """, 38 | """ 39 | x = {5: 13, **{102: 101, 103: 909}, 19: 14, **{10: 100, 101: 220, 103: 223}} 40 | """, 41 | ),) 42 | 43 | for source, expected_abstraction in test_cases: 44 | processed_content = fixes.replace_dict_update_with_dict_literal(source) 45 | if not testing_infra.check_fixes_equal( 46 | processed_content, expected_abstraction, clear_paranthesises=True 47 | ): 48 | return 1 49 | 50 | return 0 51 | 52 | 53 | if __name__ == "__main__": 54 | sys.exit(main()) 55 | -------------------------------------------------------------------------------- /tests/unit/test_replace_dictcomp_assign_with_dict_literal.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import sys 4 | from pathlib import Path 5 | 6 | from pyrefact import fixes 7 | 8 | sys.path.append(str(Path(__file__).parents[1])) 9 | import testing_infra 10 | 11 | 12 | def main() -> int: 13 | test_cases = (( 14 | """ 15 | x = {z: 21 for z in range(3)} 16 | x[10] = 100 17 | """, 18 | """ 19 | x = {**{z: 21 for z in range(3)}, 10: 100} 20 | """, 21 | ), 22 | ( 23 | """ 24 | x = {z: 21 for z in range(3)} 25 | x[10] = 100 26 | x[101] = 220 27 | x[103] = 223 28 | """, 29 | """ 30 | x = {**{z: 21 for z in range(3)}, 10: 100, 101: 220, 103: 223} 31 | """, 32 | ),) 33 | 34 | for source, expected_abstraction in test_cases: 35 | processed_content = fixes.replace_dictcomp_assign_with_dict_literal(source) 36 | if not testing_infra.check_fixes_equal( 37 | processed_content, expected_abstraction, clear_paranthesises=True 38 | ): 39 | return 1 40 | 41 | return 0 42 | 43 | 44 | if __name__ == "__main__": 45 | sys.exit(main()) 46 | -------------------------------------------------------------------------------- /tests/unit/test_replace_dictcomp_update_with_dict_literal.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import sys 4 | from pathlib import Path 5 | 6 | from pyrefact import fixes 7 | 8 | sys.path.append(str(Path(__file__).parents[1])) 9 | import testing_infra 10 | 11 | 12 | def main() -> int: 13 | test_cases = (( 14 | """ 15 | x = {z: 21 for z in range(3)} 16 | x.update({10: 100}) 17 | """, 18 | """ 19 | x = {**{z: 21 for z in range(3)}, **{10: 100}} 20 | """, 21 | ), 22 | ( 23 | """ 24 | x = {z: 21 for z in range(3)} 25 | x.update({10: 100}) 26 | x.update({101: 220}) 27 | x.update({103: 223}) 28 | """, 29 | """ 30 | x = {**{z: 21 for z in range(3)}, **{10: 100}, **{101: 220}, **{103: 223}} 31 | """, 32 | ),) 33 | 34 | for source, expected_abstraction in test_cases: 35 | processed_content = fixes.replace_dictcomp_update_with_dict_literal(source) 36 | if not testing_infra.check_fixes_equal( 37 | processed_content, expected_abstraction, clear_paranthesises=True 38 | ): 39 | return 1 40 | 41 | return 0 42 | 43 | 44 | if __name__ == "__main__": 45 | sys.exit(main()) 46 | -------------------------------------------------------------------------------- /tests/unit/test_replace_filter_lambda_with_comp.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import sys 4 | from pathlib import Path 5 | 6 | from pyrefact import fixes 7 | 8 | sys.path.append(str(Path(__file__).parents[1])) 9 | import testing_infra 10 | 11 | 12 | def main() -> int: 13 | test_cases = (( 14 | """ 15 | x = filter(lambda y: y > 0, (1, 2, 3)) 16 | """, 17 | """ 18 | x = (y for y in (1, 2, 3) if y > 0) 19 | """, 20 | ), 21 | ( # Invalid syntax 22 | """ 23 | x = filter(lambda y, z: y > z, zip((1, 2, 3), [3, 2, 1])) 24 | """, 25 | """ 26 | x = filter(lambda y, z: y > z, zip((1, 2, 3), [3, 2, 1])) 27 | """, 28 | ), 29 | ( 30 | """ 31 | x = itertools.filterfalse(lambda y: y > 0, (1, 2, 3)) 32 | """, 33 | """ 34 | x = (y for y in (1, 2, 3) if not y > 0) 35 | """, 36 | ), 37 | ( 38 | """ 39 | for x in filter(lambda y: y > 0, (1, 2, 3)): 40 | print(x) 41 | """, 42 | """ 43 | for x in filter(lambda y: y > 0, (1, 2, 3)): 44 | print(x) 45 | """, 46 | ), 47 | ( 48 | """ 49 | r = filter(lambda: True, (1, 2, 3)) # syntax error? 50 | """, 51 | """ 52 | r = filter(lambda: True, (1, 2, 3)) # syntax error? 53 | """, 54 | ),) 55 | 56 | for source, expected_abstraction in test_cases: 57 | processed_content = fixes.replace_filter_lambda_with_comp(source) 58 | if not testing_infra.check_fixes_equal( 59 | processed_content, expected_abstraction, clear_paranthesises=True 60 | ): 61 | return 1 62 | 63 | return 0 64 | 65 | 66 | if __name__ == "__main__": 67 | sys.exit(main()) 68 | -------------------------------------------------------------------------------- /tests/unit/test_replace_for_loops_with_dict_comp.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | 4 | import sys 5 | from pathlib import Path 6 | 7 | from pyrefact import fixes 8 | 9 | sys.path.append(str(Path(__file__).parents[1])) 10 | import testing_infra 11 | 12 | 13 | def main() -> int: 14 | test_cases = (( 15 | """ 16 | x = {} 17 | for i in range(10): 18 | x[i] = 10 19 | """, 20 | """ 21 | x = {i: 10 for i in range(10)} 22 | """, 23 | ), 24 | ( 25 | """ 26 | x = {} 27 | for i in range(10): 28 | if i % 3 == 0: 29 | if i % 2 == 0: 30 | x[i] = 10 31 | """, 32 | """ 33 | x = {i: 10 for i in range(10) if i % 3 == 0 and i % 2 == 0} 34 | """, 35 | ), 36 | ( 37 | """ 38 | x = {} 39 | for i in range(10): 40 | if i % 3 == 0: 41 | if i % 2 == 0: 42 | x[i] = 10 43 | else: 44 | x[i] = 2 45 | """, 46 | """ 47 | x = {} 48 | for i in range(10): 49 | if i % 3 == 0: 50 | if i % 2 == 0: 51 | x[i] = 10 52 | else: 53 | x[i] = 2 54 | """, 55 | ), 56 | ( 57 | """ 58 | x = {} 59 | for i in range(10): 60 | if i % 3 == 0: 61 | if i % 2 == 0: 62 | x[i] = 10 63 | else: 64 | x[i] = 2 65 | """, 66 | """ 67 | x = {} 68 | for i in range(10): 69 | if i % 3 == 0: 70 | if i % 2 == 0: 71 | x[i] = 10 72 | else: 73 | x[i] = 2 74 | """, 75 | ), 76 | ( 77 | """ 78 | x = {} 79 | for i in range(10): 80 | x[i] = 10 ** i - 1 81 | """, 82 | """ 83 | x = {i: 10 ** i - 1 for i in range(10)} 84 | """, 85 | ), 86 | ( 87 | """ 88 | x = {1: 2} 89 | for i in range(10): 90 | x[i] = 10 ** i - 1 91 | """, 92 | """ 93 | x = {**{1: 2}, **{i: 10 ** i - 1 for i in range(10)}} 94 | """, 95 | ), 96 | ( 97 | """ 98 | x = {i: 10 - 1 for i in range(33)} 99 | for i in range(77, 22): 100 | x[i] = 10 ** i - 1 101 | """, 102 | """ 103 | x = {**{i: 10 - 1 for i in range(33)}, **{i: 10 ** i - 1 for i in range(77, 22)}} 104 | """, 105 | ), 106 | ( 107 | """ 108 | u = {i: 10 - 1 for i in range(33)} 109 | v = {i: 10 ** i - 1 for i in range(77, 22)} 110 | w = {11: 342, 'key': "value"} 111 | x = {**u, **v, **w} 112 | for i in range(2, 4): 113 | x[i] = 10 ** i - 1 114 | """, 115 | """ 116 | u = {i: 10 - 1 for i in range(33)} 117 | v = {i: 10 ** i - 1 for i in range(77, 22)} 118 | w = {11: 342, 'key': "value"} 119 | x = {**u, **v, **w, **{i: 10 ** i - 1 for i in range(2, 4)}} 120 | """, 121 | ),) 122 | 123 | for source, expected_abstraction in test_cases: 124 | processed_content = fixes.replace_for_loops_with_dict_comp(source) 125 | 126 | if not testing_infra.check_fixes_equal(processed_content, expected_abstraction): 127 | return 1 128 | 129 | return 0 130 | 131 | 132 | if __name__ == "__main__": 133 | sys.exit(main()) 134 | -------------------------------------------------------------------------------- /tests/unit/test_replace_functions_with_literals.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | 4 | import sys 5 | from pathlib import Path 6 | 7 | from pyrefact import fixes 8 | 9 | sys.path.append(str(Path(__file__).parents[1])) 10 | import testing_infra 11 | 12 | 13 | def main() -> int: 14 | test_cases = (( 15 | """ 16 | u = list() 17 | v = tuple() 18 | w = dict() 19 | a = dict(zip(range(4), range(1, 5))) 20 | b = list((1, 2, 3, 99)) 21 | c = set([1, 2, 3]) 22 | d = iter((1, 2, 3, 5)) 23 | aa = (1 for u in (1, 2, 3, 5)) 24 | """, 25 | """ 26 | u = [] 27 | v = () 28 | w = {} 29 | a = dict(zip(range(4), range(1, 5))) 30 | b = [1, 2, 3, 99] 31 | c = {1, 2, 3} 32 | d = iter((1, 2, 3, 5)) 33 | aa = (1 for u in (1, 2, 3, 5)) 34 | """, 35 | ),) 36 | 37 | for source, expected_abstraction in test_cases: 38 | processed_content = fixes.replace_functions_with_literals(source) 39 | 40 | if not testing_infra.check_fixes_equal(processed_content, expected_abstraction): 41 | return 1 42 | 43 | return 0 44 | 45 | 46 | if __name__ == "__main__": 47 | sys.exit(main()) 48 | -------------------------------------------------------------------------------- /tests/unit/test_replace_iterrows_index.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import sys 4 | from pathlib import Path 5 | 6 | from pyrefact import performance_pandas 7 | 8 | sys.path.append(str(Path(__file__).parents[1])) 9 | import testing_infra 10 | 11 | 12 | def main() -> int: 13 | test_cases = (( 14 | """ 15 | for x, _ in df.iterrows(): 16 | print(x) 17 | """, 18 | """ 19 | for x in df.index: 20 | print(x) 21 | """, 22 | ), 23 | ( 24 | """ 25 | stuff = [x for x, _ in df.iterrows()] 26 | print(stuff[-1]) 27 | """, 28 | """ 29 | stuff = [x for x in df.index] 30 | print(stuff[-1]) 31 | """, 32 | ), 33 | ( 34 | """ 35 | stuff = df.iterrows() 36 | print(sum(stuff)) 37 | """, 38 | """ 39 | stuff = df.iterrows() 40 | print(sum(stuff)) 41 | """, 42 | ), 43 | ( 44 | """ 45 | for x, i in df.iterrows(): 46 | print(x) 47 | """, 48 | """ 49 | for x, i in df.iterrows(): 50 | print(x) 51 | """, 52 | ), 53 | ( 54 | """ 55 | stuff = [x for x, q in df.iterrows()] 56 | print(stuff[-1]) 57 | """, 58 | """ 59 | stuff = [x for x, q in df.iterrows()] 60 | print(stuff[-1]) 61 | """, 62 | ),) 63 | 64 | for source, expected_abstraction in test_cases: 65 | processed_content = performance_pandas.replace_iterrows_index(source) 66 | if not testing_infra.check_fixes_equal( 67 | processed_content, expected_abstraction, clear_paranthesises=True 68 | ): 69 | return 1 70 | 71 | return 0 72 | 73 | 74 | if __name__ == "__main__": 75 | sys.exit(main()) 76 | -------------------------------------------------------------------------------- /tests/unit/test_replace_iterrows_itertuples.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import sys 4 | from pathlib import Path 5 | 6 | from pyrefact import performance_pandas 7 | 8 | sys.path.append(str(Path(__file__).parents[1])) 9 | import testing_infra 10 | 11 | 12 | def main() -> int: 13 | test_cases = (( 14 | """ 15 | for _, x in df.iterrows(): 16 | print(x["value"]) 17 | """, 18 | """ 19 | for x in df.itertuples(): 20 | print(x.value) 21 | """, 22 | ), 23 | ( 24 | """ 25 | for _, x in df.iterrows(): 26 | print(x.at["value"]) 27 | """, 28 | """ 29 | for x in df.itertuples(): 30 | print(x.value) 31 | """, 32 | ), 33 | ( 34 | """ 35 | for _, x in df.iterrows(): 36 | print(x.iat[9]) 37 | """, 38 | """ 39 | for x in df.itertuples(): 40 | print(x[9 + 1]) 41 | """, 42 | ), 43 | ( 44 | """ 45 | for _, x in df.iterrows(): 46 | print(x.iat[9 + value - 1]) 47 | """, 48 | """ 49 | for x in df.itertuples(): 50 | print(x[9 + value - 1 + 1]) 51 | """, 52 | ), 53 | ( 54 | """ 55 | for _, x in df.iterrows(): 56 | print(x) 57 | """, 58 | """ 59 | for _, x in df.iterrows(): 60 | print(x) 61 | """, 62 | ), 63 | ( 64 | """ 65 | for i, x in df.iterrows(): 66 | print(x["value"]) 67 | """, 68 | """ 69 | for i, x in df.iterrows(): 70 | print(x["value"]) 71 | """, 72 | ), 73 | ( # Anti-pattern attribute access of column 74 | """ 75 | for _, x in df.iterrows(): 76 | print(x.value) 77 | """, 78 | """ 79 | for _, x in df.iterrows(): 80 | print(x.value) 81 | """, 82 | ), 83 | ( 84 | """ 85 | for _, x in df.iterrows(): 86 | print(x["value"]) 87 | y = 0 88 | y += x.at["qwerty"] ** x.iat[9] 89 | if y >= 199 and x.iat[13] + x.iat[8] > x["q"]: 90 | print(x["jk"] + x["e"]) 91 | """, 92 | """ 93 | for x in df.itertuples(): 94 | print(x.value) 95 | y = 0 96 | y += x.qwerty ** x[9 + 1] 97 | if y >= 199 and x[13 + 1] + x[8 + 1] > x.q: 98 | print(x.jk + x.e) 99 | """, 100 | ), 101 | ( 102 | """ 103 | for _, x in df.iterrows(): 104 | print(x["value"]) 105 | y = 0 106 | y += x.at["qwerty"] ** x.iat[9] 107 | y += x.__getattr__(1) 108 | """, 109 | """ 110 | for _, x in df.iterrows(): 111 | print(x["value"]) 112 | y = 0 113 | y += x.at["qwerty"] ** x.iat[9] 114 | y += x.__getattr__(1) 115 | """, 116 | ), 117 | ( 118 | """ 119 | for i, x in df.iterrows(): 120 | x["value"] = 1 121 | """, 122 | """ 123 | for i, x in df.iterrows(): 124 | x["value"] = 1 125 | """, 126 | ),) 127 | 128 | for source, expected_abstraction in test_cases: 129 | processed_content = performance_pandas.replace_iterrows_itertuples(source) 130 | if not testing_infra.check_fixes_equal( 131 | processed_content, expected_abstraction, clear_paranthesises=True 132 | ): 133 | return 1 134 | 135 | return 0 136 | 137 | 138 | if __name__ == "__main__": 139 | sys.exit(main()) 140 | -------------------------------------------------------------------------------- /tests/unit/test_replace_listcomp_append_with_plus.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import sys 4 | from pathlib import Path 5 | 6 | from pyrefact import fixes 7 | 8 | sys.path.append(str(Path(__file__).parents[1])) 9 | import testing_infra 10 | 11 | 12 | def main() -> int: 13 | test_cases = (( 14 | """ 15 | x = [foo(z) ** 2 for z in range(3)] 16 | for zua in range(3): 17 | x.append(zua - 1) 18 | """, 19 | """ 20 | x = [foo(z) ** 2 for z in range(3)] + [zua - 1 for zua in range(3)] 21 | """, 22 | ), 23 | ( 24 | """ 25 | x = [foo(z) ** 2 for z in range(3)] 26 | for zua in range(3): 27 | x.append(zua - 1) 28 | x.extend([1, 3, 2]) 29 | x.extend({"sadf", 312}) 30 | """, 31 | """ 32 | x = [foo(z) ** 2 for z in range(3)] + [zua - 1 for zua in range(3)] + list([1, 3, 2]) + list({"sadf", 312}) 33 | """, 34 | ), 35 | ( 36 | """ 37 | x = {foo(z) ** 2 for z in range(3)} 38 | for zua in range(3): 39 | x.add(zua - 1) 40 | """, 41 | """ 42 | x = {foo(z) ** 2 for z in range(3)} 43 | for zua in range(3): 44 | x.add(zua - 1) 45 | """, 46 | ), 47 | ( 48 | """ 49 | x = [1, 2, 3] 50 | for zua in range(3): 51 | x.append(zua - 1) 52 | for zua in range(9): 53 | x.append(zua ** 3 - 1) 54 | for fua in range(9): 55 | x.append(fua ** 2 - 1) 56 | """, 57 | """ 58 | x = [1, 2, 3] + [zua - 1 for zua in range(3)] + [zua ** 3 - 1 for zua in range(9)] + [fua ** 2 - 1 for fua in range(9)] 59 | """, 60 | ) 61 | ) 62 | 63 | for source, expected_abstraction in test_cases: 64 | processed_content = fixes.replace_listcomp_append_with_plus(source) 65 | if not testing_infra.check_fixes_equal( 66 | processed_content, expected_abstraction, clear_paranthesises=True 67 | ): 68 | return 1 69 | 70 | return 0 71 | 72 | 73 | if __name__ == "__main__": 74 | sys.exit(main()) 75 | -------------------------------------------------------------------------------- /tests/unit/test_replace_loc_at_iloc_iat.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import sys 4 | from pathlib import Path 5 | 6 | from pyrefact import performance_pandas 7 | 8 | sys.path.append(str(Path(__file__).parents[1])) 9 | import testing_infra 10 | 11 | 12 | def main() -> int: 13 | test_cases = (( 14 | """ 15 | y = x.loc[1, 2] 16 | """, 17 | """ 18 | y = x.at[1, 2] 19 | """, 20 | ), 21 | ( 22 | """ 23 | y = x.iloc[1, 2] 24 | """, 25 | """ 26 | y = x.iat[1, 2] 27 | """, 28 | ), 29 | ( # Series 30 | """ 31 | y = x.loc[1] 32 | """, 33 | """ 34 | y = x.at[1] 35 | """, 36 | ), 37 | ( 38 | """ 39 | y = x.iloc[2] 40 | """, 41 | """ 42 | y = x.iat[2] 43 | """, 44 | ), 45 | ( 46 | """ 47 | y = x.loc[1.13, "name_of_column"] 48 | """, 49 | """ 50 | y = x.at[(1.13, "name_of_column")] 51 | """, 52 | ), 53 | ( 54 | """ 55 | y = x.loc[1.13, ["name_of_column", "name_of_other_column"]] 56 | """, 57 | """ 58 | y = x.loc[1.13, ["name_of_column", "name_of_other_column"]] 59 | """, 60 | ), 61 | ( 62 | """ 63 | y = x.loc[1, var] 64 | """, 65 | """ 66 | y = x.loc[1, var] 67 | """, 68 | ), 69 | ( 70 | """ 71 | y = x.iloc[[2, 3], 2] 72 | """, 73 | """ 74 | y = x.iloc[[2, 3], 2] 75 | """, 76 | ), 77 | ( 78 | """ 79 | y = iloc[3, 2] 80 | """, 81 | """ 82 | y = iloc[3, 2] 83 | """, 84 | ), 85 | ( 86 | """ 87 | y = loc[3, 2] 88 | """, 89 | """ 90 | y = loc[3, 2] 91 | """, 92 | ), 93 | ( # Series 94 | """ 95 | y = x.loc[var] 96 | """, 97 | """ 98 | y = x.loc[var] 99 | """, 100 | ), 101 | ( # Series 102 | """ 103 | y = x.iloc[[2, 3]] 104 | """, 105 | """ 106 | y = x.iloc[[2, 3]] 107 | """, 108 | ), 109 | ( 110 | """ 111 | y = iloc[3, 2] 112 | """, 113 | """ 114 | y = iloc[3, 2] 115 | """, 116 | ), 117 | ( 118 | """ 119 | y = loc[3, 2] 120 | """, 121 | """ 122 | y = loc[3, 2] 123 | """, 124 | ),) 125 | 126 | for source, expected_abstraction in test_cases: 127 | processed_content = performance_pandas.replace_loc_at_iloc_iat(source) 128 | if not testing_infra.check_fixes_equal( 129 | processed_content, expected_abstraction, clear_paranthesises=True 130 | ): 131 | return 1 132 | 133 | return 0 134 | 135 | 136 | if __name__ == "__main__": 137 | sys.exit(main()) 138 | -------------------------------------------------------------------------------- /tests/unit/test_replace_map_lambda_with_comp.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import sys 4 | from pathlib import Path 5 | 6 | from pyrefact import fixes 7 | 8 | sys.path.append(str(Path(__file__).parents[1])) 9 | import testing_infra 10 | 11 | 12 | def main() -> int: 13 | test_cases = (( 14 | """ 15 | x = map(lambda y: y > 0, (1, 2, 3)) 16 | """, 17 | """ 18 | x = (y > 0 for y in (1, 2, 3)) 19 | """, 20 | ), 21 | ( # Invalid syntax 22 | """ 23 | x = map(lambda y, z: y > z, zip((1, 2, 3), [3, 2, 1])) 24 | """, 25 | """ 26 | x = map(lambda y, z: y > z, zip((1, 2, 3), [3, 2, 1])) 27 | """, 28 | ), 29 | ( 30 | """ 31 | for x in map(lambda y: y > 0, (1, 2, 3)): 32 | print(x) 33 | """, 34 | """ 35 | for x in map(lambda y: y > 0, (1, 2, 3)): 36 | print(x) 37 | """, 38 | ), 39 | ( 40 | """ 41 | r = map(lambda: True, (1, 2, 3)) # syntax error? 42 | """, 43 | """ 44 | r = map(lambda: True, (1, 2, 3)) # syntax error? 45 | """, 46 | ),) 47 | 48 | for source, expected_abstraction in test_cases: 49 | processed_content = fixes.replace_map_lambda_with_comp(source) 50 | if not testing_infra.check_fixes_equal( 51 | processed_content, expected_abstraction, clear_paranthesises=True 52 | ): 53 | return 1 54 | 55 | return 0 56 | 57 | 58 | if __name__ == "__main__": 59 | sys.exit(main()) 60 | -------------------------------------------------------------------------------- /tests/unit/test_replace_negated_numeric_comparison.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import sys 4 | from pathlib import Path 5 | 6 | from pyrefact import fixes 7 | 8 | sys.path.append(str(Path(__file__).parents[1])) 9 | import testing_infra 10 | 11 | 12 | def main() -> int: 13 | test_cases = (( 14 | """ 15 | x > y 16 | """, 17 | """ 18 | x > y 19 | """, 20 | ), 21 | ( 22 | """ 23 | not x > y 24 | """, 25 | """ 26 | not x > y 27 | """, 28 | ), 29 | ( 30 | """ 31 | x > 3 32 | """, 33 | """ 34 | x > 3 35 | """, 36 | ), 37 | ( 38 | """ 39 | not x > 3 40 | """, 41 | """ 42 | x <= 3 43 | """, 44 | ), 45 | ( 46 | """ 47 | not 3 > 3.3 48 | """, 49 | """ 50 | 3 <= 3.3 51 | """, 52 | ), 53 | ( 54 | """ 55 | not 500 == h 56 | """, 57 | """ 58 | 500 != h 59 | """, 60 | ), 61 | ( 62 | """ 63 | not a == b 64 | not a != b 65 | not a < b 66 | not a <= b 67 | not a > b 68 | not a >= b 69 | not a is b 70 | not a is not b 71 | not a in b 72 | not a not in b 73 | """, 74 | """ 75 | a != b 76 | a == b 77 | not a < b 78 | not a <= b 79 | not a > b 80 | not a >= b 81 | a is not b 82 | a is b 83 | a not in b 84 | a in b 85 | """, 86 | ), 87 | ( 88 | """ 89 | not a == 44.1 90 | not a != 44.1 91 | not a < 44.1 92 | not a <= 44.1 93 | not a > 44.1 94 | not a >= 44.1 95 | not a is 44.1 96 | not a is not 44.1 97 | """, 98 | """ 99 | a != 44.1 100 | a == 44.1 101 | a >= 44.1 102 | a > 44.1 103 | a <= 44.1 104 | a < 44.1 105 | a is not 44.1 106 | a is 44.1 107 | """, 108 | ), 109 | ( 110 | """ 111 | not a == -999 112 | not a != -999 113 | not a < -999 114 | not a <= -999 115 | not a > -999 116 | not a >= -999 117 | not a is -999 118 | not a is not -999 119 | """, 120 | """ 121 | a != -999 122 | a == -999 123 | a >= -999 124 | a > -999 125 | a <= -999 126 | a < -999 127 | a is not -999 128 | a is -999 129 | """, 130 | ), 131 | ( 132 | """ 133 | not y.xa() == (hqx - 999) 134 | not y.xa() != (hqx - 999) 135 | not y.xa() < (hqx - 999) 136 | not y.xa() <= (hqx - 999) 137 | not y.xa() > (hqx - 999) 138 | not y.xa() >= (hqx - 999) 139 | not y.xa() is (hqx - 999) 140 | not y.xa() is not (hqx - 999) 141 | """, 142 | """ 143 | y.xa() != hqx - 999 144 | y.xa() == hqx - 999 145 | y.xa() >= hqx - 999 146 | y.xa() > hqx - 999 147 | y.xa() <= hqx - 999 148 | y.xa() < hqx - 999 149 | y.xa() is not hqx - 999 150 | y.xa() is hqx - 999 151 | """, 152 | ), 153 | ( 154 | """ 155 | not (hqx - 999) == y.xa() 156 | not (hqx - 999) != y.xa() 157 | not (hqx - 999) < y.xa() 158 | not (hqx - 999) <= y.xa() 159 | not (hqx - 999) > y.xa() 160 | not (hqx - 999) >= y.xa() 161 | not (hqx - 999) is y.xa() 162 | not (hqx - 999) is not y.xa() 163 | """, 164 | """ 165 | hqx - 999 != y.xa() 166 | hqx - 999 == y.xa() 167 | hqx - 999 >= y.xa() 168 | hqx - 999 > y.xa() 169 | hqx - 999 <= y.xa() 170 | hqx - 999 < y.xa() 171 | hqx - 999 is not y.xa() 172 | hqx - 999 is y.xa() 173 | """, 174 | ), 175 | ( 176 | """ 177 | not (0 + k) == y.xa() 178 | not (0 + k) != y.xa() 179 | not (0 + k) < y.xa() 180 | not (0 + k) <= y.xa() 181 | not (0 + k) > y.xa() 182 | not (0 + k) >= y.xa() 183 | not (0 + k) is y.xa() 184 | not (0 + k) is not y.xa() 185 | """, 186 | """ 187 | 0 + k != y.xa() 188 | 0 + k == y.xa() 189 | 0 + k >= y.xa() 190 | 0 + k > y.xa() 191 | 0 + k <= y.xa() 192 | 0 + k < y.xa() 193 | 0 + k is not y.xa() 194 | 0 + k is y.xa() 195 | """, 196 | ),) 197 | 198 | for source, expected_abstraction in test_cases: 199 | processed_content = fixes.replace_negated_numeric_comparison(source) 200 | if not testing_infra.check_fixes_equal( 201 | processed_content, expected_abstraction, clear_paranthesises=True 202 | ): 203 | return 1 204 | 205 | return 0 206 | 207 | 208 | if __name__ == "__main__": 209 | sys.exit(main()) 210 | -------------------------------------------------------------------------------- /tests/unit/test_replace_redundant_starred.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | 4 | import sys 5 | from pathlib import Path 6 | 7 | from pyrefact import fixes 8 | 9 | sys.path.append(str(Path(__file__).parents[1])) 10 | import testing_infra 11 | 12 | 13 | def main() -> int: 14 | test_cases = (( 15 | """ 16 | [*(x for x in iterator)] 17 | """, 18 | """ 19 | list((x for x in iterator)) 20 | """, 21 | ), 22 | ( 23 | """ 24 | (*(x for x in iterator),) 25 | """, 26 | """ 27 | tuple((x for x in iterator)) 28 | """, 29 | ), 30 | ( 31 | """ 32 | {*[x for x in iterator]} 33 | """, 34 | """ 35 | set([x for x in iterator]) 36 | """, 37 | ),) 38 | 39 | for source, expected_abstraction in test_cases: 40 | processed_content = fixes.replace_redundant_starred(source) 41 | 42 | if not testing_infra.check_fixes_equal(processed_content, expected_abstraction): 43 | return 1 44 | 45 | return 0 46 | 47 | 48 | if __name__ == "__main__": 49 | sys.exit(main()) 50 | -------------------------------------------------------------------------------- /tests/unit/test_replace_setcomp_add_with_union.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import sys 4 | from pathlib import Path 5 | 6 | from pyrefact import fixes 7 | 8 | sys.path.append(str(Path(__file__).parents[1])) 9 | import testing_infra 10 | 11 | 12 | def main() -> int: 13 | test_cases = (( 14 | """ 15 | x = {foo(z) ** 2 for z in range(3)} 16 | for zua in range(3): 17 | x.add(zua - 1) 18 | """, 19 | """ 20 | x = {foo(z) ** 2 for z in range(3)} | {zua - 1 for zua in range(3)} 21 | """, 22 | ), 23 | ( 24 | """ 25 | x = {foo(z) ** 2 for z in range(3)} 26 | for zua in range(3): 27 | x.add(zua - 1) 28 | x.update((1, 3, 2)) 29 | x.update({"sadf", 312}) 30 | """, 31 | """ 32 | x = {foo(z) ** 2 for z in range(3)} | {zua - 1 for zua in range(3)} | set((1, 3, 2)) | set({"sadf", 312}) 33 | """, 34 | ), 35 | ( 36 | """ 37 | x = [foo(z) ** 2 for z in range(3)] 38 | for zua in range(3): 39 | x.append(zua - 1) 40 | """, 41 | """ 42 | x = [foo(z) ** 2 for z in range(3)] 43 | for zua in range(3): 44 | x.append(zua - 1) 45 | """, 46 | ), 47 | ( 48 | """ 49 | x = {1, 2, 3} 50 | for zua in range(3): 51 | x.add(zua - 1) 52 | for zua in range(9): 53 | x.add(zua ** 3 - 1) 54 | for fua in range(9): 55 | x.add(fua ** 2 - 1) 56 | """, 57 | """ 58 | x = {1, 2, 3} | {zua - 1 for zua in range(3)} | {zua ** 3 - 1 for zua in range(9)} | {fua ** 2 - 1 for fua in range(9)} 59 | """, 60 | ), 61 | ) 62 | 63 | for source, expected_abstraction in test_cases: 64 | processed_content = fixes.replace_setcomp_add_with_union(source) 65 | if not testing_infra.check_fixes_equal( 66 | processed_content, expected_abstraction, clear_paranthesises=True 67 | ): 68 | return 1 69 | 70 | return 0 71 | 72 | 73 | if __name__ == "__main__": 74 | sys.exit(main()) 75 | -------------------------------------------------------------------------------- /tests/unit/test_replace_subscript_looping.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | 4 | import sys 5 | from pathlib import Path 6 | 7 | from pyrefact import performance 8 | 9 | sys.path.append(str(Path(__file__).parents[1])) 10 | import testing_infra 11 | 12 | 13 | def main() -> int: 14 | test_cases = ( 15 | ( 16 | """ 17 | import numpy 18 | [a[i] for i in range(len(a))] 19 | [a[i, :] for i in range(len(a))] 20 | [a[i, :] for i in range(a.shape[0])] 21 | [a[:, i] for i in range(a.shape[1])] 22 | """, 23 | """ 24 | import numpy 25 | list(a) 26 | list(a) 27 | list(a) 28 | list(a.T) 29 | """, 30 | ), 31 | ( 32 | """ 33 | import numpy 34 | (a[i] for i in range(len(a))) 35 | (a[i, :] for i in range(len(a))) 36 | (a[i, :] for i in range(a.shape[0])) 37 | (a[:, i] for i in range(a.shape[1])) 38 | """, 39 | """ 40 | import numpy 41 | iter(a) 42 | iter(a) 43 | iter(a) 44 | iter(a.T) 45 | """, 46 | ), 47 | ( 48 | """ 49 | import numpy as np 50 | [ 51 | [ 52 | np.dot(b[:, i], a[j, :]) 53 | for i in range(b.shape[1]) 54 | ] 55 | for j in range(a.shape[0]) 56 | ] 57 | """, 58 | """ 59 | import numpy as np 60 | [ 61 | [ 62 | np.dot(b_i, a_j) 63 | for b_i in zip(*b) 64 | ] 65 | for a_j in a 66 | ] 67 | """, 68 | ), 69 | ) 70 | 71 | for source, expected_abstraction in test_cases: 72 | 73 | processed_content = performance.replace_subscript_looping(source) 74 | 75 | if not testing_infra.check_fixes_equal(processed_content, expected_abstraction): 76 | return 1 77 | 78 | return 0 79 | 80 | 81 | if __name__ == "__main__": 82 | sys.exit(main()) 83 | -------------------------------------------------------------------------------- /tests/unit/test_replace_with_filter.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import sys 4 | from pathlib import Path 5 | 6 | from pyrefact import fixes 7 | 8 | sys.path.append(str(Path(__file__).parents[1])) 9 | import testing_infra 10 | 11 | 12 | def main() -> int: 13 | test_cases = (( 14 | """ 15 | for x in range(10): 16 | if x: 17 | print(3) 18 | """, 19 | """ 20 | for x in filter(None, range(10)): 21 | print(3) 22 | """, 23 | ), 24 | ( 25 | """ 26 | for x in range(10): 27 | if not f(x): 28 | continue 29 | print(x) 30 | """, 31 | """ 32 | for x in filter(f, range(10)): 33 | print(x) 34 | """, 35 | ), 36 | ( 37 | """ 38 | for x in range(10): 39 | if f(x): 40 | print(3) 41 | """, 42 | """ 43 | for x in filter(f, range(10)): 44 | print(3) 45 | """, 46 | ), 47 | ( 48 | """ 49 | for x in range(10): 50 | if f(x): 51 | print(3) 52 | else: 53 | print(76) 54 | """, 55 | """ 56 | for x in range(10): 57 | if f(x): 58 | print(3) 59 | else: 60 | print(76) 61 | """, 62 | ), 63 | ( 64 | """ 65 | for x in range(10): 66 | if not x: 67 | continue 68 | print(3) 69 | """, 70 | """ 71 | for x in filter(None, range(10)): 72 | print(3) 73 | """, 74 | ), 75 | ( # I find itertools.filterfalse much less readable 76 | """ 77 | for x in range(10): 78 | if f(x): 79 | continue 80 | print(x) 81 | """, 82 | """ 83 | for x in range(10): 84 | if f(x): 85 | continue 86 | print(x) 87 | """, 88 | ), 89 | ( 90 | """ 91 | for x in range(10): 92 | if f(x): 93 | print(x) 94 | """, 95 | """ 96 | for x in filter(f, range(10)): 97 | print(x) 98 | """, 99 | ), 100 | ( # Another filterfalse opportunity that I will not implement 101 | """ 102 | for x in range(10): 103 | if not f(x): 104 | print(x) 105 | """, 106 | """ 107 | for x in range(10): 108 | if not f(x): 109 | print(x) 110 | """, 111 | ), 112 | ( # Do not chain filter with filter 113 | """ 114 | for x in filter(bool, range(10)): 115 | if not f(x): 116 | continue 117 | print(x) 118 | """, 119 | """ 120 | for x in filter(bool, range(10)): 121 | if not f(x): 122 | continue 123 | print(x) 124 | """, 125 | ), 126 | ( 127 | """ 128 | for x in filter(int, range(10)): 129 | if f(x): 130 | print(x) 131 | """, 132 | """ 133 | for x in filter(int, range(10)): 134 | if f(x): 135 | print(x) 136 | """, 137 | ), 138 | ( # Do not chain filter with filterfalse 139 | """ 140 | from itertools import filterfalse 141 | for x in filterfalse(bool, range(10)): 142 | if not f(x): 143 | continue 144 | print(x) 145 | """, 146 | """ 147 | from itertools import filterfalse 148 | for x in filterfalse(bool, range(10)): 149 | if not f(x): 150 | continue 151 | print(x) 152 | """, 153 | ), 154 | ( 155 | """ 156 | import itertools 157 | for x in itertools.filterfalse(int, range(10)): 158 | if f(x): 159 | print(x) 160 | """, 161 | """ 162 | import itertools 163 | for x in itertools.filterfalse(int, range(10)): 164 | if f(x): 165 | print(x) 166 | """, 167 | ),) 168 | 169 | for source, expected_abstraction in test_cases: 170 | processed_content = fixes.replace_with_filter(source) 171 | processed_content = fixes.remove_dead_ifs(processed_content) 172 | if not testing_infra.check_fixes_equal(processed_content, expected_abstraction): 173 | return 1 174 | 175 | return 0 176 | 177 | 178 | if __name__ == "__main__": 179 | sys.exit(main()) 180 | -------------------------------------------------------------------------------- /tests/unit/test_simplify_assign_immediate_return.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import sys 4 | from pathlib import Path 5 | 6 | from pyrefact import fixes 7 | 8 | sys.path.append(str(Path(__file__).parents[1])) 9 | import testing_infra 10 | 11 | 12 | def main() -> int: 13 | test_cases = (( 14 | """ 15 | def foo(): 16 | x = 100 17 | return x 18 | """, 19 | """ 20 | def foo(): 21 | return 100 22 | """, 23 | ), 24 | ( 25 | """ 26 | def foo(): 27 | x = sorted(set(range(1000))) 28 | return x 29 | """, 30 | """ 31 | def foo(): 32 | return sorted(set(range(1000))) 33 | """, 34 | ), 35 | ( # Variable names aren't the same 36 | """ 37 | y = 3 38 | def foo(): 39 | x = sorted(set(range(1000))) 40 | return y 41 | """, 42 | """ 43 | y = 3 44 | def foo(): 45 | x = sorted(set(range(1000))) 46 | return y 47 | """, 48 | ), 49 | ( # This pattern is ok, I don't want it removed. 50 | """ 51 | def foo(): 52 | x = 100 53 | x = 3 - x 54 | return x 55 | """, 56 | """ 57 | def foo(): 58 | x = 100 59 | x = 3 - x 60 | return x 61 | """, 62 | ), 63 | ( # Same variable in different places, both should be removed 64 | """ 65 | def foo(): 66 | x = 100 67 | return x 68 | 69 | def bar(): 70 | x = 301 - foo() 71 | return x 72 | """, 73 | """ 74 | def foo(): 75 | return 100 76 | 77 | def bar(): 78 | return 301 - foo() 79 | """, 80 | ), 81 | ( 82 | r""" 83 | def fix_too_many_blank_lines(source: str) -> str: 84 | # At module level, remove all above 2 blank lines 85 | source = re.sub(r"(\n\s*){3,}\n", "\n" * 3, source) 86 | 87 | # At EOF, remove all newlines and whitespace above 1 88 | source = re.sub(r"(\n\s*){2,}\Z", "\n", source) 89 | 90 | # At non-module (any indented) level, remove all newlines above 1, preserve indent 91 | source = re.sub(r"(\n\s*){2,}(\n\s+)(?=[^\n\s])", r"\n\g<2>", source) 92 | 93 | return source 94 | """, 95 | r""" 96 | def fix_too_many_blank_lines(source: str) -> str: 97 | # At module level, remove all above 2 blank lines 98 | source = re.sub(r"(\n\s*){3,}\n", "\n" * 3, source) 99 | 100 | # At EOF, remove all newlines and whitespace above 1 101 | source = re.sub(r"(\n\s*){2,}\Z", "\n", source) 102 | 103 | # At non-module (any indented) level, remove all newlines above 1, preserve indent 104 | source = re.sub(r"(\n\s*){2,}(\n\s+)(?=[^\n\s])", r"\n\g<2>", source) 105 | 106 | return source 107 | """, 108 | ),) 109 | 110 | for source, expected_abstraction in test_cases: 111 | processed_content = fixes.simplify_assign_immediate_return(source) 112 | if not testing_infra.check_fixes_equal( 113 | processed_content, expected_abstraction, clear_paranthesises=True 114 | ): 115 | return 1 116 | 117 | return 0 118 | 119 | 120 | if __name__ == "__main__": 121 | sys.exit(main()) 122 | -------------------------------------------------------------------------------- /tests/unit/test_simplify_boolean_expressions_symmath.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | 4 | import sys 5 | from pathlib import Path 6 | 7 | from pyrefact import symbolic_math 8 | 9 | sys.path.append(str(Path(__file__).parents[1])) 10 | import testing_infra 11 | 12 | 13 | def main() -> int: 14 | test_cases = (( 15 | """ 16 | x and not x 17 | """, 18 | """ 19 | False 20 | """, 21 | ), 22 | ( 23 | """ 24 | x or not x 25 | """, 26 | """ 27 | True 28 | """, 29 | ), 30 | ( 31 | """ 32 | (A and B) and (not A and not B) 33 | (A and B) and (A or B) 34 | a and b and not (not c or not d) 35 | """, 36 | """ 37 | False 38 | A and B 39 | a and b and c and d 40 | """, 41 | ), 42 | ( 43 | """ 44 | ( 45 | testing_infra.check_fixes_equal(processed_content, expected_abstraction) 46 | and True and not 47 | testing_infra.check_fixes_equal(processed_content, expected_abstraction) 48 | ) 49 | """, 50 | """ 51 | ( 52 | False 53 | ) 54 | """, 55 | ), 56 | ( 57 | """ 58 | x = [a for a in range(10) if a % 2 == 0 and a > 5 and a % 2 == 0] 59 | """, 60 | """ 61 | x = [a for a in range(10) if a % 2 == 0 and a > 5] 62 | """, 63 | ),) 64 | 65 | for source, expected_abstraction in test_cases: 66 | processed_content = symbolic_math.simplify_boolean_expressions_symmath(source) 67 | 68 | if not testing_infra.check_fixes_equal(processed_content, expected_abstraction): 69 | return 1 70 | 71 | return 0 72 | 73 | 74 | if __name__ == "__main__": 75 | sys.exit(main()) 76 | -------------------------------------------------------------------------------- /tests/unit/test_simplify_collection_unpacks.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import sys 4 | from pathlib import Path 5 | 6 | from pyrefact import fixes 7 | 8 | sys.path.append(str(Path(__file__).parents[1])) 9 | import testing_infra 10 | 11 | 12 | def main() -> int: 13 | test_cases = (( 14 | """ 15 | x = [*()] 16 | """, 17 | """ 18 | x = [] 19 | """, 20 | ), 21 | ( 22 | """ 23 | x = (*(),) 24 | """, 25 | """ 26 | x = () 27 | """, 28 | ), 29 | ( 30 | """ 31 | x = {*{}} 32 | """, 33 | """ 34 | x = set() 35 | """, 36 | ), 37 | ( 38 | """ 39 | x = [*{}] 40 | """, 41 | """ 42 | x = [] 43 | """, 44 | ), 45 | ( 46 | """ 47 | x = (*{},) 48 | """, 49 | """ 50 | x = () 51 | """, 52 | ), 53 | ( 54 | """ 55 | x = {*()} 56 | """, 57 | """ 58 | x = set() 59 | """, 60 | ), 61 | ( # 1 element is unique, hence a dict provides no non-trivial uniqueness 62 | """ 63 | x = [*{1: 3}] 64 | """, 65 | """ 66 | x = [1] 67 | """, 68 | ), 69 | ( 70 | """ 71 | x = (*{1: 3},) 72 | """, 73 | """ 74 | x = (1,) 75 | """, 76 | ), 77 | ( 78 | """ 79 | x = {*{1: 3}} 80 | """, 81 | """ 82 | x = {1} 83 | """, 84 | ), 85 | ( # 2 elements need to be compared, so can only be safely moved to the set 86 | """ 87 | x = [*{1: 3, 2: 19}] 88 | """, 89 | """ 90 | x = [*{1: 3, 2: 19}] 91 | """, 92 | ), 93 | ( 94 | """ 95 | x = (*{1: 3, 2: 19},) 96 | """, 97 | """ 98 | x = (*{1: 3, 2: 19},) 99 | """, 100 | ), 101 | ( 102 | """ 103 | x = {*{1: 3, 2: 19}} 104 | """, 105 | """ 106 | x = {1, 2} 107 | """, 108 | ), 109 | ( # One element is unique and may be unpacked into a tuple or set 110 | """ 111 | x = {*(), 2, 3, *{99}, (199, 991, 2), *[], [], *tuple([1, 2, 3]), 9} 112 | """, 113 | """ 114 | x = {2, 3, 99, 199, 991, 2, [], *tuple([1, 2, 3]), 9} 115 | """, 116 | ), 117 | ( 118 | """ 119 | x = (*(), 2, 3, *{99}, (199, 991, 2), *[], [], *tuple([1, 2, 3]), 9) 120 | """, 121 | """ 122 | x = (2, 3, 99, 199, 991, 2, [], *tuple([1, 2, 3]), 9) 123 | """, 124 | ), 125 | ( 126 | """ 127 | x = [*(), 2, 3, *{99}, (199, 991, 2), *[], [], *tuple([1, 2, 3]), 9] 128 | """, 129 | """ 130 | x = [2, 3, 99, 199, 991, 2, [], *tuple([1, 2, 3]), 9] 131 | """, 132 | ), 133 | ( # Two elements must be compared/hashed to say if unique 134 | """ 135 | x = {*(), 2, 3, *{99, 44}, (199, 991, 2), *[], [], *tuple([1, 2, 3]), 9} 136 | """, 137 | """ 138 | x = {2, 3, 99, 44, 199, 991, 2, [], *tuple([1, 2, 3]), 9} 139 | """, 140 | ), 141 | ( 142 | """ 143 | x = (*(), 2, 3, *{99, 44}, (199, 991, 2), *[], [], *tuple([1, 2, 3]), 9) 144 | """, 145 | """ 146 | x = (2, 3, *{99, 44}, 199, 991, 2, [], *tuple([1, 2, 3]), 9) 147 | """, 148 | ), 149 | ( 150 | """ 151 | x = [*(), 2, 3, *{99, 44}, (199, 991, 2), *[], [], *tuple([1, 2, 3]), 9] 152 | """, 153 | """ 154 | x = [2, 3, *{99, 44}, 199, 991, 2, [], *tuple([1, 2, 3]), 9] 155 | """, 156 | ),) 157 | 158 | for source, expected_abstraction in test_cases: 159 | processed_content = fixes.simplify_collection_unpacks(source) 160 | if not testing_infra.check_fixes_equal( 161 | processed_content, expected_abstraction, clear_paranthesises=True 162 | ): 163 | return 1 164 | 165 | return 0 166 | 167 | 168 | if __name__ == "__main__": 169 | sys.exit(main()) 170 | -------------------------------------------------------------------------------- /tests/unit/test_simplify_constrained_range.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | 4 | import sys 5 | from pathlib import Path 6 | 7 | from pyrefact import symbolic_math 8 | 9 | sys.path.append(str(Path(__file__).parents[1])) 10 | import testing_infra 11 | 12 | 13 | def main() -> int: 14 | test_cases = (( 15 | """ 16 | (x for x in range(10) if x > 4) 17 | """, 18 | """ 19 | (x for x in range(5, 10) if True) 20 | """, 21 | ), 22 | ( 23 | """ 24 | {x for x in range(10) if x > 4} 25 | """, 26 | """ 27 | {x for x in range(5, 10) if True} 28 | """, 29 | ), 30 | ( 31 | """ 32 | [x for x in range(10) if x > 4] 33 | """, 34 | """ 35 | [x for x in range(5, 10) if True] 36 | """, 37 | ), 38 | ( 39 | """ 40 | [x for x in range(10) if x >= 4] 41 | """, 42 | """ 43 | [x for x in range(4, 10) if True] 44 | """, 45 | ), 46 | ( 47 | """ 48 | [x for x in range(10) if x < 4] 49 | """, 50 | """ 51 | [x for x in range(4) if True] 52 | """, 53 | ), 54 | ( 55 | """ 56 | [x for x in range(10) if x <= 4] 57 | """, 58 | """ 59 | [x for x in range(5) if True] 60 | """, 61 | ), 62 | ( 63 | """ 64 | [x for x in range(10) if x < 4 and x > 1] 65 | """, 66 | """ 67 | [x for x in range(2, 4) if True and True] 68 | """, 69 | ), 70 | ( 71 | """ 72 | [x for x in range(10) if x < 4 and x > 1 and x == 88] 73 | """, 74 | """ 75 | [x for x in ()] 76 | """, 77 | ), 78 | ( 79 | """ 80 | [x for x in range(-1, 89) if foo() and bar and x == 88] 81 | """, 82 | """ 83 | [x for x in range(88, 89) if foo() and bar and True] 84 | """, 85 | ), 86 | ( 87 | """ 88 | [x for x in range(-1, 89) if foo() if bar if x == 88] 89 | """, 90 | """ 91 | [x for x in range(88, 89) if foo() if bar if True] 92 | """, 93 | ), 94 | ( 95 | """ 96 | [x for x in range(-1, 89) if foo() and bar and x == 89] 97 | """, 98 | """ 99 | [x for x in ()] 100 | """, 101 | ),) 102 | 103 | for source, expected_abstraction in test_cases: 104 | processed_content = symbolic_math.simplify_constrained_range(source) 105 | 106 | if not testing_infra.check_fixes_equal(processed_content, expected_abstraction): 107 | return 1 108 | 109 | return 0 110 | 111 | 112 | if __name__ == "__main__": 113 | sys.exit(main()) 114 | -------------------------------------------------------------------------------- /tests/unit/test_simplify_dict_unpacks.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import sys 4 | from pathlib import Path 5 | 6 | from pyrefact import fixes 7 | 8 | sys.path.append(str(Path(__file__).parents[1])) 9 | import testing_infra 10 | 11 | 12 | def main() -> int: 13 | test_cases = (( 14 | """ 15 | x = {**{}} 16 | """, 17 | """ 18 | x = {} 19 | """, 20 | ), 21 | ( 22 | """ 23 | x = {**{}, 13: 14} 24 | """, 25 | """ 26 | x = {13: 14} 27 | """, 28 | ), 29 | ( 30 | """ 31 | x = {3: {}, 13: 14} 32 | """, 33 | """ 34 | x = {3: {}, 13: 14} 35 | """, 36 | ), 37 | ( 38 | """ 39 | x = {1: 2, 3: 4, **{99: 109, None: None}, 4: 5, **{"asdf": 12 - 13}} 40 | """, 41 | """ 42 | x = {1: 2, 3: 4, 99: 109, None: None, 4: 5, "asdf": 12 - 13} 43 | """, 44 | ),) 45 | 46 | for source, expected_abstraction in test_cases: 47 | processed_content = fixes.simplify_dict_unpacks(source) 48 | if not testing_infra.check_fixes_equal( 49 | processed_content, expected_abstraction, clear_paranthesises=True 50 | ): 51 | return 1 52 | 53 | return 0 54 | 55 | 56 | if __name__ == "__main__": 57 | sys.exit(main()) 58 | -------------------------------------------------------------------------------- /tests/unit/test_simplify_if_control_flow.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import sys 4 | from pathlib import Path 5 | 6 | from pyrefact import abstractions, fixes 7 | 8 | sys.path.append(str(Path(__file__).parents[1])) 9 | import testing_infra 10 | 11 | 12 | def main() -> int: 13 | test_cases = (( 14 | """ 15 | import random 16 | x = 11 17 | y = 12 18 | z = random.random() 19 | if z > 1 - z: 20 | do_stuff(x) 21 | do_stuff(y - x ** 2) 22 | print(doing_other_stuff(x) - do_stuff(y ** y)) 23 | else: 24 | do_stuff(y) 25 | do_stuff(x - y ** 2) 26 | print(doing_other_stuff(y) - do_stuff(x ** x)) 27 | """, 28 | """ 29 | import random 30 | x = 11 31 | y = 12 32 | z = random.random() 33 | if z > 1 - z: 34 | var_1 = x 35 | var_2 = y 36 | else: 37 | var_1 = y 38 | var_2 = x 39 | 40 | do_stuff(var_1) 41 | do_stuff(var_2 - var_1 ** 2) 42 | print(doing_other_stuff(var_1) - do_stuff(var_2 ** var_2)) 43 | """, 44 | ), 45 | ( # Too little code would be simplified => do not replace this 46 | """ 47 | import random 48 | x = 11 49 | y = 12 50 | z = random.random() 51 | if z > 1 - z: 52 | do_stuff(x) 53 | else: 54 | do_stuff(y) 55 | """, 56 | """ 57 | import random 58 | x = 11 59 | y = 12 60 | z = random.random() 61 | if z > 1 - z: 62 | do_stuff(x) 63 | else: 64 | do_stuff(y) 65 | """, 66 | ), 67 | ( 68 | """ 69 | import random 70 | x = 11 71 | y = 12 72 | z = random.random() 73 | if z > 1 - z: 74 | do_stuff(x) 75 | if x > 1 / z: 76 | raise RuntimeError(f"Invalid value for {x}") 77 | 78 | print(random.randint(1 / z ** 2, - 1 / (z + y - x) ** 2)) 79 | else: 80 | do_stuff(y) 81 | if y > 1 / z: 82 | raise RuntimeError(f"Invalid value for {y}") 83 | 84 | print(random.randint(1 / z ** 2, - 1 / (z + x - y) ** 2)) 85 | """, 86 | """ 87 | import random 88 | x = 11 89 | y = 12 90 | z = random.random() 91 | if z > 1 - z: 92 | var_1 = x 93 | var_2 = y 94 | else: 95 | var_1 = y 96 | var_2 = x 97 | 98 | do_stuff(var_1) 99 | if var_1 > 1 / z: 100 | raise RuntimeError(f"Invalid value for {var_1}") 101 | 102 | print(random.randint(1 / z ** 2, -1 / (z + var_2 - var_1) ** 2)) 103 | """, 104 | ),) 105 | 106 | for source, expected_abstraction in test_cases: 107 | processed_content = abstractions.simplify_if_control_flow(source) 108 | processed_content = fixes.breakout_common_code_in_ifs(processed_content) 109 | if not testing_infra.check_fixes_equal( 110 | processed_content, expected_abstraction, clear_paranthesises=True 111 | ): 112 | return 1 113 | 114 | return 0 115 | 116 | 117 | if __name__ == "__main__": 118 | sys.exit(main()) 119 | -------------------------------------------------------------------------------- /tests/unit/test_simplify_math_iterators.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | 4 | import sys 5 | from pathlib import Path 6 | 7 | from pyrefact import symbolic_math 8 | 9 | sys.path.append(str(Path(__file__).parents[1])) 10 | import testing_infra 11 | 12 | 13 | def main() -> int: 14 | test_cases = (( 15 | """ 16 | x = sum(range(10)) 17 | y = sum(range(3, 17)) 18 | h = sum((1, 2, 3)) 19 | q = sum([1, 2, 3, 99, 99, -8]) 20 | r = sum({1, 2, 3, 99, 99, -8}) 21 | """, 22 | """ 23 | x = 45 24 | y = 133 25 | h = 6 26 | q = 196 27 | r = sum({1, 2, 3, 99, 99, -8}) 28 | """, 29 | ), 30 | ( 31 | """ 32 | y = sum([a ** 2 for a in range(10)]) 33 | """, 34 | """ 35 | y = 285 36 | """, 37 | ), 38 | ( 39 | """ 40 | x = 111 41 | z = 44 42 | y = sum([x * a ** 3 - a * z ** 2 for a in range(10)]) 43 | """, 44 | """ 45 | x = 111 46 | z = 44 47 | y = 2025 * x - 45 * z ** 2 48 | """, 49 | ), 50 | ( 51 | """ 52 | y = sum([x * a ** 3 - a * z ** 2 for a in range(10, 19) for z in range(3, 7) for x in range(1, 3)]) 53 | """, 54 | """ 55 | y = 304920 56 | """, 57 | ), 58 | ( 59 | """ 60 | y = sum([ 61 | x * a ** 3 - a * z ** 2 62 | for a in range(10, 19, 2) 63 | for z in range(3, 7) 64 | for x in range(1, 9) 65 | ]) 66 | """, 67 | """ 68 | y = 2169440 69 | """, 70 | ), 71 | ( 72 | """ 73 | y = sum([x * a ** 3 - a * z ** 2 for a in range(10, 19, 2) for z in range(3, 7) for x in range(1, 9, 5)]) 74 | y = sum([x * (a * 2) ** 3 - (a * 2) * z ** 2 for a in range(5, 10) for z in range(3, 7) for x in range(1, 9, 5)]) 75 | """, 76 | """ 77 | y = 419160 78 | y = 419160 79 | """, 80 | ), 81 | ( 82 | """ 83 | y = sum([x * a ** 3 - a * z ** 2 for a in range(10, 19, 2) for z in (3, 4, 5, 6) for x in range(1, 9, 5)]) 84 | h = sum(x * a ** 3 - a * z ** 2 for a in range(10, 19, 2) for z in [3, 4, 5, 6] for x in range(1, 9, 5)) 85 | w = sum([x * a ** 3 - a * z ** 2 for a in range(10, 19, 2) for z in {3, 4, 5, 6, 6, 5, 4} for x in range(1, 9, 5)]) 86 | """, 87 | """ 88 | y = 419160 89 | h = 419160 90 | w = 419160 91 | """, 92 | ), 93 | ( 94 | """ 95 | import math 96 | x = sum([math.sqrt(x) for x in range(11)]) 97 | """, 98 | """ 99 | import math 100 | x = sum([math.sqrt(x) for x in range(11)]) 101 | """, 102 | ), 103 | ( 104 | """ 105 | import random 106 | def foo(): 107 | return random.random() 108 | x = sum([foo() for x in range(11)]) 109 | """, 110 | """ 111 | import random 112 | def foo(): 113 | return random.random() 114 | x = sum([foo() for x in range(11)]) 115 | """, 116 | ), 117 | ( 118 | """ 119 | from math import exp 120 | x = sum([exp(x) for x in range(11)]) 121 | """, 122 | """ 123 | from math import exp 124 | x = sum([exp(x) for x in range(11)]) 125 | """, 126 | ), 127 | ( 128 | """ 129 | x = sum([a + b + c + d for _ in range(k, w)]) 130 | """, 131 | """ 132 | x = -(k - w) * (a + b + c + d) 133 | """, 134 | ),) 135 | 136 | for source, expected_abstraction in test_cases: 137 | processed_content = symbolic_math.simplify_math_iterators(source) 138 | 139 | if not testing_infra.check_fixes_equal(processed_content, expected_abstraction): 140 | return 1 141 | 142 | return 0 143 | 144 | 145 | if __name__ == "__main__": 146 | sys.exit(main()) 147 | -------------------------------------------------------------------------------- /tests/unit/test_simplify_matrix_operations.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import sys 4 | from pathlib import Path 5 | 6 | from pyrefact import fixes, performance_numpy 7 | 8 | sys.path.append(str(Path(__file__).parents[1])) 9 | import testing_infra 10 | 11 | 12 | def main() -> int: 13 | test_cases = (( 14 | """ 15 | np.matmul(a.T, b.T).T 16 | np.matmul(a, b.T).T 17 | np.matmul(a.T, b).T 18 | np.matmul(a.T, b.T) 19 | """, 20 | """ 21 | np.matmul(b, a) 22 | np.matmul(a, b.T).T 23 | np.matmul(a.T, b).T 24 | np.matmul(a.T, b.T) 25 | """, 26 | ),) 27 | 28 | for source, expected_abstraction in test_cases: 29 | processed_content = fixes.simplify_transposes(source) 30 | processed_content = performance_numpy.simplify_matmul_transposes(processed_content) 31 | processed_content = fixes.simplify_transposes(processed_content) 32 | 33 | if not testing_infra.check_fixes_equal(processed_content, expected_abstraction): 34 | return 1 35 | 36 | return 0 37 | 38 | 39 | if __name__ == "__main__": 40 | sys.exit(main()) 41 | -------------------------------------------------------------------------------- /tests/unit/test_simplify_redundant_lambda.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | 4 | import sys 5 | from pathlib import Path 6 | 7 | from pyrefact import fixes 8 | 9 | sys.path.append(str(Path(__file__).parents[1])) 10 | import testing_infra 11 | 12 | 13 | def main() -> int: 14 | test_cases = ( 15 | # ( 16 | # """ 17 | # lambda: complicated_function() 18 | # lambda: pd.DataFrame() 19 | # lambda: [] 20 | # lambda: {} 21 | # lambda: set() 22 | # lambda: () 23 | # """, 24 | # """ 25 | # complicated_function 26 | # pd.DataFrame 27 | # list 28 | # dict 29 | # set 30 | # tuple 31 | # """, 32 | # ), 33 | # ( 34 | # """ 35 | # lambda value: [*value] 36 | # lambda value: {*value,} 37 | # lambda value: (*value,) 38 | # lambda value, /: [*value] 39 | # """, 40 | # """ 41 | # list 42 | # set 43 | # tuple 44 | # list 45 | # """, 46 | # ), 47 | # ( 48 | # """ 49 | # lambda value, /, value2: (*value, *value2) 50 | # lambda value, /, value2: (*value,) 51 | # lambda: complicated_function(some_argument) 52 | # lambda: complicated_function(some_argument=2) 53 | # """, 54 | # """ 55 | # lambda value, /, value2: (*value, *value2) 56 | # lambda value, /, value2: (*value,) 57 | # lambda: complicated_function(some_argument) 58 | # lambda: complicated_function(some_argument=2) 59 | # """, 60 | # ), 61 | # ( 62 | # """ 63 | # lambda x: [] 64 | # lambda x: list() 65 | # """, 66 | # """ 67 | # lambda x: [] 68 | # lambda x: list() 69 | # """, 70 | # ), 71 | ( 72 | """ 73 | lambda *args: w(*args) 74 | lambda **kwargs: r(**kwargs) 75 | """, 76 | """ 77 | w 78 | r 79 | """, 80 | ), 81 | ( 82 | """ 83 | lambda q: h(q) 84 | lambda z, w: f(z, w) 85 | lambda *args, **kwargs: hh(*args, **kwargs) 86 | lambda z, k, /, w, h, *args: rrr(z, k, w, h, *args) 87 | lambda z, k, /, w, h, *args, **kwargs: rfr(z, k, w, h, *args, **kwargs) 88 | lambda z, k, /, w, h, *args: rrr(z, k, w, w, *args) 89 | lambda z, k, /, w, h: rrr(z, k, w, w, *args) 90 | """, 91 | """ 92 | h 93 | f 94 | hh 95 | rrr 96 | rfr 97 | lambda z, k, /, w, h, *args: rrr(z, k, w, w, *args) 98 | lambda z, k, /, w, h: rrr(z, k, w, w, *args) 99 | """, 100 | ),) 101 | 102 | for source, expected_abstraction in test_cases: 103 | processed_content = fixes.simplify_redundant_lambda(source) 104 | 105 | if not testing_infra.check_fixes_equal(processed_content, expected_abstraction): 106 | return 1 107 | 108 | return 0 109 | 110 | 111 | if __name__ == "__main__": 112 | sys.exit(main()) 113 | -------------------------------------------------------------------------------- /tests/unit/test_simplify_transposes.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import sys 4 | from pathlib import Path 5 | 6 | from pyrefact import fixes 7 | 8 | sys.path.append(str(Path(__file__).parents[1])) 9 | import testing_infra 10 | 11 | 12 | def main() -> int: 13 | test_cases = (( 14 | """ 15 | arr = [[1, 2, 3], [4, 5, 6]] 16 | assert list(zip(*arr)) == [[1, 4], [2, 5], [3, 6]] 17 | assert list(zip(*zip(*arr))) == [[1, 4], [2, 5], [3, 6]] 18 | """, 19 | """ 20 | arr = [[1, 2, 3], [4, 5, 6]] 21 | assert list(zip(*arr)) == [[1, 4], [2, 5], [3, 6]] 22 | assert list(arr) == [[1, 4], [2, 5], [3, 6]] 23 | """, 24 | ), 25 | ( 26 | """ 27 | arr = np.array([[1, 2, 3], [4, 5, 6]]) 28 | assert list(arr.T) == [[1, 4], [2, 5], [3, 6]] 29 | assert list(arr.T.T) == [[1, 2, 3], [4, 5, 6]] 30 | """, 31 | """ 32 | arr = np.array([[1, 2, 3], [4, 5, 6]]) 33 | assert list(arr.T) == [[1, 4], [2, 5], [3, 6]] 34 | assert list(arr) == [[1, 2, 3], [4, 5, 6]] 35 | """, 36 | ), 37 | ( 38 | """ 39 | arr = np.array([[1, 2, 3], [4, 5, 6]]) 40 | assert list(zip(*arr.T)) == [[1, 2, 3], [4, 5, 6]] 41 | assert list(zip(*arr.T.T)) == [[1, 4], [2, 5], [3, 6]] 42 | assert list(zip(*zip(*arr))) == [[1, 2, 3], [4, 5, 6]] 43 | assert list(zip(*zip(*arr.T))) == [[1, 4], [2, 5], [3, 6]] 44 | assert list(zip(*zip(*arr.T.T))) == [[1, 2, 3], [4, 5, 6]] 45 | """, 46 | """ 47 | arr = np.array([[1, 2, 3], [4, 5, 6]]) 48 | assert list(arr) == [[1, 2, 3], [4, 5, 6]] 49 | assert list(arr.T) == [[1, 4], [2, 5], [3, 6]] 50 | assert list(arr) == [[1, 2, 3], [4, 5, 6]] 51 | assert list(arr.T) == [[1, 4], [2, 5], [3, 6]] 52 | assert list(arr) == [[1, 2, 3], [4, 5, 6]] 53 | """, 54 | ),) 55 | 56 | for source, expected_abstraction in test_cases: 57 | processed_content = fixes.simplify_transposes(source) 58 | 59 | if not testing_infra.check_fixes_equal(processed_content, expected_abstraction): 60 | return 1 61 | 62 | return 0 63 | 64 | 65 | if __name__ == "__main__": 66 | sys.exit(main()) 67 | -------------------------------------------------------------------------------- /tests/unit/test_singleton_comparison.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | 4 | import sys 5 | from pathlib import Path 6 | 7 | from pyrefact import fixes 8 | 9 | sys.path.append(str(Path(__file__).parents[1])) 10 | import testing_infra 11 | 12 | 13 | def main() -> int: 14 | test_cases = (( 15 | """ 16 | x == None and k != None 17 | """, 18 | """ 19 | x is None and k is not None 20 | """, 21 | ), 22 | ( 23 | """ 24 | x == None or k != None 25 | """, 26 | """ 27 | x is None or k is not None 28 | """, 29 | ), 30 | ( 31 | """ 32 | if a == False: 33 | print(1) 34 | """, 35 | """ 36 | if a is False: 37 | print(1) 38 | """, 39 | ), 40 | ( 41 | """ 42 | print(q == True) 43 | print(k != True) 44 | """, 45 | """ 46 | print(q is True) 47 | print(k is not True) 48 | """, 49 | ), 50 | ( 51 | """ 52 | print(q == True is x) 53 | print(k != True != q != None is not False) 54 | """, 55 | """ 56 | print(q is True is x) 57 | print(k is not True != q is not None is not False) 58 | """, 59 | ),) 60 | 61 | for source, expected_abstraction in test_cases: 62 | processed_content = fixes.singleton_eq_comparison(source) 63 | 64 | if not testing_infra.check_fixes_equal(processed_content, expected_abstraction): 65 | return 1 66 | 67 | return 0 68 | 69 | 70 | if __name__ == "__main__": 71 | sys.exit(main()) 72 | -------------------------------------------------------------------------------- /tests/unit/test_sort_imports.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import sys 4 | from pathlib import Path 5 | 6 | from pyrefact import fixes 7 | 8 | sys.path.append(str(Path(__file__).parents[1])) 9 | import testing_infra 10 | 11 | 12 | def main() -> int: 13 | test_cases = (( 14 | """ 15 | from logging import info 16 | from logging import warning 17 | from logging import error, info, log 18 | from logging import ( 19 | warning, 20 | critical , 21 | error, error, error as error) 22 | from logging import critical 23 | """, 24 | """ 25 | from logging import critical 26 | from logging import critical, error, error, error, warning 27 | from logging import error, info, log 28 | from logging import info 29 | from logging import warning 30 | """, 31 | ), 32 | ( 33 | """ 34 | from logging import info 35 | from numpy import ndarray 36 | from logging import warning 37 | from numpy import array 38 | from logging import error as info, warning as error 39 | """, 40 | """ 41 | from logging import error as info, warning as error 42 | from logging import info 43 | from logging import warning 44 | from numpy import array 45 | from numpy import ndarray 46 | """, 47 | ), 48 | ( 49 | """ 50 | import logging 51 | import logging 52 | import logging, numpy, pandas as pd, os as sys, os as os 53 | import pandas as pd, os, os, os 54 | import os 55 | """, 56 | """ 57 | import logging 58 | import logging 59 | import os 60 | import logging, numpy, os, os as sys, pandas as pd 61 | import os, os, os, pandas as pd 62 | """, 63 | ), 64 | ( 65 | """ 66 | from __future__ import annotations 67 | from __future__ import absolute_import 68 | import os 69 | import logging 70 | from typing import List, Dict, Tuple 71 | from typing import List, Dict, Tuple, Any, Union 72 | from numpy import ndarray 73 | import numpy as np 74 | from .. import utils2 75 | from re import findall 76 | from . import utils 77 | from .. import utils 78 | from .. import utils1 79 | from .utils import * 80 | import re 81 | """, 82 | """ 83 | from __future__ import absolute_import 84 | from __future__ import annotations 85 | import logging 86 | import os 87 | import re 88 | from re import findall 89 | from typing import Any, Dict, List, Tuple, Union 90 | from typing import Dict, List, Tuple 91 | 92 | import numpy as np 93 | from numpy import ndarray 94 | 95 | from .. import utils 96 | from .. import utils1 97 | from .. import utils2 98 | from . import utils 99 | from .utils import * 100 | """, 101 | ), 102 | ( 103 | """ 104 | import logging as logging 105 | """, 106 | """ 107 | import logging 108 | """, 109 | ), 110 | ( 111 | """ 112 | from logging import info as info 113 | """, 114 | """ 115 | from logging import info 116 | """, 117 | ), 118 | ( 119 | """ 120 | import logging as logging 121 | from logging import info as info 122 | """, 123 | """ 124 | import logging 125 | from logging import info 126 | """, 127 | ),) 128 | 129 | for source, expected_abstraction in test_cases: 130 | processed_content = fixes.sort_imports(source) 131 | if not testing_infra.check_fixes_equal( 132 | processed_content, expected_abstraction, clear_paranthesises=True 133 | ): 134 | return 1 135 | 136 | return 0 137 | 138 | 139 | if __name__ == "__main__": 140 | sys.exit(main()) 141 | -------------------------------------------------------------------------------- /tests/unit/test_sorted_heapq.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import sys 4 | from pathlib import Path 5 | 6 | from pyrefact import performance 7 | 8 | sys.path.append(str(Path(__file__).parents[1])) 9 | import testing_infra 10 | 11 | 12 | def main() -> int: 13 | test_cases = (( 14 | """ 15 | x = sorted(y) 16 | z = sorted(p)[0] 17 | k = sorted(p)[-1] 18 | r = sorted([q, x], key=foo)[-1] 19 | w = sorted(p)[:5] 20 | w = sorted(p)[:k] 21 | w = sorted(p)[:-5] 22 | f = sorted(p)[-8:] 23 | f = sorted(p)[-q():] 24 | f = sorted(p)[13:] 25 | sorted(x)[3:8] 26 | print(sorted(z, key=lambda x: -x)[:94]) 27 | print(sorted(z, key=lambda x: -x)[-4:]) 28 | """, 29 | """ 30 | x = sorted(y) 31 | z = min(p) 32 | k = max(p) 33 | r = max([q, x], key=foo) 34 | w = heapq.nsmallest(5, p) 35 | w = heapq.nsmallest(k, p) 36 | w = sorted(p)[:-5] 37 | f = list(reversed(heapq.nlargest(8, p))) 38 | f = list(reversed(heapq.nlargest(q(), p))) 39 | f = sorted(p)[13:] 40 | sorted(x)[3:8] 41 | print(heapq.nsmallest(94, z, key=lambda x: -x)) 42 | print(list(reversed(heapq.nlargest(4, z, key=lambda x: -x)))) 43 | """, 44 | ),) 45 | 46 | for source, expected_abstraction in test_cases: 47 | processed_content = performance.replace_sorted_heapq(source) 48 | if not testing_infra.check_fixes_equal(processed_content, expected_abstraction): 49 | return 1 50 | 51 | return 0 52 | 53 | 54 | if __name__ == "__main__": 55 | sys.exit(main()) 56 | -------------------------------------------------------------------------------- /tests/unit/test_swap_if_else.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import sys 4 | from pathlib import Path 5 | 6 | from pyrefact import fixes 7 | 8 | sys.path.append(str(Path(__file__).parents[1])) 9 | import testing_infra 10 | 11 | 12 | def main() -> int: 13 | test_cases = ( 14 | ( # Explicit only 15 | """ 16 | def f(x) -> int: 17 | if x > 10: 18 | x += 1 19 | x *= 12 20 | print(x > 30) 21 | return 100 - sum(x, 2, 3) 22 | else: 23 | return 13 24 | """, 25 | """ 26 | def f(x) -> int: 27 | if x <= 10: 28 | return 13 29 | else: 30 | x += 1 31 | x *= 12 32 | print(x > 30) 33 | return 100 - sum(x, 2, 3) 34 | """, 35 | ), 36 | ( # Implicit only 37 | """ 38 | def f(x) -> int: 39 | if x > 10: 40 | x += 1 41 | x *= 12 42 | print(x > 30) 43 | return 100 - sum(x, 2, 3) 44 | 45 | return 13 46 | """, 47 | """ 48 | def f(x) -> int: 49 | if x <= 10: 50 | return 13 51 | else: 52 | x += 1 53 | x *= 12 54 | print(x > 30) 55 | return 100 - sum(x, 2, 3) 56 | """, 57 | ), 58 | ( # No body 59 | """ 60 | def f(x): 61 | if x > 10: 62 | pass 63 | else: 64 | print(2) 65 | """, 66 | """ 67 | def f(x): 68 | if x <= 10: 69 | print(2) 70 | """, 71 | ), 72 | ( # Combined 73 | """ 74 | def f(x) -> int: 75 | if x > 10: 76 | if x < 100: 77 | return 4 78 | elif x >= 12: 79 | return 2 80 | return 99 81 | else: 82 | return 14 83 | """, 84 | """ 85 | def f(x) -> int: 86 | if x <= 10: 87 | return 14 88 | else: 89 | if x < 100: 90 | return 4 91 | elif x >= 12: 92 | return 2 93 | return 99 94 | """, 95 | ), 96 | ( # Non-blocking body -> swap not equivalent 97 | """ 98 | if X % 5 == 0: 99 | if X % 61 == 0: 100 | if X % (X - 4) == 0: 101 | return 61 102 | return 12 103 | """, 104 | """ 105 | if X % 5 == 0: 106 | if X % 61 == 0: 107 | if X % (X - 4) == 0: 108 | return 61 109 | return 12 110 | """, 111 | ), 112 | ( # There's a pattern going on here, it shouldn't be disturbed. 113 | """ 114 | def foo(x): 115 | if isinstance(x, str): 116 | y = foo(x) 117 | w = faa(x) 118 | return y, w 119 | if isinstance(x, int): 120 | y = faf(x) 121 | w = wow(x) 122 | return y, w 123 | if isinstance(x, float): 124 | y = wowow(x) 125 | w = papaf(x) 126 | return y, w 127 | return 1, 2 128 | """, 129 | """ 130 | def foo(x): 131 | if isinstance(x, str): 132 | y = foo(x) 133 | w = faa(x) 134 | return y, w 135 | if isinstance(x, int): 136 | y = faf(x) 137 | w = wow(x) 138 | return y, w 139 | if isinstance(x, float): 140 | y = wowow(x) 141 | w = papaf(x) 142 | return y, w 143 | return 1, 2 144 | """, 145 | ),) 146 | 147 | for source, expected_abstraction in test_cases: 148 | processed_content = fixes.swap_if_else(source) 149 | if not testing_infra.check_fixes_equal(processed_content, expected_abstraction): 150 | return 1 151 | 152 | return 0 153 | 154 | 155 | if __name__ == "__main__": 156 | sys.exit(main()) 157 | -------------------------------------------------------------------------------- /tests/unit/test_trace_origin.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # fmt: off 4 | # isort: skip_file 5 | # pyrefact: skip_file 6 | from os import * 7 | import sys as sys 8 | from pathlib import Path 9 | 10 | from pyrefact import tracing 11 | 12 | important_variable = 42 13 | if foo := important_variable - 1: 14 | pass 15 | 16 | 17 | def main() -> int: 18 | with Path(__file__).open("r", encoding="utf-8") as stream: 19 | source = stream.read() 20 | 21 | assert tracing.trace_origin("tracing", source)[:2] == ("from pyrefact import tracing", 10) 22 | assert tracing.trace_origin("sys", source)[:2] == ("import sys as sys", 7) 23 | assert tracing.trace_origin("sys", source)[:2] == ("import sys as sys", 7) 24 | assert tracing.trace_origin("main", source)[1] == 17 25 | assert tracing.trace_origin("important_variable", source)[:2] == ("important_variable = 42", 12) 26 | 27 | assert tracing.trace_origin("foo", source)[:2] == ("foo := important_variable - 1", 13) 28 | 29 | assert tracing.trace_origin("getcwd", source)[:2] == ("from os import *", 6) 30 | assert tracing.trace_origin("source", source)[:2] == ("source = stream.read()", 19) 31 | 32 | return 0 33 | 34 | 35 | if __name__ == "__main__": 36 | sys.exit(main()) 37 | -------------------------------------------------------------------------------- /tests/unit/test_unravel_classes.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | 4 | import sys 5 | from pathlib import Path 6 | 7 | from pyrefact import object_oriented 8 | 9 | sys.path.append(str(Path(__file__).parents[1])) 10 | import testing_infra 11 | 12 | 13 | def _test_remove_unused_self_cls() -> int: 14 | test_cases = (( 15 | """ 16 | class Foo: 17 | def __init__(self): 18 | self.bar = 3 19 | 20 | def do_stuff(self): 21 | print(self.bar) 22 | 23 | @staticmethod 24 | def do_stuff_static(var, arg): 25 | print(var + arg) 26 | 27 | @classmethod 28 | def do_stuff_classmethod(cls, arg): 29 | cls.do_stuff_static(1, arg) 30 | 31 | def do_stuff_classmethod_2(self, arg): 32 | self.do_stuff_static(1, arg) 33 | 34 | @classmethod 35 | def do_stuff_classmethod_unused(cls, arg): 36 | print(arg) 37 | 38 | def do_stuff_no_self(self): 39 | print(3) 40 | 41 | @classmethod 42 | @functools.lru_cache(maxsize=None) 43 | @custom_decorator 44 | def i_have_many_decorators(cls): 45 | return 10 46 | """, 47 | """ 48 | class Foo: 49 | def __init__(self): 50 | self.bar = 3 51 | 52 | def do_stuff(self): 53 | print(self.bar) 54 | 55 | @staticmethod 56 | def do_stuff_static(var, arg): 57 | print(var + arg) 58 | 59 | @classmethod 60 | def do_stuff_classmethod(cls, arg): 61 | cls.do_stuff_static(1, arg) 62 | 63 | @classmethod 64 | def do_stuff_classmethod_2(cls, arg): 65 | cls.do_stuff_static(1, arg) 66 | 67 | @staticmethod 68 | def do_stuff_classmethod_unused(arg): 69 | print(arg) 70 | 71 | @staticmethod 72 | def do_stuff_no_self(): 73 | print(3) 74 | 75 | @staticmethod 76 | @functools.lru_cache(maxsize=None) 77 | @custom_decorator 78 | def i_have_many_decorators(): 79 | return 10 80 | """, 81 | ),) 82 | 83 | for source, expected_abstraction in test_cases: 84 | processed_content = object_oriented.remove_unused_self_cls(source) 85 | if not testing_infra.check_fixes_equal(processed_content, expected_abstraction): 86 | return 1 87 | 88 | return 0 89 | 90 | 91 | def _test_move_staticmethod_global() -> int: 92 | test_cases = (( 93 | """ 94 | def q() -> None: 95 | print(1) 96 | Spam.weeee() 97 | 98 | class Spam: 99 | @staticmethod 100 | def weeee(): 101 | print(3) 102 | 103 | "Very important string statement" 104 | 105 | class Foo: 106 | def __init__(self): 107 | self.x = 2 108 | 109 | @staticmethod 110 | def some_static_function(x, y) -> int: 111 | return 2 + x + y 112 | 113 | @staticmethod 114 | def some_other_static_function(): 115 | print(3) 116 | """, 117 | """ 118 | def q() -> None: 119 | print(1) 120 | _weeee() 121 | 122 | def _weeee(): 123 | print(3) 124 | 125 | class Spam: 126 | pass 127 | 128 | "Very important string statement" 129 | 130 | def _some_other_static_function(): 131 | print(3) 132 | 133 | def _some_static_function(x, y) -> int: 134 | return 2 + x + y 135 | 136 | class Foo: 137 | def __init__(self): 138 | self.x = 2 139 | """, 140 | ), 141 | ( 142 | """ 143 | class Foo(object): 144 | @staticmethod 145 | def h(): 146 | print(1) 147 | 148 | class Bar: 149 | @staticmethod 150 | def h(): 151 | print(1) 152 | """, 153 | """ 154 | class Foo(object): 155 | @staticmethod 156 | def h(): 157 | print(1) 158 | 159 | def _h(): 160 | print(1) 161 | class Bar: 162 | pass 163 | """, 164 | ),) 165 | 166 | for source, expected_abstraction in test_cases: 167 | processed_content = object_oriented.move_staticmethod_static_scope(source, preserve=set()) 168 | if not testing_infra.check_fixes_equal(processed_content, expected_abstraction): 169 | return 1 170 | 171 | return 0 172 | 173 | 174 | def main() -> int: 175 | returncode = 0 176 | returncode += _test_remove_unused_self_cls() 177 | returncode += _test_move_staticmethod_global() 178 | 179 | return returncode 180 | 181 | 182 | if __name__ == "__main__": 183 | sys.exit(main()) 184 | -------------------------------------------------------------------------------- /tests/unit/test_unused_zip_args.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import sys 4 | from pathlib import Path 5 | 6 | from pyrefact import fixes 7 | 8 | sys.path.append(str(Path(__file__).parents[1])) 9 | import testing_infra 10 | 11 | 12 | def main() -> int: 13 | test_cases = (( 14 | """ 15 | x = (a for _, a in zip(range(3), range(1, 3))) 16 | """, 17 | """ 18 | x = (a for a in range(1, 3)) 19 | """, 20 | ), 21 | ( 22 | """ 23 | x = {a for _, a in zip(range(3), range(1, 3))} 24 | """, 25 | """ 26 | x = {a for a in range(1, 3)} 27 | """, 28 | ), 29 | ( 30 | """ 31 | x = [a for _, a in zip(range(3), range(1, 3))] 32 | """, 33 | """ 34 | x = [a for a in range(1, 3)] 35 | """, 36 | ), 37 | ( 38 | """ 39 | x = (1 for _, _ in zip(range(3), range(1, 3))) 40 | """, 41 | """ 42 | x = (1 for _ in range(3)) 43 | """, 44 | ), 45 | ( 46 | """ 47 | x = (1 for a, q, _, _ in zip(range(3), range(1, 3), range(3, 5), (1, 2, 3))) 48 | """, 49 | """ 50 | x = (1 for a, q in zip(range(3), range(1, 3))) 51 | """, 52 | ), 53 | ( 54 | """ 55 | for _, a in zip(range(3), range(1, 3)): 56 | print(a - 1) 57 | """, 58 | """ 59 | for a in range(1, 3): 60 | print(a - 1) 61 | """, 62 | ), 63 | ( 64 | """ 65 | for a, _ in zip(range(3), range(1, 3)): 66 | print(a - 1) 67 | """, 68 | """ 69 | for a in range(3): 70 | print(a - 1) 71 | """, 72 | ), 73 | ( 74 | """ 75 | for _, _ in zip(range(3), range(1, 3)): 76 | print(10) 77 | """, 78 | """ 79 | for _ in range(3): 80 | print(10) 81 | """, 82 | ), 83 | ( 84 | """ 85 | for a, _, e, _ in zip(range(3), range(1, 3), range(3, 5), (1, 2, 3)): 86 | print(a != e != e is e) 87 | """, 88 | """ 89 | for a, e in zip(range(3), range(3, 5)): 90 | print(a != e != e is e) 91 | """, 92 | ),) 93 | 94 | for source, expected_abstraction in test_cases: 95 | processed_content = fixes.unused_zip_args(source) 96 | if not testing_infra.check_fixes_equal( 97 | processed_content, expected_abstraction, clear_paranthesises=True 98 | ): 99 | return 1 100 | 101 | return 0 102 | 103 | 104 | if __name__ == "__main__": 105 | sys.exit(main()) 106 | --------------------------------------------------------------------------------