├── .gitignore ├── LICENSE ├── README.md ├── RELEASENOTES.md ├── TorchSharp.PyBridge.Tests ├── SaveUtils.cs ├── TestLoadSaveModules.cs ├── TestLoadSaveOptimizers.cs ├── TestLoadSaveTensors.cs ├── TorchSharp.PyBridge.Tests.csproj ├── pickled_modules │ ├── module_load.pth │ └── module_load.safetensors ├── pickled_optimizers │ ├── adadelta_load.pth │ ├── adagrad_load.pth │ ├── adam_emptystate_load.pth │ ├── adam_load.pth │ ├── adamax_load.pth │ ├── adamw_load.pth │ ├── asgd_load.bin │ ├── asgd_load.pth │ ├── create_pyload_opts.py │ ├── nadam_load.pth │ ├── radam_load.pth │ ├── rmsprop_load.pth │ ├── rprop_load.pth │ ├── sgd_load.pth │ └── test_sharpsaved_opts.py ├── pickled_tensors │ ├── tensors.pth │ └── tensors.safetensors ├── test_safetensors_load.py ├── test_safetensors_tensor_load.py ├── test_torch_load.py └── test_torch_tensor_load.py ├── TorchSharp.PyBridge.sln ├── TorchSharp.PyBridge ├── Extensions.cs ├── OptimizerUtils.cs ├── PyBridgeModuleExtensions.cs ├── PyBridgeOptimizerExtensions.cs ├── PyTorchPickler.cs ├── PyTorchUnpickler.cs ├── Safetensors.cs └── TorchSharp.PyBridge.csproj └── pack.bat /.gitignore: -------------------------------------------------------------------------------- 1 | ## Ignore Visual Studio temporary files, build results, and 2 | ## files generated by popular Visual Studio add-ons. 3 | ## 4 | ## Get latest from https://github.com/github/gitignore/blob/main/VisualStudio.gitignore 5 | 6 | # User-specific files 7 | *.rsuser 8 | *.suo 9 | *.user 10 | *.userosscache 11 | *.sln.docstates 12 | 13 | # User-specific files (MonoDevelop/Xamarin Studio) 14 | *.userprefs 15 | 16 | # Mono auto generated files 17 | mono_crash.* 18 | 19 | # Build results 20 | [Dd]ebug/ 21 | [Dd]ebugPublic/ 22 | [Rr]elease/ 23 | [Rr]eleases/ 24 | x64/ 25 | x86/ 26 | [Ww][Ii][Nn]32/ 27 | [Aa][Rr][Mm]/ 28 | [Aa][Rr][Mm]64/ 29 | bld/ 30 | [Bb]in/ 31 | [Oo]bj/ 32 | [Ll]og/ 33 | [Ll]ogs/ 34 | 35 | # Visual Studio 2015/2017 cache/options directory 36 | .vs/ 37 | # Uncomment if you have tasks that create the project's static files in wwwroot 38 | #wwwroot/ 39 | 40 | # Visual Studio 2017 auto generated files 41 | Generated\ Files/ 42 | 43 | # MSTest test Results 44 | [Tt]est[Rr]esult*/ 45 | [Bb]uild[Ll]og.* 46 | 47 | # NUnit 48 | *.VisualState.xml 49 | TestResult.xml 50 | nunit-*.xml 51 | 52 | # Build Results of an ATL Project 53 | [Dd]ebugPS/ 54 | [Rr]eleasePS/ 55 | dlldata.c 56 | 57 | # Benchmark Results 58 | BenchmarkDotNet.Artifacts/ 59 | 60 | # .NET Core 61 | project.lock.json 62 | project.fragment.lock.json 63 | artifacts/ 64 | 65 | # ASP.NET Scaffolding 66 | ScaffoldingReadMe.txt 67 | 68 | # StyleCop 69 | StyleCopReport.xml 70 | 71 | # Files built by Visual Studio 72 | *_i.c 73 | *_p.c 74 | *_h.h 75 | *.ilk 76 | *.meta 77 | *.obj 78 | *.iobj 79 | *.pch 80 | *.pdb 81 | *.ipdb 82 | *.pgc 83 | *.pgd 84 | *.rsp 85 | *.sbr 86 | *.tlb 87 | *.tli 88 | *.tlh 89 | *.tmp 90 | *.tmp_proj 91 | *_wpftmp.csproj 92 | *.log 93 | *.tlog 94 | *.vspscc 95 | *.vssscc 96 | .builds 97 | *.pidb 98 | *.svclog 99 | *.scc 100 | 101 | # Chutzpah Test files 102 | _Chutzpah* 103 | 104 | # Visual C++ cache files 105 | ipch/ 106 | *.aps 107 | *.ncb 108 | *.opendb 109 | *.opensdf 110 | *.sdf 111 | *.cachefile 112 | *.VC.db 113 | *.VC.VC.opendb 114 | 115 | # Visual Studio profiler 116 | *.psess 117 | *.vsp 118 | *.vspx 119 | *.sap 120 | 121 | # Visual Studio Trace Files 122 | *.e2e 123 | 124 | # TFS 2012 Local Workspace 125 | $tf/ 126 | 127 | # Guidance Automation Toolkit 128 | *.gpState 129 | 130 | # ReSharper is a .NET coding add-in 131 | _ReSharper*/ 132 | *.[Rr]e[Ss]harper 133 | *.DotSettings.user 134 | 135 | # TeamCity is a build add-in 136 | _TeamCity* 137 | 138 | # DotCover is a Code Coverage Tool 139 | *.dotCover 140 | 141 | # AxoCover is a Code Coverage Tool 142 | .axoCover/* 143 | !.axoCover/settings.json 144 | 145 | # Coverlet is a free, cross platform Code Coverage Tool 146 | coverage*.json 147 | coverage*.xml 148 | coverage*.info 149 | 150 | # Visual Studio code coverage results 151 | *.coverage 152 | *.coveragexml 153 | 154 | # NCrunch 155 | _NCrunch_* 156 | .*crunch*.local.xml 157 | nCrunchTemp_* 158 | 159 | # MightyMoose 160 | *.mm.* 161 | AutoTest.Net/ 162 | 163 | # Web workbench (sass) 164 | .sass-cache/ 165 | 166 | # Installshield output folder 167 | [Ee]xpress/ 168 | 169 | # DocProject is a documentation generator add-in 170 | DocProject/buildhelp/ 171 | DocProject/Help/*.HxT 172 | DocProject/Help/*.HxC 173 | DocProject/Help/*.hhc 174 | DocProject/Help/*.hhk 175 | DocProject/Help/*.hhp 176 | DocProject/Help/Html2 177 | DocProject/Help/html 178 | 179 | # Click-Once directory 180 | publish/ 181 | 182 | # Publish Web Output 183 | *.[Pp]ublish.xml 184 | *.azurePubxml 185 | # Note: Comment the next line if you want to checkin your web deploy settings, 186 | # but database connection strings (with potential passwords) will be unencrypted 187 | *.pubxml 188 | *.publishproj 189 | 190 | # Microsoft Azure Web App publish settings. Comment the next line if you want to 191 | # checkin your Azure Web App publish settings, but sensitive information contained 192 | # in these scripts will be unencrypted 193 | PublishScripts/ 194 | 195 | # NuGet Packages 196 | *.nupkg 197 | # NuGet Symbol Packages 198 | *.snupkg 199 | # The packages folder can be ignored because of Package Restore 200 | **/[Pp]ackages/* 201 | # except build/, which is used as an MSBuild target. 202 | !**/[Pp]ackages/build/ 203 | # Uncomment if necessary however generally it will be regenerated when needed 204 | #!**/[Pp]ackages/repositories.config 205 | # NuGet v3's project.json files produces more ignorable files 206 | *.nuget.props 207 | *.nuget.targets 208 | 209 | # Microsoft Azure Build Output 210 | csx/ 211 | *.build.csdef 212 | 213 | # Microsoft Azure Emulator 214 | ecf/ 215 | rcf/ 216 | 217 | # Windows Store app package directories and files 218 | AppPackages/ 219 | BundleArtifacts/ 220 | Package.StoreAssociation.xml 221 | _pkginfo.txt 222 | *.appx 223 | *.appxbundle 224 | *.appxupload 225 | 226 | # Visual Studio cache files 227 | # files ending in .cache can be ignored 228 | *.[Cc]ache 229 | # but keep track of directories ending in .cache 230 | !?*.[Cc]ache/ 231 | 232 | # Others 233 | ClientBin/ 234 | ~$* 235 | *~ 236 | *.dbmdl 237 | *.dbproj.schemaview 238 | *.jfm 239 | *.pfx 240 | *.publishsettings 241 | orleans.codegen.cs 242 | 243 | # Including strong name files can present a security risk 244 | # (https://github.com/github/gitignore/pull/2483#issue-259490424) 245 | #*.snk 246 | 247 | # Since there are multiple workflows, uncomment next line to ignore bower_components 248 | # (https://github.com/github/gitignore/pull/1529#issuecomment-104372622) 249 | #bower_components/ 250 | 251 | # RIA/Silverlight projects 252 | Generated_Code/ 253 | 254 | # Backup & report files from converting an old project file 255 | # to a newer Visual Studio version. Backup files are not needed, 256 | # because we have git ;-) 257 | _UpgradeReport_Files/ 258 | Backup*/ 259 | UpgradeLog*.XML 260 | UpgradeLog*.htm 261 | ServiceFabricBackup/ 262 | *.rptproj.bak 263 | 264 | # SQL Server files 265 | *.mdf 266 | *.ldf 267 | *.ndf 268 | 269 | # Business Intelligence projects 270 | *.rdl.data 271 | *.bim.layout 272 | *.bim_*.settings 273 | *.rptproj.rsuser 274 | *- [Bb]ackup.rdl 275 | *- [Bb]ackup ([0-9]).rdl 276 | *- [Bb]ackup ([0-9][0-9]).rdl 277 | 278 | # Microsoft Fakes 279 | FakesAssemblies/ 280 | 281 | # GhostDoc plugin setting file 282 | *.GhostDoc.xml 283 | 284 | # Node.js Tools for Visual Studio 285 | .ntvs_analysis.dat 286 | node_modules/ 287 | 288 | # Visual Studio 6 build log 289 | *.plg 290 | 291 | # Visual Studio 6 workspace options file 292 | *.opt 293 | 294 | # Visual Studio 6 auto-generated workspace file (contains which files were open etc.) 295 | *.vbw 296 | 297 | # Visual Studio 6 auto-generated project file (contains which files were open etc.) 298 | *.vbp 299 | 300 | # Visual Studio 6 workspace and project file (working project files containing files to include in project) 301 | *.dsw 302 | *.dsp 303 | 304 | # Visual Studio 6 technical files 305 | *.ncb 306 | *.aps 307 | 308 | # Visual Studio LightSwitch build output 309 | **/*.HTMLClient/GeneratedArtifacts 310 | **/*.DesktopClient/GeneratedArtifacts 311 | **/*.DesktopClient/ModelManifest.xml 312 | **/*.Server/GeneratedArtifacts 313 | **/*.Server/ModelManifest.xml 314 | _Pvt_Extensions 315 | 316 | # Paket dependency manager 317 | .paket/paket.exe 318 | paket-files/ 319 | 320 | # FAKE - F# Make 321 | .fake/ 322 | 323 | # CodeRush personal settings 324 | .cr/personal 325 | 326 | # Python Tools for Visual Studio (PTVS) 327 | __pycache__/ 328 | *.pyc 329 | 330 | # Cake - Uncomment if you are using it 331 | # tools/** 332 | # !tools/packages.config 333 | 334 | # Tabs Studio 335 | *.tss 336 | 337 | # Telerik's JustMock configuration file 338 | *.jmconfig 339 | 340 | # BizTalk build output 341 | *.btp.cs 342 | *.btm.cs 343 | *.odx.cs 344 | *.xsd.cs 345 | 346 | # OpenCover UI analysis results 347 | OpenCover/ 348 | 349 | # Azure Stream Analytics local run output 350 | ASALocalRun/ 351 | 352 | # MSBuild Binary and Structured Log 353 | *.binlog 354 | 355 | # NVidia Nsight GPU debugger configuration file 356 | *.nvuser 357 | 358 | # MFractors (Xamarin productivity tool) working folder 359 | .mfractor/ 360 | 361 | # Local History for Visual Studio 362 | .localhistory/ 363 | 364 | # Visual Studio History (VSHistory) files 365 | .vshistory/ 366 | 367 | # BeatPulse healthcheck temp database 368 | healthchecksdb 369 | 370 | # Backup folder for Package Reference Convert tool in Visual Studio 2017 371 | MigrationBackup/ 372 | 373 | # Ionide (cross platform F# VS Code tools) working folder 374 | .ionide/ 375 | 376 | # Fody - auto-generated XML schema 377 | FodyWeavers.xsd 378 | 379 | # VS Code files for those working on multiple tools 380 | .vscode/* 381 | !.vscode/settings.json 382 | !.vscode/tasks.json 383 | !.vscode/launch.json 384 | !.vscode/extensions.json 385 | *.code-workspace 386 | 387 | # Local History for Visual Studio Code 388 | .history/ 389 | 390 | # Windows Installer files from build outputs 391 | *.cab 392 | *.msi 393 | *.msix 394 | *.msm 395 | *.msp 396 | 397 | # JetBrains Rider 398 | *.sln.iml 399 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 shaltielshmid 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TorchSharp.PyBridge 2 | 3 | [![NuGet](https://img.shields.io/nuget/v/TorchSharp.PyBridge.svg)](https://www.nuget.org/packages/TorchSharp.PyBridge/) 4 | 5 | TorchSharp.PyBridge is an extension library for [TorchSharp](https://github.com/dotnet/TorchSharp), providing seamless interoperability between .NET and Python for model serialization. It simplifies the process of saving and loading PyTorch models in a .NET environment, enabling developers to easily develop models in both .NET and Python and transfer models easily. 6 | 7 | ## Features 8 | 9 | - `module.load_py(...)`, `optim.load_py(...)`: Extension method for modules and optimizers for easily loading PyTorch models saved in the standard Python format (using `torch.save`) directly into TorchSharp. 10 | 11 | > This only works for when the `state_dict` was saved and not the whole model, see example below. 12 | 13 | - `module.save_py(...)`, `optim.save_py(...)`: Extension method for modules and optimizers for easily saving TorchSharp models in a format that can be directly loaded in PyTorch (using `torch.load`), offering cross-platform model compatibility. 14 | 15 | - `module.load_safetensors(...)`, `module.save_safetensors(...)`: Extension methods for modules for easily saving and loading model weights using the [safetensors](https://github.com/huggingface/safetensors) format. 16 | 17 | - `module.load_checkpoint(...)`: Extension method for loading in a checkpoint (both safetensors and regular pytorch, including sharded models) from a directory saved using HuggingFace's `PreTrainedModel.save_pretrained()` method. 18 | 19 | ## Getting Started 20 | 21 | ### Installation 22 | 23 | TorchSharp.PyBridge is available on NuGet. You can install it using the following command: 24 | 25 | #### .NET CLI 26 | ```bash 27 | dotnet add package TorchSharp.PyBridge 28 | ``` 29 | 30 | #### NuGet Package Manager 31 | ```powershell 32 | Install-Package TorchSharp.PyBridge 33 | ``` 34 | 35 | ### Prerequisites 36 | 37 | - .NET SDK 38 | - TorchSharp library 39 | 40 | ## Usage 41 | 42 | ### Loading a PyTorch Model in .NET 43 | 44 | Saving the model in Python: 45 | 46 | ```python 47 | import torch 48 | 49 | model = ... 50 | torch.save(model.state_dict(), 'path_to_your_model.pth') 51 | ``` 52 | 53 | Loading it in C#: 54 | 55 | ```csharp 56 | using TorchSharp.PyBridge; 57 | 58 | var model = ...; 59 | model.load_py("path_to_your_model.pth"); 60 | ``` 61 | 62 | ### Saving a TorchSharp Model for PyTorch 63 | 64 | To save a model in a format compatible with PyTorch: 65 | 66 | ```csharp 67 | using TorchSharp.PyBridge; 68 | 69 | var model = ...; 70 | model.save_py("path_to_save_model.pth"); 71 | ``` 72 | 73 | And loading it in in Python: 74 | 75 | ```python 76 | import torch 77 | 78 | model = ... 79 | model.load_state_dict(torch.load('path_to_save_model.pth')) 80 | ``` 81 | 82 | ## Contributing 83 | 84 | Contributions to TorchSharp.PyBridge are welcome. 85 | 86 | ## Acknowledgments 87 | 88 | This project makes use of the `pickle` library, a Java and .NET implementation of Python's pickle serialization protocol, developed by Irmen de Jong. The `pickle` library plays a vital role in enabling the serialization features within TorchSharp.PyBridge. We extend our thanks to the developer for their significant contributions to the open-source community. For more details about the `pickle` library, please visit their [GitHub repository](https://github.com/irmen/pickle). 89 | 90 | ## Support and Contact 91 | 92 | For support, questions, or feedback, please open an issue in the [GitHub repository](https://github.com/shaltielshmid/TorchSharp.PyBridge). 93 | -------------------------------------------------------------------------------- /RELEASENOTES.md: -------------------------------------------------------------------------------- 1 | # TorchSharp.PyBridge Release Notes 2 | 3 | 1.4.3: 4 | - Fixed #21: `strict` is not passed to `load_safetensor` in `load_checkpoint` extension 5 | 6 | 1.4.2: 7 | - PR #20: Optimize load_py for memory and speed (@ejhg) 8 | 9 | 1.4.1: 10 | - Fixed #17: How to disable tqdm output when loading sharded safetensors 11 | 12 | 1.4.0: 13 | - Exposed `Safetensors`, `PytorchPickler` and `PytorchUnpickler` to allow for loading/saving python tensors outside of a model. 14 | - Fixed #16: SaveStateDict calls itself recursively and fails on locked file 15 | 16 | 1.3.2: 17 | - Fixed #13: UnpickleStateDict on BatchNorm2d error 18 | 19 | 1.3.1: 20 | - Fixed error on Apple Silicon devices 21 | 22 | 1.3.0: 23 | - Added support for loading tensors that are greater than 2GB (following the update in TorchSharp 0.102.0) 24 | - Added support for loading and saving safetensors when model isn't on CPU. 25 | 26 | 1.1.0: 27 | - Added `load_py` and `save_py` extensions to optimizers. 28 | 29 | 1.0.0: 30 | - Initial release of `load_py` and `save_py` extensions for `torch.nn.Module` 31 | -------------------------------------------------------------------------------- /TorchSharp.PyBridge.Tests/SaveUtils.cs: -------------------------------------------------------------------------------- 1 | using NUnit.Framework; 2 | using System; 3 | using System.Collections.Generic; 4 | using System.IO.Compression; 5 | using System.Linq; 6 | using System.Text; 7 | using System.Threading.Tasks; 8 | using TorchSharp.Modules; 9 | 10 | namespace TorchSharp.PyBridge.Tests { 11 | internal static class SaveUtils { 12 | public static bool SaveAndCompare(OptimizerHelper optim, string file) { 13 | // Save the model to a memory stream 14 | var ms = new MemoryStream(); 15 | optim.save_py(ms, leaveOpen: true); 16 | ms.Position = 0; 17 | 18 | // Compare the bytes to pyload_test 19 | return CompareSavedModules(ms, File.OpenRead(file)); 20 | } 21 | public static bool CompareSavedModules(Stream baseFile, Stream targetFile) { 22 | // One catch: They are zip files, and therefore the timestamp is embedded 23 | // in the bytes. Therefore, we are going to extract every entry in the archive 24 | // and compare the raw bytes then. 25 | byte[] compBytes = new ZipArchive(baseFile).ExtractAllContentBytes(); 26 | byte[] goldBytes = new ZipArchive(targetFile).ExtractAllContentBytes(); 27 | 28 | return Enumerable.SequenceEqual(goldBytes, compBytes); 29 | } 30 | } 31 | static class ZipArchiveExtensions { 32 | public static byte[] ExtractAllContentBytes(this ZipArchive archive) { 33 | var ms = new MemoryStream(); 34 | foreach (var entry in archive.Entries.OrderBy(e => e.FullName)) 35 | entry.Open().CopyTo(ms); 36 | 37 | return ms.ToArray(); 38 | } 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /TorchSharp.PyBridge.Tests/TestLoadSaveModules.cs: -------------------------------------------------------------------------------- 1 | using NUnit.Framework; 2 | using System.Diagnostics; 3 | using System.IO.Compression; 4 | using TorchSharp.PyBridge.Tests; 5 | using static TorchSharp.torch.nn; 6 | 7 | namespace TorchSharp.PyBridge.Tests { 8 | 9 | public class TestLoadSaveModules { 10 | 11 | [Test] 12 | public void TestPythonModuleLoad() { 13 | // We already saved a python module using `torch.save` to the file `module_load.pth` 14 | // Load in that model and make sure that the results are the same 15 | var model = Sequential(("lin1", Linear(5, 1, hasBias: false)), ("lin2", Linear(1, 2, hasBias: false))); 16 | model.load_py("pickled_modules/module_load.pth"); 17 | 18 | // The weights are all ones, so make sure that if we give it an array of ones we get 19 | // back as the result [5,5] 20 | var res = model.forward(torch.tensor(new[] { 1, 1, 1, 1, 1 }).@float()); 21 | Assert.Multiple(() => { 22 | Assert.That(res[0].ToSingle(), Is.EqualTo(5)); 23 | Assert.That(res[1].ToSingle(), Is.EqualTo(5)); 24 | }); 25 | } 26 | 27 | [Test] 28 | public void TestPythonModuleSave() { 29 | // Create a sequential model and set the values to be absolute numbers, save the file, and make sure 30 | // loading the module in pytorch gives us the expected numbers 31 | var model = Sequential(("lin1", Linear(5, 1, hasBias: false)), ("lin2", Linear(1, 2, hasBias: false))); 32 | model.state_dict()["lin1.weight"].bytes = torch.full(1, 5, 2, torch.ScalarType.Float32).bytes; 33 | model.state_dict()["lin2.weight"].bytes = torch.full(2, 1, 2, torch.ScalarType.Float32).bytes; 34 | 35 | // Create a temporary filename to test 36 | var tempFile = Guid.NewGuid().ToString() + ".pth"; 37 | try { 38 | // Save the module 39 | model.save_py(tempFile); 40 | 41 | // Run the python 42 | var p = new Process(); 43 | p.StartInfo.FileName = "python"; 44 | p.StartInfo.Arguments = "test_torch_load.py " + tempFile; 45 | p.StartInfo.UseShellExecute = false; 46 | p.StartInfo.RedirectStandardOutput = true; 47 | 48 | p.Start(); 49 | p.WaitForExit(10); 50 | // Read the output 51 | var output = p.StandardOutput.ReadToEnd().Trim(); 52 | Assert.That(output, Is.EqualTo("tensor([20., 20.])")); 53 | } finally { 54 | if (File.Exists(tempFile)) 55 | File.Delete(tempFile); 56 | } 57 | } 58 | 59 | 60 | [Test] 61 | public void TestSafetensorsModuleLoad() { 62 | // We already saved a safetensors module using `safetensors.torch.save_file` to the file `module_load.safetensors` 63 | // Load in that model and make sure that the results are the same 64 | var model = Sequential(("lin1", Linear(5, 2, hasBias: false)), ("lin2", Linear(2, 2, hasBias: false))); 65 | model.load_safetensors("pickled_modules/module_load.safetensors"); 66 | 67 | // The weights are ones for lin1 and twos for lin2. Therefore, for the input of (11, 11, 11, 11, 11) we should get 68 | // back the result of [220,200] 69 | var res = model.forward(torch.tensor(new[] { 11, 11, 11, 11, 11 }).@float()); 70 | Assert.Multiple(() => { 71 | Assert.That(res[0].ToSingle(), Is.EqualTo(220)); 72 | Assert.That(res[1].ToSingle(), Is.EqualTo(220)); 73 | }); 74 | } 75 | 76 | [Test] 77 | public void TestSafetensorsModuleSave() { 78 | // Create a sequential model and set the values to be absolute numbers, save the file, and make sure 79 | // loading the safetensors in python gives us the expected numbers 80 | var model = Sequential(("lin1", Linear(5, 1, hasBias: false)), ("lin2", Linear(1, 2, hasBias: false))); 81 | model.state_dict()["lin1.weight"].bytes = torch.full(1, 5, 2, torch.ScalarType.Float32).bytes; 82 | model.state_dict()["lin2.weight"].bytes = torch.full(2, 1, 2, torch.ScalarType.Float32).bytes; 83 | 84 | // Create a temporary filename to test 85 | var tempFile = Guid.NewGuid().ToString() + ".safetensors"; 86 | try { 87 | // Save the module 88 | model.save_safetensors(tempFile); 89 | 90 | // Run the python 91 | var p = new Process(); 92 | p.StartInfo.FileName = "python"; 93 | p.StartInfo.Arguments = "test_safetensors_load.py " + tempFile; 94 | p.StartInfo.UseShellExecute = false; 95 | p.StartInfo.RedirectStandardOutput = true; 96 | 97 | p.Start(); 98 | p.WaitForExit(10); 99 | // Read the output 100 | var output = p.StandardOutput.ReadToEnd().Trim(); 101 | Assert.That(output, Is.EqualTo("tensor([20., 20.])")); 102 | } 103 | finally { 104 | if (File.Exists(tempFile)) 105 | File.Delete(tempFile); 106 | } 107 | 108 | 109 | } 110 | 111 | [Test] 112 | public void TestLoadBatchNorm2D_Bug13() { 113 | var model = BatchNorm2d(5); 114 | // Run a few inputs through, to increment the `num_batches_tracked` field 115 | for (int i = 0; i < 5; i++) { 116 | using var d = torch.NewDisposeScope(); 117 | model.forward(torch.rand(new[] { 5L, 5L, 5L, 5L })); 118 | } 119 | 120 | Assert.That(model.num_batches_tracked!.item(), Is.EqualTo(5)); 121 | 122 | // Save to a file, and try reloading 123 | using var stream = new MemoryStream(); 124 | model.save_py(stream, leaveOpen: true); 125 | stream.Position = 0; 126 | 127 | // Create the new model and load it 128 | var model2 = BatchNorm2d(5); 129 | model2.load_py(stream); 130 | 131 | Assert.That(model2.num_batches_tracked!.item(), Is.EqualTo(5)); 132 | } 133 | } 134 | 135 | 136 | } -------------------------------------------------------------------------------- /TorchSharp.PyBridge.Tests/TestLoadSaveOptimizers.cs: -------------------------------------------------------------------------------- 1 | using NUnit.Framework; 2 | using System.IO.Compression; 3 | using TorchSharp.Modules; 4 | using TorchSharp.PyBridge.Tests; 5 | using static TorchSharp.torch.nn; 6 | 7 | namespace TorchSharp.PyBridge.Tests { 8 | public class TestLoadSaveOptimizers { 9 | 10 | private static void TestSaveOptim(Func, T> func, bool withLoss = false) where T : OptimizerHelper { 11 | // Set the manual seed so that the randoms don't change between runs 12 | // and our tests will succeed 13 | torch.manual_seed(423812); 14 | 15 | var l1 = Linear(10, 10, true); 16 | var l2 = Linear(10, 10, true); 17 | var seq = Sequential(l1, l2); 18 | 19 | torch.manual_seed(423812); 20 | var optim = func(seq.parameters()); 21 | 22 | // Force the buffers that are only created after loss to be created. 23 | if (withLoss) { 24 | using var d = torch.NewDisposeScope(); 25 | optim.zero_grad(); 26 | var x = torch.randn(new[] { 64L, 10L }); 27 | var y = torch.randn(new[] { 64L, 10L }); 28 | torch.nn.functional.mse_loss(seq.call(x), y).backward(); 29 | optim.step(); 30 | } 31 | 32 | // Save that optim to memory 33 | using var ms = new MemoryStream(); 34 | optim.save_py(ms, leaveOpen: true); 35 | ms.Position = 0; 36 | 37 | // Create a new optimizer, and load in the values, and make sure they are the same 38 | torch.manual_seed(423812); 39 | var optim2 = func(seq.parameters()); 40 | optim2.load_py(ms, leaveOpen: true); 41 | 42 | // Save the second optim to a memory stream, and make sure they are identical 43 | using var ms2 = new MemoryStream(); 44 | optim2.save_py(ms2, leaveOpen: true); 45 | 46 | // Compare the bytes to pyload_test 47 | ms.Position = 0; ms2.Position = 0; 48 | Assert.That(SaveUtils.CompareSavedModules(ms, ms2)); 49 | } 50 | 51 | [Test] 52 | public void TestSaveRprop() { 53 | TestSaveOptim(p => torch.optim.Rprop(p)); 54 | } 55 | 56 | [Test] 57 | public void TestSaveRpropWithLoss() { 58 | TestSaveOptim(p => torch.optim.Rprop(p), true); 59 | } 60 | 61 | [Test] 62 | public void TestSaveSGD() { 63 | TestSaveOptim(p => torch.optim.SGD(p, 0.01)); 64 | } 65 | 66 | [Test] 67 | public void TestSaveSGDWithLoss() { 68 | TestSaveOptim(p => torch.optim.SGD(p, 0.01), true); 69 | } 70 | 71 | [Test] 72 | public void TestSaveASGD() { 73 | TestSaveOptim(p => torch.optim.ASGD(p)); 74 | } 75 | 76 | [Test] 77 | public void TestSaveASGDWithLoss() { 78 | TestSaveOptim(p => torch.optim.ASGD(p), true); 79 | } 80 | 81 | [Test] 82 | public void TestSaveRMSProp() { 83 | TestSaveOptim(p => torch.optim.RMSProp(p)); 84 | } 85 | 86 | [Test] 87 | public void TestSaveRMSPropWithLoss() { 88 | TestSaveOptim(p => torch.optim.RMSProp(p), true); 89 | } 90 | 91 | [Test] 92 | public void TestSaveRAdam() { 93 | TestSaveOptim(p => torch.optim.RAdam(p)); 94 | } 95 | 96 | [Test] 97 | public void TestSaveRAdamWithLoss() { 98 | TestSaveOptim(p => torch.optim.RAdam(p), true); 99 | } 100 | 101 | [Test] 102 | public void TestSaveNAdam() { 103 | TestSaveOptim(p => torch.optim.NAdam(p)); 104 | } 105 | 106 | [Test] 107 | public void TestSaveNAdamWithLoss() { 108 | TestSaveOptim(p => torch.optim.NAdam(p), true); 109 | } 110 | 111 | [Test] 112 | public void TestSaveAdam() { 113 | TestSaveOptim(p => torch.optim.Adam(p)); 114 | } 115 | 116 | [Test] 117 | public void TestSaveAdamWithLoss() { 118 | TestSaveOptim(p => torch.optim.Adam(p), true); 119 | } 120 | 121 | [Test] 122 | public void TestSaveAdamW() { 123 | TestSaveOptim(p => torch.optim.AdamW(p)); 124 | } 125 | 126 | [Test] 127 | public void TestSaveAdamWWithLoss() { 128 | TestSaveOptim(p => torch.optim.AdamW(p), true); 129 | } 130 | 131 | [Test] 132 | public void TestSaveAdamax() { 133 | TestSaveOptim(p => torch.optim.Adamax(p)); 134 | } 135 | 136 | [Test] 137 | public void TestSaveAdamaxWithLoss() { 138 | TestSaveOptim(p => torch.optim.Adamax(p), true); 139 | } 140 | 141 | [Test] 142 | public void TestSaveAdagrad() { 143 | TestSaveOptim(p => torch.optim.Adagrad(p)); 144 | } 145 | 146 | [Test] 147 | public void TestSaveAdagradWithLoss() { 148 | TestSaveOptim(p => torch.optim.Adagrad(p), true); 149 | } 150 | 151 | [Test] 152 | public void TestSaveAdadelta() { 153 | TestSaveOptim(p => torch.optim.Adadelta(p)); 154 | } 155 | 156 | [Test] 157 | public void TestSaveAdadeltaWithLoss() { 158 | TestSaveOptim(p => torch.optim.Adadelta(p), true); 159 | } 160 | 161 | 162 | [Test] 163 | public void TestLoadSGD() { 164 | var lin = torch.nn.Linear(10, 10); 165 | 166 | double learning_rate = 0.00004f; 167 | var optimizer = torch.optim.SGD(lin.parameters(), learning_rate); 168 | 169 | optimizer.load_py("pickled_optimizers/sgd_load.pth"); 170 | 171 | var sd = optimizer.state_dict(); 172 | Assert.Multiple(() => { 173 | Assert.That(sd.Options, Has.One.Items); 174 | Assert.That(sd.State, Has.Exactly(2).Items); 175 | }); 176 | 177 | foreach (var opts in sd.Options) { 178 | var options = opts as Modules.SGD.Options; 179 | Assert.Multiple(() => { 180 | Assert.That(options!.momentum, Is.EqualTo(0.1)); 181 | Assert.That(options!.LearningRate, Is.Not.EqualTo(learning_rate)); 182 | }); 183 | } 184 | 185 | foreach (var st in sd.State) { 186 | var state = st as Modules.SGD.State; 187 | Assert.That(state!.momentum_buffer, Is.Not.Null); 188 | } 189 | } 190 | 191 | [Test] 192 | public void TestLoadASGD() { 193 | var lin = torch.nn.Linear(10, 10); 194 | 195 | double learning_rate = 0.004f; 196 | var optimizer = torch.optim.ASGD(lin.parameters(), learning_rate); 197 | 198 | optimizer.load_py("pickled_optimizers/asgd_load.pth"); 199 | 200 | var sd = optimizer.state_dict(); 201 | Assert.Multiple(() => { 202 | Assert.That(sd.Options, Has.One.Items); 203 | Assert.That(sd.State, Has.Exactly(2).Items); 204 | }); 205 | foreach (var opts in sd.Options) { 206 | var options = opts as Modules.ASGD.Options; 207 | Assert.Multiple(() => { 208 | Assert.That(options!.alpha, Is.EqualTo(0.65)); 209 | Assert.That(options!.lambd, Is.EqualTo(1e-3)); 210 | Assert.That(options!.t0, Is.EqualTo(1e5)); 211 | Assert.That(options!.LearningRate, Is.Not.EqualTo(learning_rate)); 212 | }); 213 | } 214 | 215 | foreach (var st in sd.State) { 216 | var state = st as Modules.ASGD.State; 217 | Assert.Multiple(() => { 218 | Assert.That(state!.step, Is.EqualTo(1)); 219 | Assert.That(state!.ax, Is.Not.Null); 220 | }); 221 | } 222 | } 223 | 224 | [Test] 225 | public void TestLoadRMSprop() { 226 | var lin1 = torch.nn.Linear(10, 10); 227 | var lin2 = torch.nn.Linear(10, 10); 228 | 229 | var seq = Sequential(("lin1", lin1), ("lin2", lin2)); 230 | 231 | var pgs = new RMSProp.ParamGroup[] { 232 | new () { Parameters = lin1.parameters(), Options = new() { LearningRate = 0.00005 } }, 233 | new (lin2.parameters(), lr: 0.00003, centered: false, momentum: 0.25) 234 | }; 235 | 236 | double learning_rate = 0.00004f; 237 | var optimizer = torch.optim.RMSProp(pgs, learning_rate); 238 | 239 | optimizer.load_py("pickled_optimizers/rmsprop_load.pth"); 240 | 241 | var sd = optimizer.state_dict(); 242 | Assert.Multiple(() => { 243 | Assert.That(sd.Options, Has.Count.EqualTo(2)); 244 | Assert.That(sd.State, Has.Count.EqualTo(4)); 245 | }); 246 | 247 | var options = (sd.Options[0] as Modules.RMSProp.Options)!; 248 | Assert.Multiple(() => { 249 | Assert.That(options.momentum, Is.EqualTo(0.1)); 250 | Assert.That(options.LearningRate, Is.EqualTo(0.001)); 251 | Assert.That(options.centered, Is.False); 252 | }); 253 | 254 | options = (sd.Options[1] as Modules.RMSProp.Options)!; 255 | Assert.Multiple(() => { 256 | Assert.That(options.momentum, Is.EqualTo(0)); 257 | Assert.That(options.LearningRate, Is.EqualTo(0.01)); 258 | Assert.That(options.centered, Is.True); 259 | }); 260 | 261 | var state = (sd.State[0] as Modules.RMSProp.State)!; 262 | Assert.Multiple(() => { 263 | Assert.That(state.step, Is.EqualTo(1)); 264 | Assert.That(state.square_avg, Is.Not.Null); 265 | Assert.That(state.momentum_buffer, Is.Not.Null); 266 | Assert.That(state.grad_avg, Is.Null); 267 | }); 268 | 269 | state = (sd.State[1] as Modules.RMSProp.State)!; 270 | Assert.Multiple(() => { 271 | Assert.That(state.step, Is.EqualTo(1)); 272 | Assert.That(state.square_avg, Is.Not.Null); 273 | Assert.That(state.momentum_buffer, Is.Not.Null); 274 | Assert.That(state.grad_avg, Is.Null); 275 | }); 276 | 277 | state = (sd.State[2] as Modules.RMSProp.State)!; 278 | Assert.Multiple(() => { 279 | Assert.That(state.step, Is.EqualTo(1)); 280 | Assert.That(state.square_avg, Is.Not.Null); 281 | Assert.That(state.momentum_buffer, Is.Null); 282 | Assert.That(state.grad_avg, Is.Not.Null); 283 | }); 284 | 285 | state = (sd.State[3] as Modules.RMSProp.State)!; 286 | Assert.Multiple(() => { 287 | Assert.That(state.step, Is.EqualTo(1)); 288 | Assert.That(state.square_avg, Is.Not.Null); 289 | Assert.That(state.momentum_buffer, Is.Null); 290 | Assert.That(state.grad_avg, Is.Not.Null); 291 | }); 292 | } 293 | 294 | [Test] 295 | public void TestLoadRprop() { 296 | var lin1 = torch.nn.Linear(10, 10); 297 | var lin2 = torch.nn.Linear(10, 10); 298 | 299 | var seq = Sequential(("lin1", lin1), ("lin2", lin2)); 300 | 301 | var pgs = new Rprop.ParamGroup[] { 302 | new () { Parameters = lin1.parameters(), Options = new() { LearningRate = 0.00005 } }, 303 | new (lin2.parameters(), lr: 0.00003, maximize: false) 304 | }; 305 | 306 | double learning_rate = 0.00004f; 307 | var optimizer = torch.optim.Rprop(pgs, learning_rate); 308 | 309 | optimizer.load_py("pickled_optimizers/rprop_load.pth"); 310 | 311 | var sd = optimizer.state_dict(); 312 | Assert.Multiple(() => { 313 | Assert.That(sd.Options, Has.Count.EqualTo(2)); 314 | Assert.That(sd.State, Has.Count.EqualTo(4)); 315 | }); 316 | 317 | var options = (sd.Options[0] as Modules.Rprop.Options)!; 318 | Assert.Multiple(() => { 319 | Assert.That(options.etaminus, Is.EqualTo(0.35)); 320 | Assert.That(options.etaplus, Is.EqualTo(1.5)); 321 | Assert.That(options.min_step, Is.EqualTo(1e-5)); 322 | Assert.That(options.max_step, Is.EqualTo(5)); 323 | Assert.That(options.LearningRate, Is.EqualTo(0.001)); 324 | Assert.That(options.maximize, Is.False); 325 | }); 326 | 327 | options = (sd.Options[1] as Modules.Rprop.Options)!; 328 | Assert.Multiple(() => { 329 | Assert.That(options.etaminus, Is.EqualTo(0.45)); 330 | Assert.That(options.etaplus, Is.EqualTo(1.5)); 331 | Assert.That(options.min_step, Is.EqualTo(1e-5)); 332 | Assert.That(options.max_step, Is.EqualTo(5)); 333 | Assert.That(options.LearningRate, Is.EqualTo(0.01)); 334 | Assert.That(options.maximize, Is.True); 335 | }); 336 | 337 | foreach (var st in sd.State) { 338 | var state = (sd.State[0] as Modules.Rprop.State)!; 339 | Assert.Multiple(() => { 340 | Assert.That(state.step, Is.EqualTo(1)); 341 | Assert.That(state.prev, Is.Not.Null); 342 | Assert.That(state.step_size, Is.Not.Null); 343 | }); 344 | } 345 | } 346 | 347 | [Test] 348 | public void TestLoadAdam() { 349 | var lin1 = torch.nn.Linear(10, 10); 350 | var lin2 = torch.nn.Linear(10, 10); 351 | 352 | var pgs = new Adam.ParamGroup[] { 353 | new () { Parameters = lin1.parameters(), Options = new() { LearningRate = 0.00005 } }, 354 | new (lin2.parameters(), lr: 0.00003, amsgrad: false, beta1: 0.25) 355 | }; 356 | 357 | double learning_rate = 0.00004f; 358 | var optimizer = torch.optim.Adam(pgs, learning_rate); 359 | 360 | optimizer.load_py("pickled_optimizers/adam_load.pth"); 361 | 362 | var sd = optimizer.state_dict(); 363 | Assert.Multiple(() => { 364 | Assert.That(sd.Options, Has.Count.EqualTo(2)); 365 | Assert.That(sd.State, Has.Count.EqualTo(4)); 366 | }); 367 | 368 | var options = (sd.Options[0] as Modules.Adam.Options)!; 369 | Assert.Multiple(() => { 370 | Assert.That(options.beta1, Is.EqualTo(0.8)); 371 | Assert.That(options.beta2, Is.EqualTo(0.9)); 372 | Assert.That(options.LearningRate, Is.EqualTo(0.001)); 373 | Assert.That(options.amsgrad, Is.False); 374 | }); 375 | 376 | options = (sd.Options[1] as Modules.Adam.Options)!; 377 | Assert.Multiple(() => { 378 | Assert.That(options.beta1, Is.EqualTo(0.7)); 379 | Assert.That(options.beta2, Is.EqualTo(0.79)); 380 | Assert.That(options.LearningRate, Is.EqualTo(0.01)); 381 | Assert.That(options!.amsgrad, Is.True); 382 | }); 383 | 384 | var state = (sd.State[0] as Modules.Adam.State)!; 385 | Assert.Multiple(() => { 386 | Assert.That(state.step, Is.EqualTo(1)); 387 | Assert.That(state.exp_avg, Is.Not.Null); 388 | Assert.That(state.exp_avg_sq, Is.Not.Null); 389 | Assert.That(state.max_exp_avg_sq, Is.Null); 390 | }); 391 | 392 | state = (sd.State[1] as Modules.Adam.State)!; 393 | Assert.Multiple(() => { 394 | Assert.That(state.step, Is.EqualTo(1)); 395 | Assert.That(state.exp_avg, Is.Not.Null); 396 | Assert.That(state.exp_avg_sq, Is.Not.Null); 397 | Assert.That(state.max_exp_avg_sq, Is.Null); 398 | }); 399 | 400 | state = (sd.State[2] as Modules.Adam.State)!; 401 | Assert.Multiple(() => { 402 | Assert.That(state.step, Is.EqualTo(1)); 403 | Assert.That(state.exp_avg, Is.Not.Null); 404 | Assert.That(state.exp_avg_sq, Is.Not.Null); 405 | Assert.That(state.max_exp_avg_sq, Is.Not.Null); 406 | }); 407 | 408 | state = (sd.State[3] as Modules.Adam.State)!; 409 | Assert.Multiple(() => { 410 | Assert.That(state.step, Is.EqualTo(1)); 411 | Assert.That(state.exp_avg, Is.Not.Null); 412 | Assert.That(state.exp_avg_sq, Is.Not.Null); 413 | Assert.That(state.max_exp_avg_sq, Is.Not.Null); 414 | }); 415 | } 416 | 417 | [Test] 418 | public void TestLoadAdamW() { 419 | var lin1 = torch.nn.Linear(10, 10); 420 | var lin2 = torch.nn.Linear(10, 10); 421 | 422 | var pgs = new AdamW.ParamGroup[] { 423 | new () { Parameters = lin1.parameters(), Options = new() { LearningRate = 0.00005 } }, 424 | new (lin2.parameters(), lr: 0.00003, amsgrad: false, beta1: 0.25) 425 | }; 426 | 427 | double learning_rate = 0.00004f; 428 | var optimizer = torch.optim.AdamW(pgs, learning_rate); 429 | 430 | optimizer.load_py("pickled_optimizers/adamw_load.pth"); 431 | 432 | var sd = optimizer.state_dict(); 433 | Assert.Multiple(() => { 434 | Assert.That(sd.Options, Has.Count.EqualTo(2)); 435 | Assert.That(sd.State, Has.Count.EqualTo(4)); 436 | }); 437 | 438 | var options = (sd.Options[0] as Modules.AdamW.Options)!; 439 | Assert.Multiple(() => { 440 | Assert.That(options.beta1, Is.EqualTo(0.8)); 441 | Assert.That(options.beta2, Is.EqualTo(0.9)); 442 | Assert.That(options.LearningRate, Is.EqualTo(0.001)); 443 | Assert.That(options.amsgrad, Is.False); 444 | }); 445 | 446 | options = (sd.Options[1] as Modules.AdamW.Options)!; 447 | Assert.Multiple(() => { 448 | Assert.That(options.beta1, Is.EqualTo(0.7)); 449 | Assert.That(options.beta2, Is.EqualTo(0.79)); 450 | Assert.That(options.LearningRate, Is.EqualTo(0.01)); 451 | Assert.That(options.amsgrad, Is.True); 452 | }); 453 | 454 | 455 | var state = (sd.State[0] as Modules.AdamW.State)!; 456 | Assert.Multiple(() => { 457 | Assert.That(state.step, Is.EqualTo(1)); 458 | Assert.That(state.exp_avg, Is.Not.Null); 459 | Assert.That(state.exp_avg_sq, Is.Not.Null); 460 | Assert.That(state.max_exp_avg_sq, Is.Null); 461 | }); 462 | 463 | state = (sd.State[1] as Modules.AdamW.State)!; 464 | Assert.Multiple(() => { 465 | Assert.That(state.step, Is.EqualTo(1)); 466 | Assert.That(state.exp_avg, Is.Not.Null); 467 | Assert.That(state.exp_avg_sq, Is.Not.Null); 468 | Assert.That(state.max_exp_avg_sq, Is.Null); 469 | }); 470 | 471 | state = (sd.State[2] as Modules.AdamW.State)!; 472 | Assert.Multiple(() => { 473 | Assert.That(state.step, Is.EqualTo(1)); 474 | Assert.That(state.exp_avg, Is.Not.Null); 475 | Assert.That(state.exp_avg_sq, Is.Not.Null); 476 | Assert.That(state.max_exp_avg_sq, Is.Not.Null); 477 | }); 478 | 479 | state = (sd.State[3] as Modules.AdamW.State)!; 480 | Assert.Multiple(() => { 481 | Assert.That(state.step, Is.EqualTo(1)); 482 | Assert.That(state.exp_avg, Is.Not.Null); 483 | Assert.That(state.exp_avg_sq, Is.Not.Null); 484 | Assert.That(state.max_exp_avg_sq, Is.Not.Null); 485 | }); 486 | } 487 | 488 | [Test] 489 | public void TestLoadNAdam() { 490 | var lin1 = torch.nn.Linear(10, 10); 491 | var lin2 = torch.nn.Linear(10, 10); 492 | 493 | var pgs = new NAdam.ParamGroup[] { 494 | new () { Parameters = lin1.parameters(), Options = new() { LearningRate = 0.00005 } }, 495 | new (lin2.parameters(), lr: 0.00003, beta1: 0.25, weight_decay: 0.1) 496 | }; 497 | 498 | double learning_rate = 0.00004f; 499 | var optimizer = torch.optim.NAdam(pgs, learning_rate); 500 | 501 | optimizer.load_py("pickled_optimizers/nadam_load.pth"); 502 | 503 | var sd = optimizer.state_dict(); 504 | Assert.Multiple(() => { 505 | Assert.That(sd.Options, Has.Count.EqualTo(2)); 506 | Assert.That(sd.State, Has.Count.EqualTo(4)); 507 | }); 508 | 509 | var options = (sd.Options[0] as Modules.NAdam.Options)!; 510 | Assert.Multiple(() => { 511 | Assert.That(options.beta1, Is.EqualTo(0.8)); 512 | Assert.That(options.beta2, Is.EqualTo(0.9)); 513 | Assert.That(options.LearningRate, Is.EqualTo(0.001)); 514 | Assert.That(options.weight_decay, Is.EqualTo(0)); 515 | }); 516 | 517 | options = (sd.Options[1] as Modules.NAdam.Options)!; 518 | Assert.Multiple(() => { 519 | Assert.That(options.beta1, Is.EqualTo(0.7)); 520 | Assert.That(options.beta2, Is.EqualTo(0.79)); 521 | Assert.That(options.LearningRate, Is.EqualTo(0.01)); 522 | Assert.That(options.weight_decay, Is.EqualTo(0.3)); 523 | }); 524 | 525 | foreach (var st in sd.State) { 526 | var state = (st as Modules.NAdam.State)!; 527 | Assert.Multiple(() => { 528 | Assert.That(state.step, Is.EqualTo(1)); 529 | Assert.That(state.exp_avg, Is.Not.Null); 530 | Assert.That(state.exp_avg_sq, Is.Not.Null); 531 | }); 532 | } 533 | } 534 | 535 | [Test] 536 | public void TestLoadingRAdam() { 537 | var lin1 = torch.nn.Linear(10, 10); 538 | var lin2 = torch.nn.Linear(10, 10); 539 | 540 | var pgs = new RAdam.ParamGroup[] { 541 | new () { Parameters = lin1.parameters(), Options = new() { LearningRate = 0.00005 } }, 542 | new (lin2.parameters(), lr: 0.00003, beta1: 0.25) 543 | }; 544 | 545 | double learning_rate = 0.00004f; 546 | var optimizer = torch.optim.RAdam(pgs, learning_rate); 547 | 548 | optimizer.load_py("pickled_optimizers/radam_load.pth"); 549 | 550 | var sd = optimizer.state_dict(); 551 | Assert.Multiple(() => { 552 | Assert.That(sd.Options, Has.Count.EqualTo(2)); 553 | Assert.That(sd.State, Has.Count.EqualTo(4)); 554 | }); 555 | 556 | var options = (sd.Options[0] as Modules.RAdam.Options)!; 557 | Assert.Multiple(() => { 558 | Assert.That(options.beta1, Is.EqualTo(0.8)); 559 | Assert.That(options.beta2, Is.EqualTo(0.9)); 560 | Assert.That(options.LearningRate, Is.EqualTo(0.001)); 561 | Assert.That(options.weight_decay, Is.EqualTo(0)); 562 | }); 563 | 564 | options = (sd.Options[1] as Modules.RAdam.Options)!; 565 | Assert.Multiple(() => { 566 | Assert.That(options.beta1, Is.EqualTo(0.7)); 567 | Assert.That(options.beta2, Is.EqualTo(0.79)); 568 | Assert.That(options.LearningRate, Is.EqualTo(0.01)); 569 | Assert.That(options.weight_decay, Is.EqualTo(0.3)); 570 | }); 571 | 572 | foreach (var st in sd.State) { 573 | var state = (st as Modules.RAdam.State)!; 574 | Assert.Multiple(() => { 575 | Assert.That(state.step, Is.EqualTo(1)); 576 | Assert.That(state.exp_avg, Is.Not.Null); 577 | Assert.That(state.exp_avg_sq, Is.Not.Null); 578 | }); 579 | } 580 | } 581 | 582 | [Test] 583 | public void TestLoadAdadelta() { 584 | var lin1 = torch.nn.Linear(10, 10); 585 | var lin2 = torch.nn.Linear(10, 10); 586 | 587 | var pgs = new Adadelta.ParamGroup[] { 588 | new () { Parameters = lin1.parameters(), Options = new() { LearningRate = 0.00005 } }, 589 | new (lin2.parameters(), lr: 0.00003, maximize: false, rho: 0.25) 590 | }; 591 | 592 | double learning_rate = 0.00004f; 593 | var optimizer = torch.optim.Adadelta(pgs, learning_rate); 594 | 595 | optimizer.load_py("pickled_optimizers/adadelta_load.pth"); 596 | 597 | var sd = optimizer.state_dict(); 598 | Assert.Multiple(() => { 599 | Assert.That(sd.Options, Has.Count.EqualTo(2)); 600 | Assert.That(sd.State, Has.Count.EqualTo(4)); 601 | }); 602 | 603 | var options = (sd.Options[0] as Modules.Adadelta.Options)!; 604 | Assert.Multiple(() => { 605 | Assert.That(options.rho, Is.EqualTo(0.85)); 606 | Assert.That(options.weight_decay, Is.EqualTo(0.3)); 607 | Assert.That(options.LearningRate, Is.EqualTo(0.001)); 608 | Assert.That(options.maximize, Is.False); 609 | }); 610 | 611 | options = (sd.Options[1] as Modules.Adadelta.Options)!; 612 | Assert.Multiple(() => { 613 | Assert.That(options.rho, Is.EqualTo(0.79)); 614 | Assert.That(options.weight_decay, Is.EqualTo(0.3)); 615 | Assert.That(options.LearningRate, Is.EqualTo(0.01)); 616 | Assert.That(options.maximize, Is.True); 617 | }); 618 | 619 | 620 | foreach (var st in sd.State) { 621 | var state = (st as Modules.Adadelta.State)!; 622 | Assert.Multiple(() => { 623 | Assert.That(state.step, Is.EqualTo(1)); 624 | Assert.That(state.square_avg, Is.Not.Null); 625 | Assert.That(state.acc_delta, Is.Not.Null); 626 | }); 627 | } 628 | } 629 | 630 | [Test] 631 | public void TestLoadAdagrad() { 632 | var lin = torch.nn.Linear(10, 10); 633 | 634 | double learning_rate = 0.00004f; 635 | var optimizer = torch.optim.Adagrad(lin.parameters(), learning_rate); 636 | 637 | optimizer.load_py("pickled_optimizers/adagrad_load.pth"); 638 | 639 | var sd = optimizer.state_dict(); 640 | Assert.Multiple(() => { 641 | Assert.That(sd.Options, Has.One.Items); 642 | Assert.That(sd.State, Has.Count.EqualTo(2)); 643 | }); 644 | 645 | foreach (var opts in sd.Options) { 646 | var options = (opts as Modules.Adagrad.Options)!; 647 | Assert.Multiple(() => { 648 | Assert.That(options.lr_decay, Is.EqualTo(0.85)); 649 | Assert.That(options.weight_decay, Is.EqualTo(0.3)); 650 | Assert.That(options.LearningRate, Is.Not.EqualTo(learning_rate)); 651 | }); 652 | } 653 | 654 | foreach (var st in sd.State) { 655 | var state = (st as Modules.Adagrad.State)!; 656 | Assert.Multiple(() => { 657 | Assert.That(state.step, Is.EqualTo(1)); 658 | Assert.That(state.sum, Is.Not.Null); 659 | }); 660 | } 661 | } 662 | 663 | [Test] 664 | public void TestLoadAdamax() { 665 | var lin1 = torch.nn.Linear(10, 10); 666 | var lin2 = torch.nn.Linear(10, 10); 667 | 668 | var pgs = new Adamax.ParamGroup[] { 669 | new () { Parameters = lin1.parameters(), Options = new() { LearningRate = 0.00005 } }, 670 | new (lin2.parameters(), lr: 0.00003, weight_decay: 0.25, beta1: 0.25) 671 | }; 672 | 673 | double learning_rate = 0.00004f; 674 | var optimizer = torch.optim.Adamax(pgs, learning_rate); 675 | 676 | optimizer.load_py("pickled_optimizers/adamax_load.pth"); 677 | 678 | var sd = optimizer.state_dict(); 679 | Assert.Multiple(() => { 680 | Assert.That(sd.Options, Has.Count.EqualTo(2)); 681 | Assert.That(sd.State, Has.Count.EqualTo(4)); 682 | }); 683 | 684 | var options = (sd.Options[0] as Modules.Adamax.Options)!; 685 | Assert.Multiple(() => { 686 | Assert.That(options.beta1, Is.EqualTo(0.8)); 687 | Assert.That(options.beta2, Is.EqualTo(0.9)); 688 | Assert.That(options.LearningRate, Is.EqualTo(0.001)); 689 | Assert.That(options.weight_decay, Is.EqualTo(0)); 690 | }); 691 | 692 | options = (sd.Options[1] as Modules.Adamax.Options)!; 693 | Assert.Multiple(() => { 694 | Assert.That(options.beta1, Is.EqualTo(0.7)); 695 | Assert.That(options.beta2, Is.EqualTo(0.79)); 696 | Assert.That(options.LearningRate, Is.EqualTo(0.01)); 697 | Assert.That(options.weight_decay, Is.EqualTo(0.3)); 698 | }); 699 | 700 | foreach (var state in sd.State.Cast()) { 701 | Assert.Multiple(() => { 702 | Assert.That(state.step, Is.EqualTo(1)); 703 | Assert.That(state.exp_avg, Is.Not.Null); 704 | Assert.That(state.exp_inf, Is.Not.Null); 705 | }); 706 | } 707 | } 708 | [Test] 709 | public void TestLoadAdamEmptyState() { 710 | var lin1 = torch.nn.Linear(10, 10); 711 | var lin2 = torch.nn.Linear(10, 10); 712 | var lin3 = torch.nn.Linear(10, 10); 713 | 714 | var pgs = new Adam.ParamGroup[] { 715 | new () { Parameters = lin1.parameters(), Options = new() { LearningRate = 0.00005 } }, 716 | new (lin2.parameters(), lr: 0.00003, amsgrad: false, beta1: 0.25), 717 | new (lin3.parameters(), lr: 0.00003, amsgrad: false, beta1: 0.9, beta2: 9.99) 718 | }; 719 | 720 | double learning_rate = 0.00004f; 721 | var optimizer = torch.optim.Adam(pgs, learning_rate); 722 | 723 | // Calculate loss for lin3, so that it steps through 724 | torch.nn.functional.mse_loss(lin3.call(torch.rand(10)), torch.rand(10)).backward(); 725 | optimizer.step(); 726 | 727 | optimizer.load_py("pickled_optimizers/adam_emptystate_load.pth"); 728 | 729 | var sd = optimizer.state_dict(); 730 | Assert.Multiple(() => { 731 | Assert.That(sd.Options, Has.Count.EqualTo(3)); 732 | Assert.That(sd.State, Has.Count.EqualTo(6)); 733 | }); 734 | 735 | var options = (sd.Options[0] as Modules.Adam.Options)!; 736 | Assert.Multiple(() => { 737 | Assert.That(options.beta1, Is.EqualTo(0.8)); 738 | Assert.That(options.beta2, Is.EqualTo(0.9)); 739 | Assert.That(options.LearningRate, Is.EqualTo(0.001)); 740 | Assert.That(options.amsgrad, Is.False); 741 | }); 742 | 743 | options = (sd.Options[1] as Modules.Adam.Options)!; 744 | Assert.Multiple(() => { 745 | Assert.That(options.beta1, Is.EqualTo(0.7)); 746 | Assert.That(options.beta2, Is.EqualTo(0.79)); 747 | Assert.That(options.LearningRate, Is.EqualTo(0.01)); 748 | Assert.That(options!.amsgrad, Is.True); 749 | }); 750 | 751 | options = (sd.Options[2] as Modules.Adam.Options)!; 752 | Assert.Multiple(() => { 753 | Assert.That(options.beta1, Is.EqualTo(0.6)); 754 | Assert.That(options.beta2, Is.EqualTo(0.69)); 755 | Assert.That(options.LearningRate, Is.EqualTo(0.01)); 756 | Assert.That(options!.amsgrad, Is.True); 757 | }); 758 | 759 | var state = (sd.State[0] as Modules.Adam.State)!; 760 | Assert.Multiple(() => { 761 | Assert.That(state.step, Is.EqualTo(1)); 762 | Assert.That(state.exp_avg, Is.Not.Null); 763 | Assert.That(state.exp_avg_sq, Is.Not.Null); 764 | Assert.That(state.max_exp_avg_sq, Is.Null); 765 | }); 766 | 767 | state = (sd.State[1] as Modules.Adam.State)!; 768 | Assert.Multiple(() => { 769 | Assert.That(state.step, Is.EqualTo(1)); 770 | Assert.That(state.exp_avg, Is.Not.Null); 771 | Assert.That(state.exp_avg_sq, Is.Not.Null); 772 | Assert.That(state.max_exp_avg_sq, Is.Null); 773 | }); 774 | 775 | state = (sd.State[2] as Modules.Adam.State)!; 776 | Assert.Multiple(() => { 777 | Assert.That(state.step, Is.EqualTo(1)); 778 | Assert.That(state.exp_avg, Is.Not.Null); 779 | Assert.That(state.exp_avg_sq, Is.Not.Null); 780 | Assert.That(state.max_exp_avg_sq, Is.Not.Null); 781 | }); 782 | 783 | state = (sd.State[3] as Modules.Adam.State)!; 784 | Assert.Multiple(() => { 785 | Assert.That(state.step, Is.EqualTo(1)); 786 | Assert.That(state.exp_avg, Is.Not.Null); 787 | Assert.That(state.exp_avg_sq, Is.Not.Null); 788 | Assert.That(state.max_exp_avg_sq, Is.Not.Null); 789 | }); 790 | 791 | // Make sure lin3 was reset to the defaults 792 | state = (sd.State[4] as Modules.Adam.State)!; 793 | Assert.Multiple(() => { 794 | Assert.That(state.step, Is.EqualTo(0)); 795 | Assert.That(torch.count_nonzero(state.exp_avg).ToInt32(), Is.EqualTo(0)); 796 | Assert.That(torch.count_nonzero(state.exp_avg_sq).ToInt32(), Is.EqualTo(0)); 797 | Assert.That(state.max_exp_avg_sq, Is.Not.Null); 798 | }); 799 | 800 | state = (sd.State[5] as Modules.Adam.State)!; 801 | Assert.Multiple(() => { 802 | Assert.That(state.step, Is.EqualTo(0)); 803 | Assert.That(torch.count_nonzero(state.exp_avg).ToInt32(), Is.EqualTo(0)); 804 | Assert.That(torch.count_nonzero(state.exp_avg_sq).ToInt32(), Is.EqualTo(0)); 805 | Assert.That(state.max_exp_avg_sq, Is.Not.Null); 806 | }); 807 | } 808 | } 809 | } 810 | -------------------------------------------------------------------------------- /TorchSharp.PyBridge.Tests/TestLoadSaveTensors.cs: -------------------------------------------------------------------------------- 1 | using NUnit.Framework; 2 | using System.Diagnostics; 3 | using System.IO.Compression; 4 | using TorchSharp.PyBridge.Tests; 5 | using static Tensorboard.ApiDef.Types; 6 | using static TorchSharp.torch; 7 | using static TorchSharp.torch.nn; 8 | 9 | namespace TorchSharp.PyBridge.Tests { 10 | 11 | public class TestLoadSaveTensors { 12 | [Test] 13 | public void TestPythonTensorsLoad() { 14 | // We already saved a python dictionary using `torch.save` to the file `tensors.pth` 15 | var sd = PyTorchUnpickler.UnpickleStateDict("pickled_tensors/tensors.pth"); 16 | // Confirm keys 17 | Assert.Multiple(() => { 18 | Assert.That(sd.ContainsKey("arr"), Is.True); 19 | Assert.That(sd.ContainsKey("arr_2d"), Is.True); 20 | }); 21 | var arr = (torch.Tensor)sd["arr"]!; 22 | var arr2D = (torch.Tensor)sd["arr_2d"]!; 23 | 24 | // arr = torch.tensor([1, 2, 3, 4, 5, 6]) 25 | // arr_2d = arr.clone().reshape(2, 3) 26 | // Confirm type 27 | Assert.Multiple(() => { 28 | Assert.That(arr.dtype, Is.EqualTo(ScalarType.Int64)); 29 | Assert.That(arr2D.dtype, Is.EqualTo(ScalarType.Int64)); 30 | }); 31 | // Confirm shape 32 | Assert.Multiple(() => { 33 | Assert.That(arr.shape, Is.EquivalentTo(new[] { 6L })); 34 | Assert.That(arr2D.shape, Is.EquivalentTo(new[] { 2L, 3L })); 35 | }); 36 | 37 | // Confirm content 38 | Assert.Multiple(() => { 39 | Assert.That(arr.data().ToArray(), Is.EquivalentTo(new long[] { 1, 2, 3, 4, 5, 6 })); 40 | Assert.That(arr2D.data().ToArray(), Is.EquivalentTo(new long[] { 1, 2, 3, 4, 5, 6})); 41 | }); 42 | } 43 | 44 | [Test] 45 | public void TestPythonTensorsSave() { 46 | 47 | var arr = torch.tensor(new long[] { 1, 2, 3, 4, 5, 6 }); 48 | var arr2D = arr.clone().reshape(2, 3); 49 | var dict = new Dictionary() { 50 | { "arr", arr }, 51 | { "arr_2d", arr2D } 52 | }; 53 | 54 | // Create a temporary filename to test 55 | var tempFile = Guid.NewGuid().ToString() + ".pth"; 56 | try { 57 | // Save the tensors 58 | PyTorchPickler.PickleStateDict(tempFile, dict); 59 | 60 | // Run the python 61 | var p = new Process(); 62 | p.StartInfo.FileName = "python"; 63 | p.StartInfo.Arguments = "test_torch_tensor_load.py " + tempFile; 64 | p.StartInfo.UseShellExecute = false; 65 | p.StartInfo.RedirectStandardOutput = true; 66 | 67 | p.Start(); 68 | p.WaitForExit(10); 69 | // Read the output 70 | var output = p.StandardOutput.ReadToEnd().Trim(); 71 | Assert.That(output, Is.EqualTo( 72 | "arr" + Environment.NewLine + 73 | "tensor([1, 2, 3, 4, 5, 6])" + Environment.NewLine + 74 | "arr_2d" + Environment.NewLine + 75 | "tensor([[1, 2, 3]," + Environment.NewLine + " [4, 5, 6]])")); 76 | } finally { 77 | if (File.Exists(tempFile)) 78 | File.Delete(tempFile); 79 | } 80 | } 81 | 82 | 83 | [Test] 84 | public void TestSafetensorsTensorLoad() { 85 | // We already saved a the safetensors tensors using `save_file` to the file `tensors.safetensors` 86 | var sd = Safetensors.LoadStateDict("pickled_tensors/tensors.safetensors"); 87 | // Confirm keys 88 | Assert.Multiple(() => { 89 | Assert.That(sd.ContainsKey("arr"), Is.True); 90 | Assert.That(sd.ContainsKey("arr_2d"), Is.True); 91 | }); 92 | var arr = sd["arr"]!; 93 | var arr2D = sd["arr_2d"]!; 94 | 95 | // arr = torch.tensor([1, 2, 3, 4, 5, 6]) 96 | // arr_2d = arr.clone().reshape(2, 3) 97 | // Confirm type 98 | Assert.Multiple(() => { 99 | Assert.That(arr.dtype, Is.EqualTo(ScalarType.Int64)); 100 | Assert.That(arr2D.dtype, Is.EqualTo(ScalarType.Int64)); 101 | }); 102 | // Confirm shape 103 | Assert.Multiple(() => { 104 | Assert.That(arr.shape, Is.EquivalentTo(new[] { 6L })); 105 | Assert.That(arr2D.shape, Is.EquivalentTo(new[] { 2L, 3L })); 106 | }); 107 | 108 | // Confirm content 109 | Assert.Multiple(() => { 110 | Assert.That(arr.data().ToArray(), Is.EquivalentTo(new long[] { 1, 2, 3, 4, 5, 6 })); 111 | Assert.That(arr2D.data().ToArray(), Is.EquivalentTo(new long[] { 1, 2, 3, 4, 5, 6 })); 112 | }); 113 | } 114 | 115 | [Test] 116 | public void TestSafetensorsModuleSave() { 117 | 118 | var arr = torch.tensor(new long[] { 1, 2, 3, 4, 5, 6 }); 119 | var arr2D = arr.clone().reshape(2, 3); 120 | var dict = new Dictionary() { 121 | { "arr", arr }, 122 | { "arr_2d", arr2D } 123 | }; 124 | 125 | // Create a temporary filename to test 126 | var tempFile = Guid.NewGuid().ToString() + ".pth"; 127 | try { 128 | // Save the tensors 129 | Safetensors.SaveStateDict(tempFile, dict); 130 | 131 | // Run the python 132 | var p = new Process(); 133 | p.StartInfo.FileName = "python"; 134 | p.StartInfo.Arguments = "test_safetensors_tensor_load.py " + tempFile; 135 | p.StartInfo.UseShellExecute = false; 136 | p.StartInfo.RedirectStandardOutput = true; 137 | 138 | p.Start(); 139 | p.WaitForExit(10); 140 | // Read the output 141 | var output = p.StandardOutput.ReadToEnd().Trim(); 142 | Assert.That(output, Is.EqualTo( 143 | "arr" + Environment.NewLine + 144 | "tensor([1, 2, 3, 4, 5, 6])" + Environment.NewLine + 145 | "arr_2d" + Environment.NewLine + 146 | "tensor([[1, 2, 3]," + Environment.NewLine + " [4, 5, 6]])")); 147 | } 148 | finally { 149 | if (File.Exists(tempFile)) 150 | File.Delete(tempFile); 151 | } 152 | } 153 | 154 | [Test] 155 | public void TestLoadBatchNorm2D_Bug13() { 156 | var model = BatchNorm2d(5); 157 | // Run a few inputs through, to increment the `num_batches_tracked` field 158 | for (int i = 0; i < 5; i++) { 159 | using var d = torch.NewDisposeScope(); 160 | model.forward(torch.rand(new[] { 5L, 5L, 5L, 5L })); 161 | } 162 | 163 | Assert.That(model.num_batches_tracked!.item(), Is.EqualTo(5)); 164 | 165 | // Save to a file, and try reloading 166 | using var stream = new MemoryStream(); 167 | model.save_py(stream, leaveOpen: true); 168 | stream.Position = 0; 169 | 170 | // Create the new model and load it 171 | var model2 = BatchNorm2d(5); 172 | model2.load_py(stream); 173 | 174 | Assert.That(model2.num_batches_tracked!.item(), Is.EqualTo(5)); 175 | } 176 | } 177 | 178 | 179 | } -------------------------------------------------------------------------------- /TorchSharp.PyBridge.Tests/TorchSharp.PyBridge.Tests.csproj: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | net6.0 5 | enable 6 | enable 7 | 8 | false 9 | true 10 | x64 11 | x64 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | Always 27 | 28 | 29 | Always 30 | 31 | 32 | PreserveNewest 33 | 34 | 35 | PreserveNewest 36 | 37 | 38 | PreserveNewest 39 | 40 | 41 | PreserveNewest 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | PreserveNewest 61 | 62 | 63 | PreserveNewest 64 | 65 | 66 | PreserveNewest 67 | 68 | 69 | PreserveNewest 70 | 71 | 72 | PreserveNewest 73 | 74 | 75 | PreserveNewest 76 | 77 | 78 | PreserveNewest 79 | 80 | 81 | PreserveNewest 82 | 83 | 84 | PreserveNewest 85 | 86 | 87 | PreserveNewest 88 | 89 | 90 | PreserveNewest 91 | 92 | 93 | PreserveNewest 94 | 95 | 96 | PreserveNewest 97 | 98 | 99 | PreserveNewest 100 | 101 | 102 | PreserveNewest 103 | 104 | 105 | PreserveNewest 106 | 107 | 108 | PreserveNewest 109 | 110 | 111 | PreserveNewest 112 | 113 | 114 | PreserveNewest 115 | 116 | 117 | PreserveNewest 118 | 119 | 120 | Never 121 | 122 | 123 | PreserveNewest 124 | 125 | 126 | PreserveNewest 127 | 128 | 129 | PreserveNewest 130 | 131 | 132 | PreserveNewest 133 | 134 | 135 | PreserveNewest 136 | 137 | 138 | PreserveNewest 139 | 140 | 141 | PreserveNewest 142 | 143 | 144 | PreserveNewest 145 | 146 | 147 | PreserveNewest 148 | 149 | 150 | PreserveNewest 151 | 152 | 153 | PreserveNewest 154 | 155 | 156 | PreserveNewest 157 | 158 | 159 | PreserveNewest 160 | 161 | 162 | PreserveNewest 163 | 164 | 165 | PreserveNewest 166 | 167 | 168 | PreserveNewest 169 | 170 | 171 | PreserveNewest 172 | 173 | 174 | 175 | -------------------------------------------------------------------------------- /TorchSharp.PyBridge.Tests/pickled_modules/module_load.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaltielshmid/TorchSharp.PyBridge/059476b6110566352b1b255feb3d5ac79bac981a/TorchSharp.PyBridge.Tests/pickled_modules/module_load.pth -------------------------------------------------------------------------------- /TorchSharp.PyBridge.Tests/pickled_modules/module_load.safetensors: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaltielshmid/TorchSharp.PyBridge/059476b6110566352b1b255feb3d5ac79bac981a/TorchSharp.PyBridge.Tests/pickled_modules/module_load.safetensors -------------------------------------------------------------------------------- /TorchSharp.PyBridge.Tests/pickled_optimizers/adadelta_load.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaltielshmid/TorchSharp.PyBridge/059476b6110566352b1b255feb3d5ac79bac981a/TorchSharp.PyBridge.Tests/pickled_optimizers/adadelta_load.pth -------------------------------------------------------------------------------- /TorchSharp.PyBridge.Tests/pickled_optimizers/adagrad_load.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaltielshmid/TorchSharp.PyBridge/059476b6110566352b1b255feb3d5ac79bac981a/TorchSharp.PyBridge.Tests/pickled_optimizers/adagrad_load.pth -------------------------------------------------------------------------------- /TorchSharp.PyBridge.Tests/pickled_optimizers/adam_emptystate_load.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaltielshmid/TorchSharp.PyBridge/059476b6110566352b1b255feb3d5ac79bac981a/TorchSharp.PyBridge.Tests/pickled_optimizers/adam_emptystate_load.pth -------------------------------------------------------------------------------- /TorchSharp.PyBridge.Tests/pickled_optimizers/adam_load.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaltielshmid/TorchSharp.PyBridge/059476b6110566352b1b255feb3d5ac79bac981a/TorchSharp.PyBridge.Tests/pickled_optimizers/adam_load.pth -------------------------------------------------------------------------------- /TorchSharp.PyBridge.Tests/pickled_optimizers/adamax_load.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaltielshmid/TorchSharp.PyBridge/059476b6110566352b1b255feb3d5ac79bac981a/TorchSharp.PyBridge.Tests/pickled_optimizers/adamax_load.pth -------------------------------------------------------------------------------- /TorchSharp.PyBridge.Tests/pickled_optimizers/adamw_load.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaltielshmid/TorchSharp.PyBridge/059476b6110566352b1b255feb3d5ac79bac981a/TorchSharp.PyBridge.Tests/pickled_optimizers/adamw_load.pth -------------------------------------------------------------------------------- /TorchSharp.PyBridge.Tests/pickled_optimizers/asgd_load.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaltielshmid/TorchSharp.PyBridge/059476b6110566352b1b255feb3d5ac79bac981a/TorchSharp.PyBridge.Tests/pickled_optimizers/asgd_load.bin -------------------------------------------------------------------------------- /TorchSharp.PyBridge.Tests/pickled_optimizers/asgd_load.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaltielshmid/TorchSharp.PyBridge/059476b6110566352b1b255feb3d5ac79bac981a/TorchSharp.PyBridge.Tests/pickled_optimizers/asgd_load.pth -------------------------------------------------------------------------------- /TorchSharp.PyBridge.Tests/pickled_optimizers/create_pyload_opts.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Linear 3 | from torch.optim import * 4 | 5 | def calc_loss(opt, linears): 6 | opt.zero_grad() 7 | out = torch.rand(10) 8 | for lin in linears: out = lin(out) 9 | torch.nn.functional.mse_loss(out, torch.rand(10)).backward() 10 | opt.step() 11 | 12 | def save_sgd(): 13 | lin = Linear(10, 10) 14 | opt = SGD(lin.parameters(), 0.01, 0.1) 15 | calc_loss(opt, [lin]) 16 | torch.save(opt.state_dict(), 'sgd_load.pth') 17 | save_sgd() 18 | 19 | def save_asgd(): 20 | lin = Linear(10, 10) 21 | opt = ASGD(lin.parameters(), 0.01, 1e-3, 0.65, 1e5) 22 | calc_loss(opt, [lin]) 23 | torch.save(opt.state_dict(), 'asgd_load.pth') 24 | save_asgd() 25 | 26 | def save_rmsprop(): 27 | lin1 = Linear(10, 10) 28 | lin2 = Linear(10, 10) 29 | opt = RMSprop(lin1.parameters(), 0.001, momentum=0.1) 30 | opt.add_param_group(dict(params=lin2.parameters(), lr=0.01, momentum=0, centered=True)) 31 | calc_loss(opt, [lin1, lin2]) 32 | torch.save(opt.state_dict(), 'rmsprop_load.pth') 33 | save_rmsprop() 34 | 35 | def save_rprop(): 36 | lin1 = Linear(10, 10) 37 | lin2 = Linear(10, 10) 38 | opt = Rprop(lin1.parameters(), lr=0.001, etas=(0.35, 1.5), step_sizes=(1e-5, 5), maximize=False) 39 | opt.add_param_group(dict(params=lin2.parameters(), etas=(0.45, 1.5), step_sizes=(1e-5, 5), lr=0.01, maximize=True)) 40 | calc_loss(opt, [lin1, lin2]) 41 | torch.save(opt.state_dict(), 'rprop_load.pth') 42 | save_rprop() 43 | 44 | def save_adam(): 45 | lin1 = Linear(10, 10) 46 | lin2 = Linear(10, 10) 47 | opt = Adam(lin1.parameters(), lr=0.001, betas=(0.8, 0.9), amsgrad=False) 48 | opt.add_param_group(dict(params=lin2.parameters(), betas=(0.7, 0.79), lr=0.01, amsgrad=True)) 49 | calc_loss(opt, [lin1, lin2]) 50 | torch.save(opt.state_dict(), 'adam_load.pth') 51 | save_adam() 52 | 53 | def save_adamw(): 54 | lin1 = Linear(10, 10) 55 | lin2 = Linear(10, 10) 56 | opt = AdamW(lin1.parameters(), lr=0.001, betas=(0.8, 0.9), amsgrad=False) 57 | opt.add_param_group(dict(params=lin2.parameters(), betas=(0.7, 0.79), lr=0.01, amsgrad=True)) 58 | calc_loss(opt, [lin1, lin2]) 59 | torch.save(opt.state_dict(), 'adamw_load.pth') 60 | save_adamw() 61 | 62 | def save_nadam(): 63 | lin1 = Linear(10, 10) 64 | lin2 = Linear(10, 10) 65 | opt = NAdam(lin1.parameters(), lr=0.001, betas=(0.8, 0.9), weight_decay=0) 66 | opt.add_param_group(dict(params=lin2.parameters(), betas=(0.7, 0.79), lr=0.01, weight_decay=0.3)) 67 | calc_loss(opt, [lin1, lin2]) 68 | torch.save(opt.state_dict(), 'nadam_load.pth') 69 | save_nadam() 70 | 71 | def save_radam(): 72 | lin1 = Linear(10, 10) 73 | lin2 = Linear(10, 10) 74 | opt = RAdam(lin1.parameters(), lr=0.001, betas=(0.8, 0.9), weight_decay=0) 75 | opt.add_param_group(dict(params=lin2.parameters(), betas=(0.7, 0.79), lr=0.01, weight_decay=0.3)) 76 | calc_loss(opt, [lin1, lin2]) 77 | torch.save(opt.state_dict(), 'radam_load.pth') 78 | save_radam() 79 | 80 | def save_adadelta(): 81 | lin1 = Linear(10, 10) 82 | lin2 = Linear(10, 10) 83 | opt = Adadelta(lin1.parameters(), lr=0.001, rho=0.85, weight_decay=0.3, maximize=False) 84 | opt.add_param_group(dict(params=lin2.parameters(), rho=0.79, lr=0.01, weight_decay=0.3, maximize=True)) 85 | calc_loss(opt, [lin1, lin2]) 86 | torch.save(opt.state_dict(), 'adadelta_load.pth') 87 | save_adadelta() 88 | 89 | def save_adagrad(): 90 | lin = Linear(10, 10) 91 | opt = Adagrad(lin.parameters(), lr=0.001, lr_decay=0.85, weight_decay=0.3) 92 | calc_loss(opt, [lin]) 93 | torch.save(opt.state_dict(), 'adagrad_load.pth') 94 | save_adagrad() 95 | 96 | def save_adamax(): 97 | lin1 = Linear(10, 10) 98 | lin2 = Linear(10, 10) 99 | opt = Adamax(lin1.parameters(), lr=0.001, betas=(0.8, 0.9), weight_decay=0) 100 | opt.add_param_group(dict(params=lin2.parameters(), betas=(0.7, 0.79), lr=0.01, weight_decay=0.3)) 101 | calc_loss(opt, [lin1, lin2]) 102 | torch.save(opt.state_dict(), 'adamax_load.pth') 103 | save_adamax() 104 | 105 | def save_adam_empty_state(): 106 | lin1 = Linear(10, 10) 107 | lin2 = Linear(10, 10) 108 | lin3 = Linear(10, 10) 109 | opt = Adam(lin1.parameters(), lr=0.001, betas=(0.8, 0.9), amsgrad=False) 110 | opt.add_param_group(dict(params=lin2.parameters(), betas=(0.7, 0.79), lr=0.01, amsgrad=True)) 111 | opt.add_param_group(dict(params=lin3.parameters(), betas=(0.6, 0.69), lr=0.01, amsgrad=True)) 112 | calc_loss(opt, [lin1, lin2]) 113 | torch.save(opt.state_dict(), 'adam_emptystate_load.pth') 114 | save_adam_empty_state() -------------------------------------------------------------------------------- /TorchSharp.PyBridge.Tests/pickled_optimizers/nadam_load.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaltielshmid/TorchSharp.PyBridge/059476b6110566352b1b255feb3d5ac79bac981a/TorchSharp.PyBridge.Tests/pickled_optimizers/nadam_load.pth -------------------------------------------------------------------------------- /TorchSharp.PyBridge.Tests/pickled_optimizers/radam_load.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaltielshmid/TorchSharp.PyBridge/059476b6110566352b1b255feb3d5ac79bac981a/TorchSharp.PyBridge.Tests/pickled_optimizers/radam_load.pth -------------------------------------------------------------------------------- /TorchSharp.PyBridge.Tests/pickled_optimizers/rmsprop_load.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaltielshmid/TorchSharp.PyBridge/059476b6110566352b1b255feb3d5ac79bac981a/TorchSharp.PyBridge.Tests/pickled_optimizers/rmsprop_load.pth -------------------------------------------------------------------------------- /TorchSharp.PyBridge.Tests/pickled_optimizers/rprop_load.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaltielshmid/TorchSharp.PyBridge/059476b6110566352b1b255feb3d5ac79bac981a/TorchSharp.PyBridge.Tests/pickled_optimizers/rprop_load.pth -------------------------------------------------------------------------------- /TorchSharp.PyBridge.Tests/pickled_optimizers/sgd_load.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaltielshmid/TorchSharp.PyBridge/059476b6110566352b1b255feb3d5ac79bac981a/TorchSharp.PyBridge.Tests/pickled_optimizers/sgd_load.pth -------------------------------------------------------------------------------- /TorchSharp.PyBridge.Tests/pickled_optimizers/test_sharpsaved_opts.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim import * 3 | from torch.nn import Linear, Sequential 4 | import sys, os 5 | 6 | path = sys.argv[1] 7 | 8 | # base sequence for all of these 9 | l1 = Linear(10, 10, bias=True); 10 | l2 = Linear(10, 10, bias=True); 11 | seq = Sequential(l1, l2); 12 | 13 | # Go through the list of all the optimizers and make sure they all managed to 14 | # load both the regular and withloss saved files 15 | for (optim, name) in [(Adadelta, "Adadelta"), (Adagrad, "Adagrad"), (Adam, "Adam"), (Adamax, "Adamax"), (AdamW, "AdamW"), (ASGD, "ASGD"), (NAdam, "NAdam"), (RAdam, "RAdam"), (RMSprop, "RMSProp"), (Rprop, "Rprop"), (SGD, "SGD")]: 16 | opt = optim(seq.parameters(), 0.99) # init with 0.99 to make sure it changed 17 | assert opt.param_groups[0]['lr'] == 0.99 18 | 19 | # load and make sure the lr changed 20 | opt.load_state_dict(torch.load(os.path.join(name + '_save.bin'))) 21 | assert opt.param_groups[0]['lr'] != 0.99 22 | 23 | # load the after loss model, and confirm that it loads without error 24 | opt.load_state_dict(torch.load(os.path.join(path, '_withloss_save.bin'))) 25 | # TODO: do a better test and check that the parameters were actually loaded 26 | 27 | -------------------------------------------------------------------------------- /TorchSharp.PyBridge.Tests/pickled_tensors/tensors.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaltielshmid/TorchSharp.PyBridge/059476b6110566352b1b255feb3d5ac79bac981a/TorchSharp.PyBridge.Tests/pickled_tensors/tensors.pth -------------------------------------------------------------------------------- /TorchSharp.PyBridge.Tests/pickled_tensors/tensors.safetensors: -------------------------------------------------------------------------------- 1 | x{"arr":{"dtype":"I64","shape":[6],"data_offsets":[0,48]},"arr_2d":{"dtype":"I64","shape":[2,3],"data_offsets":[48,96]}}  -------------------------------------------------------------------------------- /TorchSharp.PyBridge.Tests/test_safetensors_load.py: -------------------------------------------------------------------------------- 1 | from safetensors import safe_open 2 | import torch 3 | from torch.nn import * 4 | from collections import OrderedDict 5 | 6 | # start by building our model 7 | model = Sequential(OrderedDict([("lin1", Linear(5, 1, bias=False)), ("lin2", Linear(1, 2, bias=False))])) 8 | 9 | import sys 10 | path = sys.argv[1] 11 | tensors = {} 12 | with safe_open(path, framework="pt", device=0) as f: 13 | for k in f.keys(): 14 | tensors[k] = f.get_tensor(k) 15 | model.load_state_dict(tensors) 16 | 17 | with torch.no_grad(): 18 | print(model.forward(torch.tensor([1, 1, 1, 1, 1]).float())) -------------------------------------------------------------------------------- /TorchSharp.PyBridge.Tests/test_safetensors_tensor_load.py: -------------------------------------------------------------------------------- 1 | from safetensors.torch import load_file 2 | import sys 3 | 4 | path = sys.argv[1] 5 | tensors = load_file(path) 6 | 7 | for (k,v) in tensors.items(): 8 | print(k) 9 | print(v) -------------------------------------------------------------------------------- /TorchSharp.PyBridge.Tests/test_torch_load.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import * 3 | from collections import OrderedDict 4 | 5 | # start by building our model 6 | model = Sequential(OrderedDict([("lin1", Linear(5, 1, bias=False)), ("lin2", Linear(1, 2, bias=False))])) 7 | 8 | import sys 9 | path = sys.argv[1] 10 | model.load_state_dict(torch.load(path)) 11 | 12 | with torch.no_grad(): 13 | print(model.forward(torch.tensor([1, 1, 1, 1, 1]).float())) -------------------------------------------------------------------------------- /TorchSharp.PyBridge.Tests/test_torch_tensor_load.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sys 3 | 4 | path = sys.argv[1] 5 | sd = torch.load(path) 6 | 7 | for (k,v) in sd.items(): 8 | print(k) 9 | print(v) -------------------------------------------------------------------------------- /TorchSharp.PyBridge.sln: -------------------------------------------------------------------------------- 1 |  2 | Microsoft Visual Studio Solution File, Format Version 12.00 3 | # Visual Studio Version 17 4 | VisualStudioVersion = 17.6.33723.286 5 | MinimumVisualStudioVersion = 10.0.40219.1 6 | Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TorchSharp.PyBridge", "TorchSharp.PyBridge\TorchSharp.PyBridge.csproj", "{CA09DEDA-DDEE-483F-BA53-EEE395125A4C}" 7 | EndProject 8 | Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TorchSharp.PyBridge.Tests", "TorchSharp.PyBridge.Tests\TorchSharp.PyBridge.Tests.csproj", "{082A67F7-B936-4B6B-937D-3EDB8A42874D}" 9 | EndProject 10 | Global 11 | GlobalSection(SolutionConfigurationPlatforms) = preSolution 12 | Debug|x64 = Debug|x64 13 | Release|x64 = Release|x64 14 | EndGlobalSection 15 | GlobalSection(ProjectConfigurationPlatforms) = postSolution 16 | {CA09DEDA-DDEE-483F-BA53-EEE395125A4C}.Debug|x64.ActiveCfg = Debug|x64 17 | {CA09DEDA-DDEE-483F-BA53-EEE395125A4C}.Debug|x64.Build.0 = Debug|x64 18 | {CA09DEDA-DDEE-483F-BA53-EEE395125A4C}.Release|x64.ActiveCfg = Release|x64 19 | {CA09DEDA-DDEE-483F-BA53-EEE395125A4C}.Release|x64.Build.0 = Release|x64 20 | {082A67F7-B936-4B6B-937D-3EDB8A42874D}.Debug|x64.ActiveCfg = Debug|x64 21 | {082A67F7-B936-4B6B-937D-3EDB8A42874D}.Debug|x64.Build.0 = Debug|x64 22 | {082A67F7-B936-4B6B-937D-3EDB8A42874D}.Release|x64.ActiveCfg = Release|x64 23 | {082A67F7-B936-4B6B-937D-3EDB8A42874D}.Release|x64.Build.0 = Release|x64 24 | EndGlobalSection 25 | GlobalSection(SolutionProperties) = preSolution 26 | HideSolutionNode = FALSE 27 | EndGlobalSection 28 | GlobalSection(ExtensibilityGlobals) = postSolution 29 | SolutionGuid = {0F7199C3-1A7C-46F8-9083-C53579DC33B3} 30 | EndGlobalSection 31 | EndGlobal 32 | -------------------------------------------------------------------------------- /TorchSharp.PyBridge/Extensions.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Collections.Generic; 3 | using System.Linq; 4 | using System.Text; 5 | using System.Threading.Tasks; 6 | 7 | namespace TorchSharp.PyBridge { 8 | static class Extensions { 9 | public static byte[] ReadBytes(this Stream stream, int count) { 10 | var ret = new byte[count]; 11 | stream.Read(ret, 0, count); 12 | return ret; 13 | } 14 | 15 | public static void RemoveKeys(this IDictionary dict, IList keys) { 16 | foreach (var key in keys) dict.Remove(key); 17 | } 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /TorchSharp.PyBridge/OptimizerUtils.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Collections; 3 | using System.Collections.Generic; 4 | using System.Linq; 5 | using System.Reflection; 6 | using System.Runtime.InteropServices; 7 | using System.Text; 8 | using System.Threading.Tasks; 9 | using static TorchSharp.torch; 10 | 11 | namespace TorchSharp.PyBridge { 12 | internal static class OptimizerUtils { 13 | internal static void AssignFieldsAndPropsToTargetTable(T obj, IDictionary targetTable) where T : notnull { 14 | // Go through all the fields 15 | foreach (var field in obj.GetType().GetFields(BindingFlags.Instance | BindingFlags.Public)) { 16 | object? value = field.GetValue(obj); 17 | 18 | SetValueInTargetTable(field.Name, value, targetTable); 19 | }// next property 20 | 21 | // Go through all the properties 22 | foreach (var property in obj.GetType().GetProperties(BindingFlags.Instance | BindingFlags.Public)) { 23 | object? value = property.GetValue(obj); 24 | 25 | SetValueInTargetTable(property.Name, value, targetTable); 26 | }// next property 27 | } 28 | 29 | internal static void AssignFieldsAndPropsFromReferenceTable(T obj, IDictionary referenceTable) where T : notnull { 30 | // Go through all the fields 31 | foreach (var field in obj.GetType().GetFields(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic)) { 32 | object? value = GetValueFromReferenceTable(field.Name, field.FieldType, referenceTable); 33 | 34 | // Set the value! 35 | // If it's a tensor - first dispose the old tensor, and then set the new value 36 | if (field.FieldType == typeof(torch.Tensor)) { 37 | var orig = (torch.Tensor?)field.GetValue(obj); 38 | if (orig is not null) 39 | orig?.Dispose(); 40 | } 41 | field.SetValue(obj, Convert.ChangeType(value, Nullable.GetUnderlyingType(field.FieldType) ?? field.FieldType)); 42 | }// next property 43 | 44 | // Go through all the properties 45 | foreach (var property in obj.GetType().GetProperties(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic)) { 46 | object? value = GetValueFromReferenceTable(property.Name, property.PropertyType, referenceTable); 47 | 48 | // Set the value! 49 | // If it's a tensor - first dispose the old tensor, and then set the new value 50 | if (property.PropertyType == typeof(torch.Tensor)) { 51 | var orig = (torch.Tensor?)property.GetValue(obj); 52 | if (orig is not null) 53 | orig?.Dispose(); 54 | } 55 | property.SetValue(obj, Convert.ChangeType(value, Nullable.GetUnderlyingType(property.PropertyType) ?? property.PropertyType)); 56 | }// next property 57 | } 58 | 59 | private static object? GetValueFromReferenceTable(string name, Type type, IDictionary referenceTable) { 60 | // Special handling for lrs/betas/eta/step_sizes/step: 61 | return name switch { 62 | "LearningRate" or "InitialLearningRate" => referenceTable["lr"], 63 | "beta1" or "beta2" => ((object[])referenceTable["betas"]!)[name == "beta1" ? 0 : 1], 64 | "etaminus" or "etaplus" => ((object[])referenceTable["etas"]!)[name == "etaminus" ? 0 : 1], 65 | "min_step" or "max_step" => ((object[])referenceTable["step_sizes"]!)[name == "min_step" ? 0 : 1], 66 | _ when type == typeof(torch.Tensor) => referenceTable[name!], 67 | _ => GetValueFromMaybeTensor(referenceTable[name]!), 68 | }; 69 | } 70 | 71 | private static object? GetValueFromMaybeTensor(object obj) { 72 | if (obj is null || obj is not torch.Tensor) 73 | return obj; 74 | // Stored as a tensor, so return it as a float and dispose the tensor 75 | using var tensor = (torch.Tensor)obj; 76 | return tensor.dtype switch { 77 | ScalarType.Byte => tensor.ToByte(), 78 | ScalarType.Int8 => tensor.@short().ToInt16(), 79 | ScalarType.Int16 => tensor.ToInt16(), 80 | ScalarType.Int32 => tensor.ToInt32(), 81 | ScalarType.Int64 => tensor.ToInt64(), 82 | ScalarType.Float16 => tensor.ToHalf(), 83 | ScalarType.Float32 => tensor.ToSingle(), 84 | ScalarType.Float64 => tensor.ToDouble(), 85 | ScalarType.ComplexFloat32 => tensor.ToComplexFloat32(), 86 | ScalarType.ComplexFloat64 => tensor.ToComplexFloat64(), 87 | ScalarType.Bool => tensor.ToBoolean(), 88 | ScalarType.BFloat16 => tensor.@half().ToHalf(), 89 | _ => throw new ArgumentException($"Loaded tensor of type unknown to `TorchSharp.PyBridge`: {tensor.dtype}. Please open an issue in the repository.") 90 | }; 91 | } 92 | 93 | private static void SetValueInTargetTable(string name, object? value, IDictionary targetTable) { 94 | // Special lrs/handling for betas/eta/step_sizes/step: 95 | switch (name) { 96 | case "InitialLearningRate": break; 97 | case "LearningRate": targetTable["lr"] = value; break; 98 | case "beta1": case "beta2": Set2ItemTupleValue(targetTable, "betas", value, name == "beta1" ? 0 : 1); break; 99 | case "etaminus": case "etaplus": Set2ItemTupleValue(targetTable, "etas", value, name == "etaminus" ? 0 : 1); break; 100 | case "min_step": case "max_step": Set2ItemTupleValue(targetTable, "step_sizes", value, name == "min_step" ? 0 : 1); break; 101 | case "step": targetTable[name] = torch.tensor((long)value!); break; 102 | default: targetTable[name] = value; break; 103 | } 104 | } 105 | 106 | private static void Set2ItemTupleValue(IDictionary targetTable, string key, object? value, int idx) { 107 | if (targetTable[key] is null) 108 | targetTable[key] = new object?[2]; 109 | ((object?[])targetTable[key]!)[idx] = value; 110 | } 111 | } 112 | } 113 | -------------------------------------------------------------------------------- /TorchSharp.PyBridge/PyBridgeModuleExtensions.cs: -------------------------------------------------------------------------------- 1 | using System.Text.Json; 2 | using System.Text.Json.Nodes; 3 | using TqdmSharp; 4 | using static TorchSharp.torch.nn; 5 | 6 | namespace TorchSharp.PyBridge { 7 | public static class PyBridgeModuleExtensions { 8 | 9 | /// 10 | /// Save the parameters and buffers of the module to a python-compatible file to be loaded using `torch.load`. 11 | /// 12 | /// The file path. 13 | /// A list of keys not to consider when saving the weights. 14 | /// 15 | public static Module save_py(this Module module, string location, IList? skip = null) { 16 | using var stream = System.IO.File.Create(location); 17 | module.save_py(stream, skip); 18 | 19 | return module; 20 | } 21 | 22 | /// 23 | /// Save the parameters and buffers of the module to a python-compatible file to be loaded using `torch.load`. 24 | /// 25 | /// A writable stream instance. 26 | /// A list of keys not to consider when saving the weights. 27 | /// true to leave the stream open after saving the file 28 | /// 29 | public static Module save_py(this Module module, System.IO.Stream stream, IList? skip = null, bool leaveOpen = false) { 30 | using var d = torch.NewDisposeScope(); // Create a new dispose scope for any tensors we create 31 | 32 | // Construct our state_dict, without the skip parameters 33 | var sd = module.state_dict(); 34 | if (skip is not null) sd.RemoveKeys(skip); 35 | 36 | PyTorchPickler.PickleStateDict(stream, sd, leaveOpen); 37 | 38 | return module; 39 | } 40 | 41 | 42 | /// 43 | /// Save the parameters and buffers of the module to a file using the safetensors format (https://github.com/huggingface/safetensors) 44 | /// 45 | /// The file path. 46 | /// A list of keys not to consider when saving the weights. 47 | /// 48 | public static Module save_safetensors(this Module module, string location, IList? skip = null) { 49 | using var stream = System.IO.File.Create(location); 50 | module.save_safetensors(stream, skip); 51 | 52 | return module; 53 | } 54 | 55 | /// 56 | /// Save the parameters and buffers of the module to a file using the safetensors format (https://github.com/huggingface/safetensors) 57 | /// 58 | /// A writable stream instance. 59 | /// A list of keys not to consider when saving the weights. 60 | /// true to leave the stream open after saving the file 61 | /// 62 | public static Module save_safetensors(this Module module, System.IO.Stream stream, IList? skip = null, bool leaveOpen = false) { 63 | using var d = torch.NewDisposeScope(); // Create a new dispose scope for any tensors we create 64 | 65 | // Construct our state_dict, without the skip parameters 66 | var sd = module.state_dict(); 67 | if (skip is not null) sd.RemoveKeys(skip); 68 | 69 | Safetensors.SaveStateDict(stream, sd, leaveOpen); 70 | 71 | return module; 72 | } 73 | 74 | /// 75 | /// Load the parameters and buffers from a file saved using `torch.save` 76 | /// 77 | /// The file path. 78 | /// 79 | /// If true, will only load a module if it exactly corresponds to the current module's state. 80 | /// If false, will load the parameters and buffers that it finds in the saved file, 81 | /// leaving everything else alone. 82 | /// 83 | /// A list of keys not to consider when loading the dictionary. 84 | /// A dictionary to populate with the list of parameters loaded and whether they were matched/skipped. Useful when loading in non-strict mode. 85 | /// The module, with parameters and buffers loaded. 86 | /// 87 | /// This method only supports loading the newer format used by `torch.save`, using a zip file. 88 | /// The model will be fully loaded and all the validation checks will only run after the state 89 | /// dictionary has been fully loaded. 90 | /// 91 | public static Module load_py(this Module module, string location, bool strict = true, IList? skip = null, Dictionary? loadedParameters = null) { 92 | if (!System.IO.File.Exists(location)) 93 | throw new System.IO.FileNotFoundException(location); 94 | 95 | using var stream = System.IO.File.OpenRead(location); 96 | module.load_py(stream, strict, skip, loadedParameters); 97 | 98 | return module; 99 | } 100 | 101 | /// 102 | /// Load the parameters and buffers from a file saved using `torch.save` 103 | /// 104 | /// A readable stream instance. 105 | /// 106 | /// If true, will only load a module if it exactly corresponds to the current module's state. 107 | /// If false, will load the parameters and buffers that it finds in the saved file, 108 | /// leaving everything else alone. 109 | /// 110 | /// A list of keys not to consider when loading the dictionary. 111 | /// A dictionary to populate with the list of parameters loaded and whether they were matched/skipped. Useful when loading in non-strict mode. 112 | /// true to leave the stream open after saving the file 113 | /// The module, with parameters and buffers loaded. 114 | /// 115 | /// This method only supports loading the newer format used by `torch.save`, using a zip file. 116 | /// The model will be fully loaded and all the validation checks will only run after the state 117 | /// dictionary has been fully loaded. 118 | /// 119 | public static Module load_py(this Module module, System.IO.Stream stream, bool strict = true, IList? skip = null, Dictionary? loadedParameters = null, bool leaveOpen = false) { 120 | // Create a dispose score so that we don't keep any of the loaded tensors past this function 121 | using var d = torch.NewDisposeScope(); 122 | using var d2 = torch.no_grad(); // To circumvent a bug introduced in 0.102.0 123 | 124 | // Unpickle the state dictionary into memory. 125 | // Keep stream open because tensors will not get deserialized yet. 126 | var unpickled = PyTorchUnpickler.UnpickleStateDict(stream, leaveOpen: true, skipTensorRead: true); 127 | 128 | // Convert the hashtable to a dictionary of string->tensor 129 | var unpickledConstructors = new Dictionary(); 130 | 131 | foreach (string key in unpickled.Keys) { 132 | unpickledConstructors.Add(key, (PyTorchUnpickler.TensorConstructorArgs)unpickled[key]!); 133 | } 134 | 135 | var (_, unexpectedKeys) = load_state_dict(module, unpickledConstructors, strict, skip); 136 | 137 | if (!leaveOpen) { 138 | // Close stream now that tensor streams have been read. 139 | stream.Close (); 140 | } 141 | 142 | if (loadedParameters is null) { 143 | return module; 144 | } 145 | 146 | // Fill in the loadedParameters dictionary 147 | foreach (var key in unpickledConstructors.Keys) { 148 | loadedParameters[key] = true; 149 | } 150 | 151 | foreach (var key in unexpectedKeys) { 152 | loadedParameters[key] = false; 153 | } 154 | 155 | return module; 156 | } 157 | 158 | /// 159 | /// Mirrors the implementation of module.load_state_dict but performs tensor reading 160 | /// with less intermediate memory overhead. 161 | /// 162 | static (IList missing_keys, IList unexpected_keys) load_state_dict( 163 | Module module, 164 | Dictionary unpickled, 165 | bool strict = true, 166 | IList? skip = null 167 | ) { 168 | var missingKeys = new List(); 169 | var unexpectedKeys = new List(); 170 | skip ??= Array.Empty(); 171 | 172 | var state = module.state_dict(); 173 | 174 | foreach (string key in unpickled.Keys) { 175 | if (!skip.Contains(key) && !state.ContainsKey(key)) 176 | unexpectedKeys.Add(key); 177 | } 178 | 179 | foreach (string key in state.Keys) { 180 | if (!skip.Contains(key) && !unpickled.ContainsKey(key)) { 181 | missingKeys.Add(key); 182 | } 183 | } 184 | 185 | if (strict && (missingKeys.Count > 0 || unexpectedKeys.Count > 0)) { 186 | throw new InvalidOperationException("The loaded state_dict is not identical to the target dictionary."); 187 | } 188 | 189 | var inputStreams = unpickled 190 | .Where(e => state.ContainsKey(e.Key)) 191 | // Avoid random stream seeks by reading archive files in the order that they are stored. 192 | .OrderBy(e => e.Value.ArchiveIndex) 193 | .ToArray(); 194 | 195 | foreach (var (key, constructor) in inputStreams) { 196 | var target = state[key]; 197 | target.with_requires_grad(constructor.RequiresGrad); 198 | 199 | if (constructor.DType == state[key].dtype) { 200 | using var stream = constructor.Data; 201 | // Read directly into target tensor. 202 | target 203 | .as_strided(constructor.Shape, constructor.Stride, constructor.StorageOffset) 204 | .ReadBytesFromStream(stream); 205 | } 206 | else { 207 | // Type conversion with intermediate tensor required. 208 | // This will load onto cpu first before copying to target. 209 | using torch.Tensor temp = constructor.ReadTensorFromStream(); 210 | state[key].copy_(temp); 211 | } 212 | } 213 | 214 | return (missingKeys, unexpectedKeys); 215 | } 216 | 217 | /// 218 | /// Load the parameters and buffers from a file saved using the safetensors format (https://github.com/huggingface/safetensors) 219 | /// 220 | /// The file path. 221 | /// 222 | /// If true, will only load a module if it exactly corresponds to the current module's state. 223 | /// If false, will load the parameters and buffers that it finds in the saved file, 224 | /// leaving everything else alone. 225 | /// 226 | /// A list of keys not to consider when loading the dictionary. 227 | /// A dictionary to populate with the list of parameters loaded and whether they were matched/skipped. Useful when loading in non-strict mode. 228 | /// The module, with parameters and buffers loaded. 229 | public static Module load_safetensors(this Module module, string location, bool strict = true, IList? skip = null, Dictionary? loadedParameters = null) { 230 | if (!System.IO.File.Exists(location)) 231 | throw new System.IO.FileNotFoundException(location); 232 | 233 | using var stream = System.IO.File.OpenRead(location); 234 | module.load_safetensors(stream, strict, skip, loadedParameters); 235 | 236 | return module; 237 | } 238 | 239 | /// 240 | /// Load the parameters and buffers from a file saved using the safetensors format (https://github.com/huggingface/safetensors) 241 | /// 242 | /// A readable stream instance. 243 | /// 244 | /// If true, will only load a module if it exactly corresponds to the current module's state. 245 | /// If false, will load the parameters and buffers that it finds in the saved file, 246 | /// leaving everything else alone. 247 | /// 248 | /// A list of keys not to consider when loading the dictionary. 249 | /// A dictionary to populate with the list of parameters loaded and whether they were matched/skipped. Useful when loading in non-strict mode. 250 | /// true to leave the stream open after saving the file 251 | /// The module, with parameters and buffers loaded. 252 | public static Module load_safetensors(this Module module, System.IO.Stream stream, bool strict = true, IList? skip = null, Dictionary? loadedParameters = null, bool leaveOpen = false) { 253 | // Create a dispose score so that we don't keep anyof the loaded tensors past this function 254 | using var d = torch.NewDisposeScope(); 255 | using var d2 = torch.no_grad(); // To circumvent a bug introduced in 0.102.0 256 | 257 | // Retrieve the current state dict of the module, so that we can make sure to only load the relevant 258 | // tensors from the file. 259 | var curStateDict = module.state_dict(); 260 | if (skip is not null) curStateDict.RemoveKeys(skip); 261 | // Unlike the pickler format, here we can load in the whole index quickly to check for mismatches 262 | var index = Safetensors.LoadIndex(stream); 263 | index.Remove("__metadata__"); 264 | if (skip is not null) index.RemoveKeys(skip); 265 | 266 | if (strict) { 267 | // Make sure the keys match exactly 268 | if (index.Count != curStateDict.Count || !index.Keys.All(curStateDict.ContainsKey)) 269 | throw new InvalidOperationException("The specified state dict is not identical to the target dictionary."); 270 | } 271 | 272 | // Load in the state dict, only the relevant keys (make sure to reset the position) 273 | stream.Position = 0; 274 | var loadedStateDict = Safetensors.LoadStateDict(stream, leaveOpen, curStateDict.Keys.ToList()); 275 | 276 | // Load it in using the builtin function 277 | var (_, unexpectedKeys) = module.load_state_dict(loadedStateDict, strict, skip); 278 | // Add to the unexpected all the keys in the index that weren't in the state dict 279 | unexpectedKeys = unexpectedKeys.Concat(index.Keys.Except(curStateDict.Keys)).ToList(); 280 | 281 | // Fill in the loadedParameters dictionary, if relevant 282 | if (loadedParameters is not null) { 283 | foreach (string key in loadedStateDict.Keys) 284 | loadedParameters[key] = true; 285 | foreach (string key in unexpectedKeys) 286 | loadedParameters[key] = false; 287 | } 288 | 289 | return module; 290 | } 291 | 292 | /// 293 | /// Load the parameters and buffers from a directory containing a potentially sharded checkpoint saved using the regular pytorch format or the safetensors format (https://github.com/huggingface/safetensors). 294 | /// The filenames are expected to be the way HuggingFace's `save_pretrained` saves them, which are "pytorch_model.bin.index.json" and "model.safetensors.index.json". 295 | /// Alternatively, one can specify the exact name of the main checkpoint. 296 | /// 297 | /// A path to a directory containing the checkpoint. 298 | /// 299 | /// Optional; The function defaults to look for a model.safetensors, then pytorch_model.bin, with checking for a sharded index equivalent of them. 300 | /// If specified then will use that file instead. If the checkpoint you are specifying is sharded, make sure to point to the index.json file 301 | /// Note that this parameter should be a filename without a path. 302 | /// 303 | /// If true, will only load a module if it exactly corresponds to the current module's state. 304 | /// If false, will load the parameters and buffers that it finds in the saved file, 305 | /// leaving everything else alone. 306 | /// 307 | /// A list of keys not to consider when loading the dictionary. 308 | /// A dictionary to populate with the list of parameters loaded and whether they were matched/skipped. Useful when loading in non-strict mode. 309 | /// Display the tqdm progress bar when loading in sharded checkpoint. 310 | /// The module, with parameters and buffers loaded. 311 | public static Module load_checkpoint(this Module module, string path, string? checkpointName = null, bool strict = true, IList? skip = null, Dictionary? loadedParameters = null, bool useTqdm = true) { 312 | if (!Directory.Exists(path)) 313 | throw new DirectoryNotFoundException(); 314 | 315 | // Figure out the name of the checkpoint. If unspecified, try the hierarchy used by huggingface 316 | if (checkpointName is null) { 317 | foreach (var potential in new[] { "model.safetensors", "pytorch_model.bin" }) { 318 | foreach (var suffix in new[] { "", ".index.json" }) { 319 | string name = Path.Combine(path, potential + suffix); 320 | if (File.Exists(name)) { 321 | checkpointName = potential + suffix; 322 | break; 323 | } 324 | }// next potential suffix 325 | 326 | if (checkpointName is not null) 327 | break; 328 | }// next potential checkpoint 329 | 330 | if (checkpointName is null) 331 | throw new ArgumentException("Couldn't find checkpoint in given directory. Make sure it is named correctly or specify the name of the checkpoint explicitly."); 332 | } 333 | else { 334 | // Make sure that checkpoint name isn't a full path, but just the name of the file 335 | if (checkpointName.Contains('/') || checkpointName.Contains('\\')) 336 | throw new ArgumentException("The checkpoint name should be just the name of a file, not a path", nameof(checkpointName)); 337 | } 338 | 339 | string mainFilename = Path.Combine(path, checkpointName!); 340 | // If the file ends with .safetensors - load it in using that method 341 | if (mainFilename.EndsWith(".safetensors")) 342 | return module.load_safetensors(mainFilename, strict, skip, loadedParameters); 343 | // If the file doesn't end with .json - try loading it in using the regular pytorch method 344 | if (!mainFilename.EndsWith(".json")) 345 | return module.load_py(mainFilename, strict, skip, loadedParameters); 346 | 347 | // We have an index json for a sharded file. 348 | string indexJson = File.ReadAllText(mainFilename); 349 | var fullIndex = JsonSerializer.Deserialize>(indexJson) ?? throw new NotImplementedException("Invalid JSON encountered when loading in sharded index"); 350 | 351 | // Extract just the weight map 352 | if (!fullIndex.ContainsKey("weight_map")) 353 | throw new NotImplementedException("Invalid JSON encountered when loading in sharded index"); 354 | 355 | var weightMap = fullIndex["weight_map"].Deserialize>() ?? throw new NotImplementedException("Invalid JSON encountered when loading in sharded index"); 356 | if (skip is not null) weightMap.RemoveKeys(skip); 357 | 358 | // Retrieve the current state dict of the module, so that we can make sure to only load the relevant 359 | // tensors from the file and to check for strictness 360 | var curStateDict = module.state_dict(); 361 | if (skip is not null) curStateDict.RemoveKeys(skip); 362 | 363 | // If we requested strict - confirm the state dicts match exactly 364 | if (strict) { 365 | // Make sure the keys match exactly 366 | if (weightMap.Count != curStateDict.Count || !weightMap.Keys.All(curStateDict.ContainsKey)) 367 | throw new InvalidOperationException("The specified state dict is not identical to the target dictionary."); 368 | } 369 | // Otherwise, add to the skip list all the parameters that aren't in the state dict. Also remove them from the 370 | // weight map, so that we won't even look at a file where we don't want to load any of the tensors. 371 | else { 372 | skip ??= new List(); 373 | foreach (var badKey in weightMap.Keys.Except(curStateDict.Keys)) { 374 | skip.Add(badKey); 375 | weightMap.Remove(badKey); 376 | loadedParameters?.TryAdd(badKey, false); 377 | }// next bad key 378 | } 379 | 380 | if (weightMap.Count == 0) 381 | return module; 382 | 383 | // Load in each of the files with an optional progress bar progress bar 384 | var weightMapFiles = weightMap.Values.ToHashSet(); 385 | var iterWeightMapFiles = useTqdm ? Tqdm.Wrap(weightMapFiles) : weightMapFiles; 386 | foreach (var key in iterWeightMapFiles) { 387 | string fullPath = Path.Combine(path, key); 388 | if (fullPath.EndsWith(".safetensors")) 389 | module.load_safetensors(fullPath, false, skip: skip, loadedParameters: loadedParameters); 390 | else 391 | module.load_py(fullPath, false, skip: skip, loadedParameters: loadedParameters); 392 | } 393 | 394 | return module; 395 | } 396 | } 397 | } -------------------------------------------------------------------------------- /TorchSharp.PyBridge/PyBridgeOptimizerExtensions.cs: -------------------------------------------------------------------------------- 1 | using Google.Protobuf.WellKnownTypes; 2 | using SkiaSharp; 3 | using System.Collections; 4 | using System.Reflection; 5 | using TorchSharp.Modules; 6 | using static Tensorboard.Summary.Types; 7 | using static TorchSharp.torch; 8 | using static TorchSharp.torch.nn; 9 | 10 | namespace TorchSharp.PyBridge { 11 | public static class PyBridgeOptimizerExtensions { 12 | 13 | /// 14 | /// Saves the optimizer state to a python-compatible file to be loaded using `torch.load`. 15 | /// 16 | /// The file path. 17 | public static void save_py(this OptimizerHelper optim, string location) { 18 | using var stream = System.IO.File.Create(location); 19 | optim.save_py(stream); 20 | } 21 | 22 | /// 23 | /// Saves the optimizer state to a python-compatible file to be loaded using `torch.load`. 24 | /// 25 | /// A writable stream instance. 26 | /// true to leave the stream open after saving the file 27 | /// 28 | public static void save_py(this OptimizerHelper optim, System.IO.Stream stream, bool leaveOpen = false) { 29 | using var d = torch.NewDisposeScope(); // Create a new dispose scope for any tensors we create 30 | 31 | // Get our state_dict from our optimizer 32 | var sd = optim.state_dict(); 33 | 34 | // sd.Options -> ArrayList with all the properties 35 | var optionsList = new ArrayList(); 36 | for (int iOption = 0; iOption < sd.Options.Count; iOption++) { 37 | var tgtOption = new Dictionary(); 38 | OptimizerUtils.AssignFieldsAndPropsToTargetTable(sd.Options[iOption], tgtOption); 39 | // Add the params variable, which is created separately 40 | tgtOption["params"] = sd.StateIndexRef[iOption]; 41 | 42 | optionsList.Add(tgtOption); 43 | } 44 | 45 | // sd.State -> IDictionary with the key being the index 46 | var stateTable = new Dictionary(); 47 | for (int iState = 0; iState < sd.State.Count; iState++) { 48 | var tgtState = new Dictionary(); 49 | OptimizerUtils.AssignFieldsAndPropsToTargetTable(sd.State[iState], tgtState); 50 | 51 | stateTable[iState] = tgtState; 52 | } 53 | 54 | // Add it to the pickle format 55 | var pickleSd = new Dictionary { 56 | ["param_groups"] = optionsList, 57 | ["state"] = stateTable 58 | }; 59 | 60 | PyTorchPickler.PickleStateDict(stream, pickleSd, leaveOpen); 61 | } 62 | 63 | 64 | /// 65 | /// Load the optimizer state from a file saved using `torch.save` 66 | /// 67 | /// The file path. 68 | /// 69 | /// This method only supports loading the newer format used by `torch.save`, using a zip file. 70 | /// The order in which the parameters were added to the optimizer must be identical when saving and loading. 71 | /// 72 | public static void load_py(this OptimizerHelper optim, string location) { 73 | if (!System.IO.File.Exists(location)) 74 | throw new System.IO.FileNotFoundException(location); 75 | 76 | using var stream = System.IO.File.OpenRead(location); 77 | optim.load_py(stream); 78 | } 79 | 80 | /// 81 | /// Load the optimizer state from a file saved using `torch.save` 82 | /// 83 | /// A readable stream instance. 84 | /// true to leave the stream open after saving the file 85 | /// 86 | /// This method only supports loading the newer format used by `torch.save`, using a zip file. 87 | /// The order in which the parameters were added to the optimizer must be identical when saving and loading. 88 | /// 89 | public static void load_py(this OptimizerHelper optim, System.IO.Stream stream, bool leaveOpen = false) { 90 | // Unpickle the state dictionary into memory 91 | var loadedStateDict = PyTorchUnpickler.UnpickleStateDict(stream, leaveOpen); 92 | 93 | // We will get the state dict from the optimizer, and then set the properties using reflection 94 | var optimStateDict = optim.state_dict(); 95 | 96 | // The stateDict should have two keys: 97 | // 1] "param_groups" => equivalent to Options (should have the name of options) 98 | var loadedParamGroups = (ArrayList)loadedStateDict["param_groups"]!; 99 | if (loadedParamGroups.Count != optimStateDict.Options.Count) 100 | throw new ArgumentException("Identified a mismatch between between the number parameter groups in the loaded state dict and the current one. Are you sure you added the same number of parameter groups?"); 101 | 102 | // Store a mapping between state index in the TorchSharp model to state index in the PyTorch model 103 | var stateIndexToKeyAndOptions = new Dictionary(); 104 | 105 | // Assign all the fields in the param groups 106 | for (int iOption = 0; iOption < optimStateDict.Options.Count; iOption++) { 107 | var reference = (IDictionary)loadedParamGroups[iOption]!; 108 | OptimizerUtils.AssignFieldsAndPropsFromReferenceTable(optimStateDict.Options[iOption], reference); 109 | 110 | // Map the indicies stored in StateIndexRef to the indicies stored in "params" 111 | // `referenceIdxs` is the list of keys in the loaded state object. 112 | var referenceIdxs = (ArrayList)reference["params"]!; 113 | // `targetIdxs` is the list of state indexes mapping to this group optoins 114 | var targetIdxs = optimStateDict.StateIndexRef[iOption]; 115 | 116 | // Both `referenceIdxs` and `targetIdxs` should have the same number of parameters (confirm this) 117 | if (targetIdxs.Count != referenceIdxs.Count) 118 | throw new ArgumentException("Identified a mismatch between between the parameter groups in the loaded state dict and the current one. Are you sure you added all the parameter groups the same way?"); 119 | 120 | for (int iState = 0; iState < targetIdxs.Count; iState++) 121 | stateIndexToKeyAndOptions[targetIdxs[iState]] = ((int)referenceIdxs[iState]!, optimStateDict.Options[iOption]); 122 | } 123 | 124 | // 2] "state" => equivalent to State 125 | var loadedState = (IDictionary)loadedStateDict["state"]!; 126 | // Assign all the fields in the state (note: we don't have to have all the values in the state) 127 | for (int iState = 0; iState < optimStateDict.State.Count; iState++) { 128 | // Get the key and the options 129 | var (stateKey, stateParamOptions) = stateIndexToKeyAndOptions[iState]; 130 | // Retrieve the reference value from the loaded satte 131 | var reference = (IDictionary?)loadedState[stateKey]; 132 | 133 | // If it doesn't exist - that means that the state was never initialized. Reinitialize it with the 134 | // new parameter groups. 135 | if (reference is null || reference.Count == 0) { 136 | optimStateDict.State[iState].Initialize(stateParamOptions); 137 | continue; 138 | } 139 | 140 | // Assign all the fields from the reference into our state. 141 | OptimizerUtils.AssignFieldsAndPropsFromReferenceTable(optimStateDict.State[iState], reference); 142 | } 143 | } 144 | 145 | } 146 | } -------------------------------------------------------------------------------- /TorchSharp.PyBridge/PyTorchPickler.cs: -------------------------------------------------------------------------------- 1 | using System.Collections; 2 | using System.IO.Compression; 3 | using System.Text; 4 | using Razorvine.Pickle; 5 | using TorchSharp.Modules; 6 | 7 | namespace TorchSharp.PyBridge { 8 | public static class PyTorchPickler { 9 | static PyTorchPickler() { 10 | Pickler.registerCustomPickler(typeof(Storage), new StoragePickler()); 11 | Pickler.registerCustomDeconstructor(typeof(EmptyOrderedDict), new EmptyOrderedDictDeconstructor()); 12 | Pickler.registerCustomDeconstructor(typeof(torch.Tensor), new TensorDeconstructor()); 13 | Pickler.registerCustomDeconstructor(typeof(Parameter), new TensorDeconstructor()); 14 | } 15 | 16 | static readonly byte MinProducedFileFormatVersion = 0x3; 17 | /// 18 | /// Pickle the state_dict to a python compatible file to be loaded using `torch.load` 19 | /// 20 | /// Path to the file 21 | /// The state_dict to pickle 22 | public static void PickleStateDict(string file, IDictionary source) { 23 | PickleStateDict(File.OpenWrite(file), source); 24 | } 25 | 26 | /// 27 | /// Pickle the state_dict to a python compatible file to be loaded using `torch.load` 28 | /// 29 | /// Stream of the file to write 30 | /// The state_dict to pickle 31 | /// true to leave the stream open after saving the file 32 | public static void PickleStateDict(Stream stream, IDictionary source, bool leaveOpen = false) { 33 | // Create a new archive 34 | using var archive = new ZipArchive(stream, ZipArchiveMode.Create, leaveOpen); 35 | // Start with writing out the pytorch version, #3 36 | using (var versionStream = new StreamWriter(archive.CreateEntry("model/version").Open())) 37 | versionStream.WriteLine(MinProducedFileFormatVersion); 38 | 39 | // Create our unpickler with the archive, so it can pull all the relevant files 40 | // using the persistentId 41 | var pickler = new CustomPickler(archive); 42 | 43 | // Create and dump our main data.pkl file 44 | using var ms = new MemoryStream(); 45 | pickler.dump(source, ms); 46 | 47 | // Copy it into the entry 48 | var dataPkl = archive.CreateEntry("model/data.pkl"); 49 | using var dataStream = dataPkl.Open(); 50 | ms.Seek(0, SeekOrigin.Begin); 51 | ms.CopyTo(dataStream); 52 | } 53 | 54 | /// 55 | /// This class implements custom behavior for pickling, specifically regarding persistent storage. 56 | /// In the PyTorch library, instead of serializing the tensors using Pickle, they replace the tensor 57 | /// objects in the pickle file with the metadata of the tensor, plus a link to an external file in the 58 | /// archive which contains all the byte data of the tensor. 59 | /// Therefore, our custom pickler does the same thing - whenever we pickle a tensor object, we break 60 | /// it down and store the byte data in a file in the archive and return the persistent id. 61 | /// 62 | class CustomPickler : Pickler { 63 | readonly ZipArchive _archive; 64 | int _tensorCount; 65 | 66 | public CustomPickler(ZipArchive archive) { 67 | _archive = archive; 68 | _tensorCount = 0; 69 | } 70 | 71 | protected override bool persistentId(object pid, out object? newpid) { 72 | if (pid is not TensorWrapper) { 73 | newpid = null; 74 | return false; 75 | } 76 | 77 | var tensor = ((TensorWrapper)pid).Tensor; 78 | 79 | bool copied = false; 80 | if (tensor.device_type != DeviceType.CPU) { 81 | tensor = tensor.to(torch.CPU); 82 | copied = true; 83 | } 84 | 85 | // The persistentId function in pickler is a way of serializing an object using a different 86 | // stream and then pickling just a key representing it. 87 | // The data itself you store in another source. The `torch.load` function uses this functionality 88 | // and lists for the pid a tuple with the following items: 89 | // Tuple Item1: "storage" 90 | // Tuple Item2: storage_type (e.g., torch.LongTensor) 91 | // Tuple Item3: key (link to file with the byte data in the archive) 92 | // Tuple Item4: location (cpu/gpu) 93 | // Tuple Item5: numElements (number of elements in the tensor) 94 | 95 | // Start by serializing the object to a file in the archive 96 | var entry = _archive.CreateEntry($"model/data/{_tensorCount}"); 97 | using (var stream = entry.Open()) 98 | tensor.WriteBytesToStream(stream); 99 | 100 | // Collect the items for our persistentId, as above. 101 | newpid = new object[] { 102 | "storage", 103 | new Storage(GetStorageNameFromScalarType(tensor.dtype)), // storage_type 104 | _tensorCount.ToString(), // key 105 | "cpu", // location 106 | tensor.NumberOfElements // numel 107 | }; 108 | 109 | // Post-cleanup items 110 | _tensorCount++; 111 | if (copied) tensor.Dispose(); 112 | 113 | return true; 114 | } 115 | 116 | static string GetStorageNameFromScalarType(torch.ScalarType storage) { 117 | return storage switch { 118 | torch.ScalarType.Float64 => "DoubleStorage", 119 | torch.ScalarType.Float32 => "FloatStorage", 120 | torch.ScalarType.Float16 => "HalfStorage", 121 | torch.ScalarType.Int64 => "LongStorage", 122 | torch.ScalarType.Int32 => "IntStorage", 123 | torch.ScalarType.Int16 => "ShortStorage", 124 | torch.ScalarType.Int8 => "CharStorage", 125 | torch.ScalarType.Byte => "ByteStorage", 126 | torch.ScalarType.Bool => "BoolStorage", 127 | torch.ScalarType.BFloat16 => "BFloat16Storage", 128 | torch.ScalarType.ComplexFloat64 => "ComplexDoubleStorage", 129 | torch.ScalarType.ComplexFloat32 => "ComplexFloatStorage", 130 | _ => throw new NotImplementedException() 131 | }; 132 | } 133 | } 134 | 135 | #region CustomPickleClasses 136 | // This region contains private custom classes which are only used for the pickling process 137 | // These classes are used so that we can register custom picklers/class deconstructors to them 138 | // specifically. 139 | 140 | /// 141 | /// A class wrapper for a StorageType (e.g., FloatStorage, LongStorage) so that the pickler can 142 | /// assign the the storage object to our custom pickler, and we can serialize it as an python class 143 | /// without any arguments. 144 | /// 145 | class Storage { 146 | public string Type { get; set; } 147 | 148 | public Storage(string type) { 149 | Type = type; 150 | } 151 | } 152 | 153 | /// 154 | /// A class representing an empty OrderedDict, which is used in the PyTorch serializing, for the 155 | /// backward_hooks reconstructions. They recommend against serializing them, and we don't support 156 | /// them. 157 | /// 158 | class EmptyOrderedDict { 159 | public static EmptyOrderedDict Instance => new(); 160 | } 161 | 162 | /// 163 | /// A wrapper class which just contains the tensor, in order for the pickler to be able to assign 164 | /// the class to the TensorWrapper persistentId. We need a wrapper since we need to differentiate 165 | /// between the deconstructor handler and the persistentId handler. 166 | /// 167 | class TensorWrapper { 168 | public torch.Tensor Tensor { get; set; } 169 | public TensorWrapper(torch.Tensor tensor) { 170 | Tensor = tensor; 171 | } 172 | } 173 | 174 | #endregion 175 | 176 | #region CustomPicklersAndDeconstructors 177 | 178 | // This region contains the custom picklers and class deconstructors. 179 | // The way the pickle module serializes classes is by defining the name of the module, followed 180 | // by the argument list for reconstructing the module. Since we use different classes in C#, 181 | // we defined these deconstructor to recreate the structure of the Python classes/functions being used. 182 | 183 | /// 184 | /// A custom class pickler for pickling the Storage object the way PyTorch expects to recieve, as a 185 | /// class Module with no arguments for construction. 186 | /// The reason we used a custom pickler instead of a class deconstructor, is because the name of the 187 | /// module is dependent on the object being deconstructed, and using a custom pickler is more 188 | /// efficient than defining a class deconstructor for each storage type. 189 | /// 190 | class StoragePickler : IObjectPickler { 191 | public void pickle(object o, Stream outs, Pickler currentPickler) { 192 | outs.WriteByte(Opcodes.GLOBAL); 193 | var nameBytes = Encoding.ASCII.GetBytes($"torch\n{((Storage)o).Type}\n"); 194 | outs.Write(nameBytes, 0, nameBytes.Length); 195 | } 196 | } 197 | 198 | /// 199 | /// A class deconstructor which given an EmptyOrderedDict, it deconstructs the object into the individual 200 | /// members needed for reconstructing. The module in python is `collections.OrderedDict`, 201 | /// And the arguments for the constructor are what are returned by the `deconstruct` function. 202 | /// 203 | class EmptyOrderedDictDeconstructor : IObjectDeconstructor { 204 | public string get_module() { 205 | return "collections"; 206 | } 207 | 208 | public string get_name() { 209 | return "OrderedDict"; 210 | } 211 | 212 | public object[] deconstruct(object obj) { 213 | // Empty dictionary, so the argument will be an empty array. 214 | return new[] { Array.Empty() }; 215 | } 216 | } 217 | 218 | 219 | /// 220 | /// A class deconstructor which given a Tensor, it deconstructs the Tensor into the individual 221 | /// members needed for reconstructing. The PyTorch reconstructor is a function `torch._utils._rebuild_tensor_v2` 222 | /// And the arguments for that function are what are returned by the `deconstruct` function. 223 | /// 224 | class TensorDeconstructor : IObjectDeconstructor { 225 | public string get_module() { 226 | return "torch._utils"; 227 | } 228 | 229 | public string get_name() { 230 | return "_rebuild_tensor_v2"; 231 | } 232 | 233 | public object[] deconstruct(object obj) { 234 | var tensor = (torch.Tensor)obj; 235 | // Arg 0: Tensor 236 | // Arg 1: storage_offset 237 | // Arg 2: tensor_size (the dimension, is important) 238 | // Arg 3: stride (we aren't reconstructing from stride, just inserting the bytes) 239 | // Arg 4: requires_grad 240 | // Arg 5: backward_hooks, we don't support adding them in and it's not recommended in PyTorch to serialize them. 241 | return new object[] { 242 | new TensorWrapper(tensor), 243 | tensor.storage_offset(), 244 | tensor.shape.Select(i => (object)i).ToArray(), // cast to object so it's stored as tuple not array 245 | tensor.stride().Select(i => (object)i).ToArray(), // cast to object so it's stored as tuple not array 246 | tensor.requires_grad, 247 | EmptyOrderedDict.Instance 248 | }; 249 | } 250 | } 251 | 252 | #endregion 253 | } 254 | } -------------------------------------------------------------------------------- /TorchSharp.PyBridge/PyTorchUnpickler.cs: -------------------------------------------------------------------------------- 1 | using System.Collections; 2 | using System.IO.Compression; 3 | using Razorvine.Pickle; 4 | using Razorvine.Pickle.Objects; 5 | 6 | namespace TorchSharp.PyBridge { 7 | public static class PyTorchUnpickler { 8 | static PyTorchUnpickler() { 9 | Unpickler.registerConstructor("torch._utils", "_rebuild_tensor", new TensorObjectConstructor()); 10 | Unpickler.registerConstructor("torch._utils", "_rebuild_tensor_v2", new TensorObjectConstructor()); 11 | Unpickler.registerConstructor("torch._utils", "_rebuild_parameter", new ParameterObjectConstructor()); 12 | Unpickler.registerConstructor("collections", "OrderedDict", new OrderedDictObjectConstructor()); 13 | } 14 | 15 | /// 16 | /// Unpickle the state_dict from a python file saved using `torch.save` 17 | /// 18 | /// Path to the file 19 | /// The loaded state_dict 20 | public static Hashtable UnpickleStateDict(string file) { 21 | return UnpickleStateDict(File.OpenRead(file)); 22 | } 23 | 24 | /// 25 | /// Unpickle the state_dict from a python file saved using `torch.save` 26 | /// 27 | /// Stream of the file to load 28 | /// true to leave the stream open after saving the file 29 | /// true to return descriptor objects and streams instead of tensors so that they can be loaded later 30 | /// The loaded state_dict 31 | public static Hashtable UnpickleStateDict(Stream stream, bool leaveOpen = false, bool skipTensorRead = false) { 32 | if (skipTensorRead && !leaveOpen) 33 | throw new ArgumentException("leaveOpen must be true when skipTensorRead is true"); 34 | 35 | // Make sure it's a zip file 36 | // If it's not, then it was saved using legacy torch save and we don't support it (yet, at least) 37 | // Check the local file signature 38 | byte[] signature = new byte[4]; 39 | stream.Read(signature, 0, 4); 40 | if (signature[0] != 0x50 || signature[1] != 0x4b || signature[2] != 0x03 || signature[3] != 0x04) 41 | throw new NotImplementedException("The model being loaded was saved using the old PyTorch format and isn't supported in TorchSharp. Please re-save using the new PyTorch format."); 42 | 43 | // Open the archive, since we know it's a zip file 44 | stream.Seek(0, SeekOrigin.Begin); 45 | using var archive = new ZipArchive(stream, ZipArchiveMode.Read, leaveOpen); 46 | 47 | // Find the data.pkl file, this is our main file 48 | var pklEntry = archive.Entries.First(e => e.Name.EndsWith("data.pkl")); 49 | 50 | // Create our unpickler with the archive, so it can pull all the relevant files 51 | // using the persistentId 52 | var unpickler = new CustomUnpickler(archive, skipTensorRead); 53 | // The unpickle returns a hash mapping ["key"] to the tensor 54 | return (Hashtable)unpickler.load(pklEntry.Open()); 55 | } 56 | 57 | /// 58 | /// This class implements custom behavior for unpickling, specifically regarding persistent storage. 59 | /// In the PyTorch library, instead of serializing the tensors using Pickle, they replace the tensor 60 | /// objects in the pickle file with the metadata of the tensor, plus a link to an external file in the 61 | /// archive which contains all the byte data of the tensor. 62 | /// Therefore, our custom unpickler defines the behavior for restoring the data from the archive 63 | /// when the unpickler encounters a persistentId. 64 | /// 65 | class CustomUnpickler : Unpickler { 66 | readonly ZipArchive _archive; 67 | 68 | readonly bool _skipTensorRead; 69 | 70 | public CustomUnpickler(ZipArchive archive, bool skipTensorRead) { 71 | _archive = archive; 72 | _skipTensorRead = skipTensorRead; 73 | } 74 | 75 | protected override object persistentLoad(object pid) { 76 | // The persistentLoad function in pickler is a way of pickling a key and then loading 77 | // the data yourself from another source. The `torch.save` function uses this functionality 78 | // and lists for the pid a tuple with the following items: 79 | var opid = (object[])pid; 80 | 81 | // Tuple Item0: "storage" 82 | if ((string)opid[0] != "storage") 83 | throw new NotImplementedException("Unknown persistent id loaded"); 84 | 85 | // Tuple Item1: storage_type (e.g., torch.LongTensor), which is broken into module=torch, name=LongTensor 86 | string storageType = ((ClassDictConstructor)opid[1]).name; 87 | // Tuple Item2: key (filename in the archive) 88 | string archiveKey = (string)opid[2]; 89 | // Tuple Item3: location (cpu/gpu), but we always load onto CPU. 90 | // Tuple Item4: numElems (the number of elements in the tensor) 91 | 92 | // Convert the storage name into the relevant scalar type (e.g., LongStorage => torch.long) 93 | // and then check how many bytes each element is 94 | var dtype = GetScalarTypeFromStorageName(storageType); 95 | 96 | // Retrieve the entry from the archive 97 | var entry = _archive.Entries 98 | .Select((archiveEntry, index) => (archiveEntry, index)) 99 | .First(e => e.archiveEntry.FullName.EndsWith($"data/{archiveKey}")); 100 | 101 | // Send this back, so our TensorObjectConstructor can create our torch.tensor from the object. 102 | return new TensorStream { 103 | ArchiveIndex = entry!.index, 104 | ArchiveEntry = entry!.archiveEntry, 105 | DType = dtype, 106 | SkipTensorRead = _skipTensorRead, 107 | }; 108 | } 109 | 110 | static torch.ScalarType GetScalarTypeFromStorageName(string storage) { 111 | return storage switch { 112 | "DoubleStorage" => torch.float64, 113 | "FloatStorage" => torch.@float, 114 | "HalfStorage" => torch.half, 115 | "LongStorage" => torch.@long, 116 | "IntStorage" => torch.@int, 117 | "ShortStorage" => torch.int16, 118 | "CharStorage" => torch.int8, 119 | "ByteStorage" => torch.uint8, 120 | "BoolStorage" => torch.@bool, 121 | "BFloat16Storage" => torch.bfloat16, 122 | "ComplexDoubleStorage" => torch.cdouble, 123 | "ComplexFloatStorage" => torch.cfloat, 124 | _ => throw new NotImplementedException() 125 | }; 126 | } 127 | } 128 | 129 | /// 130 | /// The unpickler implementation requires a __setstate__ function for unpickling an ordered dict, due 131 | /// to the way it was saved. This class is just a regular Hashtable with an implementation for the 132 | /// __setstate__. 133 | /// 134 | class OrderedDict : Hashtable { 135 | public void __setstate__(Hashtable arg) { 136 | foreach (string key in arg.Keys) { 137 | if (arg[key] is torch.Tensor) 138 | this[key] = arg[key]; 139 | } 140 | } 141 | } 142 | 143 | /// 144 | /// The PyTorch library stores the parameters in an OrderedDict, and we don't have that class in C#, 145 | /// so instead we treat it as a Hashtable. 146 | /// 147 | class OrderedDictObjectConstructor : IObjectConstructor { 148 | public object construct(object[] args) { 149 | return new OrderedDict(); 150 | } 151 | } 152 | 153 | /// 154 | /// This constructor recreated the behavior from the `torch._utils._rebuild_tensor_V2` method, which 155 | /// gets all the parameters for a tensor and constructs the tensor with all the relevant properties. 156 | /// 157 | class TensorObjectConstructor : IObjectConstructor { 158 | public object construct(object[] args) { 159 | // Arg 0: returned from our custom pickler 160 | var tensorStream = (TensorStream)args[0]; 161 | 162 | var constructor = new TensorConstructorArgs { 163 | ArchiveIndex = tensorStream.ArchiveIndex, 164 | Data = tensorStream.ArchiveEntry!.Open(), 165 | DType = tensorStream.DType, 166 | // Arg 1: storage_offset 167 | StorageOffset = (int)args[1], 168 | // Arg 2: tensor_shape 169 | Shape = ((object[])args[2]).Select(i => (long)(int)i).ToArray(), 170 | // Arg 3: stride 171 | Stride = ((object[])args[3]).Select(i => (long)(int)i).ToArray(), 172 | // Arg 4: requires_grad 173 | RequiresGrad = (bool)args[4], 174 | }; 175 | 176 | // Arg 5: backward_hooks, we don't support adding them in and it's not recommended 177 | // in PyTorch to serialize them. 178 | 179 | return tensorStream.SkipTensorRead 180 | ? constructor 181 | : constructor.ReadTensorFromStream(); 182 | } 183 | } 184 | 185 | /// 186 | /// This object constructor identifies when a torch.nn.Parameter object is being reconstructed from 187 | /// the pickle file. This is used to identify if the user is trying to load a saved model and not a 188 | /// saved state dict. 189 | /// 190 | class ParameterObjectConstructor : IObjectConstructor { 191 | public object construct(object[] args) { 192 | // If the user got here, that means that he saved the entire model and not the state dictionary 193 | // And we only support loading the state dict 194 | throw new NotImplementedException("The file trying to be load contains the entire model and not just the state_dict. Please resave use `torch.save(model.state_dict(), ...)`"); 195 | } 196 | } 197 | 198 | /// 199 | /// This class is an intermediate record which contains all the information to construct the tensor from a stream 200 | /// provided in the class (usually this stream is set to a DeflateStream from a ZipArchiveEntry). This can be used 201 | /// to read the tensor directly, or to store a pointer to a tensor to be read at a later time, and avoid excessive memory use. 202 | /// 203 | internal class TensorConstructorArgs { 204 | public int ArchiveIndex { get; init; } 205 | public Stream Data { get; init; } 206 | public torch.ScalarType DType { get; init; } 207 | public int StorageOffset { get; init; } 208 | public long[] Shape { get; init; } 209 | public long[] Stride { get; init; } 210 | public bool RequiresGrad { get; init; } 211 | 212 | private bool _alreadyRead = false; 213 | public torch.Tensor ReadTensorFromStream() { 214 | if (_alreadyRead) 215 | throw new InvalidOperationException("The tensor has already been constructed, cannot read tensor twice."); 216 | 217 | var temp = torch 218 | .empty(Shape, DType, device: torch.CPU) 219 | .as_strided(Shape, Stride, StorageOffset); 220 | temp.ReadBytesFromStream(Data); 221 | Data.Close(); 222 | 223 | _alreadyRead = true; 224 | return temp; 225 | } 226 | } 227 | 228 | /// 229 | /// When the unpickler first loads in the tensor, it only has access to metadata about the storage 230 | /// of the tensor, but not the info about stride/shape etc. That part is done in the TensorReconstructor. 231 | /// Therefore, this class is a simple wrapper for the bytes + dtype of the storage. 232 | /// 233 | class TensorStream { 234 | public int ArchiveIndex { get; init; } 235 | public ZipArchiveEntry ArchiveEntry { get; init; } 236 | public torch.ScalarType DType { get; init; } 237 | public bool SkipTensorRead { get; init; } 238 | } 239 | } 240 | } 241 | -------------------------------------------------------------------------------- /TorchSharp.PyBridge/Safetensors.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Collections.Generic; 3 | using System.Linq; 4 | using System.Text; 5 | using System.Text.Json; 6 | using System.Text.Json.Serialization; 7 | using System.Threading.Tasks; 8 | 9 | namespace TorchSharp.PyBridge { 10 | public static class Safetensors { 11 | 12 | public static Dictionary LoadStateDict(string path, List? keysToKeep = null) { 13 | using var stream = File.OpenRead(path); 14 | return LoadStateDict(stream, keysToKeep: keysToKeep); 15 | } 16 | 17 | public static Dictionary LoadStateDict(Stream stream, bool leaveOpen = false, List? keysToKeep = null) { 18 | 19 | // Start by loading in the index of all the tensors 20 | var index = LoadIndex(stream); 21 | 22 | long offset = stream.Position; 23 | // Each entry in the index contains all the info for reconstructing the tensors 24 | var ret = new Dictionary(); 25 | foreach (var kvp in index) { 26 | if (kvp.Key == "__metadata__") continue; 27 | if (keysToKeep is not null && !keysToKeep.Contains(kvp.Key)) continue; 28 | 29 | var tensor = torch.empty(kvp.Value.Shape, dtype: ConvertToTorchDType(kvp.Value.DataType)); 30 | 31 | // Make sure the length matches the number of bytes to load 32 | long length = kvp.Value.Offsets[1] - kvp.Value.Offsets[0]; 33 | if (length != tensor.ElementSize * tensor.NumberOfElements) 34 | throw new NotImplementedException($"Error when loading tensor {kvp.Key} - mismatched # of elements"); 35 | 36 | stream.Position = offset + kvp.Value.Offsets[0]; 37 | tensor.ReadBytesFromStream(stream); 38 | 39 | ret.Add(kvp.Key, tensor); 40 | } 41 | 42 | 43 | if (!leaveOpen) 44 | stream.Close(); 45 | 46 | return ret; 47 | } 48 | 49 | public static void SaveStateDict(string path, Dictionary stateDict) { 50 | using var stream = File.OpenWrite(path); 51 | SaveStateDict(stream, stateDict); 52 | } 53 | 54 | public static void SaveStateDict(Stream stream, Dictionary stateDict, bool leaveOpen = false) { 55 | // We want to first build the index and then write out the tensors themselves. Therefore, first convert 56 | // the state dict to an ordered collection, and build the index. 57 | var orderedState = stateDict.ToList(); 58 | var index = new Dictionary(); 59 | long offset = 0; 60 | foreach (var kvp in orderedState) { 61 | long length = kvp.Value.NumberOfElements * kvp.Value.ElementSize; 62 | 63 | index.Add(kvp.Key, new SafetensorsEntry() { 64 | DataType = ConvertToSafeTensorsDType(kvp.Value.dtype), 65 | Shape = kvp.Value.shape, 66 | Offsets = new[] { offset, offset + length } 67 | }); 68 | offset += length; 69 | }// next key 70 | 71 | byte[] indexJson = Encoding.UTF8.GetBytes(JsonSerializer.Serialize(index)); 72 | 73 | // Write out the JSON followed by the bytes of the tensors 74 | var br = new BinaryWriter(stream); 75 | br.Write((ulong)indexJson.Length); 76 | br.Write(indexJson); 77 | foreach (var kvp in orderedState) { 78 | if (kvp.Value.device.type == DeviceType.CPU) 79 | kvp.Value.WriteBytesToStream(stream); 80 | else { 81 | using var tmp = kvp.Value.cpu(); 82 | tmp.WriteBytesToStream(stream); 83 | } 84 | } 85 | if (!leaveOpen) 86 | br.Close(); 87 | } 88 | 89 | 90 | internal static Dictionary LoadIndex(string path) { 91 | using var stream = File.OpenRead(path); 92 | return LoadIndex(stream); 93 | } 94 | 95 | internal static Dictionary LoadIndex(Stream stream) { 96 | // First 8 bytes represent the length of the JSON in UTF8. 97 | ulong length = BitConverter.ToUInt64(stream.ReadBytes(8)); 98 | if (length > int.MaxValue) 99 | throw new ArgumentOutOfRangeException(nameof(length), "Length of JSON exceeded int.MaxValue, not supported yet"); 100 | 101 | // Read the rest of the JSON, and deserialize it 102 | var jsonBytes = stream.ReadBytes((int)length); 103 | return JsonSerializer.Deserialize>(Encoding.UTF8.GetString(jsonBytes)) ?? throw new NotImplementedException("Loaded header string failed to deserialize into the correct format."); 104 | } 105 | 106 | private static torch.ScalarType ConvertToTorchDType(string dataType) { 107 | return dataType switch { 108 | "F64" => torch.ScalarType.Float64, 109 | "F32" => torch.ScalarType.Float32, 110 | "F16" => torch.ScalarType.Float16, 111 | "BF16" => torch.ScalarType.BFloat16, 112 | "I64" => torch.ScalarType.Int64, 113 | "I32" => torch.ScalarType.Int32, 114 | "I16" => torch.ScalarType.Int16, 115 | "I8" => torch.ScalarType.Int8, 116 | "U8" => torch.ScalarType.Byte, 117 | "BOOL" => torch.ScalarType.Bool, 118 | _ => throw new NotImplementedException($"Unrecognized data type listed: {dataType}") 119 | }; 120 | } 121 | 122 | private static string ConvertToSafeTensorsDType(torch.ScalarType dtype) { 123 | return dtype switch { 124 | torch.ScalarType.Float64 => "F64", 125 | torch.ScalarType.Float32 => "F32", 126 | torch.ScalarType.Float16 => "F16", 127 | torch.ScalarType.BFloat16 => "BF16", 128 | torch.ScalarType.Int64 => "I64", 129 | torch.ScalarType.Int32 => "I32", 130 | torch.ScalarType.Int16 => "I16", 131 | torch.ScalarType.Int8 => "I8", 132 | torch.ScalarType.Byte => "U8", 133 | torch.ScalarType.Bool => "BOOL", 134 | _ => throw new NotImplementedException($"Unrecognized data type listed: {dtype}") 135 | }; 136 | } 137 | } 138 | internal class SafetensorsEntry { 139 | [JsonPropertyName("dtype")] 140 | public string DataType { get; init; } 141 | 142 | [JsonPropertyName("shape")] 143 | public long[] Shape { get; init; } 144 | 145 | [JsonPropertyName("data_offsets")] 146 | public long[] Offsets { get; init; } 147 | } 148 | } 149 | -------------------------------------------------------------------------------- /TorchSharp.PyBridge/TorchSharp.PyBridge.csproj: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | Library 5 | net6.0 6 | enable 7 | enable 8 | AnyCPU;x64;arm64 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | Shaltiel Shmidman 19 | README.md 20 | MIT 21 | https://github.com/shaltielshmid/TorchSharp.PyBridge 22 | https://github.com/shaltielshmid/TorchSharp.PyBridge.git 23 | git 24 | 1.4.3 25 | 1.4.3.0 26 | 1.4.3.0 27 | 28 | 1.4.3: 29 | - Fixed #21: `strict` is not passed to `load_safetensor` in `load_checkpoint` extension 30 | 1.4.2: 31 | - PR #20: Optimize load_py for memory and speed (@ejhg) 32 | 1.4.1: 33 | - Fixed #17: How to disable tqdm output when loading sharded safetensors 34 | 1.4.0: 35 | - Exposed `Safetensors`, `PytorchPickler` and `PytorchUnpickler` to allow for loading/saving python tensors outside of a model. 36 | - Fixed #16: SaveStateDict calls itself recursively and fails on locked file 37 | 1.3.2: 38 | - Fixed #13: UnpickleStateDict on BatchNorm2d error 39 | 1.3.1: 40 | - Fixed error on Apple Silicon devices 41 | 1.3.0: 42 | - Added support for loading tensors that are greater than 2GB (following the update in TorchSharp 0.102.0) 43 | - Added support for loading and saving safetensors when model isn't on CPU. 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | -------------------------------------------------------------------------------- /pack.bat: -------------------------------------------------------------------------------- 1 | dotnet pack TorchSharp.PyBridge -c Release /p:SignAssembly=true /p:AssemblyOriginatorKeyFile=%1 2 | 3 | 4 | 5 | --------------------------------------------------------------------------------