├── .gitignore ├── LICENSE ├── README.md ├── conv-3x3s2p1c3x4-neonfma-2x2.c ├── conv-3x3s2p1c3x4-scalar-1x1.c ├── dwconv-3x3p1-neonfma.c ├── dwconv-3x3p1-scalar.c ├── dwconv-3x3s2p1-neonfma.c ├── dwconv-3x3s2p1-scalar.c ├── dwconv-5x5p2-neonfma.c ├── dwconv-5x5p2-scalar.c ├── dwconv-5x5s2p2-neonfma.c ├── dwconv-5x5s2p2-scalar.c ├── gavgpool-neon-x4.c ├── gavgpool-scalar-x1.c ├── spmm-16x1-neonfma-pipelined.c ├── spmm-16x2-neonfma.c ├── spmm-16x4-neonfma.c ├── spmm-8x1-scalar.c ├── spmm-8x2-scalar.c └── spmm-8x4-scalar.c /.gitignore: -------------------------------------------------------------------------------- 1 | *~ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD License 2 | 3 | Copyright The Fast Sparse ConvNets Authors. 4 | 5 | Redistribution and use in source and binary forms, with or without modification, 6 | are permitted provided that the following conditions are met: 7 | 8 | * Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | * Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | * Neither the name of the copyright holder nor the names of its contributors may be used to 16 | endorse or promote products derived from this software without specific 17 | prior written permission. 18 | 19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 20 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 21 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 23 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 24 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 25 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 26 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 28 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # cvpr2020 2 | This directory contains all of the ARM computational kernels used the paper "Fast 3 | Sparse ConvNets" submitted to CVPR 2020. 4 | 5 | spmm-NxM-[scalar,neonfma].c - Sparse Matrix Multiplication with an unroll of N 6 | in the HW dimension and a block size of M in the channel output dimension. 7 | 8 | dwconv-KxKsSpP-[scalar,neonfma].c - Depthwise Convolution with a filter of KxK, 9 | a stride of S and symmetric padding of P. 10 | 11 | gavgpool-[scalar,neon]-xN.c - Global average pooling unrolled over N rows. 12 | 13 | conv-3x3s2p1c3x4-[scalar,neonfma]-KxK - Full 3x3 stride 2 convolution with 14 | HWC input and CHW output. Operates on 3 input channels (only supported) and 4 15 | output channels in the inner loop. Produces a KxK output. 16 | 17 | # TF-Lite Models 18 | 19 | The MobileNetV1 models have a block size of 4 in the last block, otherwise they are unstructured. 20 | 21 | The MobileNetV2 models have a block size of 2 from block 11 onwards, otherwise they are unstructured. The exception is the width 1.8, 80% sparse model which unstructured throughout. 22 | 23 | The first full convolution and the final fully connected layer are both dense in all models. 24 | 25 | EfficientNet models are fully unstructured and the final fully connected layer is sparse. 26 | 27 | | Model | Top-1 Accuracy | Sparsity | Latency (ms) SD 835 | Download | 28 | |-------|:----------------:|:----------:|:---------------------:|:----------:| 29 | | MobileNetV1 .75 | 64.4% | 90% | 21 | [link](https://storage.googleapis.com/fast-convnets/tflite-models/mbv1_.75_12_90_64.4.tflite) 30 | | MobileNetV1 1.0 | 68.4% | 90% | 31 | [link](https://storage.googleapis.com/fast-convnets/tflite-models/mbv1_1.0_12_90_68.4.tflite) 31 | | MobileNetV1 1.4 | 72.0% | 90% | 58 | [link](https://storage.googleapis.com/fast-convnets/tflite-models/mbv1_1.4_12_90_72.0.tflite) 32 | | MobileNetV2 .8 | 65.2% | 85% | 26 |[link](https://storage.googleapis.com/fast-convnets/tflite-models/mbv2_.80_11-16b2_85_65_2.tflite) 33 | | Cache Aware MobileNetV2 1.0 | 69.7% | 85% | 33 | [link](https://storage.googleapis.com/fast-convnets/tflite-models/humannasnet_1.0_x_85_69_7.tflite) 34 | | MobileNetV2 1.15 | 70.2% | 85% | 40 | [link](https://storage.googleapis.com/fast-convnets/tflite-models/mbv2_1.15_11-16b2_85_70_2.tflite) 35 | | MobileNetV2 1.4 | 72.0% | 85% | 54 | [link](https://storage.googleapis.com/fast-convnets/tflite-models/mbv2_1.4_11-16b2_85_72_0.tflite) 36 | | MobileNetV2 1.8 | 74.9% | 80% | 102 | [link](https://storage.googleapis.com/fast-convnets/tflite-models/mbv2_1.8_x_80_74.9.tflite) 37 | | MobileNetV2 2.0 | 74.5% | 85% | 93 | [link](https://storage.googleapis.com/fast-convnets/tflite-models/mbv2_2.0_11-16b2_85_74_5.tflite) 38 | | EfficientNet B0 | 75.1% | 80% | 80 | [link](https://storage.googleapis.com/fast-convnets/tflite-models/en_b0_x_80_75_1.tflite) 39 | | EfficientNet B1 | 76.7% | 85% | 110 | [link](https://storage.googleapis.com/fast-convnets/tflite-models/en_b1_x_85_76.7.tflite) 40 | -------------------------------------------------------------------------------- /conv-3x3s2p1c3x4-neonfma-2x2.c: -------------------------------------------------------------------------------- 1 | // Copyright 2019 The Fast Sparse ConvNets Authors. 2 | // 3 | // This source code is licensed under the BSD-style license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include 7 | 8 | #include 9 | 10 | void f32_conv_hwc2spchw_ukernel_3x3s2p1c3x4__neonfma_2x2( 11 | size_t input_height, 12 | size_t input_width, 13 | size_t output_y_start, 14 | size_t output_y_end, 15 | const float* input, 16 | const float* zero, 17 | const float* weights, 18 | float* output, 19 | size_t input_padding_top, 20 | size_t output_channels, 21 | size_t output_height_stride, 22 | size_t output_channel_stride, 23 | const union f32_output_params params[restrict static 1]) 24 | { 25 | assert(input_width != 0); 26 | assert(output_y_end > output_y_start); 27 | assert(input_padding_top <= 1); 28 | assert(output_channels != 0); 29 | 30 | const size_t input_height_stride = input_width * 3 /* channels */ * sizeof(float); 31 | const size_t input_width_increment = round_down_po2(input_width, 4) * 3 /* channels */ * sizeof(float); 32 | const size_t output_width = (input_width + 1) / 2; 33 | const size_t output_channel_increment = output_channel_stride * 4 - output_width * sizeof(float); 34 | 35 | // Adjustment for padding processed below 36 | const float* i0 = (const float*) ((uintptr_t) input + input_height_stride * (output_y_start * 2 - input_padding_top)); 37 | const float* i1 = (const float*) ((uintptr_t) i0 + input_height_stride); 38 | const float* i2 = (const float*) ((uintptr_t) i1 + input_height_stride); 39 | const float* i3 = (const float*) ((uintptr_t) i2 + input_height_stride); 40 | const float* i4 = (const float*) ((uintptr_t) i3 + input_height_stride); 41 | float* output0 = (float*) ((uintptr_t) output + output_height_stride * output_y_start); 42 | float* output1 = (float*) ((uintptr_t) output0 + output_height_stride); 43 | 44 | if (output_y_start < input_padding_top) { 45 | i0 = zero; 46 | } 47 | 48 | const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min); 49 | const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max); 50 | 51 | for (size_t output_y = output_y_start; output_y < output_y_end; output_y += 2) { 52 | const size_t input_y2 = output_y * 2 + 2 - input_padding_top; 53 | const size_t input_y4 = input_y2 + 2; 54 | if (input_y2 >= input_height) { 55 | i2 = zero; 56 | } 57 | if (input_y4 > input_height) { 58 | i3 = zero; 59 | } 60 | if (input_y4 >= input_height) { 61 | i4 = zero; 62 | } 63 | if (output_y + 2 > output_y_end) { 64 | output1 = output0; 65 | } 66 | 67 | const float* w = weights; 68 | size_t c = output_channels; 69 | float* o0c0 = output0; 70 | float* o1c0 = output1; 71 | float* o0c1 = (float*) ((uintptr_t) o0c0 + output_channel_stride); 72 | float* o1c1 = (float*) ((uintptr_t) o1c0 + output_channel_stride); 73 | float* o0c2 = (float*) ((uintptr_t) o0c1 + output_channel_stride); 74 | float* o1c2 = (float*) ((uintptr_t) o1c1 + output_channel_stride); 75 | float* o0c3 = (float*) ((uintptr_t) o0c2 + output_channel_stride); 76 | float* o1c3 = (float*) ((uintptr_t) o1c2 + output_channel_stride); 77 | do { 78 | if (c < 2) { 79 | o0c1 = o0c0; 80 | o1c1 = o1c0; 81 | } 82 | if (c <= 2) { 83 | o0c2 = o0c1; 84 | o1c2 = o1c1; 85 | } 86 | if (c < 4) { 87 | o0c3 = o0c2; 88 | o1c3 = o1c2; 89 | } 90 | 91 | // viMx0 = ( iM0c2, iM0c1, iM0c0, --- ) 92 | float32x4_t vi0x0 = vmovq_n_f32(0.0f); 93 | float32x4_t vi1x0 = vmovq_n_f32(0.0f); 94 | float32x4_t vi2x0 = vmovq_n_f32(0.0f); 95 | float32x4_t vi3x0 = vmovq_n_f32(0.0f); 96 | float32x4_t vi4x0 = vmovq_n_f32(0.0f); 97 | 98 | size_t iw = input_width; 99 | for (; iw >= 4; iw -= 4) { 100 | float32x4_t vo0x0 = vld1q_f32(w); 101 | float32x4_t vo1x0 = vo0x0; 102 | float32x4_t vo0x1 = vo0x0; 103 | float32x4_t vo1x1 = vo0x0; 104 | 105 | const float32x4_t vk00c0 = vld1q_f32(w + 4); 106 | 107 | // viMx1 = ( iM2c0, iM1c2, iM1c1, iM1c0 ) 108 | const float32x4_t vi0x1 = vld1q_f32(i0); i0 += 4; 109 | const float32x4_t vi1x1 = vld1q_f32(i1); i1 += 4; 110 | const float32x4_t vi2x1 = vld1q_f32(i2); i2 += 4; 111 | const float32x4_t vi3x1 = vld1q_f32(i3); i3 += 4; 112 | const float32x4_t vi4x1 = vld1q_f32(i4); i4 += 4; 113 | 114 | vo0x0 = vfmaq_laneq_f32(vo0x0, vk00c0, vi0x0, 1); 115 | vo1x0 = vfmaq_laneq_f32(vo1x0, vk00c0, vi2x0, 1); 116 | vo0x1 = vfmaq_laneq_f32(vo0x1, vk00c0, vi0x1, 3); 117 | vo1x1 = vfmaq_laneq_f32(vo1x1, vk00c0, vi2x1, 3); 118 | 119 | const float32x4_t vk10c0 = vld1q_f32(w + 8); 120 | 121 | vo0x0 = vfmaq_laneq_f32(vo0x0, vk10c0, vi1x0, 1); 122 | vo1x0 = vfmaq_laneq_f32(vo1x0, vk10c0, vi3x0, 1); 123 | vo0x1 = vfmaq_laneq_f32(vo0x1, vk10c0, vi1x1, 3); 124 | vo1x1 = vfmaq_laneq_f32(vo1x1, vk10c0, vi3x1, 3); 125 | 126 | const float32x4_t vk20c0 = vld1q_f32(w + 12); 127 | 128 | vo0x0 = vfmaq_laneq_f32(vo0x0, vk20c0, vi2x0, 1); 129 | vo1x0 = vfmaq_laneq_f32(vo1x0, vk20c0, vi4x0, 1); 130 | vo0x1 = vfmaq_laneq_f32(vo0x1, vk20c0, vi2x1, 3); 131 | vo1x1 = vfmaq_laneq_f32(vo1x1, vk20c0, vi4x1, 3); 132 | 133 | const float32x4_t vk00c1 = vld1q_f32(w + 16); 134 | 135 | // viMx2 = ( iM3c1, iM3c0, iM2c2, iM2c1 ) 136 | const float32x4_t vi0x2 = vld1q_f32(i0); i0 += 4; 137 | const float32x4_t vi1x2 = vld1q_f32(i1); i1 += 4; 138 | const float32x4_t vi2x2 = vld1q_f32(i2); i2 += 4; 139 | const float32x4_t vi3x2 = vld1q_f32(i3); i3 += 4; 140 | const float32x4_t vi4x2 = vld1q_f32(i4); i4 += 4; 141 | 142 | vo0x0 = vfmaq_laneq_f32(vo0x0, vk00c1, vi0x0, 2); 143 | vo1x0 = vfmaq_laneq_f32(vo1x0, vk00c1, vi2x0, 2); 144 | vo0x1 = vfmaq_laneq_f32(vo0x1, vk00c1, vi0x2, 0); 145 | vo1x1 = vfmaq_laneq_f32(vo1x1, vk00c1, vi2x2, 0); 146 | 147 | const float32x4_t vk10c1 = vld1q_f32(w + 20); 148 | 149 | vo0x0 = vfmaq_laneq_f32(vo0x0, vk10c1, vi1x0, 2); 150 | vo1x0 = vfmaq_laneq_f32(vo1x0, vk10c1, vi3x0, 2); 151 | vo0x1 = vfmaq_laneq_f32(vo0x1, vk10c1, vi1x2, 0); 152 | vo1x1 = vfmaq_laneq_f32(vo1x1, vk10c1, vi3x2, 0); 153 | 154 | const float32x4_t vk20c1 = vld1q_f32(w + 24); 155 | 156 | vo0x0 = vfmaq_laneq_f32(vo0x0, vk20c1, vi2x0, 2); 157 | vo1x0 = vfmaq_laneq_f32(vo1x0, vk20c1, vi4x0, 2); 158 | vo0x1 = vfmaq_laneq_f32(vo0x1, vk20c1, vi2x2, 0); 159 | vo1x1 = vfmaq_laneq_f32(vo1x1, vk20c1, vi4x2, 0); 160 | 161 | const float32x4_t vk00c2 = vld1q_f32(w + 28); 162 | 163 | vo0x0 = vfmaq_laneq_f32(vo0x0, vk00c2, vi0x0, 3); 164 | vo1x0 = vfmaq_laneq_f32(vo1x0, vk00c2, vi2x0, 3); 165 | vo0x1 = vfmaq_laneq_f32(vo0x1, vk00c2, vi0x2, 1); 166 | vo1x1 = vfmaq_laneq_f32(vo1x1, vk00c2, vi2x2, 1); 167 | 168 | const float32x4_t vk10c2 = vld1q_f32(w + 32); 169 | 170 | vo0x0 = vfmaq_laneq_f32(vo0x0, vk10c2, vi1x0, 3); 171 | vo1x0 = vfmaq_laneq_f32(vo1x0, vk10c2, vi3x0, 3); 172 | vo0x1 = vfmaq_laneq_f32(vo0x1, vk10c2, vi1x2, 1); 173 | vo1x1 = vfmaq_laneq_f32(vo1x1, vk10c2, vi3x2, 1); 174 | 175 | const float32x4_t vk20c2 = vld1q_f32(w + 36); 176 | 177 | vo0x0 = vfmaq_laneq_f32(vo0x0, vk20c2, vi2x0, 3); 178 | vo1x0 = vfmaq_laneq_f32(vo1x0, vk20c2, vi4x0, 3); 179 | vo0x1 = vfmaq_laneq_f32(vo0x1, vk20c2, vi2x2, 1); 180 | vo1x1 = vfmaq_laneq_f32(vo1x1, vk20c2, vi4x2, 1); 181 | 182 | const float32x4_t vk01c0 = vld1q_f32(w + 40); 183 | 184 | vo0x0 = vfmaq_laneq_f32(vo0x0, vk01c0, vi0x1, 0); 185 | vo1x0 = vfmaq_laneq_f32(vo1x0, vk01c0, vi2x1, 0); 186 | vo0x1 = vfmaq_laneq_f32(vo0x1, vk01c0, vi0x2, 2); 187 | vo1x1 = vfmaq_laneq_f32(vo1x1, vk01c0, vi2x2, 2); 188 | 189 | const float32x4_t vk11c0 = vld1q_f32(w + 44); 190 | 191 | vo0x0 = vfmaq_laneq_f32(vo0x0, vk11c0, vi1x1, 0); 192 | vo1x0 = vfmaq_laneq_f32(vo1x0, vk11c0, vi3x1, 0); 193 | vo0x1 = vfmaq_laneq_f32(vo0x1, vk11c0, vi1x2, 2); 194 | vo1x1 = vfmaq_laneq_f32(vo1x1, vk11c0, vi3x2, 2); 195 | 196 | const float32x4_t vk21c0 = vld1q_f32(w + 48); 197 | 198 | vo0x0 = vfmaq_laneq_f32(vo0x0, vk21c0, vi2x1, 0); 199 | vo1x0 = vfmaq_laneq_f32(vo1x0, vk21c0, vi4x1, 0); 200 | vo0x1 = vfmaq_laneq_f32(vo0x1, vk21c0, vi2x2, 2); 201 | vo1x1 = vfmaq_laneq_f32(vo1x1, vk21c0, vi4x2, 2); 202 | 203 | const float32x4_t vk01c1 = vld1q_f32(w + 52); 204 | 205 | vo0x0 = vfmaq_laneq_f32(vo0x0, vk01c1, vi0x1, 1); 206 | vo1x0 = vfmaq_laneq_f32(vo1x0, vk01c1, vi2x1, 1); 207 | vo0x1 = vfmaq_laneq_f32(vo0x1, vk01c1, vi0x2, 3); 208 | vo1x1 = vfmaq_laneq_f32(vo1x1, vk01c1, vi2x2, 3); 209 | 210 | const float32x4_t vk11c1 = vld1q_f32(w + 56); 211 | 212 | vo0x0 = vfmaq_laneq_f32(vo0x0, vk11c1, vi1x1, 1); 213 | vo1x0 = vfmaq_laneq_f32(vo1x0, vk11c1, vi3x1, 1); 214 | vo0x1 = vfmaq_laneq_f32(vo0x1, vk11c1, vi1x2, 3); 215 | vo1x1 = vfmaq_laneq_f32(vo1x1, vk11c1, vi3x2, 3); 216 | 217 | const float32x4_t vk21c1 = vld1q_f32(w + 60); 218 | 219 | vo0x0 = vfmaq_laneq_f32(vo0x0, vk21c1, vi2x1, 1); 220 | vo1x0 = vfmaq_laneq_f32(vo1x0, vk21c1, vi4x1, 1); 221 | vo0x1 = vfmaq_laneq_f32(vo0x1, vk21c1, vi2x2, 3); 222 | vo1x1 = vfmaq_laneq_f32(vo1x1, vk21c1, vi4x2, 3); 223 | 224 | const float32x4_t vk01c2 = vld1q_f32(w + 64); 225 | 226 | // viMx3 = ( iM4c2, iM4c1, iM4c0, iM3c2 ) 227 | const float32x4_t vi0x3 = vld1q_f32(i0); i0 += 4; 228 | const float32x4_t vi1x3 = vld1q_f32(i1); i1 += 4; 229 | const float32x4_t vi2x3 = vld1q_f32(i2); i2 += 4; 230 | const float32x4_t vi3x3 = vld1q_f32(i3); i3 += 4; 231 | const float32x4_t vi4x3 = vld1q_f32(i4); i4 += 4; 232 | 233 | vo0x0 = vfmaq_laneq_f32(vo0x0, vk01c2, vi0x1, 2); 234 | vo1x0 = vfmaq_laneq_f32(vo1x0, vk01c2, vi2x1, 2); 235 | vo0x1 = vfmaq_laneq_f32(vo0x1, vk01c2, vi0x3, 0); 236 | vo1x1 = vfmaq_laneq_f32(vo1x1, vk01c2, vi2x3, 0); 237 | 238 | const float32x4_t vk11c2 = vld1q_f32(w + 68); 239 | 240 | vo0x0 = vfmaq_laneq_f32(vo0x0, vk11c2, vi1x1, 2); 241 | vo1x0 = vfmaq_laneq_f32(vo1x0, vk11c2, vi3x1, 2); 242 | vo0x1 = vfmaq_laneq_f32(vo0x1, vk11c2, vi1x3, 0); 243 | vo1x1 = vfmaq_laneq_f32(vo1x1, vk11c2, vi3x3, 0); 244 | 245 | const float32x4_t vk21c2 = vld1q_f32(w + 72); 246 | 247 | vo0x0 = vfmaq_laneq_f32(vo0x0, vk21c2, vi2x1, 2); 248 | vo1x0 = vfmaq_laneq_f32(vo1x0, vk21c2, vi4x1, 2); 249 | vo0x1 = vfmaq_laneq_f32(vo0x1, vk21c2, vi2x3, 0); 250 | vo1x1 = vfmaq_laneq_f32(vo1x1, vk21c2, vi4x3, 0); 251 | 252 | const float32x4_t vk02c0 = vld1q_f32(w + 76); 253 | 254 | vo0x0 = vfmaq_laneq_f32(vo0x0, vk02c0, vi0x1, 3); 255 | vo1x0 = vfmaq_laneq_f32(vo1x0, vk02c0, vi2x1, 3); 256 | vo0x1 = vfmaq_laneq_f32(vo0x1, vk02c0, vi0x3, 1); 257 | vo1x1 = vfmaq_laneq_f32(vo1x1, vk02c0, vi2x3, 1); 258 | 259 | const float32x4_t vk12c0 = vld1q_f32(w + 80); 260 | 261 | vo0x0 = vfmaq_laneq_f32(vo0x0, vk12c0, vi1x1, 3); 262 | vo1x0 = vfmaq_laneq_f32(vo1x0, vk12c0, vi3x1, 3); 263 | vo0x1 = vfmaq_laneq_f32(vo0x1, vk12c0, vi1x3, 1); 264 | vo1x1 = vfmaq_laneq_f32(vo1x1, vk12c0, vi3x3, 1); 265 | 266 | const float32x4_t vk22c0 = vld1q_f32(w + 84); 267 | 268 | vo0x0 = vfmaq_laneq_f32(vo0x0, vk22c0, vi2x1, 3); 269 | vo1x0 = vfmaq_laneq_f32(vo1x0, vk22c0, vi4x1, 3); 270 | vo0x1 = vfmaq_laneq_f32(vo0x1, vk22c0, vi2x3, 1); 271 | vo1x1 = vfmaq_laneq_f32(vo1x1, vk22c0, vi4x3, 1); 272 | 273 | const float32x4_t vk02c1 = vld1q_f32(w + 88); 274 | 275 | vo0x0 = vfmaq_laneq_f32(vo0x0, vk02c1, vi0x2, 0); 276 | vo1x0 = vfmaq_laneq_f32(vo1x0, vk02c1, vi2x2, 0); 277 | vo0x1 = vfmaq_laneq_f32(vo0x1, vk02c1, vi0x3, 2); 278 | vo1x1 = vfmaq_laneq_f32(vo1x1, vk02c1, vi2x3, 2); 279 | 280 | const float32x4_t vk12c1 = vld1q_f32(w + 92); 281 | 282 | vo0x0 = vfmaq_laneq_f32(vo0x0, vk12c1, vi1x2, 0); 283 | vo1x0 = vfmaq_laneq_f32(vo1x0, vk12c1, vi3x2, 0); 284 | vo0x1 = vfmaq_laneq_f32(vo0x1, vk12c1, vi1x3, 2); 285 | vo1x1 = vfmaq_laneq_f32(vo1x1, vk12c1, vi3x3, 2); 286 | 287 | const float32x4_t vk22c1 = vld1q_f32(w + 96); 288 | 289 | vo0x0 = vfmaq_laneq_f32(vo0x0, vk22c1, vi2x2, 0); 290 | vo1x0 = vfmaq_laneq_f32(vo1x0, vk22c1, vi4x2, 0); 291 | vo0x1 = vfmaq_laneq_f32(vo0x1, vk22c1, vi2x3, 2); 292 | vo1x1 = vfmaq_laneq_f32(vo1x1, vk22c1, vi4x3, 2); 293 | 294 | const float32x4_t vk02c2 = vld1q_f32(w + 100); 295 | 296 | vo0x0 = vfmaq_laneq_f32(vo0x0, vk02c2, vi0x2, 1); 297 | vo1x0 = vfmaq_laneq_f32(vo1x0, vk02c2, vi2x2, 1); 298 | vo0x1 = vfmaq_laneq_f32(vo0x1, vk02c2, vi0x3, 3); 299 | vo1x1 = vfmaq_laneq_f32(vo1x1, vk02c2, vi2x3, 3); 300 | 301 | const float32x4_t vk12c2 = vld1q_f32(w + 104); 302 | 303 | vo0x0 = vfmaq_laneq_f32(vo0x0, vk12c2, vi1x2, 1); 304 | vo1x0 = vfmaq_laneq_f32(vo1x0, vk12c2, vi3x2, 1); 305 | vo0x1 = vfmaq_laneq_f32(vo0x1, vk12c2, vi1x3, 3); 306 | vo1x1 = vfmaq_laneq_f32(vo1x1, vk12c2, vi3x3, 3); 307 | 308 | const float32x4_t vk22c2 = vld1q_f32(w + 108); 309 | 310 | vo0x0 = vfmaq_laneq_f32(vo0x0, vk22c2, vi2x2, 1); 311 | vo1x0 = vfmaq_laneq_f32(vo1x0, vk22c2, vi4x2, 1); 312 | vo0x1 = vfmaq_laneq_f32(vo0x1, vk22c2, vi2x3, 3); 313 | vo1x1 = vfmaq_laneq_f32(vo1x1, vk22c2, vi4x3, 3); 314 | 315 | vi0x0 = vi0x3; 316 | vi1x0 = vi1x3; 317 | vi2x0 = vi2x3; 318 | vi3x0 = vi3x3; 319 | vi4x0 = vi4x3; 320 | 321 | vo0x0 = vmaxq_f32(vo0x0, vmin); 322 | vo1x0 = vmaxq_f32(vo1x0, vmin); 323 | vo0x1 = vmaxq_f32(vo0x1, vmin); 324 | vo1x1 = vmaxq_f32(vo1x1, vmin); 325 | 326 | vo0x0 = vminq_f32(vo0x0, vmax); 327 | vo1x0 = vminq_f32(vo1x0, vmax); 328 | vo0x1 = vminq_f32(vo0x1, vmax); 329 | vo1x1 = vminq_f32(vo1x1, vmax); 330 | 331 | const float32x4_t vo0c01 = vzip1q_f32(vo0x0, vo0x1); 332 | const float32x4_t vo0c23 = vzip2q_f32(vo0x0, vo0x1); 333 | const float32x4_t vo1c01 = vzip1q_f32(vo1x0, vo1x1); 334 | const float32x4_t vo1c23 = vzip2q_f32(vo1x0, vo1x1); 335 | 336 | // Always 2+ output width elements remaining 337 | vst1_f32(o1c0, vget_low_f32(vo1c01)); o1c0 += 2; 338 | vst1_f32(o1c1, vget_high_f32(vo1c01)); o1c1 += 2; 339 | vst1_f32(o1c2, vget_low_f32(vo1c23)); o1c2 += 2; 340 | vst1_f32(o1c3, vget_high_f32(vo1c23)); o1c3 += 2; 341 | 342 | vst1_f32(o0c0, vget_low_f32(vo0c01)); o0c0 += 2; 343 | vst1_f32(o0c1, vget_high_f32(vo0c01)); o0c1 += 2; 344 | vst1_f32(o0c2, vget_low_f32(vo0c23)); o0c2 += 2; 345 | vst1_f32(o0c3, vget_high_f32(vo0c23)); o0c3 += 2; 346 | } 347 | assert(iw < 4); 348 | if (iw != 0) { 349 | float32x4_t vo0x0 = vld1q_f32(w); 350 | float32x4_t vo1x0 = vo0x0; 351 | float32x4_t vo0x1 = vo0x0; 352 | float32x4_t vo1x1 = vo0x0; 353 | 354 | const float32x4_t vk00c0 = vld1q_f32(w + 4); 355 | 356 | // viMx1 = ( iM2c0, iM1c2, iM1c1, iM1c0 ) 357 | float32x4_t vi0x1 = vld1q_f32(i0); 358 | float32x4_t vi1x1 = vld1q_f32(i1); 359 | float32x4_t vi2x1 = vld1q_f32(i2); 360 | float32x4_t vi3x1 = vld1q_f32(i3); 361 | float32x4_t vi4x1 = vld1q_f32(i4); 362 | 363 | vo0x0 = vfmaq_laneq_f32(vo0x0, vk00c0, vi0x0, 1); 364 | vo1x0 = vfmaq_laneq_f32(vo1x0, vk00c0, vi2x0, 1); 365 | if (iw > 2) { 366 | vo0x1 = vfmaq_laneq_f32(vo0x1, vk00c0, vi0x1, 3); 367 | vo1x1 = vfmaq_laneq_f32(vo1x1, vk00c0, vi2x1, 3); 368 | } 369 | 370 | const float32x4_t vk10c0 = vld1q_f32(w + 8); 371 | 372 | vo0x0 = vfmaq_laneq_f32(vo0x0, vk10c0, vi1x0, 1); 373 | vo1x0 = vfmaq_laneq_f32(vo1x0, vk10c0, vi3x0, 1); 374 | if (iw > 2) { 375 | vo0x1 = vfmaq_laneq_f32(vo0x1, vk10c0, vi1x1, 3); 376 | vo1x1 = vfmaq_laneq_f32(vo1x1, vk10c0, vi3x1, 3); 377 | } 378 | 379 | const float32x4_t vk20c0 = vld1q_f32(w + 12); 380 | 381 | vo0x0 = vfmaq_laneq_f32(vo0x0, vk20c0, vi2x0, 1); 382 | vo1x0 = vfmaq_laneq_f32(vo1x0, vk20c0, vi4x0, 1); 383 | if (iw > 2) { 384 | vo0x1 = vfmaq_laneq_f32(vo0x1, vk20c0, vi2x1, 3); 385 | vo1x1 = vfmaq_laneq_f32(vo1x1, vk20c0, vi4x1, 3); 386 | } 387 | 388 | const float32x4_t vk00c1 = vld1q_f32(w + 16); 389 | 390 | float32x4_t vi0x2 = vmovq_n_f32(0.0f); 391 | float32x4_t vi1x2 = vmovq_n_f32(0.0f); 392 | float32x4_t vi2x2 = vmovq_n_f32(0.0f); 393 | float32x4_t vi3x2 = vmovq_n_f32(0.0f); 394 | float32x4_t vi4x2 = vmovq_n_f32(0.0f); 395 | if (iw >= 2) { 396 | // viMx2 = ( iM3c1, iM3c0, iM2c2, iM2c1 ) 397 | vi0x2 = vld1q_f32(i0 + 4); 398 | vi1x2 = vld1q_f32(i1 + 4); 399 | vi2x2 = vld1q_f32(i2 + 4); 400 | vi3x2 = vld1q_f32(i3 + 4); 401 | vi4x2 = vld1q_f32(i4 + 4); 402 | } 403 | 404 | vo0x0 = vfmaq_laneq_f32(vo0x0, vk00c1, vi0x0, 2); 405 | vo1x0 = vfmaq_laneq_f32(vo1x0, vk00c1, vi2x0, 2); 406 | vo0x1 = vfmaq_laneq_f32(vo0x1, vk00c1, vi0x2, 0); 407 | vo1x1 = vfmaq_laneq_f32(vo1x1, vk00c1, vi2x2, 0); 408 | 409 | const float32x4_t vk10c1 = vld1q_f32(w + 20); 410 | 411 | vo0x0 = vfmaq_laneq_f32(vo0x0, vk10c1, vi1x0, 2); 412 | vo1x0 = vfmaq_laneq_f32(vo1x0, vk10c1, vi3x0, 2); 413 | vo0x1 = vfmaq_laneq_f32(vo0x1, vk10c1, vi1x2, 0); 414 | vo1x1 = vfmaq_laneq_f32(vo1x1, vk10c1, vi3x2, 0); 415 | 416 | const float32x4_t vk20c1 = vld1q_f32(w + 24); 417 | 418 | vo0x0 = vfmaq_laneq_f32(vo0x0, vk20c1, vi2x0, 2); 419 | vo1x0 = vfmaq_laneq_f32(vo1x0, vk20c1, vi4x0, 2); 420 | vo0x1 = vfmaq_laneq_f32(vo0x1, vk20c1, vi2x2, 0); 421 | vo1x1 = vfmaq_laneq_f32(vo1x1, vk20c1, vi4x2, 0); 422 | 423 | const float32x4_t vk00c2 = vld1q_f32(w + 28); 424 | 425 | vo0x0 = vfmaq_laneq_f32(vo0x0, vk00c2, vi0x0, 3); 426 | vo1x0 = vfmaq_laneq_f32(vo1x0, vk00c2, vi2x0, 3); 427 | vo0x1 = vfmaq_laneq_f32(vo0x1, vk00c2, vi0x2, 1); 428 | vo1x1 = vfmaq_laneq_f32(vo1x1, vk00c2, vi2x2, 1); 429 | 430 | const float32x4_t vk10c2 = vld1q_f32(w + 32); 431 | 432 | vo0x0 = vfmaq_laneq_f32(vo0x0, vk10c2, vi1x0, 3); 433 | vo1x0 = vfmaq_laneq_f32(vo1x0, vk10c2, vi3x0, 3); 434 | vo0x1 = vfmaq_laneq_f32(vo0x1, vk10c2, vi1x2, 1); 435 | vo1x1 = vfmaq_laneq_f32(vo1x1, vk10c2, vi3x2, 1); 436 | 437 | const float32x4_t vk20c2 = vld1q_f32(w + 36); 438 | 439 | vo0x0 = vfmaq_laneq_f32(vo0x0, vk20c2, vi2x0, 3); 440 | vo1x0 = vfmaq_laneq_f32(vo1x0, vk20c2, vi4x0, 3); 441 | vo0x1 = vfmaq_laneq_f32(vo0x1, vk20c2, vi2x2, 1); 442 | vo1x1 = vfmaq_laneq_f32(vo1x1, vk20c2, vi4x2, 1); 443 | 444 | const float32x4_t vk01c0 = vld1q_f32(w + 40); 445 | 446 | vo0x0 = vfmaq_laneq_f32(vo0x0, vk01c0, vi0x1, 0); 447 | vo1x0 = vfmaq_laneq_f32(vo1x0, vk01c0, vi2x1, 0); 448 | if (iw > 2) { 449 | vo0x1 = vfmaq_laneq_f32(vo0x1, vk01c0, vi0x2, 2); 450 | vo1x1 = vfmaq_laneq_f32(vo1x1, vk01c0, vi2x2, 2); 451 | } 452 | 453 | const float32x4_t vk11c0 = vld1q_f32(w + 44); 454 | 455 | vo0x0 = vfmaq_laneq_f32(vo0x0, vk11c0, vi1x1, 0); 456 | vo1x0 = vfmaq_laneq_f32(vo1x0, vk11c0, vi3x1, 0); 457 | if (iw > 2) { 458 | vo0x1 = vfmaq_laneq_f32(vo0x1, vk11c0, vi1x2, 2); 459 | vo1x1 = vfmaq_laneq_f32(vo1x1, vk11c0, vi3x2, 2); 460 | } 461 | 462 | const float32x4_t vk21c0 = vld1q_f32(w + 48); 463 | 464 | vo0x0 = vfmaq_laneq_f32(vo0x0, vk21c0, vi2x1, 0); 465 | vo1x0 = vfmaq_laneq_f32(vo1x0, vk21c0, vi4x1, 0); 466 | if (iw > 2) { 467 | vo0x1 = vfmaq_laneq_f32(vo0x1, vk21c0, vi2x2, 2); 468 | vo1x1 = vfmaq_laneq_f32(vo1x1, vk21c0, vi4x2, 2); 469 | } 470 | 471 | const float32x4_t vk01c1 = vld1q_f32(w + 52); 472 | 473 | vo0x0 = vfmaq_laneq_f32(vo0x0, vk01c1, vi0x1, 1); 474 | vo1x0 = vfmaq_laneq_f32(vo1x0, vk01c1, vi2x1, 1); 475 | if (iw > 2) { 476 | vo0x1 = vfmaq_laneq_f32(vo0x1, vk01c1, vi0x2, 3); 477 | vo1x1 = vfmaq_laneq_f32(vo1x1, vk01c1, vi2x2, 3); 478 | } 479 | 480 | const float32x4_t vk11c1 = vld1q_f32(w + 56); 481 | 482 | vo0x0 = vfmaq_laneq_f32(vo0x0, vk11c1, vi1x1, 1); 483 | vo1x0 = vfmaq_laneq_f32(vo1x0, vk11c1, vi3x1, 1); 484 | if (iw > 2) { 485 | vo0x1 = vfmaq_laneq_f32(vo0x1, vk11c1, vi1x2, 3); 486 | vo1x1 = vfmaq_laneq_f32(vo1x1, vk11c1, vi3x2, 3); 487 | } 488 | 489 | const float32x4_t vk21c1 = vld1q_f32(w + 60); 490 | 491 | vo0x0 = vfmaq_laneq_f32(vo0x0, vk21c1, vi2x1, 1); 492 | vo1x0 = vfmaq_laneq_f32(vo1x0, vk21c1, vi4x1, 1); 493 | if (iw > 2) { 494 | vo0x1 = vfmaq_laneq_f32(vo0x1, vk21c1, vi2x2, 3); 495 | vo1x1 = vfmaq_laneq_f32(vo1x1, vk21c1, vi4x2, 3); 496 | } 497 | 498 | const float32x4_t vk01c2 = vld1q_f32(w + 64); 499 | 500 | float32x4_t vi0x3 = vmovq_n_f32(0.0f); 501 | float32x4_t vi1x3 = vmovq_n_f32(0.0f); 502 | float32x4_t vi2x3 = vmovq_n_f32(0.0f); 503 | float32x4_t vi3x3 = vmovq_n_f32(0.0f); 504 | float32x4_t vi4x3 = vmovq_n_f32(0.0f); 505 | if (iw > 2) { 506 | // viMx3 = ( 0.0, 0.0, 0.0, iM3c2 ) 507 | vi0x3 = vld1q_lane_f32(i0 + 8, vi0x3, 0); 508 | vi1x3 = vld1q_lane_f32(i1 + 8, vi1x3, 0); 509 | vi2x3 = vld1q_lane_f32(i2 + 8, vi2x3, 0); 510 | vi3x3 = vld1q_lane_f32(i3 + 8, vi3x3, 0); 511 | vi4x3 = vld1q_lane_f32(i4 + 8, vi4x3, 0); 512 | } 513 | 514 | vo0x0 = vfmaq_laneq_f32(vo0x0, vk01c2, vi0x1, 2); 515 | vo1x0 = vfmaq_laneq_f32(vo1x0, vk01c2, vi2x1, 2); 516 | vo0x1 = vfmaq_laneq_f32(vo0x1, vk01c2, vi0x3, 0); 517 | vo1x1 = vfmaq_laneq_f32(vo1x1, vk01c2, vi2x3, 0); 518 | 519 | const float32x4_t vk11c2 = vld1q_f32(w + 68); 520 | 521 | vo0x0 = vfmaq_laneq_f32(vo0x0, vk11c2, vi1x1, 2); 522 | vo1x0 = vfmaq_laneq_f32(vo1x0, vk11c2, vi3x1, 2); 523 | vo0x1 = vfmaq_laneq_f32(vo0x1, vk11c2, vi1x3, 0); 524 | vo1x1 = vfmaq_laneq_f32(vo1x1, vk11c2, vi3x3, 0); 525 | 526 | const float32x4_t vk21c2 = vld1q_f32(w + 72); 527 | 528 | vo0x0 = vfmaq_laneq_f32(vo0x0, vk21c2, vi2x1, 2); 529 | vo1x0 = vfmaq_laneq_f32(vo1x0, vk21c2, vi4x1, 2); 530 | vo0x1 = vfmaq_laneq_f32(vo0x1, vk21c2, vi2x3, 0); 531 | vo1x1 = vfmaq_laneq_f32(vo1x1, vk21c2, vi4x3, 0); 532 | 533 | if (iw >= 2) { 534 | const float32x4_t vk02c0 = vld1q_f32(w + 76); 535 | 536 | vo0x0 = vfmaq_laneq_f32(vo0x0, vk02c0, vi0x1, 3); 537 | vo1x0 = vfmaq_laneq_f32(vo1x0, vk02c0, vi2x1, 3); 538 | 539 | const float32x4_t vk12c0 = vld1q_f32(w + 80); 540 | 541 | vo0x0 = vfmaq_laneq_f32(vo0x0, vk12c0, vi1x1, 3); 542 | vo1x0 = vfmaq_laneq_f32(vo1x0, vk12c0, vi3x1, 3); 543 | 544 | const float32x4_t vk22c0 = vld1q_f32(w + 84); 545 | 546 | vo0x0 = vfmaq_laneq_f32(vo0x0, vk22c0, vi2x1, 3); 547 | vo1x0 = vfmaq_laneq_f32(vo1x0, vk22c0, vi4x1, 3); 548 | 549 | const float32x4_t vk02c1 = vld1q_f32(w + 88); 550 | 551 | vo0x0 = vfmaq_laneq_f32(vo0x0, vk02c1, vi0x2, 0); 552 | vo1x0 = vfmaq_laneq_f32(vo1x0, vk02c1, vi2x2, 0); 553 | 554 | const float32x4_t vk12c1 = vld1q_f32(w + 92); 555 | 556 | vo0x0 = vfmaq_laneq_f32(vo0x0, vk12c1, vi1x2, 0); 557 | vo1x0 = vfmaq_laneq_f32(vo1x0, vk12c1, vi3x2, 0); 558 | 559 | const float32x4_t vk22c1 = vld1q_f32(w + 96); 560 | 561 | vo0x0 = vfmaq_laneq_f32(vo0x0, vk22c1, vi2x2, 0); 562 | vo1x0 = vfmaq_laneq_f32(vo1x0, vk22c1, vi4x2, 0); 563 | 564 | const float32x4_t vk02c2 = vld1q_f32(w + 100); 565 | 566 | vo0x0 = vfmaq_laneq_f32(vo0x0, vk02c2, vi0x2, 1); 567 | vo1x0 = vfmaq_laneq_f32(vo1x0, vk02c2, vi2x2, 1); 568 | 569 | const float32x4_t vk12c2 = vld1q_f32(w + 104); 570 | 571 | vo0x0 = vfmaq_laneq_f32(vo0x0, vk12c2, vi1x2, 1); 572 | vo1x0 = vfmaq_laneq_f32(vo1x0, vk12c2, vi3x2, 1); 573 | 574 | const float32x4_t vk22c2 = vld1q_f32(w + 108); 575 | 576 | vo0x0 = vfmaq_laneq_f32(vo0x0, vk22c2, vi2x2, 1); 577 | vo1x0 = vfmaq_laneq_f32(vo1x0, vk22c2, vi4x2, 1); 578 | } 579 | 580 | vo0x0 = vmaxq_f32(vo0x0, vmin); 581 | vo1x0 = vmaxq_f32(vo1x0, vmin); 582 | vo0x1 = vmaxq_f32(vo0x1, vmin); 583 | vo1x1 = vmaxq_f32(vo1x1, vmin); 584 | 585 | vo0x0 = vminq_f32(vo0x0, vmax); 586 | vo1x0 = vminq_f32(vo1x0, vmax); 587 | vo0x1 = vminq_f32(vo0x1, vmax); 588 | vo1x1 = vminq_f32(vo1x1, vmax); 589 | 590 | if (iw == 3) { 591 | // Exactly 2 output width elements remaining 592 | const float32x4_t vo0c01 = vzip1q_f32(vo0x0, vo0x1); 593 | const float32x4_t vo0c23 = vzip2q_f32(vo0x0, vo0x1); 594 | const float32x4_t vo1c01 = vzip1q_f32(vo1x0, vo1x1); 595 | const float32x4_t vo1c23 = vzip2q_f32(vo1x0, vo1x1); 596 | 597 | vst1_f32(o1c0, vget_low_f32(vo1c01)); o1c0 += 2; 598 | vst1_f32(o1c1, vget_high_f32(vo1c01)); o1c1 += 2; 599 | vst1_f32(o1c2, vget_low_f32(vo1c23)); o1c2 += 2; 600 | vst1_f32(o1c3, vget_high_f32(vo1c23)); o1c3 += 2; 601 | 602 | vst1_f32(o0c0, vget_low_f32(vo0c01)); o0c0 += 2; 603 | vst1_f32(o0c1, vget_high_f32(vo0c01)); o0c1 += 2; 604 | vst1_f32(o0c2, vget_low_f32(vo0c23)); o0c2 += 2; 605 | vst1_f32(o0c3, vget_high_f32(vo0c23)); o0c3 += 2; 606 | } else { 607 | // Exactly 1 output width element remaining 608 | 609 | vst1q_lane_f32(o1c0, vo1x0, 0); o1c0 += 1; 610 | vst1q_lane_f32(o1c1, vo1x0, 1); o1c1 += 1; 611 | vst1q_lane_f32(o1c2, vo1x0, 2); o1c2 += 1; 612 | vst1q_lane_f32(o1c3, vo1x0, 3); o1c3 += 1; 613 | 614 | vst1q_lane_f32(o0c0, vo0x0, 0); o0c0 += 1; 615 | vst1q_lane_f32(o0c1, vo0x0, 1); o0c1 += 1; 616 | vst1q_lane_f32(o0c2, vo0x0, 2); o0c2 += 1; 617 | vst1q_lane_f32(o0c3, vo0x0, 3); o0c3 += 1; 618 | } 619 | } 620 | // Move output pointers back to the position of the first pixel in a row, 621 | // and forward to the next block of output channels. 622 | o0c0 = (float*) ((uintptr_t) o0c0 + output_channel_increment); 623 | o0c1 = (float*) ((uintptr_t) o0c1 + output_channel_increment); 624 | o0c2 = (float*) ((uintptr_t) o0c2 + output_channel_increment); 625 | o0c3 = (float*) ((uintptr_t) o0c3 + output_channel_increment); 626 | o1c0 = (float*) ((uintptr_t) o1c0 + output_channel_increment); 627 | o1c1 = (float*) ((uintptr_t) o1c1 + output_channel_increment); 628 | o1c2 = (float*) ((uintptr_t) o1c2 + output_channel_increment); 629 | o1c3 = (float*) ((uintptr_t) o1c3 + output_channel_increment); 630 | // Revert input pointers to the position of the first pixel in a row 631 | i0 = (const float*) ((uintptr_t) i0 - input_width_increment); 632 | i1 = (const float*) ((uintptr_t) i1 - input_width_increment); 633 | i2 = (const float*) ((uintptr_t) i2 - input_width_increment); 634 | i3 = (const float*) ((uintptr_t) i3 - input_width_increment); 635 | i4 = (const float*) ((uintptr_t) i4 - input_width_increment); 636 | // Move to the block of weights for the next 4 output channels 637 | w += 112; 638 | c = doz(c, 4); 639 | } while (c != 0); 640 | // Move output pointers forward to the next two rows 641 | output0 = (float*) ((uintptr_t) output1 + output_height_stride); 642 | output1 = (float*) ((uintptr_t) output0 + output_height_stride); 643 | // Move input pointers forward to the next four rows 644 | i0 = i4; 645 | i1 = (const float*) ((uintptr_t) i0 + input_height_stride); 646 | i2 = (const float*) ((uintptr_t) i1 + input_height_stride); 647 | i3 = (const float*) ((uintptr_t) i2 + input_height_stride); 648 | i4 = (const float*) ((uintptr_t) i3 + input_height_stride); 649 | } 650 | } 651 | -------------------------------------------------------------------------------- /conv-3x3s2p1c3x4-scalar-1x1.c: -------------------------------------------------------------------------------- 1 | // Copyright The Fast Sparse ConvNets Authors. 2 | // 3 | // This source code is licensed under the BSD-style license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include 7 | 8 | void f32_conv_hwc2spchw_ukernel_3x3s2p1c3x4__scalar_1x1( 9 | size_t input_height, 10 | size_t input_width, 11 | size_t output_y_start, 12 | size_t output_y_end, 13 | const float* input, 14 | const float* zero, 15 | const float* weights, 16 | float* output, 17 | size_t input_padding_top, 18 | size_t output_channels, 19 | size_t output_height_stride, 20 | size_t output_channel_stride, 21 | const union f32_output_params params[restrict static 1]) 22 | { 23 | assert(input_width != 0); 24 | assert(output_y_end > output_y_start); 25 | assert(input_padding_top <= 1); 26 | assert(output_channels != 0); 27 | 28 | const size_t input_height_stride = input_width * 3 /* channels */ * sizeof(float); 29 | const size_t input_width_increment = round_down_po2(input_width, 2) * 3 /* channels */ * sizeof(float); 30 | const size_t output_width = (input_width + 1) / 2; 31 | const size_t output_channel_increment = output_channel_stride * 4 - output_width * sizeof(float); 32 | 33 | // Adjustment for padding processed below 34 | const float* i0 = (const float*) ((uintptr_t) input + input_height_stride * (output_y_start * 2 - input_padding_top)); 35 | const float* i1 = (const float*) ((uintptr_t) i0 + input_height_stride); 36 | const float* i2 = (const float*) ((uintptr_t) i1 + input_height_stride); 37 | float* output0 = (float*) ((uintptr_t) output + output_height_stride * output_y_start); 38 | 39 | if (output_y_start < input_padding_top) { 40 | i0 = zero; 41 | } 42 | 43 | const float voutput_max = params->scalar.max; 44 | const float voutput_min = params->scalar.min; 45 | 46 | for (size_t output_y = output_y_start; output_y < output_y_end; output_y += 1) { 47 | const size_t input_y2 = output_y * 2 + 2 - input_padding_top; 48 | if (input_y2 >= input_height) { 49 | i2 = zero; 50 | } 51 | 52 | const float* w = weights; 53 | size_t c = output_channels; 54 | float* o0c0 = output0; 55 | float* o0c1 = (float*) ((uintptr_t) o0c0 + output_channel_stride); 56 | float* o0c2 = (float*) ((uintptr_t) o0c1 + output_channel_stride); 57 | float* o0c3 = (float*) ((uintptr_t) o0c2 + output_channel_stride); 58 | do { 59 | if (c < 2) { 60 | o0c1 = o0c0; 61 | } 62 | if (c <= 2) { 63 | o0c2 = o0c1; 64 | } 65 | if (c < 4) { 66 | o0c3 = o0c2; 67 | } 68 | 69 | // Left edge padding 70 | float vr0c0 = 0.f; 71 | float vr0c1 = 0.f; 72 | float vr0c2 = 0.f; 73 | float vr1c0 = 0.f; 74 | float vr1c1 = 0.f; 75 | float vr1c2 = 0.f; 76 | float vr2c0 = 0.f; 77 | float vr2c1 = 0.f; 78 | float vr2c2 = 0.f; 79 | 80 | size_t iw = input_width; 81 | for (; iw >= 2; iw -= 2) { 82 | // start with biases 83 | float vc0_out = w[0]; 84 | float vc1_out = w[1]; 85 | float vc2_out = w[2]; 86 | float vc3_out = w[3]; 87 | 88 | const float vk00ic0oc0 = w[4]; 89 | const float vk00ic0oc1 = w[5]; 90 | const float vk00ic0oc2 = w[6]; 91 | const float vk00ic0oc3 = w[7]; 92 | 93 | vc0_out += vk00ic0oc0 * vr0c0; 94 | vc1_out += vk00ic0oc1 * vr0c0; 95 | vc2_out += vk00ic0oc2 * vr0c0; 96 | vc3_out += vk00ic0oc3 * vr0c0; 97 | 98 | const float vk10ic0oc0 = w[8]; 99 | const float vk10ic0oc1 = w[9]; 100 | const float vk10ic0oc2 = w[10]; 101 | const float vk10ic0oc3 = w[11]; 102 | 103 | vc0_out += vk10ic0oc0 * vr1c0; 104 | vc1_out += vk10ic0oc1 * vr1c0; 105 | vc2_out += vk10ic0oc2 * vr1c0; 106 | vc3_out += vk10ic0oc3 * vr1c0; 107 | 108 | const float vk20ic0oc0 = w[12]; 109 | const float vk20ic0oc1 = w[13]; 110 | const float vk20ic0oc2 = w[14]; 111 | const float vk20ic0oc3 = w[15]; 112 | 113 | vc0_out += vk20ic0oc0 * vr2c0; 114 | vc1_out += vk20ic0oc1 * vr2c0; 115 | vc2_out += vk20ic0oc2 * vr2c0; 116 | vc3_out += vk20ic0oc3 * vr2c0; 117 | 118 | const float vk00ic1oc0 = w[16]; 119 | const float vk00ic1oc1 = w[17]; 120 | const float vk00ic1oc2 = w[18]; 121 | const float vk00ic1oc3 = w[19]; 122 | 123 | 124 | vc0_out += vk00ic1oc0 * vr0c1; 125 | vc1_out += vk00ic1oc1 * vr0c1; 126 | vc2_out += vk00ic1oc2 * vr0c1; 127 | vc3_out += vk00ic1oc3 * vr0c1; 128 | 129 | const float vk10ic1oc0 = w[20]; 130 | const float vk10ic1oc1 = w[21]; 131 | const float vk10ic1oc2 = w[22]; 132 | const float vk10ic1oc3 = w[23]; 133 | 134 | vc0_out += vk10ic1oc0 * vr1c1; 135 | vc1_out += vk10ic1oc1 * vr1c1; 136 | vc2_out += vk10ic1oc2 * vr1c1; 137 | vc3_out += vk10ic1oc3 * vr1c1; 138 | 139 | const float vk20ic1oc0 = w[24]; 140 | const float vk20ic1oc1 = w[25]; 141 | const float vk20ic1oc2 = w[26]; 142 | const float vk20ic1oc3 = w[27]; 143 | 144 | vc0_out += vk20ic1oc0 * vr2c1; 145 | vc1_out += vk20ic1oc1 * vr2c1; 146 | vc2_out += vk20ic1oc2 * vr2c1; 147 | vc3_out += vk20ic1oc3 * vr2c1; 148 | 149 | const float vk00ic2oc0 = w[28]; 150 | const float vk00ic2oc1 = w[29]; 151 | const float vk00ic2oc2 = w[30]; 152 | const float vk00ic2oc3 = w[31]; 153 | 154 | vc0_out += vk00ic2oc0 * vr0c2; 155 | vc1_out += vk00ic2oc1 * vr0c2; 156 | vc2_out += vk00ic2oc2 * vr0c2; 157 | vc3_out += vk00ic2oc3 * vr0c2; 158 | 159 | const float vk10ic2oc0 = w[32]; 160 | const float vk10ic2oc1 = w[33]; 161 | const float vk10ic2oc2 = w[34]; 162 | const float vk10ic2oc3 = w[35]; 163 | 164 | vc0_out += vk10ic2oc0 * vr1c2; 165 | vc1_out += vk10ic2oc1 * vr1c2; 166 | vc2_out += vk10ic2oc2 * vr1c2; 167 | vc3_out += vk10ic2oc3 * vr1c2; 168 | 169 | const float vk20ic2oc0 = w[36]; 170 | const float vk20ic2oc1 = w[37]; 171 | const float vk20ic2oc2 = w[38]; 172 | const float vk20ic2oc3 = w[39]; 173 | 174 | vc0_out += vk20ic2oc0 * vr2c2; 175 | vc1_out += vk20ic2oc1 * vr2c2; 176 | vc2_out += vk20ic2oc2 * vr2c2; 177 | vc3_out += vk20ic2oc3 * vr2c2; 178 | 179 | const float vk01ic0oc0 = w[40]; 180 | const float vk01ic0oc1 = w[41]; 181 | const float vk01ic0oc2 = w[42]; 182 | const float vk01ic0oc3 = w[43]; 183 | 184 | const float i00 = i0[0]; 185 | 186 | vc0_out += vk01ic0oc0 * i00; 187 | vc1_out += vk01ic0oc1 * i00; 188 | vc2_out += vk01ic0oc2 * i00; 189 | vc3_out += vk01ic0oc3 * i00; 190 | 191 | const float vk11ic0oc0 = w[44]; 192 | const float vk11ic0oc1 = w[45]; 193 | const float vk11ic0oc2 = w[46]; 194 | const float vk11ic0oc3 = w[47]; 195 | 196 | const float i10 = i1[0]; 197 | 198 | vc0_out += vk11ic0oc0 * i10; 199 | vc1_out += vk11ic0oc1 * i10; 200 | vc2_out += vk11ic0oc2 * i10; 201 | vc3_out += vk11ic0oc3 * i10; 202 | 203 | const float vk21ic0oc0 = w[48]; 204 | const float vk21ic0oc1 = w[49]; 205 | const float vk21ic0oc2 = w[50]; 206 | const float vk21ic0oc3 = w[51]; 207 | 208 | const float i20 = i2[0]; 209 | 210 | vc0_out += vk21ic0oc0 * i20; 211 | vc1_out += vk21ic0oc1 * i20; 212 | vc2_out += vk21ic0oc2 * i20; 213 | vc3_out += vk21ic0oc3 * i20; 214 | 215 | const float vk01ic1oc0 = w[52]; 216 | const float vk01ic1oc1 = w[53]; 217 | const float vk01ic1oc2 = w[54]; 218 | const float vk01ic1oc3 = w[55]; 219 | 220 | const float i01 = i0[1]; 221 | 222 | vc0_out += vk01ic1oc0 * i01; 223 | vc1_out += vk01ic1oc1 * i01; 224 | vc2_out += vk01ic1oc2 * i01; 225 | vc3_out += vk01ic1oc3 * i01; 226 | 227 | const float vk11ic1oc0 = w[56]; 228 | const float vk11ic1oc1 = w[57]; 229 | const float vk11ic1oc2 = w[58]; 230 | const float vk11ic1oc3 = w[59]; 231 | 232 | const float i11 = i1[1]; 233 | 234 | vc0_out += vk11ic1oc0 * i11; 235 | vc1_out += vk11ic1oc1 * i11; 236 | vc2_out += vk11ic1oc2 * i11; 237 | vc3_out += vk11ic1oc3 * i11; 238 | 239 | const float vk21ic1oc0 = w[60]; 240 | const float vk21ic1oc1 = w[61]; 241 | const float vk21ic1oc2 = w[62]; 242 | const float vk21ic1oc3 = w[63]; 243 | 244 | const float i21 = i2[1]; 245 | 246 | vc0_out += vk21ic1oc0 * i21; 247 | vc1_out += vk21ic1oc1 * i21; 248 | vc2_out += vk21ic1oc2 * i21; 249 | vc3_out += vk21ic1oc3 * i21; 250 | 251 | const float vk01ic2oc0 = w[64]; 252 | const float vk01ic2oc1 = w[65]; 253 | const float vk01ic2oc2 = w[66]; 254 | const float vk01ic2oc3 = w[67]; 255 | 256 | const float i02 = i0[2]; 257 | 258 | vc0_out += vk01ic2oc0 * i02; 259 | vc1_out += vk01ic2oc1 * i02; 260 | vc2_out += vk01ic2oc2 * i02; 261 | vc3_out += vk01ic2oc3 * i02; 262 | 263 | const float vk11ic2oc0 = w[68]; 264 | const float vk11ic2oc1 = w[69]; 265 | const float vk11ic2oc2 = w[70]; 266 | const float vk11ic2oc3 = w[71]; 267 | 268 | const float i12 = i1[2]; 269 | 270 | vc0_out += vk11ic2oc0 * i12; 271 | vc1_out += vk11ic2oc1 * i12; 272 | vc2_out += vk11ic2oc2 * i12; 273 | vc3_out += vk11ic2oc3 * i12; 274 | 275 | const float vk21ic2oc0 = w[72]; 276 | const float vk21ic2oc1 = w[73]; 277 | const float vk21ic2oc2 = w[74]; 278 | const float vk21ic2oc3 = w[75]; 279 | 280 | const float i22 = i2[2]; 281 | 282 | vc0_out += vk21ic2oc0 * i22; 283 | vc1_out += vk21ic2oc1 * i22; 284 | vc2_out += vk21ic2oc2 * i22; 285 | vc3_out += vk21ic2oc3 * i22; 286 | 287 | const float vk02ic0oc0 = w[76]; 288 | const float vk02ic0oc1 = w[77]; 289 | const float vk02ic0oc2 = w[78]; 290 | const float vk02ic0oc3 = w[79]; 291 | 292 | const float i03 = i0[3]; 293 | 294 | vc0_out += vk02ic0oc0 * i03; 295 | vc1_out += vk02ic0oc1 * i03; 296 | vc2_out += vk02ic0oc2 * i03; 297 | vc3_out += vk02ic0oc3 * i03; 298 | 299 | const float vk12ic0oc0 = w[80]; 300 | const float vk12ic0oc1 = w[81]; 301 | const float vk12ic0oc2 = w[82]; 302 | const float vk12ic0oc3 = w[83]; 303 | 304 | const float i13 = i1[3]; 305 | 306 | vc0_out += vk12ic0oc0 * i13; 307 | vc1_out += vk12ic0oc1 * i13; 308 | vc2_out += vk12ic0oc2 * i13; 309 | vc3_out += vk12ic0oc3 * i13; 310 | 311 | const float vk22ic0oc0 = w[84]; 312 | const float vk22ic0oc1 = w[85]; 313 | const float vk22ic0oc2 = w[86]; 314 | const float vk22ic0oc3 = w[87]; 315 | 316 | const float i23 = i2[3]; 317 | 318 | vc0_out += vk22ic0oc0 * i23; 319 | vc1_out += vk22ic0oc1 * i23; 320 | vc2_out += vk22ic0oc2 * i23; 321 | vc3_out += vk22ic0oc3 * i23; 322 | 323 | vr0c0 = i03; 324 | vr1c0 = i13; 325 | vr2c0 = i23; 326 | 327 | const float vk02ic1oc0 = w[88]; 328 | const float vk02ic1oc1 = w[89]; 329 | const float vk02ic1oc2 = w[90]; 330 | const float vk02ic1oc3 = w[91]; 331 | 332 | const float i04 = i0[4]; 333 | 334 | vc0_out += vk02ic1oc0 * i04; 335 | vc1_out += vk02ic1oc1 * i04; 336 | vc2_out += vk02ic1oc2 * i04; 337 | vc3_out += vk02ic1oc3 * i04; 338 | 339 | const float vk12ic1oc0 = w[92]; 340 | const float vk12ic1oc1 = w[93]; 341 | const float vk12ic1oc2 = w[94]; 342 | const float vk12ic1oc3 = w[95]; 343 | 344 | const float i14 = i1[4]; 345 | 346 | vc0_out += vk12ic1oc0 * i14; 347 | vc1_out += vk12ic1oc1 * i14; 348 | vc2_out += vk12ic1oc2 * i14; 349 | vc3_out += vk12ic1oc3 * i14; 350 | 351 | const float vk22ic1oc0 = w[96]; 352 | const float vk22ic1oc1 = w[97]; 353 | const float vk22ic1oc2 = w[98]; 354 | const float vk22ic1oc3 = w[99]; 355 | 356 | const float i24 = i2[4]; 357 | 358 | vc0_out += vk22ic1oc0 * i24; 359 | vc1_out += vk22ic1oc1 * i24; 360 | vc2_out += vk22ic1oc2 * i24; 361 | vc3_out += vk22ic1oc3 * i24; 362 | 363 | vr0c1 = i04; 364 | vr1c1 = i14; 365 | vr2c1 = i24; 366 | 367 | const float vk02ic2oc0 = w[100]; 368 | const float vk02ic2oc1 = w[101]; 369 | const float vk02ic2oc2 = w[102]; 370 | const float vk02ic2oc3 = w[103]; 371 | 372 | const float i05 = i0[5]; 373 | 374 | vc0_out += vk02ic2oc0 * i05; 375 | vc1_out += vk02ic2oc1 * i05; 376 | vc2_out += vk02ic2oc2 * i05; 377 | vc3_out += vk02ic2oc3 * i05; 378 | 379 | const float vk12ic2oc0 = w[104]; 380 | const float vk12ic2oc1 = w[105]; 381 | const float vk12ic2oc2 = w[106]; 382 | const float vk12ic2oc3 = w[107]; 383 | 384 | const float i15 = i1[5]; 385 | 386 | vc0_out += vk12ic2oc0 * i15; 387 | vc1_out += vk12ic2oc1 * i15; 388 | vc2_out += vk12ic2oc2 * i15; 389 | vc3_out += vk12ic2oc3 * i15; 390 | 391 | const float vk22ic2oc0 = w[108]; 392 | const float vk22ic2oc1 = w[109]; 393 | const float vk22ic2oc2 = w[110]; 394 | const float vk22ic2oc3 = w[111]; 395 | 396 | const float i25 = i2[5]; 397 | 398 | vc0_out += vk22ic2oc0 * i25; 399 | vc1_out += vk22ic2oc1 * i25; 400 | vc2_out += vk22ic2oc2 * i25; 401 | vc3_out += vk22ic2oc3 * i25; 402 | 403 | vr0c2 = i05; 404 | vr1c2 = i15; 405 | vr2c2 = i25; 406 | 407 | vc0_out = math_min_f32(vc0_out, voutput_max); 408 | vc0_out = math_max_f32(vc0_out, voutput_min); 409 | vc1_out = math_min_f32(vc1_out, voutput_max); 410 | vc1_out = math_max_f32(vc1_out, voutput_min); 411 | vc2_out = math_min_f32(vc2_out, voutput_max); 412 | vc2_out = math_max_f32(vc2_out, voutput_min); 413 | vc3_out = math_min_f32(vc3_out, voutput_max); 414 | vc3_out = math_max_f32(vc3_out, voutput_min); 415 | 416 | *o0c0 = vc0_out; o0c0 += 1; 417 | *o0c1 = vc1_out; o0c1 += 1; 418 | *o0c2 = vc2_out; o0c2 += 1; 419 | *o0c3 = vc3_out; o0c3 += 1; 420 | 421 | i0 += 6; 422 | i1 += 6; 423 | i2 += 6; 424 | } 425 | assert(iw < 2); 426 | if (iw != 0) { 427 | // start with biases 428 | float vc0_out = w[0]; 429 | float vc1_out = w[1]; 430 | float vc2_out = w[2]; 431 | float vc3_out = w[3]; 432 | 433 | const float vk00ic0oc0 = w[4]; 434 | const float vk00ic0oc1 = w[5]; 435 | const float vk00ic0oc2 = w[6]; 436 | const float vk00ic0oc3 = w[7]; 437 | 438 | vc0_out += vk00ic0oc0 * vr0c0; 439 | vc1_out += vk00ic0oc1 * vr0c0; 440 | vc2_out += vk00ic0oc2 * vr0c0; 441 | vc3_out += vk00ic0oc3 * vr0c0; 442 | 443 | const float vk10ic0oc0 = w[8]; 444 | const float vk10ic0oc1 = w[9]; 445 | const float vk10ic0oc2 = w[10]; 446 | const float vk10ic0oc3 = w[11]; 447 | 448 | vc0_out += vk10ic0oc0 * vr1c0; 449 | vc1_out += vk10ic0oc1 * vr1c0; 450 | vc2_out += vk10ic0oc2 * vr1c0; 451 | vc3_out += vk10ic0oc3 * vr1c0; 452 | 453 | const float vk20ic0oc0 = w[12]; 454 | const float vk20ic0oc1 = w[13]; 455 | const float vk20ic0oc2 = w[14]; 456 | const float vk20ic0oc3 = w[15]; 457 | 458 | vc0_out += vk20ic0oc0 * vr2c0; 459 | vc1_out += vk20ic0oc1 * vr2c0; 460 | vc2_out += vk20ic0oc2 * vr2c0; 461 | vc3_out += vk20ic0oc3 * vr2c0; 462 | 463 | const float vk00ic1oc0 = w[16]; 464 | const float vk00ic1oc1 = w[17]; 465 | const float vk00ic1oc2 = w[18]; 466 | const float vk00ic1oc3 = w[19]; 467 | 468 | 469 | vc0_out += vk00ic1oc0 * vr0c1; 470 | vc1_out += vk00ic1oc1 * vr0c1; 471 | vc2_out += vk00ic1oc2 * vr0c1; 472 | vc3_out += vk00ic1oc3 * vr0c1; 473 | 474 | const float vk10ic1oc0 = w[20]; 475 | const float vk10ic1oc1 = w[21]; 476 | const float vk10ic1oc2 = w[22]; 477 | const float vk10ic1oc3 = w[23]; 478 | 479 | vc0_out += vk10ic1oc0 * vr1c1; 480 | vc1_out += vk10ic1oc1 * vr1c1; 481 | vc2_out += vk10ic1oc2 * vr1c1; 482 | vc3_out += vk10ic1oc3 * vr1c1; 483 | 484 | const float vk20ic1oc0 = w[24]; 485 | const float vk20ic1oc1 = w[25]; 486 | const float vk20ic1oc2 = w[26]; 487 | const float vk20ic1oc3 = w[27]; 488 | 489 | vc0_out += vk20ic1oc0 * vr2c1; 490 | vc1_out += vk20ic1oc1 * vr2c1; 491 | vc2_out += vk20ic1oc2 * vr2c1; 492 | vc3_out += vk20ic1oc3 * vr2c1; 493 | 494 | const float vk00ic2oc0 = w[28]; 495 | const float vk00ic2oc1 = w[29]; 496 | const float vk00ic2oc2 = w[30]; 497 | const float vk00ic2oc3 = w[31]; 498 | 499 | vc0_out += vk00ic2oc0 * vr0c2; 500 | vc1_out += vk00ic2oc1 * vr0c2; 501 | vc2_out += vk00ic2oc2 * vr0c2; 502 | vc3_out += vk00ic2oc3 * vr0c2; 503 | 504 | const float vk10ic2oc0 = w[32]; 505 | const float vk10ic2oc1 = w[33]; 506 | const float vk10ic2oc2 = w[34]; 507 | const float vk10ic2oc3 = w[35]; 508 | 509 | vc0_out += vk10ic2oc0 * vr1c2; 510 | vc1_out += vk10ic2oc1 * vr1c2; 511 | vc2_out += vk10ic2oc2 * vr1c2; 512 | vc3_out += vk10ic2oc3 * vr1c2; 513 | 514 | const float vk20ic2oc0 = w[36]; 515 | const float vk20ic2oc1 = w[37]; 516 | const float vk20ic2oc2 = w[38]; 517 | const float vk20ic2oc3 = w[39]; 518 | 519 | vc0_out += vk20ic2oc0 * vr2c2; 520 | vc1_out += vk20ic2oc1 * vr2c2; 521 | vc2_out += vk20ic2oc2 * vr2c2; 522 | vc3_out += vk20ic2oc3 * vr2c2; 523 | 524 | const float vk01ic0oc0 = w[40]; 525 | const float vk01ic0oc1 = w[41]; 526 | const float vk01ic0oc2 = w[42]; 527 | const float vk01ic0oc3 = w[43]; 528 | 529 | const float i00 = i0[0]; 530 | 531 | vc0_out += vk01ic0oc0 * i00; 532 | vc1_out += vk01ic0oc1 * i00; 533 | vc2_out += vk01ic0oc2 * i00; 534 | vc3_out += vk01ic0oc3 * i00; 535 | 536 | const float vk11ic0oc0 = w[44]; 537 | const float vk11ic0oc1 = w[45]; 538 | const float vk11ic0oc2 = w[46]; 539 | const float vk11ic0oc3 = w[47]; 540 | 541 | const float i10 = i1[0]; 542 | 543 | vc0_out += vk11ic0oc0 * i10; 544 | vc1_out += vk11ic0oc1 * i10; 545 | vc2_out += vk11ic0oc2 * i10; 546 | vc3_out += vk11ic0oc3 * i10; 547 | 548 | const float vk21ic0oc0 = w[48]; 549 | const float vk21ic0oc1 = w[49]; 550 | const float vk21ic0oc2 = w[50]; 551 | const float vk21ic0oc3 = w[51]; 552 | 553 | const float i20 = i2[0]; 554 | 555 | vc0_out += vk21ic0oc0 * i20; 556 | vc1_out += vk21ic0oc1 * i20; 557 | vc2_out += vk21ic0oc2 * i20; 558 | vc3_out += vk21ic0oc3 * i20; 559 | 560 | const float vk01ic1oc0 = w[52]; 561 | const float vk01ic1oc1 = w[53]; 562 | const float vk01ic1oc2 = w[54]; 563 | const float vk01ic1oc3 = w[55]; 564 | 565 | const float i01 = i0[1]; 566 | 567 | vc0_out += vk01ic1oc0 * i01; 568 | vc1_out += vk01ic1oc1 * i01; 569 | vc2_out += vk01ic1oc2 * i01; 570 | vc3_out += vk01ic1oc3 * i01; 571 | 572 | const float vk11ic1oc0 = w[56]; 573 | const float vk11ic1oc1 = w[57]; 574 | const float vk11ic1oc2 = w[58]; 575 | const float vk11ic1oc3 = w[59]; 576 | 577 | const float i11 = i1[1]; 578 | 579 | vc0_out += vk11ic1oc0 * i11; 580 | vc1_out += vk11ic1oc1 * i11; 581 | vc2_out += vk11ic1oc2 * i11; 582 | vc3_out += vk11ic1oc3 * i11; 583 | 584 | const float vk21ic1oc0 = w[60]; 585 | const float vk21ic1oc1 = w[61]; 586 | const float vk21ic1oc2 = w[62]; 587 | const float vk21ic1oc3 = w[63]; 588 | 589 | const float i21 = i2[1]; 590 | 591 | vc0_out += vk21ic1oc0 * i21; 592 | vc1_out += vk21ic1oc1 * i21; 593 | vc2_out += vk21ic1oc2 * i21; 594 | vc3_out += vk21ic1oc3 * i21; 595 | 596 | const float vk01ic2oc0 = w[64]; 597 | const float vk01ic2oc1 = w[65]; 598 | const float vk01ic2oc2 = w[66]; 599 | const float vk01ic2oc3 = w[67]; 600 | 601 | const float i02 = i0[2]; 602 | 603 | vc0_out += vk01ic2oc0 * i02; 604 | vc1_out += vk01ic2oc1 * i02; 605 | vc2_out += vk01ic2oc2 * i02; 606 | vc3_out += vk01ic2oc3 * i02; 607 | 608 | const float vk11ic2oc0 = w[68]; 609 | const float vk11ic2oc1 = w[69]; 610 | const float vk11ic2oc2 = w[70]; 611 | const float vk11ic2oc3 = w[71]; 612 | 613 | const float i12 = i1[2]; 614 | 615 | vc0_out += vk11ic2oc0 * i12; 616 | vc1_out += vk11ic2oc1 * i12; 617 | vc2_out += vk11ic2oc2 * i12; 618 | vc3_out += vk11ic2oc3 * i12; 619 | 620 | const float vk21ic2oc0 = w[72]; 621 | const float vk21ic2oc1 = w[73]; 622 | const float vk21ic2oc2 = w[74]; 623 | const float vk21ic2oc3 = w[75]; 624 | 625 | const float i22 = i2[2]; 626 | 627 | vc0_out += vk21ic2oc0 * i22; 628 | vc1_out += vk21ic2oc1 * i22; 629 | vc2_out += vk21ic2oc2 * i22; 630 | vc3_out += vk21ic2oc3 * i22; 631 | 632 | vc0_out = math_min_f32(vc0_out, voutput_max); 633 | vc0_out = math_max_f32(vc0_out, voutput_min); 634 | vc1_out = math_min_f32(vc1_out, voutput_max); 635 | vc1_out = math_max_f32(vc1_out, voutput_min); 636 | vc2_out = math_min_f32(vc2_out, voutput_max); 637 | vc2_out = math_max_f32(vc2_out, voutput_min); 638 | vc3_out = math_min_f32(vc3_out, voutput_max); 639 | vc3_out = math_max_f32(vc3_out, voutput_min); 640 | 641 | *o0c0 = vc0_out; o0c0 += 1; 642 | *o0c1 = vc1_out; o0c1 += 1; 643 | *o0c2 = vc2_out; o0c2 += 1; 644 | *o0c3 = vc3_out; o0c3 += 1; 645 | } 646 | // Move output pointers back to the position of the first pixel in a row, 647 | // and forward to the next block of output channels. 648 | o0c0 = (float*) ((uintptr_t) o0c0 + output_channel_increment); 649 | o0c1 = (float*) ((uintptr_t) o0c1 + output_channel_increment); 650 | o0c2 = (float*) ((uintptr_t) o0c2 + output_channel_increment); 651 | o0c3 = (float*) ((uintptr_t) o0c3 + output_channel_increment); 652 | // Revert input pointers to the position of the first pixel in a row 653 | i0 = (const float*) ((uintptr_t) i0 - input_width_increment); 654 | i1 = (const float*) ((uintptr_t) i1 - input_width_increment); 655 | i2 = (const float*) ((uintptr_t) i2 - input_width_increment); 656 | // Move to the block of weights for the next 4 output channels 657 | w += 112; 658 | c = doz(c, 4); 659 | } while (c != 0); 660 | // Move output pointers forward to the next two rows 661 | output0 = (float*) ((uintptr_t) output0 + output_height_stride); 662 | // Move input pointers forward to the next four rows 663 | i0 = i2; 664 | i1 = (const float*) ((uintptr_t) i0 + input_height_stride); 665 | i2 = (const float*) ((uintptr_t) i1 + input_height_stride); 666 | } 667 | } 668 | -------------------------------------------------------------------------------- /dwconv-3x3p1-neonfma.c: -------------------------------------------------------------------------------- 1 | // Copyright The Fast Sparse ConvNets Authors. 2 | // 3 | // This source code is licensed under the BSD-style license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include 7 | 8 | #include 9 | 10 | void f32_dwconv_spchw_ukernel_3x3p1__neonfma( 11 | size_t m, 12 | size_t n, 13 | const float* input, 14 | const float* weights, 15 | float* output, 16 | size_t input_tuple_stride, 17 | size_t output_tuple_stride, 18 | size_t input_width_stride, 19 | size_t output_width_stride, 20 | const union f32_spchw_params params[restrict static 1]) 21 | { 22 | assert(n != 0); 23 | 24 | const uint32x4_t vmask = vld1q_u32(params->neon.mask); 25 | const float32x4_t vmax = vld1q_dup_f32(¶ms->neon.max); 26 | const float32x4_t vmin = vld1q_dup_f32(¶ms->neon.min); 27 | 28 | const size_t input_width_increment = 3 * input_width_stride - round_up_po2(n, 4) / 4 * input_tuple_stride; 29 | const size_t output_width_increment = 3 * output_width_stride - (n - 1) / 4 * output_tuple_stride; 30 | const size_t input_width_increment_single = input_width_stride - round_up_po2(n, 4) / 4 * input_tuple_stride; 31 | const size_t output_width_increment_single = output_width_stride - (n - 1) / 4 * output_tuple_stride; 32 | 33 | // No vertical padding. 34 | const float* i0 = input; 35 | const float* i1 = (const float*) ((uintptr_t) i0 + input_width_stride); 36 | const float* i2 = (const float*) ((uintptr_t) i1 + input_width_stride); 37 | const float* i3 = (const float*) ((uintptr_t) i2 + input_width_stride); 38 | const float* i4 = (const float*) ((uintptr_t) i3 + input_width_stride); 39 | 40 | float* output0 = output; 41 | float* output1 = (float *)((uintptr_t)output0 + output_width_stride); 42 | float* output2 = (float *)((uintptr_t)output1 + output_width_stride); 43 | 44 | const float32x4_t vw0123 = vld1q_f32(weights); 45 | const float32x4_t vw4567 = vld1q_f32(weights + 4); 46 | const float32x2_t vw89 = vld1_f32(weights + 8); 47 | 48 | while (m >= 3) { 49 | float32x4_t vi0x0123 = vmovq_n_f32(0.0f); 50 | float32x4_t vi1x0123 = vmovq_n_f32(0.0f); 51 | float32x4_t vi2x0123 = vmovq_n_f32(0.0f); 52 | float32x4_t vi3x0123 = vmovq_n_f32(0.0f); 53 | float32x4_t vi4x0123 = vmovq_n_f32(0.0f); 54 | float32x4_t vi0x4567 = vld1q_f32(i0); i0 = (const float*) ((uintptr_t) i0 + input_tuple_stride); 55 | float32x4_t vi1x4567 = vld1q_f32(i1); i1 = (const float*) ((uintptr_t) i1 + input_tuple_stride); 56 | float32x4_t vi2x4567 = vld1q_f32(i2); i2 = (const float*) ((uintptr_t) i2 + input_tuple_stride); 57 | float32x4_t vi3x4567 = vld1q_f32(i3); i3 = (const float*) ((uintptr_t) i3 + input_tuple_stride); 58 | float32x4_t vi4x4567 = vld1q_f32(i4); i4 = (const float*) ((uintptr_t) i4 + input_tuple_stride); 59 | 60 | size_t k = n; 61 | for (; k > 4; k -= 4) { 62 | float32x4_t vo4567p00 = vdupq_laneq_f32(vw0123, 0); 63 | float32x4_t vo4567p01 = vdupq_laneq_f32(vw0123, 0); 64 | float32x4_t vo4567p02 = vdupq_laneq_f32(vw0123, 0); 65 | 66 | const float32x4_t vi0x89AB = vld1q_f32(i0); i0 = (const float*) ((uintptr_t) i0 + input_tuple_stride); 67 | const float32x4_t vi1x89AB = vld1q_f32(i1); i1 = (const float*) ((uintptr_t) i1 + input_tuple_stride); 68 | const float32x4_t vi2x89AB = vld1q_f32(i2); i2 = (const float*) ((uintptr_t) i2 + input_tuple_stride); 69 | const float32x4_t vi3x89AB = vld1q_f32(i3); i3 = (const float*) ((uintptr_t) i3 + input_tuple_stride); 70 | const float32x4_t vi4x89AB = vld1q_f32(i4); i4 = (const float*) ((uintptr_t) i4 + input_tuple_stride); 71 | 72 | vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi0x4567, vw0123, 2); 73 | vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi1x4567, vw4567, 1); 74 | vo4567p00 = vfmaq_lane_f32(vo4567p00, vi2x4567, vw89, 0); 75 | 76 | vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi1x4567, vw0123, 2); 77 | vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi2x4567, vw4567, 1); 78 | vo4567p01 = vfmaq_lane_f32(vo4567p01, vi3x4567, vw89, 0); 79 | 80 | vo4567p02 = vfmaq_laneq_f32(vo4567p02, vi2x4567, vw0123, 2); 81 | vo4567p02 = vfmaq_laneq_f32(vo4567p02, vi3x4567, vw4567, 1); 82 | vo4567p02 = vfmaq_lane_f32(vo4567p02, vi4x4567, vw89, 0); 83 | 84 | const float32x4_t vi0x3456 = vextq_f32(vi0x0123, vi0x4567, 3); 85 | const float32x4_t vi1x3456 = vextq_f32(vi1x0123, vi1x4567, 3); 86 | const float32x4_t vi2x3456 = vextq_f32(vi2x0123, vi2x4567, 3); 87 | const float32x4_t vi3x3456 = vextq_f32(vi3x0123, vi3x4567, 3); 88 | const float32x4_t vi4x3456 = vextq_f32(vi4x0123, vi4x4567, 3); 89 | 90 | vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi0x3456, vw0123, 1); 91 | vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi1x3456, vw4567, 0); 92 | vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi2x3456, vw4567, 3); 93 | 94 | vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi1x3456, vw0123, 1); 95 | vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi2x3456, vw4567, 0); 96 | vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi3x3456, vw4567, 3); 97 | 98 | vo4567p02 = vfmaq_laneq_f32(vo4567p02, vi2x3456, vw0123, 1); 99 | vo4567p02 = vfmaq_laneq_f32(vo4567p02, vi3x3456, vw4567, 0); 100 | vo4567p02 = vfmaq_laneq_f32(vo4567p02, vi4x3456, vw4567, 3); 101 | 102 | vi0x0123 = vi0x4567; 103 | vi1x0123 = vi1x4567; 104 | vi2x0123 = vi2x4567; 105 | vi3x0123 = vi3x4567; 106 | vi4x0123 = vi4x4567; 107 | 108 | const float32x4_t vi0x5678 = vextq_f32(vi0x4567, vi0x89AB, 1); 109 | const float32x4_t vi1x5678 = vextq_f32(vi1x4567, vi1x89AB, 1); 110 | const float32x4_t vi2x5678 = vextq_f32(vi2x4567, vi2x89AB, 1); 111 | const float32x4_t vi3x5678 = vextq_f32(vi3x4567, vi3x89AB, 1); 112 | const float32x4_t vi4x5678 = vextq_f32(vi4x4567, vi4x89AB, 1); 113 | 114 | vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi0x5678, vw0123, 3); 115 | vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi1x5678, vw4567, 2); 116 | vo4567p00 = vfmaq_lane_f32(vo4567p00, vi2x5678, vw89, 1); 117 | 118 | vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi1x5678, vw0123, 3); 119 | vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi2x5678, vw4567, 2); 120 | vo4567p01 = vfmaq_lane_f32(vo4567p01, vi3x5678, vw89, 1); 121 | 122 | vo4567p02 = vfmaq_laneq_f32(vo4567p02, vi2x5678, vw0123, 3); 123 | vo4567p02 = vfmaq_laneq_f32(vo4567p02, vi3x5678, vw4567, 2); 124 | vo4567p02 = vfmaq_lane_f32(vo4567p02, vi4x5678, vw89, 1); 125 | 126 | vi0x4567 = vi0x89AB; 127 | vi1x4567 = vi1x89AB; 128 | vi2x4567 = vi2x89AB; 129 | vi3x4567 = vi3x89AB; 130 | vi4x4567 = vi4x89AB; 131 | 132 | float32x4_t vo0 = vo4567p00; 133 | float32x4_t vo1 = vo4567p01; 134 | float32x4_t vo2 = vo4567p02; 135 | 136 | vo0 = vmaxq_f32(vo0, vmin); 137 | vo0 = vminq_f32(vo0, vmax); 138 | vo1 = vmaxq_f32(vo1, vmin); 139 | vo1 = vminq_f32(vo1, vmax); 140 | vo2 = vmaxq_f32(vo2, vmin); 141 | vo2 = vminq_f32(vo2, vmax); 142 | 143 | vst1q_f32(output0, vo0); output0 = (float*) ((uintptr_t) output0 + output_tuple_stride); 144 | vst1q_f32(output1, vo1); output1 = (float*) ((uintptr_t) output1 + output_tuple_stride); 145 | vst1q_f32(output2, vo2); output2 = (float*) ((uintptr_t) output2 + output_tuple_stride); 146 | } 147 | // Always process the last block of 1..4 pixels. 148 | assert(k >= 1); 149 | assert(k <= 4); 150 | { 151 | float32x4_t vo4567p00 = vdupq_laneq_f32(vw0123, 0); 152 | float32x4_t vo4567p01 = vdupq_laneq_f32(vw0123, 0); 153 | float32x4_t vo4567p02 = vdupq_laneq_f32(vw0123, 0); 154 | 155 | vi0x4567 = vreinterpretq_u32_f32(vandq_u32(vmask, vreinterpretq_f32_u32(vi0x4567))); 156 | vi1x4567 = vreinterpretq_u32_f32(vandq_u32(vmask, vreinterpretq_f32_u32(vi1x4567))); 157 | vi2x4567 = vreinterpretq_u32_f32(vandq_u32(vmask, vreinterpretq_f32_u32(vi2x4567))); 158 | vi3x4567 = vreinterpretq_u32_f32(vandq_u32(vmask, vreinterpretq_f32_u32(vi3x4567))); 159 | vi4x4567 = vreinterpretq_u32_f32(vandq_u32(vmask, vreinterpretq_f32_u32(vi4x4567))); 160 | 161 | vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi0x4567, vw0123, 2); 162 | vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi1x4567, vw4567, 1); 163 | vo4567p00 = vfmaq_lane_f32(vo4567p00, vi2x4567, vw89, 0); 164 | 165 | vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi1x4567, vw0123, 2); 166 | vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi2x4567, vw4567, 1); 167 | vo4567p01 = vfmaq_lane_f32(vo4567p01, vi3x4567, vw89, 0); 168 | 169 | vo4567p02 = vfmaq_laneq_f32(vo4567p02, vi2x4567, vw0123, 2); 170 | vo4567p02 = vfmaq_laneq_f32(vo4567p02, vi3x4567, vw4567, 1); 171 | vo4567p02 = vfmaq_lane_f32(vo4567p02, vi4x4567, vw89, 0); 172 | 173 | const float32x4_t vi0x3456 = vextq_f32(vi0x0123, vi0x4567, 3); 174 | const float32x4_t vi1x3456 = vextq_f32(vi1x0123, vi1x4567, 3); 175 | const float32x4_t vi2x3456 = vextq_f32(vi2x0123, vi2x4567, 3); 176 | const float32x4_t vi3x3456 = vextq_f32(vi3x0123, vi3x4567, 3); 177 | const float32x4_t vi4x3456 = vextq_f32(vi4x0123, vi4x4567, 3); 178 | 179 | vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi0x3456, vw0123, 1); 180 | vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi1x3456, vw4567, 0); 181 | vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi2x3456, vw4567, 3); 182 | 183 | vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi1x3456, vw0123, 1); 184 | vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi2x3456, vw4567, 0); 185 | vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi3x3456, vw4567, 3); 186 | 187 | vo4567p02 = vfmaq_laneq_f32(vo4567p02, vi2x3456, vw0123, 1); 188 | vo4567p02 = vfmaq_laneq_f32(vo4567p02, vi3x3456, vw4567, 0); 189 | vo4567p02 = vfmaq_laneq_f32(vo4567p02, vi4x3456, vw4567, 3); 190 | 191 | const float32x4_t vzero = vmovq_n_f32(0.0f); 192 | const float32x4_t vi0x5678 = vextq_f32(vi0x4567, vzero, 1); 193 | const float32x4_t vi1x5678 = vextq_f32(vi1x4567, vzero, 1); 194 | const float32x4_t vi2x5678 = vextq_f32(vi2x4567, vzero, 1); 195 | const float32x4_t vi3x5678 = vextq_f32(vi3x4567, vzero, 1); 196 | const float32x4_t vi4x5678 = vextq_f32(vi4x4567, vzero, 1); 197 | 198 | vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi0x5678, vw0123, 3); 199 | vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi1x5678, vw4567, 2); 200 | vo4567p00 = vfmaq_lane_f32(vo4567p00, vi2x5678, vw89, 1); 201 | 202 | vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi1x5678, vw0123, 3); 203 | vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi2x5678, vw4567, 2); 204 | vo4567p01 = vfmaq_lane_f32(vo4567p01, vi3x5678, vw89, 1); 205 | 206 | vo4567p02 = vfmaq_laneq_f32(vo4567p02, vi2x5678, vw0123, 3); 207 | vo4567p02 = vfmaq_laneq_f32(vo4567p02, vi3x5678, vw4567, 2); 208 | vo4567p02 = vfmaq_lane_f32(vo4567p02, vi4x5678, vw89, 1); 209 | 210 | float32x4_t vo0 = vo4567p00; 211 | float32x4_t vo1 = vo4567p01; 212 | float32x4_t vo2 = vo4567p02; 213 | 214 | vo0 = vmaxq_f32(vo0, vmin); 215 | vo0 = vminq_f32(vo0, vmax); 216 | vo1 = vmaxq_f32(vo1, vmin); 217 | vo1 = vminq_f32(vo1, vmax); 218 | vo2 = vmaxq_f32(vo2, vmin); 219 | vo2 = vminq_f32(vo2, vmax); 220 | 221 | if (k & 4) { 222 | vst1q_f32(output0, vo0); 223 | vst1q_f32(output1, vo1); 224 | vst1q_f32(output2, vo2); 225 | } else { 226 | float* output0_lo = output0; 227 | float* output1_lo = output1; 228 | float* output2_lo = output2; 229 | float32x2_t vo0_lo = vget_low_f32(vo0); 230 | float32x2_t vo1_lo = vget_low_f32(vo1); 231 | float32x2_t vo2_lo = vget_low_f32(vo2); 232 | if (k & 2) { 233 | vst1_f32(output0_lo, vo0_lo); output0_lo += 2; 234 | vst1_f32(output1_lo, vo1_lo); output1_lo += 2; 235 | vst1_f32(output2_lo, vo2_lo); output2_lo += 2; 236 | vo0_lo = vget_high_f32(vo0); 237 | vo1_lo = vget_high_f32(vo1); 238 | vo2_lo = vget_high_f32(vo2); 239 | } 240 | if (k & 1) { 241 | vst1_lane_f32(output0_lo, vo0_lo, 0); 242 | vst1_lane_f32(output1_lo, vo1_lo, 0); 243 | vst1_lane_f32(output2_lo, vo2_lo, 0); 244 | } 245 | } 246 | } 247 | 248 | i0 = (const float*) ((uintptr_t) i0 + input_width_increment); 249 | i1 = (const float*) ((uintptr_t) i1 + input_width_increment); 250 | i2 = (const float*) ((uintptr_t) i2 + input_width_increment); 251 | i3 = (const float*) ((uintptr_t) i3 + input_width_increment); 252 | i4 = (const float*) ((uintptr_t) i4 + input_width_increment); 253 | output0 = (float*) ((uintptr_t) output0 + output_width_increment); 254 | output1 = (float*) ((uintptr_t) output1 + output_width_increment); 255 | output2 = (float*) ((uintptr_t) output2 + output_width_increment); 256 | m -= 3; 257 | } 258 | 259 | while (m != 0) { 260 | float32x4_t vi0x0123 = vmovq_n_f32(0.0f); 261 | float32x4_t vi1x0123 = vmovq_n_f32(0.0f); 262 | float32x4_t vi2x0123 = vmovq_n_f32(0.0f); 263 | float32x4_t vi0x4567 = vld1q_f32(i0); i0 = (const float*) ((uintptr_t) i0 + input_tuple_stride); 264 | float32x4_t vi1x4567 = vld1q_f32(i1); i1 = (const float*) ((uintptr_t) i1 + input_tuple_stride); 265 | float32x4_t vi2x4567 = vld1q_f32(i2); i2 = (const float*) ((uintptr_t) i2 + input_tuple_stride); 266 | 267 | size_t k = n; 268 | for (; k > 4; k -= 4) { 269 | float32x4_t vo4567p0 = vdupq_laneq_f32(vw0123, 0); 270 | 271 | const float32x4_t vi0x89AB = vld1q_f32(i0); i0 = (const float*) ((uintptr_t) i0 + input_tuple_stride); 272 | const float32x4_t vi1x89AB = vld1q_f32(i1); i1 = (const float*) ((uintptr_t) i1 + input_tuple_stride); 273 | const float32x4_t vi2x89AB = vld1q_f32(i2); i2 = (const float*) ((uintptr_t) i2 + input_tuple_stride); 274 | 275 | vo4567p0 = vfmaq_laneq_f32(vo4567p0, vi0x4567, vw0123, 2); 276 | float32x4_t vo4567p1 = vmulq_laneq_f32(vi1x4567, vw4567, 1); 277 | float32x4_t vo4567p2 = vmulq_lane_f32(vi2x4567, vw89, 0); 278 | 279 | const float32x4_t vi0x3456 = vextq_f32(vi0x0123, vi0x4567, 3); 280 | const float32x4_t vi1x3456 = vextq_f32(vi1x0123, vi1x4567, 3); 281 | const float32x4_t vi2x3456 = vextq_f32(vi2x0123, vi2x4567, 3); 282 | 283 | vo4567p0 = vfmaq_laneq_f32(vo4567p0, vi0x3456, vw0123, 1); 284 | vo4567p1 = vfmaq_laneq_f32(vo4567p1, vi1x3456, vw4567, 0); 285 | vo4567p2 = vfmaq_laneq_f32(vo4567p2, vi2x3456, vw4567, 3); 286 | 287 | vi0x0123 = vi0x4567; 288 | vi1x0123 = vi1x4567; 289 | vi2x0123 = vi2x4567; 290 | 291 | const float32x4_t vi0x5678 = vextq_f32(vi0x4567, vi0x89AB, 1); 292 | const float32x4_t vi1x5678 = vextq_f32(vi1x4567, vi1x89AB, 1); 293 | const float32x4_t vi2x5678 = vextq_f32(vi2x4567, vi2x89AB, 1); 294 | 295 | vo4567p0 = vfmaq_laneq_f32(vo4567p0, vi0x5678, vw0123, 3); 296 | vo4567p1 = vfmaq_laneq_f32(vo4567p1, vi1x5678, vw4567, 2); 297 | vo4567p2 = vfmaq_lane_f32(vo4567p2, vi2x5678, vw89, 1); 298 | 299 | vi0x4567 = vi0x89AB; 300 | vi1x4567 = vi1x89AB; 301 | vi2x4567 = vi2x89AB; 302 | 303 | float32x4_t vo = vaddq_f32(vo4567p0, vo4567p1); 304 | vo = vaddq_f32(vo, vo4567p2); 305 | 306 | vo = vmaxq_f32(vo, vmin); 307 | vo = vminq_f32(vo, vmax); 308 | 309 | vst1q_f32(output0, vo); output0 = (float*) ((uintptr_t) output0 + output_tuple_stride); 310 | } 311 | // Always process the last block of 1..4 pixels. 312 | assert(k >= 1); 313 | assert(k <= 4); 314 | { 315 | float32x4_t vo4567p0 = vdupq_laneq_f32(vw0123, 0); 316 | 317 | vi0x4567 = vreinterpretq_u32_f32(vandq_u32(vmask, vreinterpretq_f32_u32(vi0x4567))); 318 | vi1x4567 = vreinterpretq_u32_f32(vandq_u32(vmask, vreinterpretq_f32_u32(vi1x4567))); 319 | vi2x4567 = vreinterpretq_u32_f32(vandq_u32(vmask, vreinterpretq_f32_u32(vi2x4567))); 320 | 321 | vo4567p0 = vfmaq_laneq_f32(vo4567p0, vi0x4567, vw0123, 2); 322 | float32x4_t vo4567p1 = vmulq_laneq_f32(vi1x4567, vw4567, 1); 323 | float32x4_t vo4567p2 = vmulq_lane_f32(vi2x4567, vw89, 0); 324 | 325 | const float32x4_t vi0x3456 = vextq_f32(vi0x0123, vi0x4567, 3); 326 | const float32x4_t vi1x3456 = vextq_f32(vi1x0123, vi1x4567, 3); 327 | const float32x4_t vi2x3456 = vextq_f32(vi2x0123, vi2x4567, 3); 328 | 329 | vo4567p0 = vfmaq_laneq_f32(vo4567p0, vi0x3456, vw0123, 1); 330 | vo4567p1 = vfmaq_laneq_f32(vo4567p1, vi1x3456, vw4567, 0); 331 | vo4567p2 = vfmaq_laneq_f32(vo4567p2, vi2x3456, vw4567, 3); 332 | 333 | const float32x4_t vzero = vmovq_n_f32(0.0f); 334 | const float32x4_t vi0x5678 = vextq_f32(vi0x4567, vzero, 1); 335 | const float32x4_t vi1x5678 = vextq_f32(vi1x4567, vzero, 1); 336 | const float32x4_t vi2x5678 = vextq_f32(vi2x4567, vzero, 1); 337 | 338 | vo4567p0 = vfmaq_laneq_f32(vo4567p0, vi0x5678, vw0123, 3); 339 | vo4567p1 = vfmaq_laneq_f32(vo4567p1, vi1x5678, vw4567, 2); 340 | vo4567p2 = vfmaq_lane_f32(vo4567p2, vi2x5678, vw89, 1); 341 | 342 | float32x4_t vo = vaddq_f32(vo4567p0, vo4567p1); 343 | vo = vaddq_f32(vo, vo4567p2); 344 | 345 | vo = vmaxq_f32(vo, vmin); 346 | vo = vminq_f32(vo, vmax); 347 | 348 | if (k & 4) { 349 | vst1q_f32(output0, vo); 350 | } else { 351 | float* output0_lo = output0; 352 | float32x2_t vo_lo = vget_low_f32(vo); 353 | if (k & 2) { 354 | vst1_f32(output0_lo, vo_lo); output0_lo += 2; 355 | vo_lo = vget_high_f32(vo); 356 | } 357 | if (k & 1) { 358 | vst1_lane_f32(output0_lo, vo_lo, 0); 359 | } 360 | } 361 | } 362 | 363 | i0 = (const float*) ((uintptr_t) i0 + input_width_increment_single); 364 | i1 = (const float*) ((uintptr_t) i1 + input_width_increment_single); 365 | i2 = (const float*) ((uintptr_t) i2 + input_width_increment_single); 366 | output0 = (float*) ((uintptr_t) output0 + output_width_increment_single); 367 | m -= 1; 368 | } 369 | } 370 | -------------------------------------------------------------------------------- /dwconv-3x3p1-scalar.c: -------------------------------------------------------------------------------- 1 | // Copyright The Fast Sparse ConvNets Authors. 2 | // 3 | // This source code is licensed under the BSD-style license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include 7 | 8 | 9 | void f32_dwconv_spchw_ukernel_3x3p1__scalar( 10 | size_t m, 11 | size_t n, 12 | const float* input, 13 | const float* weights, 14 | float* output, 15 | size_t input_tuple_stride, 16 | size_t output_tuple_stride, 17 | size_t input_width_stride, 18 | size_t output_width_stride, 19 | const union f32_spchw_params params[restrict static 1]) 20 | { 21 | assert(n != 0); 22 | 23 | const size_t input_width_increment = input_width_stride - n * input_tuple_stride; 24 | const size_t output_width_increment = output_width_stride - (n - 1) * output_tuple_stride; 25 | 26 | const float params_min = params->scalar.min; 27 | const float params_max = params->scalar.max; 28 | 29 | // No vertical padding. 30 | const float* i0 = input; 31 | const float* i1 = (const float*) ((uintptr_t) i0 + input_width_stride); 32 | const float* i2 = (const float*) ((uintptr_t) i1 + input_width_stride); 33 | 34 | float* output0 = output; 35 | 36 | const float vw0 = weights[0]; 37 | const float vw1 = weights[1]; 38 | const float vw2 = weights[2]; 39 | const float vw3 = weights[3]; 40 | const float vw4 = weights[4]; 41 | const float vw5 = weights[5]; 42 | const float vw6 = weights[6]; 43 | const float vw7 = weights[7]; 44 | const float vw8 = weights[8]; 45 | const float vw9 = weights[9]; 46 | 47 | while (m > 0) { 48 | float vi0x0 = 0.0f; 49 | float vi1x0 = 0.0f; 50 | float vi2x0 = 0.0f; 51 | float vi0x1 = *i0; i0 = (const float*) ((uintptr_t) i0 + input_tuple_stride); 52 | float vi1x1 = *i1; i1 = (const float*) ((uintptr_t) i1 + input_tuple_stride); 53 | float vi2x1 = *i2; i2 = (const float*) ((uintptr_t) i2 + input_tuple_stride); 54 | 55 | size_t k = n; 56 | for (; k > 1; k--) { 57 | const float vi0x2 = *i0; i0 = (const float*) ((uintptr_t) i0 + input_tuple_stride); 58 | const float vi1x2 = *i1; i1 = (const float*) ((uintptr_t) i1 + input_tuple_stride); 59 | const float vi2x2 = *i2; i2 = (const float*) ((uintptr_t) i2 + input_tuple_stride); 60 | 61 | const float vrow0_accum = vw1 * vi0x0 + vw2 * vi0x1 + vw3 * vi0x2; 62 | vi0x0 = vi0x1; 63 | vi0x1 = vi0x2; 64 | const float vrow1_accum = vw4 * vi1x0 + vw5 * vi1x1 + vw6 * vi1x2; 65 | vi1x0 = vi1x1; 66 | vi1x1 = vi1x2; 67 | const float vrow2_accum = vw7 * vi2x0 + vw8 * vi2x1 + vw9 * vi2x2; 68 | vi2x0 = vi2x1; 69 | vi2x1 = vi2x2; 70 | 71 | float voutput = (vw0 + vrow0_accum) + (vrow1_accum + vrow2_accum); 72 | 73 | voutput = math_max_f32(voutput, params_min); 74 | voutput = math_min_f32(voutput, params_max); 75 | 76 | *output0 = voutput; output0 = (float *) ((uintptr_t) output0 + output_tuple_stride); 77 | } 78 | // Always process the last pixel separately to account for right edge. 79 | assert(k == 1); 80 | { 81 | const float vrow0_accum = vw1 * vi0x0 + vw2 * vi0x1; 82 | const float vrow1_accum = vw4 * vi1x0 + vw5 * vi1x1; 83 | const float vrow2_accum = vw7 * vi2x0 + vw8 * vi2x1; 84 | 85 | float voutput = (vw0 + vrow0_accum) + (vrow1_accum + vrow2_accum); 86 | 87 | voutput = math_max_f32(voutput, params_min); 88 | voutput = math_min_f32(voutput, params_max); 89 | 90 | *output0 = voutput; 91 | } 92 | 93 | i0 = (const float*) ((uintptr_t) i0 + input_width_increment); 94 | i1 = (const float*) ((uintptr_t) i1 + input_width_increment); 95 | i2 = (const float*) ((uintptr_t) i2 + input_width_increment); 96 | output0 = (float*) ((uintptr_t) output0 + output_width_increment); 97 | m--; 98 | } 99 | } 100 | -------------------------------------------------------------------------------- /dwconv-3x3s2p1-neonfma.c: -------------------------------------------------------------------------------- 1 | // Copyright The Fast Sparse ConvNets Authors. 2 | // 3 | // This source code is licensed under the BSD-style license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include 7 | 8 | #include 9 | 10 | void f32_dwconv_spchw_ukernel_3x3s2p1__neonfma( 11 | size_t m, 12 | size_t n, 13 | const float* input, 14 | const float* weights, 15 | float* output, 16 | size_t input_tuple_stride, 17 | size_t output_tuple_stride, 18 | size_t input_width_stride, 19 | size_t output_width_stride, 20 | const union f32_spchw_params params[restrict static 1]) 21 | { 22 | assert(n != 0); 23 | 24 | const uint32x4_t vmask_even = vld1q_u32(params->neon.mask_even); 25 | const uint32x4_t vmask_odd = vld1q_u32(params->neon.mask_odd); 26 | const float32x4_t vmax = vld1q_dup_f32(¶ms->neon.max); 27 | const float32x4_t vmin = vld1q_dup_f32(¶ms->neon.min); 28 | 29 | const size_t input_width_increment = input_width_stride * 2 - n / 8 * input_tuple_stride * 2; 30 | const size_t output_width_increment = output_width_stride - n / 8 * output_tuple_stride; 31 | 32 | // No vertical padding. 33 | const float* i0 = input; 34 | const float* i1 = (const float*) ((uintptr_t) i0 + input_width_stride); 35 | const float* i2 = (const float*) ((uintptr_t) i1 + input_width_stride); 36 | 37 | const float32x4_t vw0123 = vld1q_f32(weights); 38 | const float32x4_t vw4567 = vld1q_f32(weights + 4); 39 | const float32x2_t vw89 = vld1_f32(weights + 8); 40 | 41 | do { 42 | float32x4_t vi0x0123 = vmovq_n_f32(0.0f); 43 | float32x4_t vi1x0123 = vmovq_n_f32(0.0f); 44 | float32x4_t vi2x0123 = vmovq_n_f32(0.0f); 45 | 46 | size_t k = n; 47 | for (; k >= 8; k -= 8) { 48 | float32x4_t vo468Ap0 = vdupq_laneq_f32(vw0123, 0); 49 | 50 | const float32x4_t vi0x4567 = vld1q_f32(i0); i0 = (const float*) ((uintptr_t) i0 + input_tuple_stride); 51 | const float32x4_t vi1x4567 = vld1q_f32(i1); i1 = (const float*) ((uintptr_t) i1 + input_tuple_stride); 52 | const float32x4_t vi2x4567 = vld1q_f32(i2); i2 = (const float*) ((uintptr_t) i2 + input_tuple_stride); 53 | 54 | const float32x4_t vi0x89AB = vld1q_f32(i0); i0 = (const float*) ((uintptr_t) i0 + input_tuple_stride); 55 | const float32x4_t vi1x89AB = vld1q_f32(i1); i1 = (const float*) ((uintptr_t) i1 + input_tuple_stride); 56 | const float32x4_t vi2x89AB = vld1q_f32(i2); i2 = (const float*) ((uintptr_t) i2 + input_tuple_stride); 57 | 58 | const float32x4_t vi0x468A = vuzp1q_f32(vi0x4567, vi0x89AB); 59 | const float32x4_t vi0x579B = vuzp2q_f32(vi0x4567, vi0x89AB); 60 | const float32x4_t vi1x468A = vuzp1q_f32(vi1x4567, vi1x89AB); 61 | const float32x4_t vi1x579B = vuzp2q_f32(vi1x4567, vi1x89AB); 62 | const float32x4_t vi2x468A = vuzp1q_f32(vi2x4567, vi2x89AB); 63 | const float32x4_t vi2x579B = vuzp2q_f32(vi2x4567, vi2x89AB); 64 | // add bias only to first row, it will then get added 65 | // to the final result 66 | // multiply each row by corresponding row of center column of filter 67 | vo468Ap0 = vfmaq_laneq_f32(vo468Ap0, vi0x468A, vw0123, 2); 68 | float32x4_t vo468Ap1 = vmulq_laneq_f32(vi1x468A, vw4567, 1); 69 | float32x4_t vo468Ap2 = vmulq_lane_f32(vi2x468A, vw89, 0); 70 | 71 | // grab the values corresponding the left filter tap 72 | const float32x4_t vi0x3579 = vextq_f32(vi0x0123, vi0x579B, 3); 73 | const float32x4_t vi1x3579 = vextq_f32(vi1x0123, vi1x579B, 3); 74 | const float32x4_t vi2x3579 = vextq_f32(vi2x0123, vi2x579B, 3); 75 | 76 | vi0x0123 = vi0x89AB; 77 | vi1x0123 = vi1x89AB; 78 | vi2x0123 = vi2x89AB; 79 | 80 | vo468Ap0 = vfmaq_laneq_f32(vo468Ap0, vi0x3579, vw0123, 1); 81 | vo468Ap1 = vfmaq_laneq_f32(vo468Ap1, vi1x3579, vw4567, 0); 82 | vo468Ap2 = vfmaq_laneq_f32(vo468Ap2, vi2x3579, vw4567, 3); 83 | 84 | // Do multiplication by right filter tap. 85 | vo468Ap0 = vfmaq_laneq_f32(vo468Ap0, vi0x579B, vw0123, 3); 86 | vo468Ap1 = vfmaq_laneq_f32(vo468Ap1, vi1x579B, vw4567, 2); 87 | vo468Ap2 = vfmaq_lane_f32 (vo468Ap2, vi2x579B, vw89, 1); 88 | 89 | // Add up across rows to get the final outputs. 90 | float32x4_t vo = vaddq_f32(vo468Ap0, vo468Ap1); 91 | vo = vaddq_f32(vo, vo468Ap2); 92 | 93 | vo = vmaxq_f32(vo, vmin); 94 | vo = vminq_f32(vo, vmax); 95 | 96 | vst1q_f32(output, vo); output = (float*) ((uintptr_t) output + output_tuple_stride); 97 | } 98 | // Last block has 0-7 pixels to process. 99 | assert(k < 8); 100 | if (k != 0) { 101 | float32x4_t vo468Ap0 = vdupq_laneq_f32(vw0123, 0); 102 | 103 | const float32x4_t vi0x4567 = vld1q_f32(i0); 104 | const float32x4_t vi1x4567 = vld1q_f32(i1); 105 | const float32x4_t vi2x4567 = vld1q_f32(i2); 106 | 107 | const float32x4_t vi0x89AB = vld1q_f32((const float*) ((uintptr_t) i0 + input_tuple_stride)); 108 | const float32x4_t vi1x89AB = vld1q_f32((const float*) ((uintptr_t) i1 + input_tuple_stride)); 109 | const float32x4_t vi2x89AB = vld1q_f32((const float*) ((uintptr_t) i2 + input_tuple_stride)); 110 | 111 | const float32x4_t vi0x468A = vreinterpretq_u32_f32(vandq_u32(vmask_even, vreinterpretq_f32_u32(vuzp1q_f32(vi0x4567, vi0x89AB)))); 112 | const float32x4_t vi0x579B = vreinterpretq_u32_f32(vandq_u32(vmask_odd, vreinterpretq_f32_u32(vuzp2q_f32(vi0x4567, vi0x89AB)))); 113 | const float32x4_t vi1x468A = vreinterpretq_u32_f32(vandq_u32(vmask_even, vreinterpretq_f32_u32(vuzp1q_f32(vi1x4567, vi1x89AB)))); 114 | const float32x4_t vi1x579B = vreinterpretq_u32_f32(vandq_u32(vmask_odd, vreinterpretq_f32_u32(vuzp2q_f32(vi1x4567, vi1x89AB)))); 115 | const float32x4_t vi2x468A = vreinterpretq_u32_f32(vandq_u32(vmask_even, vreinterpretq_f32_u32(vuzp1q_f32(vi2x4567, vi2x89AB)))); 116 | const float32x4_t vi2x579B = vreinterpretq_u32_f32(vandq_u32(vmask_odd, vreinterpretq_f32_u32(vuzp2q_f32(vi2x4567, vi2x89AB)))); 117 | // add bias only to first row, it will then get added 118 | // to the final result 119 | // multiply each row by corresponding row of center column of filter 120 | vo468Ap0 = vfmaq_laneq_f32(vo468Ap0, vi0x468A, vw0123, 2); 121 | float32x4_t vo468Ap1 = vmulq_laneq_f32(vi1x468A, vw4567, 1); 122 | float32x4_t vo468Ap2 = vmulq_lane_f32(vi2x468A, vw89, 0); 123 | 124 | // grab the values corresponding the left filter tap 125 | const float32x4_t vi0x3579 = vextq_f32(vi0x0123, vi0x579B, 3); 126 | const float32x4_t vi1x3579 = vextq_f32(vi1x0123, vi1x579B, 3); 127 | const float32x4_t vi2x3579 = vextq_f32(vi2x0123, vi2x579B, 3); 128 | 129 | vo468Ap0 = vfmaq_laneq_f32(vo468Ap0, vi0x3579, vw0123, 1); 130 | vo468Ap1 = vfmaq_laneq_f32(vo468Ap1, vi1x3579, vw4567, 0); 131 | vo468Ap2 = vfmaq_laneq_f32(vo468Ap2, vi2x3579, vw4567, 3); 132 | 133 | // do multiplication by right filter tap 134 | vo468Ap0 = vfmaq_laneq_f32(vo468Ap0, vi0x579B, vw0123, 3); 135 | vo468Ap1 = vfmaq_laneq_f32(vo468Ap1, vi1x579B, vw4567, 2); 136 | vo468Ap2 = vfmaq_lane_f32 (vo468Ap2, vi2x579B, vw89, 1); 137 | 138 | // add up across rows to get the final outputs 139 | float32x4_t vo = vaddq_f32(vo468Ap0, vo468Ap1); 140 | vo = vaddq_f32(vo, vo468Ap2); 141 | 142 | vo = vmaxq_f32(vo, vmin); 143 | vo = vminq_f32(vo, vmax); 144 | 145 | k += 1; 146 | if (k & 8) { 147 | vst1q_f32(output, vo); 148 | } else { 149 | float* output_lo = output; 150 | float32x2_t vo_lo = vget_low_f32(vo); 151 | if (k & 4) { 152 | vst1_f32(output_lo, vo_lo); output_lo += 2; 153 | vo_lo = vget_high_f32(vo); 154 | } 155 | if (k & 2) { 156 | vst1_lane_f32(output_lo, vo_lo, 0); 157 | } 158 | } 159 | } 160 | 161 | i0 = (const float*) ((uintptr_t) i0 + input_width_increment); 162 | i1 = (const float*) ((uintptr_t) i1 + input_width_increment); 163 | i2 = (const float*) ((uintptr_t) i2 + input_width_increment); 164 | output = (float*) ((uintptr_t) output + output_width_increment); 165 | } while (--m != 0); 166 | } 167 | -------------------------------------------------------------------------------- /dwconv-3x3s2p1-scalar.c: -------------------------------------------------------------------------------- 1 | // Copyright The Fast Sparse ConvNets Authors. 2 | // 3 | // This source code is licensed under the BSD-style license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include 7 | 8 | 9 | void f32_dwconv_spchw_ukernel_3x3s2p1__scalar( 10 | size_t m, 11 | size_t n, 12 | const float* input, 13 | const float* weights, 14 | float* output, 15 | size_t input_tuple_stride, 16 | size_t output_tuple_stride, 17 | size_t input_width_stride, 18 | size_t output_width_stride, 19 | const union f32_spchw_params params[restrict static 1]) 20 | { 21 | assert(n != 0); 22 | 23 | const size_t input_width_increment = 2 * input_width_stride - (n/2) * 2 * input_tuple_stride; 24 | const size_t output_width_increment = output_width_stride - (n/2) * output_tuple_stride; 25 | 26 | const float params_min = params->scalar.min; 27 | const float params_max = params->scalar.max; 28 | 29 | // No vertical padding. 30 | const float* i0 = input; 31 | const float* i1 = (const float*) ((uintptr_t) i0 + input_width_stride); 32 | const float* i2 = (const float*) ((uintptr_t) i1 + input_width_stride); 33 | 34 | float* output0 = output; 35 | 36 | const float vw0 = weights[0]; 37 | const float vw1 = weights[1]; 38 | const float vw2 = weights[2]; 39 | const float vw3 = weights[3]; 40 | const float vw4 = weights[4]; 41 | const float vw5 = weights[5]; 42 | const float vw6 = weights[6]; 43 | const float vw7 = weights[7]; 44 | const float vw8 = weights[8]; 45 | const float vw9 = weights[9]; 46 | 47 | while (m > 0) { 48 | float vi0x0 = 0.0f; 49 | float vi1x0 = 0.0f; 50 | float vi2x0 = 0.0f; 51 | 52 | size_t k = n; 53 | for (; k >= 2; k -= 2) { 54 | const float vi0x1 = *i0; i0 = (const float*) ((uintptr_t) i0 + input_tuple_stride); 55 | const float vi1x1 = *i1; i1 = (const float*) ((uintptr_t) i1 + input_tuple_stride); 56 | const float vi2x1 = *i2; i2 = (const float*) ((uintptr_t) i2 + input_tuple_stride); 57 | const float vi0x2 = *i0; i0 = (const float*) ((uintptr_t) i0 + input_tuple_stride); 58 | const float vi1x2 = *i1; i1 = (const float*) ((uintptr_t) i1 + input_tuple_stride); 59 | const float vi2x2 = *i2; i2 = (const float*) ((uintptr_t) i2 + input_tuple_stride); 60 | 61 | const float vrow0_accum = vw1 * vi0x0 + vw2 * vi0x1 + vw3 * vi0x2; 62 | vi0x0 = vi0x2; 63 | const float vrow1_accum = vw4 * vi1x0 + vw5 * vi1x1 + vw6 * vi1x2; 64 | vi1x0 = vi1x2; 65 | const float vrow2_accum = vw7 * vi2x0 + vw8 * vi2x1 + vw9 * vi2x2; 66 | vi2x0 = vi2x2; 67 | 68 | float voutput = (vw0 + vrow0_accum) + (vrow1_accum + vrow2_accum); 69 | 70 | voutput = math_max_f32(voutput, params_min); 71 | voutput = math_min_f32(voutput, params_max); 72 | 73 | *output0 = voutput; output0 = (float *) ((uintptr_t) output0 + output_tuple_stride); 74 | } 75 | // Possibly process the last pixel separately to account for right edge. 76 | if (k == 1) 77 | { 78 | const float vi0x1 = i0[0]; 79 | const float vi1x1 = i1[0]; 80 | const float vi2x1 = i2[0]; 81 | const float vrow0_accum = vw1 * vi0x0 + vw2 * vi0x1; 82 | const float vrow1_accum = vw4 * vi1x0 + vw5 * vi1x1; 83 | const float vrow2_accum = vw7 * vi2x0 + vw8 * vi2x1; 84 | 85 | float voutput = (vw0 + vrow0_accum) + (vrow1_accum + vrow2_accum); 86 | 87 | voutput = math_max_f32(voutput, params_min); 88 | voutput = math_min_f32(voutput, params_max); 89 | 90 | *output0 = voutput; 91 | } 92 | 93 | i0 = (const float*) ((uintptr_t) i0 + input_width_increment); 94 | i1 = (const float*) ((uintptr_t) i1 + input_width_increment); 95 | i2 = (const float*) ((uintptr_t) i2 + input_width_increment); 96 | output0 = (float*) ((uintptr_t) output0 + output_width_increment); 97 | m--; 98 | } 99 | } 100 | -------------------------------------------------------------------------------- /dwconv-5x5p2-neonfma.c: -------------------------------------------------------------------------------- 1 | // Copyright The Fast Sparse ConvNets Authors. 2 | // 3 | // This source code is licensed under the BSD-style license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include 7 | 8 | #include 9 | 10 | void f32_dwconv_spchw_ukernel_5x5p2__neonfma( 11 | size_t m, 12 | size_t n, 13 | const float* input, 14 | const float* weights, 15 | float* output, 16 | size_t input_tuple_stride, 17 | size_t output_tuple_stride, 18 | size_t input_width_stride, 19 | size_t output_width_stride, 20 | const union f32_spchw_params params[restrict static 1]) 21 | { 22 | assert(n != 0); 23 | 24 | const uint32x4_t vmask = vld1q_u32(params->neon.mask); 25 | const float32x4_t vmax = vld1q_dup_f32(¶ms->neon.max); 26 | const float32x4_t vmin = vld1q_dup_f32(¶ms->neon.min); 27 | 28 | const size_t input_width_increment_single = input_width_stride - round_up_po2(n, 4) / 4 * input_tuple_stride; 29 | const size_t output_width_increment_single = output_width_stride - (n - 1) / 4 * output_tuple_stride; 30 | 31 | // No vertical padding. 32 | const float* i0 = input; 33 | const float* i1 = (const float*) ((uintptr_t) i0 + input_width_stride); 34 | const float* i2 = (const float*) ((uintptr_t) i1 + input_width_stride); 35 | const float* i3 = (const float*) ((uintptr_t) i2 + input_width_stride); 36 | const float* i4 = (const float*) ((uintptr_t) i3 + input_width_stride); 37 | 38 | float* output0 = output; 39 | 40 | const float32x4_t vw0123 = vld1q_f32(weights); 41 | const float32x4_t vw4567 = vld1q_f32(weights + 4); 42 | const float32x4_t vw89AB = vld1q_f32(weights + 8); 43 | const float32x4_t vwCDEF = vld1q_f32(weights + 12); 44 | const float32x4_t vwGHIJ = vld1q_f32(weights + 16); 45 | const float32x4_t vwKLMN = vld1q_f32(weights + 20); 46 | const float32x2_t vwOP = vld1_f32( weights + 24); 47 | 48 | do { 49 | float32x4_t vi0x0123 = vmovq_n_f32(0.0f); 50 | float32x4_t vi1x0123 = vmovq_n_f32(0.0f); 51 | float32x4_t vi2x0123 = vmovq_n_f32(0.0f); 52 | float32x4_t vi3x0123 = vmovq_n_f32(0.0f); 53 | float32x4_t vi4x0123 = vmovq_n_f32(0.0f); 54 | float32x4_t vi0x4567 = vld1q_f32(i0); i0 = (const float*) ((uintptr_t) i0 + input_tuple_stride); 55 | float32x4_t vi1x4567 = vld1q_f32(i1); i1 = (const float*) ((uintptr_t) i1 + input_tuple_stride); 56 | float32x4_t vi2x4567 = vld1q_f32(i2); i2 = (const float*) ((uintptr_t) i2 + input_tuple_stride); 57 | float32x4_t vi3x4567 = vld1q_f32(i3); i3 = (const float*) ((uintptr_t) i3 + input_tuple_stride); 58 | float32x4_t vi4x4567 = vld1q_f32(i4); i4 = (const float*) ((uintptr_t) i4 + input_tuple_stride); 59 | 60 | size_t k = n; 61 | for (; k > 8; k -= 4) { 62 | float32x4_t vo4567p00 = vdupq_laneq_f32(vw0123, 0); 63 | 64 | const float32x4_t vi0x89AB = vld1q_f32(i0); i0 = (const float*) ((uintptr_t) i0 + input_tuple_stride); 65 | const float32x4_t vi1x89AB = vld1q_f32(i1); i1 = (const float*) ((uintptr_t) i1 + input_tuple_stride); 66 | const float32x4_t vi2x89AB = vld1q_f32(i2); i2 = (const float*) ((uintptr_t) i2 + input_tuple_stride); 67 | const float32x4_t vi3x89AB = vld1q_f32(i3); i3 = (const float*) ((uintptr_t) i3 + input_tuple_stride); 68 | const float32x4_t vi4x89AB = vld1q_f32(i4); i4 = (const float*) ((uintptr_t) i4 + input_tuple_stride); 69 | 70 | vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi0x4567, vw0123, 3); 71 | float32x4_t vo4567p01 = vmulq_laneq_f32(vi1x4567, vw89AB, 0); 72 | vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi2x4567, vwCDEF, 1); 73 | vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi3x4567, vwGHIJ, 2); 74 | vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi4x4567, vwKLMN, 3); 75 | 76 | const float32x4_t vi0x3456 = vextq_f32(vi0x0123, vi0x4567, 3); 77 | const float32x4_t vi1x3456 = vextq_f32(vi1x0123, vi1x4567, 3); 78 | const float32x4_t vi2x3456 = vextq_f32(vi2x0123, vi2x4567, 3); 79 | const float32x4_t vi3x3456 = vextq_f32(vi3x0123, vi3x4567, 3); 80 | const float32x4_t vi4x3456 = vextq_f32(vi4x0123, vi4x4567, 3); 81 | 82 | vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi0x3456, vw0123, 2); 83 | vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi1x3456, vw4567, 3); 84 | vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi2x3456, vwCDEF, 0); 85 | vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi3x3456, vwGHIJ, 1); 86 | vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi4x3456, vwKLMN, 2); 87 | 88 | const float32x4_t vi0x2345 = vextq_f32(vi0x0123, vi0x4567, 2); 89 | const float32x4_t vi1x2345 = vextq_f32(vi1x0123, vi1x4567, 2); 90 | const float32x4_t vi2x2345 = vextq_f32(vi2x0123, vi2x4567, 2); 91 | const float32x4_t vi3x2345 = vextq_f32(vi3x0123, vi3x4567, 2); 92 | const float32x4_t vi4x2345 = vextq_f32(vi4x0123, vi4x4567, 2); 93 | 94 | vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi0x2345, vw0123, 1); 95 | vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi1x2345, vw4567, 2); 96 | vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi2x2345, vw89AB, 3); 97 | vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi3x2345, vwGHIJ, 0); 98 | vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi4x2345, vwKLMN, 1); 99 | 100 | vi0x0123 = vi0x4567; 101 | vi1x0123 = vi1x4567; 102 | vi2x0123 = vi2x4567; 103 | vi3x0123 = vi3x4567; 104 | vi4x0123 = vi4x4567; 105 | 106 | const float32x4_t vi0x5678 = vextq_f32(vi0x4567, vi0x89AB, 1); 107 | const float32x4_t vi1x5678 = vextq_f32(vi1x4567, vi1x89AB, 1); 108 | const float32x4_t vi2x5678 = vextq_f32(vi2x4567, vi2x89AB, 1); 109 | const float32x4_t vi3x5678 = vextq_f32(vi3x4567, vi3x89AB, 1); 110 | const float32x4_t vi4x5678 = vextq_f32(vi4x4567, vi4x89AB, 1); 111 | 112 | vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi0x5678, vw4567, 0); 113 | vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi1x5678, vw89AB, 1); 114 | vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi2x5678, vwCDEF, 2); 115 | vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi3x5678, vwGHIJ, 3); 116 | vo4567p00 = vfmaq_lane_f32( vo4567p00, vi4x5678, vwOP, 0); 117 | 118 | const float32x4_t vi0x6789 = vextq_f32(vi0x4567, vi0x89AB, 2); 119 | const float32x4_t vi1x6789 = vextq_f32(vi1x4567, vi1x89AB, 2); 120 | const float32x4_t vi2x6789 = vextq_f32(vi2x4567, vi2x89AB, 2); 121 | const float32x4_t vi3x6789 = vextq_f32(vi3x4567, vi3x89AB, 2); 122 | const float32x4_t vi4x6789 = vextq_f32(vi4x4567, vi4x89AB, 2); 123 | 124 | vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi0x6789, vw4567, 1); 125 | vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi1x6789, vw89AB, 2); 126 | vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi2x6789, vwCDEF, 3); 127 | vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi3x6789, vwKLMN, 0); 128 | vo4567p00 = vfmaq_lane_f32( vo4567p00, vi4x6789, vwOP, 1); 129 | 130 | vi0x4567 = vi0x89AB; 131 | vi1x4567 = vi1x89AB; 132 | vi2x4567 = vi2x89AB; 133 | vi3x4567 = vi3x89AB; 134 | vi4x4567 = vi4x89AB; 135 | 136 | vo4567p00 = vaddq_f32(vo4567p00, vo4567p01); 137 | 138 | float32x4_t vo0 = vo4567p00; 139 | 140 | vo0 = vmaxq_f32(vo0, vmin); 141 | vo0 = vminq_f32(vo0, vmax); 142 | 143 | vst1q_f32(output0, vo0); output0 = (float*) ((uintptr_t) output0 + output_tuple_stride); 144 | } 145 | // Always process the last block of 5..8 pixels. 146 | if (k > 4) 147 | { 148 | float32x4_t vo4567p00 = vdupq_laneq_f32(vw0123, 0); 149 | 150 | float32x4_t vi0x89AB = vld1q_f32(i0); i0 = (const float*) ((uintptr_t) i0 + input_tuple_stride); 151 | float32x4_t vi1x89AB = vld1q_f32(i1); i1 = (const float*) ((uintptr_t) i1 + input_tuple_stride); 152 | float32x4_t vi2x89AB = vld1q_f32(i2); i2 = (const float*) ((uintptr_t) i2 + input_tuple_stride); 153 | float32x4_t vi3x89AB = vld1q_f32(i3); i3 = (const float*) ((uintptr_t) i3 + input_tuple_stride); 154 | float32x4_t vi4x89AB = vld1q_f32(i4); i4 = (const float*) ((uintptr_t) i4 + input_tuple_stride); 155 | 156 | vi0x89AB = vreinterpretq_u32_f32(vandq_u32(vmask, vreinterpretq_f32_u32(vi0x89AB))); 157 | vi1x89AB = vreinterpretq_u32_f32(vandq_u32(vmask, vreinterpretq_f32_u32(vi1x89AB))); 158 | vi2x89AB = vreinterpretq_u32_f32(vandq_u32(vmask, vreinterpretq_f32_u32(vi2x89AB))); 159 | vi3x89AB = vreinterpretq_u32_f32(vandq_u32(vmask, vreinterpretq_f32_u32(vi3x89AB))); 160 | vi4x89AB = vreinterpretq_u32_f32(vandq_u32(vmask, vreinterpretq_f32_u32(vi4x89AB))); 161 | 162 | vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi0x4567, vw0123, 3); 163 | float32x4_t vo4567p01 = vmulq_laneq_f32(vi1x4567, vw89AB, 0); 164 | vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi2x4567, vwCDEF, 1); 165 | vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi3x4567, vwGHIJ, 2); 166 | vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi4x4567, vwKLMN, 3); 167 | 168 | const float32x4_t vi0x3456 = vextq_f32(vi0x0123, vi0x4567, 3); 169 | const float32x4_t vi1x3456 = vextq_f32(vi1x0123, vi1x4567, 3); 170 | const float32x4_t vi2x3456 = vextq_f32(vi2x0123, vi2x4567, 3); 171 | const float32x4_t vi3x3456 = vextq_f32(vi3x0123, vi3x4567, 3); 172 | const float32x4_t vi4x3456 = vextq_f32(vi4x0123, vi4x4567, 3); 173 | 174 | vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi0x3456, vw0123, 2); 175 | vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi1x3456, vw4567, 3); 176 | vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi2x3456, vwCDEF, 0); 177 | vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi3x3456, vwGHIJ, 1); 178 | vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi4x3456, vwKLMN, 2); 179 | 180 | const float32x4_t vi0x2345 = vextq_f32(vi0x0123, vi0x4567, 2); 181 | const float32x4_t vi1x2345 = vextq_f32(vi1x0123, vi1x4567, 2); 182 | const float32x4_t vi2x2345 = vextq_f32(vi2x0123, vi2x4567, 2); 183 | const float32x4_t vi3x2345 = vextq_f32(vi3x0123, vi3x4567, 2); 184 | const float32x4_t vi4x2345 = vextq_f32(vi4x0123, vi4x4567, 2); 185 | 186 | vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi0x2345, vw0123, 1); 187 | vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi1x2345, vw4567, 2); 188 | vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi2x2345, vw89AB, 3); 189 | vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi3x2345, vwGHIJ, 0); 190 | vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi4x2345, vwKLMN, 1); 191 | 192 | vi0x0123 = vi0x4567; 193 | vi1x0123 = vi1x4567; 194 | vi2x0123 = vi2x4567; 195 | vi3x0123 = vi3x4567; 196 | vi4x0123 = vi4x4567; 197 | 198 | const float32x4_t vi0x5678 = vextq_f32(vi0x4567, vi0x89AB, 1); 199 | const float32x4_t vi1x5678 = vextq_f32(vi1x4567, vi1x89AB, 1); 200 | const float32x4_t vi2x5678 = vextq_f32(vi2x4567, vi2x89AB, 1); 201 | const float32x4_t vi3x5678 = vextq_f32(vi3x4567, vi3x89AB, 1); 202 | const float32x4_t vi4x5678 = vextq_f32(vi4x4567, vi4x89AB, 1); 203 | 204 | vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi0x5678, vw4567, 0); 205 | vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi1x5678, vw89AB, 1); 206 | vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi2x5678, vwCDEF, 2); 207 | vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi3x5678, vwGHIJ, 3); 208 | vo4567p00 = vfmaq_lane_f32( vo4567p00, vi4x5678, vwOP, 0); 209 | 210 | const float32x4_t vi0x6789 = vextq_f32(vi0x4567, vi0x89AB, 2); 211 | const float32x4_t vi1x6789 = vextq_f32(vi1x4567, vi1x89AB, 2); 212 | const float32x4_t vi2x6789 = vextq_f32(vi2x4567, vi2x89AB, 2); 213 | const float32x4_t vi3x6789 = vextq_f32(vi3x4567, vi3x89AB, 2); 214 | const float32x4_t vi4x6789 = vextq_f32(vi4x4567, vi4x89AB, 2); 215 | 216 | vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi0x6789, vw4567, 1); 217 | vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi1x6789, vw89AB, 2); 218 | vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi2x6789, vwCDEF, 3); 219 | vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi3x6789, vwKLMN, 0); 220 | vo4567p00 = vfmaq_lane_f32( vo4567p00, vi4x6789, vwOP, 1); 221 | 222 | vi0x4567 = vi0x89AB; 223 | vi1x4567 = vi1x89AB; 224 | vi2x4567 = vi2x89AB; 225 | vi3x4567 = vi3x89AB; 226 | vi4x4567 = vi4x89AB; 227 | 228 | vo4567p00 = vaddq_f32(vo4567p00, vo4567p01); 229 | float32x4_t vo0 = vo4567p00; 230 | 231 | vo0 = vmaxq_f32(vo0, vmin); 232 | vo0 = vminq_f32(vo0, vmax); 233 | 234 | vst1q_f32(output0, vo0); output0 = (float*) ((uintptr_t) output0 + output_tuple_stride); 235 | k -= 4; 236 | } 237 | assert(k >= 1); 238 | assert(k <= 4); 239 | { 240 | float32x4_t vo4567p00 = vdupq_laneq_f32(vw0123, 0); 241 | 242 | // This might have already happened if there are more than 4 pixels, but 243 | // we can't count on it. 244 | vi0x4567 = vreinterpretq_u32_f32(vandq_u32(vmask, vreinterpretq_f32_u32(vi0x4567))); 245 | vi1x4567 = vreinterpretq_u32_f32(vandq_u32(vmask, vreinterpretq_f32_u32(vi1x4567))); 246 | vi2x4567 = vreinterpretq_u32_f32(vandq_u32(vmask, vreinterpretq_f32_u32(vi2x4567))); 247 | vi3x4567 = vreinterpretq_u32_f32(vandq_u32(vmask, vreinterpretq_f32_u32(vi3x4567))); 248 | vi4x4567 = vreinterpretq_u32_f32(vandq_u32(vmask, vreinterpretq_f32_u32(vi4x4567))); 249 | 250 | vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi0x4567, vw0123, 3); 251 | float32x4_t vo4567p01 = vmulq_laneq_f32(vi1x4567, vw89AB, 0); 252 | vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi2x4567, vwCDEF, 1); 253 | vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi3x4567, vwGHIJ, 2); 254 | vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi4x4567, vwKLMN, 3); 255 | 256 | const float32x4_t vi0x3456 = vextq_f32(vi0x0123, vi0x4567, 3); 257 | const float32x4_t vi1x3456 = vextq_f32(vi1x0123, vi1x4567, 3); 258 | const float32x4_t vi2x3456 = vextq_f32(vi2x0123, vi2x4567, 3); 259 | const float32x4_t vi3x3456 = vextq_f32(vi3x0123, vi3x4567, 3); 260 | const float32x4_t vi4x3456 = vextq_f32(vi4x0123, vi4x4567, 3); 261 | 262 | vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi0x3456, vw0123, 2); 263 | vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi1x3456, vw4567, 3); 264 | vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi2x3456, vwCDEF, 0); 265 | vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi3x3456, vwGHIJ, 1); 266 | vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi4x3456, vwKLMN, 2); 267 | 268 | const float32x4_t vi0x2345 = vextq_f32(vi0x0123, vi0x4567, 2); 269 | const float32x4_t vi1x2345 = vextq_f32(vi1x0123, vi1x4567, 2); 270 | const float32x4_t vi2x2345 = vextq_f32(vi2x0123, vi2x4567, 2); 271 | const float32x4_t vi3x2345 = vextq_f32(vi3x0123, vi3x4567, 2); 272 | const float32x4_t vi4x2345 = vextq_f32(vi4x0123, vi4x4567, 2); 273 | 274 | vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi0x2345, vw0123, 1); 275 | vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi1x2345, vw4567, 2); 276 | vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi2x2345, vw89AB, 3); 277 | vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi3x2345, vwGHIJ, 0); 278 | vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi4x2345, vwKLMN, 1); 279 | 280 | const float32x4_t vzero = vmovq_n_f32(0.0f); 281 | const float32x4_t vi0x5678 = vextq_f32(vi0x4567, vzero, 1); 282 | const float32x4_t vi1x5678 = vextq_f32(vi1x4567, vzero, 1); 283 | const float32x4_t vi2x5678 = vextq_f32(vi2x4567, vzero, 1); 284 | const float32x4_t vi3x5678 = vextq_f32(vi3x4567, vzero, 1); 285 | const float32x4_t vi4x5678 = vextq_f32(vi4x4567, vzero, 1); 286 | 287 | vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi0x5678, vw4567, 0); 288 | vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi1x5678, vw89AB, 1); 289 | vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi2x5678, vwCDEF, 2); 290 | vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi3x5678, vwGHIJ, 3); 291 | vo4567p00 = vfmaq_lane_f32( vo4567p00, vi4x5678, vwOP, 0); 292 | 293 | const float32x4_t vi0x6789 = vextq_f32(vi0x4567, vzero, 2); 294 | const float32x4_t vi1x6789 = vextq_f32(vi1x4567, vzero, 2); 295 | const float32x4_t vi2x6789 = vextq_f32(vi2x4567, vzero, 2); 296 | const float32x4_t vi3x6789 = vextq_f32(vi3x4567, vzero, 2); 297 | const float32x4_t vi4x6789 = vextq_f32(vi4x4567, vzero, 2); 298 | 299 | vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi0x6789, vw4567, 1); 300 | vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi1x6789, vw89AB, 2); 301 | vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi2x6789, vwCDEF, 3); 302 | vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi3x6789, vwKLMN, 0); 303 | vo4567p00 = vfmaq_lane_f32( vo4567p00, vi4x6789, vwOP, 1); 304 | 305 | vo4567p00 = vaddq_f32(vo4567p00, vo4567p01); 306 | float32x4_t vo0 = vo4567p00; 307 | 308 | vo0 = vmaxq_f32(vo0, vmin); 309 | vo0 = vminq_f32(vo0, vmax); 310 | 311 | if (k & 4) { 312 | vst1q_f32(output0, vo0); 313 | } else { 314 | float* output0_lo = output0; 315 | float32x2_t vo0_lo = vget_low_f32(vo0); 316 | if (k & 2) { 317 | vst1_f32(output0_lo, vo0_lo); output0_lo += 2; 318 | vo0_lo = vget_high_f32(vo0); 319 | } 320 | if (k & 1) { 321 | vst1_lane_f32(output0_lo, vo0_lo, 0); 322 | } 323 | } 324 | } 325 | 326 | i0 = (const float*) ((uintptr_t) i0 + input_width_increment_single); 327 | i1 = (const float*) ((uintptr_t) i1 + input_width_increment_single); 328 | i2 = (const float*) ((uintptr_t) i2 + input_width_increment_single); 329 | i3 = (const float*) ((uintptr_t) i3 + input_width_increment_single); 330 | i4 = (const float*) ((uintptr_t) i4 + input_width_increment_single); 331 | output0 = (float*) ((uintptr_t) output0 + output_width_increment_single); 332 | m -= 1; 333 | } while (m > 0); 334 | } 335 | -------------------------------------------------------------------------------- /dwconv-5x5p2-scalar.c: -------------------------------------------------------------------------------- 1 | // Copyright The Fast Sparse ConvNets Authors. 2 | // 3 | // This source code is licensed under the BSD-style license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include 7 | 8 | void f32_dwconv_spchw_ukernel_5x5p2__scalar( 9 | size_t m, 10 | size_t n, 11 | const float* input, 12 | const float* weights, 13 | float* output, 14 | size_t input_tuple_stride, 15 | size_t output_tuple_stride, 16 | size_t input_width_stride, 17 | size_t output_width_stride, 18 | const union f32_spchw_params params[restrict static 1]) 19 | { 20 | assert(n != 0); 21 | 22 | const float params_max = params->scalar.max; 23 | const float params_min = params->scalar.min; 24 | 25 | const size_t input_width_increment_single = input_width_stride - n * input_tuple_stride; 26 | const size_t output_width_increment_single = output_width_stride - (n - 1) * output_tuple_stride; 27 | 28 | // No vertical padding. 29 | const float* i0 = input; 30 | const float* i1 = (const float*) ((uintptr_t) i0 + input_width_stride); 31 | const float* i2 = (const float*) ((uintptr_t) i1 + input_width_stride); 32 | const float* i3 = (const float*) ((uintptr_t) i2 + input_width_stride); 33 | const float* i4 = (const float*) ((uintptr_t) i3 + input_width_stride); 34 | 35 | float* output0 = output; 36 | 37 | // this almost certainly will use too many scalar registers 38 | // hope the compiler is good at spilling... 39 | const float vw0 = weights[0]; 40 | const float vw1 = weights[1]; 41 | const float vw2 = weights[2]; 42 | const float vw3 = weights[3]; 43 | const float vw4 = weights[4]; 44 | const float vw5 = weights[5]; 45 | const float vw6 = weights[6]; 46 | const float vw7 = weights[7]; 47 | const float vw8 = weights[8]; 48 | const float vw9 = weights[9]; 49 | const float vw10 = weights[10]; 50 | const float vw11 = weights[11]; 51 | const float vw12 = weights[12]; 52 | const float vw13 = weights[13]; 53 | const float vw14 = weights[14]; 54 | const float vw15 = weights[15]; 55 | const float vw16 = weights[16]; 56 | const float vw17 = weights[17]; 57 | const float vw18 = weights[18]; 58 | const float vw19 = weights[19]; 59 | const float vw20 = weights[20]; 60 | const float vw21 = weights[21]; 61 | const float vw22 = weights[22]; 62 | const float vw23 = weights[23]; 63 | const float vw24 = weights[24]; 64 | const float vw25 = weights[25]; 65 | 66 | do { 67 | float vi0x0 = 0.0f; 68 | float vi1x0 = 0.0f; 69 | float vi2x0 = 0.0f; 70 | float vi3x0 = 0.0f; 71 | float vi4x0 = 0.0f; 72 | float vi0x1 = 0.0f; 73 | float vi1x1 = 0.0f; 74 | float vi2x1 = 0.0f; 75 | float vi3x1 = 0.0f; 76 | float vi4x1 = 0.0f; 77 | float vi0x2 = *i0; i0 = (const float*) ((uintptr_t) i0 + input_tuple_stride); 78 | float vi1x2 = *i1; i1 = (const float*) ((uintptr_t) i1 + input_tuple_stride); 79 | float vi2x2 = *i2; i2 = (const float*) ((uintptr_t) i2 + input_tuple_stride); 80 | float vi3x2 = *i3; i3 = (const float*) ((uintptr_t) i3 + input_tuple_stride); 81 | float vi4x2 = *i4; i4 = (const float*) ((uintptr_t) i4 + input_tuple_stride); 82 | 83 | float vi0x3; 84 | float vi1x3; 85 | float vi2x3; 86 | float vi3x3; 87 | float vi4x3; 88 | if (n > 1) { 89 | vi0x3 = *i0; i0 = (const float*) ((uintptr_t) i0 + input_tuple_stride); 90 | vi1x3 = *i1; i1 = (const float*) ((uintptr_t) i1 + input_tuple_stride); 91 | vi2x3 = *i2; i2 = (const float*) ((uintptr_t) i2 + input_tuple_stride); 92 | vi3x3 = *i3; i3 = (const float*) ((uintptr_t) i3 + input_tuple_stride); 93 | vi4x3 = *i4; i4 = (const float*) ((uintptr_t) i4 + input_tuple_stride); 94 | } 95 | 96 | size_t k = n; 97 | for (; k > 2; k -= 1) { 98 | const float vi0x4 = *i0; i0 = (const float*) ((uintptr_t) i0 + input_tuple_stride); 99 | const float vi1x4 = *i1; i1 = (const float*) ((uintptr_t) i1 + input_tuple_stride); 100 | const float vi2x4 = *i2; i2 = (const float*) ((uintptr_t) i2 + input_tuple_stride); 101 | const float vi3x4 = *i3; i3 = (const float*) ((uintptr_t) i3 + input_tuple_stride); 102 | const float vi4x4 = *i4; i4 = (const float*) ((uintptr_t) i4 + input_tuple_stride); 103 | 104 | const float vrow0_accum = vw1 * vi0x0 + vw2 * vi0x1 + vw3 * vi0x2 + vw4 * vi0x3 + vw5 * vi0x4; 105 | vi0x0 = vi0x1; 106 | vi0x1 = vi0x2; 107 | vi0x2 = vi0x3; 108 | vi0x3 = vi0x4; 109 | const float vrow1_accum = vw6 * vi1x0 + vw7 * vi1x1 + vw8 * vi1x2 + vw9 * vi1x3 + vw10 * vi1x4; 110 | vi1x0 = vi1x1; 111 | vi1x1 = vi1x2; 112 | vi1x2 = vi1x3; 113 | vi1x3 = vi1x4; 114 | const float vrow2_accum = vw11 * vi2x0 + vw12 * vi2x1 + vw13 * vi2x2 + vw14 * vi2x3 + vw15 * vi2x4; 115 | vi2x0 = vi2x1; 116 | vi2x1 = vi2x2; 117 | vi2x2 = vi2x3; 118 | vi2x3 = vi2x4; 119 | const float vrow3_accum = vw16 * vi3x0 + vw17 * vi3x1 + vw18 * vi3x2 + vw19 * vi3x3 + vw20 * vi3x4; 120 | vi3x0 = vi3x1; 121 | vi3x1 = vi3x2; 122 | vi3x2 = vi3x3; 123 | vi3x3 = vi3x4; 124 | const float vrow4_accum = vw21 * vi4x0 + vw22 * vi4x1 + vw23 * vi4x2 + vw24 * vi4x3 + vw25 * vi4x4; 125 | vi4x0 = vi4x1; 126 | vi4x1 = vi4x2; 127 | vi4x2 = vi4x3; 128 | vi4x3 = vi4x4; 129 | 130 | float voutput = (vw0 + vrow0_accum) + (vrow1_accum + vrow2_accum) + (vrow3_accum + vrow4_accum); 131 | 132 | voutput = math_max_f32(voutput, params_min); 133 | voutput = math_min_f32(voutput, params_max); 134 | 135 | *output0 = voutput; output0 = (float*) ((uintptr_t) output0 + output_tuple_stride); 136 | } 137 | if (k > 1) { 138 | const float vrow0_accum = vw1 * vi0x0 + vw2 * vi0x1 + vw3 * vi0x2 + vw4 * vi0x3; 139 | vi0x0 = vi0x1; 140 | vi0x1 = vi0x2; 141 | vi0x2 = vi0x3; 142 | const float vrow1_accum = vw6 * vi1x0 + vw7 * vi1x1 + vw8 * vi1x2 + vw9 * vi1x3; 143 | vi1x0 = vi1x1; 144 | vi1x1 = vi1x2; 145 | vi1x2 = vi1x3; 146 | const float vrow2_accum = vw11 * vi2x0 + vw12 * vi2x1 + vw13 * vi2x2 + vw14 * vi2x3; 147 | vi2x0 = vi2x1; 148 | vi2x1 = vi2x2; 149 | vi2x2 = vi2x3; 150 | const float vrow3_accum = vw16 * vi3x0 + vw17 * vi3x1 + vw18 * vi3x2 + vw19 * vi3x3; 151 | vi3x0 = vi3x1; 152 | vi3x1 = vi3x2; 153 | vi3x2 = vi3x3; 154 | const float vrow4_accum = vw21 * vi4x0 + vw22 * vi4x1 + vw23 * vi4x2 + vw24 * vi4x3; 155 | vi4x0 = vi4x1; 156 | vi4x1 = vi4x2; 157 | vi4x2 = vi4x3; 158 | 159 | float voutput = (vw0 + vrow0_accum) + (vrow1_accum + vrow2_accum) + (vrow3_accum + vrow4_accum); 160 | 161 | voutput = math_max_f32(voutput, params_min); 162 | voutput = math_min_f32(voutput, params_max); 163 | 164 | *output0 = voutput; output0 = (float*) ((uintptr_t) output0 + output_tuple_stride); 165 | k -= 1; 166 | } 167 | assert(k == 1); 168 | { 169 | const float vrow0_accum = vw1 * vi0x0 + vw2 * vi0x1 + vw3 * vi0x2; 170 | const float vrow1_accum = vw6 * vi1x0 + vw7 * vi1x1 + vw8 * vi1x2; 171 | const float vrow2_accum = vw11 * vi2x0 + vw12 * vi2x1 + vw13 * vi2x2; 172 | const float vrow3_accum = vw16 * vi3x0 + vw17 * vi3x1 + vw18 * vi3x2; 173 | const float vrow4_accum = vw21 * vi4x0 + vw22 * vi4x1 + vw23 * vi4x2; 174 | 175 | float voutput = (vw0 + vrow0_accum) + (vrow1_accum + vrow2_accum) + (vrow3_accum + vrow4_accum); 176 | 177 | voutput = math_max_f32(voutput, params_min); 178 | voutput = math_min_f32(voutput, params_max); 179 | 180 | *output0 = voutput;; 181 | } 182 | 183 | i0 = (const float*) ((uintptr_t) i0 + input_width_increment_single); 184 | i1 = (const float*) ((uintptr_t) i1 + input_width_increment_single); 185 | i2 = (const float*) ((uintptr_t) i2 + input_width_increment_single); 186 | i3 = (const float*) ((uintptr_t) i3 + input_width_increment_single); 187 | i4 = (const float*) ((uintptr_t) i4 + input_width_increment_single); 188 | output0 = (float*) ((uintptr_t) output0 + output_width_increment_single); 189 | m -= 1; 190 | } while (m > 0); 191 | } 192 | -------------------------------------------------------------------------------- /dwconv-5x5s2p2-neonfma.c: -------------------------------------------------------------------------------- 1 | // Copyright The Fast Sparse ConvNets Authors. 2 | // 3 | // This source code is licensed under the BSD-style license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include 7 | 8 | #include 9 | 10 | 11 | void f32_dwconv_spchw_ukernel_5x5s2p2__neonfma( 12 | size_t m, 13 | size_t n, 14 | const float* input, 15 | const float* weights, 16 | float* output, 17 | size_t input_tuple_stride, 18 | size_t output_tuple_stride, 19 | size_t input_width_stride, 20 | size_t output_width_stride, 21 | const union f32_spchw_params params[restrict static 1]) 22 | { 23 | assert(n != 0); 24 | 25 | const uint32x4_t vmask_even = vld1q_u32(params->neon.mask_even); 26 | const uint32x4_t vmask_odd = vld1q_u32(params->neon.mask_odd); 27 | const float32x4_t vmax = vld1q_dup_f32(¶ms->neon.max); 28 | const float32x4_t vmin = vld1q_dup_f32(¶ms->neon.min); 29 | 30 | const size_t input_width_increment_single = input_width_stride * 2 - input_tuple_stride * ( (n - 1) / 4 + 1); 31 | const size_t output_width_increment_single = output_width_stride - (n + 1) / 8 * output_tuple_stride; 32 | 33 | // No vertical padding. 34 | const float* i0 = input; 35 | const float* i1 = (const float*) ((uintptr_t) i0 + input_width_stride); 36 | const float* i2 = (const float*) ((uintptr_t) i1 + input_width_stride); 37 | const float* i3 = (const float*) ((uintptr_t) i2 + input_width_stride); 38 | const float* i4 = (const float*) ((uintptr_t) i3 + input_width_stride); 39 | 40 | float* output0 = output; 41 | 42 | const float32x4_t vw0123 = vld1q_f32(weights); 43 | const float32x4_t vw4567 = vld1q_f32(weights + 4); 44 | const float32x4_t vw89AB = vld1q_f32(weights + 8); 45 | const float32x4_t vwCDEF = vld1q_f32(weights + 12); 46 | const float32x4_t vwGHIJ = vld1q_f32(weights + 16); 47 | const float32x4_t vwKLMN = vld1q_f32(weights + 20); 48 | const float32x2_t vwOP = vld1_f32( weights + 24); 49 | 50 | do { 51 | float32x4_t vi0x0123 = vmovq_n_f32(0.0f); 52 | float32x4_t vi1x0123 = vmovq_n_f32(0.0f); 53 | float32x4_t vi2x0123 = vmovq_n_f32(0.0f); 54 | float32x4_t vi3x0123 = vmovq_n_f32(0.0f); 55 | float32x4_t vi4x0123 = vmovq_n_f32(0.0f); 56 | float32x4_t vi0x4567 = vld1q_f32(i0); i0 = (const float*) ((uintptr_t) i0 + input_tuple_stride); 57 | float32x4_t vi1x4567 = vld1q_f32(i1); i1 = (const float*) ((uintptr_t) i1 + input_tuple_stride); 58 | float32x4_t vi2x4567 = vld1q_f32(i2); i2 = (const float*) ((uintptr_t) i2 + input_tuple_stride); 59 | float32x4_t vi3x4567 = vld1q_f32(i3); i3 = (const float*) ((uintptr_t) i3 + input_tuple_stride); 60 | float32x4_t vi4x4567 = vld1q_f32(i4); i4 = (const float*) ((uintptr_t) i4 + input_tuple_stride); 61 | 62 | long long k = n; 63 | for (; k > 0; k -= 8) { 64 | float32x4_t vo468Ap00 = vdupq_laneq_f32(vw0123, 0); 65 | 66 | float32x4_t vi0x89AB; 67 | float32x4_t vi1x89AB; 68 | float32x4_t vi2x89AB; 69 | float32x4_t vi3x89AB; 70 | float32x4_t vi4x89AB; 71 | 72 | if (k > 4) { 73 | vi0x89AB = vld1q_f32(i0); i0 = (const float*) ((uintptr_t) i0 + input_tuple_stride); 74 | vi1x89AB = vld1q_f32(i1); i1 = (const float*) ((uintptr_t) i1 + input_tuple_stride); 75 | vi2x89AB = vld1q_f32(i2); i2 = (const float*) ((uintptr_t) i2 + input_tuple_stride); 76 | vi3x89AB = vld1q_f32(i3); i3 = (const float*) ((uintptr_t) i3 + input_tuple_stride); 77 | vi4x89AB = vld1q_f32(i4); i4 = (const float*) ((uintptr_t) i4 + input_tuple_stride); 78 | } else { 79 | vi0x89AB = vmovq_n_f32(0.f); 80 | vi1x89AB = vmovq_n_f32(0.f); 81 | vi2x89AB = vmovq_n_f32(0.f); 82 | vi3x89AB = vmovq_n_f32(0.f); 83 | vi4x89AB = vmovq_n_f32(0.f); 84 | } 85 | 86 | float32x4_t vi0xCDEF; 87 | float32x4_t vi1xCDEF; 88 | float32x4_t vi2xCDEF; 89 | float32x4_t vi3xCDEF; 90 | float32x4_t vi4xCDEF; 91 | 92 | if (k > 8) { 93 | vi0xCDEF = vld1q_f32(i0); i0 = (const float*) ((uintptr_t) i0 + input_tuple_stride); 94 | vi1xCDEF = vld1q_f32(i1); i1 = (const float*) ((uintptr_t) i1 + input_tuple_stride); 95 | vi2xCDEF = vld1q_f32(i2); i2 = (const float*) ((uintptr_t) i2 + input_tuple_stride); 96 | vi3xCDEF = vld1q_f32(i3); i3 = (const float*) ((uintptr_t) i3 + input_tuple_stride); 97 | vi4xCDEF = vld1q_f32(i4); i4 = (const float*) ((uintptr_t) i4 + input_tuple_stride); 98 | } else { 99 | vi0xCDEF = vmovq_n_f32(0.f); 100 | vi1xCDEF = vmovq_n_f32(0.f); 101 | vi2xCDEF = vmovq_n_f32(0.f); 102 | vi3xCDEF = vmovq_n_f32(0.f); 103 | vi4xCDEF = vmovq_n_f32(0.f); 104 | } 105 | float32x4_t vi0x468A = vuzp1q_f32(vi0x4567, vi0x89AB); 106 | float32x4_t vi0x579B = vuzp2q_f32(vi0x4567, vi0x89AB); 107 | float32x4_t vi1x468A = vuzp1q_f32(vi1x4567, vi1x89AB); 108 | float32x4_t vi1x579B = vuzp2q_f32(vi1x4567, vi1x89AB); 109 | float32x4_t vi2x468A = vuzp1q_f32(vi2x4567, vi2x89AB); 110 | float32x4_t vi2x579B = vuzp2q_f32(vi2x4567, vi2x89AB); 111 | float32x4_t vi3x468A = vuzp1q_f32(vi3x4567, vi3x89AB); 112 | float32x4_t vi3x579B = vuzp2q_f32(vi3x4567, vi3x89AB); 113 | float32x4_t vi4x468A = vuzp1q_f32(vi4x4567, vi4x89AB); 114 | float32x4_t vi4x579B = vuzp2q_f32(vi4x4567, vi4x89AB); 115 | 116 | if (k <= 8) { 117 | vi0x468A = vreinterpretq_u32_f32(vandq_u32(vmask_even, vreinterpretq_f32_u32(vi0x468A))); 118 | vi1x468A = vreinterpretq_u32_f32(vandq_u32(vmask_even, vreinterpretq_f32_u32(vi1x468A))); 119 | vi2x468A = vreinterpretq_u32_f32(vandq_u32(vmask_even, vreinterpretq_f32_u32(vi2x468A))); 120 | vi3x468A = vreinterpretq_u32_f32(vandq_u32(vmask_even, vreinterpretq_f32_u32(vi3x468A))); 121 | vi4x468A = vreinterpretq_u32_f32(vandq_u32(vmask_even, vreinterpretq_f32_u32(vi4x468A))); 122 | 123 | vi0x579B = vreinterpretq_u32_f32(vandq_u32(vmask_odd, vreinterpretq_f32_u32(vi0x579B))); 124 | vi1x579B = vreinterpretq_u32_f32(vandq_u32(vmask_odd, vreinterpretq_f32_u32(vi1x579B))); 125 | vi2x579B = vreinterpretq_u32_f32(vandq_u32(vmask_odd, vreinterpretq_f32_u32(vi2x579B))); 126 | vi3x579B = vreinterpretq_u32_f32(vandq_u32(vmask_odd, vreinterpretq_f32_u32(vi3x579B))); 127 | vi4x579B = vreinterpretq_u32_f32(vandq_u32(vmask_odd, vreinterpretq_f32_u32(vi4x579B))); 128 | } 129 | 130 | // middle tap 131 | vo468Ap00 = vfmaq_laneq_f32(vo468Ap00, vi0x468A, vw0123, 3); 132 | float32x4_t vo468Ap01 = vmulq_laneq_f32(vi1x468A, vw89AB, 0); 133 | vo468Ap00 = vfmaq_laneq_f32(vo468Ap00, vi2x468A, vwCDEF, 1); 134 | vo468Ap01 = vfmaq_laneq_f32(vo468Ap01, vi3x468A, vwGHIJ, 2); 135 | vo468Ap00 = vfmaq_laneq_f32(vo468Ap00, vi4x468A, vwKLMN, 3); 136 | 137 | // one left 138 | const float32x4_t vi0x3579 = vextq_f32(vi0x0123, vi0x579B, 3); 139 | const float32x4_t vi1x3579 = vextq_f32(vi1x0123, vi1x579B, 3); 140 | const float32x4_t vi2x3579 = vextq_f32(vi2x0123, vi2x579B, 3); 141 | const float32x4_t vi3x3579 = vextq_f32(vi3x0123, vi3x579B, 3); 142 | const float32x4_t vi4x3579 = vextq_f32(vi4x0123, vi4x579B, 3); 143 | 144 | vo468Ap00 = vfmaq_laneq_f32(vo468Ap00, vi0x3579, vw0123, 2); 145 | vo468Ap01 = vfmaq_laneq_f32(vo468Ap01, vi1x3579, vw4567, 3); 146 | vo468Ap00 = vfmaq_laneq_f32(vo468Ap00, vi2x3579, vwCDEF, 0); 147 | vo468Ap01 = vfmaq_laneq_f32(vo468Ap01, vi3x3579, vwGHIJ, 1); 148 | vo468Ap00 = vfmaq_laneq_f32(vo468Ap00, vi4x3579, vwKLMN, 2); 149 | 150 | // two left 151 | // getting the vector to use for the far left tap is annoying 152 | // as we can't ext anything we currently have to get it. 153 | // To do this, we get a bit ugly. Interpret the float 32x4 154 | // vector as int 64x2. Then left shift by 32. Interpret 155 | // again as float 32x4. Now the right most bits are what we 156 | // want them to be for the following ext. 157 | const float32x4_t vi0x0012 = vreinterpretq_f32_u64(vshlq_n_u64(vreinterpretq_u64_f32(vi0x0123), 32)); 158 | const float32x4_t vi1x0012 = vreinterpretq_f32_u64(vshlq_n_u64(vreinterpretq_u64_f32(vi1x0123), 32)); 159 | const float32x4_t vi2x0012 = vreinterpretq_f32_u64(vshlq_n_u64(vreinterpretq_u64_f32(vi2x0123), 32)); 160 | const float32x4_t vi3x0012 = vreinterpretq_f32_u64(vshlq_n_u64(vreinterpretq_u64_f32(vi3x0123), 32)); 161 | const float32x4_t vi4x0012 = vreinterpretq_f32_u64(vshlq_n_u64(vreinterpretq_u64_f32(vi4x0123), 32)); 162 | 163 | const float32x4_t vi0x2468 = vextq_f32(vi0x0012, vi0x468A, 3); 164 | const float32x4_t vi1x2468 = vextq_f32(vi1x0012, vi1x468A, 3); 165 | const float32x4_t vi2x2468 = vextq_f32(vi2x0012, vi2x468A, 3); 166 | const float32x4_t vi3x2468 = vextq_f32(vi3x0012, vi3x468A, 3); 167 | const float32x4_t vi4x2468 = vextq_f32(vi4x0012, vi4x468A, 3); 168 | 169 | vo468Ap00 = vfmaq_laneq_f32(vo468Ap00, vi0x2468, vw0123, 1); 170 | vo468Ap01 = vfmaq_laneq_f32(vo468Ap01, vi1x2468, vw4567, 2); 171 | vo468Ap00 = vfmaq_laneq_f32(vo468Ap00, vi2x2468, vw89AB, 3); 172 | vo468Ap01 = vfmaq_laneq_f32(vo468Ap01, vi3x2468, vwGHIJ, 0); 173 | vo468Ap00 = vfmaq_laneq_f32(vo468Ap00, vi4x2468, vwKLMN, 1); 174 | 175 | vi0x0123 = vi0x89AB; 176 | vi1x0123 = vi1x89AB; 177 | vi2x0123 = vi2x89AB; 178 | vi3x0123 = vi3x89AB; 179 | vi4x0123 = vi4x89AB; 180 | 181 | // one right 182 | vo468Ap00 = vfmaq_laneq_f32(vo468Ap00, vi0x579B, vw4567, 0); 183 | vo468Ap01 = vfmaq_laneq_f32(vo468Ap01, vi1x579B, vw89AB, 1); 184 | vo468Ap00 = vfmaq_laneq_f32(vo468Ap00, vi2x579B, vwCDEF, 2); 185 | vo468Ap01 = vfmaq_laneq_f32(vo468Ap01, vi3x579B, vwGHIJ, 3); 186 | vo468Ap00 = vfmaq_lane_f32( vo468Ap00, vi4x579B, vwOP, 0); 187 | 188 | // two right 189 | const float32x4_t vi0x68AC = vextq_f32(vi0x468A, vi0xCDEF, 1); 190 | const float32x4_t vi1x68AC = vextq_f32(vi1x468A, vi1xCDEF, 1); 191 | const float32x4_t vi2x68AC = vextq_f32(vi2x468A, vi2xCDEF, 1); 192 | const float32x4_t vi3x68AC = vextq_f32(vi3x468A, vi3xCDEF, 1); 193 | const float32x4_t vi4x68AC = vextq_f32(vi4x468A, vi4xCDEF, 1); 194 | 195 | vo468Ap00 = vfmaq_laneq_f32(vo468Ap00, vi0x68AC, vw4567, 1); 196 | vo468Ap01 = vfmaq_laneq_f32(vo468Ap01, vi1x68AC, vw89AB, 2); 197 | vo468Ap00 = vfmaq_laneq_f32(vo468Ap00, vi2x68AC, vwCDEF, 3); 198 | vo468Ap01 = vfmaq_laneq_f32(vo468Ap01, vi3x68AC, vwKLMN, 0); 199 | vo468Ap00 = vfmaq_lane_f32( vo468Ap00, vi4x68AC, vwOP, 1); 200 | 201 | vi0x4567 = vi0xCDEF; 202 | vi1x4567 = vi1xCDEF; 203 | vi2x4567 = vi2xCDEF; 204 | vi3x4567 = vi3xCDEF; 205 | vi4x4567 = vi4xCDEF; 206 | 207 | float32x4_t vo0 = vaddq_f32(vo468Ap00, vo468Ap01); 208 | 209 | vo0 = vmaxq_f32(vo0, vmin); 210 | vo0 = vminq_f32(vo0, vmax); 211 | 212 | size_t k_tmp = (k + 1) / 2; 213 | if (k_tmp >= 4) { 214 | vst1q_f32(output0, vo0); 215 | output0 = (float*) ((uintptr_t) output0 + output_tuple_stride); 216 | } else { 217 | float* output0_lo = output0; 218 | float32x2_t vo0_lo = vget_low_f32(vo0); 219 | if (k_tmp & 2) { 220 | vst1_f32(output0_lo, vo0_lo); output0_lo += 2; 221 | vo0_lo = vget_high_f32(vo0); 222 | } 223 | if (k_tmp & 1) { 224 | vst1_lane_f32(output0_lo, vo0_lo, 0); 225 | } 226 | } 227 | } 228 | 229 | i0 = (const float*) ((uintptr_t) i0 + input_width_increment_single); 230 | i1 = (const float*) ((uintptr_t) i1 + input_width_increment_single); 231 | i2 = (const float*) ((uintptr_t) i2 + input_width_increment_single); 232 | i3 = (const float*) ((uintptr_t) i3 + input_width_increment_single); 233 | i4 = (const float*) ((uintptr_t) i4 + input_width_increment_single); 234 | output0 = (float*) ((uintptr_t) output0 + output_width_increment_single); 235 | m -= 1; 236 | } while (m > 0); 237 | } 238 | -------------------------------------------------------------------------------- /dwconv-5x5s2p2-scalar.c: -------------------------------------------------------------------------------- 1 | // Copyright The Fast Sparse ConvNets Authors. 2 | // 3 | // This source code is licensed under the BSD-style license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include 7 | 8 | void f32_dwconv_spchw_ukernel_5x5s2p2__scalar( 9 | size_t m, 10 | size_t n, 11 | const float* input, 12 | const float* weights, 13 | float* output, 14 | size_t input_tuple_stride, 15 | size_t output_tuple_stride, 16 | size_t input_width_stride, 17 | size_t output_width_stride, 18 | const union f32_spchw_params params[restrict static 1]) 19 | { 20 | assert(n != 0); 21 | 22 | const float params_max = params->scalar.max; 23 | const float params_min = params->scalar.min; 24 | 25 | const size_t input_width_increment_single = input_width_stride * 2 - (1 + 2 * ((n - 1) / 2)) * input_tuple_stride; 26 | const size_t output_width_increment_single = output_width_stride - (n - 1) / 2 * output_tuple_stride; 27 | 28 | // No vertical padding. 29 | const float* i0 = input; 30 | const float* i1 = (const float*) ((uintptr_t) i0 + input_width_stride); 31 | const float* i2 = (const float*) ((uintptr_t) i1 + input_width_stride); 32 | const float* i3 = (const float*) ((uintptr_t) i2 + input_width_stride); 33 | const float* i4 = (const float*) ((uintptr_t) i3 + input_width_stride); 34 | 35 | float* output0 = output; 36 | 37 | // this almost certainly will use too many scalar registers 38 | // hope the compiler is good at spilling... 39 | const float vw0 = weights[0]; 40 | const float vw1 = weights[1]; 41 | const float vw2 = weights[2]; 42 | const float vw3 = weights[3]; 43 | const float vw4 = weights[4]; 44 | const float vw5 = weights[5]; 45 | const float vw6 = weights[6]; 46 | const float vw7 = weights[7]; 47 | const float vw8 = weights[8]; 48 | const float vw9 = weights[9]; 49 | const float vw10 = weights[10]; 50 | const float vw11 = weights[11]; 51 | const float vw12 = weights[12]; 52 | const float vw13 = weights[13]; 53 | const float vw14 = weights[14]; 54 | const float vw15 = weights[15]; 55 | const float vw16 = weights[16]; 56 | const float vw17 = weights[17]; 57 | const float vw18 = weights[18]; 58 | const float vw19 = weights[19]; 59 | const float vw20 = weights[20]; 60 | const float vw21 = weights[21]; 61 | const float vw22 = weights[22]; 62 | const float vw23 = weights[23]; 63 | const float vw24 = weights[24]; 64 | const float vw25 = weights[25]; 65 | 66 | do { 67 | float vi0x0 = 0.0f; 68 | float vi1x0 = 0.0f; 69 | float vi2x0 = 0.0f; 70 | float vi3x0 = 0.0f; 71 | float vi4x0 = 0.0f; 72 | float vi0x1 = 0.0f; 73 | float vi1x1 = 0.0f; 74 | float vi2x1 = 0.0f; 75 | float vi3x1 = 0.0f; 76 | float vi4x1 = 0.0f; 77 | float vi0x2 = *i0; i0 = (const float*) ((uintptr_t) i0 + input_tuple_stride); 78 | float vi1x2 = *i1; i1 = (const float*) ((uintptr_t) i1 + input_tuple_stride); 79 | float vi2x2 = *i2; i2 = (const float*) ((uintptr_t) i2 + input_tuple_stride); 80 | float vi3x2 = *i3; i3 = (const float*) ((uintptr_t) i3 + input_tuple_stride); 81 | float vi4x2 = *i4; i4 = (const float*) ((uintptr_t) i4 + input_tuple_stride); 82 | 83 | 84 | size_t k = n; 85 | for (; k > 2; k -= 2) { 86 | const float vi0x3 = *i0; i0 = (const float*) ((uintptr_t) i0 + input_tuple_stride); 87 | const float vi1x3 = *i1; i1 = (const float*) ((uintptr_t) i1 + input_tuple_stride); 88 | const float vi2x3 = *i2; i2 = (const float*) ((uintptr_t) i2 + input_tuple_stride); 89 | const float vi3x3 = *i3; i3 = (const float*) ((uintptr_t) i3 + input_tuple_stride); 90 | const float vi4x3 = *i4; i4 = (const float*) ((uintptr_t) i4 + input_tuple_stride); 91 | 92 | const float vi0x4 = *i0; i0 = (const float*) ((uintptr_t) i0 + input_tuple_stride); 93 | const float vi1x4 = *i1; i1 = (const float*) ((uintptr_t) i1 + input_tuple_stride); 94 | const float vi2x4 = *i2; i2 = (const float*) ((uintptr_t) i2 + input_tuple_stride); 95 | const float vi3x4 = *i3; i3 = (const float*) ((uintptr_t) i3 + input_tuple_stride); 96 | const float vi4x4 = *i4; i4 = (const float*) ((uintptr_t) i4 + input_tuple_stride); 97 | 98 | const float vrow0_accum = vw1 * vi0x0 + vw2 * vi0x1 + vw3 * vi0x2 + vw4 * vi0x3 + vw5 * vi0x4; 99 | vi0x0 = vi0x2; 100 | vi0x1 = vi0x3; 101 | vi0x2 = vi0x4; 102 | const float vrow1_accum = vw6 * vi1x0 + vw7 * vi1x1 + vw8 * vi1x2 + vw9 * vi1x3 + vw10 * vi1x4; 103 | vi1x0 = vi1x2; 104 | vi1x1 = vi1x3; 105 | vi1x2 = vi1x4; 106 | const float vrow2_accum = vw11 * vi2x0 + vw12 * vi2x1 + vw13 * vi2x2 + vw14 * vi2x3 + vw15 * vi2x4; 107 | vi2x0 = vi2x2; 108 | vi2x1 = vi2x3; 109 | vi2x2 = vi2x4; 110 | const float vrow3_accum = vw16 * vi3x0 + vw17 * vi3x1 + vw18 * vi3x2 + vw19 * vi3x3 + vw20 * vi3x4; 111 | vi3x0 = vi3x2; 112 | vi3x1 = vi3x3; 113 | vi3x2 = vi3x4; 114 | const float vrow4_accum = vw21 * vi4x0 + vw22 * vi4x1 + vw23 * vi4x2 + vw24 * vi4x3 + vw25 * vi4x4; 115 | vi4x0 = vi4x2; 116 | vi4x1 = vi4x3; 117 | vi4x2 = vi4x4; 118 | 119 | float voutput = (vw0 + vrow0_accum) + (vrow1_accum + vrow2_accum) + (vrow3_accum + vrow4_accum); 120 | 121 | voutput = math_max_f32(voutput, params_min); 122 | voutput = math_min_f32(voutput, params_max); 123 | 124 | *output0 = voutput; output0 = (float*) ((uintptr_t) output0 + output_tuple_stride); 125 | } 126 | if (k == 2) { 127 | const float vi0x3 = *i0; 128 | const float vi1x3 = *i1; 129 | const float vi2x3 = *i2; 130 | const float vi3x3 = *i3; 131 | const float vi4x3 = *i4; 132 | 133 | const float vrow0_accum = vw1 * vi0x0 + vw2 * vi0x1 + vw3 * vi0x2 + vw4 * vi0x3; 134 | const float vrow1_accum = vw6 * vi1x0 + vw7 * vi1x1 + vw8 * vi1x2 + vw9 * vi1x3; 135 | const float vrow2_accum = vw11 * vi2x0 + vw12 * vi2x1 + vw13 * vi2x2 + vw14 * vi2x3; 136 | const float vrow3_accum = vw16 * vi3x0 + vw17 * vi3x1 + vw18 * vi3x2 + vw19 * vi3x3; 137 | const float vrow4_accum = vw21 * vi4x0 + vw22 * vi4x1 + vw23 * vi4x2 + vw24 * vi4x3; 138 | 139 | float voutput = (vw0 + vrow0_accum) + (vrow1_accum + vrow2_accum) + (vrow3_accum + vrow4_accum); 140 | 141 | voutput = math_max_f32(voutput, params_min); 142 | voutput = math_min_f32(voutput, params_max); 143 | 144 | *output0 = voutput; 145 | } 146 | else { 147 | const float vrow0_accum = vw1 * vi0x0 + vw2 * vi0x1 + vw3 * vi0x2; 148 | const float vrow1_accum = vw6 * vi1x0 + vw7 * vi1x1 + vw8 * vi1x2; 149 | const float vrow2_accum = vw11 * vi2x0 + vw12 * vi2x1 + vw13 * vi2x2; 150 | const float vrow3_accum = vw16 * vi3x0 + vw17 * vi3x1 + vw18 * vi3x2; 151 | const float vrow4_accum = vw21 * vi4x0 + vw22 * vi4x1 + vw23 * vi4x2; 152 | 153 | float voutput = (vw0 + vrow0_accum) + (vrow1_accum + vrow2_accum) + (vrow3_accum + vrow4_accum); 154 | 155 | voutput = math_max_f32(voutput, params_min); 156 | voutput = math_min_f32(voutput, params_max); 157 | 158 | *output0 = voutput; 159 | } 160 | 161 | i0 = (const float*) ((uintptr_t) i0 + input_width_increment_single); 162 | i1 = (const float*) ((uintptr_t) i1 + input_width_increment_single); 163 | i2 = (const float*) ((uintptr_t) i2 + input_width_increment_single); 164 | i3 = (const float*) ((uintptr_t) i3 + input_width_increment_single); 165 | i4 = (const float*) ((uintptr_t) i4 + input_width_increment_single); 166 | output0 = (float*) ((uintptr_t) output0 + output_width_increment_single); 167 | m -= 1; 168 | } while (m > 0); 169 | } 170 | -------------------------------------------------------------------------------- /gavgpool-neon-x4.c: -------------------------------------------------------------------------------- 1 | // Copyright The Fast Sparse ConvNets Authors. 2 | // 3 | // This source code is licensed under the BSD-style license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include 7 | 8 | #include 9 | 10 | void f32_gavgpool_spchw_ukernel__neon_x4( 11 | size_t elements, 12 | size_t channels, 13 | const float* input, 14 | float* output, 15 | const union f32_gavgpool_params params[restrict static 1]) 16 | { 17 | assert(elements != 0); 18 | assert(elements % sizeof(float) == 0); 19 | assert(channels != 0); 20 | 21 | const float* i0 = input; 22 | const float* i1 = (const float*) ((uintptr_t) i0 + elements); 23 | const float* i2 = (const float*) ((uintptr_t) i1 + elements); 24 | const float* i3 = (const float*) ((uintptr_t) i2 + elements); 25 | 26 | const uint32x4_t vmask = vld1q_u32(params->neon.mask); 27 | const float32x4_t vmultiplier = vld1q_dup_f32(¶ms->neon.multiplier); 28 | const float32x4_t voutput_min = vld1q_dup_f32(¶ms->neon.output_min); 29 | const float32x4_t voutput_max = vld1q_dup_f32(¶ms->neon.output_max); 30 | 31 | while (channels >= 4) { 32 | float32x4_t vsum0 = vmovq_n_f32(0.0f); 33 | float32x4_t vsum1 = vmovq_n_f32(0.0f); 34 | float32x4_t vsum2 = vmovq_n_f32(0.0f); 35 | float32x4_t vsum3 = vmovq_n_f32(0.0f); 36 | size_t n = elements; 37 | while (n >= 4 * sizeof(float)) { 38 | const float32x4_t vi0 = vld1q_f32(i0); i0 += 4; 39 | const float32x4_t vi1 = vld1q_f32(i1); i1 += 4; 40 | const float32x4_t vi2 = vld1q_f32(i2); i2 += 4; 41 | const float32x4_t vi3 = vld1q_f32(i3); i3 += 4; 42 | 43 | vsum0 = vaddq_f32(vsum0, vi0); 44 | vsum1 = vaddq_f32(vsum1, vi1); 45 | vsum2 = vaddq_f32(vsum2, vi2); 46 | vsum3 = vaddq_f32(vsum3, vi3); 47 | n -= 4 * sizeof(float); 48 | } 49 | 50 | if (n != 0) { 51 | float32x4_t vi0 = vld1q_f32(i0); i0 = (const float*) ((uintptr_t) i0 + n); 52 | float32x4_t vi1 = vld1q_f32(i1); i1 = (const float*) ((uintptr_t) i1 + n); 53 | float32x4_t vi2 = vld1q_f32(i2); i2 = (const float*) ((uintptr_t) i2 + n); 54 | float32x4_t vi3 = vld1q_f32(i3); i3 = (const float*) ((uintptr_t) i3 + n); 55 | 56 | vi0 = vreinterpretq_u32_f32(vandq_u32(vmask, vreinterpretq_f32_u32(vi0))); 57 | vi1 = vreinterpretq_u32_f32(vandq_u32(vmask, vreinterpretq_f32_u32(vi1))); 58 | vi2 = vreinterpretq_u32_f32(vandq_u32(vmask, vreinterpretq_f32_u32(vi2))); 59 | vi3 = vreinterpretq_u32_f32(vandq_u32(vmask, vreinterpretq_f32_u32(vi3))); 60 | 61 | vsum0 = vaddq_f32(vsum0, vi0); 62 | vsum1 = vaddq_f32(vsum1, vi1); 63 | vsum2 = vaddq_f32(vsum2, vi2); 64 | vsum3 = vaddq_f32(vsum3, vi3); 65 | } 66 | 67 | // Having exaclty 4 rows makes this work out nicely as we end up with 68 | // the 4 totals in 4 different lanes of the same vector. 69 | #ifdef __aarch64__ 70 | const float32x4_t vsum01 = vpaddq_f32(vsum0, vsum1); 71 | const float32x4_t vsum23 = vpaddq_f32(vsum2, vsum3); 72 | const float32x4_t vsum = vpaddq_f32(vsum01, vsum23); 73 | #else 74 | const float32x4_t vsum01 = vcombine_f32(vadd_f32(vget_low_f32(vsum0), vget_high_f32(vsum0)), 75 | vadd_f32(vget_low_f32(vsum1), vget_high_f32(vsum1))); 76 | const float32x4_t vsum23 = vcombine_f32(vadd_f32(vget_low_f32(vsum2), vget_high_f32(vsum2)), 77 | vadd_f32(vget_low_f32(vsum3), vget_high_f32(vsum3))); 78 | const float32x4_t vsum = vcombine_f32(vpadd_f32(vget_low_f32(vsum01), vget_high_f32(vsum01)), 79 | vpadd_f32(vget_low_f32(vsum23), vget_high_f32(vsum23))); 80 | #endif 81 | 82 | float32x4_t vout = vmulq_f32(vsum, vmultiplier); 83 | 84 | vout = vmaxq_f32(vout, voutput_min); 85 | vout = vminq_f32(vout, voutput_max); 86 | 87 | vst1q_f32(output, vout); output += 4; 88 | i0 = i3; 89 | i1 = (const float*) ((uintptr_t) i0 + elements); 90 | i2 = (const float*) ((uintptr_t) i1 + elements); 91 | i3 = (const float*) ((uintptr_t) i2 + elements); 92 | channels -= 4; 93 | } 94 | 95 | while (channels != 0) { 96 | float32x4_t vsum0 = vmovq_n_f32(0.0f); 97 | size_t n = elements; 98 | while (n >= 4 * sizeof(float)) { 99 | const float32x4_t vi0 = vld1q_f32(i0); i0 += 4; 100 | vsum0 = vaddq_f32(vsum0, vi0); 101 | n -= 4 * sizeof(float); 102 | } 103 | 104 | if (n != 0) { 105 | float32x4_t vi0 = vld1q_f32(i0); i0 = (const float*) ((uintptr_t) i0 + n); 106 | vi0 = vreinterpretq_u32_f32(vandq_u32(vmask, vreinterpretq_f32_u32(vi0))); 107 | vsum0 = vaddq_f32(vsum0, vi0); 108 | } 109 | 110 | float32x2_t vsum = vadd_f32(vget_low_f32(vsum0), vget_high_f32(vsum0)); 111 | vsum = vpadd_f32(vsum, vsum); 112 | 113 | float32x2_t vout = vmul_f32(vsum, vget_low_f32(vmultiplier)); 114 | 115 | vout = vmax_f32(vout, vget_low_f32(voutput_min)); 116 | vout = vmin_f32(vout, vget_low_f32(voutput_max)); 117 | 118 | vst1_lane_f32(output, vout, 0); output += 1; 119 | channels -= 1; 120 | } 121 | } 122 | -------------------------------------------------------------------------------- /gavgpool-scalar-x1.c: -------------------------------------------------------------------------------- 1 | // Copyright The Fast Sparse ConvNets Authors. 2 | // 3 | // This source code is licensed under the BSD-style license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include 7 | 8 | 9 | void f32_gavgpool_spchw_ukernel__scalar_x1( 10 | size_t elements, 11 | size_t channels, 12 | const float* input, 13 | float* output, 14 | const union f32_gavgpool_params params[restrict static 1]) 15 | { 16 | assert(elements != 0); 17 | assert(elements % sizeof(float) == 0); 18 | assert(channels != 0); 19 | 20 | const float* i0 = input; 21 | 22 | const float vmultiplier = params->scalar.multiplier; 23 | const float voutput_max = params->scalar.output_max; 24 | const float voutput_min = params->scalar.output_min; 25 | 26 | while (channels != 0) { 27 | float vsum0 = 0.f; 28 | float vsum1 = 0.f; 29 | float vsum2 = 0.f; 30 | float vsum3 = 0.f; 31 | size_t n = elements; 32 | while (n >= 4 * sizeof(float)) { 33 | vsum0 += i0[0]; 34 | vsum1 += i0[1]; 35 | vsum2 += i0[2]; 36 | vsum3 += i0[3]; 37 | 38 | i0 += 4; 39 | n -= 4 * sizeof(float); 40 | } 41 | 42 | while (n != 0) { 43 | vsum0 += *i0++; 44 | n -= sizeof(float); 45 | } 46 | 47 | float vout = ( (vsum0 + vsum1) + (vsum2 + vsum3) ) * vmultiplier; 48 | 49 | vout = math_min_f32(vout, voutput_max); 50 | vout = math_max_f32(vout, voutput_min); 51 | 52 | *output++ = vout; 53 | channels -= 1; 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /spmm-16x1-neonfma-pipelined.c: -------------------------------------------------------------------------------- 1 | // Copyright The Fast Sparse ConvNets Authors. 2 | // 3 | // This source code is licensed under the BSD-style license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include 7 | 8 | #include 9 | 10 | void f32_spmm_ukernel_16x1__neonfma_pipelined( 11 | uint32_t m, 12 | uint32_t n, 13 | const float*restrict a, 14 | const float*restrict weights, 15 | const int32_t*restrict widx_dmap, 16 | const uint32_t*restrict nidx_nnzmap, 17 | float*restrict c, 18 | const union f32_output_params params[restrict static 1]) 19 | { 20 | assert(m != 0); 21 | 22 | const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min); 23 | const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max); 24 | size_t i = m; 25 | while (i >= 16) { 26 | const float*restrict w = weights; 27 | const int32_t* dmap = widx_dmap; 28 | const uint32_t* nnzmap = nidx_nnzmap; 29 | float32x4_t vw = vld1q_dup_f32(w); w += 1; 30 | intptr_t diff = *dmap++; 31 | float32x4_t va0123 = vld1q_f32(a); 32 | float32x4_t va4567 = vld1q_f32(a + 4); 33 | float32x4_t va89AB = vld1q_f32(a + 8); 34 | float32x4_t vaCDEF = vld1q_f32(a + 12); 35 | __builtin_prefetch(a + 16); 36 | size_t j = n; 37 | do { 38 | uint32_t nnz = *nnzmap++; 39 | float32x4_t vacc0123 = vw; 40 | float32x4_t vacc4567 = vw; 41 | float32x4_t vacc89AB = vw; 42 | float32x4_t vaccCDEF = vw; 43 | vw = vld1q_dup_f32(w); w += 1; 44 | if (nnz != 0) { 45 | do { 46 | vacc0123 = vfmaq_f32(vacc0123, va0123, vw); 47 | vacc4567 = vfmaq_f32(vacc4567, va4567, vw); 48 | vacc89AB = vfmaq_f32(vacc89AB, va89AB, vw); 49 | vaccCDEF = vfmaq_f32(vaccCDEF, vaCDEF, vw); 50 | a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff); 51 | 52 | diff = *dmap++; 53 | vw = vld1q_dup_f32(w); w += 1; 54 | va0123 = vld1q_f32(a); 55 | va4567 = vld1q_f32(a + 4); 56 | va89AB = vld1q_f32(a + 8); 57 | vaCDEF = vld1q_f32(a + 12); 58 | __builtin_prefetch(a + 16); 59 | } while (--nnz != 0); 60 | } 61 | float32x4_t vout0123 = vminq_f32(vacc0123, vmax); 62 | float32x4_t vout4567 = vminq_f32(vacc4567, vmax); 63 | float32x4_t vout89AB = vminq_f32(vacc89AB, vmax); 64 | float32x4_t voutCDEF = vminq_f32(vaccCDEF, vmax); 65 | vout0123 = vmaxq_f32(vout0123, vmin); 66 | vout4567 = vmaxq_f32(vout4567, vmin); 67 | vout89AB = vmaxq_f32(vout89AB, vmin); 68 | voutCDEF = vmaxq_f32(voutCDEF, vmin); 69 | vst1q_f32(c, vout0123); 70 | vst1q_f32(c + 4, vout4567); 71 | vst1q_f32(c + 8, vout89AB); 72 | vst1q_f32(c + 12, voutCDEF); 73 | c += m; 74 | } while (--j != 0); 75 | c -= m * n; 76 | c += 16; 77 | a += 16; 78 | i -= 16; 79 | } 80 | if (i != 0) { 81 | if (i & 8) { 82 | const float*restrict w = weights; 83 | const int32_t* dmap = widx_dmap; 84 | const uint32_t* nnzmap = nidx_nnzmap; 85 | size_t j = n; 86 | do { 87 | uint32_t nnz = *nnzmap++; 88 | float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1; 89 | float32x4_t vacc4567 = vacc0123; 90 | if (nnz != 0) { 91 | do { 92 | const intptr_t diff = *dmap++; 93 | const float32x4_t va0123 = vld1q_f32(a); 94 | const float32x4_t va4567 = vld1q_f32(a + 4); 95 | a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff); 96 | const float32x4_t vb = vld1q_dup_f32(w); w += 1; 97 | vacc0123 = vfmaq_f32(vacc0123, va0123, vb); 98 | vacc4567 = vfmaq_f32(vacc4567, va4567, vb); 99 | } while (--nnz != 0); 100 | } 101 | float32x4_t vout0123 = vminq_f32(vacc0123, vmax); 102 | float32x4_t vout4567 = vminq_f32(vacc4567, vmax); 103 | vout0123 = vmaxq_f32(vout0123, vmin); 104 | vout4567 = vmaxq_f32(vout4567, vmin); 105 | vst1q_f32(c, vout0123); 106 | vst1q_f32(c + 4, vout4567); 107 | c += m; 108 | } while (--j != 0); 109 | c -= m * n; 110 | c += 8; 111 | a += 8; 112 | } 113 | if (i & 4) { 114 | const float*restrict w = weights; 115 | const int32_t* dmap = widx_dmap; 116 | const uint32_t* nnzmap = nidx_nnzmap; 117 | size_t j = n; 118 | do { 119 | uint32_t nnz = *nnzmap++; 120 | float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1; 121 | if (nnz != 0) { 122 | do { 123 | const intptr_t diff = *dmap++; 124 | const float32x4_t va0123 = vld1q_f32(a); 125 | a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff); 126 | const float32x4_t vb = vld1q_dup_f32(w); w += 1; 127 | vacc0123 = vfmaq_f32(vacc0123, va0123, vb); 128 | } while (--nnz != 0); 129 | } 130 | float32x4_t vout0123 = vminq_f32(vacc0123, vmax); 131 | vout0123 = vmaxq_f32(vout0123, vmin); 132 | vst1q_f32(c, vout0123); 133 | c += m; 134 | } while (--j != 0); 135 | c -= m * n; 136 | c += 4; 137 | a += 4; 138 | } 139 | if (i & 2) { 140 | const float*restrict w = weights; 141 | const int32_t* dmap = widx_dmap; 142 | const uint32_t* nnzmap = nidx_nnzmap; 143 | size_t j = n; 144 | do { 145 | uint32_t nnz = *nnzmap++; 146 | float32x2_t vacc01 = vld1_dup_f32(w); w += 1; 147 | if (nnz != 0) { 148 | do { 149 | const intptr_t diff = *dmap++; 150 | const float32x2_t va01 = vld1_f32(a); 151 | a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff); 152 | const float32x2_t vb = vld1_dup_f32(w); w += 1; 153 | vacc01 = vfma_f32(vacc01, va01, vb); 154 | } while (--nnz != 0); 155 | } 156 | float32x2_t vout01 = vmin_f32(vacc01, vget_low_f32(vmax)); 157 | vout01 = vmax_f32(vout01, vget_low_f32(vmin)); 158 | vst1_f32(c, vout01); 159 | c += m; 160 | } while (--j != 0); 161 | c -= m * n; 162 | c += 2; 163 | a += 2; 164 | } 165 | if (i & 1) { 166 | const float*restrict w = weights; 167 | const int32_t* dmap = widx_dmap; 168 | const uint32_t* nnzmap = nidx_nnzmap; 169 | size_t j = n; 170 | do { 171 | uint32_t nnz = *nnzmap++; 172 | float32x2_t vacc0 = vld1_dup_f32(w); w += 1; 173 | if (nnz != 0) { 174 | do { 175 | const intptr_t diff = *dmap++; 176 | const float32x2_t va0 = vld1_dup_f32(a); 177 | a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff); 178 | const float32x2_t vb = vld1_dup_f32(w); w += 1; 179 | vacc0 = vfma_f32(vacc0, va0, vb); 180 | } while (--nnz != 0); 181 | } 182 | float32x2_t vout0 = vmin_f32(vacc0, vget_low_f32(vmax)); 183 | vout0 = vmax_f32(vout0, vget_low_f32(vmin)); 184 | vst1_lane_f32(c, vout0, 0); 185 | c += m; 186 | } while (--j != 0); 187 | c -= m * n; 188 | c += 1; 189 | a += 1; 190 | } 191 | } 192 | } 193 | -------------------------------------------------------------------------------- /spmm-16x2-neonfma.c: -------------------------------------------------------------------------------- 1 | // Copyright The Fast Sparse ConvNets Authors. 2 | // 3 | // This source code is licensed under the BSD-style license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include 7 | 8 | #include 9 | 10 | void f32_spmm_ukernel_16x2__neonfma( 11 | uint32_t m, 12 | uint32_t n, 13 | const float*restrict a, 14 | const float*restrict weights, 15 | const int32_t*restrict widx_dmap, 16 | const uint32_t*restrict nidx_nnzmap, 17 | float*restrict c, 18 | const union f32_output_params params[restrict static 1]) 19 | { 20 | assert(m != 0); 21 | 22 | const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min); 23 | const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max); 24 | size_t i = m; 25 | while (i >= 16) { 26 | const float*restrict w = weights; 27 | const int32_t* dmap = widx_dmap; 28 | const uint32_t* nnzmap = nidx_nnzmap; 29 | size_t j = n; 30 | while (j >= 2) { 31 | uint32_t nnz = *nnzmap++; 32 | float32x4_t vacc0123c0 = vld1q_dup_f32(w); w += 1; 33 | float32x4_t vacc4567c0 = vacc0123c0; 34 | float32x4_t vacc89ABc0 = vacc0123c0; 35 | float32x4_t vaccCDEFc0 = vacc0123c0; 36 | float32x4_t vacc0123c1 = vld1q_dup_f32(w); w += 1; 37 | float32x4_t vacc4567c1 = vacc0123c1; 38 | float32x4_t vacc89ABc1 = vacc0123c1; 39 | float32x4_t vaccCDEFc1 = vacc0123c1; 40 | if (nnz != 0) { 41 | do { 42 | const intptr_t diff = *dmap++; 43 | const float32x4_t va0123 = vld1q_f32(a); 44 | const float32x4_t va4567 = vld1q_f32(a + 4); 45 | const float32x4_t va89AB = vld1q_f32(a + 8); 46 | const float32x4_t vaCDEF = vld1q_f32(a + 12); 47 | __builtin_prefetch(a + 16); 48 | a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff); 49 | const float32x2_t vb = vld1_f32(w); w += 2; 50 | 51 | vacc0123c0 = vfmaq_lane_f32(vacc0123c0, va0123, vb, 0); 52 | vacc4567c0 = vfmaq_lane_f32(vacc4567c0, va4567, vb, 0); 53 | vacc89ABc0 = vfmaq_lane_f32(vacc89ABc0, va89AB, vb, 0); 54 | vaccCDEFc0 = vfmaq_lane_f32(vaccCDEFc0, vaCDEF, vb, 0); 55 | vacc0123c1 = vfmaq_lane_f32(vacc0123c1, va0123, vb, 1); 56 | vacc4567c1 = vfmaq_lane_f32(vacc4567c1, va4567, vb, 1); 57 | vacc89ABc1 = vfmaq_lane_f32(vacc89ABc1, va89AB, vb, 1); 58 | vaccCDEFc1 = vfmaq_lane_f32(vaccCDEFc1, vaCDEF, vb, 1); 59 | } while (--nnz != 0); 60 | } 61 | float32x4_t vout0123c0 = vminq_f32(vacc0123c0, vmax); 62 | float32x4_t vout4567c0 = vminq_f32(vacc4567c0, vmax); 63 | float32x4_t vout89ABc0 = vminq_f32(vacc89ABc0, vmax); 64 | float32x4_t voutCDEFc0 = vminq_f32(vaccCDEFc0, vmax); 65 | float32x4_t vout0123c1 = vminq_f32(vacc0123c1, vmax); 66 | float32x4_t vout4567c1 = vminq_f32(vacc4567c1, vmax); 67 | float32x4_t vout89ABc1 = vminq_f32(vacc89ABc1, vmax); 68 | float32x4_t voutCDEFc1 = vminq_f32(vaccCDEFc1, vmax); 69 | 70 | vout0123c0 = vmaxq_f32(vout0123c0, vmin); 71 | vout4567c0 = vmaxq_f32(vout4567c0, vmin); 72 | vout89ABc0 = vmaxq_f32(vout89ABc0, vmin); 73 | voutCDEFc0 = vmaxq_f32(voutCDEFc0, vmin); 74 | vout0123c1 = vmaxq_f32(vout0123c1, vmin); 75 | vout4567c1 = vmaxq_f32(vout4567c1, vmin); 76 | vout89ABc1 = vmaxq_f32(vout89ABc1, vmin); 77 | voutCDEFc1 = vmaxq_f32(voutCDEFc1, vmin); 78 | 79 | vst1q_f32(c + 0 * m + 0, vout0123c0); 80 | vst1q_f32(c + 0 * m + 4, vout4567c0); 81 | vst1q_f32(c + 0 * m + 8, vout89ABc0); 82 | vst1q_f32(c + 0 * m + 12, voutCDEFc0); 83 | vst1q_f32(c + 1 * m + 0, vout0123c1); 84 | vst1q_f32(c + 1 * m + 4, vout4567c1); 85 | vst1q_f32(c + 1 * m + 8, vout89ABc1); 86 | vst1q_f32(c + 1 * m + 12, voutCDEFc1); 87 | c += 2 * m; 88 | j -= 2; 89 | } 90 | 91 | // clean up loop, fall back to nr=1 92 | if (j != 0) { 93 | do { 94 | uint32_t nnz = *nnzmap++; 95 | float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1; 96 | float32x4_t vacc4567 = vacc0123; 97 | float32x4_t vacc89AB = vacc0123; 98 | float32x4_t vaccCDEF = vacc0123; 99 | if (nnz != 0) { 100 | do { 101 | const intptr_t diff = *dmap++; 102 | const float32x4_t va0123 = vld1q_f32(a); 103 | const float32x4_t va4567 = vld1q_f32(a + 4); 104 | const float32x4_t va89AB = vld1q_f32(a + 8); 105 | const float32x4_t vaCDEF = vld1q_f32(a + 12); 106 | a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff); 107 | const float32x4_t vb = vld1q_dup_f32(w); w += 1; 108 | vacc0123 = vfmaq_f32(vacc0123, va0123, vb); 109 | vacc4567 = vfmaq_f32(vacc4567, va4567, vb); 110 | vacc89AB = vfmaq_f32(vacc89AB, va89AB, vb); 111 | vaccCDEF = vfmaq_f32(vaccCDEF, vaCDEF, vb); 112 | } while (--nnz != 0); 113 | } 114 | float32x4_t vout0123 = vminq_f32(vacc0123, vmax); 115 | float32x4_t vout4567 = vminq_f32(vacc4567, vmax); 116 | float32x4_t vout89AB = vminq_f32(vacc89AB, vmax); 117 | float32x4_t voutCDEF = vminq_f32(vaccCDEF, vmax); 118 | 119 | vout0123 = vmaxq_f32(vout0123, vmin); 120 | vout4567 = vmaxq_f32(vout4567, vmin); 121 | vout89AB = vmaxq_f32(vout89AB, vmin); 122 | voutCDEF = vmaxq_f32(voutCDEF, vmin); 123 | 124 | vst1q_f32(c + 0, vout0123); 125 | vst1q_f32(c + 4, vout4567); 126 | vst1q_f32(c + 8, vout89AB); 127 | vst1q_f32(c + 12, voutCDEF); 128 | c += m; 129 | j -= 1; 130 | } while (j != 0); 131 | } 132 | c -= m * n; 133 | c += 16; 134 | a += 16; 135 | i -= 16; 136 | } 137 | if (i != 0) { 138 | if (i & 8) { 139 | const float*restrict w = weights; 140 | const int32_t* dmap = widx_dmap; 141 | const uint32_t* nnzmap = nidx_nnzmap; 142 | size_t j = n; 143 | while (j >= 2) { 144 | uint32_t nnz = *nnzmap++; 145 | float32x4_t vacc0123c0 = vld1q_dup_f32(w); w += 1; 146 | float32x4_t vacc4567c0 = vacc0123c0; 147 | float32x4_t vacc0123c1 = vld1q_dup_f32(w); w += 1; 148 | float32x4_t vacc4567c1 = vacc0123c1; 149 | if (nnz != 0) { 150 | do { 151 | const intptr_t diff = *dmap++; 152 | const float32x4_t va0123 = vld1q_f32(a); 153 | const float32x4_t va4567 = vld1q_f32(a + 4); 154 | a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff); 155 | const float32x2_t vb = vld1_f32(w); w += 2; 156 | 157 | vacc0123c0 = vfmaq_lane_f32(vacc0123c0, va0123, vb, 0); 158 | vacc4567c0 = vfmaq_lane_f32(vacc4567c0, va4567, vb, 0); 159 | vacc0123c1 = vfmaq_lane_f32(vacc0123c1, va0123, vb, 1); 160 | vacc4567c1 = vfmaq_lane_f32(vacc4567c1, va4567, vb, 1); 161 | } while (--nnz != 0); 162 | } 163 | float32x4_t vout0123c0 = vminq_f32(vacc0123c0, vmax); 164 | float32x4_t vout4567c0 = vminq_f32(vacc4567c0, vmax); 165 | float32x4_t vout0123c1 = vminq_f32(vacc0123c1, vmax); 166 | float32x4_t vout4567c1 = vminq_f32(vacc4567c1, vmax); 167 | 168 | vout0123c0 = vmaxq_f32(vout0123c0, vmin); 169 | vout4567c0 = vmaxq_f32(vout4567c0, vmin); 170 | vout0123c1 = vmaxq_f32(vout0123c1, vmin); 171 | vout4567c1 = vmaxq_f32(vout4567c1, vmin); 172 | 173 | vst1q_f32(c + 0 * m + 0, vout0123c0); 174 | vst1q_f32(c + 0 * m + 4, vout4567c0); 175 | vst1q_f32(c + 1 * m + 0, vout0123c1); 176 | vst1q_f32(c + 1 * m + 4, vout4567c1); 177 | c += 2 * m; 178 | j -= 2; 179 | } 180 | 181 | // clean up loop, fall back to nr=1 182 | if (j != 0) { 183 | do { 184 | uint32_t nnz = *nnzmap++; 185 | float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1; 186 | float32x4_t vacc4567 = vacc0123; 187 | if (nnz != 0) { 188 | do { 189 | const intptr_t diff = *dmap++; 190 | const float32x4_t va0123 = vld1q_f32(a); 191 | const float32x4_t va4567 = vld1q_f32(a + 4); 192 | a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff); 193 | const float32x4_t vb = vld1q_dup_f32(w); w += 1; 194 | vacc0123 = vfmaq_f32(vacc0123, va0123, vb); 195 | vacc4567 = vfmaq_f32(vacc4567, va4567, vb); 196 | } while (--nnz != 0); 197 | } 198 | float32x4_t vout0123 = vminq_f32(vacc0123, vmax); 199 | float32x4_t vout4567 = vminq_f32(vacc4567, vmax); 200 | 201 | vout0123 = vmaxq_f32(vout0123, vmin); 202 | vout4567 = vmaxq_f32(vout4567, vmin); 203 | 204 | vst1q_f32(c + 0, vout0123); 205 | vst1q_f32(c + 4, vout4567); 206 | c += m; 207 | j -= 1; 208 | } while (j != 0); 209 | } 210 | c -= m * n; 211 | c += 8; 212 | a += 8; 213 | } 214 | if (i & 4) { 215 | const float*restrict w = weights; 216 | const int32_t* dmap = widx_dmap; 217 | const uint32_t* nnzmap = nidx_nnzmap; 218 | size_t j = n; 219 | while (j >= 2) { 220 | uint32_t nnz = *nnzmap++; 221 | float32x4_t vacc0123c0 = vld1q_dup_f32(w); w += 1; 222 | float32x4_t vacc0123c1 = vld1q_dup_f32(w); w += 1; 223 | if (nnz != 0) { 224 | do { 225 | const intptr_t diff = *dmap++; 226 | const float32x4_t va0123 = vld1q_f32(a); 227 | a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff); 228 | const float32x2_t vb = vld1_f32(w); w += 2; 229 | 230 | vacc0123c0 = vfmaq_lane_f32(vacc0123c0, va0123, vb, 0); 231 | vacc0123c1 = vfmaq_lane_f32(vacc0123c1, va0123, vb, 1); 232 | } while (--nnz != 0); 233 | } 234 | float32x4_t vout0123c0 = vminq_f32(vacc0123c0, vmax); 235 | float32x4_t vout0123c1 = vminq_f32(vacc0123c1, vmax); 236 | 237 | vout0123c0 = vmaxq_f32(vout0123c0, vmin); 238 | vout0123c1 = vmaxq_f32(vout0123c1, vmin); 239 | 240 | vst1q_f32(c + 0 * m + 0, vout0123c0); 241 | vst1q_f32(c + 1 * m + 0, vout0123c1); 242 | c += 2 * m; 243 | j -= 2; 244 | } 245 | 246 | // clean up loop, fall back to nr=1 247 | if (j != 0) { 248 | do { 249 | uint32_t nnz = *nnzmap++; 250 | float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1; 251 | if (nnz != 0) { 252 | do { 253 | const intptr_t diff = *dmap++; 254 | const float32x4_t va0123 = vld1q_f32(a); 255 | a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff); 256 | const float32x4_t vb = vld1q_dup_f32(w); w += 1; 257 | vacc0123 = vfmaq_f32(vacc0123, va0123, vb); 258 | } while (--nnz != 0); 259 | } 260 | float32x4_t vout0123 = vminq_f32(vacc0123, vmax); 261 | 262 | vout0123 = vmaxq_f32(vout0123, vmin); 263 | 264 | vst1q_f32(c + 0, vout0123); 265 | c += m; 266 | j -= 1; 267 | } while (j != 0); 268 | } 269 | c -= m * n; 270 | c += 4; 271 | a += 4; 272 | } 273 | if (i & 2) { 274 | const float*restrict w = weights; 275 | const int32_t* dmap = widx_dmap; 276 | const uint32_t* nnzmap = nidx_nnzmap; 277 | size_t j = n; 278 | while (j >= 2) { 279 | uint32_t nnz = *nnzmap++; 280 | float32x2_t vacc01c0 = vld1_dup_f32(w); w += 1; 281 | float32x2_t vacc01c1 = vld1_dup_f32(w); w += 1; 282 | if (nnz != 0) { 283 | do { 284 | const intptr_t diff = *dmap++; 285 | const float32x2_t va01 = vld1_f32(a); 286 | a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff); 287 | const float32x2_t vb = vld1_f32(w); w += 2; 288 | 289 | vacc01c0 = vfma_lane_f32(vacc01c0, va01, vb, 0); 290 | vacc01c1 = vfma_lane_f32(vacc01c1, va01, vb, 1); 291 | } while (--nnz != 0); 292 | } 293 | float32x2_t vout01c0 = vmin_f32(vacc01c0, vget_low_f32(vmax)); 294 | float32x2_t vout01c1 = vmin_f32(vacc01c1, vget_low_f32(vmax)); 295 | 296 | vout01c0 = vmax_f32(vout01c0, vget_low_f32(vmin)); 297 | vout01c1 = vmax_f32(vout01c1, vget_low_f32(vmin)); 298 | 299 | vst1_f32(c + 0 * m + 0, vout01c0); 300 | vst1_f32(c + 1 * m + 0, vout01c1); 301 | c += 2 * m; 302 | j -= 2; 303 | } 304 | 305 | // clean up loop, fall back to nr=1 306 | if (j != 0) { 307 | do { 308 | uint32_t nnz = *nnzmap++; 309 | float32x2_t vacc01 = vld1_dup_f32(w); w += 1; 310 | if (nnz != 0) { 311 | do { 312 | const intptr_t diff = *dmap++; 313 | const float32x2_t va01 = vld1_f32(a); 314 | a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff); 315 | const float32x2_t vb = vld1_dup_f32(w); w += 1; 316 | vacc01 = vfma_f32(vacc01, va01, vb); 317 | } while (--nnz != 0); 318 | } 319 | float32x2_t vout01 = vmin_f32(vacc01, vget_low_f32(vmax)); 320 | vout01 = vmax_f32(vout01, vget_low_f32(vmin)); 321 | 322 | vst1_f32(c, vout01); 323 | c += m; 324 | j -= 1; 325 | } while (j != 0); 326 | } 327 | c -= m * n; 328 | c += 2; 329 | a += 2; 330 | } 331 | if (i & 1) { 332 | const float*restrict w = weights; 333 | const int32_t* dmap = widx_dmap; 334 | const uint32_t* nnzmap = nidx_nnzmap; 335 | size_t j = n; 336 | while (j >= 2) { 337 | uint32_t nnz = *nnzmap++; 338 | float32x2_t vacc0c0 = vld1_dup_f32(w); w += 1; 339 | float32x2_t vacc0c1 = vld1_dup_f32(w); w += 1; 340 | if (nnz != 0) { 341 | do { 342 | const intptr_t diff = *dmap++; 343 | const float32x2_t va0 = vld1_dup_f32(a); 344 | a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff); 345 | const float32x2_t vb = vld1_f32(w); w += 2; 346 | 347 | vacc0c0 = vfma_lane_f32(vacc0c0, va0, vb, 0); 348 | vacc0c1 = vfma_lane_f32(vacc0c1, va0, vb, 1); 349 | } while (--nnz != 0); 350 | } 351 | float32x2_t vout0c0 = vmin_f32(vacc0c0, vget_low_f32(vmax)); 352 | float32x2_t vout0c1 = vmin_f32(vacc0c1, vget_low_f32(vmax)); 353 | 354 | vout0c0 = vmax_f32(vout0c0, vget_low_f32(vmin)); 355 | vout0c1 = vmax_f32(vout0c1, vget_low_f32(vmin)); 356 | 357 | vst1_lane_f32(c + 0 * m + 0, vout0c0, 0); 358 | vst1_lane_f32(c + 1 * m + 0, vout0c1, 0); 359 | c += 2 * m; 360 | j -= 2; 361 | } 362 | 363 | // clean up loop, fall back to nr=1 364 | if (j != 0) { 365 | do { 366 | uint32_t nnz = *nnzmap++; 367 | float32x2_t vacc0 = vld1_dup_f32(w); w += 1; 368 | if (nnz != 0) { 369 | do { 370 | const intptr_t diff = *dmap++; 371 | const float32x2_t va0 = vld1_dup_f32(a); 372 | a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff); 373 | const float32x2_t vb = vld1_dup_f32(w); w += 1; 374 | vacc0 = vfma_f32(vacc0, va0, vb); 375 | } while (--nnz != 0); 376 | } 377 | float32x2_t vout0 = vmin_f32(vacc0, vget_low_f32(vmax)); 378 | vout0 = vmax_f32(vout0, vget_low_f32(vmin)); 379 | 380 | vst1_lane_f32(c, vout0, 1); 381 | c += m; 382 | j -= 1; 383 | } while (j != 0); 384 | } 385 | c -= m * n; 386 | c += 1; 387 | a += 1; 388 | } 389 | } 390 | } 391 | -------------------------------------------------------------------------------- /spmm-16x4-neonfma.c: -------------------------------------------------------------------------------- 1 | // Copyright The Fast Sparse ConvNets Authors. 2 | // 3 | // This source code is licensed under the BSD-style license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include 7 | 8 | #include 9 | 10 | void f32_spmm_ukernel_16x4__neonfma( 11 | uint32_t m, 12 | uint32_t n, 13 | const float*restrict a, 14 | const float*restrict weights, 15 | const int32_t*restrict widx_dmap, 16 | const uint32_t*restrict nidx_nnzmap, 17 | float*restrict c, 18 | const union f32_output_params params[restrict static 1]) 19 | { 20 | assert(m != 0); 21 | 22 | const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min); 23 | const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max); 24 | size_t i = m; 25 | while (i >= 16) { 26 | const float*restrict w = weights; 27 | const int32_t* dmap = widx_dmap; 28 | const uint32_t* nnzmap = nidx_nnzmap; 29 | size_t j = n; 30 | while (j >= 4) { 31 | uint32_t nnz = *nnzmap++; 32 | float32x4_t vacc0123c0 = vld1q_dup_f32(w); w += 1; 33 | float32x4_t vacc4567c0 = vacc0123c0; 34 | float32x4_t vacc89ABc0 = vacc0123c0; 35 | float32x4_t vaccCDEFc0 = vacc0123c0; 36 | float32x4_t vacc0123c1 = vld1q_dup_f32(w); w += 1; 37 | float32x4_t vacc4567c1 = vacc0123c1; 38 | float32x4_t vacc89ABc1 = vacc0123c1; 39 | float32x4_t vaccCDEFc1 = vacc0123c1; 40 | float32x4_t vacc0123c2 = vld1q_dup_f32(w); w += 1; 41 | float32x4_t vacc4567c2 = vacc0123c2; 42 | float32x4_t vacc89ABc2 = vacc0123c2; 43 | float32x4_t vaccCDEFc2 = vacc0123c2; 44 | float32x4_t vacc0123c3 = vld1q_dup_f32(w); w += 1; 45 | float32x4_t vacc4567c3 = vacc0123c3; 46 | float32x4_t vacc89ABc3 = vacc0123c3; 47 | float32x4_t vaccCDEFc3 = vacc0123c3; 48 | if (nnz != 0) { 49 | do { 50 | const intptr_t diff = *dmap++; 51 | const float32x4_t va0123 = vld1q_f32(a); 52 | const float32x4_t va4567 = vld1q_f32(a + 4); 53 | const float32x4_t va89AB = vld1q_f32(a + 8); 54 | const float32x4_t vaCDEF = vld1q_f32(a + 12); 55 | __builtin_prefetch(a + 16); 56 | a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff); 57 | const float32x4_t vb = vld1q_f32(w); w += 4; 58 | 59 | vacc0123c0 = vfmaq_laneq_f32(vacc0123c0, va0123, vb, 0); 60 | vacc4567c0 = vfmaq_laneq_f32(vacc4567c0, va4567, vb, 0); 61 | vacc89ABc0 = vfmaq_laneq_f32(vacc89ABc0, va89AB, vb, 0); 62 | vaccCDEFc0 = vfmaq_laneq_f32(vaccCDEFc0, vaCDEF, vb, 0); 63 | vacc0123c1 = vfmaq_laneq_f32(vacc0123c1, va0123, vb, 1); 64 | vacc4567c1 = vfmaq_laneq_f32(vacc4567c1, va4567, vb, 1); 65 | vacc89ABc1 = vfmaq_laneq_f32(vacc89ABc1, va89AB, vb, 1); 66 | vaccCDEFc1 = vfmaq_laneq_f32(vaccCDEFc1, vaCDEF, vb, 1); 67 | vacc0123c2 = vfmaq_laneq_f32(vacc0123c2, va0123, vb, 2); 68 | vacc4567c2 = vfmaq_laneq_f32(vacc4567c2, va4567, vb, 2); 69 | vacc89ABc2 = vfmaq_laneq_f32(vacc89ABc2, va89AB, vb, 2); 70 | vaccCDEFc2 = vfmaq_laneq_f32(vaccCDEFc2, vaCDEF, vb, 2); 71 | vacc0123c3 = vfmaq_laneq_f32(vacc0123c3, va0123, vb, 3); 72 | vacc4567c3 = vfmaq_laneq_f32(vacc4567c3, va4567, vb, 3); 73 | vacc89ABc3 = vfmaq_laneq_f32(vacc89ABc3, va89AB, vb, 3); 74 | vaccCDEFc3 = vfmaq_laneq_f32(vaccCDEFc3, vaCDEF, vb, 3); 75 | } while (--nnz != 0); 76 | } 77 | float32x4_t vout0123c0 = vminq_f32(vacc0123c0, vmax); 78 | float32x4_t vout4567c0 = vminq_f32(vacc4567c0, vmax); 79 | float32x4_t vout89ABc0 = vminq_f32(vacc89ABc0, vmax); 80 | float32x4_t voutCDEFc0 = vminq_f32(vaccCDEFc0, vmax); 81 | float32x4_t vout0123c1 = vminq_f32(vacc0123c1, vmax); 82 | float32x4_t vout4567c1 = vminq_f32(vacc4567c1, vmax); 83 | float32x4_t vout89ABc1 = vminq_f32(vacc89ABc1, vmax); 84 | float32x4_t voutCDEFc1 = vminq_f32(vaccCDEFc1, vmax); 85 | float32x4_t vout0123c2 = vminq_f32(vacc0123c2, vmax); 86 | float32x4_t vout4567c2 = vminq_f32(vacc4567c2, vmax); 87 | float32x4_t vout89ABc2 = vminq_f32(vacc89ABc2, vmax); 88 | float32x4_t voutCDEFc2 = vminq_f32(vaccCDEFc2, vmax); 89 | float32x4_t vout0123c3 = vminq_f32(vacc0123c3, vmax); 90 | float32x4_t vout4567c3 = vminq_f32(vacc4567c3, vmax); 91 | float32x4_t vout89ABc3 = vminq_f32(vacc89ABc3, vmax); 92 | float32x4_t voutCDEFc3 = vminq_f32(vaccCDEFc3, vmax); 93 | 94 | vout0123c0 = vmaxq_f32(vout0123c0, vmin); 95 | vout4567c0 = vmaxq_f32(vout4567c0, vmin); 96 | vout89ABc0 = vmaxq_f32(vout89ABc0, vmin); 97 | voutCDEFc0 = vmaxq_f32(voutCDEFc0, vmin); 98 | vout0123c1 = vmaxq_f32(vout0123c1, vmin); 99 | vout4567c1 = vmaxq_f32(vout4567c1, vmin); 100 | vout89ABc1 = vmaxq_f32(vout89ABc1, vmin); 101 | voutCDEFc1 = vmaxq_f32(voutCDEFc1, vmin); 102 | vout0123c2 = vmaxq_f32(vout0123c2, vmin); 103 | vout4567c2 = vmaxq_f32(vout4567c2, vmin); 104 | vout89ABc2 = vmaxq_f32(vout89ABc2, vmin); 105 | voutCDEFc2 = vmaxq_f32(voutCDEFc2, vmin); 106 | vout0123c3 = vmaxq_f32(vout0123c3, vmin); 107 | vout4567c3 = vmaxq_f32(vout4567c3, vmin); 108 | vout89ABc3 = vmaxq_f32(vout89ABc3, vmin); 109 | voutCDEFc3 = vmaxq_f32(voutCDEFc3, vmin); 110 | 111 | vst1q_f32(c + 0 * m + 0, vout0123c0); 112 | vst1q_f32(c + 0 * m + 4, vout4567c0); 113 | vst1q_f32(c + 0 * m + 8, vout89ABc0); 114 | vst1q_f32(c + 0 * m + 12, voutCDEFc0); 115 | vst1q_f32(c + 1 * m + 0, vout0123c1); 116 | vst1q_f32(c + 1 * m + 4, vout4567c1); 117 | vst1q_f32(c + 1 * m + 8, vout89ABc1); 118 | vst1q_f32(c + 1 * m + 12, voutCDEFc1); 119 | vst1q_f32(c + 2 * m + 0, vout0123c2); 120 | vst1q_f32(c + 2 * m + 4, vout4567c2); 121 | vst1q_f32(c + 2 * m + 8, vout89ABc2); 122 | vst1q_f32(c + 2 * m + 12, voutCDEFc2); 123 | vst1q_f32(c + 3 * m + 0, vout0123c3); 124 | vst1q_f32(c + 3 * m + 4, vout4567c3); 125 | vst1q_f32(c + 3 * m + 8, vout89ABc3); 126 | vst1q_f32(c + 3 * m + 12, voutCDEFc3); 127 | c += 4 * m; 128 | j -= 4; 129 | } 130 | 131 | // clean up loop, fall back to nr=1 132 | if (j != 0) { 133 | do { 134 | uint32_t nnz = *nnzmap++; 135 | float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1; 136 | float32x4_t vacc4567 = vacc0123; 137 | float32x4_t vacc89AB = vacc0123; 138 | float32x4_t vaccCDEF = vacc0123; 139 | if (nnz != 0) { 140 | do { 141 | const intptr_t diff = *dmap++; 142 | const float32x4_t va0123 = vld1q_f32(a); 143 | const float32x4_t va4567 = vld1q_f32(a + 4); 144 | const float32x4_t va89AB = vld1q_f32(a + 8); 145 | const float32x4_t vaCDEF = vld1q_f32(a + 12); 146 | a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff); 147 | const float32x4_t vb = vld1q_dup_f32(w); w += 1; 148 | vacc0123 = vfmaq_f32(vacc0123, va0123, vb); 149 | vacc4567 = vfmaq_f32(vacc4567, va4567, vb); 150 | vacc89AB = vfmaq_f32(vacc89AB, va89AB, vb); 151 | vaccCDEF = vfmaq_f32(vaccCDEF, vaCDEF, vb); 152 | } while (--nnz != 0); 153 | } 154 | float32x4_t vout0123 = vminq_f32(vacc0123, vmax); 155 | float32x4_t vout4567 = vminq_f32(vacc4567, vmax); 156 | float32x4_t vout89AB = vminq_f32(vacc89AB, vmax); 157 | float32x4_t voutCDEF = vminq_f32(vaccCDEF, vmax); 158 | 159 | vout0123 = vmaxq_f32(vout0123, vmin); 160 | vout4567 = vmaxq_f32(vout4567, vmin); 161 | vout89AB = vmaxq_f32(vout89AB, vmin); 162 | voutCDEF = vmaxq_f32(voutCDEF, vmin); 163 | 164 | vst1q_f32(c + 0, vout0123); 165 | vst1q_f32(c + 4, vout4567); 166 | vst1q_f32(c + 8, vout89AB); 167 | vst1q_f32(c + 12, voutCDEF); 168 | c += m; 169 | j -= 1; 170 | } while (j != 0); 171 | } 172 | c -= m * n; 173 | c += 16; 174 | a += 16; 175 | i -= 16; 176 | } 177 | if (i != 0) { 178 | if (i & 8) { 179 | const float*restrict w = weights; 180 | const int32_t* dmap = widx_dmap; 181 | const uint32_t* nnzmap = nidx_nnzmap; 182 | size_t j = n; 183 | while (j >= 4) { 184 | uint32_t nnz = *nnzmap++; 185 | float32x4_t vacc0123c0 = vld1q_dup_f32(w); w += 1; 186 | float32x4_t vacc4567c0 = vacc0123c0; 187 | float32x4_t vacc0123c1 = vld1q_dup_f32(w); w += 1; 188 | float32x4_t vacc4567c1 = vacc0123c1; 189 | float32x4_t vacc0123c2 = vld1q_dup_f32(w); w += 1; 190 | float32x4_t vacc4567c2 = vacc0123c2; 191 | float32x4_t vacc0123c3 = vld1q_dup_f32(w); w += 1; 192 | float32x4_t vacc4567c3 = vacc0123c3; 193 | if (nnz != 0) { 194 | do { 195 | const intptr_t diff = *dmap++; 196 | const float32x4_t va0123 = vld1q_f32(a); 197 | const float32x4_t va4567 = vld1q_f32(a + 4); 198 | a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff); 199 | const float32x4_t vb = vld1q_f32(w); w += 4; 200 | 201 | vacc0123c0 = vfmaq_laneq_f32(vacc0123c0, va0123, vb, 0); 202 | vacc4567c0 = vfmaq_laneq_f32(vacc4567c0, va4567, vb, 0); 203 | vacc0123c1 = vfmaq_laneq_f32(vacc0123c1, va0123, vb, 1); 204 | vacc4567c1 = vfmaq_laneq_f32(vacc4567c1, va4567, vb, 1); 205 | vacc0123c2 = vfmaq_laneq_f32(vacc0123c2, va0123, vb, 2); 206 | vacc4567c2 = vfmaq_laneq_f32(vacc4567c2, va4567, vb, 2); 207 | vacc0123c3 = vfmaq_laneq_f32(vacc0123c3, va0123, vb, 3); 208 | vacc4567c3 = vfmaq_laneq_f32(vacc4567c3, va4567, vb, 3); 209 | } while (--nnz != 0); 210 | } 211 | float32x4_t vout0123c0 = vminq_f32(vacc0123c0, vmax); 212 | float32x4_t vout4567c0 = vminq_f32(vacc4567c0, vmax); 213 | float32x4_t vout0123c1 = vminq_f32(vacc0123c1, vmax); 214 | float32x4_t vout4567c1 = vminq_f32(vacc4567c1, vmax); 215 | float32x4_t vout0123c2 = vminq_f32(vacc0123c2, vmax); 216 | float32x4_t vout4567c2 = vminq_f32(vacc4567c2, vmax); 217 | float32x4_t vout0123c3 = vminq_f32(vacc0123c3, vmax); 218 | float32x4_t vout4567c3 = vminq_f32(vacc4567c3, vmax); 219 | 220 | vout0123c0 = vmaxq_f32(vout0123c0, vmin); 221 | vout4567c0 = vmaxq_f32(vout4567c0, vmin); 222 | vout0123c1 = vmaxq_f32(vout0123c1, vmin); 223 | vout4567c1 = vmaxq_f32(vout4567c1, vmin); 224 | vout0123c2 = vmaxq_f32(vout0123c2, vmin); 225 | vout4567c2 = vmaxq_f32(vout4567c2, vmin); 226 | vout0123c3 = vmaxq_f32(vout0123c3, vmin); 227 | vout4567c3 = vmaxq_f32(vout4567c3, vmin); 228 | 229 | vst1q_f32(c + 0 * m + 0, vout0123c0); 230 | vst1q_f32(c + 0 * m + 4, vout4567c0); 231 | vst1q_f32(c + 1 * m + 0, vout0123c1); 232 | vst1q_f32(c + 1 * m + 4, vout4567c1); 233 | vst1q_f32(c + 2 * m + 0, vout0123c2); 234 | vst1q_f32(c + 2 * m + 4, vout4567c2); 235 | vst1q_f32(c + 3 * m + 0, vout0123c3); 236 | vst1q_f32(c + 3 * m + 4, vout4567c3); 237 | c += 4 * m; 238 | j -= 4; 239 | } 240 | 241 | // clean up loop, fall back to nr=1 242 | if (j != 0) { 243 | do { 244 | uint32_t nnz = *nnzmap++; 245 | float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1; 246 | float32x4_t vacc4567 = vacc0123; 247 | if (nnz != 0) { 248 | do { 249 | const intptr_t diff = *dmap++; 250 | const float32x4_t va0123 = vld1q_f32(a); 251 | const float32x4_t va4567 = vld1q_f32(a + 4); 252 | a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff); 253 | const float32x4_t vb = vld1q_dup_f32(w); w += 1; 254 | vacc0123 = vfmaq_f32(vacc0123, va0123, vb); 255 | vacc4567 = vfmaq_f32(vacc4567, va4567, vb); 256 | } while (--nnz != 0); 257 | } 258 | float32x4_t vout0123 = vminq_f32(vacc0123, vmax); 259 | float32x4_t vout4567 = vminq_f32(vacc4567, vmax); 260 | 261 | vout0123 = vmaxq_f32(vout0123, vmin); 262 | vout4567 = vmaxq_f32(vout4567, vmin); 263 | 264 | vst1q_f32(c + 0, vout0123); 265 | vst1q_f32(c + 4, vout4567); 266 | c += m; 267 | j -= 1; 268 | } while (j != 0); 269 | } 270 | c -= m * n; 271 | c += 8; 272 | a += 8; 273 | } 274 | if (i & 4) { 275 | const float*restrict w = weights; 276 | const int32_t* dmap = widx_dmap; 277 | const uint32_t* nnzmap = nidx_nnzmap; 278 | size_t j = n; 279 | while (j >= 4) { 280 | uint32_t nnz = *nnzmap++; 281 | float32x4_t vacc0123c0 = vld1q_dup_f32(w); w += 1; 282 | float32x4_t vacc0123c1 = vld1q_dup_f32(w); w += 1; 283 | float32x4_t vacc0123c2 = vld1q_dup_f32(w); w += 1; 284 | float32x4_t vacc0123c3 = vld1q_dup_f32(w); w += 1; 285 | if (nnz != 0) { 286 | do { 287 | const intptr_t diff = *dmap++; 288 | const float32x4_t va0123 = vld1q_f32(a); 289 | a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff); 290 | const float32x4_t vb = vld1q_f32(w); w += 4; 291 | 292 | vacc0123c0 = vfmaq_laneq_f32(vacc0123c0, va0123, vb, 0); 293 | vacc0123c1 = vfmaq_laneq_f32(vacc0123c1, va0123, vb, 1); 294 | vacc0123c2 = vfmaq_laneq_f32(vacc0123c2, va0123, vb, 2); 295 | vacc0123c3 = vfmaq_laneq_f32(vacc0123c3, va0123, vb, 3); 296 | } while (--nnz != 0); 297 | } 298 | float32x4_t vout0123c0 = vminq_f32(vacc0123c0, vmax); 299 | float32x4_t vout0123c1 = vminq_f32(vacc0123c1, vmax); 300 | float32x4_t vout0123c2 = vminq_f32(vacc0123c2, vmax); 301 | float32x4_t vout0123c3 = vminq_f32(vacc0123c3, vmax); 302 | 303 | vout0123c0 = vmaxq_f32(vout0123c0, vmin); 304 | vout0123c1 = vmaxq_f32(vout0123c1, vmin); 305 | vout0123c2 = vmaxq_f32(vout0123c2, vmin); 306 | vout0123c3 = vmaxq_f32(vout0123c3, vmin); 307 | 308 | vst1q_f32(c + 0 * m + 0, vout0123c0); 309 | vst1q_f32(c + 1 * m + 0, vout0123c1); 310 | vst1q_f32(c + 2 * m + 0, vout0123c2); 311 | vst1q_f32(c + 3 * m + 0, vout0123c3); 312 | c += 4 * m; 313 | j -= 4; 314 | } 315 | 316 | // clean up loop, fall back to nr=1 317 | if (j != 0) { 318 | do { 319 | uint32_t nnz = *nnzmap++; 320 | float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1; 321 | if (nnz != 0) { 322 | do { 323 | const intptr_t diff = *dmap++; 324 | const float32x4_t va0123 = vld1q_f32(a); 325 | a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff); 326 | const float32x4_t vb = vld1q_dup_f32(w); w += 1; 327 | vacc0123 = vfmaq_f32(vacc0123, va0123, vb); 328 | } while (--nnz != 0); 329 | } 330 | float32x4_t vout0123 = vminq_f32(vacc0123, vmax); 331 | 332 | vout0123 = vmaxq_f32(vout0123, vmin); 333 | 334 | vst1q_f32(c + 0, vout0123); 335 | c += m; 336 | j -= 1; 337 | } while (j != 0); 338 | } 339 | c -= m * n; 340 | c += 4; 341 | a += 4; 342 | } 343 | if (i & 2) { 344 | const float*restrict w = weights; 345 | const int32_t* dmap = widx_dmap; 346 | const uint32_t* nnzmap = nidx_nnzmap; 347 | size_t j = n; 348 | while (j >= 4) { 349 | uint32_t nnz = *nnzmap++; 350 | float32x2_t vacc01c0 = vld1_dup_f32(w); w += 1; 351 | float32x2_t vacc01c1 = vld1_dup_f32(w); w += 1; 352 | float32x2_t vacc01c2 = vld1_dup_f32(w); w += 1; 353 | float32x2_t vacc01c3 = vld1_dup_f32(w); w += 1; 354 | if (nnz != 0) { 355 | do { 356 | const intptr_t diff = *dmap++; 357 | const float32x2_t va01 = vld1_f32(a); 358 | a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff); 359 | const float32x4_t vb = vld1q_f32(w); w += 4; 360 | 361 | vacc01c0 = vfma_laneq_f32(vacc01c0, va01, vb, 0); 362 | vacc01c1 = vfma_laneq_f32(vacc01c1, va01, vb, 1); 363 | vacc01c2 = vfma_laneq_f32(vacc01c2, va01, vb, 2); 364 | vacc01c3 = vfma_laneq_f32(vacc01c3, va01, vb, 3); 365 | } while (--nnz != 0); 366 | } 367 | float32x2_t vout01c0 = vmin_f32(vacc01c0, vget_low_f32(vmax)); 368 | float32x2_t vout01c1 = vmin_f32(vacc01c1, vget_low_f32(vmax)); 369 | float32x2_t vout01c2 = vmin_f32(vacc01c2, vget_low_f32(vmax)); 370 | float32x2_t vout01c3 = vmin_f32(vacc01c3, vget_low_f32(vmax)); 371 | 372 | vout01c0 = vmax_f32(vout01c0, vget_low_f32(vmin)); 373 | vout01c1 = vmax_f32(vout01c1, vget_low_f32(vmin)); 374 | vout01c2 = vmax_f32(vout01c2, vget_low_f32(vmin)); 375 | vout01c3 = vmax_f32(vout01c3, vget_low_f32(vmin)); 376 | 377 | vst1_f32(c + 0 * m + 0, vout01c0); 378 | vst1_f32(c + 1 * m + 0, vout01c1); 379 | vst1_f32(c + 2 * m + 0, vout01c2); 380 | vst1_f32(c + 3 * m + 0, vout01c3); 381 | c += 4 * m; 382 | j -= 4; 383 | } 384 | 385 | // clean up loop, fall back to nr=1 386 | if (j != 0) { 387 | do { 388 | uint32_t nnz = *nnzmap++; 389 | float32x2_t vacc01 = vld1_dup_f32(w); w += 1; 390 | if (nnz != 0) { 391 | do { 392 | const intptr_t diff = *dmap++; 393 | const float32x2_t va01 = vld1_f32(a); 394 | a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff); 395 | const float32x2_t vb = vld1_dup_f32(w); w += 1; 396 | vacc01 = vfma_f32(vacc01, va01, vb); 397 | } while (--nnz != 0); 398 | } 399 | float32x2_t vout01 = vmin_f32(vacc01, vget_low_f32(vmax)); 400 | vout01 = vmax_f32(vout01, vget_low_f32(vmin)); 401 | 402 | vst1_f32(c, vout01); 403 | c += m; 404 | j -= 1; 405 | } while (j != 0); 406 | } 407 | c -= m * n; 408 | c += 2; 409 | a += 2; 410 | } 411 | if (i & 1) { 412 | const float*restrict w = weights; 413 | const int32_t* dmap = widx_dmap; 414 | const uint32_t* nnzmap = nidx_nnzmap; 415 | size_t j = n; 416 | while (j >= 4) { 417 | uint32_t nnz = *nnzmap++; 418 | float32x2_t vacc0c0 = vld1_dup_f32(w); w += 1; 419 | float32x2_t vacc0c1 = vld1_dup_f32(w); w += 1; 420 | float32x2_t vacc0c2 = vld1_dup_f32(w); w += 1; 421 | float32x2_t vacc0c3 = vld1_dup_f32(w); w += 1; 422 | if (nnz != 0) { 423 | do { 424 | const intptr_t diff = *dmap++; 425 | const float32x2_t va0 = vld1_dup_f32(a); 426 | a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff); 427 | const float32x4_t vb = vld1q_f32(w); w += 4; 428 | 429 | vacc0c0 = vfma_laneq_f32(vacc0c0, va0, vb, 0); 430 | vacc0c1 = vfma_laneq_f32(vacc0c1, va0, vb, 1); 431 | vacc0c2 = vfma_laneq_f32(vacc0c2, va0, vb, 2); 432 | vacc0c3 = vfma_laneq_f32(vacc0c3, va0, vb, 3); 433 | } while (--nnz != 0); 434 | } 435 | float32x2_t vout0c0 = vmin_f32(vacc0c0, vget_low_f32(vmax)); 436 | float32x2_t vout0c1 = vmin_f32(vacc0c1, vget_low_f32(vmax)); 437 | float32x2_t vout0c2 = vmin_f32(vacc0c2, vget_low_f32(vmax)); 438 | float32x2_t vout0c3 = vmin_f32(vacc0c3, vget_low_f32(vmax)); 439 | 440 | vout0c0 = vmax_f32(vout0c0, vget_low_f32(vmin)); 441 | vout0c1 = vmax_f32(vout0c1, vget_low_f32(vmin)); 442 | vout0c2 = vmax_f32(vout0c2, vget_low_f32(vmin)); 443 | vout0c3 = vmax_f32(vout0c3, vget_low_f32(vmin)); 444 | 445 | vst1_lane_f32(c + 0 * m + 0, vout0c0, 0); 446 | vst1_lane_f32(c + 1 * m + 0, vout0c1, 0); 447 | vst1_lane_f32(c + 2 * m + 0, vout0c2, 0); 448 | vst1_lane_f32(c + 3 * m + 0, vout0c3, 0); 449 | c += 4 * m; 450 | j -= 4; 451 | } 452 | 453 | // clean up loop, fall back to nr=1 454 | if (j != 0) { 455 | do { 456 | uint32_t nnz = *nnzmap++; 457 | float32x2_t vacc0 = vld1_dup_f32(w); w += 1; 458 | if (nnz != 0) { 459 | do { 460 | const intptr_t diff = *dmap++; 461 | const float32x2_t va0 = vld1_dup_f32(a); 462 | a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff); 463 | const float32x2_t vb = vld1_dup_f32(w); w += 1; 464 | vacc0 = vfma_f32(vacc0, va0, vb); 465 | } while (--nnz != 0); 466 | } 467 | float32x2_t vout0 = vmin_f32(vacc0, vget_low_f32(vmax)); 468 | vout0 = vmax_f32(vout0, vget_low_f32(vmin)); 469 | 470 | vst1_lane_f32(c, vout0, 1); 471 | c += m; 472 | j -= 1; 473 | } while (j != 0); 474 | } 475 | c -= m * n; 476 | c += 1; 477 | a += 1; 478 | } 479 | } 480 | } 481 | -------------------------------------------------------------------------------- /spmm-8x1-scalar.c: -------------------------------------------------------------------------------- 1 | // Copyright The Fast Sparse ConvNets Authors. 2 | // 3 | // This source code is licensed under the BSD-style license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include 7 | 8 | void f32_spmm_ukernel_8x1__scalar( 9 | uint32_t m, 10 | uint32_t n, 11 | const float*restrict a, 12 | const float*restrict weights, 13 | const int32_t*restrict widx_dmap, 14 | const uint32_t*restrict nidx_nnzmap, 15 | float*restrict c, 16 | const union f32_output_params params[restrict static 1]) 17 | { 18 | assert(m != 0); 19 | 20 | const float vmin = params->scalar.min; 21 | const float vmax = params->scalar.max; 22 | size_t i = m; 23 | while (i >= 8) { 24 | const float*restrict w = weights; 25 | const int32_t* dmap = widx_dmap; 26 | const uint32_t* nnzmap = nidx_nnzmap; 27 | size_t j = n; 28 | do { 29 | uint32_t nnz = *nnzmap++; 30 | float vacc0 = *w++; 31 | float vacc1 = vacc0; 32 | float vacc2 = vacc0; 33 | float vacc3 = vacc0; 34 | float vacc4 = vacc0; 35 | float vacc5 = vacc0; 36 | float vacc6 = vacc0; 37 | float vacc7 = vacc0; 38 | if (nnz != 0) { 39 | do { 40 | const intptr_t diff = *dmap++; 41 | const float va0 = a[0]; 42 | const float va1 = a[1]; 43 | const float va2 = a[2]; 44 | const float va3 = a[3]; 45 | const float va4 = a[4]; 46 | const float va5 = a[5]; 47 | const float va6 = a[6]; 48 | const float va7 = a[7]; 49 | a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff); 50 | const float vb = *w++; 51 | vacc0 += va0 * vb; 52 | vacc1 += va1 * vb; 53 | vacc2 += va2 * vb; 54 | vacc3 += va3 * vb; 55 | vacc4 += va4 * vb; 56 | vacc5 += va5 * vb; 57 | vacc6 += va6 * vb; 58 | vacc7 += va7 * vb; 59 | } while (--nnz != 0); 60 | } 61 | float vout0 = math_min_f32(vacc0, vmax); 62 | float vout1 = math_min_f32(vacc1, vmax); 63 | float vout2 = math_min_f32(vacc2, vmax); 64 | float vout3 = math_min_f32(vacc3, vmax); 65 | float vout4 = math_min_f32(vacc4, vmax); 66 | float vout5 = math_min_f32(vacc5, vmax); 67 | float vout6 = math_min_f32(vacc6, vmax); 68 | float vout7 = math_min_f32(vacc7, vmax); 69 | vout0 = math_max_f32(vout0, vmin); 70 | vout1 = math_max_f32(vout1, vmin); 71 | vout2 = math_max_f32(vout2, vmin); 72 | vout3 = math_max_f32(vout3, vmin); 73 | vout4 = math_max_f32(vout4, vmin); 74 | vout5 = math_max_f32(vout5, vmin); 75 | vout6 = math_max_f32(vout6, vmin); 76 | vout7 = math_max_f32(vout7, vmin); 77 | c[0] = vout0; 78 | c[1] = vout1; 79 | c[2] = vout2; 80 | c[3] = vout3; 81 | c[4] = vout4; 82 | c[5] = vout5; 83 | c[6] = vout6; 84 | c[7] = vout7; 85 | c += m; 86 | } while (--j != 0); 87 | c -= m * n; 88 | c += 8; 89 | a += 8; 90 | i -= 8; 91 | } 92 | if (i != 0) { 93 | if (i & 4) { 94 | const float*restrict w = weights; 95 | const int32_t* dmap = widx_dmap; 96 | const uint32_t* nnzmap = nidx_nnzmap; 97 | size_t j = n; 98 | do { 99 | uint32_t nnz = *nnzmap++; 100 | float vacc0 = *w++; 101 | float vacc1 = vacc0; 102 | float vacc2 = vacc0; 103 | float vacc3 = vacc0; 104 | if (nnz != 0) { 105 | do { 106 | const intptr_t diff = *dmap++; 107 | const float va0 = a[0]; 108 | const float va1 = a[1]; 109 | const float va2 = a[2]; 110 | const float va3 = a[3]; 111 | a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff); 112 | const float vb = *w++; 113 | vacc0 += va0 * vb; 114 | vacc1 += va1 * vb; 115 | vacc2 += va2 * vb; 116 | vacc3 += va3 * vb; 117 | } while (--nnz != 0); 118 | } 119 | float vout0 = math_min_f32(vacc0, vmax); 120 | float vout1 = math_min_f32(vacc1, vmax); 121 | float vout2 = math_min_f32(vacc2, vmax); 122 | float vout3 = math_min_f32(vacc3, vmax); 123 | vout0 = math_max_f32(vout0, vmin); 124 | vout1 = math_max_f32(vout1, vmin); 125 | vout2 = math_max_f32(vout2, vmin); 126 | vout3 = math_max_f32(vout3, vmin); 127 | c[0] = vout0; 128 | c[1] = vout1; 129 | c[2] = vout2; 130 | c[3] = vout3; 131 | c += m; 132 | } while (--j != 0); 133 | c -= m * n; 134 | c += 4; 135 | a += 4; 136 | } 137 | if (i & 2) { 138 | const float*restrict w = weights; 139 | const int32_t* dmap = widx_dmap; 140 | const uint32_t* nnzmap = nidx_nnzmap; 141 | size_t j = n; 142 | do { 143 | uint32_t nnz = *nnzmap++; 144 | float vacc0 = *w++; 145 | float vacc1 = vacc0; 146 | if (nnz != 0) { 147 | do { 148 | const intptr_t diff = *dmap++; 149 | const float va0 = a[0]; 150 | const float va1 = a[1]; 151 | a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff); 152 | const float vb = *w++; 153 | vacc0 += va0 * vb; 154 | vacc1 += va1 * vb; 155 | } while (--nnz != 0); 156 | } 157 | float vout0 = math_min_f32(vacc0, vmax); 158 | float vout1 = math_min_f32(vacc1, vmax); 159 | vout0 = math_max_f32(vout0, vmin); 160 | vout1 = math_max_f32(vout1, vmin); 161 | c[0] = vout0; 162 | c[1] = vout1; 163 | c += m; 164 | } while (--j != 0); 165 | c -= m * n; 166 | c += 2; 167 | a += 2; 168 | } 169 | if (i & 1) { 170 | const float*restrict w = weights; 171 | const int32_t* dmap = widx_dmap; 172 | const uint32_t* nnzmap = nidx_nnzmap; 173 | size_t j = n; 174 | do { 175 | uint32_t nnz = *nnzmap++; 176 | float vacc0 = *w++; 177 | if (nnz != 0) { 178 | do { 179 | const intptr_t diff = *dmap++; 180 | const float va0 = a[0]; 181 | a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff); 182 | const float vb = *w++; 183 | vacc0 += va0 * vb; 184 | } while (--nnz != 0); 185 | } 186 | float vout0 = math_min_f32(vacc0, vmax); 187 | vout0 = math_max_f32(vout0, vmin); 188 | c[0] = vout0; 189 | c += m; 190 | } while (--j != 0); 191 | c -= m * n; 192 | c += 1; 193 | a += 1; 194 | } 195 | } 196 | } 197 | -------------------------------------------------------------------------------- /spmm-8x2-scalar.c: -------------------------------------------------------------------------------- 1 | // Copyright The Fast Sparse ConvNets Authors. 2 | // 3 | // This source code is licensed under the BSD-style license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include 7 | 8 | void f32_spmm_ukernel_8x2__scalar( 9 | uint32_t m, 10 | uint32_t n, 11 | const float*restrict a, 12 | const float*restrict weights, 13 | const int32_t*restrict widx_dmap, 14 | const uint32_t*restrict nidx_nnzmap, 15 | float*restrict c, 16 | const union f32_output_params params[restrict static 1]) 17 | { 18 | assert(m != 0); 19 | 20 | const float vmin = params->scalar.min; 21 | const float vmax = params->scalar.max; 22 | size_t i = m; 23 | while (i >= 8) { 24 | const float*restrict w = weights; 25 | const int32_t* dmap = widx_dmap; 26 | const uint32_t* nnzmap = nidx_nnzmap; 27 | size_t j = n; 28 | while (j >= 2) { 29 | uint32_t nnz = *nnzmap++; 30 | float vacc0x0 = *w++; 31 | float vacc1x0 = vacc0x0; 32 | float vacc2x0 = vacc0x0; 33 | float vacc3x0 = vacc0x0; 34 | float vacc4x0 = vacc0x0; 35 | float vacc5x0 = vacc0x0; 36 | float vacc6x0 = vacc0x0; 37 | float vacc7x0 = vacc0x0; 38 | float vacc0x1 = *w++; 39 | float vacc1x1 = vacc0x1; 40 | float vacc2x1 = vacc0x1; 41 | float vacc3x1 = vacc0x1; 42 | float vacc4x1 = vacc0x1; 43 | float vacc5x1 = vacc0x1; 44 | float vacc6x1 = vacc0x1; 45 | float vacc7x1 = vacc0x1; 46 | if (nnz != 0) { 47 | do { 48 | const intptr_t diff = *dmap++; 49 | const float va0 = a[0]; 50 | const float va1 = a[1]; 51 | const float va2 = a[2]; 52 | const float va3 = a[3]; 53 | const float va4 = a[4]; 54 | const float va5 = a[5]; 55 | const float va6 = a[6]; 56 | const float va7 = a[7]; 57 | a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff); 58 | const float vb0 = *w++; 59 | const float vb1 = *w++; 60 | vacc0x0 += va0 * vb0; 61 | vacc1x0 += va1 * vb0; 62 | vacc2x0 += va2 * vb0; 63 | vacc3x0 += va3 * vb0; 64 | vacc4x0 += va4 * vb0; 65 | vacc5x0 += va5 * vb0; 66 | vacc6x0 += va6 * vb0; 67 | vacc7x0 += va7 * vb0; 68 | vacc0x1 += va0 * vb1; 69 | vacc1x1 += va1 * vb1; 70 | vacc2x1 += va2 * vb1; 71 | vacc3x1 += va3 * vb1; 72 | vacc4x1 += va4 * vb1; 73 | vacc5x1 += va5 * vb1; 74 | vacc6x1 += va6 * vb1; 75 | vacc7x1 += va7 * vb1; 76 | } while (--nnz != 0); 77 | } 78 | float vout0x0 = math_min_f32(vacc0x0, vmax); 79 | float vout1x0 = math_min_f32(vacc1x0, vmax); 80 | float vout2x0 = math_min_f32(vacc2x0, vmax); 81 | float vout3x0 = math_min_f32(vacc3x0, vmax); 82 | float vout4x0 = math_min_f32(vacc4x0, vmax); 83 | float vout5x0 = math_min_f32(vacc5x0, vmax); 84 | float vout6x0 = math_min_f32(vacc6x0, vmax); 85 | float vout7x0 = math_min_f32(vacc7x0, vmax); 86 | float vout0x1 = math_min_f32(vacc0x1, vmax); 87 | float vout1x1 = math_min_f32(vacc1x1, vmax); 88 | float vout2x1 = math_min_f32(vacc2x1, vmax); 89 | float vout3x1 = math_min_f32(vacc3x1, vmax); 90 | float vout4x1 = math_min_f32(vacc4x1, vmax); 91 | float vout5x1 = math_min_f32(vacc5x1, vmax); 92 | float vout6x1 = math_min_f32(vacc6x1, vmax); 93 | float vout7x1 = math_min_f32(vacc7x1, vmax); 94 | vout0x0 = math_max_f32(vout0x0, vmin); 95 | vout1x0 = math_max_f32(vout1x0, vmin); 96 | vout2x0 = math_max_f32(vout2x0, vmin); 97 | vout3x0 = math_max_f32(vout3x0, vmin); 98 | vout4x0 = math_max_f32(vout4x0, vmin); 99 | vout5x0 = math_max_f32(vout5x0, vmin); 100 | vout6x0 = math_max_f32(vout6x0, vmin); 101 | vout7x0 = math_max_f32(vout7x0, vmin); 102 | vout0x1 = math_max_f32(vout0x1, vmin); 103 | vout1x1 = math_max_f32(vout1x1, vmin); 104 | vout2x1 = math_max_f32(vout2x1, vmin); 105 | vout3x1 = math_max_f32(vout3x1, vmin); 106 | vout4x1 = math_max_f32(vout4x1, vmin); 107 | vout5x1 = math_max_f32(vout5x1, vmin); 108 | vout6x1 = math_max_f32(vout6x1, vmin); 109 | vout7x1 = math_max_f32(vout7x1, vmin); 110 | c[0 * m + 0] = vout0x0; 111 | c[0 * m + 1] = vout1x0; 112 | c[0 * m + 2] = vout2x0; 113 | c[0 * m + 3] = vout3x0; 114 | c[0 * m + 4] = vout4x0; 115 | c[0 * m + 5] = vout5x0; 116 | c[0 * m + 6] = vout6x0; 117 | c[0 * m + 7] = vout7x0; 118 | c[1 * m + 0] = vout0x1; 119 | c[1 * m + 1] = vout1x1; 120 | c[1 * m + 2] = vout2x1; 121 | c[1 * m + 3] = vout3x1; 122 | c[1 * m + 4] = vout4x1; 123 | c[1 * m + 5] = vout5x1; 124 | c[1 * m + 6] = vout6x1; 125 | c[1 * m + 7] = vout7x1; 126 | c += 2 * m; 127 | j -= 2; 128 | } 129 | if (j != 0) { 130 | do { 131 | uint32_t nnz = *nnzmap++; 132 | float vacc0 = *w++; 133 | float vacc1 = vacc0; 134 | float vacc2 = vacc0; 135 | float vacc3 = vacc0; 136 | float vacc4 = vacc0; 137 | float vacc5 = vacc0; 138 | float vacc6 = vacc0; 139 | float vacc7 = vacc0; 140 | if (nnz != 0) { 141 | do { 142 | const intptr_t diff = *dmap++; 143 | const float va0 = a[0]; 144 | const float va1 = a[1]; 145 | const float va2 = a[2]; 146 | const float va3 = a[3]; 147 | const float va4 = a[4]; 148 | const float va5 = a[5]; 149 | const float va6 = a[6]; 150 | const float va7 = a[7]; 151 | a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff); 152 | const float vb = *w++; 153 | vacc0 += va0 * vb; 154 | vacc1 += va1 * vb; 155 | vacc2 += va2 * vb; 156 | vacc3 += va3 * vb; 157 | vacc4 += va4 * vb; 158 | vacc5 += va5 * vb; 159 | vacc6 += va6 * vb; 160 | vacc7 += va7 * vb; 161 | } while (--nnz != 0); 162 | } 163 | float vout0 = math_min_f32(vacc0, vmax); 164 | float vout1 = math_min_f32(vacc1, vmax); 165 | float vout2 = math_min_f32(vacc2, vmax); 166 | float vout3 = math_min_f32(vacc3, vmax); 167 | float vout4 = math_min_f32(vacc4, vmax); 168 | float vout5 = math_min_f32(vacc5, vmax); 169 | float vout6 = math_min_f32(vacc6, vmax); 170 | float vout7 = math_min_f32(vacc7, vmax); 171 | vout0 = math_max_f32(vout0, vmin); 172 | vout1 = math_max_f32(vout1, vmin); 173 | vout2 = math_max_f32(vout2, vmin); 174 | vout3 = math_max_f32(vout3, vmin); 175 | vout4 = math_max_f32(vout4, vmin); 176 | vout5 = math_max_f32(vout5, vmin); 177 | vout6 = math_max_f32(vout6, vmin); 178 | vout7 = math_max_f32(vout7, vmin); 179 | c[0] = vout0; 180 | c[1] = vout1; 181 | c[2] = vout2; 182 | c[3] = vout3; 183 | c[4] = vout4; 184 | c[5] = vout5; 185 | c[6] = vout6; 186 | c[7] = vout7; 187 | c += m; 188 | j -= 1; 189 | } while (j != 0); 190 | } 191 | c -= m * n; 192 | c += 8; 193 | a += 8; 194 | i -= 8; 195 | } 196 | if (i != 0) { 197 | if (i & 4) { 198 | const float*restrict w = weights; 199 | const int32_t* dmap = widx_dmap; 200 | const uint32_t* nnzmap = nidx_nnzmap; 201 | size_t j = n; 202 | while (j >= 2) { 203 | uint32_t nnz = *nnzmap++; 204 | float vacc0x0 = *w++; 205 | float vacc1x0 = vacc0x0; 206 | float vacc2x0 = vacc0x0; 207 | float vacc3x0 = vacc0x0; 208 | float vacc0x1 = *w++; 209 | float vacc1x1 = vacc0x1; 210 | float vacc2x1 = vacc0x1; 211 | float vacc3x1 = vacc0x1; 212 | if (nnz != 0) { 213 | do { 214 | const intptr_t diff = *dmap++; 215 | const float va0 = a[0]; 216 | const float va1 = a[1]; 217 | const float va2 = a[2]; 218 | const float va3 = a[3]; 219 | a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff); 220 | const float vb0 = *w++; 221 | const float vb1 = *w++; 222 | vacc0x0 += va0 * vb0; 223 | vacc1x0 += va1 * vb0; 224 | vacc2x0 += va2 * vb0; 225 | vacc3x0 += va3 * vb0; 226 | vacc0x1 += va0 * vb1; 227 | vacc1x1 += va1 * vb1; 228 | vacc2x1 += va2 * vb1; 229 | vacc3x1 += va3 * vb1; 230 | } while (--nnz != 0); 231 | } 232 | float vout0x0 = math_min_f32(vacc0x0, vmax); 233 | float vout1x0 = math_min_f32(vacc1x0, vmax); 234 | float vout2x0 = math_min_f32(vacc2x0, vmax); 235 | float vout3x0 = math_min_f32(vacc3x0, vmax); 236 | float vout0x1 = math_min_f32(vacc0x1, vmax); 237 | float vout1x1 = math_min_f32(vacc1x1, vmax); 238 | float vout2x1 = math_min_f32(vacc2x1, vmax); 239 | float vout3x1 = math_min_f32(vacc3x1, vmax); 240 | vout0x0 = math_max_f32(vout0x0, vmin); 241 | vout1x0 = math_max_f32(vout1x0, vmin); 242 | vout2x0 = math_max_f32(vout2x0, vmin); 243 | vout3x0 = math_max_f32(vout3x0, vmin); 244 | vout0x1 = math_max_f32(vout0x1, vmin); 245 | vout1x1 = math_max_f32(vout1x1, vmin); 246 | vout2x1 = math_max_f32(vout2x1, vmin); 247 | vout3x1 = math_max_f32(vout3x1, vmin); 248 | c[0 * m + 0] = vout0x0; 249 | c[0 * m + 1] = vout1x0; 250 | c[0 * m + 2] = vout2x0; 251 | c[0 * m + 3] = vout3x0; 252 | c[1 * m + 0] = vout0x1; 253 | c[1 * m + 1] = vout1x1; 254 | c[1 * m + 2] = vout2x1; 255 | c[1 * m + 3] = vout3x1; 256 | c += 2 * m; 257 | j -= 2; 258 | } 259 | if (j != 0) { 260 | do { 261 | uint32_t nnz = *nnzmap++; 262 | float vacc0 = *w++; 263 | float vacc1 = vacc0; 264 | float vacc2 = vacc0; 265 | float vacc3 = vacc0; 266 | if (nnz != 0) { 267 | do { 268 | const intptr_t diff = *dmap++; 269 | const float va0 = a[0]; 270 | const float va1 = a[1]; 271 | const float va2 = a[2]; 272 | const float va3 = a[3]; 273 | a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff); 274 | const float vb = *w++; 275 | vacc0 += va0 * vb; 276 | vacc1 += va1 * vb; 277 | vacc2 += va2 * vb; 278 | vacc3 += va3 * vb; 279 | } while (--nnz != 0); 280 | } 281 | float vout0 = math_min_f32(vacc0, vmax); 282 | float vout1 = math_min_f32(vacc1, vmax); 283 | float vout2 = math_min_f32(vacc2, vmax); 284 | float vout3 = math_min_f32(vacc3, vmax); 285 | vout0 = math_max_f32(vout0, vmin); 286 | vout1 = math_max_f32(vout1, vmin); 287 | vout2 = math_max_f32(vout2, vmin); 288 | vout3 = math_max_f32(vout3, vmin); 289 | c[0] = vout0; 290 | c[1] = vout1; 291 | c[2] = vout2; 292 | c[3] = vout3; 293 | c += m; 294 | j -= 1; 295 | } while (j != 0); 296 | } 297 | c -= m * n; 298 | c += 4; 299 | a += 4; 300 | } 301 | if (i & 2) { 302 | const float*restrict w = weights; 303 | const int32_t* dmap = widx_dmap; 304 | const uint32_t* nnzmap = nidx_nnzmap; 305 | size_t j = n; 306 | while (j >= 2) { 307 | uint32_t nnz = *nnzmap++; 308 | float vacc0x0 = *w++; 309 | float vacc1x0 = vacc0x0; 310 | float vacc0x1 = *w++; 311 | float vacc1x1 = vacc0x1; 312 | if (nnz != 0) { 313 | do { 314 | const intptr_t diff = *dmap++; 315 | const float va0 = a[0]; 316 | const float va1 = a[1]; 317 | a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff); 318 | const float vb0 = *w++; 319 | const float vb1 = *w++; 320 | vacc0x0 += va0 * vb0; 321 | vacc1x0 += va1 * vb0; 322 | vacc0x1 += va0 * vb1; 323 | vacc1x1 += va1 * vb1; 324 | } while (--nnz != 0); 325 | } 326 | float vout0x0 = math_min_f32(vacc0x0, vmax); 327 | float vout1x0 = math_min_f32(vacc1x0, vmax); 328 | float vout0x1 = math_min_f32(vacc0x1, vmax); 329 | float vout1x1 = math_min_f32(vacc1x1, vmax); 330 | vout0x0 = math_max_f32(vout0x0, vmin); 331 | vout1x0 = math_max_f32(vout1x0, vmin); 332 | vout0x1 = math_max_f32(vout0x1, vmin); 333 | vout1x1 = math_max_f32(vout1x1, vmin); 334 | c[0 * m + 0] = vout0x0; 335 | c[0 * m + 1] = vout1x0; 336 | c[1 * m + 0] = vout0x1; 337 | c[1 * m + 1] = vout1x1; 338 | c += 2 * m; 339 | j -= 2; 340 | } 341 | if (j != 0) { 342 | do { 343 | uint32_t nnz = *nnzmap++; 344 | float vacc0 = *w++; 345 | float vacc1 = vacc0; 346 | if (nnz != 0) { 347 | do { 348 | const intptr_t diff = *dmap++; 349 | const float va0 = a[0]; 350 | const float va1 = a[1]; 351 | a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff); 352 | const float vb = *w++; 353 | vacc0 += va0 * vb; 354 | vacc1 += va1 * vb; 355 | } while (--nnz != 0); 356 | } 357 | float vout0 = math_min_f32(vacc0, vmax); 358 | float vout1 = math_min_f32(vacc1, vmax); 359 | vout0 = math_max_f32(vout0, vmin); 360 | vout1 = math_max_f32(vout1, vmin); 361 | c[0] = vout0; 362 | c[1] = vout1; 363 | c += m; 364 | j -= 1; 365 | } while (j != 0); 366 | } 367 | c -= m * n; 368 | c += 2; 369 | a += 2; 370 | } 371 | if (i & 1) { 372 | const float*restrict w = weights; 373 | const int32_t* dmap = widx_dmap; 374 | const uint32_t* nnzmap = nidx_nnzmap; 375 | size_t j = n; 376 | while (j >= 2) { 377 | uint32_t nnz = *nnzmap++; 378 | float vacc0x0 = *w++; 379 | float vacc0x1 = *w++; 380 | if (nnz != 0) { 381 | do { 382 | const intptr_t diff = *dmap++; 383 | const float va0 = a[0]; 384 | a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff); 385 | const float vb0 = *w++; 386 | const float vb1 = *w++; 387 | vacc0x0 += va0 * vb0; 388 | vacc0x1 += va0 * vb1; 389 | } while (--nnz != 0); 390 | } 391 | float vout0x0 = math_min_f32(vacc0x0, vmax); 392 | float vout0x1 = math_min_f32(vacc0x1, vmax); 393 | vout0x0 = math_max_f32(vout0x0, vmin); 394 | vout0x1 = math_max_f32(vout0x1, vmin); 395 | c[0 * m + 0] = vout0x0; 396 | c[1 * m + 0] = vout0x1; 397 | c += 2 * m; 398 | j -= 2; 399 | } 400 | if (j != 0) { 401 | do { 402 | uint32_t nnz = *nnzmap++; 403 | float vacc0 = *w++; 404 | if (nnz != 0) { 405 | do { 406 | const intptr_t diff = *dmap++; 407 | const float va0 = a[0]; 408 | a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff); 409 | const float vb = *w++; 410 | vacc0 += va0 * vb; 411 | } while (--nnz != 0); 412 | } 413 | float vout0 = math_min_f32(vacc0, vmax); 414 | vout0 = math_max_f32(vout0, vmin); 415 | c[0] = vout0; 416 | c += m; 417 | j -= 1; 418 | } while (j != 0); 419 | } 420 | c -= m * n; 421 | c += 1; 422 | a += 1; 423 | } 424 | } 425 | } 426 | -------------------------------------------------------------------------------- /spmm-8x4-scalar.c: -------------------------------------------------------------------------------- 1 | // Copyright The Fast Sparse ConvNets Authors. 2 | // 3 | // This source code is licensed under the BSD-style license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include 7 | 8 | void f32_spmm_ukernel_8x4__scalar( 9 | uint32_t m, 10 | uint32_t n, 11 | const float*restrict a, 12 | const float*restrict weights, 13 | const int32_t*restrict widx_dmap, 14 | const uint32_t*restrict nidx_nnzmap, 15 | float*restrict c, 16 | const union f32_output_params params[restrict static 1]) 17 | { 18 | assert(m != 0); 19 | 20 | const float vmin = params->scalar.min; 21 | const float vmax = params->scalar.max; 22 | size_t i = m; 23 | while (i >= 8) { 24 | const float*restrict w = weights; 25 | const int32_t* dmap = widx_dmap; 26 | const uint32_t* nnzmap = nidx_nnzmap; 27 | size_t j = n; 28 | while (j >= 4) { 29 | uint32_t nnz = *nnzmap++; 30 | float vacc0x0 = *w++; 31 | float vacc1x0 = vacc0x0; 32 | float vacc2x0 = vacc0x0; 33 | float vacc3x0 = vacc0x0; 34 | float vacc4x0 = vacc0x0; 35 | float vacc5x0 = vacc0x0; 36 | float vacc6x0 = vacc0x0; 37 | float vacc7x0 = vacc0x0; 38 | float vacc0x1 = *w++; 39 | float vacc1x1 = vacc0x1; 40 | float vacc2x1 = vacc0x1; 41 | float vacc3x1 = vacc0x1; 42 | float vacc4x1 = vacc0x1; 43 | float vacc5x1 = vacc0x1; 44 | float vacc6x1 = vacc0x1; 45 | float vacc7x1 = vacc0x1; 46 | float vacc0x2 = *w++; 47 | float vacc1x2 = vacc0x2; 48 | float vacc2x2 = vacc0x2; 49 | float vacc3x2 = vacc0x2; 50 | float vacc4x2 = vacc0x2; 51 | float vacc5x2 = vacc0x2; 52 | float vacc6x2 = vacc0x2; 53 | float vacc7x2 = vacc0x2; 54 | float vacc0x3 = *w++; 55 | float vacc1x3 = vacc0x3; 56 | float vacc2x3 = vacc0x3; 57 | float vacc3x3 = vacc0x3; 58 | float vacc4x3 = vacc0x3; 59 | float vacc5x3 = vacc0x3; 60 | float vacc6x3 = vacc0x3; 61 | float vacc7x3 = vacc0x3; 62 | if (nnz != 0) { 63 | do { 64 | const intptr_t diff = *dmap++; 65 | const float va0 = a[0]; 66 | const float va1 = a[1]; 67 | const float va2 = a[2]; 68 | const float va3 = a[3]; 69 | const float va4 = a[4]; 70 | const float va5 = a[5]; 71 | const float va6 = a[6]; 72 | const float va7 = a[7]; 73 | a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff); 74 | const float vb0 = *w++; 75 | const float vb1 = *w++; 76 | const float vb2 = *w++; 77 | const float vb3 = *w++; 78 | vacc0x0 += va0 * vb0; 79 | vacc1x0 += va1 * vb0; 80 | vacc2x0 += va2 * vb0; 81 | vacc3x0 += va3 * vb0; 82 | vacc4x0 += va4 * vb0; 83 | vacc5x0 += va5 * vb0; 84 | vacc6x0 += va6 * vb0; 85 | vacc7x0 += va7 * vb0; 86 | vacc0x1 += va0 * vb1; 87 | vacc1x1 += va1 * vb1; 88 | vacc2x1 += va2 * vb1; 89 | vacc3x1 += va3 * vb1; 90 | vacc4x1 += va4 * vb1; 91 | vacc5x1 += va5 * vb1; 92 | vacc6x1 += va6 * vb1; 93 | vacc7x1 += va7 * vb1; 94 | vacc0x2 += va0 * vb2; 95 | vacc1x2 += va1 * vb2; 96 | vacc2x2 += va2 * vb2; 97 | vacc3x2 += va3 * vb2; 98 | vacc4x2 += va4 * vb2; 99 | vacc5x2 += va5 * vb2; 100 | vacc6x2 += va6 * vb2; 101 | vacc7x2 += va7 * vb2; 102 | vacc0x3 += va0 * vb3; 103 | vacc1x3 += va1 * vb3; 104 | vacc2x3 += va2 * vb3; 105 | vacc3x3 += va3 * vb3; 106 | vacc4x3 += va4 * vb3; 107 | vacc5x3 += va5 * vb3; 108 | vacc6x3 += va6 * vb3; 109 | vacc7x3 += va7 * vb3; 110 | } while (--nnz != 0); 111 | } 112 | float vout0x0 = math_min_f32(vacc0x0, vmax); 113 | float vout1x0 = math_min_f32(vacc1x0, vmax); 114 | float vout2x0 = math_min_f32(vacc2x0, vmax); 115 | float vout3x0 = math_min_f32(vacc3x0, vmax); 116 | float vout4x0 = math_min_f32(vacc4x0, vmax); 117 | float vout5x0 = math_min_f32(vacc5x0, vmax); 118 | float vout6x0 = math_min_f32(vacc6x0, vmax); 119 | float vout7x0 = math_min_f32(vacc7x0, vmax); 120 | float vout0x1 = math_min_f32(vacc0x1, vmax); 121 | float vout1x1 = math_min_f32(vacc1x1, vmax); 122 | float vout2x1 = math_min_f32(vacc2x1, vmax); 123 | float vout3x1 = math_min_f32(vacc3x1, vmax); 124 | float vout4x1 = math_min_f32(vacc4x1, vmax); 125 | float vout5x1 = math_min_f32(vacc5x1, vmax); 126 | float vout6x1 = math_min_f32(vacc6x1, vmax); 127 | float vout7x1 = math_min_f32(vacc7x1, vmax); 128 | float vout0x2 = math_min_f32(vacc0x2, vmax); 129 | float vout1x2 = math_min_f32(vacc1x2, vmax); 130 | float vout2x2 = math_min_f32(vacc2x2, vmax); 131 | float vout3x2 = math_min_f32(vacc3x2, vmax); 132 | float vout4x2 = math_min_f32(vacc4x2, vmax); 133 | float vout5x2 = math_min_f32(vacc5x2, vmax); 134 | float vout6x2 = math_min_f32(vacc6x2, vmax); 135 | float vout7x2 = math_min_f32(vacc7x2, vmax); 136 | float vout0x3 = math_min_f32(vacc0x3, vmax); 137 | float vout1x3 = math_min_f32(vacc1x3, vmax); 138 | float vout2x3 = math_min_f32(vacc2x3, vmax); 139 | float vout3x3 = math_min_f32(vacc3x3, vmax); 140 | float vout4x3 = math_min_f32(vacc4x3, vmax); 141 | float vout5x3 = math_min_f32(vacc5x3, vmax); 142 | float vout6x3 = math_min_f32(vacc6x3, vmax); 143 | float vout7x3 = math_min_f32(vacc7x3, vmax); 144 | vout0x0 = math_max_f32(vout0x0, vmin); 145 | vout1x0 = math_max_f32(vout1x0, vmin); 146 | vout2x0 = math_max_f32(vout2x0, vmin); 147 | vout3x0 = math_max_f32(vout3x0, vmin); 148 | vout4x0 = math_max_f32(vout4x0, vmin); 149 | vout5x0 = math_max_f32(vout5x0, vmin); 150 | vout6x0 = math_max_f32(vout6x0, vmin); 151 | vout7x0 = math_max_f32(vout7x0, vmin); 152 | vout0x1 = math_max_f32(vout0x1, vmin); 153 | vout1x1 = math_max_f32(vout1x1, vmin); 154 | vout2x1 = math_max_f32(vout2x1, vmin); 155 | vout3x1 = math_max_f32(vout3x1, vmin); 156 | vout4x1 = math_max_f32(vout4x1, vmin); 157 | vout5x1 = math_max_f32(vout5x1, vmin); 158 | vout6x1 = math_max_f32(vout6x1, vmin); 159 | vout7x1 = math_max_f32(vout7x1, vmin); 160 | vout0x2 = math_max_f32(vout0x2, vmin); 161 | vout1x2 = math_max_f32(vout1x2, vmin); 162 | vout2x2 = math_max_f32(vout2x2, vmin); 163 | vout3x2 = math_max_f32(vout3x2, vmin); 164 | vout4x2 = math_max_f32(vout4x2, vmin); 165 | vout5x2 = math_max_f32(vout5x2, vmin); 166 | vout6x2 = math_max_f32(vout6x2, vmin); 167 | vout7x2 = math_max_f32(vout7x2, vmin); 168 | vout0x3 = math_max_f32(vout0x3, vmin); 169 | vout1x3 = math_max_f32(vout1x3, vmin); 170 | vout2x3 = math_max_f32(vout2x3, vmin); 171 | vout3x3 = math_max_f32(vout3x3, vmin); 172 | vout4x3 = math_max_f32(vout4x3, vmin); 173 | vout5x3 = math_max_f32(vout5x3, vmin); 174 | vout6x3 = math_max_f32(vout6x3, vmin); 175 | vout7x3 = math_max_f32(vout7x3, vmin); 176 | c[0 * m + 0] = vout0x0; 177 | c[0 * m + 1] = vout1x0; 178 | c[0 * m + 2] = vout2x0; 179 | c[0 * m + 3] = vout3x0; 180 | c[0 * m + 4] = vout4x0; 181 | c[0 * m + 5] = vout5x0; 182 | c[0 * m + 6] = vout6x0; 183 | c[0 * m + 7] = vout7x0; 184 | c[1 * m + 0] = vout0x1; 185 | c[1 * m + 1] = vout1x1; 186 | c[1 * m + 2] = vout2x1; 187 | c[1 * m + 3] = vout3x1; 188 | c[1 * m + 4] = vout4x1; 189 | c[1 * m + 5] = vout5x1; 190 | c[1 * m + 6] = vout6x1; 191 | c[1 * m + 7] = vout7x1; 192 | c[2 * m + 0] = vout0x2; 193 | c[2 * m + 1] = vout1x2; 194 | c[2 * m + 2] = vout2x2; 195 | c[2 * m + 3] = vout3x2; 196 | c[2 * m + 4] = vout4x2; 197 | c[2 * m + 5] = vout5x2; 198 | c[2 * m + 6] = vout6x2; 199 | c[2 * m + 7] = vout7x2; 200 | c[3 * m + 0] = vout0x3; 201 | c[3 * m + 1] = vout1x3; 202 | c[3 * m + 2] = vout2x3; 203 | c[3 * m + 3] = vout3x3; 204 | c[3 * m + 4] = vout4x3; 205 | c[3 * m + 5] = vout5x3; 206 | c[3 * m + 6] = vout6x3; 207 | c[3 * m + 7] = vout7x3; 208 | c += 4 * m; 209 | j -= 4; 210 | } 211 | if (j != 0) { 212 | do { 213 | uint32_t nnz = *nnzmap++; 214 | float vacc0 = *w++; 215 | float vacc1 = vacc0; 216 | float vacc2 = vacc0; 217 | float vacc3 = vacc0; 218 | float vacc4 = vacc0; 219 | float vacc5 = vacc0; 220 | float vacc6 = vacc0; 221 | float vacc7 = vacc0; 222 | if (nnz != 0) { 223 | do { 224 | const intptr_t diff = *dmap++; 225 | const float va0 = a[0]; 226 | const float va1 = a[1]; 227 | const float va2 = a[2]; 228 | const float va3 = a[3]; 229 | const float va4 = a[4]; 230 | const float va5 = a[5]; 231 | const float va6 = a[6]; 232 | const float va7 = a[7]; 233 | a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff); 234 | const float vb = *w++; 235 | vacc0 += va0 * vb; 236 | vacc1 += va1 * vb; 237 | vacc2 += va2 * vb; 238 | vacc3 += va3 * vb; 239 | vacc4 += va4 * vb; 240 | vacc5 += va5 * vb; 241 | vacc6 += va6 * vb; 242 | vacc7 += va7 * vb; 243 | } while (--nnz != 0); 244 | } 245 | float vout0 = math_min_f32(vacc0, vmax); 246 | float vout1 = math_min_f32(vacc1, vmax); 247 | float vout2 = math_min_f32(vacc2, vmax); 248 | float vout3 = math_min_f32(vacc3, vmax); 249 | float vout4 = math_min_f32(vacc4, vmax); 250 | float vout5 = math_min_f32(vacc5, vmax); 251 | float vout6 = math_min_f32(vacc6, vmax); 252 | float vout7 = math_min_f32(vacc7, vmax); 253 | vout0 = math_max_f32(vout0, vmin); 254 | vout1 = math_max_f32(vout1, vmin); 255 | vout2 = math_max_f32(vout2, vmin); 256 | vout3 = math_max_f32(vout3, vmin); 257 | vout4 = math_max_f32(vout4, vmin); 258 | vout5 = math_max_f32(vout5, vmin); 259 | vout6 = math_max_f32(vout6, vmin); 260 | vout7 = math_max_f32(vout7, vmin); 261 | c[0] = vout0; 262 | c[1] = vout1; 263 | c[2] = vout2; 264 | c[3] = vout3; 265 | c[4] = vout4; 266 | c[5] = vout5; 267 | c[6] = vout6; 268 | c[7] = vout7; 269 | c += m; 270 | j -= 1; 271 | } while (j != 0); 272 | } 273 | c -= m * n; 274 | c += 8; 275 | a += 8; 276 | i -= 8; 277 | } 278 | if (i != 0) { 279 | if (i & 4) { 280 | const float*restrict w = weights; 281 | const int32_t* dmap = widx_dmap; 282 | const uint32_t* nnzmap = nidx_nnzmap; 283 | size_t j = n; 284 | while (j >= 4) { 285 | uint32_t nnz = *nnzmap++; 286 | float vacc0x0 = *w++; 287 | float vacc1x0 = vacc0x0; 288 | float vacc2x0 = vacc0x0; 289 | float vacc3x0 = vacc0x0; 290 | float vacc0x1 = *w++; 291 | float vacc1x1 = vacc0x1; 292 | float vacc2x1 = vacc0x1; 293 | float vacc3x1 = vacc0x1; 294 | float vacc0x2 = *w++; 295 | float vacc1x2 = vacc0x2; 296 | float vacc2x2 = vacc0x2; 297 | float vacc3x2 = vacc0x2; 298 | float vacc0x3 = *w++; 299 | float vacc1x3 = vacc0x3; 300 | float vacc2x3 = vacc0x3; 301 | float vacc3x3 = vacc0x3; 302 | if (nnz != 0) { 303 | do { 304 | const intptr_t diff = *dmap++; 305 | const float va0 = a[0]; 306 | const float va1 = a[1]; 307 | const float va2 = a[2]; 308 | const float va3 = a[3]; 309 | a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff); 310 | const float vb0 = *w++; 311 | const float vb1 = *w++; 312 | const float vb2 = *w++; 313 | const float vb3 = *w++; 314 | vacc0x0 += va0 * vb0; 315 | vacc1x0 += va1 * vb0; 316 | vacc2x0 += va2 * vb0; 317 | vacc3x0 += va3 * vb0; 318 | vacc0x1 += va0 * vb1; 319 | vacc1x1 += va1 * vb1; 320 | vacc2x1 += va2 * vb1; 321 | vacc3x1 += va3 * vb1; 322 | vacc0x2 += va0 * vb2; 323 | vacc1x2 += va1 * vb2; 324 | vacc2x2 += va2 * vb2; 325 | vacc3x2 += va3 * vb2; 326 | vacc0x3 += va0 * vb3; 327 | vacc1x3 += va1 * vb3; 328 | vacc2x3 += va2 * vb3; 329 | vacc3x3 += va3 * vb3; 330 | } while (--nnz != 0); 331 | } 332 | float vout0x0 = math_min_f32(vacc0x0, vmax); 333 | float vout1x0 = math_min_f32(vacc1x0, vmax); 334 | float vout2x0 = math_min_f32(vacc2x0, vmax); 335 | float vout3x0 = math_min_f32(vacc3x0, vmax); 336 | float vout0x1 = math_min_f32(vacc0x1, vmax); 337 | float vout1x1 = math_min_f32(vacc1x1, vmax); 338 | float vout2x1 = math_min_f32(vacc2x1, vmax); 339 | float vout3x1 = math_min_f32(vacc3x1, vmax); 340 | float vout0x2 = math_min_f32(vacc0x2, vmax); 341 | float vout1x2 = math_min_f32(vacc1x2, vmax); 342 | float vout2x2 = math_min_f32(vacc2x2, vmax); 343 | float vout3x2 = math_min_f32(vacc3x2, vmax); 344 | float vout0x3 = math_min_f32(vacc0x3, vmax); 345 | float vout1x3 = math_min_f32(vacc1x3, vmax); 346 | float vout2x3 = math_min_f32(vacc2x3, vmax); 347 | float vout3x3 = math_min_f32(vacc3x3, vmax); 348 | vout0x0 = math_max_f32(vout0x0, vmin); 349 | vout1x0 = math_max_f32(vout1x0, vmin); 350 | vout2x0 = math_max_f32(vout2x0, vmin); 351 | vout3x0 = math_max_f32(vout3x0, vmin); 352 | vout0x1 = math_max_f32(vout0x1, vmin); 353 | vout1x1 = math_max_f32(vout1x1, vmin); 354 | vout2x1 = math_max_f32(vout2x1, vmin); 355 | vout3x1 = math_max_f32(vout3x1, vmin); 356 | vout0x2 = math_max_f32(vout0x2, vmin); 357 | vout1x2 = math_max_f32(vout1x2, vmin); 358 | vout2x2 = math_max_f32(vout2x2, vmin); 359 | vout3x2 = math_max_f32(vout3x2, vmin); 360 | vout0x3 = math_max_f32(vout0x3, vmin); 361 | vout1x3 = math_max_f32(vout1x3, vmin); 362 | vout2x3 = math_max_f32(vout2x3, vmin); 363 | vout3x3 = math_max_f32(vout3x3, vmin); 364 | c[0 * m + 0] = vout0x0; 365 | c[0 * m + 1] = vout1x0; 366 | c[0 * m + 2] = vout2x0; 367 | c[0 * m + 3] = vout3x0; 368 | c[1 * m + 0] = vout0x1; 369 | c[1 * m + 1] = vout1x1; 370 | c[1 * m + 2] = vout2x1; 371 | c[1 * m + 3] = vout3x1; 372 | c[2 * m + 0] = vout0x2; 373 | c[2 * m + 1] = vout1x2; 374 | c[2 * m + 2] = vout2x2; 375 | c[2 * m + 3] = vout3x2; 376 | c[3 * m + 0] = vout0x3; 377 | c[3 * m + 1] = vout1x3; 378 | c[3 * m + 2] = vout2x3; 379 | c[3 * m + 3] = vout3x3; 380 | c += 4 * m; 381 | j -= 4; 382 | } 383 | if (j != 0) { 384 | do { 385 | uint32_t nnz = *nnzmap++; 386 | float vacc0 = *w++; 387 | float vacc1 = vacc0; 388 | float vacc2 = vacc0; 389 | float vacc3 = vacc0; 390 | if (nnz != 0) { 391 | do { 392 | const intptr_t diff = *dmap++; 393 | const float va0 = a[0]; 394 | const float va1 = a[1]; 395 | const float va2 = a[2]; 396 | const float va3 = a[3]; 397 | a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff); 398 | const float vb = *w++; 399 | vacc0 += va0 * vb; 400 | vacc1 += va1 * vb; 401 | vacc2 += va2 * vb; 402 | vacc3 += va3 * vb; 403 | } while (--nnz != 0); 404 | } 405 | float vout0 = math_min_f32(vacc0, vmax); 406 | float vout1 = math_min_f32(vacc1, vmax); 407 | float vout2 = math_min_f32(vacc2, vmax); 408 | float vout3 = math_min_f32(vacc3, vmax); 409 | vout0 = math_max_f32(vout0, vmin); 410 | vout1 = math_max_f32(vout1, vmin); 411 | vout2 = math_max_f32(vout2, vmin); 412 | vout3 = math_max_f32(vout3, vmin); 413 | c[0] = vout0; 414 | c[1] = vout1; 415 | c[2] = vout2; 416 | c[3] = vout3; 417 | c += m; 418 | j -= 1; 419 | } while (j != 0); 420 | } 421 | c -= m * n; 422 | c += 4; 423 | a += 4; 424 | } 425 | if (i & 2) { 426 | const float*restrict w = weights; 427 | const int32_t* dmap = widx_dmap; 428 | const uint32_t* nnzmap = nidx_nnzmap; 429 | size_t j = n; 430 | while (j >= 4) { 431 | uint32_t nnz = *nnzmap++; 432 | float vacc0x0 = *w++; 433 | float vacc1x0 = vacc0x0; 434 | float vacc0x1 = *w++; 435 | float vacc1x1 = vacc0x1; 436 | float vacc0x2 = *w++; 437 | float vacc1x2 = vacc0x2; 438 | float vacc0x3 = *w++; 439 | float vacc1x3 = vacc0x3; 440 | if (nnz != 0) { 441 | do { 442 | const intptr_t diff = *dmap++; 443 | const float va0 = a[0]; 444 | const float va1 = a[1]; 445 | a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff); 446 | const float vb0 = *w++; 447 | const float vb1 = *w++; 448 | const float vb2 = *w++; 449 | const float vb3 = *w++; 450 | vacc0x0 += va0 * vb0; 451 | vacc1x0 += va1 * vb0; 452 | vacc0x1 += va0 * vb1; 453 | vacc1x1 += va1 * vb1; 454 | vacc0x2 += va0 * vb2; 455 | vacc1x2 += va1 * vb2; 456 | vacc0x3 += va0 * vb3; 457 | vacc1x3 += va1 * vb3; 458 | } while (--nnz != 0); 459 | } 460 | float vout0x0 = math_min_f32(vacc0x0, vmax); 461 | float vout1x0 = math_min_f32(vacc1x0, vmax); 462 | float vout0x1 = math_min_f32(vacc0x1, vmax); 463 | float vout1x1 = math_min_f32(vacc1x1, vmax); 464 | float vout0x2 = math_min_f32(vacc0x2, vmax); 465 | float vout1x2 = math_min_f32(vacc1x2, vmax); 466 | float vout0x3 = math_min_f32(vacc0x3, vmax); 467 | float vout1x3 = math_min_f32(vacc1x3, vmax); 468 | vout0x0 = math_max_f32(vout0x0, vmin); 469 | vout1x0 = math_max_f32(vout1x0, vmin); 470 | vout0x1 = math_max_f32(vout0x1, vmin); 471 | vout1x1 = math_max_f32(vout1x1, vmin); 472 | vout0x2 = math_max_f32(vout0x2, vmin); 473 | vout1x2 = math_max_f32(vout1x2, vmin); 474 | vout0x3 = math_max_f32(vout0x3, vmin); 475 | vout1x3 = math_max_f32(vout1x3, vmin); 476 | c[0 * m + 0] = vout0x0; 477 | c[0 * m + 1] = vout1x0; 478 | c[1 * m + 0] = vout0x1; 479 | c[1 * m + 1] = vout1x1; 480 | c[2 * m + 0] = vout0x2; 481 | c[2 * m + 1] = vout1x2; 482 | c[3 * m + 0] = vout0x3; 483 | c[3 * m + 1] = vout1x3; 484 | c += 4 * m; 485 | j -= 4; 486 | } 487 | if (j != 0) { 488 | do { 489 | uint32_t nnz = *nnzmap++; 490 | float vacc0 = *w++; 491 | float vacc1 = vacc0; 492 | if (nnz != 0) { 493 | do { 494 | const intptr_t diff = *dmap++; 495 | const float va0 = a[0]; 496 | const float va1 = a[1]; 497 | a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff); 498 | const float vb = *w++; 499 | vacc0 += va0 * vb; 500 | vacc1 += va1 * vb; 501 | } while (--nnz != 0); 502 | } 503 | float vout0 = math_min_f32(vacc0, vmax); 504 | float vout1 = math_min_f32(vacc1, vmax); 505 | vout0 = math_max_f32(vout0, vmin); 506 | vout1 = math_max_f32(vout1, vmin); 507 | c[0] = vout0; 508 | c[1] = vout1; 509 | c += m; 510 | j -= 1; 511 | } while (j != 0); 512 | } 513 | c -= m * n; 514 | c += 2; 515 | a += 2; 516 | } 517 | if (i & 1) { 518 | const float*restrict w = weights; 519 | const int32_t* dmap = widx_dmap; 520 | const uint32_t* nnzmap = nidx_nnzmap; 521 | size_t j = n; 522 | while (j >= 4) { 523 | uint32_t nnz = *nnzmap++; 524 | float vacc0x0 = *w++; 525 | float vacc0x1 = *w++; 526 | float vacc0x2 = *w++; 527 | float vacc0x3 = *w++; 528 | if (nnz != 0) { 529 | do { 530 | const intptr_t diff = *dmap++; 531 | const float va0 = a[0]; 532 | a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff); 533 | const float vb0 = *w++; 534 | const float vb1 = *w++; 535 | const float vb2 = *w++; 536 | const float vb3 = *w++; 537 | vacc0x0 += va0 * vb0; 538 | vacc0x1 += va0 * vb1; 539 | vacc0x2 += va0 * vb2; 540 | vacc0x3 += va0 * vb3; 541 | } while (--nnz != 0); 542 | } 543 | float vout0x0 = math_min_f32(vacc0x0, vmax); 544 | float vout0x1 = math_min_f32(vacc0x1, vmax); 545 | float vout0x2 = math_min_f32(vacc0x2, vmax); 546 | float vout0x3 = math_min_f32(vacc0x3, vmax); 547 | vout0x0 = math_max_f32(vout0x0, vmin); 548 | vout0x1 = math_max_f32(vout0x1, vmin); 549 | vout0x2 = math_max_f32(vout0x2, vmin); 550 | vout0x3 = math_max_f32(vout0x3, vmin); 551 | c[0 * m + 0] = vout0x0; 552 | c[1 * m + 0] = vout0x1; 553 | c[2 * m + 0] = vout0x2; 554 | c[3 * m + 0] = vout0x3; 555 | c += 4 * m; 556 | j -= 4; 557 | } 558 | if (j != 0) { 559 | do { 560 | uint32_t nnz = *nnzmap++; 561 | float vacc0 = *w++; 562 | if (nnz != 0) { 563 | do { 564 | const intptr_t diff = *dmap++; 565 | const float va0 = a[0]; 566 | a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff); 567 | const float vb = *w++; 568 | vacc0 += va0 * vb; 569 | } while (--nnz != 0); 570 | } 571 | float vout0 = math_min_f32(vacc0, vmax); 572 | vout0 = math_max_f32(vout0, vmin); 573 | c[0] = vout0; 574 | c += m; 575 | j -= 1; 576 | } while (j != 0); 577 | } 578 | c -= m * n; 579 | c += 1; 580 | a += 1; 581 | } 582 | } 583 | } 584 | --------------------------------------------------------------------------------