├── .gitattributes ├── .gitignore ├── LICENSE.txt ├── README.md ├── StableDiffusionDemo_Console ├── Program.cs └── StableDiffusionDemo_Console.csproj ├── StableDiffusionDemo_Winform ├── FormMain.Designer.cs ├── FormMain.cs ├── FormMain.resx ├── Program.cs └── StableDiffusionDemo_Winform.csproj ├── StableDiffusionSharp.sln └── StableDiffusionSharp ├── ModelLoader ├── ModelLoader.cs ├── PickleLoader.cs ├── SafetensorsLoader.cs └── TensorInfo.cs ├── Models ├── Clip │ ├── merges.txt │ └── vocab.json └── VAEApprox │ ├── vaeapp_sd15.pth │ └── xlvaeapp.pth ├── Modules ├── Clip.cs ├── Esrgan.cs ├── SD1.cs ├── SDModel.cs ├── SDXL.cs ├── Tokenizer.cs ├── Unet.cs ├── VAE.cs └── VAEApprox.cs ├── SDType.cs ├── Sampler ├── BasicSampler.cs ├── EulerAncestralSampler.cs └── EulerSampler.cs ├── Scheduler └── DiscreteSchedule.cs ├── StableDiffusion.cs ├── StableDiffusionSharp.csproj └── Tools.cs /.gitattributes: -------------------------------------------------------------------------------- 1 | ############################################################################### 2 | # Set default behavior to automatically normalize line endings. 3 | ############################################################################### 4 | * text=auto 5 | 6 | ############################################################################### 7 | # Set default behavior for command prompt diff. 8 | # 9 | # This is need for earlier builds of msysgit that does not have it on by 10 | # default for csharp files. 11 | # Note: This is only used by command line 12 | ############################################################################### 13 | #*.cs diff=csharp 14 | 15 | ############################################################################### 16 | # Set the merge driver for project and solution files 17 | # 18 | # Merging from the command prompt will add diff markers to the files if there 19 | # are conflicts (Merging from VS is not affected by the settings below, in VS 20 | # the diff markers are never inserted). Diff markers may cause the following 21 | # file extensions to fail to load in VS. An alternative would be to treat 22 | # these files as binary and thus will always conflict and require user 23 | # intervention with every merge. To do so, just uncomment the entries below 24 | ############################################################################### 25 | #*.sln merge=binary 26 | #*.csproj merge=binary 27 | #*.vbproj merge=binary 28 | #*.vcxproj merge=binary 29 | #*.vcproj merge=binary 30 | #*.dbproj merge=binary 31 | #*.fsproj merge=binary 32 | #*.lsproj merge=binary 33 | #*.wixproj merge=binary 34 | #*.modelproj merge=binary 35 | #*.sqlproj merge=binary 36 | #*.wwaproj merge=binary 37 | 38 | ############################################################################### 39 | # behavior for image files 40 | # 41 | # image files are treated as binary by default. 42 | ############################################################################### 43 | #*.jpg binary 44 | #*.png binary 45 | #*.gif binary 46 | 47 | ############################################################################### 48 | # diff behavior for common document formats 49 | # 50 | # Convert binary document formats to text before diffing them. This feature 51 | # is only available from the command line. Turn it on by uncommenting the 52 | # entries below. 53 | ############################################################################### 54 | #*.doc diff=astextplain 55 | #*.DOC diff=astextplain 56 | #*.docx diff=astextplain 57 | #*.DOCX diff=astextplain 58 | #*.dot diff=astextplain 59 | #*.DOT diff=astextplain 60 | #*.pdf diff=astextplain 61 | #*.PDF diff=astextplain 62 | #*.rtf diff=astextplain 63 | #*.RTF diff=astextplain 64 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ## Ignore Visual Studio temporary files, build results, and 2 | ## files generated by popular Visual Studio add-ons. 3 | ## 4 | ## Get latest from https://github.com/github/gitignore/blob/master/VisualStudio.gitignore 5 | 6 | # User-specific files 7 | *.rsuser 8 | *.suo 9 | *.user 10 | *.userosscache 11 | *.sln.docstates 12 | 13 | # User-specific files (MonoDevelop/Xamarin Studio) 14 | *.userprefs 15 | 16 | # Mono auto generated files 17 | mono_crash.* 18 | 19 | # Build results 20 | [Dd]ebug/ 21 | [Dd]ebugPublic/ 22 | [Rr]elease/ 23 | [Rr]eleases/ 24 | x64/ 25 | x86/ 26 | [Ww][Ii][Nn]32/ 27 | [Aa][Rr][Mm]/ 28 | [Aa][Rr][Mm]64/ 29 | bld/ 30 | [Bb]in/ 31 | [Oo]bj/ 32 | [Oo]ut/ 33 | [Ll]og/ 34 | [Ll]ogs/ 35 | 36 | # Visual Studio 2015/2017 cache/options directory 37 | .vs/ 38 | # Uncomment if you have tasks that create the project's static files in wwwroot 39 | #wwwroot/ 40 | 41 | # Visual Studio 2017 auto generated files 42 | Generated\ Files/ 43 | 44 | # MSTest test Results 45 | [Tt]est[Rr]esult*/ 46 | [Bb]uild[Ll]og.* 47 | 48 | # NUnit 49 | *.VisualState.xml 50 | TestResult.xml 51 | nunit-*.xml 52 | 53 | # Build Results of an ATL Project 54 | [Dd]ebugPS/ 55 | [Rr]eleasePS/ 56 | dlldata.c 57 | 58 | # Benchmark Results 59 | BenchmarkDotNet.Artifacts/ 60 | 61 | # .NET Core 62 | project.lock.json 63 | project.fragment.lock.json 64 | artifacts/ 65 | 66 | # ASP.NET Scaffolding 67 | ScaffoldingReadMe.txt 68 | 69 | # StyleCop 70 | StyleCopReport.xml 71 | 72 | # Files built by Visual Studio 73 | *_i.c 74 | *_p.c 75 | *_h.h 76 | *.ilk 77 | *.meta 78 | *.obj 79 | *.iobj 80 | *.pch 81 | *.pdb 82 | *.ipdb 83 | *.pgc 84 | *.pgd 85 | *.rsp 86 | *.sbr 87 | *.tlb 88 | *.tli 89 | *.tlh 90 | *.tmp 91 | *.tmp_proj 92 | *_wpftmp.csproj 93 | *.log 94 | *.vspscc 95 | *.vssscc 96 | .builds 97 | *.pidb 98 | *.svclog 99 | *.scc 100 | 101 | # Chutzpah Test files 102 | _Chutzpah* 103 | 104 | # Visual C++ cache files 105 | ipch/ 106 | *.aps 107 | *.ncb 108 | *.opendb 109 | *.opensdf 110 | *.sdf 111 | *.cachefile 112 | *.VC.db 113 | *.VC.VC.opendb 114 | 115 | # Visual Studio profiler 116 | *.psess 117 | *.vsp 118 | *.vspx 119 | *.sap 120 | 121 | # Visual Studio Trace Files 122 | *.e2e 123 | 124 | # TFS 2012 Local Workspace 125 | $tf/ 126 | 127 | # Guidance Automation Toolkit 128 | *.gpState 129 | 130 | # ReSharper is a .NET coding add-in 131 | _ReSharper*/ 132 | *.[Rr]e[Ss]harper 133 | *.DotSettings.user 134 | 135 | # TeamCity is a build add-in 136 | _TeamCity* 137 | 138 | # DotCover is a Code Coverage Tool 139 | *.dotCover 140 | 141 | # AxoCover is a Code Coverage Tool 142 | .axoCover/* 143 | !.axoCover/settings.json 144 | 145 | # Coverlet is a free, cross platform Code Coverage Tool 146 | coverage*.json 147 | coverage*.xml 148 | coverage*.info 149 | 150 | # Visual Studio code coverage results 151 | *.coverage 152 | *.coveragexml 153 | 154 | # NCrunch 155 | _NCrunch_* 156 | .*crunch*.local.xml 157 | nCrunchTemp_* 158 | 159 | # MightyMoose 160 | *.mm.* 161 | AutoTest.Net/ 162 | 163 | # Web workbench (sass) 164 | .sass-cache/ 165 | 166 | # Installshield output folder 167 | [Ee]xpress/ 168 | 169 | # DocProject is a documentation generator add-in 170 | DocProject/buildhelp/ 171 | DocProject/Help/*.HxT 172 | DocProject/Help/*.HxC 173 | DocProject/Help/*.hhc 174 | DocProject/Help/*.hhk 175 | DocProject/Help/*.hhp 176 | DocProject/Help/Html2 177 | DocProject/Help/html 178 | 179 | # Click-Once directory 180 | publish/ 181 | 182 | # Publish Web Output 183 | *.[Pp]ublish.xml 184 | *.azurePubxml 185 | # Note: Comment the next line if you want to checkin your web deploy settings, 186 | # but database connection strings (with potential passwords) will be unencrypted 187 | *.pubxml 188 | *.publishproj 189 | 190 | # Microsoft Azure Web App publish settings. Comment the next line if you want to 191 | # checkin your Azure Web App publish settings, but sensitive information contained 192 | # in these scripts will be unencrypted 193 | PublishScripts/ 194 | 195 | # NuGet Packages 196 | *.nupkg 197 | # NuGet Symbol Packages 198 | *.snupkg 199 | # The packages folder can be ignored because of Package Restore 200 | **/[Pp]ackages/* 201 | # except build/, which is used as an MSBuild target. 202 | !**/[Pp]ackages/build/ 203 | # Uncomment if necessary however generally it will be regenerated when needed 204 | #!**/[Pp]ackages/repositories.config 205 | # NuGet v3's project.json files produces more ignorable files 206 | *.nuget.props 207 | *.nuget.targets 208 | 209 | # Microsoft Azure Build Output 210 | csx/ 211 | *.build.csdef 212 | 213 | # Microsoft Azure Emulator 214 | ecf/ 215 | rcf/ 216 | 217 | # Windows Store app package directories and files 218 | AppPackages/ 219 | BundleArtifacts/ 220 | Package.StoreAssociation.xml 221 | _pkginfo.txt 222 | *.appx 223 | *.appxbundle 224 | *.appxupload 225 | 226 | # Visual Studio cache files 227 | # files ending in .cache can be ignored 228 | *.[Cc]ache 229 | # but keep track of directories ending in .cache 230 | !?*.[Cc]ache/ 231 | 232 | # Others 233 | ClientBin/ 234 | ~$* 235 | *~ 236 | *.dbmdl 237 | *.dbproj.schemaview 238 | *.jfm 239 | *.pfx 240 | *.publishsettings 241 | orleans.codegen.cs 242 | 243 | # Including strong name files can present a security risk 244 | # (https://github.com/github/gitignore/pull/2483#issue-259490424) 245 | #*.snk 246 | 247 | # Since there are multiple workflows, uncomment next line to ignore bower_components 248 | # (https://github.com/github/gitignore/pull/1529#issuecomment-104372622) 249 | #bower_components/ 250 | 251 | # RIA/Silverlight projects 252 | Generated_Code/ 253 | 254 | # Backup & report files from converting an old project file 255 | # to a newer Visual Studio version. Backup files are not needed, 256 | # because we have git ;-) 257 | _UpgradeReport_Files/ 258 | Backup*/ 259 | UpgradeLog*.XML 260 | UpgradeLog*.htm 261 | ServiceFabricBackup/ 262 | *.rptproj.bak 263 | 264 | # SQL Server files 265 | *.mdf 266 | *.ldf 267 | *.ndf 268 | 269 | # Business Intelligence projects 270 | *.rdl.data 271 | *.bim.layout 272 | *.bim_*.settings 273 | *.rptproj.rsuser 274 | *- [Bb]ackup.rdl 275 | *- [Bb]ackup ([0-9]).rdl 276 | *- [Bb]ackup ([0-9][0-9]).rdl 277 | 278 | # Microsoft Fakes 279 | FakesAssemblies/ 280 | 281 | # GhostDoc plugin setting file 282 | *.GhostDoc.xml 283 | 284 | # Node.js Tools for Visual Studio 285 | .ntvs_analysis.dat 286 | node_modules/ 287 | 288 | # Visual Studio 6 build log 289 | *.plg 290 | 291 | # Visual Studio 6 workspace options file 292 | *.opt 293 | 294 | # Visual Studio 6 auto-generated workspace file (contains which files were open etc.) 295 | *.vbw 296 | 297 | # Visual Studio LightSwitch build output 298 | **/*.HTMLClient/GeneratedArtifacts 299 | **/*.DesktopClient/GeneratedArtifacts 300 | **/*.DesktopClient/ModelManifest.xml 301 | **/*.Server/GeneratedArtifacts 302 | **/*.Server/ModelManifest.xml 303 | _Pvt_Extensions 304 | 305 | # Paket dependency manager 306 | .paket/paket.exe 307 | paket-files/ 308 | 309 | # FAKE - F# Make 310 | .fake/ 311 | 312 | # CodeRush personal settings 313 | .cr/personal 314 | 315 | # Python Tools for Visual Studio (PTVS) 316 | __pycache__/ 317 | *.pyc 318 | 319 | # Cake - Uncomment if you are using it 320 | # tools/** 321 | # !tools/packages.config 322 | 323 | # Tabs Studio 324 | *.tss 325 | 326 | # Telerik's JustMock configuration file 327 | *.jmconfig 328 | 329 | # BizTalk build output 330 | *.btp.cs 331 | *.btm.cs 332 | *.odx.cs 333 | *.xsd.cs 334 | 335 | # OpenCover UI analysis results 336 | OpenCover/ 337 | 338 | # Azure Stream Analytics local run output 339 | ASALocalRun/ 340 | 341 | # MSBuild Binary and Structured Log 342 | *.binlog 343 | 344 | # NVidia Nsight GPU debugger configuration file 345 | *.nvuser 346 | 347 | # MFractors (Xamarin productivity tool) working folder 348 | .mfractor/ 349 | 350 | # Local History for Visual Studio 351 | .localhistory/ 352 | 353 | # BeatPulse healthcheck temp database 354 | healthchecksdb 355 | 356 | # Backup folder for Package Reference Convert tool in Visual Studio 2017 357 | MigrationBackup/ 358 | 359 | # Ionide (cross platform F# VS Code tools) working folder 360 | .ionide/ 361 | 362 | # Fody - auto-generated XML schema 363 | FodyWeavers.xsd -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # StableDiffusionSharp 2 | 3 | **Use Stable diffusion with C# only.** 4 | 5 | StableDiffusionSharp is an image generating software. With the help of torchsharp, stable diffusion can run without python. 6 | 7 | ![Demo](https://github.com/user-attachments/assets/6924f4fa-0c5b-495c-81a8-3bcaf9b7d6e6) 8 | 9 | ## Features 10 | 11 | - Written in C# only. 12 | - Can load .safetensors or .ckpt model directly. 13 | - Cuda support. 14 | - Use SDPA for speed-up and save vram in fp16. 15 | - Text2Image support. 16 | - Image2Image support. 17 | - SD1.5 support. 18 | - SDXL support. 19 | - VAEApprox support. 20 | - Esrgan 4x support. 21 | - Nuget package support. 22 | 23 | For SD1.5 Text to Image, it cost about 3G VRAM and 2.4 seconds for Generating a 512*512 image in 20 step. 24 | 25 | ## Work to do 26 | 27 | - Lora support. 28 | - ControlNet support. 29 | - Inpaint support. 30 | - Tiled VAE. 31 | 32 | ## How to use 33 | 34 | You can download the code or add it from nuget. 35 | 36 | dotnet add package IntptrMax.YoloSharp 37 | 38 | Or use the code directly. 39 | 40 | > [!NOTE] 41 | > Please add one of libtorch-cpu, libtorch-cuda-12.1, libtorch-cuda-12.1-win-x64 or libtorch-cuda-12.1-linux-x64 version 2.5.1.0 to execute. 42 | 43 | You have to download sd model first. If you need a seperate vae, and you have to download it too. 44 | 45 | 46 | If you want to use esrgan for upscaling, you have to download model from [RealESRGAN_x4plus.pth](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth) 47 | 48 | Now you can use it like the code below. 49 | 50 | ``` C# 51 | static void Main(string[] args) 52 | { 53 | string sdModelPath = @".\Chilloutmix.safetensors"; 54 | string vaeModelPath = @".\vae.safetensors"; 55 | 56 | string esrganModelPath = @".\RealESRGAN_x4plus.pth"; 57 | string i2iPrompt = "High quality, best quality, moon, grass, tree, boat."; 58 | string prompt = "cat with blue eyes"; 59 | string nprompt = ""; 60 | 61 | SDDeviceType deviceType = SDDeviceType.CUDA; 62 | SDScalarType scalarType = SDScalarType.Float16; 63 | SDSamplerType samplerType = SDSamplerType.EulerAncestral; 64 | int step = 20; 65 | float cfg = 7.0f; 66 | long seed = 0; 67 | long img2imgSubSeed = 0; 68 | int width = 512; 69 | int height = 512; 70 | float strength = 0.75f; 71 | long clipSkip = 2; 72 | 73 | StableDiffusion sd = new StableDiffusion(deviceType, scalarType); 74 | sd.StepProgress += Sd_StepProgress; 75 | Console.WriteLine("Loading model......"); 76 | sd.LoadModel(sdModelPath, vaeModelPath); 77 | Console.WriteLine("Model loaded."); 78 | 79 | ImageMagick.MagickImage t2iImage = sd.TextToImage(prompt, nprompt, clipSkip, width, height, step, seed, cfg, samplerType); 80 | t2iImage.Write("output_t2i.png"); 81 | 82 | ImageMagick.MagickImage i2iImage = sd.ImageToImage(t2iImage, i2iPrompt, nprompt, clipSkip, step, strength, seed, img2imgSubSeed, cfg, samplerType); 83 | i2iImage.Write("output_i2i.png"); 84 | 85 | sd.Dispose(); 86 | GC.Collect(); 87 | 88 | Console.WriteLine("Doing upscale......"); 89 | StableDiffusionSharp.Modules.Esrgan esrgan = new StableDiffusionSharp.Modules.Esrgan(deviceType: deviceType, scalarType: scalarType); 90 | esrgan.LoadModel(esrganModelPath); 91 | ImageMagick.MagickImage upscaleImg = esrgan.UpScale(t2iImage); 92 | upscaleImg.Write("upscale.png"); 93 | 94 | Console.WriteLine(@"Done. Images have been saved."); 95 | } 96 | 97 | private static void Sd_StepProgress(object? sender, StableDiffusion.StepEventArgs e) 98 | { 99 | Console.WriteLine($"Progress: {e.CurrentStep}/{e.TotalSteps}"); 100 | } 101 | -------------------------------------------------------------------------------- /StableDiffusionDemo_Console/Program.cs: -------------------------------------------------------------------------------- 1 | using StableDiffusionSharp; 2 | 3 | namespace StableDiffusionDemo_Console 4 | { 5 | internal class Program 6 | { 7 | static void Main(string[] args) 8 | { 9 | string sdModelPath = @".\Chilloutmix.safetensors"; 10 | string vaeModelPath = @".\vae.safetensors"; 11 | 12 | string esrganModelPath = @".\RealESRGAN_x4plus.pth"; 13 | string i2iPrompt = "High quality, best quality, moon, grass, tree, boat."; 14 | string prompt = "cat with blue eyes"; 15 | string nprompt = ""; 16 | 17 | SDDeviceType deviceType = SDDeviceType.CUDA; 18 | SDScalarType scalarType = SDScalarType.Float16; 19 | SDSamplerType samplerType = SDSamplerType.Euler; 20 | int step = 20; 21 | float cfg = 7.0f; 22 | long seed = 0; 23 | long img2imgSubSeed = 0; 24 | int width = 512; 25 | int height = 512; 26 | float strength = 0.75f; 27 | long clipSkip = 2; 28 | 29 | StableDiffusion sd = new StableDiffusion(deviceType, scalarType); 30 | sd.StepProgress += Sd_StepProgress; 31 | Console.WriteLine("Loading model......"); 32 | sd.LoadModel(sdModelPath, vaeModelPath); 33 | Console.WriteLine("Model loaded."); 34 | 35 | ImageMagick.MagickImage t2iImage = sd.TextToImage(prompt, nprompt, clipSkip, width, height, step, seed, cfg, samplerType); 36 | t2iImage.Write("output_t2i.png"); 37 | 38 | ImageMagick.MagickImage i2iImage = sd.ImageToImage(t2iImage, i2iPrompt, nprompt, clipSkip, step, strength, seed, img2imgSubSeed, cfg, samplerType); 39 | i2iImage.Write("output_i2i.png"); 40 | 41 | sd.Dispose(); 42 | GC.Collect(); 43 | 44 | Console.WriteLine("Doing upscale......"); 45 | StableDiffusionSharp.Modules.Esrgan esrgan = new StableDiffusionSharp.Modules.Esrgan(deviceType: deviceType, scalarType: scalarType); 46 | esrgan.LoadModel(esrganModelPath); 47 | ImageMagick.MagickImage upscaleImg = esrgan.UpScale(t2iImage); 48 | upscaleImg.Write("upscale.png"); 49 | 50 | Console.WriteLine(@"Done. Images have been saved."); 51 | } 52 | 53 | private static void Sd_StepProgress(object? sender, StableDiffusion.StepEventArgs e) 54 | { 55 | Console.WriteLine($"Progress: {e.CurrentStep}/{e.TotalSteps}"); 56 | } 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /StableDiffusionDemo_Console/StableDiffusionDemo_Console.csproj: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | Exe 5 | net6.0 6 | enable 7 | enable 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | -------------------------------------------------------------------------------- /StableDiffusionDemo_Winform/FormMain.Designer.cs: -------------------------------------------------------------------------------- 1 | namespace StableDiffusionDemo_Winform 2 | { 3 | partial class FormMain 4 | { 5 | /// 6 | /// Required designer variable. 7 | /// 8 | private System.ComponentModel.IContainer components = null; 9 | 10 | /// 11 | /// Clean up any resources being used. 12 | /// 13 | /// true if managed resources should be disposed; otherwise, false. 14 | protected override void Dispose(bool disposing) 15 | { 16 | if (disposing && (components != null)) 17 | { 18 | components.Dispose(); 19 | } 20 | base.Dispose(disposing); 21 | } 22 | 23 | #region Windows Form Designer generated code 24 | 25 | /// 26 | /// Required method for Designer support - do not modify 27 | /// the contents of this method with the code editor. 28 | /// 29 | private void InitializeComponent() 30 | { 31 | groupBox1 = new GroupBox(); 32 | label11 = new Label(); 33 | NumericUpDown_ClipSkip = new NumericUpDown(); 34 | label10 = new Label(); 35 | Button_VAEModelScan = new Button(); 36 | TextBox_VaePath = new TextBox(); 37 | label9 = new Label(); 38 | label8 = new Label(); 39 | ComboBox_Precition = new ComboBox(); 40 | ComboBox_Device = new ComboBox(); 41 | Button_ModelLoad = new Button(); 42 | Button_ModelScan = new Button(); 43 | label1 = new Label(); 44 | TextBox_ModelPath = new TextBox(); 45 | tabControl1 = new TabControl(); 46 | tabPage1 = new TabPage(); 47 | groupBox2 = new GroupBox(); 48 | Label_State = new Label(); 49 | Button_Generate = new Button(); 50 | label7 = new Label(); 51 | label6 = new Label(); 52 | label5 = new Label(); 53 | NumericUpDown_Height = new NumericUpDown(); 54 | NumericUpDown_CFG = new NumericUpDown(); 55 | NumericUpDown_Step = new NumericUpDown(); 56 | NumericUpDown_Width = new NumericUpDown(); 57 | label4 = new Label(); 58 | PictureBox_Output = new PictureBox(); 59 | label3 = new Label(); 60 | TextBox_NPrompt = new TextBox(); 61 | TextBox_Prompt = new TextBox(); 62 | label2 = new Label(); 63 | tabPage2 = new TabPage(); 64 | tabPage3 = new TabPage(); 65 | groupBox1.SuspendLayout(); 66 | ((System.ComponentModel.ISupportInitialize)NumericUpDown_ClipSkip).BeginInit(); 67 | tabControl1.SuspendLayout(); 68 | tabPage1.SuspendLayout(); 69 | groupBox2.SuspendLayout(); 70 | ((System.ComponentModel.ISupportInitialize)NumericUpDown_Height).BeginInit(); 71 | ((System.ComponentModel.ISupportInitialize)NumericUpDown_CFG).BeginInit(); 72 | ((System.ComponentModel.ISupportInitialize)NumericUpDown_Step).BeginInit(); 73 | ((System.ComponentModel.ISupportInitialize)NumericUpDown_Width).BeginInit(); 74 | ((System.ComponentModel.ISupportInitialize)PictureBox_Output).BeginInit(); 75 | SuspendLayout(); 76 | // 77 | // groupBox1 78 | // 79 | groupBox1.Controls.Add(label11); 80 | groupBox1.Controls.Add(NumericUpDown_ClipSkip); 81 | groupBox1.Controls.Add(label10); 82 | groupBox1.Controls.Add(Button_VAEModelScan); 83 | groupBox1.Controls.Add(TextBox_VaePath); 84 | groupBox1.Controls.Add(label9); 85 | groupBox1.Controls.Add(label8); 86 | groupBox1.Controls.Add(ComboBox_Precition); 87 | groupBox1.Controls.Add(ComboBox_Device); 88 | groupBox1.Controls.Add(Button_ModelLoad); 89 | groupBox1.Controls.Add(Button_ModelScan); 90 | groupBox1.Controls.Add(label1); 91 | groupBox1.Controls.Add(TextBox_ModelPath); 92 | groupBox1.Location = new Point(12, 12); 93 | groupBox1.Name = "groupBox1"; 94 | groupBox1.Size = new Size(865, 178); 95 | groupBox1.TabIndex = 0; 96 | groupBox1.TabStop = false; 97 | groupBox1.Text = "Base"; 98 | // 99 | // label11 100 | // 101 | label11.AutoSize = true; 102 | label11.Location = new Point(461, 117); 103 | label11.Name = "label11"; 104 | label11.Size = new Size(59, 17); 105 | label11.TabIndex = 12; 106 | label11.Text = "Clip Skip"; 107 | // 108 | // NumericUpDown_ClipSkip 109 | // 110 | NumericUpDown_ClipSkip.Location = new Point(526, 114); 111 | NumericUpDown_ClipSkip.Maximum = new decimal(new int[] { 10, 0, 0, 0 }); 112 | NumericUpDown_ClipSkip.Name = "NumericUpDown_ClipSkip"; 113 | NumericUpDown_ClipSkip.Size = new Size(62, 23); 114 | NumericUpDown_ClipSkip.TabIndex = 11; 115 | // 116 | // label10 117 | // 118 | label10.AutoSize = true; 119 | label10.Location = new Point(18, 78); 120 | label10.Name = "label10"; 121 | label10.Size = new Size(60, 17); 122 | label10.TabIndex = 10; 123 | label10.Text = "VAE Path"; 124 | // 125 | // Button_VAEModelScan 126 | // 127 | Button_VAEModelScan.Location = new Point(708, 72); 128 | Button_VAEModelScan.Name = "Button_VAEModelScan"; 129 | Button_VAEModelScan.Size = new Size(101, 23); 130 | Button_VAEModelScan.TabIndex = 9; 131 | Button_VAEModelScan.Text = "Scan"; 132 | Button_VAEModelScan.UseVisualStyleBackColor = true; 133 | Button_VAEModelScan.Click += Button_VAEModelScan_Click; 134 | // 135 | // TextBox_VaePath 136 | // 137 | TextBox_VaePath.Location = new Point(113, 72); 138 | TextBox_VaePath.Name = "TextBox_VaePath"; 139 | TextBox_VaePath.ReadOnly = true; 140 | TextBox_VaePath.Size = new Size(564, 23); 141 | TextBox_VaePath.TabIndex = 8; 142 | // 143 | // label9 144 | // 145 | label9.AutoSize = true; 146 | label9.Location = new Point(217, 115); 147 | label9.Name = "label9"; 148 | label9.Size = new Size(58, 17); 149 | label9.TabIndex = 7; 150 | label9.Text = "Precition"; 151 | // 152 | // label8 153 | // 154 | label8.AutoSize = true; 155 | label8.Location = new Point(18, 115); 156 | label8.Name = "label8"; 157 | label8.Size = new Size(46, 17); 158 | label8.TabIndex = 6; 159 | label8.Text = "Device"; 160 | // 161 | // ComboBox_Precition 162 | // 163 | ComboBox_Precition.DropDownStyle = ComboBoxStyle.DropDownList; 164 | ComboBox_Precition.FormattingEnabled = true; 165 | ComboBox_Precition.Items.AddRange(new object[] { "fp16", "fp32" }); 166 | ComboBox_Precition.Location = new Point(281, 112); 167 | ComboBox_Precition.Name = "ComboBox_Precition"; 168 | ComboBox_Precition.Size = new Size(121, 25); 169 | ComboBox_Precition.TabIndex = 5; 170 | // 171 | // ComboBox_Device 172 | // 173 | ComboBox_Device.DropDownStyle = ComboBoxStyle.DropDownList; 174 | ComboBox_Device.FormattingEnabled = true; 175 | ComboBox_Device.Items.AddRange(new object[] { "CUDA", "CPU" }); 176 | ComboBox_Device.Location = new Point(70, 112); 177 | ComboBox_Device.Name = "ComboBox_Device"; 178 | ComboBox_Device.Size = new Size(121, 25); 179 | ComboBox_Device.TabIndex = 4; 180 | // 181 | // Button_ModelLoad 182 | // 183 | Button_ModelLoad.Location = new Point(708, 137); 184 | Button_ModelLoad.Name = "Button_ModelLoad"; 185 | Button_ModelLoad.Size = new Size(101, 23); 186 | Button_ModelLoad.TabIndex = 3; 187 | Button_ModelLoad.Text = "Load Model"; 188 | Button_ModelLoad.UseVisualStyleBackColor = true; 189 | Button_ModelLoad.Click += Button_ModelLoad_Click; 190 | // 191 | // Button_ModelScan 192 | // 193 | Button_ModelScan.Location = new Point(708, 40); 194 | Button_ModelScan.Name = "Button_ModelScan"; 195 | Button_ModelScan.Size = new Size(101, 23); 196 | Button_ModelScan.TabIndex = 2; 197 | Button_ModelScan.Text = "Scan"; 198 | Button_ModelScan.UseVisualStyleBackColor = true; 199 | Button_ModelScan.Click += Button_ModelScan_Click; 200 | // 201 | // label1 202 | // 203 | label1.AutoSize = true; 204 | label1.Location = new Point(18, 40); 205 | label1.Name = "label1"; 206 | label1.Size = new Size(75, 17); 207 | label1.TabIndex = 1; 208 | label1.Text = "Model Path"; 209 | // 210 | // TextBox_ModelPath 211 | // 212 | TextBox_ModelPath.Location = new Point(113, 34); 213 | TextBox_ModelPath.Name = "TextBox_ModelPath"; 214 | TextBox_ModelPath.ReadOnly = true; 215 | TextBox_ModelPath.Size = new Size(564, 23); 216 | TextBox_ModelPath.TabIndex = 0; 217 | // 218 | // tabControl1 219 | // 220 | tabControl1.Controls.Add(tabPage1); 221 | tabControl1.Controls.Add(tabPage2); 222 | tabControl1.Controls.Add(tabPage3); 223 | tabControl1.Location = new Point(12, 196); 224 | tabControl1.Name = "tabControl1"; 225 | tabControl1.SelectedIndex = 0; 226 | tabControl1.Size = new Size(865, 397); 227 | tabControl1.TabIndex = 1; 228 | // 229 | // tabPage1 230 | // 231 | tabPage1.Controls.Add(groupBox2); 232 | tabPage1.Location = new Point(4, 26); 233 | tabPage1.Name = "tabPage1"; 234 | tabPage1.Padding = new Padding(3); 235 | tabPage1.Size = new Size(857, 367); 236 | tabPage1.TabIndex = 0; 237 | tabPage1.Text = "Text To Image"; 238 | tabPage1.UseVisualStyleBackColor = true; 239 | // 240 | // groupBox2 241 | // 242 | groupBox2.Controls.Add(Label_State); 243 | groupBox2.Controls.Add(Button_Generate); 244 | groupBox2.Controls.Add(label7); 245 | groupBox2.Controls.Add(label6); 246 | groupBox2.Controls.Add(label5); 247 | groupBox2.Controls.Add(NumericUpDown_Height); 248 | groupBox2.Controls.Add(NumericUpDown_CFG); 249 | groupBox2.Controls.Add(NumericUpDown_Step); 250 | groupBox2.Controls.Add(NumericUpDown_Width); 251 | groupBox2.Controls.Add(label4); 252 | groupBox2.Controls.Add(PictureBox_Output); 253 | groupBox2.Controls.Add(label3); 254 | groupBox2.Controls.Add(TextBox_NPrompt); 255 | groupBox2.Controls.Add(TextBox_Prompt); 256 | groupBox2.Controls.Add(label2); 257 | groupBox2.Location = new Point(6, 6); 258 | groupBox2.Name = "groupBox2"; 259 | groupBox2.Size = new Size(845, 355); 260 | groupBox2.TabIndex = 0; 261 | groupBox2.TabStop = false; 262 | groupBox2.Text = "Parameters"; 263 | // 264 | // Label_State 265 | // 266 | Label_State.BorderStyle = BorderStyle.FixedSingle; 267 | Label_State.Location = new Point(6, 294); 268 | Label_State.Name = "Label_State"; 269 | Label_State.Size = new Size(282, 58); 270 | Label_State.TabIndex = 15; 271 | Label_State.Text = "Please load a model first."; 272 | // 273 | // Button_Generate 274 | // 275 | Button_Generate.Enabled = false; 276 | Button_Generate.Location = new Point(294, 294); 277 | Button_Generate.Name = "Button_Generate"; 278 | Button_Generate.Size = new Size(86, 55); 279 | Button_Generate.TabIndex = 14; 280 | Button_Generate.Text = "Generate"; 281 | Button_Generate.UseVisualStyleBackColor = true; 282 | Button_Generate.Click += Button_Generate_Click; 283 | // 284 | // label7 285 | // 286 | label7.AutoSize = true; 287 | label7.Location = new Point(285, 267); 288 | label7.Name = "label7"; 289 | label7.Size = new Size(31, 17); 290 | label7.TabIndex = 13; 291 | label7.Text = "CFG"; 292 | // 293 | // label6 294 | // 295 | label6.AutoSize = true; 296 | label6.Location = new Point(167, 267); 297 | label6.Name = "label6"; 298 | label6.Size = new Size(34, 17); 299 | label6.TabIndex = 12; 300 | label6.Text = "Step"; 301 | // 302 | // label5 303 | // 304 | label5.AutoSize = true; 305 | label5.Location = new Point(89, 267); 306 | label5.Name = "label5"; 307 | label5.Size = new Size(17, 17); 308 | label5.TabIndex = 11; 309 | label5.Text = "H"; 310 | // 311 | // NumericUpDown_Height 312 | // 313 | NumericUpDown_Height.Increment = new decimal(new int[] { 64, 0, 0, 0 }); 314 | NumericUpDown_Height.Location = new Point(112, 265); 315 | NumericUpDown_Height.Maximum = new decimal(new int[] { 2048, 0, 0, 0 }); 316 | NumericUpDown_Height.Minimum = new decimal(new int[] { 64, 0, 0, 0 }); 317 | NumericUpDown_Height.Name = "NumericUpDown_Height"; 318 | NumericUpDown_Height.Size = new Size(49, 23); 319 | NumericUpDown_Height.TabIndex = 10; 320 | NumericUpDown_Height.Value = new decimal(new int[] { 512, 0, 0, 0 }); 321 | // 322 | // NumericUpDown_CFG 323 | // 324 | NumericUpDown_CFG.Increment = new decimal(new int[] { 5, 0, 0, 65536 }); 325 | NumericUpDown_CFG.Location = new Point(322, 265); 326 | NumericUpDown_CFG.Maximum = new decimal(new int[] { 25, 0, 0, 0 }); 327 | NumericUpDown_CFG.Minimum = new decimal(new int[] { 5, 0, 0, 65536 }); 328 | NumericUpDown_CFG.Name = "NumericUpDown_CFG"; 329 | NumericUpDown_CFG.Size = new Size(58, 23); 330 | NumericUpDown_CFG.TabIndex = 9; 331 | NumericUpDown_CFG.Value = new decimal(new int[] { 7, 0, 0, 0 }); 332 | // 333 | // NumericUpDown_Step 334 | // 335 | NumericUpDown_Step.Location = new Point(207, 265); 336 | NumericUpDown_Step.Minimum = new decimal(new int[] { 1, 0, 0, 0 }); 337 | NumericUpDown_Step.Name = "NumericUpDown_Step"; 338 | NumericUpDown_Step.Size = new Size(60, 23); 339 | NumericUpDown_Step.TabIndex = 8; 340 | NumericUpDown_Step.Value = new decimal(new int[] { 20, 0, 0, 0 }); 341 | // 342 | // NumericUpDown_Width 343 | // 344 | NumericUpDown_Width.Increment = new decimal(new int[] { 64, 0, 0, 0 }); 345 | NumericUpDown_Width.Location = new Point(34, 265); 346 | NumericUpDown_Width.Maximum = new decimal(new int[] { 2048, 0, 0, 0 }); 347 | NumericUpDown_Width.Minimum = new decimal(new int[] { 64, 0, 0, 0 }); 348 | NumericUpDown_Width.Name = "NumericUpDown_Width"; 349 | NumericUpDown_Width.Size = new Size(49, 23); 350 | NumericUpDown_Width.TabIndex = 6; 351 | NumericUpDown_Width.Value = new decimal(new int[] { 512, 0, 0, 0 }); 352 | // 353 | // label4 354 | // 355 | label4.AutoSize = true; 356 | label4.Location = new Point(8, 267); 357 | label4.Name = "label4"; 358 | label4.Size = new Size(20, 17); 359 | label4.TabIndex = 5; 360 | label4.Text = "W"; 361 | // 362 | // PictureBox_Output 363 | // 364 | PictureBox_Output.BorderStyle = BorderStyle.FixedSingle; 365 | PictureBox_Output.Location = new Point(398, 22); 366 | PictureBox_Output.Name = "PictureBox_Output"; 367 | PictureBox_Output.Size = new Size(432, 327); 368 | PictureBox_Output.SizeMode = PictureBoxSizeMode.Zoom; 369 | PictureBox_Output.TabIndex = 4; 370 | PictureBox_Output.TabStop = false; 371 | // 372 | // label3 373 | // 374 | label3.AutoSize = true; 375 | label3.Location = new Point(8, 193); 376 | label3.Name = "label3"; 377 | label3.Size = new Size(66, 17); 378 | label3.TabIndex = 3; 379 | label3.Text = "N_Prompt"; 380 | // 381 | // TextBox_NPrompt 382 | // 383 | TextBox_NPrompt.Location = new Point(78, 158); 384 | TextBox_NPrompt.Multiline = true; 385 | TextBox_NPrompt.Name = "TextBox_NPrompt"; 386 | TextBox_NPrompt.Size = new Size(302, 87); 387 | TextBox_NPrompt.TabIndex = 2; 388 | TextBox_NPrompt.Text = "2d, 3d, cartoon, paintings"; 389 | // 390 | // TextBox_Prompt 391 | // 392 | TextBox_Prompt.Location = new Point(78, 36); 393 | TextBox_Prompt.Multiline = true; 394 | TextBox_Prompt.Name = "TextBox_Prompt"; 395 | TextBox_Prompt.Size = new Size(302, 104); 396 | TextBox_Prompt.TabIndex = 1; 397 | TextBox_Prompt.Text = "realistic, best quality, 4k, 8k, trees, beach, moon, stars, boat, "; 398 | // 399 | // label2 400 | // 401 | label2.AutoSize = true; 402 | label2.Location = new Point(8, 90); 403 | label2.Name = "label2"; 404 | label2.Size = new Size(51, 17); 405 | label2.TabIndex = 0; 406 | label2.Text = "Prompt"; 407 | // 408 | // tabPage2 409 | // 410 | tabPage2.Location = new Point(4, 26); 411 | tabPage2.Name = "tabPage2"; 412 | tabPage2.Padding = new Padding(3); 413 | tabPage2.Size = new Size(857, 367); 414 | tabPage2.TabIndex = 1; 415 | tabPage2.Text = "Image To Image"; 416 | tabPage2.UseVisualStyleBackColor = true; 417 | // 418 | // tabPage3 419 | // 420 | tabPage3.Location = new Point(4, 26); 421 | tabPage3.Name = "tabPage3"; 422 | tabPage3.Padding = new Padding(3); 423 | tabPage3.Size = new Size(857, 367); 424 | tabPage3.TabIndex = 2; 425 | tabPage3.Text = "Restore"; 426 | tabPage3.UseVisualStyleBackColor = true; 427 | // 428 | // FormMain 429 | // 430 | AutoScaleDimensions = new SizeF(7F, 17F); 431 | AutoScaleMode = AutoScaleMode.Font; 432 | ClientSize = new Size(889, 605); 433 | Controls.Add(tabControl1); 434 | Controls.Add(groupBox1); 435 | Name = "FormMain"; 436 | Text = "Stabel Diffusion Sharp"; 437 | Load += FormMain_Load; 438 | groupBox1.ResumeLayout(false); 439 | groupBox1.PerformLayout(); 440 | ((System.ComponentModel.ISupportInitialize)NumericUpDown_ClipSkip).EndInit(); 441 | tabControl1.ResumeLayout(false); 442 | tabPage1.ResumeLayout(false); 443 | groupBox2.ResumeLayout(false); 444 | groupBox2.PerformLayout(); 445 | ((System.ComponentModel.ISupportInitialize)NumericUpDown_Height).EndInit(); 446 | ((System.ComponentModel.ISupportInitialize)NumericUpDown_CFG).EndInit(); 447 | ((System.ComponentModel.ISupportInitialize)NumericUpDown_Step).EndInit(); 448 | ((System.ComponentModel.ISupportInitialize)NumericUpDown_Width).EndInit(); 449 | ((System.ComponentModel.ISupportInitialize)PictureBox_Output).EndInit(); 450 | ResumeLayout(false); 451 | } 452 | 453 | #endregion 454 | 455 | private GroupBox groupBox1; 456 | private Button Button_ModelScan; 457 | private Label label1; 458 | private TextBox TextBox_ModelPath; 459 | private TabControl tabControl1; 460 | private TabPage tabPage1; 461 | private Button Button_ModelLoad; 462 | private GroupBox groupBox2; 463 | private PictureBox PictureBox_Output; 464 | private Label label3; 465 | private TextBox TextBox_NPrompt; 466 | private TextBox TextBox_Prompt; 467 | private Label label2; 468 | private NumericUpDown NumericUpDown_Width; 469 | private Label label4; 470 | private Button Button_Generate; 471 | private Label label7; 472 | private Label label6; 473 | private Label label5; 474 | private NumericUpDown NumericUpDown_Height; 475 | private NumericUpDown NumericUpDown_CFG; 476 | private NumericUpDown NumericUpDown_Step; 477 | private Label Label_State; 478 | private TabPage tabPage2; 479 | private TabPage tabPage3; 480 | private Label label9; 481 | private Label label8; 482 | private ComboBox ComboBox_Precition; 483 | private ComboBox ComboBox_Device; 484 | private Button Button_VAEModelScan; 485 | private TextBox TextBox_VaePath; 486 | private Label label10; 487 | private Label label11; 488 | private NumericUpDown NumericUpDown_ClipSkip; 489 | } 490 | } 491 | -------------------------------------------------------------------------------- /StableDiffusionDemo_Winform/FormMain.cs: -------------------------------------------------------------------------------- 1 | using StableDiffusionSharp; 2 | using System.Diagnostics; 3 | 4 | namespace StableDiffusionDemo_Winform 5 | { 6 | public partial class FormMain : Form 7 | { 8 | string modelPath = string.Empty; 9 | string vaeModelPath = string.Empty; 10 | StableDiffusion? sd; 11 | 12 | public FormMain() 13 | { 14 | InitializeComponent(); 15 | } 16 | 17 | private void FormMain_Load(object sender, EventArgs e) 18 | { 19 | ComboBox_Device.SelectedIndex = 0; 20 | ComboBox_Precition.SelectedIndex = 0; 21 | } 22 | 23 | private void Button_ModelScan_Click(object sender, EventArgs e) 24 | { 25 | FileDialog fileDialog = new OpenFileDialog(); 26 | fileDialog.Filter = "Model files|*.safetensors;*.ckpt;*.pt;*.pth|All files|*.*"; 27 | if (fileDialog.ShowDialog() == DialogResult.OK) 28 | { 29 | TextBox_ModelPath.Text = fileDialog.FileName; 30 | modelPath = fileDialog.FileName; 31 | } 32 | } 33 | 34 | private void Button_ModelLoad_Click(object sender, EventArgs e) 35 | { 36 | if (File.Exists(modelPath)) 37 | { 38 | SDDeviceType deviceType = ComboBox_Device.SelectedIndex == 0 ? SDDeviceType.CUDA : SDDeviceType.CPU; 39 | SDScalarType scalarType = ComboBox_Precition.SelectedIndex == 0 ? SDScalarType.Float16 : SDScalarType.Float32; 40 | Task.Run(() => 41 | { 42 | base.Invoke(() => 43 | { 44 | Button_ModelLoad.Enabled = false; 45 | Button_Generate.Enabled = false; 46 | }); 47 | sd = new StableDiffusion(deviceType, scalarType); 48 | sd.StepProgress += Sd_StepProgress; 49 | sd.LoadModel(modelPath, vaeModelPath); 50 | base.Invoke(() => 51 | { 52 | Button_ModelLoad.Enabled = true; 53 | Button_Generate.Enabled = true; 54 | Label_State.Text = "Model loaded."; 55 | }); 56 | }); 57 | } 58 | } 59 | 60 | private void Button_VAEModelScan_Click(object sender, EventArgs e) 61 | { 62 | FileDialog fileDialog = new OpenFileDialog(); 63 | fileDialog.Filter = "Model files|*.safetensors;*.ckpt;*.pt;*.pth|All files|*.*"; 64 | if (fileDialog.ShowDialog() == DialogResult.OK) 65 | { 66 | TextBox_VaePath.Text = fileDialog.FileName; 67 | vaeModelPath = fileDialog.FileName; 68 | } 69 | } 70 | 71 | private void Sd_StepProgress(object? sender, StableDiffusion.StepEventArgs e) 72 | { 73 | base.Invoke(() => 74 | { 75 | Label_State.Text = $"Progress: {e.CurrentStep}/{e.TotalSteps}"; 76 | if (e.VaeApproxImg != null) 77 | { 78 | MemoryStream memoryStream = new MemoryStream(); 79 | e.VaeApproxImg.Write(memoryStream, ImageMagick.MagickFormat.Jpg); 80 | base.Invoke(() => 81 | { 82 | PictureBox_Output.Image = Image.FromStream(memoryStream); 83 | }); 84 | } 85 | }); 86 | } 87 | 88 | private void Button_Generate_Click(object sender, EventArgs e) 89 | { 90 | string prompt = TextBox_Prompt.Text; 91 | string nprompt = TextBox_NPrompt.Text; 92 | int step = (int)NumericUpDown_Step.Value; 93 | float cfg = (float)NumericUpDown_CFG.Value; 94 | long seed = 0; 95 | int width = (int)NumericUpDown_Width.Value; 96 | int height = (int)NumericUpDown_Height.Value; 97 | int clipSkip = (int)NumericUpDown_ClipSkip.Value; 98 | 99 | Task.Run(() => 100 | { 101 | Stopwatch stopwatch = Stopwatch.StartNew(); 102 | base.Invoke(() => 103 | { 104 | Button_ModelLoad.Enabled = false; 105 | Button_Generate.Enabled = false; 106 | Label_State.Text = "Generating..."; 107 | }); 108 | ImageMagick.MagickImage image = sd.TextToImage(prompt, nprompt, clipSkip, width, height, step, seed, cfg); 109 | MemoryStream memoryStream = new MemoryStream(); 110 | image.Write(memoryStream, ImageMagick.MagickFormat.Jpg); 111 | base.Invoke(() => 112 | { 113 | PictureBox_Output.Image = Image.FromStream(memoryStream); 114 | Button_ModelLoad.Enabled = true; 115 | Button_Generate.Enabled = true; 116 | Label_State.Text = $"Done. It takes {stopwatch.Elapsed.TotalSeconds.ToString("f2")} s"; 117 | }); 118 | GC.Collect(); 119 | }); 120 | } 121 | } 122 | } 123 | -------------------------------------------------------------------------------- /StableDiffusionDemo_Winform/FormMain.resx: -------------------------------------------------------------------------------- 1 |  2 | 3 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | text/microsoft-resx 110 | 111 | 112 | 2.0 113 | 114 | 115 | System.Resources.ResXResourceReader, System.Windows.Forms, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089 116 | 117 | 118 | System.Resources.ResXResourceWriter, System.Windows.Forms, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089 119 | 120 | -------------------------------------------------------------------------------- /StableDiffusionDemo_Winform/Program.cs: -------------------------------------------------------------------------------- 1 | namespace StableDiffusionDemo_Winform 2 | { 3 | internal static class Program 4 | { 5 | /// 6 | /// The main entry point for the application. 7 | /// 8 | [STAThread] 9 | static void Main() 10 | { 11 | // To customize application configuration such as set high DPI settings or default font, 12 | // see https://aka.ms/applicationconfiguration. 13 | ApplicationConfiguration.Initialize(); 14 | Application.Run(new FormMain()); 15 | } 16 | } 17 | } -------------------------------------------------------------------------------- /StableDiffusionDemo_Winform/StableDiffusionDemo_Winform.csproj: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | WinExe 5 | net6.0-windows7.0 6 | enable 7 | true 8 | enable 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | -------------------------------------------------------------------------------- /StableDiffusionSharp.sln: -------------------------------------------------------------------------------- 1 |  2 | Microsoft Visual Studio Solution File, Format Version 12.00 3 | # Visual Studio Version 17 4 | VisualStudioVersion = 17.12.35728.132 5 | MinimumVisualStudioVersion = 10.0.40219.1 6 | Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "StableDiffusionSharp", "StableDiffusionSharp\StableDiffusionSharp.csproj", "{BF6F0C17-D34A-4EFB-9194-DF0ED1FBB4D8}" 7 | EndProject 8 | Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "StableDiffusionDemo_Console", "StableDiffusionDemo_Console\StableDiffusionDemo_Console.csproj", "{4F4250A4-B849-4821-AFA5-F8B5191BF08C}" 9 | EndProject 10 | Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "StableDiffusionDemo_Winform", "StableDiffusionDemo_Winform\StableDiffusionDemo_Winform.csproj", "{7860DFE9-EC36-44B3-81E8-817029B849B5}" 11 | EndProject 12 | Global 13 | GlobalSection(SolutionConfigurationPlatforms) = preSolution 14 | Debug|Any CPU = Debug|Any CPU 15 | Release|Any CPU = Release|Any CPU 16 | EndGlobalSection 17 | GlobalSection(ProjectConfigurationPlatforms) = postSolution 18 | {BF6F0C17-D34A-4EFB-9194-DF0ED1FBB4D8}.Debug|Any CPU.ActiveCfg = Debug|Any CPU 19 | {BF6F0C17-D34A-4EFB-9194-DF0ED1FBB4D8}.Debug|Any CPU.Build.0 = Debug|Any CPU 20 | {BF6F0C17-D34A-4EFB-9194-DF0ED1FBB4D8}.Release|Any CPU.ActiveCfg = Release|Any CPU 21 | {BF6F0C17-D34A-4EFB-9194-DF0ED1FBB4D8}.Release|Any CPU.Build.0 = Release|Any CPU 22 | {4F4250A4-B849-4821-AFA5-F8B5191BF08C}.Debug|Any CPU.ActiveCfg = Debug|Any CPU 23 | {4F4250A4-B849-4821-AFA5-F8B5191BF08C}.Debug|Any CPU.Build.0 = Debug|Any CPU 24 | {4F4250A4-B849-4821-AFA5-F8B5191BF08C}.Release|Any CPU.ActiveCfg = Release|Any CPU 25 | {4F4250A4-B849-4821-AFA5-F8B5191BF08C}.Release|Any CPU.Build.0 = Release|Any CPU 26 | {7860DFE9-EC36-44B3-81E8-817029B849B5}.Debug|Any CPU.ActiveCfg = Debug|Any CPU 27 | {7860DFE9-EC36-44B3-81E8-817029B849B5}.Debug|Any CPU.Build.0 = Debug|Any CPU 28 | {7860DFE9-EC36-44B3-81E8-817029B849B5}.Release|Any CPU.ActiveCfg = Release|Any CPU 29 | {7860DFE9-EC36-44B3-81E8-817029B849B5}.Release|Any CPU.Build.0 = Release|Any CPU 30 | EndGlobalSection 31 | GlobalSection(SolutionProperties) = preSolution 32 | HideSolutionNode = FALSE 33 | EndGlobalSection 34 | GlobalSection(ExtensibilityGlobals) = postSolution 35 | SolutionGuid = {2D950EF2-E5CA-4631-8B81-9E0974E394D7} 36 | EndGlobalSection 37 | EndGlobal 38 | -------------------------------------------------------------------------------- /StableDiffusionSharp/ModelLoader/ModelLoader.cs: -------------------------------------------------------------------------------- 1 | using TorchSharp; 2 | using static TorchSharp.torch; 3 | 4 | namespace StableDiffusionSharp.ModelLoader 5 | { 6 | internal static class ModelLoader 7 | { 8 | public static nn.Module LoadModel(this torch.nn.Module module, string fileName, string maybeAddHeaderInBlock = "") 9 | { 10 | string extension = Path.GetExtension(fileName).ToLower(); 11 | if (extension == ".pt" || extension == ".ckpt" || extension == ".pth") 12 | { 13 | PickleLoader pickleLoader = new PickleLoader(); 14 | return pickleLoader.LoadPickle(module, fileName, maybeAddHeaderInBlock); 15 | } 16 | else if (extension == ".safetensors") 17 | { 18 | SafetensorsLoader safetensorsLoader = new SafetensorsLoader(); 19 | return safetensorsLoader.LoadSafetensors(module, fileName, maybeAddHeaderInBlock); 20 | } 21 | else 22 | { 23 | throw new ArgumentException("Invalid file extension"); 24 | } 25 | } 26 | 27 | public static ModelType GetModelType(string ModelPath) 28 | { 29 | string extension = Path.GetExtension(ModelPath).ToLower(); 30 | List tensorInfos = new List(); 31 | 32 | if (extension == ".pt" || extension == ".ckpt" || extension == ".pth") 33 | { 34 | PickleLoader pickleLoader = new PickleLoader(); 35 | tensorInfos = pickleLoader.ReadTensorsInfoFromFile(ModelPath); 36 | } 37 | else if (extension == ".safetensors") 38 | { 39 | SafetensorsLoader safetensorsLoader = new SafetensorsLoader(); 40 | tensorInfos = safetensorsLoader.ReadTensorsInfoFromFile(ModelPath); 41 | } 42 | else 43 | { 44 | throw new ArgumentException("Invalid file extension"); 45 | } 46 | 47 | if (tensorInfos.Count(a => a.Name.Contains("model.diffusion_model.double_blocks.")) > 0) 48 | { 49 | return ModelType.FLUX; 50 | } 51 | else if (tensorInfos.Count(a => a.Name.Contains("model.diffusion_model.joint_blocks.")) > 0) 52 | { 53 | return ModelType.SD3; 54 | } 55 | else if (tensorInfos.Count(a => a.Name.Contains("conditioner.embedders.1")) > 0) 56 | { 57 | return ModelType.SDXL; 58 | } 59 | else 60 | { 61 | return ModelType.SD1; 62 | } 63 | 64 | } 65 | 66 | 67 | } 68 | } 69 | -------------------------------------------------------------------------------- /StableDiffusionSharp/ModelLoader/PickleLoader.cs: -------------------------------------------------------------------------------- 1 | using System.Collections.ObjectModel; 2 | using System.IO.Compression; 3 | using TorchSharp; 4 | using static TorchSharp.torch; 5 | 6 | namespace StableDiffusionSharp.ModelLoader 7 | { 8 | internal class PickleLoader 9 | { 10 | private ZipArchive zip; 11 | private ReadOnlyCollection entries; 12 | 13 | internal List ReadTensorsInfoFromFile(string fileName) 14 | { 15 | List tensors = new List(); 16 | 17 | zip = ZipFile.OpenRead(fileName); 18 | entries = zip.Entries; 19 | ZipArchiveEntry headerEntry = entries.First(e => e.Name == "data.pkl"); 20 | byte[] headerBytes = new byte[headerEntry.Length]; 21 | // Header is always small enough to fit in memory, so we can read it all at once 22 | using (Stream stream = headerEntry.Open()) 23 | { 24 | stream.Read(headerBytes, 0, headerBytes.Length); 25 | } 26 | 27 | if (headerBytes[0] != 0x80 || headerBytes[1] != 0x02) 28 | { 29 | throw new ArgumentException("Not a valid pickle file"); 30 | } 31 | 32 | int index = 1; 33 | bool finished = false; 34 | bool readStrides = false; 35 | bool binPersid = false; 36 | 37 | TensorInfo tensor = new TensorInfo() { FileName = fileName, Offset = { 0 } }; 38 | 39 | int deepth = 0; 40 | 41 | Dictionary BinPut = new Dictionary(); 42 | 43 | while (index < headerBytes.Length && !finished) 44 | { 45 | byte opcode = headerBytes[index]; 46 | switch (opcode) 47 | { 48 | case (byte)'}': // EMPTY_DICT = b'}' # push empty dict 49 | break; 50 | case (byte)']': // EMPTY_LIST = b']' # push empty list 51 | break; 52 | // skip unused sections 53 | case (byte)'h': // BINGET = b'h' # " " " " " " ; " " 1-byte arg 54 | { 55 | int id = headerBytes[index + 1]; 56 | BinPut.TryGetValue(id, out string precision); 57 | if (precision != null) 58 | { 59 | if (precision.Contains("FloatStorage")) 60 | { 61 | tensor.Type = TorchSharp.torch.ScalarType.Float32; 62 | } 63 | else if (precision.Contains("HalfStorage")) 64 | { 65 | tensor.Type = TorchSharp.torch.ScalarType.Float16; 66 | } 67 | else if (precision.Contains("BFloat16Storage")) 68 | { 69 | tensor.Type = TorchSharp.torch.ScalarType.BFloat16; 70 | } 71 | } 72 | index++; 73 | break; 74 | } 75 | case (byte)'q': // BINPUT = b'to_q' # " " " " " ; " " 1-byte arg 76 | { 77 | index++; 78 | break; 79 | } 80 | case (byte)'Q': // BINPERSID = b'Q' # " " " ; " " " " stack 81 | binPersid = true; 82 | break; 83 | case (byte)'r': // LONG_BINPUT = b'r' # " " " " " ; " " 4-byte arg 84 | index += 4; 85 | break; 86 | case 0x95: // FRAME = b'\x95' # indicate the beginning of a new frame 87 | index += 8; 88 | break; 89 | case 0x94: // MEMOIZE = b'\x94' # store top of the stack in memo 90 | break; 91 | case (byte)'(': // MARK = b'(' # push special markobject on stack 92 | deepth++; 93 | break; 94 | case (byte)'K': // BININT1 = b'K' # push 1-byte unsigned int 95 | { 96 | int value = headerBytes[index + 1]; 97 | index++; 98 | 99 | if (deepth > 1 && value != 0 && binPersid) 100 | { 101 | if (readStrides) 102 | { 103 | //tensor.Stride.Add((ulong)value); 104 | tensor.Stride.Add((ulong)value); 105 | } 106 | else 107 | { 108 | tensor.Shape.Add(value); 109 | } 110 | } 111 | } 112 | break; 113 | case (byte)'M': // BININT2 = b'M' # push 2-byte unsigned int 114 | { 115 | UInt16 value = BitConverter.ToUInt16(headerBytes, index + 1); 116 | index += 2; 117 | 118 | if (deepth > 1 && value != 0 && binPersid) 119 | { 120 | if (readStrides) 121 | { 122 | tensor.Stride.Add(value); 123 | } 124 | else 125 | { 126 | tensor.Shape.Add(value); 127 | } 128 | } 129 | 130 | } 131 | break; 132 | case (byte)'J': // BININT = b'J' # push four-byte signed int 133 | { 134 | int value = BitConverter.ToInt32(headerBytes, index + 1); 135 | //int value = headerBytes[index + 4] << 24 + headerBytes[index + 3] << 16 + headerBytes[index + 2] << 8 + headerBytes[index + 1]; 136 | index += 4; 137 | 138 | if (deepth > 1 && value != 0 && binPersid) 139 | { 140 | if (readStrides) 141 | { 142 | tensor.Stride.Add((ulong)value); 143 | } 144 | else 145 | { 146 | tensor.Shape.Add(value); 147 | } 148 | } 149 | } 150 | break; 151 | 152 | case (byte)'X': // BINUNICODE = b'X' # " " " ; counted UTF-8 string argument 153 | { 154 | int length = headerBytes[index + 1]; 155 | int start = index + 5; 156 | byte module = headerBytes[index + 1]; 157 | string name = System.Text.Encoding.UTF8.GetString(headerBytes, start, length); 158 | index = index + 4 + length; 159 | 160 | if (deepth == 1) 161 | { 162 | tensor.Name = name; 163 | } 164 | else if (deepth == 3) 165 | { 166 | if ("cpu" != name && !name.Contains("cuda")) 167 | { 168 | tensor.DataNameInZipFile = name; 169 | } 170 | } 171 | } 172 | break; 173 | case 0x8C: // SHORT_BINUNICODE = b'\x8c' # push short string; UTF-8 length < 256 bytes 174 | { 175 | 176 | } 177 | break; 178 | case (byte)'c': // GLOBAL = b'c' # push self.find_class(modname, name); 2 string args 179 | { 180 | int start = index + 1; 181 | while (headerBytes[index + 1] != (byte)'q') 182 | { 183 | index++; 184 | } 185 | int length = index - start + 1; 186 | 187 | string global = System.Text.Encoding.UTF8.GetString(headerBytes, start, length); 188 | 189 | // precision is stored in the global variable 190 | // next tensor will read the precision 191 | // so we can set the Type here 192 | 193 | BinPut.Add(headerBytes[index + 2], global); 194 | 195 | if (global.Contains("FloatStorage")) 196 | { 197 | tensor.Type = TorchSharp.torch.ScalarType.Float32; 198 | } 199 | else if (global.Contains("HalfStorage")) 200 | { 201 | tensor.Type = TorchSharp.torch.ScalarType.Float16; 202 | } 203 | else if (global.Contains("BFloat16Storage")) 204 | { 205 | tensor.Type = TorchSharp.torch.ScalarType.BFloat16; 206 | } 207 | break; 208 | } 209 | case 0x86: // TUPLE2 = b'\x86' # build 2-tuple from two topmost stack items 210 | { 211 | if (binPersid) 212 | { 213 | readStrides = true; 214 | } 215 | break; 216 | } 217 | case 0x85: // TUPLE1 = b'\x85' # build 1-tuple from stack top 218 | if (binPersid) 219 | { 220 | readStrides = true; 221 | } 222 | break; 223 | case (byte)'t': // TUPLE = b't' # build tuple from topmost stack items 224 | deepth--; 225 | if (binPersid) 226 | { 227 | readStrides = true; 228 | } 229 | break; 230 | case (byte)'R': // REDUCE = b'R' # apply callable to argtuple, both on stack 231 | if (deepth == 1) 232 | { 233 | if (tensor.Name.Contains("metadata")) 234 | { 235 | break; 236 | } 237 | 238 | if (string.IsNullOrEmpty(tensor.DataNameInZipFile)) 239 | { 240 | tensor.DataNameInZipFile = tensors.Last().DataNameInZipFile; 241 | tensor.Offset = new List { (ulong)(tensor.Shape[0] * tensor.Type.ElementSize()) }; 242 | tensor.Shape.RemoveAt(0); 243 | //tensor.offset = tensors.Last(). 244 | } 245 | tensors.Add(tensor); 246 | 247 | tensor = new TensorInfo() { FileName = fileName, Offset = { 0 } }; 248 | readStrides = false; 249 | binPersid = false; 250 | } 251 | break; 252 | case (byte)'.': // STOP = b'.' # every pickle ends with STOP 253 | finished = true; 254 | break; 255 | default: 256 | break; 257 | } 258 | index++; 259 | } 260 | TensorInfo metaTensor = tensors.Find(x => x.Name.Contains("_metadata")); 261 | if (metaTensor != null) 262 | { 263 | tensors.Remove(metaTensor); 264 | } 265 | return tensors; 266 | } 267 | 268 | private byte[] ReadByteFromFile(TensorInfo tensor) 269 | { 270 | if (entries is null) 271 | { 272 | throw new ArgumentNullException(nameof(entries)); 273 | } 274 | 275 | ZipArchiveEntry dataEntry = entries.First(e => e.Name == tensor.DataNameInZipFile); 276 | long i = 1; 277 | foreach (var ne in tensor.Shape) 278 | { 279 | i *= ne; 280 | } 281 | ulong length = (ulong)(tensor.Type.ElementSize() * i); 282 | byte[] data = new byte[dataEntry.Length]; 283 | 284 | using (Stream stream = dataEntry.Open()) 285 | { 286 | stream.Read(data, 0, data.Length); 287 | } 288 | 289 | //data = data.Take(new Range((int)tensor.Offset[0], (int)(tensor.Offset[0] + length))).ToArray(); 290 | byte[] result = new byte[length]; 291 | for (int j = 0; j < (int)length; j++) 292 | { 293 | result[j] = data[j + (int)tensor.Offset[0]]; 294 | } 295 | return result; 296 | //return data; 297 | } 298 | 299 | internal Dictionary Load(string fileName, string addString = "") 300 | { 301 | Dictionary tensors = new Dictionary(); 302 | List tensorInfos = ReadTensorsInfoFromFile(fileName); 303 | foreach (TensorInfo tensorInfo in tensorInfos) 304 | { 305 | TorchSharp.torch.Tensor tensor = TorchSharp.torch.empty(tensorInfo.Shape.ToArray(), dtype: tensorInfo.Type); 306 | tensor.bytes = ReadByteFromFile(tensorInfo); 307 | tensors.Add(addString + tensorInfo.Name, tensor); 308 | } 309 | return tensors; 310 | } 311 | 312 | internal nn.Module LoadPickle(torch.nn.Module module, string fileName, string maybeAddHeaderInBlock = "") 313 | { 314 | using (torch.no_grad()) 315 | using (NewDisposeScope()) 316 | { 317 | List tensorInfos = ReadTensorsInfoFromFile(fileName); 318 | foreach (var mod in module.named_parameters()) 319 | { 320 | ScalarType dtype = mod.parameter.dtype; 321 | TensorInfo info = tensorInfos.First(a => ((a.Name == mod.name) || (maybeAddHeaderInBlock + a.Name == mod.name))); 322 | Tensor t = torch.zeros(mod.parameter.shape, info.Type); 323 | t.bytes = ReadByteFromFile(info); 324 | mod.parameter.copy_(t); 325 | t.Dispose(); 326 | GC.Collect(); 327 | } 328 | return module; 329 | } 330 | } 331 | 332 | } 333 | } 334 | -------------------------------------------------------------------------------- /StableDiffusionSharp/ModelLoader/SafetensorsLoader.cs: -------------------------------------------------------------------------------- 1 | using Newtonsoft.Json.Linq; 2 | using System.Text; 3 | using static TorchSharp.torch; 4 | using TorchSharp; 5 | 6 | namespace StableDiffusionSharp.ModelLoader 7 | { 8 | internal class SafetensorsLoader 9 | { 10 | internal List ReadTensorsInfoFromFile(string inputFileName) 11 | { 12 | using (FileStream stream = File.OpenRead(inputFileName)) 13 | { 14 | long len = stream.Length; 15 | if (len < 10) 16 | { 17 | throw new ArgumentOutOfRangeException("File cannot be valid safetensors: too short"); 18 | } 19 | 20 | // Safetensors file first 8 byte to int64 is the header length 21 | byte[] headerBlock = new byte[8]; 22 | stream.Read(headerBlock, 0, 8); 23 | long headerSize = BitConverter.ToInt64(headerBlock, 0); 24 | if (len < 8 + headerSize || headerSize <= 0 || headerSize > 100000000) 25 | { 26 | throw new ArgumentOutOfRangeException($"File cannot be valid safetensors: header len wrong, size:{headerSize}"); 27 | } 28 | 29 | // Read the header, header file is a json file 30 | byte[] headerBytes = new byte[headerSize]; 31 | stream.Read(headerBytes, 0, (int)headerSize); 32 | 33 | string header = Encoding.UTF8.GetString(headerBytes); 34 | long bodyPosition = stream.Position; 35 | JToken token = JToken.Parse(header); 36 | 37 | List tensors = new List(); 38 | foreach (var sub in token.ToObject>()) 39 | { 40 | Dictionary value = sub.Value.ToObject>(); 41 | value.TryGetValue("data_offsets", out JToken offsets); 42 | value.TryGetValue("dtype", out JToken dtype); 43 | value.TryGetValue("shape", out JToken shape); 44 | 45 | ulong[] offsetArray = offsets?.ToObject(); 46 | if (null == offsetArray) 47 | { 48 | continue; 49 | } 50 | long[] shapeArray = shape.ToObject(); 51 | if (shapeArray.Length < 1) 52 | { 53 | shapeArray = new long[] { 1 }; 54 | } 55 | TorchSharp.torch.ScalarType tensor_type = TorchSharp.torch.ScalarType.Float32; 56 | switch (dtype.ToString()) 57 | { 58 | case "I8": tensor_type = TorchSharp.torch.ScalarType.Int8; break; 59 | case "I16": tensor_type = TorchSharp.torch.ScalarType.Int16; break; 60 | case "I32": tensor_type = TorchSharp.torch.ScalarType.Int32; break; 61 | case "I64": tensor_type = TorchSharp.torch.ScalarType.Int64; break; 62 | case "BF16": tensor_type = TorchSharp.torch.ScalarType.BFloat16; break; 63 | case "F16": tensor_type = TorchSharp.torch.ScalarType.Float16; break; 64 | case "F32": tensor_type = TorchSharp.torch.ScalarType.Float32; break; 65 | case "F64": tensor_type = TorchSharp.torch.ScalarType.Float64; break; 66 | case "U8": tensor_type = TorchSharp.torch.ScalarType.Byte; break; 67 | case "BOOL": tensor_type = TorchSharp.torch.ScalarType.Bool; break; 68 | case "U16": 69 | case "U32": 70 | case "U64": 71 | case "F8_E4M3": 72 | case "F8_E5M2": break; 73 | } 74 | 75 | TensorInfo tensor = new TensorInfo 76 | { 77 | Name = sub.Key, 78 | Type = tensor_type, 79 | Shape = shapeArray.ToList(), 80 | Offset = offsetArray.ToList(), 81 | FileName = inputFileName, 82 | BodyPosition = bodyPosition 83 | }; 84 | 85 | tensors.Add(tensor); 86 | } 87 | return tensors; 88 | } 89 | } 90 | 91 | private byte[] ReadByteFromFile(string inputFileName, long bodyPosition, long offset, int size) 92 | { 93 | using (FileStream stream = File.OpenRead(inputFileName)) 94 | { 95 | stream.Seek(bodyPosition + offset, SeekOrigin.Begin); 96 | byte[] dest = new byte[size]; 97 | stream.Read(dest, 0, size); 98 | return dest; 99 | } 100 | } 101 | 102 | private byte[] ReadByteFromFile(TensorInfo tensor) 103 | { 104 | string inputFileName = tensor.FileName; 105 | long bodyPosition = tensor.BodyPosition; 106 | ulong offset = tensor.Offset[0]; 107 | int size = (int)(tensor.Offset[1] - tensor.Offset[0]); 108 | return ReadByteFromFile(inputFileName, bodyPosition, (long)offset, size); 109 | } 110 | 111 | internal Dictionary Load(string fileName, string addString = "") 112 | { 113 | Dictionary tensors = new Dictionary(); 114 | List tensorInfos = ReadTensorsInfoFromFile(fileName); 115 | foreach (TensorInfo tensorInfo in tensorInfos) 116 | { 117 | TorchSharp.torch.Tensor tensor = TorchSharp.torch.empty(tensorInfo.Shape.ToArray(), dtype: tensorInfo.Type); 118 | tensor.bytes = ReadByteFromFile(tensorInfo); 119 | tensors.Add(addString + tensorInfo.Name, tensor); 120 | } 121 | return tensors; 122 | } 123 | 124 | internal nn.Module LoadSafetensors(torch.nn.Module module, string fileName, string maybeAddHeaderInBlock = "") 125 | { 126 | using (torch.no_grad()) 127 | using (NewDisposeScope()) 128 | { 129 | List tensorInfos = ReadTensorsInfoFromFile(fileName); 130 | foreach (var mod in module.named_parameters()) 131 | { 132 | ScalarType dtype = mod.parameter.dtype; 133 | TensorInfo info = tensorInfos.First(a => ((a.Name == mod.name) || (maybeAddHeaderInBlock + a.Name == mod.name))); 134 | Tensor t = torch.zeros(mod.parameter.shape, info.Type); 135 | t.bytes = ReadByteFromFile(info); 136 | mod.parameter.copy_(t); 137 | t.Dispose(); 138 | GC.Collect(); 139 | } 140 | return module; 141 | } 142 | } 143 | } 144 | } 145 | -------------------------------------------------------------------------------- /StableDiffusionSharp/ModelLoader/TensorInfo.cs: -------------------------------------------------------------------------------- 1 | namespace StableDiffusionSharp.ModelLoader 2 | { 3 | internal class TensorInfo 4 | { 5 | public string Name { get; set; } 6 | public TorchSharp.torch.ScalarType Type { get; set; } = TorchSharp.torch.ScalarType.Float16; 7 | public List Shape { get; set; } = new List(); 8 | public List Stride { get; set; } = new List(); 9 | public string DataNameInZipFile { get; set; } 10 | public string FileName { get; set; } 11 | public List Offset { get; set; } = new List(); 12 | public long BodyPosition { get; set; } 13 | } 14 | } 15 | -------------------------------------------------------------------------------- /StableDiffusionSharp/Models/VAEApprox/vaeapp_sd15.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntptrMax/StableDiffusionSharp/cfdca9b5c50b86cd59aec08dfebd0961d91ba1c2/StableDiffusionSharp/Models/VAEApprox/vaeapp_sd15.pth -------------------------------------------------------------------------------- /StableDiffusionSharp/Models/VAEApprox/xlvaeapp.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntptrMax/StableDiffusionSharp/cfdca9b5c50b86cd59aec08dfebd0961d91ba1c2/StableDiffusionSharp/Models/VAEApprox/xlvaeapp.pth -------------------------------------------------------------------------------- /StableDiffusionSharp/Modules/Clip.cs: -------------------------------------------------------------------------------- 1 | using TorchSharp; 2 | using TorchSharp.Modules; 3 | using static TorchSharp.torch; 4 | using static TorchSharp.torch.nn; 5 | 6 | namespace StableDiffusionSharp.Modules 7 | { 8 | internal class Clip 9 | { 10 | private enum Activations 11 | { 12 | ReLU, 13 | SiLU, 14 | QuickGELU, 15 | GELU 16 | } 17 | 18 | internal class ViT_L_Clip : Module 19 | { 20 | private readonly CLIPTextModel transformer; 21 | 22 | public ViT_L_Clip(long n_vocab = 49408, long n_token = 77, long num_layers = 12, long n_heads = 12, long embed_dim = 768, long intermediate_size = 768 * 4, Device? device = null, ScalarType? dtype = null) : base(nameof(ViT_L_Clip)) 23 | { 24 | transformer = new CLIPTextModel(n_vocab, n_token, num_layers, n_heads, embed_dim, intermediate_size, device: device, dtype: dtype); 25 | RegisterComponents(); 26 | } 27 | 28 | public override Tensor forward(Tensor token, long num_skip, bool with_final_ln) 29 | { 30 | Device device = transformer.parameters().First().device; 31 | token = token.to(device); 32 | return transformer.forward(token, num_skip, with_final_ln); 33 | } 34 | 35 | private class CLIPTextModel : Module 36 | { 37 | private readonly CLIPTextTransformer text_model; 38 | public CLIPTextModel(long n_vocab, long n_token, long num_layers, long n_heads, long embed_dim, long intermediate_size, Device? device = null, ScalarType? dtype = null) : base(nameof(CLIPTextModel)) 39 | { 40 | text_model = new CLIPTextTransformer(n_vocab, n_token, num_layers, n_heads, embed_dim, intermediate_size, device: device, dtype: dtype); 41 | RegisterComponents(); 42 | } 43 | public override Tensor forward(Tensor x, long num_skip, bool with_final_ln) 44 | { 45 | return text_model.forward(x, num_skip, with_final_ln); 46 | } 47 | } 48 | 49 | private class CLIPTextTransformer : Module 50 | { 51 | private readonly CLIPTextEmbeddings embeddings; 52 | private readonly CLIPEncoder encoder; 53 | private readonly LayerNorm final_layer_norm; 54 | private readonly long num_layers; 55 | 56 | public CLIPTextTransformer(long n_vocab, long n_token, long num_layers, long n_heads, long embed_dim, long intermediate_size, Device? device = null, ScalarType? dtype = null) : base(nameof(CLIPTextTransformer)) 57 | { 58 | this.num_layers = num_layers; 59 | embeddings = new CLIPTextEmbeddings(n_vocab, embed_dim, n_token, device: device, dtype: dtype); 60 | encoder = new CLIPEncoder(num_layers, embed_dim, n_heads, intermediate_size, Activations.QuickGELU, device: device, dtype: dtype); 61 | final_layer_norm = LayerNorm(embed_dim, device: device, dtype: dtype); 62 | RegisterComponents(); 63 | } 64 | public override Tensor forward(Tensor x, long num_skip, bool with_final_ln) 65 | { 66 | x = embeddings.forward(x); 67 | x = encoder.forward(x, num_skip); 68 | if (with_final_ln) 69 | { 70 | x = final_layer_norm.forward(x); 71 | } 72 | return x; 73 | } 74 | } 75 | 76 | private class CLIPTextEmbeddings : Module 77 | { 78 | private readonly Embedding token_embedding; 79 | private readonly Embedding position_embedding; 80 | private readonly Parameter position_ids; 81 | public CLIPTextEmbeddings(long n_vocab, long n_embd, long n_token, Device? device = null, ScalarType? dtype = null) : base(nameof(CLIPTextEmbeddings)) 82 | { 83 | position_ids = Parameter(zeros(size: new long[] { 1, n_token }, device: device, dtype: dtype)); 84 | token_embedding = Embedding(n_vocab, n_embd, device: device, dtype: dtype); 85 | position_embedding = Embedding(n_token, n_embd, device: device, dtype: dtype); 86 | RegisterComponents(); 87 | } 88 | 89 | public override Tensor forward(Tensor tokens) 90 | { 91 | return token_embedding.forward(tokens) + position_embedding.forward(position_ids.@long()); 92 | } 93 | } 94 | 95 | private class CLIPEncoderLayer : Module 96 | { 97 | private readonly LayerNorm layer_norm1; 98 | private readonly LayerNorm layer_norm2; 99 | private readonly CLIPAttention self_attn; 100 | private readonly CLIPMLP mlp; 101 | 102 | public CLIPEncoderLayer(long n_head, long embed_dim, long intermediate_size, Activations activations = Activations.QuickGELU, Device? device = null, ScalarType? dtype = null) : base(nameof(CLIPEncoderLayer)) 103 | { 104 | layer_norm1 = LayerNorm(embed_dim, device: device, dtype: dtype); 105 | self_attn = new CLIPAttention(embed_dim, n_head, device: device, dtype: dtype); 106 | layer_norm2 = LayerNorm(embed_dim, device: device, dtype: dtype); 107 | mlp = new CLIPMLP(embed_dim, intermediate_size, embed_dim, activations, device: device, dtype: dtype); 108 | RegisterComponents(); 109 | } 110 | 111 | public override Tensor forward(Tensor x) 112 | { 113 | x += self_attn.forward(layer_norm1.forward(x)); 114 | x += mlp.forward(layer_norm2.forward(x)); 115 | return x; 116 | } 117 | } 118 | 119 | private class CLIPMLP : Module 120 | { 121 | private readonly Linear fc1; 122 | private readonly Linear fc2; 123 | private readonly Activations act_layer; 124 | public CLIPMLP(long in_features, long? hidden_features = null, long? out_features = null, Activations act_layer = Activations.QuickGELU, bool bias = true, Device? device = null, ScalarType? dtype = null) : base(nameof(CLIPMLP)) 125 | { 126 | out_features ??= in_features; 127 | hidden_features ??= out_features; 128 | 129 | fc1 = Linear(in_features, (long)hidden_features, hasBias: bias, device: device, dtype: dtype); 130 | fc2 = Linear((long)hidden_features, (long)out_features, hasBias: bias, device: device, dtype: dtype); 131 | this.act_layer = act_layer; 132 | RegisterComponents(); 133 | } 134 | 135 | public override Tensor forward(Tensor x) 136 | { 137 | x = fc1.forward(x); 138 | 139 | switch (act_layer) 140 | { 141 | case Activations.ReLU: 142 | x = functional.relu(x); 143 | break; 144 | case Activations.SiLU: 145 | x = functional.silu(x); 146 | break; 147 | case Activations.QuickGELU: 148 | x = x * sigmoid(1.702 * x); 149 | break; 150 | case Activations.GELU: 151 | x = functional.gelu(x); 152 | break; 153 | } 154 | x = fc2.forward(x); 155 | return x; 156 | } 157 | } 158 | 159 | private class CLIPAttention : Module 160 | { 161 | private readonly long heads; 162 | private readonly Linear q_proj; 163 | private readonly Linear k_proj; 164 | private readonly Linear v_proj; 165 | private readonly Linear out_proj; 166 | 167 | public CLIPAttention(long embed_dim, long heads, Device? device = null, ScalarType? dtype = null) : base(nameof(CLIPAttention)) 168 | { 169 | this.heads = heads; 170 | q_proj = Linear(embed_dim, embed_dim, hasBias: true, device: device, dtype: dtype); 171 | k_proj = Linear(embed_dim, embed_dim, hasBias: true, device: device, dtype: dtype); 172 | v_proj = Linear(embed_dim, embed_dim, hasBias: true, device: device, dtype: dtype); 173 | out_proj = Linear(embed_dim, embed_dim, hasBias: true, device: device, dtype: dtype); 174 | 175 | RegisterComponents(); 176 | } 177 | 178 | public override Tensor forward(Tensor x) 179 | { 180 | using (var _ = NewDisposeScope()) 181 | { 182 | Tensor q = q_proj.forward(x); 183 | Tensor k = k_proj.forward(x); 184 | Tensor v = v_proj.forward(x); 185 | Tensor output = attention(q, k, v, heads); 186 | //TensorInfo output = self_atten(to_q, to_k, to_v, this.heads); 187 | return out_proj.forward(output).MoveToOuterDisposeScope(); 188 | } 189 | } 190 | 191 | private static Tensor self_atten(Tensor q, Tensor k, Tensor v, long heads) 192 | { 193 | long[] input_shape = q.shape; 194 | long batch_size = q.shape[0]; 195 | long sequence_length = q.shape[1]; 196 | long d_head = q.shape[2] / heads; 197 | long[] interim_shape = new long[] { batch_size, sequence_length, heads, d_head }; 198 | 199 | q = q.view(interim_shape).transpose(1, 2); 200 | k = k.view(interim_shape).transpose(1, 2); 201 | v = v.view(interim_shape).transpose(1, 2); 202 | 203 | var weight = matmul(q, k.transpose(-1, -2)); 204 | var mask = ones_like(weight).triu(1).to(@bool); 205 | weight.masked_fill_(mask, float.NegativeInfinity); 206 | 207 | weight = weight / (float)Math.Sqrt(d_head); 208 | weight = functional.softmax(weight, dim: -1); 209 | 210 | var output = matmul(weight, v); 211 | output = output.transpose(1, 2); 212 | output = output.reshape(input_shape); 213 | return output; 214 | } 215 | 216 | // Convenience wrapper around a basic attention operation 217 | private static Tensor attention(Tensor q, Tensor k, Tensor v, long heads) 218 | { 219 | long b = q.shape[0]; 220 | long dim_head = q.shape[2]; 221 | dim_head /= heads; 222 | q = q.view(b, -1, heads, dim_head).transpose(1, 2); 223 | k = k.view(b, -1, heads, dim_head).transpose(1, 2); 224 | v = v.view(b, -1, heads, dim_head).transpose(1, 2); 225 | Tensor output = functional.scaled_dot_product_attention(q, k, v, is_casual: true); 226 | output = output.transpose(1, 2); 227 | output = output.view(b, -1, heads * dim_head); 228 | return output; 229 | } 230 | } 231 | 232 | private class CLIPEncoder : Module 233 | { 234 | private readonly ModuleList layers; 235 | 236 | public CLIPEncoder(long num_layers, long embed_dim, long heads, long intermediate_size, Activations intermediate_activation, Device? device = null, ScalarType? dtype = null) : base(nameof(CLIPEncoder)) 237 | { 238 | layers = new ModuleList(); 239 | for (int i = 0; i < num_layers; i++) 240 | { 241 | layers.append(new CLIPEncoderLayer(heads, embed_dim, intermediate_size, intermediate_activation, device: device, dtype: dtype)); 242 | } 243 | RegisterComponents(); 244 | } 245 | 246 | public override Tensor forward(Tensor x, long num_skip) 247 | { 248 | long num_act = num_skip > 0 ? layers.Count - num_skip : layers.Count; 249 | for (int i = 0; i < num_act; i++) 250 | { 251 | x = layers[i].forward(x); 252 | } 253 | 254 | return x; 255 | } 256 | } 257 | } 258 | 259 | private class ViT_bigG_Clip : Module 260 | { 261 | private readonly int adm_in_channels; 262 | 263 | private readonly Embedding token_embedding; 264 | private readonly Parameter positional_embedding; 265 | private readonly Transformer transformer; 266 | private readonly LayerNorm ln_final; 267 | private readonly Parameter text_projection; 268 | 269 | public ViT_bigG_Clip(long n_vocab = 49408, long n_token = 77, long num_layers = 32, long n_heads = 20, long embed_dim = 1280, long intermediate_size = 1280 * 4, Device? device = null, ScalarType? dtype = null) : base(nameof(ViT_bigG_Clip)) 270 | { 271 | token_embedding = Embedding(n_vocab, embed_dim, device: device, dtype: dtype); 272 | positional_embedding = Parameter(zeros(size: new long[] { n_token, embed_dim }, device: device, dtype: dtype)); 273 | text_projection = Parameter(zeros(size: new long[] { embed_dim, embed_dim }, device: device, dtype: dtype)); 274 | transformer = new Transformer(num_layers, embed_dim, n_heads, intermediate_size, Activations.GELU, device: device, dtype: dtype); 275 | ln_final = LayerNorm(embed_dim, device: device, dtype: dtype); 276 | RegisterComponents(); 277 | } 278 | 279 | public override Tensor forward(Tensor x, int num_skip, bool with_final_ln, bool return_pooled) 280 | { 281 | using (NewDisposeScope()) 282 | { 283 | Tensor input_ids = x; 284 | x = token_embedding.forward(x) + positional_embedding; 285 | x = transformer.forward(x, num_skip); 286 | if (with_final_ln || return_pooled) 287 | { 288 | x = ln_final.forward(x); 289 | } 290 | if (return_pooled) 291 | { 292 | x = x[torch.arange(x.shape[0], device: x.device), input_ids.to(type: ScalarType.Int32, device: x.device).argmax(dim: -1)]; 293 | x = functional.linear(x, text_projection.transpose(0, 1)); 294 | } 295 | return x.MoveToOuterDisposeScope(); 296 | } 297 | } 298 | 299 | private class Transformer : Module 300 | { 301 | private readonly ModuleList resblocks; 302 | public Transformer(long num_layers, long embed_dim, long heads, long intermediate_size, Activations intermediate_activation, Device? device = null, ScalarType? dtype = null) : base(nameof(Transformer)) 303 | { 304 | resblocks = new ModuleList(); 305 | for (int i = 0; i < num_layers; i++) 306 | { 307 | resblocks.append(new ResidualAttentionBlock(heads, embed_dim, intermediate_size, intermediate_activation, device: device, dtype: dtype)); 308 | } 309 | RegisterComponents(); 310 | } 311 | 312 | public override Tensor forward(Tensor x, int num_skip) 313 | { 314 | int num_act = num_skip > 0 ? resblocks.Count - num_skip : resblocks.Count; 315 | for (int i = 0; i < num_act; i++) 316 | { 317 | x = resblocks[i].forward(x); 318 | } 319 | return x; 320 | } 321 | } 322 | 323 | private class ResidualAttentionBlock : Module 324 | { 325 | private readonly LayerNorm ln_1; 326 | private readonly LayerNorm ln_2; 327 | private readonly MultiheadAttention attn; 328 | private readonly Mlp mlp; 329 | 330 | public ResidualAttentionBlock(long n_head, long embed_dim, long intermediate_size, Activations activations = Activations.QuickGELU, Device? device = null, ScalarType? dtype = null) : base(nameof(ResidualAttentionBlock)) 331 | { 332 | ln_1 = LayerNorm(embed_dim, device: device, dtype: dtype); 333 | attn = new MultiheadAttention(embed_dim, n_head, device: device, dtype: dtype); 334 | ln_2 = LayerNorm(embed_dim, device: device, dtype: dtype); 335 | mlp = new Mlp(embed_dim, intermediate_size, embed_dim, activations, device: device, dtype: dtype); 336 | RegisterComponents(); 337 | } 338 | 339 | public override Tensor forward(Tensor x) 340 | { 341 | x += attn.forward(ln_1.forward(x)); 342 | x += mlp.forward(ln_2.forward(x)); 343 | return x; 344 | } 345 | } 346 | 347 | private class Mlp : Module 348 | { 349 | private readonly Linear c_fc; 350 | private readonly Linear c_proj; 351 | private readonly Activations act_layer; 352 | public Mlp(long in_features, long? hidden_features = null, long? out_features = null, Activations act_layer = Activations.QuickGELU, bool bias = true, Device? device = null, ScalarType? dtype = null) : base(nameof(Mlp)) 353 | { 354 | out_features ??= in_features; 355 | hidden_features ??= out_features; 356 | 357 | c_fc = Linear(in_features, (long)hidden_features, hasBias: bias, device: device, dtype: dtype); 358 | c_proj = Linear((long)hidden_features, (long)out_features, hasBias: bias, device: device, dtype: dtype); 359 | this.act_layer = act_layer; 360 | RegisterComponents(); 361 | } 362 | 363 | public override Tensor forward(Tensor x) 364 | { 365 | x = c_fc.forward(x); 366 | 367 | switch (act_layer) 368 | { 369 | case Activations.ReLU: 370 | x = functional.relu(x); 371 | break; 372 | case Activations.SiLU: 373 | x = functional.silu(x); 374 | break; 375 | case Activations.QuickGELU: 376 | x = x * sigmoid(1.702 * x); 377 | break; 378 | case Activations.GELU: 379 | x = functional.gelu(x); 380 | break; 381 | } 382 | x = c_proj.forward(x); 383 | return x; 384 | } 385 | } 386 | 387 | private class MultiheadAttention : Module 388 | { 389 | private readonly long heads; 390 | private readonly Parameter in_proj_weight; 391 | private readonly Parameter in_proj_bias; 392 | private readonly Linear out_proj; 393 | 394 | public MultiheadAttention(long embed_dim, long heads, Device? device = null, ScalarType? dtype = null) : base(nameof(MultiheadAttention)) 395 | { 396 | this.heads = heads; 397 | in_proj_weight = Parameter(zeros(new long[] { 3 * embed_dim, embed_dim }, device: device, dtype: dtype)); 398 | in_proj_bias = Parameter(zeros(new long[] { 3 * embed_dim }, device: device, dtype: dtype)); 399 | out_proj = Linear(embed_dim, embed_dim, hasBias: true, device: device, dtype: dtype); 400 | 401 | RegisterComponents(); 402 | } 403 | 404 | public override Tensor forward(Tensor x) 405 | { 406 | using (var _ = NewDisposeScope()) 407 | { 408 | Tensor[] qkv = functional.linear(x, in_proj_weight, in_proj_bias).chunk(3, 2); 409 | Tensor q = qkv[0]; 410 | Tensor k = qkv[1]; 411 | Tensor v = qkv[2]; 412 | Tensor output = attention(q, k, v, heads); 413 | //TensorInfo output = self_atten(to_q, to_k, to_v, this.heads); 414 | return out_proj.forward(output).MoveToOuterDisposeScope(); 415 | } 416 | } 417 | 418 | private static Tensor self_atten(Tensor q, Tensor k, Tensor v, long heads) 419 | { 420 | long[] input_shape = q.shape; 421 | long batch_size = q.shape[0]; 422 | long sequence_length = q.shape[1]; 423 | long d_head = q.shape[2] / heads; 424 | long[] interim_shape = new long[] { batch_size, sequence_length, heads, d_head }; 425 | 426 | q = q.view(interim_shape).transpose(1, 2); 427 | k = k.view(interim_shape).transpose(1, 2); 428 | v = v.view(interim_shape).transpose(1, 2); 429 | 430 | var weight = matmul(q, k.transpose(-1, -2)); 431 | var mask = ones_like(weight).triu(1).to(@bool); 432 | weight.masked_fill_(mask, float.NegativeInfinity); 433 | 434 | weight = weight / (float)Math.Sqrt(d_head); 435 | weight = functional.softmax(weight, dim: -1); 436 | 437 | var output = matmul(weight, v); 438 | output = output.transpose(1, 2); 439 | output = output.reshape(input_shape); 440 | return output; 441 | } 442 | 443 | // Convenience wrapper around a basic attention operation 444 | private static Tensor attention(Tensor q, Tensor k, Tensor v, long heads) 445 | { 446 | long b = q.shape[0]; 447 | long dim_head = q.shape[2]; 448 | dim_head /= heads; 449 | q = q.view(b, -1, heads, dim_head).transpose(1, 2); 450 | k = k.view(b, -1, heads, dim_head).transpose(1, 2); 451 | v = v.view(b, -1, heads, dim_head).transpose(1, 2); 452 | Tensor output = functional.scaled_dot_product_attention(q, k, v, is_casual: true); 453 | output = output.transpose(1, 2); 454 | output = output.view(b, -1, heads * dim_head); 455 | return output; 456 | } 457 | } 458 | } 459 | 460 | internal class SDCliper : Module 461 | { 462 | private readonly ViT_L_Clip cond_stage_model; 463 | private readonly long n_token; 464 | private readonly long endToken; 465 | 466 | public SDCliper(long n_vocab = 49408, long n_token = 77, long num_layers = 12, long n_heads = 12, long embed_dim = 768, long intermediate_size = 768 * 4, long endToken = 49407, Device? device = null, ScalarType? dtype = null) : base(nameof(SDCliper)) 467 | { 468 | this.n_token = n_token; 469 | this.endToken = endToken; 470 | cond_stage_model = new ViT_L_Clip(n_vocab, n_token, num_layers, n_heads, embed_dim, intermediate_size, device: device, dtype: dtype); 471 | RegisterComponents(); 472 | } 473 | public override (Tensor, Tensor) forward(Tensor token, long num_skip) 474 | { 475 | using (NewDisposeScope()) 476 | { 477 | Device device = cond_stage_model.parameters().First().device; 478 | long padLength = n_token - token.shape[1]; 479 | Tensor token1 = functional.pad(token, new long[] { 0, padLength, 0, 0 }, value: endToken); 480 | return (cond_stage_model.forward(token1, num_skip, true).MoveToOuterDisposeScope(), zeros(1).MoveToOuterDisposeScope()); 481 | } 482 | } 483 | } 484 | 485 | internal class SDXLCliper : Module 486 | { 487 | private readonly Embedders conditioner; 488 | public SDXLCliper(long n_vocab = 49408, long n_token = 77, Device? device = null, ScalarType? dtype = null) : base(nameof(SDXLCliper)) 489 | { 490 | conditioner = new Embedders(n_token, device: device, dtype: dtype); 491 | RegisterComponents(); 492 | } 493 | 494 | public override (Tensor, Tensor) forward(Tensor token, long num_skip) 495 | { 496 | Device device = conditioner.parameters().First().device; 497 | token = token.to(device); 498 | return conditioner.forward(token); 499 | } 500 | 501 | private class Embedders : Module 502 | { 503 | private readonly ModuleList embedders; 504 | private readonly long n_token; 505 | private readonly long endToken; 506 | public Embedders(long n_token = 77, int endToken = 49407, Device? device = null, ScalarType? dtype = null) : base(nameof(Embedders)) 507 | { 508 | this.n_token = n_token; 509 | this.endToken = endToken; 510 | Model model = new Model(device: device, dtype: dtype); 511 | embedders = ModuleList(new ViT_L_Clip(device: device, dtype: dtype), model); 512 | RegisterComponents(); 513 | } 514 | public override (Tensor, Tensor) forward(Tensor token) 515 | { 516 | using (NewDisposeScope()) 517 | { 518 | long padLength = n_token - token.shape[1]; 519 | Tensor token1 = functional.pad(token, new long[] { 0, padLength, 0, 0 }, value: endToken); 520 | Tensor token2 = functional.pad(token, new long[] { 0, padLength, 0, 0 }); 521 | 522 | Tensor vit_l_result = ((ViT_L_Clip)embedders[0]).forward(token1, 1, false); 523 | Tensor vit_bigG_result = ((Model)embedders[1]).forward(token2, 1, false, false); 524 | Tensor vit_bigG_vec = ((Model)embedders[1]).forward(token2, 0, false, true); 525 | Tensor crossattn = cat(new Tensor[] { vit_l_result, vit_bigG_result }, -1); 526 | return (crossattn.MoveToOuterDisposeScope(), vit_bigG_vec.MoveToOuterDisposeScope()); 527 | } 528 | } 529 | } 530 | 531 | private class Model : Module 532 | { 533 | private readonly ViT_bigG_Clip model; 534 | public Model(Device? device = null, ScalarType? dtype = null) : base(nameof(Model)) 535 | { 536 | model = new ViT_bigG_Clip(device: device, dtype: dtype); 537 | RegisterComponents(); 538 | } 539 | public override Tensor forward(Tensor token, int num_skip, bool with_final_ln, bool return_pooled) 540 | { 541 | return model.forward(token, num_skip, with_final_ln, return_pooled); 542 | } 543 | } 544 | } 545 | 546 | } 547 | } 548 | -------------------------------------------------------------------------------- /StableDiffusionSharp/Modules/Esrgan.cs: -------------------------------------------------------------------------------- 1 | using StableDiffusionSharp.ModelLoader; 2 | using TorchSharp; 3 | using TorchSharp.Modules; 4 | using static TorchSharp.torch; 5 | using static TorchSharp.torch.nn; 6 | 7 | namespace StableDiffusionSharp.Modules 8 | { 9 | public class Esrgan : IDisposable 10 | { 11 | private readonly RRDBNet rrdbnet; 12 | Device device; 13 | ScalarType dtype; 14 | 15 | public Esrgan(int num_block = 23, SDDeviceType deviceType = SDDeviceType.CUDA, SDScalarType scalarType = SDScalarType.Float16) 16 | { 17 | torchvision.io.DefaultImager = new torchvision.io.SkiaImager(); 18 | device = new Device((DeviceType)deviceType); 19 | dtype = (ScalarType)scalarType; 20 | rrdbnet = new RRDBNet(num_in_ch: 3, num_out_ch: 3, num_feat: 64, num_block: num_block, num_grow_ch: 32, scale: 4, device: device, dtype: dtype); 21 | } 22 | 23 | /// 24 | /// Residual Dense Block. 25 | /// 26 | private class ResidualDenseBlock : Module 27 | { 28 | private readonly Conv2d conv1; 29 | private readonly Conv2d conv2; 30 | private readonly Conv2d conv3; 31 | private readonly Conv2d conv4; 32 | private readonly Conv2d conv5; 33 | private readonly LeakyReLU lrelu; 34 | 35 | /// 36 | /// Used in RRDB block in ESRGAN. 37 | /// 38 | /// Channel number of intermediate features. 39 | /// Channels for each growth. 40 | public ResidualDenseBlock(int num_feat = 64, int num_grow_ch = 32, Device? device = null, ScalarType? dtype = null) : base(nameof(ResidualDenseBlock)) 41 | { 42 | conv1 = Conv2d(num_feat, num_grow_ch, 3, 1, 1, device: device, dtype: dtype); 43 | conv2 = Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1, device: device, dtype: dtype); 44 | conv3 = Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1, device: device, dtype: dtype); 45 | conv4 = Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1, device: device, dtype: dtype); 46 | conv5 = Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1, device: device, dtype: dtype); 47 | lrelu = LeakyReLU(negative_slope: 0.2f, inplace: true); 48 | RegisterComponents(); 49 | } 50 | 51 | public override Tensor forward(Tensor x) 52 | { 53 | using (NewDisposeScope()) 54 | { 55 | Tensor x1 = lrelu.forward(conv1.forward(x)); 56 | Tensor x2 = lrelu.forward(conv2.forward(cat(new Tensor[] { x, x1 }, 1))); 57 | Tensor x3 = lrelu.forward(conv3.forward(cat(new Tensor[] { x, x1, x2 }, 1))); 58 | Tensor x4 = lrelu.forward(conv4.forward(cat(new Tensor[] { x, x1, x2, x3 }, 1))); 59 | Tensor x5 = conv5.forward(cat(new Tensor[] { x, x1, x2, x3, x4 }, 1)); 60 | // Empirically, we use 0.2 to scale the residual for better performance 61 | return (x5 * 0.2 + x).MoveToOuterDisposeScope(); 62 | } 63 | } 64 | } 65 | 66 | /// 67 | /// Residual in Residual Dense Block. 68 | /// 69 | private class RRDB : Module 70 | { 71 | private readonly ResidualDenseBlock rdb1; 72 | private readonly ResidualDenseBlock rdb2; 73 | private readonly ResidualDenseBlock rdb3; 74 | 75 | /// 76 | /// Used in RRDB-Net in ESRGAN. 77 | /// 78 | /// Channel number of intermediate features. 79 | /// Channels for each growth. 80 | public RRDB(int num_feat, int num_grow_ch = 32, Device? device = null, ScalarType? dtype = null) : base(nameof(RRDB)) 81 | { 82 | rdb1 = new ResidualDenseBlock(num_feat, num_grow_ch, device: device, dtype: dtype); 83 | rdb2 = new ResidualDenseBlock(num_feat, num_grow_ch, device: device, dtype: dtype); 84 | rdb3 = new ResidualDenseBlock(num_feat, num_grow_ch, device: device, dtype: dtype); 85 | RegisterComponents(); 86 | } 87 | public override Tensor forward(Tensor x) 88 | { 89 | using (NewDisposeScope()) 90 | { 91 | Tensor @out = rdb1.forward(x); 92 | @out = rdb2.forward(@out); 93 | @out = rdb3.forward(@out); 94 | // Empirically, we use 0.2 to scale the residual for better performance 95 | return (@out * 0.2 + x).MoveToOuterDisposeScope(); 96 | } 97 | } 98 | } 99 | 100 | private class RRDBNet : Module 101 | { 102 | private readonly int scale; 103 | private readonly Conv2d conv_first; 104 | private readonly Sequential body; 105 | private readonly Conv2d conv_body; 106 | private readonly Conv2d conv_up1; 107 | private readonly Conv2d conv_up2; 108 | private readonly Conv2d conv_hr; 109 | private readonly Conv2d conv_last; 110 | private readonly LeakyReLU lrelu; 111 | 112 | public RRDBNet(int num_in_ch, int num_out_ch, int scale = 4, int num_feat = 64, int num_block = 23, int num_grow_ch = 32, Device? device = null, ScalarType? dtype = null) : base(nameof(RRDBNet)) 113 | { 114 | this.scale = scale; 115 | if (scale == 2) 116 | { 117 | num_in_ch = num_in_ch * 4; 118 | } 119 | else if (scale == 1) 120 | { 121 | num_in_ch = num_in_ch * 16; 122 | } 123 | conv_first = Conv2d(num_in_ch, num_feat, 3, 1, 1, device: device, dtype: dtype); 124 | body = Sequential(); 125 | for (int i = 0; i < num_block; i++) 126 | { 127 | body.append(new RRDB(num_feat: num_feat, num_grow_ch: num_grow_ch, device: device, dtype: dtype)); 128 | } 129 | conv_body = Conv2d(num_feat, num_feat, 3, 1, 1, device: device, dtype: dtype); 130 | // upsample 131 | conv_up1 = Conv2d(num_feat, num_feat, 3, 1, 1, device: device, dtype: dtype); 132 | conv_up2 = Conv2d(num_feat, num_feat, 3, 1, 1, device: device, dtype: dtype); 133 | conv_hr = Conv2d(num_feat, num_feat, 3, 1, 1, device: device, dtype: dtype); 134 | conv_last = Conv2d(num_feat, num_out_ch, 3, 1, 1, device: device, dtype: dtype); 135 | lrelu = LeakyReLU(negative_slope: 0.2f, inplace: true); 136 | RegisterComponents(); 137 | } 138 | 139 | public override Tensor forward(Tensor x) 140 | { 141 | using (NewDisposeScope()) 142 | { 143 | Tensor feat = x; 144 | if (scale == 2) 145 | { 146 | feat = pixel_unshuffle(x, scale: 2); 147 | } 148 | else if (scale == 1) 149 | { 150 | feat = pixel_unshuffle(x, scale: 4); 151 | } 152 | feat = conv_first.forward(feat); 153 | Tensor body_feat = conv_body.forward(body.forward(feat)); 154 | feat = feat + body_feat; 155 | // upsample 156 | feat = lrelu.forward(conv_up1.forward(functional.interpolate(feat, scale_factor: new double[] { 2, 2 }, mode: InterpolationMode.Nearest))); 157 | feat = lrelu.forward(conv_up2.forward(functional.interpolate(feat, scale_factor: new double[] { 2, 2 }, mode: InterpolationMode.Nearest))); 158 | Tensor @out = conv_last.forward(lrelu.forward(conv_hr.forward(feat))); 159 | return @out.MoveToOuterDisposeScope(); 160 | } 161 | } 162 | 163 | /// 164 | /// Pixel unshuffle. 165 | /// 166 | /// Input feature with shape (b, c, hh, hw). 167 | /// Downsample ratio. 168 | /// the pixel unshuffled feature. 169 | private Tensor pixel_unshuffle(Tensor x, int scale) 170 | { 171 | long b = x.shape[0]; 172 | long c = x.shape[1]; 173 | long hh = x.shape[2]; 174 | long hw = x.shape[3]; 175 | 176 | long out_channel = c * (scale * scale); 177 | 178 | if (hh % scale != 0 && hw % scale != 0) 179 | { 180 | throw new ArgumentException("Width or Hight are not match"); 181 | } 182 | 183 | long h = hh / scale; 184 | long w = hw / scale; 185 | 186 | Tensor x_view = x.view(b, c, h, scale, w, scale); 187 | return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w); 188 | } 189 | } 190 | 191 | public void LoadModel(string path) 192 | { 193 | rrdbnet.LoadModel(path); 194 | rrdbnet.eval(); 195 | } 196 | 197 | public ImageMagick.MagickImage UpScale(ImageMagick.MagickImage inputImg) 198 | { 199 | using (no_grad()) 200 | { 201 | Tensor tensor = Tools.GetTensorFromImage(inputImg); 202 | tensor = tensor.unsqueeze(0) / 255.0; 203 | tensor = tensor.to(dtype, device); 204 | Tensor op = rrdbnet.forward(tensor); 205 | op = (op.cpu() * 255.0f).clamp(0, 255).@byte(); 206 | return Tools.GetImageFromTensor(op); 207 | } 208 | } 209 | 210 | public void Dispose() 211 | { 212 | rrdbnet?.Dispose(); 213 | } 214 | } 215 | } 216 | -------------------------------------------------------------------------------- /StableDiffusionSharp/Modules/SD1.cs: -------------------------------------------------------------------------------- 1 | using TorchSharp; 2 | using static TorchSharp.torch; 3 | 4 | namespace StableDiffusionSharp.Modules 5 | { 6 | public class SD1 : SDModel 7 | { 8 | public SD1(Device? device = null, ScalarType? dtype = null) : base(device, dtype) 9 | { 10 | torchvision.io.DefaultImager = new torchvision.io.SkiaImager(); 11 | this.device = device ?? torch.CPU; 12 | this.dtype = dtype ?? torch.float32; 13 | 14 | // Default parameters 15 | this.scale_factor = 0.18215f; 16 | 17 | // UNet config 18 | this.in_channels = 4; 19 | this.model_channels = 320; 20 | this.context_dim = 768; 21 | this.num_head = 8; 22 | this.dropout = 0.0f; 23 | this.embed_dim = 4; 24 | 25 | // first stage config: 26 | this.embed_dim = 4; 27 | this.double_z = true; 28 | this.z_channels = 4; 29 | 30 | } 31 | } 32 | 33 | } 34 | 35 | -------------------------------------------------------------------------------- /StableDiffusionSharp/Modules/SDModel.cs: -------------------------------------------------------------------------------- 1 | using StableDiffusionSharp.ModelLoader; 2 | using StableDiffusionSharp.Sampler; 3 | using System.Diagnostics; 4 | using System.Text; 5 | using TorchSharp; 6 | using static TorchSharp.torch; 7 | using static TorchSharp.torch.nn; 8 | 9 | namespace StableDiffusionSharp.Modules 10 | { 11 | public class SDModel : IDisposable 12 | { 13 | // Default parameters 14 | private float linear_start = 0.00085f; 15 | private float linear_end = 0.0120f; 16 | private int num_timesteps_cond = 1; 17 | private int timesteps = 1000; 18 | internal float scale_factor = 0.18215f; 19 | internal int adm_in_channels = 2816; 20 | 21 | // UNet config 22 | internal int in_channels = 4; 23 | internal int model_channels = 320; 24 | internal int context_dim = 768; 25 | internal int num_head = 8; 26 | internal float dropout = 0.0f; 27 | 28 | // first stage config: 29 | internal int embed_dim = 4; 30 | internal bool double_z = true; 31 | internal int z_channels = 4; 32 | 33 | public class StepEventArgs : EventArgs 34 | { 35 | public int CurrentStep { get; } 36 | public int TotalSteps { get; } 37 | public ImageMagick.MagickImage VAEApproxImg { get; } 38 | 39 | public StepEventArgs(int currentStep, int totalSteps, ImageMagick.MagickImage vAEApproxImg) 40 | { 41 | CurrentStep = currentStep; 42 | TotalSteps = totalSteps; 43 | VAEApproxImg = vAEApproxImg; 44 | } 45 | } 46 | 47 | public event EventHandler StepProgress; 48 | protected void OnStepProgress(int currentStep, int totalSteps, ImageMagick.MagickImage vaeApproxImg) 49 | { 50 | StepProgress?.Invoke(this, new StepEventArgs(currentStep, totalSteps, vaeApproxImg)); 51 | } 52 | 53 | internal Module cliper; 54 | internal Module diffusion; 55 | private VAE.Decoder decoder; 56 | private VAE.Encoder encoder; 57 | private Tokenizer tokenizer; 58 | private VAEApprox vaeApprox; 59 | 60 | internal Device device; 61 | internal ScalarType dtype; 62 | 63 | private int tempPromptHash; 64 | private Tensor tempTextContext; 65 | private Tensor tempPooled; 66 | 67 | bool is_loaded = false; 68 | 69 | public SDModel(Device? device = null, ScalarType? dtype = null) 70 | { 71 | torchvision.io.DefaultImager = new torchvision.io.SkiaImager(); 72 | this.device = device ?? torch.CPU; 73 | this.dtype = dtype ?? torch.float32; 74 | } 75 | 76 | public virtual void LoadModel(string modelPath, string vaeModelPath, string vocabPath = @".\models\clip\vocab.json", string mergesPath = @".\models\clip\merges.txt") 77 | { 78 | is_loaded = false; 79 | ModelType modelType = ModelLoader.ModelLoader.GetModelType(modelPath); 80 | 81 | cliper = modelType switch 82 | { 83 | ModelType.SD1 => new Clip.SDCliper(device: device, dtype: dtype), 84 | ModelType.SDXL => new Clip.SDXLCliper(device: device, dtype: dtype), 85 | _ => throw new ArgumentException("Invalid model type") 86 | }; 87 | cliper.eval(); 88 | 89 | diffusion = modelType switch 90 | { 91 | ModelType.SD1 => new SDUnet(model_channels, in_channels, num_head, context_dim, dropout, device: device, dtype: dtype), 92 | ModelType.SDXL => new SDXLUnet(model_channels, in_channels, num_head, context_dim, adm_in_channels, dropout, device: device, dtype: dtype), 93 | _ => throw new ArgumentException("Invalid model type") 94 | }; 95 | diffusion.eval(); 96 | 97 | decoder = new VAE.Decoder(embed_dim: embed_dim, z_channels: z_channels, device: device, dtype: dtype); 98 | decoder.eval(); 99 | encoder = new VAE.Encoder(embed_dim: embed_dim, z_channels: z_channels, double_z: double_z, device: device, dtype: dtype); 100 | encoder.eval(); 101 | 102 | vaeApprox = new VAEApprox(4, device, dtype); 103 | vaeApprox.eval(); 104 | 105 | vaeModelPath = string.IsNullOrEmpty(vaeModelPath) ? modelPath : vaeModelPath; 106 | 107 | cliper.LoadModel(modelPath); 108 | diffusion.LoadModel(modelPath); 109 | decoder.LoadModel(vaeModelPath, "first_stage_model."); 110 | encoder.LoadModel(vaeModelPath, "first_stage_model."); 111 | 112 | string vaeApproxPath = modelType switch 113 | { 114 | ModelType.SD1 => @".\Models\VAEApprox\vaeapp_sd15.pth", 115 | ModelType.SDXL => @".\Models\VAEApprox\xlvaeapp.pth", 116 | _ => throw new ArgumentException("Invalid model type") 117 | }; 118 | 119 | vaeApprox.LoadModel(vaeApproxPath); 120 | 121 | tokenizer = new Tokenizer(vocabPath, mergesPath); 122 | is_loaded = true; 123 | 124 | GC.Collect(); 125 | } 126 | 127 | private void CheckModelLoaded() 128 | { 129 | if (!is_loaded) 130 | { 131 | throw new InvalidOperationException("Model not loaded"); 132 | } 133 | } 134 | 135 | private static Tensor GetTimeEmbedding(Tensor timestep, int max_period = 10000, int dim = 320, bool repeat_only = false) 136 | { 137 | if (repeat_only) 138 | { 139 | return torch.repeat_interleave(timestep, dim); 140 | } 141 | else 142 | { 143 | int half = dim / 2; 144 | var freqs = torch.pow(max_period, -torch.arange(0, half, dtype: torch.float32) / half); 145 | var x = timestep * freqs.unsqueeze(0); 146 | x = torch.cat(new Tensor[] { x, x }); 147 | return torch.cat(new Tensor[] { torch.cos(x), torch.sin(x) }, dim: -1); 148 | } 149 | } 150 | 151 | private (Tensor, Tensor) Clip(string prompt, string nprompt, long clip_skip) 152 | { 153 | CheckModelLoaded(); 154 | if (tempPromptHash != (prompt + nprompt).GetHashCode()) 155 | { 156 | using (no_grad()) 157 | using (NewDisposeScope()) 158 | { 159 | Tensor cond_tokens = tokenizer.Tokenize(prompt).to(device); 160 | (Tensor cond_context, Tensor cond_pooled) = cliper.forward(cond_tokens, clip_skip); 161 | Tensor uncond_tokens = tokenizer.Tokenize(nprompt).to(device); 162 | (Tensor uncond_context, Tensor uncond_pooled) = cliper.forward(uncond_tokens, clip_skip); 163 | Tensor context = cat(new Tensor[] { cond_context, uncond_context }); 164 | tempPromptHash = (prompt + nprompt).GetHashCode(); 165 | tempTextContext = context; 166 | tempPooled = cat(new Tensor[] { cond_pooled, uncond_pooled }); 167 | tempTextContext = tempTextContext.MoveToOuterDisposeScope(); 168 | tempPooled = tempPooled.MoveToOuterDisposeScope(); 169 | } 170 | } 171 | return (tempTextContext, tempPooled); 172 | } 173 | 174 | /// 175 | /// Generate image from text 176 | /// 177 | /// Prompt 178 | /// Negtive Prompt 179 | /// Image width, must be multiples of 64, otherwise, it will be resized 180 | /// Image width, must be multiples of 64, otherwise, it will be resized 181 | /// Step to generate image 182 | /// Random seed for generating image, it will get random when the value is 0 183 | /// Classifier Free Guidance 184 | public virtual ImageMagick.MagickImage TextToImage(string prompt, string nprompt = "", long clip_skip = 0, int width = 512, int height = 512, int steps = 20, long seed = 0, float cfg = 7.0f, SDSamplerType samplerType = SDSamplerType.Euler) 185 | { 186 | CheckModelLoaded(); 187 | 188 | using (no_grad()) 189 | { 190 | if (steps < 1) 191 | { 192 | throw new ArgumentException("steps must be greater than 0"); 193 | } 194 | if (cfg < 0.5) 195 | { 196 | throw new ArgumentException("cfg is too small, it may cause the image to be too noisy"); 197 | } 198 | 199 | seed = seed == 0 ? Random.Shared.NextInt64() : seed; 200 | set_rng_state(manual_seed(seed).get_state()); 201 | 202 | width = width / 64 * 8; // must be multiples of 64 203 | height = height / 64 * 8; // must be multiples of 64 204 | Console.WriteLine("Device:" + device); 205 | Console.WriteLine("Type:" + dtype); 206 | Console.WriteLine("CFG:" + cfg); 207 | Console.WriteLine("Seed:" + seed); 208 | Console.WriteLine("Width:" + width * 8); 209 | Console.WriteLine("Height:" + height * 8); 210 | 211 | Stopwatch sp = Stopwatch.StartNew(); 212 | Console.WriteLine("Clip is doing......"); 213 | (Tensor context, Tensor vector) = Clip(prompt, nprompt, clip_skip); 214 | using var _ = NewDisposeScope(); 215 | Console.WriteLine("Getting latents......"); 216 | Tensor latents = randn(new long[] { 1, 4, height, width }).to(dtype, device); 217 | 218 | BasicSampler sampler = samplerType switch 219 | { 220 | SDSamplerType.Euler => new EulerSampler(timesteps, linear_start, linear_end, num_timesteps_cond), 221 | SDSamplerType.EulerAncestral => new EulerAncestralSampler(timesteps, linear_start, linear_end, num_timesteps_cond), 222 | _ => throw new ArgumentException("Unknown sampler type") 223 | }; 224 | 225 | sampler.SetTimesteps(steps); 226 | latents *= sampler.InitNoiseSigma(); 227 | 228 | Console.WriteLine($"begin sampling"); 229 | for (int i = 0; i < steps; i++) 230 | { 231 | Tensor approxTensor = vaeApprox.forward(latents); 232 | approxTensor = approxTensor * 127.5 + 127.5; 233 | approxTensor = approxTensor.clamp(0, 255).@byte().cpu(); 234 | ImageMagick.MagickImage approxImg = Tools.GetImageFromTensor(approxTensor); 235 | OnStepProgress(i + 1, steps, approxImg); 236 | Tensor timestep = sampler.Timesteps[i]; 237 | Tensor time_embedding = GetTimeEmbedding(timestep); 238 | Tensor input_latents = sampler.ScaleModelInput(latents, i); 239 | input_latents = input_latents.repeat(2, 1, 1, 1); 240 | Tensor output = diffusion.forward(input_latents, context, time_embedding, vector); 241 | Tensor[] ret = output.chunk(2); 242 | Tensor output_cond = ret[0]; 243 | Tensor output_uncond = ret[1]; 244 | output = cfg * (output_cond - output_uncond) + output_uncond; 245 | latents = sampler.Step(output, i, latents, seed); 246 | } 247 | Console.WriteLine($"end sampling"); 248 | Console.WriteLine($"begin decoder"); 249 | latents = latents / scale_factor; 250 | Tensor image = decoder.forward(latents); 251 | Console.WriteLine($"end decoder"); 252 | 253 | 254 | image = ((image + 0.5) * 255.0f).clamp(0, 255).@byte().cpu(); 255 | 256 | ImageMagick.MagickImage img = Tools.GetImageFromTensor(image); 257 | 258 | StringBuilder stringBuilder = new StringBuilder(); 259 | stringBuilder.AppendLine(prompt); 260 | if (!string.IsNullOrEmpty(nprompt)) 261 | { 262 | stringBuilder.AppendLine("Negative prompt: " + nprompt); 263 | } 264 | stringBuilder.AppendLine($"Steps: {steps}, CFG scale_factor: {cfg}, Seed: {seed}, Size: {width}x{height}, Version: StableDiffusionSharp"); 265 | img.SetAttribute("parameters", stringBuilder.ToString()); 266 | sp.Stop(); 267 | Console.WriteLine($"Total time is: {sp.ElapsedMilliseconds} ms."); 268 | return img; 269 | } 270 | } 271 | 272 | 273 | public virtual ImageMagick.MagickImage ImageToImage(ImageMagick.MagickImage orgImage, string prompt, string nprompt = "", long clip_skip = 0, int steps = 20, float strength = 0.75f, long seed = 0, long subSeed = 0, float cfg = 7.0f, SDSamplerType samplerType = SDSamplerType.Euler) 274 | { 275 | CheckModelLoaded(); 276 | 277 | using (no_grad()) 278 | { 279 | Stopwatch sp = Stopwatch.StartNew(); 280 | seed = seed == 0 ? Random.Shared.NextInt64() : seed; 281 | Generator generator = manual_seed(seed); 282 | set_rng_state(generator.get_state()); 283 | 284 | Console.WriteLine("Clip is doing......"); 285 | (Tensor context, Tensor vector) = Clip(prompt, nprompt, clip_skip); 286 | 287 | Console.WriteLine("Getting latents......"); 288 | Tensor inputTensor = Tools.GetTensorFromImage(orgImage).unsqueeze(0); 289 | inputTensor = inputTensor.to(dtype, device); 290 | inputTensor = inputTensor / 255.0f * 2 - 1.0f; 291 | Tensor lt = encoder.forward(inputTensor); 292 | 293 | Tensor[] mean_var = lt.chunk(2, 1); 294 | Tensor mean = mean_var[0]; 295 | Tensor logvar = mean_var[1].clamp(-30, 20); 296 | Tensor std = exp(0.5f * logvar); 297 | Tensor latents = mean + std * randn_like(mean); 298 | 299 | latents = latents * scale_factor; 300 | int t_enc = (int)(strength * steps) - 1; 301 | 302 | BasicSampler sampler = samplerType switch 303 | { 304 | SDSamplerType.Euler => new EulerSampler(timesteps, linear_start, linear_end, num_timesteps_cond), 305 | SDSamplerType.EulerAncestral => new EulerAncestralSampler(timesteps, linear_start, linear_end, num_timesteps_cond), 306 | _ => throw new ArgumentException("Unknown sampler type") 307 | }; 308 | 309 | sampler.SetTimesteps(steps); 310 | Tensor sigma_sched = sampler.Sigmas[(steps - t_enc - 1)..]; 311 | Tensor noise = randn_like(latents); 312 | latents = latents + noise * sigma_sched.max(); 313 | 314 | Console.WriteLine($"begin sampling"); 315 | for (int i = 0; i < sigma_sched.NumberOfElements - 1; i++) 316 | { 317 | Tensor approxTensor = vaeApprox.forward(latents); 318 | approxTensor = approxTensor * 127.5 + 127.5; 319 | approxTensor = approxTensor.clamp(0, 255).@byte().cpu(); 320 | ImageMagick.MagickImage approxImg = Tools.GetImageFromTensor(approxTensor); 321 | OnStepProgress(i + 1, steps, approxImg); 322 | 323 | int index = steps - t_enc + i - 1; 324 | Tensor timestep = sampler.Timesteps[index]; 325 | Tensor time_embedding = GetTimeEmbedding(timestep); 326 | Tensor input_latents = sampler.ScaleModelInput(latents, index); 327 | input_latents = input_latents.repeat(2, 1, 1, 1); 328 | Tensor output = diffusion.forward(input_latents, context, time_embedding, vector); 329 | Tensor[] ret = output.chunk(2); 330 | Tensor output_cond = ret[0]; 331 | Tensor output_uncond = ret[1]; 332 | Tensor noisePred = cfg * (output_cond - output_uncond) + output_uncond; 333 | latents = sampler.Step(noisePred, index, latents, seed); 334 | } 335 | Console.WriteLine($"end sampling"); 336 | Console.WriteLine($"begin decoder"); 337 | latents = latents / scale_factor; 338 | Tensor image = decoder.forward(latents); 339 | Console.WriteLine($"end decoder"); 340 | 341 | sp.Stop(); 342 | Console.WriteLine($"Total time is: {sp.ElapsedMilliseconds} ms."); 343 | image = ((image + 0.5) * 255.0f).clamp(0, 255).@byte().cpu(); 344 | 345 | ImageMagick.MagickImage img = Tools.GetImageFromTensor(image); 346 | 347 | StringBuilder stringBuilder = new StringBuilder(); 348 | stringBuilder.AppendLine(prompt); 349 | if (!string.IsNullOrEmpty(nprompt)) 350 | { 351 | stringBuilder.AppendLine("Negative prompt: " + nprompt); 352 | } 353 | stringBuilder.AppendLine($"Steps: {steps}, CFG scale_factor: {cfg}, Seed: {seed}, Size: {img.Width}x{img.Height}, Version: StableDiffusionSharp"); 354 | img.SetAttribute("parameters", stringBuilder.ToString()); 355 | return img; 356 | } 357 | } 358 | 359 | public void Dispose() 360 | { 361 | cliper?.Dispose(); 362 | diffusion?.Dispose(); 363 | decoder?.Dispose(); 364 | encoder?.Dispose(); 365 | tempTextContext?.Dispose(); 366 | } 367 | 368 | } 369 | 370 | } 371 | 372 | -------------------------------------------------------------------------------- /StableDiffusionSharp/Modules/SDXL.cs: -------------------------------------------------------------------------------- 1 | using TorchSharp; 2 | using static TorchSharp.torch; 3 | 4 | namespace StableDiffusionSharp.Modules 5 | { 6 | public class SDXL : SD1 7 | { 8 | public SDXL(Device? device = null, ScalarType? dtype = null) : base(device, dtype) 9 | { 10 | torchvision.io.DefaultImager = new torchvision.io.SkiaImager(); 11 | this.device = device ?? torch.CPU; 12 | this.dtype = dtype ?? torch.float32; 13 | 14 | this.scale_factor = 0.13025f; 15 | 16 | this.in_channels = 4; 17 | this.model_channels = 320; 18 | this.context_dim = 2048; 19 | this.num_head = 20; 20 | this.dropout = 0.0f; 21 | this.adm_in_channels = 2816; 22 | 23 | this.embed_dim = 4; 24 | this.double_z = true; 25 | this.z_channels = 4; 26 | } 27 | } 28 | } 29 | 30 | -------------------------------------------------------------------------------- /StableDiffusionSharp/Modules/Tokenizer.cs: -------------------------------------------------------------------------------- 1 | using Microsoft.ML.Tokenizers; 2 | using System.Reflection; 3 | using static TorchSharp.torch; 4 | 5 | namespace StableDiffusionSharp.Modules 6 | { 7 | internal class Tokenizer 8 | { 9 | private readonly BpeTokenizer _tokenizer; 10 | private readonly int _startToken; 11 | private readonly int _endToken; 12 | 13 | public Tokenizer(string vocabPath, string mergesPath, int startToken = 49406, int endToken = 49407) 14 | { 15 | if (!File.Exists(vocabPath)) 16 | { 17 | string path = Path.GetDirectoryName(vocabPath)!; 18 | if (!Directory.Exists(path)) 19 | { 20 | Directory.CreateDirectory(path); 21 | } 22 | Assembly _assembly = Assembly.GetExecutingAssembly(); 23 | string resourceName = "StableDiffusionSharp.Models.Clip.vocab.json"; 24 | using (Stream stream = _assembly.GetManifestResourceStream(resourceName)!) 25 | { 26 | if (stream == null) 27 | { 28 | Console.WriteLine("Resource can't find!"); 29 | return; 30 | } 31 | using (FileStream fileStream = new FileStream(vocabPath, FileMode.Create, FileAccess.Write)) 32 | { 33 | stream.CopyTo(fileStream); 34 | } 35 | } 36 | 37 | } 38 | 39 | if (!File.Exists(mergesPath)) 40 | { 41 | string path = Path.GetDirectoryName(mergesPath)!; 42 | if (!Directory.Exists(path)) 43 | { 44 | Directory.CreateDirectory(path); 45 | } 46 | Assembly _assembly = Assembly.GetExecutingAssembly(); 47 | string resourceName = "StableDiffusionSharp.Models.Clip.merges.txt"; 48 | using (Stream stream = _assembly.GetManifestResourceStream(resourceName)!) 49 | { 50 | if (stream == null) 51 | { 52 | Console.WriteLine("Resource can't find!"); 53 | return; 54 | } 55 | using (FileStream fileStream = new FileStream(mergesPath, FileMode.Create, FileAccess.Write)) 56 | { 57 | stream.CopyTo(fileStream); 58 | } 59 | } 60 | 61 | } 62 | 63 | _tokenizer = BpeTokenizer.Create(vocabPath, mergesPath, endOfWordSuffix: ""); 64 | _startToken = startToken; 65 | _endToken = endToken; 66 | } 67 | 68 | public Tensor Tokenize(string text, int maxTokens = 77) 69 | { 70 | var res = _tokenizer.EncodeToIds(text).ToList(); 71 | res.Insert(0, _startToken); 72 | res.Add(_endToken); 73 | return tensor(res, ScalarType.Int64).unsqueeze(0); 74 | } 75 | } 76 | } 77 | -------------------------------------------------------------------------------- /StableDiffusionSharp/Modules/Unet.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using TorchSharp; 3 | using TorchSharp.Modules; 4 | using static Tensorboard.CostGraphDef.Types; 5 | using static Tensorboard.TensorShapeProto.Types; 6 | using static TorchSharp.torch; 7 | using static TorchSharp.torch.nn; 8 | 9 | namespace StableDiffusionSharp.Modules 10 | { 11 | internal class CrossAttention : Module 12 | { 13 | private readonly Linear to_q; 14 | private readonly Linear to_k; 15 | private readonly Linear to_v; 16 | private readonly Sequential to_out; 17 | private readonly long n_heads_; 18 | private readonly long d_head; 19 | private readonly bool causal_mask_; 20 | 21 | public CrossAttention(long channels, long d_cross, long n_heads, bool causal_mask = false, bool in_proj_bias = false, bool out_proj_bias = true, float dropout_p = 0.0f, Device? device = null, ScalarType? dtype = null) : base(nameof(CrossAttention)) 22 | { 23 | to_q = Linear(channels, channels, hasBias: in_proj_bias, device: device, dtype: dtype); 24 | to_k = Linear(d_cross, channels, hasBias: in_proj_bias, device: device, dtype: dtype); 25 | to_v = Linear(d_cross, channels, hasBias: in_proj_bias, device: device, dtype: dtype); 26 | to_out = Sequential(Linear(channels, channels, hasBias: out_proj_bias, device: device, dtype: dtype), Dropout(dropout_p, inplace: false)); 27 | n_heads_ = n_heads; 28 | d_head = channels / n_heads; 29 | causal_mask_ = causal_mask; 30 | RegisterComponents(); 31 | } 32 | 33 | public override Tensor forward(Tensor x, Tensor y) 34 | { 35 | using (NewDisposeScope()) 36 | { 37 | long[] input_shape = x.shape; 38 | long batch_size = input_shape[0]; 39 | long sequence_length = input_shape[1]; 40 | 41 | long[] interim_shape = new long[] { batch_size, -1, n_heads_, d_head }; 42 | Tensor q = to_q.forward(x); 43 | Tensor k = to_k.forward(y); 44 | Tensor v = to_v.forward(y); 45 | 46 | q = q.view(interim_shape).transpose(1, 2); 47 | k = k.view(interim_shape).transpose(1, 2); 48 | v = v.view(interim_shape).transpose(1, 2); 49 | Tensor output = functional.scaled_dot_product_attention(q, k, v, is_casual: causal_mask_); 50 | output = output.transpose(1, 2).reshape(input_shape); 51 | output = to_out.forward(output); 52 | return output.MoveToOuterDisposeScope(); 53 | } 54 | } 55 | } 56 | 57 | internal class ResnetBlock : Module 58 | { 59 | private readonly int in_channels; 60 | private readonly int out_channels; 61 | 62 | private readonly Module skip_connection; 63 | private readonly Sequential emb_layers; 64 | private readonly Sequential in_layers; 65 | private readonly Sequential out_layers; 66 | 67 | public ResnetBlock(int in_channels, int out_channels, double dropout = 0.0, int temb_channels = 1280, Device? device = null, ScalarType? dtype = null) : base(nameof(ResnetBlock)) 68 | { 69 | this.in_channels = in_channels; 70 | out_channels = out_channels < 1 ? in_channels : out_channels; 71 | this.out_channels = out_channels; 72 | 73 | in_layers = Sequential(GroupNorm(32, in_channels, device: device, dtype: dtype), SiLU(), Conv2d(in_channels, out_channels, kernel_size: 3, stride: 1, padding: 1, device: device, dtype: dtype)); 74 | 75 | if (temb_channels > 0) 76 | { 77 | emb_layers = Sequential(SiLU(), Linear(temb_channels, out_channels, device: device, dtype: dtype)); 78 | } 79 | 80 | out_layers = Sequential(GroupNorm(32, out_channels, device: device, dtype: dtype), SiLU(), Dropout(dropout), Conv2d(out_channels, out_channels, kernel_size: 3, stride: 1, padding: 1, device: device, dtype: dtype)); 81 | 82 | if (this.in_channels != this.out_channels) 83 | { 84 | skip_connection = Conv2d(in_channels: in_channels, out_channels: this.out_channels, kernel_size: 1, stride: 1, device: device, dtype: dtype); 85 | } 86 | else 87 | { 88 | skip_connection = Identity(); 89 | } 90 | 91 | RegisterComponents(); 92 | } 93 | 94 | public override Tensor forward(Tensor x, Tensor time) 95 | { 96 | using (NewDisposeScope()) 97 | { 98 | Tensor hidden = x; 99 | hidden = in_layers.forward(hidden); 100 | 101 | if (time is not null) 102 | { 103 | time = emb_layers.forward(time); 104 | hidden = hidden + time.unsqueeze(-1).unsqueeze(-1); 105 | } 106 | 107 | hidden = out_layers.forward(hidden); 108 | if (in_channels != out_channels) 109 | { 110 | x = skip_connection.forward(x); 111 | } 112 | return (x + hidden).MoveToOuterDisposeScope(); 113 | } 114 | } 115 | } 116 | 117 | internal class TransformerBlock : Module 118 | { 119 | private LayerNorm norm1; 120 | private CrossAttention attn1; 121 | private LayerNorm norm2; 122 | private CrossAttention attn2; 123 | private LayerNorm norm3; 124 | private FeedForward ff; 125 | 126 | public TransformerBlock(int channels, int n_cross, int n_head, Device? device = null, ScalarType? dtype = null) : base(nameof(TransformerBlock)) 127 | { 128 | norm1 = LayerNorm(channels, device: device, dtype: dtype); 129 | attn1 = new CrossAttention(channels, channels, n_head, device: device, dtype: dtype); 130 | norm2 = LayerNorm(channels, device: device, dtype: dtype); 131 | attn2 = new CrossAttention(channels, n_cross, n_head, device: device, dtype: dtype); 132 | norm3 = LayerNorm(channels, device: device, dtype: dtype); 133 | ff = new FeedForward(channels, glu: true, device: device, dtype: dtype); 134 | RegisterComponents(); 135 | } 136 | public override Tensor forward(Tensor x, Tensor context) 137 | { 138 | var residue_short = x; 139 | x = norm1.forward(x); 140 | x = attn1.forward(x, x); 141 | x += residue_short; 142 | residue_short = x; 143 | x = norm2.forward(x); 144 | x = attn2.forward(x, context); 145 | x += residue_short; 146 | residue_short = x; 147 | x = norm3.forward(x); 148 | x = ff.forward(x); 149 | x += residue_short; 150 | return x.MoveToOuterDisposeScope(); 151 | } 152 | } 153 | 154 | internal class SpatialTransformer : Module 155 | { 156 | private readonly GroupNorm norm; 157 | private readonly Module proj_in; 158 | private readonly Module proj_out; 159 | private readonly ModuleList> transformer_blocks; 160 | private readonly bool use_linear; 161 | 162 | public SpatialTransformer(int channels, int n_cross, int n_head, int num_atten_blocks, float drop_out = 0.0f, bool use_linear = false, Device? device = null, ScalarType? dtype = null) : base(nameof(SpatialTransformer)) 163 | { 164 | norm = Normalize(channels, device: device, dtype: dtype); 165 | this.use_linear = use_linear; 166 | proj_in = use_linear ? Linear(channels, channels, device: device, dtype: dtype) : Conv2d(channels, channels, kernel_size: 1, device: device, dtype: dtype); 167 | proj_out = use_linear ? Linear(channels, channels, device: device, dtype: dtype) : Conv2d(channels, channels, kernel_size: 1, device: device, dtype: dtype); 168 | transformer_blocks = new ModuleList>(); 169 | for (int i = 0; i < num_atten_blocks; i++) 170 | { 171 | transformer_blocks.Add(new TransformerBlock(channels, n_cross, n_head, device: device, dtype: dtype)); 172 | } 173 | RegisterComponents(); 174 | } 175 | 176 | public override Tensor forward(Tensor x, Tensor context) 177 | { 178 | using (NewDisposeScope()) 179 | { 180 | long n = x.shape[0]; 181 | long c = x.shape[1]; 182 | long h = x.shape[2]; 183 | long w = x.shape[3]; 184 | 185 | Tensor residue_short = x; 186 | x = norm.forward(x); 187 | 188 | if (!use_linear) 189 | { 190 | x = proj_in.forward(x); 191 | } 192 | 193 | x = x.view(new long[] { n, c, h * w }); 194 | x = x.transpose(-1, -2); 195 | 196 | if (use_linear) 197 | { 198 | x = proj_in.forward(x); 199 | } 200 | 201 | foreach (Module layer in transformer_blocks) 202 | { 203 | x = layer.forward(x, context); 204 | } 205 | 206 | if (use_linear) 207 | { 208 | x = proj_out.forward(x); 209 | } 210 | x = x.transpose(-1, -2); 211 | x = x.view(new long[] { n, c, h, w }); 212 | if (!use_linear) 213 | { 214 | x = proj_out.forward(x); 215 | } 216 | 217 | residue_short = residue_short + x; 218 | return residue_short.MoveToOuterDisposeScope(); 219 | } 220 | } 221 | 222 | private static GroupNorm Normalize(int in_channels, int num_groups = 32, float eps = 1e-6f, bool affine = true, Device? device = null, ScalarType? dtype = null) 223 | { 224 | return GroupNorm(num_groups: 32, num_channels: in_channels, eps: eps, affine: affine, device: device, dtype: dtype); 225 | } 226 | 227 | } 228 | 229 | internal class Upsample : Module 230 | { 231 | private readonly Conv2d? conv; 232 | private readonly bool with_conv; 233 | public Upsample(int in_channels, bool with_conv = true, Device? device = null, ScalarType? dtype = null) : base(nameof(Upsample)) 234 | { 235 | this.with_conv = with_conv; 236 | if (with_conv) 237 | { 238 | conv = Conv2d(in_channels, in_channels, kernel_size: 3, padding: 1, device: device, dtype: dtype); 239 | } 240 | RegisterComponents(); 241 | } 242 | public override Tensor forward(Tensor x) 243 | { 244 | var output = functional.interpolate(x, scale_factor: new double[] { 2.0, 2.0 }, mode: InterpolationMode.Nearest); 245 | if (with_conv && conv is not null) 246 | { 247 | output = conv.forward(output); 248 | } 249 | return output; 250 | } 251 | } 252 | 253 | internal class Downsample : Module 254 | { 255 | private readonly Conv2d op; 256 | public Downsample(int in_channels, Device? device = null, ScalarType? dtype = null) : base(nameof(Downsample)) 257 | { 258 | op = Conv2d(in_channels: in_channels, out_channels: in_channels, kernel_size: 3, stride: 2, padding: 1, device: device, dtype: dtype); 259 | RegisterComponents(); 260 | } 261 | public override Tensor forward(Tensor x) 262 | { 263 | x = op.forward(x); 264 | return x; 265 | } 266 | } 267 | 268 | internal class TimestepEmbedSequential : Sequential 269 | { 270 | internal TimestepEmbedSequential(params (string name, Module)[] modules) : base(modules) 271 | { 272 | RegisterComponents(); 273 | } 274 | 275 | internal TimestepEmbedSequential(params Module[] modules) : base(modules) 276 | { 277 | RegisterComponents(); 278 | } 279 | 280 | public override Tensor forward(Tensor x, Tensor context, Tensor time) 281 | { 282 | using (NewDisposeScope()) 283 | { 284 | foreach (var layer in children()) 285 | { 286 | switch (layer) 287 | { 288 | case ResnetBlock res: 289 | x = res.call(x, time); 290 | break; 291 | case SpatialTransformer abl: 292 | x = abl.call(x, context); 293 | break; 294 | case Module m: 295 | x = m.call(x); 296 | break; 297 | } 298 | } 299 | return x.MoveToOuterDisposeScope(); 300 | } 301 | } 302 | } 303 | 304 | internal class GEGLU : Module 305 | { 306 | private readonly Linear proj; 307 | public GEGLU(int dim_in, int dim_out, Device? device = null, ScalarType? dtype = null) : base(nameof(GEGLU)) 308 | { 309 | proj = Linear(dim_in, dim_out * 2, device: device, dtype: dtype); 310 | RegisterComponents(); 311 | } 312 | 313 | public override Tensor forward(Tensor x) 314 | { 315 | using (NewDisposeScope()) 316 | { 317 | Tensor[] result = proj.forward(x).chunk(2, dim: -1); 318 | x = result[0]; 319 | Tensor gate = result[1]; 320 | return (x * functional.gelu(gate)).MoveToOuterDisposeScope(); 321 | } 322 | } 323 | } 324 | 325 | internal class FeedForward : Module 326 | { 327 | private readonly Sequential net; 328 | 329 | public FeedForward(int dim, int? dim_out = null, int mult = 4, bool glu = true, float dropout = 0.0f, Device? device = null, ScalarType? dtype = null) : base(nameof(FeedForward)) 330 | { 331 | int inner_dim = dim * mult; 332 | int dim_ot = dim_out ?? dim; 333 | Module project_in = glu ? new GEGLU(dim, inner_dim, device: device, dtype: dtype) : Sequential(nn.Linear(dim, inner_dim, device: device, dtype: dtype), nn.GELU()); 334 | net = Sequential(project_in, Dropout(dropout), Linear(inner_dim, dim_ot, device: device, dtype: dtype)); 335 | RegisterComponents(); 336 | } 337 | 338 | public override Tensor forward(Tensor input) 339 | { 340 | return net.forward(input); 341 | } 342 | } 343 | 344 | internal class SDUnet : Module 345 | { 346 | private class UNet : Module 347 | { 348 | private readonly int ch; 349 | private readonly int time_embed_dim; 350 | private readonly int in_channels; 351 | private readonly bool use_timestep; 352 | 353 | private readonly Sequential time_embed; 354 | private readonly ModuleList input_blocks; 355 | private readonly TimestepEmbedSequential middle_block; 356 | private readonly ModuleList output_blocks; 357 | private readonly Sequential @out; 358 | 359 | public UNet(int model_channels, int in_channels, int[]? channel_mult = null, int num_res_blocks = 2, int num_atten_blocks = 1, int context_dim = 768, int num_heads = 8, float dropout = 0.0f, bool use_timestep = true, Device? device = null, ScalarType? dtype = null) : base(nameof(UNet)) 360 | { 361 | bool mask = false; 362 | channel_mult = channel_mult ?? new int[] { 1, 2, 4, 4 }; 363 | 364 | ch = model_channels; 365 | time_embed_dim = model_channels * 4; 366 | this.in_channels = in_channels; 367 | this.use_timestep = use_timestep; 368 | 369 | List input_block_channels = new List { model_channels }; 370 | 371 | if (use_timestep) 372 | { 373 | // timestep embedding 374 | time_embed = Sequential(new Module[] { Linear(model_channels, time_embed_dim, device: device, dtype: dtype), SiLU(), Linear(time_embed_dim, time_embed_dim, device: device, dtype: dtype) }); 375 | } 376 | 377 | // downsampling 378 | input_blocks = new ModuleList(); 379 | input_blocks.Add(new TimestepEmbedSequential(Conv2d(in_channels, ch, kernel_size: 3, padding: 1, device: device, dtype: dtype))); 380 | 381 | for (int i = 0; i < channel_mult.Length; i++) 382 | { 383 | int in_ch = model_channels * channel_mult[i > 0 ? i - 1 : i]; 384 | int out_ch = model_channels * channel_mult[i]; 385 | 386 | for (int j = 0; j < num_res_blocks; j++) 387 | { 388 | input_blocks.Add(new TimestepEmbedSequential(new ResnetBlock(in_ch, out_ch, dropout, time_embed_dim, device: device, dtype: dtype), i < channel_mult.Length - 1 ? new SpatialTransformer(out_ch, context_dim, num_heads, num_atten_blocks, dropout, device: device, dtype: dtype) : Identity())); 389 | input_block_channels.Add(in_ch); 390 | in_ch = out_ch; 391 | } 392 | if (i < channel_mult.Length - 1) 393 | { 394 | input_blocks.Add(new TimestepEmbedSequential(Sequential(("op", Conv2d(out_ch, out_ch, 3, stride: 2, padding: 1, device: device, dtype: dtype))))); 395 | input_block_channels.Add(out_ch); 396 | } 397 | } 398 | 399 | // middle block 400 | middle_block = new TimestepEmbedSequential(new ResnetBlock(time_embed_dim, time_embed_dim, dropout, time_embed_dim, device: device, dtype: dtype), new SpatialTransformer(time_embed_dim, context_dim, num_heads, num_atten_blocks, dropout, device: device, dtype: dtype), new ResnetBlock(1280, 1280, device: device, dtype: dtype)); 401 | 402 | // upsampling 403 | var reversed_mult = channel_mult.Reverse().ToList(); 404 | int prev_channels = time_embed_dim; 405 | output_blocks = new ModuleList(); 406 | for (int i = 0; i < reversed_mult.Count; i++) 407 | { 408 | int mult = reversed_mult[i]; 409 | int current_channels = model_channels * mult; 410 | int down_stage_index = channel_mult.Length - 1 - i; 411 | int skip_channels = model_channels * channel_mult[down_stage_index]; 412 | bool has_atten = i >= 1; 413 | 414 | for (int j = 0; j < num_res_blocks + 1; j++) 415 | { 416 | int current_skip = skip_channels; 417 | if (j == num_res_blocks && i < reversed_mult.Count - 1) 418 | { 419 | int next_down_stage_index = channel_mult.Length - 1 - (i + 1); 420 | current_skip = model_channels * channel_mult[next_down_stage_index]; 421 | } 422 | 423 | int input_channels = prev_channels + current_skip; 424 | bool has_upsample = j == num_res_blocks && i != reversed_mult.Count - 1; 425 | 426 | if (has_atten) 427 | { 428 | output_blocks.Add(new TimestepEmbedSequential( 429 | new ResnetBlock(input_channels, current_channels, dropout, time_embed_dim, device: device, dtype: dtype), 430 | new SpatialTransformer(current_channels, context_dim, num_heads, num_atten_blocks, dropout, device: device, dtype: dtype), 431 | has_upsample ? new Upsample(current_channels, device: device, dtype: dtype) : Identity())); 432 | } 433 | else 434 | { 435 | output_blocks.Add(new TimestepEmbedSequential( 436 | new ResnetBlock(input_channels, current_channels, dropout, time_embed_dim, device: device, dtype: dtype), 437 | has_upsample ? new Upsample(current_channels, device: device, dtype: dtype) : Identity())); 438 | } 439 | 440 | prev_channels = current_channels; 441 | } 442 | } 443 | 444 | @out = Sequential(GroupNorm(32, model_channels, device: device, dtype: dtype), SiLU(), Conv2d(model_channels, in_channels, kernel_size: 3, padding: 1, device: device, dtype: dtype)); 445 | 446 | RegisterComponents(); 447 | 448 | } 449 | public override Tensor forward(Tensor x, Tensor context, Tensor time) 450 | { 451 | using (NewDisposeScope()) 452 | { 453 | time = time_embed.forward(time); 454 | 455 | List skip_connections = new List(); 456 | foreach (TimestepEmbedSequential layers in input_blocks) 457 | { 458 | x = layers.forward(x, context, time); 459 | skip_connections.Add(x); 460 | } 461 | x = middle_block.forward(x, context, time); 462 | foreach (TimestepEmbedSequential layers in output_blocks) 463 | { 464 | Tensor index = skip_connections.Last(); 465 | x = cat(new Tensor[] { x, index }, 1); 466 | skip_connections.RemoveAt(skip_connections.Count - 1); 467 | x = layers.forward(x, context, time); 468 | } 469 | 470 | x = @out.forward(x); 471 | return x.MoveToOuterDisposeScope(); 472 | } 473 | } 474 | } 475 | 476 | private class Model : Module 477 | { 478 | private readonly UNet diffusion_model; 479 | 480 | public Model(int model_channels, int in_channels, int num_heads = 8, int context_dim = 768, float dropout = 0.0f, bool use_timestep = true, Device? device = null, ScalarType? dtype = null) : base(nameof(SDUnet)) 481 | { 482 | diffusion_model = new UNet(model_channels, in_channels, context_dim: context_dim, num_heads: num_heads, dropout: dropout, use_timestep: use_timestep, device: device, dtype: dtype); 483 | RegisterComponents(); 484 | } 485 | 486 | public override Tensor forward(Tensor latent, Tensor context, Tensor time) 487 | { 488 | return diffusion_model.forward(latent, context, time); 489 | } 490 | } 491 | 492 | private readonly Model model; 493 | 494 | public SDUnet(int model_channels, int in_channels, int num_heads = 8, int context_dim = 768, float dropout = 0.0f, bool use_timestep = true, Device? device = null, ScalarType? dtype = null) : base(nameof(SDUnet)) 495 | { 496 | model = new Model(model_channels, in_channels, context_dim: context_dim, num_heads: num_heads, dropout: dropout, use_timestep: use_timestep, device: device, dtype: dtype); 497 | RegisterComponents(); 498 | } 499 | 500 | public override Tensor forward(Tensor latent, Tensor context, Tensor time, Tensor y) 501 | { 502 | Device device = model.parameters().First().device; 503 | ScalarType dtype = model.parameters().First().dtype; 504 | 505 | latent = latent.to(dtype, device); 506 | time = time.to(dtype, device); 507 | context = context.to(dtype, device); 508 | return model.forward(latent, context, time); 509 | } 510 | } 511 | 512 | internal class SDXLUnet : Module 513 | { 514 | private class UNet : Module 515 | { 516 | private readonly int ch; 517 | private readonly int time_embed_dim; 518 | private readonly int in_channels; 519 | private readonly bool use_timestep; 520 | 521 | private readonly Sequential time_embed; 522 | private readonly Sequential label_emb; 523 | private readonly ModuleList input_blocks; 524 | private readonly TimestepEmbedSequential middle_block; 525 | private readonly ModuleList output_blocks; 526 | private readonly Sequential @out; 527 | 528 | 529 | public UNet(int model_channels, int in_channels, int[]? channel_mult = null, int num_res_blocks = 2, int context_dim = 768, int adm_in_channels = 2816, int num_heads = 20, float dropout = 0.0f, bool use_timestep = true, Device? device = null, ScalarType? dtype = null) : base(nameof(SDUnet)) 530 | { 531 | channel_mult = channel_mult ?? new int[] { 1, 2, 4 }; 532 | 533 | ch = model_channels; 534 | time_embed_dim = model_channels * 4; 535 | this.in_channels = in_channels; 536 | this.use_timestep = use_timestep; 537 | 538 | bool useLinear = true; 539 | bool mask = false; 540 | 541 | List input_block_channels = new List { model_channels }; 542 | 543 | if (use_timestep) 544 | { 545 | int time_embed_dim = model_channels * 4; 546 | time_embed = Sequential(Linear(model_channels, time_embed_dim, device: device, dtype: dtype), SiLU(), Linear(time_embed_dim, time_embed_dim, device: device, dtype: dtype)); 547 | label_emb = Sequential(Sequential(Linear(adm_in_channels, time_embed_dim, device: device, dtype: dtype), SiLU(), Linear(time_embed_dim, time_embed_dim, device: device, dtype: dtype))); 548 | } 549 | 550 | // downsampling 551 | input_blocks = new ModuleList(); 552 | input_blocks.Add(new TimestepEmbedSequential(Conv2d(in_channels, ch, kernel_size: 3, padding: 1, device: device, dtype: dtype))); 553 | 554 | input_blocks.Add(new TimestepEmbedSequential(new ResnetBlock(320, 320, device: device, dtype: dtype))); 555 | input_blocks.Add(new TimestepEmbedSequential(new ResnetBlock(320, 320, device: device, dtype: dtype))); 556 | input_blocks.Add(new TimestepEmbedSequential(new Downsample(320, device: device, dtype: dtype))); 557 | 558 | input_blocks.Add(new TimestepEmbedSequential(new ResnetBlock(320, 640, device: device, dtype: dtype), new SpatialTransformer(640, 2048, num_heads, 2, 0, useLinear, device: device, dtype: dtype))); 559 | input_blocks.Add(new TimestepEmbedSequential(new ResnetBlock(640, 640, device: device, dtype: dtype), new SpatialTransformer(640, 2048, num_heads, 2, 0, useLinear, device: device, dtype: dtype))); 560 | input_blocks.Add(new TimestepEmbedSequential(new Downsample(640, device: device, dtype: dtype))); 561 | 562 | input_blocks.Add(new TimestepEmbedSequential(new ResnetBlock(640, 1280, device: device, dtype: dtype), new SpatialTransformer(1280, 2048, num_heads, 10, 0, useLinear, device: device, dtype: dtype))); 563 | input_blocks.Add(new TimestepEmbedSequential(new ResnetBlock(1280, 1280, device: device, dtype: dtype), new SpatialTransformer(1280, 2048, num_heads, 10, 0, useLinear, device: device, dtype: dtype))); 564 | 565 | // mid_block 566 | middle_block = new TimestepEmbedSequential(new ResnetBlock(1280, 1280, device: device, dtype: dtype), new SpatialTransformer(1280, 2048, num_heads, 10, 0, useLinear, device: device, dtype: dtype), new ResnetBlock(1280, 1280, device: device, dtype: dtype)); 567 | 568 | // upsampling 569 | output_blocks = new ModuleList(); 570 | output_blocks.Add(new TimestepEmbedSequential(new ResnetBlock(2560, 1280, device: device, dtype: dtype), new SpatialTransformer(1280, 2048, num_heads, 10, 0, useLinear, device: device, dtype: dtype))); 571 | output_blocks.Add(new TimestepEmbedSequential(new ResnetBlock(2560, 1280, device: device, dtype: dtype), new SpatialTransformer(1280, 2048, num_heads, 10, 0, useLinear, device: device, dtype: dtype))); 572 | output_blocks.Add(new TimestepEmbedSequential(new ResnetBlock(1920, 1280, device: device, dtype: dtype), new SpatialTransformer(1280, 2048, num_heads, 10, 0, useLinear, device: device, dtype: dtype), new Upsample(1280, device: device, dtype: dtype))); 573 | 574 | output_blocks.Add(new TimestepEmbedSequential(new ResnetBlock(1920, 640, device: device, dtype: dtype), new SpatialTransformer(640, 2048, num_heads, 2, 0, useLinear, device: device, dtype: dtype))); 575 | output_blocks.Add(new TimestepEmbedSequential(new ResnetBlock(1280, 640, device: device, dtype: dtype), new SpatialTransformer(640, 2048, num_heads, 2, 0, useLinear, device: device, dtype: dtype))); 576 | output_blocks.Add(new TimestepEmbedSequential(new ResnetBlock(960, 640, device: device, dtype: dtype), new SpatialTransformer(640, 2048, num_heads, 2, 0, useLinear, device: device, dtype: dtype), new Upsample(640, device: device, dtype: dtype))); 577 | 578 | output_blocks.Add(new TimestepEmbedSequential(new ResnetBlock(960, 320, device: device, dtype: dtype))); 579 | output_blocks.Add(new TimestepEmbedSequential(new ResnetBlock(640, 320, device: device, dtype: dtype))); 580 | output_blocks.Add(new TimestepEmbedSequential(new ResnetBlock(640, 320, device: device, dtype: dtype))); 581 | 582 | @out = Sequential(GroupNorm(32, model_channels, device: device, dtype: dtype), SiLU(), Conv2d(model_channels, in_channels, kernel_size: 3, padding: 1, device: device, dtype: dtype)); 583 | 584 | RegisterComponents(); 585 | } 586 | 587 | public override Tensor forward(Tensor x, Tensor context, Tensor time, Tensor y) 588 | { 589 | using (NewDisposeScope()) 590 | { 591 | int dim = 512; 592 | Tensor embed = time_embed.forward(time); 593 | Tensor time_ids = tensor(new float[] { dim, dim, 0, 0, dim, dim }, embed.dtype, embed.device).repeat(new long[] { 2, 1 }); 594 | Tensor time_embeds = get_timestep_embedding(time_ids.flatten(), dim / 2, true, 0, 1); 595 | time_embeds = time_embeds.reshape(new long[] { 2, -1 }); 596 | y = cat(new Tensor[] { y, time_embeds }, dim: -1); 597 | Tensor label_embed = label_emb.forward(y.to(embed.dtype, embed.device)); 598 | embed = embed + label_embed; 599 | 600 | List skip_connections = new List(); 601 | foreach (TimestepEmbedSequential layers in input_blocks) 602 | { 603 | x = layers.forward(x, context, embed); 604 | skip_connections.Add(x); 605 | } 606 | x = middle_block.forward(x, context, embed); 607 | foreach (TimestepEmbedSequential layers in output_blocks) 608 | { 609 | Tensor index = skip_connections.Last(); 610 | x = cat(new Tensor[] { x, index }, 1); 611 | skip_connections.RemoveAt(skip_connections.Count - 1); 612 | x = layers.forward(x, context, embed); 613 | } 614 | 615 | x = @out.forward(x); 616 | return x.MoveToOuterDisposeScope(); 617 | } 618 | 619 | } 620 | } 621 | 622 | private class Model : Module 623 | { 624 | private UNet diffusion_model; 625 | public Model(int model_channels, int in_channels, int num_heads = 20, int context_dim = 2048, int adm_in_channels = 2816, float dropout = 0.0f, bool use_timestep = true, Device? device = null, ScalarType? dtype = null) : base(nameof(SDUnet)) 626 | { 627 | diffusion_model = new UNet(model_channels, in_channels, context_dim: context_dim, adm_in_channels: adm_in_channels, num_heads: num_heads, dropout: dropout, use_timestep: use_timestep, device: device, dtype: dtype); 628 | RegisterComponents(); 629 | } 630 | 631 | public override Tensor forward(Tensor latent, Tensor context, Tensor time, Tensor y) 632 | { 633 | latent = diffusion_model.forward(latent, context, time, y); 634 | return latent; 635 | } 636 | } 637 | 638 | private readonly Model model; 639 | 640 | public SDXLUnet(int model_channels, int in_channels, int num_heads = 20, int context_dim = 2048, int adm_in_channels = 2816, float dropout = 0.0f, bool use_timestep = true, Device? device = null, ScalarType? dtype = null) : base(nameof(SDUnet)) 641 | { 642 | model = new Model(model_channels, in_channels, context_dim: context_dim, adm_in_channels: adm_in_channels, num_heads: num_heads, dropout: dropout, use_timestep: use_timestep, device: device, dtype: dtype); 643 | RegisterComponents(); 644 | } 645 | 646 | public override Tensor forward(Tensor latent, Tensor context, Tensor time, Tensor y) 647 | { 648 | Device device = model.parameters().First().device; 649 | ScalarType dtype = model.parameters().First().dtype; 650 | 651 | latent = latent.to(dtype, device); 652 | time = time.to(dtype, device); 653 | y = y.to(dtype, device); 654 | context = context.to(dtype, device); 655 | 656 | latent = model.forward(latent, context, time, y); 657 | return latent; 658 | } 659 | 660 | /// 661 | /// This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. 662 | /// 663 | /// a 1-D Tensor of N indices, one per batch element. These may be fractional. 664 | /// the dimension of the output. 665 | /// Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False) 666 | /// Controls the delta between frequencies between dimensions 667 | /// Scaling factor applied to the embeddings. 668 | /// Controls the maximum frequency of the embeddings 669 | /// torch.Tensor: an [N x dim] Tensor of positional embeddings. 670 | private static Tensor get_timestep_embedding(Tensor timesteps, int embedding_dim, bool flip_sin_to_cos = false, float downscale_freq_shift = 1, float scale = 1, int max_period = 10000) 671 | { 672 | using (NewDisposeScope()) 673 | { 674 | if (timesteps.Dimensions != 1) 675 | { 676 | throw new ArgumentOutOfRangeException("Timesteps should be a 1d-array"); 677 | } 678 | int half_dim = embedding_dim / 2; 679 | Tensor exponent = -Math.Log(max_period) * torch.arange(start: 0, stop: half_dim, dtype: torch.float32, device: timesteps.device); 680 | exponent = exponent / (half_dim - downscale_freq_shift); 681 | Tensor emb = torch.exp(exponent); 682 | emb = timesteps[.., TensorIndex.None].@float() * emb[TensorIndex.None, ..]; 683 | 684 | // scale embeddings 685 | emb = scale * emb; 686 | 687 | // concat sine and cosine embeddings 688 | emb = torch.cat(new Tensor[] { torch.sin(emb), torch.cos(emb) }, dim: -1); 689 | 690 | // flip sine and cosine embeddings 691 | if (flip_sin_to_cos) 692 | { 693 | emb = torch.cat(new Tensor[] { emb[.., half_dim..], emb[.., ..half_dim] }, dim: -1); 694 | } 695 | 696 | // zero pad 697 | if (embedding_dim % 2 == 1) 698 | { 699 | emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)); 700 | } 701 | return emb.MoveToOuterDisposeScope(); 702 | } 703 | } 704 | 705 | } 706 | } 707 | -------------------------------------------------------------------------------- /StableDiffusionSharp/Modules/VAE.cs: -------------------------------------------------------------------------------- 1 | using TorchSharp; 2 | using TorchSharp.Modules; 3 | using static TorchSharp.torch; 4 | using static TorchSharp.torch.nn; 5 | 6 | namespace StableDiffusionSharp.Modules 7 | { 8 | internal class VAE 9 | { 10 | private static GroupNorm Normalize(int in_channels, int num_groups = 32, float eps = 1e-6f, bool affine = true, Device? device = null, ScalarType? dtype = null) 11 | { 12 | return GroupNorm(num_groups: num_groups, num_channels: in_channels, eps: eps, affine: affine, device: device, dtype: dtype); 13 | } 14 | 15 | private class ResnetBlock : Module 16 | { 17 | private readonly int in_channels; 18 | private readonly int out_channels; 19 | private readonly GroupNorm norm1; 20 | private readonly Conv2d conv1; 21 | private readonly GroupNorm norm2; 22 | private readonly Conv2d conv2; 23 | private readonly Module nin_shortcut; 24 | private readonly SiLU swish; 25 | 26 | public ResnetBlock(int in_channels, int out_channels, Device? device = null, ScalarType? dtype = null) : base(nameof(AttnBlock)) 27 | { 28 | this.in_channels = in_channels; 29 | this.out_channels = out_channels; 30 | norm1 = Normalize(in_channels, device: device, dtype: dtype); 31 | conv1 = Conv2d(in_channels, out_channels, kernel_size: 3, stride: 1, padding: 1, device: device, dtype: dtype); 32 | norm2 = Normalize(out_channels, device: device, dtype: dtype); 33 | conv2 = Conv2d(out_channels, out_channels, kernel_size: 3, stride: 1, padding: 1, device: device, dtype: dtype); 34 | 35 | if (this.in_channels != this.out_channels) 36 | { 37 | nin_shortcut = Conv2d(in_channels: in_channels, out_channels: out_channels, kernel_size: 1, device: device, dtype: dtype); 38 | } 39 | else 40 | { 41 | nin_shortcut = Identity(); 42 | } 43 | 44 | swish = SiLU(inplace: true); 45 | RegisterComponents(); 46 | } 47 | 48 | public override Tensor forward(Tensor x) 49 | { 50 | Tensor hidden = x; 51 | hidden = norm1.forward(hidden); 52 | hidden = swish.forward(hidden); 53 | hidden = conv1.forward(hidden); 54 | hidden = norm2.forward(hidden); 55 | hidden = swish.forward(hidden); 56 | hidden = conv2.forward(hidden); 57 | if (in_channels != out_channels) 58 | { 59 | x = nin_shortcut.forward(x); 60 | } 61 | return x + hidden; 62 | } 63 | } 64 | 65 | private class AttnBlock : Module 66 | { 67 | private readonly GroupNorm norm; 68 | private readonly Conv2d q; 69 | private readonly Conv2d k; 70 | private readonly Conv2d v; 71 | private readonly Conv2d proj_out; 72 | 73 | public AttnBlock(int in_channels, Device? device = null, ScalarType? dtype = null) : base(nameof(AttnBlock)) 74 | { 75 | norm = Normalize(in_channels, device: device, dtype: dtype); 76 | q = Conv2d(in_channels, in_channels, kernel_size: 1, device: device, dtype: dtype); 77 | k = Conv2d(in_channels, in_channels, kernel_size: 1, device: device, dtype: dtype); 78 | v = Conv2d(in_channels, in_channels, kernel_size: 1, device: device, dtype: dtype); 79 | proj_out = Conv2d(in_channels, in_channels, kernel_size: 1, device: device, dtype: dtype); 80 | RegisterComponents(); 81 | } 82 | 83 | public override Tensor forward(Tensor x) 84 | { 85 | using (NewDisposeScope()) 86 | { 87 | var hidden = norm.forward(x); 88 | var q = this.q.forward(hidden); 89 | var k = this.k.forward(hidden); 90 | var v = this.v.forward(hidden); 91 | 92 | var (b, c, h, w) = (q.size(0), q.size(1), q.size(2), q.size(3)); 93 | 94 | q = q.view(b, 1, h * w, c).contiguous(); 95 | k = k.view(b, 1, h * w, c).contiguous(); 96 | v = v.view(b, 1, h * w, c).contiguous(); 97 | 98 | hidden = functional.scaled_dot_product_attention(q, k, v); // scale_factor is dim ** -0.5 per default 99 | 100 | hidden = hidden.view(b, c, h, w).contiguous(); 101 | hidden = proj_out.forward(hidden); 102 | 103 | return (x + hidden).MoveToOuterDisposeScope(); 104 | } 105 | } 106 | 107 | } 108 | 109 | private class Downsample : Module 110 | { 111 | private readonly Conv2d? conv; 112 | private readonly bool with_conv; 113 | 114 | public Downsample(int in_channels, bool with_conv = true, Device? device = null, ScalarType? dtype = null) : base(nameof(Downsample)) 115 | { 116 | this.with_conv = with_conv; 117 | if (with_conv) 118 | { 119 | conv = Conv2d(in_channels, in_channels, kernel_size: 3, stride: 2, device: device, dtype: dtype); 120 | 121 | } 122 | RegisterComponents(); 123 | } 124 | 125 | public override Tensor forward(Tensor x) 126 | { 127 | if (with_conv && conv != null) 128 | { 129 | long[] pad = new long[] { 0, 1, 0, 1 }; 130 | x = functional.pad(x, pad, mode: PaddingModes.Constant, value: 0); 131 | x = conv.forward(x); 132 | } 133 | else 134 | { 135 | x = functional.avg_pool2d(x, kernel_size: 2, stride: 2); 136 | } 137 | return x; 138 | } 139 | } 140 | 141 | private class Upsample : Module 142 | { 143 | private readonly Conv2d? conv; 144 | private readonly bool with_conv; 145 | public Upsample(int in_channels, bool with_conv = true, Device? device = null, ScalarType? dtype = null) : base(nameof(Upsample)) 146 | { 147 | this.with_conv = with_conv; 148 | if (with_conv) 149 | { 150 | conv = Conv2d(in_channels, in_channels, kernel_size: 3, padding: 1, device: device, dtype: dtype); 151 | } 152 | RegisterComponents(); 153 | } 154 | public override Tensor forward(Tensor x) 155 | { 156 | var output = functional.interpolate(x, scale_factor: new double[] { 2.0, 2.0 }, mode: InterpolationMode.Nearest); 157 | if (with_conv && conv != null) 158 | { 159 | output = conv.forward(output); 160 | } 161 | return output; 162 | } 163 | } 164 | 165 | private class VAEEncoder : Module 166 | { 167 | private readonly int num_resolutions; 168 | private readonly int num_res_blocks; 169 | private readonly Conv2d conv_in; 170 | private readonly List in_ch_mult; 171 | private readonly Sequential down; 172 | private readonly Sequential mid; 173 | private readonly GroupNorm norm_out; 174 | private readonly Conv2d conv_out; 175 | private readonly SiLU swish; 176 | private readonly int block_in; 177 | private readonly bool double_z; 178 | 179 | 180 | public VAEEncoder(int ch = 128, int[]? ch_mult = null, int num_res_blocks = 2, int in_channels = 3, int z_channels = 16, bool double_z = true, Device? device = null, ScalarType? dtype = null) : base(nameof(VAEEncoder)) 181 | { 182 | this.double_z = double_z; 183 | ch_mult ??= new int[] { 1, 2, 4, 4 }; 184 | num_resolutions = ch_mult.Length; 185 | this.num_res_blocks = num_res_blocks; 186 | 187 | // Input convolution 188 | conv_in = Conv2d(in_channels, ch, kernel_size: 3, stride: 1, padding: 1, device: device, dtype: dtype); 189 | 190 | // Downsampling layers 191 | in_ch_mult = new List { 1 }; 192 | in_ch_mult.AddRange(ch_mult); 193 | down = Sequential(); 194 | 195 | block_in = ch * in_ch_mult[0]; 196 | 197 | for (int i_level = 0; i_level < num_resolutions; i_level++) 198 | { 199 | var block = Sequential(); 200 | var attn = Sequential(); 201 | int block_out = ch * ch_mult[i_level]; 202 | block_in = ch * in_ch_mult[i_level]; 203 | for (int _ = 0; _ < num_res_blocks; _++) 204 | { 205 | block.append(new ResnetBlock(block_in, block_out, device: device, dtype: dtype)); 206 | block_in = block_out; 207 | } 208 | 209 | var d = Sequential( 210 | ("block", block), 211 | ("attn", attn)); 212 | 213 | if (i_level != num_resolutions - 1) 214 | { 215 | d.append("downsample", new Downsample(block_in, device: device, dtype: dtype)); 216 | } 217 | down.append(d); 218 | } 219 | 220 | // Middle layers 221 | mid = Sequential( 222 | ("block_1", new ResnetBlock(block_in, block_in, device: device, dtype: dtype)), 223 | ("attn_1", new AttnBlock(block_in, device: device, dtype: dtype)), 224 | ("block_2", new ResnetBlock(block_in, block_in, device: device, dtype: dtype))); 225 | 226 | 227 | // Output layers 228 | norm_out = Normalize(block_in, device: device, dtype: dtype); 229 | conv_out = Conv2d(block_in, (double_z ? 2 : 1) * z_channels, kernel_size: 3, stride: 1, padding: 1, device: device, dtype: dtype); 230 | swish = SiLU(inplace: true); 231 | 232 | RegisterComponents(); 233 | } 234 | 235 | public override Tensor forward(Tensor x) 236 | { 237 | using var _ = NewDisposeScope(); 238 | 239 | // Downsampling 240 | var h = conv_in.forward(x); 241 | 242 | h = down.forward(h); 243 | 244 | // Middle layers 245 | h = mid.forward(h); 246 | 247 | // Output layers 248 | h = norm_out.forward(h); 249 | h = swish.forward(h); 250 | h = conv_out.forward(h); 251 | return h.MoveToOuterDisposeScope(); 252 | } 253 | } 254 | 255 | private class VAEDecoder : Module 256 | { 257 | private readonly int num_resolutions; 258 | private readonly int num_res_blocks; 259 | 260 | private readonly Conv2d conv_in; 261 | private readonly Sequential mid; 262 | 263 | private readonly Sequential up; 264 | 265 | private readonly GroupNorm norm_out; 266 | private readonly Conv2d conv_out; 267 | private readonly GELU swish; 268 | 269 | public VAEDecoder(int ch = 128, int out_ch = 3, int[]? ch_mult = null, int num_res_blocks = 2, int resolution = 256, int z_channels = 16, Device? device = null, ScalarType? dtype = null) : base(nameof(VAEDecoder)) 270 | { 271 | ch_mult ??= new int[] { 1, 2, 4, 4 }; 272 | num_resolutions = ch_mult.Length; 273 | this.num_res_blocks = num_res_blocks; 274 | int block_in = ch * ch_mult[num_resolutions - 1]; 275 | 276 | int curr_res = resolution / (int)Math.Pow(2, num_resolutions - 1); 277 | // z to block_in 278 | conv_in = Conv2d(z_channels, block_in, kernel_size: 3, padding: 1, device: device, dtype: dtype); 279 | 280 | // middle 281 | mid = Sequential( 282 | ("block_1", new ResnetBlock(block_in, block_in, device: device, dtype: dtype)), 283 | ("attn_1", new AttnBlock(block_in, device: device, dtype: dtype)), 284 | ("block_2", new ResnetBlock(block_in, block_in, device: device, dtype: dtype)) 285 | ); 286 | 287 | // upsampling 288 | up = Sequential(); 289 | 290 | List list = new List(); 291 | for (int i_level = num_resolutions - 1; i_level >= 0; i_level--) 292 | { 293 | var block = Sequential(); 294 | 295 | int block_out = ch * ch_mult[i_level]; 296 | 297 | for (int i_block = 0; i_block < num_res_blocks + 1; i_block++) 298 | { 299 | block.append(new ResnetBlock(block_in, block_out, device: device, dtype: dtype)); 300 | block_in = block_out; 301 | } 302 | 303 | Sequential u = Sequential(("block", block)); 304 | 305 | if (i_level != 0) 306 | { 307 | u.append("upsample", new Upsample(block_in, device: device, dtype: dtype)); 308 | curr_res *= 2; 309 | } 310 | //this.up.append(u); 311 | list.Insert(0, u); 312 | } 313 | 314 | up = Sequential(list); 315 | 316 | // end 317 | norm_out = Normalize(block_in, device: device, dtype: dtype); 318 | conv_out = Conv2d(block_in, out_ch, kernel_size: 3, stride: 1, padding: 1, device: device, dtype: dtype); 319 | swish = GELU(inplace: true); 320 | RegisterComponents(); 321 | } 322 | 323 | public override Tensor forward(Tensor z) 324 | { 325 | // z to block_in 326 | Tensor hidden = conv_in.forward(z); 327 | 328 | // middle 329 | hidden = mid.forward(hidden); 330 | 331 | // upsampling 332 | foreach (Module md in up.children().Reverse()) 333 | { 334 | hidden = md.forward(hidden); 335 | } 336 | 337 | // end 338 | hidden = norm_out.forward(hidden); 339 | hidden = swish.forward(hidden); 340 | hidden = conv_out.forward(hidden); 341 | return hidden; 342 | } 343 | } 344 | 345 | internal class Decoder : Module 346 | { 347 | private Sequential first_stage_model; 348 | 349 | public Decoder(int embed_dim = 4, int z_channels = 4, Device? device = null, ScalarType? dtype = null) : base(nameof(Decoder)) 350 | { 351 | first_stage_model = Sequential(("post_quant_conv", Conv2d(embed_dim, z_channels, 1, device: device, dtype: dtype)), ("decoder", new VAEDecoder(z_channels: z_channels, device: device, dtype: dtype))); 352 | RegisterComponents(); 353 | } 354 | 355 | public override Tensor forward(Tensor latents) 356 | { 357 | Device device = first_stage_model.parameters().First().device; 358 | ScalarType dtype = first_stage_model.parameters().First().dtype; 359 | latents = latents.to(dtype, device); 360 | return first_stage_model.forward(latents); 361 | } 362 | } 363 | 364 | internal class Encoder : Module 365 | { 366 | private Sequential first_stage_model; 367 | public Encoder(int embed_dim = 4, int z_channels = 4, bool double_z = true, Device? device = null, ScalarType? dtype = null) : base(nameof(Encoder)) 368 | { 369 | int factor = double_z ? 2 : 1; 370 | first_stage_model = Sequential(("encoder", new VAEEncoder(z_channels: z_channels, device: device, dtype: dtype)), ("quant_conv", Conv2d(factor * embed_dim, factor * z_channels, 1, device: device, dtype: dtype))); 371 | RegisterComponents(); 372 | } 373 | 374 | public override Tensor forward(Tensor input) 375 | { 376 | Device device = first_stage_model.parameters().First().device; 377 | ScalarType dtype = first_stage_model.parameters().First().dtype; 378 | input = input.to(dtype, device); 379 | return first_stage_model.forward(input); 380 | } 381 | } 382 | } 383 | } 384 | -------------------------------------------------------------------------------- /StableDiffusionSharp/Modules/VAEApprox.cs: -------------------------------------------------------------------------------- 1 | using System.Reflection; 2 | using TorchSharp.Modules; 3 | using static TorchSharp.torch; 4 | using static TorchSharp.torch.nn; 5 | 6 | namespace StableDiffusionSharp.Modules 7 | { 8 | internal class VAEApprox : Module 9 | { 10 | private readonly Conv2d conv1; 11 | private readonly Conv2d conv2; 12 | private readonly Conv2d conv3; 13 | private readonly Conv2d conv4; 14 | private readonly Conv2d conv5; 15 | private readonly Conv2d conv6; 16 | private readonly Conv2d conv7; 17 | private readonly Conv2d conv8; 18 | 19 | internal VAEApprox(int latent_channels = 4, Device? device = null, ScalarType? dtype = null) : base(nameof(VAEApprox)) 20 | { 21 | string vaeSD15ApproxPath = @".\models\vaeapprox\vaeapp_sd15.pth"; 22 | string vaeSDXLApproxPath = @".\models\vaeapprox\xlvaeapp.pth"; 23 | string path = Path.GetDirectoryName(vaeSD15ApproxPath)!; 24 | if (!Directory.Exists(path)) 25 | { 26 | Directory.CreateDirectory(path); 27 | } 28 | Assembly _assembly = Assembly.GetExecutingAssembly(); 29 | if (!File.Exists(vaeSDXLApproxPath)) 30 | { 31 | string sd15ResourceName = "StableDiffusionSharp.Models.VAEApprox.vaeapp_sd15.pth"; 32 | using (Stream stream = _assembly.GetManifestResourceStream(sd15ResourceName)!) 33 | { 34 | if (stream == null) 35 | { 36 | Console.WriteLine("Resource can't find!"); 37 | return; 38 | } 39 | using (FileStream fileStream = new FileStream(vaeSD15ApproxPath, FileMode.Create, FileAccess.Write)) 40 | { 41 | stream.CopyTo(fileStream); 42 | } 43 | } 44 | } 45 | if (!File.Exists(vaeSDXLApproxPath)) 46 | { 47 | string sdxlResourceName = "StableDiffusionSharp.Models.VAEApprox.xlvaeapp.pth"; 48 | using (Stream stream = _assembly.GetManifestResourceStream(sdxlResourceName)!) 49 | { 50 | if (stream == null) 51 | { 52 | Console.WriteLine("Resource can't find!"); 53 | return; 54 | } 55 | using (FileStream fileStream = new FileStream(vaeSDXLApproxPath, FileMode.Create, FileAccess.Write)) 56 | { 57 | stream.CopyTo(fileStream); 58 | } 59 | } 60 | } 61 | 62 | conv1 = Conv2d(latent_channels, 8, (7, 7), device: device, dtype: dtype); 63 | conv2 = Conv2d(8, 16, (5, 5), device: device, dtype: dtype); 64 | conv3 = Conv2d(16, 32, (3, 3), device: device, dtype: dtype); 65 | conv4 = Conv2d(32, 64, (3, 3), device: device, dtype: dtype); 66 | conv5 = Conv2d(64, 32, (3, 3), device: device, dtype: dtype); 67 | conv6 = Conv2d(32, 16, (3, 3), device: device, dtype: dtype); 68 | conv7 = Conv2d(16, 8, (3, 3), device: device, dtype: dtype); 69 | conv8 = Conv2d(8, 3, (3, 3), device: device, dtype: dtype); 70 | RegisterComponents(); 71 | } 72 | 73 | public override Tensor forward(Tensor x) 74 | { 75 | using (NewDisposeScope()) 76 | { 77 | int extra = 11; 78 | x = functional.interpolate(x, new long[] { x.shape[2] * 2, x.shape[3] * 2 }); 79 | x = functional.pad(x, (extra, extra, extra, extra)); 80 | 81 | foreach (var layer in ModuleList(new Conv2d[] { conv1, conv2, conv3, conv4, conv5, conv6, conv7, conv8 })) 82 | { 83 | x = layer.forward(x); 84 | x = functional.leaky_relu(x, 0.1); 85 | } 86 | return x.MoveToOuterDisposeScope(); 87 | } 88 | } 89 | 90 | public enum SharedModel 91 | { 92 | SD3, 93 | SDXL, 94 | SD 95 | } 96 | } 97 | } 98 | -------------------------------------------------------------------------------- /StableDiffusionSharp/SDType.cs: -------------------------------------------------------------------------------- 1 | namespace StableDiffusionSharp 2 | { 3 | public enum SDScalarType 4 | { 5 | Float16 = 5, 6 | Float32 = 6, 7 | BFloat16 = 15, 8 | } 9 | 10 | public enum SDDeviceType 11 | { 12 | CPU = 0, 13 | CUDA = 1, 14 | } 15 | 16 | public enum SDSamplerType 17 | { 18 | EulerAncestral = 0, 19 | Euler = 1, 20 | } 21 | 22 | public enum ModelType 23 | { 24 | SD1, 25 | SD2, 26 | SD3, 27 | SDXL, 28 | FLUX, 29 | } 30 | 31 | public enum TimestepSpacing 32 | { 33 | Linspace, 34 | Leading, 35 | Trailing, 36 | } 37 | 38 | } 39 | -------------------------------------------------------------------------------- /StableDiffusionSharp/Sampler/BasicSampler.cs: -------------------------------------------------------------------------------- 1 | using TorchSharp; 2 | using static TorchSharp.torch; 3 | 4 | namespace StableDiffusionSharp.Sampler 5 | { 6 | public abstract class BasicSampler 7 | { 8 | public Tensor Sigmas; 9 | internal Tensor Timesteps; 10 | private Scheduler.DiscreteSchedule schedule; 11 | private readonly TimestepSpacing timestepSpacing; 12 | 13 | public BasicSampler(int num_train_timesteps = 1000, float beta_start = 0.00085f, float beta_end = 0.012f, int steps_offset = 1, TimestepSpacing timestepSpacing = TimestepSpacing.Leading) 14 | { 15 | this.timestepSpacing = timestepSpacing; 16 | Tensor betas = GetBetaSchedule(beta_start, beta_end, num_train_timesteps); 17 | Tensor alphas = 1.0f - betas; 18 | Tensor alphas_cumprod = torch.cumprod(alphas, 0); 19 | this.Sigmas = torch.pow((1.0f - alphas_cumprod) / alphas_cumprod, 0.5f); 20 | } 21 | 22 | public Tensor InitNoiseSigma() 23 | { 24 | if (timestepSpacing == TimestepSpacing.Linspace || timestepSpacing == TimestepSpacing.Trailing) 25 | { 26 | return Sigmas.max(); 27 | } 28 | return torch.sqrt(torch.pow(Sigmas.max(), 2) + 1); 29 | } 30 | 31 | public Tensor ScaleModelInput(Tensor sample, int step_index) 32 | { 33 | Tensor sigma = Sigmas[step_index]; 34 | return sample / torch.sqrt(torch.pow(sigma, 2) + 1); 35 | } 36 | 37 | /// 38 | /// Get the scalings for the given step index 39 | /// 40 | /// 41 | /// Tensor c_out, Tensor c_in 42 | public (Tensor, Tensor) GetScalings(int step_index) 43 | { 44 | Tensor sigma = Sigmas[step_index]; 45 | Tensor c_out = -sigma; 46 | Tensor c_in = 1 / torch.sqrt(torch.pow(sigma, 2) + 1); 47 | return (c_out, c_in); 48 | } 49 | public Tensor append_dims(Tensor x, long target_dims) 50 | { 51 | long dims_to_append = target_dims - x.ndim; 52 | if (dims_to_append < 0) 53 | { 54 | throw new ArgumentException("target_dims must be greater than x.ndim"); 55 | } 56 | long[] dims = x.shape; 57 | for (int i = 0; i < dims_to_append; i++) 58 | { 59 | dims.Append(1); 60 | } 61 | return x.view(dims); 62 | } 63 | 64 | 65 | 66 | public void SetTimesteps(long num_inference_steps) 67 | { 68 | if (num_inference_steps < 1) 69 | { 70 | throw new ArgumentException("num_inference_steps must be greater than 0"); 71 | } 72 | //long t_max = Sigmas.NumberOfElements - 1; 73 | //this.Timesteps = torch.linspace(t_max, 0, num_inference_steps); 74 | this.Timesteps = GetTimeSteps(Sigmas.NumberOfElements, num_inference_steps, timestepSpacing); 75 | schedule = new Scheduler.DiscreteSchedule(Sigmas); 76 | this.Sigmas = append_zero(schedule.t_to_sigma(this.Timesteps)); 77 | } 78 | 79 | private Tensor GetTimeSteps(double t_max, long num_steps, TimestepSpacing timestepSpacing) 80 | { 81 | if (timestepSpacing == TimestepSpacing.Linspace) 82 | { 83 | return torch.linspace(t_max - 1, 0, num_steps); 84 | } 85 | else if (timestepSpacing == TimestepSpacing.Leading) 86 | { 87 | long step_ratio = (long)t_max / num_steps; 88 | return torch.linspace(t_max - step_ratio, 0, num_steps) + 1; 89 | } 90 | else 91 | { 92 | long step_ratio = (long)t_max / num_steps; 93 | return torch.arange(t_max, 0, -step_ratio).round() - 1; 94 | } 95 | } 96 | 97 | 98 | public virtual Tensor Step(Tensor model_output, int step_index, Tensor sample, long seed = 0, float s_churn = 0.0f, float s_tmin = 0.0f, float s_tmax = float.PositiveInfinity, float s_noise = 1.0f) 99 | { 100 | // It is the same as EulerSampler 101 | sample = sample.to(model_output.dtype, model_output.device); 102 | Generator generator = torch.manual_seed(seed); 103 | torch.set_rng_state(generator.get_state()); 104 | float sigma = Sigmas[step_index].ToSingle(); 105 | float gamma = s_tmin <= sigma && sigma <= s_tmax ? (float)Math.Min(s_churn / (Sigmas.NumberOfElements - 1f), Math.Sqrt(2.0f) - 1.0f) : 0f; 106 | Tensor epsilon = torch.randn_like(model_output) * s_noise; 107 | float sigma_hat = sigma * (gamma + 1); 108 | if (gamma > 0) 109 | { 110 | sample = sample + epsilon * (float)Math.Sqrt(Math.Pow(sigma_hat, 2f) - Math.Pow(sigma, 2f)); 111 | } 112 | Tensor pred_original_sample = sample - sigma_hat * model_output; // to_d and sigma is c_out 113 | Tensor derivative = (sample - pred_original_sample) / sigma_hat; 114 | float dt = Sigmas[step_index + 1].ToSingle() - sigma_hat; 115 | return sample + derivative * dt; 116 | } 117 | 118 | private Tensor GetBetaSchedule(float beta_start, float beta_end, int num_train_timesteps) 119 | { 120 | return torch.pow(torch.linspace(Math.Pow(beta_start, 0.5), Math.Pow(beta_end, 0.5), num_train_timesteps, ScalarType.Float32), 2); 121 | } 122 | 123 | private static Tensor append_zero(Tensor x) 124 | { 125 | return torch.cat(new Tensor[] { x, x.new_zeros(1) }); 126 | } 127 | } 128 | } 129 | -------------------------------------------------------------------------------- /StableDiffusionSharp/Sampler/EulerAncestralSampler.cs: -------------------------------------------------------------------------------- 1 | using TorchSharp; 2 | using static TorchSharp.torch; 3 | 4 | namespace StableDiffusionSharp.Sampler 5 | { 6 | internal class EulerAncestralSampler : BasicSampler 7 | { 8 | public EulerAncestralSampler(int num_train_timesteps = 1000, float beta_start = 0.00085f, float beta_end = 0.012f, int steps_offset = 1) : base(num_train_timesteps, beta_start, beta_end, steps_offset) 9 | { 10 | 11 | } 12 | public override torch.Tensor Step(torch.Tensor model_output, int step_index, torch.Tensor sample, long seed = 0, float s_churn = 0, float s_tmin = 0, float s_tmax = float.PositiveInfinity, float s_noise = 1) 13 | { 14 | sample = sample.to(model_output.dtype, model_output.device); 15 | Generator generator = torch.manual_seed(seed); 16 | torch.set_rng_state(generator.get_state()); 17 | 18 | float sigma = base.Sigmas[step_index].ToSingle(); 19 | 20 | Tensor predOriginalSample = sample - model_output * sigma; 21 | Tensor sigmaFrom = base.Sigmas[step_index]; 22 | Tensor sigmaTo = base.Sigmas[step_index + 1]; 23 | Tensor sigmaFromLessSigmaTo = torch.pow(sigmaFrom, 2) - torch.pow(sigmaTo, 2); 24 | Tensor sigmaUpResult = torch.pow(sigmaTo, 2) * sigmaFromLessSigmaTo / torch.pow(sigmaFrom, 2); 25 | 26 | Tensor sigmaUp = sigmaUpResult.ToSingle() < 0 ? -torch.pow(torch.abs(sigmaUpResult), 0.5f) : torch.pow(sigmaUpResult, 0.5f); 27 | Tensor sigmaDownResult = torch.pow(sigmaTo, 2) - torch.pow(sigmaUp, 2); 28 | Tensor sigmaDown = sigmaDownResult.ToSingle() < 0 ? -torch.pow(torch.abs(sigmaDownResult), 0.5f) : torch.pow(sigmaDownResult, 0.5f); 29 | Tensor derivative = (sample - predOriginalSample) / sigma; // to_d and sigma is c_out 30 | Tensor delta = sigmaDown - sigma; 31 | Tensor prevSample = sample + derivative * delta; 32 | var noise = torch.randn_like(prevSample); 33 | prevSample = prevSample + noise * sigmaUp; 34 | return prevSample; 35 | } 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /StableDiffusionSharp/Sampler/EulerSampler.cs: -------------------------------------------------------------------------------- 1 | using TorchSharp; 2 | using static TorchSharp.torch; 3 | 4 | namespace StableDiffusionSharp.Sampler 5 | { 6 | internal class EulerSampler : BasicSampler 7 | { 8 | public EulerSampler(int num_train_timesteps = 1000, float beta_start = 0.00085f, float beta_end = 0.012f, int steps_offset = 1) : base(num_train_timesteps, beta_start, beta_end, steps_offset) 9 | { 10 | 11 | } 12 | 13 | public override torch.Tensor Step(torch.Tensor model_output, int step_index, torch.Tensor sample, long seed = 0, float s_churn = 0, float s_tmin = 0, float s_tmax = float.PositiveInfinity, float s_noise = 1) 14 | { 15 | sample = sample.to(model_output.dtype, model_output.device); 16 | Generator generator = torch.manual_seed(seed); 17 | torch.set_rng_state(generator.get_state()); 18 | float sigma = base.Sigmas[step_index].ToSingle(); 19 | float gamma = s_tmin <= sigma && sigma <= s_tmax ? (float)Math.Min(s_churn / (Sigmas.NumberOfElements - 1f), Math.Sqrt(2.0f) - 1.0f) : 0f; 20 | Tensor noise = torch.randn_like(model_output); 21 | Tensor epsilon = noise * s_noise; 22 | float sigma_hat = sigma * (gamma + 1.0f); 23 | if (gamma > 0) 24 | { 25 | sample = sample + epsilon * (float)Math.Sqrt(Math.Pow(sigma_hat, 2f) - Math.Pow(sigma, 2f)); 26 | } 27 | Tensor pred_original_sample = sample - sigma_hat * model_output; // to_d and sigma is c_out 28 | Tensor derivative = (sample - pred_original_sample) / sigma_hat; 29 | Tensor dt = Sigmas[step_index + 1] - sigma_hat; 30 | return sample + derivative * dt; 31 | } 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /StableDiffusionSharp/Scheduler/DiscreteSchedule.cs: -------------------------------------------------------------------------------- 1 | using static TorchSharp.torch; 2 | using static TorchSharp.torch.nn; 3 | 4 | namespace StableDiffusionSharp.Scheduler 5 | { 6 | internal class DiscreteSchedule : Module 7 | { 8 | private Tensor sigmas; 9 | private Tensor log_sigmas; 10 | private bool quantize; 11 | 12 | public DiscreteSchedule(Tensor sigmas, bool quantize = false) : base(nameof(DiscreteSchedule)) 13 | { 14 | this.sigmas = sigmas; 15 | log_sigmas = sigmas.log(); 16 | this.quantize = quantize; 17 | RegisterComponents(); 18 | } 19 | 20 | public Tensor sigma_mix => sigmas.max(); 21 | public Tensor sigma_max => sigmas.min(); 22 | 23 | public Tensor t_to_sigma(Tensor t) 24 | { 25 | t = t.@float(); 26 | Tensor low_idx = t.floor().@long(); 27 | Tensor high_idx = t.ceil().@long(); 28 | Tensor w = t.frac(); 29 | Tensor log_sigma = (1 - w) * log_sigmas[low_idx] + w * log_sigmas[high_idx]; 30 | return log_sigma.exp(); 31 | } 32 | 33 | public Tensor sigma_to_t(Tensor sigma, bool? quantize = null) 34 | { 35 | quantize = quantize ?? this.quantize; 36 | Tensor log_sigma = sigma.log(); 37 | Tensor dists = log_sigma - log_sigmas[.., TensorIndex.None]; 38 | 39 | if (quantize == true) 40 | { 41 | return dists.abs().argmin(dim: 0).view(sigma.shape); 42 | } 43 | 44 | Tensor low_idx = dists.ge(0).cumsum(dim: 0).argmax(dim: 0).clamp(max: log_sigmas.shape[0] - 2); 45 | Tensor high_idx = low_idx + 1; 46 | var (low, high) = (log_sigmas[low_idx], log_sigmas[high_idx]); 47 | Tensor w = (low - log_sigma) / (low - high); 48 | w = w.clamp(0, 1); 49 | Tensor t = (1 - w) * low_idx + w * high_idx; 50 | return t.view(sigma.shape); 51 | } 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /StableDiffusionSharp/StableDiffusion.cs: -------------------------------------------------------------------------------- 1 | using StableDiffusionSharp.Modules; 2 | using TorchSharp; 3 | using static TorchSharp.torch; 4 | 5 | namespace StableDiffusionSharp 6 | { 7 | public class StableDiffusion : nn.Module 8 | { 9 | private SDModel model; 10 | private readonly Device device; 11 | private readonly ScalarType dtype; 12 | 13 | public class StepEventArgs : EventArgs 14 | { 15 | public int CurrentStep { get; } 16 | public int TotalSteps { get; } 17 | 18 | public ImageMagick.MagickImage VaeApproxImg { get; } 19 | 20 | public StepEventArgs(int currentStep, int totalSteps, ImageMagick.MagickImage vaeApproxImg) 21 | { 22 | CurrentStep = currentStep; 23 | TotalSteps = totalSteps; 24 | VaeApproxImg = vaeApproxImg; 25 | } 26 | } 27 | 28 | public event EventHandler StepProgress; 29 | protected void OnStepProgress(int currentStep, int totalSteps, ImageMagick.MagickImage vaeApproxImg) 30 | { 31 | StepProgress?.Invoke(this, new StepEventArgs(currentStep, totalSteps, vaeApproxImg)); 32 | } 33 | 34 | public StableDiffusion(SDDeviceType deviceType, SDScalarType scaleType) : base(nameof(StableDiffusion)) 35 | { 36 | this.device = new Device((DeviceType)deviceType); 37 | this.dtype = (ScalarType)scaleType; 38 | } 39 | 40 | public void LoadModel(string modelPath, string vaeModelPath = "", string vocabPath = @".\models\clip\vocab.json", string mergesPath = @".\models\clip\merges.txt") 41 | { 42 | ModelType modelType = ModelLoader.ModelLoader.GetModelType(modelPath); 43 | Console.WriteLine($"Maybe you are using: {modelType}"); 44 | model = modelType switch 45 | { 46 | ModelType.SD1 => new SD1(this.device, this.dtype), 47 | ModelType.SDXL => new SDXL(this.device, this.dtype), 48 | _ => throw new ArgumentException("Invalid model type") 49 | }; 50 | model.LoadModel(modelPath, vaeModelPath, vocabPath, mergesPath); 51 | model.StepProgress += Model_StepProgress; 52 | } 53 | 54 | private void Model_StepProgress(object? sender, SDModel.StepEventArgs e) 55 | { 56 | OnStepProgress(e.CurrentStep, e.TotalSteps, e.VAEApproxImg); 57 | } 58 | 59 | public ImageMagick.MagickImage TextToImage(string prompt, string nprompt = "", long clip_skip = 0, int width = 512, int height = 512, int steps = 20, long seed = 0, float cfg = 7.0f, SDSamplerType samplerType = SDSamplerType.Euler) 60 | { 61 | return model.TextToImage(prompt, nprompt, clip_skip, width, height, steps, seed, cfg, samplerType); 62 | } 63 | 64 | public ImageMagick.MagickImage ImageToImage(ImageMagick.MagickImage orgImage, string prompt, string nprompt = "", long clip_skip = 0, int steps = 20, float strength = 0.75f, long seed = 0, long subSeed = 0, float cfg = 7.0f, SDSamplerType samplerType = SDSamplerType.Euler) 65 | { 66 | return model.ImageToImage(orgImage, prompt, nprompt, clip_skip, steps, strength, seed, subSeed, cfg, samplerType); 67 | } 68 | 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /StableDiffusionSharp/StableDiffusionSharp.csproj: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | net6.0 5 | enable 6 | enable 7 | StableDiffusionSharp 8 | IntptrMax 9 | 10 | 11 | Use Stable Diffusion with C# with fast speed and less VRAM. 12 | Requires reference to one of libtorch-cpu, libtorch-cuda-12.1, libtorch-cuda-12.1-win-x64 or libtorch-cuda-12.1-linux-x64 version 2.5.1.0 to execute. 13 | https://github.com/IntptrMax/StableDiffusionSharp 14 | LICENSE.txt 15 | 1.0.8 16 | README.md 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | Never 29 | True 30 | \ 31 | 32 | 33 | Never 34 | True 35 | \ 36 | 37 | 38 | Never 39 | 40 | 41 | Never 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | True 58 | \ 59 | 60 | 61 | True 62 | \ 63 | 64 | 65 | 66 | 67 | 68 | -------------------------------------------------------------------------------- /StableDiffusionSharp/Tools.cs: -------------------------------------------------------------------------------- 1 | using ImageMagick; 2 | using System.IO.Compression; 3 | using System.Text; 4 | using TorchSharp; 5 | using static TorchSharp.torch; 6 | 7 | namespace StableDiffusionSharp 8 | { 9 | internal class Tools 10 | { 11 | internal static Tensor GetTensorFromImage(MagickImage image) 12 | { 13 | using (MemoryStream memoryStream = new MemoryStream()) 14 | { 15 | image.Write(memoryStream, MagickFormat.Png); 16 | memoryStream.Position = 0; 17 | return torchvision.io.read_image(memoryStream); 18 | } 19 | } 20 | 21 | public static MagickImage GetImageFromTensor(Tensor tensor) 22 | { 23 | MemoryStream memoryStream = new MemoryStream(); 24 | torchvision.io.write_png(tensor.cpu(), memoryStream); 25 | memoryStream.Position = 0; 26 | return new MagickImage(memoryStream, MagickFormat.Png); 27 | } 28 | 29 | /// 30 | /// Load Python .pt tensor file and change dtype and device the same as given tensor. 31 | /// 32 | /// tensor path 33 | /// the given tensor 34 | /// Tensor in TorchSharp 35 | public static Tensor LoadTensorFromPT(string path, Tensor tensor) 36 | { 37 | return LoadTensorFromPT(path).to(tensor.dtype, tensor.device); 38 | } 39 | 40 | /// 41 | /// Load Python .pt tensor file 42 | /// 43 | /// tensor path 44 | /// Tensor in TorchSharp 45 | public static Tensor LoadTensorFromPT(string path) 46 | { 47 | torch.ScalarType dtype = torch.ScalarType.Float32; 48 | List shape = new List(); 49 | ZipArchive zip = ZipFile.OpenRead(path); 50 | ZipArchiveEntry headerEntry = zip.Entries.First(e => e.Name == "data.pkl"); 51 | 52 | // Header is always small enough to fit in memory, so we can read it all at once 53 | using Stream headerStream = headerEntry.Open(); 54 | byte[] headerBytes = new byte[headerEntry.Length]; 55 | headerStream.Read(headerBytes, 0, headerBytes.Length); 56 | 57 | string headerStr = Encoding.Default.GetString(headerBytes); 58 | if (headerStr.Contains("HalfStorage")) 59 | { 60 | dtype = torch.ScalarType.Float16; 61 | } 62 | else if (headerStr.Contains("BFloat")) 63 | { 64 | dtype = torch.ScalarType.Float16; 65 | } 66 | else if (headerStr.Contains("FloatStorage")) 67 | { 68 | dtype = torch.ScalarType.Float32; 69 | } 70 | for (int i = 0; i < headerBytes.Length; i++) 71 | { 72 | if (headerBytes[i] == 81 && headerBytes[i + 1] == 75 && headerBytes[i + 2] == 0) 73 | { 74 | for (int j = i + 2; j < headerBytes.Length; j++) 75 | { 76 | if (headerBytes[j] == 75) 77 | { 78 | shape.Add(headerBytes[j + 1]); 79 | j++; 80 | } 81 | else if (headerBytes[j] == 77) 82 | { 83 | shape.Add(headerBytes[j + 1] + headerBytes[j + 2] * 256); 84 | j += 2; 85 | } 86 | else if (headerBytes[j] == 113) 87 | { 88 | break; 89 | } 90 | 91 | } 92 | break; 93 | } 94 | } 95 | 96 | Tensor tensor = torch.zeros(shape.ToArray(), dtype: dtype); 97 | ZipArchiveEntry dataEntry = zip.Entries.First(e => e.Name == "0"); 98 | 99 | using Stream dataStream = dataEntry.Open(); 100 | byte[] data = new byte[dataEntry.Length]; 101 | dataStream.Read(data, 0, data.Length); 102 | tensor.bytes = data; 103 | return tensor; 104 | } 105 | 106 | public static long GetFreeVRAM() 107 | { 108 | if (!cuda.is_available()) 109 | { 110 | return 0; 111 | } 112 | else 113 | { 114 | using (var factory = new SharpDX.DXGI.Factory1()) 115 | { 116 | var adapter = factory.Adapters[0]; 117 | using (var adapter3 = adapter.QueryInterface()) 118 | { 119 | if (adapter3 == null) 120 | { 121 | throw new ArgumentException($"Adapter {adapter.Description.Description} not support"); 122 | } 123 | var memoryInfo = adapter3.QueryVideoMemoryInfo(0, SharpDX.DXGI.MemorySegmentGroup.Local); 124 | long totalVRAM = adapter.Description.DedicatedVideoMemory; 125 | long usedVRAM = memoryInfo.CurrentUsage; 126 | long freeVRAM = memoryInfo.Budget - usedVRAM; 127 | return freeVRAM; 128 | } 129 | } 130 | } 131 | } 132 | 133 | } 134 | } 135 | --------------------------------------------------------------------------------