├── .gitignore ├── __init__.py ├── adv_autotvm ├── autotvm_test.json ├── demo.py ├── rpc_msg.jpg └── visualizes │ └── optimized_mod.prototxt ├── adv_interpreter ├── demo.py └── visualizes │ └── optimized_mod.prototxt ├── adv_quantize ├── demo.py └── visualizes │ ├── annotate.prototxt │ ├── partition.prototxt │ ├── prerequisite.prototxt │ ├── quantized.prototxt │ ├── quantized_opt.prototxt │ └── realize.prototxt ├── adv_tensorrt ├── demo.py └── visualizes │ ├── annotate_target.prototxt │ ├── convert_layout.prototxt │ ├── merge_compiler_regions.prototxt │ ├── optimized_mod.prototxt │ ├── partition_graph.prototxt │ ├── partition_graph_tensorrt_0.prototxt │ └── raw_mod.prototxt ├── adv_virtualmachine ├── demo.py └── visualizes │ ├── inline_primitives.prototxt │ ├── memory_opt.prototxt │ ├── raw_mod.prototxt │ └── to_a_normal_form.prototxt ├── basic_whole.jpg ├── dev_msir └── torch_test.py ├── graph_execute └── demo.py ├── relay_codegen └── demo.py ├── relay_optimize ├── demo.py ├── test_dependency_graph.py ├── test_index_graph.py ├── visualize.cpp └── visualizes │ ├── alter_op_layout.prototxt │ ├── backward_fold_scale_axis.prototxt │ ├── base.prototxt │ ├── dependcy_graph.prototxt │ ├── fold_constant.prototxt │ ├── fuse_ops.prototxt │ ├── graph_paritioner.prototxt │ ├── indexed_forward_graph.prototxt │ ├── optimized_mod.prototxt │ ├── post_dom_tree.prototxt │ └── simplify_inference.prototxt ├── relay_parse └── demo.py ├── stmt_build ├── demo.py ├── test_cases.py └── visualizes │ ├── bind_after.prototxt │ ├── bind_before.prototxt │ ├── cache_read_after.prototxt │ ├── cache_read_before.prototxt │ ├── compute_inline_after.prototxt │ ├── compute_inline_before.prototxt │ ├── normal_compute.prototxt │ ├── normal_stmt.prototxt │ ├── normal_stmt_complete.prototxt │ ├── split_after.prototxt │ ├── split_before.prototxt │ ├── tensorize_after.prototxt │ └── tensorize_before.prototxt ├── stmt_optimize ├── demo.py └── visualizes │ ├── simplify_after.prototxt │ ├── simplify_before.prototxt │ ├── storage_flatten_after.prototxt │ ├── storage_flatten_before.prototxt │ ├── vectorize_loop_after.prototxt │ └── vectorize_loop_before.prototxt ├── tvm_build ├── build_transform.py ├── demo.py └── visualizes │ ├── lower_tvm_builtin_after.prototxt │ ├── lower_tvm_builtin_before.prototxt │ ├── make_packed_api_after.prototxt │ ├── make_packed_api_before.prototxt │ ├── split_host_device_after_device.prototxt │ ├── split_host_device_after_host.prototxt │ └── split_host_device_before.prototxt ├── tvm_concepts.pdf ├── utils.py └── visualize.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.docx 2 | *.doc 3 | *.graffle 4 | *.DS_Store 5 | *.pkl_memoize_py3 6 | tmp_test.sh 7 | __pycache__ 8 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Archermmt/tvm_walk_through/7e1e6e8f8061c5e79e5e399b2a85c1f404ef5893/__init__.py -------------------------------------------------------------------------------- /adv_autotvm/demo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torchvision.models as models 4 | 5 | import tvm 6 | from tvm import relay 7 | from tvm import autotvm 8 | from tvm.autotvm.tuner import GridSearchTuner 9 | 10 | import sys 11 | sys.path.append("..") 12 | from visualize import RelayVisualizer 13 | 14 | def check_optimize(mod,target,params): 15 | with tvm.transform.PassContext(opt_level=3): 16 | opt_mod, _ = relay.optimize(mod, target, params) 17 | visualizer=RelayVisualizer() 18 | visualizer.visualize(opt_mod,path="visualizes/optimized_mod.prototxt") 19 | print("optimized main func "+str(opt_mod["main"])) 20 | 21 | if __name__=='__main__': 22 | #prepare model and input 23 | model = models.resnet18(pretrained=True) 24 | shape_list = [("input0",(1,3,224,224))] 25 | fake_input = np.random.random_sample(shape_list[0][1]).astype('float32') 26 | graph = torch.jit.trace(model,torch.from_numpy(fake_input)) 27 | #step 1 parse 28 | mod, params = relay.frontend.from_pytorch(graph, shape_list) 29 | target = tvm.target.Target("llvm", host="llvm") 30 | #[optional] step 2.1 check optimize, for debug only 31 | #check_optimize(mod,target,params) 32 | #step 2 extract tasks 33 | tasks = autotvm.task.extract_from_program(mod["main"], target=target, params=params) 34 | #step 3 fintune 35 | runner = autotvm.LocalRunner(number=10,repeat=1,timeout=10,min_repeat_ms=0,) 36 | measure_option=autotvm.measure_option( 37 | builder=autotvm.LocalBuilder(n_parallel=1,build_func="default"), runner=runner) 38 | for i, task in enumerate(tasks[:1]): 39 | prefix = "[Task %2d/%2d] " % (i + 1, len(tasks)) 40 | tuner_obj = GridSearchTuner(task) 41 | tuner_obj.tune( 42 | n_trial=min(10, len(task.config_space)), 43 | early_stopping=100, 44 | measure_option=measure_option, 45 | callbacks=[ 46 | autotvm.callback.progress_bar(10, prefix=prefix), 47 | autotvm.callback.log_to_file("autotvm_test.json"), 48 | ], 49 | ) -------------------------------------------------------------------------------- /adv_autotvm/rpc_msg.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Archermmt/tvm_walk_through/7e1e6e8f8061c5e79e5e399b2a85c1f404ef5893/adv_autotvm/rpc_msg.jpg -------------------------------------------------------------------------------- /adv_interpreter/demo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torchvision.models as models 4 | 5 | import tvm 6 | from tvm import relay 7 | 8 | import sys 9 | sys.path.append("..") 10 | from utils import array_des 11 | from visualize import RelayVisualizer 12 | 13 | def check_optimize(inter): 14 | opt_mod=inter.optimize() 15 | visualizer=RelayVisualizer() 16 | visualizer.visualize(opt_mod,path="visualizes/optimized_mod.prototxt") 17 | 18 | if __name__=='__main__': 19 | #prepare model and input 20 | model = models.resnet18(pretrained=True) 21 | shape_list = [("input0",(1,3,224,224))] 22 | fake_input = np.random.random_sample(shape_list[0][1]).astype('float32') 23 | graph = torch.jit.trace(model,torch.from_numpy(fake_input)) 24 | #step 1 parse 25 | mod, params = relay.frontend.from_pytorch(graph, shape_list) 26 | target = tvm.target.Target("llvm", host="llvm") 27 | with tvm.transform.PassContext(opt_level=3): 28 | #step 2 optimize 29 | mod,params=relay.optimize(mod, target=target, params=params) 30 | #step 3 create Interpreter 31 | inter = relay.create_executor("debug", mod=mod, device=tvm.cpu(0), target=target) 32 | ''' 33 | #[optional] step 3.1 optimize, only fo debug use 34 | check_optimize(inter) 35 | ''' 36 | results = inter.evaluate()(fake_input) 37 | print("results "+array_des(results)) -------------------------------------------------------------------------------- /adv_quantize/demo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torchvision.models as models 4 | 5 | import tvm 6 | from tvm import relay 7 | from tvm.contrib import graph_executor 8 | 9 | import sys 10 | sys.path.append("..") 11 | from utils import array_des 12 | from visualize import RelayVisualizer 13 | 14 | def calibrate_dataset(): 15 | for i in range(10): 16 | print("Creating {} th data".format(i)) 17 | cal_data=np.random.random_sample((1,3,224,224)).astype('float32') 18 | yield {"input0": cal_data} 19 | 20 | def check_optimize(mod,target,params): 21 | visualizer=RelayVisualizer() 22 | with tvm.transform.PassContext(opt_level=3): 23 | mod, params = relay.optimize(mod, params=params, target=target) 24 | print("mod "+str(mod["main"])) 25 | visualizer.visualize(mod,path="visualizes/quantized_opt.prototxt") 26 | 27 | if __name__=='__main__': 28 | visualizer=RelayVisualizer() 29 | #prepare model and input 30 | model = models.resnet18(pretrained=True) 31 | shape_list = [("input0",(1,3,224,224))] 32 | fake_input = torch.from_numpy(np.random.random_sample(shape_list[0][1]).astype('float32')) 33 | graph = torch.jit.trace(model,fake_input) 34 | #step 1 parse 35 | mod, params = relay.frontend.from_pytorch(graph, shape_list) 36 | target = tvm.target.Target("llvm", host="llvm") 37 | #step 2 quantize 38 | with relay.quantize.qconfig(calibrate_mode="kl_divergence", weight_scale="max"): 39 | #step 2.1 [optional] debug the prerequisite_optimize process 40 | #mod=relay.quantize.prerequisite_optimize(mod,params) 41 | #visualizer.visualize(mod,path="visualizes/prerequisite.prototxt") 42 | mod = relay.quantize.quantize(mod, params, dataset=calibrate_dataset()) 43 | visualizer.visualize(mod,path="visualizes/quantized.prototxt") 44 | 45 | #step 3.1 [optional] debug the optimize process 46 | #check_optimize(mod,target,params) 47 | 48 | #step 3 build lib 49 | with tvm.transform.PassContext(opt_level=3): 50 | lib = relay.build(mod, target=target, params=params) 51 | dev = tvm.cpu(0) 52 | m = graph_executor.GraphModule(lib["default"](dev)) 53 | # Set inputs 54 | m.set_input("input0", tvm.nd.array(fake_input)) 55 | # Execute 56 | m.run() 57 | # Get outputs 58 | res = m.get_output(0) 59 | print("output "+array_des(res)) -------------------------------------------------------------------------------- /adv_tensorrt/demo.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | try: 3 | tf_v1 = tf.compat.v1 4 | except ImportError: 5 | tf_v1 = tf 6 | 7 | import numpy as np 8 | 9 | import tvm 10 | from tvm import relay, runtime 11 | from tvm.contrib import graph_executor 12 | from tvm.contrib.download import download_testdata 13 | import tvm.relay.testing.tf as tf_testing 14 | from tvm.relay.op.contrib import tensorrt 15 | 16 | import sys 17 | sys.path.append("..") 18 | from utils import array_des 19 | from visualize import RelayVisualizer 20 | 21 | def check_optimize(mod,params): 22 | with tvm.transform.PassContext(opt_level=3, config={"relay.ext.tensorrt.options": config}): 23 | mod, params = relay.optimize(mod, params=params, target="cuda") 24 | visualizer.visualize(mod,path="visualizes/optimized_mod.prototxt") 25 | 26 | if __name__=='__main__': 27 | #prepare model and input 28 | model_url="https://github.com/dmlc/web-data/raw/main/tensorflow/models/InceptionV1/classify_image_graph_def-with_shapes.pb" 29 | model_name="classify_image_graph_def-with_shapes.pb" 30 | model_path = download_testdata(model_url, model_name, module=["tf", "InceptionV1"]) 31 | visualizer = RelayVisualizer() 32 | #load graph and create input 33 | fake_input = np.random.random_sample([299,299,3]).astype('uint8') 34 | with tf_v1.gfile.GFile(model_path, "rb") as f: 35 | graph_def = tf_v1.GraphDef() 36 | graph_def.ParseFromString(f.read()) 37 | graph = tf.import_graph_def(graph_def, name="") 38 | graph_def = tf_testing.ProcessGraphDefParam(graph_def) 39 | with tf_v1.Session() as sess: 40 | graph_def = tf_testing.AddShapesToGraphDef(sess, "softmax") 41 | 42 | #step 1 parse to relay 43 | shape_dict = {"DecodeJpeg/contents": fake_input.shape} 44 | mod, params = relay.frontend.from_tensorflow(graph_def, layout=None, shape=shape_dict) 45 | #visualizer.visualize(mod,path="visualizes/raw_mod.prototxt") 46 | 47 | #step 2 pre optimize for tensorrt 48 | mod, config = tensorrt.partition_for_tensorrt(mod,params) 49 | #visualizer.visualize(mod,path="visualizes/partition_graph.prototxt") 50 | 51 | #step 3.1 [optional] check optimize 52 | #check_optimize(mod,params) 53 | 54 | #step 3 build the mod 55 | with tvm.transform.PassContext(opt_level=3, config={"relay.ext.tensorrt.options": config}): 56 | graph, lib, params = relay.build(mod, params=params, target="cuda") 57 | params = runtime.save_param_dict(params) 58 | 59 | #step 4 run executor and get results 60 | mod_ = graph_executor.create(graph, lib, device=tvm.gpu(0)) 61 | mod_.load_params(params) 62 | mod_.run(data=fake_input) 63 | res = mod_.get_output(0) 64 | print("output "+array_des(res)) 65 | -------------------------------------------------------------------------------- /adv_tensorrt/visualizes/partition_graph.prototxt: -------------------------------------------------------------------------------- 1 | name : "relay_ir" 2 | layer { 3 | name:"DecodeJpeg/contents" 4 | type:"input" 5 | top:"DecodeJpeg/contents" 6 | layer_param { 7 | idx:0 8 | out_0 {name:"DecodeJpeg/contents:0",dtype:uint8,shape:[299, 299, 3]} 9 | } 10 | } 11 | layer { 12 | name:"Node_6" 13 | type:"cast" 14 | top:"Node_6" 15 | bottom:"DecodeJpeg/contents" 16 | layer_param { 17 | idx:6 18 | in_0 {name:"DecodeJpeg/contents:0",dtype:uint8,shape:[299, 299, 3]} 19 | out_0 {name:"Node_6:0",dtype:float32,shape:[299, 299, 3]} 20 | attrs {'dtype': 'float32'} 21 | } 22 | } 23 | layer { 24 | name:"Node_7" 25 | type:"expand_dims" 26 | top:"Node_7" 27 | bottom:"Node_6" 28 | layer_param { 29 | idx:7 30 | in_0 {name:"Node_6:0",dtype:float32,shape:[299, 299, 3]} 31 | out_0 {name:"Node_7:0",dtype:float32,shape:[1, 299, 299, 3]} 32 | attrs {'axis': 0, 'num_newaxis': 1} 33 | } 34 | } 35 | layer { 36 | name:"Node_8" 37 | type:"image_resize" 38 | top:"Node_8" 39 | bottom:"Node_7" 40 | layer_param { 41 | idx:8 42 | in_0 {name:"Node_7:0",dtype:float32,shape:[1, 299, 299, 3]} 43 | out_0 {name:"Node_8:0",dtype:float32,shape:[1, 299, 299, 3]} 44 | attrs {'size': [299, 299], 'layout': 'NHWC', 'method': 'bilinear', 'coordinate_transformation_mode': 'asymmetric', 'rounding_method': '', 'bicubic_alpha': -0.5, 'bicubic_exclude': 0, 'out_dtype': ''} 45 | } 46 | } 47 | layer { 48 | name:"tensorrt_0" 49 | type:"global_var" 50 | top:"tensorrt_0" 51 | bottom:"Node_8" 52 | layer_param { 53 | idx:9 54 | in_0 {name:"Node_8:0",dtype:float32,shape:[1, 299, 299, 3]} 55 | out_0 {name:"tensorrt_0:0",dtype:,shape:[]} 56 | } 57 | } 58 | layer { 59 | name:"Node_10" 60 | type:"function" 61 | top:"Node_10" 62 | bottom:"DecodeJpeg/contents" 63 | bottom:"tensorrt_0" 64 | layer_param { 65 | idx:10 66 | in_0 {name:"DecodeJpeg/contents:0",dtype:uint8,shape:[299, 299, 3]} 67 | in_1 {name:"tensorrt_0:0",dtype:,shape:[]} 68 | out_0 {name:"Node_10:0",dtype:float32,shape:[1, 1008]} 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /adv_virtualmachine/demo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torchvision.models as models 4 | 5 | import tvm 6 | from tvm import relay 7 | from tvm.runtime.vm import VirtualMachine 8 | 9 | import sys 10 | sys.path.append("..") 11 | from utils import array_des 12 | from visualize import RelayVisualizer 13 | 14 | def check_optimize(mod,target,params): 15 | visualizer=RelayVisualizer() 16 | with tvm.transform.PassContext(opt_level=3): 17 | compiler = relay.vm.VMCompiler() 18 | mod,params=compiler.optimize(mod, target=target, params=params) 19 | print("mod "+str(mod["main"])) 20 | visualizer.visualize(mod,path="visualizes/memory_opt.prototxt") 21 | 22 | if __name__=='__main__': 23 | #prepare model and input 24 | model = models.resnet18(pretrained=True) 25 | shape_list = [("input0",(1,3,224,224))] 26 | fake_input = np.random.random_sample(shape_list[0][1]).astype('float32') 27 | graph = torch.jit.trace(model,torch.from_numpy(fake_input)) 28 | 29 | #step 1 parse to relay 30 | mod, params = relay.frontend.from_pytorch(graph, shape_list) 31 | target = tvm.target.Target("llvm", host="llvm") 32 | 33 | #step 2.1.1 [optional] debug the optimize process 34 | #check_optimize(mod,target,params) 35 | 36 | #step 2 compile the module 37 | with tvm.transform.PassContext(opt_level=3): 38 | vm_exec = relay.vm.compile(mod, target=target, params=params) 39 | 40 | #step 3 run the VirtualMachine 41 | dev = tvm.device("llvm", 0) 42 | vm = VirtualMachine(vm_exec, dev) 43 | vm.set_input("main", **{"input0": fake_input}) 44 | res=vm.run() 45 | print("res "+array_des(res)) -------------------------------------------------------------------------------- /basic_whole.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Archermmt/tvm_walk_through/7e1e6e8f8061c5e79e5e399b2a85c1f404ef5893/basic_whole.jpg -------------------------------------------------------------------------------- /dev_msir/torch_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import torchvision.models as models 5 | 6 | import tvm 7 | import tvm.testing 8 | from tvm import relay 9 | from tvm.contrib.msir.torch.transform import partition_for_torch 10 | from tvm.contrib.msir.core.ir import MSIRGraphRuntimeModule 11 | from tvm.contrib.msir.core import utils as msir_utils 12 | 13 | def debug_optimize(mod, params, disabled_pass, config): 14 | with tvm.transform.PassContext(opt_level=3, disabled_pass=disabled_pass, config=config): 15 | mod, params=relay.optimize(mod, params=params, target="llvm") 16 | mod = relay.transform.InferLayout()(mod) 17 | from tvm.contrib.msir.core.ir import build_from_relay 18 | graph=build_from_relay(mod,"main") 19 | graph.visualize("visualizes/test.prototxt") 20 | 21 | if __name__=='__main__': 22 | # prepare model and input 23 | model = models.resnet18(pretrained=True).eval() 24 | shape_list = [("input0",(1,3,224,224))] 25 | fake_input = torch.from_numpy(np.random.random_sample(shape_list[0][1]).astype('float32')) 26 | golden = model(fake_input) 27 | graph = torch.jit.trace(model,fake_input) 28 | 29 | # partition main function 30 | mod, params = relay.frontend.from_pytorch(graph, shape_list, with_name=True) 31 | mod, disabled_pass, config = partition_for_torch(mod,params) 32 | 33 | # optional optimize 34 | # debug_optimize(mod, params, disabled_pass, config) 35 | with tvm.transform.PassContext(opt_level=3, disabled_pass=disabled_pass, config=config): 36 | graph_json, mod1, params = relay.build(mod, target="llvm", params=params) 37 | 38 | msir_utils.set_work_dir("/tmp/dev_torch_test") 39 | module = MSIRGraphRuntimeModule(mod1) 40 | 41 | model = module.load_source()() 42 | state_dict = {k:torch.from_numpy(v) for k,v in module.load_weights().items()} 43 | model.load_state_dict(state_dict) 44 | res = model(fake_input) 45 | 46 | tvm.testing.assert_allclose(golden.detach().cpu().numpy(), res.detach().cpu().numpy(), rtol=5e-2) 47 | -------------------------------------------------------------------------------- /graph_execute/demo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torchvision.models as models 4 | 5 | import tvm 6 | from tvm import relay 7 | from tvm.contrib import graph_executor 8 | 9 | import sys 10 | sys.path.append("..") 11 | from utils import array_des 12 | 13 | if __name__=='__main__': 14 | #prepare model and input 15 | model = models.resnet18(pretrained=True) 16 | shape_list = [("input0",(1,3,224,224))] 17 | fake_input = torch.from_numpy(np.random.random_sample(shape_list[0][1]).astype('float32')) 18 | graph = torch.jit.trace(model,fake_input) 19 | #main function 20 | mod, params = relay.frontend.from_pytorch(graph, shape_list) 21 | #optimize the mod 22 | target = tvm.target.Target("llvm", host="llvm") 23 | with tvm.transform.PassContext(opt_level=3): 24 | lib = relay.build(mod, target=target, params=params) 25 | #execute 26 | dev = tvm.cpu(0) 27 | m = graph_executor.GraphModule(lib["default"](dev)) 28 | # Set inputs 29 | m.set_input("input0", tvm.nd.array(fake_input)) 30 | # Execute 31 | m.run() 32 | # Get outputs 33 | res = m.get_output(0) 34 | print("output "+array_des(res)) -------------------------------------------------------------------------------- /relay_codegen/demo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torchvision.models as models 4 | 5 | import tvm 6 | from tvm import relay 7 | 8 | if __name__=='__main__': 9 | #prepare model and input 10 | model = models.resnet18(pretrained=True) 11 | shape_list = [("input0",(1,3,224,224))] 12 | fake_input = torch.from_numpy(np.random.random_sample(shape_list[0][1]).astype('float32')) 13 | graph = torch.jit.trace(model,fake_input) 14 | #main function 15 | mod, params = relay.frontend.from_pytorch(graph, shape_list) 16 | #optimize the mod 17 | target = tvm.target.Target("llvm", host="llvm") 18 | with tvm.transform.PassContext(opt_level=3): 19 | graph_json, mod, params = relay.build(mod, target=target, params=params) -------------------------------------------------------------------------------- /relay_optimize/demo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torchvision.models as models 4 | 5 | import tvm 6 | from tvm import relay 7 | 8 | import sys 9 | sys.path.append("..") 10 | from visualize import RelayVisualizer 11 | 12 | def auto_optimize(mod,target,params): 13 | mod,params=relay.optimize(mod, target=target, params=params) 14 | visualizer=RelayVisualizer() 15 | visualizer.visualize(mod,path="visualizes/optimized_mod.prototxt") 16 | return mod,params 17 | 18 | def debug_optimize(mod,target,params): 19 | mod["main"]=relay.build_module.bind_params_by_name(mod["main"],params) 20 | #add transform passes 21 | seq = tvm.transform.Sequential( 22 | [ 23 | relay.transform.SimplifyInference(), 24 | relay.transform.BackwardFoldScaleAxis(), 25 | relay.transform.ForwardFoldScaleAxis(), 26 | relay.transform.FoldConstant(), 27 | relay.transform.AlterOpLayout(), 28 | relay.transform.FoldConstant(), 29 | relay.transform.FuseOps(), 30 | ] 31 | ) 32 | with target: 33 | mod=seq(mod) 34 | 35 | visualizer=RelayVisualizer() 36 | visualizer.visualize(mod,path="visualizes/fuse_ops.prototxt") 37 | return mod,params 38 | 39 | if __name__=='__main__': 40 | #prepare model and input 41 | model = models.resnet18(pretrained=True) 42 | shape_list = [("input0",(1,3,224,224))] 43 | fake_input = torch.from_numpy(np.random.random_sample(shape_list[0][1]).astype('float32')) 44 | graph = torch.jit.trace(model,fake_input) 45 | #main function 46 | mod, params = relay.frontend.from_pytorch(graph, shape_list) 47 | #optimize the mod 48 | #step 1 create target 49 | target = tvm.target.Target("llvm", host="llvm") 50 | #step 1 create PassContext 51 | with tvm.transform.PassContext(opt_level=3): 52 | #step 3 optimize 53 | mod,params=auto_optimize(mod,target,params) 54 | print("optimize func "+str(mod["main"])) -------------------------------------------------------------------------------- /relay_optimize/test_dependency_graph.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torchvision.models as models 4 | 5 | import tvm 6 | from tvm import relay 7 | 8 | if __name__=='__main__': 9 | #prepare model and input 10 | model = models.resnet18(pretrained=True) 11 | shape_list = [("input0",(1,3,224,224))] 12 | fake_input = torch.from_numpy(np.random.random_sample(shape_list[0][1]).astype('float32')) 13 | graph = torch.jit.trace(model,fake_input) 14 | #main function 15 | mod, params = relay.frontend.from_pytorch(graph, shape_list) 16 | #dependency will be build during PrettyPrint 17 | print("entry func "+str(mod["main"])) -------------------------------------------------------------------------------- /relay_optimize/test_index_graph.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torchvision.models as models 4 | 5 | import tvm 6 | from tvm import relay 7 | 8 | def debug_optimize(mod,target,params): 9 | mod["main"]=relay.build_module.bind_params_by_name(mod["main"],params) 10 | #add transform passes 11 | seq = tvm.transform.Sequential( 12 | [ 13 | relay.transform.SimplifyInference(), 14 | ] 15 | ) 16 | mod=seq(mod) 17 | print("base func "+str(mod["main"])) 18 | 19 | seq = tvm.transform.Sequential( 20 | [ 21 | relay.transform.FuseOps(), 22 | ] 23 | ) 24 | mod=seq(mod) 25 | print("optimize func "+str(mod["main"])) 26 | 27 | return mod,params 28 | 29 | if __name__=='__main__': 30 | #prepare model and input 31 | model = models.resnet18(pretrained=True) 32 | shape_list = [("input0",(1,3,224,224))] 33 | fake_input = torch.from_numpy(np.random.random_sample(shape_list[0][1]).astype('float32')) 34 | graph = torch.jit.trace(model,fake_input) 35 | #main function 36 | mod, params = relay.frontend.from_pytorch(graph, shape_list) 37 | #optimize the mod 38 | target = tvm.target.Target("llvm", host="llvm") 39 | with tvm.transform.PassContext(opt_level=3): 40 | mod,params=debug_optimize(mod,target,params) -------------------------------------------------------------------------------- /relay_optimize/visualize.cpp: -------------------------------------------------------------------------------- 1 | //for DependencyGraph, add to DependencyGraph class 2 | void visualize(const std::string& file_path){ 3 | std::unordered_map node_names; 4 | int cnt=0; 5 | std::ofstream out(file_path,std::ofstream::binary); 6 | if (out.is_open()){ 7 | std::streambuf *coutbuf = std::cout.rdbuf(); 8 | std::cout.rdbuf(out.rdbuf()); 9 | //build map 10 | std::unordered_map node_to_expr; 11 | for (auto expr_node : this->expr_node) { 12 | node_to_expr[expr_node.second] = expr_node.first; 13 | } 14 | //write the nodes 15 | std::cout<<"name : \"dependency\"\n"; 16 | for (auto it = this->post_dfs_order.rbegin(); it != this->post_dfs_order.rend(); ++it) { 17 | DependencyGraph::Node* n = *it; 18 | auto iit = n->parents.head; 19 | if(node_names.find(n)==node_names.end()){ 20 | node_names[n]="Node_"+std::to_string(cnt++); 21 | } 22 | std::cout<<"layer { name:\""<next) { 27 | std::cout<<" bottom : \""<value]<<"\"\n"; 28 | } 29 | } 30 | //add type 31 | Expr expr = node_to_expr[n]; 32 | if(!expr.defined()){ 33 | std::cout<<" type : \"Connect\"\n"; 34 | }else if (expr.as()){ 35 | auto call=Downcast(expr); 36 | auto op=Downcast(call->op); 37 | std::cout<<" type : \"Call_"<name<<"\"\n"; 38 | }else if(expr.as()){ 39 | std::cout<<" type : \"Function\"\n"; 40 | }else if(expr.as()){ 41 | auto node=Downcast(expr); 42 | std::cout<<" type : \"TupleGetItemNode\"\n"; 43 | }else if(expr.as()){ 44 | auto node=Downcast(expr); 45 | std::cout<<" type : \"Op_"<name<<"\"\n"; 46 | }else if(expr.as()){ 47 | auto node=Downcast(expr); 48 | std::cout<<" type : \"Var\""<<"\n"; 49 | }else{ 50 | std::cout<<" type : \"UNKNOWN\""<<"\n"; 51 | } 52 | //add attributes 53 | std::cout<<" layer_param : {\n"; 54 | std::cout<<" addr : \""<()){ 56 | auto node=Downcast(expr); 57 | std::cout<<" index : "<index<<"\n"; 58 | }else if(expr.as()){ 59 | auto node=Downcast(expr); 60 | std::cout<<" name_hint : \""<name_hint()<<"\"\n"; 61 | } 62 | std::cout<<" }\n}\n"; 63 | } 64 | std::cout.rdbuf(coutbuf); 65 | out.close(); 66 | } 67 | } 68 | 69 | //base utils 70 | std::string get_pattern_kind(const OpPatternKind& kind){ 71 | std::string kind_name="kOpaque"; 72 | switch(kind){ 73 | case kElemWise: 74 | kind_name="kElemWise"; 75 | break; 76 | case kBroadcast: 77 | kind_name="kBroadcast"; 78 | break; 79 | case kInjective: 80 | kind_name="kInjective"; 81 | break; 82 | case kCommReduce: 83 | kind_name="kCommReduce"; 84 | break; 85 | case kOutEWiseFusable: 86 | kind_name="kOutEWiseFusable"; 87 | break; 88 | case kTuple: 89 | kind_name="kTuple"; 90 | break; 91 | default: 92 | break; 93 | } 94 | return kind_name; 95 | } 96 | 97 | //for IndexedForwardGraph, add to IndexedForwardGraph class 98 | void visualize(const std::string& file_path){ 99 | std::ofstream out(file_path,std::ofstream::binary); 100 | if (out.is_open()){ 101 | std::streambuf *coutbuf = std::cout.rdbuf(); 102 | std::cout.rdbuf(out.rdbuf()); 103 | //write the nodes 104 | std::cout<<"name : \"dependency\"\n"; 105 | for (auto it = this->post_dfs_order.rbegin(); it != this->post_dfs_order.rend(); ++it) { 106 | Node* n = *it; 107 | auto iit = n->outputs.head; 108 | std::cout<<"layer { name:\"Node_"<index<<"\"\n"; 109 | //add topo information 110 | std::cout<<" top : \"Node_"<index<<"\"\n"; 111 | if(iit!=nullptr){ 112 | for (; iit != nullptr; iit = iit->next) { 113 | std::cout<<" bottom : \"Node_"<value.node->index<<"\"\n"; 114 | } 115 | } 116 | //add type 117 | auto expr=GetRef(n->ref); 118 | auto pattern_name=get_pattern_kind(n->pattern); 119 | if(!expr.defined()){ 120 | std::cout<<" type : \"Connect["<()){ 122 | auto call=Downcast(expr); 123 | auto op=Downcast(call->op); 124 | std::cout<<" type : \"Call_"<name<<"["<()){ 126 | std::cout<<" type : \"Constant["<()){ 128 | std::cout<<" type : \"Function["<()){ 130 | auto node=Downcast(expr); 131 | std::cout<<" type : \"TupleGetItemNode["<()){ 133 | auto node=Downcast(expr); 134 | std::cout<<" type : \"Op_"<name<<"["<()){ 136 | auto node=Downcast(expr); 137 | std::cout<<" type : \"Var["<extern_ref ? "true" : "false")<<"\"\n"; 144 | if(expr.as()){ 145 | auto node=Downcast(expr); 146 | std::cout<<" index : "<index<<"\n"; 147 | }else if(expr.as()){ 148 | auto node=Downcast(expr); 149 | std::cout<<" tensor_type : \""<tensor_type()<<"\"\n"; 150 | }else if(expr.as()){ 151 | auto node=Downcast(expr); 152 | std::cout<<" name_hint : \""<name_hint()<<"\"\n"; 153 | } 154 | std::cout<<" }\n}\n"; 155 | } 156 | std::cout.rdbuf(coutbuf); 157 | out.close(); 158 | } 159 | } 160 | 161 | //for DominatorTree, add to DominatorTree class 162 | void visualize(const std::string& file_path){ 163 | std::ofstream out(file_path,std::ofstream::binary); 164 | if (out.is_open()){ 165 | std::streambuf *coutbuf = std::cout.rdbuf(); 166 | std::cout.rdbuf(out.rdbuf()); 167 | //write the nodes 168 | std::cout<<"name : \"dependency\"\n"; 169 | for (auto it = this->nodes.rbegin(); it != this->nodes.rend(); ++it) { 170 | Node* node = *it; 171 | IndexedForwardGraph::Node* gnode=node->gnode; 172 | std::cout<<"layer { name:\"Node_"<index<<"\"\n"; 173 | //add topo information 174 | std::cout<<" top : \"Node_"<index<<"\"\n"; 175 | if(node->parent!=nullptr){ 176 | std::cout<<" bottom : \"Node_"<parent->gnode->index<<"\"\n"; 177 | } 178 | //add type 179 | auto expr=GetRef(gnode->ref); 180 | auto pattern_name=get_pattern_kind(node->pattern); 181 | if(!expr.defined()){ 182 | std::cout<<" type : \"Connect\n["<()){ 184 | auto call=Downcast(expr); 185 | auto op=Downcast(call->op); 186 | std::cout<<" type : \"Call_"<name<<"["<()){ 188 | std::cout<<" type : \"Constant["<()){ 190 | std::cout<<" type : \"Function["<()){ 192 | std::cout<<" type : \"TupleGetItemNode["<()){ 194 | auto e_node=Downcast(expr); 195 | std::cout<<" type : \"Op_"<name<<"["<()){ 197 | auto e_node=Downcast(expr); 198 | std::cout<<" type : \"Var["<depth<<"\"\n"; 205 | if(expr.as()){ 206 | auto e_node=Downcast(expr); 207 | std::cout<<" index : "<index<<"\n"; 208 | }else if(expr.as()){ 209 | auto e_node=Downcast(expr); 210 | std::cout<<" tensor_type : \""<tensor_type()<<"\"\n"; 211 | }else if(expr.as()){ 212 | auto e_node=Downcast(expr); 213 | std::cout<<" name_hint : \""<name_hint()<<"\"\n"; 214 | } 215 | std::cout<<" }\n}\n"; 216 | } 217 | std::cout.rdbuf(coutbuf); 218 | out.close(); 219 | } 220 | } 221 | 222 | //for GraphPartitioner, add to GraphPartitioner class 223 | void visualize(const std::string& file_path){ 224 | std::unordered_map group_names; 225 | std::unordered_map ref_names; 226 | std::ofstream out(file_path,std::ofstream::binary); 227 | if (out.is_open()){ 228 | std::streambuf *coutbuf = std::cout.rdbuf(); 229 | std::cout.rdbuf(out.rdbuf()); 230 | //build names map 231 | for (int i=0;iroot_ref!=nullptr){ 235 | ref_names[group->root_ref]="Node_"+std::to_string(i); 236 | } 237 | } 238 | //write the nodes 239 | std::cout<<"name : \"graph_paritioner\"\n"; 240 | for (int i=0;iparent!=nullptr){ 246 | std::cout<<" bottom : \""<parent]<<"\"\n"; 247 | } 248 | //add type 249 | auto expr=GetRef(group->root_ref); 250 | auto pattern_name=get_pattern_kind(group->pattern); 251 | if(!expr.defined()){ 252 | std::cout<<" type : \"Connect\n["<()){ 254 | auto call=Downcast(expr); 255 | auto op=Downcast(call->op); 256 | std::cout<<" type : \"Call_"<name<<"["<()){ 258 | std::cout<<" type : \"Constant["<()){ 260 | std::cout<<" type : \"Function["<()){ 262 | std::cout<<" type : \"TupleGetItemNode["<()){ 264 | auto e_node=Downcast(expr); 265 | std::cout<<" type : \"Op_"<name<<"["<()){ 267 | auto e_node=Downcast(expr); 268 | std::cout<<" type : \"Var["<anchor_ref!=nullptr){ 275 | std::cout<<" anchor_ref : \""<anchor_ref]<<"\"\n"; 276 | } 277 | if(expr.as()){ 278 | auto e_node=Downcast(expr); 279 | std::cout<<" index : "<index<<"\n"; 280 | }else if(expr.as()){ 281 | auto e_node=Downcast(expr); 282 | std::cout<<" tensor_type : \""<tensor_type()<<"\"\n"; 283 | }else if(expr.as()){ 284 | auto e_node=Downcast(expr); 285 | std::cout<<" name_hint : \""<name_hint()<<"\"\n"; 286 | } 287 | std::cout<<" }\n}\n"; 288 | } 289 | std::cout.rdbuf(coutbuf); 290 | out.close(); 291 | } 292 | } 293 | -------------------------------------------------------------------------------- /relay_parse/demo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torchvision.models as models 4 | 5 | from tvm import relay 6 | 7 | if __name__=='__main__': 8 | #prepare model and input 9 | model = models.resnet18(pretrained=True) 10 | shape_list = [("input0",(1,3,224,224))] 11 | fake_input = torch.from_numpy(np.random.random_sample(shape_list[0][1]).astype('float32')) 12 | graph = torch.jit.trace(model,fake_input) 13 | #main function 14 | mod, params = relay.frontend.from_pytorch(graph, shape_list) 15 | print("Parsed mod "+str(mod)) -------------------------------------------------------------------------------- /stmt_build/demo.py: -------------------------------------------------------------------------------- 1 | import tvm 2 | 3 | import sys 4 | sys.path.append("..") 5 | from visualize import PrimExprVisualizer 6 | 7 | if __name__=='__main__': 8 | visualizer=PrimExprVisualizer() 9 | n = tvm.te.var() 10 | A = tvm.te.placeholder((n, n), name='A') 11 | B = tvm.te.placeholder((n, n), name='B') 12 | #step 1 get ComputeOp 13 | C = tvm.te.compute((n, n), lambda i, j: A[i, j] + B[i, j], name='C') 14 | print("Compute struct "+str(C)) 15 | visualizer.visualize(C,"visualizes/normal_compute.prototxt") 16 | 17 | #step 2 get Schedule 18 | s = tvm.te.create_schedule(C.op) 19 | 20 | #step 3 build stmt 21 | mod=tvm.driver.build_module.form_irmodule(s,[A,B,C],"main",binds=None) 22 | visualizer.visualize(mod,"visualizes/normal_stmt.prototxt") 23 | visualizer.visualize(mod,"visualizes/normal_stmt_complete.prototxt",simple_mode=False) 24 | print("Stmt struct "+str(mod["main"])) -------------------------------------------------------------------------------- /stmt_build/test_cases.py: -------------------------------------------------------------------------------- 1 | import tvm 2 | 3 | import sys 4 | sys.path.append("..") 5 | from visualize import PrimExprVisualizer 6 | 7 | def cache_read(visualizer): 8 | print("\nTest cache_read") 9 | n = 1024 10 | dtype = "float32" 11 | A = tvm.te.placeholder((n, n), dtype=dtype, name='A') 12 | k = tvm.te.reduce_axis((0, n), name='k') 13 | B = tvm.te.compute((n,), lambda i: tvm.te.sum(A[i, k], axis=k), name='B') 14 | 15 | s = tvm.te.create_schedule(B.op) 16 | mod=tvm.driver.build_module.form_irmodule(s,[A,B],"main",binds=None) 17 | visualizer.visualize(mod,"visualizes/cache_read_before.prototxt") 18 | print(""+str(mod["main"])) 19 | 20 | #with cache_read 21 | AA = s.cache_read(A, "shared", [B]) 22 | mod=tvm.driver.build_module.form_irmodule(s,[A,B],"main",binds=None) 23 | visualizer.visualize(mod,"visualizes/cache_read_after.prototxt") 24 | print(""+str(mod["main"])) 25 | 26 | def compute_inline(visualizer): 27 | print("\nTest compute_inline") 28 | n = 1024 29 | k = 3 30 | pad = 2 31 | A = tvm.te.placeholder((n, n), name='A') 32 | W = tvm.te.placeholder((k, k), name='W') 33 | m = (n - k + 2 * pad) + 1 34 | Apad = tvm.te.compute((n + 2 * pad, n + 2 * pad), 35 | lambda yy, xx: tvm.te.if_then_else( 36 | tvm.te.all(yy >= pad, yy < pad + n, xx >= pad, xx < pad + n), 37 | A[yy - pad, xx - pad], tvm.tir.const(0., "float32")), 38 | name='Apad') 39 | 40 | ry = tvm.te.reduce_axis((0, k), name='ry') 41 | rx = tvm.te.reduce_axis((0, k), name='rx') 42 | 43 | B = tvm.te.compute((m, m), 44 | lambda yy, xx: 45 | tvm.te.sum(Apad[yy + ry, xx + rx] * W[ry, rx], 46 | axis=[ry, rx]), 47 | name='B') 48 | 49 | s = tvm.te.create_schedule(B.op) 50 | mod=tvm.driver.build_module.form_irmodule(s,[A,W,B],"main",binds=None) 51 | visualizer.visualize(mod,"visualizes/compute_inline_before.prototxt") 52 | print(""+str(mod["main"])) 53 | 54 | #with compute_inline 55 | s[Apad].compute_inline() 56 | mod=tvm.driver.build_module.form_irmodule(s,[A,W,B],"main",binds=None) 57 | visualizer.visualize(mod,"visualizes/compute_inline_after.prototxt") 58 | print(""+str(mod["main"])) 59 | 60 | def split(visualizer): 61 | print("\nTest split") 62 | n = 1024 63 | A = tvm.te.placeholder((n,), name='A') 64 | k = tvm.te.reduce_axis((0, n), name='k') 65 | B = tvm.te.compute((1,), lambda i: tvm.te.sum(A[k], axis=k), name='B') 66 | 67 | s = tvm.te.create_schedule(B.op) 68 | mod=tvm.driver.build_module.form_irmodule(s,[A,B],"main",binds=None) 69 | visualizer.visualize(mod,"visualizes/split_before.prototxt") 70 | print(""+str(mod["main"])) 71 | 72 | #with split 73 | ko, ki = s[B].split(B.op.reduce_axis[0], factor=32) 74 | mod=tvm.driver.build_module.form_irmodule(s,[A,B],"main",binds=None) 75 | visualizer.visualize(mod,"visualizes/split_after.prototxt") 76 | print(""+str(mod["main"])) 77 | 78 | def tensorize(visualizer): 79 | print("\nTest tensorize") 80 | N, M, L = 1024, 512, 64 81 | A = tvm.te.placeholder((N, L), name='A') 82 | B = tvm.te.placeholder((M, L), name='B') 83 | k = tvm.te.reduce_axis((0, L), name='k') 84 | C = tvm.te.compute((N, M), lambda i, j: tvm.te.sum(A[i, k] * B[j, k], axis=k), name='C') 85 | s = tvm.te.create_schedule(C.op) 86 | 87 | def intrin_gemv(m, l): 88 | a = tvm.te.placeholder((l,), name='a') 89 | b = tvm.te.placeholder((m, l), name='b') 90 | k = tvm.te.reduce_axis((0, l), name='k') 91 | c = tvm.te.compute((m,), lambda i: tvm.te.sum(a[k] * b[i, k], axis=k), name='c') 92 | Abuf = tvm.tir.decl_buffer(a.shape, a.dtype, name='A', offset_factor=1, strides=[1]) 93 | Bbuf = tvm.tir.decl_buffer(b.shape, b.dtype, name='B', offset_factor=1, strides=[tvm.te.var("s1"), 1]) 94 | Cbuf = tvm.tir.decl_buffer(c.shape, c.dtype, name='C', offset_factor=1, strides=[1]) 95 | 96 | def intrin_func(ins, outs): 97 | ib = tvm.tir.ir_builder.create() 98 | aa, bb = ins 99 | cc = outs[0] 100 | ib.emit(tvm.tir.call_extern("int32", "gemv_update", cc.access_ptr("w"), aa.access_ptr("r"), bb.access_ptr("r"), m, l, bb.strides[0])) 101 | return ib.get() 102 | return tvm.te.decl_tensor_intrin(c.op, intrin_func, binds={a: Abuf, b: Bbuf, c: Cbuf}) 103 | 104 | factor = 16 105 | x, y = C.op.axis 106 | z, = C.op.reduce_axis 107 | yo, yi = s[C].split(y, factor=factor) 108 | s[C].reorder(x, yo, yi, z) 109 | 110 | mod=tvm.driver.build_module.form_irmodule(s,[A,B,C],"main",binds=None) 111 | visualizer.visualize(mod,"visualizes/tensorize_before.prototxt") 112 | print(""+str(mod["main"])) 113 | 114 | #with tensorize 115 | gemv = intrin_gemv(factor, L) 116 | s[C].tensorize(yi, gemv) 117 | mod=tvm.driver.build_module.form_irmodule(s,[A,B,C],"main",binds=None) 118 | visualizer.visualize(mod,"visualizes/tensorize_after.prototxt") 119 | print(""+str(mod["main"])) 120 | 121 | def bind(visualizer): 122 | print("\nTest bind") 123 | n = 1024 124 | A = tvm.te.placeholder((n,), name='A') 125 | k = tvm.te.reduce_axis((0, n), name='k') 126 | B = tvm.te.compute((1,), lambda i: tvm.te.sum(A[k], axis=k), name='B') 127 | s = tvm.te.create_schedule(B.op) 128 | ko, ki = s[B].split(B.op.reduce_axis[0], factor=32) 129 | 130 | mod=tvm.driver.build_module.form_irmodule(s,[A,B],"main",binds=None) 131 | visualizer.visualize(mod,"visualizes/bind_before.prototxt") 132 | print(""+str(mod["main"])) 133 | 134 | #with bind 135 | s[B].bind(ko, tvm.te.thread_axis("blockIdx.x")) 136 | s[B].bind(ki, tvm.te.thread_axis("threadIdx.x")) 137 | mod=tvm.driver.build_module.form_irmodule(s,[A,B],"main",binds=None) 138 | visualizer.visualize(mod,"visualizes/bind_after.prototxt") 139 | print(""+str(mod["main"])) 140 | 141 | if __name__=='__main__': 142 | visualizer=PrimExprVisualizer() 143 | cache_read(visualizer) 144 | compute_inline(visualizer) 145 | split(visualizer) 146 | bind(visualizer) 147 | tensorize(visualizer) -------------------------------------------------------------------------------- /stmt_build/visualizes/bind_after.prototxt: -------------------------------------------------------------------------------- 1 | name : "prim_expr" 2 | layer { 3 | name:"B" 4 | type:"buffer(node)" 5 | top:"B" 6 | layer_param { 7 | idx:0 8 | buffer_name:"B" 9 | shape:[1] 10 | dtype:float32 11 | } 12 | } 13 | layer { 14 | name:"blockIdx.x" 15 | type:"var(iter)" 16 | top:"blockIdx.x" 17 | layer_param { 18 | idx:1 19 | dtype:int32 20 | } 21 | } 22 | layer { 23 | name:"Node_2" 24 | type:"itervar(node)" 25 | top:"Node_2" 26 | bottom:"blockIdx.x" 27 | layer_param { 28 | idx:2 29 | dom:"None" 30 | iter_type:"1" 31 | thread_tag:"blockIdx.x" 32 | } 33 | } 34 | layer { 35 | name:"threadIdx.x" 36 | type:"var(iter)" 37 | top:"threadIdx.x" 38 | layer_param { 39 | idx:3 40 | dtype:int32 41 | } 42 | } 43 | layer { 44 | name:"Node_4" 45 | type:"itervar(node)" 46 | top:"Node_4" 47 | bottom:"threadIdx.x" 48 | layer_param { 49 | idx:4 50 | dom:"None" 51 | iter_type:"1" 52 | thread_tag:"threadIdx.x" 53 | } 54 | } 55 | layer { 56 | name:"reduce_temp0" 57 | type:"var(node)" 58 | top:"reduce_temp0" 59 | layer_param { 60 | idx:5 61 | dtype:handle 62 | } 63 | } 64 | layer { 65 | name:"x" 66 | type:"var(reduce_l)" 67 | top:"x" 68 | layer_param { 69 | idx:6 70 | dtype:float32 71 | } 72 | } 73 | layer { 74 | name:"y" 75 | type:"var(reduce_r)" 76 | top:"y" 77 | layer_param { 78 | idx:7 79 | dtype:float32 80 | } 81 | } 82 | layer { 83 | name:"Node_8" 84 | type:"add(reduce_res)" 85 | top:"Node_8" 86 | bottom:"x" 87 | bottom:"y" 88 | layer_param { 89 | idx:8 90 | } 91 | } 92 | layer { 93 | name:"Node_9" 94 | type:"float(reduce_ind)" 95 | top:"Node_9" 96 | layer_param { 97 | idx:9 98 | value:0.0 99 | dtype:float32 100 | } 101 | } 102 | layer { 103 | name:"Node_10" 104 | type:"common_reducer(node)" 105 | top:"Node_10" 106 | bottom:"x" 107 | bottom:"y" 108 | bottom:"Node_8" 109 | bottom:"Node_9" 110 | layer_param { 111 | idx:10 112 | result_00:"[(x + y)]" 113 | } 114 | } 115 | layer { 116 | name:"Node_11" 117 | type:"int" 118 | top:"Node_11" 119 | layer_param { 120 | idx:11 121 | value:1 122 | dtype:uint32 123 | } 124 | } 125 | layer { 126 | name:"A" 127 | type:"buffer(buffer)" 128 | top:"A" 129 | layer_param { 130 | idx:12 131 | buffer_name:"A" 132 | shape:[1024] 133 | dtype:float32 134 | } 135 | } 136 | layer { 137 | name:"Node_13" 138 | type:"int(b)" 139 | top:"Node_13" 140 | layer_param { 141 | idx:13 142 | value:32 143 | dtype:int32 144 | } 145 | } 146 | layer { 147 | name:"Node_14" 148 | type:"mul(b)" 149 | top:"Node_14" 150 | bottom:"blockIdx.x" 151 | bottom:"Node_13" 152 | layer_param { 153 | idx:14 154 | } 155 | } 156 | layer { 157 | name:"Node_15" 158 | type:"add(indice)" 159 | top:"Node_15" 160 | bottom:"threadIdx.x" 161 | bottom:"Node_14" 162 | layer_param { 163 | idx:15 164 | } 165 | } 166 | layer { 167 | name:"Node_16" 168 | type:"buffer_load" 169 | top:"Node_16" 170 | bottom:"A" 171 | bottom:"Node_15" 172 | layer_param { 173 | idx:16 174 | } 175 | } 176 | layer { 177 | name:"Node_17" 178 | type:"Call_tir.tvm_thread_allreduce(value)" 179 | top:"Node_17" 180 | bottom:"Node_11" 181 | bottom:"Node_16" 182 | bottom:"Node_11" 183 | bottom:"reduce_temp0" 184 | bottom:"blockIdx.x" 185 | bottom:"threadIdx.x" 186 | layer_param { 187 | idx:17 188 | } 189 | } 190 | layer { 191 | name:"Node_18" 192 | type:"evaluate" 193 | top:"Node_18" 194 | bottom:"Node_17" 195 | layer_param { 196 | idx:18 197 | } 198 | } 199 | layer { 200 | name:"Node_19" 201 | type:"attribute(seq_0)" 202 | top:"Node_19" 203 | bottom:"Node_10" 204 | bottom:"Node_18" 205 | layer_param { 206 | idx:19 207 | attr_key:reduce_scope 208 | body_00:"tir.tvm_thread_allreduce((uint32)1, A[(threadIdx.x + (blockIdx.x*32))], (bool)1, reduce_temp0, blockIdx.x, threadIdx.x)" 209 | value_00:"@tir.reinterpret(0u64, dtype=handle)" 210 | } 211 | } 212 | layer { 213 | name:"Node_20" 214 | type:"int(load_index)" 215 | top:"Node_20" 216 | layer_param { 217 | idx:20 218 | value:0 219 | dtype:int32 220 | } 221 | } 222 | layer { 223 | name:"Node_21" 224 | type:"load(value)" 225 | top:"Node_21" 226 | bottom:"reduce_temp0" 227 | bottom:"Node_20" 228 | layer_param { 229 | idx:21 230 | predicate_00:"True" 231 | body_00:"(float32*)reduce_temp0: Pointer(float32)[0]" 232 | } 233 | } 234 | layer { 235 | name:"Node_22" 236 | type:"buffer_store(true)" 237 | top:"Node_22" 238 | bottom:"B" 239 | bottom:"Node_21" 240 | layer_param { 241 | idx:22 242 | value_00:"(float32*)reduce_temp0: Pointer(float32)[0]" 243 | indices_00:"[0]" 244 | } 245 | } 246 | layer { 247 | name:"Node_23" 248 | type:"ifthenelse(seq_1)" 249 | top:"Node_23" 250 | bottom:"Node_22" 251 | layer_param { 252 | idx:23 253 | condition:"True" 254 | } 255 | } 256 | layer { 257 | name:"Node_24" 258 | type:"seq" 259 | top:"Node_24" 260 | bottom:"Node_19" 261 | bottom:"Node_23" 262 | layer_param { 263 | idx:24 264 | seq_00:"[// attr [comm_reducer(result=[(x + y)], lhs=[x], rhs=[y], identity_element=[0f])] reduce_scope = tir.reinterpret((uint64)0)" 265 | seq_01:" tir.tvm_thread_allreduce((uint32)1, A[(threadIdx.x + (blockIdx.x*32))], (bool)1, reduce_temp0, blockIdx.x, threadIdx.x)" 266 | seq_02:" , if ((bool)1)" 267 | seq_03:" B[0] = reduce_temp0[0]" 268 | seq_04:" ]" 269 | } 270 | } 271 | layer { 272 | name:"Node_25" 273 | type:"allocate" 274 | top:"Node_25" 275 | bottom:"Node_24" 276 | layer_param { 277 | idx:25 278 | dtype:float32 279 | extents:"[1]" 280 | condition:"True" 281 | body_00:"// attr [comm_reducer(result=[(x + y)], lhs=[x], rhs=[y], identity_element=[0f])] reduce_scope = tir.reinterpret((uint64)0)" 282 | body_01:"tir.tvm_thread_allreduce((uint32)1, A[(threadIdx.x + (blockIdx.x*32))], (bool)1, reduce_temp0, blockIdx.x, threadIdx.x)" 283 | body_02:" if ((bool)1)" 284 | body_03:" B[0] = reduce_temp0[0]" 285 | } 286 | } 287 | layer { 288 | name:"Node_26" 289 | type:"attribute" 290 | top:"Node_26" 291 | bottom:"reduce_temp0" 292 | bottom:"Node_25" 293 | layer_param { 294 | idx:26 295 | attr_key:storage_scope 296 | body_00:"allocate reduce_temp0[float32 * 1]" 297 | body_01:" // attr [comm_reducer(result=[(x + y)], lhs=[x], rhs=[y], identity_element=[0f])] reduce_scope = tir.reinterpret((uint64)0)" 298 | body_02:" tir.tvm_thread_allreduce((uint32)1, A[(threadIdx.x + (blockIdx.x*32))], (bool)1, reduce_temp0, blockIdx.x, threadIdx.x)" 299 | body_03:" if ((bool)1)" 300 | body_04:" B[0] = reduce_temp0[0]" 301 | value_00:"'local'" 302 | } 303 | } 304 | layer { 305 | name:"Node_27" 306 | type:"attribute" 307 | top:"Node_27" 308 | bottom:"Node_4" 309 | bottom:"Node_26" 310 | layer_param { 311 | idx:27 312 | attr_key:thread_extent 313 | body_00:"// attr [reduce_temp0] storage_scope = 'local'" 314 | body_01:"allocate reduce_temp0[float32 * 1]" 315 | body_02:" // attr [comm_reducer(result=[(x + y)], lhs=[x], rhs=[y], identity_element=[0f])] reduce_scope = tir.reinterpret((uint64)0)" 316 | body_03:" tir.tvm_thread_allreduce((uint32)1, A[(threadIdx.x + (blockIdx.x*32))], (bool)1, reduce_temp0, blockIdx.x, threadIdx.x)" 317 | body_04:" if ((bool)1)" 318 | body_05:" B[0] = reduce_temp0[0]" 319 | value_00:"32" 320 | } 321 | } 322 | layer { 323 | name:"Node_28" 324 | type:"attribute" 325 | top:"Node_28" 326 | bottom:"Node_2" 327 | bottom:"Node_27" 328 | layer_param { 329 | idx:28 330 | attr_key:thread_extent 331 | body_00:"// attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 32" 332 | body_01:"// attr [reduce_temp0] storage_scope = 'local'" 333 | body_02:"allocate reduce_temp0[float32 * 1]" 334 | body_03:" // attr [comm_reducer(result=[(x + y)], lhs=[x], rhs=[y], identity_element=[0f])] reduce_scope = tir.reinterpret((uint64)0)" 335 | body_04:" tir.tvm_thread_allreduce((uint32)1, A[(threadIdx.x + (blockIdx.x*32))], (bool)1, reduce_temp0, blockIdx.x, threadIdx.x)" 336 | body_05:" if ((bool)1)" 337 | body_06:" B[0] = reduce_temp0[0]" 338 | value_00:"32" 339 | } 340 | } 341 | layer { 342 | name:"Node_29" 343 | type:"buffer_realize" 344 | top:"Node_29" 345 | bottom:"Node_28" 346 | layer_param { 347 | idx:29 348 | condition:True 349 | body_00:"// attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 32" 350 | body_01:"// attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 32" 351 | body_02:"// attr [reduce_temp0] storage_scope = 'local'" 352 | body_03:"allocate reduce_temp0[float32 * 1]" 353 | body_04:" // attr [comm_reducer(result=[(x + y)], lhs=[x], rhs=[y], identity_element=[0f])] reduce_scope = tir.reinterpret((uint64)0)" 354 | body_05:" tir.tvm_thread_allreduce((uint32)1, A[(threadIdx.x + (blockIdx.x*32))], (bool)1, reduce_temp0, blockIdx.x, threadIdx.x)" 355 | body_06:" if ((bool)1)" 356 | body_07:" B[0] = reduce_temp0[0]" 357 | bounds_00:"[range(min=0, ext=1)]" 358 | } 359 | } 360 | layer { 361 | name:"Node_30" 362 | type:"attribute" 363 | top:"Node_30" 364 | bottom:"B" 365 | bottom:"Node_29" 366 | layer_param { 367 | idx:30 368 | attr_key:realize_scope 369 | body_00:"buffer_realize B([0, 1])" 370 | body_01:" // attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 32" 371 | body_02:" // attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 32" 372 | body_03:" // attr [reduce_temp0] storage_scope = 'local'" 373 | body_04:" allocate reduce_temp0[float32 * 1]" 374 | body_05:" // attr [comm_reducer(result=[(x + y)], lhs=[x], rhs=[y], identity_element=[0f])] reduce_scope = tir.reinterpret((uint64)0)" 375 | body_06:" tir.tvm_thread_allreduce((uint32)1, A[(threadIdx.x + (blockIdx.x*32))], (bool)1, reduce_temp0, blockIdx.x, threadIdx.x)" 376 | body_07:" if ((bool)1)" 377 | body_08:" B[0] = reduce_temp0[0]" 378 | value_00:"''" 379 | } 380 | } 381 | layer { 382 | name:"Node_31" 383 | type:"primfunc" 384 | top:"Node_31" 385 | bottom:"Node_30" 386 | layer_param { 387 | idx:31 388 | body_00:"// attr [buffer(B, 0x7ff979c2a0a0)] realize_scope = ''" 389 | body_01:"buffer_realize B([0, 1])" 390 | body_02:" // attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 32" 391 | body_03:" // attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 32" 392 | body_04:" // attr [reduce_temp0] storage_scope = 'local'" 393 | body_05:" allocate reduce_temp0[float32 * 1]" 394 | body_06:" // attr [comm_reducer(result=[(x + y)], lhs=[x], rhs=[y], identity_element=[0f])] reduce_scope = tir.reinterpret((uint64)0)" 395 | body_07:" tir.tvm_thread_allreduce((uint32)1, A[(threadIdx.x + (blockIdx.x*32))], (bool)1, reduce_temp0, blockIdx.x, threadIdx.x)" 396 | body_08:" if ((bool)1)" 397 | body_09:" B[0] = reduce_temp0[0]" 398 | } 399 | } 400 | -------------------------------------------------------------------------------- /stmt_build/visualizes/bind_before.prototxt: -------------------------------------------------------------------------------- 1 | name : "prim_expr" 2 | layer { 3 | name:"B" 4 | type:"buffer(node)" 5 | top:"B" 6 | layer_param { 7 | idx:0 8 | buffer_name:"B" 9 | shape:[1] 10 | dtype:float32 11 | } 12 | } 13 | layer { 14 | name:"Node_1" 15 | type:"float(value)" 16 | top:"Node_1" 17 | layer_param { 18 | idx:1 19 | value:0.0 20 | dtype:float32 21 | } 22 | } 23 | layer { 24 | name:"Node_2" 25 | type:"buffer_store(seq_0)" 26 | top:"Node_2" 27 | bottom:"B" 28 | bottom:"Node_1" 29 | layer_param { 30 | idx:2 31 | value_00:"0f32" 32 | indices_00:"[0]" 33 | } 34 | } 35 | layer { 36 | name:"Node_3" 37 | type:"int(indice)" 38 | top:"Node_3" 39 | layer_param { 40 | idx:3 41 | value:0 42 | dtype:int32 43 | } 44 | } 45 | layer { 46 | name:"Node_4" 47 | type:"buffer_load(a)" 48 | top:"Node_4" 49 | bottom:"B" 50 | bottom:"Node_3" 51 | layer_param { 52 | idx:4 53 | } 54 | } 55 | layer { 56 | name:"A" 57 | type:"buffer(buffer)" 58 | top:"A" 59 | layer_param { 60 | idx:5 61 | buffer_name:"A" 62 | shape:[1024] 63 | dtype:float32 64 | } 65 | } 66 | layer { 67 | name:"k.inner" 68 | type:"var(a)" 69 | top:"k.inner" 70 | layer_param { 71 | idx:6 72 | dtype:int32 73 | } 74 | } 75 | layer { 76 | name:"k.outer" 77 | type:"var(a)" 78 | top:"k.outer" 79 | layer_param { 80 | idx:7 81 | dtype:int32 82 | } 83 | } 84 | layer { 85 | name:"Node_8" 86 | type:"int(b)" 87 | top:"Node_8" 88 | layer_param { 89 | idx:8 90 | value:32 91 | dtype:int32 92 | } 93 | } 94 | layer { 95 | name:"Node_9" 96 | type:"mul(b)" 97 | top:"Node_9" 98 | bottom:"k.outer" 99 | bottom:"Node_8" 100 | layer_param { 101 | idx:9 102 | } 103 | } 104 | layer { 105 | name:"Node_10" 106 | type:"add(indice)" 107 | top:"Node_10" 108 | bottom:"k.inner" 109 | bottom:"Node_9" 110 | layer_param { 111 | idx:10 112 | } 113 | } 114 | layer { 115 | name:"Node_11" 116 | type:"buffer_load(b)" 117 | top:"Node_11" 118 | bottom:"A" 119 | bottom:"Node_10" 120 | layer_param { 121 | idx:11 122 | } 123 | } 124 | layer { 125 | name:"Node_12" 126 | type:"add(value)" 127 | top:"Node_12" 128 | bottom:"Node_4" 129 | bottom:"Node_11" 130 | layer_param { 131 | idx:12 132 | } 133 | } 134 | layer { 135 | name:"Node_13" 136 | type:"buffer_store" 137 | top:"Node_13" 138 | bottom:"B" 139 | bottom:"Node_12" 140 | layer_param { 141 | idx:13 142 | value_00:"(B: Buffer(B_1: Pointer(float32), float32, [1], [])[0] + A: Buffer(A_1: Pointer(float32), float32, [1024], [])[(k.inner: int32 + (k.outer: int32*32))])" 143 | indices_00:"[0]" 144 | } 145 | } 146 | layer { 147 | name:"Node_14" 148 | type:"for" 149 | top:"Node_14" 150 | bottom:"Node_13" 151 | layer_param { 152 | idx:14 153 | kind:0 154 | body_00:"B[0] = (B[0] + A[(k.inner + (k.outer*32))])" 155 | } 156 | } 157 | layer { 158 | name:"Node_15" 159 | type:"for(seq_1)" 160 | top:"Node_15" 161 | bottom:"Node_14" 162 | layer_param { 163 | idx:15 164 | kind:0 165 | body_00:"for (k.inner, 0, 32)" 166 | body_01:" B[0] = (B[0] + A[(k.inner + (k.outer*32))])" 167 | } 168 | } 169 | layer { 170 | name:"Node_16" 171 | type:"seq" 172 | top:"Node_16" 173 | bottom:"Node_2" 174 | bottom:"Node_15" 175 | layer_param { 176 | idx:16 177 | seq_00:"[B[0] = 0f" 178 | seq_01:" , for (k.outer, 0, 32)" 179 | seq_02:" for (k.inner, 0, 32)" 180 | seq_03:" B[0] = (B[0] + A[(k.inner + (k.outer*32))])" 181 | seq_04:" ]" 182 | } 183 | } 184 | layer { 185 | name:"Node_17" 186 | type:"buffer_realize" 187 | top:"Node_17" 188 | bottom:"Node_16" 189 | layer_param { 190 | idx:17 191 | condition:True 192 | body_00:"B[0] = 0f" 193 | body_01:" for (k.outer, 0, 32)" 194 | body_02:" for (k.inner, 0, 32)" 195 | body_03:" B[0] = (B[0] + A[(k.inner + (k.outer*32))])" 196 | bounds_00:"[range(min=0, ext=1)]" 197 | } 198 | } 199 | layer { 200 | name:"Node_18" 201 | type:"attribute" 202 | top:"Node_18" 203 | bottom:"B" 204 | bottom:"Node_17" 205 | layer_param { 206 | idx:18 207 | attr_key:realize_scope 208 | body_00:"buffer_realize B([0, 1])" 209 | body_01:" B[0] = 0f" 210 | body_02:" for (k.outer, 0, 32)" 211 | body_03:" for (k.inner, 0, 32)" 212 | body_04:" B[0] = (B[0] + A[(k.inner + (k.outer*32))])" 213 | value_00:"''" 214 | } 215 | } 216 | layer { 217 | name:"Node_19" 218 | type:"primfunc" 219 | top:"Node_19" 220 | bottom:"Node_18" 221 | layer_param { 222 | idx:19 223 | body_00:"// attr [buffer(B, 0x7ff97c52aaf0)] realize_scope = ''" 224 | body_01:"buffer_realize B([0, 1])" 225 | body_02:" B[0] = 0f" 226 | body_03:" for (k.outer, 0, 32)" 227 | body_04:" for (k.inner, 0, 32)" 228 | body_05:" B[0] = (B[0] + A[(k.inner + (k.outer*32))])" 229 | } 230 | } 231 | -------------------------------------------------------------------------------- /stmt_build/visualizes/cache_read_after.prototxt: -------------------------------------------------------------------------------- 1 | name : "prim_expr" 2 | layer { 3 | name:"A.shared" 4 | type:"buffer(node)" 5 | top:"A.shared" 6 | layer_param { 7 | idx:0 8 | buffer_name:"A.shared" 9 | shape:[1024, 1024] 10 | dtype:float32 11 | } 12 | } 13 | layer { 14 | name:"A" 15 | type:"buffer(buffer)" 16 | top:"A" 17 | layer_param { 18 | idx:1 19 | buffer_name:"A" 20 | shape:[1024, 1024] 21 | dtype:float32 22 | } 23 | } 24 | layer { 25 | name:"ax0" 26 | type:"var(indice)" 27 | top:"ax0" 28 | layer_param { 29 | idx:2 30 | dtype:int32 31 | } 32 | } 33 | layer { 34 | name:"ax1" 35 | type:"var(indice)" 36 | top:"ax1" 37 | layer_param { 38 | idx:3 39 | dtype:int32 40 | } 41 | } 42 | layer { 43 | name:"Node_4" 44 | type:"buffer_load(value)" 45 | top:"Node_4" 46 | bottom:"A" 47 | bottom:"ax0" 48 | bottom:"ax1" 49 | layer_param { 50 | idx:4 51 | } 52 | } 53 | layer { 54 | name:"Node_5" 55 | type:"buffer_store" 56 | top:"Node_5" 57 | bottom:"A.shared" 58 | bottom:"Node_4" 59 | layer_param { 60 | idx:5 61 | value_00:"A: Buffer(A_1: Pointer(float32), float32, [1024, 1024], [])[ax0: int32, ax1: int32]" 62 | indices_00:"[ax0, ax1]" 63 | } 64 | } 65 | layer { 66 | name:"Node_6" 67 | type:"for" 68 | top:"Node_6" 69 | bottom:"Node_5" 70 | layer_param { 71 | idx:6 72 | kind:0 73 | body_00:"A.shared[ax0, ax1] = A[ax0, ax1]" 74 | } 75 | } 76 | layer { 77 | name:"Node_7" 78 | type:"for(seq_0)" 79 | top:"Node_7" 80 | bottom:"Node_6" 81 | layer_param { 82 | idx:7 83 | kind:0 84 | body_00:"for (ax1, 0, 1024)" 85 | body_01:" A.shared[ax0, ax1] = A[ax0, ax1]" 86 | } 87 | } 88 | layer { 89 | name:"B" 90 | type:"buffer(node)" 91 | top:"B" 92 | layer_param { 93 | idx:8 94 | buffer_name:"B" 95 | shape:[1024] 96 | dtype:float32 97 | } 98 | } 99 | layer { 100 | name:"Node_9" 101 | type:"float(value)" 102 | top:"Node_9" 103 | layer_param { 104 | idx:9 105 | value:0.0 106 | dtype:float32 107 | } 108 | } 109 | layer { 110 | name:"Node_10" 111 | type:"buffer_store(seq_0)" 112 | top:"Node_10" 113 | bottom:"B" 114 | bottom:"Node_9" 115 | layer_param { 116 | idx:10 117 | value_00:"0f32" 118 | indices_00:"[i]" 119 | } 120 | } 121 | layer { 122 | name:"i" 123 | type:"var(indice)" 124 | top:"i" 125 | layer_param { 126 | idx:11 127 | dtype:int32 128 | } 129 | } 130 | layer { 131 | name:"Node_12" 132 | type:"buffer_load(a)" 133 | top:"Node_12" 134 | bottom:"B" 135 | bottom:"i" 136 | layer_param { 137 | idx:12 138 | } 139 | } 140 | layer { 141 | name:"k" 142 | type:"var(indice)" 143 | top:"k" 144 | layer_param { 145 | idx:13 146 | dtype:int32 147 | } 148 | } 149 | layer { 150 | name:"Node_14" 151 | type:"buffer_load(b)" 152 | top:"Node_14" 153 | bottom:"A.shared" 154 | bottom:"i" 155 | bottom:"k" 156 | layer_param { 157 | idx:14 158 | } 159 | } 160 | layer { 161 | name:"Node_15" 162 | type:"add(value)" 163 | top:"Node_15" 164 | bottom:"Node_12" 165 | bottom:"Node_14" 166 | layer_param { 167 | idx:15 168 | } 169 | } 170 | layer { 171 | name:"Node_16" 172 | type:"buffer_store" 173 | top:"Node_16" 174 | bottom:"B" 175 | bottom:"Node_15" 176 | layer_param { 177 | idx:16 178 | value_00:"(B: Buffer(B_1: Pointer(float32), float32, [1024], [])[i: int32] + A.shared: Buffer(A.shared_1: Pointer(float32), float32, [1024, 1024], [])[i, k: int32])" 179 | indices_00:"[i]" 180 | } 181 | } 182 | layer { 183 | name:"Node_17" 184 | type:"for(seq_1)" 185 | top:"Node_17" 186 | bottom:"Node_16" 187 | layer_param { 188 | idx:17 189 | kind:0 190 | body_00:"B[i] = (B[i] + A.shared[i, k])" 191 | } 192 | } 193 | layer { 194 | name:"Node_18" 195 | type:"seq" 196 | top:"Node_18" 197 | bottom:"Node_10" 198 | bottom:"Node_17" 199 | layer_param { 200 | idx:18 201 | seq_00:"[B[i] = 0f" 202 | seq_01:" , for (k, 0, 1024)" 203 | seq_02:" B[i] = (B[i] + A.shared[i, k])" 204 | seq_03:" ]" 205 | } 206 | } 207 | layer { 208 | name:"Node_19" 209 | type:"for" 210 | top:"Node_19" 211 | bottom:"Node_18" 212 | layer_param { 213 | idx:19 214 | kind:0 215 | body_00:"B[i] = 0f" 216 | body_01:" for (k, 0, 1024)" 217 | body_02:" B[i] = (B[i] + A.shared[i, k])" 218 | } 219 | } 220 | layer { 221 | name:"Node_20" 222 | type:"buffer_realize" 223 | top:"Node_20" 224 | bottom:"Node_19" 225 | layer_param { 226 | idx:20 227 | condition:True 228 | body_00:"for (i, 0, 1024)" 229 | body_01:" B[i] = 0f" 230 | body_02:" for (k, 0, 1024)" 231 | body_03:" B[i] = (B[i] + A.shared[i, k])" 232 | bounds_00:"[range(min=0, ext=1024)]" 233 | } 234 | } 235 | layer { 236 | name:"Node_21" 237 | type:"attribute(seq_1)" 238 | top:"Node_21" 239 | bottom:"B" 240 | bottom:"Node_20" 241 | layer_param { 242 | idx:21 243 | attr_key:realize_scope 244 | body_00:"buffer_realize B([0, 1024])" 245 | body_01:" for (i, 0, 1024)" 246 | body_02:" B[i] = 0f" 247 | body_03:" for (k, 0, 1024)" 248 | body_04:" B[i] = (B[i] + A.shared[i, k])" 249 | value_00:"''" 250 | } 251 | } 252 | layer { 253 | name:"Node_22" 254 | type:"seq" 255 | top:"Node_22" 256 | bottom:"Node_7" 257 | bottom:"Node_21" 258 | layer_param { 259 | idx:22 260 | seq_00:"[for (ax0, 0, 1024)" 261 | seq_01:" for (ax1, 0, 1024)" 262 | seq_02:" A.shared[ax0, ax1] = A[ax0, ax1]" 263 | seq_03:" , // attr [buffer(B, 0x7ff97c733100)] realize_scope = ''" 264 | seq_04:" buffer_realize B([0, 1024])" 265 | seq_05:" for (i, 0, 1024)" 266 | seq_06:" B[i] = 0f" 267 | seq_07:" for (k, 0, 1024)" 268 | seq_08:" B[i] = (B[i] + A.shared[i, k])" 269 | seq_09:" ]" 270 | } 271 | } 272 | layer { 273 | name:"Node_23" 274 | type:"buffer_realize" 275 | top:"Node_23" 276 | bottom:"Node_22" 277 | layer_param { 278 | idx:23 279 | condition:True 280 | body_00:"for (ax0, 0, 1024)" 281 | body_01:" for (ax1, 0, 1024)" 282 | body_02:" A.shared[ax0, ax1] = A[ax0, ax1]" 283 | body_03:" // attr [buffer(B, 0x7ff97c733100)] realize_scope = ''" 284 | body_04:" buffer_realize B([0, 1024])" 285 | body_05:" for (i, 0, 1024)" 286 | body_06:" B[i] = 0f" 287 | body_07:" for (k, 0, 1024)" 288 | body_08:" B[i] = (B[i] + A.shared[i, k])" 289 | bounds_00:"[range(min=0, ext=1024), range(min=0, ext=1024)]" 290 | } 291 | } 292 | layer { 293 | name:"Node_24" 294 | type:"attribute" 295 | top:"Node_24" 296 | bottom:"A.shared" 297 | bottom:"Node_23" 298 | layer_param { 299 | idx:24 300 | attr_key:realize_scope 301 | body_00:"buffer_realize A.shared([0, 1024], [0, 1024])" 302 | body_01:" for (ax0, 0, 1024)" 303 | body_02:" for (ax1, 0, 1024)" 304 | body_03:" A.shared[ax0, ax1] = A[ax0, ax1]" 305 | body_04:" // attr [buffer(B, 0x7ff97c733100)] realize_scope = ''" 306 | body_05:" buffer_realize B([0, 1024])" 307 | body_06:" for (i, 0, 1024)" 308 | body_07:" B[i] = 0f" 309 | body_08:" for (k, 0, 1024)" 310 | body_09:" B[i] = (B[i] + A.shared[i, k])" 311 | value_00:"'shared'" 312 | } 313 | } 314 | layer { 315 | name:"Node_25" 316 | type:"primfunc" 317 | top:"Node_25" 318 | bottom:"Node_24" 319 | layer_param { 320 | idx:25 321 | body_00:"// attr [buffer(A.shared, 0x7ff97c732c10)] realize_scope = 'shared'" 322 | body_01:"buffer_realize A.shared([0, 1024], [0, 1024])" 323 | body_02:" for (ax0, 0, 1024)" 324 | body_03:" for (ax1, 0, 1024)" 325 | body_04:" A.shared[ax0, ax1] = A[ax0, ax1]" 326 | body_05:" // attr [buffer(B, 0x7ff97c733100)] realize_scope = ''" 327 | body_06:" buffer_realize B([0, 1024])" 328 | body_07:" for (i, 0, 1024)" 329 | body_08:" B[i] = 0f" 330 | body_09:" for (k, 0, 1024)" 331 | body_10:" B[i] = (B[i] + A.shared[i, k])" 332 | } 333 | } 334 | -------------------------------------------------------------------------------- /stmt_build/visualizes/cache_read_before.prototxt: -------------------------------------------------------------------------------- 1 | name : "prim_expr" 2 | layer { 3 | name:"B" 4 | type:"buffer(node)" 5 | top:"B" 6 | layer_param { 7 | idx:0 8 | buffer_name:"B" 9 | shape:[1024] 10 | dtype:float32 11 | } 12 | } 13 | layer { 14 | name:"Node_1" 15 | type:"float(value)" 16 | top:"Node_1" 17 | layer_param { 18 | idx:1 19 | value:0.0 20 | dtype:float32 21 | } 22 | } 23 | layer { 24 | name:"Node_2" 25 | type:"buffer_store(seq_0)" 26 | top:"Node_2" 27 | bottom:"B" 28 | bottom:"Node_1" 29 | layer_param { 30 | idx:2 31 | value_00:"0f32" 32 | indices_00:"[i]" 33 | } 34 | } 35 | layer { 36 | name:"i" 37 | type:"var(indice)" 38 | top:"i" 39 | layer_param { 40 | idx:3 41 | dtype:int32 42 | } 43 | } 44 | layer { 45 | name:"Node_4" 46 | type:"buffer_load(a)" 47 | top:"Node_4" 48 | bottom:"B" 49 | bottom:"i" 50 | layer_param { 51 | idx:4 52 | } 53 | } 54 | layer { 55 | name:"A" 56 | type:"buffer(buffer)" 57 | top:"A" 58 | layer_param { 59 | idx:5 60 | buffer_name:"A" 61 | shape:[1024, 1024] 62 | dtype:float32 63 | } 64 | } 65 | layer { 66 | name:"k" 67 | type:"var(indice)" 68 | top:"k" 69 | layer_param { 70 | idx:6 71 | dtype:int32 72 | } 73 | } 74 | layer { 75 | name:"Node_7" 76 | type:"buffer_load(b)" 77 | top:"Node_7" 78 | bottom:"A" 79 | bottom:"i" 80 | bottom:"k" 81 | layer_param { 82 | idx:7 83 | } 84 | } 85 | layer { 86 | name:"Node_8" 87 | type:"add(value)" 88 | top:"Node_8" 89 | bottom:"Node_4" 90 | bottom:"Node_7" 91 | layer_param { 92 | idx:8 93 | } 94 | } 95 | layer { 96 | name:"Node_9" 97 | type:"buffer_store" 98 | top:"Node_9" 99 | bottom:"B" 100 | bottom:"Node_8" 101 | layer_param { 102 | idx:9 103 | value_00:"(B: Buffer(B_1: Pointer(float32), float32, [1024], [])[i: int32] + A: Buffer(A_1: Pointer(float32), float32, [1024, 1024], [])[i, k: int32])" 104 | indices_00:"[i]" 105 | } 106 | } 107 | layer { 108 | name:"Node_10" 109 | type:"for(seq_1)" 110 | top:"Node_10" 111 | bottom:"Node_9" 112 | layer_param { 113 | idx:10 114 | kind:0 115 | body_00:"B[i] = (B[i] + A[i, k])" 116 | } 117 | } 118 | layer { 119 | name:"Node_11" 120 | type:"seq" 121 | top:"Node_11" 122 | bottom:"Node_2" 123 | bottom:"Node_10" 124 | layer_param { 125 | idx:11 126 | seq_00:"[B[i] = 0f" 127 | seq_01:" , for (k, 0, 1024)" 128 | seq_02:" B[i] = (B[i] + A[i, k])" 129 | seq_03:" ]" 130 | } 131 | } 132 | layer { 133 | name:"Node_12" 134 | type:"for" 135 | top:"Node_12" 136 | bottom:"Node_11" 137 | layer_param { 138 | idx:12 139 | kind:0 140 | body_00:"B[i] = 0f" 141 | body_01:" for (k, 0, 1024)" 142 | body_02:" B[i] = (B[i] + A[i, k])" 143 | } 144 | } 145 | layer { 146 | name:"Node_13" 147 | type:"buffer_realize" 148 | top:"Node_13" 149 | bottom:"Node_12" 150 | layer_param { 151 | idx:13 152 | condition:True 153 | body_00:"for (i, 0, 1024)" 154 | body_01:" B[i] = 0f" 155 | body_02:" for (k, 0, 1024)" 156 | body_03:" B[i] = (B[i] + A[i, k])" 157 | bounds_00:"[range(min=0, ext=1024)]" 158 | } 159 | } 160 | layer { 161 | name:"Node_14" 162 | type:"attribute" 163 | top:"Node_14" 164 | bottom:"B" 165 | bottom:"Node_13" 166 | layer_param { 167 | idx:14 168 | attr_key:realize_scope 169 | body_00:"buffer_realize B([0, 1024])" 170 | body_01:" for (i, 0, 1024)" 171 | body_02:" B[i] = 0f" 172 | body_03:" for (k, 0, 1024)" 173 | body_04:" B[i] = (B[i] + A[i, k])" 174 | value_00:"''" 175 | } 176 | } 177 | layer { 178 | name:"Node_15" 179 | type:"primfunc" 180 | top:"Node_15" 181 | bottom:"Node_14" 182 | layer_param { 183 | idx:15 184 | body_00:"// attr [buffer(B, 0x7ff97c72e280)] realize_scope = ''" 185 | body_01:"buffer_realize B([0, 1024])" 186 | body_02:" for (i, 0, 1024)" 187 | body_03:" B[i] = 0f" 188 | body_04:" for (k, 0, 1024)" 189 | body_05:" B[i] = (B[i] + A[i, k])" 190 | } 191 | } 192 | -------------------------------------------------------------------------------- /stmt_build/visualizes/compute_inline_after.prototxt: -------------------------------------------------------------------------------- 1 | name : "prim_expr" 2 | layer { 3 | name:"B" 4 | type:"buffer(node)" 5 | top:"B" 6 | layer_param { 7 | idx:0 8 | buffer_name:"B" 9 | shape:[1026, 1026] 10 | dtype:float32 11 | } 12 | } 13 | layer { 14 | name:"Node_1" 15 | type:"float(value)" 16 | top:"Node_1" 17 | layer_param { 18 | idx:1 19 | value:0.0 20 | dtype:float32 21 | } 22 | } 23 | layer { 24 | name:"Node_2" 25 | type:"buffer_store(seq_0)" 26 | top:"Node_2" 27 | bottom:"B" 28 | bottom:"Node_1" 29 | layer_param { 30 | idx:2 31 | value_00:"0f32" 32 | indices_00:"[yy, xx]" 33 | } 34 | } 35 | layer { 36 | name:"yy" 37 | type:"var(indice)" 38 | top:"yy" 39 | layer_param { 40 | idx:3 41 | dtype:int32 42 | } 43 | } 44 | layer { 45 | name:"xx" 46 | type:"var(indice)" 47 | top:"xx" 48 | layer_param { 49 | idx:4 50 | dtype:int32 51 | } 52 | } 53 | layer { 54 | name:"Node_5" 55 | type:"buffer_load(a)" 56 | top:"Node_5" 57 | bottom:"B" 58 | bottom:"yy" 59 | bottom:"xx" 60 | layer_param { 61 | idx:5 62 | } 63 | } 64 | layer { 65 | name:"ry" 66 | type:"var(b)" 67 | top:"ry" 68 | layer_param { 69 | idx:6 70 | dtype:int32 71 | } 72 | } 73 | layer { 74 | name:"Node_7" 75 | type:"add(a)" 76 | top:"Node_7" 77 | bottom:"yy" 78 | bottom:"ry" 79 | layer_param { 80 | idx:7 81 | } 82 | } 83 | layer { 84 | name:"Node_8" 85 | type:"int(b)" 86 | top:"Node_8" 87 | layer_param { 88 | idx:8 89 | value:2 90 | dtype:int32 91 | } 92 | } 93 | layer { 94 | name:"Node_9" 95 | type:"greater_equal(a)" 96 | top:"Node_9" 97 | bottom:"Node_7" 98 | bottom:"Node_8" 99 | layer_param { 100 | idx:9 101 | } 102 | } 103 | layer { 104 | name:"Node_10" 105 | type:"add(a)" 106 | top:"Node_10" 107 | bottom:"yy" 108 | bottom:"ry" 109 | layer_param { 110 | idx:10 111 | } 112 | } 113 | layer { 114 | name:"Node_11" 115 | type:"int(b)" 116 | top:"Node_11" 117 | layer_param { 118 | idx:11 119 | value:1026 120 | dtype:int32 121 | } 122 | } 123 | layer { 124 | name:"Node_12" 125 | type:"less_than(b)" 126 | top:"Node_12" 127 | bottom:"Node_10" 128 | bottom:"Node_11" 129 | layer_param { 130 | idx:12 131 | } 132 | } 133 | layer { 134 | name:"Node_13" 135 | type:"and(a)" 136 | top:"Node_13" 137 | bottom:"Node_9" 138 | bottom:"Node_12" 139 | layer_param { 140 | idx:13 141 | } 142 | } 143 | layer { 144 | name:"rx" 145 | type:"var(b)" 146 | top:"rx" 147 | layer_param { 148 | idx:14 149 | dtype:int32 150 | } 151 | } 152 | layer { 153 | name:"Node_15" 154 | type:"add(a)" 155 | top:"Node_15" 156 | bottom:"xx" 157 | bottom:"rx" 158 | layer_param { 159 | idx:15 160 | } 161 | } 162 | layer { 163 | name:"Node_16" 164 | type:"greater_equal(b)" 165 | top:"Node_16" 166 | bottom:"Node_15" 167 | bottom:"Node_8" 168 | layer_param { 169 | idx:16 170 | } 171 | } 172 | layer { 173 | name:"Node_17" 174 | type:"and(a)" 175 | top:"Node_17" 176 | bottom:"Node_13" 177 | bottom:"Node_16" 178 | layer_param { 179 | idx:17 180 | } 181 | } 182 | layer { 183 | name:"Node_18" 184 | type:"add(a)" 185 | top:"Node_18" 186 | bottom:"xx" 187 | bottom:"rx" 188 | layer_param { 189 | idx:18 190 | } 191 | } 192 | layer { 193 | name:"Node_19" 194 | type:"less_than(b)" 195 | top:"Node_19" 196 | bottom:"Node_18" 197 | bottom:"Node_11" 198 | layer_param { 199 | idx:19 200 | } 201 | } 202 | layer { 203 | name:"Node_20" 204 | type:"and" 205 | top:"Node_20" 206 | bottom:"Node_17" 207 | bottom:"Node_19" 208 | layer_param { 209 | idx:20 210 | } 211 | } 212 | layer { 213 | name:"A" 214 | type:"buffer(buffer)" 215 | top:"A" 216 | layer_param { 217 | idx:21 218 | buffer_name:"A" 219 | shape:[1024, 1024] 220 | dtype:float32 221 | } 222 | } 223 | layer { 224 | name:"Node_22" 225 | type:"add(a)" 226 | top:"Node_22" 227 | bottom:"yy" 228 | bottom:"ry" 229 | layer_param { 230 | idx:22 231 | } 232 | } 233 | layer { 234 | name:"Node_23" 235 | type:"sub(indice)" 236 | top:"Node_23" 237 | bottom:"Node_22" 238 | bottom:"Node_8" 239 | layer_param { 240 | idx:23 241 | } 242 | } 243 | layer { 244 | name:"Node_24" 245 | type:"add(a)" 246 | top:"Node_24" 247 | bottom:"xx" 248 | bottom:"rx" 249 | layer_param { 250 | idx:24 251 | } 252 | } 253 | layer { 254 | name:"Node_25" 255 | type:"sub(indice)" 256 | top:"Node_25" 257 | bottom:"Node_24" 258 | bottom:"Node_8" 259 | layer_param { 260 | idx:25 261 | } 262 | } 263 | layer { 264 | name:"Node_26" 265 | type:"buffer_load" 266 | top:"Node_26" 267 | bottom:"A" 268 | bottom:"Node_23" 269 | bottom:"Node_25" 270 | layer_param { 271 | idx:26 272 | } 273 | } 274 | layer { 275 | name:"Node_27" 276 | type:"float" 277 | top:"Node_27" 278 | layer_param { 279 | idx:27 280 | value:0.0 281 | dtype:float32 282 | } 283 | } 284 | layer { 285 | name:"Node_28" 286 | type:"Call_tir.if_then_else(a)" 287 | top:"Node_28" 288 | bottom:"Node_20" 289 | bottom:"Node_26" 290 | bottom:"Node_27" 291 | layer_param { 292 | idx:28 293 | } 294 | } 295 | layer { 296 | name:"W" 297 | type:"buffer(buffer)" 298 | top:"W" 299 | layer_param { 300 | idx:29 301 | buffer_name:"W" 302 | shape:[3, 3] 303 | dtype:float32 304 | } 305 | } 306 | layer { 307 | name:"Node_30" 308 | type:"buffer_load(b)" 309 | top:"Node_30" 310 | bottom:"W" 311 | bottom:"ry" 312 | bottom:"rx" 313 | layer_param { 314 | idx:30 315 | } 316 | } 317 | layer { 318 | name:"Node_31" 319 | type:"mul(b)" 320 | top:"Node_31" 321 | bottom:"Node_28" 322 | bottom:"Node_30" 323 | layer_param { 324 | idx:31 325 | } 326 | } 327 | layer { 328 | name:"Node_32" 329 | type:"add(value)" 330 | top:"Node_32" 331 | bottom:"Node_5" 332 | bottom:"Node_31" 333 | layer_param { 334 | idx:32 335 | } 336 | } 337 | layer { 338 | name:"Node_33" 339 | type:"buffer_store" 340 | top:"Node_33" 341 | bottom:"B" 342 | bottom:"Node_32" 343 | layer_param { 344 | idx:33 345 | value_00:"(B: Buffer(B_1: Pointer(float32), float32, [1026, 1026], [])[yy: int32, xx: int32] + (@tir.if_then_else((((((yy + ry: int32) >= 2) && ((yy + ry) < 1026)) && ((xx + rx: int32) >= 2)) && ((xx + rx) < 1026)), A: Buffer(A_1: Pointer(float32), float32, [1024, 1024], [])[((yy + ry) - 2), ((xx + rx) - 2)], 0f32, dtype=float32)*W: Buffer(W_1: Pointer(float32), float32, [3, 3], [])[ry, rx]))" 346 | indices_00:"[yy, xx]" 347 | } 348 | } 349 | layer { 350 | name:"Node_34" 351 | type:"for" 352 | top:"Node_34" 353 | bottom:"Node_33" 354 | layer_param { 355 | idx:34 356 | kind:0 357 | body_00:"B[yy, xx] = (B[yy, xx] + (tir.if_then_else((((((yy + ry) >= 2) && ((yy + ry) < 1026)) && ((xx + rx) >= 2)) && ((xx + rx) < 1026)), A[((yy + ry) - 2), ((xx + rx) - 2)], 0f)*W[ry, rx]))" 358 | } 359 | } 360 | layer { 361 | name:"Node_35" 362 | type:"for(seq_1)" 363 | top:"Node_35" 364 | bottom:"Node_34" 365 | layer_param { 366 | idx:35 367 | kind:0 368 | body_00:"for (rx, 0, 3)" 369 | body_01:" B[yy, xx] = (B[yy, xx] + (tir.if_then_else((((((yy + ry) >= 2) && ((yy + ry) < 1026)) && ((xx + rx) >= 2)) && ((xx + rx) < 1026)), A[((yy + ry) - 2), ((xx + rx) - 2)], 0f)*W[ry, rx]))" 370 | } 371 | } 372 | layer { 373 | name:"Node_36" 374 | type:"seq" 375 | top:"Node_36" 376 | bottom:"Node_2" 377 | bottom:"Node_35" 378 | layer_param { 379 | idx:36 380 | seq_00:"[B[yy, xx] = 0f" 381 | seq_01:" , for (ry, 0, 3)" 382 | seq_02:" for (rx, 0, 3)" 383 | seq_03:" B[yy, xx] = (B[yy, xx] + (tir.if_then_else((((((yy + ry) >= 2) && ((yy + ry) < 1026)) && ((xx + rx) >= 2)) && ((xx + rx) < 1026)), A[((yy + ry) - 2), ((xx + rx) - 2)], 0f)*W[ry, rx]))" 384 | seq_04:" ]" 385 | } 386 | } 387 | layer { 388 | name:"Node_37" 389 | type:"for" 390 | top:"Node_37" 391 | bottom:"Node_36" 392 | layer_param { 393 | idx:37 394 | kind:0 395 | body_00:"B[yy, xx] = 0f" 396 | body_01:" for (ry, 0, 3)" 397 | body_02:" for (rx, 0, 3)" 398 | body_03:" B[yy, xx] = (B[yy, xx] + (tir.if_then_else((((((yy + ry) >= 2) && ((yy + ry) < 1026)) && ((xx + rx) >= 2)) && ((xx + rx) < 1026)), A[((yy + ry) - 2), ((xx + rx) - 2)], 0f)*W[ry, rx]))" 399 | } 400 | } 401 | layer { 402 | name:"Node_38" 403 | type:"for" 404 | top:"Node_38" 405 | bottom:"Node_37" 406 | layer_param { 407 | idx:38 408 | kind:0 409 | body_00:"for (xx, 0, 1026)" 410 | body_01:" B[yy, xx] = 0f" 411 | body_02:" for (ry, 0, 3)" 412 | body_03:" for (rx, 0, 3)" 413 | body_04:" B[yy, xx] = (B[yy, xx] + (tir.if_then_else((((((yy + ry) >= 2) && ((yy + ry) < 1026)) && ((xx + rx) >= 2)) && ((xx + rx) < 1026)), A[((yy + ry) - 2), ((xx + rx) - 2)], 0f)*W[ry, rx]))" 414 | } 415 | } 416 | layer { 417 | name:"Node_39" 418 | type:"buffer_realize" 419 | top:"Node_39" 420 | bottom:"Node_38" 421 | layer_param { 422 | idx:39 423 | condition:True 424 | body_00:"for (yy, 0, 1026)" 425 | body_01:" for (xx, 0, 1026)" 426 | body_02:" B[yy, xx] = 0f" 427 | body_03:" for (ry, 0, 3)" 428 | body_04:" for (rx, 0, 3)" 429 | body_05:" B[yy, xx] = (B[yy, xx] + (tir.if_then_else((((((yy + ry) >= 2) && ((yy + ry) < 1026)) && ((xx + rx) >= 2)) && ((xx + rx) < 1026)), A[((yy + ry) - 2), ((xx + rx) - 2)], 0f)*W[ry, rx]))" 430 | bounds_00:"[range(min=0, ext=1026), range(min=0, ext=1026)]" 431 | } 432 | } 433 | layer { 434 | name:"Node_40" 435 | type:"attribute" 436 | top:"Node_40" 437 | bottom:"B" 438 | bottom:"Node_39" 439 | layer_param { 440 | idx:40 441 | attr_key:realize_scope 442 | body_00:"buffer_realize B([0, 1026], [0, 1026])" 443 | body_01:" for (yy, 0, 1026)" 444 | body_02:" for (xx, 0, 1026)" 445 | body_03:" B[yy, xx] = 0f" 446 | body_04:" for (ry, 0, 3)" 447 | body_05:" for (rx, 0, 3)" 448 | body_06:" B[yy, xx] = (B[yy, xx] + (tir.if_then_else((((((yy + ry) >= 2) && ((yy + ry) < 1026)) && ((xx + rx) >= 2)) && ((xx + rx) < 1026)), A[((yy + ry) - 2), ((xx + rx) - 2)], 0f)*W[ry, rx]))" 449 | value_00:"''" 450 | } 451 | } 452 | layer { 453 | name:"Node_41" 454 | type:"primfunc" 455 | top:"Node_41" 456 | bottom:"Node_40" 457 | layer_param { 458 | idx:41 459 | body_00:"// attr [buffer(B, 0x7ff979c6d770)] realize_scope = ''" 460 | body_01:"buffer_realize B([0, 1026], [0, 1026])" 461 | body_02:" for (yy, 0, 1026)" 462 | body_03:" for (xx, 0, 1026)" 463 | body_04:" B[yy, xx] = 0f" 464 | body_05:" for (ry, 0, 3)" 465 | body_06:" for (rx, 0, 3)" 466 | body_07:" B[yy, xx] = (B[yy, xx] + (tir.if_then_else((((((yy + ry) >= 2) && ((yy + ry) < 1026)) && ((xx + rx) >= 2)) && ((xx + rx) < 1026)), A[((yy + ry) - 2), ((xx + rx) - 2)], 0f)*W[ry, rx]))" 467 | } 468 | } 469 | -------------------------------------------------------------------------------- /stmt_build/visualizes/compute_inline_before.prototxt: -------------------------------------------------------------------------------- 1 | name : "prim_expr" 2 | layer { 3 | name:"Apad" 4 | type:"buffer(node)" 5 | top:"Apad" 6 | layer_param { 7 | idx:0 8 | buffer_name:"Apad" 9 | shape:[1028, 1028] 10 | dtype:float32 11 | } 12 | } 13 | layer { 14 | name:"yy" 15 | type:"var(a)" 16 | top:"yy" 17 | layer_param { 18 | idx:1 19 | dtype:int32 20 | } 21 | } 22 | layer { 23 | name:"Node_2" 24 | type:"int(b)" 25 | top:"Node_2" 26 | layer_param { 27 | idx:2 28 | value:2 29 | dtype:int32 30 | } 31 | } 32 | layer { 33 | name:"Node_3" 34 | type:"greater_equal(a)" 35 | top:"Node_3" 36 | bottom:"yy" 37 | bottom:"Node_2" 38 | layer_param { 39 | idx:3 40 | } 41 | } 42 | layer { 43 | name:"Node_4" 44 | type:"int(b)" 45 | top:"Node_4" 46 | layer_param { 47 | idx:4 48 | value:1026 49 | dtype:int32 50 | } 51 | } 52 | layer { 53 | name:"Node_5" 54 | type:"less_than(b)" 55 | top:"Node_5" 56 | bottom:"yy" 57 | bottom:"Node_4" 58 | layer_param { 59 | idx:5 60 | } 61 | } 62 | layer { 63 | name:"Node_6" 64 | type:"and(a)" 65 | top:"Node_6" 66 | bottom:"Node_3" 67 | bottom:"Node_5" 68 | layer_param { 69 | idx:6 70 | } 71 | } 72 | layer { 73 | name:"xx" 74 | type:"var(a)" 75 | top:"xx" 76 | layer_param { 77 | idx:7 78 | dtype:int32 79 | } 80 | } 81 | layer { 82 | name:"Node_8" 83 | type:"greater_equal(b)" 84 | top:"Node_8" 85 | bottom:"xx" 86 | bottom:"Node_2" 87 | layer_param { 88 | idx:8 89 | } 90 | } 91 | layer { 92 | name:"Node_9" 93 | type:"and(a)" 94 | top:"Node_9" 95 | bottom:"Node_6" 96 | bottom:"Node_8" 97 | layer_param { 98 | idx:9 99 | } 100 | } 101 | layer { 102 | name:"Node_10" 103 | type:"less_than(b)" 104 | top:"Node_10" 105 | bottom:"xx" 106 | bottom:"Node_4" 107 | layer_param { 108 | idx:10 109 | } 110 | } 111 | layer { 112 | name:"Node_11" 113 | type:"and" 114 | top:"Node_11" 115 | bottom:"Node_9" 116 | bottom:"Node_10" 117 | layer_param { 118 | idx:11 119 | } 120 | } 121 | layer { 122 | name:"A" 123 | type:"buffer(buffer)" 124 | top:"A" 125 | layer_param { 126 | idx:12 127 | buffer_name:"A" 128 | shape:[1024, 1024] 129 | dtype:float32 130 | } 131 | } 132 | layer { 133 | name:"Node_13" 134 | type:"sub(indice)" 135 | top:"Node_13" 136 | bottom:"yy" 137 | bottom:"Node_2" 138 | layer_param { 139 | idx:13 140 | } 141 | } 142 | layer { 143 | name:"Node_14" 144 | type:"sub(indice)" 145 | top:"Node_14" 146 | bottom:"xx" 147 | bottom:"Node_2" 148 | layer_param { 149 | idx:14 150 | } 151 | } 152 | layer { 153 | name:"Node_15" 154 | type:"buffer_load" 155 | top:"Node_15" 156 | bottom:"A" 157 | bottom:"Node_13" 158 | bottom:"Node_14" 159 | layer_param { 160 | idx:15 161 | } 162 | } 163 | layer { 164 | name:"Node_16" 165 | type:"float" 166 | top:"Node_16" 167 | layer_param { 168 | idx:16 169 | value:0.0 170 | dtype:float32 171 | } 172 | } 173 | layer { 174 | name:"Node_17" 175 | type:"Call_tir.if_then_else(value)" 176 | top:"Node_17" 177 | bottom:"Node_11" 178 | bottom:"Node_15" 179 | bottom:"Node_16" 180 | layer_param { 181 | idx:17 182 | } 183 | } 184 | layer { 185 | name:"Node_18" 186 | type:"buffer_store" 187 | top:"Node_18" 188 | bottom:"Apad" 189 | bottom:"Node_17" 190 | layer_param { 191 | idx:18 192 | value_00:"@tir.if_then_else(((((yy: int32 >= 2) && (yy < 1026)) && (xx: int32 >= 2)) && (xx < 1026)), A: Buffer(A_1: Pointer(float32), float32, [1024, 1024], [])[(yy - 2), (xx - 2)], 0f32, dtype=float32)" 193 | indices_00:"[yy, xx]" 194 | } 195 | } 196 | layer { 197 | name:"Node_19" 198 | type:"for" 199 | top:"Node_19" 200 | bottom:"Node_18" 201 | layer_param { 202 | idx:19 203 | kind:0 204 | body_00:"Apad[yy, xx] = tir.if_then_else(((((yy >= 2) && (yy < 1026)) && (xx >= 2)) && (xx < 1026)), A[(yy - 2), (xx - 2)], 0f)" 205 | } 206 | } 207 | layer { 208 | name:"Node_20" 209 | type:"for(seq_0)" 210 | top:"Node_20" 211 | bottom:"Node_19" 212 | layer_param { 213 | idx:20 214 | kind:0 215 | body_00:"for (xx, 0, 1028)" 216 | body_01:" Apad[yy, xx] = tir.if_then_else(((((yy >= 2) && (yy < 1026)) && (xx >= 2)) && (xx < 1026)), A[(yy - 2), (xx - 2)], 0f)" 217 | } 218 | } 219 | layer { 220 | name:"B" 221 | type:"buffer(node)" 222 | top:"B" 223 | layer_param { 224 | idx:21 225 | buffer_name:"B" 226 | shape:[1026, 1026] 227 | dtype:float32 228 | } 229 | } 230 | layer { 231 | name:"Node_22" 232 | type:"float(value)" 233 | top:"Node_22" 234 | layer_param { 235 | idx:22 236 | value:0.0 237 | dtype:float32 238 | } 239 | } 240 | layer { 241 | name:"Node_23" 242 | type:"buffer_store(seq_0)" 243 | top:"Node_23" 244 | bottom:"B" 245 | bottom:"Node_22" 246 | layer_param { 247 | idx:23 248 | value_00:"0f32" 249 | indices_00:"[yy, xx]" 250 | } 251 | } 252 | layer { 253 | name:"yy_1" 254 | type:"var(indice)" 255 | top:"yy_1" 256 | layer_param { 257 | idx:24 258 | dtype:int32 259 | } 260 | } 261 | layer { 262 | name:"xx_1" 263 | type:"var(indice)" 264 | top:"xx_1" 265 | layer_param { 266 | idx:25 267 | dtype:int32 268 | } 269 | } 270 | layer { 271 | name:"Node_26" 272 | type:"buffer_load(a)" 273 | top:"Node_26" 274 | bottom:"B" 275 | bottom:"yy_1" 276 | bottom:"xx_1" 277 | layer_param { 278 | idx:26 279 | } 280 | } 281 | layer { 282 | name:"ry" 283 | type:"var(b)" 284 | top:"ry" 285 | layer_param { 286 | idx:27 287 | dtype:int32 288 | } 289 | } 290 | layer { 291 | name:"Node_28" 292 | type:"add(indice)" 293 | top:"Node_28" 294 | bottom:"yy_1" 295 | bottom:"ry" 296 | layer_param { 297 | idx:28 298 | } 299 | } 300 | layer { 301 | name:"rx" 302 | type:"var(b)" 303 | top:"rx" 304 | layer_param { 305 | idx:29 306 | dtype:int32 307 | } 308 | } 309 | layer { 310 | name:"Node_30" 311 | type:"add(indice)" 312 | top:"Node_30" 313 | bottom:"xx_1" 314 | bottom:"rx" 315 | layer_param { 316 | idx:30 317 | } 318 | } 319 | layer { 320 | name:"Node_31" 321 | type:"buffer_load(a)" 322 | top:"Node_31" 323 | bottom:"Apad" 324 | bottom:"Node_28" 325 | bottom:"Node_30" 326 | layer_param { 327 | idx:31 328 | } 329 | } 330 | layer { 331 | name:"W" 332 | type:"buffer(buffer)" 333 | top:"W" 334 | layer_param { 335 | idx:32 336 | buffer_name:"W" 337 | shape:[3, 3] 338 | dtype:float32 339 | } 340 | } 341 | layer { 342 | name:"Node_33" 343 | type:"buffer_load(b)" 344 | top:"Node_33" 345 | bottom:"W" 346 | bottom:"ry" 347 | bottom:"rx" 348 | layer_param { 349 | idx:33 350 | } 351 | } 352 | layer { 353 | name:"Node_34" 354 | type:"mul(b)" 355 | top:"Node_34" 356 | bottom:"Node_31" 357 | bottom:"Node_33" 358 | layer_param { 359 | idx:34 360 | } 361 | } 362 | layer { 363 | name:"Node_35" 364 | type:"add(value)" 365 | top:"Node_35" 366 | bottom:"Node_26" 367 | bottom:"Node_34" 368 | layer_param { 369 | idx:35 370 | } 371 | } 372 | layer { 373 | name:"Node_36" 374 | type:"buffer_store" 375 | top:"Node_36" 376 | bottom:"B" 377 | bottom:"Node_35" 378 | layer_param { 379 | idx:36 380 | value_00:"(B: Buffer(B_1: Pointer(float32), float32, [1026, 1026], [])[yy: int32, xx: int32] + (Apad: Buffer(Apad_1: Pointer(float32), float32, [1028, 1028], [])[(yy + ry: int32), (xx + rx: int32)]*W: Buffer(W_1: Pointer(float32), float32, [3, 3], [])[ry, rx]))" 381 | indices_00:"[yy, xx]" 382 | } 383 | } 384 | layer { 385 | name:"Node_37" 386 | type:"for" 387 | top:"Node_37" 388 | bottom:"Node_36" 389 | layer_param { 390 | idx:37 391 | kind:0 392 | body_00:"B[yy, xx] = (B[yy, xx] + (Apad[(yy + ry), (xx + rx)]*W[ry, rx]))" 393 | } 394 | } 395 | layer { 396 | name:"Node_38" 397 | type:"for(seq_1)" 398 | top:"Node_38" 399 | bottom:"Node_37" 400 | layer_param { 401 | idx:38 402 | kind:0 403 | body_00:"for (rx, 0, 3)" 404 | body_01:" B[yy, xx] = (B[yy, xx] + (Apad[(yy + ry), (xx + rx)]*W[ry, rx]))" 405 | } 406 | } 407 | layer { 408 | name:"Node_39" 409 | type:"seq" 410 | top:"Node_39" 411 | bottom:"Node_23" 412 | bottom:"Node_38" 413 | layer_param { 414 | idx:39 415 | seq_00:"[B[yy, xx] = 0f" 416 | seq_01:" , for (ry, 0, 3)" 417 | seq_02:" for (rx, 0, 3)" 418 | seq_03:" B[yy, xx] = (B[yy, xx] + (Apad[(yy + ry), (xx + rx)]*W[ry, rx]))" 419 | seq_04:" ]" 420 | } 421 | } 422 | layer { 423 | name:"Node_40" 424 | type:"for" 425 | top:"Node_40" 426 | bottom:"Node_39" 427 | layer_param { 428 | idx:40 429 | kind:0 430 | body_00:"B[yy, xx] = 0f" 431 | body_01:" for (ry, 0, 3)" 432 | body_02:" for (rx, 0, 3)" 433 | body_03:" B[yy, xx] = (B[yy, xx] + (Apad[(yy + ry), (xx + rx)]*W[ry, rx]))" 434 | } 435 | } 436 | layer { 437 | name:"Node_41" 438 | type:"for" 439 | top:"Node_41" 440 | bottom:"Node_40" 441 | layer_param { 442 | idx:41 443 | kind:0 444 | body_00:"for (xx, 0, 1026)" 445 | body_01:" B[yy, xx] = 0f" 446 | body_02:" for (ry, 0, 3)" 447 | body_03:" for (rx, 0, 3)" 448 | body_04:" B[yy, xx] = (B[yy, xx] + (Apad[(yy + ry), (xx + rx)]*W[ry, rx]))" 449 | } 450 | } 451 | layer { 452 | name:"Node_42" 453 | type:"buffer_realize" 454 | top:"Node_42" 455 | bottom:"Node_41" 456 | layer_param { 457 | idx:42 458 | condition:True 459 | body_00:"for (yy, 0, 1026)" 460 | body_01:" for (xx, 0, 1026)" 461 | body_02:" B[yy, xx] = 0f" 462 | body_03:" for (ry, 0, 3)" 463 | body_04:" for (rx, 0, 3)" 464 | body_05:" B[yy, xx] = (B[yy, xx] + (Apad[(yy + ry), (xx + rx)]*W[ry, rx]))" 465 | bounds_00:"[range(min=0, ext=1026), range(min=0, ext=1026)]" 466 | } 467 | } 468 | layer { 469 | name:"Node_43" 470 | type:"attribute(seq_1)" 471 | top:"Node_43" 472 | bottom:"B" 473 | bottom:"Node_42" 474 | layer_param { 475 | idx:43 476 | attr_key:realize_scope 477 | body_00:"buffer_realize B([0, 1026], [0, 1026])" 478 | body_01:" for (yy, 0, 1026)" 479 | body_02:" for (xx, 0, 1026)" 480 | body_03:" B[yy, xx] = 0f" 481 | body_04:" for (ry, 0, 3)" 482 | body_05:" for (rx, 0, 3)" 483 | body_06:" B[yy, xx] = (B[yy, xx] + (Apad[(yy + ry), (xx + rx)]*W[ry, rx]))" 484 | value_00:"''" 485 | } 486 | } 487 | layer { 488 | name:"Node_44" 489 | type:"seq" 490 | top:"Node_44" 491 | bottom:"Node_20" 492 | bottom:"Node_43" 493 | layer_param { 494 | idx:44 495 | seq_00:"[for (yy, 0, 1028)" 496 | seq_01:" for (xx, 0, 1028)" 497 | seq_02:" Apad[yy, xx] = tir.if_then_else(((((yy >= 2) && (yy < 1026)) && (xx >= 2)) && (xx < 1026)), A[(yy - 2), (xx - 2)], 0f)" 498 | seq_03:" , // attr [buffer(B, 0x7ff979d7f8b0)] realize_scope = ''" 499 | seq_04:" buffer_realize B([0, 1026], [0, 1026])" 500 | seq_05:" for (yy, 0, 1026)" 501 | seq_06:" for (xx, 0, 1026)" 502 | seq_07:" B[yy, xx] = 0f" 503 | seq_08:" for (ry, 0, 3)" 504 | seq_09:" for (rx, 0, 3)" 505 | seq_10:" B[yy, xx] = (B[yy, xx] + (Apad[(yy + ry), (xx + rx)]*W[ry, rx]))" 506 | seq_11:" ]" 507 | } 508 | } 509 | layer { 510 | name:"Node_45" 511 | type:"buffer_realize" 512 | top:"Node_45" 513 | bottom:"Node_44" 514 | layer_param { 515 | idx:45 516 | condition:True 517 | body_00:"for (yy, 0, 1028)" 518 | body_01:" for (xx, 0, 1028)" 519 | body_02:" Apad[yy, xx] = tir.if_then_else(((((yy >= 2) && (yy < 1026)) && (xx >= 2)) && (xx < 1026)), A[(yy - 2), (xx - 2)], 0f)" 520 | body_03:" // attr [buffer(B, 0x7ff979d7f8b0)] realize_scope = ''" 521 | body_04:" buffer_realize B([0, 1026], [0, 1026])" 522 | body_05:" for (yy, 0, 1026)" 523 | body_06:" for (xx, 0, 1026)" 524 | body_07:" B[yy, xx] = 0f" 525 | body_08:" for (ry, 0, 3)" 526 | body_09:" for (rx, 0, 3)" 527 | body_10:" B[yy, xx] = (B[yy, xx] + (Apad[(yy + ry), (xx + rx)]*W[ry, rx]))" 528 | bounds_00:"[range(min=0, ext=1028), range(min=0, ext=1028)]" 529 | } 530 | } 531 | layer { 532 | name:"Node_46" 533 | type:"attribute" 534 | top:"Node_46" 535 | bottom:"Apad" 536 | bottom:"Node_45" 537 | layer_param { 538 | idx:46 539 | attr_key:realize_scope 540 | body_00:"buffer_realize Apad([0, 1028], [0, 1028])" 541 | body_01:" for (yy, 0, 1028)" 542 | body_02:" for (xx, 0, 1028)" 543 | body_03:" Apad[yy, xx] = tir.if_then_else(((((yy >= 2) && (yy < 1026)) && (xx >= 2)) && (xx < 1026)), A[(yy - 2), (xx - 2)], 0f)" 544 | body_04:" // attr [buffer(B, 0x7ff979d7f8b0)] realize_scope = ''" 545 | body_05:" buffer_realize B([0, 1026], [0, 1026])" 546 | body_06:" for (yy, 0, 1026)" 547 | body_07:" for (xx, 0, 1026)" 548 | body_08:" B[yy, xx] = 0f" 549 | body_09:" for (ry, 0, 3)" 550 | body_10:" for (rx, 0, 3)" 551 | body_11:" B[yy, xx] = (B[yy, xx] + (Apad[(yy + ry), (xx + rx)]*W[ry, rx]))" 552 | value_00:"''" 553 | } 554 | } 555 | layer { 556 | name:"Node_47" 557 | type:"primfunc" 558 | top:"Node_47" 559 | bottom:"Node_46" 560 | layer_param { 561 | idx:47 562 | body_00:"// attr [buffer(Apad, 0x7ff979d812d0)] realize_scope = ''" 563 | body_01:"buffer_realize Apad([0, 1028], [0, 1028])" 564 | body_02:" for (yy, 0, 1028)" 565 | body_03:" for (xx, 0, 1028)" 566 | body_04:" Apad[yy, xx] = tir.if_then_else(((((yy >= 2) && (yy < 1026)) && (xx >= 2)) && (xx < 1026)), A[(yy - 2), (xx - 2)], 0f)" 567 | body_05:" // attr [buffer(B, 0x7ff979d7f8b0)] realize_scope = ''" 568 | body_06:" buffer_realize B([0, 1026], [0, 1026])" 569 | body_07:" for (yy, 0, 1026)" 570 | body_08:" for (xx, 0, 1026)" 571 | body_09:" B[yy, xx] = 0f" 572 | body_10:" for (ry, 0, 3)" 573 | body_11:" for (rx, 0, 3)" 574 | body_12:" B[yy, xx] = (B[yy, xx] + (Apad[(yy + ry), (xx + rx)]*W[ry, rx]))" 575 | } 576 | } 577 | -------------------------------------------------------------------------------- /stmt_build/visualizes/normal_compute.prototxt: -------------------------------------------------------------------------------- 1 | name : "prim_expr" 2 | layer { 3 | name:"A" 4 | type:"place_holder" 5 | top:"A" 6 | layer_param { 7 | idx:0 8 | dtype:float32 9 | } 10 | } 11 | layer { 12 | name:"i" 13 | type:"var(indice)" 14 | top:"i" 15 | layer_param { 16 | idx:1 17 | dtype:int32 18 | } 19 | } 20 | layer { 21 | name:"j" 22 | type:"var(indice)" 23 | top:"j" 24 | layer_param { 25 | idx:2 26 | dtype:int32 27 | } 28 | } 29 | layer { 30 | name:"Node_3" 31 | type:"producer_load(a)" 32 | top:"Node_3" 33 | bottom:"A" 34 | bottom:"i" 35 | bottom:"j" 36 | layer_param { 37 | idx:3 38 | } 39 | } 40 | layer { 41 | name:"B" 42 | type:"place_holder" 43 | top:"B" 44 | layer_param { 45 | idx:4 46 | dtype:float32 47 | } 48 | } 49 | layer { 50 | name:"Node_5" 51 | type:"producer_load(b)" 52 | top:"Node_5" 53 | bottom:"B" 54 | bottom:"i" 55 | bottom:"j" 56 | layer_param { 57 | idx:5 58 | } 59 | } 60 | layer { 61 | name:"Node_6" 62 | type:"add" 63 | top:"Node_6" 64 | bottom:"Node_3" 65 | bottom:"Node_5" 66 | layer_param { 67 | idx:6 68 | } 69 | } 70 | layer { 71 | name:"C" 72 | type:"compute" 73 | top:"C" 74 | bottom:"Node_6" 75 | layer_param { 76 | idx:7 77 | } 78 | } 79 | -------------------------------------------------------------------------------- /stmt_build/visualizes/normal_stmt.prototxt: -------------------------------------------------------------------------------- 1 | name : "prim_expr" 2 | layer { 3 | name:"C" 4 | type:"buffer(node)" 5 | top:"C" 6 | layer_param { 7 | idx:0 8 | buffer_name:"C" 9 | shape:[tindex, tindex] 10 | dtype:float32 11 | } 12 | } 13 | layer { 14 | name:"A" 15 | type:"buffer(buffer)" 16 | top:"A" 17 | layer_param { 18 | idx:1 19 | buffer_name:"A" 20 | shape:[tindex, tindex] 21 | dtype:float32 22 | } 23 | } 24 | layer { 25 | name:"i" 26 | type:"var(indice)" 27 | top:"i" 28 | layer_param { 29 | idx:2 30 | dtype:int32 31 | } 32 | } 33 | layer { 34 | name:"j" 35 | type:"var(indice)" 36 | top:"j" 37 | layer_param { 38 | idx:3 39 | dtype:int32 40 | } 41 | } 42 | layer { 43 | name:"Node_4" 44 | type:"buffer_load(a)" 45 | top:"Node_4" 46 | bottom:"A" 47 | bottom:"i" 48 | bottom:"j" 49 | layer_param { 50 | idx:4 51 | } 52 | } 53 | layer { 54 | name:"B" 55 | type:"buffer(buffer)" 56 | top:"B" 57 | layer_param { 58 | idx:5 59 | buffer_name:"B" 60 | shape:[tindex, tindex] 61 | dtype:float32 62 | } 63 | } 64 | layer { 65 | name:"Node_6" 66 | type:"buffer_load(b)" 67 | top:"Node_6" 68 | bottom:"B" 69 | bottom:"i" 70 | bottom:"j" 71 | layer_param { 72 | idx:6 73 | } 74 | } 75 | layer { 76 | name:"Node_7" 77 | type:"add(value)" 78 | top:"Node_7" 79 | bottom:"Node_4" 80 | bottom:"Node_6" 81 | layer_param { 82 | idx:7 83 | } 84 | } 85 | layer { 86 | name:"Node_8" 87 | type:"buffer_store" 88 | top:"Node_8" 89 | bottom:"C" 90 | bottom:"Node_7" 91 | layer_param { 92 | idx:8 93 | value_00:"(A: Buffer(A_1: Pointer(float32), float32, [tindex: int32, tindex], [stride: int32, stride_1: int32], type='auto')[i: int32, j: int32] + B: Buffer(B_1: Pointer(float32), float32, [tindex, tindex], [stride_2: int32, stride_3: int32], type='auto')[i, j])" 94 | indices_00:"[i, j]" 95 | } 96 | } 97 | layer { 98 | name:"Node_9" 99 | type:"for" 100 | top:"Node_9" 101 | bottom:"Node_8" 102 | layer_param { 103 | idx:9 104 | kind:0 105 | body_00:"C[i, j] = (A[i, j] + B[i, j])" 106 | } 107 | } 108 | layer { 109 | name:"Node_10" 110 | type:"for" 111 | top:"Node_10" 112 | bottom:"Node_9" 113 | layer_param { 114 | idx:10 115 | kind:0 116 | body_00:"for (j, 0, tindex)" 117 | body_01:" C[i, j] = (A[i, j] + B[i, j])" 118 | } 119 | } 120 | layer { 121 | name:"Node_11" 122 | type:"buffer_realize" 123 | top:"Node_11" 124 | bottom:"Node_10" 125 | layer_param { 126 | idx:11 127 | condition:True 128 | body_00:"for (i, 0, tindex)" 129 | body_01:" for (j, 0, tindex)" 130 | body_02:" C[i, j] = (A[i, j] + B[i, j])" 131 | bounds_00:"[range(min=0, ext=tindex), range(min=0, ext=tindex)]" 132 | } 133 | } 134 | layer { 135 | name:"Node_12" 136 | type:"attribute" 137 | top:"Node_12" 138 | bottom:"C" 139 | bottom:"Node_11" 140 | layer_param { 141 | idx:12 142 | attr_key:realize_scope 143 | body_00:"buffer_realize C([0, tindex], [0, tindex])" 144 | body_01:" for (i, 0, tindex)" 145 | body_02:" for (j, 0, tindex)" 146 | body_03:" C[i, j] = (A[i, j] + B[i, j])" 147 | value_00:"''" 148 | } 149 | } 150 | layer { 151 | name:"Node_13" 152 | type:"primfunc" 153 | top:"Node_13" 154 | bottom:"Node_12" 155 | layer_param { 156 | idx:13 157 | body_00:"// attr [buffer(C, 0x7fc805f8b9e0)] realize_scope = ''" 158 | body_01:"buffer_realize C([0, tindex], [0, tindex])" 159 | body_02:" for (i, 0, tindex)" 160 | body_03:" for (j, 0, tindex)" 161 | body_04:" C[i, j] = (A[i, j] + B[i, j])" 162 | } 163 | } 164 | -------------------------------------------------------------------------------- /stmt_build/visualizes/normal_stmt_complete.prototxt: -------------------------------------------------------------------------------- 1 | name : "prim_expr" 2 | layer { 3 | name:"A" 4 | type:"buffer" 5 | top:"A" 6 | layer_param { 7 | idx:0 8 | buffer_name:"A" 9 | shape:[tindex, tindex] 10 | dtype:float32 11 | } 12 | } 13 | layer { 14 | name:"B" 15 | type:"buffer" 16 | top:"B" 17 | layer_param { 18 | idx:1 19 | buffer_name:"B" 20 | shape:[tindex, tindex] 21 | dtype:float32 22 | } 23 | } 24 | layer { 25 | name:"C" 26 | type:"buffer" 27 | top:"C" 28 | layer_param { 29 | idx:2 30 | buffer_name:"C" 31 | shape:[tindex, tindex] 32 | dtype:float32 33 | } 34 | } 35 | layer { 36 | name:"Node_3" 37 | type:"int" 38 | top:"Node_3" 39 | layer_param { 40 | idx:3 41 | value:0 42 | dtype:int32 43 | } 44 | } 45 | layer { 46 | name:"tindex" 47 | type:"var" 48 | top:"tindex" 49 | layer_param { 50 | idx:4 51 | dtype:int32 52 | } 53 | } 54 | layer { 55 | name:"Node_5" 56 | type:"range(bound_0)" 57 | top:"Node_5" 58 | bottom:"Node_3" 59 | bottom:"tindex" 60 | layer_param { 61 | idx:5 62 | range_00:"range(min=0, ext=tindex)" 63 | } 64 | } 65 | layer { 66 | name:"Node_6" 67 | type:"range(bound_1)" 68 | top:"Node_6" 69 | bottom:"Node_3" 70 | bottom:"tindex" 71 | layer_param { 72 | idx:6 73 | range_00:"range(min=0, ext=tindex)" 74 | } 75 | } 76 | layer { 77 | name:"i" 78 | type:"var(loop_var)" 79 | top:"i" 80 | layer_param { 81 | idx:7 82 | dtype:int32 83 | } 84 | } 85 | layer { 86 | name:"j" 87 | type:"var(loop_var)" 88 | top:"j" 89 | layer_param { 90 | idx:8 91 | dtype:int32 92 | } 93 | } 94 | layer { 95 | name:"Node_9" 96 | type:"buffer_load(a)" 97 | top:"Node_9" 98 | bottom:"A" 99 | bottom:"i" 100 | bottom:"j" 101 | layer_param { 102 | idx:9 103 | } 104 | } 105 | layer { 106 | name:"Node_10" 107 | type:"buffer_load(b)" 108 | top:"Node_10" 109 | bottom:"B" 110 | bottom:"i" 111 | bottom:"j" 112 | layer_param { 113 | idx:10 114 | } 115 | } 116 | layer { 117 | name:"Node_11" 118 | type:"add(value)" 119 | top:"Node_11" 120 | bottom:"Node_9" 121 | bottom:"Node_10" 122 | layer_param { 123 | idx:11 124 | } 125 | } 126 | layer { 127 | name:"Node_12" 128 | type:"buffer_store" 129 | top:"Node_12" 130 | bottom:"C" 131 | bottom:"Node_11" 132 | bottom:"i" 133 | bottom:"j" 134 | layer_param { 135 | idx:12 136 | value_00:"(A: Buffer(A_1: Pointer(float32), float32, [tindex: int32, tindex], [stride: int32, stride_1: int32], type='auto')[i: int32, j: int32] + B: Buffer(B_1: Pointer(float32), float32, [tindex, tindex], [stride_2: int32, stride_3: int32], type='auto')[i, j])" 137 | indices_00:"[i, j]" 138 | } 139 | } 140 | layer { 141 | name:"Node_13" 142 | type:"for" 143 | top:"Node_13" 144 | bottom:"j" 145 | bottom:"Node_3" 146 | bottom:"tindex" 147 | bottom:"Node_12" 148 | layer_param { 149 | idx:13 150 | kind:0 151 | body_00:"C[i, j] = (A[i, j] + B[i, j])" 152 | } 153 | } 154 | layer { 155 | name:"Node_14" 156 | type:"for" 157 | top:"Node_14" 158 | bottom:"i" 159 | bottom:"Node_3" 160 | bottom:"tindex" 161 | bottom:"Node_13" 162 | layer_param { 163 | idx:14 164 | kind:0 165 | body_00:"for (j, 0, tindex)" 166 | body_01:" C[i, j] = (A[i, j] + B[i, j])" 167 | } 168 | } 169 | layer { 170 | name:"Node_15" 171 | type:"buffer_realize" 172 | top:"Node_15" 173 | bottom:"Node_5" 174 | bottom:"Node_6" 175 | bottom:"C" 176 | bottom:"Node_14" 177 | layer_param { 178 | idx:15 179 | condition:True 180 | body_00:"for (i, 0, tindex)" 181 | body_01:" for (j, 0, tindex)" 182 | body_02:" C[i, j] = (A[i, j] + B[i, j])" 183 | bounds_00:"[range(min=0, ext=tindex), range(min=0, ext=tindex)]" 184 | } 185 | } 186 | layer { 187 | name:"Node_16" 188 | type:"attribute" 189 | top:"Node_16" 190 | bottom:"C" 191 | bottom:"Node_15" 192 | layer_param { 193 | idx:16 194 | attr_key:realize_scope 195 | body_00:"buffer_realize C([0, tindex], [0, tindex])" 196 | body_01:" for (i, 0, tindex)" 197 | body_02:" for (j, 0, tindex)" 198 | body_03:" C[i, j] = (A[i, j] + B[i, j])" 199 | value_00:"''" 200 | } 201 | } 202 | layer { 203 | name:"Node_17" 204 | type:"primfunc" 205 | top:"Node_17" 206 | bottom:"A" 207 | bottom:"B" 208 | bottom:"C" 209 | bottom:"Node_16" 210 | layer_param { 211 | idx:17 212 | body_00:"// attr [buffer(C, 0x7fc805f8b9e0)] realize_scope = ''" 213 | body_01:"buffer_realize C([0, tindex], [0, tindex])" 214 | body_02:" for (i, 0, tindex)" 215 | body_03:" for (j, 0, tindex)" 216 | body_04:" C[i, j] = (A[i, j] + B[i, j])" 217 | } 218 | } 219 | -------------------------------------------------------------------------------- /stmt_build/visualizes/split_after.prototxt: -------------------------------------------------------------------------------- 1 | name : "prim_expr" 2 | layer { 3 | name:"B" 4 | type:"buffer(node)" 5 | top:"B" 6 | layer_param { 7 | idx:0 8 | buffer_name:"B" 9 | shape:[1] 10 | dtype:float32 11 | } 12 | } 13 | layer { 14 | name:"Node_1" 15 | type:"float(value)" 16 | top:"Node_1" 17 | layer_param { 18 | idx:1 19 | value:0.0 20 | dtype:float32 21 | } 22 | } 23 | layer { 24 | name:"Node_2" 25 | type:"buffer_store(seq_0)" 26 | top:"Node_2" 27 | bottom:"B" 28 | bottom:"Node_1" 29 | layer_param { 30 | idx:2 31 | value_00:"0f32" 32 | indices_00:"[0]" 33 | } 34 | } 35 | layer { 36 | name:"Node_3" 37 | type:"int(indice)" 38 | top:"Node_3" 39 | layer_param { 40 | idx:3 41 | value:0 42 | dtype:int32 43 | } 44 | } 45 | layer { 46 | name:"Node_4" 47 | type:"buffer_load(a)" 48 | top:"Node_4" 49 | bottom:"B" 50 | bottom:"Node_3" 51 | layer_param { 52 | idx:4 53 | } 54 | } 55 | layer { 56 | name:"A" 57 | type:"buffer(buffer)" 58 | top:"A" 59 | layer_param { 60 | idx:5 61 | buffer_name:"A" 62 | shape:[1024] 63 | dtype:float32 64 | } 65 | } 66 | layer { 67 | name:"k.inner" 68 | type:"var(a)" 69 | top:"k.inner" 70 | layer_param { 71 | idx:6 72 | dtype:int32 73 | } 74 | } 75 | layer { 76 | name:"k.outer" 77 | type:"var(a)" 78 | top:"k.outer" 79 | layer_param { 80 | idx:7 81 | dtype:int32 82 | } 83 | } 84 | layer { 85 | name:"Node_8" 86 | type:"int(b)" 87 | top:"Node_8" 88 | layer_param { 89 | idx:8 90 | value:32 91 | dtype:int32 92 | } 93 | } 94 | layer { 95 | name:"Node_9" 96 | type:"mul(b)" 97 | top:"Node_9" 98 | bottom:"k.outer" 99 | bottom:"Node_8" 100 | layer_param { 101 | idx:9 102 | } 103 | } 104 | layer { 105 | name:"Node_10" 106 | type:"add(indice)" 107 | top:"Node_10" 108 | bottom:"k.inner" 109 | bottom:"Node_9" 110 | layer_param { 111 | idx:10 112 | } 113 | } 114 | layer { 115 | name:"Node_11" 116 | type:"buffer_load(b)" 117 | top:"Node_11" 118 | bottom:"A" 119 | bottom:"Node_10" 120 | layer_param { 121 | idx:11 122 | } 123 | } 124 | layer { 125 | name:"Node_12" 126 | type:"add(value)" 127 | top:"Node_12" 128 | bottom:"Node_4" 129 | bottom:"Node_11" 130 | layer_param { 131 | idx:12 132 | } 133 | } 134 | layer { 135 | name:"Node_13" 136 | type:"buffer_store" 137 | top:"Node_13" 138 | bottom:"B" 139 | bottom:"Node_12" 140 | layer_param { 141 | idx:13 142 | value_00:"(B: Buffer(B_1: Pointer(float32), float32, [1], [])[0] + A: Buffer(A_1: Pointer(float32), float32, [1024], [])[(k.inner: int32 + (k.outer: int32*32))])" 143 | indices_00:"[0]" 144 | } 145 | } 146 | layer { 147 | name:"Node_14" 148 | type:"for" 149 | top:"Node_14" 150 | bottom:"Node_13" 151 | layer_param { 152 | idx:14 153 | kind:0 154 | body_00:"B[0] = (B[0] + A[(k.inner + (k.outer*32))])" 155 | } 156 | } 157 | layer { 158 | name:"Node_15" 159 | type:"for(seq_1)" 160 | top:"Node_15" 161 | bottom:"Node_14" 162 | layer_param { 163 | idx:15 164 | kind:0 165 | body_00:"for (k.inner, 0, 32)" 166 | body_01:" B[0] = (B[0] + A[(k.inner + (k.outer*32))])" 167 | } 168 | } 169 | layer { 170 | name:"Node_16" 171 | type:"seq" 172 | top:"Node_16" 173 | bottom:"Node_2" 174 | bottom:"Node_15" 175 | layer_param { 176 | idx:16 177 | seq_00:"[B[0] = 0f" 178 | seq_01:" , for (k.outer, 0, 32)" 179 | seq_02:" for (k.inner, 0, 32)" 180 | seq_03:" B[0] = (B[0] + A[(k.inner + (k.outer*32))])" 181 | seq_04:" ]" 182 | } 183 | } 184 | layer { 185 | name:"Node_17" 186 | type:"buffer_realize" 187 | top:"Node_17" 188 | bottom:"Node_16" 189 | layer_param { 190 | idx:17 191 | condition:True 192 | body_00:"B[0] = 0f" 193 | body_01:" for (k.outer, 0, 32)" 194 | body_02:" for (k.inner, 0, 32)" 195 | body_03:" B[0] = (B[0] + A[(k.inner + (k.outer*32))])" 196 | bounds_00:"[range(min=0, ext=1)]" 197 | } 198 | } 199 | layer { 200 | name:"Node_18" 201 | type:"attribute" 202 | top:"Node_18" 203 | bottom:"B" 204 | bottom:"Node_17" 205 | layer_param { 206 | idx:18 207 | attr_key:realize_scope 208 | body_00:"buffer_realize B([0, 1])" 209 | body_01:" B[0] = 0f" 210 | body_02:" for (k.outer, 0, 32)" 211 | body_03:" for (k.inner, 0, 32)" 212 | body_04:" B[0] = (B[0] + A[(k.inner + (k.outer*32))])" 213 | value_00:"''" 214 | } 215 | } 216 | layer { 217 | name:"Node_19" 218 | type:"primfunc" 219 | top:"Node_19" 220 | bottom:"Node_18" 221 | layer_param { 222 | idx:19 223 | body_00:"// attr [buffer(B, 0x7ff97c726d90)] realize_scope = ''" 224 | body_01:"buffer_realize B([0, 1])" 225 | body_02:" B[0] = 0f" 226 | body_03:" for (k.outer, 0, 32)" 227 | body_04:" for (k.inner, 0, 32)" 228 | body_05:" B[0] = (B[0] + A[(k.inner + (k.outer*32))])" 229 | } 230 | } 231 | -------------------------------------------------------------------------------- /stmt_build/visualizes/split_before.prototxt: -------------------------------------------------------------------------------- 1 | name : "prim_expr" 2 | layer { 3 | name:"B" 4 | type:"buffer(node)" 5 | top:"B" 6 | layer_param { 7 | idx:0 8 | buffer_name:"B" 9 | shape:[1] 10 | dtype:float32 11 | } 12 | } 13 | layer { 14 | name:"Node_1" 15 | type:"float(value)" 16 | top:"Node_1" 17 | layer_param { 18 | idx:1 19 | value:0.0 20 | dtype:float32 21 | } 22 | } 23 | layer { 24 | name:"Node_2" 25 | type:"buffer_store(seq_0)" 26 | top:"Node_2" 27 | bottom:"B" 28 | bottom:"Node_1" 29 | layer_param { 30 | idx:2 31 | value_00:"0f32" 32 | indices_00:"[0]" 33 | } 34 | } 35 | layer { 36 | name:"Node_3" 37 | type:"int(indice)" 38 | top:"Node_3" 39 | layer_param { 40 | idx:3 41 | value:0 42 | dtype:int32 43 | } 44 | } 45 | layer { 46 | name:"Node_4" 47 | type:"buffer_load(a)" 48 | top:"Node_4" 49 | bottom:"B" 50 | bottom:"Node_3" 51 | layer_param { 52 | idx:4 53 | } 54 | } 55 | layer { 56 | name:"A" 57 | type:"buffer(buffer)" 58 | top:"A" 59 | layer_param { 60 | idx:5 61 | buffer_name:"A" 62 | shape:[1024] 63 | dtype:float32 64 | } 65 | } 66 | layer { 67 | name:"k" 68 | type:"var(indice)" 69 | top:"k" 70 | layer_param { 71 | idx:6 72 | dtype:int32 73 | } 74 | } 75 | layer { 76 | name:"Node_7" 77 | type:"buffer_load(b)" 78 | top:"Node_7" 79 | bottom:"A" 80 | bottom:"k" 81 | layer_param { 82 | idx:7 83 | } 84 | } 85 | layer { 86 | name:"Node_8" 87 | type:"add(value)" 88 | top:"Node_8" 89 | bottom:"Node_4" 90 | bottom:"Node_7" 91 | layer_param { 92 | idx:8 93 | } 94 | } 95 | layer { 96 | name:"Node_9" 97 | type:"buffer_store" 98 | top:"Node_9" 99 | bottom:"B" 100 | bottom:"Node_8" 101 | layer_param { 102 | idx:9 103 | value_00:"(B: Buffer(B_1: Pointer(float32), float32, [1], [])[0] + A: Buffer(A_1: Pointer(float32), float32, [1024], [])[k: int32])" 104 | indices_00:"[0]" 105 | } 106 | } 107 | layer { 108 | name:"Node_10" 109 | type:"for(seq_1)" 110 | top:"Node_10" 111 | bottom:"Node_9" 112 | layer_param { 113 | idx:10 114 | kind:0 115 | body_00:"B[0] = (B[0] + A[k])" 116 | } 117 | } 118 | layer { 119 | name:"Node_11" 120 | type:"seq" 121 | top:"Node_11" 122 | bottom:"Node_2" 123 | bottom:"Node_10" 124 | layer_param { 125 | idx:11 126 | seq_00:"[B[0] = 0f" 127 | seq_01:" , for (k, 0, 1024)" 128 | seq_02:" B[0] = (B[0] + A[k])" 129 | seq_03:" ]" 130 | } 131 | } 132 | layer { 133 | name:"Node_12" 134 | type:"buffer_realize" 135 | top:"Node_12" 136 | bottom:"Node_11" 137 | layer_param { 138 | idx:12 139 | condition:True 140 | body_00:"B[0] = 0f" 141 | body_01:" for (k, 0, 1024)" 142 | body_02:" B[0] = (B[0] + A[k])" 143 | bounds_00:"[range(min=0, ext=1)]" 144 | } 145 | } 146 | layer { 147 | name:"Node_13" 148 | type:"attribute" 149 | top:"Node_13" 150 | bottom:"B" 151 | bottom:"Node_12" 152 | layer_param { 153 | idx:13 154 | attr_key:realize_scope 155 | body_00:"buffer_realize B([0, 1])" 156 | body_01:" B[0] = 0f" 157 | body_02:" for (k, 0, 1024)" 158 | body_03:" B[0] = (B[0] + A[k])" 159 | value_00:"''" 160 | } 161 | } 162 | layer { 163 | name:"Node_14" 164 | type:"primfunc" 165 | top:"Node_14" 166 | bottom:"Node_13" 167 | layer_param { 168 | idx:14 169 | body_00:"// attr [buffer(B, 0x7ff979d89590)] realize_scope = ''" 170 | body_01:"buffer_realize B([0, 1])" 171 | body_02:" B[0] = 0f" 172 | body_03:" for (k, 0, 1024)" 173 | body_04:" B[0] = (B[0] + A[k])" 174 | } 175 | } 176 | -------------------------------------------------------------------------------- /stmt_build/visualizes/tensorize_after.prototxt: -------------------------------------------------------------------------------- 1 | name : "prim_expr" 2 | layer { 3 | name:"C" 4 | type:"buffer(node)" 5 | top:"C" 6 | layer_param { 7 | idx:0 8 | buffer_name:"C" 9 | shape:[1024, 512] 10 | dtype:float32 11 | } 12 | } 13 | layer { 14 | name:"A" 15 | type:"buffer(array_0)" 16 | top:"A" 17 | layer_param { 18 | idx:1 19 | buffer_name:"A" 20 | shape:[64] 21 | dtype:float32 22 | } 23 | } 24 | layer { 25 | name:"A_1" 26 | type:"buffer(array_1)" 27 | top:"A_1" 28 | layer_param { 29 | idx:2 30 | buffer_name:"A" 31 | shape:[1024, 64] 32 | dtype:float32 33 | } 34 | } 35 | layer { 36 | name:"Node_3" 37 | type:"array(node)" 38 | top:"Node_3" 39 | bottom:"A" 40 | bottom:"A_1" 41 | layer_param { 42 | idx:3 43 | } 44 | } 45 | layer { 46 | name:"B" 47 | type:"buffer(array_0)" 48 | top:"B" 49 | layer_param { 50 | idx:4 51 | buffer_name:"B" 52 | shape:[16, 64] 53 | dtype:float32 54 | } 55 | } 56 | layer { 57 | name:"B_1" 58 | type:"buffer(array_1)" 59 | top:"B_1" 60 | layer_param { 61 | idx:5 62 | buffer_name:"B" 63 | shape:[512, 64] 64 | dtype:float32 65 | } 66 | } 67 | layer { 68 | name:"Node_6" 69 | type:"array(node)" 70 | top:"Node_6" 71 | bottom:"B" 72 | bottom:"B_1" 73 | layer_param { 74 | idx:6 75 | } 76 | } 77 | layer { 78 | name:"C_1" 79 | type:"buffer(array_0)" 80 | top:"C_1" 81 | layer_param { 82 | idx:7 83 | buffer_name:"C" 84 | shape:[16] 85 | dtype:float32 86 | } 87 | } 88 | layer { 89 | name:"Node_8" 90 | type:"array(node)" 91 | top:"Node_8" 92 | bottom:"C_1" 93 | bottom:"C" 94 | layer_param { 95 | idx:8 96 | } 97 | } 98 | layer { 99 | name:"Node_9" 100 | type:"Call_tir.call_extern(value)" 101 | top:"Node_9" 102 | layer_param { 103 | idx:9 104 | body_00:"@tir.call_extern('gemv_update', @tir.tvm_access_ptr(@tir.type_annotation(, dtype=float32), C: Pointer(float32), C_elem_offset: int32, 16, 2, dtype=handle), @tir.tvm_access_ptr(@tir.type_annotation(, dtype=float32), A: Pointer(float32), A_elem_offset: int32, 64, 1, dtype=handle), @tir.tvm_access_ptr(@tir.type_annotation(, dtype=float32), B: Pointer(float32), B_elem_offset: int32, (s1: int32*16), 1, dtype=handle), 16, 64, s1, dtype=int32)" 105 | } 106 | } 107 | layer { 108 | name:"Node_10" 109 | type:"evaluate" 110 | top:"Node_10" 111 | bottom:"Node_9" 112 | layer_param { 113 | idx:10 114 | } 115 | } 116 | layer { 117 | name:"Node_11" 118 | type:"attribute" 119 | top:"Node_11" 120 | bottom:"Node_8" 121 | bottom:"Node_10" 122 | layer_param { 123 | idx:11 124 | attr_key:buffer_bind_scope 125 | body_00:"tir.call_extern('gemv_update', tir.tvm_access_ptr(tir.type_annotation(), C, C_elem_offset, 16, 2), tir.tvm_access_ptr(tir.type_annotation(), A, A_elem_offset, 64, 1), tir.tvm_access_ptr(tir.type_annotation(), B, B_elem_offset, (s1*16), 1), 16, 64, s1)" 126 | value_00:"@tir.tvm_tuple(i: int32, 1, (j.outer: int32*16), 16, dtype=handle)" 127 | } 128 | } 129 | layer { 130 | name:"Node_12" 131 | type:"attribute" 132 | top:"Node_12" 133 | bottom:"Node_6" 134 | bottom:"Node_11" 135 | layer_param { 136 | idx:12 137 | attr_key:buffer_bind_scope 138 | body_00:"// attr [[buffer(C, 0x7ff97c52cfe0), buffer(C, 0x7ff97c558780)]] buffer_bind_scope = tir.tvm_tuple(i, 1, (j.outer*16), 16)" 139 | body_01:"tir.call_extern('gemv_update', tir.tvm_access_ptr(tir.type_annotation(), C, C_elem_offset, 16, 2), tir.tvm_access_ptr(tir.type_annotation(), A, A_elem_offset, 64, 1), tir.tvm_access_ptr(tir.type_annotation(), B, B_elem_offset, (s1*16), 1), 16, 64, s1)" 140 | value_00:"@tir.tvm_tuple((j.outer: int32*16), 16, 0, 64, dtype=handle)" 141 | } 142 | } 143 | layer { 144 | name:"Node_13" 145 | type:"attribute" 146 | top:"Node_13" 147 | bottom:"Node_3" 148 | bottom:"Node_12" 149 | layer_param { 150 | idx:13 151 | attr_key:buffer_bind_scope 152 | body_00:"// attr [[buffer(B, 0x7ff97c52cd90), buffer(B, 0x7ff97c558ae0)]] buffer_bind_scope = tir.tvm_tuple((j.outer*16), 16, 0, 64)" 153 | body_01:"// attr [[buffer(C, 0x7ff97c52cfe0), buffer(C, 0x7ff97c558780)]] buffer_bind_scope = tir.tvm_tuple(i, 1, (j.outer*16), 16)" 154 | body_02:"tir.call_extern('gemv_update', tir.tvm_access_ptr(tir.type_annotation(), C, C_elem_offset, 16, 2), tir.tvm_access_ptr(tir.type_annotation(), A, A_elem_offset, 64, 1), tir.tvm_access_ptr(tir.type_annotation(), B, B_elem_offset, (s1*16), 1), 16, 64, s1)" 155 | value_00:"@tir.tvm_tuple(i: int32, 1, 0, 64, dtype=handle)" 156 | } 157 | } 158 | layer { 159 | name:"Node_14" 160 | type:"for" 161 | top:"Node_14" 162 | bottom:"Node_13" 163 | layer_param { 164 | idx:14 165 | kind:0 166 | body_00:"// attr [[buffer(A, 0x7ff97c52b910), buffer(A, 0x7ff97c561dd0)]] buffer_bind_scope = tir.tvm_tuple(i, 1, 0, 64)" 167 | body_01:"// attr [[buffer(B, 0x7ff97c52cd90), buffer(B, 0x7ff97c558ae0)]] buffer_bind_scope = tir.tvm_tuple((j.outer*16), 16, 0, 64)" 168 | body_02:"// attr [[buffer(C, 0x7ff97c52cfe0), buffer(C, 0x7ff97c558780)]] buffer_bind_scope = tir.tvm_tuple(i, 1, (j.outer*16), 16)" 169 | body_03:"tir.call_extern('gemv_update', tir.tvm_access_ptr(tir.type_annotation(), C, C_elem_offset, 16, 2), tir.tvm_access_ptr(tir.type_annotation(), A, A_elem_offset, 64, 1), tir.tvm_access_ptr(tir.type_annotation(), B, B_elem_offset, (s1*16), 1), 16, 64, s1)" 170 | } 171 | } 172 | layer { 173 | name:"Node_15" 174 | type:"for" 175 | top:"Node_15" 176 | bottom:"Node_14" 177 | layer_param { 178 | idx:15 179 | kind:0 180 | body_00:"for (j.outer, 0, 32)" 181 | body_01:" // attr [[buffer(A, 0x7ff97c52b910), buffer(A, 0x7ff97c561dd0)]] buffer_bind_scope = tir.tvm_tuple(i, 1, 0, 64)" 182 | body_02:" // attr [[buffer(B, 0x7ff97c52cd90), buffer(B, 0x7ff97c558ae0)]] buffer_bind_scope = tir.tvm_tuple((j.outer*16), 16, 0, 64)" 183 | body_03:" // attr [[buffer(C, 0x7ff97c52cfe0), buffer(C, 0x7ff97c558780)]] buffer_bind_scope = tir.tvm_tuple(i, 1, (j.outer*16), 16)" 184 | body_04:" tir.call_extern('gemv_update', tir.tvm_access_ptr(tir.type_annotation(), C, C_elem_offset, 16, 2), tir.tvm_access_ptr(tir.type_annotation(), A, A_elem_offset, 64, 1), tir.tvm_access_ptr(tir.type_annotation(), B, B_elem_offset, (s1*16), 1), 16, 64, s1)" 185 | } 186 | } 187 | layer { 188 | name:"Node_16" 189 | type:"buffer_realize" 190 | top:"Node_16" 191 | bottom:"Node_15" 192 | layer_param { 193 | idx:16 194 | condition:True 195 | body_00:"for (i, 0, 1024)" 196 | body_01:" for (j.outer, 0, 32)" 197 | body_02:" // attr [[buffer(A, 0x7ff97c52b910), buffer(A, 0x7ff97c561dd0)]] buffer_bind_scope = tir.tvm_tuple(i, 1, 0, 64)" 198 | body_03:" // attr [[buffer(B, 0x7ff97c52cd90), buffer(B, 0x7ff97c558ae0)]] buffer_bind_scope = tir.tvm_tuple((j.outer*16), 16, 0, 64)" 199 | body_04:" // attr [[buffer(C, 0x7ff97c52cfe0), buffer(C, 0x7ff97c558780)]] buffer_bind_scope = tir.tvm_tuple(i, 1, (j.outer*16), 16)" 200 | body_05:" tir.call_extern('gemv_update', tir.tvm_access_ptr(tir.type_annotation(), C, C_elem_offset, 16, 2), tir.tvm_access_ptr(tir.type_annotation(), A, A_elem_offset, 64, 1), tir.tvm_access_ptr(tir.type_annotation(), B, B_elem_offset, (s1*16), 1), 16, 64, s1)" 201 | bounds_00:"[range(min=0, ext=1024), range(min=0, ext=512)]" 202 | } 203 | } 204 | layer { 205 | name:"Node_17" 206 | type:"attribute" 207 | top:"Node_17" 208 | bottom:"C" 209 | bottom:"Node_16" 210 | layer_param { 211 | idx:17 212 | attr_key:realize_scope 213 | body_00:"buffer_realize C([0, 1024], [0, 512])" 214 | body_01:" for (i, 0, 1024)" 215 | body_02:" for (j.outer, 0, 32)" 216 | body_03:" // attr [[buffer(A, 0x7ff97c52b910), buffer(A, 0x7ff97c561dd0)]] buffer_bind_scope = tir.tvm_tuple(i, 1, 0, 64)" 217 | body_04:" // attr [[buffer(B, 0x7ff97c52cd90), buffer(B, 0x7ff97c558ae0)]] buffer_bind_scope = tir.tvm_tuple((j.outer*16), 16, 0, 64)" 218 | body_05:" // attr [[buffer(C, 0x7ff97c52cfe0), buffer(C, 0x7ff97c558780)]] buffer_bind_scope = tir.tvm_tuple(i, 1, (j.outer*16), 16)" 219 | body_06:" tir.call_extern('gemv_update', tir.tvm_access_ptr(tir.type_annotation(), C, C_elem_offset, 16, 2), tir.tvm_access_ptr(tir.type_annotation(), A, A_elem_offset, 64, 1), tir.tvm_access_ptr(tir.type_annotation(), B, B_elem_offset, (s1*16), 1), 16, 64, s1)" 220 | value_00:"''" 221 | } 222 | } 223 | layer { 224 | name:"Node_18" 225 | type:"primfunc" 226 | top:"Node_18" 227 | bottom:"Node_17" 228 | layer_param { 229 | idx:18 230 | body_00:"// attr [buffer(C, 0x7ff97c558780)] realize_scope = ''" 231 | body_01:"buffer_realize C([0, 1024], [0, 512])" 232 | body_02:" for (i, 0, 1024)" 233 | body_03:" for (j.outer, 0, 32)" 234 | body_04:" // attr [[buffer(A, 0x7ff97c52b910), buffer(A, 0x7ff97c561dd0)]] buffer_bind_scope = tir.tvm_tuple(i, 1, 0, 64)" 235 | body_05:" // attr [[buffer(B, 0x7ff97c52cd90), buffer(B, 0x7ff97c558ae0)]] buffer_bind_scope = tir.tvm_tuple((j.outer*16), 16, 0, 64)" 236 | body_06:" // attr [[buffer(C, 0x7ff97c52cfe0), buffer(C, 0x7ff97c558780)]] buffer_bind_scope = tir.tvm_tuple(i, 1, (j.outer*16), 16)" 237 | body_07:" tir.call_extern('gemv_update', tir.tvm_access_ptr(tir.type_annotation(), C, C_elem_offset, 16, 2), tir.tvm_access_ptr(tir.type_annotation(), A, A_elem_offset, 64, 1), tir.tvm_access_ptr(tir.type_annotation(), B, B_elem_offset, (s1*16), 1), 16, 64, s1)" 238 | } 239 | } 240 | -------------------------------------------------------------------------------- /stmt_build/visualizes/tensorize_before.prototxt: -------------------------------------------------------------------------------- 1 | name : "prim_expr" 2 | layer { 3 | name:"C" 4 | type:"buffer(node)" 5 | top:"C" 6 | layer_param { 7 | idx:0 8 | buffer_name:"C" 9 | shape:[1024, 512] 10 | dtype:float32 11 | } 12 | } 13 | layer { 14 | name:"Node_1" 15 | type:"float(value)" 16 | top:"Node_1" 17 | layer_param { 18 | idx:1 19 | value:0.0 20 | dtype:float32 21 | } 22 | } 23 | layer { 24 | name:"Node_2" 25 | type:"buffer_store(seq_0)" 26 | top:"Node_2" 27 | bottom:"C" 28 | bottom:"Node_1" 29 | layer_param { 30 | idx:2 31 | value_00:"0f32" 32 | indices_00:"[i, (j.inner + (j.outer*16))]" 33 | } 34 | } 35 | layer { 36 | name:"i" 37 | type:"var(indice)" 38 | top:"i" 39 | layer_param { 40 | idx:3 41 | dtype:int32 42 | } 43 | } 44 | layer { 45 | name:"j.inner" 46 | type:"var(a)" 47 | top:"j.inner" 48 | layer_param { 49 | idx:4 50 | dtype:int32 51 | } 52 | } 53 | layer { 54 | name:"j.outer" 55 | type:"var(a)" 56 | top:"j.outer" 57 | layer_param { 58 | idx:5 59 | dtype:int32 60 | } 61 | } 62 | layer { 63 | name:"Node_6" 64 | type:"int(b)" 65 | top:"Node_6" 66 | layer_param { 67 | idx:6 68 | value:16 69 | dtype:int32 70 | } 71 | } 72 | layer { 73 | name:"Node_7" 74 | type:"mul(b)" 75 | top:"Node_7" 76 | bottom:"j.outer" 77 | bottom:"Node_6" 78 | layer_param { 79 | idx:7 80 | } 81 | } 82 | layer { 83 | name:"Node_8" 84 | type:"add(indice)" 85 | top:"Node_8" 86 | bottom:"j.inner" 87 | bottom:"Node_7" 88 | layer_param { 89 | idx:8 90 | } 91 | } 92 | layer { 93 | name:"Node_9" 94 | type:"buffer_load(a)" 95 | top:"Node_9" 96 | bottom:"C" 97 | bottom:"i" 98 | bottom:"Node_8" 99 | layer_param { 100 | idx:9 101 | } 102 | } 103 | layer { 104 | name:"A" 105 | type:"buffer(buffer)" 106 | top:"A" 107 | layer_param { 108 | idx:10 109 | buffer_name:"A" 110 | shape:[1024, 64] 111 | dtype:float32 112 | } 113 | } 114 | layer { 115 | name:"k" 116 | type:"var(indice)" 117 | top:"k" 118 | layer_param { 119 | idx:11 120 | dtype:int32 121 | } 122 | } 123 | layer { 124 | name:"Node_12" 125 | type:"buffer_load(a)" 126 | top:"Node_12" 127 | bottom:"A" 128 | bottom:"i" 129 | bottom:"k" 130 | layer_param { 131 | idx:12 132 | } 133 | } 134 | layer { 135 | name:"B" 136 | type:"buffer(buffer)" 137 | top:"B" 138 | layer_param { 139 | idx:13 140 | buffer_name:"B" 141 | shape:[512, 64] 142 | dtype:float32 143 | } 144 | } 145 | layer { 146 | name:"Node_14" 147 | type:"buffer_load(b)" 148 | top:"Node_14" 149 | bottom:"B" 150 | bottom:"Node_8" 151 | bottom:"k" 152 | layer_param { 153 | idx:14 154 | } 155 | } 156 | layer { 157 | name:"Node_15" 158 | type:"mul(b)" 159 | top:"Node_15" 160 | bottom:"Node_12" 161 | bottom:"Node_14" 162 | layer_param { 163 | idx:15 164 | } 165 | } 166 | layer { 167 | name:"Node_16" 168 | type:"add(value)" 169 | top:"Node_16" 170 | bottom:"Node_9" 171 | bottom:"Node_15" 172 | layer_param { 173 | idx:16 174 | } 175 | } 176 | layer { 177 | name:"Node_17" 178 | type:"buffer_store" 179 | top:"Node_17" 180 | bottom:"C" 181 | bottom:"Node_16" 182 | layer_param { 183 | idx:17 184 | value_00:"(C: Buffer(C_1: Pointer(float32), float32, [1024, 512], [])[i: int32, (j.inner: int32 + (j.outer: int32*16))] + (A: Buffer(A_1: Pointer(float32), float32, [1024, 64], [])[i, k: int32]*B: Buffer(B_1: Pointer(float32), float32, [512, 64], [])[(j.inner + (j.outer*16)), k]))" 185 | indices_00:"[i, (j.inner + (j.outer*16))]" 186 | } 187 | } 188 | layer { 189 | name:"Node_18" 190 | type:"for(seq_1)" 191 | top:"Node_18" 192 | bottom:"Node_17" 193 | layer_param { 194 | idx:18 195 | kind:0 196 | body_00:"C[i, (j.inner + (j.outer*16))] = (C[i, (j.inner + (j.outer*16))] + (A[i, k]*B[(j.inner + (j.outer*16)), k]))" 197 | } 198 | } 199 | layer { 200 | name:"Node_19" 201 | type:"seq" 202 | top:"Node_19" 203 | bottom:"Node_2" 204 | bottom:"Node_18" 205 | layer_param { 206 | idx:19 207 | seq_00:"[C[i, (j.inner + (j.outer*16))] = 0f" 208 | seq_01:" , for (k, 0, 64)" 209 | seq_02:" C[i, (j.inner + (j.outer*16))] = (C[i, (j.inner + (j.outer*16))] + (A[i, k]*B[(j.inner + (j.outer*16)), k]))" 210 | seq_03:" ]" 211 | } 212 | } 213 | layer { 214 | name:"Node_20" 215 | type:"for" 216 | top:"Node_20" 217 | bottom:"Node_19" 218 | layer_param { 219 | idx:20 220 | kind:0 221 | body_00:"C[i, (j.inner + (j.outer*16))] = 0f" 222 | body_01:" for (k, 0, 64)" 223 | body_02:" C[i, (j.inner + (j.outer*16))] = (C[i, (j.inner + (j.outer*16))] + (A[i, k]*B[(j.inner + (j.outer*16)), k]))" 224 | } 225 | } 226 | layer { 227 | name:"Node_21" 228 | type:"for" 229 | top:"Node_21" 230 | bottom:"Node_20" 231 | layer_param { 232 | idx:21 233 | kind:0 234 | body_00:"for (j.inner, 0, 16)" 235 | body_01:" C[i, (j.inner + (j.outer*16))] = 0f" 236 | body_02:" for (k, 0, 64)" 237 | body_03:" C[i, (j.inner + (j.outer*16))] = (C[i, (j.inner + (j.outer*16))] + (A[i, k]*B[(j.inner + (j.outer*16)), k]))" 238 | } 239 | } 240 | layer { 241 | name:"Node_22" 242 | type:"for" 243 | top:"Node_22" 244 | bottom:"Node_21" 245 | layer_param { 246 | idx:22 247 | kind:0 248 | body_00:"for (j.outer, 0, 32)" 249 | body_01:" for (j.inner, 0, 16)" 250 | body_02:" C[i, (j.inner + (j.outer*16))] = 0f" 251 | body_03:" for (k, 0, 64)" 252 | body_04:" C[i, (j.inner + (j.outer*16))] = (C[i, (j.inner + (j.outer*16))] + (A[i, k]*B[(j.inner + (j.outer*16)), k]))" 253 | } 254 | } 255 | layer { 256 | name:"Node_23" 257 | type:"buffer_realize" 258 | top:"Node_23" 259 | bottom:"Node_22" 260 | layer_param { 261 | idx:23 262 | condition:True 263 | body_00:"for (i, 0, 1024)" 264 | body_01:" for (j.outer, 0, 32)" 265 | body_02:" for (j.inner, 0, 16)" 266 | body_03:" C[i, (j.inner + (j.outer*16))] = 0f" 267 | body_04:" for (k, 0, 64)" 268 | body_05:" C[i, (j.inner + (j.outer*16))] = (C[i, (j.inner + (j.outer*16))] + (A[i, k]*B[(j.inner + (j.outer*16)), k]))" 269 | bounds_00:"[range(min=0, ext=1024), range(min=0, ext=512)]" 270 | } 271 | } 272 | layer { 273 | name:"Node_24" 274 | type:"attribute" 275 | top:"Node_24" 276 | bottom:"C" 277 | bottom:"Node_23" 278 | layer_param { 279 | idx:24 280 | attr_key:realize_scope 281 | body_00:"buffer_realize C([0, 1024], [0, 512])" 282 | body_01:" for (i, 0, 1024)" 283 | body_02:" for (j.outer, 0, 32)" 284 | body_03:" for (j.inner, 0, 16)" 285 | body_04:" C[i, (j.inner + (j.outer*16))] = 0f" 286 | body_05:" for (k, 0, 64)" 287 | body_06:" C[i, (j.inner + (j.outer*16))] = (C[i, (j.inner + (j.outer*16))] + (A[i, k]*B[(j.inner + (j.outer*16)), k]))" 288 | value_00:"''" 289 | } 290 | } 291 | layer { 292 | name:"Node_25" 293 | type:"primfunc" 294 | top:"Node_25" 295 | bottom:"Node_24" 296 | layer_param { 297 | idx:25 298 | body_00:"// attr [buffer(C, 0x7ff97c52bca0)] realize_scope = ''" 299 | body_01:"buffer_realize C([0, 1024], [0, 512])" 300 | body_02:" for (i, 0, 1024)" 301 | body_03:" for (j.outer, 0, 32)" 302 | body_04:" for (j.inner, 0, 16)" 303 | body_05:" C[i, (j.inner + (j.outer*16))] = 0f" 304 | body_06:" for (k, 0, 64)" 305 | body_07:" C[i, (j.inner + (j.outer*16))] = (C[i, (j.inner + (j.outer*16))] + (A[i, k]*B[(j.inner + (j.outer*16)), k]))" 306 | } 307 | } 308 | -------------------------------------------------------------------------------- /stmt_optimize/demo.py: -------------------------------------------------------------------------------- 1 | import tvm 2 | 3 | import sys 4 | sys.path.append("..") 5 | from visualize import PrimExprVisualizer 6 | 7 | def simple(visualizer): 8 | print("\nTest StorageFlatten") 9 | n = tvm.te.var() 10 | A = tvm.te.placeholder((n, n), name='A') 11 | B = tvm.te.placeholder((n, n), name='B') 12 | C = tvm.te.compute((n, n), lambda i, j: A[i, j] + B[i, j], name='C') 13 | s = tvm.te.create_schedule(C.op) 14 | mod=tvm.driver.build_module.form_irmodule(s,[A,B,C],"main",binds=None) 15 | visualizer.visualize(mod,"visualizes/storage_flatten_before.prototxt") 16 | print(""+str(mod["main"])) 17 | 18 | #optimize 19 | pass_list=[tvm.tir.transform.StorageFlatten(64)] 20 | optimize = tvm.transform.Sequential(pass_list) 21 | mod = optimize(mod) 22 | visualizer.visualize(mod,"visualizes/storage_flatten_after.prototxt") 23 | print(""+str(mod["main"])) 24 | 25 | def vectorize(visualizer): 26 | print("\nTest VectorizeLoop") 27 | M = 1024 28 | N = 1024 29 | A = tvm.te.placeholder((M, N), name='A') 30 | B = tvm.te.placeholder((M, N), name='B') 31 | C = tvm.te.compute( 32 | (M, N), 33 | lambda x, y: A[x, y] + B[x, y], 34 | name='C') 35 | 36 | s = tvm.te.create_schedule(C.op) 37 | xo, yo, xi, yi = s[C].tile(C.op.axis[0], C.op.axis[1], 32, 32) 38 | s[C].vectorize(yi) 39 | mod=tvm.driver.build_module.form_irmodule(s,[A,B,C],"main",binds=None) 40 | pass_list=[tvm.tir.transform.StorageFlatten(64)] 41 | optimize = tvm.transform.Sequential(pass_list) 42 | mod = optimize(mod) 43 | visualizer.visualize(mod,"visualizes/vectorize_loop_before.prototxt") 44 | print(""+str(mod["main"])) 45 | 46 | pass_list=[tvm.tir.transform.VectorizeLoop()] 47 | optimize = tvm.transform.Sequential(pass_list) 48 | mod = optimize(mod) 49 | visualizer.visualize(mod,"visualizes/vectorize_loop_after.prototxt") 50 | print(""+str(mod["main"])) 51 | 52 | def bind(visualizer): 53 | print("\nTest Simplify") 54 | n = 1024 55 | A = tvm.te.placeholder((n,), name='A') 56 | k = tvm.te.reduce_axis((0, n), name='k') 57 | B = tvm.te.compute((1,), lambda i: tvm.te.sum(A[k], axis=k), name='B') 58 | s = tvm.te.create_schedule(B.op) 59 | ko, ki = s[B].split(B.op.reduce_axis[0], factor=32) 60 | 61 | s[B].bind(ko, tvm.te.thread_axis("blockIdx.x")) 62 | s[B].bind(ki, tvm.te.thread_axis("threadIdx.x")) 63 | mod=tvm.driver.build_module.form_irmodule(s,[A,B],"main",binds=None) 64 | pass_list=[tvm.tir.transform.StorageFlatten(64)] 65 | optimize = tvm.transform.Sequential(pass_list) 66 | mod = optimize(mod) 67 | visualizer.visualize(mod,"visualizes/simplify_before.prototxt") 68 | print(""+str(mod["main"])) 69 | 70 | pass_list=[tvm.tir.transform.Simplify()] 71 | optimize = tvm.transform.Sequential(pass_list) 72 | mod = optimize(mod) 73 | visualizer.visualize(mod,"visualizes/simplify_after.prototxt") 74 | print(""+str(mod["main"])) 75 | 76 | if __name__=='__main__': 77 | visualizer=PrimExprVisualizer(simple_mode=False) 78 | simple(visualizer) 79 | vectorize(visualizer) 80 | bind(visualizer) -------------------------------------------------------------------------------- /stmt_optimize/visualizes/simplify_after.prototxt: -------------------------------------------------------------------------------- 1 | name : "prim_expr" 2 | layer { 3 | name:"A" 4 | type:"buffer" 5 | top:"A" 6 | layer_param { 7 | idx:0 8 | buffer_name:"A" 9 | shape:[1024] 10 | dtype:float32 11 | } 12 | } 13 | layer { 14 | name:"B" 15 | type:"buffer" 16 | top:"B" 17 | layer_param { 18 | idx:1 19 | buffer_name:"B" 20 | shape:[1] 21 | dtype:float32 22 | } 23 | } 24 | layer { 25 | name:"blockIdx.x" 26 | type:"var(iter)" 27 | top:"blockIdx.x" 28 | layer_param { 29 | idx:2 30 | dtype:int32 31 | } 32 | } 33 | layer { 34 | name:"Node_3" 35 | type:"itervar(node)" 36 | top:"Node_3" 37 | bottom:"blockIdx.x" 38 | layer_param { 39 | idx:3 40 | dom:"None" 41 | iter_type:"1" 42 | thread_tag:"blockIdx.x" 43 | } 44 | } 45 | layer { 46 | name:"threadIdx.x" 47 | type:"var(iter)" 48 | top:"threadIdx.x" 49 | layer_param { 50 | idx:4 51 | dtype:int32 52 | } 53 | } 54 | layer { 55 | name:"Node_5" 56 | type:"itervar(node)" 57 | top:"Node_5" 58 | bottom:"threadIdx.x" 59 | layer_param { 60 | idx:5 61 | dom:"None" 62 | iter_type:"1" 63 | thread_tag:"threadIdx.x" 64 | } 65 | } 66 | layer { 67 | name:"reduce_temp0" 68 | type:"var(node)" 69 | top:"reduce_temp0" 70 | layer_param { 71 | idx:6 72 | dtype:handle 73 | } 74 | } 75 | layer { 76 | name:"x" 77 | type:"var(reduce_l)" 78 | top:"x" 79 | layer_param { 80 | idx:7 81 | dtype:float32 82 | } 83 | } 84 | layer { 85 | name:"y" 86 | type:"var(reduce_r)" 87 | top:"y" 88 | layer_param { 89 | idx:8 90 | dtype:float32 91 | } 92 | } 93 | layer { 94 | name:"Node_9" 95 | type:"add(reduce_res)" 96 | top:"Node_9" 97 | bottom:"x" 98 | bottom:"y" 99 | layer_param { 100 | idx:9 101 | } 102 | } 103 | layer { 104 | name:"Node_10" 105 | type:"float(reduce_ind)" 106 | top:"Node_10" 107 | layer_param { 108 | idx:10 109 | value:0.0 110 | dtype:float32 111 | } 112 | } 113 | layer { 114 | name:"Node_11" 115 | type:"common_reducer(node)" 116 | top:"Node_11" 117 | bottom:"x" 118 | bottom:"y" 119 | bottom:"Node_9" 120 | bottom:"Node_10" 121 | layer_param { 122 | idx:11 123 | result_00:"[(x + y)]" 124 | } 125 | } 126 | layer { 127 | name:"Node_12" 128 | type:"int" 129 | top:"Node_12" 130 | layer_param { 131 | idx:12 132 | value:1 133 | dtype:uint32 134 | } 135 | } 136 | layer { 137 | name:"A_1" 138 | type:"var(load_buffer)" 139 | top:"A_1" 140 | layer_param { 141 | idx:13 142 | dtype:handle 143 | } 144 | } 145 | layer { 146 | name:"Node_14" 147 | type:"int(b)" 148 | top:"Node_14" 149 | layer_param { 150 | idx:14 151 | value:32 152 | dtype:int32 153 | } 154 | } 155 | layer { 156 | name:"Node_15" 157 | type:"mul(a)" 158 | top:"Node_15" 159 | bottom:"blockIdx.x" 160 | bottom:"Node_14" 161 | layer_param { 162 | idx:15 163 | } 164 | } 165 | layer { 166 | name:"Node_16" 167 | type:"add(load_index)" 168 | top:"Node_16" 169 | bottom:"Node_15" 170 | bottom:"threadIdx.x" 171 | layer_param { 172 | idx:16 173 | } 174 | } 175 | layer { 176 | name:"Node_17" 177 | type:"load" 178 | top:"Node_17" 179 | bottom:"A_1" 180 | bottom:"Node_16" 181 | layer_param { 182 | idx:17 183 | predicate_00:"True" 184 | body_00:"(float32*)A: Pointer(float32)[((blockIdx.x: int32*32) + threadIdx.x: int32)]" 185 | } 186 | } 187 | layer { 188 | name:"Node_18" 189 | type:"Call_tir.tvm_thread_allreduce(value)" 190 | top:"Node_18" 191 | bottom:"Node_12" 192 | bottom:"Node_17" 193 | bottom:"Node_12" 194 | bottom:"reduce_temp0" 195 | bottom:"blockIdx.x" 196 | bottom:"threadIdx.x" 197 | layer_param { 198 | idx:18 199 | } 200 | } 201 | layer { 202 | name:"Node_19" 203 | type:"evaluate" 204 | top:"Node_19" 205 | bottom:"Node_18" 206 | layer_param { 207 | idx:19 208 | } 209 | } 210 | layer { 211 | name:"Node_20" 212 | type:"attribute(seq_0)" 213 | top:"Node_20" 214 | bottom:"Node_11" 215 | bottom:"Node_19" 216 | layer_param { 217 | idx:20 218 | attr_key:reduce_scope 219 | body_00:"tir.tvm_thread_allreduce((uint32)1, A[((blockIdx.x*32) + threadIdx.x)], (bool)1, reduce_temp0, blockIdx.x, threadIdx.x)" 220 | value_00:"@tir.reinterpret(0u64, dtype=handle)" 221 | } 222 | } 223 | layer { 224 | name:"B_1" 225 | type:"var(store_buffer)" 226 | top:"B_1" 227 | layer_param { 228 | idx:21 229 | dtype:handle 230 | } 231 | } 232 | layer { 233 | name:"Node_22" 234 | type:"int(load_index)" 235 | top:"Node_22" 236 | layer_param { 237 | idx:22 238 | value:0 239 | dtype:int32 240 | } 241 | } 242 | layer { 243 | name:"Node_23" 244 | type:"load(store_value)" 245 | top:"Node_23" 246 | bottom:"reduce_temp0" 247 | bottom:"Node_22" 248 | layer_param { 249 | idx:23 250 | predicate_00:"True" 251 | body_00:"(float32*)reduce_temp0: Pointer(float32)[0]" 252 | } 253 | } 254 | layer { 255 | name:"Node_24" 256 | type:"store(seq_1)" 257 | top:"Node_24" 258 | bottom:"B_1" 259 | bottom:"Node_23" 260 | bottom:"Node_22" 261 | layer_param { 262 | idx:24 263 | predicate_00:"True" 264 | value_00:"(float32*)reduce_temp0: Pointer(float32)[0]" 265 | index_00:"0" 266 | body_00:"B[0] = reduce_temp0[0]" 267 | } 268 | } 269 | layer { 270 | name:"Node_25" 271 | type:"seq" 272 | top:"Node_25" 273 | bottom:"Node_20" 274 | bottom:"Node_24" 275 | layer_param { 276 | idx:25 277 | seq_00:"[// attr [comm_reducer(result=[(x + y)], lhs=[x], rhs=[y], identity_element=[0f])] reduce_scope = tir.reinterpret((uint64)0)" 278 | seq_01:" tir.tvm_thread_allreduce((uint32)1, A[((blockIdx.x*32) + threadIdx.x)], (bool)1, reduce_temp0, blockIdx.x, threadIdx.x)" 279 | seq_02:" , B[0] = reduce_temp0[0]" 280 | seq_03:" ]" 281 | } 282 | } 283 | layer { 284 | name:"Node_26" 285 | type:"allocate" 286 | top:"Node_26" 287 | bottom:"reduce_temp0" 288 | bottom:"Node_25" 289 | layer_param { 290 | idx:26 291 | dtype:float32 292 | extents:"[1]" 293 | condition:"True" 294 | body_00:"// attr [comm_reducer(result=[(x + y)], lhs=[x], rhs=[y], identity_element=[0f])] reduce_scope = tir.reinterpret((uint64)0)" 295 | body_01:"tir.tvm_thread_allreduce((uint32)1, A[((blockIdx.x*32) + threadIdx.x)], (bool)1, reduce_temp0, blockIdx.x, threadIdx.x)" 296 | body_02:" B[0] = reduce_temp0[0]" 297 | } 298 | } 299 | layer { 300 | name:"Node_27" 301 | type:"attribute" 302 | top:"Node_27" 303 | bottom:"reduce_temp0" 304 | bottom:"Node_26" 305 | layer_param { 306 | idx:27 307 | attr_key:storage_scope 308 | body_00:"allocate reduce_temp0[float32 * 1]" 309 | body_01:" // attr [comm_reducer(result=[(x + y)], lhs=[x], rhs=[y], identity_element=[0f])] reduce_scope = tir.reinterpret((uint64)0)" 310 | body_02:" tir.tvm_thread_allreduce((uint32)1, A[((blockIdx.x*32) + threadIdx.x)], (bool)1, reduce_temp0, blockIdx.x, threadIdx.x)" 311 | body_03:" B[0] = reduce_temp0[0]" 312 | value_00:"'local'" 313 | } 314 | } 315 | layer { 316 | name:"Node_28" 317 | type:"attribute" 318 | top:"Node_28" 319 | bottom:"Node_5" 320 | bottom:"Node_27" 321 | layer_param { 322 | idx:28 323 | attr_key:thread_extent 324 | body_00:"// attr [reduce_temp0] storage_scope = 'local'" 325 | body_01:"allocate reduce_temp0[float32 * 1]" 326 | body_02:" // attr [comm_reducer(result=[(x + y)], lhs=[x], rhs=[y], identity_element=[0f])] reduce_scope = tir.reinterpret((uint64)0)" 327 | body_03:" tir.tvm_thread_allreduce((uint32)1, A[((blockIdx.x*32) + threadIdx.x)], (bool)1, reduce_temp0, blockIdx.x, threadIdx.x)" 328 | body_04:" B[0] = reduce_temp0[0]" 329 | value_00:"32" 330 | } 331 | } 332 | layer { 333 | name:"Node_29" 334 | type:"attribute" 335 | top:"Node_29" 336 | bottom:"Node_3" 337 | bottom:"Node_28" 338 | layer_param { 339 | idx:29 340 | attr_key:thread_extent 341 | body_00:"// attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 32" 342 | body_01:"// attr [reduce_temp0] storage_scope = 'local'" 343 | body_02:"allocate reduce_temp0[float32 * 1]" 344 | body_03:" // attr [comm_reducer(result=[(x + y)], lhs=[x], rhs=[y], identity_element=[0f])] reduce_scope = tir.reinterpret((uint64)0)" 345 | body_04:" tir.tvm_thread_allreduce((uint32)1, A[((blockIdx.x*32) + threadIdx.x)], (bool)1, reduce_temp0, blockIdx.x, threadIdx.x)" 346 | body_05:" B[0] = reduce_temp0[0]" 347 | value_00:"32" 348 | } 349 | } 350 | layer { 351 | name:"Node_30" 352 | type:"primfunc" 353 | top:"Node_30" 354 | bottom:"A" 355 | bottom:"B" 356 | bottom:"Node_29" 357 | layer_param { 358 | idx:30 359 | body_00:"// attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 32" 360 | body_01:"// attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 32" 361 | body_02:"// attr [reduce_temp0] storage_scope = 'local'" 362 | body_03:"allocate reduce_temp0[float32 * 1]" 363 | body_04:" // attr [comm_reducer(result=[(x + y)], lhs=[x], rhs=[y], identity_element=[0f])] reduce_scope = tir.reinterpret((uint64)0)" 364 | body_05:" tir.tvm_thread_allreduce((uint32)1, A[((blockIdx.x*32) + threadIdx.x)], (bool)1, reduce_temp0, blockIdx.x, threadIdx.x)" 365 | body_06:" B[0] = reduce_temp0[0]" 366 | } 367 | } 368 | -------------------------------------------------------------------------------- /stmt_optimize/visualizes/simplify_before.prototxt: -------------------------------------------------------------------------------- 1 | name : "prim_expr" 2 | layer { 3 | name:"A" 4 | type:"buffer" 5 | top:"A" 6 | layer_param { 7 | idx:0 8 | buffer_name:"A" 9 | shape:[1024] 10 | dtype:float32 11 | } 12 | } 13 | layer { 14 | name:"B" 15 | type:"buffer" 16 | top:"B" 17 | layer_param { 18 | idx:1 19 | buffer_name:"B" 20 | shape:[1] 21 | dtype:float32 22 | } 23 | } 24 | layer { 25 | name:"blockIdx.x" 26 | type:"var(iter)" 27 | top:"blockIdx.x" 28 | layer_param { 29 | idx:2 30 | dtype:int32 31 | } 32 | } 33 | layer { 34 | name:"Node_3" 35 | type:"itervar(node)" 36 | top:"Node_3" 37 | bottom:"blockIdx.x" 38 | layer_param { 39 | idx:3 40 | dom:"None" 41 | iter_type:"1" 42 | thread_tag:"blockIdx.x" 43 | } 44 | } 45 | layer { 46 | name:"threadIdx.x" 47 | type:"var(iter)" 48 | top:"threadIdx.x" 49 | layer_param { 50 | idx:4 51 | dtype:int32 52 | } 53 | } 54 | layer { 55 | name:"Node_5" 56 | type:"itervar(node)" 57 | top:"Node_5" 58 | bottom:"threadIdx.x" 59 | layer_param { 60 | idx:5 61 | dom:"None" 62 | iter_type:"1" 63 | thread_tag:"threadIdx.x" 64 | } 65 | } 66 | layer { 67 | name:"reduce_temp0" 68 | type:"var(node)" 69 | top:"reduce_temp0" 70 | layer_param { 71 | idx:6 72 | dtype:handle 73 | } 74 | } 75 | layer { 76 | name:"x" 77 | type:"var(reduce_l)" 78 | top:"x" 79 | layer_param { 80 | idx:7 81 | dtype:float32 82 | } 83 | } 84 | layer { 85 | name:"y" 86 | type:"var(reduce_r)" 87 | top:"y" 88 | layer_param { 89 | idx:8 90 | dtype:float32 91 | } 92 | } 93 | layer { 94 | name:"Node_9" 95 | type:"add(reduce_res)" 96 | top:"Node_9" 97 | bottom:"x" 98 | bottom:"y" 99 | layer_param { 100 | idx:9 101 | } 102 | } 103 | layer { 104 | name:"Node_10" 105 | type:"float(reduce_ind)" 106 | top:"Node_10" 107 | layer_param { 108 | idx:10 109 | value:0.0 110 | dtype:float32 111 | } 112 | } 113 | layer { 114 | name:"Node_11" 115 | type:"common_reducer(node)" 116 | top:"Node_11" 117 | bottom:"x" 118 | bottom:"y" 119 | bottom:"Node_9" 120 | bottom:"Node_10" 121 | layer_param { 122 | idx:11 123 | result_00:"[(x + y)]" 124 | } 125 | } 126 | layer { 127 | name:"Node_12" 128 | type:"int" 129 | top:"Node_12" 130 | layer_param { 131 | idx:12 132 | value:1 133 | dtype:uint32 134 | } 135 | } 136 | layer { 137 | name:"A_1" 138 | type:"var(load_buffer)" 139 | top:"A_1" 140 | layer_param { 141 | idx:13 142 | dtype:handle 143 | } 144 | } 145 | layer { 146 | name:"Node_14" 147 | type:"int(b)" 148 | top:"Node_14" 149 | layer_param { 150 | idx:14 151 | value:32 152 | dtype:int32 153 | } 154 | } 155 | layer { 156 | name:"Node_15" 157 | type:"mul(b)" 158 | top:"Node_15" 159 | bottom:"blockIdx.x" 160 | bottom:"Node_14" 161 | layer_param { 162 | idx:15 163 | } 164 | } 165 | layer { 166 | name:"Node_16" 167 | type:"add(load_index)" 168 | top:"Node_16" 169 | bottom:"threadIdx.x" 170 | bottom:"Node_15" 171 | layer_param { 172 | idx:16 173 | } 174 | } 175 | layer { 176 | name:"Node_17" 177 | type:"load" 178 | top:"Node_17" 179 | bottom:"A_1" 180 | bottom:"Node_16" 181 | layer_param { 182 | idx:17 183 | predicate_00:"True" 184 | body_00:"(float32*)A: Pointer(float32)[(threadIdx.x: int32 + (blockIdx.x: int32*32))]" 185 | } 186 | } 187 | layer { 188 | name:"Node_18" 189 | type:"Call_tir.tvm_thread_allreduce(value)" 190 | top:"Node_18" 191 | bottom:"Node_12" 192 | bottom:"Node_17" 193 | bottom:"Node_12" 194 | bottom:"reduce_temp0" 195 | bottom:"blockIdx.x" 196 | bottom:"threadIdx.x" 197 | layer_param { 198 | idx:18 199 | } 200 | } 201 | layer { 202 | name:"Node_19" 203 | type:"evaluate" 204 | top:"Node_19" 205 | bottom:"Node_18" 206 | layer_param { 207 | idx:19 208 | } 209 | } 210 | layer { 211 | name:"Node_20" 212 | type:"attribute(seq_0)" 213 | top:"Node_20" 214 | bottom:"Node_11" 215 | bottom:"Node_19" 216 | layer_param { 217 | idx:20 218 | attr_key:reduce_scope 219 | body_00:"tir.tvm_thread_allreduce((uint32)1, A[(threadIdx.x + (blockIdx.x*32))], (bool)1, reduce_temp0, blockIdx.x, threadIdx.x)" 220 | value_00:"@tir.reinterpret(0u64, dtype=handle)" 221 | } 222 | } 223 | layer { 224 | name:"B_1" 225 | type:"var(store_buffer)" 226 | top:"B_1" 227 | layer_param { 228 | idx:21 229 | dtype:handle 230 | } 231 | } 232 | layer { 233 | name:"Node_22" 234 | type:"int(load_index)" 235 | top:"Node_22" 236 | layer_param { 237 | idx:22 238 | value:0 239 | dtype:int32 240 | } 241 | } 242 | layer { 243 | name:"Node_23" 244 | type:"load(store_value)" 245 | top:"Node_23" 246 | bottom:"reduce_temp0" 247 | bottom:"Node_22" 248 | layer_param { 249 | idx:23 250 | predicate_00:"True" 251 | body_00:"(float32*)reduce_temp0: Pointer(float32)[0]" 252 | } 253 | } 254 | layer { 255 | name:"Node_24" 256 | type:"store(true)" 257 | top:"Node_24" 258 | bottom:"B_1" 259 | bottom:"Node_23" 260 | bottom:"Node_22" 261 | layer_param { 262 | idx:24 263 | predicate_00:"True" 264 | value_00:"(float32*)reduce_temp0: Pointer(float32)[0]" 265 | index_00:"0" 266 | body_00:"B[0] = reduce_temp0[0]" 267 | } 268 | } 269 | layer { 270 | name:"Node_25" 271 | type:"ifthenelse(seq_1)" 272 | top:"Node_25" 273 | bottom:"Node_24" 274 | layer_param { 275 | idx:25 276 | condition:"True" 277 | } 278 | } 279 | layer { 280 | name:"Node_26" 281 | type:"seq" 282 | top:"Node_26" 283 | bottom:"Node_20" 284 | bottom:"Node_25" 285 | layer_param { 286 | idx:26 287 | seq_00:"[// attr [comm_reducer(result=[(x + y)], lhs=[x], rhs=[y], identity_element=[0f])] reduce_scope = tir.reinterpret((uint64)0)" 288 | seq_01:" tir.tvm_thread_allreduce((uint32)1, A[(threadIdx.x + (blockIdx.x*32))], (bool)1, reduce_temp0, blockIdx.x, threadIdx.x)" 289 | seq_02:" , if ((bool)1)" 290 | seq_03:" B[0] = reduce_temp0[0]" 291 | seq_04:" ]" 292 | } 293 | } 294 | layer { 295 | name:"Node_27" 296 | type:"allocate" 297 | top:"Node_27" 298 | bottom:"reduce_temp0" 299 | bottom:"Node_26" 300 | layer_param { 301 | idx:27 302 | dtype:float32 303 | extents:"[1]" 304 | condition:"True" 305 | body_00:"// attr [comm_reducer(result=[(x + y)], lhs=[x], rhs=[y], identity_element=[0f])] reduce_scope = tir.reinterpret((uint64)0)" 306 | body_01:"tir.tvm_thread_allreduce((uint32)1, A[(threadIdx.x + (blockIdx.x*32))], (bool)1, reduce_temp0, blockIdx.x, threadIdx.x)" 307 | body_02:" if ((bool)1)" 308 | body_03:" B[0] = reduce_temp0[0]" 309 | } 310 | } 311 | layer { 312 | name:"Node_28" 313 | type:"attribute" 314 | top:"Node_28" 315 | bottom:"reduce_temp0" 316 | bottom:"Node_27" 317 | layer_param { 318 | idx:28 319 | attr_key:storage_scope 320 | body_00:"allocate reduce_temp0[float32 * 1]" 321 | body_01:" // attr [comm_reducer(result=[(x + y)], lhs=[x], rhs=[y], identity_element=[0f])] reduce_scope = tir.reinterpret((uint64)0)" 322 | body_02:" tir.tvm_thread_allreduce((uint32)1, A[(threadIdx.x + (blockIdx.x*32))], (bool)1, reduce_temp0, blockIdx.x, threadIdx.x)" 323 | body_03:" if ((bool)1)" 324 | body_04:" B[0] = reduce_temp0[0]" 325 | value_00:"'local'" 326 | } 327 | } 328 | layer { 329 | name:"Node_29" 330 | type:"attribute" 331 | top:"Node_29" 332 | bottom:"Node_5" 333 | bottom:"Node_28" 334 | layer_param { 335 | idx:29 336 | attr_key:thread_extent 337 | body_00:"// attr [reduce_temp0] storage_scope = 'local'" 338 | body_01:"allocate reduce_temp0[float32 * 1]" 339 | body_02:" // attr [comm_reducer(result=[(x + y)], lhs=[x], rhs=[y], identity_element=[0f])] reduce_scope = tir.reinterpret((uint64)0)" 340 | body_03:" tir.tvm_thread_allreduce((uint32)1, A[(threadIdx.x + (blockIdx.x*32))], (bool)1, reduce_temp0, blockIdx.x, threadIdx.x)" 341 | body_04:" if ((bool)1)" 342 | body_05:" B[0] = reduce_temp0[0]" 343 | value_00:"32" 344 | } 345 | } 346 | layer { 347 | name:"Node_30" 348 | type:"attribute" 349 | top:"Node_30" 350 | bottom:"Node_3" 351 | bottom:"Node_29" 352 | layer_param { 353 | idx:30 354 | attr_key:thread_extent 355 | body_00:"// attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 32" 356 | body_01:"// attr [reduce_temp0] storage_scope = 'local'" 357 | body_02:"allocate reduce_temp0[float32 * 1]" 358 | body_03:" // attr [comm_reducer(result=[(x + y)], lhs=[x], rhs=[y], identity_element=[0f])] reduce_scope = tir.reinterpret((uint64)0)" 359 | body_04:" tir.tvm_thread_allreduce((uint32)1, A[(threadIdx.x + (blockIdx.x*32))], (bool)1, reduce_temp0, blockIdx.x, threadIdx.x)" 360 | body_05:" if ((bool)1)" 361 | body_06:" B[0] = reduce_temp0[0]" 362 | value_00:"32" 363 | } 364 | } 365 | layer { 366 | name:"Node_31" 367 | type:"primfunc" 368 | top:"Node_31" 369 | bottom:"A" 370 | bottom:"B" 371 | bottom:"Node_30" 372 | layer_param { 373 | idx:31 374 | body_00:"// attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 32" 375 | body_01:"// attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 32" 376 | body_02:"// attr [reduce_temp0] storage_scope = 'local'" 377 | body_03:"allocate reduce_temp0[float32 * 1]" 378 | body_04:" // attr [comm_reducer(result=[(x + y)], lhs=[x], rhs=[y], identity_element=[0f])] reduce_scope = tir.reinterpret((uint64)0)" 379 | body_05:" tir.tvm_thread_allreduce((uint32)1, A[(threadIdx.x + (blockIdx.x*32))], (bool)1, reduce_temp0, blockIdx.x, threadIdx.x)" 380 | body_06:" if ((bool)1)" 381 | body_07:" B[0] = reduce_temp0[0]" 382 | } 383 | } 384 | -------------------------------------------------------------------------------- /stmt_optimize/visualizes/storage_flatten_after.prototxt: -------------------------------------------------------------------------------- 1 | name : "prim_expr" 2 | layer { 3 | name:"A" 4 | type:"buffer" 5 | top:"A" 6 | layer_param { 7 | idx:0 8 | buffer_name:"A" 9 | shape:[tindex, tindex] 10 | dtype:float32 11 | } 12 | } 13 | layer { 14 | name:"B" 15 | type:"buffer" 16 | top:"B" 17 | layer_param { 18 | idx:1 19 | buffer_name:"B" 20 | shape:[tindex, tindex] 21 | dtype:float32 22 | } 23 | } 24 | layer { 25 | name:"C" 26 | type:"buffer" 27 | top:"C" 28 | layer_param { 29 | idx:2 30 | buffer_name:"C" 31 | shape:[tindex, tindex] 32 | dtype:float32 33 | } 34 | } 35 | layer { 36 | name:"i" 37 | type:"var(loop_var)" 38 | top:"i" 39 | layer_param { 40 | idx:3 41 | dtype:int32 42 | } 43 | } 44 | layer { 45 | name:"Node_4" 46 | type:"int(for_min)" 47 | top:"Node_4" 48 | layer_param { 49 | idx:4 50 | value:0 51 | dtype:int32 52 | } 53 | } 54 | layer { 55 | name:"tindex" 56 | type:"var(for_extent)" 57 | top:"tindex" 58 | layer_param { 59 | idx:5 60 | dtype:int32 61 | } 62 | } 63 | layer { 64 | name:"j" 65 | type:"var(loop_var)" 66 | top:"j" 67 | layer_param { 68 | idx:6 69 | dtype:int32 70 | } 71 | } 72 | layer { 73 | name:"C_1" 74 | type:"var(store_buffer)" 75 | top:"C_1" 76 | layer_param { 77 | idx:7 78 | dtype:handle 79 | } 80 | } 81 | layer { 82 | name:"A_1" 83 | type:"var(load_buffer)" 84 | top:"A_1" 85 | layer_param { 86 | idx:8 87 | dtype:handle 88 | } 89 | } 90 | layer { 91 | name:"stride" 92 | type:"var(b)" 93 | top:"stride" 94 | layer_param { 95 | idx:9 96 | dtype:int32 97 | } 98 | } 99 | layer { 100 | name:"Node_10" 101 | type:"mul(a)" 102 | top:"Node_10" 103 | bottom:"i" 104 | bottom:"stride" 105 | layer_param { 106 | idx:10 107 | } 108 | } 109 | layer { 110 | name:"stride_1" 111 | type:"var(b)" 112 | top:"stride_1" 113 | layer_param { 114 | idx:11 115 | dtype:int32 116 | } 117 | } 118 | layer { 119 | name:"Node_12" 120 | type:"mul(b)" 121 | top:"Node_12" 122 | bottom:"j" 123 | bottom:"stride_1" 124 | layer_param { 125 | idx:12 126 | } 127 | } 128 | layer { 129 | name:"Node_13" 130 | type:"add(load_index)" 131 | top:"Node_13" 132 | bottom:"Node_10" 133 | bottom:"Node_12" 134 | layer_param { 135 | idx:13 136 | } 137 | } 138 | layer { 139 | name:"Node_14" 140 | type:"load(a)" 141 | top:"Node_14" 142 | bottom:"A_1" 143 | bottom:"Node_13" 144 | layer_param { 145 | idx:14 146 | predicate_00:"True" 147 | body_00:"(float32*)A: Pointer(float32)[((i: int32*stride: int32) + (j: int32*stride_1: int32))]" 148 | } 149 | } 150 | layer { 151 | name:"B_1" 152 | type:"var(load_buffer)" 153 | top:"B_1" 154 | layer_param { 155 | idx:15 156 | dtype:handle 157 | } 158 | } 159 | layer { 160 | name:"stride_2" 161 | type:"var(b)" 162 | top:"stride_2" 163 | layer_param { 164 | idx:16 165 | dtype:int32 166 | } 167 | } 168 | layer { 169 | name:"Node_17" 170 | type:"mul(a)" 171 | top:"Node_17" 172 | bottom:"i" 173 | bottom:"stride_2" 174 | layer_param { 175 | idx:17 176 | } 177 | } 178 | layer { 179 | name:"stride_3" 180 | type:"var(b)" 181 | top:"stride_3" 182 | layer_param { 183 | idx:18 184 | dtype:int32 185 | } 186 | } 187 | layer { 188 | name:"Node_19" 189 | type:"mul(b)" 190 | top:"Node_19" 191 | bottom:"j" 192 | bottom:"stride_3" 193 | layer_param { 194 | idx:19 195 | } 196 | } 197 | layer { 198 | name:"Node_20" 199 | type:"add(load_index)" 200 | top:"Node_20" 201 | bottom:"Node_17" 202 | bottom:"Node_19" 203 | layer_param { 204 | idx:20 205 | } 206 | } 207 | layer { 208 | name:"Node_21" 209 | type:"load(b)" 210 | top:"Node_21" 211 | bottom:"B_1" 212 | bottom:"Node_20" 213 | layer_param { 214 | idx:21 215 | predicate_00:"True" 216 | body_00:"(float32*)B: Pointer(float32)[((i: int32*stride: int32) + (j: int32*stride_1: int32))]" 217 | } 218 | } 219 | layer { 220 | name:"Node_22" 221 | type:"add(store_value)" 222 | top:"Node_22" 223 | bottom:"Node_14" 224 | bottom:"Node_21" 225 | layer_param { 226 | idx:22 227 | } 228 | } 229 | layer { 230 | name:"stride_4" 231 | type:"var(b)" 232 | top:"stride_4" 233 | layer_param { 234 | idx:23 235 | dtype:int32 236 | } 237 | } 238 | layer { 239 | name:"Node_24" 240 | type:"mul(a)" 241 | top:"Node_24" 242 | bottom:"i" 243 | bottom:"stride_4" 244 | layer_param { 245 | idx:24 246 | } 247 | } 248 | layer { 249 | name:"stride_5" 250 | type:"var(b)" 251 | top:"stride_5" 252 | layer_param { 253 | idx:25 254 | dtype:int32 255 | } 256 | } 257 | layer { 258 | name:"Node_26" 259 | type:"mul(b)" 260 | top:"Node_26" 261 | bottom:"j" 262 | bottom:"stride_5" 263 | layer_param { 264 | idx:26 265 | } 266 | } 267 | layer { 268 | name:"Node_27" 269 | type:"add(store_index)" 270 | top:"Node_27" 271 | bottom:"Node_24" 272 | bottom:"Node_26" 273 | layer_param { 274 | idx:27 275 | } 276 | } 277 | layer { 278 | name:"Node_28" 279 | type:"store" 280 | top:"Node_28" 281 | bottom:"C_1" 282 | bottom:"Node_22" 283 | bottom:"Node_27" 284 | layer_param { 285 | idx:28 286 | predicate_00:"True" 287 | value_00:"((float32*)A: Pointer(float32)[((i: int32*stride: int32) + (j: int32*stride_1: int32))] + (float32*)B: Pointer(float32)[((i*stride_2: int32) + (j*stride_3: int32))])" 288 | index_00:"((i: int32*stride: int32) + (j: int32*stride_1: int32))" 289 | body_00:"C[((i*stride) + (j*stride))] = (A[((i*stride) + (j*stride))] + B[((i*stride) + (j*stride))])" 290 | } 291 | } 292 | layer { 293 | name:"Node_29" 294 | type:"for" 295 | top:"Node_29" 296 | bottom:"j" 297 | bottom:"Node_4" 298 | bottom:"tindex" 299 | bottom:"Node_28" 300 | layer_param { 301 | idx:29 302 | kind:0 303 | body_00:"C[((i*stride) + (j*stride))] = (A[((i*stride) + (j*stride))] + B[((i*stride) + (j*stride))])" 304 | } 305 | } 306 | layer { 307 | name:"Node_30" 308 | type:"for" 309 | top:"Node_30" 310 | bottom:"i" 311 | bottom:"Node_4" 312 | bottom:"tindex" 313 | bottom:"Node_29" 314 | layer_param { 315 | idx:30 316 | kind:0 317 | body_00:"for (j, 0, tindex)" 318 | body_01:" C[((i*stride) + (j*stride))] = (A[((i*stride) + (j*stride))] + B[((i*stride) + (j*stride))])" 319 | } 320 | } 321 | layer { 322 | name:"Node_31" 323 | type:"primfunc" 324 | top:"Node_31" 325 | bottom:"A" 326 | bottom:"B" 327 | bottom:"C" 328 | bottom:"Node_30" 329 | layer_param { 330 | idx:31 331 | body_00:"for (i, 0, tindex)" 332 | body_01:" for (j, 0, tindex)" 333 | body_02:" C[((i*stride) + (j*stride))] = (A[((i*stride) + (j*stride))] + B[((i*stride) + (j*stride))])" 334 | } 335 | } 336 | -------------------------------------------------------------------------------- /stmt_optimize/visualizes/storage_flatten_before.prototxt: -------------------------------------------------------------------------------- 1 | name : "prim_expr" 2 | layer { 3 | name:"A" 4 | type:"buffer" 5 | top:"A" 6 | layer_param { 7 | idx:0 8 | buffer_name:"A" 9 | shape:[tindex, tindex] 10 | dtype:float32 11 | } 12 | } 13 | layer { 14 | name:"B" 15 | type:"buffer" 16 | top:"B" 17 | layer_param { 18 | idx:1 19 | buffer_name:"B" 20 | shape:[tindex, tindex] 21 | dtype:float32 22 | } 23 | } 24 | layer { 25 | name:"C" 26 | type:"buffer" 27 | top:"C" 28 | layer_param { 29 | idx:2 30 | buffer_name:"C" 31 | shape:[tindex, tindex] 32 | dtype:float32 33 | } 34 | } 35 | layer { 36 | name:"Node_3" 37 | type:"int" 38 | top:"Node_3" 39 | layer_param { 40 | idx:3 41 | value:0 42 | dtype:int32 43 | } 44 | } 45 | layer { 46 | name:"tindex" 47 | type:"var" 48 | top:"tindex" 49 | layer_param { 50 | idx:4 51 | dtype:int32 52 | } 53 | } 54 | layer { 55 | name:"Node_5" 56 | type:"range(bound_0)" 57 | top:"Node_5" 58 | bottom:"Node_3" 59 | bottom:"tindex" 60 | layer_param { 61 | idx:5 62 | range_00:"range(min=0, ext=tindex)" 63 | } 64 | } 65 | layer { 66 | name:"Node_6" 67 | type:"range(bound_1)" 68 | top:"Node_6" 69 | bottom:"Node_3" 70 | bottom:"tindex" 71 | layer_param { 72 | idx:6 73 | range_00:"range(min=0, ext=tindex)" 74 | } 75 | } 76 | layer { 77 | name:"i" 78 | type:"var(loop_var)" 79 | top:"i" 80 | layer_param { 81 | idx:7 82 | dtype:int32 83 | } 84 | } 85 | layer { 86 | name:"j" 87 | type:"var(loop_var)" 88 | top:"j" 89 | layer_param { 90 | idx:8 91 | dtype:int32 92 | } 93 | } 94 | layer { 95 | name:"Node_9" 96 | type:"buffer_load(a)" 97 | top:"Node_9" 98 | bottom:"A" 99 | bottom:"i" 100 | bottom:"j" 101 | layer_param { 102 | idx:9 103 | } 104 | } 105 | layer { 106 | name:"Node_10" 107 | type:"buffer_load(b)" 108 | top:"Node_10" 109 | bottom:"B" 110 | bottom:"i" 111 | bottom:"j" 112 | layer_param { 113 | idx:10 114 | } 115 | } 116 | layer { 117 | name:"Node_11" 118 | type:"add(value)" 119 | top:"Node_11" 120 | bottom:"Node_9" 121 | bottom:"Node_10" 122 | layer_param { 123 | idx:11 124 | } 125 | } 126 | layer { 127 | name:"Node_12" 128 | type:"buffer_store" 129 | top:"Node_12" 130 | bottom:"C" 131 | bottom:"Node_11" 132 | bottom:"i" 133 | bottom:"j" 134 | layer_param { 135 | idx:12 136 | value_00:"(A: Buffer(A_1: Pointer(float32), float32, [tindex: int32, tindex], [stride: int32, stride_1: int32], type='auto')[i: int32, j: int32] + B: Buffer(B_1: Pointer(float32), float32, [tindex, tindex], [stride_2: int32, stride_3: int32], type='auto')[i, j])" 137 | indices_00:"[i, j]" 138 | } 139 | } 140 | layer { 141 | name:"Node_13" 142 | type:"for" 143 | top:"Node_13" 144 | bottom:"j" 145 | bottom:"Node_3" 146 | bottom:"tindex" 147 | bottom:"Node_12" 148 | layer_param { 149 | idx:13 150 | kind:0 151 | body_00:"C[i, j] = (A[i, j] + B[i, j])" 152 | } 153 | } 154 | layer { 155 | name:"Node_14" 156 | type:"for" 157 | top:"Node_14" 158 | bottom:"i" 159 | bottom:"Node_3" 160 | bottom:"tindex" 161 | bottom:"Node_13" 162 | layer_param { 163 | idx:14 164 | kind:0 165 | body_00:"for (j, 0, tindex)" 166 | body_01:" C[i, j] = (A[i, j] + B[i, j])" 167 | } 168 | } 169 | layer { 170 | name:"Node_15" 171 | type:"buffer_realize" 172 | top:"Node_15" 173 | bottom:"Node_5" 174 | bottom:"Node_6" 175 | bottom:"C" 176 | bottom:"Node_14" 177 | layer_param { 178 | idx:15 179 | condition:True 180 | body_00:"for (i, 0, tindex)" 181 | body_01:" for (j, 0, tindex)" 182 | body_02:" C[i, j] = (A[i, j] + B[i, j])" 183 | bounds_00:"[range(min=0, ext=tindex), range(min=0, ext=tindex)]" 184 | } 185 | } 186 | layer { 187 | name:"Node_16" 188 | type:"attribute" 189 | top:"Node_16" 190 | bottom:"C" 191 | bottom:"Node_15" 192 | layer_param { 193 | idx:16 194 | attr_key:realize_scope 195 | body_00:"buffer_realize C([0, tindex], [0, tindex])" 196 | body_01:" for (i, 0, tindex)" 197 | body_02:" for (j, 0, tindex)" 198 | body_03:" C[i, j] = (A[i, j] + B[i, j])" 199 | value_00:"''" 200 | } 201 | } 202 | layer { 203 | name:"Node_17" 204 | type:"primfunc" 205 | top:"Node_17" 206 | bottom:"A" 207 | bottom:"B" 208 | bottom:"C" 209 | bottom:"Node_16" 210 | layer_param { 211 | idx:17 212 | body_00:"// attr [buffer(C, 0x7fed84a9e610)] realize_scope = ''" 213 | body_01:"buffer_realize C([0, tindex], [0, tindex])" 214 | body_02:" for (i, 0, tindex)" 215 | body_03:" for (j, 0, tindex)" 216 | body_04:" C[i, j] = (A[i, j] + B[i, j])" 217 | } 218 | } 219 | -------------------------------------------------------------------------------- /stmt_optimize/visualizes/vectorize_loop_after.prototxt: -------------------------------------------------------------------------------- 1 | name : "prim_expr" 2 | layer { 3 | name:"A" 4 | type:"buffer" 5 | top:"A" 6 | layer_param { 7 | idx:0 8 | buffer_name:"A" 9 | shape:[1024, 1024] 10 | dtype:float32 11 | } 12 | } 13 | layer { 14 | name:"B" 15 | type:"buffer" 16 | top:"B" 17 | layer_param { 18 | idx:1 19 | buffer_name:"B" 20 | shape:[1024, 1024] 21 | dtype:float32 22 | } 23 | } 24 | layer { 25 | name:"C" 26 | type:"buffer" 27 | top:"C" 28 | layer_param { 29 | idx:2 30 | buffer_name:"C" 31 | shape:[1024, 1024] 32 | dtype:float32 33 | } 34 | } 35 | layer { 36 | name:"x.outer" 37 | type:"var(loop_var)" 38 | top:"x.outer" 39 | layer_param { 40 | idx:3 41 | dtype:int32 42 | } 43 | } 44 | layer { 45 | name:"Node_4" 46 | type:"int(for_min)" 47 | top:"Node_4" 48 | layer_param { 49 | idx:4 50 | value:0 51 | dtype:int32 52 | } 53 | } 54 | layer { 55 | name:"Node_5" 56 | type:"int(for_extent)" 57 | top:"Node_5" 58 | layer_param { 59 | idx:5 60 | value:32 61 | dtype:int32 62 | } 63 | } 64 | layer { 65 | name:"y.outer" 66 | type:"var(loop_var)" 67 | top:"y.outer" 68 | layer_param { 69 | idx:6 70 | dtype:int32 71 | } 72 | } 73 | layer { 74 | name:"x.inner" 75 | type:"var(loop_var)" 76 | top:"x.inner" 77 | layer_param { 78 | idx:7 79 | dtype:int32 80 | } 81 | } 82 | layer { 83 | name:"C_1" 84 | type:"var(store_buffer)" 85 | top:"C_1" 86 | layer_param { 87 | idx:8 88 | dtype:handle 89 | } 90 | } 91 | layer { 92 | name:"A_1" 93 | type:"var(load_buffer)" 94 | top:"A_1" 95 | layer_param { 96 | idx:9 97 | dtype:handle 98 | } 99 | } 100 | layer { 101 | name:"Node_10" 102 | type:"int(b)" 103 | top:"Node_10" 104 | layer_param { 105 | idx:10 106 | value:32768 107 | dtype:int32 108 | } 109 | } 110 | layer { 111 | name:"Node_11" 112 | type:"mul(a)" 113 | top:"Node_11" 114 | bottom:"x.outer" 115 | bottom:"Node_10" 116 | layer_param { 117 | idx:11 118 | } 119 | } 120 | layer { 121 | name:"Node_12" 122 | type:"int(b)" 123 | top:"Node_12" 124 | layer_param { 125 | idx:12 126 | value:1024 127 | dtype:int32 128 | } 129 | } 130 | layer { 131 | name:"Node_13" 132 | type:"mul(b)" 133 | top:"Node_13" 134 | bottom:"x.inner" 135 | bottom:"Node_12" 136 | layer_param { 137 | idx:13 138 | } 139 | } 140 | layer { 141 | name:"Node_14" 142 | type:"add(a)" 143 | top:"Node_14" 144 | bottom:"Node_11" 145 | bottom:"Node_13" 146 | layer_param { 147 | idx:14 148 | } 149 | } 150 | layer { 151 | name:"Node_15" 152 | type:"mul(b)" 153 | top:"Node_15" 154 | bottom:"y.outer" 155 | bottom:"Node_5" 156 | layer_param { 157 | idx:15 158 | } 159 | } 160 | layer { 161 | name:"Node_16" 162 | type:"add(base)" 163 | top:"Node_16" 164 | bottom:"Node_14" 165 | bottom:"Node_15" 166 | layer_param { 167 | idx:16 168 | } 169 | } 170 | layer { 171 | name:"Node_17" 172 | type:"int(stride)" 173 | top:"Node_17" 174 | layer_param { 175 | idx:17 176 | value:1 177 | dtype:int32 178 | } 179 | } 180 | layer { 181 | name:"Node_18" 182 | type:"ramp(load_index)" 183 | top:"Node_18" 184 | bottom:"Node_16" 185 | bottom:"Node_17" 186 | layer_param { 187 | idx:18 188 | lanes:32 189 | base_00:"(((x.outer: int32*32768) + (x.inner: int32*1024)) + (y.outer: int32*32))" 190 | stride_00:"1" 191 | } 192 | } 193 | layer { 194 | name:"Node_19" 195 | type:"load(a)" 196 | top:"Node_19" 197 | bottom:"A_1" 198 | bottom:"Node_18" 199 | layer_param { 200 | idx:19 201 | predicate_00:"broadcast(True, 32)" 202 | body_00:"(float32x32*)A: Pointer(float32)[ramp((((x.outer: int32*32768) + (x.inner: int32*1024)) + (y.outer: int32*32)), 1, 32)]" 203 | } 204 | } 205 | layer { 206 | name:"B_1" 207 | type:"var(load_buffer)" 208 | top:"B_1" 209 | layer_param { 210 | idx:20 211 | dtype:handle 212 | } 213 | } 214 | layer { 215 | name:"Node_21" 216 | type:"mul(a)" 217 | top:"Node_21" 218 | bottom:"x.outer" 219 | bottom:"Node_10" 220 | layer_param { 221 | idx:21 222 | } 223 | } 224 | layer { 225 | name:"Node_22" 226 | type:"mul(b)" 227 | top:"Node_22" 228 | bottom:"x.inner" 229 | bottom:"Node_12" 230 | layer_param { 231 | idx:22 232 | } 233 | } 234 | layer { 235 | name:"Node_23" 236 | type:"add(a)" 237 | top:"Node_23" 238 | bottom:"Node_21" 239 | bottom:"Node_22" 240 | layer_param { 241 | idx:23 242 | } 243 | } 244 | layer { 245 | name:"Node_24" 246 | type:"mul(b)" 247 | top:"Node_24" 248 | bottom:"y.outer" 249 | bottom:"Node_5" 250 | layer_param { 251 | idx:24 252 | } 253 | } 254 | layer { 255 | name:"Node_25" 256 | type:"add(base)" 257 | top:"Node_25" 258 | bottom:"Node_23" 259 | bottom:"Node_24" 260 | layer_param { 261 | idx:25 262 | } 263 | } 264 | layer { 265 | name:"Node_26" 266 | type:"ramp(load_index)" 267 | top:"Node_26" 268 | bottom:"Node_25" 269 | bottom:"Node_17" 270 | layer_param { 271 | idx:26 272 | lanes:32 273 | base_00:"(((x.outer: int32*32768) + (x.inner: int32*1024)) + (y.outer: int32*32))" 274 | stride_00:"1" 275 | } 276 | } 277 | layer { 278 | name:"Node_27" 279 | type:"load(b)" 280 | top:"Node_27" 281 | bottom:"B_1" 282 | bottom:"Node_26" 283 | layer_param { 284 | idx:27 285 | predicate_00:"broadcast(True, 32)" 286 | body_00:"(float32x32*)B: Pointer(float32)[ramp((((x.outer: int32*32768) + (x.inner: int32*1024)) + (y.outer: int32*32)), 1, 32)]" 287 | } 288 | } 289 | layer { 290 | name:"Node_28" 291 | type:"add(store_value)" 292 | top:"Node_28" 293 | bottom:"Node_19" 294 | bottom:"Node_27" 295 | layer_param { 296 | idx:28 297 | } 298 | } 299 | layer { 300 | name:"Node_29" 301 | type:"mul(a)" 302 | top:"Node_29" 303 | bottom:"x.outer" 304 | bottom:"Node_10" 305 | layer_param { 306 | idx:29 307 | } 308 | } 309 | layer { 310 | name:"Node_30" 311 | type:"mul(b)" 312 | top:"Node_30" 313 | bottom:"x.inner" 314 | bottom:"Node_12" 315 | layer_param { 316 | idx:30 317 | } 318 | } 319 | layer { 320 | name:"Node_31" 321 | type:"add(a)" 322 | top:"Node_31" 323 | bottom:"Node_29" 324 | bottom:"Node_30" 325 | layer_param { 326 | idx:31 327 | } 328 | } 329 | layer { 330 | name:"Node_32" 331 | type:"mul(b)" 332 | top:"Node_32" 333 | bottom:"y.outer" 334 | bottom:"Node_5" 335 | layer_param { 336 | idx:32 337 | } 338 | } 339 | layer { 340 | name:"Node_33" 341 | type:"add(base)" 342 | top:"Node_33" 343 | bottom:"Node_31" 344 | bottom:"Node_32" 345 | layer_param { 346 | idx:33 347 | } 348 | } 349 | layer { 350 | name:"Node_34" 351 | type:"ramp(store_index)" 352 | top:"Node_34" 353 | bottom:"Node_33" 354 | bottom:"Node_17" 355 | layer_param { 356 | idx:34 357 | lanes:32 358 | base_00:"(((x.outer: int32*32768) + (x.inner: int32*1024)) + (y.outer: int32*32))" 359 | stride_00:"1" 360 | } 361 | } 362 | layer { 363 | name:"Node_35" 364 | type:"store" 365 | top:"Node_35" 366 | bottom:"C_1" 367 | bottom:"Node_28" 368 | bottom:"Node_34" 369 | layer_param { 370 | idx:35 371 | predicate_00:"broadcast(True, 32)" 372 | value_00:"((float32x32*)A: Pointer(float32)[ramp((((x.outer: int32*32768) + (x.inner: int32*1024)) + (y.outer: int32*32)), 1, 32)] + (float32x32*)B: Pointer(float32)[ramp((((x.outer*32768) + (x.inner*1024)) + (y.outer*32)), 1, 32)])" 373 | index_00:"ramp((((x.outer: int32*32768) + (x.inner: int32*1024)) + (y.outer: int32*32)), 1, 32)" 374 | body_00:"C[ramp((((x.outer*32768) + (x.inner*1024)) + (y.outer*32)), 1, 32)] = (A[ramp((((x.outer*32768) + (x.inner*1024)) + (y.outer*32)), 1, 32)] + B[ramp((((x.outer*32768) + (x.inner*1024)) + (y.outer*32)), 1, 32)])" 375 | } 376 | } 377 | layer { 378 | name:"Node_36" 379 | type:"for" 380 | top:"Node_36" 381 | bottom:"x.inner" 382 | bottom:"Node_4" 383 | bottom:"Node_5" 384 | bottom:"Node_35" 385 | layer_param { 386 | idx:36 387 | kind:0 388 | body_00:"C[ramp((((x.outer*32768) + (x.inner*1024)) + (y.outer*32)), 1, 32)] = (A[ramp((((x.outer*32768) + (x.inner*1024)) + (y.outer*32)), 1, 32)] + B[ramp((((x.outer*32768) + (x.inner*1024)) + (y.outer*32)), 1, 32)])" 389 | } 390 | } 391 | layer { 392 | name:"Node_37" 393 | type:"for" 394 | top:"Node_37" 395 | bottom:"y.outer" 396 | bottom:"Node_4" 397 | bottom:"Node_5" 398 | bottom:"Node_36" 399 | layer_param { 400 | idx:37 401 | kind:0 402 | body_00:"for (x.inner, 0, 32)" 403 | body_01:" C[ramp((((x.outer*32768) + (x.inner*1024)) + (y.outer*32)), 1, 32)] = (A[ramp((((x.outer*32768) + (x.inner*1024)) + (y.outer*32)), 1, 32)] + B[ramp((((x.outer*32768) + (x.inner*1024)) + (y.outer*32)), 1, 32)])" 404 | } 405 | } 406 | layer { 407 | name:"Node_38" 408 | type:"for" 409 | top:"Node_38" 410 | bottom:"x.outer" 411 | bottom:"Node_4" 412 | bottom:"Node_5" 413 | bottom:"Node_37" 414 | layer_param { 415 | idx:38 416 | kind:0 417 | body_00:"for (y.outer, 0, 32)" 418 | body_01:" for (x.inner, 0, 32)" 419 | body_02:" C[ramp((((x.outer*32768) + (x.inner*1024)) + (y.outer*32)), 1, 32)] = (A[ramp((((x.outer*32768) + (x.inner*1024)) + (y.outer*32)), 1, 32)] + B[ramp((((x.outer*32768) + (x.inner*1024)) + (y.outer*32)), 1, 32)])" 420 | } 421 | } 422 | layer { 423 | name:"Node_39" 424 | type:"primfunc" 425 | top:"Node_39" 426 | bottom:"A" 427 | bottom:"B" 428 | bottom:"C" 429 | bottom:"Node_38" 430 | layer_param { 431 | idx:39 432 | body_00:"for (x.outer, 0, 32)" 433 | body_01:" for (y.outer, 0, 32)" 434 | body_02:" for (x.inner, 0, 32)" 435 | body_03:" C[ramp((((x.outer*32768) + (x.inner*1024)) + (y.outer*32)), 1, 32)] = (A[ramp((((x.outer*32768) + (x.inner*1024)) + (y.outer*32)), 1, 32)] + B[ramp((((x.outer*32768) + (x.inner*1024)) + (y.outer*32)), 1, 32)])" 436 | } 437 | } 438 | -------------------------------------------------------------------------------- /stmt_optimize/visualizes/vectorize_loop_before.prototxt: -------------------------------------------------------------------------------- 1 | name : "prim_expr" 2 | layer { 3 | name:"A" 4 | type:"buffer" 5 | top:"A" 6 | layer_param { 7 | idx:0 8 | buffer_name:"A" 9 | shape:[1024, 1024] 10 | dtype:float32 11 | } 12 | } 13 | layer { 14 | name:"B" 15 | type:"buffer" 16 | top:"B" 17 | layer_param { 18 | idx:1 19 | buffer_name:"B" 20 | shape:[1024, 1024] 21 | dtype:float32 22 | } 23 | } 24 | layer { 25 | name:"C" 26 | type:"buffer" 27 | top:"C" 28 | layer_param { 29 | idx:2 30 | buffer_name:"C" 31 | shape:[1024, 1024] 32 | dtype:float32 33 | } 34 | } 35 | layer { 36 | name:"x.outer" 37 | type:"var(loop_var)" 38 | top:"x.outer" 39 | layer_param { 40 | idx:3 41 | dtype:int32 42 | } 43 | } 44 | layer { 45 | name:"Node_4" 46 | type:"int(for_min)" 47 | top:"Node_4" 48 | layer_param { 49 | idx:4 50 | value:0 51 | dtype:int32 52 | } 53 | } 54 | layer { 55 | name:"Node_5" 56 | type:"int(for_extent)" 57 | top:"Node_5" 58 | layer_param { 59 | idx:5 60 | value:32 61 | dtype:int32 62 | } 63 | } 64 | layer { 65 | name:"y.outer" 66 | type:"var(loop_var)" 67 | top:"y.outer" 68 | layer_param { 69 | idx:6 70 | dtype:int32 71 | } 72 | } 73 | layer { 74 | name:"x.inner" 75 | type:"var(loop_var)" 76 | top:"x.inner" 77 | layer_param { 78 | idx:7 79 | dtype:int32 80 | } 81 | } 82 | layer { 83 | name:"y.inner" 84 | type:"var(loop_var)" 85 | top:"y.inner" 86 | layer_param { 87 | idx:8 88 | dtype:int32 89 | } 90 | } 91 | layer { 92 | name:"C_1" 93 | type:"var(store_buffer)" 94 | top:"C_1" 95 | layer_param { 96 | idx:9 97 | dtype:handle 98 | } 99 | } 100 | layer { 101 | name:"A_1" 102 | type:"var(load_buffer)" 103 | top:"A_1" 104 | layer_param { 105 | idx:10 106 | dtype:handle 107 | } 108 | } 109 | layer { 110 | name:"Node_11" 111 | type:"int(b)" 112 | top:"Node_11" 113 | layer_param { 114 | idx:11 115 | value:32768 116 | dtype:int32 117 | } 118 | } 119 | layer { 120 | name:"Node_12" 121 | type:"mul(a)" 122 | top:"Node_12" 123 | bottom:"x.outer" 124 | bottom:"Node_11" 125 | layer_param { 126 | idx:12 127 | } 128 | } 129 | layer { 130 | name:"Node_13" 131 | type:"int(b)" 132 | top:"Node_13" 133 | layer_param { 134 | idx:13 135 | value:1024 136 | dtype:int32 137 | } 138 | } 139 | layer { 140 | name:"Node_14" 141 | type:"mul(b)" 142 | top:"Node_14" 143 | bottom:"x.inner" 144 | bottom:"Node_13" 145 | layer_param { 146 | idx:14 147 | } 148 | } 149 | layer { 150 | name:"Node_15" 151 | type:"add(a)" 152 | top:"Node_15" 153 | bottom:"Node_12" 154 | bottom:"Node_14" 155 | layer_param { 156 | idx:15 157 | } 158 | } 159 | layer { 160 | name:"Node_16" 161 | type:"mul(b)" 162 | top:"Node_16" 163 | bottom:"y.outer" 164 | bottom:"Node_5" 165 | layer_param { 166 | idx:16 167 | } 168 | } 169 | layer { 170 | name:"Node_17" 171 | type:"add(a)" 172 | top:"Node_17" 173 | bottom:"Node_15" 174 | bottom:"Node_16" 175 | layer_param { 176 | idx:17 177 | } 178 | } 179 | layer { 180 | name:"Node_18" 181 | type:"add(load_index)" 182 | top:"Node_18" 183 | bottom:"Node_17" 184 | bottom:"y.inner" 185 | layer_param { 186 | idx:18 187 | } 188 | } 189 | layer { 190 | name:"Node_19" 191 | type:"load(a)" 192 | top:"Node_19" 193 | bottom:"A_1" 194 | bottom:"Node_18" 195 | layer_param { 196 | idx:19 197 | predicate_00:"True" 198 | body_00:"(float32*)A: Pointer(float32)[((((x.outer: int32*32768) + (x.inner: int32*1024)) + (y.outer: int32*32)) + y.inner: int32)]" 199 | } 200 | } 201 | layer { 202 | name:"B_1" 203 | type:"var(load_buffer)" 204 | top:"B_1" 205 | layer_param { 206 | idx:20 207 | dtype:handle 208 | } 209 | } 210 | layer { 211 | name:"Node_21" 212 | type:"mul(a)" 213 | top:"Node_21" 214 | bottom:"x.outer" 215 | bottom:"Node_11" 216 | layer_param { 217 | idx:21 218 | } 219 | } 220 | layer { 221 | name:"Node_22" 222 | type:"mul(b)" 223 | top:"Node_22" 224 | bottom:"x.inner" 225 | bottom:"Node_13" 226 | layer_param { 227 | idx:22 228 | } 229 | } 230 | layer { 231 | name:"Node_23" 232 | type:"add(a)" 233 | top:"Node_23" 234 | bottom:"Node_21" 235 | bottom:"Node_22" 236 | layer_param { 237 | idx:23 238 | } 239 | } 240 | layer { 241 | name:"Node_24" 242 | type:"mul(b)" 243 | top:"Node_24" 244 | bottom:"y.outer" 245 | bottom:"Node_5" 246 | layer_param { 247 | idx:24 248 | } 249 | } 250 | layer { 251 | name:"Node_25" 252 | type:"add(a)" 253 | top:"Node_25" 254 | bottom:"Node_23" 255 | bottom:"Node_24" 256 | layer_param { 257 | idx:25 258 | } 259 | } 260 | layer { 261 | name:"Node_26" 262 | type:"add(load_index)" 263 | top:"Node_26" 264 | bottom:"Node_25" 265 | bottom:"y.inner" 266 | layer_param { 267 | idx:26 268 | } 269 | } 270 | layer { 271 | name:"Node_27" 272 | type:"load(b)" 273 | top:"Node_27" 274 | bottom:"B_1" 275 | bottom:"Node_26" 276 | layer_param { 277 | idx:27 278 | predicate_00:"True" 279 | body_00:"(float32*)B: Pointer(float32)[((((x.outer: int32*32768) + (x.inner: int32*1024)) + (y.outer: int32*32)) + y.inner: int32)]" 280 | } 281 | } 282 | layer { 283 | name:"Node_28" 284 | type:"add(store_value)" 285 | top:"Node_28" 286 | bottom:"Node_19" 287 | bottom:"Node_27" 288 | layer_param { 289 | idx:28 290 | } 291 | } 292 | layer { 293 | name:"Node_29" 294 | type:"mul(a)" 295 | top:"Node_29" 296 | bottom:"x.outer" 297 | bottom:"Node_11" 298 | layer_param { 299 | idx:29 300 | } 301 | } 302 | layer { 303 | name:"Node_30" 304 | type:"mul(b)" 305 | top:"Node_30" 306 | bottom:"x.inner" 307 | bottom:"Node_13" 308 | layer_param { 309 | idx:30 310 | } 311 | } 312 | layer { 313 | name:"Node_31" 314 | type:"add(a)" 315 | top:"Node_31" 316 | bottom:"Node_29" 317 | bottom:"Node_30" 318 | layer_param { 319 | idx:31 320 | } 321 | } 322 | layer { 323 | name:"Node_32" 324 | type:"mul(b)" 325 | top:"Node_32" 326 | bottom:"y.outer" 327 | bottom:"Node_5" 328 | layer_param { 329 | idx:32 330 | } 331 | } 332 | layer { 333 | name:"Node_33" 334 | type:"add(a)" 335 | top:"Node_33" 336 | bottom:"Node_31" 337 | bottom:"Node_32" 338 | layer_param { 339 | idx:33 340 | } 341 | } 342 | layer { 343 | name:"Node_34" 344 | type:"add(store_index)" 345 | top:"Node_34" 346 | bottom:"Node_33" 347 | bottom:"y.inner" 348 | layer_param { 349 | idx:34 350 | } 351 | } 352 | layer { 353 | name:"Node_35" 354 | type:"store" 355 | top:"Node_35" 356 | bottom:"C_1" 357 | bottom:"Node_28" 358 | bottom:"Node_34" 359 | layer_param { 360 | idx:35 361 | predicate_00:"True" 362 | value_00:"((float32*)A: Pointer(float32)[((((x.outer: int32*32768) + (x.inner: int32*1024)) + (y.outer: int32*32)) + y.inner: int32)] + (float32*)B: Pointer(float32)[((((x.outer*32768) + (x.inner*1024)) + (y.outer*32)) + y.inner)])" 363 | index_00:"((((x.outer: int32*32768) + (x.inner: int32*1024)) + (y.outer: int32*32)) + y.inner: int32)" 364 | body_00:"C[((((x.outer*32768) + (x.inner*1024)) + (y.outer*32)) + y.inner)] = (A[((((x.outer*32768) + (x.inner*1024)) + (y.outer*32)) + y.inner)] + B[((((x.outer*32768) + (x.inner*1024)) + (y.outer*32)) + y.inner)])" 365 | } 366 | } 367 | layer { 368 | name:"Node_36" 369 | type:"for" 370 | top:"Node_36" 371 | bottom:"y.inner" 372 | bottom:"Node_4" 373 | bottom:"Node_5" 374 | bottom:"Node_35" 375 | layer_param { 376 | idx:36 377 | kind:2 378 | body_00:"C[((((x.outer*32768) + (x.inner*1024)) + (y.outer*32)) + y.inner)] = (A[((((x.outer*32768) + (x.inner*1024)) + (y.outer*32)) + y.inner)] + B[((((x.outer*32768) + (x.inner*1024)) + (y.outer*32)) + y.inner)])" 379 | } 380 | } 381 | layer { 382 | name:"Node_37" 383 | type:"for" 384 | top:"Node_37" 385 | bottom:"x.inner" 386 | bottom:"Node_4" 387 | bottom:"Node_5" 388 | bottom:"Node_36" 389 | layer_param { 390 | idx:37 391 | kind:0 392 | body_00:"vectorized (y.inner, 0, 32)" 393 | body_01:" C[((((x.outer*32768) + (x.inner*1024)) + (y.outer*32)) + y.inner)] = (A[((((x.outer*32768) + (x.inner*1024)) + (y.outer*32)) + y.inner)] + B[((((x.outer*32768) + (x.inner*1024)) + (y.outer*32)) + y.inner)])" 394 | } 395 | } 396 | layer { 397 | name:"Node_38" 398 | type:"for" 399 | top:"Node_38" 400 | bottom:"y.outer" 401 | bottom:"Node_4" 402 | bottom:"Node_5" 403 | bottom:"Node_37" 404 | layer_param { 405 | idx:38 406 | kind:0 407 | body_00:"for (x.inner, 0, 32)" 408 | body_01:" vectorized (y.inner, 0, 32)" 409 | body_02:" C[((((x.outer*32768) + (x.inner*1024)) + (y.outer*32)) + y.inner)] = (A[((((x.outer*32768) + (x.inner*1024)) + (y.outer*32)) + y.inner)] + B[((((x.outer*32768) + (x.inner*1024)) + (y.outer*32)) + y.inner)])" 410 | } 411 | } 412 | layer { 413 | name:"Node_39" 414 | type:"for" 415 | top:"Node_39" 416 | bottom:"x.outer" 417 | bottom:"Node_4" 418 | bottom:"Node_5" 419 | bottom:"Node_38" 420 | layer_param { 421 | idx:39 422 | kind:0 423 | body_00:"for (y.outer, 0, 32)" 424 | body_01:" for (x.inner, 0, 32)" 425 | body_02:" vectorized (y.inner, 0, 32)" 426 | body_03:" C[((((x.outer*32768) + (x.inner*1024)) + (y.outer*32)) + y.inner)] = (A[((((x.outer*32768) + (x.inner*1024)) + (y.outer*32)) + y.inner)] + B[((((x.outer*32768) + (x.inner*1024)) + (y.outer*32)) + y.inner)])" 427 | } 428 | } 429 | layer { 430 | name:"Node_40" 431 | type:"primfunc" 432 | top:"Node_40" 433 | bottom:"A" 434 | bottom:"B" 435 | bottom:"C" 436 | bottom:"Node_39" 437 | layer_param { 438 | idx:40 439 | body_00:"for (x.outer, 0, 32)" 440 | body_01:" for (y.outer, 0, 32)" 441 | body_02:" for (x.inner, 0, 32)" 442 | body_03:" vectorized (y.inner, 0, 32)" 443 | body_04:" C[((((x.outer*32768) + (x.inner*1024)) + (y.outer*32)) + y.inner)] = (A[((((x.outer*32768) + (x.inner*1024)) + (y.outer*32)) + y.inner)] + B[((((x.outer*32768) + (x.inner*1024)) + (y.outer*32)) + y.inner)])" 444 | } 445 | } 446 | -------------------------------------------------------------------------------- /tvm_build/build_transform.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torchvision.models as models 4 | 5 | import tvm 6 | from tvm import relay 7 | from tvm.relay.backend import graph_executor_codegen 8 | 9 | import sys 10 | sys.path.append("..") 11 | from visualize import PrimExprVisualizer 12 | visualizer=PrimExprVisualizer(simple_mode=False) 13 | 14 | def process_codegen(target,optimize_mixed=False): 15 | #prepare model and input 16 | model = models.resnet18(pretrained=True) 17 | shape_list = [("input0",(1,3,224,224))] 18 | fake_input = torch.from_numpy(np.random.random_sample(shape_list[0][1]).astype('float32')) 19 | graph = torch.jit.trace(model,fake_input) 20 | #main function 21 | mod, params = relay.frontend.from_pytorch(graph, shape_list) 22 | #optimize the mod 23 | with tvm.transform.PassContext(opt_level=3): 24 | mod, _ = relay.optimize(mod, target, params) 25 | grc = graph_executor_codegen.GraphExecutorCodegen(None, target) 26 | graph_json, lowered_func, params=grc.codegen(mod["main"]) 27 | for tar, input_mod in lowered_func.items(): 28 | if optimize_mixed: 29 | input_mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(input_mod) 30 | passes = [ 31 | tvm.tir.transform.VerifyMemory(), 32 | tvm.tir.transform.ThreadSync("shared"), 33 | tvm.tir.transform.ThreadSync("warp"), 34 | tvm.tir.transform.InferFragment(), 35 | tvm.tir.transform.LowerThreadAllreduce(), 36 | tvm.tir.transform.MakePackedAPI(), 37 | tvm.tir.transform.SplitHostDevice(), 38 | ] 39 | input_mod = tvm.transform.Sequential(passes)(input_mod) 40 | return input_mod 41 | 42 | def compare_mod(raw_mod,opt_mod,func_var=None,visual_name=None): 43 | #find a transformed var 44 | if not func_var: 45 | for var in raw_mod.get_global_vars(): 46 | if var not in opt_mod.get_global_vars(): 47 | continue 48 | try: 49 | raw_des=str(raw_mod[var]) 50 | opt_des=str(opt_mod[var]) 51 | except: 52 | continue 53 | if raw_des!=opt_des: 54 | func_var=var 55 | break 56 | if not func_var: 57 | print("raw mod and optimized mod are same") 58 | return 59 | print(" {} : {}".format(func_var,raw_mod[func_var])) 60 | print(" {} : {}".format(func_var,opt_mod[func_var])) 61 | 62 | if visual_name: 63 | visualizer.visualize(raw_mod[func_var],"visualizes/{}_before.prototxt".format(visual_name)) 64 | visualizer.visualize(opt_mod[func_var],"visualizes/{}_after.prototxt".format(visual_name)) 65 | 66 | def test_thread_sync(func_var=None): 67 | print("\nTest ThreadSync shared") 68 | target = tvm.target.Target("cuda") 69 | mod_mixed = process_codegen(target) 70 | raw_mixed = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(mod_mixed) 71 | passes = [ 72 | tvm.tir.transform.VerifyMemory(), 73 | ] 74 | raw_mixed = tvm.transform.Sequential(passes)(raw_mixed) 75 | opt_mixed = tvm.transform.Sequential([tvm.tir.transform.ThreadSync("shared")])(raw_mixed) 76 | compare_mod(raw_mixed,opt_mixed,func_var) 77 | 78 | def test_make_packed_api(func_var=None): 79 | print("\nTest MakePackedAPI") 80 | target = tvm.target.Target("llvm") 81 | mod_mixed = process_codegen(target) 82 | raw_mixed = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(mod_mixed) 83 | passes = [ 84 | tvm.tir.transform.VerifyMemory(), 85 | tvm.tir.transform.ThreadSync("shared"), 86 | tvm.tir.transform.ThreadSync("warp"), 87 | tvm.tir.transform.InferFragment(), 88 | tvm.tir.transform.LowerThreadAllreduce(), 89 | ] 90 | raw_mixed = tvm.transform.Sequential(passes)(raw_mixed) 91 | opt_mixed = tvm.transform.Sequential([tvm.tir.transform.MakePackedAPI()])(raw_mixed) 92 | compare_mod(raw_mixed,opt_mixed,func_var,visual_name="make_packed_api") 93 | 94 | def test_split_host_device(func_var=None): 95 | print("\nTest SplitHostDevice") 96 | target = tvm.target.Target("cuda") 97 | mod_mixed = process_codegen(target) 98 | raw_mixed = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(mod_mixed) 99 | passes = [ 100 | tvm.tir.transform.VerifyMemory(), 101 | tvm.tir.transform.ThreadSync("shared"), 102 | tvm.tir.transform.ThreadSync("warp"), 103 | tvm.tir.transform.InferFragment(), 104 | tvm.tir.transform.LowerThreadAllreduce(), 105 | tvm.tir.transform.MakePackedAPI() 106 | ] 107 | raw_mixed = tvm.transform.Sequential(passes)(raw_mixed) 108 | opt_mixed = tvm.transform.Sequential([tvm.tir.transform.SplitHostDevice()])(raw_mixed) 109 | print(" fused_nn_dense_add : "+str(raw_mixed["fused_nn_dense_add"])) 110 | print(" fused_nn_dense_add "+str(opt_mixed["fused_nn_dense_add"])) 111 | print(" fused_nn_dense_add_kernel0 "+str(opt_mixed["fused_nn_dense_add_kernel0"])) 112 | visualizer.visualize(raw_mixed["fused_nn_dense_add"],"visualizes/split_host_device_before.prototxt") 113 | visualizer.visualize(opt_mixed["fused_nn_dense_add"],"visualizes/split_host_device_after_host.prototxt") 114 | visualizer.visualize(opt_mixed["fused_nn_dense_add_kernel0"],"visualizes/split_host_device_after_device.prototxt") 115 | 116 | def test_lower_tvm_builtin(func_var=None): 117 | print("\nTest LowerTVMBuiltin") 118 | target = tvm.target.Target("cuda") 119 | target, target_host = tvm.target.Target.check_and_update_host_consist(target) 120 | mod_mixed = process_codegen(target,optimize_mixed=True) 121 | passes = [ 122 | tvm.tir.transform.Filter( 123 | lambda f: "calling_conv" in f.attrs 124 | and f.attrs["calling_conv"].value != tvm.ir.CallingConv.DEVICE_KERNEL_LAUNCH 125 | ), 126 | tvm.tir.transform.Apply(lambda f: f.with_attr("target", target_host)), 127 | ] 128 | raw_dev = tvm.transform.Sequential(passes)(mod_mixed) 129 | opt_dev = tvm.transform.Sequential([tvm.tir.transform.LowerTVMBuiltin(),])(raw_dev) 130 | compare_mod(raw_dev,opt_dev,func_var,visual_name="lower_tvm_builtin") 131 | 132 | if __name__=='__main__': 133 | test_thread_sync() 134 | test_make_packed_api() 135 | test_split_host_device() 136 | test_lower_tvm_builtin("fused_nn_dense_add") -------------------------------------------------------------------------------- /tvm_build/demo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torchvision.models as models 4 | 5 | import tvm 6 | from tvm import relay 7 | 8 | if __name__=='__main__': 9 | #prepare model and input 10 | model = models.resnet18(pretrained=True) 11 | shape_list = [("input0",(1,3,224,224))] 12 | fake_input = torch.from_numpy(np.random.random_sample(shape_list[0][1]).astype('float32')) 13 | graph = torch.jit.trace(model,fake_input) 14 | #main function 15 | mod, params = relay.frontend.from_pytorch(graph, shape_list) 16 | #optimize the mod 17 | target = tvm.target.Target("llvm", host="llvm") 18 | with tvm.transform.PassContext(opt_level=3): 19 | graph_json, mod, params = relay.build(mod, target=target, params=params) -------------------------------------------------------------------------------- /tvm_build/visualizes/make_packed_api_before.prototxt: -------------------------------------------------------------------------------- 1 | name : "prim_expr" 2 | layer { 3 | name:"placeholder" 4 | type:"buffer" 5 | top:"placeholder" 6 | layer_param { 7 | idx:0 8 | shape:[1, 2, 56, 56, 32] 9 | dtype:float32 10 | name:"placeholder" 11 | } 12 | } 13 | layer { 14 | name:"T_layout_trans" 15 | type:"buffer" 16 | top:"T_layout_trans" 17 | layer_param { 18 | idx:1 19 | shape:[1, 8, 56, 56, 8] 20 | dtype:float32 21 | name:"T_layout_trans" 22 | } 23 | } 24 | layer { 25 | name:"ax0.ax1.fused.ax2.fused" 26 | type:"var(loop_var)" 27 | top:"ax0.ax1.fused.ax2.fused" 28 | layer_param { 29 | idx:2 30 | dtype:int32 31 | } 32 | } 33 | layer { 34 | name:"Node_3" 35 | type:"int(for_min)" 36 | top:"Node_3" 37 | layer_param { 38 | idx:3 39 | value:0 40 | dtype:int32 41 | } 42 | } 43 | layer { 44 | name:"Node_4" 45 | type:"int(for_extent)" 46 | top:"Node_4" 47 | layer_param { 48 | idx:4 49 | value:448 50 | dtype:int32 51 | } 52 | } 53 | layer { 54 | name:"ax3" 55 | type:"var(loop_var)" 56 | top:"ax3" 57 | layer_param { 58 | idx:5 59 | dtype:int32 60 | } 61 | } 62 | layer { 63 | name:"Node_6" 64 | type:"int(for_extent)" 65 | top:"Node_6" 66 | layer_param { 67 | idx:6 68 | value:56 69 | dtype:int32 70 | } 71 | } 72 | layer { 73 | name:"T_layout_trans_1" 74 | type:"var(store_buffer)" 75 | top:"T_layout_trans_1" 76 | layer_param { 77 | idx:7 78 | dtype:handle 79 | } 80 | } 81 | layer { 82 | name:"placeholder_1" 83 | type:"var(load_buffer)" 84 | top:"placeholder_1" 85 | layer_param { 86 | idx:8 87 | dtype:handle 88 | } 89 | } 90 | layer { 91 | name:"Node_9" 92 | type:"int(b)" 93 | top:"Node_9" 94 | layer_param { 95 | idx:9 96 | value:224 97 | dtype:int32 98 | } 99 | } 100 | layer { 101 | name:"Node_10" 102 | type:"floor_div(a)" 103 | top:"Node_10" 104 | bottom:"ax0.ax1.fused.ax2.fused" 105 | bottom:"Node_9" 106 | layer_param { 107 | idx:10 108 | } 109 | } 110 | layer { 111 | name:"Node_11" 112 | type:"int(b)" 113 | top:"Node_11" 114 | layer_param { 115 | idx:11 116 | value:100352 117 | dtype:int32 118 | } 119 | } 120 | layer { 121 | name:"Node_12" 122 | type:"mul(a)" 123 | top:"Node_12" 124 | bottom:"Node_10" 125 | bottom:"Node_11" 126 | layer_param { 127 | idx:12 128 | } 129 | } 130 | layer { 131 | name:"Node_13" 132 | type:"floor_mod(a)" 133 | top:"Node_13" 134 | bottom:"ax0.ax1.fused.ax2.fused" 135 | bottom:"Node_6" 136 | layer_param { 137 | idx:13 138 | } 139 | } 140 | layer { 141 | name:"Node_14" 142 | type:"int(b)" 143 | top:"Node_14" 144 | layer_param { 145 | idx:14 146 | value:1792 147 | dtype:int32 148 | } 149 | } 150 | layer { 151 | name:"Node_15" 152 | type:"mul(b)" 153 | top:"Node_15" 154 | bottom:"Node_13" 155 | bottom:"Node_14" 156 | layer_param { 157 | idx:15 158 | } 159 | } 160 | layer { 161 | name:"Node_16" 162 | type:"add(a)" 163 | top:"Node_16" 164 | bottom:"Node_12" 165 | bottom:"Node_15" 166 | layer_param { 167 | idx:16 168 | } 169 | } 170 | layer { 171 | name:"Node_17" 172 | type:"int(b)" 173 | top:"Node_17" 174 | layer_param { 175 | idx:17 176 | value:32 177 | dtype:int32 178 | } 179 | } 180 | layer { 181 | name:"Node_18" 182 | type:"mul(b)" 183 | top:"Node_18" 184 | bottom:"ax3" 185 | bottom:"Node_17" 186 | layer_param { 187 | idx:18 188 | } 189 | } 190 | layer { 191 | name:"Node_19" 192 | type:"add(a)" 193 | top:"Node_19" 194 | bottom:"Node_16" 195 | bottom:"Node_18" 196 | layer_param { 197 | idx:19 198 | } 199 | } 200 | layer { 201 | name:"Node_20" 202 | type:"floor_mod(a)" 203 | top:"Node_20" 204 | bottom:"ax0.ax1.fused.ax2.fused" 205 | bottom:"Node_9" 206 | layer_param { 207 | idx:20 208 | } 209 | } 210 | layer { 211 | name:"Node_21" 212 | type:"floor_div(a)" 213 | top:"Node_21" 214 | bottom:"Node_20" 215 | bottom:"Node_6" 216 | layer_param { 217 | idx:21 218 | } 219 | } 220 | layer { 221 | name:"Node_22" 222 | type:"int(b)" 223 | top:"Node_22" 224 | layer_param { 225 | idx:22 226 | value:8 227 | dtype:int32 228 | } 229 | } 230 | layer { 231 | name:"Node_23" 232 | type:"mul(b)" 233 | top:"Node_23" 234 | bottom:"Node_21" 235 | bottom:"Node_22" 236 | layer_param { 237 | idx:23 238 | } 239 | } 240 | layer { 241 | name:"Node_24" 242 | type:"add(base)" 243 | top:"Node_24" 244 | bottom:"Node_19" 245 | bottom:"Node_23" 246 | layer_param { 247 | idx:24 248 | } 249 | } 250 | layer { 251 | name:"Node_25" 252 | type:"int(stride)" 253 | top:"Node_25" 254 | layer_param { 255 | idx:25 256 | value:1 257 | dtype:int32 258 | } 259 | } 260 | layer { 261 | name:"Node_26" 262 | type:"ramp(load_index)" 263 | top:"Node_26" 264 | bottom:"Node_24" 265 | bottom:"Node_25" 266 | layer_param { 267 | idx:26 268 | lanes:8 269 | base:"((((floordiv(ax0.ax1.fused.ax2.fused: int32, 224)*100352) + (floormod(ax0.ax1.fused.ax2.fused, 56)*1792)) + (ax3: int32*32)) + (floordiv(floormod(ax0.ax1.fused.ax2.fused, 224), 56)*8))" 270 | stride:"1" 271 | } 272 | } 273 | layer { 274 | name:"Node_27" 275 | type:"load(store_value)" 276 | top:"Node_27" 277 | bottom:"placeholder_1" 278 | bottom:"Node_26" 279 | layer_param { 280 | idx:27 281 | predicate:"broadcast(True, 8)" 282 | body:"(float32x8*)placeholder: Pointer(float32)[ramp(((((floordiv(ax0.ax1.fused.ax2.fused: int32, 224)*100352) + (floormod(ax0.ax1.fused.ax2.fused, 56)*1792)) + (ax3: int32*32)) + (floordiv(floormod(ax0.ax1.fused.ax2.fused, 224), 56)*8)), 1, 8)]" 283 | } 284 | } 285 | layer { 286 | name:"Node_28" 287 | type:"mul(a)" 288 | top:"Node_28" 289 | bottom:"ax0.ax1.fused.ax2.fused" 290 | bottom:"Node_4" 291 | layer_param { 292 | idx:28 293 | } 294 | } 295 | layer { 296 | name:"Node_29" 297 | type:"mul(b)" 298 | top:"Node_29" 299 | bottom:"ax3" 300 | bottom:"Node_22" 301 | layer_param { 302 | idx:29 303 | } 304 | } 305 | layer { 306 | name:"Node_30" 307 | type:"add(base)" 308 | top:"Node_30" 309 | bottom:"Node_28" 310 | bottom:"Node_29" 311 | layer_param { 312 | idx:30 313 | } 314 | } 315 | layer { 316 | name:"Node_31" 317 | type:"ramp(store_index)" 318 | top:"Node_31" 319 | bottom:"Node_30" 320 | bottom:"Node_25" 321 | layer_param { 322 | idx:31 323 | lanes:8 324 | base:"((ax0.ax1.fused.ax2.fused: int32*448) + (ax3: int32*8))" 325 | stride:"1" 326 | } 327 | } 328 | layer { 329 | name:"Node_32" 330 | type:"store" 331 | top:"Node_32" 332 | bottom:"T_layout_trans_1" 333 | bottom:"Node_27" 334 | bottom:"Node_31" 335 | layer_param { 336 | idx:32 337 | predicate:"broadcast(True, 8)" 338 | value:"(float32x8*)placeholder: Pointer(float32)[ramp(((((floordiv(ax0.ax1.fused.ax2.fused: int32, 224)*100352) + (floormod(ax0.ax1.fused.ax2.fused, 56)*1792)) + (ax3: int32*32)) + (floordiv(floormod(ax0.ax1.fused.ax2.fused, 224), 56)*8)), 1, 8)]" 339 | index:"ramp(((ax0.ax1.fused.ax2.fused: int32*448) + (ax3: int32*8)), 1, 8)" 340 | body_:"T_layout_trans[ramp(((ax0.ax1.fused.ax2.fused*448) + (ax3*8)), 1, 8)] = placeholder[ramp(((((floordiv(ax0.ax1.fused.ax2.fused, 224)*100352) + (floormod(ax0.ax1.fused.ax2.fused, 56)*1792)) + (ax3*32)) + (floordiv(floormod(ax0.ax1.fused.ax2.fused, 224), 56)*8)), 1, 8)]" 341 | } 342 | } 343 | layer { 344 | name:"Node_33" 345 | type:"for" 346 | top:"Node_33" 347 | bottom:"ax3" 348 | bottom:"Node_3" 349 | bottom:"Node_6" 350 | bottom:"Node_32" 351 | layer_param { 352 | idx:33 353 | kind:0 354 | body_:"T_layout_trans[ramp(((ax0.ax1.fused.ax2.fused*448) + (ax3*8)), 1, 8)] = placeholder[ramp(((((floordiv(ax0.ax1.fused.ax2.fused, 224)*100352) + (floormod(ax0.ax1.fused.ax2.fused, 56)*1792)) + (ax3*32)) + (floordiv(floormod(ax0.ax1.fused.ax2.fused, 224), 56)*8)), 1, 8)]" 355 | } 356 | } 357 | layer { 358 | name:"Node_34" 359 | type:"for" 360 | top:"Node_34" 361 | bottom:"ax0.ax1.fused.ax2.fused" 362 | bottom:"Node_3" 363 | bottom:"Node_4" 364 | bottom:"Node_33" 365 | layer_param { 366 | idx:34 367 | kind:1 368 | body_00:"for (ax3, 0, 56)" 369 | body_01:" T_layout_trans[ramp(((ax0.ax1.fused.ax2.fused*448) + (ax3*8)), 1, 8)] = placeholder[ramp(((((floordiv(ax0.ax1.fused.ax2.fused, 224)*100352) + (floormod(ax0.ax1.fused.ax2.fused, 56)*1792)) + (ax3*32)) + (floordiv(floormod(ax0.ax1.fused.ax2.fused, 224), 56)*8)), 1, 8)]" 370 | } 371 | } 372 | layer { 373 | name:"Node_35" 374 | type:"primfunc" 375 | top:"Node_35" 376 | bottom:"placeholder" 377 | bottom:"T_layout_trans" 378 | bottom:"Node_34" 379 | layer_param { 380 | idx:35 381 | body_00:"parallel (ax0.ax1.fused.ax2.fused, 0, 448)" 382 | body_01:" for (ax3, 0, 56)" 383 | body_02:" T_layout_trans[ramp(((ax0.ax1.fused.ax2.fused*448) + (ax3*8)), 1, 8)] = placeholder[ramp(((((floordiv(ax0.ax1.fused.ax2.fused, 224)*100352) + (floormod(ax0.ax1.fused.ax2.fused, 56)*1792)) + (ax3*32)) + (floordiv(floormod(ax0.ax1.fused.ax2.fused, 224), 56)*8)), 1, 8)]" 384 | } 385 | } 386 | -------------------------------------------------------------------------------- /tvm_concepts.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Archermmt/tvm_walk_through/7e1e6e8f8061c5e79e5e399b2a85c1f404ef5893/tvm_concepts.pdf -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import tvm 2 | import torch 3 | import numpy as np 4 | 5 | def cast_array(array): 6 | if isinstance(array,tvm.runtime.ndarray.NDArray): 7 | array=array.asnumpy() 8 | elif isinstance(array,torch.Tensor): 9 | array=array.detach().cpu().numpy() 10 | assert isinstance(array,np.ndarray),"Only accept array as numpy.ndarray, get "+str(type(array)) 11 | return array 12 | 13 | def array_des(array): 14 | type_des=array.__class__.__name__ 15 | array=cast_array(array) 16 | return "<{}>[{};{}] max {:g}, min {:g}, sum {:g}".format( 17 | type_des,','.join([str(s) for s in array.shape]),array.dtype.name, 18 | array.max(),array.min(),array.sum()) 19 | 20 | def array_compare(arrayA,arrayB,nameA="A",nameB="B",error=0.05): 21 | arrayA=cast_array(arrayA) 22 | arrayB=cast_array(arrayB) 23 | if arrayA.dtype!=arrayB.dtype: 24 | print("dtype mismatch between {} and {}".format(arrayA.dtype,arrayB.dtype)) 25 | if arrayA.shape!=arrayB.shape: 26 | print("dtype mismatch between {} and {}".format(arrayA.dtype,arrayB.dtype)) 27 | diff=(arrayA-arrayB)/(abs(arrayA)+0.0001) 28 | msg="max : {:g}, min :{:g}, sum : {:g}".format(diff.max(),diff.min(),diff.sum()) 29 | if abs(diff).max()>error: 30 | print("[FAIL] "+msg) 31 | print("{} : {}".format(nameA,array_des(arrayA))) 32 | print("{} : {}".format(nameB,array_des(arrayB))) 33 | return False 34 | print("[PASS] "+msg) 35 | return True --------------------------------------------------------------------------------