├── .clang-format ├── .github └── workflows │ ├── format.yml │ ├── lint.yml │ └── test.yml ├── .gitignore ├── .mypy.ini ├── .style.yapf ├── .yapfignore ├── LICENSE ├── README.md ├── code-check.sh ├── frontend ├── __init__.py ├── bytecode_analysis.py ├── bytecode_writter.py ├── c_api.pyi ├── cache.py ├── code.py ├── compile.py ├── config.py ├── control_flow.py ├── csrc │ ├── csrc.h │ ├── frame_evaluation.cpp │ ├── opcode.cpp │ └── parse_types.cpp ├── dynamic.py ├── fx_graph.py ├── guard_tracker.py ├── guards.py ├── instruction.py ├── no_preload.py ├── object_table.py ├── pycode_generator.py ├── pycode_writer.py ├── store_pos.py ├── tracer.py ├── utils.py └── variables │ ├── __init__.py │ ├── any_.py │ ├── base.py │ ├── builtin_types.py │ ├── const.py │ ├── dict_.py │ ├── iterator.py │ ├── list_.py │ ├── scalar.py │ ├── set_.py │ ├── tensor.py │ ├── torch_module.py │ └── tuple_.py ├── pytest.ini ├── requirements.txt ├── scripts ├── compile_longobj.sh ├── longobject.v3.9.12.patch └── pytest_with_preload.sh ├── setup.py └── test ├── common ├── checker.py └── plugin_disable_preload.py ├── conftest.py ├── example.py ├── test_builtins.py ├── test_call_function_ex.py ├── test_call_udf.py ├── test_cuda.py ├── test_dict.py ├── test_dyn_shape.py ├── test_end_of_control_flow.py ├── test_extend_arg.py ├── test_inplace.py ├── test_int_cache.py ├── test_list.py ├── test_model_a_tridentnet.py ├── test_model_bart.py ├── test_model_bert.py ├── test_model_blockdrop.py ├── test_model_deberta.py ├── test_model_densenet.py ├── test_model_lstm.py ├── test_model_monodepth.py ├── test_model_multi_align.py ├── test_model_quantized.py ├── test_model_resnet.py ├── test_model_seq2seq.py ├── test_nnmodule.py ├── test_numpy.py ├── test_random_key.py ├── test_scalar.py ├── test_set.py ├── test_stack_effect.py ├── test_static_control_flow.py ├── test_store.py ├── test_tensor.py ├── test_tuple.py └── test_ud_class.py /.clang-format: -------------------------------------------------------------------------------- 1 | Language: Cpp 2 | BasedOnStyle: LLVM 3 | IndentWidth: 4 4 | -------------------------------------------------------------------------------- /.github/workflows/format.yml: -------------------------------------------------------------------------------- 1 | name: Formatting Check 2 | on: [push] 3 | jobs: 4 | python-formatting-check: 5 | name: Python Formatting Check 6 | runs-on: ubuntu-latest 7 | steps: 8 | - uses: actions/checkout@v2 9 | - name: run YAPF to test if python code is correctly formatted 10 | uses: AlexanderMelde/yapf-action@master 11 | with: 12 | args: --verbose 13 | cpp-formatting-check: 14 | name: C++ Formatting Check 15 | runs-on: ubuntu-latest 16 | steps: 17 | - uses: actions/checkout@v2 18 | - name: run clang-format to test if C++ code is correctly formatted 19 | uses: RafikFarhad/clang-format-github-action@v3 20 | with: 21 | sources: frontend/**/*.cpp frontend/**/*.h 22 | style: file 23 | -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: Lint 2 | on: [push] 3 | jobs: 4 | mypy: 5 | name: Python Type Check 6 | runs-on: ubuntu-latest 7 | steps: 8 | - name: Setup Python 9 | uses: actions/setup-python@v1 10 | with: 11 | python-version: 3.9 12 | architecture: x64 13 | - name: Checkout 14 | uses: actions/checkout@v3 15 | - name: Install mypy 16 | run: pip install mypy 17 | pip install --upgrade -r requirements.txt -f https://download.pytorch.org/whl/torch_stable.html 18 | - name: Run mypy 19 | run: mypy -p frontend 20 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Python unit tests 2 | 3 | on: [push] 4 | 5 | jobs: 6 | build: 7 | 8 | runs-on: self-hosted 9 | strategy: 10 | matrix: 11 | python-version: ["3.9"] 12 | 13 | steps: 14 | - uses: actions/checkout@v3 15 | with: 16 | ssh-key: ${{ secrets.SSHKEY }} 17 | - name: install the package 18 | run: | 19 | source ~/venv/frontend-env/bin/activate 20 | pip install --force-reinstall . 21 | cd scripts && BUILD_DIR=~/frontend ./compile_longobj.sh && cd .. 22 | - name: install dependency 23 | run: | 24 | source /opt/spack/share/spack/setup-env.sh 25 | spack load cuda@11.8.0 /jb4mlxg 26 | spack load python@3.9.12%gcc@=11.3.0 27 | source ~/venv/frontend-env/bin/activate 28 | pip install --upgrade -r requirements.txt -f https://download.pytorch.org/whl/torch_stable.html 29 | if ! pip show transformers &> /dev/null; then 30 | pip install transformers==v4.29.1 31 | fi 32 | - name: Test with pytest 33 | run: | 34 | source /opt/spack/share/spack/setup-env.sh 35 | spack load cuda@11.8.0 /jb4mlxg 36 | spack load python@3.9.12%gcc@=11.3.0 37 | source ~/venv/frontend-env/bin/activate 38 | srun -p ja --gres=gpu:v100:1 --exclusive ./scripts/pytest_with_preload.sh -vs test 39 | FORCE_RUN_SKIPPED_TEST=1 srun -p ja --gres=gpu:v100:1 --exclusive ./scripts/pytest_with_preload.sh -vs test/test_model_blockdrop.py -k test_blockdrop_dyn 40 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.egg-info 2 | .vscode 3 | build 4 | __pycache__ 5 | *.so 6 | test/simple.py 7 | tmp -------------------------------------------------------------------------------- /.mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | exclude = (?x)( 3 | ^build/ 4 | ) 5 | strict = True 6 | [mypy-torch.*] 7 | follow_imports = skip 8 | [mypy-sympy.*] 9 | follow_imports = skip -------------------------------------------------------------------------------- /.style.yapf: -------------------------------------------------------------------------------- 1 | [style] 2 | based_on_style = google -------------------------------------------------------------------------------- /.yapfignore: -------------------------------------------------------------------------------- 1 | build -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MagPy 2 | MagPy is a JIT compiler for PyTorch programs. It can extract the operator graph from PyTorch programs and optimize the graph with a wide range of deep learning graph compilers. 3 | 4 | # Installation 5 | MagPy now supports Python 3.9. The support of other Python versions is working in progress. 6 | 7 | 1. Install CUDA. CUDA 11.8 is recommended. 8 | 2. Install dependencies: 9 | ```bash 10 | pip install -r requirements.txt -f https://download.pytorch.org/whl/torch_stable.html 11 | ``` 12 | 3. Install MagPy: 13 | ```bash 14 | pip install -e . 15 | ``` 16 | 4. Compile a shared library to disable Python integer cache by LD_PRELOAD. This script will generates a ``ldlong.v3.9.12.so'' file in build/ directory. You need to set the LD_PRELOAD environment variable to this file when running the PyTorch program. 17 | ```bash 18 | cd scripts 19 | ./compile_longobj.sh 20 | ``` 21 | 22 | # Example Usage 23 | 24 | The following script compiles and runs a simple PyTorch program with MagPy. 25 | 26 | ```python 27 | LD_PRELOAD=build/ldlong.v3.9.12.so python test/example.py 28 | ``` 29 | 30 | # Citation 31 | If you find MagPy useful in your research, please consider citing the following paper: 32 | 33 | > MagPy: Effective Operator Graph Instantiation for Deep Learning by Execution State Monitoring; Chen Zhang, Rongchao Dong, Haojie Wang, Runxin Zhong, Jike Chen, and Jidong Zhai, Tsinghua University; will be appeared in USENIX ATC'24. 34 | 35 | -------------------------------------------------------------------------------- /code-check.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | # python lint check 4 | mypy -p frontend 5 | # python code style 6 | yapf -r . -d 7 | # C++ code style 8 | clang-format --style=file -n --Werror frontend/csrc/* 9 | 10 | echo "check passed" -------------------------------------------------------------------------------- /frontend/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heheda12345/MagPy/9819fc3242baa6509d7056bff70ba1d2126219dd/frontend/__init__.py -------------------------------------------------------------------------------- /frontend/c_api.pyi: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Dict, Optional, Tuple, TYPE_CHECKING, Iterable 2 | from types import FrameType, CellType 3 | 4 | if TYPE_CHECKING: 5 | from .code import ProcessedCode 6 | 7 | 8 | def set_eval_frame( 9 | new_callback: Optional[Tuple[Callable[..., Any], Callable[..., Any]]] 10 | ) -> Optional[Tuple[Callable[..., Any], Callable[..., Any]]]: 11 | pass 12 | 13 | 14 | def set_fallback( 15 | new_callback: Optional[Tuple[Callable[..., Any], Callable[..., Any]]] 16 | ) -> Optional[Tuple[Callable[..., Any], Callable[..., Any]]]: 17 | pass 18 | 19 | 20 | def set_skip_files(skip_file: set[str], end_file: set[str]) -> None: 21 | pass 22 | 23 | 24 | def get_value_stack_from_top(frame: FrameType, index: int) -> Any: 25 | pass 26 | 27 | 28 | def set_value_stack_from_top(frame: FrameType, index: int, value: Any) -> None: 29 | pass 30 | 31 | 32 | def get_value_stack_size(frame: FrameType) -> int: 33 | pass 34 | 35 | 36 | def guard_match(frame_id: int, callsite_id: int, 37 | locals: Dict[str, Any]) -> Optional[Callable[..., Any]]: 38 | pass 39 | 40 | def get_miss_locals(frame_id: int) -> list[str]: 41 | pass 42 | 43 | def finalize() -> None: 44 | pass 45 | 46 | 47 | def enter_nested_tracer() -> None: 48 | pass 49 | 50 | 51 | def exit_nested_tracer() -> None: 52 | pass 53 | 54 | 55 | def mark_need_postprocess() -> None: 56 | pass 57 | 58 | 59 | def add_to_cache(frame_id: int, callsite_id: int, id_in_callsite: int, 60 | guard_fn: Callable[..., Any], graph_fn: Callable[..., 61 | Any]) -> None: 62 | pass 63 | 64 | 65 | def c_reset() -> None: 66 | pass 67 | 68 | 69 | def stack_effect(op: int, oparg: int, 70 | jump: Optional[bool]) -> tuple[int, int, int, bool, bool]: 71 | pass 72 | 73 | 74 | def set_null_object(obj: Any) -> None: 75 | pass 76 | 77 | 78 | def set_miss_threshold(obj: Any) -> None: 79 | pass 80 | 81 | 82 | def get_next_frame_id() -> int: 83 | pass 84 | 85 | 86 | def get_code_map(frame: FrameType) -> 'ProcessedCode': 87 | pass 88 | 89 | 90 | def is_bound_method(obj: Any, name: str) -> bool: 91 | pass 92 | 93 | 94 | def parse_rangeiterobject(obj: Any) -> Tuple[int, int, int, int]: 95 | pass 96 | 97 | 98 | def parse_mapproxyobject(obj: Any) -> Any: 99 | pass 100 | 101 | 102 | def make_rangeiterobject(start: int, stop: int, step: int) -> Any: 103 | pass 104 | 105 | 106 | def get_from_freevars(frame: FrameType, idx: int) -> Any: 107 | pass 108 | 109 | 110 | def parse_mapobject(obj: Any) -> Tuple[Iterable[Any], Callable[..., Any]]: 111 | pass 112 | 113 | def parse_cell(cell: CellType) -> Any: 114 | pass 115 | 116 | def set_cell(cell: CellType, value: Any) -> None: 117 | pass 118 | 119 | def set_local(frame: FrameType, idx: int, value: Any) -> None: 120 | pass 121 | 122 | 123 | def parse_type_obj(obj: Any) -> str: 124 | pass -------------------------------------------------------------------------------- /frontend/cache.py: -------------------------------------------------------------------------------- 1 | from types import CodeType 2 | from typing import Callable, Any, Optional, Tuple 3 | from dataclasses import dataclass 4 | 5 | from frontend.code import ProcessedCode 6 | from .instruction import Instruction 7 | from .c_api import add_to_cache 8 | from .store_pos import StorePos 9 | 10 | 11 | @dataclass 12 | class CachedGraph: 13 | guard_fn: Callable[..., Any] 14 | graph_fn: Callable[..., Any] 15 | start_pc: int 16 | end_pc: int 17 | start_stack_size: int 18 | end_stack_size: int 19 | return_values: list[StorePos] 20 | key: int 21 | object_refs: list[Any] 22 | 23 | 24 | TOTAL_SIZE = 0 25 | 26 | 27 | class FrameCache: 28 | frame_id: int 29 | cached_graphs: dict[int, 30 | list[CachedGraph]] # start_pc -> list of cached graph 31 | callsite_id: dict[int, int] # start_pc -> callsite_id 32 | pre_cache_size: int 33 | updated: bool 34 | # 0 for root, 1 for callee 35 | code: list[Optional[Tuple[CodeType, ProcessedCode]]] 36 | 37 | def __init__(self, frame_id: int) -> None: 38 | self.frame_id = frame_id 39 | self.cached_graphs = {0: []} 40 | self.callsite_id = {0: 0} 41 | self.new_code = None 42 | self.code_map = None 43 | self.code = [None, None] 44 | self.updated = True # rewrite bytecode for the first time 45 | 46 | def add(self, traced_code: CachedGraph) -> None: 47 | start_pc = traced_code.start_pc 48 | assert traced_code.end_pc >= 0 49 | if start_pc not in self.cached_graphs: 50 | self.cached_graphs[start_pc] = [] 51 | self.callsite_id[start_pc] = len(self.cached_graphs) - 1 52 | 53 | self.cached_graphs[start_pc].append(traced_code) 54 | 55 | add_to_cache(self.frame_id, self.callsite_id[start_pc], 56 | len(self.cached_graphs[start_pc]) - 1, 57 | traced_code.guard_fn, traced_code.graph_fn) 58 | global TOTAL_SIZE 59 | TOTAL_SIZE += 1 60 | self.updated = True 61 | 62 | def set_new_code(self, new_code: CodeType, code_map: ProcessedCode, 63 | is_callee: bool) -> None: 64 | self.code[is_callee] = (new_code, code_map) 65 | 66 | def get_new_code(self, is_callee: bool) -> Tuple[CodeType, ProcessedCode]: 67 | code = self.code[is_callee] 68 | assert code is not None 69 | return code 70 | 71 | def is_valid(self, is_callee: bool) -> bool: 72 | return not self.updated and self.code[is_callee] is not None 73 | 74 | def update_code(self, f_code: CodeType, frame_id: int, 75 | is_callee: bool) -> None: 76 | if not self.is_valid(is_callee): 77 | from .bytecode_writter import rewrite_bytecode 78 | for i in (False, True): 79 | if i == is_callee or self.code[i] is not None: 80 | # print("new_code for is_callee =", i) 81 | new_code, code_map = rewrite_bytecode(f_code, frame_id, i) 82 | self.set_new_code(new_code, code_map, i) 83 | self.updated = False 84 | 85 | 86 | frame_caches: dict[int, FrameCache] = {} 87 | 88 | 89 | def get_frame_cache(frame_id: int) -> FrameCache: 90 | return frame_caches[frame_id] 91 | 92 | 93 | def enable_cache(frame_id: int) -> None: 94 | if frame_id not in frame_caches: 95 | frame_caches[frame_id] = FrameCache(frame_id) 96 | 97 | 98 | def check_cache_updated(frame_id: int) -> bool: 99 | assert frame_id in frame_caches 100 | return frame_caches[frame_id].updated 101 | 102 | 103 | def reset() -> None: 104 | global TOTAL_SIZE 105 | TOTAL_SIZE = 0 106 | frame_caches.clear() 107 | -------------------------------------------------------------------------------- /frontend/code.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, cast 2 | import dis 3 | from .instruction import Instruction 4 | 5 | dynamic_next_pc_opnames = { 6 | "POP_JUMP_IF_FALSE", 7 | "POP_JUMP_IF_TRUE", 8 | "JUMP_IF_FALSE_OR_POP", 9 | "JUMP_IF_TRUE_OR_POP", 10 | "JUMP_IF_NOT_EXC_MATCH", 11 | "FOR_ITER", 12 | } 13 | 14 | dynamic_next_pc_opcodes = { 15 | dis.opmap[opname] for opname in dynamic_next_pc_opnames 16 | } 17 | 18 | 19 | class ProcessedCode: 20 | ''' 21 | EXTENDED_ARG: last_pc and next_pc will both return the real instruction 22 | after EXTENDED_ARG 23 | Example: 24 | 0 LOAD_FAST 25 | 2 EXTENDED_ARG 26 | 4 EXTENDED_ARG 27 | 6 LOAD_CONST 28 | 8 RETURN_VALUE 29 | last_i = 0: 30 | last_pc = original_insts[0] (LOAD_FAST) 31 | next_pc = original_insts[3] (LOAD_CONST) 32 | last_i = 2 33 | last_pc = original_insts[3] (LOAD_CONST) 34 | next_pc = original_insts[4] (RETURN_VALUE) 35 | last_i cannot be 4 or 6 36 | JUMP_IF_{TRUE,FALSE} like opcodes needs the current TOS to decide the next pc 37 | FOR_ITER: a NOP is inserted after FOR_ITER, the next pc of FOR_ITER is -1 38 | RETURN_VALUE: the next pc of RETURN_VALUE is len(original_insts) 39 | naming: 40 | last_i: the index before diviing by sizeof(Instruction) 41 | pc: the index after dividing by sizeof(Instruction) 42 | ''' 43 | 44 | pc_guarded_to_origin: dict[int, int] # last pc guard -> origin 45 | # heheda: not sure whether we need this field 46 | original_insts: list[Instruction] 47 | guard_insts: list[Instruction] 48 | original_pc: dict[Instruction, 49 | int] # original instruction -> pc in original_insts 50 | guarded_pc: dict[Instruction, 51 | int] # guarded instruction -> pc in guard_insts 52 | next_original_pc: dict[ 53 | int, 54 | int] # pc guarded -> original, only for replaced code in the orignal section of the guarded code 55 | 56 | def __init__( 57 | self, original_insts: list[Instruction], 58 | guard_insts: list[Instruction], 59 | inside_trace_opcodes: list[Instruction], 60 | next_original_pc: list[tuple[Instruction, Instruction]]) -> None: 61 | self.original_insts = original_insts[:] 62 | self.guard_insts = guard_insts[:] 63 | 64 | self.original_pc = {} 65 | pc = -1 66 | for inst in original_insts: 67 | assert inst.offset is not None 68 | for inst in guard_insts: 69 | assert inst.offset is not None 70 | for inst in reversed(original_insts): 71 | if inst.opname != "EXTENDED_ARG": 72 | pc = cast(int, inst.offset) // 2 # mypy: no-strict-optional 73 | self.original_pc[inst] = pc 74 | 75 | self.guarded_pc = {} 76 | pc = -1 77 | for inst in reversed(guard_insts): 78 | if inst.opname != "EXTENDED_ARG": 79 | pc = cast(int, inst.offset) // 2 80 | self.guarded_pc[inst] = pc 81 | 82 | self.pc_guarded_to_origin = {} 83 | for inst in guard_insts: 84 | if inst.original_inst is not None: 85 | self.pc_guarded_to_origin[cast(int, inst.offset) // 86 | 2] = self.original_pc[ 87 | inst.original_inst] 88 | for inst in inside_trace_opcodes: 89 | self.pc_guarded_to_origin[cast(int, inst.offset) // 2] = -1 90 | 91 | self.next_original_pc = {} 92 | for o, g in next_original_pc: 93 | self.next_original_pc[self.guarded_pc[g]] = self.original_pc[o] 94 | 95 | def get_pc(self, inst_list: list[Instruction], pc: int) -> int: 96 | while pc < len(inst_list) and inst_list[pc].opname == "EXTENDED_ARG": 97 | pc += 1 98 | return pc 99 | 100 | def get_orig_pc(self, lasti: int) -> int: 101 | ''' 102 | returns -1 if the lasti is a helper opcode inside tracing region 103 | returns -2 if the lasti is outside tracing region 104 | ''' 105 | pc = lasti // 2 106 | while pc < len(self.guard_insts 107 | ) and self.guard_insts[pc].opname == "EXTENDED_ARG": 108 | pc += 1 109 | if pc not in self.pc_guarded_to_origin: 110 | return -2 111 | 112 | return self.pc_guarded_to_origin[pc] 113 | 114 | def get_orig_inst(self, lasti: int) -> tuple[int, Optional[Instruction]]: 115 | pc = lasti // 2 116 | while pc < len(self.guard_insts 117 | ) and self.guard_insts[pc].opname == "EXTENDED_ARG": 118 | pc += 1 119 | assert pc in self.pc_guarded_to_origin, ( 120 | "pc %d not in pc_guarded_to_origin" % pc) 121 | origin_pc = self.pc_guarded_to_origin[pc] 122 | if origin_pc == -1: # is a helper opcode inside tracing region 123 | inst = None 124 | else: 125 | inst = self.original_insts[self.pc_guarded_to_origin[pc]] 126 | return origin_pc, inst 127 | 128 | def get_next_orig_pc(self, lasti: int) -> int: 129 | pc = lasti // 2 130 | while pc < len(self.guard_insts 131 | ) and self.guard_insts[pc].opname == "EXTENDED_ARG": 132 | pc += 1 133 | if pc not in self.next_original_pc: 134 | raise ValueError("pc %d not in next_original_pc" % pc) 135 | 136 | return self.next_original_pc[pc] 137 | 138 | def get_inst(self, lasti: int) -> Instruction: 139 | pc = lasti // 2 140 | while pc < len(self.guard_insts 141 | ) and self.guard_insts[pc].opname == "EXTENDED_ARG": 142 | pc += 1 143 | return self.guard_insts[pc] 144 | 145 | def get_pc_by_inst(self, inst: Instruction) -> int: 146 | return self.guarded_pc[inst] 147 | 148 | def is_match(self, original_pc: int, guard_pc: int) -> bool: 149 | return self.pc_guarded_to_origin[guard_pc] == original_pc 150 | 151 | def get_dependence_of_stack_var(self, original_inst: Instruction, 152 | stack_depth: int) -> list[Instruction]: 153 | raise NotImplementedError 154 | 155 | def get_dependence_of_local_var(self, original_inst: Instruction, 156 | local_name: str) -> list[Instruction]: 157 | raise NotImplementedError 158 | 159 | 160 | def generate_code_map( 161 | original_insts: list[Instruction], generated_insts: list[Instruction], 162 | inside_trace_opcodes: list[Instruction], 163 | next_original_pc: list[tuple[Instruction, 164 | Instruction]]) -> ProcessedCode: 165 | return ProcessedCode(original_insts, generated_insts, inside_trace_opcodes, 166 | next_original_pc) 167 | -------------------------------------------------------------------------------- /frontend/compile.py: -------------------------------------------------------------------------------- 1 | import dis 2 | import sys 3 | import traceback 4 | from types import FrameType, CodeType 5 | from typing import Any, Tuple, Callable, cast 6 | import logging 7 | import inspect 8 | import torch 9 | from . import tracer, utils, guard_tracker 10 | from .config import get_config 11 | from .c_api import set_eval_frame, set_skip_files, guard_match, c_reset, set_null_object, set_miss_threshold 12 | from .tracer import enable_trace, disable_trace, get_trace_func, get_process_frame 13 | from .cache import enable_cache 14 | from .utils import null_object 15 | from .fx_graph import set_frame_root 16 | from .control_flow import if_stmt 17 | 18 | logging.basicConfig( 19 | format='%(levelname)s [%(filename)s:%(lineno)d] %(message)s', 20 | level=logging.INFO) 21 | 22 | LOAD_OPCODES = list( 23 | map(dis.opmap.get, [ 24 | "LOAD_GLOBAL", "LOAD_NAME", "LOAD_FAST", "LOAD_DEREF", 25 | "LOAD_ASSERTION_ERROR", "LOAD_BUILD_CLASS", "LOAD_CONST", "LOAD_ATTR", 26 | "LOAD_CLOSURE", "LOAD_CLASSDEREF", "LOAD_METHOD" 27 | ])) 28 | STORE_OPCODES = list( 29 | map(dis.opmap.get, [ 30 | "STORE_SUBSCR", "STORE_NAME", "STORE_ATTR", "STORE_GLOBAL", 31 | "STORE_FAST", "STORE_DEREF" 32 | ])) 33 | 34 | last_op_code = dis.opmap.get("NOP") 35 | 36 | 37 | def run_graph(graph_id: int, *args: Any, **kwargs: Any) -> None: 38 | print("run_graph", graph_id, args, kwargs) 39 | return None 40 | 41 | 42 | init = False 43 | 44 | 45 | def compile(f: Callable[..., Any]) -> Callable[..., Any]: 46 | global init 47 | if not init: 48 | nn_module = inspect.getmodule(torch.nn.Module) 49 | assert nn_module is not None 50 | set_skip_files( 51 | set({ 52 | cast(str, nn_module.__file__), 53 | tracer.__file__, 54 | utils.__file__, 55 | torch.autograd.function.__file__, 56 | torch._functorch.utils.__file__, 57 | }), set({ 58 | guard_tracker.__file__, 59 | })) 60 | set_null_object(null_object) 61 | set_miss_threshold(get_config("miss_threshold")) 62 | init = True 63 | import builtins 64 | setattr(builtins, "guard_match", guard_match) 65 | setattr(builtins, "enable_trace", enable_trace) 66 | setattr(builtins, "disable_trace", disable_trace) 67 | setattr(builtins, "_frontend_compile_if_stmt", if_stmt) 68 | 69 | def _fn(*args: Any, **kwargs: Any) -> Any: 70 | pre, post = get_process_frame(f, False) 71 | prior = set_eval_frame((pre, post)) 72 | try: 73 | fn = f.forward if isinstance(f, torch.nn.Module) else f 74 | return fn(*args, **kwargs) 75 | except Exception as e: 76 | print("exception in _fn:", e, type(e)) 77 | raise e 78 | finally: 79 | set_eval_frame(prior) 80 | 81 | return _fn 82 | 83 | 84 | def reset() -> None: 85 | c_reset() 86 | from . import cache 87 | cache.reset() 88 | from . import guard_tracker 89 | guard_tracker.reset() 90 | from . import utils 91 | utils.reset() 92 | from . import fx_graph 93 | fx_graph.reset() 94 | from . import dynamic 95 | dynamic.reset() 96 | from . import tracer 97 | tracer.reset() 98 | -------------------------------------------------------------------------------- /frontend/config.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Any, Union 2 | 3 | CONFIG = { 4 | "backend": "inductor", # Union[str, Callable[..., Any]] 5 | "debug": True, 6 | "miss_threshold": 3, 7 | "dynshape": False, 8 | "model_name": "", 9 | "enable_fallback": False, 10 | } 11 | 12 | 13 | def set_config(key: str, value: Any) -> None: 14 | CONFIG[key] = value 15 | 16 | 17 | def get_config(key: str) -> Any: 18 | return CONFIG[key] -------------------------------------------------------------------------------- /frontend/csrc/csrc.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | 6 | struct _object; 7 | typedef _object PyObject; 8 | 9 | namespace frontend_csrc { 10 | 11 | class NullObjectSingleton { 12 | public: 13 | static NullObjectSingleton &getInstance() { 14 | static NullObjectSingleton instance; 15 | return instance; 16 | } 17 | 18 | PyObject *getNullObject() { return this->null_object; } 19 | void setNullObject(PyObject *obj) { this->null_object = obj; } 20 | 21 | private: 22 | NullObjectSingleton() {} 23 | 24 | NullObjectSingleton(const NullObjectSingleton &) = delete; 25 | NullObjectSingleton &operator=(const NullObjectSingleton &) = delete; 26 | PyObject *null_object = nullptr; 27 | }; 28 | 29 | struct Cache { 30 | PyObject *check_fn; 31 | PyObject *graph_fn; 32 | Cache *next; 33 | bool move_to_start; 34 | }; 35 | 36 | struct FrameCache { 37 | std::vector caches; 38 | std::map> miss_locals; 39 | }; 40 | 41 | typedef std::vector ProgramCache; 42 | 43 | // When not understanding an opcode, mark it as {-1, 0, stack_effect} 44 | // if stack_effect > 0 or {-1, -stack_effect, 0, true, true} if stack_effect < 45 | // 0, and update the opcode when needed 46 | struct StackEffect { 47 | StackEffect(int read, int pop, int push, bool local_effect = false, 48 | bool global_effect = false) 49 | : read(read), write_old(pop), write_new(push), 50 | local_effect(local_effect), global_effect(global_effect) {} 51 | int read, write_old, write_new; 52 | bool local_effect, global_effect; 53 | }; 54 | StackEffect stack_effect(int opcode, int oparg, int jump); 55 | PyObject *parse_rangeiterobject(PyObject *self, PyObject *args); 56 | PyObject *make_rangeiterobject(PyObject *self, PyObject *args); 57 | PyObject *parse_mapproxyobject(PyObject *self, PyObject *args); 58 | PyObject *parse_mapobject(PyObject *self, PyObject *args); 59 | PyObject *parse_cell(PyObject *self, PyObject *args); 60 | PyObject *set_cell(PyObject *self, PyObject *args); 61 | PyObject *parse_type_obj(PyObject *self, PyObject *args); 62 | 63 | } // namespace frontend_csrc 64 | -------------------------------------------------------------------------------- /frontend/csrc/opcode.cpp: -------------------------------------------------------------------------------- 1 | #include "csrc.h" 2 | #include 3 | #include 4 | #include 5 | using namespace frontend_csrc; 6 | StackEffect frontend_csrc::stack_effect(int opcode, int oparg, int jump) { 7 | switch (opcode) { 8 | case NOP: 9 | case EXTENDED_ARG: 10 | return {0, 0, 0}; 11 | 12 | /* Stack manipulation */ 13 | case POP_TOP: 14 | return {0, 1, 0}; 15 | case ROT_TWO: 16 | return {0, 2, 2}; 17 | case ROT_THREE: 18 | return {0, 3, 3}; 19 | case ROT_FOUR: 20 | return {0, 4, 4}; 21 | case DUP_TOP: 22 | return {1, 0, 1}; 23 | case DUP_TOP_TWO: 24 | return {2, 0, 2}; 25 | 26 | /* Unary operators */ 27 | case UNARY_POSITIVE: 28 | case UNARY_NEGATIVE: 29 | case UNARY_NOT: 30 | case UNARY_INVERT: 31 | return {0, 1, 1}; 32 | 33 | // heheda: not sure 34 | case SET_ADD: 35 | case LIST_APPEND: 36 | return {-1, 1, 0}; 37 | case MAP_ADD: 38 | return {-1, 2, 0}; 39 | 40 | /* Binary operators */ 41 | case BINARY_POWER: 42 | case BINARY_MULTIPLY: 43 | case BINARY_MATRIX_MULTIPLY: 44 | case BINARY_MODULO: 45 | case BINARY_ADD: 46 | case BINARY_SUBTRACT: 47 | case BINARY_SUBSCR: 48 | case BINARY_FLOOR_DIVIDE: 49 | case BINARY_TRUE_DIVIDE: 50 | return {0, 2, 1}; 51 | case INPLACE_FLOOR_DIVIDE: 52 | case INPLACE_TRUE_DIVIDE: 53 | return {0, 2, 1}; 54 | 55 | case INPLACE_ADD: 56 | case INPLACE_SUBTRACT: 57 | case INPLACE_MULTIPLY: 58 | case INPLACE_MATRIX_MULTIPLY: 59 | case INPLACE_MODULO: 60 | return {0, 2, 1}; 61 | case STORE_SUBSCR: 62 | return {0, 3, 0}; 63 | case DELETE_SUBSCR: 64 | return {0, 2, 0}; 65 | 66 | case BINARY_LSHIFT: 67 | case BINARY_RSHIFT: 68 | case BINARY_AND: 69 | case BINARY_XOR: 70 | case BINARY_OR: 71 | return {0, 2, 1}; 72 | case INPLACE_POWER: 73 | return {0, 2, 1}; 74 | case GET_ITER: 75 | return {0, 1, 1}; 76 | 77 | case PRINT_EXPR: 78 | return {0, 1, 0}; 79 | case LOAD_BUILD_CLASS: 80 | return {0, 0, 1}; 81 | case INPLACE_LSHIFT: 82 | case INPLACE_RSHIFT: 83 | case INPLACE_AND: 84 | case INPLACE_XOR: 85 | case INPLACE_OR: 86 | return {0, 2, 1}; 87 | 88 | case SETUP_WITH: 89 | /* 1 in the normal flow. 90 | * Restore the stack position and push 6 values before jumping to 91 | * the handler if an exception be raised. */ 92 | return jump ? StackEffect{-1, 0, 6, true, true} 93 | : StackEffect{-1, 0, 1, true, true}; 94 | case RETURN_VALUE: 95 | return {0, 1, 0}; 96 | case IMPORT_STAR: 97 | return {0, 1, 0, true}; 98 | case SETUP_ANNOTATIONS: 99 | return {0, 0, 0, true}; 100 | case YIELD_VALUE: 101 | return {0, 1, 1}; 102 | case YIELD_FROM: 103 | return {0, 1, 0}; 104 | case POP_BLOCK: 105 | return {0, 0, 0, false, true}; 106 | case POP_EXCEPT: 107 | return {0, 3, 0, false, true}; 108 | 109 | case STORE_NAME: 110 | return {0, 1, 0, true}; 111 | case DELETE_NAME: 112 | return {0, 0, 0, true}; 113 | case UNPACK_SEQUENCE: 114 | return {0, 1, oparg}; 115 | case UNPACK_EX: 116 | return {1, 0, (oparg & 0xFF) + (oparg >> 8)}; // heheda: not sure 117 | case FOR_ITER: 118 | /* -1 at end of iterator, 1 if continue iterating. */ 119 | return jump > 0 ? StackEffect{0, 1, 0} : StackEffect{0, 1, 2}; 120 | 121 | case STORE_ATTR: 122 | return {0, 2, 0}; 123 | case DELETE_ATTR: 124 | return {0, 1, 0}; 125 | case STORE_GLOBAL: 126 | return {0, 1, 0, false, true}; 127 | case DELETE_GLOBAL: 128 | return {0, 0, 0, false, false}; 129 | case LOAD_CONST: 130 | return {0, 0, 1}; 131 | case LOAD_NAME: 132 | return {0, 0, 1, false}; 133 | case BUILD_TUPLE: 134 | case BUILD_LIST: 135 | case BUILD_SET: 136 | case BUILD_STRING: 137 | return {0, oparg, 1}; 138 | case BUILD_MAP: 139 | return {0, 2 * oparg, 1}; 140 | case BUILD_CONST_KEY_MAP: 141 | return {0, oparg + 1, 1}; 142 | case LOAD_ATTR: 143 | return {0, 1, 1}; 144 | case COMPARE_OP: 145 | case IS_OP: 146 | case CONTAINS_OP: 147 | return {0, 2, 1}; 148 | case JUMP_IF_NOT_EXC_MATCH: 149 | return {0, 2, 0}; 150 | case IMPORT_NAME: 151 | return {0, 2, 1}; 152 | case IMPORT_FROM: 153 | return {0, 0, 1}; 154 | 155 | /* Jumps */ 156 | case JUMP_FORWARD: 157 | case JUMP_ABSOLUTE: 158 | return {0, 0, 0}; 159 | 160 | case JUMP_IF_TRUE_OR_POP: 161 | case JUMP_IF_FALSE_OR_POP: 162 | return jump ? StackEffect{1, 0, 0} : StackEffect{0, 1, 0}; 163 | 164 | case POP_JUMP_IF_FALSE: 165 | case POP_JUMP_IF_TRUE: 166 | return {0, 1, 0}; 167 | 168 | case LOAD_GLOBAL: 169 | return {0, 0, 1, false, true}; 170 | 171 | /* Exception handling */ 172 | case SETUP_FINALLY: 173 | /* 0 in the normal flow. 174 | * Restore the stack position and push 6 values before jumping to 175 | * the handler if an exception be raised. */ 176 | return jump ? StackEffect{0, -6, 0} : StackEffect{0, 0, 0}; 177 | case RERAISE: 178 | return {0, 3, 0}; 179 | 180 | case WITH_EXCEPT_START: 181 | return {0, 7, 4}; 182 | 183 | case LOAD_FAST: 184 | return {0, 0, 1}; 185 | case STORE_FAST: 186 | return {0, 1, 0, true}; 187 | case DELETE_FAST: 188 | return {0, 0, 0, true}; 189 | 190 | case RAISE_VARARGS: 191 | if (oparg == 0) 192 | return {0, 0, 0, false, true}; 193 | if (oparg == 1) 194 | return {0, 1, 0, false, false}; 195 | if (oparg == 2) 196 | return {0, 2, 0, false, true}; 197 | return {PY_INVALID_STACK_EFFECT, PY_INVALID_STACK_EFFECT, 198 | PY_INVALID_STACK_EFFECT}; 199 | 200 | /* Functions and calls */ 201 | case CALL_FUNCTION: 202 | return {0, oparg + 1, 1}; 203 | case CALL_METHOD: 204 | return {0, oparg + 2, 1}; 205 | case CALL_FUNCTION_KW: 206 | return {0, oparg + 2, 1}; 207 | case CALL_FUNCTION_EX: 208 | return {0, 2 + ((oparg & 0x01) != 0), 1}; 209 | case MAKE_FUNCTION: 210 | return {0, 211 | 2 + ((oparg & 0x01) != 0) + ((oparg & 0x02) != 0) + 212 | ((oparg & 0x04) != 0) + ((oparg & 0x08) != 0), 213 | 1}; 214 | case BUILD_SLICE: 215 | if (oparg == 3) 216 | return {0, 3, 1}; 217 | else 218 | return {0, 2, 1}; 219 | 220 | /* Closures */ 221 | case LOAD_CLOSURE: 222 | return {0, 0, 1}; 223 | case LOAD_DEREF: 224 | case LOAD_CLASSDEREF: 225 | return {0, 0, 1}; 226 | case STORE_DEREF: 227 | return {0, 1, 0, false, true}; 228 | case DELETE_DEREF: 229 | return {0, 0, 0, false, true}; 230 | 231 | /* Iterators and generators */ 232 | case GET_AWAITABLE: 233 | return {0, 1, 1}; 234 | case SETUP_ASYNC_WITH: 235 | /* 0 in the normal flow. 236 | * Restore the stack position to the position before the result 237 | * of __aenter__ and push 6 values before jumping to the handler 238 | * if an exception be raised. */ 239 | return {0, 0, jump ? -1 + 6 : 0}; 240 | case BEFORE_ASYNC_WITH: 241 | return {0, 1, 2}; 242 | case GET_AITER: 243 | return {0, 1, 1}; 244 | case GET_ANEXT: 245 | return {0, 1, 1}; 246 | case GET_YIELD_FROM_ITER: 247 | return {0, 1, 1}; 248 | case END_ASYNC_FOR: 249 | return {-1, 7, 0}; // seems related with tos? (if tos is 250 | // StopAsyncIteration, pop 7 values) 251 | case FORMAT_VALUE: 252 | /* If there's a fmt_spec on the stack, we go from 2->1, 253 | else 1->1. */ 254 | return {0, (oparg & FVS_MASK) == FVS_HAVE_SPEC ? 2 : 1, 1}; 255 | case LOAD_METHOD: 256 | return {0, 1, 2}; 257 | case LOAD_ASSERTION_ERROR: 258 | return {0, 0, 1}; 259 | case LIST_TO_TUPLE: 260 | return {0, 1, 1}; 261 | case LIST_EXTEND: 262 | case SET_UPDATE: 263 | case DICT_MERGE: 264 | case DICT_UPDATE: 265 | return {0, 2, 1}; 266 | default: 267 | return {PY_INVALID_STACK_EFFECT, PY_INVALID_STACK_EFFECT, 268 | PY_INVALID_STACK_EFFECT}; 269 | } 270 | return {PY_INVALID_STACK_EFFECT, PY_INVALID_STACK_EFFECT, 271 | PY_INVALID_STACK_EFFECT}; /* not reachable */ 272 | } -------------------------------------------------------------------------------- /frontend/csrc/parse_types.cpp: -------------------------------------------------------------------------------- 1 | #include "csrc.h" 2 | #include 3 | #include 4 | #include 5 | 6 | namespace frontend_csrc { 7 | 8 | typedef struct { 9 | PyObject_HEAD long index; 10 | long start; 11 | long step; 12 | long len; 13 | } rangeiterobject; 14 | 15 | PyObject *parse_rangeiterobject(PyObject *self, PyObject *args) { 16 | PyObject *obj; 17 | if (!PyArg_ParseTuple(args, "O", &obj)) { 18 | return NULL; 19 | } 20 | if (Py_TYPE(obj) != &PyRangeIter_Type) { 21 | PyErr_SetString(PyExc_TypeError, "Expected rangeiterobject"); 22 | return NULL; 23 | } 24 | rangeiterobject *robj = (rangeiterobject *)obj; 25 | return PyTuple_Pack( 26 | 4, PyLong_FromLong(robj->index), PyLong_FromLong(robj->start), 27 | PyLong_FromLong(robj->step), PyLong_FromLong(robj->len)); 28 | } 29 | 30 | PyObject *make_rangeiterobject(PyObject *self, PyObject *args) { 31 | long index, start, step, len; 32 | if (!PyArg_ParseTuple(args, "llll", &index, &start, &step, &len)) { 33 | return NULL; 34 | } 35 | rangeiterobject *robj = PyObject_New(rangeiterobject, &PyRangeIter_Type); 36 | robj->index = index; 37 | robj->start = start; 38 | robj->step = step; 39 | robj->len = len; 40 | return (PyObject *)robj; 41 | } 42 | 43 | typedef struct { 44 | PyObject_HEAD PyObject *mapping; 45 | } mappingproxyobject; 46 | 47 | PyObject *parse_mapproxyobject(PyObject *self, PyObject *args) { 48 | PyObject *obj; 49 | if (!PyArg_ParseTuple(args, "O", &obj)) { 50 | return NULL; 51 | } 52 | if (Py_TYPE(obj) != &PyDictProxy_Type) { 53 | PyErr_SetString(PyExc_TypeError, "Expected mapproxyobject"); 54 | return NULL; 55 | } 56 | mappingproxyobject *mobj = (mappingproxyobject *)obj; 57 | Py_INCREF(mobj->mapping); 58 | return mobj->mapping; 59 | } 60 | 61 | typedef struct { 62 | PyObject_HEAD PyObject *iters; 63 | PyObject *func; 64 | } mapobject; 65 | 66 | PyObject *parse_mapobject(PyObject *self, PyObject *args) { 67 | PyObject *obj; 68 | if (!PyArg_ParseTuple(args, "O", &obj)) { 69 | return NULL; 70 | } 71 | if (Py_TYPE(obj) != &PyMap_Type) { 72 | PyErr_SetString(PyExc_TypeError, "Expected mapobject"); 73 | return NULL; 74 | } 75 | mapobject *mobj = (mapobject *)obj; 76 | Py_INCREF(mobj->iters); 77 | Py_INCREF(mobj->func); 78 | return PyTuple_Pack(2, mobj->iters, mobj->func); 79 | } 80 | 81 | PyObject *parse_cell(PyObject *self, PyObject *args) { 82 | PyObject *cell; 83 | if (!PyArg_ParseTuple(args, "O", &cell)) { 84 | return NULL; 85 | } 86 | if (Py_TYPE(cell) != &PyCell_Type) { 87 | PyErr_SetString(PyExc_TypeError, "Expected cell"); 88 | return NULL; 89 | } 90 | PyCellObject *cobj = (PyCellObject *)cell; 91 | if (cobj->ob_ref == NULL) { 92 | PyObject *null_obj = NullObjectSingleton::getInstance().getNullObject(); 93 | Py_INCREF(null_obj); 94 | return null_obj; 95 | } 96 | Py_INCREF(cobj->ob_ref); 97 | return cobj->ob_ref; 98 | } 99 | 100 | PyObject *set_cell(PyObject *self, PyObject *args) { 101 | PyObject *cell, *value; 102 | if (!PyArg_ParseTuple(args, "OO", &cell, &value)) { 103 | return NULL; 104 | } 105 | if (Py_TYPE(cell) != &PyCell_Type) { 106 | PyErr_SetString(PyExc_TypeError, "Expected cell"); 107 | return NULL; 108 | } 109 | PyCell_SET(cell, value); 110 | return Py_None; 111 | } 112 | 113 | PyObject *parse_type_obj(PyObject *self, PyObject *args) { 114 | PyObject *obj; 115 | if (!PyArg_ParseTuple(args, "O", &obj)) { 116 | return NULL; 117 | } 118 | if (PyType_Check(obj)) { 119 | return PyUnicode_FromString(((PyTypeObject *)obj)->tp_name); 120 | } 121 | PyErr_SetString(PyExc_TypeError, "Expected type object"); 122 | return NULL; 123 | } 124 | } // namespace frontend_csrc -------------------------------------------------------------------------------- /frontend/dynamic.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from typing import Any 3 | 4 | 5 | class Dynamic: 6 | pass 7 | 8 | 9 | class ScalarWithUnknownValue(Dynamic): 10 | pass 11 | 12 | 13 | @dataclasses.dataclass 14 | class DynamicControlFlow(Dynamic): 15 | pc: int 16 | opcode: str 17 | 18 | 19 | dynamic_vars = {} 20 | dynamic_refs = {} 21 | dynamic_pcs = {} 22 | dynamic_need_branch_rewrite: dict[int, list[int]] = {} 23 | 24 | 25 | def mark_dynamic(obj: Any, dyn: Dynamic) -> None: 26 | idx = id(obj) 27 | dynamic_vars[idx] = dyn 28 | dynamic_refs[idx] = obj 29 | 30 | 31 | def contains(obj: Any) -> bool: 32 | idx = id(obj) 33 | return idx in dynamic_vars 34 | 35 | 36 | def contains_by_id(idx: int) -> bool: 37 | return idx in dynamic_vars 38 | 39 | 40 | def mark_dynamic_pc(frame_id: int, pc: int, dyn: Dynamic) -> None: 41 | dynamic_pcs[(frame_id, pc)] = dyn 42 | 43 | 44 | def contains_pc(frame_id: int, pc: int) -> bool: 45 | return (frame_id, pc) in dynamic_pcs 46 | 47 | 48 | def pop_dynamic_pc(frame_id: int, pc: int) -> Dynamic: 49 | return dynamic_pcs.pop((frame_id, pc)) 50 | 51 | 52 | def add_branch_rewrite_pc(frame_id: int, pc: int) -> None: 53 | if frame_id not in dynamic_need_branch_rewrite: 54 | dynamic_need_branch_rewrite[frame_id] = list() 55 | dynamic_need_branch_rewrite[frame_id].append(pc) 56 | 57 | 58 | def need_branch_rewrite(frame_id: int) -> bool: 59 | return frame_id in dynamic_need_branch_rewrite 60 | 61 | 62 | def get_branch_rewrite_pcs(frame_id: int) -> list[int]: 63 | return dynamic_need_branch_rewrite[frame_id] 64 | 65 | 66 | def reset() -> None: 67 | dynamic_vars.clear() 68 | dynamic_refs.clear() 69 | dynamic_pcs.clear() 70 | -------------------------------------------------------------------------------- /frontend/guards.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import typing 3 | from typing import Any 4 | 5 | 6 | @dataclasses.dataclass 7 | class GuardedCode: 8 | check_fn: typing.Callable[..., Any] 9 | graph_fn: typing.Callable[..., Any] 10 | -------------------------------------------------------------------------------- /frontend/instruction.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional 2 | import dataclasses 3 | import dis 4 | 5 | 6 | @dataclasses.dataclass 7 | class Instruction: 8 | """A mutable version of dis.Instruction""" 9 | 10 | opcode: int 11 | opname: str 12 | arg: Any 13 | argval: Any 14 | offset: Optional[int] = None 15 | starts_line: Optional[int] = None 16 | is_jump_target: bool = False 17 | # extra fields to make modification easier: 18 | target: Optional["Instruction"] = None 19 | original_inst: Optional["Instruction"] = None 20 | comment: str = "" 21 | is_start: bool = False 22 | is_end: bool = False 23 | 24 | def __hash__(self) -> int: 25 | return id(self) 26 | 27 | def __eq__(self, other: object) -> bool: 28 | return id(self) == id(other) 29 | 30 | def __repr__(self) -> str: 31 | # yellow if is original inst, green if is generated inst 32 | color = "\033[33m" if self.original_inst else "\033[32m" 33 | color_gray = "\033[90m" 34 | comment = f"{color_gray}# {self.comment} \033[0m" if self.comment else "" 35 | return f"{color}{self.opname}\033[0m({self.arg}, {self.argval}) {comment}" 36 | 37 | 38 | def convert_instruction(i: dis.Instruction) -> Instruction: 39 | return Instruction( 40 | i.opcode, 41 | i.opname, 42 | i.arg, 43 | i.argval, 44 | i.offset, 45 | i.starts_line, 46 | i.is_jump_target, 47 | ) 48 | 49 | 50 | def format_insts(insts: list[Instruction], 51 | allow_unknown_target: bool = False) -> str: 52 | ret = "" 53 | for i, inst in enumerate(insts): 54 | if inst.target is not None: 55 | try: 56 | target_idx = insts.index(inst.target) 57 | except ValueError: 58 | if allow_unknown_target: 59 | target_idx = -1 60 | else: 61 | raise 62 | ret += f"{i}: {inst} -> inst {target_idx}\n" 63 | else: 64 | ret += f"{i}: {inst}\n" 65 | return ret 66 | 67 | 68 | class _NotProvided: 69 | pass 70 | 71 | 72 | # short for create_instruction 73 | def ci(name: str, 74 | arg: Any = None, 75 | argval: Any = _NotProvided, 76 | target: Optional[Instruction] = None, 77 | comment: str = "") -> Instruction: 78 | if argval is _NotProvided: 79 | argval = arg 80 | return Instruction(opcode=dis.opmap[name], 81 | opname=name, 82 | arg=arg, 83 | argval=argval, 84 | target=target, 85 | comment=comment) 86 | -------------------------------------------------------------------------------- /frontend/no_preload.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Any 3 | 4 | 5 | class NO_LD_PRELOAD_CTX: 6 | old_ld_preload: str = '' 7 | 8 | def __enter__(self) -> None: 9 | if 'LD_PRELOAD' in os.environ: 10 | self.old_ld_preload = os.environ['LD_PRELOAD'] 11 | del os.environ['LD_PRELOAD'] 12 | 13 | def __exit__(self, *args: Any) -> None: 14 | if self.old_ld_preload: 15 | os.environ['LD_PRELOAD'] = self.old_ld_preload 16 | -------------------------------------------------------------------------------- /frontend/object_table.py: -------------------------------------------------------------------------------- 1 | from typing import Any, get_args, Optional, Tuple, Generic, Callable 2 | from types import CodeType 3 | from .variables.base import Variable, HelperFunctions 4 | from .variables.any_ import AnyVar 5 | from .variables import CONST_TYPES, ScalarVar, make_var_from_value 6 | from .variables.tuple_ import TupleVar 7 | from .utils import NullObject, ReadOnlyObject 8 | from .store_pos import StorePos 9 | from .fx_graph import FxGraph 10 | import numpy as np 11 | import torch 12 | 13 | 14 | class ObjectTable: 15 | objs: dict[int, Variable] # id -> object 16 | # Python caches small integers, so int variables don't have unique ids 17 | objs_no_id: list[Variable] 18 | helper_functions: HelperFunctions 19 | 20 | def __init__(self, gen_by_caller: Callable[[Any], bool], 21 | mark_cannot_guard: Callable[[], None]) -> None: 22 | self.objs = {} 23 | self.objs_no_id = [] 24 | self.helper_functions = HelperFunctions(self.get_or_make_var, 25 | gen_by_caller, 26 | mark_cannot_guard) 27 | 28 | def add(self, var: Variable, value: Any) -> None: 29 | if id(value) in self.objs: 30 | old_var = self.objs[id(value)] 31 | if isinstance(old_var, AnyVar) and not isinstance(var, AnyVar): 32 | self.objs[id(value)] = var 33 | var, old_var = old_var, var 34 | for pos in var.extract_code_at_start: 35 | old_var.add_extract_code_at_start(pos) 36 | old_var.need_guard_check |= var.need_guard_check 37 | else: 38 | self.add_by_id(var, id(value)) 39 | 40 | def add_by_id(self, var: Variable, idx: int) -> None: 41 | assert idx not in self.objs 42 | self.objs[idx] = var 43 | var.add_subvars_to_table(self) 44 | 45 | def update_by_id(self, var: Variable, idx: int) -> None: 46 | if self.contains_by_id(idx): 47 | old_var = self.objs[idx] 48 | else: 49 | old_var = None 50 | var.set_prev(old_var) 51 | self.objs[idx] = var 52 | if old_var is not None: 53 | for attr_name, attr_var in old_var.modified_attrs.items(): 54 | if attr_name not in var.modified_attrs: 55 | var.add_modified_attr(attr_name, attr_var) 56 | 57 | def get_all(self) -> list[Variable]: 58 | return list(self.objs.values()) + self.objs_no_id 59 | 60 | def get_all_with_id(self) -> list[Tuple[int, Variable]]: 61 | return list(self.objs.items()) 62 | 63 | def get(self, 64 | value: Any, 65 | allow_unexist_const: bool = False, 66 | fx_graph: Optional[FxGraph] = None) -> Variable: 67 | if id(value) in self.objs: 68 | return self.objs[id(value)] 69 | elif value is None: 70 | return make_var_from_value(value, False, self.helper_functions, 71 | fx_graph) 72 | elif allow_unexist_const: 73 | if isinstance(value, get_args(CONST_TYPES)) or isinstance( 74 | value, 75 | (list, tuple, set, dict, range, CodeType, type(Ellipsis), 76 | np.ndarray, frozenset, torch.nn.Parameter)): 77 | return make_var_from_value(value, False, self.helper_functions, 78 | fx_graph) 79 | raise RuntimeError( 80 | f"Object({id(value)}) {value} {type(value)} not found in object table {id(self)}" 81 | ) 82 | 83 | def get_or_none(self, value: Any) -> Optional[Variable]: 84 | if id(value) in self.objs: 85 | return self.objs[id(value)] 86 | else: 87 | return None 88 | 89 | def get_or_none_by_id(self, idx: int) -> Optional[Variable]: 90 | if idx in self.objs: 91 | return self.objs[idx] 92 | else: 93 | return None 94 | 95 | def get_or_make_var(self, 96 | value: Any, 97 | need_guard_check: bool, 98 | fx_graph: Optional[FxGraph] = None, 99 | extract_code_at_start: list[StorePos] = []) -> Variable: 100 | if id(value) in self.objs: 101 | return self.objs[id(value)] 102 | else: 103 | return make_var_from_value(value, need_guard_check, 104 | self.helper_functions, fx_graph, 105 | extract_code_at_start) 106 | 107 | def get_by_id(self, idx: int) -> Variable: 108 | return self.objs[idx] 109 | 110 | def contains(self, value: Any) -> bool: 111 | return id(value) in self.objs 112 | 113 | def contains_by_id(self, idx: int) -> bool: 114 | return idx in self.objs 115 | -------------------------------------------------------------------------------- /frontend/pycode_generator.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Any 2 | from itertools import chain 3 | import torch 4 | import torch.fx 5 | from .pycode_writer import PyCodeWriter, new_name, is_valid_name 6 | from .store_pos import StorePos 7 | from .config import get_config 8 | 9 | 10 | def gen_imports(writer: PyCodeWriter, imports: set[str]) -> None: 11 | for module_import in imports: 12 | writer.wl(module_import) 13 | 14 | 15 | class FnCodegen: 16 | prepare_var_writer: PyCodeWriter 17 | writer: PyCodeWriter 18 | imports: set[str] 19 | key: int 20 | objs: dict[str, Any] # name -> obj 21 | statements: set[str] 22 | 23 | def __init__(self, key: int) -> None: 24 | self.key = key 25 | self.prepare_var_writer = PyCodeWriter() 26 | self.writer = PyCodeWriter() 27 | self.imports = set() 28 | self.objs = {} 29 | self.statements = set() 30 | 31 | def add_obj(self, obj: Any, name: str = "", force: bool = False) -> str: 32 | if force: 33 | assert name != "" 34 | assert is_valid_name(name) 35 | if name in self.objs: 36 | assert self.objs[name] == obj 37 | else: 38 | self.objs[name] = obj 39 | return name 40 | else: 41 | if name == "" or not is_valid_name(name): 42 | name = new_name("obj") 43 | elif name in self.objs: 44 | name = new_name(name) 45 | 46 | self.objs[name] = obj 47 | return name 48 | 49 | def add_import(self, module_name: str) -> None: 50 | self.imports.add(f"import {module_name}") 51 | 52 | def add_import_from(self, module_name: str, name: str) -> None: 53 | self.imports.add(f"from {module_name} import {name}") 54 | 55 | def add_stmt(self, stmt: str, is_prepare: bool = False) -> None: 56 | if is_prepare: 57 | self.prepare_var_writer.wl(stmt) 58 | else: 59 | self.writer.wl(stmt) 60 | 61 | def add_statements(self, stmt: str) -> None: 62 | self.statements.add(stmt) 63 | 64 | 65 | class GraphFnCodegen(FnCodegen): 66 | returns: list[Tuple[str, StorePos]] 67 | graph_inputs: list[tuple[StorePos, bool]] # (extract_code, to_tensor) 68 | graph_outputs: list[torch.fx.Node] 69 | id2name: dict[int, str] # idx -> name_in_graph_fn 70 | 71 | def __init__(self, key: int) -> None: 72 | super().__init__(key) 73 | self.postprossess = PyCodeWriter() 74 | self.returns = [] 75 | self.graph_inputs = [] 76 | self.graph_outputs = [] 77 | self.id2name = {} 78 | 79 | def output(self, name_in_graph_fn: str, store_pos: StorePos, code: str, 80 | in_return: bool, idx: int) -> None: 81 | self.writer.wl(f"{name_in_graph_fn} = {code}") 82 | if idx != 0: 83 | self.id2name[idx] = name_in_graph_fn 84 | if in_return: 85 | self.returns.append((name_in_graph_fn, store_pos)) 86 | 87 | def get_code(self) -> str: 88 | writer = PyCodeWriter() 89 | writer.wl( 90 | f"def ___make_graph_fn({', '.join(chain(('compiled_graph',), self.objs.keys()) )}):" 91 | ) 92 | writer.block_start() 93 | gen_imports(writer, self.imports) 94 | writer.wl(f"def fn(locals):") 95 | writer.block_start() 96 | for stmt in self.statements: 97 | writer.wl(stmt) 98 | if get_config('debug'): 99 | writer.wl( 100 | f"print('running graph_fn (key = {self.key})', locals.keys())") 101 | writer.write(self.prepare_var_writer.get_code()) 102 | # TODO: simplify 103 | graph_inputs = [] 104 | for x, to_tensor in self.graph_inputs: 105 | if to_tensor: 106 | graph_inputs.append(f"torch.tensor({x})") 107 | else: 108 | graph_inputs.append(f"{x}.contiguous()") 109 | writer.wl(f"graph_out = compiled_graph({', '.join(graph_inputs)})" 110 | ) # writer.wl(f"print('graph_out', graph_out)") 111 | writer.write(self.writer.get_code()) 112 | # writer.wl(f"print('graph_fn done', locals)") 113 | graph_retures = ", ".join( 114 | f"{target_name}" for target_name, _ in self.returns) 115 | writer.wl(f"return {graph_retures}") 116 | writer.block_end() 117 | writer.wl(f"return fn") 118 | writer.block_end() 119 | return writer.get_code() 120 | 121 | def get_return_values(self) -> list[StorePos]: 122 | return [store_pos for _, store_pos in self.returns] 123 | 124 | def add_graph_output(self, fx_node: torch.fx.Node) -> str: 125 | self.graph_outputs.append(fx_node) 126 | return f"graph_out[{len(self.graph_outputs)-1}]" 127 | 128 | def get_graph_outputs(self) -> list[torch.fx.Node]: 129 | return self.graph_outputs 130 | 131 | def add_graph_input(self, 132 | extract_code: StorePos, 133 | to_tensor: bool = False) -> None: 134 | self.graph_inputs.append((extract_code, to_tensor)) 135 | if to_tensor: 136 | self.add_import("torch") 137 | extract_code.add_name_to_fn(self) 138 | 139 | 140 | class GuardFnCodegen(FnCodegen): 141 | checks: set[tuple[str, StorePos]] 142 | imports: set[str] 143 | object_refs: list[Any] # the reference to objects for id check 144 | layout_sensitive: bool 145 | 146 | def __init__(self, key: int) -> None: 147 | super().__init__(key) 148 | self.checks = set() 149 | self.imports = set() 150 | self.object_refs = [] 151 | self.layout_sensitive = False 152 | 153 | def add_check(self, check: tuple[str, StorePos]) -> None: 154 | self.checks.add(check) 155 | 156 | def add_id_check(self, check: tuple[str, StorePos], obj: Any) -> None: 157 | self.add_check(check) 158 | self.object_refs.append(obj) 159 | 160 | def get_code(self) -> str: 161 | writer = PyCodeWriter() 162 | writer.wl(f"def ___make_guard_fn({', '.join(self.objs.keys())}):") 163 | writer.block_start() 164 | gen_imports(writer, self.imports) 165 | writer.wl(f"def fn(locals):") 166 | writer.block_start() 167 | writer.write(f"try:") 168 | writer.block_start() 169 | for stmt in self.statements: 170 | writer.write(stmt) 171 | if get_config('debug'): 172 | writer.wl( 173 | f"print('running guard_fn (key = {self.key})', locals.keys())") 174 | writer.write(self.prepare_var_writer.get_code()) 175 | writer.write(self.writer.get_code()) 176 | if len(self.checks) == 0: 177 | writer.wl(f"ok = True") 178 | writer.wl(f"missed_check = []") 179 | else: 180 | writer.wl(f"ok = True") 181 | writer.wl(f"missed_check = []") 182 | for x in self.checks: 183 | writer.wl(f"if not ({x[0]}):") 184 | writer.block_start() 185 | if not hasattr(x[1], '_init_'): 186 | for check in self.checks: 187 | if str(check[1]) in x[0] and len(str(check[1])) > 0: 188 | x = (x[0], check[1]) 189 | writer.wl(f'''missed_check.append((r"{x[1]}", r"{x[0]}"))''') 190 | writer.wl(f"ok = False") 191 | writer.block_end() 192 | if get_config('debug'): 193 | writer.wl(f"print('ok = ', ok)") 194 | writer.block_end() 195 | writer.wl(f"except Exception as e:") 196 | writer.block_start() 197 | writer.wl(f"print('exception in guard_fn:', e, type(e))") 198 | writer.wl(f'import traceback') 199 | writer.wl(f"print(traceback.format_exc())") 200 | writer.wl(f"return (missed_check, False)") 201 | writer.block_end() 202 | writer.wl(f"return (missed_check, ok)") 203 | writer.block_end() 204 | writer.wl(f"return fn") 205 | writer.block_end() 206 | return writer.get_code() 207 | 208 | def get_object_refs(self) -> list[Any]: 209 | return self.object_refs 210 | -------------------------------------------------------------------------------- /frontend/pycode_writer.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | import struct 3 | import keyword 4 | 5 | 6 | def get_float_string(value: float) -> str: 7 | binary_data = struct.pack('d', value) 8 | hex_string = "b'" + ''.join( 9 | '\\x' + format(byte, '02x') for byte in binary_data) + "'" 10 | return f"struct.unpack('d', {hex_string})[0]" 11 | 12 | 13 | NEW_VAR_ID = 0 14 | 15 | 16 | def new_name(prefix: str) -> str: 17 | if prefix == "": 18 | prefix = "tmp" 19 | global NEW_VAR_ID 20 | NEW_VAR_ID += 1 21 | return f"{prefix}_{NEW_VAR_ID}" 22 | 23 | 24 | def is_valid_name(variable_name: str) -> bool: 25 | if not variable_name: 26 | return False 27 | 28 | if not variable_name[0].isalpha() and variable_name[0] != '_': 29 | return False 30 | 31 | for char in variable_name[1:]: 32 | if not (char.isalnum() or char == '_'): 33 | return False 34 | 35 | if keyword.iskeyword(variable_name): 36 | return False 37 | 38 | return True 39 | 40 | 41 | class PyCodeWriter: 42 | imports: set[str] 43 | code_strs: list[str] 44 | indent: int 45 | 46 | def __init__(self) -> None: 47 | self.imports = set() 48 | self.code_strs = [] 49 | self.indent = 0 50 | 51 | def block_start(self) -> None: 52 | self.indent += 1 53 | 54 | def block_end(self) -> None: 55 | self.indent -= 1 56 | 57 | def set_indent(self, indent: int) -> None: 58 | self.indent = indent 59 | 60 | def write(self, code_str: str) -> None: 61 | code = code_str.splitlines() 62 | for line in code: 63 | self.wl(line) 64 | 65 | def wl(self, code_str: str) -> None: 66 | if code_str.endswith('\n'): 67 | code_str = code_str[:-1] 68 | self.code_strs.append(' ' * self.indent + code_str) 69 | 70 | def get_code(self) -> str: 71 | return '\n'.join(self.code_strs) 72 | -------------------------------------------------------------------------------- /frontend/store_pos.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional, TYPE_CHECKING, Callable, Union 2 | from types import FrameType 3 | 4 | from torch import Tensor 5 | 6 | from .c_api import get_value_stack_from_top 7 | if TYPE_CHECKING: 8 | from .pycode_generator import FnCodegen 9 | 10 | 11 | class StorePos: 12 | 13 | def get_value_from_frame(self, frame: FrameType) -> Any: 14 | raise NotImplementedError 15 | 16 | def add_name_to_fn(self, codegen: 'FnCodegen') -> None: 17 | pass 18 | 19 | 20 | class StoreInStack(StorePos): 21 | idx: int 22 | 23 | def __init__(self, idx: int) -> None: 24 | self.idx = idx 25 | 26 | def __repr__(self) -> str: 27 | return f"__stack__{self.idx}" 28 | 29 | def get_value_from_frame(self, frame: FrameType) -> Any: 30 | return get_value_stack_from_top(frame, self.idx) 31 | 32 | 33 | class StoreInLocal(StorePos): 34 | name: str 35 | 36 | def __init__(self, name: str) -> None: 37 | self.name = name 38 | 39 | def __repr__(self) -> str: 40 | return f"locals['{self.name}']" 41 | 42 | def get_value_from_frame(self, frame: FrameType) -> Any: 43 | return frame.f_locals[self.name] 44 | 45 | 46 | class StoreConstant(StorePos): 47 | value: Union[int, float] 48 | self_id: int 49 | 50 | def __init__(self, value: Union[int, float], self_id: int) -> None: 51 | self.value = value 52 | self.self_id = self_id 53 | 54 | def __repr__(self) -> str: 55 | return str(self.value) 56 | 57 | def get_value_from_frame(self, frame: FrameType) -> Any: 58 | return self.value 59 | 60 | 61 | class StoreInGlobal(StorePos): 62 | name: str 63 | 64 | def __init__(self, name: str) -> None: 65 | self.name = name 66 | 67 | def __repr__(self) -> str: 68 | return f"globals()['{self.name}']" 69 | 70 | def get_value_from_frame(self, frame: FrameType) -> Any: 71 | return frame.f_globals[self.name] 72 | 73 | 74 | class StoreInFreeVar(StorePos): 75 | free_idx: int 76 | 77 | def __init__(self, free_idx: int) -> None: 78 | super().__init__() 79 | self.free_idx = free_idx 80 | 81 | def __repr__(self) -> str: 82 | return f"get_from_freevars(frame, {self.free_idx})" 83 | 84 | def get_value_from_frame(self, frame: FrameType) -> Any: 85 | raise NotImplementedError 86 | 87 | def add_name_to_fn(self, codegen: 'FnCodegen') -> None: 88 | codegen.add_import_from("frontend.c_api", "get_from_freevars") 89 | codegen.add_import("inspect") 90 | codegen.add_stmt("frame = inspect.currentframe().f_back", 91 | is_prepare=True) 92 | 93 | 94 | class StoreInBuiltin(StorePos): 95 | name: str 96 | ty: str # attr or dict 97 | 98 | def __init__(self, name: str, ty: str) -> None: 99 | self.name = name 100 | self.ty = ty 101 | assert ty in ['attr', 'dict'] 102 | 103 | def __repr__(self) -> str: 104 | if self.ty == 'dict': 105 | return f"globals()['__builtins__']['{self.name}']" 106 | else: 107 | return f"globals()['__builtins__'].{self.name}" 108 | 109 | def get_value_from_frame(self, frame: FrameType) -> Any: 110 | if self.ty == 'dict': 111 | return frame.f_globals['__builtins__'][self.name] 112 | else: 113 | return getattr(frame.f_globals['__builtins__'], self.name) 114 | 115 | 116 | class StoreInAttr(StorePos): 117 | self_pos: StorePos 118 | self_id: int 119 | attr_name: str 120 | 121 | def __init__(self, self_pos: StorePos, self_id: int, 122 | attr_name: str) -> None: 123 | self.self_pos = self_pos 124 | self.self_id = self_id 125 | self.attr_name = attr_name 126 | 127 | def __repr__(self) -> str: 128 | return f"{self.self_pos}.{self.attr_name}" 129 | 130 | def get_value_from_frame(self, frame: FrameType) -> Any: 131 | return getattr(self.self_pos.get_value_from_frame(frame), 132 | self.attr_name) 133 | 134 | def add_name_to_fn(self, codegen: 'FnCodegen') -> None: 135 | self.self_pos.add_name_to_fn(codegen) 136 | 137 | 138 | class StoreInIndex(StorePos): 139 | self_pos: StorePos 140 | self_id: int # id of the bind object 141 | self_index: Any # array index 142 | subscriptable: bool 143 | 144 | def __init__(self, 145 | self_pos: StorePos, 146 | self_id: int, 147 | self_index: Any, 148 | subscritable: bool = True) -> None: 149 | self.self_pos = self_pos 150 | self.self_id = self_id 151 | self.self_index = self_index 152 | self.subscriptable = subscritable 153 | 154 | def __repr__(self) -> str: 155 | if self.subscriptable: 156 | return f"{self.self_pos}[{self.self_index}]" 157 | else: 158 | return f'list({self.self_pos})[{self.self_index}]' 159 | 160 | def get_value_from_frame(self, frame: FrameType) -> Any: 161 | if self.subscriptable: 162 | return self.self_pos.get_value_from_frame(frame)[self.self_index] 163 | else: 164 | return list( 165 | self.self_pos.get_value_from_frame(frame))[self.self_index] 166 | 167 | def add_name_to_fn(self, codegen: 'FnCodegen') -> None: 168 | self.self_pos.add_name_to_fn(codegen) 169 | 170 | 171 | class StoreNegate(StorePos): 172 | pos: StorePos 173 | neg_id: int 174 | 175 | def __init__(self, pos: StorePos) -> None: 176 | self.pos = pos 177 | 178 | def __repr__(self) -> str: 179 | return f"-({self.pos})" 180 | 181 | def get_value_from_frame(self, frame: FrameType) -> Any: 182 | return -self.pos.get_value_from_frame(frame) 183 | 184 | def add_name_to_fn(self, codegen: 'FnCodegen') -> None: 185 | self.pos.add_name_to_fn(codegen) 186 | 187 | 188 | class ExtractFromMethod(StorePos): 189 | self_pos: StorePos 190 | self_id: int 191 | method_name: str 192 | 193 | def __init__(self, self_pos: StorePos, self_id: int, 194 | method_name: str) -> None: 195 | self.self_pos = self_pos 196 | self.self_id = self_id 197 | self.method_name = method_name 198 | 199 | def __repr__(self) -> str: 200 | return f"{self.self_pos}.{self.method_name}()" 201 | 202 | def get_value_from_frame(self, frame: FrameType) -> Any: 203 | return getattr(self.self_pos.get_value_from_frame(frame), 204 | self.method_name)() 205 | 206 | def add_name_to_fn(self, codegen: 'FnCodegen') -> None: 207 | self.self_pos.add_name_to_fn(codegen) 208 | 209 | 210 | class ExtractFromFunction(StorePos): 211 | var_pos: list[StorePos] 212 | var_id: list[int] 213 | func_name: str 214 | func_obj: Any 215 | need_add_to_fn: bool 216 | preserved_name: str 217 | 218 | def __init__(self, 219 | var_pos: list[StorePos], 220 | var_id: list[int], 221 | func_name: str, 222 | func_obj: Callable[..., Any], 223 | need_add_to_fn: bool = False) -> None: 224 | self.var_pos = var_pos 225 | self.var_id = var_id 226 | self.func_name = func_name 227 | self.func_obj = func_obj 228 | self.need_add_to_fn = need_add_to_fn 229 | from .pycode_writer import new_name 230 | self.preserved_name = new_name(f"function_{self.func_name}") 231 | 232 | def __repr__(self) -> str: 233 | if self.need_add_to_fn: 234 | return f"{self.preserved_name}({','.join([str(pos) for pos in self.var_pos])})" 235 | else: 236 | return f"{self.func_name}({','.join([str(pos) for pos in self.var_pos])})" 237 | 238 | def get_value_from_frame(self, frame: FrameType) -> Any: 239 | args = [pos.get_value_from_frame(frame) for pos in self.var_pos] 240 | return self.func_obj(*args) 241 | 242 | def add_name_to_fn(self, codegen: 'FnCodegen') -> None: 243 | if self.need_add_to_fn: 244 | codegen.add_obj(self.func_obj, self.preserved_name, force=True) 245 | for pos in self.var_pos: 246 | pos.add_name_to_fn(codegen) 247 | 248 | 249 | class ExtractFromNew(StorePos): 250 | type_obj: Any 251 | preserved_name: Optional[str] 252 | 253 | def __init__(self, type_obj: Any) -> None: 254 | self.type_obj = type_obj 255 | self.preserved_name = None 256 | 257 | def gen_preserved_name(self) -> str: 258 | if self.preserved_name is None: 259 | from .pycode_writer import new_name 260 | self.preserved_name = new_name(f"class_{self.type_obj.__name__}") 261 | return self.preserved_name 262 | 263 | def __repr__(self) -> str: 264 | return f"object.__new__({self.gen_preserved_name()})" 265 | 266 | def get_value_from_frame(self, frame: FrameType) -> Any: 267 | return self.type_obj.__new__(self.type_obj) 268 | 269 | def add_name_to_fn(self, codegen: 'FnCodegen') -> None: 270 | codegen.add_obj(self.type_obj, self.gen_preserved_name(), force=True) 271 | 272 | 273 | class IterValue(StorePos): 274 | 275 | def __init__(self) -> None: 276 | super().__init__() 277 | 278 | def __repr__(self) -> str: 279 | return "__iter_value__" 280 | 281 | 282 | class UnknownPosInCaller(StorePos): 283 | 284 | def __init__(self) -> None: 285 | super().__init__() 286 | 287 | def __repr__(self) -> str: 288 | return "@__unknown_pos_in_caller__" 289 | 290 | 291 | class voidpos(StorePos): 292 | pass 293 | -------------------------------------------------------------------------------- /frontend/tracer.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import dis 3 | import traceback 4 | from types import FrameType, CodeType 5 | from typing import Any, Callable, Tuple 6 | import inspect 7 | from .guard_tracker import push_tracker, pop_tracker, record, trackers 8 | from .cache import enable_cache, check_cache_updated, get_frame_cache 9 | from .fx_graph import set_frame_root 10 | from .c_api import set_eval_frame, mark_need_postprocess, set_fallback 11 | from .code import ProcessedCode 12 | from .instruction import format_insts 13 | from .config import get_config 14 | 15 | run_trace_func: bool = True 16 | fall_back_frames: list[int] = [] 17 | 18 | 19 | def get_trace_func(frame_id: int) -> Callable[[FrameType, str, Any], None]: 20 | is_debug = get_config("debug") 21 | 22 | def trace_func(frame: FrameType, event: str, arg: Any) -> None: 23 | global run_trace_func 24 | if not run_trace_func and frame_id in fall_back_frames: 25 | return None 26 | try: 27 | if event == "opcode": 28 | opcode = frame.f_code.co_code[frame.f_lasti] 29 | opname = dis.opname[opcode] 30 | if is_debug: 31 | print( 32 | f"tracing {event} {opname} {arg} pc={frame.f_lasti} frame={frame_id}({hex(id(frame))})" 33 | ) 34 | record(frame, frame_id) 35 | elif event == "line": 36 | if is_debug: 37 | print( 38 | f"tracing {event} {frame.f_code.co_filename}:{frame.f_lineno}" 39 | ) 40 | else: 41 | if is_debug: 42 | print(f"tracing {event} in {frame.f_code.co_filename}") 43 | except Exception as e: 44 | print("exception in trace_func:", e, type(e)) 45 | print(traceback.format_exc()) 46 | print("code stack:") 47 | traceback.print_stack(f=frame, file=sys.stdout) 48 | if get_config("enable_fallback"): 49 | run_trace_func = False 50 | for i in trackers: 51 | fall_back_frames.append(i.frame_id) 52 | # if len(trackers) > 1: 53 | # disable_trace(frame_id) 54 | print("fallback frames", fall_back_frames) 55 | set_fallback(None) 56 | return None 57 | else: 58 | raise e 59 | return None 60 | 61 | return trace_func 62 | 63 | 64 | def empty_trace_func(_frame: FrameType, _event: str, _arg: Any) -> None: 65 | return None 66 | 67 | 68 | def enable_trace(frame_id: int) -> None: 69 | try: 70 | # print("enable_trace") 71 | this_frame = inspect.currentframe() 72 | assert this_frame is not None 73 | caller_frame = this_frame.f_back 74 | assert caller_frame is not None 75 | push_tracker(caller_frame, frame_id) 76 | sys.settrace(empty_trace_func) 77 | except Exception as e: 78 | print("exception in enable_trace:", e, type(e)) 79 | print(traceback.format_exc()) 80 | raise e 81 | 82 | 83 | def disable_trace(frame_id: int) -> None: 84 | try: 85 | # print("disable_trace") 86 | pop_tracker(frame_id) 87 | sys.settrace(None) 88 | except Exception as e: 89 | print("exception in disable_trace:", e, type(e)) 90 | print(traceback.format_exc()) 91 | raise e 92 | 93 | 94 | def get_process_frame( 95 | f: Callable[..., Any], 96 | is_callee: bool) -> Tuple[Callable[..., Any], Callable[..., Any]]: 97 | 98 | is_debug = get_config('debug') 99 | 100 | def preprocess_frame( 101 | frame: FrameType, frame_id: int 102 | ) -> Tuple[CodeType, Callable[..., Any], ProcessedCode]: 103 | try: 104 | if is_debug: 105 | print(f"preprocess frame {frame.f_code.co_filename}", frame_id, 106 | hex(id(frame)), frame.f_code.co_name) 107 | enable_cache(frame_id) 108 | set_frame_root(frame_id, f) 109 | frame_cache = get_frame_cache(frame_id) 110 | frame_cache.update_code(frame.f_code, frame_id, is_callee) 111 | new_code, code_map = frame_cache.get_new_code(is_callee) 112 | if is_debug: 113 | print("bytecode to run:") 114 | print(format_insts(code_map.guard_insts)) 115 | trace_func = get_trace_func(frame_id) 116 | 117 | except Exception as e: 118 | print("exception in preprocess:", e, type(e)) 119 | print(traceback.format_exc()) 120 | raise e 121 | return (new_code, trace_func, code_map) 122 | 123 | def postprocess_frame(frame: FrameType, frame_id: int) -> None: 124 | try: 125 | from .bytecode_writter import SHOULD_NOT_CALL_REWRITE 126 | if SHOULD_NOT_CALL_REWRITE: 127 | raise ValueError("should not call postprocess") 128 | if is_debug: 129 | print(f"postprocess frame {frame.f_code.co_filename}") 130 | set_frame_root(frame_id, f) 131 | frame_cache = get_frame_cache(frame_id) 132 | frame_cache.update_code(frame.f_code, frame_id, is_callee) 133 | except Exception as e: 134 | if is_debug: 135 | print("exception in postprocess:", e, type(e)) 136 | print(traceback.format_exc()) 137 | raise e 138 | return 139 | 140 | return (preprocess_frame, postprocess_frame) 141 | 142 | 143 | def reset() -> None: 144 | run_trace_func = True 145 | fall_back_frames.clear() 146 | -------------------------------------------------------------------------------- /frontend/variables/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Union, Optional, Tuple, TYPE_CHECKING, Callable 2 | from types import ModuleType, CodeType, CellType, MappingProxyType 3 | import torch 4 | import numpy as np 5 | from collections import OrderedDict 6 | from .base import Variable, HelperFunctions 7 | from .scalar import ScalarVar, NumpyScalarVar 8 | from .tensor import TensorVar, TorchParamVar, TorchSizeVar, TorchDtypeVar, TorchDeviceVar, TorchLayoutVar 9 | from .torch_module import TorchModuleVar, TorchSequentialVar, TorchModuleListVar 10 | from .any_ import AnyVar 11 | from .const import NullVar, NoneVar, SliceVar, ModuleVar, FunctionVar, RangeVar, CodeVar, EllipsisVar 12 | from .iterator import IteratorVar, RangeIterVar 13 | from .tuple_ import TupleVar 14 | from .set_ import SetVar, FrozensetVar 15 | from .list_ import ListVar, NdarrayVar 16 | from .dict_ import DictVar, OrderedDictVar 17 | from .builtin_types import CellVar, MappingProxyVar 18 | from ..fx_graph import FxGraph 19 | from ..utils import NullObject, UnknownTypeError, is_structseq 20 | from ..store_pos import StorePos 21 | 22 | ty2var: dict[type[Any], type[Variable]] = { 23 | float: ScalarVar, 24 | int: ScalarVar, 25 | str: ScalarVar, 26 | bool: ScalarVar, 27 | torch.Tensor: TensorVar, 28 | NullObject: NullVar, 29 | type(None): NoneVar, 30 | slice: SliceVar, 31 | torch.nn.Parameter: TorchParamVar, 32 | tuple: TupleVar, 33 | list: ListVar, 34 | set: SetVar, 35 | frozenset: FrozensetVar, 36 | torch.Size: TorchSizeVar, 37 | torch.dtype: TorchDtypeVar, 38 | torch.device: TorchDeviceVar, 39 | torch.layout: TorchLayoutVar, 40 | dict: DictVar, 41 | CodeType: CodeVar, 42 | OrderedDict: OrderedDictVar, 43 | np.ndarray: NdarrayVar, 44 | } 45 | 46 | CONST_TYPES = Union[int, float, bool, str, NullObject, None, slice] 47 | 48 | 49 | def make_var_from_value( 50 | value: Any, 51 | need_guard_check: bool, 52 | helper_functions: HelperFunctions, 53 | fx_graph: Optional[FxGraph] = None, 54 | extract_code_at_start: Optional[list[StorePos]] = None) -> Variable: 55 | if extract_code_at_start is None: 56 | extract_code_at_start = [] 57 | if type(value) == np.ndarray and value.size == 1: 58 | return NumpyScalarVar.from_value(np.int64(value.tolist()), 59 | need_guard_check, helper_functions, 60 | fx_graph, extract_code_at_start) 61 | if type(value) in ty2var: 62 | return ty2var[type(value)].from_value(value, need_guard_check, 63 | helper_functions, fx_graph, 64 | extract_code_at_start) 65 | elif isinstance(value, torch.nn.Module): 66 | return TorchModuleVar.from_value(value, need_guard_check, 67 | helper_functions, fx_graph, 68 | extract_code_at_start) 69 | elif isinstance(value, ModuleType): 70 | return ModuleVar.from_value(value, need_guard_check, helper_functions, 71 | fx_graph, extract_code_at_start) 72 | elif callable(value): 73 | return FunctionVar.from_value(value, need_guard_check, helper_functions, 74 | fx_graph, extract_code_at_start) 75 | elif isinstance(value, range): 76 | return RangeVar.from_value(value, need_guard_check, helper_functions, 77 | fx_graph, extract_code_at_start) 78 | elif isinstance(value, type(range(0).__iter__())): 79 | return RangeIterVar.from_value(value, need_guard_check, 80 | helper_functions, fx_graph, 81 | extract_code_at_start) 82 | elif isinstance(value, CellType): 83 | return CellVar.from_value(value, need_guard_check, helper_functions, 84 | fx_graph, extract_code_at_start) 85 | elif isinstance(value, np.generic): 86 | return NumpyScalarVar.from_value(value, need_guard_check, 87 | helper_functions, fx_graph, 88 | extract_code_at_start) 89 | elif is_structseq(value): 90 | return TupleVar.from_value(value, need_guard_check, helper_functions, 91 | fx_graph, extract_code_at_start) 92 | elif type(value) == MappingProxyType: 93 | return MappingProxyVar.from_value(value, need_guard_check, 94 | helper_functions, fx_graph, 95 | extract_code_at_start) 96 | elif isinstance(value, type(Ellipsis)): 97 | return EllipsisVar.from_value(value, need_guard_check, helper_functions, 98 | fx_graph, extract_code_at_start) 99 | else: 100 | # NOTE: use any instead of iteartor_var to represent iterator with unknown source due to the hardness of getting iterable and num_iters 101 | return AnyVar.from_value(value, need_guard_check, helper_functions, 102 | fx_graph, extract_code_at_start) 103 | 104 | 105 | __all__ = [ 106 | 'make_var_from_value', 'Variable', 'ScalarVar', 'TensorVar', 107 | 'TorchModuleVar', 'NullVar', 'NoneVar', "ModuleVar", "FunctionVar", 108 | "TorchParamVar", "AnyVar", "IteratorVar", "RangeIterVar", "HelperFunctions" 109 | ] 110 | -------------------------------------------------------------------------------- /frontend/variables/any_.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING, Union, Optional, Callable, Any 2 | from frontend.pycode_generator import GraphFnCodegen 3 | 4 | import torch.fx 5 | from types import ModuleType 6 | from enum import Enum 7 | from .base import Variable, HelperFunctions 8 | from ..fx_graph import NodeArgs, FxGraph 9 | from ..store_pos import StorePos 10 | if TYPE_CHECKING: 11 | from ..pycode_generator import GraphFnCodegen, GuardFnCodegen 12 | 13 | 14 | class AnyVar(Variable): 15 | 16 | def __init__(self, need_guard_check: bool, obj: Any, 17 | extract_code_at_start: list[StorePos]) -> None: 18 | super().__init__(need_guard_check, obj, extract_code_at_start) 19 | 20 | def make_guard_inner(self, codegen: "GuardFnCodegen", 21 | pos: StorePos) -> None: 22 | pass 23 | 24 | def make_output_inner(self, name_in_graph_fn: str, store_pos: StorePos, 25 | codegen: "GraphFnCodegen", in_return: bool, 26 | idx: int) -> None: 27 | extract_pos = self.fetch_extract_code_at_start() 28 | assert len(extract_pos) > 0 29 | extract_pos[0].add_name_to_fn(codegen) 30 | codegen.output(name_in_graph_fn, store_pos, str(extract_pos[0]), 31 | in_return, idx) 32 | 33 | @classmethod 34 | def from_value(cls, value: None, need_guard_check: bool, 35 | _helper_functions: 'HelperFunctions', 36 | _fx_graph: Optional[FxGraph], 37 | extract_code_at_start: list[StorePos]) -> "AnyVar": 38 | return cls(need_guard_check, value, extract_code_at_start) 39 | 40 | def as_fx_node(self) -> NodeArgs: 41 | raise NotImplementedError(self.obj) 42 | -------------------------------------------------------------------------------- /frontend/variables/base.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from abc import abstractmethod 3 | from typing import Any, TYPE_CHECKING, Optional, Tuple, Iterable, Callable 4 | from copy import copy 5 | import torch 6 | from frontend.utils import add_force_graph_break 7 | 8 | from ..c_api import get_miss_locals 9 | from ..fx_graph import FxGraph 10 | from ..store_pos import StorePos, StoreInAttr 11 | 12 | if TYPE_CHECKING: 13 | import torch.fx 14 | from ..pycode_generator import GraphFnCodegen, GuardFnCodegen 15 | from ..fx_graph import FxGraph, NodeArgs 16 | from ..object_table import ObjectTable 17 | 18 | 19 | @dataclass 20 | class HelperFunctions: 21 | get_or_make_var: Callable[[Any, bool, Optional[FxGraph], list[StorePos]], 22 | 'Variable'] 23 | gen_by_caller: Callable[[Any], bool] 24 | mark_cannot_guard: Callable[[], None] 25 | 26 | 27 | @dataclass 28 | class Variable: 29 | need_guard_check: bool 30 | extract_code_at_start: list[StorePos] 31 | extract_code_hashs: set[int] 32 | obj: Any 33 | modified_attrs: dict[str, 'Variable'] 34 | prev: Optional['Variable'] = None 35 | succ: Optional['Variable'] = None 36 | 37 | def __init__(self, need_guard_check: bool, obj: Any, 38 | extract_code_at_start: list[StorePos]) -> None: 39 | from ..guard_tracker import trackers 40 | for i in get_miss_locals(trackers[-1].frame_id): 41 | for j in extract_code_at_start: 42 | if (i == f"{j}"): 43 | print(i) 44 | print("--------warning--------") 45 | 46 | self.need_guard_check = need_guard_check 47 | self.obj = obj 48 | self.extract_code_at_start = extract_code_at_start 49 | self.extract_code_hashs = set() 50 | for pos in extract_code_at_start: 51 | self.extract_code_hashs.add(str(pos).__hash__()) 52 | if need_guard_check: 53 | assert len(extract_code_at_start) > 0 54 | self.modified_attrs = dict() 55 | 56 | @classmethod 57 | @abstractmethod 58 | def from_value( 59 | self, 60 | value: Any, 61 | need_guard_check: bool, 62 | _helper_functions: 'HelperFunctions', 63 | fx_graph: Optional[FxGraph], 64 | extract_code_at_start: list[StorePos], 65 | ) -> 'Variable': 66 | raise NotImplementedError 67 | 68 | def make_guard(self, codegen: "GuardFnCodegen") -> None: 69 | if self.need_guard_check: 70 | assert len(self.extract_code_at_start) > 0 71 | for pos in self.extract_code_at_start: 72 | pos.add_name_to_fn(codegen) 73 | self.make_guard_inner(codegen, pos) 74 | 75 | @abstractmethod 76 | def make_guard_inner(self, codegen: "GuardFnCodegen", 77 | pos: StorePos) -> None: 78 | raise NotImplementedError 79 | 80 | def make_output(self, name_in_graph_fn: str, store_pos: StorePos, 81 | codegen: "GraphFnCodegen", in_return: bool, 82 | idx: int) -> None: 83 | if self.succ is not None: 84 | return self.succ.make_output(name_in_graph_fn, store_pos, codegen, 85 | in_return, idx) 86 | if idx in codegen.id2name: 87 | codegen.output(name_in_graph_fn, store_pos, codegen.id2name[idx], 88 | in_return, 0) 89 | else: 90 | self.make_output_inner(name_in_graph_fn, store_pos, codegen, 91 | in_return, idx) 92 | for attr, var in self.modified_attrs.items(): 93 | if isinstance(var.obj, torch.nn.Parameter) and len( 94 | var.extract_code_at_start) == 0: 95 | continue 96 | var.make_output(f'{name_in_graph_fn}_dot_{attr}', 97 | StoreInAttr(store_pos, id(self.obj), attr), 98 | codegen, False, id(getattr(self.obj, attr))) 99 | codegen.add_stmt( 100 | f"setattr({name_in_graph_fn}, '{attr}', {name_in_graph_fn}_dot_{attr})" 101 | ) 102 | 103 | @abstractmethod 104 | def make_output_inner(self, name_in_graph_fn: str, store_pos: StorePos, 105 | codegen: "GraphFnCodegen", in_return: bool, 106 | idx: int) -> None: 107 | raise NotImplementedError 108 | 109 | @abstractmethod 110 | def as_fx_node(self) -> "NodeArgs": 111 | raise NotImplementedError 112 | 113 | def add_subvars_to_table(self, table: 'ObjectTable') -> None: 114 | pass 115 | 116 | def set_prev(self, prev: Optional['Variable']) -> None: 117 | self.prev = prev 118 | if prev is not None: 119 | prev.succ = self 120 | 121 | def get_subvars_with_idx(self) -> Iterable[Tuple["Variable", int]]: 122 | return [] 123 | 124 | def get_oldest_var(self) -> "Variable": 125 | ret = self 126 | while ret.prev is not None: 127 | ret = ret.prev 128 | return ret 129 | 130 | def disable_guard_check(self) -> None: 131 | self.need_guard_check = False 132 | 133 | def clear_extract_code_at_start(self) -> None: 134 | self.extract_code_at_start = [] 135 | self.extract_code_hashs = set() 136 | 137 | def add_extract_code_at_start(self, pos: StorePos) -> None: 138 | from ..guard_tracker import trackers 139 | for i in get_miss_locals(trackers[-1].frame_id): 140 | if i == f"{pos}": 141 | print(i) 142 | print("--------warning--------") 143 | 144 | hash_value = str(pos).__hash__() 145 | if hash_value not in self.extract_code_hashs: 146 | self.extract_code_at_start.append(pos) 147 | self.extract_code_hashs.add(hash_value) 148 | 149 | def add_modified_attr(self, attr: str, var: 'Variable') -> None: 150 | self.modified_attrs[attr] = var 151 | 152 | def fetch_extract_code_at_start(self) -> list[StorePos]: 153 | 154 | def is_same(a: dict[str, Variable], b: dict[str, Variable]) -> bool: 155 | if len(a) != len(b): 156 | return False 157 | for k, v in a.items(): 158 | if k not in b or id(b[k]) != id(v): 159 | return False 160 | return True 161 | 162 | prev = self 163 | while prev.prev is not None and not is_same(prev.prev.modified_attrs, 164 | prev.modified_attrs): 165 | prev = prev.prev 166 | return prev.extract_code_at_start 167 | -------------------------------------------------------------------------------- /frontend/variables/builtin_types.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING, Optional, Tuple, Any, Callable 2 | from types import CellType 3 | from .base import Variable, HelperFunctions 4 | from ..fx_graph import NodeArgs, FxGraph 5 | from ..store_pos import StorePos, StoreInAttr, StoreInFreeVar 6 | from ..c_api import parse_mapproxyobject, parse_cell 7 | import torch 8 | if TYPE_CHECKING: 9 | from ..pycode_generator import GraphFnCodegen, GuardFnCodegen 10 | from ..object_table import ObjectTable 11 | 12 | 13 | class CellVar(Variable): 14 | sub_var: Variable 15 | sub_id: int 16 | 17 | def __init__(self, value: CellType, need_guard_check: bool, 18 | helper_functions: HelperFunctions, fx_graph: Optional[FxGraph], 19 | extract_code_at_start: list[StorePos]) -> None: 20 | super().__init__(need_guard_check, value, extract_code_at_start) 21 | assert len(extract_code_at_start) > 0 22 | sub_obj = parse_cell(value) 23 | new_extract: list[StorePos] = [ 24 | StoreInAttr(pos, id(value), "cell_contents") 25 | for pos in self.extract_code_at_start 26 | ] 27 | self.sub_var = helper_functions.get_or_make_var(sub_obj, 28 | need_guard_check, 29 | fx_graph, new_extract) 30 | self.sub_id = id(sub_obj) 31 | 32 | def make_guard_inner(self, codegen: "GuardFnCodegen", 33 | pos: StorePos) -> None: 34 | codegen.add_import_from("types", "CellType") 35 | codegen.add_check((f"isinstance({pos}, CellType)", pos)) 36 | self.sub_var.make_guard_inner( 37 | codegen, StoreInAttr(pos, self.sub_id, "cell_contents")) 38 | 39 | def make_output_inner(self, name_in_graph_fn: str, store_pos: StorePos, 40 | codegen: "GraphFnCodegen", in_return: bool, 41 | idx: int) -> None: 42 | codegen.add_import("inspect") 43 | codegen.add_import_from("frontend.c_api", "get_from_freevars") 44 | codegen.output(f"{name_in_graph_fn}_frame", store_pos, 45 | "inspect.currentframe().f_back", False, idx) 46 | extract_pos = self.fetch_extract_code_at_start() 47 | assert len(extract_pos) > 0 48 | extract_pos[0].add_name_to_fn(codegen) 49 | cell_pos = extract_pos[0] 50 | if isinstance(cell_pos, StoreInFreeVar): 51 | codegen.output( 52 | name_in_graph_fn, store_pos, 53 | f"get_from_freevars({name_in_graph_fn}_frame, {cell_pos.free_idx})", 54 | False, idx) 55 | else: 56 | raise NotImplementedError 57 | 58 | @classmethod 59 | def from_value(cls, value: CellType, need_guard_check: bool, 60 | helper_functions: HelperFunctions, 61 | fx_graph: Optional[FxGraph], 62 | extract_code_at_start: list[StorePos]) -> "CellVar": 63 | return cls(value, need_guard_check, helper_functions, fx_graph, 64 | extract_code_at_start) 65 | 66 | def add_subvars_to_table(self, table: 'ObjectTable') -> None: 67 | old_var = table.get_or_none_by_id(self.sub_id) 68 | if old_var is not None: 69 | pass 70 | # new_extract: list[StorePos] = [ 71 | # StoreInAttr(pos, self.sub_id, "cell_contents") 72 | # for pos in self.extract_code_at_start 73 | # ] 74 | # old_var.extract_code_at_start.extend(new_extract) 75 | # old_var.need_guard_check |= self.need_guard_check 76 | else: 77 | table.add_by_id(self.sub_var, self.sub_id) 78 | self.sub_var.add_subvars_to_table(table) 79 | 80 | 81 | class MappingProxyVar(Variable): 82 | sub_var: Variable 83 | sub_id: int 84 | 85 | def __init__(self, value: Any, need_guard_check: bool, 86 | helper_functions: HelperFunctions, fx_graph: Optional[FxGraph], 87 | extract_code_at_start: list[StorePos]) -> None: 88 | super().__init__(need_guard_check, value, extract_code_at_start) 89 | # assert len(extract_code_at_start) > 0 90 | sub_obj = parse_mapproxyobject(value) 91 | new_extract: list[StorePos] = [] 92 | self.sub_var = helper_functions.get_or_make_var(sub_obj, 93 | need_guard_check, 94 | fx_graph, new_extract) 95 | self.sub_id = id(sub_obj) 96 | 97 | @classmethod 98 | def from_value(cls, value: Any, need_guard_check: bool, 99 | helper_functions: HelperFunctions, 100 | fx_graph: Optional[FxGraph], 101 | extract_code_at_start: list[StorePos]) -> "MappingProxyVar": 102 | return cls(value, need_guard_check, helper_functions, fx_graph, 103 | extract_code_at_start) 104 | 105 | def make_guard_inner(self, codegen: 'GuardFnCodegen', 106 | pos: StorePos) -> None: 107 | raise ValueError("TOOD") 108 | 109 | def make_output_inner(self, name_in_graph_fn: str, store_pos: StorePos, 110 | codegen: 'GraphFnCodegen', in_return: bool, 111 | idx: int) -> None: 112 | raise ValueError("TOOD") 113 | 114 | def add_subvars_to_table(self, table: 'ObjectTable') -> None: 115 | old_var = table.get_or_none_by_id(self.sub_id) 116 | if old_var is not None: 117 | # TODO: handle extract_code_at_start 118 | pass 119 | else: 120 | table.add_by_id(self.sub_var, self.sub_id) 121 | self.sub_var.add_subvars_to_table(table) 122 | -------------------------------------------------------------------------------- /frontend/variables/dict_.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING, Optional, Any, Callable 2 | from .base import Variable, HelperFunctions 3 | from ..fx_graph import NodeArgs, FxGraph 4 | from ..store_pos import StorePos, StoreInIndex 5 | from .tensor import TensorVar 6 | import torch 7 | if TYPE_CHECKING: 8 | from ..pycode_generator import GraphFnCodegen, GuardFnCodegen 9 | from ..object_table import ObjectTable 10 | 11 | 12 | class DictVar(Variable): 13 | vars: list[Variable] 14 | obj_ids: list[int] 15 | length: int 16 | 17 | def __init__(self, 18 | value: dict[Any, Any], 19 | need_guard_check: bool, 20 | helper_functions: HelperFunctions, 21 | fx_graph: Optional[FxGraph] = None, 22 | extract_code_at_start: list[StorePos] = []) -> None: 23 | super().__init__(need_guard_check, value, extract_code_at_start) 24 | self.value = value 25 | self.length = len(value) 26 | self.vars = [] 27 | self.obj_ids = [] 28 | for key, obj in self.value.items(): 29 | assert not isinstance(key, torch.Tensor) 30 | if isinstance(key, str): 31 | new_extract: list[StorePos] = [ 32 | StoreInIndex(pos, id(obj), f"'{key}'") 33 | for pos in self.extract_code_at_start 34 | ] 35 | else: 36 | new_extract = [ 37 | StoreInIndex(pos, id(obj), str(key)) 38 | for pos in self.extract_code_at_start 39 | ] 40 | var = helper_functions.get_or_make_var(obj, need_guard_check, 41 | fx_graph, new_extract) 42 | self.vars.append(var) 43 | self.obj_ids.append(id(obj)) 44 | 45 | def make_guard_inner(self, codegen: "GuardFnCodegen", 46 | pos: StorePos) -> None: 47 | codegen.add_check((f"isinstance({pos}, dict)", pos)) 48 | codegen.add_check((f"len({pos}) == {self.length}", pos)) 49 | for key, obj in zip(self.value.keys(), self.vars): 50 | if not isinstance(obj, TensorVar): 51 | if isinstance(key, str): 52 | obj.make_guard_inner(codegen, 53 | StoreInIndex(pos, id(obj), f"'{key}'")) 54 | else: 55 | obj.make_guard_inner(codegen, 56 | StoreInIndex(pos, id(obj), str(key))) 57 | 58 | def make_output_inner(self, name_in_graph_fn: str, store_pos: StorePos, 59 | codegen: "GraphFnCodegen", in_return: bool, 60 | idx: int) -> None: 61 | oldest = self.get_oldest_var() 62 | for j, (idx_j, var) in enumerate(zip(self.obj_ids, self.vars)): 63 | var.make_output(f"{name_in_graph_fn}_{j}", store_pos, codegen, 64 | False, idx_j) 65 | 66 | if len(oldest.extract_code_at_start) > 0: 67 | assert isinstance(oldest, DictVar) 68 | old_store_pos = oldest.extract_code_at_start[0] 69 | codegen.add_stmt(f"{old_store_pos}.clear()") 70 | for i, key in enumerate(self.value.keys()): 71 | codegen.add_stmt( 72 | f"{old_store_pos}[{key}]={name_in_graph_fn}_{i}") 73 | codegen.output(name_in_graph_fn, store_pos, str(old_store_pos), 74 | in_return, idx) 75 | else: 76 | items = [] 77 | for key, j in zip(self.value.keys(), range(len(self.vars))): 78 | if isinstance(key, str): 79 | if "\n" not in key: 80 | key_part = f"'{key}'" 81 | else: 82 | key_part = f"'{repr(key)}'" 83 | key_part = key_part.strip("'") 84 | else: 85 | key_part = key 86 | item = f'{key_part}: {name_in_graph_fn}_{j}' 87 | items.append(item) 88 | target = f"{{{', '.join(i for i in items)}}}" 89 | codegen.output(name_in_graph_fn, store_pos, 90 | target if len(self.vars) > 0 else "{}", in_return, 91 | idx) 92 | 93 | @classmethod 94 | def from_value(cls, 95 | value: dict[Any, Any], 96 | need_guard_check: bool, 97 | helper_functions: HelperFunctions, 98 | fx_graph: Optional[FxGraph] = None, 99 | extract_code_at_start: list[StorePos] = []) -> "DictVar": 100 | return cls(value, need_guard_check, helper_functions, fx_graph, 101 | extract_code_at_start) 102 | 103 | def as_fx_node(self) -> NodeArgs: 104 | return self.value 105 | 106 | def add_subvars_to_table(self, table: 'ObjectTable') -> None: 107 | for key, var, idx in zip(self.value.keys(), self.vars, self.obj_ids): 108 | old_var = table.get_or_none_by_id(idx) 109 | if old_var is not None: 110 | if isinstance(key, str): 111 | new_extract: list[StorePos] = [ 112 | StoreInIndex(pos, idx, f"'{key}'") 113 | for pos in self.extract_code_at_start 114 | ] 115 | else: 116 | new_extract = [ 117 | StoreInIndex(pos, idx, str(key)) 118 | for pos in self.extract_code_at_start 119 | ] 120 | old_var.extract_code_at_start.extend(new_extract) 121 | old_var.need_guard_check |= self.need_guard_check 122 | else: 123 | table.add_by_id(var, idx) 124 | var.add_subvars_to_table(table) 125 | 126 | 127 | class OrderedDictVar(DictVar): 128 | 129 | def __init__(self, 130 | value: dict[Any, Any], 131 | need_guard_check: bool, 132 | helper_functions: HelperFunctions, 133 | fx_graph: Optional[FxGraph] = None, 134 | extract_code_at_start: list[StorePos] = []) -> None: 135 | super().__init__(value, need_guard_check, helper_functions, fx_graph, 136 | extract_code_at_start) 137 | 138 | @classmethod 139 | def from_value( 140 | cls, 141 | value: dict[Any, Any], 142 | need_guard_check: bool, 143 | helper_functions: HelperFunctions, 144 | fx_graph: Optional[FxGraph] = None, 145 | extract_code_at_start: list[StorePos] = []) -> "OrderedDictVar": 146 | return cls(value, need_guard_check, helper_functions, fx_graph, 147 | extract_code_at_start) 148 | 149 | def make_guard_inner(self, codegen: "GuardFnCodegen", 150 | pos: StorePos) -> None: 151 | codegen.add_check((f"isinstance({pos}, dict)", pos)) 152 | codegen.add_check((f"len({pos}) == {self.length}", pos)) 153 | for key, var in zip(self.value.keys(), self.vars): 154 | if not isinstance(var, TensorVar): 155 | if isinstance(key, str): 156 | var.make_guard_inner(codegen, 157 | StoreInIndex(pos, id(var), f"'{key}'")) 158 | else: 159 | var.make_guard_inner(codegen, 160 | StoreInIndex(pos, id(var), str(key))) 161 | # TODO: check order 162 | 163 | def make_output_inner(self, name_in_graph_fn: str, store_pos: StorePos, 164 | codegen: "GraphFnCodegen", in_return: bool, 165 | idx: int) -> None: 166 | oldest = self.get_oldest_var() 167 | for j, (idx_j, var) in enumerate(zip(self.obj_ids, self.vars)): 168 | var.make_output(f"{name_in_graph_fn}_{j}", store_pos, codegen, 169 | False, idx_j) 170 | if len(oldest.extract_code_at_start) > 0: 171 | assert isinstance(oldest, DictVar) 172 | old_store_pos = oldest.extract_code_at_start[0] 173 | codegen.add_stmt(f"{old_store_pos}.clear()") 174 | for i, key in enumerate(self.value.keys()): 175 | codegen.add_stmt( 176 | f"{old_store_pos}[{key}]={name_in_graph_fn}_{i}") 177 | codegen.output(name_in_graph_fn, store_pos, str(old_store_pos), 178 | in_return, idx) 179 | else: 180 | codegen.add_import("collections") 181 | 182 | def to_str(value: Any) -> str: 183 | if isinstance(value, str): 184 | return f"'{value}'" 185 | else: 186 | return str(value) 187 | 188 | codegen.output( 189 | name_in_graph_fn, store_pos, 190 | f"collections.OrderedDict([{','.join(f'({to_str(key)}, {name_in_graph_fn}_{j})' for key, j in zip(self.value.keys(), range(len(self.vars))))}])" 191 | if len(self.vars) > 0 else "{}", in_return, idx) 192 | -------------------------------------------------------------------------------- /frontend/variables/iterator.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING, Optional, Tuple, Any, Callable, Iterable 2 | from .base import Variable, HelperFunctions 3 | from ..fx_graph import NodeArgs, FxGraph 4 | from ..store_pos import StorePos, StoreInIndex 5 | import torch 6 | from ..c_api import parse_rangeiterobject 7 | if TYPE_CHECKING: 8 | from ..pycode_generator import GraphFnCodegen, GuardFnCodegen 9 | from ..object_table import ObjectTable 10 | 11 | 12 | class IteratorVar(Variable): 13 | parent_var: Optional[Variable] 14 | parent_idx: int 15 | num_iters: int 16 | 17 | def __init__( 18 | self, 19 | value: Any, 20 | parent_var: Optional[Variable], 21 | parent_idx: int, 22 | num_iters: int, 23 | need_guard_check: bool, 24 | extract_code_at_start: list[StorePos], 25 | ) -> None: 26 | super().__init__(need_guard_check, value, extract_code_at_start) 27 | assert not need_guard_check 28 | self.parent_var = parent_var 29 | self.parent_idx = parent_idx 30 | self.num_iters = num_iters 31 | 32 | @classmethod 33 | def from_parent_var(cls, value: Any, parent_var: Optional[Variable], 34 | parent_idx: int, num_iters: int, need_guard_check: bool, 35 | extract_code_at_start: list[StorePos]) -> "IteratorVar": 36 | return cls(value, parent_var, parent_idx, num_iters, need_guard_check, 37 | extract_code_at_start) 38 | 39 | def make_guard_inner(self, codegen: "GuardFnCodegen", 40 | store_pos: StorePos) -> None: 41 | raise NotImplementedError() 42 | 43 | def make_output_inner(self, name_in_graph_fn: str, store_pos: StorePos, 44 | codegen: "GraphFnCodegen", in_return: bool, 45 | idx: int) -> None: 46 | if self.parent_var is None: 47 | raise ValueError("cannot gen output for None parent_var") 48 | self.parent_var.make_output(f"{name_in_graph_fn}_iterable", store_pos, 49 | codegen, False, idx) 50 | codegen.output(name_in_graph_fn, store_pos, 51 | f"{name_in_graph_fn}_iterable.__iter__()", in_return, 52 | idx) 53 | for i in range(self.num_iters): 54 | codegen.add_stmt(f"{name_in_graph_fn}.__next__()") 55 | 56 | 57 | class RangeIterVar(Variable): 58 | index: int 59 | start: int 60 | step: int 61 | len: int 62 | 63 | def __init__(self, obj: Any, need_guard_check: bool, 64 | extract_code_at_start: list[StorePos]) -> None: 65 | super().__init__(need_guard_check, obj, extract_code_at_start) 66 | self.index, self.start, self.step, self.len = parse_rangeiterobject(obj) 67 | 68 | def make_guard_inner(self, codegen: "GuardFnCodegen", 69 | pos: StorePos) -> None: 70 | codegen.add_import("frontend.c_api") 71 | codegen.add_check(( 72 | f"frontend.c_api.parse_rangeiterobject({pos}) == ({self.index}, {self.start}, {self.step}, {self.len})", 73 | pos)) 74 | 75 | def make_output_inner(self, name_in_graph_fn: str, store_pos: StorePos, 76 | codegen: "GraphFnCodegen", in_return: bool, 77 | idx: int) -> None: 78 | codegen.add_import("frontend.c_api") 79 | codegen.output( 80 | name_in_graph_fn, store_pos, 81 | f"frontend.c_api.make_rangeiterobject({self.index}, {self.start}, {self.step}, {self.len})", 82 | in_return, idx) 83 | 84 | @classmethod 85 | def from_value(cls, value: Any, need_guard_check: bool, 86 | _helper_functions: HelperFunctions, 87 | _fx_graph: Optional[FxGraph], 88 | extract_code_at_start: list[StorePos]) -> "RangeIterVar": 89 | return cls(value, need_guard_check, extract_code_at_start) 90 | -------------------------------------------------------------------------------- /frontend/variables/list_.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING, Optional, Any, Callable 2 | from copy import copy 3 | import numpy as np 4 | from .base import Variable, HelperFunctions 5 | from ..fx_graph import NodeArgs, FxGraph 6 | from ..store_pos import StorePos, StoreInIndex 7 | if TYPE_CHECKING: 8 | from ..pycode_generator import GraphFnCodegen, GuardFnCodegen 9 | from ..object_table import ObjectTable 10 | 11 | 12 | class ListVar(Variable): 13 | vars: list[Variable] 14 | obj_ids: list[int] 15 | length: int 16 | helper_functions: HelperFunctions 17 | graph: Optional[FxGraph] 18 | 19 | def __init__(self, value: list[Any], need_guard_check: bool, 20 | helper_functions: HelperFunctions, fx_graph: Optional[FxGraph], 21 | extract_code_at_start: list[StorePos]) -> None: 22 | super().__init__(need_guard_check, value, extract_code_at_start) 23 | self.value = value 24 | self.length = len(value) 25 | self.vars = [] 26 | self.obj_ids = [] 27 | self.helper_functions = helper_functions 28 | self.graph = fx_graph 29 | for i, obj in enumerate(value): 30 | new_extract: list[StorePos] = [ 31 | StoreInIndex(pos, id(obj), i) 32 | for pos in self.extract_code_at_start 33 | ] 34 | var = helper_functions.get_or_make_var(obj, need_guard_check, 35 | fx_graph, new_extract) 36 | self.vars.append(var) 37 | self.obj_ids.append(id(obj)) 38 | 39 | def make_guard_inner(self, codegen: "GuardFnCodegen", 40 | pos: StorePos) -> None: 41 | codegen.add_check((f"isinstance({pos}, list)", pos)) 42 | codegen.add_check((f"len({pos}) == {self.length}", pos)) 43 | for i, obj in enumerate(self.vars): 44 | obj.make_guard_inner(codegen, StoreInIndex(pos, id(obj), i)) 45 | 46 | def make_output_inner(self, name_in_graph_fn: str, store_pos: StorePos, 47 | codegen: "GraphFnCodegen", in_return: bool, 48 | idx: int) -> None: 49 | oldest = self.get_oldest_var() 50 | if len(self.obj) != len(self.vars): 51 | # updated list 52 | self.vars.clear() 53 | self.obj_ids.clear() 54 | for i, obj in enumerate(self.obj): 55 | new_extract: list[StorePos] = [ 56 | StoreInIndex(pos, id(obj), i) 57 | for pos in self.extract_code_at_start 58 | ] 59 | var = self.helper_functions.get_or_make_var( 60 | obj, self.need_guard_check, self.graph, new_extract) 61 | self.vars.append(var) 62 | self.obj_ids.append(id(obj)) 63 | for j, (idx_j, var) in enumerate(zip(self.obj_ids, self.vars)): 64 | var.make_output(f"{name_in_graph_fn}_{j}", store_pos, codegen, 65 | False, idx_j) 66 | if len(oldest.extract_code_at_start) > 0: 67 | assert isinstance(oldest, ListVar) 68 | old_store_pos = oldest.extract_code_at_start[0] 69 | codegen.add_stmt(f"{old_store_pos}.clear()") 70 | for i in range(self.length): 71 | codegen.add_stmt( 72 | f"{old_store_pos}.append({name_in_graph_fn}_{i})") 73 | codegen.output(name_in_graph_fn, store_pos, str(old_store_pos), 74 | in_return, idx) 75 | else: 76 | codegen.output( 77 | name_in_graph_fn, store_pos, 78 | f"[{','.join(f'{name_in_graph_fn}_{j}' for j in range(len(self.vars)))},]" 79 | if len(self.vars) > 0 else "[]", in_return, idx) 80 | 81 | @classmethod 82 | def from_value(cls, value: list[Any], need_guard_check: bool, 83 | helper_functions: HelperFunctions, 84 | fx_graph: Optional[FxGraph], 85 | extract_code_at_start: list[StorePos]) -> "ListVar": 86 | return cls(value, need_guard_check, helper_functions, fx_graph, 87 | extract_code_at_start) 88 | 89 | def as_fx_node(self) -> NodeArgs: 90 | return self.value 91 | 92 | def add_subvars_to_table(self, table: 'ObjectTable') -> None: 93 | for i, (var, idx) in enumerate(zip(self.vars, self.obj_ids)): 94 | old_var = table.get_or_none_by_id(idx) 95 | if old_var is not None: 96 | new_extract: list[StorePos] = [ 97 | StoreInIndex(pos, idx, i) 98 | for pos in self.extract_code_at_start 99 | ] 100 | old_var.extract_code_at_start.extend(new_extract) 101 | old_var.need_guard_check |= self.need_guard_check 102 | else: 103 | table.add_by_id(var, idx) 104 | var.add_subvars_to_table(table) 105 | 106 | 107 | class NdarrayVar(Variable): 108 | vars: list[Variable] 109 | obj_ids: list[int] 110 | length: int 111 | 112 | def __init__(self, value: np.ndarray[Any, Any], need_guard_check: bool, 113 | helper_functions: HelperFunctions, fx_graph: Optional[FxGraph], 114 | extract_code_at_start: list[StorePos]) -> None: 115 | super().__init__(need_guard_check, value, extract_code_at_start) 116 | self.value = value 117 | self.length = value.size 118 | self.vars = [] 119 | self.obj_ids = [] 120 | for i, obj in enumerate(value): 121 | new_extract: list[StorePos] = [ 122 | StoreInIndex(pos, id(obj), i) 123 | for pos in self.extract_code_at_start 124 | ] 125 | var = helper_functions.get_or_make_var(obj, need_guard_check, 126 | fx_graph, new_extract) 127 | self.vars.append(var) 128 | self.obj_ids.append(id(obj)) 129 | 130 | def make_guard_inner(self, codegen: "GuardFnCodegen", 131 | pos: StorePos) -> None: 132 | codegen.add_import("numpy") 133 | codegen.add_check((f"isinstance({pos}, numpy.ndarray)", pos)) 134 | codegen.add_check((f"len({pos}) == {self.length}", pos)) 135 | for i, obj in enumerate(self.vars): 136 | obj.make_guard_inner(codegen, StoreInIndex(pos, id(obj), i)) 137 | 138 | def make_output_inner(self, name_in_graph_fn: str, store_pos: StorePos, 139 | codegen: "GraphFnCodegen", in_return: bool, 140 | idx: int) -> None: 141 | for j, (idx_j, var) in enumerate(zip(self.obj_ids, self.vars)): 142 | var.make_output(f"{name_in_graph_fn}_{j}", store_pos, codegen, 143 | False, idx_j) 144 | list_str = f"[{','.join(f'{name_in_graph_fn}_{j}' for j in range(len(self.vars)))},]" if len( 145 | self.vars) > 0 else "[]" 146 | codegen.add_import("numpy") 147 | var_str = f"numpy.array({list_str})" 148 | codegen.output(name_in_graph_fn, store_pos, var_str, in_return, idx) 149 | 150 | @classmethod 151 | def from_value(cls, value: np.ndarray[Any, Any], need_guard_check: bool, 152 | helper_functions: HelperFunctions, 153 | fx_graph: Optional[FxGraph], 154 | extract_code_at_start: list[StorePos]) -> "NdarrayVar": 155 | return cls(value, need_guard_check, helper_functions, fx_graph, 156 | extract_code_at_start) 157 | 158 | def as_fx_node(self) -> NodeArgs: 159 | return self.value 160 | 161 | def add_subvars_to_table(self, table: 'ObjectTable') -> None: 162 | for i, (var, idx) in enumerate(zip(self.vars, self.obj_ids)): 163 | old_var = table.get_or_none_by_id(idx) 164 | if old_var is not None: 165 | new_extract: list[StorePos] = [ 166 | StoreInIndex(pos, idx, i) 167 | for pos in self.extract_code_at_start 168 | ] 169 | old_var.extract_code_at_start.extend(new_extract) 170 | old_var.need_guard_check |= self.need_guard_check 171 | else: 172 | table.add_by_id(var, idx) 173 | var.add_subvars_to_table(table) 174 | -------------------------------------------------------------------------------- /frontend/variables/scalar.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING, Union, Optional, Callable, Any 2 | 3 | import torch.fx 4 | import numpy as np 5 | from .. import config 6 | from .base import Variable, HelperFunctions 7 | from ..pycode_writer import get_float_string 8 | from ..fx_graph import NodeArgs, FxGraph 9 | from ..store_pos import StorePos 10 | from ..pycode_writer import new_name 11 | from ..utils import ScalarType 12 | from .. import dynamic as dyn 13 | if TYPE_CHECKING: 14 | from ..pycode_generator import GraphFnCodegen, GuardFnCodegen 15 | 16 | 17 | class ScalarVar(Variable): 18 | value_fix: bool 19 | fx_node: Optional[torch.fx.Node] 20 | 21 | def __init__(self, value: ScalarType, value_fix: bool, 22 | need_guard_check: bool, fx_node: Optional[torch.fx.Node], 23 | extract_code_at_start: list[StorePos]) -> None: 24 | super().__init__(need_guard_check, value, extract_code_at_start) 25 | # NOTE: should implement bool genererated from tensor 26 | # if isinstance(value, bool) and not value_fix: 27 | # raise NotImplementedError 28 | if not value_fix: 29 | assert fx_node is not None 30 | self.value_fix = value_fix 31 | self.fx_node = fx_node 32 | 33 | def make_guard_inner(self, codegen: "GuardFnCodegen", 34 | pos: StorePos) -> None: 35 | codegen.add_check( 36 | (f"isinstance({pos}, {type(self.obj).__name__})", pos)) 37 | if self.value_fix: 38 | if type(self.obj) == float: 39 | codegen.add_check( 40 | (f"{pos} == {get_float_string(self.obj)}", pos)) 41 | codegen.add_import("struct") 42 | elif isinstance(self.obj, str): 43 | codegen.add_check((f"{pos} == '{self.obj}'", pos)) 44 | else: 45 | codegen.add_check((f"{pos} == {self.obj}", pos)) 46 | 47 | def make_output_inner(self, name_in_graph_fn: str, store_pos: StorePos, 48 | codegen: "GraphFnCodegen", in_return: bool, 49 | idx: int) -> None: 50 | if self.value_fix: 51 | if type(self.obj) == float: 52 | codegen.output(name_in_graph_fn, store_pos, 53 | f"{get_float_string(self.obj)} # {self.obj}", 54 | in_return, idx) 55 | codegen.add_import("struct") 56 | elif isinstance(self.obj, str): 57 | codegen.output(name_in_graph_fn, store_pos, f"'{self.obj}'", 58 | in_return, idx) 59 | else: 60 | codegen.output(name_in_graph_fn, store_pos, str(self.obj), 61 | in_return, idx) 62 | else: 63 | name_in_graph_output = codegen.add_graph_output(self.fx_node) 64 | codegen.output( 65 | name_in_graph_fn, store_pos, 66 | f'{name_in_graph_output}.item() if isinstance({name_in_graph_output}, torch.Tensor) else {name_in_graph_output}', 67 | in_return, idx) 68 | 69 | @classmethod 70 | def from_value(cls, value: ScalarType, need_guard_check: bool, 71 | _helper_functions: HelperFunctions, 72 | fx_graph: Optional[FxGraph], 73 | extract_code_at_start: list[StorePos]) -> "ScalarVar": 74 | if id(value) not in dyn.dynamic_vars: 75 | return cls(value, True, need_guard_check, None, 76 | extract_code_at_start) 77 | else: 78 | assert fx_graph is not None 79 | if need_guard_check: 80 | assert len(extract_code_at_start) > 0 81 | name = new_name('scalar') 82 | if not config.get_config('dynshape'): 83 | fx_node = fx_graph.create_input(torch.tensor(value), name, (), 84 | {}, name) 85 | else: 86 | fx_node = fx_graph.create_sym_input(value, name, (), {}, name) 87 | var = cls.from_value_and_node(value, fx_node, need_guard_check, 88 | extract_code_at_start) 89 | return var 90 | 91 | @classmethod 92 | def from_value_and_node( 93 | cls, value: ScalarType, fx_node: torch.fx.Node, 94 | need_guard_check: bool, 95 | extract_code_at_start: list[StorePos]) -> 'ScalarVar': 96 | var = cls(value, False, need_guard_check, fx_node, 97 | extract_code_at_start) 98 | fx_node.meta["var"] = var 99 | return var 100 | 101 | def as_fx_node(self) -> NodeArgs: 102 | if self.value_fix: 103 | return self.obj 104 | else: 105 | return self.fx_node 106 | 107 | 108 | class NumpyScalarVar(Variable): 109 | dtype: type 110 | value: np.generic 111 | value_fix: bool 112 | fx_node: Optional[torch.fx.Node] 113 | 114 | def __init__(self, value: np.generic, value_fix: bool, 115 | need_guard_check: bool, fx_node: Optional[torch.fx.Node], 116 | extract_code_at_start: list[StorePos]) -> None: 117 | super().__init__(need_guard_check, value, extract_code_at_start) 118 | if not value_fix: 119 | assert fx_node is not None 120 | self.dtype = type(value) 121 | self.value = value.item() 122 | self.value_fix = value_fix 123 | self.fx_node = fx_node 124 | 125 | def make_guard_inner(self, codegen: "GuardFnCodegen", 126 | pos: StorePos) -> None: 127 | codegen.add_import("numpy") 128 | codegen.add_check( 129 | (f"isinstance({pos}, numpy.{type(self.obj).__name__})", pos)) 130 | if self.value_fix: 131 | item = self.obj.item() 132 | if type(item) == float: 133 | codegen.add_check( 134 | (f"{pos}.item() == {get_float_string(item)}", pos)) 135 | codegen.add_import("struct") 136 | elif isinstance(item, str): 137 | codegen.add_check((f"{pos}.item() == '{item}'", pos)) 138 | else: 139 | codegen.add_check((f"{pos}.item() == {item}", pos)) 140 | 141 | def make_output_inner(self, name_in_graph_fn: str, store_pos: StorePos, 142 | codegen: "GraphFnCodegen", in_return: bool, 143 | idx: int) -> None: 144 | if self.value_fix: 145 | if type(self.obj) == float: 146 | codegen.output( 147 | name_in_graph_fn, store_pos, 148 | f"{self.dtype}({get_float_string(self.obj)}) # {self.obj}", 149 | in_return, idx) 150 | codegen.add_import("struct") 151 | elif isinstance(self.obj, str): 152 | codegen.output(name_in_graph_fn, store_pos, f"'{self.obj}'", 153 | in_return, idx) 154 | else: 155 | codegen.output(name_in_graph_fn, store_pos, str(self.obj), 156 | in_return, idx) 157 | else: 158 | name_in_graph_output = codegen.add_graph_output(self.fx_node) 159 | codegen.output( 160 | name_in_graph_fn, store_pos, 161 | f'{name_in_graph_output}.item() if isinstance({name_in_graph_output}, torch.Tensor) else {name_in_graph_output}', 162 | in_return, idx) 163 | 164 | @classmethod 165 | def from_value(cls, value: np.generic, need_guard_check: bool, 166 | _helper_functions: HelperFunctions, 167 | fx_graph: Optional[FxGraph], 168 | extract_code_at_start: list[StorePos]) -> "NumpyScalarVar": 169 | if id(value) not in dyn.dynamic_vars: 170 | return cls(value, True, need_guard_check, None, 171 | extract_code_at_start) 172 | else: 173 | assert fx_graph is not None 174 | assert len(extract_code_at_start) > 0 175 | name = new_name('np_scalar') 176 | fx_node = fx_graph.create_input(torch.tensor(value), name, (), {}, 177 | name) 178 | var = cls.from_value_and_node(value, fx_node, need_guard_check, 179 | extract_code_at_start) 180 | return var 181 | 182 | @classmethod 183 | def from_value_and_node( 184 | cls, value: np.generic, fx_node: torch.fx.Node, 185 | need_guard_check: bool, 186 | extract_code_at_start: list[StorePos]) -> 'NumpyScalarVar': 187 | var = cls(value, False, need_guard_check, fx_node, 188 | extract_code_at_start) 189 | fx_node.meta["var"] = var 190 | return var 191 | 192 | def as_fx_node(self) -> NodeArgs: 193 | if self.value_fix: 194 | return self.obj 195 | else: 196 | return self.fx_node 197 | -------------------------------------------------------------------------------- /frontend/variables/set_.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING, Optional, Callable, Any 2 | from .base import Variable, HelperFunctions 3 | from ..fx_graph import NodeArgs, FxGraph 4 | from ..store_pos import StorePos, StoreInIndex 5 | import torch 6 | if TYPE_CHECKING: 7 | from ..pycode_generator import GraphFnCodegen, GuardFnCodegen 8 | from ..object_table import ObjectTable 9 | 10 | 11 | class SetVar(Variable): 12 | vars: list[Variable] 13 | obj_ids: list[int] 14 | length: int 15 | 16 | def __init__(self, value: set[Any], need_guard_check: bool, 17 | helper_functions: HelperFunctions, fx_graph: Optional[FxGraph], 18 | extract_code_at_start: list[StorePos]) -> None: 19 | super().__init__(need_guard_check, value, extract_code_at_start) 20 | self.value = value 21 | self.length = len(value) 22 | self.vars = [] 23 | self.obj_ids = [] 24 | for i, obj in enumerate(value): 25 | new_extract: list[StorePos] = [ 26 | StoreInIndex(pos, id(obj), i, False) 27 | for pos in self.extract_code_at_start 28 | ] 29 | var = helper_functions.get_or_make_var(obj, need_guard_check, 30 | fx_graph, new_extract) 31 | self.vars.append(var) 32 | self.obj_ids.append(id(obj)) 33 | 34 | def make_guard_inner(self, codegen: "GuardFnCodegen", 35 | pos: StorePos) -> None: 36 | codegen.add_check((f'isinstance({pos}, set)', pos)) 37 | codegen.add_check((f"len({pos}) == {self.length}", pos)) 38 | for i, (var, obj) in enumerate(zip(self.vars, self.obj_ids)): 39 | var.make_guard_inner(codegen, StoreInIndex(pos, obj, i, False)) 40 | 41 | def make_output_inner(self, name_in_graph_fn: str, store_pos: StorePos, 42 | codegen: "GraphFnCodegen", in_return: bool, 43 | idx: int) -> None: 44 | for j, (idx_j, var) in enumerate(zip(self.obj_ids, self.vars)): 45 | var.make_output(f"{name_in_graph_fn}_{j}", store_pos, codegen, 46 | False, idx_j) 47 | 48 | codegen.output( 49 | name_in_graph_fn, store_pos, 50 | f"{{{','.join(f'{name_in_graph_fn}_{j}' for j in range(len(self.vars)))},}}" 51 | if len(self.vars) > 0 else "set()", in_return, idx) 52 | 53 | @classmethod 54 | def from_value(cls, value: set[Any], need_guard_check: bool, 55 | helper_functions: HelperFunctions, 56 | fx_graph: Optional[FxGraph], 57 | extract_code_at_start: list[StorePos]) -> "SetVar": 58 | return cls(value, need_guard_check, helper_functions, fx_graph, 59 | extract_code_at_start) 60 | 61 | def as_fx_node(self) -> NodeArgs: 62 | return self.value 63 | 64 | def add_subvars_to_table(self, table: 'ObjectTable') -> None: 65 | for i, (var, idx) in enumerate(zip(self.vars, self.obj_ids)): 66 | old_var = table.get_or_none_by_id(idx) 67 | if old_var is not None: 68 | new_extract: list[StorePos] = [ 69 | StoreInIndex(pos, idx, i, False) 70 | for pos in self.extract_code_at_start 71 | ] 72 | old_var.extract_code_at_start.extend(new_extract) 73 | old_var.need_guard_check |= self.need_guard_check 74 | else: 75 | table.add_by_id(var, idx) 76 | var.add_subvars_to_table(table) 77 | 78 | 79 | class FrozensetVar(Variable): 80 | vars: list[Variable] 81 | obj_ids: list[int] 82 | length: int 83 | 84 | def __init__(self, value: frozenset[Any], need_guard_check: bool, 85 | helper_functions: HelperFunctions, fx_graph: Optional[FxGraph], 86 | extract_code_at_start: list[StorePos]) -> None: 87 | super().__init__(need_guard_check, value, extract_code_at_start) 88 | self.value = value 89 | self.length = len(value) 90 | self.vars = [] 91 | self.obj_ids = [] 92 | for i, obj in enumerate(value): 93 | new_extract: list[StorePos] = [ 94 | StoreInIndex(pos, id(obj), i, False) 95 | for pos in self.extract_code_at_start 96 | ] 97 | var = helper_functions.get_or_make_var(obj, need_guard_check, 98 | fx_graph, new_extract) 99 | self.vars.append(var) 100 | self.obj_ids.append(id(obj)) 101 | 102 | def make_guard_inner(self, codegen: "GuardFnCodegen", 103 | pos: StorePos) -> None: 104 | codegen.add_check((f'isinstance({pos}, frozenset)', pos)) 105 | codegen.add_check((f"len({pos}) == {self.length}", pos)) 106 | for i, (var, obj) in enumerate(zip(self.vars, self.obj_ids)): 107 | var.make_guard_inner(codegen, StoreInIndex(pos, obj, i, False)) 108 | 109 | def make_output_inner(self, name_in_graph_fn: str, store_pos: StorePos, 110 | codegen: "GraphFnCodegen", in_return: bool, 111 | idx: int) -> None: 112 | for j, (idx_j, var) in enumerate(zip(self.obj_ids, self.vars)): 113 | var.make_output(f"{name_in_graph_fn}_{j}", store_pos, codegen, 114 | False, idx_j) 115 | 116 | codegen.output( 117 | name_in_graph_fn, store_pos, 118 | f"{{{','.join(f'{name_in_graph_fn}_{j}' for j in range(len(self.vars)))},}}" 119 | if len(self.vars) > 0 else "frozenset()", in_return, idx) 120 | 121 | @classmethod 122 | def from_value(cls, value: frozenset[Any], need_guard_check: bool, 123 | helper_functions: HelperFunctions, 124 | fx_graph: Optional[FxGraph], 125 | extract_code_at_start: list[StorePos]) -> "FrozensetVar": 126 | return cls(value, need_guard_check, helper_functions, fx_graph, 127 | extract_code_at_start) 128 | 129 | def as_fx_node(self) -> NodeArgs: 130 | return self.value 131 | 132 | def add_subvars_to_table(self, table: 'ObjectTable') -> None: 133 | for i, (var, idx) in enumerate(zip(self.vars, self.obj_ids)): 134 | old_var = table.get_or_none_by_id(idx) 135 | if old_var is not None: 136 | new_extract: list[StorePos] = [ 137 | StoreInIndex(pos, idx, i, False) 138 | for pos in self.extract_code_at_start 139 | ] 140 | old_var.extract_code_at_start.extend(new_extract) 141 | old_var.need_guard_check |= self.need_guard_check 142 | else: 143 | table.add_by_id(var, idx) 144 | var.add_subvars_to_table(table) -------------------------------------------------------------------------------- /frontend/variables/torch_module.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING, Union, Optional, Callable, Any 2 | import torch 3 | import torch.fx 4 | from frontend.pycode_generator import GraphFnCodegen, GuardFnCodegen 5 | from .base import Variable, HelperFunctions 6 | from ..utils import ScalarType 7 | from ..fx_graph import FxGraph, NodeArgs 8 | from ..store_pos import StorePos, StoreInIndex 9 | if TYPE_CHECKING: 10 | from ..pycode_generator import GraphFnCodegen, GuardFnCodegen 11 | from ..object_table import ObjectTable 12 | 13 | 14 | class TorchModuleVar(Variable): 15 | 16 | def __init__(self, value: torch.nn.Module, need_guard_check: bool, 17 | extract_code_at_start: list[StorePos]) -> None: 18 | super().__init__(need_guard_check, value, extract_code_at_start) 19 | 20 | @classmethod 21 | def from_value( 22 | cls, 23 | value: torch.nn.Module, 24 | need_guard_check: bool, 25 | helper_functions: HelperFunctions, 26 | _fx_graph: Optional[FxGraph], 27 | extract_code_at_start: list[StorePos], 28 | ) -> "TorchModuleVar": 29 | if isinstance(value, torch.nn.Sequential): 30 | return TorchSequentialVar(value, need_guard_check, helper_functions, 31 | extract_code_at_start) 32 | elif isinstance(value, torch.nn.ModuleList): 33 | return TorchModuleListVar(value, need_guard_check, helper_functions, 34 | extract_code_at_start) 35 | else: 36 | return cls(value, need_guard_check, extract_code_at_start) 37 | 38 | def make_guard_inner(self, codegen: GuardFnCodegen, pos: StorePos) -> None: 39 | codegen.add_id_check((f"id({pos}) == {id(self.obj)}", pos), self.obj) 40 | 41 | def make_output_inner(self, name_in_graph_fn: str, store_pos: StorePos, 42 | codegen: "GraphFnCodegen", in_return: bool, 43 | idx: int) -> None: 44 | extract_pos = self.fetch_extract_code_at_start() 45 | assert len(extract_pos) > 0 46 | extract_pos[0].add_name_to_fn(codegen) 47 | codegen.output(name_in_graph_fn, store_pos, str(extract_pos[0]), 48 | in_return, idx) 49 | 50 | def as_fx_node(self) -> NodeArgs: 51 | return self.obj 52 | raise ValueError("Cannot convert a module to a node") 53 | 54 | 55 | class TorchSequentialVar(TorchModuleVar): 56 | submodules: list[TorchModuleVar] 57 | submodule_ids: list[int] 58 | 59 | def __init__(self, value: torch.nn.Sequential, need_guard_check: bool, 60 | helper_functions: HelperFunctions, 61 | extract_code_at_start: list[StorePos]) -> None: 62 | super().__init__(value, need_guard_check, extract_code_at_start) 63 | self.submodules = [] 64 | self.submodule_ids = [] 65 | for i, m in enumerate(value): 66 | new_extract: list[StorePos] = [ 67 | StoreInIndex(pos, id(m), i) 68 | for pos in self.extract_code_at_start 69 | ] 70 | var = helper_functions.get_or_make_var(m, need_guard_check, None, 71 | new_extract) 72 | assert isinstance(var, TorchModuleVar) 73 | self.submodules.append(var) 74 | self.submodule_ids.append(id(m)) 75 | 76 | def add_subvars_to_table(self, table: 'ObjectTable') -> None: 77 | for i, (var, idx) in enumerate(zip(self.submodules, 78 | self.submodule_ids)): 79 | old_var = table.get_or_none_by_id(idx) 80 | if old_var is not None: 81 | new_extract: list[StorePos] = [ 82 | StoreInIndex(pos, idx, i) 83 | for pos in self.extract_code_at_start 84 | ] 85 | old_var.extract_code_at_start.extend(new_extract) 86 | old_var.need_guard_check |= self.need_guard_check 87 | else: 88 | table.add_by_id(var, idx) 89 | var.add_subvars_to_table(table) 90 | 91 | 92 | class TorchModuleListVar(TorchModuleVar): 93 | submodules: list[TorchModuleVar] 94 | submodule_ids: list[int] 95 | 96 | def __init__(self, value: torch.nn.ModuleList, need_guard_check: bool, 97 | helper_functions: HelperFunctions, 98 | extract_code_at_start: list[StorePos]) -> None: 99 | super().__init__(value, need_guard_check, extract_code_at_start) 100 | self.submodules = [] 101 | self.submodule_ids = [] 102 | for i, m in enumerate(value): 103 | new_extract: list[StorePos] = [ 104 | StoreInIndex(pos, id(m), i) 105 | for pos in self.extract_code_at_start 106 | ] 107 | var = helper_functions.get_or_make_var(m, need_guard_check, None, 108 | new_extract) 109 | assert isinstance(var, TorchModuleVar) 110 | self.submodules.append(var) 111 | self.submodule_ids.append(id(m)) 112 | 113 | def add_subvars_to_table(self, table: 'ObjectTable') -> None: 114 | for i, (var, idx) in enumerate(zip(self.submodules, 115 | self.submodule_ids)): 116 | old_var = table.get_or_none_by_id(idx) 117 | if old_var is not None: 118 | new_extract: list[StorePos] = [ 119 | StoreInIndex(pos, idx, i) 120 | for pos in self.extract_code_at_start 121 | ] 122 | old_var.extract_code_at_start.extend(new_extract) 123 | old_var.need_guard_check |= self.need_guard_check 124 | else: 125 | table.add_by_id(var, idx) 126 | var.add_subvars_to_table(table) -------------------------------------------------------------------------------- /frontend/variables/tuple_.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING, Optional, Tuple, Any, Callable 2 | from .base import Variable, HelperFunctions 3 | from ..fx_graph import NodeArgs, FxGraph 4 | from ..store_pos import StorePos, StoreInIndex 5 | import torch 6 | if TYPE_CHECKING: 7 | from ..pycode_generator import GraphFnCodegen, GuardFnCodegen 8 | from ..object_table import ObjectTable 9 | 10 | 11 | class TupleVar(Variable): 12 | vars: list[Variable] 13 | obj_ids: list[int] 14 | length: int 15 | 16 | def __init__(self, value: tuple[Any, ...], need_guard_check: bool, 17 | helper_functions: HelperFunctions, fx_graph: Optional[FxGraph], 18 | extract_code_at_start: list[StorePos]) -> None: 19 | super().__init__(need_guard_check, value, extract_code_at_start) 20 | self.value = value 21 | self.length = len(value) 22 | self.vars = [] 23 | self.obj_ids = [] 24 | for i, obj in enumerate(value): 25 | new_extract: list[StorePos] = [ 26 | StoreInIndex(pos, id(obj), i) 27 | for pos in self.extract_code_at_start 28 | ] 29 | var = helper_functions.get_or_make_var(obj, need_guard_check, 30 | fx_graph, new_extract) 31 | self.vars.append(var) 32 | self.obj_ids.append(id(obj)) 33 | 34 | def make_guard_inner(self, codegen: "GuardFnCodegen", 35 | pos: StorePos) -> None: 36 | codegen.add_check((f"isinstance({pos}, tuple)", pos)) 37 | codegen.add_check((f"len({pos}) == {self.length}", pos)) 38 | for i, (var, obj) in enumerate(zip(self.vars, self.obj_ids)): 39 | var.make_guard_inner(codegen, StoreInIndex(pos, obj, i)) 40 | 41 | def make_output_inner(self, name_in_graph_fn: str, store_pos: StorePos, 42 | codegen: "GraphFnCodegen", in_return: bool, 43 | idx: int) -> None: 44 | for j, (idx_j, var) in enumerate(zip(self.obj_ids, self.vars)): 45 | var.make_output(f"{name_in_graph_fn}_{j}", store_pos, codegen, 46 | False, idx_j) 47 | 48 | codegen.output( 49 | name_in_graph_fn, store_pos, 50 | f"({','.join(f'{name_in_graph_fn}_{j}' for j in range(len(self.vars)))},)" 51 | if len(self.vars) > 0 else "()", in_return, idx) 52 | 53 | @classmethod 54 | def from_value(cls, value: Tuple[Any, ...], need_guard_check: bool, 55 | helper_functions: HelperFunctions, 56 | fx_graph: Optional[FxGraph], 57 | extract_code_at_start: list[StorePos]) -> "TupleVar": 58 | return cls(value, need_guard_check, helper_functions, fx_graph, 59 | extract_code_at_start) 60 | 61 | def as_fx_node(self) -> NodeArgs: 62 | return self.value 63 | 64 | def add_subvars_to_table(self, table: 'ObjectTable') -> None: 65 | for i, (var, idx) in enumerate(zip(self.vars, self.obj_ids)): 66 | old_var = table.get_or_none_by_id(idx) 67 | if old_var is not None: 68 | new_extract: list[StorePos] = [ 69 | StoreInIndex(pos, idx, i) 70 | for pos in self.extract_code_at_start 71 | ] 72 | old_var.extract_code_at_start.extend(new_extract) 73 | old_var.need_guard_check |= self.need_guard_check 74 | else: 75 | table.add_by_id(var, idx) 76 | var.add_subvars_to_table(table) 77 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | norecursedirs = test/common 3 | markers = model -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.0.1+cu118 2 | torchvision==0.15.2+cu118 3 | pytest==7.4.0 -------------------------------------------------------------------------------- /scripts/compile_longobj.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | source /opt/spack/share/spack/setup-env.sh 4 | spack unload 5 | spack load gcc@11.3.0 /ohhhwxj 6 | spack load python@3.9.12%gcc@=11.3.0 7 | 8 | set -x 9 | BUILD_DIR="${BUILD_DIR:-`pwd`/../build}" 10 | CPYTHON_DIR=${BUILD_DIR}/cpython 11 | if [ ! -d $BUILD_DIR ]; then 12 | mkdir $BUILD_DIR 13 | fi 14 | if [ ! -d $CPYTHON_DIR ]; then 15 | git clone git@github.com:python/cpython.git $CPYTHON_DIR 16 | fi 17 | TAG=v3.9.12 18 | if [ ! -f $BUILD_DIR/ldlong.${TAG}.so ]; then 19 | SCRIPT_DIR=`pwd` 20 | pushd $CPYTHON_DIR 21 | git checkout $TAG 22 | git checkout -- ${CPYTHON_DIR}/Objects/longobject.c 23 | ./configure '--without-pydebug' '--enable-shared' '--without-ensurepip' '--with-openssl=/usr' '--with-dbmliborder=gdbm' '--with-system-expat' '--with-system-ffi' '--enable-loadable-sqlite-extensions' 'CFLAGS=-fPIC' 24 | make clean 25 | make -j 26 | git apply ${SCRIPT_DIR}/longobject.${TAG}.patch 27 | gcc -c -Wno-unused-result -Wsign-compare -DNDEBUG -g -fwrapv -O3 -Wall -fPIC -std=c99 -Wextra -Wno-unused-result -Wno-unused-parameter -Wno-missing-field-initializers -Werror=implicit-function-declaration -fvisibility=hidden -I${CPYTHON_DIR}/Include/internal -I${CPYTHON_DIR} -I${CPYTHON_DIR}/Include -DPy_BUILD_CORE -o ${BUILD_DIR}/ldlong.o ${CPYTHON_DIR}/Objects/longobject.c 28 | gcc -shared ${BUILD_DIR}/ldlong.o -o ${BUILD_DIR}/ldlong.${TAG}.so -ldl 29 | git checkout -- ${CPYTHON_DIR}/Objects/longobject.c 30 | popd 31 | fi 32 | -------------------------------------------------------------------------------- /scripts/longobject.v3.9.12.patch: -------------------------------------------------------------------------------- 1 | diff --git a/Objects/longobject.c b/Objects/longobject.c 2 | index cf13b2c430..396ad09ec6 100644 3 | --- a/Objects/longobject.c 4 | +++ b/Objects/longobject.c 5 | @@ -17,8 +17,8 @@ class int "PyObject *" "&PyLong_Type" 6 | [clinic start generated code]*/ 7 | /*[clinic end generated code: output=da39a3ee5e6b4b0d input=ec0275e3422a36e3]*/ 8 | 9 | -#define NSMALLPOSINTS _PY_NSMALLPOSINTS 10 | -#define NSMALLNEGINTS _PY_NSMALLNEGINTS 11 | +#define NSMALLPOSINTS 0 12 | +#define NSMALLNEGINTS 0 13 | 14 | _Py_IDENTIFIER(little); 15 | _Py_IDENTIFIER(big); 16 | @@ -29,9 +29,6 @@ _Py_IDENTIFIER(big); 17 | (Py_SIZE(x) == 0 ? (sdigit)0 : \ 18 | (sdigit)(x)->ob_digit[0])) 19 | 20 | -PyObject *_PyLong_Zero = NULL; 21 | -PyObject *_PyLong_One = NULL; 22 | - 23 | #if NSMALLNEGINTS + NSMALLPOSINTS > 0 24 | #define IS_SMALL_INT(ival) (-NSMALLNEGINTS <= (ival) && (ival) < NSMALLPOSINTS) 25 | #define IS_SMALL_UINT(ival) ((ival) < NSMALLPOSINTS) 26 | @@ -5712,99 +5709,6 @@ PyTypeObject PyLong_Type = { 27 | PyObject_Del, /* tp_free */ 28 | }; 29 | 30 | -static PyTypeObject Int_InfoType; 31 | - 32 | -PyDoc_STRVAR(int_info__doc__, 33 | -"sys.int_info\n\ 34 | -\n\ 35 | -A named tuple that holds information about Python's\n\ 36 | -internal representation of integers. The attributes are read only."); 37 | - 38 | -static PyStructSequence_Field int_info_fields[] = { 39 | - {"bits_per_digit", "size of a digit in bits"}, 40 | - {"sizeof_digit", "size in bytes of the C type used to represent a digit"}, 41 | - {NULL, NULL} 42 | -}; 43 | - 44 | -static PyStructSequence_Desc int_info_desc = { 45 | - "sys.int_info", /* name */ 46 | - int_info__doc__, /* doc */ 47 | - int_info_fields, /* fields */ 48 | - 2 /* number of fields */ 49 | -}; 50 | - 51 | -PyObject * 52 | -PyLong_GetInfo(void) 53 | -{ 54 | - PyObject* int_info; 55 | - int field = 0; 56 | - int_info = PyStructSequence_New(&Int_InfoType); 57 | - if (int_info == NULL) 58 | - return NULL; 59 | - PyStructSequence_SET_ITEM(int_info, field++, 60 | - PyLong_FromLong(PyLong_SHIFT)); 61 | - PyStructSequence_SET_ITEM(int_info, field++, 62 | - PyLong_FromLong(sizeof(digit))); 63 | - if (PyErr_Occurred()) { 64 | - Py_CLEAR(int_info); 65 | - return NULL; 66 | - } 67 | - return int_info; 68 | -} 69 | - 70 | -int 71 | -_PyLong_Init(PyThreadState *tstate) 72 | -{ 73 | -#if NSMALLNEGINTS + NSMALLPOSINTS > 0 74 | - for (Py_ssize_t i=0; i < NSMALLNEGINTS + NSMALLPOSINTS; i++) { 75 | - sdigit ival = (sdigit)i - NSMALLNEGINTS; 76 | - int size = (ival < 0) ? -1 : ((ival == 0) ? 0 : 1); 77 | - 78 | - PyLongObject *v = _PyLong_New(1); 79 | - if (!v) { 80 | - return -1; 81 | - } 82 | - 83 | - Py_SET_SIZE(v, size); 84 | - v->ob_digit[0] = (digit)abs(ival); 85 | - 86 | - tstate->interp->small_ints[i] = v; 87 | - } 88 | -#endif 89 | - 90 | - if (_Py_IsMainInterpreter(tstate)) { 91 | - _PyLong_Zero = PyLong_FromLong(0); 92 | - if (_PyLong_Zero == NULL) { 93 | - return 0; 94 | - } 95 | - 96 | - _PyLong_One = PyLong_FromLong(1); 97 | - if (_PyLong_One == NULL) { 98 | - return 0; 99 | - } 100 | - 101 | - /* initialize int_info */ 102 | - if (Int_InfoType.tp_name == NULL) { 103 | - if (PyStructSequence_InitType2(&Int_InfoType, &int_info_desc) < 0) { 104 | - return 0; 105 | - } 106 | - } 107 | - } 108 | - 109 | - return 1; 110 | -} 111 | - 112 | -void 113 | -_PyLong_Fini(PyThreadState *tstate) 114 | -{ 115 | - if (_Py_IsMainInterpreter(tstate)) { 116 | - Py_CLEAR(_PyLong_One); 117 | - Py_CLEAR(_PyLong_Zero); 118 | - } 119 | - 120 | -#if NSMALLNEGINTS + NSMALLPOSINTS > 0 121 | - for (Py_ssize_t i = 0; i < NSMALLNEGINTS + NSMALLPOSINTS; i++) { 122 | - Py_CLEAR(tstate->interp->small_ints[i]); 123 | - } 124 | -#endif 125 | -} 126 | +PyTypeObject get_PyLong_Type() { 127 | + return PyLong_Type; 128 | +} 129 | \ No newline at end of file 130 | -------------------------------------------------------------------------------- /scripts/pytest_with_preload.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | LD_PRELOAD=~/frontend/ldlong.v3.9.12.so pytest $@ -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | setuptools.setup( 4 | name='frontend', 5 | version='0.0.0', 6 | packages=setuptools.find_packages('.', exclude=['test']), 7 | include_dirs=['frontend'], 8 | ext_modules=[ 9 | setuptools.Extension('frontend.c_api', [ 10 | 'frontend/csrc/frame_evaluation.cpp', 'frontend/csrc/opcode.cpp', 11 | 'frontend/csrc/parse_types.cpp' 12 | ], 13 | language='c++', 14 | define_macros=[('LOG_CACHE', 'None')]) 15 | ], 16 | ) 17 | -------------------------------------------------------------------------------- /test/common/checker.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | import logging 3 | from collections import Iterable 4 | import torch 5 | from frontend import cache 6 | 7 | HIT = 1 8 | MISS = 2 9 | ALL_MISS = 3 10 | 11 | 12 | def assert_equal(ref, out): 13 | precision = 1e-3 14 | assert type(ref) == type( 15 | out), f"wrong type: expect {type(ref)}, got {type(out)}" 16 | if isinstance(ref, torch.Tensor): 17 | assert (isinstance(out, torch.Tensor)) 18 | r = ref.cpu() 19 | o = out.cpu() 20 | if r.dtype == torch.bool and o.dtype == torch.int8: 21 | o = o.bool() 22 | all_close = torch.allclose(r, o, atol=precision, rtol=precision) 23 | if not all_close: 24 | close = torch.isclose(r, o, rtol=precision, atol=precision) 25 | print("ref:", torch.masked_select(r, ~close)) 26 | print("out:", torch.masked_select(o, ~close)) 27 | print(torch.sum(~close)) 28 | print("wrong answer !!!!!!!!!!!!!!!!!!!!!!!!!!") 29 | assert (False) 30 | elif isinstance(ref, float): 31 | assert torch.isclose(torch.tensor(ref), 32 | torch.tensor(out), 33 | atol=precision, 34 | rtol=precision) 35 | elif isinstance(ref, Iterable): 36 | assert (isinstance(out, Iterable)) 37 | if isinstance(ref, set): 38 | assert (len(ref) == len(out)) 39 | elif isinstance(ref, dict): 40 | assert (len(ref) == len(out)) 41 | for k, v in ref.items(): 42 | assert_equal(v, out[k]) 43 | else: 44 | for r, o in zip(ref, out): 45 | assert_equal(r, o) 46 | else: 47 | assert ref == out, f"wrong answer: expect {ref}, got {out}" 48 | 49 | 50 | def check_cache_log(caplog, expect_cache_logs, expect_cache_size: int): 51 | recorded_cache_logs = [] 52 | for record in caplog.records: 53 | if record.message.startswith("\033[31mguard cache"): 54 | if "hit" in record.message: 55 | recorded_cache_logs.append(HIT) 56 | elif "miss" in record.message: 57 | recorded_cache_logs.append(MISS) 58 | else: 59 | assert (False), "unknown cache log" 60 | if len(expect_cache_logs) == 1 and expect_cache_logs[0] == ALL_MISS: 61 | expect_cache_logs = [MISS for _ in range(len(recorded_cache_logs))] 62 | assert len(recorded_cache_logs) == len( 63 | expect_cache_logs 64 | ), f"wrong cache log: expect {expect_cache_logs}, got {recorded_cache_logs}" 65 | for recorded, expected in zip(recorded_cache_logs, expect_cache_logs): 66 | assert recorded == expected, f"wrong cache log: expect {expect_cache_logs}, got {recorded_cache_logs}" 67 | assert cache.TOTAL_SIZE == expect_cache_size, f"wrong cache size: expect {expect_cache_size}, got {cache.TOTAL_SIZE}" 68 | 69 | 70 | def should_not_call(*args, **kwargs): 71 | raise ValueError("should not rewrite bytecode") 72 | 73 | 74 | class DisableRewriteByteCode: 75 | old_should_call: bool 76 | 77 | def __enter__(self): 78 | from frontend import bytecode_writter 79 | self.old_should_call = bytecode_writter.SHOULD_NOT_CALL_REWRITE 80 | bytecode_writter.SHOULD_NOT_CALL_REWRITE = True 81 | 82 | def __exit__(self, exc_type, exc_value, traceback): 83 | from frontend import bytecode_writter 84 | bytecode_writter.SHOULD_NOT_CALL_REWRITE = self.old_should_call 85 | 86 | 87 | def run_and_check(compiled, expect_cache_logs, expect_cache_size: int, caplog, 88 | expected_result, *args, **kwargs): 89 | caplog.set_level(logging.INFO) 90 | caplog.clear() 91 | with torch.no_grad(): 92 | if all([x == HIT for x in expect_cache_logs]): 93 | with DisableRewriteByteCode(): 94 | out = compiled(*args, **kwargs) 95 | else: 96 | out = compiled(*args, **kwargs) 97 | assert_equal(expected_result, out) 98 | check_cache_log(caplog, expect_cache_logs, expect_cache_size) 99 | 100 | 101 | def run_and_check_cache(compiled, expect_cache_logs, expect_cache_size: int, 102 | caplog, *args, **kwargs): # do not perform result check 103 | caplog.set_level(logging.INFO) 104 | caplog.clear() 105 | with torch.no_grad(): 106 | if all([x == HIT for x in expect_cache_logs]): 107 | with DisableRewriteByteCode(): 108 | _ = compiled(*args, **kwargs) 109 | else: 110 | _ = compiled(*args, **kwargs) 111 | check_cache_log(caplog, expect_cache_logs, expect_cache_size) 112 | -------------------------------------------------------------------------------- /test/common/plugin_disable_preload.py: -------------------------------------------------------------------------------- 1 | from frontend.no_preload import NO_LD_PRELOAD_CTX 2 | 3 | no_ld_preload = NO_LD_PRELOAD_CTX() 4 | 5 | 6 | # content of plugins/example_plugin.py 7 | def pytest_configure(config): 8 | no_ld_preload.__enter__() 9 | 10 | 11 | def pytest_unconfigure(config): 12 | no_ld_preload.__exit__() -------------------------------------------------------------------------------- /test/conftest.py: -------------------------------------------------------------------------------- 1 | pytest_plugins = [ 2 | 'common.plugin_disable_preload', 3 | ] -------------------------------------------------------------------------------- /test/example.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from frontend.compile import compile 3 | from frontend.utils import SetConfig 4 | 5 | 6 | class Example(torch.nn.Module): 7 | 8 | def __init__(self): 9 | super(Example, self).__init__() 10 | self.conv = torch.nn.Conv2d(3, 3, 3) 11 | self.relu = torch.nn.ReLU() 12 | 13 | def forward(self, x): 14 | x = self.conv(x) 15 | x = self.relu(x) 16 | return x 17 | 18 | 19 | with torch.no_grad(): 20 | model = Example().eval() 21 | x = torch.randn(1, 3, 4, 4) 22 | expect_output = model(x) 23 | print("expect:", expect_output) 24 | 25 | # set the graph compiler to inductor 26 | with SetConfig({'backend': 'inductor'}): 27 | compiled = compile(model) 28 | # run the python code to compile the model. The fx graph and the guards will be printed out 29 | output1 = compiled(x) 30 | print("output1:", output1) 31 | 32 | # run the compiled model. "guard cache hit" means we find the compiled record and use it directly 33 | output2 = compiled(x) 34 | print("output2", output2) 35 | assert torch.allclose(expect_output, output1) 36 | assert torch.allclose(expect_output, output2) 37 | -------------------------------------------------------------------------------- /test/test_builtins.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from frontend.compile import compile, reset 3 | from frontend.utils import add_force_graph_break 4 | from frontend.c_api import get_next_frame_id 5 | import logging 6 | from common.checker import run_and_check, HIT, MISS 7 | 8 | 9 | def run_enumerate(x): 10 | s = 0 11 | for i, v in enumerate(x): 12 | s += i * v 13 | return s, enumerate(x) 14 | 15 | 16 | def run_enumerate2(x): 17 | s = 0 18 | for i, v in enumerate(x, 2): 19 | s += i * v 20 | return s 21 | 22 | 23 | def test_enumerate(caplog): 24 | reset() 25 | compiled_run_enumerate = compile(run_enumerate) 26 | expect_result = run_enumerate([1, 2, 3, 4, 5]) 27 | run_and_check(compiled_run_enumerate, [MISS], 1, caplog, expect_result, 28 | [1, 2, 3, 4, 5]) 29 | expect_result = run_enumerate([1, 2, 3, 4, 5]) 30 | run_and_check(compiled_run_enumerate, [HIT], 1, caplog, expect_result, 31 | [1, 2, 3, 4, 5]) 32 | compiled_run_enumerate2 = compile(run_enumerate2) 33 | expect_result2 = run_enumerate2([1, 2, 3, 4, 5]) 34 | run_and_check(compiled_run_enumerate2, [MISS], 2, caplog, expect_result2, 35 | [1, 2, 3, 4, 5]) 36 | run_and_check(compiled_run_enumerate2, [HIT], 2, caplog, expect_result2, 37 | [1, 2, 3, 4, 5]) 38 | -------------------------------------------------------------------------------- /test/test_call_function_ex.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from common.checker import run_and_check, HIT, MISS, ALL_MISS 3 | from frontend.compile import compile, reset 4 | 5 | 6 | def view_operation(a): 7 | shape = (2, 3) 8 | b = a.view(*shape) 9 | return b 10 | 11 | 12 | def reshape_operation(a): 13 | shape = (2, 3) 14 | b = a.reshape(*shape) 15 | return b 16 | 17 | 18 | def test_call_function_ex(caplog): 19 | reset() 20 | compiled2 = compile(view_operation) 21 | compiled3 = compile(reshape_operation) 22 | tensor = torch.tensor([1, 2, 3, 4, 5, 6]) 23 | result = view_operation(tensor) 24 | run_and_check(compiled2, [MISS], 1, caplog, result, tensor) 25 | run_and_check(compiled2, [HIT], 1, caplog, result, tensor) 26 | result = reshape_operation(tensor) 27 | run_and_check(compiled3, [MISS], 2, caplog, result, tensor) 28 | run_and_check(compiled3, [HIT], 2, caplog, result, tensor) 29 | 30 | 31 | class closure_call(torch.nn.Module): 32 | 33 | def __init__(self): 34 | super().__init__() 35 | self.n_heads = 2 36 | self.key_value_proj_dim = 6 37 | 38 | def forward(self, x): 39 | 40 | def shape(states): 41 | return states.view(self.n_heads, self.key_value_proj_dim) 42 | 43 | def project(hidden_states): 44 | hidden_states = shape(hidden_states) 45 | return hidden_states 46 | 47 | key_states = project(x) 48 | return key_states 49 | 50 | 51 | def test_closure_call(caplog): 52 | reset() 53 | with torch.no_grad(): 54 | model = closure_call().eval() 55 | a = torch.arange(12).reshape(3, 4) 56 | compiled = compile(model) 57 | expect_result = model(a) 58 | run_and_check(compiled, [ALL_MISS], 1, caplog, expect_result, a) 59 | run_and_check(compiled, [HIT], 1, caplog, expect_result, a) 60 | 61 | 62 | def inner_call_ex(a, b, **kwargs): 63 | return torch.add(a, b, **kwargs) 64 | 65 | 66 | def outer_call_ex(a, b): 67 | return inner_call_ex(a, b, alpha=1.0) 68 | 69 | 70 | def test_call_ex(caplog): 71 | reset() 72 | with torch.no_grad(): 73 | a = torch.rand((2, 2)) 74 | b = torch.rand((2, 2)) 75 | expect = outer_call_ex(a, b) 76 | compiled = compile(outer_call_ex) 77 | run_and_check(compiled, [ALL_MISS], 1, caplog, expect, a, b) 78 | run_and_check(compiled, [HIT], 1, caplog, expect, a, b) 79 | 80 | 81 | def inner_call_ex_with_update(a, b, **kwargs): 82 | kwargs.update(alpha=1.0) 83 | return torch.add(a, b, **kwargs) 84 | 85 | 86 | def outer_call_ex_with_update(a, b): 87 | return inner_call_ex_with_update(a, b, alpha=2.0) 88 | 89 | 90 | def test_call_ex_with_update(caplog): 91 | reset() 92 | with torch.no_grad(): 93 | a = torch.rand((2, 2)) 94 | b = torch.rand((2, 2)) 95 | expect = outer_call_ex_with_update(a, b) 96 | compiled = compile(outer_call_ex_with_update) 97 | run_and_check(compiled, [ALL_MISS], 1, caplog, expect, a, b) 98 | run_and_check(compiled, [HIT], 1, caplog, expect, a, b) 99 | 100 | 101 | def callee_kw(a, b): 102 | return a[0] + b 103 | 104 | 105 | def caller_kw(a, b): 106 | return callee_kw((a, 2), b=b) 107 | 108 | 109 | def test_caller_kw(caplog): 110 | reset() 111 | with torch.no_grad(): 112 | a = 1 113 | b = 3 114 | expect = caller_kw(a, b) 115 | compiled = compile(caller_kw) 116 | run_and_check(compiled, [ALL_MISS], 1, caplog, expect, a, b) 117 | run_and_check(compiled, [HIT], 1, caplog, expect, a, b) 118 | -------------------------------------------------------------------------------- /test/test_cuda.py: -------------------------------------------------------------------------------- 1 | from frontend.compile import compile, reset 2 | from frontend.utils import add_force_graph_break 3 | from frontend.c_api import get_next_frame_id 4 | from common.checker import run_and_check, HIT, MISS 5 | import torch 6 | 7 | 8 | def simple_add(a): 9 | return a + 1 10 | 11 | 12 | def test_simple_add(caplog): 13 | reset() 14 | a = torch.full((1, 1), 1.0).cuda() 15 | expected = simple_add(a) 16 | compiled = compile(simple_add) 17 | run_and_check(compiled, [MISS], 1, caplog, expected, a) 18 | run_and_check(compiled, [HIT], 1, caplog, expected, a) 19 | -------------------------------------------------------------------------------- /test/test_dict.py: -------------------------------------------------------------------------------- 1 | from frontend.compile import compile, reset 2 | from common.checker import run_and_check, HIT, MISS, assert_equal 3 | from collections import OrderedDict 4 | import torch 5 | 6 | 7 | def without_tensor_0(a): 8 | return a 9 | 10 | 11 | def func(a, *args, b, **kwargs): 12 | return a + args[0] + b + kwargs['c'] 13 | 14 | 15 | def call_function_ex(a, *args, b, **kwargs): 16 | return func(a, *args, b=b, **kwargs) 17 | 18 | 19 | def test_without_tensor(caplog): 20 | reset() 21 | compiled_no_tensor0 = compile(without_tensor_0) 22 | compiled_no_tensor1 = compile(call_function_ex) 23 | a = {2.8: 1.1, 2.1: 2, 2.2: 3.3} 24 | result = without_tensor_0(a) 25 | run_and_check(compiled_no_tensor0, [MISS], 1, caplog, result, a) 26 | run_and_check(compiled_no_tensor0, [HIT], 1, caplog, result, a) 27 | a = {} 28 | result = without_tensor_0(a) 29 | run_and_check(compiled_no_tensor0, [MISS], 2, caplog, result, a) 30 | run_and_check(compiled_no_tensor0, [HIT], 2, caplog, result, a) 31 | result = call_function_ex(1.0, 2.0, b=3.0, c=4.0) 32 | run_and_check(compiled_no_tensor1, [MISS, MISS], 33 | 3, 34 | caplog, 35 | result, 36 | 1.0, 37 | 2.0, 38 | b=3.0, 39 | c=4.0) 40 | run_and_check(compiled_no_tensor1, [HIT], 41 | 3, 42 | caplog, 43 | result, 44 | 1.0, 45 | 2.0, 46 | b=3.0, 47 | c=4.0) 48 | 49 | 50 | def tensor_0(a, b): 51 | return {1: 1, 2: 4, 3: a + b} 52 | 53 | 54 | def test_with_tensor(caplog): 55 | reset() 56 | compiled_tensor0 = compile(without_tensor_0) 57 | compiled_tensor5 = compile(tensor_0) 58 | compiled_tensor6 = compile(call_function_ex) 59 | a = torch.full((1,), 5.0) 60 | b = torch.full((1,), 7.0) 61 | c = torch.full((1,), 7.0) 62 | dict_a = {1: c, 2: b, 4: 1, 3: a} 63 | result = without_tensor_0(dict_a) 64 | run_and_check(compiled_tensor0, [MISS], 1, caplog, result, dict_a) 65 | run_and_check(compiled_tensor0, [HIT], 1, caplog, result, dict_a) 66 | result = tensor_0(a, b) 67 | run_and_check(compiled_tensor5, [MISS], 2, caplog, result, a, b) 68 | run_and_check(compiled_tensor5, [HIT], 2, caplog, result, a, b) 69 | a = torch.full((1,), 6.0) 70 | b = torch.full((1,), 7.0) 71 | result = tensor_0(a, b) 72 | run_and_check(compiled_tensor5, [HIT], 2, caplog, result, a, b) 73 | # test nested dict 74 | dict_a = {1: [a, b, c], 2: (b, 2.2), 4: 1, 3: a} 75 | result = without_tensor_0(dict_a) 76 | run_and_check(compiled_tensor0, [MISS], 3, caplog, result, dict_a) 77 | run_and_check(compiled_tensor0, [HIT], 3, caplog, result, dict_a) 78 | temp1 = torch.full((1,), 3.0) 79 | temp2 = torch.full((1,), 4.0) 80 | temp3 = torch.full((1,), 5.0) 81 | temp4 = torch.full((1,), 6.0) 82 | result = call_function_ex(temp1, temp2, b=temp3, c=temp4) 83 | run_and_check(compiled_tensor6, [MISS, MISS], 84 | 4, 85 | caplog, 86 | result, 87 | temp1, 88 | temp2, 89 | b=temp3, 90 | c=temp4) 91 | run_and_check(compiled_tensor6, [HIT], 92 | 4, 93 | caplog, 94 | result, 95 | temp1, 96 | temp2, 97 | b=temp3, 98 | c=temp4) 99 | 100 | 101 | def run_ordered_dict(a): 102 | b = OrderedDict([("three", 3)]) 103 | return OrderedDict([("one", a["one"]), ("two", a["two"]), 104 | ("three", b["three"])]) 105 | 106 | 107 | def test_ordered_dict(caplog): 108 | reset() 109 | numbers = OrderedDict([("one", 1), ("two", 2)]) 110 | compiled_ordered_dict = compile(run_ordered_dict) 111 | result = run_ordered_dict(numbers) 112 | run_and_check(compiled_ordered_dict, [MISS], 1, caplog, result, numbers) 113 | run_and_check(compiled_ordered_dict, [HIT], 1, caplog, result, numbers) 114 | -------------------------------------------------------------------------------- /test/test_dyn_shape.py: -------------------------------------------------------------------------------- 1 | from frontend.compile import compile, reset 2 | from frontend.utils import enable_dyn_shape 3 | from common.checker import run_and_check, HIT, MISS, ALL_MISS, assert_equal 4 | import torch 5 | 6 | 7 | def dyn_shape1(a): 8 | b = a * 2 9 | return b.view((b.shape[0] * 2, 2)) 10 | 11 | 12 | def dyn_shape2(a): 13 | return a.view((a.shape[0] * 2, 2)) * 2 14 | 15 | 16 | def dyn_callee(sz): 17 | return torch.ones((sz,)) 18 | 19 | 20 | def dyn_caller(a): 21 | b = a * 2 22 | return dyn_callee(b.size(0)) 23 | 24 | 25 | def test_dyn_shape(caplog): 26 | reset() 27 | for i, fn in enumerate((dyn_shape1, dyn_shape2, dyn_caller)): 28 | with enable_dyn_shape(): 29 | inp1 = torch.randn((5, 2, 2)) 30 | y1 = fn(inp1) 31 | inp2 = torch.randn((10, 2, 2)) 32 | y2 = fn(inp2) 33 | 34 | compiled = compile(fn) 35 | run_and_check(compiled, [ALL_MISS], i + 1, caplog, y1, inp1) 36 | run_and_check(compiled, [HIT], i + 1, caplog, y1, inp1) 37 | run_and_check(compiled, [HIT], i + 1, caplog, y2, inp2) 38 | 39 | 40 | class Model(torch.nn.Module): 41 | 42 | def __init__(self, *args, **kwargs) -> None: 43 | super().__init__(*args, **kwargs) 44 | self.linear = torch.nn.Linear(3, 3, bias=False) 45 | 46 | def forward(self, x): 47 | return self.linear(x * 2) 48 | 49 | 50 | def test_dyn_module(caplog): 51 | reset() 52 | with enable_dyn_shape(): 53 | with torch.no_grad(): 54 | model = Model().cuda().eval() 55 | inp1 = torch.randn((4, 5, 3)).cuda() 56 | y1 = model(inp1) 57 | inp2 = torch.randn((4, 5, 3)).cuda() 58 | y2 = model(inp2) 59 | 60 | compiled = compile(model) 61 | run_and_check(compiled, [MISS], 1, caplog, y1, inp1) 62 | run_and_check(compiled, [HIT], 1, caplog, y1, inp1) 63 | run_and_check(compiled, [HIT], 1, caplog, y2, inp2) 64 | 65 | 66 | class ModelParam(torch.nn.Module): 67 | 68 | def __init__(self, *args, **kwargs) -> None: 69 | super().__init__(*args, **kwargs) 70 | self.param = torch.nn.Parameter(torch.randn((3, 3))) 71 | 72 | def forward(self, x): 73 | return x * self.param 74 | 75 | 76 | def test_dyn_module_param(caplog): 77 | reset() 78 | with enable_dyn_shape(): 79 | with torch.no_grad(): 80 | model = ModelParam().eval() 81 | inp1 = torch.randn((4, 3, 3)) 82 | y1 = model(inp1) 83 | inp2 = torch.randn((4, 3, 3)) 84 | y2 = model(inp2) 85 | 86 | compiled = compile(model) 87 | run_and_check(compiled, [MISS], 1, caplog, y1, inp1) 88 | run_and_check(compiled, [HIT], 1, caplog, y1, inp1) 89 | run_and_check(compiled, [HIT], 1, caplog, y2, inp2) 90 | 91 | 92 | def shape_min_max(a, b, sz): 93 | x = min(a.size(0), b.size(0)) 94 | y = max(a.size(1), b.size(1)) 95 | z = max(a.size(2), sz) 96 | return torch.ones((x, y, z)) 97 | 98 | 99 | def test_shape_min_max(caplog): 100 | reset() 101 | with enable_dyn_shape(): 102 | with torch.no_grad(): 103 | x1 = torch.randn((4, 3, 5)).cuda() 104 | y1 = torch.randn((4, 3, 5)).cuda() 105 | z1 = 7 106 | out1 = shape_min_max(x1, y1, z1) 107 | x2 = torch.randn((5, 3, 5)).cuda() 108 | y2 = torch.randn((4, 3, 5)).cuda() 109 | z2 = 5 110 | out2 = shape_min_max(x2, y2, z2) 111 | 112 | compiled = compile(shape_min_max) 113 | run_and_check(compiled, [MISS], 1, caplog, out1, x1, y1, z1) 114 | run_and_check(compiled, [HIT], 1, caplog, out1, x1, y1, z1) 115 | run_and_check(compiled, [MISS], 2, caplog, out2, x2, y2, z2) 116 | run_and_check(compiled, [HIT], 2, caplog, out2, x2, y2, z2) 117 | 118 | 119 | def dyn_slice_callee(a, sz): 120 | return a[:sz] 121 | 122 | 123 | def dyn_slice_caller(a): 124 | b = a * 2 125 | return dyn_slice_callee(b, b.size(0)) 126 | 127 | 128 | def test_dyn_slice(caplog): 129 | reset() 130 | with enable_dyn_shape(): 131 | with torch.no_grad(): 132 | x1 = torch.randn((4, 3, 5)).cuda() 133 | out1 = dyn_slice_caller(x1) 134 | x2 = torch.randn((8, 3, 5)).cuda() 135 | out2 = dyn_slice_caller(x2) 136 | 137 | compiled = compile(dyn_slice_caller) 138 | run_and_check(compiled, [MISS, MISS], 1, caplog, out1, x1) 139 | run_and_check(compiled, [HIT], 1, caplog, out1, x1) 140 | run_and_check(compiled, [HIT], 1, caplog, out2, x2) 141 | -------------------------------------------------------------------------------- /test/test_end_of_control_flow.py: -------------------------------------------------------------------------------- 1 | from frontend.bytecode_writter import get_instructions 2 | from frontend.bytecode_analysis import end_of_control_flow 3 | 4 | 5 | def while_loop(): 6 | i = 0 7 | while i < 10: 8 | i += 1 9 | print(i) 10 | 11 | 12 | def for_loop(): 13 | for i in range(10): 14 | print(i) 15 | for j in (1, 2, 3): 16 | print(j) 17 | 18 | 19 | def if_else(): 20 | a = 1 21 | if a > 0: 22 | print("a > 0") 23 | if a > 1: 24 | print("a > 1") 25 | else: 26 | print("a <= 0") 27 | if a > -1: 28 | print("a > -1") 29 | else: 30 | print("a <= -1") 31 | 32 | 33 | def forward_lstm(self, inputs): # seq_len, batch, input_size 34 | state_c = () 35 | state_h = () 36 | for i in range(inputs.size()[0]): 37 | cur_input = inputs[i] 38 | for j in range(self.num_layers): 39 | c = cur_input 40 | h = c + cur_input 41 | return state_h[self.num_layers - 1] 42 | 43 | 44 | def forward_seq2seq(self, encoder_output, std, h, c): 45 | batch_size = encoder_output.size()[1] 46 | cond = True 47 | id = 0 48 | while cond: 49 | x = self.embedding(output) 50 | id = id + 1 51 | cond = (torch.max(output) > self.EOS_token) & (id < self.max_length) 52 | return x 53 | 54 | 55 | def forward_blockdrop(self, x, policy): 56 | 57 | x = self.seed(x) 58 | 59 | t = 0 60 | for segment, num_blocks in enumerate(self.layer_config): 61 | for b in range(num_blocks): 62 | action = policy[:, t].contiguous() 63 | residual = self.ds[segment](x) if b == 0 else x 64 | 65 | # early termination if all actions in the batch are zero 66 | if action.data.sum() == 0: 67 | x = residual 68 | t += 1 69 | continue 70 | 71 | action_mask = action.float().view(-1, 1, 1, 1) 72 | fx = F.relu(residual + self.blocks[segment][b](x)) 73 | x = fx * action_mask + residual * (1 - action_mask) 74 | t += 1 75 | 76 | x = self.avgpool(x) 77 | x = x.view(x.size(0), -1) 78 | x = self.fc(x) 79 | return x 80 | 81 | 82 | def check_one(f, start_pc: int, end_pc: int): 83 | instructions = get_instructions(f) 84 | end_pc_out = end_of_control_flow(instructions, start_pc) 85 | assert end_pc == end_pc_out, f"end_pc_ref: {end_pc}, end_pc_out: {end_pc_out}" 86 | 87 | 88 | def test_end_of_control_flow(): 89 | check_one(while_loop, 5, 15) 90 | check_one(for_loop, 4, 11) 91 | check_one(for_loop, 13, 20) 92 | check_one(if_else, 5, 36) 93 | check_one(if_else, 13, 36) 94 | check_one(if_else, 26, 36) 95 | check_one(forward_lstm, 12, 33) 96 | check_one(forward_lstm, 23, 32) 97 | check_one(forward_seq2seq, 11, 35) 98 | check_one(forward_blockdrop, 12, 99) 99 | check_one(forward_blockdrop, 20, 98) 100 | check_one(forward_blockdrop, 35, 44) 101 | check_one(forward_blockdrop, 51, 20) -------------------------------------------------------------------------------- /test/test_extend_arg.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from frontend.compile import compile, reset 3 | import logging 4 | from common.checker import run_and_check, HIT, MISS 5 | 6 | 7 | def with_extend_arg(a): 8 | return a + 0 + 1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9 + 10 + 11 + 12 + 13 + 14 + 15 + 16 + 17 + 18 + 19 + 20 + 21 + 22 + 23 + 24 + 25 + 26 + 27 + 28 + 29 + 30 + 31 + 32 + 33 + 34 + 35 + 36 + 37 + 38 + 39 + 40 + 41 + 42 + 43 + 44 + 45 + 46 + 47 + 48 + 49 + 50 + 51 + 52 + 53 + 54 + 55 + 56 + 57 + 58 + 59 + 60 + 61 + 62 + 63 + 64 + 65 + 66 + 67 + 68 + 69 + 70 + 71 + 72 + 73 + 74 + 75 + 76 + 77 + 78 + 79 + 80 + 81 + 82 + 83 + 84 + 85 + 86 + 87 + 88 + 89 + 90 + 91 + 92 + 93 + 94 + 95 + 96 + 97 + 98 + 99 + 100 + 101 + 102 + 103 + 104 + 105 + 106 + 107 + 108 + 109 + 110 + 111 + 112 + 113 + 114 + 115 + 116 + 117 + 118 + 119 + 120 + 121 + 122 + 123 + 124 + 125 + 126 + 127 + 128 + 129 + 130 + 131 + 132 + 133 + 134 + 135 + 136 + 137 + 138 + 139 + 140 + 141 + 142 + 143 + 144 + 145 + 146 + 147 + 148 + 149 + 150 + 151 + 152 + 153 + 154 + 155 + 156 + 157 + 158 + 159 + 160 + 161 + 162 + 163 + 164 + 165 + 166 + 167 + 168 + 169 + 170 + 171 + 172 + 173 + 174 + 175 + 176 + 177 + 178 + 179 + 180 + 181 + 182 + 183 + 184 + 185 + 186 + 187 + 188 + 189 + 190 + 191 + 192 + 193 + 194 + 195 + 196 + 197 + 198 + 199 + 200 + 201 + 202 + 203 + 204 + 205 + 206 + 207 + 208 + 209 + 210 + 211 + 212 + 213 + 214 + 215 + 216 + 217 + 218 + 219 + 220 + 221 + 222 + 223 + 224 + 225 + 226 + 227 + 228 + 229 + 230 + 231 + 232 + 233 + 234 + 235 + 236 + 237 + 238 + 239 + 240 + 241 + 242 + 243 + 244 + 245 + 246 + 247 + 248 + 249 + 250 + 251 + 252 + 253 + 254 + 255 + 256 + 257 + 258 + 259 9 | 10 | 11 | def test_extend_arg(caplog): 12 | reset() 13 | compiled_with_extend_arg = compile(with_extend_arg) 14 | run_and_check(compiled_with_extend_arg, [MISS], 1, caplog, 33671, 1) 15 | run_and_check(compiled_with_extend_arg, [HIT], 1, caplog, 33671, 1) -------------------------------------------------------------------------------- /test/test_inplace.py: -------------------------------------------------------------------------------- 1 | from frontend.compile import compile, reset 2 | from frontend.utils import add_force_graph_break 3 | from frontend.c_api import get_next_frame_id 4 | from common.checker import run_and_check, HIT, MISS, assert_equal 5 | import torch 6 | 7 | 8 | def inplace_add(a, b): 9 | a += b 10 | return a 11 | 12 | 13 | def test_inplace_add(caplog): 14 | reset() 15 | compiled = compile(inplace_add) 16 | result1 = inplace_add(1.0, 2.0) 17 | run_and_check(compiled, [MISS], 1, caplog, result1, 1.0, 2.0) 18 | run_and_check(compiled, [HIT], 1, caplog, result1, 1.0, 2.0) 19 | 20 | result2 = inplace_add((1, 2), (3, 4)) 21 | run_and_check(compiled, [MISS], 2, caplog, result2, (1, 2), (3, 4)) 22 | run_and_check(compiled, [HIT], 2, caplog, result2, (1, 2), (3, 4)) 23 | 24 | def get_input3(): 25 | return [1, 2], [3, 4] 26 | 27 | result3 = inplace_add(*get_input3()) 28 | run_and_check(compiled, [MISS], 3, caplog, result3, *get_input3()) 29 | run_and_check(compiled, [HIT], 3, caplog, result3, *get_input3()) 30 | input3 = get_input3() 31 | output3 = compiled(*input3) 32 | assert_equal(id(input3[0]), id(output3)) 33 | 34 | result4 = inplace_add(torch.tensor(1), torch.tensor(2)) 35 | run_and_check(compiled, [MISS], 4, caplog, result4, torch.tensor(1), 36 | torch.tensor(2)) 37 | result5 = inplace_add(torch.tensor(3), torch.tensor(4)) 38 | run_and_check(compiled, [HIT], 4, caplog, result5, torch.tensor(3), 39 | torch.tensor(4)) 40 | input6 = (torch.tensor(5), torch.tensor(6)) 41 | result6 = compiled(*input6) 42 | assert_equal(id(input6[0]), id(result6)) 43 | 44 | result7 = inplace_add(3, torch.tensor(5)) 45 | run_and_check(compiled, [MISS], 5, caplog, result7, 3, torch.tensor(5)) 46 | run_and_check(compiled, [HIT], 5, caplog, result7, 3, torch.tensor(5)) 47 | result8 = inplace_add(3, torch.tensor(9)) 48 | run_and_check(compiled, [HIT], 5, caplog, result8, 3, torch.tensor(9)) 49 | 50 | 51 | # TODO: 52 | # def inplace_add2(a, b): 53 | # a += b 54 | # return b # but a is still modified 55 | 56 | 57 | def store_subscr_add(a, b): 58 | a[1] += b 59 | return a 60 | 61 | 62 | def test_inplace_subscr_add(caplog): 63 | reset() 64 | compiled = compile(store_subscr_add) 65 | 66 | def get_input1(): 67 | return [1, 2], 3 68 | 69 | result1 = store_subscr_add(*get_input1()) 70 | run_and_check(compiled, [MISS], 1, caplog, result1, *get_input1()) 71 | run_and_check(compiled, [HIT], 1, caplog, result1, *get_input1()) 72 | input1 = get_input1() 73 | output1 = compiled(*input1) 74 | assert_equal(id(input1[0]), id(output1)) 75 | 76 | def get_input2(): 77 | return torch.tensor([1, 2]), torch.tensor(3) 78 | 79 | result2 = store_subscr_add(*get_input2()) 80 | run_and_check(compiled, [MISS], 2, caplog, result2, *get_input2()) 81 | run_and_check(compiled, [HIT], 2, caplog, result2, *get_input2()) 82 | input2 = get_input2() 83 | output2 = compiled(*input2) 84 | assert_equal(id(input2[0]), id(output2)) 85 | 86 | 87 | def store_subscr(a, b): 88 | a[1] = b 89 | return a, b 90 | 91 | 92 | def test_store_subscr(caplog): 93 | reset() 94 | compiled = compile(store_subscr) 95 | 96 | def get_input1(): 97 | return [1, 2], [3, 4] 98 | 99 | result = store_subscr(*get_input1()) 100 | run_and_check(compiled, [MISS], 1, caplog, result, *get_input1()) 101 | run_and_check(compiled, [HIT], 1, caplog, result, *get_input1()) 102 | a, b = get_input1() 103 | output = compiled(a, b) 104 | assert_equal(id(a), id(output[0])) 105 | assert_equal(id(b), id(output[0][1])) 106 | assert_equal(id(b), id(output[1])) 107 | 108 | 109 | def store_without_return(a, b): 110 | a[1] = b 111 | return b 112 | 113 | 114 | def test_store_without_return(caplog): 115 | reset() 116 | compiled = compile(store_without_return) 117 | 118 | def get_input1(): 119 | return [1, 2], [3, 4] 120 | 121 | a, b = get_input1() 122 | result = store_without_return(a, b) 123 | 124 | run_and_check(compiled, [MISS], 1, caplog, result, *get_input1()) 125 | run_and_check(compiled, [HIT], 1, caplog, result, *get_input1()) 126 | 127 | a1, b1 = get_input1() 128 | output = compiled(a1, b1) 129 | assert_equal(id(b1), id(output)) 130 | assert_equal(a1, a) 131 | 132 | 133 | def store_to_temp1(a): 134 | b = [1, 2, 3] 135 | b[2] = a 136 | return b 137 | 138 | 139 | def store_to_temp2(a): 140 | b = [1, 2, 3] 141 | b[2] = a 142 | return a 143 | 144 | 145 | def test_store_to_temp(caplog): 146 | reset() 147 | 148 | result = store_to_temp1(4) 149 | compiled = compile(store_to_temp1) 150 | run_and_check(compiled, [MISS], 1, caplog, result, 4) 151 | run_and_check(compiled, [HIT], 1, caplog, result, 4) 152 | 153 | result = store_to_temp2(4) 154 | compiled = compile(store_to_temp2) 155 | run_and_check(compiled, [MISS], 2, caplog, result, 4) 156 | run_and_check(compiled, [HIT], 2, caplog, result, 4) 157 | 158 | 159 | def inplace_callee_no_ret(a): 160 | a[1] = 2.0 161 | 162 | 163 | def inplace_callee_ret(a): 164 | a[1] = 2.0 165 | return a 166 | 167 | 168 | def inplace_callee_add_no_ret(a): 169 | a[1] += 2.0 170 | 171 | 172 | def inplace_callee_add_ret(a): 173 | a[1] += 2.0 174 | return a 175 | 176 | 177 | def caller1(a): 178 | inplace_callee_no_ret(a) 179 | return a 180 | 181 | 182 | def caller2(a): 183 | inplace_callee_no_ret(a) 184 | 185 | 186 | def caller3(a): 187 | inplace_callee_ret(a) 188 | return a 189 | 190 | 191 | def caller4(a): 192 | inplace_callee_ret(a) 193 | 194 | 195 | def caller5(a): 196 | inplace_callee_add_no_ret(a) 197 | return a 198 | 199 | 200 | def caller6(a): 201 | inplace_callee_add_no_ret(a) 202 | 203 | 204 | def caller7(a): 205 | inplace_callee_add_ret(a) 206 | return a 207 | 208 | 209 | def caller8(a): 210 | inplace_callee_add_ret(a) 211 | 212 | 213 | def test_inplace_function(caplog): 214 | reset() 215 | fs = [ 216 | caller1, caller2, caller3, caller4, caller5, caller6, caller7, caller8 217 | ] 218 | 219 | def get_input1(): 220 | return [1.0, 3.0] 221 | 222 | def get_input2(): 223 | return torch.tensor([1.0, 3.0]) 224 | 225 | cache_size = 0 226 | for f in fs: 227 | print("===============running", f) 228 | compiled = compile(f) 229 | cache_size += 1 230 | original_input = get_input1() 231 | result = f(original_input) 232 | run_and_check(compiled, [MISS, MISS], cache_size, caplog, result, 233 | get_input1()) 234 | run_and_check(compiled, [HIT], cache_size, caplog, result, get_input1()) 235 | input1 = get_input1() 236 | _output = compiled(input1) 237 | assert_equal(input1, original_input) 238 | 239 | compiled = compile(f) 240 | cache_size += 1 241 | original_input = get_input2() 242 | result = f(original_input) 243 | run_and_check(compiled, [MISS, MISS], cache_size, caplog, result, 244 | get_input2()) 245 | run_and_check(compiled, [HIT], cache_size, caplog, result, get_input2()) 246 | input2 = get_input2() 247 | _output = compiled(input2) 248 | assert_equal(input2, original_input) 249 | -------------------------------------------------------------------------------- /test/test_int_cache.py: -------------------------------------------------------------------------------- 1 | def fn(a, b, c, d): 2 | return a + b, c + d 3 | 4 | 5 | def test_int_cache(): 6 | a, b = fn(1, 4, 2, 3) 7 | assert id(a) != id(b) -------------------------------------------------------------------------------- /test/test_list.py: -------------------------------------------------------------------------------- 1 | from frontend.compile import compile, reset 2 | from common.checker import run_and_check, HIT, MISS, ALL_MISS, assert_equal 3 | import torch 4 | import numpy as np 5 | 6 | 7 | def without_tensor_0(a): 8 | return a 9 | 10 | 11 | def without_tensor_1(a): 12 | return a * 3 13 | 14 | 15 | def without_tensor_2(a): 16 | return [a, 9] 17 | 18 | 19 | def without_tensor_3(a): 20 | return a[1:] 21 | 22 | 23 | def without_tensor_4(a, b): 24 | return a + b 25 | 26 | 27 | def test_without_tensor(caplog): 28 | reset() 29 | compiled_no_tensor0 = compile(without_tensor_0) 30 | compiled_no_tensor1 = compile(without_tensor_1) 31 | compiled_no_tensor2 = compile(without_tensor_2) 32 | compiled_no_tensor3 = compile(without_tensor_3) 33 | compiled_no_tensor4 = compile(without_tensor_4) 34 | a = [1, 2.5] 35 | b = [2, 4] 36 | result = without_tensor_0(a) 37 | run_and_check(compiled_no_tensor0, [MISS], 1, caplog, result, a) 38 | run_and_check(compiled_no_tensor0, [HIT], 1, caplog, result, a) 39 | result = without_tensor_1(a) 40 | run_and_check(compiled_no_tensor1, [MISS], 2, caplog, result, a) 41 | run_and_check(compiled_no_tensor1, [HIT], 2, caplog, result, a) 42 | result = without_tensor_2(a) 43 | run_and_check(compiled_no_tensor2, [MISS], 3, caplog, result, a) 44 | run_and_check(compiled_no_tensor2, [HIT], 3, caplog, result, a) 45 | result = without_tensor_3(a) 46 | run_and_check(compiled_no_tensor3, [MISS], 4, caplog, result, a) 47 | run_and_check(compiled_no_tensor3, [HIT], 4, caplog, result, a) 48 | result = without_tensor_4(a, b) 49 | run_and_check(compiled_no_tensor4, [MISS], 5, caplog, result, a, b) 50 | run_and_check(compiled_no_tensor4, [HIT], 5, caplog, result, a, b) 51 | a = [10, 6] 52 | b = [8, 7] 53 | result = without_tensor_0(a) 54 | run_and_check(compiled_no_tensor0, [MISS], 6, caplog, result, a) 55 | run_and_check(compiled_no_tensor0, [HIT], 6, caplog, result, a) 56 | result = without_tensor_4(a, b) 57 | run_and_check(compiled_no_tensor4, [MISS], 7, caplog, result, a, b) 58 | run_and_check(compiled_no_tensor4, [HIT], 7, caplog, result, a, b) 59 | 60 | 61 | def tensor_0(list_a, list_b): 62 | return list_a[3] + list_b[2] 63 | 64 | 65 | def tensor_1(list_a, list_b): 66 | return list_a[3] * list_b[2] 67 | 68 | 69 | def tensor_2(list_a): 70 | return list_a 71 | 72 | 73 | def tensor_3(list_a, list_b): 74 | return list_a + list_b 75 | 76 | 77 | def tensor_4(list_a, list_b): 78 | return list_a + [3] 79 | 80 | 81 | def tensor_5(list_a): 82 | return list_a[..., 2:] 83 | 84 | 85 | def list_id(list_a, list_b): 86 | c = list_a + list_b 87 | return c[3], c[6] 88 | 89 | 90 | def test_with_tensor(caplog): 91 | reset() 92 | compiled_tensor0 = compile(tensor_0) 93 | compiled_tensor1 = compile(tensor_1) 94 | compiled_tensor2 = compile(tensor_2) 95 | compiled_tensor3 = compile(tensor_3) 96 | compiled_tensor4 = compile(tensor_4) 97 | compiled_tensor5 = compile(list_id) 98 | compiled_tensor6 = compile(tensor_5) 99 | a = torch.full((1,), 5.0) 100 | b = torch.full((1,), 7.0) 101 | list_a = [1, 2, 4, a] 102 | list_b = [3.5, 7, b] 103 | result = tensor_0(list_a, list_b) 104 | run_and_check(compiled_tensor0, [MISS], 1, caplog, result, list_a, list_b) 105 | run_and_check(compiled_tensor0, [HIT], 1, caplog, result, list_a, list_b) 106 | result = tensor_1(list_a, list_b) 107 | run_and_check(compiled_tensor1, [MISS], 2, caplog, result, list_a, list_b) 108 | run_and_check(compiled_tensor1, [HIT], 2, caplog, result, list_a, list_b) 109 | result = tensor_2(list_a) 110 | run_and_check(compiled_tensor2, [MISS], 3, caplog, result, list_a) 111 | run_and_check(compiled_tensor2, [HIT], 3, caplog, result, list_a) 112 | result = tensor_3(list_a, list_b) 113 | run_and_check(compiled_tensor3, [MISS], 4, caplog, result, list_a, list_b) 114 | run_and_check(compiled_tensor3, [HIT], 4, caplog, result, list_a, list_b) 115 | result = tensor_4(list_a, list_b) 116 | run_and_check(compiled_tensor4, [MISS], 5, caplog, result, list_a, list_b) 117 | run_and_check(compiled_tensor4, [HIT], 5, caplog, result, list_a, list_b) 118 | list_a = [1, 2, 4, a] 119 | list_b = [3.5, 7, a] 120 | result = list_id(list_a, list_b) 121 | assert_equal(id(result[0]), id(result[1])) 122 | assert_equal(id(result[0]), id(compiled_tensor5(list_a, list_b)[1])) 123 | assert_equal(id(compiled_tensor5(list_a, list_b)[0]), 124 | id(compiled_tensor5(list_a, list_b)[1])) 125 | # test nested list 126 | list_a = [1, 2, 4, (6, 7), a, [8, (9, 10), 11]] 127 | list_b = [3.5, 7, b] 128 | result = tensor_3(list_a, list_b) 129 | run_and_check(compiled_tensor3, [MISS], 7, caplog, result, list_a, list_b) 130 | run_and_check(compiled_tensor3, [HIT], 7, caplog, result, list_a, list_b) 131 | #TODO: support numpy array variables 132 | # list_a = np.array([[1, 2, 3, 4], 133 | # [5, 6, 7, 8], 134 | # [9, 10, 11, 12]]) 135 | # result = tensor_5(list_a) 136 | # run_and_check(compiled_tensor6, [MISS], 8, caplog, result, list_a) 137 | # run_and_check(compiled_tensor6, [HIT], 8, caplog, result, list_a) 138 | 139 | 140 | def list_contains(a, b): 141 | return b in a 142 | 143 | 144 | def test_list_contains(caplog): 145 | reset() 146 | a = [1.0, 2.0, 3.0] 147 | b = 3.0 148 | compiled_list_contains = compile(list_contains) 149 | run_and_check(compiled_list_contains, [MISS], 1, caplog, True, a, b) 150 | run_and_check(compiled_list_contains, [HIT], 1, caplog, True, a, b) 151 | 152 | 153 | def list_comp(a, b): 154 | return [i + b for i in a] 155 | 156 | 157 | def test_list_comp(caplog): 158 | reset() 159 | a = [1.0, 2.0, 3.0] 160 | b = 3.0 161 | compiled_list_comp = compile(list_comp) 162 | run_and_check(compiled_list_comp, [MISS, MISS], 1, caplog, [4.0, 5.0, 6.0], 163 | a, b) 164 | run_and_check(compiled_list_comp, [HIT], 1, caplog, [4.0, 5.0, 6.0], a, b) 165 | 166 | 167 | def test_list_comp_tensor(caplog): 168 | reset() 169 | a = [torch.tensor(1.0), torch.tensor(2.0), torch.tensor(3.0)] 170 | b = 3.0 171 | expect = list_comp(a, b) 172 | compiled_list_comp = compile(list_comp) 173 | run_and_check(compiled_list_comp, [MISS, MISS], 1, caplog, expect, a, b) 174 | run_and_check(compiled_list_comp, [HIT], 1, caplog, expect, a, b) 175 | 176 | 177 | def list_comp_with_wrapper(a, b): 178 | c = [torch.tensor(x + y) for x, y in zip(a, b)] 179 | d = [torch.tensor(x + y) for x, y in zip(a, c)] 180 | return d 181 | 182 | 183 | def test_list_comp_with_wrapper(caplog): 184 | reset() 185 | a = [1.0, 2.0, 3.0] 186 | b = [1.0, 2.0, 3.0] 187 | expect = list_comp_with_wrapper(a, b) 188 | compiled_list_comp_with_wrapper = compile(list_comp_with_wrapper) 189 | run_and_check(compiled_list_comp_with_wrapper, [MISS, MISS, MISS], 1, 190 | caplog, expect, a, b) 191 | run_and_check(compiled_list_comp_with_wrapper, [HIT], 1, caplog, expect, a, 192 | b) 193 | 194 | 195 | def list_inplace(): 196 | a = [1] 197 | a.append(2) 198 | return (3, a) 199 | 200 | 201 | def test_list_inplace(caplog): 202 | reset() 203 | compiled = compile(list_inplace) 204 | expect = list_inplace() 205 | run_and_check(compiled, [MISS], 1, caplog, expect) 206 | run_and_check(compiled, [HIT], 1, caplog, expect) 207 | 208 | 209 | # def unpack_list(a, b): 210 | # a, b = (y + 1 for y in [a,b]) 211 | # return a + b 212 | 213 | # def test_unpack_list(caplog): 214 | # reset() 215 | # compiled = compile(unpack_list) 216 | # expect = unpack_list(1, 2) 217 | # run_and_check(compiled, [ALL_MISS], 1, caplog, expect, 1,2) 218 | # run_and_check(compiled, [HIT], 1, caplog, expect, 1, 2) 219 | # a = torch.rand((2,2)) 220 | # b = torch.rand((2,2)) 221 | # expect = unpack_list(a, b) 222 | # run_and_check(compiled, [ALL_MISS], 2, caplog, expect, a, b) 223 | # run_and_check(compiled, [HIT], 2, caplog, expect, a, b) 224 | -------------------------------------------------------------------------------- /test/test_model_seq2seq.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from frontend.compile import compile, reset 3 | from frontend.utils import add_force_graph_break 4 | from frontend.c_api import get_next_frame_id 5 | import logging 6 | from common.checker import run_and_check, HIT, MISS 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | MAX_LENGTH = 50 12 | OUTPUT_SIZE = 60 13 | HIDDEN_SIZE = 4 14 | 15 | 16 | class LSTMCell(nn.Module): 17 | 18 | def __init__(self, hidden_size, input_size): 19 | super().__init__() 20 | self.weight_ih_l0_t = nn.Parameter( 21 | torch.randn(4, input_size, hidden_size, dtype=torch.float32)) 22 | self.weight_hh_l0_t = nn.Parameter( 23 | torch.randn(4, input_size, hidden_size, dtype=torch.float32)) 24 | self.bias_ih_0 = nn.Parameter( 25 | torch.randn(hidden_size, dtype=torch.float32)) 26 | self.bias_hh_0 = nn.Parameter( 27 | torch.randn(hidden_size, dtype=torch.float32)) 28 | self.bias_ih_1 = nn.Parameter( 29 | torch.randn(hidden_size, dtype=torch.float32)) 30 | self.bias_hh_1 = nn.Parameter( 31 | torch.randn(hidden_size, dtype=torch.float32)) 32 | self.bias_ih_2 = nn.Parameter( 33 | torch.randn(hidden_size, dtype=torch.float32)) 34 | self.bias_hh_2 = nn.Parameter( 35 | torch.randn(hidden_size, dtype=torch.float32)) 36 | self.bias_ih_3 = nn.Parameter( 37 | torch.randn(hidden_size, dtype=torch.float32)) 38 | self.bias_hh_3 = nn.Parameter( 39 | torch.randn(hidden_size, dtype=torch.float32)) 40 | self.hidden_size = hidden_size 41 | self.input_size = input_size 42 | nn.init.xavier_uniform_(self.weight_ih_l0_t) 43 | nn.init.xavier_uniform_(self.weight_hh_l0_t) 44 | 45 | def forward(self, x, h, c): 46 | ih = torch.matmul(x, self.weight_ih_l0_t) 47 | hh = torch.matmul(h, self.weight_hh_l0_t) 48 | ih0 = ih[0] + self.bias_ih_0 49 | hh0 = hh[0] + self.bias_hh_0 50 | ih1 = ih[1] + self.bias_ih_1 51 | hh1 = hh[1] + self.bias_hh_1 52 | ih2 = ih[2] + self.bias_ih_2 53 | hh2 = hh[2] + self.bias_hh_2 54 | ih3 = ih[3] + self.bias_ih_3 55 | hh3 = hh[3] + self.bias_hh_3 56 | 57 | ingate = torch.sigmoid(ih0 + hh0) 58 | forgetgate = torch.sigmoid(ih1 + hh1) 59 | cellgate = torch.tanh(ih2 + hh2) 60 | outgate = torch.sigmoid(ih3 + hh3) 61 | 62 | c = (forgetgate * c) + (ingate * cellgate) 63 | h = outgate * torch.tanh(c) 64 | return h, c 65 | 66 | 67 | class AttnDecoderRNN(nn.Module): 68 | 69 | def __init__(self, 70 | hidden_size, 71 | output_size, 72 | dropout_p=0.1, 73 | max_length=MAX_LENGTH): 74 | super(AttnDecoderRNN, self).__init__() 75 | self.hidden_size = hidden_size 76 | self.output_size = output_size 77 | self.dropout_p = dropout_p 78 | self.max_length = max_length 79 | 80 | self.gru = LSTMCell(self.hidden_size, self.hidden_size) 81 | self.out = nn.Linear(self.hidden_size, self.output_size) 82 | self.embedding = nn.Embedding(self.output_size, self.hidden_size) 83 | self.EOS_token = 0 84 | self.SOS_token = 1 85 | 86 | def forward_one(self, encoder_output, std, h, c): 87 | batch_size = encoder_output.size()[1] 88 | output_all = torch.zeros( 89 | self.max_length, batch_size, dtype=torch.int64, device='cuda') + 0 90 | output = torch.full((batch_size,), 91 | self.SOS_token, 92 | dtype=torch.int64, 93 | device='cuda') 94 | # cond = True 95 | id = 0 96 | # while cond: 97 | x = self.embedding(output) 98 | h = torch.reshape(h, (batch_size, self.hidden_size)) 99 | # lstm start 100 | ih = torch.matmul(x, self.gru.weight_ih_l0_t) 101 | hh = torch.matmul(h, self.gru.weight_hh_l0_t) 102 | ih0 = ih[0] + self.gru.bias_ih_0 103 | hh0 = hh[0] + self.gru.bias_hh_0 104 | ih1 = ih[1] + self.gru.bias_ih_1 105 | hh1 = hh[1] + self.gru.bias_hh_1 106 | ih2 = ih[2] + self.gru.bias_ih_2 107 | hh2 = hh[2] + self.gru.bias_hh_2 108 | ih3 = ih[3] + self.gru.bias_ih_3 109 | hh3 = hh[3] + self.gru.bias_hh_3 110 | 111 | ingate = torch.sigmoid(ih0 + hh0) 112 | forgetgate = torch.sigmoid(ih1 + hh1) 113 | cellgate = torch.tanh(ih2 + hh2) 114 | outgate = torch.sigmoid(ih3 + hh3) 115 | 116 | c = (forgetgate * c) + (ingate * cellgate) 117 | h = outgate * torch.tanh(c) 118 | # lstm end 119 | output = self.out(h) + std[id] 120 | output = output.argmax(1) 121 | output_all[id] = output 122 | id = id + 1 123 | return output_all, h, id 124 | 125 | 126 | def gen_mask_from_sequence(std): 127 | bs = std.shape[0] 128 | padded_std = torch.zeros((bs, MAX_LENGTH), dtype=std.dtype, device='cuda') 129 | padded_std[:, :std.shape[1]] = std 130 | mask = torch.zeros(bs, MAX_LENGTH, OUTPUT_SIZE, device='cuda') 131 | mask[torch.arange(bs).unsqueeze(1), 132 | torch.arange(MAX_LENGTH).unsqueeze(0), padded_std] = 1000000.0 133 | mask = mask.transpose(0, 1).contiguous().clone() 134 | return mask 135 | 136 | 137 | def get_input(batch_size): 138 | std = [] 139 | for i in range(batch_size): 140 | l = max(i, 10) 141 | l = min(l, MAX_LENGTH) 142 | lst = list(range(1, l)) 143 | lst.append(0) 144 | assert (len(lst) <= MAX_LENGTH) 145 | # pad to MAX_LENGTH 146 | lst = lst + [0] * (MAX_LENGTH - len(lst)) 147 | std.append(lst) 148 | std = torch.tensor(std, device='cuda') 149 | mask = gen_mask_from_sequence(std) 150 | encoder_output = torch.randn(MAX_LENGTH, 151 | batch_size, 152 | HIDDEN_SIZE, 153 | device='cuda') 154 | h = torch.randn(batch_size, HIDDEN_SIZE, device='cuda') 155 | c = torch.randn(batch_size, HIDDEN_SIZE, device='cuda') 156 | return encoder_output, mask, h, c 157 | 158 | 159 | def test_seq2seq_one_token(caplog): 160 | reset() 161 | with torch.no_grad(): 162 | batch_size = 2 163 | seq_len = 50 164 | model = AttnDecoderRNN( 165 | HIDDEN_SIZE, 166 | OUTPUT_SIZE, 167 | ).cuda() 168 | model.eval() 169 | args = get_input(batch_size) 170 | expect_result = model.forward_one(*args) 171 | compiled = compile(model.forward_one) 172 | run_and_check(compiled, [MISS], 1, caplog, expect_result, *args) 173 | run_and_check(compiled, [HIT], 1, caplog, expect_result, *args) 174 | -------------------------------------------------------------------------------- /test/test_nnmodule.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from frontend.compile import compile, reset 3 | from frontend.utils import add_force_graph_break 4 | from frontend.c_api import get_next_frame_id 5 | import logging 6 | import torch 7 | from common.checker import run_and_check, HIT, MISS, ALL_MISS 8 | 9 | 10 | class Model(torch.nn.Module): 11 | 12 | def __init__(self): 13 | super(Model, self).__init__() 14 | self.linear = torch.nn.Linear(10, 5) 15 | 16 | def forward(self, x): 17 | y = self.linear(x) 18 | z = y + 1.0 19 | return z 20 | 21 | 22 | class ModelParam(torch.nn.Module): 23 | 24 | def __init__(self): 25 | super(ModelParam, self).__init__() 26 | self.param = torch.nn.Parameter(torch.randn(5, 5)) 27 | 28 | def forward(self, x): 29 | y = self.param + x 30 | return y 31 | 32 | 33 | class Model2(torch.nn.Module): 34 | 35 | def __init__(self): 36 | super().__init__() 37 | self.linear = torch.nn.Linear(1, 1) 38 | 39 | def forward(self, x): 40 | return self.linear(x) - 4.0 41 | 42 | 43 | def call_model(model, a): 44 | b = model.linear(a) + 1 45 | return b 46 | 47 | 48 | def nn_module(a): 49 | b = torch.nn.Softmax(dim=-1)(a) 50 | return b 51 | 52 | 53 | def test_call_method(caplog): 54 | reset() 55 | with torch.no_grad(): 56 | model = Model().eval() 57 | x = torch.randn(1, 10) 58 | expect_result = model(x) 59 | add_force_graph_break(get_next_frame_id(), 3) 60 | compiled_model = compile(model) 61 | run_and_check(compiled_model, [MISS], 2, caplog, expect_result, x) 62 | run_and_check(compiled_model, [HIT, HIT], 2, caplog, expect_result, x) 63 | 64 | 65 | def test_module(caplog): 66 | reset() 67 | with torch.no_grad(): 68 | model = Model().eval() 69 | x = torch.randn(1, 10) 70 | expect_result = model(x) 71 | compiled_model = compile(model) 72 | run_and_check(compiled_model, [MISS], 1, caplog, expect_result, x) 73 | run_and_check(compiled_model, [HIT], 1, caplog, expect_result, x) 74 | 75 | 76 | def test_module_param(caplog): 77 | reset() 78 | with torch.no_grad(): 79 | model = ModelParam().eval() 80 | x = torch.randn(1, 5) 81 | expect_result = model(x) 82 | compiled_model = compile(model) 83 | run_and_check(compiled_model, [MISS], 1, caplog, expect_result, x) 84 | run_and_check(compiled_model, [HIT], 1, caplog, expect_result, x) 85 | 86 | 87 | def test_external_module(caplog): 88 | reset() 89 | with torch.no_grad(): 90 | model = Model2().eval() 91 | x = torch.randn(1, 1) 92 | expect_result = call_model(model, x) 93 | compiled_model = compile(call_model) 94 | run_and_check(compiled_model, [MISS], 1, caplog, expect_result, model, 95 | x) 96 | run_and_check(compiled_model, [HIT], 1, caplog, expect_result, model, x) 97 | 98 | 99 | def test_nn_module(caplog): 100 | reset() 101 | compiled = compile(nn_module) 102 | x = torch.randn(1, 10) 103 | expect_result = nn_module(x) 104 | run_and_check(compiled, [MISS], 1, caplog, expect_result, x) 105 | run_and_check(compiled, [HIT], 1, caplog, expect_result, x) 106 | 107 | 108 | class MapModule(torch.nn.Module): 109 | 110 | def __init__(self): 111 | super().__init__() 112 | self.linears = torch.nn.ModuleList( 113 | [torch.nn.Linear(3, 3) for _ in range(3)]) 114 | 115 | def forward(self, x): 116 | fmaps = tuple(map(lambda l: l(x), self.linears)) 117 | return torch.cat(fmaps, dim=1) 118 | 119 | 120 | def test_map_module(caplog): 121 | reset() 122 | model = MapModule() 123 | compiled = compile(model) 124 | x = torch.randn(3, 3) 125 | expect_result = model(x) 126 | run_and_check(compiled, [ALL_MISS], 1, caplog, expect_result, x) 127 | run_and_check(compiled, [HIT], 1, caplog, expect_result, x) 128 | 129 | 130 | class InplaceRelu(torch.nn.Module): 131 | 132 | def __init__(self) -> None: 133 | super().__init__() 134 | self.conv = torch.nn.Conv2d(3, 3, 3) 135 | self.bn = torch.nn.BatchNorm2d(3) 136 | self.relu = torch.nn.ReLU(inplace=True) 137 | 138 | def forward(self, x): 139 | x = self.conv(x) 140 | x = self.bn(x) 141 | x = self.relu(x) 142 | return x + 1.0 143 | 144 | 145 | def test_inplace_relu(caplog): 146 | reset() 147 | model = InplaceRelu().eval() 148 | compiled = compile(model) 149 | x = torch.randn(1, 3, 3, 3) 150 | expect_result = model(x) 151 | run_and_check(compiled, [MISS], 1, caplog, expect_result, x) 152 | run_and_check(compiled, [HIT], 1, caplog, expect_result, x) 153 | 154 | 155 | if __name__ == "__main__": 156 | caplog = logging.getLogger(__name__) 157 | test_call_method(caplog) 158 | test_module(caplog) 159 | -------------------------------------------------------------------------------- /test/test_numpy.py: -------------------------------------------------------------------------------- 1 | from frontend.compile import compile, reset 2 | from frontend.utils import add_force_graph_break 3 | from frontend.c_api import get_next_frame_id 4 | from common.checker import run_and_check, HIT, MISS, assert_equal 5 | import torch 6 | import numpy as np 7 | 8 | 9 | def numpy_to_int(x): 10 | p = int(np.floor((x - 1) / 2)) 11 | return p 12 | 13 | 14 | def np_array(x): 15 | return x 16 | 17 | 18 | def test_array(caplog): 19 | reset() 20 | compiled = compile(np_array) 21 | a = np.array([1, 2.0, 3.33]) 22 | result = np_array(a) 23 | run_and_check(compiled, [MISS], 1, caplog, result, a) 24 | run_and_check(compiled, [HIT], 1, caplog, result, a) 25 | 26 | 27 | def test_numpy_to_int(caplog): 28 | reset() 29 | compiled_numpy_to_int = compile(numpy_to_int) 30 | result = numpy_to_int(10) 31 | run_and_check(compiled_numpy_to_int, [MISS], 1, caplog, result, 10) 32 | run_and_check(compiled_numpy_to_int, [HIT], 1, caplog, result, 10) 33 | 34 | 35 | def numpy_to_torch(x): 36 | y = np.floor((x - 1) / 2) 37 | return torch.tensor(y) 38 | 39 | 40 | def test_numpy_to_torch(caplog): 41 | from frontend.utils import SetConfig 42 | with SetConfig({"backend": "eager"}): 43 | reset() 44 | compiled = compile(numpy_to_torch) 45 | a = np.array([1, 2.0, 3.33]) 46 | result = numpy_to_torch(a) 47 | run_and_check(compiled, [MISS], 1, caplog, result, a) 48 | run_and_check(compiled, [HIT], 1, caplog, result, a) -------------------------------------------------------------------------------- /test/test_random_key.py: -------------------------------------------------------------------------------- 1 | from frontend.utils import new_random_key 2 | 3 | import random 4 | 5 | 6 | def test_random_key(): 7 | random.seed(123) 8 | a = random.randint(0, 10000) 9 | b = random.randint(0, 10000) 10 | random.seed(123) 11 | aa = random.randint(0, 10000) 12 | key = new_random_key() 13 | bb = random.randint(0, 10000) 14 | assert a == aa 15 | assert b == bb 16 | -------------------------------------------------------------------------------- /test/test_scalar.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from frontend.compile import compile, reset 3 | from frontend.utils import add_force_graph_break 4 | from frontend.c_api import get_next_frame_id 5 | import logging 6 | from common.checker import run_and_check, HIT, MISS 7 | import torch 8 | 9 | 10 | def perfect0(a): 11 | return a + 1 12 | 13 | 14 | def perfect1(a): 15 | return a + 1 16 | 17 | 18 | def graph_break0(a): 19 | return (a + 1) // 2 20 | 21 | 22 | def graph_break1(a): 23 | return (a + 1) // 2 + 1 24 | 25 | 26 | def test_perfect(caplog): 27 | reset() 28 | compiled_perfect0 = compile(perfect0) 29 | compiled_perfect1 = compile(perfect1) 30 | run_and_check(compiled_perfect0, [MISS], 1, caplog, 4, 3) 31 | run_and_check(compiled_perfect0, [HIT], 1, caplog, 4, 3) 32 | run_and_check(compiled_perfect0, [MISS], 2, caplog, 5, 4) 33 | run_and_check(compiled_perfect0, [HIT], 2, caplog, 5, 4) 34 | run_and_check(compiled_perfect1, [MISS], 3, caplog, 4, 3) 35 | run_and_check(compiled_perfect1, [HIT], 3, caplog, 4, 3) 36 | run_and_check(compiled_perfect1, [MISS], 4, caplog, 5, 4) 37 | run_and_check(compiled_perfect1, [HIT], 4, caplog, 5, 4) 38 | 39 | 40 | def test_graph_break(caplog): 41 | reset() 42 | compiled_graph_break0 = compile(graph_break0) 43 | add_force_graph_break(get_next_frame_id(), 4) 44 | run_and_check(compiled_graph_break0, [MISS], 1, caplog, 2, 3) 45 | run_and_check(compiled_graph_break0, [HIT], 1, caplog, 2, 3) 46 | run_and_check(compiled_graph_break0, [MISS], 2, caplog, 2, 4) 47 | run_and_check(compiled_graph_break0, [HIT], 2, caplog, 2, 4) 48 | 49 | compiled_graph_break1 = compile(graph_break1) 50 | add_force_graph_break(get_next_frame_id(), 4) 51 | run_and_check(compiled_graph_break1, [MISS], 4, caplog, 2, 1) 52 | run_and_check(compiled_graph_break1, [HIT, HIT], 4, caplog, 2, 1) 53 | run_and_check(compiled_graph_break1, [MISS, HIT], 5, caplog, 2, 2) 54 | run_and_check(compiled_graph_break1, [HIT, HIT], 5, caplog, 2, 2) 55 | 56 | 57 | def perfect0_float(a): 58 | return a + 1.0 59 | 60 | 61 | def perfect1_float(a): 62 | return a + 1.0 63 | 64 | 65 | def graph_break0_float(a): 66 | return (a + 1.0) / 2 67 | 68 | 69 | def graph_break1_float(a): 70 | return (a + 1.0) / 2 + 1 71 | 72 | 73 | def test_perfect_float(caplog): 74 | reset() 75 | compiled_perfect0 = compile(perfect0_float) 76 | compiled_perfect1 = compile(perfect1_float) 77 | run_and_check(compiled_perfect0, [MISS], 1, caplog, 4.0, 3.0) 78 | run_and_check(compiled_perfect0, [HIT], 1, caplog, 4.0, 3.0) 79 | run_and_check(compiled_perfect0, [MISS], 2, caplog, 5.0, 4.0) 80 | run_and_check(compiled_perfect0, [HIT], 2, caplog, 5.0, 4.0) 81 | run_and_check(compiled_perfect1, [MISS], 3, caplog, 4.0, 3.0) 82 | run_and_check(compiled_perfect1, [HIT], 3, caplog, 4.0, 3.0) 83 | run_and_check(compiled_perfect1, [MISS], 4, caplog, 5.0, 4.0) 84 | run_and_check(compiled_perfect1, [HIT], 4, caplog, 5.0, 4.0) 85 | 86 | 87 | def test_graph_break_float(caplog): 88 | reset() 89 | add_force_graph_break(get_next_frame_id(), 4) 90 | compiled_graph_break0 = compile(graph_break0_float) 91 | run_and_check(compiled_graph_break0, [MISS], 1, caplog, 2.0, 3.0) 92 | run_and_check(compiled_graph_break0, [HIT], 1, caplog, 2.0, 3.0) 93 | run_and_check(compiled_graph_break0, [MISS], 2, caplog, 2.5, 4.0) 94 | run_and_check(compiled_graph_break0, [HIT], 2, caplog, 2.5, 4.0) 95 | 96 | add_force_graph_break(get_next_frame_id(), 4) 97 | compiled_graph_break1 = compile(graph_break1_float) 98 | run_and_check(compiled_graph_break1, [MISS], 4, caplog, 2.0, 1.0) 99 | run_and_check(compiled_graph_break1, [HIT, HIT], 4, caplog, 2.0, 1.0) 100 | 101 | 102 | def binary_add(a, b): 103 | return a + b 104 | 105 | 106 | def binary_subtract(a, b): 107 | return a - b 108 | 109 | 110 | def binary_multiply(a, b): 111 | return a * b 112 | 113 | 114 | def binary_floor_divide(a, b): 115 | return a // b 116 | 117 | 118 | def binary_true_divide(a, b): 119 | return a / b 120 | 121 | 122 | def binary_mod(a, b): 123 | return a % b 124 | 125 | 126 | def binary_power(a, b): 127 | return a**b 128 | 129 | 130 | def binary_lshift(a, b): 131 | return a << b 132 | 133 | 134 | def binary_rshift(a, b): 135 | return a >> b 136 | 137 | 138 | def binary_and(a, b): 139 | return a & b 140 | 141 | 142 | def binary_xor(a, b): 143 | return a ^ b 144 | 145 | 146 | def binary_or(a, b): 147 | return a | b 148 | 149 | 150 | def test_binary_op(caplog): 151 | reset() 152 | funcs = [ 153 | binary_add, binary_subtract, binary_multiply, binary_floor_divide, 154 | binary_true_divide, binary_mod, binary_power, binary_lshift, 155 | binary_rshift, binary_and, binary_xor, binary_or 156 | ] 157 | compiled_funcs = [compile(func) for func in funcs] 158 | cache_cnt = 0 159 | for func, compiled_func in zip(funcs, compiled_funcs): 160 | for a in [1, 2, 3]: 161 | for b in [1, 2, 3]: 162 | cache_cnt += 1 163 | run_and_check(compiled_func, [MISS], cache_cnt, caplog, 164 | func(a, b), a, b) 165 | for func, compiled_func in zip(funcs, compiled_funcs): 166 | for a in [1, 2, 3]: 167 | for b in [1, 2, 3]: 168 | run_and_check(compiled_func, [HIT], cache_cnt, caplog, 169 | func(a, b), a, b) 170 | 171 | 172 | def add3(a, b, c): 173 | return a + b + c 174 | 175 | 176 | def call_add3(a, b, c): 177 | return add3(a, b, c) 178 | 179 | 180 | def test_dyn_int(caplog): 181 | reset() 182 | import frontend.dynamic as dyn 183 | a = 1 184 | b = 2 185 | c = 3 186 | dyn.mark_dynamic(a, dyn.ScalarWithUnknownValue()) 187 | dyn.mark_dynamic(b, dyn.ScalarWithUnknownValue()) 188 | compiled_fn = compile(add3) 189 | run_and_check(compiled_fn, [MISS], 1, caplog, 6, a, b, c) 190 | run_and_check(compiled_fn, [HIT], 1, caplog, 6, a, b, c) 191 | run_and_check(compiled_fn, [HIT], 1, caplog, 12, 4, 5, 3) 192 | run_and_check(compiled_fn, [MISS], 2, caplog, 8, a, b, 5) 193 | 194 | 195 | def test_dyn_int_with_call(caplog): 196 | reset() 197 | import frontend.dynamic as dyn 198 | a = 1 199 | b = 2 200 | c = 3 201 | dyn.mark_dynamic(a, dyn.ScalarWithUnknownValue()) 202 | dyn.mark_dynamic(b, dyn.ScalarWithUnknownValue()) 203 | compiled_fn = compile(call_add3) 204 | run_and_check(compiled_fn, [MISS, MISS], 1, caplog, 6, a, b, c) 205 | run_and_check(compiled_fn, [HIT], 1, caplog, 6, a, b, c) 206 | run_and_check(compiled_fn, [HIT], 1, caplog, 12, 4, 5, 3) 207 | run_and_check(compiled_fn, [MISS, MISS], 2, caplog, 8, a, b, 5) 208 | 209 | 210 | def dynamic_scalar_from_tensor(a, b, c): 211 | d = float(a + b) 212 | return c + d 213 | 214 | 215 | def test_dynamic_scalar_from_tensor(caplog): 216 | reset() 217 | a = torch.tensor(1.0) 218 | b = torch.tensor(2.0) 219 | c = 3.0 220 | expect = dynamic_scalar_from_tensor(a, b, c) 221 | compiled = compile(dynamic_scalar_from_tensor) 222 | run_and_check(compiled, [MISS], 1, caplog, expect, a, b, c) 223 | run_and_check(compiled, [HIT], 1, caplog, expect, a, b, c) 224 | aa = torch.tensor(4.0) 225 | bb = torch.tensor(5.0) 226 | expect = dynamic_scalar_from_tensor(aa, bb, c) 227 | run_and_check(compiled, [HIT], 1, caplog, expect, aa, bb, c) 228 | 229 | 230 | def itertools_product(a, b): 231 | import itertools 232 | return list(itertools.product(a, b)) 233 | 234 | 235 | def test_itertools_product(caplog): 236 | reset() 237 | a = [1, 2] 238 | b = [3, 4] 239 | expect = itertools_product(a, b) 240 | compiled = compile(itertools_product) 241 | run_and_check(compiled, [MISS], 1, caplog, expect, a, b) 242 | run_and_check(compiled, [HIT], 1, caplog, expect, a, b) 243 | -------------------------------------------------------------------------------- /test/test_set.py: -------------------------------------------------------------------------------- 1 | from frontend.compile import compile, reset 2 | from common.checker import run_and_check, HIT, MISS, assert_equal 3 | import torch 4 | 5 | 6 | def without_tensor_0(a): 7 | return a 8 | 9 | 10 | def without_tensor_1(a, b): 11 | return a - b 12 | 13 | 14 | def without_tensor_2(a, b): 15 | return a | b 16 | 17 | 18 | def without_tensor_3(a, b): 19 | return a & b 20 | 21 | 22 | def without_tensor_4(a, b): 23 | return a ^ b 24 | 25 | 26 | def test_without_tensor(caplog): 27 | reset() 28 | compiled_no_tensor0 = compile(without_tensor_0) 29 | compiled_no_tensor1 = compile(without_tensor_1) 30 | compiled_no_tensor2 = compile(without_tensor_2) 31 | compiled_no_tensor3 = compile(without_tensor_3) 32 | compiled_no_tensor4 = compile(without_tensor_4) 33 | a = {2.8, 2.1, 2.2, 2.1} 34 | b = {2.8, 3.9, 4.1} 35 | result = without_tensor_0(a) 36 | run_and_check(compiled_no_tensor0, [MISS], 1, caplog, result, a) 37 | run_and_check(compiled_no_tensor0, [HIT], 1, caplog, result, a) 38 | result = without_tensor_1(a, b) 39 | run_and_check(compiled_no_tensor1, [MISS], 2, caplog, result, a, b) 40 | run_and_check(compiled_no_tensor1, [HIT], 2, caplog, result, a, b) 41 | result = without_tensor_2(a, b) 42 | run_and_check(compiled_no_tensor2, [MISS], 3, caplog, result, a, b) 43 | run_and_check(compiled_no_tensor2, [HIT], 3, caplog, result, a, b) 44 | result = without_tensor_3(a, b) 45 | run_and_check(compiled_no_tensor3, [MISS], 4, caplog, result, a, b) 46 | run_and_check(compiled_no_tensor3, [HIT], 4, caplog, result, a, b) 47 | result = without_tensor_4(a, b) 48 | run_and_check(compiled_no_tensor4, [MISS], 5, caplog, result, a, b) 49 | run_and_check(compiled_no_tensor4, [HIT], 5, caplog, result, a, b) 50 | a = {10.1, 6.2} 51 | b = {8.4, 7.2} 52 | result = without_tensor_0(a) 53 | run_and_check(compiled_no_tensor0, [MISS], 6, caplog, result, a) 54 | run_and_check(compiled_no_tensor0, [HIT], 6, caplog, result, a) 55 | result = without_tensor_4(a, b) 56 | run_and_check(compiled_no_tensor4, [MISS], 7, caplog, result, a, b) 57 | run_and_check(compiled_no_tensor4, [HIT], 7, caplog, result, a, b) 58 | 59 | 60 | def tensor_0(a, b): 61 | return {1, 2, 3, a + b} 62 | 63 | 64 | def test_with_tensor(caplog): 65 | reset() 66 | compiled_tensor0 = compile(without_tensor_0) 67 | compiled_tensor1 = compile(without_tensor_1) 68 | compiled_tensor2 = compile(without_tensor_2) 69 | compiled_tensor3 = compile(without_tensor_3) 70 | compiled_tensor4 = compile(without_tensor_4) 71 | compiled_tensor5 = compile(tensor_0) 72 | a = torch.full((1,), 5.0) 73 | b = torch.full((1,), 7.0) 74 | c = torch.full((1,), 7.0) 75 | set_a = {1, 2, 4, a, 4, 1} 76 | set_b = {3.5, 7, b, 4, 2, c} 77 | result = without_tensor_0(set_a) 78 | run_and_check(compiled_tensor0, [MISS], 1, caplog, result, set_a) 79 | run_and_check(compiled_tensor0, [HIT], 1, caplog, result, set_a) 80 | result = without_tensor_1(set_a, set_b) 81 | run_and_check(compiled_tensor1, [MISS], 2, caplog, result, set_a, set_b) 82 | run_and_check(compiled_tensor1, [HIT], 2, caplog, result, set_a, set_b) 83 | result = without_tensor_2(set_a, set_b) 84 | run_and_check(compiled_tensor2, [MISS], 3, caplog, result, set_a, set_b) 85 | run_and_check(compiled_tensor2, [HIT], 3, caplog, result, set_a, set_b) 86 | result = without_tensor_3(set_a, set_b) 87 | run_and_check(compiled_tensor3, [MISS], 4, caplog, result, set_a, set_b) 88 | run_and_check(compiled_tensor3, [HIT], 4, caplog, result, set_a, set_b) 89 | result = without_tensor_4(set_a, set_b) 90 | run_and_check(compiled_tensor4, [MISS], 5, caplog, result, set_a, set_b) 91 | run_and_check(compiled_tensor4, [HIT], 5, caplog, result, set_a, set_b) 92 | # test nested set 93 | set_a = {1, 2, 4, (6, 7), a, (8, (9, 10), 11)} 94 | set_b = {3.5, 7, b, (6.6, 8.8)} 95 | result = without_tensor_3(set_a, set_b) 96 | run_and_check(compiled_tensor3, [MISS], 6, caplog, result, set_a, set_b) 97 | run_and_check(compiled_tensor3, [HIT], 6, caplog, result, set_a, set_b) 98 | result = tensor_0(a, b) 99 | run_and_check(compiled_tensor5, [MISS], 7, caplog, result, a, b) 100 | run_and_check(compiled_tensor5, [HIT], 7, caplog, result, a, b) 101 | a = torch.full((1,), 6.0) 102 | b = torch.full((1,), 7.0) 103 | result = tensor_0(a, b) 104 | run_and_check(compiled_tensor5, [HIT], 7, caplog, result, a, b) 105 | -------------------------------------------------------------------------------- /test/test_stack_effect.py: -------------------------------------------------------------------------------- 1 | import dis 2 | from frontend.c_api import stack_effect 3 | 4 | 5 | def test_stack_effect(): 6 | for op in dis.opmap.values(): 7 | for oparg in range(0, 15): 8 | for jump in (None, True, False): 9 | try: 10 | ref = dis.stack_effect(op, oparg, jump=jump) 11 | except ValueError: 12 | continue 13 | if op == 130 and oparg >= 3: # RAISE_VARARGS 14 | continue 15 | out = stack_effect(op, oparg, jump) 16 | assert ref == out[2] - out[ 17 | 1], f"op: {dis.opname[op]}({op}), oparg: {oparg}, jump: {jump}, ref: {ref}, out: {out}" 18 | -------------------------------------------------------------------------------- /test/test_static_control_flow.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from frontend.compile import compile, reset 3 | from frontend.utils import add_force_graph_break 4 | from frontend.c_api import get_next_frame_id 5 | import logging 6 | from common.checker import run_and_check, HIT, MISS 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | 12 | class Model(torch.nn.Module): 13 | 14 | def __init__(self, bn): 15 | super().__init__() 16 | self.a = torch.nn.Parameter(torch.randn(2, 2)) 17 | if bn: 18 | self.bn = torch.nn.BatchNorm1d(2) 19 | else: 20 | self.bn = None 21 | 22 | def forward(self, x): 23 | y = self.a * x 24 | if self.bn: 25 | y = self.bn(y) 26 | return y 27 | 28 | 29 | def test_static_cf(caplog): 30 | reset() 31 | x = torch.randn(2, 2) 32 | model1 = Model(True).eval() 33 | expect_result = model1(x) 34 | compiled_model1 = compile(model1) 35 | run_and_check(compiled_model1, [MISS], 1, caplog, expect_result, x) 36 | run_and_check(compiled_model1, [HIT], 1, caplog, expect_result, x) 37 | 38 | model2 = Model(False).eval() 39 | expect_result = model2(x) 40 | compiled_model2 = compile(model2) 41 | run_and_check(compiled_model2, [MISS], 2, caplog, expect_result, x) 42 | run_and_check(compiled_model2, [HIT], 2, caplog, expect_result, x) 43 | 44 | reset() 45 | model1 = Model(True).eval() 46 | expect_result1 = model1(x) 47 | add_force_graph_break(get_next_frame_id() - 1, 8) # frame of model1 48 | compiled_model1 = compile(model1) 49 | run_and_check(compiled_model1, [MISS], 2, caplog, expect_result1, x) 50 | run_and_check(compiled_model1, [HIT, HIT], 2, caplog, expect_result1, x) 51 | 52 | model2 = Model(False).eval() 53 | expect_result2 = model2(x) 54 | compiled_model2 = compile(model2) 55 | run_and_check(compiled_model2, [MISS], 3, caplog, expect_result2, x) 56 | run_and_check(compiled_model2, [HIT], 3, caplog, expect_result2, x) 57 | 58 | run_and_check(compiled_model1, [HIT, HIT], 3, caplog, expect_result1, x) -------------------------------------------------------------------------------- /test/test_store.py: -------------------------------------------------------------------------------- 1 | from frontend.compile import compile, reset 2 | from frontend.utils import add_force_graph_break 3 | from frontend.c_api import get_next_frame_id 4 | from common.checker import run_and_check, HIT, MISS 5 | import torch 6 | 7 | 8 | def store_no_break(a, b): 9 | c = a + b 10 | return c 11 | 12 | 13 | def store_with_break(a, b): 14 | c = a + b 15 | d = c + a 16 | e = c + d 17 | f = c / 2 + e 18 | return f 19 | 20 | 21 | def test_store(caplog): 22 | reset() 23 | compiled_store_no_break = compile(store_no_break) 24 | a = torch.full((1,), 1.0) 25 | b = torch.full((1,), 2.0) 26 | result = store_no_break(a, b) 27 | run_and_check(compiled_store_no_break, [MISS], 1, caplog, result, a, b) 28 | run_and_check(compiled_store_no_break, [HIT], 1, caplog, result, a, b) 29 | 30 | compiled_store_with_break = compile(store_with_break) 31 | add_force_graph_break(get_next_frame_id(), 14) 32 | result = store_with_break(a, b) 33 | run_and_check(compiled_store_with_break, [MISS], 3, caplog, result, a, b) 34 | run_and_check(compiled_store_with_break, [HIT, HIT], 3, caplog, result, a, 35 | b) 36 | -------------------------------------------------------------------------------- /test/test_tuple.py: -------------------------------------------------------------------------------- 1 | from frontend.compile import compile, reset 2 | from common.checker import run_and_check, HIT, MISS, assert_equal 3 | import torch 4 | from collections import namedtuple 5 | 6 | 7 | def without_tensor_0(a): 8 | return a 9 | 10 | 11 | def without_tensor_1(a): 12 | return a * 3 13 | 14 | 15 | def without_tensor_2(a): 16 | return (a, 9) 17 | 18 | 19 | def without_tensor_3(a): 20 | return a[1:] 21 | 22 | 23 | def without_tensor_4(a, b): 24 | return a + b 25 | 26 | 27 | def named_tuple(a, b): 28 | funcs = namedtuple('output', ['relu', 'relu1']) 29 | out = funcs(a, b) 30 | return out 31 | 32 | 33 | def test_without_tensor(caplog): 34 | reset() 35 | compiled_no_tensor0 = compile(without_tensor_0) 36 | compiled_no_tensor1 = compile(without_tensor_1) 37 | compiled_no_tensor2 = compile(without_tensor_2) 38 | compiled_no_tensor3 = compile(without_tensor_3) 39 | compiled_no_tensor4 = compile(without_tensor_4) 40 | compiled_no_tensor5 = compile(named_tuple) 41 | a = (1, 2.5) 42 | b = (2, 4) 43 | result = without_tensor_0(a) 44 | run_and_check(compiled_no_tensor0, [MISS], 1, caplog, result, a) 45 | run_and_check(compiled_no_tensor0, [HIT], 1, caplog, result, a) 46 | result = without_tensor_1(a) 47 | run_and_check(compiled_no_tensor1, [MISS], 2, caplog, result, a) 48 | run_and_check(compiled_no_tensor1, [HIT], 2, caplog, result, a) 49 | result = without_tensor_2(a) 50 | run_and_check(compiled_no_tensor2, [MISS], 3, caplog, result, a) 51 | run_and_check(compiled_no_tensor2, [HIT], 3, caplog, result, a) 52 | result = without_tensor_3(a) 53 | run_and_check(compiled_no_tensor3, [MISS], 4, caplog, result, a) 54 | run_and_check(compiled_no_tensor3, [HIT], 4, caplog, result, a) 55 | result = without_tensor_4(a, b) 56 | run_and_check(compiled_no_tensor4, [MISS], 5, caplog, result, a, b) 57 | run_and_check(compiled_no_tensor4, [HIT], 5, caplog, result, a, b) 58 | a = (10, 6) 59 | b = (8, 7) 60 | result = without_tensor_0(a) 61 | run_and_check(compiled_no_tensor0, [MISS], 6, caplog, result, a) 62 | run_and_check(compiled_no_tensor0, [HIT], 6, caplog, result, a) 63 | result = without_tensor_4(a, b) 64 | run_and_check(compiled_no_tensor4, [MISS], 7, caplog, result, a, b) 65 | run_and_check(compiled_no_tensor4, [HIT], 7, caplog, result, a, b) 66 | # a = 2.2 67 | # b = 3.3 68 | # result = named_tuple(a, b) 69 | # run_and_check(compiled_no_tensor5, [MISS], 8, caplog, result, a, b) 70 | # run_and_check(compiled_no_tensor5, [HIT], 8, caplog, result, a, b) 71 | 72 | 73 | def tensor_0(tuple_a, tuple_b): 74 | return tuple_a[3] + tuple_b[2] 75 | 76 | 77 | def tensor_1(tuple_a, tuple_b): 78 | return tuple_a[3] * tuple_b[2] 79 | 80 | 81 | def tensor_2(tuple_a): 82 | return tuple_a 83 | 84 | 85 | def tensor_3(tuple_a, tuple_b): 86 | return tuple_a + tuple_b 87 | 88 | 89 | def tensor_4(tuple_a, tuple_b): 90 | return tuple_a + (3,) 91 | 92 | 93 | def tuple_id(tuple_a, tuple_b): 94 | c = tuple_a + tuple_b 95 | return c[3], c[6] 96 | 97 | 98 | def test_with_tensor(caplog): 99 | reset() 100 | compiled_tensor0 = compile(tensor_0) 101 | compiled_tensor1 = compile(tensor_1) 102 | compiled_tensor2 = compile(tensor_2) 103 | compiled_tensor3 = compile(tensor_3) 104 | compiled_tensor4 = compile(tensor_4) 105 | compiled_tensor5 = compile(tuple_id) 106 | a = torch.full((1,), 5.0) 107 | b = torch.full((1,), 7.0) 108 | tuple_a = (1, 2, 4, a) 109 | tuple_b = (3.5, 7, b) 110 | result = tensor_0(tuple_a, tuple_b) 111 | run_and_check(compiled_tensor0, [MISS], 1, caplog, result, tuple_a, tuple_b) 112 | run_and_check(compiled_tensor0, [HIT], 1, caplog, result, tuple_a, tuple_b) 113 | result = tensor_1(tuple_a, tuple_b) 114 | run_and_check(compiled_tensor1, [MISS], 2, caplog, result, tuple_a, tuple_b) 115 | run_and_check(compiled_tensor1, [HIT], 2, caplog, result, tuple_a, tuple_b) 116 | result = tensor_2(tuple_a) 117 | run_and_check(compiled_tensor2, [MISS], 3, caplog, result, tuple_a) 118 | run_and_check(compiled_tensor2, [HIT], 3, caplog, result, tuple_a) 119 | result = tensor_3(tuple_a, tuple_b) 120 | run_and_check(compiled_tensor3, [MISS], 4, caplog, result, tuple_a, tuple_b) 121 | run_and_check(compiled_tensor3, [HIT], 4, caplog, result, tuple_a, tuple_b) 122 | result = tensor_4(tuple_a, tuple_b) 123 | run_and_check(compiled_tensor4, [MISS], 5, caplog, result, tuple_a, tuple_b) 124 | run_and_check(compiled_tensor4, [HIT], 5, caplog, result, tuple_a, tuple_b) 125 | tuple_a = (1, 2, 4, a) 126 | tuple_b = (3.5, 7, a) 127 | result = tuple_id(tuple_a, tuple_b) 128 | assert_equal(id(result[0]), id(result[1])) 129 | assert_equal(id(result[0]), id(compiled_tensor5(tuple_a, tuple_b)[1])) 130 | assert_equal(id(compiled_tensor5(tuple_a, tuple_b)[0]), 131 | id(compiled_tensor5(tuple_a, tuple_b)[1])) 132 | # test nested tuple 133 | tuple_a = (1, 2, 4, (6, 7), a, (8, (9, 10), 11)) 134 | tuple_b = (3.5, 7, b) 135 | result = tensor_3(tuple_a, tuple_b) 136 | run_and_check(compiled_tensor3, [MISS], 7, caplog, result, tuple_a, tuple_b) 137 | run_and_check(compiled_tensor3, [HIT], 7, caplog, result, tuple_a, tuple_b) -------------------------------------------------------------------------------- /test/test_ud_class.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from frontend.compile import compile, reset 3 | from common.checker import run_and_check, HIT, MISS, assert_equal, run_and_check_cache 4 | 5 | 6 | class A: 7 | 8 | def __init__(self, x) -> None: 9 | self.x = x 10 | 11 | 12 | def assign_to_exist(a): 13 | a.x = a.x + 1 14 | 15 | 16 | def test_assign_to_exist(caplog): 17 | reset() 18 | x = torch.randn(3, 4) 19 | a1 = A(x) 20 | a2 = A(x) 21 | a3 = A(x) 22 | expect = assign_to_exist(a1) 23 | compiled = compile(assign_to_exist) 24 | run_and_check(compiled, [MISS], 1, caplog, expect, a2) 25 | run_and_check(compiled, [HIT], 1, caplog, expect, a3) 26 | assert_equal(a2.x, a1.x) 27 | assert_equal(a3.x, a1.x) 28 | 29 | 30 | def create_new_class(x): 31 | a = A(x + 1) 32 | return a 33 | 34 | 35 | def test_create_new_class(caplog): 36 | reset() 37 | x = torch.randn(4) 38 | y = create_new_class(x) 39 | compiled = compile(create_new_class) 40 | run_and_check_cache(compiled, [MISS, MISS], 1, caplog, x) 41 | run_and_check_cache(compiled, [HIT], 1, caplog, x) 42 | z = compiled(x) 43 | assert_equal(y.x, z.x) 44 | 45 | 46 | def create_new_class_complex(x): 47 | a = A(x + 1.0) 48 | a.x = a.x + 1.0 49 | return a.x + 1.0, a 50 | 51 | 52 | def test_create_new_class_complex(caplog): 53 | reset() 54 | x = torch.randn(4) 55 | y = create_new_class_complex(x) 56 | compiled = compile(create_new_class_complex) 57 | run_and_check_cache(compiled, [MISS, MISS], 1, caplog, x) 58 | run_and_check_cache(compiled, [HIT], 1, caplog, x) 59 | z = compiled(x) 60 | assert_equal(y[0], z[0]) 61 | assert_equal(y[1].x, z[1].x) 62 | --------------------------------------------------------------------------------