├── talk ├── PAC.pdf └── PAC_finall.pptx ├── reference ├── 07477467.pdf ├── 1509.09308.pdf ├── 2464-supp.pdf └── Optimization.pdf ├── src ├── build.sh ├── readme ├── winconv.cpp ├── winconv.hpp ├── test.cpp └── winconv_2x3.cpp ├── winconv_4x3 ├── build.sh ├── winconv.hpp ├── winconv.cpp ├── test.cpp └── winconv_4x3.cpp ├── .gitignore ├── LICENSE └── README.md /talk/PAC.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chasingegg/Winconv/HEAD/talk/PAC.pdf -------------------------------------------------------------------------------- /reference/07477467.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chasingegg/Winconv/HEAD/reference/07477467.pdf -------------------------------------------------------------------------------- /talk/PAC_finall.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chasingegg/Winconv/HEAD/talk/PAC_finall.pptx -------------------------------------------------------------------------------- /reference/1509.09308.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chasingegg/Winconv/HEAD/reference/1509.09308.pdf -------------------------------------------------------------------------------- /reference/2464-supp.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chasingegg/Winconv/HEAD/reference/2464-supp.pdf -------------------------------------------------------------------------------- /reference/Optimization.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chasingegg/Winconv/HEAD/reference/Optimization.pdf -------------------------------------------------------------------------------- /src/build.sh: -------------------------------------------------------------------------------- 1 | icpc -qopenmp -O2 -xHost -restrict -I ./ -lmkl_rt winconv.cpp winconv_2x3.cpp test.cpp -o test 2 | -------------------------------------------------------------------------------- /winconv_4x3/build.sh: -------------------------------------------------------------------------------- 1 | mpiicc -O3 -qopenmp -xHost -restrict -I ./ -lmkl_rt -lmkl_blacs_intelmpi_ilp64 -liomp5 -lpthread -ldl winconv.cpp winconv_4x3.cpp test.cpp -o test -lm 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Prerequisites 2 | *.d 3 | 4 | # Compiled Object files 5 | *.slo 6 | *.lo 7 | *.o 8 | *.obj 9 | 10 | # Precompiled Headers 11 | *.gch 12 | *.pch 13 | 14 | # Compiled Dynamic libraries 15 | *.so 16 | *.dylib 17 | *.dll 18 | 19 | # Fortran module files 20 | *.mod 21 | *.smod 22 | 23 | # Compiled Static libraries 24 | *.lai 25 | *.la 26 | *.a 27 | *.lib 28 | 29 | # Executables 30 | *.exe 31 | *.out 32 | *.app 33 | 34 | ~* 35 | test 36 | .* 37 | -------------------------------------------------------------------------------- /src/readme: -------------------------------------------------------------------------------- 1 | 2 | #Code Description: This code is a sample code used for convolution computation based on winograd algorithm. It is paralleled with OpenMP and written in C/C++. It will compute 18 convolution layers and here are more information about code modification below. 3 | - "C_array, IH_array, IW_array, K_array, Batch_array" are convolution parameters, CANN’t be modified; 4 | 5 | #Usage: 6 | - How to compiler: ./build.sh 7 | - How to run: “./test 0” is for performance and “ ./test 1” is for result validation. 8 | 9 | #Optimization Direction: 10 | - SIMD and Memory Access Optimization by intrinsic if needed 11 | - Winograd Algorithm Modification,such as F(3,3) 12 | - SGEMM performance improvement 13 | - Pipeline Optimization between data transform and SGEMM computing 14 | 15 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Gao 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Winograd on CNN 2 | 3 | ## Background 4 | 5 | The success of convolutional neural networks is limited by how fast we can compute them. We use the winograd's minimal filtering algorithms for small, 3*3 filters on CNN for intel architecture, specifically the Knights Landing Architecture. However, the reduction of arithmetic operations in Winograd algorithm comes at the cost of complicating the memory accesses. 6 | 7 | --- 8 | 9 | Direct convolution of a batch of N input images of dimension HxWxC with F filters of dimension RxSxC requires O(NFHWCRS) floating point operations. One useful method is the Fast Fourier Transform method(FFT). It is required that the size of th filter should be large enough to take advantage of it. 10 | 11 | --- 12 | 13 | ## Winograd 14 | 15 | In simple words, Winograd works on one image tile at a time. Let the tile size be 4x4, and a 3x3 filter, we can get 2x2 output. We can do the transformation that 4x4 for input data, 4x3 for filter data and 2x4 for the inverse transform that produces a 2x2 filtered output on one tile. Ignoring the transformation, we ohly need 16 element wise product, this results in a speedup of 36/16=2.25. To ignore the transformation cost, we need to transform element wise product to a GEMM operation. 16 | 17 | --- 18 | 19 | Each image frame has T = (H-2)x(W-2)/4 tiles and there are NxC such frames. We hope that a transformed input tile can be reused to multiply with corresponding F filter tiles. Similarly, a transformed filter tile should be reused to multiply with corresponding input tiles across all the batches. This can be turned into GEMM form by scattering the 16 elements of every tile to 16 different matrices to form the inputs for the GEMM. Therefore, input is converted 16 matrices each of T = (H-2)x(W-2)/4 rows and C columns. Filter is converted to 16 matrices each of C rows and F columns. 20 | 21 | --- 22 | 23 | So there are 16 matrices of dimension TxC for an image input and 16 matrices of dimension CxF that represent filters, resulting in 16 one to one matrix multiplications. Turning to GEMM results in the maximum reuse of every transformed tile and increases the code-vectorization capabilities that will lead to better performance. OpenMP is used for multithreading since it is more scalable and portable than pthreads. 24 | 25 | ## Reference 26 | 27 | - [Fast Algorithms for Convolutional Neural Networks](reference/1509.09308.pdf) 28 | - [Optimization of Spatial Convolution in ConvNets on Intel KNL](reference/Optimization.pdf) 29 | - [FLACON Library](https://github.com/ColfaxResearch/FALCON) -------------------------------------------------------------------------------- /src/winconv.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | float* t_filter; 11 | float* t_image; 12 | float* c_out; 13 | 14 | #if 1 15 | long ISTRIDE = (MAX_BATCH)*(MAX_IMAGE_CHANNELS+18)*(MAX_TILES+13); 16 | long FSTRIDE = (MAX_FILTER_CHANNELS+1)*(MAX_FILTERS+1); 17 | long OSTRIDE = ISTRIDE; 18 | #endif 19 | 20 | // setup scratch memory used in the algorithm 21 | void winconv_init_lib() 22 | { 23 | int ret; 24 | 25 | t_filter = (float *)mkl_malloc(16*FSTRIDE*sizeof(float), 64); 26 | assert(t_filter != NULL); 27 | t_image = (float *)mkl_malloc(16*ISTRIDE*sizeof(float), 64); 28 | assert(t_image != NULL); 29 | c_out = (float *)mkl_malloc(16*OSTRIDE*sizeof(float), 64); 30 | assert(c_out != NULL); 31 | } 32 | 33 | // free up 34 | void winconv_free_lib() 35 | { 36 | mkl_free(t_filter); 37 | mkl_free(t_image); 38 | mkl_free(c_out); 39 | } 40 | 41 | /* Make number not 4096 aligned. */ 42 | inline void no4k_aligned(long *num) 43 | { 44 | long flag = *num; 45 | 46 | if(flag%4096 == 0) 47 | (*num) += 128; 48 | } 49 | 50 | /* Compute stride for input, filter and output. */ 51 | void compute_max_stride(const int lyn, const int N, 52 | const int *C, const int *H, const int *W, const int *K) 53 | { 54 | int tmp; 55 | long istride, fstride, ostride; 56 | istride = fstride = ostride = 0; 57 | 58 | for(int i = 0; i < lyn; i++){ 59 | tmp = N * (H[i]-2)/2 * (W[i]-2)/2 * C[i]; 60 | if(tmp > istride) istride = tmp; 61 | 62 | tmp = C[i] * K[i]; 63 | if(tmp > fstride) fstride = tmp; 64 | 65 | tmp = N * (H[i]-2)/2 * (W[i]-2)/2 * K[i]; 66 | if(tmp > ostride) ostride = tmp; 67 | } 68 | 69 | no4k_aligned(&istride); 70 | no4k_aligned(&fstride); 71 | no4k_aligned(&ostride); 72 | 73 | ISTRIDE = istride; 74 | FSTRIDE = fstride; 75 | OSTRIDE = ostride; 76 | } 77 | 78 | /* Decide to how to divide block for batch. */ 79 | void decide_batch_block(const int lyn, const int N, 80 | const int *C, const int *H, const int *W, const int *K, 81 | int *bblock2x3) 82 | { 83 | float m_used; 84 | 85 | /* F(2,3) */ 86 | for(int i = 0; i < lyn; i++){ 87 | m_used = 1.0f*(N*(H[i]-2)/2*(W[i]-2)/2*C[i] + C[i]*K[i] + N*(H[i]-2)/2*(W[i]-2)/2*K[i])/1024/1024/1024*16; 88 | m_used += 1.0f*(N*C[i]*H[i]*W[i] + K[i]*C[i]*3*3 + N*K[i]*(H[i]-2)*(W[i]-2))/1024/1024/1024; 89 | m_used *= 4; 90 | 91 | if(m_used <= BB_MEM_BOUND) 92 | bblock2x3[i] = BATCH_TOGETHER; 93 | else 94 | bblock2x3[i] = BATCH_BLOCK; 95 | } 96 | 97 | } 98 | 99 | 100 | void winconv(const int bblock2x3, const int M2x3, 101 | float *image, const int irows, const int icols, const int C, 102 | float *filter, const int K, const int N, float *out) 103 | { 104 | winconv_2x3(bblock2x3, M2x3, image, irows, icols, C, filter, K, N, out); 105 | } 106 | -------------------------------------------------------------------------------- /src/winconv.hpp: -------------------------------------------------------------------------------- 1 | #ifndef _FALCON_HPP_ 2 | #define _FALCON_HPP_ 3 | 4 | #include 5 | #include 6 | 7 | // The below parameters are required to generate the scratch pad memory 8 | // It is required to reserve enough memory to store data for all the sizes you will be working on 9 | // For example : by default, the below parameters are set for the test network 10 | #define MAX_BATCH 128 11 | #define MAX_IMAGE_CHANNELS 64 12 | #define MAX_IROWS 1024 13 | #define MAX_FILTER_CHANNELS 512 14 | #define MAX_FILTERS 2048 15 | 16 | #define BB_MEM_BOUND 4 17 | #define BATCH_TOGETHER 0 18 | #define BATCH_BLOCK 1 19 | 20 | #define F2X3 2 21 | 22 | /* STRIDE is the max batch*ntile*channel for input 23 | * FSTRIDE2X3 is the max C*K for filter 24 | * OSTRIDE2X3 is the max batch*ntile*nfilter 25 | **/ 26 | const long MAX_TILES = (MAX_IROWS-2)*(MAX_IROWS-2)*0.25; 27 | #if 0 28 | const long ISTRIDE = (MAX_BATCH)*(MAX_IMAGE_CHANNELS+18)*(MAX_TILES+13); 29 | const long FSTRIDE = (MAX_FILTER_CHANNELS+1)*(MAX_FILTERS+1); 30 | const long OSTRIDE = ISTRIDE; 31 | #else 32 | extern long ISTRIDE; 33 | extern long FSTRIDE; 34 | extern long OSTRIDE; 35 | #endif 36 | extern float* t_filter; 37 | extern float* t_image; 38 | extern float* c_out; 39 | 40 | void winconv(const int bblock2x3, const int M2x3, 41 | float* image, const int irows, const int icols, 42 | const int C, float* filter, const int K, const int N, 43 | float* out); 44 | 45 | void winconv_2x3(const int bblock, const int M, float* image, const int irows, const int icols, 46 | const int C, float* filter, const int K, const int N, 47 | float* out); 48 | 49 | 50 | inline void no4k_aligned(long *); 51 | void compute_max_stride(const int, const int, const int *, const int *, const int *, const int *); 52 | 53 | void decide_batch_block(const int, const int, const int *, const int *, const int *, const int *, int *); 54 | 55 | void winconv_init_lib(); 56 | void winconv_free_lib(); 57 | 58 | // IMAGE LAYOUT : Image is a 4D data structure, image[N][C][H][W], where H=W=irows. 59 | // W is the inner most dimension with unit stride. Image data structure is stored in a linear 60 | // array I[N*channels*irows*irows]. 61 | 62 | // FILTER LAYOUT: Filter is a 4D data structure, filter[K][C][R][S], where R=S=3. S is the inner most dimension 63 | // with unit stride. Filter data structure is stored in a linear array F[K*C*3*3]. 64 | 65 | // OUTPUT LAYOUT: Ouput of convolution is a 4D data structure, out[N][K][oH][oW], where oH=oW=(irows-2). 66 | // oW is the inner most dimension with unit stride. output data structure is stored in a linear 67 | // array O[N*K*oH*oW]. 68 | 69 | 70 | // M -> the merge factor 71 | // image -> pointer to I array 72 | // irows -> is height or width of a square image 73 | // C -> number of image Channels 74 | // Filter -> pointer to F array 75 | // K -> number of filters 76 | // N -> batch size 77 | // out -> pointer to O array 78 | 79 | 80 | // The Merge factor provides flexibility in the way the input data layout is used. 81 | // if M=1 --> NCHW 82 | // else if M=N --> CNHW 83 | // else (1 < M < N) --> (N/M)C(M*HW) 84 | 85 | #endif 86 | -------------------------------------------------------------------------------- /winconv_4x3/winconv.hpp: -------------------------------------------------------------------------------- 1 | #ifndef _FALCON_HPP_ 2 | #define _FALCON_HPP_ 3 | 4 | #include 5 | #include 6 | 7 | // The below parameters are required to generate the scratch pad memory 8 | // It is required to reserve enough memory to store data for all the sizes you will be working on 9 | // For example : by default, the below parameters are set for the test network 10 | #define MAX_BATCH 128 11 | #define MAX_IMAGE_CHANNELS 64 12 | #define MAX_IROWS 1024 13 | #define MAX_FILTER_CHANNELS 512 14 | #define MAX_FILTERS 2048 15 | 16 | #define BB_MEM_BOUND 8 17 | #define BATCH_TOGETHER 0 18 | #define BATCH_BLOCK 1 19 | 20 | #define F2X3 2 21 | 22 | /* STRIDE is the max batch*ntile*channel for input 23 | * FSTRIDE2X3 is the max C*K for filter 24 | * OSTRIDE2X3 is the max batch*ntile*nfilter 25 | **/ 26 | const long MAX_TILES = (MAX_IROWS-2)*(MAX_IROWS-2)*0.25; 27 | #if 0 28 | const long ISTRIDE = (MAX_BATCH)*(MAX_IMAGE_CHANNELS+18)*(MAX_TILES+13); 29 | const long FSTRIDE = (MAX_FILTER_CHANNELS+1)*(MAX_FILTERS+1); 30 | const long OSTRIDE = ISTRIDE; 31 | #else 32 | extern long ISTRIDE; 33 | extern long FSTRIDE; 34 | extern long OSTRIDE; 35 | #endif 36 | extern float* t_filter; 37 | extern float* t_image; 38 | extern float* c_out; 39 | 40 | void winconv(const int bblock2x3, const int M2x3, 41 | float* image, const int irows, const int icols, 42 | const int C, float* filter, const int K, const int N, 43 | float* out); 44 | 45 | void winconv_2x3(const int bblock, const int M, float* image, const int irows, const int icols, 46 | const int C, float* filter, const int K, const int N, 47 | float* out); 48 | 49 | 50 | inline void no4k_aligned(long *); 51 | void compute_max_stride(const int, const int, const int *, const int *, const int *, const int *); 52 | 53 | void decide_batch_block(const int, const int, const int *, const int *, const int *, const int *, int *); 54 | 55 | void winconv_init_lib(); 56 | void winconv_free_lib(); 57 | 58 | // IMAGE LAYOUT : Image is a 4D data structure, image[N][C][H][W], where H=W=irows. 59 | // W is the inner most dimension with unit stride. Image data structure is stored in a linear 60 | // array I[N*channels*irows*irows]. 61 | 62 | // FILTER LAYOUT: Filter is a 4D data structure, filter[K][C][R][S], where R=S=3. S is the inner most dimension 63 | // with unit stride. Filter data structure is stored in a linear array F[K*C*3*3]. 64 | 65 | // OUTPUT LAYOUT: Ouput of convolution is a 4D data structure, out[N][K][oH][oW], where oH=oW=(irows-2). 66 | // oW is the inner most dimension with unit stride. output data structure is stored in a linear 67 | // array O[N*K*oH*oW]. 68 | 69 | 70 | // M -> the merge factor 71 | // image -> pointer to I array 72 | // irows -> is height or width of a square image 73 | // C -> number of image Channels 74 | // Filter -> pointer to F array 75 | // K -> number of filters 76 | // N -> batch size 77 | // out -> pointer to O array 78 | 79 | 80 | // The Merge factor provides flexibility in the way the input data layout is used. 81 | // if M=1 --> NCHW 82 | // else if M=N --> CNHW 83 | // else (1 < M < N) --> (N/M)C(M*HW) 84 | 85 | #endif 86 | -------------------------------------------------------------------------------- /winconv_4x3/winconv.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | using namespace std; 10 | float* t_filter; 11 | float* t_image; 12 | float* c_out; 13 | 14 | #if 1 15 | long ISTRIDE = (MAX_BATCH)*(MAX_IMAGE_CHANNELS+18)*(MAX_TILES+13); 16 | long FSTRIDE = (MAX_FILTER_CHANNELS+1)*(MAX_FILTERS+1); 17 | long OSTRIDE = ISTRIDE; 18 | #endif 19 | 20 | // setup scratch memory used in the algorithm 21 | void winconv_init_lib() 22 | { 23 | int ret; 24 | 25 | t_filter = (float *)mkl_malloc(36*FSTRIDE*sizeof(float), 64); 26 | //std::cout << 36 * FSTRIDE * sizeof(float) << std::endl; 27 | assert(t_filter != NULL); 28 | t_image = (float *)mkl_malloc(36*ISTRIDE*sizeof(float), 64); 29 | //std::cout << 36 * ISTRIDE * sizeof(float) << std::endl; 30 | assert(t_image != NULL); 31 | c_out = (float *)mkl_malloc(36*OSTRIDE*sizeof(float), 64); 32 | //std::cout << 36 * OSTRIDE * sizeof(float) << std::endl; 33 | assert(c_out != NULL); 34 | } 35 | 36 | // free up 37 | void winconv_free_lib() 38 | { 39 | mkl_free(t_filter); 40 | mkl_free(t_image); 41 | mkl_free(c_out); 42 | } 43 | 44 | /* Make number not 4096 aligned. */ 45 | inline void no4k_aligned(long *num) 46 | { 47 | long flag = *num; 48 | 49 | if(flag%4096 == 0) 50 | (*num) += 128; 51 | } 52 | 53 | /* Compute stride for input, filter and output. */ 54 | void compute_max_stride(const int lyn, const int N, 55 | const int *C, const int *H, const int *W, const int *K) 56 | { 57 | int tmp; 58 | long istride, fstride, ostride; 59 | istride = fstride = ostride = 0; 60 | 61 | for(int i = 0; i < lyn; i++){ 62 | int htile = (H[i] + 1) / 4; // outH = H[i] - 2; (outH + 3) / 4; 63 | int wtile = (W[i] + 1) / 4; 64 | tmp = N * htile * wtile * C[i]; 65 | if(tmp > istride) istride = tmp; 66 | 67 | tmp = C[i] * K[i]; 68 | if(tmp > fstride) fstride = tmp; 69 | 70 | tmp = N * htile * wtile * K[i]; 71 | if(tmp > ostride) ostride = tmp; 72 | } 73 | 74 | no4k_aligned(&istride); 75 | no4k_aligned(&fstride); 76 | no4k_aligned(&ostride); 77 | 78 | ISTRIDE = istride; 79 | FSTRIDE = fstride; 80 | OSTRIDE = ostride; 81 | } 82 | 83 | /* Decide to how to divide block for batch. */ 84 | void decide_batch_block(const int lyn, const int N, 85 | const int *C, const int *H, const int *W, const int *K, 86 | int *bblock2x3) 87 | { 88 | float m_used; 89 | 90 | /* F(2,3) */ 91 | for(int i = 0; i < lyn; i++){ 92 | int htile = (H[i] + 1) / 4; 93 | int wtile = (W[i] + 1) / 4; 94 | m_used = 1.0f*(N*htile*wtile*C[i] + C[i]*K[i] + N*htile*wtile*K[i])/1024/1024/1024*36; 95 | m_used += 1.0f*(N*C[i]*H[i]*W[i] + K[i]*C[i]*3*3 + N*K[i]*(H[i]-2)*(W[i]-2))/1024/1024/1024; 96 | m_used *= 4; 97 | 98 | if(m_used <= BB_MEM_BOUND) 99 | bblock2x3[i] = BATCH_TOGETHER; 100 | else 101 | bblock2x3[i] = BATCH_BLOCK; 102 | //bblock2x3[i] = BATCH_TOGETHER; 103 | } 104 | 105 | } 106 | 107 | 108 | void winconv(const int bblock2x3, const int M2x3, 109 | float *image, const int irows, const int icols, const int C, 110 | float *filter, const int K, const int N, float *out) 111 | { 112 | winconv_2x3(bblock2x3, M2x3, image, irows, icols, C, filter, K, N, out); 113 | } 114 | -------------------------------------------------------------------------------- /src/test.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #include 10 | 11 | #include "winconv.hpp" 12 | #include 13 | #include 14 | #include 15 | #include 16 | #define icrt_time_sec() ({ struct timeval tp; gettimeofday(&tp, 0); tp.tv_sec + tp. tv_usec * 1.e-6; }) 17 | 18 | #define CYCLE_NUM 100 19 | 20 | #define F_2X3 1 21 | //#define F_3X3 2 22 | //#define F_HYBRID 3 23 | 24 | int counter = 0; 25 | 26 | int myDirectConv(float *in, float *kn, float *out, 27 | const int N, const int C, const int H, const int W, const int K) 28 | { 29 | int inpos, knpos, outpos; 30 | 31 | int dimIn[4] = {N, C, H, W}; 32 | int dimKn[4] = {K, C, 3, 3}; 33 | int dimOut[4] = {N, K, H-2, W-2}; 34 | 35 | int ingap[3] = {dimIn[1]*dimIn[2]*dimIn[3], dimIn[2]*dimIn[3], dimIn[3]}; 36 | int kngap[3] = {dimKn[1]*dimKn[2]*dimKn[3], dimKn[2]*dimKn[3], dimKn[3]}; 37 | int outgap[3] = {dimOut[1]*dimOut[2]*dimOut[3], dimOut[2]*dimOut[3], dimOut[3]}; 38 | 39 | #pragma omp parallel for private(inpos, knpos, outpos) 40 | for(int inn = 0; inn < dimIn[0]; inn++) 41 | for(int knn = 0; knn < dimKn[0]; knn++) 42 | for(int inc = 0; inc < dimIn[1]; inc++){ 43 | for(int outh = 0; outh < dimOut[2]; outh++) 44 | for(int outw = 0; outw < dimOut[3]; outw++){ 45 | outpos = inn*outgap[0] + knn*outgap[1] + outh*outgap[2] + outw; 46 | for(int knh = 0; knh < dimKn[2]; knh++) 47 | for(int knw = 0; knw < dimKn[3]; knw++){ 48 | inpos = inn*ingap[0] + inc*ingap[1] + (outh+knh)*ingap[2] + (outw+knw); 49 | //knpos = knn*kngap[0] + inc*kngap[1] + 8 - (knh*kngap[2] + knw); 50 | knpos = knn*kngap[0] + inc*kngap[1] + knh*kngap[2] + knw; 51 | out[outpos] += in[inpos] * kn[knpos]; 52 | } 53 | } 54 | } 55 | 56 | return 0; 57 | } 58 | 59 | void winograd_conv(const int bblock2x3, const int M2x3, 60 | int irows, int icols,int C, int K, const int batch, 61 | long* total_flops, double* total_time, const int verify){ 62 | counter++; 63 | 64 | long i, j, n; 65 | const int outHeight = irows-2; 66 | const int outWidth = icols-2; 67 | const int sizeI = irows*icols; 68 | const int sizeF = 3*3; 69 | const int sizeO = outHeight*outWidth; 70 | const int tiles = (outHeight)*0.5*(outWidth)*0.5; 71 | 72 | int ret; 73 | 74 | float* image, *filter, *out; 75 | image = (float *)mkl_malloc(batch*C*sizeI*sizeof(float), 64); 76 | assert(image != NULL); 77 | filter = (float *)mkl_malloc(K*C*sizeF*sizeof(float), 64); 78 | assert(filter != NULL); 79 | out = (float *)mkl_malloc(batch*K*sizeO*sizeof(float), 64); 80 | assert(out != NULL); 81 | 82 | //initialize image in parallel 83 | #pragma omp parallel for private(i) 84 | for(i = 0; i < batch*C*sizeI; i++) 85 | image[i] = (float)(i%11); 86 | //image[i] = rand()%5; 87 | 88 | //initialize image in parallel 89 | #pragma omp parallel for private(i) 90 | for(i = 0; i < K*C*sizeF; i++) 91 | filter[i] = (float)(i%7); 92 | //filter[i] = rand()%3; 93 | 94 | 95 | double timer; 96 | double timer_acc = 0.0f; 97 | 98 | double stime, etime; 99 | 100 | /* First Time */ 101 | winconv_2x3(bblock2x3, M2x3, image, irows, icols, C, filter, K, batch, out); 102 | 103 | stime = icrt_time_sec(); 104 | for(i = 0; i < CYCLE_NUM; i++){ 105 | winconv_2x3(bblock2x3, M2x3, image, irows, icols, C, filter, K, batch, out); 106 | } 107 | etime = icrt_time_sec(); 108 | 109 | timer_acc = etime - stime; 110 | 111 | timer = timer_acc/CYCLE_NUM; 112 | long nflops = batch*K*C*(irows-2)*(icols-2)*3*3*2; 113 | double gflops = (double) nflops*1.0e-9/timer; 114 | *total_flops += nflops; 115 | *total_time += timer; 116 | 117 | if(verify){ 118 | float* vout = (float *)malloc(batch*K*sizeO*sizeof(float)); 119 | memset(vout, 0, batch*K*sizeO*sizeof(float)); 120 | 121 | myDirectConv(image, filter, vout, batch, C, irows, icols, K); 122 | printf("CONV[%-2d], N-C-H-W-K-(Merge2x3-Block2x3) = %-3d %-3d %-3d %-3d %-3d (%-3d %-2d) : ", 123 | counter, batch, C, irows, icols, K, M2x3, bblock2x3 ); 124 | for(n = 0; n < batch*sizeO*K; n++){ 125 | if(fabs((out[n] - vout[n])/vout[n]) > 1e-4){ 126 | printf("Output Error!!! winogradConv[%d] = %f || directConv[%d] = %f \n", n, out[n], n, vout[n]); 127 | break; 128 | } 129 | } 130 | if(n == batch*sizeO*K) 131 | printf("Output True!!!\n"); 132 | free(vout); 133 | } 134 | else{ 135 | printf("CONV[%d]:\tEFFECTIVE GFLOPS is %7.2f \tGFlops \tand timing is \t%f ms \n", counter, gflops, timer*1000); 136 | } 137 | 138 | mkl_free(image); 139 | mkl_free(filter); 140 | mkl_free(out); 141 | 142 | } 143 | 144 | int main(int argc, char** argv){ 145 | 146 | if(argc < 2){ 147 | printf("Enter the running mode\n"); 148 | printf("Example: ./test 0 or ./test 1\n"); 149 | // exit(-1); 150 | } 151 | int i, j; 152 | double timer; 153 | 154 | int verify = 0; //fix me 155 | if(argc>1) 156 | verify = atoi(argv[1]); 157 | 158 | //const int max_tiles = 224*224*0.25; 159 | const long max_tiles = MAX_TILES; 160 | 161 | const int layer_num = 18; 162 | const int C_array[18] = {1, 32, 64, 64, 128, 128, 256, 256, 512, 512, 3, 32, 64, 64, 128, 96, 128, 128}; 163 | const int IH_array[18] = {40, 20, 20, 10, 10, 6, 6, 6, 6, 6, 100, 100, 50, 50, 20, 12, 12, 8}; 164 | const int IW_array[18] = {1024, 512, 512, 512, 512, 512, 512, 512, 512, 1024, 100, 100, 50, 50, 26, 12, 12, 8}; 165 | const int K_array[18] = {32, 64, 64, 128, 128, 256, 256, 512, 512, 2048, 32, 64, 64, 128, 96, 192, 256, 512}; 166 | const int Batch_array[18] = {64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 128, 128, 128, 128, 128, 128, 128, 128}; 167 | int merge_array2x3[18] = {1, 1, 1, 1, 1, 1, 4, 4, 8, 8, 1, 1, 1, 1, 1, 1, 1, 1}; 168 | int b_block2x3[18]; 169 | 170 | int t; 171 | double total_time; 172 | long total_flops; 173 | int batch = 64; 174 | /* for(t = 0; t < layer_num; t++){ 175 | if(batch 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include 11 | 12 | #include "winconv.hpp" 13 | #include 14 | #include 15 | #include 16 | #include 17 | using namespace std; 18 | #define icrt_time_sec() ({ struct timeval tp; gettimeofday(&tp, 0); tp.tv_sec + tp. tv_usec * 1.e-6; }) 19 | 20 | #define CYCLE_NUM 100 21 | 22 | #define F_2X3 1 23 | //#define F_3X3 2 24 | //#define F_HYBRID 3 25 | 26 | int counter = 0; 27 | 28 | int myDirectConv(float *in, float *kn, float *out, 29 | const int N, const int C, const int H, const int W, const int K) 30 | { 31 | int inpos, knpos, outpos; 32 | 33 | int dimIn[4] = {N, C, H, W}; 34 | int dimKn[4] = {K, C, 3, 3}; 35 | int dimOut[4] = {N, K, H-2, W-2}; 36 | 37 | int ingap[3] = {dimIn[1]*dimIn[2]*dimIn[3], dimIn[2]*dimIn[3], dimIn[3]}; 38 | int kngap[3] = {dimKn[1]*dimKn[2]*dimKn[3], dimKn[2]*dimKn[3], dimKn[3]}; 39 | int outgap[3] = {dimOut[1]*dimOut[2]*dimOut[3], dimOut[2]*dimOut[3], dimOut[3]}; 40 | 41 | #pragma omp parallel for private(inpos, knpos, outpos) 42 | for(int inn = 0; inn < dimIn[0]; inn++) 43 | for(int knn = 0; knn < dimKn[0]; knn++) 44 | for(int inc = 0; inc < dimIn[1]; inc++){ 45 | for(int outh = 0; outh < dimOut[2]; outh++) 46 | for(int outw = 0; outw < dimOut[3]; outw++){ 47 | outpos = inn*outgap[0] + knn*outgap[1] + outh*outgap[2] + outw; 48 | for(int knh = 0; knh < dimKn[2]; knh++) 49 | for(int knw = 0; knw < dimKn[3]; knw++){ 50 | inpos = inn*ingap[0] + inc*ingap[1] + (outh+knh)*ingap[2] + (outw+knw); 51 | //knpos = knn*kngap[0] + inc*kngap[1] + 8 - (knh*kngap[2] + knw); 52 | knpos = knn*kngap[0] + inc*kngap[1] + knh*kngap[2] + knw; 53 | out[outpos] += in[inpos] * kn[knpos]; 54 | } 55 | } 56 | } 57 | 58 | return 0; 59 | } 60 | 61 | void winograd_conv(const int bblock2x3, const int M2x3, 62 | int irows, int icols,int C, int K, const int batch, 63 | long* total_flops, double* total_time, const int verify){ 64 | counter++; 65 | 66 | long i, j, n; 67 | const int outHeight = irows-2; 68 | const int outWidth = icols-2; 69 | const int sizeI = irows*icols; 70 | const int sizeF = 3*3; 71 | const int sizeO = outHeight*outWidth; 72 | const int tiles = (outHeight)*0.5*(outWidth)*0.5; 73 | 74 | int mpiRank, mpiSize; 75 | MPI_Comm_rank(MPI_COMM_WORLD, &mpiRank); 76 | MPI_Comm_size(MPI_COMM_WORLD, &mpiSize); 77 | 78 | int num = (batch + mpiSize - 1) / mpiSize; 79 | int batchBegin, batchEnd; 80 | batchBegin = mpiRank * num; 81 | batchEnd = (mpiRank + 1) * num; 82 | if (mpiRank == mpiSize - 1) { 83 | batchEnd= batch; 84 | } 85 | int numBatch = batchEnd - batchBegin; 86 | int realBatch = numBatch; 87 | numBatch = (numBatch + M2x3 - 1) / M2x3 * M2x3; 88 | 89 | float* image, *filter, *out; 90 | if (mpiRank == 0) { 91 | image = (float *)mkl_malloc(batch*C*sizeI*sizeof(float), 64); 92 | } else { 93 | image = (float *)mkl_malloc(numBatch*C*sizeI*sizeof(float), 64); 94 | } 95 | assert(image != NULL); 96 | filter = (float *)mkl_malloc(K*C*sizeF*sizeof(float), 64); 97 | assert(filter != NULL); 98 | if (mpiRank == 0) { 99 | out = (float *)mkl_malloc(batch*K*sizeO*sizeof(float), 64); 100 | } else { 101 | out = (float *)mkl_malloc(numBatch*K*sizeO*sizeof(float), 64); 102 | } 103 | assert(out != NULL); 104 | 105 | //initialize image in parallel 106 | if (mpiRank == 0) { 107 | #pragma omp parallel for private(i) 108 | for (i = 0; i < batch*C*sizeI; i++) 109 | image[i] = (float)(i%11); 110 | } else { 111 | #pragma omp parallel for private(i) 112 | for(i = batchBegin*C*sizeI; i < batchEnd*C*sizeI; i++) 113 | image[i-batchBegin*C*sizeI] = (float)(i%11); 114 | //image[i] = rand()%5; 115 | } 116 | 117 | //initialize filter in parallel 118 | #pragma omp parallel for private(i) 119 | for(i = 0; i < K*C*sizeF; i++) 120 | filter[i] = (float)(i%7); 121 | //filter[i] = rand()%3; 122 | 123 | double timer; 124 | double timer_acc = 0.0f; 125 | 126 | double minTime = 1e32; 127 | int num_threads; 128 | 129 | double stime, etime; 130 | /* First Time */ 131 | winconv_2x3(bblock2x3, M2x3, image, irows, icols, C, filter, K, numBatch, out); 132 | 133 | stime = icrt_time_sec(); 134 | for(i = 0; i < CYCLE_NUM; i++){ 135 | winconv_2x3(bblock2x3, M2x3, image, irows, icols, C, filter, K, numBatch, out); 136 | } 137 | etime = icrt_time_sec(); 138 | 139 | timer_acc = etime - stime; 140 | 141 | if (mpiRank == 0) { 142 | MPI_Request req[4]; 143 | for (int i = 1; i < mpiSize; ++i) { 144 | int batchBegin = i * num; 145 | int recvNumBatch = num; 146 | if (i == mpiSize - 1) { 147 | recvNumBatch = batch - batchBegin; 148 | } 149 | MPI_Irecv(out + batchBegin * K * sizeO, recvNumBatch*K*sizeO, MPI_FLOAT, i, 0, MPI_COMM_WORLD, &(req[i])); 150 | } 151 | for (int i = 1; i < mpiSize; ++i) { 152 | MPI_Status sta; 153 | MPI_Wait(&(req[i]), &sta); 154 | } 155 | } else { 156 | MPI_Request req; 157 | MPI_Isend(out, numBatch*K*sizeO, MPI_FLOAT, 0, 0, MPI_COMM_WORLD, &req); 158 | } 159 | 160 | timer = timer_acc/CYCLE_NUM; 161 | long nflops = batch*K*C*(irows-2)*(icols-2)*3*3*2; 162 | double gflops = (double) nflops*1.0e-9/timer; 163 | *total_flops += nflops; 164 | *total_time += timer; 165 | 166 | if (mpiRank == 0) { 167 | if(verify){ 168 | float* vout = (float *)malloc(batch*K*sizeO*sizeof(float)); 169 | memset(vout, 0, batch*K*sizeO*sizeof(float)); 170 | 171 | myDirectConv(image, filter, vout, batch, C, irows, icols, K); 172 | printf("CONV[%-2d], N-C-H-W-K-(Merge2x3-Block2x3) = %-3d %-3d %-3d %-3d %-3d (%-3d %-2d) : ", 173 | counter, batch, C, irows, icols, K, M2x3, bblock2x3 ); 174 | for(n = 0; n < batch*sizeO*K; n++){ 175 | if(fabs((out[n] - vout[n])/vout[n]) > 1e-4){ 176 | printf("Output Error!!! winogradConv[%d] = %f || directConv[%d] = %f \n", n, out[n], n, vout[n]); 177 | break; 178 | } 179 | } 180 | if(n == batch*sizeO*K) 181 | printf("Output True!!!\n"); 182 | free(vout); 183 | } 184 | else{ 185 | printf("CONV[%d]:\tEFFECTIVE GFLOPS is %7.2f \tGFlops \tand timing is \t%f ms \n", counter, gflops, timer*1000); 186 | } 187 | } 188 | MPI_Barrier(MPI_COMM_WORLD); 189 | mkl_free(image); 190 | mkl_free(filter); 191 | mkl_free(out); 192 | } 193 | 194 | int main(int argc, char** argv){ 195 | 196 | if(argc < 2){ 197 | printf("Enter the running mode\n"); 198 | printf("Example: ./test 0 or ./test 1\n"); 199 | // exit(-1); 200 | } 201 | 202 | MPI_Init(&argc, &argv); 203 | 204 | int i, j; 205 | double timer; 206 | 207 | int verify = 0; //fix me 208 | if(argc>1) 209 | verify = atoi(argv[1]); 210 | 211 | //const int max_tiles = 224*224*0.25; 212 | const long max_tiles = MAX_TILES; 213 | 214 | const int layer_num = 18; //1 40 1024 32 64 64 20 512 64 64 215 | const int C_array[18] = { 1, 32, 64, 64, 128, 128, 256, 256, 512, 512, 3, 32, 64, 64, 128, 96, 128, 128}; 216 | const int IH_array[18] = { 40, 20, 20, 10, 10, 6, 6, 6, 6, 6, 100, 100, 50, 50, 20, 12, 12, 8}; 217 | const int IW_array[18] = { 1024, 512, 512, 512, 512, 512, 512, 512, 512, 1024, 100, 100, 50, 50, 26, 12, 12, 8}; 218 | const int K_array[18] = { 32, 64, 64, 128, 128, 256, 256, 512, 512, 2048, 32, 64, 64, 128, 96, 192, 256, 512}; 219 | const int Batch_array[18] = {64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 128, 128, 128, 128, 128, 128, 128, 128}; 220 | int merge_array2x3[18] = { 1, 1, 1, 1, 2, 4, 4, 4, 4, 2, 1, 4, 1, 4, 1, 1, 1, 2}; 221 | int b_block2x3[18]; 222 | 223 | int t; 224 | double total_time; 225 | long total_flops; 226 | int batch = 64; 227 | /* for(t = 0; t < layer_num; t++){ 228 | if(batch 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #include 10 | #include 11 | 12 | #include 13 | 14 | /* Input data transform part with 16 tiles. */ 15 | static void get_tiles_2x3_16t(int x, int y, int nrows, const float *dataSrc, 16 | float *dataDst, int *counter) 17 | { 18 | const int coter = *counter; 19 | __m512 bufA[16], bufB, bufC, bufD; 20 | __m512i idx0 = _mm512_set_epi32(30,28,26,24,22,20,18,16,14,12,10, 8,6,4,2,0); 21 | __m512i idx1 = _mm512_set_epi32(31,29,27,25,23,21,19,17,15,13,11, 9,7,5,3,1); 22 | 23 | /* 0, 1, 2, 3 */ 24 | bufB = _mm512_load_ps(&dataSrc[(x+0)*nrows + (y+0)]); 25 | bufC = _mm512_load_ps(&dataSrc[(x+0)*nrows + (y+16)]); 26 | bufA[ 0] = _mm512_permutex2var_ps(bufB, idx0, bufC); 27 | bufA[ 1] = _mm512_permutex2var_ps(bufB, idx1, bufC); 28 | 29 | bufB = _mm512_load_ps(&dataSrc[(x+0)*nrows + (y+2)]); 30 | bufC = _mm512_load_ps(&dataSrc[(x+0)*nrows + (y+18)]); 31 | bufA[ 2] = _mm512_permutex2var_ps(bufB, idx0, bufC); 32 | bufA[ 3] = _mm512_permutex2var_ps(bufB, idx1, bufC); 33 | 34 | /* 4, 5, 6, 7 */ 35 | bufB = _mm512_load_ps(&dataSrc[(x+1)*nrows + (y+0)]); 36 | bufC = _mm512_load_ps(&dataSrc[(x+1)*nrows + (y+16)]); 37 | bufA[ 4] = _mm512_permutex2var_ps(bufB, idx0, bufC); 38 | bufA[ 5] = _mm512_permutex2var_ps(bufB, idx1, bufC); 39 | 40 | bufB = _mm512_load_ps(&dataSrc[(x+1)*nrows + (y+2)]); 41 | bufC = _mm512_load_ps(&dataSrc[(x+1)*nrows + (y+18)]); 42 | bufA[ 6] = _mm512_permutex2var_ps(bufB, idx0, bufC); 43 | bufA[ 7] = _mm512_permutex2var_ps(bufB, idx1, bufC); 44 | 45 | /* 8, 9, 10, 11 */ 46 | bufB = _mm512_load_ps(&dataSrc[(x+2)*nrows + (y+0)]); 47 | bufC = _mm512_load_ps(&dataSrc[(x+2)*nrows + (y+16)]); 48 | bufA[ 8] = _mm512_permutex2var_ps(bufB, idx0, bufC); 49 | bufA[ 9] = _mm512_permutex2var_ps(bufB, idx1, bufC); 50 | 51 | bufB = _mm512_load_ps(&dataSrc[(x+2)*nrows + (y+2)]); 52 | bufC = _mm512_load_ps(&dataSrc[(x+2)*nrows + (y+18)]); 53 | bufA[10] = _mm512_permutex2var_ps(bufB, idx0, bufC); 54 | bufA[11] = _mm512_permutex2var_ps(bufB, idx1, bufC); 55 | 56 | /* 12, 13, 14, 15 */ 57 | bufB = _mm512_load_ps(&dataSrc[(x+3)*nrows + (y+0)]); 58 | bufC = _mm512_load_ps(&dataSrc[(x+3)*nrows + (y+16)]); 59 | bufA[12] = _mm512_permutex2var_ps(bufB, idx0, bufC); 60 | bufA[13] = _mm512_permutex2var_ps(bufB, idx1, bufC); 61 | 62 | bufB = _mm512_load_ps(&dataSrc[(x+3)*nrows + (y+2)]); 63 | bufC = _mm512_load_ps(&dataSrc[(x+3)*nrows + (y+18)]); 64 | bufA[14] = _mm512_permutex2var_ps(bufB, idx0, bufC); 65 | bufA[15] = _mm512_permutex2var_ps(bufB, idx1, bufC); 66 | 67 | /* 0 */ 68 | bufB = _mm512_sub_ps(bufA[ 0], bufA[ 8]); 69 | bufC = _mm512_sub_ps(bufA[ 2], bufA[10]); 70 | bufD = _mm512_sub_ps(bufB, bufC); 71 | _mm512_store_ps(&dataDst[ 0*ISTRIDE + coter], bufD); 72 | 73 | /* 1 */ 74 | bufB = _mm512_sub_ps(bufA[ 1], bufA[ 9]); 75 | bufD = _mm512_add_ps(bufB, bufC); 76 | _mm512_store_ps(&dataDst[ 1*ISTRIDE + coter], bufD); 77 | 78 | /* 2 */ 79 | bufC = _mm512_sub_ps(bufA[ 2], bufA[10]); 80 | bufD = _mm512_sub_ps(bufC, bufB); 81 | _mm512_store_ps(&dataDst[ 2*ISTRIDE + coter], bufD); 82 | 83 | /* 3 */ 84 | bufC = _mm512_sub_ps(bufA[ 3], bufA[11]); 85 | bufD = _mm512_sub_ps(bufB, bufC); 86 | _mm512_store_ps(&dataDst[ 3*ISTRIDE + coter], bufD); 87 | 88 | /* 4 */ 89 | bufB = _mm512_add_ps(bufA[ 4], bufA[ 8]); 90 | bufC = _mm512_add_ps(bufA[ 6], bufA[10]); 91 | bufD = _mm512_sub_ps(bufB, bufC); 92 | _mm512_store_ps(&dataDst[ 4*ISTRIDE + coter], bufD); 93 | 94 | /* 5 */ 95 | bufB = _mm512_add_ps(bufA[ 5], bufA[ 9]); 96 | bufD = _mm512_add_ps(bufB, bufC); 97 | _mm512_store_ps(&dataDst[ 5*ISTRIDE + coter], bufD); 98 | 99 | /* 6 */ 100 | bufB = _mm512_add_ps(bufA[ 5], bufA[ 9]); 101 | bufD = _mm512_sub_ps(bufC, bufB); 102 | _mm512_store_ps(&dataDst[ 6*ISTRIDE + coter], bufD); 103 | 104 | /* 7 */ 105 | bufC = _mm512_add_ps(bufA[ 7], bufA[11]); 106 | bufD = _mm512_sub_ps(bufB, bufC); 107 | _mm512_store_ps(&dataDst[ 7*ISTRIDE + coter], bufD); 108 | 109 | /* 8 */ 110 | bufB = _mm512_sub_ps(bufA[ 8], bufA[ 4]); 111 | bufC = _mm512_sub_ps(bufA[10], bufA[ 6]); 112 | bufD = _mm512_sub_ps(bufB, bufC); 113 | _mm512_store_ps(&dataDst[ 8*ISTRIDE + coter], bufD); 114 | 115 | /* 9 */ 116 | bufB = _mm512_sub_ps(bufA[ 9], bufA[ 5]); 117 | bufD = _mm512_add_ps(bufB, bufC); 118 | _mm512_store_ps(&dataDst[ 9*ISTRIDE + coter], bufD); 119 | 120 | /* 10 */ 121 | bufD = _mm512_sub_ps(bufC, bufB); 122 | _mm512_store_ps(&dataDst[10*ISTRIDE + coter], bufD); 123 | 124 | /* 11 */ 125 | bufC = _mm512_sub_ps(bufA[11], bufA[ 7]); 126 | bufD = _mm512_sub_ps(bufB, bufC); 127 | _mm512_store_ps(&dataDst[11*ISTRIDE + coter], bufD); 128 | 129 | /* 12 */ 130 | bufB = _mm512_sub_ps(bufA[ 4], bufA[12]); 131 | bufC = _mm512_sub_ps(bufA[ 6], bufA[14]); 132 | bufD = _mm512_sub_ps(bufB, bufC); 133 | _mm512_store_ps(&dataDst[12*ISTRIDE + coter], bufD); 134 | 135 | /* 13 */ 136 | bufB = _mm512_sub_ps(bufA[ 5], bufA[13]); 137 | bufD = _mm512_add_ps(bufB, bufC); 138 | _mm512_store_ps(&dataDst[13*ISTRIDE + coter], bufD); 139 | 140 | /* 14 */ 141 | bufD = _mm512_sub_ps(bufC, bufB); 142 | _mm512_store_ps(&dataDst[14*ISTRIDE + coter], bufD); 143 | 144 | /* 15 */ 145 | bufC = _mm512_sub_ps(bufA[ 7], bufA[15]); 146 | bufD = _mm512_sub_ps(bufB, bufC); 147 | _mm512_store_ps(&dataDst[15*ISTRIDE + coter], bufD); 148 | 149 | *counter += 16; 150 | } 151 | 152 | /* Input data transform part with 1 tiles. */ 153 | static inline void get_tiles_2x3_1t(int x, int y, int nrows, const float *dataSrc, 154 | float *dataDst, int *counter) 155 | { 156 | int coter = *counter; 157 | float tmp[16] __attribute__((aligned(64))); 158 | 159 | tmp[ 0] = dataSrc[(x+0)*nrows + y+0]; 160 | tmp[ 1] = dataSrc[(x+0)*nrows + y+1]; 161 | tmp[ 2] = dataSrc[(x+0)*nrows + y+2]; 162 | tmp[ 3] = dataSrc[(x+0)*nrows + y+3]; 163 | 164 | tmp[ 4] = dataSrc[(x+1)*nrows + y+0]; 165 | tmp[ 5] = dataSrc[(x+1)*nrows + y+1]; 166 | tmp[ 6] = dataSrc[(x+1)*nrows + y+2]; 167 | tmp[ 7] = dataSrc[(x+1)*nrows + y+3]; 168 | 169 | tmp[ 8] = dataSrc[(x+2)*nrows + y+0]; 170 | tmp[ 9] = dataSrc[(x+2)*nrows + y+1]; 171 | tmp[10] = dataSrc[(x+2)*nrows + y+2]; 172 | tmp[11] = dataSrc[(x+2)*nrows + y+3]; 173 | 174 | tmp[12] = dataSrc[(x+3)*nrows + y+0]; 175 | tmp[13] = dataSrc[(x+3)*nrows + y+1]; 176 | tmp[14] = dataSrc[(x+3)*nrows + y+2]; 177 | tmp[15] = dataSrc[(x+3)*nrows + y+3]; 178 | 179 | // The tranformation manually simplified 180 | dataDst[coter+ 0*ISTRIDE] =(tmp[0] - tmp[8 ]) - (tmp[2 ]- tmp[10]); 181 | dataDst[coter+ 1*ISTRIDE] =(tmp[1] - tmp[9 ]) + (tmp[2 ]- tmp[10]); 182 | dataDst[coter+ 2*ISTRIDE] =(tmp[2] - tmp[10]) - (tmp[1 ]- tmp[9 ]); 183 | dataDst[coter+ 3*ISTRIDE] =(tmp[1] - tmp[9 ]) - (tmp[3 ]- tmp[11]); 184 | dataDst[coter+ 4*ISTRIDE] =(tmp[4] + tmp[8 ]) - (tmp[6 ]+ tmp[10]); 185 | dataDst[coter+ 5*ISTRIDE] =(tmp[5] + tmp[9 ]) + (tmp[6 ]+ tmp[10]); 186 | dataDst[coter+ 6*ISTRIDE] =(tmp[6] + tmp[10]) - (tmp[5 ]+ tmp[9 ]); 187 | dataDst[coter+ 7*ISTRIDE] =(tmp[5] + tmp[9 ]) - (tmp[7 ]+ tmp[11]); 188 | dataDst[coter+ 8*ISTRIDE] =(tmp[8] - tmp[4 ]) - (tmp[10]- tmp[6 ]); 189 | dataDst[coter+ 9*ISTRIDE] =(tmp[9] - tmp[5 ]) + (tmp[10]- tmp[6 ]); 190 | dataDst[coter+10*ISTRIDE] =(tmp[10]- tmp[6 ]) - (tmp[9 ]- tmp[5 ]); 191 | dataDst[coter+11*ISTRIDE] =(tmp[9] - tmp[5 ]) - (tmp[11]- tmp[7 ]); 192 | dataDst[coter+12*ISTRIDE] =(tmp[4] - tmp[12]) - (tmp[6 ]- tmp[14]); 193 | dataDst[coter+13*ISTRIDE] =(tmp[5] - tmp[13]) + (tmp[6 ]- tmp[14]); 194 | dataDst[coter+14*ISTRIDE] =(tmp[6] - tmp[14]) - (tmp[5 ]- tmp[13]); 195 | dataDst[coter+15*ISTRIDE] =(tmp[5] - tmp[13]) - (tmp[7 ]- tmp[15]); 196 | 197 | (*counter)++; 198 | 199 | } 200 | 201 | // INTERNAL FUNCTION : FORM MATRIX A from input data, also includes transformation F(2,3) 202 | static void get_tiles_2x3(const float* restrict image, const int ldi, const int irows, const int icols, 203 | const int sizeI, const int C, float* restrict otile, const int N, const int ntiles, const int M) 204 | { 205 | 206 | int t, u; 207 | 208 | #pragma omp parallel for 209 | for(t = 0; t < N*C; t++){ 210 | int i, j; 211 | 212 | const int t1 = t/(C*M); 213 | const int t2 = (t%(C*M))/M; 214 | const int t3 = t%M; 215 | 216 | //const float* data = image+t*sizeI; 217 | const float *data = image + (t1*M*C + t3*C + t2)*sizeI; 218 | int tile_count = t*ntiles; 219 | const int num16t = (icols-2)/32*32; 220 | 221 | // work on one image plane at a time, irrespective of the order 222 | for(i = 0; i < irows-2; i += 2){ 223 | /* 16 tiles together */ 224 | for(j = 0; j < num16t; j += 32){ 225 | get_tiles_2x3_16t(i, j, ldi, data, otile, &tile_count); 226 | } 227 | 228 | /* 1 tile together */ 229 | #pragma simd 230 | for(j = num16t; j < (icols-2); j += 2){ 231 | get_tiles_2x3_1t(i, j, ldi, data, otile, &tile_count); 232 | } 233 | } 234 | } 235 | } 236 | 237 | // INTERNAL FUNCTION: FORM MATRIX B, also includes filter transform F(2,3) 238 | static void filter_transform_2x3(const float* restrict filter, const int C, const int K, float* restrict out) 239 | { 240 | 241 | int m, n, x; 242 | const float *F; 243 | 244 | #pragma omp parallel for collapse(2) private(m, n, x, F) 245 | #pragma simd 246 | for(m = 0; m < K; m++){ 247 | for(n = 0; n < C; n++){ 248 | float c1[16] __attribute__((aligned(64))); 249 | F = filter+n*3*3 + m*3*3*C; 250 | 251 | // work on in 3x3 plane at a time 252 | // The tranformation manually simplified 253 | c1[0] = F[0]; 254 | c1[1] = (F[0]+F[2]+F[1])*0.5f; 255 | c1[2] = (F[0]+F[2]-F[1])*0.5f; 256 | c1[3] = F[2]; 257 | c1[4] = (F[0]+F[6]+F[3])*0.5f; 258 | c1[5] = ((F[0]+F[6]+F[3])+(F[2]+F[8]+F[5])+(F[1]+F[7]+F[4]))*0.25f; 259 | c1[6] = ((F[0]+F[6]+F[3])+(F[2]+F[8]+F[5])-(F[1]+F[7]+F[4]))*0.25f; 260 | c1[7] = (F[2]+F[8]+F[5])*0.5f; 261 | c1[8] = (F[0]+F[6]-F[3])*0.5f; 262 | c1[9] = ((F[0]+F[6]-F[3])+(F[2]+F[8]-F[5])+(F[1]+F[7]-F[4]))*0.25f; 263 | c1[10] = ((F[0]+F[6]-F[3])+(F[2]+F[8]-F[5])-(F[1]+F[7]-F[4]))*0.25f; 264 | c1[11] = (F[2]+F[8]-F[5])*0.5f; 265 | c1[12] = F[6]; 266 | c1[13] = (F[6]+F[8]+F[7])*0.5f; 267 | c1[14] = (F[6]+F[8]-F[7])*0.5f; 268 | c1[15] = F[8]; 269 | 270 | // scatter 271 | #pragma unroll(16) 272 | for(x = 0; x < 16; x++){ 273 | out[x*FSTRIDE+m*C+n] = c1[x]; 274 | } 275 | } 276 | } 277 | } 278 | 279 | // INTERNAL FUNCTION F(2,3) 280 | // GEMM specific to Ist layer of VGG with (M, N, K) = (12544, 64, 3) 281 | // MKL performs bad 282 | static void gemm_ker(int m, int n, int k, const float* a, const int lda, const float* b, const int ldb, float* c, const int ldc) 283 | { 284 | 285 | const int BLK = 16; 286 | int x, xx, y, z, i; 287 | 288 | for(z = 0; z < n; z++){ 289 | for(x = 0; x < m; x += BLK){ 290 | float p[BLK] __attribute__((aligned(64))); 291 | p[0:BLK] = 0.0f; 292 | #pragma unroll(3) 293 | for(y = 0; y < 3; y++){ 294 | #pragma vector aligned 295 | for(i = 0; i < BLK; i++){ 296 | p[i] += a[x+i+y*lda]*b[y+z*ldb]; 297 | } 298 | } 299 | c[x+z*ldc:BLK] = p[0:BLK]; 300 | } 301 | } 302 | 303 | } 304 | 305 | 306 | // INTERNAL FUNCTION F(2,3) 307 | // C = A*B with beta = 0.0f and alpha = 1.0f 308 | // Number of gemm calls is 16*BATCH 309 | static void batched_gemm_2x3(const float* image, const int irows, const int icols, 310 | const float* filter, const int frows, const int fcols, float* restrict out, const int batch) 311 | { 312 | 313 | int t, i; 314 | const char trans ='n'; 315 | const float alpha = 1.0; 316 | const float beta = 0.0; 317 | const int ldi = irows; 318 | const int ldf = frows; 319 | const int ldo = irows; 320 | 321 | #pragma omp parallel for num_threads(68) collapse(2) private(t, i) 322 | for(i = 0; i < 16; i++){ 323 | for(t = 0; t < batch; t++){ 324 | const float* im = image+i*ISTRIDE+t*irows*icols; 325 | const float* fi = filter+i*FSTRIDE; 326 | float* ot = out+i*OSTRIDE+t*irows*fcols; 327 | sgemm(&trans, &trans, &irows, &fcols, &icols, &alpha, im, &ldi, fi, &ldf, &beta, ot, &ldo); 328 | } 329 | } 330 | 331 | } 332 | 333 | /* Output data transform part with 16 tiles. */ 334 | static void out_transform_2x3_16t(int x, int y, int nrows, const float *dataSrc, 335 | float *dataDst, int *counter) 336 | { 337 | int coter = *counter; 338 | float c1[256] __attribute__((aligned(64))); 339 | __m512 bufA[16], bufB, bufC, bufD, bufE; 340 | __m512i idx0 = _mm512_set_epi32(23, 7,22, 6,21, 5,20, 4,19, 3,18, 2,17, 1,16, 0); 341 | __m512i idx1 = _mm512_set_epi32(31,15,30,14,29,13,28,12,27,11,26,10,25, 9,24, 8); 342 | 343 | // gather the 16 elements form C to form a tile 344 | c1[ 0:16] = dataSrc[coter + 0*OSTRIDE:16]; 345 | c1[ 16:16] = dataSrc[coter + 1*OSTRIDE:16]; 346 | c1[ 32:16] = dataSrc[coter + 2*OSTRIDE:16]; 347 | c1[ 48:16] = dataSrc[coter + 3*OSTRIDE:16]; 348 | c1[ 64:16] = dataSrc[coter + 4*OSTRIDE:16]; 349 | c1[ 80:16] = dataSrc[coter + 5*OSTRIDE:16]; 350 | c1[ 96:16] = dataSrc[coter + 6*OSTRIDE:16]; 351 | c1[112:16] = dataSrc[coter + 7*OSTRIDE:16]; 352 | c1[128:16] = dataSrc[coter + 8*OSTRIDE:16]; 353 | c1[144:16] = dataSrc[coter + 9*OSTRIDE:16]; 354 | c1[160:16] = dataSrc[coter + 10*OSTRIDE:16]; 355 | c1[176:16] = dataSrc[coter + 11*OSTRIDE:16]; 356 | c1[192:16] = dataSrc[coter + 12*OSTRIDE:16]; 357 | c1[208:16] = dataSrc[coter + 13*OSTRIDE:16]; 358 | c1[224:16] = dataSrc[coter + 14*OSTRIDE:16]; 359 | c1[240:16] = dataSrc[coter + 15*OSTRIDE:16]; 360 | 361 | /* Register store the source data */ 362 | bufA[ 0] = _mm512_load_ps(c1+ 0); 363 | bufA[ 1] = _mm512_load_ps(c1+ 16); 364 | bufA[ 2] = _mm512_load_ps(c1+ 32); 365 | bufA[ 3] = _mm512_load_ps(c1+ 48); 366 | bufA[ 4] = _mm512_load_ps(c1+ 64); 367 | bufA[ 5] = _mm512_load_ps(c1+ 80); 368 | bufA[ 6] = _mm512_load_ps(c1+ 96); 369 | bufA[ 7] = _mm512_load_ps(c1+112); 370 | bufA[ 8] = _mm512_load_ps(c1+128); 371 | bufA[ 9] = _mm512_load_ps(c1+144); 372 | bufA[10] = _mm512_load_ps(c1+160); 373 | bufA[11] = _mm512_load_ps(c1+176); 374 | bufA[12] = _mm512_load_ps(c1+192); 375 | bufA[13] = _mm512_load_ps(c1+208); 376 | bufA[14] = _mm512_load_ps(c1+224); 377 | bufA[15] = _mm512_load_ps(c1+240); 378 | 379 | 380 | /* Compute the media result */ 381 | bufB = _mm512_add_ps(bufA[ 0], bufA[ 1]); 382 | bufB = _mm512_add_ps(bufB, bufA[ 2]); 383 | bufB = _mm512_add_ps(bufB, bufA[ 4]); 384 | bufB = _mm512_add_ps(bufB, bufA[ 5]); 385 | bufB = _mm512_add_ps(bufB, bufA[ 6]); 386 | bufB = _mm512_add_ps(bufB, bufA[ 8]); 387 | bufB = _mm512_add_ps(bufB, bufA[ 9]); 388 | bufB = _mm512_add_ps(bufB, bufA[10]); 389 | 390 | bufC = _mm512_sub_ps(bufA[ 1], bufA[ 2]); 391 | bufC = _mm512_sub_ps(bufC, bufA[ 3]); 392 | bufC = _mm512_add_ps(bufC, bufA[ 5]); 393 | bufC = _mm512_sub_ps(bufC, bufA[ 6]); 394 | bufC = _mm512_sub_ps(bufC, bufA[ 7]); 395 | bufC = _mm512_add_ps(bufC, bufA[ 9]); 396 | bufC = _mm512_sub_ps(bufC, bufA[10]); 397 | bufC = _mm512_sub_ps(bufC, bufA[11]); 398 | 399 | bufD = _mm512_add_ps(bufA[ 4], bufA[ 5]); 400 | bufD = _mm512_add_ps(bufD, bufA[ 6]); 401 | bufD = _mm512_sub_ps(bufD, bufA[ 8]); 402 | bufD = _mm512_sub_ps(bufD, bufA[ 9]); 403 | bufD = _mm512_sub_ps(bufD, bufA[10]); 404 | bufD = _mm512_sub_ps(bufD, bufA[12]); 405 | bufD = _mm512_sub_ps(bufD, bufA[13]); 406 | bufD = _mm512_sub_ps(bufD, bufA[14]); 407 | 408 | bufE = _mm512_sub_ps(bufA[ 5], bufA[ 6]); 409 | bufE = _mm512_sub_ps(bufE, bufA[ 7]); 410 | bufE = _mm512_sub_ps(bufE, bufA[ 9]); 411 | bufE = _mm512_add_ps(bufE, bufA[10]); 412 | bufE = _mm512_add_ps(bufE, bufA[11]); 413 | bufE = _mm512_sub_ps(bufE, bufA[13]); 414 | bufE = _mm512_add_ps(bufE, bufA[14]); 415 | bufE = _mm512_add_ps(bufE, bufA[15]); 416 | 417 | /* Store the finally output data */ 418 | bufA[0] = _mm512_permutex2var_ps(bufB, idx0, bufC); 419 | bufA[1] = _mm512_permutex2var_ps(bufB, idx1, bufC); 420 | bufA[2] = _mm512_permutex2var_ps(bufD, idx0, bufE); 421 | bufA[3] = _mm512_permutex2var_ps(bufD, idx1, bufE); 422 | 423 | _mm512_store_ps(&dataDst[(x+0)*nrows + (y+ 0)], bufA[0]); 424 | _mm512_store_ps(&dataDst[(x+0)*nrows + (y+16)], bufA[1]); 425 | _mm512_store_ps(&dataDst[(x+1)*nrows + (y+ 0)], bufA[2]); 426 | _mm512_store_ps(&dataDst[(x+1)*nrows + (y+16)], bufA[3]); 427 | 428 | (*counter) += 16; 429 | } 430 | 431 | /* Output data transform part with 1 tile. */ 432 | static inline void out_transform_2x3_1t(int x, int y, int nrows, const float *dataSrc, 433 | float *dataDst, int *counter) 434 | { 435 | int coter = *counter; 436 | float c1[16] __attribute__((aligned(64))); 437 | float temp[8] __attribute__((aligned(64))); 438 | 439 | // gather the 16 elements form C to form a tile 440 | c1[0 ] = dataSrc[coter+0 *OSTRIDE]; 441 | c1[1 ] = dataSrc[coter+1 *OSTRIDE]; 442 | c1[2 ] = dataSrc[coter+2 *OSTRIDE]; 443 | c1[3 ] = dataSrc[coter+3 *OSTRIDE]; 444 | c1[4 ] = dataSrc[coter+4 *OSTRIDE]; 445 | c1[5 ] = dataSrc[coter+5 *OSTRIDE]; 446 | c1[6 ] = dataSrc[coter+6 *OSTRIDE]; 447 | c1[7 ] = dataSrc[coter+7 *OSTRIDE]; 448 | c1[8 ] = dataSrc[coter+8 *OSTRIDE]; 449 | c1[9 ] = dataSrc[coter+9 *OSTRIDE]; 450 | c1[10] = dataSrc[coter+10*OSTRIDE]; 451 | c1[11] = dataSrc[coter+11*OSTRIDE]; 452 | c1[12] = dataSrc[coter+12*OSTRIDE]; 453 | c1[13] = dataSrc[coter+13*OSTRIDE]; 454 | c1[14] = dataSrc[coter+14*OSTRIDE]; 455 | c1[15] = dataSrc[coter+15*OSTRIDE]; 456 | 457 | 458 | // The tranformation manually simplified 459 | temp[0] = c1[0]+c1[1]+ c1[2]; 460 | temp[1] = c1[1]-c1[2]- c1[3]; 461 | temp[2] = c1[4]+c1[5]+ c1[6]; 462 | temp[3] = c1[5]-c1[6]- c1[7]; 463 | temp[4] = c1[8]+c1[9]+ c1[10]; 464 | temp[5] = c1[9]-c1[10]- c1[11]; 465 | temp[6] = c1[12]+c1[13]+ c1[14]; 466 | temp[7] = c1[13]-c1[14]- c1[15]; 467 | 468 | dataDst[(x+0)*nrows+y] = temp[0]+temp[2]+temp[4]; 469 | dataDst[(x+0)*nrows+y+1] = temp[1]+temp[3]+temp[5]; 470 | dataDst[(x+1)*nrows+y] = temp[2]-temp[4]-temp[6]; 471 | dataDst[(x+1)*nrows+y+1] = temp[3]-temp[5]-temp[7]; 472 | 473 | (*counter)++; 474 | } 475 | 476 | // INTERNAL FUNCTION F(2,3) 477 | // Transform matrix multiplication output 478 | static void out_transform_2x3(const float* restrict d, const int K, const int ntiles, 479 | float* restrict out, const int ldo,const int oH, const int oW, const int N, const int M) 480 | { 481 | 482 | int t; 483 | int sizeO = oH*oW; 484 | 485 | #pragma omp parallel for 486 | for(t = 0; t < N*K; t++){ 487 | int i, j; 488 | 489 | const int t1 = t/(K*M); 490 | const int t2 = (t%(K*M))/M; 491 | const int t3 = t%M; 492 | 493 | float *data = out + (t1*M*K + t3*K + t2)*sizeO; 494 | int tile_offset = t*ntiles; 495 | const int num16t = oW/32*32; 496 | 497 | // work on one output plane at a time, irrespective of the order 498 | for(i = 0; i < oH; i += 2){ 499 | /* 16 tiles together */ 500 | for(j = 0; j < num16t; j += 32){ 501 | out_transform_2x3_16t(i, j, ldo, d, data, &tile_offset); 502 | } 503 | 504 | /* 1 tile together */ 505 | #pragma simd 506 | for(j = num16t; j < oW; j += 2){ 507 | out_transform_2x3_1t(i, j, ldo, d, data, &tile_offset); 508 | } 509 | } 510 | } 511 | } 512 | 513 | // User API for winograd F(2,3) 514 | void winconv_2x3(const int bblock, const int M, float* restrict image, const int irows, const int icols, 515 | const int C, float* restrict filter, const int K, const int batch, 516 | float* restrict out) 517 | { 518 | 519 | const int outHeight = irows-2; 520 | const int outWidth = icols-2; 521 | const int sizeI = irows*icols; 522 | const int tiles = (outHeight)*0.5*(outWidth)*0.5; 523 | 524 | float *b_image; 525 | float *b_out; 526 | const int b_batchSize = 32; 527 | 528 | if(batch%b_batchSize != 0){ 529 | printf("Error: Batch can't be divided by %d!\n", b_batchSize); 530 | exit(0); 531 | } 532 | 533 | filter_transform_2x3(filter, C, K, t_filter); 534 | switch(bblock){ 535 | case BATCH_TOGETHER: 536 | get_tiles_2x3(image, icols, irows, icols, sizeI, C, t_image, batch, tiles, M); 537 | batched_gemm_2x3(t_image, M*tiles, C, t_filter, C, K, c_out, batch/M); 538 | out_transform_2x3(c_out, K, tiles, out, outWidth, outHeight, outWidth, batch, M); 539 | break; 540 | case BATCH_BLOCK: 541 | for(int i = 0; i < batch; i += b_batchSize){ 542 | b_image = image + i*C*irows*icols; 543 | b_out = out + i*K*outHeight*outWidth; 544 | get_tiles_2x3(b_image, icols, irows, icols, sizeI, C, t_image, b_batchSize, tiles, M); 545 | batched_gemm_2x3(t_image, M*tiles, C, t_filter, C, K, c_out, b_batchSize/M); 546 | out_transform_2x3(c_out, K, tiles, b_out, outWidth, outHeight, outWidth, b_batchSize, M); 547 | } 548 | break; 549 | default: 550 | printf("Error: You need to decide wether to divide block for batch!\n"); 551 | break; 552 | } 553 | } 554 | -------------------------------------------------------------------------------- /winconv_4x3/winconv_4x3.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #include 10 | #include 11 | 12 | #include 13 | #include 14 | using namespace std; 15 | 16 | static void get_tiles_4x3_16t(int x, int y, int nrows, const float *dataSrc, 17 | float *dataDst, int *counter) 18 | { 19 | const int coter = *counter; 20 | //cout << "get tiles " << coter << endl; 21 | __m512 bufA[36]; 22 | __m512 bufB, bufC, bufD, bufE, bufF, bufG, bufH, bufI; 23 | __m512i idx0 = _mm512_set_epi32(30, 28, 26, 24, 22, 20, 18, 16, 24 | 14, 12, 10, 8, 6, 4, 2, 0); 25 | __m512i idx1 = _mm512_set_epi32(31, 29, 27, 25, 23, 21, 19, 17, 26 | 15, 13, 11, 9, 7, 5, 3, 1); 27 | 28 | /* 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 29 | 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 30 | 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 31 | 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 */ 32 | 33 | /* 0 2 4 6 8 10 12 14 16 18 20 22 24 26 28 30 34 | 1 3 5 7 9 11 13 15 17 19 21 23 25 27 29 31 35 | 32 34 36 38 40 42 44 46 48 50 52 54 56 58 60 62 36 | 33 35 37 39 41 43 45 47 49 51 53 55 57 59 61 63 */ 37 | 38 | /* 0 4 8 12 16 20 24 28 32 36 40 44 48 52 56 60 1, 3 permute0 39 | 1 5 9 13 17 21 25 29 33 37 41 45 49 53 57 61 2, 4 permute0 40 | 2 6 10 14 18 22 26 30 34 38 42 46 50 54 58 62 1, 3 permute1 41 | 3 7 11 15 19 23 27 31 35 39 43 47 51 55 59 63 2, 4 permute1 */ 42 | 43 | /* 0, 1, 2, 3, 4, 5 */ 44 | bufB = _mm512_load_ps(dataSrc + (x+0) * nrows + y); 45 | bufC = _mm512_load_ps(dataSrc + (x+0) * nrows + y + 16); 46 | bufD = _mm512_load_ps(dataSrc + (x+0) * nrows + y + 32); 47 | bufE = _mm512_load_ps(dataSrc + (x+0) * nrows + y + 48); 48 | bufF = _mm512_permutex2var_ps(bufB, idx0, bufC); 49 | bufG = _mm512_permutex2var_ps(bufB, idx1, bufC); 50 | bufH = _mm512_permutex2var_ps(bufD, idx0, bufE); 51 | bufI = _mm512_permutex2var_ps(bufD, idx1, bufE); 52 | bufA[0] = _mm512_permutex2var_ps(bufF, idx0, bufH); 53 | bufA[1] = _mm512_permutex2var_ps(bufG, idx0, bufI); 54 | bufA[2] = _mm512_permutex2var_ps(bufF, idx1, bufH); 55 | bufA[3] = _mm512_permutex2var_ps(bufG, idx1, bufI); 56 | 57 | bufB = _mm512_load_ps(dataSrc + (x+0) * nrows + y + 2); 58 | bufC = _mm512_load_ps(dataSrc + (x+0) * nrows + y + 18); 59 | bufD = _mm512_load_ps(dataSrc + (x+0) * nrows + y + 34); 60 | bufE = _mm512_load_ps(dataSrc + (x+0) * nrows + y + 50); 61 | bufF = _mm512_permutex2var_ps(bufB, idx0, bufC); 62 | bufG = _mm512_permutex2var_ps(bufB, idx1, bufC); 63 | bufH = _mm512_permutex2var_ps(bufD, idx0, bufE); 64 | bufI = _mm512_permutex2var_ps(bufD, idx1, bufE); 65 | bufA[4] = _mm512_permutex2var_ps(bufF, idx1, bufH); 66 | bufA[5] = _mm512_permutex2var_ps(bufG, idx1, bufI); 67 | /* 6, 7, 8, 9, 10, 11 */ 68 | bufB = _mm512_load_ps(dataSrc + (x+1) * nrows + y); 69 | bufC = _mm512_load_ps(dataSrc + (x+1) * nrows + y + 16); 70 | bufD = _mm512_load_ps(dataSrc + (x+1) * nrows + y + 32); 71 | bufE = _mm512_load_ps(dataSrc + (x+1) * nrows + y + 48); 72 | bufF = _mm512_permutex2var_ps(bufB, idx0, bufC); 73 | bufG = _mm512_permutex2var_ps(bufB, idx1, bufC); 74 | bufH = _mm512_permutex2var_ps(bufD, idx0, bufE); 75 | bufI = _mm512_permutex2var_ps(bufD, idx1, bufE); 76 | bufA[6] = _mm512_permutex2var_ps(bufF, idx0, bufH); 77 | bufA[7] = _mm512_permutex2var_ps(bufG, idx0, bufI); 78 | bufA[8] = _mm512_permutex2var_ps(bufF, idx1, bufH); 79 | bufA[9] = _mm512_permutex2var_ps(bufG, idx1, bufI); 80 | 81 | bufB = _mm512_load_ps(dataSrc + (x+1) * nrows + y + 2); 82 | bufC = _mm512_load_ps(dataSrc + (x+1) * nrows + y + 18); 83 | bufD = _mm512_load_ps(dataSrc + (x+1) * nrows + y + 34); 84 | bufE = _mm512_load_ps(dataSrc + (x+1) * nrows + y + 50); 85 | bufF = _mm512_permutex2var_ps(bufB, idx0, bufC); 86 | bufG = _mm512_permutex2var_ps(bufB, idx1, bufC); 87 | bufH = _mm512_permutex2var_ps(bufD, idx0, bufE); 88 | bufI = _mm512_permutex2var_ps(bufD, idx1, bufE); 89 | bufA[10] = _mm512_permutex2var_ps(bufF, idx1, bufH); 90 | bufA[11] = _mm512_permutex2var_ps(bufG, idx1, bufI); 91 | /* 12, 13, 14, 15, 16, 17 */ 92 | bufB = _mm512_load_ps(dataSrc + (x+2) * nrows + y); 93 | bufC = _mm512_load_ps(dataSrc + (x+2) * nrows + y + 16); 94 | bufD = _mm512_load_ps(dataSrc + (x+2) * nrows + y + 32); 95 | bufE = _mm512_load_ps(dataSrc + (x+2) * nrows + y + 48); 96 | bufF = _mm512_permutex2var_ps(bufB, idx0, bufC); 97 | bufG = _mm512_permutex2var_ps(bufB, idx1, bufC); 98 | bufH = _mm512_permutex2var_ps(bufD, idx0, bufE); 99 | bufI = _mm512_permutex2var_ps(bufD, idx1, bufE); 100 | bufA[12] = _mm512_permutex2var_ps(bufF, idx0, bufH); 101 | bufA[13] = _mm512_permutex2var_ps(bufG, idx0, bufI); 102 | bufA[14] = _mm512_permutex2var_ps(bufF, idx1, bufH); 103 | bufA[15] = _mm512_permutex2var_ps(bufG, idx1, bufI); 104 | 105 | bufB = _mm512_load_ps(dataSrc + (x+2) * nrows + y + 2); 106 | bufC = _mm512_load_ps(dataSrc + (x+2) * nrows + y + 18); 107 | bufD = _mm512_load_ps(dataSrc + (x+2) * nrows + y + 34); 108 | bufE = _mm512_load_ps(dataSrc + (x+2) * nrows + y + 50); 109 | bufF = _mm512_permutex2var_ps(bufB, idx0, bufC); 110 | bufG = _mm512_permutex2var_ps(bufB, idx1, bufC); 111 | bufH = _mm512_permutex2var_ps(bufD, idx0, bufE); 112 | bufI = _mm512_permutex2var_ps(bufD, idx1, bufE); 113 | bufA[16] = _mm512_permutex2var_ps(bufF, idx1, bufH); 114 | bufA[17] = _mm512_permutex2var_ps(bufG, idx1, bufI); 115 | /* 18, 19, 20, 21, 22, 23 */ 116 | bufB = _mm512_load_ps(dataSrc + (x+3) * nrows + y); 117 | bufC = _mm512_load_ps(dataSrc + (x+3) * nrows + y + 16); 118 | bufD = _mm512_load_ps(dataSrc + (x+3) * nrows + y + 32); 119 | bufE = _mm512_load_ps(dataSrc + (x+3) * nrows + y + 48); 120 | bufF = _mm512_permutex2var_ps(bufB, idx0, bufC); 121 | bufG = _mm512_permutex2var_ps(bufB, idx1, bufC); 122 | bufH = _mm512_permutex2var_ps(bufD, idx0, bufE); 123 | bufI = _mm512_permutex2var_ps(bufD, idx1, bufE); 124 | bufA[18] = _mm512_permutex2var_ps(bufF, idx0, bufH); 125 | bufA[19] = _mm512_permutex2var_ps(bufG, idx0, bufI); 126 | bufA[20] = _mm512_permutex2var_ps(bufF, idx1, bufH); 127 | bufA[21] = _mm512_permutex2var_ps(bufG, idx1, bufI); 128 | 129 | bufB = _mm512_load_ps(dataSrc + (x+3) * nrows + y + 2); 130 | bufC = _mm512_load_ps(dataSrc + (x+3) * nrows + y + 18); 131 | bufD = _mm512_load_ps(dataSrc + (x+3) * nrows + y + 34); 132 | bufE = _mm512_load_ps(dataSrc + (x+3) * nrows + y + 50); 133 | bufF = _mm512_permutex2var_ps(bufB, idx0, bufC); 134 | bufG = _mm512_permutex2var_ps(bufB, idx1, bufC); 135 | bufH = _mm512_permutex2var_ps(bufD, idx0, bufE); 136 | bufI = _mm512_permutex2var_ps(bufD, idx1, bufE); 137 | bufA[22] = _mm512_permutex2var_ps(bufF, idx1, bufH); 138 | bufA[23] = _mm512_permutex2var_ps(bufG, idx1, bufI); 139 | /* 24, 25, 26, 27, 28, 29 */ 140 | bufB = _mm512_load_ps(dataSrc + (x+4) * nrows + y); 141 | bufC = _mm512_load_ps(dataSrc + (x+4) * nrows + y + 16); 142 | bufD = _mm512_load_ps(dataSrc + (x+4) * nrows + y + 32); 143 | bufE = _mm512_load_ps(dataSrc + (x+4) * nrows + y + 48); 144 | bufF = _mm512_permutex2var_ps(bufB, idx0, bufC); 145 | bufG = _mm512_permutex2var_ps(bufB, idx1, bufC); 146 | bufH = _mm512_permutex2var_ps(bufD, idx0, bufE); 147 | bufI = _mm512_permutex2var_ps(bufD, idx1, bufE); 148 | bufA[24] = _mm512_permutex2var_ps(bufF, idx0, bufH); 149 | bufA[25] = _mm512_permutex2var_ps(bufG, idx0, bufI); 150 | bufA[26] = _mm512_permutex2var_ps(bufF, idx1, bufH); 151 | bufA[27] = _mm512_permutex2var_ps(bufG, idx1, bufI); 152 | 153 | bufB = _mm512_load_ps(dataSrc + (x+4) * nrows + y + 2); 154 | bufC = _mm512_load_ps(dataSrc + (x+4) * nrows + y + 18); 155 | bufD = _mm512_load_ps(dataSrc + (x+4) * nrows + y + 34); 156 | bufE = _mm512_load_ps(dataSrc + (x+4) * nrows + y + 50); 157 | bufF = _mm512_permutex2var_ps(bufB, idx0, bufC); 158 | bufG = _mm512_permutex2var_ps(bufB, idx1, bufC); 159 | bufH = _mm512_permutex2var_ps(bufD, idx0, bufE); 160 | bufI = _mm512_permutex2var_ps(bufD, idx1, bufE); 161 | bufA[28] = _mm512_permutex2var_ps(bufF, idx1, bufH); 162 | bufA[29] = _mm512_permutex2var_ps(bufG, idx1, bufI); 163 | /* 30, 31, 32, 33, 34, 35 */ 164 | bufB = _mm512_load_ps(dataSrc + (x+5) * nrows + y); 165 | bufC = _mm512_load_ps(dataSrc + (x+5) * nrows + y + 16); 166 | bufD = _mm512_load_ps(dataSrc + (x+5) * nrows + y + 32); 167 | bufE = _mm512_load_ps(dataSrc + (x+5) * nrows + y + 48); 168 | bufF = _mm512_permutex2var_ps(bufB, idx0, bufC); 169 | bufG = _mm512_permutex2var_ps(bufB, idx1, bufC); 170 | bufH = _mm512_permutex2var_ps(bufD, idx0, bufE); 171 | bufI = _mm512_permutex2var_ps(bufD, idx1, bufE); 172 | bufA[30] = _mm512_permutex2var_ps(bufF, idx0, bufH); 173 | bufA[31] = _mm512_permutex2var_ps(bufG, idx0, bufI); 174 | bufA[32] = _mm512_permutex2var_ps(bufF, idx1, bufH); 175 | bufA[33] = _mm512_permutex2var_ps(bufG, idx1, bufI); 176 | 177 | bufB = _mm512_load_ps(dataSrc + (x+5) * nrows + y + 2); 178 | bufC = _mm512_load_ps(dataSrc + (x+5) * nrows + y + 18); 179 | bufD = _mm512_load_ps(dataSrc + (x+5) * nrows + y + 34); 180 | bufE = _mm512_load_ps(dataSrc + (x+5) * nrows + y + 50); 181 | bufF = _mm512_permutex2var_ps(bufB, idx0, bufC); 182 | bufG = _mm512_permutex2var_ps(bufB, idx1, bufC); 183 | bufH = _mm512_permutex2var_ps(bufD, idx0, bufE); 184 | bufI = _mm512_permutex2var_ps(bufD, idx1, bufE); 185 | bufA[34] = _mm512_permutex2var_ps(bufF, idx1, bufH); 186 | bufA[35] = _mm512_permutex2var_ps(bufG, idx1, bufI); 187 | 188 | 189 | __m512 bufTemp[36]; 190 | __m512 m0 = _mm512_setzero_ps(); 191 | __m512 m1 = _mm512_set1_ps(1.0f); 192 | __m512 m2 = _mm512_set1_ps(2.0f); 193 | __m512 m4 = _mm512_set1_ps(4.0f); 194 | __m512 m5 = _mm512_set1_ps(5.0f); 195 | 196 | bufTemp[0] = _mm512_mul_ps(m4, bufA[0]); 197 | bufTemp[1] = _mm512_mul_ps(m4, bufA[1]); 198 | bufTemp[2] = _mm512_mul_ps(m4, bufA[2]); 199 | bufTemp[3] = _mm512_mul_ps(m4, bufA[3]); 200 | bufTemp[4] = _mm512_mul_ps(m4, bufA[4]); 201 | bufTemp[5] = _mm512_mul_ps(m4, bufA[5]); 202 | bufTemp[0] = _mm512_fnmadd_ps(m5, bufA[12], bufTemp[0]); 203 | bufTemp[1] = _mm512_fnmadd_ps(m5, bufA[13], bufTemp[1]); 204 | bufTemp[2] = _mm512_fnmadd_ps(m5, bufA[14], bufTemp[2]); 205 | bufTemp[3] = _mm512_fnmadd_ps(m5, bufA[15], bufTemp[3]); 206 | bufTemp[4] = _mm512_fnmadd_ps(m5, bufA[16], bufTemp[4]); 207 | bufTemp[5] = _mm512_fnmadd_ps(m5, bufA[17], bufTemp[5]); 208 | bufTemp[0] = _mm512_add_ps(bufA[24], bufTemp[0]); 209 | bufTemp[1] = _mm512_add_ps(bufA[25], bufTemp[1]); 210 | bufTemp[2] = _mm512_add_ps(bufA[26], bufTemp[2]); 211 | bufTemp[3] = _mm512_add_ps(bufA[27], bufTemp[3]); 212 | bufTemp[4] = _mm512_add_ps(bufA[28], bufTemp[4]); 213 | bufTemp[5] = _mm512_add_ps(bufA[29], bufTemp[5]); 214 | 215 | bufTemp[6] = _mm512_fnmadd_ps(m4, bufA[6], m0); 216 | bufTemp[7] = _mm512_fnmadd_ps(m4, bufA[7], m0); 217 | bufTemp[8] = _mm512_fnmadd_ps(m4, bufA[8], m0); 218 | bufTemp[9] = _mm512_fnmadd_ps(m4, bufA[9], m0); 219 | bufTemp[10] = _mm512_fnmadd_ps(m4, bufA[10], m0); 220 | bufTemp[11] = _mm512_fnmadd_ps(m4, bufA[11], m0); 221 | bufTemp[6] = _mm512_fnmadd_ps(m4, bufA[12], bufTemp[6]); 222 | bufTemp[7] = _mm512_fnmadd_ps(m4, bufA[13], bufTemp[7]); 223 | bufTemp[8] = _mm512_fnmadd_ps(m4, bufA[14], bufTemp[8]); 224 | bufTemp[9] = _mm512_fnmadd_ps(m4, bufA[15], bufTemp[9]); 225 | bufTemp[10] = _mm512_fnmadd_ps(m4, bufA[16], bufTemp[10]); 226 | bufTemp[11] = _mm512_fnmadd_ps(m4, bufA[17], bufTemp[11]); 227 | bufTemp[6] = _mm512_add_ps(bufA[18], bufTemp[6]); 228 | bufTemp[7] = _mm512_add_ps(bufA[19], bufTemp[7]); 229 | bufTemp[8] = _mm512_add_ps(bufA[20], bufTemp[8]); 230 | bufTemp[9] = _mm512_add_ps(bufA[21], bufTemp[9]); 231 | bufTemp[10] = _mm512_add_ps(bufA[22], bufTemp[10]); 232 | bufTemp[11] = _mm512_add_ps(bufA[23], bufTemp[11]); 233 | bufTemp[6] = _mm512_add_ps(bufA[24], bufTemp[6]); 234 | bufTemp[7] = _mm512_add_ps(bufA[25], bufTemp[7]); 235 | bufTemp[8] = _mm512_add_ps(bufA[26], bufTemp[8]); 236 | bufTemp[9] = _mm512_add_ps(bufA[27], bufTemp[9]); 237 | bufTemp[10] = _mm512_add_ps(bufA[28], bufTemp[10]); 238 | bufTemp[11] = _mm512_add_ps(bufA[29], bufTemp[11]); 239 | 240 | bufTemp[12] = _mm512_mul_ps(m4, bufA[6]); 241 | bufTemp[13] = _mm512_mul_ps(m4, bufA[7]); 242 | bufTemp[14] = _mm512_mul_ps(m4, bufA[8]); 243 | bufTemp[15] = _mm512_mul_ps(m4, bufA[9]); 244 | bufTemp[16] = _mm512_mul_ps(m4, bufA[10]); 245 | bufTemp[17] = _mm512_mul_ps(m4, bufA[11]); 246 | bufTemp[12] = _mm512_fnmadd_ps(m4, bufA[12], bufTemp[12]); 247 | bufTemp[13] = _mm512_fnmadd_ps(m4, bufA[13], bufTemp[13]); 248 | bufTemp[14] = _mm512_fnmadd_ps(m4, bufA[14], bufTemp[14]); 249 | bufTemp[15] = _mm512_fnmadd_ps(m4, bufA[15], bufTemp[15]); 250 | bufTemp[16] = _mm512_fnmadd_ps(m4, bufA[16], bufTemp[16]); 251 | bufTemp[17] = _mm512_fnmadd_ps(m4, bufA[17], bufTemp[17]); 252 | bufTemp[12] = _mm512_sub_ps(bufTemp[12], bufA[18]); 253 | bufTemp[13] = _mm512_sub_ps(bufTemp[13], bufA[19]); 254 | bufTemp[14] = _mm512_sub_ps(bufTemp[14], bufA[20]); 255 | bufTemp[15] = _mm512_sub_ps(bufTemp[15], bufA[21]); 256 | bufTemp[16] = _mm512_sub_ps(bufTemp[16], bufA[22]); 257 | bufTemp[17] = _mm512_sub_ps(bufTemp[17], bufA[23]); 258 | bufTemp[12] = _mm512_add_ps(bufTemp[12], bufA[24]); 259 | bufTemp[13] = _mm512_add_ps(bufTemp[13], bufA[25]); 260 | bufTemp[14] = _mm512_add_ps(bufTemp[14], bufA[26]); 261 | bufTemp[15] = _mm512_add_ps(bufTemp[15], bufA[27]); 262 | bufTemp[16] = _mm512_add_ps(bufTemp[16], bufA[28]); 263 | bufTemp[17] = _mm512_add_ps(bufTemp[17], bufA[29]); 264 | 265 | bufTemp[18] = _mm512_fnmadd_ps(m2, bufA[6], m0); 266 | bufTemp[19] = _mm512_fnmadd_ps(m2, bufA[7], m0); 267 | bufTemp[20] = _mm512_fnmadd_ps(m2, bufA[8], m0); 268 | bufTemp[21] = _mm512_fnmadd_ps(m2, bufA[9], m0); 269 | bufTemp[22] = _mm512_fnmadd_ps(m2, bufA[10], m0); 270 | bufTemp[23] = _mm512_fnmadd_ps(m2, bufA[11], m0); 271 | bufTemp[18] = _mm512_sub_ps(bufTemp[18], bufA[12]); 272 | bufTemp[19] = _mm512_sub_ps(bufTemp[19], bufA[13]); 273 | bufTemp[20] = _mm512_sub_ps(bufTemp[20], bufA[14]); 274 | bufTemp[21] = _mm512_sub_ps(bufTemp[21], bufA[15]); 275 | bufTemp[22] = _mm512_sub_ps(bufTemp[22], bufA[16]); 276 | bufTemp[23] = _mm512_sub_ps(bufTemp[23], bufA[17]); 277 | bufTemp[18] = _mm512_fmadd_ps(m2, bufA[18], bufTemp[18]); 278 | bufTemp[19] = _mm512_fmadd_ps(m2, bufA[19], bufTemp[19]); 279 | bufTemp[20] = _mm512_fmadd_ps(m2, bufA[20], bufTemp[20]); 280 | bufTemp[21] = _mm512_fmadd_ps(m2, bufA[21], bufTemp[21]); 281 | bufTemp[22] = _mm512_fmadd_ps(m2, bufA[22], bufTemp[22]); 282 | bufTemp[23] = _mm512_fmadd_ps(m2, bufA[23], bufTemp[23]); 283 | bufTemp[18] = _mm512_add_ps(bufTemp[18], bufA[24]); 284 | bufTemp[19] = _mm512_add_ps(bufTemp[19], bufA[25]); 285 | bufTemp[20] = _mm512_add_ps(bufTemp[20], bufA[26]); 286 | bufTemp[21] = _mm512_add_ps(bufTemp[21], bufA[27]); 287 | bufTemp[22] = _mm512_add_ps(bufTemp[22], bufA[28]); 288 | bufTemp[23] = _mm512_add_ps(bufTemp[23], bufA[29]); 289 | 290 | bufTemp[24] = _mm512_mul_ps(m2, bufA[6]); 291 | bufTemp[25] = _mm512_mul_ps(m2, bufA[7]); 292 | bufTemp[26] = _mm512_mul_ps(m2, bufA[8]); 293 | bufTemp[27] = _mm512_mul_ps(m2, bufA[9]); 294 | bufTemp[28] = _mm512_mul_ps(m2, bufA[10]); 295 | bufTemp[29] = _mm512_mul_ps(m2, bufA[11]); 296 | bufTemp[24] = _mm512_sub_ps(bufTemp[24], bufA[12]); 297 | bufTemp[25] = _mm512_sub_ps(bufTemp[25], bufA[13]); 298 | bufTemp[26] = _mm512_sub_ps(bufTemp[26], bufA[14]); 299 | bufTemp[27] = _mm512_sub_ps(bufTemp[27], bufA[15]); 300 | bufTemp[28] = _mm512_sub_ps(bufTemp[28], bufA[16]); 301 | bufTemp[29] = _mm512_sub_ps(bufTemp[29], bufA[17]); 302 | bufTemp[24] = _mm512_fnmadd_ps(m2, bufA[18], bufTemp[24]); 303 | bufTemp[25] = _mm512_fnmadd_ps(m2, bufA[19], bufTemp[25]); 304 | bufTemp[26] = _mm512_fnmadd_ps(m2, bufA[20], bufTemp[26]); 305 | bufTemp[27] = _mm512_fnmadd_ps(m2, bufA[21], bufTemp[27]); 306 | bufTemp[28] = _mm512_fnmadd_ps(m2, bufA[22], bufTemp[28]); 307 | bufTemp[29] = _mm512_fnmadd_ps(m2, bufA[23], bufTemp[29]); 308 | bufTemp[24] = _mm512_add_ps(bufTemp[24], bufA[24]); 309 | bufTemp[25] = _mm512_add_ps(bufTemp[25], bufA[25]); 310 | bufTemp[26] = _mm512_add_ps(bufTemp[26], bufA[26]); 311 | bufTemp[27] = _mm512_add_ps(bufTemp[27], bufA[27]); 312 | bufTemp[28] = _mm512_add_ps(bufTemp[28], bufA[28]); 313 | bufTemp[29] = _mm512_add_ps(bufTemp[29], bufA[29]); 314 | 315 | bufTemp[30] = _mm512_mul_ps(m4, bufA[6]); 316 | bufTemp[31] = _mm512_mul_ps(m4, bufA[7]); 317 | bufTemp[32] = _mm512_mul_ps(m4, bufA[8]); 318 | bufTemp[33] = _mm512_mul_ps(m4, bufA[9]); 319 | bufTemp[34] = _mm512_mul_ps(m4, bufA[10]); 320 | bufTemp[35] = _mm512_mul_ps(m4, bufA[11]); 321 | bufTemp[30] = _mm512_fnmadd_ps(m5, bufA[18], bufTemp[30]); 322 | bufTemp[31] = _mm512_fnmadd_ps(m5, bufA[19], bufTemp[31]); 323 | bufTemp[32] = _mm512_fnmadd_ps(m5, bufA[20], bufTemp[32]); 324 | bufTemp[33] = _mm512_fnmadd_ps(m5, bufA[21], bufTemp[33]); 325 | bufTemp[34] = _mm512_fnmadd_ps(m5, bufA[22], bufTemp[34]); 326 | bufTemp[35] = _mm512_fnmadd_ps(m5, bufA[23], bufTemp[35]); 327 | bufTemp[30] = _mm512_add_ps(bufTemp[30], bufA[30]); 328 | bufTemp[31] = _mm512_add_ps(bufTemp[31], bufA[31]); 329 | bufTemp[32] = _mm512_add_ps(bufTemp[32], bufA[32]); 330 | bufTemp[33] = _mm512_add_ps(bufTemp[33], bufA[33]); 331 | bufTemp[34] = _mm512_add_ps(bufTemp[34], bufA[34]); 332 | bufTemp[35] = _mm512_add_ps(bufTemp[35], bufA[35]); 333 | 334 | /* 4 0 0 0 0 0 335 | 0 -4 4 -2 2 4 336 | -5 -4 -4 -1 -1 0 337 | 0 1 -1 2 -2 -5 338 | 1 1 1 1 1 0 339 | 0 0 0 0 0 1 */ 340 | 341 | bufB = _mm512_mul_ps(bufTemp[0], m4); 342 | bufB = _mm512_fnmadd_ps(m5, bufTemp[2], bufB); 343 | bufB = _mm512_add_ps(bufB, bufTemp[4]); 344 | _mm512_store_ps(dataDst + 0 * ISTRIDE + coter, bufB); 345 | 346 | bufC = _mm512_fnmadd_ps(m4, bufTemp[1], m0); 347 | bufC = _mm512_fnmadd_ps(m4, bufTemp[2], bufC); 348 | bufC = _mm512_add_ps(bufTemp[3], bufC); 349 | bufC = _mm512_add_ps(bufTemp[4], bufC); 350 | _mm512_store_ps(dataDst + 1 * ISTRIDE + coter, bufC); 351 | 352 | bufD = _mm512_mul_ps(m4, bufTemp[1]); 353 | bufD = _mm512_fnmadd_ps(m4, bufTemp[2], bufD); 354 | bufD = _mm512_sub_ps(bufD, bufTemp[3]); 355 | bufD = _mm512_add_ps(bufD, bufTemp[4]); 356 | _mm512_store_ps(dataDst + 2 * ISTRIDE + coter, bufD); 357 | 358 | bufE = _mm512_fnmadd_ps(m2, bufTemp[1], m0); 359 | bufE = _mm512_sub_ps(bufE, bufTemp[2]); 360 | bufE = _mm512_fmadd_ps(m2, bufTemp[3], bufE); 361 | bufE = _mm512_add_ps(bufE, bufTemp[4]); 362 | _mm512_store_ps(dataDst + 3 * ISTRIDE + coter, bufE); 363 | 364 | bufF = _mm512_mul_ps(m2, bufTemp[1]); 365 | bufF = _mm512_sub_ps(bufF, bufTemp[2]); 366 | bufF = _mm512_fnmadd_ps(m2, bufTemp[3], bufF); 367 | bufF = _mm512_add_ps(bufF, bufTemp[4]); 368 | _mm512_store_ps(dataDst + 4 * ISTRIDE + coter, bufF); 369 | 370 | bufG = _mm512_mul_ps(m4, bufTemp[1]); 371 | bufG = _mm512_fnmadd_ps(m5, bufTemp[3], bufG); 372 | bufG = _mm512_add_ps(bufG, bufTemp[5]); 373 | _mm512_store_ps(dataDst + 5 * ISTRIDE + coter, bufG); 374 | 375 | // ------------------------------------------------------- 376 | bufB = _mm512_mul_ps(bufTemp[6], m4); 377 | bufB = _mm512_fnmadd_ps(m5, bufTemp[8], bufB); 378 | bufB = _mm512_add_ps(bufB, bufTemp[10]); 379 | _mm512_store_ps(dataDst + 6 * ISTRIDE + coter, bufB); 380 | 381 | bufC = _mm512_fnmadd_ps(m4, bufTemp[7], m0); 382 | bufC = _mm512_fnmadd_ps(m4, bufTemp[8], bufC); 383 | bufC = _mm512_add_ps(bufTemp[9], bufC); 384 | bufC = _mm512_add_ps(bufTemp[10], bufC); 385 | _mm512_store_ps(dataDst + 7 * ISTRIDE + coter, bufC); 386 | 387 | bufD = _mm512_mul_ps(m4, bufTemp[7]); 388 | bufD = _mm512_fnmadd_ps(m4, bufTemp[8], bufD); 389 | bufD = _mm512_sub_ps(bufD, bufTemp[9]); 390 | bufD = _mm512_add_ps(bufD, bufTemp[10]); 391 | _mm512_store_ps(dataDst + 8 * ISTRIDE + coter, bufD); 392 | 393 | bufE = _mm512_fnmadd_ps(m2, bufTemp[7], m0); 394 | bufE = _mm512_sub_ps(bufE, bufTemp[8]); 395 | bufE = _mm512_fmadd_ps(m2, bufTemp[9], bufE); 396 | bufE = _mm512_add_ps(bufE, bufTemp[10]); 397 | _mm512_store_ps(dataDst + 9 * ISTRIDE + coter, bufE); 398 | 399 | bufF = _mm512_mul_ps(m2, bufTemp[7]); 400 | bufF = _mm512_sub_ps(bufF, bufTemp[8]); 401 | bufF = _mm512_fnmadd_ps(m2, bufTemp[9], bufF); 402 | bufF = _mm512_add_ps(bufF, bufTemp[10]); 403 | _mm512_store_ps(dataDst + 10 * ISTRIDE + coter, bufF); 404 | 405 | bufG = _mm512_mul_ps(m4, bufTemp[7]); 406 | bufG = _mm512_fnmadd_ps(m5, bufTemp[9], bufG); 407 | bufG = _mm512_add_ps(bufG, bufTemp[11]); 408 | _mm512_store_ps(dataDst + 11 * ISTRIDE + coter, bufG); 409 | 410 | // ------------------------------------ 411 | bufB = _mm512_mul_ps(bufTemp[12], m4); 412 | bufB = _mm512_fnmadd_ps(m5, bufTemp[14], bufB); 413 | bufB = _mm512_add_ps(bufB, bufTemp[16]); 414 | _mm512_store_ps(dataDst + 12 * ISTRIDE + coter, bufB); 415 | 416 | bufC = _mm512_fnmadd_ps(m4, bufTemp[13], m0); 417 | bufC = _mm512_fnmadd_ps(m4, bufTemp[14], bufC); 418 | bufC = _mm512_add_ps(bufTemp[15], bufC); 419 | bufC = _mm512_add_ps(bufTemp[16], bufC); 420 | _mm512_store_ps(dataDst + 13 * ISTRIDE + coter, bufC); 421 | 422 | bufD = _mm512_mul_ps(m4, bufTemp[13]); 423 | bufD = _mm512_fnmadd_ps(m4, bufTemp[14], bufD); 424 | bufD = _mm512_sub_ps(bufD, bufTemp[15]); 425 | bufD = _mm512_add_ps(bufD, bufTemp[16]); 426 | _mm512_store_ps(dataDst + 14 * ISTRIDE + coter, bufD); 427 | 428 | bufE = _mm512_fnmadd_ps(m2, bufTemp[13], m0); 429 | bufE = _mm512_sub_ps(bufE, bufTemp[14]); 430 | bufE = _mm512_fmadd_ps(m2, bufTemp[15], bufE); 431 | bufE = _mm512_add_ps(bufE, bufTemp[16]); 432 | _mm512_store_ps(dataDst + 15 * ISTRIDE + coter, bufE); 433 | 434 | bufF = _mm512_mul_ps(m2, bufTemp[13]); 435 | bufF = _mm512_sub_ps(bufF, bufTemp[14]); 436 | bufF = _mm512_fnmadd_ps(m2, bufTemp[15], bufF); 437 | bufF = _mm512_add_ps(bufF, bufTemp[16]); 438 | _mm512_store_ps(dataDst + 16 * ISTRIDE + coter, bufF); 439 | 440 | bufG = _mm512_mul_ps(m4, bufTemp[13]); 441 | bufG = _mm512_fnmadd_ps(m5, bufTemp[15], bufG); 442 | bufG = _mm512_add_ps(bufG, bufTemp[17]); 443 | _mm512_store_ps(dataDst + 17 * ISTRIDE + coter, bufG); 444 | 445 | // -------------------------------------------- 446 | bufB = _mm512_mul_ps(bufTemp[18], m4); 447 | bufB = _mm512_fnmadd_ps(m5, bufTemp[20], bufB); 448 | bufB = _mm512_add_ps(bufB, bufTemp[22]); 449 | _mm512_store_ps(dataDst + 18 * ISTRIDE + coter, bufB); 450 | 451 | bufC = _mm512_fnmadd_ps(m4, bufTemp[19], m0); 452 | bufC = _mm512_fnmadd_ps(m4, bufTemp[20], bufC); 453 | bufC = _mm512_add_ps(bufTemp[21], bufC); 454 | bufC = _mm512_add_ps(bufTemp[22], bufC); 455 | _mm512_store_ps(dataDst + 19 * ISTRIDE + coter, bufC); 456 | 457 | bufD = _mm512_mul_ps(m4, bufTemp[19]); 458 | bufD = _mm512_fnmadd_ps(m4, bufTemp[20], bufD); 459 | bufD = _mm512_sub_ps(bufD, bufTemp[21]); 460 | bufD = _mm512_add_ps(bufD, bufTemp[22]); 461 | _mm512_store_ps(dataDst + 20 * ISTRIDE + coter, bufD); 462 | 463 | bufE = _mm512_fnmadd_ps(m2, bufTemp[19], m0); 464 | bufE = _mm512_sub_ps(bufE, bufTemp[20]); 465 | bufE = _mm512_fmadd_ps(m2, bufTemp[21], bufE); 466 | bufE = _mm512_add_ps(bufE, bufTemp[22]); 467 | _mm512_store_ps(dataDst + 21 * ISTRIDE + coter, bufE); 468 | 469 | bufF = _mm512_mul_ps(m2, bufTemp[19]); 470 | bufF = _mm512_sub_ps(bufF, bufTemp[20]); 471 | bufF = _mm512_fnmadd_ps(m2, bufTemp[21], bufF); 472 | bufF = _mm512_add_ps(bufF, bufTemp[22]); 473 | _mm512_store_ps(dataDst + 22 * ISTRIDE + coter, bufF); 474 | 475 | bufG = _mm512_mul_ps(m4, bufTemp[19]); 476 | bufG = _mm512_fnmadd_ps(m5, bufTemp[21], bufG); 477 | bufG = _mm512_add_ps(bufG, bufTemp[23]); 478 | _mm512_store_ps(dataDst + 23 * ISTRIDE + coter, bufG); 479 | 480 | // -------------------------------------------- 481 | bufB = _mm512_mul_ps(bufTemp[24], m4); 482 | bufB = _mm512_fnmadd_ps(m5, bufTemp[26], bufB); 483 | bufB = _mm512_add_ps(bufB, bufTemp[28]); 484 | _mm512_store_ps(dataDst + 24 * ISTRIDE + coter, bufB); 485 | 486 | bufC = _mm512_fnmadd_ps(m4, bufTemp[25], m0); 487 | bufC = _mm512_fnmadd_ps(m4, bufTemp[26], bufC); 488 | bufC = _mm512_add_ps(bufTemp[27], bufC); 489 | bufC = _mm512_add_ps(bufTemp[28], bufC); 490 | _mm512_store_ps(dataDst + 25 * ISTRIDE + coter, bufC); 491 | 492 | bufD = _mm512_mul_ps(m4, bufTemp[25]); 493 | bufD = _mm512_fnmadd_ps(m4, bufTemp[26], bufD); 494 | bufD = _mm512_sub_ps(bufD, bufTemp[27]); 495 | bufD = _mm512_add_ps(bufD, bufTemp[28]); 496 | _mm512_store_ps(dataDst + 26 * ISTRIDE + coter, bufD); 497 | 498 | bufE = _mm512_fnmadd_ps(m2, bufTemp[25], m0); 499 | bufE = _mm512_sub_ps(bufE, bufTemp[26]); 500 | bufE = _mm512_fmadd_ps(m2, bufTemp[27], bufE); 501 | bufE = _mm512_add_ps(bufE, bufTemp[28]); 502 | _mm512_store_ps(dataDst + 27 * ISTRIDE + coter, bufE); 503 | 504 | bufF = _mm512_mul_ps(m2, bufTemp[25]); 505 | bufF = _mm512_sub_ps(bufF, bufTemp[26]); 506 | bufF = _mm512_fnmadd_ps(m2, bufTemp[27], bufF); 507 | bufF = _mm512_add_ps(bufF, bufTemp[28]); 508 | _mm512_store_ps(dataDst + 28 * ISTRIDE + coter, bufF); 509 | 510 | bufG = _mm512_mul_ps(m4, bufTemp[25]); 511 | bufG = _mm512_fnmadd_ps(m5, bufTemp[27], bufG); 512 | bufG = _mm512_add_ps(bufG, bufTemp[29]); 513 | _mm512_store_ps(dataDst + 29 * ISTRIDE + coter, bufG); 514 | 515 | // ---------------------------------------- 516 | bufB = _mm512_mul_ps(bufTemp[30], m4); 517 | bufB = _mm512_fnmadd_ps(m5, bufTemp[32], bufB); 518 | bufB = _mm512_add_ps(bufB, bufTemp[34]); 519 | _mm512_store_ps(dataDst + 30 * ISTRIDE + coter, bufB); 520 | 521 | bufC = _mm512_fnmadd_ps(m4, bufTemp[31], m0); 522 | bufC = _mm512_fnmadd_ps(m4, bufTemp[32], bufC); 523 | bufC = _mm512_add_ps(bufTemp[33], bufC); 524 | bufC = _mm512_add_ps(bufTemp[34], bufC); 525 | _mm512_store_ps(dataDst + 31 * ISTRIDE + coter, bufC); 526 | 527 | bufD = _mm512_mul_ps(m4, bufTemp[31]); 528 | bufD = _mm512_fnmadd_ps(m4, bufTemp[32], bufD); 529 | bufD = _mm512_sub_ps(bufD, bufTemp[33]); 530 | bufD = _mm512_add_ps(bufD, bufTemp[34]); 531 | _mm512_store_ps(dataDst + 32 * ISTRIDE + coter, bufD); 532 | 533 | bufE = _mm512_fnmadd_ps(m2, bufTemp[31], m0); 534 | bufE = _mm512_sub_ps(bufE, bufTemp[32]); 535 | bufE = _mm512_fmadd_ps(m2, bufTemp[33], bufE); 536 | bufE = _mm512_add_ps(bufE, bufTemp[34]); 537 | _mm512_store_ps(dataDst + 33 * ISTRIDE + coter, bufE); 538 | 539 | bufF = _mm512_mul_ps(m2, bufTemp[31]); 540 | bufF = _mm512_sub_ps(bufF, bufTemp[32]); 541 | bufF = _mm512_fnmadd_ps(m2, bufTemp[33], bufF); 542 | bufF = _mm512_add_ps(bufF, bufTemp[34]); 543 | _mm512_store_ps(dataDst + 34 * ISTRIDE + coter, bufF); 544 | 545 | bufG = _mm512_mul_ps(m4, bufTemp[31]); 546 | bufG = _mm512_fnmadd_ps(m5, bufTemp[33], bufG); 547 | bufG = _mm512_add_ps(bufG, bufTemp[35]); 548 | _mm512_store_ps(dataDst + 35 * ISTRIDE + coter, bufG); 549 | 550 | *counter += 16; 551 | } 552 | 553 | static inline void pad_get_tiles(int x, int y, int lenX, int lenY, int nrows, const float *dataSrc, 554 | float *temp, float *dataDst, int *counter) { 555 | if (2 == lenX || 2 == lenY) return; 556 | int i, j; 557 | for (i = 0; i < lenX; ++i) { 558 | for (j = 0; j < lenY; ++j) { 559 | temp[i * 66 + j] = dataSrc[(x + i) * nrows + y + j]; 560 | } 561 | for (; j < 66; ++j) { 562 | temp[i * 66 + j] = 0; 563 | } 564 | //memset(temp + i * 66 + j, 1, (66 - j) * sizeof(float)); 565 | } 566 | /*if (i < 6) { 567 | memset(temp + i * 66, 0, 66 * (6 - i) * sizeof(float)); 568 | }*/ 569 | for (; i < 6; ++i) { 570 | for (j = 0; j < 66; ++j) { 571 | temp[i * 66 + j] = 0; 572 | } 573 | } 574 | 575 | get_tiles_4x3_16t(0, 0, 66, temp, dataDst, counter); 576 | } 577 | 578 | static inline void get_tiles_4x3_1t(int x, int y, int nrows, const float *dataSrc, 579 | float *dataDst, int *counter) { 580 | int coter = *counter; 581 | float temp[36] __attribute__((aligned(64))); 582 | 583 | temp[0] = dataSrc[(x + 0) * nrows + y + 0]; 584 | temp[1] = dataSrc[(x + 0) * nrows + y + 1]; 585 | temp[2] = dataSrc[(x + 0) * nrows + y + 2]; 586 | temp[3] = dataSrc[(x + 0) * nrows + y + 3]; 587 | temp[4] = dataSrc[(x + 0) * nrows + y + 4]; 588 | temp[5] = dataSrc[(x + 0) * nrows + y + 5]; 589 | temp[6] = dataSrc[(x + 1) * nrows + y + 0]; 590 | temp[7] = dataSrc[(x + 1) * nrows + y + 1]; 591 | temp[8] = dataSrc[(x + 1) * nrows + y + 2]; 592 | temp[9] = dataSrc[(x + 1) * nrows + y + 3]; 593 | temp[10] = dataSrc[(x + 1) * nrows + y + 4]; 594 | temp[11] = dataSrc[(x + 1) * nrows + y + 5]; 595 | temp[12] = dataSrc[(x + 2) * nrows + y + 0]; 596 | temp[13] = dataSrc[(x + 2) * nrows + y + 1]; 597 | temp[14] = dataSrc[(x + 2) * nrows + y + 2]; 598 | temp[15] = dataSrc[(x + 2) * nrows + y + 3]; 599 | temp[16] = dataSrc[(x + 2) * nrows + y + 4]; 600 | temp[17] = dataSrc[(x + 2) * nrows + y + 5]; 601 | temp[18] = dataSrc[(x + 3) * nrows + y + 0]; 602 | temp[19] = dataSrc[(x + 3) * nrows + y + 1]; 603 | temp[20] = dataSrc[(x + 3) * nrows + y + 2]; 604 | temp[21] = dataSrc[(x + 3) * nrows + y + 3]; 605 | temp[22] = dataSrc[(x + 3) * nrows + y + 4]; 606 | temp[23] = dataSrc[(x + 3) * nrows + y + 5]; 607 | temp[24] = dataSrc[(x + 4) * nrows + y + 0]; 608 | temp[25] = dataSrc[(x + 4) * nrows + y + 1]; 609 | temp[26] = dataSrc[(x + 4) * nrows + y + 2]; 610 | temp[27] = dataSrc[(x + 4) * nrows + y + 3]; 611 | temp[28] = dataSrc[(x + 4) * nrows + y + 4]; 612 | temp[29] = dataSrc[(x + 4) * nrows + y + 5]; 613 | temp[30] = dataSrc[(x + 5) * nrows + y + 0]; 614 | temp[31] = dataSrc[(x + 5) * nrows + y + 1]; 615 | temp[32] = dataSrc[(x + 5) * nrows + y + 2]; 616 | temp[33] = dataSrc[(x + 5) * nrows + y + 3]; 617 | temp[34] = dataSrc[(x + 5) * nrows + y + 4]; 618 | temp[35] = dataSrc[(x + 5) * nrows + y + 5]; 619 | 620 | float temp2[36]__attribute__((aligned(64))); 621 | temp2[0] = 4 * temp[0] - 5 * temp[12] + temp[24]; 622 | temp2[1] = 4 * temp[1] - 5 * temp[13] + temp[25]; 623 | temp2[2] = 4 * temp[2] - 5 * temp[14] + temp[26]; 624 | temp2[3] = 4 * temp[3] - 5 * temp[15] + temp[27]; 625 | temp2[4] = 4 * temp[4] - 5 * temp[16] + temp[28]; 626 | temp2[5] = 4 * temp[5] - 5 * temp[17] + temp[29]; 627 | temp2[6] = -4 * temp[6] - 4 * temp[12] + temp[18] + temp[24]; 628 | temp2[7] = -4 * temp[7] - 4 * temp[13] + temp[19] + temp[25]; 629 | temp2[8] = -4 * temp[8] - 4 * temp[14] + temp[20] + temp[26]; 630 | temp2[9] = -4 * temp[9] - 4 * temp[15] + temp[21] + temp[27]; 631 | temp2[10] = -4 * temp[10] - 4 * temp[16] + temp[22] + temp[28]; 632 | temp2[11] = -4 * temp[11] - 4 * temp[17] + temp[23] + temp[29]; 633 | temp2[12] = 4 * temp[6] - 4 * temp[12] - temp[18] + temp[24]; 634 | temp2[13] = 4 * temp[7] - 4 * temp[13] - temp[19] + temp[25]; 635 | temp2[14] = 4 * temp[8] - 4 * temp[14] - temp[20] + temp[26]; 636 | temp2[15] = 4 * temp[9] - 4 * temp[15] - temp[21] + temp[27]; 637 | temp2[16] = 4 * temp[10] - 4 * temp[16] - temp[22] + temp[28]; 638 | temp2[17] = 4 * temp[11] - 4 * temp[17] - temp[23] + temp[29]; 639 | temp2[18] = -2 * temp[6] - temp[12] + 2 * temp[18] + temp[24]; 640 | temp2[19] = -2 * temp[7] - temp[13] + 2 * temp[19] + temp[25]; 641 | temp2[20] = -2 * temp[8] - temp[14] + 2 * temp[20] + temp[26]; 642 | temp2[21] = -2 * temp[9] - temp[15] + 2 * temp[21] + temp[27]; 643 | temp2[22] = -2 * temp[10] - temp[16] + 2 * temp[22] + temp[28]; 644 | temp2[23] = -2 * temp[11] - temp[17] + 2 * temp[23] + temp[29]; 645 | temp2[24] = 2 * temp[6] - temp[12] - 2 * temp[18] + temp[24]; 646 | temp2[25] = 2 * temp[7] - temp[13] - 2 * temp[19] + temp[25]; 647 | temp2[26] = 2 * temp[8] - temp[14] - 2 * temp[20] + temp[26]; 648 | temp2[27] = 2 * temp[9] - temp[15] - 2 * temp[21] + temp[27]; 649 | temp2[28] = 2 * temp[10] - temp[16] - 2 * temp[22] + temp[28]; 650 | temp2[29] = 2 * temp[11] - temp[17] - 2 * temp[23] + temp[29]; 651 | temp2[30] = 4 * temp[6] - 5 * temp[18] + temp[30]; 652 | temp2[31] = 4 * temp[7] - 5 * temp[19] + temp[31]; 653 | temp2[32] = 4 * temp[8] - 5 * temp[20] + temp[32]; 654 | temp2[33] = 4 * temp[9] - 5 * temp[21] + temp[33]; 655 | temp2[34] = 4 * temp[10] - 5 * temp[22] + temp[34]; 656 | temp2[35] = 4 * temp[11] - 5 * temp[23] + temp[35]; 657 | 658 | dataDst[0 * ISTRIDE + coter] = temp2[0] * 4 - temp2[2] * 5 + temp2[4]; 659 | dataDst[1 * ISTRIDE + coter] = -temp2[1] * 4 - temp2[2] * 4 + temp2[3] + temp2[4]; 660 | dataDst[2 * ISTRIDE + coter] = temp2[1] * 4 - temp2[2] * 4 - temp2[3] + temp2[4]; 661 | dataDst[3 * ISTRIDE + coter] = -temp2[1] * 2 - temp2[2] + temp2[3] * 2 + temp2[4]; 662 | dataDst[4 * ISTRIDE + coter] = temp2[1] * 2 - temp2[2] - temp2[3] * 2 + temp2[4]; 663 | dataDst[5 * ISTRIDE + coter] = temp2[1] * 4 - temp2[3] * 5 + temp2[5]; 664 | dataDst[6 * ISTRIDE + coter] = temp2[6] * 4 - temp2[8] * 5 + temp2[10]; 665 | dataDst[7 * ISTRIDE + coter] = -temp2[7] * 4 - temp2[8] * 4 + temp2[9] + temp2[10]; 666 | dataDst[8 * ISTRIDE + coter] = temp2[7] * 4 - temp2[8] * 4 - temp2[9] + temp2[10]; 667 | dataDst[9 * ISTRIDE + coter] = -temp2[7] * 2 - temp2[8] + temp2[9] * 2 + temp2[10]; 668 | dataDst[10 * ISTRIDE + coter] = temp2[7] * 2 - temp2[8] - temp2[9] * 2 + temp2[10]; 669 | dataDst[11 * ISTRIDE + coter] = temp2[7] * 4 - temp2[9] * 5 + temp2[11]; 670 | dataDst[12 * ISTRIDE + coter] = temp2[12] * 4 - temp2[14] * 5 + temp2[16]; 671 | dataDst[13 * ISTRIDE + coter] = -temp2[13] * 4 - temp2[14] * 4 + temp2[15] + temp2[16]; 672 | dataDst[14 * ISTRIDE + coter] = temp2[13] * 4 - temp2[14] * 4 - temp2[15] + temp2[16]; 673 | dataDst[15 * ISTRIDE + coter] = -temp2[13] * 2 - temp2[14] + temp2[15] * 2 + temp2[16]; 674 | dataDst[16 * ISTRIDE + coter] = temp2[13] * 2 - temp2[14] - temp2[15] * 2 + temp2[16]; 675 | dataDst[17 * ISTRIDE + coter] = temp2[13] * 4 - temp2[15] * 5 + temp2[17]; 676 | dataDst[18 * ISTRIDE + coter] = temp2[18] * 4 - temp2[20] * 5 + temp2[22]; 677 | dataDst[19 * ISTRIDE + coter] = -temp2[19] * 4 - temp2[20] * 4 + temp2[21] + temp2[22]; 678 | dataDst[20 * ISTRIDE + coter] = temp2[19] * 4 - temp2[20] * 4 - temp2[21] + temp2[22]; 679 | dataDst[21 * ISTRIDE + coter] = -temp2[19] * 2 - temp2[20] + temp2[21] * 2 + temp2[22]; 680 | dataDst[22 * ISTRIDE + coter] = temp2[19] * 2 - temp2[20] - temp2[21] * 2 + temp2[22]; 681 | dataDst[23 * ISTRIDE + coter] = temp2[19] * 4 - temp2[21] * 5 + temp2[23]; 682 | dataDst[24 * ISTRIDE + coter] = temp2[24] * 4 - temp2[26] * 5 + temp2[28]; 683 | dataDst[25 * ISTRIDE + coter] = -temp2[25] * 4 - temp2[26] * 4 + temp2[27] + temp2[28]; 684 | dataDst[26 * ISTRIDE + coter] = temp2[25] * 4 - temp2[26] * 4 - temp2[27] + temp2[28]; 685 | dataDst[27 * ISTRIDE + coter] = -temp2[25] * 2 - temp2[26] + temp2[27] * 2 + temp2[28]; 686 | dataDst[28 * ISTRIDE + coter] = temp2[25] * 2 - temp2[26] - temp2[27] * 2 + temp2[28]; 687 | dataDst[29 * ISTRIDE + coter] = temp2[25] * 4 - temp2[27] * 5 + temp2[29]; 688 | dataDst[30 * ISTRIDE + coter] = temp2[30] * 4 - temp2[32] * 5 + temp2[34]; 689 | dataDst[31 * ISTRIDE + coter] = -temp2[31] * 4 - temp2[32] * 4 + temp2[33] + temp2[34]; 690 | dataDst[32 * ISTRIDE + coter] = temp2[31] * 4 - temp2[32] * 4 - temp2[33] + temp2[34]; 691 | dataDst[33 * ISTRIDE + coter] = -temp2[31] * 2 - temp2[32] + temp2[33] * 2 + temp2[34]; 692 | dataDst[34 * ISTRIDE + coter] = temp2[31] * 2 - temp2[32] - temp2[33] * 2 + temp2[34]; 693 | dataDst[35 * ISTRIDE + coter] = temp2[31] * 4 - temp2[33] * 5 + temp2[35]; 694 | 695 | (*counter)++; 696 | } 697 | 698 | static void filter_transform_4x3(const float* restrict filter, const int C, const int K, float* restrict out) { 699 | int m, n, x; 700 | const float *F; 701 | const float r4 = 1.0 / 4; 702 | const float r6 = 1.0 / 6; 703 | const float r12 = 1.0 / 12; 704 | const float r24 = 1.0 / 24; 705 | 706 | #pragma omp parallel for collapse(2) private(m, n, x, F) 707 | #pragma simd 708 | for (m = 0; m < K; ++m) { 709 | for (n = 0; n < C; ++n) { 710 | float c1[18] __attribute__((aligned(64))); 711 | F = filter + n * 3 * 3 + m * 3 * 3 * C; 712 | c1[0] = r4 * F[0]; 713 | c1[1] = r4 * F[1]; 714 | c1[2] = r4 * F[2]; 715 | c1[3] = -r6 * (F[0] + F[3] + F[6]); 716 | c1[4] = -r6 * (F[1] + F[4] + F[7]); 717 | c1[5] = -r6 * (F[2] + F[5] + F[8]); 718 | c1[6] = -r6 * (F[0] - F[3] + F[6]); 719 | c1[7] = -r6 * (F[1] - F[4] + F[7]); 720 | c1[8] = -r6 * (F[2] - F[5] + F[8]); 721 | c1[9] = r24 * F[0] + r12 * F[3] + r6 * F[6]; 722 | c1[10] = r24 * F[1] + r12 * F[4] + r6 * F[7]; 723 | c1[11] = r24 * F[2] + r12 * F[5] + r6 * F[8]; 724 | c1[12] = r24 * F[0] - r12 * F[3] + r6 * F[6]; 725 | c1[13] = r24 * F[1] - r12 * F[4] + r6 * F[7]; 726 | c1[14] = r24 * F[2] - r12 * F[5] + r6 * F[8]; 727 | c1[15] = F[6]; 728 | c1[16] = F[7]; 729 | c1[17] = F[8]; 730 | 731 | float c2[36] __attribute__((aligned(64))); 732 | c2[0] = r4 * c1[0]; 733 | c2[1] = -r6 * (c1[0] + c1[1] + c1[2]); 734 | c2[2] = -r6 * (c1[0] - c1[1] + c1[2]); 735 | c2[3] = r24 * c1[0] + r12 * c1[1] + r6 * c1[2]; 736 | c2[4] = r24 * c1[0] - r12 * c1[1] + r6 * c1[2]; 737 | c2[5] = c1[2]; 738 | 739 | c2[6] = r4 * c1[3]; 740 | c2[7] = -r6 * (c1[3] + c1[4] + c1[5]); 741 | c2[8] = -r6 * (c1[3] - c1[4] + c1[5]); 742 | c2[9] = r24 * c1[3] + r12 * c1[4] + r6 * c1[5]; 743 | c2[10] = r24 * c1[3] - r12 * c1[4] + r6 * c1[5]; 744 | c2[11] = c1[5]; 745 | 746 | c2[12] = r4 * c1[6]; 747 | c2[13] = -r6 * (c1[6] + c1[7] + c1[8]); 748 | c2[14] = -r6 * (c1[6] - c1[7] + c1[8]); 749 | c2[15] = r24 * c1[6] + r12 * c1[7] + r6 * c1[8]; 750 | c2[16] = r24 * c1[6] - r12 * c1[7] + r6 * c1[8]; 751 | c2[17] = c1[8]; 752 | 753 | c2[18] = r4 * c1[9]; 754 | c2[19] = -r6 * (c1[9] + c1[10] + c1[11]); 755 | c2[20] = -r6 * (c1[9] - c1[10] + c1[11]); 756 | c2[21] = r24 * c1[9] + r12 * c1[10] + r6 * c1[11]; 757 | c2[22] = r24 * c1[9] - r12 * c1[10] + r6 * c1[11]; 758 | c2[23] = c1[11]; 759 | 760 | c2[24] = r4 * c1[12]; 761 | c2[25] = -r6 * (c1[12] + c1[13] + c1[14]); 762 | c2[26] = -r6 * (c1[12] - c1[13] + c1[14]); 763 | c2[27] = r24 * c1[12] + r12 * c1[13] + r6 * c1[14]; 764 | c2[28] = r24 * c1[12] - r12 * c1[13] + r6 * c1[14]; 765 | c2[29] = c1[14]; 766 | 767 | c2[30] = r4 * c1[15]; 768 | c2[31] = -r6 * (c1[15] + c1[16] + c1[17]); 769 | c2[32] = -r6 * (c1[15] - c1[16] + c1[17]); 770 | c2[33] = r24 * c1[15] + r12 * c1[16] + r6 * c1[17]; 771 | c2[34] = r24 * c1[15] - r12 * c1[16] + r6 * c1[17]; 772 | c2[35] = c1[17]; 773 | 774 | #pragma unroll(9) 775 | for (x = 0; x < 36; ++x) { 776 | out[x * FSTRIDE + m * C + n] = c2[x]; 777 | } 778 | } 779 | } 780 | } 781 | 782 | static void out_transform_4x3_16t(int x, int y, int nrows, 783 | const float* dataSrc, float* dataDst, 784 | int *counter) { 785 | int coter = *counter; 786 | float c1[384] __attribute__((aligned(64))); 787 | __m512 bufA[36], bufB, bufC, bufD, bufE, bufF, bufG, bufH, bufI; 788 | __m512 bufTemp[24]; 789 | 790 | __m512i idx0 = _mm512_set_epi32(23, 7, 22, 6, 21, 5, 20, 4, 791 | 19, 3, 18, 2, 17, 1, 16, 0); 792 | __m512i idx1 = _mm512_set_epi32(31, 15, 30, 14, 29, 13, 28, 12, 793 | 27, 11, 26, 10, 25, 9, 24, 8); 794 | 795 | /* 0 4 8 12 16 20 24 28 32 36 40 44 48 52 56 60 796 | 1 5 9 13 17 21 25 29 33 37 41 45 49 53 57 61 797 | 2 6 10 14 18 22 26 30 34 38 42 46 50 54 58 62 798 | 3 7 11 15 19 23 27 31 35 39 43 47 51 55 59 63 */ 799 | 800 | /* 0 2 4 6 8 10 12 14 16 18 20 22 24 26 28 30 801 | 1 3 5 7 9 11 13 15 17 19 21 23 25 27 29 31 802 | 32 34 36 38 40 42 44 46 48 50 52 54 56 58 60 62 803 | 33 35 37 39 41 43 45 47 49 51 53 55 57 59 61 63 */ 804 | 805 | /* 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 806 | 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 807 | 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 808 | 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 */ 809 | 810 | bufA[0] = _mm512_load_ps(dataSrc + 0 * OSTRIDE + coter); 811 | bufA[1] = _mm512_load_ps(dataSrc + 1 * OSTRIDE + coter); 812 | bufA[2] = _mm512_load_ps(dataSrc + 2 * OSTRIDE + coter); 813 | bufA[3] = _mm512_load_ps(dataSrc + 3 * OSTRIDE + coter); 814 | bufA[4] = _mm512_load_ps(dataSrc + 4 * OSTRIDE + coter); 815 | bufA[5] = _mm512_load_ps(dataSrc + 5 * OSTRIDE + coter); 816 | bufA[6] = _mm512_load_ps(dataSrc + 6 * OSTRIDE + coter); 817 | bufA[7] = _mm512_load_ps(dataSrc + 7 * OSTRIDE + coter); 818 | bufA[8] = _mm512_load_ps(dataSrc + 8 * OSTRIDE + coter); 819 | bufA[9] = _mm512_load_ps(dataSrc + 9 * OSTRIDE + coter); 820 | bufA[10] = _mm512_load_ps(dataSrc + 10 * OSTRIDE + coter); 821 | bufA[11] = _mm512_load_ps(dataSrc + 11 * OSTRIDE + coter); 822 | bufA[12] = _mm512_load_ps(dataSrc + 12 * OSTRIDE + coter); 823 | bufA[13] = _mm512_load_ps(dataSrc + 13 * OSTRIDE + coter); 824 | bufA[14] = _mm512_load_ps(dataSrc + 14 * OSTRIDE + coter); 825 | bufA[15] = _mm512_load_ps(dataSrc + 15 * OSTRIDE + coter); 826 | bufA[16] = _mm512_load_ps(dataSrc + 16 * OSTRIDE + coter); 827 | bufA[17] = _mm512_load_ps(dataSrc + 17 * OSTRIDE + coter); 828 | bufA[18] = _mm512_load_ps(dataSrc + 18 * OSTRIDE + coter); 829 | bufA[19] = _mm512_load_ps(dataSrc + 19 * OSTRIDE + coter); 830 | bufA[20] = _mm512_load_ps(dataSrc + 20 * OSTRIDE + coter); 831 | bufA[21] = _mm512_load_ps(dataSrc + 21 * OSTRIDE + coter); 832 | bufA[22] = _mm512_load_ps(dataSrc + 22 * OSTRIDE + coter); 833 | bufA[23] = _mm512_load_ps(dataSrc + 23 * OSTRIDE + coter); 834 | bufA[24] = _mm512_load_ps(dataSrc + 24 * OSTRIDE + coter); 835 | bufA[25] = _mm512_load_ps(dataSrc + 25 * OSTRIDE + coter); 836 | bufA[26] = _mm512_load_ps(dataSrc + 26 * OSTRIDE + coter); 837 | bufA[27] = _mm512_load_ps(dataSrc + 27 * OSTRIDE + coter); 838 | bufA[28] = _mm512_load_ps(dataSrc + 28 * OSTRIDE + coter); 839 | bufA[29] = _mm512_load_ps(dataSrc + 29 * OSTRIDE + coter); 840 | bufA[30] = _mm512_load_ps(dataSrc + 30 * OSTRIDE + coter); 841 | bufA[31] = _mm512_load_ps(dataSrc + 31 * OSTRIDE + coter); 842 | bufA[32] = _mm512_load_ps(dataSrc + 32 * OSTRIDE + coter); 843 | bufA[33] = _mm512_load_ps(dataSrc + 33 * OSTRIDE + coter); 844 | bufA[34] = _mm512_load_ps(dataSrc + 34 * OSTRIDE + coter); 845 | bufA[35] = _mm512_load_ps(dataSrc + 35 * OSTRIDE + coter); 846 | 847 | __m512 m2 = _mm512_set1_ps(2); 848 | __m512 m4 = _mm512_set1_ps(4); 849 | __m512 m8 = _mm512_set1_ps(8); 850 | 851 | bufTemp[0] = _mm512_add_ps(bufA[0], bufA[6]); 852 | bufTemp[1] = _mm512_add_ps(bufA[1], bufA[7]); 853 | bufTemp[2] = _mm512_add_ps(bufA[2], bufA[8]); 854 | bufTemp[3] = _mm512_add_ps(bufA[3], bufA[9]); 855 | bufTemp[4] = _mm512_add_ps(bufA[4], bufA[10]); 856 | bufTemp[5] = _mm512_add_ps(bufA[5], bufA[11]); 857 | bufTemp[0] = _mm512_add_ps(bufTemp[0], bufA[12]); 858 | bufTemp[1] = _mm512_add_ps(bufTemp[1], bufA[13]); 859 | bufTemp[2] = _mm512_add_ps(bufTemp[2], bufA[14]); 860 | bufTemp[3] = _mm512_add_ps(bufTemp[3], bufA[15]); 861 | bufTemp[4] = _mm512_add_ps(bufTemp[4], bufA[16]); 862 | bufTemp[5] = _mm512_add_ps(bufTemp[5], bufA[17]); 863 | bufTemp[0] = _mm512_add_ps(bufTemp[0], bufA[18]); 864 | bufTemp[1] = _mm512_add_ps(bufTemp[1], bufA[19]); 865 | bufTemp[2] = _mm512_add_ps(bufTemp[2], bufA[20]); 866 | bufTemp[3] = _mm512_add_ps(bufTemp[3], bufA[21]); 867 | bufTemp[4] = _mm512_add_ps(bufTemp[4], bufA[22]); 868 | bufTemp[5] = _mm512_add_ps(bufTemp[5], bufA[23]); 869 | bufTemp[0] = _mm512_add_ps(bufTemp[0], bufA[24]); 870 | bufTemp[1] = _mm512_add_ps(bufTemp[1], bufA[25]); 871 | bufTemp[2] = _mm512_add_ps(bufTemp[2], bufA[26]); 872 | bufTemp[3] = _mm512_add_ps(bufTemp[3], bufA[27]); 873 | bufTemp[4] = _mm512_add_ps(bufTemp[4], bufA[28]); 874 | bufTemp[5] = _mm512_add_ps(bufTemp[5], bufA[29]); 875 | 876 | bufTemp[6] = _mm512_sub_ps(bufA[6], bufA[12]); 877 | bufTemp[7] = _mm512_sub_ps(bufA[7], bufA[13]); 878 | bufTemp[8] = _mm512_sub_ps(bufA[8], bufA[14]); 879 | bufTemp[9] = _mm512_sub_ps(bufA[9], bufA[15]); 880 | bufTemp[10] = _mm512_sub_ps(bufA[10], bufA[16]); 881 | bufTemp[11] = _mm512_sub_ps(bufA[11], bufA[17]); 882 | bufTemp[6] = _mm512_fmadd_ps(bufA[18], m2, bufTemp[6]); 883 | bufTemp[7] = _mm512_fmadd_ps(bufA[19], m2, bufTemp[7]); 884 | bufTemp[8] = _mm512_fmadd_ps(bufA[20], m2, bufTemp[8]); 885 | bufTemp[9] = _mm512_fmadd_ps(bufA[21], m2, bufTemp[9]); 886 | bufTemp[10] = _mm512_fmadd_ps(bufA[22], m2, bufTemp[10]); 887 | bufTemp[11] = _mm512_fmadd_ps(bufA[23], m2, bufTemp[11]); 888 | bufTemp[6] = _mm512_fnmadd_ps(bufA[24], m2, bufTemp[6]); 889 | bufTemp[7] = _mm512_fnmadd_ps(bufA[25], m2, bufTemp[7]); 890 | bufTemp[8] = _mm512_fnmadd_ps(bufA[26], m2, bufTemp[8]); 891 | bufTemp[9] = _mm512_fnmadd_ps(bufA[27], m2, bufTemp[9]); 892 | bufTemp[10] = _mm512_fnmadd_ps(bufA[28], m2, bufTemp[10]); 893 | bufTemp[11] = _mm512_fnmadd_ps(bufA[29], m2, bufTemp[11]); 894 | 895 | bufTemp[12] = _mm512_add_ps(bufA[6], bufA[12]); 896 | bufTemp[13] = _mm512_add_ps(bufA[7], bufA[13]); 897 | bufTemp[14] = _mm512_add_ps(bufA[8], bufA[14]); 898 | bufTemp[15] = _mm512_add_ps(bufA[9], bufA[15]); 899 | bufTemp[16] = _mm512_add_ps(bufA[10], bufA[16]); 900 | bufTemp[17] = _mm512_add_ps(bufA[11], bufA[17]); 901 | bufTemp[12] = _mm512_fmadd_ps(m4, bufA[18], bufTemp[12]); 902 | bufTemp[13] = _mm512_fmadd_ps(m4, bufA[19], bufTemp[13]); 903 | bufTemp[14] = _mm512_fmadd_ps(m4, bufA[20], bufTemp[14]); 904 | bufTemp[15] = _mm512_fmadd_ps(m4, bufA[21], bufTemp[15]); 905 | bufTemp[16] = _mm512_fmadd_ps(m4, bufA[22], bufTemp[16]); 906 | bufTemp[17] = _mm512_fmadd_ps(m4, bufA[23], bufTemp[17]); 907 | bufTemp[12] = _mm512_fmadd_ps(m4, bufA[24], bufTemp[12]); 908 | bufTemp[13] = _mm512_fmadd_ps(m4, bufA[25], bufTemp[13]); 909 | bufTemp[14] = _mm512_fmadd_ps(m4, bufA[26], bufTemp[14]); 910 | bufTemp[15] = _mm512_fmadd_ps(m4, bufA[27], bufTemp[15]); 911 | bufTemp[16] = _mm512_fmadd_ps(m4, bufA[28], bufTemp[16]); 912 | bufTemp[17] = _mm512_fmadd_ps(m4, bufA[29], bufTemp[17]); 913 | 914 | bufTemp[18] = _mm512_sub_ps(bufA[6], bufA[12]); 915 | bufTemp[19] = _mm512_sub_ps(bufA[7], bufA[13]); 916 | bufTemp[20] = _mm512_sub_ps(bufA[8], bufA[14]); 917 | bufTemp[21] = _mm512_sub_ps(bufA[9], bufA[15]); 918 | bufTemp[22] = _mm512_sub_ps(bufA[10], bufA[16]); 919 | bufTemp[23] = _mm512_sub_ps(bufA[11], bufA[17]); 920 | bufTemp[18] = _mm512_fmadd_ps(m8, bufA[18], bufTemp[18]); 921 | bufTemp[19] = _mm512_fmadd_ps(m8, bufA[19], bufTemp[19]); 922 | bufTemp[20] = _mm512_fmadd_ps(m8, bufA[20], bufTemp[20]); 923 | bufTemp[21] = _mm512_fmadd_ps(m8, bufA[21], bufTemp[21]); 924 | bufTemp[22] = _mm512_fmadd_ps(m8, bufA[22], bufTemp[22]); 925 | bufTemp[23] = _mm512_fmadd_ps(m8, bufA[23], bufTemp[23]); 926 | bufTemp[18] = _mm512_fnmadd_ps(m8, bufA[24], bufTemp[18]); 927 | bufTemp[19] = _mm512_fnmadd_ps(m8, bufA[25], bufTemp[19]); 928 | bufTemp[20] = _mm512_fnmadd_ps(m8, bufA[26], bufTemp[20]); 929 | bufTemp[21] = _mm512_fnmadd_ps(m8, bufA[27], bufTemp[21]); 930 | bufTemp[22] = _mm512_fnmadd_ps(m8, bufA[28], bufTemp[22]); 931 | bufTemp[23] = _mm512_fnmadd_ps(m8, bufA[29], bufTemp[23]); 932 | bufTemp[18] = _mm512_add_ps(bufA[30], bufTemp[18]); 933 | bufTemp[19] = _mm512_add_ps(bufA[31], bufTemp[19]); 934 | bufTemp[20] = _mm512_add_ps(bufA[32], bufTemp[20]); 935 | bufTemp[21] = _mm512_add_ps(bufA[33], bufTemp[21]); 936 | bufTemp[22] = _mm512_add_ps(bufA[34], bufTemp[22]); 937 | bufTemp[23] = _mm512_add_ps(bufA[35], bufTemp[23]); 938 | 939 | bufB = _mm512_add_ps(bufTemp[0], bufTemp[1]); 940 | bufB = _mm512_add_ps(bufB, bufTemp[2]); 941 | bufB = _mm512_add_ps(bufB, bufTemp[3]); 942 | bufB = _mm512_add_ps(bufB, bufTemp[4]); 943 | 944 | bufC = _mm512_sub_ps(bufTemp[1], bufTemp[2]); 945 | bufC = _mm512_fmadd_ps(m2, bufTemp[3], bufC); 946 | bufC = _mm512_fnmadd_ps(m2, bufTemp[4], bufC); 947 | 948 | bufD = _mm512_add_ps(bufTemp[1], bufTemp[2]); 949 | bufD = _mm512_fmadd_ps(m4, bufTemp[3], bufD); 950 | bufD = _mm512_fmadd_ps(m4, bufTemp[4], bufD); 951 | 952 | bufE = _mm512_sub_ps(bufTemp[1], bufTemp[2]); 953 | bufE = _mm512_fmadd_ps(m8, bufTemp[3], bufE); 954 | bufE = _mm512_fnmadd_ps(m8, bufTemp[4], bufE); 955 | bufE = _mm512_add_ps(bufTemp[5], bufE); 956 | 957 | bufF = _mm512_permutex2var_ps(bufB, idx0, bufD); 958 | bufG = _mm512_permutex2var_ps(bufC, idx0, bufE); 959 | bufH = _mm512_permutex2var_ps(bufB, idx1, bufD); 960 | bufI = _mm512_permutex2var_ps(bufC, idx1, bufE); 961 | 962 | bufB = _mm512_permutex2var_ps(bufF, idx0, bufG); 963 | bufC = _mm512_permutex2var_ps(bufF, idx1, bufG); 964 | bufD = _mm512_permutex2var_ps(bufH, idx0, bufI); 965 | bufE = _mm512_permutex2var_ps(bufH, idx1, bufI); 966 | 967 | _mm512_store_ps(dataDst + (x + 0) * nrows + y + 0, bufB); 968 | _mm512_store_ps(dataDst + (x + 0) * nrows + y + 16, bufC); 969 | _mm512_store_ps(dataDst + (x + 0) * nrows + y + 32, bufD); 970 | _mm512_store_ps(dataDst + (x + 0) * nrows + y + 48, bufE); 971 | 972 | bufB = _mm512_add_ps(bufTemp[6], bufTemp[7]); 973 | bufB = _mm512_add_ps(bufB, bufTemp[8]); 974 | bufB = _mm512_add_ps(bufB, bufTemp[9]); 975 | bufB = _mm512_add_ps(bufB, bufTemp[10]); 976 | 977 | bufC = _mm512_sub_ps(bufTemp[7], bufTemp[8]); 978 | bufC = _mm512_fmadd_ps(m2, bufTemp[9], bufC); 979 | bufC = _mm512_fnmadd_ps(m2, bufTemp[10], bufC); 980 | 981 | bufD = _mm512_add_ps(bufTemp[7], bufTemp[8]); 982 | bufD = _mm512_fmadd_ps(m4, bufTemp[9], bufD); 983 | bufD = _mm512_fmadd_ps(m4, bufTemp[10], bufD); 984 | 985 | bufE = _mm512_sub_ps(bufTemp[7], bufTemp[8]); 986 | bufE = _mm512_fmadd_ps(m8, bufTemp[9], bufE); 987 | bufE = _mm512_fnmadd_ps(m8, bufTemp[10], bufE); 988 | bufE = _mm512_add_ps(bufTemp[11], bufE); 989 | 990 | bufF = _mm512_permutex2var_ps(bufB, idx0, bufD); 991 | bufG = _mm512_permutex2var_ps(bufC, idx0, bufE); 992 | bufH = _mm512_permutex2var_ps(bufB, idx1, bufD); 993 | bufI = _mm512_permutex2var_ps(bufC, idx1, bufE); 994 | 995 | bufB = _mm512_permutex2var_ps(bufF, idx0, bufG); 996 | bufC = _mm512_permutex2var_ps(bufF, idx1, bufG); 997 | bufD = _mm512_permutex2var_ps(bufH, idx0, bufI); 998 | bufE = _mm512_permutex2var_ps(bufH, idx1, bufI); 999 | 1000 | _mm512_store_ps(dataDst + (x + 1) * nrows + y + 0, bufB); 1001 | _mm512_store_ps(dataDst + (x + 1) * nrows + y + 16, bufC); 1002 | _mm512_store_ps(dataDst + (x + 1) * nrows + y + 32, bufD); 1003 | _mm512_store_ps(dataDst + (x + 1) * nrows + y + 48, bufE); 1004 | 1005 | bufB = _mm512_add_ps(bufTemp[12], bufTemp[13]); 1006 | bufB = _mm512_add_ps(bufB, bufTemp[14]); 1007 | bufB = _mm512_add_ps(bufB, bufTemp[15]); 1008 | bufB = _mm512_add_ps(bufB, bufTemp[16]); 1009 | 1010 | bufC = _mm512_sub_ps(bufTemp[13], bufTemp[14]); 1011 | bufC = _mm512_fmadd_ps(m2, bufTemp[15], bufC); 1012 | bufC = _mm512_fnmadd_ps(m2, bufTemp[16], bufC); 1013 | 1014 | bufD = _mm512_add_ps(bufTemp[13], bufTemp[14]); 1015 | bufD = _mm512_fmadd_ps(m4, bufTemp[15], bufD); 1016 | bufD = _mm512_fmadd_ps(m4, bufTemp[16], bufD); 1017 | 1018 | bufE = _mm512_sub_ps(bufTemp[13], bufTemp[14]); 1019 | bufE = _mm512_fmadd_ps(m8, bufTemp[15], bufE); 1020 | bufE = _mm512_fnmadd_ps(m8, bufTemp[16], bufE); 1021 | bufE = _mm512_add_ps(bufTemp[17], bufE); 1022 | 1023 | bufF = _mm512_permutex2var_ps(bufB, idx0, bufD); 1024 | bufG = _mm512_permutex2var_ps(bufC, idx0, bufE); 1025 | bufH = _mm512_permutex2var_ps(bufB, idx1, bufD); 1026 | bufI = _mm512_permutex2var_ps(bufC, idx1, bufE); 1027 | 1028 | bufB = _mm512_permutex2var_ps(bufF, idx0, bufG); 1029 | bufC = _mm512_permutex2var_ps(bufF, idx1, bufG); 1030 | bufD = _mm512_permutex2var_ps(bufH, idx0, bufI); 1031 | bufE = _mm512_permutex2var_ps(bufH, idx1, bufI); 1032 | 1033 | _mm512_store_ps(dataDst + (x + 2) * nrows + y + 0, bufB); 1034 | _mm512_store_ps(dataDst + (x + 2) * nrows + y + 16, bufC); 1035 | _mm512_store_ps(dataDst + (x + 2) * nrows + y + 32, bufD); 1036 | _mm512_store_ps(dataDst + (x + 2) * nrows + y + 48, bufE); 1037 | 1038 | bufB = _mm512_add_ps(bufTemp[18], bufTemp[19]); 1039 | bufB = _mm512_add_ps(bufB, bufTemp[20]); 1040 | bufB = _mm512_add_ps(bufB, bufTemp[21]); 1041 | bufB = _mm512_add_ps(bufB, bufTemp[22]); 1042 | 1043 | bufC = _mm512_sub_ps(bufTemp[19], bufTemp[20]); 1044 | bufC = _mm512_fmadd_ps(m2, bufTemp[21], bufC); 1045 | bufC = _mm512_fnmadd_ps(m2, bufTemp[22], bufC); 1046 | 1047 | bufD = _mm512_add_ps(bufTemp[19], bufTemp[20]); 1048 | bufD = _mm512_fmadd_ps(m4, bufTemp[21], bufD); 1049 | bufD = _mm512_fmadd_ps(m4, bufTemp[22], bufD); 1050 | 1051 | bufE = _mm512_sub_ps(bufTemp[19], bufTemp[20]); 1052 | bufE = _mm512_fmadd_ps(m8, bufTemp[21], bufE); 1053 | bufE = _mm512_fnmadd_ps(m8, bufTemp[22], bufE); 1054 | bufE = _mm512_add_ps(bufTemp[23], bufE); 1055 | 1056 | bufF = _mm512_permutex2var_ps(bufB, idx0, bufD); 1057 | bufG = _mm512_permutex2var_ps(bufC, idx0, bufE); 1058 | bufH = _mm512_permutex2var_ps(bufB, idx1, bufD); 1059 | bufI = _mm512_permutex2var_ps(bufC, idx1, bufE); 1060 | 1061 | bufB = _mm512_permutex2var_ps(bufF, idx0, bufG); 1062 | bufC = _mm512_permutex2var_ps(bufF, idx1, bufG); 1063 | bufD = _mm512_permutex2var_ps(bufH, idx0, bufI); 1064 | bufE = _mm512_permutex2var_ps(bufH, idx1, bufI); 1065 | 1066 | _mm512_store_ps(dataDst + (x + 3) * nrows + y + 0, bufB); 1067 | _mm512_store_ps(dataDst + (x + 3) * nrows + y + 16, bufC); 1068 | _mm512_store_ps(dataDst + (x + 3) * nrows + y + 32, bufD); 1069 | _mm512_store_ps(dataDst + (x + 3) * nrows + y + 48, bufE); 1070 | 1071 | *counter += 16; 1072 | } 1073 | 1074 | static inline void pad_out_transform(int x, int y, int lenX, int lenY, int nrows, const float *dataSrc, 1075 | float *temp, float *dataDst, int *counter) { 1076 | if (0 == lenX || 0 == lenY) { 1077 | return; 1078 | } 1079 | out_transform_4x3_16t(0, 0, 64, dataSrc, temp, counter); 1080 | /*for (int i = 0; i < 4; ++i) { 1081 | for (int j = 0; j < 64; ++j) { 1082 | cout << temp[i * 64 + j] << ' '; 1083 | } 1084 | cout << endl; 1085 | }*/ 1086 | for (int i = 0; i < lenX; ++i) { 1087 | for (int j = 0; j < lenY; ++j) { 1088 | dataDst[(x + i) * nrows + y + j] = temp[i * 64 + j]; 1089 | } 1090 | } 1091 | } 1092 | 1093 | static inline void out_transform_4x3_1t(int x, int y, int nrows, const float *dataSrc, 1094 | float *dataDst, int *counter) { 1095 | int coter = *counter; 1096 | float c1[36]__attribute__((aligned(64))); 1097 | c1[0] = dataSrc[0 * OSTRIDE + coter]; 1098 | c1[1] = dataSrc[1 * OSTRIDE + coter]; 1099 | c1[2] = dataSrc[2 * OSTRIDE + coter]; 1100 | c1[3] = dataSrc[3 * OSTRIDE + coter]; 1101 | c1[4] = dataSrc[4 * OSTRIDE + coter]; 1102 | c1[5] = dataSrc[5 * OSTRIDE + coter]; 1103 | c1[6] = dataSrc[6 * OSTRIDE + coter]; 1104 | c1[7] = dataSrc[7 * OSTRIDE + coter]; 1105 | c1[8] = dataSrc[8 * OSTRIDE + coter]; 1106 | c1[9] = dataSrc[9 * OSTRIDE + coter]; 1107 | c1[10] = dataSrc[10 * OSTRIDE + coter]; 1108 | c1[11] = dataSrc[11 * OSTRIDE + coter]; 1109 | c1[12] = dataSrc[12 * OSTRIDE + coter]; 1110 | c1[13] = dataSrc[13 * OSTRIDE + coter]; 1111 | c1[14] = dataSrc[14 * OSTRIDE + coter]; 1112 | c1[15] = dataSrc[15 * OSTRIDE + coter]; 1113 | c1[16] = dataSrc[16 * OSTRIDE + coter]; 1114 | c1[17] = dataSrc[17 * OSTRIDE + coter]; 1115 | c1[18] = dataSrc[18 * OSTRIDE + coter]; 1116 | c1[19] = dataSrc[19 * OSTRIDE + coter]; 1117 | c1[20] = dataSrc[20 * OSTRIDE + coter]; 1118 | c1[21] = dataSrc[21 * OSTRIDE + coter]; 1119 | c1[22] = dataSrc[22 * OSTRIDE + coter]; 1120 | c1[23] = dataSrc[23 * OSTRIDE + coter]; 1121 | c1[24] = dataSrc[24 * OSTRIDE + coter]; 1122 | c1[25] = dataSrc[25 * OSTRIDE + coter]; 1123 | c1[26] = dataSrc[26 * OSTRIDE + coter]; 1124 | c1[27] = dataSrc[27 * OSTRIDE + coter]; 1125 | c1[28] = dataSrc[28 * OSTRIDE + coter]; 1126 | c1[29] = dataSrc[29 * OSTRIDE + coter]; 1127 | c1[30] = dataSrc[30 * OSTRIDE + coter]; 1128 | c1[31] = dataSrc[31 * OSTRIDE + coter]; 1129 | c1[32] = dataSrc[32 * OSTRIDE + coter]; 1130 | c1[33] = dataSrc[33 * OSTRIDE + coter]; 1131 | c1[34] = dataSrc[34 * OSTRIDE + coter]; 1132 | c1[35] = dataSrc[35 * OSTRIDE + coter]; 1133 | 1134 | float temp[24]__attribute__((aligned(64))); 1135 | temp[0] = c1[0] + c1[6] + c1[12] + c1[18] + c1[24]; 1136 | temp[1] = c1[1] + c1[7] + c1[13] + c1[19] + c1[25]; 1137 | temp[2] = c1[2] + c1[8] + c1[14] + c1[20] + c1[26]; 1138 | temp[3] = c1[3] + c1[9] + c1[15] + c1[21] + c1[27]; 1139 | temp[4] = c1[4] + c1[10] + c1[16] + c1[22] + c1[28]; 1140 | temp[5] = c1[5] + c1[11] + c1[17] + c1[23] + c1[29]; 1141 | temp[6] = c1[6] - c1[12] + 2 * c1[18] - 2 * c1[24]; 1142 | temp[7] = c1[7] - c1[13] + 2 * c1[19] - 2 * c1[25]; 1143 | temp[8] = c1[8] - c1[14] + 2 * c1[20] - 2 * c1[26]; 1144 | temp[9] = c1[9] - c1[15] + 2 * c1[21] - 2 * c1[27]; 1145 | temp[10] = c1[10] - c1[16] + 2 * c1[22] - 2 * c1[28]; 1146 | temp[11] = c1[11] - c1[17] + 2 * c1[23] - 2 * c1[29]; 1147 | temp[12] = c1[6] + c1[12] + 4 * c1[18] + 4 * c1[24]; 1148 | temp[13] = c1[7] + c1[13] + 4 * c1[19] + 4 * c1[25]; 1149 | temp[14] = c1[8] + c1[14] + 4 * c1[20] + 4 * c1[26]; 1150 | temp[15] = c1[9] + c1[15] + 4 * c1[21] + 4 * c1[27]; 1151 | temp[16] = c1[10] + c1[16] + 4 * c1[22] + 4 * c1[28]; 1152 | temp[17] = c1[11] + c1[17] + 4 * c1[23] + 4 * c1[29]; 1153 | temp[18] = c1[6] - c1[12] + 8 * c1[18] - 8 * c1[24] + c1[30]; 1154 | temp[19] = c1[7] - c1[13] + 8 * c1[19] - 8 * c1[25] + c1[31]; 1155 | temp[20] = c1[8] - c1[14] + 8 * c1[20] - 8 * c1[26] + c1[32]; 1156 | temp[21] = c1[9] - c1[15] + 8 * c1[21] - 8 * c1[27] + c1[33]; 1157 | temp[22] = c1[10] - c1[16] + 8 * c1[22] - 8 * c1[28] + c1[34]; 1158 | temp[23] = c1[11] - c1[17] + 8 * c1[23] - 8 * c1[29] + c1[35]; 1159 | 1160 | dataDst[(x + 0) * nrows + y] = temp[0] + temp[1] + temp[2] + temp[3] + temp[4]; 1161 | dataDst[(x + 0) * nrows + y + 1] = temp[1] - temp[2] + 2 * temp[3] - 2 * temp[4]; 1162 | dataDst[(x + 0) * nrows + y + 2] = temp[1] + temp[2] + 4 * temp[3] + 4 * temp[4]; 1163 | dataDst[(x + 0) * nrows + y + 3] = temp[1] - temp[2] + 8 * temp[3] - 8 * temp[4] + temp[5]; 1164 | dataDst[(x + 1) * nrows + y] = temp[6] + temp[7] + temp[8] + temp[9] + temp[10]; 1165 | dataDst[(x + 1) * nrows + y + 1] = temp[7] - temp[8] + 2 * temp[9] - 2 * temp[10]; 1166 | dataDst[(x + 1) * nrows + y + 2] = temp[7] + temp[8] + 4 * temp[9] + 4 * temp[10]; 1167 | dataDst[(x + 1) * nrows + y + 3] = temp[7] - temp[8] + 8 * temp[9] - 8 * temp[10] + temp[11]; 1168 | dataDst[(x + 2) * nrows + y] = temp[12] + temp[13] + temp[14] + temp[15] + temp[16]; 1169 | dataDst[(x + 2) * nrows + y + 1] = temp[13] - temp[14] + 2 * temp[15] - 2 * temp[16]; 1170 | dataDst[(x + 2) * nrows + y + 2] = temp[13] + temp[14] + 4 * temp[15] + 4 * temp[16]; 1171 | dataDst[(x + 2) * nrows + y + 3] = temp[13] - temp[14] + 8 * temp[15] - 8 * temp[16] + temp[17]; 1172 | dataDst[(x + 3) * nrows + y] = temp[18] + temp[19] + temp[20] + temp[21] + temp[22]; 1173 | dataDst[(x + 3) * nrows + y + 1] = temp[19] - temp[20] + 2 * temp[21] - 2 * temp[22]; 1174 | dataDst[(x + 3) * nrows + y + 2] = temp[19] + temp[20] + 4 * temp[21] + 4 * temp[22]; 1175 | dataDst[(x + 3) * nrows + y + 3] = temp[19] - temp[20] + 8 * temp[21] - 8 * temp[22] + temp[23]; 1176 | 1177 | (*counter)++; 1178 | } 1179 | 1180 | static void get_tiles_4x3(const float* restrict image, const int ldi, const int irows, const int icols, 1181 | const int sizeI, const int C, float* restrict otile, const int N, const int ntiles, const int M) { 1182 | int outHeight = irows - 2; 1183 | int outWidth = icols - 2; 1184 | int fullOutHeight = outHeight / 4 * 4; 1185 | int fullOutWidth = outWidth / 64 * 64; 1186 | 1187 | //cout << "get tiles " << ntiles << ' ' << N * C << endl; 1188 | #pragma omp parallel for 1189 | for (int t = 0; t < N * C; ++t) { 1190 | int i, j; 1191 | 1192 | const int t1 = t / (C * M); 1193 | const int t2 = (t % (C * M)) / M; 1194 | const int t3 = t % M; 1195 | 1196 | const float *data = image + (t1 * M * C + t3 * C + t2) * sizeI; 1197 | int tile_count = t * ntiles; 1198 | 1199 | const int num16t = (icols - 2) / 64 * 64; 1200 | 1201 | float temp[6 * 66]__attribute__((aligned(64))); 1202 | for (i = 0; i < fullOutHeight; i += 4) { 1203 | for (j = 0; j < fullOutWidth; j += 64) { 1204 | get_tiles_4x3_16t(i, j, ldi, data, otile, &tile_count); 1205 | } 1206 | pad_get_tiles(i, j, 6, outWidth - fullOutWidth + 2, ldi, data, temp, otile, &tile_count); 1207 | } 1208 | for (j = 0; j < fullOutWidth; j += 64) { 1209 | pad_get_tiles(i, j, outHeight - fullOutHeight + 2, 66, ldi, data, temp, otile, &tile_count); 1210 | } 1211 | pad_get_tiles(i, j, outHeight - fullOutHeight + 2, outWidth - fullOutWidth + 2, ldi, data, temp, otile, &tile_count); 1212 | } 1213 | 1214 | /*for (int i = 0; i < 36; ++i) { 1215 | for (int j = 0; j < 32; ++j) { 1216 | cout << otile[i * ISTRIDE + j] << ' '; 1217 | } 1218 | cout << endl; 1219 | }*/ 1220 | /* for (i = 0; i < irows - 4; i += 4) { 1221 | for (j = 0; j < num16t; j += 64) { 1222 | get_tiles_4x3_16t(i, j, ldi, data, otile, &tile_count); 1223 | } 1224 | #pragma simd 1225 | for (; j < (icols - 4); j += 4) { 1226 | get_tiles_4x3_1t(i, j, ldi, data, otile, &tile_count); 1227 | } 1228 | 1229 | }*/ 1230 | 1231 | } 1232 | 1233 | static void batched_gemm_4x3(const float* image, const int irows, const int icols, const float* filter, const int frows, const int fcols, float* restrict out, const int batch) { 1234 | int t, i; 1235 | const char trans = 'n'; 1236 | const float alpha = 1.0; 1237 | const float beta = 0.0; 1238 | const int ldi = irows; 1239 | const int ldf = frows; 1240 | const int ldo = irows; 1241 | 1242 | //cout << "batched_gemm " << 36 * batch << ' ' << ISTRIDE << ' ' << OSTRIDE << ' ' << irows << ' ' << fcols << ' ' << icols << endl; 1243 | #pragma omp parallel for collapse(2) private(t, i) 1244 | for (i = 0; i < 36; ++i) { 1245 | for (t = 0; t < batch; ++t) { 1246 | const float* im = image + i * ISTRIDE + t * irows * icols; 1247 | const float* fi = filter + i * FSTRIDE; 1248 | float *ot = out + i * OSTRIDE + t * irows * fcols; 1249 | 1250 | sgemm(&trans, &trans, &irows, &fcols, &icols, &alpha, im, &ldi, fi, &ldf, &beta, ot, &ldo); 1251 | } 1252 | } 1253 | 1254 | /*for (int i = 0; i < 36; ++i) { 1255 | for (int j = 0; j < 16; ++j) { 1256 | cout << out[i * OSTRIDE + j] << ' '; 1257 | } 1258 | cout << endl; 1259 | }*/ 1260 | } 1261 | 1262 | static void out_transform_4x3(const float* restrict d, const int K, const int ntiles, float* restrict out, const int ldo, const int oH, const int oW, const int N, const int M) { 1263 | int t; 1264 | int sizeO = oH * oW; 1265 | const int OHP = oH / 4 * 4; 1266 | const int OWP = oW / 4 * 4; 1267 | 1268 | //cout << "out transform " << N * K << endl; 1269 | #pragma omp parallel for private(t) 1270 | for (t = 0; t < N * K; ++t) { 1271 | int i, j; 1272 | 1273 | const int t1 = t / (K * M); 1274 | const int t2 = (t % (K * M)) / M; 1275 | const int t3 = t % M; 1276 | 1277 | float *data = out + (t1 * M * K + t3 * K + t2) * sizeO; 1278 | int tile_offset = t * ntiles; 1279 | const int num16t = oW / 64 * 64; 1280 | float temp[4 * 64]__attribute__((aligned(64))); 1281 | for (i = 0; i < OHP; i += 4) { 1282 | for (j = 0; j < num16t; j += 64) { 1283 | out_transform_4x3_16t(i, j, ldo, d, data, &tile_offset); 1284 | } 1285 | pad_out_transform(i, j, 4, oW - j, ldo, d, temp, data, &tile_offset); 1286 | } 1287 | for (j = 0; j < num16t; j += 64) { 1288 | pad_out_transform(i, j, oH - i, 64, ldo, d, temp, data, &tile_offset); 1289 | } 1290 | pad_out_transform(i, j, oH - i, oW - j, ldo, d, temp, data, &tile_offset); 1291 | } 1292 | /* #pragma simd 1293 | for (; j < OWP; j += 4) { 1294 | out_transform_4x3_1t(i, j, ldo, d, data, &tile_offset); 1295 | }*/ 1296 | 1297 | } 1298 | 1299 | void winconv_2x3(const int bblock, const int M, float* restrict image, const int irows, const int icols, 1300 | const int C, float* restrict filter, const int K, const int batch, 1301 | float* restrict out) { 1302 | const int outHeight = irows - 2; 1303 | const int outWidth = icols - 2; 1304 | const int sizeI = irows * icols; 1305 | const int tiles = (outHeight) * 0.25 * (outWidth) * 0.25; 1306 | const int padHeight = (outHeight + 3) / 4 * 4; 1307 | const int padWidth = (outWidth + 63) / 64 * 64; 1308 | const int padTiles = padHeight / 4 * padWidth / 4; 1309 | float *b_image; 1310 | float *b_out; 1311 | const int b_batchSize = 64; 1312 | 1313 | 1314 | filter_transform_4x3(filter, C, K, t_filter); 1315 | 1316 | switch(bblock) { 1317 | case BATCH_TOGETHER: 1318 | timeval begin, end; 1319 | double elapse_time; 1320 | 1321 | int temp1 = ISTRIDE; 1322 | int temp2 = OSTRIDE; 1323 | 1324 | //ISTRIDE = batch * padTiles * C + 128; 1325 | //OSTRIDE = batch * padTiles * K + 128; 1326 | //gettimeofday(&begin, NULL); 1327 | get_tiles_4x3(image, icols, irows, icols, sizeI, C, t_image, batch, padTiles, M); 1328 | //gettimeofday(&end, NULL); 1329 | //elapse_time = (end.tv_sec - begin.tv_sec) * 1e3 + (end.tv_usec - begin.tv_usec) * 1e-3; 1330 | //cout << "get tiles time = " << elapse_time << endl; 1331 | 1332 | //gettimeofday(&begin, NULL); 1333 | batched_gemm_4x3(t_image, M * padTiles, C, t_filter, C, K, c_out, batch / M); 1334 | //gettimeofday(&end, NULL); 1335 | //elapse_time = (end.tv_sec - begin.tv_sec) * 1e3 + (end.tv_usec - begin.tv_usec) * 1e-3; 1336 | //cout << "gemm time = " << elapse_time << endl; 1337 | 1338 | 1339 | //gettimeofday(&begin, NULL); 1340 | out_transform_4x3(c_out, K, padTiles, out, outWidth, outHeight, outWidth, batch, M); 1341 | //gettimeofday(&end, NULL); 1342 | //elapse_time = (end.tv_sec - begin.tv_sec) * 1e3 + (end.tv_usec - begin.tv_usec) * 1e-3; 1343 | //cout << "out_transform time = " << elapse_time << endl << endl; 1344 | //ISTRIDE = temp1; 1345 | //OSTRIDE = temp2; 1346 | break; 1347 | case BATCH_BLOCK: 1348 | //ISTRIDE = batch * padTiles * C + 128; 1349 | //OSTRIDE = batch * padTiles * K + 128; 1350 | for (int i = 0; i < batch; i += b_batchSize) { 1351 | b_image = image + i * C * irows * icols; 1352 | b_out = out + i * K * outHeight * outWidth; 1353 | get_tiles_4x3(b_image, icols, irows, icols, sizeI, C, t_image, b_batchSize, padTiles, M); 1354 | batched_gemm_4x3(t_image, M * padTiles, C, t_filter, C, K, c_out, b_batchSize / M); 1355 | out_transform_4x3(c_out, K, padTiles, b_out, outWidth, outHeight, outWidth, b_batchSize, M); 1356 | } 1357 | break; 1358 | } 1359 | } 1360 | 1361 | /*int main() { 1362 | float *dataSrc = (float*) _mm_malloc(6 * 66 * sizeof(float), 64); 1363 | float *dataDst = (float*) _mm_malloc(36 * 16 * sizeof(float), 64); 1364 | int counter = 0; 1365 | for (int i = 0; i < 6; ++i) { 1366 | for (int j = 0; j < 66; ++j) { 1367 | dataSrc[i * 66 + j] = i * 66 + j; 1368 | } 1369 | } 1370 | get_tiles_4x3_1t(0, 0, 66, dataSrc, dataDst, &counter); 1371 | for (int i = 0; i < 36; ++i) { 1372 | cout << dataDst[i * ISTRIDE] << endl; 1373 | } 1374 | 1375 | float *filterSrc = (float*) _mm_malloc(9 * sizeof(float), 64); 1376 | float *filterDst = (float*) _mm_malloc(36 * sizeof(float), 64); 1377 | for (int i = 0; i < 9; ++i) { 1378 | filterSrc[i] = i + 1; 1379 | } 1380 | filter_transform_4x3(filterSrc, 1, 1, filterDst); 1381 | 1382 | float *out = (float*) _mm_malloc(36 * 16 * sizeof(float), 64); 1383 | float *outDst = (float*) _mm_malloc(4 * 64 * sizeof(float), 64); 1384 | 1385 | int t, i; 1386 | const char trans = 'n'; 1387 | const float alpha = 1.0; 1388 | const float beta = 0.0; 1389 | 1390 | int irows = 16; 1391 | int frows = 1; 1392 | int fcols = 1; 1393 | int icols = 1; 1394 | const int ldi = irows; 1395 | const int ldf = frows; 1396 | const int ldo = irows; 1397 | 1398 | for (i = 0; i < 36; ++i) { 1399 | const float* im = dataDst + i * ISTRIDE; 1400 | const float* fi = filterDst + i * FSTRIDE; 1401 | float *ot = out + i * OSTRIDE; 1402 | 1403 | sgemm(&trans, &trans, &irows, &fcols, &icols, &alpha, im, 1404 | &ldi, fi, &ldf, &beta, ot, &ldo); 1405 | } 1406 | 1407 | int tile_offset = 0; 1408 | out_transform_4x3_16t(0, 0, ldo, out, outDst, &tile_offset); 1409 | 1410 | cout << outDst[0] << ' ' << outDst[1] << ' ' << outDst[16] << endl; 1411 | 1412 | _mm_free(out); 1413 | _mm_free(outDst); 1414 | _mm_free(filterSrc); 1415 | _mm_free(filterDst); 1416 | _mm_free(dataSrc); 1417 | _mm_free(dataDst); 1418 | return 0; 1419 | } 1420 | */ 1421 | --------------------------------------------------------------------------------