├── requirements.txt ├── README.md ├── tuner ├── dynamic_model_based_tuner.py └── RF_cost_model.py ├── tune.py └── measure └── measure_methods.py /requirements.txt: -------------------------------------------------------------------------------- 1 | # python>=3.6 2 | numpy 3 | decorator 4 | attrs 5 | tornado 6 | psutil 7 | xgboost 8 | tensorflow==1.13.1 9 | scikit-learn 10 | sklearn 11 | scipy 12 | antlr4-python3-runtime -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AdaTune: Adaptive Tensor Program Compilation Made Efficient 2 | 3 | This repository is the official implementation of AdaTune: Adaptive Tensor Program Compilation Made Efficient. 4 | 5 | ## Requirements 6 | 7 | Install TVM first. You can find TVM installation instructions [here](https://tvm.apache.org/docs/install/from_source.html). 8 | >Prepare llvm: 9 | ``` 10 | wget https://releases.llvm.org/6.0.0/clang+llvm-6.0.0-x86_64-linux-gnu-ubuntu-16.04.tar.xz 11 | tar xvJf clang+llvm-6.0.0-x86_64-linux-gnu-ubuntu-16.04.tar.xz 12 | ``` 13 | 14 | >Clone the TVM project from github: 15 | ``` 16 | git clone --recursive https://github.com/apache/incubator-tvm tvm 17 | sudo apt-get update 18 | sudo apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev 19 | mkdir build 20 | cp cmake/config.cmake build 21 | ``` 22 | >Edit build/config.cmake: 23 | ``` 24 | set(USE_LLVM /bin/llvm-config) 25 | set(USE_CUDA ON) (you can ignore this if you want to test cpu only) 26 | ``` 27 | >Building: 28 | ``` 29 | cd build 30 | cmake .. 31 | make -j6 32 | ``` 33 | >Add TVM into PYTHONPATH, edit your ~/.bashrc: 34 | ``` 35 | export TVM_HOME=/path/to/tvm 36 | export PYTHONPATH=$TVM_HOME/python:$TVM_HOME/topi/python:${PYTHONPATH} 37 | ``` 38 | >Install other required packages: 39 | ``` 40 | pip install -r requirements.txt 41 | ``` 42 | >Add AdaTune files. 43 | ``` 44 | cp tuner/* /python/tvm/autotvm/tuner/ 45 | cp measure/measure_methods.py /python/tvm/autotvm/measure/ 46 | ``` 47 | 48 | ## Optimizing Models and Evaluation 49 | 50 | To obtain the end-to-end experiments results in the paper, run the following command: 51 | 52 | ``` 53 | python tune.py 54 | --model_name # for example: 'resnet-18','squeezenet_v1.1','vgg-16' 55 | --use_gpu # bool, True/False 56 | --tuner # for example: 'ada', 'xgb' 57 | --ops # for example: 'conv2d', 'dense' 58 | ``` 59 | 60 | > If the use_gpu flag is set to True, TVM should have been compiled with CUDA. 61 | > The tune.py file will tune all the dense and conv ops in the models and then evaluate the inference latency on the optimized models. These models are constructed as TVM relay module. Please refer to the [TVM tutorial](https://tvm.apache.org/docs/tutorials/index.html) to tune more models in different formats. 62 | 63 | ### Testing environment 64 | All the results from the paper are collected on the following hardware. 65 | + CPU: Intel Xeon x86 CPU E5-2690 v3 66 | + GPU: Nvidia Tesla P100 67 | 68 | ## Results 69 | 70 | Our method achieves the following performance (optimization time) on the Resnet-18, VGG-16, Squeezenet_V1.1 models compared with the AutoTVM (XGBTuner): 71 | 72 | #### Compilation time comparison 73 | | Model name | AutoTVM(GPU) | AdaTune(GPU) | Speedup | AutoTVM(CPU) | AdaTune(CPU) | Speedup | 74 | | --------------- | ------------ | ------------ | ------- | ------------ | ------------ | ------- | 75 | | Resnet-18 | 22.6h | 9.6h | 2.4X | 2.0h | 1.0h | 2.0X | 76 | | Resnet-50 | 20.0h | 14.1h | 1.4X | 3.6h | 1.7h | 2.1X | 77 | | VGG-16 | 21.9h | 16.7h | 1.3X | 18.9h | 6.5h | 2.9X | 78 | | Squeezenet_V1.1 | 7.6h | 5.8h | 1.3X | 1.2h | 0.7h | 1.7X | 79 | | Encoder | 3.8h | 2.8h | 1.4X | 8.4h | 3.8h | 2.2X | 80 | 81 | #### Inference time comparison 82 | | Model name | TVM(GPU) | AutoTVM(GPU) | AdaTune(GPU) | TVM(CPU) | AutoTVM(CPU) | AdaTune(CPU) | 83 | | --------------- | ------- | ------------ | ------------ | -------- | ------------ | ------------ | 84 | | Resnet-18 | 1.53ms | 1.38ms | 1.38ms | 79.24ms | 52.64ms | 52.64ms | 85 | | Resnet-50 | 4.82ms | 4.37ms | 4.37ms | 217.12ms | 115.76ms | 115.68ms | 86 | | VGG-16 | 3.95ms | 3.86ms | 3.86ms | 884.94ms | 442.01ms | 438.68ms | 87 | | Squeezenet_V1.1 | 2.93ms | 0.65ms | 0.63ms | 14.41 ms | 11.36ms | 11.25ms | 88 | | Encoder | 78.15ms | 52.25ms | 47.46ms | 2897.27ms| 1620.88ms | 1607.67ms | 89 | 90 | ## Contributing 91 | Under Apache License 2.0 -------------------------------------------------------------------------------- /tuner/dynamic_model_based_tuner.py: -------------------------------------------------------------------------------- 1 | from .model_based_tuner import knob2point,point2knob,submodular_pick 2 | from .tuner import Tuner 3 | from ..util import sample_ints 4 | 5 | import numpy as np 6 | 7 | 8 | class ModelBasedTunerAda(Tuner): 9 | """Base class for model based tuner 10 | This type of tuner will fit a cost model and use an optimizer to 11 | find the maximums of the cost model as next trials 12 | 13 | Parameters 14 | ---------- 15 | task: autotvm.task.Task 16 | The tuning task 17 | cost_model: CostModel 18 | The cost model that predicts the speed of a config (IR) 19 | model_optimizer: 20 | The optimizer to find local optimum points of cost model in tuning search space 21 | plan_size: int 22 | Tuner will re-fit model per `plan_size` new measure samples 23 | diversity_filter_ratio: int or float, optional 24 | If is not None, the tuner will first select 25 | top-(plan_size * diversity_filter_ratio) candidates according to the cost model 26 | and then pick plan_size of them according to the diversity metric. 27 | """ 28 | 29 | def __init__(self, task, cost_model, model_optimizer, plan_size, diversity_filter_ratio=None, dynamic_ep=True): 30 | super(ModelBasedTunerAda, self).__init__(task) 31 | 32 | # space 33 | self.task = task 34 | self.target = task.target 35 | self.plan_size = plan_size 36 | self.space = task.config_space 37 | self.space_len = len(task.config_space) 38 | self.dims = [len(x) for x in self.space.space_map.values()] 39 | 40 | self.cost_model = cost_model 41 | self.model_optimizer = model_optimizer 42 | self.diversity_filter_ratio = diversity_filter_ratio 43 | 44 | if self.diversity_filter_ratio: 45 | assert self.diversity_filter_ratio >= 1, "Diversity filter ratio " \ 46 | "must be larger than one" 47 | 48 | # trial plan 49 | self.trials = [] 50 | self.trial_pt = 0 51 | self.visited = set() 52 | 53 | # observed samples 54 | self.xs = [] 55 | self.ys = [] 56 | self.flops_max = 0.0 57 | self.train_ct = 0 58 | self.dynamic_ep = dynamic_ep 59 | self.balance_ep = 0.05 60 | 61 | def next_batch(self, batch_size): 62 | ret = [] 63 | 64 | counter = 0 65 | while counter < batch_size: 66 | if len(self.visited) >= len(self.space): 67 | break 68 | 69 | while self.trial_pt < len(self.trials): 70 | index = self.trials[self.trial_pt] 71 | if index not in self.visited: 72 | break 73 | self.trial_pt += 1 74 | 75 | if self.trial_pt >= len(self.trials) - int(self.balance_ep * self.plan_size): 76 | # if the trial list is empty or 77 | # the tuner is doing the last 5% trials (e-greedy), choose randomly 78 | index = np.random.randint(len(self.space)) 79 | while index in self.visited: 80 | index = np.random.randint(len(self.space)) 81 | 82 | ret.append(self.space.get(index)) 83 | self.visited.add(index) 84 | 85 | counter += 1 86 | return ret 87 | 88 | def update(self, inputs, results): 89 | for inp, res in zip(inputs, results): 90 | index = inp.config.index 91 | if res.error_no == 0: 92 | self.xs.append(index) 93 | flops = inp.task.flop / np.mean(res.costs) 94 | self.flops_max = max(self.flops_max, flops) 95 | self.ys.append(flops) 96 | # self.count_not0 = self.count_not0 + 1 97 | else: 98 | self.xs.append(index) 99 | self.ys.append(0.0) 100 | 101 | # if we have enough new training samples 102 | if len(self.xs) >= self.plan_size * (self.train_ct + 1) \ 103 | and self.flops_max > 1e-6: 104 | self.cost_model.fit(self.xs, self.ys, self.plan_size) 105 | if self.diversity_filter_ratio: 106 | candidate = self.model_optimizer.find_maximums( 107 | self.cost_model, self.plan_size * self.diversity_filter_ratio, self.visited) 108 | scores = self.cost_model.predict(candidate) 109 | knobs = [point2knob(x, self.dims) for x in candidate] 110 | pick_index = submodular_pick(0 * scores, knobs, self.plan_size, knob_weight=1) 111 | maximums = np.array(candidate)[pick_index] 112 | else: 113 | maximums = self.model_optimizer.find_maximums( 114 | self.cost_model, self.plan_size, self.visited) 115 | if self.dynamic_ep: 116 | samples = np.array(sample_ints(0, len(self.space), 20)) 117 | _, mean_of_variance = self.cost_model._expected_imporvement(samples) 118 | 119 | self.balance_ep = mean_of_variance/self.best_flops 120 | 121 | self.trials = maximums 122 | self.trial_pt = 0 123 | self.train_ct += 1 124 | 125 | def load_history(self, data_set): 126 | # set in_tuning as True to make the feature extraction consistent 127 | GLOBAL_SCOPE.in_tuning = True 128 | 129 | # fit base model 130 | base_model = self.cost_model.spawn_base_model() 131 | success = base_model.fit_log(data_set, self.plan_size) 132 | 133 | if not success: 134 | GLOBAL_SCOPE.in_tuning = False 135 | return 136 | 137 | # use base model to select initial points 138 | if not self.trials: 139 | # no plan yet, use base model to select initial trials 140 | maximums = self.model_optimizer.find_maximums(base_model, self.plan_size, self.visited) 141 | self.trials = maximums 142 | self.trial_pt = 0 143 | 144 | self.cost_model.load_basemodel(base_model) 145 | GLOBAL_SCOPE.in_tuning = False 146 | 147 | def has_next(self): 148 | return len(self.visited) < len(self.space) -------------------------------------------------------------------------------- /tune.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | import tvm 5 | from tvm import te 6 | from tvm import autotvm 7 | from tvm import relay 8 | from tvm.relay import testing 9 | from tvm.autotvm.tuner import XGBTuner, GATuner, RandomTuner, GridSearchTuner 10 | from tvm.autotvm.tuner.RF_cost_model import RFTuner 11 | from tvm.autotvm.graph_tuner import DPTuner, PBQPTuner 12 | import tvm.contrib.graph_runtime as runtime 13 | import argparse 14 | 15 | # Args 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('-m', '--model_name', metavar='model_name', type = str, default='resnet-18', help="model_name") 18 | parser.add_argument('-t', '--use_gpu', metavar='use_gpu', type = bool, default=False, help="use_gpu") 19 | parser.add_argument('-w', '--tuner', metavar='tuner', type = str, default='ada', help="tuner") 20 | parser.add_argument('-o', '--ops', dest='ops', nargs='*', default=['conv2d']) 21 | args = parser.parse_args() 22 | 23 | model_name = args.model_name 24 | use_gpu = args.use_gpu 25 | tuner = args.tuner 26 | ops = () 27 | if 'conv2d' in args.ops: 28 | ops += (relay.op.get("nn.conv2d"),) 29 | if 'dense' in args.ops: 30 | ops += (relay.op.get("nn.dense"),) 31 | 32 | # ops = tuple(ops) 33 | 34 | 35 | ################################################################# 36 | # Define network 37 | # -------------- 38 | # First we need to define the network in relay frontend API. 39 | # We can either load some pre-defined network from :code:`relay.testing` 40 | # or building :any:`relay.testing.resnet` with relay. 41 | # We can also load models from MXNet, ONNX and TensorFlow. 42 | 43 | def get_network(name, batch_size): 44 | """Get the symbol definition and random weight of a network""" 45 | input_shape = (batch_size, 3, 224, 224) 46 | output_shape = (batch_size, 1000) 47 | 48 | if "resnet" in name: 49 | n_layer = int(name.split('-')[1]) 50 | mod, params = relay.testing.resnet.get_workload(num_layers=n_layer, batch_size=batch_size, dtype=dtype) 51 | elif "vgg" in name: 52 | n_layer = int(name.split('-')[1]) 53 | mod, params = relay.testing.vgg.get_workload(num_layers=n_layer, batch_size=batch_size, dtype=dtype) 54 | elif name == 'mobilenet': 55 | mod, params = relay.testing.mobilenet.get_workload(batch_size=batch_size, dtype=dtype) 56 | elif name == 'squeezenet_v1.1': 57 | mod, params = relay.testing.squeezenet.get_workload(batch_size=batch_size, version='1.1', dtype=dtype) 58 | elif name == 'inception_v3': 59 | input_shape = (1, 3, 299, 299) 60 | mod, params = relay.testing.inception_v3.get_workload(batch_size=batch_size, dtype=dtype) 61 | elif name == 'mxnet': 62 | # an example for mxnet model 63 | from mxnet.gluon.model_zoo.vision import get_model 64 | block = get_model('resnet18_v1', pretrained=True) 65 | mod, params = relay.frontend.from_mxnet(block, shape={input_name: input_shape}, dtype=dtype) 66 | net = mod["main"] 67 | net = relay.Function(net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs) 68 | mod = tvm.IRModule.from_expr(net) 69 | else: 70 | raise ValueError("Unsupported network: " + name) 71 | 72 | return mod, params, input_shape, output_shape 73 | 74 | 75 | # Replace "llvm" with the correct target of your CPU. 76 | # For example, for AWS EC2 c5 instance with Intel Xeon 77 | # Platinum 8000 series, the target should be "llvm -mcpu=skylake-avx512". 78 | # For AWS EC2 c4 instance with Intel Xeon E5-2666 v3, it should be 79 | # "llvm -mcpu=core-avx2". 80 | if use_gpu: 81 | target = "cuda" 82 | ctx = tvm.gpu() 83 | else: 84 | target = "llvm" 85 | ctx = tvm.cpu() 86 | batch_size = 1 87 | dtype = "float32" 88 | log_file = "%s_opt.log" % model_name 89 | 90 | # Set the input name of the graph 91 | # For ONNX models, it is typically "0". 92 | input_name = "data" 93 | 94 | # Set number of threads used for tuning based on the number of 95 | # physical CPU cores on your machine. 96 | # num_threads = 1 97 | # os.environ["TVM_NUM_THREADS"] = str(num_threads) 98 | 99 | 100 | ################################################################# 101 | # Configure tensor tuning settings and create tasks 102 | # ------------------------------------------------- 103 | # To get better kernel execution performance on x86 CPU, 104 | # we need to change data layout of convolution kernel from 105 | # "NCHW" to "NCHWc". To deal with this situation, we define 106 | # conv2d_NCHWc operator in topi. We will tune this operator 107 | # instead of plain conv2d. 108 | # 109 | # We will use local mode for tuning configuration. RPC tracker 110 | # mode can be setup similarly to the approach in 111 | # :ref:`tune_relay_arm` tutorial. 112 | 113 | tuning_option = { 114 | 'log_filename': log_file, 115 | 'tuner': tuner, 116 | 'early_stopping': None, 117 | 'measure_option': autotvm.measure_option( 118 | builder=autotvm.LocalBuilder(), 119 | runner=autotvm.LocalRunner(number=500, repeat=1, timeout=100), 120 | ), 121 | } 122 | 123 | 124 | # You can skip the implementation of this function. 125 | def tune_kernels(tasks, 126 | measure_option, 127 | tuner='rf', 128 | early_stopping=None, 129 | log_filename='tuning.log'): 130 | 131 | for i, task in enumerate(tasks): 132 | prefix = "[Task %2d/%2d] " % (i+1, len(tasks)) 133 | 134 | # create tuner 135 | if tuner == 'xgb' or tuner == 'xgb-rank': 136 | tuner_obj = XGBTuner(task, loss_type='rank',plan_size=32) 137 | elif tuner == 'ga': 138 | tuner_obj = GATuner(task, pop_size=50) 139 | elif tuner == 'random': 140 | tuner_obj = RandomTuner(task) 141 | elif tuner == 'gridsearch': 142 | tuner_obj = GridSearchTuner(task) 143 | elif tuner == "ada": 144 | tuner_obj = RFTuner(task,feature_type="itervar",plan_size=32, dynamic_ep=True) 145 | else: 146 | raise ValueError("Invalid tuner: " + tuner) 147 | 148 | # do tuning 149 | n_trial=len(task.config_space) 150 | # n_trial=6 151 | tuner_obj.tune(n_trial=n_trial, 152 | early_stopping=early_stopping, 153 | measure_option=measure_option, 154 | callbacks=[ 155 | autotvm.callback.progress_bar(n_trial, prefix=prefix), 156 | autotvm.callback.log_to_file(log_filename)]) 157 | 158 | ######################################################################## 159 | # Finally, we launch tuning jobs and evaluate the end-to-end performance. 160 | 161 | def tune_and_evaluate(tuning_opt): 162 | # extract workloads from relay program 163 | print("Extract tasks...") 164 | mod, params, data_shape, out_shape = get_network(model_name, batch_size) 165 | tasks = autotvm.task.extract_from_program(mod["main"], target=target, 166 | params=params, 167 | # ops=(relay.op.get("nn.conv2d"), 168 | # relay.op.get("nn.dense"))) 169 | ops = ops) 170 | # run tuning tasks 171 | tune_kernels(tasks, **tuning_opt) 172 | 173 | # compile kernels with graph-level best records 174 | with autotvm.apply_history_best(log_file): 175 | print("Compile...") 176 | with tvm.transform.PassContext(opt_level=3): 177 | graph, lib, params = relay.build_module.build( 178 | mod, target=target, params=params) 179 | 180 | # upload parameters to device 181 | if use_gpu: 182 | ctx = tvm.gpu() 183 | else: 184 | ctx = tvm.cpu() 185 | data_tvm = tvm.nd.array((np.random.uniform(size=data_shape)).astype(dtype)) 186 | module = runtime.create(graph, lib, ctx) 187 | module.set_input(input_name, data_tvm) 188 | module.set_input(**params) 189 | 190 | # evaluate 191 | print("Evaluate inference time cost...") 192 | ftimer = module.module.time_evaluator("run", ctx, number=500, repeat=1) 193 | prof_res = np.array(ftimer().results) * 1000 # convert to millisecond 194 | print("Mean inference time (std dev): %.2f ms (%.2f ms)" % 195 | (np.mean(prof_res), np.std(prof_res))) 196 | 197 | # We do not run the tuning in our webpage server since it takes too long. 198 | # Uncomment the following line to run it by yourself. 199 | 200 | tune_and_evaluate(tuning_option) 201 | 202 | ###################################################################### 203 | # Sample Output 204 | # ------------- 205 | # The tuning needs to compile many programs and extract feature from them. 206 | # So a high performance CPU is recommended. 207 | # One sample output is listed below. 208 | # 209 | # .. code-block:: bash 210 | # 211 | # Extract tasks... 212 | # Tuning... 213 | # [Task 1/12] Current/Best: 598.05/2497.63 GFLOPS | Progress: (252/252) | 1357.95 s Done. 214 | # [Task 2/12] Current/Best: 522.63/2279.24 GFLOPS | Progress: (784/784) | 3989.60 s Done. 215 | # [Task 3/12] Current/Best: 447.33/1927.69 GFLOPS | Progress: (784/784) | 3869.14 s Done. 216 | # [Task 4/12] Current/Best: 481.11/1912.34 GFLOPS | Progress: (672/672) | 3274.25 s Done. 217 | # [Task 5/12] Current/Best: 414.09/1598.45 GFLOPS | Progress: (672/672) | 2720.78 s Done. 218 | # [Task 6/12] Current/Best: 508.96/2273.20 GFLOPS | Progress: (768/768) | 3718.75 s Done. 219 | # [Task 7/12] Current/Best: 469.14/1955.79 GFLOPS | Progress: (576/576) | 2665.67 s Done. 220 | # [Task 8/12] Current/Best: 230.91/1658.97 GFLOPS | Progress: (576/576) | 2435.01 s Done. 221 | # [Task 9/12] Current/Best: 487.75/2295.19 GFLOPS | Progress: (648/648) | 3009.95 s Done. 222 | # [Task 10/12] Current/Best: 182.33/1734.45 GFLOPS | Progress: (360/360) | 1755.06 s Done. 223 | # [Task 11/12] Current/Best: 372.18/1745.15 GFLOPS | Progress: (360/360) | 1684.50 s Done. 224 | # [Task 12/12] Current/Best: 215.34/2271.11 GFLOPS | Progress: (400/400) | 2128.74 s Done. 225 | # Compile... 226 | # Evaluate inference time cost... 227 | # Mean inference time (std dev): 3.16 ms (0.03 ms) 228 | -------------------------------------------------------------------------------- /tuner/RF_cost_model.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | import logging 3 | import time 4 | import numpy as np 5 | 6 | from .model_based_tuner import CostModel, FeatureCache 7 | from sklearn.ensemble import RandomForestRegressor 8 | from scipy.stats import norm 9 | from .. import feature 10 | 11 | logger = logging.getLogger('autotvm') 12 | 13 | class RFModel(CostModel): 14 | def __init__(self, task, fea_type="itervar",num_threads=None, log_interval=25, upper_model=None): 15 | super().__init__() 16 | self.task = task 17 | self.target = task.target 18 | self.space = task.config_space 19 | 20 | self.prior = RandomForestRegressor(n_estimators=10, random_state=2, max_features=10) 21 | self.fea_type = fea_type 22 | self.num_threads = num_threads 23 | self.log_interval = log_interval 24 | if fea_type == 'itervar': 25 | self.feature_extract_func = _extract_itervar_feature_index 26 | elif fea_type == 'knob': 27 | self.feature_extract_func = _extract_knob_feature_index 28 | elif fea_type == "simpleknob": 29 | self.feature_extract_func = _extract_simpleknob_feature_index 30 | elif fea_type == 'curve': 31 | self.feature_extract_func = _extract_curve_feature_index 32 | else: 33 | raise RuntimeError("Invalid feature type " + fea_type) 34 | 35 | # self.feature_cache = FeatureCache() 36 | self.best_flops = 0.0 37 | if upper_model: # share a same feature cache with upper model 38 | self.feature_cache = upper_model.feature_cache 39 | else: 40 | self.feature_cache = FeatureCache() 41 | self.upper_model = upper_model 42 | self.pool = None 43 | self._reset_pool(self.space, self.target, self.task) 44 | 45 | 46 | def _reset_pool(self, space, target, task): 47 | """reset processing pool for feature extraction""" 48 | 49 | if self.upper_model: # base model will reuse upper model's pool, 50 | self.upper_model._reset_pool(space, target, task) 51 | return 52 | 53 | self._close_pool() 54 | 55 | # use global variable to pass common arguments 56 | global _extract_space, _extract_target, _extract_task 57 | _extract_space = space 58 | _extract_target = target 59 | _extract_task = task 60 | self.pool = multiprocessing.Pool(self.num_threads) 61 | 62 | def _close_pool(self): 63 | if self.pool: 64 | self.pool.terminate() 65 | self.pool.join() 66 | self.pool = None 67 | 68 | def _get_pool(self): 69 | if self.upper_model: 70 | return self.upper_model._get_pool() 71 | return self.pool 72 | 73 | def _expected_imporvement(self, x_to_predict): 74 | feas = self._get_feature(x_to_predict) 75 | preds = np.array([tree.predict(feas) for tree in self.prior]).T 76 | eis = [] 77 | variances = [] 78 | for pred in preds: 79 | mu = np.mean(pred) 80 | sigma = pred.std() 81 | # print("mu: %f, sigma: %f" % (mu, sigma)) 82 | best_flops = self.best_flops 83 | variances.append(sigma) 84 | with np.errstate(divide='ignore'): 85 | Z = (mu - best_flops) / sigma 86 | ei = (mu - best_flops) * norm.cdf(Z) + sigma * norm.pdf(Z) 87 | ei[sigma == 0.0] == max(0.0, mu-best_flops) 88 | eis.append(ei) 89 | # print("return eis: " + str(eis)) 90 | mean_of_variance = sum(variances)/len(variances) 91 | return np.array(eis), mean_of_variance 92 | 93 | def _get_feature(self, indexes): 94 | """get features for indexes, run extraction if we do not have cache for them""" 95 | # free feature cache 96 | if self.feature_cache.size(self.fea_type) >= 100000: 97 | self.feature_cache.clear(self.fea_type) 98 | 99 | fea_cache = self.feature_cache.get(self.fea_type) 100 | 101 | indexes = np.array(indexes) 102 | need_extract = [x for x in indexes if x not in fea_cache] 103 | 104 | if need_extract: 105 | pool = self._get_pool() 106 | feas = pool.map(self.feature_extract_func, need_extract) 107 | for i, fea in zip(need_extract, feas): 108 | fea_cache[i] = fea 109 | 110 | feature_len = None 111 | for idx in indexes: 112 | if fea_cache[idx] is not None: 113 | feature_len = fea_cache[idx].shape[-1] 114 | break 115 | 116 | ret = np.empty((len(indexes), feature_len), dtype=np.float32) 117 | for i, ii in enumerate(indexes): 118 | t = fea_cache[ii] 119 | ret[i, :] = t if t is not None else 0 120 | return ret 121 | 122 | def fit(self, xs, ys, plan_size): 123 | """Fit to training data 124 | 125 | Parameters 126 | ---------- 127 | xs: Array of int 128 | indexes of configs in the config space 129 | ys: Array of float 130 | The speed (flop, float number operations per second) 131 | plan_size: int 132 | The plan size of tuner 133 | """ 134 | # here, xs is a list of config_index 135 | # transfer into corresbonding x_list of fea_type 136 | x_list = self._get_feature(xs) 137 | self.best_flops = max(ys) 138 | # print(self.best_flops) 139 | self.prior.fit(x_list, ys) 140 | 141 | def fit_log(self, records, plan_size): 142 | """Fit training data from log. 143 | 144 | Parameters 145 | ---------- 146 | records: Array of Tuple(MeasureInput, MeasureResult) 147 | The tuning records 148 | plan_size: int 149 | The plan size of tuner 150 | """ 151 | raise NotImplementedError() 152 | 153 | def predict(self, xs, output_margin=False): 154 | """Predict the speed of configs 155 | 156 | Parameters 157 | ---------- 158 | xs: Array of int 159 | The indexes of configs to predict 160 | output_margin: bool, optional 161 | Whether output the untransformed margin. 162 | When a model is used as base model, it should output untransformed margin 163 | 164 | Returns 165 | ------- 166 | preds: Array of float 167 | The prediction 168 | """ 169 | predicts, variance = self._expected_imporvement(xs) 170 | return predicts 171 | 172 | def load_basemodel(self, base_model): 173 | self.base_model = base_model 174 | self.base_model._close_pool() 175 | self.base_model.upper_model = self 176 | 177 | def spawn_base_model(self): 178 | return RFModel(self.task, self.fea_type, self.loss_type, 179 | self.num_threads, self.log_interval, self) 180 | 181 | def __del__(self): 182 | self._close_pool() 183 | 184 | _extract_space = None 185 | _extract_target = None 186 | _extract_task = None 187 | 188 | def _extract_itervar_feature_index(index): 189 | """extract iteration var feature for an index in extract_space""" 190 | try: 191 | config = _extract_space.get(index) 192 | with _extract_target: 193 | sch, args = _extract_task.instantiate(config) 194 | fea = feature.get_itervar_feature_flatten(sch, args, take_log=True) 195 | fea = np.concatenate((fea, list(config.get_other_option().values()))) 196 | return fea 197 | except Exception: # pylint: disable=broad-except 198 | return None 199 | 200 | def _extract_itervar_feature_log(arg): 201 | """extract iteration var feature for log items""" 202 | try: 203 | inp, res = arg 204 | config = inp.config 205 | with inp.target: 206 | sch, args = inp.task.instantiate(config) 207 | fea = feature.get_itervar_feature_flatten(sch, args, take_log=True) 208 | x = np.concatenate((fea, list(config.get_other_option().values()))) 209 | 210 | if res.error_no == 0: 211 | y = inp.task.flop / np.mean(res.costs) 212 | else: 213 | y = 0.0 214 | return x, y 215 | except Exception: # pylint: disable=broad-except 216 | return None 217 | 218 | def _extract_knob_feature_index(index): 219 | """extract knob feature for an index in extract_space""" 220 | try: 221 | config = _extract_space.get(index) 222 | return config.get_flatten_feature() 223 | except Exception: # pylint: disable=broad-except 224 | return None 225 | 226 | def _extract_knob_feature_log(arg): 227 | """extract knob feature for log items""" 228 | try: 229 | inp, res = arg 230 | config = inp.config 231 | x = config.get_flatten_feature() 232 | 233 | if res.error_no == 0: 234 | with inp.target: # necessary, for calculating flops of this task 235 | inp.task.instantiate(config) 236 | y = inp.task.flop / np.mean(res.costs) 237 | else: 238 | y = 0.0 239 | return x, y 240 | except Exception: # pylint: disable=broad-except 241 | return None 242 | 243 | from .model_based_tuner import knob2point, point2knob 244 | def _extract_simpleknob_feature_index(index): 245 | """take the knob as feature to train the model""" 246 | 247 | try: 248 | # config = _extract_space.get(index) 249 | # return config.get_flatten_feature() 250 | dims = [len(x) for x in _extract_space.space_map.values()] 251 | knob = point2knob(index, dims) 252 | return np.array(knob) 253 | except Exception: # pylint: disable=broad-except 254 | return None 255 | 256 | def _extract_simpleknob_feature_log(arg): 257 | """extract knob feature for log items""" 258 | try: 259 | inp, res = arg 260 | config = inp.config 261 | # x = config.get_flatten_feature() 262 | dims = [len(x) for x in _extract_space.space_map.values()] 263 | x = point2knob(inp.config.index, dims) 264 | x = np.array(x) 265 | 266 | if res.error_no == 0: 267 | with inp.target: # necessary, for calculating flops of this task 268 | inp.task.instantiate(config) 269 | y = inp.task.flop / np.mean(res.costs) 270 | else: 271 | y = 0.0 272 | return x, y 273 | except Exception: # pylint: disable=broad-except 274 | return None 275 | 276 | def _extract_curve_feature_index(index): 277 | """extract sampled curve feature for an index in extract_space""" 278 | try: 279 | config = _extract_space.get(index) 280 | with _extract_target: 281 | sch, args = _extract_task.instantiate(config) 282 | fea = feature.get_buffer_curve_sample_flatten(sch, args, sample_n=20) 283 | fea = np.concatenate((fea, list(config.get_other_option().values()))) 284 | return np.array(fea) 285 | except Exception: # pylint: disable=broad-except 286 | return None 287 | 288 | def _extract_curve_feature_log(arg): 289 | """extract sampled curve feature for log items""" 290 | try: 291 | inp, res = arg 292 | config = inp.config 293 | with inp.target: 294 | sch, args = inp.task.instantiate(config) 295 | fea = feature.get_buffer_curve_sample_flatten(sch, args, sample_n=20) 296 | x = np.concatenate((fea, list(config.get_other_option().values()))) 297 | 298 | if res.error_no == 0: 299 | y = inp.task.flop / np.mean(res.costs) 300 | else: 301 | y = 0.0 302 | return x, y 303 | except Exception: # pylint: disable=broad-except 304 | return None 305 | 306 | from .model_based_tuner import ModelOptimizer 307 | from .dynamic_model_based_tuner import ModelBasedTunerAda 308 | from .sa_model_optimizer import SimulatedAnnealingOptimizer 309 | from .local_model_optimizer import BestNeighborsOptimizer 310 | 311 | class RFTuner(ModelBasedTunerAda): 312 | def __init__(self, task, plan_size=32, 313 | feature_type='itervar', loss_type='rank', num_threads=None, 314 | optimizer='sa', diversity_filter_ratio=None, log_interval=50, dynamic_ep=False): 315 | 316 | cost_model = RFModel(task, fea_type=feature_type) 317 | if optimizer == 'sa': 318 | optimizer = SimulatedAnnealingOptimizer(task, log_interval=log_interval, parallel_size=plan_size*2) 319 | else: 320 | assert isinstance(optimizer, ModelOptimizer), "Optimizer must be " \ 321 | "a supported name string" \ 322 | "or a ModelOptimizer object." 323 | super(RFTuner, self).__init__(task, cost_model, optimizer, 324 | plan_size, diversity_filter_ratio, dynamic_ep) 325 | 326 | def tune(self, *args, **kwargs): # pylint: disable=arguments-differ 327 | super(RFTuner, self).tune(*args, **kwargs) 328 | # manually close pool to avoid multiprocessing issues 329 | self.cost_model._close_pool() 330 | -------------------------------------------------------------------------------- /measure/measure_methods.py: -------------------------------------------------------------------------------- 1 | # Licensed to the Apache Software Foundation (ASF) under one 2 | # or more contributor license agreements. See the NOTICE file 3 | # distributed with this work for additional information 4 | # regarding copyright ownership. The ASF licenses this file 5 | # to you under the Apache License, Version 2.0 (the 6 | # "License"); you may not use this file except in compliance 7 | # with the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, 12 | # software distributed under the License is distributed on an 13 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | # KIND, either express or implied. See the License for the 15 | # specific language governing permissions and limitations 16 | # under the License. 17 | # pylint: disable=invalid-name,too-many-function-args,too-many-nested-blocks 18 | """ 19 | Functions that run on executor for measurement. 20 | 21 | These functions are responsible for building the tvm module, uploading it to 22 | remote devices, recording the running time costs, and checking the correctness of the output. 23 | """ 24 | 25 | import logging 26 | import shutil 27 | import os 28 | import threading 29 | import time 30 | from random import getrandbits 31 | from collections import namedtuple 32 | import tempfile 33 | 34 | import numpy as np 35 | 36 | import tvm._ffi 37 | from tvm import nd, rpc as _rpc, target as _target 38 | from tvm.tir import ir_pass 39 | from tvm.error import TVMError 40 | from tvm.target import build_config 41 | from tvm.driver import build 42 | from tvm.contrib import nvcc, ndk, tar 43 | 44 | from ..util import get_const_tuple 45 | from ..env import AutotvmGlobalScope 46 | from ..task.space import InstantiationError 47 | 48 | from .measure import MeasureResult, MeasureErrorNo, Builder, Runner 49 | from .local_executor import LocalExecutor 50 | 51 | logger = logging.getLogger('autotvm') 52 | 53 | class BuildResult(namedtuple("BuildResult", ('filename', 'arg_info', 'error', 'time_cost'))): 54 | """ 55 | Stores all the necessary inputs for a measurement. 56 | 57 | Parameters 58 | ---------- 59 | filename : str 60 | The filename of generated library 61 | arg_info : Tuple 62 | The shape and dtype information of tvm tensor arguments 63 | error : Exception 64 | The error happens during compilation. 65 | time_cost : float 66 | The time cost of building 67 | """ 68 | 69 | class LocalBuilder(Builder): 70 | """Run compilation on local machine 71 | 72 | Parameters 73 | ---------- 74 | timeout: float 75 | The timeout of a compilation 76 | n_parallel: int 77 | The number of tasks run in parallel. "None" will use all cpu cores 78 | build_func: callable or str 79 | If is 'default', use default build function 80 | If is 'ndk', use function for android ndk 81 | If is callable, use it as custom build function, expect lib_format field. 82 | """ 83 | def __init__(self, timeout=10, n_parallel=None, build_func='default'): 84 | super(LocalBuilder, self).__init__(timeout, n_parallel) 85 | 86 | if isinstance(build_func, str): 87 | if build_func == 'default': 88 | build_func = tar.tar 89 | elif build_func == 'ndk': 90 | build_func = ndk.create_shared 91 | else: 92 | raise ValueError("Invalid build_func" + build_func) 93 | self.build_func = _wrap_build_func(build_func) 94 | self.executor = LocalExecutor(timeout=timeout) 95 | self.tmp_dir = tempfile.mkdtemp() 96 | 97 | def build(self, measure_inputs): 98 | results = [] 99 | 100 | shutil.rmtree(self.tmp_dir, ignore_errors=True) 101 | self.tmp_dir = tempfile.mkdtemp() 102 | 103 | for i in range(0, len(measure_inputs), self.n_parallel): 104 | # print("n_parallel" + str(self.n_parallel)) 105 | futures = [] 106 | for inp in measure_inputs[i:i + self.n_parallel]: 107 | ret = self.executor.submit(self.build_func, 108 | inp, 109 | self.tmp_dir, 110 | **self.build_kwargs) 111 | futures.append(ret) 112 | 113 | for future in futures: 114 | res = future.get() 115 | 116 | if isinstance(res, Exception): 117 | # timeout or fleet error, return MeasureResult directly 118 | results.append(MeasureResult((res,), MeasureErrorNo.BUILD_TIMEOUT, 119 | self.timeout, time.time())) 120 | elif res.error is not None: 121 | # instantiation error 122 | if isinstance(res.error, InstantiationError): 123 | results.append(MeasureResult((res.error,), 124 | MeasureErrorNo.INSTANTIATION_ERROR, 125 | res.time_cost, time.time())) 126 | else: 127 | if "InstantiationError" in str(res.error): 128 | msg = str(res.error) 129 | try: 130 | msg = msg.split('\n')[-2].split(": ")[1] 131 | except Exception: # pylint: disable=broad-except 132 | pass 133 | results.append(MeasureResult((InstantiationError(msg),), 134 | MeasureErrorNo.INSTANTIATION_ERROR, 135 | res.time_cost, time.time())) 136 | else: # tvm error 137 | results.append(MeasureResult((res.error,), 138 | MeasureErrorNo.COMPILE_HOST, 139 | res.time_cost, time.time())) 140 | else: 141 | # return BuildResult 142 | results.append(res) 143 | 144 | return results 145 | 146 | 147 | class RPCRunner(Runner): 148 | """Run generated code on remove devices. 149 | This function will ask a RPC Tracker to get device for measurement. 150 | 151 | Parameters 152 | ---------- 153 | timeout: float 154 | The timeout of a compilation 155 | n_parallel: int 156 | The number of tasks run in parallel. "None" will use all cpu cores 157 | key: str 158 | The key of the device registered in the tracker 159 | host: str 160 | The host address of RPC Tracker 161 | port: int 162 | The port of RPC Tracker 163 | number: int 164 | The number of times to run the generated code for taking average. 165 | We call these runs as one `repeat` of measurement. 166 | repeat : int, optional 167 | The number of times to repeat the measurement. 168 | In total, the generated code will be run (1 + number x repeat) times, 169 | where the first "1" is warm up and will be discarded. 170 | The returned result contains `repeat` costs, 171 | each of which is an average of `number` costs. 172 | min_repeat_ms: int, optional 173 | The minimum duration of one `repeat` in milliseconds. 174 | By default, one `repeat` contains `number` runs. If this parameter is set, 175 | the parameters `number` will be dynamically adjusted to meet the 176 | minimum duration requirement of one `repeat`. 177 | i.e., When the run time of one `repeat` falls below this time, the `number` parameter 178 | will be automatically increased. 179 | cooldown_interval: float, optional 180 | The cool down interval between two measurements. 181 | check_correctness: bool, optional 182 | Whether check correctness after measurement. This will use llvm cpu target to 183 | call your template and get the reference output. 184 | This can work for TOPI templates, but may not work for your custom template. 185 | """ 186 | def __init__(self, 187 | key, host, port, priority=1, 188 | timeout=10, n_parallel=None, 189 | number=4, repeat=3, min_repeat_ms=0, cooldown_interval=0.1, 190 | check_correctness=False): 191 | super(RPCRunner, self).__init__(timeout, n_parallel) 192 | 193 | self.key = key 194 | self.host = host 195 | self.port = port 196 | self.priority = priority 197 | self.timeout = timeout 198 | 199 | self.number = number 200 | self.repeat = repeat 201 | self.min_repeat_ms = min_repeat_ms 202 | 203 | self.ref_input = None 204 | self.ref_output = None 205 | self.check_correctness = check_correctness 206 | self.cooldown_interval = cooldown_interval 207 | 208 | self.executor = LocalExecutor() 209 | 210 | def set_task(self, task): 211 | self.task = task 212 | 213 | if check_remote(task.target, self.key, self.host, self.port): 214 | logger.info("Get devices for measurement successfully!") 215 | else: 216 | raise RuntimeError("Cannot get remote devices from the tracker. " 217 | "Please check the status of tracker by " 218 | "'python -m tvm.exec.query_rpc_tracker --port [THE PORT YOU USE]' " 219 | "and make sure you have free devices on the queue status.") 220 | 221 | if self.check_correctness: 222 | # use llvm cpu to generate a reference input/output 223 | # this option works for tuning topi, but might not work for you custom op 224 | with _target.create("llvm"): 225 | s, arg_bufs = task.instantiate(task.config_space.get(0)) 226 | self.ref_input = [np.random.uniform(size=get_const_tuple(x.shape)).astype(x.dtype) 227 | for x in arg_bufs] 228 | func = build(s, arg_bufs, "llvm") 229 | tvm_buf = [nd.array(x) for x in self.ref_input] 230 | func(*tvm_buf) 231 | self.ref_output = [x.asnumpy() for x in tvm_buf] 232 | 233 | def get_build_kwargs(self): 234 | kwargs = {} 235 | if 'cuda' in self.task.target.keys or 'opencl' in self.task.target.keys or \ 236 | 'rocm' in self.task.target.keys: 237 | remote = request_remote(self.key, self.host, self.port) 238 | ctx = remote.context(str(self.task.target), 0) 239 | max_dims = ctx.max_thread_dimensions 240 | kwargs['check_gpu'] = { 241 | 'max_shared_memory_per_block': ctx.max_shared_memory_per_block, 242 | 'max_threads_per_block': ctx.max_threads_per_block, 243 | 'max_thread_x': max_dims[0], 244 | 'max_thread_y': max_dims[1], 245 | 'max_thread_z': max_dims[2], 246 | } 247 | 248 | if 'cuda' in self.task.target.keys: 249 | kwargs["cuda_arch"] = "sm_" + "".join(ctx.compute_version.split('.')) 250 | 251 | return kwargs 252 | 253 | def run(self, measure_inputs, build_results): 254 | results = [] 255 | remote_args = (self.key, self.host, self.port, self.priority, self.timeout) 256 | 257 | for i in range(0, len(measure_inputs), self.n_parallel): 258 | futures = [] 259 | for measure_inp, build_res in zip(measure_inputs[i:i+self.n_parallel], 260 | build_results[i:i+self.n_parallel]): 261 | ret = self.executor.submit(run_through_rpc, 262 | measure_inp, 263 | build_res, 264 | self.number, 265 | self.repeat, 266 | self.min_repeat_ms, 267 | self.cooldown_interval, 268 | remote_args, 269 | self.ref_input, 270 | self.ref_output) 271 | futures.append(ret) 272 | 273 | for future in futures: 274 | res = future.get() 275 | if isinstance(res, Exception): # executor error or timeout 276 | results.append(MeasureResult((str(res),), MeasureErrorNo.RUN_TIMEOUT, 277 | self.timeout, time.time())) 278 | else: 279 | results.append(res) 280 | 281 | return results 282 | 283 | class LocalRunner(RPCRunner): 284 | """Run generated code on local devices. 285 | 286 | Parameters 287 | ---------- 288 | timeout: float 289 | The timeout of a compilation 290 | number: int 291 | The number of times to run the generated code for taking average. 292 | We call these runs as one `repeat` of measurement. 293 | repeat : int, optional 294 | The number of times to repeat the measurement. 295 | In total, the generated code will be run (1 + number x repeat) times, 296 | where the first one is warm up and will be discarded. 297 | The returned result contains `repeat` costs, 298 | each of which is an average of `number` costs. 299 | min_repeat_ms: int, optional 300 | The minimum duration of one `repeat` in milliseconds. 301 | By default, one `repeat` contains `number` runs. If this parameter is set, 302 | the parameters `number` will be dynamically adjusted to meet the 303 | minimum duration requirement of one `repeat`. 304 | i.e., When the run time of one `repeat` falls below this time, the `number` parameter 305 | will be automatically increased. 306 | cooldown_interval: float, optional 307 | The cool down interval between two measurements. 308 | check_correctness: bool, optional 309 | Whether check correctness after measurement. This will use llvm cpu target to 310 | call your template and get the reference output. 311 | This can work for TOPI templates, but may not work for your custom template. 312 | 313 | Note 314 | ---- 315 | This is a "fake" local mode. We start a silent rpc tracker and rpc server 316 | for the user. In this way we reuse timeout/isolation mechanism in RPC infrastructure. 317 | """ 318 | def __init__(self, 319 | timeout=10, 320 | number=4, repeat=3, min_repeat_ms=0, cooldown_interval=0.1, 321 | check_correctness=False): 322 | super(LocalRunner, self).__init__('', None, None, 0, 323 | timeout=timeout, n_parallel=1, 324 | number=number, repeat=repeat, 325 | min_repeat_ms=min_repeat_ms, 326 | cooldown_interval=cooldown_interval, 327 | check_correctness=check_correctness) 328 | self.tracker = None 329 | self.server = None 330 | 331 | def set_task(self, task): 332 | # pylint: disable=import-outside-toplevel 333 | from ...rpc.tracker import Tracker 334 | from ...rpc.server import Server 335 | 336 | self.task = task 337 | tracker = Tracker('0.0.0.0', port=9000, port_end=10000, silent=True) 338 | device_key = '$local$device$%d' % tracker.port 339 | server = Server('0.0.0.0', port=9000, port_end=10000, 340 | key=device_key, 341 | use_popen=True, silent=True, 342 | tracker_addr=(tracker.host, tracker.port)) 343 | self.key = device_key 344 | self.host = tracker.host 345 | self.port = tracker.port 346 | 347 | super(LocalRunner, self).set_task(task) 348 | return server, tracker 349 | 350 | 351 | def _build_func_common(measure_input, check_gpu=None, cuda_arch=None, build_option=None): 352 | """Common part for building a configuration""" 353 | target, task, config = measure_input 354 | with target: 355 | s, args = task.instantiate(config) 356 | 357 | # check invalidity of template and code hash consistency 358 | if not config.valid(): 359 | raise InstantiationError(config.errors) 360 | 361 | opts = build_option or {} 362 | if check_gpu: # Add verify pass to filter out invalid configs in advance. 363 | opts["add_lower_pass"] = [(2, gpu_verify_pass(**check_gpu))] 364 | if cuda_arch: 365 | set_cuda_target_arch(cuda_arch) 366 | 367 | # if target is vta, we need to use vta build 368 | if hasattr(measure_input.target, 'device_name') and \ 369 | measure_input.target.device_name == 'vta': 370 | # pylint: disable=import-outside-toplevel 371 | import vta 372 | func = vta.build(s, args, target_host=task.target_host) 373 | else: 374 | with build_config(**opts): 375 | func = build(s, args, target_host=task.target_host) 376 | return func, tuple((get_const_tuple(x.shape), x.dtype) for x in args) 377 | 378 | 379 | def _wrap_build_func(build_func): 380 | """ 381 | Wrap build_func to a function that can be used in measure. 382 | 383 | Parameters 384 | ---------- 385 | build_func : The compilation function 386 | We expect fcompile to contain an attr "output_format" 387 | 388 | Returns 389 | ------- 390 | wrapped_build_func : function 391 | The wrapped build function 392 | """ 393 | if not hasattr(build_func, "output_format"): 394 | raise AttributeError("Expect build_func to have the attribute output_format.") 395 | output_format = build_func.output_format 396 | 397 | def _wrapped(measure_input, tmp_dir, **kwargs): 398 | """ 399 | Wrapped build func. 400 | 401 | Parameters 402 | ---------- 403 | measure_input: MeasureInput 404 | The input of measurement 405 | 406 | tmp_dir: str 407 | The path of temporary directory to export generated library 408 | """ 409 | tic = time.time() 410 | try: 411 | filename = os.path.join(tmp_dir, "tmp_func_%0x.%s" % ( 412 | getrandbits(64), output_format)) 413 | # TODO(tvm-team) consider linline _build_func_common 414 | func, arg_info = _build_func_common(measure_input, **kwargs) 415 | func.export_library(filename, build_func) 416 | except Exception as e: # pylint: disable=broad-except 417 | return BuildResult(None, None, e, time.time() - tic) 418 | return BuildResult(filename, arg_info, None, time.time() - tic) 419 | return _wrapped 420 | 421 | # func.entry_name, ctx, number=number, repeat=repeat, min_repeat_ms=min_repeat_ms 422 | def adaptive_evaluator(epsilon, remote, build_result, measure_input, ref_input, number, repeat, min_repeat_ms): 423 | # print("####in adaptive evaluator###") 424 | func = remote.load_module(os.path.split(build_result.filename)[1]) 425 | ctx = remote.context(str(measure_input.target), 0) 426 | flop = measure_input.task.flop 427 | # set input 428 | if ref_input: 429 | args = [nd.array(x, ctx=ctx) for x in ref_input] 430 | else: 431 | # create empty arrays on the remote device and copy them once. 432 | # This can avoid some memory issues that make the measurement results unreliable. 433 | args = [nd.empty(x[0], dtype=x[1], ctx=ctx) for x in build_result.arg_info] 434 | args = [nd.array(x, ctx=ctx) for x in args] 435 | ctx.sync() 436 | # break the number*repeat into several batch 437 | # print("number=%d, repeat=%d" % (number, repeat)) 438 | if repeat*number < 300: # no need to do adaptive evaluator 439 | time_f = func.time_evaluator( 440 | func.entry_name, ctx, number=number, repeat=repeat, min_repeat_ms=min_repeat_ms) 441 | eva_res = time_f(*args) 442 | costs = eva_res.results 443 | else: 444 | b_size = 50 445 | costs = [] 446 | sum_num = 0 447 | max_iter = number*repeat 448 | pis = [] 449 | bi = 1 450 | rep = 0 451 | flag = True 452 | while flag and sum_num 4: # remove the min and max to reduce variance 464 | pis_array.sort() 465 | pis_array = pis_array[1:-1] 466 | # calculate the coefficient of variation 467 | cv = pis_array.std()/pis_array.mean() 468 | if bi > 2 and cv < epsilon: 469 | # print("\nindex is %d, break at batch#%d, cv=%.10f, cost is %.8f." % (measure_input.config.index, bi, cv, b_mean)) 470 | flag = False 471 | bi = bi + 1 472 | return costs 473 | 474 | def run_through_rpc(measure_input, build_result, 475 | number, repeat, min_repeat_ms, cooldown_interval, 476 | remote_args, ref_input=None, ref_output=None): 477 | """Run a generated library through rpc 478 | 479 | Parameters 480 | ---------- 481 | measure_input: MeasureInput 482 | The raw measure input 483 | build_result: BuildResult 484 | The result returned from Builder. This contains the path to the generated library. 485 | number: int 486 | The number of times to run the generated code for taking average. 487 | We call these runs as one `repeat` of measurement. 488 | repeat : int, optional 489 | The number of times to repeat the measurement. 490 | In total, the generated code will be run (1 + number x repeat) times, 491 | where the first one is warm up and will be discarded. 492 | The returned result contains `repeat` costs, 493 | each of which is an average of `number` costs. 494 | min_repeat_ms: int, optional 495 | The minimum duration of one `repeat` in milliseconds. 496 | By default, one `repeat` contains `number` runs. If this parameter is set, 497 | the parameters `number` will be dynamically adjusted to meet the 498 | minimum duration requirement of one `repeat`. 499 | i.e., When the run time of one `repeat` falls below this time, the `number` parameter 500 | will be automatically increased. 501 | cooldown_interval: float 502 | The cool down interval between two measurements 503 | remote_args: Tuple 504 | The argument for request_remote 505 | ref_input: List of np.ndarray 506 | The reference input used for checking correctness 507 | ref_output: List of np.ndarray 508 | The reference output used for checking correctness 509 | """ 510 | if isinstance(build_result, MeasureResult): 511 | return build_result 512 | 513 | tic = time.time() 514 | errno = MeasureErrorNo.NO_ERROR 515 | try: 516 | # upload built module 517 | remote = request_remote(*remote_args) 518 | # Program the FPGA every single time when targeting VTA 519 | if hasattr(measure_input.target, 'device_name') and \ 520 | measure_input.target.device_name == 'vta': 521 | # pylint: disable=import-outside-toplevel 522 | from vta import program_fpga, reconfig_runtime 523 | program_fpga(remote, None) 524 | reconfig_runtime(remote) 525 | remote.upload(build_result.filename) 526 | epsilon = 0.1 527 | costs = adaptive_evaluator(epsilon, remote, build_result, measure_input, ref_input, number, repeat, min_repeat_ms) 528 | 529 | # clean up remote files 530 | remote.remove(build_result.filename) 531 | remote.remove(os.path.splitext(build_result.filename)[0] + '.so') 532 | remote.remove('') 533 | 534 | if len(costs) > 2: # remove largest and smallest value to reduce variance 535 | costs = list(costs) 536 | costs.sort() 537 | costs = tuple(costs[1:-1]) 538 | 539 | # check correctness of output 540 | if ref_output: 541 | for expected, real in zip(ref_output, args): 542 | if not np.allclose(expected, real.asnumpy(), rtol=1e-4): 543 | logger.warning("Wrong Answer!") 544 | errno = MeasureErrorNo.WRONG_ANSWER 545 | except TVMError as exc: 546 | msg = str(exc) 547 | if "Stack trace returned" in msg: 548 | msg = msg[:msg.index("Stack trace returned")] 549 | if "CUDA Source" in msg: 550 | msg = msg[:msg.index("CUDA Source")] 551 | costs = (RuntimeError(msg[:1024]),) 552 | errno = MeasureErrorNo.RUNTIME_DEVICE 553 | tstamp = time.time() 554 | time.sleep(cooldown_interval) 555 | return MeasureResult(costs, errno, tstamp - tic + build_result.time_cost, tstamp) 556 | 557 | 558 | def request_remote(device_key, host=None, port=None, priority=1, timeout=60): 559 | """Request a remote session 560 | 561 | Parameters 562 | ---------- 563 | device_key: string 564 | The device key of registered device in tracker 565 | host: host, optional 566 | The host address of rpc tracker. 567 | If is none, will use environment variable "TVM_TRACKER_HOST" 568 | port: int, optional 569 | The port of rpc tracker. 570 | If is none, will use environment variable "TVM_TRACKER_PORT" 571 | priority: int, optional 572 | The priority of this request, larger is more prior 573 | timeout: float, optional 574 | The timeout of this session (units: second) 575 | 576 | Returns 577 | ------ 578 | session: RPCSession 579 | """ 580 | # connect to the tracker 581 | host = host or os.environ['TVM_TRACKER_HOST'] 582 | port = port or int(os.environ['TVM_TRACKER_PORT']) 583 | 584 | tracker = _rpc.connect_tracker(host, port) 585 | remote = tracker.request(device_key, priority=priority, 586 | session_timeout=timeout) 587 | return remote 588 | 589 | 590 | def check_remote(target, device_key, host=None, port=None, priority=100, timeout=10): 591 | """ 592 | Check the availability of a remote device 593 | 594 | Parameters 595 | ---------- 596 | target: Target 597 | The wanted compilation target 598 | device_key: string 599 | device key of registered device in tracker 600 | host: host, optional 601 | The host address of rpc tracker. 602 | If is none, will use environment variable "TVM_TRACKER_HOST" 603 | port: int, optional 604 | The port address of rpc tracker. 605 | If is none, will use environment variable "TVM_TRACKER_PORT" 606 | priority: int, optional 607 | The priority of this request, larger is more prior 608 | timeout: float, optional 609 | The timeout of this check (units: seconds). 610 | 611 | Returns 612 | ------- 613 | available: bool 614 | True if can find available device 615 | """ 616 | def _check(): 617 | remote = request_remote(device_key, host, port, priority) 618 | ctx = remote.context(str(target)) 619 | while not ctx.exist: # wait until we get an available device 620 | pass 621 | t = threading.Thread(target=_check,) 622 | t.start() 623 | t.join(timeout) 624 | return not t.is_alive() 625 | 626 | 627 | @tvm._ffi.register_func 628 | def tvm_callback_cuda_compile(code): 629 | """use nvcc to generate ptx code for better optimization""" 630 | curr_cuda_target_arch = AutotvmGlobalScope.current.cuda_target_arch 631 | # e.g., target arch could be [ 632 | # "-gencode", "arch=compute_52,code=sm_52", 633 | # "-gencode", "arch=compute_70,code=sm_70" 634 | # ] 635 | target = "fatbin" if isinstance(curr_cuda_target_arch, list) else "ptx" 636 | ptx = nvcc.compile_cuda(code, target=target, arch=AutotvmGlobalScope.current.cuda_target_arch) 637 | return ptx 638 | 639 | 640 | def set_cuda_target_arch(arch): 641 | """set target architecture of nvcc compiler 642 | 643 | Parameters 644 | ---------- 645 | arch: str or list 646 | The argument of nvcc -arch. (e.g. "sm_51", "sm_62") 647 | it can also be a count of gencode arguments pass to nvcc command line, 648 | e.g., ["-gencode", "arch=compute_52,code=sm_52", "-gencode", "arch=compute_70,code=sm_70"] 649 | """ 650 | AutotvmGlobalScope.current.cuda_target_arch = arch 651 | 652 | 653 | def gpu_verify_pass(**kwargs): 654 | """Verify the validity of a gpu kernel. 655 | This pass will check memory usage and number of threads per block. 656 | """ 657 | def verify_pass(stmt): 658 | valid = ir_pass.VerifyGPUCode(stmt, kwargs) 659 | if not valid: 660 | raise InstantiationError("Skipped because of invalid gpu kernel") 661 | return stmt 662 | return verify_pass 663 | --------------------------------------------------------------------------------