├── .gitattributes
├── .gitignore
├── README.md
├── github_winograd.sln
└── github_winograd
├── github_winograd.vcxproj
├── github_winograd.vcxproj.filters
├── include
├── mathlib.h
├── tool.h
├── winograd_kernel.h
└── winograd_layer.h
└── src
└── winograd_test.cpp
/.gitattributes:
--------------------------------------------------------------------------------
1 | ###############################################################################
2 | # Set default behavior to automatically normalize line endings.
3 | ###############################################################################
4 | * text=auto
5 |
6 | ###############################################################################
7 | # Set default behavior for command prompt diff.
8 | #
9 | # This is need for earlier builds of msysgit that does not have it on by
10 | # default for csharp files.
11 | # Note: This is only used by command line
12 | ###############################################################################
13 | #*.cs diff=csharp
14 |
15 | ###############################################################################
16 | # Set the merge driver for project and solution files
17 | #
18 | # Merging from the command prompt will add diff markers to the files if there
19 | # are conflicts (Merging from VS is not affected by the settings below, in VS
20 | # the diff markers are never inserted). Diff markers may cause the following
21 | # file extensions to fail to load in VS. An alternative would be to treat
22 | # these files as binary and thus will always conflict and require user
23 | # intervention with every merge. To do so, just uncomment the entries below
24 | ###############################################################################
25 | #*.sln merge=binary
26 | #*.csproj merge=binary
27 | #*.vbproj merge=binary
28 | #*.vcxproj merge=binary
29 | #*.vcproj merge=binary
30 | #*.dbproj merge=binary
31 | #*.fsproj merge=binary
32 | #*.lsproj merge=binary
33 | #*.wixproj merge=binary
34 | #*.modelproj merge=binary
35 | #*.sqlproj merge=binary
36 | #*.wwaproj merge=binary
37 |
38 | ###############################################################################
39 | # behavior for image files
40 | #
41 | # image files are treated as binary by default.
42 | ###############################################################################
43 | #*.jpg binary
44 | #*.png binary
45 | #*.gif binary
46 |
47 | ###############################################################################
48 | # diff behavior for common document formats
49 | #
50 | # Convert binary document formats to text before diffing them. This feature
51 | # is only available from the command line. Turn it on by uncommenting the
52 | # entries below.
53 | ###############################################################################
54 | #*.doc diff=astextplain
55 | #*.DOC diff=astextplain
56 | #*.docx diff=astextplain
57 | #*.DOCX diff=astextplain
58 | #*.dot diff=astextplain
59 | #*.DOT diff=astextplain
60 | #*.pdf diff=astextplain
61 | #*.PDF diff=astextplain
62 | #*.rtf diff=astextplain
63 | #*.RTF diff=astextplain
64 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | ## Ignore Visual Studio temporary files, build results, and
2 | ## files generated by popular Visual Studio add-ons.
3 |
4 | # User-specific files
5 | *.suo
6 | *.user
7 | *.userosscache
8 | *.sln.docstates
9 |
10 | # User-specific files (MonoDevelop/Xamarin Studio)
11 | *.userprefs
12 |
13 | # Build results
14 | [Dd]ebug/
15 | [Dd]ebugPublic/
16 | [Rr]elease/
17 | [Rr]eleases/
18 | [Xx]64/
19 | [Xx]86/
20 | [Bb]uild/
21 | bld/
22 | [Bb]in/
23 | [Oo]bj/
24 |
25 | # Visual Studio 2015 cache/options directory
26 | .vs/
27 | # Uncomment if you have tasks that create the project's static files in wwwroot
28 | #wwwroot/
29 |
30 | # MSTest test Results
31 | [Tt]est[Rr]esult*/
32 | [Bb]uild[Ll]og.*
33 |
34 | # NUNIT
35 | *.VisualState.xml
36 | TestResult.xml
37 |
38 | # Build Results of an ATL Project
39 | [Dd]ebugPS/
40 | [Rr]eleasePS/
41 | dlldata.c
42 |
43 | # DNX
44 | project.lock.json
45 | artifacts/
46 |
47 | *_i.c
48 | *_p.c
49 | *_i.h
50 | *.ilk
51 | *.meta
52 | *.obj
53 | *.pch
54 | *.pdb
55 | *.pgc
56 | *.pgd
57 | *.rsp
58 | *.sbr
59 | *.tlb
60 | *.tli
61 | *.tlh
62 | *.tmp
63 | *.tmp_proj
64 | *.log
65 | *.vspscc
66 | *.vssscc
67 | .builds
68 | *.pidb
69 | *.svclog
70 | *.scc
71 |
72 | # Chutzpah Test files
73 | _Chutzpah*
74 |
75 | # Visual C++ cache files
76 | ipch/
77 | *.aps
78 | *.ncb
79 | *.opendb
80 | *.opensdf
81 | *.sdf
82 | *.cachefile
83 | *.VC.db
84 |
85 | # Visual Studio profiler
86 | *.psess
87 | *.vsp
88 | *.vspx
89 | *.sap
90 |
91 | # TFS 2012 Local Workspace
92 | $tf/
93 |
94 | # Guidance Automation Toolkit
95 | *.gpState
96 |
97 | # ReSharper is a .NET coding add-in
98 | _ReSharper*/
99 | *.[Rr]e[Ss]harper
100 | *.DotSettings.user
101 |
102 | # JustCode is a .NET coding add-in
103 | .JustCode
104 |
105 | # TeamCity is a build add-in
106 | _TeamCity*
107 |
108 | # DotCover is a Code Coverage Tool
109 | *.dotCover
110 |
111 | # NCrunch
112 | _NCrunch_*
113 | .*crunch*.local.xml
114 | nCrunchTemp_*
115 |
116 | # MightyMoose
117 | *.mm.*
118 | AutoTest.Net/
119 |
120 | # Web workbench (sass)
121 | .sass-cache/
122 |
123 | # Installshield output folder
124 | [Ee]xpress/
125 |
126 | # DocProject is a documentation generator add-in
127 | DocProject/buildhelp/
128 | DocProject/Help/*.HxT
129 | DocProject/Help/*.HxC
130 | DocProject/Help/*.hhc
131 | DocProject/Help/*.hhk
132 | DocProject/Help/*.hhp
133 | DocProject/Help/Html2
134 | DocProject/Help/html
135 |
136 | # Click-Once directory
137 | publish/
138 |
139 | # Publish Web Output
140 | *.[Pp]ublish.xml
141 | *.azurePubxml
142 |
143 | # TODO: Un-comment the next line if you do not want to checkin
144 | # your web deploy settings because they may include unencrypted
145 | # passwords
146 | #*.pubxml
147 | *.publishproj
148 |
149 | # NuGet Packages
150 | *.nupkg
151 | # The packages folder can be ignored because of Package Restore
152 | **/packages/*
153 | # except build/, which is used as an MSBuild target.
154 | !**/packages/build/
155 | # Uncomment if necessary however generally it will be regenerated when needed
156 | #!**/packages/repositories.config
157 | # NuGet v3's project.json files produces more ignoreable files
158 | *.nuget.props
159 | *.nuget.targets
160 |
161 | # Microsoft Azure Build Output
162 | csx/
163 | *.build.csdef
164 |
165 | # Microsoft Azure Emulator
166 | ecf/
167 | rcf/
168 |
169 | # Windows Store app package directory
170 | AppPackages/
171 | BundleArtifacts/
172 |
173 | # Visual Studio cache files
174 | # files ending in .cache can be ignored
175 | *.[Cc]ache
176 | # but keep track of directories ending in .cache
177 | !*.[Cc]ache/
178 |
179 | # Others
180 | ClientBin/
181 | [Ss]tyle[Cc]op.*
182 | ~$*
183 | *~
184 | *.dbmdl
185 | *.dbproj.schemaview
186 | *.pfx
187 | *.publishsettings
188 | node_modules/
189 | orleans.codegen.cs
190 |
191 | # RIA/Silverlight projects
192 | Generated_Code/
193 |
194 | # Backup & report files from converting an old project file
195 | # to a newer Visual Studio version. Backup files are not needed,
196 | # because we have git ;-)
197 | _UpgradeReport_Files/
198 | Backup*/
199 | UpgradeLog*.XML
200 | UpgradeLog*.htm
201 |
202 | # SQL Server files
203 | *.mdf
204 | *.ldf
205 |
206 | # Business Intelligence projects
207 | *.rdl.data
208 | *.bim.layout
209 | *.bim_*.settings
210 |
211 | # Microsoft Fakes
212 | FakesAssemblies/
213 |
214 | # GhostDoc plugin setting file
215 | *.GhostDoc.xml
216 |
217 | # Node.js Tools for Visual Studio
218 | .ntvs_analysis.dat
219 |
220 | # Visual Studio 6 build log
221 | *.plg
222 |
223 | # Visual Studio 6 workspace options file
224 | *.opt
225 |
226 | # Visual Studio LightSwitch build output
227 | **/*.HTMLClient/GeneratedArtifacts
228 | **/*.DesktopClient/GeneratedArtifacts
229 | **/*.DesktopClient/ModelManifest.xml
230 | **/*.Server/GeneratedArtifacts
231 | **/*.Server/ModelManifest.xml
232 | _Pvt_Extensions
233 |
234 | # LightSwitch generated files
235 | GeneratedArtifacts/
236 | ModelManifest.xml
237 |
238 | # Paket dependency manager
239 | .paket/paket.exe
240 |
241 | # FAKE - F# Make
242 | .fake/
243 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Winograd_Convolution
2 | Winograd_Convolution is a winograd based kernel for convolutions in deep learning frameworks, which is an implementation of winograd convolutions in [1]. Three WT methods, WT_6X6_F_4X4_3X3, WT_8X8_F_4X4_5X5, and WT_8X8_F_6X6_3X3, are supported, where convolution kernel 3x3 is the best choice. Parts of this work are from SkimCaffe [2], but this winograd kernel is more portable.
3 |
4 | Dependencies
5 | -----------------------------------
6 |
7 | A fast blas is better, such as mkl-gemm and openblas [3].
8 |
9 | Building
10 | -----------------------------------
11 |
12 | Only header files written in C++, supports windows and linux. This version is built on VS 2015.
13 |
14 | Testing
15 | -----------------------------------
16 |
17 | See winograd_test.cpp.
18 |
19 | Packaging
20 | -----------------------------------
21 |
22 | "include/winograd_layer.h", can be natively integrated into some famous deep learning frameworks as a winograd_layer, like caffe (https://github.com/BVLC/caffe) and tiny-dnn (https://github.com/tiny-dnn/tiny-dnn).
23 |
24 | References & Dependencies
25 | -----------------------------------
26 | [1] Andrew Lavin, Scott Gray. Fast Algorithms for Convolutional Neural Networks. https://arxiv.org/abs/1509.09308
27 |
28 | [2] SkimCaffe, https://github.com/IntelLabs/SkimCaffe
29 |
30 | [3] OpenBLAS, https://github.com/xianyi/OpenBLAS.
31 |
32 | License
33 | -----------------------------------
34 | The BSD 3-Clause License
35 |
--------------------------------------------------------------------------------
/github_winograd.sln:
--------------------------------------------------------------------------------
1 |
2 | Microsoft Visual Studio Solution File, Format Version 12.00
3 | # Visual Studio 14
4 | VisualStudioVersion = 14.0.25420.1
5 | MinimumVisualStudioVersion = 10.0.40219.1
6 | Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "github_winograd", "github_winograd\github_winograd.vcxproj", "{91EDB49D-AAD0-47A7-A8C3-A87BBDF7099C}"
7 | EndProject
8 | Global
9 | GlobalSection(SolutionConfigurationPlatforms) = preSolution
10 | Debug|x64 = Debug|x64
11 | Debug|x86 = Debug|x86
12 | Release|x64 = Release|x64
13 | Release|x86 = Release|x86
14 | EndGlobalSection
15 | GlobalSection(ProjectConfigurationPlatforms) = postSolution
16 | {91EDB49D-AAD0-47A7-A8C3-A87BBDF7099C}.Debug|x64.ActiveCfg = Debug|x64
17 | {91EDB49D-AAD0-47A7-A8C3-A87BBDF7099C}.Debug|x64.Build.0 = Debug|x64
18 | {91EDB49D-AAD0-47A7-A8C3-A87BBDF7099C}.Debug|x86.ActiveCfg = Debug|Win32
19 | {91EDB49D-AAD0-47A7-A8C3-A87BBDF7099C}.Debug|x86.Build.0 = Debug|Win32
20 | {91EDB49D-AAD0-47A7-A8C3-A87BBDF7099C}.Release|x64.ActiveCfg = Release|x64
21 | {91EDB49D-AAD0-47A7-A8C3-A87BBDF7099C}.Release|x64.Build.0 = Release|x64
22 | {91EDB49D-AAD0-47A7-A8C3-A87BBDF7099C}.Release|x86.ActiveCfg = Release|Win32
23 | {91EDB49D-AAD0-47A7-A8C3-A87BBDF7099C}.Release|x86.Build.0 = Release|Win32
24 | EndGlobalSection
25 | GlobalSection(SolutionProperties) = preSolution
26 | HideSolutionNode = FALSE
27 | EndGlobalSection
28 | EndGlobal
29 |
--------------------------------------------------------------------------------
/github_winograd/github_winograd.vcxproj:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | Debug
6 | Win32
7 |
8 |
9 | Release
10 | Win32
11 |
12 |
13 | Debug
14 | x64
15 |
16 |
17 | Release
18 | x64
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 | {91EDB49D-AAD0-47A7-A8C3-A87BBDF7099C}
32 | github_winograd
33 | 8.1
34 | Winograd_Convolution
35 |
36 |
37 |
38 | Application
39 | true
40 | v140
41 | MultiByte
42 |
43 |
44 | Application
45 | false
46 | v140
47 | true
48 | MultiByte
49 |
50 |
51 | Application
52 | true
53 | v140
54 | MultiByte
55 |
56 |
57 | Application
58 | false
59 | v140
60 | true
61 | MultiByte
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 | $(SolutionDir)\packages\OpenBLAS.0.2.14.1\lib\native\include;$(VC_IncludePath);$(WindowsSDK_IncludePath);
83 | $(SolutionDir)\packages\OpenBLAS.0.2.14.1\lib\native\lib\x64;$(VC_LibraryPath_x64);$(VC_LibraryPath_x86);$(WindowsSDK_LibraryPath_x86);$(NETFXKitsDir)Lib\um\x86
84 |
85 |
86 | $(SolutionDir)\packages\OpenBLAS.0.2.14.1\lib\native\include;$(VC_IncludePath);$(WindowsSDK_IncludePath);
87 | $(SolutionDir)\packages\OpenBLAS.0.2.14.1\lib\native\lib\x64;$(VC_LibraryPath_x64);$(VC_LibraryPath_x86);$(WindowsSDK_LibraryPath_x86);$(VC_LibraryPath_x86);$(WindowsSDK_LibraryPath_x86);$(NETFXKitsDir)Lib\um\x86
88 |
89 |
90 | $(SolutionDir)\packages\OpenBLAS.0.2.14.1\lib\native\include;$(VC_IncludePath);$(VC_IncludePath);$(WindowsSDK_IncludePath);
91 | $(SolutionDir)\packages\OpenBLAS.0.2.14.1\lib\native\lib\x64;$(VC_LibraryPath_x64);$(VC_LibraryPath_x64);$(WindowsSDK_LibraryPath_x64);$(NETFXKitsDir)Lib\um\x64
92 |
93 |
94 | $(SolutionDir)\packages\OpenBLAS.0.2.14.1\lib\native\lib\x64;$(VC_LibraryPath_x64);$(WindowsSDK_LibraryPath_x64);$(NETFXKitsDir)Lib\um\x64
95 | $(SolutionDir)\packages\OpenBLAS.0.2.14.1\lib\native\include;$(VC_IncludePath);$(WindowsSDK_IncludePath);
96 |
97 |
98 |
99 | Level3
100 | Disabled
101 | true
102 |
103 |
104 | libopenblas.dll.a;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)
105 |
106 |
107 |
108 |
109 | Level3
110 | Disabled
111 | true
112 |
113 |
114 | libopenblas.dll.a;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)
115 |
116 |
117 |
118 |
119 | Level3
120 | MaxSpeed
121 | true
122 | true
123 | true
124 |
125 |
126 | true
127 | true
128 | libopenblas.dll.a;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)
129 |
130 |
131 |
132 |
133 | Level3
134 | MaxSpeed
135 | true
136 | true
137 | true
138 |
139 |
140 | true
141 | true
142 | libopenblas.dll.a;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)
143 |
144 |
145 |
146 |
147 |
148 |
--------------------------------------------------------------------------------
/github_winograd/github_winograd.vcxproj.filters:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | {4FC737F1-C7A5-4376-A066-2A32D752A2FF}
6 | cpp;c;cc;cxx;def;odl;idl;hpj;bat;asm;asmx
7 |
8 |
9 | {93995380-89BD-4b04-88EB-625FBE52EBFB}
10 | h;hh;hpp;hxx;hm;inl;inc;xsd
11 |
12 |
13 | {67DA6AB6-F800-4c08-8B7A-83BB121AAD01}
14 | rc;ico;cur;bmp;dlg;rc2;rct;bin;rgs;gif;jpg;jpeg;jpe;resx;tiff;tif;png;wav;mfcribbon-ms
15 |
16 |
17 |
18 |
19 | Header Files
20 |
21 |
22 | Header Files
23 |
24 |
25 | Header Files
26 |
27 |
28 | Header Files
29 |
30 |
31 |
32 |
33 | Source Files
34 |
35 |
36 |
--------------------------------------------------------------------------------
/github_winograd/include/mathlib.h:
--------------------------------------------------------------------------------
1 | #ifndef MATHLAB_H
2 | #define MATHLAB_H
3 |
4 | #define USE_MKL 0
5 | #define USE_OPENBLAS 1
6 |
7 | #if USE_MKL
8 | #include
9 |
10 | #elif USE_OPENBLAS
11 | #include
12 |
13 | #endif
14 |
15 | #endif
--------------------------------------------------------------------------------
/github_winograd/include/tool.h:
--------------------------------------------------------------------------------
1 | #ifndef PUBLIC_TOOL_H
2 | #define PUBLIC_TOOL_H
3 |
4 | #include
5 | #include
6 | #include "mathlib.h"
7 |
8 | namespace PUBLIC_TOOL{
9 |
10 | template
11 | Dtype max(Dtype a, Dtype b) {
12 | if (a > b) return a;
13 | else return b;
14 | }
15 |
16 | template
17 | Dtype min(Dtype a, Dtype b) {
18 | if (a < b) return a;
19 | else return b;
20 | }
21 |
22 | void dlm_cpu_gemm(const CBLAS_TRANSPOSE TransA,
23 | const CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
24 | const float alpha, const float* A, const float* B, const float beta,
25 | float* C) {
26 | int lda = (TransA == CblasNoTrans) ? K : M;
27 | int ldb = (TransB == CblasNoTrans) ? N : K;
28 | cblas_sgemm(CblasRowMajor, TransA, TransB, M, N, K, alpha, A, lda, B,
29 | ldb, beta, C, N);
30 | }
31 |
32 | void dlm_cpu_gemm(const CBLAS_TRANSPOSE TransA,
33 | const CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
34 | const double alpha, const double* A, const double* B, const double beta,
35 | double* C) {
36 | int lda = (TransA == CblasNoTrans) ? K : M;
37 | int ldb = (TransB == CblasNoTrans) ? N : K;
38 | cblas_dgemm(CblasRowMajor, TransA, TransB, M, N, K, alpha, A, lda, B,
39 | ldb, beta, C, N);
40 | }
41 |
42 |
43 | };
44 |
45 | #endif
--------------------------------------------------------------------------------
/github_winograd/include/winograd_kernel.h:
--------------------------------------------------------------------------------
1 |
2 | #ifndef WINOGRAD_KERNEL_H
3 | #define WINOGRAD_KERNEL_H
4 |
5 | #include
6 |
7 | #define DEBUG_WINOGRAD 1
8 |
9 | #if DEBUG_WINOGRAD
10 | #include
11 | #endif
12 |
13 | namespace WINOGRAD_KERNEL {
14 |
15 | const enum WINOGRAD_MATRIX {
16 | WINOGRAD_A = 0,
17 | WINOGRAD_B,
18 | WINOGRAD_G,
19 | };
20 | const enum WINOGRAD_ALG {
21 | WT_8X8_F_6X6_3X3 = 0,
22 | WT_6X6_F_4X4_3X3,
23 | WT_8X8_F_4X4_5X5,
24 | };
25 |
26 | const int MATRIX_KINDS = 3;
27 | const int WINOGRAD_PAIR_KINDS = 3;
28 |
29 | template
30 | struct WinogradTransformMatrix {};
31 |
32 | /**
33 | * compute Kronecker product of in1 and in2, where in1 is a m by n matrix and in2 is a p by q matrix
34 | *
35 | * @params out an (m*p) by (n*q) matrix stored in row major
36 | * @params in1 an m by n matrix stored in row major
37 | * @params in2 an p by q matrix stored in row major
38 | */
39 | void kronecker_product(float *out, const float *in1, const float *in2, int m, int n, int p, int q);
40 |
41 | //singleton, precomputation before inference
42 | void winograd2D_initialize();
43 |
44 | template<>
45 | struct WinogradTransformMatrix
46 | {
47 | // wt6x6, F(4x4,3x3)
48 | private:
49 | static const int O = 4;
50 | static const int K = 3;
51 | static const int T = O + K - 1;
52 |
53 | static const float *getG() {
54 | static const float G[T*K] = {
55 | 1. / 4., 0, 0,
56 | -1. / 6., -1. / 6., -1. / 6.,
57 | -1. / 6., 1. / 6., -1. / 6.,
58 | 1. / 24., 1. / 12., 1. / 6.,
59 | 1. / 24., -1. / 12., 1. / 6.,
60 | 0, 0, 1,
61 | };
62 | return G;
63 | }
64 |
65 | static const float *getA() {
66 | static const float A[T*O] = {
67 | 1, 0, 0, 0,
68 | 1, 1, 1, 1,
69 | 1, -1, 1, -1,
70 | 1, 2, 4, 8,
71 | 1, -2, 4, -8,
72 | 0, 0, 0, 1,
73 | };
74 | return A;
75 | }
76 |
77 | static const float *getB() {
78 | static const float B[T*T] = {
79 | 4, 0, 0, 0, 0, 0,
80 | 0, -4, 4, -2, 2, 4,
81 | -5, -4, -4, -1, -1, 0,
82 | 0, 1, -1, 2, -2, -5,
83 | 1, 1, 1, 1, 1, 0,
84 | 0, 0, 0, 0, 0, 1,
85 | };
86 | return B;
87 | };
88 |
89 | public:
90 | static const float *get(WINOGRAD_MATRIX mat, int &row, int& col) {
91 |
92 | #if DEBUG_WINOGRAD
93 | assert(mat >= WINOGRAD_A && mat <= WINOGRAD_G);
94 | #endif
95 | switch (mat) {
96 |
97 | case WINOGRAD_A: row = T; col = O; return getA();
98 | case WINOGRAD_B: row = T; col = T; return getB();
99 | case WINOGRAD_G: row = T; col = K; return getG();
100 |
101 | }
102 |
103 | }
104 |
105 | };
106 |
107 | template<>
108 | struct WinogradTransformMatrix
109 | {
110 |
111 | private:
112 |
113 | // wt8x8, F(6x6,3x3)
114 |
115 | static const int O = 6;
116 | static const int K = 3;
117 | static const int T = O + K - 1;
118 |
119 | public:
120 | static const float *get(WINOGRAD_MATRIX mat, int &row, int& col) {
121 |
122 | #if DEBUG_WINOGRAD
123 | assert(mat >= WINOGRAD_A && mat <= WINOGRAD_G);
124 | #endif
125 | switch (mat) {
126 |
127 | case WINOGRAD_A: row = T; col = O; return getA();
128 | case WINOGRAD_B: row = T; col = T; return getB();
129 | case WINOGRAD_G: row = T; col = K; return getG();
130 |
131 | }
132 |
133 | }
134 |
135 | private:
136 | static const float *getG() {
137 | static const float G[T*K] = {
138 | 1.f, 0.f , 0.f ,
139 | -2.f / 9 , -2.f / 9 , -2.f / 9,
140 | -2.f / 9 , 2.f / 9 , -2.f / 9,
141 | 1.f / 90 , 1.f / 45 , 2.f / 45,
142 | 1.f / 90 , -1.f / 45 , 2.f / 45,
143 | 32.f / 45, 16.f / 45, 8.f / 45,
144 | 32.f / 45, -16.f / 45, 8.f / 45,
145 | 0.f , 0.f , 1.f ,
146 | };
147 | return G;
148 | }
149 |
150 | static const float *getA() {
151 | static const float A[T*(T - K + 1)] = {
152 | 1 * 1.f, 0 * 1.f, 0 * 1.f, 0 * 1.f, 0 * 1.f, 0 * 1.f,
153 | 1 * 1.f, 1 * 1.f, 1 * 1.f, 1 * 1.f, 1 * 1.f, 1 * 1.f,
154 | 1 * 1.f, -1 * 1.f, 1 * 1.f, -1 * 1.f, 1 * 1.f, -1 * 1.f,
155 | 1 * 1.f, 2 * 1.f, 4 * 1.f, 8 * 1.f, 16 * 1.f, 32 * 1.f,
156 | 1 * 1.f, -2 * 1.f, 4 * 1.f, -8 * 1.f, 16 * 1.f, -32 * 1.f,
157 | 1 * 1.f, 0.5*1.f, 0.25*1.f, 0.125*1.f, 0.0625*1.f, 0.03125*1.f,
158 | 1 * 1.f, -0.5*1.f, 0.25*1.f, -0.125*1.f, 0.0625*1.f, -0.03125*1.f,
159 | 0 * 1.f, 0 * 1.f, 0 * 1.f, 0 * 1.f, 0 * 1.f, 1 * 1.f,
160 | };
161 | return A;
162 | }
163 |
164 | static const float *getB() {
165 | static const float B[T*T] = {
166 | 1 * 1.f, 0 * 1.f, 0 * 1.f, 0 * 1.f, 0 * 1.f, 0 * 1.f, 0 * 1.f, 0 * 1.f,
167 | 0 * 1.f, 1 * 1.f, -1 * 1.f, 0.5*1.f, -0.5*1.f, 2 * 1.f, -2 * 1.f, -1 * 1.f,
168 | -5.25*1.f, 1 * 1.f, 1 * 1.f, 0.25*1.f, 0.25*1.f, 4 * 1.f, 4 * 1.f, 0 * 1.f,
169 | 0 * 1.f, -4.25*1.f, 4.25*1.f, -2.5*1.f, 2.5*1.f, -2.5*1.f, 2.5*1.f, 5.25*1.f,
170 | 5.25*1.f, -4.25*1.f, -4.25*1.f, -1.25*1.f, -1.25*1.f, -5 * 1.f, -5 * 1.f, 0 * 1.f,
171 | 0 * 1.f, 1 * 1.f, -1 * 1.f, 2 * 1.f, -2 * 1.f, 0.5*1.f, -0.5*1.f, -5.25*1.f,
172 | -1 * 1.f, 1 * 1.f, 1 * 1.f, 1 * 1.f, 1 * 1.f, 1 * 1.f, 1 * 1.f, 0 * 1.f,
173 | 0 * 1.f, 0 * 1.f, 0 * 1.f, 0 * 1.f, 0 * 1.f, 0 * 1.f, 0 * 1.f, 1 * 1.f,
174 | };
175 | return B;
176 | };
177 | };
178 |
179 | template<>
180 | struct WinogradTransformMatrix
181 | {
182 |
183 | private:
184 | // wt8x8, F(4x4,5x5)
185 | static const int T = 5 + 4 - 1;
186 | static const int K = 5;
187 | static const int O = 4;
188 |
189 | public:
190 | static const float *get(WINOGRAD_MATRIX mat, int &row, int& col) {
191 |
192 | #if DEBUG_WINOGRAD
193 | assert(mat >= WINOGRAD_A && mat <= WINOGRAD_G);
194 | #endif
195 | switch (mat) {
196 |
197 | case WINOGRAD_A: row = T; col = O; return getA();
198 | case WINOGRAD_B: row = T; col = T; return getB();
199 | case WINOGRAD_G: row = T; col = K; return getG();
200 |
201 | }
202 |
203 | }
204 |
205 | private:
206 |
207 | // from https://github.com/Maratyszcza/NNPACK/issues/12
208 |
209 | static const float *getG() {
210 | static const float G[T*K] = {
211 | 1, 0, 0, 0, 0,
212 | -2. / 9., -2. / 9., -2. / 9., -2. / 9., -2. / 9.,
213 | -2. / 9., 2. / 9., -2. / 9., 2. / 9., -2. / 9.,
214 | 1. / 90., 1. / 45., 2. / 45., 4. / 45., 8. / 45.,
215 | 1. / 90., -1. / 45., 2. / 45., -4. / 45., 8. / 45.,
216 | 4. / 45., 2. / 45., 1. / 45., 1. / 90., 1. / 180.,
217 | 4. / 45., -2. / 45., 1. / 45., -1. / 90., 1. / 180.,
218 | 0, 0, 0, 0, 1,
219 | };
220 | return G;
221 | }
222 |
223 |
224 |
225 |
226 | static const float *getA() {
227 | static const float A[T*(O)] = {
228 | 1, 0, 0, 0,
229 | 1, 1, 1, 1,
230 | 1, -1, 1, -1,
231 | 1, 2, 4, 8,
232 | 1, -2, 4, -8,
233 | 8, 4, 2, 1,
234 | 8, -4, 2, -1,
235 | 0, 0, 0, 1
236 | };
237 | return A;
238 | }
239 |
240 | static const float *getB() {
241 | static const float B[T*T] = {
242 | 1, 0, 0, 0, 0, 0, 0, 0,
243 | 0, 1, -1, 1. / 2, -1. / 2, 2, -2, -1,
244 | -21. / 4, 1, 1, 1. / 4, 1. / 4, 4, 4, 0,
245 | 0, -17. / 4, 17. / 4, -5. / 2, 5. / 2, -5. / 2, 5. / 2, 21. / 4,
246 | 21. / 4, -17. / 4, -17. / 4, -5. / 4, -5. / 4, -5, -5, 0,
247 | 0, 1, -1, 2, -2, 1. / 2, -1. / 2, -21. / 4,
248 | -1, 1, 1, 1, 1, 1, 1, 0,
249 | 0, 0, 0, 0, 0, 0, 0, 1,
250 | };
251 | return B;
252 | }
253 | };
254 |
255 | class Winograd_Kron
256 | {
257 |
258 | private:
259 |
260 | Winograd_Kron(WINOGRAD_ALG alg, WINOGRAD_MATRIX mat) {
261 |
262 | isCalc = false;
263 |
264 | switch (alg) {
265 |
266 | case WT_8X8_F_6X6_3X3:
267 | matrix = WinogradTransformMatrix::get(mat, row, col); break;
268 | case WT_6X6_F_4X4_3X3:
269 | matrix = WinogradTransformMatrix::get(mat, row, col); break;
270 | case WT_8X8_F_4X4_5X5:
271 | matrix = WinogradTransformMatrix::get(mat, row, col); break;
272 |
273 | }
274 |
275 | }
276 |
277 | private:
278 | const float *matrix; // = A, B, G
279 | int row, col;// matrix: row*col
280 | // A: T*O
281 | // B: M*M
282 | // G: T*K
283 |
284 | std::shared_ptr kron;
285 |
286 | bool isCalc;
287 |
288 | //std::shared_ptr normOfRowsInv;
289 |
290 | public:
291 |
292 | static Winograd_Kron *getInstance(WINOGRAD_ALG alg, WINOGRAD_MATRIX mat) {
293 |
294 | // 9 instances 3*3
295 | static Winograd_Kron * instances[MATRIX_KINDS *WINOGRAD_PAIR_KINDS] = { NULL }; // according to [WINOGRAD_MATRIX] [WINOGRAD_PAIR]
296 |
297 | int index = alg*WINOGRAD_PAIR_KINDS + mat;
298 |
299 | if (instances[index] == NULL)
300 | instances[index] = new Winograd_Kron(alg, mat);
301 |
302 | return instances[index];
303 |
304 | }
305 |
306 | const std::shared_ptr get() {
307 | if (isCalc)
308 | return kron;
309 | else {
310 | calc();
311 | return kron;
312 | }
313 |
314 | }
315 |
316 | private:
317 |
318 | void calc() {
319 |
320 | kron = std::shared_ptr(new float[row*col*row*col]);
321 |
322 | kronecker_product(kron.get(), matrix, matrix, row, col, row, col);
323 |
324 | isCalc = true;
325 |
326 | }
327 |
328 | };
329 |
330 | void kronecker_product(float *out, const float *in1, const float *in2, int m, int n, int p, int q)
331 | {
332 | for (int i = 0; i < m; ++i) {
333 | for (int j = 0; j < n; ++j) {
334 | for (int k = 0; k < p; ++k) {
335 | for (int l = 0; l < q; ++l) {
336 | out[(p*i + k)*n*q + q*j + l] = in1[n*i + j] * in2[k*q + l];
337 | /* compute in float precision and then convert it back to float for accuracy */
338 | }
339 | }
340 | }
341 | }
342 | }
343 |
344 | void winograd2D_initialize() {
345 | //singleton, precomputation before inference
346 |
347 | Winograd_Kron::getInstance(WT_6X6_F_4X4_3X3, WINOGRAD_A)->get();
348 | Winograd_Kron::getInstance(WT_6X6_F_4X4_3X3, WINOGRAD_B)->get();
349 | Winograd_Kron::getInstance(WT_6X6_F_4X4_3X3, WINOGRAD_G)->get();
350 |
351 | Winograd_Kron::getInstance(WT_8X8_F_6X6_3X3, WINOGRAD_A)->get();
352 | Winograd_Kron::getInstance(WT_8X8_F_6X6_3X3, WINOGRAD_B)->get();
353 | Winograd_Kron::getInstance(WT_8X8_F_6X6_3X3, WINOGRAD_G)->get();
354 | }
355 | }
356 |
357 |
358 | #endif
--------------------------------------------------------------------------------
/github_winograd/include/winograd_layer.h:
--------------------------------------------------------------------------------
1 | #ifndef WINOGRAD_LAYER_H
2 | #define WINOGRAD_LAYER_H
3 |
4 | #include
5 | #include "winograd_kernel.h"
6 | #include "tool.h"
7 | //winograd for cpu inference
8 |
9 | // default 3x3
10 | const int KERNEL_SIZE = 3;
11 |
12 | namespace WINOGRAD_KERNEL
13 | {
14 |
15 | template
16 | class WinogradLayer {
17 |
18 | private:
19 |
20 | int m_group_;
21 | int m_batchSize;
22 |
23 | int m_bottom_dim_;// par size
24 | int m_top_dim_;
25 |
26 | // The following variables are initialized in WeightAlign
27 | int tile_h_in_, tile_w_in_; /* input tile size */
28 | int tile_h_out_, tile_w_out_; /* output tile size */
29 | int ntiles_h_, ntiles_w_; /* number of tiles */
30 |
31 | int conv_in_channels_; //ic
32 | int conv_out_channels_;//oc
33 |
34 | int m_iH;
35 | int m_iW;
36 |
37 | int m_oH;
38 | int m_oW;
39 |
40 | int m_kH;
41 | int m_kW;
42 | int m_sH;
43 | int m_sW;
44 |
45 | int m_pad;
46 | bool m_bias;
47 |
48 | private:
49 |
50 | Dtype* m_inputOrg;
51 | const Dtype* m_weightOrg;
52 |
53 | Dtype* m_winogradWeight; // support NCHW storage
54 | Dtype* m_winogradInput;
55 |
56 | Dtype* m_col_buff;//buffer
57 |
58 | WINOGRAD_ALG m_alg;
59 |
60 | public:
61 |
62 | WinogradLayer(WINOGRAD_ALG alg, int batch_size, int iH, int iW, int iC, int kH, int kW, int sH, int sW, int oC, int pad, bool bias = true) : m_alg(alg) {
63 |
64 | #if DEBUG_WINOGRAD
65 | assert(kH == kW, "kernel 3x3 is the best choice, some errors may occur for other kernels");
66 | #endif
67 | m_iH = iH;
68 | m_iW = iW;
69 | conv_in_channels_ = iC;
70 | m_kH = kH;
71 | m_kW = kW;
72 | m_sH = sH;
73 | m_sW = sW;
74 | conv_out_channels_ = oC;
75 | m_pad = pad; // pad_h = pad_w
76 | m_bias = bias;
77 |
78 | m_batchSize = batch_size;
79 | m_group_ = 1;
80 |
81 | m_bottom_dim_ = 0;// default batch =1
82 | m_top_dim_ = 0;
83 |
84 | m_winogradWeight = NULL;
85 | m_winogradInput = NULL;
86 |
87 |
88 | // Output width.
89 | m_oW = (m_iW + m_pad * 2 - m_kW) / m_sW + 1;
90 | m_oH = (m_iH + m_pad * 2 - m_kH) / m_sH + 1;
91 |
92 | if (alg == WT_8X8_F_6X6_3X3) {
93 |
94 | tile_h_in_ = 8;
95 | tile_w_in_ = 8; /* input tile size */
96 |
97 | tile_h_out_ = tile_h_in_ - m_kH + 1;
98 | tile_w_out_ = tile_w_in_ - m_kW + 1; /* output tile size */
99 |
100 | ntiles_h_ = (PUBLIC_TOOL::max(m_iH + m_pad - tile_h_in_ + 1, m_oH) + tile_h_out_ - 1) / tile_h_out_;
101 | ntiles_w_ = (PUBLIC_TOOL::max(m_iW + m_pad - tile_w_in_ + 1, m_oW) + tile_w_out_ - 1) / tile_w_out_;
102 |
103 | }
104 | else if (alg == WT_6X6_F_4X4_3X3) {
105 |
106 | tile_h_in_ = 6;
107 | tile_w_in_ = 6; /* input tile size */
108 |
109 | tile_h_out_ = tile_h_in_ - m_kH + 1;
110 | tile_w_out_ = tile_w_in_ - m_kW + 1; /* output tile size */
111 |
112 | ntiles_h_ = (PUBLIC_TOOL::max(m_iH + m_pad - tile_h_in_ + 1, m_oH) + tile_h_out_ - 1) / tile_h_out_;
113 | ntiles_w_ = (PUBLIC_TOOL::max(m_iW + m_pad - tile_w_in_ + 1, m_oW) + tile_w_out_ - 1) / tile_w_out_;
114 |
115 | }
116 | else throw("convolution algorithm error!");
117 |
118 | }
119 |
120 | template
121 | const std::shared_ptr get_inference_cpu(Dtype* data, const Dtype* par, Dtype* col_buff) {
122 |
123 | m_inputOrg = data;
124 | m_weightOrg = par;
125 | m_col_buff = col_buff;
126 |
127 |
128 | std::shared_ptr resOut = std::shared_ptr(new Dtype[m_oH*m_oW*conv_out_channels_]);
129 |
130 | //trans weight to winograd domain
131 | trans_weight2wiongrad();
132 |
133 |
134 | for (int n = 0; n < m_batchSize; n++) {
135 |
136 | //trans input to winograd domain
137 | trans_input2winograd(m_inputOrg + n*m_bottom_dim_, m_col_buff);
138 |
139 |
140 | // Convolution in Winograd domain
141 | winograd_conv();
142 |
143 |
144 | // Transform back to time domain
145 | trans2spatial(resOut.get() + n*this->m_top_dim_);
146 |
147 | //bias
148 | if (this->m_bias) {
149 |
150 | int base = conv_in_channels_ * conv_out_channels_ * m_kW * m_kH;
151 |
152 | const Dtype* bias = &par[base];
153 |
154 | this->forward_cpu_bias(resOut.get() + n * this->m_top_dim_, bias);
155 | }
156 | }
157 |
158 | return resOut;
159 | }
160 |
161 |
162 | public:
163 | ~WinogradLayer() {
164 | /*if (!m_winogradInput) delete[] m_winogradInput;
165 | if (!m_winogradWeight) delete[] m_winogradWeight;*/
166 | }
167 |
168 |
169 | private:
170 |
171 | void trans_weight2wiongrad() {// weight: hwcn --> cn hw
172 |
173 | // transform weights to Winograd domain
174 | if (!m_winogradWeight) m_winogradWeight = new Dtype[conv_in_channels_*conv_out_channels_*tile_h_in_*tile_w_in_];
175 |
176 | PUBLIC_TOOL::dlm_cpu_gemm(CblasNoTrans, CblasTrans,
177 | tile_h_in_*tile_w_in_, (conv_in_channels_ / m_group_)*conv_out_channels_, m_kH*m_kW,
178 | (Dtype)1,
179 | Winograd_Kron::getInstance(m_alg, WINOGRAD_G)->get().get(),
180 | m_weightOrg,
181 | (Dtype)0,
182 | m_winogradWeight);
183 |
184 | }
185 |
186 | template
187 | void trans_input2winograd(const Dtype *data, Dtype *col_buff) {
188 | // Transform input to Winograd domain
189 |
190 | winograd_input_im2col_cpu(data, col_buff);
191 |
192 |
193 | int M = this->conv_in_channels_*ntiles_h_*ntiles_w_;
194 |
195 | if (!m_winogradInput) m_winogradInput = new Dtype[tile_h_in_*tile_w_in_*M];
196 |
197 | PUBLIC_TOOL::dlm_cpu_gemm(CblasTrans, CblasTrans,
198 | tile_h_in_*tile_w_in_, M, tile_h_in_*tile_w_in_,
199 | (Dtype)1,
200 | Winograd_Kron::getInstance(m_alg, WINOGRAD_B)->get().get(),
201 | col_buff,
202 | (Dtype)0, this->m_winogradInput);
203 |
204 | }
205 |
206 | void winograd_conv() {
207 |
208 | // Convolution in Winograd domain
209 | for (int j = 0; j < tile_h_in_*tile_w_in_; ++j) {
210 | for (int g = 0; g < this->m_group_; ++g) {
211 | PUBLIC_TOOL::dlm_cpu_gemm(CblasNoTrans, CblasNoTrans,
212 | this->conv_out_channels_ / this->m_group_, ntiles_h_*ntiles_w_, this->conv_in_channels_ / this->m_group_,
213 | (Dtype)1,
214 | m_winogradWeight + (j*this->m_group_ + g)*(this->conv_out_channels_ / this->m_group_)*(this->conv_in_channels_ / this->m_group_),
215 | m_winogradInput + (j*this->m_group_ + g)*(this->conv_in_channels_ / this->m_group_)*ntiles_h_*ntiles_w_,
216 | (Dtype)0, m_col_buff + (j*this->m_group_ + g)*(this->conv_out_channels_ / this->m_group_)*ntiles_h_*ntiles_w_);
217 | }
218 | }
219 | // col_buff has (tile_h_in*tile_w_in) x (conv_out_channels) x (ntiles_h*ntiles_w)
220 | }
221 |
222 | template
223 | void trans2spatial(Dtype *data) {
224 |
225 | Dtype *winogradRes = new Dtype[this->conv_out_channels_*ntiles_h_*ntiles_w_*tile_h_out_*tile_w_out_];
226 |
227 | PUBLIC_TOOL::dlm_cpu_gemm(CblasTrans, CblasNoTrans,
228 | this->conv_out_channels_*ntiles_h_*ntiles_w_, tile_h_out_*tile_w_out_, tile_h_in_*tile_w_in_,
229 | (Dtype)1, m_col_buff,
230 | Winograd_Kron::getInstance(m_alg, WINOGRAD_A)->get().get(),
231 | (Dtype)0, winogradRes);
232 |
233 | winograd_output_col2im_cpu(winogradRes, data);
234 |
235 | delete[] winogradRes;
236 | }
237 |
238 | template
239 | void winograd_input_im2col_cpu(const Dtype *data, Dtype *col_buff)
240 | {
241 | int height = m_iH;
242 | int width = m_iW;
243 | int pad_h = m_pad, pad_w = m_pad;
244 |
245 | for (int c = 0; c < this->conv_in_channels_; ++c) {
246 | for (int tile_h = 0; tile_h < ntiles_h_; ++tile_h) {
247 | for (int tile_w = 0; tile_w < ntiles_w_; ++tile_w) {
248 |
249 |
250 |
251 | for (int y = 0; y < tile_h_in_; ++y) {
252 | for (int x = 0; x < tile_w_in_; ++x) {
253 | int in_y = tile_h*tile_h_out_ + y - pad_h;
254 | int in_x = tile_w*tile_w_out_ + x - pad_w;
255 |
256 | if (in_y < 0 || in_x < 0 || in_y >= height || in_x >= width) {
257 | col_buff[(((c*ntiles_h_ + tile_h)*ntiles_w_ + tile_w)*tile_h_in_ + y)*tile_w_in_ + x] = 0;
258 | }
259 | else {
260 | col_buff[(((c*ntiles_h_ + tile_h)*ntiles_w_ + tile_w)*tile_h_in_ + y)*tile_w_in_ + x] =
261 | data[(c*height + in_y)*width + in_x];
262 | }
263 | }
264 | }
265 |
266 |
267 | } // for each tile
268 | } // for each tile
269 | } // for each input channel
270 | }
271 |
272 |
273 | template
274 | void forward_cpu_bias(Dtype* output,
275 | const Dtype* bias) {
276 |
277 | int out_spatial_dim_ = m_oH * m_oW;
278 |
279 | for (int i = 0; i < conv_out_channels_; i++) {
280 |
281 | for (int j = 0; j < out_spatial_dim_; j++)
282 | output[i*out_spatial_dim_ + j] += bias[i];
283 |
284 | }
285 |
286 | //dlm_cpu_gemm(CblasNoTrans, CblasNoTrans, conv_out_channels_,
287 | // out_spatial_dim_, 1, (Dtype)1., bias, bias_multiplier_.cpu_data(),
288 | // (Dtype)1., output);
289 | }
290 |
291 |
292 | template
293 | void winograd_output_col2im_cpu(const Dtype *col_buff, Dtype *data)
294 | {
295 | const int output_h = m_iH, output_w = m_iW;
296 |
297 | for (int n = 0; n < this->conv_out_channels_; ++n) {
298 | for (int tile_h = 0; tile_h < ntiles_h_; ++tile_h) {
299 | for (int tile_w = 0; tile_w < ntiles_w_; ++tile_w) {
300 | for (int y = 0; y < tile_h_out_; ++y) {
301 | for (int x = 0; x < tile_w_out_; ++x) {
302 | int out_y = tile_h*tile_h_out_ + y;
303 | int out_x = tile_w*tile_w_out_ + x;
304 |
305 | if (out_y < output_h && out_x < output_w) {
306 |
307 | /*int kk = 0;
308 | if ((((n*ntiles_h_ + tile_h)*ntiles_w_ + tile_w)*tile_h_out_ + y)*tile_w_out_ + x == 184604)
309 | kk++;
310 |
311 | cout << "dat: "<<(n*output_h + out_y)*output_w + out_x << " col : " << (((n*ntiles_h_ + tile_h)*ntiles_w_ + tile_w)*tile_h_out_ + y)*tile_w_out_ + x << endl;*/
312 |
313 | data[(n*output_h + out_y)*output_w + out_x] =
314 | col_buff[(((n*ntiles_h_ + tile_h)*ntiles_w_ + tile_w)*tile_h_out_ + y)*tile_w_out_ + x];
315 | }
316 | }
317 | }
318 | } // for each tile
319 | } // for each tile
320 | } // for each input channel
321 | }
322 |
323 | };
324 | }
325 |
326 | #endif
--------------------------------------------------------------------------------
/github_winograd/src/winograd_test.cpp:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include "../include/winograd_kernel.h"
4 | #include "../include/winograd_layer.h"
5 | #include
6 | #include
7 |
8 | using namespace WINOGRAD_KERNEL;
9 | using namespace std;
10 |
11 | const int CIN = 3;
12 | const int COUT = 7;
13 |
14 | const int IH = 25;
15 | const int IW = 25;
16 |
17 | const int PRECISE = 0;
18 |
19 | #define INPUT_INTEGER 1
20 | #define KERNEL_INTEGER 1
21 |
22 | void testWinograd();
23 |
24 | int main() {
25 |
26 |
27 | WINOGRAD_KERNEL::winograd2D_initialize();
28 |
29 | testWinograd();
30 |
31 | return 0;
32 | }
33 |
34 | void testWinograd() {
35 |
36 | //int batch_size = 1;
37 |
38 | int tiH = IH;
39 | int tiW = IW;
40 |
41 | int tkW = 3;
42 | int tkH = 3;
43 |
44 | int tsW = 1;
45 | int tsH = 1;
46 |
47 | int tiC = CIN;
48 | const int toC = COUT;
49 |
50 | bool tbias = true;
51 |
52 | int tpad = 1;
53 |
54 | const auto toH = (tiH + tpad * 2 - tkH) / tsH + 1;
55 |
56 | // Output width.
57 | const auto toW = (tiW + tpad * 2 - tkW) / tsW + 1;
58 |
59 | cout << setprecision(PRECISE);
60 |
61 | //NCHW
62 | float* input = new float[tiC*tiH*tiW];
63 | float* kernel = new float[tiC*tkH*tkW*toC + toC];
64 |
65 | //initInput
66 | for (int c = 0; c wt8X8(
90 | WINOGRAD_KERNEL::WT_8X8_F_6X6_3X3, //WT_6X6_F_4X4_3X3
91 | 1,
92 | tiH,
93 | tiW,
94 | tiC,
95 | tkH,
96 | tkW,
97 | tsH,
98 | tsW,
99 | toC,
100 | tpad,
101 | tbias
102 | );
103 |
104 | WINOGRAD_KERNEL::WinogradLayer wt6x6(
105 | WINOGRAD_KERNEL::WT_6X6_F_4X4_3X3, //WT_6X6_F_4X4_3X3
106 | 1,
107 | tiH,
108 | tiW,
109 | tiC,
110 | tkH,
111 | tkW,
112 | tsH,
113 | tsW,
114 | toC,
115 | tpad,
116 | tbias
117 | );
118 |
119 | float* buffer=new float [toH*toW*tiC*100];// enough buffer, used as medium buffer flowing through each layer
120 |
121 | shared_ptr output = wt8X8.get_inference_cpu(input, kernel, (float*)buffer); //
122 |
123 | cout << "the first three elements and the last one of the wt8x8 result:" << endl;
124 | cout << output.get()[0] << " " << output.get()[1] << " " << output.get()[2] << " " << output.get()[toC*toH*toW - 1] << " " << endl;
125 |
126 |
127 | output = wt6x6.get_inference_cpu(input, kernel, (float*)buffer); //
128 |
129 | cout << "the first three elements and the last one of the wt6x6 result:" << endl;
130 | cout << output.get()[0] << " " << output.get()[1] << " " << output.get()[2] << " " << output.get()[toC*toH*toW - 1] << " " << endl;
131 |
132 | delete[] buffer;
133 | }
134 |
--------------------------------------------------------------------------------