├── .gitignore ├── LICENSE ├── README.md ├── dataloaders ├── dataloader.py ├── nyu.py └── transforms.py ├── deploy ├── data │ ├── depth.npy │ ├── depth.png │ ├── pred.npy │ ├── pred.png │ ├── rgb.npy │ ├── rgb.png │ └── visualize.py └── tx2_run_tvm.py ├── imagenet ├── __init__.py └── mobilenet.py ├── img ├── acc_fps_cpu.png ├── acc_fps_gpu.png └── visualization.png ├── main.py ├── metrics.py ├── models.py ├── tvm_compile └── tuning │ ├── tx2-cpu.mobilenet-nnconv5.trials=1000.stop=250.log │ ├── tx2-cpu.mobilenet-nnconv5dw-skipadd-pruned.trials=1000.stop=250.log │ ├── tx2-cpu.mobilenet-nnconv5dw-skipadd.trials=1000.stop=250.log │ ├── tx2-cpu.mobilenet-nnconv5dw.trials=1000.stop=250.log │ ├── tx2-gpu.mobilenet-nnconv5.trials=2000.stop=600.log │ ├── tx2-gpu.mobilenet-nnconv5dw-skipadd-pruned.trials=2000.stop=600.log │ ├── tx2-gpu.mobilenet-nnconv5dw-skipadd.trials=2000.stop=600.log │ └── tx2-gpu.mobilenet-nnconv5dw.trials=2000.stop=600.log └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.pyc 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Diana Wofk 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | FastDepth 2 | ============================ 3 | 4 | This repo offers trained models and evaluation code for the [FastDepth](http://fastdepth.mit.edu/) project at MIT. 5 | 6 |

7 | photo not available 8 |

9 | 10 | ## Contents 11 | 0. [Requirements](#requirements) 12 | 0. [Trained Models](#trained-models) 13 | 0. [Evaluation](#evaluation) 14 | 0. [Deployment](#deployment) 15 | 0. [Results](#results) 16 | 0. [Citation](#citation) 17 | 18 | ## Requirements 19 | - Install [PyTorch](https://pytorch.org/) on a machine with a CUDA GPU. Our code was developed on a system running PyTorch v0.4.1. 20 | - Install the [HDF5](https://en.wikipedia.org/wiki/Hierarchical_Data_Format) format libraries. Files in our pre-processed datasets are in HDF5 format. 21 | ```bash 22 | sudo apt-get update 23 | sudo apt-get install -y libhdf5-serial-dev hdf5-tools 24 | pip3 install h5py matplotlib imageio scikit-image opencv-python 25 | ``` 26 | - Download the preprocessed [NYU Depth V2](http://cs.nyu.edu/~silberman/datasets/nyu_depth_v2.html) dataset in HDF5 format and place it under a `data` folder outside the repo directory. The NYU dataset requires 32G of storage space. 27 | ```bash 28 | mkdir data; cd data 29 | wget http://datasets.lids.mit.edu/fastdepth/data/nyudepthv2.tar.gz 30 | tar -xvf nyudepthv2.tar.gz && rm -f nyudepthv2.tar.gz 31 | cd .. 32 | ``` 33 | 34 | ## Trained Models ## 35 | The following trained models can be found at [http://datasets.lids.mit.edu/fastdepth/results/](http://datasets.lids.mit.edu/fastdepth/results/). 36 | - MobileNet-NNConv5 37 | - MobileNet-NNConv5(depthwise) 38 | - MobileNet-NNConv5(depthwise), with additive skip connections 39 | - **MobileNet-NNConv5(depthwise), with additive skip connections, pruned** 40 | 41 | Our final model is `mobilenet-nnconv5-skipadd-pruned`, i.e. a MobileNet-NNConv5 architecture with depthwise separable layers in the decoder, with additive skip connections between the encoder and decoder, and after network pruning using [NetAdapt](https://github.com/denru01/netadapt/). The other models are offered to provide insight into our approach. 42 | 43 | When downloading, save models to a `results` folder outside the repo directory: 44 | ```bash 45 | mkdir results; cd results 46 | wget -r -np -nH --cut-dirs=2 --reject "index.html*" http://datasets.lids.mit.edu/fastdepth/results/ 47 | cd .. 48 | ``` 49 | ### Pretrained MobileNet ### 50 | 51 | The model file for the pretrained MobileNet used in our model definition can be downloaded from [http://datasets.lids.mit.edu/fastdepth/imagenet/](http://datasets.lids.mit.edu/fastdepth/imagenet/). 52 | 53 | ## Evaluation ## 54 | 55 | This step requires a valid PyTorch installation and a saved copy of the NYU Depth v2 dataset. It is meant to be performed on a host machine with a CUDA GPU, not on an embedded platform. Deployment on an embedded device is discussed in the [next section](#deployment). 56 | 57 | To evaluate a model, navigate to the repo directory and run: 58 | 59 | ```bash 60 | python3 main.py --evaluate [path_to_trained_model] 61 | ``` 62 | 63 | The evaluation code will report model accuracy in terms of the delta1 metric as well as RMSE in millimeters. 64 | 65 | Note: This evaluation code was sourced and modified from [here](https://github.com/fangchangma/sparse-to-dense.pytorch). 66 | 67 | ## Deployment ## 68 | 69 | We use the [TVM compiler stack](https://tvm.ai/) to compile trained models for **deployment on an NVIDIA Jetson TX2**. Models are cross-compiled on a host machine and then deployed on the TX2. The `tvm-compile/tuning` folder in this repo contains the results of our [auto-tuning](https://docs.tvm.ai/tutorials/index.html#auto-tuning) the layers within our models for both the TX2 GPU and CPU. These can be used during the compilation process to achieve low model runtimes on the TX2. Outputs of TVM compilation for our trained models can be found at [http://datasets.lids.mit.edu/fastdepth/results/tvm_compiled/](http://datasets.lids.mit.edu/fastdepth/results/tvm_compiled/). 70 | 71 | On the TX2, download the trained models as explained above in the section [Trained Models](#trained-models). The compiled model files should be located in `results/tvm_compiled`. 72 | 73 | ### Installing the TVM Runtime #### 74 | 75 | Deployment requires building the TVM runtime code on the target embedded device (that will be used solely for running a trained and compiled model). The following instructions are taken from [this TVM tutorial](https://docs.tvm.ai/tutorials/cross_compilation_and_rpc.html#build-tvm-runtime-on-device) and have been tested on a **TX2 with CUDA-8.0 and LLVM-4.0 installed**. 76 | 77 | First, clone the TVM repo and modify config file: 78 | ```bash 79 | git clone --recursive https://github.com/dmlc/tvm 80 | cd tvm 81 | git reset --hard ab4946c8b80da510a5a518dca066d8159473345f 82 | git submodule update --init 83 | cp cmake/config.cmake . 84 | ``` 85 | Make the following edits to the `config.cmake` file: 86 | ```cmake 87 | set(USE_CUDA OFF) -> set(USE_CUDA [path_to_cuda]) # e.g. /usr/local/cuda-8.0/ 88 | set(USE_LLVM OFF) -> set(USE_LLVM [path_to_llvm-config]) # e.g. /usr/lib/llvm-4.0/bin/llvm-config 89 | ``` 90 | 91 | Then build the runtime: 92 | ```bash 93 | make runtime -j2 94 | ``` 95 | Finally, update the `PYTHONPATH` environment variable: 96 | ```bash 97 | export PYTHONPATH=$PYTHONPATH:~/tvm/python 98 | ``` 99 | ### Running a Compiled Model #### 100 | 101 | To run a compiled model on the device, navigate to the `deploy` folder and run: 102 | 103 | ```bash 104 | python3 tx2_run_tvm.py --input-fp [path_to_input_npy_file] --output-fp [path_to_output_npy_file] --model-dir [path_to_folder_with_tvm_compiled_model_files] 105 | ``` 106 | 107 | Note that when running a model compiled for the GPU, a `cuda` argument must be specified. For instance: 108 | 109 | ```bash 110 | python3 tx2_run_tvm.py --input-fp data/rgb.npy --output-fp data/pred.npy --model-dir ../../results/tvm_compiled/tx2_cpu_mobilenet_nnconv5dw_skipadd_pruned/ 111 | python3 tx2_run_tvm.py --input-fp data/rgb.npy --output-fp data/pred.npy --model-dir ../../results/tvm_compiled/tx2_gpu_mobilenet_nnconv5dw_skipadd_pruned/ --cuda True 112 | ``` 113 | 114 | Example RGB input, ground truth, and model prediction data (as numpy arrays) is provided in the `data` folder. To convert the `.npy` files into `.png` format, navigate into `data` and run `python3 visualize.py`. 115 | 116 | ### Measuring Power Consumption ### 117 | 118 | On the TX2, power consumption on the main VDD_IN rail can be measured by running the following command: 119 | 120 | ```bash 121 | cat /sys/devices/3160000.i2c/i2c-0/0-0041/iio_device/in_power0_input 122 | ``` 123 | 124 | ## Results 125 | 126 | Comparison against prior work. Runtimes were measured on an NVIDIA Jetson TX2 in max-N mode. 127 | 128 | | on NYU Depth v2 | Input Size | MACs [G] | RMSE [m] | delta1 | CPU [ms] | GPU [ms] | 129 | |---------------------------------------------|:-----:|:-----:|:-----:|:-----:|:-----:|:--:| 130 | | [Eigen et al. [NIPS 2014]](https://papers.nips.cc/paper/5539-depth-map-prediction-from-a-single-image-using-a-multi-scale-deep-network.pdf) | 228×304 | 2.06 | 0.907 | 0.611 | 300 | 23 | 131 | | [Eigen et al. [ICCV 2015]](https://www.cv-foundation.org/openaccess/content_iccv_2015/papers/Eigen_Predicting_Depth_Surface_ICCV_2015_paper.pdf) (with AlexNet) | 228×304 | 8.39 | 0.753 | 0.697 | 1400 | 96 | 132 | | [Eigen et al. [ICCV 2015]](https://www.cv-foundation.org/openaccess/content_iccv_2015/papers/Eigen_Predicting_Depth_Surface_ICCV_2015_paper.pdf) (with VGG) | 228×304 | 23.4 | 0.641 | 0.769 | 2800 | 195 | 133 | | [Laina et al. [3DV 2016]](https://arxiv.org/pdf/1606.00373.pdf) (with UpConv) | 228×304 | 22.9 | 0.604 | 0.789 | 2400 | 237 | 134 | | [Laina et al. [3DV 2016]](https://arxiv.org/pdf/1606.00373.pdf) (with UpProj) | 228×304 | 42.7 | **0.573** | **0.811** | 3900 | 319 | 135 | | [Xian et al. [CVPR 2018]](http://openaccess.thecvf.com/content_cvpr_2018/papers/Xian_Monocular_Relative_Depth_CVPR_2018_paper.pdf) (with UpProj) | 384×384 | 61.8 | 0.660 | 0.781 | 4400 | 283 | 136 | | This Work | 224×224 | **0.37** | 0.604 | 0.771 | **37** | **5.6** | 137 | 138 | "This Work" refers to MobileNet-NNConv5(depthwise), with additive skip connections, pruned. 139 | 140 |

141 | photo not available 142 | photo not available 143 |

144 | 145 | ## Citation 146 | If you reference our work, please consider citing the following: 147 | 148 | @inproceedings{icra_2019_fastdepth, 149 | author = {{Wofk, Diana and Ma, Fangchang and Yang, Tien-Ju and Karaman, Sertac and Sze, Vivienne}}, 150 | title = {{FastDepth: Fast Monocular Depth Estimation on Embedded Systems}}, 151 | booktitle = {{IEEE International Conference on Robotics and Automation (ICRA)}}, 152 | year = {{2019}} 153 | } 154 | -------------------------------------------------------------------------------- /dataloaders/dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import numpy as np 4 | import torch.utils.data as data 5 | import h5py 6 | import dataloaders.transforms as transforms 7 | 8 | def h5_loader(path): 9 | h5f = h5py.File(path, "r") 10 | rgb = np.array(h5f['rgb']) 11 | rgb = np.transpose(rgb, (1, 2, 0)) 12 | depth = np.array(h5f['depth']) 13 | return rgb, depth 14 | 15 | # def rgb2grayscale(rgb): 16 | # return rgb[:,:,0] * 0.2989 + rgb[:,:,1] * 0.587 + rgb[:,:,2] * 0.114 17 | 18 | class MyDataloader(data.Dataset): 19 | modality_names = ['rgb'] 20 | 21 | def is_image_file(self, filename): 22 | IMG_EXTENSIONS = ['.h5'] 23 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 24 | 25 | def find_classes(self, dir): 26 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] 27 | classes.sort() 28 | class_to_idx = {classes[i]: i for i in range(len(classes))} 29 | return classes, class_to_idx 30 | 31 | def make_dataset(self, dir, class_to_idx): 32 | images = [] 33 | dir = os.path.expanduser(dir) 34 | for target in sorted(os.listdir(dir)): 35 | d = os.path.join(dir, target) 36 | if not os.path.isdir(d): 37 | continue 38 | for root, _, fnames in sorted(os.walk(d)): 39 | for fname in sorted(fnames): 40 | if self.is_image_file(fname): 41 | path = os.path.join(root, fname) 42 | item = (path, class_to_idx[target]) 43 | images.append(item) 44 | return images 45 | 46 | color_jitter = transforms.ColorJitter(0.4, 0.4, 0.4) 47 | 48 | def __init__(self, root, split, modality='rgb', loader=h5_loader): 49 | classes, class_to_idx = self.find_classes(root) 50 | imgs = self.make_dataset(root, class_to_idx) 51 | assert len(imgs)>0, "Found 0 images in subfolders of: " + root + "\n" 52 | # print("Found {} images in {} folder.".format(len(imgs), split)) 53 | self.root = root 54 | self.imgs = imgs 55 | self.classes = classes 56 | self.class_to_idx = class_to_idx 57 | if split == 'train': 58 | self.transform = self.train_transform 59 | elif split == 'holdout': 60 | self.transform = self.val_transform 61 | elif split == 'val': 62 | self.transform = self.val_transform 63 | else: 64 | raise (RuntimeError("Invalid dataset split: " + split + "\n" 65 | "Supported dataset splits are: train, val")) 66 | self.loader = loader 67 | 68 | assert (modality in self.modality_names), "Invalid modality split: " + modality + "\n" + \ 69 | "Supported dataset splits are: " + ''.join(self.modality_names) 70 | self.modality = modality 71 | 72 | def train_transform(self, rgb, depth): 73 | raise (RuntimeError("train_transform() is not implemented. ")) 74 | 75 | def val_transform(rgb, depth): 76 | raise (RuntimeError("val_transform() is not implemented.")) 77 | 78 | def __getraw__(self, index): 79 | """ 80 | Args: 81 | index (int): Index 82 | 83 | Returns: 84 | tuple: (rgb, depth) the raw data. 85 | """ 86 | path, target = self.imgs[index] 87 | rgb, depth = self.loader(path) 88 | return rgb, depth 89 | 90 | def __getitem__(self, index): 91 | rgb, depth = self.__getraw__(index) 92 | if self.transform is not None: 93 | rgb_np, depth_np = self.transform(rgb, depth) 94 | else: 95 | raise(RuntimeError("transform not defined")) 96 | 97 | # color normalization 98 | # rgb_tensor = normalize_rgb(rgb_tensor) 99 | # rgb_np = normalize_np(rgb_np) 100 | 101 | if self.modality == 'rgb': 102 | input_np = rgb_np 103 | 104 | to_tensor = transforms.ToTensor() 105 | input_tensor = to_tensor(input_np) 106 | while input_tensor.dim() < 3: 107 | input_tensor = input_tensor.unsqueeze(0) 108 | depth_tensor = to_tensor(depth_np) 109 | depth_tensor = depth_tensor.unsqueeze(0) 110 | 111 | return input_tensor, depth_tensor 112 | 113 | def __len__(self): 114 | return len(self.imgs) 115 | -------------------------------------------------------------------------------- /dataloaders/nyu.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import dataloaders.transforms as transforms 3 | from dataloaders.dataloader import MyDataloader 4 | 5 | iheight, iwidth = 480, 640 # raw image size 6 | 7 | class NYUDataset(MyDataloader): 8 | def __init__(self, root, split, modality='rgb'): 9 | self.split = split 10 | super(NYUDataset, self).__init__(root, split, modality) 11 | self.output_size = (224, 224) 12 | 13 | def is_image_file(self, filename): 14 | # IMG_EXTENSIONS = ['.h5'] 15 | if self.split == 'train': 16 | return (filename.endswith('.h5') and \ 17 | '00001.h5' not in filename and '00201.h5' not in filename) 18 | elif self.split == 'holdout': 19 | return ('00001.h5' in filename or '00201.h5' in filename) 20 | elif self.split == 'val': 21 | return (filename.endswith('.h5')) 22 | else: 23 | raise (RuntimeError("Invalid dataset split: " + split + "\n" 24 | "Supported dataset splits are: train, val")) 25 | 26 | def train_transform(self, rgb, depth): 27 | s = np.random.uniform(1.0, 1.5) # random scaling 28 | depth_np = depth / s 29 | angle = np.random.uniform(-5.0, 5.0) # random rotation degrees 30 | do_flip = np.random.uniform(0.0, 1.0) < 0.5 # random horizontal flip 31 | 32 | # perform 1st step of data augmentation 33 | transform = transforms.Compose([ 34 | transforms.Resize(250.0 / iheight), # this is for computational efficiency, since rotation can be slow 35 | transforms.Rotate(angle), 36 | transforms.Resize(s), 37 | transforms.CenterCrop((228, 304)), 38 | transforms.HorizontalFlip(do_flip), 39 | transforms.Resize(self.output_size), 40 | ]) 41 | rgb_np = transform(rgb) 42 | rgb_np = self.color_jitter(rgb_np) # random color jittering 43 | rgb_np = np.asfarray(rgb_np, dtype='float') / 255 44 | depth_np = transform(depth_np) 45 | 46 | return rgb_np, depth_np 47 | 48 | def val_transform(self, rgb, depth): 49 | depth_np = depth 50 | transform = transforms.Compose([ 51 | transforms.Resize(250.0 / iheight), 52 | transforms.CenterCrop((228, 304)), 53 | transforms.Resize(self.output_size), 54 | ]) 55 | rgb_np = transform(rgb) 56 | rgb_np = np.asfarray(rgb_np, dtype='float') / 255 57 | depth_np = transform(depth_np) 58 | 59 | return rgb_np, depth_np 60 | -------------------------------------------------------------------------------- /dataloaders/transforms.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | import math 4 | import random 5 | 6 | from PIL import Image, ImageOps, ImageEnhance 7 | try: 8 | import accimage 9 | except ImportError: 10 | accimage = None 11 | 12 | import numpy as np 13 | import numbers 14 | import types 15 | import collections 16 | import warnings 17 | 18 | import scipy.ndimage.interpolation as itpl 19 | import scipy.misc as misc 20 | 21 | 22 | def _is_numpy_image(img): 23 | return isinstance(img, np.ndarray) and (img.ndim in {2, 3}) 24 | 25 | def _is_pil_image(img): 26 | if accimage is not None: 27 | return isinstance(img, (Image.Image, accimage.Image)) 28 | else: 29 | return isinstance(img, Image.Image) 30 | 31 | def _is_tensor_image(img): 32 | return torch.is_tensor(img) and img.ndimension() == 3 33 | 34 | def adjust_brightness(img, brightness_factor): 35 | """Adjust brightness of an Image. 36 | 37 | Args: 38 | img (PIL Image): PIL Image to be adjusted. 39 | brightness_factor (float): How much to adjust the brightness. Can be 40 | any non negative number. 0 gives a black image, 1 gives the 41 | original image while 2 increases the brightness by a factor of 2. 42 | 43 | Returns: 44 | PIL Image: Brightness adjusted image. 45 | """ 46 | if not _is_pil_image(img): 47 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 48 | 49 | enhancer = ImageEnhance.Brightness(img) 50 | img = enhancer.enhance(brightness_factor) 51 | return img 52 | 53 | 54 | def adjust_contrast(img, contrast_factor): 55 | """Adjust contrast of an Image. 56 | 57 | Args: 58 | img (PIL Image): PIL Image to be adjusted. 59 | contrast_factor (float): How much to adjust the contrast. Can be any 60 | non negative number. 0 gives a solid gray image, 1 gives the 61 | original image while 2 increases the contrast by a factor of 2. 62 | 63 | Returns: 64 | PIL Image: Contrast adjusted image. 65 | """ 66 | if not _is_pil_image(img): 67 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 68 | 69 | enhancer = ImageEnhance.Contrast(img) 70 | img = enhancer.enhance(contrast_factor) 71 | return img 72 | 73 | 74 | def adjust_saturation(img, saturation_factor): 75 | """Adjust color saturation of an image. 76 | 77 | Args: 78 | img (PIL Image): PIL Image to be adjusted. 79 | saturation_factor (float): How much to adjust the saturation. 0 will 80 | give a black and white image, 1 will give the original image while 81 | 2 will enhance the saturation by a factor of 2. 82 | 83 | Returns: 84 | PIL Image: Saturation adjusted image. 85 | """ 86 | if not _is_pil_image(img): 87 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 88 | 89 | enhancer = ImageEnhance.Color(img) 90 | img = enhancer.enhance(saturation_factor) 91 | return img 92 | 93 | 94 | def adjust_hue(img, hue_factor): 95 | """Adjust hue of an image. 96 | 97 | The image hue is adjusted by converting the image to HSV and 98 | cyclically shifting the intensities in the hue channel (H). 99 | The image is then converted back to original image mode. 100 | 101 | `hue_factor` is the amount of shift in H channel and must be in the 102 | interval `[-0.5, 0.5]`. 103 | 104 | See https://en.wikipedia.org/wiki/Hue for more details on Hue. 105 | 106 | Args: 107 | img (PIL Image): PIL Image to be adjusted. 108 | hue_factor (float): How much to shift the hue channel. Should be in 109 | [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in 110 | HSV space in positive and negative direction respectively. 111 | 0 means no shift. Therefore, both -0.5 and 0.5 will give an image 112 | with complementary colors while 0 gives the original image. 113 | 114 | Returns: 115 | PIL Image: Hue adjusted image. 116 | """ 117 | if not(-0.5 <= hue_factor <= 0.5): 118 | raise ValueError('hue_factor is not in [-0.5, 0.5].'.format(hue_factor)) 119 | 120 | if not _is_pil_image(img): 121 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 122 | 123 | input_mode = img.mode 124 | if input_mode in {'L', '1', 'I', 'F'}: 125 | return img 126 | 127 | h, s, v = img.convert('HSV').split() 128 | 129 | np_h = np.array(h, dtype=np.uint8) 130 | # uint8 addition take cares of rotation across boundaries 131 | with np.errstate(over='ignore'): 132 | np_h += np.uint8(hue_factor * 255) 133 | h = Image.fromarray(np_h, 'L') 134 | 135 | img = Image.merge('HSV', (h, s, v)).convert(input_mode) 136 | return img 137 | 138 | 139 | def adjust_gamma(img, gamma, gain=1): 140 | """Perform gamma correction on an image. 141 | 142 | Also known as Power Law Transform. Intensities in RGB mode are adjusted 143 | based on the following equation: 144 | 145 | I_out = 255 * gain * ((I_in / 255) ** gamma) 146 | 147 | See https://en.wikipedia.org/wiki/Gamma_correction for more details. 148 | 149 | Args: 150 | img (PIL Image): PIL Image to be adjusted. 151 | gamma (float): Non negative real number. gamma larger than 1 make the 152 | shadows darker, while gamma smaller than 1 make dark regions 153 | lighter. 154 | gain (float): The constant multiplier. 155 | """ 156 | if not _is_pil_image(img): 157 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 158 | 159 | if gamma < 0: 160 | raise ValueError('Gamma should be a non-negative real number') 161 | 162 | input_mode = img.mode 163 | img = img.convert('RGB') 164 | 165 | np_img = np.array(img, dtype=np.float32) 166 | np_img = 255 * gain * ((np_img / 255) ** gamma) 167 | np_img = np.uint8(np.clip(np_img, 0, 255)) 168 | 169 | img = Image.fromarray(np_img, 'RGB').convert(input_mode) 170 | return img 171 | 172 | 173 | class Compose(object): 174 | """Composes several transforms together. 175 | 176 | Args: 177 | transforms (list of ``Transform`` objects): list of transforms to compose. 178 | 179 | Example: 180 | >>> transforms.Compose([ 181 | >>> transforms.CenterCrop(10), 182 | >>> transforms.ToTensor(), 183 | >>> ]) 184 | """ 185 | 186 | def __init__(self, transforms): 187 | self.transforms = transforms 188 | 189 | def __call__(self, img): 190 | for t in self.transforms: 191 | img = t(img) 192 | return img 193 | 194 | 195 | class ToTensor(object): 196 | """Convert a ``numpy.ndarray`` to tensor. 197 | 198 | Converts a numpy.ndarray (H x W x C) to a torch.FloatTensor of shape (C x H x W). 199 | """ 200 | 201 | def __call__(self, img): 202 | """Convert a ``numpy.ndarray`` to tensor. 203 | 204 | Args: 205 | img (numpy.ndarray): Image to be converted to tensor. 206 | 207 | Returns: 208 | Tensor: Converted image. 209 | """ 210 | if not(_is_numpy_image(img)): 211 | raise TypeError('img should be ndarray. Got {}'.format(type(img))) 212 | 213 | if isinstance(img, np.ndarray): 214 | # handle numpy array 215 | if img.ndim == 3: 216 | img = torch.from_numpy(img.transpose((2, 0, 1)).copy()) 217 | elif img.ndim == 2: 218 | img = torch.from_numpy(img.copy()) 219 | else: 220 | raise RuntimeError('img should be ndarray with 2 or 3 dimensions. Got {}'.format(img.ndim)) 221 | 222 | # backward compatibility 223 | # return img.float().div(255) 224 | return img.float() 225 | 226 | 227 | class NormalizeNumpyArray(object): 228 | """Normalize a ``numpy.ndarray`` with mean and standard deviation. 229 | Given mean: ``(M1,...,Mn)`` and std: ``(M1,..,Mn)`` for ``n`` channels, this transform 230 | will normalize each channel of the input ``numpy.ndarray`` i.e. 231 | ``input[channel] = (input[channel] - mean[channel]) / std[channel]`` 232 | 233 | Args: 234 | mean (sequence): Sequence of means for each channel. 235 | std (sequence): Sequence of standard deviations for each channel. 236 | """ 237 | 238 | def __init__(self, mean, std): 239 | self.mean = mean 240 | self.std = std 241 | 242 | def __call__(self, img): 243 | """ 244 | Args: 245 | img (numpy.ndarray): Image of size (H, W, C) to be normalized. 246 | 247 | Returns: 248 | Tensor: Normalized image. 249 | """ 250 | if not(_is_numpy_image(img)): 251 | raise TypeError('img should be ndarray. Got {}'.format(type(img))) 252 | # TODO: make efficient 253 | print(img.shape) 254 | for i in range(3): 255 | img[:,:,i] = (img[:,:,i] - self.mean[i]) / self.std[i] 256 | return img 257 | 258 | class NormalizeTensor(object): 259 | """Normalize an tensor image with mean and standard deviation. 260 | Given mean: ``(M1,...,Mn)`` and std: ``(M1,..,Mn)`` for ``n`` channels, this transform 261 | will normalize each channel of the input ``torch.*Tensor`` i.e. 262 | ``input[channel] = (input[channel] - mean[channel]) / std[channel]`` 263 | 264 | Args: 265 | mean (sequence): Sequence of means for each channel. 266 | std (sequence): Sequence of standard deviations for each channel. 267 | """ 268 | 269 | def __init__(self, mean, std): 270 | self.mean = mean 271 | self.std = std 272 | 273 | def __call__(self, tensor): 274 | """ 275 | Args: 276 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized. 277 | 278 | Returns: 279 | Tensor: Normalized Tensor image. 280 | """ 281 | if not _is_tensor_image(tensor): 282 | raise TypeError('tensor is not a torch image.') 283 | # TODO: make efficient 284 | for t, m, s in zip(tensor, self.mean, self.std): 285 | t.sub_(m).div_(s) 286 | return tensor 287 | 288 | class Rotate(object): 289 | """Rotates the given ``numpy.ndarray``. 290 | 291 | Args: 292 | angle (float): The rotation angle in degrees. 293 | """ 294 | 295 | def __init__(self, angle): 296 | self.angle = angle 297 | 298 | def __call__(self, img): 299 | """ 300 | Args: 301 | img (numpy.ndarray (C x H x W)): Image to be rotated. 302 | 303 | Returns: 304 | img (numpy.ndarray (C x H x W)): Rotated image. 305 | """ 306 | 307 | # order=0 means nearest-neighbor type interpolation 308 | return itpl.rotate(img, self.angle, reshape=False, prefilter=False, order=0) 309 | 310 | 311 | class Resize(object): 312 | """Resize the the given ``numpy.ndarray`` to the given size. 313 | Args: 314 | size (sequence or int): Desired output size. If size is a sequence like 315 | (h, w), output size will be matched to this. If size is an int, 316 | smaller edge of the image will be matched to this number. 317 | i.e, if height > width, then image will be rescaled to 318 | (size * height / width, size) 319 | interpolation (int, optional): Desired interpolation. Default is 320 | ``PIL.Image.BILINEAR`` 321 | """ 322 | 323 | def __init__(self, size, interpolation='nearest'): 324 | assert isinstance(size, int) or isinstance(size, float) or \ 325 | (isinstance(size, collections.Iterable) and len(size) == 2) 326 | self.size = size 327 | self.interpolation = interpolation 328 | 329 | def __call__(self, img): 330 | """ 331 | Args: 332 | img (PIL Image): Image to be scaled. 333 | Returns: 334 | PIL Image: Rescaled image. 335 | """ 336 | if img.ndim == 3: 337 | return misc.imresize(img, self.size, self.interpolation) 338 | elif img.ndim == 2: 339 | return misc.imresize(img, self.size, self.interpolation, 'F') 340 | else: 341 | RuntimeError('img should be ndarray with 2 or 3 dimensions. Got {}'.format(img.ndim)) 342 | 343 | 344 | class CenterCrop(object): 345 | """Crops the given ``numpy.ndarray`` at the center. 346 | 347 | Args: 348 | size (sequence or int): Desired output size of the crop. If size is an 349 | int instead of sequence like (h, w), a square crop (size, size) is 350 | made. 351 | """ 352 | 353 | def __init__(self, size): 354 | if isinstance(size, numbers.Number): 355 | self.size = (int(size), int(size)) 356 | else: 357 | self.size = size 358 | 359 | @staticmethod 360 | def get_params(img, output_size): 361 | """Get parameters for ``crop`` for center crop. 362 | 363 | Args: 364 | img (numpy.ndarray (C x H x W)): Image to be cropped. 365 | output_size (tuple): Expected output size of the crop. 366 | 367 | Returns: 368 | tuple: params (i, j, h, w) to be passed to ``crop`` for center crop. 369 | """ 370 | h = img.shape[0] 371 | w = img.shape[1] 372 | th, tw = output_size 373 | i = int(round((h - th) / 2.)) 374 | j = int(round((w - tw) / 2.)) 375 | 376 | # # randomized cropping 377 | # i = np.random.randint(i-3, i+4) 378 | # j = np.random.randint(j-3, j+4) 379 | 380 | return i, j, th, tw 381 | 382 | def __call__(self, img): 383 | """ 384 | Args: 385 | img (numpy.ndarray (C x H x W)): Image to be cropped. 386 | 387 | Returns: 388 | img (numpy.ndarray (C x H x W)): Cropped image. 389 | """ 390 | i, j, h, w = self.get_params(img, self.size) 391 | 392 | """ 393 | i: Upper pixel coordinate. 394 | j: Left pixel coordinate. 395 | h: Height of the cropped image. 396 | w: Width of the cropped image. 397 | """ 398 | if not(_is_numpy_image(img)): 399 | raise TypeError('img should be ndarray. Got {}'.format(type(img))) 400 | if img.ndim == 3: 401 | return img[i:i+h, j:j+w, :] 402 | elif img.ndim == 2: 403 | return img[i:i + h, j:j + w] 404 | else: 405 | raise RuntimeError('img should be ndarray with 2 or 3 dimensions. Got {}'.format(img.ndim)) 406 | 407 | class BottomCrop(object): 408 | """Crops the given ``numpy.ndarray`` at the bottom. 409 | 410 | Args: 411 | size (sequence or int): Desired output size of the crop. If size is an 412 | int instead of sequence like (h, w), a square crop (size, size) is 413 | made. 414 | """ 415 | 416 | def __init__(self, size): 417 | if isinstance(size, numbers.Number): 418 | self.size = (int(size), int(size)) 419 | else: 420 | self.size = size 421 | 422 | @staticmethod 423 | def get_params(img, output_size): 424 | """Get parameters for ``crop`` for bottom crop. 425 | 426 | Args: 427 | img (numpy.ndarray (C x H x W)): Image to be cropped. 428 | output_size (tuple): Expected output size of the crop. 429 | 430 | Returns: 431 | tuple: params (i, j, h, w) to be passed to ``crop`` for bottom crop. 432 | """ 433 | h = img.shape[0] 434 | w = img.shape[1] 435 | th, tw = output_size 436 | i = h - th 437 | j = int(round((w - tw) / 2.)) 438 | 439 | # randomized left and right cropping 440 | # i = np.random.randint(i-3, i+4) 441 | # j = np.random.randint(j-1, j+1) 442 | 443 | return i, j, th, tw 444 | 445 | def __call__(self, img): 446 | """ 447 | Args: 448 | img (numpy.ndarray (C x H x W)): Image to be cropped. 449 | 450 | Returns: 451 | img (numpy.ndarray (C x H x W)): Cropped image. 452 | """ 453 | i, j, h, w = self.get_params(img, self.size) 454 | 455 | """ 456 | i: Upper pixel coordinate. 457 | j: Left pixel coordinate. 458 | h: Height of the cropped image. 459 | w: Width of the cropped image. 460 | """ 461 | if not(_is_numpy_image(img)): 462 | raise TypeError('img should be ndarray. Got {}'.format(type(img))) 463 | if img.ndim == 3: 464 | return img[i:i+h, j:j+w, :] 465 | elif img.ndim == 2: 466 | return img[i:i + h, j:j + w] 467 | else: 468 | raise RuntimeError('img should be ndarray with 2 or 3 dimensions. Got {}'.format(img.ndim)) 469 | 470 | class Lambda(object): 471 | """Apply a user-defined lambda as a transform. 472 | 473 | Args: 474 | lambd (function): Lambda/function to be used for transform. 475 | """ 476 | 477 | def __init__(self, lambd): 478 | assert isinstance(lambd, types.LambdaType) 479 | self.lambd = lambd 480 | 481 | def __call__(self, img): 482 | return self.lambd(img) 483 | 484 | 485 | class HorizontalFlip(object): 486 | """Horizontally flip the given ``numpy.ndarray``. 487 | 488 | Args: 489 | do_flip (boolean): whether or not do horizontal flip. 490 | 491 | """ 492 | 493 | def __init__(self, do_flip): 494 | self.do_flip = do_flip 495 | 496 | def __call__(self, img): 497 | """ 498 | Args: 499 | img (numpy.ndarray (C x H x W)): Image to be flipped. 500 | 501 | Returns: 502 | img (numpy.ndarray (C x H x W)): flipped image. 503 | """ 504 | if not(_is_numpy_image(img)): 505 | raise TypeError('img should be ndarray. Got {}'.format(type(img))) 506 | 507 | if self.do_flip: 508 | return np.fliplr(img) 509 | else: 510 | return img 511 | 512 | 513 | class ColorJitter(object): 514 | """Randomly change the brightness, contrast and saturation of an image. 515 | 516 | Args: 517 | brightness (float): How much to jitter brightness. brightness_factor 518 | is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]. 519 | contrast (float): How much to jitter contrast. contrast_factor 520 | is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]. 521 | saturation (float): How much to jitter saturation. saturation_factor 522 | is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]. 523 | hue(float): How much to jitter hue. hue_factor is chosen uniformly from 524 | [-hue, hue]. Should be >=0 and <= 0.5. 525 | """ 526 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): 527 | self.brightness = brightness 528 | self.contrast = contrast 529 | self.saturation = saturation 530 | self.hue = hue 531 | 532 | @staticmethod 533 | def get_params(brightness, contrast, saturation, hue): 534 | """Get a randomized transform to be applied on image. 535 | 536 | Arguments are same as that of __init__. 537 | 538 | Returns: 539 | Transform which randomly adjusts brightness, contrast and 540 | saturation in a random order. 541 | """ 542 | transforms = [] 543 | if brightness > 0: 544 | brightness_factor = np.random.uniform(max(0, 1 - brightness), 1 + brightness) 545 | transforms.append(Lambda(lambda img: adjust_brightness(img, brightness_factor))) 546 | 547 | if contrast > 0: 548 | contrast_factor = np.random.uniform(max(0, 1 - contrast), 1 + contrast) 549 | transforms.append(Lambda(lambda img: adjust_contrast(img, contrast_factor))) 550 | 551 | if saturation > 0: 552 | saturation_factor = np.random.uniform(max(0, 1 - saturation), 1 + saturation) 553 | transforms.append(Lambda(lambda img: adjust_saturation(img, saturation_factor))) 554 | 555 | if hue > 0: 556 | hue_factor = np.random.uniform(-hue, hue) 557 | transforms.append(Lambda(lambda img: adjust_hue(img, hue_factor))) 558 | 559 | np.random.shuffle(transforms) 560 | transform = Compose(transforms) 561 | 562 | return transform 563 | 564 | def __call__(self, img): 565 | """ 566 | Args: 567 | img (numpy.ndarray (C x H x W)): Input image. 568 | 569 | Returns: 570 | img (numpy.ndarray (C x H x W)): Color jittered image. 571 | """ 572 | if not(_is_numpy_image(img)): 573 | raise TypeError('img should be ndarray. Got {}'.format(type(img))) 574 | 575 | pil = Image.fromarray(img) 576 | transform = self.get_params(self.brightness, self.contrast, 577 | self.saturation, self.hue) 578 | return np.array(transform(pil)) 579 | 580 | class Crop(object): 581 | """Crops the given PIL Image to a rectangular region based on a given 582 | 4-tuple defining the left, upper pixel coordinated, hight and width size. 583 | 584 | Args: 585 | a tuple: (upper pixel coordinate, left pixel coordinate, hight, width)-tuple 586 | """ 587 | 588 | def __init__(self, i, j, h, w): 589 | """ 590 | i: Upper pixel coordinate. 591 | j: Left pixel coordinate. 592 | h: Height of the cropped image. 593 | w: Width of the cropped image. 594 | """ 595 | self.i = i 596 | self.j = j 597 | self.h = h 598 | self.w = w 599 | 600 | def __call__(self, img): 601 | """ 602 | Args: 603 | img (numpy.ndarray (C x H x W)): Image to be cropped. 604 | Returns: 605 | img (numpy.ndarray (C x H x W)): Cropped image. 606 | """ 607 | 608 | i, j, h, w = self.i, self.j, self.h, self.w 609 | 610 | if not(_is_numpy_image(img)): 611 | raise TypeError('img should be ndarray. Got {}'.format(type(img))) 612 | if img.ndim == 3: 613 | return img[i:i + h, j:j + w, :] 614 | elif img.ndim == 2: 615 | return img[i:i + h, j:j + w] 616 | else: 617 | raise RuntimeError( 618 | 'img should be ndarray with 2 or 3 dimensions. Got {}'.format(img.ndim)) 619 | 620 | def __repr__(self): 621 | return self.__class__.__name__ + '(i={0},j={1},h={2},w={3})'.format( 622 | self.i, self.j, self.h, self.w) 623 | -------------------------------------------------------------------------------- /deploy/data/depth.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dwofk/fast-depth/e68492011609c9bfb7de6d402da5d1d201d95bd9/deploy/data/depth.npy -------------------------------------------------------------------------------- /deploy/data/depth.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dwofk/fast-depth/e68492011609c9bfb7de6d402da5d1d201d95bd9/deploy/data/depth.png -------------------------------------------------------------------------------- /deploy/data/pred.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dwofk/fast-depth/e68492011609c9bfb7de6d402da5d1d201d95bd9/deploy/data/pred.npy -------------------------------------------------------------------------------- /deploy/data/pred.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dwofk/fast-depth/e68492011609c9bfb7de6d402da5d1d201d95bd9/deploy/data/pred.png -------------------------------------------------------------------------------- /deploy/data/rgb.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dwofk/fast-depth/e68492011609c9bfb7de6d402da5d1d201d95bd9/deploy/data/rgb.npy -------------------------------------------------------------------------------- /deploy/data/rgb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dwofk/fast-depth/e68492011609c9bfb7de6d402da5d1d201d95bd9/deploy/data/rgb.png -------------------------------------------------------------------------------- /deploy/data/visualize.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib as mp 3 | 4 | mp.use("pdf") 5 | import matplotlib.pyplot as plt 6 | 7 | cmap = plt.cm.viridis 8 | 9 | def colored_depthmap(depth, d_min=None, d_max=None): 10 | if d_min is None: 11 | d_min = np.min(depth) 12 | if d_max is None: 13 | d_max = np.max(depth) 14 | depth_relative = (depth - d_min) / (d_max - d_min) 15 | return 255 * cmap(depth_relative)[:,:,:3] # HWC 16 | 17 | def save_rgb_image(rgb_npy_fp, filename): 18 | rgb_np = np.load(rgb_npy_fp) 19 | rgb_scaled = 255 * np.transpose(np.squeeze(rgb_np), (0,1,2)) # HWC 20 | mp.image.imsave('rgb.png', rgb_scaled.astype('uint8')) 21 | 22 | def save_depth_image(depth_npy_fp, filename): 23 | depth_np = np.load(depth_npy_fp) 24 | depth_np_color = colored_depthmap(depth_np) 25 | mp.image.imsave('depth.png', depth_np_color.astype('uint8')) 26 | 27 | def save_pred_image(pred_npy_fp, filename): 28 | pred_np = np.load(pred_npy_fp) 29 | pred_np_2d = pred_np[0,0,:,:] # HW 30 | pred_np_color = colored_depthmap(pred_np_2d) 31 | mp.image.imsave('pred.png', pred_np_color.astype('uint8')) 32 | 33 | save_rgb_image('rgb.npy', 'rgb.png') 34 | save_depth_image('depth.npy', 'depth.png') 35 | save_pred_image('pred.npy', 'pred.png') 36 | -------------------------------------------------------------------------------- /deploy/tx2_run_tvm.py: -------------------------------------------------------------------------------- 1 | import tvm 2 | import numpy as np 3 | import argparse 4 | import os 5 | import time 6 | 7 | def run_model(model_dir, input_fp, output_fp, warmup_trials, run_trials, cuda, try_randin): 8 | # import compiled graph 9 | print("=> [TVM on TX2] using model files in {}".format(model_dir)) 10 | assert(os.path.isdir(model_dir)) 11 | 12 | print("=> [TVM on TX2] loading model lib and ptx") 13 | loaded_lib = tvm.module.load(os.path.join(model_dir, "deploy_lib.o")) 14 | if cuda: 15 | dev_lib = tvm.module.load(os.path.join(model_dir, "deploy_cuda.ptx")) 16 | loaded_lib.import_module(dev_lib) 17 | 18 | print("=> [TVM on TX2] loading model graph and params") 19 | loaded_graph = open(os.path.join(model_dir,"deploy_graph.json")).read() 20 | loaded_params = bytearray(open(os.path.join(model_dir, "deploy_param.params"), "rb").read()) 21 | 22 | print("=> [TVM on TX2] creating TVM runtime module") 23 | fcreate = tvm.get_global_func("tvm.graph_runtime.create") 24 | ctx = tvm.gpu(0) if cuda else tvm.cpu(0) 25 | gmodule = fcreate(loaded_graph, loaded_lib, ctx.device_type, ctx.device_id) 26 | set_input, get_output, run = gmodule["set_input"], gmodule["get_output"], gmodule["run"] 27 | 28 | print("=> [TVM on TX2] feeding inputs and params into TVM module") 29 | rgb_np = np.load(input_fp) # HWC 30 | x = np.zeros([1,3,224,224]) # NCHW 31 | x[0,:,:,:] = np.transpose(rgb_np, (2,0,1)) 32 | set_input('0', tvm.nd.array(x.astype('float32'))) 33 | gmodule["load_params"](loaded_params) 34 | 35 | print("=> [TVM on TX2] running TVM module, saving output") 36 | run() # not gmodule.run() 37 | out_shape = (1, 1, 224, 224) 38 | out = tvm.nd.empty(out_shape, "float32") 39 | get_output(0, out) 40 | np.save(output_fp, out.asnumpy()) 41 | 42 | print("=> [TVM on TX2] benchmarking: {} warmup, {} run trials".format(warmup_trials, run_trials)) 43 | # run model several times as a warmup 44 | for i in range(warmup_trials): 45 | run() 46 | ctx.sync() 47 | 48 | # profile runtime using TVM time evaluator 49 | ftimer = gmodule.time_evaluator("run", ctx, number=1, repeat=run_trials) 50 | profile_result = ftimer() 51 | profiled_runtime = profile_result[0] 52 | 53 | print("=> [TVM on TX2] profiled runtime (in ms): {:.5f}".format(1000*profiled_runtime)) 54 | 55 | # try randomizing input 56 | if try_randin: 57 | randin_runtime = 0 58 | for i in range(run_trials): 59 | x = np.random.randn(1, 3, 224, 224) 60 | set_input('0', tvm.nd.array(x.astype('float32'))) 61 | randin_ftimer = gmodule.time_evaluator("run", ctx, number=1, repeat=1) 62 | randin_profile_result = randin_ftimer() 63 | randin_runtime += randin_profile_result[0] 64 | randomized_input_runtime = randin_runtime/run_trials 65 | print("=> [TVM on TX2] with randomized input on every run, profiled runtime (in ms): {:.5f}".format(1000*randomized_input_runtime)) 66 | 67 | def main(): 68 | parser = argparse.ArgumentParser() 69 | parser.add_argument('--model-dir', type=str, required=True, 70 | help='path to folder with TVM-compiled model files (required)') 71 | parser.add_argument('--input-fp', type=str, default='data/rgb.npy', 72 | help='numpy file containing input rgb data (default: data/rgb.npy') 73 | parser.add_argument('--output-fp', type=str, default='data/pred.npy', 74 | help='numpy file to store output prediction data (default: data/pred.npy') 75 | 76 | 77 | parser.add_argument('--warmup', type=int, default=10, 78 | help='number of inference warmup trials (default: 10)') 79 | parser.add_argument('--run', type=int, default=100, 80 | help='number of inference run trials (default: 100)') 81 | parser.add_argument('--cuda', type=bool, default=False, 82 | help='run with CUDA (default: False)') 83 | 84 | parser.add_argument('--randin', type=bool, default=False, 85 | help='profile runtime while randomizing input on every run (default: False)') 86 | 87 | args = parser.parse_args() 88 | run_model(args.model_dir, args.input_fp, args.output_fp, args.warmup, args.run, args.cuda, try_randin=args.randin) 89 | 90 | if __name__ == '__main__': 91 | main() 92 | 93 | -------------------------------------------------------------------------------- /imagenet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dwofk/fast-depth/e68492011609c9bfb7de6d402da5d1d201d95bd9/imagenet/__init__.py -------------------------------------------------------------------------------- /imagenet/mobilenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import time 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.parallel 8 | import torch.backends.cudnn as cudnn 9 | import torch.optim 10 | import torch.utils.data 11 | 12 | class MobileNet(nn.Module): 13 | def __init__(self, relu6=True): 14 | super(MobileNet, self).__init__() 15 | 16 | def relu(relu6): 17 | if relu6: 18 | return nn.ReLU6(inplace=True) 19 | else: 20 | return nn.ReLU(inplace=True) 21 | 22 | def conv_bn(inp, oup, stride, relu6): 23 | return nn.Sequential( 24 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 25 | nn.BatchNorm2d(oup), 26 | relu(relu6), 27 | ) 28 | 29 | def conv_dw(inp, oup, stride, relu6): 30 | return nn.Sequential( 31 | nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), 32 | nn.BatchNorm2d(inp), 33 | relu(relu6), 34 | 35 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 36 | nn.BatchNorm2d(oup), 37 | relu(relu6), 38 | ) 39 | 40 | self.model = nn.Sequential( 41 | conv_bn( 3, 32, 2, relu6), 42 | conv_dw( 32, 64, 1, relu6), 43 | conv_dw( 64, 128, 2, relu6), 44 | conv_dw(128, 128, 1, relu6), 45 | conv_dw(128, 256, 2, relu6), 46 | conv_dw(256, 256, 1, relu6), 47 | conv_dw(256, 512, 2, relu6), 48 | conv_dw(512, 512, 1, relu6), 49 | conv_dw(512, 512, 1, relu6), 50 | conv_dw(512, 512, 1, relu6), 51 | conv_dw(512, 512, 1, relu6), 52 | conv_dw(512, 512, 1, relu6), 53 | conv_dw(512, 1024, 2, relu6), 54 | conv_dw(1024, 1024, 1, relu6), 55 | nn.AvgPool2d(7), 56 | ) 57 | self.fc = nn.Linear(1024, 1000) 58 | 59 | def forward(self, x): 60 | x = self.model(x) 61 | x = x.view(-1, 1024) 62 | x = self.fc(x) 63 | return x 64 | 65 | def main(): 66 | import torchvision.models 67 | model = MobileNet(relu6=True) 68 | model = torch.nn.DataParallel(model).cuda() 69 | model_filename = os.path.join('results', 'imagenet.arch=mobilenet.lr=0.1.bs=256', 'model_best.pth.tar') 70 | if os.path.isfile(model_filename): 71 | print("=> loading Imagenet pretrained model '{}'".format(model_filename)) 72 | checkpoint = torch.load(model_filename) 73 | epoch = checkpoint['epoch'] 74 | best_prec1 = checkpoint['best_prec1'] 75 | model.load_state_dict(checkpoint['state_dict']) 76 | print("=> loaded Imagenet pretrained model '{}' (epoch {}). best_prec1={}".format(model_filename, epoch, best_prec1)) 77 | 78 | if __name__ == '__main__': 79 | main() 80 | -------------------------------------------------------------------------------- /img/acc_fps_cpu.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dwofk/fast-depth/e68492011609c9bfb7de6d402da5d1d201d95bd9/img/acc_fps_cpu.png -------------------------------------------------------------------------------- /img/acc_fps_gpu.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dwofk/fast-depth/e68492011609c9bfb7de6d402da5d1d201d95bd9/img/acc_fps_gpu.png -------------------------------------------------------------------------------- /img/visualization.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dwofk/fast-depth/e68492011609c9bfb7de6d402da5d1d201d95bd9/img/visualization.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import csv 4 | import numpy as np 5 | 6 | import torch 7 | import torch.nn.parallel 8 | import torch.backends.cudnn as cudnn 9 | import torch.optim 10 | cudnn.benchmark = True 11 | 12 | import models 13 | from metrics import AverageMeter, Result 14 | import utils 15 | 16 | args = utils.parse_command() 17 | print(args) 18 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu # Set the GPU. 19 | 20 | fieldnames = ['rmse', 'mae', 'delta1', 'absrel', 21 | 'lg10', 'mse', 'delta2', 'delta3', 'data_time', 'gpu_time'] 22 | best_fieldnames = ['best_epoch'] + fieldnames 23 | best_result = Result() 24 | best_result.set_to_worst() 25 | 26 | def main(): 27 | global args, best_result, output_directory, train_csv, test_csv 28 | 29 | # Data loading code 30 | print("=> creating data loaders...") 31 | valdir = os.path.join('..', 'data', args.data, 'val') 32 | 33 | if args.data == 'nyudepthv2': 34 | from dataloaders.nyu import NYUDataset 35 | val_dataset = NYUDataset(valdir, split='val', modality=args.modality) 36 | else: 37 | raise RuntimeError('Dataset not found.') 38 | 39 | # set batch size to be 1 for validation 40 | val_loader = torch.utils.data.DataLoader(val_dataset, 41 | batch_size=1, shuffle=False, num_workers=args.workers, pin_memory=True) 42 | print("=> data loaders created.") 43 | 44 | # evaluation mode 45 | if args.evaluate: 46 | assert os.path.isfile(args.evaluate), \ 47 | "=> no model found at '{}'".format(args.evaluate) 48 | print("=> loading model '{}'".format(args.evaluate)) 49 | checkpoint = torch.load(args.evaluate) 50 | if type(checkpoint) is dict: 51 | args.start_epoch = checkpoint['epoch'] 52 | best_result = checkpoint['best_result'] 53 | model = checkpoint['model'] 54 | print("=> loaded best model (epoch {})".format(checkpoint['epoch'])) 55 | else: 56 | model = checkpoint 57 | args.start_epoch = 0 58 | output_directory = os.path.dirname(args.evaluate) 59 | validate(val_loader, model, args.start_epoch, write_to_file=False) 60 | return 61 | 62 | 63 | def validate(val_loader, model, epoch, write_to_file=True): 64 | average_meter = AverageMeter() 65 | model.eval() # switch to evaluate mode 66 | end = time.time() 67 | for i, (input, target) in enumerate(val_loader): 68 | input, target = input.cuda(), target.cuda() 69 | # torch.cuda.synchronize() 70 | data_time = time.time() - end 71 | 72 | # compute output 73 | end = time.time() 74 | with torch.no_grad(): 75 | pred = model(input) 76 | # torch.cuda.synchronize() 77 | gpu_time = time.time() - end 78 | 79 | # measure accuracy and record loss 80 | result = Result() 81 | result.evaluate(pred.data, target.data) 82 | average_meter.update(result, gpu_time, data_time, input.size(0)) 83 | end = time.time() 84 | 85 | # save 8 images for visualization 86 | skip = 50 87 | 88 | if args.modality == 'rgb': 89 | rgb = input 90 | 91 | if i == 0: 92 | img_merge = utils.merge_into_row(rgb, target, pred) 93 | elif (i < 8*skip) and (i % skip == 0): 94 | row = utils.merge_into_row(rgb, target, pred) 95 | img_merge = utils.add_row(img_merge, row) 96 | elif i == 8*skip: 97 | filename = output_directory + '/comparison_' + str(epoch) + '.png' 98 | utils.save_image(img_merge, filename) 99 | 100 | if (i+1) % args.print_freq == 0: 101 | print('Test: [{0}/{1}]\t' 102 | 't_GPU={gpu_time:.3f}({average.gpu_time:.3f})\n\t' 103 | 'RMSE={result.rmse:.2f}({average.rmse:.2f}) ' 104 | 'MAE={result.mae:.2f}({average.mae:.2f}) ' 105 | 'Delta1={result.delta1:.3f}({average.delta1:.3f}) ' 106 | 'REL={result.absrel:.3f}({average.absrel:.3f}) ' 107 | 'Lg10={result.lg10:.3f}({average.lg10:.3f}) '.format( 108 | i+1, len(val_loader), gpu_time=gpu_time, result=result, average=average_meter.average())) 109 | 110 | avg = average_meter.average() 111 | 112 | print('\n*\n' 113 | 'RMSE={average.rmse:.3f}\n' 114 | 'MAE={average.mae:.3f}\n' 115 | 'Delta1={average.delta1:.3f}\n' 116 | 'REL={average.absrel:.3f}\n' 117 | 'Lg10={average.lg10:.3f}\n' 118 | 't_GPU={time:.3f}\n'.format( 119 | average=avg, time=avg.gpu_time)) 120 | 121 | if write_to_file: 122 | with open(test_csv, 'a') as csvfile: 123 | writer = csv.DictWriter(csvfile, fieldnames=fieldnames) 124 | writer.writerow({'mse': avg.mse, 'rmse': avg.rmse, 'absrel': avg.absrel, 'lg10': avg.lg10, 125 | 'mae': avg.mae, 'delta1': avg.delta1, 'delta2': avg.delta2, 'delta3': avg.delta3, 126 | 'data_time': avg.data_time, 'gpu_time': avg.gpu_time}) 127 | return avg, img_merge 128 | 129 | if __name__ == '__main__': 130 | main() 131 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import numpy as np 4 | 5 | def log10(x): 6 | """Convert a new tensor with the base-10 logarithm of the elements of x. """ 7 | return torch.log(x) / math.log(10) 8 | 9 | class Result(object): 10 | def __init__(self): 11 | self.irmse, self.imae = 0, 0 12 | self.mse, self.rmse, self.mae = 0, 0, 0 13 | self.absrel, self.lg10 = 0, 0 14 | self.delta1, self.delta2, self.delta3 = 0, 0, 0 15 | self.data_time, self.gpu_time = 0, 0 16 | 17 | def set_to_worst(self): 18 | self.irmse, self.imae = np.inf, np.inf 19 | self.mse, self.rmse, self.mae = np.inf, np.inf, np.inf 20 | self.absrel, self.lg10 = np.inf, np.inf 21 | self.delta1, self.delta2, self.delta3 = 0, 0, 0 22 | self.data_time, self.gpu_time = 0, 0 23 | 24 | def update(self, irmse, imae, mse, rmse, mae, absrel, lg10, delta1, delta2, delta3, gpu_time, data_time): 25 | self.irmse, self.imae = irmse, imae 26 | self.mse, self.rmse, self.mae = mse, rmse, mae 27 | self.absrel, self.lg10 = absrel, lg10 28 | self.delta1, self.delta2, self.delta3 = delta1, delta2, delta3 29 | self.data_time, self.gpu_time = data_time, gpu_time 30 | 31 | def evaluate(self, output, target): 32 | valid_mask = ((target>0) + (output>0)) > 0 33 | 34 | output = 1e3 * output[valid_mask] 35 | target = 1e3 * target[valid_mask] 36 | abs_diff = (output - target).abs() 37 | 38 | self.mse = float((torch.pow(abs_diff, 2)).mean()) 39 | self.rmse = math.sqrt(self.mse) 40 | self.mae = float(abs_diff.mean()) 41 | self.lg10 = float((log10(output) - log10(target)).abs().mean()) 42 | self.absrel = float((abs_diff / target).mean()) 43 | 44 | maxRatio = torch.max(output / target, target / output) 45 | self.delta1 = float((maxRatio < 1.25).float().mean()) 46 | self.delta2 = float((maxRatio < 1.25 ** 2).float().mean()) 47 | self.delta3 = float((maxRatio < 1.25 ** 3).float().mean()) 48 | self.data_time = 0 49 | self.gpu_time = 0 50 | 51 | inv_output = 1 / output 52 | inv_target = 1 / target 53 | abs_inv_diff = (inv_output - inv_target).abs() 54 | self.irmse = math.sqrt((torch.pow(abs_inv_diff, 2)).mean()) 55 | self.imae = float(abs_inv_diff.mean()) 56 | 57 | 58 | class AverageMeter(object): 59 | def __init__(self): 60 | self.reset() 61 | 62 | def reset(self): 63 | self.count = 0.0 64 | 65 | self.sum_irmse, self.sum_imae = 0, 0 66 | self.sum_mse, self.sum_rmse, self.sum_mae = 0, 0, 0 67 | self.sum_absrel, self.sum_lg10 = 0, 0 68 | self.sum_delta1, self.sum_delta2, self.sum_delta3 = 0, 0, 0 69 | self.sum_data_time, self.sum_gpu_time = 0, 0 70 | 71 | def update(self, result, gpu_time, data_time, n=1): 72 | self.count += n 73 | 74 | self.sum_irmse += n*result.irmse 75 | self.sum_imae += n*result.imae 76 | self.sum_mse += n*result.mse 77 | self.sum_rmse += n*result.rmse 78 | self.sum_mae += n*result.mae 79 | self.sum_absrel += n*result.absrel 80 | self.sum_lg10 += n*result.lg10 81 | self.sum_delta1 += n*result.delta1 82 | self.sum_delta2 += n*result.delta2 83 | self.sum_delta3 += n*result.delta3 84 | self.sum_data_time += n*data_time 85 | self.sum_gpu_time += n*gpu_time 86 | 87 | def average(self): 88 | avg = Result() 89 | avg.update( 90 | self.sum_irmse / self.count, self.sum_imae / self.count, 91 | self.sum_mse / self.count, self.sum_rmse / self.count, self.sum_mae / self.count, 92 | self.sum_absrel / self.count, self.sum_lg10 / self.count, 93 | self.sum_delta1 / self.count, self.sum_delta2 / self.count, self.sum_delta3 / self.count, 94 | self.sum_gpu_time / self.count, self.sum_data_time / self.count) 95 | return avg -------------------------------------------------------------------------------- /tvm_compile/tuning/tx2-cpu.mobilenet-nnconv5.trials=1000.stop=250.log: -------------------------------------------------------------------------------- 1 | {"i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_conv2d", [["TENSOR", [1, 32, 224, 224], "float32"], ["TENSOR", [1, 32, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 32, 224, 224, "float32"], [1, 32, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"t": "direct", "i": 7410, "c": null, "e": [["tile_co", "sp", [1, 1]], ["tile_oh", "sp", [14, 16]], ["tile_ow", "sp", [16, 14]], ["reorder_0", "re", [0, 1, 2, 3, 4, 5, 6, 9, 7, 8]], ["ann_reduce", "an", ["none", "unroll"]], ["ann_spatial", "an", ["none", "vec", "none"]]]}], "v": 0.1, "r": [[0.0038207925], 0, 1.7591049671173096, 1549478477.5367444]} 2 | {"i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_conv2d", [["TENSOR", [1, 64, 112, 112], "float32"], ["TENSOR", [32, 64, 5, 5], "float32"], [1, 1], [2, 2], "NCHW", "float32"], {}, ["conv2d", [1, 64, 112, 112, "float32"], [32, 64, 5, 5, "float32"], [1, 1], [2, 2], "NCHW", "float32"], {"t": "direct", "i": 66074, "c": null, "e": [["tile_co", "sp", [8, 4]], ["tile_oh", "sp", [28, 4]], ["tile_ow", "sp", [56, 2]], ["reorder_0", "re", [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]], ["ann_reduce", "an", ["unroll", "unroll"]], ["ann_spatial", "an", ["unroll", "unroll", "vec"]]]}], "v": 0.1, "r": [[0.066266069], 0, 5.3768932819366455, 1549479743.947523]} 3 | {"i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_conv2d", [["TENSOR", [1, 128, 56, 56], "float32"], ["TENSOR", [64, 128, 5, 5], "float32"], [1, 1], [2, 2], "NCHW", "float32"], {}, ["conv2d", [1, 128, 56, 56, "float32"], [64, 128, 5, 5, "float32"], [1, 1], [2, 2], "NCHW", "float32"], {"t": "direct", "i": 49401, "c": null, "e": [["tile_co", "sp", [16, 4]], ["tile_oh", "sp", [28, 2]], ["tile_ow", "sp", [14, 4]], ["reorder_0", "re", [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]], ["ann_reduce", "an", ["unroll", "unroll"]], ["ann_spatial", "an", ["unroll", "unroll", "vec"]]]}], "v": 0.1, "r": [[0.051313814], 0, 9.66435980796814, 1549480240.59895]} 4 | {"i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_conv2d", [["TENSOR", [1, 256, 28, 28], "float32"], ["TENSOR", [128, 256, 5, 5], "float32"], [1, 1], [2, 2], "NCHW", "float32"], {}, ["conv2d", [1, 256, 28, 28, "float32"], [128, 256, 5, 5, "float32"], [1, 1], [2, 2], "NCHW", "float32"], {"t": "direct", "i": 31787, "c": null, "e": [["tile_co", "sp", [16, 8]], ["tile_oh", "sp", [14, 2]], ["tile_ow", "sp", [7, 4]], ["reorder_0", "re", [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]], ["ann_reduce", "an", ["unroll", "unroll"]], ["ann_spatial", "an", ["unroll", "unroll", "vec"]]]}], "v": 0.1, "r": [[0.04352342825], 0, 3.9346649646759033, 1549480763.6560743]} 5 | {"i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_conv2d", [["TENSOR", [1, 512, 14, 14], "float32"], ["TENSOR", [256, 512, 5, 5], "float32"], [1, 1], [2, 2], "NCHW", "float32"], {}, ["conv2d", [1, 512, 14, 14, "float32"], [256, 512, 5, 5, "float32"], [1, 1], [2, 2], "NCHW", "float32"], {"t": "direct", "i": 6734, "c": null, "e": [["tile_co", "sp", [64, 4]], ["tile_oh", "sp", [14, 1]], ["tile_ow", "sp", [1, 14]], ["reorder_0", "re", [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]], ["ann_reduce", "an", ["unroll", "unroll"]], ["ann_spatial", "an", ["none", "unroll", "vec"]]]}], "v": 0.1, "r": [[0.03991823], 0, 3.566746950149536, 1549481391.9678566]} 6 | {"i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_conv2d", [["TENSOR", [1, 1024, 7, 7], "float32"], ["TENSOR", [512, 1024, 5, 5], "float32"], [1, 1], [2, 2], "NCHW", "float32"], {}, ["conv2d", [1, 1024, 7, 7, "float32"], [512, 1024, 5, 5, "float32"], [1, 1], [2, 2], "NCHW", "float32"], {"t": "direct", "i": 3463, "c": null, "e": [["tile_co", "sp", [64, 8]], ["tile_oh", "sp", [7, 1]], ["tile_ow", "sp", [1, 7]], ["reorder_0", "re", [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]], ["ann_reduce", "an", ["unroll", "unroll"]], ["ann_spatial", "an", ["unroll", "none", "vec"]]]}], "v": 0.1, "r": [[0.036470917], 0, 5.302130937576294, 1549481777.5447953]} 7 | {"i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_conv2d", [["TENSOR", [1, 1024, 7, 7], "float32"], ["TENSOR", [1024, 1024, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 1024, 7, 7, "float32"], [1024, 1024, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"t": "direct", "i": 4591, "c": null, "e": [["tile_co", "sp", [64, 16]], ["tile_oh", "sp", [1, 7]], ["tile_ow", "sp", [7, 1]], ["reorder_0", "re", [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]], ["ann_reduce", "an", ["none", "none"]], ["ann_spatial", "an", ["unroll", "unroll", "vec"]]]}], "v": 0.1, "r": [[0.0043714415], 0, 2.8076863288879395, 1549482462.154945]} 8 | {"i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 1024, 7, 7], "float32"], ["TENSOR", [1024, 1, 3, 3], "float32"], [1, 1], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 1024, 7, 7, "float32"], [1024, 1, 3, 3, "float32"], [1, 1], [1, 1], "float32"], {"t": "direct", "i": 455, "c": null, "e": [["tile_c", "sp", [64, 16]], ["tile_h", "sp", [1, 7]], ["tile_w", "sp", [7, 1]], ["ann", "an", ["unroll", "none", "vec"]]]}], "v": 0.1, "r": [[0.000124567], 0, 0.8551723957061768, 1549482638.7107537]} 9 | {"i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_conv2d", [["TENSOR", [1, 512, 7, 7], "float32"], ["TENSOR", [1024, 512, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 512, 7, 7, "float32"], [1024, 512, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"t": "direct", "i": 3667, "c": null, "e": [["tile_co", "sp", [64, 16]], ["tile_oh", "sp", [1, 7]], ["tile_ow", "sp", [7, 1]], ["reorder_0", "re", [0, 1, 2, 3, 4, 5, 6, 9, 7, 8]], ["ann_reduce", "an", ["none", "unroll"]], ["ann_spatial", "an", ["unroll", "none", "vec"]]]}], "v": 0.1, "r": [[0.0021189205], 0, 1.0700788497924805, 1549482899.913476]} 10 | {"i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 512, 14, 14], "float32"], ["TENSOR", [512, 1, 3, 3], "float32"], [2, 2], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 512, 14, 14, "float32"], [512, 1, 3, 3, "float32"], [2, 2], [1, 1], "float32"], {"t": "direct", "i": 412, "c": null, "e": [["tile_c", "sp", [128, 4]], ["tile_h", "sp", [1, 7]], ["tile_w", "sp", [7, 1]], ["ann", "an", ["unroll", "none", "vec"]]]}], "v": 0.1, "r": [[0.0001378475], 0, 1.3542444705963135, 1549483028.7694135]} 11 | {"i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_conv2d", [["TENSOR", [1, 512, 14, 14], "float32"], ["TENSOR", [512, 512, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 512, 14, 14, "float32"], [512, 512, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"t": "direct", "i": 12884, "c": null, "e": [["tile_co", "sp", [32, 16]], ["tile_oh", "sp", [14, 1]], ["tile_ow", "sp", [2, 7]], ["reorder_0", "re", [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]], ["ann_reduce", "an", ["none", "none"]], ["ann_spatial", "an", ["unroll", "none", "vec"]]]}], "v": 0.1, "r": [[0.00439193], 0, 2.0783283710479736, 1549483215.359623]} 12 | {"i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 512, 14, 14], "float32"], ["TENSOR", [512, 1, 3, 3], "float32"], [1, 1], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 512, 14, 14, "float32"], [512, 1, 3, 3, "float32"], [1, 1], [1, 1], "float32"], {"t": "direct", "i": 2112, "c": null, "e": [["tile_c", "sp", [128, 4]], ["tile_h", "sp", [1, 14]], ["tile_w", "sp", [14, 1]], ["ann", "an", ["unroll", "unroll", "vec"]]]}], "v": 0.1, "r": [[0.0002192565], 0, 0.5895946025848389, 1549483345.4555154]} 13 | {"i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_conv2d", [["TENSOR", [1, 256, 14, 14], "float32"], ["TENSOR", [512, 256, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 256, 14, 14, "float32"], [512, 256, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"t": "direct", "i": 13364, "c": null, "e": [["tile_co", "sp", [32, 16]], ["tile_oh", "sp", [14, 1]], ["tile_ow", "sp", [2, 7]], ["reorder_0", "re", [0, 1, 2, 3, 4, 5, 6, 9, 7, 8]], ["ann_reduce", "an", ["none", "unroll"]], ["ann_spatial", "an", ["unroll", "none", "vec"]]]}], "v": 0.1, "r": [[0.00211282625], 0, 0.6594874858856201, 1549483478.8246555]} 14 | {"i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 256, 28, 28], "float32"], ["TENSOR", [256, 1, 3, 3], "float32"], [2, 2], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 256, 28, 28, "float32"], [256, 1, 3, 3, "float32"], [2, 2], [1, 1], "float32"], {"t": "direct", "i": 1928, "c": null, "e": [["tile_c", "sp", [64, 4]], ["tile_h", "sp", [2, 7]], ["tile_w", "sp", [7, 2]], ["ann", "an", ["unroll", "unroll", "vec"]]]}], "v": 0.1, "r": [[0.000224376], 0, 0.7611992359161377, 1549483607.5230222]} 15 | {"i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_conv2d", [["TENSOR", [1, 256, 28, 28], "float32"], ["TENSOR", [256, 256, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 256, 28, 28, "float32"], [256, 256, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"t": "direct", "i": 6646, "c": null, "e": [["tile_co", "sp", [16, 16]], ["tile_oh", "sp", [28, 1]], ["tile_ow", "sp", [4, 7]], ["reorder_0", "re", [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]], ["ann_reduce", "an", ["unroll", "none"]], ["ann_spatial", "an", ["none", "none", "vec"]]]}], "v": 0.1, "r": [[0.004570014], 0, 1.2378997802734375, 1549483753.9382432]} 16 | {"i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 256, 28, 28], "float32"], ["TENSOR", [256, 1, 3, 3], "float32"], [1, 1], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 256, 28, 28, "float32"], [256, 1, 3, 3, "float32"], [1, 1], [1, 1], "float32"], {"t": "direct", "i": 3332, "c": null, "e": [["tile_c", "sp", [64, 4]], ["tile_h", "sp", [2, 14]], ["tile_w", "sp", [14, 2]], ["ann", "an", ["unroll", "none", "vec"]]]}], "v": 0.1, "r": [[0.0005125755], 0, 0.9667601585388184, 1549483890.7859952]} 17 | {"i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_conv2d", [["TENSOR", [1, 128, 28, 28], "float32"], ["TENSOR", [256, 128, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 128, 28, 28, "float32"], [256, 128, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"t": "direct", "i": 14692, "c": null, "e": [["tile_co", "sp", [16, 16]], ["tile_oh", "sp", [28, 1]], ["tile_ow", "sp", [7, 4]], ["reorder_0", "re", [0, 1, 2, 3, 4, 5, 6, 9, 7, 8]], ["ann_reduce", "an", ["unroll", "none"]], ["ann_spatial", "an", ["none", "unroll", "vec"]]]}], "v": 0.1, "r": [[0.00236881275], 0, 0.5344700813293457, 1549484047.5860903]} 18 | {"i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 128, 56, 56], "float32"], ["TENSOR", [128, 1, 3, 3], "float32"], [2, 2], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 128, 56, 56, "float32"], [128, 1, 3, 3, "float32"], [2, 2], [1, 1], "float32"], {"t": "direct", "i": 3818, "c": null, "e": [["tile_c", "sp", [32, 4]], ["tile_h", "sp", [4, 7]], ["tile_w", "sp", [14, 2]], ["ann", "an", ["unroll", "unroll", "vec"]]]}], "v": 0.1, "r": [[0.00043479925], 0, 0.5477871894836426, 1549484179.4244192]} 19 | {"i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_conv2d", [["TENSOR", [1, 128, 56, 56], "float32"], ["TENSOR", [128, 128, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 128, 56, 56, "float32"], [128, 128, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"t": "direct", "i": 21315, "c": null, "e": [["tile_co", "sp", [16, 8]], ["tile_oh", "sp", [56, 1]], ["tile_ow", "sp", [4, 14]], ["reorder_0", "re", [0, 1, 2, 3, 4, 5, 6, 9, 7, 8]], ["ann_reduce", "an", ["none", "none"]], ["ann_spatial", "an", ["none", "unroll", "vec"]]]}], "v": 0.1, "r": [[0.005667829], 0, 1.7964444160461426, 1549484435.3359888]} 20 | {"i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 128, 56, 56], "float32"], ["TENSOR", [128, 1, 3, 3], "float32"], [1, 1], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 128, 56, 56, "float32"], [128, 1, 3, 3, "float32"], [1, 1], [1, 1], "float32"], {"t": "direct", "i": 6698, "c": null, "e": [["tile_c", "sp", [32, 4]], ["tile_h", "sp", [4, 14]], ["tile_w", "sp", [56, 1]], ["ann", "an", ["unroll", "unroll", "vec"]]]}], "v": 0.1, "r": [[0.0010919125], 0, 1.353485107421875, 1549484632.265847]} 21 | {"i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_conv2d", [["TENSOR", [1, 64, 56, 56], "float32"], ["TENSOR", [128, 64, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 64, 56, 56, "float32"], [128, 64, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"t": "direct", "i": 53955, "c": null, "e": [["tile_co", "sp", [16, 8]], ["tile_oh", "sp", [56, 1]], ["tile_ow", "sp", [8, 7]], ["reorder_0", "re", [0, 1, 2, 3, 4, 5, 6, 9, 7, 8]], ["ann_reduce", "an", ["none", "none"]], ["ann_spatial", "an", ["unroll", "unroll", "vec"]]]}], "v": 0.1, "r": [[0.00302072925], 0, 1.6285948753356934, 1549484777.1761692]} 22 | {"i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 64, 112, 112], "float32"], ["TENSOR", [64, 1, 3, 3], "float32"], [2, 2], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 64, 112, 112, "float32"], [64, 1, 3, 3, "float32"], [2, 2], [1, 1], "float32"], {"t": "direct", "i": 2417, "c": null, "e": [["tile_c", "sp", [16, 4]], ["tile_h", "sp", [28, 2]], ["tile_w", "sp", [8, 7]], ["ann", "an", ["none", "unroll", "vec"]]]}], "v": 0.1, "r": [[0.00089098925], 0, 1.8967888355255127, 1549484944.7167633]} 23 | {"i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_conv2d", [["TENSOR", [1, 32, 112, 112], "float32"], ["TENSOR", [64, 32, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 32, 112, 112, "float32"], [64, 32, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"t": "direct", "i": 75618, "c": null, "e": [["tile_co", "sp", [4, 16]], ["tile_oh", "sp", [28, 4]], ["tile_ow", "sp", [112, 1]], ["reorder_0", "re", [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]], ["ann_reduce", "an", ["unroll", "none"]], ["ann_spatial", "an", ["unroll", "unroll", "vec"]]]}], "v": 0.1, "r": [[0.0042586435], 0, 0.5511445999145508, 1549485179.0085921]} 24 | {"i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 32, 112, 112], "float32"], ["TENSOR", [32, 1, 3, 3], "float32"], [1, 1], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 32, 112, 112, "float32"], [32, 1, 3, 3, "float32"], [1, 1], [1, 1], "float32"], {"t": "direct", "i": 6086, "c": null, "e": [["tile_c", "sp", [8, 4]], ["tile_h", "sp", [14, 8]], ["tile_w", "sp", [56, 2]], ["ann", "an", ["unroll", "none", "vec"]]]}], "v": 0.1, "r": [[0.0013867625], 0, 0.6528422832489014, 1549485331.7410119]} 25 | {"i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_conv2d", [["TENSOR", [1, 3, 224, 224], "float32"], ["TENSOR", [32, 3, 3, 3], "float32"], [2, 2], [1, 1], "NCHW", "float32"], {}, ["conv2d", [1, 3, 224, 224, "float32"], [32, 3, 3, 3, "float32"], [2, 2], [1, 1], "NCHW", "float32"], {"t": "direct", "i": 49329, "c": null, "e": [["tile_co", "sp", [4, 8]], ["tile_oh", "sp", [56, 2]], ["tile_ow", "sp", [28, 4]], ["reorder_0", "re", [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]], ["ann_reduce", "an", ["none", "unroll"]], ["ann_spatial", "an", ["unroll", "none", "vec"]]]}], "v": 0.1, "r": [[0.00171880075], 0, 0.7888402938842773, 1549485642.3346002]} 26 | -------------------------------------------------------------------------------- /tvm_compile/tuning/tx2-cpu.mobilenet-nnconv5dw-skipadd-pruned.trials=1000.stop=250.log: -------------------------------------------------------------------------------- 1 | {"v": 0.1, "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_conv2d", [["TENSOR", [1, 16, 224, 224], "float32"], ["TENSOR", [1, 16, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 16, 224, 224, "float32"], [1, 16, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"t": "direct", "c": null, "e": [["tile_co", "sp", [1, 1]], ["tile_oh", "sp", [1, 224]], ["tile_ow", "sp", [14, 16]], ["reorder_0", "re", [0, 1, 2, 3, 4, 5, 6, 9, 7, 8]], ["ann_reduce", "an", ["unroll", "unroll"]], ["ann_spatial", "an", ["none", "vec", "unroll"]]], "i": 9155}], "r": [[0.00116569925], 0, 1.1003811359405518, 1548189982.8573406]} 2 | {"v": 0.1, "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_conv2d", [["TENSOR", [1, 56, 112, 112], "float32"], ["TENSOR", [16, 56, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 56, 112, 112, "float32"], [16, 56, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"t": "direct", "c": null, "e": [["tile_co", "sp", [4, 4]], ["tile_oh", "sp", [112, 1]], ["tile_ow", "sp", [7, 16]], ["reorder_0", "re", [0, 1, 2, 3, 4, 5, 6, 9, 7, 8]], ["ann_reduce", "an", ["none", "unroll"]], ["ann_spatial", "an", ["unroll", "unroll", "vec"]]], "i": 53802}], "r": [[0.00162019275], 0, 0.8826901912689209, 1548190186.063045]} 3 | {"v": 0.1, "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 56, 112, 112], "float32"], ["TENSOR", [56, 1, 5, 5], "float32"], [1, 1], [2, 2], "float32"], {}, ["depthwise_conv2d_nchw", [1, 56, 112, 112, "float32"], [56, 1, 5, 5, "float32"], [1, 1], [2, 2], "float32"], {"t": "direct", "c": null, "e": [["tile_c", "sp", [14, 4]], ["tile_h", "sp", [7, 16]], ["tile_w", "sp", [56, 2]], ["ann", "an", ["unroll", "unroll", "vec"]]], "i": 10530}], "r": [[0.002066204], 0, 2.2892937660217285, 1548190458.9113562]} 4 | {"v": 0.1, "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_conv2d", [["TENSOR", [1, 120, 56, 56], "float32"], ["TENSOR", [56, 120, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 120, 56, 56, "float32"], [56, 120, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"t": "direct", "c": null, "e": [["tile_co", "sp", [7, 8]], ["tile_oh", "sp", [56, 1]], ["tile_ow", "sp", [4, 14]], ["reorder_0", "re", [0, 1, 2, 3, 4, 5, 6, 9, 7, 8]], ["ann_reduce", "an", ["none", "none"]], ["ann_spatial", "an", ["unroll", "unroll", "vec"]]], "i": 54084}], "r": [[0.001725823], 0, 0.8528237342834473, 1548190698.1614182]} 5 | {"v": 0.1, "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 120, 56, 56], "float32"], ["TENSOR", [120, 1, 5, 5], "float32"], [1, 1], [2, 2], "float32"], {}, ["depthwise_conv2d_nchw", [1, 120, 56, 56, "float32"], [120, 1, 5, 5, "float32"], [1, 1], [2, 2], "float32"], {"t": "direct", "c": null, "e": [["tile_c", "sp", [30, 4]], ["tile_h", "sp", [4, 14]], ["tile_w", "sp", [56, 1]], ["ann", "an", ["unroll", "unroll", "vec"]]], "i": 13395}], "r": [[0.00095593075], 0, 1.2980289459228516, 1548190892.7128842]} 6 | {"v": 0.1, "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_conv2d", [["TENSOR", [1, 256, 28, 28], "float32"], ["TENSOR", [120, 256, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 256, 28, 28, "float32"], [120, 256, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"t": "direct", "c": null, "e": [["tile_co", "sp", [15, 8]], ["tile_oh", "sp", [14, 2]], ["tile_ow", "sp", [4, 7]], ["reorder_0", "re", [0, 1, 2, 3, 4, 5, 6, 9, 7, 8]], ["ann_reduce", "an", ["unroll", "unroll"]], ["ann_spatial", "an", ["unroll", "none", "vec"]]], "i": 50422}], "r": [[0.00122721775], 0, 1.411975622177124, 1548191121.6755435]} 7 | {"v": 0.1, "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 256, 28, 28], "float32"], ["TENSOR", [256, 1, 5, 5], "float32"], [1, 1], [2, 2], "float32"], {}, ["depthwise_conv2d_nchw", [1, 256, 28, 28, "float32"], [256, 1, 5, 5, "float32"], [1, 1], [2, 2], "float32"], {"t": "direct", "c": null, "e": [["tile_c", "sp", [32, 8]], ["tile_h", "sp", [7, 4]], ["tile_w", "sp", [14, 2]], ["ann", "an", ["unroll", "none", "vec"]]], "i": 3315}], "r": [[0.000616226], 0, 0.5349960327148438, 1548191327.3241053]} 8 | {"v": 0.1, "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_conv2d", [["TENSOR", [1, 200, 14, 14], "float32"], ["TENSOR", [256, 200, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 200, 14, 14, "float32"], [256, 200, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"t": "direct", "c": null, "e": [["tile_co", "sp", [64, 4]], ["tile_oh", "sp", [14, 1]], ["tile_ow", "sp", [1, 14]], ["reorder_0", "re", [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]], ["ann_reduce", "an", ["none", "unroll"]], ["ann_spatial", "an", ["unroll", "unroll", "vec"]]], "i": 15374}], "r": [[0.00046795525], 0, 1.549518346786499, 1548191591.9854326]} 9 | {"v": 0.1, "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 200, 14, 14], "float32"], ["TENSOR", [200, 1, 5, 5], "float32"], [1, 1], [2, 2], "float32"], {}, ["depthwise_conv2d_nchw", [1, 200, 14, 14, "float32"], [200, 1, 5, 5, "float32"], [1, 1], [2, 2], "float32"], {"t": "direct", "c": null, "e": [["tile_c", "sp", [50, 4]], ["tile_h", "sp", [1, 14]], ["tile_w", "sp", [7, 2]], ["ann", "an", ["unroll", "none", "vec"]]], "i": 2006}], "r": [[0.0001491905], 0, 1.5461997985839844, 1548191767.1845741]} 10 | {"v": 0.1, "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_conv2d", [["TENSOR", [1, 512, 7, 7], "float32"], ["TENSOR", [200, 512, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 512, 7, 7, "float32"], [200, 512, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"t": "direct", "c": null, "e": [["tile_co", "sp", [25, 8]], ["tile_oh", "sp", [7, 1]], ["tile_ow", "sp", [1, 7]], ["reorder_0", "re", [0, 1, 2, 3, 4, 5, 6, 9, 7, 8]], ["ann_reduce", "an", ["none", "none"]], ["ann_spatial", "an", ["unroll", "unroll", "vec"]]], "i": 5068}], "r": [[0.0003510605], 0, 0.4742095470428467, 1548191892.1782413]} 11 | {"v": 0.1, "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 512, 7, 7], "float32"], ["TENSOR", [512, 1, 5, 5], "float32"], [1, 1], [2, 2], "float32"], {}, ["depthwise_conv2d_nchw", [1, 512, 7, 7, "float32"], [512, 1, 5, 5, "float32"], [1, 1], [2, 2], "float32"], {"t": "direct", "c": null, "e": [["tile_c", "sp", [64, 8]], ["tile_h", "sp", [1, 7]], ["tile_w", "sp", [7, 1]], ["ann", "an", ["unroll", "unroll", "vec"]]], "i": 533}], "r": [[0.0001124635], 0, 0.676560640335083, 1548191991.434562]} 12 | {"v": 0.1, "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_conv2d", [["TENSOR", [1, 480, 7, 7], "float32"], ["TENSOR", [512, 480, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 480, 7, 7, "float32"], [512, 480, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"t": "direct", "c": null, "e": [["tile_co", "sp", [32, 16]], ["tile_oh", "sp", [7, 1]], ["tile_ow", "sp", [1, 7]], ["reorder_0", "re", [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]], ["ann_reduce", "an", ["none", "none"]], ["ann_spatial", "an", ["none", "unroll", "vec"]]], "i": 1624}], "r": [[0.0005328295], 0, 1.406437635421753, 1548192119.4241006]} 13 | {"v": 0.1, "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 480, 7, 7], "float32"], ["TENSOR", [480, 1, 3, 3], "float32"], [1, 1], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 480, 7, 7, "float32"], [480, 1, 3, 3, "float32"], [1, 1], [1, 1], "float32"], {"t": "direct", "c": null, "e": [["tile_c", "sp", [30, 16]], ["tile_h", "sp", [1, 7]], ["tile_w", "sp", [7, 1]], ["ann", "an", ["unroll", "none", "vec"]]], "i": 994}], "r": [[6.130375e-05], 0, 1.4277822971343994, 1548192234.7198637]} 14 | {"v": 0.1, "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_conv2d", [["TENSOR", [1, 328, 7, 7], "float32"], ["TENSOR", [480, 328, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 328, 7, 7, "float32"], [480, 328, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"t": "direct", "c": null, "e": [["tile_co", "sp", [30, 16]], ["tile_oh", "sp", [7, 1]], ["tile_ow", "sp", [1, 7]], ["reorder_0", "re", [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]], ["ann_reduce", "an", ["unroll", "none"]], ["ann_spatial", "an", ["unroll", "unroll", "vec"]]], "i": 10426}], "r": [[0.00031776475], 0, 0.4927072525024414, 1548192479.2059495]} 15 | {"v": 0.1, "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 328, 14, 14], "float32"], ["TENSOR", [328, 1, 3, 3], "float32"], [2, 2], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 328, 14, 14, "float32"], [328, 1, 3, 3, "float32"], [2, 2], [1, 1], "float32"], {"t": "direct", "c": null, "e": [["tile_c", "sp", [82, 4]], ["tile_h", "sp", [1, 7]], ["tile_w", "sp", [7, 1]], ["ann", "an", ["unroll", "none", "vec"]]], "i": 330}], "r": [[8.818325e-05], 0, 0.7043302059173584, 1548192581.0602925]} 16 | {"v": 0.1, "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_conv2d", [["TENSOR", [1, 296, 14, 14], "float32"], ["TENSOR", [328, 296, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 296, 14, 14, "float32"], [328, 296, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"t": "direct", "c": null, "e": [["tile_co", "sp", [41, 8]], ["tile_oh", "sp", [14, 1]], ["tile_ow", "sp", [1, 14]], ["reorder_0", "re", [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]], ["ann_reduce", "an", ["none", "none"]], ["ann_spatial", "an", ["none", "unroll", "vec"]]], "i": 5219}], "r": [[0.00074605375], 0, 2.0092272758483887, 1548192829.0489042]} 17 | {"v": 0.1, "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 296, 14, 14], "float32"], ["TENSOR", [296, 1, 3, 3], "float32"], [1, 1], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 296, 14, 14, "float32"], [296, 1, 3, 3, "float32"], [1, 1], [1, 1], "float32"], {"t": "direct", "c": null, "e": [["tile_c", "sp", [74, 4]], ["tile_h", "sp", [1, 14]], ["tile_w", "sp", [7, 2]], ["ann", "an", ["unroll", "none", "vec"]]], "i": 1338}], "r": [[0.0001298155], 0, 0.9202358722686768, 1548192962.0136871]} 18 | {"v": 0.1, "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_conv2d", [["TENSOR", [1, 288, 14, 14], "float32"], ["TENSOR", [296, 288, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 288, 14, 14, "float32"], [296, 288, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"t": "direct", "c": null, "e": [["tile_co", "sp", [37, 8]], ["tile_oh", "sp", [7, 2]], ["tile_ow", "sp", [2, 7]], ["reorder_0", "re", [0, 1, 2, 3, 4, 5, 6, 9, 7, 8]], ["ann_reduce", "an", ["unroll", "unroll"]], ["ann_spatial", "an", ["unroll", "none", "vec"]]], "i": 11211}], "r": [[0.00079630925], 0, 0.8467371463775635, 1548193166.7224784]} 19 | {"v": 0.1, "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 288, 14, 14], "float32"], ["TENSOR", [288, 1, 3, 3], "float32"], [1, 1], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 288, 14, 14, "float32"], [288, 1, 3, 3, "float32"], [1, 1], [1, 1], "float32"], {"t": "direct", "c": null, "e": [["tile_c", "sp", [72, 4]], ["tile_h", "sp", [1, 14]], ["tile_w", "sp", [7, 2]], ["ann", "an", ["unroll", "none", "vec"]]], "i": 3009}], "r": [[0.0001182315], 0, 0.7226011753082275, 1548193289.163559]} 20 | {"v": 0.1, "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_conv2d", [["TENSOR", [1, 272, 14, 14], "float32"], ["TENSOR", [288, 272, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 272, 14, 14, "float32"], [288, 272, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"t": "direct", "c": null, "e": [["tile_co", "sp", [36, 8]], ["tile_oh", "sp", [1, 14]], ["tile_ow", "sp", [7, 2]], ["reorder_0", "re", [0, 1, 2, 3, 4, 5, 6, 9, 7, 8]], ["ann_reduce", "an", ["none", "none"]], ["ann_spatial", "an", ["unroll", "none", "vec"]]], "i": 23459}], "r": [[0.00073303275], 0, 0.9118449687957764, 1548193427.542316]} 21 | {"v": 0.1, "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 272, 14, 14], "float32"], ["TENSOR", [272, 1, 3, 3], "float32"], [1, 1], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 272, 14, 14, "float32"], [272, 1, 3, 3, "float32"], [1, 1], [1, 1], "float32"], {"t": "direct", "c": null, "e": [["tile_c", "sp", [68, 4]], ["tile_h", "sp", [1, 14]], ["tile_w", "sp", [7, 2]], ["ann", "an", ["unroll", "none", "vec"]]], "i": 1672}], "r": [[0.0001212955], 0, 1.2856500148773193, 1548193581.368387]} 22 | {"v": 0.1, "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_conv2d", [["TENSOR", [1, 376, 14, 14], "float32"], ["TENSOR", [272, 376, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 376, 14, 14, "float32"], [272, 376, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"t": "direct", "c": null, "e": [["tile_co", "sp", [34, 8]], ["tile_oh", "sp", [7, 2]], ["tile_ow", "sp", [2, 7]], ["reorder_0", "re", [0, 1, 2, 3, 4, 5, 6, 9, 7, 8]], ["ann_reduce", "an", ["none", "unroll"]], ["ann_spatial", "an", ["none", "unroll", "vec"]]], "i": 6973}], "r": [[0.0008803325], 0, 0.6495048999786377, 1548193731.0867894]} 23 | {"v": 0.1, "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 376, 14, 14], "float32"], ["TENSOR", [376, 1, 3, 3], "float32"], [1, 1], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 376, 14, 14, "float32"], [376, 1, 3, 3, "float32"], [1, 1], [1, 1], "float32"], {"t": "direct", "c": null, "e": [["tile_c", "sp", [94, 4]], ["tile_h", "sp", [1, 14]], ["tile_w", "sp", [14, 1]], ["ann", "an", ["unroll", "none", "vec"]]], "i": 1306}], "r": [[0.0001601915], 0, 1.889883279800415, 1548193863.9992387]} 24 | {"v": 0.1, "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_conv2d", [["TENSOR", [1, 408, 14, 14], "float32"], ["TENSOR", [376, 408, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 408, 14, 14, "float32"], [376, 408, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"t": "direct", "c": null, "e": [["tile_co", "sp", [47, 8]], ["tile_oh", "sp", [7, 2]], ["tile_ow", "sp", [2, 7]], ["reorder_0", "re", [0, 1, 2, 3, 4, 5, 6, 9, 7, 8]], ["ann_reduce", "an", ["none", "none"]], ["ann_spatial", "an", ["unroll", "unroll", "vec"]]], "i": 13515}], "r": [[0.0011498275], 0, 0.4511895179748535, 1548193991.0625558]} 25 | {"v": 0.1, "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 408, 14, 14], "float32"], ["TENSOR", [408, 1, 3, 3], "float32"], [1, 1], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 408, 14, 14, "float32"], [408, 1, 3, 3, "float32"], [1, 1], [1, 1], "float32"], {"t": "direct", "c": null, "e": [["tile_c", "sp", [102, 4]], ["tile_h", "sp", [1, 14]], ["tile_w", "sp", [7, 2]], ["ann", "an", ["unroll", "none", "vec"]]], "i": 2675}], "r": [[0.000155239], 0, 0.7136528491973877, 1548194156.493113]} 26 | {"v": 0.1, "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_conv2d", [["TENSOR", [1, 256, 14, 14], "float32"], ["TENSOR", [408, 256, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 256, 14, 14, "float32"], [408, 256, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"t": "direct", "c": null, "e": [["tile_co", "sp", [51, 8]], ["tile_oh", "sp", [14, 1]], ["tile_ow", "sp", [1, 14]], ["reorder_0", "re", [0, 1, 2, 3, 4, 5, 6, 9, 7, 8]], ["ann_reduce", "an", ["none", "none"]], ["ann_spatial", "an", ["none", "unroll", "vec"]]], "i": 10693}], "r": [[0.00085011425], 0, 1.063706874847412, 1548194341.4088423]} 27 | {"v": 0.1, "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 256, 28, 28], "float32"], ["TENSOR", [256, 1, 3, 3], "float32"], [2, 2], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 256, 28, 28, "float32"], [256, 1, 3, 3, "float32"], [2, 2], [1, 1], "float32"], {"t": "direct", "c": null, "e": [["tile_c", "sp", [32, 8]], ["tile_h", "sp", [7, 2]], ["tile_w", "sp", [2, 7]], ["ann", "an", ["unroll", "unroll", "vec"]]], "i": 1956}], "r": [[0.00021250325], 0, 0.9065120220184326, 1548194495.7803543]} 28 | {"v": 0.1, "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_conv2d", [["TENSOR", [1, 144, 28, 28], "float32"], ["TENSOR", [256, 144, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 144, 28, 28, "float32"], [256, 144, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"t": "direct", "c": null, "e": [["tile_co", "sp", [32, 8]], ["tile_oh", "sp", [28, 1]], ["tile_ow", "sp", [2, 14]], ["reorder_0", "re", [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]], ["ann_reduce", "an", ["none", "none"]], ["ann_spatial", "an", ["unroll", "unroll", "vec"]]], "i": 33915}], "r": [[0.00119772325], 0, 0.663311243057251, 1548194752.2460217]} 29 | {"v": 0.1, "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 144, 28, 28], "float32"], ["TENSOR", [144, 1, 3, 3], "float32"], [1, 1], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 144, 28, 28, "float32"], [144, 1, 3, 3, "float32"], [1, 1], [1, 1], "float32"], {"t": "direct", "c": null, "e": [["tile_c", "sp", [36, 4]], ["tile_h", "sp", [4, 7]], ["tile_w", "sp", [7, 4]], ["ann", "an", ["unroll", "unroll", "vec"]]], "i": 7248}], "r": [[0.00026295], 0, 1.8018908500671387, 1548194896.5325284]} 30 | {"v": 0.1, "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_conv2d", [["TENSOR", [1, 120, 28, 28], "float32"], ["TENSOR", [144, 120, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 120, 28, 28, "float32"], [144, 120, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"t": "direct", "c": null, "e": [["tile_co", "sp", [18, 8]], ["tile_oh", "sp", [28, 1]], ["tile_ow", "sp", [4, 7]], ["reorder_0", "re", [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]], ["ann_reduce", "an", ["unroll", "none"]], ["ann_spatial", "an", ["none", "unroll", "vec"]]], "i": 24035}], "r": [[0.00057663775], 0, 0.8634822368621826, 1548195053.5468194]} 31 | {"v": 0.1, "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 120, 56, 56], "float32"], ["TENSOR", [120, 1, 3, 3], "float32"], [2, 2], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 120, 56, 56, "float32"], [120, 1, 3, 3, "float32"], [2, 2], [1, 1], "float32"], {"t": "direct", "c": null, "e": [["tile_c", "sp", [30, 4]], ["tile_h", "sp", [4, 7]], ["tile_w", "sp", [7, 4]], ["ann", "an", ["unroll", "unroll", "vec"]]], "i": 7731}], "r": [[0.00031674125], 0, 1.1324889659881592, 1548195190.1497178]} 32 | {"v": 0.1, "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_conv2d", [["TENSOR", [1, 88, 56, 56], "float32"], ["TENSOR", [120, 88, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 88, 56, 56, "float32"], [120, 88, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"t": "direct", "c": null, "e": [["tile_co", "sp", [15, 8]], ["tile_oh", "sp", [56, 1]], ["tile_ow", "sp", [7, 8]], ["reorder_0", "re", [0, 1, 2, 3, 4, 5, 6, 9, 7, 8]], ["ann_reduce", "an", ["unroll", "none"]], ["ann_spatial", "an", ["unroll", "unroll", "vec"]]], "i": 112134}], "r": [[0.001929271], 0, 1.1447086334228516, 1548195327.1973894]} 33 | {"v": 0.1, "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 88, 56, 56], "float32"], ["TENSOR", [88, 1, 3, 3], "float32"], [1, 1], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 88, 56, 56, "float32"], [88, 1, 3, 3, "float32"], [1, 1], [1, 1], "float32"], {"t": "direct", "c": null, "e": [["tile_c", "sp", [22, 4]], ["tile_h", "sp", [56, 1]], ["tile_w", "sp", [14, 4]], ["ann", "an", ["none", "unroll", "vec"]]], "i": 2690}], "r": [[0.00058562175], 0, 1.124485731124878, 1548195501.234142]} 34 | {"v": 0.1, "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_conv2d", [["TENSOR", [1, 56, 56, 56], "float32"], ["TENSOR", [88, 56, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 56, 56, 56, "float32"], [88, 56, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"t": "direct", "c": null, "e": [["tile_co", "sp", [11, 8]], ["tile_oh", "sp", [56, 1]], ["tile_ow", "sp", [7, 8]], ["reorder_0", "re", [0, 1, 2, 3, 4, 5, 6, 9, 7, 8]], ["ann_reduce", "an", ["unroll", "none"]], ["ann_spatial", "an", ["none", "none", "vec"]]], "i": 11011}], "r": [[0.000900928], 0, 1.7441248893737793, 1548195753.743164]} 35 | {"v": 0.1, "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 56, 112, 112], "float32"], ["TENSOR", [56, 1, 3, 3], "float32"], [2, 2], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 56, 112, 112, "float32"], [56, 1, 3, 3, "float32"], [2, 2], [1, 1], "float32"], {"t": "direct", "c": null, "e": [["tile_c", "sp", [14, 4]], ["tile_h", "sp", [28, 2]], ["tile_w", "sp", [28, 2]], ["ann", "an", ["unroll", "unroll", "vec"]]], "i": 6730}], "r": [[0.0007634405], 0, 0.5247933864593506, 1548195883.5107229]} 36 | {"v": 0.1, "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_conv2d", [["TENSOR", [1, 16, 112, 112], "float32"], ["TENSOR", [56, 16, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 16, 112, 112, "float32"], [56, 16, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"t": "direct", "c": null, "e": [["tile_co", "sp", [14, 4]], ["tile_oh", "sp", [112, 1]], ["tile_ow", "sp", [16, 7]], ["reorder_0", "re", [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]], ["ann_reduce", "an", ["none", "unroll"]], ["ann_spatial", "an", ["none", "unroll", "vec"]]], "i": 33842}], "r": [[0.00138687475], 0, 0.5438559055328369, 1548196184.7103753]} 37 | {"v": 0.1, "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 16, 112, 112], "float32"], ["TENSOR", [16, 1, 3, 3], "float32"], [1, 1], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 16, 112, 112, "float32"], [16, 1, 3, 3, "float32"], [1, 1], [1, 1], "float32"], {"t": "direct", "c": null, "e": [["tile_c", "sp", [4, 4]], ["tile_h", "sp", [56, 2]], ["tile_w", "sp", [28, 4]], ["ann", "an", ["unroll", "unroll", "vec"]]], "i": 6607}], "r": [[0.0006212455], 0, 0.6399178504943848, 1548196338.593711]} 38 | {"v": 0.1, "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_conv2d", [["TENSOR", [1, 3, 224, 224], "float32"], ["TENSOR", [16, 3, 3, 3], "float32"], [2, 2], [1, 1], "NCHW", "float32"], {}, ["conv2d", [1, 3, 224, 224, "float32"], [16, 3, 3, 3, "float32"], [2, 2], [1, 1], "NCHW", "float32"], {"t": "direct", "c": null, "e": [["tile_co", "sp", [4, 4]], ["tile_oh", "sp", [56, 2]], ["tile_ow", "sp", [28, 4]], ["reorder_0", "re", [0, 1, 2, 3, 4, 5, 6, 9, 7, 8]], ["ann_reduce", "an", ["unroll", "unroll"]], ["ann_spatial", "an", ["none", "none", "vec"]]], "i": 11607}], "r": [[0.00103108375], 0, 0.8736927509307861, 1548196616.90433]} 39 | -------------------------------------------------------------------------------- /tvm_compile/tuning/tx2-cpu.mobilenet-nnconv5dw-skipadd.trials=1000.stop=250.log: -------------------------------------------------------------------------------- 1 | {"r": [[0.00218976625], 0, 0.38044118881225586, 1547737329.8211834], "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_conv2d", [["TENSOR", [1, 32, 224, 224], "float32"], ["TENSOR", [1, 32, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 32, 224, 224, "float32"], [1, 32, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"i": 7283, "e": [["tile_co", "sp", [1, 1]], ["tile_oh", "sp", [1, 224]], ["tile_ow", "sp", [14, 16]], ["reorder_0", "re", [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]], ["ann_reduce", "an", ["none", "unroll"]], ["ann_spatial", "an", ["none", "vec", "none"]]], "t": "direct", "c": null}], "v": 0.1} 2 | {"r": [[0.002453024], 0, 0.8830351829528809, 1547737670.3910737], "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_conv2d", [["TENSOR", [1, 64, 112, 112], "float32"], ["TENSOR", [32, 64, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 64, 112, 112, "float32"], [32, 64, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"i": 65168, "e": [["tile_co", "sp", [8, 4]], ["tile_oh", "sp", [56, 2]], ["tile_ow", "sp", [7, 16]], ["reorder_0", "re", [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]], ["ann_reduce", "an", ["unroll", "none"]], ["ann_spatial", "an", ["unroll", "unroll", "vec"]]], "t": "direct", "c": null}], "v": 0.1} 3 | {"r": [[0.00222503425], 0, 1.423985481262207, 1547737855.3445117], "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 64, 112, 112], "float32"], ["TENSOR", [64, 1, 5, 5], "float32"], [1, 1], [2, 2], "float32"], {}, ["depthwise_conv2d_nchw", [1, 64, 112, 112, "float32"], [64, 1, 5, 5, "float32"], [1, 1], [2, 2], "float32"], {"i": 9130, "e": [["tile_c", "sp", [16, 4]], ["tile_h", "sp", [14, 8]], ["tile_w", "sp", [112, 1]], ["ann", "an", ["unroll", "unroll", "vec"]]], "t": "direct", "c": null}], "v": 0.1} 4 | {"r": [[0.0017814385], 0, 0.6939797401428223, 1547738053.2760224], "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_conv2d", [["TENSOR", [1, 128, 56, 56], "float32"], ["TENSOR", [64, 128, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 128, 56, 56, "float32"], [64, 128, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"i": 8292, "e": [["tile_co", "sp", [4, 16]], ["tile_oh", "sp", [56, 1]], ["tile_ow", "sp", [7, 8]], ["reorder_0", "re", [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]], ["ann_reduce", "an", ["none", "unroll"]], ["ann_spatial", "an", ["none", "none", "vec"]]], "t": "direct", "c": null}], "v": 0.1} 5 | {"r": [[0.001141443], 0, 0.5493149757385254, 1547738328.2206848], "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 128, 56, 56], "float32"], ["TENSOR", [128, 1, 5, 5], "float32"], [1, 1], [2, 2], "float32"], {}, ["depthwise_conv2d_nchw", [1, 128, 56, 56, "float32"], [128, 1, 5, 5, "float32"], [1, 1], [2, 2], "float32"], {"i": 5386, "e": [["tile_c", "sp", [32, 4]], ["tile_h", "sp", [28, 2]], ["tile_w", "sp", [7, 8]], ["ann", "an", ["unroll", "none", "vec"]]], "t": "direct", "c": null}], "v": 0.1} 6 | {"r": [[0.0012462585], 0, 1.180959701538086, 1547738477.9994175], "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_conv2d", [["TENSOR", [1, 256, 28, 28], "float32"], ["TENSOR", [128, 256, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 256, 28, 28, "float32"], [128, 256, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"i": 31299, "e": [["tile_co", "sp", [16, 8]], ["tile_oh", "sp", [28, 1]], ["tile_ow", "sp", [2, 14]], ["reorder_0", "re", [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]], ["ann_reduce", "an", ["unroll", "none"]], ["ann_spatial", "an", ["unroll", "unroll", "vec"]]], "t": "direct", "c": null}], "v": 0.1} 7 | {"r": [[0.000699989], 0, 1.2548935413360596, 1547738647.2521586], "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 256, 28, 28], "float32"], ["TENSOR", [256, 1, 5, 5], "float32"], [1, 1], [2, 2], "float32"], {}, ["depthwise_conv2d_nchw", [1, 256, 28, 28, "float32"], [256, 1, 5, 5, "float32"], [1, 1], [2, 2], "float32"], {"i": 3377, "e": [["tile_c", "sp", [64, 4]], ["tile_h", "sp", [4, 7]], ["tile_w", "sp", [7, 4]], ["ann", "an", ["unroll", "none", "vec"]]], "t": "direct", "c": null}], "v": 0.1} 8 | {"r": [[0.001157523], 0, 0.510570764541626, 1547738817.8810546], "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_conv2d", [["TENSOR", [1, 512, 14, 14], "float32"], ["TENSOR", [256, 512, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 512, 14, 14, "float32"], [256, 512, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"i": 6735, "e": [["tile_co", "sp", [32, 8]], ["tile_oh", "sp", [14, 1]], ["tile_ow", "sp", [1, 14]], ["reorder_0", "re", [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]], ["ann_reduce", "an", ["unroll", "unroll"]], ["ann_spatial", "an", ["none", "unroll", "vec"]]], "t": "direct", "c": null}], "v": 0.1} 9 | {"r": [[0.00032969425], 0, 1.0283777713775635, 1547738958.5527909], "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 512, 14, 14], "float32"], ["TENSOR", [512, 1, 5, 5], "float32"], [1, 1], [2, 2], "float32"], {}, ["depthwise_conv2d_nchw", [1, 512, 14, 14, "float32"], [512, 1, 5, 5, "float32"], [1, 1], [2, 2], "float32"], {"i": 2112, "e": [["tile_c", "sp", [128, 4]], ["tile_h", "sp", [1, 14]], ["tile_w", "sp", [14, 1]], ["ann", "an", ["unroll", "unroll", "vec"]]], "t": "direct", "c": null}], "v": 0.1} 10 | {"r": [[0.001088833], 0, 1.3669013977050781, 1547739167.7177014], "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_conv2d", [["TENSOR", [1, 1024, 7, 7], "float32"], ["TENSOR", [512, 1024, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 1024, 7, 7, "float32"], [512, 1024, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"i": 942, "e": [["tile_co", "sp", [128, 4]], ["tile_oh", "sp", [7, 1]], ["tile_ow", "sp", [1, 7]], ["reorder_0", "re", [0, 1, 2, 3, 4, 5, 6, 9, 7, 8]], ["ann_reduce", "an", ["unroll", "unroll"]], ["ann_spatial", "an", ["none", "none", "vec"]]], "t": "direct", "c": null}], "v": 0.1} 11 | {"r": [[0.0002016545], 0, 1.0431005954742432, 1547739286.0030966], "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 1024, 7, 7], "float32"], ["TENSOR", [1024, 1, 5, 5], "float32"], [1, 1], [2, 2], "float32"], {}, ["depthwise_conv2d_nchw", [1, 1024, 7, 7, "float32"], [1024, 1, 5, 5, "float32"], [1, 1], [2, 2], "float32"], {"i": 587, "e": [["tile_c", "sp", [64, 16]], ["tile_h", "sp", [1, 7]], ["tile_w", "sp", [7, 1]], ["ann", "an", ["unroll", "unroll", "vec"]]], "t": "direct", "c": null}], "v": 0.1} 12 | {"r": [[0.00204971875], 0, 1.4387884140014648, 1547739409.8011386], "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_conv2d", [["TENSOR", [1, 1024, 7, 7], "float32"], ["TENSOR", [1024, 1024, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 1024, 7, 7, "float32"], [1024, 1024, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"i": 729, "e": [["tile_co", "sp", [128, 8]], ["tile_oh", "sp", [7, 1]], ["tile_ow", "sp", [1, 7]], ["reorder_0", "re", [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]], ["ann_reduce", "an", ["none", "none"]], ["ann_spatial", "an", ["none", "none", "vec"]]], "t": "direct", "c": null}], "v": 0.1} 13 | {"r": [[0.000125911], 0, 0.5861396789550781, 1547739528.3804939], "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 1024, 7, 7], "float32"], ["TENSOR", [1024, 1, 3, 3], "float32"], [1, 1], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 1024, 7, 7, "float32"], [1024, 1, 3, 3, "float32"], [1, 1], [1, 1], "float32"], {"i": 455, "e": [["tile_c", "sp", [64, 16]], ["tile_h", "sp", [1, 7]], ["tile_w", "sp", [7, 1]], ["ann", "an", ["unroll", "none", "vec"]]], "t": "direct", "c": null}], "v": 0.1} 14 | {"r": [[0.0010378515], 0, 1.5314667224884033, 1547739659.6202457], "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_conv2d", [["TENSOR", [1, 512, 7, 7], "float32"], ["TENSOR", [1024, 512, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 512, 7, 7, "float32"], [1024, 512, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"i": 3799, "e": [["tile_co", "sp", [64, 16]], ["tile_oh", "sp", [1, 7]], ["tile_ow", "sp", [7, 1]], ["reorder_0", "re", [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]], ["ann_reduce", "an", ["unroll", "unroll"]], ["ann_spatial", "an", ["unroll", "none", "vec"]]], "t": "direct", "c": null}], "v": 0.1} 15 | {"r": [[0.0001344795], 0, 1.138458490371704, 1547739772.782865], "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 512, 14, 14], "float32"], ["TENSOR", [512, 1, 3, 3], "float32"], [2, 2], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 512, 14, 14, "float32"], [512, 1, 3, 3, "float32"], [2, 2], [1, 1], "float32"], {"i": 412, "e": [["tile_c", "sp", [128, 4]], ["tile_h", "sp", [1, 7]], ["tile_w", "sp", [7, 1]], ["ann", "an", ["unroll", "none", "vec"]]], "t": "direct", "c": null}], "v": 0.1} 16 | {"r": [[0.00178856], 0, 0.4340789318084717, 1547740023.9774854], "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_conv2d", [["TENSOR", [1, 512, 14, 14], "float32"], ["TENSOR", [512, 512, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 512, 14, 14, "float32"], [512, 512, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"i": 16993, "e": [["tile_co", "sp", [64, 8]], ["tile_oh", "sp", [1, 14]], ["tile_ow", "sp", [14, 1]], ["reorder_0", "re", [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]], ["ann_reduce", "an", ["none", "unroll"]], ["ann_spatial", "an", ["unroll", "unroll", "vec"]]], "t": "direct", "c": null}], "v": 0.1} 17 | {"r": [[0.000210239], 0, 1.2972590923309326, 1547740177.1609933], "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 512, 14, 14], "float32"], ["TENSOR", [512, 1, 3, 3], "float32"], [1, 1], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 512, 14, 14, "float32"], [512, 1, 3, 3, "float32"], [1, 1], [1, 1], "float32"], {"i": 1672, "e": [["tile_c", "sp", [128, 4]], ["tile_h", "sp", [1, 14]], ["tile_w", "sp", [7, 2]], ["ann", "an", ["unroll", "none", "vec"]]], "t": "direct", "c": null}], "v": 0.1} 18 | {"r": [[0.0009993475], 0, 0.8285174369812012, 1547740306.637762], "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_conv2d", [["TENSOR", [1, 256, 14, 14], "float32"], ["TENSOR", [512, 256, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 256, 14, 14, "float32"], [512, 256, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"i": 17693, "e": [["tile_co", "sp", [64, 8]], ["tile_oh", "sp", [7, 2]], ["tile_ow", "sp", [2, 7]], ["reorder_0", "re", [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]], ["ann_reduce", "an", ["unroll", "unroll"]], ["ann_spatial", "an", ["unroll", "unroll", "vec"]]], "t": "direct", "c": null}], "v": 0.1} 19 | {"r": [[0.00019471025], 0, 1.4993979930877686, 1547740464.7630417], "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 256, 28, 28], "float32"], ["TENSOR", [256, 1, 3, 3], "float32"], [2, 2], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 256, 28, 28, "float32"], [256, 1, 3, 3, "float32"], [2, 2], [1, 1], "float32"], {"i": 1489, "e": [["tile_c", "sp", [16, 16]], ["tile_h", "sp", [7, 2]], ["tile_w", "sp", [7, 2]], ["ann", "an", ["unroll", "none", "vec"]]], "t": "direct", "c": null}], "v": 0.1} 20 | {"r": [[0.00211812525], 0, 1.3814301490783691, 1547740591.591791], "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_conv2d", [["TENSOR", [1, 256, 28, 28], "float32"], ["TENSOR", [256, 256, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 256, 28, 28, "float32"], [256, 256, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"i": 28030, "e": [["tile_co", "sp", [16, 16]], ["tile_oh", "sp", [28, 1]], ["tile_ow", "sp", [4, 7]], ["reorder_0", "re", [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]], ["ann_reduce", "an", ["unroll", "unroll"]], ["ann_spatial", "an", ["unroll", "none", "vec"]]], "t": "direct", "c": null}], "v": 0.1} 21 | {"r": [[0.00051201125], 0, 1.5567152500152588, 1547740743.3866448], "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 256, 28, 28], "float32"], ["TENSOR", [256, 1, 3, 3], "float32"], [1, 1], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 256, 28, 28, "float32"], [256, 1, 3, 3, "float32"], [1, 1], [1, 1], "float32"], {"i": 3323, "e": [["tile_c", "sp", [64, 4]], ["tile_h", "sp", [4, 7]], ["tile_w", "sp", [14, 2]], ["ann", "an", ["unroll", "none", "vec"]]], "t": "direct", "c": null}], "v": 0.1} 22 | {"r": [[0.001125614], 0, 1.1552152633666992, 1547740871.2665403], "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_conv2d", [["TENSOR", [1, 128, 28, 28], "float32"], ["TENSOR", [256, 128, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 128, 28, 28, "float32"], [256, 128, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"i": 14692, "e": [["tile_co", "sp", [16, 16]], ["tile_oh", "sp", [28, 1]], ["tile_ow", "sp", [7, 4]], ["reorder_0", "re", [0, 1, 2, 3, 4, 5, 6, 9, 7, 8]], ["ann_reduce", "an", ["unroll", "none"]], ["ann_spatial", "an", ["none", "unroll", "vec"]]], "t": "direct", "c": null}], "v": 0.1} 23 | {"r": [[0.00040230925], 0, 1.6127750873565674, 1547741005.5860286], "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 128, 56, 56], "float32"], ["TENSOR", [128, 1, 3, 3], "float32"], [2, 2], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 128, 56, 56, "float32"], [128, 1, 3, 3, "float32"], [2, 2], [1, 1], "float32"], {"i": 3802, "e": [["tile_c", "sp", [32, 4]], ["tile_h", "sp", [14, 2]], ["tile_w", "sp", [14, 2]], ["ann", "an", ["unroll", "unroll", "vec"]]], "t": "direct", "c": null}], "v": 0.1} 24 | {"r": [[0.00276929975], 0, 1.7560184001922607, 1547741243.1428356], "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_conv2d", [["TENSOR", [1, 128, 56, 56], "float32"], ["TENSOR", [128, 128, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 128, 56, 56, "float32"], [128, 128, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"i": 54595, "e": [["tile_co", "sp", [16, 8]], ["tile_oh", "sp", [56, 1]], ["tile_ow", "sp", [4, 14]], ["reorder_0", "re", [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]], ["ann_reduce", "an", ["none", "unroll"]], ["ann_spatial", "an", ["unroll", "unroll", "vec"]]], "t": "direct", "c": null}], "v": 0.1} 25 | {"r": [[0.00082235325], 0, 1.116337537765503, 1547741456.8146865], "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 128, 56, 56], "float32"], ["TENSOR", [128, 1, 3, 3], "float32"], [1, 1], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 128, 56, 56, "float32"], [128, 1, 3, 3, "float32"], [1, 1], [1, 1], "float32"], {"i": 2690, "e": [["tile_c", "sp", [32, 4]], ["tile_h", "sp", [56, 1]], ["tile_w", "sp", [14, 4]], ["ann", "an", ["none", "unroll", "vec"]]], "t": "direct", "c": null}], "v": 0.1} 26 | {"r": [[0.00145455625], 0, 1.3054676055908203, 1547741597.89791], "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_conv2d", [["TENSOR", [1, 64, 56, 56], "float32"], ["TENSOR", [128, 64, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 64, 56, 56, "float32"], [128, 64, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"i": 22787, "e": [["tile_co", "sp", [16, 8]], ["tile_oh", "sp", [56, 1]], ["tile_ow", "sp", [7, 8]], ["reorder_0", "re", [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]], ["ann_reduce", "an", ["unroll", "none"]], ["ann_spatial", "an", ["none", "unroll", "vec"]]], "t": "direct", "c": null}], "v": 0.1} 27 | {"r": [[0.00079617225], 0, 1.41221022605896, 1547741742.9475193], "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 64, 112, 112], "float32"], ["TENSOR", [64, 1, 3, 3], "float32"], [2, 2], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 64, 112, 112, "float32"], [64, 1, 3, 3, "float32"], [2, 2], [1, 1], "float32"], {"i": 6001, "e": [["tile_c", "sp", [16, 4]], ["tile_h", "sp", [28, 2]], ["tile_w", "sp", [8, 7]], ["ann", "an", ["unroll", "unroll", "vec"]]], "t": "direct", "c": null}], "v": 0.1} 28 | {"r": [[0.002438363], 0, 0.6403403282165527, 1547742139.42199], "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_conv2d", [["TENSOR", [1, 32, 112, 112], "float32"], ["TENSOR", [64, 32, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 32, 112, 112, "float32"], [64, 32, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"i": 29122, "e": [["tile_co", "sp", [16, 4]], ["tile_oh", "sp", [112, 1]], ["tile_ow", "sp", [7, 16]], ["reorder_0", "re", [0, 1, 2, 3, 4, 5, 6, 9, 7, 8]], ["ann_reduce", "an", ["none", "none"]], ["ann_spatial", "an", ["none", "unroll", "vec"]]], "t": "direct", "c": null}], "v": 0.1} 29 | {"r": [[0.00110020375], 0, 1.101663589477539, 1547742312.38935], "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 32, 112, 112], "float32"], ["TENSOR", [32, 1, 3, 3], "float32"], [1, 1], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 32, 112, 112, "float32"], [32, 1, 3, 3, "float32"], [1, 1], [1, 1], "float32"], {"i": 7862, "e": [["tile_c", "sp", [8, 4]], ["tile_h", "sp", [112, 1]], ["tile_w", "sp", [56, 2]], ["ann", "an", ["unroll", "unroll", "vec"]]], "t": "direct", "c": null}], "v": 0.1} 30 | {"r": [[0.00148117875], 0, 1.003594160079956, 1547742463.732416], "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_conv2d", [["TENSOR", [1, 3, 224, 224], "float32"], ["TENSOR", [32, 3, 3, 3], "float32"], [2, 2], [1, 1], "NCHW", "float32"], {}, ["conv2d", [1, 3, 224, 224, "float32"], [32, 3, 3, 3, "float32"], [2, 2], [1, 1], "NCHW", "float32"], {"i": 65048, "e": [["tile_co", "sp", [8, 4]], ["tile_oh", "sp", [56, 2]], ["tile_ow", "sp", [14, 8]], ["reorder_0", "re", [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]], ["ann_reduce", "an", ["unroll", "none"]], ["ann_spatial", "an", ["unroll", "unroll", "vec"]]], "t": "direct", "c": null}], "v": 0.1} 31 | -------------------------------------------------------------------------------- /tvm_compile/tuning/tx2-cpu.mobilenet-nnconv5dw.trials=1000.stop=250.log: -------------------------------------------------------------------------------- 1 | {"v": 0.1, "r": [[0.00211120525], 0, 1.6490120887756348, 1549463275.3477194], "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_conv2d", [["TENSOR", [1, 32, 224, 224], "float32"], ["TENSOR", [1, 32, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 32, 224, 224, "float32"], [1, 32, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"c": null, "i": 7706, "t": "direct", "e": [["tile_co", "sp", [1, 1]], ["tile_oh", "sp", [56, 4]], ["tile_ow", "sp", [14, 16]], ["reorder_0", "re", [0, 1, 2, 3, 4, 5, 6, 9, 7, 8]], ["ann_reduce", "an", ["unroll", "none"]], ["ann_spatial", "an", ["none", "vec", "none"]]]}]} 2 | {"v": 0.1, "r": [[0.002641528], 0, 1.8992490768432617, 1549463522.9436731], "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_conv2d", [["TENSOR", [1, 64, 112, 112], "float32"], ["TENSOR", [32, 64, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 64, 112, 112, "float32"], [32, 64, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"c": null, "i": 62768, "t": "direct", "e": [["tile_co", "sp", [8, 4]], ["tile_oh", "sp", [56, 2]], ["tile_ow", "sp", [7, 16]], ["reorder_0", "re", [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]], ["ann_reduce", "an", ["none", "none"]], ["ann_spatial", "an", ["unroll", "unroll", "vec"]]]}]} 3 | {"v": 0.1, "r": [[0.002179639], 0, 0.5747640132904053, 1549463794.6300278], "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 64, 112, 112], "float32"], ["TENSOR", [64, 1, 5, 5], "float32"], [1, 1], [2, 2], "float32"], {}, ["depthwise_conv2d_nchw", [1, 64, 112, 112, "float32"], [64, 1, 5, 5, "float32"], [1, 1], [2, 2], "float32"], {"c": null, "i": 9193, "t": "direct", "e": [["tile_c", "sp", [16, 4]], ["tile_h", "sp", [16, 7]], ["tile_w", "sp", [56, 2]], ["ann", "an", ["unroll", "unroll", "vec"]]]}]} 4 | {"v": 0.1, "r": [[0.00177205575], 0, 0.9056341648101807, 1549464008.0640798], "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_conv2d", [["TENSOR", [1, 128, 56, 56], "float32"], ["TENSOR", [64, 128, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 128, 56, 56, "float32"], [64, 128, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"c": null, "i": 8292, "t": "direct", "e": [["tile_co", "sp", [4, 16]], ["tile_oh", "sp", [56, 1]], ["tile_ow", "sp", [7, 8]], ["reorder_0", "re", [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]], ["ann_reduce", "an", ["none", "unroll"]], ["ann_spatial", "an", ["none", "none", "vec"]]]}]} 5 | {"v": 0.1, "r": [[0.00111706975], 0, 1.5156757831573486, 1549464225.2041118], "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 128, 56, 56], "float32"], ["TENSOR", [128, 1, 5, 5], "float32"], [1, 1], [2, 2], "float32"], {}, ["depthwise_conv2d_nchw", [1, 128, 56, 56, "float32"], [128, 1, 5, 5, "float32"], [1, 1], [2, 2], "float32"], {"c": null, "i": 5218, "t": "direct", "e": [["tile_c", "sp", [32, 4]], ["tile_h", "sp", [7, 8]], ["tile_w", "sp", [28, 2]], ["ann", "an", ["unroll", "none", "vec"]]]}]} 6 | {"v": 0.1, "r": [[0.00118750475], 0, 1.0649268627166748, 1549464359.02914], "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_conv2d", [["TENSOR", [1, 256, 28, 28], "float32"], ["TENSOR", [128, 256, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 256, 28, 28, "float32"], [128, 256, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"c": null, "i": 31827, "t": "direct", "e": [["tile_co", "sp", [16, 8]], ["tile_oh", "sp", [28, 1]], ["tile_ow", "sp", [4, 7]], ["reorder_0", "re", [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]], ["ann_reduce", "an", ["unroll", "unroll"]], ["ann_spatial", "an", ["unroll", "unroll", "vec"]]]}]} 7 | {"v": 0.1, "r": [[0.00073739475], 0, 0.8508305549621582, 1549464569.377289], "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 256, 28, 28], "float32"], ["TENSOR", [256, 1, 5, 5], "float32"], [1, 1], [2, 2], "float32"], {}, ["depthwise_conv2d_nchw", [1, 256, 28, 28, "float32"], [256, 1, 5, 5, "float32"], [1, 1], [2, 2], "float32"], {"c": null, "i": 3377, "t": "direct", "e": [["tile_c", "sp", [64, 4]], ["tile_h", "sp", [4, 7]], ["tile_w", "sp", [7, 4]], ["ann", "an", ["unroll", "none", "vec"]]]}]} 8 | {"v": 0.1, "r": [[0.00099514], 0, 0.7310960292816162, 1549464698.1474411], "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_conv2d", [["TENSOR", [1, 512, 14, 14], "float32"], ["TENSOR", [256, 512, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 512, 14, 14, "float32"], [256, 512, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"c": null, "i": 11892, "t": "direct", "e": [["tile_co", "sp", [32, 8]], ["tile_oh", "sp", [7, 2]], ["tile_ow", "sp", [2, 7]], ["reorder_0", "re", [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]], ["ann_reduce", "an", ["none", "unroll"]], ["ann_spatial", "an", ["unroll", "none", "vec"]]]}]} 9 | {"v": 0.1, "r": [[0.0003016135], 0, 0.6347737312316895, 1549464859.3759325], "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 512, 14, 14], "float32"], ["TENSOR", [512, 1, 5, 5], "float32"], [1, 1], [2, 2], "float32"], {}, ["depthwise_conv2d_nchw", [1, 512, 14, 14, "float32"], [512, 1, 5, 5, "float32"], [1, 1], [2, 2], "float32"], {"c": null, "i": 952, "t": "direct", "e": [["tile_c", "sp", [128, 4]], ["tile_h", "sp", [1, 14]], ["tile_w", "sp", [1, 14]], ["ann", "an", ["none", "unroll", "vec"]]]}]} 10 | {"v": 0.1, "r": [[0.0011650785], 0, 1.9754915237426758, 1549465000.6372852], "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_conv2d", [["TENSOR", [1, 1024, 7, 7], "float32"], ["TENSOR", [512, 1024, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 1024, 7, 7, "float32"], [512, 1024, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"c": null, "i": 813, "t": "direct", "e": [["tile_co", "sp", [64, 8]], ["tile_oh", "sp", [1, 7]], ["tile_ow", "sp", [7, 1]], ["reorder_0", "re", [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]], ["ann_reduce", "an", ["unroll", "none"]], ["ann_spatial", "an", ["none", "none", "vec"]]]}]} 11 | {"v": 0.1, "r": [[0.0002012395], 0, 0.5135416984558105, 1549465119.4674332], "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 1024, 7, 7], "float32"], ["TENSOR", [1024, 1, 5, 5], "float32"], [1, 1], [2, 2], "float32"], {}, ["depthwise_conv2d_nchw", [1, 1024, 7, 7, "float32"], [1024, 1, 5, 5, "float32"], [1, 1], [2, 2], "float32"], {"c": null, "i": 587, "t": "direct", "e": [["tile_c", "sp", [64, 16]], ["tile_h", "sp", [1, 7]], ["tile_w", "sp", [7, 1]], ["ann", "an", ["unroll", "unroll", "vec"]]]}]} 12 | {"v": 0.1, "r": [[0.001770547], 0, 1.3441791534423828, 1549465333.001983], "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_conv2d", [["TENSOR", [1, 1024, 7, 7], "float32"], ["TENSOR", [1024, 1024, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 1024, 7, 7, "float32"], [1024, 1024, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"c": null, "i": 4733, "t": "direct", "e": [["tile_co", "sp", [128, 8]], ["tile_oh", "sp", [7, 1]], ["tile_ow", "sp", [1, 7]], ["reorder_0", "re", [0, 1, 2, 3, 4, 5, 6, 9, 7, 8]], ["ann_reduce", "an", ["none", "unroll"]], ["ann_spatial", "an", ["unroll", "unroll", "vec"]]]}]} 13 | {"v": 0.1, "r": [[0.000129191], 0, 0.7576065063476562, 1549465462.6113133], "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 1024, 7, 7], "float32"], ["TENSOR", [1024, 1, 3, 3], "float32"], [1, 1], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 1024, 7, 7, "float32"], [1024, 1, 3, 3, "float32"], [1, 1], [1, 1], "float32"], {"c": null, "i": 587, "t": "direct", "e": [["tile_c", "sp", [64, 16]], ["tile_h", "sp", [1, 7]], ["tile_w", "sp", [7, 1]], ["ann", "an", ["unroll", "unroll", "vec"]]]}]} 14 | {"v": 0.1, "r": [[0.001034831], 0, 1.3871076107025146, 1549465588.4565344], "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_conv2d", [["TENSOR", [1, 512, 7, 7], "float32"], ["TENSOR", [1024, 512, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 512, 7, 7, "float32"], [1024, 512, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"c": null, "i": 3711, "t": "direct", "e": [["tile_co", "sp", [64, 16]], ["tile_oh", "sp", [1, 7]], ["tile_ow", "sp", [7, 1]], ["reorder_0", "re", [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]], ["ann_reduce", "an", ["unroll", "none"]], ["ann_spatial", "an", ["unroll", "none", "vec"]]]}]} 15 | {"v": 0.1, "r": [[0.00013741475], 0, 0.7633135318756104, 1549465716.011089], "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 512, 14, 14], "float32"], ["TENSOR", [512, 1, 3, 3], "float32"], [2, 2], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 512, 14, 14, "float32"], [512, 1, 3, 3, "float32"], [2, 2], [1, 1], "float32"], {"c": null, "i": 532, "t": "direct", "e": [["tile_c", "sp", [128, 4]], ["tile_h", "sp", [1, 7]], ["tile_w", "sp", [7, 1]], ["ann", "an", ["unroll", "unroll", "vec"]]]}]} 16 | {"v": 0.1, "r": [[0.0018755755], 0, 1.5863986015319824, 1549465865.7965143], "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_conv2d", [["TENSOR", [1, 512, 14, 14], "float32"], ["TENSOR", [512, 512, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 512, 14, 14, "float32"], [512, 512, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"c": null, "i": 7163, "t": "direct", "e": [["tile_co", "sp", [64, 8]], ["tile_oh", "sp", [14, 1]], ["tile_ow", "sp", [1, 14]], ["reorder_0", "re", [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]], ["ann_reduce", "an", ["unroll", "none"]], ["ann_spatial", "an", ["none", "unroll", "vec"]]]}]} 17 | {"v": 0.1, "r": [[0.00021360725], 0, 1.4937946796417236, 1549466005.5232635], "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 512, 14, 14], "float32"], ["TENSOR", [512, 1, 3, 3], "float32"], [1, 1], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 512, 14, 14, "float32"], [512, 1, 3, 3, "float32"], [1, 1], [1, 1], "float32"], {"c": null, "i": 1672, "t": "direct", "e": [["tile_c", "sp", [128, 4]], ["tile_h", "sp", [1, 14]], ["tile_w", "sp", [7, 2]], ["ann", "an", ["unroll", "none", "vec"]]]}]} 18 | {"v": 0.1, "r": [[0.0008690005], 0, 0.654231071472168, 1549466252.5401561], "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_conv2d", [["TENSOR", [1, 256, 14, 14], "float32"], ["TENSOR", [512, 256, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 256, 14, 14, "float32"], [512, 256, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"c": null, "i": 3604, "t": "direct", "e": [["tile_co", "sp", [32, 16]], ["tile_oh", "sp", [14, 1]], ["tile_ow", "sp", [2, 7]], ["reorder_0", "re", [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]], ["ann_reduce", "an", ["unroll", "unroll"]], ["ann_spatial", "an", ["none", "none", "vec"]]]}]} 19 | {"v": 0.1, "r": [[0.00021412725], 0, 1.4013373851776123, 1549466444.3549302], "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 256, 28, 28], "float32"], ["TENSOR", [256, 1, 3, 3], "float32"], [2, 2], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 256, 28, 28, "float32"], [256, 1, 3, 3, "float32"], [2, 2], [1, 1], "float32"], {"c": null, "i": 1505, "t": "direct", "e": [["tile_c", "sp", [64, 4]], ["tile_h", "sp", [1, 14]], ["tile_w", "sp", [7, 2]], ["ann", "an", ["unroll", "none", "vec"]]]}]} 20 | {"v": 0.1, "r": [[0.0020150435], 0, 0.5900120735168457, 1549466623.6702547], "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_conv2d", [["TENSOR", [1, 256, 28, 28], "float32"], ["TENSOR", [256, 256, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 256, 28, 28, "float32"], [256, 256, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"c": null, "i": 27382, "t": "direct", "e": [["tile_co", "sp", [16, 16]], ["tile_oh", "sp", [28, 1]], ["tile_ow", "sp", [4, 7]], ["reorder_0", "re", [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]], ["ann_reduce", "an", ["unroll", "none"]], ["ann_spatial", "an", ["unroll", "none", "vec"]]]}]} 21 | {"v": 0.1, "r": [[0.00048243825], 0, 0.46898937225341797, 1549466769.5930836], "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 256, 28, 28], "float32"], ["TENSOR", [256, 1, 3, 3], "float32"], [1, 1], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 256, 28, 28, "float32"], [256, 1, 3, 3, "float32"], [1, 1], [1, 1], "float32"], {"c": null, "i": 3278, "t": "direct", "e": [["tile_c", "sp", [64, 4]], ["tile_h", "sp", [2, 14]], ["tile_w", "sp", [28, 1]], ["ann", "an", ["unroll", "none", "vec"]]]}]} 22 | {"v": 0.1, "r": [[0.00114010975], 0, 1.0067620277404785, 1549466900.7609394], "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_conv2d", [["TENSOR", [1, 128, 28, 28], "float32"], ["TENSOR", [256, 128, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 128, 28, 28, "float32"], [256, 128, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"c": null, "i": 26032, "t": "direct", "e": [["tile_co", "sp", [16, 16]], ["tile_oh", "sp", [28, 1]], ["tile_ow", "sp", [7, 4]], ["reorder_0", "re", [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]], ["ann_reduce", "an", ["none", "none"]], ["ann_spatial", "an", ["unroll", "none", "vec"]]]}]} 23 | {"v": 0.1, "r": [[0.00035789325], 0, 0.7935357093811035, 1549467113.2244818], "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 128, 56, 56], "float32"], ["TENSOR", [128, 1, 3, 3], "float32"], [2, 2], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 128, 56, 56, "float32"], [128, 1, 3, 3, "float32"], [2, 2], [1, 1], "float32"], {"c": null, "i": 3819, "t": "direct", "e": [["tile_c", "sp", [16, 8]], ["tile_h", "sp", [4, 7]], ["tile_w", "sp", [14, 2]], ["ann", "an", ["unroll", "unroll", "vec"]]]}]} 24 | {"v": 0.1, "r": [[0.00285133], 0, 1.7403016090393066, 1549467283.6627488], "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_conv2d", [["TENSOR", [1, 128, 56, 56], "float32"], ["TENSOR", [128, 128, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 128, 56, 56, "float32"], [128, 128, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"c": null, "i": 53571, "t": "direct", "e": [["tile_co", "sp", [16, 8]], ["tile_oh", "sp", [56, 1]], ["tile_ow", "sp", [4, 14]], ["reorder_0", "re", [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]], ["ann_reduce", "an", ["none", "none"]], ["ann_spatial", "an", ["unroll", "unroll", "vec"]]]}]} 25 | {"v": 0.1, "r": [[0.0008334615], 0, 0.5077493190765381, 1549467438.3719728], "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 128, 56, 56], "float32"], ["TENSOR", [128, 1, 3, 3], "float32"], [1, 1], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 128, 56, 56, "float32"], [128, 1, 3, 3, "float32"], [1, 1], [1, 1], "float32"], {"c": null, "i": 5386, "t": "direct", "e": [["tile_c", "sp", [32, 4]], ["tile_h", "sp", [28, 2]], ["tile_w", "sp", [7, 8]], ["ann", "an", ["unroll", "none", "vec"]]]}]} 26 | {"v": 0.1, "r": [[0.00140673975], 0, 0.7620842456817627, 1549467661.36005], "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_conv2d", [["TENSOR", [1, 64, 56, 56], "float32"], ["TENSOR", [128, 64, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 64, 56, 56, "float32"], [128, 64, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"c": null, "i": 43267, "t": "direct", "e": [["tile_co", "sp", [16, 8]], ["tile_oh", "sp", [56, 1]], ["tile_ow", "sp", [7, 8]], ["reorder_0", "re", [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]], ["ann_reduce", "an", ["unroll", "none"]], ["ann_spatial", "an", ["unroll", "none", "vec"]]]}]} 27 | {"v": 0.1, "r": [[0.0008076695], 0, 1.599564790725708, 1549467809.7291658], "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 64, 112, 112], "float32"], ["TENSOR", [64, 1, 3, 3], "float32"], [2, 2], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 64, 112, 112, "float32"], [64, 1, 3, 3, "float32"], [2, 2], [1, 1], "float32"], {"c": null, "i": 4601, "t": "direct", "e": [["tile_c", "sp", [16, 4]], ["tile_h", "sp", [28, 2]], ["tile_w", "sp", [14, 4]], ["ann", "an", ["unroll", "none", "vec"]]]}]} 28 | {"v": 0.1, "r": [[0.00245742775], 0, 0.8367326259613037, 1549467959.4589097], "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_conv2d", [["TENSOR", [1, 32, 112, 112], "float32"], ["TENSOR", [64, 32, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 32, 112, 112, "float32"], [64, 32, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"c": null, "i": 73222, "t": "direct", "e": [["tile_co", "sp", [16, 4]], ["tile_oh", "sp", [112, 1]], ["tile_ow", "sp", [7, 16]], ["reorder_0", "re", [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]], ["ann_reduce", "an", ["none", "none"]], ["ann_spatial", "an", ["unroll", "unroll", "vec"]]]}]} 29 | {"v": 0.1, "r": [[0.0010120925], 0, 0.7711925506591797, 1549468147.8392646], "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 32, 112, 112], "float32"], ["TENSOR", [32, 1, 3, 3], "float32"], [1, 1], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 32, 112, 112, "float32"], [32, 1, 3, 3, "float32"], [1, 1], [1, 1], "float32"], {"c": null, "i": 1273, "t": "direct", "e": [["tile_c", "sp", [16, 2]], ["tile_h", "sp", [28, 4]], ["tile_w", "sp", [56, 2]], ["ann", "an", ["none", "none", "vec"]]]}]} 30 | {"v": 0.1, "r": [[0.001093943], 0, 1.7501792907714844, 1549468286.1650174], "i": ["llvm -device=arm_cpu -target=aarch64-linux-gnu", "topi_nn_conv2d", [["TENSOR", [1, 3, 224, 224], "float32"], ["TENSOR", [32, 3, 3, 3], "float32"], [2, 2], [1, 1], "NCHW", "float32"], {}, ["conv2d", [1, 3, 224, 224, "float32"], [32, 3, 3, 3, "float32"], [2, 2], [1, 1], "NCHW", "float32"], {"c": null, "i": 64940, "t": "direct", "e": [["tile_co", "sp", [8, 4]], ["tile_oh", "sp", [16, 7]], ["tile_ow", "sp", [28, 4]], ["reorder_0", "re", [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]], ["ann_reduce", "an", ["unroll", "none"]], ["ann_spatial", "an", ["unroll", "unroll", "vec"]]]}]} 31 | -------------------------------------------------------------------------------- /tvm_compile/tuning/tx2-gpu.mobilenet-nnconv5.trials=2000.stop=600.log: -------------------------------------------------------------------------------- 1 | {"v": 0.1, "r": [[0.000195583], 0, 2.7407631874084473, 1549520386.8809931], "i": ["cuda", "topi_nn_conv2d", [["TENSOR", [1, 32, 224, 224], "float32"], ["TENSOR", [1, 32, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 32, 224, 224, "float32"], [1, 32, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"c": null, "e": [["tile_f", "sp", [1, 1, 1, 1]], ["tile_y", "sp", [112, 1, 2, 1]], ["tile_x", "sp", [1, 2, 112, 1]], ["tile_rc", "sp", [16, 2]], ["tile_ry", "sp", [1, 1]], ["tile_rx", "sp", [1, 1]], ["auto_unroll_max_step", "ot", 512], ["unroll_explicit", "ot", 0]], "i": 364908, "t": "direct"}]} 2 | {"v": 0.1, "r": [[0.003038952], 0, 2.0225138664245605, 1549522084.273167], "i": ["cuda", "topi_nn_conv2d", [["TENSOR", [1, 64, 112, 112], "float32"], ["TENSOR", [32, 64, 5, 5], "float32"], [1, 1], [2, 2], "NCHW", "float32"], {}, ["conv2d", [1, 64, 112, 112, "float32"], [32, 64, 5, 5, "float32"], [1, 1], [2, 2], "NCHW", "float32"], {"c": null, "e": [["tile_f", "sp", [1, 1, 8, 4]], ["tile_y", "sp", [7, 2, 4, 2]], ["tile_x", "sp", [14, 1, 4, 2]], ["tile_rc", "sp", [32, 2]], ["tile_ry", "sp", [1, 5]], ["tile_rx", "sp", [5, 1]], ["auto_unroll_max_step", "ot", 1500], ["unroll_explicit", "ot", 0]], "i": 70712365, "t": "direct"}]} 3 | {"v": 0.1, "r": [[0.00275695625], 0, 5.14209771156311, 1549523356.497182], "i": ["cuda", "topi_nn_conv2d", [["TENSOR", [1, 128, 56, 56], "float32"], ["TENSOR", [64, 128, 5, 5], "float32"], [1, 1], [2, 2], "NCHW", "float32"], {}, ["conv2d", [1, 128, 56, 56, "float32"], [64, 128, 5, 5, "float32"], [1, 1], [2, 2], "NCHW", "float32"], {"c": null, "e": [["tile_f", "sp", [4, 1, 16, 1]], ["tile_y", "sp", [2, 2, 2, 7]], ["tile_x", "sp", [7, 1, 4, 2]], ["tile_rc", "sp", [128, 1]], ["tile_ry", "sp", [1, 5]], ["tile_rx", "sp", [1, 5]], ["auto_unroll_max_step", "ot", 1500], ["unroll_explicit", "ot", 0]], "i": 47582830, "t": "direct"}]} 4 | {"v": 0.1, "r": [[0.00238482725], 0, 7.642461776733398, 1549525645.12028], "i": ["cuda", "topi_nn_conv2d", [["TENSOR", [1, 256, 28, 28], "float32"], ["TENSOR", [128, 256, 5, 5], "float32"], [1, 1], [2, 2], "NCHW", "float32"], {}, ["conv2d", [1, 256, 28, 28, "float32"], [128, 256, 5, 5, "float32"], [1, 1], [2, 2], "NCHW", "float32"], {"c": null, "e": [["tile_f", "sp", [8, 1, 16, 1]], ["tile_y", "sp", [1, 1, 4, 7]], ["tile_x", "sp", [2, 1, 2, 7]], ["tile_rc", "sp", [64, 4]], ["tile_ry", "sp", [1, 5]], ["tile_rx", "sp", [1, 5]], ["auto_unroll_max_step", "ot", 1500], ["unroll_explicit", "ot", 1]], "i": 40290626, "t": "direct"}]} 5 | {"v": 0.1, "r": [[0.002864893], 0, 7.662211894989014, 1549526803.3654418], "i": ["cuda", "topi_nn_conv2d", [["TENSOR", [1, 512, 14, 14], "float32"], ["TENSOR", [256, 512, 5, 5], "float32"], [1, 1], [2, 2], "NCHW", "float32"], {}, ["conv2d", [1, 512, 14, 14, "float32"], [256, 512, 5, 5, "float32"], [1, 1], [2, 2], "NCHW", "float32"], {"c": null, "e": [["tile_f", "sp", [4, 1, 64, 1]], ["tile_y", "sp", [1, 1, 2, 7]], ["tile_x", "sp", [2, 1, 1, 7]], ["tile_rc", "sp", [256, 2]], ["tile_ry", "sp", [1, 5]], ["tile_rx", "sp", [1, 5]], ["auto_unroll_max_step", "ot", 1500], ["unroll_explicit", "ot", 1]], "i": 9791469, "t": "direct"}]} 6 | {"v": 0.1, "r": [[0.003657288], 0, 5.146381616592407, 1549528047.2862632], "i": ["cuda", "topi_nn_conv2d", [["TENSOR", [1, 1024, 7, 7], "float32"], ["TENSOR", [512, 1024, 5, 5], "float32"], [1, 1], [2, 2], "NCHW", "float32"], {}, ["conv2d", [1, 1024, 7, 7, "float32"], [512, 1024, 5, 5, "float32"], [1, 1], [2, 2], "NCHW", "float32"], {"c": null, "e": [["tile_f", "sp", [16, 2, 16, 1]], ["tile_y", "sp", [1, 1, 7, 1]], ["tile_x", "sp", [1, 1, 1, 7]], ["tile_rc", "sp", [512, 2]], ["tile_ry", "sp", [1, 5]], ["tile_rx", "sp", [1, 5]], ["auto_unroll_max_step", "ot", 1500], ["unroll_explicit", "ot", 1]], "i": 897195, "t": "direct"}]} 7 | {"v": 0.1, "r": [[0.0005521495], 0, 2.424633264541626, 1549529378.1834185], "i": ["cuda", "topi_nn_conv2d", [["TENSOR", [1, 1024, 7, 7], "float32"], ["TENSOR", [1024, 1024, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 1024, 7, 7, "float32"], [1024, 1024, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"c": null, "e": [["tile_f", "sp", [16, 1, 64, 1]], ["tile_y", "sp", [1, 7, 1, 1]], ["tile_x", "sp", [1, 7, 1, 1]], ["tile_rc", "sp", [128, 8]], ["tile_ry", "sp", [1, 1]], ["tile_rx", "sp", [1, 1]], ["auto_unroll_max_step", "ot", 512], ["unroll_explicit", "ot", 1]], "i": 216553, "t": "direct"}]} 8 | {"v": 0.1, "r": [[4.76715e-05], 0, 1.9822640419006348, 1549530309.6870823], "i": ["cuda", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 1024, 7, 7], "float32"], ["TENSOR", [1024, 1, 3, 3], "float32"], [1, 1], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 1024, 7, 7, "float32"], [1024, 1, 3, 3, "float32"], [1, 1], [1, 1], "float32"], {"c": null, "e": [["tile_f", "sp", [128, 2, 1, 4]], ["tile_y", "sp", [1, 1, 7, 1]], ["tile_x", "sp", [1, 1, 7, 1]], ["auto_unroll_max_step", "ot", 1500], ["unroll_explicit", "ot", 1]], "i": 25862, "t": "direct"}]} 9 | {"v": 0.1, "r": [[0.000293095], 0, 1.1059186458587646, 1549531195.1452584], "i": ["cuda", "topi_nn_conv2d", [["TENSOR", [1, 512, 7, 7], "float32"], ["TENSOR", [1024, 512, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 512, 7, 7, "float32"], [1024, 512, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"c": null, "e": [["tile_f", "sp", [32, 2, 8, 2]], ["tile_y", "sp", [1, 1, 1, 7]], ["tile_x", "sp", [1, 1, 7, 1]], ["tile_rc", "sp", [64, 8]], ["tile_ry", "sp", [1, 1]], ["tile_rx", "sp", [1, 1]], ["auto_unroll_max_step", "ot", 0], ["unroll_explicit", "ot", 1]], "i": 154248, "t": "direct"}]} 10 | {"v": 0.1, "r": [[4.919975e-05], 0, 2.612053871154785, 1549531758.9769905], "i": ["cuda", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 512, 14, 14], "float32"], ["TENSOR", [512, 1, 3, 3], "float32"], [2, 2], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 512, 14, 14, "float32"], [512, 1, 3, 3, "float32"], [2, 2], [1, 1], "float32"], {"c": null, "e": [["tile_f", "sp", [64, 1, 4, 2]], ["tile_y", "sp", [1, 1, 7, 1]], ["tile_x", "sp", [1, 1, 7, 1]], ["auto_unroll_max_step", "ot", 1500], ["unroll_explicit", "ot", 1]], "i": 19872, "t": "direct"}]} 11 | {"v": 0.1, "r": [[0.0003131495], 0, 2.075115203857422, 1549532492.7750611], "i": ["cuda", "topi_nn_conv2d", [["TENSOR", [1, 512, 14, 14], "float32"], ["TENSOR", [512, 512, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 512, 14, 14, "float32"], [512, 512, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"c": null, "e": [["tile_f", "sp", [8, 8, 8, 1]], ["tile_y", "sp", [2, 7, 1, 1]], ["tile_x", "sp", [1, 1, 14, 1]], ["tile_rc", "sp", [64, 8]], ["tile_ry", "sp", [1, 1]], ["tile_rx", "sp", [1, 1]], ["auto_unroll_max_step", "ot", 1500], ["unroll_explicit", "ot", 0]], "i": 1323990, "t": "direct"}]} 12 | {"v": 0.1, "r": [[5.81515e-05], 0, 3.1690168380737305, 1549533246.0323243], "i": ["cuda", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 512, 14, 14], "float32"], ["TENSOR", [512, 1, 3, 3], "float32"], [1, 1], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 512, 14, 14, "float32"], [512, 1, 3, 3, "float32"], [1, 1], [1, 1], "float32"], {"c": null, "e": [["tile_f", "sp", [128, 1, 2, 2]], ["tile_y", "sp", [1, 2, 7, 1]], ["tile_x", "sp", [1, 1, 14, 1]], ["auto_unroll_max_step", "ot", 256], ["unroll_explicit", "ot", 1]], "i": 255044, "t": "direct"}]} 13 | {"v": 0.1, "r": [[0.00017847925], 0, 2.3081159591674805, 1549533918.8541245], "i": ["cuda", "topi_nn_conv2d", [["TENSOR", [1, 256, 14, 14], "float32"], ["TENSOR", [512, 256, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 256, 14, 14, "float32"], [512, 256, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"c": null, "e": [["tile_f", "sp", [8, 4, 8, 2]], ["tile_y", "sp", [2, 1, 1, 7]], ["tile_x", "sp", [1, 1, 14, 1]], ["tile_rc", "sp", [32, 8]], ["tile_ry", "sp", [1, 1]], ["tile_rx", "sp", [1, 1]], ["auto_unroll_max_step", "ot", 512], ["unroll_explicit", "ot", 0]], "i": 706721, "t": "direct"}]} 14 | {"v": 0.1, "r": [[7.766325e-05], 0, 2.9981043338775635, 1549534785.984523], "i": ["cuda", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 256, 28, 28], "float32"], ["TENSOR", [256, 1, 3, 3], "float32"], [2, 2], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 256, 28, 28, "float32"], [256, 1, 3, 3, "float32"], [2, 2], [1, 1], "float32"], {"c": null, "e": [["tile_f", "sp", [64, 1, 4, 1]], ["tile_y", "sp", [1, 1, 7, 2]], ["tile_x", "sp", [2, 1, 7, 1]], ["auto_unroll_max_step", "ot", 256], ["unroll_explicit", "ot", 1]], "i": 186632, "t": "direct"}]} 15 | {"v": 0.1, "r": [[0.00035476675], 0, 2.6585144996643066, 1549535835.6294103], "i": ["cuda", "topi_nn_conv2d", [["TENSOR", [1, 256, 28, 28], "float32"], ["TENSOR", [256, 256, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 256, 28, 28, "float32"], [256, 256, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"c": null, "e": [["tile_f", "sp", [4, 8, 8, 1]], ["tile_y", "sp", [7, 1, 1, 4]], ["tile_x", "sp", [1, 2, 14, 1]], ["tile_rc", "sp", [32, 8]], ["tile_ry", "sp", [1, 1]], ["tile_rx", "sp", [1, 1]], ["auto_unroll_max_step", "ot", 512], ["unroll_explicit", "ot", 0]], "i": 3278082, "t": "direct"}]} 16 | {"v": 0.1, "r": [[8.427175e-05], 0, 1.690295934677124, 1549536635.6460495], "i": ["cuda", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 256, 28, 28], "float32"], ["TENSOR", [256, 1, 3, 3], "float32"], [1, 1], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 256, 28, 28, "float32"], [256, 1, 3, 3, "float32"], [1, 1], [1, 1], "float32"], {"c": null, "e": [["tile_f", "sp", [256, 1, 1, 1]], ["tile_y", "sp", [1, 1, 4, 7]], ["tile_x", "sp", [1, 1, 28, 1]], ["auto_unroll_max_step", "ot", 256], ["unroll_explicit", "ot", 1]], "i": 1173975, "t": "direct"}]} 17 | {"v": 0.1, "r": [[0.00021443025], 0, 4.84061861038208, 1549537368.9350924], "i": ["cuda", "topi_nn_conv2d", [["TENSOR", [1, 128, 28, 28], "float32"], ["TENSOR", [256, 128, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 128, 28, 28, "float32"], [256, 128, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"c": null, "e": [["tile_f", "sp", [4, 1, 8, 8]], ["tile_y", "sp", [7, 4, 1, 1]], ["tile_x", "sp", [1, 2, 14, 1]], ["tile_rc", "sp", [16, 8]], ["tile_ry", "sp", [1, 1]], ["tile_rx", "sp", [1, 1]], ["auto_unroll_max_step", "ot", 1500], ["unroll_explicit", "ot", 0]], "i": 5122054, "t": "direct"}]} 18 | {"v": 0.1, "r": [[0.0001168395], 0, 1.664238691329956, 1549538332.091168], "i": ["cuda", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 128, 56, 56], "float32"], ["TENSOR", [128, 1, 3, 3], "float32"], [2, 2], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 128, 56, 56, "float32"], [128, 1, 3, 3, "float32"], [2, 2], [1, 1], "float32"], {"c": null, "e": [["tile_f", "sp", [128, 1, 1, 1]], ["tile_y", "sp", [1, 1, 14, 2]], ["tile_x", "sp", [1, 1, 28, 1]], ["auto_unroll_max_step", "ot", 1500], ["unroll_explicit", "ot", 1]], "i": 1044720, "t": "direct"}]} 19 | {"v": 0.1, "r": [[0.00031253325], 0, 2.2684757709503174, 1549539680.4623065], "i": ["cuda", "topi_nn_conv2d", [["TENSOR", [1, 128, 56, 56], "float32"], ["TENSOR", [128, 128, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 128, 56, 56, "float32"], [128, 128, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"c": null, "e": [["tile_f", "sp", [2, 1, 8, 8]], ["tile_y", "sp", [28, 1, 2, 1]], ["tile_x", "sp", [1, 7, 8, 1]], ["tile_rc", "sp", [16, 8]], ["tile_ry", "sp", [1, 1]], ["tile_rx", "sp", [1, 1]], ["auto_unroll_max_step", "ot", 512], ["unroll_explicit", "ot", 1]], "i": 27101857, "t": "direct"}]} 20 | {"v": 0.1, "r": [[0.0001404955], 0, 1.7230756282806396, 1549540462.374373], "i": ["cuda", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 128, 56, 56], "float32"], ["TENSOR", [128, 1, 3, 3], "float32"], [1, 1], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 128, 56, 56, "float32"], [128, 1, 3, 3, "float32"], [1, 1], [1, 1], "float32"], {"c": null, "e": [["tile_f", "sp", [128, 1, 1, 1]], ["tile_y", "sp", [1, 1, 8, 7]], ["tile_x", "sp", [1, 1, 56, 1]], ["auto_unroll_max_step", "ot", 1500], ["unroll_explicit", "ot", 1]], "i": 4126320, "t": "direct"}]} 21 | {"v": 0.1, "r": [[0.00020255025], 0, 1.367140293121338, 1549541573.3382344], "i": ["cuda", "topi_nn_conv2d", [["TENSOR", [1, 64, 56, 56], "float32"], ["TENSOR", [128, 64, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 64, 56, 56, "float32"], [128, 64, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"c": null, "e": [["tile_f", "sp", [2, 1, 8, 8]], ["tile_y", "sp", [28, 1, 2, 1]], ["tile_x", "sp", [1, 7, 8, 1]], ["tile_rc", "sp", [8, 8]], ["tile_ry", "sp", [1, 1]], ["tile_rx", "sp", [1, 1]], ["auto_unroll_max_step", "ot", 0], ["unroll_explicit", "ot", 1]], "i": 18653857, "t": "direct"}]} 22 | {"v": 0.1, "r": [[0.00021456725], 0, 1.408292531967163, 1549542340.7373421], "i": ["cuda", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 64, 112, 112], "float32"], ["TENSOR", [64, 1, 3, 3], "float32"], [2, 2], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 64, 112, 112, "float32"], [64, 1, 3, 3, "float32"], [2, 2], [1, 1], "float32"], {"c": null, "e": [["tile_f", "sp", [64, 1, 1, 1]], ["tile_y", "sp", [7, 2, 4, 1]], ["tile_x", "sp", [1, 1, 56, 1]], ["auto_unroll_max_step", "ot", 256], ["unroll_explicit", "ot", 1]], "i": 2346540, "t": "direct"}]} 23 | {"v": 0.1, "r": [[0.000223671], 0, 3.9969608783721924, 1549543261.223412], "i": ["cuda", "topi_nn_conv2d", [["TENSOR", [1, 32, 112, 112], "float32"], ["TENSOR", [64, 32, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 32, 112, 112, "float32"], [64, 32, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"c": null, "e": [["tile_f", "sp", [2, 2, 2, 8]], ["tile_y", "sp", [56, 1, 2, 1]], ["tile_x", "sp", [1, 7, 16, 1]], ["tile_rc", "sp", [8, 4]], ["tile_ry", "sp", [1, 1]], ["tile_rx", "sp", [1, 1]], ["auto_unroll_max_step", "ot", 0], ["unroll_explicit", "ot", 0]], "i": 3740589, "t": "direct"}]} 24 | {"v": 0.1, "r": [[0.00013695075], 0, 2.816110610961914, 1549544188.9147506], "i": ["cuda", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 32, 112, 112], "float32"], ["TENSOR", [32, 1, 3, 3], "float32"], [1, 1], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 32, 112, 112, "float32"], [32, 1, 3, 3, "float32"], [1, 1], [1, 1], "float32"], {"c": null, "e": [["tile_f", "sp", [32, 1, 1, 1]], ["tile_y", "sp", [4, 1, 4, 7]], ["tile_x", "sp", [1, 1, 112, 1]], ["auto_unroll_max_step", "ot", 256], ["unroll_explicit", "ot", 1]], "i": 4741072, "t": "direct"}]} 25 | {"v": 0.1, "r": [[0.00019237425], 0, 2.6703319549560547, 1549545517.7321806], "i": ["cuda", "topi_nn_conv2d", [["TENSOR", [1, 3, 224, 224], "float32"], ["TENSOR", [32, 3, 3, 3], "float32"], [2, 2], [1, 1], "NCHW", "float32"], {}, ["conv2d", [1, 3, 224, 224, "float32"], [32, 3, 3, 3, "float32"], [2, 2], [1, 1], "NCHW", "float32"], {"c": null, "e": [["tile_f", "sp", [1, 1, 8, 4]], ["tile_y", "sp", [56, 1, 1, 2]], ["tile_x", "sp", [2, 1, 56, 1]], ["tile_rc", "sp", [1, 3]], ["tile_ry", "sp", [1, 3]], ["tile_rx", "sp", [1, 3]], ["auto_unroll_max_step", "ot", 512], ["unroll_explicit", "ot", 1]], "i": 43138245, "t": "direct"}]} 26 | -------------------------------------------------------------------------------- /tvm_compile/tuning/tx2-gpu.mobilenet-nnconv5dw-skipadd-pruned.trials=2000.stop=600.log: -------------------------------------------------------------------------------- 1 | {"i": ["cuda", "topi_nn_conv2d", [["TENSOR", [1, 16, 224, 224], "float32"], ["TENSOR", [1, 16, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 16, 224, 224, "float32"], [1, 16, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"i": 314944, "e": [["tile_f", "sp", [1, 1, 1, 1]], ["tile_y", "sp", [224, 1, 1, 1]], ["tile_x", "sp", [1, 1, 224, 1]], ["tile_rc", "sp", [8, 2]], ["tile_ry", "sp", [1, 1]], ["tile_rx", "sp", [1, 1]], ["auto_unroll_max_step", "ot", 512], ["unroll_explicit", "ot", 0]], "t": "direct", "c": null}], "r": [[0.0001044715], 0, 1.65395188331604, 1548646492.0790715], "v": 0.1} 2 | {"i": ["cuda", "topi_nn_conv2d", [["TENSOR", [1, 56, 112, 112], "float32"], ["TENSOR", [16, 56, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 56, 112, 112, "float32"], [16, 56, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"i": 13920913, "e": [["tile_f", "sp", [1, 2, 8, 1]], ["tile_y", "sp", [112, 1, 1, 1]], ["tile_x", "sp", [1, 4, 28, 1]], ["tile_rc", "sp", [7, 8]], ["tile_ry", "sp", [1, 1]], ["tile_rx", "sp", [1, 1]], ["auto_unroll_max_step", "ot", 1500], ["unroll_explicit", "ot", 0]], "t": "direct", "c": null}], "r": [[0.000170911], 0, 2.5805726051330566, 1548647107.7181647], "v": 0.1} 3 | {"i": ["cuda", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 56, 112, 112], "float32"], ["TENSOR", [56, 1, 5, 5], "float32"], [1, 1], [2, 2], "float32"], {}, ["depthwise_conv2d_nchw", [1, 56, 112, 112, "float32"], [56, 1, 5, 5, "float32"], [1, 1], [2, 2], "float32"], {"i": 6693600, "e": [["tile_f", "sp", [56, 1, 1, 1]], ["tile_y", "sp", [2, 1, 14, 4]], ["tile_x", "sp", [7, 1, 16, 1]], ["auto_unroll_max_step", "ot", 256], ["unroll_explicit", "ot", 1]], "t": "direct", "c": null}], "r": [[0.0002963425], 0, 2.6830427646636963, 1548648186.9543624], "v": 0.1} 4 | {"i": ["cuda", "topi_nn_conv2d", [["TENSOR", [1, 120, 56, 56], "float32"], ["TENSOR", [56, 120, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 120, 56, 56, "float32"], [56, 120, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"i": 9907345, "e": [["tile_f", "sp", [1, 2, 4, 7]], ["tile_y", "sp", [28, 2, 1, 1]], ["tile_x", "sp", [1, 2, 28, 1]], ["tile_rc", "sp", [30, 4]], ["tile_ry", "sp", [1, 1]], ["tile_rx", "sp", [1, 1]], ["auto_unroll_max_step", "ot", 512], ["unroll_explicit", "ot", 0]], "t": "direct", "c": null}], "r": [[0.00020867025], 0, 1.3627350330352783, 1548649315.0133436], "v": 0.1} 5 | {"i": ["cuda", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 120, 56, 56], "float32"], ["TENSOR", [120, 1, 5, 5], "float32"], [1, 1], [2, 2], "float32"], {}, ["depthwise_conv2d_nchw", [1, 120, 56, 56, "float32"], [120, 1, 5, 5, "float32"], [1, 1], [2, 2], "float32"], {"i": 11456640, "e": [["tile_f", "sp", [120, 1, 1, 1]], ["tile_y", "sp", [4, 1, 7, 2]], ["tile_x", "sp", [1, 1, 28, 2]], ["auto_unroll_max_step", "ot", 1500], ["unroll_explicit", "ot", 1]], "t": "direct", "c": null}], "r": [[0.0001770705], 0, 2.9571971893310547, 1548650130.5959702], "v": 0.1} 6 | {"i": ["cuda", "topi_nn_conv2d", [["TENSOR", [1, 256, 28, 28], "float32"], ["TENSOR", [120, 256, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 256, 28, 28, "float32"], [120, 256, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"i": 15587259, "e": [["tile_f", "sp", [5, 3, 8, 1]], ["tile_y", "sp", [4, 1, 1, 7]], ["tile_x", "sp", [1, 1, 28, 1]], ["tile_rc", "sp", [32, 8]], ["tile_ry", "sp", [1, 1]], ["tile_rx", "sp", [1, 1]], ["auto_unroll_max_step", "ot", 0], ["unroll_explicit", "ot", 1]], "t": "direct", "c": null}], "r": [[0.000176463], 0, 1.957888126373291, 1548652038.6319573], "v": 0.1} 7 | {"i": ["cuda", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 256, 28, 28], "float32"], ["TENSOR", [256, 1, 5, 5], "float32"], [1, 1], [2, 2], "float32"], {}, ["depthwise_conv2d_nchw", [1, 256, 28, 28, "float32"], [256, 1, 5, 5, "float32"], [1, 1], [2, 2], "float32"], {"i": 1495890, "e": [["tile_f", "sp", [256, 1, 1, 1]], ["tile_y", "sp", [1, 1, 14, 2]], ["tile_x", "sp", [1, 1, 14, 2]], ["auto_unroll_max_step", "ot", 1500], ["unroll_explicit", "ot", 1]], "t": "direct", "c": null}], "r": [[0.000116847], 0, 2.5286483764648438, 1548652840.9417276], "v": 0.1} 8 | {"i": ["cuda", "topi_nn_conv2d", [["TENSOR", [1, 200, 14, 14], "float32"], ["TENSOR", [256, 200, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 200, 14, 14, "float32"], [256, 200, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"i": 697884, "e": [["tile_f", "sp", [8, 1, 8, 4]], ["tile_y", "sp", [1, 7, 2, 1]], ["tile_x", "sp", [1, 1, 14, 1]], ["tile_rc", "sp", [25, 8]], ["tile_ry", "sp", [1, 1]], ["tile_rx", "sp", [1, 1]], ["auto_unroll_max_step", "ot", 512], ["unroll_explicit", "ot", 0]], "t": "direct", "c": null}], "r": [[7.766325e-05], 0, 1.3543050289154053, 1548653774.0519202], "v": 0.1} 9 | {"i": ["cuda", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 200, 14, 14], "float32"], ["TENSOR", [200, 1, 5, 5], "float32"], [1, 1], [2, 2], "float32"], {}, ["depthwise_conv2d_nchw", [1, 200, 14, 14, "float32"], [200, 1, 5, 5, "float32"], [1, 1], [2, 2], "float32"], {"i": 273621, "e": [["tile_f", "sp", [50, 1, 4, 1]], ["tile_y", "sp", [1, 1, 14, 1]], ["tile_x", "sp", [1, 7, 2, 1]], ["auto_unroll_max_step", "ot", 1500], ["unroll_explicit", "ot", 1]], "t": "direct", "c": null}], "r": [[3.818375e-05], 0, 1.5769789218902588, 1548654909.867864], "v": 0.1} 10 | {"i": ["cuda", "topi_nn_conv2d", [["TENSOR", [1, 512, 7, 7], "float32"], ["TENSOR", [200, 512, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 512, 7, 7, "float32"], [200, 512, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"i": 43436, "e": [["tile_f", "sp", [5, 5, 8, 1]], ["tile_y", "sp", [1, 7, 1, 1]], ["tile_x", "sp", [1, 1, 7, 1]], ["tile_rc", "sp", [64, 8]], ["tile_ry", "sp", [1, 1]], ["tile_rx", "sp", [1, 1]], ["auto_unroll_max_step", "ot", 512], ["unroll_explicit", "ot", 0]], "t": "direct", "c": null}], "r": [[9.72875e-05], 0, 2.179405450820923, 1548655486.6924798], "v": 0.1} 11 | {"i": ["cuda", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 512, 7, 7], "float32"], ["TENSOR", [512, 1, 5, 5], "float32"], [1, 1], [2, 2], "float32"], {}, ["depthwise_conv2d_nchw", [1, 512, 7, 7, "float32"], [512, 1, 5, 5, "float32"], [1, 1], [2, 2], "float32"], {"i": 19900, "e": [["tile_f", "sp", [128, 1, 1, 4]], ["tile_y", "sp", [1, 1, 7, 1]], ["tile_x", "sp", [1, 1, 7, 1]], ["auto_unroll_max_step", "ot", 1500], ["unroll_explicit", "ot", 1]], "t": "direct", "c": null}], "r": [[3.839175e-05], 0, 2.85693097114563, 1548656199.745016], "v": 0.1} 12 | {"i": ["cuda", "topi_nn_conv2d", [["TENSOR", [1, 480, 7, 7], "float32"], ["TENSOR", [512, 480, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 480, 7, 7, "float32"], [512, 480, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"i": 287630, "e": [["tile_f", "sp", [8, 1, 32, 2]], ["tile_y", "sp", [1, 1, 1, 7]], ["tile_x", "sp", [1, 1, 7, 1]], ["tile_rc", "sp", [32, 15]], ["tile_ry", "sp", [1, 1]], ["tile_rx", "sp", [1, 1]], ["auto_unroll_max_step", "ot", 0], ["unroll_explicit", "ot", 1]], "t": "direct", "c": null}], "r": [[0.00013743925], 0, 1.9187650680541992, 1548657065.7516673], "v": 0.1} 13 | {"i": ["cuda", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 480, 7, 7], "float32"], ["TENSOR", [480, 1, 3, 3], "float32"], [1, 1], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 480, 7, 7, "float32"], [480, 1, 3, 3, "float32"], [1, 1], [1, 1], "float32"], {"i": 23621, "e": [["tile_f", "sp", [80, 2, 1, 3]], ["tile_y", "sp", [1, 1, 7, 1]], ["tile_x", "sp", [1, 1, 7, 1]], ["auto_unroll_max_step", "ot", 256], ["unroll_explicit", "ot", 0]], "t": "direct", "c": null}], "r": [[3.796775e-05], 0, 2.5406243801116943, 1548658554.4247355], "v": 0.1} 14 | {"i": ["cuda", "topi_nn_conv2d", [["TENSOR", [1, 328, 7, 7], "float32"], ["TENSOR", [480, 328, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 328, 7, 7, "float32"], [480, 328, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"i": 51212, "e": [["tile_f", "sp", [8, 3, 20, 1]], ["tile_y", "sp", [1, 7, 1, 1]], ["tile_x", "sp", [1, 1, 7, 1]], ["tile_rc", "sp", [41, 8]], ["tile_ry", "sp", [1, 1]], ["tile_rx", "sp", [1, 1]], ["auto_unroll_max_step", "ot", 0], ["unroll_explicit", "ot", 0]], "t": "direct", "c": null}], "r": [[0.000105911], 0, 1.2114758491516113, 1548659160.2813828], "v": 0.1} 15 | {"i": ["cuda", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 328, 14, 14], "float32"], ["TENSOR", [328, 1, 3, 3], "float32"], [2, 2], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 328, 14, 14, "float32"], [328, 1, 3, 3, "float32"], [2, 2], [1, 1], "float32"], {"i": 7218, "e": [["tile_f", "sp", [41, 1, 8, 1]], ["tile_y", "sp", [1, 1, 7, 1]], ["tile_x", "sp", [1, 1, 7, 1]], ["auto_unroll_max_step", "ot", 1500], ["unroll_explicit", "ot", 1]], "t": "direct", "c": null}], "r": [[3.831175e-05], 0, 1.9223744869232178, 1548659793.643159], "v": 0.1} 16 | {"i": ["cuda", "topi_nn_conv2d", [["TENSOR", [1, 296, 14, 14], "float32"], ["TENSOR", [328, 296, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 296, 14, 14, "float32"], [328, 296, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"i": 69322, "e": [["tile_f", "sp", [4, 1, 41, 2]], ["tile_y", "sp", [2, 7, 1, 1]], ["tile_x", "sp", [2, 1, 7, 1]], ["tile_rc", "sp", [37, 8]], ["tile_ry", "sp", [1, 1]], ["tile_rx", "sp", [1, 1]], ["auto_unroll_max_step", "ot", 0], ["unroll_explicit", "ot", 0]], "t": "direct", "c": null}], "r": [[0.000235302], 0, 2.302255630493164, 1548660675.0394955], "v": 0.1} 17 | {"i": ["cuda", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 296, 14, 14], "float32"], ["TENSOR", [296, 1, 3, 3], "float32"], [1, 1], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 296, 14, 14, "float32"], [296, 1, 3, 3, "float32"], [1, 1], [1, 1], "float32"], {"i": 112248, "e": [["tile_f", "sp", [148, 1, 2, 1]], ["tile_y", "sp", [1, 1, 7, 2]], ["tile_x", "sp", [1, 2, 7, 1]], ["auto_unroll_max_step", "ot", 1500], ["unroll_explicit", "ot", 1]], "t": "direct", "c": null}], "r": [[3.78955e-05], 0, 1.5320491790771484, 1548661622.8226697], "v": 0.1} 18 | {"i": ["cuda", "topi_nn_conv2d", [["TENSOR", [1, 288, 14, 14], "float32"], ["TENSOR", [296, 288, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 288, 14, 14, "float32"], [296, 288, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"i": 130741, "e": [["tile_f", "sp", [4, 2, 37, 1]], ["tile_y", "sp", [2, 7, 1, 1]], ["tile_x", "sp", [2, 1, 7, 1]], ["tile_rc", "sp", [32, 9]], ["tile_ry", "sp", [1, 1]], ["tile_rx", "sp", [1, 1]], ["auto_unroll_max_step", "ot", 0], ["unroll_explicit", "ot", 0]], "t": "direct", "c": null}], "r": [[0.00021435875], 0, 1.6423635482788086, 1548662485.159752], "v": 0.1} 19 | {"i": ["cuda", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 288, 14, 14], "float32"], ["TENSOR", [288, 1, 3, 3], "float32"], [1, 1], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 288, 14, 14, "float32"], [288, 1, 3, 3, "float32"], [1, 1], [1, 1], "float32"], {"i": 676499, "e": [["tile_f", "sp", [72, 2, 2, 1]], ["tile_y", "sp", [1, 1, 14, 1]], ["tile_x", "sp", [1, 1, 7, 2]], ["auto_unroll_max_step", "ot", 256], ["unroll_explicit", "ot", 1]], "t": "direct", "c": null}], "r": [[3.786375e-05], 0, 1.390625238418579, 1548663463.9980013], "v": 0.1} 20 | {"i": ["cuda", "topi_nn_conv2d", [["TENSOR", [1, 272, 14, 14], "float32"], ["TENSOR", [288, 272, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 272, 14, 14, "float32"], [288, 272, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"i": 4805431, "e": [["tile_f", "sp", [6, 6, 8, 1]], ["tile_y", "sp", [1, 7, 2, 1]], ["tile_x", "sp", [1, 1, 14, 1]], ["tile_rc", "sp", [34, 8]], ["tile_ry", "sp", [1, 1]], ["tile_rx", "sp", [1, 1]], ["auto_unroll_max_step", "ot", 0], ["unroll_explicit", "ot", 1]], "t": "direct", "c": null}], "r": [[0.00011731875], 0, 2.9426913261413574, 1548664977.3522089], "v": 0.1} 21 | {"i": ["cuda", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 272, 14, 14], "float32"], ["TENSOR", [272, 1, 3, 3], "float32"], [1, 1], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 272, 14, 14, "float32"], [272, 1, 3, 3, "float32"], [1, 1], [1, 1], "float32"], {"i": 160044, "e": [["tile_f", "sp", [34, 1, 8, 1]], ["tile_y", "sp", [1, 2, 7, 1]], ["tile_x", "sp", [1, 2, 7, 1]], ["auto_unroll_max_step", "ot", 256], ["unroll_explicit", "ot", 1]], "t": "direct", "c": null}], "r": [[3.8048e-05], 0, 1.3805427551269531, 1548666124.2337215], "v": 0.1} 22 | {"i": ["cuda", "topi_nn_conv2d", [["TENSOR", [1, 376, 14, 14], "float32"], ["TENSOR", [272, 376, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 376, 14, 14, "float32"], [272, 376, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"i": 986321, "e": [["tile_f", "sp", [4, 17, 4, 1]], ["tile_y", "sp", [1, 7, 2, 1]], ["tile_x", "sp", [1, 1, 14, 1]], ["tile_rc", "sp", [47, 8]], ["tile_ry", "sp", [1, 1]], ["tile_rx", "sp", [1, 1]], ["auto_unroll_max_step", "ot", 0], ["unroll_explicit", "ot", 1]], "t": "direct", "c": null}], "r": [[0.00017648625], 0, 2.901970386505127, 1548666802.045066], "v": 0.1} 23 | {"i": ["cuda", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 376, 14, 14], "float32"], ["TENSOR", [376, 1, 3, 3], "float32"], [1, 1], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 376, 14, 14, "float32"], [376, 1, 3, 3, "float32"], [1, 1], [1, 1], "float32"], {"i": 31608, "e": [["tile_f", "sp", [188, 1, 2, 1]], ["tile_y", "sp", [1, 1, 7, 2]], ["tile_x", "sp", [1, 1, 14, 1]], ["auto_unroll_max_step", "ot", 256], ["unroll_explicit", "ot", 0]], "t": "direct", "c": null}], "r": [[5.722375e-05], 0, 1.8894000053405762, 1548667848.8311331], "v": 0.1} 24 | {"i": ["cuda", "topi_nn_conv2d", [["TENSOR", [1, 408, 14, 14], "float32"], ["TENSOR", [376, 408, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 408, 14, 14, "float32"], [376, 408, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"i": 1399221, "e": [["tile_f", "sp", [4, 2, 47, 1]], ["tile_y", "sp", [2, 7, 1, 1]], ["tile_x", "sp", [1, 7, 2, 1]], ["tile_rc", "sp", [68, 6]], ["tile_ry", "sp", [1, 1]], ["tile_rx", "sp", [1, 1]], ["auto_unroll_max_step", "ot", 512], ["unroll_explicit", "ot", 1]], "t": "direct", "c": null}], "r": [[0.000361302], 0, 1.5910117626190186, 1548668723.9305615], "v": 0.1} 25 | {"i": ["cuda", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 408, 14, 14], "float32"], ["TENSOR", [408, 1, 3, 3], "float32"], [1, 1], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 408, 14, 14, "float32"], [408, 1, 3, 3, "float32"], [1, 1], [1, 1], "float32"], {"i": 371374, "e": [["tile_f", "sp", [102, 1, 1, 4]], ["tile_y", "sp", [1, 1, 14, 1]], ["tile_x", "sp", [1, 1, 14, 1]], ["auto_unroll_max_step", "ot", 256], ["unroll_explicit", "ot", 1]], "t": "direct", "c": null}], "r": [[5.76475e-05], 0, 2.138455390930176, 1548669478.0002713], "v": 0.1} 26 | {"i": ["cuda", "topi_nn_conv2d", [["TENSOR", [1, 256, 14, 14], "float32"], ["TENSOR", [408, 256, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 256, 14, 14, "float32"], [408, 256, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"i": 3237695, "e": [["tile_f", "sp", [6, 1, 4, 17]], ["tile_y", "sp", [1, 7, 2, 1]], ["tile_x", "sp", [1, 1, 14, 1]], ["tile_rc", "sp", [32, 8]], ["tile_ry", "sp", [1, 1]], ["tile_rx", "sp", [1, 1]], ["auto_unroll_max_step", "ot", 512], ["unroll_explicit", "ot", 1]], "t": "direct", "c": null}], "r": [[0.00014766275], 0, 1.5402188301086426, 1548670679.6535883], "v": 0.1} 27 | {"i": ["cuda", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 256, 28, 28], "float32"], ["TENSOR", [256, 1, 3, 3], "float32"], [2, 2], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 256, 28, 28, "float32"], [256, 1, 3, 3, "float32"], [2, 2], [1, 1], "float32"], {"i": 233311, "e": [["tile_f", "sp", [128, 2, 1, 1]], ["tile_y", "sp", [2, 1, 7, 1]], ["tile_x", "sp", [1, 1, 14, 1]], ["auto_unroll_max_step", "ot", 1500], ["unroll_explicit", "ot", 1]], "t": "direct", "c": null}], "r": [[7.772775e-05], 0, 2.0413858890533447, 1548671505.4201894], "v": 0.1} 28 | {"i": ["cuda", "topi_nn_conv2d", [["TENSOR", [1, 144, 28, 28], "float32"], ["TENSOR", [256, 144, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 144, 28, 28, "float32"], [256, 144, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"i": 1432721, "e": [["tile_f", "sp", [8, 4, 8, 1]], ["tile_y", "sp", [4, 7, 1, 1]], ["tile_x", "sp", [1, 1, 28, 1]], ["tile_rc", "sp", [18, 8]], ["tile_ry", "sp", [1, 1]], ["tile_rx", "sp", [1, 1]], ["auto_unroll_max_step", "ot", 0], ["unroll_explicit", "ot", 0]], "t": "direct", "c": null}], "r": [[0.000216294], 0, 1.6201062202453613, 1548672455.3859437], "v": 0.1} 29 | {"i": ["cuda", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 144, 28, 28], "float32"], ["TENSOR", [144, 1, 3, 3], "float32"], [1, 1], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 144, 28, 28, "float32"], [144, 1, 3, 3, "float32"], [1, 1], [1, 1], "float32"], {"i": 3046490, "e": [["tile_f", "sp", [72, 1, 1, 2]], ["tile_y", "sp", [2, 1, 7, 2]], ["tile_x", "sp", [1, 1, 28, 1]], ["auto_unroll_max_step", "ot", 1500], ["unroll_explicit", "ot", 1]], "t": "direct", "c": null}], "r": [[5.779175e-05], 0, 2.2191083431243896, 1548673391.4412506], "v": 0.1} 30 | {"i": ["cuda", "topi_nn_conv2d", [["TENSOR", [1, 120, 28, 28], "float32"], ["TENSOR", [144, 120, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 120, 28, 28, "float32"], [144, 120, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"i": 3047741, "e": [["tile_f", "sp", [2, 1, 8, 9]], ["tile_y", "sp", [7, 1, 1, 4]], ["tile_x", "sp", [1, 1, 28, 1]], ["tile_rc", "sp", [20, 6]], ["tile_ry", "sp", [1, 1]], ["tile_rx", "sp", [1, 1]], ["auto_unroll_max_step", "ot", 0], ["unroll_explicit", "ot", 0]], "t": "direct", "c": null}], "r": [[0.0001163435], 0, 1.2421369552612305, 1548675456.083652], "v": 0.1} 31 | {"i": ["cuda", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 120, 56, 56], "float32"], ["TENSOR", [120, 1, 3, 3], "float32"], [2, 2], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 120, 56, 56, "float32"], [120, 1, 3, 3, "float32"], [2, 2], [1, 1], "float32"], {"i": 2269760, "e": [["tile_f", "sp", [120, 1, 1, 1]], ["tile_y", "sp", [2, 2, 7, 1]], ["tile_x", "sp", [1, 1, 28, 1]], ["auto_unroll_max_step", "ot", 256], ["unroll_explicit", "ot", 1]], "t": "direct", "c": null}], "r": [[0.00011662325], 0, 1.597099781036377, 1548676329.4550035], "v": 0.1} 32 | {"i": ["cuda", "topi_nn_conv2d", [["TENSOR", [1, 88, 56, 56], "float32"], ["TENSOR", [120, 88, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 88, 56, 56, "float32"], [120, 88, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"i": 53975012, "e": [["tile_f", "sp", [2, 1, 4, 15]], ["tile_y", "sp", [14, 2, 1, 2]], ["tile_x", "sp", [1, 2, 28, 1]], ["tile_rc", "sp", [22, 4]], ["tile_ry", "sp", [1, 1]], ["tile_rx", "sp", [1, 1]], ["auto_unroll_max_step", "ot", 0], ["unroll_explicit", "ot", 1]], "t": "direct", "c": null}], "r": [[0.00027512525], 0, 1.369657278060913, 1548677300.418834], "v": 0.1} 33 | {"i": ["cuda", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 88, 56, 56], "float32"], ["TENSOR", [88, 1, 3, 3], "float32"], [1, 1], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 88, 56, 56, "float32"], [88, 1, 3, 3, "float32"], [1, 1], [1, 1], "float32"], {"i": 2750720, "e": [["tile_f", "sp", [88, 1, 1, 1]], ["tile_y", "sp", [2, 1, 4, 7]], ["tile_x", "sp", [1, 1, 56, 1]], ["auto_unroll_max_step", "ot", 1500], ["unroll_explicit", "ot", 1]], "t": "direct", "c": null}], "r": [[9.8191e-05], 0, 1.5733342170715332, 1548678185.3843915], "v": 0.1} 34 | {"i": ["cuda", "topi_nn_conv2d", [["TENSOR", [1, 56, 56, 56], "float32"], ["TENSOR", [88, 56, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 56, 56, 56, "float32"], [88, 56, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"i": 13501456, "e": [["tile_f", "sp", [2, 11, 4, 1]], ["tile_y", "sp", [14, 1, 1, 4]], ["tile_x", "sp", [1, 1, 56, 1]], ["tile_rc", "sp", [14, 4]], ["tile_ry", "sp", [1, 1]], ["tile_rx", "sp", [1, 1]], ["auto_unroll_max_step", "ot", 0], ["unroll_explicit", "ot", 1]], "t": "direct", "c": null}], "r": [[0.0001444785], 0, 1.7148137092590332, 1548679401.0876276], "v": 0.1} 35 | {"i": ["cuda", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 56, 112, 112], "float32"], ["TENSOR", [56, 1, 3, 3], "float32"], [2, 2], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 56, 112, 112, "float32"], [56, 1, 3, 3, "float32"], [2, 2], [1, 1], "float32"], {"i": 2748800, "e": [["tile_f", "sp", [56, 1, 1, 1]], ["tile_y", "sp", [7, 1, 4, 2]], ["tile_x", "sp", [1, 1, 56, 1]], ["auto_unroll_max_step", "ot", 1500], ["unroll_explicit", "ot", 1]], "t": "direct", "c": null}], "r": [[0.0001829665], 0, 1.6487436294555664, 1548680224.10838], "v": 0.1} 36 | {"i": ["cuda", "topi_nn_conv2d", [["TENSOR", [1, 16, 112, 112], "float32"], ["TENSOR", [56, 16, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 16, 112, 112, "float32"], [56, 16, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"i": 28698016, "e": [["tile_f", "sp", [2, 7, 4, 1]], ["tile_y", "sp", [56, 1, 1, 2]], ["tile_x", "sp", [2, 1, 56, 1]], ["tile_rc", "sp", [2, 8]], ["tile_ry", "sp", [1, 1]], ["tile_rx", "sp", [1, 1]], ["auto_unroll_max_step", "ot", 0], ["unroll_explicit", "ot", 1]], "t": "direct", "c": null}], "r": [[0.000139111], 0, 1.4088413715362549, 1548681204.6446605], "v": 0.1} 37 | {"i": ["cuda", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 16, 112, 112], "float32"], ["TENSOR", [16, 1, 3, 3], "float32"], [1, 1], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 16, 112, 112, "float32"], [16, 1, 3, 3, "float32"], [1, 1], [1, 1], "float32"], {"i": 3638775, "e": [["tile_f", "sp", [16, 1, 1, 1]], ["tile_y", "sp", [7, 1, 4, 4]], ["tile_x", "sp", [2, 1, 56, 1]], ["auto_unroll_max_step", "ot", 1500], ["unroll_explicit", "ot", 1]], "t": "direct", "c": null}], "r": [[7.75195e-05], 0, 1.5174787044525146, 1548682004.0559516], "v": 0.1} 38 | {"i": ["cuda", "topi_nn_conv2d", [["TENSOR", [1, 3, 224, 224], "float32"], ["TENSOR", [16, 3, 3, 3], "float32"], [2, 2], [1, 1], "NCHW", "float32"], {}, ["conv2d", [1, 3, 224, 224, "float32"], [16, 3, 3, 3, "float32"], [2, 2], [1, 1], "NCHW", "float32"], {"i": 20787688, "e": [["tile_f", "sp", [1, 1, 2, 8]], ["tile_y", "sp", [28, 1, 2, 2]], ["tile_x", "sp", [2, 1, 56, 1]], ["tile_rc", "sp", [3, 1]], ["tile_ry", "sp", [1, 3]], ["tile_rx", "sp", [1, 3]], ["auto_unroll_max_step", "ot", 0], ["unroll_explicit", "ot", 1]], "t": "direct", "c": null}], "r": [[0.00013608675], 0, 3.140852451324463, 1548683091.1406922], "v": 0.1} 39 | -------------------------------------------------------------------------------- /tvm_compile/tuning/tx2-gpu.mobilenet-nnconv5dw-skipadd.trials=2000.stop=600.log: -------------------------------------------------------------------------------- 1 | {"v": 0.1, "r": [[0.00019598275], 0, 1.5730626583099365, 1548562540.6536844], "i": ["cuda", "topi_nn_conv2d", [["TENSOR", [1, 32, 224, 224], "float32"], ["TENSOR", [1, 32, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 32, 224, 224, "float32"], [1, 32, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"e": [["tile_f", "sp", [1, 1, 1, 1]], ["tile_y", "sp", [28, 2, 4, 1]], ["tile_x", "sp", [1, 1, 224, 1]], ["tile_rc", "sp", [32, 1]], ["tile_ry", "sp", [1, 1]], ["tile_rx", "sp", [1, 1]], ["auto_unroll_max_step", "ot", 0], ["unroll_explicit", "ot", 0]], "t": "direct", "c": null, "i": 13911}]} 2 | {"v": 0.1, "r": [[0.0002421905], 0, 3.849091053009033, 1548564960.4150162], "i": ["cuda", "topi_nn_conv2d", [["TENSOR", [1, 64, 112, 112], "float32"], ["TENSOR", [32, 64, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 64, 112, 112, "float32"], [32, 64, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"e": [["tile_f", "sp", [1, 1, 8, 4]], ["tile_y", "sp", [56, 1, 1, 2]], ["tile_x", "sp", [1, 7, 16, 1]], ["tile_rc", "sp", [16, 4]], ["tile_ry", "sp", [1, 1]], ["tile_rx", "sp", [1, 1]], ["auto_unroll_max_step", "ot", 0], ["unroll_explicit", "ot", 0]], "t": "direct", "c": null, "i": 2495685}]} 3 | {"v": 0.1, "r": [[0.00033507775], 0, 1.67569899559021, 1548565783.6229427], "i": ["cuda", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 64, 112, 112], "float32"], ["TENSOR", [64, 1, 5, 5], "float32"], [1, 1], [2, 2], "float32"], {}, ["depthwise_conv2d_nchw", [1, 64, 112, 112, "float32"], [64, 1, 5, 5, "float32"], [1, 1], [2, 2], "float32"], {"e": [["tile_f", "sp", [64, 1, 1, 1]], ["tile_y", "sp", [2, 1, 14, 4]], ["tile_x", "sp", [7, 1, 16, 1]], ["auto_unroll_max_step", "ot", 1500], ["unroll_explicit", "ot", 1]], "t": "direct", "c": null, "i": 8674680}]} 4 | {"v": 0.1, "r": [[0.00019454275], 0, 1.5050935745239258, 1548566644.4516716], "i": ["cuda", "topi_nn_conv2d", [["TENSOR", [1, 128, 56, 56], "float32"], ["TENSOR", [64, 128, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 128, 56, 56, "float32"], [64, 128, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"e": [["tile_f", "sp", [2, 4, 8, 1]], ["tile_y", "sp", [28, 1, 1, 2]], ["tile_x", "sp", [1, 7, 8, 1]], ["tile_rc", "sp", [32, 4]], ["tile_ry", "sp", [1, 1]], ["tile_rx", "sp", [1, 1]], ["auto_unroll_max_step", "ot", 512], ["unroll_explicit", "ot", 0]], "t": "direct", "c": null, "i": 5533100}]} 5 | {"v": 0.1, "r": [[0.00019157425], 0, 3.138878345489502, 1548567436.7737145], "i": ["cuda", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 128, 56, 56], "float32"], ["TENSOR", [128, 1, 5, 5], "float32"], [1, 1], [2, 2], "float32"], {}, ["depthwise_conv2d_nchw", [1, 128, 56, 56, "float32"], [128, 1, 5, 5, "float32"], [1, 1], [2, 2], "float32"], {"e": [["tile_f", "sp", [128, 1, 1, 1]], ["tile_y", "sp", [2, 2, 7, 2]], ["tile_x", "sp", [1, 1, 28, 2]], ["auto_unroll_max_step", "ot", 1500], ["unroll_explicit", "ot", 1]], "t": "direct", "c": null, "i": 4296360}]} 6 | {"v": 0.1, "r": [[0.0001684785], 0, 1.6706435680389404, 1548568288.6996257], "i": ["cuda", "topi_nn_conv2d", [["TENSOR", [1, 256, 28, 28], "float32"], ["TENSOR", [128, 256, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 256, 28, 28, "float32"], [128, 256, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"e": [["tile_f", "sp", [2, 1, 8, 8]], ["tile_y", "sp", [7, 1, 4, 1]], ["tile_x", "sp", [1, 7, 4, 1]], ["tile_rc", "sp", [32, 8]], ["tile_ry", "sp", [1, 1]], ["tile_rx", "sp", [1, 1]], ["auto_unroll_max_step", "ot", 512], ["unroll_explicit", "ot", 1]], "t": "direct", "c": null, "i": 7542097}]} 7 | {"v": 0.1, "r": [[0.000109935], 0, 1.4281127452850342, 1548568975.349207], "i": ["cuda", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 256, 28, 28], "float32"], ["TENSOR", [256, 1, 5, 5], "float32"], [1, 1], [2, 2], "float32"], {}, ["depthwise_conv2d_nchw", [1, 256, 28, 28, "float32"], [256, 1, 5, 5, "float32"], [1, 1], [2, 2], "float32"], {"e": [["tile_f", "sp", [256, 1, 1, 1]], ["tile_y", "sp", [2, 1, 2, 7]], ["tile_x", "sp", [1, 1, 28, 1]], ["auto_unroll_max_step", "ot", 1500], ["unroll_explicit", "ot", 1]], "t": "direct", "c": null, "i": 1437645}]} 8 | {"v": 0.1, "r": [[0.000167879], 0, 1.5164942741394043, 1548570086.3063977], "i": ["cuda", "topi_nn_conv2d", [["TENSOR", [1, 512, 14, 14], "float32"], ["TENSOR", [256, 512, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 512, 14, 14, "float32"], [256, 512, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"e": [["tile_f", "sp", [4, 8, 8, 1]], ["tile_y", "sp", [2, 7, 1, 1]], ["tile_x", "sp", [1, 1, 14, 1]], ["tile_rc", "sp", [64, 8]], ["tile_ry", "sp", [1, 1]], ["tile_rx", "sp", [1, 1]], ["auto_unroll_max_step", "ot", 512], ["unroll_explicit", "ot", 0]], "t": "direct", "c": null, "i": 570597}]} 9 | {"v": 0.1, "r": [[7.83035e-05], 0, 3.1352176666259766, 1548570831.7553627], "i": ["cuda", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 512, 14, 14], "float32"], ["TENSOR", [512, 1, 5, 5], "float32"], [1, 1], [2, 2], "float32"], {}, ["depthwise_conv2d_nchw", [1, 512, 14, 14, "float32"], [512, 1, 5, 5, "float32"], [1, 1], [2, 2], "float32"], {"e": [["tile_f", "sp", [512, 1, 1, 1]], ["tile_y", "sp", [1, 1, 7, 2]], ["tile_x", "sp", [1, 2, 7, 1]], ["auto_unroll_max_step", "ot", 256], ["unroll_explicit", "ot", 1]], "t": "direct", "c": null, "i": 252340}]} 10 | {"v": 0.1, "r": [[0.0002710375], 0, 2.723203420639038, 1548571856.4664495], "i": ["cuda", "topi_nn_conv2d", [["TENSOR", [1, 1024, 7, 7], "float32"], ["TENSOR", [512, 1024, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 1024, 7, 7, "float32"], [512, 1024, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"e": [["tile_f", "sp", [8, 4, 16, 1]], ["tile_y", "sp", [1, 1, 7, 1]], ["tile_x", "sp", [1, 7, 1, 1]], ["tile_rc", "sp", [64, 16]], ["tile_ry", "sp", [1, 1]], ["tile_rx", "sp", [1, 1]], ["auto_unroll_max_step", "ot", 1500], ["unroll_explicit", "ot", 1]], "t": "direct", "c": null, "i": 209036}]} 11 | {"v": 0.1, "r": [[7.43915e-05], 0, 1.9873511791229248, 1548572384.122891], "i": ["cuda", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 1024, 7, 7], "float32"], ["TENSOR", [1024, 1, 5, 5], "float32"], [1, 1], [2, 2], "float32"], {}, ["depthwise_conv2d_nchw", [1, 1024, 7, 7, "float32"], [1024, 1, 5, 5, "float32"], [1, 1], [2, 2], "float32"], {"e": [["tile_f", "sp", [256, 1, 1, 4]], ["tile_y", "sp", [1, 1, 7, 1]], ["tile_x", "sp", [1, 1, 7, 1]], ["auto_unroll_max_step", "ot", 256], ["unroll_explicit", "ot", 1]], "t": "direct", "c": null, "i": 21285}]} 12 | {"v": 0.1, "r": [[0.00056902825], 0, 3.107520818710327, 1548573097.8638668], "i": ["cuda", "topi_nn_conv2d", [["TENSOR", [1, 1024, 7, 7], "float32"], ["TENSOR", [1024, 1024, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 1024, 7, 7, "float32"], [1024, 1024, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"e": [["tile_f", "sp", [16, 2, 32, 1]], ["tile_y", "sp", [1, 1, 1, 7]], ["tile_x", "sp", [1, 1, 7, 1]], ["tile_rc", "sp", [64, 16]], ["tile_ry", "sp", [1, 1]], ["tile_rx", "sp", [1, 1]], ["auto_unroll_max_step", "ot", 0], ["unroll_explicit", "ot", 0]], "t": "direct", "c": null, "i": 21496}]} 13 | {"v": 0.1, "r": [[4.303975e-05], 0, 1.3151054382324219, 1548573702.4212985], "i": ["cuda", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 1024, 7, 7], "float32"], ["TENSOR", [1024, 1, 3, 3], "float32"], [1, 1], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 1024, 7, 7, "float32"], [1024, 1, 3, 3, "float32"], [1, 1], [1, 1], "float32"], {"e": [["tile_f", "sp", [128, 1, 8, 1]], ["tile_y", "sp", [1, 1, 1, 7]], ["tile_x", "sp", [1, 1, 7, 1]], ["auto_unroll_max_step", "ot", 1500], ["unroll_explicit", "ot", 1]], "t": "direct", "c": null, "i": 26056}]} 14 | {"v": 0.1, "r": [[0.000293478], 0, 1.6003823280334473, 1548574300.4192855], "i": ["cuda", "topi_nn_conv2d", [["TENSOR", [1, 512, 7, 7], "float32"], ["TENSOR", [1024, 512, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 512, 7, 7, "float32"], [1024, 512, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"e": [["tile_f", "sp", [32, 2, 8, 2]], ["tile_y", "sp", [1, 7, 1, 1]], ["tile_x", "sp", [1, 1, 7, 1]], ["tile_rc", "sp", [64, 8]], ["tile_ry", "sp", [1, 1]], ["tile_rx", "sp", [1, 1]], ["auto_unroll_max_step", "ot", 0], ["unroll_explicit", "ot", 1]], "t": "direct", "c": null, "i": 153676}]} 15 | {"v": 0.1, "r": [[5.61915e-05], 0, 1.6853110790252686, 1548574886.3446763], "i": ["cuda", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 512, 14, 14], "float32"], ["TENSOR", [512, 1, 3, 3], "float32"], [2, 2], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 512, 14, 14, "float32"], [512, 1, 3, 3, "float32"], [2, 2], [1, 1], "float32"], {"e": [["tile_f", "sp", [32, 1, 8, 2]], ["tile_y", "sp", [1, 1, 7, 1]], ["tile_x", "sp", [1, 1, 7, 1]], ["auto_unroll_max_step", "ot", 1500], ["unroll_explicit", "ot", 1]], "t": "direct", "c": null, "i": 19879}]} 16 | {"v": 0.1, "r": [[0.000312902], 0, 2.8101165294647217, 1548575397.150415], "i": ["cuda", "topi_nn_conv2d", [["TENSOR", [1, 512, 14, 14], "float32"], ["TENSOR", [512, 512, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 512, 14, 14, "float32"], [512, 512, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"e": [["tile_f", "sp", [8, 4, 8, 2]], ["tile_y", "sp", [2, 1, 1, 7]], ["tile_x", "sp", [1, 1, 14, 1]], ["tile_rc", "sp", [64, 8]], ["tile_ry", "sp", [1, 1]], ["tile_rx", "sp", [1, 1]], ["auto_unroll_max_step", "ot", 1500], ["unroll_explicit", "ot", 0]], "t": "direct", "c": null, "i": 1326241}]} 17 | {"v": 0.1, "r": [[5.80155e-05], 0, 2.5316574573516846, 1548576505.3454807], "i": ["cuda", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 512, 14, 14], "float32"], ["TENSOR", [512, 1, 3, 3], "float32"], [1, 1], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 512, 14, 14, "float32"], [512, 1, 3, 3, "float32"], [1, 1], [1, 1], "float32"], {"e": [["tile_f", "sp", [128, 1, 4, 1]], ["tile_y", "sp", [1, 7, 2, 1]], ["tile_x", "sp", [1, 1, 14, 1]], ["auto_unroll_max_step", "ot", 0], ["unroll_explicit", "ot", 0]], "t": "direct", "c": null, "i": 29279}]} 18 | {"v": 0.1, "r": [[0.00017500625], 0, 3.4283621311187744, 1548577207.0975342], "i": ["cuda", "topi_nn_conv2d", [["TENSOR", [1, 256, 14, 14], "float32"], ["TENSOR", [512, 256, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 256, 14, 14, "float32"], [512, 256, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"e": [["tile_f", "sp", [8, 4, 8, 2]], ["tile_y", "sp", [2, 7, 1, 1]], ["tile_x", "sp", [1, 1, 14, 1]], ["tile_rc", "sp", [32, 8]], ["tile_ry", "sp", [1, 1]], ["tile_rx", "sp", [1, 1]], ["auto_unroll_max_step", "ot", 512], ["unroll_explicit", "ot", 0]], "t": "direct", "c": null, "i": 704521}]} 19 | {"v": 0.1, "r": [[7.74235e-05], 0, 1.7664618492126465, 1548578059.412158], "i": ["cuda", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 256, 28, 28], "float32"], ["TENSOR", [256, 1, 3, 3], "float32"], [2, 2], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 256, 28, 28, "float32"], [256, 1, 3, 3, "float32"], [2, 2], [1, 1], "float32"], {"e": [["tile_f", "sp", [32, 1, 4, 2]], ["tile_y", "sp", [2, 1, 7, 1]], ["tile_x", "sp", [1, 1, 14, 1]], ["auto_unroll_max_step", "ot", 1500], ["unroll_explicit", "ot", 1]], "t": "direct", "c": null, "i": 233370}]} 20 | {"v": 0.1, "r": [[0.000316189], 0, 1.5466210842132568, 1548578728.0832877], "i": ["cuda", "topi_nn_conv2d", [["TENSOR", [1, 256, 28, 28], "float32"], ["TENSOR", [256, 256, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 256, 28, 28, "float32"], [256, 256, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"e": [["tile_f", "sp", [4, 1, 8, 8]], ["tile_y", "sp", [7, 1, 4, 1]], ["tile_x", "sp", [1, 7, 4, 1]], ["tile_rc", "sp", [32, 8]], ["tile_ry", "sp", [1, 1]], ["tile_rx", "sp", [1, 1]], ["auto_unroll_max_step", "ot", 1500], ["unroll_explicit", "ot", 1]], "t": "direct", "c": null, "i": 12746374}]} 21 | {"v": 0.1, "r": [[7.8927e-05], 0, 1.5628066062927246, 1548579523.0229578], "i": ["cuda", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 256, 28, 28], "float32"], ["TENSOR", [256, 1, 3, 3], "float32"], [1, 1], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 256, 28, 28, "float32"], [256, 1, 3, 3, "float32"], [1, 1], [1, 1], "float32"], {"e": [["tile_f", "sp", [256, 1, 1, 1]], ["tile_y", "sp", [1, 1, 4, 7]], ["tile_x", "sp", [1, 1, 28, 1]], ["auto_unroll_max_step", "ot", 256], ["unroll_explicit", "ot", 1]], "t": "direct", "c": null, "i": 1173975}]} 22 | {"v": 0.1, "r": [[0.00019647875], 0, 2.4384193420410156, 1548580509.127791], "i": ["cuda", "topi_nn_conv2d", [["TENSOR", [1, 128, 28, 28], "float32"], ["TENSOR", [256, 128, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 128, 28, 28, "float32"], [256, 128, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"e": [["tile_f", "sp", [4, 1, 8, 8]], ["tile_y", "sp", [7, 1, 4, 1]], ["tile_x", "sp", [1, 7, 4, 1]], ["tile_rc", "sp", [16, 8]], ["tile_ry", "sp", [1, 1]], ["tile_rx", "sp", [1, 1]], ["auto_unroll_max_step", "ot", 512], ["unroll_explicit", "ot", 1]], "t": "direct", "c": null, "i": 9314374}]} 23 | {"v": 0.1, "r": [[0.00011714325], 0, 1.3188731670379639, 1548581309.9010916], "i": ["cuda", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 128, 56, 56], "float32"], ["TENSOR", [128, 1, 3, 3], "float32"], [2, 2], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 128, 56, 56, "float32"], [128, 1, 3, 3, "float32"], [2, 2], [1, 1], "float32"], {"e": [["tile_f", "sp", [128, 1, 1, 1]], ["tile_y", "sp", [2, 1, 14, 1]], ["tile_x", "sp", [1, 2, 14, 1]], ["auto_unroll_max_step", "ot", 1500], ["unroll_explicit", "ot", 1]], "t": "direct", "c": null, "i": 1038600}]} 24 | {"v": 0.1, "r": [[0.00031403], 0, 1.9832639694213867, 1548581896.1469371], "i": ["cuda", "topi_nn_conv2d", [["TENSOR", [1, 128, 56, 56], "float32"], ["TENSOR", [128, 128, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 128, 56, 56, "float32"], [128, 128, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"e": [["tile_f", "sp", [2, 2, 8, 4]], ["tile_y", "sp", [28, 1, 2, 1]], ["tile_x", "sp", [1, 7, 8, 1]], ["tile_rc", "sp", [16, 8]], ["tile_ry", "sp", [1, 1]], ["tile_rx", "sp", [1, 1]], ["auto_unroll_max_step", "ot", 512], ["unroll_explicit", "ot", 1]], "t": "direct", "c": null, "i": 27101840}]} 25 | {"v": 0.1, "r": [[0.000137879], 0, 1.4539649486541748, 1548582835.3741553], "i": ["cuda", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 128, 56, 56], "float32"], ["TENSOR", [128, 1, 3, 3], "float32"], [1, 1], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 128, 56, 56, "float32"], [128, 1, 3, 3, "float32"], [1, 1], [1, 1], "float32"], {"e": [["tile_f", "sp", [128, 1, 1, 1]], ["tile_y", "sp", [2, 1, 4, 7]], ["tile_x", "sp", [1, 1, 56, 1]], ["auto_unroll_max_step", "ot", 1500], ["unroll_explicit", "ot", 1]], "t": "direct", "c": null, "i": 4126080}]} 26 | {"v": 0.1, "r": [[0.00020232675], 0, 2.3310959339141846, 1548584787.231282], "i": ["cuda", "topi_nn_conv2d", [["TENSOR", [1, 64, 56, 56], "float32"], ["TENSOR", [128, 64, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 64, 56, 56, "float32"], [128, 64, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"e": [["tile_f", "sp", [2, 1, 8, 8]], ["tile_y", "sp", [28, 1, 2, 1]], ["tile_x", "sp", [1, 7, 8, 1]], ["tile_rc", "sp", [8, 8]], ["tile_ry", "sp", [1, 1]], ["tile_rx", "sp", [1, 1]], ["auto_unroll_max_step", "ot", 0], ["unroll_explicit", "ot", 0]], "t": "direct", "c": null, "i": 2525857}]} 27 | {"v": 0.1, "r": [[0.000208358], 0, 2.127153158187866, 1548585621.7478325], "i": ["cuda", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 64, 112, 112], "float32"], ["TENSOR", [64, 1, 3, 3], "float32"], [2, 2], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 64, 112, 112, "float32"], [64, 1, 3, 3, "float32"], [2, 2], [1, 1], "float32"], {"e": [["tile_f", "sp", [64, 1, 1, 1]], ["tile_y", "sp", [7, 2, 4, 1]], ["tile_x", "sp", [1, 1, 56, 1]], ["auto_unroll_max_step", "ot", 256], ["unroll_explicit", "ot", 1]], "t": "direct", "c": null, "i": 2346540}]} 28 | {"v": 0.1, "r": [[0.00021625475], 0, 2.464158296585083, 1548587361.9766982], "i": ["cuda", "topi_nn_conv2d", [["TENSOR", [1, 32, 112, 112], "float32"], ["TENSOR", [64, 32, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 32, 112, 112, "float32"], [64, 32, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"e": [["tile_f", "sp", [2, 2, 8, 2]], ["tile_y", "sp", [28, 1, 1, 4]], ["tile_x", "sp", [7, 1, 16, 1]], ["tile_rc", "sp", [8, 4]], ["tile_ry", "sp", [1, 1]], ["tile_rx", "sp", [1, 1]], ["auto_unroll_max_step", "ot", 0], ["unroll_explicit", "ot", 0]], "t": "direct", "c": null, "i": 3734264}]} 29 | {"v": 0.1, "r": [[0.00013695125], 0, 2.6825642585754395, 1548588199.2572803], "i": ["cuda", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 32, 112, 112], "float32"], ["TENSOR", [32, 1, 3, 3], "float32"], [1, 1], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 32, 112, 112, "float32"], [32, 1, 3, 3, "float32"], [1, 1], [1, 1], "float32"], {"e": [["tile_f", "sp", [32, 1, 1, 1]], ["tile_y", "sp", [4, 1, 4, 7]], ["tile_x", "sp", [2, 1, 56, 1]], ["auto_unroll_max_step", "ot", 256], ["unroll_explicit", "ot", 1]], "t": "direct", "c": null, "i": 4725392}]} 30 | {"v": 0.1, "r": [[0.000175575], 0, 1.2250316143035889, 1548589605.8968496], "i": ["cuda", "topi_nn_conv2d", [["TENSOR", [1, 3, 224, 224], "float32"], ["TENSOR", [32, 3, 3, 3], "float32"], [2, 2], [1, 1], "NCHW", "float32"], {}, ["conv2d", [1, 3, 224, 224, "float32"], [32, 3, 3, 3, "float32"], [2, 2], [1, 1], "NCHW", "float32"], {"e": [["tile_f", "sp", [1, 2, 8, 2]], ["tile_y", "sp", [28, 4, 1, 1]], ["tile_x", "sp", [7, 1, 16, 1]], ["tile_rc", "sp", [3, 1]], ["tile_ry", "sp", [1, 3]], ["tile_rx", "sp", [1, 3]], ["auto_unroll_max_step", "ot", 0], ["unroll_explicit", "ot", 1]], "t": "direct", "c": null, "i": 33218226}]} 31 | -------------------------------------------------------------------------------- /tvm_compile/tuning/tx2-gpu.mobilenet-nnconv5dw.trials=2000.stop=600.log: -------------------------------------------------------------------------------- 1 | {"v": 0.1, "r": [[0.00019565525], 0, 2.689757823944092, 1549429087.5617292], "i": ["cuda", "topi_nn_conv2d", [["TENSOR", [1, 32, 224, 224], "float32"], ["TENSOR", [1, 32, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 32, 224, 224, "float32"], [1, 32, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"e": [["tile_f", "sp", [1, 1, 1, 1]], ["tile_y", "sp", [56, 4, 1, 1]], ["tile_x", "sp", [1, 1, 224, 1]], ["tile_rc", "sp", [32, 1]], ["tile_ry", "sp", [1, 1]], ["tile_rx", "sp", [1, 1]], ["auto_unroll_max_step", "ot", 512], ["unroll_explicit", "ot", 0]], "t": "direct", "c": null, "i": 314946}]} 2 | {"v": 0.1, "r": [[0.0003639425], 0, 4.725115060806274, 1549430548.7858527], "i": ["cuda", "topi_nn_conv2d", [["TENSOR", [1, 64, 112, 112], "float32"], ["TENSOR", [32, 64, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 64, 112, 112, "float32"], [32, 64, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"e": [["tile_f", "sp", [2, 4, 1, 4]], ["tile_y", "sp", [112, 1, 1, 1]], ["tile_x", "sp", [1, 1, 112, 1]], ["tile_rc", "sp", [32, 2]], ["tile_ry", "sp", [1, 1]], ["tile_rx", "sp", [1, 1]], ["auto_unroll_max_step", "ot", 1500], ["unroll_explicit", "ot", 0]], "t": "direct", "c": null, "i": 16808998}]} 3 | {"v": 0.1, "r": [[0.00033395075], 0, 2.3271985054016113, 1549431980.0935845], "i": ["cuda", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 64, 112, 112], "float32"], ["TENSOR", [64, 1, 5, 5], "float32"], [1, 1], [2, 2], "float32"], {}, ["depthwise_conv2d_nchw", [1, 64, 112, 112, "float32"], [64, 1, 5, 5, "float32"], [1, 1], [2, 2], "float32"], {"e": [["tile_f", "sp", [64, 1, 1, 1]], ["tile_y", "sp", [2, 1, 14, 4]], ["tile_x", "sp", [7, 1, 16, 1]], ["auto_unroll_max_step", "ot", 1500], ["unroll_explicit", "ot", 1]], "t": "direct", "c": null, "i": 8674680}]} 4 | {"v": 0.1, "r": [[0.0002864375], 0, 2.3175532817840576, 1549433087.935274], "i": ["cuda", "topi_nn_conv2d", [["TENSOR", [1, 128, 56, 56], "float32"], ["TENSOR", [64, 128, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 128, 56, 56, "float32"], [64, 128, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"e": [["tile_f", "sp", [2, 32, 1, 1]], ["tile_y", "sp", [56, 1, 1, 1]], ["tile_x", "sp", [1, 2, 28, 1]], ["tile_rc", "sp", [32, 4]], ["tile_ry", "sp", [1, 1]], ["tile_rx", "sp", [1, 1]], ["auto_unroll_max_step", "ot", 1500], ["unroll_explicit", "ot", 0]], "t": "direct", "c": null, "i": 9864965}]} 5 | {"v": 0.1, "r": [[0.00019396625], 0, 1.462557315826416, 1549433838.8822808], "i": ["cuda", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 128, 56, 56], "float32"], ["TENSOR", [128, 1, 5, 5], "float32"], [1, 1], [2, 2], "float32"], {}, ["depthwise_conv2d_nchw", [1, 128, 56, 56, "float32"], [128, 1, 5, 5, "float32"], [1, 1], [2, 2], "float32"], {"e": [["tile_f", "sp", [128, 1, 1, 1]], ["tile_y", "sp", [2, 2, 7, 2]], ["tile_x", "sp", [1, 1, 28, 2]], ["auto_unroll_max_step", "ot", 256], ["unroll_explicit", "ot", 1]], "t": "direct", "c": null, "i": 3528360}]} 6 | {"v": 0.1, "r": [[0.000215031], 0, 1.3755156993865967, 1549435352.1497386], "i": ["cuda", "topi_nn_conv2d", [["TENSOR", [1, 256, 28, 28], "float32"], ["TENSOR", [128, 256, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 256, 28, 28, "float32"], [128, 256, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"e": [["tile_f", "sp", [2, 4, 16, 1]], ["tile_y", "sp", [14, 1, 1, 2]], ["tile_x", "sp", [1, 2, 14, 1]], ["tile_rc", "sp", [16, 16]], ["tile_ry", "sp", [1, 1]], ["tile_rx", "sp", [1, 1]], ["auto_unroll_max_step", "ot", 0], ["unroll_explicit", "ot", 1]], "t": "direct", "c": null, "i": 6030988}]} 7 | {"v": 0.1, "r": [[0.0001163515], 0, 3.0618081092834473, 1549436102.657577], "i": ["cuda", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 256, 28, 28], "float32"], ["TENSOR", [256, 1, 5, 5], "float32"], [1, 1], [2, 2], "float32"], {}, ["depthwise_conv2d_nchw", [1, 256, 28, 28, "float32"], [256, 1, 5, 5, "float32"], [1, 1], [2, 2], "float32"], {"e": [["tile_f", "sp", [256, 1, 1, 1]], ["tile_y", "sp", [2, 1, 2, 7]], ["tile_x", "sp", [1, 1, 28, 1]], ["auto_unroll_max_step", "ot", 256], ["unroll_explicit", "ot", 1]], "t": "direct", "c": null, "i": 1173645}]} 8 | {"v": 0.1, "r": [[0.00017022325], 0, 3.1339035034179688, 1549437395.9139726], "i": ["cuda", "topi_nn_conv2d", [["TENSOR", [1, 512, 14, 14], "float32"], ["TENSOR", [256, 512, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 512, 14, 14, "float32"], [256, 512, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"e": [["tile_f", "sp", [4, 4, 16, 1]], ["tile_y", "sp", [2, 7, 1, 1]], ["tile_x", "sp", [1, 1, 14, 1]], ["tile_rc", "sp", [32, 16]], ["tile_ry", "sp", [1, 1]], ["tile_rx", "sp", [1, 1]], ["auto_unroll_max_step", "ot", 512], ["unroll_explicit", "ot", 0]], "t": "direct", "c": null, "i": 612842}]} 9 | {"v": 0.1, "r": [[7.806325e-05], 0, 1.9228358268737793, 1549438115.9921758], "i": ["cuda", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 512, 14, 14], "float32"], ["TENSOR", [512, 1, 5, 5], "float32"], [1, 1], [2, 2], "float32"], {}, ["depthwise_conv2d_nchw", [1, 512, 14, 14, "float32"], [512, 1, 5, 5, "float32"], [1, 1], [2, 2], "float32"], {"e": [["tile_f", "sp", [512, 1, 1, 1]], ["tile_y", "sp", [1, 2, 7, 1]], ["tile_x", "sp", [1, 1, 7, 2]], ["auto_unroll_max_step", "ot", 256], ["unroll_explicit", "ot", 1]], "t": "direct", "c": null, "i": 265540}]} 10 | {"v": 0.1, "r": [[0.00027059775], 0, 1.8419475555419922, 1549439012.3243341], "i": ["cuda", "topi_nn_conv2d", [["TENSOR", [1, 1024, 7, 7], "float32"], ["TENSOR", [512, 1024, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 1024, 7, 7, "float32"], [512, 1024, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"e": [["tile_f", "sp", [8, 4, 16, 1]], ["tile_y", "sp", [1, 1, 1, 7]], ["tile_x", "sp", [1, 1, 7, 1]], ["tile_rc", "sp", [64, 16]], ["tile_ry", "sp", [1, 1]], ["tile_rx", "sp", [1, 1]], ["auto_unroll_max_step", "ot", 1500], ["unroll_explicit", "ot", 1]], "t": "direct", "c": null, "i": 210136}]} 11 | {"v": 0.1, "r": [[7.395175e-05], 0, 1.3589377403259277, 1549439569.3904214], "i": ["cuda", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 1024, 7, 7], "float32"], ["TENSOR", [1024, 1, 5, 5], "float32"], [1, 1], [2, 2], "float32"], {}, ["depthwise_conv2d_nchw", [1, 1024, 7, 7, "float32"], [1024, 1, 5, 5, "float32"], [1, 1], [2, 2], "float32"], {"e": [["tile_f", "sp", [128, 1, 4, 2]], ["tile_y", "sp", [1, 1, 7, 1]], ["tile_x", "sp", [1, 1, 7, 1]], ["auto_unroll_max_step", "ot", 256], ["unroll_explicit", "ot", 1]], "t": "direct", "c": null, "i": 21249}]} 12 | {"v": 0.1, "r": [[0.000567654], 0, 2.060800075531006, 1549440270.1668124], "i": ["cuda", "topi_nn_conv2d", [["TENSOR", [1, 1024, 7, 7], "float32"], ["TENSOR", [1024, 1024, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 1024, 7, 7, "float32"], [1024, 1024, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"e": [["tile_f", "sp", [16, 2, 32, 1]], ["tile_y", "sp", [1, 1, 1, 7]], ["tile_x", "sp", [1, 1, 7, 1]], ["tile_rc", "sp", [64, 16]], ["tile_ry", "sp", [1, 1]], ["tile_rx", "sp", [1, 1]], ["auto_unroll_max_step", "ot", 0], ["unroll_explicit", "ot", 1]], "t": "direct", "c": null, "i": 172504}]} 13 | {"v": 0.1, "r": [[4.21195e-05], 0, 1.4836952686309814, 1549440881.346037], "i": ["cuda", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 1024, 7, 7], "float32"], ["TENSOR", [1024, 1, 3, 3], "float32"], [1, 1], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 1024, 7, 7, "float32"], [1024, 1, 3, 3, "float32"], [1, 1], [1, 1], "float32"], {"e": [["tile_f", "sp", [256, 1, 4, 1]], ["tile_y", "sp", [1, 7, 1, 1]], ["tile_x", "sp", [1, 1, 7, 1]], ["auto_unroll_max_step", "ot", 1500], ["unroll_explicit", "ot", 1]], "t": "direct", "c": null, "i": 25475}]} 14 | {"v": 0.1, "r": [[0.00029675725], 0, 2.273451805114746, 1549441441.2646382], "i": ["cuda", "topi_nn_conv2d", [["TENSOR", [1, 512, 7, 7], "float32"], ["TENSOR", [1024, 512, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 512, 7, 7, "float32"], [1024, 512, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"e": [["tile_f", "sp", [32, 2, 8, 2]], ["tile_y", "sp", [1, 1, 1, 7]], ["tile_x", "sp", [1, 1, 7, 1]], ["tile_rc", "sp", [64, 8]], ["tile_ry", "sp", [1, 1]], ["tile_rx", "sp", [1, 1]], ["auto_unroll_max_step", "ot", 0], ["unroll_explicit", "ot", 1]], "t": "direct", "c": null, "i": 154248}]} 15 | {"v": 0.1, "r": [[5.194375e-05], 0, 1.7700073719024658, 1549442009.8341005], "i": ["cuda", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 512, 14, 14], "float32"], ["TENSOR", [512, 1, 3, 3], "float32"], [2, 2], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 512, 14, 14, "float32"], [512, 1, 3, 3, "float32"], [2, 2], [1, 1], "float32"], {"e": [["tile_f", "sp", [32, 1, 8, 2]], ["tile_y", "sp", [1, 1, 7, 1]], ["tile_x", "sp", [1, 1, 7, 1]], ["auto_unroll_max_step", "ot", 1500], ["unroll_explicit", "ot", 1]], "t": "direct", "c": null, "i": 19879}]} 16 | {"v": 0.1, "r": [[0.00031262125], 0, 1.2954254150390625, 1549442532.348055], "i": ["cuda", "topi_nn_conv2d", [["TENSOR", [1, 512, 14, 14], "float32"], ["TENSOR", [512, 512, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 512, 14, 14, "float32"], [512, 512, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"e": [["tile_f", "sp", [8, 4, 8, 2]], ["tile_y", "sp", [2, 7, 1, 1]], ["tile_x", "sp", [1, 1, 14, 1]], ["tile_rc", "sp", [64, 8]], ["tile_ry", "sp", [1, 1]], ["tile_rx", "sp", [1, 1]], ["auto_unroll_max_step", "ot", 1500], ["unroll_explicit", "ot", 0]], "t": "direct", "c": null, "i": 1324041}]} 17 | {"v": 0.1, "r": [[5.823175e-05], 0, 1.1607718467712402, 1549443333.8850622], "i": ["cuda", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 512, 14, 14], "float32"], ["TENSOR", [512, 1, 3, 3], "float32"], [1, 1], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 512, 14, 14, "float32"], [512, 1, 3, 3, "float32"], [1, 1], [1, 1], "float32"], {"e": [["tile_f", "sp", [512, 1, 1, 1]], ["tile_y", "sp", [1, 2, 7, 1]], ["tile_x", "sp", [1, 1, 7, 2]], ["auto_unroll_max_step", "ot", 1500], ["unroll_explicit", "ot", 1]], "t": "direct", "c": null, "i": 321860}]} 18 | {"v": 0.1, "r": [[0.0001771745], 0, 2.1588521003723145, 1549443916.2324407], "i": ["cuda", "topi_nn_conv2d", [["TENSOR", [1, 256, 14, 14], "float32"], ["TENSOR", [512, 256, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 256, 14, 14, "float32"], [512, 256, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"e": [["tile_f", "sp", [8, 8, 8, 1]], ["tile_y", "sp", [2, 1, 1, 7]], ["tile_x", "sp", [1, 1, 14, 1]], ["tile_rc", "sp", [32, 8]], ["tile_ry", "sp", [1, 1]], ["tile_rx", "sp", [1, 1]], ["auto_unroll_max_step", "ot", 512], ["unroll_explicit", "ot", 1]], "t": "direct", "c": null, "i": 2227310}]} 19 | {"v": 0.1, "r": [[7.740725e-05], 0, 2.187483310699463, 1549444777.0748699], "i": ["cuda", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 256, 28, 28], "float32"], ["TENSOR", [256, 1, 3, 3], "float32"], [2, 2], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 256, 28, 28, "float32"], [256, 1, 3, 3, "float32"], [2, 2], [1, 1], "float32"], {"e": [["tile_f", "sp", [64, 1, 4, 1]], ["tile_y", "sp", [1, 1, 7, 2]], ["tile_x", "sp", [2, 1, 7, 1]], ["auto_unroll_max_step", "ot", 1500], ["unroll_explicit", "ot", 1]], "t": "direct", "c": null, "i": 228872}]} 20 | {"v": 0.1, "r": [[0.00036090275], 0, 1.9614224433898926, 1549445574.4298754], "i": ["cuda", "topi_nn_conv2d", [["TENSOR", [1, 256, 28, 28], "float32"], ["TENSOR", [256, 256, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 256, 28, 28, "float32"], [256, 256, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"e": [["tile_f", "sp", [4, 8, 8, 1]], ["tile_y", "sp", [7, 1, 1, 4]], ["tile_x", "sp", [1, 2, 14, 1]], ["tile_rc", "sp", [32, 8]], ["tile_ry", "sp", [1, 1]], ["tile_rx", "sp", [1, 1]], ["auto_unroll_max_step", "ot", 1500], ["unroll_explicit", "ot", 1]], "t": "direct", "c": null, "i": 12782082}]} 21 | {"v": 0.1, "r": [[7.871125e-05], 0, 1.6131527423858643, 1549446317.7387452], "i": ["cuda", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 256, 28, 28], "float32"], ["TENSOR", [256, 1, 3, 3], "float32"], [1, 1], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 256, 28, 28, "float32"], [256, 1, 3, 3, "float32"], [1, 1], [1, 1], "float32"], {"e": [["tile_f", "sp", [256, 1, 1, 1]], ["tile_y", "sp", [1, 1, 4, 7]], ["tile_x", "sp", [1, 1, 28, 1]], ["auto_unroll_max_step", "ot", 1500], ["unroll_explicit", "ot", 1]], "t": "direct", "c": null, "i": 1437975}]} 22 | {"v": 0.1, "r": [[0.00019607125], 0, 1.8411102294921875, 1549446989.5312948], "i": ["cuda", "topi_nn_conv2d", [["TENSOR", [1, 128, 28, 28], "float32"], ["TENSOR", [256, 128, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 128, 28, 28, "float32"], [256, 128, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"e": [["tile_f", "sp", [4, 4, 8, 2]], ["tile_y", "sp", [7, 1, 1, 4]], ["tile_x", "sp", [1, 1, 28, 1]], ["tile_rc", "sp", [16, 8]], ["tile_ry", "sp", [1, 1]], ["tile_rx", "sp", [1, 1]], ["auto_unroll_max_step", "ot", 0], ["unroll_explicit", "ot", 1]], "t": "direct", "c": null, "i": 7244723}]} 23 | {"v": 0.1, "r": [[0.000116695], 0, 1.8163161277770996, 1549447789.34965], "i": ["cuda", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 128, 56, 56], "float32"], ["TENSOR", [128, 1, 3, 3], "float32"], [2, 2], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 128, 56, 56, "float32"], [128, 1, 3, 3, "float32"], [2, 2], [1, 1], "float32"], {"e": [["tile_f", "sp", [128, 1, 1, 1]], ["tile_y", "sp", [2, 1, 7, 2]], ["tile_x", "sp", [1, 1, 28, 1]], ["auto_unroll_max_step", "ot", 256], ["unroll_explicit", "ot", 1]], "t": "direct", "c": null, "i": 852480}]} 24 | {"v": 0.1, "r": [[0.00031092525], 0, 3.726764440536499, 1549448612.6387038], "i": ["cuda", "topi_nn_conv2d", [["TENSOR", [1, 128, 56, 56], "float32"], ["TENSOR", [128, 128, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 128, 56, 56, "float32"], [128, 128, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"e": [["tile_f", "sp", [2, 4, 8, 2]], ["tile_y", "sp", [28, 1, 2, 1]], ["tile_x", "sp", [1, 7, 8, 1]], ["tile_rc", "sp", [16, 8]], ["tile_ry", "sp", [1, 1]], ["tile_rx", "sp", [1, 1]], ["auto_unroll_max_step", "ot", 1500], ["unroll_explicit", "ot", 1]], "t": "direct", "c": null, "i": 33245816}]} 25 | {"v": 0.1, "r": [[0.0001369275], 0, 2.400529623031616, 1549449447.2495077], "i": ["cuda", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 128, 56, 56], "float32"], ["TENSOR", [128, 1, 3, 3], "float32"], [1, 1], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 128, 56, 56, "float32"], [128, 1, 3, 3, "float32"], [1, 1], [1, 1], "float32"], {"e": [["tile_f", "sp", [128, 1, 1, 1]], ["tile_y", "sp", [1, 1, 8, 7]], ["tile_x", "sp", [1, 1, 56, 1]], ["auto_unroll_max_step", "ot", 1500], ["unroll_explicit", "ot", 1]], "t": "direct", "c": null, "i": 4126320}]} 26 | {"v": 0.1, "r": [[0.00020027125], 0, 2.450047016143799, 1549450572.8107345], "i": ["cuda", "topi_nn_conv2d", [["TENSOR", [1, 64, 56, 56], "float32"], ["TENSOR", [128, 64, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 64, 56, 56, "float32"], [128, 64, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"e": [["tile_f", "sp", [2, 2, 8, 4]], ["tile_y", "sp", [28, 1, 2, 1]], ["tile_x", "sp", [1, 7, 8, 1]], ["tile_rc", "sp", [8, 8]], ["tile_ry", "sp", [1, 1]], ["tile_rx", "sp", [1, 1]], ["auto_unroll_max_step", "ot", 0], ["unroll_explicit", "ot", 1]], "t": "direct", "c": null, "i": 18653840}]} 27 | {"v": 0.1, "r": [[0.000202495], 0, 1.4411323070526123, 1549451405.1190436], "i": ["cuda", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 64, 112, 112], "float32"], ["TENSOR", [64, 1, 3, 3], "float32"], [2, 2], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 64, 112, 112, "float32"], [64, 1, 3, 3, "float32"], [2, 2], [1, 1], "float32"], {"e": [["tile_f", "sp", [64, 1, 1, 1]], ["tile_y", "sp", [7, 2, 4, 1]], ["tile_x", "sp", [1, 1, 56, 1]], ["auto_unroll_max_step", "ot", 1500], ["unroll_explicit", "ot", 1]], "t": "direct", "c": null, "i": 2884140}]} 28 | {"v": 0.1, "r": [[0.000216318], 0, 1.7042787075042725, 1549452490.2196207], "i": ["cuda", "topi_nn_conv2d", [["TENSOR", [1, 32, 112, 112], "float32"], ["TENSOR", [64, 32, 1, 1], "float32"], [1, 1], [0, 0], "NCHW", "float32"], {}, ["conv2d", [1, 32, 112, 112, "float32"], [64, 32, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "float32"], {"e": [["tile_f", "sp", [2, 2, 8, 2]], ["tile_y", "sp", [28, 2, 1, 2]], ["tile_x", "sp", [7, 1, 16, 1]], ["tile_rc", "sp", [8, 4]], ["tile_ry", "sp", [1, 1]], ["tile_rx", "sp", [1, 1]], ["auto_unroll_max_step", "ot", 0], ["unroll_explicit", "ot", 1]], "t": "direct", "c": null, "i": 33367028}]} 29 | {"v": 0.1, "r": [[0.0001368875], 0, 1.4162554740905762, 1549453299.762188], "i": ["cuda", "topi_nn_depthwise_conv2d_nchw", [["TENSOR", [1, 32, 112, 112], "float32"], ["TENSOR", [32, 1, 3, 3], "float32"], [1, 1], [1, 1], "float32"], {}, ["depthwise_conv2d_nchw", [1, 32, 112, 112, "float32"], [32, 1, 3, 3, "float32"], [1, 1], [1, 1], "float32"], {"e": [["tile_f", "sp", [32, 1, 1, 1]], ["tile_y", "sp", [4, 1, 4, 7]], ["tile_x", "sp", [2, 1, 56, 1]], ["auto_unroll_max_step", "ot", 1500], ["unroll_explicit", "ot", 1]], "t": "direct", "c": null, "i": 5822992}]} 30 | {"v": 0.1, "r": [[0.0001747745], 0, 1.469916820526123, 1549454076.4118538], "i": ["cuda", "topi_nn_conv2d", [["TENSOR", [1, 3, 224, 224], "float32"], ["TENSOR", [32, 3, 3, 3], "float32"], [2, 2], [1, 1], "NCHW", "float32"], {}, ["conv2d", [1, 3, 224, 224, "float32"], [32, 3, 3, 3, "float32"], [2, 2], [1, 1], "NCHW", "float32"], {"e": [["tile_f", "sp", [1, 2, 8, 2]], ["tile_y", "sp", [56, 2, 1, 1]], ["tile_x", "sp", [7, 1, 16, 1]], ["tile_rc", "sp", [3, 1]], ["tile_ry", "sp", [1, 3]], ["tile_rx", "sp", [1, 3]], ["auto_unroll_max_step", "ot", 0], ["unroll_explicit", "ot", 1]], "t": "direct", "c": null, "i": 33218170}]} 31 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import shutil 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | from PIL import Image 7 | import math 8 | 9 | cmap = plt.cm.viridis 10 | 11 | 12 | def parse_command(): 13 | data_names = ['nyudepthv2'] 14 | 15 | from dataloaders.dataloader import MyDataloader 16 | modality_names = MyDataloader.modality_names 17 | 18 | import argparse 19 | parser = argparse.ArgumentParser(description='FastDepth') 20 | parser.add_argument('--data', metavar='DATA', default='nyudepthv2', 21 | choices=data_names, 22 | help='dataset: ' + ' | '.join(data_names) + ' (default: nyudepthv2)') 23 | parser.add_argument('--modality', '-m', metavar='MODALITY', default='rgb', choices=modality_names, 24 | help='modality: ' + ' | '.join(modality_names) + ' (default: rgb)') 25 | parser.add_argument('-j', '--workers', default=16, type=int, metavar='N', 26 | help='number of data loading workers (default: 16)') 27 | parser.add_argument('--print-freq', '-p', default=50, type=int, 28 | metavar='N', help='print frequency (default: 50)') 29 | parser.add_argument('-e', '--evaluate', default='', type=str, metavar='PATH',) 30 | parser.add_argument('--gpu', default='0', type=str, metavar='N', help="gpu id") 31 | parser.set_defaults(cuda=True) 32 | 33 | args = parser.parse_args() 34 | return args 35 | 36 | 37 | def colored_depthmap(depth, d_min=None, d_max=None): 38 | if d_min is None: 39 | d_min = np.min(depth) 40 | if d_max is None: 41 | d_max = np.max(depth) 42 | depth_relative = (depth - d_min) / (d_max - d_min) 43 | return 255 * cmap(depth_relative)[:,:,:3] # H, W, C 44 | 45 | 46 | def merge_into_row(input, depth_target, depth_pred): 47 | rgb = 255 * np.transpose(np.squeeze(input.cpu().numpy()), (1,2,0)) # H, W, C 48 | depth_target_cpu = np.squeeze(depth_target.cpu().numpy()) 49 | depth_pred_cpu = np.squeeze(depth_pred.data.cpu().numpy()) 50 | 51 | d_min = min(np.min(depth_target_cpu), np.min(depth_pred_cpu)) 52 | d_max = max(np.max(depth_target_cpu), np.max(depth_pred_cpu)) 53 | depth_target_col = colored_depthmap(depth_target_cpu, d_min, d_max) 54 | depth_pred_col = colored_depthmap(depth_pred_cpu, d_min, d_max) 55 | img_merge = np.hstack([rgb, depth_target_col, depth_pred_col]) 56 | 57 | return img_merge 58 | 59 | 60 | def merge_into_row_with_gt(input, depth_input, depth_target, depth_pred): 61 | rgb = 255 * np.transpose(np.squeeze(input.cpu().numpy()), (1,2,0)) # H, W, C 62 | depth_input_cpu = np.squeeze(depth_input.cpu().numpy()) 63 | depth_target_cpu = np.squeeze(depth_target.cpu().numpy()) 64 | depth_pred_cpu = np.squeeze(depth_pred.data.cpu().numpy()) 65 | 66 | d_min = min(np.min(depth_input_cpu), np.min(depth_target_cpu), np.min(depth_pred_cpu)) 67 | d_max = max(np.max(depth_input_cpu), np.max(depth_target_cpu), np.max(depth_pred_cpu)) 68 | depth_input_col = colored_depthmap(depth_input_cpu, d_min, d_max) 69 | depth_target_col = colored_depthmap(depth_target_cpu, d_min, d_max) 70 | depth_pred_col = colored_depthmap(depth_pred_cpu, d_min, d_max) 71 | 72 | img_merge = np.hstack([rgb, depth_input_col, depth_target_col, depth_pred_col]) 73 | 74 | return img_merge 75 | 76 | 77 | def add_row(img_merge, row): 78 | return np.vstack([img_merge, row]) 79 | 80 | 81 | def save_image(img_merge, filename): 82 | img_merge = Image.fromarray(img_merge.astype('uint8')) 83 | img_merge.save(filename) 84 | --------------------------------------------------------------------------------