├── StableDiffusionTorchSharp ├── ModelLoader │ ├── IModelLoader.cs │ ├── Tensor.cs │ ├── SafetensorsLoader.cs │ ├── PickleLoader.cs │ └── LoadData.cs ├── StableDiffusionTorchSharp.csproj ├── Encoder.cs ├── EulerDiscreteScheduler.cs ├── Attention.cs ├── Program.cs ├── Clip.cs ├── Decoder.cs └── Diffusion.cs ├── README.md ├── LICENSE.md └── StableDiffusionTorchSharp.sln /StableDiffusionTorchSharp/ModelLoader/IModelLoader.cs: -------------------------------------------------------------------------------- 1 | namespace StableDiffusionTorchSharp.ModelLoader 2 | { 3 | public interface IModelLoader 4 | { 5 | List ReadTensorsInfoFromFile(string fileName); 6 | byte[] ReadByteFromFile(Tensor tensor); 7 | 8 | } 9 | } 10 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Introduction 2 | Stable Diffusion model v1.5 for TorchSharp. 3 | The cpu requires a minimum of 16GB of memory. 4 | # download checkpoint 5 | https://huggingface.co/williamlzw/stable-diffusion-1-5-torchsharp 6 | 7 | and now you can download safetensors/ckpt checkpoint. 8 | 9 | # inference model with torchsharp 10 | run c# program 11 | 12 | -------------------------------------------------------------------------------- /StableDiffusionTorchSharp/ModelLoader/Tensor.cs: -------------------------------------------------------------------------------- 1 | namespace StableDiffusionTorchSharp.ModelLoader 2 | { 3 | public class Tensor 4 | { 5 | public string Name { get; set; } 6 | public TorchSharp.torch.ScalarType Type { get; set; } = TorchSharp.torch.ScalarType.Float16; 7 | public List Shape { get; set; } = new List(); 8 | public List Stride { get; set; } = new List(); 9 | public string DataNameInZipFile { get; set; } 10 | public string FileName { get; set; } 11 | public List Offset { get; set; } = new List(); 12 | public long BodyPosition { get; set; } 13 | } 14 | } 15 | -------------------------------------------------------------------------------- /StableDiffusionTorchSharp/StableDiffusionTorchSharp.csproj: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | Exe 5 | net6.0 6 | enable 7 | enable 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | -------------------------------------------------------------------------------- /StableDiffusionTorchSharp/Encoder.cs: -------------------------------------------------------------------------------- 1 | using Microsoft.ML.Tokenizers; 2 | 3 | namespace StableDiffusionTorchSharp 4 | { 5 | public class ClipTokenizer 6 | { 7 | private readonly Tokenizer _tokenizer; 8 | private readonly int _startToken; 9 | private readonly int _endToken; 10 | 11 | public ClipTokenizer(string vocabPath, string mergesPath, int startToken = 49406, int endToken = 49407) 12 | { 13 | _tokenizer = new Tokenizer(new Bpe(vocabPath, mergesPath, endOfWordSuffix: "")); 14 | _startToken = startToken; 15 | _endToken = endToken; 16 | } 17 | 18 | public List Tokenize(string text, int maxTokens = 77) 19 | { 20 | var res = _tokenizer.Encode(text); 21 | var tokens = new[] { _startToken }.Concat(res.Ids.Concat(Enumerable.Repeat(0, maxTokens - res.Ids.Count - 2))).Concat(new[] { _endToken }).ToArray(); 22 | return new List(tokens); 23 | } 24 | } 25 | } -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 william_lzw 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /StableDiffusionTorchSharp.sln: -------------------------------------------------------------------------------- 1 | 2 | Microsoft Visual Studio Solution File, Format Version 12.00 3 | # Visual Studio Version 17 4 | VisualStudioVersion = 17.7.34031.279 5 | MinimumVisualStudioVersion = 10.0.40219.1 6 | Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "StableDiffusionTorchSharp", "StableDiffusionTorchSharp\StableDiffusionTorchSharp.csproj", "{D209A46A-07DB-4648-810B-67FFE92798B9}" 7 | EndProject 8 | Global 9 | GlobalSection(SolutionConfigurationPlatforms) = preSolution 10 | Debug|Any CPU = Debug|Any CPU 11 | Release|Any CPU = Release|Any CPU 12 | EndGlobalSection 13 | GlobalSection(ProjectConfigurationPlatforms) = postSolution 14 | {D209A46A-07DB-4648-810B-67FFE92798B9}.Debug|Any CPU.ActiveCfg = Debug|Any CPU 15 | {D209A46A-07DB-4648-810B-67FFE92798B9}.Debug|Any CPU.Build.0 = Debug|Any CPU 16 | {D209A46A-07DB-4648-810B-67FFE92798B9}.Release|Any CPU.ActiveCfg = Release|Any CPU 17 | {D209A46A-07DB-4648-810B-67FFE92798B9}.Release|Any CPU.Build.0 = Release|Any CPU 18 | EndGlobalSection 19 | GlobalSection(SolutionProperties) = preSolution 20 | HideSolutionNode = FALSE 21 | EndGlobalSection 22 | GlobalSection(ExtensibilityGlobals) = postSolution 23 | SolutionGuid = {39792448-4816-4F15-9F39-21B053A80024} 24 | EndGlobalSection 25 | EndGlobal 26 | -------------------------------------------------------------------------------- /StableDiffusionTorchSharp/ModelLoader/SafetensorsLoader.cs: -------------------------------------------------------------------------------- 1 | using Newtonsoft.Json.Linq; 2 | using System.Text; 3 | 4 | namespace StableDiffusionTorchSharp.ModelLoader 5 | { 6 | public class SafetensorsLoader : IModelLoader 7 | { 8 | public List ReadTensorsInfoFromFile(string inputFileName) 9 | { 10 | using (FileStream stream = File.OpenRead(inputFileName)) 11 | { 12 | long len = stream.Length; 13 | if (len < 10) 14 | { 15 | throw new ArgumentOutOfRangeException("File cannot be valid safetensors: too short"); 16 | } 17 | 18 | // Safetensors file first 8 byte to int64 is the header length 19 | byte[] headerBlock = new byte[8]; 20 | stream.Read(headerBlock, 0, 8); 21 | long headerSize = BitConverter.ToInt64(headerBlock, 0); 22 | if (len < 8 + headerSize || headerSize <= 0 || headerSize > 100_000_000) 23 | { 24 | throw new ArgumentOutOfRangeException($"File cannot be valid safetensors: header len wrong, size:{headerSize}"); 25 | } 26 | 27 | // Read the header, header file is a json file 28 | byte[] headerBytes = new byte[headerSize]; 29 | stream.Read(headerBytes, 0, (int)headerSize); 30 | 31 | string header = Encoding.UTF8.GetString(headerBytes); 32 | long bodyPosition = stream.Position; 33 | JToken token = JToken.Parse(header); 34 | 35 | List tensors = new List(); 36 | foreach (var sub in token.ToObject>()) 37 | { 38 | Dictionary value = sub.Value.ToObject>(); 39 | value.TryGetValue("data_offsets", out JToken offsets); 40 | value.TryGetValue("dtype", out JToken dtype); 41 | value.TryGetValue("shape", out JToken shape); 42 | 43 | ulong[] offsetArray = offsets?.ToObject(); 44 | if (null == offsetArray) 45 | { 46 | continue; 47 | } 48 | long[] shapeArray = shape.ToObject(); 49 | if (shapeArray.Length < 1) 50 | { 51 | shapeArray = new long[] { 1 }; 52 | } 53 | TorchSharp.torch.ScalarType tensor_type = TorchSharp.torch.ScalarType.Float32; 54 | switch (dtype.ToString()) 55 | { 56 | case "I8": tensor_type = TorchSharp.torch.ScalarType.Int8; break; 57 | case "I16": tensor_type = TorchSharp.torch.ScalarType.Int16; break; 58 | case "I32": tensor_type = TorchSharp.torch.ScalarType.Int32; break; 59 | case "I64": tensor_type = TorchSharp.torch.ScalarType.Int64; break; 60 | case "BF16": tensor_type = TorchSharp.torch.ScalarType.BFloat16; break; 61 | case "F16": tensor_type = TorchSharp.torch.ScalarType.Float16; break; 62 | case "F32": tensor_type = TorchSharp.torch.ScalarType.Float32; break; 63 | case "F64": tensor_type = TorchSharp.torch.ScalarType.Float64; break; 64 | case "U8": 65 | case "U16": 66 | case "U32": 67 | case "U64": 68 | case "BOOL": 69 | case "F8_E4M3": 70 | case "F8_E5M2": break; 71 | } 72 | 73 | Tensor tensor = new Tensor 74 | { 75 | Name = sub.Key, 76 | Type = tensor_type, 77 | Shape = shapeArray.ToList(), 78 | Offset = offsetArray.ToList(), 79 | FileName = inputFileName, 80 | BodyPosition = bodyPosition 81 | }; 82 | 83 | tensors.Add(tensor); 84 | } 85 | return tensors; 86 | } 87 | } 88 | 89 | private byte[] ReadByteFromFile(string inputFileName, long bodyPosition, long offset, int size) 90 | { 91 | using (FileStream stream = File.OpenRead(inputFileName)) 92 | { 93 | stream.Seek(bodyPosition + offset, SeekOrigin.Begin); 94 | byte[] dest = new byte[size]; 95 | stream.Read(dest, 0, size); 96 | return dest; 97 | } 98 | } 99 | 100 | public byte[] ReadByteFromFile(Tensor tensor) 101 | { 102 | string inputFileName = tensor.FileName; 103 | long bodyPosition = tensor.BodyPosition; 104 | ulong offset = tensor.Offset[0]; 105 | int size = (int)(tensor.Offset[1] - tensor.Offset[0]); 106 | return ReadByteFromFile(inputFileName, bodyPosition, (long)offset, size); 107 | } 108 | } 109 | } 110 | -------------------------------------------------------------------------------- /StableDiffusionTorchSharp/EulerDiscreteScheduler.cs: -------------------------------------------------------------------------------- 1 | using TorchSharp; 2 | using static TorchSharp.torch; 3 | 4 | namespace StableDiffusionTorchSharp 5 | { 6 | public class EulerDiscreteScheduler 7 | { 8 | private long num_train_timesteps_; 9 | private int steps_offset_; 10 | private Tensor betas_; 11 | private Tensor alphas_; 12 | private Tensor alphas_cumprod_; 13 | private Tensor sigmas_; 14 | public Tensor timesteps_; 15 | private long num_inference_steps_; 16 | 17 | public EulerDiscreteScheduler(long num_train_timesteps = 1000, float beta_start = 0.00085f, float beta_end = 0.012f, int steps_offset = 1) 18 | { 19 | num_train_timesteps_ = num_train_timesteps; 20 | steps_offset_ = steps_offset; 21 | betas_ = torch.pow(torch.linspace(Math.Pow(beta_start, 0.5), Math.Pow(beta_end, 0.5), num_train_timesteps, ScalarType.Float32), 2); 22 | alphas_ = 1f - betas_; 23 | alphas_cumprod_ = torch.cumprod(alphas_, 0); 24 | var sigmas = torch.pow((1.0f - alphas_cumprod_) / alphas_cumprod_, 0.5f); 25 | sigmas_ = torch.cat(new Tensor[] { sigmas.flip(0), torch.tensor(new float[] { 0.0f }) }); 26 | timesteps_ = torch.linspace(0, num_train_timesteps - 1, num_train_timesteps_).flip(0); 27 | } 28 | 29 | public Tensor InitNoiseSigma() 30 | { 31 | return torch.pow(torch.pow(sigmas_.max(), 2) + 1, 0.5f); 32 | } 33 | 34 | public Tensor ScaleModelInput(Tensor sample, Tensor timestep) 35 | { 36 | var step_index = (timesteps_ == timestep).nonzero().ToInt64(); 37 | var sigma = sigmas_[step_index].ToSingle(); 38 | sample = sample / (float)(Math.Pow(Math.Pow(sigma, 2) + 1, 0.5f)); 39 | return sample; 40 | } 41 | 42 | public void SetTimesteps(long num_inference_steps, torch.Device device) 43 | { 44 | num_inference_steps_ = num_inference_steps; 45 | long step_ratio = num_train_timesteps_ / num_inference_steps_; 46 | var timesteps = (torch.arange(0, num_inference_steps, ScalarType.Float32, device: device) * step_ratio).round().flip(0); 47 | timesteps = timesteps + steps_offset_; 48 | var sigmas = torch.pow((1.0f - alphas_cumprod_) / alphas_cumprod_, 0.5f); 49 | sigmas = Interp(timesteps, torch.arange(0, sigmas.shape[0], device: device), sigmas.to(device)); 50 | sigmas_ = torch.cat(new Tensor[] { sigmas, torch.tensor(new float[] { 0.0f }, device: device) }); 51 | timesteps_ = timesteps; 52 | } 53 | 54 | public Tensor Step(Tensor model_output, Tensor timestep, Tensor sample) 55 | { 56 | var step_index = (timesteps_ == timestep).nonzero().ToInt64(); 57 | var sigma = sigmas_[step_index].ToSingle(); 58 | float gamma = 0; 59 | if (sigma >= 0 && sigma <= System.Single.PositiveInfinity) 60 | { 61 | gamma = (float)Math.Min(0f, Math.Pow(2, 0.5f) - 1); 62 | } 63 | var noise = torch.randn(model_output.shape, model_output.dtype, model_output.device); 64 | var sigma_hat = sigma * (gamma + 1); 65 | if (gamma > 0) 66 | { 67 | sample = sample + noise * (torch.pow(torch.pow(sigma_hat, 2) - torch.pow(sigma, 2), 0.5)); 68 | } 69 | var pred_original_sample = sample - sigma_hat * model_output; 70 | var derivative = (sample - pred_original_sample) / sigma_hat; 71 | var dt = sigmas_[step_index + 1].ToSingle() - sigma_hat; 72 | var prev_sample = sample + derivative * dt; 73 | return prev_sample; 74 | } 75 | 76 | private Tensor Interp(Tensor x, Tensor xp, Tensor fp) 77 | { 78 | var sort_idx = torch.argsort(xp); 79 | xp = xp[sort_idx]; 80 | fp = fp[sort_idx]; 81 | var idx = torch.searchsorted(xp, x); 82 | idx = torch.clamp(idx, 0, xp.shape[0] - 2); 83 | var weight = (x - xp[idx]) / (xp[idx + 1] - xp[idx]); 84 | return fp[idx] * (1 - weight) + fp[idx + 1] * weight; 85 | } 86 | } 87 | } 88 | -------------------------------------------------------------------------------- /StableDiffusionTorchSharp/Attention.cs: -------------------------------------------------------------------------------- 1 | using TorchSharp; 2 | using TorchSharp.FlashAttention; 3 | using TorchSharp.Modules; 4 | using static TorchSharp.torch; 5 | using static TorchSharp.torch.nn; 6 | 7 | namespace StableDiffusionTorchSharp 8 | { 9 | public class SelfAttention : Module 10 | { 11 | internal readonly Linear in_proj; 12 | internal readonly Linear out_proj; 13 | private readonly long n_heads_; 14 | private readonly long d_head; 15 | bool causal_mask_; 16 | float dropout_p; 17 | bool useFlashAtten; 18 | Dropout dropout; 19 | 20 | public SelfAttention(long n_heads, long d_embed, bool in_proj_bias = true, bool out_proj_bias = true, bool causal_mask = false, float dropout_p = 0.1f, bool useFlashAtten = false) : base("SelfAttention") 21 | { 22 | in_proj = Linear(d_embed, 3 * d_embed, hasBias: in_proj_bias); 23 | out_proj = Linear(d_embed, d_embed, hasBias: out_proj_bias); 24 | n_heads_ = n_heads; 25 | d_head = d_embed / n_heads; 26 | causal_mask_ = causal_mask; 27 | this.dropout_p = dropout_p; 28 | this.useFlashAtten = useFlashAtten; 29 | dropout = Dropout(dropout_p); 30 | RegisterComponents(); 31 | } 32 | 33 | public override Tensor forward(Tensor x) 34 | { 35 | if (useFlashAtten) 36 | { 37 | using var _ = NewDisposeScope(); 38 | var input_shape = x.shape; 39 | var batch_size = input_shape[0]; 40 | var sequence_length = input_shape[1]; 41 | long[] interim_shape = new long[] { batch_size, sequence_length, 3, n_heads_, d_head }; 42 | var output = in_proj.forward(x); 43 | output = output.view(interim_shape); 44 | output = new FlashAttention(softmax_scale: 1 / (float)Math.Sqrt(d_head), dropout_p, causal_mask_).forward(output); 45 | output = output.reshape(input_shape); 46 | output = out_proj.forward(output); 47 | return output.MoveToOuterDisposeScope(); 48 | } 49 | else 50 | { 51 | using var _ = NewDisposeScope(); 52 | var input_shape = x.shape; 53 | var batch_size = input_shape[0]; 54 | var sequence_length = input_shape[1]; 55 | 56 | var ret = in_proj.forward(x).chunk(3, dim: -1); 57 | long[] interim_shape = new long[] { batch_size, sequence_length, n_heads_, d_head }; 58 | var q = ret[0]; 59 | var k = ret[1]; 60 | var v = ret[2]; 61 | 62 | q = q.view(interim_shape).transpose(1, 2); 63 | k = k.view(interim_shape).transpose(1, 2); 64 | v = v.view(interim_shape).transpose(1, 2); 65 | 66 | var weight = torch.matmul(q, k.transpose(-1, -2)); 67 | if (causal_mask_) 68 | { 69 | var mask = torch.ones_like(weight).triu(1).to(torch.@bool); 70 | weight.masked_fill_(mask, Single.NegativeInfinity); 71 | } 72 | weight = weight / (float)Math.Sqrt(d_head); 73 | weight = torch.nn.functional.softmax(weight, dim: -1); 74 | weight = dropout.forward(weight); 75 | 76 | var output = torch.matmul(weight, v); 77 | output = output.transpose(1, 2); 78 | output = output.reshape(input_shape); 79 | output = out_proj.forward(output); 80 | return output.MoveToOuterDisposeScope(); 81 | } 82 | } 83 | } 84 | 85 | public class CrossAttention : Module 86 | { 87 | internal readonly Linear q_proj; 88 | internal readonly Linear k_proj; 89 | internal readonly Linear v_proj; 90 | internal readonly Linear out_proj; 91 | internal readonly long n_heads_; 92 | internal readonly long d_head; 93 | bool causal_mask_; 94 | bool useFlashAtten; 95 | float dropout_p; 96 | 97 | public CrossAttention(long n_heads, long d_embed, long d_cross, bool in_proj_bias = true, bool out_proj_bias = true, float dropout_p = 0.2f, bool causal_mask = true, bool useFlashAtten = false) : base("CrossAttention") 98 | { 99 | q_proj = Linear(d_embed, d_embed, hasBias: in_proj_bias); 100 | k_proj = Linear(d_cross, d_embed, hasBias: in_proj_bias); 101 | v_proj = Linear(d_cross, d_embed, hasBias: in_proj_bias); 102 | out_proj = Linear(d_embed, d_embed, hasBias: out_proj_bias); 103 | n_heads_ = n_heads; 104 | d_head = d_embed / n_heads; 105 | causal_mask_ = causal_mask; 106 | this.useFlashAtten = useFlashAtten; 107 | this.dropout_p = dropout_p; 108 | RegisterComponents(); 109 | } 110 | 111 | public override Tensor forward(Tensor x, Tensor y) 112 | { 113 | if (useFlashAtten) 114 | { 115 | using var _ = NewDisposeScope(); 116 | var input_shape = x.shape; 117 | var batch_size = input_shape[0]; 118 | var sequence_length = input_shape[1]; 119 | 120 | long[] interim_shape = new long[] { batch_size, -1, n_heads_, d_head }; 121 | var q = q_proj.forward(x); 122 | var k = k_proj.forward(y); 123 | var v = v_proj.forward(y); 124 | 125 | q = q.view(interim_shape); 126 | k = k.view(interim_shape); 127 | v = v.view(interim_shape); 128 | 129 | (var output, var _, var _) = FlashAttentionInterface.flash_attn_func(q, k, v, dropout_p: dropout_p, softmax_scale: 1 / (float)Math.Sqrt(d_head), causal: causal_mask_); 130 | 131 | output = output.reshape(input_shape); 132 | output = out_proj.forward(output); 133 | return output.MoveToOuterDisposeScope(); 134 | } 135 | else 136 | { 137 | using var _ = NewDisposeScope(); 138 | var input_shape = x.shape; 139 | var batch_size = input_shape[0]; 140 | var sequence_length = input_shape[1]; 141 | 142 | long[] interim_shape = new long[] { batch_size, -1, n_heads_, d_head }; 143 | var q = q_proj.forward(x); 144 | var k = k_proj.forward(y); 145 | var v = v_proj.forward(y); 146 | 147 | q = q.view(interim_shape).transpose(1, 2); 148 | k = k.view(interim_shape).transpose(1, 2); 149 | v = v.view(interim_shape).transpose(1, 2); 150 | 151 | var weight = torch.matmul(q, k.transpose(-1, -2)); 152 | weight = weight / (float)Math.Sqrt(d_head); 153 | weight = torch.nn.functional.softmax(weight, dim: -1); 154 | 155 | weight = Dropout(dropout_p).forward(weight); 156 | 157 | var output = torch.matmul(weight, v); 158 | output = output.transpose(1, 2).contiguous(); 159 | output = output.reshape(input_shape); 160 | output = out_proj.forward(output); 161 | return output.MoveToOuterDisposeScope(); 162 | } 163 | 164 | } 165 | 166 | } 167 | } 168 | -------------------------------------------------------------------------------- /StableDiffusionTorchSharp/Program.cs: -------------------------------------------------------------------------------- 1 | using TorchSharp; 2 | using TorchSharp.Modules; 3 | using static TorchSharp.torch; 4 | 5 | namespace StableDiffusionTorchSharp 6 | { 7 | public class Program 8 | { 9 | public static void Main() 10 | { 11 | Generate(); 12 | } 13 | 14 | public static Tensor GetTimeEmbedding(float timestep) 15 | { 16 | var freqs = torch.pow(10000, -torch.arange(0, 160, dtype: torch.float32) / 160); 17 | var x = torch.tensor(new float[] { timestep }, dtype: torch.float32)[torch.TensorIndex.Colon, torch.TensorIndex.None] * freqs[torch.TensorIndex.None]; 18 | return torch.cat(new Tensor[] { torch.cos(x), torch.sin(x) }, dim: -1); 19 | } 20 | 21 | public static void Generate() 22 | { 23 | using (torch.no_grad()) 24 | { 25 | bool useFlashAttention = true; // flash attention only support fp16 or bf16 type 26 | int num_inference_steps = 20; 27 | var device = torch.device("cuda"); // if use flash attention, device must be cuda 28 | float cfg = 7.5f; 29 | ulong seed = (ulong)new Random().Next(0, int.MaxValue); 30 | //string modelname = @".\model\v1-5-pruned.safetensors"; 31 | string modelname = @".\model\unet.dat"; 32 | torchvision.io.DefaultImager = new torchvision.io.SkiaImager(100); 33 | 34 | Console.WriteLine("Device:" + device); 35 | Console.WriteLine("CFG:" + cfg); 36 | Console.WriteLine("Seed:" + seed); 37 | Console.WriteLine("Loading clip......"); 38 | var clip = new CLIP(useFlashAttention: useFlashAttention); 39 | clip.load(@".\model\clip.dat"); 40 | 41 | Console.WriteLine("Loading unet......"); 42 | var diffusion = new Diffusion(useFlashAttention: useFlashAttention); 43 | diffusion.load(modelname); 44 | 45 | Console.WriteLine("Loading vae......"); 46 | var decoder = new Decoder(useFlashAttention: useFlashAttention); 47 | decoder.load(@".\model\decoder.dat"); 48 | if (useFlashAttention) 49 | { 50 | clip = clip.half(); 51 | } 52 | clip = clip.to(device); 53 | clip.eval(); 54 | 55 | if (useFlashAttention) 56 | { 57 | diffusion = diffusion.half(); 58 | } 59 | diffusion = diffusion.to(device); 60 | diffusion.eval(); 61 | 62 | if (useFlashAttention) 63 | { 64 | decoder = decoder.half(); 65 | } 66 | decoder = decoder.to(device); 67 | decoder.eval(); 68 | 69 | ScalarType clipType = clip.embedding.token_embedding.weight.dtype; 70 | ScalarType diffusionType = diffusion.final.groupnorm.weight.dtype; 71 | ScalarType decoderType = ((Conv2d)decoder.children().First()).weight.dtype; 72 | 73 | string VocabPath = @".\model\vocab.json"; 74 | string MergesPath = @".\model\merges.txt"; 75 | var tokenizer = new ClipTokenizer(VocabPath, MergesPath); 76 | 77 | string prompt = "typographic art bird. stylized, intricate, detailed, artistic, text-based"; 78 | string uncond_prompts = ""; 79 | 80 | Console.WriteLine("Clip is doing......"); 81 | var cond_tokens_ids = tokenizer.Tokenize(prompt); 82 | var cond_tokens = torch.tensor(cond_tokens_ids, torch.@long).unsqueeze(0).to(device); 83 | var cond_context = clip.forward(cond_tokens); 84 | 85 | var uncond_tokens_ids = tokenizer.Tokenize(uncond_prompts); 86 | var uncond_tokens = torch.tensor(uncond_tokens_ids, torch.@long).unsqueeze(0).to(device); 87 | var uncond_context = clip.forward(uncond_tokens); 88 | 89 | var context = torch.cat(new Tensor[] { cond_context, uncond_context }).to(diffusionType).to(device); 90 | 91 | Console.WriteLine("Getting latents......"); 92 | long[] noise_shape = new long[] { 1, 4, 64, 64 }; 93 | Generator generator = new Generator(seed, device); 94 | var latents = torch.randn(noise_shape, generator: generator); 95 | latents = latents.to(diffusionType).to(device); 96 | var sampler = new EulerDiscreteScheduler(); 97 | 98 | sampler.SetTimesteps(num_inference_steps, device); 99 | latents *= sampler.InitNoiseSigma(); 100 | Console.WriteLine($"begin step"); 101 | for (int i = 0; i < num_inference_steps; i++) 102 | { 103 | var timestep = sampler.timesteps_[i]; 104 | var time_embedding = GetTimeEmbedding(timestep.ToSingle()).to(diffusionType).to(device); 105 | var input_latents = sampler.ScaleModelInput(latents, timestep); 106 | input_latents = input_latents.repeat(2, 1, 1, 1).to(diffusionType).to(device); 107 | var output = diffusion.forward(input_latents, context, time_embedding); 108 | var ret = output.chunk(2); 109 | var output_cond = ret[0]; 110 | var output_uncond = ret[1]; 111 | output = cfg * (output_cond - output_uncond) + output_uncond; 112 | latents = sampler.Step(output, timestep, latents); 113 | } 114 | Console.WriteLine($"end step"); 115 | latents = latents.to(decoderType); 116 | Console.WriteLine($"begin decoder"); 117 | var images = decoder.forward(latents); 118 | Console.WriteLine($"end decoder"); 119 | images = images.clip(-1, 1) * 0.5 + 0.5; 120 | images = images.cpu(); 121 | images = torchvision.transforms.functional.convert_image_dtype(images, torch.ScalarType.Byte); 122 | torchvision.io.write_jpeg(images, "result.jpg"); 123 | 124 | } 125 | } 126 | } 127 | } -------------------------------------------------------------------------------- /StableDiffusionTorchSharp/Clip.cs: -------------------------------------------------------------------------------- 1 | using TorchSharp; 2 | using TorchSharp.Modules; 3 | using static TorchSharp.torch; 4 | using static TorchSharp.torch.nn; 5 | 6 | namespace StableDiffusionTorchSharp 7 | { 8 | internal class CLIPEmbedding : Module 9 | { 10 | internal readonly Embedding token_embedding; 11 | internal readonly Parameter position_value; 12 | internal CLIPEmbedding(long n_vocab, long n_embd, long n_token) : base("CLIPEmbedding") 13 | { 14 | token_embedding = Embedding(n_vocab, n_embd); 15 | position_value = Parameter(torch.zeros(n_token, n_embd)); 16 | RegisterComponents(); 17 | } 18 | 19 | public override Tensor forward(Tensor tokens) 20 | { 21 | using var _ = NewDisposeScope(); 22 | var x = token_embedding.forward(tokens); 23 | x += position_value; 24 | return x.MoveToOuterDisposeScope(); 25 | } 26 | } 27 | 28 | internal class CLIPLayer : Module 29 | { 30 | internal readonly LayerNorm layernorm_1; 31 | internal readonly LayerNorm layernorm_2; 32 | internal readonly SelfAttention attention; 33 | internal readonly Linear linear_1; 34 | internal readonly Linear linear_2; 35 | 36 | internal CLIPLayer(long n_head, long n_embd, bool useFlashAttention = false) : base("CLIPLayer") 37 | { 38 | layernorm_1 = LayerNorm(n_embd); 39 | attention = new SelfAttention(n_head, n_embd, causal_mask: true, useFlashAtten: useFlashAttention); 40 | layernorm_2 = LayerNorm(n_embd); 41 | linear_1 = Linear(n_embd, 4 * n_embd); 42 | linear_2 = Linear(4 * n_embd, n_embd); 43 | RegisterComponents(); 44 | } 45 | 46 | public override Tensor forward(Tensor x) 47 | { 48 | using var _ = NewDisposeScope(); 49 | var residue = x; 50 | x = layernorm_1.forward(x); 51 | x = attention.forward(x); 52 | x += residue; 53 | residue = x; 54 | x = layernorm_2.forward(x); 55 | x = linear_1.forward(x); 56 | x = x * torch.sigmoid(1.702 * x); 57 | x = linear_2.forward(x); 58 | x += residue; 59 | return x.MoveToOuterDisposeScope(); 60 | } 61 | } 62 | 63 | internal class CLIP : Module 64 | { 65 | internal readonly CLIPEmbedding embedding; 66 | internal readonly ModuleList> layers; 67 | internal readonly LayerNorm layernorm; 68 | 69 | internal CLIP(bool useFlashAttention = false) : base("CLIP") 70 | { 71 | embedding = new CLIPEmbedding(49408, 768, 77); 72 | layers = nn.ModuleList>(); 73 | for (int i = 0; i < 12; i++) 74 | { 75 | layers.Add(new CLIPLayer(12, 768, useFlashAttention: useFlashAttention)); 76 | } 77 | layernorm = LayerNorm(768); 78 | RegisterComponents(); 79 | } 80 | 81 | public override Tensor forward(Tensor token) 82 | { 83 | using var _ = NewDisposeScope(); 84 | var state = embedding.forward(token); 85 | foreach (var layer in layers) 86 | { 87 | state = layer.forward(state); 88 | } 89 | var output = layernorm.forward(state); 90 | return output.MoveToOuterDisposeScope(); 91 | } 92 | 93 | public override Module load(string filename, bool strict = true, IList skip = null, Dictionary loadedParameters = null) 94 | { 95 | string extension = Path.GetExtension(filename); 96 | 97 | if (extension.ToLower().Contains("dat")) 98 | { 99 | using (FileStream fileStream = new FileStream(filename, FileMode.Open)) 100 | { 101 | using (BinaryReader binaryReader = new BinaryReader(fileStream)) 102 | { 103 | return load(binaryReader, strict, skip, loadedParameters); 104 | } 105 | } 106 | } 107 | 108 | ModelLoader.IModelLoader modelLoader = null; 109 | if (extension.ToLower().Contains("safetensor")) 110 | { 111 | modelLoader = new ModelLoader.SafetensorsLoader(); 112 | } 113 | else if (extension.ToLower().Contains("pt")) 114 | { 115 | modelLoader = new ModelLoader.PickleLoader(); 116 | } 117 | 118 | List tensors = modelLoader.ReadTensorsInfoFromFile(filename); 119 | 120 | byte[] data; 121 | var t = tensors.First(a => a.Name == "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight"); 122 | this.to(t.Type); 123 | data = modelLoader.ReadByteFromFile(tensors.First(a => a.Name == "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight")); 124 | embedding.token_embedding.weight.bytes = new Span(data); 125 | 126 | data = modelLoader.ReadByteFromFile(tensors.First(a => a.Name == "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight")); 127 | embedding.position_value.bytes = new Span(data); 128 | 129 | data = modelLoader.ReadByteFromFile(tensors.First(a => a.Name == "cond_stage_model.transformer.text_model.final_layer_norm.bias")); 130 | layernorm.bias.bytes = new Span(data); 131 | 132 | data = modelLoader.ReadByteFromFile(tensors.First(a => a.Name == "cond_stage_model.transformer.text_model.final_layer_norm.weight")); 133 | layernorm.weight.bytes = new Span(data); 134 | 135 | for (int i = 0; i < layers.Count; i++) 136 | { 137 | data = modelLoader.ReadByteFromFile(tensors.First(a => a.Name == "cond_stage_model.transformer.text_model.encoder.layers." + i + ".layer_norm1.weight")); 138 | ((CLIPLayer)layers[i]).layernorm_1.weight.bytes = new Span(data); 139 | 140 | data = modelLoader.ReadByteFromFile(tensors.First(a => a.Name == "cond_stage_model.transformer.text_model.encoder.layers." + i + ".layer_norm1.bias")); 141 | ((CLIPLayer)layers[i]).layernorm_1.bias.bytes = new Span(data); 142 | 143 | data = modelLoader.ReadByteFromFile(tensors.First(a => a.Name == "cond_stage_model.transformer.text_model.encoder.layers." + i + ".layer_norm2.weight")); 144 | ((CLIPLayer)layers[i]).layernorm_2.weight.bytes = new Span(data); 145 | 146 | data = modelLoader.ReadByteFromFile(tensors.First(a => a.Name == "cond_stage_model.transformer.text_model.encoder.layers." + i + ".layer_norm2.bias")); 147 | ((CLIPLayer)layers[i]).layernorm_2.bias.bytes = new Span(data); 148 | 149 | 150 | data = modelLoader.ReadByteFromFile(tensors.First(a => a.Name == "cond_stage_model.transformer.text_model.encoder.layers." + i + ".mlp.fc1.weight")); 151 | ((CLIPLayer)layers[i]).linear_1.weight.bytes = new Span(data); 152 | 153 | data = modelLoader.ReadByteFromFile(tensors.First(a => a.Name == "cond_stage_model.transformer.text_model.encoder.layers." + i + ".mlp.fc1.bias")); 154 | ((CLIPLayer)layers[i]).linear_1.bias.bytes = new Span(data); 155 | 156 | data = modelLoader.ReadByteFromFile(tensors.First(a => a.Name == "cond_stage_model.transformer.text_model.encoder.layers." + i + ".mlp.fc2.weight")); 157 | ((CLIPLayer)layers[i]).linear_2.weight.bytes = new Span(data); 158 | 159 | data = modelLoader.ReadByteFromFile(tensors.First(a => a.Name == "cond_stage_model.transformer.text_model.encoder.layers." + i + ".mlp.fc2.bias")); 160 | ((CLIPLayer)layers[i]).linear_2.bias.bytes = new Span(data); 161 | 162 | 163 | data = modelLoader.ReadByteFromFile(tensors.First(a => a.Name == "cond_stage_model.transformer.text_model.encoder.layers." + i + ".self_attn.q_proj.weight")).Concat(modelLoader.ReadByteFromFile(tensors.First(a => a.Name == "cond_stage_model.transformer.text_model.encoder.layers." + i + ".self_attn.k_proj.weight"))).Concat(modelLoader.ReadByteFromFile(tensors.First(a => a.Name == "cond_stage_model.transformer.text_model.encoder.layers." + i + ".self_attn.v_proj.weight"))).ToArray(); 164 | ((CLIPLayer)layers[i]).attention.in_proj.weight.bytes = data; 165 | 166 | data = modelLoader.ReadByteFromFile(tensors.First(a => a.Name == "cond_stage_model.transformer.text_model.encoder.layers." + i + ".self_attn.q_proj.bias")).Concat(modelLoader.ReadByteFromFile(tensors.First(a => a.Name == "cond_stage_model.transformer.text_model.encoder.layers." + i + ".self_attn.k_proj.bias"))).Concat(modelLoader.ReadByteFromFile(tensors.First(a => a.Name == "cond_stage_model.transformer.text_model.encoder.layers." + i + ".self_attn.v_proj.bias"))).ToArray(); 167 | ((CLIPLayer)layers[i]).attention.in_proj.bias.bytes = data; 168 | 169 | 170 | data = modelLoader.ReadByteFromFile(tensors.First(a => a.Name == "cond_stage_model.transformer.text_model.encoder.layers." + i + ".self_attn.out_proj.weight")); 171 | ((CLIPLayer)layers[i]).attention.out_proj.weight.bytes = new Span(data); 172 | 173 | data = modelLoader.ReadByteFromFile(tensors.First(a => a.Name == "cond_stage_model.transformer.text_model.encoder.layers." + i + ".self_attn.out_proj.bias")); 174 | ((CLIPLayer)layers[i]).attention.out_proj.bias.bytes = new Span(data); 175 | } 176 | 177 | return this; 178 | } 179 | 180 | } 181 | } 182 | -------------------------------------------------------------------------------- /StableDiffusionTorchSharp/ModelLoader/PickleLoader.cs: -------------------------------------------------------------------------------- 1 | using System.Collections.ObjectModel; 2 | using System.IO.Compression; 3 | using TorchSharp; 4 | 5 | namespace StableDiffusionTorchSharp.ModelLoader 6 | { 7 | public class PickleLoader : IModelLoader 8 | { 9 | private ZipArchive zip; 10 | private ReadOnlyCollection entries; 11 | 12 | public List ReadTensorsInfoFromFile(string fileName) 13 | { 14 | List tensors = new List(); 15 | 16 | zip = ZipFile.OpenRead(fileName); 17 | entries = zip.Entries; 18 | ZipArchiveEntry headerEntry = entries.First(e => e.Name == "data.pkl"); 19 | byte[] headerBytes = new byte[headerEntry.Length]; 20 | // Header is always small enough to fit in memory, so we can read it all at once 21 | using (Stream stream = headerEntry.Open()) 22 | { 23 | stream.Read(headerBytes, 0, headerBytes.Length); 24 | } 25 | 26 | if (headerBytes[0] != 0x80 || headerBytes[1] != 0x02) 27 | { 28 | throw new ArgumentException("Not a valid pickle file"); 29 | } 30 | 31 | int index = 1; 32 | bool finished = false; 33 | bool readStrides = false; 34 | bool binPersid = false; 35 | 36 | Tensor tensor = new Tensor() { FileName = fileName, Offset = { 0 } }; 37 | 38 | int deepth = 0; 39 | 40 | Dictionary BinPut = new Dictionary(); 41 | 42 | while (index < headerBytes.Length && !finished) 43 | { 44 | byte opcode = headerBytes[index]; 45 | switch (opcode) 46 | { 47 | case (byte)'}': // EMPTY_DICT = b'}' # push empty dict 48 | break; 49 | case (byte)']': // EMPTY_LIST = b']' # push empty list 50 | break; 51 | // skip unused sections 52 | case (byte)'h': // BINGET = b'h' # " " " " " " ; " " 1-byte arg 53 | { 54 | int id = headerBytes[index + 1]; 55 | BinPut.TryGetValue(id, out string precision); 56 | if (precision != null) 57 | { 58 | if (precision.Contains("FloatStorage")) 59 | { 60 | tensor.Type = TorchSharp.torch.ScalarType.Float32; 61 | } 62 | else if (precision.Contains("HalfStorage")) 63 | { 64 | tensor.Type = TorchSharp.torch.ScalarType.Float16; 65 | } 66 | else if (precision.Contains("BFloat16Storage")) 67 | { 68 | tensor.Type = TorchSharp.torch.ScalarType.BFloat16; 69 | } 70 | } 71 | index++; 72 | break; 73 | } 74 | case (byte)'q': // BINPUT = b'q' # " " " " " ; " " 1-byte arg 75 | { 76 | index++; 77 | break; 78 | } 79 | case (byte)'Q': // BINPERSID = b'Q' # " " " ; " " " " stack 80 | binPersid = true; 81 | break; 82 | case (byte)'r': // LONG_BINPUT = b'r' # " " " " " ; " " 4-byte arg 83 | index += 4; 84 | break; 85 | case 0x95: // FRAME = b'\x95' # indicate the beginning of a new frame 86 | index += 8; 87 | break; 88 | case 0x94: // MEMOIZE = b'\x94' # store top of the stack in memo 89 | break; 90 | case (byte)'(': // MARK = b'(' # push special markobject on stack 91 | deepth++; 92 | break; 93 | case (byte)'K': // BININT1 = b'K' # push 1-byte unsigned int 94 | { 95 | int value = headerBytes[index + 1]; 96 | index++; 97 | 98 | if (deepth > 1 && value != 0 && binPersid) 99 | { 100 | if (readStrides) 101 | { 102 | //tensor.Stride.Add((ulong)value); 103 | tensor.Stride.Add((ulong)value); 104 | } 105 | else 106 | { 107 | tensor.Shape.Add(value); 108 | } 109 | } 110 | } 111 | break; 112 | case (byte)'M': // BININT2 = b'M' # push 2-byte unsigned int 113 | { 114 | UInt16 value = BitConverter.ToUInt16(headerBytes, index + 1); 115 | index += 2; 116 | 117 | if (deepth > 1 && value != 0 && binPersid) 118 | { 119 | if (readStrides) 120 | { 121 | tensor.Stride.Add(value); 122 | } 123 | else 124 | { 125 | tensor.Shape.Add(value); 126 | } 127 | } 128 | 129 | } 130 | break; 131 | case (byte)'J': // BININT = b'J' # push four-byte signed int 132 | { 133 | int value = BitConverter.ToInt32(headerBytes, index + 1); 134 | //int value = headerBytes[index + 4] << 24 + headerBytes[index + 3] << 16 + headerBytes[index + 2] << 8 + headerBytes[index + 1]; 135 | index += 4; 136 | 137 | if (deepth > 1 && value != 0 && binPersid) 138 | { 139 | if (readStrides) 140 | { 141 | tensor.Stride.Add((ulong)value); 142 | } 143 | else 144 | { 145 | tensor.Shape.Add(value); 146 | } 147 | } 148 | } 149 | break; 150 | 151 | case (byte)'X': // BINUNICODE = b'X' # " " " ; counted UTF-8 string argument 152 | { 153 | int length = headerBytes[index + 1]; 154 | int start = index + 5; 155 | byte module = headerBytes[index + 1]; 156 | string name = System.Text.Encoding.UTF8.GetString(headerBytes, start, length); 157 | index = index + 4 + length; 158 | 159 | if (deepth == 1) 160 | { 161 | tensor.Name = name; 162 | } 163 | else if (deepth == 3) 164 | { 165 | if ("cpu" != name && !name.Contains("cuda")) 166 | { 167 | tensor.DataNameInZipFile = name; 168 | } 169 | } 170 | } 171 | break; 172 | case 0x8C: // SHORT_BINUNICODE = b'\x8c' # push short string; UTF-8 length < 256 bytes 173 | { 174 | 175 | } 176 | break; 177 | case (byte)'c': // GLOBAL = b'c' # push self.find_class(modname, name); 2 string args 178 | { 179 | int start = index + 1; 180 | while (headerBytes[index + 1] != (byte)'q') 181 | { 182 | index++; 183 | } 184 | int length = index - start + 1; 185 | 186 | string global = System.Text.Encoding.UTF8.GetString(headerBytes, start, length); 187 | 188 | // precision is stored in the global variable 189 | // next tensor will read the precision 190 | // so we can set the Type here 191 | 192 | BinPut.Add(headerBytes[index + 2], global); 193 | 194 | if (global.Contains("FloatStorage")) 195 | { 196 | tensor.Type = TorchSharp.torch.ScalarType.Float32; 197 | } 198 | else if (global.Contains("HalfStorage")) 199 | { 200 | tensor.Type = TorchSharp.torch.ScalarType.Float16; 201 | } 202 | else if (global.Contains("BFloat16Storage")) 203 | { 204 | tensor.Type = TorchSharp.torch.ScalarType.BFloat16; 205 | } 206 | break; 207 | } 208 | case 0x86: // TUPLE2 = b'\x86' # build 2-tuple from two topmost stack items 209 | { 210 | if (binPersid) 211 | { 212 | readStrides = true; 213 | } 214 | break; 215 | } 216 | case 0x85: // TUPLE1 = b'\x85' # build 1-tuple from stack top 217 | if (binPersid) 218 | { 219 | readStrides = true; 220 | } 221 | break; 222 | case (byte)'t': // TUPLE = b't' # build tuple from topmost stack items 223 | deepth--; 224 | if (binPersid) 225 | { 226 | readStrides = true; 227 | } 228 | break; 229 | case (byte)'R': // REDUCE = b'R' # apply callable to argtuple, both on stack 230 | if (deepth == 1) 231 | { 232 | if (tensor.Name.Contains("metadata")) 233 | { 234 | break; 235 | } 236 | 237 | if (string.IsNullOrEmpty(tensor.DataNameInZipFile)) 238 | { 239 | tensor.DataNameInZipFile = tensors.Last().DataNameInZipFile; 240 | tensor.Offset = new List { (ulong)(tensor.Shape[0] * tensor.Type.ElementSize()) }; 241 | tensor.Shape.RemoveAt(0); 242 | //tensor.offset = tensors.Last(). 243 | } 244 | tensors.Add(tensor); 245 | 246 | tensor = new Tensor() { FileName = fileName, Offset = { 0 } }; 247 | readStrides = false; 248 | binPersid = false; 249 | } 250 | break; 251 | case (byte)'.': // STOP = b'.' # every pickle ends with STOP 252 | finished = true; 253 | break; 254 | default: 255 | break; 256 | } 257 | index++; 258 | } 259 | Tensor metaTensor = tensors.Find(x => x.Name.Contains("_metadata")); 260 | if (metaTensor != null) 261 | { 262 | tensors.Remove(metaTensor); 263 | } 264 | return tensors; 265 | } 266 | 267 | public byte[] ReadByteFromFile(Tensor tensor) 268 | { 269 | if (entries is null) 270 | { 271 | throw new ArgumentNullException(nameof(entries)); 272 | } 273 | 274 | ZipArchiveEntry dataEntry = entries.First(e => e.Name == tensor.DataNameInZipFile); 275 | long i = 1; 276 | foreach (var ne in tensor.Shape) 277 | { 278 | i *= ne; 279 | } 280 | ulong length = (ulong)(tensor.Type.ElementSize() * i); 281 | byte[] data = new byte[dataEntry.Length]; 282 | 283 | using (Stream stream = dataEntry.Open()) 284 | { 285 | stream.Read(data, 0, data.Length); 286 | } 287 | 288 | //data = data.Take(new Range((int)tensor.Offset[0], (int)(tensor.Offset[0] + length))).ToArray(); 289 | byte[] result = new byte[length]; 290 | for (int j = 0; j < (int)length; j++) 291 | { 292 | result[j] = data[j + (int)tensor.Offset[0]]; 293 | } 294 | return result; 295 | //return data; 296 | } 297 | 298 | } 299 | } 300 | -------------------------------------------------------------------------------- /StableDiffusionTorchSharp/ModelLoader/LoadData.cs: -------------------------------------------------------------------------------- 1 | using TorchSharp.Modules; 2 | 3 | namespace StableDiffusionTorchSharp.ModelLoader 4 | { 5 | internal class LoadData 6 | { 7 | public static void LoadResidualBlock(IModelLoader modelLoader, List tensorlist, ResidualBlock modules, string name) 8 | { 9 | byte[] data = modelLoader.ReadByteFromFile(tensorlist.First(a => a.Name == name + ".in_layers.0.weight")); 10 | modules.groupnorm_feature.weight.bytes = data; 11 | 12 | data = modelLoader.ReadByteFromFile(tensorlist.First(a => a.Name == name + ".in_layers.0.bias")); 13 | modules.groupnorm_feature.bias.bytes = data; 14 | 15 | data = modelLoader.ReadByteFromFile(tensorlist.First(a => a.Name == name + ".in_layers.2.weight")); 16 | modules.conv_feature.weight.bytes = data; 17 | 18 | data = modelLoader.ReadByteFromFile(tensorlist.First(a => a.Name == name + ".in_layers.2.bias")); 19 | modules.conv_feature.bias.bytes = data; 20 | 21 | data = modelLoader.ReadByteFromFile(tensorlist.First(a => a.Name == name + ".emb_layers.1.weight")); 22 | modules.linear_time.weight.bytes = data; 23 | 24 | data = modelLoader.ReadByteFromFile(tensorlist.First(a => a.Name == name + ".emb_layers.1.bias")); 25 | modules.linear_time.bias.bytes = data; 26 | 27 | data = modelLoader.ReadByteFromFile(tensorlist.First(a => a.Name == name + ".out_layers.0.weight")); 28 | modules.groupnorm_merged.weight.bytes = data; 29 | 30 | data = modelLoader.ReadByteFromFile(tensorlist.First(a => a.Name == name + ".out_layers.0.bias")); 31 | modules.groupnorm_merged.bias.bytes = data; 32 | 33 | data = modelLoader.ReadByteFromFile(tensorlist.First(a => a.Name == name + ".out_layers.3.weight")); 34 | modules.conv_merged.weight.bytes = data; 35 | 36 | data = modelLoader.ReadByteFromFile(tensorlist.First(a => a.Name == name + ".out_layers.3.bias")); 37 | modules.conv_merged.bias.bytes = data; 38 | 39 | 40 | if (!modules.identity) 41 | { 42 | data = modelLoader.ReadByteFromFile(tensorlist.First(a => a.Name == name + ".skip_connection.weight")); 43 | ((Conv2d)modules.residual_layer).weight.bytes = data; 44 | 45 | data = modelLoader.ReadByteFromFile(tensorlist.First(a => a.Name == name + ".skip_connection.bias")); 46 | ((Conv2d)modules.residual_layer).bias.bytes = data; 47 | } 48 | 49 | } 50 | 51 | public static void LoadAttentionBlock(IModelLoader modelLoader, List tensorlist, AttentionBlock modules, string name) 52 | { 53 | byte[] data = modelLoader.ReadByteFromFile(tensorlist.First(a => a.Name == name + ".norm.weight")); 54 | modules.groupnorm.weight.bytes = data; 55 | 56 | data = modelLoader.ReadByteFromFile(tensorlist.First(a => a.Name == name + ".norm.bias")); 57 | modules.groupnorm.bias.bytes = data; 58 | 59 | 60 | data = modelLoader.ReadByteFromFile(tensorlist.First(a => a.Name == name + ".proj_in.weight")); 61 | modules.conv_input.weight.bytes = data; 62 | 63 | data = modelLoader.ReadByteFromFile(tensorlist.First(a => a.Name == name + ".proj_in.bias")); 64 | modules.conv_input.bias.bytes = data; 65 | 66 | 67 | data = modelLoader.ReadByteFromFile(tensorlist.First(a => a.Name == name + ".transformer_blocks.0.norm1.weight")); 68 | modules.layernorm_1.weight.bytes = data; 69 | 70 | data = modelLoader.ReadByteFromFile(tensorlist.First(a => a.Name == name + ".transformer_blocks.0.norm1.bias")); 71 | modules.layernorm_1.bias.bytes = data; 72 | 73 | data = modelLoader.ReadByteFromFile(tensorlist.First(a => a.Name == name + ".transformer_blocks.0.norm2.weight")); 74 | modules.layernorm_2.weight.bytes = data; 75 | 76 | data = modelLoader.ReadByteFromFile(tensorlist.First(a => a.Name == name + ".transformer_blocks.0.norm2.bias")); 77 | modules.layernorm_2.bias.bytes = data; 78 | 79 | data = modelLoader.ReadByteFromFile(tensorlist.First(a => a.Name == name + ".transformer_blocks.0.norm3.weight")); 80 | modules.layernorm_3.weight.bytes = data; 81 | 82 | data = modelLoader.ReadByteFromFile(tensorlist.First(a => a.Name == name + ".transformer_blocks.0.norm3.bias")); 83 | modules.layernorm_3.bias.bytes = data; 84 | 85 | 86 | data = modelLoader.ReadByteFromFile(tensorlist.First(a => a.Name == name + ".transformer_blocks.0.attn1.to_q.weight")).Concat(modelLoader.ReadByteFromFile(tensorlist.First(a => a.Name == name + ".transformer_blocks.0.attn1.to_k.weight"))).Concat(modelLoader.ReadByteFromFile(tensorlist.First(a => a.Name == name + ".transformer_blocks.0.attn1.to_v.weight"))).ToArray(); 87 | modules.attention_1.in_proj.weight.bytes = data; 88 | 89 | data = modelLoader.ReadByteFromFile(tensorlist.First(a => a.Name == name + ".transformer_blocks.0.attn1.to_out.0.weight")); 90 | modules.attention_1.out_proj.weight.bytes = data; 91 | 92 | data = modelLoader.ReadByteFromFile(tensorlist.First(a => a.Name == name + ".transformer_blocks.0.attn1.to_out.0.bias")); 93 | modules.attention_1.out_proj.bias.bytes = data; 94 | 95 | 96 | data = modelLoader.ReadByteFromFile(tensorlist.First(a => a.Name == name + ".transformer_blocks.0.attn2.to_q.weight")); 97 | modules.attention_2.q_proj.weight.bytes = data; 98 | 99 | data = modelLoader.ReadByteFromFile(tensorlist.First(a => a.Name == name + ".transformer_blocks.0.attn2.to_k.weight")); 100 | modules.attention_2.k_proj.weight.bytes = data; 101 | 102 | data = modelLoader.ReadByteFromFile(tensorlist.First(a => a.Name == name + ".transformer_blocks.0.attn2.to_v.weight")); 103 | modules.attention_2.v_proj.weight.bytes = data; 104 | 105 | data = modelLoader.ReadByteFromFile(tensorlist.First(a => a.Name == name + ".transformer_blocks.0.attn2.to_out.0.weight")); 106 | modules.attention_2.out_proj.weight.bytes = data; 107 | 108 | data = modelLoader.ReadByteFromFile(tensorlist.First(a => a.Name == name + ".transformer_blocks.0.attn2.to_out.0.bias")); 109 | modules.attention_2.out_proj.bias.bytes = data; 110 | 111 | 112 | data = modelLoader.ReadByteFromFile(tensorlist.First(a => a.Name == name + ".transformer_blocks.0.ff.net.0.proj.weight")); 113 | modules.linear_geglu_1.weight.bytes = data; 114 | 115 | data = modelLoader.ReadByteFromFile(tensorlist.First(a => a.Name == name + ".transformer_blocks.0.ff.net.0.proj.bias")); 116 | modules.linear_geglu_1.bias.bytes = data; 117 | 118 | data = modelLoader.ReadByteFromFile(tensorlist.First(a => a.Name == name + ".transformer_blocks.0.ff.net.2.weight")); 119 | modules.linear_geglu_2.weight.bytes = data; 120 | 121 | data = modelLoader.ReadByteFromFile(tensorlist.First(a => a.Name == name + ".transformer_blocks.0.ff.net.2.bias")); 122 | modules.linear_geglu_2.bias.bytes = data; 123 | 124 | 125 | data = modelLoader.ReadByteFromFile(tensorlist.First(a => a.Name == name + ".proj_out.weight")); 126 | modules.conv_output.weight.bytes = data; 127 | 128 | data = modelLoader.ReadByteFromFile(tensorlist.First(a => a.Name == name + ".proj_out.bias")); 129 | modules.conv_output.bias.bytes = data; 130 | 131 | } 132 | 133 | 134 | public static void LoadResidualBlockA(IModelLoader modelLoader, List tensorlist, ResidualBlockA modules, string name) 135 | { 136 | byte[] data = modelLoader.ReadByteFromFile(tensorlist.First(a => a.Name == name + ".conv1.weight")); 137 | modules.conv_1.weight.bytes = data; 138 | 139 | data = modelLoader.ReadByteFromFile(tensorlist.First(a => a.Name == name + ".conv1.bias")); 140 | modules.conv_1.bias.bytes = data; 141 | 142 | data = modelLoader.ReadByteFromFile(tensorlist.First(a => a.Name == name + ".conv2.weight")); 143 | modules.conv_2.weight.bytes = data; 144 | 145 | data = modelLoader.ReadByteFromFile(tensorlist.First(a => a.Name == name + ".conv2.bias")); 146 | modules.conv_2.bias.bytes = data; 147 | 148 | data = modelLoader.ReadByteFromFile(tensorlist.First(a => a.Name == name + ".norm1.weight")); 149 | modules.groupnorm_1.weight.bytes = data; 150 | 151 | data = modelLoader.ReadByteFromFile(tensorlist.First(a => a.Name == name + ".norm1.bias")); 152 | modules.groupnorm_1.bias.bytes = data; 153 | 154 | data = modelLoader.ReadByteFromFile(tensorlist.First(a => a.Name == name + ".norm2.weight")); 155 | modules.groupnorm_2.weight.bytes = data; 156 | 157 | data = modelLoader.ReadByteFromFile(tensorlist.First(a => a.Name == name + ".norm2.bias")); 158 | modules.groupnorm_2.bias.bytes = data; 159 | 160 | if (!modules.identity) 161 | { 162 | data = modelLoader.ReadByteFromFile(tensorlist.First(a => a.Name == name + ".nin_shortcut.weight")); 163 | ((Conv2d)modules.residual_layer).weight.bytes = data; 164 | 165 | data = modelLoader.ReadByteFromFile(tensorlist.First(a => a.Name == name + ".nin_shortcut.bias")); 166 | ((Conv2d)modules.residual_layer).bias.bytes = data; 167 | } 168 | 169 | } 170 | 171 | public static void LoadAttentionBlockA(IModelLoader modelLoader, List tensorlist, AttentionBlockA modules, string name) 172 | { 173 | byte[] data = modelLoader.ReadByteFromFile(tensorlist.First(a => a.Name == name + ".attn_1.norm.weight")); 174 | modules.groupnorm.weight.bytes = data; 175 | 176 | data = modelLoader.ReadByteFromFile(tensorlist.First(a => a.Name == name + ".attn_1.norm.bias")); 177 | modules.groupnorm.bias.bytes = data; 178 | 179 | 180 | 181 | data = modelLoader.ReadByteFromFile(tensorlist.First(a => a.Name == name + ".attn_1.q.weight")).Concat(modelLoader.ReadByteFromFile(tensorlist.First(a => a.Name == name + ".attn_1.k.weight"))).Concat(modelLoader.ReadByteFromFile(tensorlist.First(a => a.Name == name + ".attn_1.v.weight"))).ToArray(); 182 | modules.attention.in_proj.weight.bytes = data; 183 | 184 | data = modelLoader.ReadByteFromFile(tensorlist.First(a => a.Name == name + ".attn_1.q.bias")).Concat(modelLoader.ReadByteFromFile(tensorlist.First(a => a.Name == name + ".attn_1.k.bias"))).Concat(modelLoader.ReadByteFromFile(tensorlist.First(a => a.Name == name + ".attn_1.v.bias"))).ToArray(); 185 | modules.attention.in_proj.bias.bytes = data; 186 | 187 | 188 | data = modelLoader.ReadByteFromFile(tensorlist.First(a => a.Name == name + ".attn_1.proj_out.weight")); 189 | modules.attention.out_proj.weight.bytes = data; 190 | 191 | data = modelLoader.ReadByteFromFile(tensorlist.First(a => a.Name == name + ".attn_1.proj_out.bias")); 192 | modules.attention.out_proj.bias.bytes = data; 193 | 194 | } 195 | 196 | } 197 | } 198 | -------------------------------------------------------------------------------- /StableDiffusionTorchSharp/Decoder.cs: -------------------------------------------------------------------------------- 1 | using TorchSharp; 2 | using TorchSharp.Modules; 3 | using static TorchSharp.torch; 4 | using static TorchSharp.torch.nn; 5 | 6 | namespace StableDiffusionTorchSharp 7 | { 8 | public class AttentionBlockA : Module 9 | { 10 | internal readonly GroupNorm groupnorm; 11 | internal readonly SelfAttention attention; 12 | 13 | public AttentionBlockA(long channels, bool useFlashAttention = false) : base("AttentionBlockA") 14 | { 15 | groupnorm = GroupNorm(32, channels); 16 | attention = new SelfAttention(32, channels, causal_mask: true, useFlashAtten: useFlashAttention); 17 | RegisterComponents(); 18 | } 19 | 20 | public override Tensor forward(Tensor x) 21 | { 22 | using var _ = NewDisposeScope(); 23 | var residue = x; 24 | x = groupnorm.forward(x); 25 | var n = x.shape[0]; 26 | var c = x.shape[1]; 27 | var h = x.shape[2]; 28 | var w = x.shape[3]; 29 | x = x.view(n, c, h * w); 30 | x = x.transpose(-1, -2); 31 | x = attention.forward(x); 32 | x = x.transpose(-1, -2); 33 | x = x.view(n, c, h, w); 34 | x += residue; 35 | return x.MoveToOuterDisposeScope(); 36 | } 37 | } 38 | 39 | public class ResidualBlockA : Module 40 | { 41 | internal readonly GroupNorm groupnorm_1; 42 | internal readonly Conv2d conv_1; 43 | internal readonly GroupNorm groupnorm_2; 44 | internal readonly Conv2d conv_2; 45 | internal readonly Module residual_layer; 46 | internal readonly bool identity; 47 | 48 | public ResidualBlockA(long in_channels, long out_channels) : base("ResidualBlockA") 49 | { 50 | groupnorm_1 = GroupNorm(32, in_channels); 51 | conv_1 = Conv2d(in_channels, out_channels, kernelSize: 3, padding: 1); 52 | groupnorm_2 = GroupNorm(32, out_channels); 53 | conv_2 = Conv2d(out_channels, out_channels, kernelSize: 3, padding: 1); 54 | identity = (in_channels == out_channels); 55 | if (in_channels == out_channels) 56 | { 57 | residual_layer = nn.Identity(); 58 | } 59 | else 60 | { 61 | residual_layer = Conv2d(in_channels, out_channels, kernelSize: 1); 62 | } 63 | RegisterComponents(); 64 | } 65 | 66 | public override Tensor forward(Tensor x) 67 | { 68 | using var _ = NewDisposeScope(); 69 | var residue = x; 70 | x = groupnorm_1.forward(x); 71 | x = torch.nn.functional.silu(x); 72 | x = conv_1.forward(x); 73 | x = groupnorm_2.forward(x); 74 | x = torch.nn.functional.silu(x); 75 | x = conv_2.forward(x); 76 | var output = x + residual_layer.forward(residue); 77 | return output.MoveToOuterDisposeScope(); 78 | } 79 | } 80 | 81 | 82 | public class Decoder : Sequential 83 | { 84 | public Decoder(bool useFlashAttention = false) : base( 85 | Conv2d(4, 4, kernelSize: 1), 86 | Conv2d(4, 512, kernelSize: 3, padding: 1), 87 | 88 | //mid 89 | new ResidualBlockA(512, 512), 90 | new AttentionBlockA(512, useFlashAttention: useFlashAttention), 91 | new ResidualBlockA(512, 512), 92 | 93 | // up 94 | new ResidualBlockA(512, 512), 95 | new ResidualBlockA(512, 512), 96 | new ResidualBlockA(512, 512), 97 | Upsample(scale_factor: new double[] { 2, 2 }), 98 | Conv2d(512, 512, kernelSize: 3, padding: 1), 99 | 100 | new ResidualBlockA(512, 512), 101 | new ResidualBlockA(512, 512), 102 | new ResidualBlockA(512, 512), 103 | Upsample(scale_factor: new double[] { 2, 2 }), 104 | Conv2d(512, 512, kernelSize: 3, padding: 1), 105 | 106 | new ResidualBlockA(512, 256), 107 | new ResidualBlockA(256, 256), 108 | new ResidualBlockA(256, 256), 109 | Upsample(scale_factor: new double[] { 2, 2 }), 110 | Conv2d(256, 256, kernelSize: 3, padding: 1), 111 | 112 | new ResidualBlockA(256, 128), 113 | new ResidualBlockA(128, 128), 114 | new ResidualBlockA(128, 128), 115 | 116 | GroupNorm(32, 128), 117 | GELU(), 118 | Conv2d(128, 3, kernelSize: 3, padding: 1) 119 | 120 | ) 121 | { 122 | } 123 | 124 | public override Tensor forward(Tensor x) 125 | { 126 | x = x / 0.18215f; 127 | foreach (var module in children()) 128 | { 129 | x = ((Module)module).forward(x); 130 | } 131 | return x; 132 | } 133 | 134 | 135 | public override Module load(string filename, bool strict = true, IList skip = null, Dictionary loadedParameters = null) 136 | { 137 | string extension = Path.GetExtension(filename); 138 | 139 | if (extension.Contains("dat")) 140 | { 141 | using (FileStream fileStream = new FileStream(filename, FileMode.Open)) 142 | { 143 | using (BinaryReader binaryReader = new BinaryReader(fileStream)) 144 | { 145 | return load(binaryReader, strict, skip, loadedParameters); 146 | } 147 | } 148 | } 149 | 150 | ModelLoader.IModelLoader modelLoader = null; 151 | if (extension.Contains("safetensor")) 152 | { 153 | modelLoader = new ModelLoader.SafetensorsLoader(); 154 | } 155 | else if (extension.Contains("pt")) 156 | { 157 | modelLoader = new ModelLoader.PickleLoader(); 158 | } 159 | 160 | List tensors = modelLoader.ReadTensorsInfoFromFile(filename); 161 | 162 | var t = tensors.First(a => a.Name == "first_stage_model.post_quant_conv.weight"); 163 | this.to(t.Type); 164 | 165 | byte[] data = modelLoader.ReadByteFromFile(tensors.First(a => a.Name == "first_stage_model.post_quant_conv.weight")); 166 | ((Conv2d)children().ToArray()[0]).weight.bytes = data; 167 | 168 | data = modelLoader.ReadByteFromFile(tensors.First(a => a.Name == "first_stage_model.post_quant_conv.bias")); 169 | ((Conv2d)children().ToArray()[0]).bias.bytes = data; 170 | 171 | data = modelLoader.ReadByteFromFile(tensors.First(a => a.Name == "first_stage_model.decoder.conv_in.weight")); 172 | ((Conv2d)children().ToArray()[1]).weight.bytes = data; 173 | 174 | data = modelLoader.ReadByteFromFile(tensors.First(a => a.Name == "first_stage_model.decoder.conv_in.bias")); 175 | ((Conv2d)children().ToArray()[1]).bias.bytes = data; 176 | 177 | // mid 178 | ModelLoader.LoadData.LoadResidualBlockA(modelLoader, tensors, (ResidualBlockA)(children().ToArray()[2]), "first_stage_model.decoder.mid.block_1"); 179 | ModelLoader.LoadData.LoadAttentionBlockA(modelLoader, tensors, (AttentionBlockA)(children().ToArray()[3]), "first_stage_model.decoder.mid"); 180 | ModelLoader.LoadData.LoadResidualBlockA(modelLoader, tensors, (ResidualBlockA)(children().ToArray()[4]), "first_stage_model.decoder.mid.block_2"); 181 | 182 | 183 | // first_stage_model.decoder.up.3 184 | ModelLoader.LoadData.LoadResidualBlockA(modelLoader, tensors, (ResidualBlockA)(children().ToArray()[5]), "first_stage_model.decoder.up.3.block.0"); 185 | ModelLoader.LoadData.LoadResidualBlockA(modelLoader, tensors, (ResidualBlockA)(children().ToArray()[6]), "first_stage_model.decoder.up.3.block.1"); 186 | ModelLoader.LoadData.LoadResidualBlockA(modelLoader, tensors, (ResidualBlockA)(children().ToArray()[7]), "first_stage_model.decoder.up.3.block.2"); 187 | 188 | data = modelLoader.ReadByteFromFile(tensors.First(a => a.Name == "first_stage_model.decoder.up.3.upsample.conv.weight")); 189 | ((Conv2d)children().ToArray()[9]).weight.bytes = data; 190 | 191 | data = modelLoader.ReadByteFromFile(tensors.First(a => a.Name == "first_stage_model.decoder.up.3.upsample.conv.bias")); 192 | ((Conv2d)children().ToArray()[9]).bias.bytes = data; 193 | 194 | 195 | // first_stage_model.decoder.up.2 196 | ModelLoader.LoadData.LoadResidualBlockA(modelLoader, tensors, (ResidualBlockA)(children().ToArray()[10]), "first_stage_model.decoder.up.2.block.0"); 197 | ModelLoader.LoadData.LoadResidualBlockA(modelLoader, tensors, (ResidualBlockA)(children().ToArray()[11]), "first_stage_model.decoder.up.2.block.1"); 198 | ModelLoader.LoadData.LoadResidualBlockA(modelLoader, tensors, (ResidualBlockA)(children().ToArray()[12]), "first_stage_model.decoder.up.2.block.2"); 199 | 200 | data = modelLoader.ReadByteFromFile(tensors.First(a => a.Name == "first_stage_model.decoder.up.2.upsample.conv.weight")); 201 | ((Conv2d)children().ToArray()[14]).weight.bytes = data; 202 | 203 | data = modelLoader.ReadByteFromFile(tensors.First(a => a.Name == "first_stage_model.decoder.up.2.upsample.conv.bias")); 204 | ((Conv2d)children().ToArray()[14]).bias.bytes = data; 205 | 206 | 207 | // first_stage_model.decoder.up.1 208 | ModelLoader.LoadData.LoadResidualBlockA(modelLoader, tensors, (ResidualBlockA)(children().ToArray()[15]), "first_stage_model.decoder.up.1.block.0"); 209 | ModelLoader.LoadData.LoadResidualBlockA(modelLoader, tensors, (ResidualBlockA)(children().ToArray()[16]), "first_stage_model.decoder.up.1.block.1"); 210 | ModelLoader.LoadData.LoadResidualBlockA(modelLoader, tensors, (ResidualBlockA)(children().ToArray()[17]), "first_stage_model.decoder.up.1.block.2"); 211 | 212 | data = modelLoader.ReadByteFromFile(tensors.First(a => a.Name == "first_stage_model.decoder.up.1.upsample.conv.weight")); 213 | ((Conv2d)children().ToArray()[19]).weight.bytes = data; 214 | 215 | data = modelLoader.ReadByteFromFile(tensors.First(a => a.Name == "first_stage_model.decoder.up.1.upsample.conv.bias")); 216 | ((Conv2d)children().ToArray()[19]).bias.bytes = data; 217 | 218 | 219 | // first_stage_model.decoder.up.0 220 | ModelLoader.LoadData.LoadResidualBlockA(modelLoader, tensors, (ResidualBlockA)(children().ToArray()[20]), "first_stage_model.decoder.up.0.block.0"); 221 | ModelLoader.LoadData.LoadResidualBlockA(modelLoader, tensors, (ResidualBlockA)(children().ToArray()[21]), "first_stage_model.decoder.up.0.block.1"); 222 | ModelLoader.LoadData.LoadResidualBlockA(modelLoader, tensors, (ResidualBlockA)(children().ToArray()[22]), "first_stage_model.decoder.up.0.block.2"); 223 | 224 | 225 | 226 | // out 227 | data = modelLoader.ReadByteFromFile(tensors.First(a => a.Name == "first_stage_model.decoder.norm_out.weight")); 228 | ((GroupNorm)children().ToArray()[23]).weight.bytes = data; 229 | 230 | data = modelLoader.ReadByteFromFile(tensors.First(a => a.Name == "first_stage_model.decoder.norm_out.bias")); 231 | ((GroupNorm)children().ToArray()[23]).bias.bytes = data; 232 | 233 | 234 | data = modelLoader.ReadByteFromFile(tensors.First(a => a.Name == "first_stage_model.decoder.conv_out.weight")); 235 | ((Conv2d)children().ToArray()[25]).weight.bytes = data; 236 | 237 | data = modelLoader.ReadByteFromFile(tensors.First(a => a.Name == "first_stage_model.decoder.conv_out.bias")); 238 | ((Conv2d)children().ToArray()[25]).bias.bytes = data; 239 | 240 | return this; 241 | } 242 | } 243 | } 244 | -------------------------------------------------------------------------------- /StableDiffusionTorchSharp/Diffusion.cs: -------------------------------------------------------------------------------- 1 | using TorchSharp; 2 | using static TorchSharp.torch.nn; 3 | using static TorchSharp.torch; 4 | using TorchSharp.Modules; 5 | 6 | namespace StableDiffusionTorchSharp 7 | { 8 | public class TimeEmbedding : Module 9 | { 10 | internal readonly Linear linear_1; 11 | internal readonly Linear linear_2; 12 | 13 | public TimeEmbedding(long n_embd) : base("TimeEmbedding") 14 | { 15 | linear_1 = Linear(n_embd, 4 * n_embd); 16 | linear_2 = Linear(4 * n_embd, 4 * n_embd); 17 | RegisterComponents(); 18 | } 19 | 20 | public override Tensor forward(Tensor x) 21 | { 22 | using var _ = NewDisposeScope(); 23 | x = linear_1.forward(x); 24 | x = torch.nn.functional.silu(x); 25 | x = linear_2.forward(x); 26 | return x.MoveToOuterDisposeScope(); 27 | } 28 | } 29 | 30 | public class ResidualBlock : Module 31 | { 32 | internal readonly GroupNorm groupnorm_feature; 33 | internal readonly Conv2d conv_feature; 34 | internal readonly Linear linear_time; 35 | internal readonly GroupNorm groupnorm_merged; 36 | internal readonly Conv2d conv_merged; 37 | internal readonly Module residual_layer; 38 | internal readonly bool identity; 39 | 40 | public ResidualBlock(long in_channels, long out_channels, long n_time = 1280) : base("ResidualBlock") 41 | { 42 | groupnorm_feature = GroupNorm(32, in_channels); 43 | conv_feature = Conv2d(in_channels, out_channels, kernelSize: 3, padding: 1); 44 | linear_time = Linear(n_time, out_channels); 45 | groupnorm_merged = GroupNorm(32, out_channels); 46 | conv_merged = Conv2d(out_channels, out_channels, kernelSize: 3, padding: 1); 47 | identity = (in_channels == out_channels); 48 | if (in_channels == out_channels) 49 | { 50 | residual_layer = nn.Identity(); 51 | } 52 | else 53 | { 54 | residual_layer = Conv2d(in_channels, out_channels, kernelSize: 1); 55 | } 56 | RegisterComponents(); 57 | } 58 | 59 | public override Tensor forward(Tensor feature, Tensor time) 60 | { 61 | using var _ = NewDisposeScope(); 62 | var residue = feature; 63 | feature = groupnorm_feature.forward(feature); 64 | feature = torch.nn.functional.silu(feature); 65 | feature = conv_feature.forward(feature); 66 | 67 | time = torch.nn.functional.silu(time); 68 | time = linear_time.forward(time); 69 | 70 | var merged = feature + time.unsqueeze(-1).unsqueeze(-1); 71 | merged = groupnorm_merged.forward(merged); 72 | merged = torch.nn.functional.silu(merged); 73 | merged = conv_merged.forward(merged); 74 | 75 | var output = merged + residual_layer.forward(residue); 76 | return output.MoveToOuterDisposeScope(); 77 | } 78 | } 79 | 80 | public class AttentionBlock : Module 81 | { 82 | internal readonly GroupNorm groupnorm; 83 | internal readonly Conv2d conv_input; 84 | internal readonly LayerNorm layernorm_1; 85 | internal readonly SelfAttention attention_1; 86 | internal readonly LayerNorm layernorm_2; 87 | internal readonly CrossAttention attention_2; 88 | internal readonly LayerNorm layernorm_3; 89 | internal readonly Linear linear_geglu_1; 90 | internal readonly Linear linear_geglu_2; 91 | internal readonly Conv2d conv_output; 92 | 93 | public AttentionBlock(long n_head, long n_embd, long d_context = 768, bool useFlashAttention = false) : base("AttentionBlock") 94 | { 95 | var channels = n_head * n_embd; 96 | groupnorm = GroupNorm(32, channels); 97 | conv_input = Conv2d(channels, channels, kernelSize: 1); 98 | layernorm_1 = LayerNorm(channels); 99 | attention_1 = new SelfAttention(n_head, channels, in_proj_bias: false, useFlashAtten: useFlashAttention); 100 | layernorm_2 = LayerNorm(channels); 101 | attention_2 = new CrossAttention(n_head, channels, d_context, in_proj_bias: false); 102 | layernorm_3 = LayerNorm(channels); 103 | linear_geglu_1 = Linear(channels, 4 * channels * 2); 104 | linear_geglu_2 = Linear(4 * channels, channels); 105 | conv_output = Conv2d(channels, channels, kernelSize: 1); 106 | RegisterComponents(); 107 | } 108 | 109 | public override Tensor forward(Tensor x, Tensor context) 110 | { 111 | using var _ = NewDisposeScope(); 112 | var residue_long = x; 113 | x = groupnorm.forward(x); 114 | x = conv_input.forward(x); 115 | var n = x.shape[0]; 116 | var c = x.shape[1]; 117 | var h = x.shape[2]; 118 | var w = x.shape[3]; 119 | x = x.view(new long[] { n, c, h * w }); 120 | x = x.transpose(-1, -2); 121 | var residue_short = x; 122 | x = layernorm_1.forward(x); 123 | x = attention_1.forward(x); 124 | x += residue_short; 125 | residue_short = x; 126 | x = layernorm_2.forward(x); 127 | x = attention_2.forward(x, context); 128 | x += residue_short; 129 | residue_short = x; 130 | x = layernorm_3.forward(x); 131 | var ret = linear_geglu_1.forward(x).chunk(2, -1); 132 | x = ret[0]; 133 | var gate = ret[1]; 134 | x = x * torch.nn.functional.gelu(gate); 135 | x = linear_geglu_2.forward(x); 136 | x += residue_short; 137 | x = x.transpose(-1, -2); 138 | x = x.view(new long[] { n, c, h, w }); 139 | var output = conv_output.forward(x) + residue_long; 140 | return output.MoveToOuterDisposeScope(); 141 | } 142 | } 143 | 144 | public class Upsample : Module 145 | { 146 | internal readonly Conv2d conv; 147 | public Upsample(long channels) : base("Upsample") 148 | { 149 | conv = Conv2d(channels, channels, kernelSize: 3, padding: 1); 150 | RegisterComponents(); 151 | } 152 | 153 | public override Tensor forward(Tensor x) 154 | { 155 | using var _ = NewDisposeScope(); 156 | x = torch.nn.functional.interpolate(x, scale_factor: new double[] { 2, 2 }); 157 | var output = conv.forward(x); 158 | return output.MoveToOuterDisposeScope(); 159 | } 160 | } 161 | 162 | class SwitchSequential : Sequential 163 | { 164 | internal SwitchSequential(params (string name, torch.nn.Module)[] modules) : base(modules) 165 | { 166 | } 167 | 168 | internal SwitchSequential(params torch.nn.Module[] modules) : base(modules) 169 | { 170 | } 171 | 172 | public override torch.Tensor forward(torch.Tensor x, torch.Tensor context, torch.Tensor time) 173 | { 174 | using var _ = torch.NewDisposeScope(); 175 | foreach (var layer in children()) 176 | { 177 | switch (layer) 178 | { 179 | case AttentionBlock abl: 180 | x = abl.call(x, context); 181 | break; 182 | case ResidualBlock rbl: 183 | x = rbl.call(x, time); 184 | break; 185 | case torch.nn.Module m: 186 | x = m.call(x); 187 | break; 188 | } 189 | } 190 | return x.MoveToOuterDisposeScope(); 191 | } 192 | } 193 | 194 | public class UNet : Module 195 | { 196 | internal readonly ModuleList encoders; 197 | internal readonly SwitchSequential bottleneck; 198 | internal readonly ModuleList decoders; 199 | 200 | public UNet(bool useFlashAttention = false) : base("UNet") 201 | { 202 | encoders = new ModuleList( 203 | new SwitchSequential(Conv2d(4, 320, 3, padding: 1)), 204 | new SwitchSequential(new ResidualBlock(320, 320), new AttentionBlock(8, 40, useFlashAttention: useFlashAttention)), 205 | new SwitchSequential(new ResidualBlock(320, 320), new AttentionBlock(8, 40, useFlashAttention: useFlashAttention)), 206 | new SwitchSequential(Conv2d(320, 320, 3, stride: 2, padding: 1)), 207 | new SwitchSequential(new ResidualBlock(320, 640), new AttentionBlock(8, 80, useFlashAttention: useFlashAttention)), 208 | new SwitchSequential(new ResidualBlock(640, 640), new AttentionBlock(8, 80, useFlashAttention: useFlashAttention)), 209 | new SwitchSequential(Conv2d(640, 640, 3, stride: 2, padding: 1)), 210 | new SwitchSequential(new ResidualBlock(640, 1280), new AttentionBlock(8, 160, useFlashAttention: useFlashAttention)), 211 | new SwitchSequential(new ResidualBlock(1280, 1280), new AttentionBlock(8, 160, useFlashAttention: useFlashAttention)), 212 | new SwitchSequential(Conv2d(1280, 1280, 3, stride: 2, padding: 1)), 213 | new SwitchSequential(new ResidualBlock(1280, 1280)), 214 | new SwitchSequential(new ResidualBlock(1280, 1280)) 215 | ); 216 | bottleneck = new SwitchSequential( 217 | new ResidualBlock(1280, 1280), 218 | new AttentionBlock(8, 160, useFlashAttention: useFlashAttention), 219 | new ResidualBlock(1280, 1280) 220 | ); 221 | decoders = new ModuleList( 222 | new SwitchSequential(new ResidualBlock(2560, 1280)), 223 | new SwitchSequential(new ResidualBlock(2560, 1280)), 224 | new SwitchSequential(new ResidualBlock(2560, 1280), new Upsample(1280)), 225 | new SwitchSequential(new ResidualBlock(2560, 1280), new AttentionBlock(8, 160, useFlashAttention: useFlashAttention)), 226 | new SwitchSequential(new ResidualBlock(2560, 1280), new AttentionBlock(8, 160, useFlashAttention: useFlashAttention)), 227 | new SwitchSequential(new ResidualBlock(1920, 1280), new AttentionBlock(8, 160, useFlashAttention: useFlashAttention), new Upsample(1280)), 228 | new SwitchSequential(new ResidualBlock(1920, 640), new AttentionBlock(8, 80, useFlashAttention: useFlashAttention)), 229 | new SwitchSequential(new ResidualBlock(1280, 640), new AttentionBlock(8, 80, useFlashAttention: useFlashAttention)), 230 | new SwitchSequential(new ResidualBlock(960, 640), new AttentionBlock(8, 80, useFlashAttention: useFlashAttention), new Upsample(640)), 231 | new SwitchSequential(new ResidualBlock(960, 320), new AttentionBlock(8, 40, useFlashAttention: useFlashAttention)), 232 | new SwitchSequential(new ResidualBlock(640, 320), new AttentionBlock(8, 40, useFlashAttention: useFlashAttention)), 233 | new SwitchSequential(new ResidualBlock(640, 320), new AttentionBlock(8, 40, useFlashAttention: useFlashAttention)) 234 | ); 235 | RegisterComponents(); 236 | } 237 | 238 | public override Tensor forward(Tensor x, Tensor context, Tensor time) 239 | { 240 | using var _ = NewDisposeScope(); 241 | List skip_connections = new List(); 242 | foreach (var layers in encoders) 243 | { 244 | x = layers.forward(x, context, time); 245 | skip_connections.Add(x); 246 | } 247 | x = bottleneck.forward(x, context, time); 248 | foreach (var layers in decoders) 249 | { 250 | var index = skip_connections.Last(); 251 | x = torch.cat(new Tensor[] { x, index }, 1); 252 | skip_connections.RemoveAt(skip_connections.Count - 1); 253 | x = layers.forward(x, context, time); 254 | } 255 | return x.MoveToOuterDisposeScope(); 256 | } 257 | } 258 | 259 | public class FinalLayer : Module 260 | { 261 | internal readonly Conv2d conv; 262 | internal readonly GroupNorm groupnorm; 263 | 264 | public FinalLayer(long in_channels, long out_channels) : base("FinalLayer") 265 | { 266 | groupnorm = GroupNorm(32, in_channels); 267 | conv = Conv2d(in_channels, out_channels, kernelSize: 3, padding: 1); 268 | RegisterComponents(); 269 | } 270 | 271 | public override Tensor forward(Tensor x) 272 | { 273 | using var _ = NewDisposeScope(); 274 | x = groupnorm.forward(x); 275 | x = torch.nn.functional.silu(x); 276 | x = conv.forward(x); 277 | return x.MoveToOuterDisposeScope(); 278 | } 279 | } 280 | 281 | public class Diffusion : Module 282 | { 283 | internal readonly TimeEmbedding time_embedding; 284 | internal readonly UNet unet; 285 | internal readonly FinalLayer final; 286 | 287 | public Diffusion(bool useFlashAttention = false) : base("Diffusion") 288 | { 289 | time_embedding = new TimeEmbedding(320); 290 | unet = new UNet(useFlashAttention); 291 | final = new FinalLayer(320, 4); 292 | RegisterComponents(); 293 | } 294 | 295 | public override Tensor forward(Tensor latent, Tensor context, Tensor time) 296 | { 297 | using var _ = NewDisposeScope(); 298 | time = time_embedding.forward(time); 299 | var output = unet.forward(latent, context, time); 300 | output = final.forward(output); 301 | return output.MoveToOuterDisposeScope(); 302 | } 303 | 304 | public override Module load(string filename, bool strict = true, IList skip = null, Dictionary loadedParameters = null) 305 | { 306 | string extension = Path.GetExtension(filename); 307 | 308 | if (extension.Contains("dat")) 309 | { 310 | using (FileStream fileStream = new FileStream(filename, FileMode.Open)) 311 | { 312 | using (BinaryReader binaryReader = new BinaryReader(fileStream)) 313 | { 314 | return load(binaryReader, strict, skip, loadedParameters); 315 | } 316 | } 317 | } 318 | 319 | ModelLoader.IModelLoader modelLoader = null; 320 | if (extension.Contains("safetensor")) 321 | { 322 | modelLoader = new ModelLoader.SafetensorsLoader(); 323 | } 324 | else if (extension.Contains("pt")) 325 | { 326 | modelLoader = new ModelLoader.PickleLoader(); 327 | } 328 | 329 | List tensors = modelLoader.ReadTensorsInfoFromFile(filename); 330 | 331 | // Load Encoders 332 | var t = tensors.First(a => a.Name == "model.diffusion_model.time_embed.0.weight"); 333 | this.to(t.Type); 334 | 335 | byte[] data = modelLoader.ReadByteFromFile(tensors.First(a => a.Name == "model.diffusion_model.time_embed.0.weight")); 336 | time_embedding.linear_1.weight.bytes = data; 337 | 338 | data = modelLoader.ReadByteFromFile(tensors.First(a => a.Name == "model.diffusion_model.time_embed.0.bias")); 339 | time_embedding.linear_1.bias.bytes = data; 340 | 341 | data = modelLoader.ReadByteFromFile(tensors.First(a => a.Name == "model.diffusion_model.time_embed.2.weight")); 342 | time_embedding.linear_2.weight.bytes = data; 343 | 344 | data = modelLoader.ReadByteFromFile(tensors.First(a => a.Name == "model.diffusion_model.time_embed.2.bias")); 345 | time_embedding.linear_2.bias.bytes = data; 346 | 347 | 348 | data = modelLoader.ReadByteFromFile(tensors.First(a => a.Name == "model.diffusion_model.input_blocks.0.0.weight")); 349 | ((Conv2d)unet.encoders[0][0]).weight.bytes = data; 350 | 351 | data = modelLoader.ReadByteFromFile(tensors.First(a => a.Name == "model.diffusion_model.input_blocks.0.0.bias")); 352 | ((Conv2d)unet.encoders[0][0]).bias.bytes = data; 353 | 354 | 355 | data = modelLoader.ReadByteFromFile(tensors.First(a => a.Name == "model.diffusion_model.input_blocks.3.0.op.weight")); 356 | ((Conv2d)unet.encoders[3][0]).weight.bytes = data; 357 | 358 | data = modelLoader.ReadByteFromFile(tensors.First(a => a.Name == "model.diffusion_model.input_blocks.3.0.op.bias")); 359 | ((Conv2d)unet.encoders[3][0]).bias.bytes = data; 360 | 361 | 362 | data = modelLoader.ReadByteFromFile(tensors.First(a => a.Name == "model.diffusion_model.input_blocks.6.0.op.weight")); 363 | ((Conv2d)unet.encoders[6][0]).weight.bytes = data; 364 | 365 | data = modelLoader.ReadByteFromFile(tensors.First(a => a.Name == "model.diffusion_model.input_blocks.6.0.op.bias")); 366 | ((Conv2d)unet.encoders[6][0]).bias.bytes = data; 367 | 368 | 369 | data = modelLoader.ReadByteFromFile(tensors.First(a => a.Name == "model.diffusion_model.input_blocks.9.0.op.weight")); 370 | ((Conv2d)unet.encoders[9][0]).weight.bytes = data; 371 | 372 | data = modelLoader.ReadByteFromFile(tensors.First(a => a.Name == "model.diffusion_model.input_blocks.9.0.op.bias")); 373 | ((Conv2d)unet.encoders[9][0]).bias.bytes = data; 374 | 375 | 376 | ModelLoader.LoadData.LoadResidualBlock(modelLoader, tensors, (ResidualBlock)unet.encoders[1][0], "model.diffusion_model.input_blocks.1.0"); 377 | ModelLoader.LoadData.LoadAttentionBlock(modelLoader, tensors, (AttentionBlock)unet.encoders[1][1], "model.diffusion_model.input_blocks.1.1"); 378 | 379 | ModelLoader.LoadData.LoadResidualBlock(modelLoader, tensors, (ResidualBlock)unet.encoders[2][0], "model.diffusion_model.input_blocks.2.0"); 380 | ModelLoader.LoadData.LoadAttentionBlock(modelLoader, tensors, (AttentionBlock)unet.encoders[2][1], "model.diffusion_model.input_blocks.2.1"); 381 | 382 | ModelLoader.LoadData.LoadResidualBlock(modelLoader, tensors, (ResidualBlock)unet.encoders[4][0], "model.diffusion_model.input_blocks.4.0"); 383 | ModelLoader.LoadData.LoadAttentionBlock(modelLoader, tensors, (AttentionBlock)unet.encoders[4][1], "model.diffusion_model.input_blocks.4.1"); 384 | 385 | ModelLoader.LoadData.LoadResidualBlock(modelLoader, tensors, (ResidualBlock)unet.encoders[5][0], "model.diffusion_model.input_blocks.5.0"); 386 | ModelLoader.LoadData.LoadAttentionBlock(modelLoader, tensors, (AttentionBlock)unet.encoders[5][1], "model.diffusion_model.input_blocks.5.1"); 387 | 388 | ModelLoader.LoadData.LoadResidualBlock(modelLoader, tensors, (ResidualBlock)unet.encoders[7][0], "model.diffusion_model.input_blocks.7.0"); 389 | ModelLoader.LoadData.LoadAttentionBlock(modelLoader, tensors, (AttentionBlock)unet.encoders[7][1], "model.diffusion_model.input_blocks.7.1"); 390 | 391 | ModelLoader.LoadData.LoadResidualBlock(modelLoader, tensors, (ResidualBlock)unet.encoders[8][0], "model.diffusion_model.input_blocks.8.0"); 392 | ModelLoader.LoadData.LoadAttentionBlock(modelLoader, tensors, (AttentionBlock)unet.encoders[8][1], "model.diffusion_model.input_blocks.8.1"); 393 | 394 | ModelLoader.LoadData.LoadResidualBlock(modelLoader, tensors, (ResidualBlock)unet.encoders[10][0], "model.diffusion_model.input_blocks.10.0"); 395 | 396 | ModelLoader.LoadData.LoadResidualBlock(modelLoader, tensors, (ResidualBlock)unet.encoders[11][0], "model.diffusion_model.input_blocks.11.0"); 397 | 398 | // Load bottleneck 399 | 400 | ModelLoader.LoadData.LoadResidualBlock(modelLoader, tensors, (ResidualBlock)unet.bottleneck[0], "model.diffusion_model.middle_block.0"); 401 | ModelLoader.LoadData.LoadAttentionBlock(modelLoader, tensors, (AttentionBlock)unet.bottleneck[1], "model.diffusion_model.middle_block.1"); 402 | ModelLoader.LoadData.LoadResidualBlock(modelLoader, tensors, (ResidualBlock)unet.bottleneck[2], "model.diffusion_model.middle_block.2"); 403 | 404 | 405 | // Load decoders 406 | 407 | ModelLoader.LoadData.LoadResidualBlock(modelLoader, tensors, (ResidualBlock)unet.decoders[0][0], "model.diffusion_model.output_blocks.0.0"); 408 | 409 | ModelLoader.LoadData.LoadResidualBlock(modelLoader, tensors, (ResidualBlock)unet.decoders[1][0], "model.diffusion_model.output_blocks.1.0"); 410 | 411 | ModelLoader.LoadData.LoadResidualBlock(modelLoader, tensors, (ResidualBlock)unet.decoders[2][0], "model.diffusion_model.output_blocks.2.0"); 412 | data = modelLoader.ReadByteFromFile(tensors.First(a => a.Name == "model.diffusion_model.output_blocks.2.1.conv.weight")); 413 | ((Upsample)unet.decoders[2][1]).conv.weight.bytes = data; 414 | data = modelLoader.ReadByteFromFile(tensors.First(a => a.Name == "model.diffusion_model.output_blocks.2.1.conv.bias")); 415 | ((Upsample)unet.decoders[2][1]).conv.bias.bytes = data; 416 | 417 | ModelLoader.LoadData.LoadResidualBlock(modelLoader, tensors, (ResidualBlock)unet.decoders[3][0], "model.diffusion_model.output_blocks.3.0"); 418 | ModelLoader.LoadData.LoadAttentionBlock(modelLoader, tensors, (AttentionBlock)unet.decoders[3][1], "model.diffusion_model.output_blocks.3.1"); 419 | 420 | ModelLoader.LoadData.LoadResidualBlock(modelLoader, tensors, (ResidualBlock)unet.decoders[4][0], "model.diffusion_model.output_blocks.4.0"); 421 | ModelLoader.LoadData.LoadAttentionBlock(modelLoader, tensors, (AttentionBlock)unet.decoders[4][1], "model.diffusion_model.output_blocks.4.1"); 422 | 423 | ModelLoader.LoadData.LoadResidualBlock(modelLoader, tensors, (ResidualBlock)unet.decoders[5][0], "model.diffusion_model.output_blocks.5.0"); 424 | ModelLoader.LoadData.LoadAttentionBlock(modelLoader, tensors, (AttentionBlock)unet.decoders[5][1], "model.diffusion_model.output_blocks.5.1"); 425 | data = modelLoader.ReadByteFromFile(tensors.First(a => a.Name == "model.diffusion_model.output_blocks.5.2.conv.weight")); 426 | ((Upsample)unet.decoders[5][2]).conv.weight.bytes = data; 427 | data = modelLoader.ReadByteFromFile(tensors.First(a => a.Name == "model.diffusion_model.output_blocks.5.2.conv.bias")); 428 | ((Upsample)unet.decoders[5][2]).conv.bias.bytes = data; 429 | 430 | ModelLoader.LoadData.LoadResidualBlock(modelLoader, tensors, (ResidualBlock)unet.decoders[6][0], "model.diffusion_model.output_blocks.6.0"); 431 | ModelLoader.LoadData.LoadAttentionBlock(modelLoader, tensors, (AttentionBlock)unet.decoders[6][1], "model.diffusion_model.output_blocks.6.1"); 432 | 433 | ModelLoader.LoadData.LoadResidualBlock(modelLoader, tensors, (ResidualBlock)unet.decoders[7][0], "model.diffusion_model.output_blocks.7.0"); 434 | ModelLoader.LoadData.LoadAttentionBlock(modelLoader, tensors, (AttentionBlock)unet.decoders[7][1], "model.diffusion_model.output_blocks.7.1"); 435 | 436 | ModelLoader.LoadData.LoadResidualBlock(modelLoader, tensors, (ResidualBlock)unet.decoders[8][0], "model.diffusion_model.output_blocks.8.0"); 437 | ModelLoader.LoadData.LoadAttentionBlock(modelLoader, tensors, (AttentionBlock)unet.decoders[8][1], "model.diffusion_model.output_blocks.8.1"); 438 | data = modelLoader.ReadByteFromFile(tensors.First(a => a.Name == "model.diffusion_model.output_blocks.8.2.conv.weight")); 439 | ((Upsample)unet.decoders[8][2]).conv.weight.bytes = data; 440 | data = modelLoader.ReadByteFromFile(tensors.First(a => a.Name == "model.diffusion_model.output_blocks.8.2.conv.bias")); 441 | ((Upsample)unet.decoders[8][2]).conv.bias.bytes = data; 442 | 443 | ModelLoader.LoadData.LoadResidualBlock(modelLoader, tensors, (ResidualBlock)unet.decoders[9][0], "model.diffusion_model.output_blocks.9.0"); 444 | ModelLoader.LoadData.LoadAttentionBlock(modelLoader, tensors, (AttentionBlock)unet.decoders[9][1], "model.diffusion_model.output_blocks.9.1"); 445 | 446 | ModelLoader.LoadData.LoadResidualBlock(modelLoader, tensors, (ResidualBlock)unet.decoders[10][0], "model.diffusion_model.output_blocks.10.0"); 447 | ModelLoader.LoadData.LoadAttentionBlock(modelLoader, tensors, (AttentionBlock)unet.decoders[10][1], "model.diffusion_model.output_blocks.10.1"); 448 | 449 | ModelLoader.LoadData.LoadResidualBlock(modelLoader, tensors, (ResidualBlock)unet.decoders[11][0], "model.diffusion_model.output_blocks.11.0"); 450 | ModelLoader.LoadData.LoadAttentionBlock(modelLoader, tensors, (AttentionBlock)unet.decoders[11][1], "model.diffusion_model.output_blocks.11.1"); 451 | 452 | data = modelLoader.ReadByteFromFile(tensors.First(a => a.Name == "model.diffusion_model.out.2.weight")); 453 | final.conv.weight.bytes = data; 454 | 455 | data = modelLoader.ReadByteFromFile(tensors.First(a => a.Name == "model.diffusion_model.out.2.bias")); 456 | final.conv.bias.bytes = data; 457 | 458 | data = modelLoader.ReadByteFromFile(tensors.First(a => a.Name == "model.diffusion_model.out.0.weight")); 459 | final.groupnorm.weight.bytes = data; 460 | 461 | data = modelLoader.ReadByteFromFile(tensors.First(a => a.Name == "model.diffusion_model.out.0.bias")); 462 | final.groupnorm.bias.bytes = data; 463 | 464 | 465 | return this; 466 | } 467 | 468 | } 469 | 470 | 471 | 472 | } 473 | --------------------------------------------------------------------------------