├── .gitattributes
├── .gitignore
├── LICENSE.txt
├── README.md
├── StableDiffusionDemo_Console
├── Program.cs
└── StableDiffusionDemo_Console.csproj
├── StableDiffusionDemo_Winform
├── FormMain.Designer.cs
├── FormMain.cs
├── FormMain.resx
├── Program.cs
└── StableDiffusionDemo_Winform.csproj
├── StableDiffusionSharp.sln
└── StableDiffusionSharp
├── ModelLoader
├── ModelLoader.cs
├── PickleLoader.cs
├── SafetensorsLoader.cs
└── TensorInfo.cs
├── Models
├── Clip
│ ├── merges.txt
│ └── vocab.json
└── VAEApprox
│ ├── vaeapp_sd15.pth
│ └── xlvaeapp.pth
├── Modules
├── Clip.cs
├── Esrgan.cs
├── SD1.cs
├── SDModel.cs
├── SDXL.cs
├── Tokenizer.cs
├── Unet.cs
├── VAE.cs
└── VAEApprox.cs
├── SDType.cs
├── Sampler
├── BasicSampler.cs
├── EulerAncestralSampler.cs
└── EulerSampler.cs
├── Scheduler
└── DiscreteSchedule.cs
├── StableDiffusion.cs
├── StableDiffusionSharp.csproj
└── Tools.cs
/.gitattributes:
--------------------------------------------------------------------------------
1 | ###############################################################################
2 | # Set default behavior to automatically normalize line endings.
3 | ###############################################################################
4 | * text=auto
5 |
6 | ###############################################################################
7 | # Set default behavior for command prompt diff.
8 | #
9 | # This is need for earlier builds of msysgit that does not have it on by
10 | # default for csharp files.
11 | # Note: This is only used by command line
12 | ###############################################################################
13 | #*.cs diff=csharp
14 |
15 | ###############################################################################
16 | # Set the merge driver for project and solution files
17 | #
18 | # Merging from the command prompt will add diff markers to the files if there
19 | # are conflicts (Merging from VS is not affected by the settings below, in VS
20 | # the diff markers are never inserted). Diff markers may cause the following
21 | # file extensions to fail to load in VS. An alternative would be to treat
22 | # these files as binary and thus will always conflict and require user
23 | # intervention with every merge. To do so, just uncomment the entries below
24 | ###############################################################################
25 | #*.sln merge=binary
26 | #*.csproj merge=binary
27 | #*.vbproj merge=binary
28 | #*.vcxproj merge=binary
29 | #*.vcproj merge=binary
30 | #*.dbproj merge=binary
31 | #*.fsproj merge=binary
32 | #*.lsproj merge=binary
33 | #*.wixproj merge=binary
34 | #*.modelproj merge=binary
35 | #*.sqlproj merge=binary
36 | #*.wwaproj merge=binary
37 |
38 | ###############################################################################
39 | # behavior for image files
40 | #
41 | # image files are treated as binary by default.
42 | ###############################################################################
43 | #*.jpg binary
44 | #*.png binary
45 | #*.gif binary
46 |
47 | ###############################################################################
48 | # diff behavior for common document formats
49 | #
50 | # Convert binary document formats to text before diffing them. This feature
51 | # is only available from the command line. Turn it on by uncommenting the
52 | # entries below.
53 | ###############################################################################
54 | #*.doc diff=astextplain
55 | #*.DOC diff=astextplain
56 | #*.docx diff=astextplain
57 | #*.DOCX diff=astextplain
58 | #*.dot diff=astextplain
59 | #*.DOT diff=astextplain
60 | #*.pdf diff=astextplain
61 | #*.PDF diff=astextplain
62 | #*.rtf diff=astextplain
63 | #*.RTF diff=astextplain
64 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | ## Ignore Visual Studio temporary files, build results, and
2 | ## files generated by popular Visual Studio add-ons.
3 | ##
4 | ## Get latest from https://github.com/github/gitignore/blob/master/VisualStudio.gitignore
5 |
6 | # User-specific files
7 | *.rsuser
8 | *.suo
9 | *.user
10 | *.userosscache
11 | *.sln.docstates
12 |
13 | # User-specific files (MonoDevelop/Xamarin Studio)
14 | *.userprefs
15 |
16 | # Mono auto generated files
17 | mono_crash.*
18 |
19 | # Build results
20 | [Dd]ebug/
21 | [Dd]ebugPublic/
22 | [Rr]elease/
23 | [Rr]eleases/
24 | x64/
25 | x86/
26 | [Ww][Ii][Nn]32/
27 | [Aa][Rr][Mm]/
28 | [Aa][Rr][Mm]64/
29 | bld/
30 | [Bb]in/
31 | [Oo]bj/
32 | [Oo]ut/
33 | [Ll]og/
34 | [Ll]ogs/
35 |
36 | # Visual Studio 2015/2017 cache/options directory
37 | .vs/
38 | # Uncomment if you have tasks that create the project's static files in wwwroot
39 | #wwwroot/
40 |
41 | # Visual Studio 2017 auto generated files
42 | Generated\ Files/
43 |
44 | # MSTest test Results
45 | [Tt]est[Rr]esult*/
46 | [Bb]uild[Ll]og.*
47 |
48 | # NUnit
49 | *.VisualState.xml
50 | TestResult.xml
51 | nunit-*.xml
52 |
53 | # Build Results of an ATL Project
54 | [Dd]ebugPS/
55 | [Rr]eleasePS/
56 | dlldata.c
57 |
58 | # Benchmark Results
59 | BenchmarkDotNet.Artifacts/
60 |
61 | # .NET Core
62 | project.lock.json
63 | project.fragment.lock.json
64 | artifacts/
65 |
66 | # ASP.NET Scaffolding
67 | ScaffoldingReadMe.txt
68 |
69 | # StyleCop
70 | StyleCopReport.xml
71 |
72 | # Files built by Visual Studio
73 | *_i.c
74 | *_p.c
75 | *_h.h
76 | *.ilk
77 | *.meta
78 | *.obj
79 | *.iobj
80 | *.pch
81 | *.pdb
82 | *.ipdb
83 | *.pgc
84 | *.pgd
85 | *.rsp
86 | *.sbr
87 | *.tlb
88 | *.tli
89 | *.tlh
90 | *.tmp
91 | *.tmp_proj
92 | *_wpftmp.csproj
93 | *.log
94 | *.vspscc
95 | *.vssscc
96 | .builds
97 | *.pidb
98 | *.svclog
99 | *.scc
100 |
101 | # Chutzpah Test files
102 | _Chutzpah*
103 |
104 | # Visual C++ cache files
105 | ipch/
106 | *.aps
107 | *.ncb
108 | *.opendb
109 | *.opensdf
110 | *.sdf
111 | *.cachefile
112 | *.VC.db
113 | *.VC.VC.opendb
114 |
115 | # Visual Studio profiler
116 | *.psess
117 | *.vsp
118 | *.vspx
119 | *.sap
120 |
121 | # Visual Studio Trace Files
122 | *.e2e
123 |
124 | # TFS 2012 Local Workspace
125 | $tf/
126 |
127 | # Guidance Automation Toolkit
128 | *.gpState
129 |
130 | # ReSharper is a .NET coding add-in
131 | _ReSharper*/
132 | *.[Rr]e[Ss]harper
133 | *.DotSettings.user
134 |
135 | # TeamCity is a build add-in
136 | _TeamCity*
137 |
138 | # DotCover is a Code Coverage Tool
139 | *.dotCover
140 |
141 | # AxoCover is a Code Coverage Tool
142 | .axoCover/*
143 | !.axoCover/settings.json
144 |
145 | # Coverlet is a free, cross platform Code Coverage Tool
146 | coverage*.json
147 | coverage*.xml
148 | coverage*.info
149 |
150 | # Visual Studio code coverage results
151 | *.coverage
152 | *.coveragexml
153 |
154 | # NCrunch
155 | _NCrunch_*
156 | .*crunch*.local.xml
157 | nCrunchTemp_*
158 |
159 | # MightyMoose
160 | *.mm.*
161 | AutoTest.Net/
162 |
163 | # Web workbench (sass)
164 | .sass-cache/
165 |
166 | # Installshield output folder
167 | [Ee]xpress/
168 |
169 | # DocProject is a documentation generator add-in
170 | DocProject/buildhelp/
171 | DocProject/Help/*.HxT
172 | DocProject/Help/*.HxC
173 | DocProject/Help/*.hhc
174 | DocProject/Help/*.hhk
175 | DocProject/Help/*.hhp
176 | DocProject/Help/Html2
177 | DocProject/Help/html
178 |
179 | # Click-Once directory
180 | publish/
181 |
182 | # Publish Web Output
183 | *.[Pp]ublish.xml
184 | *.azurePubxml
185 | # Note: Comment the next line if you want to checkin your web deploy settings,
186 | # but database connection strings (with potential passwords) will be unencrypted
187 | *.pubxml
188 | *.publishproj
189 |
190 | # Microsoft Azure Web App publish settings. Comment the next line if you want to
191 | # checkin your Azure Web App publish settings, but sensitive information contained
192 | # in these scripts will be unencrypted
193 | PublishScripts/
194 |
195 | # NuGet Packages
196 | *.nupkg
197 | # NuGet Symbol Packages
198 | *.snupkg
199 | # The packages folder can be ignored because of Package Restore
200 | **/[Pp]ackages/*
201 | # except build/, which is used as an MSBuild target.
202 | !**/[Pp]ackages/build/
203 | # Uncomment if necessary however generally it will be regenerated when needed
204 | #!**/[Pp]ackages/repositories.config
205 | # NuGet v3's project.json files produces more ignorable files
206 | *.nuget.props
207 | *.nuget.targets
208 |
209 | # Microsoft Azure Build Output
210 | csx/
211 | *.build.csdef
212 |
213 | # Microsoft Azure Emulator
214 | ecf/
215 | rcf/
216 |
217 | # Windows Store app package directories and files
218 | AppPackages/
219 | BundleArtifacts/
220 | Package.StoreAssociation.xml
221 | _pkginfo.txt
222 | *.appx
223 | *.appxbundle
224 | *.appxupload
225 |
226 | # Visual Studio cache files
227 | # files ending in .cache can be ignored
228 | *.[Cc]ache
229 | # but keep track of directories ending in .cache
230 | !?*.[Cc]ache/
231 |
232 | # Others
233 | ClientBin/
234 | ~$*
235 | *~
236 | *.dbmdl
237 | *.dbproj.schemaview
238 | *.jfm
239 | *.pfx
240 | *.publishsettings
241 | orleans.codegen.cs
242 |
243 | # Including strong name files can present a security risk
244 | # (https://github.com/github/gitignore/pull/2483#issue-259490424)
245 | #*.snk
246 |
247 | # Since there are multiple workflows, uncomment next line to ignore bower_components
248 | # (https://github.com/github/gitignore/pull/1529#issuecomment-104372622)
249 | #bower_components/
250 |
251 | # RIA/Silverlight projects
252 | Generated_Code/
253 |
254 | # Backup & report files from converting an old project file
255 | # to a newer Visual Studio version. Backup files are not needed,
256 | # because we have git ;-)
257 | _UpgradeReport_Files/
258 | Backup*/
259 | UpgradeLog*.XML
260 | UpgradeLog*.htm
261 | ServiceFabricBackup/
262 | *.rptproj.bak
263 |
264 | # SQL Server files
265 | *.mdf
266 | *.ldf
267 | *.ndf
268 |
269 | # Business Intelligence projects
270 | *.rdl.data
271 | *.bim.layout
272 | *.bim_*.settings
273 | *.rptproj.rsuser
274 | *- [Bb]ackup.rdl
275 | *- [Bb]ackup ([0-9]).rdl
276 | *- [Bb]ackup ([0-9][0-9]).rdl
277 |
278 | # Microsoft Fakes
279 | FakesAssemblies/
280 |
281 | # GhostDoc plugin setting file
282 | *.GhostDoc.xml
283 |
284 | # Node.js Tools for Visual Studio
285 | .ntvs_analysis.dat
286 | node_modules/
287 |
288 | # Visual Studio 6 build log
289 | *.plg
290 |
291 | # Visual Studio 6 workspace options file
292 | *.opt
293 |
294 | # Visual Studio 6 auto-generated workspace file (contains which files were open etc.)
295 | *.vbw
296 |
297 | # Visual Studio LightSwitch build output
298 | **/*.HTMLClient/GeneratedArtifacts
299 | **/*.DesktopClient/GeneratedArtifacts
300 | **/*.DesktopClient/ModelManifest.xml
301 | **/*.Server/GeneratedArtifacts
302 | **/*.Server/ModelManifest.xml
303 | _Pvt_Extensions
304 |
305 | # Paket dependency manager
306 | .paket/paket.exe
307 | paket-files/
308 |
309 | # FAKE - F# Make
310 | .fake/
311 |
312 | # CodeRush personal settings
313 | .cr/personal
314 |
315 | # Python Tools for Visual Studio (PTVS)
316 | __pycache__/
317 | *.pyc
318 |
319 | # Cake - Uncomment if you are using it
320 | # tools/**
321 | # !tools/packages.config
322 |
323 | # Tabs Studio
324 | *.tss
325 |
326 | # Telerik's JustMock configuration file
327 | *.jmconfig
328 |
329 | # BizTalk build output
330 | *.btp.cs
331 | *.btm.cs
332 | *.odx.cs
333 | *.xsd.cs
334 |
335 | # OpenCover UI analysis results
336 | OpenCover/
337 |
338 | # Azure Stream Analytics local run output
339 | ASALocalRun/
340 |
341 | # MSBuild Binary and Structured Log
342 | *.binlog
343 |
344 | # NVidia Nsight GPU debugger configuration file
345 | *.nvuser
346 |
347 | # MFractors (Xamarin productivity tool) working folder
348 | .mfractor/
349 |
350 | # Local History for Visual Studio
351 | .localhistory/
352 |
353 | # BeatPulse healthcheck temp database
354 | healthchecksdb
355 |
356 | # Backup folder for Package Reference Convert tool in Visual Studio 2017
357 | MigrationBackup/
358 |
359 | # Ionide (cross platform F# VS Code tools) working folder
360 | .ionide/
361 |
362 | # Fody - auto-generated XML schema
363 | FodyWeavers.xsd
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # StableDiffusionSharp
2 |
3 | **Use Stable diffusion with C# only.**
4 |
5 | StableDiffusionSharp is an image generating software. With the help of torchsharp, stable diffusion can run without python.
6 |
7 | 
8 |
9 | ## Features
10 |
11 | - Written in C# only.
12 | - Can load .safetensors or .ckpt model directly.
13 | - Cuda support.
14 | - Use SDPA for speed-up and save vram in fp16.
15 | - Text2Image support.
16 | - Image2Image support.
17 | - SD1.5 support.
18 | - SDXL support.
19 | - VAEApprox support.
20 | - Esrgan 4x support.
21 | - Nuget package support.
22 |
23 | For SD1.5 Text to Image, it cost about 3G VRAM and 2.4 seconds for Generating a 512*512 image in 20 step.
24 |
25 | ## Work to do
26 |
27 | - Lora support.
28 | - ControlNet support.
29 | - Inpaint support.
30 | - Tiled VAE.
31 |
32 | ## How to use
33 |
34 | You can download the code or add it from nuget.
35 |
36 | dotnet add package IntptrMax.YoloSharp
37 |
38 | Or use the code directly.
39 |
40 | > [!NOTE]
41 | > Please add one of libtorch-cpu, libtorch-cuda-12.1, libtorch-cuda-12.1-win-x64 or libtorch-cuda-12.1-linux-x64 version 2.5.1.0 to execute.
42 |
43 | You have to download sd model first. If you need a seperate vae, and you have to download it too.
44 |
45 |
46 | If you want to use esrgan for upscaling, you have to download model from [RealESRGAN_x4plus.pth](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth)
47 |
48 | Now you can use it like the code below.
49 |
50 | ``` C#
51 | static void Main(string[] args)
52 | {
53 | string sdModelPath = @".\Chilloutmix.safetensors";
54 | string vaeModelPath = @".\vae.safetensors";
55 |
56 | string esrganModelPath = @".\RealESRGAN_x4plus.pth";
57 | string i2iPrompt = "High quality, best quality, moon, grass, tree, boat.";
58 | string prompt = "cat with blue eyes";
59 | string nprompt = "";
60 |
61 | SDDeviceType deviceType = SDDeviceType.CUDA;
62 | SDScalarType scalarType = SDScalarType.Float16;
63 | SDSamplerType samplerType = SDSamplerType.EulerAncestral;
64 | int step = 20;
65 | float cfg = 7.0f;
66 | long seed = 0;
67 | long img2imgSubSeed = 0;
68 | int width = 512;
69 | int height = 512;
70 | float strength = 0.75f;
71 | long clipSkip = 2;
72 |
73 | StableDiffusion sd = new StableDiffusion(deviceType, scalarType);
74 | sd.StepProgress += Sd_StepProgress;
75 | Console.WriteLine("Loading model......");
76 | sd.LoadModel(sdModelPath, vaeModelPath);
77 | Console.WriteLine("Model loaded.");
78 |
79 | ImageMagick.MagickImage t2iImage = sd.TextToImage(prompt, nprompt, clipSkip, width, height, step, seed, cfg, samplerType);
80 | t2iImage.Write("output_t2i.png");
81 |
82 | ImageMagick.MagickImage i2iImage = sd.ImageToImage(t2iImage, i2iPrompt, nprompt, clipSkip, step, strength, seed, img2imgSubSeed, cfg, samplerType);
83 | i2iImage.Write("output_i2i.png");
84 |
85 | sd.Dispose();
86 | GC.Collect();
87 |
88 | Console.WriteLine("Doing upscale......");
89 | StableDiffusionSharp.Modules.Esrgan esrgan = new StableDiffusionSharp.Modules.Esrgan(deviceType: deviceType, scalarType: scalarType);
90 | esrgan.LoadModel(esrganModelPath);
91 | ImageMagick.MagickImage upscaleImg = esrgan.UpScale(t2iImage);
92 | upscaleImg.Write("upscale.png");
93 |
94 | Console.WriteLine(@"Done. Images have been saved.");
95 | }
96 |
97 | private static void Sd_StepProgress(object? sender, StableDiffusion.StepEventArgs e)
98 | {
99 | Console.WriteLine($"Progress: {e.CurrentStep}/{e.TotalSteps}");
100 | }
101 |
--------------------------------------------------------------------------------
/StableDiffusionDemo_Console/Program.cs:
--------------------------------------------------------------------------------
1 | using StableDiffusionSharp;
2 |
3 | namespace StableDiffusionDemo_Console
4 | {
5 | internal class Program
6 | {
7 | static void Main(string[] args)
8 | {
9 | string sdModelPath = @".\Chilloutmix.safetensors";
10 | string vaeModelPath = @".\vae.safetensors";
11 |
12 | string esrganModelPath = @".\RealESRGAN_x4plus.pth";
13 | string i2iPrompt = "High quality, best quality, moon, grass, tree, boat.";
14 | string prompt = "cat with blue eyes";
15 | string nprompt = "";
16 |
17 | SDDeviceType deviceType = SDDeviceType.CUDA;
18 | SDScalarType scalarType = SDScalarType.Float16;
19 | SDSamplerType samplerType = SDSamplerType.Euler;
20 | int step = 20;
21 | float cfg = 7.0f;
22 | long seed = 0;
23 | long img2imgSubSeed = 0;
24 | int width = 512;
25 | int height = 512;
26 | float strength = 0.75f;
27 | long clipSkip = 2;
28 |
29 | StableDiffusion sd = new StableDiffusion(deviceType, scalarType);
30 | sd.StepProgress += Sd_StepProgress;
31 | Console.WriteLine("Loading model......");
32 | sd.LoadModel(sdModelPath, vaeModelPath);
33 | Console.WriteLine("Model loaded.");
34 |
35 | ImageMagick.MagickImage t2iImage = sd.TextToImage(prompt, nprompt, clipSkip, width, height, step, seed, cfg, samplerType);
36 | t2iImage.Write("output_t2i.png");
37 |
38 | ImageMagick.MagickImage i2iImage = sd.ImageToImage(t2iImage, i2iPrompt, nprompt, clipSkip, step, strength, seed, img2imgSubSeed, cfg, samplerType);
39 | i2iImage.Write("output_i2i.png");
40 |
41 | sd.Dispose();
42 | GC.Collect();
43 |
44 | Console.WriteLine("Doing upscale......");
45 | StableDiffusionSharp.Modules.Esrgan esrgan = new StableDiffusionSharp.Modules.Esrgan(deviceType: deviceType, scalarType: scalarType);
46 | esrgan.LoadModel(esrganModelPath);
47 | ImageMagick.MagickImage upscaleImg = esrgan.UpScale(t2iImage);
48 | upscaleImg.Write("upscale.png");
49 |
50 | Console.WriteLine(@"Done. Images have been saved.");
51 | }
52 |
53 | private static void Sd_StepProgress(object? sender, StableDiffusion.StepEventArgs e)
54 | {
55 | Console.WriteLine($"Progress: {e.CurrentStep}/{e.TotalSteps}");
56 | }
57 | }
58 | }
59 |
--------------------------------------------------------------------------------
/StableDiffusionDemo_Console/StableDiffusionDemo_Console.csproj:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | Exe
5 | net6.0
6 | enable
7 | enable
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
--------------------------------------------------------------------------------
/StableDiffusionDemo_Winform/FormMain.Designer.cs:
--------------------------------------------------------------------------------
1 | namespace StableDiffusionDemo_Winform
2 | {
3 | partial class FormMain
4 | {
5 | ///
6 | /// Required designer variable.
7 | ///
8 | private System.ComponentModel.IContainer components = null;
9 |
10 | ///
11 | /// Clean up any resources being used.
12 | ///
13 | /// true if managed resources should be disposed; otherwise, false.
14 | protected override void Dispose(bool disposing)
15 | {
16 | if (disposing && (components != null))
17 | {
18 | components.Dispose();
19 | }
20 | base.Dispose(disposing);
21 | }
22 |
23 | #region Windows Form Designer generated code
24 |
25 | ///
26 | /// Required method for Designer support - do not modify
27 | /// the contents of this method with the code editor.
28 | ///
29 | private void InitializeComponent()
30 | {
31 | groupBox1 = new GroupBox();
32 | label11 = new Label();
33 | NumericUpDown_ClipSkip = new NumericUpDown();
34 | label10 = new Label();
35 | Button_VAEModelScan = new Button();
36 | TextBox_VaePath = new TextBox();
37 | label9 = new Label();
38 | label8 = new Label();
39 | ComboBox_Precition = new ComboBox();
40 | ComboBox_Device = new ComboBox();
41 | Button_ModelLoad = new Button();
42 | Button_ModelScan = new Button();
43 | label1 = new Label();
44 | TextBox_ModelPath = new TextBox();
45 | tabControl1 = new TabControl();
46 | tabPage1 = new TabPage();
47 | groupBox2 = new GroupBox();
48 | Label_State = new Label();
49 | Button_Generate = new Button();
50 | label7 = new Label();
51 | label6 = new Label();
52 | label5 = new Label();
53 | NumericUpDown_Height = new NumericUpDown();
54 | NumericUpDown_CFG = new NumericUpDown();
55 | NumericUpDown_Step = new NumericUpDown();
56 | NumericUpDown_Width = new NumericUpDown();
57 | label4 = new Label();
58 | PictureBox_Output = new PictureBox();
59 | label3 = new Label();
60 | TextBox_NPrompt = new TextBox();
61 | TextBox_Prompt = new TextBox();
62 | label2 = new Label();
63 | tabPage2 = new TabPage();
64 | tabPage3 = new TabPage();
65 | groupBox1.SuspendLayout();
66 | ((System.ComponentModel.ISupportInitialize)NumericUpDown_ClipSkip).BeginInit();
67 | tabControl1.SuspendLayout();
68 | tabPage1.SuspendLayout();
69 | groupBox2.SuspendLayout();
70 | ((System.ComponentModel.ISupportInitialize)NumericUpDown_Height).BeginInit();
71 | ((System.ComponentModel.ISupportInitialize)NumericUpDown_CFG).BeginInit();
72 | ((System.ComponentModel.ISupportInitialize)NumericUpDown_Step).BeginInit();
73 | ((System.ComponentModel.ISupportInitialize)NumericUpDown_Width).BeginInit();
74 | ((System.ComponentModel.ISupportInitialize)PictureBox_Output).BeginInit();
75 | SuspendLayout();
76 | //
77 | // groupBox1
78 | //
79 | groupBox1.Controls.Add(label11);
80 | groupBox1.Controls.Add(NumericUpDown_ClipSkip);
81 | groupBox1.Controls.Add(label10);
82 | groupBox1.Controls.Add(Button_VAEModelScan);
83 | groupBox1.Controls.Add(TextBox_VaePath);
84 | groupBox1.Controls.Add(label9);
85 | groupBox1.Controls.Add(label8);
86 | groupBox1.Controls.Add(ComboBox_Precition);
87 | groupBox1.Controls.Add(ComboBox_Device);
88 | groupBox1.Controls.Add(Button_ModelLoad);
89 | groupBox1.Controls.Add(Button_ModelScan);
90 | groupBox1.Controls.Add(label1);
91 | groupBox1.Controls.Add(TextBox_ModelPath);
92 | groupBox1.Location = new Point(12, 12);
93 | groupBox1.Name = "groupBox1";
94 | groupBox1.Size = new Size(865, 178);
95 | groupBox1.TabIndex = 0;
96 | groupBox1.TabStop = false;
97 | groupBox1.Text = "Base";
98 | //
99 | // label11
100 | //
101 | label11.AutoSize = true;
102 | label11.Location = new Point(461, 117);
103 | label11.Name = "label11";
104 | label11.Size = new Size(59, 17);
105 | label11.TabIndex = 12;
106 | label11.Text = "Clip Skip";
107 | //
108 | // NumericUpDown_ClipSkip
109 | //
110 | NumericUpDown_ClipSkip.Location = new Point(526, 114);
111 | NumericUpDown_ClipSkip.Maximum = new decimal(new int[] { 10, 0, 0, 0 });
112 | NumericUpDown_ClipSkip.Name = "NumericUpDown_ClipSkip";
113 | NumericUpDown_ClipSkip.Size = new Size(62, 23);
114 | NumericUpDown_ClipSkip.TabIndex = 11;
115 | //
116 | // label10
117 | //
118 | label10.AutoSize = true;
119 | label10.Location = new Point(18, 78);
120 | label10.Name = "label10";
121 | label10.Size = new Size(60, 17);
122 | label10.TabIndex = 10;
123 | label10.Text = "VAE Path";
124 | //
125 | // Button_VAEModelScan
126 | //
127 | Button_VAEModelScan.Location = new Point(708, 72);
128 | Button_VAEModelScan.Name = "Button_VAEModelScan";
129 | Button_VAEModelScan.Size = new Size(101, 23);
130 | Button_VAEModelScan.TabIndex = 9;
131 | Button_VAEModelScan.Text = "Scan";
132 | Button_VAEModelScan.UseVisualStyleBackColor = true;
133 | Button_VAEModelScan.Click += Button_VAEModelScan_Click;
134 | //
135 | // TextBox_VaePath
136 | //
137 | TextBox_VaePath.Location = new Point(113, 72);
138 | TextBox_VaePath.Name = "TextBox_VaePath";
139 | TextBox_VaePath.ReadOnly = true;
140 | TextBox_VaePath.Size = new Size(564, 23);
141 | TextBox_VaePath.TabIndex = 8;
142 | //
143 | // label9
144 | //
145 | label9.AutoSize = true;
146 | label9.Location = new Point(217, 115);
147 | label9.Name = "label9";
148 | label9.Size = new Size(58, 17);
149 | label9.TabIndex = 7;
150 | label9.Text = "Precition";
151 | //
152 | // label8
153 | //
154 | label8.AutoSize = true;
155 | label8.Location = new Point(18, 115);
156 | label8.Name = "label8";
157 | label8.Size = new Size(46, 17);
158 | label8.TabIndex = 6;
159 | label8.Text = "Device";
160 | //
161 | // ComboBox_Precition
162 | //
163 | ComboBox_Precition.DropDownStyle = ComboBoxStyle.DropDownList;
164 | ComboBox_Precition.FormattingEnabled = true;
165 | ComboBox_Precition.Items.AddRange(new object[] { "fp16", "fp32" });
166 | ComboBox_Precition.Location = new Point(281, 112);
167 | ComboBox_Precition.Name = "ComboBox_Precition";
168 | ComboBox_Precition.Size = new Size(121, 25);
169 | ComboBox_Precition.TabIndex = 5;
170 | //
171 | // ComboBox_Device
172 | //
173 | ComboBox_Device.DropDownStyle = ComboBoxStyle.DropDownList;
174 | ComboBox_Device.FormattingEnabled = true;
175 | ComboBox_Device.Items.AddRange(new object[] { "CUDA", "CPU" });
176 | ComboBox_Device.Location = new Point(70, 112);
177 | ComboBox_Device.Name = "ComboBox_Device";
178 | ComboBox_Device.Size = new Size(121, 25);
179 | ComboBox_Device.TabIndex = 4;
180 | //
181 | // Button_ModelLoad
182 | //
183 | Button_ModelLoad.Location = new Point(708, 137);
184 | Button_ModelLoad.Name = "Button_ModelLoad";
185 | Button_ModelLoad.Size = new Size(101, 23);
186 | Button_ModelLoad.TabIndex = 3;
187 | Button_ModelLoad.Text = "Load Model";
188 | Button_ModelLoad.UseVisualStyleBackColor = true;
189 | Button_ModelLoad.Click += Button_ModelLoad_Click;
190 | //
191 | // Button_ModelScan
192 | //
193 | Button_ModelScan.Location = new Point(708, 40);
194 | Button_ModelScan.Name = "Button_ModelScan";
195 | Button_ModelScan.Size = new Size(101, 23);
196 | Button_ModelScan.TabIndex = 2;
197 | Button_ModelScan.Text = "Scan";
198 | Button_ModelScan.UseVisualStyleBackColor = true;
199 | Button_ModelScan.Click += Button_ModelScan_Click;
200 | //
201 | // label1
202 | //
203 | label1.AutoSize = true;
204 | label1.Location = new Point(18, 40);
205 | label1.Name = "label1";
206 | label1.Size = new Size(75, 17);
207 | label1.TabIndex = 1;
208 | label1.Text = "Model Path";
209 | //
210 | // TextBox_ModelPath
211 | //
212 | TextBox_ModelPath.Location = new Point(113, 34);
213 | TextBox_ModelPath.Name = "TextBox_ModelPath";
214 | TextBox_ModelPath.ReadOnly = true;
215 | TextBox_ModelPath.Size = new Size(564, 23);
216 | TextBox_ModelPath.TabIndex = 0;
217 | //
218 | // tabControl1
219 | //
220 | tabControl1.Controls.Add(tabPage1);
221 | tabControl1.Controls.Add(tabPage2);
222 | tabControl1.Controls.Add(tabPage3);
223 | tabControl1.Location = new Point(12, 196);
224 | tabControl1.Name = "tabControl1";
225 | tabControl1.SelectedIndex = 0;
226 | tabControl1.Size = new Size(865, 397);
227 | tabControl1.TabIndex = 1;
228 | //
229 | // tabPage1
230 | //
231 | tabPage1.Controls.Add(groupBox2);
232 | tabPage1.Location = new Point(4, 26);
233 | tabPage1.Name = "tabPage1";
234 | tabPage1.Padding = new Padding(3);
235 | tabPage1.Size = new Size(857, 367);
236 | tabPage1.TabIndex = 0;
237 | tabPage1.Text = "Text To Image";
238 | tabPage1.UseVisualStyleBackColor = true;
239 | //
240 | // groupBox2
241 | //
242 | groupBox2.Controls.Add(Label_State);
243 | groupBox2.Controls.Add(Button_Generate);
244 | groupBox2.Controls.Add(label7);
245 | groupBox2.Controls.Add(label6);
246 | groupBox2.Controls.Add(label5);
247 | groupBox2.Controls.Add(NumericUpDown_Height);
248 | groupBox2.Controls.Add(NumericUpDown_CFG);
249 | groupBox2.Controls.Add(NumericUpDown_Step);
250 | groupBox2.Controls.Add(NumericUpDown_Width);
251 | groupBox2.Controls.Add(label4);
252 | groupBox2.Controls.Add(PictureBox_Output);
253 | groupBox2.Controls.Add(label3);
254 | groupBox2.Controls.Add(TextBox_NPrompt);
255 | groupBox2.Controls.Add(TextBox_Prompt);
256 | groupBox2.Controls.Add(label2);
257 | groupBox2.Location = new Point(6, 6);
258 | groupBox2.Name = "groupBox2";
259 | groupBox2.Size = new Size(845, 355);
260 | groupBox2.TabIndex = 0;
261 | groupBox2.TabStop = false;
262 | groupBox2.Text = "Parameters";
263 | //
264 | // Label_State
265 | //
266 | Label_State.BorderStyle = BorderStyle.FixedSingle;
267 | Label_State.Location = new Point(6, 294);
268 | Label_State.Name = "Label_State";
269 | Label_State.Size = new Size(282, 58);
270 | Label_State.TabIndex = 15;
271 | Label_State.Text = "Please load a model first.";
272 | //
273 | // Button_Generate
274 | //
275 | Button_Generate.Enabled = false;
276 | Button_Generate.Location = new Point(294, 294);
277 | Button_Generate.Name = "Button_Generate";
278 | Button_Generate.Size = new Size(86, 55);
279 | Button_Generate.TabIndex = 14;
280 | Button_Generate.Text = "Generate";
281 | Button_Generate.UseVisualStyleBackColor = true;
282 | Button_Generate.Click += Button_Generate_Click;
283 | //
284 | // label7
285 | //
286 | label7.AutoSize = true;
287 | label7.Location = new Point(285, 267);
288 | label7.Name = "label7";
289 | label7.Size = new Size(31, 17);
290 | label7.TabIndex = 13;
291 | label7.Text = "CFG";
292 | //
293 | // label6
294 | //
295 | label6.AutoSize = true;
296 | label6.Location = new Point(167, 267);
297 | label6.Name = "label6";
298 | label6.Size = new Size(34, 17);
299 | label6.TabIndex = 12;
300 | label6.Text = "Step";
301 | //
302 | // label5
303 | //
304 | label5.AutoSize = true;
305 | label5.Location = new Point(89, 267);
306 | label5.Name = "label5";
307 | label5.Size = new Size(17, 17);
308 | label5.TabIndex = 11;
309 | label5.Text = "H";
310 | //
311 | // NumericUpDown_Height
312 | //
313 | NumericUpDown_Height.Increment = new decimal(new int[] { 64, 0, 0, 0 });
314 | NumericUpDown_Height.Location = new Point(112, 265);
315 | NumericUpDown_Height.Maximum = new decimal(new int[] { 2048, 0, 0, 0 });
316 | NumericUpDown_Height.Minimum = new decimal(new int[] { 64, 0, 0, 0 });
317 | NumericUpDown_Height.Name = "NumericUpDown_Height";
318 | NumericUpDown_Height.Size = new Size(49, 23);
319 | NumericUpDown_Height.TabIndex = 10;
320 | NumericUpDown_Height.Value = new decimal(new int[] { 512, 0, 0, 0 });
321 | //
322 | // NumericUpDown_CFG
323 | //
324 | NumericUpDown_CFG.Increment = new decimal(new int[] { 5, 0, 0, 65536 });
325 | NumericUpDown_CFG.Location = new Point(322, 265);
326 | NumericUpDown_CFG.Maximum = new decimal(new int[] { 25, 0, 0, 0 });
327 | NumericUpDown_CFG.Minimum = new decimal(new int[] { 5, 0, 0, 65536 });
328 | NumericUpDown_CFG.Name = "NumericUpDown_CFG";
329 | NumericUpDown_CFG.Size = new Size(58, 23);
330 | NumericUpDown_CFG.TabIndex = 9;
331 | NumericUpDown_CFG.Value = new decimal(new int[] { 7, 0, 0, 0 });
332 | //
333 | // NumericUpDown_Step
334 | //
335 | NumericUpDown_Step.Location = new Point(207, 265);
336 | NumericUpDown_Step.Minimum = new decimal(new int[] { 1, 0, 0, 0 });
337 | NumericUpDown_Step.Name = "NumericUpDown_Step";
338 | NumericUpDown_Step.Size = new Size(60, 23);
339 | NumericUpDown_Step.TabIndex = 8;
340 | NumericUpDown_Step.Value = new decimal(new int[] { 20, 0, 0, 0 });
341 | //
342 | // NumericUpDown_Width
343 | //
344 | NumericUpDown_Width.Increment = new decimal(new int[] { 64, 0, 0, 0 });
345 | NumericUpDown_Width.Location = new Point(34, 265);
346 | NumericUpDown_Width.Maximum = new decimal(new int[] { 2048, 0, 0, 0 });
347 | NumericUpDown_Width.Minimum = new decimal(new int[] { 64, 0, 0, 0 });
348 | NumericUpDown_Width.Name = "NumericUpDown_Width";
349 | NumericUpDown_Width.Size = new Size(49, 23);
350 | NumericUpDown_Width.TabIndex = 6;
351 | NumericUpDown_Width.Value = new decimal(new int[] { 512, 0, 0, 0 });
352 | //
353 | // label4
354 | //
355 | label4.AutoSize = true;
356 | label4.Location = new Point(8, 267);
357 | label4.Name = "label4";
358 | label4.Size = new Size(20, 17);
359 | label4.TabIndex = 5;
360 | label4.Text = "W";
361 | //
362 | // PictureBox_Output
363 | //
364 | PictureBox_Output.BorderStyle = BorderStyle.FixedSingle;
365 | PictureBox_Output.Location = new Point(398, 22);
366 | PictureBox_Output.Name = "PictureBox_Output";
367 | PictureBox_Output.Size = new Size(432, 327);
368 | PictureBox_Output.SizeMode = PictureBoxSizeMode.Zoom;
369 | PictureBox_Output.TabIndex = 4;
370 | PictureBox_Output.TabStop = false;
371 | //
372 | // label3
373 | //
374 | label3.AutoSize = true;
375 | label3.Location = new Point(8, 193);
376 | label3.Name = "label3";
377 | label3.Size = new Size(66, 17);
378 | label3.TabIndex = 3;
379 | label3.Text = "N_Prompt";
380 | //
381 | // TextBox_NPrompt
382 | //
383 | TextBox_NPrompt.Location = new Point(78, 158);
384 | TextBox_NPrompt.Multiline = true;
385 | TextBox_NPrompt.Name = "TextBox_NPrompt";
386 | TextBox_NPrompt.Size = new Size(302, 87);
387 | TextBox_NPrompt.TabIndex = 2;
388 | TextBox_NPrompt.Text = "2d, 3d, cartoon, paintings";
389 | //
390 | // TextBox_Prompt
391 | //
392 | TextBox_Prompt.Location = new Point(78, 36);
393 | TextBox_Prompt.Multiline = true;
394 | TextBox_Prompt.Name = "TextBox_Prompt";
395 | TextBox_Prompt.Size = new Size(302, 104);
396 | TextBox_Prompt.TabIndex = 1;
397 | TextBox_Prompt.Text = "realistic, best quality, 4k, 8k, trees, beach, moon, stars, boat, ";
398 | //
399 | // label2
400 | //
401 | label2.AutoSize = true;
402 | label2.Location = new Point(8, 90);
403 | label2.Name = "label2";
404 | label2.Size = new Size(51, 17);
405 | label2.TabIndex = 0;
406 | label2.Text = "Prompt";
407 | //
408 | // tabPage2
409 | //
410 | tabPage2.Location = new Point(4, 26);
411 | tabPage2.Name = "tabPage2";
412 | tabPage2.Padding = new Padding(3);
413 | tabPage2.Size = new Size(857, 367);
414 | tabPage2.TabIndex = 1;
415 | tabPage2.Text = "Image To Image";
416 | tabPage2.UseVisualStyleBackColor = true;
417 | //
418 | // tabPage3
419 | //
420 | tabPage3.Location = new Point(4, 26);
421 | tabPage3.Name = "tabPage3";
422 | tabPage3.Padding = new Padding(3);
423 | tabPage3.Size = new Size(857, 367);
424 | tabPage3.TabIndex = 2;
425 | tabPage3.Text = "Restore";
426 | tabPage3.UseVisualStyleBackColor = true;
427 | //
428 | // FormMain
429 | //
430 | AutoScaleDimensions = new SizeF(7F, 17F);
431 | AutoScaleMode = AutoScaleMode.Font;
432 | ClientSize = new Size(889, 605);
433 | Controls.Add(tabControl1);
434 | Controls.Add(groupBox1);
435 | Name = "FormMain";
436 | Text = "Stabel Diffusion Sharp";
437 | Load += FormMain_Load;
438 | groupBox1.ResumeLayout(false);
439 | groupBox1.PerformLayout();
440 | ((System.ComponentModel.ISupportInitialize)NumericUpDown_ClipSkip).EndInit();
441 | tabControl1.ResumeLayout(false);
442 | tabPage1.ResumeLayout(false);
443 | groupBox2.ResumeLayout(false);
444 | groupBox2.PerformLayout();
445 | ((System.ComponentModel.ISupportInitialize)NumericUpDown_Height).EndInit();
446 | ((System.ComponentModel.ISupportInitialize)NumericUpDown_CFG).EndInit();
447 | ((System.ComponentModel.ISupportInitialize)NumericUpDown_Step).EndInit();
448 | ((System.ComponentModel.ISupportInitialize)NumericUpDown_Width).EndInit();
449 | ((System.ComponentModel.ISupportInitialize)PictureBox_Output).EndInit();
450 | ResumeLayout(false);
451 | }
452 |
453 | #endregion
454 |
455 | private GroupBox groupBox1;
456 | private Button Button_ModelScan;
457 | private Label label1;
458 | private TextBox TextBox_ModelPath;
459 | private TabControl tabControl1;
460 | private TabPage tabPage1;
461 | private Button Button_ModelLoad;
462 | private GroupBox groupBox2;
463 | private PictureBox PictureBox_Output;
464 | private Label label3;
465 | private TextBox TextBox_NPrompt;
466 | private TextBox TextBox_Prompt;
467 | private Label label2;
468 | private NumericUpDown NumericUpDown_Width;
469 | private Label label4;
470 | private Button Button_Generate;
471 | private Label label7;
472 | private Label label6;
473 | private Label label5;
474 | private NumericUpDown NumericUpDown_Height;
475 | private NumericUpDown NumericUpDown_CFG;
476 | private NumericUpDown NumericUpDown_Step;
477 | private Label Label_State;
478 | private TabPage tabPage2;
479 | private TabPage tabPage3;
480 | private Label label9;
481 | private Label label8;
482 | private ComboBox ComboBox_Precition;
483 | private ComboBox ComboBox_Device;
484 | private Button Button_VAEModelScan;
485 | private TextBox TextBox_VaePath;
486 | private Label label10;
487 | private Label label11;
488 | private NumericUpDown NumericUpDown_ClipSkip;
489 | }
490 | }
491 |
--------------------------------------------------------------------------------
/StableDiffusionDemo_Winform/FormMain.cs:
--------------------------------------------------------------------------------
1 | using StableDiffusionSharp;
2 | using System.Diagnostics;
3 |
4 | namespace StableDiffusionDemo_Winform
5 | {
6 | public partial class FormMain : Form
7 | {
8 | string modelPath = string.Empty;
9 | string vaeModelPath = string.Empty;
10 | StableDiffusion? sd;
11 |
12 | public FormMain()
13 | {
14 | InitializeComponent();
15 | }
16 |
17 | private void FormMain_Load(object sender, EventArgs e)
18 | {
19 | ComboBox_Device.SelectedIndex = 0;
20 | ComboBox_Precition.SelectedIndex = 0;
21 | }
22 |
23 | private void Button_ModelScan_Click(object sender, EventArgs e)
24 | {
25 | FileDialog fileDialog = new OpenFileDialog();
26 | fileDialog.Filter = "Model files|*.safetensors;*.ckpt;*.pt;*.pth|All files|*.*";
27 | if (fileDialog.ShowDialog() == DialogResult.OK)
28 | {
29 | TextBox_ModelPath.Text = fileDialog.FileName;
30 | modelPath = fileDialog.FileName;
31 | }
32 | }
33 |
34 | private void Button_ModelLoad_Click(object sender, EventArgs e)
35 | {
36 | if (File.Exists(modelPath))
37 | {
38 | SDDeviceType deviceType = ComboBox_Device.SelectedIndex == 0 ? SDDeviceType.CUDA : SDDeviceType.CPU;
39 | SDScalarType scalarType = ComboBox_Precition.SelectedIndex == 0 ? SDScalarType.Float16 : SDScalarType.Float32;
40 | Task.Run(() =>
41 | {
42 | base.Invoke(() =>
43 | {
44 | Button_ModelLoad.Enabled = false;
45 | Button_Generate.Enabled = false;
46 | });
47 | sd = new StableDiffusion(deviceType, scalarType);
48 | sd.StepProgress += Sd_StepProgress;
49 | sd.LoadModel(modelPath, vaeModelPath);
50 | base.Invoke(() =>
51 | {
52 | Button_ModelLoad.Enabled = true;
53 | Button_Generate.Enabled = true;
54 | Label_State.Text = "Model loaded.";
55 | });
56 | });
57 | }
58 | }
59 |
60 | private void Button_VAEModelScan_Click(object sender, EventArgs e)
61 | {
62 | FileDialog fileDialog = new OpenFileDialog();
63 | fileDialog.Filter = "Model files|*.safetensors;*.ckpt;*.pt;*.pth|All files|*.*";
64 | if (fileDialog.ShowDialog() == DialogResult.OK)
65 | {
66 | TextBox_VaePath.Text = fileDialog.FileName;
67 | vaeModelPath = fileDialog.FileName;
68 | }
69 | }
70 |
71 | private void Sd_StepProgress(object? sender, StableDiffusion.StepEventArgs e)
72 | {
73 | base.Invoke(() =>
74 | {
75 | Label_State.Text = $"Progress: {e.CurrentStep}/{e.TotalSteps}";
76 | if (e.VaeApproxImg != null)
77 | {
78 | MemoryStream memoryStream = new MemoryStream();
79 | e.VaeApproxImg.Write(memoryStream, ImageMagick.MagickFormat.Jpg);
80 | base.Invoke(() =>
81 | {
82 | PictureBox_Output.Image = Image.FromStream(memoryStream);
83 | });
84 | }
85 | });
86 | }
87 |
88 | private void Button_Generate_Click(object sender, EventArgs e)
89 | {
90 | string prompt = TextBox_Prompt.Text;
91 | string nprompt = TextBox_NPrompt.Text;
92 | int step = (int)NumericUpDown_Step.Value;
93 | float cfg = (float)NumericUpDown_CFG.Value;
94 | long seed = 0;
95 | int width = (int)NumericUpDown_Width.Value;
96 | int height = (int)NumericUpDown_Height.Value;
97 | int clipSkip = (int)NumericUpDown_ClipSkip.Value;
98 |
99 | Task.Run(() =>
100 | {
101 | Stopwatch stopwatch = Stopwatch.StartNew();
102 | base.Invoke(() =>
103 | {
104 | Button_ModelLoad.Enabled = false;
105 | Button_Generate.Enabled = false;
106 | Label_State.Text = "Generating...";
107 | });
108 | ImageMagick.MagickImage image = sd.TextToImage(prompt, nprompt, clipSkip, width, height, step, seed, cfg);
109 | MemoryStream memoryStream = new MemoryStream();
110 | image.Write(memoryStream, ImageMagick.MagickFormat.Jpg);
111 | base.Invoke(() =>
112 | {
113 | PictureBox_Output.Image = Image.FromStream(memoryStream);
114 | Button_ModelLoad.Enabled = true;
115 | Button_Generate.Enabled = true;
116 | Label_State.Text = $"Done. It takes {stopwatch.Elapsed.TotalSeconds.ToString("f2")} s";
117 | });
118 | GC.Collect();
119 | });
120 | }
121 | }
122 | }
123 |
--------------------------------------------------------------------------------
/StableDiffusionDemo_Winform/FormMain.resx:
--------------------------------------------------------------------------------
1 |
2 |
3 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 | text/microsoft-resx
110 |
111 |
112 | 2.0
113 |
114 |
115 | System.Resources.ResXResourceReader, System.Windows.Forms, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089
116 |
117 |
118 | System.Resources.ResXResourceWriter, System.Windows.Forms, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089
119 |
120 |
--------------------------------------------------------------------------------
/StableDiffusionDemo_Winform/Program.cs:
--------------------------------------------------------------------------------
1 | namespace StableDiffusionDemo_Winform
2 | {
3 | internal static class Program
4 | {
5 | ///
6 | /// The main entry point for the application.
7 | ///
8 | [STAThread]
9 | static void Main()
10 | {
11 | // To customize application configuration such as set high DPI settings or default font,
12 | // see https://aka.ms/applicationconfiguration.
13 | ApplicationConfiguration.Initialize();
14 | Application.Run(new FormMain());
15 | }
16 | }
17 | }
--------------------------------------------------------------------------------
/StableDiffusionDemo_Winform/StableDiffusionDemo_Winform.csproj:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | WinExe
5 | net6.0-windows7.0
6 | enable
7 | true
8 | enable
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
--------------------------------------------------------------------------------
/StableDiffusionSharp.sln:
--------------------------------------------------------------------------------
1 |
2 | Microsoft Visual Studio Solution File, Format Version 12.00
3 | # Visual Studio Version 17
4 | VisualStudioVersion = 17.12.35728.132
5 | MinimumVisualStudioVersion = 10.0.40219.1
6 | Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "StableDiffusionSharp", "StableDiffusionSharp\StableDiffusionSharp.csproj", "{BF6F0C17-D34A-4EFB-9194-DF0ED1FBB4D8}"
7 | EndProject
8 | Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "StableDiffusionDemo_Console", "StableDiffusionDemo_Console\StableDiffusionDemo_Console.csproj", "{4F4250A4-B849-4821-AFA5-F8B5191BF08C}"
9 | EndProject
10 | Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "StableDiffusionDemo_Winform", "StableDiffusionDemo_Winform\StableDiffusionDemo_Winform.csproj", "{7860DFE9-EC36-44B3-81E8-817029B849B5}"
11 | EndProject
12 | Global
13 | GlobalSection(SolutionConfigurationPlatforms) = preSolution
14 | Debug|Any CPU = Debug|Any CPU
15 | Release|Any CPU = Release|Any CPU
16 | EndGlobalSection
17 | GlobalSection(ProjectConfigurationPlatforms) = postSolution
18 | {BF6F0C17-D34A-4EFB-9194-DF0ED1FBB4D8}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
19 | {BF6F0C17-D34A-4EFB-9194-DF0ED1FBB4D8}.Debug|Any CPU.Build.0 = Debug|Any CPU
20 | {BF6F0C17-D34A-4EFB-9194-DF0ED1FBB4D8}.Release|Any CPU.ActiveCfg = Release|Any CPU
21 | {BF6F0C17-D34A-4EFB-9194-DF0ED1FBB4D8}.Release|Any CPU.Build.0 = Release|Any CPU
22 | {4F4250A4-B849-4821-AFA5-F8B5191BF08C}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
23 | {4F4250A4-B849-4821-AFA5-F8B5191BF08C}.Debug|Any CPU.Build.0 = Debug|Any CPU
24 | {4F4250A4-B849-4821-AFA5-F8B5191BF08C}.Release|Any CPU.ActiveCfg = Release|Any CPU
25 | {4F4250A4-B849-4821-AFA5-F8B5191BF08C}.Release|Any CPU.Build.0 = Release|Any CPU
26 | {7860DFE9-EC36-44B3-81E8-817029B849B5}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
27 | {7860DFE9-EC36-44B3-81E8-817029B849B5}.Debug|Any CPU.Build.0 = Debug|Any CPU
28 | {7860DFE9-EC36-44B3-81E8-817029B849B5}.Release|Any CPU.ActiveCfg = Release|Any CPU
29 | {7860DFE9-EC36-44B3-81E8-817029B849B5}.Release|Any CPU.Build.0 = Release|Any CPU
30 | EndGlobalSection
31 | GlobalSection(SolutionProperties) = preSolution
32 | HideSolutionNode = FALSE
33 | EndGlobalSection
34 | GlobalSection(ExtensibilityGlobals) = postSolution
35 | SolutionGuid = {2D950EF2-E5CA-4631-8B81-9E0974E394D7}
36 | EndGlobalSection
37 | EndGlobal
38 |
--------------------------------------------------------------------------------
/StableDiffusionSharp/ModelLoader/ModelLoader.cs:
--------------------------------------------------------------------------------
1 | using TorchSharp;
2 | using static TorchSharp.torch;
3 |
4 | namespace StableDiffusionSharp.ModelLoader
5 | {
6 | internal static class ModelLoader
7 | {
8 | public static nn.Module LoadModel(this torch.nn.Module module, string fileName, string maybeAddHeaderInBlock = "")
9 | {
10 | string extension = Path.GetExtension(fileName).ToLower();
11 | if (extension == ".pt" || extension == ".ckpt" || extension == ".pth")
12 | {
13 | PickleLoader pickleLoader = new PickleLoader();
14 | return pickleLoader.LoadPickle(module, fileName, maybeAddHeaderInBlock);
15 | }
16 | else if (extension == ".safetensors")
17 | {
18 | SafetensorsLoader safetensorsLoader = new SafetensorsLoader();
19 | return safetensorsLoader.LoadSafetensors(module, fileName, maybeAddHeaderInBlock);
20 | }
21 | else
22 | {
23 | throw new ArgumentException("Invalid file extension");
24 | }
25 | }
26 |
27 | public static ModelType GetModelType(string ModelPath)
28 | {
29 | string extension = Path.GetExtension(ModelPath).ToLower();
30 | List tensorInfos = new List();
31 |
32 | if (extension == ".pt" || extension == ".ckpt" || extension == ".pth")
33 | {
34 | PickleLoader pickleLoader = new PickleLoader();
35 | tensorInfos = pickleLoader.ReadTensorsInfoFromFile(ModelPath);
36 | }
37 | else if (extension == ".safetensors")
38 | {
39 | SafetensorsLoader safetensorsLoader = new SafetensorsLoader();
40 | tensorInfos = safetensorsLoader.ReadTensorsInfoFromFile(ModelPath);
41 | }
42 | else
43 | {
44 | throw new ArgumentException("Invalid file extension");
45 | }
46 |
47 | if (tensorInfos.Count(a => a.Name.Contains("model.diffusion_model.double_blocks.")) > 0)
48 | {
49 | return ModelType.FLUX;
50 | }
51 | else if (tensorInfos.Count(a => a.Name.Contains("model.diffusion_model.joint_blocks.")) > 0)
52 | {
53 | return ModelType.SD3;
54 | }
55 | else if (tensorInfos.Count(a => a.Name.Contains("conditioner.embedders.1")) > 0)
56 | {
57 | return ModelType.SDXL;
58 | }
59 | else
60 | {
61 | return ModelType.SD1;
62 | }
63 |
64 | }
65 |
66 |
67 | }
68 | }
69 |
--------------------------------------------------------------------------------
/StableDiffusionSharp/ModelLoader/PickleLoader.cs:
--------------------------------------------------------------------------------
1 | using System.Collections.ObjectModel;
2 | using System.IO.Compression;
3 | using TorchSharp;
4 | using static TorchSharp.torch;
5 |
6 | namespace StableDiffusionSharp.ModelLoader
7 | {
8 | internal class PickleLoader
9 | {
10 | private ZipArchive zip;
11 | private ReadOnlyCollection entries;
12 |
13 | internal List ReadTensorsInfoFromFile(string fileName)
14 | {
15 | List tensors = new List();
16 |
17 | zip = ZipFile.OpenRead(fileName);
18 | entries = zip.Entries;
19 | ZipArchiveEntry headerEntry = entries.First(e => e.Name == "data.pkl");
20 | byte[] headerBytes = new byte[headerEntry.Length];
21 | // Header is always small enough to fit in memory, so we can read it all at once
22 | using (Stream stream = headerEntry.Open())
23 | {
24 | stream.Read(headerBytes, 0, headerBytes.Length);
25 | }
26 |
27 | if (headerBytes[0] != 0x80 || headerBytes[1] != 0x02)
28 | {
29 | throw new ArgumentException("Not a valid pickle file");
30 | }
31 |
32 | int index = 1;
33 | bool finished = false;
34 | bool readStrides = false;
35 | bool binPersid = false;
36 |
37 | TensorInfo tensor = new TensorInfo() { FileName = fileName, Offset = { 0 } };
38 |
39 | int deepth = 0;
40 |
41 | Dictionary BinPut = new Dictionary();
42 |
43 | while (index < headerBytes.Length && !finished)
44 | {
45 | byte opcode = headerBytes[index];
46 | switch (opcode)
47 | {
48 | case (byte)'}': // EMPTY_DICT = b'}' # push empty dict
49 | break;
50 | case (byte)']': // EMPTY_LIST = b']' # push empty list
51 | break;
52 | // skip unused sections
53 | case (byte)'h': // BINGET = b'h' # " " " " " " ; " " 1-byte arg
54 | {
55 | int id = headerBytes[index + 1];
56 | BinPut.TryGetValue(id, out string precision);
57 | if (precision != null)
58 | {
59 | if (precision.Contains("FloatStorage"))
60 | {
61 | tensor.Type = TorchSharp.torch.ScalarType.Float32;
62 | }
63 | else if (precision.Contains("HalfStorage"))
64 | {
65 | tensor.Type = TorchSharp.torch.ScalarType.Float16;
66 | }
67 | else if (precision.Contains("BFloat16Storage"))
68 | {
69 | tensor.Type = TorchSharp.torch.ScalarType.BFloat16;
70 | }
71 | }
72 | index++;
73 | break;
74 | }
75 | case (byte)'q': // BINPUT = b'to_q' # " " " " " ; " " 1-byte arg
76 | {
77 | index++;
78 | break;
79 | }
80 | case (byte)'Q': // BINPERSID = b'Q' # " " " ; " " " " stack
81 | binPersid = true;
82 | break;
83 | case (byte)'r': // LONG_BINPUT = b'r' # " " " " " ; " " 4-byte arg
84 | index += 4;
85 | break;
86 | case 0x95: // FRAME = b'\x95' # indicate the beginning of a new frame
87 | index += 8;
88 | break;
89 | case 0x94: // MEMOIZE = b'\x94' # store top of the stack in memo
90 | break;
91 | case (byte)'(': // MARK = b'(' # push special markobject on stack
92 | deepth++;
93 | break;
94 | case (byte)'K': // BININT1 = b'K' # push 1-byte unsigned int
95 | {
96 | int value = headerBytes[index + 1];
97 | index++;
98 |
99 | if (deepth > 1 && value != 0 && binPersid)
100 | {
101 | if (readStrides)
102 | {
103 | //tensor.Stride.Add((ulong)value);
104 | tensor.Stride.Add((ulong)value);
105 | }
106 | else
107 | {
108 | tensor.Shape.Add(value);
109 | }
110 | }
111 | }
112 | break;
113 | case (byte)'M': // BININT2 = b'M' # push 2-byte unsigned int
114 | {
115 | UInt16 value = BitConverter.ToUInt16(headerBytes, index + 1);
116 | index += 2;
117 |
118 | if (deepth > 1 && value != 0 && binPersid)
119 | {
120 | if (readStrides)
121 | {
122 | tensor.Stride.Add(value);
123 | }
124 | else
125 | {
126 | tensor.Shape.Add(value);
127 | }
128 | }
129 |
130 | }
131 | break;
132 | case (byte)'J': // BININT = b'J' # push four-byte signed int
133 | {
134 | int value = BitConverter.ToInt32(headerBytes, index + 1);
135 | //int value = headerBytes[index + 4] << 24 + headerBytes[index + 3] << 16 + headerBytes[index + 2] << 8 + headerBytes[index + 1];
136 | index += 4;
137 |
138 | if (deepth > 1 && value != 0 && binPersid)
139 | {
140 | if (readStrides)
141 | {
142 | tensor.Stride.Add((ulong)value);
143 | }
144 | else
145 | {
146 | tensor.Shape.Add(value);
147 | }
148 | }
149 | }
150 | break;
151 |
152 | case (byte)'X': // BINUNICODE = b'X' # " " " ; counted UTF-8 string argument
153 | {
154 | int length = headerBytes[index + 1];
155 | int start = index + 5;
156 | byte module = headerBytes[index + 1];
157 | string name = System.Text.Encoding.UTF8.GetString(headerBytes, start, length);
158 | index = index + 4 + length;
159 |
160 | if (deepth == 1)
161 | {
162 | tensor.Name = name;
163 | }
164 | else if (deepth == 3)
165 | {
166 | if ("cpu" != name && !name.Contains("cuda"))
167 | {
168 | tensor.DataNameInZipFile = name;
169 | }
170 | }
171 | }
172 | break;
173 | case 0x8C: // SHORT_BINUNICODE = b'\x8c' # push short string; UTF-8 length < 256 bytes
174 | {
175 |
176 | }
177 | break;
178 | case (byte)'c': // GLOBAL = b'c' # push self.find_class(modname, name); 2 string args
179 | {
180 | int start = index + 1;
181 | while (headerBytes[index + 1] != (byte)'q')
182 | {
183 | index++;
184 | }
185 | int length = index - start + 1;
186 |
187 | string global = System.Text.Encoding.UTF8.GetString(headerBytes, start, length);
188 |
189 | // precision is stored in the global variable
190 | // next tensor will read the precision
191 | // so we can set the Type here
192 |
193 | BinPut.Add(headerBytes[index + 2], global);
194 |
195 | if (global.Contains("FloatStorage"))
196 | {
197 | tensor.Type = TorchSharp.torch.ScalarType.Float32;
198 | }
199 | else if (global.Contains("HalfStorage"))
200 | {
201 | tensor.Type = TorchSharp.torch.ScalarType.Float16;
202 | }
203 | else if (global.Contains("BFloat16Storage"))
204 | {
205 | tensor.Type = TorchSharp.torch.ScalarType.BFloat16;
206 | }
207 | break;
208 | }
209 | case 0x86: // TUPLE2 = b'\x86' # build 2-tuple from two topmost stack items
210 | {
211 | if (binPersid)
212 | {
213 | readStrides = true;
214 | }
215 | break;
216 | }
217 | case 0x85: // TUPLE1 = b'\x85' # build 1-tuple from stack top
218 | if (binPersid)
219 | {
220 | readStrides = true;
221 | }
222 | break;
223 | case (byte)'t': // TUPLE = b't' # build tuple from topmost stack items
224 | deepth--;
225 | if (binPersid)
226 | {
227 | readStrides = true;
228 | }
229 | break;
230 | case (byte)'R': // REDUCE = b'R' # apply callable to argtuple, both on stack
231 | if (deepth == 1)
232 | {
233 | if (tensor.Name.Contains("metadata"))
234 | {
235 | break;
236 | }
237 |
238 | if (string.IsNullOrEmpty(tensor.DataNameInZipFile))
239 | {
240 | tensor.DataNameInZipFile = tensors.Last().DataNameInZipFile;
241 | tensor.Offset = new List { (ulong)(tensor.Shape[0] * tensor.Type.ElementSize()) };
242 | tensor.Shape.RemoveAt(0);
243 | //tensor.offset = tensors.Last().
244 | }
245 | tensors.Add(tensor);
246 |
247 | tensor = new TensorInfo() { FileName = fileName, Offset = { 0 } };
248 | readStrides = false;
249 | binPersid = false;
250 | }
251 | break;
252 | case (byte)'.': // STOP = b'.' # every pickle ends with STOP
253 | finished = true;
254 | break;
255 | default:
256 | break;
257 | }
258 | index++;
259 | }
260 | TensorInfo metaTensor = tensors.Find(x => x.Name.Contains("_metadata"));
261 | if (metaTensor != null)
262 | {
263 | tensors.Remove(metaTensor);
264 | }
265 | return tensors;
266 | }
267 |
268 | private byte[] ReadByteFromFile(TensorInfo tensor)
269 | {
270 | if (entries is null)
271 | {
272 | throw new ArgumentNullException(nameof(entries));
273 | }
274 |
275 | ZipArchiveEntry dataEntry = entries.First(e => e.Name == tensor.DataNameInZipFile);
276 | long i = 1;
277 | foreach (var ne in tensor.Shape)
278 | {
279 | i *= ne;
280 | }
281 | ulong length = (ulong)(tensor.Type.ElementSize() * i);
282 | byte[] data = new byte[dataEntry.Length];
283 |
284 | using (Stream stream = dataEntry.Open())
285 | {
286 | stream.Read(data, 0, data.Length);
287 | }
288 |
289 | //data = data.Take(new Range((int)tensor.Offset[0], (int)(tensor.Offset[0] + length))).ToArray();
290 | byte[] result = new byte[length];
291 | for (int j = 0; j < (int)length; j++)
292 | {
293 | result[j] = data[j + (int)tensor.Offset[0]];
294 | }
295 | return result;
296 | //return data;
297 | }
298 |
299 | internal Dictionary Load(string fileName, string addString = "")
300 | {
301 | Dictionary tensors = new Dictionary();
302 | List tensorInfos = ReadTensorsInfoFromFile(fileName);
303 | foreach (TensorInfo tensorInfo in tensorInfos)
304 | {
305 | TorchSharp.torch.Tensor tensor = TorchSharp.torch.empty(tensorInfo.Shape.ToArray(), dtype: tensorInfo.Type);
306 | tensor.bytes = ReadByteFromFile(tensorInfo);
307 | tensors.Add(addString + tensorInfo.Name, tensor);
308 | }
309 | return tensors;
310 | }
311 |
312 | internal nn.Module LoadPickle(torch.nn.Module module, string fileName, string maybeAddHeaderInBlock = "")
313 | {
314 | using (torch.no_grad())
315 | using (NewDisposeScope())
316 | {
317 | List tensorInfos = ReadTensorsInfoFromFile(fileName);
318 | foreach (var mod in module.named_parameters())
319 | {
320 | ScalarType dtype = mod.parameter.dtype;
321 | TensorInfo info = tensorInfos.First(a => ((a.Name == mod.name) || (maybeAddHeaderInBlock + a.Name == mod.name)));
322 | Tensor t = torch.zeros(mod.parameter.shape, info.Type);
323 | t.bytes = ReadByteFromFile(info);
324 | mod.parameter.copy_(t);
325 | t.Dispose();
326 | GC.Collect();
327 | }
328 | return module;
329 | }
330 | }
331 |
332 | }
333 | }
334 |
--------------------------------------------------------------------------------
/StableDiffusionSharp/ModelLoader/SafetensorsLoader.cs:
--------------------------------------------------------------------------------
1 | using Newtonsoft.Json.Linq;
2 | using System.Text;
3 | using static TorchSharp.torch;
4 | using TorchSharp;
5 |
6 | namespace StableDiffusionSharp.ModelLoader
7 | {
8 | internal class SafetensorsLoader
9 | {
10 | internal List ReadTensorsInfoFromFile(string inputFileName)
11 | {
12 | using (FileStream stream = File.OpenRead(inputFileName))
13 | {
14 | long len = stream.Length;
15 | if (len < 10)
16 | {
17 | throw new ArgumentOutOfRangeException("File cannot be valid safetensors: too short");
18 | }
19 |
20 | // Safetensors file first 8 byte to int64 is the header length
21 | byte[] headerBlock = new byte[8];
22 | stream.Read(headerBlock, 0, 8);
23 | long headerSize = BitConverter.ToInt64(headerBlock, 0);
24 | if (len < 8 + headerSize || headerSize <= 0 || headerSize > 100000000)
25 | {
26 | throw new ArgumentOutOfRangeException($"File cannot be valid safetensors: header len wrong, size:{headerSize}");
27 | }
28 |
29 | // Read the header, header file is a json file
30 | byte[] headerBytes = new byte[headerSize];
31 | stream.Read(headerBytes, 0, (int)headerSize);
32 |
33 | string header = Encoding.UTF8.GetString(headerBytes);
34 | long bodyPosition = stream.Position;
35 | JToken token = JToken.Parse(header);
36 |
37 | List tensors = new List();
38 | foreach (var sub in token.ToObject>())
39 | {
40 | Dictionary value = sub.Value.ToObject>();
41 | value.TryGetValue("data_offsets", out JToken offsets);
42 | value.TryGetValue("dtype", out JToken dtype);
43 | value.TryGetValue("shape", out JToken shape);
44 |
45 | ulong[] offsetArray = offsets?.ToObject();
46 | if (null == offsetArray)
47 | {
48 | continue;
49 | }
50 | long[] shapeArray = shape.ToObject();
51 | if (shapeArray.Length < 1)
52 | {
53 | shapeArray = new long[] { 1 };
54 | }
55 | TorchSharp.torch.ScalarType tensor_type = TorchSharp.torch.ScalarType.Float32;
56 | switch (dtype.ToString())
57 | {
58 | case "I8": tensor_type = TorchSharp.torch.ScalarType.Int8; break;
59 | case "I16": tensor_type = TorchSharp.torch.ScalarType.Int16; break;
60 | case "I32": tensor_type = TorchSharp.torch.ScalarType.Int32; break;
61 | case "I64": tensor_type = TorchSharp.torch.ScalarType.Int64; break;
62 | case "BF16": tensor_type = TorchSharp.torch.ScalarType.BFloat16; break;
63 | case "F16": tensor_type = TorchSharp.torch.ScalarType.Float16; break;
64 | case "F32": tensor_type = TorchSharp.torch.ScalarType.Float32; break;
65 | case "F64": tensor_type = TorchSharp.torch.ScalarType.Float64; break;
66 | case "U8": tensor_type = TorchSharp.torch.ScalarType.Byte; break;
67 | case "BOOL": tensor_type = TorchSharp.torch.ScalarType.Bool; break;
68 | case "U16":
69 | case "U32":
70 | case "U64":
71 | case "F8_E4M3":
72 | case "F8_E5M2": break;
73 | }
74 |
75 | TensorInfo tensor = new TensorInfo
76 | {
77 | Name = sub.Key,
78 | Type = tensor_type,
79 | Shape = shapeArray.ToList(),
80 | Offset = offsetArray.ToList(),
81 | FileName = inputFileName,
82 | BodyPosition = bodyPosition
83 | };
84 |
85 | tensors.Add(tensor);
86 | }
87 | return tensors;
88 | }
89 | }
90 |
91 | private byte[] ReadByteFromFile(string inputFileName, long bodyPosition, long offset, int size)
92 | {
93 | using (FileStream stream = File.OpenRead(inputFileName))
94 | {
95 | stream.Seek(bodyPosition + offset, SeekOrigin.Begin);
96 | byte[] dest = new byte[size];
97 | stream.Read(dest, 0, size);
98 | return dest;
99 | }
100 | }
101 |
102 | private byte[] ReadByteFromFile(TensorInfo tensor)
103 | {
104 | string inputFileName = tensor.FileName;
105 | long bodyPosition = tensor.BodyPosition;
106 | ulong offset = tensor.Offset[0];
107 | int size = (int)(tensor.Offset[1] - tensor.Offset[0]);
108 | return ReadByteFromFile(inputFileName, bodyPosition, (long)offset, size);
109 | }
110 |
111 | internal Dictionary Load(string fileName, string addString = "")
112 | {
113 | Dictionary tensors = new Dictionary();
114 | List tensorInfos = ReadTensorsInfoFromFile(fileName);
115 | foreach (TensorInfo tensorInfo in tensorInfos)
116 | {
117 | TorchSharp.torch.Tensor tensor = TorchSharp.torch.empty(tensorInfo.Shape.ToArray(), dtype: tensorInfo.Type);
118 | tensor.bytes = ReadByteFromFile(tensorInfo);
119 | tensors.Add(addString + tensorInfo.Name, tensor);
120 | }
121 | return tensors;
122 | }
123 |
124 | internal nn.Module LoadSafetensors(torch.nn.Module module, string fileName, string maybeAddHeaderInBlock = "")
125 | {
126 | using (torch.no_grad())
127 | using (NewDisposeScope())
128 | {
129 | List tensorInfos = ReadTensorsInfoFromFile(fileName);
130 | foreach (var mod in module.named_parameters())
131 | {
132 | ScalarType dtype = mod.parameter.dtype;
133 | TensorInfo info = tensorInfos.First(a => ((a.Name == mod.name) || (maybeAddHeaderInBlock + a.Name == mod.name)));
134 | Tensor t = torch.zeros(mod.parameter.shape, info.Type);
135 | t.bytes = ReadByteFromFile(info);
136 | mod.parameter.copy_(t);
137 | t.Dispose();
138 | GC.Collect();
139 | }
140 | return module;
141 | }
142 | }
143 | }
144 | }
145 |
--------------------------------------------------------------------------------
/StableDiffusionSharp/ModelLoader/TensorInfo.cs:
--------------------------------------------------------------------------------
1 | namespace StableDiffusionSharp.ModelLoader
2 | {
3 | internal class TensorInfo
4 | {
5 | public string Name { get; set; }
6 | public TorchSharp.torch.ScalarType Type { get; set; } = TorchSharp.torch.ScalarType.Float16;
7 | public List Shape { get; set; } = new List();
8 | public List Stride { get; set; } = new List();
9 | public string DataNameInZipFile { get; set; }
10 | public string FileName { get; set; }
11 | public List Offset { get; set; } = new List();
12 | public long BodyPosition { get; set; }
13 | }
14 | }
15 |
--------------------------------------------------------------------------------
/StableDiffusionSharp/Models/VAEApprox/vaeapp_sd15.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntptrMax/StableDiffusionSharp/cfdca9b5c50b86cd59aec08dfebd0961d91ba1c2/StableDiffusionSharp/Models/VAEApprox/vaeapp_sd15.pth
--------------------------------------------------------------------------------
/StableDiffusionSharp/Models/VAEApprox/xlvaeapp.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntptrMax/StableDiffusionSharp/cfdca9b5c50b86cd59aec08dfebd0961d91ba1c2/StableDiffusionSharp/Models/VAEApprox/xlvaeapp.pth
--------------------------------------------------------------------------------
/StableDiffusionSharp/Modules/Clip.cs:
--------------------------------------------------------------------------------
1 | using TorchSharp;
2 | using TorchSharp.Modules;
3 | using static TorchSharp.torch;
4 | using static TorchSharp.torch.nn;
5 |
6 | namespace StableDiffusionSharp.Modules
7 | {
8 | internal class Clip
9 | {
10 | private enum Activations
11 | {
12 | ReLU,
13 | SiLU,
14 | QuickGELU,
15 | GELU
16 | }
17 |
18 | internal class ViT_L_Clip : Module
19 | {
20 | private readonly CLIPTextModel transformer;
21 |
22 | public ViT_L_Clip(long n_vocab = 49408, long n_token = 77, long num_layers = 12, long n_heads = 12, long embed_dim = 768, long intermediate_size = 768 * 4, Device? device = null, ScalarType? dtype = null) : base(nameof(ViT_L_Clip))
23 | {
24 | transformer = new CLIPTextModel(n_vocab, n_token, num_layers, n_heads, embed_dim, intermediate_size, device: device, dtype: dtype);
25 | RegisterComponents();
26 | }
27 |
28 | public override Tensor forward(Tensor token, long num_skip, bool with_final_ln)
29 | {
30 | Device device = transformer.parameters().First().device;
31 | token = token.to(device);
32 | return transformer.forward(token, num_skip, with_final_ln);
33 | }
34 |
35 | private class CLIPTextModel : Module
36 | {
37 | private readonly CLIPTextTransformer text_model;
38 | public CLIPTextModel(long n_vocab, long n_token, long num_layers, long n_heads, long embed_dim, long intermediate_size, Device? device = null, ScalarType? dtype = null) : base(nameof(CLIPTextModel))
39 | {
40 | text_model = new CLIPTextTransformer(n_vocab, n_token, num_layers, n_heads, embed_dim, intermediate_size, device: device, dtype: dtype);
41 | RegisterComponents();
42 | }
43 | public override Tensor forward(Tensor x, long num_skip, bool with_final_ln)
44 | {
45 | return text_model.forward(x, num_skip, with_final_ln);
46 | }
47 | }
48 |
49 | private class CLIPTextTransformer : Module
50 | {
51 | private readonly CLIPTextEmbeddings embeddings;
52 | private readonly CLIPEncoder encoder;
53 | private readonly LayerNorm final_layer_norm;
54 | private readonly long num_layers;
55 |
56 | public CLIPTextTransformer(long n_vocab, long n_token, long num_layers, long n_heads, long embed_dim, long intermediate_size, Device? device = null, ScalarType? dtype = null) : base(nameof(CLIPTextTransformer))
57 | {
58 | this.num_layers = num_layers;
59 | embeddings = new CLIPTextEmbeddings(n_vocab, embed_dim, n_token, device: device, dtype: dtype);
60 | encoder = new CLIPEncoder(num_layers, embed_dim, n_heads, intermediate_size, Activations.QuickGELU, device: device, dtype: dtype);
61 | final_layer_norm = LayerNorm(embed_dim, device: device, dtype: dtype);
62 | RegisterComponents();
63 | }
64 | public override Tensor forward(Tensor x, long num_skip, bool with_final_ln)
65 | {
66 | x = embeddings.forward(x);
67 | x = encoder.forward(x, num_skip);
68 | if (with_final_ln)
69 | {
70 | x = final_layer_norm.forward(x);
71 | }
72 | return x;
73 | }
74 | }
75 |
76 | private class CLIPTextEmbeddings : Module
77 | {
78 | private readonly Embedding token_embedding;
79 | private readonly Embedding position_embedding;
80 | private readonly Parameter position_ids;
81 | public CLIPTextEmbeddings(long n_vocab, long n_embd, long n_token, Device? device = null, ScalarType? dtype = null) : base(nameof(CLIPTextEmbeddings))
82 | {
83 | position_ids = Parameter(zeros(size: new long[] { 1, n_token }, device: device, dtype: dtype));
84 | token_embedding = Embedding(n_vocab, n_embd, device: device, dtype: dtype);
85 | position_embedding = Embedding(n_token, n_embd, device: device, dtype: dtype);
86 | RegisterComponents();
87 | }
88 |
89 | public override Tensor forward(Tensor tokens)
90 | {
91 | return token_embedding.forward(tokens) + position_embedding.forward(position_ids.@long());
92 | }
93 | }
94 |
95 | private class CLIPEncoderLayer : Module
96 | {
97 | private readonly LayerNorm layer_norm1;
98 | private readonly LayerNorm layer_norm2;
99 | private readonly CLIPAttention self_attn;
100 | private readonly CLIPMLP mlp;
101 |
102 | public CLIPEncoderLayer(long n_head, long embed_dim, long intermediate_size, Activations activations = Activations.QuickGELU, Device? device = null, ScalarType? dtype = null) : base(nameof(CLIPEncoderLayer))
103 | {
104 | layer_norm1 = LayerNorm(embed_dim, device: device, dtype: dtype);
105 | self_attn = new CLIPAttention(embed_dim, n_head, device: device, dtype: dtype);
106 | layer_norm2 = LayerNorm(embed_dim, device: device, dtype: dtype);
107 | mlp = new CLIPMLP(embed_dim, intermediate_size, embed_dim, activations, device: device, dtype: dtype);
108 | RegisterComponents();
109 | }
110 |
111 | public override Tensor forward(Tensor x)
112 | {
113 | x += self_attn.forward(layer_norm1.forward(x));
114 | x += mlp.forward(layer_norm2.forward(x));
115 | return x;
116 | }
117 | }
118 |
119 | private class CLIPMLP : Module
120 | {
121 | private readonly Linear fc1;
122 | private readonly Linear fc2;
123 | private readonly Activations act_layer;
124 | public CLIPMLP(long in_features, long? hidden_features = null, long? out_features = null, Activations act_layer = Activations.QuickGELU, bool bias = true, Device? device = null, ScalarType? dtype = null) : base(nameof(CLIPMLP))
125 | {
126 | out_features ??= in_features;
127 | hidden_features ??= out_features;
128 |
129 | fc1 = Linear(in_features, (long)hidden_features, hasBias: bias, device: device, dtype: dtype);
130 | fc2 = Linear((long)hidden_features, (long)out_features, hasBias: bias, device: device, dtype: dtype);
131 | this.act_layer = act_layer;
132 | RegisterComponents();
133 | }
134 |
135 | public override Tensor forward(Tensor x)
136 | {
137 | x = fc1.forward(x);
138 |
139 | switch (act_layer)
140 | {
141 | case Activations.ReLU:
142 | x = functional.relu(x);
143 | break;
144 | case Activations.SiLU:
145 | x = functional.silu(x);
146 | break;
147 | case Activations.QuickGELU:
148 | x = x * sigmoid(1.702 * x);
149 | break;
150 | case Activations.GELU:
151 | x = functional.gelu(x);
152 | break;
153 | }
154 | x = fc2.forward(x);
155 | return x;
156 | }
157 | }
158 |
159 | private class CLIPAttention : Module
160 | {
161 | private readonly long heads;
162 | private readonly Linear q_proj;
163 | private readonly Linear k_proj;
164 | private readonly Linear v_proj;
165 | private readonly Linear out_proj;
166 |
167 | public CLIPAttention(long embed_dim, long heads, Device? device = null, ScalarType? dtype = null) : base(nameof(CLIPAttention))
168 | {
169 | this.heads = heads;
170 | q_proj = Linear(embed_dim, embed_dim, hasBias: true, device: device, dtype: dtype);
171 | k_proj = Linear(embed_dim, embed_dim, hasBias: true, device: device, dtype: dtype);
172 | v_proj = Linear(embed_dim, embed_dim, hasBias: true, device: device, dtype: dtype);
173 | out_proj = Linear(embed_dim, embed_dim, hasBias: true, device: device, dtype: dtype);
174 |
175 | RegisterComponents();
176 | }
177 |
178 | public override Tensor forward(Tensor x)
179 | {
180 | using (var _ = NewDisposeScope())
181 | {
182 | Tensor q = q_proj.forward(x);
183 | Tensor k = k_proj.forward(x);
184 | Tensor v = v_proj.forward(x);
185 | Tensor output = attention(q, k, v, heads);
186 | //TensorInfo output = self_atten(to_q, to_k, to_v, this.heads);
187 | return out_proj.forward(output).MoveToOuterDisposeScope();
188 | }
189 | }
190 |
191 | private static Tensor self_atten(Tensor q, Tensor k, Tensor v, long heads)
192 | {
193 | long[] input_shape = q.shape;
194 | long batch_size = q.shape[0];
195 | long sequence_length = q.shape[1];
196 | long d_head = q.shape[2] / heads;
197 | long[] interim_shape = new long[] { batch_size, sequence_length, heads, d_head };
198 |
199 | q = q.view(interim_shape).transpose(1, 2);
200 | k = k.view(interim_shape).transpose(1, 2);
201 | v = v.view(interim_shape).transpose(1, 2);
202 |
203 | var weight = matmul(q, k.transpose(-1, -2));
204 | var mask = ones_like(weight).triu(1).to(@bool);
205 | weight.masked_fill_(mask, float.NegativeInfinity);
206 |
207 | weight = weight / (float)Math.Sqrt(d_head);
208 | weight = functional.softmax(weight, dim: -1);
209 |
210 | var output = matmul(weight, v);
211 | output = output.transpose(1, 2);
212 | output = output.reshape(input_shape);
213 | return output;
214 | }
215 |
216 | // Convenience wrapper around a basic attention operation
217 | private static Tensor attention(Tensor q, Tensor k, Tensor v, long heads)
218 | {
219 | long b = q.shape[0];
220 | long dim_head = q.shape[2];
221 | dim_head /= heads;
222 | q = q.view(b, -1, heads, dim_head).transpose(1, 2);
223 | k = k.view(b, -1, heads, dim_head).transpose(1, 2);
224 | v = v.view(b, -1, heads, dim_head).transpose(1, 2);
225 | Tensor output = functional.scaled_dot_product_attention(q, k, v, is_casual: true);
226 | output = output.transpose(1, 2);
227 | output = output.view(b, -1, heads * dim_head);
228 | return output;
229 | }
230 | }
231 |
232 | private class CLIPEncoder : Module
233 | {
234 | private readonly ModuleList layers;
235 |
236 | public CLIPEncoder(long num_layers, long embed_dim, long heads, long intermediate_size, Activations intermediate_activation, Device? device = null, ScalarType? dtype = null) : base(nameof(CLIPEncoder))
237 | {
238 | layers = new ModuleList();
239 | for (int i = 0; i < num_layers; i++)
240 | {
241 | layers.append(new CLIPEncoderLayer(heads, embed_dim, intermediate_size, intermediate_activation, device: device, dtype: dtype));
242 | }
243 | RegisterComponents();
244 | }
245 |
246 | public override Tensor forward(Tensor x, long num_skip)
247 | {
248 | long num_act = num_skip > 0 ? layers.Count - num_skip : layers.Count;
249 | for (int i = 0; i < num_act; i++)
250 | {
251 | x = layers[i].forward(x);
252 | }
253 |
254 | return x;
255 | }
256 | }
257 | }
258 |
259 | private class ViT_bigG_Clip : Module
260 | {
261 | private readonly int adm_in_channels;
262 |
263 | private readonly Embedding token_embedding;
264 | private readonly Parameter positional_embedding;
265 | private readonly Transformer transformer;
266 | private readonly LayerNorm ln_final;
267 | private readonly Parameter text_projection;
268 |
269 | public ViT_bigG_Clip(long n_vocab = 49408, long n_token = 77, long num_layers = 32, long n_heads = 20, long embed_dim = 1280, long intermediate_size = 1280 * 4, Device? device = null, ScalarType? dtype = null) : base(nameof(ViT_bigG_Clip))
270 | {
271 | token_embedding = Embedding(n_vocab, embed_dim, device: device, dtype: dtype);
272 | positional_embedding = Parameter(zeros(size: new long[] { n_token, embed_dim }, device: device, dtype: dtype));
273 | text_projection = Parameter(zeros(size: new long[] { embed_dim, embed_dim }, device: device, dtype: dtype));
274 | transformer = new Transformer(num_layers, embed_dim, n_heads, intermediate_size, Activations.GELU, device: device, dtype: dtype);
275 | ln_final = LayerNorm(embed_dim, device: device, dtype: dtype);
276 | RegisterComponents();
277 | }
278 |
279 | public override Tensor forward(Tensor x, int num_skip, bool with_final_ln, bool return_pooled)
280 | {
281 | using (NewDisposeScope())
282 | {
283 | Tensor input_ids = x;
284 | x = token_embedding.forward(x) + positional_embedding;
285 | x = transformer.forward(x, num_skip);
286 | if (with_final_ln || return_pooled)
287 | {
288 | x = ln_final.forward(x);
289 | }
290 | if (return_pooled)
291 | {
292 | x = x[torch.arange(x.shape[0], device: x.device), input_ids.to(type: ScalarType.Int32, device: x.device).argmax(dim: -1)];
293 | x = functional.linear(x, text_projection.transpose(0, 1));
294 | }
295 | return x.MoveToOuterDisposeScope();
296 | }
297 | }
298 |
299 | private class Transformer : Module
300 | {
301 | private readonly ModuleList resblocks;
302 | public Transformer(long num_layers, long embed_dim, long heads, long intermediate_size, Activations intermediate_activation, Device? device = null, ScalarType? dtype = null) : base(nameof(Transformer))
303 | {
304 | resblocks = new ModuleList();
305 | for (int i = 0; i < num_layers; i++)
306 | {
307 | resblocks.append(new ResidualAttentionBlock(heads, embed_dim, intermediate_size, intermediate_activation, device: device, dtype: dtype));
308 | }
309 | RegisterComponents();
310 | }
311 |
312 | public override Tensor forward(Tensor x, int num_skip)
313 | {
314 | int num_act = num_skip > 0 ? resblocks.Count - num_skip : resblocks.Count;
315 | for (int i = 0; i < num_act; i++)
316 | {
317 | x = resblocks[i].forward(x);
318 | }
319 | return x;
320 | }
321 | }
322 |
323 | private class ResidualAttentionBlock : Module
324 | {
325 | private readonly LayerNorm ln_1;
326 | private readonly LayerNorm ln_2;
327 | private readonly MultiheadAttention attn;
328 | private readonly Mlp mlp;
329 |
330 | public ResidualAttentionBlock(long n_head, long embed_dim, long intermediate_size, Activations activations = Activations.QuickGELU, Device? device = null, ScalarType? dtype = null) : base(nameof(ResidualAttentionBlock))
331 | {
332 | ln_1 = LayerNorm(embed_dim, device: device, dtype: dtype);
333 | attn = new MultiheadAttention(embed_dim, n_head, device: device, dtype: dtype);
334 | ln_2 = LayerNorm(embed_dim, device: device, dtype: dtype);
335 | mlp = new Mlp(embed_dim, intermediate_size, embed_dim, activations, device: device, dtype: dtype);
336 | RegisterComponents();
337 | }
338 |
339 | public override Tensor forward(Tensor x)
340 | {
341 | x += attn.forward(ln_1.forward(x));
342 | x += mlp.forward(ln_2.forward(x));
343 | return x;
344 | }
345 | }
346 |
347 | private class Mlp : Module
348 | {
349 | private readonly Linear c_fc;
350 | private readonly Linear c_proj;
351 | private readonly Activations act_layer;
352 | public Mlp(long in_features, long? hidden_features = null, long? out_features = null, Activations act_layer = Activations.QuickGELU, bool bias = true, Device? device = null, ScalarType? dtype = null) : base(nameof(Mlp))
353 | {
354 | out_features ??= in_features;
355 | hidden_features ??= out_features;
356 |
357 | c_fc = Linear(in_features, (long)hidden_features, hasBias: bias, device: device, dtype: dtype);
358 | c_proj = Linear((long)hidden_features, (long)out_features, hasBias: bias, device: device, dtype: dtype);
359 | this.act_layer = act_layer;
360 | RegisterComponents();
361 | }
362 |
363 | public override Tensor forward(Tensor x)
364 | {
365 | x = c_fc.forward(x);
366 |
367 | switch (act_layer)
368 | {
369 | case Activations.ReLU:
370 | x = functional.relu(x);
371 | break;
372 | case Activations.SiLU:
373 | x = functional.silu(x);
374 | break;
375 | case Activations.QuickGELU:
376 | x = x * sigmoid(1.702 * x);
377 | break;
378 | case Activations.GELU:
379 | x = functional.gelu(x);
380 | break;
381 | }
382 | x = c_proj.forward(x);
383 | return x;
384 | }
385 | }
386 |
387 | private class MultiheadAttention : Module
388 | {
389 | private readonly long heads;
390 | private readonly Parameter in_proj_weight;
391 | private readonly Parameter in_proj_bias;
392 | private readonly Linear out_proj;
393 |
394 | public MultiheadAttention(long embed_dim, long heads, Device? device = null, ScalarType? dtype = null) : base(nameof(MultiheadAttention))
395 | {
396 | this.heads = heads;
397 | in_proj_weight = Parameter(zeros(new long[] { 3 * embed_dim, embed_dim }, device: device, dtype: dtype));
398 | in_proj_bias = Parameter(zeros(new long[] { 3 * embed_dim }, device: device, dtype: dtype));
399 | out_proj = Linear(embed_dim, embed_dim, hasBias: true, device: device, dtype: dtype);
400 |
401 | RegisterComponents();
402 | }
403 |
404 | public override Tensor forward(Tensor x)
405 | {
406 | using (var _ = NewDisposeScope())
407 | {
408 | Tensor[] qkv = functional.linear(x, in_proj_weight, in_proj_bias).chunk(3, 2);
409 | Tensor q = qkv[0];
410 | Tensor k = qkv[1];
411 | Tensor v = qkv[2];
412 | Tensor output = attention(q, k, v, heads);
413 | //TensorInfo output = self_atten(to_q, to_k, to_v, this.heads);
414 | return out_proj.forward(output).MoveToOuterDisposeScope();
415 | }
416 | }
417 |
418 | private static Tensor self_atten(Tensor q, Tensor k, Tensor v, long heads)
419 | {
420 | long[] input_shape = q.shape;
421 | long batch_size = q.shape[0];
422 | long sequence_length = q.shape[1];
423 | long d_head = q.shape[2] / heads;
424 | long[] interim_shape = new long[] { batch_size, sequence_length, heads, d_head };
425 |
426 | q = q.view(interim_shape).transpose(1, 2);
427 | k = k.view(interim_shape).transpose(1, 2);
428 | v = v.view(interim_shape).transpose(1, 2);
429 |
430 | var weight = matmul(q, k.transpose(-1, -2));
431 | var mask = ones_like(weight).triu(1).to(@bool);
432 | weight.masked_fill_(mask, float.NegativeInfinity);
433 |
434 | weight = weight / (float)Math.Sqrt(d_head);
435 | weight = functional.softmax(weight, dim: -1);
436 |
437 | var output = matmul(weight, v);
438 | output = output.transpose(1, 2);
439 | output = output.reshape(input_shape);
440 | return output;
441 | }
442 |
443 | // Convenience wrapper around a basic attention operation
444 | private static Tensor attention(Tensor q, Tensor k, Tensor v, long heads)
445 | {
446 | long b = q.shape[0];
447 | long dim_head = q.shape[2];
448 | dim_head /= heads;
449 | q = q.view(b, -1, heads, dim_head).transpose(1, 2);
450 | k = k.view(b, -1, heads, dim_head).transpose(1, 2);
451 | v = v.view(b, -1, heads, dim_head).transpose(1, 2);
452 | Tensor output = functional.scaled_dot_product_attention(q, k, v, is_casual: true);
453 | output = output.transpose(1, 2);
454 | output = output.view(b, -1, heads * dim_head);
455 | return output;
456 | }
457 | }
458 | }
459 |
460 | internal class SDCliper : Module
461 | {
462 | private readonly ViT_L_Clip cond_stage_model;
463 | private readonly long n_token;
464 | private readonly long endToken;
465 |
466 | public SDCliper(long n_vocab = 49408, long n_token = 77, long num_layers = 12, long n_heads = 12, long embed_dim = 768, long intermediate_size = 768 * 4, long endToken = 49407, Device? device = null, ScalarType? dtype = null) : base(nameof(SDCliper))
467 | {
468 | this.n_token = n_token;
469 | this.endToken = endToken;
470 | cond_stage_model = new ViT_L_Clip(n_vocab, n_token, num_layers, n_heads, embed_dim, intermediate_size, device: device, dtype: dtype);
471 | RegisterComponents();
472 | }
473 | public override (Tensor, Tensor) forward(Tensor token, long num_skip)
474 | {
475 | using (NewDisposeScope())
476 | {
477 | Device device = cond_stage_model.parameters().First().device;
478 | long padLength = n_token - token.shape[1];
479 | Tensor token1 = functional.pad(token, new long[] { 0, padLength, 0, 0 }, value: endToken);
480 | return (cond_stage_model.forward(token1, num_skip, true).MoveToOuterDisposeScope(), zeros(1).MoveToOuterDisposeScope());
481 | }
482 | }
483 | }
484 |
485 | internal class SDXLCliper : Module
486 | {
487 | private readonly Embedders conditioner;
488 | public SDXLCliper(long n_vocab = 49408, long n_token = 77, Device? device = null, ScalarType? dtype = null) : base(nameof(SDXLCliper))
489 | {
490 | conditioner = new Embedders(n_token, device: device, dtype: dtype);
491 | RegisterComponents();
492 | }
493 |
494 | public override (Tensor, Tensor) forward(Tensor token, long num_skip)
495 | {
496 | Device device = conditioner.parameters().First().device;
497 | token = token.to(device);
498 | return conditioner.forward(token);
499 | }
500 |
501 | private class Embedders : Module
502 | {
503 | private readonly ModuleList embedders;
504 | private readonly long n_token;
505 | private readonly long endToken;
506 | public Embedders(long n_token = 77, int endToken = 49407, Device? device = null, ScalarType? dtype = null) : base(nameof(Embedders))
507 | {
508 | this.n_token = n_token;
509 | this.endToken = endToken;
510 | Model model = new Model(device: device, dtype: dtype);
511 | embedders = ModuleList(new ViT_L_Clip(device: device, dtype: dtype), model);
512 | RegisterComponents();
513 | }
514 | public override (Tensor, Tensor) forward(Tensor token)
515 | {
516 | using (NewDisposeScope())
517 | {
518 | long padLength = n_token - token.shape[1];
519 | Tensor token1 = functional.pad(token, new long[] { 0, padLength, 0, 0 }, value: endToken);
520 | Tensor token2 = functional.pad(token, new long[] { 0, padLength, 0, 0 });
521 |
522 | Tensor vit_l_result = ((ViT_L_Clip)embedders[0]).forward(token1, 1, false);
523 | Tensor vit_bigG_result = ((Model)embedders[1]).forward(token2, 1, false, false);
524 | Tensor vit_bigG_vec = ((Model)embedders[1]).forward(token2, 0, false, true);
525 | Tensor crossattn = cat(new Tensor[] { vit_l_result, vit_bigG_result }, -1);
526 | return (crossattn.MoveToOuterDisposeScope(), vit_bigG_vec.MoveToOuterDisposeScope());
527 | }
528 | }
529 | }
530 |
531 | private class Model : Module
532 | {
533 | private readonly ViT_bigG_Clip model;
534 | public Model(Device? device = null, ScalarType? dtype = null) : base(nameof(Model))
535 | {
536 | model = new ViT_bigG_Clip(device: device, dtype: dtype);
537 | RegisterComponents();
538 | }
539 | public override Tensor forward(Tensor token, int num_skip, bool with_final_ln, bool return_pooled)
540 | {
541 | return model.forward(token, num_skip, with_final_ln, return_pooled);
542 | }
543 | }
544 | }
545 |
546 | }
547 | }
548 |
--------------------------------------------------------------------------------
/StableDiffusionSharp/Modules/Esrgan.cs:
--------------------------------------------------------------------------------
1 | using StableDiffusionSharp.ModelLoader;
2 | using TorchSharp;
3 | using TorchSharp.Modules;
4 | using static TorchSharp.torch;
5 | using static TorchSharp.torch.nn;
6 |
7 | namespace StableDiffusionSharp.Modules
8 | {
9 | public class Esrgan : IDisposable
10 | {
11 | private readonly RRDBNet rrdbnet;
12 | Device device;
13 | ScalarType dtype;
14 |
15 | public Esrgan(int num_block = 23, SDDeviceType deviceType = SDDeviceType.CUDA, SDScalarType scalarType = SDScalarType.Float16)
16 | {
17 | torchvision.io.DefaultImager = new torchvision.io.SkiaImager();
18 | device = new Device((DeviceType)deviceType);
19 | dtype = (ScalarType)scalarType;
20 | rrdbnet = new RRDBNet(num_in_ch: 3, num_out_ch: 3, num_feat: 64, num_block: num_block, num_grow_ch: 32, scale: 4, device: device, dtype: dtype);
21 | }
22 |
23 | ///
24 | /// Residual Dense Block.
25 | ///
26 | private class ResidualDenseBlock : Module
27 | {
28 | private readonly Conv2d conv1;
29 | private readonly Conv2d conv2;
30 | private readonly Conv2d conv3;
31 | private readonly Conv2d conv4;
32 | private readonly Conv2d conv5;
33 | private readonly LeakyReLU lrelu;
34 |
35 | ///
36 | /// Used in RRDB block in ESRGAN.
37 | ///
38 | /// Channel number of intermediate features.
39 | /// Channels for each growth.
40 | public ResidualDenseBlock(int num_feat = 64, int num_grow_ch = 32, Device? device = null, ScalarType? dtype = null) : base(nameof(ResidualDenseBlock))
41 | {
42 | conv1 = Conv2d(num_feat, num_grow_ch, 3, 1, 1, device: device, dtype: dtype);
43 | conv2 = Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1, device: device, dtype: dtype);
44 | conv3 = Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1, device: device, dtype: dtype);
45 | conv4 = Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1, device: device, dtype: dtype);
46 | conv5 = Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1, device: device, dtype: dtype);
47 | lrelu = LeakyReLU(negative_slope: 0.2f, inplace: true);
48 | RegisterComponents();
49 | }
50 |
51 | public override Tensor forward(Tensor x)
52 | {
53 | using (NewDisposeScope())
54 | {
55 | Tensor x1 = lrelu.forward(conv1.forward(x));
56 | Tensor x2 = lrelu.forward(conv2.forward(cat(new Tensor[] { x, x1 }, 1)));
57 | Tensor x3 = lrelu.forward(conv3.forward(cat(new Tensor[] { x, x1, x2 }, 1)));
58 | Tensor x4 = lrelu.forward(conv4.forward(cat(new Tensor[] { x, x1, x2, x3 }, 1)));
59 | Tensor x5 = conv5.forward(cat(new Tensor[] { x, x1, x2, x3, x4 }, 1));
60 | // Empirically, we use 0.2 to scale the residual for better performance
61 | return (x5 * 0.2 + x).MoveToOuterDisposeScope();
62 | }
63 | }
64 | }
65 |
66 | ///
67 | /// Residual in Residual Dense Block.
68 | ///
69 | private class RRDB : Module
70 | {
71 | private readonly ResidualDenseBlock rdb1;
72 | private readonly ResidualDenseBlock rdb2;
73 | private readonly ResidualDenseBlock rdb3;
74 |
75 | ///
76 | /// Used in RRDB-Net in ESRGAN.
77 | ///
78 | /// Channel number of intermediate features.
79 | /// Channels for each growth.
80 | public RRDB(int num_feat, int num_grow_ch = 32, Device? device = null, ScalarType? dtype = null) : base(nameof(RRDB))
81 | {
82 | rdb1 = new ResidualDenseBlock(num_feat, num_grow_ch, device: device, dtype: dtype);
83 | rdb2 = new ResidualDenseBlock(num_feat, num_grow_ch, device: device, dtype: dtype);
84 | rdb3 = new ResidualDenseBlock(num_feat, num_grow_ch, device: device, dtype: dtype);
85 | RegisterComponents();
86 | }
87 | public override Tensor forward(Tensor x)
88 | {
89 | using (NewDisposeScope())
90 | {
91 | Tensor @out = rdb1.forward(x);
92 | @out = rdb2.forward(@out);
93 | @out = rdb3.forward(@out);
94 | // Empirically, we use 0.2 to scale the residual for better performance
95 | return (@out * 0.2 + x).MoveToOuterDisposeScope();
96 | }
97 | }
98 | }
99 |
100 | private class RRDBNet : Module
101 | {
102 | private readonly int scale;
103 | private readonly Conv2d conv_first;
104 | private readonly Sequential body;
105 | private readonly Conv2d conv_body;
106 | private readonly Conv2d conv_up1;
107 | private readonly Conv2d conv_up2;
108 | private readonly Conv2d conv_hr;
109 | private readonly Conv2d conv_last;
110 | private readonly LeakyReLU lrelu;
111 |
112 | public RRDBNet(int num_in_ch, int num_out_ch, int scale = 4, int num_feat = 64, int num_block = 23, int num_grow_ch = 32, Device? device = null, ScalarType? dtype = null) : base(nameof(RRDBNet))
113 | {
114 | this.scale = scale;
115 | if (scale == 2)
116 | {
117 | num_in_ch = num_in_ch * 4;
118 | }
119 | else if (scale == 1)
120 | {
121 | num_in_ch = num_in_ch * 16;
122 | }
123 | conv_first = Conv2d(num_in_ch, num_feat, 3, 1, 1, device: device, dtype: dtype);
124 | body = Sequential();
125 | for (int i = 0; i < num_block; i++)
126 | {
127 | body.append(new RRDB(num_feat: num_feat, num_grow_ch: num_grow_ch, device: device, dtype: dtype));
128 | }
129 | conv_body = Conv2d(num_feat, num_feat, 3, 1, 1, device: device, dtype: dtype);
130 | // upsample
131 | conv_up1 = Conv2d(num_feat, num_feat, 3, 1, 1, device: device, dtype: dtype);
132 | conv_up2 = Conv2d(num_feat, num_feat, 3, 1, 1, device: device, dtype: dtype);
133 | conv_hr = Conv2d(num_feat, num_feat, 3, 1, 1, device: device, dtype: dtype);
134 | conv_last = Conv2d(num_feat, num_out_ch, 3, 1, 1, device: device, dtype: dtype);
135 | lrelu = LeakyReLU(negative_slope: 0.2f, inplace: true);
136 | RegisterComponents();
137 | }
138 |
139 | public override Tensor forward(Tensor x)
140 | {
141 | using (NewDisposeScope())
142 | {
143 | Tensor feat = x;
144 | if (scale == 2)
145 | {
146 | feat = pixel_unshuffle(x, scale: 2);
147 | }
148 | else if (scale == 1)
149 | {
150 | feat = pixel_unshuffle(x, scale: 4);
151 | }
152 | feat = conv_first.forward(feat);
153 | Tensor body_feat = conv_body.forward(body.forward(feat));
154 | feat = feat + body_feat;
155 | // upsample
156 | feat = lrelu.forward(conv_up1.forward(functional.interpolate(feat, scale_factor: new double[] { 2, 2 }, mode: InterpolationMode.Nearest)));
157 | feat = lrelu.forward(conv_up2.forward(functional.interpolate(feat, scale_factor: new double[] { 2, 2 }, mode: InterpolationMode.Nearest)));
158 | Tensor @out = conv_last.forward(lrelu.forward(conv_hr.forward(feat)));
159 | return @out.MoveToOuterDisposeScope();
160 | }
161 | }
162 |
163 | ///
164 | /// Pixel unshuffle.
165 | ///
166 | /// Input feature with shape (b, c, hh, hw).
167 | /// Downsample ratio.
168 | /// the pixel unshuffled feature.
169 | private Tensor pixel_unshuffle(Tensor x, int scale)
170 | {
171 | long b = x.shape[0];
172 | long c = x.shape[1];
173 | long hh = x.shape[2];
174 | long hw = x.shape[3];
175 |
176 | long out_channel = c * (scale * scale);
177 |
178 | if (hh % scale != 0 && hw % scale != 0)
179 | {
180 | throw new ArgumentException("Width or Hight are not match");
181 | }
182 |
183 | long h = hh / scale;
184 | long w = hw / scale;
185 |
186 | Tensor x_view = x.view(b, c, h, scale, w, scale);
187 | return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w);
188 | }
189 | }
190 |
191 | public void LoadModel(string path)
192 | {
193 | rrdbnet.LoadModel(path);
194 | rrdbnet.eval();
195 | }
196 |
197 | public ImageMagick.MagickImage UpScale(ImageMagick.MagickImage inputImg)
198 | {
199 | using (no_grad())
200 | {
201 | Tensor tensor = Tools.GetTensorFromImage(inputImg);
202 | tensor = tensor.unsqueeze(0) / 255.0;
203 | tensor = tensor.to(dtype, device);
204 | Tensor op = rrdbnet.forward(tensor);
205 | op = (op.cpu() * 255.0f).clamp(0, 255).@byte();
206 | return Tools.GetImageFromTensor(op);
207 | }
208 | }
209 |
210 | public void Dispose()
211 | {
212 | rrdbnet?.Dispose();
213 | }
214 | }
215 | }
216 |
--------------------------------------------------------------------------------
/StableDiffusionSharp/Modules/SD1.cs:
--------------------------------------------------------------------------------
1 | using TorchSharp;
2 | using static TorchSharp.torch;
3 |
4 | namespace StableDiffusionSharp.Modules
5 | {
6 | public class SD1 : SDModel
7 | {
8 | public SD1(Device? device = null, ScalarType? dtype = null) : base(device, dtype)
9 | {
10 | torchvision.io.DefaultImager = new torchvision.io.SkiaImager();
11 | this.device = device ?? torch.CPU;
12 | this.dtype = dtype ?? torch.float32;
13 |
14 | // Default parameters
15 | this.scale_factor = 0.18215f;
16 |
17 | // UNet config
18 | this.in_channels = 4;
19 | this.model_channels = 320;
20 | this.context_dim = 768;
21 | this.num_head = 8;
22 | this.dropout = 0.0f;
23 | this.embed_dim = 4;
24 |
25 | // first stage config:
26 | this.embed_dim = 4;
27 | this.double_z = true;
28 | this.z_channels = 4;
29 |
30 | }
31 | }
32 |
33 | }
34 |
35 |
--------------------------------------------------------------------------------
/StableDiffusionSharp/Modules/SDModel.cs:
--------------------------------------------------------------------------------
1 | using StableDiffusionSharp.ModelLoader;
2 | using StableDiffusionSharp.Sampler;
3 | using System.Diagnostics;
4 | using System.Text;
5 | using TorchSharp;
6 | using static TorchSharp.torch;
7 | using static TorchSharp.torch.nn;
8 |
9 | namespace StableDiffusionSharp.Modules
10 | {
11 | public class SDModel : IDisposable
12 | {
13 | // Default parameters
14 | private float linear_start = 0.00085f;
15 | private float linear_end = 0.0120f;
16 | private int num_timesteps_cond = 1;
17 | private int timesteps = 1000;
18 | internal float scale_factor = 0.18215f;
19 | internal int adm_in_channels = 2816;
20 |
21 | // UNet config
22 | internal int in_channels = 4;
23 | internal int model_channels = 320;
24 | internal int context_dim = 768;
25 | internal int num_head = 8;
26 | internal float dropout = 0.0f;
27 |
28 | // first stage config:
29 | internal int embed_dim = 4;
30 | internal bool double_z = true;
31 | internal int z_channels = 4;
32 |
33 | public class StepEventArgs : EventArgs
34 | {
35 | public int CurrentStep { get; }
36 | public int TotalSteps { get; }
37 | public ImageMagick.MagickImage VAEApproxImg { get; }
38 |
39 | public StepEventArgs(int currentStep, int totalSteps, ImageMagick.MagickImage vAEApproxImg)
40 | {
41 | CurrentStep = currentStep;
42 | TotalSteps = totalSteps;
43 | VAEApproxImg = vAEApproxImg;
44 | }
45 | }
46 |
47 | public event EventHandler StepProgress;
48 | protected void OnStepProgress(int currentStep, int totalSteps, ImageMagick.MagickImage vaeApproxImg)
49 | {
50 | StepProgress?.Invoke(this, new StepEventArgs(currentStep, totalSteps, vaeApproxImg));
51 | }
52 |
53 | internal Module cliper;
54 | internal Module diffusion;
55 | private VAE.Decoder decoder;
56 | private VAE.Encoder encoder;
57 | private Tokenizer tokenizer;
58 | private VAEApprox vaeApprox;
59 |
60 | internal Device device;
61 | internal ScalarType dtype;
62 |
63 | private int tempPromptHash;
64 | private Tensor tempTextContext;
65 | private Tensor tempPooled;
66 |
67 | bool is_loaded = false;
68 |
69 | public SDModel(Device? device = null, ScalarType? dtype = null)
70 | {
71 | torchvision.io.DefaultImager = new torchvision.io.SkiaImager();
72 | this.device = device ?? torch.CPU;
73 | this.dtype = dtype ?? torch.float32;
74 | }
75 |
76 | public virtual void LoadModel(string modelPath, string vaeModelPath, string vocabPath = @".\models\clip\vocab.json", string mergesPath = @".\models\clip\merges.txt")
77 | {
78 | is_loaded = false;
79 | ModelType modelType = ModelLoader.ModelLoader.GetModelType(modelPath);
80 |
81 | cliper = modelType switch
82 | {
83 | ModelType.SD1 => new Clip.SDCliper(device: device, dtype: dtype),
84 | ModelType.SDXL => new Clip.SDXLCliper(device: device, dtype: dtype),
85 | _ => throw new ArgumentException("Invalid model type")
86 | };
87 | cliper.eval();
88 |
89 | diffusion = modelType switch
90 | {
91 | ModelType.SD1 => new SDUnet(model_channels, in_channels, num_head, context_dim, dropout, device: device, dtype: dtype),
92 | ModelType.SDXL => new SDXLUnet(model_channels, in_channels, num_head, context_dim, adm_in_channels, dropout, device: device, dtype: dtype),
93 | _ => throw new ArgumentException("Invalid model type")
94 | };
95 | diffusion.eval();
96 |
97 | decoder = new VAE.Decoder(embed_dim: embed_dim, z_channels: z_channels, device: device, dtype: dtype);
98 | decoder.eval();
99 | encoder = new VAE.Encoder(embed_dim: embed_dim, z_channels: z_channels, double_z: double_z, device: device, dtype: dtype);
100 | encoder.eval();
101 |
102 | vaeApprox = new VAEApprox(4, device, dtype);
103 | vaeApprox.eval();
104 |
105 | vaeModelPath = string.IsNullOrEmpty(vaeModelPath) ? modelPath : vaeModelPath;
106 |
107 | cliper.LoadModel(modelPath);
108 | diffusion.LoadModel(modelPath);
109 | decoder.LoadModel(vaeModelPath, "first_stage_model.");
110 | encoder.LoadModel(vaeModelPath, "first_stage_model.");
111 |
112 | string vaeApproxPath = modelType switch
113 | {
114 | ModelType.SD1 => @".\Models\VAEApprox\vaeapp_sd15.pth",
115 | ModelType.SDXL => @".\Models\VAEApprox\xlvaeapp.pth",
116 | _ => throw new ArgumentException("Invalid model type")
117 | };
118 |
119 | vaeApprox.LoadModel(vaeApproxPath);
120 |
121 | tokenizer = new Tokenizer(vocabPath, mergesPath);
122 | is_loaded = true;
123 |
124 | GC.Collect();
125 | }
126 |
127 | private void CheckModelLoaded()
128 | {
129 | if (!is_loaded)
130 | {
131 | throw new InvalidOperationException("Model not loaded");
132 | }
133 | }
134 |
135 | private static Tensor GetTimeEmbedding(Tensor timestep, int max_period = 10000, int dim = 320, bool repeat_only = false)
136 | {
137 | if (repeat_only)
138 | {
139 | return torch.repeat_interleave(timestep, dim);
140 | }
141 | else
142 | {
143 | int half = dim / 2;
144 | var freqs = torch.pow(max_period, -torch.arange(0, half, dtype: torch.float32) / half);
145 | var x = timestep * freqs.unsqueeze(0);
146 | x = torch.cat(new Tensor[] { x, x });
147 | return torch.cat(new Tensor[] { torch.cos(x), torch.sin(x) }, dim: -1);
148 | }
149 | }
150 |
151 | private (Tensor, Tensor) Clip(string prompt, string nprompt, long clip_skip)
152 | {
153 | CheckModelLoaded();
154 | if (tempPromptHash != (prompt + nprompt).GetHashCode())
155 | {
156 | using (no_grad())
157 | using (NewDisposeScope())
158 | {
159 | Tensor cond_tokens = tokenizer.Tokenize(prompt).to(device);
160 | (Tensor cond_context, Tensor cond_pooled) = cliper.forward(cond_tokens, clip_skip);
161 | Tensor uncond_tokens = tokenizer.Tokenize(nprompt).to(device);
162 | (Tensor uncond_context, Tensor uncond_pooled) = cliper.forward(uncond_tokens, clip_skip);
163 | Tensor context = cat(new Tensor[] { cond_context, uncond_context });
164 | tempPromptHash = (prompt + nprompt).GetHashCode();
165 | tempTextContext = context;
166 | tempPooled = cat(new Tensor[] { cond_pooled, uncond_pooled });
167 | tempTextContext = tempTextContext.MoveToOuterDisposeScope();
168 | tempPooled = tempPooled.MoveToOuterDisposeScope();
169 | }
170 | }
171 | return (tempTextContext, tempPooled);
172 | }
173 |
174 | ///
175 | /// Generate image from text
176 | ///
177 | /// Prompt
178 | /// Negtive Prompt
179 | /// Image width, must be multiples of 64, otherwise, it will be resized
180 | /// Image width, must be multiples of 64, otherwise, it will be resized
181 | /// Step to generate image
182 | /// Random seed for generating image, it will get random when the value is 0
183 | /// Classifier Free Guidance
184 | public virtual ImageMagick.MagickImage TextToImage(string prompt, string nprompt = "", long clip_skip = 0, int width = 512, int height = 512, int steps = 20, long seed = 0, float cfg = 7.0f, SDSamplerType samplerType = SDSamplerType.Euler)
185 | {
186 | CheckModelLoaded();
187 |
188 | using (no_grad())
189 | {
190 | if (steps < 1)
191 | {
192 | throw new ArgumentException("steps must be greater than 0");
193 | }
194 | if (cfg < 0.5)
195 | {
196 | throw new ArgumentException("cfg is too small, it may cause the image to be too noisy");
197 | }
198 |
199 | seed = seed == 0 ? Random.Shared.NextInt64() : seed;
200 | set_rng_state(manual_seed(seed).get_state());
201 |
202 | width = width / 64 * 8; // must be multiples of 64
203 | height = height / 64 * 8; // must be multiples of 64
204 | Console.WriteLine("Device:" + device);
205 | Console.WriteLine("Type:" + dtype);
206 | Console.WriteLine("CFG:" + cfg);
207 | Console.WriteLine("Seed:" + seed);
208 | Console.WriteLine("Width:" + width * 8);
209 | Console.WriteLine("Height:" + height * 8);
210 |
211 | Stopwatch sp = Stopwatch.StartNew();
212 | Console.WriteLine("Clip is doing......");
213 | (Tensor context, Tensor vector) = Clip(prompt, nprompt, clip_skip);
214 | using var _ = NewDisposeScope();
215 | Console.WriteLine("Getting latents......");
216 | Tensor latents = randn(new long[] { 1, 4, height, width }).to(dtype, device);
217 |
218 | BasicSampler sampler = samplerType switch
219 | {
220 | SDSamplerType.Euler => new EulerSampler(timesteps, linear_start, linear_end, num_timesteps_cond),
221 | SDSamplerType.EulerAncestral => new EulerAncestralSampler(timesteps, linear_start, linear_end, num_timesteps_cond),
222 | _ => throw new ArgumentException("Unknown sampler type")
223 | };
224 |
225 | sampler.SetTimesteps(steps);
226 | latents *= sampler.InitNoiseSigma();
227 |
228 | Console.WriteLine($"begin sampling");
229 | for (int i = 0; i < steps; i++)
230 | {
231 | Tensor approxTensor = vaeApprox.forward(latents);
232 | approxTensor = approxTensor * 127.5 + 127.5;
233 | approxTensor = approxTensor.clamp(0, 255).@byte().cpu();
234 | ImageMagick.MagickImage approxImg = Tools.GetImageFromTensor(approxTensor);
235 | OnStepProgress(i + 1, steps, approxImg);
236 | Tensor timestep = sampler.Timesteps[i];
237 | Tensor time_embedding = GetTimeEmbedding(timestep);
238 | Tensor input_latents = sampler.ScaleModelInput(latents, i);
239 | input_latents = input_latents.repeat(2, 1, 1, 1);
240 | Tensor output = diffusion.forward(input_latents, context, time_embedding, vector);
241 | Tensor[] ret = output.chunk(2);
242 | Tensor output_cond = ret[0];
243 | Tensor output_uncond = ret[1];
244 | output = cfg * (output_cond - output_uncond) + output_uncond;
245 | latents = sampler.Step(output, i, latents, seed);
246 | }
247 | Console.WriteLine($"end sampling");
248 | Console.WriteLine($"begin decoder");
249 | latents = latents / scale_factor;
250 | Tensor image = decoder.forward(latents);
251 | Console.WriteLine($"end decoder");
252 |
253 |
254 | image = ((image + 0.5) * 255.0f).clamp(0, 255).@byte().cpu();
255 |
256 | ImageMagick.MagickImage img = Tools.GetImageFromTensor(image);
257 |
258 | StringBuilder stringBuilder = new StringBuilder();
259 | stringBuilder.AppendLine(prompt);
260 | if (!string.IsNullOrEmpty(nprompt))
261 | {
262 | stringBuilder.AppendLine("Negative prompt: " + nprompt);
263 | }
264 | stringBuilder.AppendLine($"Steps: {steps}, CFG scale_factor: {cfg}, Seed: {seed}, Size: {width}x{height}, Version: StableDiffusionSharp");
265 | img.SetAttribute("parameters", stringBuilder.ToString());
266 | sp.Stop();
267 | Console.WriteLine($"Total time is: {sp.ElapsedMilliseconds} ms.");
268 | return img;
269 | }
270 | }
271 |
272 |
273 | public virtual ImageMagick.MagickImage ImageToImage(ImageMagick.MagickImage orgImage, string prompt, string nprompt = "", long clip_skip = 0, int steps = 20, float strength = 0.75f, long seed = 0, long subSeed = 0, float cfg = 7.0f, SDSamplerType samplerType = SDSamplerType.Euler)
274 | {
275 | CheckModelLoaded();
276 |
277 | using (no_grad())
278 | {
279 | Stopwatch sp = Stopwatch.StartNew();
280 | seed = seed == 0 ? Random.Shared.NextInt64() : seed;
281 | Generator generator = manual_seed(seed);
282 | set_rng_state(generator.get_state());
283 |
284 | Console.WriteLine("Clip is doing......");
285 | (Tensor context, Tensor vector) = Clip(prompt, nprompt, clip_skip);
286 |
287 | Console.WriteLine("Getting latents......");
288 | Tensor inputTensor = Tools.GetTensorFromImage(orgImage).unsqueeze(0);
289 | inputTensor = inputTensor.to(dtype, device);
290 | inputTensor = inputTensor / 255.0f * 2 - 1.0f;
291 | Tensor lt = encoder.forward(inputTensor);
292 |
293 | Tensor[] mean_var = lt.chunk(2, 1);
294 | Tensor mean = mean_var[0];
295 | Tensor logvar = mean_var[1].clamp(-30, 20);
296 | Tensor std = exp(0.5f * logvar);
297 | Tensor latents = mean + std * randn_like(mean);
298 |
299 | latents = latents * scale_factor;
300 | int t_enc = (int)(strength * steps) - 1;
301 |
302 | BasicSampler sampler = samplerType switch
303 | {
304 | SDSamplerType.Euler => new EulerSampler(timesteps, linear_start, linear_end, num_timesteps_cond),
305 | SDSamplerType.EulerAncestral => new EulerAncestralSampler(timesteps, linear_start, linear_end, num_timesteps_cond),
306 | _ => throw new ArgumentException("Unknown sampler type")
307 | };
308 |
309 | sampler.SetTimesteps(steps);
310 | Tensor sigma_sched = sampler.Sigmas[(steps - t_enc - 1)..];
311 | Tensor noise = randn_like(latents);
312 | latents = latents + noise * sigma_sched.max();
313 |
314 | Console.WriteLine($"begin sampling");
315 | for (int i = 0; i < sigma_sched.NumberOfElements - 1; i++)
316 | {
317 | Tensor approxTensor = vaeApprox.forward(latents);
318 | approxTensor = approxTensor * 127.5 + 127.5;
319 | approxTensor = approxTensor.clamp(0, 255).@byte().cpu();
320 | ImageMagick.MagickImage approxImg = Tools.GetImageFromTensor(approxTensor);
321 | OnStepProgress(i + 1, steps, approxImg);
322 |
323 | int index = steps - t_enc + i - 1;
324 | Tensor timestep = sampler.Timesteps[index];
325 | Tensor time_embedding = GetTimeEmbedding(timestep);
326 | Tensor input_latents = sampler.ScaleModelInput(latents, index);
327 | input_latents = input_latents.repeat(2, 1, 1, 1);
328 | Tensor output = diffusion.forward(input_latents, context, time_embedding, vector);
329 | Tensor[] ret = output.chunk(2);
330 | Tensor output_cond = ret[0];
331 | Tensor output_uncond = ret[1];
332 | Tensor noisePred = cfg * (output_cond - output_uncond) + output_uncond;
333 | latents = sampler.Step(noisePred, index, latents, seed);
334 | }
335 | Console.WriteLine($"end sampling");
336 | Console.WriteLine($"begin decoder");
337 | latents = latents / scale_factor;
338 | Tensor image = decoder.forward(latents);
339 | Console.WriteLine($"end decoder");
340 |
341 | sp.Stop();
342 | Console.WriteLine($"Total time is: {sp.ElapsedMilliseconds} ms.");
343 | image = ((image + 0.5) * 255.0f).clamp(0, 255).@byte().cpu();
344 |
345 | ImageMagick.MagickImage img = Tools.GetImageFromTensor(image);
346 |
347 | StringBuilder stringBuilder = new StringBuilder();
348 | stringBuilder.AppendLine(prompt);
349 | if (!string.IsNullOrEmpty(nprompt))
350 | {
351 | stringBuilder.AppendLine("Negative prompt: " + nprompt);
352 | }
353 | stringBuilder.AppendLine($"Steps: {steps}, CFG scale_factor: {cfg}, Seed: {seed}, Size: {img.Width}x{img.Height}, Version: StableDiffusionSharp");
354 | img.SetAttribute("parameters", stringBuilder.ToString());
355 | return img;
356 | }
357 | }
358 |
359 | public void Dispose()
360 | {
361 | cliper?.Dispose();
362 | diffusion?.Dispose();
363 | decoder?.Dispose();
364 | encoder?.Dispose();
365 | tempTextContext?.Dispose();
366 | }
367 |
368 | }
369 |
370 | }
371 |
372 |
--------------------------------------------------------------------------------
/StableDiffusionSharp/Modules/SDXL.cs:
--------------------------------------------------------------------------------
1 | using TorchSharp;
2 | using static TorchSharp.torch;
3 |
4 | namespace StableDiffusionSharp.Modules
5 | {
6 | public class SDXL : SD1
7 | {
8 | public SDXL(Device? device = null, ScalarType? dtype = null) : base(device, dtype)
9 | {
10 | torchvision.io.DefaultImager = new torchvision.io.SkiaImager();
11 | this.device = device ?? torch.CPU;
12 | this.dtype = dtype ?? torch.float32;
13 |
14 | this.scale_factor = 0.13025f;
15 |
16 | this.in_channels = 4;
17 | this.model_channels = 320;
18 | this.context_dim = 2048;
19 | this.num_head = 20;
20 | this.dropout = 0.0f;
21 | this.adm_in_channels = 2816;
22 |
23 | this.embed_dim = 4;
24 | this.double_z = true;
25 | this.z_channels = 4;
26 | }
27 | }
28 | }
29 |
30 |
--------------------------------------------------------------------------------
/StableDiffusionSharp/Modules/Tokenizer.cs:
--------------------------------------------------------------------------------
1 | using Microsoft.ML.Tokenizers;
2 | using System.Reflection;
3 | using static TorchSharp.torch;
4 |
5 | namespace StableDiffusionSharp.Modules
6 | {
7 | internal class Tokenizer
8 | {
9 | private readonly BpeTokenizer _tokenizer;
10 | private readonly int _startToken;
11 | private readonly int _endToken;
12 |
13 | public Tokenizer(string vocabPath, string mergesPath, int startToken = 49406, int endToken = 49407)
14 | {
15 | if (!File.Exists(vocabPath))
16 | {
17 | string path = Path.GetDirectoryName(vocabPath)!;
18 | if (!Directory.Exists(path))
19 | {
20 | Directory.CreateDirectory(path);
21 | }
22 | Assembly _assembly = Assembly.GetExecutingAssembly();
23 | string resourceName = "StableDiffusionSharp.Models.Clip.vocab.json";
24 | using (Stream stream = _assembly.GetManifestResourceStream(resourceName)!)
25 | {
26 | if (stream == null)
27 | {
28 | Console.WriteLine("Resource can't find!");
29 | return;
30 | }
31 | using (FileStream fileStream = new FileStream(vocabPath, FileMode.Create, FileAccess.Write))
32 | {
33 | stream.CopyTo(fileStream);
34 | }
35 | }
36 |
37 | }
38 |
39 | if (!File.Exists(mergesPath))
40 | {
41 | string path = Path.GetDirectoryName(mergesPath)!;
42 | if (!Directory.Exists(path))
43 | {
44 | Directory.CreateDirectory(path);
45 | }
46 | Assembly _assembly = Assembly.GetExecutingAssembly();
47 | string resourceName = "StableDiffusionSharp.Models.Clip.merges.txt";
48 | using (Stream stream = _assembly.GetManifestResourceStream(resourceName)!)
49 | {
50 | if (stream == null)
51 | {
52 | Console.WriteLine("Resource can't find!");
53 | return;
54 | }
55 | using (FileStream fileStream = new FileStream(mergesPath, FileMode.Create, FileAccess.Write))
56 | {
57 | stream.CopyTo(fileStream);
58 | }
59 | }
60 |
61 | }
62 |
63 | _tokenizer = BpeTokenizer.Create(vocabPath, mergesPath, endOfWordSuffix: "");
64 | _startToken = startToken;
65 | _endToken = endToken;
66 | }
67 |
68 | public Tensor Tokenize(string text, int maxTokens = 77)
69 | {
70 | var res = _tokenizer.EncodeToIds(text).ToList();
71 | res.Insert(0, _startToken);
72 | res.Add(_endToken);
73 | return tensor(res, ScalarType.Int64).unsqueeze(0);
74 | }
75 | }
76 | }
77 |
--------------------------------------------------------------------------------
/StableDiffusionSharp/Modules/Unet.cs:
--------------------------------------------------------------------------------
1 | using System;
2 | using TorchSharp;
3 | using TorchSharp.Modules;
4 | using static Tensorboard.CostGraphDef.Types;
5 | using static Tensorboard.TensorShapeProto.Types;
6 | using static TorchSharp.torch;
7 | using static TorchSharp.torch.nn;
8 |
9 | namespace StableDiffusionSharp.Modules
10 | {
11 | internal class CrossAttention : Module
12 | {
13 | private readonly Linear to_q;
14 | private readonly Linear to_k;
15 | private readonly Linear to_v;
16 | private readonly Sequential to_out;
17 | private readonly long n_heads_;
18 | private readonly long d_head;
19 | private readonly bool causal_mask_;
20 |
21 | public CrossAttention(long channels, long d_cross, long n_heads, bool causal_mask = false, bool in_proj_bias = false, bool out_proj_bias = true, float dropout_p = 0.0f, Device? device = null, ScalarType? dtype = null) : base(nameof(CrossAttention))
22 | {
23 | to_q = Linear(channels, channels, hasBias: in_proj_bias, device: device, dtype: dtype);
24 | to_k = Linear(d_cross, channels, hasBias: in_proj_bias, device: device, dtype: dtype);
25 | to_v = Linear(d_cross, channels, hasBias: in_proj_bias, device: device, dtype: dtype);
26 | to_out = Sequential(Linear(channels, channels, hasBias: out_proj_bias, device: device, dtype: dtype), Dropout(dropout_p, inplace: false));
27 | n_heads_ = n_heads;
28 | d_head = channels / n_heads;
29 | causal_mask_ = causal_mask;
30 | RegisterComponents();
31 | }
32 |
33 | public override Tensor forward(Tensor x, Tensor y)
34 | {
35 | using (NewDisposeScope())
36 | {
37 | long[] input_shape = x.shape;
38 | long batch_size = input_shape[0];
39 | long sequence_length = input_shape[1];
40 |
41 | long[] interim_shape = new long[] { batch_size, -1, n_heads_, d_head };
42 | Tensor q = to_q.forward(x);
43 | Tensor k = to_k.forward(y);
44 | Tensor v = to_v.forward(y);
45 |
46 | q = q.view(interim_shape).transpose(1, 2);
47 | k = k.view(interim_shape).transpose(1, 2);
48 | v = v.view(interim_shape).transpose(1, 2);
49 | Tensor output = functional.scaled_dot_product_attention(q, k, v, is_casual: causal_mask_);
50 | output = output.transpose(1, 2).reshape(input_shape);
51 | output = to_out.forward(output);
52 | return output.MoveToOuterDisposeScope();
53 | }
54 | }
55 | }
56 |
57 | internal class ResnetBlock : Module
58 | {
59 | private readonly int in_channels;
60 | private readonly int out_channels;
61 |
62 | private readonly Module skip_connection;
63 | private readonly Sequential emb_layers;
64 | private readonly Sequential in_layers;
65 | private readonly Sequential out_layers;
66 |
67 | public ResnetBlock(int in_channels, int out_channels, double dropout = 0.0, int temb_channels = 1280, Device? device = null, ScalarType? dtype = null) : base(nameof(ResnetBlock))
68 | {
69 | this.in_channels = in_channels;
70 | out_channels = out_channels < 1 ? in_channels : out_channels;
71 | this.out_channels = out_channels;
72 |
73 | in_layers = Sequential(GroupNorm(32, in_channels, device: device, dtype: dtype), SiLU(), Conv2d(in_channels, out_channels, kernel_size: 3, stride: 1, padding: 1, device: device, dtype: dtype));
74 |
75 | if (temb_channels > 0)
76 | {
77 | emb_layers = Sequential(SiLU(), Linear(temb_channels, out_channels, device: device, dtype: dtype));
78 | }
79 |
80 | out_layers = Sequential(GroupNorm(32, out_channels, device: device, dtype: dtype), SiLU(), Dropout(dropout), Conv2d(out_channels, out_channels, kernel_size: 3, stride: 1, padding: 1, device: device, dtype: dtype));
81 |
82 | if (this.in_channels != this.out_channels)
83 | {
84 | skip_connection = Conv2d(in_channels: in_channels, out_channels: this.out_channels, kernel_size: 1, stride: 1, device: device, dtype: dtype);
85 | }
86 | else
87 | {
88 | skip_connection = Identity();
89 | }
90 |
91 | RegisterComponents();
92 | }
93 |
94 | public override Tensor forward(Tensor x, Tensor time)
95 | {
96 | using (NewDisposeScope())
97 | {
98 | Tensor hidden = x;
99 | hidden = in_layers.forward(hidden);
100 |
101 | if (time is not null)
102 | {
103 | time = emb_layers.forward(time);
104 | hidden = hidden + time.unsqueeze(-1).unsqueeze(-1);
105 | }
106 |
107 | hidden = out_layers.forward(hidden);
108 | if (in_channels != out_channels)
109 | {
110 | x = skip_connection.forward(x);
111 | }
112 | return (x + hidden).MoveToOuterDisposeScope();
113 | }
114 | }
115 | }
116 |
117 | internal class TransformerBlock : Module
118 | {
119 | private LayerNorm norm1;
120 | private CrossAttention attn1;
121 | private LayerNorm norm2;
122 | private CrossAttention attn2;
123 | private LayerNorm norm3;
124 | private FeedForward ff;
125 |
126 | public TransformerBlock(int channels, int n_cross, int n_head, Device? device = null, ScalarType? dtype = null) : base(nameof(TransformerBlock))
127 | {
128 | norm1 = LayerNorm(channels, device: device, dtype: dtype);
129 | attn1 = new CrossAttention(channels, channels, n_head, device: device, dtype: dtype);
130 | norm2 = LayerNorm(channels, device: device, dtype: dtype);
131 | attn2 = new CrossAttention(channels, n_cross, n_head, device: device, dtype: dtype);
132 | norm3 = LayerNorm(channels, device: device, dtype: dtype);
133 | ff = new FeedForward(channels, glu: true, device: device, dtype: dtype);
134 | RegisterComponents();
135 | }
136 | public override Tensor forward(Tensor x, Tensor context)
137 | {
138 | var residue_short = x;
139 | x = norm1.forward(x);
140 | x = attn1.forward(x, x);
141 | x += residue_short;
142 | residue_short = x;
143 | x = norm2.forward(x);
144 | x = attn2.forward(x, context);
145 | x += residue_short;
146 | residue_short = x;
147 | x = norm3.forward(x);
148 | x = ff.forward(x);
149 | x += residue_short;
150 | return x.MoveToOuterDisposeScope();
151 | }
152 | }
153 |
154 | internal class SpatialTransformer : Module
155 | {
156 | private readonly GroupNorm norm;
157 | private readonly Module proj_in;
158 | private readonly Module proj_out;
159 | private readonly ModuleList> transformer_blocks;
160 | private readonly bool use_linear;
161 |
162 | public SpatialTransformer(int channels, int n_cross, int n_head, int num_atten_blocks, float drop_out = 0.0f, bool use_linear = false, Device? device = null, ScalarType? dtype = null) : base(nameof(SpatialTransformer))
163 | {
164 | norm = Normalize(channels, device: device, dtype: dtype);
165 | this.use_linear = use_linear;
166 | proj_in = use_linear ? Linear(channels, channels, device: device, dtype: dtype) : Conv2d(channels, channels, kernel_size: 1, device: device, dtype: dtype);
167 | proj_out = use_linear ? Linear(channels, channels, device: device, dtype: dtype) : Conv2d(channels, channels, kernel_size: 1, device: device, dtype: dtype);
168 | transformer_blocks = new ModuleList>();
169 | for (int i = 0; i < num_atten_blocks; i++)
170 | {
171 | transformer_blocks.Add(new TransformerBlock(channels, n_cross, n_head, device: device, dtype: dtype));
172 | }
173 | RegisterComponents();
174 | }
175 |
176 | public override Tensor forward(Tensor x, Tensor context)
177 | {
178 | using (NewDisposeScope())
179 | {
180 | long n = x.shape[0];
181 | long c = x.shape[1];
182 | long h = x.shape[2];
183 | long w = x.shape[3];
184 |
185 | Tensor residue_short = x;
186 | x = norm.forward(x);
187 |
188 | if (!use_linear)
189 | {
190 | x = proj_in.forward(x);
191 | }
192 |
193 | x = x.view(new long[] { n, c, h * w });
194 | x = x.transpose(-1, -2);
195 |
196 | if (use_linear)
197 | {
198 | x = proj_in.forward(x);
199 | }
200 |
201 | foreach (Module layer in transformer_blocks)
202 | {
203 | x = layer.forward(x, context);
204 | }
205 |
206 | if (use_linear)
207 | {
208 | x = proj_out.forward(x);
209 | }
210 | x = x.transpose(-1, -2);
211 | x = x.view(new long[] { n, c, h, w });
212 | if (!use_linear)
213 | {
214 | x = proj_out.forward(x);
215 | }
216 |
217 | residue_short = residue_short + x;
218 | return residue_short.MoveToOuterDisposeScope();
219 | }
220 | }
221 |
222 | private static GroupNorm Normalize(int in_channels, int num_groups = 32, float eps = 1e-6f, bool affine = true, Device? device = null, ScalarType? dtype = null)
223 | {
224 | return GroupNorm(num_groups: 32, num_channels: in_channels, eps: eps, affine: affine, device: device, dtype: dtype);
225 | }
226 |
227 | }
228 |
229 | internal class Upsample : Module
230 | {
231 | private readonly Conv2d? conv;
232 | private readonly bool with_conv;
233 | public Upsample(int in_channels, bool with_conv = true, Device? device = null, ScalarType? dtype = null) : base(nameof(Upsample))
234 | {
235 | this.with_conv = with_conv;
236 | if (with_conv)
237 | {
238 | conv = Conv2d(in_channels, in_channels, kernel_size: 3, padding: 1, device: device, dtype: dtype);
239 | }
240 | RegisterComponents();
241 | }
242 | public override Tensor forward(Tensor x)
243 | {
244 | var output = functional.interpolate(x, scale_factor: new double[] { 2.0, 2.0 }, mode: InterpolationMode.Nearest);
245 | if (with_conv && conv is not null)
246 | {
247 | output = conv.forward(output);
248 | }
249 | return output;
250 | }
251 | }
252 |
253 | internal class Downsample : Module
254 | {
255 | private readonly Conv2d op;
256 | public Downsample(int in_channels, Device? device = null, ScalarType? dtype = null) : base(nameof(Downsample))
257 | {
258 | op = Conv2d(in_channels: in_channels, out_channels: in_channels, kernel_size: 3, stride: 2, padding: 1, device: device, dtype: dtype);
259 | RegisterComponents();
260 | }
261 | public override Tensor forward(Tensor x)
262 | {
263 | x = op.forward(x);
264 | return x;
265 | }
266 | }
267 |
268 | internal class TimestepEmbedSequential : Sequential
269 | {
270 | internal TimestepEmbedSequential(params (string name, Module)[] modules) : base(modules)
271 | {
272 | RegisterComponents();
273 | }
274 |
275 | internal TimestepEmbedSequential(params Module[] modules) : base(modules)
276 | {
277 | RegisterComponents();
278 | }
279 |
280 | public override Tensor forward(Tensor x, Tensor context, Tensor time)
281 | {
282 | using (NewDisposeScope())
283 | {
284 | foreach (var layer in children())
285 | {
286 | switch (layer)
287 | {
288 | case ResnetBlock res:
289 | x = res.call(x, time);
290 | break;
291 | case SpatialTransformer abl:
292 | x = abl.call(x, context);
293 | break;
294 | case Module m:
295 | x = m.call(x);
296 | break;
297 | }
298 | }
299 | return x.MoveToOuterDisposeScope();
300 | }
301 | }
302 | }
303 |
304 | internal class GEGLU : Module
305 | {
306 | private readonly Linear proj;
307 | public GEGLU(int dim_in, int dim_out, Device? device = null, ScalarType? dtype = null) : base(nameof(GEGLU))
308 | {
309 | proj = Linear(dim_in, dim_out * 2, device: device, dtype: dtype);
310 | RegisterComponents();
311 | }
312 |
313 | public override Tensor forward(Tensor x)
314 | {
315 | using (NewDisposeScope())
316 | {
317 | Tensor[] result = proj.forward(x).chunk(2, dim: -1);
318 | x = result[0];
319 | Tensor gate = result[1];
320 | return (x * functional.gelu(gate)).MoveToOuterDisposeScope();
321 | }
322 | }
323 | }
324 |
325 | internal class FeedForward : Module
326 | {
327 | private readonly Sequential net;
328 |
329 | public FeedForward(int dim, int? dim_out = null, int mult = 4, bool glu = true, float dropout = 0.0f, Device? device = null, ScalarType? dtype = null) : base(nameof(FeedForward))
330 | {
331 | int inner_dim = dim * mult;
332 | int dim_ot = dim_out ?? dim;
333 | Module project_in = glu ? new GEGLU(dim, inner_dim, device: device, dtype: dtype) : Sequential(nn.Linear(dim, inner_dim, device: device, dtype: dtype), nn.GELU());
334 | net = Sequential(project_in, Dropout(dropout), Linear(inner_dim, dim_ot, device: device, dtype: dtype));
335 | RegisterComponents();
336 | }
337 |
338 | public override Tensor forward(Tensor input)
339 | {
340 | return net.forward(input);
341 | }
342 | }
343 |
344 | internal class SDUnet : Module
345 | {
346 | private class UNet : Module
347 | {
348 | private readonly int ch;
349 | private readonly int time_embed_dim;
350 | private readonly int in_channels;
351 | private readonly bool use_timestep;
352 |
353 | private readonly Sequential time_embed;
354 | private readonly ModuleList input_blocks;
355 | private readonly TimestepEmbedSequential middle_block;
356 | private readonly ModuleList output_blocks;
357 | private readonly Sequential @out;
358 |
359 | public UNet(int model_channels, int in_channels, int[]? channel_mult = null, int num_res_blocks = 2, int num_atten_blocks = 1, int context_dim = 768, int num_heads = 8, float dropout = 0.0f, bool use_timestep = true, Device? device = null, ScalarType? dtype = null) : base(nameof(UNet))
360 | {
361 | bool mask = false;
362 | channel_mult = channel_mult ?? new int[] { 1, 2, 4, 4 };
363 |
364 | ch = model_channels;
365 | time_embed_dim = model_channels * 4;
366 | this.in_channels = in_channels;
367 | this.use_timestep = use_timestep;
368 |
369 | List input_block_channels = new List { model_channels };
370 |
371 | if (use_timestep)
372 | {
373 | // timestep embedding
374 | time_embed = Sequential(new Module[] { Linear(model_channels, time_embed_dim, device: device, dtype: dtype), SiLU(), Linear(time_embed_dim, time_embed_dim, device: device, dtype: dtype) });
375 | }
376 |
377 | // downsampling
378 | input_blocks = new ModuleList();
379 | input_blocks.Add(new TimestepEmbedSequential(Conv2d(in_channels, ch, kernel_size: 3, padding: 1, device: device, dtype: dtype)));
380 |
381 | for (int i = 0; i < channel_mult.Length; i++)
382 | {
383 | int in_ch = model_channels * channel_mult[i > 0 ? i - 1 : i];
384 | int out_ch = model_channels * channel_mult[i];
385 |
386 | for (int j = 0; j < num_res_blocks; j++)
387 | {
388 | input_blocks.Add(new TimestepEmbedSequential(new ResnetBlock(in_ch, out_ch, dropout, time_embed_dim, device: device, dtype: dtype), i < channel_mult.Length - 1 ? new SpatialTransformer(out_ch, context_dim, num_heads, num_atten_blocks, dropout, device: device, dtype: dtype) : Identity()));
389 | input_block_channels.Add(in_ch);
390 | in_ch = out_ch;
391 | }
392 | if (i < channel_mult.Length - 1)
393 | {
394 | input_blocks.Add(new TimestepEmbedSequential(Sequential(("op", Conv2d(out_ch, out_ch, 3, stride: 2, padding: 1, device: device, dtype: dtype)))));
395 | input_block_channels.Add(out_ch);
396 | }
397 | }
398 |
399 | // middle block
400 | middle_block = new TimestepEmbedSequential(new ResnetBlock(time_embed_dim, time_embed_dim, dropout, time_embed_dim, device: device, dtype: dtype), new SpatialTransformer(time_embed_dim, context_dim, num_heads, num_atten_blocks, dropout, device: device, dtype: dtype), new ResnetBlock(1280, 1280, device: device, dtype: dtype));
401 |
402 | // upsampling
403 | var reversed_mult = channel_mult.Reverse().ToList();
404 | int prev_channels = time_embed_dim;
405 | output_blocks = new ModuleList();
406 | for (int i = 0; i < reversed_mult.Count; i++)
407 | {
408 | int mult = reversed_mult[i];
409 | int current_channels = model_channels * mult;
410 | int down_stage_index = channel_mult.Length - 1 - i;
411 | int skip_channels = model_channels * channel_mult[down_stage_index];
412 | bool has_atten = i >= 1;
413 |
414 | for (int j = 0; j < num_res_blocks + 1; j++)
415 | {
416 | int current_skip = skip_channels;
417 | if (j == num_res_blocks && i < reversed_mult.Count - 1)
418 | {
419 | int next_down_stage_index = channel_mult.Length - 1 - (i + 1);
420 | current_skip = model_channels * channel_mult[next_down_stage_index];
421 | }
422 |
423 | int input_channels = prev_channels + current_skip;
424 | bool has_upsample = j == num_res_blocks && i != reversed_mult.Count - 1;
425 |
426 | if (has_atten)
427 | {
428 | output_blocks.Add(new TimestepEmbedSequential(
429 | new ResnetBlock(input_channels, current_channels, dropout, time_embed_dim, device: device, dtype: dtype),
430 | new SpatialTransformer(current_channels, context_dim, num_heads, num_atten_blocks, dropout, device: device, dtype: dtype),
431 | has_upsample ? new Upsample(current_channels, device: device, dtype: dtype) : Identity()));
432 | }
433 | else
434 | {
435 | output_blocks.Add(new TimestepEmbedSequential(
436 | new ResnetBlock(input_channels, current_channels, dropout, time_embed_dim, device: device, dtype: dtype),
437 | has_upsample ? new Upsample(current_channels, device: device, dtype: dtype) : Identity()));
438 | }
439 |
440 | prev_channels = current_channels;
441 | }
442 | }
443 |
444 | @out = Sequential(GroupNorm(32, model_channels, device: device, dtype: dtype), SiLU(), Conv2d(model_channels, in_channels, kernel_size: 3, padding: 1, device: device, dtype: dtype));
445 |
446 | RegisterComponents();
447 |
448 | }
449 | public override Tensor forward(Tensor x, Tensor context, Tensor time)
450 | {
451 | using (NewDisposeScope())
452 | {
453 | time = time_embed.forward(time);
454 |
455 | List skip_connections = new List();
456 | foreach (TimestepEmbedSequential layers in input_blocks)
457 | {
458 | x = layers.forward(x, context, time);
459 | skip_connections.Add(x);
460 | }
461 | x = middle_block.forward(x, context, time);
462 | foreach (TimestepEmbedSequential layers in output_blocks)
463 | {
464 | Tensor index = skip_connections.Last();
465 | x = cat(new Tensor[] { x, index }, 1);
466 | skip_connections.RemoveAt(skip_connections.Count - 1);
467 | x = layers.forward(x, context, time);
468 | }
469 |
470 | x = @out.forward(x);
471 | return x.MoveToOuterDisposeScope();
472 | }
473 | }
474 | }
475 |
476 | private class Model : Module
477 | {
478 | private readonly UNet diffusion_model;
479 |
480 | public Model(int model_channels, int in_channels, int num_heads = 8, int context_dim = 768, float dropout = 0.0f, bool use_timestep = true, Device? device = null, ScalarType? dtype = null) : base(nameof(SDUnet))
481 | {
482 | diffusion_model = new UNet(model_channels, in_channels, context_dim: context_dim, num_heads: num_heads, dropout: dropout, use_timestep: use_timestep, device: device, dtype: dtype);
483 | RegisterComponents();
484 | }
485 |
486 | public override Tensor forward(Tensor latent, Tensor context, Tensor time)
487 | {
488 | return diffusion_model.forward(latent, context, time);
489 | }
490 | }
491 |
492 | private readonly Model model;
493 |
494 | public SDUnet(int model_channels, int in_channels, int num_heads = 8, int context_dim = 768, float dropout = 0.0f, bool use_timestep = true, Device? device = null, ScalarType? dtype = null) : base(nameof(SDUnet))
495 | {
496 | model = new Model(model_channels, in_channels, context_dim: context_dim, num_heads: num_heads, dropout: dropout, use_timestep: use_timestep, device: device, dtype: dtype);
497 | RegisterComponents();
498 | }
499 |
500 | public override Tensor forward(Tensor latent, Tensor context, Tensor time, Tensor y)
501 | {
502 | Device device = model.parameters().First().device;
503 | ScalarType dtype = model.parameters().First().dtype;
504 |
505 | latent = latent.to(dtype, device);
506 | time = time.to(dtype, device);
507 | context = context.to(dtype, device);
508 | return model.forward(latent, context, time);
509 | }
510 | }
511 |
512 | internal class SDXLUnet : Module
513 | {
514 | private class UNet : Module
515 | {
516 | private readonly int ch;
517 | private readonly int time_embed_dim;
518 | private readonly int in_channels;
519 | private readonly bool use_timestep;
520 |
521 | private readonly Sequential time_embed;
522 | private readonly Sequential label_emb;
523 | private readonly ModuleList input_blocks;
524 | private readonly TimestepEmbedSequential middle_block;
525 | private readonly ModuleList output_blocks;
526 | private readonly Sequential @out;
527 |
528 |
529 | public UNet(int model_channels, int in_channels, int[]? channel_mult = null, int num_res_blocks = 2, int context_dim = 768, int adm_in_channels = 2816, int num_heads = 20, float dropout = 0.0f, bool use_timestep = true, Device? device = null, ScalarType? dtype = null) : base(nameof(SDUnet))
530 | {
531 | channel_mult = channel_mult ?? new int[] { 1, 2, 4 };
532 |
533 | ch = model_channels;
534 | time_embed_dim = model_channels * 4;
535 | this.in_channels = in_channels;
536 | this.use_timestep = use_timestep;
537 |
538 | bool useLinear = true;
539 | bool mask = false;
540 |
541 | List input_block_channels = new List { model_channels };
542 |
543 | if (use_timestep)
544 | {
545 | int time_embed_dim = model_channels * 4;
546 | time_embed = Sequential(Linear(model_channels, time_embed_dim, device: device, dtype: dtype), SiLU(), Linear(time_embed_dim, time_embed_dim, device: device, dtype: dtype));
547 | label_emb = Sequential(Sequential(Linear(adm_in_channels, time_embed_dim, device: device, dtype: dtype), SiLU(), Linear(time_embed_dim, time_embed_dim, device: device, dtype: dtype)));
548 | }
549 |
550 | // downsampling
551 | input_blocks = new ModuleList();
552 | input_blocks.Add(new TimestepEmbedSequential(Conv2d(in_channels, ch, kernel_size: 3, padding: 1, device: device, dtype: dtype)));
553 |
554 | input_blocks.Add(new TimestepEmbedSequential(new ResnetBlock(320, 320, device: device, dtype: dtype)));
555 | input_blocks.Add(new TimestepEmbedSequential(new ResnetBlock(320, 320, device: device, dtype: dtype)));
556 | input_blocks.Add(new TimestepEmbedSequential(new Downsample(320, device: device, dtype: dtype)));
557 |
558 | input_blocks.Add(new TimestepEmbedSequential(new ResnetBlock(320, 640, device: device, dtype: dtype), new SpatialTransformer(640, 2048, num_heads, 2, 0, useLinear, device: device, dtype: dtype)));
559 | input_blocks.Add(new TimestepEmbedSequential(new ResnetBlock(640, 640, device: device, dtype: dtype), new SpatialTransformer(640, 2048, num_heads, 2, 0, useLinear, device: device, dtype: dtype)));
560 | input_blocks.Add(new TimestepEmbedSequential(new Downsample(640, device: device, dtype: dtype)));
561 |
562 | input_blocks.Add(new TimestepEmbedSequential(new ResnetBlock(640, 1280, device: device, dtype: dtype), new SpatialTransformer(1280, 2048, num_heads, 10, 0, useLinear, device: device, dtype: dtype)));
563 | input_blocks.Add(new TimestepEmbedSequential(new ResnetBlock(1280, 1280, device: device, dtype: dtype), new SpatialTransformer(1280, 2048, num_heads, 10, 0, useLinear, device: device, dtype: dtype)));
564 |
565 | // mid_block
566 | middle_block = new TimestepEmbedSequential(new ResnetBlock(1280, 1280, device: device, dtype: dtype), new SpatialTransformer(1280, 2048, num_heads, 10, 0, useLinear, device: device, dtype: dtype), new ResnetBlock(1280, 1280, device: device, dtype: dtype));
567 |
568 | // upsampling
569 | output_blocks = new ModuleList();
570 | output_blocks.Add(new TimestepEmbedSequential(new ResnetBlock(2560, 1280, device: device, dtype: dtype), new SpatialTransformer(1280, 2048, num_heads, 10, 0, useLinear, device: device, dtype: dtype)));
571 | output_blocks.Add(new TimestepEmbedSequential(new ResnetBlock(2560, 1280, device: device, dtype: dtype), new SpatialTransformer(1280, 2048, num_heads, 10, 0, useLinear, device: device, dtype: dtype)));
572 | output_blocks.Add(new TimestepEmbedSequential(new ResnetBlock(1920, 1280, device: device, dtype: dtype), new SpatialTransformer(1280, 2048, num_heads, 10, 0, useLinear, device: device, dtype: dtype), new Upsample(1280, device: device, dtype: dtype)));
573 |
574 | output_blocks.Add(new TimestepEmbedSequential(new ResnetBlock(1920, 640, device: device, dtype: dtype), new SpatialTransformer(640, 2048, num_heads, 2, 0, useLinear, device: device, dtype: dtype)));
575 | output_blocks.Add(new TimestepEmbedSequential(new ResnetBlock(1280, 640, device: device, dtype: dtype), new SpatialTransformer(640, 2048, num_heads, 2, 0, useLinear, device: device, dtype: dtype)));
576 | output_blocks.Add(new TimestepEmbedSequential(new ResnetBlock(960, 640, device: device, dtype: dtype), new SpatialTransformer(640, 2048, num_heads, 2, 0, useLinear, device: device, dtype: dtype), new Upsample(640, device: device, dtype: dtype)));
577 |
578 | output_blocks.Add(new TimestepEmbedSequential(new ResnetBlock(960, 320, device: device, dtype: dtype)));
579 | output_blocks.Add(new TimestepEmbedSequential(new ResnetBlock(640, 320, device: device, dtype: dtype)));
580 | output_blocks.Add(new TimestepEmbedSequential(new ResnetBlock(640, 320, device: device, dtype: dtype)));
581 |
582 | @out = Sequential(GroupNorm(32, model_channels, device: device, dtype: dtype), SiLU(), Conv2d(model_channels, in_channels, kernel_size: 3, padding: 1, device: device, dtype: dtype));
583 |
584 | RegisterComponents();
585 | }
586 |
587 | public override Tensor forward(Tensor x, Tensor context, Tensor time, Tensor y)
588 | {
589 | using (NewDisposeScope())
590 | {
591 | int dim = 512;
592 | Tensor embed = time_embed.forward(time);
593 | Tensor time_ids = tensor(new float[] { dim, dim, 0, 0, dim, dim }, embed.dtype, embed.device).repeat(new long[] { 2, 1 });
594 | Tensor time_embeds = get_timestep_embedding(time_ids.flatten(), dim / 2, true, 0, 1);
595 | time_embeds = time_embeds.reshape(new long[] { 2, -1 });
596 | y = cat(new Tensor[] { y, time_embeds }, dim: -1);
597 | Tensor label_embed = label_emb.forward(y.to(embed.dtype, embed.device));
598 | embed = embed + label_embed;
599 |
600 | List skip_connections = new List();
601 | foreach (TimestepEmbedSequential layers in input_blocks)
602 | {
603 | x = layers.forward(x, context, embed);
604 | skip_connections.Add(x);
605 | }
606 | x = middle_block.forward(x, context, embed);
607 | foreach (TimestepEmbedSequential layers in output_blocks)
608 | {
609 | Tensor index = skip_connections.Last();
610 | x = cat(new Tensor[] { x, index }, 1);
611 | skip_connections.RemoveAt(skip_connections.Count - 1);
612 | x = layers.forward(x, context, embed);
613 | }
614 |
615 | x = @out.forward(x);
616 | return x.MoveToOuterDisposeScope();
617 | }
618 |
619 | }
620 | }
621 |
622 | private class Model : Module
623 | {
624 | private UNet diffusion_model;
625 | public Model(int model_channels, int in_channels, int num_heads = 20, int context_dim = 2048, int adm_in_channels = 2816, float dropout = 0.0f, bool use_timestep = true, Device? device = null, ScalarType? dtype = null) : base(nameof(SDUnet))
626 | {
627 | diffusion_model = new UNet(model_channels, in_channels, context_dim: context_dim, adm_in_channels: adm_in_channels, num_heads: num_heads, dropout: dropout, use_timestep: use_timestep, device: device, dtype: dtype);
628 | RegisterComponents();
629 | }
630 |
631 | public override Tensor forward(Tensor latent, Tensor context, Tensor time, Tensor y)
632 | {
633 | latent = diffusion_model.forward(latent, context, time, y);
634 | return latent;
635 | }
636 | }
637 |
638 | private readonly Model model;
639 |
640 | public SDXLUnet(int model_channels, int in_channels, int num_heads = 20, int context_dim = 2048, int adm_in_channels = 2816, float dropout = 0.0f, bool use_timestep = true, Device? device = null, ScalarType? dtype = null) : base(nameof(SDUnet))
641 | {
642 | model = new Model(model_channels, in_channels, context_dim: context_dim, adm_in_channels: adm_in_channels, num_heads: num_heads, dropout: dropout, use_timestep: use_timestep, device: device, dtype: dtype);
643 | RegisterComponents();
644 | }
645 |
646 | public override Tensor forward(Tensor latent, Tensor context, Tensor time, Tensor y)
647 | {
648 | Device device = model.parameters().First().device;
649 | ScalarType dtype = model.parameters().First().dtype;
650 |
651 | latent = latent.to(dtype, device);
652 | time = time.to(dtype, device);
653 | y = y.to(dtype, device);
654 | context = context.to(dtype, device);
655 |
656 | latent = model.forward(latent, context, time, y);
657 | return latent;
658 | }
659 |
660 | ///
661 | /// This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
662 | ///
663 | /// a 1-D Tensor of N indices, one per batch element. These may be fractional.
664 | /// the dimension of the output.
665 | /// Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)
666 | /// Controls the delta between frequencies between dimensions
667 | /// Scaling factor applied to the embeddings.
668 | /// Controls the maximum frequency of the embeddings
669 | /// torch.Tensor: an [N x dim] Tensor of positional embeddings.
670 | private static Tensor get_timestep_embedding(Tensor timesteps, int embedding_dim, bool flip_sin_to_cos = false, float downscale_freq_shift = 1, float scale = 1, int max_period = 10000)
671 | {
672 | using (NewDisposeScope())
673 | {
674 | if (timesteps.Dimensions != 1)
675 | {
676 | throw new ArgumentOutOfRangeException("Timesteps should be a 1d-array");
677 | }
678 | int half_dim = embedding_dim / 2;
679 | Tensor exponent = -Math.Log(max_period) * torch.arange(start: 0, stop: half_dim, dtype: torch.float32, device: timesteps.device);
680 | exponent = exponent / (half_dim - downscale_freq_shift);
681 | Tensor emb = torch.exp(exponent);
682 | emb = timesteps[.., TensorIndex.None].@float() * emb[TensorIndex.None, ..];
683 |
684 | // scale embeddings
685 | emb = scale * emb;
686 |
687 | // concat sine and cosine embeddings
688 | emb = torch.cat(new Tensor[] { torch.sin(emb), torch.cos(emb) }, dim: -1);
689 |
690 | // flip sine and cosine embeddings
691 | if (flip_sin_to_cos)
692 | {
693 | emb = torch.cat(new Tensor[] { emb[.., half_dim..], emb[.., ..half_dim] }, dim: -1);
694 | }
695 |
696 | // zero pad
697 | if (embedding_dim % 2 == 1)
698 | {
699 | emb = torch.nn.functional.pad(emb, (0, 1, 0, 0));
700 | }
701 | return emb.MoveToOuterDisposeScope();
702 | }
703 | }
704 |
705 | }
706 | }
707 |
--------------------------------------------------------------------------------
/StableDiffusionSharp/Modules/VAE.cs:
--------------------------------------------------------------------------------
1 | using TorchSharp;
2 | using TorchSharp.Modules;
3 | using static TorchSharp.torch;
4 | using static TorchSharp.torch.nn;
5 |
6 | namespace StableDiffusionSharp.Modules
7 | {
8 | internal class VAE
9 | {
10 | private static GroupNorm Normalize(int in_channels, int num_groups = 32, float eps = 1e-6f, bool affine = true, Device? device = null, ScalarType? dtype = null)
11 | {
12 | return GroupNorm(num_groups: num_groups, num_channels: in_channels, eps: eps, affine: affine, device: device, dtype: dtype);
13 | }
14 |
15 | private class ResnetBlock : Module
16 | {
17 | private readonly int in_channels;
18 | private readonly int out_channels;
19 | private readonly GroupNorm norm1;
20 | private readonly Conv2d conv1;
21 | private readonly GroupNorm norm2;
22 | private readonly Conv2d conv2;
23 | private readonly Module nin_shortcut;
24 | private readonly SiLU swish;
25 |
26 | public ResnetBlock(int in_channels, int out_channels, Device? device = null, ScalarType? dtype = null) : base(nameof(AttnBlock))
27 | {
28 | this.in_channels = in_channels;
29 | this.out_channels = out_channels;
30 | norm1 = Normalize(in_channels, device: device, dtype: dtype);
31 | conv1 = Conv2d(in_channels, out_channels, kernel_size: 3, stride: 1, padding: 1, device: device, dtype: dtype);
32 | norm2 = Normalize(out_channels, device: device, dtype: dtype);
33 | conv2 = Conv2d(out_channels, out_channels, kernel_size: 3, stride: 1, padding: 1, device: device, dtype: dtype);
34 |
35 | if (this.in_channels != this.out_channels)
36 | {
37 | nin_shortcut = Conv2d(in_channels: in_channels, out_channels: out_channels, kernel_size: 1, device: device, dtype: dtype);
38 | }
39 | else
40 | {
41 | nin_shortcut = Identity();
42 | }
43 |
44 | swish = SiLU(inplace: true);
45 | RegisterComponents();
46 | }
47 |
48 | public override Tensor forward(Tensor x)
49 | {
50 | Tensor hidden = x;
51 | hidden = norm1.forward(hidden);
52 | hidden = swish.forward(hidden);
53 | hidden = conv1.forward(hidden);
54 | hidden = norm2.forward(hidden);
55 | hidden = swish.forward(hidden);
56 | hidden = conv2.forward(hidden);
57 | if (in_channels != out_channels)
58 | {
59 | x = nin_shortcut.forward(x);
60 | }
61 | return x + hidden;
62 | }
63 | }
64 |
65 | private class AttnBlock : Module
66 | {
67 | private readonly GroupNorm norm;
68 | private readonly Conv2d q;
69 | private readonly Conv2d k;
70 | private readonly Conv2d v;
71 | private readonly Conv2d proj_out;
72 |
73 | public AttnBlock(int in_channels, Device? device = null, ScalarType? dtype = null) : base(nameof(AttnBlock))
74 | {
75 | norm = Normalize(in_channels, device: device, dtype: dtype);
76 | q = Conv2d(in_channels, in_channels, kernel_size: 1, device: device, dtype: dtype);
77 | k = Conv2d(in_channels, in_channels, kernel_size: 1, device: device, dtype: dtype);
78 | v = Conv2d(in_channels, in_channels, kernel_size: 1, device: device, dtype: dtype);
79 | proj_out = Conv2d(in_channels, in_channels, kernel_size: 1, device: device, dtype: dtype);
80 | RegisterComponents();
81 | }
82 |
83 | public override Tensor forward(Tensor x)
84 | {
85 | using (NewDisposeScope())
86 | {
87 | var hidden = norm.forward(x);
88 | var q = this.q.forward(hidden);
89 | var k = this.k.forward(hidden);
90 | var v = this.v.forward(hidden);
91 |
92 | var (b, c, h, w) = (q.size(0), q.size(1), q.size(2), q.size(3));
93 |
94 | q = q.view(b, 1, h * w, c).contiguous();
95 | k = k.view(b, 1, h * w, c).contiguous();
96 | v = v.view(b, 1, h * w, c).contiguous();
97 |
98 | hidden = functional.scaled_dot_product_attention(q, k, v); // scale_factor is dim ** -0.5 per default
99 |
100 | hidden = hidden.view(b, c, h, w).contiguous();
101 | hidden = proj_out.forward(hidden);
102 |
103 | return (x + hidden).MoveToOuterDisposeScope();
104 | }
105 | }
106 |
107 | }
108 |
109 | private class Downsample : Module
110 | {
111 | private readonly Conv2d? conv;
112 | private readonly bool with_conv;
113 |
114 | public Downsample(int in_channels, bool with_conv = true, Device? device = null, ScalarType? dtype = null) : base(nameof(Downsample))
115 | {
116 | this.with_conv = with_conv;
117 | if (with_conv)
118 | {
119 | conv = Conv2d(in_channels, in_channels, kernel_size: 3, stride: 2, device: device, dtype: dtype);
120 |
121 | }
122 | RegisterComponents();
123 | }
124 |
125 | public override Tensor forward(Tensor x)
126 | {
127 | if (with_conv && conv != null)
128 | {
129 | long[] pad = new long[] { 0, 1, 0, 1 };
130 | x = functional.pad(x, pad, mode: PaddingModes.Constant, value: 0);
131 | x = conv.forward(x);
132 | }
133 | else
134 | {
135 | x = functional.avg_pool2d(x, kernel_size: 2, stride: 2);
136 | }
137 | return x;
138 | }
139 | }
140 |
141 | private class Upsample : Module
142 | {
143 | private readonly Conv2d? conv;
144 | private readonly bool with_conv;
145 | public Upsample(int in_channels, bool with_conv = true, Device? device = null, ScalarType? dtype = null) : base(nameof(Upsample))
146 | {
147 | this.with_conv = with_conv;
148 | if (with_conv)
149 | {
150 | conv = Conv2d(in_channels, in_channels, kernel_size: 3, padding: 1, device: device, dtype: dtype);
151 | }
152 | RegisterComponents();
153 | }
154 | public override Tensor forward(Tensor x)
155 | {
156 | var output = functional.interpolate(x, scale_factor: new double[] { 2.0, 2.0 }, mode: InterpolationMode.Nearest);
157 | if (with_conv && conv != null)
158 | {
159 | output = conv.forward(output);
160 | }
161 | return output;
162 | }
163 | }
164 |
165 | private class VAEEncoder : Module
166 | {
167 | private readonly int num_resolutions;
168 | private readonly int num_res_blocks;
169 | private readonly Conv2d conv_in;
170 | private readonly List in_ch_mult;
171 | private readonly Sequential down;
172 | private readonly Sequential mid;
173 | private readonly GroupNorm norm_out;
174 | private readonly Conv2d conv_out;
175 | private readonly SiLU swish;
176 | private readonly int block_in;
177 | private readonly bool double_z;
178 |
179 |
180 | public VAEEncoder(int ch = 128, int[]? ch_mult = null, int num_res_blocks = 2, int in_channels = 3, int z_channels = 16, bool double_z = true, Device? device = null, ScalarType? dtype = null) : base(nameof(VAEEncoder))
181 | {
182 | this.double_z = double_z;
183 | ch_mult ??= new int[] { 1, 2, 4, 4 };
184 | num_resolutions = ch_mult.Length;
185 | this.num_res_blocks = num_res_blocks;
186 |
187 | // Input convolution
188 | conv_in = Conv2d(in_channels, ch, kernel_size: 3, stride: 1, padding: 1, device: device, dtype: dtype);
189 |
190 | // Downsampling layers
191 | in_ch_mult = new List { 1 };
192 | in_ch_mult.AddRange(ch_mult);
193 | down = Sequential();
194 |
195 | block_in = ch * in_ch_mult[0];
196 |
197 | for (int i_level = 0; i_level < num_resolutions; i_level++)
198 | {
199 | var block = Sequential();
200 | var attn = Sequential();
201 | int block_out = ch * ch_mult[i_level];
202 | block_in = ch * in_ch_mult[i_level];
203 | for (int _ = 0; _ < num_res_blocks; _++)
204 | {
205 | block.append(new ResnetBlock(block_in, block_out, device: device, dtype: dtype));
206 | block_in = block_out;
207 | }
208 |
209 | var d = Sequential(
210 | ("block", block),
211 | ("attn", attn));
212 |
213 | if (i_level != num_resolutions - 1)
214 | {
215 | d.append("downsample", new Downsample(block_in, device: device, dtype: dtype));
216 | }
217 | down.append(d);
218 | }
219 |
220 | // Middle layers
221 | mid = Sequential(
222 | ("block_1", new ResnetBlock(block_in, block_in, device: device, dtype: dtype)),
223 | ("attn_1", new AttnBlock(block_in, device: device, dtype: dtype)),
224 | ("block_2", new ResnetBlock(block_in, block_in, device: device, dtype: dtype)));
225 |
226 |
227 | // Output layers
228 | norm_out = Normalize(block_in, device: device, dtype: dtype);
229 | conv_out = Conv2d(block_in, (double_z ? 2 : 1) * z_channels, kernel_size: 3, stride: 1, padding: 1, device: device, dtype: dtype);
230 | swish = SiLU(inplace: true);
231 |
232 | RegisterComponents();
233 | }
234 |
235 | public override Tensor forward(Tensor x)
236 | {
237 | using var _ = NewDisposeScope();
238 |
239 | // Downsampling
240 | var h = conv_in.forward(x);
241 |
242 | h = down.forward(h);
243 |
244 | // Middle layers
245 | h = mid.forward(h);
246 |
247 | // Output layers
248 | h = norm_out.forward(h);
249 | h = swish.forward(h);
250 | h = conv_out.forward(h);
251 | return h.MoveToOuterDisposeScope();
252 | }
253 | }
254 |
255 | private class VAEDecoder : Module
256 | {
257 | private readonly int num_resolutions;
258 | private readonly int num_res_blocks;
259 |
260 | private readonly Conv2d conv_in;
261 | private readonly Sequential mid;
262 |
263 | private readonly Sequential up;
264 |
265 | private readonly GroupNorm norm_out;
266 | private readonly Conv2d conv_out;
267 | private readonly GELU swish;
268 |
269 | public VAEDecoder(int ch = 128, int out_ch = 3, int[]? ch_mult = null, int num_res_blocks = 2, int resolution = 256, int z_channels = 16, Device? device = null, ScalarType? dtype = null) : base(nameof(VAEDecoder))
270 | {
271 | ch_mult ??= new int[] { 1, 2, 4, 4 };
272 | num_resolutions = ch_mult.Length;
273 | this.num_res_blocks = num_res_blocks;
274 | int block_in = ch * ch_mult[num_resolutions - 1];
275 |
276 | int curr_res = resolution / (int)Math.Pow(2, num_resolutions - 1);
277 | // z to block_in
278 | conv_in = Conv2d(z_channels, block_in, kernel_size: 3, padding: 1, device: device, dtype: dtype);
279 |
280 | // middle
281 | mid = Sequential(
282 | ("block_1", new ResnetBlock(block_in, block_in, device: device, dtype: dtype)),
283 | ("attn_1", new AttnBlock(block_in, device: device, dtype: dtype)),
284 | ("block_2", new ResnetBlock(block_in, block_in, device: device, dtype: dtype))
285 | );
286 |
287 | // upsampling
288 | up = Sequential();
289 |
290 | List list = new List();
291 | for (int i_level = num_resolutions - 1; i_level >= 0; i_level--)
292 | {
293 | var block = Sequential();
294 |
295 | int block_out = ch * ch_mult[i_level];
296 |
297 | for (int i_block = 0; i_block < num_res_blocks + 1; i_block++)
298 | {
299 | block.append(new ResnetBlock(block_in, block_out, device: device, dtype: dtype));
300 | block_in = block_out;
301 | }
302 |
303 | Sequential u = Sequential(("block", block));
304 |
305 | if (i_level != 0)
306 | {
307 | u.append("upsample", new Upsample(block_in, device: device, dtype: dtype));
308 | curr_res *= 2;
309 | }
310 | //this.up.append(u);
311 | list.Insert(0, u);
312 | }
313 |
314 | up = Sequential(list);
315 |
316 | // end
317 | norm_out = Normalize(block_in, device: device, dtype: dtype);
318 | conv_out = Conv2d(block_in, out_ch, kernel_size: 3, stride: 1, padding: 1, device: device, dtype: dtype);
319 | swish = GELU(inplace: true);
320 | RegisterComponents();
321 | }
322 |
323 | public override Tensor forward(Tensor z)
324 | {
325 | // z to block_in
326 | Tensor hidden = conv_in.forward(z);
327 |
328 | // middle
329 | hidden = mid.forward(hidden);
330 |
331 | // upsampling
332 | foreach (Module md in up.children().Reverse())
333 | {
334 | hidden = md.forward(hidden);
335 | }
336 |
337 | // end
338 | hidden = norm_out.forward(hidden);
339 | hidden = swish.forward(hidden);
340 | hidden = conv_out.forward(hidden);
341 | return hidden;
342 | }
343 | }
344 |
345 | internal class Decoder : Module
346 | {
347 | private Sequential first_stage_model;
348 |
349 | public Decoder(int embed_dim = 4, int z_channels = 4, Device? device = null, ScalarType? dtype = null) : base(nameof(Decoder))
350 | {
351 | first_stage_model = Sequential(("post_quant_conv", Conv2d(embed_dim, z_channels, 1, device: device, dtype: dtype)), ("decoder", new VAEDecoder(z_channels: z_channels, device: device, dtype: dtype)));
352 | RegisterComponents();
353 | }
354 |
355 | public override Tensor forward(Tensor latents)
356 | {
357 | Device device = first_stage_model.parameters().First().device;
358 | ScalarType dtype = first_stage_model.parameters().First().dtype;
359 | latents = latents.to(dtype, device);
360 | return first_stage_model.forward(latents);
361 | }
362 | }
363 |
364 | internal class Encoder : Module
365 | {
366 | private Sequential first_stage_model;
367 | public Encoder(int embed_dim = 4, int z_channels = 4, bool double_z = true, Device? device = null, ScalarType? dtype = null) : base(nameof(Encoder))
368 | {
369 | int factor = double_z ? 2 : 1;
370 | first_stage_model = Sequential(("encoder", new VAEEncoder(z_channels: z_channels, device: device, dtype: dtype)), ("quant_conv", Conv2d(factor * embed_dim, factor * z_channels, 1, device: device, dtype: dtype)));
371 | RegisterComponents();
372 | }
373 |
374 | public override Tensor forward(Tensor input)
375 | {
376 | Device device = first_stage_model.parameters().First().device;
377 | ScalarType dtype = first_stage_model.parameters().First().dtype;
378 | input = input.to(dtype, device);
379 | return first_stage_model.forward(input);
380 | }
381 | }
382 | }
383 | }
384 |
--------------------------------------------------------------------------------
/StableDiffusionSharp/Modules/VAEApprox.cs:
--------------------------------------------------------------------------------
1 | using System.Reflection;
2 | using TorchSharp.Modules;
3 | using static TorchSharp.torch;
4 | using static TorchSharp.torch.nn;
5 |
6 | namespace StableDiffusionSharp.Modules
7 | {
8 | internal class VAEApprox : Module
9 | {
10 | private readonly Conv2d conv1;
11 | private readonly Conv2d conv2;
12 | private readonly Conv2d conv3;
13 | private readonly Conv2d conv4;
14 | private readonly Conv2d conv5;
15 | private readonly Conv2d conv6;
16 | private readonly Conv2d conv7;
17 | private readonly Conv2d conv8;
18 |
19 | internal VAEApprox(int latent_channels = 4, Device? device = null, ScalarType? dtype = null) : base(nameof(VAEApprox))
20 | {
21 | string vaeSD15ApproxPath = @".\models\vaeapprox\vaeapp_sd15.pth";
22 | string vaeSDXLApproxPath = @".\models\vaeapprox\xlvaeapp.pth";
23 | string path = Path.GetDirectoryName(vaeSD15ApproxPath)!;
24 | if (!Directory.Exists(path))
25 | {
26 | Directory.CreateDirectory(path);
27 | }
28 | Assembly _assembly = Assembly.GetExecutingAssembly();
29 | if (!File.Exists(vaeSDXLApproxPath))
30 | {
31 | string sd15ResourceName = "StableDiffusionSharp.Models.VAEApprox.vaeapp_sd15.pth";
32 | using (Stream stream = _assembly.GetManifestResourceStream(sd15ResourceName)!)
33 | {
34 | if (stream == null)
35 | {
36 | Console.WriteLine("Resource can't find!");
37 | return;
38 | }
39 | using (FileStream fileStream = new FileStream(vaeSD15ApproxPath, FileMode.Create, FileAccess.Write))
40 | {
41 | stream.CopyTo(fileStream);
42 | }
43 | }
44 | }
45 | if (!File.Exists(vaeSDXLApproxPath))
46 | {
47 | string sdxlResourceName = "StableDiffusionSharp.Models.VAEApprox.xlvaeapp.pth";
48 | using (Stream stream = _assembly.GetManifestResourceStream(sdxlResourceName)!)
49 | {
50 | if (stream == null)
51 | {
52 | Console.WriteLine("Resource can't find!");
53 | return;
54 | }
55 | using (FileStream fileStream = new FileStream(vaeSDXLApproxPath, FileMode.Create, FileAccess.Write))
56 | {
57 | stream.CopyTo(fileStream);
58 | }
59 | }
60 | }
61 |
62 | conv1 = Conv2d(latent_channels, 8, (7, 7), device: device, dtype: dtype);
63 | conv2 = Conv2d(8, 16, (5, 5), device: device, dtype: dtype);
64 | conv3 = Conv2d(16, 32, (3, 3), device: device, dtype: dtype);
65 | conv4 = Conv2d(32, 64, (3, 3), device: device, dtype: dtype);
66 | conv5 = Conv2d(64, 32, (3, 3), device: device, dtype: dtype);
67 | conv6 = Conv2d(32, 16, (3, 3), device: device, dtype: dtype);
68 | conv7 = Conv2d(16, 8, (3, 3), device: device, dtype: dtype);
69 | conv8 = Conv2d(8, 3, (3, 3), device: device, dtype: dtype);
70 | RegisterComponents();
71 | }
72 |
73 | public override Tensor forward(Tensor x)
74 | {
75 | using (NewDisposeScope())
76 | {
77 | int extra = 11;
78 | x = functional.interpolate(x, new long[] { x.shape[2] * 2, x.shape[3] * 2 });
79 | x = functional.pad(x, (extra, extra, extra, extra));
80 |
81 | foreach (var layer in ModuleList(new Conv2d[] { conv1, conv2, conv3, conv4, conv5, conv6, conv7, conv8 }))
82 | {
83 | x = layer.forward(x);
84 | x = functional.leaky_relu(x, 0.1);
85 | }
86 | return x.MoveToOuterDisposeScope();
87 | }
88 | }
89 |
90 | public enum SharedModel
91 | {
92 | SD3,
93 | SDXL,
94 | SD
95 | }
96 | }
97 | }
98 |
--------------------------------------------------------------------------------
/StableDiffusionSharp/SDType.cs:
--------------------------------------------------------------------------------
1 | namespace StableDiffusionSharp
2 | {
3 | public enum SDScalarType
4 | {
5 | Float16 = 5,
6 | Float32 = 6,
7 | BFloat16 = 15,
8 | }
9 |
10 | public enum SDDeviceType
11 | {
12 | CPU = 0,
13 | CUDA = 1,
14 | }
15 |
16 | public enum SDSamplerType
17 | {
18 | EulerAncestral = 0,
19 | Euler = 1,
20 | }
21 |
22 | public enum ModelType
23 | {
24 | SD1,
25 | SD2,
26 | SD3,
27 | SDXL,
28 | FLUX,
29 | }
30 |
31 | public enum TimestepSpacing
32 | {
33 | Linspace,
34 | Leading,
35 | Trailing,
36 | }
37 |
38 | }
39 |
--------------------------------------------------------------------------------
/StableDiffusionSharp/Sampler/BasicSampler.cs:
--------------------------------------------------------------------------------
1 | using TorchSharp;
2 | using static TorchSharp.torch;
3 |
4 | namespace StableDiffusionSharp.Sampler
5 | {
6 | public abstract class BasicSampler
7 | {
8 | public Tensor Sigmas;
9 | internal Tensor Timesteps;
10 | private Scheduler.DiscreteSchedule schedule;
11 | private readonly TimestepSpacing timestepSpacing;
12 |
13 | public BasicSampler(int num_train_timesteps = 1000, float beta_start = 0.00085f, float beta_end = 0.012f, int steps_offset = 1, TimestepSpacing timestepSpacing = TimestepSpacing.Leading)
14 | {
15 | this.timestepSpacing = timestepSpacing;
16 | Tensor betas = GetBetaSchedule(beta_start, beta_end, num_train_timesteps);
17 | Tensor alphas = 1.0f - betas;
18 | Tensor alphas_cumprod = torch.cumprod(alphas, 0);
19 | this.Sigmas = torch.pow((1.0f - alphas_cumprod) / alphas_cumprod, 0.5f);
20 | }
21 |
22 | public Tensor InitNoiseSigma()
23 | {
24 | if (timestepSpacing == TimestepSpacing.Linspace || timestepSpacing == TimestepSpacing.Trailing)
25 | {
26 | return Sigmas.max();
27 | }
28 | return torch.sqrt(torch.pow(Sigmas.max(), 2) + 1);
29 | }
30 |
31 | public Tensor ScaleModelInput(Tensor sample, int step_index)
32 | {
33 | Tensor sigma = Sigmas[step_index];
34 | return sample / torch.sqrt(torch.pow(sigma, 2) + 1);
35 | }
36 |
37 | ///
38 | /// Get the scalings for the given step index
39 | ///
40 | ///
41 | /// Tensor c_out, Tensor c_in
42 | public (Tensor, Tensor) GetScalings(int step_index)
43 | {
44 | Tensor sigma = Sigmas[step_index];
45 | Tensor c_out = -sigma;
46 | Tensor c_in = 1 / torch.sqrt(torch.pow(sigma, 2) + 1);
47 | return (c_out, c_in);
48 | }
49 | public Tensor append_dims(Tensor x, long target_dims)
50 | {
51 | long dims_to_append = target_dims - x.ndim;
52 | if (dims_to_append < 0)
53 | {
54 | throw new ArgumentException("target_dims must be greater than x.ndim");
55 | }
56 | long[] dims = x.shape;
57 | for (int i = 0; i < dims_to_append; i++)
58 | {
59 | dims.Append(1);
60 | }
61 | return x.view(dims);
62 | }
63 |
64 |
65 |
66 | public void SetTimesteps(long num_inference_steps)
67 | {
68 | if (num_inference_steps < 1)
69 | {
70 | throw new ArgumentException("num_inference_steps must be greater than 0");
71 | }
72 | //long t_max = Sigmas.NumberOfElements - 1;
73 | //this.Timesteps = torch.linspace(t_max, 0, num_inference_steps);
74 | this.Timesteps = GetTimeSteps(Sigmas.NumberOfElements, num_inference_steps, timestepSpacing);
75 | schedule = new Scheduler.DiscreteSchedule(Sigmas);
76 | this.Sigmas = append_zero(schedule.t_to_sigma(this.Timesteps));
77 | }
78 |
79 | private Tensor GetTimeSteps(double t_max, long num_steps, TimestepSpacing timestepSpacing)
80 | {
81 | if (timestepSpacing == TimestepSpacing.Linspace)
82 | {
83 | return torch.linspace(t_max - 1, 0, num_steps);
84 | }
85 | else if (timestepSpacing == TimestepSpacing.Leading)
86 | {
87 | long step_ratio = (long)t_max / num_steps;
88 | return torch.linspace(t_max - step_ratio, 0, num_steps) + 1;
89 | }
90 | else
91 | {
92 | long step_ratio = (long)t_max / num_steps;
93 | return torch.arange(t_max, 0, -step_ratio).round() - 1;
94 | }
95 | }
96 |
97 |
98 | public virtual Tensor Step(Tensor model_output, int step_index, Tensor sample, long seed = 0, float s_churn = 0.0f, float s_tmin = 0.0f, float s_tmax = float.PositiveInfinity, float s_noise = 1.0f)
99 | {
100 | // It is the same as EulerSampler
101 | sample = sample.to(model_output.dtype, model_output.device);
102 | Generator generator = torch.manual_seed(seed);
103 | torch.set_rng_state(generator.get_state());
104 | float sigma = Sigmas[step_index].ToSingle();
105 | float gamma = s_tmin <= sigma && sigma <= s_tmax ? (float)Math.Min(s_churn / (Sigmas.NumberOfElements - 1f), Math.Sqrt(2.0f) - 1.0f) : 0f;
106 | Tensor epsilon = torch.randn_like(model_output) * s_noise;
107 | float sigma_hat = sigma * (gamma + 1);
108 | if (gamma > 0)
109 | {
110 | sample = sample + epsilon * (float)Math.Sqrt(Math.Pow(sigma_hat, 2f) - Math.Pow(sigma, 2f));
111 | }
112 | Tensor pred_original_sample = sample - sigma_hat * model_output; // to_d and sigma is c_out
113 | Tensor derivative = (sample - pred_original_sample) / sigma_hat;
114 | float dt = Sigmas[step_index + 1].ToSingle() - sigma_hat;
115 | return sample + derivative * dt;
116 | }
117 |
118 | private Tensor GetBetaSchedule(float beta_start, float beta_end, int num_train_timesteps)
119 | {
120 | return torch.pow(torch.linspace(Math.Pow(beta_start, 0.5), Math.Pow(beta_end, 0.5), num_train_timesteps, ScalarType.Float32), 2);
121 | }
122 |
123 | private static Tensor append_zero(Tensor x)
124 | {
125 | return torch.cat(new Tensor[] { x, x.new_zeros(1) });
126 | }
127 | }
128 | }
129 |
--------------------------------------------------------------------------------
/StableDiffusionSharp/Sampler/EulerAncestralSampler.cs:
--------------------------------------------------------------------------------
1 | using TorchSharp;
2 | using static TorchSharp.torch;
3 |
4 | namespace StableDiffusionSharp.Sampler
5 | {
6 | internal class EulerAncestralSampler : BasicSampler
7 | {
8 | public EulerAncestralSampler(int num_train_timesteps = 1000, float beta_start = 0.00085f, float beta_end = 0.012f, int steps_offset = 1) : base(num_train_timesteps, beta_start, beta_end, steps_offset)
9 | {
10 |
11 | }
12 | public override torch.Tensor Step(torch.Tensor model_output, int step_index, torch.Tensor sample, long seed = 0, float s_churn = 0, float s_tmin = 0, float s_tmax = float.PositiveInfinity, float s_noise = 1)
13 | {
14 | sample = sample.to(model_output.dtype, model_output.device);
15 | Generator generator = torch.manual_seed(seed);
16 | torch.set_rng_state(generator.get_state());
17 |
18 | float sigma = base.Sigmas[step_index].ToSingle();
19 |
20 | Tensor predOriginalSample = sample - model_output * sigma;
21 | Tensor sigmaFrom = base.Sigmas[step_index];
22 | Tensor sigmaTo = base.Sigmas[step_index + 1];
23 | Tensor sigmaFromLessSigmaTo = torch.pow(sigmaFrom, 2) - torch.pow(sigmaTo, 2);
24 | Tensor sigmaUpResult = torch.pow(sigmaTo, 2) * sigmaFromLessSigmaTo / torch.pow(sigmaFrom, 2);
25 |
26 | Tensor sigmaUp = sigmaUpResult.ToSingle() < 0 ? -torch.pow(torch.abs(sigmaUpResult), 0.5f) : torch.pow(sigmaUpResult, 0.5f);
27 | Tensor sigmaDownResult = torch.pow(sigmaTo, 2) - torch.pow(sigmaUp, 2);
28 | Tensor sigmaDown = sigmaDownResult.ToSingle() < 0 ? -torch.pow(torch.abs(sigmaDownResult), 0.5f) : torch.pow(sigmaDownResult, 0.5f);
29 | Tensor derivative = (sample - predOriginalSample) / sigma; // to_d and sigma is c_out
30 | Tensor delta = sigmaDown - sigma;
31 | Tensor prevSample = sample + derivative * delta;
32 | var noise = torch.randn_like(prevSample);
33 | prevSample = prevSample + noise * sigmaUp;
34 | return prevSample;
35 | }
36 | }
37 | }
38 |
--------------------------------------------------------------------------------
/StableDiffusionSharp/Sampler/EulerSampler.cs:
--------------------------------------------------------------------------------
1 | using TorchSharp;
2 | using static TorchSharp.torch;
3 |
4 | namespace StableDiffusionSharp.Sampler
5 | {
6 | internal class EulerSampler : BasicSampler
7 | {
8 | public EulerSampler(int num_train_timesteps = 1000, float beta_start = 0.00085f, float beta_end = 0.012f, int steps_offset = 1) : base(num_train_timesteps, beta_start, beta_end, steps_offset)
9 | {
10 |
11 | }
12 |
13 | public override torch.Tensor Step(torch.Tensor model_output, int step_index, torch.Tensor sample, long seed = 0, float s_churn = 0, float s_tmin = 0, float s_tmax = float.PositiveInfinity, float s_noise = 1)
14 | {
15 | sample = sample.to(model_output.dtype, model_output.device);
16 | Generator generator = torch.manual_seed(seed);
17 | torch.set_rng_state(generator.get_state());
18 | float sigma = base.Sigmas[step_index].ToSingle();
19 | float gamma = s_tmin <= sigma && sigma <= s_tmax ? (float)Math.Min(s_churn / (Sigmas.NumberOfElements - 1f), Math.Sqrt(2.0f) - 1.0f) : 0f;
20 | Tensor noise = torch.randn_like(model_output);
21 | Tensor epsilon = noise * s_noise;
22 | float sigma_hat = sigma * (gamma + 1.0f);
23 | if (gamma > 0)
24 | {
25 | sample = sample + epsilon * (float)Math.Sqrt(Math.Pow(sigma_hat, 2f) - Math.Pow(sigma, 2f));
26 | }
27 | Tensor pred_original_sample = sample - sigma_hat * model_output; // to_d and sigma is c_out
28 | Tensor derivative = (sample - pred_original_sample) / sigma_hat;
29 | Tensor dt = Sigmas[step_index + 1] - sigma_hat;
30 | return sample + derivative * dt;
31 | }
32 | }
33 | }
34 |
--------------------------------------------------------------------------------
/StableDiffusionSharp/Scheduler/DiscreteSchedule.cs:
--------------------------------------------------------------------------------
1 | using static TorchSharp.torch;
2 | using static TorchSharp.torch.nn;
3 |
4 | namespace StableDiffusionSharp.Scheduler
5 | {
6 | internal class DiscreteSchedule : Module
7 | {
8 | private Tensor sigmas;
9 | private Tensor log_sigmas;
10 | private bool quantize;
11 |
12 | public DiscreteSchedule(Tensor sigmas, bool quantize = false) : base(nameof(DiscreteSchedule))
13 | {
14 | this.sigmas = sigmas;
15 | log_sigmas = sigmas.log();
16 | this.quantize = quantize;
17 | RegisterComponents();
18 | }
19 |
20 | public Tensor sigma_mix => sigmas.max();
21 | public Tensor sigma_max => sigmas.min();
22 |
23 | public Tensor t_to_sigma(Tensor t)
24 | {
25 | t = t.@float();
26 | Tensor low_idx = t.floor().@long();
27 | Tensor high_idx = t.ceil().@long();
28 | Tensor w = t.frac();
29 | Tensor log_sigma = (1 - w) * log_sigmas[low_idx] + w * log_sigmas[high_idx];
30 | return log_sigma.exp();
31 | }
32 |
33 | public Tensor sigma_to_t(Tensor sigma, bool? quantize = null)
34 | {
35 | quantize = quantize ?? this.quantize;
36 | Tensor log_sigma = sigma.log();
37 | Tensor dists = log_sigma - log_sigmas[.., TensorIndex.None];
38 |
39 | if (quantize == true)
40 | {
41 | return dists.abs().argmin(dim: 0).view(sigma.shape);
42 | }
43 |
44 | Tensor low_idx = dists.ge(0).cumsum(dim: 0).argmax(dim: 0).clamp(max: log_sigmas.shape[0] - 2);
45 | Tensor high_idx = low_idx + 1;
46 | var (low, high) = (log_sigmas[low_idx], log_sigmas[high_idx]);
47 | Tensor w = (low - log_sigma) / (low - high);
48 | w = w.clamp(0, 1);
49 | Tensor t = (1 - w) * low_idx + w * high_idx;
50 | return t.view(sigma.shape);
51 | }
52 | }
53 | }
54 |
--------------------------------------------------------------------------------
/StableDiffusionSharp/StableDiffusion.cs:
--------------------------------------------------------------------------------
1 | using StableDiffusionSharp.Modules;
2 | using TorchSharp;
3 | using static TorchSharp.torch;
4 |
5 | namespace StableDiffusionSharp
6 | {
7 | public class StableDiffusion : nn.Module
8 | {
9 | private SDModel model;
10 | private readonly Device device;
11 | private readonly ScalarType dtype;
12 |
13 | public class StepEventArgs : EventArgs
14 | {
15 | public int CurrentStep { get; }
16 | public int TotalSteps { get; }
17 |
18 | public ImageMagick.MagickImage VaeApproxImg { get; }
19 |
20 | public StepEventArgs(int currentStep, int totalSteps, ImageMagick.MagickImage vaeApproxImg)
21 | {
22 | CurrentStep = currentStep;
23 | TotalSteps = totalSteps;
24 | VaeApproxImg = vaeApproxImg;
25 | }
26 | }
27 |
28 | public event EventHandler StepProgress;
29 | protected void OnStepProgress(int currentStep, int totalSteps, ImageMagick.MagickImage vaeApproxImg)
30 | {
31 | StepProgress?.Invoke(this, new StepEventArgs(currentStep, totalSteps, vaeApproxImg));
32 | }
33 |
34 | public StableDiffusion(SDDeviceType deviceType, SDScalarType scaleType) : base(nameof(StableDiffusion))
35 | {
36 | this.device = new Device((DeviceType)deviceType);
37 | this.dtype = (ScalarType)scaleType;
38 | }
39 |
40 | public void LoadModel(string modelPath, string vaeModelPath = "", string vocabPath = @".\models\clip\vocab.json", string mergesPath = @".\models\clip\merges.txt")
41 | {
42 | ModelType modelType = ModelLoader.ModelLoader.GetModelType(modelPath);
43 | Console.WriteLine($"Maybe you are using: {modelType}");
44 | model = modelType switch
45 | {
46 | ModelType.SD1 => new SD1(this.device, this.dtype),
47 | ModelType.SDXL => new SDXL(this.device, this.dtype),
48 | _ => throw new ArgumentException("Invalid model type")
49 | };
50 | model.LoadModel(modelPath, vaeModelPath, vocabPath, mergesPath);
51 | model.StepProgress += Model_StepProgress;
52 | }
53 |
54 | private void Model_StepProgress(object? sender, SDModel.StepEventArgs e)
55 | {
56 | OnStepProgress(e.CurrentStep, e.TotalSteps, e.VAEApproxImg);
57 | }
58 |
59 | public ImageMagick.MagickImage TextToImage(string prompt, string nprompt = "", long clip_skip = 0, int width = 512, int height = 512, int steps = 20, long seed = 0, float cfg = 7.0f, SDSamplerType samplerType = SDSamplerType.Euler)
60 | {
61 | return model.TextToImage(prompt, nprompt, clip_skip, width, height, steps, seed, cfg, samplerType);
62 | }
63 |
64 | public ImageMagick.MagickImage ImageToImage(ImageMagick.MagickImage orgImage, string prompt, string nprompt = "", long clip_skip = 0, int steps = 20, float strength = 0.75f, long seed = 0, long subSeed = 0, float cfg = 7.0f, SDSamplerType samplerType = SDSamplerType.Euler)
65 | {
66 | return model.ImageToImage(orgImage, prompt, nprompt, clip_skip, steps, strength, seed, subSeed, cfg, samplerType);
67 | }
68 |
69 | }
70 | }
71 |
--------------------------------------------------------------------------------
/StableDiffusionSharp/StableDiffusionSharp.csproj:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | net6.0
5 | enable
6 | enable
7 | StableDiffusionSharp
8 | IntptrMax
9 |
10 |
11 | Use Stable Diffusion with C# with fast speed and less VRAM.
12 | Requires reference to one of libtorch-cpu, libtorch-cuda-12.1, libtorch-cuda-12.1-win-x64 or libtorch-cuda-12.1-linux-x64 version 2.5.1.0 to execute.
13 | https://github.com/IntptrMax/StableDiffusionSharp
14 | LICENSE.txt
15 | 1.0.8
16 | README.md
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 | Never
29 | True
30 | \
31 |
32 |
33 | Never
34 | True
35 | \
36 |
37 |
38 | Never
39 |
40 |
41 | Never
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 | True
58 | \
59 |
60 |
61 | True
62 | \
63 |
64 |
65 |
66 |
67 |
68 |
--------------------------------------------------------------------------------
/StableDiffusionSharp/Tools.cs:
--------------------------------------------------------------------------------
1 | using ImageMagick;
2 | using System.IO.Compression;
3 | using System.Text;
4 | using TorchSharp;
5 | using static TorchSharp.torch;
6 |
7 | namespace StableDiffusionSharp
8 | {
9 | internal class Tools
10 | {
11 | internal static Tensor GetTensorFromImage(MagickImage image)
12 | {
13 | using (MemoryStream memoryStream = new MemoryStream())
14 | {
15 | image.Write(memoryStream, MagickFormat.Png);
16 | memoryStream.Position = 0;
17 | return torchvision.io.read_image(memoryStream);
18 | }
19 | }
20 |
21 | public static MagickImage GetImageFromTensor(Tensor tensor)
22 | {
23 | MemoryStream memoryStream = new MemoryStream();
24 | torchvision.io.write_png(tensor.cpu(), memoryStream);
25 | memoryStream.Position = 0;
26 | return new MagickImage(memoryStream, MagickFormat.Png);
27 | }
28 |
29 | ///
30 | /// Load Python .pt tensor file and change dtype and device the same as given tensor.
31 | ///
32 | /// tensor path
33 | /// the given tensor
34 | /// Tensor in TorchSharp
35 | public static Tensor LoadTensorFromPT(string path, Tensor tensor)
36 | {
37 | return LoadTensorFromPT(path).to(tensor.dtype, tensor.device);
38 | }
39 |
40 | ///
41 | /// Load Python .pt tensor file
42 | ///
43 | /// tensor path
44 | /// Tensor in TorchSharp
45 | public static Tensor LoadTensorFromPT(string path)
46 | {
47 | torch.ScalarType dtype = torch.ScalarType.Float32;
48 | List shape = new List();
49 | ZipArchive zip = ZipFile.OpenRead(path);
50 | ZipArchiveEntry headerEntry = zip.Entries.First(e => e.Name == "data.pkl");
51 |
52 | // Header is always small enough to fit in memory, so we can read it all at once
53 | using Stream headerStream = headerEntry.Open();
54 | byte[] headerBytes = new byte[headerEntry.Length];
55 | headerStream.Read(headerBytes, 0, headerBytes.Length);
56 |
57 | string headerStr = Encoding.Default.GetString(headerBytes);
58 | if (headerStr.Contains("HalfStorage"))
59 | {
60 | dtype = torch.ScalarType.Float16;
61 | }
62 | else if (headerStr.Contains("BFloat"))
63 | {
64 | dtype = torch.ScalarType.Float16;
65 | }
66 | else if (headerStr.Contains("FloatStorage"))
67 | {
68 | dtype = torch.ScalarType.Float32;
69 | }
70 | for (int i = 0; i < headerBytes.Length; i++)
71 | {
72 | if (headerBytes[i] == 81 && headerBytes[i + 1] == 75 && headerBytes[i + 2] == 0)
73 | {
74 | for (int j = i + 2; j < headerBytes.Length; j++)
75 | {
76 | if (headerBytes[j] == 75)
77 | {
78 | shape.Add(headerBytes[j + 1]);
79 | j++;
80 | }
81 | else if (headerBytes[j] == 77)
82 | {
83 | shape.Add(headerBytes[j + 1] + headerBytes[j + 2] * 256);
84 | j += 2;
85 | }
86 | else if (headerBytes[j] == 113)
87 | {
88 | break;
89 | }
90 |
91 | }
92 | break;
93 | }
94 | }
95 |
96 | Tensor tensor = torch.zeros(shape.ToArray(), dtype: dtype);
97 | ZipArchiveEntry dataEntry = zip.Entries.First(e => e.Name == "0");
98 |
99 | using Stream dataStream = dataEntry.Open();
100 | byte[] data = new byte[dataEntry.Length];
101 | dataStream.Read(data, 0, data.Length);
102 | tensor.bytes = data;
103 | return tensor;
104 | }
105 |
106 | public static long GetFreeVRAM()
107 | {
108 | if (!cuda.is_available())
109 | {
110 | return 0;
111 | }
112 | else
113 | {
114 | using (var factory = new SharpDX.DXGI.Factory1())
115 | {
116 | var adapter = factory.Adapters[0];
117 | using (var adapter3 = adapter.QueryInterface())
118 | {
119 | if (adapter3 == null)
120 | {
121 | throw new ArgumentException($"Adapter {adapter.Description.Description} not support");
122 | }
123 | var memoryInfo = adapter3.QueryVideoMemoryInfo(0, SharpDX.DXGI.MemorySegmentGroup.Local);
124 | long totalVRAM = adapter.Description.DedicatedVideoMemory;
125 | long usedVRAM = memoryInfo.CurrentUsage;
126 | long freeVRAM = memoryInfo.Budget - usedVRAM;
127 | return freeVRAM;
128 | }
129 | }
130 | }
131 | }
132 |
133 | }
134 | }
135 |
--------------------------------------------------------------------------------