├── .gitignore ├── LICENSE ├── README.md ├── acl_test.cc ├── data └── kernels.cl ├── install.sh ├── layer-test ├── README.md ├── conv2d.cc ├── depth.cc ├── gemm.cc ├── test_conv2d.py ├── test_dense.py ├── test_depth.py ├── util.cc └── util.h ├── mali_imagenet_bench.py ├── mxnet_test.py ├── results.png ├── run_test.sh └── spatial.png /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Lianmin Zheng 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Note: The data and scripts here are all stale. Please go to https://github.com/dmlc/tvm/wiki/Benchmark#mobile-gpu For the latest results. 2 | 3 |


4 |


5 |


6 |


7 |


8 |


9 |


10 | 11 | # Benchmarking Deep Neural Networks on ARM CPU/GPU 12 | 13 | This repo is the supporting material for [Optimizing Mobile Deep Learning on ARM GPU with TVM](http://tvmlang.org/2018/01/16/opt-mali-gpu.html) 14 | 15 | ## Inference Speed on ImageNet 16 | Tested on 17 | ``` 18 | Firefly-RK3399 4G, CPU: dual-core Cortex-A72 + quad-core Cortex-A53, GPU: Mali-T860MP4 19 | Arm Compute Library: v17.12, MXNet: v1.0.1, Openblas: v0.2.18 20 | ``` 21 | 22 | ![result](results.png) 23 | 24 |   25 | ## Set Test Environment 26 | ``` 27 | sudo /etc/init.d/lightdm stop 28 | sudo -i 29 | echo performance > /sys/class/misc/mali0/device/devfreq/ff9a0000.gpu/governor 30 | ``` 31 | This can make the environment more stable. 32 | 33 | **Note**: You need more than 2.5GB of memory to run the following test. 34 | Otherwise, you must skip the test of vgg16 by replacing `--model all` with `--model resnet18` or `--model mobilenet` 35 | in the commond. 36 | 37 | ## Run Test for TVM/NNVM 38 | In TVM, we use [RPC](http://nnvm.tvmlang.org/tutorials/deploy_model_on_mali_gpu.html) to do test, 39 | so you should build TVM runtime and start a RPC server on your device. 40 | ``` 41 | python -m tvm.exec.rpc_server --host 0.0.0.0 --port=9090 42 | ``` 43 | 44 | Then in your host machine, run the test commond 45 | ``` bash 46 | python mali_imagenet_bench.py --target-host TARGET_HOST --host HOST --port PORT --model all 47 | ``` 48 | Replace the `TARGET_HOST`, `HOST` and `PORT` with the corresponding values in your environment. 49 | 50 | For example, on my Firefly-RK3399, the commond is 51 | ``` bash 52 | python mali_imagenet_bench.py --target-host 'llvm -target=aarch64-linux-gnu -mattr=+neon' --host 10.42.0.96 --port 9090 --model all 53 | ``` 54 | 55 | ## Run Test for MXNet + Openblas 56 | This test is executed locally on your device. So you need install the mxnet with openblas on your device first. 57 | 58 | ``` bash 59 | python mxnet_test.py --model all 60 | ``` 61 | 62 | ## Run Test for Arm Compute Library 63 | Build ACL by cross-compile on host system. 64 | ``` bash 65 | scons Werror=1 neon=1 opencl=1 examples=1 benchmark_tests=1 os=linux arch=arm64-v8a embed_kernels=1 -j$(nproc) 66 | ``` 67 | 68 | copy acl\_test.cc to the root directoy of ACL and build the acl\_test by 69 | ``` bash 70 | aarch64-linux-gnu-g++ acl_test.cc build/utils/*.o -O2 -std=c++11\ 71 | -I. -Iinclude -Lbuild -Lbuild/opencl-1.2-stubs/\ 72 | -larm_compute -larm_compute_graph -larm_compute_core -lOpenCL -o acl_test 73 | ``` 74 | 75 | copy the binary file acl\_test to your device and run 76 | ``` 77 | ./acl_test all 78 | cat result-acl.txt 79 | ``` 80 | results are recored in `result-acl.txt` 81 | 82 | **Note** Some testcases (e.g. resnet) are missing because Arm Compute Library currently (v17.12) does not 83 | support skip connection in its graph runtime. Also some testcases are too slow so that be skipped. 84 | 85 | ## Result 86 | Paste the outputs on my board here. 87 | 88 | ### TVM/NNVM 89 | ``` 90 | ============================================================ 91 | model: vgg16, dtype: float32 92 | warm up.. 93 | test.. 94 | cost per image: 1.2926s 95 | ============================================================ 96 | model: vgg16, dtype: float16 97 | warm up.. 98 | test.. 99 | cost per image: 0.6896s 100 | ============================================================ 101 | model: resnet18, dtype: float32 102 | warm up.. 103 | test.. 104 | cost per image: 0.2041s 105 | ============================================================ 106 | model: resnet18, dtype: float16 107 | warm up.. 108 | test.. 109 | cost per image: 0.1183s 110 | ============================================================ 111 | model: mobilenet, dtype: float32 112 | warm up.. 113 | test.. 114 | cost per image: 0.0767s 115 | ============================================================ 116 | model: mobilenet, dtype: float16 117 | warm up.. 118 | test.. 119 | cost per image: 0.0479s 120 | ``` 121 | 122 | ### MXNet + Openblas 123 | ``` 124 | ============================================================ 125 | model: vgg16, dtype: float32 126 | warm up... 127 | test.. 128 | cost per image: 3.0250s 129 | ============================================================ 130 | model: resnet18, dtype: float32 131 | warm up... 132 | test.. 133 | cost per image: 0.3977s 134 | ============================================================ 135 | model: mobilenet, dtype: float32 136 | warm up... 137 | test.. 138 | cost per image: 0.2914s 139 | ``` 140 | 141 | ### ACL 142 | ``` 143 | backend: cl model: vgg16 conv_method: gemm dtype: float32 cost: 1.64456 144 | backend: cl model: vgg16 conv_method: gemm dtype: float16 cost: 0.969372 145 | backend: cl model: vgg16 conv_method: direct dtype: float32 cost: 3.90031 146 | backend: cl model: vgg16 conv_method: direct dtype: float16 cost: 1.61179 147 | backend: cl model: mobilenet conv_method: gemm dtype: float32 cost: 0.170934 148 | backend: cl model: mobilenet conv_method: direct dtype: float32 cost: 0.173883 149 | backend: neon model: vgg16 conv_method: gemm dtype: float32 cost: 4.10269 150 | ``` 151 | 152 | -------------------------------------------------------------------------------- /acl_test.cc: -------------------------------------------------------------------------------- 1 | #include "arm_compute/graph/Graph.h" 2 | #include "arm_compute/graph/Nodes.h" 3 | #include "arm_compute/runtime/CL/CLScheduler.h" 4 | #include "arm_compute/runtime/CPP/CPPScheduler.h" 5 | #include "arm_compute/runtime/Scheduler.h" 6 | #include "support/ToolchainSupport.h" 7 | #include "utils/GraphUtils.h" 8 | #include "utils/Utils.h" 9 | 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | 17 | using namespace arm_compute::graph; 18 | using namespace arm_compute::graph_utils; 19 | 20 | std::unique_ptr dummy() { 21 | return arm_compute::support::cpp14::make_unique(1); 22 | } 23 | 24 | void get_vgg16(Graph *graph, arm_compute::DataType type) { 25 | *graph << Tensor(TensorInfo(TensorShape(224U, 224U, 3U, 1U), 1, type)) 26 | // Layer 1 27 | << ConvolutionLayer( 3U, 3U, 64U, dummy(), dummy(), PadStrideInfo(1, 1, 1, 1)) 28 | << ActivationLayer(ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU)) 29 | // Layer 2 30 | << ConvolutionLayer( 3U, 3U, 64U, dummy(), dummy(), PadStrideInfo(1, 1, 1, 1)) 31 | << ActivationLayer(ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU)) 32 | << PoolingLayer(PoolingLayerInfo(PoolingType::MAX, 2, PadStrideInfo(2, 2, 0, 0))) 33 | // Layer 3 34 | << ConvolutionLayer( 3U, 3U, 128U, dummy(), dummy(), PadStrideInfo(1, 1, 1, 1)) 35 | << ActivationLayer(ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU)) 36 | // Layer 4 37 | << ConvolutionLayer( 3U, 3U, 128U, dummy(), dummy(), PadStrideInfo(1, 1, 1, 1)) 38 | << ActivationLayer(ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU)) 39 | << PoolingLayer(PoolingLayerInfo(PoolingType::MAX, 2, PadStrideInfo(2, 2, 0, 0))) 40 | // Layer 5 41 | << ConvolutionLayer( 3U, 3U, 256U, dummy(), dummy(), PadStrideInfo(1, 1, 1, 1)) 42 | << ActivationLayer(ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU)) 43 | // Layer 6 44 | << ConvolutionLayer( 3U, 3U, 256U, dummy(), dummy(), PadStrideInfo(1, 1, 1, 1)) 45 | << ActivationLayer(ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU)) 46 | // Layer 7 47 | << ConvolutionLayer( 3U, 3U, 256U, dummy(), dummy(), PadStrideInfo(1, 1, 1, 1)) 48 | << ActivationLayer(ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU)) 49 | << PoolingLayer(PoolingLayerInfo(PoolingType::MAX, 2, PadStrideInfo(2, 2, 0, 0))) 50 | // Layer 8 51 | << ConvolutionLayer( 3U, 3U, 512U, dummy(), dummy(), PadStrideInfo(1, 1, 1, 1)) 52 | << ActivationLayer(ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU)) 53 | // Layer 9 54 | << ConvolutionLayer( 3U, 3U, 512U, dummy(), dummy(), PadStrideInfo(1, 1, 1, 1)) 55 | << ActivationLayer(ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU)) 56 | // Layer 10 57 | << ConvolutionLayer( 3U, 3U, 512U, dummy(), dummy(), PadStrideInfo(1, 1, 1, 1)) 58 | << ActivationLayer(ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU)) 59 | << PoolingLayer(PoolingLayerInfo(PoolingType::MAX, 2, PadStrideInfo(2, 2, 0, 0))) 60 | // Layer 11 61 | << ConvolutionLayer( 3U, 3U, 512U, dummy(), dummy(), PadStrideInfo(1, 1, 1, 1)) 62 | << ActivationLayer(ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU)) 63 | // Layer 12 64 | << ConvolutionLayer( 3U, 3U, 512U, dummy(), dummy(), PadStrideInfo(1, 1, 1, 1)) 65 | << ActivationLayer(ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU)) 66 | // Layer 13 67 | << ConvolutionLayer( 3U, 3U, 512U, dummy(), dummy(), PadStrideInfo(1, 1, 1, 1)) 68 | << ActivationLayer(ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU)) 69 | << PoolingLayer(PoolingLayerInfo(PoolingType::MAX, 2, PadStrideInfo(2, 2, 0, 0))) 70 | // Layer 14 71 | << FullyConnectedLayer( 4096U, dummy(), dummy()) 72 | << ActivationLayer(ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU)) 73 | // Layer 15 74 | << FullyConnectedLayer( 4096U, dummy(), dummy()) 75 | << ActivationLayer(ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU)) 76 | // Layer 16 77 | << FullyConnectedLayer( 1000U, dummy(), dummy()) 78 | // Softmax 79 | << SoftmaxLayer() 80 | << Tensor(TensorInfo(TensorShape(1000U), 1, type)); 81 | } 82 | 83 | BranchLayer get_dwsc_node(const std::string &data_path, std::string &¶m_path, 84 | unsigned int conv_filt, 85 | PadStrideInfo dwc_pad_stride_info, PadStrideInfo conv_pad_stride_info) 86 | { 87 | std::string total_path = "/cnn_data/mobilenet_v1_model/" + param_path + "_"; 88 | SubGraph sg; 89 | sg << DepthwiseConvolutionLayer( 90 | 3U, 3U, dummy(), 91 | std::unique_ptr(nullptr), 92 | dwc_pad_stride_info, 93 | true) 94 | << BatchNormalizationLayer(dummy(), dummy(), dummy(), dummy(), 0.001f) 95 | << ActivationLayer(ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, 6.f)) 96 | << ConvolutionLayer( 1U, 1U, conv_filt, dummy(), 97 | std::unique_ptr(nullptr), conv_pad_stride_info) 98 | << BatchNormalizationLayer( dummy(), dummy(), dummy(), dummy(), 0.001f) 99 | << ActivationLayer(ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, 6.f)); 100 | 101 | return BranchLayer(std::move(sg)); 102 | } 103 | 104 | void get_mobilenet(Graph *graph, arm_compute::DataType type) { 105 | std::string data_path; /* Path to the trainable data */ 106 | 107 | *graph << Tensor(TensorInfo(TensorShape(224U, 224U, 3U, 1U), 1, type)) 108 | << ConvolutionLayer( 3U, 3U, 32U, dummy(), 109 | std::unique_ptr(nullptr), 110 | PadStrideInfo(2, 2, 0, 1, 0, 1, DimensionRoundingType::FLOOR)) 111 | << BatchNormalizationLayer( dummy(), dummy(), dummy(), dummy(), 0.001f) 112 | << ActivationLayer(ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, 6.f)) 113 | << get_dwsc_node(data_path, "Conv2d_1", 64, PadStrideInfo(1, 1, 1, 1), PadStrideInfo(1, 1, 0, 0)) 114 | << get_dwsc_node(data_path, "Conv2d_2", 128, PadStrideInfo(2, 2, 0, 1, 0, 1, DimensionRoundingType::FLOOR), PadStrideInfo(1, 1, 0, 0)) 115 | << get_dwsc_node(data_path, "Conv2d_3", 128, PadStrideInfo(1, 1, 1, 1, 1, 1, DimensionRoundingType::FLOOR), PadStrideInfo(1, 1, 0, 0)) 116 | << get_dwsc_node(data_path, "Conv2d_4", 256, PadStrideInfo(2, 2, 0, 1, 0, 1, DimensionRoundingType::FLOOR), PadStrideInfo(1, 1, 0, 0)) 117 | << get_dwsc_node(data_path, "Conv2d_5", 256, PadStrideInfo(1, 1, 1, 1, 1, 1, DimensionRoundingType::FLOOR), PadStrideInfo(1, 1, 0, 0)) 118 | << get_dwsc_node(data_path, "Conv2d_6", 512, PadStrideInfo(2, 2, 0, 1, 0, 1, DimensionRoundingType::FLOOR), PadStrideInfo(1, 1, 0, 0)) 119 | << get_dwsc_node(data_path, "Conv2d_7", 512, PadStrideInfo(1, 1, 1, 1, 1, 1, DimensionRoundingType::FLOOR), PadStrideInfo(1, 1, 0, 0)) 120 | << get_dwsc_node(data_path, "Conv2d_8", 512, PadStrideInfo(1, 1, 1, 1, 1, 1, DimensionRoundingType::FLOOR), PadStrideInfo(1, 1, 0, 0)) 121 | << get_dwsc_node(data_path, "Conv2d_9", 512, PadStrideInfo(1, 1, 1, 1, 1, 1, DimensionRoundingType::FLOOR), PadStrideInfo(1, 1, 0, 0)) 122 | << get_dwsc_node(data_path, "Conv2d_10", 512, PadStrideInfo(1, 1, 1, 1, 1, 1, DimensionRoundingType::FLOOR), PadStrideInfo(1, 1, 0, 0)) 123 | << get_dwsc_node(data_path, "Conv2d_11", 512, PadStrideInfo(1, 1, 1, 1, 1, 1, DimensionRoundingType::FLOOR), PadStrideInfo(1, 1, 0, 0)) 124 | << get_dwsc_node(data_path, "Conv2d_12", 1024, PadStrideInfo(2, 2, 0, 1, 0, 1, DimensionRoundingType::FLOOR), PadStrideInfo(1, 1, 0, 0)) 125 | << get_dwsc_node(data_path, "Conv2d_13", 1024, PadStrideInfo(1, 1, 1, 1, 1, 1, DimensionRoundingType::FLOOR), PadStrideInfo(1, 1, 0, 0)) 126 | << PoolingLayer(PoolingLayerInfo(PoolingType::AVG)) 127 | << ConvolutionLayer( 1U, 1U, 1000U, dummy(), dummy(), PadStrideInfo(1, 1, 0, 0)) 128 | << ReshapeLayer(TensorShape(1000U)) 129 | << SoftmaxLayer() 130 | << Tensor(TensorInfo(TensorShape(1000U), 1, type)); 131 | } 132 | 133 | double measure(Graph *graph, int n_times) { 134 | arm_compute::CLScheduler::get().default_init(); 135 | graph->run(); 136 | arm_compute::CLScheduler::get().sync(); 137 | 138 | auto tbegin = std::chrono::high_resolution_clock::now(); 139 | for (int i = 0; i < n_times; i++) { 140 | graph->run(); 141 | } 142 | arm_compute::CLScheduler::get().sync(); 143 | auto tend = std::chrono::high_resolution_clock::now(); 144 | 145 | 146 | double cost = std::chrono::duration_cast>(tend - tbegin).count(); 147 | return cost / n_times; 148 | } 149 | 150 | double run_case(std::string backend, std::string model, std::string conv_method, std::string dtype) { 151 | TargetHint target_hint; 152 | ConvolutionMethodHint convolution_hint; 153 | arm_compute::DataType type; 154 | 155 | if (conv_method == "gemm") { 156 | convolution_hint = ConvolutionMethodHint::GEMM; 157 | } else { 158 | convolution_hint = ConvolutionMethodHint::DIRECT; 159 | } 160 | 161 | if (backend == "cl") { 162 | target_hint = TargetHint::OPENCL; 163 | } else { 164 | target_hint = TargetHint::NEON; 165 | } 166 | 167 | if (dtype == "float32") { 168 | type = DataType::F32; 169 | } else { 170 | type = DataType::F16; 171 | } 172 | 173 | Graph graph; 174 | graph << target_hint << convolution_hint; 175 | 176 | if (model == "mobilenet") 177 | get_mobilenet(&graph, type); 178 | else if (model == "vgg16") 179 | get_vgg16(&graph, type); 180 | else 181 | std::cout << "unknown model" << std::endl; 182 | 183 | int num_warmup, num_test; 184 | 185 | num_warmup = 10; 186 | num_test = 60; 187 | 188 | if (model == "mobilenet") { // mobilenet is fast, need more runs for stable measureament 189 | num_warmup *= 5; 190 | num_test *= 5; 191 | } 192 | 193 | // warm up 194 | measure(&graph, num_warmup); 195 | 196 | // test 197 | double cost = measure(&graph, num_test); 198 | return cost; 199 | } 200 | 201 | int main(int argc, const char **argv) 202 | { 203 | // Check if OpenCL is available and initialize the scheduler 204 | // Usage 1 : test all 205 | // Usage 2 : test [cl|neno] [mobilenet|vgg16] [gemm|direct] [float32|float16] 206 | 207 | std::ofstream fout("result-acl.txt", std::ios::app); 208 | 209 | if (strcmp(argv[1], "all") == 0) { // test all 210 | std::string backend[] = {"cl", "neon"}; 211 | std::string model[] = {"vgg16", "mobilenet"}; 212 | std::string conv_method[] = {"gemm", "direct"}; 213 | std::string dtype[] = {"float32", "float16"}; 214 | 215 | for (int i = 0; i < sizeof(backend)/sizeof(backend[0]); i++) { 216 | for (int j = 0; j < sizeof(model)/sizeof(model[0]); j++) { 217 | for (int k = 0; k < sizeof(conv_method)/sizeof(conv_method[0]); k++) { 218 | for (int l = 0; l < sizeof(dtype)/sizeof(dtype[0]); l++) { 219 | 220 | // skip some test for neon 221 | if (backend[i] == "neon" ) { 222 | continue; 223 | if (conv_method[k] == "direct") // this config is too slow, skip it 224 | continue; 225 | if (model[j] == "mobilenet") // too slow, skip it 226 | continue; 227 | if (dtype[l] == "float16") // skip the test of fp16 on CPU 228 | continue; 229 | } else { 230 | // ACL does not support FP16 depthwise conv 231 | if (model[j] == "mobilenet" && dtype[l] == "float16") 232 | continue; 233 | } 234 | 235 | double cost = run_case(backend[i], model[j], conv_method[k], dtype[l]); 236 | 237 | std::stringstream ss; 238 | 239 | std::string back_name; 240 | if (backend[i] == "cl") 241 | back_name = "mali"; 242 | else 243 | back_name = "neon"; 244 | 245 | ss << "backend: ARMComputeLib-" << back_name << "\tmodel: " << model[j] 246 | << "\tconv_method: " << conv_method[k] << "\tdtype: " << dtype[l] 247 | << "\tcost: " << cost; 248 | std::cout << ss.str() << std::endl; 249 | fout << ss.str() << std::endl; 250 | sleep(20); 251 | } 252 | } 253 | } 254 | } 255 | } else { // test single case 256 | std::string backend = argv[1]; 257 | std::string model = argv[2]; 258 | std::string conv_method = argv[3]; 259 | std::string dtype = argv[4]; 260 | 261 | double cost = run_case(backend, model, conv_method, dtype); 262 | std::stringstream ss; 263 | ss << "backend: " << backend << "\tmodel: " << model 264 | << "\tconv_method: " << conv_method << "\tdtype: " << dtype 265 | << "\tcost: " << cost; 266 | std::cout << ss.str() << std::endl; 267 | fout << ss.str() << std::endl; 268 | } 269 | 270 | return 0; 271 | } 272 | 273 | -------------------------------------------------------------------------------- /data/kernels.cl: -------------------------------------------------------------------------------- 1 | // kernel for packing data 2 | __kernel void default_function__kernel0(__global void* restrict data_vec, __global float* restrict data) { 3 | for (int vh = 0; vh < 3; ++vh) { 4 | ((__global float*)data_vec)[(((((((((int)get_group_id(2)) * 14) + ((int)get_group_id(1))) * 256) + ((int)get_group_id(0))) * 3) + vh) * 6)] = ((((((1 - vh) <= ((int)get_group_id(2))) && (((int)get_group_id(2)) < (57 - vh))) && (1 <= ((int)get_group_id(1)))) && (((int)get_group_id(1)) < 15)) ? data[((((((((int)get_group_id(2)) * 14) + ((int)get_group_id(1))) + (((int)get_group_id(0)) * 784)) + (vh * 14)) * 4) + -57)] : 0.000000e+00f); 5 | ((__global float*)data_vec)[((((((((((int)get_group_id(2)) * 14) + ((int)get_group_id(1))) * 256) + ((int)get_group_id(0))) * 3) + vh) * 6) + 1)] = ((((((1 - vh) <= ((int)get_group_id(2))) && (((int)get_group_id(2)) < (57 - vh))) && (0 <= ((int)get_group_id(1)))) && (((int)get_group_id(1)) < 14)) ? data[((((((((int)get_group_id(2)) * 14) + ((int)get_group_id(1))) + (((int)get_group_id(0)) * 784)) + (vh * 14)) * 4) + -56)] : 0.000000e+00f); 6 | ((__global float*)data_vec)[((((((((((int)get_group_id(2)) * 14) + ((int)get_group_id(1))) * 256) + ((int)get_group_id(0))) * 3) + vh) * 6) + 2)] = ((((((1 - vh) <= ((int)get_group_id(2))) && (((int)get_group_id(2)) < (57 - vh))) && (0 <= ((int)get_group_id(1)))) && (((int)get_group_id(1)) < 14)) ? data[((((((((int)get_group_id(2)) * 14) + ((int)get_group_id(1))) + (((int)get_group_id(0)) * 784)) + (vh * 14)) * 4) + -55)] : 0.000000e+00f); 7 | ((__global float*)data_vec)[((((((((((int)get_group_id(2)) * 14) + ((int)get_group_id(1))) * 256) + ((int)get_group_id(0))) * 3) + vh) * 6) + 3)] = ((((((1 - vh) <= ((int)get_group_id(2))) && (((int)get_group_id(2)) < (57 - vh))) && (0 <= ((int)get_group_id(1)))) && (((int)get_group_id(1)) < 14)) ? data[((((((((int)get_group_id(2)) * 14) + ((int)get_group_id(1))) + (((int)get_group_id(0)) * 784)) + (vh * 14)) * 4) + -54)] : 0.000000e+00f); 8 | ((__global float*)data_vec)[((((((((((int)get_group_id(2)) * 14) + ((int)get_group_id(1))) * 256) + ((int)get_group_id(0))) * 3) + vh) * 6) + 4)] = ((((((1 - vh) <= ((int)get_group_id(2))) && (((int)get_group_id(2)) < (57 - vh))) && (0 <= ((int)get_group_id(1)))) && (((int)get_group_id(1)) < 14)) ? data[((((((((int)get_group_id(2)) * 14) + ((int)get_group_id(1))) + (((int)get_group_id(0)) * 784)) + (vh * 14)) * 4) + -53)] : 0.000000e+00f); 9 | ((__global float*)data_vec)[((((((((((int)get_group_id(2)) * 14) + ((int)get_group_id(1))) * 256) + ((int)get_group_id(0))) * 3) + vh) * 6) + 5)] = ((((((1 - vh) <= ((int)get_group_id(2))) && (((int)get_group_id(2)) < (57 - vh))) && (-1 <= ((int)get_group_id(1)))) && (((int)get_group_id(1)) < 13)) ? data[((((((((int)get_group_id(2)) * 14) + ((int)get_group_id(1))) + (((int)get_group_id(0)) * 784)) + (vh * 14)) * 4) + -52)] : 0.000000e+00f); 10 | } 11 | } 12 | 13 | // kernel for packing filter 14 | __kernel void default_function__kernel1(__global void* restrict kernel_vec, __global float* restrict weight) { 15 | float4 _1; 16 | int4 _2 = (int4)(((((((int)get_group_id(1)) * 1024) + ((int)get_group_id(0))) * 9))+(2304*0), ((((((int)get_group_id(1)) * 1024) + ((int)get_group_id(0))) * 9))+(2304*1), ((((((int)get_group_id(1)) * 1024) + ((int)get_group_id(0))) * 9))+(2304*2), ((((((int)get_group_id(1)) * 1024) + ((int)get_group_id(0))) * 9))+(2304*3)); 17 | _1.s0 = weight[_2.s0]; 18 | _1.s1 = weight[_2.s1]; 19 | _1.s2 = weight[_2.s2]; 20 | _1.s3 = weight[_2.s3]; 21 | vstore4(_1, 0, (__global float*)kernel_vec + (((((int)get_group_id(1)) * 256) + ((int)get_group_id(0))) * 36)); 22 | float4 _3; 23 | int4 _4 = (int4)((((((((int)get_group_id(1)) * 1024) + ((int)get_group_id(0))) * 9) + 1))+(2304*0), (((((((int)get_group_id(1)) * 1024) + ((int)get_group_id(0))) * 9) + 1))+(2304*1), (((((((int)get_group_id(1)) * 1024) + ((int)get_group_id(0))) * 9) + 1))+(2304*2), (((((((int)get_group_id(1)) * 1024) + ((int)get_group_id(0))) * 9) + 1))+(2304*3)); 24 | _3.s0 = weight[_4.s0]; 25 | _3.s1 = weight[_4.s1]; 26 | _3.s2 = weight[_4.s2]; 27 | _3.s3 = weight[_4.s3]; 28 | vstore4(_3, 0, (__global float*)kernel_vec + ((((((int)get_group_id(1)) * 256) + ((int)get_group_id(0))) * 36) + 4)); 29 | float4 _5; 30 | int4 _6 = (int4)((((((((int)get_group_id(1)) * 1024) + ((int)get_group_id(0))) * 9) + 2))+(2304*0), (((((((int)get_group_id(1)) * 1024) + ((int)get_group_id(0))) * 9) + 2))+(2304*1), (((((((int)get_group_id(1)) * 1024) + ((int)get_group_id(0))) * 9) + 2))+(2304*2), (((((((int)get_group_id(1)) * 1024) + ((int)get_group_id(0))) * 9) + 2))+(2304*3)); 31 | _5.s0 = weight[_6.s0]; 32 | _5.s1 = weight[_6.s1]; 33 | _5.s2 = weight[_6.s2]; 34 | _5.s3 = weight[_6.s3]; 35 | vstore4(_5, 0, (__global float*)kernel_vec + ((((((int)get_group_id(1)) * 256) + ((int)get_group_id(0))) * 36) + 8)); 36 | float4 _7; 37 | int4 _8 = (int4)((((((((int)get_group_id(1)) * 1024) + ((int)get_group_id(0))) * 9) + 3))+(2304*0), (((((((int)get_group_id(1)) * 1024) + ((int)get_group_id(0))) * 9) + 3))+(2304*1), (((((((int)get_group_id(1)) * 1024) + ((int)get_group_id(0))) * 9) + 3))+(2304*2), (((((((int)get_group_id(1)) * 1024) + ((int)get_group_id(0))) * 9) + 3))+(2304*3)); 38 | _7.s0 = weight[_8.s0]; 39 | _7.s1 = weight[_8.s1]; 40 | _7.s2 = weight[_8.s2]; 41 | _7.s3 = weight[_8.s3]; 42 | vstore4(_7, 0, (__global float*)kernel_vec + ((((((int)get_group_id(1)) * 256) + ((int)get_group_id(0))) * 36) + 12)); 43 | float4 _9; 44 | int4 _10 = (int4)((((((((int)get_group_id(1)) * 1024) + ((int)get_group_id(0))) * 9) + 4))+(2304*0), (((((((int)get_group_id(1)) * 1024) + ((int)get_group_id(0))) * 9) + 4))+(2304*1), (((((((int)get_group_id(1)) * 1024) + ((int)get_group_id(0))) * 9) + 4))+(2304*2), (((((((int)get_group_id(1)) * 1024) + ((int)get_group_id(0))) * 9) + 4))+(2304*3)); 45 | _9.s0 = weight[_10.s0]; 46 | _9.s1 = weight[_10.s1]; 47 | _9.s2 = weight[_10.s2]; 48 | _9.s3 = weight[_10.s3]; 49 | vstore4(_9, 0, (__global float*)kernel_vec + ((((((int)get_group_id(1)) * 256) + ((int)get_group_id(0))) * 36) + 16)); 50 | float4 _11; 51 | int4 _12 = (int4)((((((((int)get_group_id(1)) * 1024) + ((int)get_group_id(0))) * 9) + 5))+(2304*0), (((((((int)get_group_id(1)) * 1024) + ((int)get_group_id(0))) * 9) + 5))+(2304*1), (((((((int)get_group_id(1)) * 1024) + ((int)get_group_id(0))) * 9) + 5))+(2304*2), (((((((int)get_group_id(1)) * 1024) + ((int)get_group_id(0))) * 9) + 5))+(2304*3)); 52 | _11.s0 = weight[_12.s0]; 53 | _11.s1 = weight[_12.s1]; 54 | _11.s2 = weight[_12.s2]; 55 | _11.s3 = weight[_12.s3]; 56 | vstore4(_11, 0, (__global float*)kernel_vec + ((((((int)get_group_id(1)) * 256) + ((int)get_group_id(0))) * 36) + 20)); 57 | float4 _13; 58 | int4 _14 = (int4)((((((((int)get_group_id(1)) * 1024) + ((int)get_group_id(0))) * 9) + 6))+(2304*0), (((((((int)get_group_id(1)) * 1024) + ((int)get_group_id(0))) * 9) + 6))+(2304*1), (((((((int)get_group_id(1)) * 1024) + ((int)get_group_id(0))) * 9) + 6))+(2304*2), (((((((int)get_group_id(1)) * 1024) + ((int)get_group_id(0))) * 9) + 6))+(2304*3)); 59 | _13.s0 = weight[_14.s0]; 60 | _13.s1 = weight[_14.s1]; 61 | _13.s2 = weight[_14.s2]; 62 | _13.s3 = weight[_14.s3]; 63 | vstore4(_13, 0, (__global float*)kernel_vec + ((((((int)get_group_id(1)) * 256) + ((int)get_group_id(0))) * 36) + 24)); 64 | float4 _15; 65 | int4 _16 = (int4)((((((((int)get_group_id(1)) * 1024) + ((int)get_group_id(0))) * 9) + 7))+(2304*0), (((((((int)get_group_id(1)) * 1024) + ((int)get_group_id(0))) * 9) + 7))+(2304*1), (((((((int)get_group_id(1)) * 1024) + ((int)get_group_id(0))) * 9) + 7))+(2304*2), (((((((int)get_group_id(1)) * 1024) + ((int)get_group_id(0))) * 9) + 7))+(2304*3)); 66 | _15.s0 = weight[_16.s0]; 67 | _15.s1 = weight[_16.s1]; 68 | _15.s2 = weight[_16.s2]; 69 | _15.s3 = weight[_16.s3]; 70 | vstore4(_15, 0, (__global float*)kernel_vec + ((((((int)get_group_id(1)) * 256) + ((int)get_group_id(0))) * 36) + 28)); 71 | float4 _17; 72 | int4 _18 = (int4)((((((((int)get_group_id(1)) * 1024) + ((int)get_group_id(0))) * 9) + 8))+(2304*0), (((((((int)get_group_id(1)) * 1024) + ((int)get_group_id(0))) * 9) + 8))+(2304*1), (((((((int)get_group_id(1)) * 1024) + ((int)get_group_id(0))) * 9) + 8))+(2304*2), (((((((int)get_group_id(1)) * 1024) + ((int)get_group_id(0))) * 9) + 8))+(2304*3)); 73 | _17.s0 = weight[_18.s0]; 74 | _17.s1 = weight[_18.s1]; 75 | _17.s2 = weight[_18.s2]; 76 | _17.s3 = weight[_18.s3]; 77 | vstore4(_17, 0, (__global float*)kernel_vec + ((((((int)get_group_id(1)) * 256) + ((int)get_group_id(0))) * 36) + 32)); 78 | } 79 | 80 | // kernel for convolution 81 | __kernel void default_function__kernel2(__global void* restrict conv, __global void* restrict data_vec, __global void* restrict kernel_vec) { 82 | vstore4(((float4)(0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f)), 0, (__global float*)conv + (((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 56) + ((int)get_group_id(1))) * 14) + ((int)get_group_id(0))) * 16)); 83 | vstore4(((float4)(0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f)), 0, (__global float*)conv + ((((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 56) + ((int)get_group_id(1))) * 14) + ((int)get_group_id(0))) * 16) + 4)); 84 | vstore4(((float4)(0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f)), 0, (__global float*)conv + ((((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 56) + ((int)get_group_id(1))) * 14) + ((int)get_group_id(0))) * 16) + 8)); 85 | vstore4(((float4)(0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f)), 0, (__global float*)conv + ((((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 56) + ((int)get_group_id(1))) * 14) + ((int)get_group_id(0))) * 16) + 12)); 86 | for (int ci = 0; ci < 256; ++ci) { 87 | vstore4((vload4(0, (__global float*)conv + (((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 56) + ((int)get_group_id(1))) * 14) + ((int)get_group_id(0))) * 16)) + (((float4)(((__global float*)data_vec)[(((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18)], ((__global float*)data_vec)[(((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18)], ((__global float*)data_vec)[(((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18)], ((__global float*)data_vec)[(((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18)])) * vload4(0, (__global float*)kernel_vec + (((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 256) + ci) * 36)))), 0, (__global float*)conv + (((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 56) + ((int)get_group_id(1))) * 14) + ((int)get_group_id(0))) * 16)); 88 | vstore4((vload4(0, (__global float*)conv + ((((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 56) + ((int)get_group_id(1))) * 14) + ((int)get_group_id(0))) * 16) + 4)) + (((float4)(((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 1)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 1)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 1)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 1)])) * vload4(0, (__global float*)kernel_vec + (((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 256) + ci) * 36)))), 0, (__global float*)conv + ((((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 56) + ((int)get_group_id(1))) * 14) + ((int)get_group_id(0))) * 16) + 4)); 89 | vstore4((vload4(0, (__global float*)conv + ((((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 56) + ((int)get_group_id(1))) * 14) + ((int)get_group_id(0))) * 16) + 8)) + (((float4)(((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 2)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 2)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 2)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 2)])) * vload4(0, (__global float*)kernel_vec + (((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 256) + ci) * 36)))), 0, (__global float*)conv + ((((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 56) + ((int)get_group_id(1))) * 14) + ((int)get_group_id(0))) * 16) + 8)); 90 | vstore4((vload4(0, (__global float*)conv + ((((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 56) + ((int)get_group_id(1))) * 14) + ((int)get_group_id(0))) * 16) + 12)) + (((float4)(((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 3)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 3)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 3)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 3)])) * vload4(0, (__global float*)kernel_vec + (((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 256) + ci) * 36)))), 0, (__global float*)conv + ((((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 56) + ((int)get_group_id(1))) * 14) + ((int)get_group_id(0))) * 16) + 12)); 91 | vstore4((vload4(0, (__global float*)conv + (((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 56) + ((int)get_group_id(1))) * 14) + ((int)get_group_id(0))) * 16)) + (((float4)(((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 1)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 1)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 1)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 1)])) * vload4(0, (__global float*)kernel_vec + ((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 256) + ci) * 36) + 4)))), 0, (__global float*)conv + (((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 56) + ((int)get_group_id(1))) * 14) + ((int)get_group_id(0))) * 16)); 92 | vstore4((vload4(0, (__global float*)conv + ((((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 56) + ((int)get_group_id(1))) * 14) + ((int)get_group_id(0))) * 16) + 4)) + (((float4)(((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 2)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 2)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 2)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 2)])) * vload4(0, (__global float*)kernel_vec + ((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 256) + ci) * 36) + 4)))), 0, (__global float*)conv + ((((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 56) + ((int)get_group_id(1))) * 14) + ((int)get_group_id(0))) * 16) + 4)); 93 | vstore4((vload4(0, (__global float*)conv + ((((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 56) + ((int)get_group_id(1))) * 14) + ((int)get_group_id(0))) * 16) + 8)) + (((float4)(((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 3)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 3)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 3)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 3)])) * vload4(0, (__global float*)kernel_vec + ((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 256) + ci) * 36) + 4)))), 0, (__global float*)conv + ((((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 56) + ((int)get_group_id(1))) * 14) + ((int)get_group_id(0))) * 16) + 8)); 94 | vstore4((vload4(0, (__global float*)conv + ((((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 56) + ((int)get_group_id(1))) * 14) + ((int)get_group_id(0))) * 16) + 12)) + (((float4)(((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 4)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 4)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 4)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 4)])) * vload4(0, (__global float*)kernel_vec + ((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 256) + ci) * 36) + 4)))), 0, (__global float*)conv + ((((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 56) + ((int)get_group_id(1))) * 14) + ((int)get_group_id(0))) * 16) + 12)); 95 | vstore4((vload4(0, (__global float*)conv + (((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 56) + ((int)get_group_id(1))) * 14) + ((int)get_group_id(0))) * 16)) + (((float4)(((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 2)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 2)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 2)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 2)])) * vload4(0, (__global float*)kernel_vec + ((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 256) + ci) * 36) + 8)))), 0, (__global float*)conv + (((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 56) + ((int)get_group_id(1))) * 14) + ((int)get_group_id(0))) * 16)); 96 | vstore4((vload4(0, (__global float*)conv + ((((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 56) + ((int)get_group_id(1))) * 14) + ((int)get_group_id(0))) * 16) + 4)) + (((float4)(((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 3)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 3)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 3)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 3)])) * vload4(0, (__global float*)kernel_vec + ((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 256) + ci) * 36) + 8)))), 0, (__global float*)conv + ((((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 56) + ((int)get_group_id(1))) * 14) + ((int)get_group_id(0))) * 16) + 4)); 97 | vstore4((vload4(0, (__global float*)conv + ((((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 56) + ((int)get_group_id(1))) * 14) + ((int)get_group_id(0))) * 16) + 8)) + (((float4)(((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 4)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 4)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 4)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 4)])) * vload4(0, (__global float*)kernel_vec + ((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 256) + ci) * 36) + 8)))), 0, (__global float*)conv + ((((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 56) + ((int)get_group_id(1))) * 14) + ((int)get_group_id(0))) * 16) + 8)); 98 | vstore4((vload4(0, (__global float*)conv + ((((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 56) + ((int)get_group_id(1))) * 14) + ((int)get_group_id(0))) * 16) + 12)) + (((float4)(((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 5)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 5)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 5)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 5)])) * vload4(0, (__global float*)kernel_vec + ((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 256) + ci) * 36) + 8)))), 0, (__global float*)conv + ((((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 56) + ((int)get_group_id(1))) * 14) + ((int)get_group_id(0))) * 16) + 12)); 99 | vstore4((vload4(0, (__global float*)conv + (((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 56) + ((int)get_group_id(1))) * 14) + ((int)get_group_id(0))) * 16)) + (((float4)(((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 6)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 6)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 6)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 6)])) * vload4(0, (__global float*)kernel_vec + ((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 256) + ci) * 36) + 12)))), 0, (__global float*)conv + (((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 56) + ((int)get_group_id(1))) * 14) + ((int)get_group_id(0))) * 16)); 100 | vstore4((vload4(0, (__global float*)conv + ((((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 56) + ((int)get_group_id(1))) * 14) + ((int)get_group_id(0))) * 16) + 4)) + (((float4)(((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 7)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 7)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 7)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 7)])) * vload4(0, (__global float*)kernel_vec + ((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 256) + ci) * 36) + 12)))), 0, (__global float*)conv + ((((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 56) + ((int)get_group_id(1))) * 14) + ((int)get_group_id(0))) * 16) + 4)); 101 | vstore4((vload4(0, (__global float*)conv + ((((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 56) + ((int)get_group_id(1))) * 14) + ((int)get_group_id(0))) * 16) + 8)) + (((float4)(((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 8)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 8)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 8)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 8)])) * vload4(0, (__global float*)kernel_vec + ((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 256) + ci) * 36) + 12)))), 0, (__global float*)conv + ((((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 56) + ((int)get_group_id(1))) * 14) + ((int)get_group_id(0))) * 16) + 8)); 102 | vstore4((vload4(0, (__global float*)conv + ((((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 56) + ((int)get_group_id(1))) * 14) + ((int)get_group_id(0))) * 16) + 12)) + (((float4)(((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 9)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 9)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 9)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 9)])) * vload4(0, (__global float*)kernel_vec + ((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 256) + ci) * 36) + 12)))), 0, (__global float*)conv + ((((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 56) + ((int)get_group_id(1))) * 14) + ((int)get_group_id(0))) * 16) + 12)); 103 | vstore4((vload4(0, (__global float*)conv + (((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 56) + ((int)get_group_id(1))) * 14) + ((int)get_group_id(0))) * 16)) + (((float4)(((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 7)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 7)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 7)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 7)])) * vload4(0, (__global float*)kernel_vec + ((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 256) + ci) * 36) + 16)))), 0, (__global float*)conv + (((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 56) + ((int)get_group_id(1))) * 14) + ((int)get_group_id(0))) * 16)); 104 | vstore4((vload4(0, (__global float*)conv + ((((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 56) + ((int)get_group_id(1))) * 14) + ((int)get_group_id(0))) * 16) + 4)) + (((float4)(((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 8)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 8)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 8)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 8)])) * vload4(0, (__global float*)kernel_vec + ((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 256) + ci) * 36) + 16)))), 0, (__global float*)conv + ((((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 56) + ((int)get_group_id(1))) * 14) + ((int)get_group_id(0))) * 16) + 4)); 105 | vstore4((vload4(0, (__global float*)conv + ((((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 56) + ((int)get_group_id(1))) * 14) + ((int)get_group_id(0))) * 16) + 8)) + (((float4)(((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 9)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 9)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 9)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 9)])) * vload4(0, (__global float*)kernel_vec + ((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 256) + ci) * 36) + 16)))), 0, (__global float*)conv + ((((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 56) + ((int)get_group_id(1))) * 14) + ((int)get_group_id(0))) * 16) + 8)); 106 | vstore4((vload4(0, (__global float*)conv + ((((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 56) + ((int)get_group_id(1))) * 14) + ((int)get_group_id(0))) * 16) + 12)) + (((float4)(((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 10)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 10)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 10)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 10)])) * vload4(0, (__global float*)kernel_vec + ((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 256) + ci) * 36) + 16)))), 0, (__global float*)conv + ((((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 56) + ((int)get_group_id(1))) * 14) + ((int)get_group_id(0))) * 16) + 12)); 107 | vstore4((vload4(0, (__global float*)conv + (((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 56) + ((int)get_group_id(1))) * 14) + ((int)get_group_id(0))) * 16)) + (((float4)(((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 8)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 8)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 8)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 8)])) * vload4(0, (__global float*)kernel_vec + ((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 256) + ci) * 36) + 20)))), 0, (__global float*)conv + (((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 56) + ((int)get_group_id(1))) * 14) + ((int)get_group_id(0))) * 16)); 108 | vstore4((vload4(0, (__global float*)conv + ((((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 56) + ((int)get_group_id(1))) * 14) + ((int)get_group_id(0))) * 16) + 4)) + (((float4)(((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 9)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 9)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 9)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 9)])) * vload4(0, (__global float*)kernel_vec + ((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 256) + ci) * 36) + 20)))), 0, (__global float*)conv + ((((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 56) + ((int)get_group_id(1))) * 14) + ((int)get_group_id(0))) * 16) + 4)); 109 | vstore4((vload4(0, (__global float*)conv + ((((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 56) + ((int)get_group_id(1))) * 14) + ((int)get_group_id(0))) * 16) + 8)) + (((float4)(((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 10)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 10)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 10)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 10)])) * vload4(0, (__global float*)kernel_vec + ((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 256) + ci) * 36) + 20)))), 0, (__global float*)conv + ((((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 56) + ((int)get_group_id(1))) * 14) + ((int)get_group_id(0))) * 16) + 8)); 110 | vstore4((vload4(0, (__global float*)conv + ((((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 56) + ((int)get_group_id(1))) * 14) + ((int)get_group_id(0))) * 16) + 12)) + (((float4)(((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 11)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 11)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 11)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 11)])) * vload4(0, (__global float*)kernel_vec + ((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 256) + ci) * 36) + 20)))), 0, (__global float*)conv + ((((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 56) + ((int)get_group_id(1))) * 14) + ((int)get_group_id(0))) * 16) + 12)); 111 | vstore4((vload4(0, (__global float*)conv + (((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 56) + ((int)get_group_id(1))) * 14) + ((int)get_group_id(0))) * 16)) + (((float4)(((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 12)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 12)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 12)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 12)])) * vload4(0, (__global float*)kernel_vec + ((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 256) + ci) * 36) + 24)))), 0, (__global float*)conv + (((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 56) + ((int)get_group_id(1))) * 14) + ((int)get_group_id(0))) * 16)); 112 | vstore4((vload4(0, (__global float*)conv + ((((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 56) + ((int)get_group_id(1))) * 14) + ((int)get_group_id(0))) * 16) + 4)) + (((float4)(((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 13)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 13)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 13)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 13)])) * vload4(0, (__global float*)kernel_vec + ((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 256) + ci) * 36) + 24)))), 0, (__global float*)conv + ((((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 56) + ((int)get_group_id(1))) * 14) + ((int)get_group_id(0))) * 16) + 4)); 113 | vstore4((vload4(0, (__global float*)conv + ((((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 56) + ((int)get_group_id(1))) * 14) + ((int)get_group_id(0))) * 16) + 8)) + (((float4)(((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 14)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 14)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 14)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 14)])) * vload4(0, (__global float*)kernel_vec + ((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 256) + ci) * 36) + 24)))), 0, (__global float*)conv + ((((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 56) + ((int)get_group_id(1))) * 14) + ((int)get_group_id(0))) * 16) + 8)); 114 | vstore4((vload4(0, (__global float*)conv + ((((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 56) + ((int)get_group_id(1))) * 14) + ((int)get_group_id(0))) * 16) + 12)) + (((float4)(((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 15)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 15)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 15)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 15)])) * vload4(0, (__global float*)kernel_vec + ((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 256) + ci) * 36) + 24)))), 0, (__global float*)conv + ((((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 56) + ((int)get_group_id(1))) * 14) + ((int)get_group_id(0))) * 16) + 12)); 115 | vstore4((vload4(0, (__global float*)conv + (((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 56) + ((int)get_group_id(1))) * 14) + ((int)get_group_id(0))) * 16)) + (((float4)(((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 13)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 13)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 13)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 13)])) * vload4(0, (__global float*)kernel_vec + ((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 256) + ci) * 36) + 28)))), 0, (__global float*)conv + (((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 56) + ((int)get_group_id(1))) * 14) + ((int)get_group_id(0))) * 16)); 116 | vstore4((vload4(0, (__global float*)conv + ((((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 56) + ((int)get_group_id(1))) * 14) + ((int)get_group_id(0))) * 16) + 4)) + (((float4)(((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 14)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 14)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 14)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 14)])) * vload4(0, (__global float*)kernel_vec + ((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 256) + ci) * 36) + 28)))), 0, (__global float*)conv + ((((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 56) + ((int)get_group_id(1))) * 14) + ((int)get_group_id(0))) * 16) + 4)); 117 | vstore4((vload4(0, (__global float*)conv + ((((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 56) + ((int)get_group_id(1))) * 14) + ((int)get_group_id(0))) * 16) + 8)) + (((float4)(((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 15)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 15)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 15)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 15)])) * vload4(0, (__global float*)kernel_vec + ((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 256) + ci) * 36) + 28)))), 0, (__global float*)conv + ((((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 56) + ((int)get_group_id(1))) * 14) + ((int)get_group_id(0))) * 16) + 8)); 118 | vstore4((vload4(0, (__global float*)conv + ((((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 56) + ((int)get_group_id(1))) * 14) + ((int)get_group_id(0))) * 16) + 12)) + (((float4)(((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 16)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 16)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 16)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 16)])) * vload4(0, (__global float*)kernel_vec + ((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 256) + ci) * 36) + 28)))), 0, (__global float*)conv + ((((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 56) + ((int)get_group_id(1))) * 14) + ((int)get_group_id(0))) * 16) + 12)); 119 | vstore4((vload4(0, (__global float*)conv + (((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 56) + ((int)get_group_id(1))) * 14) + ((int)get_group_id(0))) * 16)) + (((float4)(((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 14)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 14)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 14)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 14)])) * vload4(0, (__global float*)kernel_vec + ((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 256) + ci) * 36) + 32)))), 0, (__global float*)conv + (((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 56) + ((int)get_group_id(1))) * 14) + ((int)get_group_id(0))) * 16)); 120 | vstore4((vload4(0, (__global float*)conv + ((((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 56) + ((int)get_group_id(1))) * 14) + ((int)get_group_id(0))) * 16) + 4)) + (((float4)(((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 15)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 15)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 15)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 15)])) * vload4(0, (__global float*)kernel_vec + ((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 256) + ci) * 36) + 32)))), 0, (__global float*)conv + ((((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 56) + ((int)get_group_id(1))) * 14) + ((int)get_group_id(0))) * 16) + 4)); 121 | vstore4((vload4(0, (__global float*)conv + ((((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 56) + ((int)get_group_id(1))) * 14) + ((int)get_group_id(0))) * 16) + 8)) + (((float4)(((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 16)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 16)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 16)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 16)])) * vload4(0, (__global float*)kernel_vec + ((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 256) + ci) * 36) + 32)))), 0, (__global float*)conv + ((((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 56) + ((int)get_group_id(1))) * 14) + ((int)get_group_id(0))) * 16) + 8)); 122 | vstore4((vload4(0, (__global float*)conv + ((((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 56) + ((int)get_group_id(1))) * 14) + ((int)get_group_id(0))) * 16) + 12)) + (((float4)(((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 17)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 17)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 17)], ((__global float*)data_vec)[((((((((int)get_group_id(1)) * 14) + ((int)get_group_id(0))) * 256) + ci) * 18) + 17)])) * vload4(0, (__global float*)kernel_vec + ((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 256) + ci) * 36) + 32)))), 0, (__global float*)conv + ((((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 56) + ((int)get_group_id(1))) * 14) + ((int)get_group_id(0))) * 16) + 12)); 123 | } 124 | } 125 | 126 | // kernel for unpacking the output 127 | __kernel void default_function__kernel3(__global float* restrict output_unpack, __global void* restrict conv) { 128 | output_unpack[((((((((int)get_group_id(2)) * 8) + ((int)get_local_id(2))) * 56) + ((int)get_group_id(1))) * 56) + ((int)get_group_id(0)))] = ((__global float*)conv)[((((((((int)get_group_id(2)) * 2) + (((int)get_local_id(2)) / 4)) * 12544) + (((int)get_local_id(2)) % 4)) + (((int)get_group_id(1)) * 224)) + (((int)get_group_id(0)) * 4))]; 129 | } 130 | 131 | -------------------------------------------------------------------------------- /install.sh: -------------------------------------------------------------------------------- 1 | sudo apt-get update 2 | 3 | sudo apt install llvm-4.0 4 | sudo apt install scons 5 | sudo apt install libopenblas-dev 6 | sudo apt-get -y install git cmake build-essential g++-4.8 c++-4.8 liblapack* libblas* libopencv* 7 | 8 | git clone --recursive https://github.com/dmlc/nnvm.git 9 | git clone https://github.com/ARM-software/ComputeLibrary.git --branch v17.12 10 | git clone --recursive https://github.com/apache/incubator-mxnet.git 11 | 12 | # build nnvm/tvm 13 | cd nnvm/tvm 14 | make USE_OPENCL=1 LLVM_CONFIG=llvm-config-4.0 -j4 15 | cd .. 16 | make 17 | cd .. 18 | 19 | # build arm compute library 20 | cd ComputeLibrary 21 | scons Werror=1 neon=1 opencl=1 examples=1 os=linux arch=arm64-v8a embed_kernels=1 build=native -j4 22 | cp ../acl_test.cc . 23 | 24 | g++ acl_test.cc build/utils/*.o -O2 -std=c++11 -I. -Iinclude -Lbuild -larm_compute -larm_compute_graph -larm_compute_core -lOpenCL -o acl_test 25 | cp acl_test .. 26 | cd .. 27 | 28 | # build mxnet 29 | cd incubator-mxnet 30 | make -j2 USE_OPENCV=0 USE_BLAS=openblas 31 | cd .. 32 | 33 | -------------------------------------------------------------------------------- /layer-test/README.md: -------------------------------------------------------------------------------- 1 | # Test script for layer-wise benchmark 2 | 3 | to be documented... 4 | 5 | -------------------------------------------------------------------------------- /layer-test/conv2d.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include "arm_compute/core/Types.h" 7 | #include "arm_compute/runtime/CL/CLFunctions.h" 8 | #include "arm_compute/runtime/CL/CLScheduler.h" 9 | 10 | #include "util.h" 11 | 12 | using namespace arm_compute; 13 | 14 | struct Workload { 15 | std::string in_dtype; 16 | std::string out_dtype; 17 | size_t batch; 18 | size_t height; 19 | size_t width; 20 | size_t in_filter; 21 | size_t out_filter; 22 | size_t hkernel; 23 | size_t wkernel; 24 | size_t hpad; 25 | size_t wpad; 26 | size_t hstride; 27 | size_t wstride; 28 | }; 29 | 30 | // measure the cost and gflops of 2d convolution 31 | std::pair MeasureConv(const Workload &w, int times=30) { 32 | assert(w.in_dtype == w.out_dtype); 33 | Format format = DtypeToFormat(w.in_dtype); 34 | 35 | CLTensor input, weight, output; 36 | PadStrideInfo conv_info(w.wstride, w.hstride, w.wpad, w.hpad); 37 | 38 | // init OpenCL 39 | CLScheduler::get().default_init(); 40 | 41 | // allocate tensors 42 | input.allocator()->init(TensorInfo(TensorShape(w.width, w.height, w.in_filter, w.batch), format)); 43 | weight.allocator()->init(TensorInfo(TensorShape(w.wkernel, w.hkernel, w.in_filter, w.out_filter), format)); 44 | size_t w_out = (w.width - w.wkernel + w.wpad * 2) / w.wstride + 1; 45 | size_t h_out = (w.height - w.hkernel + w.hpad * 2) / w.hstride + 1; 46 | output.allocator()->init(TensorInfo(TensorShape(w_out, h_out, w.out_filter, w.batch), format)); 47 | input.allocator()->allocate(); 48 | weight.allocator()->allocate(); 49 | output.allocator()->allocate(); 50 | CLScheduler::get().sync(); 51 | 52 | // configure conv2d function 53 | CLConvolutionLayer conv2d; 54 | conv2d.configure(&input, &weight, nullptr, &output, conv_info); 55 | 56 | // run test 57 | conv2d.run(); 58 | std::chrono::high_resolution_clock::time_point begin = std::chrono::high_resolution_clock::now(); 59 | 60 | for (int i = 0; i < times; i++) { 61 | conv2d.run(); 62 | } 63 | CLScheduler::get().sync(); 64 | 65 | std::chrono::high_resolution_clock::time_point end = std::chrono::high_resolution_clock::now(); 66 | 67 | // calcuate gflops 68 | std::chrono::duration fp_ms = end - begin; 69 | double cost = fp_ms.count() / times; 70 | return std::make_pair(cost, 2.0 * w.batch * w_out * h_out * w.out_filter * 71 | w.hkernel * w.wkernel * w.in_filter / 1e9 / cost); 72 | } 73 | 74 | 75 | int main(int argc, const char **argv) 76 | { 77 | Workload to_test[] = { 78 | // vgg16 79 | // Workload{"float32", "float32", 1, 224, 224, 3, 64, 3, 3, 1, 1, 1, 1}, 80 | // Workload{"float32", "float32", 1, 224, 224, 64, 64, 3, 3, 1, 1, 1, 1}, 81 | // Workload{"float32", "float32", 1, 112, 112, 64, 128,3, 3, 1, 1, 1, 1}, 82 | // Workload{"float32", "float32", 1, 112, 112,128, 128,3, 3, 1, 1, 1, 1}, 83 | // Workload{"float32", "float32", 1, 56, 56, 128, 256, 3, 3, 1, 1, 1, 1}, 84 | // Workload{"float32", "float32", 1, 56, 56, 256, 256, 3, 3, 1, 1, 1, 1}, 85 | // Workload{"float32", "float32", 1, 28, 28, 256, 512, 3, 3, 1, 1, 1, 1}, 86 | // Workload{"float32", "float32", 1, 28, 28, 512, 512, 3, 3, 1, 1, 1, 1}, 87 | // Workload{"float32", "float32", 1, 14, 14, 512, 512, 3, 3, 1, 1, 1, 1}, 88 | 89 | // resnet 90 | Workload{"float32", "float32", 1, 224, 224, 3, 64, 7, 7, 3, 3, 2, 2}, 91 | Workload{"float32", "float32", 32, 224, 224, 3, 64, 7, 7, 3, 3, 2, 2}, 92 | Workload{"float32", "float32", 1, 56, 56, 64, 64, 3, 3, 1, 1, 1, 1}, 93 | Workload{"float32", "float32", 32, 56, 56, 64, 64, 3, 3, 1, 1, 1, 1}, 94 | Workload{"float32", "float32", 1, 56, 56, 64, 64, 1, 1, 0, 0, 1, 1}, 95 | Workload{"float32", "float32", 32, 56, 56, 64, 64, 1, 1, 0, 0, 1, 1}, 96 | Workload{"float32", "float32", 1, 56, 56, 64, 128, 3, 3, 1, 1, 2, 2}, 97 | Workload{"float32", "float32", 32, 56, 56, 64, 128, 3, 3, 1, 1, 2, 2}, 98 | Workload{"float32", "float32", 1, 56, 56, 64, 128, 1, 1, 0, 0, 2, 2}, 99 | Workload{"float32", "float32", 32, 56, 56, 64, 128, 1, 1, 0, 0, 2, 2}, 100 | Workload{"float32", "float32", 1, 28, 28, 128, 128, 3, 3, 1, 1, 1, 1}, 101 | Workload{"float32", "float32", 32, 28, 28, 128, 128, 3, 3, 1, 1, 1, 1}, 102 | Workload{"float32", "float32", 1, 28, 28, 128, 256, 3, 3, 1, 1, 2, 2}, 103 | Workload{"float32", "float32", 32, 28, 28, 128, 256, 3, 3, 1, 1, 2, 2}, 104 | Workload{"float32", "float32", 1, 28, 28, 128, 256, 1, 1, 0, 0, 2, 2}, 105 | Workload{"float32", "float32", 32, 28, 28, 128, 256, 1, 1, 0, 0, 2, 2}, 106 | Workload{"float32", "float32", 1, 14, 14, 256, 256, 3, 3, 1, 1, 1, 1}, 107 | Workload{"float32", "float32", 32, 14, 14, 256, 256, 3, 3, 1, 1, 1, 1}, 108 | Workload{"float32", "float32", 1, 14, 14, 256, 512, 3, 3, 1, 1, 2, 2}, 109 | Workload{"float32", "float32", 32, 14, 14, 256, 512, 3, 3, 1, 1, 2, 2}, 110 | Workload{"float32", "float32", 1, 14, 14, 256, 512, 1, 1, 0, 0, 2, 2}, 111 | Workload{"float32", "float32", 32, 14, 14, 256, 512, 1, 1, 0, 0, 2, 2}, 112 | Workload{"float32", "float32", 1, 7, 7, 512, 512, 3, 3, 1, 1, 1, 1}, 113 | Workload{"float32", "float32", 32, 7, 7, 512, 512, 3, 3, 1, 1, 1, 1}, 114 | 115 | // // mobilenet 116 | // Workload{"float32", "float32", 1, 224, 224, 3, 32, 3, 3, 1, 1, 2, 2}, 117 | // Workload{"float32", "float32", 1, 112, 112, 32, 64, 1, 1, 0, 0, 1, 1}, 118 | // Workload{"float32", "float32", 1, 56, 56, 64, 128, 1, 1, 0, 0, 1, 1}, 119 | // Workload{"float32", "float32", 1, 56, 56, 128, 128, 1, 1, 0, 0, 1, 1}, 120 | // Workload{"float32", "float32", 1, 28, 28, 128, 256, 1, 1, 0, 0, 1, 1}, 121 | // Workload{"float32", "float32", 1, 28, 28, 256, 256, 1, 1, 0, 0, 1, 1}, 122 | // Workload{"float32", "float32", 1, 14, 14, 256, 512, 1, 1, 0, 0, 1, 1}, 123 | // Workload{"float32", "float32", 1, 14, 14, 512, 512, 1, 1, 0, 0, 1, 1}, 124 | // Workload{"float32", "float32", 1, 7, 7, 512, 1024, 1, 1, 0, 0, 1, 1}, 125 | // Workload{"float32", "float32", 1, 7, 7, 1024,1024, 1, 1, 0, 0, 1, 1}, 126 | }; 127 | 128 | for (size_t i = 0; i < sizeof(to_test) / sizeof(to_test[0]); i++) { 129 | Workload &w = to_test[i]; 130 | double cost, gflops; 131 | std::tie(cost, gflops) = MeasureConv(w); 132 | 133 | std::cout << std::fixed << std::setprecision(4); 134 | std::cout << w.height << "x" << w.width << 'x' << w.in_filter << "x" << w.out_filter 135 | << " " << w.hkernel << "\t"; 136 | std::cout << "cost: " << cost << ", " 137 | << "GFLOPS: " << gflops << std::endl; 138 | } 139 | 140 | return 0; 141 | } 142 | 143 | -------------------------------------------------------------------------------- /layer-test/depth.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include "arm_compute/core/Types.h" 7 | #include "arm_compute/runtime/CL/CLFunctions.h" 8 | #include "arm_compute/runtime/CL/CLScheduler.h" 9 | 10 | #include "util.h" 11 | 12 | using namespace arm_compute; 13 | 14 | struct Workload { 15 | std::string in_dtype; 16 | std::string out_dtype; 17 | size_t n; 18 | size_t height; 19 | size_t in_filter; 20 | int channel_m; 21 | size_t hkernel; 22 | size_t hpad; 23 | size_t hstride; 24 | }; 25 | 26 | // measure the cost and gflops of 2d convolution 27 | std::pair MeasureConv(const Workload &w, int times=100) { 28 | assert(w.in_dtype == w.out_dtype); 29 | Format format = DtypeToFormat(w.in_dtype); 30 | 31 | CLTensor input, weight, output; 32 | PadStrideInfo conv_info(w.hstride, w.hstride, w.hpad, w.hpad); 33 | 34 | // init OpenCL 35 | CLScheduler::get().default_init(); 36 | 37 | // allocate tensors 38 | input.allocator()->init(TensorInfo(TensorShape(w.height, w.height, w.in_filter), format)); 39 | weight.allocator()->init(TensorInfo(TensorShape(w.hkernel, w.hkernel, w.in_filter), format)); 40 | size_t h_out = (w.height - w.hkernel + w.hpad * 2) / w.hstride + 1; 41 | output.allocator()->init(TensorInfo(TensorShape(h_out, h_out, w.in_filter), format)); 42 | input.allocator()->allocate(); 43 | weight.allocator()->allocate(); 44 | output.allocator()->allocate(); 45 | CLScheduler::get().sync(); 46 | 47 | // configure gemm function 48 | CLDepthwiseConvolutionLayer conv2d; 49 | conv2d.configure(&input, &weight, nullptr, &output, conv_info); 50 | 51 | // run test 52 | conv2d.run(); 53 | CLScheduler::get().sync(); 54 | std::chrono::high_resolution_clock::time_point begin = std::chrono::high_resolution_clock::now(); 55 | 56 | for (int i = 0; i < times; i++) { 57 | conv2d.run(); 58 | } 59 | CLScheduler::get().sync(); 60 | 61 | std::chrono::high_resolution_clock::time_point end = std::chrono::high_resolution_clock::now(); 62 | 63 | // calcuate gflops 64 | std::chrono::duration fp_ms = end - begin; 65 | double cost = fp_ms.count() / times; 66 | return std::make_pair(cost, 2.0 * h_out * h_out * 67 | w.hkernel * w.hkernel * w.in_filter / 1e9 / cost); 68 | } 69 | 70 | 71 | int main(int argc, const char **argv) 72 | { 73 | Workload to_test[] = { 74 | // mobilenet 75 | Workload{"float32", "float32", 1, 112, 32, 1, 3, 1, 1}, 76 | Workload{"float32", "float32", 1, 112, 64, 1, 3, 1, 2}, 77 | Workload{"float32", "float32", 1, 56, 128, 1, 3, 1, 1}, 78 | Workload{"float32", "float32", 1, 56, 128, 1, 3, 1, 2}, 79 | Workload{"float32", "float32", 1, 28, 256, 1, 3, 1, 1}, 80 | Workload{"float32", "float32", 1, 28, 256, 1, 3, 1, 2}, 81 | Workload{"float32", "float32", 1, 14, 512, 1, 3, 1, 1}, 82 | Workload{"float32", "float32", 1, 14, 512, 1, 3, 1, 2}, 83 | Workload{"float32", "float32", 1, 7, 1024, 1, 3, 1, 1}, 84 | }; 85 | 86 | for (size_t i = 0; i < sizeof(to_test) / sizeof(to_test[0]); i++) { 87 | Workload &w = to_test[i]; 88 | double cost, gflops; 89 | std::tie(cost, gflops) = MeasureConv(w); 90 | 91 | std::cout << std::fixed << std::setprecision(6); 92 | std::cout << w.height << "x" << w.height << 'x' << w.in_filter << "x" << w.in_filter 93 | << " " << w.hkernel << "\t"; 94 | std::cout << "cost: " << cost << ", " 95 | << "GFLOPS: " << gflops << std::endl; 96 | } 97 | 98 | return 0; 99 | } 100 | 101 | -------------------------------------------------------------------------------- /layer-test/gemm.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "arm_compute/core/Types.h" 6 | #include "arm_compute/runtime/CL/CLFunctions.h" 7 | #include "arm_compute/runtime/CL/CLScheduler.h" 8 | 9 | #include "util.h" 10 | 11 | using namespace arm_compute; 12 | 13 | // measure the cost and gflops of gemm 14 | std::pair MeasureGemm(int n, int l, int m, std::string dtype, int times=30) { 15 | Format format = DtypeToFormat(dtype); 16 | 17 | CLTensor a, b, dst; 18 | 19 | // init OpenCL 20 | CLScheduler::get().default_init(); 21 | 22 | // allocate tensors 23 | a.allocator()->init(TensorInfo(l, n, format)); 24 | b.allocator()->init(TensorInfo(m, l, format)); 25 | dst.allocator()->init(TensorInfo(m, n, format)); 26 | a.allocator()->allocate(); 27 | b.allocator()->allocate(); 28 | dst.allocator()->allocate(); 29 | CLScheduler::get().sync(); 30 | 31 | // configure gemm function 32 | CLGEMM gemm; 33 | gemm.configure(&a, &b, nullptr, &dst, 1.0, 0.0); 34 | 35 | // run test 36 | gemm.run(); 37 | std::chrono::high_resolution_clock::time_point begin = std::chrono::high_resolution_clock::now(); 38 | 39 | for (int i = 0; i < times; i++) { 40 | gemm.run(); 41 | } 42 | CLScheduler::get().sync(); 43 | 44 | std::chrono::high_resolution_clock::time_point end = std::chrono::high_resolution_clock::now(); 45 | 46 | // calcuate gflops 47 | std::chrono::duration fp_ms = end - begin; 48 | double cost = fp_ms.count() / times; 49 | return std::make_pair(cost, 2.0 * n * l * m / (1e9) / cost); 50 | } 51 | 52 | int main(int argc, const char **argv) 53 | { 54 | size_t to_test[][3] = { 55 | {1024, 1024, 1024}, 56 | {2048, 2048, 2048}, 57 | }; 58 | 59 | for (size_t i = 0; i < sizeof(to_test) / sizeof(to_test[0]); i++) { 60 | int n, l, m; 61 | double cost, gflops; 62 | n = to_test[i][0]; 63 | l = to_test[i][1]; 64 | m = to_test[i][2]; 65 | 66 | std::tie(cost, gflops) = MeasureGemm(n, l, m, "float"); 67 | 68 | std::cout << std::fixed << std::setprecision(4); 69 | std::cout << "size: " << i << ", " << "cost: " << cost << ", " 70 | << "GFLOPS: " << gflops << std::endl; 71 | } 72 | 73 | return 0; 74 | } 75 | 76 | -------------------------------------------------------------------------------- /layer-test/test_conv2d.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | import numpy as np 5 | import tvm 6 | import topi 7 | from tvm.contrib import rpc, util 8 | from topi.util import get_const_tuple 9 | import topi.testing 10 | from tvm.contrib.pickle_memoize import memoize 11 | 12 | dtype = 'float32' 13 | 14 | def convert_to_remote(func, remote): 15 | temp = util.tempdir() 16 | prefix = str(np.random.randint(1 << 31)) + "_" 17 | path_dso = temp.relpath(prefix + "tmp_func.tar") 18 | func.export_library(path_dso) 19 | 20 | remote.upload(path_dso) 21 | func = remote.load_module(prefix + "tmp_func.tar") 22 | return func 23 | 24 | 25 | def generate_tune_packs(item_list): 26 | ret = [] 27 | 28 | now = {} 29 | def dfs(depth): 30 | if depth == len(item_list): 31 | ret.append(now.copy()) 32 | return 33 | 34 | name = item_list[depth][0] 35 | for value in item_list[depth][1]: 36 | now[name] = value 37 | dfs(depth + 1) 38 | 39 | dfs(0) 40 | 41 | return ret 42 | 43 | USE_MANUAL_CODE = False 44 | @tvm.register_func 45 | def tvm_callback_opencl_postproc(code): 46 | if not os.path.exists("perf"): 47 | os.mkdir("perf") 48 | with open("generated.cl", 'w') as fout: 49 | fout.write(code) 50 | if USE_MANUAL_CODE: 51 | split = code.split("\n") 52 | code = '\n'.join(split) 53 | return code 54 | 55 | 56 | 57 | def tune_conv2d_nchw(batch, in_size, in_channel, num_filter, kernel, padding, stride, ctx, 58 | n_times=1, target_host=None, remote=None): 59 | in_height = in_width = in_size 60 | 61 | A = tvm.placeholder((batch, in_channel, in_height, in_width), name='data') 62 | W = tvm.placeholder((num_filter, in_channel, kernel, kernel), name='weight') 63 | 64 | # get verify data 65 | a_shape = get_const_tuple(A.shape) 66 | w_shape = get_const_tuple(W.shape) 67 | dtype = A.dtype 68 | 69 | @memoize("topi.tests.test_topi_conv2d.verify_conv2d") 70 | def get_ref_data(): 71 | a_np = np.random.uniform(size=a_shape).astype(dtype) 72 | w_np = np.random.uniform(size=w_shape).astype(dtype) 73 | b_np = topi.testing.conv2d_nchw_python(a_np, w_np, stride, padding) 74 | return a_np, w_np, b_np 75 | 76 | a_np, w_np, b_np = get_ref_data() 77 | a = tvm.nd.array(a_np, ctx) 78 | w = tvm.nd.array(w_np, ctx) 79 | b = tvm.nd.array(np.zeros(b_np.shape, dtype=dtype), ctx) 80 | 81 | # generate static config 82 | #tune_pack = generate_tune_packs([ 83 | # ["bn", [4]], 84 | # ["num_thread", [1, 2, 4, 8, 16]], 85 | # ["unroll_step", [1, 4, 16]], 86 | # ]) 87 | 88 | tune_pack = generate_tune_packs([ 89 | ["VH", [1]], 90 | ["VW", [1, 7]], 91 | ["VC", [1, 2, 4, 8, 16]], 92 | ["num_thread", [1, 2, 4, 8, 16, 32]], 93 | ]) 94 | 95 | # search 96 | best_cost = 1e9 97 | best_config = None 98 | for config in reversed(tune_pack): 99 | with tvm.target.mali(): 100 | tvm.target.current_target().tune_config = config 101 | B = topi.nn.conv2d(A, W, stride, padding) 102 | s = topi.generic.schedule_conv2d_nchw([B]) 103 | func = tvm.build(s, [A, W, B], target_host=target_host) 104 | 105 | if remote is not None: 106 | func = convert_to_remote(func, remote) 107 | 108 | time_f = func.time_evaluator(func.entry_name, ctx, number=n_times) 109 | cost = time_f(a, w, b).mean 110 | 111 | try: 112 | np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-4) 113 | except Exception as e: 114 | pass 115 | 116 | gflops = 2.0 * np.prod(b.shape) * kernel * kernel * in_channel /(1e9)/ cost 117 | print(config, cost, gflops) 118 | if cost < best_cost: 119 | best_cost = cost 120 | best_config = config 121 | 122 | return best_cost, 2.0 * np.prod(b.shape) * kernel * kernel * in_channel /(1e9)/ best_cost, best_config 123 | 124 | 125 | def verify_conv2d_nchw(batch, in_size, in_channel, num_filter, kernel, padding, stride, ctx, 126 | n_times=1, target=None, target_host=None, remote=None): 127 | in_height = in_width = in_size 128 | 129 | A = tvm.placeholder((batch, in_channel, in_height, in_width), dtype=dtype, name='data') 130 | W = tvm.placeholder((num_filter, in_channel, kernel, kernel), dtype=dtype, name='weight') 131 | 132 | with target: 133 | B = topi.nn.conv2d(A, W, stride, padding) 134 | s = topi.generic.schedule_conv2d_nchw([B]) 135 | func = tvm.build(s, [A, W, B], target_host=target_host) 136 | #print(func.imported_modules[0].get_source()) 137 | 138 | a_shape = get_const_tuple(A.shape) 139 | w_shape = get_const_tuple(W.shape) 140 | 141 | @memoize("topi.tests.test_topi_conv2d.verify_conv2d") 142 | def get_ref_data(): 143 | a_np = np.random.uniform(size=a_shape) 144 | w_np = np.random.uniform(size=w_shape) 145 | b_np = topi.testing.conv2d_nchw_python(a_np, w_np, stride, padding) 146 | return a_np, w_np, b_np 147 | 148 | a_np, w_np, b_np = get_ref_data() 149 | a = tvm.nd.array(a_np.astype(dtype), ctx) 150 | w = tvm.nd.array(w_np.astype(dtype), ctx) 151 | b = tvm.nd.array(np.zeros(get_const_tuple(B.shape)).astype(B.dtype), ctx) 152 | 153 | if remote is not None: 154 | func = convert_to_remote(func, remote) 155 | 156 | time_f = func.time_evaluator(func.entry_name, ctx, number=n_times) 157 | cost = time_f(a, w, b).mean 158 | 159 | try: 160 | if dtype == 'float32': 161 | np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-4) 162 | elif dtype == 'float16': 163 | np.testing.assert_allclose(b.asnumpy(), b_np, rtol=0.2) 164 | else: 165 | raise NotImplementedError 166 | except Exception as e: 167 | print(e) 168 | 169 | return cost, 2.0 * np.prod(b.shape) * kernel * kernel * in_channel / (1e9) / cost 170 | 171 | workloads = [ 172 | # vgg16 173 | # (1, 224, 3, 64, 3, 1, 1), 174 | # (1, 224, 64, 64, 3, 1, 1), 175 | # (1, 112, 64, 128, 3, 1, 1), 176 | # (1, 112,128, 128, 3, 1, 1), 177 | # (1, 56, 128, 256, 3, 1, 1), 178 | # (1, 56, 256, 256, 3, 1, 1), 179 | # (1, 28, 256, 512, 3, 1, 1), 180 | # (1, 28, 512, 512, 3, 1, 1), 181 | # (1, 14, 512, 512, 3, 1, 1), 182 | # (1, 7, 512, 512, 3, 1, 1), 183 | # 184 | # # resnet-18 185 | # (1, 224, 3, 64, 7, 3, 2), 186 | # (1, 56, 64, 64, 3, 1, 1), 187 | # (1, 56, 64, 64, 1, 0, 1), 188 | # (1, 56, 64, 128, 3, 1, 2), 189 | # (1, 56, 64, 128, 1, 0, 2), 190 | (1, 28, 128, 128, 3, 1, 1), 191 | # (1, 28, 128, 256, 3, 1, 2), 192 | # (1, 28, 128, 256, 1, 0, 2), 193 | # (1, 14, 256, 256, 3, 1, 1), 194 | # (1, 14, 256, 512, 3, 1, 2), 195 | # (1, 14, 256, 512, 1, 0, 2), 196 | # (1, 7, 512, 512, 3, 1, 1), 197 | # 198 | # mobilenet 199 | # (1, 224, 3, 32, 3, 1, 2), 200 | # (1, 112, 32, 64, 1, 0, 1), 201 | # (1, 56, 64, 128, 1, 0, 1), 202 | # (1, 56, 128, 128, 1, 0, 1), 203 | # (1, 28, 128, 256, 1, 0, 1), 204 | # (1, 28, 256, 256, 1, 0, 1), 205 | # (1, 14, 256, 512, 1, 0, 1), 206 | # (1, 14, 512, 512, 1, 0, 1), 207 | # (1, 7, 512, 1024, 1, 0, 1), 208 | # (1, 7, 1024, 1024, 1, 0, 1), 209 | ] 210 | 211 | def verify_workloads(ctx, n_times=1, target=None, target_host=None, remote=None): 212 | for item in workloads: 213 | cost, gflops = verify_conv2d_nchw(*item, ctx=ctx, target=target, 214 | target_host=target_host, remote=remote) 215 | print("%-30s %.6f %.6f" % (item, cost, gflops)) 216 | 217 | #def tune_workloads(ctx, n_times=1, target=None, target_host=None, remote=None): 218 | # ret = [] 219 | # for item in workloads: 220 | # cost, gflops, config = tune_conv2d_nchw(*item, ctx=ctx, target_host=target_host, remote=remote) 221 | # print(item, cost, gflops, config) 222 | # ret.append([item, config]) 223 | # for item in ret: 224 | # print(item, config) 225 | 226 | if __name__ == "__main__": 227 | host = os.environ["TVM_OPENCL_DEVICE_HOST"] 228 | port = 9090 229 | remote = rpc.connect(host, port) 230 | target_host = "llvm -target=aarch64-linux-gnu -mattr=+neon" 231 | 232 | verify_workloads(remote.cl(), 10, tvm.target.mali(), target_host, remote) 233 | 234 | -------------------------------------------------------------------------------- /layer-test/test_dense.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | import numpy as np 5 | import tvm 6 | import topi 7 | from tvm.contrib import rpc, util 8 | from topi.util import get_const_tuple 9 | from tvm.contrib.pickle_memoize import memoize 10 | 11 | dtype = 'float16' 12 | 13 | USE_MANUAL_CODE = False 14 | @tvm.register_func 15 | def tvm_callback_opencl_postproc(code): 16 | if not os.path.exists("perf"): 17 | os.mkdir("perf") 18 | with open("generated.cl", 'w') as fout: 19 | fout.write(code) 20 | if USE_MANUAL_CODE: 21 | with open("manual.cl") as fin: 22 | code = "\n".join(fin.readlines()) 23 | print(code) 24 | return code 25 | 26 | 27 | def convert_to_remote(func, remote): 28 | temp = util.tempdir() 29 | prefix = str(np.random.randint(1 << 31)) + "_" 30 | path_dso = temp.relpath(prefix + "tmp_func.tar") 31 | func.export_library(path_dso) 32 | 33 | remote.upload(path_dso) 34 | func = remote.load_module(prefix + "tmp_func.tar") 35 | return func 36 | 37 | 38 | def generate_tune_packs(item_list): 39 | ret = [] 40 | 41 | now = {} 42 | def dfs(depth): 43 | if depth == len(item_list): 44 | ret.append(now.copy()) 45 | return 46 | 47 | name = item_list[depth][0] 48 | for value in item_list[depth][1]: 49 | now[name] = value 50 | dfs(depth + 1) 51 | 52 | dfs(0) 53 | 54 | return ret 55 | 56 | 57 | def tune_dense(batch, hidden, out, ctx, 58 | n_times=1, target_host=None, remote=None): 59 | A = tvm.placeholder((1, hidden), dtype=dtype, name='A') 60 | B = tvm.placeholder((out, hidden), dtype=dtype, name='B') 61 | BIAS = tvm.placeholder((out,), dtype=dtype, name='bias') 62 | 63 | # generate static config 64 | tune_pack = generate_tune_packs([ 65 | # ["bn", [1, 2, 4, 8, 16]], 66 | # ["reuse", [1, 2, 4, 8]], 67 | ["num_thread", [1, 2, 4, 32, 64, 256]], 68 | ["unroll_step", [1, 2, 4, 5, 6, 16, 32]], 69 | ]) 70 | 71 | a_shape = get_const_tuple(A.shape) 72 | b_shape = get_const_tuple(B.shape) 73 | bias_shape = get_const_tuple(BIAS.shape) 74 | c_shape = (1, out) 75 | 76 | a_np = np.random.uniform(size=a_shape) 77 | b_np = np.random.uniform(size=b_shape) 78 | bias_np = np.random.uniform(size=bias_shape) 79 | c_np = np.random.uniform(size=c_shape) 80 | 81 | # search 82 | tic = time.time() 83 | best_cost = 1e9 84 | best_config = None 85 | for i, config in enumerate(tune_pack): 86 | with tvm.target.mali(): 87 | tvm.target.current_target().tune_config = config 88 | C = topi.nn.dense(A, B, BIAS) 89 | s = topi.generic.schedule_dense([C]) 90 | func = tvm.build(s, [A, B, BIAS, C], target_host=target_host) 91 | 92 | a = tvm.nd.array(a_np.astype(dtype), ctx=ctx) 93 | b = tvm.nd.array(b_np.astype(dtype), ctx=ctx) 94 | bias = tvm.nd.array(bias_np.astype(dtype), ctx=ctx) 95 | c = tvm.nd.array(c_np.astype(dtype), ctx=ctx) 96 | 97 | if remote is not None: 98 | func = convert_to_remote(func, remote) 99 | 100 | time_f = func.time_evaluator(func.entry_name, ctx, number=n_times) 101 | cost = time_f(a, b, bias, c).mean 102 | 103 | gflops = 2.0 * np.prod(b.shape) / (1e9) / cost 104 | if cost < best_cost: 105 | print(config, cost, gflops) 106 | best_cost = cost 107 | best_config = config 108 | 109 | if i % 20 == 0: 110 | print(i, len(tune_pack), time.time()- tic, (time.time() - tic) / (i+1)) 111 | 112 | try: 113 | np.testing.assert_allclose(np.dot(a_np, b_np.T) + bias_np, c.asnumpy(), rtol=1e-2) 114 | except Exception as e: 115 | pass 116 | print(e) 117 | 118 | return best_cost, 2.0 * np.prod(b.shape) / (1e9) / best_cost, best_config 119 | 120 | def verify_dense(batch, hidden, out, ctx, 121 | n_times=1, target_host=None, remote=None): 122 | A = tvm.placeholder((1, hidden), dtype=dtype, name='A') 123 | B = tvm.placeholder((out, hidden), dtype=dtype, name='B') 124 | bias = tvm.placeholder((out,), dtype=dtype, name='bias') 125 | 126 | with tvm.target.mali(): 127 | C = topi.nn.dense(A, B, bias) 128 | s = topi.generic.schedule_dense([C]) 129 | func = tvm.build(s, [A, B, bias, C], target_host=target_host) 130 | 131 | a_shape = get_const_tuple(A.shape) 132 | b_shape = get_const_tuple(B.shape) 133 | bias_shape = get_const_tuple(bias.shape) 134 | c_shape = get_const_tuple(C.shape) 135 | 136 | a_np = np.random.uniform(size=a_shape) 137 | b_np = np.random.uniform(size=b_shape) 138 | bias_np = np.random.uniform(size=bias_shape) 139 | c_np = np.random.uniform(size=c_shape) 140 | 141 | a = tvm.nd.array(a_np.astype(dtype), ctx=ctx) 142 | b = tvm.nd.array(b_np.astype(dtype), ctx=ctx) 143 | bias = tvm.nd.array(bias_np.astype(dtype), ctx=ctx) 144 | c = tvm.nd.array(np.zeros_like(c_np).astype(dtype), ctx=ctx) 145 | 146 | if remote is not None: 147 | func = convert_to_remote(func, remote) 148 | 149 | time_f = func.time_evaluator(func.entry_name, ctx, number=n_times) 150 | cost = time_f(a, b, bias, c).mean 151 | 152 | try: 153 | np.testing.assert_allclose(np.dot(a_np, b_np.T) + bias_np, c.asnumpy(), rtol=1e-1) 154 | except Exception as e: 155 | pass 156 | print(e) 157 | 158 | return cost, 2.0 * np.prod(b.shape) / (1e9) / cost 159 | 160 | workloads = [ 161 | (1, 25088, 4096), 162 | # (1, 4096, 4096), 163 | # (1, 4096, 1000), 164 | # (1, 1024, 1000), 165 | ] 166 | 167 | def verify_workloads(ctx, n_times=1, target_host=None, remote=None): 168 | for item in workloads: 169 | cost, gflops = verify_dense(*item, ctx=ctx, target_host=target_host, remote=remote) 170 | print("%-30s %.6f %.6f" % (item, cost, gflops)) 171 | 172 | def tune_workloads(ctx, n_times=1, target_host=None, remote=None): 173 | for item in workloads: 174 | cost, gflops, config = tune_dense(*item, ctx=ctx, target_host=target_host, remote=remote) 175 | print(item, cost, gflops, config) 176 | 177 | if __name__ == "__main__": 178 | host = os.environ["TVM_OPENCL_DEVICE_HOST"] 179 | port = 9090 180 | remote = rpc.connect(host, port) 181 | target_host = "llvm -target=aarch64-linux-gnu" 182 | 183 | verify_workloads(remote.cl(), 10, target_host, remote) 184 | 185 | -------------------------------------------------------------------------------- /layer-test/test_depth.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import tvm 5 | import topi 6 | from tvm.contrib import rpc, util 7 | from topi.util import get_const_tuple 8 | from tvm.contrib.pickle_memoize import memoize 9 | 10 | dtype = 'float32' 11 | 12 | def convert_to_remote(func, remote): 13 | temp = util.tempdir() 14 | prefix = str(np.random.randint(1 << 31)) + "_" 15 | path_dso = temp.relpath(prefix + "tmp_func.tar") 16 | func.export_library(path_dso) 17 | 18 | remote.upload(path_dso) 19 | func = remote.load_module(prefix + "tmp_func.tar") 20 | return func 21 | 22 | 23 | def generate_tune_packs(item_list): 24 | ret = [] 25 | 26 | now = {} 27 | def dfs(depth): 28 | if depth == len(item_list): 29 | ret.append(now.copy()) 30 | return 31 | 32 | name = item_list[depth][0] 33 | for value in item_list[depth][1]: 34 | now[name] = value 35 | dfs(depth + 1) 36 | 37 | dfs(0) 38 | 39 | return ret 40 | 41 | USE_MANUAL_CODE = False 42 | @tvm.register_func 43 | def tvm_callback_opencl_postproc(code): 44 | if not os.path.exists("perf"): 45 | os.mkdir("perf") 46 | with open("generated.cl", 'w') as fout: 47 | fout.write(code) 48 | if USE_MANUAL_CODE: 49 | split = code.split("\n") 50 | code = '\n'.join(split) 51 | return code 52 | 53 | 54 | 55 | def tune_conv2d_nchw(batch, in_size, in_channel, num_filter, kernel, padding, stride, ctx, 56 | n_times=1, target_host=None, remote=None): 57 | in_height = in_width = in_size 58 | 59 | A = tvm.placeholder((batch, in_channel, in_height, in_width), dtype=dtype, name='data') 60 | W = tvm.placeholder((num_filter, in_channel, kernel, kernel), dtype=dtype, name='weight') 61 | 62 | # get verify data 63 | a_shape = get_const_tuple(A.shape) 64 | w_shape = get_const_tuple(W.shape) 65 | 66 | @memoize("topi.tests.test_topi_conv2d.verify_conv2d") 67 | def get_ref_data(): 68 | a_np = np.random.uniform(size=a_shape) 69 | w_np = np.random.uniform(size=w_shape) 70 | b_np = topi.testing.conv2d_nchw_python(a_np, w_np, stride, padding) 71 | return a_np, w_np, b_np 72 | 73 | a_np, w_np, b_np = get_ref_data() 74 | a = tvm.nd.array(a_np.astype(dtype), ctx) 75 | w = tvm.nd.array(w_np.astype(dtype), ctx) 76 | b = tvm.nd.array(np.zeros(b_np.shape).astype(dtype), ctx) 77 | 78 | # generate static config 79 | #tune_pack = generate_tune_packs([ 80 | # ["bn", [4]], 81 | # ["num_thread", [1, 2, 4, 8, 16]], 82 | # ["unroll_step", [1, 4, 16]], 83 | # ]) 84 | 85 | tune_pack = generate_tune_packs([ 86 | ["VH", [1, 2, 4]], 87 | ["VW", [1, 2, 4, 8]], 88 | ["VC", [1, 2, 4, 8]], 89 | ["num_thread", [1, 2, 4, 16, 32, 64]], 90 | ]) 91 | 92 | # search 93 | best_cost = 1e9 94 | best_config = None 95 | for config in reversed(tune_pack): 96 | with tvm.target.mali(): 97 | tvm.target.current_target().tune_config = config 98 | B = topi.nn.conv2d(A, W, stride, padding) 99 | s = topi.generic.schedule_conv2d_nchw([B]) 100 | func = tvm.build(s, [A, W, B], target_host=target_host) 101 | 102 | if remote is not None: 103 | func = convert_to_remote(func, remote) 104 | 105 | time_f = func.time_evaluator(func.entry_name, ctx, number=n_times) 106 | cost = time_f(a, w, b).mean 107 | 108 | try: 109 | np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-4) 110 | except Exception as e: 111 | pass 112 | 113 | gflops = 2.0 * np.prod(b.shape) * kernel * kernel * in_channel /(1e9)/ cost 114 | print(config, cost, gflops) 115 | if cost < best_cost: 116 | best_cost = cost 117 | best_config = config 118 | 119 | return best_cost, 2.0 * np.prod(b.shape) * kernel * kernel * in_channel /(1e9)/ best_cost, best_config 120 | 121 | 122 | def verify_conv2d_nchw(batch, in_size, in_channel, channel_multiplier, kernel, padding, stride, ctx, 123 | n_times=1, target_host=None, remote=None): 124 | in_height = in_width = in_size 125 | 126 | A = tvm.placeholder((batch, in_channel, in_height, in_width), dtype=dtype, name='data') 127 | W = tvm.placeholder((in_channel, channel_multiplier, kernel, kernel), dtype=dtype, name='weight') 128 | 129 | with tvm.target.mali(): 130 | B = topi.nn.depthwise_conv2d_nchw(A, W, stride, padding) 131 | #B = topi.nn.relu(B) 132 | s = topi.generic.schedule_depthwise_conv2d_nchw([B]) 133 | func = tvm.build(s, [A, W, B], target_host=target_host) 134 | 135 | a_shape = get_const_tuple(A.shape) 136 | w_shape = get_const_tuple(W.shape) 137 | 138 | @memoize("topi.tests.test_topi_depthconv.verify_depthconv") 139 | def get_ref_data(): 140 | a_np = np.random.uniform(size=a_shape).astype('float32') 141 | w_np = np.random.uniform(size=w_shape).astype('float32') 142 | b_np = topi.testing.depthwise_conv2d_python_nchw(a_np, w_np, stride, padding) 143 | return a_np, w_np, b_np 144 | 145 | a_np, w_np, b_np = get_ref_data() 146 | a = tvm.nd.array(a_np.astype(dtype), ctx) 147 | w = tvm.nd.array(w_np.astype(dtype), ctx) 148 | b = tvm.nd.array(np.zeros(get_const_tuple(B.shape)).astype(dtype), ctx) 149 | 150 | if remote is not None: 151 | func = convert_to_remote(func, remote) 152 | 153 | time_f = func.time_evaluator(func.entry_name, ctx, number=n_times) 154 | cost = time_f(a, w, b).mean 155 | 156 | try: 157 | np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-1) 158 | except Exception as e: 159 | print(e) 160 | 161 | return cost, 2.0 * np.prod(b.shape) * kernel * kernel / (1e9) / cost 162 | 163 | workloads = [ 164 | # mobilenet 165 | (1, 112, 32, 1, 3, 1, 1), 166 | (1, 112, 64, 1, 3, 1, 2), 167 | (1, 56, 128, 1, 3, 1, 1), 168 | (1, 56, 128, 1, 3, 1, 2), 169 | (1, 28, 256, 1, 3, 1, 1), 170 | (1, 28, 256, 1, 3, 1, 2), 171 | (1, 14, 512, 1, 3, 1, 1), 172 | (1, 14, 512, 1, 3, 1, 2), 173 | (1, 7, 1024, 1, 3, 1, 1), 174 | ] 175 | 176 | def verify_workloads(ctx, n_times=1, target_host=None, remote=None): 177 | for item in workloads: 178 | cost, gflops = verify_conv2d_nchw(*item, ctx=ctx, target_host=target_host, remote=remote) 179 | print("%-30s %.6f %.6f" % (item, cost, gflops)) 180 | 181 | def tune_workloads(ctx, n_times=1, target_host=None, remote=None): 182 | for item in workloads: 183 | cost, gflops, config = tune_conv2d_nchw(*item, ctx=ctx, target_host=target_host, remote=remote) 184 | print(item, cost, gflops, config) 185 | 186 | if __name__ == "__main__": 187 | host = os.environ["TVM_OPENCL_DEVICE_HOST"] 188 | port = 9090 189 | remote = rpc.connect(host, port) 190 | target_host = "llvm -target=aarch64-linux-gnu -mattr=+neon" 191 | 192 | verify_workloads(remote.cl(), 10000, target_host, remote) 193 | 194 | -------------------------------------------------------------------------------- /layer-test/util.cc: -------------------------------------------------------------------------------- 1 | #include "util.h" 2 | 3 | // read data from CLTensor 4 | void ReadTensor(const CLTensor *tensor, void *to, size_t size) { 5 | cl::CommandQueue &queue = CLScheduler::get().queue(); 6 | queue.enqueueReadBuffer(tensor->cl_buffer(), true, 0, size, to); 7 | queue.finish(); 8 | } 9 | 10 | // write data to CLTensor 11 | void WriteTensor(CLTensor *tensor, const void *from, size_t size) { 12 | cl::CommandQueue &queue = CLScheduler::get().queue(); 13 | queue.enqueueWriteBuffer(tensor->cl_buffer(), true, 0, size, from); 14 | queue.finish(); 15 | } 16 | 17 | // transform dtype to format in arm compute 18 | Format DtypeToFormat(std::string dtype) { 19 | if (dtype == "float" || dtype == "float32") 20 | return Format::F32; 21 | else if (dtype == "float16") 22 | return Format::F16; 23 | else { 24 | std::cerr << "Unsupported type: " << dtype << std::endl; 25 | exit(-1); 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /layer-test/util.h: -------------------------------------------------------------------------------- 1 | #ifndef ARM_COMPUTE_UTIL_H_ 2 | #define ARM_COMPUTE_UTIL_H_ 3 | 4 | #include "arm_compute/runtime/CL/CLFunctions.h" 5 | #include "arm_compute/core/Types.h" 6 | #include 7 | 8 | using namespace arm_compute; 9 | 10 | // read data from CLTensor 11 | void ReadTensor(CLTensor &tensor, void *to, size_t size); 12 | 13 | // write data to CLTensor 14 | void WriteTensor(CLTensor &tensor, void *from, size_t size); 15 | 16 | // transform dtype to format in arm compute 17 | Format DtypeToFormat(std::string); 18 | 19 | #endif // ARM_COMPUTE_UTIL_H_ 20 | 21 | -------------------------------------------------------------------------------- /mali_imagenet_bench.py: -------------------------------------------------------------------------------- 1 | """ 2 | Benchmark inference speed on ImageNet 3 | Example (run on Firefly RK3399): 4 | python mali_imagenet_bench.py --target-host 'llvm -target=aarch64-linux-gnu' --host 192.168.0.100 --port 9090 --model mobilenet 5 | """ 6 | 7 | import time 8 | import argparse 9 | import numpy as np 10 | import tvm 11 | import nnvm.compiler 12 | import nnvm.testing 13 | from tvm.contrib import util, rpc 14 | from tvm.contrib import graph_runtime as runtime 15 | 16 | def run_case(model, dtype): 17 | # load model 18 | if model == 'vgg16': 19 | net, params = nnvm.testing.vgg.get_workload(num_layers=16, 20 | batch_size=1, image_shape=image_shape, dtype=dtype) 21 | elif model == 'resnet18': 22 | net, params = nnvm.testing.resnet.get_workload(num_layers=18, 23 | batch_size=1, image_shape=image_shape, dtype=dtype) 24 | elif model == 'mobilenet': 25 | net, params = nnvm.testing.mobilenet.get_workload( 26 | batch_size=1, image_shape=image_shape, dtype=dtype) 27 | else: 28 | raise ValueError('no benchmark prepared for {}.'.format(model)) 29 | 30 | # compile 31 | opt_level = 2 if dtype == 'float32' else 1 32 | with nnvm.compiler.build_config(opt_level=opt_level): 33 | graph, lib, params = nnvm.compiler.build( 34 | net, tvm.target.mali(), shape={"data": data_shape}, params=params, 35 | dtype=dtype, target_host=args.target_host) 36 | 37 | # upload model to remote device 38 | tmp = util.tempdir() 39 | lib_fname = tmp.relpath('net.tar') 40 | lib.export_library(lib_fname) 41 | 42 | if args.host is not None: 43 | remote = rpc.connect(args.host, args.port) 44 | remote.upload(lib_fname) 45 | 46 | ctx = remote.cl(0) 47 | rlib = remote.load_module('net.tar') 48 | rparams = {k: tvm.nd.array(v, ctx) for k, v in params.items()} 49 | else: 50 | ctx = tvm.cl(0) 51 | rlib = lib 52 | rparams = params 53 | 54 | # create graph runtime 55 | module = runtime.create(graph, rlib, ctx) 56 | module.set_input('data', tvm.nd.array(np.random.uniform(size=(data_shape)).astype(dtype))) 57 | module.set_input(**rparams) 58 | 59 | # benchmark 60 | # print("============================================================") 61 | # print("model: %s, dtype: %s" % (model, dtype)) 62 | 63 | # the num of runs for warm up and test 64 | num_warmup = 10 65 | num_test = 60 66 | if model == 'mobilenet': # mobilenet is fast, need more runs for stable measureament 67 | num_warmup *= 5 68 | num_test *= 5 69 | 70 | # perform some warm up runs 71 | # print("warm up..") 72 | warm_up_timer = module.module.time_evaluator("run", ctx, num_warmup) 73 | warm_up_timer() 74 | 75 | # test 76 | # print("test..") 77 | ftimer = module.module.time_evaluator("run", ctx, num_test) 78 | prof_res = ftimer() 79 | # print("cost per image: %.4fs" % prof_res.mean) 80 | 81 | print("backend: TVM-mali\tmodel: %s\tdtype: %s\tcost:%.4f" % (model, dtype, prof_res.mean)) 82 | 83 | if __name__ == '__main__': 84 | parser = argparse.ArgumentParser() 85 | parser.add_argument('--model', type=str, required=True, choices=['vgg16', 'resnet18', 'mobilenet', 'all'], 86 | help="The model type.") 87 | parser.add_argument('--dtype', type=str, default='float32', choices=['float16', 'float32']) 88 | parser.add_argument('--host', type=str, help="The host address of your arm device.", default=None) 89 | parser.add_argument('--port', type=int, help="The port number of your arm device", default=None) 90 | parser.add_argument('--target-host', type=str, help="The compilation target of host device.", default=None) 91 | args = parser.parse_args() 92 | 93 | # set parameter 94 | batch_size = 1 95 | num_classes = 1000 96 | image_shape = (3, 224, 224) 97 | 98 | # load model 99 | data_shape = (batch_size,) + image_shape 100 | out_shape = (batch_size, num_classes) 101 | 102 | if args.model == 'all': # test all 103 | for model in ['vgg16', 'resnet18', 'mobilenet']: 104 | for dtype in ['float32', 'float16']: 105 | run_case(model, dtype) 106 | time.sleep(10) 107 | 108 | else: # test single 109 | run_case(args.model, args.dtype) 110 | 111 | -------------------------------------------------------------------------------- /mxnet_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import sys 4 | import argparse 5 | 6 | import numpy as np 7 | 8 | import mxnet as mx 9 | from mxnet import gluon 10 | from mxnet.gluon.model_zoo.vision import get_model 11 | from mxnet.gluon.utils import download 12 | 13 | input_size = 224 14 | 15 | def test_module(model, dtype): 16 | assert dtype == 'float32' 17 | 18 | if model == 'vgg16': 19 | model_block = mx.gluon.model_zoo.vision.get_vgg(16, pretrained=False) 20 | elif model == 'mobilenet': 21 | model_block = mx.gluon.model_zoo.vision.get_mobilenet(1.0, pretrained=False) 22 | elif model == 'resnet18': 23 | model_block = mx.gluon.model_zoo.vision.get_resnet(version=1, num_layers=18, pretrained=False) 24 | else: 25 | raise RuntimeError("invalid model model " + model) 26 | model_block.collect_params().initialize(mx.init.Xavier()) 27 | 28 | # define input and test function 29 | x = mx.nd.array(np.zeros((1, 3, input_size, input_size))) 30 | def measure(n_time): 31 | out = model_block(x).asnumpy() 32 | tic = time.time() 33 | for i in range(n_time): 34 | out = model_block(x).asnumpy() 35 | cost = time.time() - tic 36 | return cost / n_time 37 | 38 | # benchmark 39 | # print("============================================================") 40 | # print("model: %s, dtype: %s" % (model, dtype)) 41 | 42 | num_warmup = 15 43 | num_test = 80 44 | if model == 'mobilenet': # mobilenet is fast, need more runs for stable measureament 45 | num_warmup *= 4 46 | num_test *= 4 47 | 48 | # warm up 49 | # print("warm up...") 50 | measure(num_warmup) 51 | 52 | # print("test..") 53 | cost = measure(num_test) 54 | # print("cost per image: %.4fs" % cost) 55 | 56 | print("backend: MXNet+OpenBLAS\tmodel: %s\tdtype: %s\tcost:%.4f" % (model, dtype, cost)) 57 | 58 | if __name__ == "__main__": 59 | parser = argparse.ArgumentParser() 60 | parser.add_argument('--model', type=str, required=True, choices=['vgg16', 'mobilenet', 'resnet18', 'all']) 61 | args = parser.parse_args() 62 | 63 | if args.model == 'all': 64 | for model in ['resnet18', 'mobilenet', 'vgg16']: 65 | test_module(model, 'float32') 66 | time.sleep(20) 67 | else: 68 | test_module(args.model, 'float32') 69 | 70 | -------------------------------------------------------------------------------- /results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merrymercy/tvm-mali/70f57626b45b484edcb9ef03e01e5ff3d63296ed/results.png -------------------------------------------------------------------------------- /run_test.sh: -------------------------------------------------------------------------------- 1 | sudo /etc/init.d/lightdm stop 2 | sudo echo performance > /sys/class/misc/mali0/device/devfreq/ff9a0000.gpu/governor 3 | 4 | export PYTHONPATH=$(pwd)/nnvm/python:$(pwd)/nnvm/tvm/python:$(pwd)/nnvm/tvm/topi/python:$(pwd)/incubator-mxnet/python 5 | 6 | python mxnet_test.py --model all 7 | python mali_imagenet_bench.py --model all 8 | LD_LIBRARY_PATH=ComputeLibrary/build ./acl_test all 9 | 10 | -------------------------------------------------------------------------------- /spatial.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merrymercy/tvm-mali/70f57626b45b484edcb9ef03e01e5ff3d63296ed/spatial.png --------------------------------------------------------------------------------