├── .gitignore
├── CI
└── Azure-Master.yml
├── LICENSE
├── README.md
├── YOLOv4.sln
├── YOLOv4.sln.DotSettings
├── data
├── performance.png
├── result-int8.png
├── result.png
└── summary.txt
├── samples
├── DetectUI
│ ├── DetectUI.csproj
│ ├── Program.cs
│ ├── YoloForm.Designer.cs
│ ├── YoloForm.cs
│ └── YoloForm.resx
└── TrainV4
│ ├── Program.cs
│ ├── ToSavedModel.cs
│ ├── TrainV4.cs
│ ├── TrainV4.csproj
│ └── TrainingLogger.cs
├── src
├── BufferedEnumerable.cs
├── InternalsVisibleTo.cs
├── ListLinq.cs
├── Tools.cs
├── Utils.cs
├── YOLOv4.csproj
├── data
│ └── ObjectDetectionDataset.cs
├── datasets
│ └── ObjectDetection
│ │ └── MS_COCO.cs
├── image
│ └── ImageTools.cs
└── keras
│ ├── Activations.cs
│ ├── Blocks.cs
│ ├── applications
│ ├── ObjectDetectionResult.cs
│ ├── YOLO.Common.cs
│ ├── YOLO.Evaluate.cs
│ ├── YOLO.LearningRateSchedule.cs
│ ├── YOLO.Raw.cs
│ ├── YOLO.SaveModel.cs
│ └── YOLO.Train.cs
│ ├── callbacks
│ └── LearningRateLogger.cs
│ ├── layers
│ ├── FreezableBatchNormalization.cs
│ └── YoloLossEndpoint.cs
│ ├── losses
│ └── ZeroLoss.cs
│ ├── models
│ ├── CrossStagePartialDarknet53.cs
│ ├── Darknet53.cs
│ └── YOLOv4.cs
│ └── utils
│ └── Sequence.cs
└── test
├── TensorFlowFixture.cs
└── YOLOv4.Tests.csproj
/.gitignore:
--------------------------------------------------------------------------------
1 | /data/yolov4.weights
2 | /data/anchors/
3 | /data/classes/
4 | /data/dataset/
5 |
6 | /samples/TrainV4/TrainLog/
7 | /samples/TrainV4/Trained/
8 | /samples/TrainV4/train.err
9 | /samples/TrainV4/train.out
10 |
11 | LOCAL_TESTS.cs
12 |
13 | ## Ignore Visual Studio temporary files, build results, and
14 | ## files generated by popular Visual Studio add-ons.
15 | ##
16 | ## Get latest from https://github.com/github/gitignore/blob/master/VisualStudio.gitignore
17 |
18 | # User-specific files
19 | *.rsuser
20 | *.suo
21 | *.user
22 | *.userosscache
23 | *.sln.docstates
24 |
25 | # User-specific files (MonoDevelop/Xamarin Studio)
26 | *.userprefs
27 |
28 | # Mono auto generated files
29 | mono_crash.*
30 |
31 | # Build results
32 | [Dd]ebug/
33 | [Dd]ebugPublic/
34 | [Rr]elease/
35 | [Rr]eleases/
36 | x64/
37 | x86/
38 | [Aa][Rr][Mm]/
39 | [Aa][Rr][Mm]64/
40 | bld/
41 | [Bb]in/
42 | [Oo]bj/
43 | [Ll]og/
44 | [Ll]ogs/
45 |
46 | # Visual Studio 2015/2017 cache/options directory
47 | .vs/
48 | # Uncomment if you have tasks that create the project's static files in wwwroot
49 | #wwwroot/
50 |
51 | # Visual Studio 2017 auto generated files
52 | Generated\ Files/
53 |
54 | # MSTest test Results
55 | [Tt]est[Rr]esult*/
56 | [Bb]uild[Ll]og.*
57 |
58 | # NUnit
59 | *.VisualState.xml
60 | TestResult.xml
61 | nunit-*.xml
62 |
63 | # Build Results of an ATL Project
64 | [Dd]ebugPS/
65 | [Rr]eleasePS/
66 | dlldata.c
67 |
68 | # Benchmark Results
69 | BenchmarkDotNet.Artifacts/
70 |
71 | # .NET Core
72 | project.lock.json
73 | project.fragment.lock.json
74 | artifacts/
75 |
76 | # StyleCop
77 | StyleCopReport.xml
78 |
79 | # Files built by Visual Studio
80 | *_i.c
81 | *_p.c
82 | *_h.h
83 | *.ilk
84 | *.meta
85 | *.obj
86 | *.iobj
87 | *.pch
88 | *.pdb
89 | *.ipdb
90 | *.pgc
91 | *.pgd
92 | *.rsp
93 | *.sbr
94 | *.tlb
95 | *.tli
96 | *.tlh
97 | *.tmp
98 | *.tmp_proj
99 | *_wpftmp.csproj
100 | *.log
101 | *.vspscc
102 | *.vssscc
103 | .builds
104 | *.pidb
105 | *.svclog
106 | *.scc
107 |
108 | # Chutzpah Test files
109 | _Chutzpah*
110 |
111 | # Visual C++ cache files
112 | ipch/
113 | *.aps
114 | *.ncb
115 | *.opendb
116 | *.opensdf
117 | *.sdf
118 | *.cachefile
119 | *.VC.db
120 | *.VC.VC.opendb
121 |
122 | # Visual Studio profiler
123 | *.psess
124 | *.vsp
125 | *.vspx
126 | *.sap
127 |
128 | # Visual Studio Trace Files
129 | *.e2e
130 |
131 | # TFS 2012 Local Workspace
132 | $tf/
133 |
134 | # Guidance Automation Toolkit
135 | *.gpState
136 |
137 | # ReSharper is a .NET coding add-in
138 | _ReSharper*/
139 | *.[Rr]e[Ss]harper
140 | *.DotSettings.user
141 |
142 | # TeamCity is a build add-in
143 | _TeamCity*
144 |
145 | # DotCover is a Code Coverage Tool
146 | *.dotCover
147 |
148 | # AxoCover is a Code Coverage Tool
149 | .axoCover/*
150 | !.axoCover/settings.json
151 |
152 | # Visual Studio code coverage results
153 | *.coverage
154 | *.coveragexml
155 |
156 | # NCrunch
157 | _NCrunch_*
158 | .*crunch*.local.xml
159 | nCrunchTemp_*
160 |
161 | # MightyMoose
162 | *.mm.*
163 | AutoTest.Net/
164 |
165 | # Web workbench (sass)
166 | .sass-cache/
167 |
168 | # Installshield output folder
169 | [Ee]xpress/
170 |
171 | # DocProject is a documentation generator add-in
172 | DocProject/buildhelp/
173 | DocProject/Help/*.HxT
174 | DocProject/Help/*.HxC
175 | DocProject/Help/*.hhc
176 | DocProject/Help/*.hhk
177 | DocProject/Help/*.hhp
178 | DocProject/Help/Html2
179 | DocProject/Help/html
180 |
181 | # Click-Once directory
182 | publish/
183 |
184 | # Publish Web Output
185 | *.[Pp]ublish.xml
186 | *.azurePubxml
187 | # Note: Comment the next line if you want to checkin your web deploy settings,
188 | # but database connection strings (with potential passwords) will be unencrypted
189 | *.pubxml
190 | *.publishproj
191 |
192 | # Microsoft Azure Web App publish settings. Comment the next line if you want to
193 | # checkin your Azure Web App publish settings, but sensitive information contained
194 | # in these scripts will be unencrypted
195 | PublishScripts/
196 |
197 | # NuGet Packages
198 | *.nupkg
199 | # NuGet Symbol Packages
200 | *.snupkg
201 | # The packages folder can be ignored because of Package Restore
202 | **/[Pp]ackages/*
203 | # except build/, which is used as an MSBuild target.
204 | !**/[Pp]ackages/build/
205 | # Uncomment if necessary however generally it will be regenerated when needed
206 | #!**/[Pp]ackages/repositories.config
207 | # NuGet v3's project.json files produces more ignorable files
208 | *.nuget.props
209 | *.nuget.targets
210 |
211 | # Microsoft Azure Build Output
212 | csx/
213 | *.build.csdef
214 |
215 | # Microsoft Azure Emulator
216 | ecf/
217 | rcf/
218 |
219 | # Windows Store app package directories and files
220 | AppPackages/
221 | BundleArtifacts/
222 | Package.StoreAssociation.xml
223 | _pkginfo.txt
224 | *.appx
225 | *.appxbundle
226 | *.appxupload
227 |
228 | # Visual Studio cache files
229 | # files ending in .cache can be ignored
230 | *.[Cc]ache
231 | # but keep track of directories ending in .cache
232 | !?*.[Cc]ache/
233 |
234 | # Others
235 | ClientBin/
236 | ~$*
237 | *~
238 | *.dbmdl
239 | *.dbproj.schemaview
240 | *.jfm
241 | *.pfx
242 | *.publishsettings
243 | orleans.codegen.cs
244 |
245 | # Including strong name files can present a security risk
246 | # (https://github.com/github/gitignore/pull/2483#issue-259490424)
247 | #*.snk
248 |
249 | # Since there are multiple workflows, uncomment next line to ignore bower_components
250 | # (https://github.com/github/gitignore/pull/1529#issuecomment-104372622)
251 | #bower_components/
252 |
253 | # RIA/Silverlight projects
254 | Generated_Code/
255 |
256 | # Backup & report files from converting an old project file
257 | # to a newer Visual Studio version. Backup files are not needed,
258 | # because we have git ;-)
259 | _UpgradeReport_Files/
260 | Backup*/
261 | UpgradeLog*.XML
262 | UpgradeLog*.htm
263 | ServiceFabricBackup/
264 | *.rptproj.bak
265 |
266 | # SQL Server files
267 | *.mdf
268 | *.ldf
269 | *.ndf
270 |
271 | # Business Intelligence projects
272 | *.rdl.data
273 | *.bim.layout
274 | *.bim_*.settings
275 | *.rptproj.rsuser
276 | *- [Bb]ackup.rdl
277 | *- [Bb]ackup ([0-9]).rdl
278 | *- [Bb]ackup ([0-9][0-9]).rdl
279 |
280 | # Microsoft Fakes
281 | FakesAssemblies/
282 |
283 | # GhostDoc plugin setting file
284 | *.GhostDoc.xml
285 |
286 | # Node.js Tools for Visual Studio
287 | .ntvs_analysis.dat
288 | node_modules/
289 |
290 | # Visual Studio 6 build log
291 | *.plg
292 |
293 | # Visual Studio 6 workspace options file
294 | *.opt
295 |
296 | # Visual Studio 6 auto-generated workspace file (contains which files were open etc.)
297 | *.vbw
298 |
299 | # Visual Studio LightSwitch build output
300 | **/*.HTMLClient/GeneratedArtifacts
301 | **/*.DesktopClient/GeneratedArtifacts
302 | **/*.DesktopClient/ModelManifest.xml
303 | **/*.Server/GeneratedArtifacts
304 | **/*.Server/ModelManifest.xml
305 | _Pvt_Extensions
306 |
307 | # Paket dependency manager
308 | .paket/paket.exe
309 | paket-files/
310 |
311 | # FAKE - F# Make
312 | .fake/
313 |
314 | # CodeRush personal settings
315 | .cr/personal
316 |
317 | # Python Tools for Visual Studio (PTVS)
318 | __pycache__/
319 | *.pyc
320 |
321 | # Cake - Uncomment if you are using it
322 | # tools/**
323 | # !tools/packages.config
324 |
325 | # Tabs Studio
326 | *.tss
327 |
328 | # Telerik's JustMock configuration file
329 | *.jmconfig
330 |
331 | # BizTalk build output
332 | *.btp.cs
333 | *.btm.cs
334 | *.odx.cs
335 | *.xsd.cs
336 |
337 | # OpenCover UI analysis results
338 | OpenCover/
339 |
340 | # Azure Stream Analytics local run output
341 | ASALocalRun/
342 |
343 | # MSBuild Binary and Structured Log
344 | *.binlog
345 |
346 | # NVidia Nsight GPU debugger configuration file
347 | *.nvuser
348 |
349 | # MFractors (Xamarin productivity tool) working folder
350 | .mfractor/
351 |
352 | # Local History for Visual Studio
353 | .localhistory/
354 |
355 | # BeatPulse healthcheck temp database
356 | healthchecksdb
357 |
358 | # Backup folder for Package Reference Convert tool in Visual Studio 2017
359 | MigrationBackup/
360 |
361 | # Ionide (cross platform F# VS Code tools) working folder
362 | .ionide/
363 |
--------------------------------------------------------------------------------
/CI/Azure-Master.yml:
--------------------------------------------------------------------------------
1 | trigger:
2 | - master
3 |
4 | pool:
5 | vmImage: 'ubuntu-latest'
6 |
7 | variables:
8 | buildConfiguration: 'Release'
9 |
10 | steps:
11 | - script: dotnet restore
12 | displayName: 'Restore'
13 | - script: dotnet build --configuration $(buildConfiguration)
14 | displayName: 'Build'
15 | - script: dotnet test --no-build --configuration $(buildConfiguration) --collect:"XPlat Code Coverage" --logger trx
16 | displayName: 'Test'
17 | - script: dotnet pack --configuration $(buildConfiguration)
18 | displayName: 'Pack'
19 |
20 | - task: PublishTestResults@2
21 | inputs:
22 | testRunner: VSTest
23 | testResultsFiles: '**/*.trx'
24 |
25 | - task: PublishCodeCoverageResults@1
26 | displayName: 'Upload Coverage'
27 | inputs:
28 | codeCoverageTool: 'cobertura'
29 | summaryFileLocation: '$(Build.SourcesDirectory)/**/coverage.cobertura.xml'
30 | failIfCoverageEmpty: true
31 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Lost Tech LLC
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # YOLOv4
2 |
3 | [](LICENSE)
4 |
5 | *NOTICE: This is a port of https://github.com/hunglc007/tensorflow-yolov4-tflite
6 |
7 | YOLOv4 Implemented in Tensorflow 1.15
8 |
9 | ### Prerequisites
10 | [](https://www.nuget.org/packages/LostTech.TensorFlow)
11 |
12 | ### Performance
13 |

14 |
15 | ### Demo
16 |
17 | TBD
18 |
19 | #### Output
20 |
21 | ##### Yolov4 original weight
22 | 
23 |
24 | ##### Yolov4 tflite int8
25 | 
26 |
27 | ### Convert to ONNX
28 |
29 | TBD
30 |
31 | ### Evaluate on COCO 2017 Dataset
32 |
33 | TBD
34 |
35 | # evaluate yolov4 model
36 |
37 | TBD
38 |
39 | #### mAP50 on COCO 2017 Dataset
40 |
41 | | Detection | 512x512 | 416x416 | 320x320 |
42 | |-------------|---------|---------|---------|
43 | | YoloV3 | 55.43 | | |
44 | | YoloV4 | 61.96 | 57.33 | |
45 |
46 | ### Benchmark
47 |
48 | TBD
49 |
50 | #### Tesla P100
51 |
52 | | Detection | 512x512 | 416x416 | 320x320 |
53 | |-------------|---------|---------|---------|
54 | | YoloV3 FPS | 40.6 | 49.4 | 61.3 |
55 | | YoloV4 FPS | 33.4 | 41.7 | 50.0 |
56 |
57 | #### Tesla K80
58 |
59 | | Detection | 512x512 | 416x416 | 320x320 |
60 | |-------------|---------|---------|---------|
61 | | YoloV3 FPS | 10.8 | 12.9 | 17.6 |
62 | | YoloV4 FPS | 9.6 | 11.7 | 16.0 |
63 |
64 | #### Tesla T4
65 |
66 | | Detection | 512x512 | 416x416 | 320x320 |
67 | |-------------|---------|---------|---------|
68 | | YoloV3 FPS | 27.6 | 32.3 | 45.1 |
69 | | YoloV4 FPS | 24.0 | 30.3 | 40.1 |
70 |
71 | #### Tesla P4
72 |
73 | | Detection | 512x512 | 416x416 | 320x320 |
74 | |-------------|---------|---------|---------|
75 | | YoloV3 FPS | 20.2 | 24.2 | 31.2 |
76 | | YoloV4 FPS | 16.2 | 20.2 | 26.5 |
77 |
78 | #### Macbook Pro 15 (2.3GHz i7)
79 |
80 | | Detection | 512x512 | 416x416 | 320x320 |
81 | |-------------|---------|---------|---------|
82 | | YoloV3 FPS | | | |
83 | | YoloV4 FPS | | | |
84 |
85 | ### Traning your own model
86 |
87 | Sample training code available at [samples/TrainV4](samples/TrainV4)
88 |
89 | ### References
90 |
91 | * YOLOv4: Optimal Speed and Accuracy of Object Detection [YOLOv4](https://arxiv.org/abs/2004.10934).
92 | * [darknet](https://github.com/AlexeyAB/darknet)
93 |
94 | My project is inspired by these previous fantastic YOLOv3 implementations:
95 | * [Yolov3 tensorflow](https://github.com/YunYang1994/tensorflow-yolov3)
96 | * [Yolov3 tf2](https://github.com/zzh8829/yolov3-tf2)
97 |
--------------------------------------------------------------------------------
/YOLOv4.sln:
--------------------------------------------------------------------------------
1 | Microsoft Visual Studio Solution File, Format Version 12.00
2 | # Visual Studio Version 16
3 | VisualStudioVersion = 16.0.30309.148
4 | MinimumVisualStudioVersion = 15.0.26124.0
5 | Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Repo", "Repo", "{9E5831CD-6490-4374-8502-825AB423A8B7}"
6 | ProjectSection(SolutionItems) = preProject
7 | .gitignore = .gitignore
8 | LICENSE = LICENSE
9 | README.md = README.md
10 | EndProjectSection
11 | EndProject
12 | Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "CI", "CI", "{8C3BD3D9-5F7E-4622-80E3-45BF5190320E}"
13 | ProjectSection(SolutionItems) = preProject
14 | CI\Azure-Master.yml = CI\Azure-Master.yml
15 | EndProjectSection
16 | EndProject
17 | Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "YOLOv4", "src\YOLOv4.csproj", "{04BE1707-4235-44E6-AB58-48621D5160D3}"
18 | EndProject
19 | Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "YOLOv4.Tests", "test\YOLOv4.Tests.csproj", "{241AC79D-0399-4B2A-9C01-07C609F74B69}"
20 | EndProject
21 | Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TrainV4", "samples\TrainV4\TrainV4.csproj", "{5299523B-B225-4A47-8C58-66D58B3547D2}"
22 | EndProject
23 | Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Samples", "Samples", "{4AA5CFE0-5A26-45D9-9748-FB3F89EDE83A}"
24 | EndProject
25 | Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "DetectUI", "samples\DetectUI\DetectUI.csproj", "{D695A271-2BBA-4737-81E2-6A9BE0A8290E}"
26 | EndProject
27 | Global
28 | GlobalSection(SolutionConfigurationPlatforms) = preSolution
29 | Debug|Any CPU = Debug|Any CPU
30 | Release|Any CPU = Release|Any CPU
31 | EndGlobalSection
32 | GlobalSection(ProjectConfigurationPlatforms) = postSolution
33 | {04BE1707-4235-44E6-AB58-48621D5160D3}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
34 | {04BE1707-4235-44E6-AB58-48621D5160D3}.Debug|Any CPU.Build.0 = Debug|Any CPU
35 | {04BE1707-4235-44E6-AB58-48621D5160D3}.Release|Any CPU.ActiveCfg = Release|Any CPU
36 | {04BE1707-4235-44E6-AB58-48621D5160D3}.Release|Any CPU.Build.0 = Release|Any CPU
37 | {241AC79D-0399-4B2A-9C01-07C609F74B69}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
38 | {241AC79D-0399-4B2A-9C01-07C609F74B69}.Debug|Any CPU.Build.0 = Debug|Any CPU
39 | {241AC79D-0399-4B2A-9C01-07C609F74B69}.Release|Any CPU.ActiveCfg = Release|Any CPU
40 | {241AC79D-0399-4B2A-9C01-07C609F74B69}.Release|Any CPU.Build.0 = Release|Any CPU
41 | {5299523B-B225-4A47-8C58-66D58B3547D2}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
42 | {5299523B-B225-4A47-8C58-66D58B3547D2}.Debug|Any CPU.Build.0 = Debug|Any CPU
43 | {5299523B-B225-4A47-8C58-66D58B3547D2}.Release|Any CPU.ActiveCfg = Release|Any CPU
44 | {5299523B-B225-4A47-8C58-66D58B3547D2}.Release|Any CPU.Build.0 = Release|Any CPU
45 | {D695A271-2BBA-4737-81E2-6A9BE0A8290E}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
46 | {D695A271-2BBA-4737-81E2-6A9BE0A8290E}.Debug|Any CPU.Build.0 = Debug|Any CPU
47 | {D695A271-2BBA-4737-81E2-6A9BE0A8290E}.Release|Any CPU.ActiveCfg = Release|Any CPU
48 | {D695A271-2BBA-4737-81E2-6A9BE0A8290E}.Release|Any CPU.Build.0 = Release|Any CPU
49 | EndGlobalSection
50 | GlobalSection(SolutionProperties) = preSolution
51 | HideSolutionNode = FALSE
52 | EndGlobalSection
53 | GlobalSection(NestedProjects) = preSolution
54 | {5299523B-B225-4A47-8C58-66D58B3547D2} = {4AA5CFE0-5A26-45D9-9748-FB3F89EDE83A}
55 | {D695A271-2BBA-4737-81E2-6A9BE0A8290E} = {4AA5CFE0-5A26-45D9-9748-FB3F89EDE83A}
56 | EndGlobalSection
57 | GlobalSection(ExtensibilityGlobals) = postSolution
58 | SolutionGuid = {4D804A23-81D4-4FE4-9ADE-1110371610D4}
59 | EndGlobalSection
60 | EndGlobal
61 |
--------------------------------------------------------------------------------
/YOLOv4.sln.DotSettings:
--------------------------------------------------------------------------------
1 |
2 | True
--------------------------------------------------------------------------------
/data/performance.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/losttech/YOLOv4/0f09f4e2d446699557aefcc225d5f8b11caa365d/data/performance.png
--------------------------------------------------------------------------------
/data/result-int8.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/losttech/YOLOv4/0f09f4e2d446699557aefcc225d5f8b11caa365d/data/result-int8.png
--------------------------------------------------------------------------------
/data/result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/losttech/YOLOv4/0f09f4e2d446699557aefcc225d5f8b11caa365d/data/result.png
--------------------------------------------------------------------------------
/data/summary.txt:
--------------------------------------------------------------------------------
1 | conv2d: 864
2 | batch_normalization: 64
3 | conv2d_1: 18432
4 | batch_normalization_1: 128
5 | conv2d_3: 4096
6 | batch_normalization_3: 128
7 | conv2d_4: 2048
8 | batch_normalization_4: 64
9 | conv2d_5: 18432
10 | batch_normalization_5: 128
11 | conv2d_6: 4096
12 | conv2d_2: 4096
13 | batch_normalization_6: 128
14 | batch_normalization_2: 128
15 | conv2d_7: 8192
16 | batch_normalization_7: 128
17 | conv2d_8: 73728
18 | batch_normalization_8: 256
19 | conv2d_10: 8192
20 | batch_normalization_10: 128
21 | conv2d_11: 4096
22 | batch_normalization_11: 128
23 | conv2d_12: 36864
24 | batch_normalization_12: 128
25 | conv2d_13: 4096
26 | batch_normalization_13: 128
27 | conv2d_14: 36864
28 | batch_normalization_14: 128
29 | conv2d_15: 4096
30 | conv2d_9: 8192
31 | batch_normalization_15: 128
32 | batch_normalization_9: 128
33 | conv2d_16: 16384
34 | batch_normalization_16: 256
35 | conv2d_17: 294912
36 | batch_normalization_17: 512
37 | conv2d_19: 32768
38 | batch_normalization_19: 256
39 | conv2d_20: 16384
40 | batch_normalization_20: 256
41 | conv2d_21: 147456
42 | batch_normalization_21: 256
43 | conv2d_22: 16384
44 | batch_normalization_22: 256
45 | conv2d_23: 147456
46 | batch_normalization_23: 256
47 | conv2d_24: 16384
48 | batch_normalization_24: 256
49 | conv2d_25: 147456
50 | batch_normalization_25: 256
51 | conv2d_26: 16384
52 | batch_normalization_26: 256
53 | conv2d_27: 147456
54 | batch_normalization_27: 256
55 | conv2d_28: 16384
56 | batch_normalization_28: 256
57 | conv2d_29: 147456
58 | batch_normalization_29: 256
59 | conv2d_30: 16384
60 | batch_normalization_30: 256
61 | conv2d_31: 147456
62 | batch_normalization_31: 256
63 | conv2d_32: 16384
64 | batch_normalization_32: 256
65 | conv2d_33: 147456
66 | batch_normalization_33: 256
67 | conv2d_34: 16384
68 | batch_normalization_34: 256
69 | conv2d_35: 147456
70 | batch_normalization_35: 256
71 | conv2d_36: 16384
72 | conv2d_18: 32768
73 | batch_normalization_36: 256
74 | batch_normalization_18: 256
75 | conv2d_37: 65536
76 | batch_normalization_37: 512
77 | conv2d_38: 1179648
78 | batch_normalization_38: 1024
79 | conv2d_40: 131072
80 | batch_normalization_40: 512
81 | conv2d_41: 65536
82 | batch_normalization_41: 512
83 | conv2d_42: 589824
84 | batch_normalization_42: 512
85 | conv2d_43: 65536
86 | batch_normalization_43: 512
87 | conv2d_44: 589824
88 | batch_normalization_44: 512
89 | conv2d_45: 65536
90 | batch_normalization_45: 512
91 | conv2d_46: 589824
92 | batch_normalization_46: 512
93 | conv2d_47: 65536
94 | batch_normalization_47: 512
95 | conv2d_48: 589824
96 | batch_normalization_48: 512
97 | conv2d_49: 65536
98 | batch_normalization_49: 512
99 | conv2d_50: 589824
100 | batch_normalization_50: 512
101 | conv2d_51: 65536
102 | batch_normalization_51: 512
103 | conv2d_52: 589824
104 | batch_normalization_52: 512
105 | conv2d_53: 65536
106 | batch_normalization_53: 512
107 | conv2d_54: 589824
108 | batch_normalization_54: 512
109 | conv2d_55: 65536
110 | batch_normalization_55: 512
111 | conv2d_56: 589824
112 | batch_normalization_56: 512
113 | conv2d_57: 65536
114 | conv2d_39: 131072
115 | batch_normalization_57: 512
116 | batch_normalization_39: 512
117 | conv2d_58: 262144
118 | batch_normalization_58: 1024
119 | conv2d_59: 4718592
120 | batch_normalization_59: 2048
121 | conv2d_61: 524288
122 | batch_normalization_61: 1024
123 | conv2d_62: 262144
124 | batch_normalization_62: 1024
125 | conv2d_63: 2359296
126 | batch_normalization_63: 1024
127 | conv2d_64: 262144
128 | batch_normalization_64: 1024
129 | conv2d_65: 2359296
130 | batch_normalization_65: 1024
131 | conv2d_66: 262144
132 | batch_normalization_66: 1024
133 | conv2d_67: 2359296
134 | batch_normalization_67: 1024
135 | conv2d_68: 262144
136 | batch_normalization_68: 1024
137 | conv2d_69: 2359296
138 | batch_normalization_69: 1024
139 | conv2d_70: 262144
140 | conv2d_60: 524288
141 | batch_normalization_70: 1024
142 | batch_normalization_60: 1024
143 | conv2d_71: 1048576
144 | batch_normalization_71: 2048
145 | conv2d_72: 524288
146 | batch_normalization_72: 1024
147 | conv2d_73: 4718592
148 | batch_normalization_73: 2048
149 | conv2d_74: 524288
150 | batch_normalization_74: 1024
151 | conv2d_75: 1048576
152 | batch_normalization_75: 1024
153 | conv2d_76: 4718592
154 | batch_normalization_76: 2048
155 | conv2d_77: 524288
156 | batch_normalization_77: 1024
157 | conv2d_78: 131072
158 | batch_normalization_78: 512
159 | conv2d_79: 131072
160 | batch_normalization_79: 512
161 | conv2d_80: 131072
162 | batch_normalization_80: 512
163 | conv2d_81: 1179648
164 | batch_normalization_81: 1024
165 | conv2d_82: 131072
166 | batch_normalization_82: 512
167 | conv2d_83: 1179648
168 | batch_normalization_83: 1024
169 | conv2d_84: 131072
170 | batch_normalization_84: 512
171 | conv2d_85: 32768
172 | batch_normalization_85: 256
173 | conv2d_86: 32768
174 | batch_normalization_86: 256
175 | conv2d_87: 32768
176 | batch_normalization_87: 256
177 | conv2d_88: 294912
178 | batch_normalization_88: 512
179 | conv2d_89: 32768
180 | batch_normalization_89: 256
181 | conv2d_90: 294912
182 | batch_normalization_90: 512
183 | conv2d_91: 32768
184 | batch_normalization_91: 256
185 | conv2d_94: 294912
186 | batch_normalization_93: 512
187 | conv2d_95: 131072
188 | batch_normalization_94: 512
189 | conv2d_96: 1179648
190 | batch_normalization_95: 1024
191 | conv2d_97: 131072
192 | batch_normalization_96: 512
193 | conv2d_98: 1179648
194 | batch_normalization_97: 1024
195 | conv2d_99: 131072
196 | batch_normalization_98: 512
197 | conv2d_102: 1179648
198 | batch_normalization_100: 1024
199 | conv2d_103: 524288
200 | batch_normalization_101: 1024
201 | conv2d_104: 4718592
202 | batch_normalization_102: 2048
203 | conv2d_105: 524288
204 | batch_normalization_103: 1024
205 | conv2d_106: 4718592
206 | batch_normalization_104: 2048
207 | conv2d_107: 524288
208 | batch_normalization_105: 1024
209 | conv2d_92: 294912
210 | conv2d_100: 1179648
211 | conv2d_108: 4718592
212 | batch_normalization_92: 512
213 | batch_normalization_99: 1024
214 | batch_normalization_106: 2048
215 | conv2d_93: 65535
216 | conv2d_101: 130815
217 | conv2d_109: 261375
--------------------------------------------------------------------------------
/samples/DetectUI/DetectUI.csproj:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | Exe
5 | netcoreapp3.1
6 | true
7 | true
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
--------------------------------------------------------------------------------
/samples/DetectUI/Program.cs:
--------------------------------------------------------------------------------
1 | using System;
2 | using System.Collections.Generic;
3 | using System.Linq;
4 | using System.Runtime.InteropServices;
5 | using System.Threading.Tasks;
6 | using System.Windows.Forms;
7 |
8 | namespace DetectUI {
9 | static class Program {
10 | ///
11 | /// The main entry point for the application.
12 | ///
13 | [STAThread]
14 | static void Main() {
15 | Application.SetHighDpiMode(HighDpiMode.SystemAware);
16 | Application.EnableVisualStyles();
17 | Application.SetCompatibleTextRenderingDefault(false);
18 | Application.Run(new YoloForm());
19 | }
20 | }
21 | }
22 |
--------------------------------------------------------------------------------
/samples/DetectUI/YoloForm.Designer.cs:
--------------------------------------------------------------------------------
1 | namespace DetectUI {
2 | partial class YoloForm {
3 | ///
4 | /// Required designer variable.
5 | ///
6 | private System.ComponentModel.IContainer components = null;
7 |
8 | ///
9 | /// Clean up any resources being used.
10 | ///
11 | /// true if managed resources should be disposed; otherwise, false.
12 | protected override void Dispose(bool disposing) {
13 | if (disposing && (components != null)) {
14 | components.Dispose();
15 | }
16 | base.Dispose(disposing);
17 | }
18 |
19 | #region Windows Form Designer generated code
20 |
21 | ///
22 | /// Required method for Designer support - do not modify
23 | /// the contents of this method with the code editor.
24 | ///
25 | private void InitializeComponent() {
26 | this.openPic = new System.Windows.Forms.Button();
27 | this.pictureBox = new System.Windows.Forms.PictureBox();
28 | this.openPicDialog = new System.Windows.Forms.OpenFileDialog();
29 | this.openWeightsDirDialog = new System.Windows.Forms.FolderBrowserDialog();
30 | ((System.ComponentModel.ISupportInitialize)(this.pictureBox)).BeginInit();
31 | this.SuspendLayout();
32 | //
33 | // openPic
34 | //
35 | this.openPic.Anchor = ((System.Windows.Forms.AnchorStyles)(((System.Windows.Forms.AnchorStyles.Bottom | System.Windows.Forms.AnchorStyles.Left)
36 | | System.Windows.Forms.AnchorStyles.Right)));
37 | this.openPic.AutoSize = true;
38 | this.openPic.AutoSizeMode = System.Windows.Forms.AutoSizeMode.GrowAndShrink;
39 | this.openPic.Enabled = false;
40 | this.openPic.Location = new System.Drawing.Point(12, 408);
41 | this.openPic.Name = "openPic";
42 | this.openPic.Size = new System.Drawing.Size(101, 30);
43 | this.openPic.TabIndex = 0;
44 | this.openPic.Text = "Open Image";
45 | this.openPic.UseVisualStyleBackColor = true;
46 | this.openPic.Click += new System.EventHandler(this.openPic_Click);
47 | //
48 | // pictureBox
49 | //
50 | this.pictureBox.Anchor = ((System.Windows.Forms.AnchorStyles)((((System.Windows.Forms.AnchorStyles.Top | System.Windows.Forms.AnchorStyles.Bottom)
51 | | System.Windows.Forms.AnchorStyles.Left)
52 | | System.Windows.Forms.AnchorStyles.Right)));
53 | this.pictureBox.Location = new System.Drawing.Point(1, -1);
54 | this.pictureBox.Name = "pictureBox";
55 | this.pictureBox.Size = new System.Drawing.Size(800, 403);
56 | this.pictureBox.SizeMode = System.Windows.Forms.PictureBoxSizeMode.Zoom;
57 | this.pictureBox.TabIndex = 1;
58 | this.pictureBox.TabStop = false;
59 | //
60 | // openPicDialog
61 | //
62 | this.openPicDialog.Title = "Open Picture";
63 | //
64 | // openWeightsDirDialog
65 | //
66 | this.openWeightsDirDialog.Description = "Select directory with SavedModel";
67 | this.openWeightsDirDialog.RootFolder = System.Environment.SpecialFolder.MyDocuments;
68 | this.openWeightsDirDialog.ShowNewFolderButton = false;
69 | //
70 | // YoloForm
71 | //
72 | this.AutoScaleDimensions = new System.Drawing.SizeF(8F, 20F);
73 | this.AutoScaleMode = System.Windows.Forms.AutoScaleMode.Font;
74 | this.ClientSize = new System.Drawing.Size(800, 450);
75 | this.Controls.Add(this.pictureBox);
76 | this.Controls.Add(this.openPic);
77 | this.Name = "YoloForm";
78 | this.Text = "YOLO";
79 | this.Load += new System.EventHandler(this.YoloForm_Load);
80 | ((System.ComponentModel.ISupportInitialize)(this.pictureBox)).EndInit();
81 | this.ResumeLayout(false);
82 | this.PerformLayout();
83 |
84 | }
85 |
86 | #endregion
87 |
88 | private System.Windows.Forms.Button openPic;
89 | private System.Windows.Forms.PictureBox pictureBox;
90 | private System.Windows.Forms.OpenFileDialog openPicDialog;
91 | private System.Windows.Forms.FolderBrowserDialog openWeightsDirDialog;
92 | }
93 | }
94 |
95 |
--------------------------------------------------------------------------------
/samples/DetectUI/YoloForm.cs:
--------------------------------------------------------------------------------
1 | namespace DetectUI {
2 | using System;
3 | using System.Collections.Generic;
4 | using System.Data;
5 | using System.Diagnostics;
6 | using System.IO;
7 | using System.Linq;
8 | using System.Windows.Forms;
9 |
10 | using LostTech.Gradient;
11 | using LostTech.Gradient.Exceptions;
12 | using LostTech.TensorFlow;
13 |
14 | using SixLabors.Fonts;
15 | using SixLabors.ImageSharp;
16 | using SixLabors.ImageSharp.Drawing;
17 | using SixLabors.ImageSharp.Drawing.Processing;
18 | using SixLabors.ImageSharp.PixelFormats;
19 | using SixLabors.ImageSharp.Processing;
20 |
21 | using tensorflow;
22 | using tensorflow.core.protobuf.config_pb2;
23 | using tensorflow.datasets.ObjectDetection;
24 | using tensorflow.keras.applications;
25 |
26 | public partial class YoloForm : Form {
27 | dynamic model;
28 | dynamic infer;
29 | bool loaded;
30 | public YoloForm() {
31 | this.InitializeComponent();
32 |
33 | this.openPicDialog.InitialDirectory = Environment.GetEnvironmentVariable("IMG_DIR")
34 | ?? Environment.GetFolderPath(Environment.SpecialFolder.MyPictures);
35 |
36 | GradientEngine.UseEnvironmentFromVariable();
37 | TensorFlowSetup.Instance.EnsureInitialized();
38 |
39 | // TODO: remove this after replacing tf.sigmoid in PostProcessBBBox
40 | tf.enable_eager_execution();
41 | tf.enable_v2_behavior();
42 |
43 | dynamic config = config_pb2.ConfigProto.CreateInstance();
44 | config.gpu_options.allow_growth = true;
45 | tf.keras.backend.set_session(Session.NewDyn(config: config));
46 | }
47 |
48 | private void openPic_Click(object sender, EventArgs e) {
49 | if (this.openPicDialog.ShowDialog(this) != DialogResult.OK)
50 | return;
51 |
52 | using var image = Image.Load(this.openPicDialog.FileName);
53 |
54 | var timer = Stopwatch.StartNew();
55 | ObjectDetectionResult[] detections = YOLO.Detect(this.infer,
56 | supportedSize: new Size(MS_COCO.InputSize, MS_COCO.InputSize),
57 | image: image);
58 | timer.Stop();
59 |
60 | image.Mutate(context => {
61 | var font = SystemFonts.CreateFont("Arial", 16);
62 | var textColor = Color.White;
63 | var boxPen = new Pen(Color.White, width: 4);
64 | foreach(var detection in detections) {
65 | string className = detection.Class < MS_COCO.ClassCount && detection.Class >= 0
66 | ? MS_COCO.ClassNames[detection.Class] : "imaginary class";
67 | string text = $"{className}: {detection.Score:P0}";
68 | var box = Scale(detection.Box, image.Size());
69 | context.DrawText(text, font, textColor, TopLeft(box));
70 | var drawingBox = new RectangularPolygon(box);
71 | context.Draw(boxPen, drawingBox);
72 | }
73 | });
74 |
75 | using var temp = new MemoryStream();
76 | image.SaveAsBmp(temp);
77 | temp.Position = 0;
78 |
79 | this.pictureBox.Image = new System.Drawing.Bitmap(temp);
80 |
81 | this.Text = "YOLO " + string.Join(", ", detections.Select(d => MS_COCO.ClassNames[d.Class]))
82 | + " in " + timer.ElapsedMilliseconds + "ms";
83 | }
84 |
85 | static PointF TopLeft(RectangleF rect) => new PointF(rect.Left, rect.Top);
86 | static RectangleF Scale(RectangleF rect, SizeF size)
87 | => new RectangleF(x: rect.Left * size.Width, width: rect.Width * size.Width,
88 | y: rect.Top * size.Height, height: rect.Height * size.Height);
89 |
90 | void LoadWeights() {
91 | while (!this.loaded) {
92 | string modelDir = Environment.GetEnvironmentVariable("DETECT_UI_WEIGHTS");
93 | if (modelDir is null) {
94 | if (this.openWeightsDirDialog.ShowDialog(this) != DialogResult.OK)
95 | continue;
96 |
97 | modelDir = this.openWeightsDirDialog.SelectedPath;
98 | }
99 |
100 | try {
101 | this.model = tf.saved_model.load_v2(modelDir, tags: tf.saved_model.SERVING);
102 | this.infer = this.model.signatures["serving_default"];
103 | } catch (ValueError e) {
104 | this.Text = e.Message;
105 | continue;
106 | }
107 | this.loaded = true;
108 | this.Text = "YOLO " + modelDir;
109 | }
110 |
111 | this.openPic.Enabled = true;
112 | }
113 |
114 | private void YoloForm_Load(object sender, EventArgs e) {
115 | this.LoadWeights();
116 | }
117 | }
118 | }
119 |
--------------------------------------------------------------------------------
/samples/DetectUI/YoloForm.resx:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 | text/microsoft-resx
50 |
51 |
52 | 2.0
53 |
54 |
55 | System.Resources.ResXResourceReader, System.Windows.Forms, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089
56 |
57 |
58 | System.Resources.ResXResourceWriter, System.Windows.Forms, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089
59 |
60 |
--------------------------------------------------------------------------------
/samples/TrainV4/Program.cs:
--------------------------------------------------------------------------------
1 | namespace tensorflow {
2 | using System;
3 |
4 | using LostTech.Gradient;
5 | using LostTech.TensorFlow;
6 |
7 | using ManyConsole.CommandLineUtils;
8 |
9 | class Program {
10 | static int Main(string[] args) {
11 | Console.Title = "YOLOv4";
12 | GradientEngine.UseEnvironmentFromVariable();
13 | TensorFlowSetup.Instance.EnsureInitialized();
14 |
15 | return ConsoleCommandDispatcher.DispatchCommand(
16 | ConsoleCommandDispatcher.FindCommandsInSameAssemblyAs(typeof(Program)),
17 | args, Console.Out);
18 | }
19 | }
20 | }
21 |
--------------------------------------------------------------------------------
/samples/TrainV4/ToSavedModel.cs:
--------------------------------------------------------------------------------
1 | namespace tensorflow {
2 | using System;
3 | using System.Collections.Generic;
4 |
5 | using ManyConsole.CommandLineUtils;
6 |
7 | using numpy;
8 |
9 | using tensorflow.datasets.ObjectDetection;
10 | using tensorflow.keras.applications;
11 | using tensorflow.keras.models;
12 |
13 | class ToSavedModel : ConsoleCommand {
14 | public string WeigthsPath { get; set; }
15 | public string OutputPath { get; set; }
16 | public int InputSize { get; set; } = MS_COCO.InputSize;
17 | public int ClassCount { get; set; } = MS_COCO.ClassCount;
18 | public float ScoreThreshold { get; set; } = 0.2f;
19 | public int[] Strides { get; set; } = YOLOv4.Strides.ToArray();
20 | public ndarray Anchors { get; set; } = YOLOv4.Anchors;
21 | public override int Run(string[] remainingArguments) {
22 | var trainable = YOLO.CreateV4Trainable(inputSize: this.InputSize,
23 | classCount: this.ClassCount,
24 | strides: this.Strides);
25 | trainable.load_weights(this.WeigthsPath);
26 | var output = YOLOv4.Output.Get(trainable);
27 | Tensor input = trainable.input_dyn;
28 | var savable = YOLO.CreateSaveable(inputSize: this.InputSize, input, output,
29 | classCount: this.ClassCount,
30 | strides: this.Strides,
31 | anchors: tf.constant(this.Anchors),
32 | xyScale: YOLOv4.XYScale,
33 | scoreThreshold: this.ScoreThreshold);
34 | savable.summary();
35 | savable.save(this.OutputPath, save_format: "tf", include_optimizer: false);
36 | return 0;
37 | }
38 |
39 | public ToSavedModel() {
40 | this.IsCommand("to-saved-model");
41 | this.HasRequiredOption("w|weights=", "Path to weights file (.index)",
42 | path => this.WeigthsPath = path);
43 | this.HasRequiredOption("o|output=", "Path to the output file",
44 | path => this.OutputPath = path);
45 | this.HasOption("t|score-threshold=", "Minimal score for detections",
46 | (float threshold) => this.ScoreThreshold = threshold);
47 | }
48 | }
49 | }
50 |
--------------------------------------------------------------------------------
/samples/TrainV4/TrainV4.cs:
--------------------------------------------------------------------------------
1 | namespace tensorflow {
2 | using System.Collections.Generic;
3 | using System.Diagnostics;
4 | using System.IO;
5 | using System.Linq;
6 |
7 | using LostTech.Gradient;
8 |
9 | using ManyConsole.CommandLineUtils;
10 |
11 | using numpy;
12 |
13 | using tensorflow.core.protobuf.config_pb2;
14 | using tensorflow.data;
15 | using tensorflow.datasets.ObjectDetection;
16 | using tensorflow.keras.applications;
17 | using tensorflow.keras.callbacks;
18 | using tensorflow.keras.models;
19 | using tensorflow.keras.optimizers;
20 |
21 | class TrainV4 : ConsoleCommand {
22 | public string[] Annotations { get; set; }
23 | public string[] ClassNames { get; set; }
24 | public int InputSize { get; set; } = MS_COCO.InputSize;
25 | public int MaxBBoxPerScale { get; set; } = 150;
26 | public int BatchSize { get; set; } = 2;
27 | public ndarray Anchors { get; set; } = YOLOv4.Anchors.AsType();
28 | public int AnchorsPerScale { get; set; } = YOLOv4.AnchorsPerScale;
29 | public int[] Strides { get; set; } = YOLOv4.Strides.ToArray();
30 | public bool LogDevicePlacement { get; set; }
31 | public bool GpuAllowGrowth { get; set; }
32 | public bool ModelSummary { get; set; }
33 | public bool TestRun { get; set; }
34 | public bool Benchmark { get; set; }
35 | public int FirstStageEpochs { get; set; } = 20;
36 | public int SecondStageEpochs { get; set; } = 30;
37 | public int WarmupEpochs { get; set; } = 2;
38 | public string LogDir { get; set; }
39 | public string? WeightsPath { get; set; }
40 |
41 | public override int Run(string[] remainingArguments) {
42 | Trace.Listeners.Add(new ConsoleTraceListener(useErrorStream: true));
43 |
44 | tf.debugging.set_log_device_placement(this.LogDevicePlacement);
45 |
46 | if (this.GpuAllowGrowth) {
47 | dynamic config = config_pb2.ConfigProto.CreateInstance();
48 | config.gpu_options.allow_growth = true;
49 | tf.keras.backend.set_session(Session.NewDyn(config: config));
50 | }
51 |
52 | if (this.TestRun)
53 | this.Annotations = this.Annotations.Take(this.BatchSize*3).ToArray();
54 |
55 | var dataset = new ObjectDetectionDataset(this.Annotations,
56 | classNames: this.ClassNames,
57 | strides: this.Strides,
58 | inputSize: this.InputSize,
59 | anchors: this.Anchors,
60 | anchorsPerScale: this.AnchorsPerScale,
61 | maxBBoxPerScale: this.MaxBBoxPerScale);
62 | var model = YOLO.CreateV4Trainable(dataset.InputSize, dataset.ClassNames.Length, dataset.Strides);
63 |
64 | var learningRateSchedule = new YOLO.LearningRateSchedule(
65 | totalSteps: (long)(this.FirstStageEpochs + this.SecondStageEpochs) * dataset.BatchCount(this.BatchSize),
66 | warmupSteps: this.WarmupEpochs * dataset.BatchCount(this.BatchSize));
67 | // https://github.com/AlexeyAB/darknet/issues/1845
68 | var optimizer = new Adam(learning_rate: learningRateSchedule, epsilon: 0.000001);
69 | if (this.ModelSummary)
70 | model.summary();
71 | if (this.WeightsPath != null)
72 | model.load_weights(this.WeightsPath);
73 |
74 | var callbacks = new List {
75 | new LearningRateLogger(),
76 | new TensorBoard(log_dir: this.LogDir, batch_size: this.BatchSize, profile_batch: 4),
77 | };
78 | if (!this.Benchmark && !this.TestRun)
79 | callbacks.Add(new ModelCheckpoint("yoloV4.weights.{epoch:02d}", save_weights_only: true));
80 |
81 | YOLO.TrainGenerator(model, optimizer, dataset, batchSize: this.BatchSize,
82 | firstStageEpochs: this.FirstStageEpochs,
83 | secondStageEpochs: this.SecondStageEpochs,
84 | callbacks: callbacks);
85 |
86 | if (!this.Benchmark && !this.TestRun)
87 | model.save_weights("yoloV4.weights-trained");
88 |
89 | // the following does not work due to the need to name layers properly
90 | // https://stackoverflow.com/questions/61402903/unable-to-create-group-name-already-exists
91 | // model.save("yoloV4-trained");
92 | return 0;
93 | }
94 |
95 | public TrainV4() {
96 | this.IsCommand("trainV4");
97 | this.HasRequiredOption("a|annotations=", "Path to MS COCO-compatible annotations file",
98 | filePath => this.Annotations = Tools.NonEmptyLines(filePath));
99 | this.HasRequiredOption("c|class-names=",
100 | "Path to MS COCO-compatible .names file listing all object classes",
101 | filePath => this.ClassNames = Tools.NonEmptyLines(filePath));
102 | this.HasOption("batch-size=", "Batch size during training",
103 | (int size) => this.BatchSize = size);
104 | this.HasOption("log-device-placement", "Enables TensorFlow device placement logging",
105 | (string onOff) => this.LogDevicePlacement = onOff == "on");
106 | this.HasOption("gpu-allow-growth", "Makes TensorFlow allocate GPU memory as needed (default: reserve all GPU memory)",
107 | (string onOff) => this.GpuAllowGrowth = onOff == "on");
108 | this.HasOption("model-summary", "Print model summary before training",
109 | (string onOff) => this.ModelSummary = onOff == "on");
110 | this.HasOption("log-dir=", "Write training logs to the specified directory",
111 | dir => {
112 | dir = Path.GetFullPath(dir);
113 | Directory.CreateDirectory(dir);
114 | this.LogDir = dir;
115 | });
116 | this.HasOption("transfer-epochs=", "Number of epochs to run to adapt before fine-tuning",
117 | (int epochs) => this.FirstStageEpochs = epochs);
118 | this.HasOption("training-epochs=", "Number of epochs to run training/fine-tuning for",
119 | (int epochs) => this.SecondStageEpochs = epochs);
120 | this.HasOption("test-run", "Only does 1 batch per epoch instead of the entire dataset",
121 | (string onOff) => this.TestRun = onOff == "on");
122 | this.HasOption("weights=", "Path to pretrained model weights",
123 | (string path) => this.WeightsPath = path);
124 | this.HasOption("benchmark", "Run 1 epoch without training and output losses",
125 | (string onOff) => this.Benchmark = onOff == "on");
126 | }
127 | }
128 | }
129 |
--------------------------------------------------------------------------------
/samples/TrainV4/TrainV4.csproj:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | Exe
5 | netcoreapp3.1
6 | tensorflow
7 | enable
8 | Yolo.TrainV4
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
--------------------------------------------------------------------------------
/samples/TrainV4/TrainingLogger.cs:
--------------------------------------------------------------------------------
1 | namespace tensorflow.keras.callbacks {
2 | using System;
3 | using System.Collections.Generic;
4 | using System.Linq;
5 |
6 | class TrainingLogger: Callback, ICallback {
7 | public override void on_epoch_end(int epoch, IDictionary logs) {
8 | string metrics = string.Join("; ", logs.Select(entry =>
9 | $"{entry.Key}={entry.Value}"));
10 | Console.WriteLine($"epoch {epoch} @ {DateTime.Now.ToLongTimeString()}: {metrics}");
11 | }
12 |
13 | dynamic? ICallback.on_epoch_end(dynamic epoch, dynamic logs) {
14 | this.on_epoch_end((int)epoch, (IDictionary)logs);
15 | return null;
16 | }
17 | }
18 | }
19 |
--------------------------------------------------------------------------------
/src/BufferedEnumerable.cs:
--------------------------------------------------------------------------------
1 | namespace tensorflow {
2 | using System;
3 | using System.Collections;
4 | using System.Collections.Concurrent;
5 | using System.Collections.Generic;
6 | using System.Linq;
7 | using System.Threading.Tasks;
8 |
9 | using Python.Runtime;
10 |
11 | class BufferedEnumerable : IEnumerable, ICollection {
12 | readonly IReadOnlyList lazyList;
13 | readonly int bufferSize;
14 |
15 | public BufferedEnumerable(IReadOnlyList lazyList, int bufferSize) {
16 | this.lazyList = lazyList ?? throw new ArgumentNullException(nameof(lazyList));
17 | if (bufferSize < 1) throw new ArgumentOutOfRangeException(nameof(bufferSize));
18 | this.bufferSize = bufferSize;
19 | }
20 |
21 | public int Count => this.lazyList.Count;
22 |
23 | public IEnumerator GetEnumerator() {
24 | var buffer = new BlockingCollection>(boundedCapacity: this.bufferSize);
25 | var readyToRun = new BlockingCollection>(boundedCapacity: this.bufferSize);
26 |
27 | void Load() {
28 | while (!readyToRun.IsCompleted) {
29 | var task = readyToRun.Take();
30 | buffer.Add(task);
31 | }
32 | buffer.CompleteAdding();
33 | }
34 |
35 | void QueueLoading() {
36 | for(int i = 0; i < this.lazyList.Count; i++) {
37 | int index = i;
38 | var task = new Task(() => this.lazyList[index]);
39 | readyToRun.Add(task);
40 | task.Start();
41 | }
42 | readyToRun.CompleteAdding();
43 | }
44 |
45 | Task.Run(QueueLoading);
46 | Task.Run(Load);
47 |
48 | return buffer.GetConsumingEnumerable()
49 | .Select(t => {
50 | IntPtr multithreadHandle = PythonEngine.BeginAllowThreads();
51 | try {
52 | return t.Result;
53 | } finally {
54 | PythonEngine.EndAllowThreads(multithreadHandle);
55 | }
56 | })
57 | .GetEnumerator();
58 | }
59 | IEnumerator IEnumerable.GetEnumerator() => this.GetEnumerator();
60 |
61 | bool ICollection.IsSynchronized => false;
62 | object ICollection.SyncRoot => this.lazyList;
63 | void ICollection.CopyTo(Array array, int index) => throw new NotImplementedException();
64 | }
65 | }
66 |
--------------------------------------------------------------------------------
/src/InternalsVisibleTo.cs:
--------------------------------------------------------------------------------
1 | using System.Runtime.CompilerServices;
2 |
3 | [assembly: InternalsVisibleTo("YOLOv4.Tests")]
4 | [assembly: InternalsVisibleTo("Yolo.TrainV4")]
--------------------------------------------------------------------------------
/src/ListLinq.cs:
--------------------------------------------------------------------------------
1 | namespace tensorflow {
2 | using System;
3 | using System.Collections;
4 | using System.Collections.Generic;
5 | using System.Linq;
6 |
7 | static class ListLinq {
8 | public static IReadOnlyList Select(
9 | this IReadOnlyList source, Func selector)
10 | => new LazySelectList(source, selector);
11 |
12 | class LazySelectList : IReadOnlyList {
13 | readonly IReadOnlyList source;
14 | readonly Func selector;
15 | public LazySelectList(IReadOnlyList source, Func selector) {
16 | this.source = source ?? throw new ArgumentNullException(nameof(source));
17 | this.selector = selector ?? throw new ArgumentNullException(nameof(selector));
18 | }
19 | public T this[int index] => this.selector(this.source[index]);
20 | public int Count => this.source.Count;
21 | public IEnumerator GetEnumerator()
22 | => Enumerable.Select(this.source, this.selector).GetEnumerator();
23 |
24 | IEnumerator IEnumerable.GetEnumerator() => this.GetEnumerator();
25 | }
26 | }
27 | }
28 |
--------------------------------------------------------------------------------
/src/Tools.cs:
--------------------------------------------------------------------------------
1 | namespace tensorflow {
2 | using System;
3 | using System.Collections.Generic;
4 | using System.Linq;
5 |
6 | using LostTech.Gradient.BuiltIns;
7 |
8 | static class Tools {
9 | public static IEnumerable Repeat(int times) => Enumerable.Repeat(true, times);
10 | [Obsolete("Use random.shuffle for reproducibility")]
11 | public static void Shuffle(IList list) {
12 | var random = new Random();
13 | for(int i = list.Count - 1; i > 0; i--) {
14 | int swapWith = random.Next(i+1);
15 | Swap(list, i, swapWith);
16 | }
17 | }
18 |
19 | public static void Swap(IList list, int index1, int index2) {
20 | T tmp = list[index1];
21 | list[index1] = list[index2];
22 | list[index2] = tmp;
23 | }
24 |
25 | public static T[] Slice(this T[] array, Range range) {
26 | if (array is null) throw new ArgumentNullException(nameof(array));
27 |
28 | var (offset, len) = range.GetOffsetAndLength(array.Length);
29 | var result = new T[len];
30 | Array.Copy(array, offset, result, 0, len);
31 | return result;
32 | }
33 |
34 | public static string[] NonEmptyLines(string filePath)
35 | => System.IO.File.ReadAllLines(filePath)
36 | .Select(l => l.Trim())
37 | .Where(l => !string.IsNullOrEmpty(l))
38 | .ToArray();
39 |
40 | public static void Deconstruct(this T[] array, out T i0, out T i1, out T i2, out T i3) {
41 | if (array is null) throw new ArgumentNullException(nameof(array));
42 | if (array.Length != 4) throw new ArgumentException();
43 |
44 | i0 = array[0];
45 | i1 = array[1];
46 | i2 = array[2];
47 | i3 = array[3];
48 | }
49 | public static void Deconstruct(this T[] array, out T i0, out T i1, out T i2) {
50 | if (array is null) throw new ArgumentNullException(nameof(array));
51 | if (array.Length != 3) throw new ArgumentException();
52 |
53 | i0 = array[0];
54 | i1 = array[1];
55 | i2 = array[2];
56 | }
57 | public static void Deconstruct(this T[] array, out T i0, out T i1) {
58 | if (array is null) throw new ArgumentNullException(nameof(array));
59 | if (array.Length != 2) throw new ArgumentException();
60 |
61 | i0 = array[0];
62 | i1 = array[1];
63 | }
64 |
65 | public static IEnumerable<(int, T)> Enumerate(params T[] items)
66 | => items.Select((index, item) => (item, index));
67 |
68 | public static IEnumerable BufferedEnumerate(this IReadOnlyList list, int bufferSize)
69 | => new BufferedEnumerable(list, bufferSize);
70 |
71 | internal static PythonDict ToDictionary(this IEnumerable> seq) {
72 | var result = new PythonDict();
73 | foreach (var (key, val) in seq)
74 | result.Add(key, val);
75 | return result;
76 | }
77 | }
78 | }
79 |
--------------------------------------------------------------------------------
/src/Utils.cs:
--------------------------------------------------------------------------------
1 | namespace tensorflow {
2 | using System.Collections.Generic;
3 | using System.Globalization;
4 | using System.Linq;
5 |
6 | using tensorflow.keras;
7 | using tensorflow.keras.layers;
8 |
9 | static class Utils {
10 | public static void SetTrainableRecursive(ILayer parent, bool trainable) {
11 | parent.trainable = trainable;
12 | if (parent is IModel model) {
13 | foreach (ILayer nested in model.layers)
14 | SetTrainableRecursive(nested, trainable);
15 | }
16 | }
17 |
18 | public static void FreezeAll(IModel model) => SetTrainableRecursive(model, false);
19 | public static void UnfreezeAll(IModel model) => SetTrainableRecursive(model, true);
20 | }
21 | }
22 |
--------------------------------------------------------------------------------
/src/YOLOv4.csproj:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | LostTech.TensorFlow.YOLOv4
5 | 0.0.1
6 | tensorflow
7 | netstandard2.0
8 | 8.0
9 | enable
10 | true
11 |
12 |
13 | LICENSE
14 | Lost Tech LLC
15 | Lost Tech LLC
16 | Real-Time Object Detection network. TensorFlow-based implementation with support for fine-tuning and training from scratch.
17 | https://github.com/losttech/YOLOv4
18 | yolo;yolov4;object-detection;neural-network;darknet;tensorflow;machine-learning;ML;CNN;deep-learning
19 | See project site for samples
20 |
21 |
22 | true
23 | true
24 | true
25 | snupkg
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
--------------------------------------------------------------------------------
/src/data/ObjectDetectionDataset.cs:
--------------------------------------------------------------------------------
1 | namespace tensorflow.data {
2 | using System;
3 | using System.Collections.Generic;
4 | using System.Globalization;
5 | using System.IO;
6 | using System.Linq;
7 | using System.Runtime.InteropServices;
8 | using System.Threading;
9 |
10 | using numpy;
11 |
12 | using SixLabors.ImageSharp;
13 | using SixLabors.ImageSharp.Advanced;
14 | using SixLabors.ImageSharp.PixelFormats;
15 | using SixLabors.ImageSharp.Processing;
16 |
17 | using tensorflow.image;
18 |
19 | using Image = SixLabors.ImageSharp.Image;
20 | using Rectangle = SixLabors.ImageSharp.Rectangle;
21 | using Size = SixLabors.ImageSharp.Size;
22 |
23 | public class ObjectDetectionDataset {
24 | readonly string[] annotations;
25 | readonly string[] classNames;
26 | readonly int[] strides;
27 | readonly ndarray anchors;
28 | readonly int anchorsPerScale;
29 | readonly int inputSize;
30 | readonly int maxBBoxPerScale;
31 |
32 | public int InputSize => (int)this.inputSize;
33 | public ReadOnlySpan ClassNames => this.classNames;
34 | public ReadOnlySpan Strides => this.strides;
35 | int ClassCount => this.classNames.Length;
36 | public int Count => this.annotations.Length;
37 |
38 | public ObjectDetectionDataset(string[] annotations, string[] classNames,
39 | int[] strides, int inputSize,
40 | ndarray anchors,
41 | int anchorsPerScale,
42 | int maxBBoxPerScale) {
43 | this.classNames = classNames ?? throw new ArgumentNullException(nameof(classNames));
44 | if (classNames.Length == 0)
45 | throw new ArgumentException(message: "List of class names must not be empty");
46 |
47 | this.annotations = annotations ?? throw new ArgumentNullException(nameof(annotations));
48 | if (annotations.Length == 0)
49 | throw new ArgumentException(message: "List of annotations must not be empty");
50 |
51 | if (strides is null || strides.Length == 0)
52 | throw new ArgumentNullException(nameof(strides));
53 | if (strides.Any(NotPositive)) throw new ArgumentOutOfRangeException(nameof(strides));
54 | this.strides = strides.ToArray();
55 |
56 | if (anchors is null) throw new ArgumentNullException(nameof(anchors));
57 | if (anchors.ndim != 3) throw new ArgumentException("Bad shape", paramName: nameof(anchors));
58 | this.anchors = anchors;
59 |
60 | if (anchorsPerScale <= 0)
61 | throw new ArgumentOutOfRangeException(nameof(anchorsPerScale));
62 | this.anchorsPerScale = anchorsPerScale;
63 |
64 | if (inputSize <= 0)
65 | throw new ArgumentOutOfRangeException(nameof(inputSize));
66 | this.inputSize = inputSize;
67 |
68 | if (maxBBoxPerScale <= 0)
69 | throw new ArgumentOutOfRangeException(nameof(maxBBoxPerScale));
70 | this.maxBBoxPerScale = maxBBoxPerScale;
71 | }
72 |
73 | public void Shuffle() => Tools.Shuffle(this.annotations);
74 |
75 | public IReadOnlyList Batch(int batchSize,
76 | Func? onloadAugmentation) {
77 | if (batchSize <= 0) throw new ArgumentOutOfRangeException(nameof(batchSize));
78 |
79 | return new BatchList(this, batchSize: batchSize, onloadAugmentation);
80 | }
81 |
82 | class BatchList: IReadOnlyList {
83 | readonly ObjectDetectionDataset dataset;
84 | public Func? OnloadAugmentation { get; }
85 | public int BatchSize { get; }
86 | public int Count { get; }
87 |
88 | public EntryBatch this[int index] => this.dataset.GetBatch(this.BatchSize, index, this.OnloadAugmentation);
89 |
90 | public BatchList(ObjectDetectionDataset dataset, int batchSize, Func? onloadAugmentation) {
91 | if (dataset is null) throw new ArgumentNullException(nameof(dataset));
92 | if (batchSize <= 0) throw new ArgumentOutOfRangeException(nameof(batchSize));
93 | this.dataset = dataset;
94 | this.BatchSize = batchSize;
95 | this.Count = this.dataset.BatchCount(this.BatchSize);
96 | this.OnloadAugmentation = onloadAugmentation;
97 | }
98 |
99 | public IEnumerator GetEnumerator() {
100 | for(int batch = 0; batch < this.Count; batch++)
101 | yield return this[batch];
102 | }
103 | System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator() => this.GetEnumerator();
104 | }
105 |
106 | public int BatchCount(int batchSize) {
107 | if (batchSize <= 0) throw new ArgumentOutOfRangeException(nameof(batchSize));
108 | return (int)Math.Ceiling(this.Count * 1F / batchSize);
109 | }
110 |
111 | public EntryBatch GetBatch(int batchSize, int batchIndex,
112 | Func? onloadAugmentation) {
113 | if (batchSize <= 0) throw new ArgumentOutOfRangeException(nameof(batchSize));
114 | int totalBatches = this.BatchCount(batchSize);
115 | if (batchIndex < 0 || batchIndex >= totalBatches)
116 | throw new IndexOutOfRangeException();
117 |
118 | int[] outputSizes = this.strides.Select(stride => this.inputSize / stride).ToArray();
119 |
120 | var batchImages = np.zeros(batchSize, this.inputSize, this.inputSize, 3);
121 | var batchBBoxLabels = outputSizes.Select(outputSize
122 | => np.zeros(
123 | batchSize, outputSize, outputSize,
124 | this.anchorsPerScale, 5 + this.ClassCount)
125 | ).ToArray();
126 |
127 | var batchBBoxes = outputSizes.Select(
128 | _ => np.zeros(batchSize, this.maxBBoxPerScale, 4))
129 | .ToArray();
130 |
131 | for (int itemNo = 0; itemNo < batchSize; itemNo++) {
132 | int index = batchIndex * batchSize + itemNo;
133 | // loop the last few items for the last batch if necessary
134 | if (index >= this.Count) index -= this.Count;
135 |
136 | string annotation = this.annotations[index];
137 | var rawEntry = LoadAnnotationClr(annotation);
138 | if (onloadAugmentation != null)
139 | rawEntry = onloadAugmentation(rawEntry);
140 |
141 | var entry = Preprocess(rawEntry, new Size(this.inputSize, this.inputSize));
142 |
143 | var (labels, boxes) = this.PreprocessTrueBoxes(entry.BoundingBoxes, outputSizes);
144 |
145 | batchImages[itemNo, .., .., ..] = entry.Image;
146 | for (int i = 0; i < outputSizes.Length; i++) {
147 | batchBBoxLabels[i][itemNo, .., .., .., ..] = labels[i];
148 | batchBBoxes[i][itemNo, .., ..] = boxes[i];
149 | }
150 | }
151 |
152 | return new EntryBatch {
153 | Images = batchImages,
154 | BBoxLabels = batchBBoxLabels,
155 | BBoxes = batchBBoxes
156 | };
157 | }
158 |
159 | public static string[] LoadAnnotations(TextReader reader) {
160 | var result = new List();
161 | for (string line = reader.ReadLine(); line != null; line = reader.ReadLine()) {
162 | string trimmed = line.Trim();
163 | if (trimmed.Split(new[] { ' ', '\t' }, StringSplitOptions.RemoveEmptyEntries).Length > 1)
164 | result.Add(trimmed);
165 | }
166 |
167 | return result.ToArray();
168 | }
169 |
170 | // TODO: for reproducibility, use numpy.random
171 | static readonly ThreadLocal random = new ThreadLocal(() => new Random());
172 | public static Entry RandomHorizontalFlip(Entry entry) {
173 | if (random.Value.Next(2) == 0)
174 | return entry;
175 |
176 | int width = entry.Image.shape.Item2;
177 | int[] reversedXs = Enumerable.Range(0, width).Reverse().ToArray();
178 | entry.Image = (ndarray)entry.Image[.., reversedXs, ..];
179 | entry.BoundingBoxes[.., new[] { 0, 2 }] =
180 | width - entry.BoundingBoxes[.., new[] { 2, 0 }];
181 |
182 | return entry;
183 | }
184 | public static ClrEntry RandomHorizontalFlip(ClrEntry entry) {
185 | if (random.Value.Next(2) == 0)
186 | return entry;
187 |
188 | entry.Image.Mutate(x => x.Flip(FlipMode.Horizontal));
189 | entry.BoundingBoxes[.., new[] { 0, 2 }] = entry.Image.Width - entry.BoundingBoxes[.., new[] { 2, 0 }];
190 |
191 | return entry;
192 | }
193 |
194 | public static Entry RandomCrop(Entry entry) {
195 | if (random.Value.Next(2) == 0)
196 | return entry;
197 |
198 | int h = entry.Image.shape.Item1, w = entry.Image.shape.Item2;
199 | GetRandomCrop(entry.BoundingBoxes, h, w,
200 | out int cropXMin, out int cropYMin, out int cropXMax, out int cropYMax);
201 |
202 | entry.Image = entry.Image[cropYMin..cropYMax, cropXMin..cropXMax];
203 | entry.BoundingBoxes[.., new[] { 0, 2 }] -= cropXMin;
204 | entry.BoundingBoxes[.., new[] { 1, 3 }] -= cropYMin;
205 |
206 | return entry;
207 | }
208 |
209 | public static ClrEntry RandomCrop(ClrEntry entry) {
210 | if (random.Value.Next(2) == 0)
211 | return entry;
212 |
213 | GetRandomCrop(entry.BoundingBoxes, height: entry.Image.Height, width: entry.Image.Width,
214 | out int cropXMin, out int cropYMin, out int cropXMax, out int cropYMax);
215 |
216 | var rect = new Rectangle(cropXMin, cropYMin, cropXMax - cropXMin, cropYMax - cropYMin);
217 |
218 | entry.Image.Mutate(x => x.Crop(rect));
219 | entry.BoundingBoxes[.., new[] { 0, 2 }] -= cropXMin;
220 | entry.BoundingBoxes[.., new[] { 1, 3 }] -= cropYMin;
221 |
222 | return entry;
223 | }
224 |
225 | static void GetRandomCrop(ndarray boundingBoxes, int height, int width,
226 | out int cropXMin, out int cropYMin, out int cropXMax, out int cropYMax) {
227 | ndarray maxBBox = np.concatenate(new[] {
228 | (ndarray)boundingBoxes[.., 0..2].min(axis: 0),
229 | (ndarray)boundingBoxes[.., 2..4].max(axis: 0),
230 | }, axis: -1);
231 |
232 | int maxLtrans = maxBBox[0].AsScalar();
233 | int maxUtrans = maxBBox[1].AsScalar();
234 | int maxRtrans = width - maxBBox[2].AsScalar();
235 | int maxDtrans = height - maxBBox[3].AsScalar();
236 |
237 | cropXMin = Math.Max(0, maxBBox[0].AsScalar() - random.Value.Next(maxLtrans));
238 | cropYMin = Math.Max(0, maxBBox[1].AsScalar() - random.Value.Next(maxUtrans));
239 | cropXMax = Math.Min(width, maxBBox[2].AsScalar() + random.Value.Next(maxRtrans));
240 | cropYMax = Math.Min(height, maxBBox[3].AsScalar() + random.Value.Next(maxDtrans));
241 | }
242 |
243 | public static Entry RandomTranslate(Entry entry) where T : unmanaged {
244 | if (random.Value.Next(2) == 0)
245 | return entry;
246 |
247 | int h = entry.Image.shape.Item1, w = entry.Image.shape.Item2;
248 | GetRandomTranslation(entry.BoundingBoxes, h, w, out int tx, out int ty);
249 |
250 | entry.Image = TranslateImage(entry.Image, tx: tx, ty: ty);
251 | entry.BoundingBoxes[.., new[] { 0, 2 }] += tx;
252 | entry.BoundingBoxes[.., new[] { 1, 3 }] += ty;
253 |
254 | return entry;
255 | }
256 |
257 | public static ClrEntry RandomTranslate(ClrEntry entry) {
258 | if (random.Value.Next(2) == 0)
259 | return entry;
260 |
261 | GetRandomTranslation(entry.BoundingBoxes,
262 | height: entry.Image.Height, width: entry.Image.Width,
263 | out int tx, out int ty);
264 |
265 | if (tx == 0 && ty == 0) return entry;
266 |
267 | var rect = new Rectangle(-tx, -ty, entry.Image.Width, entry.Image.Height);
268 | var translated = new Image(entry.Image.Width, entry.Image.Height, Color.Black);
269 | translated.Mutate(x => x.DrawImage(entry.Image, new Point(tx, ty), opacity: 1));
270 | entry.Image = translated;
271 | entry.BoundingBoxes[.., new[] { 0, 2 }] += tx;
272 | entry.BoundingBoxes[.., new[] { 1, 3 }] += ty;
273 |
274 | return entry;
275 | }
276 |
277 | static void GetRandomTranslation(ndarray boundingBoxes, int height, int width,
278 | out int tx, out int ty) {
279 | ndarray maxBBox = np.concatenate(new[] {
280 | (ndarray)boundingBoxes[.., 0..2].min(axis: 0),
281 | (ndarray)boundingBoxes[.., 2..4].max(axis: 0),
282 | }, axis: -1);
283 |
284 | int maxLtrans = maxBBox[0].AsScalar();
285 | int maxUtrans = maxBBox[1].AsScalar();
286 | int maxRtrans = width - maxBBox[2].AsScalar();
287 | int maxDtrans = height - maxBBox[3].AsScalar();
288 |
289 | // TODO: use numpy.random.uniform for reproducibility?
290 | var (min, max) = Sort(-(maxLtrans - 1), maxRtrans - 1);
291 | tx = random.Value.Next(minValue: min, maxValue: max);
292 | (min, max) = Sort(-(maxUtrans - 1), maxDtrans - 1);
293 | ty = random.Value.Next(minValue: min, maxValue: max);
294 | }
295 |
296 | static (int, int) Sort(int a, int b) => (Math.Min(a, b), Math.Max(a, b));
297 |
298 | static ndarray TranslateImage(ndarray image, int tx, int ty) where T : unmanaged {
299 | if (tx == 0 && ty == 0) return image;
300 |
301 | int h = image.shape.Item1, w = image.shape.Item2, c = image.shape.Item3;
302 | int toX = tx < 0 ? 0 : tx;
303 | int toY = ty < 0 ? 0 : ty;
304 |
305 | var temp = np.zeros(h + Math.Abs(ty), w + Math.Abs(tx), c);
306 | temp[toY..(toY + h), toX..(toX + w), ..] = image;
307 |
308 | int fromX = tx < 0 ? -tx : 0;
309 | int fromY = ty < 0 ? -ty : 0;
310 | return temp[fromY..(fromY + h), fromX..(fromX + w), ..];
311 | }
312 |
313 | public static Entry LoadAnnotation(string annotation) {
314 | if (annotation is null) throw new ArgumentNullException(nameof(annotation));
315 |
316 | string[] line = annotation.Split(new[] { ' ', '\t' }, StringSplitOptions.RemoveEmptyEntries);
317 | string imagePath = line[0];
318 | ndarray image = ImageTools.LoadRGB8(imagePath).ToNumPyArray();
319 | ndarray bboxes = LoadBBoxes(line.Slice(1..));
320 |
321 | return new Entry {
322 | Image = image,
323 | BoundingBoxes = bboxes,
324 | };
325 | }
326 |
327 | public static ClrEntry LoadAnnotationClr(string annotation) {
328 | if (annotation is null) throw new ArgumentNullException(nameof(annotation));
329 |
330 | string[] line = annotation.Split(new[] { ' ', '\t' }, StringSplitOptions.RemoveEmptyEntries);
331 | string imagePath = line[0];
332 | var image = Image.Load(imagePath);
333 | ndarray bboxes = LoadBBoxes(line.Slice(1..));
334 |
335 | return new ClrEntry {
336 | Image = image,
337 | BoundingBoxes = bboxes,
338 | };
339 | }
340 |
341 | static ndarray LoadBBoxes(string[] bboxTexts)
342 | => (ndarray)np.array(bboxTexts
343 | .Select(box =>box
344 | .Split(new[] { ',' }, StringSplitOptions.RemoveEmptyEntries)
345 | .Select(s => int.Parse(s, CultureInfo.InvariantCulture))
346 | .ToArray()));
347 |
348 | public static Entry RandomlyApplyAugmentations(Entry entry) where T : unmanaged {
349 | entry = RandomHorizontalFlip(entry);
350 | entry = RandomCrop(entry);
351 | entry = RandomTranslate(entry);
352 | return entry;
353 | }
354 |
355 | public static ClrEntry RandomlyApplyAugmentations(ClrEntry entry) {
356 | entry = RandomHorizontalFlip(entry);
357 | entry = RandomCrop(entry);
358 | entry = RandomTranslate(entry);
359 | return entry;
360 | }
361 |
362 | public static Entry Preprocess(ClrEntry entry, Size targetSize)
363 | => ImageTools.YoloPreprocess(entry, targetSize);
364 |
365 | internal static ndarray BBoxIOU(ndarray boxes1, ndarray boxes2) {
366 | var area1 = boxes1[np.rest_of_the_axes, 2] * boxes1[np.rest_of_the_axes, 3];
367 | var area2 = boxes1[np.rest_of_the_axes, 2] * boxes1[np.rest_of_the_axes, 3];
368 |
369 | boxes1 = np.concatenate(new[] {
370 | boxes1[np.rest_of_the_axes, ..2] - boxes1[np.rest_of_the_axes, 2..] * 0.5f,
371 | boxes1[np.rest_of_the_axes, ..2] + boxes1[np.rest_of_the_axes, 2..] * 0.5f,
372 | }, axis: -1);
373 | boxes2 = np.concatenate(new[] {
374 | boxes2[np.rest_of_the_axes, ..2] - boxes2[np.rest_of_the_axes, 2..]*0.5f,
375 | boxes2[np.rest_of_the_axes, ..2] + boxes2[np.rest_of_the_axes, 2..]*0.5f,
376 | }, axis: -1);
377 |
378 | var leftUp = np.maximum(boxes1[np.rest_of_the_axes, ..2], boxes2[np.rest_of_the_axes, ..2]);
379 | var rightDown = np.minimum(boxes1[np.rest_of_the_axes, 2..], boxes2[np.rest_of_the_axes, 2..]);
380 |
381 | var intersection = np.maximum(rightDown - leftUp, 0.0f);
382 | var intersectionArea = intersection[np.rest_of_the_axes, 0] * intersection[np.rest_of_the_axes, 1];
383 | var epsilon = new float32(tf.keras.backend.epsilon());
384 | var unionArea = np.maximum(area1 + area2 - intersectionArea, epsilon);
385 |
386 | return np.maximum(epsilon, intersectionArea / unionArea);
387 | }
388 | (ndarray[], ndarray[]) PreprocessTrueBoxes(ndarray bboxes, int[] outputSizes) {
389 | var label = outputSizes
390 | .Select(size => np.zeros(
391 | size, size, this.anchorsPerScale, 5 + this.ClassCount))
392 | .ToArray();
393 | var bboxesXYWH = outputSizes
394 | .Select(_ => np.zeros(this.maxBBoxPerScale, 4))
395 | .ToArray();
396 | var bboxCount = np.zeros(outputSizes.Length);
397 | var stridesPlus = np.array(this.strides)[(.., np.newaxis)].AsArray().AsType();
398 |
399 | foreach (ndarray bbox in bboxes) {
400 | var coords = bbox[..4];
401 | var classIndex = bbox[4];
402 |
403 | var oneHot = np.zeros(this.ClassCount);
404 | oneHot[classIndex] = 1;
405 |
406 | var uniform = np.full((int)this.ClassCount,
407 | fill_value: 1.0f / this.ClassCount,
408 | dtype: dtype.GetClass())
409 | .AsArray();
410 | const float deta = 0.01f;
411 | var smoothOneHot = oneHot * (1 - deta) + deta * uniform;
412 |
413 | var bboxXYWH = np.concatenate(new[] {
414 | (coords[2..] + coords[..2]).AsType() * 0.5f,
415 | (coords[2..] - coords[..2]).AsType()
416 | }, axis: -1);
417 | var bboxXYWHScaled = (1.0f * bboxXYWH[np.newaxis, ..] / stridesPlus).AsArray();
418 |
419 | var iou = new List>();
420 |
421 | void UpdateBoxesAtScale(int scale, object iouMaskOrIndex) {
422 | var indices = bboxXYWHScaled[scale, 0..2].AsType();
423 | ArrayOrElement xind = indices[0], yind = indices[1];
424 |
425 | label[scale][yind, xind, iouMaskOrIndex, ..] = 0;
426 | label[scale][yind, xind, iouMaskOrIndex, 0..4] = bboxXYWH;
427 | label[scale][yind, xind, iouMaskOrIndex, 4..5] = 1.0f;
428 | label[scale][yind, xind, iouMaskOrIndex, 5..] = smoothOneHot;
429 |
430 | int bboxIndex = bboxCount[scale].AsScalar() % (int)this.maxBBoxPerScale;
431 | bboxesXYWH[scale][bboxIndex, ..4] = bboxXYWH;
432 | bboxCount[scale] += 1;
433 | }
434 |
435 | bool positiveExists = false;
436 | for (int scaleIndex = 0; scaleIndex < outputSizes.Length; scaleIndex++) {
437 | int outputSize = outputSizes[scaleIndex];
438 | var anchorsXYWH = np.zeros(this.anchorsPerScale, 4);
439 | anchorsXYWH[.., 0..2] = anchorsXYWH[.., 0..2].AsType().AsType() + 0.5f;
440 | anchorsXYWH[.., 2..4] = this.anchors[scaleIndex].AsArray();
441 |
442 | var iouScale = BBoxIOU(bboxXYWHScaled[scaleIndex][(np.newaxis, ..)].AsArray(),
443 | anchorsXYWH);
444 | iou.Add(iouScale);
445 | var iouMask = iouScale > 0.3f;
446 |
447 | if (iouMask.any()) {
448 | UpdateBoxesAtScale(scaleIndex, iouMask);
449 |
450 | positiveExists = true;
451 | }
452 | }
453 |
454 | if (!positiveExists) {
455 | int bestAnchorIndex = (int)np.array(iou).reshape(-1).argmax(axis: -1).AsScalar();
456 | int bestDetection = bestAnchorIndex / (int)this.anchorsPerScale;
457 | int bestAnchor = bestAnchorIndex % (int)this.anchorsPerScale;
458 |
459 | UpdateBoxesAtScale(bestDetection, bestAnchor);
460 | }
461 | }
462 |
463 | return (label, bboxesXYWH);
464 | }
465 |
466 | public struct EntryBatch {
467 | public ndarray Images { get; set; }
468 | public ndarray[] BBoxLabels { get; set; }
469 | public ndarray[] BBoxes { get; set; }
470 |
471 | public (IDictionary>, ndarray) ToGeneratorOutput()
472 | => (this.BBoxLabels.Select((l, i) => ($"label{i}", l))
473 | .Concat(this.BBoxes.Select((b, i) => ($"box{i}", b)))
474 | .Append(("image", this.Images))
475 | .ToDictionary(),
476 |
477 | np.zeros(this.Images.shape.Item1));
478 | }
479 |
480 | public struct Entry {
481 | /// HWC image
482 | public ndarray Image { get; set; }
483 | public ndarray BoundingBoxes { get; set; }
484 | }
485 |
486 | public struct ClrEntry {
487 | public Image Image { get; set; }
488 | public ndarray BoundingBoxes { get; set; }
489 |
490 | public Entry ToNumPyEntry() {
491 | var numpyImage = np.zeros(this.Image.Height, this.Image.Width * 3);
492 | for(int y = 0; y < this.Image.Height; y++) {
493 | var row = this.Image.GetPixelRowMemory(y);
494 | numpyImage[y] = MarshalingExtensions.ToNumPyArray(
495 | MemoryMarshal.Cast(row.Span));
496 | }
497 | return new Entry {
498 | Image = ((ndarray)numpyImage.reshape(new[] { this.Image.Height, this.Image.Width, 3 }))
499 | .AsType(),
500 | BoundingBoxes = this.BoundingBoxes,
501 | };
502 | }
503 |
504 | public static ClrEntry FromNumPyEntry(Entry entry) {
505 | int height = entry.Image.shape.Item1;
506 | int width = entry.Image.shape.Item2;
507 | var image = new Image(width: width, height: height);
508 | var bytes = entry.Image.reshape(new[] { height, width * 3 });
509 | for (int y = 0; y < height; y++) {
510 | var row = MemoryMarshal.Cast(image.GetPixelRowMemory(y).Span);
511 | var byteRow = bytes[y];
512 | for (int byteOffset = 0; byteOffset < width * 3; byteOffset++)
513 | row[byteOffset] = (byte)(byteRow[byteOffset].AsScalar() * 255);
514 | }
515 | return new ClrEntry { Image = image, BoundingBoxes = entry.BoundingBoxes };
516 | }
517 | }
518 |
519 | public static class Entry {
520 | public static (Range, int[]) AllHorizontal { get; } = (.., new[] { 0, 2 });
521 | public static (Range, int[]) AllVertical { get; } = (.., new[] { 1, 3 });
522 | }
523 |
524 | static bool NotPositive(int value) => value <= 0;
525 |
526 | public static ndarray ParseAnchors(string anchors)
527 | => anchors.Split(',')
528 | .Select(coord => float.Parse(coord.Trim(), CultureInfo.InvariantCulture))
529 | .ToNumPyArray()
530 | .reshape(new[] { 3, 3, 2 })
531 | .AsArray();
532 |
533 | public static ndarray ParseAnchors(IEnumerable anchors)
534 | => anchors.ToNumPyArray()
535 | .reshape(new[] { 3, 3, 2 })
536 | .AsArray()
537 | .AsType();
538 | }
539 | }
540 |
--------------------------------------------------------------------------------
/src/datasets/ObjectDetection/MS_COCO.cs:
--------------------------------------------------------------------------------
1 | namespace tensorflow.datasets.ObjectDetection {
2 | using System;
3 | using System.Linq;
4 |
5 | public static class MS_COCO {
6 | public static int InputSize => 416;
7 | public static int ClassCount => 80;
8 | public static ReadOnlySpan ClassNames => classNames;
9 |
10 | static readonly string[] classNames = names.Split('\n', '\r')
11 | .Select(l => l.Trim())
12 | .Where(l => !string.IsNullOrEmpty(l))
13 | .ToArray();
14 |
15 | const string names = @"
16 | person
17 | bicycle
18 | car
19 | motorbike
20 | aeroplane
21 | bus
22 | train
23 | truck
24 | boat
25 | traffic light
26 | fire hydrant
27 | stop sign
28 | parking meter
29 | bench
30 | bird
31 | cat
32 | dog
33 | horse
34 | sheep
35 | cow
36 | elephant
37 | bear
38 | zebra
39 | giraffe
40 | backpack
41 | umbrella
42 | handbag
43 | tie
44 | suitcase
45 | frisbee
46 | skis
47 | snowboard
48 | sports ball
49 | kite
50 | baseball bat
51 | baseball glove
52 | skateboard
53 | surfboard
54 | tennis racket
55 | bottle
56 | wine glass
57 | cup
58 | fork
59 | knife
60 | spoon
61 | bowl
62 | banana
63 | apple
64 | sandwich
65 | orange
66 | broccoli
67 | carrot
68 | hot dog
69 | pizza
70 | donut
71 | cake
72 | chair
73 | sofa
74 | potted plant
75 | bed
76 | dining table
77 | toilet
78 | tvmonitor
79 | laptop
80 | mouse
81 | remote
82 | keyboard
83 | cell phone
84 | microwave
85 | oven
86 | toaster
87 | sink
88 | refrigerator
89 | book
90 | clock
91 | vase
92 | scissors
93 | teddy bear
94 | hair drier
95 | toothbrush
96 | ";
97 | }
98 | }
99 |
--------------------------------------------------------------------------------
/src/image/ImageTools.cs:
--------------------------------------------------------------------------------
1 | namespace tensorflow.image {
2 | using System;
3 | using System.Collections.Generic;
4 | using System.Drawing;
5 | using System.Drawing.Imaging;
6 |
7 | using LostTech.Gradient;
8 |
9 | using numpy;
10 |
11 | using SixLabors.ImageSharp;
12 | using SixLabors.ImageSharp.Processing;
13 |
14 | using tensorflow.data;
15 |
16 | using Size = SixLabors.ImageSharp.Size;
17 |
18 | static class ImageTools {
19 | ///
20 | /// Returns bytes of the image in HWC order
21 | ///
22 | public unsafe static byte[,,] LoadRGB8(string filePath) {
23 | using var bitmap = new Bitmap(filePath);
24 | int channels = System.Drawing.Image.IsAlphaPixelFormat(bitmap.PixelFormat) ? 4 : 3;
25 | byte[,,] result = new byte[bitmap.Height, bitmap.Width, channels];
26 | var lockFormat = channels == 3 ? PixelFormat.Format24bppRgb : PixelFormat.Format32bppArgb;
27 |
28 | var data = bitmap.LockBits(new System.Drawing.Rectangle { Width = bitmap.Width, Height = bitmap.Height },
29 | ImageLockMode.ReadOnly, lockFormat);
30 | try {
31 | for (int y = 0; y < bitmap.Height; y++) {
32 | fixed (byte* targetStride = &result[y, 0, 0]) {
33 | byte* sourceStride = (byte*)data.Scan0 + y * data.Stride;
34 | var sourceSpan = new ReadOnlySpan(sourceStride, bitmap.Width * channels);
35 | var targetSpan = new Span(targetStride, bitmap.Width * channels);
36 | sourceSpan.CopyTo(targetSpan);
37 | }
38 | }
39 | } finally {
40 | bitmap.UnlockBits(data);
41 | }
42 |
43 | return result;
44 | }
45 |
46 | public static ObjectDetectionDataset.Entry YoloPreprocess(ObjectDetectionDataset.ClrEntry entry, Size targetSize) {
47 | if (entry.Image is null) throw new ArgumentNullException(nameof(image));
48 |
49 | int h = entry.Image.Height, w = entry.Image.Width;
50 | float scale = Math.Min(targetSize.Width * 1f / w, targetSize.Height *1f / h);
51 | int newW = (int)(scale * w), newH = (int)(scale * h);
52 |
53 | Resize(entry.Image, width: newW, height: newH);
54 |
55 | var padded = (ndarray)np.full(shape: new[] { targetSize.Height, targetSize.Width, 3 },
56 | fill_value: 128f, dtype: dtype.GetClass());
57 | int dw = (targetSize.Width - newW) / 2, dh = (targetSize.Height - newH) / 2;
58 | padded[dh..(newH + dh), dw..(newW + dw)] = entry.ToNumPyEntry().Image;
59 | padded /= 255f;
60 |
61 | if (entry.BoundingBoxes != null) {
62 | var horIndex = (.., new[] { 0, 2 });
63 | var vertIndex = (.., new[] { 1, 3 });
64 | entry.BoundingBoxes[horIndex] = (entry.BoundingBoxes[horIndex] * scale).astype(np.int32_fn).AsArray() + dw;
65 | entry.BoundingBoxes[vertIndex] = (entry.BoundingBoxes[vertIndex] * scale).astype(np.int32_fn).AsArray() + dh;
66 | }
67 |
68 | return new ObjectDetectionDataset.Entry {
69 | Image = padded,
70 | BoundingBoxes = entry.BoundingBoxes,
71 | };
72 | }
73 |
74 | static void Resize(SixLabors.ImageSharp.Image image, int width, int height) {
75 | image.Mutate(x => x.Resize(width, height, KnownResamplers.Box));
76 | }
77 | }
78 | }
79 |
--------------------------------------------------------------------------------
/src/keras/Activations.cs:
--------------------------------------------------------------------------------
1 | namespace tensorflow.keras {
2 | using LostTech.Gradient;
3 | using LostTech.Gradient.ManualWrappers;
4 | static class Activations {
5 | public static Tensor Mish(IGraphNodeBase input)
6 | // https://github.com/hunglc007/tensorflow-yolov4-tflite/commit/a61f81f9118df9cec4d53736648174f6fb113e5f#diff-69d62c22a92472901b83e55ac7c153317c649564d4ae9945dcaed27d37295867R41
7 | => input * tf.tanh(tf.nn.softplus(input));
8 |
9 | public static PythonFunctionContainer Mish_fn { get; }
10 | = PythonFunctionContainer.Of(Mish);
11 | }
12 | }
13 |
--------------------------------------------------------------------------------
/src/keras/Blocks.cs:
--------------------------------------------------------------------------------
1 | namespace tensorflow.keras {
2 | using System;
3 |
4 | using LostTech.Gradient.ManualWrappers;
5 |
6 | using tensorflow.image;
7 | using tensorflow.keras.layers;
8 | static class Blocks {
9 | public static Tensor Conv(IGraphNodeBase input, int[] filtersShape,
10 | Func? activation,
11 | bool downsample = false,
12 | bool batchNorm = true
13 | ) {
14 | if (input is null) throw new ArgumentNullException(nameof(input));
15 | if (filtersShape is null) throw new ArgumentNullException(nameof(filtersShape));
16 |
17 | int strides = 1;
18 | IGraphNodeBase convolutionInput = input;
19 | string padding = "same";
20 |
21 | if (downsample) {
22 | convolutionInput = ZeroPadding2D.NewDyn(padding: ((1, 0), (1, 0))).__call__(input);
23 | padding = "valid";
24 | strides = 2;
25 | }
26 |
27 | var convLayer = new Conv2D(filters: filtersShape[^1], kernel_size: filtersShape[0],
28 | strides: strides, padding: padding,
29 | use_bias: !batchNorm,
30 | kernel_regularizer: tf.keras.regularizers.l2(0.0005),
31 | kernel_initializer: new random_normal_initializer(stddev: 0.01),
32 | bias_initializer: new constant_initializer(0.0));
33 | var conv = convLayer.__call__(convolutionInput);
34 |
35 | if (batchNorm)
36 | conv = new FreezableBatchNormalization().__call__(conv);
37 |
38 | return activation is null ? conv : activation(conv);
39 | }
40 |
41 | static Tensor TunedLeakyRelu(Tensor input) => tf.nn.leaky_relu(input, alpha: 0.1);
42 |
43 | public static Tensor Conv(IGraphNodeBase input, int[] filtersShape, bool downsample = false, bool batchNorm = true)
44 | => Conv(input, filtersShape,
45 | downsample: downsample, batchNorm: batchNorm,
46 | activation: TunedLeakyRelu);
47 |
48 | public static Tensor Residual(Tensor input, int inputChannel, int filter1, int filter2,
49 | Func? activation) {
50 | if (input is null) throw new ArgumentNullException(nameof(input));
51 | if (inputChannel <= 0) throw new ArgumentOutOfRangeException(nameof(inputChannel));
52 | if (filter1 <= 0) throw new ArgumentOutOfRangeException(nameof(filter1));
53 | if (filter2 <= 0) throw new ArgumentOutOfRangeException(nameof(filter2));
54 |
55 | var shortcut = input;
56 | var conv = Conv(input, filtersShape: new[] { 1, 1, inputChannel, filter1 }, activation: activation);
57 | conv = Conv(conv, filtersShape: new[] { 3, 3, filter1, filter2 }, activation: activation);
58 | return shortcut + conv;
59 | }
60 |
61 | public static Tensor Residual(Tensor input, int inputChannel, int filter1, int filter2)
62 | => Residual(input, inputChannel: inputChannel,
63 | filter1: filter1, filter2: filter2,
64 | activation: TunedLeakyRelu);
65 |
66 | public static Tensor Upsample(Tensor input) {
67 | var shape = tf.shape(input);
68 | return tf.image.resize(input, new[] { shape[1] * 2, shape[2] * 2 },
69 | method: ResizeMethod.BILINEAR);
70 | }
71 | }
72 | }
73 |
--------------------------------------------------------------------------------
/src/keras/applications/ObjectDetectionResult.cs:
--------------------------------------------------------------------------------
1 | namespace tensorflow.keras.applications {
2 | using numpy;
3 |
4 | using SixLabors.ImageSharp;
5 |
6 | public class ObjectDetectionResult {
7 | public int Class { get; set; }
8 | public float Score { get; set; }
9 | public RectangleF Box { get; set; }
10 |
11 | public static ObjectDetectionResult[] FromCombinedNonMaxSuppressionBatch(
12 | ndarray boxes, ndarray scores, ndarray classes,
13 | int detectionCount) {
14 | var result = new ObjectDetectionResult[detectionCount];
15 | for(int detection = 0; detection < detectionCount; detection++) {
16 | result[detection] = new ObjectDetectionResult {
17 | Class = checked((int)classes[0, detection].AsScalar()),
18 | Box = ToBox(boxes[0, detection].AsArray()),
19 | Score = scores[0, detection].AsScalar(),
20 | };
21 | }
22 | return result;
23 | }
24 |
25 | static RectangleF ToBox(ndarray tlbr) {
26 | var (y1, x1, y2, x2) = (tlbr[0].AsScalar(), tlbr[1].AsScalar(), tlbr[2].AsScalar(), tlbr[3].AsScalar());
27 | return new RectangleF(x: x1, y: y1, width: x2 - x1, height: y2 - y1);
28 | }
29 | }
30 | }
31 |
--------------------------------------------------------------------------------
/src/keras/applications/YOLO.Common.cs:
--------------------------------------------------------------------------------
1 | namespace tensorflow.keras.applications {
2 | using System;
3 | using System.Collections.Generic;
4 |
5 | using LostTech.Gradient;
6 |
7 | public static partial class YOLO {
8 | static (Tensor xywh, Tensor conf, Tensor prob) DecodeCommon(
9 | Tensor convOut, int outputSize, int classCount,
10 | ReadOnlySpan strides, Tensor anchors,
11 | int scaleIndex, ReadOnlySpan xyScale) {
12 | var varScope = new variable_scope("scale" + scaleIndex.ToString(System.Globalization.CultureInfo.InvariantCulture));
13 | using var _ = varScope.StartUsing();
14 | Tensor batchSize = tf.shape(convOut)[0];
15 |
16 | convOut = tf.reshape_dyn(convOut, new object[] { batchSize, outputSize, outputSize, 3, 5 + classCount });
17 | Tensor[] raws = tf.split(convOut, new[] { 2, 2, 1, classCount }, axis: -1);
18 | var (convRawDxDy, convRawDwDh, convRawConf, convRawProb) = raws;
19 |
20 | var meshgrid = tf.meshgrid(tf.range_dyn(outputSize), tf.range_dyn(outputSize));
21 | meshgrid = tf.expand_dims(tf.stack(meshgrid, axis: -1), axis: 2); // [gx, gy, 1, 2]
22 | Tensor xyGrid = tf.tile_dyn(
23 | tf.expand_dims(meshgrid, axis: 0),
24 | new object[] { tf.shape(convOut)[0], 1, 1, 3, 1 });
25 |
26 | xyGrid = tf.cast(xyGrid, tf.float32);
27 |
28 | var predictedXY = ((tf.sigmoid(convRawDxDy) * xyScale[scaleIndex]) - 0.5 * (xyScale[scaleIndex] - 1) + xyGrid) * strides[scaleIndex];
29 | var predictedWH = tf.exp(convRawDwDh) * tf.cast(anchors[scaleIndex], tf.float32);
30 | var predictedXYWH = tf.concat(new[] { predictedXY, predictedWH }, axis: -1);
31 |
32 | var predictedConf = tf.sigmoid(convRawConf);
33 | var predictedProb = tf.sigmoid(convRawProb);
34 |
35 | return (predictedXYWH, conf: predictedConf, prob: predictedProb);
36 | }
37 |
38 | static (Tensor boxes, Tensor conf) FilterBoxes(Tensor xywh, Tensor scores,
39 | float scoreThreshold,
40 | Tensor inputShape) {
41 | Tensor scoresMax = tf.reduce_max(scores, axis: new[] { -1 });
42 | Tensor mask = scoresMax >= scoreThreshold;
43 | Tensor classBoxes = tf.boolean_mask(xywh, mask);
44 | Tensor conf = tf.boolean_mask(scores, mask);
45 |
46 | Tensor count = tf.shape(scores)[0];
47 |
48 | classBoxes = tf.reshape_dyn(classBoxes, new object[] { count, -1, tf.shape(classBoxes)[^1] });
49 | conf = tf.reshape_dyn(conf, new object[] { count, -1, tf.shape(conf)[^1] });
50 |
51 | Tensor[] boxXY_WH = tf.split(classBoxes, new[] { 2, 2 }, axis: -1);
52 | var (boxXY, boxWH) = boxXY_WH;
53 |
54 | inputShape = tf.cast(inputShape, tf.float32);
55 |
56 | var boxYX = boxXY[tf.rest_of_the_axes, TensorDimensionSlice.Reverse];
57 | var boxHW = boxWH[tf.rest_of_the_axes, TensorDimensionSlice.Reverse];
58 |
59 | var boxMins = (boxYX - (boxHW / 2f)) / inputShape;
60 | var boxMaxes = (boxYX + (boxHW / 2f)) / inputShape;
61 |
62 | var boxes = tf.concat(new[] {
63 | boxMins[tf.rest_of_the_axes, 0..1], //y_min
64 | boxMins[tf.rest_of_the_axes, 1..2], //x_min
65 | boxMaxes[tf.rest_of_the_axes, 0..1], //y_max
66 | boxMaxes[tf.rest_of_the_axes, 1..2], //x_max
67 | }, axis: -1);
68 |
69 | return (boxes, conf);
70 | }
71 | }
72 | }
--------------------------------------------------------------------------------
/src/keras/applications/YOLO.Evaluate.cs:
--------------------------------------------------------------------------------
1 | namespace tensorflow.keras.applications {
2 | using System;
3 | using System.Collections.Generic;
4 | using System.Diagnostics;
5 | using System.Linq;
6 |
7 | using LostTech.Gradient;
8 | using LostTech.Gradient.BuiltIns;
9 |
10 | using numpy;
11 |
12 | using SixLabors.ImageSharp;
13 | using SixLabors.ImageSharp.PixelFormats;
14 |
15 | using tensorflow.data;
16 | using tensorflow.image;
17 | using tensorflow.keras.models;
18 | static partial class YOLO {
19 | static Tensor DecodeEval(Tensor convOut, int classCount) {
20 | Tensor batchSize = tf.shape(convOut)[0];
21 | Tensor outputSize = tf.shape(convOut)[1];
22 |
23 | convOut = tf.reshape_dyn(convOut, new object[] { batchSize, outputSize, outputSize, 3, 5 + classCount });
24 | Tensor[] raws = tf.split(convOut, new[] { 4, 1, classCount }, axis: -1);
25 | var (convRawXYWH, convRawConf, convRawProb) = raws;
26 |
27 | Tensor predConf = tf.sigmoid(convRawConf);
28 | Tensor predProb = tf.sigmoid(convRawProb);
29 |
30 | return tf.concat(new[] { convRawXYWH, predConf, predProb }, axis: -1);
31 | }
32 |
33 | public static Model CreateV4EvalOnly(int inputSize, int classCount) {
34 | if (inputSize <= 0) throw new ArgumentOutOfRangeException(nameof(inputSize));
35 | if (classCount <= 0) throw new ArgumentOutOfRangeException(nameof(classCount));
36 |
37 | Tensor input = tf.keras.Input(new TensorShape(inputSize, inputSize, 3));
38 | var featureMaps = YOLOv4.Apply(input, classCount: classCount);
39 |
40 | var bboxTensors = new PythonList();
41 | foreach (var featureMap in new[] { featureMaps.SSBox, featureMaps.MBBox, featureMaps.LBBox }) {
42 | var bbox = DecodeEval(featureMap, classCount: classCount);
43 | bboxTensors.Add(bbox);
44 | }
45 | return new Model(new { inputs = input, outputs = bboxTensors }.AsKwArgs());
46 | }
47 |
48 | public static ObjectDetectionResult[] Detect(dynamic detector, Size supportedSize, Image image) {
49 | if (detector is null) throw new ArgumentNullException(nameof(detector));
50 | if (image is null) throw new ArgumentNullException(nameof(image));
51 |
52 | var input = ImageTools.YoloPreprocess(new ObjectDetectionDataset.ClrEntry {
53 | Image = image.Clone(),
54 | }, supportedSize);
55 | var images = input.Image[np.newaxis, np.rest_of_the_axes].AsArray();
56 |
57 | IDictionary prediction = detector(tf.constant(images));
58 | _ArrayLike Get(string name) => prediction["tf_op_layer_" + name].numpy();
59 | ndarray boxs = Get(nameof(SelectedBoxesOutput.Boxes)).AsArray();
60 | ndarray scores = Get(nameof(SelectedBoxesOutput.Scores)).AsArray();
61 | ndarray classes = Get(nameof(SelectedBoxesOutput.Classes)).AsArray();
62 | ndarray detections = Get(nameof(SelectedBoxesOutput.Detections)).AsArray();
63 |
64 | return ObjectDetectionResult.FromCombinedNonMaxSuppressionBatch(
65 | boxs, scores, classes, detections[0].AsScalar());
66 | }
67 |
68 | public static ObjectDetectionResult[] DetectRaw(Model rawDetector,
69 | Size supportedSize, int classCount,
70 | Image image,
71 | ReadOnlySpan strides, Tensor anchors,
72 | ReadOnlySpan xyScale,
73 | float scoreThreshold = 0.2f) {
74 | if (rawDetector is null) throw new ArgumentNullException(nameof(rawDetector));
75 | if (image is null) throw new ArgumentNullException(nameof(image));
76 |
77 | var input = ImageTools.YoloPreprocess(new ObjectDetectionDataset.ClrEntry {
78 | Image = image.Clone(),
79 | }, supportedSize);
80 | var images = input.Image[np.newaxis, np.rest_of_the_axes].AsArray();
81 |
82 | IList prediction = rawDetector.__call__(images);
83 | Debug.Assert(prediction.Count == 3);
84 | var output = new YOLOv4.Output {
85 | SSBox = prediction[0],
86 | MBBox = prediction[1],
87 | LBBox = prediction[2],
88 | };
89 | var suppression = SelectBoxes(output, inputSize: supportedSize.Width, classCount: classCount,
90 | strides: strides, anchors: anchors,
91 | xyScale: xyScale,
92 | scoreThreshold: scoreThreshold);
93 |
94 | ndarray boxs = suppression.Boxes.numpy();
95 | ndarray scores = suppression.Scores.numpy();
96 | ndarray classes = suppression.Classes.numpy();
97 | ndarray detections = suppression.Detections.numpy();
98 |
99 | return ObjectDetectionResult.FromCombinedNonMaxSuppressionBatch(
100 | boxs, scores, classes, detections[0].AsScalar());
101 | }
102 |
103 | public static SelectedBoxesOutput
104 | SelectBoxes(YOLOv4.Output featureMaps, int inputSize, int classCount,
105 | ReadOnlySpan strides, Tensor anchors,
106 | ReadOnlySpan xyScale,
107 | float scoreThreshold = 0.2f) {
108 | var pred = ProcessPrediction(inputSize: inputSize, featureMaps,
109 | classCount: classCount,
110 | strides: strides,
111 | anchors: anchors,
112 | xyScale: xyScale,
113 | scoreThreshold: scoreThreshold);
114 |
115 | var boxes = pred[.., .., 0..4];
116 | var conf = pred[.., .., 4..];
117 |
118 | var batchSize = tf.shape(boxes)[0];
119 |
120 | var suppression = tf.image.combined_non_max_suppression(
121 | boxes: tf.reshape_dyn(boxes, new object[] { batchSize, -1, 1, 4 }),
122 | scores: tf.reshape_dyn(conf, new object[] { batchSize, -1, tf.shape(conf)[^1] }),
123 | max_output_size_per_class: tf.constant(50),
124 | max_total_size: tf.constant(50),
125 | iou_threshold: 0.45f,
126 | score_threshold: 0.20f
127 | );
128 | return new SelectedBoxesOutput {
129 | Boxes = tf.identity(suppression[0], name: nameof(SelectedBoxesOutput.Boxes)),
130 | Scores = tf.identity(suppression[1], name: nameof(SelectedBoxesOutput.Scores)),
131 | Classes = tf.cast(suppression[2], name: nameof(SelectedBoxesOutput.Classes)),
132 | Detections = tf.identity(suppression[3], name: nameof(SelectedBoxesOutput.Detections)),
133 | };
134 | }
135 |
136 | static ndarray PostProcessBBBox(IEnumerable> predictions,
137 | ndarray anchors,
138 | ReadOnlySpan strides,
139 | ReadOnlySpan xyScale) {
140 | foreach(var (scaleIndex, pred) in Tools.Enumerate(predictions.ToArray())) {
141 | var convShape = pred.shape;
142 | int outputSize = convShape.Item2;
143 | var convRawDxDy = pred[.., .., .., .., 0..2];
144 | var convRawDwDh = pred[.., .., .., .., 2..4];
145 |
146 | dynamic numpy = PythonModuleContainer.Get();
147 | var sizeRange = Enumerable.Range(0, outputSize).ToNumPyArray();
148 | PythonList> tempGrid = numpy.meshgrid(sizeRange, sizeRange);
149 | ndarray xyGrid = np.expand_dims(
150 | np.stack(tempGrid, axis: -1),
151 | axis: 2).AsArray(); // [gx, gy, 1, 2]
152 |
153 | xyGrid = numpy.tile(np.expand_dims(xyGrid, axis: 0), new[] { 1, 1, 1, 3, 1 });
154 | var xyGridFloat = xyGrid.AsType();
155 |
156 | var predXY = ((tf.sigmoid_dyn(convRawDxDy).numpy() * xyScale[scaleIndex]) - 0.5f * (xyScale[scaleIndex] - 1) + xyGridFloat) * strides[scaleIndex];
157 | ndarray predWH = numpy.exp(convRawDwDh) * anchors[scaleIndex];
158 |
159 | pred[.., .., .., .., 0..2] = predXY;
160 | pred[.., .., .., .., 2..4] = predWH.AsType();
161 | }
162 |
163 | var reshapedPredictions = predictions.Select(
164 | x => x.reshape(new[] { -1, (int)x.shape.Item5 }).AsArray());
165 | return np.concatenate(reshapedPredictions, axis: 0);
166 | }
167 |
168 | static ndarray PostProcessBoxes(ndarray predictions, Size originalSize, int inputSize, float scoreThreshold) {
169 | var predXYWH = predictions[.., 0..4];
170 | var predConf = predictions[.., 4];
171 | var predProb = predictions[.., 5..];
172 |
173 | // (1) (x, y, w, h) --> (xmin, ymin, xmax, ymax)
174 | var predCoor = np.concatenate(new[]{
175 | predXYWH[.., ..2] - predXYWH[..,2..]*0.5f,
176 | predXYWH[.., ..2] + predXYWH[..,2..]*0.5f,
177 | }, axis: -1);
178 |
179 | // (2) (xmin, ymin, xmax, ymax) -> (xmin_org, ymin_org, xmax_org, ymax_org)
180 | var (h, w) = (originalSize.Height, originalSize.Width);
181 | float resizeRatio = Math.Min(inputSize * 1f / w, inputSize * 1f / h);
182 |
183 | float dw = (inputSize - resizeRatio * w) / 2;
184 | float dh = (inputSize - resizeRatio * h) / 2;
185 |
186 | for (int i = 0; i < 2; i++) {
187 | predCoor[.., i * 2] = (predCoor[.., i * 2] - dw) / resizeRatio;
188 | predCoor[.., i * 2 + 1] = (predCoor[.., i * 2 + 1] - dh) / resizeRatio;
189 | }
190 |
191 | // (3) clip some boxes those are out of range
192 | predCoor = np.concatenate(new[] {
193 | np.maximum(predCoor[.., ..2], np.zeros(2)),
194 | np.maximum(predCoor[.., 2..], new []{w-1f, h-1f}.ToNumPyArray()),
195 | }, axis: -1);
196 | dynamic numpy = PythonModuleContainer.Get();
197 | ndarray invalidMask = numpy.logical_or(predCoor[.., 0] > predCoor[.., 2], predCoor[.., 1] > predCoor[.., 3]);
198 | predCoor[invalidMask] = 0;
199 |
200 | // (4) discard some invalid boxes
201 | ndarray bboxesScale = numpy.sqrt(numpy.multiply.reduce(
202 | predCoor[.., 2..4] - predCoor[.., 0..2],
203 | axis: -1));
204 | ndarray scaleMask = numpy.logical_and(0 < bboxesScale, bboxesScale < float.PositiveInfinity);
205 |
206 | // (5) discard some boxes with low scores
207 | var classes = predProb.argmax(axis: -1).AsArray();
208 | ndarray scores = predConf * predProb[numpy.arange(predCoor.shape.Item1), classes];
209 | ndarray? scoreMask = scores > scoreThreshold;
210 | ndarray mask = numpy.logical_and(scaleMask, scoreMask);
211 |
212 | var coords = predCoor[mask].AsArray();
213 | scores = scores[mask].AsArray();
214 | classes = classes[mask].AsArray();
215 |
216 | return np.concatenate(new [] {
217 | coords,
218 | scores[.., np.newaxis],
219 | classes[.., np.newaxis].AsArray().AsType(),
220 | }, axis: -1);
221 | }
222 |
223 | public struct SelectedBoxesOutput {
224 | public Tensor Boxes { get; set; }
225 | public Tensor Scores { get; set; }
226 | public Tensor Classes { get; set; }
227 | public Tensor Detections { get; set; }
228 | }
229 | }
230 | }
231 |
--------------------------------------------------------------------------------
/src/keras/applications/YOLO.LearningRateSchedule.cs:
--------------------------------------------------------------------------------
1 | namespace tensorflow.keras.applications {
2 | using System;
3 | using System.Collections.Generic;
4 | using System.ComponentModel;
5 |
6 | using LostTech.Gradient;
7 | using LostTech.Gradient.ManualWrappers;
8 | partial class YOLO {
9 | public class LearningRateSchedule : optimizers.schedules.LearningRateSchedule {
10 | internal const float defaultInitialLearningRate = 1e-3f;
11 | internal const float defaultFinalLearningRate = 1e-6f;
12 |
13 | readonly Tensor totalSteps, warmupSteps, initialLR, finalLR;
14 |
15 | public long TotalSteps { get; }
16 | public long WarmupSteps { get; }
17 | public float InitialLearningRate { get; }
18 | public float FinalLearningRate { get; }
19 |
20 | public LearningRateSchedule(long totalSteps, long warmupSteps,
21 | float initialLearningRate = defaultInitialLearningRate,
22 | float finalLearningRate = defaultFinalLearningRate) {
23 | if (totalSteps <= 0) throw new ArgumentOutOfRangeException(nameof(totalSteps));
24 | if (warmupSteps <= 0) throw new ArgumentOutOfRangeException(nameof(warmupSteps));
25 | if (!GoodLearningRate(initialLearningRate))
26 | throw new ArgumentOutOfRangeException(nameof(initialLearningRate));
27 | if (!GoodLearningRate(finalLearningRate))
28 | throw new ArgumentOutOfRangeException(nameof(finalLearningRate));
29 |
30 | this.TotalSteps = totalSteps;
31 | this.WarmupSteps = warmupSteps;
32 | this.InitialLearningRate = initialLearningRate;
33 | this.FinalLearningRate = finalLearningRate;
34 |
35 | this.totalSteps = tf.constant_scalar(totalSteps);
36 | this.warmupSteps = tf.constant_scalar(warmupSteps);
37 | this.initialLR = tf.constant_scalar(initialLearningRate);
38 | this.finalLR = tf.constant_scalar(finalLearningRate);
39 | }
40 |
41 | public override IDictionary get_config() {
42 | throw new NotImplementedException();
43 | }
44 |
45 | public Tensor Get(IGraphNodeBase step) => this.__call__(step);
46 |
47 | [EditorBrowsable(EditorBrowsableState.Advanced)]
48 | public override dynamic __call__(IGraphNodeBase step)
49 | => tf.cond(step < this.warmupSteps,
50 | PythonFunctionContainer.Of(() => (step / this.warmupSteps) * this.initialLR),
51 | PythonFunctionContainer.Of(() => this.finalLR
52 | + 0.5f * (this.initialLR - this.finalLR)
53 | * (1 + tf.cos(
54 | (step - this.warmupSteps) / (this.totalSteps - this.warmupSteps)
55 | * Math.PI)))
56 | );
57 |
58 | static bool GoodLearningRate(float lr)
59 | => lr > 0 && !float.IsPositiveInfinity(lr);
60 |
61 | public override dynamic __call___dyn(object step) => throw new NotImplementedException();
62 | public override dynamic get_config_dyn() => throw new NotImplementedException();
63 |
64 | public static float DefaultInitialLearningRate => defaultInitialLearningRate;
65 | public static float DefaultFinalLearningRate => defaultFinalLearningRate;
66 | }
67 | }
68 | }
69 |
--------------------------------------------------------------------------------
/src/keras/applications/YOLO.Raw.cs:
--------------------------------------------------------------------------------
1 | namespace tensorflow.keras.applications {
2 | using LostTech.Gradient;
3 | using LostTech.Gradient.BuiltIns;
4 |
5 | using tensorflow.keras.models;
6 | partial class YOLO {
7 | public static Model CreateRaw(int inputSize, int classCount) {
8 | Tensor input = tf.keras.Input(new TensorShape(inputSize, inputSize, 3));
9 | var featureMaps = YOLOv4.Apply(input, classCount: classCount);
10 | var featureMapTensors = new PythonList { featureMaps.SSBox, featureMaps.MBBox, featureMaps.LBBox };
11 | return new Model(new { inputs = input, outputs = featureMapTensors }.AsKwArgs());
12 | }
13 | }
14 | }
15 |
--------------------------------------------------------------------------------
/src/keras/applications/YOLO.SaveModel.cs:
--------------------------------------------------------------------------------
1 | namespace tensorflow.keras.applications {
2 | using System;
3 | using System.Collections.Generic;
4 |
5 | using LostTech.Gradient;
6 | using LostTech.Gradient.BuiltIns;
7 |
8 | using tensorflow.keras.models;
9 |
10 | partial class YOLO {
11 | public static Model CreateSaveable(int inputSize, int classCount,
12 | ReadOnlySpan strides, Tensor anchors,
13 | ReadOnlySpan xyScale,
14 | float scoreThreshold) {
15 | Tensor input = tf.keras.Input(new TensorShape(inputSize, inputSize, 3));
16 | var featureMaps = YOLOv4.Apply(input, classCount: classCount);
17 | return CreateSaveable(inputSize: inputSize, input: input, featureMaps,
18 | classCount: classCount,
19 | strides: strides, anchors: anchors, xyScale: xyScale,
20 | scoreThreshold: scoreThreshold);
21 | }
22 |
23 | public static Model CreateSaveable(int inputSize, Tensor input, YOLOv4.Output featureMaps,
24 | int classCount,
25 | ReadOnlySpan strides, Tensor anchors,
26 | ReadOnlySpan xyScale, float scoreThreshold) {
27 | var suppression = SelectBoxes(featureMaps, inputSize: inputSize, classCount: classCount,
28 | strides: strides, anchors: anchors,
29 | xyScale: xyScale,
30 | scoreThreshold: scoreThreshold);
31 | return new Model(new { inputs = input, outputs = new PythonList {
32 | suppression.Boxes, suppression.Scores, suppression.Classes, suppression.Detections,
33 | }}.AsKwArgs());
34 | }
35 |
36 | public static Tensor ProcessPrediction(int inputSize, YOLOv4.Output modelOutput, int classCount, ReadOnlySpan strides, Tensor anchors, ReadOnlySpan xyScale, float scoreThreshold) {
37 | var bboxTensors = new List();
38 | var probTensors = new List();
39 | foreach (var (scaleIndex, featureMap) in Tools.Enumerate(modelOutput.SSBox, modelOutput.MBBox, modelOutput.LBBox)) {
40 | var outputTensors = Decode(featureMap,
41 | outputSize: inputSize / strides[scaleIndex],
42 | classCount: classCount,
43 | strides: strides,
44 | anchors: anchors,
45 | scaleIndex: scaleIndex,
46 | xyScale: xyScale);
47 | bboxTensors.Add(outputTensors.xywh);
48 | probTensors.Add(outputTensors.prob);
49 | }
50 | var bbox = tf.concat(bboxTensors.ToArray(), axis: 1);
51 | var prob = tf.concat(probTensors.ToArray(), axis: 1);
52 |
53 | var (boxes, conf) = FilterBoxes(bbox, prob,
54 | scoreThreshold: scoreThreshold,
55 | inputShape: tf.constant(new[] { inputSize, inputSize }));
56 |
57 | return tf.concat(new[] { boxes, conf }, axis: -1);
58 | }
59 |
60 | static (Tensor xywh, Tensor prob) Decode(
61 | Tensor convOut, int classCount, int outputSize,
62 | ReadOnlySpan strides, Tensor anchors,
63 | int scaleIndex, ReadOnlySpan xyScale) {
64 | var pred = DecodeCommon(convOut,
65 | classCount: classCount, outputSize: outputSize,
66 | strides: strides, anchors: anchors,
67 | scaleIndex: scaleIndex,
68 | xyScale: xyScale);
69 |
70 | Tensor batchSize = tf.shape(convOut)[0];
71 | pred.prob = pred.conf * pred.prob;
72 | pred.prob = tf.reshape_dyn(pred.prob, new object[] { batchSize, -1, classCount });
73 | pred.xywh = tf.reshape_dyn(pred.xywh, new object[] { batchSize, -1, 4 });
74 | return (pred.xywh, pred.prob);
75 | }
76 | }
77 | }
--------------------------------------------------------------------------------
/src/keras/applications/YOLO.Train.cs:
--------------------------------------------------------------------------------
1 | namespace tensorflow.keras.applications {
2 | using System;
3 | using System.Collections.Generic;
4 | using System.Diagnostics;
5 | using System.Linq;
6 |
7 | using LostTech.Gradient;
8 | using LostTech.Gradient.BuiltIns;
9 | using LostTech.Gradient.ManualWrappers;
10 |
11 | using numpy;
12 |
13 | using tensorflow.data;
14 | using tensorflow.datasets.ObjectDetection;
15 | using tensorflow.errors;
16 | using tensorflow.keras.callbacks;
17 | using tensorflow.keras.layers;
18 | using tensorflow.keras.losses;
19 | using tensorflow.keras.models;
20 | using tensorflow.keras.optimizers;
21 | using tensorflow.keras.utils;
22 | using tensorflow.python.eager.context;
23 | using tensorflow.python.ops.summary_ops_v2;
24 |
25 | public static partial class YOLO {
26 | public static void Train(Model model, IOptimizer optimizer, ObjectDetectionDataset dataset,
27 | ObjectDetectionDataset? testSet = null,
28 | IEnumerable? callbacks = null,
29 | int batchSize = 2,
30 | int warmupEpochs = 2, int firstStageEpochs = 20,
31 | int secondStageEpochs = 30,
32 | float initialLearningRate = 1e-3f,
33 | float finalLearningRate = 1e-6f,
34 | bool testRun = false,
35 | bool benchmark = false) {
36 | var globalSteps = new Variable(1, dtype: tf.int64);
37 |
38 | var learningRateSchedule = new YOLO.LearningRateSchedule(
39 | totalSteps: (long)(firstStageEpochs + secondStageEpochs) * dataset.BatchCount(batchSize),
40 | warmupSteps: warmupEpochs * dataset.BatchCount(batchSize),
41 | initialLearningRate: initialLearningRate,
42 | finalLearningRate: finalLearningRate);
43 |
44 | foreach (var callback in callbacks ?? Array.Empty()) {
45 | callback.DynamicInvoke