├── docs ├── design │ ├── opcode_executor.md │ ├── eval_frame.md │ ├── output-restoration.md │ ├── tracker-and-guard.md │ ├── stringify-guard.md │ ├── closure.md │ └── builtin-dispatcher.md ├── compat │ └── python311 │ │ ├── index.md │ │ ├── CALL.md │ │ └── closure.md ├── notes │ └── function-and-method.md ├── profiler_introduction.md └── instructions │ ├── UNPACK_SEQUENCE.md │ └── CALL_FUNCTION.md ├── tests ├── error_test_paddle_cast.py ├── test_tensor_slice.py ├── tests_legacy │ ├── run_all.sh │ ├── test_generator.py │ ├── test_symbolic_for.py │ ├── test_symbolic_fallback.py │ ├── test_symbolic_infer_meta.py │ ├── test_resnet50.py │ ├── error_test_symbolic_load_before_call.py │ ├── test_symbolic_backward.py │ ├── test_simple_net.py │ ├── test_error_cases.py │ ├── test_optransform_cache.py │ ├── test_symbolic_fallback_if.py │ ├── error_test_trace_cache.py │ ├── test_symbolic_trace.py │ ├── error_test_sir_call.py │ ├── test_basic_translation.py │ ├── error_test_trace_cache.py.zhanfei │ ├── test_symbolic_nested.py │ ├── error_test_resnet_with_trace_cache.py │ ├── test_case_base.py │ └── test_resnet50_backward.py ├── test_multiple_args.py ├── test_delete_fast.py ├── error_test_guard.py ├── test_13_make_function.py ├── test_error_handling.py ├── extract_errors.py ├── error_test_resnet50.py ├── test_02_store_inplace.py ├── test_09_f_string.py ├── test_str_format.py ├── error_test_general_bytecode.py ├── test_numpy.py ├── test_dup_top.py ├── test_simulate_initialize.py ├── test_01_basic.py ├── test_constant_graph.py ├── test_numpy_var_if.py ├── error_test_jump_inline.py ├── test_execution_base.py ├── run_all.sh ├── test_stack.py ├── test_07_unpack.py ├── test_16_paddle_api.py ├── test_call_object.py ├── test_resnet.py ├── test_tensor_dtype_in_guard.py ├── test_segment_linear.py ├── test_trace_list_arg.py ├── run_all_paddle_ci.sh ├── test_exception.py ├── test_output_restoration.py ├── test_range.py ├── test_20_string.py ├── test_03_tuple.py ├── test_sir_rollback.py ├── test_guard_outputs.py ├── test_17_paddle_layer.py ├── test_18_tensor_method.py ├── test_08_rot.py ├── test_10_build_unpack.py ├── test_guard_user_defined_fn.py ├── test_resnet50_backward.py ├── test_enumerate.py ├── test_cost_model.py ├── test_map.py ├── test_11_jumps.py ├── test_06_call_function.py ├── test_15_slice.py ├── test_inplace_api.py ├── test_break_graph.py ├── test_code_status.py ├── test_21_global.py └── test_instruction_translator_cache.py ├── sot ├── opcode_translator │ ├── executor │ │ ├── __init__.py │ │ ├── instr_flag.py │ │ ├── dispatch_functions.py │ │ ├── variables │ │ │ └── __init__.py │ │ └── tracker_viewer.py │ ├── __init__.py │ ├── custom_code.py │ ├── instruction_utils │ │ ├── __init__.py │ │ └── opcode_info.py │ ├── transform.py │ └── skip_files.py ├── __init__.py ├── psdb.py ├── utils │ ├── __init__.py │ ├── exceptions.py │ ├── code_status.py │ ├── paddle_api_config.py │ └── magic_methods.py ├── profiler.py ├── symbolic │ ├── compile_cache.py │ └── symbolic_context.py └── translate.py ├── .gitignore ├── examples ├── trace_basic.py ├── graph_break.py ├── guard.py └── run_all.sh ├── .github └── workflows │ ├── codestyle.yaml │ ├── unittest.yaml │ └── paddle_ci.yaml ├── README.md ├── .pre-commit-config.yaml └── pyproject.toml /docs/design/opcode_executor.md: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/error_test_paddle_cast.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /sot/opcode_translator/executor/__init__.py: -------------------------------------------------------------------------------- 1 | from . import variable_dispatch # noqa: F401 2 | -------------------------------------------------------------------------------- /sot/opcode_translator/__init__.py: -------------------------------------------------------------------------------- 1 | from .transform import eval_frame_callback # noqa: F401 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.sw* 3 | user_tag 4 | 5 | # Editor config 6 | .vscode 7 | 8 | # Environments 9 | venv/ 10 | .venv/ 11 | 12 | # Build 13 | build/ 14 | *.egg-info 15 | -------------------------------------------------------------------------------- /sot/opcode_translator/custom_code.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import types 4 | from typing import NamedTuple 5 | 6 | 7 | class CustomCode(NamedTuple): 8 | code: types.CodeType | None 9 | disable_eval_frame: bool 10 | -------------------------------------------------------------------------------- /sot/__init__.py: -------------------------------------------------------------------------------- 1 | from . import psdb # noqa: F401 2 | from .opcode_translator.breakpoint import ( # noqa: F401 3 | BM, 4 | add_breakpoint, 5 | add_event, 6 | ) 7 | from .opcode_translator.skip_files import skip_function # noqa: F401 8 | from .translate import symbolic_translate # noqa: F401 9 | -------------------------------------------------------------------------------- /docs/compat/python311/index.md: -------------------------------------------------------------------------------- 1 | # Python 3.11 适配 2 | 3 | ## 字段适配 4 | 5 | 指 CodeObject 的字段修改适配,Python 3.11 主要新增了 `co_exceptiontable` 字段,并对 `co_linetable` 进行了一定的修改 6 | 7 | - [`co_linetable`](./co_linetable.md) 8 | 9 | ## 字节码修改适配 10 | 11 | - [CALL 相关字节码](./CALL.md) 12 | 13 | - [closure 相关修改](./closure.md) 14 | -------------------------------------------------------------------------------- /tests/test_tensor_slice.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from test_case_base import TestCaseBase 4 | 5 | import paddle 6 | 7 | 8 | def foo(x: paddle.Tensor): 9 | return x[:, 0] 10 | 11 | 12 | class TestExecutor(TestCaseBase): 13 | def test_tensor_slice(self): 14 | x = paddle.randn((10, 10)) 15 | self.assert_results(foo, x) 16 | 17 | 18 | if __name__ == "__main__": 19 | unittest.main() 20 | -------------------------------------------------------------------------------- /tests/tests_legacy/run_all.sh: -------------------------------------------------------------------------------- 1 | # 遍历目录下的所有 Python 文件 2 | export PYTHONPATH=$PYTHONPATH:../ 3 | 4 | for file in ./test_*.py; do 5 | # 检查文件是否为 Python 文件 6 | if [ -f "$file" ]; then 7 | echo Running: PYTHONPATH=$PYTHONPATH" python " $file 8 | # 执行文件 9 | python "$file" 10 | if [ $? -ne 0 ]; then 11 | echo "run $file failed" 12 | exit 1 13 | fi 14 | fi 15 | done 16 | -------------------------------------------------------------------------------- /tests/tests_legacy/test_generator.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from test_case_base import TestCaseBase 4 | 5 | 6 | def gen(): 7 | yield from range(10) 8 | 9 | 10 | def case1(): 11 | sum = 0 12 | for i in gen(): 13 | sum += i 14 | return sum 15 | 16 | 17 | class TestGen(TestCaseBase): 18 | def test_gen(self): 19 | self.assert_results(case1) 20 | 21 | 22 | if __name__ == "__main__": 23 | unittest.main() 24 | -------------------------------------------------------------------------------- /tests/test_multiple_args.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from test_case_base import TestCaseBase 4 | 5 | import paddle 6 | 7 | 8 | def foo(x, y): 9 | ret = x + y 10 | return ret 11 | 12 | 13 | class TestMultipleArgs(TestCaseBase): 14 | def test_multiple_args(self): 15 | x = paddle.to_tensor([1.0]) 16 | y = paddle.to_tensor([2.0]) 17 | self.assert_results(foo, x, y) 18 | 19 | 20 | if __name__ == "__main__": 21 | unittest.main() 22 | -------------------------------------------------------------------------------- /tests/test_delete_fast.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import unittest 4 | 5 | from test_case_base import TestCaseBase 6 | 7 | import paddle 8 | 9 | 10 | def test_delete_fast(a): 11 | a = a + 2 12 | t = a * 3 13 | del t 14 | return a 15 | 16 | 17 | class TestExecutor(TestCaseBase): 18 | def test_simple(self): 19 | a = paddle.to_tensor(1) 20 | self.assert_results(test_delete_fast, a) 21 | 22 | 23 | if __name__ == "__main__": 24 | unittest.main() 25 | -------------------------------------------------------------------------------- /sot/opcode_translator/executor/instr_flag.py: -------------------------------------------------------------------------------- 1 | # flags for instructions 2 | 3 | 4 | class FORMAT_VALUE_FLAG: 5 | FVC_MASK = 0x3 6 | FVC_NONE = 0x0 7 | FVC_STR = 0x1 8 | FVC_REPR = 0x2 9 | FVC_ASCII = 0x3 10 | FVS_MASK = 0x4 11 | FVS_HAVE_SPEC = 0x4 12 | 13 | 14 | class MAKE_FUNCTION_FLAG: 15 | MF_HAS_CLOSURE = 0x08 16 | MF_HAS_ANNOTATION = 0x04 17 | MF_HAS_KWDEFAULTS = 0x02 18 | MF_HAS_DEFAULTS = 0x01 19 | 20 | 21 | class CALL_FUNCTION_EX_FLAG: 22 | CFE_HAS_KWARGS = 0x01 23 | -------------------------------------------------------------------------------- /tests/error_test_guard.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from test_case_base import TestCaseBase 4 | 5 | import paddle 6 | 7 | 8 | def fn(x): 9 | # now do not support `with xx_guard()`, will fallback to dygraph 10 | with paddle.static.amp.fp16_guard(): 11 | out = x + 1 12 | return out 13 | 14 | 15 | class TestGuard(TestCaseBase): 16 | def test_simple(self): 17 | x = paddle.to_tensor(2) 18 | self.assert_results(fn, x) 19 | 20 | 21 | if __name__ == "__main__": 22 | unittest.main() 23 | -------------------------------------------------------------------------------- /sot/opcode_translator/instruction_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .instruction_utils import ( # noqa: F401 2 | Instruction, 3 | calc_offset_from_bytecode_offset, 4 | calc_stack_effect, 5 | convert_instruction, 6 | gen_instr, 7 | get_instructions, 8 | instrs_info, 9 | modify_extended_args, 10 | modify_instrs, 11 | modify_vars, 12 | relocate_jump_target, 13 | replace_instr, 14 | reset_offset, 15 | ) 16 | from .opcode_analysis import ( # noqa: F401 17 | Space, 18 | analysis_inputs, 19 | analysis_used_names_with_space, 20 | ) 21 | -------------------------------------------------------------------------------- /tests/test_13_make_function.py: -------------------------------------------------------------------------------- 1 | # MAKE_FUNCTION 2 | # CALL_FUNCTION_KW 3 | from __future__ import annotations 4 | 5 | import unittest 6 | 7 | from test_case_base import TestCaseBase 8 | 9 | import paddle 10 | 11 | 12 | def make_fn(x: paddle.Tensor): 13 | def fn(a, b=2, c=3, d=4): 14 | return a + b + c + d 15 | 16 | return fn(1) + fn(2, c=5) + x 17 | 18 | 19 | class TestExecutor(TestCaseBase): 20 | def test_simple(self): 21 | self.assert_results(make_fn, paddle.to_tensor(1)) 22 | 23 | 24 | if __name__ == "__main__": 25 | unittest.main() 26 | -------------------------------------------------------------------------------- /tests/tests_legacy/test_symbolic_for.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from test_case_base import TestCaseBase 4 | 5 | import paddle 6 | 7 | 8 | def case1(x): 9 | for i in range(int(x)): 10 | print("yes") 11 | return x 12 | 13 | 14 | def case2(x): 15 | sum = 0 16 | for i in x: 17 | sum += i 18 | return sum 19 | 20 | 21 | class TestFor(TestCaseBase): 22 | def test(self): 23 | self.assert_results(case1, paddle.to_tensor([4])) 24 | self.assert_results(case2, paddle.to_tensor([4.0, 1.0, 2.0, 3.0])) 25 | 26 | 27 | if __name__ == "__main__": 28 | unittest.main() 29 | -------------------------------------------------------------------------------- /tests/test_error_handling.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from test_case_base import TestCaseBase, strict_mode_guard 4 | 5 | import sot 6 | 7 | 8 | def fn_with_try_except(): 9 | sot.psdb.breakgraph() 10 | sot.psdb.fallback() 11 | try: 12 | raise ValueError("ValueError") 13 | except ValueError: 14 | print("catch ValueError") 15 | return True 16 | 17 | 18 | class TestErrorHandling(TestCaseBase): 19 | @strict_mode_guard(0) 20 | def test_fn_with_try_except(self): 21 | self.assert_results(fn_with_try_except) 22 | 23 | 24 | if __name__ == "__main__": 25 | unittest.main() 26 | -------------------------------------------------------------------------------- /tests/extract_errors.py: -------------------------------------------------------------------------------- 1 | import re 2 | import sys 3 | 4 | runtime_error_msg = sys.stdin.read() 5 | 6 | pattern = r'File "?(.*?)"?, line (\d+),.*\n(.*?)\n(.*?)$' 7 | for match in re.finditer(pattern, runtime_error_msg, re.MULTILINE): 8 | file = match.group(1) 9 | if file.startswith("./"): 10 | file = f"tests/{file[2:]}" 11 | line = match.group(2) 12 | error_info = match.group(4) 13 | if "AssertionError" not in error_info: 14 | # error_info = match.group(3) + '\n' + match.group(4) 15 | output = f"::error file={file},line={line}::Error" 16 | print(output) 17 | -------------------------------------------------------------------------------- /tests/tests_legacy/test_symbolic_fallback.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from test_case_base import TestCaseBase 4 | 5 | import paddle 6 | 7 | 8 | def case1(x): 9 | kk = "sdfsdf" 10 | y = 2 + x 11 | ret = paddle.nn.functional.relu(y) 12 | print(y) 13 | print(y.numpy()) 14 | print("yes") 15 | print("no") 16 | # for i in range(10): 17 | ret = ret + 2 + x 18 | return ret 19 | 20 | 21 | class TestFallback(TestCaseBase): 22 | def test_bool(self): 23 | x = paddle.to_tensor([1.0]) 24 | self.assert_results(case1, x) 25 | 26 | 27 | if __name__ == "__main__": 28 | unittest.main() 29 | -------------------------------------------------------------------------------- /tests/tests_legacy/test_symbolic_infer_meta.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from test_case_base import TestCaseBase 4 | 5 | import paddle 6 | from sot import symbolic_translate 7 | 8 | 9 | def case1(x): 10 | y = 2 + x 11 | y = y + 1 12 | ret = paddle.nn.functional.relu(y) 13 | 14 | assert list(y.meta.shape) == [2] 15 | z = paddle.concat([x, y]) 16 | assert list(z.meta.shape) == [4] 17 | 18 | 19 | class TestIf(TestCaseBase): 20 | def test_if_1(self): 21 | x = paddle.to_tensor([1.0, 2.0]) 22 | symbolic_translate(case1)(x) 23 | 24 | 25 | if __name__ == "__main__": 26 | unittest.main() 27 | -------------------------------------------------------------------------------- /tests/error_test_resnet50.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from test_case_base import TestCaseBase 4 | 5 | import paddle 6 | from paddle.vision import resnet50 7 | 8 | 9 | class SimpleNet(paddle.nn.Layer): 10 | def __init__(self): 11 | super().__init__() 12 | self.relu = paddle.nn.ReLU() 13 | 14 | def forward(self, x): 15 | x = self.relu(x) 16 | return x 17 | 18 | 19 | class TestNet(TestCaseBase): 20 | def test(self): 21 | image = paddle.rand((1, 3, 255, 255)) 22 | net = resnet50() 23 | self.assert_results(net, image) 24 | 25 | 26 | if __name__ == "__main__": 27 | unittest.main() 28 | -------------------------------------------------------------------------------- /tests/test_02_store_inplace.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from test_case_base import TestCaseBase 4 | 5 | import paddle 6 | 7 | 8 | def foo(x: int, y: paddle.Tensor): 9 | x = x + 1 10 | y = y + 1 11 | x += y 12 | return x 13 | 14 | 15 | class TestStoreInplace(TestCaseBase): 16 | def test_simple(self): 17 | self.assert_results(foo, 1, paddle.to_tensor(2)) 18 | 19 | 20 | if __name__ == "__main__": 21 | unittest.main() 22 | 23 | 24 | # Instructions: 25 | # LOAD_FAST 26 | # BINARY_ADD 27 | # STORE_FAST (new) 28 | # INPLACE_ADD (new) 29 | # RETURN_VALUE 30 | 31 | # Variables: 32 | # ConstantVariable 33 | # TensorVariable 34 | -------------------------------------------------------------------------------- /tests/tests_legacy/test_resnet50.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from test_case_base import TestCaseBase 4 | 5 | import paddle 6 | from paddle.vision import resnet50 7 | 8 | 9 | class SimpleNet(paddle.nn.Layer): 10 | def __init__(self): 11 | super().__init__() 12 | self.relu = paddle.nn.ReLU() 13 | 14 | def forward(self, x): 15 | x = self.relu(x) 16 | return x 17 | 18 | 19 | class TestNet(TestCaseBase): 20 | def test(self): 21 | image = paddle.rand((1, 3, 255, 255)) 22 | net = resnet50() 23 | self.assert_results(net, image) 24 | 25 | 26 | if __name__ == "__main__": 27 | unittest.main() 28 | -------------------------------------------------------------------------------- /tests/test_09_f_string.py: -------------------------------------------------------------------------------- 1 | # FORMAT_VALUE (new) 2 | # BUILD_STRING (new) 3 | from __future__ import annotations 4 | 5 | import unittest 6 | 7 | from test_case_base import TestCaseBase 8 | 9 | import paddle 10 | from sot.psdb import assert_true 11 | 12 | 13 | def foo(x: paddle.Tensor): 14 | whilespace = 123 15 | hello_world = f"Hello {whilespace} World" 16 | z = assert_true(hello_world == "Hello 123 World") 17 | x = x + 1 18 | return x 19 | 20 | 21 | class TestFString(TestCaseBase): 22 | def test_fstring(self): 23 | self.assert_results(foo, paddle.to_tensor(1)) 24 | 25 | 26 | if __name__ == "__main__": 27 | unittest.main() 28 | -------------------------------------------------------------------------------- /tests/test_str_format.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import unittest 4 | 5 | from test_case_base import TestCaseBase 6 | 7 | 8 | # copy from python library _distutils_hack/__init__.py 9 | def find_spec(self, fullname, path, target=None): 10 | method_name = 'spec_for_{fullname}'.format( 11 | **{'self': self, 'fullname': fullname} 12 | ) 13 | method = getattr(self, method_name, lambda: None) 14 | return method() 15 | 16 | 17 | class TestExecutor(TestCaseBase): 18 | def test_simple(self): 19 | self.assert_results(find_spec, "self", "fullname", "path", None) 20 | 21 | 22 | if __name__ == "__main__": 23 | unittest.main() 24 | -------------------------------------------------------------------------------- /examples/trace_basic.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import paddle 4 | from sot.translate import symbolic_translate 5 | 6 | 7 | def foo(x: paddle.Tensor, y: paddle.Tensor): 8 | z = x + y 9 | return z + 1 10 | 11 | 12 | def main(): 13 | x = paddle.rand([2, 3]) 14 | y = paddle.rand([2, 3]) 15 | dygraph_out = foo(x, y) 16 | symbolic_translate_out = symbolic_translate(foo)(x, y) 17 | 18 | print("dygraph_out:", dygraph_out) 19 | print("symbolic_translate_out:", symbolic_translate_out) 20 | np.testing.assert_allclose( 21 | dygraph_out.numpy(), symbolic_translate_out.numpy() 22 | ) 23 | 24 | 25 | if __name__ == '__main__': 26 | main() 27 | -------------------------------------------------------------------------------- /tests/tests_legacy/error_test_symbolic_load_before_call.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from test_case_base import TestCaseBase 4 | 5 | import paddle 6 | 7 | # TODO(SigureMo): Fix this test case 8 | 9 | 10 | def nest_func_3(x, y): 11 | a = x + y 12 | b = x * a 13 | z = b + nest_func_2(a) 14 | return z 15 | 16 | 17 | def nest_func_2(z): 18 | print(z) 19 | return z 20 | 21 | 22 | class Test(TestCaseBase): 23 | def test(self): 24 | self.assert_results( 25 | nest_func_3, 26 | paddle.to_tensor([1, 1, 1, 1]), 27 | paddle.to_tensor([2, 3, 4, 5]), 28 | ) 29 | 30 | 31 | if __name__ == "__main__": 32 | unittest.main() 33 | -------------------------------------------------------------------------------- /tests/tests_legacy/test_symbolic_backward.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | from test_case_base import TestCaseBase 5 | 6 | import paddle 7 | from sot import symbolic_translate 8 | 9 | 10 | def func(x, y): 11 | return x * y 12 | 13 | 14 | class TestNet(TestCaseBase): 15 | def test(self): 16 | x = paddle.to_tensor([5, 3]) 17 | y = paddle.to_tensor([1, 3]) 18 | x.stop_gradient = False 19 | self.assert_results(func, x, y) 20 | 21 | ret = symbolic_translate(func)(x, y) 22 | ret.backward() 23 | np.testing.assert_allclose(x.grad.numpy(), [1.0, 3.0]) 24 | 25 | 26 | if __name__ == "__main__": 27 | unittest.main() 28 | -------------------------------------------------------------------------------- /.github/workflows/codestyle.yaml: -------------------------------------------------------------------------------- 1 | name: CodeStyle Check 2 | 3 | on: 4 | push: 5 | branches: [develop] 6 | pull_request: 7 | workflow_dispatch: 8 | 9 | jobs: 10 | Test: 11 | runs-on: ubuntu-latest 12 | name: CodeStyle Check 13 | steps: 14 | - name: Checkout 15 | uses: actions/checkout@v3 16 | 17 | - name: Install python 18 | uses: actions/setup-python@v4 19 | with: 20 | # Run pre-commit on the lowest supported python version 21 | python-version: '3.8' 22 | 23 | - name: Install dependencies 24 | run: | 25 | pip install pre-commit 26 | - name: Precommit Check 27 | run : pre-commit run --all-files 28 | -------------------------------------------------------------------------------- /tests/tests_legacy/test_simple_net.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from test_case_base import TestCaseBase 4 | 5 | import paddle 6 | 7 | 8 | class SimpleNet(paddle.nn.Layer): 9 | def __init__(self): 10 | super().__init__() 11 | self.linear1 = paddle.nn.Linear(10, 3) 12 | self.linear2 = paddle.nn.Linear(3, 1) 13 | 14 | def forward(self, x): 15 | out1 = self.linear1(x) 16 | out2 = self.linear2(out1) 17 | return out1 + out2 18 | 19 | 20 | class TestNet(TestCaseBase): 21 | def test(self): 22 | inp = paddle.rand((10,)) 23 | net = SimpleNet() 24 | self.assert_results(net, inp) 25 | 26 | 27 | if __name__ == "__main__": 28 | unittest.main() 29 | -------------------------------------------------------------------------------- /tests/error_test_general_bytecode.py: -------------------------------------------------------------------------------- 1 | import dis 2 | import unittest 3 | 4 | from test_case_base import TestCaseBase 5 | 6 | 7 | def func(): 8 | return True is True is not False 9 | 10 | 11 | def func2(x): 12 | # TODO(@xiaojian): SIR not used by output. 13 | y = x + 1 14 | return True is True is not False 15 | 16 | 17 | def func3(): 18 | i = 0 19 | 20 | def inner(): 21 | return i + 1 22 | 23 | return inner() 24 | 25 | 26 | class TestExecutor(TestCaseBase): 27 | def test_simple(self): 28 | self.assert_results(func3) 29 | # self.assert_results(func2, paddle.to_tensor(1.0)) 30 | 31 | 32 | dis.dis(func3) 33 | 34 | if __name__ == "__main__": 35 | unittest.main() 36 | -------------------------------------------------------------------------------- /tests/tests_legacy/test_error_cases.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from test_case_base import TestCaseBase 4 | 5 | import paddle 6 | 7 | 8 | def case_list_comp(x): 9 | def add_one(x): 10 | return x + 1 11 | 12 | ret = [add_one(x) for x in x] 13 | return ret 14 | 15 | 16 | def case_map_structure(x): 17 | def add_one(x): 18 | return x + 1 19 | 20 | ret = paddle.utils.map_structure(add_one, x) 21 | return ret 22 | 23 | 24 | class Test(TestCaseBase): 25 | def test(self): 26 | self.assert_results(case_list_comp, paddle.to_tensor([1, 2, 4, 3])) 27 | self.assert_results(case_map_structure, paddle.to_tensor([1, 2, 4, 3])) 28 | 29 | 30 | if __name__ == "__main__": 31 | unittest.main() 32 | -------------------------------------------------------------------------------- /tests/tests_legacy/test_optransform_cache.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from test_case_base import TestCaseBase 4 | 5 | import paddle 6 | from sot import symbolic_translate 7 | from sot.opcode_translator.instruction_translator import OpcodeExecutorCache 8 | 9 | 10 | def case1(x): 11 | for i in range(int(x)): 12 | print("yes") 13 | return x 14 | 15 | 16 | class TestFor(TestCaseBase): 17 | def test(self): 18 | symbolic_translate(case1)(paddle.to_tensor([4])) 19 | symbolic_translate(case1)(paddle.to_tensor([4])) 20 | symbolic_translate(case1)(paddle.to_tensor([4])) 21 | assert OpcodeExecutorCache().hit_num == 2, "cache hit num should be 2" 22 | 23 | 24 | if __name__ == "__main__": 25 | unittest.main() 26 | -------------------------------------------------------------------------------- /examples/graph_break.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import paddle 4 | from sot.translate import symbolic_translate 5 | 6 | 7 | def foo(cond: paddle.Tensor, x: paddle.Tensor): 8 | x += 1 9 | if cond: 10 | x += 1 11 | else: 12 | x -= 1 13 | return x 14 | 15 | 16 | def main(): 17 | cond = paddle.to_tensor(True) 18 | x = paddle.to_tensor(0) 19 | dygraph_out = foo(cond, x) 20 | symbolic_translate_out = symbolic_translate(foo)(cond, x) 21 | 22 | print("dygraph_out:", dygraph_out) 23 | print("symbolic_translate_out:", symbolic_translate_out) 24 | np.testing.assert_allclose( 25 | dygraph_out.numpy(), symbolic_translate_out.numpy() 26 | ) 27 | 28 | 29 | if __name__ == '__main__': 30 | main() 31 | -------------------------------------------------------------------------------- /tests/tests_legacy/test_symbolic_fallback_if.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from test_case_base import TestCaseBase 4 | 5 | import paddle 6 | 7 | 8 | def case1(cond, x): 9 | cond = 2 + cond 10 | if cond: 11 | print("yes") 12 | x = x + 4 13 | else: 14 | print("no") 15 | x = x - 4 16 | ret = paddle.nn.functional.relu(x) 17 | return ret 18 | 19 | 20 | class TestIf(TestCaseBase): 21 | def test_if_1(self): 22 | self.assert_results( 23 | case1, paddle.to_tensor([4.0]), paddle.to_tensor([4.0]) 24 | ) 25 | self.assert_results( 26 | case1, paddle.to_tensor([-2.0]), paddle.to_tensor([4.0]) 27 | ) 28 | 29 | 30 | if __name__ == "__main__": 31 | unittest.main() 32 | -------------------------------------------------------------------------------- /tests/test_numpy.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | from test_case_base import TestCaseBase, strict_mode_guard 5 | 6 | import paddle 7 | 8 | 9 | def foo(x, y): 10 | ret = x + y 11 | return ret 12 | 13 | 14 | class TestNumpy(TestCaseBase): 15 | def test_tensor_add_numpy_number(self): 16 | x = paddle.to_tensor([1.0]) 17 | y = np.int64(2) 18 | self.assert_results(foo, x, y) 19 | self.assert_results(foo, y, x) 20 | 21 | @strict_mode_guard(0) 22 | def test_tensor_add_numpy_array(self): 23 | x = paddle.to_tensor([1.0]) 24 | y = np.array(2.0) 25 | self.assert_results(foo, x, y) 26 | self.assert_results(foo, y, x) 27 | 28 | 29 | if __name__ == "__main__": 30 | unittest.main() 31 | -------------------------------------------------------------------------------- /examples/guard.py: -------------------------------------------------------------------------------- 1 | import paddle 2 | from sot.translate import symbolic_translate 3 | 4 | 5 | def foo(x: paddle.Tensor, y: paddle.Tensor, z: paddle.Tensor): 6 | a = x + y 7 | a *= x 8 | return a, z 9 | 10 | 11 | def main(): 12 | a = paddle.rand([1]) 13 | b = paddle.rand([2, 3]) 14 | c = paddle.rand([4]) 15 | d = paddle.rand([5, 6]) 16 | e = paddle.rand([]) 17 | sym_foo = symbolic_translate(foo) 18 | dygraph_out = foo(a, b, c) 19 | symbolic_translate_out = sym_foo(a, b, c) 20 | 21 | print("dygraph_out:", dygraph_out) 22 | print("symbolic_translate_out:", symbolic_translate_out) 23 | 24 | # cache hit 25 | sym_foo(a, b, d) 26 | 27 | # cache miss 28 | sym_foo(e, b, c) 29 | 30 | 31 | if __name__ == '__main__': 32 | main() 33 | -------------------------------------------------------------------------------- /examples/run_all.sh: -------------------------------------------------------------------------------- 1 | # 遍历目录下的所有 Python 文件 2 | export PYTHONPATH=$PYTHONPATH:../ 3 | export STRICT_MODE=1 4 | export MIN_GRAPH_SIZE=-1 5 | 6 | for file in ./*.py; do 7 | # 检查文件是否为 Python 文件 8 | if [ -f "$file" ]; then 9 | if [[ -n "$GITHUB_ACTIONS" ]]; then 10 | echo ::group::example Running: LOG_LEVEL=3 PYTHONPATH=$PYTHONPATH " STRICT_MODE=1 python " $file 11 | else 12 | echo Running: LOG_LEVEL=3 PYTHONPATH=$PYTHONPATH " STRICT_MODE=1 python " $file 13 | fi 14 | # 执行文件 15 | python "$file" 16 | if [ $? -ne 0 ]; then 17 | echo "run $file failed" 18 | exit 1 19 | fi 20 | if [[ -n "$GITHUB_ACTIONS" ]]; then 21 | echo "::endgroup::" 22 | fi 23 | fi 24 | done 25 | -------------------------------------------------------------------------------- /tests/tests_legacy/error_test_trace_cache.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import paddle 4 | from sot import symbolic_translate 5 | from sot.proxy_tensor import cache_and_return, frame_enter, frame_leave 6 | 7 | 8 | class A: 9 | def __init__(self, x): 10 | self.x = x 11 | 12 | 13 | def sum_2(l): 14 | if frame_enter("func1", (l)): 15 | print("hit cache") 16 | return cache_and_return("func1", (l)) 17 | ret = l[0] + l[1] 18 | frame_leave(ret) 19 | return ret 20 | 21 | 22 | class TestCaseName(unittest.TestCase): 23 | def test_return_callable(self): 24 | x = paddle.to_tensor([1.0]) 25 | y = paddle.to_tensor([2.0]) 26 | ret = symbolic_translate(sum_2)([x, y]) 27 | print(ret) 28 | 29 | 30 | if __name__ == "__main__": 31 | unittest.main() 32 | -------------------------------------------------------------------------------- /tests/tests_legacy/test_symbolic_trace.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from test_case_base import TestCaseBase 4 | 5 | import paddle 6 | 7 | 8 | def case1(x): 9 | y = 2 + x 10 | ret = paddle.nn.functional.relu(y) 11 | print("yes") 12 | print("no") 13 | # for i in range(10): 14 | ret = ret + 2 + x 15 | return ret 16 | 17 | 18 | def case2(x): 19 | y = x + 2 20 | ret = paddle.nn.functional.relu(y) 21 | for i in range(10): 22 | ret = ret + 2 + x 23 | return ret 24 | 25 | 26 | class TestIf(TestCaseBase): 27 | def test_if_1(self): 28 | x = paddle.to_tensor([1.0]) 29 | self.assert_results(case1, paddle.to_tensor([1.0])) 30 | self.assert_results(case2, paddle.to_tensor([1.0])) 31 | 32 | 33 | if __name__ == "__main__": 34 | unittest.main() 35 | -------------------------------------------------------------------------------- /tests/tests_legacy/error_test_sir_call.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import paddle 4 | from sot import symbolic_translate 5 | from sot.trace_cache_entrance import cache_and_return, frame_enter, frame_leave 6 | 7 | 8 | def sum(x, y): 9 | if frame_enter("sum", (x, y)): 10 | return cache_and_return("sum", (x, y)) 11 | ret = x + y 12 | frame_leave("sum", (ret)) 13 | return ret 14 | 15 | 16 | def main(x, y): 17 | ret = sum(x, x) 18 | ret2 = sum(x, y) 19 | return ret2 20 | 21 | 22 | class TestCaseName(unittest.TestCase): 23 | def test_return_callable(self): 24 | x = paddle.to_tensor([1.0]) 25 | y = paddle.to_tensor([2.0]) 26 | ret = symbolic_translate(main)(x, y) 27 | assert ret.item() == 3.0, "Should be 4.0" 28 | 29 | 30 | if __name__ == "__main__": 31 | unittest.main() 32 | -------------------------------------------------------------------------------- /tests/test_dup_top.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import unittest 4 | 5 | from test_case_base import TestCaseBase 6 | 7 | import paddle 8 | 9 | 10 | def func_dup_top_1(): 11 | return True == True != False 12 | 13 | 14 | def func_dup_top_2(x): 15 | y = x + 1 16 | return True == True != False 17 | 18 | 19 | def func_dup_top_two(x: list[paddle.Tensor]): 20 | x[0] += x[1] 21 | return x 22 | 23 | 24 | class TestDupTop(TestCaseBase): 25 | def test_dup_top(self): 26 | self.assert_results(func_dup_top_1) 27 | self.assert_results(func_dup_top_2, paddle.to_tensor(1.0)) 28 | # TODO: fix this after we support side effect 29 | # self.assert_results( 30 | # func_dup_top_two, [paddle.to_tensor(1.0), paddle.to_tensor(2.0)] 31 | # ) 32 | 33 | 34 | if __name__ == "__main__": 35 | unittest.main() 36 | -------------------------------------------------------------------------------- /tests/test_simulate_initialize.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from test_case_base import TestCaseBase 4 | 5 | import paddle 6 | from paddle import nn 7 | from sot import symbolic_translate 8 | 9 | 10 | class A: 11 | def __init__(self, vals): 12 | vals.append(1) 13 | 14 | 15 | def foo(x, y): 16 | out = nn.Softmax()(paddle.to_tensor([x, y], dtype="float32")) 17 | return out 18 | 19 | 20 | def bar(x): 21 | a = A(x) 22 | t = paddle.to_tensor(x) 23 | return t.mean() 24 | 25 | 26 | class TestInit(TestCaseBase): 27 | def test_init_paddle_layer(self): 28 | self.assert_results(foo, 1, 2) 29 | 30 | def test_init_python_object(self): 31 | sot_output = symbolic_translate(bar)([1.0, 2.0]) 32 | dyn_output = bar([1.0, 2.0]) 33 | self.assert_nest_match(sot_output, dyn_output) 34 | 35 | 36 | if __name__ == "__main__": 37 | unittest.main() 38 | -------------------------------------------------------------------------------- /docs/notes/function-and-method.md: -------------------------------------------------------------------------------- 1 | # 函数和方法 2 | 3 | ## method 与 bind 4 | 5 | 什么是 function 就不必说了,这里说明一下 method 和 function 的区别 6 | 7 | 简单来说,method 就是 function 的第一个位置 bind 一个 object(即 self),其行为上有点类似用 partial 绑定了第一个参数的 function 8 | 9 | ```python 10 | class A: 11 | def b(self, x): 12 | return x 13 | 14 | a = A() 15 | a.b(1) 16 | 17 | # equivalent to 18 | A.b(a, 1) 19 | 20 | # it is like 21 | partial_b = partial(A.b, a) # bind a to the first position 22 | partial_b(1) 23 | ``` 24 | 25 | 注意 `a.b(a)` 和 `A.b(a, 1)` 是完全等价的,这里 `A.b` 是一个 function,也称 unbound method,而 `a.b` 则是一个 bound method,根据名字也能看出来,就是 bind 了 self 的 method 26 | 27 | ## method 与 descriptor 28 | 29 | method 的实际是利用了 descriptor,比如 30 | 31 | ```python 32 | A.b.__get__(a, A) 33 | 34 | # equivalent to 35 | a.b 36 | ``` 37 | 38 | 由于 `A.b` 是一个 descriptor,因此在其实例获取属性 `b` 时,自然会调用 `A.b.__get__`,在此时便会将原来的 function 和 object 绑定在一起,获得一个 bound method 39 | -------------------------------------------------------------------------------- /tests/test_01_basic.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from test_case_base import TestCaseBase, strict_mode_guard 4 | 5 | import paddle 6 | 7 | 8 | def foo(x: int, y: paddle.Tensor): 9 | return x + y 10 | 11 | 12 | class TestExecutor(TestCaseBase): 13 | def test_simple(self): 14 | self.assert_results(foo, 1, paddle.to_tensor(2)) 15 | 16 | 17 | def numpy_add(x, y): 18 | out = paddle.to_tensor(x.numpy() + y.numpy()) 19 | return out 20 | 21 | 22 | class TestNumpyAdd(TestCaseBase): 23 | @strict_mode_guard(0) 24 | def test_numpy_add(self): 25 | x = paddle.to_tensor([2]) 26 | y = paddle.to_tensor([3]) 27 | self.assert_results(numpy_add, x, y) 28 | 29 | 30 | if __name__ == "__main__": 31 | unittest.main() 32 | 33 | 34 | # Instructions: 35 | # LOAD_FAST 36 | # BINARY_ADD 37 | # RETURN_VALUE 38 | 39 | # Variables: 40 | # ConstantVariable 41 | # TensorVariable 42 | -------------------------------------------------------------------------------- /sot/opcode_translator/executor/dispatch_functions.py: -------------------------------------------------------------------------------- 1 | # This file stores the customed function that will be called by the dispatch mechanism. 2 | 3 | from ...utils import BreakGraphError, FallbackError 4 | 5 | 6 | def raise_break_graph_fn(*args, **kwarg): 7 | raise BreakGraphError("raise by raise_break_graph_fn.") 8 | 9 | 10 | def raise_not_implement_fn(*args, **kwarg): 11 | raise FallbackError("raise by raise_break_graph_fn.") 12 | 13 | 14 | # just a function for operator.in 15 | def operator_in(left, right): 16 | return left in right 17 | 18 | 19 | def operator_not_in(left, right): 20 | return left not in right 21 | 22 | 23 | def operator_exception_match(left, right): 24 | pass 25 | 26 | 27 | def operator_BAD(left, right): 28 | pass 29 | 30 | 31 | def operator_is_none(val): 32 | pass 33 | 34 | 35 | def operator_is_not_none(val): 36 | pass 37 | 38 | 39 | def tensor_numel(x): 40 | pass 41 | -------------------------------------------------------------------------------- /docs/design/eval_frame.md: -------------------------------------------------------------------------------- 1 | # Eval Frame 设计和实现 2 | 3 | ## 疑难问题定位和分析 4 | 5 | ### Generator Early Return Causes `System Error` 6 | 7 | 待完善,可以见如流知识库 8 | 9 | ### Resnet 支持后出现 segement fault 10 | 11 | - [x] 果然是 stacksize 的问题 12 | 13 | 初步定位,发现第一个Step的翻译过程已经走完,出现问题的环节应该是 `eval frame` 中的第二步。 14 | 15 | #### 如何查找线索? 16 | 17 | 面临的第一个问题是,如何寻找到定位问题的线索。eval frame 中没有对应的 log 可以查看,所以我们很难进行定位那里出现了问题,但是因为 Python 的独立编译过,所以可以很方便的在Pytho中打印出我们需要的变量。 18 | 19 | 所以我们有如下两个探索方式: 20 | 21 | 1. 在 Eval frame 中进行日志探索。 22 | 23 | 2. 在 Python 源码中插入信息搜集的点。 24 | 25 | 理论上,有了上述两个方法,我们可以找到一切问题。包含 segment 错误。 26 | 27 | #### 定位 segment fault 位置 28 | 29 | segment fault 问题的首要任务就是找到错误的地点。才可以逐渐分析出错误原因。 30 | 31 | **没有跑到动转静组网**: 首先需要明确,segment fault 位置在 `CALL_FUNCTION` 字节码之前。 32 | 33 | ```python 34 | import os 35 | os.main 36 | ``` 37 | 38 | int main 39 | 40 | 41 | #### 问题猜测 42 | 43 | - [x] 果然是这个stacksize的问题!! 44 | 45 | 这个问题是否与 eval frame 中的 stack size 有关系,因为stack size不够大,导致栈溢出了。这样的解释也是比较合理的。 46 | -------------------------------------------------------------------------------- /tests/test_constant_graph.py: -------------------------------------------------------------------------------- 1 | # New Supported Instructions: 2 | # BUILD_MAP (new) 3 | # BUILD_CONST_KEY_MAP (new) 4 | 5 | import unittest 6 | 7 | from test_case_base import TestCaseBase 8 | 9 | import paddle 10 | 11 | 12 | def func_1(format_str, tensor): 13 | str = format_str.format(xx=12) 14 | a = "{xx} = 12".format 15 | ttt = f"{10} = 12" 16 | a(xx=12) 17 | tensor = tensor + 1 18 | return str, tensor 19 | 20 | 21 | def func_2(format_str, tensor): 22 | str = format_str % 10 23 | tensor = tensor + 1 24 | return str, tensor 25 | 26 | 27 | class TestConstantGraph(TestCaseBase): 28 | def test_case_1(self): 29 | x = "{xx} is xx" 30 | tensor = paddle.to_tensor(1) 31 | self.assert_results(func_1, x, tensor) 32 | 33 | def test_case_2(self): 34 | x = "%s is xx" 35 | tensor = paddle.to_tensor(1) 36 | self.assert_results(func_2, x, tensor) 37 | 38 | 39 | if __name__ == "__main__": 40 | unittest.main() 41 | -------------------------------------------------------------------------------- /tests/test_numpy_var_if.py: -------------------------------------------------------------------------------- 1 | import os 2 | import unittest 3 | 4 | import numpy as np 5 | from test_case_base import TestCaseBase 6 | 7 | import paddle 8 | from sot.psdb import check_no_breakgraph, check_no_fallback 9 | 10 | os.environ['MIN_GRAPH_SIZE'] = '-1' 11 | 12 | 13 | @check_no_breakgraph 14 | @check_no_fallback 15 | def forward(x, y): 16 | if x == 0: 17 | return y + 2 18 | else: 19 | return y * 2 20 | 21 | 22 | @check_no_breakgraph 23 | @check_no_fallback 24 | def forward2(x, y): 25 | if x == x: # numpy == numpy 26 | return y + 2 27 | else: 28 | return y * 2 29 | 30 | 31 | class TestJumpWithNumpy(TestCaseBase): 32 | def test_jump(self): 33 | self.assert_results(forward, np.array([1]), paddle.to_tensor(2)) 34 | self.assert_results(forward, np.array([0]), paddle.to_tensor(2)) 35 | self.assert_results(forward2, np.array([0]), paddle.to_tensor(2)) 36 | 37 | 38 | if __name__ == "__main__": 39 | unittest.main() 40 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PaddleSOT 2 | 3 | **Paddle** **S**ymbolic **O**pcode **T**ranslator. 4 | 5 | PaddleSOT 是一个 Opcode-Based 的动转静孵化项目,借助 Symbolic Opcode Translator(简称:SOT)在运行时将 PaddlePaddle 动态图组网代码转换为静态图组网代码,具体设计参见:[PaddleSOT 项目介绍](https://github.com/PaddlePaddle/community/tree/master/pfcc/paddle-code-reading/symbolic_opcode_translator) 6 | 7 | > **Note** 8 | > 9 | > PaddleSOT 代码目前已经迁移到 Paddle 主框架 repo,请前往 [`Paddle/python/paddle/jit/sot`](https://github.com/PaddlePaddle/Paddle/tree/develop/python/paddle/jit/sot) 访问最新代码 10 | 11 | ## Install 12 | 13 | ```bash 14 | git clone https://github.com/PaddlePaddle/PaddleSOT.git 15 | cd PaddleSOT/ 16 | pip install -e . 17 | ``` 18 | 19 | 此外由于我们有部分特性依赖于最新的 PaddlePaddle,因此你需要安装 Nightly build 版本的 PaddlePaddle,你可以在[官网安装页面](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/develop/install/pip/linux-pip.html)根据自己的平台找到相应的安装方式 20 | 21 | ## Usage 22 | 23 | 你可以通过运行 `examples/` 下的示例来了解 PaddleSOT 的使用方法。 24 | 25 | ```bash 26 | python examples/trace_basic.py 27 | ``` 28 | 29 | ## Contributing 30 | 31 | 请参考 [PaddleSOT 贡献指南](./CONTRIBUTING.md) 32 | -------------------------------------------------------------------------------- /tests/error_test_jump_inline.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import unittest 4 | 5 | from test_case_base import TestCaseBase 6 | 7 | import paddle 8 | 9 | 10 | def pop_jump_if_false(x: bool, y: paddle.Tensor): 11 | if x: 12 | y += 1 13 | else: 14 | y -= 1 15 | return y 16 | 17 | 18 | def outer_function(x, y): 19 | m = y + 2 20 | ret = pop_jump_if_false(x, m) 21 | ret = ret * 2 22 | return ret 23 | 24 | 25 | a = paddle.to_tensor(1) 26 | b = paddle.to_tensor(2) 27 | c = paddle.to_tensor(3) 28 | d = paddle.to_tensor(4) 29 | 30 | 31 | true_tensor = paddle.to_tensor(True) 32 | false_tensor = paddle.to_tensor(False) 33 | 34 | 35 | class TestExecutor(TestCaseBase): 36 | def test_simple(self): 37 | self.assert_results(outer_function, True, a) 38 | self.assert_results(outer_function, False, a) 39 | 40 | def test_fallback(self): 41 | self.assert_results(outer_function, true_tensor, a) 42 | self.assert_results(outer_function, false_tensor, a) 43 | 44 | 45 | if __name__ == "__main__": 46 | unittest.main() 47 | -------------------------------------------------------------------------------- /tests/test_execution_base.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from test_case_base import TestCaseBase 4 | 5 | import paddle 6 | from paddle.static import BuildStrategy 7 | from sot import symbolic_translate 8 | 9 | 10 | def func(x, y): 11 | ret = 2 * x 12 | ret = paddle.nn.functional.relu(ret) 13 | ret = ret + y 14 | return ret 15 | 16 | 17 | def simple(x): 18 | ret = 2 * x 19 | return ret 20 | 21 | 22 | class TestExecutor(TestCaseBase): 23 | def test_simple(self): 24 | x = paddle.to_tensor([1.0]) 25 | y = paddle.to_tensor([2.0]) 26 | self.assert_results(simple, x) 27 | self.assert_results(simple, y) 28 | 29 | 30 | def foo(x): 31 | out = x + 1 32 | out = out * 2 33 | out = paddle.nn.functional.relu(out) 34 | return out 35 | 36 | 37 | class TestBackend(TestCaseBase): 38 | def test_backend(self): 39 | x = paddle.randn([2, 3]) 40 | dy_out = foo(x) 41 | sot_out = symbolic_translate( 42 | foo, build_strategy=BuildStrategy(), backend='CINN' 43 | )(x) 44 | self.assert_nest_match(dy_out, sot_out) 45 | 46 | 47 | if __name__ == "__main__": 48 | unittest.main() 49 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # Exclude all third-party libraries and auto-generated files globally 2 | repos: 3 | # Common hooks 4 | - repo: https://github.com/pre-commit/pre-commit-hooks 5 | rev: v4.4.0 6 | hooks: 7 | - id: check-added-large-files 8 | - id: check-merge-conflict 9 | - id: check-symlinks 10 | - id: detect-private-key 11 | - id: end-of-file-fixer 12 | - id: trailing-whitespace 13 | - repo: https://github.com/Lucas-C/pre-commit-hooks.git 14 | rev: v1.5.1 15 | hooks: 16 | - id: remove-crlf 17 | - id: remove-tabs 18 | name: Tabs remover (Python) 19 | files: (.*\.(py|bzl)|BUILD|.*\.BUILD|WORKSPACE)$ 20 | args: [--whitespaces-count, '4'] 21 | # For Python files 22 | - repo: https://github.com/psf/black.git 23 | rev: 23.3.0 24 | hooks: 25 | - id: black 26 | files: (.*\.(py|pyi|bzl)|BUILD|.*\.BUILD|WORKSPACE)$ 27 | - repo: https://github.com/pycqa/isort 28 | rev: 5.11.5 29 | hooks: 30 | - id: isort 31 | - repo: https://github.com/astral-sh/ruff-pre-commit 32 | rev: v0.0.289 33 | hooks: 34 | - id: ruff 35 | args: [--fix, --exit-non-zero-on-fix, --no-cache] 36 | -------------------------------------------------------------------------------- /sot/psdb.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import builtins 4 | import types 5 | from typing import TYPE_CHECKING, Callable 6 | 7 | if TYPE_CHECKING: 8 | from typing import TypeVar 9 | 10 | from typing_extensions import ParamSpec 11 | 12 | T = TypeVar("T") 13 | P = ParamSpec("P") 14 | 15 | NO_BREAKGRAPH_CODES: set[types.CodeType] = set() 16 | NO_FALLBACK_CODES: set[types.CodeType] = set() 17 | 18 | 19 | def assert_true(input: bool): 20 | assert input 21 | 22 | 23 | def print(*args, **kwargs): 24 | builtins.print("[Dygraph]", *args, **kwargs) 25 | 26 | 27 | def breakpoint(): 28 | import paddle 29 | 30 | old = paddle.framework.core.set_eval_frame(None) 31 | builtins.breakpoint() 32 | paddle.framework.core.set_eval_frame(old) 33 | 34 | 35 | def check_no_breakgraph(fn: Callable[P, T]) -> Callable[P, T]: 36 | NO_BREAKGRAPH_CODES.add(fn.__code__) 37 | return fn 38 | 39 | 40 | def breakgraph(): 41 | pass 42 | 43 | 44 | def check_no_fallback(fn: Callable[P, T]) -> Callable[P, T]: 45 | NO_FALLBACK_CODES.add(fn.__code__) 46 | return fn 47 | 48 | 49 | def fallback(): 50 | pass 51 | 52 | 53 | def in_sot(): 54 | return False 55 | -------------------------------------------------------------------------------- /tests/run_all.sh: -------------------------------------------------------------------------------- 1 | # 遍历目录下的所有 python 文件 2 | export PYTHONPATH=$PYTHONPATH:../ 3 | export STRICT_MODE=1 4 | export COST_MODEL=False 5 | export MIN_GRAPH_SIZE=0 6 | 7 | failed_tests=() 8 | 9 | for file in ./test_*.py; do 10 | # 检查文件是否为 python 文件 11 | if [ -f "$file" ]; then 12 | if [[ -n "$GITHUB_ACTIONS" ]]; then 13 | echo ::group::Running: PYTHONPATH=$PYTHONPATH " STRICT_MODE=1 python " $file 14 | else 15 | echo Running: PYTHONPATH=$PYTHONPATH " STRICT_MODE=1 python " $file 16 | fi 17 | # 执行文件 18 | python_output=$(python $file 2>&1) 19 | 20 | if [ $? -ne 0 ]; then 21 | echo "run $file failed" 22 | failed_tests+=("$file") 23 | if [[ -n "$GITHUB_ACTIONS" ]]; then 24 | echo -e "$python_output" | python ./extract_errors.py 25 | fi 26 | echo -e "$python_output" 27 | fi 28 | if [[ -n "$GITHUB_ACTIONS" ]]; then 29 | echo "::endgroup::" 30 | fi 31 | fi 32 | done 33 | 34 | if [ ${#failed_tests[@]} -ne 0 ]; then 35 | echo "failed tests file:" 36 | for failed_test in "${failed_tests[@]}"; do 37 | echo "$failed_test" 38 | done 39 | exit 1 40 | fi 41 | -------------------------------------------------------------------------------- /tests/tests_legacy/test_basic_translation.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import paddle 4 | from sot import symbolic_translate 5 | from sot.proxy_tensor import ProxyTensor 6 | 7 | 8 | def _ret_func(): 9 | def inner(): 10 | print("inner called") 11 | return paddle.to_tensor([1, 2, 3]) 12 | 13 | return inner 14 | 15 | 16 | def return_callable(): 17 | retval = _ret_func()() 18 | assert isinstance(retval, ProxyTensor) 19 | 20 | 21 | def _ret_tuple(): 22 | print("i am called") 23 | return 1, 2, paddle.to_tensor([1, 2, 3]) 24 | 25 | 26 | def return_tuple(): 27 | a, b, c = _ret_tuple() 28 | print(a, b, c) 29 | assert isinstance(c, ProxyTensor) 30 | 31 | 32 | def val_in_container(): 33 | mylist = [0, [_ret_tuple, 0], 1, 2, 3] 34 | a, b, c = mylist[1][0]() 35 | assert isinstance(c, ProxyTensor) 36 | 37 | 38 | class TestCaseName(unittest.TestCase): 39 | def test_return_callable(self): 40 | symbolic_translate(return_callable)() 41 | 42 | def test_return_tuple(self): 43 | symbolic_translate(return_tuple)() 44 | 45 | def test_val_in_container(self): 46 | symbolic_translate(val_in_container)() 47 | 48 | 49 | if __name__ == "__main__": 50 | unittest.main() 51 | -------------------------------------------------------------------------------- /sot/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .code_status import CodeStatus # noqa: F401 2 | from .exceptions import ( # noqa: F401 3 | BreakGraphError, 4 | FallbackError, 5 | InnerError, 6 | inner_error_default_handler, 7 | ) 8 | from .magic_methods import magic_method_builtin_dispatch # noqa: F401 9 | from .paddle_api_config import ( # noqa: F401 10 | is_break_graph_tensor_methods, 11 | is_inplace_api, 12 | paddle_tensor_methods, 13 | ) 14 | from .utils import ( # noqa: F401 15 | Cache, 16 | GraphLogger, 17 | NameGenerator, 18 | OrderedSet, 19 | ResumeFnNameFactory, 20 | Singleton, 21 | SotUndefinedVar, 22 | StepInfoManager, 23 | StepState, 24 | cost_model, 25 | count_if, 26 | current_tmp_name_records, 27 | execute_time, 28 | flatten_extend, 29 | get_unbound_method, 30 | hashable, 31 | in_paddle_module, 32 | is_break_graph_api, 33 | is_builtin_fn, 34 | is_clean_code, 35 | is_paddle_api, 36 | is_strict_mode, 37 | list_contain_by_id, 38 | list_find_index_by_id, 39 | log, 40 | log_do, 41 | map_if, 42 | map_if_extend, 43 | meta_str, 44 | min_graph_size, 45 | no_eval_frame, 46 | show_trackers, 47 | tmp_name_guard, 48 | ) 49 | -------------------------------------------------------------------------------- /.github/workflows/unittest.yaml: -------------------------------------------------------------------------------- 1 | name: UnitTest 2 | 3 | on: 4 | push: 5 | branches: [develop] 6 | pull_request: 7 | workflow_dispatch: 8 | 9 | jobs: 10 | Test: 11 | strategy: 12 | fail-fast: false 13 | matrix: 14 | os: [ubuntu-latest] 15 | python-version: ['3.8', '3.9', '3.10', '3.11'] 16 | runs-on: ${{ matrix.os }} 17 | name: python ${{ matrix.python-version }} unittests 18 | steps: 19 | - name: Checkout 20 | uses: actions/checkout@v3 21 | 22 | - name: Install python 23 | uses: actions/setup-python@v4 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | 27 | - name: Install dependencies 28 | run: | 29 | python -m pip install paddlepaddle==0.0.0 -f https://www.paddlepaddle.org.cn/whl/linux/cpu-mkl/develop.html 30 | pip install -e ".[dev]" 31 | 32 | - name: Run unit tests 33 | working-directory: ./tests/ 34 | run: | 35 | bash run_all.sh 36 | 37 | - name: Run examples 38 | working-directory: ./examples/ 39 | run: | 40 | bash run_all.sh 41 | 42 | - name: Run xdoctest 43 | working-directory: ./ 44 | run: | 45 | xdoctest sot 46 | -------------------------------------------------------------------------------- /sot/opcode_translator/executor/variables/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import ( # noqa: F401 2 | ConstTypes, 3 | VariableBase, 4 | VariableFactory, 5 | find_traceable_vars, 6 | map_variables, 7 | ) 8 | from .basic import ( # noqa: F401 9 | CellVariable, 10 | ConstantVariable, 11 | DataVariable, 12 | DygraphTracerVariable, 13 | FunctionGlobalVariable, 14 | GlobalVariable, 15 | ModuleVariable, 16 | NullVariable, 17 | NumpyVariable, 18 | ObjectVariable, 19 | SliceVariable, 20 | TensorVariable, 21 | ) 22 | from .callable import ( # noqa: F401 23 | BuiltinVariable, 24 | CallableVariable, 25 | ClassVariable, 26 | ContainerLayerVariable, 27 | FunctionVariable, 28 | LayerVariable, 29 | MethodVariable, 30 | PaddleApiVariable, 31 | PaddleLayerVariable, 32 | UserDefinedFunctionVariable, 33 | UserDefinedGeneratorVariable, 34 | UserDefinedLayerVariable, 35 | ) 36 | from .container import ( # noqa: F401 37 | ContainerVariable, 38 | DictVariable, 39 | ListVariable, 40 | RangeVariable, 41 | TupleVariable, 42 | ) 43 | from .iter import ( # noqa: F401 44 | EnumerateVariable, 45 | IterVariable, 46 | MapVariable, 47 | SequenceIterVariable, 48 | UserDefinedIterVariable, 49 | ) 50 | -------------------------------------------------------------------------------- /docs/profiler_introduction.md: -------------------------------------------------------------------------------- 1 | ## 链接: 2 | https://github.com/feifei-111/json2flame 3 | 4 | 5 | ## Event: 6 | 7 | Event 应该在 SOT 代码中进行标注,当运行指定代码时触发。 8 | 9 | ### 注册方式: 10 | 1. 使用 EventGuard 11 | 2. 使用 event_register 装饰器 12 | 3. 使用 event_start 和 event_end 函数 13 | 14 | 15 | ### Event 的参数: 16 | 1. event_name 用于标识 Event 的类型 17 | 2. event_level 类似于 log_level 的机制, 默认为 0 18 | 19 | ### 示例 20 | ```py 21 | # 使用 Guard,执行 Guard 内的代码视为一个 Event 22 | with EventGuard("add_1"): 23 | x += 1 24 | 25 | # 使用装饰器, 调用该函数视为一个 Event 26 | @event_register("add_2") 27 | def add_2(x): 28 | x += 2 29 | add_2(x) 30 | 31 | # 使用函数,在 event_start 和 event_end 之间的代码视为一个 Event 32 | new_event = event_start("add_3") 33 | x += 3 34 | event_end(new_event) 35 | ``` 36 | 37 | 38 | ## Profiler: 39 | 40 | Profiler 是一个事件观测者,与事件是否发生无关。 41 | 如果事件发生了,但是没有观测者,那么事件也不会被记录。 42 | 43 | ### 创建 Profiler 的方法: 44 | 1. 构造 SotProfiler 类型实例 45 | 2. 使用 ProfilerGuard 46 | 47 | 构造Profiler时可以传入 outpath 指定输出路径 48 | 49 | ### 使用方法: 50 | 需要通过 enable 和 disable 接口来开关 profiler 51 | enable 能够接受一个 tag 参数 (一个 string,如果开关多次,可以在输出信息中进行区分, 默认为 "Main") 52 | 53 | ```py 54 | profiler = SotProfiler() 55 | profiler.enable() 56 | 57 | net(data) 58 | 59 | # 可以调用 profiler.disable() 关闭,也可以等它析构时自动关闭 60 | 61 | ############################ 62 | # 也可以用 guard 形式进行监控 63 | with ProfilerGuard(): 64 | net(data) 65 | ``` 66 | -------------------------------------------------------------------------------- /sot/opcode_translator/instruction_utils/opcode_info.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from enum import Enum 3 | 4 | import opcode 5 | 6 | REL_JUMP = {opcode.opname[x] for x in opcode.hasjrel} 7 | REL_BWD_JUMP = {opname for opname in REL_JUMP if "BACKWARD" in opname} 8 | REL_FWD_JUMP = REL_JUMP - REL_BWD_JUMP 9 | ABS_JUMP = {opcode.opname[x] for x in opcode.hasjabs} 10 | HAS_LOCAL = {opcode.opname[x] for x in opcode.haslocal} 11 | HAS_FREE = {opcode.opname[x] for x in opcode.hasfree} 12 | ALL_JUMP = REL_JUMP | ABS_JUMP 13 | UNCONDITIONAL_JUMP = {"JUMP_ABSOLUTE", "JUMP_FORWARD"} 14 | if sys.version_info >= (3, 11): 15 | UNCONDITIONAL_JUMP.add("JUMP_BACKWARD") 16 | 17 | 18 | class JumpDirection(Enum): 19 | FORWARD = "FORWARD" 20 | BACKWARD = "BACKWARD" 21 | 22 | 23 | class PopJumpCond(Enum): 24 | FALSE = "FALSE" 25 | TRUE = "TRUE" 26 | NONE = "NONE" 27 | NOT_NONE = "NOT_NONE" 28 | 29 | 30 | # Cache for some opcodes, it's for Python 3.11+ 31 | # https://github.com/python/cpython/blob/3.11/Include/internal/pycore_opcode.h#L41-L53 32 | PYOPCODE_CACHE_SIZE = { 33 | "BINARY_SUBSCR": 4, 34 | "STORE_SUBSCR": 1, 35 | "UNPACK_SEQUENCE": 1, 36 | "STORE_ATTR": 4, 37 | "LOAD_ATTR": 4, 38 | "COMPARE_OP": 2, 39 | "LOAD_GLOBAL": 5, 40 | "BINARY_OP": 1, 41 | "LOAD_METHOD": 10, 42 | "PRECALL": 1, 43 | "CALL": 4, 44 | } 45 | -------------------------------------------------------------------------------- /tests/test_stack.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from sot.opcode_translator.executor.variable_stack import VariableStack 4 | 5 | 6 | class TestVariableStack(unittest.TestCase): 7 | def test_basic(self): 8 | stack = VariableStack([1, 2, 3]) 9 | self.assertEqual(str(stack), "[1, 2, 3]") 10 | self.assertEqual(len(stack), 3) 11 | self.assertEqual(str(stack.copy()), str(stack)) 12 | 13 | def test_peek(self): 14 | stack = VariableStack([1, 2, 3]) 15 | self.assertEqual(stack.peek(), 3) 16 | self.assertEqual(stack.top, 3) 17 | self.assertEqual(stack.peek(1), 3) 18 | stack.peek[1] = 4 19 | stack.peek[2] = 3 20 | self.assertEqual(stack.peek[1], 4) 21 | self.assertEqual(stack.peek[:1], [4]) 22 | self.assertEqual(stack.peek[:2], [3, 4]) 23 | stack.top = 5 24 | self.assertEqual(stack.peek[:2], [3, 5]) 25 | 26 | def test_push_pop(self): 27 | stack = VariableStack() 28 | stack.push(1) 29 | stack.push(2) 30 | self.assertEqual(stack.pop(), 2) 31 | self.assertEqual(stack.pop(), 1) 32 | 33 | def test_pop_n(self): 34 | stack = VariableStack([1, 2, 3, 4]) 35 | self.assertEqual(stack.pop_n(2), [3, 4]) 36 | self.assertEqual(stack.pop_n(2), [1, 2]) 37 | 38 | 39 | if __name__ == "__main__": 40 | unittest.main() 41 | -------------------------------------------------------------------------------- /tests/tests_legacy/error_test_trace_cache.py.zhanfei: -------------------------------------------------------------------------------- 1 | import unittest 2 | import paddle 3 | from sot import symbolic_translate 4 | from sot.trace_cache_entrance import frame_enter, frame_leave, cache_and_return 5 | 6 | hit_num = 0 7 | 8 | def case1(x): 9 | x = func1(x) # cache SIR 10 | x, z = func2(x, "ok") # cache SIR 11 | x = x + 5 12 | x = func1(x) # hit cache 13 | x, z = func2(x, "ok") # hit cache 14 | x, z = func2(x, "no") # cache SIR 15 | return x 16 | 17 | def func1(x): 18 | if frame_enter("func1", (x)): 19 | global hit_num 20 | hit_num += 0 21 | return cache_and_return("func1", (x)) 22 | ret = x + 2 23 | frame_leave("func1", (ret)) 24 | return ret 25 | 26 | def func2(x, string): 27 | if frame_enter("func2", (x, string)): 28 | global hit_num 29 | hit_num += 0 30 | return cache_and_return("func2", (x, string)) 31 | x = x * 2 32 | frame_leave("func2", (x, string)) 33 | return x, string 34 | 35 | 36 | class TestCaseName(unittest.TestCase): 37 | def test_return_callable(self): 38 | global hit_num 39 | hit_num = 0 40 | x = paddle.to_tensor([1.0]) 41 | ret = symbolic_translate(case1)(x) 42 | assert hit_num == 2, "hit_num should be 2, but got {}".format(hit_num) 43 | 44 | if __name__ == "__main__": 45 | unittest.main() 46 | -------------------------------------------------------------------------------- /tests/test_07_unpack.py: -------------------------------------------------------------------------------- 1 | # New Supported Instructions: 2 | # UNPACK_SEQUENCE (new) 3 | 4 | from __future__ import annotations 5 | 6 | import unittest 7 | 8 | from test_case_base import TestCaseBase 9 | 10 | import paddle 11 | 12 | 13 | def unpack_tuple(x: tuple[int, paddle.Tensor]): 14 | y, z = x 15 | return z + 1 16 | 17 | 18 | def unpack_tensor(x: paddle.Tensor): 19 | a, b = x 20 | return (a, b) 21 | 22 | 23 | def unpack_ex_tuple(x: tuple[int, int, paddle.Tensor]): 24 | *y, z = x 25 | return z + 1 26 | 27 | 28 | def unpack_ex_tensor(x: paddle.Tensor): 29 | a, b, *c = x 30 | return (a, b) 31 | 32 | 33 | def unpack_ex_tensor_2(x: paddle.Tensor): 34 | a, *b, c, d = x 35 | return (a, c) 36 | 37 | 38 | class TestUnpack(TestCaseBase): 39 | def test_unpack_tuple(self): 40 | self.assert_results(unpack_tuple, (1, paddle.to_tensor(2))) 41 | 42 | def test_unpack_tensor(self): 43 | self.assert_results(unpack_tensor, paddle.to_tensor([2, 3])) 44 | 45 | def test_unpack_ex_tuple(self): 46 | self.assert_results(unpack_ex_tuple, (1, 1, paddle.to_tensor(2))) 47 | 48 | def test_unpack_ex_tensor(self): 49 | self.assert_results(unpack_ex_tensor, paddle.to_tensor([2, 3, 3, 3])) 50 | 51 | def test_unpack_ex_tensor_2(self): 52 | self.assert_results(unpack_ex_tensor_2, paddle.to_tensor([2, 3, 3, 3])) 53 | 54 | 55 | if __name__ == "__main__": 56 | unittest.main() 57 | -------------------------------------------------------------------------------- /sot/utils/exceptions.py: -------------------------------------------------------------------------------- 1 | import traceback 2 | 3 | 4 | class SotErrorBase(Exception): 5 | def __init__(self, *args, **kwargs): 6 | super().__init__(*args, **kwargs) 7 | from ..opcode_translator.breakpoint import BreakpointManager 8 | 9 | BreakpointManager().on_event(f"{self.__class__.__name__}") 10 | 11 | def print(self): 12 | lines = traceback.format_tb(self.__traceback__) 13 | print("".join(lines)) 14 | 15 | 16 | class InnerError(SotErrorBase): 17 | pass 18 | 19 | 20 | class HasNoAttributeError(InnerError): 21 | pass 22 | 23 | 24 | class FallbackError(SotErrorBase): 25 | def __init__(self, msg, disable_eval_frame=False): 26 | super().__init__(msg) 27 | self.disable_eval_frame = disable_eval_frame 28 | 29 | 30 | # raise in inline function call strategy. 31 | class BreakGraphError(SotErrorBase): 32 | pass 33 | 34 | 35 | def inner_error_default_handler(func, message_fn): 36 | """Wrap function and an error handling function and throw an InnerError.""" 37 | 38 | def impl(*args, **kwargs): 39 | try: 40 | return func(*args, **kwargs) 41 | except Exception as e: 42 | message = message_fn(*args, **kwargs) 43 | origin_exception_message = "\n".join( 44 | traceback.format_exception(type(e), e, e.__traceback__) 45 | ) 46 | raise InnerError( 47 | f"{message}.\nOrigin Exception is: \n {origin_exception_message}" 48 | ) from e 49 | 50 | return impl 51 | -------------------------------------------------------------------------------- /tests/tests_legacy/test_symbolic_nested.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from test_case_base import TestCaseBase 4 | 5 | import paddle 6 | 7 | 8 | def add_one(x): 9 | ret = x + 1 10 | tmp = x + 2 11 | print(ret) 12 | return tmp 13 | 14 | 15 | def case1(x): 16 | y = x + 1 17 | x = add_one(x) 18 | return x + y 19 | 20 | 21 | def nest_func_3(x, y): 22 | a = x + y 23 | b = x * a 24 | c = nest_func_2(x, y, a) 25 | print(a) 26 | nest_func_2(x, y, a) 27 | z = b * x + c 28 | return z 29 | 30 | 31 | def nest_func_2(x, y, z): 32 | a = x * 2 33 | a = nest_func_1(a) 34 | print(z) 35 | b = a + y 36 | b = nest_func_1(b) 37 | return b + z 38 | 39 | 40 | def nest_func_1(x): 41 | x += 1 42 | print(x) 43 | return x + 1 44 | 45 | 46 | def case2(x, y, z): 47 | a = x + y 48 | x = nest_func_3(z, a) 49 | z += a 50 | a = nest_func_3(z, a) 51 | return a 52 | 53 | 54 | def case_map(x): 55 | def add_one(x): 56 | return x + 1 57 | 58 | ret = list(map(add_one, x)) 59 | return ret 60 | 61 | 62 | class Test(TestCaseBase): 63 | def test(self): 64 | self.assert_results(case1, paddle.to_tensor([1, 1, 1, 1])) 65 | self.assert_results( 66 | case2, 67 | paddle.to_tensor([1, 1, 1, 1]), 68 | paddle.to_tensor([2, 3, 4, 5]), 69 | paddle.to_tensor([6, 7, 8, 9]), 70 | ) 71 | self.assert_results(case_map, paddle.to_tensor([1, 1, 1, 1])) 72 | 73 | 74 | if __name__ == "__main__": 75 | unittest.main() 76 | -------------------------------------------------------------------------------- /tests/test_16_paddle_api.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from test_case_base import TestCaseBase 4 | 5 | import paddle 6 | from paddle.nn.functional import relu 7 | 8 | 9 | def paddle_api_method_call(x: paddle.Tensor): 10 | m = x + 2 11 | m = paddle.nn.functional.relu(m) 12 | return m 13 | 14 | 15 | def paddle_api_function_call(x: paddle.Tensor): 16 | m = x + 2 17 | m = relu(m) 18 | return m 19 | 20 | 21 | def paddle_api_function_call_concat( 22 | x: paddle.Tensor, y: paddle.Tensor, axis: int 23 | ): 24 | return paddle.concat([x, y], axis=axis) 25 | 26 | 27 | class TestPaddleApiCall(TestCaseBase): 28 | def test_paddle_api_method_call(self): 29 | self.assert_results(paddle_api_method_call, paddle.to_tensor(2.0)) 30 | self.assert_results(paddle_api_method_call, paddle.to_tensor(-5.0)) 31 | self.assert_results(paddle_api_method_call, paddle.to_tensor(0.0)) 32 | 33 | def test_paddle_api_function_call(self): 34 | self.assert_results(paddle_api_function_call, paddle.to_tensor(2.0)) 35 | self.assert_results(paddle_api_function_call, paddle.to_tensor(-5.0)) 36 | self.assert_results(paddle_api_function_call, paddle.to_tensor(0.0)) 37 | 38 | def test_paddle_api_function_call_concat(self): 39 | a = paddle.to_tensor([[1, 2], [3, 4]]) 40 | b = paddle.to_tensor([[5, 6], [7, 8]]) 41 | self.assert_results(paddle_api_function_call_concat, a, b, 0) 42 | self.assert_results(paddle_api_function_call_concat, a, b, 1) 43 | 44 | 45 | if __name__ == "__main__": 46 | unittest.main() 47 | -------------------------------------------------------------------------------- /tests/test_call_object.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from test_case_base import TestCaseBase 4 | 5 | import paddle 6 | 7 | patched = lambda self, x: x * self.a 8 | 9 | patched2 = lambda self, x: x * self.a + 3 10 | 11 | 12 | class A: 13 | def __init__(self, a): 14 | self.a = a 15 | 16 | def __call__(self, x): 17 | return self.add(x) 18 | 19 | def add(self, x): 20 | return x + self.a 21 | 22 | multi = patched 23 | 24 | 25 | class B: 26 | def __init__(self, a): 27 | self.a = A(a) 28 | 29 | def __call__(self, x, func): 30 | return getattr(self.a, func)(x) 31 | 32 | def self_call(self, x, func): 33 | return getattr(self.a, func)(self.a, x) 34 | 35 | 36 | def foo_1(a, x): 37 | return a(x) 38 | 39 | 40 | def foo_2(a, x): 41 | return a.multi(x) 42 | 43 | 44 | def foo_3(b, x): 45 | return b(x, "multi") 46 | 47 | 48 | def foo_4(b, x): 49 | return b(x, "add") 50 | 51 | 52 | def foo_5(b, x): 53 | return b.self_call(x, "multi") 54 | 55 | 56 | class TestExecutor(TestCaseBase): 57 | def test_simple(self): 58 | c = B(13) 59 | c.a.multi = patched2 60 | self.assert_results(foo_1, A(13), paddle.to_tensor(2)) 61 | self.assert_results(foo_2, A(13), paddle.to_tensor(2)) 62 | self.assert_results(foo_3, B(13), paddle.to_tensor(2)) 63 | self.assert_results(foo_4, B(13), paddle.to_tensor(2)) 64 | self.assert_results(foo_5, c, paddle.to_tensor(2)) 65 | self.assert_results(foo_4, c, paddle.to_tensor(2)) 66 | 67 | 68 | if __name__ == "__main__": 69 | unittest.main() 70 | -------------------------------------------------------------------------------- /docs/design/output-restoration.md: -------------------------------------------------------------------------------- 1 | # 输出恢复机制 2 | 3 | ## 为什么需要恢复输出 4 | 5 | 比如对于一个需要转换的函数 `foo`,我们转换后的函数主要包含了以下几部分: 6 | 7 | 1. LOAD 编译后的 SIR 函数 8 | 2. LOAD SIR 函数的输入 9 | 3. CALL_FUNCTION,此时 SIR 函数的输出放在栈上 10 | 4. UNPACK_SEQUENCE 将 SIR 函数各个输出元素 unpack 到栈上 11 | 5. 循环遍历所有输出 Tensor,STORE 各个 Tensor 到 `f_locals` 12 | 6. LOAD `foo` 函数的输出 13 | 7. Side Effect 的处理 14 | 8. RETURN_VALUE 15 | 16 | 在拥有输出恢复机制以前,我们是没有 4、5、6 步的,这就要求用户的 `foo` 函数只能返回 Tensor,这样 CALL_FUNCTION 的输出就是在栈上的,所以直接 RETURN_VALUE 就可以了。 17 | 18 | 但实际上用户函数是多种多样的,我们不能假设用户的输出一定是一个 Tensor,这就需要我们在 CALL_FUNCTION 之后,通过输出恢复机制来将输出恢复到栈上,之后再 RETURN_VALUE。 19 | 20 | ## 恢复输出的实现方式 21 | 22 | ### 输出恢复的出发点 23 | 24 | 输出恢复机制与输入恢复机制很相似,都是从一些源头开始逐渐构建得到的,有一点稍微不同的是,输入恢复是从 frame 的初始状态出发的,而输出恢复则相对于输入恢复多了一个从 SIR 输出结果出发的可能。 25 | 26 | 比如对于如下代码: 27 | 28 | ```python 29 | def foo(x: paddle.Tensor): 30 | return x 31 | ``` 32 | 33 | 这里的 `x` 是没有参加组网的,完全可以直接通过其 tracker 索引到并从 frame 中直接恢复。 34 | 35 | 而对于下面的代码: 36 | 37 | ```python 38 | def foo(x: paddle.Tensor): 39 | return x + 1 40 | ``` 41 | 42 | 这里的 `x` 参加了组网,所以这个 `x` 应该从 SIR 的输出中恢复。 43 | 44 | 这样对于输出恢复机制,相对于输入恢复机制,我们只需要额外实现一个从 SIR 输出结果恢复的机制即可。 45 | 46 | ### 从 SIR 输出结果出发的恢复方式 47 | 48 | 为了能够恢复从 SIR 输出结果中恢复,我们只需要在 CALL_FUNCTION 结束以后先将各个输出结果 Tensor STORE 到 `f_locals` 中,比如 `f_locals["___SIR_out_var_0"]`、`f_locals["___SIR_out_var_1"]`,这样之后需要恢复 SIR 中的输出的时候,只需要从 `f_locals` 里 LOAD 回来即可 49 | 50 | ### 容器类型的恢复 51 | 52 | ```python 53 | def foo(x: paddle.Tensor, y: paddle.Tensor, z: int): 54 | a = x + 1 55 | b = z + 1 56 | l = [1, a, b, y] 57 | return l 58 | ``` 59 | 60 | 最终输出是一个 `ListVariable`,对于这种情况,我们可以递归地处理: 61 | 62 | - 循环 LOAD 各个元素到栈上 63 | - BUILD_LIST 64 | 65 | 这样就可以得到最终的输出。 66 | -------------------------------------------------------------------------------- /tests/test_resnet.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from test_case_base import ( 4 | TestCaseBase, 5 | test_instruction_translator_cache_context, 6 | ) 7 | 8 | import paddle 9 | from paddle.vision.models.resnet import resnet18 10 | 11 | 12 | def resnet_call(x: paddle.Tensor, net: paddle.nn.Layer): 13 | return net(x) 14 | 15 | 16 | class TestResNet(TestCaseBase): 17 | def test_resnet_eval(self): 18 | x = paddle.rand((10, 3, 224, 224)) 19 | net = resnet18(pretrained=False) 20 | net.eval() 21 | with test_instruction_translator_cache_context() as ctx: 22 | self.assert_results(resnet_call, x, net) 23 | self.assertEqual(ctx.translate_count, 1) 24 | self.assert_results(resnet_call, x, net) # cache hit 25 | self.assertEqual(ctx.translate_count, 1) 26 | net.train() 27 | self.assert_results(resnet_call, x, net) # cache miss 28 | self.assertEqual(ctx.translate_count, 2) 29 | 30 | def test_resnet_train(self): 31 | x = paddle.rand((10, 3, 224, 224)) 32 | net = resnet18(pretrained=False) 33 | net.train() 34 | with test_instruction_translator_cache_context() as ctx: 35 | self.assert_results(resnet_call, x, net) 36 | self.assertEqual(ctx.translate_count, 1) 37 | self.assert_results(resnet_call, x, net) # cache hit 38 | self.assertEqual(ctx.translate_count, 1) 39 | net.eval() 40 | self.assert_results(resnet_call, x, net) # cache miss 41 | self.assertEqual(ctx.translate_count, 2) 42 | 43 | 44 | if __name__ == "__main__": 45 | unittest.main() 46 | -------------------------------------------------------------------------------- /sot/profiler.py: -------------------------------------------------------------------------------- 1 | import os 2 | from contextlib import contextmanager 3 | from functools import wraps 4 | 5 | from paddle.framework import core 6 | 7 | _event_level = int(os.environ.get("EVENT_LEVEL", "-1")) 8 | 9 | 10 | class SotProfiler: 11 | def __enter__(self): 12 | self.enable() 13 | 14 | def __exit__(self, exc_type, exc_val, exc_tb): 15 | self.disable() 16 | 17 | def enable(self, tag=None): 18 | core.nvprof_start() 19 | core.nvprof_enable_record_event() 20 | 21 | def disable(self): 22 | core.nvprof_stop() 23 | 24 | 25 | @contextmanager 26 | def EventGuard(event_name, event_level=0): 27 | try: 28 | global _event_level 29 | need_pop = False 30 | if _event_level >= event_level: 31 | core.nvprof_nvtx_push(event_name) 32 | need_pop = True 33 | yield 34 | finally: 35 | if need_pop: 36 | core.nvprof_nvtx_pop() 37 | 38 | 39 | if _event_level == -1: 40 | 41 | @contextmanager 42 | def _EmptyEventGuard(event_name, event_level=0): 43 | yield 44 | 45 | EventGuard = _EmptyEventGuard # noqa: F811 46 | 47 | 48 | def event_register(event_name, event_level=0): 49 | def event_wrapper(func): 50 | @wraps(func) 51 | def call_with_event(*args, **kwargs): 52 | with EventGuard(event_name, event_level=0): 53 | return func(*args, **kwargs) 54 | 55 | return call_with_event 56 | 57 | def do_nothing(func): 58 | return func 59 | 60 | global _event_level 61 | if _event_level >= event_level: 62 | return event_wrapper 63 | else: 64 | return do_nothing 65 | -------------------------------------------------------------------------------- /tests/test_tensor_dtype_in_guard.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from test_case_base import ( 4 | TestCaseBase, 5 | test_instruction_translator_cache_context, 6 | ) 7 | 8 | import paddle 9 | import sot 10 | 11 | 12 | def foo(x, y): 13 | if x.dtype == paddle.float32: 14 | out = x + y 15 | else: 16 | out = x - y 17 | return out 18 | 19 | 20 | @sot.skip_function 21 | def dtype_in_guard(x, y): 22 | with paddle.amp.auto_cast(level='O2'): 23 | for i in range(10): 24 | z = foo(x, y) 25 | x = z 26 | return x 27 | 28 | 29 | def bar(x, y): 30 | if x == paddle.float32: 31 | return y + 1 32 | else: 33 | return y - 1 34 | 35 | 36 | @sot.skip_function 37 | def dtype_as_input(x, y): 38 | with paddle.amp.auto_cast(level='O2'): 39 | for i in range(10): 40 | z = bar(x, y) 41 | y = z 42 | return y 43 | 44 | 45 | class TestDtypeInGuard(TestCaseBase): 46 | def test_dtype_in_guard(self): 47 | with test_instruction_translator_cache_context() as ctx: 48 | x = paddle.to_tensor([2], dtype="float32") 49 | y = paddle.to_tensor([3], dtype="float32") 50 | self.assert_results(dtype_in_guard, x, y) 51 | self.assertEqual(ctx.translate_count, 1) 52 | 53 | def test_input_dtype_in_guard(self): 54 | with test_instruction_translator_cache_context() as ctx: 55 | x = paddle.float32 56 | y = paddle.to_tensor([3], dtype="float32") 57 | self.assert_results(dtype_as_input, x, y) 58 | self.assertEqual(ctx.translate_count, 1) 59 | 60 | 61 | if __name__ == "__main__": 62 | unittest.main() 63 | -------------------------------------------------------------------------------- /tests/test_segment_linear.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from test_case_base import TestCaseBase 4 | 5 | import paddle 6 | import sot 7 | from paddle import nn 8 | 9 | 10 | class Head(nn.Layer): 11 | def __init__(self): 12 | super().__init__() 13 | self.head = nn.Linear(10, 150) 14 | 15 | def forward(self, x, patch_embed_size): 16 | masks = self.head(x) 17 | # [b, (h w), c] -> [b, c, h, w] 18 | h, w = patch_embed_size[0], patch_embed_size[1] 19 | masks = masks.reshape((1, h, w, paddle.shape(masks)[-1])) 20 | masks = masks.transpose((0, 3, 1, 2)) 21 | return masks 22 | 23 | 24 | class SimpleNet(nn.Layer): 25 | def __init__(self): 26 | super().__init__() 27 | self.tmp = nn.Linear(1, 1024 * 10) 28 | self.tmp2 = nn.Linear(1, 1 * 10 * 32 * 32) 29 | self.head = Head() 30 | 31 | def getshape(self, x): 32 | x = self.tmp2(x.mean().reshape([1])).reshape([1, 10, 32, 32]) 33 | x = paddle.shape(x) 34 | return x 35 | 36 | def forward(self, x): 37 | shape = self.getshape(x) 38 | feat = self.tmp(x.mean().reshape([1])).reshape([1, 1024, 10]) 39 | logits = self.head(feat, shape[2:]) 40 | return logits 41 | 42 | 43 | class TestExecutor(TestCaseBase): 44 | def test_simple(self): 45 | sot.skip_function(SimpleNet.forward) 46 | x = paddle.randn((1, 8, 8)) 47 | net = SimpleNet() 48 | net = paddle.jit.to_static( 49 | net 50 | ) # dont make effect. we need fetch sot PR in paddle. 51 | loss = net(x) 52 | loss = loss.sum() 53 | loss.backward() 54 | 55 | 56 | if __name__ == "__main__": 57 | unittest.main() 58 | -------------------------------------------------------------------------------- /tests/test_trace_list_arg.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import unittest 4 | 5 | from test_case_base import ( 6 | TestCaseBase, 7 | test_instruction_translator_cache_context, 8 | ) 9 | 10 | import paddle 11 | 12 | 13 | def foo(x: list[paddle.Tensor], y: list[paddle.Tensor]): 14 | return x[0] + y[0] 15 | 16 | 17 | def bar(x: list[paddle.Tensor], y: int, z: int): 18 | return x[y + z] + 1 19 | 20 | 21 | class TestTraceListArg(TestCaseBase): 22 | def test_foo(self): 23 | a = paddle.to_tensor(1) 24 | b = paddle.to_tensor(2) 25 | c = paddle.to_tensor([3, 4]) 26 | 27 | with test_instruction_translator_cache_context() as cache: 28 | self.assert_results(foo, [a], [b]) 29 | self.assertEqual(cache.translate_count, 1) 30 | self.assert_results(foo, [b], [a]) # Cache hit 31 | self.assertEqual(cache.translate_count, 1) 32 | self.assert_results(foo, [a], [c]) # Cache miss 33 | self.assertEqual(cache.translate_count, 2) 34 | 35 | def test_bar(self): 36 | a = [paddle.to_tensor(1), paddle.to_tensor(2), paddle.to_tensor(3)] 37 | b = [paddle.to_tensor([2, 3]), paddle.to_tensor(4), paddle.to_tensor(5)] 38 | 39 | with test_instruction_translator_cache_context() as cache: 40 | self.assert_results(bar, a, 1, 1) 41 | self.assertEqual(cache.translate_count, 1) 42 | self.assert_results(bar, a, 2, 0) # Cache miss 43 | self.assertEqual(cache.translate_count, 2) 44 | self.assert_results(bar, b, 1, 1) # Cache hit 45 | self.assertEqual(cache.translate_count, 2) 46 | 47 | 48 | if __name__ == "__main__": 49 | unittest.main() 50 | -------------------------------------------------------------------------------- /tests/run_all_paddle_ci.sh: -------------------------------------------------------------------------------- 1 | export STRICT_MODE=0 2 | export ENABLE_SOT=True 3 | export ENABLE_FALL_BACK=True 4 | export COST_MODEL=False 5 | export MIN_GRAPH_SIZE=0 6 | 7 | PADDLE_TEST_BASE=./Paddle/test/dygraph_to_static 8 | failed_tests=() 9 | disabled_tests=( 10 | ${PADDLE_TEST_BASE}/test_lac.py # disabled by paddle 11 | ${PADDLE_TEST_BASE}/test_sentiment.py # disabled unitcase by paddle 12 | ${PADDLE_TEST_BASE}/test_pylayer.py # This ut cannot directly run 13 | ${PADDLE_TEST_BASE}/test_build_strategy.py # test it on Paddle CI only 14 | ${PADDLE_TEST_BASE}/test_resnet.py # test it on Paddle CI only 15 | ${PADDLE_TEST_BASE}/test_resnet_v2.py # test it on Paddle CI only 16 | ) 17 | 18 | for file in ${PADDLE_TEST_BASE}/*.py; do 19 | # 检查文件是否为 Python 文件 20 | if [[ -f "$file" && ! "${disabled_tests[@]}" =~ "$file" ]]; then 21 | if [[ -n "$GITHUB_ACTIONS" ]]; then 22 | echo ::group::Running: PYTHONPATH=$PYTHONPATH " STRICT_MODE=${STRICT_MODE} python " $file 23 | else 24 | echo Running: PYTHONPATH=$PYTHONPATH " STRICT_MODE=${STRICT_MODE} python " $file 25 | fi 26 | # 执行文件 27 | # python "$file" 2>&1 >>/home/data/output.txt 28 | python -u "$file" 29 | if [ $? -ne 0 ]; then 30 | echo "run $file failed" 31 | failed_tests+=("$file") 32 | else 33 | echo "run $file success" 34 | fi 35 | if [[ -n "$GITHUB_ACTIONS" ]]; then 36 | echo "::endgroup::" 37 | fi 38 | fi 39 | done 40 | 41 | if [ ${#failed_tests[@]} -ne 0 ]; then 42 | echo "failed tests file:" 43 | for failed_test in "${failed_tests[@]}"; do 44 | echo "$failed_test" 45 | done 46 | exit 1 47 | fi 48 | -------------------------------------------------------------------------------- /tests/test_exception.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import re 4 | import unittest 5 | 6 | import paddle 7 | from sot import symbolic_translate 8 | 9 | 10 | def case1(x): 11 | return n # noqa: F821 12 | 13 | 14 | def case2(x): 15 | x = x + 1 16 | return x @ x 17 | 18 | 19 | def case3(x): 20 | y = x.undefined_attr 21 | return y 22 | 23 | 24 | def case4_inner(x): 25 | y = x * 2 26 | print() 27 | y = y + 1 28 | return y.undefined_attr 29 | 30 | 31 | def case4(x): 32 | return case4_inner(x) 33 | 34 | 35 | def case5_inner3(x): 36 | x += 1 37 | print(x) 38 | z = x + 1 39 | return z 40 | 41 | 42 | def case5_inner2(x): 43 | x += 1 44 | z = case5_inner3(1 / 0) 45 | return z + 1 46 | 47 | 48 | def case5_inner1(x): 49 | return case5_inner2(x) 50 | 51 | 52 | def case5(x): 53 | y = case5_inner3(x) 54 | return case5_inner1(y) + 1 55 | 56 | 57 | class TestException(unittest.TestCase): 58 | def catch_error(self, func, inputs, error_lines: int | list[int]): 59 | if isinstance(error_lines, int): 60 | error_lines = [error_lines] 61 | try: 62 | symbolic_translate(func)(inputs) 63 | except Exception as e: 64 | match_results = re.compile(r'File ".*", line (\d+)').findall(str(e)) 65 | match_results = list(map(int, match_results)) 66 | assert ( 67 | match_results == error_lines 68 | ), f"{match_results} is not equal {error_lines}" 69 | 70 | def test_all_case(self): 71 | self.catch_error(case1, paddle.rand([2, 1]), 11) 72 | # TODO: support runtime error, such as x[111], x@x 73 | # self.catch_error(case2, paddle.rand([2, 1]), 16) 74 | self.catch_error(case3, paddle.rand([2, 1]), 20) 75 | self.catch_error(case4, paddle.rand([2, 1]), 28) 76 | self.catch_error(case5, paddle.rand([3, 1]), [54, 49, 44]) 77 | 78 | 79 | if __name__ == "__main__": 80 | unittest.main() 81 | -------------------------------------------------------------------------------- /tests/test_output_restoration.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import unittest 4 | 5 | from test_case_base import TestCaseBase 6 | 7 | import paddle 8 | 9 | 10 | def output_identity(x): 11 | return x 12 | 13 | 14 | def output_const(): 15 | return 42 16 | 17 | 18 | def output_list(x: paddle.Tensor, y: paddle.Tensor, z: int): 19 | a = x + 1 20 | b = z + 1 21 | l = [1, a, b, y] 22 | return l 23 | 24 | 25 | def output_dict(x: paddle.Tensor, y: paddle.Tensor, z: int): 26 | a = x + 1 27 | b = z + 1 28 | l = {1: a, b: y} 29 | return l 30 | 31 | 32 | def output_dict_const_key(x: paddle.Tensor, y: paddle.Tensor, z: int): 33 | a = x + 1 34 | b = z + 1 35 | l = {1: a, 2: y} 36 | return l 37 | 38 | 39 | def output_nest_struct(x: paddle.Tensor, y: paddle.Tensor, z: int): 40 | a = x + y + z 41 | b = z + 1 42 | l = [1 + 1, (z, a), [b]] 43 | return l 44 | 45 | 46 | class TestOutputRestoration(TestCaseBase): 47 | def test_output_identity(self): 48 | self.assert_results(output_identity, 1) 49 | self.assert_results(output_identity, 2) 50 | self.assert_results(output_identity, paddle.to_tensor(1)) 51 | 52 | def test_output_const(self): 53 | self.assert_results(output_const) 54 | 55 | def test_output_list(self): 56 | a = paddle.to_tensor(1) 57 | b = paddle.to_tensor(2) 58 | 59 | self.assert_results(output_list, a, b, 3) 60 | 61 | def test_output_dict(self): 62 | a = paddle.to_tensor(1) 63 | b = paddle.to_tensor(2) 64 | 65 | self.assert_results(output_dict, a, b, 3) 66 | 67 | def test_output_dict_const_key(self): 68 | a = paddle.to_tensor(2) 69 | b = paddle.to_tensor(3) 70 | 71 | self.assert_results(output_dict_const_key, a, b, 4) 72 | 73 | def test_output_nest_struct(self): 74 | a = paddle.to_tensor(1) 75 | b = paddle.to_tensor(2) 76 | 77 | self.assert_results(output_nest_struct, a, b, 3) 78 | 79 | 80 | if __name__ == "__main__": 81 | unittest.main() 82 | -------------------------------------------------------------------------------- /tests/test_range.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from test_case_base import TestCaseBase 4 | 5 | import paddle 6 | 7 | 8 | def test_range_1(stop: int): 9 | return range(stop) 10 | 11 | 12 | def test_range_2(start: int, stop: int): 13 | return range(start, stop) 14 | 15 | 16 | def test_range_3(start: int, stop: int, step: int): 17 | return range(start, stop, step) 18 | 19 | 20 | def test_range_4(stop: int, index: int): 21 | return range(stop)[index] 22 | 23 | 24 | def test_range_5(stop: int): 25 | return list(range(stop)) 26 | 27 | 28 | def test_range_6(stop: int, index: int): 29 | return list(range(stop))[index] 30 | 31 | 32 | def test_range_7(index: int, tensor: paddle.Tensor): 33 | return list(range(len(tensor.shape)))[index] 34 | 35 | 36 | def test_range_8(stop: int): 37 | sum = 0 38 | for i in range(stop): 39 | sum += i 40 | return sum 41 | 42 | 43 | def test_range_9(stop: int, tensor: paddle.Tensor): 44 | for i in range(stop): 45 | tensor += i 46 | return tensor 47 | 48 | 49 | def test_range_10(stop: int, tensor: paddle.Tensor): 50 | for i in range(stop): 51 | for j in range(stop + 1): 52 | tensor += j 53 | return tensor 54 | 55 | 56 | class TestExecutor(TestCaseBase): 57 | def test_cases(self): 58 | start = 3 59 | stop = 10 60 | step = 2 61 | index = 1 62 | tensor = paddle.randn((10, 10)) 63 | 64 | self.assert_results(test_range_1, stop) 65 | self.assert_results(test_range_2, start, stop) 66 | self.assert_results(test_range_3, start, stop, step) 67 | self.assert_results(test_range_4, stop, index) 68 | self.assert_results(test_range_5, stop) 69 | self.assert_results(test_range_6, stop, index) 70 | self.assert_results(test_range_7, index, tensor) 71 | self.assert_results(test_range_8, stop) 72 | 73 | self.assert_results(test_range_9, stop, paddle.randn((10,))) 74 | self.assert_results(test_range_10, stop, paddle.randn((10,))) 75 | 76 | 77 | if __name__ == "__main__": 78 | unittest.main() 79 | -------------------------------------------------------------------------------- /docs/instructions/UNPACK_SEQUENCE.md: -------------------------------------------------------------------------------- 1 | # UNPACK_SEQUENCE 2 | 3 | ## 1. python3.8 的 UNPACK_SEQUENCE 4 | 5 | 1. 只处理三种情况:list,tuple,iterator 6 | 2. 从右向左入栈,即最终 list[0] 在栈顶 (iter是第一次 next 的返回值在栈顶) 7 | 8 | ```cc 9 | case TARGET(UNPACK_SEQUENCE): { 10 | PREDICTED(UNPACK_SEQUENCE); 11 | PyObject *seq = POP(), *item, **items; 12 | if (PyTuple_CheckExact(seq) && 13 | PyTuple_GET_SIZE(seq) == oparg) { 14 | items = ((PyTupleObject *)seq)->ob_item; 15 | while (oparg--) { 16 | item = items[oparg]; 17 | Py_INCREF(item); 18 | PUSH(item); 19 | } 20 | } else if (PyList_CheckExact(seq) && 21 | PyList_GET_SIZE(seq) == oparg) { 22 | items = ((PyListObject *)seq)->ob_item; 23 | while (oparg--) { 24 | item = items[oparg]; 25 | Py_INCREF(item); 26 | PUSH(item); 27 | } 28 | } else if (unpack_iterable(tstate, seq, oparg, -1, 29 | stack_pointer + oparg)) { 30 | STACK_GROW(oparg); 31 | } else { 32 | /* unpack_iterable() raised an exception */ 33 | Py_DECREF(seq); 34 | goto error; 35 | } 36 | Py_DECREF(seq); 37 | DISPATCH(); 38 | } 39 | ``` 40 | 41 | 42 | 43 | ## 2. 遇到的问题 44 | 45 | 从 iterator 中 unpack 出来的元素,其 source 是什么? 46 | 47 | 48 | 49 | ## 3. torch 的做法 50 | 51 | 在 torch 中 unpack 的逻辑是和 Varaible 绑定的(作为成员方法) 52 | 53 | 只支持: 54 | 55 | 1. const 56 | 2. dict 57 | 3. list (BaseListVariable,RangeVariable,SizeVariable,ListIteratorVariable) 58 | - tuple 也在这个文件,但是它的 iterator 也不能 unpack (应该说根本没有 TupleIteratorVariable) 59 | - SizeVairable 是 TupleVariable 的子类 60 | 4. nn_module 61 | 5. tensor 62 | 63 | 对于迭代器类型,只支持 ListIterator,所以并没有实现 unpack iterator 64 |  65 | -------------------------------------------------------------------------------- /tests/test_20_string.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import unittest 4 | 5 | from test_case_base import TestCaseBase 6 | 7 | import paddle 8 | from sot.psdb import assert_true, check_no_breakgraph 9 | 10 | 11 | def string_format(x: paddle.Tensor): 12 | whilespace = 123 13 | hello_world = f"Hello {whilespace} World" 14 | z = assert_true(hello_world == "Hello 123 World") 15 | hello_world2 = f"Hello {whilespace}{whilespace} World" 16 | z = assert_true(hello_world2 == "Hello 123123 World") 17 | hello_world_lower = "Hello World".lower() 18 | z = assert_true(hello_world_lower == "hello world") 19 | return x + 1 20 | 21 | 22 | def string_lower(x: paddle.Tensor): 23 | hello_world_lower = "Hello World".lower() 24 | z = assert_true(hello_world_lower == "hello world") 25 | return x + 1 26 | 27 | 28 | @check_no_breakgraph 29 | def str_startswith(): 30 | s = "Hello World" 31 | a1 = s.startswith("Hello") 32 | a2 = s.startswith("World") 33 | a3 = s.startswith("Hello World") 34 | a4 = s.startswith("Hello World!") 35 | a5 = s.startswith("Hello", 5) 36 | a6 = s.startswith("Hello", 1, 4) 37 | a7 = s.startswith("Hello", 0, 11) 38 | return (a1, a2, a3, a4, a5, a6, a7) 39 | 40 | 41 | @check_no_breakgraph 42 | def str_endswith(): 43 | s = "Hello World" 44 | a1 = s.endswith("Hello") 45 | a2 = s.endswith("World") 46 | a3 = s.endswith("Hello World") 47 | a4 = s.endswith("Hello World!") 48 | a5 = s.endswith("Hello", 5) 49 | a6 = s.endswith("Hello", 0, 4) 50 | a7 = s.endswith("Hello", 1, 11) 51 | return (a1, a2, a3, a4, a5, a6, a7) 52 | 53 | 54 | class TestExecutor(TestCaseBase): 55 | def test_string_format(self): 56 | self.assert_results(string_format, paddle.to_tensor(1)) 57 | 58 | def test_string_lower(self): 59 | self.assert_results(string_lower, paddle.to_tensor(1)) 60 | 61 | def test_str_startswith(self): 62 | self.assert_results(str_startswith) 63 | 64 | def test_str_endswith(self): 65 | self.assert_results(str_endswith) 66 | 67 | 68 | if __name__ == "__main__": 69 | unittest.main() 70 | -------------------------------------------------------------------------------- /tests/test_03_tuple.py: -------------------------------------------------------------------------------- 1 | # New Supported Instructions: 2 | # BUILD_TUPLE 3 | # BINARY_SUBSCR 4 | 5 | from __future__ import annotations 6 | 7 | import unittest 8 | 9 | from test_case_base import TestCaseBase 10 | 11 | import paddle 12 | from sot.psdb import check_no_breakgraph 13 | 14 | 15 | @check_no_breakgraph 16 | def build_tuple(x: int, y: paddle.Tensor): 17 | x = (x, y) 18 | return x[1] + 1 19 | 20 | 21 | @check_no_breakgraph 22 | def build_tuple_with_slice_subscript(x: int, y: paddle.Tensor): 23 | z = (x, y, 3, 4) 24 | return z[0:5:1] 25 | 26 | 27 | @check_no_breakgraph 28 | def build_tuple_with_int_subscript(x: int, y: paddle.Tensor): 29 | z = (x, y) 30 | return z[0] 31 | 32 | 33 | @check_no_breakgraph 34 | def tuple_count_int(x: int, y: paddle.Tensor): 35 | z = (x, x, 2, 1) 36 | return z.count(x) 37 | 38 | 39 | def tuple_count_tensor(x: paddle.Tensor, y: tuple[paddle.Tensor]): 40 | return y.count(x) 41 | 42 | 43 | @check_no_breakgraph 44 | def tuple_index_int(x: int, y: paddle.Tensor): 45 | z = (x, y, x, y, y) 46 | return z.index(x) 47 | 48 | 49 | def tuple_index_tensor(x: paddle.Tensor, y: tuple[paddle.Tensor]): 50 | return y.index(x) 51 | 52 | 53 | class TestBuildTuple(TestCaseBase): 54 | def test_build_tuple(self): 55 | self.assert_results(build_tuple, 1, paddle.to_tensor(2)) 56 | self.assert_results( 57 | build_tuple_with_slice_subscript, 1, paddle.to_tensor(2) 58 | ) 59 | self.assert_results( 60 | build_tuple_with_int_subscript, 1, paddle.to_tensor(2) 61 | ) 62 | 63 | 64 | class TestTupleMethods(TestCaseBase): 65 | def test_tuple_methods_int(self): 66 | self.assert_results(tuple_count_int, 1, paddle.to_tensor(2)) 67 | self.assert_results(tuple_index_int, 1, paddle.to_tensor(2)) 68 | 69 | def test_tuple_methods_tensor(self): 70 | a = paddle.to_tensor(1) 71 | b = paddle.to_tensor(2) 72 | self.assert_results(tuple_count_tensor, a, (a, b, a, b)) 73 | self.assert_results(tuple_index_tensor, b, (b, b, b, a)) 74 | 75 | 76 | if __name__ == "__main__": 77 | unittest.main() 78 | -------------------------------------------------------------------------------- /tests/test_sir_rollback.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import inspect 4 | import operator 5 | import unittest 6 | 7 | from test_case_base import TestCaseBase 8 | 9 | import paddle 10 | from sot.opcode_translator.executor.function_graph import FunctionGraph 11 | from sot.opcode_translator.executor.tracker import DanglingTracker, LocalTracker 12 | from sot.opcode_translator.executor.variables import ( 13 | BuiltinVariable, 14 | VariableFactory, 15 | ) 16 | 17 | 18 | def compute(x, y): 19 | ret = BuiltinVariable(operator.add, x.graph, DanglingTracker())(x, y) 20 | return BuiltinVariable(operator.mul, x.graph, DanglingTracker())(ret, x) 21 | 22 | 23 | def try_add(x, y): 24 | return BuiltinVariable(operator.add, x.graph, DanglingTracker())(x, y) 25 | 26 | 27 | class TestRollback(TestCaseBase): 28 | def test_rollback(self): 29 | frame = inspect.currentframe() 30 | graph = FunctionGraph(frame) 31 | a = paddle.to_tensor(1.0) 32 | b = paddle.to_tensor(2.0) 33 | a = VariableFactory().from_value(a, graph, LocalTracker("a")) 34 | b = VariableFactory().from_value(b, graph, LocalTracker("b")) 35 | out = compute(a, b) 36 | original_length = len(graph.sir_ctx.TOS.statements) 37 | memo = graph.save_memo() 38 | try_add(out, out) 39 | 40 | assert len(graph.sir_ctx.TOS.statements) != len( 41 | memo.stmt_ir.statements 42 | ), "After add, we must statement IR." 43 | graph.restore_memo(memo) 44 | 45 | assert len(graph.sir_ctx.TOS.statements) == original_length 46 | 47 | 48 | def fn_with_side_effects_inner(x, y): 49 | x[0] += 10 50 | x[1] += 20 51 | x[2] -= 10 52 | print(y) # print will cause breakgraph 53 | return 54 | 55 | 56 | def fn_with_side_effects(x, y): 57 | x[0] += 1 58 | fn_with_side_effects_inner(x, y) 59 | return x[0] + y 60 | 61 | 62 | class TestSideEffectRollback(TestCaseBase): 63 | def test_side_effect_rollback(self): 64 | self.assert_results_with_side_effects( 65 | fn_with_side_effects, [1, 2, 3], paddle.to_tensor(42) 66 | ) 67 | 68 | 69 | if __name__ == "__main__": 70 | unittest.main() 71 | -------------------------------------------------------------------------------- /tests/test_guard_outputs.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import unittest 4 | 5 | from test_case_base import ( 6 | TestCaseBase, 7 | test_instruction_translator_cache_context, 8 | ) 9 | 10 | import paddle 11 | 12 | 13 | def non_operator_related_fn(x: int, y: int): 14 | return x + y 15 | 16 | 17 | def partial_non_operator_related_fn(x: paddle.Tensor, y: paddle.Tensor, z: int): 18 | a = x + y 19 | return [a, z + z] 20 | 21 | 22 | def guard_inputs(x: int, y: int, z: int): 23 | return x + y + z 24 | 25 | 26 | class TestGuardOutputs(TestCaseBase): 27 | def test_non_operator_related_fn(self): 28 | with test_instruction_translator_cache_context() as ctx: 29 | self.assert_results(non_operator_related_fn, 1, 2) 30 | self.assertEqual(ctx.translate_count, 1) 31 | self.assert_results(non_operator_related_fn, 3, 4) 32 | self.assertEqual(ctx.translate_count, 2) 33 | 34 | def test_partial_non_operator_related_fn(self): 35 | with test_instruction_translator_cache_context() as ctx: 36 | self.assert_results( 37 | partial_non_operator_related_fn, 38 | paddle.to_tensor(1), 39 | paddle.to_tensor(2), 40 | 3, 41 | ) 42 | self.assertEqual(ctx.translate_count, 1) 43 | self.assert_results( 44 | partial_non_operator_related_fn, 45 | paddle.to_tensor(4), 46 | paddle.to_tensor(5), 47 | 6, 48 | ) 49 | self.assertEqual(ctx.translate_count, 2) 50 | 51 | def test_guard_inputs(self): 52 | with test_instruction_translator_cache_context() as ctx: 53 | self.assert_results(guard_inputs, 1, 2, 3) 54 | self.assertEqual(ctx.translate_count, 1) 55 | self.assert_results(guard_inputs, 0, 2, 3) 56 | self.assertEqual(ctx.translate_count, 2) 57 | self.assert_results(guard_inputs, 1, 0, 3) 58 | self.assertEqual(ctx.translate_count, 3) 59 | self.assert_results(guard_inputs, 1, 2, 0) 60 | self.assertEqual(ctx.translate_count, 4) 61 | 62 | 63 | if __name__ == "__main__": 64 | unittest.main() 65 | -------------------------------------------------------------------------------- /tests/test_17_paddle_layer.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from test_case_base import TestCaseBase 4 | 5 | import paddle 6 | 7 | 8 | class SimpleNet(paddle.nn.Layer): 9 | def __init__(self): 10 | super().__init__() 11 | self.linear1 = paddle.nn.Linear(10, 1) 12 | 13 | def forward(self, x): 14 | out1 = self.linear1(x) 15 | return out1 16 | 17 | 18 | class SimpleNet_bound(paddle.nn.Layer): 19 | def __init__(self): 20 | super().__init__() 21 | self.linear1 = paddle.nn.Linear(10, 1) 22 | 23 | def add(self, x): 24 | return x + 1 25 | 26 | def forward(self, x): 27 | x = self.add(x) 28 | out1 = self.linear1(x) 29 | return out1 30 | 31 | 32 | def net_call(x: paddle.Tensor, net): 33 | return net(x) 34 | 35 | 36 | def net_call_passed_by_user(x: paddle.Tensor, net_forward): 37 | return net_forward(x) 38 | 39 | 40 | class SimpleNetWithSequenital(paddle.nn.Layer): 41 | def __init__(self): 42 | super().__init__() 43 | self.seq = paddle.nn.Sequential( 44 | paddle.nn.Linear(10, 10), 45 | paddle.nn.Linear(10, 10), 46 | paddle.nn.Linear(10, 1), 47 | ) 48 | 49 | def forward(self, x): 50 | out1 = self.seq(x) 51 | return out1 52 | 53 | 54 | class TestLayer(TestCaseBase): 55 | def test_layer(self): 56 | x = paddle.rand((10,)) 57 | y = paddle.rand((10, 10)) 58 | net = SimpleNet() 59 | self.assert_results(net_call, x, net) 60 | self.assert_results(net_call, y, net) 61 | self.assert_results(net_call_passed_by_user, x, net.forward) 62 | 63 | def test_layer_with_sequential(self): 64 | x = paddle.rand((10,)) 65 | y = paddle.rand((10, 10)) 66 | net = SimpleNetWithSequenital() 67 | self.assert_results(net_call, x, net) 68 | self.assert_results(net_call, y, net) 69 | self.assert_results(net_call_passed_by_user, x, net.forward) 70 | 71 | def test_bound(self): 72 | x = paddle.rand((10,)) 73 | y = paddle.rand((10, 10)) 74 | net = SimpleNet_bound() 75 | self.assert_results(net_call, x, net) 76 | self.assert_results(net_call, y, net) 77 | 78 | 79 | if __name__ == "__main__": 80 | unittest.main() 81 | -------------------------------------------------------------------------------- /tests/test_18_tensor_method.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from test_case_base import TestCaseBase 4 | 5 | import paddle 6 | 7 | 8 | def tensor_method_call_1(x: paddle.Tensor): 9 | y = x + 1 10 | return y.mean() 11 | 12 | 13 | def tensor_method_call_2(a: paddle.Tensor, b: paddle.Tensor): 14 | c = a.add(b) 15 | d = c.multiply(a) 16 | e = d.subtract(b) 17 | f = e.divide(a) 18 | g = f.pow(2) + f.abs().sqrt() 19 | h = (g.abs() + 1).log() - (g / g.max()).exp() 20 | i = h.sin() + h.cos() 21 | return i 22 | 23 | 24 | def tensor_method_passed_by_user(a: paddle.Tensor, func: paddle.Tensor): 25 | return func(a) 26 | 27 | 28 | def tensor_method_property(a: paddle.Tensor, b: paddle.Tensor): 29 | return ( 30 | a.name, 31 | str(a.place), 32 | a.persistable, 33 | a.dtype, 34 | a.type, 35 | a.is_tensor(), 36 | a.clear_gradient(), 37 | a @ b.T + len(a.shape) + b.size + a.ndim + a.dim() + a.rank(), 38 | ) 39 | 40 | 41 | def middle_tensor_name(a: paddle.Tensor, b: paddle.Tensor): 42 | c = a + b 43 | return c.name 44 | 45 | 46 | class TestTensorMethod(TestCaseBase): 47 | def test_tensor_method_1(self): 48 | x = paddle.rand([10]) 49 | y = paddle.rand([2, 4, 6]) 50 | self.assert_results(tensor_method_call_1, x) 51 | self.assert_results(tensor_method_call_1, y) 52 | 53 | def test_tensor_method_2(self): 54 | x = paddle.rand([42]) 55 | y = paddle.rand([42]) 56 | self.assert_results(tensor_method_call_2, x, y) 57 | 58 | def test_tensor_method_passed_by_user(self): 59 | x = paddle.rand([42]) 60 | y = paddle.rand([42]) 61 | self.assert_results(tensor_method_passed_by_user, x, y.add) 62 | 63 | def test_tensor_method_property(self): 64 | x = paddle.rand([42, 24], dtype='float64') 65 | y = paddle.rand([42, 24], dtype='float32') 66 | self.assert_results(tensor_method_property, x, y) 67 | 68 | @unittest.skip("TODO: dynamic tensor name is different") 69 | def test_middle_tensor_name(self): 70 | x = paddle.rand([42, 24]) 71 | y = paddle.rand([42, 24]) 72 | self.assert_results(middle_tensor_name, x, y) 73 | 74 | 75 | if __name__ == "__main__": 76 | unittest.main() 77 | -------------------------------------------------------------------------------- /sot/utils/code_status.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from enum import Enum 3 | 4 | import paddle 5 | 6 | from .utils import Singleton, log 7 | 8 | 9 | class CodeState(Enum): 10 | UNKNOW = 1 11 | WITH_GRAPH = 2 12 | WITHOUT_GRAPH = 3 13 | 14 | 15 | class CodeInfo: 16 | def __init__(self): 17 | self.state = CodeState.UNKNOW 18 | self.counter = 0 19 | 20 | def __repr__(self): 21 | return f"state: {self.state}, counter: {self.counter}" 22 | 23 | 24 | @Singleton 25 | class CodeStatus: 26 | WITH_GRAPH_API = [ 27 | paddle.nn.Layer.__call__.__code__, 28 | paddle.nn.Layer._dygraph_call_func.__code__, 29 | ] 30 | 31 | def __init__(self): 32 | self.code_map = {} 33 | self.setup_code_map() 34 | 35 | def setup_code_map(self): 36 | for code in self.WITH_GRAPH_API: 37 | info = CodeInfo() 38 | info.state = CodeState.WITH_GRAPH 39 | self.code_map[code] = info 40 | 41 | def clear(self): 42 | self.code_map.clear() 43 | self.setup_code_map() 44 | 45 | def is_code_without_graph(self, code): 46 | if code not in self.code_map: 47 | info = CodeInfo() 48 | self.code_map[code] = info 49 | else: 50 | info = self.code_map[code] 51 | 52 | if info.state == CodeState.WITHOUT_GRAPH: 53 | return True 54 | if info.state == CodeState.UNKNOW: 55 | info.counter += 1 56 | if info.counter >= 10: 57 | log( 58 | 3, 59 | f"[CodeStatus] Switch state to WITHOUT_GRAPH for {code}\n", 60 | ) 61 | info.state = CodeState.WITHOUT_GRAPH 62 | return False 63 | 64 | def trace_back_frames(self): 65 | frame = inspect.currentframe() 66 | while frame.f_back is not None: 67 | frame = frame.f_back 68 | code = frame.f_code 69 | if code in self.code_map: 70 | info = self.code_map[code] 71 | if info.state != CodeState.WITH_GRAPH: 72 | log( 73 | 3, 74 | f"[CodeStatus] Switch state to WITH_GRAPH for {code}\n", 75 | ) 76 | info.state = CodeState.WITH_GRAPH 77 | -------------------------------------------------------------------------------- /tests/test_08_rot.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import unittest 4 | 5 | from test_case_base import TestCaseBase 6 | 7 | import paddle 8 | 9 | 10 | def rot_two_return_a(a: paddle.Tensor, b: paddle.Tensor): 11 | b, a = a, b 12 | return a + 1 13 | 14 | 15 | def rot_two_return_b(a: paddle.Tensor, b: paddle.Tensor): 16 | b, a = a, b 17 | return b + 2 18 | 19 | 20 | def rot_three_return_a(a: paddle.Tensor, b: paddle.Tensor, c: paddle.Tensor): 21 | a, b, c = c, b, a 22 | return a + 1 23 | 24 | 25 | def rot_three_return_b(a: paddle.Tensor, b: paddle.Tensor, c: paddle.Tensor): 26 | a, b, c = c, b, a 27 | return b + 1 28 | 29 | 30 | def rot_three_return_c(a: paddle.Tensor, b: paddle.Tensor, c: paddle.Tensor): 31 | a, b, c = c, b, a 32 | return c + 1 33 | 34 | 35 | def rot_four_return_a( 36 | a: paddle.Tensor, b: paddle.Tensor, c: paddle.Tensor, d: paddle.Tensor 37 | ): 38 | a, b, c, d = d, c, b, a 39 | return a + 1 40 | 41 | 42 | def rot_four_return_b( 43 | a: paddle.Tensor, b: paddle.Tensor, c: paddle.Tensor, d: paddle.Tensor 44 | ): 45 | a, b, c, d = d, c, b, a 46 | return b + 1 47 | 48 | 49 | def rot_four_return_c( 50 | a: paddle.Tensor, b: paddle.Tensor, c: paddle.Tensor, d: paddle.Tensor 51 | ): 52 | a, b, c, d = d, c, b, a 53 | return c + 1 54 | 55 | 56 | def rot_four_return_d( 57 | a: paddle.Tensor, b: paddle.Tensor, c: paddle.Tensor, d: paddle.Tensor 58 | ): 59 | a, b, c, d = d, c, b, a 60 | return d + 1 61 | 62 | 63 | class TestExecutor(TestCaseBase): 64 | def test_simple(self): 65 | a = paddle.to_tensor(1) 66 | b = paddle.to_tensor(2) 67 | c = paddle.to_tensor(3) 68 | d = paddle.to_tensor(4) 69 | self.assert_results(rot_two_return_a, a, b) 70 | self.assert_results(rot_two_return_b, a, b) 71 | 72 | self.assert_results(rot_three_return_a, a, b, c) 73 | self.assert_results(rot_three_return_b, a, b, c) 74 | self.assert_results(rot_three_return_c, a, b, c) 75 | 76 | self.assert_results(rot_four_return_a, a, b, c, d) 77 | self.assert_results(rot_four_return_b, a, b, c, d) 78 | self.assert_results(rot_four_return_c, a, b, c, d) 79 | self.assert_results(rot_four_return_d, a, b, c, d) 80 | 81 | 82 | if __name__ == "__main__": 83 | unittest.main() 84 | -------------------------------------------------------------------------------- /tests/test_10_build_unpack.py: -------------------------------------------------------------------------------- 1 | # BUILD_TUPLE_UNPACK (new) 2 | # BUILD_LIST_UNPACK (new) 3 | # BUILD_TUPLE_UNPACK_WITH_CALL (new) 4 | # CALL_FUNCTION_EX (new) 5 | # BUILD_MAP_UNPACK (new) 6 | # LIST_EXTEND (new) 7 | # LIST_TO_TUPLE (new) 8 | # DICT_UPDATE (new) 9 | # DICT_MERGE (new) 10 | 11 | from __future__ import annotations 12 | 13 | import unittest 14 | 15 | from test_case_base import TestCaseBase 16 | 17 | import paddle 18 | 19 | 20 | def build_tuple_unpack(x: tuple[paddle.Tensor], y: tuple[paddle.Tensor]): 21 | z = (*x, *y) 22 | 23 | return z[0] + 1 24 | 25 | 26 | def build_list_unpack(x: list[paddle.Tensor], y: list[paddle.Tensor]): 27 | z = [*x, *y] 28 | return z[0] + 1 29 | 30 | 31 | def build_tuple_unpack_with_call( 32 | x: tuple[paddle.Tensor], y: tuple[paddle.Tensor] 33 | ): 34 | z = build_tuple_unpack_with_call_inner(*x, *y) 35 | return z[0] + 1 36 | 37 | 38 | def build_tuple_unpack_with_call_inner( 39 | a: paddle.Tensor, b: paddle.Tensor, c: paddle.Tensor, d: paddle.Tensor 40 | ): 41 | z = (a, b, c, d) 42 | return z 43 | 44 | 45 | def build_map_unpack(x: dict[str, paddle.Tensor], y: dict[str, paddle.Tensor]): 46 | z = {**x, **y} 47 | return z["a"] + 1 48 | 49 | 50 | def build_map_unpack_with_call_inner( 51 | a: paddle.Tensor, b: paddle.Tensor, c: paddle.Tensor, d: paddle.Tensor 52 | ): 53 | z = {"a": a, "b": b, "c": c, "d": d} 54 | return z 55 | 56 | 57 | def build_map_unpack_with_call( 58 | x: dict[str, paddle.Tensor], y: dict[str, paddle.Tensor] 59 | ): 60 | z = build_map_unpack_with_call_inner(**x, **y) 61 | return z["a"] + 1 62 | 63 | 64 | class TestExecutor(TestCaseBase): 65 | def test_simple(self): 66 | a = paddle.to_tensor(1) 67 | b = paddle.to_tensor(2) 68 | c = paddle.to_tensor(3) 69 | d = paddle.to_tensor(4) 70 | 71 | self.assert_results(build_tuple_unpack, (a, b), (c, d)) 72 | self.assert_results(build_list_unpack, [a, b], [c, d]) 73 | self.assert_results(build_tuple_unpack_with_call, (a, b), (c, d)) 74 | self.assert_results( 75 | build_map_unpack, {"a": a, "b": b}, {"c": c, "d": d} 76 | ) 77 | self.assert_results( 78 | build_map_unpack_with_call, {"a": a, "b": b}, {"c": c, "d": d} 79 | ) 80 | 81 | 82 | if __name__ == "__main__": 83 | unittest.main() 84 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "PaddleSOT" 7 | version = "0.0.1a0" 8 | description = "A Bytecode level Implementation of Symbolic OpCode Translator For PaddlePaddle" 9 | readme = "README.md" 10 | requires-python = ">=3.7,<3.12" 11 | authors = [ 12 | {name = "PaddlePaddle", email = "Paddle-better@baidu.com"}, 13 | ] 14 | keywords = ["Framework", "Deep Learning", "JIT"] 15 | license = { file = "LICENSE" } 16 | classifiers = [ 17 | "License :: OSI Approved :: Apache Software License", 18 | "Programming Language :: Python :: 3", 19 | "Programming Language :: Python :: 3.8", 20 | "Programming Language :: Python :: 3.9", 21 | "Programming Language :: Python :: 3.10", 22 | "Programming Language :: Python :: 3.11", 23 | "Programming Language :: Python :: Implementation :: CPython", 24 | "Operating System :: OS Independent", 25 | ] 26 | 27 | [project.optional-dependencies] 28 | dev = [ 29 | "xdoctest" 30 | ] 31 | paddle-test = [ 32 | "opencv-python", 33 | "scipy", 34 | "pyyaml", 35 | "gym==0.26.2", 36 | ] 37 | 38 | [tool.setuptools.packages] 39 | find = {} 40 | 41 | [tool.black] 42 | line-length = 80 43 | skip-string-normalization = true 44 | 45 | [tool.isort] 46 | profile = "black" 47 | line_length = 80 48 | known_first_party = ["paddle", "sot"] 49 | 50 | [tool.ruff] 51 | target-version = "py37" 52 | select = [ 53 | # Pycodestyle 54 | "E", 55 | "W", 56 | 57 | # Pyflakes 58 | "F", 59 | 60 | # Comprehensions 61 | "C4", 62 | 63 | # Pyupgrade 64 | "UP", 65 | 66 | # Bugbear 67 | "B002", 68 | "B003", 69 | "B004", 70 | "B009", 71 | "B010", 72 | "B011", 73 | "B012", 74 | "B013", 75 | "B014", 76 | "B015", 77 | "B016", 78 | "B017", 79 | "B018", 80 | "B019", 81 | "B020", 82 | "B021", 83 | "B022", 84 | "B025", 85 | "B029", 86 | "B032", 87 | 88 | # Pylint 89 | "PLE", 90 | "PLC0414", 91 | "PLC3002", 92 | "PLR0206", 93 | "PLR0402", 94 | ] 95 | ignore = [ 96 | "E402", 97 | "E501", 98 | "E722", 99 | "E731", 100 | "E741", 101 | "F841", 102 | "UP015", 103 | "UP038", 104 | ] 105 | 106 | [tool.ruff.per-file-ignores] 107 | "tests/test_dup_top.py" = ["E712"] 108 | -------------------------------------------------------------------------------- /tests/test_guard_user_defined_fn.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import unittest 4 | 5 | from test_case_base import ( 6 | TestCaseBase, 7 | test_instruction_translator_cache_context, 8 | ) 9 | 10 | import paddle 11 | 12 | 13 | def test_guard_fn(fn, inp): 14 | if fn is None: 15 | return 0 16 | else: 17 | return fn(inp) 18 | 19 | 20 | class TestGuardOutputs(TestCaseBase): 21 | def test_non_operator_related_fn(self): 22 | with test_instruction_translator_cache_context() as ctx: 23 | self.assert_results( 24 | test_guard_fn, 25 | paddle.nn.functional.relu, 26 | paddle.to_tensor([1.0, -1.0]), 27 | ) 28 | self.assertEqual(ctx.translate_count, 1) 29 | self.assert_results( 30 | test_guard_fn, 31 | paddle.nn.functional.gelu, 32 | paddle.to_tensor([1.0, -1.0]), 33 | ) 34 | self.assertEqual(ctx.translate_count, 2) 35 | self.assert_results( 36 | test_guard_fn, 37 | paddle.nn.functional.relu, 38 | paddle.to_tensor([-1.0, -1.0]), 39 | ) 40 | self.assertEqual(ctx.translate_count, 2) 41 | self.assert_results( 42 | test_guard_fn, None, paddle.to_tensor([-1.0, -1.0]) 43 | ) 44 | self.assertEqual(ctx.translate_count, 3) 45 | 46 | deleted_cnt = 0 47 | 48 | class Callable: 49 | def __call__(self, var): 50 | return paddle.nn.functional.relu(var) 51 | 52 | def __del__(self): 53 | nonlocal deleted_cnt 54 | deleted_cnt += 1 55 | 56 | fn1 = Callable() 57 | fn2 = Callable() 58 | with test_instruction_translator_cache_context() as ctx: 59 | self.assert_results( 60 | test_guard_fn, fn1, paddle.to_tensor([1.0, -1.0]) 61 | ) 62 | self.assertEqual(ctx.translate_count, 1) 63 | self.assert_results( 64 | test_guard_fn, fn2, paddle.to_tensor([1.0, -1.0]) 65 | ) 66 | self.assertEqual(ctx.translate_count, 2) 67 | self.assert_results( 68 | test_guard_fn, fn2, paddle.to_tensor([1.0, -1.0]) 69 | ) 70 | self.assertEqual(ctx.translate_count, 2) 71 | 72 | 73 | if __name__ == "__main__": 74 | unittest.main() 75 | -------------------------------------------------------------------------------- /tests/tests_legacy/error_test_resnet_with_trace_cache.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ["FLAGS_cudnn_deterministic"] = "True" 4 | 5 | import random 6 | import unittest 7 | from types import MethodType 8 | 9 | import numpy as np 10 | from numpy.testing import assert_array_equal 11 | 12 | import paddle 13 | from paddle.vision import resnet50 14 | from sot import symbolic_translate 15 | from sot.symbolic.compile_cache import CompileSIRCache 16 | from sot.trace_cache_entrance import trace_cache 17 | from sot.utils.utils import execute_time 18 | 19 | 20 | @trace_cache 21 | def forward_with_cache(self, x): 22 | x = self.conv1(x) 23 | x = self.bn1(x) 24 | x = self.relu(x) 25 | x = self.maxpool(x) 26 | x = self.layer1(x) 27 | x = self.layer2(x) 28 | x = self.layer3(x) 29 | x = self.layer4(x) 30 | if self.with_pool: 31 | x = self.avgpool(x) 32 | if self.num_classes > 0: 33 | x = paddle.flatten(x, 1) 34 | x = self.fc(x) 35 | return x 36 | 37 | 38 | def run_dygraph_optimizer(inp, to_static): 39 | """dygraph train + SGD optimizer""" 40 | paddle.seed(2021) 41 | np.random.seed(2021) 42 | random.seed(2021) 43 | net = resnet50() 44 | if to_static: 45 | net.forward = MethodType(forward_with_cache, net) 46 | net.forward = symbolic_translate(net.forward) 47 | # net = paddle.jit.to_static(net) 48 | optimizer = paddle.optimizer.SGD( 49 | learning_rate=0.03, parameters=net.parameters() 50 | ) 51 | for i in range(5): 52 | optimizer.clear_grad() 53 | loss = execute_time(net)(inp) 54 | loss.backward() 55 | optimizer.step() 56 | print("===============================================") 57 | return loss 58 | 59 | 60 | class TestBackward(unittest.TestCase): 61 | def test(self): 62 | # TODO(xiongkun) add cache to speedup ! 63 | paddle.seed(2021) 64 | np.random.seed(2021) 65 | random.seed(2021) 66 | inp = paddle.rand((3, 3, 255, 255)) 67 | out1 = run_dygraph_optimizer(inp, True)[0].numpy() 68 | out2 = run_dygraph_optimizer(inp, False)[0].numpy() 69 | assert_array_equal( 70 | out1, out2, "Not Equal in dygraph and static graph", True 71 | ) 72 | assert ( 73 | CompileSIRCache().hit_num == 4 74 | ), f"CompileSIRCache hit_num should be 4, but{CompileSIRCache().hit_num}" 75 | 76 | 77 | if __name__ == "__main__": 78 | unittest.main() 79 | -------------------------------------------------------------------------------- /.github/workflows/paddle_ci.yaml: -------------------------------------------------------------------------------- 1 | name: PaddleDy2staicTests 2 | 3 | on: 4 | push: 5 | branches: [develop] 6 | pull_request: 7 | workflow_dispatch: 8 | 9 | jobs: 10 | Test: 11 | strategy: 12 | fail-fast: false 13 | matrix: 14 | os: [ubuntu-latest] 15 | python-version: ['3.8', '3.9', '3.10', '3.11'] 16 | runs-on: ${{ matrix.os }} 17 | name: python ${{ matrix.python-version }} dy2st unittests 18 | steps: 19 | - name: Checkout 20 | uses: actions/checkout@v3 21 | 22 | # RAM too small, Related link: https://github.com/actions/runner-images/discussions/7188 23 | - name: Increase swapfile 24 | run: | 25 | sudo swapoff -a 26 | sudo fallocate -l 8G /swapfile 27 | sudo chmod 600 /swapfile 28 | sudo mkswap /swapfile 29 | sudo swapon /swapfile 30 | sudo swapon --show 31 | 32 | - name: Install python 33 | uses: actions/setup-python@v4 34 | with: 35 | python-version: ${{ matrix.python-version }} 36 | 37 | - name: Install dependencies 38 | run: | 39 | python -m pip install paddlepaddle==0.0.0 -f https://www.paddlepaddle.org.cn/whl/linux/cpu-mkl/develop.html 40 | pip install ".[paddle-test]" 41 | 42 | - uses: actions/cache@v3 43 | with: 44 | path: | 45 | ~/.cache/paddle 46 | key: ${{ runner.os }}-${{ matrix.python-version }}-paddle-cache-v1 47 | 48 | - name: Paddle run check 49 | run: | 50 | python -c "import paddle; paddle.utils.run_check()" 51 | 52 | - name: Clone paddle tests 53 | working-directory: ./tests/ 54 | run: | 55 | git clone https://github.com/PaddlePaddle/Paddle 56 | cd Paddle 57 | PADDLE_COMMIT_HASH=`python -c "import paddle; import sys; sys.stdout.write(paddle.version.commit)"` 58 | echo -e "paddle commit hash: \x1b[32m${PADDLE_COMMIT_HASH}\x1b[0m" 59 | git checkout ${PADDLE_COMMIT_HASH} 60 | # skip new ir tests 61 | sed -i "s/@test_with_new_ir/@unittest.skip('skip new ir test')/g" `grep -rl "@test_with_new_ir" test/dygraph_to_static` 62 | sed -i "s/@test_and_compare_with_new_ir/# @test_and_compare_with_new_ir/g" `grep -rl "@test_and_compare_with_new_ir" test/dygraph_to_static` 63 | 64 | - name: Run unit tests 65 | working-directory: ./tests/ 66 | run: | 67 | bash ./run_all_paddle_ci.sh 68 | -------------------------------------------------------------------------------- /docs/design/tracker-and-guard.md: -------------------------------------------------------------------------------- 1 | # Guard 收集机制 2 | 3 | ## 为什么需要 Guard? 4 | 5 | 在整个实现过程中,我们需要对用户函数的原始字节码转换成转换后的字节码,但如果每次运行都完整地转写全部字节码只会导致性能的浪费,也无法实现 JIT 的效果,因此我们需要一个缓存机制,来复用已经转换后的字节码。 6 | 7 | 但并不是说任何字节码成功转换一次后都是可以复用的,因为我们的字节码变换操作是通过模拟执行得到的,而模拟执行的起点是 Eval Frame 的初始状态,主要就是函数的输入,对于不同的输入,我们得到的字节码转换结果是可能不同的,因此我们需要有一个机制来判断转换后的字节码是否有效。 8 | 9 | 由于转换的过程与输入是强相关的,在函数 Eval Frame 初始阶段,我们可以从 `frame` 中拿到函数的输入,之后我们只需要通过 `guard` 来判断一个已经缓存的字节码是否有效即可,即 `guard(frame)`,如果结果是 `True`,则认为缓存命中。 10 | 11 | guard 签名如下: 12 | 13 | ```python 14 | Guard = Callable[[types.FrameType], bool] 15 | ``` 16 | 17 | ## Guard 的收集机制 18 | 19 | 在模拟执行过程中,我们会根据字节码执行不同的操作,每一个字节码都会对应一个操作,如果我们将整个操作的链条构建起来,形成一个 DAG,就可以在任何时刻追踪到我们需要的 Guard。 20 | 21 | 我们使用 Tracker 来承载追踪操作的功能,Tracker 的 `inputs` 会持有其相关输入的 Variable,该 Tracker 将会由输出的 Variable 持有,相关数据结构如下: 22 | 23 | ```python 24 | class VariableBase: 25 | tracker: Tracker 26 | 27 | class Tracker: 28 | inputs: list[VariableBase] 29 | ``` 30 | 31 | 比如对于如下的代码 32 | 33 | ```python 34 | def foo(a: list[Tensor], b: int, c: int): 35 | d = a[b] 36 | e = d + c 37 | return e 38 | ``` 39 | 40 | 最终构建的 Python 端 DAG 如下: 41 | 42 |
43 |
44 |
)
32 | 16 LOAD_CONST 4 ('func..local')
33 | 18 MAKE_FUNCTION 8 (closure)
34 | 20 STORE_FAST 0 (local)
35 |
36 | 9 22 LOAD_FAST 0 (local)
37 | 24 LOAD_CONST 1 (1)
38 | 26 CALL_FUNCTION 1
39 | 28 RETURN_VALUE
40 |
41 | Disassembly of :
42 | 8 0 LOAD_FAST 0 (y)
43 | 2 LOAD_DEREF 0 (free_x)
44 | 4 BINARY_ADD
45 | 6 LOAD_DEREF 1 (free_y)
46 | 8 BINARY_ADD
47 | 10 RETURN_VALUE
48 |
49 | ```
50 |
51 | 上述字节码可以先进行猜测:一切闭包都是通过额外的字节码进行构建的。
52 |
53 | - STORE_DEREF : 将TOS存储到cell中
54 |
55 | - LOAD_CLOSURE:将Cell构建为闭包
56 |
57 | - LOAD_DEREF:将CELL中的值读取出来
58 |
59 | ## Closure Implementation Bytecode overview in Python
60 |
61 | 在Python3.8中,闭包是通过字节码实现的,与Closure相关的字节码有如下几个:
62 |
63 | | ByteCode | 功能 |
64 | | ------------- | ------- |
65 | | LOAD_CLOSURE | N |
66 | | LOAD_DEREF | N |
67 | | LOAD_CLASSDEREF | N |
68 | | STORE_DEREF | N |
69 | | DELETE_DEREF | N |
70 |
71 |
72 | ## Closure bytecode implementation in detail
73 |
74 |
75 | ```python
76 | case TARGET(LOAD_CLOSURE): {
77 | PyObject *cell = freevars[oparg];
78 | Py_INCREF(cell);
79 | PUSH(cell);
80 | DISPATCH();
81 | }
82 | ```
83 |
84 |
85 | ```python
86 | case TARGET(LOAD_DEREF): {
87 | PyObject *cell = freevars[oparg];
88 | PyObject *value = PyCell_GET(cell);
89 | if (value == NULL) {
90 | format_exc_unbound(tstate, co, oparg);
91 | goto error;
92 | }
93 | Py_INCREF(value);
94 | PUSH(value);
95 | DISPATCH();
96 | }
97 | ```
98 |
99 | ```python
100 | case TARGET(STORE_DEREF): {
101 | PyObject *v = POP();
102 | PyObject *cell = freevars[oparg];
103 | PyObject *oldobj = PyCell_GET(cell);
104 | PyCell_SET(cell, v);
105 | Py_XDECREF(oldobj);
106 | DISPATCH();
107 | }
108 | ```
109 |
110 | ## Conclusion:the implementation of python in detail。
111 |
112 | 对闭包进行总结,弄明白他的整体的工作方式。
113 |
114 | 首先是外层函数给inner函数准备闭包环境:CELLs,外层函数遇到 STORE_DEREF 指令就会将栈顶元素压入到freevars准备好的CELLS中。然后每个CELL中可以存储一个python obj,存储完毕之后会对old存储的对象进行减引用。
115 |
116 | > Notes: freevars 的定义如下,指代 `freevars = f->f_localsplus + co->co_nlocals;`
117 |
--------------------------------------------------------------------------------
/sot/opcode_translator/transform.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import dis
4 | import sys
5 | from functools import partial
6 |
7 | from ..profiler import EventGuard
8 | from ..utils import CodeStatus, log, log_do
9 | from .custom_code import CustomCode
10 | from .executor.executor_cache import OpcodeExecutorCache
11 | from .skip_files import need_skip
12 |
13 |
14 | def print_locals(frame):
15 | local_key = [
16 | key for key in frame.f_locals.keys() if not key.startswith("__")
17 | ]
18 | print(
19 | f"[eval_frame_callback] {frame.f_code.co_name} with locals {local_key}"
20 | )
21 | print(
22 | f"[eval_frame_callback] {' ' * len(frame.f_code.co_name)} with cellvars + freevars: {frame.f_code.co_cellvars + frame.f_code.co_freevars}"
23 | )
24 |
25 | def convert_obj(obj):
26 | import paddle
27 |
28 | if isinstance(obj, paddle.Tensor):
29 | return "Tensor(" + str(obj.shape) + ")"
30 | if isinstance(obj, list):
31 | return [convert_obj(i) for i in obj]
32 | return obj
33 |
34 | for key in local_key:
35 | print(
36 | f"[eval_frame_callback] {' ' * len(frame.f_code.co_name)} {key} = {convert_obj(frame.f_locals[key])}"
37 | )
38 |
39 |
40 | def eval_frame_callback(frame, **kwargs) -> CustomCode:
41 | with EventGuard(
42 | f"eval_frame_callback: {frame.f_code.co_name}", event_level=2
43 | ):
44 | # is generator
45 | if frame.f_code.co_flags & 0x20 > 0:
46 | return CustomCode(None, True)
47 |
48 | # NOTE(SigureMo): Temporary fallback when code has exception handling.
49 | if sys.version_info >= (3, 11) and frame.f_code.co_exceptiontable:
50 | log(
51 | 3,
52 | f"[eval_frame_callback] {frame.f_code} has co_exceptiontable\n",
53 | )
54 | return CustomCode(None, False)
55 |
56 | if need_skip(frame):
57 | log(3, f"[eval_frame_callback] skip {frame.f_code}\n")
58 | custom_code = CustomCode(None, False)
59 | new_code = frame.f_code
60 | else:
61 | log(
62 | 2, f"[eval_frame_callback] start to translate: {frame.f_code}\n"
63 | )
64 | log_do(4, partial(print_locals, frame))
65 |
66 | log(3, f"[transform] OriginCode: {frame.f_code.co_name}\n")
67 | log_do(3, lambda: dis.dis(frame.f_code))
68 |
69 | custom_code = OpcodeExecutorCache()(frame, **kwargs)
70 |
71 | if custom_code.code is None:
72 | log(
73 | 3,
74 | "[transform] NewCode (same as origin code): "
75 | + frame.f_code.co_name
76 | + "\n",
77 | )
78 | new_code = frame.f_code
79 | else:
80 | log(
81 | 3,
82 | "[transform] NewCode: " + custom_code.code.co_name + "\n",
83 | )
84 | log_do(3, lambda: dis.dis(custom_code.code))
85 | new_code = custom_code.code
86 |
87 | # just check those codes which need open eval_frame
88 | if (
89 | custom_code.disable_eval_frame is False
90 | and CodeStatus().is_code_without_graph(new_code)
91 | ):
92 | log(
93 | 3,
94 | "[eval_frame_callback] Code has no graph, block it.\n",
95 | )
96 | return CustomCode(None, True)
97 |
98 | return custom_code
99 |
--------------------------------------------------------------------------------
/tests/test_15_slice.py:
--------------------------------------------------------------------------------
1 | # BUILD_SLICE (new)
2 |
3 | from __future__ import annotations
4 |
5 | import unittest
6 |
7 | from test_case_base import TestCaseBase
8 |
9 | import paddle
10 | from sot.psdb import check_no_breakgraph
11 |
12 |
13 | def build_list_slice(x: list, y: paddle.Tensor):
14 | x[2:4] = [0, 1]
15 | return x[0] + y
16 |
17 |
18 | def build_list_slice_with_step(x: list, y: paddle.Tensor):
19 | x[1:5:2] = [0, 1]
20 | return x[0] + y
21 |
22 |
23 | def build_tuple_slice(x: list, y: paddle.Tensor):
24 | x[2:4] = (0, 1)
25 | return x[0] + y
26 |
27 |
28 | def build_tuple_slice_with_step(x: list, y: paddle.Tensor):
29 | x[1:5:2] = (0, 1)
30 | return x[0] + y
31 |
32 |
33 | def tensor_subscript_ellipsis(x: paddle.Tensor, y: paddle.Tensor):
34 | return x[...] + y[...]
35 |
36 |
37 | @check_no_breakgraph
38 | def tensor_subscript_tensor(x: paddle.Tensor):
39 | d0, d1 = paddle.shape(x)
40 | return x[: d0 // 2, d1 // 2 : d1]
41 |
42 |
43 | class TestSlice(TestCaseBase):
44 | def test_simple(self):
45 | x = list(range(10))
46 | y = paddle.arange(10)
47 | self.assert_results_with_side_effects(build_list_slice, x, y)
48 | self.assert_results_with_side_effects(build_list_slice_with_step, x, y)
49 | self.assert_results_with_side_effects(build_tuple_slice, x, y)
50 | self.assert_results_with_side_effects(build_tuple_slice_with_step, x, y)
51 |
52 |
53 | class MyLayer(paddle.nn.Layer):
54 | def __init__(self):
55 | super().__init__()
56 | self.linears = paddle.nn.LayerList(
57 | [paddle.nn.Linear(10, 10) for i in range(10)]
58 | )
59 |
60 | def forward(self, x):
61 | for i, l in enumerate(self.linears):
62 | x = self.linears[i // 2](x) + l(x)
63 | return x
64 |
65 |
66 | def layer_list_slice(layer, x):
67 | out = layer(x)
68 | return out
69 |
70 |
71 | class TestLayerList(TestCaseBase):
72 | def test_layer_list_slice(self):
73 | layer = MyLayer()
74 | x = paddle.randn([5, 10])
75 | self.assert_results(layer_list_slice, layer, x)
76 |
77 |
78 | def tensor_slice(x: paddle.Tensor):
79 | return x[1, 1, 1] + 1
80 |
81 |
82 | class TestTensorSlice(TestCaseBase):
83 | def test_tensor_slice(self):
84 | x = paddle.randn([4, 3, 10])
85 | self.assert_results(tensor_slice, x)
86 |
87 |
88 | class TestTensorEllipsis(TestCaseBase):
89 | def test_tensor_subscript_ellipsis(self):
90 | x = paddle.rand((10,))
91 | y = paddle.rand((10, 10))
92 | self.assert_results(tensor_subscript_ellipsis, x, y)
93 |
94 |
95 | class TestTensorSubscriptTensor(TestCaseBase):
96 | def test_tensor_subscript_tensor(self):
97 | x = paddle.rand((10, 10))
98 | self.assert_results(tensor_subscript_tensor, x)
99 |
100 |
101 | class LayerListNet(paddle.nn.Layer):
102 | def __init__(self) -> None:
103 | super().__init__()
104 | self.layer_list = paddle.nn.LayerList(
105 | [paddle.nn.Linear(5, 5), paddle.nn.Linear(5, 5)]
106 | )
107 |
108 | def forward(self, x):
109 | out = self.layer_list[0](x)
110 | for layer in self.layer_list[1:]:
111 | out = layer(out)
112 | return out
113 |
114 |
115 | class TestLayerListSlice(TestCaseBase):
116 | def test_layer_list_slice(self):
117 | x = paddle.randn([2, 5])
118 | net = LayerListNet()
119 | self.assert_results(layer_list_slice, net, x)
120 |
121 |
122 | if __name__ == "__main__":
123 | unittest.main()
124 |
--------------------------------------------------------------------------------
/tests/test_inplace_api.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | from test_case_base import TestCaseBase
4 |
5 | import paddle
6 | from sot import symbolic_translate
7 |
8 |
9 | def simple(x, y):
10 | x[0] = 3.0
11 | z = [y]
12 | y[1] = 5.0
13 | return x[0] + x[1] + z[0][1] + y[0] + y[1]
14 |
15 |
16 | def inplace_in_if(x, y, z):
17 | if z:
18 | x[0] = 3.0
19 | z = [y]
20 | y[1] = 5.0
21 | ret = x[0] + x[1] + z[0][1] + y[0] + y[1]
22 | return ret
23 | else:
24 | return None
25 |
26 |
27 | def inplace_in_if_fallback(x, y, z):
28 | if z > 0:
29 | x[0] = 3.0
30 | z = [y]
31 | y[1] = 5.0
32 | ret = x[0] + x[1] + z[0][1] + y[0] + y[1]
33 | return ret
34 | else:
35 | return None
36 |
37 |
38 | def inplace_in_loop(x, y):
39 | ret = 0
40 | for i in range(10):
41 | x[0] = 1
42 | z = [y]
43 | y[1] = 2 * i + 1
44 | ret += x[0] + x[1] + z[0][1] + y[0] + y[1]
45 | return ret
46 |
47 |
48 | def inplace_in_loop_fallback(x, y, it):
49 | ret = 0
50 | for i in it:
51 | x[0] = 1
52 | z = [y]
53 | y[1] = 2 * i + 1
54 | ret += x[0] + x[1] + z[0][1] + y[0] + y[1]
55 | return ret
56 |
57 |
58 | def inplace_case_0(x):
59 | x[:] = 1.0
60 | return x
61 |
62 |
63 | def inplace_case_1(x):
64 | x[0][0, 0::2] = 1.0
65 | return x
66 |
67 |
68 | def inplace_case_2(x):
69 | t = x[0]
70 | t[:, 0::2] = t[:, 0::2] * 0
71 | t[:, 1::2] = t[:, 1::2] + 2
72 | return x
73 |
74 |
75 | class TestExecutor(TestCaseBase):
76 | def test_case(self):
77 | self.assert_results(inplace_case_0, paddle.randn((1, 4)))
78 | self.assert_results(inplace_case_1, [paddle.randn((1, 4))])
79 | self.assert_results(inplace_case_2, [paddle.randn((1, 4))])
80 |
81 | def test_backward(self):
82 | @symbolic_translate
83 | def func(x):
84 | m = x * 2
85 | n = x * 3
86 | y = m
87 | y[:] = n
88 | return y
89 |
90 | x = paddle.ones((1, 4)) * 4
91 | x.stop_gradient = False
92 | y = func(x)
93 | y.sum().backward()
94 | assert (x.grad.numpy() == 3).all()
95 |
96 | def test_simple(self):
97 | self.assert_results(
98 | simple, paddle.to_tensor([1.0, 2.0]), paddle.to_tensor([3.0, 4.0])
99 | )
100 |
101 | def test_if(self):
102 | self.assert_results(
103 | inplace_in_if,
104 | paddle.to_tensor([1.0, 2.0]),
105 | paddle.to_tensor([3.0, 4.0]),
106 | True,
107 | )
108 | self.assert_results(
109 | inplace_in_if_fallback,
110 | paddle.to_tensor([1.0, 2.0]),
111 | paddle.to_tensor([3.0, 4.0]),
112 | paddle.to_tensor(1),
113 | )
114 |
115 | def test_loop(self):
116 | self.assert_results(
117 | inplace_in_loop,
118 | paddle.to_tensor([1.0, 2.0]),
119 | paddle.to_tensor([3.0, 4.0]),
120 | )
121 |
122 | a = range(10)
123 | sym_output = symbolic_translate(inplace_in_loop_fallback)(
124 | paddle.to_tensor([1.0, 2.0]), paddle.to_tensor([3.0, 4.0]), iter(a)
125 | )
126 | paddle_output = inplace_in_loop_fallback(
127 | paddle.to_tensor([1.0, 2.0]), paddle.to_tensor([3.0, 4.0]), iter(a)
128 | )
129 | self.assert_nest_match(sym_output, paddle_output)
130 |
131 |
132 | if __name__ == "__main__":
133 | unittest.main()
134 |
--------------------------------------------------------------------------------
/tests/test_break_graph.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | import numpy as np
4 | from test_case_base import TestCaseBase
5 |
6 | import paddle
7 | from sot.utils.paddle_api_config import add_break_graph_apis
8 |
9 |
10 | def ifelse_func(x, y):
11 | if x > 0:
12 | y = y + 1
13 | else:
14 | y = y + 2
15 | return y
16 |
17 |
18 | class TestIfElse(TestCaseBase):
19 | def test_simple(self):
20 | x = paddle.to_tensor([1.0])
21 | y = paddle.to_tensor([2.0])
22 | self.assert_results(ifelse_func, x, y)
23 |
24 |
25 | def multi_output(x: paddle.Tensor):
26 | m = x + 1
27 | if x > 0:
28 | return m
29 | else:
30 | return 2 * m
31 |
32 |
33 | class TestExecutor(TestCaseBase):
34 | def test_simple(self):
35 | x = paddle.to_tensor(2)
36 | self.assert_results(multi_output, x)
37 | x = paddle.to_tensor(-2)
38 | self.assert_results(multi_output, x)
39 |
40 |
41 | def print_break_graph(x, y):
42 | z = x + y
43 | print(x, z)
44 | out = y * z * 2
45 | return out
46 |
47 |
48 | class TestPrint(TestCaseBase):
49 | def test_simple(self):
50 | x = paddle.to_tensor(2)
51 | y = paddle.to_tensor(3)
52 | self.assert_results(print_break_graph, x, y)
53 |
54 |
55 | def to_tensor_break_graph(x, y):
56 | z = x + y
57 | out = y * paddle.to_tensor(2) * z
58 | return out
59 |
60 |
61 | class TestToTensor(TestCaseBase):
62 | def test_simple(self):
63 | add_break_graph_apis([paddle.to_tensor])
64 | x = paddle.to_tensor(2)
65 | y = paddle.to_tensor(3)
66 | self.assert_results(to_tensor_break_graph, x, y)
67 |
68 |
69 | def tensor_clear_gradient(x):
70 | x = paddle.to_tensor(x)
71 | x.clear_gradient()
72 | return x
73 |
74 |
75 | class TestBreakGraphInResumeFn(TestCaseBase):
76 | def test_simple(self):
77 | x = paddle.to_tensor(2)
78 | self.assert_results(tensor_clear_gradient, x)
79 |
80 |
81 | def inner_fn(a, b, c, d):
82 | return a + b * c - d
83 |
84 |
85 | def multi_stack_args(a, b, c):
86 | out = inner_fn(a, b, c, paddle.to_tensor(4))
87 | return out
88 |
89 |
90 | class TestMultiStackArgs(TestCaseBase):
91 | def test_simple(self):
92 | a = paddle.to_tensor(1)
93 | b = paddle.to_tensor(2)
94 | c = paddle.to_tensor(3)
95 | self.assert_results(multi_stack_args, a, b, c)
96 |
97 |
98 | def break_graph_in_call_method(x):
99 | out = paddle.nn.functional.relu(paddle.to_tensor([4.0]))
100 | return x + out
101 |
102 |
103 | def numpy_break_graph():
104 | a = paddle.to_tensor([1, 2])
105 | b = np.sum(a.numpy())
106 | print(b)
107 | return b
108 |
109 |
110 | class TestBreakGraphInCallMethod(TestCaseBase):
111 | def test_simple(self):
112 | x = paddle.to_tensor([1.0])
113 | break_graph_in_call_method(x)
114 | x = paddle.to_tensor([2.0])
115 | break_graph_in_call_method(x)
116 |
117 | x = paddle.to_tensor([3.0])
118 | self.assert_results(break_graph_in_call_method, x)
119 |
120 | def test_numpy(self):
121 | self.assert_results(numpy_break_graph)
122 |
123 |
124 | def test_break_graph_repeat(x):
125 | out = paddle.to_tensor(
126 | paddle.to_tensor(paddle.to_tensor(paddle.to_tensor([1.0])))
127 | )
128 | return x + out
129 |
130 |
131 | class TestBreakGraphRepeat(TestCaseBase):
132 | def test_simple(self):
133 | x = paddle.to_tensor([1.0])
134 | test_break_graph_repeat(x)
135 | x = paddle.to_tensor([2.0])
136 | test_break_graph_repeat(x)
137 |
138 | x = paddle.to_tensor([3.0])
139 | self.assert_results(test_break_graph_repeat, x)
140 |
141 |
142 | if __name__ == "__main__":
143 | unittest.main()
144 |
--------------------------------------------------------------------------------
/docs/compat/python311/closure.md:
--------------------------------------------------------------------------------
1 | # Closure 适配
2 |
3 |
4 | ## Python 中的闭包示例
5 |
6 | 以下是在新版本中闭包函数处理的demo,以及它对应的字节码 :
7 |
8 | ```python
9 | import dis
10 |
11 | def func():
12 | free_x = 1
13 | free_y = 2
14 |
15 | def local(y):
16 | return y + free_x + free_y
17 | return local(1)
18 |
19 | dis.dis(func)
20 | ```
21 |
22 | ```bash
23 | 0 MAKE_CELL 1 (free_x)
24 | 2 MAKE_CELL 2 (free_y)
25 |
26 | 9 4 RESUME 0
27 |
28 | 10 6 LOAD_CONST 1 (1)
29 | 8 STORE_DEREF 1 (free_x)
30 |
31 | 11 10 LOAD_CONST 2 (2)
32 | 12 STORE_DEREF 2 (free_y)
33 |
34 | 13 14 LOAD_CLOSURE 1 (free_x)
35 | 16 LOAD_CLOSURE 2 (free_y)
36 | 18 BUILD_TUPLE 2
37 | 20 LOAD_CONST 3 ()
38 | 22 MAKE_FUNCTION 8 (closure)
39 | 24 STORE_FAST 0 (local)
40 |
41 | 15 26 PUSH_NULL
42 | 28 LOAD_FAST 0 (local)
43 | 30 LOAD_CONST 1 (1)
44 | 32 PRECALL 1
45 | 36 CALL 1
46 | 46 RETURN_VALUE
47 |
48 | Disassembly of :
49 | 0 COPY_FREE_VARS 2
50 |
51 | 13 2 RESUME 0
52 |
53 | 14 4 LOAD_FAST 0 (y)
54 | 6 LOAD_DEREF 1 (free_x)
55 | 8 BINARY_OP 0 (+)
56 | 12 LOAD_DEREF 2 (free_y)
57 | 14 BINARY_OP 0 (+)
58 | 18 RETURN_VALUE
59 | ```
60 |
61 | ## 新版本中对字节码的改动:
62 |
63 | ### 首先是语义上的改动
64 |
65 | LOAD_CLOSURE: 新版本不再是`co_cellvars + co_freevars`长度偏移量, 而是`LOAD_FAST`的一个别名
66 |
67 | LOAD_DEREF: 加载包含在 locals 中的元素
68 |
69 | STORE_DEREF: 存储 TOS 到 locals 中
70 |
71 | ### 新增字节码
72 |
73 | MAKE_CELL: 如果元素不存在于 locals 则从 co_freevars 和 co_cellvars 中加载
74 |
75 | COPY_FREE_VARS: 复制 co_freevars 和 co_cellvars 中的元素到 locals
76 |
77 | ## 分析
78 |
79 | 从字节码上的改动来看,在 python3.11 中, 闭包将数据存储在 locals 中,而不是 cell 中,这样做的好处是可以减少一次间接寻址,提高性能。
80 |
81 | ## 实现
82 |
83 | LOAD_CLOSURE: 作为`LOAD_FAST`的别名,所以直接调用
84 |
85 | LOAD_DEREF: 改为从 `self._locals` 中加载元素到 TOS 中
86 |
87 | STORE_DEREF: 改为存储 TOS 到 `self._locals` 中
88 |
89 | MAKE_CELL: 从 `self._cells` 中加载元素到 `self._locals`
90 |
91 | COPY_FREE_VARS(闭包内部字节码): 从 `self._code.co_freevars` 拿到 key 在 `self._cells` 中找到元素存储到 `self._locals`
92 |
93 | ## codegen
94 |
95 | ```bash
96 | [transform] NewCode: #foo_af1a0
97 | 9 0 MAKE_CELL 0 (x) # 在此处生成存储字节码,将元素存储至 locals
98 | 2 MAKE_CELL 1 (y)
99 | 4 MAKE_CELL 5 (z)
100 | 6 RESUME 0
101 | 8 LOAD_GLOBAL 1 (NULL + paddle_set_eval_frame_fn)
102 | ...
103 | 104 POP_TOP
104 | 106 RETURN_VALUE
105 |
106 | Disassembly of :
107 | 0 COPY_FREE_VARS 3 # 在此处生成拷贝字节码,将数据拷贝至闭包内部调用
108 |
109 | 12 2 RESUME 0
110 |
111 | 13 4 LOAD_FAST 0 (a)
112 | ...
113 | 30 RETURN_VALUE
114 |
115 | ```
116 |
117 |
118 | ## 单测
119 |
120 | 新增一项之前未覆盖情况
121 |
122 | ```python
123 | def create_closure():
124 | x = 1
125 |
126 | def closure():
127 | return x + 1
128 |
129 | return closure
130 | ```
131 |
132 | ## 其他更改
133 |
134 | 此次升级还依赖于 eval frame 修改,相关适配链接:[#57490](https://github.com/PaddlePaddle/Paddle/pull/57490)、[#57653](https://github.com/PaddlePaddle/Paddle/pull/57653)
135 |
--------------------------------------------------------------------------------
/docs/design/builtin-dispatcher.md:
--------------------------------------------------------------------------------
1 | # Builtin 函数派发机制
2 |
3 | ## 什么是 BuiltinVariable?
4 |
5 | 什么是 BuiltinVariable 呢?最开始我们以为 BuiltinVariable 应该就是 builtin 这个 namespace 里的各种变量,比如 `int`、`abs` 等等这些不需要 import 就可以直接使用的变量,但是实际上这一角色已经由 BuiltinTracker 承担了,BuiltinVariable 就和 BuiltinTracker 定位重复了。
6 |
7 | 对于现有的其他 CallableVariable,每一个 Variable 的定位都很清晰,实现方式也很清晰,比如以 UserDefinedFunctionVariable 为代表的 inline call 方式、以 PaddleApiVariable 为代表的组网方式、以及部分需要子图打断的 API。但是 BuiltinVariable 不是这样的,对于大多数 builtin 函数,在 Python 执行时会去调用对应的 magic method,比如 `len` 会调用 `__len__`,此时执行的效果是与变量类型强相关的,比如用户重载了一个自定义对象的 `__len__`,此时应该去尝试 inline call,因为用户的代码中包含各种各样的情况,而部分对象的 `__len__` 不是写在 Python 端的,此时我们应该去模拟其执行效果。
8 |
9 | 另外值得注意的是,Python 会认为不在 Python 端定义的、没有字节码的函数的类型为 `types.BuiltinFunctionType`,这样 BuiltinVariable 的定位已经很清晰了,即**没有字节码,无法通过字节码来模拟执行的函数**。
10 |
11 | ## 为什么需要派发机制
12 |
13 | 如果无法直接模拟字节码的话,我们在模拟执行时要怎么做呢?起初我们直接利用 Python 的派发机制,在 Variable 上重载 magic method,实现自动派发:
14 |
15 | ```python
16 | class BuiltinVariable:
17 | def call_function(self, *args: VariableBase, **kwargs: VariableBase):
18 | self.value(*args, **kwargs)
19 | ```
20 |
21 | 但这样问题很明显,比如其实 Python 对部分 magic method 输出是有强制类型检查的,比如 `__bool__` 强制是 bool 类型、`__len__` 强制是 int 类型。而按照上述的实现,我们应该返回的是 VariableBase,这就导致部分 magic method 是无法复用 Python 内建的派发机制的。另一方面,也因为很多 Variable 还没有实现相应的 magic method 而报错。
22 |
23 | 为了避免这些问题,我们添加了一个类似 magic method 的派发机制,来将 BuiltinVariable 的调用转发某一个具体的函数上,这个派发机制会去尝试匹配参数的类型,如果找不到匹配的参数类型,就会尝试从类上获取相应的 magic method 来 inline call,如果仍然找不到,则会产生一个 BreakGraphError 来打断子图。
24 |
25 | ## 派发机制的实现和使用方式
26 |
27 | 派发机制的本质是对参数类型进行匹配,对于一个函数,我们会有多种允许的类型签名(Pattern)以及对应的 handler,而在调用这个函数的时候,我们会根据参数类型来派发到相应的 handler 上,主体代码如下:
28 |
29 | ```python
30 | class Pattern:
31 | type_strings: Args[str]
32 | kwtype_strings: Kwargs[str]
33 |
34 | def __init__(
35 | self,
36 | *types: str,
37 | **kwtypes: str,
38 | ):
39 | self.type_strings = types
40 | self.kwtype_strings = kwtypes
41 |
42 | class Dispatcher:
43 | handlers: dict[
44 | Callable[..., Any], list[tuple[Pattern, Callable[..., Any]]]
45 | ] = {}
46 |
47 | @classmethod
48 | def register(
49 | cls,
50 | fn: Callable[..., Any],
51 | types: tuple[str, ...],
52 | kwtypes: dict[str, str],
53 | handler: Callable[..., Any],
54 | ):
55 | if fn not in cls.handlers:
56 | cls.handlers[fn] = []
57 | cls.handlers[fn].append((Pattern(*types, **kwtypes), handler))
58 |
59 | @classmethod
60 | def dispatch(
61 | cls, fn: Callable[..., Any], *args: Any, **kwargs: Any
62 | ) -> Callable[..., Any] | None:
63 | if fn not in cls.handlers:
64 | return None
65 | for pattern, handler in cls.handlers[fn]:
66 | if pattern.match_inputs(*args, **kwargs):
67 | return handler
68 | return None
69 | ```
70 |
71 | 这样,我们只需要调用 Dispatcher 将函数签名和 handler 注册到 Dispatcher 上即可:
72 |
73 | ```python
74 | Dispatcher.register(
75 | len,
76 | ("ContainerVariable",),
77 | {},
78 | lambda var: var.len(),
79 | )
80 | ```
81 |
82 | 为了方便使用,我们还提供了一个装饰器的模式:
83 |
84 |
85 | ```python
86 | if TYPE_CHECKING:
87 | from .variables import ContainerVariable
88 |
89 | @Dispatcher.register_decorator(len)
90 | def dispatch_len(var: ContainerVariable):
91 | return var.len()
92 | ```
93 |
94 | 对于一些复杂的函数,是比较推荐装饰器模式的。
95 |
96 | ## 利用派发机制简化现有代码
97 |
98 | 在实现派发机制之前,我们在很多地方会有重复的代码,比如 Python 的 `a.b` 和 `getattr(a, "b")` 是等价的,但是字节码层面是完全不同的,前者是 `LOAD_ATTR`,后者则是 `CALL_FUNCTION`,此前我们也是各自实现的。
99 |
100 | 在实现了派发机制之后,我们完全可以利用派发机制来实现 `LOAD_ATTR`:
101 |
102 | ```python
103 | @call_break_graph_decorator(push_n=1)
104 | def LOAD_ATTR(self, instr):
105 | attr_name = instr.argval
106 | obj = self.pop()
107 | self.push(
108 | BuiltinVariable(
109 | getattr, graph=self._graph, tracker=DanglingTracker()
110 | )(obj, attr_name)
111 | )
112 | ```
113 |
114 | 这样可以极大的简化代码,降低维护成本。
115 |
116 | > **Note**
117 | >
118 | > 我们现在代码中仍然有很多地方使用了旧的派发机制(利用 Python magic method 直接派发),这些将在之后逐步替换
119 |
--------------------------------------------------------------------------------
/sot/opcode_translator/skip_files.py:
--------------------------------------------------------------------------------
1 | import abc
2 | import codecs
3 | import collections
4 | import contextlib
5 | import copy
6 | import copyreg
7 | import dataclasses
8 | import distutils
9 | import enum
10 | import functools
11 | import importlib
12 | import inspect
13 | import linecache
14 | import logging
15 | import multiprocessing
16 | import operator
17 | import os
18 | import posixpath
19 | import random
20 | import re
21 | import selectors
22 | import signal
23 | import sys
24 | import tempfile
25 | import threading
26 | import tokenize
27 | import traceback
28 | import types
29 | import typing
30 | import unittest
31 | import uuid
32 | import warnings
33 | import weakref
34 |
35 | import _collections_abc
36 | import _weakrefset
37 | import decorator
38 | import google.protobuf
39 | import numpy
40 | import setuptools
41 |
42 | import paddle
43 |
44 | from ..utils import log
45 |
46 | NEED_SKIP_THIRD_PARTIY_MODULES = {
47 | abc,
48 | collections,
49 | contextlib,
50 | copy,
51 | copyreg,
52 | dataclasses,
53 | enum,
54 | functools,
55 | google.protobuf,
56 | importlib,
57 | inspect,
58 | linecache,
59 | logging,
60 | multiprocessing,
61 | numpy,
62 | operator,
63 | os,
64 | posixpath,
65 | random,
66 | re,
67 | selectors,
68 | signal,
69 | tempfile,
70 | threading,
71 | tokenize,
72 | traceback,
73 | types,
74 | typing,
75 | unittest,
76 | weakref,
77 | _collections_abc,
78 | _weakrefset,
79 | decorator,
80 | codecs,
81 | uuid,
82 | setuptools,
83 | distutils,
84 | warnings,
85 | }
86 |
87 | if sys.version_info < (3, 11):
88 | import sre_compile
89 | import sre_parse
90 |
91 | NEED_SKIP_THIRD_PARTIY_MODULES.add(sre_compile)
92 | NEED_SKIP_THIRD_PARTIY_MODULES.add(sre_parse)
93 |
94 |
95 | def _strip_init_py(s):
96 | return re.sub(r"__init__.py$", "", s)
97 |
98 |
99 | def _module_dir(m: types.ModuleType):
100 | return _strip_init_py(m.__file__)
101 |
102 |
103 | skip_file_names = {_module_dir(m) for m in NEED_SKIP_THIRD_PARTIY_MODULES}
104 |
105 |
106 | sot_path = os.path.dirname(__file__).rpartition("/")[0] + "/"
107 | paddle_path = sys.modules["paddle"].__file__.rpartition("/")[0] + "/"
108 |
109 | skip_file_names.add(sot_path)
110 | skip_file_names.add(paddle_path)
111 | skip_file_names.add(
112 | "")
115 |
116 | skip_file_name_re = re.compile(
117 | f"^({'|'.join(map(re.escape, skip_file_names))})"
118 | )
119 |
120 | customed_skip_code = set()
121 |
122 | no_skip_code = {paddle.nn.Sequential.forward.__code__}
123 |
124 |
125 | def need_skip_path(filepath: str) -> bool:
126 | """
127 | Check if the file should be skipped and not transcribed.
128 |
129 | Args:
130 | filepath: The path of the file to check.
131 |
132 | Returns:
133 | bool: True if the file should be skipped.
134 | """
135 | if not filepath.startswith("<"):
136 | filepath = os.path.abspath(filepath)
137 | return bool(skip_file_name_re.match(filepath))
138 |
139 |
140 | def skip_function(function):
141 | customed_skip_code.add(function.__code__)
142 | return function
143 |
144 |
145 | def need_skip(frame):
146 | pycode = frame.f_code
147 | if pycode in no_skip_code:
148 | return False
149 | if pycode in customed_skip_code:
150 | log(3, f"Skip frame by code: {pycode}\n")
151 | return True
152 | filename = pycode.co_filename
153 | if sys.version_info >= (3, 11) and filename.startswith(" 0.0057
64 | with EventGuard("FallbackWrapper: call partial_program"):
65 | outputs = self.partial_program(*args, **kwargs)
66 |
67 | clear_eager_tensor_name(outputs)
68 | log_do(
69 | 1,
70 | lambda: GraphLogger().add_subgraph(
71 | self.concrete_program.main_program
72 | ),
73 | )
74 | log_do(
75 | 4,
76 | lambda: print("[CompileCache] run sir forward success."),
77 | )
78 | return outputs
79 |
80 |
81 | @Singleton
82 | class CompileSIRCache(Cache):
83 | """
84 | Cache the compiled function of SIR
85 | """
86 |
87 | def __init__(self):
88 | super().__init__(weak=False)
89 |
90 | def key_fn(self, context: SymbolicTraceContext, sir_name: str, **kwargs):
91 | """
92 | generate a hash key for a SIR
93 |
94 | Args:
95 | context: The context to compile
96 | sir_name: The name of the sir to compile
97 | build_strategy: The build strategy to compile
98 |
99 | Returns:
100 | The hash key of the SIR
101 | """
102 | sir = context.get_sir(sir_name)
103 | # NOTE(dev): Is str(sir) a heavy opearation ?
104 | hash_key = hash(str(sir))
105 | return hash_key
106 |
107 | def value_fn(self, context: SymbolicTraceContext, sir_name: str, **kwargs):
108 | """
109 | Generate static graph function
110 |
111 | Args:
112 | context: The context to compile
113 | sir_name: The name of the sir to compile
114 | build_strategy: The build strategy to compile
115 |
116 | Returns:
117 | The static graph function
118 | """
119 | build_strategy = kwargs.get("build_strategy", None)
120 | backend = kwargs.get("backend", None)
121 | return FallbackWrapper(
122 | paddle.jit.to_static(
123 | compile_sir(context, sir_name),
124 | build_strategy=build_strategy,
125 | backend=backend,
126 | enable_fallback=False,
127 | ),
128 | context.get_sir(sir_name),
129 | )
130 |
--------------------------------------------------------------------------------
/sot/translate.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import os
4 | from typing import TYPE_CHECKING, Callable, TypeVar
5 |
6 | import paddle
7 |
8 | from .opcode_translator import eval_frame_callback
9 | from .utils import GraphLogger, StepInfoManager, StepState, log_do
10 |
11 | if TYPE_CHECKING:
12 | from typing_extensions import ParamSpec
13 |
14 | P = ParamSpec("P")
15 | R = TypeVar("R")
16 |
17 | # Temporarily set the default log level to 2 to get more information in CI log.
18 | os.environ["LOG_LEVEL"] = os.getenv("LOG_LEVEL", "2")
19 |
20 |
21 | def symbolic_translate(fn: Callable[P, R], **kwargs) -> Callable[P, R]:
22 | """
23 | This function is the entry point of PaddleSOT. It sets eval_frame_callback before input
24 | function to achieve Opcode-level translation. The translation process depends on the
25 | simulation execution, in which information will be collected, especially the network
26 | code. After the simulation execution is completed, the network code will be compiled
27 | into a static graph Program to improve performance.
28 |
29 | Args:
30 | fn: The input function.
31 |
32 | Returns:
33 | Callable, The wrapped function.
34 |
35 | Examples:
36 | >>> # doctest: +SKIP("Cound not get source code of function foo."")
37 | >>> import paddle
38 | >>> import numpy as np
39 | >>> from sot.translate import symbolic_translate
40 | >>> def foo(cond: paddle.Tensor, x: paddle.Tensor):
41 | ... x += 1
42 | ... if cond:
43 | ... x += 1
44 | ... else:
45 | ... x -= 1
46 | ... return x
47 | >>> symbolic_translate_foo = symbolic_translate(foo)
48 | >>> # For the true branch, the output is 2.
49 | >>> cond = paddle.to_tensor(True)
50 | >>> x = paddle.to_tensor(0)
51 | >>> dygraph_out = foo(cond, x)
52 | >>> symbolic_translate_out = symbolic_translate_foo(cond, x)
53 | >>> dygraph_out
54 | Tensor(shape=[], dtype=int64, place=Place(cpu), stop_gradient=True,
55 | 2)
56 | >>> symbolic_translate_out
57 | Tensor(shape=[], dtype=int64, place=Place(cpu), stop_gradient=True,
58 | 2)
59 | >>> np.testing.assert_allclose(
60 | ... dygraph_out.numpy(), symbolic_translate_out.numpy()
61 | ... )
62 | >>> # For the false branch, the output is 0.
63 | >>> cond = paddle.to_tensor(False)
64 | >>> dygraph_out = foo(cond, x)
65 | >>> symbolic_translate_out = symbolic_translate_foo(cond, x)
66 | >>> dygraph_out
67 | Tensor(shape=[], dtype=int64, place=Place(cpu), stop_gradient=True,
68 | 0)
69 | >>> symbolic_translate_out
70 | Tensor(shape=[], dtype=int64, place=Place(cpu), stop_gradient=True,
71 | 0)
72 | >>> np.testing.assert_allclose(
73 | ... dygraph_out.numpy(), symbolic_translate_out.numpy()
74 | ... )
75 |
76 | """
77 |
78 | def callback(frame):
79 | return eval_frame_callback(frame, **kwargs)
80 |
81 | def impl_sot(*args: P.args, **kwargs: P.kwargs) -> R:
82 | assert hasattr(
83 | fn, "__code__"
84 | ), "Target function doesn't have code for simulating."
85 | StepInfoManager().sot_step()
86 | GraphLogger().clear()
87 | paddle.framework.core.set_eval_frame(callback)
88 | try:
89 | outs = fn(*args, **kwargs)
90 | except Exception as e:
91 | raise e
92 | finally:
93 | paddle.framework.core.set_eval_frame(None)
94 |
95 | log_do(1, lambda: GraphLogger().print_info())
96 | return outs
97 |
98 | def impl_dynamic(*args: P.args, **kwargs: P.kwargs) -> R:
99 | outs = fn(*args, **kwargs)
100 | return outs
101 |
102 | def impl(*args: P.args, **kwargs: P.kwargs) -> R:
103 | with StepInfoManager().step_guard(fn.__code__):
104 | state = StepInfoManager().current_state
105 |
106 | if state == StepState.RUN_SOT:
107 | return impl_sot(*args, **kwargs)
108 | elif state == StepState.RUN_DYN:
109 | return impl_dynamic(*args, **kwargs)
110 | elif state == StepState.COLLECT_INFO:
111 | return StepInfoManager().collect_info(
112 | impl_dynamic, impl_sot, *args, **kwargs
113 | )
114 |
115 | return impl
116 |
--------------------------------------------------------------------------------
/sot/utils/magic_methods.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import operator
4 | from dataclasses import dataclass
5 | from typing import TYPE_CHECKING, Any, Callable
6 |
7 | from .utils import hashable
8 |
9 | if TYPE_CHECKING:
10 | BinaryOp = Callable[[Any, Any], Any]
11 | UnaryOp = Callable[[Any], Any]
12 |
13 |
14 | INPLACE_BINARY_OPS_TO_MAGIC_NAMES: dict[BinaryOp, tuple[str, BinaryOp]] = {
15 | # inplace op fn: (magic name, non-inplace op fn)
16 | operator.iadd: ("__iadd__", operator.add),
17 | operator.iand: ("__iand__", operator.and_),
18 | operator.iconcat: ("__iconcat__", operator.concat),
19 | operator.ifloordiv: ("__ifloordiv__", operator.floordiv),
20 | operator.ilshift: ("__ilshift__", operator.lshift),
21 | operator.imatmul: ("__imatmul__", operator.matmul),
22 | operator.imod: ("__imod__", operator.mod),
23 | operator.imul: ("__imul__", operator.mul),
24 | operator.ior: ("__ior__", operator.or_),
25 | operator.ipow: ("__ipow__", operator.pow),
26 | operator.irshift: ("__irshift__", operator.rshift),
27 | operator.isub: ("__isub__", operator.sub),
28 | operator.itruediv: ("__itruediv__", operator.truediv),
29 | operator.ixor: ("__ixor__", operator.xor),
30 | }
31 |
32 | NON_INPLACE_BINARY_OPS_TO_MAGIC_NAMES: dict[
33 | BinaryOp, tuple[str, str | None]
34 | ] = {
35 | # op fn: (magic name, reverse magic name)
36 | operator.add: ("__add__", "__radd__"),
37 | operator.and_: ("__and__", "__rand__"),
38 | operator.contains: ("__contains__", None),
39 | operator.delitem: ("__delitem__", None),
40 | operator.eq: ("__eq__", "__eq__"),
41 | operator.floordiv: ("__floordiv__", "__rfloordiv__"),
42 | operator.ge: ("__ge__", "__le__"),
43 | operator.getitem: ("__getitem__", None),
44 | operator.gt: ("__gt__", "__lt__"),
45 | operator.le: ("__le__", "__ge__"),
46 | operator.lshift: ("__lshift__", "__rlshift__"),
47 | operator.lt: ("__lt__", "__gt__"),
48 | operator.matmul: ("__matmul__", "__rmatmul__"),
49 | operator.mod: ("__mod__", "__rmod__"),
50 | operator.mul: ("__mul__", "__rmul__"),
51 | operator.ne: ("__ne__", "__ne__"),
52 | operator.or_: ("__or__", "__ror__"),
53 | operator.pow: ("__pow__", "__rpow__"),
54 | operator.rshift: ("__rshift__", "__rrshift__"),
55 | operator.sub: ("__sub__", "__rsub__"),
56 | operator.truediv: ("__truediv__", "__rtruediv__"),
57 | operator.xor: ("__xor__", "__rxor__"),
58 | }
59 |
60 | UNARY_OPS_TO_MAGIC_NAMES: dict[UnaryOp, str] = {
61 | operator.neg: "__neg__",
62 | operator.invert: "__invert__",
63 | operator.pos: "__pos__",
64 | operator.abs: "__abs__",
65 | operator.index: "__index__",
66 | operator.inv: "__inv__",
67 | operator.invert: "__invert__",
68 | operator.not_: "__not__",
69 | operator.pos: "__pos__",
70 | operator.truth: "__bool__",
71 | bool: "__bool__",
72 | abs: "__abs__",
73 | float: "__float__",
74 | len: "__len__",
75 | int: "__int__",
76 | }
77 | # TODO(SigureMo): support any, all, sum
78 |
79 |
80 | INPLACE_BINARY_OPS = set(INPLACE_BINARY_OPS_TO_MAGIC_NAMES.keys())
81 | NON_INPLACE_BINARY_OPS = set(NON_INPLACE_BINARY_OPS_TO_MAGIC_NAMES.keys())
82 | BINARY_OPS = INPLACE_BINARY_OPS | NON_INPLACE_BINARY_OPS
83 | UNARY_OPS = set(UNARY_OPS_TO_MAGIC_NAMES.keys())
84 |
85 |
86 | @dataclass
87 | class MagicMethod:
88 | name: str
89 | is_inplace: bool = False
90 | is_reverse: bool = False
91 |
92 |
93 | def magic_method_builtin_dispatch(fn: BinaryOp | UnaryOp) -> list[MagicMethod]:
94 | if not hashable(fn):
95 | return []
96 | if fn in INPLACE_BINARY_OPS:
97 | inplace_magic_name, non_inplace_op = INPLACE_BINARY_OPS_TO_MAGIC_NAMES[
98 | fn
99 | ]
100 | return [
101 | MagicMethod(inplace_magic_name, is_inplace=True)
102 | ] + magic_method_builtin_dispatch(non_inplace_op)
103 | elif fn in NON_INPLACE_BINARY_OPS:
104 | magic_name, reverse_magic_name = NON_INPLACE_BINARY_OPS_TO_MAGIC_NAMES[
105 | fn
106 | ]
107 | magic_methods = [MagicMethod(magic_name)]
108 | if reverse_magic_name is not None:
109 | magic_methods.append(
110 | MagicMethod(reverse_magic_name, is_reverse=True)
111 | )
112 | return magic_methods
113 | elif fn in UNARY_OPS:
114 | magic_name = UNARY_OPS_TO_MAGIC_NAMES[fn]
115 | return [MagicMethod(magic_name)]
116 | return []
117 |
--------------------------------------------------------------------------------
/tests/test_code_status.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | from test_case_base import TestCaseBase, strict_mode_guard
4 |
5 | import paddle
6 | import sot
7 | from sot.opcode_translator.skip_files import skip_function
8 | from sot.utils.code_status import CodeState, CodeStatus
9 |
10 |
11 | class SimpleNet1(paddle.nn.Layer):
12 | def __init__(self):
13 | super().__init__()
14 | self.layers = paddle.nn.LayerList(
15 | [paddle.nn.Linear(10, 10) for _ in range(30)]
16 | )
17 |
18 | def forward(self, x):
19 | for i in range(len(self.layers)):
20 | sot.psdb.breakgraph()
21 | x = self.layers[i](x)
22 | x = self.layers[i](x)
23 | x = self.layers[i](x)
24 | x = self.layers[i](x)
25 | return x
26 |
27 |
28 | class SimpleNet2(paddle.nn.Layer):
29 | def __init__(self):
30 | super().__init__()
31 | self.layers = paddle.nn.LayerList(
32 | [paddle.nn.Linear(10, 10) for _ in range(30)]
33 | )
34 |
35 | def forward(self, x):
36 | sot.psdb.fallback()
37 | for i in range(len(self.layers)):
38 | x = self.layers[i](x)
39 | x = self.layers[i](x)
40 | x = self.layers[i](x)
41 | x = self.layers[i](x)
42 | return x
43 |
44 |
45 | def run_net(net, x):
46 | for i in range(20):
47 | x = net(x)
48 | return x
49 |
50 |
51 | class TestCodeInfo(TestCaseBase):
52 | def test_case_1(self):
53 | CodeStatus().clear()
54 | net = SimpleNet1()
55 | inp = paddle.rand((10, 10))
56 | self.assert_results(run_net, net, inp)
57 | code_map = CodeStatus().code_map
58 | states = []
59 | for k, v in code_map.items():
60 | if k.co_name.startswith("#") or k.co_name.startswith("$"):
61 | states.append(v)
62 | elif k in CodeStatus().WITH_GRAPH_API:
63 | assert v.state == CodeState.WITH_GRAPH
64 | else:
65 | assert v.state == CodeState.WITHOUT_GRAPH
66 | # run_net, forward, loop body, resumed part2 in loop body
67 | assert len([v for v in states if v.state == CodeState.WITH_GRAPH]) == 4
68 | # resumed part1 in loop body
69 | assert (
70 | len([v for v in states if v.state == CodeState.WITHOUT_GRAPH]) == 1
71 | )
72 |
73 | def test_case_2(self):
74 | with strict_mode_guard(0):
75 | CodeStatus().clear()
76 | net = SimpleNet2()
77 | inp = paddle.rand((10, 10))
78 | self.assert_results(run_net, net, inp)
79 | code_map = CodeStatus().code_map
80 | states = []
81 | for k, v in code_map.items():
82 | if k.co_name.startswith("#") or k.co_name.startswith("$"):
83 | states.append(v)
84 | elif k in CodeStatus().WITH_GRAPH_API:
85 | assert v.state == CodeState.WITH_GRAPH
86 | else:
87 | assert v.state == CodeState.WITHOUT_GRAPH
88 | # no graph found because fallback (paddle api will not enter simulate)
89 | assert (
90 | len([v for v in states if v.state == CodeState.WITH_GRAPH]) == 0
91 | )
92 |
93 |
94 | def no_skip_func_0(x):
95 | return x + 1
96 |
97 |
98 | def skipped_func_0():
99 | pass
100 |
101 |
102 | def skipped_func_1(x):
103 | return x + 1
104 |
105 |
106 | def skipped_func_2(x):
107 | return no_skip_func_0(x)
108 |
109 |
110 | def call_skipped_func_0(x):
111 | for i in range(15):
112 | skipped_func_0()
113 | x = skipped_func_1(x)
114 | x = skipped_func_2(x)
115 | return x
116 |
117 |
118 | skip_function(skipped_func_0)
119 | skip_function(skipped_func_1)
120 | skip_function(skipped_func_2)
121 | skip_function(call_skipped_func_0)
122 |
123 |
124 | class TestDisableSkippedFrame(TestCaseBase):
125 | def test_case_0(self):
126 | CodeStatus().clear()
127 | x = paddle.to_tensor([1])
128 | self.assert_results(call_skipped_func_0, x)
129 | code_map = CodeStatus().code_map
130 | assert (
131 | code_map[skipped_func_0.__code__].state == CodeState.WITHOUT_GRAPH
132 | )
133 | assert (
134 | code_map[skipped_func_1.__code__].state == CodeState.WITHOUT_GRAPH
135 | )
136 | assert code_map[skipped_func_2.__code__].state == CodeState.WITH_GRAPH
137 |
138 |
139 | if __name__ == "__main__":
140 | unittest.main()
141 |
--------------------------------------------------------------------------------
/tests/test_21_global.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import unittest
4 |
5 | from test_case_base import TestCaseBase
6 |
7 | import paddle
8 | import sot
9 |
10 | global_x = 1
11 | global_y = paddle.to_tensor(2)
12 | global_z = None
13 | global_del_val = 1
14 | global_dict = {}
15 | global_list = [1, 2]
16 | global_inline = 0
17 |
18 |
19 | def global_func_int():
20 | global global_x
21 | global_x = global_x + 1
22 | return global_x
23 |
24 |
25 | def global_func_int_add():
26 | global global_x
27 | global_x = global_x + global_x
28 | return global_x + global_x
29 |
30 |
31 | def global_func_tensor_int_add(tensor_y: paddle.Tensor):
32 | global global_x
33 | global_x += 1
34 | return global_x + tensor_y
35 |
36 |
37 | def global_multiple_update():
38 | global global_x
39 | global_x = 999
40 | global_x = 888
41 | global_x = 777
42 | return global_x - 1
43 |
44 |
45 | def global_func_tensor():
46 | global global_y
47 | global_y = global_y + global_y
48 | return global_y
49 |
50 |
51 | def global_func_tensor_add():
52 | global global_y
53 | global_y = global_y + global_y
54 | return global_y + global_y
55 |
56 |
57 | def global_func():
58 | global global_x
59 | global global_y
60 | global global_z
61 |
62 | global_z = global_x + global_y
63 | return global_z
64 |
65 |
66 | def global_del_global():
67 | global global_del_val
68 |
69 | del global_del_val
70 |
71 |
72 | def global_func_dict():
73 | global global_dict
74 | global_dict["key"] = "value"
75 | global_dict.update({"test_key1": "test_value2"})
76 | return global_dict
77 |
78 |
79 | def global_func_control1():
80 | global global_dict
81 | if "key" in global_dict:
82 | del global_dict["key"]
83 | return global_dict
84 |
85 |
86 | def global_func_control2():
87 | global global_list
88 | for i in range(len(global_list)):
89 | global_list[i] = global_list[i] + 1
90 | return global_list
91 |
92 |
93 | def global_func_inline_inner_1():
94 | global global_inline
95 | global_func_inline_inner_2()
96 | global_inline += 1
97 |
98 |
99 | def global_func_inline_inner_2():
100 | global global_inline
101 | global_inline += 1
102 |
103 |
104 | def global_func_inline():
105 | global_func_inline_inner_1()
106 | global global_inline
107 | return global_inline
108 |
109 |
110 | class TestGlobal(TestCaseBase):
111 | def test_global_func_int(self):
112 | global global_x
113 | self.assert_results_with_global_check(global_func_int, ["global_x"])
114 | global_x += 1
115 | self.assert_results_with_global_check(global_func_int, ["global_x"])
116 | self.assert_results_with_global_check(global_func_int_add, ["global_x"])
117 |
118 | def test_global_multiple_update(self):
119 | self.assert_results_with_global_check(
120 | global_multiple_update, ["global_x"]
121 | )
122 |
123 | def test_global_func_tensor_int_add(self):
124 | self.assert_results_with_global_check(
125 | global_func_tensor_int_add, ["global_x"], paddle.to_tensor(1)
126 | )
127 |
128 | def test_global_func_tensor(self):
129 | self.assert_results_with_global_check(global_func_tensor, ["global_y"])
130 | self.assert_results_with_global_check(
131 | global_func_tensor_add, ["global_y"]
132 | )
133 |
134 | def test_global_func(self):
135 | self.assert_results_with_global_check(global_func, ["global_z"])
136 | self.assertIn("global_del_val", global_del_global.__globals__)
137 | sot.symbolic_translate(global_del_global)()
138 | self.assertNotIn("global_del_val", global_del_global.__globals__)
139 |
140 | def test_global_func_dict(self):
141 | self.assert_results_with_global_check(global_func_dict, ["global_dict"])
142 | self.assert_results_with_global_check(
143 | global_func_control1, ["global_dict"]
144 | )
145 |
146 | def test_global_func_list(self):
147 | self.assert_results_with_global_check(
148 | global_func_control2, ["global_list"]
149 | )
150 |
151 | def test_global_func_inline(self):
152 | global global_inline
153 | global_inline = 0
154 | sot.symbolic_translate(global_func_inline)()
155 | self.assertEqual(global_inline, 2)
156 | sot.symbolic_translate(global_func_inline)()
157 | self.assertEqual(global_inline, 4)
158 |
159 |
160 | if __name__ == "__main__":
161 | unittest.main()
162 |
--------------------------------------------------------------------------------
/sot/symbolic/symbolic_context.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from ..utils import log
4 | from .compile_cache import CompileSIRCache
5 | from .statement_ir import (
6 | ApiStatement,
7 | CallStatement,
8 | LayerStatement,
9 | MethodStatement,
10 | StatementIR,
11 | StatementIRFactory,
12 | Symbol,
13 | )
14 |
15 |
16 | class SymbolicTraceContext:
17 | """
18 | SymbolicTraceContext is a context manager, which is used to record the symbolic trace.
19 |
20 | """
21 |
22 | def __init__(self):
23 | self.reset()
24 |
25 | def reset(self):
26 | """
27 | Reset the context.
28 | """
29 |
30 | # TODO(dev): StatementIRFactory is a singleton, but SymbolicTraceContext is not.
31 | # whether will two different SymbolicTraceContext objects be conflict ?
32 | self.statement_factory = StatementIRFactory()
33 | self.sir_stack = [self.statement_factory.create()]
34 |
35 | @property
36 | def TOS(self):
37 | """
38 | The top SIR of sir_stack.
39 |
40 | Returns:
41 | StatementIR: the top of stack.
42 | """
43 |
44 | return self.sir_stack[-1]
45 |
46 | def call_SIR(self, sirname, inputs, outputs, stacks):
47 | """
48 | Call a SIR, which is a subgraph.
49 | """
50 |
51 | stmt = CallStatement(sirname, inputs, outputs, stacks)
52 | self.TOS.add_statement(stmt)
53 |
54 | def call_API(self, api, inputs, outputs, stacks):
55 | """
56 | Call a paddle api.
57 | """
58 |
59 | assert callable(api), "call_API must receive a paddle api."
60 | stmt = ApiStatement(api, inputs, outputs, stacks)
61 | self.TOS.add_statement(stmt)
62 |
63 | def call_METHOD(self, method_name, inputs, outputs, stacks):
64 | """
65 | Call a method of a api. The API here can be python or Paddle
66 | """
67 | assert isinstance(
68 | method_name, str
69 | ), "call_METHOD must method api name. string."
70 | assert isinstance(
71 | inputs[0][0], Symbol
72 | ), "call_METHOD must first augument must be Symbol Variable."
73 | stmt = MethodStatement(method_name, inputs, outputs, stacks)
74 | self.TOS.add_statement(stmt)
75 |
76 | def call_LAYER(self, layer, inputs, outputs, stacks):
77 | """
78 | Call a layer of a api.
79 | """
80 | stmt = LayerStatement(layer, inputs, outputs, stacks)
81 | self.TOS.add_statement(stmt)
82 |
83 | def get_sir(self, name: str):
84 | """
85 | Get a SIR from statement_factory.
86 |
87 | Args:
88 | name (str): the name of SIR.
89 |
90 | Returns:
91 | StatementIR: the SIR.
92 | """
93 | return self.statement_factory[name]
94 |
95 | def reset_TOS(self):
96 | """
97 | Reset the TOS.
98 | """
99 | self.sir_stack.pop()
100 | self.sir_stack.append(self.statement_factory.create())
101 |
102 | def replace_TOS(self, sir):
103 | """
104 | Use deepcopyed sir to replace the TOS.
105 | This function will update statment_factory.
106 | """
107 | self.sir_stack.pop()
108 | self.sir_stack.append(sir)
109 | self.statement_factory.update(sir)
110 |
111 | def compile_do_nothing(self, ret_vals):
112 | """
113 | Return a dummy function, which will return an empty list.
114 |
115 | Args:
116 | ret_vals (list[Symbol]): the return values of the function.
117 | """
118 |
119 | def dummy_func(*args, **kwargs):
120 | return []
121 |
122 | # return None function
123 | dummy_stmt_ir = StatementIR("dummy_func")
124 | dummy_stmt_ir.outputs = []
125 | dummy_stmt_ir.inputs = []
126 | return dummy_func, dummy_stmt_ir
127 |
128 | def compile_fn(self, ret_vals, **kwargs):
129 | """
130 | start compile and return the python function, which must can be to_static without errors.
131 | """
132 | cur_sir: StatementIR = self.TOS
133 | # step0: if no statement, return a dummy function
134 | if len(cur_sir.statements) == 0:
135 | return self.compile_do_nothing(ret_vals)
136 | # step1: analyse sir inputs and outputs
137 | cur_sir.inputs = cur_sir.analyse_inputs()
138 | # TODO: output analysis
139 | cur_sir.outputs = ret_vals
140 | log(2, "start subgraph compile and execution.\n")
141 | log(2, self.TOS, "\n")
142 | # step2: call compile_sir and get python function, third cache is triggered here.
143 | static_func = CompileSIRCache()(self, cur_sir.name, **kwargs)
144 | # step3: GC and reset TOS
145 | # self.reset_TOS()
146 |
147 | return static_func, cur_sir
148 |
--------------------------------------------------------------------------------
/docs/instructions/CALL_FUNCTION.md:
--------------------------------------------------------------------------------
1 | ## CallFunction 支持
2 |
3 | ### 两个选择
4 |
5 | 1. 使用自己模拟的 Frame 构建, Frame 传参等操作。
6 |
7 | 2. 使用 Eval Frame 的机制,但是构造不同的 eval frame 函数。
8 |
9 | 优缺点对比:
10 |
11 | | 两个CallFunction方案对比 | 优点 | 缺点 |
12 | | ------------- | ------------- | ------- |
13 | | 模拟Frame | 完全自己控制 | 1. 得自己处理Globals、Locals等逻辑
2. 完备性无法保证,魔法函数难以追踪 |
14 | | 复用EvalFrame | 不用处理调用前的准备细节 | 完备性、依赖eval frame的3个模式的完备性 N |
15 |
16 | 两个都依赖的:
17 |
18 | 1. 他们都需要区分 `组网模式` 和 `内部callback模式` 的区别。
19 |
20 |
21 | ### 尝试:直接复用Python的函数调用: Frame准备逻辑
22 |
23 | 可以直接复用 Frame 的准备逻辑,这个逻辑如果复用,那么我们直接在 OpExecutor 过程中 运行 call function 即可。
24 |
25 | 1. call function 的话,什么时候 func 进行融合,什么时候不进行融合 ?
26 |
27 | ### 总结:三种 Eval Frame 方案
28 |
29 | | 方案类别 | 模拟执行 | 真实执行 |
30 | | ------------- | ------------- | ------- |
31 | | 组网模式 | Y | Y |
32 | | 内部CallFunction | Y | N |
33 | | 动态图fallback模式 | N | Y |
34 |
35 |
36 | ### 问题:复用EvalFrame方案是否会有调用问题 ?
37 |
38 | 希望在模拟时直接调用 call function 函数,闭包、globals 等是否会有问题。
39 |
40 | ```python
41 | def func():
42 | bar()
43 | ```
44 |
45 | 这里的 bar 的 globals 如果和 func 不一样,我们在 eval frame 里面对 bar 进行了调用,是否一定和原来代码执行相同?
46 |
47 |
48 | ### 似乎 eval frame 方案是没问题的
49 |
50 | 任务拆分 -- 代码实现一波:
51 |
52 | 1. eval frame 模式 2 的准备和开发 。
53 |
54 | 2. call function 时进行 return value 的区分 。
55 |
56 | 3. opcode executor 的分类,子类方式来重载某些字节码的实现。比如 return value,组网行为。
57 |
58 | 4. eval frame 如何获取当前是初始frame,还是sub frame,如何给上层的 frame 返回值捏
59 |
60 | 5. SIR 组网阶段是分开组网还是同用同一个?
61 |
62 | ### SIR 如何实现 fallback 时回退?对FunctionGraph Copy 一份?
63 |
64 |
65 |
66 | ## 竞品调研
67 |
68 | ### Dynamo 选择了『Frame 模拟操作』
69 |
70 | 1. Dynamo 会将每个 function type 包装为 SkipFileVariable 或者是 UserFunctionVariable
71 |
72 | 2. 如果是一个用户的函数,那么会进行 inline 的链接。即使用了*模拟Frame的调用方式*。
73 | ``` python
74 | def inline_user_function_return(self, fn, args, kwargs):
75 | """
76 | A call to some user defined function by inlining it.
77 | """
78 | state = self.copy_graphstate()
79 | try:
80 | result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
81 | self.output.guards.update(fn.guards)
82 | return result
83 | except Exception:
84 | self.restore_graphstate(state)
85 | raise
86 | ```
87 | 可以看到上述代码也是进行了Graph的恢复行为,这个其实和我想的差不多。
88 |
89 | 3. 如果中间出现了任何的异常,都会触发当前 Graph 的独立成组。将每个 Function 当做是一个 sub graph 来进行调用。
90 |
91 |
92 | ### 需要考虑的:
93 |
94 | 1. code 可以是一个 generator 类型
95 |
96 | 2. 传参等操作可以很好的处理吗?globals准备、locals 的准备、闭包的准备。
97 |
98 | ### Call Function 的参数传递 :Frame 的准备
99 |
100 | 好像也不复杂。
101 |
102 |
103 | ## 版本适配工作
104 |
105 | CPython 中对 code 的实现有如下的代码:
106 |
107 |
108 | ## PaddleSOT 实现
109 |
110 | 目前 PaddleSOT 也是使用 inline call 的方式来实现的。因为 eval frame 的实现参数传递比较复杂。而模拟只需要处理前面和后面的参数传递部分即可。
111 |
112 | 在 paddle 的实现中,一共有多个组件来实现 inline call 的子图融合,主要包含下面几个部分:
113 |
114 | - InlineExecutor: 包含字节码差异行为的实现
115 |
116 | - FunctionGlobalTracker 的引入
117 |
118 | - Function graph 回滚机制
119 |
120 | - Bytecode Force Fallback 机制
121 |
122 | ### InlineCallExecutor
123 |
124 | 在原来的 OpcodeExecutor的基础上引入了inline call的executor,专门用来对子函数的SIR进行融合。主要的算法流图如下:
125 |
126 | 如果 father 调用了 child 函数,那么会有如下的流程图:
127 |
128 | ```mermaid
129 | graph TD
130 | A[CALL_FUNCTION调用] --> B[搜集参数,fn variable传递给新的InlineExecutor]
131 | B --> C[调用 InlineExecutor::_prepare_env, 准备globals, locals等环境]
132 | C --> D(存储当前的SIR状态)
133 | D --> E[调用 InlineExecutor::run 函数, 进行SIR的融合]
134 | E --> F{run是否可以成功run}
135 | F --> | Y | G[将run返回的Variable放到father的stack作为TOS]
136 | H --> I[触发Non Jump Fallback]
137 | F --> | N | H[回退SIR,FunctionGraph等的状态]
138 | G --> J[继续father函数的下一条字节码]
139 | I --> K[结束]
140 |
141 | ```
142 |
143 |
144 |
145 | ##### 部分指令的区别
146 |
147 | InlineCallExecutor 与 OpcodeExecutor 存在某些指令上的区别。目前是通过类继承重写来实现的差异表示,当前已知的差异有如下几点。
148 |
149 | | ByteCodeName | OpcodeExecutor | InlineCallExecutor |
150 | | ------------- | ------------- | ------- |
151 | | RETURN_VALUE | Y | N |
152 | | JUMP fallback | Y | N |
153 |
154 |
155 | ### FunctionGlobalTracker (Motivation)
156 |
157 | 这个Case可以比较好的让大家理解 Tracker 的主要机制,什么时候需要引入新的 Tracker等。
158 |
159 |
160 |
161 | ### Function graph 回滚机制 (Motivation)
162 |
163 | ### Bytecode Force Fallback (Motivation)
164 |
165 |
166 | ## Tasks and Todos
167 |
168 | - [x] InlineExecutor
169 |
170 | - [x] RETURN_VALUE 区别:组网返回 vs 值返回
171 |
172 | - [x] 出现子图fallback情况: fallback vs SIR回滚
173 |
174 |
175 | - [x] prepare_virtual_env 函数的实现
176 |
177 | - [x] globals, locals, const
178 |
179 | - [x] prepare_closure 闭包函数的实现
180 |
181 | - [ ] 闭包支持
182 |
183 |
184 | - [x] 额外的 Tracker
185 |
186 | - [x] FunctionGlobalTracker
187 |
188 |
189 | - [ ] SIR save_memo 机制和 restore_memo 机制
190 |
--------------------------------------------------------------------------------
/tests/test_instruction_translator_cache.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import inspect
4 | import random
5 | import types
6 | import unittest
7 | from unittest.mock import patch
8 |
9 | from test_case_base import (
10 | TestCaseBase,
11 | test_instruction_translator_cache_context,
12 | )
13 |
14 | from sot.opcode_translator.custom_code import CustomCode
15 | from sot.opcode_translator.executor.executor_cache import OpcodeExecutorCache
16 |
17 |
18 | def fake_frames() -> (
19 | tuple[
20 | types.FrameType,
21 | types.FrameType,
22 | types.FrameType,
23 | types.FrameType,
24 | types.FrameType,
25 | ]
26 | ):
27 | def fake_inner_fn_1():
28 | frame = inspect.currentframe()
29 | assert frame is not None
30 | return frame
31 |
32 | def fake_inner_fn_2():
33 | frame = inspect.currentframe()
34 | assert frame is not None
35 | return frame
36 |
37 | def fake_inner_fn_3():
38 | frame = inspect.currentframe()
39 | assert frame is not None
40 | return frame
41 |
42 | def fake_inner_fn_4():
43 | frame = inspect.currentframe()
44 | assert frame is not None
45 | return frame
46 |
47 | def fake_inner_fn_5():
48 | frame = inspect.currentframe()
49 | assert frame is not None
50 | return frame
51 |
52 | return (
53 | fake_inner_fn_1(),
54 | fake_inner_fn_2(),
55 | fake_inner_fn_3(),
56 | fake_inner_fn_4(),
57 | fake_inner_fn_5(),
58 | )
59 |
60 |
61 | (
62 | FRAME_1,
63 | FRAME_2,
64 | FRAME_3,
65 | FRAME_4,
66 | FRAME_5,
67 | ) = fake_frames()
68 |
69 |
70 | def mock_start_translate(frame: types.FrameType, **kwargs):
71 | translate_map = {
72 | FRAME_1: (CustomCode(FRAME_2.f_code, False), lambda frame: True),
73 | FRAME_3: (
74 | CustomCode(FRAME_4.f_code, False),
75 | lambda frame: False,
76 | ), # Always re-compile
77 | FRAME_5: (CustomCode(None, False), lambda frame: True),
78 | }
79 | return translate_map[frame]
80 |
81 |
82 | class TestOpcodeExecutorCache(unittest.TestCase):
83 | def reset(self):
84 | global translate_count
85 | translate_count = 0
86 | OpcodeExecutorCache().clear()
87 |
88 | @patch(
89 | "sot.opcode_translator.executor.executor_cache.start_translate",
90 | mock_start_translate,
91 | )
92 | def test_cache_hit(self):
93 | with test_instruction_translator_cache_context() as ctx:
94 | translated_code_1 = OpcodeExecutorCache()(FRAME_1)
95 | assert translated_code_1 is not None
96 | self.assertEqual(translated_code_1.code, FRAME_2.f_code)
97 | self.assertEqual(ctx.translate_count, 1)
98 | # cache hit
99 | translated_code_2 = OpcodeExecutorCache()(FRAME_1)
100 | assert translated_code_2 is not None
101 | self.assertEqual(translated_code_2.code, FRAME_2.f_code)
102 | self.assertEqual(ctx.translate_count, 1)
103 |
104 | @patch(
105 | "sot.opcode_translator.executor.executor_cache.start_translate",
106 | mock_start_translate,
107 | )
108 | def test_cache_miss_due_to_unknown_code(self):
109 | with test_instruction_translator_cache_context() as ctx:
110 | translated_code_1 = OpcodeExecutorCache()(FRAME_1)
111 | assert translated_code_1 is not None
112 | self.assertEqual(translated_code_1.code, FRAME_2.f_code)
113 | self.assertEqual(ctx.translate_count, 1)
114 | # cache miss
115 | translated_code_2 = OpcodeExecutorCache()(FRAME_3)
116 | assert translated_code_2 is not None
117 | self.assertEqual(translated_code_2.code, FRAME_4.f_code)
118 | self.assertEqual(ctx.translate_count, 2)
119 |
120 | @patch(
121 | "sot.opcode_translator.executor.executor_cache.start_translate",
122 | mock_start_translate,
123 | )
124 | def test_cache_miss_due_to_check_failed(self):
125 | with test_instruction_translator_cache_context() as ctx:
126 | translated_code_1 = OpcodeExecutorCache()(FRAME_3)
127 | assert translated_code_1 is not None
128 | self.assertEqual(translated_code_1.code, FRAME_4.f_code)
129 | self.assertEqual(ctx.translate_count, 1)
130 | # cache miss
131 | translated_code_2 = OpcodeExecutorCache()(FRAME_3)
132 | assert translated_code_2 is not None
133 | self.assertEqual(translated_code_2.code, FRAME_4.f_code)
134 | self.assertEqual(ctx.translate_count, 2)
135 |
136 |
137 | def foo(x):
138 | return x + 1
139 |
140 |
141 | class TestCacheExceedLimit(TestCaseBase):
142 | def test_cache_exceed_limit(self):
143 | for _ in range(30):
144 | input = random.random()
145 | self.assert_results(foo, input)
146 |
147 |
148 | if __name__ == '__main__':
149 | unittest.main()
150 |
--------------------------------------------------------------------------------