├── .gitattributes ├── .gitignore ├── README.md ├── github_winograd.sln └── github_winograd ├── github_winograd.vcxproj ├── github_winograd.vcxproj.filters ├── include ├── mathlib.h ├── tool.h ├── winograd_kernel.h └── winograd_layer.h └── src └── winograd_test.cpp /.gitattributes: -------------------------------------------------------------------------------- 1 | ############################################################################### 2 | # Set default behavior to automatically normalize line endings. 3 | ############################################################################### 4 | * text=auto 5 | 6 | ############################################################################### 7 | # Set default behavior for command prompt diff. 8 | # 9 | # This is need for earlier builds of msysgit that does not have it on by 10 | # default for csharp files. 11 | # Note: This is only used by command line 12 | ############################################################################### 13 | #*.cs diff=csharp 14 | 15 | ############################################################################### 16 | # Set the merge driver for project and solution files 17 | # 18 | # Merging from the command prompt will add diff markers to the files if there 19 | # are conflicts (Merging from VS is not affected by the settings below, in VS 20 | # the diff markers are never inserted). Diff markers may cause the following 21 | # file extensions to fail to load in VS. An alternative would be to treat 22 | # these files as binary and thus will always conflict and require user 23 | # intervention with every merge. To do so, just uncomment the entries below 24 | ############################################################################### 25 | #*.sln merge=binary 26 | #*.csproj merge=binary 27 | #*.vbproj merge=binary 28 | #*.vcxproj merge=binary 29 | #*.vcproj merge=binary 30 | #*.dbproj merge=binary 31 | #*.fsproj merge=binary 32 | #*.lsproj merge=binary 33 | #*.wixproj merge=binary 34 | #*.modelproj merge=binary 35 | #*.sqlproj merge=binary 36 | #*.wwaproj merge=binary 37 | 38 | ############################################################################### 39 | # behavior for image files 40 | # 41 | # image files are treated as binary by default. 42 | ############################################################################### 43 | #*.jpg binary 44 | #*.png binary 45 | #*.gif binary 46 | 47 | ############################################################################### 48 | # diff behavior for common document formats 49 | # 50 | # Convert binary document formats to text before diffing them. This feature 51 | # is only available from the command line. Turn it on by uncommenting the 52 | # entries below. 53 | ############################################################################### 54 | #*.doc diff=astextplain 55 | #*.DOC diff=astextplain 56 | #*.docx diff=astextplain 57 | #*.DOCX diff=astextplain 58 | #*.dot diff=astextplain 59 | #*.DOT diff=astextplain 60 | #*.pdf diff=astextplain 61 | #*.PDF diff=astextplain 62 | #*.rtf diff=astextplain 63 | #*.RTF diff=astextplain 64 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ## Ignore Visual Studio temporary files, build results, and 2 | ## files generated by popular Visual Studio add-ons. 3 | 4 | # User-specific files 5 | *.suo 6 | *.user 7 | *.userosscache 8 | *.sln.docstates 9 | 10 | # User-specific files (MonoDevelop/Xamarin Studio) 11 | *.userprefs 12 | 13 | # Build results 14 | [Dd]ebug/ 15 | [Dd]ebugPublic/ 16 | [Rr]elease/ 17 | [Rr]eleases/ 18 | [Xx]64/ 19 | [Xx]86/ 20 | [Bb]uild/ 21 | bld/ 22 | [Bb]in/ 23 | [Oo]bj/ 24 | 25 | # Visual Studio 2015 cache/options directory 26 | .vs/ 27 | # Uncomment if you have tasks that create the project's static files in wwwroot 28 | #wwwroot/ 29 | 30 | # MSTest test Results 31 | [Tt]est[Rr]esult*/ 32 | [Bb]uild[Ll]og.* 33 | 34 | # NUNIT 35 | *.VisualState.xml 36 | TestResult.xml 37 | 38 | # Build Results of an ATL Project 39 | [Dd]ebugPS/ 40 | [Rr]eleasePS/ 41 | dlldata.c 42 | 43 | # DNX 44 | project.lock.json 45 | artifacts/ 46 | 47 | *_i.c 48 | *_p.c 49 | *_i.h 50 | *.ilk 51 | *.meta 52 | *.obj 53 | *.pch 54 | *.pdb 55 | *.pgc 56 | *.pgd 57 | *.rsp 58 | *.sbr 59 | *.tlb 60 | *.tli 61 | *.tlh 62 | *.tmp 63 | *.tmp_proj 64 | *.log 65 | *.vspscc 66 | *.vssscc 67 | .builds 68 | *.pidb 69 | *.svclog 70 | *.scc 71 | 72 | # Chutzpah Test files 73 | _Chutzpah* 74 | 75 | # Visual C++ cache files 76 | ipch/ 77 | *.aps 78 | *.ncb 79 | *.opendb 80 | *.opensdf 81 | *.sdf 82 | *.cachefile 83 | *.VC.db 84 | 85 | # Visual Studio profiler 86 | *.psess 87 | *.vsp 88 | *.vspx 89 | *.sap 90 | 91 | # TFS 2012 Local Workspace 92 | $tf/ 93 | 94 | # Guidance Automation Toolkit 95 | *.gpState 96 | 97 | # ReSharper is a .NET coding add-in 98 | _ReSharper*/ 99 | *.[Rr]e[Ss]harper 100 | *.DotSettings.user 101 | 102 | # JustCode is a .NET coding add-in 103 | .JustCode 104 | 105 | # TeamCity is a build add-in 106 | _TeamCity* 107 | 108 | # DotCover is a Code Coverage Tool 109 | *.dotCover 110 | 111 | # NCrunch 112 | _NCrunch_* 113 | .*crunch*.local.xml 114 | nCrunchTemp_* 115 | 116 | # MightyMoose 117 | *.mm.* 118 | AutoTest.Net/ 119 | 120 | # Web workbench (sass) 121 | .sass-cache/ 122 | 123 | # Installshield output folder 124 | [Ee]xpress/ 125 | 126 | # DocProject is a documentation generator add-in 127 | DocProject/buildhelp/ 128 | DocProject/Help/*.HxT 129 | DocProject/Help/*.HxC 130 | DocProject/Help/*.hhc 131 | DocProject/Help/*.hhk 132 | DocProject/Help/*.hhp 133 | DocProject/Help/Html2 134 | DocProject/Help/html 135 | 136 | # Click-Once directory 137 | publish/ 138 | 139 | # Publish Web Output 140 | *.[Pp]ublish.xml 141 | *.azurePubxml 142 | 143 | # TODO: Un-comment the next line if you do not want to checkin 144 | # your web deploy settings because they may include unencrypted 145 | # passwords 146 | #*.pubxml 147 | *.publishproj 148 | 149 | # NuGet Packages 150 | *.nupkg 151 | # The packages folder can be ignored because of Package Restore 152 | **/packages/* 153 | # except build/, which is used as an MSBuild target. 154 | !**/packages/build/ 155 | # Uncomment if necessary however generally it will be regenerated when needed 156 | #!**/packages/repositories.config 157 | # NuGet v3's project.json files produces more ignoreable files 158 | *.nuget.props 159 | *.nuget.targets 160 | 161 | # Microsoft Azure Build Output 162 | csx/ 163 | *.build.csdef 164 | 165 | # Microsoft Azure Emulator 166 | ecf/ 167 | rcf/ 168 | 169 | # Windows Store app package directory 170 | AppPackages/ 171 | BundleArtifacts/ 172 | 173 | # Visual Studio cache files 174 | # files ending in .cache can be ignored 175 | *.[Cc]ache 176 | # but keep track of directories ending in .cache 177 | !*.[Cc]ache/ 178 | 179 | # Others 180 | ClientBin/ 181 | [Ss]tyle[Cc]op.* 182 | ~$* 183 | *~ 184 | *.dbmdl 185 | *.dbproj.schemaview 186 | *.pfx 187 | *.publishsettings 188 | node_modules/ 189 | orleans.codegen.cs 190 | 191 | # RIA/Silverlight projects 192 | Generated_Code/ 193 | 194 | # Backup & report files from converting an old project file 195 | # to a newer Visual Studio version. Backup files are not needed, 196 | # because we have git ;-) 197 | _UpgradeReport_Files/ 198 | Backup*/ 199 | UpgradeLog*.XML 200 | UpgradeLog*.htm 201 | 202 | # SQL Server files 203 | *.mdf 204 | *.ldf 205 | 206 | # Business Intelligence projects 207 | *.rdl.data 208 | *.bim.layout 209 | *.bim_*.settings 210 | 211 | # Microsoft Fakes 212 | FakesAssemblies/ 213 | 214 | # GhostDoc plugin setting file 215 | *.GhostDoc.xml 216 | 217 | # Node.js Tools for Visual Studio 218 | .ntvs_analysis.dat 219 | 220 | # Visual Studio 6 build log 221 | *.plg 222 | 223 | # Visual Studio 6 workspace options file 224 | *.opt 225 | 226 | # Visual Studio LightSwitch build output 227 | **/*.HTMLClient/GeneratedArtifacts 228 | **/*.DesktopClient/GeneratedArtifacts 229 | **/*.DesktopClient/ModelManifest.xml 230 | **/*.Server/GeneratedArtifacts 231 | **/*.Server/ModelManifest.xml 232 | _Pvt_Extensions 233 | 234 | # LightSwitch generated files 235 | GeneratedArtifacts/ 236 | ModelManifest.xml 237 | 238 | # Paket dependency manager 239 | .paket/paket.exe 240 | 241 | # FAKE - F# Make 242 | .fake/ 243 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Winograd_Convolution 2 | Winograd_Convolution is a winograd based kernel for convolutions in deep learning frameworks, which is an implementation of winograd convolutions in [1]. Three WT methods, WT_6X6_F_4X4_3X3, WT_8X8_F_4X4_5X5, and WT_8X8_F_6X6_3X3, are supported, where convolution kernel 3x3 is the best choice. Parts of this work are from SkimCaffe [2], but this winograd kernel is more portable. 3 | 4 | Dependencies 5 | ----------------------------------- 6 | 7 | A fast blas is better, such as mkl-gemm and openblas [3]. 8 | 9 | Building 10 | ----------------------------------- 11 | 12 | Only header files written in C++, supports windows and linux. This version is built on VS 2015. 13 | 14 | Testing 15 | ----------------------------------- 16 | 17 | See winograd_test.cpp. 18 | 19 | Packaging 20 | ----------------------------------- 21 | 22 | "include/winograd_layer.h", can be natively integrated into some famous deep learning frameworks as a winograd_layer, like caffe (https://github.com/BVLC/caffe) and tiny-dnn (https://github.com/tiny-dnn/tiny-dnn). 23 | 24 | References & Dependencies 25 | ----------------------------------- 26 | [1] Andrew Lavin, Scott Gray. Fast Algorithms for Convolutional Neural Networks. https://arxiv.org/abs/1509.09308 27 | 28 | [2] SkimCaffe, https://github.com/IntelLabs/SkimCaffe 29 | 30 | [3] OpenBLAS, https://github.com/xianyi/OpenBLAS. 31 | 32 | License 33 | ----------------------------------- 34 | The BSD 3-Clause License 35 | -------------------------------------------------------------------------------- /github_winograd.sln: -------------------------------------------------------------------------------- 1 |  2 | Microsoft Visual Studio Solution File, Format Version 12.00 3 | # Visual Studio 14 4 | VisualStudioVersion = 14.0.25420.1 5 | MinimumVisualStudioVersion = 10.0.40219.1 6 | Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "github_winograd", "github_winograd\github_winograd.vcxproj", "{91EDB49D-AAD0-47A7-A8C3-A87BBDF7099C}" 7 | EndProject 8 | Global 9 | GlobalSection(SolutionConfigurationPlatforms) = preSolution 10 | Debug|x64 = Debug|x64 11 | Debug|x86 = Debug|x86 12 | Release|x64 = Release|x64 13 | Release|x86 = Release|x86 14 | EndGlobalSection 15 | GlobalSection(ProjectConfigurationPlatforms) = postSolution 16 | {91EDB49D-AAD0-47A7-A8C3-A87BBDF7099C}.Debug|x64.ActiveCfg = Debug|x64 17 | {91EDB49D-AAD0-47A7-A8C3-A87BBDF7099C}.Debug|x64.Build.0 = Debug|x64 18 | {91EDB49D-AAD0-47A7-A8C3-A87BBDF7099C}.Debug|x86.ActiveCfg = Debug|Win32 19 | {91EDB49D-AAD0-47A7-A8C3-A87BBDF7099C}.Debug|x86.Build.0 = Debug|Win32 20 | {91EDB49D-AAD0-47A7-A8C3-A87BBDF7099C}.Release|x64.ActiveCfg = Release|x64 21 | {91EDB49D-AAD0-47A7-A8C3-A87BBDF7099C}.Release|x64.Build.0 = Release|x64 22 | {91EDB49D-AAD0-47A7-A8C3-A87BBDF7099C}.Release|x86.ActiveCfg = Release|Win32 23 | {91EDB49D-AAD0-47A7-A8C3-A87BBDF7099C}.Release|x86.Build.0 = Release|Win32 24 | EndGlobalSection 25 | GlobalSection(SolutionProperties) = preSolution 26 | HideSolutionNode = FALSE 27 | EndGlobalSection 28 | EndGlobal 29 | -------------------------------------------------------------------------------- /github_winograd/github_winograd.vcxproj: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | 5 | Debug 6 | Win32 7 | 8 | 9 | Release 10 | Win32 11 | 12 | 13 | Debug 14 | x64 15 | 16 | 17 | Release 18 | x64 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | {91EDB49D-AAD0-47A7-A8C3-A87BBDF7099C} 32 | github_winograd 33 | 8.1 34 | Winograd_Convolution 35 | 36 | 37 | 38 | Application 39 | true 40 | v140 41 | MultiByte 42 | 43 | 44 | Application 45 | false 46 | v140 47 | true 48 | MultiByte 49 | 50 | 51 | Application 52 | true 53 | v140 54 | MultiByte 55 | 56 | 57 | Application 58 | false 59 | v140 60 | true 61 | MultiByte 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | $(SolutionDir)\packages\OpenBLAS.0.2.14.1\lib\native\include;$(VC_IncludePath);$(WindowsSDK_IncludePath); 83 | $(SolutionDir)\packages\OpenBLAS.0.2.14.1\lib\native\lib\x64;$(VC_LibraryPath_x64);$(VC_LibraryPath_x86);$(WindowsSDK_LibraryPath_x86);$(NETFXKitsDir)Lib\um\x86 84 | 85 | 86 | $(SolutionDir)\packages\OpenBLAS.0.2.14.1\lib\native\include;$(VC_IncludePath);$(WindowsSDK_IncludePath); 87 | $(SolutionDir)\packages\OpenBLAS.0.2.14.1\lib\native\lib\x64;$(VC_LibraryPath_x64);$(VC_LibraryPath_x86);$(WindowsSDK_LibraryPath_x86);$(VC_LibraryPath_x86);$(WindowsSDK_LibraryPath_x86);$(NETFXKitsDir)Lib\um\x86 88 | 89 | 90 | $(SolutionDir)\packages\OpenBLAS.0.2.14.1\lib\native\include;$(VC_IncludePath);$(VC_IncludePath);$(WindowsSDK_IncludePath); 91 | $(SolutionDir)\packages\OpenBLAS.0.2.14.1\lib\native\lib\x64;$(VC_LibraryPath_x64);$(VC_LibraryPath_x64);$(WindowsSDK_LibraryPath_x64);$(NETFXKitsDir)Lib\um\x64 92 | 93 | 94 | $(SolutionDir)\packages\OpenBLAS.0.2.14.1\lib\native\lib\x64;$(VC_LibraryPath_x64);$(WindowsSDK_LibraryPath_x64);$(NETFXKitsDir)Lib\um\x64 95 | $(SolutionDir)\packages\OpenBLAS.0.2.14.1\lib\native\include;$(VC_IncludePath);$(WindowsSDK_IncludePath); 96 | 97 | 98 | 99 | Level3 100 | Disabled 101 | true 102 | 103 | 104 | libopenblas.dll.a;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies) 105 | 106 | 107 | 108 | 109 | Level3 110 | Disabled 111 | true 112 | 113 | 114 | libopenblas.dll.a;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies) 115 | 116 | 117 | 118 | 119 | Level3 120 | MaxSpeed 121 | true 122 | true 123 | true 124 | 125 | 126 | true 127 | true 128 | libopenblas.dll.a;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies) 129 | 130 | 131 | 132 | 133 | Level3 134 | MaxSpeed 135 | true 136 | true 137 | true 138 | 139 | 140 | true 141 | true 142 | libopenblas.dll.a;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies) 143 | 144 | 145 | 146 | 147 | 148 | -------------------------------------------------------------------------------- /github_winograd/github_winograd.vcxproj.filters: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | 5 | {4FC737F1-C7A5-4376-A066-2A32D752A2FF} 6 | cpp;c;cc;cxx;def;odl;idl;hpj;bat;asm;asmx 7 | 8 | 9 | {93995380-89BD-4b04-88EB-625FBE52EBFB} 10 | h;hh;hpp;hxx;hm;inl;inc;xsd 11 | 12 | 13 | {67DA6AB6-F800-4c08-8B7A-83BB121AAD01} 14 | rc;ico;cur;bmp;dlg;rc2;rct;bin;rgs;gif;jpg;jpeg;jpe;resx;tiff;tif;png;wav;mfcribbon-ms 15 | 16 | 17 | 18 | 19 | Header Files 20 | 21 | 22 | Header Files 23 | 24 | 25 | Header Files 26 | 27 | 28 | Header Files 29 | 30 | 31 | 32 | 33 | Source Files 34 | 35 | 36 | -------------------------------------------------------------------------------- /github_winograd/include/mathlib.h: -------------------------------------------------------------------------------- 1 | #ifndef MATHLAB_H 2 | #define MATHLAB_H 3 | 4 | #define USE_MKL 0 5 | #define USE_OPENBLAS 1 6 | 7 | #if USE_MKL 8 | #include 9 | 10 | #elif USE_OPENBLAS 11 | #include 12 | 13 | #endif 14 | 15 | #endif -------------------------------------------------------------------------------- /github_winograd/include/tool.h: -------------------------------------------------------------------------------- 1 | #ifndef PUBLIC_TOOL_H 2 | #define PUBLIC_TOOL_H 3 | 4 | #include 5 | #include 6 | #include "mathlib.h" 7 | 8 | namespace PUBLIC_TOOL{ 9 | 10 | template 11 | Dtype max(Dtype a, Dtype b) { 12 | if (a > b) return a; 13 | else return b; 14 | } 15 | 16 | template 17 | Dtype min(Dtype a, Dtype b) { 18 | if (a < b) return a; 19 | else return b; 20 | } 21 | 22 | void dlm_cpu_gemm(const CBLAS_TRANSPOSE TransA, 23 | const CBLAS_TRANSPOSE TransB, const int M, const int N, const int K, 24 | const float alpha, const float* A, const float* B, const float beta, 25 | float* C) { 26 | int lda = (TransA == CblasNoTrans) ? K : M; 27 | int ldb = (TransB == CblasNoTrans) ? N : K; 28 | cblas_sgemm(CblasRowMajor, TransA, TransB, M, N, K, alpha, A, lda, B, 29 | ldb, beta, C, N); 30 | } 31 | 32 | void dlm_cpu_gemm(const CBLAS_TRANSPOSE TransA, 33 | const CBLAS_TRANSPOSE TransB, const int M, const int N, const int K, 34 | const double alpha, const double* A, const double* B, const double beta, 35 | double* C) { 36 | int lda = (TransA == CblasNoTrans) ? K : M; 37 | int ldb = (TransB == CblasNoTrans) ? N : K; 38 | cblas_dgemm(CblasRowMajor, TransA, TransB, M, N, K, alpha, A, lda, B, 39 | ldb, beta, C, N); 40 | } 41 | 42 | 43 | }; 44 | 45 | #endif -------------------------------------------------------------------------------- /github_winograd/include/winograd_kernel.h: -------------------------------------------------------------------------------- 1 | 2 | #ifndef WINOGRAD_KERNEL_H 3 | #define WINOGRAD_KERNEL_H 4 | 5 | #include 6 | 7 | #define DEBUG_WINOGRAD 1 8 | 9 | #if DEBUG_WINOGRAD 10 | #include 11 | #endif 12 | 13 | namespace WINOGRAD_KERNEL { 14 | 15 | const enum WINOGRAD_MATRIX { 16 | WINOGRAD_A = 0, 17 | WINOGRAD_B, 18 | WINOGRAD_G, 19 | }; 20 | const enum WINOGRAD_ALG { 21 | WT_8X8_F_6X6_3X3 = 0, 22 | WT_6X6_F_4X4_3X3, 23 | WT_8X8_F_4X4_5X5, 24 | }; 25 | 26 | const int MATRIX_KINDS = 3; 27 | const int WINOGRAD_PAIR_KINDS = 3; 28 | 29 | template 30 | struct WinogradTransformMatrix {}; 31 | 32 | /** 33 | * compute Kronecker product of in1 and in2, where in1 is a m by n matrix and in2 is a p by q matrix 34 | * 35 | * @params out an (m*p) by (n*q) matrix stored in row major 36 | * @params in1 an m by n matrix stored in row major 37 | * @params in2 an p by q matrix stored in row major 38 | */ 39 | void kronecker_product(float *out, const float *in1, const float *in2, int m, int n, int p, int q); 40 | 41 | //singleton, precomputation before inference 42 | void winograd2D_initialize(); 43 | 44 | template<> 45 | struct WinogradTransformMatrix 46 | { 47 | // wt6x6, F(4x4,3x3) 48 | private: 49 | static const int O = 4; 50 | static const int K = 3; 51 | static const int T = O + K - 1; 52 | 53 | static const float *getG() { 54 | static const float G[T*K] = { 55 | 1. / 4., 0, 0, 56 | -1. / 6., -1. / 6., -1. / 6., 57 | -1. / 6., 1. / 6., -1. / 6., 58 | 1. / 24., 1. / 12., 1. / 6., 59 | 1. / 24., -1. / 12., 1. / 6., 60 | 0, 0, 1, 61 | }; 62 | return G; 63 | } 64 | 65 | static const float *getA() { 66 | static const float A[T*O] = { 67 | 1, 0, 0, 0, 68 | 1, 1, 1, 1, 69 | 1, -1, 1, -1, 70 | 1, 2, 4, 8, 71 | 1, -2, 4, -8, 72 | 0, 0, 0, 1, 73 | }; 74 | return A; 75 | } 76 | 77 | static const float *getB() { 78 | static const float B[T*T] = { 79 | 4, 0, 0, 0, 0, 0, 80 | 0, -4, 4, -2, 2, 4, 81 | -5, -4, -4, -1, -1, 0, 82 | 0, 1, -1, 2, -2, -5, 83 | 1, 1, 1, 1, 1, 0, 84 | 0, 0, 0, 0, 0, 1, 85 | }; 86 | return B; 87 | }; 88 | 89 | public: 90 | static const float *get(WINOGRAD_MATRIX mat, int &row, int& col) { 91 | 92 | #if DEBUG_WINOGRAD 93 | assert(mat >= WINOGRAD_A && mat <= WINOGRAD_G); 94 | #endif 95 | switch (mat) { 96 | 97 | case WINOGRAD_A: row = T; col = O; return getA(); 98 | case WINOGRAD_B: row = T; col = T; return getB(); 99 | case WINOGRAD_G: row = T; col = K; return getG(); 100 | 101 | } 102 | 103 | } 104 | 105 | }; 106 | 107 | template<> 108 | struct WinogradTransformMatrix 109 | { 110 | 111 | private: 112 | 113 | // wt8x8, F(6x6,3x3) 114 | 115 | static const int O = 6; 116 | static const int K = 3; 117 | static const int T = O + K - 1; 118 | 119 | public: 120 | static const float *get(WINOGRAD_MATRIX mat, int &row, int& col) { 121 | 122 | #if DEBUG_WINOGRAD 123 | assert(mat >= WINOGRAD_A && mat <= WINOGRAD_G); 124 | #endif 125 | switch (mat) { 126 | 127 | case WINOGRAD_A: row = T; col = O; return getA(); 128 | case WINOGRAD_B: row = T; col = T; return getB(); 129 | case WINOGRAD_G: row = T; col = K; return getG(); 130 | 131 | } 132 | 133 | } 134 | 135 | private: 136 | static const float *getG() { 137 | static const float G[T*K] = { 138 | 1.f, 0.f , 0.f , 139 | -2.f / 9 , -2.f / 9 , -2.f / 9, 140 | -2.f / 9 , 2.f / 9 , -2.f / 9, 141 | 1.f / 90 , 1.f / 45 , 2.f / 45, 142 | 1.f / 90 , -1.f / 45 , 2.f / 45, 143 | 32.f / 45, 16.f / 45, 8.f / 45, 144 | 32.f / 45, -16.f / 45, 8.f / 45, 145 | 0.f , 0.f , 1.f , 146 | }; 147 | return G; 148 | } 149 | 150 | static const float *getA() { 151 | static const float A[T*(T - K + 1)] = { 152 | 1 * 1.f, 0 * 1.f, 0 * 1.f, 0 * 1.f, 0 * 1.f, 0 * 1.f, 153 | 1 * 1.f, 1 * 1.f, 1 * 1.f, 1 * 1.f, 1 * 1.f, 1 * 1.f, 154 | 1 * 1.f, -1 * 1.f, 1 * 1.f, -1 * 1.f, 1 * 1.f, -1 * 1.f, 155 | 1 * 1.f, 2 * 1.f, 4 * 1.f, 8 * 1.f, 16 * 1.f, 32 * 1.f, 156 | 1 * 1.f, -2 * 1.f, 4 * 1.f, -8 * 1.f, 16 * 1.f, -32 * 1.f, 157 | 1 * 1.f, 0.5*1.f, 0.25*1.f, 0.125*1.f, 0.0625*1.f, 0.03125*1.f, 158 | 1 * 1.f, -0.5*1.f, 0.25*1.f, -0.125*1.f, 0.0625*1.f, -0.03125*1.f, 159 | 0 * 1.f, 0 * 1.f, 0 * 1.f, 0 * 1.f, 0 * 1.f, 1 * 1.f, 160 | }; 161 | return A; 162 | } 163 | 164 | static const float *getB() { 165 | static const float B[T*T] = { 166 | 1 * 1.f, 0 * 1.f, 0 * 1.f, 0 * 1.f, 0 * 1.f, 0 * 1.f, 0 * 1.f, 0 * 1.f, 167 | 0 * 1.f, 1 * 1.f, -1 * 1.f, 0.5*1.f, -0.5*1.f, 2 * 1.f, -2 * 1.f, -1 * 1.f, 168 | -5.25*1.f, 1 * 1.f, 1 * 1.f, 0.25*1.f, 0.25*1.f, 4 * 1.f, 4 * 1.f, 0 * 1.f, 169 | 0 * 1.f, -4.25*1.f, 4.25*1.f, -2.5*1.f, 2.5*1.f, -2.5*1.f, 2.5*1.f, 5.25*1.f, 170 | 5.25*1.f, -4.25*1.f, -4.25*1.f, -1.25*1.f, -1.25*1.f, -5 * 1.f, -5 * 1.f, 0 * 1.f, 171 | 0 * 1.f, 1 * 1.f, -1 * 1.f, 2 * 1.f, -2 * 1.f, 0.5*1.f, -0.5*1.f, -5.25*1.f, 172 | -1 * 1.f, 1 * 1.f, 1 * 1.f, 1 * 1.f, 1 * 1.f, 1 * 1.f, 1 * 1.f, 0 * 1.f, 173 | 0 * 1.f, 0 * 1.f, 0 * 1.f, 0 * 1.f, 0 * 1.f, 0 * 1.f, 0 * 1.f, 1 * 1.f, 174 | }; 175 | return B; 176 | }; 177 | }; 178 | 179 | template<> 180 | struct WinogradTransformMatrix 181 | { 182 | 183 | private: 184 | // wt8x8, F(4x4,5x5) 185 | static const int T = 5 + 4 - 1; 186 | static const int K = 5; 187 | static const int O = 4; 188 | 189 | public: 190 | static const float *get(WINOGRAD_MATRIX mat, int &row, int& col) { 191 | 192 | #if DEBUG_WINOGRAD 193 | assert(mat >= WINOGRAD_A && mat <= WINOGRAD_G); 194 | #endif 195 | switch (mat) { 196 | 197 | case WINOGRAD_A: row = T; col = O; return getA(); 198 | case WINOGRAD_B: row = T; col = T; return getB(); 199 | case WINOGRAD_G: row = T; col = K; return getG(); 200 | 201 | } 202 | 203 | } 204 | 205 | private: 206 | 207 | // from https://github.com/Maratyszcza/NNPACK/issues/12 208 | 209 | static const float *getG() { 210 | static const float G[T*K] = { 211 | 1, 0, 0, 0, 0, 212 | -2. / 9., -2. / 9., -2. / 9., -2. / 9., -2. / 9., 213 | -2. / 9., 2. / 9., -2. / 9., 2. / 9., -2. / 9., 214 | 1. / 90., 1. / 45., 2. / 45., 4. / 45., 8. / 45., 215 | 1. / 90., -1. / 45., 2. / 45., -4. / 45., 8. / 45., 216 | 4. / 45., 2. / 45., 1. / 45., 1. / 90., 1. / 180., 217 | 4. / 45., -2. / 45., 1. / 45., -1. / 90., 1. / 180., 218 | 0, 0, 0, 0, 1, 219 | }; 220 | return G; 221 | } 222 | 223 | 224 | 225 | 226 | static const float *getA() { 227 | static const float A[T*(O)] = { 228 | 1, 0, 0, 0, 229 | 1, 1, 1, 1, 230 | 1, -1, 1, -1, 231 | 1, 2, 4, 8, 232 | 1, -2, 4, -8, 233 | 8, 4, 2, 1, 234 | 8, -4, 2, -1, 235 | 0, 0, 0, 1 236 | }; 237 | return A; 238 | } 239 | 240 | static const float *getB() { 241 | static const float B[T*T] = { 242 | 1, 0, 0, 0, 0, 0, 0, 0, 243 | 0, 1, -1, 1. / 2, -1. / 2, 2, -2, -1, 244 | -21. / 4, 1, 1, 1. / 4, 1. / 4, 4, 4, 0, 245 | 0, -17. / 4, 17. / 4, -5. / 2, 5. / 2, -5. / 2, 5. / 2, 21. / 4, 246 | 21. / 4, -17. / 4, -17. / 4, -5. / 4, -5. / 4, -5, -5, 0, 247 | 0, 1, -1, 2, -2, 1. / 2, -1. / 2, -21. / 4, 248 | -1, 1, 1, 1, 1, 1, 1, 0, 249 | 0, 0, 0, 0, 0, 0, 0, 1, 250 | }; 251 | return B; 252 | } 253 | }; 254 | 255 | class Winograd_Kron 256 | { 257 | 258 | private: 259 | 260 | Winograd_Kron(WINOGRAD_ALG alg, WINOGRAD_MATRIX mat) { 261 | 262 | isCalc = false; 263 | 264 | switch (alg) { 265 | 266 | case WT_8X8_F_6X6_3X3: 267 | matrix = WinogradTransformMatrix::get(mat, row, col); break; 268 | case WT_6X6_F_4X4_3X3: 269 | matrix = WinogradTransformMatrix::get(mat, row, col); break; 270 | case WT_8X8_F_4X4_5X5: 271 | matrix = WinogradTransformMatrix::get(mat, row, col); break; 272 | 273 | } 274 | 275 | } 276 | 277 | private: 278 | const float *matrix; // = A, B, G 279 | int row, col;// matrix: row*col 280 | // A: T*O 281 | // B: M*M 282 | // G: T*K 283 | 284 | std::shared_ptr kron; 285 | 286 | bool isCalc; 287 | 288 | //std::shared_ptr normOfRowsInv; 289 | 290 | public: 291 | 292 | static Winograd_Kron *getInstance(WINOGRAD_ALG alg, WINOGRAD_MATRIX mat) { 293 | 294 | // 9 instances 3*3 295 | static Winograd_Kron * instances[MATRIX_KINDS *WINOGRAD_PAIR_KINDS] = { NULL }; // according to [WINOGRAD_MATRIX] [WINOGRAD_PAIR] 296 | 297 | int index = alg*WINOGRAD_PAIR_KINDS + mat; 298 | 299 | if (instances[index] == NULL) 300 | instances[index] = new Winograd_Kron(alg, mat); 301 | 302 | return instances[index]; 303 | 304 | } 305 | 306 | const std::shared_ptr get() { 307 | if (isCalc) 308 | return kron; 309 | else { 310 | calc(); 311 | return kron; 312 | } 313 | 314 | } 315 | 316 | private: 317 | 318 | void calc() { 319 | 320 | kron = std::shared_ptr(new float[row*col*row*col]); 321 | 322 | kronecker_product(kron.get(), matrix, matrix, row, col, row, col); 323 | 324 | isCalc = true; 325 | 326 | } 327 | 328 | }; 329 | 330 | void kronecker_product(float *out, const float *in1, const float *in2, int m, int n, int p, int q) 331 | { 332 | for (int i = 0; i < m; ++i) { 333 | for (int j = 0; j < n; ++j) { 334 | for (int k = 0; k < p; ++k) { 335 | for (int l = 0; l < q; ++l) { 336 | out[(p*i + k)*n*q + q*j + l] = in1[n*i + j] * in2[k*q + l]; 337 | /* compute in float precision and then convert it back to float for accuracy */ 338 | } 339 | } 340 | } 341 | } 342 | } 343 | 344 | void winograd2D_initialize() { 345 | //singleton, precomputation before inference 346 | 347 | Winograd_Kron::getInstance(WT_6X6_F_4X4_3X3, WINOGRAD_A)->get(); 348 | Winograd_Kron::getInstance(WT_6X6_F_4X4_3X3, WINOGRAD_B)->get(); 349 | Winograd_Kron::getInstance(WT_6X6_F_4X4_3X3, WINOGRAD_G)->get(); 350 | 351 | Winograd_Kron::getInstance(WT_8X8_F_6X6_3X3, WINOGRAD_A)->get(); 352 | Winograd_Kron::getInstance(WT_8X8_F_6X6_3X3, WINOGRAD_B)->get(); 353 | Winograd_Kron::getInstance(WT_8X8_F_6X6_3X3, WINOGRAD_G)->get(); 354 | } 355 | } 356 | 357 | 358 | #endif -------------------------------------------------------------------------------- /github_winograd/include/winograd_layer.h: -------------------------------------------------------------------------------- 1 | #ifndef WINOGRAD_LAYER_H 2 | #define WINOGRAD_LAYER_H 3 | 4 | #include 5 | #include "winograd_kernel.h" 6 | #include "tool.h" 7 | //winograd for cpu inference 8 | 9 | // default 3x3 10 | const int KERNEL_SIZE = 3; 11 | 12 | namespace WINOGRAD_KERNEL 13 | { 14 | 15 | template 16 | class WinogradLayer { 17 | 18 | private: 19 | 20 | int m_group_; 21 | int m_batchSize; 22 | 23 | int m_bottom_dim_;// par size 24 | int m_top_dim_; 25 | 26 | // The following variables are initialized in WeightAlign 27 | int tile_h_in_, tile_w_in_; /* input tile size */ 28 | int tile_h_out_, tile_w_out_; /* output tile size */ 29 | int ntiles_h_, ntiles_w_; /* number of tiles */ 30 | 31 | int conv_in_channels_; //ic 32 | int conv_out_channels_;//oc 33 | 34 | int m_iH; 35 | int m_iW; 36 | 37 | int m_oH; 38 | int m_oW; 39 | 40 | int m_kH; 41 | int m_kW; 42 | int m_sH; 43 | int m_sW; 44 | 45 | int m_pad; 46 | bool m_bias; 47 | 48 | private: 49 | 50 | Dtype* m_inputOrg; 51 | const Dtype* m_weightOrg; 52 | 53 | Dtype* m_winogradWeight; // support NCHW storage 54 | Dtype* m_winogradInput; 55 | 56 | Dtype* m_col_buff;//buffer 57 | 58 | WINOGRAD_ALG m_alg; 59 | 60 | public: 61 | 62 | WinogradLayer(WINOGRAD_ALG alg, int batch_size, int iH, int iW, int iC, int kH, int kW, int sH, int sW, int oC, int pad, bool bias = true) : m_alg(alg) { 63 | 64 | #if DEBUG_WINOGRAD 65 | assert(kH == kW, "kernel 3x3 is the best choice, some errors may occur for other kernels"); 66 | #endif 67 | m_iH = iH; 68 | m_iW = iW; 69 | conv_in_channels_ = iC; 70 | m_kH = kH; 71 | m_kW = kW; 72 | m_sH = sH; 73 | m_sW = sW; 74 | conv_out_channels_ = oC; 75 | m_pad = pad; // pad_h = pad_w 76 | m_bias = bias; 77 | 78 | m_batchSize = batch_size; 79 | m_group_ = 1; 80 | 81 | m_bottom_dim_ = 0;// default batch =1 82 | m_top_dim_ = 0; 83 | 84 | m_winogradWeight = NULL; 85 | m_winogradInput = NULL; 86 | 87 | 88 | // Output width. 89 | m_oW = (m_iW + m_pad * 2 - m_kW) / m_sW + 1; 90 | m_oH = (m_iH + m_pad * 2 - m_kH) / m_sH + 1; 91 | 92 | if (alg == WT_8X8_F_6X6_3X3) { 93 | 94 | tile_h_in_ = 8; 95 | tile_w_in_ = 8; /* input tile size */ 96 | 97 | tile_h_out_ = tile_h_in_ - m_kH + 1; 98 | tile_w_out_ = tile_w_in_ - m_kW + 1; /* output tile size */ 99 | 100 | ntiles_h_ = (PUBLIC_TOOL::max(m_iH + m_pad - tile_h_in_ + 1, m_oH) + tile_h_out_ - 1) / tile_h_out_; 101 | ntiles_w_ = (PUBLIC_TOOL::max(m_iW + m_pad - tile_w_in_ + 1, m_oW) + tile_w_out_ - 1) / tile_w_out_; 102 | 103 | } 104 | else if (alg == WT_6X6_F_4X4_3X3) { 105 | 106 | tile_h_in_ = 6; 107 | tile_w_in_ = 6; /* input tile size */ 108 | 109 | tile_h_out_ = tile_h_in_ - m_kH + 1; 110 | tile_w_out_ = tile_w_in_ - m_kW + 1; /* output tile size */ 111 | 112 | ntiles_h_ = (PUBLIC_TOOL::max(m_iH + m_pad - tile_h_in_ + 1, m_oH) + tile_h_out_ - 1) / tile_h_out_; 113 | ntiles_w_ = (PUBLIC_TOOL::max(m_iW + m_pad - tile_w_in_ + 1, m_oW) + tile_w_out_ - 1) / tile_w_out_; 114 | 115 | } 116 | else throw("convolution algorithm error!"); 117 | 118 | } 119 | 120 | template 121 | const std::shared_ptr get_inference_cpu(Dtype* data, const Dtype* par, Dtype* col_buff) { 122 | 123 | m_inputOrg = data; 124 | m_weightOrg = par; 125 | m_col_buff = col_buff; 126 | 127 | 128 | std::shared_ptr resOut = std::shared_ptr(new Dtype[m_oH*m_oW*conv_out_channels_]); 129 | 130 | //trans weight to winograd domain 131 | trans_weight2wiongrad(); 132 | 133 | 134 | for (int n = 0; n < m_batchSize; n++) { 135 | 136 | //trans input to winograd domain 137 | trans_input2winograd(m_inputOrg + n*m_bottom_dim_, m_col_buff); 138 | 139 | 140 | // Convolution in Winograd domain 141 | winograd_conv(); 142 | 143 | 144 | // Transform back to time domain 145 | trans2spatial(resOut.get() + n*this->m_top_dim_); 146 | 147 | //bias 148 | if (this->m_bias) { 149 | 150 | int base = conv_in_channels_ * conv_out_channels_ * m_kW * m_kH; 151 | 152 | const Dtype* bias = &par[base]; 153 | 154 | this->forward_cpu_bias(resOut.get() + n * this->m_top_dim_, bias); 155 | } 156 | } 157 | 158 | return resOut; 159 | } 160 | 161 | 162 | public: 163 | ~WinogradLayer() { 164 | /*if (!m_winogradInput) delete[] m_winogradInput; 165 | if (!m_winogradWeight) delete[] m_winogradWeight;*/ 166 | } 167 | 168 | 169 | private: 170 | 171 | void trans_weight2wiongrad() {// weight: hwcn --> cn hw 172 | 173 | // transform weights to Winograd domain 174 | if (!m_winogradWeight) m_winogradWeight = new Dtype[conv_in_channels_*conv_out_channels_*tile_h_in_*tile_w_in_]; 175 | 176 | PUBLIC_TOOL::dlm_cpu_gemm(CblasNoTrans, CblasTrans, 177 | tile_h_in_*tile_w_in_, (conv_in_channels_ / m_group_)*conv_out_channels_, m_kH*m_kW, 178 | (Dtype)1, 179 | Winograd_Kron::getInstance(m_alg, WINOGRAD_G)->get().get(), 180 | m_weightOrg, 181 | (Dtype)0, 182 | m_winogradWeight); 183 | 184 | } 185 | 186 | template 187 | void trans_input2winograd(const Dtype *data, Dtype *col_buff) { 188 | // Transform input to Winograd domain 189 | 190 | winograd_input_im2col_cpu(data, col_buff); 191 | 192 | 193 | int M = this->conv_in_channels_*ntiles_h_*ntiles_w_; 194 | 195 | if (!m_winogradInput) m_winogradInput = new Dtype[tile_h_in_*tile_w_in_*M]; 196 | 197 | PUBLIC_TOOL::dlm_cpu_gemm(CblasTrans, CblasTrans, 198 | tile_h_in_*tile_w_in_, M, tile_h_in_*tile_w_in_, 199 | (Dtype)1, 200 | Winograd_Kron::getInstance(m_alg, WINOGRAD_B)->get().get(), 201 | col_buff, 202 | (Dtype)0, this->m_winogradInput); 203 | 204 | } 205 | 206 | void winograd_conv() { 207 | 208 | // Convolution in Winograd domain 209 | for (int j = 0; j < tile_h_in_*tile_w_in_; ++j) { 210 | for (int g = 0; g < this->m_group_; ++g) { 211 | PUBLIC_TOOL::dlm_cpu_gemm(CblasNoTrans, CblasNoTrans, 212 | this->conv_out_channels_ / this->m_group_, ntiles_h_*ntiles_w_, this->conv_in_channels_ / this->m_group_, 213 | (Dtype)1, 214 | m_winogradWeight + (j*this->m_group_ + g)*(this->conv_out_channels_ / this->m_group_)*(this->conv_in_channels_ / this->m_group_), 215 | m_winogradInput + (j*this->m_group_ + g)*(this->conv_in_channels_ / this->m_group_)*ntiles_h_*ntiles_w_, 216 | (Dtype)0, m_col_buff + (j*this->m_group_ + g)*(this->conv_out_channels_ / this->m_group_)*ntiles_h_*ntiles_w_); 217 | } 218 | } 219 | // col_buff has (tile_h_in*tile_w_in) x (conv_out_channels) x (ntiles_h*ntiles_w) 220 | } 221 | 222 | template 223 | void trans2spatial(Dtype *data) { 224 | 225 | Dtype *winogradRes = new Dtype[this->conv_out_channels_*ntiles_h_*ntiles_w_*tile_h_out_*tile_w_out_]; 226 | 227 | PUBLIC_TOOL::dlm_cpu_gemm(CblasTrans, CblasNoTrans, 228 | this->conv_out_channels_*ntiles_h_*ntiles_w_, tile_h_out_*tile_w_out_, tile_h_in_*tile_w_in_, 229 | (Dtype)1, m_col_buff, 230 | Winograd_Kron::getInstance(m_alg, WINOGRAD_A)->get().get(), 231 | (Dtype)0, winogradRes); 232 | 233 | winograd_output_col2im_cpu(winogradRes, data); 234 | 235 | delete[] winogradRes; 236 | } 237 | 238 | template 239 | void winograd_input_im2col_cpu(const Dtype *data, Dtype *col_buff) 240 | { 241 | int height = m_iH; 242 | int width = m_iW; 243 | int pad_h = m_pad, pad_w = m_pad; 244 | 245 | for (int c = 0; c < this->conv_in_channels_; ++c) { 246 | for (int tile_h = 0; tile_h < ntiles_h_; ++tile_h) { 247 | for (int tile_w = 0; tile_w < ntiles_w_; ++tile_w) { 248 | 249 | 250 | 251 | for (int y = 0; y < tile_h_in_; ++y) { 252 | for (int x = 0; x < tile_w_in_; ++x) { 253 | int in_y = tile_h*tile_h_out_ + y - pad_h; 254 | int in_x = tile_w*tile_w_out_ + x - pad_w; 255 | 256 | if (in_y < 0 || in_x < 0 || in_y >= height || in_x >= width) { 257 | col_buff[(((c*ntiles_h_ + tile_h)*ntiles_w_ + tile_w)*tile_h_in_ + y)*tile_w_in_ + x] = 0; 258 | } 259 | else { 260 | col_buff[(((c*ntiles_h_ + tile_h)*ntiles_w_ + tile_w)*tile_h_in_ + y)*tile_w_in_ + x] = 261 | data[(c*height + in_y)*width + in_x]; 262 | } 263 | } 264 | } 265 | 266 | 267 | } // for each tile 268 | } // for each tile 269 | } // for each input channel 270 | } 271 | 272 | 273 | template 274 | void forward_cpu_bias(Dtype* output, 275 | const Dtype* bias) { 276 | 277 | int out_spatial_dim_ = m_oH * m_oW; 278 | 279 | for (int i = 0; i < conv_out_channels_; i++) { 280 | 281 | for (int j = 0; j < out_spatial_dim_; j++) 282 | output[i*out_spatial_dim_ + j] += bias[i]; 283 | 284 | } 285 | 286 | //dlm_cpu_gemm(CblasNoTrans, CblasNoTrans, conv_out_channels_, 287 | // out_spatial_dim_, 1, (Dtype)1., bias, bias_multiplier_.cpu_data(), 288 | // (Dtype)1., output); 289 | } 290 | 291 | 292 | template 293 | void winograd_output_col2im_cpu(const Dtype *col_buff, Dtype *data) 294 | { 295 | const int output_h = m_iH, output_w = m_iW; 296 | 297 | for (int n = 0; n < this->conv_out_channels_; ++n) { 298 | for (int tile_h = 0; tile_h < ntiles_h_; ++tile_h) { 299 | for (int tile_w = 0; tile_w < ntiles_w_; ++tile_w) { 300 | for (int y = 0; y < tile_h_out_; ++y) { 301 | for (int x = 0; x < tile_w_out_; ++x) { 302 | int out_y = tile_h*tile_h_out_ + y; 303 | int out_x = tile_w*tile_w_out_ + x; 304 | 305 | if (out_y < output_h && out_x < output_w) { 306 | 307 | /*int kk = 0; 308 | if ((((n*ntiles_h_ + tile_h)*ntiles_w_ + tile_w)*tile_h_out_ + y)*tile_w_out_ + x == 184604) 309 | kk++; 310 | 311 | cout << "dat: "<<(n*output_h + out_y)*output_w + out_x << " col : " << (((n*ntiles_h_ + tile_h)*ntiles_w_ + tile_w)*tile_h_out_ + y)*tile_w_out_ + x << endl;*/ 312 | 313 | data[(n*output_h + out_y)*output_w + out_x] = 314 | col_buff[(((n*ntiles_h_ + tile_h)*ntiles_w_ + tile_w)*tile_h_out_ + y)*tile_w_out_ + x]; 315 | } 316 | } 317 | } 318 | } // for each tile 319 | } // for each tile 320 | } // for each input channel 321 | } 322 | 323 | }; 324 | } 325 | 326 | #endif -------------------------------------------------------------------------------- /github_winograd/src/winograd_test.cpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "../include/winograd_kernel.h" 4 | #include "../include/winograd_layer.h" 5 | #include 6 | #include 7 | 8 | using namespace WINOGRAD_KERNEL; 9 | using namespace std; 10 | 11 | const int CIN = 3; 12 | const int COUT = 7; 13 | 14 | const int IH = 25; 15 | const int IW = 25; 16 | 17 | const int PRECISE = 0; 18 | 19 | #define INPUT_INTEGER 1 20 | #define KERNEL_INTEGER 1 21 | 22 | void testWinograd(); 23 | 24 | int main() { 25 | 26 | 27 | WINOGRAD_KERNEL::winograd2D_initialize(); 28 | 29 | testWinograd(); 30 | 31 | return 0; 32 | } 33 | 34 | void testWinograd() { 35 | 36 | //int batch_size = 1; 37 | 38 | int tiH = IH; 39 | int tiW = IW; 40 | 41 | int tkW = 3; 42 | int tkH = 3; 43 | 44 | int tsW = 1; 45 | int tsH = 1; 46 | 47 | int tiC = CIN; 48 | const int toC = COUT; 49 | 50 | bool tbias = true; 51 | 52 | int tpad = 1; 53 | 54 | const auto toH = (tiH + tpad * 2 - tkH) / tsH + 1; 55 | 56 | // Output width. 57 | const auto toW = (tiW + tpad * 2 - tkW) / tsW + 1; 58 | 59 | cout << setprecision(PRECISE); 60 | 61 | //NCHW 62 | float* input = new float[tiC*tiH*tiW]; 63 | float* kernel = new float[tiC*tkH*tkW*toC + toC]; 64 | 65 | //initInput 66 | for (int c = 0; c wt8X8( 90 | WINOGRAD_KERNEL::WT_8X8_F_6X6_3X3, //WT_6X6_F_4X4_3X3 91 | 1, 92 | tiH, 93 | tiW, 94 | tiC, 95 | tkH, 96 | tkW, 97 | tsH, 98 | tsW, 99 | toC, 100 | tpad, 101 | tbias 102 | ); 103 | 104 | WINOGRAD_KERNEL::WinogradLayer wt6x6( 105 | WINOGRAD_KERNEL::WT_6X6_F_4X4_3X3, //WT_6X6_F_4X4_3X3 106 | 1, 107 | tiH, 108 | tiW, 109 | tiC, 110 | tkH, 111 | tkW, 112 | tsH, 113 | tsW, 114 | toC, 115 | tpad, 116 | tbias 117 | ); 118 | 119 | float* buffer=new float [toH*toW*tiC*100];// enough buffer, used as medium buffer flowing through each layer 120 | 121 | shared_ptr output = wt8X8.get_inference_cpu(input, kernel, (float*)buffer); // 122 | 123 | cout << "the first three elements and the last one of the wt8x8 result:" << endl; 124 | cout << output.get()[0] << " " << output.get()[1] << " " << output.get()[2] << " " << output.get()[toC*toH*toW - 1] << " " << endl; 125 | 126 | 127 | output = wt6x6.get_inference_cpu(input, kernel, (float*)buffer); // 128 | 129 | cout << "the first three elements and the last one of the wt6x6 result:" << endl; 130 | cout << output.get()[0] << " " << output.get()[1] << " " << output.get()[2] << " " << output.get()[toC*toH*toW - 1] << " " << endl; 131 | 132 | delete[] buffer; 133 | } 134 | --------------------------------------------------------------------------------