├── docs ├── design │ ├── opcode_executor.md │ ├── eval_frame.md │ ├── output-restoration.md │ ├── tracker-and-guard.md │ ├── stringify-guard.md │ ├── closure.md │ └── builtin-dispatcher.md ├── compat │ └── python311 │ │ ├── index.md │ │ ├── CALL.md │ │ └── closure.md ├── notes │ └── function-and-method.md ├── profiler_introduction.md └── instructions │ ├── UNPACK_SEQUENCE.md │ └── CALL_FUNCTION.md ├── tests ├── error_test_paddle_cast.py ├── test_tensor_slice.py ├── tests_legacy │ ├── run_all.sh │ ├── test_generator.py │ ├── test_symbolic_for.py │ ├── test_symbolic_fallback.py │ ├── test_symbolic_infer_meta.py │ ├── test_resnet50.py │ ├── error_test_symbolic_load_before_call.py │ ├── test_symbolic_backward.py │ ├── test_simple_net.py │ ├── test_error_cases.py │ ├── test_optransform_cache.py │ ├── test_symbolic_fallback_if.py │ ├── error_test_trace_cache.py │ ├── test_symbolic_trace.py │ ├── error_test_sir_call.py │ ├── test_basic_translation.py │ ├── error_test_trace_cache.py.zhanfei │ ├── test_symbolic_nested.py │ ├── error_test_resnet_with_trace_cache.py │ ├── test_case_base.py │ └── test_resnet50_backward.py ├── test_multiple_args.py ├── test_delete_fast.py ├── error_test_guard.py ├── test_13_make_function.py ├── test_error_handling.py ├── extract_errors.py ├── error_test_resnet50.py ├── test_02_store_inplace.py ├── test_09_f_string.py ├── test_str_format.py ├── error_test_general_bytecode.py ├── test_numpy.py ├── test_dup_top.py ├── test_simulate_initialize.py ├── test_01_basic.py ├── test_constant_graph.py ├── test_numpy_var_if.py ├── error_test_jump_inline.py ├── test_execution_base.py ├── run_all.sh ├── test_stack.py ├── test_07_unpack.py ├── test_16_paddle_api.py ├── test_call_object.py ├── test_resnet.py ├── test_tensor_dtype_in_guard.py ├── test_segment_linear.py ├── test_trace_list_arg.py ├── run_all_paddle_ci.sh ├── test_exception.py ├── test_output_restoration.py ├── test_range.py ├── test_20_string.py ├── test_03_tuple.py ├── test_sir_rollback.py ├── test_guard_outputs.py ├── test_17_paddle_layer.py ├── test_18_tensor_method.py ├── test_08_rot.py ├── test_10_build_unpack.py ├── test_guard_user_defined_fn.py ├── test_resnet50_backward.py ├── test_enumerate.py ├── test_cost_model.py ├── test_map.py ├── test_11_jumps.py ├── test_06_call_function.py ├── test_15_slice.py ├── test_inplace_api.py ├── test_break_graph.py ├── test_code_status.py ├── test_21_global.py └── test_instruction_translator_cache.py ├── sot ├── opcode_translator │ ├── executor │ │ ├── __init__.py │ │ ├── instr_flag.py │ │ ├── dispatch_functions.py │ │ ├── variables │ │ │ └── __init__.py │ │ └── tracker_viewer.py │ ├── __init__.py │ ├── custom_code.py │ ├── instruction_utils │ │ ├── __init__.py │ │ └── opcode_info.py │ ├── transform.py │ └── skip_files.py ├── __init__.py ├── psdb.py ├── utils │ ├── __init__.py │ ├── exceptions.py │ ├── code_status.py │ ├── paddle_api_config.py │ └── magic_methods.py ├── profiler.py ├── symbolic │ ├── compile_cache.py │ └── symbolic_context.py └── translate.py ├── .gitignore ├── examples ├── trace_basic.py ├── graph_break.py ├── guard.py └── run_all.sh ├── .github └── workflows │ ├── codestyle.yaml │ ├── unittest.yaml │ └── paddle_ci.yaml ├── README.md ├── .pre-commit-config.yaml └── pyproject.toml /docs/design/opcode_executor.md: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/error_test_paddle_cast.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /sot/opcode_translator/executor/__init__.py: -------------------------------------------------------------------------------- 1 | from . import variable_dispatch # noqa: F401 2 | -------------------------------------------------------------------------------- /sot/opcode_translator/__init__.py: -------------------------------------------------------------------------------- 1 | from .transform import eval_frame_callback # noqa: F401 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.sw* 3 | user_tag 4 | 5 | # Editor config 6 | .vscode 7 | 8 | # Environments 9 | venv/ 10 | .venv/ 11 | 12 | # Build 13 | build/ 14 | *.egg-info 15 | -------------------------------------------------------------------------------- /sot/opcode_translator/custom_code.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import types 4 | from typing import NamedTuple 5 | 6 | 7 | class CustomCode(NamedTuple): 8 | code: types.CodeType | None 9 | disable_eval_frame: bool 10 | -------------------------------------------------------------------------------- /sot/__init__.py: -------------------------------------------------------------------------------- 1 | from . import psdb # noqa: F401 2 | from .opcode_translator.breakpoint import ( # noqa: F401 3 | BM, 4 | add_breakpoint, 5 | add_event, 6 | ) 7 | from .opcode_translator.skip_files import skip_function # noqa: F401 8 | from .translate import symbolic_translate # noqa: F401 9 | -------------------------------------------------------------------------------- /docs/compat/python311/index.md: -------------------------------------------------------------------------------- 1 | # Python 3.11 适配 2 | 3 | ## 字段适配 4 | 5 | 指 CodeObject 的字段修改适配,Python 3.11 主要新增了 `co_exceptiontable` 字段,并对 `co_linetable` 进行了一定的修改 6 | 7 | - [`co_linetable`](./co_linetable.md) 8 | 9 | ## 字节码修改适配 10 | 11 | - [CALL 相关字节码](./CALL.md) 12 | 13 | - [closure 相关修改](./closure.md) 14 | -------------------------------------------------------------------------------- /tests/test_tensor_slice.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from test_case_base import TestCaseBase 4 | 5 | import paddle 6 | 7 | 8 | def foo(x: paddle.Tensor): 9 | return x[:, 0] 10 | 11 | 12 | class TestExecutor(TestCaseBase): 13 | def test_tensor_slice(self): 14 | x = paddle.randn((10, 10)) 15 | self.assert_results(foo, x) 16 | 17 | 18 | if __name__ == "__main__": 19 | unittest.main() 20 | -------------------------------------------------------------------------------- /tests/tests_legacy/run_all.sh: -------------------------------------------------------------------------------- 1 | # 遍历目录下的所有 Python 文件 2 | export PYTHONPATH=$PYTHONPATH:../ 3 | 4 | for file in ./test_*.py; do 5 | # 检查文件是否为 Python 文件 6 | if [ -f "$file" ]; then 7 | echo Running: PYTHONPATH=$PYTHONPATH" python " $file 8 | # 执行文件 9 | python "$file" 10 | if [ $? -ne 0 ]; then 11 | echo "run $file failed" 12 | exit 1 13 | fi 14 | fi 15 | done 16 | -------------------------------------------------------------------------------- /tests/tests_legacy/test_generator.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from test_case_base import TestCaseBase 4 | 5 | 6 | def gen(): 7 | yield from range(10) 8 | 9 | 10 | def case1(): 11 | sum = 0 12 | for i in gen(): 13 | sum += i 14 | return sum 15 | 16 | 17 | class TestGen(TestCaseBase): 18 | def test_gen(self): 19 | self.assert_results(case1) 20 | 21 | 22 | if __name__ == "__main__": 23 | unittest.main() 24 | -------------------------------------------------------------------------------- /tests/test_multiple_args.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from test_case_base import TestCaseBase 4 | 5 | import paddle 6 | 7 | 8 | def foo(x, y): 9 | ret = x + y 10 | return ret 11 | 12 | 13 | class TestMultipleArgs(TestCaseBase): 14 | def test_multiple_args(self): 15 | x = paddle.to_tensor([1.0]) 16 | y = paddle.to_tensor([2.0]) 17 | self.assert_results(foo, x, y) 18 | 19 | 20 | if __name__ == "__main__": 21 | unittest.main() 22 | -------------------------------------------------------------------------------- /tests/test_delete_fast.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import unittest 4 | 5 | from test_case_base import TestCaseBase 6 | 7 | import paddle 8 | 9 | 10 | def test_delete_fast(a): 11 | a = a + 2 12 | t = a * 3 13 | del t 14 | return a 15 | 16 | 17 | class TestExecutor(TestCaseBase): 18 | def test_simple(self): 19 | a = paddle.to_tensor(1) 20 | self.assert_results(test_delete_fast, a) 21 | 22 | 23 | if __name__ == "__main__": 24 | unittest.main() 25 | -------------------------------------------------------------------------------- /sot/opcode_translator/executor/instr_flag.py: -------------------------------------------------------------------------------- 1 | # flags for instructions 2 | 3 | 4 | class FORMAT_VALUE_FLAG: 5 | FVC_MASK = 0x3 6 | FVC_NONE = 0x0 7 | FVC_STR = 0x1 8 | FVC_REPR = 0x2 9 | FVC_ASCII = 0x3 10 | FVS_MASK = 0x4 11 | FVS_HAVE_SPEC = 0x4 12 | 13 | 14 | class MAKE_FUNCTION_FLAG: 15 | MF_HAS_CLOSURE = 0x08 16 | MF_HAS_ANNOTATION = 0x04 17 | MF_HAS_KWDEFAULTS = 0x02 18 | MF_HAS_DEFAULTS = 0x01 19 | 20 | 21 | class CALL_FUNCTION_EX_FLAG: 22 | CFE_HAS_KWARGS = 0x01 23 | -------------------------------------------------------------------------------- /tests/error_test_guard.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from test_case_base import TestCaseBase 4 | 5 | import paddle 6 | 7 | 8 | def fn(x): 9 | # now do not support `with xx_guard()`, will fallback to dygraph 10 | with paddle.static.amp.fp16_guard(): 11 | out = x + 1 12 | return out 13 | 14 | 15 | class TestGuard(TestCaseBase): 16 | def test_simple(self): 17 | x = paddle.to_tensor(2) 18 | self.assert_results(fn, x) 19 | 20 | 21 | if __name__ == "__main__": 22 | unittest.main() 23 | -------------------------------------------------------------------------------- /sot/opcode_translator/instruction_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .instruction_utils import ( # noqa: F401 2 | Instruction, 3 | calc_offset_from_bytecode_offset, 4 | calc_stack_effect, 5 | convert_instruction, 6 | gen_instr, 7 | get_instructions, 8 | instrs_info, 9 | modify_extended_args, 10 | modify_instrs, 11 | modify_vars, 12 | relocate_jump_target, 13 | replace_instr, 14 | reset_offset, 15 | ) 16 | from .opcode_analysis import ( # noqa: F401 17 | Space, 18 | analysis_inputs, 19 | analysis_used_names_with_space, 20 | ) 21 | -------------------------------------------------------------------------------- /tests/test_13_make_function.py: -------------------------------------------------------------------------------- 1 | # MAKE_FUNCTION 2 | # CALL_FUNCTION_KW 3 | from __future__ import annotations 4 | 5 | import unittest 6 | 7 | from test_case_base import TestCaseBase 8 | 9 | import paddle 10 | 11 | 12 | def make_fn(x: paddle.Tensor): 13 | def fn(a, b=2, c=3, d=4): 14 | return a + b + c + d 15 | 16 | return fn(1) + fn(2, c=5) + x 17 | 18 | 19 | class TestExecutor(TestCaseBase): 20 | def test_simple(self): 21 | self.assert_results(make_fn, paddle.to_tensor(1)) 22 | 23 | 24 | if __name__ == "__main__": 25 | unittest.main() 26 | -------------------------------------------------------------------------------- /tests/tests_legacy/test_symbolic_for.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from test_case_base import TestCaseBase 4 | 5 | import paddle 6 | 7 | 8 | def case1(x): 9 | for i in range(int(x)): 10 | print("yes") 11 | return x 12 | 13 | 14 | def case2(x): 15 | sum = 0 16 | for i in x: 17 | sum += i 18 | return sum 19 | 20 | 21 | class TestFor(TestCaseBase): 22 | def test(self): 23 | self.assert_results(case1, paddle.to_tensor([4])) 24 | self.assert_results(case2, paddle.to_tensor([4.0, 1.0, 2.0, 3.0])) 25 | 26 | 27 | if __name__ == "__main__": 28 | unittest.main() 29 | -------------------------------------------------------------------------------- /tests/test_error_handling.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from test_case_base import TestCaseBase, strict_mode_guard 4 | 5 | import sot 6 | 7 | 8 | def fn_with_try_except(): 9 | sot.psdb.breakgraph() 10 | sot.psdb.fallback() 11 | try: 12 | raise ValueError("ValueError") 13 | except ValueError: 14 | print("catch ValueError") 15 | return True 16 | 17 | 18 | class TestErrorHandling(TestCaseBase): 19 | @strict_mode_guard(0) 20 | def test_fn_with_try_except(self): 21 | self.assert_results(fn_with_try_except) 22 | 23 | 24 | if __name__ == "__main__": 25 | unittest.main() 26 | -------------------------------------------------------------------------------- /tests/extract_errors.py: -------------------------------------------------------------------------------- 1 | import re 2 | import sys 3 | 4 | runtime_error_msg = sys.stdin.read() 5 | 6 | pattern = r'File "?(.*?)"?, line (\d+),.*\n(.*?)\n(.*?)$' 7 | for match in re.finditer(pattern, runtime_error_msg, re.MULTILINE): 8 | file = match.group(1) 9 | if file.startswith("./"): 10 | file = f"tests/{file[2:]}" 11 | line = match.group(2) 12 | error_info = match.group(4) 13 | if "AssertionError" not in error_info: 14 | # error_info = match.group(3) + '\n' + match.group(4) 15 | output = f"::error file={file},line={line}::Error" 16 | print(output) 17 | -------------------------------------------------------------------------------- /tests/tests_legacy/test_symbolic_fallback.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from test_case_base import TestCaseBase 4 | 5 | import paddle 6 | 7 | 8 | def case1(x): 9 | kk = "sdfsdf" 10 | y = 2 + x 11 | ret = paddle.nn.functional.relu(y) 12 | print(y) 13 | print(y.numpy()) 14 | print("yes") 15 | print("no") 16 | # for i in range(10): 17 | ret = ret + 2 + x 18 | return ret 19 | 20 | 21 | class TestFallback(TestCaseBase): 22 | def test_bool(self): 23 | x = paddle.to_tensor([1.0]) 24 | self.assert_results(case1, x) 25 | 26 | 27 | if __name__ == "__main__": 28 | unittest.main() 29 | -------------------------------------------------------------------------------- /tests/tests_legacy/test_symbolic_infer_meta.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from test_case_base import TestCaseBase 4 | 5 | import paddle 6 | from sot import symbolic_translate 7 | 8 | 9 | def case1(x): 10 | y = 2 + x 11 | y = y + 1 12 | ret = paddle.nn.functional.relu(y) 13 | 14 | assert list(y.meta.shape) == [2] 15 | z = paddle.concat([x, y]) 16 | assert list(z.meta.shape) == [4] 17 | 18 | 19 | class TestIf(TestCaseBase): 20 | def test_if_1(self): 21 | x = paddle.to_tensor([1.0, 2.0]) 22 | symbolic_translate(case1)(x) 23 | 24 | 25 | if __name__ == "__main__": 26 | unittest.main() 27 | -------------------------------------------------------------------------------- /tests/error_test_resnet50.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from test_case_base import TestCaseBase 4 | 5 | import paddle 6 | from paddle.vision import resnet50 7 | 8 | 9 | class SimpleNet(paddle.nn.Layer): 10 | def __init__(self): 11 | super().__init__() 12 | self.relu = paddle.nn.ReLU() 13 | 14 | def forward(self, x): 15 | x = self.relu(x) 16 | return x 17 | 18 | 19 | class TestNet(TestCaseBase): 20 | def test(self): 21 | image = paddle.rand((1, 3, 255, 255)) 22 | net = resnet50() 23 | self.assert_results(net, image) 24 | 25 | 26 | if __name__ == "__main__": 27 | unittest.main() 28 | -------------------------------------------------------------------------------- /tests/test_02_store_inplace.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from test_case_base import TestCaseBase 4 | 5 | import paddle 6 | 7 | 8 | def foo(x: int, y: paddle.Tensor): 9 | x = x + 1 10 | y = y + 1 11 | x += y 12 | return x 13 | 14 | 15 | class TestStoreInplace(TestCaseBase): 16 | def test_simple(self): 17 | self.assert_results(foo, 1, paddle.to_tensor(2)) 18 | 19 | 20 | if __name__ == "__main__": 21 | unittest.main() 22 | 23 | 24 | # Instructions: 25 | # LOAD_FAST 26 | # BINARY_ADD 27 | # STORE_FAST (new) 28 | # INPLACE_ADD (new) 29 | # RETURN_VALUE 30 | 31 | # Variables: 32 | # ConstantVariable 33 | # TensorVariable 34 | -------------------------------------------------------------------------------- /tests/tests_legacy/test_resnet50.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from test_case_base import TestCaseBase 4 | 5 | import paddle 6 | from paddle.vision import resnet50 7 | 8 | 9 | class SimpleNet(paddle.nn.Layer): 10 | def __init__(self): 11 | super().__init__() 12 | self.relu = paddle.nn.ReLU() 13 | 14 | def forward(self, x): 15 | x = self.relu(x) 16 | return x 17 | 18 | 19 | class TestNet(TestCaseBase): 20 | def test(self): 21 | image = paddle.rand((1, 3, 255, 255)) 22 | net = resnet50() 23 | self.assert_results(net, image) 24 | 25 | 26 | if __name__ == "__main__": 27 | unittest.main() 28 | -------------------------------------------------------------------------------- /tests/test_09_f_string.py: -------------------------------------------------------------------------------- 1 | # FORMAT_VALUE (new) 2 | # BUILD_STRING (new) 3 | from __future__ import annotations 4 | 5 | import unittest 6 | 7 | from test_case_base import TestCaseBase 8 | 9 | import paddle 10 | from sot.psdb import assert_true 11 | 12 | 13 | def foo(x: paddle.Tensor): 14 | whilespace = 123 15 | hello_world = f"Hello {whilespace} World" 16 | z = assert_true(hello_world == "Hello 123 World") 17 | x = x + 1 18 | return x 19 | 20 | 21 | class TestFString(TestCaseBase): 22 | def test_fstring(self): 23 | self.assert_results(foo, paddle.to_tensor(1)) 24 | 25 | 26 | if __name__ == "__main__": 27 | unittest.main() 28 | -------------------------------------------------------------------------------- /tests/test_str_format.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import unittest 4 | 5 | from test_case_base import TestCaseBase 6 | 7 | 8 | # copy from python library _distutils_hack/__init__.py 9 | def find_spec(self, fullname, path, target=None): 10 | method_name = 'spec_for_{fullname}'.format( 11 | **{'self': self, 'fullname': fullname} 12 | ) 13 | method = getattr(self, method_name, lambda: None) 14 | return method() 15 | 16 | 17 | class TestExecutor(TestCaseBase): 18 | def test_simple(self): 19 | self.assert_results(find_spec, "self", "fullname", "path", None) 20 | 21 | 22 | if __name__ == "__main__": 23 | unittest.main() 24 | -------------------------------------------------------------------------------- /examples/trace_basic.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import paddle 4 | from sot.translate import symbolic_translate 5 | 6 | 7 | def foo(x: paddle.Tensor, y: paddle.Tensor): 8 | z = x + y 9 | return z + 1 10 | 11 | 12 | def main(): 13 | x = paddle.rand([2, 3]) 14 | y = paddle.rand([2, 3]) 15 | dygraph_out = foo(x, y) 16 | symbolic_translate_out = symbolic_translate(foo)(x, y) 17 | 18 | print("dygraph_out:", dygraph_out) 19 | print("symbolic_translate_out:", symbolic_translate_out) 20 | np.testing.assert_allclose( 21 | dygraph_out.numpy(), symbolic_translate_out.numpy() 22 | ) 23 | 24 | 25 | if __name__ == '__main__': 26 | main() 27 | -------------------------------------------------------------------------------- /tests/tests_legacy/error_test_symbolic_load_before_call.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from test_case_base import TestCaseBase 4 | 5 | import paddle 6 | 7 | # TODO(SigureMo): Fix this test case 8 | 9 | 10 | def nest_func_3(x, y): 11 | a = x + y 12 | b = x * a 13 | z = b + nest_func_2(a) 14 | return z 15 | 16 | 17 | def nest_func_2(z): 18 | print(z) 19 | return z 20 | 21 | 22 | class Test(TestCaseBase): 23 | def test(self): 24 | self.assert_results( 25 | nest_func_3, 26 | paddle.to_tensor([1, 1, 1, 1]), 27 | paddle.to_tensor([2, 3, 4, 5]), 28 | ) 29 | 30 | 31 | if __name__ == "__main__": 32 | unittest.main() 33 | -------------------------------------------------------------------------------- /tests/tests_legacy/test_symbolic_backward.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | from test_case_base import TestCaseBase 5 | 6 | import paddle 7 | from sot import symbolic_translate 8 | 9 | 10 | def func(x, y): 11 | return x * y 12 | 13 | 14 | class TestNet(TestCaseBase): 15 | def test(self): 16 | x = paddle.to_tensor([5, 3]) 17 | y = paddle.to_tensor([1, 3]) 18 | x.stop_gradient = False 19 | self.assert_results(func, x, y) 20 | 21 | ret = symbolic_translate(func)(x, y) 22 | ret.backward() 23 | np.testing.assert_allclose(x.grad.numpy(), [1.0, 3.0]) 24 | 25 | 26 | if __name__ == "__main__": 27 | unittest.main() 28 | -------------------------------------------------------------------------------- /.github/workflows/codestyle.yaml: -------------------------------------------------------------------------------- 1 | name: CodeStyle Check 2 | 3 | on: 4 | push: 5 | branches: [develop] 6 | pull_request: 7 | workflow_dispatch: 8 | 9 | jobs: 10 | Test: 11 | runs-on: ubuntu-latest 12 | name: CodeStyle Check 13 | steps: 14 | - name: Checkout 15 | uses: actions/checkout@v3 16 | 17 | - name: Install python 18 | uses: actions/setup-python@v4 19 | with: 20 | # Run pre-commit on the lowest supported python version 21 | python-version: '3.8' 22 | 23 | - name: Install dependencies 24 | run: | 25 | pip install pre-commit 26 | - name: Precommit Check 27 | run : pre-commit run --all-files 28 | -------------------------------------------------------------------------------- /tests/tests_legacy/test_simple_net.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from test_case_base import TestCaseBase 4 | 5 | import paddle 6 | 7 | 8 | class SimpleNet(paddle.nn.Layer): 9 | def __init__(self): 10 | super().__init__() 11 | self.linear1 = paddle.nn.Linear(10, 3) 12 | self.linear2 = paddle.nn.Linear(3, 1) 13 | 14 | def forward(self, x): 15 | out1 = self.linear1(x) 16 | out2 = self.linear2(out1) 17 | return out1 + out2 18 | 19 | 20 | class TestNet(TestCaseBase): 21 | def test(self): 22 | inp = paddle.rand((10,)) 23 | net = SimpleNet() 24 | self.assert_results(net, inp) 25 | 26 | 27 | if __name__ == "__main__": 28 | unittest.main() 29 | -------------------------------------------------------------------------------- /tests/error_test_general_bytecode.py: -------------------------------------------------------------------------------- 1 | import dis 2 | import unittest 3 | 4 | from test_case_base import TestCaseBase 5 | 6 | 7 | def func(): 8 | return True is True is not False 9 | 10 | 11 | def func2(x): 12 | # TODO(@xiaojian): SIR not used by output. 13 | y = x + 1 14 | return True is True is not False 15 | 16 | 17 | def func3(): 18 | i = 0 19 | 20 | def inner(): 21 | return i + 1 22 | 23 | return inner() 24 | 25 | 26 | class TestExecutor(TestCaseBase): 27 | def test_simple(self): 28 | self.assert_results(func3) 29 | # self.assert_results(func2, paddle.to_tensor(1.0)) 30 | 31 | 32 | dis.dis(func3) 33 | 34 | if __name__ == "__main__": 35 | unittest.main() 36 | -------------------------------------------------------------------------------- /tests/tests_legacy/test_error_cases.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from test_case_base import TestCaseBase 4 | 5 | import paddle 6 | 7 | 8 | def case_list_comp(x): 9 | def add_one(x): 10 | return x + 1 11 | 12 | ret = [add_one(x) for x in x] 13 | return ret 14 | 15 | 16 | def case_map_structure(x): 17 | def add_one(x): 18 | return x + 1 19 | 20 | ret = paddle.utils.map_structure(add_one, x) 21 | return ret 22 | 23 | 24 | class Test(TestCaseBase): 25 | def test(self): 26 | self.assert_results(case_list_comp, paddle.to_tensor([1, 2, 4, 3])) 27 | self.assert_results(case_map_structure, paddle.to_tensor([1, 2, 4, 3])) 28 | 29 | 30 | if __name__ == "__main__": 31 | unittest.main() 32 | -------------------------------------------------------------------------------- /tests/tests_legacy/test_optransform_cache.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from test_case_base import TestCaseBase 4 | 5 | import paddle 6 | from sot import symbolic_translate 7 | from sot.opcode_translator.instruction_translator import OpcodeExecutorCache 8 | 9 | 10 | def case1(x): 11 | for i in range(int(x)): 12 | print("yes") 13 | return x 14 | 15 | 16 | class TestFor(TestCaseBase): 17 | def test(self): 18 | symbolic_translate(case1)(paddle.to_tensor([4])) 19 | symbolic_translate(case1)(paddle.to_tensor([4])) 20 | symbolic_translate(case1)(paddle.to_tensor([4])) 21 | assert OpcodeExecutorCache().hit_num == 2, "cache hit num should be 2" 22 | 23 | 24 | if __name__ == "__main__": 25 | unittest.main() 26 | -------------------------------------------------------------------------------- /examples/graph_break.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import paddle 4 | from sot.translate import symbolic_translate 5 | 6 | 7 | def foo(cond: paddle.Tensor, x: paddle.Tensor): 8 | x += 1 9 | if cond: 10 | x += 1 11 | else: 12 | x -= 1 13 | return x 14 | 15 | 16 | def main(): 17 | cond = paddle.to_tensor(True) 18 | x = paddle.to_tensor(0) 19 | dygraph_out = foo(cond, x) 20 | symbolic_translate_out = symbolic_translate(foo)(cond, x) 21 | 22 | print("dygraph_out:", dygraph_out) 23 | print("symbolic_translate_out:", symbolic_translate_out) 24 | np.testing.assert_allclose( 25 | dygraph_out.numpy(), symbolic_translate_out.numpy() 26 | ) 27 | 28 | 29 | if __name__ == '__main__': 30 | main() 31 | -------------------------------------------------------------------------------- /tests/tests_legacy/test_symbolic_fallback_if.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from test_case_base import TestCaseBase 4 | 5 | import paddle 6 | 7 | 8 | def case1(cond, x): 9 | cond = 2 + cond 10 | if cond: 11 | print("yes") 12 | x = x + 4 13 | else: 14 | print("no") 15 | x = x - 4 16 | ret = paddle.nn.functional.relu(x) 17 | return ret 18 | 19 | 20 | class TestIf(TestCaseBase): 21 | def test_if_1(self): 22 | self.assert_results( 23 | case1, paddle.to_tensor([4.0]), paddle.to_tensor([4.0]) 24 | ) 25 | self.assert_results( 26 | case1, paddle.to_tensor([-2.0]), paddle.to_tensor([4.0]) 27 | ) 28 | 29 | 30 | if __name__ == "__main__": 31 | unittest.main() 32 | -------------------------------------------------------------------------------- /tests/test_numpy.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | from test_case_base import TestCaseBase, strict_mode_guard 5 | 6 | import paddle 7 | 8 | 9 | def foo(x, y): 10 | ret = x + y 11 | return ret 12 | 13 | 14 | class TestNumpy(TestCaseBase): 15 | def test_tensor_add_numpy_number(self): 16 | x = paddle.to_tensor([1.0]) 17 | y = np.int64(2) 18 | self.assert_results(foo, x, y) 19 | self.assert_results(foo, y, x) 20 | 21 | @strict_mode_guard(0) 22 | def test_tensor_add_numpy_array(self): 23 | x = paddle.to_tensor([1.0]) 24 | y = np.array(2.0) 25 | self.assert_results(foo, x, y) 26 | self.assert_results(foo, y, x) 27 | 28 | 29 | if __name__ == "__main__": 30 | unittest.main() 31 | -------------------------------------------------------------------------------- /examples/guard.py: -------------------------------------------------------------------------------- 1 | import paddle 2 | from sot.translate import symbolic_translate 3 | 4 | 5 | def foo(x: paddle.Tensor, y: paddle.Tensor, z: paddle.Tensor): 6 | a = x + y 7 | a *= x 8 | return a, z 9 | 10 | 11 | def main(): 12 | a = paddle.rand([1]) 13 | b = paddle.rand([2, 3]) 14 | c = paddle.rand([4]) 15 | d = paddle.rand([5, 6]) 16 | e = paddle.rand([]) 17 | sym_foo = symbolic_translate(foo) 18 | dygraph_out = foo(a, b, c) 19 | symbolic_translate_out = sym_foo(a, b, c) 20 | 21 | print("dygraph_out:", dygraph_out) 22 | print("symbolic_translate_out:", symbolic_translate_out) 23 | 24 | # cache hit 25 | sym_foo(a, b, d) 26 | 27 | # cache miss 28 | sym_foo(e, b, c) 29 | 30 | 31 | if __name__ == '__main__': 32 | main() 33 | -------------------------------------------------------------------------------- /examples/run_all.sh: -------------------------------------------------------------------------------- 1 | # 遍历目录下的所有 Python 文件 2 | export PYTHONPATH=$PYTHONPATH:../ 3 | export STRICT_MODE=1 4 | export MIN_GRAPH_SIZE=-1 5 | 6 | for file in ./*.py; do 7 | # 检查文件是否为 Python 文件 8 | if [ -f "$file" ]; then 9 | if [[ -n "$GITHUB_ACTIONS" ]]; then 10 | echo ::group::example Running: LOG_LEVEL=3 PYTHONPATH=$PYTHONPATH " STRICT_MODE=1 python " $file 11 | else 12 | echo Running: LOG_LEVEL=3 PYTHONPATH=$PYTHONPATH " STRICT_MODE=1 python " $file 13 | fi 14 | # 执行文件 15 | python "$file" 16 | if [ $? -ne 0 ]; then 17 | echo "run $file failed" 18 | exit 1 19 | fi 20 | if [[ -n "$GITHUB_ACTIONS" ]]; then 21 | echo "::endgroup::" 22 | fi 23 | fi 24 | done 25 | -------------------------------------------------------------------------------- /tests/tests_legacy/error_test_trace_cache.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import paddle 4 | from sot import symbolic_translate 5 | from sot.proxy_tensor import cache_and_return, frame_enter, frame_leave 6 | 7 | 8 | class A: 9 | def __init__(self, x): 10 | self.x = x 11 | 12 | 13 | def sum_2(l): 14 | if frame_enter("func1", (l)): 15 | print("hit cache") 16 | return cache_and_return("func1", (l)) 17 | ret = l[0] + l[1] 18 | frame_leave(ret) 19 | return ret 20 | 21 | 22 | class TestCaseName(unittest.TestCase): 23 | def test_return_callable(self): 24 | x = paddle.to_tensor([1.0]) 25 | y = paddle.to_tensor([2.0]) 26 | ret = symbolic_translate(sum_2)([x, y]) 27 | print(ret) 28 | 29 | 30 | if __name__ == "__main__": 31 | unittest.main() 32 | -------------------------------------------------------------------------------- /tests/tests_legacy/test_symbolic_trace.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from test_case_base import TestCaseBase 4 | 5 | import paddle 6 | 7 | 8 | def case1(x): 9 | y = 2 + x 10 | ret = paddle.nn.functional.relu(y) 11 | print("yes") 12 | print("no") 13 | # for i in range(10): 14 | ret = ret + 2 + x 15 | return ret 16 | 17 | 18 | def case2(x): 19 | y = x + 2 20 | ret = paddle.nn.functional.relu(y) 21 | for i in range(10): 22 | ret = ret + 2 + x 23 | return ret 24 | 25 | 26 | class TestIf(TestCaseBase): 27 | def test_if_1(self): 28 | x = paddle.to_tensor([1.0]) 29 | self.assert_results(case1, paddle.to_tensor([1.0])) 30 | self.assert_results(case2, paddle.to_tensor([1.0])) 31 | 32 | 33 | if __name__ == "__main__": 34 | unittest.main() 35 | -------------------------------------------------------------------------------- /tests/tests_legacy/error_test_sir_call.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import paddle 4 | from sot import symbolic_translate 5 | from sot.trace_cache_entrance import cache_and_return, frame_enter, frame_leave 6 | 7 | 8 | def sum(x, y): 9 | if frame_enter("sum", (x, y)): 10 | return cache_and_return("sum", (x, y)) 11 | ret = x + y 12 | frame_leave("sum", (ret)) 13 | return ret 14 | 15 | 16 | def main(x, y): 17 | ret = sum(x, x) 18 | ret2 = sum(x, y) 19 | return ret2 20 | 21 | 22 | class TestCaseName(unittest.TestCase): 23 | def test_return_callable(self): 24 | x = paddle.to_tensor([1.0]) 25 | y = paddle.to_tensor([2.0]) 26 | ret = symbolic_translate(main)(x, y) 27 | assert ret.item() == 3.0, "Should be 4.0" 28 | 29 | 30 | if __name__ == "__main__": 31 | unittest.main() 32 | -------------------------------------------------------------------------------- /tests/test_dup_top.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import unittest 4 | 5 | from test_case_base import TestCaseBase 6 | 7 | import paddle 8 | 9 | 10 | def func_dup_top_1(): 11 | return True == True != False 12 | 13 | 14 | def func_dup_top_2(x): 15 | y = x + 1 16 | return True == True != False 17 | 18 | 19 | def func_dup_top_two(x: list[paddle.Tensor]): 20 | x[0] += x[1] 21 | return x 22 | 23 | 24 | class TestDupTop(TestCaseBase): 25 | def test_dup_top(self): 26 | self.assert_results(func_dup_top_1) 27 | self.assert_results(func_dup_top_2, paddle.to_tensor(1.0)) 28 | # TODO: fix this after we support side effect 29 | # self.assert_results( 30 | # func_dup_top_two, [paddle.to_tensor(1.0), paddle.to_tensor(2.0)] 31 | # ) 32 | 33 | 34 | if __name__ == "__main__": 35 | unittest.main() 36 | -------------------------------------------------------------------------------- /tests/test_simulate_initialize.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from test_case_base import TestCaseBase 4 | 5 | import paddle 6 | from paddle import nn 7 | from sot import symbolic_translate 8 | 9 | 10 | class A: 11 | def __init__(self, vals): 12 | vals.append(1) 13 | 14 | 15 | def foo(x, y): 16 | out = nn.Softmax()(paddle.to_tensor([x, y], dtype="float32")) 17 | return out 18 | 19 | 20 | def bar(x): 21 | a = A(x) 22 | t = paddle.to_tensor(x) 23 | return t.mean() 24 | 25 | 26 | class TestInit(TestCaseBase): 27 | def test_init_paddle_layer(self): 28 | self.assert_results(foo, 1, 2) 29 | 30 | def test_init_python_object(self): 31 | sot_output = symbolic_translate(bar)([1.0, 2.0]) 32 | dyn_output = bar([1.0, 2.0]) 33 | self.assert_nest_match(sot_output, dyn_output) 34 | 35 | 36 | if __name__ == "__main__": 37 | unittest.main() 38 | -------------------------------------------------------------------------------- /docs/notes/function-and-method.md: -------------------------------------------------------------------------------- 1 | # 函数和方法 2 | 3 | ## method 与 bind 4 | 5 | 什么是 function 就不必说了,这里说明一下 method 和 function 的区别 6 | 7 | 简单来说,method 就是 function 的第一个位置 bind 一个 object(即 self),其行为上有点类似用 partial 绑定了第一个参数的 function 8 | 9 | ```python 10 | class A: 11 | def b(self, x): 12 | return x 13 | 14 | a = A() 15 | a.b(1) 16 | 17 | # equivalent to 18 | A.b(a, 1) 19 | 20 | # it is like 21 | partial_b = partial(A.b, a) # bind a to the first position 22 | partial_b(1) 23 | ``` 24 | 25 | 注意 `a.b(a)` 和 `A.b(a, 1)` 是完全等价的,这里 `A.b` 是一个 function,也称 unbound method,而 `a.b` 则是一个 bound method,根据名字也能看出来,就是 bind 了 self 的 method 26 | 27 | ## method 与 descriptor 28 | 29 | method 的实际是利用了 descriptor,比如 30 | 31 | ```python 32 | A.b.__get__(a, A) 33 | 34 | # equivalent to 35 | a.b 36 | ``` 37 | 38 | 由于 `A.b` 是一个 descriptor,因此在其实例获取属性 `b` 时,自然会调用 `A.b.__get__`,在此时便会将原来的 function 和 object 绑定在一起,获得一个 bound method 39 | -------------------------------------------------------------------------------- /tests/test_01_basic.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from test_case_base import TestCaseBase, strict_mode_guard 4 | 5 | import paddle 6 | 7 | 8 | def foo(x: int, y: paddle.Tensor): 9 | return x + y 10 | 11 | 12 | class TestExecutor(TestCaseBase): 13 | def test_simple(self): 14 | self.assert_results(foo, 1, paddle.to_tensor(2)) 15 | 16 | 17 | def numpy_add(x, y): 18 | out = paddle.to_tensor(x.numpy() + y.numpy()) 19 | return out 20 | 21 | 22 | class TestNumpyAdd(TestCaseBase): 23 | @strict_mode_guard(0) 24 | def test_numpy_add(self): 25 | x = paddle.to_tensor([2]) 26 | y = paddle.to_tensor([3]) 27 | self.assert_results(numpy_add, x, y) 28 | 29 | 30 | if __name__ == "__main__": 31 | unittest.main() 32 | 33 | 34 | # Instructions: 35 | # LOAD_FAST 36 | # BINARY_ADD 37 | # RETURN_VALUE 38 | 39 | # Variables: 40 | # ConstantVariable 41 | # TensorVariable 42 | -------------------------------------------------------------------------------- /sot/opcode_translator/executor/dispatch_functions.py: -------------------------------------------------------------------------------- 1 | # This file stores the customed function that will be called by the dispatch mechanism. 2 | 3 | from ...utils import BreakGraphError, FallbackError 4 | 5 | 6 | def raise_break_graph_fn(*args, **kwarg): 7 | raise BreakGraphError("raise by raise_break_graph_fn.") 8 | 9 | 10 | def raise_not_implement_fn(*args, **kwarg): 11 | raise FallbackError("raise by raise_break_graph_fn.") 12 | 13 | 14 | # just a function for operator.in 15 | def operator_in(left, right): 16 | return left in right 17 | 18 | 19 | def operator_not_in(left, right): 20 | return left not in right 21 | 22 | 23 | def operator_exception_match(left, right): 24 | pass 25 | 26 | 27 | def operator_BAD(left, right): 28 | pass 29 | 30 | 31 | def operator_is_none(val): 32 | pass 33 | 34 | 35 | def operator_is_not_none(val): 36 | pass 37 | 38 | 39 | def tensor_numel(x): 40 | pass 41 | -------------------------------------------------------------------------------- /docs/design/eval_frame.md: -------------------------------------------------------------------------------- 1 | # Eval Frame 设计和实现 2 | 3 | ## 疑难问题定位和分析 4 | 5 | ### Generator Early Return Causes `System Error` 6 | 7 | 待完善,可以见如流知识库 8 | 9 | ### Resnet 支持后出现 segement fault 10 | 11 | - [x] 果然是 stacksize 的问题 12 | 13 | 初步定位,发现第一个Step的翻译过程已经走完,出现问题的环节应该是 `eval frame` 中的第二步。 14 | 15 | #### 如何查找线索? 16 | 17 | 面临的第一个问题是,如何寻找到定位问题的线索。eval frame 中没有对应的 log 可以查看,所以我们很难进行定位那里出现了问题,但是因为 Python 的独立编译过,所以可以很方便的在Pytho中打印出我们需要的变量。 18 | 19 | 所以我们有如下两个探索方式: 20 | 21 | 1. 在 Eval frame 中进行日志探索。 22 | 23 | 2. 在 Python 源码中插入信息搜集的点。 24 | 25 | 理论上,有了上述两个方法,我们可以找到一切问题。包含 segment 错误。 26 | 27 | #### 定位 segment fault 位置 28 | 29 | segment fault 问题的首要任务就是找到错误的地点。才可以逐渐分析出错误原因。 30 | 31 | **没有跑到动转静组网**: 首先需要明确,segment fault 位置在 `CALL_FUNCTION` 字节码之前。 32 | 33 | ```python 34 | import os 35 | os.main 36 | ``` 37 | 38 | int main 39 | 40 | 41 | #### 问题猜测 42 | 43 | - [x] 果然是这个stacksize的问题!! 44 | 45 | 这个问题是否与 eval frame 中的 stack size 有关系,因为stack size不够大,导致栈溢出了。这样的解释也是比较合理的。 46 | -------------------------------------------------------------------------------- /tests/test_constant_graph.py: -------------------------------------------------------------------------------- 1 | # New Supported Instructions: 2 | # BUILD_MAP (new) 3 | # BUILD_CONST_KEY_MAP (new) 4 | 5 | import unittest 6 | 7 | from test_case_base import TestCaseBase 8 | 9 | import paddle 10 | 11 | 12 | def func_1(format_str, tensor): 13 | str = format_str.format(xx=12) 14 | a = "{xx} = 12".format 15 | ttt = f"{10} = 12" 16 | a(xx=12) 17 | tensor = tensor + 1 18 | return str, tensor 19 | 20 | 21 | def func_2(format_str, tensor): 22 | str = format_str % 10 23 | tensor = tensor + 1 24 | return str, tensor 25 | 26 | 27 | class TestConstantGraph(TestCaseBase): 28 | def test_case_1(self): 29 | x = "{xx} is xx" 30 | tensor = paddle.to_tensor(1) 31 | self.assert_results(func_1, x, tensor) 32 | 33 | def test_case_2(self): 34 | x = "%s is xx" 35 | tensor = paddle.to_tensor(1) 36 | self.assert_results(func_2, x, tensor) 37 | 38 | 39 | if __name__ == "__main__": 40 | unittest.main() 41 | -------------------------------------------------------------------------------- /tests/test_numpy_var_if.py: -------------------------------------------------------------------------------- 1 | import os 2 | import unittest 3 | 4 | import numpy as np 5 | from test_case_base import TestCaseBase 6 | 7 | import paddle 8 | from sot.psdb import check_no_breakgraph, check_no_fallback 9 | 10 | os.environ['MIN_GRAPH_SIZE'] = '-1' 11 | 12 | 13 | @check_no_breakgraph 14 | @check_no_fallback 15 | def forward(x, y): 16 | if x == 0: 17 | return y + 2 18 | else: 19 | return y * 2 20 | 21 | 22 | @check_no_breakgraph 23 | @check_no_fallback 24 | def forward2(x, y): 25 | if x == x: # numpy == numpy 26 | return y + 2 27 | else: 28 | return y * 2 29 | 30 | 31 | class TestJumpWithNumpy(TestCaseBase): 32 | def test_jump(self): 33 | self.assert_results(forward, np.array([1]), paddle.to_tensor(2)) 34 | self.assert_results(forward, np.array([0]), paddle.to_tensor(2)) 35 | self.assert_results(forward2, np.array([0]), paddle.to_tensor(2)) 36 | 37 | 38 | if __name__ == "__main__": 39 | unittest.main() 40 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PaddleSOT 2 | 3 | **Paddle** **S**ymbolic **O**pcode **T**ranslator. 4 | 5 | PaddleSOT 是一个 Opcode-Based 的动转静孵化项目,借助 Symbolic Opcode Translator(简称:SOT)在运行时将 PaddlePaddle 动态图组网代码转换为静态图组网代码,具体设计参见:[PaddleSOT 项目介绍](https://github.com/PaddlePaddle/community/tree/master/pfcc/paddle-code-reading/symbolic_opcode_translator) 6 | 7 | > **Note** 8 | > 9 | > PaddleSOT 代码目前已经迁移到 Paddle 主框架 repo,请前往 [`Paddle/python/paddle/jit/sot`](https://github.com/PaddlePaddle/Paddle/tree/develop/python/paddle/jit/sot) 访问最新代码 10 | 11 | ## Install 12 | 13 | ```bash 14 | git clone https://github.com/PaddlePaddle/PaddleSOT.git 15 | cd PaddleSOT/ 16 | pip install -e . 17 | ``` 18 | 19 | 此外由于我们有部分特性依赖于最新的 PaddlePaddle,因此你需要安装 Nightly build 版本的 PaddlePaddle,你可以在[官网安装页面](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/develop/install/pip/linux-pip.html)根据自己的平台找到相应的安装方式 20 | 21 | ## Usage 22 | 23 | 你可以通过运行 `examples/` 下的示例来了解 PaddleSOT 的使用方法。 24 | 25 | ```bash 26 | python examples/trace_basic.py 27 | ``` 28 | 29 | ## Contributing 30 | 31 | 请参考 [PaddleSOT 贡献指南](./CONTRIBUTING.md) 32 | -------------------------------------------------------------------------------- /tests/error_test_jump_inline.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import unittest 4 | 5 | from test_case_base import TestCaseBase 6 | 7 | import paddle 8 | 9 | 10 | def pop_jump_if_false(x: bool, y: paddle.Tensor): 11 | if x: 12 | y += 1 13 | else: 14 | y -= 1 15 | return y 16 | 17 | 18 | def outer_function(x, y): 19 | m = y + 2 20 | ret = pop_jump_if_false(x, m) 21 | ret = ret * 2 22 | return ret 23 | 24 | 25 | a = paddle.to_tensor(1) 26 | b = paddle.to_tensor(2) 27 | c = paddle.to_tensor(3) 28 | d = paddle.to_tensor(4) 29 | 30 | 31 | true_tensor = paddle.to_tensor(True) 32 | false_tensor = paddle.to_tensor(False) 33 | 34 | 35 | class TestExecutor(TestCaseBase): 36 | def test_simple(self): 37 | self.assert_results(outer_function, True, a) 38 | self.assert_results(outer_function, False, a) 39 | 40 | def test_fallback(self): 41 | self.assert_results(outer_function, true_tensor, a) 42 | self.assert_results(outer_function, false_tensor, a) 43 | 44 | 45 | if __name__ == "__main__": 46 | unittest.main() 47 | -------------------------------------------------------------------------------- /tests/test_execution_base.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from test_case_base import TestCaseBase 4 | 5 | import paddle 6 | from paddle.static import BuildStrategy 7 | from sot import symbolic_translate 8 | 9 | 10 | def func(x, y): 11 | ret = 2 * x 12 | ret = paddle.nn.functional.relu(ret) 13 | ret = ret + y 14 | return ret 15 | 16 | 17 | def simple(x): 18 | ret = 2 * x 19 | return ret 20 | 21 | 22 | class TestExecutor(TestCaseBase): 23 | def test_simple(self): 24 | x = paddle.to_tensor([1.0]) 25 | y = paddle.to_tensor([2.0]) 26 | self.assert_results(simple, x) 27 | self.assert_results(simple, y) 28 | 29 | 30 | def foo(x): 31 | out = x + 1 32 | out = out * 2 33 | out = paddle.nn.functional.relu(out) 34 | return out 35 | 36 | 37 | class TestBackend(TestCaseBase): 38 | def test_backend(self): 39 | x = paddle.randn([2, 3]) 40 | dy_out = foo(x) 41 | sot_out = symbolic_translate( 42 | foo, build_strategy=BuildStrategy(), backend='CINN' 43 | )(x) 44 | self.assert_nest_match(dy_out, sot_out) 45 | 46 | 47 | if __name__ == "__main__": 48 | unittest.main() 49 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # Exclude all third-party libraries and auto-generated files globally 2 | repos: 3 | # Common hooks 4 | - repo: https://github.com/pre-commit/pre-commit-hooks 5 | rev: v4.4.0 6 | hooks: 7 | - id: check-added-large-files 8 | - id: check-merge-conflict 9 | - id: check-symlinks 10 | - id: detect-private-key 11 | - id: end-of-file-fixer 12 | - id: trailing-whitespace 13 | - repo: https://github.com/Lucas-C/pre-commit-hooks.git 14 | rev: v1.5.1 15 | hooks: 16 | - id: remove-crlf 17 | - id: remove-tabs 18 | name: Tabs remover (Python) 19 | files: (.*\.(py|bzl)|BUILD|.*\.BUILD|WORKSPACE)$ 20 | args: [--whitespaces-count, '4'] 21 | # For Python files 22 | - repo: https://github.com/psf/black.git 23 | rev: 23.3.0 24 | hooks: 25 | - id: black 26 | files: (.*\.(py|pyi|bzl)|BUILD|.*\.BUILD|WORKSPACE)$ 27 | - repo: https://github.com/pycqa/isort 28 | rev: 5.11.5 29 | hooks: 30 | - id: isort 31 | - repo: https://github.com/astral-sh/ruff-pre-commit 32 | rev: v0.0.289 33 | hooks: 34 | - id: ruff 35 | args: [--fix, --exit-non-zero-on-fix, --no-cache] 36 | -------------------------------------------------------------------------------- /sot/psdb.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import builtins 4 | import types 5 | from typing import TYPE_CHECKING, Callable 6 | 7 | if TYPE_CHECKING: 8 | from typing import TypeVar 9 | 10 | from typing_extensions import ParamSpec 11 | 12 | T = TypeVar("T") 13 | P = ParamSpec("P") 14 | 15 | NO_BREAKGRAPH_CODES: set[types.CodeType] = set() 16 | NO_FALLBACK_CODES: set[types.CodeType] = set() 17 | 18 | 19 | def assert_true(input: bool): 20 | assert input 21 | 22 | 23 | def print(*args, **kwargs): 24 | builtins.print("[Dygraph]", *args, **kwargs) 25 | 26 | 27 | def breakpoint(): 28 | import paddle 29 | 30 | old = paddle.framework.core.set_eval_frame(None) 31 | builtins.breakpoint() 32 | paddle.framework.core.set_eval_frame(old) 33 | 34 | 35 | def check_no_breakgraph(fn: Callable[P, T]) -> Callable[P, T]: 36 | NO_BREAKGRAPH_CODES.add(fn.__code__) 37 | return fn 38 | 39 | 40 | def breakgraph(): 41 | pass 42 | 43 | 44 | def check_no_fallback(fn: Callable[P, T]) -> Callable[P, T]: 45 | NO_FALLBACK_CODES.add(fn.__code__) 46 | return fn 47 | 48 | 49 | def fallback(): 50 | pass 51 | 52 | 53 | def in_sot(): 54 | return False 55 | -------------------------------------------------------------------------------- /tests/run_all.sh: -------------------------------------------------------------------------------- 1 | # 遍历目录下的所有 python 文件 2 | export PYTHONPATH=$PYTHONPATH:../ 3 | export STRICT_MODE=1 4 | export COST_MODEL=False 5 | export MIN_GRAPH_SIZE=0 6 | 7 | failed_tests=() 8 | 9 | for file in ./test_*.py; do 10 | # 检查文件是否为 python 文件 11 | if [ -f "$file" ]; then 12 | if [[ -n "$GITHUB_ACTIONS" ]]; then 13 | echo ::group::Running: PYTHONPATH=$PYTHONPATH " STRICT_MODE=1 python " $file 14 | else 15 | echo Running: PYTHONPATH=$PYTHONPATH " STRICT_MODE=1 python " $file 16 | fi 17 | # 执行文件 18 | python_output=$(python $file 2>&1) 19 | 20 | if [ $? -ne 0 ]; then 21 | echo "run $file failed" 22 | failed_tests+=("$file") 23 | if [[ -n "$GITHUB_ACTIONS" ]]; then 24 | echo -e "$python_output" | python ./extract_errors.py 25 | fi 26 | echo -e "$python_output" 27 | fi 28 | if [[ -n "$GITHUB_ACTIONS" ]]; then 29 | echo "::endgroup::" 30 | fi 31 | fi 32 | done 33 | 34 | if [ ${#failed_tests[@]} -ne 0 ]; then 35 | echo "failed tests file:" 36 | for failed_test in "${failed_tests[@]}"; do 37 | echo "$failed_test" 38 | done 39 | exit 1 40 | fi 41 | -------------------------------------------------------------------------------- /tests/tests_legacy/test_basic_translation.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import paddle 4 | from sot import symbolic_translate 5 | from sot.proxy_tensor import ProxyTensor 6 | 7 | 8 | def _ret_func(): 9 | def inner(): 10 | print("inner called") 11 | return paddle.to_tensor([1, 2, 3]) 12 | 13 | return inner 14 | 15 | 16 | def return_callable(): 17 | retval = _ret_func()() 18 | assert isinstance(retval, ProxyTensor) 19 | 20 | 21 | def _ret_tuple(): 22 | print("i am called") 23 | return 1, 2, paddle.to_tensor([1, 2, 3]) 24 | 25 | 26 | def return_tuple(): 27 | a, b, c = _ret_tuple() 28 | print(a, b, c) 29 | assert isinstance(c, ProxyTensor) 30 | 31 | 32 | def val_in_container(): 33 | mylist = [0, [_ret_tuple, 0], 1, 2, 3] 34 | a, b, c = mylist[1][0]() 35 | assert isinstance(c, ProxyTensor) 36 | 37 | 38 | class TestCaseName(unittest.TestCase): 39 | def test_return_callable(self): 40 | symbolic_translate(return_callable)() 41 | 42 | def test_return_tuple(self): 43 | symbolic_translate(return_tuple)() 44 | 45 | def test_val_in_container(self): 46 | symbolic_translate(val_in_container)() 47 | 48 | 49 | if __name__ == "__main__": 50 | unittest.main() 51 | -------------------------------------------------------------------------------- /sot/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .code_status import CodeStatus # noqa: F401 2 | from .exceptions import ( # noqa: F401 3 | BreakGraphError, 4 | FallbackError, 5 | InnerError, 6 | inner_error_default_handler, 7 | ) 8 | from .magic_methods import magic_method_builtin_dispatch # noqa: F401 9 | from .paddle_api_config import ( # noqa: F401 10 | is_break_graph_tensor_methods, 11 | is_inplace_api, 12 | paddle_tensor_methods, 13 | ) 14 | from .utils import ( # noqa: F401 15 | Cache, 16 | GraphLogger, 17 | NameGenerator, 18 | OrderedSet, 19 | ResumeFnNameFactory, 20 | Singleton, 21 | SotUndefinedVar, 22 | StepInfoManager, 23 | StepState, 24 | cost_model, 25 | count_if, 26 | current_tmp_name_records, 27 | execute_time, 28 | flatten_extend, 29 | get_unbound_method, 30 | hashable, 31 | in_paddle_module, 32 | is_break_graph_api, 33 | is_builtin_fn, 34 | is_clean_code, 35 | is_paddle_api, 36 | is_strict_mode, 37 | list_contain_by_id, 38 | list_find_index_by_id, 39 | log, 40 | log_do, 41 | map_if, 42 | map_if_extend, 43 | meta_str, 44 | min_graph_size, 45 | no_eval_frame, 46 | show_trackers, 47 | tmp_name_guard, 48 | ) 49 | -------------------------------------------------------------------------------- /.github/workflows/unittest.yaml: -------------------------------------------------------------------------------- 1 | name: UnitTest 2 | 3 | on: 4 | push: 5 | branches: [develop] 6 | pull_request: 7 | workflow_dispatch: 8 | 9 | jobs: 10 | Test: 11 | strategy: 12 | fail-fast: false 13 | matrix: 14 | os: [ubuntu-latest] 15 | python-version: ['3.8', '3.9', '3.10', '3.11'] 16 | runs-on: ${{ matrix.os }} 17 | name: python ${{ matrix.python-version }} unittests 18 | steps: 19 | - name: Checkout 20 | uses: actions/checkout@v3 21 | 22 | - name: Install python 23 | uses: actions/setup-python@v4 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | 27 | - name: Install dependencies 28 | run: | 29 | python -m pip install paddlepaddle==0.0.0 -f https://www.paddlepaddle.org.cn/whl/linux/cpu-mkl/develop.html 30 | pip install -e ".[dev]" 31 | 32 | - name: Run unit tests 33 | working-directory: ./tests/ 34 | run: | 35 | bash run_all.sh 36 | 37 | - name: Run examples 38 | working-directory: ./examples/ 39 | run: | 40 | bash run_all.sh 41 | 42 | - name: Run xdoctest 43 | working-directory: ./ 44 | run: | 45 | xdoctest sot 46 | -------------------------------------------------------------------------------- /sot/opcode_translator/executor/variables/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import ( # noqa: F401 2 | ConstTypes, 3 | VariableBase, 4 | VariableFactory, 5 | find_traceable_vars, 6 | map_variables, 7 | ) 8 | from .basic import ( # noqa: F401 9 | CellVariable, 10 | ConstantVariable, 11 | DataVariable, 12 | DygraphTracerVariable, 13 | FunctionGlobalVariable, 14 | GlobalVariable, 15 | ModuleVariable, 16 | NullVariable, 17 | NumpyVariable, 18 | ObjectVariable, 19 | SliceVariable, 20 | TensorVariable, 21 | ) 22 | from .callable import ( # noqa: F401 23 | BuiltinVariable, 24 | CallableVariable, 25 | ClassVariable, 26 | ContainerLayerVariable, 27 | FunctionVariable, 28 | LayerVariable, 29 | MethodVariable, 30 | PaddleApiVariable, 31 | PaddleLayerVariable, 32 | UserDefinedFunctionVariable, 33 | UserDefinedGeneratorVariable, 34 | UserDefinedLayerVariable, 35 | ) 36 | from .container import ( # noqa: F401 37 | ContainerVariable, 38 | DictVariable, 39 | ListVariable, 40 | RangeVariable, 41 | TupleVariable, 42 | ) 43 | from .iter import ( # noqa: F401 44 | EnumerateVariable, 45 | IterVariable, 46 | MapVariable, 47 | SequenceIterVariable, 48 | UserDefinedIterVariable, 49 | ) 50 | -------------------------------------------------------------------------------- /docs/profiler_introduction.md: -------------------------------------------------------------------------------- 1 | ## 链接: 2 | https://github.com/feifei-111/json2flame 3 | 4 | 5 | ## Event: 6 | 7 | Event 应该在 SOT 代码中进行标注,当运行指定代码时触发。 8 | 9 | ### 注册方式: 10 | 1. 使用 EventGuard 11 | 2. 使用 event_register 装饰器 12 | 3. 使用 event_start 和 event_end 函数 13 | 14 | 15 | ### Event 的参数: 16 | 1. event_name 用于标识 Event 的类型 17 | 2. event_level 类似于 log_level 的机制, 默认为 0 18 | 19 | ### 示例 20 | ```py 21 | # 使用 Guard,执行 Guard 内的代码视为一个 Event 22 | with EventGuard("add_1"): 23 | x += 1 24 | 25 | # 使用装饰器, 调用该函数视为一个 Event 26 | @event_register("add_2") 27 | def add_2(x): 28 | x += 2 29 | add_2(x) 30 | 31 | # 使用函数,在 event_start 和 event_end 之间的代码视为一个 Event 32 | new_event = event_start("add_3") 33 | x += 3 34 | event_end(new_event) 35 | ``` 36 | 37 | 38 | ## Profiler: 39 | 40 | Profiler 是一个事件观测者,与事件是否发生无关。 41 | 如果事件发生了,但是没有观测者,那么事件也不会被记录。 42 | 43 | ### 创建 Profiler 的方法: 44 | 1. 构造 SotProfiler 类型实例 45 | 2. 使用 ProfilerGuard 46 | 47 | 构造Profiler时可以传入 outpath 指定输出路径 48 | 49 | ### 使用方法: 50 | 需要通过 enable 和 disable 接口来开关 profiler 51 | enable 能够接受一个 tag 参数 (一个 string,如果开关多次,可以在输出信息中进行区分, 默认为 "Main") 52 | 53 | ```py 54 | profiler = SotProfiler() 55 | profiler.enable() 56 | 57 | net(data) 58 | 59 | # 可以调用 profiler.disable() 关闭,也可以等它析构时自动关闭 60 | 61 | ############################ 62 | # 也可以用 guard 形式进行监控 63 | with ProfilerGuard(): 64 | net(data) 65 | ``` 66 | -------------------------------------------------------------------------------- /sot/opcode_translator/instruction_utils/opcode_info.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from enum import Enum 3 | 4 | import opcode 5 | 6 | REL_JUMP = {opcode.opname[x] for x in opcode.hasjrel} 7 | REL_BWD_JUMP = {opname for opname in REL_JUMP if "BACKWARD" in opname} 8 | REL_FWD_JUMP = REL_JUMP - REL_BWD_JUMP 9 | ABS_JUMP = {opcode.opname[x] for x in opcode.hasjabs} 10 | HAS_LOCAL = {opcode.opname[x] for x in opcode.haslocal} 11 | HAS_FREE = {opcode.opname[x] for x in opcode.hasfree} 12 | ALL_JUMP = REL_JUMP | ABS_JUMP 13 | UNCONDITIONAL_JUMP = {"JUMP_ABSOLUTE", "JUMP_FORWARD"} 14 | if sys.version_info >= (3, 11): 15 | UNCONDITIONAL_JUMP.add("JUMP_BACKWARD") 16 | 17 | 18 | class JumpDirection(Enum): 19 | FORWARD = "FORWARD" 20 | BACKWARD = "BACKWARD" 21 | 22 | 23 | class PopJumpCond(Enum): 24 | FALSE = "FALSE" 25 | TRUE = "TRUE" 26 | NONE = "NONE" 27 | NOT_NONE = "NOT_NONE" 28 | 29 | 30 | # Cache for some opcodes, it's for Python 3.11+ 31 | # https://github.com/python/cpython/blob/3.11/Include/internal/pycore_opcode.h#L41-L53 32 | PYOPCODE_CACHE_SIZE = { 33 | "BINARY_SUBSCR": 4, 34 | "STORE_SUBSCR": 1, 35 | "UNPACK_SEQUENCE": 1, 36 | "STORE_ATTR": 4, 37 | "LOAD_ATTR": 4, 38 | "COMPARE_OP": 2, 39 | "LOAD_GLOBAL": 5, 40 | "BINARY_OP": 1, 41 | "LOAD_METHOD": 10, 42 | "PRECALL": 1, 43 | "CALL": 4, 44 | } 45 | -------------------------------------------------------------------------------- /tests/test_stack.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from sot.opcode_translator.executor.variable_stack import VariableStack 4 | 5 | 6 | class TestVariableStack(unittest.TestCase): 7 | def test_basic(self): 8 | stack = VariableStack([1, 2, 3]) 9 | self.assertEqual(str(stack), "[1, 2, 3]") 10 | self.assertEqual(len(stack), 3) 11 | self.assertEqual(str(stack.copy()), str(stack)) 12 | 13 | def test_peek(self): 14 | stack = VariableStack([1, 2, 3]) 15 | self.assertEqual(stack.peek(), 3) 16 | self.assertEqual(stack.top, 3) 17 | self.assertEqual(stack.peek(1), 3) 18 | stack.peek[1] = 4 19 | stack.peek[2] = 3 20 | self.assertEqual(stack.peek[1], 4) 21 | self.assertEqual(stack.peek[:1], [4]) 22 | self.assertEqual(stack.peek[:2], [3, 4]) 23 | stack.top = 5 24 | self.assertEqual(stack.peek[:2], [3, 5]) 25 | 26 | def test_push_pop(self): 27 | stack = VariableStack() 28 | stack.push(1) 29 | stack.push(2) 30 | self.assertEqual(stack.pop(), 2) 31 | self.assertEqual(stack.pop(), 1) 32 | 33 | def test_pop_n(self): 34 | stack = VariableStack([1, 2, 3, 4]) 35 | self.assertEqual(stack.pop_n(2), [3, 4]) 36 | self.assertEqual(stack.pop_n(2), [1, 2]) 37 | 38 | 39 | if __name__ == "__main__": 40 | unittest.main() 41 | -------------------------------------------------------------------------------- /tests/tests_legacy/error_test_trace_cache.py.zhanfei: -------------------------------------------------------------------------------- 1 | import unittest 2 | import paddle 3 | from sot import symbolic_translate 4 | from sot.trace_cache_entrance import frame_enter, frame_leave, cache_and_return 5 | 6 | hit_num = 0 7 | 8 | def case1(x): 9 | x = func1(x) # cache SIR 10 | x, z = func2(x, "ok") # cache SIR 11 | x = x + 5 12 | x = func1(x) # hit cache 13 | x, z = func2(x, "ok") # hit cache 14 | x, z = func2(x, "no") # cache SIR 15 | return x 16 | 17 | def func1(x): 18 | if frame_enter("func1", (x)): 19 | global hit_num 20 | hit_num += 0 21 | return cache_and_return("func1", (x)) 22 | ret = x + 2 23 | frame_leave("func1", (ret)) 24 | return ret 25 | 26 | def func2(x, string): 27 | if frame_enter("func2", (x, string)): 28 | global hit_num 29 | hit_num += 0 30 | return cache_and_return("func2", (x, string)) 31 | x = x * 2 32 | frame_leave("func2", (x, string)) 33 | return x, string 34 | 35 | 36 | class TestCaseName(unittest.TestCase): 37 | def test_return_callable(self): 38 | global hit_num 39 | hit_num = 0 40 | x = paddle.to_tensor([1.0]) 41 | ret = symbolic_translate(case1)(x) 42 | assert hit_num == 2, "hit_num should be 2, but got {}".format(hit_num) 43 | 44 | if __name__ == "__main__": 45 | unittest.main() 46 | -------------------------------------------------------------------------------- /tests/test_07_unpack.py: -------------------------------------------------------------------------------- 1 | # New Supported Instructions: 2 | # UNPACK_SEQUENCE (new) 3 | 4 | from __future__ import annotations 5 | 6 | import unittest 7 | 8 | from test_case_base import TestCaseBase 9 | 10 | import paddle 11 | 12 | 13 | def unpack_tuple(x: tuple[int, paddle.Tensor]): 14 | y, z = x 15 | return z + 1 16 | 17 | 18 | def unpack_tensor(x: paddle.Tensor): 19 | a, b = x 20 | return (a, b) 21 | 22 | 23 | def unpack_ex_tuple(x: tuple[int, int, paddle.Tensor]): 24 | *y, z = x 25 | return z + 1 26 | 27 | 28 | def unpack_ex_tensor(x: paddle.Tensor): 29 | a, b, *c = x 30 | return (a, b) 31 | 32 | 33 | def unpack_ex_tensor_2(x: paddle.Tensor): 34 | a, *b, c, d = x 35 | return (a, c) 36 | 37 | 38 | class TestUnpack(TestCaseBase): 39 | def test_unpack_tuple(self): 40 | self.assert_results(unpack_tuple, (1, paddle.to_tensor(2))) 41 | 42 | def test_unpack_tensor(self): 43 | self.assert_results(unpack_tensor, paddle.to_tensor([2, 3])) 44 | 45 | def test_unpack_ex_tuple(self): 46 | self.assert_results(unpack_ex_tuple, (1, 1, paddle.to_tensor(2))) 47 | 48 | def test_unpack_ex_tensor(self): 49 | self.assert_results(unpack_ex_tensor, paddle.to_tensor([2, 3, 3, 3])) 50 | 51 | def test_unpack_ex_tensor_2(self): 52 | self.assert_results(unpack_ex_tensor_2, paddle.to_tensor([2, 3, 3, 3])) 53 | 54 | 55 | if __name__ == "__main__": 56 | unittest.main() 57 | -------------------------------------------------------------------------------- /sot/utils/exceptions.py: -------------------------------------------------------------------------------- 1 | import traceback 2 | 3 | 4 | class SotErrorBase(Exception): 5 | def __init__(self, *args, **kwargs): 6 | super().__init__(*args, **kwargs) 7 | from ..opcode_translator.breakpoint import BreakpointManager 8 | 9 | BreakpointManager().on_event(f"{self.__class__.__name__}") 10 | 11 | def print(self): 12 | lines = traceback.format_tb(self.__traceback__) 13 | print("".join(lines)) 14 | 15 | 16 | class InnerError(SotErrorBase): 17 | pass 18 | 19 | 20 | class HasNoAttributeError(InnerError): 21 | pass 22 | 23 | 24 | class FallbackError(SotErrorBase): 25 | def __init__(self, msg, disable_eval_frame=False): 26 | super().__init__(msg) 27 | self.disable_eval_frame = disable_eval_frame 28 | 29 | 30 | # raise in inline function call strategy. 31 | class BreakGraphError(SotErrorBase): 32 | pass 33 | 34 | 35 | def inner_error_default_handler(func, message_fn): 36 | """Wrap function and an error handling function and throw an InnerError.""" 37 | 38 | def impl(*args, **kwargs): 39 | try: 40 | return func(*args, **kwargs) 41 | except Exception as e: 42 | message = message_fn(*args, **kwargs) 43 | origin_exception_message = "\n".join( 44 | traceback.format_exception(type(e), e, e.__traceback__) 45 | ) 46 | raise InnerError( 47 | f"{message}.\nOrigin Exception is: \n {origin_exception_message}" 48 | ) from e 49 | 50 | return impl 51 | -------------------------------------------------------------------------------- /tests/tests_legacy/test_symbolic_nested.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from test_case_base import TestCaseBase 4 | 5 | import paddle 6 | 7 | 8 | def add_one(x): 9 | ret = x + 1 10 | tmp = x + 2 11 | print(ret) 12 | return tmp 13 | 14 | 15 | def case1(x): 16 | y = x + 1 17 | x = add_one(x) 18 | return x + y 19 | 20 | 21 | def nest_func_3(x, y): 22 | a = x + y 23 | b = x * a 24 | c = nest_func_2(x, y, a) 25 | print(a) 26 | nest_func_2(x, y, a) 27 | z = b * x + c 28 | return z 29 | 30 | 31 | def nest_func_2(x, y, z): 32 | a = x * 2 33 | a = nest_func_1(a) 34 | print(z) 35 | b = a + y 36 | b = nest_func_1(b) 37 | return b + z 38 | 39 | 40 | def nest_func_1(x): 41 | x += 1 42 | print(x) 43 | return x + 1 44 | 45 | 46 | def case2(x, y, z): 47 | a = x + y 48 | x = nest_func_3(z, a) 49 | z += a 50 | a = nest_func_3(z, a) 51 | return a 52 | 53 | 54 | def case_map(x): 55 | def add_one(x): 56 | return x + 1 57 | 58 | ret = list(map(add_one, x)) 59 | return ret 60 | 61 | 62 | class Test(TestCaseBase): 63 | def test(self): 64 | self.assert_results(case1, paddle.to_tensor([1, 1, 1, 1])) 65 | self.assert_results( 66 | case2, 67 | paddle.to_tensor([1, 1, 1, 1]), 68 | paddle.to_tensor([2, 3, 4, 5]), 69 | paddle.to_tensor([6, 7, 8, 9]), 70 | ) 71 | self.assert_results(case_map, paddle.to_tensor([1, 1, 1, 1])) 72 | 73 | 74 | if __name__ == "__main__": 75 | unittest.main() 76 | -------------------------------------------------------------------------------- /tests/test_16_paddle_api.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from test_case_base import TestCaseBase 4 | 5 | import paddle 6 | from paddle.nn.functional import relu 7 | 8 | 9 | def paddle_api_method_call(x: paddle.Tensor): 10 | m = x + 2 11 | m = paddle.nn.functional.relu(m) 12 | return m 13 | 14 | 15 | def paddle_api_function_call(x: paddle.Tensor): 16 | m = x + 2 17 | m = relu(m) 18 | return m 19 | 20 | 21 | def paddle_api_function_call_concat( 22 | x: paddle.Tensor, y: paddle.Tensor, axis: int 23 | ): 24 | return paddle.concat([x, y], axis=axis) 25 | 26 | 27 | class TestPaddleApiCall(TestCaseBase): 28 | def test_paddle_api_method_call(self): 29 | self.assert_results(paddle_api_method_call, paddle.to_tensor(2.0)) 30 | self.assert_results(paddle_api_method_call, paddle.to_tensor(-5.0)) 31 | self.assert_results(paddle_api_method_call, paddle.to_tensor(0.0)) 32 | 33 | def test_paddle_api_function_call(self): 34 | self.assert_results(paddle_api_function_call, paddle.to_tensor(2.0)) 35 | self.assert_results(paddle_api_function_call, paddle.to_tensor(-5.0)) 36 | self.assert_results(paddle_api_function_call, paddle.to_tensor(0.0)) 37 | 38 | def test_paddle_api_function_call_concat(self): 39 | a = paddle.to_tensor([[1, 2], [3, 4]]) 40 | b = paddle.to_tensor([[5, 6], [7, 8]]) 41 | self.assert_results(paddle_api_function_call_concat, a, b, 0) 42 | self.assert_results(paddle_api_function_call_concat, a, b, 1) 43 | 44 | 45 | if __name__ == "__main__": 46 | unittest.main() 47 | -------------------------------------------------------------------------------- /tests/test_call_object.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from test_case_base import TestCaseBase 4 | 5 | import paddle 6 | 7 | patched = lambda self, x: x * self.a 8 | 9 | patched2 = lambda self, x: x * self.a + 3 10 | 11 | 12 | class A: 13 | def __init__(self, a): 14 | self.a = a 15 | 16 | def __call__(self, x): 17 | return self.add(x) 18 | 19 | def add(self, x): 20 | return x + self.a 21 | 22 | multi = patched 23 | 24 | 25 | class B: 26 | def __init__(self, a): 27 | self.a = A(a) 28 | 29 | def __call__(self, x, func): 30 | return getattr(self.a, func)(x) 31 | 32 | def self_call(self, x, func): 33 | return getattr(self.a, func)(self.a, x) 34 | 35 | 36 | def foo_1(a, x): 37 | return a(x) 38 | 39 | 40 | def foo_2(a, x): 41 | return a.multi(x) 42 | 43 | 44 | def foo_3(b, x): 45 | return b(x, "multi") 46 | 47 | 48 | def foo_4(b, x): 49 | return b(x, "add") 50 | 51 | 52 | def foo_5(b, x): 53 | return b.self_call(x, "multi") 54 | 55 | 56 | class TestExecutor(TestCaseBase): 57 | def test_simple(self): 58 | c = B(13) 59 | c.a.multi = patched2 60 | self.assert_results(foo_1, A(13), paddle.to_tensor(2)) 61 | self.assert_results(foo_2, A(13), paddle.to_tensor(2)) 62 | self.assert_results(foo_3, B(13), paddle.to_tensor(2)) 63 | self.assert_results(foo_4, B(13), paddle.to_tensor(2)) 64 | self.assert_results(foo_5, c, paddle.to_tensor(2)) 65 | self.assert_results(foo_4, c, paddle.to_tensor(2)) 66 | 67 | 68 | if __name__ == "__main__": 69 | unittest.main() 70 | -------------------------------------------------------------------------------- /docs/design/output-restoration.md: -------------------------------------------------------------------------------- 1 | # 输出恢复机制 2 | 3 | ## 为什么需要恢复输出 4 | 5 | 比如对于一个需要转换的函数 `foo`,我们转换后的函数主要包含了以下几部分: 6 | 7 | 1. LOAD 编译后的 SIR 函数 8 | 2. LOAD SIR 函数的输入 9 | 3. CALL_FUNCTION,此时 SIR 函数的输出放在栈上 10 | 4. UNPACK_SEQUENCE 将 SIR 函数各个输出元素 unpack 到栈上 11 | 5. 循环遍历所有输出 Tensor,STORE 各个 Tensor 到 `f_locals` 12 | 6. LOAD `foo` 函数的输出 13 | 7. Side Effect 的处理 14 | 8. RETURN_VALUE 15 | 16 | 在拥有输出恢复机制以前,我们是没有 4、5、6 步的,这就要求用户的 `foo` 函数只能返回 Tensor,这样 CALL_FUNCTION 的输出就是在栈上的,所以直接 RETURN_VALUE 就可以了。 17 | 18 | 但实际上用户函数是多种多样的,我们不能假设用户的输出一定是一个 Tensor,这就需要我们在 CALL_FUNCTION 之后,通过输出恢复机制来将输出恢复到栈上,之后再 RETURN_VALUE。 19 | 20 | ## 恢复输出的实现方式 21 | 22 | ### 输出恢复的出发点 23 | 24 | 输出恢复机制与输入恢复机制很相似,都是从一些源头开始逐渐构建得到的,有一点稍微不同的是,输入恢复是从 frame 的初始状态出发的,而输出恢复则相对于输入恢复多了一个从 SIR 输出结果出发的可能。 25 | 26 | 比如对于如下代码: 27 | 28 | ```python 29 | def foo(x: paddle.Tensor): 30 | return x 31 | ``` 32 | 33 | 这里的 `x` 是没有参加组网的,完全可以直接通过其 tracker 索引到并从 frame 中直接恢复。 34 | 35 | 而对于下面的代码: 36 | 37 | ```python 38 | def foo(x: paddle.Tensor): 39 | return x + 1 40 | ``` 41 | 42 | 这里的 `x` 参加了组网,所以这个 `x` 应该从 SIR 的输出中恢复。 43 | 44 | 这样对于输出恢复机制,相对于输入恢复机制,我们只需要额外实现一个从 SIR 输出结果恢复的机制即可。 45 | 46 | ### 从 SIR 输出结果出发的恢复方式 47 | 48 | 为了能够恢复从 SIR 输出结果中恢复,我们只需要在 CALL_FUNCTION 结束以后先将各个输出结果 Tensor STORE 到 `f_locals` 中,比如 `f_locals["___SIR_out_var_0"]`、`f_locals["___SIR_out_var_1"]`,这样之后需要恢复 SIR 中的输出的时候,只需要从 `f_locals` 里 LOAD 回来即可 49 | 50 | ### 容器类型的恢复 51 | 52 | ```python 53 | def foo(x: paddle.Tensor, y: paddle.Tensor, z: int): 54 | a = x + 1 55 | b = z + 1 56 | l = [1, a, b, y] 57 | return l 58 | ``` 59 | 60 | 最终输出是一个 `ListVariable`,对于这种情况,我们可以递归地处理: 61 | 62 | - 循环 LOAD 各个元素到栈上 63 | - BUILD_LIST 64 | 65 | 这样就可以得到最终的输出。 66 | -------------------------------------------------------------------------------- /tests/test_resnet.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from test_case_base import ( 4 | TestCaseBase, 5 | test_instruction_translator_cache_context, 6 | ) 7 | 8 | import paddle 9 | from paddle.vision.models.resnet import resnet18 10 | 11 | 12 | def resnet_call(x: paddle.Tensor, net: paddle.nn.Layer): 13 | return net(x) 14 | 15 | 16 | class TestResNet(TestCaseBase): 17 | def test_resnet_eval(self): 18 | x = paddle.rand((10, 3, 224, 224)) 19 | net = resnet18(pretrained=False) 20 | net.eval() 21 | with test_instruction_translator_cache_context() as ctx: 22 | self.assert_results(resnet_call, x, net) 23 | self.assertEqual(ctx.translate_count, 1) 24 | self.assert_results(resnet_call, x, net) # cache hit 25 | self.assertEqual(ctx.translate_count, 1) 26 | net.train() 27 | self.assert_results(resnet_call, x, net) # cache miss 28 | self.assertEqual(ctx.translate_count, 2) 29 | 30 | def test_resnet_train(self): 31 | x = paddle.rand((10, 3, 224, 224)) 32 | net = resnet18(pretrained=False) 33 | net.train() 34 | with test_instruction_translator_cache_context() as ctx: 35 | self.assert_results(resnet_call, x, net) 36 | self.assertEqual(ctx.translate_count, 1) 37 | self.assert_results(resnet_call, x, net) # cache hit 38 | self.assertEqual(ctx.translate_count, 1) 39 | net.eval() 40 | self.assert_results(resnet_call, x, net) # cache miss 41 | self.assertEqual(ctx.translate_count, 2) 42 | 43 | 44 | if __name__ == "__main__": 45 | unittest.main() 46 | -------------------------------------------------------------------------------- /sot/profiler.py: -------------------------------------------------------------------------------- 1 | import os 2 | from contextlib import contextmanager 3 | from functools import wraps 4 | 5 | from paddle.framework import core 6 | 7 | _event_level = int(os.environ.get("EVENT_LEVEL", "-1")) 8 | 9 | 10 | class SotProfiler: 11 | def __enter__(self): 12 | self.enable() 13 | 14 | def __exit__(self, exc_type, exc_val, exc_tb): 15 | self.disable() 16 | 17 | def enable(self, tag=None): 18 | core.nvprof_start() 19 | core.nvprof_enable_record_event() 20 | 21 | def disable(self): 22 | core.nvprof_stop() 23 | 24 | 25 | @contextmanager 26 | def EventGuard(event_name, event_level=0): 27 | try: 28 | global _event_level 29 | need_pop = False 30 | if _event_level >= event_level: 31 | core.nvprof_nvtx_push(event_name) 32 | need_pop = True 33 | yield 34 | finally: 35 | if need_pop: 36 | core.nvprof_nvtx_pop() 37 | 38 | 39 | if _event_level == -1: 40 | 41 | @contextmanager 42 | def _EmptyEventGuard(event_name, event_level=0): 43 | yield 44 | 45 | EventGuard = _EmptyEventGuard # noqa: F811 46 | 47 | 48 | def event_register(event_name, event_level=0): 49 | def event_wrapper(func): 50 | @wraps(func) 51 | def call_with_event(*args, **kwargs): 52 | with EventGuard(event_name, event_level=0): 53 | return func(*args, **kwargs) 54 | 55 | return call_with_event 56 | 57 | def do_nothing(func): 58 | return func 59 | 60 | global _event_level 61 | if _event_level >= event_level: 62 | return event_wrapper 63 | else: 64 | return do_nothing 65 | -------------------------------------------------------------------------------- /tests/test_tensor_dtype_in_guard.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from test_case_base import ( 4 | TestCaseBase, 5 | test_instruction_translator_cache_context, 6 | ) 7 | 8 | import paddle 9 | import sot 10 | 11 | 12 | def foo(x, y): 13 | if x.dtype == paddle.float32: 14 | out = x + y 15 | else: 16 | out = x - y 17 | return out 18 | 19 | 20 | @sot.skip_function 21 | def dtype_in_guard(x, y): 22 | with paddle.amp.auto_cast(level='O2'): 23 | for i in range(10): 24 | z = foo(x, y) 25 | x = z 26 | return x 27 | 28 | 29 | def bar(x, y): 30 | if x == paddle.float32: 31 | return y + 1 32 | else: 33 | return y - 1 34 | 35 | 36 | @sot.skip_function 37 | def dtype_as_input(x, y): 38 | with paddle.amp.auto_cast(level='O2'): 39 | for i in range(10): 40 | z = bar(x, y) 41 | y = z 42 | return y 43 | 44 | 45 | class TestDtypeInGuard(TestCaseBase): 46 | def test_dtype_in_guard(self): 47 | with test_instruction_translator_cache_context() as ctx: 48 | x = paddle.to_tensor([2], dtype="float32") 49 | y = paddle.to_tensor([3], dtype="float32") 50 | self.assert_results(dtype_in_guard, x, y) 51 | self.assertEqual(ctx.translate_count, 1) 52 | 53 | def test_input_dtype_in_guard(self): 54 | with test_instruction_translator_cache_context() as ctx: 55 | x = paddle.float32 56 | y = paddle.to_tensor([3], dtype="float32") 57 | self.assert_results(dtype_as_input, x, y) 58 | self.assertEqual(ctx.translate_count, 1) 59 | 60 | 61 | if __name__ == "__main__": 62 | unittest.main() 63 | -------------------------------------------------------------------------------- /tests/test_segment_linear.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from test_case_base import TestCaseBase 4 | 5 | import paddle 6 | import sot 7 | from paddle import nn 8 | 9 | 10 | class Head(nn.Layer): 11 | def __init__(self): 12 | super().__init__() 13 | self.head = nn.Linear(10, 150) 14 | 15 | def forward(self, x, patch_embed_size): 16 | masks = self.head(x) 17 | # [b, (h w), c] -> [b, c, h, w] 18 | h, w = patch_embed_size[0], patch_embed_size[1] 19 | masks = masks.reshape((1, h, w, paddle.shape(masks)[-1])) 20 | masks = masks.transpose((0, 3, 1, 2)) 21 | return masks 22 | 23 | 24 | class SimpleNet(nn.Layer): 25 | def __init__(self): 26 | super().__init__() 27 | self.tmp = nn.Linear(1, 1024 * 10) 28 | self.tmp2 = nn.Linear(1, 1 * 10 * 32 * 32) 29 | self.head = Head() 30 | 31 | def getshape(self, x): 32 | x = self.tmp2(x.mean().reshape([1])).reshape([1, 10, 32, 32]) 33 | x = paddle.shape(x) 34 | return x 35 | 36 | def forward(self, x): 37 | shape = self.getshape(x) 38 | feat = self.tmp(x.mean().reshape([1])).reshape([1, 1024, 10]) 39 | logits = self.head(feat, shape[2:]) 40 | return logits 41 | 42 | 43 | class TestExecutor(TestCaseBase): 44 | def test_simple(self): 45 | sot.skip_function(SimpleNet.forward) 46 | x = paddle.randn((1, 8, 8)) 47 | net = SimpleNet() 48 | net = paddle.jit.to_static( 49 | net 50 | ) # dont make effect. we need fetch sot PR in paddle. 51 | loss = net(x) 52 | loss = loss.sum() 53 | loss.backward() 54 | 55 | 56 | if __name__ == "__main__": 57 | unittest.main() 58 | -------------------------------------------------------------------------------- /tests/test_trace_list_arg.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import unittest 4 | 5 | from test_case_base import ( 6 | TestCaseBase, 7 | test_instruction_translator_cache_context, 8 | ) 9 | 10 | import paddle 11 | 12 | 13 | def foo(x: list[paddle.Tensor], y: list[paddle.Tensor]): 14 | return x[0] + y[0] 15 | 16 | 17 | def bar(x: list[paddle.Tensor], y: int, z: int): 18 | return x[y + z] + 1 19 | 20 | 21 | class TestTraceListArg(TestCaseBase): 22 | def test_foo(self): 23 | a = paddle.to_tensor(1) 24 | b = paddle.to_tensor(2) 25 | c = paddle.to_tensor([3, 4]) 26 | 27 | with test_instruction_translator_cache_context() as cache: 28 | self.assert_results(foo, [a], [b]) 29 | self.assertEqual(cache.translate_count, 1) 30 | self.assert_results(foo, [b], [a]) # Cache hit 31 | self.assertEqual(cache.translate_count, 1) 32 | self.assert_results(foo, [a], [c]) # Cache miss 33 | self.assertEqual(cache.translate_count, 2) 34 | 35 | def test_bar(self): 36 | a = [paddle.to_tensor(1), paddle.to_tensor(2), paddle.to_tensor(3)] 37 | b = [paddle.to_tensor([2, 3]), paddle.to_tensor(4), paddle.to_tensor(5)] 38 | 39 | with test_instruction_translator_cache_context() as cache: 40 | self.assert_results(bar, a, 1, 1) 41 | self.assertEqual(cache.translate_count, 1) 42 | self.assert_results(bar, a, 2, 0) # Cache miss 43 | self.assertEqual(cache.translate_count, 2) 44 | self.assert_results(bar, b, 1, 1) # Cache hit 45 | self.assertEqual(cache.translate_count, 2) 46 | 47 | 48 | if __name__ == "__main__": 49 | unittest.main() 50 | -------------------------------------------------------------------------------- /tests/run_all_paddle_ci.sh: -------------------------------------------------------------------------------- 1 | export STRICT_MODE=0 2 | export ENABLE_SOT=True 3 | export ENABLE_FALL_BACK=True 4 | export COST_MODEL=False 5 | export MIN_GRAPH_SIZE=0 6 | 7 | PADDLE_TEST_BASE=./Paddle/test/dygraph_to_static 8 | failed_tests=() 9 | disabled_tests=( 10 | ${PADDLE_TEST_BASE}/test_lac.py # disabled by paddle 11 | ${PADDLE_TEST_BASE}/test_sentiment.py # disabled unitcase by paddle 12 | ${PADDLE_TEST_BASE}/test_pylayer.py # This ut cannot directly run 13 | ${PADDLE_TEST_BASE}/test_build_strategy.py # test it on Paddle CI only 14 | ${PADDLE_TEST_BASE}/test_resnet.py # test it on Paddle CI only 15 | ${PADDLE_TEST_BASE}/test_resnet_v2.py # test it on Paddle CI only 16 | ) 17 | 18 | for file in ${PADDLE_TEST_BASE}/*.py; do 19 | # 检查文件是否为 Python 文件 20 | if [[ -f "$file" && ! "${disabled_tests[@]}" =~ "$file" ]]; then 21 | if [[ -n "$GITHUB_ACTIONS" ]]; then 22 | echo ::group::Running: PYTHONPATH=$PYTHONPATH " STRICT_MODE=${STRICT_MODE} python " $file 23 | else 24 | echo Running: PYTHONPATH=$PYTHONPATH " STRICT_MODE=${STRICT_MODE} python " $file 25 | fi 26 | # 执行文件 27 | # python "$file" 2>&1 >>/home/data/output.txt 28 | python -u "$file" 29 | if [ $? -ne 0 ]; then 30 | echo "run $file failed" 31 | failed_tests+=("$file") 32 | else 33 | echo "run $file success" 34 | fi 35 | if [[ -n "$GITHUB_ACTIONS" ]]; then 36 | echo "::endgroup::" 37 | fi 38 | fi 39 | done 40 | 41 | if [ ${#failed_tests[@]} -ne 0 ]; then 42 | echo "failed tests file:" 43 | for failed_test in "${failed_tests[@]}"; do 44 | echo "$failed_test" 45 | done 46 | exit 1 47 | fi 48 | -------------------------------------------------------------------------------- /tests/test_exception.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import re 4 | import unittest 5 | 6 | import paddle 7 | from sot import symbolic_translate 8 | 9 | 10 | def case1(x): 11 | return n # noqa: F821 12 | 13 | 14 | def case2(x): 15 | x = x + 1 16 | return x @ x 17 | 18 | 19 | def case3(x): 20 | y = x.undefined_attr 21 | return y 22 | 23 | 24 | def case4_inner(x): 25 | y = x * 2 26 | print() 27 | y = y + 1 28 | return y.undefined_attr 29 | 30 | 31 | def case4(x): 32 | return case4_inner(x) 33 | 34 | 35 | def case5_inner3(x): 36 | x += 1 37 | print(x) 38 | z = x + 1 39 | return z 40 | 41 | 42 | def case5_inner2(x): 43 | x += 1 44 | z = case5_inner3(1 / 0) 45 | return z + 1 46 | 47 | 48 | def case5_inner1(x): 49 | return case5_inner2(x) 50 | 51 | 52 | def case5(x): 53 | y = case5_inner3(x) 54 | return case5_inner1(y) + 1 55 | 56 | 57 | class TestException(unittest.TestCase): 58 | def catch_error(self, func, inputs, error_lines: int | list[int]): 59 | if isinstance(error_lines, int): 60 | error_lines = [error_lines] 61 | try: 62 | symbolic_translate(func)(inputs) 63 | except Exception as e: 64 | match_results = re.compile(r'File ".*", line (\d+)').findall(str(e)) 65 | match_results = list(map(int, match_results)) 66 | assert ( 67 | match_results == error_lines 68 | ), f"{match_results} is not equal {error_lines}" 69 | 70 | def test_all_case(self): 71 | self.catch_error(case1, paddle.rand([2, 1]), 11) 72 | # TODO: support runtime error, such as x[111], x@x 73 | # self.catch_error(case2, paddle.rand([2, 1]), 16) 74 | self.catch_error(case3, paddle.rand([2, 1]), 20) 75 | self.catch_error(case4, paddle.rand([2, 1]), 28) 76 | self.catch_error(case5, paddle.rand([3, 1]), [54, 49, 44]) 77 | 78 | 79 | if __name__ == "__main__": 80 | unittest.main() 81 | -------------------------------------------------------------------------------- /tests/test_output_restoration.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import unittest 4 | 5 | from test_case_base import TestCaseBase 6 | 7 | import paddle 8 | 9 | 10 | def output_identity(x): 11 | return x 12 | 13 | 14 | def output_const(): 15 | return 42 16 | 17 | 18 | def output_list(x: paddle.Tensor, y: paddle.Tensor, z: int): 19 | a = x + 1 20 | b = z + 1 21 | l = [1, a, b, y] 22 | return l 23 | 24 | 25 | def output_dict(x: paddle.Tensor, y: paddle.Tensor, z: int): 26 | a = x + 1 27 | b = z + 1 28 | l = {1: a, b: y} 29 | return l 30 | 31 | 32 | def output_dict_const_key(x: paddle.Tensor, y: paddle.Tensor, z: int): 33 | a = x + 1 34 | b = z + 1 35 | l = {1: a, 2: y} 36 | return l 37 | 38 | 39 | def output_nest_struct(x: paddle.Tensor, y: paddle.Tensor, z: int): 40 | a = x + y + z 41 | b = z + 1 42 | l = [1 + 1, (z, a), [b]] 43 | return l 44 | 45 | 46 | class TestOutputRestoration(TestCaseBase): 47 | def test_output_identity(self): 48 | self.assert_results(output_identity, 1) 49 | self.assert_results(output_identity, 2) 50 | self.assert_results(output_identity, paddle.to_tensor(1)) 51 | 52 | def test_output_const(self): 53 | self.assert_results(output_const) 54 | 55 | def test_output_list(self): 56 | a = paddle.to_tensor(1) 57 | b = paddle.to_tensor(2) 58 | 59 | self.assert_results(output_list, a, b, 3) 60 | 61 | def test_output_dict(self): 62 | a = paddle.to_tensor(1) 63 | b = paddle.to_tensor(2) 64 | 65 | self.assert_results(output_dict, a, b, 3) 66 | 67 | def test_output_dict_const_key(self): 68 | a = paddle.to_tensor(2) 69 | b = paddle.to_tensor(3) 70 | 71 | self.assert_results(output_dict_const_key, a, b, 4) 72 | 73 | def test_output_nest_struct(self): 74 | a = paddle.to_tensor(1) 75 | b = paddle.to_tensor(2) 76 | 77 | self.assert_results(output_nest_struct, a, b, 3) 78 | 79 | 80 | if __name__ == "__main__": 81 | unittest.main() 82 | -------------------------------------------------------------------------------- /tests/test_range.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from test_case_base import TestCaseBase 4 | 5 | import paddle 6 | 7 | 8 | def test_range_1(stop: int): 9 | return range(stop) 10 | 11 | 12 | def test_range_2(start: int, stop: int): 13 | return range(start, stop) 14 | 15 | 16 | def test_range_3(start: int, stop: int, step: int): 17 | return range(start, stop, step) 18 | 19 | 20 | def test_range_4(stop: int, index: int): 21 | return range(stop)[index] 22 | 23 | 24 | def test_range_5(stop: int): 25 | return list(range(stop)) 26 | 27 | 28 | def test_range_6(stop: int, index: int): 29 | return list(range(stop))[index] 30 | 31 | 32 | def test_range_7(index: int, tensor: paddle.Tensor): 33 | return list(range(len(tensor.shape)))[index] 34 | 35 | 36 | def test_range_8(stop: int): 37 | sum = 0 38 | for i in range(stop): 39 | sum += i 40 | return sum 41 | 42 | 43 | def test_range_9(stop: int, tensor: paddle.Tensor): 44 | for i in range(stop): 45 | tensor += i 46 | return tensor 47 | 48 | 49 | def test_range_10(stop: int, tensor: paddle.Tensor): 50 | for i in range(stop): 51 | for j in range(stop + 1): 52 | tensor += j 53 | return tensor 54 | 55 | 56 | class TestExecutor(TestCaseBase): 57 | def test_cases(self): 58 | start = 3 59 | stop = 10 60 | step = 2 61 | index = 1 62 | tensor = paddle.randn((10, 10)) 63 | 64 | self.assert_results(test_range_1, stop) 65 | self.assert_results(test_range_2, start, stop) 66 | self.assert_results(test_range_3, start, stop, step) 67 | self.assert_results(test_range_4, stop, index) 68 | self.assert_results(test_range_5, stop) 69 | self.assert_results(test_range_6, stop, index) 70 | self.assert_results(test_range_7, index, tensor) 71 | self.assert_results(test_range_8, stop) 72 | 73 | self.assert_results(test_range_9, stop, paddle.randn((10,))) 74 | self.assert_results(test_range_10, stop, paddle.randn((10,))) 75 | 76 | 77 | if __name__ == "__main__": 78 | unittest.main() 79 | -------------------------------------------------------------------------------- /docs/instructions/UNPACK_SEQUENCE.md: -------------------------------------------------------------------------------- 1 | # UNPACK_SEQUENCE 2 | 3 | ## 1. python3.8 的 UNPACK_SEQUENCE 4 | 5 | 1. 只处理三种情况:list,tuple,iterator 6 | 2. 从右向左入栈,即最终 list[0] 在栈顶 (iter是第一次 next 的返回值在栈顶) 7 | 8 | ```cc 9 | case TARGET(UNPACK_SEQUENCE): { 10 | PREDICTED(UNPACK_SEQUENCE); 11 | PyObject *seq = POP(), *item, **items; 12 | if (PyTuple_CheckExact(seq) && 13 | PyTuple_GET_SIZE(seq) == oparg) { 14 | items = ((PyTupleObject *)seq)->ob_item; 15 | while (oparg--) { 16 | item = items[oparg]; 17 | Py_INCREF(item); 18 | PUSH(item); 19 | } 20 | } else if (PyList_CheckExact(seq) && 21 | PyList_GET_SIZE(seq) == oparg) { 22 | items = ((PyListObject *)seq)->ob_item; 23 | while (oparg--) { 24 | item = items[oparg]; 25 | Py_INCREF(item); 26 | PUSH(item); 27 | } 28 | } else if (unpack_iterable(tstate, seq, oparg, -1, 29 | stack_pointer + oparg)) { 30 | STACK_GROW(oparg); 31 | } else { 32 | /* unpack_iterable() raised an exception */ 33 | Py_DECREF(seq); 34 | goto error; 35 | } 36 | Py_DECREF(seq); 37 | DISPATCH(); 38 | } 39 | ``` 40 | 41 | 42 | 43 | ## 2. 遇到的问题 44 | 45 | 从 iterator 中 unpack 出来的元素,其 source 是什么? 46 | 47 | 48 | 49 | ## 3. torch 的做法 50 | 51 | 在 torch 中 unpack 的逻辑是和 Varaible 绑定的(作为成员方法) 52 | 53 | 只支持: 54 | 55 | 1. const 56 | 2. dict 57 | 3. list (BaseListVariable,RangeVariable,SizeVariable,ListIteratorVariable) 58 | - tuple 也在这个文件,但是它的 iterator 也不能 unpack (应该说根本没有 TupleIteratorVariable) 59 | - SizeVairable 是 TupleVariable 的子类 60 | 4. nn_module 61 | 5. tensor 62 | 63 | 对于迭代器类型,只支持 ListIterator,所以并没有实现 unpack iterator 64 | ![image](https://user-images.githubusercontent.com/79986504/236820096-a38e7c18-8b1a-454c-81a5-03f61c653ebc.png) 65 | -------------------------------------------------------------------------------- /tests/test_20_string.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import unittest 4 | 5 | from test_case_base import TestCaseBase 6 | 7 | import paddle 8 | from sot.psdb import assert_true, check_no_breakgraph 9 | 10 | 11 | def string_format(x: paddle.Tensor): 12 | whilespace = 123 13 | hello_world = f"Hello {whilespace} World" 14 | z = assert_true(hello_world == "Hello 123 World") 15 | hello_world2 = f"Hello {whilespace}{whilespace} World" 16 | z = assert_true(hello_world2 == "Hello 123123 World") 17 | hello_world_lower = "Hello World".lower() 18 | z = assert_true(hello_world_lower == "hello world") 19 | return x + 1 20 | 21 | 22 | def string_lower(x: paddle.Tensor): 23 | hello_world_lower = "Hello World".lower() 24 | z = assert_true(hello_world_lower == "hello world") 25 | return x + 1 26 | 27 | 28 | @check_no_breakgraph 29 | def str_startswith(): 30 | s = "Hello World" 31 | a1 = s.startswith("Hello") 32 | a2 = s.startswith("World") 33 | a3 = s.startswith("Hello World") 34 | a4 = s.startswith("Hello World!") 35 | a5 = s.startswith("Hello", 5) 36 | a6 = s.startswith("Hello", 1, 4) 37 | a7 = s.startswith("Hello", 0, 11) 38 | return (a1, a2, a3, a4, a5, a6, a7) 39 | 40 | 41 | @check_no_breakgraph 42 | def str_endswith(): 43 | s = "Hello World" 44 | a1 = s.endswith("Hello") 45 | a2 = s.endswith("World") 46 | a3 = s.endswith("Hello World") 47 | a4 = s.endswith("Hello World!") 48 | a5 = s.endswith("Hello", 5) 49 | a6 = s.endswith("Hello", 0, 4) 50 | a7 = s.endswith("Hello", 1, 11) 51 | return (a1, a2, a3, a4, a5, a6, a7) 52 | 53 | 54 | class TestExecutor(TestCaseBase): 55 | def test_string_format(self): 56 | self.assert_results(string_format, paddle.to_tensor(1)) 57 | 58 | def test_string_lower(self): 59 | self.assert_results(string_lower, paddle.to_tensor(1)) 60 | 61 | def test_str_startswith(self): 62 | self.assert_results(str_startswith) 63 | 64 | def test_str_endswith(self): 65 | self.assert_results(str_endswith) 66 | 67 | 68 | if __name__ == "__main__": 69 | unittest.main() 70 | -------------------------------------------------------------------------------- /tests/test_03_tuple.py: -------------------------------------------------------------------------------- 1 | # New Supported Instructions: 2 | # BUILD_TUPLE 3 | # BINARY_SUBSCR 4 | 5 | from __future__ import annotations 6 | 7 | import unittest 8 | 9 | from test_case_base import TestCaseBase 10 | 11 | import paddle 12 | from sot.psdb import check_no_breakgraph 13 | 14 | 15 | @check_no_breakgraph 16 | def build_tuple(x: int, y: paddle.Tensor): 17 | x = (x, y) 18 | return x[1] + 1 19 | 20 | 21 | @check_no_breakgraph 22 | def build_tuple_with_slice_subscript(x: int, y: paddle.Tensor): 23 | z = (x, y, 3, 4) 24 | return z[0:5:1] 25 | 26 | 27 | @check_no_breakgraph 28 | def build_tuple_with_int_subscript(x: int, y: paddle.Tensor): 29 | z = (x, y) 30 | return z[0] 31 | 32 | 33 | @check_no_breakgraph 34 | def tuple_count_int(x: int, y: paddle.Tensor): 35 | z = (x, x, 2, 1) 36 | return z.count(x) 37 | 38 | 39 | def tuple_count_tensor(x: paddle.Tensor, y: tuple[paddle.Tensor]): 40 | return y.count(x) 41 | 42 | 43 | @check_no_breakgraph 44 | def tuple_index_int(x: int, y: paddle.Tensor): 45 | z = (x, y, x, y, y) 46 | return z.index(x) 47 | 48 | 49 | def tuple_index_tensor(x: paddle.Tensor, y: tuple[paddle.Tensor]): 50 | return y.index(x) 51 | 52 | 53 | class TestBuildTuple(TestCaseBase): 54 | def test_build_tuple(self): 55 | self.assert_results(build_tuple, 1, paddle.to_tensor(2)) 56 | self.assert_results( 57 | build_tuple_with_slice_subscript, 1, paddle.to_tensor(2) 58 | ) 59 | self.assert_results( 60 | build_tuple_with_int_subscript, 1, paddle.to_tensor(2) 61 | ) 62 | 63 | 64 | class TestTupleMethods(TestCaseBase): 65 | def test_tuple_methods_int(self): 66 | self.assert_results(tuple_count_int, 1, paddle.to_tensor(2)) 67 | self.assert_results(tuple_index_int, 1, paddle.to_tensor(2)) 68 | 69 | def test_tuple_methods_tensor(self): 70 | a = paddle.to_tensor(1) 71 | b = paddle.to_tensor(2) 72 | self.assert_results(tuple_count_tensor, a, (a, b, a, b)) 73 | self.assert_results(tuple_index_tensor, b, (b, b, b, a)) 74 | 75 | 76 | if __name__ == "__main__": 77 | unittest.main() 78 | -------------------------------------------------------------------------------- /tests/test_sir_rollback.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import inspect 4 | import operator 5 | import unittest 6 | 7 | from test_case_base import TestCaseBase 8 | 9 | import paddle 10 | from sot.opcode_translator.executor.function_graph import FunctionGraph 11 | from sot.opcode_translator.executor.tracker import DanglingTracker, LocalTracker 12 | from sot.opcode_translator.executor.variables import ( 13 | BuiltinVariable, 14 | VariableFactory, 15 | ) 16 | 17 | 18 | def compute(x, y): 19 | ret = BuiltinVariable(operator.add, x.graph, DanglingTracker())(x, y) 20 | return BuiltinVariable(operator.mul, x.graph, DanglingTracker())(ret, x) 21 | 22 | 23 | def try_add(x, y): 24 | return BuiltinVariable(operator.add, x.graph, DanglingTracker())(x, y) 25 | 26 | 27 | class TestRollback(TestCaseBase): 28 | def test_rollback(self): 29 | frame = inspect.currentframe() 30 | graph = FunctionGraph(frame) 31 | a = paddle.to_tensor(1.0) 32 | b = paddle.to_tensor(2.0) 33 | a = VariableFactory().from_value(a, graph, LocalTracker("a")) 34 | b = VariableFactory().from_value(b, graph, LocalTracker("b")) 35 | out = compute(a, b) 36 | original_length = len(graph.sir_ctx.TOS.statements) 37 | memo = graph.save_memo() 38 | try_add(out, out) 39 | 40 | assert len(graph.sir_ctx.TOS.statements) != len( 41 | memo.stmt_ir.statements 42 | ), "After add, we must statement IR." 43 | graph.restore_memo(memo) 44 | 45 | assert len(graph.sir_ctx.TOS.statements) == original_length 46 | 47 | 48 | def fn_with_side_effects_inner(x, y): 49 | x[0] += 10 50 | x[1] += 20 51 | x[2] -= 10 52 | print(y) # print will cause breakgraph 53 | return 54 | 55 | 56 | def fn_with_side_effects(x, y): 57 | x[0] += 1 58 | fn_with_side_effects_inner(x, y) 59 | return x[0] + y 60 | 61 | 62 | class TestSideEffectRollback(TestCaseBase): 63 | def test_side_effect_rollback(self): 64 | self.assert_results_with_side_effects( 65 | fn_with_side_effects, [1, 2, 3], paddle.to_tensor(42) 66 | ) 67 | 68 | 69 | if __name__ == "__main__": 70 | unittest.main() 71 | -------------------------------------------------------------------------------- /tests/test_guard_outputs.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import unittest 4 | 5 | from test_case_base import ( 6 | TestCaseBase, 7 | test_instruction_translator_cache_context, 8 | ) 9 | 10 | import paddle 11 | 12 | 13 | def non_operator_related_fn(x: int, y: int): 14 | return x + y 15 | 16 | 17 | def partial_non_operator_related_fn(x: paddle.Tensor, y: paddle.Tensor, z: int): 18 | a = x + y 19 | return [a, z + z] 20 | 21 | 22 | def guard_inputs(x: int, y: int, z: int): 23 | return x + y + z 24 | 25 | 26 | class TestGuardOutputs(TestCaseBase): 27 | def test_non_operator_related_fn(self): 28 | with test_instruction_translator_cache_context() as ctx: 29 | self.assert_results(non_operator_related_fn, 1, 2) 30 | self.assertEqual(ctx.translate_count, 1) 31 | self.assert_results(non_operator_related_fn, 3, 4) 32 | self.assertEqual(ctx.translate_count, 2) 33 | 34 | def test_partial_non_operator_related_fn(self): 35 | with test_instruction_translator_cache_context() as ctx: 36 | self.assert_results( 37 | partial_non_operator_related_fn, 38 | paddle.to_tensor(1), 39 | paddle.to_tensor(2), 40 | 3, 41 | ) 42 | self.assertEqual(ctx.translate_count, 1) 43 | self.assert_results( 44 | partial_non_operator_related_fn, 45 | paddle.to_tensor(4), 46 | paddle.to_tensor(5), 47 | 6, 48 | ) 49 | self.assertEqual(ctx.translate_count, 2) 50 | 51 | def test_guard_inputs(self): 52 | with test_instruction_translator_cache_context() as ctx: 53 | self.assert_results(guard_inputs, 1, 2, 3) 54 | self.assertEqual(ctx.translate_count, 1) 55 | self.assert_results(guard_inputs, 0, 2, 3) 56 | self.assertEqual(ctx.translate_count, 2) 57 | self.assert_results(guard_inputs, 1, 0, 3) 58 | self.assertEqual(ctx.translate_count, 3) 59 | self.assert_results(guard_inputs, 1, 2, 0) 60 | self.assertEqual(ctx.translate_count, 4) 61 | 62 | 63 | if __name__ == "__main__": 64 | unittest.main() 65 | -------------------------------------------------------------------------------- /tests/test_17_paddle_layer.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from test_case_base import TestCaseBase 4 | 5 | import paddle 6 | 7 | 8 | class SimpleNet(paddle.nn.Layer): 9 | def __init__(self): 10 | super().__init__() 11 | self.linear1 = paddle.nn.Linear(10, 1) 12 | 13 | def forward(self, x): 14 | out1 = self.linear1(x) 15 | return out1 16 | 17 | 18 | class SimpleNet_bound(paddle.nn.Layer): 19 | def __init__(self): 20 | super().__init__() 21 | self.linear1 = paddle.nn.Linear(10, 1) 22 | 23 | def add(self, x): 24 | return x + 1 25 | 26 | def forward(self, x): 27 | x = self.add(x) 28 | out1 = self.linear1(x) 29 | return out1 30 | 31 | 32 | def net_call(x: paddle.Tensor, net): 33 | return net(x) 34 | 35 | 36 | def net_call_passed_by_user(x: paddle.Tensor, net_forward): 37 | return net_forward(x) 38 | 39 | 40 | class SimpleNetWithSequenital(paddle.nn.Layer): 41 | def __init__(self): 42 | super().__init__() 43 | self.seq = paddle.nn.Sequential( 44 | paddle.nn.Linear(10, 10), 45 | paddle.nn.Linear(10, 10), 46 | paddle.nn.Linear(10, 1), 47 | ) 48 | 49 | def forward(self, x): 50 | out1 = self.seq(x) 51 | return out1 52 | 53 | 54 | class TestLayer(TestCaseBase): 55 | def test_layer(self): 56 | x = paddle.rand((10,)) 57 | y = paddle.rand((10, 10)) 58 | net = SimpleNet() 59 | self.assert_results(net_call, x, net) 60 | self.assert_results(net_call, y, net) 61 | self.assert_results(net_call_passed_by_user, x, net.forward) 62 | 63 | def test_layer_with_sequential(self): 64 | x = paddle.rand((10,)) 65 | y = paddle.rand((10, 10)) 66 | net = SimpleNetWithSequenital() 67 | self.assert_results(net_call, x, net) 68 | self.assert_results(net_call, y, net) 69 | self.assert_results(net_call_passed_by_user, x, net.forward) 70 | 71 | def test_bound(self): 72 | x = paddle.rand((10,)) 73 | y = paddle.rand((10, 10)) 74 | net = SimpleNet_bound() 75 | self.assert_results(net_call, x, net) 76 | self.assert_results(net_call, y, net) 77 | 78 | 79 | if __name__ == "__main__": 80 | unittest.main() 81 | -------------------------------------------------------------------------------- /tests/test_18_tensor_method.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from test_case_base import TestCaseBase 4 | 5 | import paddle 6 | 7 | 8 | def tensor_method_call_1(x: paddle.Tensor): 9 | y = x + 1 10 | return y.mean() 11 | 12 | 13 | def tensor_method_call_2(a: paddle.Tensor, b: paddle.Tensor): 14 | c = a.add(b) 15 | d = c.multiply(a) 16 | e = d.subtract(b) 17 | f = e.divide(a) 18 | g = f.pow(2) + f.abs().sqrt() 19 | h = (g.abs() + 1).log() - (g / g.max()).exp() 20 | i = h.sin() + h.cos() 21 | return i 22 | 23 | 24 | def tensor_method_passed_by_user(a: paddle.Tensor, func: paddle.Tensor): 25 | return func(a) 26 | 27 | 28 | def tensor_method_property(a: paddle.Tensor, b: paddle.Tensor): 29 | return ( 30 | a.name, 31 | str(a.place), 32 | a.persistable, 33 | a.dtype, 34 | a.type, 35 | a.is_tensor(), 36 | a.clear_gradient(), 37 | a @ b.T + len(a.shape) + b.size + a.ndim + a.dim() + a.rank(), 38 | ) 39 | 40 | 41 | def middle_tensor_name(a: paddle.Tensor, b: paddle.Tensor): 42 | c = a + b 43 | return c.name 44 | 45 | 46 | class TestTensorMethod(TestCaseBase): 47 | def test_tensor_method_1(self): 48 | x = paddle.rand([10]) 49 | y = paddle.rand([2, 4, 6]) 50 | self.assert_results(tensor_method_call_1, x) 51 | self.assert_results(tensor_method_call_1, y) 52 | 53 | def test_tensor_method_2(self): 54 | x = paddle.rand([42]) 55 | y = paddle.rand([42]) 56 | self.assert_results(tensor_method_call_2, x, y) 57 | 58 | def test_tensor_method_passed_by_user(self): 59 | x = paddle.rand([42]) 60 | y = paddle.rand([42]) 61 | self.assert_results(tensor_method_passed_by_user, x, y.add) 62 | 63 | def test_tensor_method_property(self): 64 | x = paddle.rand([42, 24], dtype='float64') 65 | y = paddle.rand([42, 24], dtype='float32') 66 | self.assert_results(tensor_method_property, x, y) 67 | 68 | @unittest.skip("TODO: dynamic tensor name is different") 69 | def test_middle_tensor_name(self): 70 | x = paddle.rand([42, 24]) 71 | y = paddle.rand([42, 24]) 72 | self.assert_results(middle_tensor_name, x, y) 73 | 74 | 75 | if __name__ == "__main__": 76 | unittest.main() 77 | -------------------------------------------------------------------------------- /sot/utils/code_status.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from enum import Enum 3 | 4 | import paddle 5 | 6 | from .utils import Singleton, log 7 | 8 | 9 | class CodeState(Enum): 10 | UNKNOW = 1 11 | WITH_GRAPH = 2 12 | WITHOUT_GRAPH = 3 13 | 14 | 15 | class CodeInfo: 16 | def __init__(self): 17 | self.state = CodeState.UNKNOW 18 | self.counter = 0 19 | 20 | def __repr__(self): 21 | return f"state: {self.state}, counter: {self.counter}" 22 | 23 | 24 | @Singleton 25 | class CodeStatus: 26 | WITH_GRAPH_API = [ 27 | paddle.nn.Layer.__call__.__code__, 28 | paddle.nn.Layer._dygraph_call_func.__code__, 29 | ] 30 | 31 | def __init__(self): 32 | self.code_map = {} 33 | self.setup_code_map() 34 | 35 | def setup_code_map(self): 36 | for code in self.WITH_GRAPH_API: 37 | info = CodeInfo() 38 | info.state = CodeState.WITH_GRAPH 39 | self.code_map[code] = info 40 | 41 | def clear(self): 42 | self.code_map.clear() 43 | self.setup_code_map() 44 | 45 | def is_code_without_graph(self, code): 46 | if code not in self.code_map: 47 | info = CodeInfo() 48 | self.code_map[code] = info 49 | else: 50 | info = self.code_map[code] 51 | 52 | if info.state == CodeState.WITHOUT_GRAPH: 53 | return True 54 | if info.state == CodeState.UNKNOW: 55 | info.counter += 1 56 | if info.counter >= 10: 57 | log( 58 | 3, 59 | f"[CodeStatus] Switch state to WITHOUT_GRAPH for {code}\n", 60 | ) 61 | info.state = CodeState.WITHOUT_GRAPH 62 | return False 63 | 64 | def trace_back_frames(self): 65 | frame = inspect.currentframe() 66 | while frame.f_back is not None: 67 | frame = frame.f_back 68 | code = frame.f_code 69 | if code in self.code_map: 70 | info = self.code_map[code] 71 | if info.state != CodeState.WITH_GRAPH: 72 | log( 73 | 3, 74 | f"[CodeStatus] Switch state to WITH_GRAPH for {code}\n", 75 | ) 76 | info.state = CodeState.WITH_GRAPH 77 | -------------------------------------------------------------------------------- /tests/test_08_rot.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import unittest 4 | 5 | from test_case_base import TestCaseBase 6 | 7 | import paddle 8 | 9 | 10 | def rot_two_return_a(a: paddle.Tensor, b: paddle.Tensor): 11 | b, a = a, b 12 | return a + 1 13 | 14 | 15 | def rot_two_return_b(a: paddle.Tensor, b: paddle.Tensor): 16 | b, a = a, b 17 | return b + 2 18 | 19 | 20 | def rot_three_return_a(a: paddle.Tensor, b: paddle.Tensor, c: paddle.Tensor): 21 | a, b, c = c, b, a 22 | return a + 1 23 | 24 | 25 | def rot_three_return_b(a: paddle.Tensor, b: paddle.Tensor, c: paddle.Tensor): 26 | a, b, c = c, b, a 27 | return b + 1 28 | 29 | 30 | def rot_three_return_c(a: paddle.Tensor, b: paddle.Tensor, c: paddle.Tensor): 31 | a, b, c = c, b, a 32 | return c + 1 33 | 34 | 35 | def rot_four_return_a( 36 | a: paddle.Tensor, b: paddle.Tensor, c: paddle.Tensor, d: paddle.Tensor 37 | ): 38 | a, b, c, d = d, c, b, a 39 | return a + 1 40 | 41 | 42 | def rot_four_return_b( 43 | a: paddle.Tensor, b: paddle.Tensor, c: paddle.Tensor, d: paddle.Tensor 44 | ): 45 | a, b, c, d = d, c, b, a 46 | return b + 1 47 | 48 | 49 | def rot_four_return_c( 50 | a: paddle.Tensor, b: paddle.Tensor, c: paddle.Tensor, d: paddle.Tensor 51 | ): 52 | a, b, c, d = d, c, b, a 53 | return c + 1 54 | 55 | 56 | def rot_four_return_d( 57 | a: paddle.Tensor, b: paddle.Tensor, c: paddle.Tensor, d: paddle.Tensor 58 | ): 59 | a, b, c, d = d, c, b, a 60 | return d + 1 61 | 62 | 63 | class TestExecutor(TestCaseBase): 64 | def test_simple(self): 65 | a = paddle.to_tensor(1) 66 | b = paddle.to_tensor(2) 67 | c = paddle.to_tensor(3) 68 | d = paddle.to_tensor(4) 69 | self.assert_results(rot_two_return_a, a, b) 70 | self.assert_results(rot_two_return_b, a, b) 71 | 72 | self.assert_results(rot_three_return_a, a, b, c) 73 | self.assert_results(rot_three_return_b, a, b, c) 74 | self.assert_results(rot_three_return_c, a, b, c) 75 | 76 | self.assert_results(rot_four_return_a, a, b, c, d) 77 | self.assert_results(rot_four_return_b, a, b, c, d) 78 | self.assert_results(rot_four_return_c, a, b, c, d) 79 | self.assert_results(rot_four_return_d, a, b, c, d) 80 | 81 | 82 | if __name__ == "__main__": 83 | unittest.main() 84 | -------------------------------------------------------------------------------- /tests/test_10_build_unpack.py: -------------------------------------------------------------------------------- 1 | # BUILD_TUPLE_UNPACK (new) 2 | # BUILD_LIST_UNPACK (new) 3 | # BUILD_TUPLE_UNPACK_WITH_CALL (new) 4 | # CALL_FUNCTION_EX (new) 5 | # BUILD_MAP_UNPACK (new) 6 | # LIST_EXTEND (new) 7 | # LIST_TO_TUPLE (new) 8 | # DICT_UPDATE (new) 9 | # DICT_MERGE (new) 10 | 11 | from __future__ import annotations 12 | 13 | import unittest 14 | 15 | from test_case_base import TestCaseBase 16 | 17 | import paddle 18 | 19 | 20 | def build_tuple_unpack(x: tuple[paddle.Tensor], y: tuple[paddle.Tensor]): 21 | z = (*x, *y) 22 | 23 | return z[0] + 1 24 | 25 | 26 | def build_list_unpack(x: list[paddle.Tensor], y: list[paddle.Tensor]): 27 | z = [*x, *y] 28 | return z[0] + 1 29 | 30 | 31 | def build_tuple_unpack_with_call( 32 | x: tuple[paddle.Tensor], y: tuple[paddle.Tensor] 33 | ): 34 | z = build_tuple_unpack_with_call_inner(*x, *y) 35 | return z[0] + 1 36 | 37 | 38 | def build_tuple_unpack_with_call_inner( 39 | a: paddle.Tensor, b: paddle.Tensor, c: paddle.Tensor, d: paddle.Tensor 40 | ): 41 | z = (a, b, c, d) 42 | return z 43 | 44 | 45 | def build_map_unpack(x: dict[str, paddle.Tensor], y: dict[str, paddle.Tensor]): 46 | z = {**x, **y} 47 | return z["a"] + 1 48 | 49 | 50 | def build_map_unpack_with_call_inner( 51 | a: paddle.Tensor, b: paddle.Tensor, c: paddle.Tensor, d: paddle.Tensor 52 | ): 53 | z = {"a": a, "b": b, "c": c, "d": d} 54 | return z 55 | 56 | 57 | def build_map_unpack_with_call( 58 | x: dict[str, paddle.Tensor], y: dict[str, paddle.Tensor] 59 | ): 60 | z = build_map_unpack_with_call_inner(**x, **y) 61 | return z["a"] + 1 62 | 63 | 64 | class TestExecutor(TestCaseBase): 65 | def test_simple(self): 66 | a = paddle.to_tensor(1) 67 | b = paddle.to_tensor(2) 68 | c = paddle.to_tensor(3) 69 | d = paddle.to_tensor(4) 70 | 71 | self.assert_results(build_tuple_unpack, (a, b), (c, d)) 72 | self.assert_results(build_list_unpack, [a, b], [c, d]) 73 | self.assert_results(build_tuple_unpack_with_call, (a, b), (c, d)) 74 | self.assert_results( 75 | build_map_unpack, {"a": a, "b": b}, {"c": c, "d": d} 76 | ) 77 | self.assert_results( 78 | build_map_unpack_with_call, {"a": a, "b": b}, {"c": c, "d": d} 79 | ) 80 | 81 | 82 | if __name__ == "__main__": 83 | unittest.main() 84 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "PaddleSOT" 7 | version = "0.0.1a0" 8 | description = "A Bytecode level Implementation of Symbolic OpCode Translator For PaddlePaddle" 9 | readme = "README.md" 10 | requires-python = ">=3.7,<3.12" 11 | authors = [ 12 | {name = "PaddlePaddle", email = "Paddle-better@baidu.com"}, 13 | ] 14 | keywords = ["Framework", "Deep Learning", "JIT"] 15 | license = { file = "LICENSE" } 16 | classifiers = [ 17 | "License :: OSI Approved :: Apache Software License", 18 | "Programming Language :: Python :: 3", 19 | "Programming Language :: Python :: 3.8", 20 | "Programming Language :: Python :: 3.9", 21 | "Programming Language :: Python :: 3.10", 22 | "Programming Language :: Python :: 3.11", 23 | "Programming Language :: Python :: Implementation :: CPython", 24 | "Operating System :: OS Independent", 25 | ] 26 | 27 | [project.optional-dependencies] 28 | dev = [ 29 | "xdoctest" 30 | ] 31 | paddle-test = [ 32 | "opencv-python", 33 | "scipy", 34 | "pyyaml", 35 | "gym==0.26.2", 36 | ] 37 | 38 | [tool.setuptools.packages] 39 | find = {} 40 | 41 | [tool.black] 42 | line-length = 80 43 | skip-string-normalization = true 44 | 45 | [tool.isort] 46 | profile = "black" 47 | line_length = 80 48 | known_first_party = ["paddle", "sot"] 49 | 50 | [tool.ruff] 51 | target-version = "py37" 52 | select = [ 53 | # Pycodestyle 54 | "E", 55 | "W", 56 | 57 | # Pyflakes 58 | "F", 59 | 60 | # Comprehensions 61 | "C4", 62 | 63 | # Pyupgrade 64 | "UP", 65 | 66 | # Bugbear 67 | "B002", 68 | "B003", 69 | "B004", 70 | "B009", 71 | "B010", 72 | "B011", 73 | "B012", 74 | "B013", 75 | "B014", 76 | "B015", 77 | "B016", 78 | "B017", 79 | "B018", 80 | "B019", 81 | "B020", 82 | "B021", 83 | "B022", 84 | "B025", 85 | "B029", 86 | "B032", 87 | 88 | # Pylint 89 | "PLE", 90 | "PLC0414", 91 | "PLC3002", 92 | "PLR0206", 93 | "PLR0402", 94 | ] 95 | ignore = [ 96 | "E402", 97 | "E501", 98 | "E722", 99 | "E731", 100 | "E741", 101 | "F841", 102 | "UP015", 103 | "UP038", 104 | ] 105 | 106 | [tool.ruff.per-file-ignores] 107 | "tests/test_dup_top.py" = ["E712"] 108 | -------------------------------------------------------------------------------- /tests/test_guard_user_defined_fn.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import unittest 4 | 5 | from test_case_base import ( 6 | TestCaseBase, 7 | test_instruction_translator_cache_context, 8 | ) 9 | 10 | import paddle 11 | 12 | 13 | def test_guard_fn(fn, inp): 14 | if fn is None: 15 | return 0 16 | else: 17 | return fn(inp) 18 | 19 | 20 | class TestGuardOutputs(TestCaseBase): 21 | def test_non_operator_related_fn(self): 22 | with test_instruction_translator_cache_context() as ctx: 23 | self.assert_results( 24 | test_guard_fn, 25 | paddle.nn.functional.relu, 26 | paddle.to_tensor([1.0, -1.0]), 27 | ) 28 | self.assertEqual(ctx.translate_count, 1) 29 | self.assert_results( 30 | test_guard_fn, 31 | paddle.nn.functional.gelu, 32 | paddle.to_tensor([1.0, -1.0]), 33 | ) 34 | self.assertEqual(ctx.translate_count, 2) 35 | self.assert_results( 36 | test_guard_fn, 37 | paddle.nn.functional.relu, 38 | paddle.to_tensor([-1.0, -1.0]), 39 | ) 40 | self.assertEqual(ctx.translate_count, 2) 41 | self.assert_results( 42 | test_guard_fn, None, paddle.to_tensor([-1.0, -1.0]) 43 | ) 44 | self.assertEqual(ctx.translate_count, 3) 45 | 46 | deleted_cnt = 0 47 | 48 | class Callable: 49 | def __call__(self, var): 50 | return paddle.nn.functional.relu(var) 51 | 52 | def __del__(self): 53 | nonlocal deleted_cnt 54 | deleted_cnt += 1 55 | 56 | fn1 = Callable() 57 | fn2 = Callable() 58 | with test_instruction_translator_cache_context() as ctx: 59 | self.assert_results( 60 | test_guard_fn, fn1, paddle.to_tensor([1.0, -1.0]) 61 | ) 62 | self.assertEqual(ctx.translate_count, 1) 63 | self.assert_results( 64 | test_guard_fn, fn2, paddle.to_tensor([1.0, -1.0]) 65 | ) 66 | self.assertEqual(ctx.translate_count, 2) 67 | self.assert_results( 68 | test_guard_fn, fn2, paddle.to_tensor([1.0, -1.0]) 69 | ) 70 | self.assertEqual(ctx.translate_count, 2) 71 | 72 | 73 | if __name__ == "__main__": 74 | unittest.main() 75 | -------------------------------------------------------------------------------- /tests/tests_legacy/error_test_resnet_with_trace_cache.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ["FLAGS_cudnn_deterministic"] = "True" 4 | 5 | import random 6 | import unittest 7 | from types import MethodType 8 | 9 | import numpy as np 10 | from numpy.testing import assert_array_equal 11 | 12 | import paddle 13 | from paddle.vision import resnet50 14 | from sot import symbolic_translate 15 | from sot.symbolic.compile_cache import CompileSIRCache 16 | from sot.trace_cache_entrance import trace_cache 17 | from sot.utils.utils import execute_time 18 | 19 | 20 | @trace_cache 21 | def forward_with_cache(self, x): 22 | x = self.conv1(x) 23 | x = self.bn1(x) 24 | x = self.relu(x) 25 | x = self.maxpool(x) 26 | x = self.layer1(x) 27 | x = self.layer2(x) 28 | x = self.layer3(x) 29 | x = self.layer4(x) 30 | if self.with_pool: 31 | x = self.avgpool(x) 32 | if self.num_classes > 0: 33 | x = paddle.flatten(x, 1) 34 | x = self.fc(x) 35 | return x 36 | 37 | 38 | def run_dygraph_optimizer(inp, to_static): 39 | """dygraph train + SGD optimizer""" 40 | paddle.seed(2021) 41 | np.random.seed(2021) 42 | random.seed(2021) 43 | net = resnet50() 44 | if to_static: 45 | net.forward = MethodType(forward_with_cache, net) 46 | net.forward = symbolic_translate(net.forward) 47 | # net = paddle.jit.to_static(net) 48 | optimizer = paddle.optimizer.SGD( 49 | learning_rate=0.03, parameters=net.parameters() 50 | ) 51 | for i in range(5): 52 | optimizer.clear_grad() 53 | loss = execute_time(net)(inp) 54 | loss.backward() 55 | optimizer.step() 56 | print("===============================================") 57 | return loss 58 | 59 | 60 | class TestBackward(unittest.TestCase): 61 | def test(self): 62 | # TODO(xiongkun) add cache to speedup ! 63 | paddle.seed(2021) 64 | np.random.seed(2021) 65 | random.seed(2021) 66 | inp = paddle.rand((3, 3, 255, 255)) 67 | out1 = run_dygraph_optimizer(inp, True)[0].numpy() 68 | out2 = run_dygraph_optimizer(inp, False)[0].numpy() 69 | assert_array_equal( 70 | out1, out2, "Not Equal in dygraph and static graph", True 71 | ) 72 | assert ( 73 | CompileSIRCache().hit_num == 4 74 | ), f"CompileSIRCache hit_num should be 4, but{CompileSIRCache().hit_num}" 75 | 76 | 77 | if __name__ == "__main__": 78 | unittest.main() 79 | -------------------------------------------------------------------------------- /.github/workflows/paddle_ci.yaml: -------------------------------------------------------------------------------- 1 | name: PaddleDy2staicTests 2 | 3 | on: 4 | push: 5 | branches: [develop] 6 | pull_request: 7 | workflow_dispatch: 8 | 9 | jobs: 10 | Test: 11 | strategy: 12 | fail-fast: false 13 | matrix: 14 | os: [ubuntu-latest] 15 | python-version: ['3.8', '3.9', '3.10', '3.11'] 16 | runs-on: ${{ matrix.os }} 17 | name: python ${{ matrix.python-version }} dy2st unittests 18 | steps: 19 | - name: Checkout 20 | uses: actions/checkout@v3 21 | 22 | # RAM too small, Related link: https://github.com/actions/runner-images/discussions/7188 23 | - name: Increase swapfile 24 | run: | 25 | sudo swapoff -a 26 | sudo fallocate -l 8G /swapfile 27 | sudo chmod 600 /swapfile 28 | sudo mkswap /swapfile 29 | sudo swapon /swapfile 30 | sudo swapon --show 31 | 32 | - name: Install python 33 | uses: actions/setup-python@v4 34 | with: 35 | python-version: ${{ matrix.python-version }} 36 | 37 | - name: Install dependencies 38 | run: | 39 | python -m pip install paddlepaddle==0.0.0 -f https://www.paddlepaddle.org.cn/whl/linux/cpu-mkl/develop.html 40 | pip install ".[paddle-test]" 41 | 42 | - uses: actions/cache@v3 43 | with: 44 | path: | 45 | ~/.cache/paddle 46 | key: ${{ runner.os }}-${{ matrix.python-version }}-paddle-cache-v1 47 | 48 | - name: Paddle run check 49 | run: | 50 | python -c "import paddle; paddle.utils.run_check()" 51 | 52 | - name: Clone paddle tests 53 | working-directory: ./tests/ 54 | run: | 55 | git clone https://github.com/PaddlePaddle/Paddle 56 | cd Paddle 57 | PADDLE_COMMIT_HASH=`python -c "import paddle; import sys; sys.stdout.write(paddle.version.commit)"` 58 | echo -e "paddle commit hash: \x1b[32m${PADDLE_COMMIT_HASH}\x1b[0m" 59 | git checkout ${PADDLE_COMMIT_HASH} 60 | # skip new ir tests 61 | sed -i "s/@test_with_new_ir/@unittest.skip('skip new ir test')/g" `grep -rl "@test_with_new_ir" test/dygraph_to_static` 62 | sed -i "s/@test_and_compare_with_new_ir/# @test_and_compare_with_new_ir/g" `grep -rl "@test_and_compare_with_new_ir" test/dygraph_to_static` 63 | 64 | - name: Run unit tests 65 | working-directory: ./tests/ 66 | run: | 67 | bash ./run_all_paddle_ci.sh 68 | -------------------------------------------------------------------------------- /docs/design/tracker-and-guard.md: -------------------------------------------------------------------------------- 1 | # Guard 收集机制 2 | 3 | ## 为什么需要 Guard? 4 | 5 | 在整个实现过程中,我们需要对用户函数的原始字节码转换成转换后的字节码,但如果每次运行都完整地转写全部字节码只会导致性能的浪费,也无法实现 JIT 的效果,因此我们需要一个缓存机制,来复用已经转换后的字节码。 6 | 7 | 但并不是说任何字节码成功转换一次后都是可以复用的,因为我们的字节码变换操作是通过模拟执行得到的,而模拟执行的起点是 Eval Frame 的初始状态,主要就是函数的输入,对于不同的输入,我们得到的字节码转换结果是可能不同的,因此我们需要有一个机制来判断转换后的字节码是否有效。 8 | 9 | 由于转换的过程与输入是强相关的,在函数 Eval Frame 初始阶段,我们可以从 `frame` 中拿到函数的输入,之后我们只需要通过 `guard` 来判断一个已经缓存的字节码是否有效即可,即 `guard(frame)`,如果结果是 `True`,则认为缓存命中。 10 | 11 | guard 签名如下: 12 | 13 | ```python 14 | Guard = Callable[[types.FrameType], bool] 15 | ``` 16 | 17 | ## Guard 的收集机制 18 | 19 | 在模拟执行过程中,我们会根据字节码执行不同的操作,每一个字节码都会对应一个操作,如果我们将整个操作的链条构建起来,形成一个 DAG,就可以在任何时刻追踪到我们需要的 Guard。 20 | 21 | 我们使用 Tracker 来承载追踪操作的功能,Tracker 的 `inputs` 会持有其相关输入的 Variable,该 Tracker 将会由输出的 Variable 持有,相关数据结构如下: 22 | 23 | ```python 24 | class VariableBase: 25 | tracker: Tracker 26 | 27 | class Tracker: 28 | inputs: list[VariableBase] 29 | ``` 30 | 31 | 比如对于如下的代码 32 | 33 | ```python 34 | def foo(a: list[Tensor], b: int, c: int): 35 | d = a[b] 36 | e = d + c 37 | return e 38 | ``` 39 | 40 | 最终构建的 Python 端 DAG 如下: 41 | 42 |

