├── .github └── workflows │ ├── pythonpublish.yml │ └── test.yml ├── .gitignore ├── LICENSE ├── Makefile ├── README.rst ├── codemod_unittest_to_pytest_asserts ├── __init__.py └── tests │ ├── __init__.py │ ├── pytest_code.py │ ├── test_codemod.py │ └── unittest_code.py ├── setup.cfg └── setup.py /.github/workflows/pythonpublish.yml: -------------------------------------------------------------------------------- 1 | name: Upload Python Package 2 | 3 | on: 4 | release: 5 | types: [created] 6 | 7 | jobs: 8 | deploy: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v2 12 | - name: Set up Python 13 | uses: actions/setup-python@v1 14 | with: 15 | python-version: '3.x' 16 | - name: Install dependencies 17 | run: | 18 | python -m pip install --upgrade pip 19 | pip install setuptools wheel twine 20 | - name: Build and publish 21 | env: 22 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 23 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 24 | run: | 25 | python setup.py sdist bdist_wheel 26 | twine upload dist/* 27 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Test 2 | "on": 3 | push: 4 | branches: 5 | - master 6 | pull_request: 7 | branches: 8 | - master 9 | jobs: 10 | Test: 11 | runs-on: ubuntu-20.04 12 | steps: 13 | - uses: actions/setup-python@v2 14 | with: 15 | python-version: "3.10" 16 | - uses: actions/checkout@v2 17 | - run: pip install pytest-cov -e . 18 | - run: py.test -vvv --cov . 19 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .coverage 2 | /*.egg-info 3 | /build 4 | /dist 5 | __pycache__/ 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2019 Hans-Wilhelm Warlo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | test_codemod: 2 | py.test 3 | 4 | build: 5 | python setup.py sdist bdist_wheel 6 | 7 | clean: 8 | rm -r build/ dist/ 9 | 10 | upload: 11 | twine upload dist/* 12 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | ********************************** 2 | codemod-unittest-to-pytest-asserts 3 | ********************************** 4 | 5 | A `codemod `_ to automatically refactor 6 | unittest assertions with pytest assertions. 7 | 8 | 9 | Installation 10 | ============ 11 | 12 | This codemod requires Python 3.8 or newer. 13 | 14 | With pip, assuming Python 3.8 or newer is used:: 15 | 16 | python3 -m pip install codemod-unittest-to-pytest-asserts 17 | 18 | With pipx, assuming Python 3.8 exists on the system:: 19 | 20 | pipx install --python $(which python3.8) codemod-unittest-to-pytest-asserts 21 | 22 | 23 | Usage 24 | ===== 25 | 26 | Run the installed command on the Python files or directory of files you want to refactor:: 27 | 28 | codemod-unittest-to-pytest-asserts some-python-files.py 29 | 30 | or:: 31 | 32 | codemod-unittest-to-pytest-asserts some_directory/ 33 | 34 | You'll be asked to confirm all changes. 35 | 36 | It is recommended to run an autoformatter, like Black, after the 37 | refactoring. 38 | -------------------------------------------------------------------------------- /codemod_unittest_to_pytest_asserts/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import ast 3 | import re 4 | 5 | import astunparse 6 | import codemod 7 | 8 | TRUE_FALSE_NONE = {"True", "False", "None"} 9 | 10 | COMMENT_REGEX = re.compile(r"(\s*).*\)(\s*\#.*)") 11 | 12 | 13 | class Malformed(Exception): 14 | def __init__(self, message="Malformed", *, node): 15 | super().__init__(f"{message}: {node}: {astunparse.unparse(node)}") 16 | 17 | 18 | def parse_args(node): 19 | args = [] 20 | kwarg_list = [] 21 | 22 | for arg in node.args: 23 | args.append(astunparse.unparse(arg).replace("\n", "")) 24 | for kwarg in node.keywords: 25 | kwarg_list.append(astunparse.unparse(kwarg).replace("\n", "")) 26 | 27 | return args, kwarg_list 28 | 29 | 30 | def parse_args_and_msg(node, required_args_count, *, raise_if_malformed=True): 31 | args, kwarg_list = parse_args(node) 32 | msg = "" 33 | 34 | for i, kwarg in enumerate(kwarg_list): 35 | key, val = kwarg.split("=") 36 | if key == "msg": 37 | msg = val 38 | kwarg_list.pop(i) 39 | break 40 | 41 | if len(args) > required_args_count and type(args[required_args_count]) == str: 42 | msg = args.pop(required_args_count) 43 | 44 | if raise_if_malformed and len(args) != required_args_count: 45 | raise Malformed(node=node) 46 | 47 | return args, kwarg_list, f", {msg}" if msg else "" 48 | 49 | 50 | def _handle_equal_or_unequal(node, *, is_op, cmp_op): 51 | args, kwarg_list, msg_with_comma = parse_args_and_msg(node, 2, raise_if_malformed=False) 52 | 53 | if len(args) != 2 or len(kwarg_list) > 0: 54 | raise Malformed(f"Potentially malformed", node=node) 55 | 56 | if args[0] in TRUE_FALSE_NONE: 57 | return f"assert {args[1]} {is_op} {args[0]}{msg_with_comma}" 58 | if args[1] in TRUE_FALSE_NONE: 59 | return f"assert {args[0]} {is_op} {args[1]}{msg_with_comma}" 60 | 61 | # De-yoda expressions like assertEqual("foo", bar) to bar == "foo" 62 | if node.args[0].__class__ == ast.Constant and node.args[1].__class__ != ast.Constant: 63 | args = [args[1], args[0]] 64 | 65 | return f"assert {args[0]} {cmp_op} {args[1]}{msg_with_comma}" 66 | 67 | 68 | def _handle_prefix_or_suffix(node, *, prefix="", suffix=""): 69 | args, _, msg_with_comma = parse_args_and_msg(node, 1) 70 | return f"assert {prefix}{args[0]}{suffix}{msg_with_comma}" 71 | 72 | 73 | def _handle_generic_binary(node, *, op): 74 | args, _, msg_with_comma = parse_args_and_msg(node, 2) 75 | return f"assert {args[0]} {op} {args[1]}{msg_with_comma}" 76 | 77 | 78 | def _handle_generic_call(node, *, func): 79 | args, _, msg_with_comma = parse_args_and_msg(node, 2) 80 | return f"assert {func}({args[0]}, {args[1]}){msg_with_comma}" 81 | 82 | 83 | def _handle_almost_equal(node, *, op): 84 | args, _, msg_with_comma = parse_args_and_msg(node, 2) 85 | return f"assert round({args[0]} - {args[1]}, 7) {op} 0{msg_with_comma}" 86 | 87 | 88 | def handle_equal(node): 89 | return _handle_equal_or_unequal(node, is_op="is", cmp_op="==") 90 | 91 | 92 | def handle_not_equal(node): 93 | return _handle_equal_or_unequal(node, is_op="is not", cmp_op="!=") 94 | 95 | 96 | def handle_true(node): 97 | return _handle_prefix_or_suffix(node, prefix="") 98 | 99 | 100 | def handle_false(node): 101 | return _handle_prefix_or_suffix(node, prefix="not ") 102 | 103 | 104 | def handle_in(node): 105 | return _handle_generic_binary(node, op="in") 106 | 107 | 108 | def handle_not_in(node): 109 | return _handle_generic_binary(node, op="not in") 110 | 111 | 112 | def handle_is(node): 113 | return _handle_generic_binary(node, op="is") 114 | 115 | 116 | def handle_is_not(node): 117 | return _handle_generic_binary(node, op="is not") 118 | 119 | 120 | def handle_is_none(node): 121 | return _handle_prefix_or_suffix(node, suffix=" is None") 122 | 123 | 124 | def handle_is_not_none(node): 125 | return _handle_prefix_or_suffix(node, suffix=" is not None") 126 | 127 | 128 | def handle_is_instance(node): 129 | return _handle_generic_call(node, func="isinstance") 130 | 131 | 132 | def handle_not_is_instance(node): 133 | return _handle_generic_call(node, func="not isinstance") 134 | 135 | 136 | def handle_less(node): 137 | return _handle_generic_binary(node, op="<") 138 | 139 | 140 | def handle_less_equal(node): 141 | return _handle_generic_binary(node, op="<=") 142 | 143 | 144 | def handle_greater(node): 145 | return _handle_generic_binary(node, op=">") 146 | 147 | 148 | def handle_greater_equal(node): 149 | return _handle_generic_binary(node, op=">=") 150 | 151 | 152 | def handle_almost_equal(node): 153 | return _handle_almost_equal(node, op=">=") 154 | 155 | 156 | def handle_not_almost_equal(node): 157 | return _handle_almost_equal(node, op="!=") 158 | 159 | 160 | def handle_raises(node, **kwargs): 161 | if kwargs.get("withitem"): 162 | return handle_with_raises(node, **kwargs) 163 | args, _ = parse_args(node) 164 | if len(args) > 2: 165 | raise Malformed(node=node) 166 | if len(args) == 2: 167 | return f"pytest.raises({args[0]}, {args[1]})" 168 | 169 | 170 | def handle_with_raises(node, **kwargs): 171 | args, _ = parse_args(node) 172 | optional_vars = kwargs.get('optional_vars', None) 173 | if len(args) > 1: 174 | raise Malformed(node=node) 175 | 176 | if optional_vars: 177 | return f"with pytest.raises({args[0]}) as {optional_vars.id}:" 178 | return f"with pytest.raises({args[0]}):" 179 | 180 | 181 | assert_mapping = { 182 | "assertEqual": handle_equal, 183 | "assertEquals": handle_equal, 184 | "assertNotEqual": handle_not_equal, 185 | "assertNotEquals": handle_not_equal, 186 | "assert_": handle_true, 187 | "assertTrue": handle_true, 188 | "assertFalse": handle_false, 189 | "assertIn": handle_in, 190 | "assertNotIn": handle_not_in, 191 | "assertIs": handle_is, 192 | "assertIsNot": handle_is_not, 193 | "assertIsNone": handle_is_none, 194 | "assertIsNotNone": handle_is_not_none, 195 | "assertIsInstance": handle_is_instance, 196 | "assertNotIsInstance": handle_not_is_instance, 197 | "assertLess": handle_less, 198 | "assertLessEqual": handle_less_equal, 199 | "assertGreater": handle_greater, 200 | "assertGreaterEqual": handle_greater_equal, 201 | "assertAlmostEqual": handle_almost_equal, 202 | "assertNotAlmostEqual": handle_not_almost_equal, 203 | "assertRaises": handle_raises, 204 | } 205 | 206 | 207 | def convert(node): 208 | node_call = node_get_call(node) 209 | f = assert_mapping.get(node_get_func_attr(node_call), None) 210 | if not f: 211 | return None 212 | 213 | try: 214 | if isinstance(node, ast.With): 215 | return f(node_call, withitem=True, optional_vars=node.items[0].optional_vars) 216 | return f(node_call) 217 | except Malformed as e: 218 | print(str(e)) 219 | return None 220 | 221 | 222 | def dfs_walk(node): 223 | """ 224 | Walk along the nodes of the AST in a DFS fashion returning the pre-order-tree-traversal 225 | """ 226 | 227 | stack = [node] 228 | for child in ast.iter_child_nodes(node): 229 | stack.extend(dfs_walk(child)) 230 | return stack 231 | 232 | 233 | def node_get_func_attr(node): 234 | if isinstance(node, ast.Call): 235 | return getattr(node.func, "attr", None) 236 | 237 | 238 | def node_get_call(node): 239 | if not (isinstance(node, ast.Expr) or isinstance(node, ast.With)): 240 | return False 241 | 242 | if isinstance(node, ast.Expr): 243 | value = getattr(node, "value", None) 244 | if isinstance(value, ast.Call): 245 | return value 246 | 247 | if isinstance(node, ast.With): 248 | value = getattr( 249 | node.items[0], "context_expr", None 250 | ) # Naively choosing the first item in the with 251 | if isinstance(value, ast.Call): 252 | return value 253 | return None 254 | 255 | 256 | def get_col_offset(node): 257 | return node.col_offset 258 | 259 | 260 | def get_lineno(node): 261 | # We generally use `lineno` from the AST node, but special case for `With` expressions 262 | if isinstance(node, ast.With): 263 | return node.items[0].context_expr.lineno 264 | 265 | return node.lineno 266 | 267 | 268 | def get_end_lineno(node): 269 | # We generally use `end_lineno` directly from the AST node, but special case for `With` expressions 270 | if isinstance(node, ast.With): 271 | return node.items[0].context_expr.end_lineno 272 | 273 | return node.end_lineno 274 | 275 | 276 | def assert_patches(list_of_lines): 277 | """ 278 | Main method where we get the list of lines from codemod. 279 | 1. Parses it with AST 280 | 2. Traverses the AST in a pre-order-tree traversal 281 | 3. Grab the Call values we are interested in e.g. `assertEqual()` 282 | 4. Try executing `convert` on the Call node or continue 283 | 5. Construct a codemod.Patch for the conversion and replace start->end lines with the conversion 284 | 6. Handle special cases with importing pytest if it used and not imported, and append comment if it exists. 285 | """ 286 | 287 | patches = [] 288 | joined_lines = "".join(list_of_lines) 289 | ast_parsed = ast.parse(joined_lines) 290 | 291 | pytest_imported = "import pytest" in joined_lines 292 | 293 | line_deviation = 0 294 | for node in dfs_walk(ast_parsed): 295 | if not node_get_call(node): 296 | continue 297 | 298 | converted = convert(node) 299 | if not converted: 300 | continue 301 | 302 | assert_line = get_col_offset(node) * " " + converted + "\n" 303 | start_line = get_lineno(node) 304 | end_line = get_end_lineno(node) 305 | 306 | patches.append( 307 | codemod.Patch( 308 | start_line - line_deviation - 1, 309 | end_line_number=end_line - line_deviation, 310 | new_lines=assert_line, 311 | ) 312 | ) 313 | 314 | requires_import = "pytest." in assert_line 315 | if requires_import and not pytest_imported: 316 | patches.append( 317 | codemod.Patch(0, end_line_number=0, new_lines="import pytest\n") 318 | ) 319 | line_deviation -= 1 320 | pytest_imported = True 321 | 322 | comment_line = COMMENT_REGEX.search( 323 | list_of_lines[min(end_line - 1, len(list_of_lines) - 1)] 324 | ) 325 | line_deviation += end_line - start_line 326 | 327 | if comment_line: 328 | comment = comment_line.group(1) + comment_line.group(2).lstrip() + "\n" 329 | patches.append( 330 | codemod.Patch( 331 | end_line - line_deviation, 332 | end_line_number=end_line - line_deviation, 333 | new_lines=comment, 334 | ) 335 | ) 336 | line_deviation -= 1 337 | 338 | return patches 339 | 340 | 341 | def is_py(filename): 342 | """ 343 | Filter method using filename's to select what files to evaluate for codemodding 344 | """ 345 | 346 | return filename.split(".")[-1] == "py" 347 | 348 | 349 | def main(): 350 | import sys 351 | if sys.version_info < (3, 8): 352 | raise RuntimeError("This script requires Python version >=3.8") 353 | 354 | try: 355 | path = sys.argv[1] 356 | except IndexError: 357 | path = "." 358 | 359 | codemod.Query( 360 | assert_patches, path_filter=is_py, root_directory=path 361 | ).run_interactive() 362 | print( 363 | "\nHINT: Consider running a formatter to correctly format your new assertions!" 364 | ) 365 | 366 | 367 | if __name__ == "__main__": 368 | main() 369 | -------------------------------------------------------------------------------- /codemod_unittest_to_pytest_asserts/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/warlo/codemod-unittest-to-pytest-asserts/eaabf0e28a7c3712b20d0e0ca14481d13bcf4c46/codemod_unittest_to_pytest_asserts/tests/__init__.py -------------------------------------------------------------------------------- /codemod_unittest_to_pytest_asserts/tests/pytest_code.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | class ExampleTest: 3 | def test_something(self): 4 | assert 1 == 1 5 | assert 1 == 1, '1 should always be 1' 6 | assert 1 == 1, '1 should always be 1' 7 | assert 1 == 1 8 | # 1 should always be one 9 | 10 | def inner_test_method(): 11 | assert 1 == 1 12 | with pytest.raises(ValueError): 13 | raise ValueError("SomeError") 14 | 15 | innerTestMethod() 16 | 17 | assert 1 == 1 18 | 19 | def test_lots_of_arguments(self): 20 | def inside_another_function(): 21 | 22 | assert get_product_from_backend_product_with_supplier_product_and_cart_etc_etc_etc(product__backend_product_id=self.backend_product.id, product_id=self.product.id, name='Julebrus') is True 23 | assert True 24 | 25 | assert not False 26 | 27 | def test_assert_raises(self): 28 | with pytest.raises(ZeroDivisionError) as exc: 29 | divide_by_zero = 3 / 0 30 | assert exc.exception.args[0] == 'division by zero' 31 | 32 | def test_assert_raises_legacy(self): 33 | def foo(): 34 | raise ValueError("bar") 35 | pytest.raises(ValueError, foo) 36 | 37 | def test_various_ops(self): 38 | assert 'a' in 'abc' 39 | assert 'a' not in 'def' 40 | assert 1 != 2 41 | assert None is None 42 | assert True is not False 43 | assert None is None 44 | assert True is not None 45 | assert isinstance(1, int) 46 | assert not isinstance(1, str) 47 | assert 1 < 2 48 | assert 2 <= 2 49 | assert 3 > 2 50 | assert 4 >= 3 51 | 52 | def test_de_yoda(self): 53 | foo = "baz" 54 | assert foo == 'bar' 55 | -------------------------------------------------------------------------------- /codemod_unittest_to_pytest_asserts/tests/test_codemod.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | import shutil 3 | 4 | import codemod 5 | 6 | from codemod_unittest_to_pytest_asserts import assert_patches, is_py 7 | 8 | DIRNAME = pathlib.Path(__file__).parent 9 | 10 | UNITTEST_FILE = DIRNAME / "unittest_code.py" 11 | PYTEST_FILE = DIRNAME / "pytest_code.py" 12 | 13 | 14 | def test_codemod(tmp_path): 15 | victim = tmp_path / "victim.py" 16 | shutil.copy(UNITTEST_FILE, victim) 17 | for patch in codemod.Query( 18 | assert_patches, 19 | root_directory=tmp_path, 20 | path_filter=is_py, 21 | ).generate_patches(): 22 | lines = list(open(patch.path)) 23 | patch.apply_to(lines) 24 | pathlib.Path(patch.path).write_text("".join(lines)) 25 | 26 | assert victim.read_text() == (DIRNAME / PYTEST_FILE).read_text() 27 | -------------------------------------------------------------------------------- /codemod_unittest_to_pytest_asserts/tests/unittest_code.py: -------------------------------------------------------------------------------- 1 | class ExampleTest: 2 | def test_something(self): 3 | self.assertEqual(1, 1) 4 | self.assertEqual(1, 1, msg="1 should always be 1") 5 | self.assertEqual(1, 1, "1 should always be 1") 6 | self.assertEqual(1, 1) # 1 should always be one 7 | 8 | def inner_test_method(): 9 | self.assertEqual(1, 1) 10 | with self.assertRaises(ValueError): # This error is always raised! 11 | raise ValueError("SomeError") 12 | 13 | innerTestMethod() 14 | 15 | self.assertEqual(1, 1) 16 | 17 | def test_lots_of_arguments(self): 18 | def inside_another_function(): 19 | 20 | self.assertEqual( 21 | get_product_from_backend_product_with_supplier_product_and_cart_etc_etc_etc( 22 | product__backend_product_id=self.backend_product.id, 23 | product_id=self.product.id, 24 | name="Julebrus", 25 | ), 26 | True 27 | ) 28 | self.assertTrue(True) 29 | 30 | self.assertFalse(False) 31 | 32 | def test_assert_raises(self): 33 | with self.assertRaises(ZeroDivisionError) as exc: 34 | divide_by_zero = 3 / 0 35 | self.assertEqual(exc.exception.args[0], 'division by zero') 36 | 37 | def test_assert_raises_legacy(self): 38 | def foo(): 39 | raise ValueError("bar") 40 | self.assertRaises(ValueError, foo) 41 | 42 | def test_various_ops(self): 43 | self.assertIn("a", "abc") 44 | self.assertNotIn("a", "def") 45 | self.assertNotEqual(1, 2) 46 | self.assertIs(None, None) 47 | self.assertIsNot(True, False) 48 | self.assertIsNone(None) 49 | self.assertIsNotNone(True) 50 | self.assertIsInstance(1, int) 51 | self.assertNotIsInstance(1, str) 52 | self.assertLess(1, 2) 53 | self.assertLessEqual(2, 2) 54 | self.assertGreater(3, 2) 55 | self.assertGreaterEqual(4, 3) 56 | 57 | def test_de_yoda(self): 58 | foo = "baz" 59 | self.assertEqual('bar', foo) 60 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = codemod-unittest-to-pytest-asserts 3 | version = 1.2.0 4 | author = Hans-Wilhelm Warlo 5 | author_email = hw@warlo.no 6 | license = MIT 7 | description = Codemod to refactor unittest assertions to pytest assertions. 8 | url = https://github.com/hanswilw/codemod-unittest-to-pytest-asserts 9 | long_description = file: README.rst 10 | classifiers = 11 | Programming Language :: Python :: 3 12 | Programming Language :: Python :: 3.8 13 | 14 | [options] 15 | zip_safe = False 16 | include_package_data = False 17 | packages = find: 18 | python_requires = >= 3.8 19 | install_requires = 20 | codemod >= 1.0.0 21 | astunparse >= 1.6.2 22 | 23 | [options.packages.find] 24 | exclude = 25 | *.tests 26 | *.tests.* 27 | 28 | [options.entry_points] 29 | console_scripts = 30 | codemod-unittest-to-pytest-asserts = codemod_unittest_to_pytest_asserts:main 31 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup() 4 | --------------------------------------------------------------------------------