43 | Tracker 44 |

45 | 46 | 有了 DAG 之后,我们只需要从需要的结点出发,找到全部需要的结点,并按照拓扑序收集一下即可~ 47 | 48 | ## DummyTracker 49 | 50 | 上图中可以看到有 DummyTracker,而 DummyTracker 相关的路径也标成了虚线,那么什么情况需要 DummyTracker 呢? 51 | 52 | 对于 LocalTracker、GetItemTracker 来说,除去 Guard 的收集,有很重要的一点是,我们可以通过这些 Tracker 还原从 frame 初始状态出发,获取这些值的方法,这包括了如下两点: 53 | 54 | - 在生成函数的字节码前,需要将输入 LOAD 到栈上,我们需要根据 Tracker 来生成 LOAD 这些输入的字节码 55 | - 在调用 Guard 时,需要根据 Tracker 来索引到新的 Frame 里的相同变量的值,这样才能进行 Guard 的判断(`new_value == old_value`) 56 | 57 | 我们可以将这种索引机制称为 Source,而大多数中间结点是经过计算得到的,我们并不需要去还原这些中间结点,比如 `c = a + b`,`c` 是由 `BINARY_ADD` 构建得到的,我们的 Source 只需要分别索引 `a` 和 `b` 的来源,而我们的 Guard 也只需要分别 Guard 住 `a` 和 `b` 即可。 58 | 59 | 因此对于这种中间结点,我们只需要知道它是由什么构建得到即可,即只需要知道 inputs 是什么,对于这些结点,我们使用 DummyTracker 来作为连接结点,DummyTracker 不会承担 Source 的索引功能,只会承担 DAG 的连接功能,以便 Guard 的收集。 60 | 61 | ## Guard 收集的短路机制 62 | 63 | 对于如下的 case 64 | 65 | ```python 66 | def foo(x): 67 | if x < 4: 68 | ... 69 | else: 70 | ... 71 | 72 | foo(9) 73 | foo(10) 74 | ``` 75 | 76 | 如果我们的 Guard 收集机制是遍历全部结点的话,会强制 Guard 住 `x == 9`,所以第二次调用 `foo(10)` 时会 cache miss。 77 | 78 | 为了减少 cache miss 的概率,我们增加了一个短路机制,当一个 Tracker 的所有输入都不是 DummyTracker 时,可以认为从该 Tracker 上所获得的 Guard 会从其 inputs 所获得的更加精准,就不需要再从其 inputs 收集 Guard 了,可以大大降低重新编译的概率。 79 | -------------------------------------------------------------------------------- /tests/tests_legacy/test_case_base.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import unittest 3 | 4 | import numpy as np 5 | 6 | import paddle 7 | from sot import symbolic_translate 8 | from sot.proxy_tensor import ProxyTensorContext 9 | from sot.utils import is_proxy_tensor, no_eval_frame 10 | 11 | 12 | class TestCaseBase(unittest.TestCase): 13 | def assert_results(self, func, *inputs): 14 | sym_output = symbolic_translate(func)(*inputs) 15 | paddle_output = func(*inputs) 16 | np.testing.assert_allclose(sym_output, paddle_output) 17 | 18 | 19 | @no_eval_frame 20 | def check_live_vars(live_vars, dead_vars): 21 | current_frame = inspect.currentframe() 22 | assert current_frame is not None 23 | 24 | no_eval_frame_func_frame = current_frame.f_back 25 | assert no_eval_frame_func_frame is not None 26 | assert no_eval_frame_func_frame.f_code.co_name == "no_eval_frame_func" 27 | 28 | test_case_func_frame = no_eval_frame_func_frame.f_back 29 | assert test_case_func_frame is not None 30 | 31 | runtime_live_proxy_tensors = set( 32 | ProxyTensorContext().runtime_proxy_tensor_to_name.keys() 33 | ) 34 | runtime_live_eager_tensors = set( 35 | ProxyTensorContext().tensor_to_proxy_tensor.keys() 36 | ) 37 | 38 | for live_var in live_vars: 39 | assert live_var in test_case_func_frame.f_locals 40 | local_var = test_case_func_frame.f_locals[live_var] 41 | if is_proxy_tensor(local_var): 42 | proxy_tensor_id = id(local_var) 43 | assert ( 44 | proxy_tensor_id in runtime_live_proxy_tensors 45 | ), f"{live_var} ({local_var.name}) is not live" 46 | elif isinstance(local_var, paddle.Tensor): 47 | eager_tensor_id = id(local_var) 48 | assert ( 49 | eager_tensor_id in runtime_live_eager_tensors 50 | ), f"{live_var} ({local_var.name}) is not live" 51 | 52 | for dead_var in dead_vars: 53 | assert dead_var in test_case_func_frame.f_locals 54 | local_var = test_case_func_frame.f_locals[dead_var] 55 | print(dead_var, local_var) 56 | if is_proxy_tensor(local_var): 57 | proxy_tensor_id = id(local_var) 58 | assert ( 59 | proxy_tensor_id not in runtime_live_proxy_tensors 60 | ), f"{dead_var} ({local_var.name}) is live" 61 | elif isinstance(local_var, paddle.Tensor): 62 | eager_tensor_id = id(local_var) 63 | assert ( 64 | eager_tensor_id not in runtime_live_eager_tensors 65 | ), f"{dead_var} ({local_var.name}) is live" 66 | -------------------------------------------------------------------------------- /sot/utils/paddle_api_config.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | 3 | import paddle 4 | 5 | 6 | def is_inplace_api(func): 7 | inplace_apis = {paddle.static.setitem} 8 | return func in inplace_apis 9 | 10 | 11 | def get_tensor_methods(): 12 | return [ 13 | member_name 14 | for member_name, member in inspect.getmembers(paddle.static.Variable) 15 | if inspect.isfunction(member) 16 | ] 17 | 18 | 19 | def get_paddle_api(): 20 | modules = [ 21 | paddle, 22 | paddle.nn.functional, 23 | paddle.linalg, 24 | paddle.signal, 25 | paddle.fft, 26 | paddle.vision.ops, 27 | ] 28 | special_paddle_apis = [paddle.tensor.fill_constant] 29 | non_operator_related_apis = [ 30 | paddle.in_dynamic_mode, 31 | paddle.save, 32 | paddle.load, 33 | paddle.get_cuda_rng_state, 34 | paddle.set_rng_state, 35 | paddle.set_cuda_rng_state, 36 | paddle.get_rng_state, 37 | paddle.set_default_dtype, 38 | paddle.check_shape, 39 | paddle.summary, 40 | paddle.finfo, 41 | paddle.iinfo, 42 | paddle.enable_static, 43 | paddle.disable_static, 44 | paddle.is_grad_enabled, 45 | ] 46 | # TODO: users should not call static_apis, but we need to use, so add static_apis here temporary 47 | static_apis = [paddle.static.setitem, paddle.static.accuracy] 48 | paddle_api_list = [] 49 | for module in modules: 50 | for fn_name in getattr(module, "__all__", []): 51 | fn = getattr(module, fn_name) 52 | if inspect.isfunction(fn): 53 | paddle_api_list.append(fn) 54 | return list( 55 | set(special_paddle_apis) 56 | | set(static_apis) 57 | | set(paddle_api_list) - set(non_operator_related_apis) 58 | ) 59 | 60 | 61 | paddle_tensor_methods = get_tensor_methods() 62 | paddle_api_list = get_paddle_api() 63 | 64 | # TODO(Aurelius84): It seems that we use it to judge 'in_paddle_module()'. 65 | # Bug what does 'is_paddle_module' really means? Is all paddle.xx sub module 66 | # considered as paddle module? 67 | paddle_api_module_prefix = { 68 | "paddle.nn.functional", 69 | "paddle.nn.layer.activation", 70 | } 71 | 72 | break_graph_set = set() 73 | 74 | 75 | break_graph_tensor_method = { 76 | 'register_hook', 77 | 'numpy', 78 | 'clear_gradient', 79 | # TODO: Browse all possible functions and make prior judgments. 80 | } 81 | 82 | 83 | def is_break_graph_tensor_methods(method_name): 84 | return method_name in break_graph_tensor_method 85 | 86 | 87 | def add_break_graph_apis(apis: list): 88 | break_graph_set.update(apis) 89 | -------------------------------------------------------------------------------- /tests/tests_legacy/test_resnet50_backward.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ["FLAGS_cudnn_deterministic"] = "True" 4 | 5 | import random 6 | import unittest 7 | 8 | import numpy as np 9 | from numpy.testing import assert_array_equal 10 | 11 | import paddle 12 | from paddle.vision import resnet50 13 | from sot import symbolic_translate 14 | from sot.utils.utils import execute_time 15 | 16 | 17 | def resnet_call(net: paddle.nn.Layer, x: paddle.Tensor): 18 | return net(x) 19 | 20 | 21 | def run_dygraph_optimizer(inp): 22 | """dygraph train + SGD optimizer""" 23 | paddle.seed(2021) 24 | np.random.seed(2021) 25 | random.seed(2021) 26 | net = resnet50() 27 | optimizer = paddle.optimizer.SGD( 28 | learning_rate=0.03, parameters=net.parameters() 29 | ) 30 | for i in range(5): 31 | optimizer.clear_grad() 32 | loss = execute_time(net)(inp) 33 | loss.backward() 34 | optimizer.step() 35 | return loss 36 | 37 | 38 | def run_symbolic_optimizer(inp): 39 | """dygraph train + SGD optimizer""" 40 | paddle.seed(2021) 41 | np.random.seed(2021) 42 | random.seed(2021) 43 | net = resnet50() 44 | net_wrapper = symbolic_translate(resnet_call) 45 | optimizer = paddle.optimizer.SGD( 46 | learning_rate=0.03, parameters=net.parameters() 47 | ) 48 | for i in range(5): 49 | optimizer.clear_grad() 50 | loss = execute_time(net_wrapper)(net, inp) 51 | loss.backward() 52 | optimizer.step() 53 | return loss 54 | 55 | 56 | def run_to_static_optimizer(inp): 57 | """dygraph train + SGD optimizer""" 58 | paddle.seed(2021) 59 | np.random.seed(2021) 60 | random.seed(2021) 61 | net = resnet50() 62 | net = paddle.jit.to_static(net) 63 | optimizer = paddle.optimizer.SGD( 64 | learning_rate=0.03, parameters=net.parameters() 65 | ) 66 | for i in range(5): 67 | optimizer.clear_grad() 68 | loss = execute_time(net)(inp) 69 | loss.backward() 70 | optimizer.step() 71 | return loss 72 | 73 | 74 | class TestBackward(unittest.TestCase): 75 | def test(self): 76 | # TODO(xiongkun) add cache to speedup ! 77 | paddle.seed(2021) 78 | np.random.seed(2021) 79 | random.seed(2021) 80 | inp = paddle.rand((3, 3, 255, 255)) 81 | print("Start Run SymbolicTranslate:") 82 | out2 = run_symbolic_optimizer(inp)[0].numpy() 83 | print("Start Run Dygraph:") 84 | out1 = run_dygraph_optimizer(inp)[0].numpy() 85 | print("Start Run To Static:") 86 | out1 = run_to_static_optimizer(inp)[0].numpy() 87 | assert_array_equal( 88 | out1, out2, "Not Equal in dygraph and static graph", True 89 | ) 90 | 91 | 92 | if __name__ == "__main__": 93 | unittest.main() 94 | -------------------------------------------------------------------------------- /tests/test_resnet50_backward.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ["FLAGS_cudnn_deterministic"] = "True" 4 | 5 | import random 6 | import unittest 7 | 8 | import numpy as np 9 | from numpy.testing import assert_array_equal 10 | 11 | import paddle 12 | from paddle.vision import resnet50 13 | from sot import symbolic_translate 14 | from sot.utils.utils import execute_time 15 | 16 | 17 | def resnet_call(net: paddle.nn.Layer, x: paddle.Tensor): 18 | return net(x) 19 | 20 | 21 | def run_dygraph_optimizer(inp): 22 | """dygraph train + SGD optimizer""" 23 | paddle.seed(2021) 24 | np.random.seed(2021) 25 | random.seed(2021) 26 | net = resnet50() 27 | optimizer = paddle.optimizer.SGD( 28 | learning_rate=0.03, parameters=net.parameters() 29 | ) 30 | for i in range(5): 31 | optimizer.clear_grad() 32 | loss = execute_time(net)(inp) 33 | loss.backward() 34 | optimizer.step() 35 | return loss 36 | 37 | 38 | def run_symbolic_optimizer(inp): 39 | """dygraph train + SGD optimizer""" 40 | paddle.seed(2021) 41 | np.random.seed(2021) 42 | random.seed(2021) 43 | net = resnet50() 44 | net_wrapper = symbolic_translate(resnet_call) 45 | optimizer = paddle.optimizer.SGD( 46 | learning_rate=0.03, parameters=net.parameters() 47 | ) 48 | for i in range(5): 49 | optimizer.clear_grad() 50 | loss = execute_time(net_wrapper)(net, inp) 51 | loss.backward() 52 | optimizer.step() 53 | return loss 54 | 55 | 56 | def run_to_static_optimizer(inp): 57 | """dygraph train + SGD optimizer""" 58 | paddle.seed(2021) 59 | np.random.seed(2021) 60 | random.seed(2021) 61 | net = resnet50() 62 | net = paddle.jit.to_static(net, enable_fallback=False) 63 | optimizer = paddle.optimizer.SGD( 64 | learning_rate=0.03, parameters=net.parameters() 65 | ) 66 | for i in range(5): 67 | optimizer.clear_grad() 68 | loss = execute_time(net)(inp) 69 | loss.backward() 70 | optimizer.step() 71 | return loss 72 | 73 | 74 | class TestBackward(unittest.TestCase): 75 | def test(self): 76 | # TODO(xiongkun) add cache to speedup ! 77 | paddle.seed(2021) 78 | np.random.seed(2021) 79 | random.seed(2021) 80 | inp = paddle.rand((3, 3, 255, 255)) 81 | print("Start Run SymbolicTranslate:") 82 | out2 = run_symbolic_optimizer(inp)[0].numpy() 83 | print("Start Run Dygraph:") 84 | out1 = run_dygraph_optimizer(inp)[0].numpy() 85 | print("Start Run To Static:") 86 | out1 = run_to_static_optimizer(inp)[0].numpy() 87 | assert_array_equal( 88 | out1, out2, "Not Equal in dygraph and static graph", True 89 | ) 90 | 91 | 92 | if __name__ == "__main__": 93 | unittest.main() 94 | -------------------------------------------------------------------------------- /tests/test_enumerate.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from test_case_base import TestCaseBase, strict_mode_guard 4 | 5 | import paddle 6 | 7 | 8 | def test_enumerate_1(x: int, y: int): 9 | for id, val in enumerate(range(x)): 10 | if id % 2 == 0: 11 | y += val 12 | return y 13 | 14 | 15 | def test_enumerate_2(x: list): 16 | return list(enumerate(x)) 17 | 18 | 19 | def test_enumerate_3(x: list): 20 | return tuple(enumerate(x)) 21 | 22 | 23 | def test_enumerate_4(x: paddle.Tensor): 24 | sum = 0 25 | for idx, val in enumerate(x): 26 | sum += val 27 | return sum 28 | 29 | 30 | # TODO(zmh): support range for tensor 31 | def test_enumerate_5(x: paddle.Tensor): 32 | sum = 0 33 | 34 | for idx, val in enumerate(x): 35 | for i in range(val): 36 | sum += val 37 | return sum 38 | 39 | 40 | def test_enumerate_6(x: paddle.Tensor): 41 | sum = 0 42 | 43 | for idx, val in enumerate(x): 44 | for i in range(idx): 45 | sum += val 46 | return sum 47 | 48 | 49 | def test_enumerate_7(x: paddle.Tensor): 50 | sum = 0 51 | x = x.flatten() 52 | for idx, val in enumerate(x): 53 | sum += val 54 | return sum 55 | 56 | 57 | # TODO(zmh): support -1 58 | def test_enumerate_8(x: paddle.Tensor): 59 | sum = 0 60 | x = paddle.nonzero(x, as_tuple=False) 61 | for idx, val in enumerate(x): 62 | sum += val 63 | return sum 64 | 65 | 66 | def test_enumerate_10(layer_list, x): 67 | sum = 0 68 | for idx, layer in enumerate(layer_list): 69 | sum += layer(x) 70 | return sum 71 | 72 | 73 | class TestExecutor(TestCaseBase): 74 | def test_cases(self): 75 | x = 8 76 | y = 5 77 | ty = paddle.randn((10, 10)) 78 | layer_list = paddle.nn.LayerList( 79 | [paddle.nn.Linear(10, 10) for _ in range(3)] 80 | ) 81 | 82 | self.assert_results(test_enumerate_1, x, y) 83 | self.assert_results(test_enumerate_2, [2, 4, 6, 8, 10]) 84 | self.assert_results(test_enumerate_3, [2, 4, 6, 8, 10]) 85 | 86 | self.assert_results(test_enumerate_4, ty) 87 | # TODO(zmh): support range for tensor 88 | 89 | with strict_mode_guard(0): 90 | self.assert_results(test_enumerate_5, paddle.to_tensor([1, 2, 3])) 91 | self.assert_results(test_enumerate_6, paddle.to_tensor([1, 2, 3])) 92 | self.assert_results(test_enumerate_7, ty) 93 | # TODO(zmh): support -1 94 | 95 | with strict_mode_guard(0): 96 | self.assert_results(test_enumerate_8, ty) 97 | 98 | self.assert_results(test_enumerate_10, layer_list, paddle.randn((10,))) 99 | 100 | 101 | if __name__ == "__main__": 102 | unittest.main() 103 | -------------------------------------------------------------------------------- /docs/design/stringify-guard.md: -------------------------------------------------------------------------------- 1 | # 字符串化 Guard 2 | 3 | ## Guard 原设计及其遇到的问题 4 | 5 | 我们的原来的 Guard 设计是,每个子 Guard 都是一个 lambda 函数,签名如下: 6 | 7 | ```python 8 | Guard = Callable[[types.FrameType], bool] 9 | ``` 10 | 11 | 在 Guard 收集时,我们可以通过 `compose_guards` 来整合成一个总 Guard,这个总 Guard 也是一个 lambda 函数,它的实现非常简单,就是多个 Guard 的 `and` 串联: 12 | 13 | ```python 14 | def compose_guards(guards: list[Guard]) -> Guard: 15 | def composed_guard_fn(frame: types.FrameType) -> bool: 16 | ret = True 17 | for guard in guards: 18 | ret = ret and guard(frame) 19 | return ret 20 | 21 | return composed_guard_fn 22 | ``` 23 | 24 | 这个设计在正确性上没有太大问题,但是经测试发现,该设计会造成非常大的性能开销。我们分别测试了将总 Guard 直接设置为 `lambda _: True` 和每个 子 Guard 设置为 `lambda _: True`,后者比前者仅仅多了函数调用逻辑,不过我们发现后者仍然会比前者多出不少的性能开销,这很好理解,因为每个子 Guard 都需要一次函数调用,函数调用的开销是非常大的。 25 | 26 | ## 字符串化 Guard 的设计 27 | 28 | 为了避免 Guard 中的函数调用性能开销,我们将每个子 Guard 表示为字符串,在最后汇总的时候使用 `eval` 来生成一个 lambda 函数,这样最后的总 Guard 只是一个函数调用,没有了多个函数调用的开销。 29 | 30 | 为了能够让每个 Guard 都字符串化,Guard 中所使用的 `trace_value_from_frame` 也需要字符串化。 31 | 32 | 另外,对于 lambda 函数来说,可利用闭包来捕获自由变量,但字符串是没有这一能力的,为了能够让字符串化的 Guard 也能够捕获自由变量,我们需要在字符串化的同时,将自由变量一并保存下来,因此子 Guard 和 `trace_value_from_frame` 不仅会返回字符串,还会返回自由变量 dict。为了方便管理,将该数据结构命名为 `StringifyExpression`: 33 | 34 | ```python 35 | @dataclass 36 | class StringifyExpression: 37 | expr: str 38 | free_vars: dict[str, Any] 39 | ``` 40 | 41 | 比如编写 BuiltinTracker 时,对比如下: 42 | 43 | ```diff 44 | class BuiltinTracker: 45 | def trace_value_from_frame(self): 46 | - return lambda frame: builtins.__dict__[self.name] 47 | + return StringifyExpression( 48 | + f"builtins.__dict__[{self.name}]", {"builtins": builtins} 49 | + ) 50 | ``` 51 | 52 | 这里 `builtins` 是一个本应通过闭包捕获的自由变量,在字符串化后,通过 `free_vars` 字段来保存。 53 | 54 | Variable 和 Tracker 相关函数变化成如下: 55 | 56 | ```python 57 | class VariableBase: 58 | def make_stringify_guard(self) -> StringifyExpression: 59 | ... 60 | 61 | class Tracker: 62 | def trace_value_from_frame(self) -> StringifyExpression: 63 | ... 64 | ``` 65 | 66 | 最终合并后的总 Guard 签名不变,仍然是 `Guard`: 67 | 68 | ```python 69 | Guard = Callable[[types.FrameType], bool] 70 | 71 | def make_guard(stringify_guards: list[StringifyExpression]) -> Guard: 72 | num_guards = len(stringify_guards) 73 | if not num_guards: 74 | return lambda frame: True 75 | union_guard_expr = reduce(lambda x, y: x & y, stringify_guards) 76 | guard_string = f"lambda frame: {union_guard_expr.expr}" 77 | guard = eval( 78 | guard_string, 79 | union_guard_expr.free_vars, 80 | ) 81 | log(3, f"[Guard]: {guard_string}\n") 82 | assert callable(guard), "guard must be callable." 83 | 84 | return guard 85 | ``` 86 | 87 | 实现也很简单,就是字符串上的 `and` 拼接(`reduce` 函数里使用重载后的 `&`),之后 `eval` 并传入自由变量即可。 88 | 89 | ## 字符串化 Guard 书写的注意点 90 | 91 | 1. 应注意捕获自由变量,字符串无法自动捕获自由变量 92 | 2. 应注意字符串化前后的比较逻辑可能有所不同 93 | 3. 应注意尽可能将计算在「编译时」就计算好,编码在 Guard 字符串中,而不是传到运行时再进行计算,尽可能降低运行时开销 94 | -------------------------------------------------------------------------------- /tests/test_cost_model.py: -------------------------------------------------------------------------------- 1 | import time 2 | import unittest 3 | 4 | from test_case_base import TestCaseBase, cost_model_guard 5 | 6 | import paddle 7 | from sot import psdb, symbolic_translate 8 | from sot.utils import StepInfoManager, StepState 9 | 10 | 11 | def dyn_fast(x, net, iter_): 12 | for i in iter_: 13 | x = net(x) 14 | return x 15 | 16 | 17 | def sot_fast_with_single_graph(x, net): 18 | if not psdb.in_sot(): 19 | time.sleep(0.1) 20 | return x + 1 21 | 22 | 23 | def sot_fast_with_multi_graph(x, net): 24 | if not psdb.in_sot(): 25 | time.sleep(0.1) 26 | x = x + 1 27 | psdb.breakgraph() 28 | x = x + 2 29 | return x 30 | 31 | 32 | class Net(paddle.nn.Layer): 33 | def __init__(self): 34 | super().__init__() 35 | self.linear = paddle.nn.Linear(10, 10) 36 | 37 | def forward(self, x): 38 | if not psdb.in_sot(): 39 | time.sleep(0.1) 40 | x = x / 3 41 | x = x + 5 42 | x = self.linear(x) 43 | return x 44 | 45 | 46 | class TestCostModel(TestCaseBase): 47 | @cost_model_guard("True") 48 | def test_dyn_fast(self): 49 | x = paddle.rand([10]) 50 | net = paddle.nn.Linear(10, 10) 51 | sot_fn = symbolic_translate(dyn_fast) 52 | for i in range(60): 53 | sot_fn(x, net, iter(range(10))) 54 | 55 | state = StepInfoManager().step_record[dyn_fast.__code__].state 56 | assert state == StepState.RUN_DYN 57 | 58 | @cost_model_guard("True") 59 | def test_sot_fast_with_multi_graph(self): 60 | x = paddle.rand([10]) 61 | net = paddle.nn.Linear(10, 10) 62 | sot_fn = symbolic_translate(sot_fast_with_multi_graph) 63 | for i in range(30): 64 | sot_fn(x, net) 65 | 66 | state = ( 67 | StepInfoManager() 68 | .step_record[sot_fast_with_multi_graph.__code__] 69 | .state 70 | ) 71 | assert state == StepState.RUN_SOT 72 | 73 | @cost_model_guard("True") 74 | def test_sot_fast_with_single_graph(self): 75 | x = paddle.rand([10]) 76 | net = paddle.nn.Linear(10, 10) 77 | for i in range(30): 78 | symbolic_translate(sot_fast_with_single_graph)(x, net) 79 | 80 | state = ( 81 | StepInfoManager() 82 | .step_record[sot_fast_with_single_graph.__code__] 83 | .state 84 | ) 85 | assert state == StepState.RUN_SOT 86 | 87 | @cost_model_guard("True") 88 | def test_net(self): 89 | x = paddle.rand([10]) 90 | net = Net() 91 | net = paddle.jit.to_static(net, enable_fallback=True) 92 | for i in range(30): 93 | x = net(x) 94 | 95 | state = StepInfoManager().step_record[Net.forward.__code__].state 96 | assert state == StepState.RUN_SOT 97 | 98 | 99 | if __name__ == "__main__": 100 | unittest.main() 101 | -------------------------------------------------------------------------------- /tests/test_map.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import unittest 4 | from typing import Iterable 5 | 6 | from test_case_base import TestCaseBase, strict_mode_guard 7 | 8 | import sot 9 | from sot.psdb import check_no_breakgraph 10 | 11 | 12 | def double_num(num: float | int): 13 | return num * 2 14 | 15 | 16 | def double_num_with_breakgraph(num: float | int): 17 | sot.psdb.breakgraph() 18 | return num * 2 19 | 20 | 21 | @check_no_breakgraph 22 | def test_map_list(x: list): 23 | return list(map(double_num, x)) 24 | 25 | 26 | @check_no_breakgraph 27 | def test_map_list_comprehension(x: list): 28 | return [i for i in map(double_num, x)] # noqa: C416 29 | 30 | 31 | @check_no_breakgraph 32 | def test_map_tuple(x: tuple): 33 | return tuple(map(double_num, x)) 34 | 35 | 36 | @check_no_breakgraph 37 | def test_map_tuple_comprehension(x: tuple): 38 | return [i for i in map(double_num, x)] # noqa: C416 39 | 40 | 41 | @check_no_breakgraph 42 | def test_map_range(x: Iterable): 43 | return list(map(double_num, x)) 44 | 45 | 46 | @check_no_breakgraph 47 | def test_map_range_comprehension(x: Iterable): 48 | return [i for i in map(double_num, x)] # noqa: C416 49 | 50 | 51 | def add_dict_prefix(key: str): 52 | return f"dict_{key}" 53 | 54 | 55 | @check_no_breakgraph 56 | def test_map_dict(x: dict): 57 | return list(map(add_dict_prefix, x)) 58 | 59 | 60 | @check_no_breakgraph 61 | def test_map_dict_comprehension(x: dict): 62 | return [i for i in map(add_dict_prefix, x)] # noqa: C416 63 | 64 | 65 | def test_map_list_with_breakgraph(x: list): 66 | return list(map(double_num_with_breakgraph, x)) 67 | 68 | 69 | @check_no_breakgraph 70 | def test_map_unpack(x: list): 71 | a, b, c, d = map(double_num, x) 72 | return a, b, c, d 73 | 74 | 75 | @check_no_breakgraph 76 | def test_map_for_loop(x: list): 77 | res = 0 78 | for i in map(double_num, x): 79 | res += i 80 | return res 81 | 82 | 83 | class TestMap(TestCaseBase): 84 | def test_map(self): 85 | self.assert_results(test_map_list, [1, 2, 3, 4]) 86 | self.assert_results(test_map_tuple, (1, 2, 3, 4)) 87 | self.assert_results(test_map_range, range(5)) 88 | self.assert_results(test_map_dict, {"a": 1, "b": 2, "c": 3}) 89 | 90 | def test_map_comprehension(self): 91 | self.assert_results(test_map_list_comprehension, [1, 2, 3, 4]) 92 | self.assert_results(test_map_tuple_comprehension, (1, 2, 3, 4)) 93 | self.assert_results(test_map_range_comprehension, range(5)) 94 | self.assert_results( 95 | test_map_dict_comprehension, {"a": 1, "b": 2, "c": 3} 96 | ) 97 | 98 | def test_map_with_breakgraph(self): 99 | with strict_mode_guard(0): 100 | self.assert_results(test_map_list_with_breakgraph, [1, 2, 3, 4]) 101 | 102 | def test_map_unpack(self): 103 | self.assert_results(test_map_unpack, [1, 2, 3, 4]) 104 | 105 | def test_map_for_loop(self): 106 | self.assert_results(test_map_for_loop, [7, 8, 9, 10]) 107 | 108 | 109 | if __name__ == "__main__": 110 | unittest.main() 111 | -------------------------------------------------------------------------------- /docs/compat/python311/CALL.md: -------------------------------------------------------------------------------- 1 | # CALL 相关字节码适配 2 | 3 | ## CALL 相关字节码 4 | 5 | 函数调用主要涉及 LOAD 和 CALL 两类字节码,在 Python 3.10 及之前(以下简称 Python 3.10)和 Python 3.11 Python 生成的字节码发生了变化 6 | 7 | 在 Python 3.10,对于 function call 和 method call 会生成两种不同的 LOAD + CALL 字节码,而在 Python 3.11 将 CALL 进行了统一,具体如下 8 | 9 | | code | `b(1)` | `a.b(1)` | 10 | |-|-|-| 11 | | 3.10 | `LOAD_GLOBAL`
-
`CALL_FUNCTION` | `LOAD_METHOD`
-
`CALL_METHOD` | 12 | | 3.11 | `PUSH_NULL`
`LOAD_GLOBAL`
-
`PRECALL`
`CALL` | `LOAD_METHOD`
-
`PRECALL`
`CALL` | 13 | 14 | > **Note** 15 | > 16 | > - function call 指 `b(1)` 这种形式,method call 指 `a.b(1)` 这种形式,注意后者虽然是 method call,但 `a.b` 不一定是 method,也可能只是普通的 function,比如 `paddle.abs`,在编译时时我们无法知道它具体的类型,只是从语法结构上我们会认为其是 method call,要注意语法形式上的 function call 和 method call 以及运行时 function 和 method 的区别,后者区别见 [函数和方法](../../notes/function-and-method.md) 17 | > - 实际使用 dis 在 Python 3.11 下 `b(1)` 的字节码会发现字节码是 `LOAD_GLOBAL 1 (NULL + b)`,而其实际上只是 `PUSH_NULL` + `LOAD_GLOBAL` 字节码序列经过一个 pass 优化后的结果(见 [cpython 3.11 compile.c - optimize_basic_block](https://github.com/python/cpython/blob/3.11/Python/compile.c#L9034-L9040)),实际上等价于 `PUSH_NULL` + `LOAD_GLOBAL` 18 | > - 注意 `LOAD_GLOBAL` 只是其中一种 LOAD 指令而已,实际上该处可能是 `LOAD_FAST` 等指令 19 | 20 | ## Python 3.10 相关字节码的行为 21 | 22 | 在 Python 3.10,CALL 有两种,一种是 `CALL_FUNCTION`,简单来说就是把栈上的函数取出来直接 CALL,其往往会搭配 `LOAD_GLOBAL` 等 LOAD 指令 23 | 24 | 另一种是 `CALL_METHOD`,其往往会搭配 `LOAD_METHOD`,因为在运行时才能知道它具体是 function 还是 method,因此在 `LOAD_METHOD` 时候会根据情况来判断具体向栈上放什么元素,相关源码见 [cpython 3.10 ceval.c - LOAD_METHOD](https://github.com/python/cpython/blob/3.10/Python/ceval.c#L4122-L4157),具体如下: 25 | 26 | - 如果是 method,那么向栈上放 27 | 28 | ``` 29 | meth | self | arg1 | ... | argN 30 | ``` 31 | 32 | - 如果是 function,那么向栈上放 33 | 34 | ``` 35 | NULL | meth | arg1 | ... | argN 36 | ``` 37 | 38 | 两者的栈布局是完全不同的 39 | 40 | `CALL_METHOD` 时,则会根据栈的布局来判断这是一个 function 还是 method,相关源码见 [cpython 3.10 ceval.c - CALL_METHOD](https://github.com/python/cpython/blob/3.10/Python/ceval.c#L4159-L4207) 41 | 42 | 其实就是看 `-oparg-2` 位置是不是 `NULL` 而已,如果是就认为其是一个 function,否则认为其是一个 method 43 | 44 | ## Python 3.11 相关字节码的行为 45 | 46 | Python 3.10 为两种语法形式生成了不同的 CALL 字节码,Python 3.11 则是将两者进行了统一,统一生成字节码 `PRECALL` + `CALL`,其实就是将 `CALL_METHOD` 拆成两部分,`PRECALL` 用于根据栈的布局来判断是 function 还是 method,如果是 function 布局,但其实际上是一个 method,就将其调整成 method 布局,之后 `CALL` 会进行函数调用,具体代码见 [cpython 3.11 ceval.c - PRECALL](https://github.com/python/cpython/blob/3.11/Python/ceval.c#L4657-L4701) 47 | 48 | 那么「如果是 function 布局,但其实际上是一个 method,就将其调整成 method 布局」是指什么呢? 49 | 50 | 对于 `LOAD_GLOBAL` + `PRECALL` + `CALL` 的 function call 布局,`LOAD_GLOBAL` 可能 `LOAD` 任何对象,当然可能其本身就已经是一个 method 了,比如函数作为一个参数传入 51 | 52 | ```python 53 | def foo(method, x): 54 | method(x) 55 | ``` 56 | 57 | 这里 method 是通过 `LOAD_FAST`(和 `LOAD_GLOBAL` 是同一类)LOAD 到栈上的,在 CALL 的时候其栈布局必然是 function call 的布局,但其实际上是一个 method,在这种情况下 `PRECALL` 便会调整其布局,将其变为 method call 布局 58 | 59 | 在 `PRECALL` 之后,通过栈的布局是否是 method call 布局就可以完全确定调用对象是否是 method 了,`CALL` 时对 method 对象的处理是统一的,即 `A.b(a, *args)`,在 Python 3.10 之前,通过 method call 形式 LOAD 到栈上的 method 同样是 `A.b(a, *args)` 调用的,而通过 function call 形式 LOAD 到栈上的 method(`a.b`)则是直接 `a.b(*args)` 调用的 60 | 61 | 不过 `LOAD_GLOBAL` 和 `LOAD_METHOD` 在处理 function 时是有一点差别的,就是 `LOAD_METHOD` 在处理 function 时会先 push 一个 `NULL` 到栈上,为了能够完全统一两者,在遇到 function call 形式时,编译时在生成 `LOAD_GLOBAL` 之前会先插入一条 `PUSH_NULL`,这样两者就一致了~ 62 | -------------------------------------------------------------------------------- /tests/test_11_jumps.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import unittest 4 | 5 | from test_case_base import TestCaseBase 6 | 7 | import paddle 8 | from sot.psdb import check_no_breakgraph 9 | 10 | 11 | @check_no_breakgraph 12 | def pop_jump_if_false(x: bool, y: paddle.Tensor): 13 | if x: 14 | y += 1 15 | else: 16 | y -= 1 17 | return y 18 | 19 | 20 | @check_no_breakgraph 21 | def pop_jump_if_true(x: bool, y: bool, z: paddle.Tensor): 22 | return (x or y) and z 23 | 24 | 25 | @check_no_breakgraph 26 | def jump_if_false_or_pop(x: bool, y: paddle.Tensor): 27 | return x and (y + 1) 28 | 29 | 30 | @check_no_breakgraph 31 | def jump_if_true_or_pop(x: bool, y: paddle.Tensor): 32 | return x or (y + 1) 33 | 34 | 35 | @check_no_breakgraph 36 | def jump_absolute(x: int, y: paddle.Tensor): 37 | while x > 0: 38 | y += 1 39 | x -= 1 40 | return y 41 | 42 | 43 | @check_no_breakgraph 44 | def pop_jump_if_none(x: bool, y: paddle.Tensor): 45 | if x is not None: 46 | y += 1 47 | else: 48 | y -= 1 49 | return y 50 | 51 | 52 | @check_no_breakgraph 53 | def pop_jump_if_not_none(x: bool, y: paddle.Tensor): 54 | if x is None: 55 | y += 1 56 | else: 57 | y -= 1 58 | return y 59 | 60 | 61 | a = paddle.to_tensor(1) 62 | b = paddle.to_tensor(2) 63 | c = paddle.to_tensor(3) 64 | d = paddle.to_tensor(4) 65 | 66 | true_tensor = paddle.to_tensor(True) 67 | false_tensor = paddle.to_tensor(False) 68 | 69 | 70 | class TestExecutor(TestCaseBase): 71 | def test_simple(self): 72 | self.assert_results(jump_absolute, 5, a) 73 | 74 | self.assert_results(pop_jump_if_false, True, a) 75 | self.assert_results(pop_jump_if_false, False, a) 76 | self.assert_results(jump_if_false_or_pop, True, a) 77 | self.assert_results(jump_if_false_or_pop, False, a) 78 | self.assert_results(jump_if_true_or_pop, True, a) 79 | self.assert_results(jump_if_true_or_pop, False, a) 80 | self.assert_results(pop_jump_if_true, True, False, a) 81 | self.assert_results(pop_jump_if_true, False, False, a) 82 | 83 | self.assert_results(pop_jump_if_none, None, a) 84 | self.assert_results(pop_jump_if_none, True, a) 85 | self.assert_results(pop_jump_if_not_none, None, a) 86 | self.assert_results(pop_jump_if_not_none, True, a) 87 | 88 | def test_breakgraph(self): 89 | self.assert_results(pop_jump_if_false, true_tensor, a) 90 | self.assert_results(jump_if_false_or_pop, true_tensor, a) 91 | self.assert_results(jump_if_true_or_pop, false_tensor, a) 92 | self.assert_results(pop_jump_if_true, true_tensor, false_tensor, a) 93 | self.assert_results(jump_absolute, 5, a) 94 | self.assert_results(pop_jump_if_false, false_tensor, a) 95 | self.assert_results(jump_if_false_or_pop, false_tensor, a) 96 | self.assert_results(jump_if_true_or_pop, false_tensor, a) 97 | self.assert_results(pop_jump_if_true, true_tensor, false_tensor, a) 98 | 99 | self.assert_results(pop_jump_if_none, true_tensor, a) 100 | self.assert_results(pop_jump_if_not_none, true_tensor, a) 101 | 102 | 103 | if __name__ == "__main__": 104 | unittest.main() 105 | -------------------------------------------------------------------------------- /tests/test_06_call_function.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from test_case_base import TestCaseBase 4 | 5 | import paddle 6 | 7 | 8 | def add(x, y): 9 | return x + y 10 | 11 | 12 | def sub(x, y): 13 | return x - y 14 | 15 | 16 | def foo_1(x: paddle.Tensor): 17 | m = x + 1 18 | y = add(m * 3, m * 2) 19 | return y 20 | 21 | 22 | def foo_2(x: paddle.Tensor): 23 | m = x + 1 24 | y = sub(m * 3, m * 2) 25 | return y 26 | 27 | 28 | def foo_3(x: paddle.Tensor): 29 | m = x + 1 30 | y = sub(m * 3, m * 2) 31 | y = sub(y, y) 32 | y = sub(y, y) 33 | return y 34 | 35 | 36 | def nest_2(x): 37 | return x + 1 38 | 39 | 40 | def nest_1(x): 41 | return (x - 1) * 2 42 | 43 | 44 | def foo_4(x: paddle.Tensor): 45 | m = x + 1 46 | m = nest_1(m) 47 | return m 48 | 49 | 50 | def fn_with_varargs_and_kwargs(x, *args, **kwargs): 51 | return ( 52 | x 53 | + args[0] 54 | + args[1] 55 | - args[2] 56 | + kwargs['a'] * kwargs['b'] / kwargs['c'] 57 | ) 58 | 59 | 60 | def foo_5(x: paddle.Tensor): 61 | m = x + 1 62 | m = fn_with_varargs_and_kwargs( 63 | m, x + 1, x + 2, x + 3, a=x + 4, b=x + 5, c=x + 6 64 | ) 65 | return m 66 | 67 | 68 | def fn_with_default_value(x, y=1, z=2): 69 | return x + y + z 70 | 71 | 72 | def foo_6(x: paddle.Tensor): 73 | m = x + 1 74 | m = fn_with_default_value(m, m + 10) 75 | m = fn_with_default_value(m + 42) 76 | return m 77 | 78 | 79 | def fn_with_default_value_and_varargs_kwargs(x, y=1, *args, **kwargs): 80 | return x + y + args[0] + kwargs['a'] 81 | 82 | 83 | def foo_7(x: paddle.Tensor): 84 | m = x + 1 85 | m = fn_with_default_value_and_varargs_kwargs(m, m + 1, m + 2, a=m + 3) 86 | return m 87 | 88 | 89 | def fn_with_default_value_and_varargs_kwargs_kwonly_1( 90 | x, y=1, *args, z, **kwargs 91 | ): 92 | return x + y + args[0] + kwargs['a'] + z 93 | 94 | 95 | def fn_with_default_value_and_varargs_kwargs_kwonly_2( 96 | x, y=1, *args, z=10, **kwargs 97 | ): 98 | return x + y + args[0] + kwargs['a'] + z 99 | 100 | 101 | def foo_8(x: paddle.Tensor): 102 | m = x + 1 103 | m = fn_with_default_value_and_varargs_kwargs_kwonly_1( 104 | m, m + 1, m + 2, a=m + 3, z=m + 4 105 | ) 106 | m = fn_with_default_value_and_varargs_kwargs_kwonly_2( 107 | m, m + 1, m + 2, a=m + 3 108 | ) 109 | return m 110 | 111 | 112 | class TestCall(TestCaseBase): 113 | def test_call1(self): 114 | self.assert_results(foo_1, paddle.to_tensor(2)) 115 | 116 | def test_call2(self): 117 | self.assert_results(foo_2, paddle.to_tensor(3)) 118 | 119 | def test_call3(self): 120 | self.assert_results(foo_3, paddle.to_tensor(4)) 121 | 122 | def test_call4(self): 123 | self.assert_results(foo_4, paddle.to_tensor(5)) 124 | 125 | def test_call5(self): 126 | self.assert_results(foo_5, paddle.to_tensor(6)) 127 | 128 | def test_call6(self): 129 | self.assert_results(foo_6, paddle.to_tensor(7)) 130 | 131 | def test_call7(self): 132 | self.assert_results(foo_7, paddle.to_tensor(8)) 133 | 134 | def test_call8(self): 135 | self.assert_results(foo_8, paddle.to_tensor(9)) 136 | 137 | 138 | if __name__ == "__main__": 139 | unittest.main() 140 | -------------------------------------------------------------------------------- /sot/opcode_translator/executor/tracker_viewer.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import queue 4 | from typing import TYPE_CHECKING 5 | 6 | from .tracker import DummyTracker 7 | from .variables import VariableBase 8 | 9 | SIR_GRAPH_CLUSTER_NAME = "cluster_sir_part" 10 | 11 | if TYPE_CHECKING: 12 | import graphviz 13 | 14 | 15 | def try_import_graphviz(): 16 | try: 17 | import graphviz 18 | 19 | return graphviz 20 | except ImportError: 21 | return None 22 | 23 | 24 | def draw_variable(graph: graphviz.Digraph, var: VariableBase): 25 | """ 26 | Draw and colour a node in the graph. 27 | 28 | Args: 29 | graph (graphviz.Digraph): The graph to draw the variable. 30 | var (VariableBase): The variable to draw. 31 | 32 | Returns: 33 | None 34 | """ 35 | # Draw Variable 36 | graph.attr('node', shape='oval', style="filled", fillcolor='aliceblue') 37 | graph.attr('edge', style='solid') 38 | graph.node(var.id, str(var)) 39 | 40 | # Draw Tracker 41 | tracker = var.tracker 42 | graph.attr('node', shape='rect', style='filled', fillcolor='beige') 43 | if isinstance(tracker, DummyTracker): 44 | graph.attr('edge', style='dashed') 45 | graph.attr('node', shape='rect', style='filled', fillcolor='goldenrod') 46 | graph.node(tracker.id, str(tracker)) 47 | 48 | # Draw edge (Tracker -> Variable) 49 | graph.edge(tracker.id, var.id) 50 | 51 | # Draw edge (Tracker inputs -> Tracker) 52 | graph.attr('node', shape='oval', style="filled", fillcolor='cadetblue') 53 | graph.attr('edge', style='solid') 54 | for input in tracker.inputs: 55 | graph.edge(input.id, tracker.id) 56 | 57 | 58 | def view_tracker( 59 | root_variables: list[VariableBase], filename: str, format: str 60 | ): 61 | """ 62 | Generates a graph visualization starting from the given root variables and save it to the given file. 63 | 64 | Args: 65 | root_variables (list[VariableBase]): The root variables to start the visualization from. 66 | filename (str): The name of the file used to save the results of the visualisation. 67 | format (str): The format (e.g., `pdf`, `png` and 'svg' etc.) of the file to save the visualization to. 68 | 69 | Returns: 70 | None 71 | """ 72 | # TODO(SigureMo): 73 | # 1. Colorize the trackers 74 | # 2. Highlight the user specific node, to speedup debug process 75 | graphviz = try_import_graphviz() 76 | if graphviz is None: 77 | print("Cannot import graphviz, please install it first.") 78 | return 79 | 80 | graph = graphviz.Digraph("graph", filename=filename, format=format) 81 | visited = set() 82 | var_queue = queue.Queue() 83 | for var in root_variables: 84 | var_queue.put(var) 85 | 86 | while not var_queue.empty(): 87 | var = var_queue.get() 88 | if var.id in visited: 89 | continue 90 | visited.add(var.id) 91 | if isinstance(var.tracker, DummyTracker): 92 | with graph.subgraph(name=SIR_GRAPH_CLUSTER_NAME) as sir_part: 93 | sir_part.attr(color='green') 94 | draw_variable(sir_part, var) 95 | else: 96 | draw_variable(graph, var) 97 | for input in var.tracker.inputs: 98 | if input not in var_queue.queue: 99 | var_queue.put(input) 100 | 101 | graph.render(view=False) 102 | -------------------------------------------------------------------------------- /docs/design/closure.md: -------------------------------------------------------------------------------- 1 | # Closure Implementation 2 | 3 | ## Closure Example in Python 4 | 5 | 以下是对一个闭包函数的处理的demo,以及它对应的字节码 : 6 | 7 | ```python 8 | import dis 9 | 10 | def func(): 11 | free_x = 1 12 | free_y = 2 13 | 14 | def local(y): 15 | return y + free_x + free_y 16 | return local(1) 17 | 18 | dis.dis(func) 19 | ``` 20 | 21 | ```text 22 | 4 0 LOAD_CONST 1 (1) 23 | 2 STORE_DEREF 0 (free_x) 24 | 25 | 5 4 LOAD_CONST 2 (2) 26 | 6 STORE_DEREF 1 (free_y) 27 | 28 | 7 8 LOAD_CLOSURE 0 (free_x) 29 | 10 LOAD_CLOSURE 1 (free_y) 30 | 12 BUILD_TUPLE 2 31 | 14 LOAD_CONST 3 () 32 | 16 LOAD_CONST 4 ('func..local') 33 | 18 MAKE_FUNCTION 8 (closure) 34 | 20 STORE_FAST 0 (local) 35 | 36 | 9 22 LOAD_FAST 0 (local) 37 | 24 LOAD_CONST 1 (1) 38 | 26 CALL_FUNCTION 1 39 | 28 RETURN_VALUE 40 | 41 | Disassembly of : 42 | 8 0 LOAD_FAST 0 (y) 43 | 2 LOAD_DEREF 0 (free_x) 44 | 4 BINARY_ADD 45 | 6 LOAD_DEREF 1 (free_y) 46 | 8 BINARY_ADD 47 | 10 RETURN_VALUE 48 | 49 | ``` 50 | 51 | 上述字节码可以先进行猜测:一切闭包都是通过额外的字节码进行构建的。 52 | 53 | - STORE_DEREF : 将TOS存储到cell中 54 | 55 | - LOAD_CLOSURE:将Cell构建为闭包 56 | 57 | - LOAD_DEREF:将CELL中的值读取出来 58 | 59 | ## Closure Implementation Bytecode overview in Python 60 | 61 | 在Python3.8中,闭包是通过字节码实现的,与Closure相关的字节码有如下几个: 62 | 63 | | ByteCode | 功能 | 64 | | ------------- | ------- | 65 | | LOAD_CLOSURE | N | 66 | | LOAD_DEREF | N | 67 | | LOAD_CLASSDEREF | N | 68 | | STORE_DEREF | N | 69 | | DELETE_DEREF | N | 70 | 71 | 72 | ## Closure bytecode implementation in detail 73 | 74 | 75 | ```python 76 | case TARGET(LOAD_CLOSURE): { 77 | PyObject *cell = freevars[oparg]; 78 | Py_INCREF(cell); 79 | PUSH(cell); 80 | DISPATCH(); 81 | } 82 | ``` 83 | 84 | 85 | ```python 86 | case TARGET(LOAD_DEREF): { 87 | PyObject *cell = freevars[oparg]; 88 | PyObject *value = PyCell_GET(cell); 89 | if (value == NULL) { 90 | format_exc_unbound(tstate, co, oparg); 91 | goto error; 92 | } 93 | Py_INCREF(value); 94 | PUSH(value); 95 | DISPATCH(); 96 | } 97 | ``` 98 | 99 | ```python 100 | case TARGET(STORE_DEREF): { 101 | PyObject *v = POP(); 102 | PyObject *cell = freevars[oparg]; 103 | PyObject *oldobj = PyCell_GET(cell); 104 | PyCell_SET(cell, v); 105 | Py_XDECREF(oldobj); 106 | DISPATCH(); 107 | } 108 | ``` 109 | 110 | ## Conclusion:the implementation of python in detail。 111 | 112 | 对闭包进行总结,弄明白他的整体的工作方式。 113 | 114 | 首先是外层函数给inner函数准备闭包环境:CELLs,外层函数遇到 STORE_DEREF 指令就会将栈顶元素压入到freevars准备好的CELLS中。然后每个CELL中可以存储一个python obj,存储完毕之后会对old存储的对象进行减引用。 115 | 116 | > Notes: freevars 的定义如下,指代 `freevars = f->f_localsplus + co->co_nlocals;` 117 | -------------------------------------------------------------------------------- /sot/opcode_translator/transform.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import dis 4 | import sys 5 | from functools import partial 6 | 7 | from ..profiler import EventGuard 8 | from ..utils import CodeStatus, log, log_do 9 | from .custom_code import CustomCode 10 | from .executor.executor_cache import OpcodeExecutorCache 11 | from .skip_files import need_skip 12 | 13 | 14 | def print_locals(frame): 15 | local_key = [ 16 | key for key in frame.f_locals.keys() if not key.startswith("__") 17 | ] 18 | print( 19 | f"[eval_frame_callback] {frame.f_code.co_name} with locals {local_key}" 20 | ) 21 | print( 22 | f"[eval_frame_callback] {' ' * len(frame.f_code.co_name)} with cellvars + freevars: {frame.f_code.co_cellvars + frame.f_code.co_freevars}" 23 | ) 24 | 25 | def convert_obj(obj): 26 | import paddle 27 | 28 | if isinstance(obj, paddle.Tensor): 29 | return "Tensor(" + str(obj.shape) + ")" 30 | if isinstance(obj, list): 31 | return [convert_obj(i) for i in obj] 32 | return obj 33 | 34 | for key in local_key: 35 | print( 36 | f"[eval_frame_callback] {' ' * len(frame.f_code.co_name)} {key} = {convert_obj(frame.f_locals[key])}" 37 | ) 38 | 39 | 40 | def eval_frame_callback(frame, **kwargs) -> CustomCode: 41 | with EventGuard( 42 | f"eval_frame_callback: {frame.f_code.co_name}", event_level=2 43 | ): 44 | # is generator 45 | if frame.f_code.co_flags & 0x20 > 0: 46 | return CustomCode(None, True) 47 | 48 | # NOTE(SigureMo): Temporary fallback when code has exception handling. 49 | if sys.version_info >= (3, 11) and frame.f_code.co_exceptiontable: 50 | log( 51 | 3, 52 | f"[eval_frame_callback] {frame.f_code} has co_exceptiontable\n", 53 | ) 54 | return CustomCode(None, False) 55 | 56 | if need_skip(frame): 57 | log(3, f"[eval_frame_callback] skip {frame.f_code}\n") 58 | custom_code = CustomCode(None, False) 59 | new_code = frame.f_code 60 | else: 61 | log( 62 | 2, f"[eval_frame_callback] start to translate: {frame.f_code}\n" 63 | ) 64 | log_do(4, partial(print_locals, frame)) 65 | 66 | log(3, f"[transform] OriginCode: {frame.f_code.co_name}\n") 67 | log_do(3, lambda: dis.dis(frame.f_code)) 68 | 69 | custom_code = OpcodeExecutorCache()(frame, **kwargs) 70 | 71 | if custom_code.code is None: 72 | log( 73 | 3, 74 | "[transform] NewCode (same as origin code): " 75 | + frame.f_code.co_name 76 | + "\n", 77 | ) 78 | new_code = frame.f_code 79 | else: 80 | log( 81 | 3, 82 | "[transform] NewCode: " + custom_code.code.co_name + "\n", 83 | ) 84 | log_do(3, lambda: dis.dis(custom_code.code)) 85 | new_code = custom_code.code 86 | 87 | # just check those codes which need open eval_frame 88 | if ( 89 | custom_code.disable_eval_frame is False 90 | and CodeStatus().is_code_without_graph(new_code) 91 | ): 92 | log( 93 | 3, 94 | "[eval_frame_callback] Code has no graph, block it.\n", 95 | ) 96 | return CustomCode(None, True) 97 | 98 | return custom_code 99 | -------------------------------------------------------------------------------- /tests/test_15_slice.py: -------------------------------------------------------------------------------- 1 | # BUILD_SLICE (new) 2 | 3 | from __future__ import annotations 4 | 5 | import unittest 6 | 7 | from test_case_base import TestCaseBase 8 | 9 | import paddle 10 | from sot.psdb import check_no_breakgraph 11 | 12 | 13 | def build_list_slice(x: list, y: paddle.Tensor): 14 | x[2:4] = [0, 1] 15 | return x[0] + y 16 | 17 | 18 | def build_list_slice_with_step(x: list, y: paddle.Tensor): 19 | x[1:5:2] = [0, 1] 20 | return x[0] + y 21 | 22 | 23 | def build_tuple_slice(x: list, y: paddle.Tensor): 24 | x[2:4] = (0, 1) 25 | return x[0] + y 26 | 27 | 28 | def build_tuple_slice_with_step(x: list, y: paddle.Tensor): 29 | x[1:5:2] = (0, 1) 30 | return x[0] + y 31 | 32 | 33 | def tensor_subscript_ellipsis(x: paddle.Tensor, y: paddle.Tensor): 34 | return x[...] + y[...] 35 | 36 | 37 | @check_no_breakgraph 38 | def tensor_subscript_tensor(x: paddle.Tensor): 39 | d0, d1 = paddle.shape(x) 40 | return x[: d0 // 2, d1 // 2 : d1] 41 | 42 | 43 | class TestSlice(TestCaseBase): 44 | def test_simple(self): 45 | x = list(range(10)) 46 | y = paddle.arange(10) 47 | self.assert_results_with_side_effects(build_list_slice, x, y) 48 | self.assert_results_with_side_effects(build_list_slice_with_step, x, y) 49 | self.assert_results_with_side_effects(build_tuple_slice, x, y) 50 | self.assert_results_with_side_effects(build_tuple_slice_with_step, x, y) 51 | 52 | 53 | class MyLayer(paddle.nn.Layer): 54 | def __init__(self): 55 | super().__init__() 56 | self.linears = paddle.nn.LayerList( 57 | [paddle.nn.Linear(10, 10) for i in range(10)] 58 | ) 59 | 60 | def forward(self, x): 61 | for i, l in enumerate(self.linears): 62 | x = self.linears[i // 2](x) + l(x) 63 | return x 64 | 65 | 66 | def layer_list_slice(layer, x): 67 | out = layer(x) 68 | return out 69 | 70 | 71 | class TestLayerList(TestCaseBase): 72 | def test_layer_list_slice(self): 73 | layer = MyLayer() 74 | x = paddle.randn([5, 10]) 75 | self.assert_results(layer_list_slice, layer, x) 76 | 77 | 78 | def tensor_slice(x: paddle.Tensor): 79 | return x[1, 1, 1] + 1 80 | 81 | 82 | class TestTensorSlice(TestCaseBase): 83 | def test_tensor_slice(self): 84 | x = paddle.randn([4, 3, 10]) 85 | self.assert_results(tensor_slice, x) 86 | 87 | 88 | class TestTensorEllipsis(TestCaseBase): 89 | def test_tensor_subscript_ellipsis(self): 90 | x = paddle.rand((10,)) 91 | y = paddle.rand((10, 10)) 92 | self.assert_results(tensor_subscript_ellipsis, x, y) 93 | 94 | 95 | class TestTensorSubscriptTensor(TestCaseBase): 96 | def test_tensor_subscript_tensor(self): 97 | x = paddle.rand((10, 10)) 98 | self.assert_results(tensor_subscript_tensor, x) 99 | 100 | 101 | class LayerListNet(paddle.nn.Layer): 102 | def __init__(self) -> None: 103 | super().__init__() 104 | self.layer_list = paddle.nn.LayerList( 105 | [paddle.nn.Linear(5, 5), paddle.nn.Linear(5, 5)] 106 | ) 107 | 108 | def forward(self, x): 109 | out = self.layer_list[0](x) 110 | for layer in self.layer_list[1:]: 111 | out = layer(out) 112 | return out 113 | 114 | 115 | class TestLayerListSlice(TestCaseBase): 116 | def test_layer_list_slice(self): 117 | x = paddle.randn([2, 5]) 118 | net = LayerListNet() 119 | self.assert_results(layer_list_slice, net, x) 120 | 121 | 122 | if __name__ == "__main__": 123 | unittest.main() 124 | -------------------------------------------------------------------------------- /tests/test_inplace_api.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from test_case_base import TestCaseBase 4 | 5 | import paddle 6 | from sot import symbolic_translate 7 | 8 | 9 | def simple(x, y): 10 | x[0] = 3.0 11 | z = [y] 12 | y[1] = 5.0 13 | return x[0] + x[1] + z[0][1] + y[0] + y[1] 14 | 15 | 16 | def inplace_in_if(x, y, z): 17 | if z: 18 | x[0] = 3.0 19 | z = [y] 20 | y[1] = 5.0 21 | ret = x[0] + x[1] + z[0][1] + y[0] + y[1] 22 | return ret 23 | else: 24 | return None 25 | 26 | 27 | def inplace_in_if_fallback(x, y, z): 28 | if z > 0: 29 | x[0] = 3.0 30 | z = [y] 31 | y[1] = 5.0 32 | ret = x[0] + x[1] + z[0][1] + y[0] + y[1] 33 | return ret 34 | else: 35 | return None 36 | 37 | 38 | def inplace_in_loop(x, y): 39 | ret = 0 40 | for i in range(10): 41 | x[0] = 1 42 | z = [y] 43 | y[1] = 2 * i + 1 44 | ret += x[0] + x[1] + z[0][1] + y[0] + y[1] 45 | return ret 46 | 47 | 48 | def inplace_in_loop_fallback(x, y, it): 49 | ret = 0 50 | for i in it: 51 | x[0] = 1 52 | z = [y] 53 | y[1] = 2 * i + 1 54 | ret += x[0] + x[1] + z[0][1] + y[0] + y[1] 55 | return ret 56 | 57 | 58 | def inplace_case_0(x): 59 | x[:] = 1.0 60 | return x 61 | 62 | 63 | def inplace_case_1(x): 64 | x[0][0, 0::2] = 1.0 65 | return x 66 | 67 | 68 | def inplace_case_2(x): 69 | t = x[0] 70 | t[:, 0::2] = t[:, 0::2] * 0 71 | t[:, 1::2] = t[:, 1::2] + 2 72 | return x 73 | 74 | 75 | class TestExecutor(TestCaseBase): 76 | def test_case(self): 77 | self.assert_results(inplace_case_0, paddle.randn((1, 4))) 78 | self.assert_results(inplace_case_1, [paddle.randn((1, 4))]) 79 | self.assert_results(inplace_case_2, [paddle.randn((1, 4))]) 80 | 81 | def test_backward(self): 82 | @symbolic_translate 83 | def func(x): 84 | m = x * 2 85 | n = x * 3 86 | y = m 87 | y[:] = n 88 | return y 89 | 90 | x = paddle.ones((1, 4)) * 4 91 | x.stop_gradient = False 92 | y = func(x) 93 | y.sum().backward() 94 | assert (x.grad.numpy() == 3).all() 95 | 96 | def test_simple(self): 97 | self.assert_results( 98 | simple, paddle.to_tensor([1.0, 2.0]), paddle.to_tensor([3.0, 4.0]) 99 | ) 100 | 101 | def test_if(self): 102 | self.assert_results( 103 | inplace_in_if, 104 | paddle.to_tensor([1.0, 2.0]), 105 | paddle.to_tensor([3.0, 4.0]), 106 | True, 107 | ) 108 | self.assert_results( 109 | inplace_in_if_fallback, 110 | paddle.to_tensor([1.0, 2.0]), 111 | paddle.to_tensor([3.0, 4.0]), 112 | paddle.to_tensor(1), 113 | ) 114 | 115 | def test_loop(self): 116 | self.assert_results( 117 | inplace_in_loop, 118 | paddle.to_tensor([1.0, 2.0]), 119 | paddle.to_tensor([3.0, 4.0]), 120 | ) 121 | 122 | a = range(10) 123 | sym_output = symbolic_translate(inplace_in_loop_fallback)( 124 | paddle.to_tensor([1.0, 2.0]), paddle.to_tensor([3.0, 4.0]), iter(a) 125 | ) 126 | paddle_output = inplace_in_loop_fallback( 127 | paddle.to_tensor([1.0, 2.0]), paddle.to_tensor([3.0, 4.0]), iter(a) 128 | ) 129 | self.assert_nest_match(sym_output, paddle_output) 130 | 131 | 132 | if __name__ == "__main__": 133 | unittest.main() 134 | -------------------------------------------------------------------------------- /tests/test_break_graph.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | from test_case_base import TestCaseBase 5 | 6 | import paddle 7 | from sot.utils.paddle_api_config import add_break_graph_apis 8 | 9 | 10 | def ifelse_func(x, y): 11 | if x > 0: 12 | y = y + 1 13 | else: 14 | y = y + 2 15 | return y 16 | 17 | 18 | class TestIfElse(TestCaseBase): 19 | def test_simple(self): 20 | x = paddle.to_tensor([1.0]) 21 | y = paddle.to_tensor([2.0]) 22 | self.assert_results(ifelse_func, x, y) 23 | 24 | 25 | def multi_output(x: paddle.Tensor): 26 | m = x + 1 27 | if x > 0: 28 | return m 29 | else: 30 | return 2 * m 31 | 32 | 33 | class TestExecutor(TestCaseBase): 34 | def test_simple(self): 35 | x = paddle.to_tensor(2) 36 | self.assert_results(multi_output, x) 37 | x = paddle.to_tensor(-2) 38 | self.assert_results(multi_output, x) 39 | 40 | 41 | def print_break_graph(x, y): 42 | z = x + y 43 | print(x, z) 44 | out = y * z * 2 45 | return out 46 | 47 | 48 | class TestPrint(TestCaseBase): 49 | def test_simple(self): 50 | x = paddle.to_tensor(2) 51 | y = paddle.to_tensor(3) 52 | self.assert_results(print_break_graph, x, y) 53 | 54 | 55 | def to_tensor_break_graph(x, y): 56 | z = x + y 57 | out = y * paddle.to_tensor(2) * z 58 | return out 59 | 60 | 61 | class TestToTensor(TestCaseBase): 62 | def test_simple(self): 63 | add_break_graph_apis([paddle.to_tensor]) 64 | x = paddle.to_tensor(2) 65 | y = paddle.to_tensor(3) 66 | self.assert_results(to_tensor_break_graph, x, y) 67 | 68 | 69 | def tensor_clear_gradient(x): 70 | x = paddle.to_tensor(x) 71 | x.clear_gradient() 72 | return x 73 | 74 | 75 | class TestBreakGraphInResumeFn(TestCaseBase): 76 | def test_simple(self): 77 | x = paddle.to_tensor(2) 78 | self.assert_results(tensor_clear_gradient, x) 79 | 80 | 81 | def inner_fn(a, b, c, d): 82 | return a + b * c - d 83 | 84 | 85 | def multi_stack_args(a, b, c): 86 | out = inner_fn(a, b, c, paddle.to_tensor(4)) 87 | return out 88 | 89 | 90 | class TestMultiStackArgs(TestCaseBase): 91 | def test_simple(self): 92 | a = paddle.to_tensor(1) 93 | b = paddle.to_tensor(2) 94 | c = paddle.to_tensor(3) 95 | self.assert_results(multi_stack_args, a, b, c) 96 | 97 | 98 | def break_graph_in_call_method(x): 99 | out = paddle.nn.functional.relu(paddle.to_tensor([4.0])) 100 | return x + out 101 | 102 | 103 | def numpy_break_graph(): 104 | a = paddle.to_tensor([1, 2]) 105 | b = np.sum(a.numpy()) 106 | print(b) 107 | return b 108 | 109 | 110 | class TestBreakGraphInCallMethod(TestCaseBase): 111 | def test_simple(self): 112 | x = paddle.to_tensor([1.0]) 113 | break_graph_in_call_method(x) 114 | x = paddle.to_tensor([2.0]) 115 | break_graph_in_call_method(x) 116 | 117 | x = paddle.to_tensor([3.0]) 118 | self.assert_results(break_graph_in_call_method, x) 119 | 120 | def test_numpy(self): 121 | self.assert_results(numpy_break_graph) 122 | 123 | 124 | def test_break_graph_repeat(x): 125 | out = paddle.to_tensor( 126 | paddle.to_tensor(paddle.to_tensor(paddle.to_tensor([1.0]))) 127 | ) 128 | return x + out 129 | 130 | 131 | class TestBreakGraphRepeat(TestCaseBase): 132 | def test_simple(self): 133 | x = paddle.to_tensor([1.0]) 134 | test_break_graph_repeat(x) 135 | x = paddle.to_tensor([2.0]) 136 | test_break_graph_repeat(x) 137 | 138 | x = paddle.to_tensor([3.0]) 139 | self.assert_results(test_break_graph_repeat, x) 140 | 141 | 142 | if __name__ == "__main__": 143 | unittest.main() 144 | -------------------------------------------------------------------------------- /docs/compat/python311/closure.md: -------------------------------------------------------------------------------- 1 | # Closure 适配 2 | 3 | 4 | ## Python 中的闭包示例 5 | 6 | 以下是在新版本中闭包函数处理的demo,以及它对应的字节码 : 7 | 8 | ```python 9 | import dis 10 | 11 | def func(): 12 | free_x = 1 13 | free_y = 2 14 | 15 | def local(y): 16 | return y + free_x + free_y 17 | return local(1) 18 | 19 | dis.dis(func) 20 | ``` 21 | 22 | ```bash 23 | 0 MAKE_CELL 1 (free_x) 24 | 2 MAKE_CELL 2 (free_y) 25 | 26 | 9 4 RESUME 0 27 | 28 | 10 6 LOAD_CONST 1 (1) 29 | 8 STORE_DEREF 1 (free_x) 30 | 31 | 11 10 LOAD_CONST 2 (2) 32 | 12 STORE_DEREF 2 (free_y) 33 | 34 | 13 14 LOAD_CLOSURE 1 (free_x) 35 | 16 LOAD_CLOSURE 2 (free_y) 36 | 18 BUILD_TUPLE 2 37 | 20 LOAD_CONST 3 () 38 | 22 MAKE_FUNCTION 8 (closure) 39 | 24 STORE_FAST 0 (local) 40 | 41 | 15 26 PUSH_NULL 42 | 28 LOAD_FAST 0 (local) 43 | 30 LOAD_CONST 1 (1) 44 | 32 PRECALL 1 45 | 36 CALL 1 46 | 46 RETURN_VALUE 47 | 48 | Disassembly of : 49 | 0 COPY_FREE_VARS 2 50 | 51 | 13 2 RESUME 0 52 | 53 | 14 4 LOAD_FAST 0 (y) 54 | 6 LOAD_DEREF 1 (free_x) 55 | 8 BINARY_OP 0 (+) 56 | 12 LOAD_DEREF 2 (free_y) 57 | 14 BINARY_OP 0 (+) 58 | 18 RETURN_VALUE 59 | ``` 60 | 61 | ## 新版本中对字节码的改动: 62 | 63 | ### 首先是语义上的改动 64 | 65 | LOAD_CLOSURE: 新版本不再是`co_cellvars + co_freevars`长度偏移量, 而是`LOAD_FAST`的一个别名 66 | 67 | LOAD_DEREF: 加载包含在 locals 中的元素 68 | 69 | STORE_DEREF: 存储 TOS 到 locals 中 70 | 71 | ### 新增字节码 72 | 73 | MAKE_CELL: 如果元素不存在于 locals 则从 co_freevars 和 co_cellvars 中加载 74 | 75 | COPY_FREE_VARS: 复制 co_freevars 和 co_cellvars 中的元素到 locals 76 | 77 | ## 分析 78 | 79 | 从字节码上的改动来看,在 python3.11 中, 闭包将数据存储在 locals 中,而不是 cell 中,这样做的好处是可以减少一次间接寻址,提高性能。 80 | 81 | ## 实现 82 | 83 | LOAD_CLOSURE: 作为`LOAD_FAST`的别名,所以直接调用 84 | 85 | LOAD_DEREF: 改为从 `self._locals` 中加载元素到 TOS 中 86 | 87 | STORE_DEREF: 改为存储 TOS 到 `self._locals` 中 88 | 89 | MAKE_CELL: 从 `self._cells` 中加载元素到 `self._locals` 90 | 91 | COPY_FREE_VARS(闭包内部字节码): 从 `self._code.co_freevars` 拿到 key 在 `self._cells` 中找到元素存储到 `self._locals` 92 | 93 | ## codegen 94 | 95 | ```bash 96 | [transform] NewCode: #foo_af1a0 97 | 9 0 MAKE_CELL 0 (x) # 在此处生成存储字节码,将元素存储至 locals 98 | 2 MAKE_CELL 1 (y) 99 | 4 MAKE_CELL 5 (z) 100 | 6 RESUME 0 101 | 8 LOAD_GLOBAL 1 (NULL + paddle_set_eval_frame_fn) 102 | ... 103 | 104 POP_TOP 104 | 106 RETURN_VALUE 105 | 106 | Disassembly of : 107 | 0 COPY_FREE_VARS 3 # 在此处生成拷贝字节码,将数据拷贝至闭包内部调用 108 | 109 | 12 2 RESUME 0 110 | 111 | 13 4 LOAD_FAST 0 (a) 112 | ... 113 | 30 RETURN_VALUE 114 | 115 | ``` 116 | 117 | 118 | ## 单测 119 | 120 | 新增一项之前未覆盖情况 121 | 122 | ```python 123 | def create_closure(): 124 | x = 1 125 | 126 | def closure(): 127 | return x + 1 128 | 129 | return closure 130 | ``` 131 | 132 | ## 其他更改 133 | 134 | 此次升级还依赖于 eval frame 修改,相关适配链接:[#57490](https://github.com/PaddlePaddle/Paddle/pull/57490)、[#57653](https://github.com/PaddlePaddle/Paddle/pull/57653) 135 | -------------------------------------------------------------------------------- /docs/design/builtin-dispatcher.md: -------------------------------------------------------------------------------- 1 | # Builtin 函数派发机制 2 | 3 | ## 什么是 BuiltinVariable? 4 | 5 | 什么是 BuiltinVariable 呢?最开始我们以为 BuiltinVariable 应该就是 builtin 这个 namespace 里的各种变量,比如 `int`、`abs` 等等这些不需要 import 就可以直接使用的变量,但是实际上这一角色已经由 BuiltinTracker 承担了,BuiltinVariable 就和 BuiltinTracker 定位重复了。 6 | 7 | 对于现有的其他 CallableVariable,每一个 Variable 的定位都很清晰,实现方式也很清晰,比如以 UserDefinedFunctionVariable 为代表的 inline call 方式、以 PaddleApiVariable 为代表的组网方式、以及部分需要子图打断的 API。但是 BuiltinVariable 不是这样的,对于大多数 builtin 函数,在 Python 执行时会去调用对应的 magic method,比如 `len` 会调用 `__len__`,此时执行的效果是与变量类型强相关的,比如用户重载了一个自定义对象的 `__len__`,此时应该去尝试 inline call,因为用户的代码中包含各种各样的情况,而部分对象的 `__len__` 不是写在 Python 端的,此时我们应该去模拟其执行效果。 8 | 9 | 另外值得注意的是,Python 会认为不在 Python 端定义的、没有字节码的函数的类型为 `types.BuiltinFunctionType`,这样 BuiltinVariable 的定位已经很清晰了,即**没有字节码,无法通过字节码来模拟执行的函数**。 10 | 11 | ## 为什么需要派发机制 12 | 13 | 如果无法直接模拟字节码的话,我们在模拟执行时要怎么做呢?起初我们直接利用 Python 的派发机制,在 Variable 上重载 magic method,实现自动派发: 14 | 15 | ```python 16 | class BuiltinVariable: 17 | def call_function(self, *args: VariableBase, **kwargs: VariableBase): 18 | self.value(*args, **kwargs) 19 | ``` 20 | 21 | 但这样问题很明显,比如其实 Python 对部分 magic method 输出是有强制类型检查的,比如 `__bool__` 强制是 bool 类型、`__len__` 强制是 int 类型。而按照上述的实现,我们应该返回的是 VariableBase,这就导致部分 magic method 是无法复用 Python 内建的派发机制的。另一方面,也因为很多 Variable 还没有实现相应的 magic method 而报错。 22 | 23 | 为了避免这些问题,我们添加了一个类似 magic method 的派发机制,来将 BuiltinVariable 的调用转发某一个具体的函数上,这个派发机制会去尝试匹配参数的类型,如果找不到匹配的参数类型,就会尝试从类上获取相应的 magic method 来 inline call,如果仍然找不到,则会产生一个 BreakGraphError 来打断子图。 24 | 25 | ## 派发机制的实现和使用方式 26 | 27 | 派发机制的本质是对参数类型进行匹配,对于一个函数,我们会有多种允许的类型签名(Pattern)以及对应的 handler,而在调用这个函数的时候,我们会根据参数类型来派发到相应的 handler 上,主体代码如下: 28 | 29 | ```python 30 | class Pattern: 31 | type_strings: Args[str] 32 | kwtype_strings: Kwargs[str] 33 | 34 | def __init__( 35 | self, 36 | *types: str, 37 | **kwtypes: str, 38 | ): 39 | self.type_strings = types 40 | self.kwtype_strings = kwtypes 41 | 42 | class Dispatcher: 43 | handlers: dict[ 44 | Callable[..., Any], list[tuple[Pattern, Callable[..., Any]]] 45 | ] = {} 46 | 47 | @classmethod 48 | def register( 49 | cls, 50 | fn: Callable[..., Any], 51 | types: tuple[str, ...], 52 | kwtypes: dict[str, str], 53 | handler: Callable[..., Any], 54 | ): 55 | if fn not in cls.handlers: 56 | cls.handlers[fn] = [] 57 | cls.handlers[fn].append((Pattern(*types, **kwtypes), handler)) 58 | 59 | @classmethod 60 | def dispatch( 61 | cls, fn: Callable[..., Any], *args: Any, **kwargs: Any 62 | ) -> Callable[..., Any] | None: 63 | if fn not in cls.handlers: 64 | return None 65 | for pattern, handler in cls.handlers[fn]: 66 | if pattern.match_inputs(*args, **kwargs): 67 | return handler 68 | return None 69 | ``` 70 | 71 | 这样,我们只需要调用 Dispatcher 将函数签名和 handler 注册到 Dispatcher 上即可: 72 | 73 | ```python 74 | Dispatcher.register( 75 | len, 76 | ("ContainerVariable",), 77 | {}, 78 | lambda var: var.len(), 79 | ) 80 | ``` 81 | 82 | 为了方便使用,我们还提供了一个装饰器的模式: 83 | 84 | 85 | ```python 86 | if TYPE_CHECKING: 87 | from .variables import ContainerVariable 88 | 89 | @Dispatcher.register_decorator(len) 90 | def dispatch_len(var: ContainerVariable): 91 | return var.len() 92 | ``` 93 | 94 | 对于一些复杂的函数,是比较推荐装饰器模式的。 95 | 96 | ## 利用派发机制简化现有代码 97 | 98 | 在实现派发机制之前,我们在很多地方会有重复的代码,比如 Python 的 `a.b` 和 `getattr(a, "b")` 是等价的,但是字节码层面是完全不同的,前者是 `LOAD_ATTR`,后者则是 `CALL_FUNCTION`,此前我们也是各自实现的。 99 | 100 | 在实现了派发机制之后,我们完全可以利用派发机制来实现 `LOAD_ATTR`: 101 | 102 | ```python 103 | @call_break_graph_decorator(push_n=1) 104 | def LOAD_ATTR(self, instr): 105 | attr_name = instr.argval 106 | obj = self.pop() 107 | self.push( 108 | BuiltinVariable( 109 | getattr, graph=self._graph, tracker=DanglingTracker() 110 | )(obj, attr_name) 111 | ) 112 | ``` 113 | 114 | 这样可以极大的简化代码,降低维护成本。 115 | 116 | > **Note** 117 | > 118 | > 我们现在代码中仍然有很多地方使用了旧的派发机制(利用 Python magic method 直接派发),这些将在之后逐步替换 119 | -------------------------------------------------------------------------------- /sot/opcode_translator/skip_files.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import codecs 3 | import collections 4 | import contextlib 5 | import copy 6 | import copyreg 7 | import dataclasses 8 | import distutils 9 | import enum 10 | import functools 11 | import importlib 12 | import inspect 13 | import linecache 14 | import logging 15 | import multiprocessing 16 | import operator 17 | import os 18 | import posixpath 19 | import random 20 | import re 21 | import selectors 22 | import signal 23 | import sys 24 | import tempfile 25 | import threading 26 | import tokenize 27 | import traceback 28 | import types 29 | import typing 30 | import unittest 31 | import uuid 32 | import warnings 33 | import weakref 34 | 35 | import _collections_abc 36 | import _weakrefset 37 | import decorator 38 | import google.protobuf 39 | import numpy 40 | import setuptools 41 | 42 | import paddle 43 | 44 | from ..utils import log 45 | 46 | NEED_SKIP_THIRD_PARTIY_MODULES = { 47 | abc, 48 | collections, 49 | contextlib, 50 | copy, 51 | copyreg, 52 | dataclasses, 53 | enum, 54 | functools, 55 | google.protobuf, 56 | importlib, 57 | inspect, 58 | linecache, 59 | logging, 60 | multiprocessing, 61 | numpy, 62 | operator, 63 | os, 64 | posixpath, 65 | random, 66 | re, 67 | selectors, 68 | signal, 69 | tempfile, 70 | threading, 71 | tokenize, 72 | traceback, 73 | types, 74 | typing, 75 | unittest, 76 | weakref, 77 | _collections_abc, 78 | _weakrefset, 79 | decorator, 80 | codecs, 81 | uuid, 82 | setuptools, 83 | distutils, 84 | warnings, 85 | } 86 | 87 | if sys.version_info < (3, 11): 88 | import sre_compile 89 | import sre_parse 90 | 91 | NEED_SKIP_THIRD_PARTIY_MODULES.add(sre_compile) 92 | NEED_SKIP_THIRD_PARTIY_MODULES.add(sre_parse) 93 | 94 | 95 | def _strip_init_py(s): 96 | return re.sub(r"__init__.py$", "", s) 97 | 98 | 99 | def _module_dir(m: types.ModuleType): 100 | return _strip_init_py(m.__file__) 101 | 102 | 103 | skip_file_names = {_module_dir(m) for m in NEED_SKIP_THIRD_PARTIY_MODULES} 104 | 105 | 106 | sot_path = os.path.dirname(__file__).rpartition("/")[0] + "/" 107 | paddle_path = sys.modules["paddle"].__file__.rpartition("/")[0] + "/" 108 | 109 | skip_file_names.add(sot_path) 110 | skip_file_names.add(paddle_path) 111 | skip_file_names.add( 112 | "") 115 | 116 | skip_file_name_re = re.compile( 117 | f"^({'|'.join(map(re.escape, skip_file_names))})" 118 | ) 119 | 120 | customed_skip_code = set() 121 | 122 | no_skip_code = {paddle.nn.Sequential.forward.__code__} 123 | 124 | 125 | def need_skip_path(filepath: str) -> bool: 126 | """ 127 | Check if the file should be skipped and not transcribed. 128 | 129 | Args: 130 | filepath: The path of the file to check. 131 | 132 | Returns: 133 | bool: True if the file should be skipped. 134 | """ 135 | if not filepath.startswith("<"): 136 | filepath = os.path.abspath(filepath) 137 | return bool(skip_file_name_re.match(filepath)) 138 | 139 | 140 | def skip_function(function): 141 | customed_skip_code.add(function.__code__) 142 | return function 143 | 144 | 145 | def need_skip(frame): 146 | pycode = frame.f_code 147 | if pycode in no_skip_code: 148 | return False 149 | if pycode in customed_skip_code: 150 | log(3, f"Skip frame by code: {pycode}\n") 151 | return True 152 | filename = pycode.co_filename 153 | if sys.version_info >= (3, 11) and filename.startswith(" 0.0057 64 | with EventGuard("FallbackWrapper: call partial_program"): 65 | outputs = self.partial_program(*args, **kwargs) 66 | 67 | clear_eager_tensor_name(outputs) 68 | log_do( 69 | 1, 70 | lambda: GraphLogger().add_subgraph( 71 | self.concrete_program.main_program 72 | ), 73 | ) 74 | log_do( 75 | 4, 76 | lambda: print("[CompileCache] run sir forward success."), 77 | ) 78 | return outputs 79 | 80 | 81 | @Singleton 82 | class CompileSIRCache(Cache): 83 | """ 84 | Cache the compiled function of SIR 85 | """ 86 | 87 | def __init__(self): 88 | super().__init__(weak=False) 89 | 90 | def key_fn(self, context: SymbolicTraceContext, sir_name: str, **kwargs): 91 | """ 92 | generate a hash key for a SIR 93 | 94 | Args: 95 | context: The context to compile 96 | sir_name: The name of the sir to compile 97 | build_strategy: The build strategy to compile 98 | 99 | Returns: 100 | The hash key of the SIR 101 | """ 102 | sir = context.get_sir(sir_name) 103 | # NOTE(dev): Is str(sir) a heavy opearation ? 104 | hash_key = hash(str(sir)) 105 | return hash_key 106 | 107 | def value_fn(self, context: SymbolicTraceContext, sir_name: str, **kwargs): 108 | """ 109 | Generate static graph function 110 | 111 | Args: 112 | context: The context to compile 113 | sir_name: The name of the sir to compile 114 | build_strategy: The build strategy to compile 115 | 116 | Returns: 117 | The static graph function 118 | """ 119 | build_strategy = kwargs.get("build_strategy", None) 120 | backend = kwargs.get("backend", None) 121 | return FallbackWrapper( 122 | paddle.jit.to_static( 123 | compile_sir(context, sir_name), 124 | build_strategy=build_strategy, 125 | backend=backend, 126 | enable_fallback=False, 127 | ), 128 | context.get_sir(sir_name), 129 | ) 130 | -------------------------------------------------------------------------------- /sot/translate.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import os 4 | from typing import TYPE_CHECKING, Callable, TypeVar 5 | 6 | import paddle 7 | 8 | from .opcode_translator import eval_frame_callback 9 | from .utils import GraphLogger, StepInfoManager, StepState, log_do 10 | 11 | if TYPE_CHECKING: 12 | from typing_extensions import ParamSpec 13 | 14 | P = ParamSpec("P") 15 | R = TypeVar("R") 16 | 17 | # Temporarily set the default log level to 2 to get more information in CI log. 18 | os.environ["LOG_LEVEL"] = os.getenv("LOG_LEVEL", "2") 19 | 20 | 21 | def symbolic_translate(fn: Callable[P, R], **kwargs) -> Callable[P, R]: 22 | """ 23 | This function is the entry point of PaddleSOT. It sets eval_frame_callback before input 24 | function to achieve Opcode-level translation. The translation process depends on the 25 | simulation execution, in which information will be collected, especially the network 26 | code. After the simulation execution is completed, the network code will be compiled 27 | into a static graph Program to improve performance. 28 | 29 | Args: 30 | fn: The input function. 31 | 32 | Returns: 33 | Callable, The wrapped function. 34 | 35 | Examples: 36 | >>> # doctest: +SKIP("Cound not get source code of function foo."") 37 | >>> import paddle 38 | >>> import numpy as np 39 | >>> from sot.translate import symbolic_translate 40 | >>> def foo(cond: paddle.Tensor, x: paddle.Tensor): 41 | ... x += 1 42 | ... if cond: 43 | ... x += 1 44 | ... else: 45 | ... x -= 1 46 | ... return x 47 | >>> symbolic_translate_foo = symbolic_translate(foo) 48 | >>> # For the true branch, the output is 2. 49 | >>> cond = paddle.to_tensor(True) 50 | >>> x = paddle.to_tensor(0) 51 | >>> dygraph_out = foo(cond, x) 52 | >>> symbolic_translate_out = symbolic_translate_foo(cond, x) 53 | >>> dygraph_out 54 | Tensor(shape=[], dtype=int64, place=Place(cpu), stop_gradient=True, 55 | 2) 56 | >>> symbolic_translate_out 57 | Tensor(shape=[], dtype=int64, place=Place(cpu), stop_gradient=True, 58 | 2) 59 | >>> np.testing.assert_allclose( 60 | ... dygraph_out.numpy(), symbolic_translate_out.numpy() 61 | ... ) 62 | >>> # For the false branch, the output is 0. 63 | >>> cond = paddle.to_tensor(False) 64 | >>> dygraph_out = foo(cond, x) 65 | >>> symbolic_translate_out = symbolic_translate_foo(cond, x) 66 | >>> dygraph_out 67 | Tensor(shape=[], dtype=int64, place=Place(cpu), stop_gradient=True, 68 | 0) 69 | >>> symbolic_translate_out 70 | Tensor(shape=[], dtype=int64, place=Place(cpu), stop_gradient=True, 71 | 0) 72 | >>> np.testing.assert_allclose( 73 | ... dygraph_out.numpy(), symbolic_translate_out.numpy() 74 | ... ) 75 | 76 | """ 77 | 78 | def callback(frame): 79 | return eval_frame_callback(frame, **kwargs) 80 | 81 | def impl_sot(*args: P.args, **kwargs: P.kwargs) -> R: 82 | assert hasattr( 83 | fn, "__code__" 84 | ), "Target function doesn't have code for simulating." 85 | StepInfoManager().sot_step() 86 | GraphLogger().clear() 87 | paddle.framework.core.set_eval_frame(callback) 88 | try: 89 | outs = fn(*args, **kwargs) 90 | except Exception as e: 91 | raise e 92 | finally: 93 | paddle.framework.core.set_eval_frame(None) 94 | 95 | log_do(1, lambda: GraphLogger().print_info()) 96 | return outs 97 | 98 | def impl_dynamic(*args: P.args, **kwargs: P.kwargs) -> R: 99 | outs = fn(*args, **kwargs) 100 | return outs 101 | 102 | def impl(*args: P.args, **kwargs: P.kwargs) -> R: 103 | with StepInfoManager().step_guard(fn.__code__): 104 | state = StepInfoManager().current_state 105 | 106 | if state == StepState.RUN_SOT: 107 | return impl_sot(*args, **kwargs) 108 | elif state == StepState.RUN_DYN: 109 | return impl_dynamic(*args, **kwargs) 110 | elif state == StepState.COLLECT_INFO: 111 | return StepInfoManager().collect_info( 112 | impl_dynamic, impl_sot, *args, **kwargs 113 | ) 114 | 115 | return impl 116 | -------------------------------------------------------------------------------- /sot/utils/magic_methods.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import operator 4 | from dataclasses import dataclass 5 | from typing import TYPE_CHECKING, Any, Callable 6 | 7 | from .utils import hashable 8 | 9 | if TYPE_CHECKING: 10 | BinaryOp = Callable[[Any, Any], Any] 11 | UnaryOp = Callable[[Any], Any] 12 | 13 | 14 | INPLACE_BINARY_OPS_TO_MAGIC_NAMES: dict[BinaryOp, tuple[str, BinaryOp]] = { 15 | # inplace op fn: (magic name, non-inplace op fn) 16 | operator.iadd: ("__iadd__", operator.add), 17 | operator.iand: ("__iand__", operator.and_), 18 | operator.iconcat: ("__iconcat__", operator.concat), 19 | operator.ifloordiv: ("__ifloordiv__", operator.floordiv), 20 | operator.ilshift: ("__ilshift__", operator.lshift), 21 | operator.imatmul: ("__imatmul__", operator.matmul), 22 | operator.imod: ("__imod__", operator.mod), 23 | operator.imul: ("__imul__", operator.mul), 24 | operator.ior: ("__ior__", operator.or_), 25 | operator.ipow: ("__ipow__", operator.pow), 26 | operator.irshift: ("__irshift__", operator.rshift), 27 | operator.isub: ("__isub__", operator.sub), 28 | operator.itruediv: ("__itruediv__", operator.truediv), 29 | operator.ixor: ("__ixor__", operator.xor), 30 | } 31 | 32 | NON_INPLACE_BINARY_OPS_TO_MAGIC_NAMES: dict[ 33 | BinaryOp, tuple[str, str | None] 34 | ] = { 35 | # op fn: (magic name, reverse magic name) 36 | operator.add: ("__add__", "__radd__"), 37 | operator.and_: ("__and__", "__rand__"), 38 | operator.contains: ("__contains__", None), 39 | operator.delitem: ("__delitem__", None), 40 | operator.eq: ("__eq__", "__eq__"), 41 | operator.floordiv: ("__floordiv__", "__rfloordiv__"), 42 | operator.ge: ("__ge__", "__le__"), 43 | operator.getitem: ("__getitem__", None), 44 | operator.gt: ("__gt__", "__lt__"), 45 | operator.le: ("__le__", "__ge__"), 46 | operator.lshift: ("__lshift__", "__rlshift__"), 47 | operator.lt: ("__lt__", "__gt__"), 48 | operator.matmul: ("__matmul__", "__rmatmul__"), 49 | operator.mod: ("__mod__", "__rmod__"), 50 | operator.mul: ("__mul__", "__rmul__"), 51 | operator.ne: ("__ne__", "__ne__"), 52 | operator.or_: ("__or__", "__ror__"), 53 | operator.pow: ("__pow__", "__rpow__"), 54 | operator.rshift: ("__rshift__", "__rrshift__"), 55 | operator.sub: ("__sub__", "__rsub__"), 56 | operator.truediv: ("__truediv__", "__rtruediv__"), 57 | operator.xor: ("__xor__", "__rxor__"), 58 | } 59 | 60 | UNARY_OPS_TO_MAGIC_NAMES: dict[UnaryOp, str] = { 61 | operator.neg: "__neg__", 62 | operator.invert: "__invert__", 63 | operator.pos: "__pos__", 64 | operator.abs: "__abs__", 65 | operator.index: "__index__", 66 | operator.inv: "__inv__", 67 | operator.invert: "__invert__", 68 | operator.not_: "__not__", 69 | operator.pos: "__pos__", 70 | operator.truth: "__bool__", 71 | bool: "__bool__", 72 | abs: "__abs__", 73 | float: "__float__", 74 | len: "__len__", 75 | int: "__int__", 76 | } 77 | # TODO(SigureMo): support any, all, sum 78 | 79 | 80 | INPLACE_BINARY_OPS = set(INPLACE_BINARY_OPS_TO_MAGIC_NAMES.keys()) 81 | NON_INPLACE_BINARY_OPS = set(NON_INPLACE_BINARY_OPS_TO_MAGIC_NAMES.keys()) 82 | BINARY_OPS = INPLACE_BINARY_OPS | NON_INPLACE_BINARY_OPS 83 | UNARY_OPS = set(UNARY_OPS_TO_MAGIC_NAMES.keys()) 84 | 85 | 86 | @dataclass 87 | class MagicMethod: 88 | name: str 89 | is_inplace: bool = False 90 | is_reverse: bool = False 91 | 92 | 93 | def magic_method_builtin_dispatch(fn: BinaryOp | UnaryOp) -> list[MagicMethod]: 94 | if not hashable(fn): 95 | return [] 96 | if fn in INPLACE_BINARY_OPS: 97 | inplace_magic_name, non_inplace_op = INPLACE_BINARY_OPS_TO_MAGIC_NAMES[ 98 | fn 99 | ] 100 | return [ 101 | MagicMethod(inplace_magic_name, is_inplace=True) 102 | ] + magic_method_builtin_dispatch(non_inplace_op) 103 | elif fn in NON_INPLACE_BINARY_OPS: 104 | magic_name, reverse_magic_name = NON_INPLACE_BINARY_OPS_TO_MAGIC_NAMES[ 105 | fn 106 | ] 107 | magic_methods = [MagicMethod(magic_name)] 108 | if reverse_magic_name is not None: 109 | magic_methods.append( 110 | MagicMethod(reverse_magic_name, is_reverse=True) 111 | ) 112 | return magic_methods 113 | elif fn in UNARY_OPS: 114 | magic_name = UNARY_OPS_TO_MAGIC_NAMES[fn] 115 | return [MagicMethod(magic_name)] 116 | return [] 117 | -------------------------------------------------------------------------------- /tests/test_code_status.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from test_case_base import TestCaseBase, strict_mode_guard 4 | 5 | import paddle 6 | import sot 7 | from sot.opcode_translator.skip_files import skip_function 8 | from sot.utils.code_status import CodeState, CodeStatus 9 | 10 | 11 | class SimpleNet1(paddle.nn.Layer): 12 | def __init__(self): 13 | super().__init__() 14 | self.layers = paddle.nn.LayerList( 15 | [paddle.nn.Linear(10, 10) for _ in range(30)] 16 | ) 17 | 18 | def forward(self, x): 19 | for i in range(len(self.layers)): 20 | sot.psdb.breakgraph() 21 | x = self.layers[i](x) 22 | x = self.layers[i](x) 23 | x = self.layers[i](x) 24 | x = self.layers[i](x) 25 | return x 26 | 27 | 28 | class SimpleNet2(paddle.nn.Layer): 29 | def __init__(self): 30 | super().__init__() 31 | self.layers = paddle.nn.LayerList( 32 | [paddle.nn.Linear(10, 10) for _ in range(30)] 33 | ) 34 | 35 | def forward(self, x): 36 | sot.psdb.fallback() 37 | for i in range(len(self.layers)): 38 | x = self.layers[i](x) 39 | x = self.layers[i](x) 40 | x = self.layers[i](x) 41 | x = self.layers[i](x) 42 | return x 43 | 44 | 45 | def run_net(net, x): 46 | for i in range(20): 47 | x = net(x) 48 | return x 49 | 50 | 51 | class TestCodeInfo(TestCaseBase): 52 | def test_case_1(self): 53 | CodeStatus().clear() 54 | net = SimpleNet1() 55 | inp = paddle.rand((10, 10)) 56 | self.assert_results(run_net, net, inp) 57 | code_map = CodeStatus().code_map 58 | states = [] 59 | for k, v in code_map.items(): 60 | if k.co_name.startswith("#") or k.co_name.startswith("$"): 61 | states.append(v) 62 | elif k in CodeStatus().WITH_GRAPH_API: 63 | assert v.state == CodeState.WITH_GRAPH 64 | else: 65 | assert v.state == CodeState.WITHOUT_GRAPH 66 | # run_net, forward, loop body, resumed part2 in loop body 67 | assert len([v for v in states if v.state == CodeState.WITH_GRAPH]) == 4 68 | # resumed part1 in loop body 69 | assert ( 70 | len([v for v in states if v.state == CodeState.WITHOUT_GRAPH]) == 1 71 | ) 72 | 73 | def test_case_2(self): 74 | with strict_mode_guard(0): 75 | CodeStatus().clear() 76 | net = SimpleNet2() 77 | inp = paddle.rand((10, 10)) 78 | self.assert_results(run_net, net, inp) 79 | code_map = CodeStatus().code_map 80 | states = [] 81 | for k, v in code_map.items(): 82 | if k.co_name.startswith("#") or k.co_name.startswith("$"): 83 | states.append(v) 84 | elif k in CodeStatus().WITH_GRAPH_API: 85 | assert v.state == CodeState.WITH_GRAPH 86 | else: 87 | assert v.state == CodeState.WITHOUT_GRAPH 88 | # no graph found because fallback (paddle api will not enter simulate) 89 | assert ( 90 | len([v for v in states if v.state == CodeState.WITH_GRAPH]) == 0 91 | ) 92 | 93 | 94 | def no_skip_func_0(x): 95 | return x + 1 96 | 97 | 98 | def skipped_func_0(): 99 | pass 100 | 101 | 102 | def skipped_func_1(x): 103 | return x + 1 104 | 105 | 106 | def skipped_func_2(x): 107 | return no_skip_func_0(x) 108 | 109 | 110 | def call_skipped_func_0(x): 111 | for i in range(15): 112 | skipped_func_0() 113 | x = skipped_func_1(x) 114 | x = skipped_func_2(x) 115 | return x 116 | 117 | 118 | skip_function(skipped_func_0) 119 | skip_function(skipped_func_1) 120 | skip_function(skipped_func_2) 121 | skip_function(call_skipped_func_0) 122 | 123 | 124 | class TestDisableSkippedFrame(TestCaseBase): 125 | def test_case_0(self): 126 | CodeStatus().clear() 127 | x = paddle.to_tensor([1]) 128 | self.assert_results(call_skipped_func_0, x) 129 | code_map = CodeStatus().code_map 130 | assert ( 131 | code_map[skipped_func_0.__code__].state == CodeState.WITHOUT_GRAPH 132 | ) 133 | assert ( 134 | code_map[skipped_func_1.__code__].state == CodeState.WITHOUT_GRAPH 135 | ) 136 | assert code_map[skipped_func_2.__code__].state == CodeState.WITH_GRAPH 137 | 138 | 139 | if __name__ == "__main__": 140 | unittest.main() 141 | -------------------------------------------------------------------------------- /tests/test_21_global.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import unittest 4 | 5 | from test_case_base import TestCaseBase 6 | 7 | import paddle 8 | import sot 9 | 10 | global_x = 1 11 | global_y = paddle.to_tensor(2) 12 | global_z = None 13 | global_del_val = 1 14 | global_dict = {} 15 | global_list = [1, 2] 16 | global_inline = 0 17 | 18 | 19 | def global_func_int(): 20 | global global_x 21 | global_x = global_x + 1 22 | return global_x 23 | 24 | 25 | def global_func_int_add(): 26 | global global_x 27 | global_x = global_x + global_x 28 | return global_x + global_x 29 | 30 | 31 | def global_func_tensor_int_add(tensor_y: paddle.Tensor): 32 | global global_x 33 | global_x += 1 34 | return global_x + tensor_y 35 | 36 | 37 | def global_multiple_update(): 38 | global global_x 39 | global_x = 999 40 | global_x = 888 41 | global_x = 777 42 | return global_x - 1 43 | 44 | 45 | def global_func_tensor(): 46 | global global_y 47 | global_y = global_y + global_y 48 | return global_y 49 | 50 | 51 | def global_func_tensor_add(): 52 | global global_y 53 | global_y = global_y + global_y 54 | return global_y + global_y 55 | 56 | 57 | def global_func(): 58 | global global_x 59 | global global_y 60 | global global_z 61 | 62 | global_z = global_x + global_y 63 | return global_z 64 | 65 | 66 | def global_del_global(): 67 | global global_del_val 68 | 69 | del global_del_val 70 | 71 | 72 | def global_func_dict(): 73 | global global_dict 74 | global_dict["key"] = "value" 75 | global_dict.update({"test_key1": "test_value2"}) 76 | return global_dict 77 | 78 | 79 | def global_func_control1(): 80 | global global_dict 81 | if "key" in global_dict: 82 | del global_dict["key"] 83 | return global_dict 84 | 85 | 86 | def global_func_control2(): 87 | global global_list 88 | for i in range(len(global_list)): 89 | global_list[i] = global_list[i] + 1 90 | return global_list 91 | 92 | 93 | def global_func_inline_inner_1(): 94 | global global_inline 95 | global_func_inline_inner_2() 96 | global_inline += 1 97 | 98 | 99 | def global_func_inline_inner_2(): 100 | global global_inline 101 | global_inline += 1 102 | 103 | 104 | def global_func_inline(): 105 | global_func_inline_inner_1() 106 | global global_inline 107 | return global_inline 108 | 109 | 110 | class TestGlobal(TestCaseBase): 111 | def test_global_func_int(self): 112 | global global_x 113 | self.assert_results_with_global_check(global_func_int, ["global_x"]) 114 | global_x += 1 115 | self.assert_results_with_global_check(global_func_int, ["global_x"]) 116 | self.assert_results_with_global_check(global_func_int_add, ["global_x"]) 117 | 118 | def test_global_multiple_update(self): 119 | self.assert_results_with_global_check( 120 | global_multiple_update, ["global_x"] 121 | ) 122 | 123 | def test_global_func_tensor_int_add(self): 124 | self.assert_results_with_global_check( 125 | global_func_tensor_int_add, ["global_x"], paddle.to_tensor(1) 126 | ) 127 | 128 | def test_global_func_tensor(self): 129 | self.assert_results_with_global_check(global_func_tensor, ["global_y"]) 130 | self.assert_results_with_global_check( 131 | global_func_tensor_add, ["global_y"] 132 | ) 133 | 134 | def test_global_func(self): 135 | self.assert_results_with_global_check(global_func, ["global_z"]) 136 | self.assertIn("global_del_val", global_del_global.__globals__) 137 | sot.symbolic_translate(global_del_global)() 138 | self.assertNotIn("global_del_val", global_del_global.__globals__) 139 | 140 | def test_global_func_dict(self): 141 | self.assert_results_with_global_check(global_func_dict, ["global_dict"]) 142 | self.assert_results_with_global_check( 143 | global_func_control1, ["global_dict"] 144 | ) 145 | 146 | def test_global_func_list(self): 147 | self.assert_results_with_global_check( 148 | global_func_control2, ["global_list"] 149 | ) 150 | 151 | def test_global_func_inline(self): 152 | global global_inline 153 | global_inline = 0 154 | sot.symbolic_translate(global_func_inline)() 155 | self.assertEqual(global_inline, 2) 156 | sot.symbolic_translate(global_func_inline)() 157 | self.assertEqual(global_inline, 4) 158 | 159 | 160 | if __name__ == "__main__": 161 | unittest.main() 162 | -------------------------------------------------------------------------------- /sot/symbolic/symbolic_context.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from ..utils import log 4 | from .compile_cache import CompileSIRCache 5 | from .statement_ir import ( 6 | ApiStatement, 7 | CallStatement, 8 | LayerStatement, 9 | MethodStatement, 10 | StatementIR, 11 | StatementIRFactory, 12 | Symbol, 13 | ) 14 | 15 | 16 | class SymbolicTraceContext: 17 | """ 18 | SymbolicTraceContext is a context manager, which is used to record the symbolic trace. 19 | 20 | """ 21 | 22 | def __init__(self): 23 | self.reset() 24 | 25 | def reset(self): 26 | """ 27 | Reset the context. 28 | """ 29 | 30 | # TODO(dev): StatementIRFactory is a singleton, but SymbolicTraceContext is not. 31 | # whether will two different SymbolicTraceContext objects be conflict ? 32 | self.statement_factory = StatementIRFactory() 33 | self.sir_stack = [self.statement_factory.create()] 34 | 35 | @property 36 | def TOS(self): 37 | """ 38 | The top SIR of sir_stack. 39 | 40 | Returns: 41 | StatementIR: the top of stack. 42 | """ 43 | 44 | return self.sir_stack[-1] 45 | 46 | def call_SIR(self, sirname, inputs, outputs, stacks): 47 | """ 48 | Call a SIR, which is a subgraph. 49 | """ 50 | 51 | stmt = CallStatement(sirname, inputs, outputs, stacks) 52 | self.TOS.add_statement(stmt) 53 | 54 | def call_API(self, api, inputs, outputs, stacks): 55 | """ 56 | Call a paddle api. 57 | """ 58 | 59 | assert callable(api), "call_API must receive a paddle api." 60 | stmt = ApiStatement(api, inputs, outputs, stacks) 61 | self.TOS.add_statement(stmt) 62 | 63 | def call_METHOD(self, method_name, inputs, outputs, stacks): 64 | """ 65 | Call a method of a api. The API here can be python or Paddle 66 | """ 67 | assert isinstance( 68 | method_name, str 69 | ), "call_METHOD must method api name. string." 70 | assert isinstance( 71 | inputs[0][0], Symbol 72 | ), "call_METHOD must first augument must be Symbol Variable." 73 | stmt = MethodStatement(method_name, inputs, outputs, stacks) 74 | self.TOS.add_statement(stmt) 75 | 76 | def call_LAYER(self, layer, inputs, outputs, stacks): 77 | """ 78 | Call a layer of a api. 79 | """ 80 | stmt = LayerStatement(layer, inputs, outputs, stacks) 81 | self.TOS.add_statement(stmt) 82 | 83 | def get_sir(self, name: str): 84 | """ 85 | Get a SIR from statement_factory. 86 | 87 | Args: 88 | name (str): the name of SIR. 89 | 90 | Returns: 91 | StatementIR: the SIR. 92 | """ 93 | return self.statement_factory[name] 94 | 95 | def reset_TOS(self): 96 | """ 97 | Reset the TOS. 98 | """ 99 | self.sir_stack.pop() 100 | self.sir_stack.append(self.statement_factory.create()) 101 | 102 | def replace_TOS(self, sir): 103 | """ 104 | Use deepcopyed sir to replace the TOS. 105 | This function will update statment_factory. 106 | """ 107 | self.sir_stack.pop() 108 | self.sir_stack.append(sir) 109 | self.statement_factory.update(sir) 110 | 111 | def compile_do_nothing(self, ret_vals): 112 | """ 113 | Return a dummy function, which will return an empty list. 114 | 115 | Args: 116 | ret_vals (list[Symbol]): the return values of the function. 117 | """ 118 | 119 | def dummy_func(*args, **kwargs): 120 | return [] 121 | 122 | # return None function 123 | dummy_stmt_ir = StatementIR("dummy_func") 124 | dummy_stmt_ir.outputs = [] 125 | dummy_stmt_ir.inputs = [] 126 | return dummy_func, dummy_stmt_ir 127 | 128 | def compile_fn(self, ret_vals, **kwargs): 129 | """ 130 | start compile and return the python function, which must can be to_static without errors. 131 | """ 132 | cur_sir: StatementIR = self.TOS 133 | # step0: if no statement, return a dummy function 134 | if len(cur_sir.statements) == 0: 135 | return self.compile_do_nothing(ret_vals) 136 | # step1: analyse sir inputs and outputs 137 | cur_sir.inputs = cur_sir.analyse_inputs() 138 | # TODO: output analysis 139 | cur_sir.outputs = ret_vals 140 | log(2, "start subgraph compile and execution.\n") 141 | log(2, self.TOS, "\n") 142 | # step2: call compile_sir and get python function, third cache is triggered here. 143 | static_func = CompileSIRCache()(self, cur_sir.name, **kwargs) 144 | # step3: GC and reset TOS 145 | # self.reset_TOS() 146 | 147 | return static_func, cur_sir 148 | -------------------------------------------------------------------------------- /docs/instructions/CALL_FUNCTION.md: -------------------------------------------------------------------------------- 1 | ## CallFunction 支持 2 | 3 | ### 两个选择 4 | 5 | 1. 使用自己模拟的 Frame 构建, Frame 传参等操作。 6 | 7 | 2. 使用 Eval Frame 的机制,但是构造不同的 eval frame 函数。 8 | 9 | 优缺点对比: 10 | 11 | | 两个CallFunction方案对比 | 优点 | 缺点 | 12 | | ------------- | ------------- | ------- | 13 | | 模拟Frame | 完全自己控制 | 1. 得自己处理Globals、Locals等逻辑
2. 完备性无法保证,魔法函数难以追踪 | 14 | | 复用EvalFrame | 不用处理调用前的准备细节 | 完备性、依赖eval frame的3个模式的完备性 N | 15 | 16 | 两个都依赖的: 17 | 18 | 1. 他们都需要区分 `组网模式` 和 `内部callback模式` 的区别。 19 | 20 | 21 | ### 尝试:直接复用Python的函数调用: Frame准备逻辑 22 | 23 | 可以直接复用 Frame 的准备逻辑,这个逻辑如果复用,那么我们直接在 OpExecutor 过程中 运行 call function 即可。 24 | 25 | 1. call function 的话,什么时候 func 进行融合,什么时候不进行融合 ? 26 | 27 | ### 总结:三种 Eval Frame 方案 28 | 29 | | 方案类别 | 模拟执行 | 真实执行 | 30 | | ------------- | ------------- | ------- | 31 | | 组网模式 | Y | Y | 32 | | 内部CallFunction | Y | N | 33 | | 动态图fallback模式 | N | Y | 34 | 35 | 36 | ### 问题:复用EvalFrame方案是否会有调用问题 ? 37 | 38 | 希望在模拟时直接调用 call function 函数,闭包、globals 等是否会有问题。 39 | 40 | ```python 41 | def func(): 42 | bar() 43 | ``` 44 | 45 | 这里的 bar 的 globals 如果和 func 不一样,我们在 eval frame 里面对 bar 进行了调用,是否一定和原来代码执行相同? 46 | 47 | 48 | ### 似乎 eval frame 方案是没问题的 49 | 50 | 任务拆分 -- 代码实现一波: 51 | 52 | 1. eval frame 模式 2 的准备和开发 。 53 | 54 | 2. call function 时进行 return value 的区分 。 55 | 56 | 3. opcode executor 的分类,子类方式来重载某些字节码的实现。比如 return value,组网行为。 57 | 58 | 4. eval frame 如何获取当前是初始frame,还是sub frame,如何给上层的 frame 返回值捏 59 | 60 | 5. SIR 组网阶段是分开组网还是同用同一个? 61 | 62 | ### SIR 如何实现 fallback 时回退?对FunctionGraph Copy 一份? 63 | 64 | 65 | 66 | ## 竞品调研 67 | 68 | ### Dynamo 选择了『Frame 模拟操作』 69 | 70 | 1. Dynamo 会将每个 function type 包装为 SkipFileVariable 或者是 UserFunctionVariable 71 | 72 | 2. 如果是一个用户的函数,那么会进行 inline 的链接。即使用了*模拟Frame的调用方式*。 73 | ``` python 74 | def inline_user_function_return(self, fn, args, kwargs): 75 | """ 76 | A call to some user defined function by inlining it. 77 | """ 78 | state = self.copy_graphstate() 79 | try: 80 | result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs) 81 | self.output.guards.update(fn.guards) 82 | return result 83 | except Exception: 84 | self.restore_graphstate(state) 85 | raise 86 | ``` 87 | 可以看到上述代码也是进行了Graph的恢复行为,这个其实和我想的差不多。 88 | 89 | 3. 如果中间出现了任何的异常,都会触发当前 Graph 的独立成组。将每个 Function 当做是一个 sub graph 来进行调用。 90 | 91 | 92 | ### 需要考虑的: 93 | 94 | 1. code 可以是一个 generator 类型 95 | 96 | 2. 传参等操作可以很好的处理吗?globals准备、locals 的准备、闭包的准备。 97 | 98 | ### Call Function 的参数传递 :Frame 的准备 99 | 100 | 好像也不复杂。 101 | 102 | 103 | ## 版本适配工作 104 | 105 | CPython 中对 code 的实现有如下的代码: 106 | 107 | 108 | ## PaddleSOT 实现 109 | 110 | 目前 PaddleSOT 也是使用 inline call 的方式来实现的。因为 eval frame 的实现参数传递比较复杂。而模拟只需要处理前面和后面的参数传递部分即可。 111 | 112 | 在 paddle 的实现中,一共有多个组件来实现 inline call 的子图融合,主要包含下面几个部分: 113 | 114 | - InlineExecutor: 包含字节码差异行为的实现 115 | 116 | - FunctionGlobalTracker 的引入 117 | 118 | - Function graph 回滚机制 119 | 120 | - Bytecode Force Fallback 机制 121 | 122 | ### InlineCallExecutor 123 | 124 | 在原来的 OpcodeExecutor的基础上引入了inline call的executor,专门用来对子函数的SIR进行融合。主要的算法流图如下: 125 | 126 | 如果 father 调用了 child 函数,那么会有如下的流程图: 127 | 128 | ```mermaid 129 | graph TD 130 | A[CALL_FUNCTION调用] --> B[搜集参数,fn variable传递给新的InlineExecutor] 131 | B --> C[调用 InlineExecutor::_prepare_env, 准备globals, locals等环境] 132 | C --> D(存储当前的SIR状态) 133 | D --> E[调用 InlineExecutor::run 函数, 进行SIR的融合] 134 | E --> F{run是否可以成功run} 135 | F --> | Y | G[将run返回的Variable放到father的stack作为TOS] 136 | H --> I[触发Non Jump Fallback] 137 | F --> | N | H[回退SIR,FunctionGraph等的状态] 138 | G --> J[继续father函数的下一条字节码] 139 | I --> K[结束] 140 | 141 | ``` 142 | 143 | 144 | 145 | ##### 部分指令的区别 146 | 147 | InlineCallExecutor 与 OpcodeExecutor 存在某些指令上的区别。目前是通过类继承重写来实现的差异表示,当前已知的差异有如下几点。 148 | 149 | | ByteCodeName | OpcodeExecutor | InlineCallExecutor | 150 | | ------------- | ------------- | ------- | 151 | | RETURN_VALUE | Y | N | 152 | | JUMP fallback | Y | N | 153 | 154 | 155 | ### FunctionGlobalTracker (Motivation) 156 | 157 | 这个Case可以比较好的让大家理解 Tracker 的主要机制,什么时候需要引入新的 Tracker等。 158 | 159 | 160 | 161 | ### Function graph 回滚机制 (Motivation) 162 | 163 | ### Bytecode Force Fallback (Motivation) 164 | 165 | 166 | ## Tasks and Todos 167 | 168 | - [x] InlineExecutor 169 | 170 | - [x] RETURN_VALUE 区别:组网返回 vs 值返回 171 | 172 | - [x] 出现子图fallback情况: fallback vs SIR回滚 173 | 174 | 175 | - [x] prepare_virtual_env 函数的实现 176 | 177 | - [x] globals, locals, const 178 | 179 | - [x] prepare_closure 闭包函数的实现 180 | 181 | - [ ] 闭包支持 182 | 183 | 184 | - [x] 额外的 Tracker 185 | 186 | - [x] FunctionGlobalTracker 187 | 188 | 189 | - [ ] SIR save_memo 机制和 restore_memo 机制 190 | -------------------------------------------------------------------------------- /tests/test_instruction_translator_cache.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import inspect 4 | import random 5 | import types 6 | import unittest 7 | from unittest.mock import patch 8 | 9 | from test_case_base import ( 10 | TestCaseBase, 11 | test_instruction_translator_cache_context, 12 | ) 13 | 14 | from sot.opcode_translator.custom_code import CustomCode 15 | from sot.opcode_translator.executor.executor_cache import OpcodeExecutorCache 16 | 17 | 18 | def fake_frames() -> ( 19 | tuple[ 20 | types.FrameType, 21 | types.FrameType, 22 | types.FrameType, 23 | types.FrameType, 24 | types.FrameType, 25 | ] 26 | ): 27 | def fake_inner_fn_1(): 28 | frame = inspect.currentframe() 29 | assert frame is not None 30 | return frame 31 | 32 | def fake_inner_fn_2(): 33 | frame = inspect.currentframe() 34 | assert frame is not None 35 | return frame 36 | 37 | def fake_inner_fn_3(): 38 | frame = inspect.currentframe() 39 | assert frame is not None 40 | return frame 41 | 42 | def fake_inner_fn_4(): 43 | frame = inspect.currentframe() 44 | assert frame is not None 45 | return frame 46 | 47 | def fake_inner_fn_5(): 48 | frame = inspect.currentframe() 49 | assert frame is not None 50 | return frame 51 | 52 | return ( 53 | fake_inner_fn_1(), 54 | fake_inner_fn_2(), 55 | fake_inner_fn_3(), 56 | fake_inner_fn_4(), 57 | fake_inner_fn_5(), 58 | ) 59 | 60 | 61 | ( 62 | FRAME_1, 63 | FRAME_2, 64 | FRAME_3, 65 | FRAME_4, 66 | FRAME_5, 67 | ) = fake_frames() 68 | 69 | 70 | def mock_start_translate(frame: types.FrameType, **kwargs): 71 | translate_map = { 72 | FRAME_1: (CustomCode(FRAME_2.f_code, False), lambda frame: True), 73 | FRAME_3: ( 74 | CustomCode(FRAME_4.f_code, False), 75 | lambda frame: False, 76 | ), # Always re-compile 77 | FRAME_5: (CustomCode(None, False), lambda frame: True), 78 | } 79 | return translate_map[frame] 80 | 81 | 82 | class TestOpcodeExecutorCache(unittest.TestCase): 83 | def reset(self): 84 | global translate_count 85 | translate_count = 0 86 | OpcodeExecutorCache().clear() 87 | 88 | @patch( 89 | "sot.opcode_translator.executor.executor_cache.start_translate", 90 | mock_start_translate, 91 | ) 92 | def test_cache_hit(self): 93 | with test_instruction_translator_cache_context() as ctx: 94 | translated_code_1 = OpcodeExecutorCache()(FRAME_1) 95 | assert translated_code_1 is not None 96 | self.assertEqual(translated_code_1.code, FRAME_2.f_code) 97 | self.assertEqual(ctx.translate_count, 1) 98 | # cache hit 99 | translated_code_2 = OpcodeExecutorCache()(FRAME_1) 100 | assert translated_code_2 is not None 101 | self.assertEqual(translated_code_2.code, FRAME_2.f_code) 102 | self.assertEqual(ctx.translate_count, 1) 103 | 104 | @patch( 105 | "sot.opcode_translator.executor.executor_cache.start_translate", 106 | mock_start_translate, 107 | ) 108 | def test_cache_miss_due_to_unknown_code(self): 109 | with test_instruction_translator_cache_context() as ctx: 110 | translated_code_1 = OpcodeExecutorCache()(FRAME_1) 111 | assert translated_code_1 is not None 112 | self.assertEqual(translated_code_1.code, FRAME_2.f_code) 113 | self.assertEqual(ctx.translate_count, 1) 114 | # cache miss 115 | translated_code_2 = OpcodeExecutorCache()(FRAME_3) 116 | assert translated_code_2 is not None 117 | self.assertEqual(translated_code_2.code, FRAME_4.f_code) 118 | self.assertEqual(ctx.translate_count, 2) 119 | 120 | @patch( 121 | "sot.opcode_translator.executor.executor_cache.start_translate", 122 | mock_start_translate, 123 | ) 124 | def test_cache_miss_due_to_check_failed(self): 125 | with test_instruction_translator_cache_context() as ctx: 126 | translated_code_1 = OpcodeExecutorCache()(FRAME_3) 127 | assert translated_code_1 is not None 128 | self.assertEqual(translated_code_1.code, FRAME_4.f_code) 129 | self.assertEqual(ctx.translate_count, 1) 130 | # cache miss 131 | translated_code_2 = OpcodeExecutorCache()(FRAME_3) 132 | assert translated_code_2 is not None 133 | self.assertEqual(translated_code_2.code, FRAME_4.f_code) 134 | self.assertEqual(ctx.translate_count, 2) 135 | 136 | 137 | def foo(x): 138 | return x + 1 139 | 140 | 141 | class TestCacheExceedLimit(TestCaseBase): 142 | def test_cache_exceed_limit(self): 143 | for _ in range(30): 144 | input = random.random() 145 | self.assert_results(foo, input) 146 | 147 | 148 | if __name__ == '__main__': 149 | unittest.main() 150 | --------------------------------------------------------------------------------