├── .gitignore ├── Activation ├── GEGLU.cs └── GELU.cs ├── Attention ├── AttnProcessor.cs ├── BasicTransformerBlock.cs ├── CrossAttention.cs └── FeedForward.cs ├── AttentionMaskConverter.cs ├── Clip ├── CLIPAttention.cs ├── CLIPEncoder.cs ├── CLIPEncoderLayer.cs ├── CLIPMLP.cs ├── CLIPTextConfig.cs ├── CLIPTextEmbeddings.cs ├── CLIPTextModel.cs └── CLIPTextTransformer.cs ├── Embedding ├── ImagePositionalEmbeddings.cs └── TimestepEmbedding.cs ├── Extension.cs ├── Globalusing.cs ├── IModelConfig.cs ├── Pipelines └── StableDiffusionPipeline.cs ├── Program.cs ├── README.md ├── Scheduler ├── DDIMScheduler.cs └── DDIMSchedulerConfig.cs ├── Tests ├── Approvals │ ├── AutoEncoderKLTest.DecoderForwardTest.approved.txt │ ├── AutoEncoderKLTest.DecoderShapeTest.approved.txt │ ├── AutoEncoderKLTest.EncoderForwardTest.approved.txt │ ├── AutoEncoderKLTest.Fp16ShapeTest.approved.txt │ ├── AutoEncoderKLTest.ShapeTest.approved.txt │ ├── CLIPTextModelTest.Fp16ShapeTest.approved.txt │ ├── CLIPTextModelTest.Fp16TextModelForwardTest.approved.txt │ ├── CLIPTextModelTest.ShapeTest.approved.txt │ ├── CLIPTextModelTest.TextModelForwardTest.approved.txt │ ├── DDIMSchedulerTest.StepTest.approved.txt │ ├── StableDiffusionPipelineTest.GenerateCatImageTest.approved.txt │ ├── UNet2DConditionModelTest.ForwardTest.approved.txt │ ├── UNet2DConditionModelTest.Fp16ShapeTest.approved.txt │ └── UNet2DConditionModelTest.ShapeTest.approved.txt ├── AutoEncoderKL.test.cs ├── CLIPTextModel.test.cs ├── DDIMScheduler.test.cs ├── StableDiffusionPipeline.test.cs ├── Tokenizer.test.cs └── UNet2DConditionModel.test.cs ├── Tokenizer.cs ├── Torchsharp-stable-diffusion-2.csproj ├── Torchsharp-stable-diffusion-2.sln ├── UNet ├── AdaGroupNorm.cs ├── CrossAttnDownBlock2D.cs ├── CrossAttnUpBlock2D.cs ├── DownBlock2D.cs ├── DownEncoderBlock2D.cs ├── Downsample2D.cs ├── DualTransformer2DModel.cs ├── ResnetBlock2D.cs ├── ResnetBlockCondNorm2D.cs ├── SpatialNorm.cs ├── Timesteps.cs ├── Transformer2DModel.cs ├── UNet2DConditionModel.cs ├── UNet2DConditionModelConfig.cs ├── UNetMidBlock2D.cs ├── UNetMidBlock2DCrossAttn.cs ├── UpBlock2D.cs ├── UpDecoderBlock2D.cs └── Upsample2D.cs ├── Utils.cs ├── VAE ├── AutoencoderKL.cs ├── Config.cs ├── Decoder.cs ├── DiagonalGaussianDistribution.cs └── Encoder.cs └── img └── a photo of an astronaut riding a horse on mars.png /.gitignore: -------------------------------------------------------------------------------- 1 | obj/ 2 | bin/ 3 | *.received.txt 4 | .mono/ -------------------------------------------------------------------------------- /Activation/GEGLU.cs: -------------------------------------------------------------------------------- 1 | namespace SD; 2 | 3 | public class GEGLU : Module 4 | { 5 | private readonly Linear proj; 6 | 7 | public GEGLU(int dim_in, int dim_out, bool bias = true, ScalarType dtype = ScalarType.Float32) 8 | : base("GEGLU") 9 | { 10 | this.proj = Linear(dim_in, dim_out * 2, bias, dtype: dtype); 11 | } 12 | 13 | public Tensor gelu(Tensor gate) 14 | { 15 | return functional.gelu(gate); 16 | } 17 | 18 | public override Tensor forward(Tensor hidden_states) 19 | { 20 | var chunks = this.proj.forward(hidden_states).chunk(2, -1); 21 | hidden_states = chunks[0]; 22 | var gate = chunks[1]; 23 | return hidden_states * this.gelu(gate); 24 | } 25 | } -------------------------------------------------------------------------------- /Activation/GELU.cs: -------------------------------------------------------------------------------- 1 | namespace SD; 2 | public class GELU : Module 3 | { 4 | private readonly Linear proj; 5 | private readonly string approximate; 6 | 7 | public GELU( 8 | int dim_in, 9 | int dim_out, 10 | string approximate = "none", 11 | bool bias = true, 12 | ScalarType dtype = ScalarType.Float32) 13 | : base("GELU") 14 | { 15 | this.proj = Linear(dim_in, dim_out, bias, dtype: dtype); 16 | this.approximate = approximate; 17 | } 18 | 19 | public Tensor gelu(Tensor gate) 20 | { 21 | // todo 22 | // support approximate 23 | return functional.gelu(gate); 24 | } 25 | 26 | public override Tensor forward(Tensor hidden_states) 27 | { 28 | hidden_states = this.proj.forward(hidden_states); 29 | hidden_states = this.gelu(hidden_states); 30 | return hidden_states; 31 | } 32 | } -------------------------------------------------------------------------------- /Attention/AttnProcessor.cs: -------------------------------------------------------------------------------- 1 | using System.Reflection; 2 | using SD; 3 | using static TorchSharp.torch; 4 | 5 | public abstract class AttnProcessorBase 6 | { 7 | abstract public Tensor Process( 8 | Attention attn, 9 | Tensor hidden_states, 10 | Tensor? encoder_hidden_states = null, 11 | Tensor? attention_mask = null, 12 | Tensor? temb = null); 13 | } 14 | 15 | public class AttnProcessor2_0 : AttnProcessorBase 16 | { 17 | public override Tensor Process( 18 | Attention attn, 19 | Tensor hidden_states, 20 | Tensor? encoder_hidden_states = null, 21 | Tensor? attention_mask = null, 22 | Tensor? temb = null) 23 | { 24 | var residual = hidden_states; 25 | if (attn.SpatialNorm is not null){ 26 | hidden_states = attn.SpatialNorm.forward(hidden_states, temb); 27 | } 28 | 29 | var input_ndim = hidden_states.ndim; 30 | int batch_size; 31 | long channel = 0; 32 | long height = 0; 33 | long width = 0; 34 | if (input_ndim == 4){ 35 | batch_size = (int)hidden_states.shape[0]; 36 | channel = hidden_states.shape[1]; 37 | height = hidden_states.shape[2]; 38 | width = hidden_states.shape[3]; 39 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2); 40 | } 41 | 42 | int sequence_length; 43 | if (encoder_hidden_states is not null){ 44 | batch_size = (int)encoder_hidden_states.shape[0]; 45 | sequence_length = (int)encoder_hidden_states.shape[1]; 46 | } 47 | else{ 48 | batch_size = (int)hidden_states.shape[0]; 49 | sequence_length = (int)hidden_states.shape[1]; 50 | } 51 | 52 | if (attention_mask is not null) 53 | { 54 | attention_mask = attn.PrepareAttentionMask(attention_mask, sequence_length, batch_size); 55 | attention_mask = attention_mask!.view(batch_size, attn.Heads, -1, attention_mask.shape[^1]); 56 | } 57 | 58 | if (attn.GroupNorm is not null) 59 | { 60 | hidden_states = attn.GroupNorm.forward(hidden_states.transpose(1, 2)).transpose(1, 2); 61 | } 62 | 63 | var query = attn.ToQ.forward(hidden_states); 64 | 65 | if (encoder_hidden_states is null) 66 | { 67 | encoder_hidden_states = hidden_states; 68 | } 69 | else if (attn.NormCross is not null) 70 | { 71 | encoder_hidden_states = attn.NormEncoderHiddenStates(encoder_hidden_states); 72 | } 73 | 74 | var key = attn.ToK.forward(encoder_hidden_states); 75 | var value = attn.ToV.forward(encoder_hidden_states); 76 | var inner_dim = key.shape[^1]; 77 | var head_dim = inner_dim / attn.Heads; 78 | query = query.view(batch_size, -1, attn.Heads, head_dim).transpose(1, 2); 79 | key = key.view(batch_size, -1, attn.Heads, head_dim).transpose(1, 2); 80 | value = value.view(batch_size, -1, attn.Heads, head_dim).transpose(1, 2); 81 | hidden_states = nn.functional.scaled_dot_product_attention(query, key, value, attn_mask: attention_mask, p: 0, is_casual: false); 82 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.Heads * head_dim); 83 | hidden_states = hidden_states.to(query.dtype); 84 | 85 | // linear proj 86 | hidden_states = attn.ToOut[0].forward(hidden_states); 87 | // dropout 88 | hidden_states = attn.ToOut[1].forward(hidden_states); 89 | 90 | if (input_ndim == 4) 91 | { 92 | hidden_states = hidden_states.transpose(-1, -2).view(batch_size, channel, height, width); 93 | } 94 | 95 | if (attn.ResidualConnection) 96 | { 97 | hidden_states = hidden_states + residual; 98 | } 99 | 100 | hidden_states = hidden_states / attn.RescaleOutputFactor; 101 | 102 | return hidden_states; 103 | } 104 | } -------------------------------------------------------------------------------- /Attention/BasicTransformerBlock.cs: -------------------------------------------------------------------------------- 1 | using SD; 2 | 3 | public class BasicTransformerBlock : Module 4 | { 5 | private readonly int dim; 6 | private readonly int num_attention_heads; 7 | private readonly int attention_head_dim; 8 | private readonly double dropout; 9 | private readonly int? cross_attention_dim; 10 | private readonly string activation_fn; 11 | private readonly int? num_embeds_ada_norm; 12 | private readonly bool attention_bias; 13 | private readonly bool only_cross_attention; 14 | private readonly bool double_self_attention; 15 | private readonly bool upcast_attention; 16 | private readonly bool norm_elementwise_affine; 17 | private readonly string norm_type; 18 | private readonly double norm_eps; 19 | private readonly bool final_dropout; 20 | private readonly string attention_type; 21 | private readonly string? positional_embeddings; 22 | private readonly int? num_positional_embeddings; 23 | private readonly int? ada_norm_continous_conditioning_embedding_dim; 24 | private readonly int? ada_norm_bias; 25 | private readonly int? ff_inner_dim; 26 | private readonly bool ff_bias; 27 | private readonly bool attention_out_bias; 28 | private readonly bool use_ada_layer_norm_zero; 29 | private readonly bool use_ada_layer_norm; 30 | private readonly bool use_ada_layer_norm_single; 31 | private readonly bool use_layer_norm; 32 | private readonly bool use_ada_layer_norm_conitnuous; 33 | 34 | private readonly int? _chunk_size; 35 | private readonly int _chunk_dim; 36 | 37 | private readonly Module norm1; 38 | private readonly Attention attn1; 39 | 40 | private readonly Module? norm2 = null; 41 | private readonly Attention? attn2 = null; 42 | 43 | private readonly Module? norm3 = null; 44 | 45 | private readonly FeedForward ff; 46 | 47 | public BasicTransformerBlock( 48 | int dim, 49 | int num_attention_heads, 50 | int attention_head_dim, 51 | double dropout = 0.0, 52 | int? cross_attention_dim = null, 53 | string activation_fn = "geglu", 54 | int? num_embeds_ada_norm = null, 55 | bool attention_bias = false, 56 | bool only_cross_attention = false, 57 | bool double_self_attention = false, 58 | bool upcast_attention = false, 59 | bool norm_elementwise_affine = true, 60 | string norm_type = "layer_norm", 61 | double norm_eps = 1e-5, 62 | bool final_dropout = false, 63 | string attention_type = "default", 64 | string? positional_embeddings = null, 65 | int? num_positional_embeddings = null, 66 | int? ada_norm_continous_conditioning_embedding_dim = null, 67 | int? ada_norm_bias = null, 68 | int? ff_inner_dim = null, 69 | bool ff_bias = true, 70 | bool attention_out_bias = true, 71 | ScalarType dtype = ScalarType.Float32 72 | ) : base(nameof(BasicTransformerBlock)) 73 | { 74 | this.dim = dim; 75 | this.num_attention_heads = num_attention_heads; 76 | this.attention_head_dim = attention_head_dim; 77 | this.dropout = dropout; 78 | this.cross_attention_dim = cross_attention_dim; 79 | this.activation_fn = activation_fn; 80 | this.num_embeds_ada_norm = num_embeds_ada_norm; 81 | this.attention_bias = attention_bias; 82 | this.only_cross_attention = only_cross_attention; 83 | this.double_self_attention = double_self_attention; 84 | this.upcast_attention = upcast_attention; 85 | this.norm_elementwise_affine = norm_elementwise_affine; 86 | this.norm_type = norm_type; 87 | this.norm_eps = norm_eps; 88 | this.final_dropout = final_dropout; 89 | this.attention_type = attention_type; 90 | this.positional_embeddings = positional_embeddings; 91 | this.num_positional_embeddings = num_positional_embeddings; 92 | this.ada_norm_continous_conditioning_embedding_dim = ada_norm_continous_conditioning_embedding_dim; 93 | this.ada_norm_bias = ada_norm_bias; 94 | this.ff_inner_dim = ff_inner_dim; 95 | this.ff_bias = ff_bias; 96 | this.attention_out_bias = attention_out_bias; 97 | 98 | if (norm_type != "layer_norm") 99 | { 100 | throw new NotImplementedException("Only layer_norm is supported for now"); 101 | } 102 | 103 | this.use_ada_layer_norm_zero = false; 104 | this.use_ada_layer_norm = false; 105 | this.use_ada_layer_norm_single = false; 106 | this.use_layer_norm = true; 107 | this.use_ada_layer_norm_conitnuous = false; 108 | 109 | if (this.positional_embeddings is not null) 110 | { 111 | throw new NotImplementedException("Positional embeddings are not supported for now"); 112 | } 113 | 114 | this.norm1 = LayerNorm(dim, elementwise_affine: norm_elementwise_affine, eps: norm_eps, dtype: dtype); 115 | this.attn1 = new Attention( 116 | query_dim: dim, 117 | heads: num_attention_heads, 118 | dim_head: attention_head_dim, 119 | dropout: (float)dropout, 120 | bias: attention_bias, 121 | cross_attention_dim: only_cross_attention ? cross_attention_dim : null, 122 | upcast_attention: upcast_attention, 123 | out_bias: attention_out_bias, 124 | dtype: dtype); 125 | 126 | if (cross_attention_dim is not null || double_self_attention) 127 | { 128 | this.norm2 = LayerNorm(dim, elementwise_affine: norm_elementwise_affine, eps: norm_eps, dtype: dtype); 129 | this.attn2 = new Attention( 130 | query_dim: dim, 131 | cross_attention_dim: double_self_attention ? null : cross_attention_dim, 132 | heads: num_attention_heads, 133 | dim_head: attention_head_dim, 134 | dropout: (float)dropout, 135 | bias: attention_bias, 136 | upcast_attention: upcast_attention, 137 | out_bias: attention_out_bias, 138 | dtype: dtype); 139 | } 140 | 141 | if (norm_type == "layer_norm") 142 | { 143 | this.norm3 = LayerNorm(dim, elementwise_affine: norm_elementwise_affine, eps: norm_eps, dtype: dtype); 144 | } 145 | 146 | if (attention_type != "default") 147 | { 148 | throw new NotImplementedException("Only default attention is supported for now"); 149 | } 150 | 151 | this.ff = new FeedForward( 152 | dim: dim, 153 | dropout: dropout, 154 | activation_fn: activation_fn, 155 | final_dropout: final_dropout, 156 | inner_dim: ff_inner_dim, 157 | bias: ff_bias, 158 | dtype: dtype); 159 | 160 | this._chunk_size = null; 161 | this._chunk_dim = 0; 162 | } 163 | 164 | public override Tensor forward( 165 | Tensor hidden_states, 166 | Tensor? attention_mask = null, 167 | Tensor? encoder_hidden_states = null, 168 | Tensor? encoder_attention_mask = null, 169 | Tensor? timestep = null) 170 | { 171 | // self-attention 172 | var batch_size = hidden_states.shape[0]; 173 | 174 | var norm_hidden_states = this.norm1.forward(hidden_states); 175 | var attn_output = this.attn1.forward( 176 | norm_hidden_states, 177 | encoder_hidden_states: this.only_cross_attention ? encoder_hidden_states : null, 178 | attention_mask: attention_mask); 179 | 180 | hidden_states = hidden_states + attn_output; 181 | 182 | if (hidden_states.ndim == 4) 183 | { 184 | hidden_states = hidden_states.squeeze(1); 185 | } 186 | 187 | // cross-attention 188 | if (this.attn2 is not null) 189 | { 190 | norm_hidden_states = this.norm2!.forward(hidden_states); 191 | attn_output = this.attn2.forward( 192 | norm_hidden_states, 193 | encoder_hidden_states: encoder_hidden_states, 194 | attention_mask: encoder_attention_mask); 195 | 196 | hidden_states = hidden_states + attn_output; 197 | } 198 | 199 | // feed-forward 200 | norm_hidden_states = this.norm3!.forward(hidden_states); 201 | var ff_output = this.ff.forward(norm_hidden_states); 202 | 203 | hidden_states = hidden_states + ff_output; 204 | if(hidden_states.ndim == 4) 205 | { 206 | hidden_states = hidden_states.squeeze(1); 207 | } 208 | 209 | return hidden_states; 210 | } 211 | } -------------------------------------------------------------------------------- /Attention/FeedForward.cs: -------------------------------------------------------------------------------- 1 | namespace SD; 2 | public class FeedForward : Module 3 | { 4 | private readonly Module dropout; 5 | private readonly Linear linear_cls; 6 | private readonly ModuleList> net; 7 | 8 | public FeedForward( 9 | int dim, 10 | int? dim_out = null, 11 | int mult = 4, 12 | double dropout = 0.0, 13 | string activation_fn = "geglu", 14 | bool final_dropout = false, 15 | int? inner_dim = null, 16 | bool bias = true, 17 | ScalarType dtype = ScalarType.Float32) 18 | : base(nameof(FeedForward)) 19 | { 20 | inner_dim = inner_dim ?? (int)(dim * mult); 21 | dim_out = dim_out ?? dim; 22 | var act_fn = new GEGLU(dim, inner_dim.Value, bias, dtype: dtype); 23 | if (activation_fn == "geglu") 24 | { 25 | act_fn = new GEGLU(dim, inner_dim.Value, bias, dtype: dtype); 26 | } 27 | else 28 | { 29 | throw new NotImplementedException("Only GEGLU is supported for now"); 30 | } 31 | 32 | net = new ModuleList>(); 33 | // project in 34 | net.Add(act_fn); 35 | // project dropout 36 | net.Add(nn.Dropout(dropout)); 37 | // project out 38 | net.Add(Linear(inner_dim.Value, dim_out.Value, bias, dtype: dtype)); 39 | // FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout 40 | if (final_dropout) 41 | { 42 | net.Add(nn.Dropout(dropout)); 43 | } 44 | 45 | RegisterComponents(); 46 | } 47 | 48 | public override Tensor forward(Tensor hidden_states) 49 | { 50 | foreach (var module in net) 51 | { 52 | hidden_states = module.forward(hidden_states); 53 | } 54 | return hidden_states; 55 | } 56 | } -------------------------------------------------------------------------------- /AttentionMaskConverter.cs: -------------------------------------------------------------------------------- 1 | using static TorchSharp.torch.nn; 2 | using static TorchSharp.torch; 3 | using TorchSharp.Modules; 4 | using TorchSharp; 5 | 6 | namespace SD; 7 | 8 | public class AttentionMaskConverter 9 | { 10 | private readonly bool is_casual; 11 | private readonly int? sliding_window; 12 | 13 | public AttentionMaskConverter(bool is_casual, int? sliding_window) 14 | { 15 | this.is_casual = is_casual; 16 | this.sliding_window = sliding_window; 17 | } 18 | 19 | public Tensor? ToCasual4D( 20 | int batch_size, 21 | int query_length, 22 | int key_value_length, 23 | ScalarType dtype, 24 | Device device) 25 | { 26 | if (!is_casual) 27 | { 28 | throw new ArgumentException("This is not a casual mask"); 29 | } 30 | 31 | long[] input_shape = [batch_size, query_length]; 32 | var past_key_values_length = key_value_length - query_length; 33 | 34 | // create causal mask 35 | // [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 36 | Tensor? casual_4d_mask = null; 37 | if (query_length > 1 || this.sliding_window is int window) 38 | { 39 | casual_4d_mask = MakeCasualMask(input_shape, dtype, device, past_key_values_length, this.sliding_window); 40 | } 41 | 42 | return casual_4d_mask; 43 | } 44 | 45 | public static Tensor MakeCasualMask( 46 | long[] input_ids_shape, 47 | ScalarType dtype, 48 | Device device, 49 | int past_key_values_length = 0, 50 | int? sliding_window = null) 51 | { 52 | // Make causal mask used for bi-directional self-attention. 53 | var bsz = input_ids_shape[0]; 54 | var tgt_len = input_ids_shape[1]; 55 | var min = dtype switch 56 | { 57 | ScalarType.Float32 => torch.finfo(dtype).min, 58 | ScalarType.Float64 => torch.finfo(dtype).min, 59 | ScalarType.Float16 => -65504.0, 60 | _ => throw new ArgumentException("Invalid dtype"), 61 | }; 62 | var mask = torch.full([tgt_len, tgt_len], min, dtype: dtype, device: device); 63 | var mask_cond = torch.arange(tgt_len, device: device); 64 | mask.masked_fill_(mask_cond < (mask_cond + 1).view(tgt_len, 1), 0); 65 | mask = mask.to(dtype); 66 | 67 | 68 | if (past_key_values_length > 0) 69 | { 70 | mask = torch.cat([torch.zeros([tgt_len, past_key_values_length], dtype: dtype, device: device), mask], dim: -1); 71 | } 72 | 73 | if (sliding_window is int window) 74 | { 75 | var diagonal = past_key_values_length - window - 1; 76 | var context_mask = torch.tril(torch.ones([tgt_len, tgt_len], dtype: ScalarType.Bool, device: device), diagonal: diagonal); 77 | mask = mask.masked_fill(context_mask, min); 78 | } 79 | 80 | // return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) 81 | 82 | return mask.unsqueeze(0).unsqueeze(0).expand(bsz, 1, tgt_len, tgt_len + past_key_values_length); 83 | } 84 | 85 | /// 86 | /// Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` 87 | /// 88 | /// The input shape should be a tuple that defines `(batch_size, query_length)`. 89 | public static Tensor? Create4DCasualAttentionMask( 90 | long[] input_shape, 91 | ScalarType dtype, 92 | Device device, 93 | int past_key_values_length = 0, 94 | int? sliding_window = null) 95 | { 96 | var batch_size = (int)input_shape[0]; 97 | var query_length = (int)input_shape[1]; 98 | var converter = new AttentionMaskConverter(is_casual: true, sliding_window: sliding_window); 99 | var key_value_length = past_key_values_length + query_length; 100 | return converter.ToCasual4D(batch_size, query_length, key_value_length, dtype, device); 101 | } 102 | 103 | public static Tensor ExpandMask( 104 | Tensor mask, 105 | ScalarType dtype, 106 | int? tgt_len = null) 107 | { 108 | var bsz = (int)mask.shape[0]; 109 | var src_len = (int)mask.shape[1]; 110 | tgt_len = tgt_len ?? src_len; 111 | 112 | var expanded_mask = mask.unsqueeze(1).unsqueeze(1).expand(bsz, 1, tgt_len.Value, src_len).to(dtype); 113 | var inverted_mask = 1.0 - expanded_mask; 114 | var min = dtype switch 115 | { 116 | ScalarType.Float32 => torch.finfo(dtype).min, 117 | ScalarType.Float64 => torch.finfo(dtype).min, 118 | ScalarType.Float16 => -65504.0, 119 | _ => throw new ArgumentException("Invalid dtype"), 120 | }; 121 | 122 | return inverted_mask.masked_fill(inverted_mask.to(ScalarType.Bool), min); 123 | } 124 | } -------------------------------------------------------------------------------- /Clip/CLIPAttention.cs: -------------------------------------------------------------------------------- 1 | using static TorchSharp.torch.nn; 2 | using static TorchSharp.torch; 3 | using TorchSharp.Modules; 4 | using TorchSharp; 5 | 6 | namespace SD; 7 | 8 | public class CLIPAttention : Module 9 | { 10 | private readonly CLIPTextConfig config; 11 | private readonly int embed_dim; 12 | private readonly int num_heads; 13 | private readonly int head_dim; 14 | private readonly float scale; 15 | private readonly double dropout; 16 | private readonly Linear k_proj; 17 | private readonly Linear v_proj; 18 | private readonly Linear q_proj; 19 | private readonly Linear out_proj; 20 | private readonly ScalarType dtype; 21 | 22 | public CLIPAttention( 23 | CLIPTextConfig config) 24 | : base(nameof(CLIPAttention)) 25 | { 26 | this.config = config; 27 | this.embed_dim = config.HiddenSize; 28 | this.num_heads = config.NumAttentionHeads; 29 | this.head_dim = this.embed_dim / this.num_heads; 30 | this.dtype = config.DType; 31 | if (this.head_dim * this.num_heads != this.embed_dim) 32 | { 33 | throw new ArgumentException("embed_dim must be divisible by num_heads"); 34 | } 35 | 36 | this.scale = 1.0f / MathF.Sqrt(this.head_dim); 37 | this.dropout = config.AttentionDropout; 38 | 39 | this.k_proj = Linear(this.embed_dim, this.embed_dim, dtype: dtype); 40 | this.v_proj = Linear(this.embed_dim, this.embed_dim, dtype: dtype); 41 | this.q_proj = Linear(this.embed_dim, this.embed_dim, dtype: dtype); 42 | this.out_proj = Linear(this.embed_dim, this.embed_dim, dtype: dtype); 43 | 44 | RegisterComponents(); 45 | } 46 | 47 | public override (Tensor, Tensor?) forward( 48 | Tensor hidden_states, 49 | Tensor? attention_mask = null, 50 | Tensor? causal_attention_mask = null, 51 | bool? output_attentions = false) 52 | { 53 | // shape of hidden_states: (bsz, time, channel) 54 | var bsz = (int)hidden_states.shape[0]; 55 | var tgt_len = (int)hidden_states.shape[1]; 56 | var embed_dim = (int)hidden_states.shape[2]; 57 | 58 | // get query proj 59 | var query_states = this.q_proj.forward(hidden_states) * this.scale; 60 | var key_states = this._shape(this.k_proj.forward(hidden_states), -1, bsz); 61 | var value_states = this._shape(this.v_proj.forward(hidden_states), -1, bsz); 62 | 63 | long[] proj_shape = [bsz * this.num_heads, -1, this.head_dim]; 64 | query_states = this._shape(query_states, tgt_len, bsz).view(proj_shape); 65 | key_states = key_states.view(proj_shape); 66 | value_states = value_states.view(proj_shape); 67 | 68 | var src_len = key_states.shape[1]; 69 | var attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)); 70 | // attn_weights's shape: (bsz * num_heads, tgt_len, src_len) 71 | 72 | if (causal_attention_mask is not null) 73 | { 74 | // causal_attention_mask's shape: (bsz, 1, tgt_len, src_len) 75 | attn_weights = attn_weights.view(bsz, this.num_heads, tgt_len, src_len) + causal_attention_mask; 76 | attn_weights = attn_weights.view(bsz * this.num_heads, tgt_len, src_len); 77 | } 78 | 79 | if (attention_mask is not null) 80 | { 81 | // attention_mask's shape: (bsz, 1, tgt_len, src_len) 82 | attn_weights = attn_weights.view(bsz, this.num_heads, tgt_len, src_len) + attention_mask; 83 | attn_weights = attn_weights.view(bsz * this.num_heads, tgt_len, src_len); 84 | } 85 | 86 | attn_weights = attn_weights.softmax(-1, dtype: this.config.DType); 87 | Tensor? attn_weights_reshaped = null; 88 | 89 | if (output_attentions == true) 90 | { 91 | // this operation is a bit akward, but it's required to 92 | // make sure that attn_weights keeps its gradient. 93 | // In order to do so, attn_weights have to reshaped 94 | // twice and have to be reused in the following 95 | attn_weights_reshaped = attn_weights.view(bsz, this.num_heads, tgt_len, src_len); 96 | attn_weights = attn_weights_reshaped.view(bsz * this.num_heads, tgt_len, src_len); 97 | } 98 | 99 | var attn_probs = nn.functional.dropout(attn_weights, this.dropout, this.training); 100 | var attn_output = torch.bmm(attn_probs, value_states); 101 | 102 | // attn_output's shape: (bsz * num_heads, tgt_len, head_dim) 103 | attn_output = attn_output.view(bsz, this.num_heads, tgt_len, this.head_dim); 104 | attn_output = attn_output.transpose(1, 2); 105 | attn_output.Peek("attn_output"); 106 | attn_output = attn_output.reshape(bsz, tgt_len, embed_dim); 107 | attn_output = this.out_proj.forward(attn_output); 108 | 109 | return (attn_output, attn_weights_reshaped); 110 | } 111 | 112 | private Tensor _shape(Tensor tensor, int seq_len, int bsz) 113 | { 114 | return tensor.view(bsz, seq_len, this.num_heads, this.head_dim).permute(0, 2, 1, 3).contiguous(); 115 | } 116 | } -------------------------------------------------------------------------------- /Clip/CLIPEncoder.cs: -------------------------------------------------------------------------------- 1 | using static TorchSharp.torch.nn; 2 | using static TorchSharp.torch; 3 | using TorchSharp.Modules; 4 | using TorchSharp; 5 | 6 | namespace SD; 7 | 8 | public class CLIPEncoder : Module 9 | { 10 | private readonly CLIPTextConfig config; 11 | private readonly ModuleList layers; 12 | private readonly bool gradient_checkpointing = false; 13 | 14 | public CLIPEncoder(CLIPTextConfig config) 15 | : base(nameof(CLIPEncoder)) 16 | { 17 | this.config = config; 18 | this.layers = new ModuleList(Enumerable.Range(0, config.NumHiddenLayers).Select(_ => new CLIPEncoderLayer(config)).ToArray()); 19 | RegisterComponents(); 20 | } 21 | 22 | public override BaseModelOutput forward( 23 | Tensor inputs_embeds, 24 | Tensor? attention_mask = null, 25 | Tensor? casual_attention_mask = null, 26 | bool? output_attentions = false, 27 | bool? output_hidden_states = false) 28 | { 29 | // inputs_embeds: [batch_size, seq_length, hidden_size] 30 | output_hidden_states = output_hidden_states ?? false; 31 | output_attentions = output_attentions ?? false; 32 | 33 | List? encoder_states = null; 34 | List? all_attentions = null; 35 | 36 | if (output_hidden_states is true) 37 | { 38 | encoder_states = new List(); 39 | } 40 | 41 | if (output_attentions is true) 42 | { 43 | all_attentions = new List(); 44 | } 45 | 46 | var hidden_states = inputs_embeds; 47 | foreach (var layer in layers) 48 | { 49 | if (encoder_states is not null) 50 | { 51 | encoder_states.Add(hidden_states); 52 | } 53 | (hidden_states, var attension_weight) = layer.forward(hidden_states, attention_mask, casual_attention_mask, output_attentions); 54 | 55 | if (all_attentions is not null && attension_weight is not null) 56 | { 57 | all_attentions.Add(attension_weight); 58 | } 59 | } 60 | 61 | if (encoder_states is not null) 62 | { 63 | encoder_states.Add(hidden_states); 64 | } 65 | 66 | return new BaseModelOutput(lastHiddenState: hidden_states, hiddenStates: encoder_states?.ToArray(), attentions: all_attentions?.ToArray()); 67 | } 68 | } -------------------------------------------------------------------------------- /Clip/CLIPEncoderLayer.cs: -------------------------------------------------------------------------------- 1 | using static TorchSharp.torch.nn; 2 | using static TorchSharp.torch; 3 | using TorchSharp.Modules; 4 | using TorchSharp; 5 | 6 | namespace SD; 7 | 8 | public class CLIPEncoderLayer : Module 9 | { 10 | private readonly int embed_dim; 11 | private readonly CLIPAttention self_attn; 12 | private readonly LayerNorm layer_norm1; 13 | private readonly CLIPMLP mlp; 14 | private readonly LayerNorm layer_norm2; 15 | 16 | public CLIPEncoderLayer(CLIPTextConfig config) 17 | : base(nameof(CLIPEncoderLayer)) 18 | { 19 | this.embed_dim = config.HiddenSize; 20 | this.self_attn = new CLIPAttention(config); 21 | this.layer_norm1 = LayerNorm(embed_dim, eps: config.LayerNormEps, dtype: config.DType); 22 | this.mlp = new CLIPMLP(config); 23 | this.layer_norm2 = LayerNorm(embed_dim, eps: config.LayerNormEps, dtype: config.DType); 24 | 25 | RegisterComponents(); 26 | } 27 | 28 | public override (Tensor, Tensor?) forward( 29 | Tensor hidden_states, 30 | Tensor? attention_mask = null, 31 | Tensor? causal_attention_mask = null, 32 | bool? output_attentions = false) 33 | { 34 | var residual = hidden_states; 35 | hidden_states = this.layer_norm1.forward(hidden_states); 36 | (hidden_states, var attention_weights) = this.self_attn.forward(hidden_states, attention_mask, causal_attention_mask, output_attentions); 37 | hidden_states.Peek("clip_encoder_layer_hidden_states"); 38 | hidden_states = hidden_states + residual; 39 | residual = hidden_states; 40 | hidden_states = this.layer_norm2.forward(hidden_states); 41 | hidden_states = this.mlp.forward(hidden_states); 42 | hidden_states = hidden_states + residual; 43 | if (output_attentions == true) 44 | { 45 | return (hidden_states, attention_weights); 46 | } 47 | else 48 | { 49 | return (hidden_states, null); 50 | } 51 | } 52 | } -------------------------------------------------------------------------------- /Clip/CLIPMLP.cs: -------------------------------------------------------------------------------- 1 | using static TorchSharp.torch.nn; 2 | using static TorchSharp.torch; 3 | using TorchSharp.Modules; 4 | 5 | namespace SD; 6 | 7 | public class CLIPMLP : Module 8 | { 9 | private readonly CLIPTextConfig config; 10 | 11 | private readonly Linear fc1; 12 | private readonly Linear fc2; 13 | private readonly Module activation_fn; 14 | 15 | public CLIPMLP(CLIPTextConfig config) 16 | : base(nameof(CLIPMLP)) 17 | { 18 | this.config = config; 19 | this.activation_fn = Utils.GetActivation(config.HiddenAct); 20 | this.fc1 = Linear(config.HiddenSize, config.IntermediateSize, dtype: config.DType); 21 | this.fc2 = Linear(config.IntermediateSize, config.HiddenSize, dtype: config.DType); 22 | RegisterComponents(); 23 | } 24 | 25 | public override Tensor forward(Tensor hidden_states) 26 | { 27 | hidden_states = this.fc1.forward(hidden_states); 28 | hidden_states = this.activation_fn.forward(hidden_states); 29 | hidden_states = this.fc2.forward(hidden_states); 30 | 31 | return hidden_states; 32 | } 33 | } -------------------------------------------------------------------------------- /Clip/CLIPTextConfig.cs: -------------------------------------------------------------------------------- 1 | using System.Text.Json.Serialization; 2 | 3 | namespace SD; 4 | 5 | public class CLIPTextConfig 6 | { 7 | [JsonPropertyName("vocab_size")] 8 | public int VocabSize { get; set; } = 49408; 9 | 10 | [JsonPropertyName("hidden_size")] 11 | public int HiddenSize { get; set; } = 512; 12 | 13 | [JsonPropertyName("intermediate_size")] 14 | public int IntermediateSize { get; set; } = 2048; 15 | 16 | [JsonPropertyName("projection_dim")] 17 | public int ProjectionDim { get; set; } = 512; 18 | 19 | [JsonPropertyName("num_hidden_layers")] 20 | public int NumHiddenLayers { get; set; } = 12; 21 | 22 | [JsonPropertyName("num_attention_heads")] 23 | public int NumAttentionHeads { get; set; } = 8; 24 | 25 | [JsonPropertyName("max_position_embeddings")] 26 | public int MaxPositionEmbeddings { get; set; } = 77; 27 | 28 | [JsonPropertyName("hidden_act")] 29 | public string HiddenAct { get; set; } = "quick_gelu"; 30 | 31 | [JsonPropertyName("layer_norm_eps")] 32 | public double LayerNormEps { get; set; } = 1e-5; 33 | 34 | [JsonPropertyName("attention_dropout")] 35 | public double AttentionDropout { get; set; } = 0.0; 36 | 37 | [JsonPropertyName("initializer_range")] 38 | public double InitializerRange { get; set; } = 0.02; 39 | 40 | [JsonPropertyName("initializer_factor")] 41 | public double InitializerFactor { get; set; } = 1.0; 42 | 43 | [JsonPropertyName("pad_token_id")] 44 | public int PadTokenId { get; set; } = 1; 45 | 46 | [JsonPropertyName("bos_token_id")] 47 | public int BosTokenId { get; set; } = 49406; 48 | 49 | [JsonPropertyName("eos_token_id")] 50 | public int EosTokenId { get; set; } = 49407; 51 | 52 | [JsonPropertyName("use_attention_mask")] 53 | public bool UseAttentionMask { get; set; } = false; 54 | 55 | [JsonPropertyName("dtype")] 56 | public ScalarType DType { get; set; } = ScalarType.Float32; 57 | } -------------------------------------------------------------------------------- /Clip/CLIPTextEmbeddings.cs: -------------------------------------------------------------------------------- 1 | using static TorchSharp.torch.nn; 2 | using static TorchSharp.torch; 3 | using TorchSharp.Modules; 4 | 5 | namespace SD; 6 | 7 | public class CLIPTextEmbeddings : Module 8 | { 9 | private readonly CLIPTextConfig config; 10 | private readonly Embedding token_embedding; 11 | private readonly Embedding position_embedding; 12 | 13 | public CLIPTextEmbeddings(CLIPTextConfig config) 14 | : base(nameof(CLIPTextEmbeddings)) 15 | { 16 | this.config = config; 17 | var embed_dim = config.HiddenSize; 18 | token_embedding = Embedding(config.VocabSize, embed_dim, dtype: config.DType); 19 | position_embedding = Embedding(config.MaxPositionEmbeddings, embed_dim, dtype: config.DType); 20 | 21 | this.register_buffer("position_ids", arange(config.MaxPositionEmbeddings).expand(1, -1), persistent: false); 22 | 23 | RegisterComponents(); 24 | } 25 | 26 | public override Tensor forward( 27 | Tensor? input_ids = null, 28 | Tensor? position_ids = null, 29 | Tensor? inputs_embeds = null) 30 | { 31 | if (input_ids is null && position_ids is null && inputs_embeds is null) 32 | { 33 | throw new ArgumentException("You have to specify either input_ids or inputs_embeds"); 34 | } 35 | var seq_length = input_ids is not null ? input_ids.shape[^1] : inputs_embeds!.shape[^2]; 36 | var device = input_ids?.device ?? position_ids?.device ?? inputs_embeds?.device ?? throw new ArgumentException("You have to specify either input_ids or inputs_embeds"); 37 | if (position_ids is null) 38 | { 39 | position_ids = this.get_buffer("position_ids")[.., ..(int)seq_length]; 40 | position_ids = position_ids.to(device); 41 | } 42 | 43 | if (inputs_embeds is null) 44 | { 45 | inputs_embeds = this.token_embedding.forward(input_ids!); 46 | } 47 | 48 | var position_embeds = this.position_embedding.forward(position_ids); 49 | return inputs_embeds + position_embeds; 50 | } 51 | } -------------------------------------------------------------------------------- /Clip/CLIPTextModel.cs: -------------------------------------------------------------------------------- 1 | using static TorchSharp.torch.nn; 2 | using static TorchSharp.torch; 3 | using TorchSharp.Modules; 4 | using TorchSharp; 5 | using System.Text.Json; 6 | using TorchSharp.PyBridge; 7 | 8 | namespace SD; 9 | 10 | public class CLIPTextModel : Module, IModelConfigLoader 11 | { 12 | private readonly CLIPTextConfig config; 13 | private readonly CLIPTextTransformer text_model; 14 | 15 | public CLIPTextModel(CLIPTextConfig config) 16 | : base(nameof(CLIPTextModel)) 17 | { 18 | this.config = config; 19 | this.text_model = new CLIPTextTransformer(config); 20 | this.PostInit(); 21 | RegisterComponents(); 22 | } 23 | 24 | private void PostInit() 25 | { 26 | var factor = this.config.InitializerFactor; 27 | } 28 | 29 | public override BaseModelOutputWithPooling forward( 30 | Tensor input_ids, 31 | Tensor? attention_mask = null, 32 | Tensor? position_ids = null, 33 | bool? output_hidden_states = false, 34 | bool? output_attentions = false) 35 | { 36 | return this.text_model.forward(input_ids, attention_mask, position_ids, output_hidden_states, output_attentions); 37 | } 38 | 39 | public static CLIPTextModel FromPretrained( 40 | string pretrainedModelNameOrPath, 41 | string configName = "config.json", 42 | string modelWeightName = "model", 43 | bool useSafeTensor = true, 44 | ScalarType torchDtype = ScalarType.Float32) 45 | { 46 | var configPath = Path.Combine(pretrainedModelNameOrPath, configName); 47 | var json = File.ReadAllText(configPath); 48 | var config = JsonSerializer.Deserialize(json) ?? throw new ArgumentNullException(nameof(CLIPTextConfig)); 49 | config.DType = torchDtype; 50 | 51 | var clipTextModel = new CLIPTextModel(config); 52 | 53 | modelWeightName = (useSafeTensor, torchDtype) switch 54 | { 55 | (true, ScalarType.Float32) => $"{modelWeightName}.safetensors", 56 | (true, ScalarType.Float16) => $"{modelWeightName}.fp16.safetensors", 57 | (false, ScalarType.Float32) => $"{modelWeightName}.bin", 58 | (false, ScalarType.Float16) => $"{modelWeightName}.fp16.bin", 59 | _ => throw new ArgumentException("Invalid arguments for useSafeTensor and torchDtype") 60 | }; 61 | 62 | var location = Path.Combine(pretrainedModelNameOrPath, modelWeightName); 63 | 64 | var loadedParameters = new Dictionary(); 65 | clipTextModel.load_safetensors(location, strict: false, loadedParameters: loadedParameters); 66 | 67 | return clipTextModel; 68 | } 69 | 70 | public CLIPTextModel LoadFromModelConfig( 71 | string pretrainedModelNameOrPath, 72 | string configName = "config.json", 73 | string modelWeightName = "diffusion_pytorch_model", 74 | bool useSafeTensor = true, 75 | ScalarType torchDtype = ScalarType.Float32) 76 | { 77 | return CLIPTextModel.FromPretrained(pretrainedModelNameOrPath, configName, modelWeightName, useSafeTensor, torchDtype); 78 | } 79 | } -------------------------------------------------------------------------------- /Clip/CLIPTextTransformer.cs: -------------------------------------------------------------------------------- 1 | using static TorchSharp.torch.nn; 2 | using static TorchSharp.torch; 3 | using TorchSharp.Modules; 4 | using TorchSharp; 5 | 6 | namespace SD; 7 | 8 | public class BaseModelOutput 9 | { 10 | public BaseModelOutput( 11 | Tensor lastHiddenState, 12 | Tensor[]? hiddenStates = null, 13 | Tensor[]? attentions = null) 14 | { 15 | LastHiddenState = lastHiddenState; 16 | HiddenStates = hiddenStates; 17 | Attentions = attentions; 18 | } 19 | 20 | public Tensor LastHiddenState { get; } 21 | 22 | public Tensor[]? HiddenStates { get; } 23 | 24 | public Tensor[]? Attentions { get; } 25 | } 26 | 27 | public class BaseModelOutputWithPooling : BaseModelOutput 28 | { 29 | public BaseModelOutputWithPooling( 30 | Tensor lastHiddenState, 31 | Tensor poolerOutput, 32 | Tensor[]? hiddenStates = null, 33 | Tensor[]? attentions = null) 34 | : base(lastHiddenState, hiddenStates, attentions) 35 | { 36 | PoolerOutput = poolerOutput; 37 | } 38 | 39 | public Tensor PoolerOutput { get; } 40 | } 41 | public class CLIPTextTransformer : Module 42 | { 43 | private readonly CLIPTextConfig config; 44 | private readonly CLIPTextEmbeddings embeddings; 45 | private readonly CLIPEncoder encoder; 46 | private readonly LayerNorm final_layer_norm; 47 | private readonly int eos_token_id; 48 | 49 | public CLIPTextTransformer(CLIPTextConfig config) 50 | : base(nameof(CLIPTextTransformer)) 51 | { 52 | this.config = config; 53 | this.embeddings = new CLIPTextEmbeddings(config); 54 | this.encoder = new CLIPEncoder(config); 55 | this.final_layer_norm = LayerNorm(config.HiddenSize, eps: config.LayerNormEps, dtype: config.DType); 56 | this.eos_token_id = config.EosTokenId; 57 | 58 | RegisterComponents(); 59 | } 60 | 61 | public override BaseModelOutputWithPooling forward( 62 | Tensor input_ids, 63 | Tensor? attention_mask = null, 64 | Tensor? position_ids = null, 65 | bool? output_attentions = false, 66 | bool? output_hidden_states = false) 67 | { 68 | output_attentions = output_attentions ?? false; 69 | output_hidden_states = output_hidden_states ?? false; 70 | 71 | var input_shape = input_ids.shape; 72 | input_ids = input_ids.view(-1, input_shape[^1]); 73 | var hidden_states = this.embeddings.forward(input_ids: input_ids, position_ids: position_ids); 74 | var casual_attention_mask = AttentionMaskConverter.Create4DCasualAttentionMask(input_shape, hidden_states.dtype, hidden_states.device); 75 | if (attention_mask is not null) 76 | { 77 | attention_mask = AttentionMaskConverter.ExpandMask(attention_mask, hidden_states.dtype); 78 | } 79 | 80 | hidden_states.Peek("hidden_states"); 81 | attention_mask?.Peek("attention_mask"); 82 | casual_attention_mask.Peek("casual_attention_mask"); 83 | var encoder_outputs = this.encoder.forward(hidden_states, attention_mask, casual_attention_mask, output_attentions, output_hidden_states); 84 | 85 | var last_hidden_state = encoder_outputs.LastHiddenState; 86 | last_hidden_state.Peek("last_hidden_state"); 87 | last_hidden_state = this.final_layer_norm.forward(last_hidden_state); 88 | Tensor pooled_output; 89 | if (this.eos_token_id == 2) 90 | { 91 | // The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here. 92 | // A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added 93 | // ------------------------------------------------------------ 94 | // text_embeds.shape = [batch_size, sequence_length, transformer.width] 95 | // take features from the eot embedding (eot_token is the highest number in each sequence) 96 | // casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 97 | pooled_output = last_hidden_state[ 98 | torch.arange(last_hidden_state.shape[0], device: last_hidden_state.device), 99 | input_ids.to(ScalarType.Int32).to(last_hidden_state.device).argmax(dim: 1) 100 | ]; 101 | } 102 | else 103 | { 104 | pooled_output = last_hidden_state[ 105 | torch.arange(last_hidden_state.shape[0], device: last_hidden_state.device), 106 | (input_ids.to(ScalarType.Int32).to(last_hidden_state.device) == this.eos_token_id).to(ScalarType.Int32).argmax(dim: -1) 107 | ]; 108 | } 109 | 110 | return new BaseModelOutputWithPooling(last_hidden_state, pooled_output, encoder_outputs.HiddenStates, encoder_outputs.Attentions); 111 | } 112 | } -------------------------------------------------------------------------------- /Embedding/ImagePositionalEmbeddings.cs: -------------------------------------------------------------------------------- 1 | using static TorchSharp.torch.nn; 2 | using static TorchSharp.torch; 3 | using TorchSharp.Modules; 4 | using TorchSharp; 5 | 6 | namespace SD; 7 | 8 | public class ImagePositionalEmbeddings : Module 9 | { 10 | private readonly Embedding emb; 11 | private readonly Embedding height_emb; 12 | private readonly Embedding width_emb; 13 | private readonly int height; 14 | private readonly int width; 15 | private readonly int num_embed; 16 | private readonly int embed_dim; 17 | 18 | public ImagePositionalEmbeddings( 19 | int num_embed, 20 | int height, 21 | int width, 22 | int embed_dim, 23 | ScalarType dtype = ScalarType.Float32 24 | ) : base(nameof(ImagePositionalEmbeddings)) 25 | { 26 | this.height = height; 27 | this.width = width; 28 | this.num_embed = num_embed; 29 | this.embed_dim = embed_dim; 30 | 31 | this.emb = Embedding(num_embed, embed_dim, dtype: dtype); 32 | this.height_emb = Embedding(height, embed_dim, dtype: dtype); 33 | this.width_emb = Embedding(width, embed_dim, dtype: dtype); 34 | 35 | RegisterComponents(); 36 | } 37 | 38 | public override Tensor forward(Tensor index) 39 | { 40 | var emb = this.emb.forward(index); 41 | 42 | var height_emb = this.height_emb.forward(torch.arange(this.height, device: index.device).view(1, this.height)); 43 | 44 | // 1 x H x D -> 1 x H x 1 x D 45 | height_emb = height_emb.unsqueeze(2); 46 | 47 | var width_emb = this.width_emb.forward(torch.arange(this.width, device: index.device).view(1, this.width)); 48 | 49 | // 1 x W x D -> 1 x 1 x W x D 50 | width_emb = width_emb.unsqueeze(1); 51 | 52 | var pos_emb = height_emb + width_emb; 53 | 54 | // 1 x H x W x D -> 1 x L xD 55 | pos_emb = pos_emb.view(1, this.height * this.width, -1); 56 | 57 | emb = emb + pos_emb[.., 0..(int)emb.shape[1], 0..]; 58 | 59 | return emb; 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /Embedding/TimestepEmbedding.cs: -------------------------------------------------------------------------------- 1 | using static TorchSharp.torch.nn; 2 | using static TorchSharp.torch; 3 | using TorchSharp.Modules; 4 | using TorchSharp; 5 | 6 | namespace SD; 7 | 8 | public class TimestepEmbedding : Module 9 | { 10 | private readonly Linear linear_1; 11 | private readonly Linear linear_2; 12 | private readonly Linear? cond_proj = null; 13 | private readonly Module act; 14 | private readonly Module? post_act = null; 15 | 16 | public TimestepEmbedding( 17 | int in_channels, 18 | int time_embed_dim, 19 | string act_fn = "silu", 20 | int? out_dim = null, 21 | string? post_act_fn = null, 22 | int? cond_proj_dim = null, 23 | bool sample_proj_bias = true, 24 | ScalarType dtype = ScalarType.Float32) 25 | : base(nameof(TimestepEmbedding)) 26 | { 27 | this.linear_1 = Linear(in_channels, time_embed_dim, sample_proj_bias, dtype: dtype); 28 | 29 | if (cond_proj_dim is int proj_dim) 30 | { 31 | this.cond_proj = Linear(proj_dim, time_embed_dim, false, dtype: dtype); 32 | } 33 | 34 | this.act = Utils.GetActivation(act_fn); 35 | 36 | var time_embed_dim_out = out_dim ?? time_embed_dim; 37 | 38 | this.linear_2 = Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias, dtype: dtype); 39 | 40 | if (post_act_fn is string post_act_fn_str) 41 | { 42 | this.post_act = Utils.GetActivation(post_act_fn_str); 43 | } 44 | 45 | RegisterComponents(); 46 | } 47 | 48 | public override Tensor forward(Tensor sample, Tensor? condition = null) 49 | { 50 | if (this.cond_proj is not null && condition is not null) 51 | { 52 | sample = sample + this.cond_proj.forward(condition); 53 | } 54 | 55 | sample = this.linear_1.forward(sample); 56 | 57 | if (this.act is not null) 58 | { 59 | sample = this.act.forward(sample); 60 | } 61 | 62 | sample = this.linear_2.forward(sample); 63 | 64 | if (this.post_act is not null) 65 | { 66 | sample = this.post_act.forward(sample); 67 | } 68 | 69 | return sample; 70 | } 71 | } -------------------------------------------------------------------------------- /Extension.cs: -------------------------------------------------------------------------------- 1 | using System.Text; 2 | using TorchSharp; 3 | using static TorchSharp.torch; 4 | 5 | public static class Extension 6 | { 7 | public static string Peek(this Tensor tensor, string id, int n = 10) 8 | { 9 | var device = tensor.device; 10 | var dtype = tensor.dtype; 11 | // if type is fp16, convert to fp32 12 | if (tensor.dtype == ScalarType.Float16) 13 | { 14 | tensor = tensor.to_type(ScalarType.Float32); 15 | } 16 | tensor = tensor.cpu(); 17 | var shapeString = string.Join(',', tensor.shape); 18 | var tensor_1d = tensor.reshape(-1); 19 | var tensor_index = torch.arange(tensor_1d.shape[0], dtype: ScalarType.Float32).to(tensor_1d.device).sqrt(); 20 | var avg = (tensor_1d * tensor_index).sum(); 21 | avg = avg / tensor_1d.sum(); 22 | // keep four decimal places 23 | avg = avg.round(4); 24 | var str = $"{id}: sum: {avg.ToSingle()} dtype: {dtype} shape: [{shapeString}]"; 25 | 26 | Console.WriteLine(str); 27 | 28 | return str; 29 | } 30 | 31 | public static string Peek(this nn.Module model) 32 | { 33 | var sb = new StringBuilder(); 34 | var state_dict = model.state_dict(); 35 | // preview state_dict 36 | int i = 0; 37 | foreach (var (key, value) in state_dict.OrderBy(x => x.Key, StringComparer.OrdinalIgnoreCase)) 38 | { 39 | var str = value.Peek(key); 40 | sb.AppendLine($"{i}: {str}"); 41 | i++; 42 | } 43 | 44 | var res = sb.ToString(); 45 | 46 | Console.WriteLine(res); 47 | 48 | return res; 49 | } 50 | 51 | public static string Peek_Shape(this nn.Module model) 52 | { 53 | var sb = new StringBuilder(); 54 | var state_dict = model.state_dict(); 55 | // preview state_dict 56 | int i = 0; 57 | foreach (var (key, value) in state_dict.OrderBy(x => x.Key, StringComparer.OrdinalIgnoreCase)) 58 | { 59 | // shape str: [x, y, z] 60 | var shapeStr = string.Join(", ", value.shape); 61 | sb.AppendLine($"{i}: {key} shape: [{shapeStr}]"); 62 | i++; 63 | } 64 | 65 | var res = sb.ToString(); 66 | 67 | Console.WriteLine(res); 68 | 69 | return res; 70 | } 71 | 72 | public static void LoadStateDict(this Dictionary dict, string location) 73 | { 74 | using FileStream stream = File.OpenRead(location); 75 | using BinaryReader reader = new BinaryReader(stream); 76 | var num = reader.Decode(); 77 | Console.WriteLine($"num: {num}"); 78 | for (int i = 0; i < num; i++) 79 | { 80 | var key = reader.ReadString(); 81 | Tensor tensor = dict[key]; 82 | Console.WriteLine($"load key: {key} tensor: {tensor}"); 83 | 84 | var originalDevice = tensor.device; 85 | var originalType = tensor.dtype; 86 | if (tensor.dtype == ScalarType.BFloat16) 87 | { 88 | tensor = tensor.to_type(ScalarType.Float32); 89 | } 90 | 91 | TensorExtensionMethods.Load(ref tensor!, reader, skip: false); 92 | 93 | // convert type to bf16 if type is float 94 | tensor = tensor!.to_type(originalType); 95 | dict[key] = tensor; 96 | } 97 | } 98 | 99 | // 100 | // 摘要: 101 | // Decode a long value from a binary reader 102 | // 103 | // 参数: 104 | // reader: 105 | // A BinaryReader instance used for input. 106 | // 107 | // 返回结果: 108 | // The decoded value 109 | public static long Decode(this BinaryReader reader) 110 | { 111 | long num = 0L; 112 | int num2 = 0; 113 | while (true) 114 | { 115 | long num3 = reader.ReadByte(); 116 | num += (num3 & 0x7F) << num2 * 7; 117 | if ((num3 & 0x80) == 0L) 118 | { 119 | break; 120 | } 121 | 122 | num2++; 123 | } 124 | 125 | return num; 126 | } 127 | } -------------------------------------------------------------------------------- /Globalusing.cs: -------------------------------------------------------------------------------- 1 | global using static TorchSharp.torch.nn; 2 | global using static TorchSharp.torch; 3 | global using TorchSharp.Modules; 4 | global using TorchSharp; -------------------------------------------------------------------------------- /IModelConfig.cs: -------------------------------------------------------------------------------- 1 | using TorchSharp; 2 | using static TorchSharp.torch; 3 | 4 | public interface IModelConfigLoader 5 | { 6 | T LoadFromModelConfig( 7 | string pretrainedModelNameOrPath, 8 | string configName = "config.json", 9 | string modelWeightName = "diffusion_pytorch_model", 10 | bool useSafeTensor = true, 11 | ScalarType torchDtype = ScalarType.Float32); 12 | } -------------------------------------------------------------------------------- /Pipelines/StableDiffusionPipeline.cs: -------------------------------------------------------------------------------- 1 | namespace SD; 2 | public class StableDiffusionPipelineOutput 3 | { 4 | public StableDiffusionPipelineOutput(Tensor images) 5 | { 6 | Images = images; 7 | } 8 | 9 | /// 10 | /// The generated images. size (batch_size, ...). 11 | /// 12 | public Tensor Images { get; } 13 | } 14 | public class StableDiffusionPipeline 15 | { 16 | private readonly ScalarType defaultDtype; 17 | private readonly int vae_scale_factor; 18 | private DeviceType device = DeviceType.CPU; 19 | 20 | public StableDiffusionPipeline( 21 | AutoencoderKL vae, 22 | CLIPTextModel text_encoder, 23 | BPETokenizer tokenizer, 24 | UNet2DConditionModel unet, 25 | DDIMScheduler scheduler, 26 | ScalarType dtype = ScalarType.Float32) // todo: safety checker, feature extractor and image encoder 27 | { 28 | this.defaultDtype = dtype; 29 | this.vae = vae; 30 | this.text_encoder = text_encoder; 31 | this.tokenizer = tokenizer; 32 | this.unet = unet; 33 | this.scheduler = scheduler; 34 | 35 | this.vae_scale_factor = Convert.ToInt32(Math.Pow(2, this.vae.Config.BlockOutChannels.Length - 1)); 36 | } 37 | 38 | public void To(DeviceType device) 39 | { 40 | if (device != this.device) 41 | { 42 | this.device = device; 43 | this.text_encoder.to(device); 44 | this.vae.to(device); 45 | this.unet.to(device); 46 | } 47 | } 48 | 49 | public AutoencoderKL vae { get; } 50 | public CLIPTextModel text_encoder { get; } 51 | public BPETokenizer tokenizer { get; } 52 | public UNet2DConditionModel unet { get; } 53 | public DDIMScheduler scheduler { get; } 54 | 55 | /// 56 | /// Run the stable diffusion pipeline. 57 | /// 58 | /// he prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds` 59 | /// The height in pixels of the generated image. defalt to unet.sample_size * vae_scale_factor 60 | /// The width in pixels of the generated image. defalt to unet.sample_size * vae_scale_factor 61 | /// The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. 62 | /// Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument 63 | /// in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is 64 | /// passed will be used. Must be in descending order. 65 | /// A higher guidance scale value encourages the model to generate images closely linked to the text 66 | /// `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. 67 | /// The prompt or prompts to guide what to not include in image generation. Ignored when not using guidance (`guidance_scale < 1`). 68 | /// The number of images to generate per prompt. 69 | /// Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies 70 | /// to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. 71 | /// A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. 72 | /// Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image 73 | /// generation. Can be used to tweak the same generation with different prompts. If not provided, a latents 74 | /// tensor is generated by sampling using the supplied random `generator`. 75 | /// Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not 76 | /// provided, text embeddings are generated from the `prompt` input argument. 77 | /// Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). 78 | /// If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. 79 | /// Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). 80 | /// Guidance rescale factor should fix overexposure when using zero terminal SNR. 81 | /// 82 | public StableDiffusionPipelineOutput Run( 83 | string? prompt = null, 84 | int? height = null, 85 | int? width = null, 86 | int num_inference_steps = 50, 87 | int[]? timesteps = null, 88 | float guidance_scale = 7.5f, 89 | string? negative_prompt = null, 90 | int num_images_per_prompt = 1, 91 | float eta = 0.0f, 92 | Generator? generator = null, 93 | Tensor? latents = null, 94 | Tensor? prompt_embeds = null, 95 | Tensor? negative_prompt_embeds = null, 96 | float guidance_rescale = 0.0f) 97 | { 98 | using var _ = torch.no_grad(); 99 | height = height ?? unet.Config.SampleSize * this.vae_scale_factor; 100 | width = width ?? unet.Config.SampleSize * this.vae_scale_factor; 101 | var do_classifier_free_guidance = guidance_scale > 1.0f && this.unet.Config.TimeCondProjDim is null; 102 | 103 | if (prompt is not null && prompt_embeds is not null) 104 | { 105 | throw new ArgumentException("Only one of `prompt` or `prompt_embeds` should be passed."); 106 | } 107 | 108 | if (negative_prompt is not null && negative_prompt_embeds is not null) 109 | { 110 | throw new ArgumentException("Only one of `negative_prompt` or `negative_prompt_embeds` should be passed."); 111 | } 112 | // todo 113 | // deal with lora 114 | 115 | int batch_size = 1; 116 | if (prompt_embeds is not null) 117 | { 118 | batch_size = prompt_embeds.IntShape()[0]; 119 | } 120 | 121 | prompt_embeds = this.EncodePrompt(batch_size, prompt, prompt_embeds, num_images_per_prompt); 122 | if (do_classifier_free_guidance) 123 | { 124 | if (negative_prompt is null && negative_prompt_embeds is null) 125 | { 126 | negative_prompt = ""; 127 | } 128 | 129 | negative_prompt_embeds = this.EncodePrompt(batch_size, negative_prompt, negative_prompt_embeds, num_images_per_prompt); 130 | 131 | // For classifier free guidance, we need to do two forward passes. 132 | // Here we concatenate the unconditional and text embeddings into a single batch 133 | // to avoid doing two forward passes 134 | 135 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]); 136 | } 137 | 138 | // prepare timesteps 139 | (var time_steps_tensor, num_inference_steps) = this.RetireveTimesteps(num_inference_steps, timesteps); 140 | 141 | // prepare latent variables 142 | var num_channels_latents = this.unet.Config.InChannels; 143 | latents = this.PrepareLatents( 144 | batch_size * num_images_per_prompt, 145 | num_channels_latents, 146 | width!.Value, 147 | height!.Value, 148 | device, 149 | dtype: this.defaultDtype, 150 | generator: generator, 151 | latents: latents); 152 | 153 | // denosing loop 154 | for(int i = 0; i!= num_inference_steps; i++) 155 | { 156 | var step = (int)time_steps_tensor[i].ToInt64(); 157 | Console.WriteLine($"Step {step}"); 158 | // expand the latents if we are doing classifier free guidance 159 | var latent_model_input = !do_classifier_free_guidance ? latents : torch.cat([latents, latents], 0); 160 | latent_model_input = this.scheduler.ScaleModelInput(latent_model_input, step); 161 | // predict noise residual 162 | Tensor noise_pred; 163 | using (var __ = NewDisposeScope()) 164 | { 165 | var unetInput = new UNet2DConditionModelInput( 166 | sample: latent_model_input, 167 | timestep: time_steps_tensor[i], 168 | encoderHiddenStates: prompt_embeds); 169 | noise_pred = this.unet.forward(unetInput).MoveToOuterDisposeScope(); 170 | } 171 | latent_model_input.Peek("latent_model_input"); 172 | prompt_embeds.Peek("prompt_embeds"); 173 | noise_pred.Peek("noise_pred"); 174 | 175 | if (do_classifier_free_guidance) 176 | { 177 | var chunk = noise_pred.chunk(2, 0); 178 | var noise_pred_uncond = chunk[0]; 179 | var noise_pred_text = chunk[1]; 180 | 181 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond); 182 | } 183 | 184 | // compute the previous noisy sample x_t -> x_{t-1} 185 | latents = this.scheduler.Step( 186 | noise_pred, 187 | step, 188 | latents).PrevSample; 189 | } 190 | 191 | // decode to image tensor 192 | var image_tensor = this.vae.decode(latents / this.vae.Config.ScalingFactor); 193 | return new StableDiffusionPipelineOutput(image_tensor); 194 | } 195 | 196 | public Tensor PrepareLatents( 197 | int batch_size, 198 | int num_channels_latents, 199 | int width, 200 | int height, 201 | DeviceType device, 202 | ScalarType dtype = ScalarType.Float32, 203 | Generator? generator = null, 204 | Tensor? latents = null) 205 | { 206 | long[] shape = [batch_size, num_channels_latents, height / this.vae_scale_factor, width / this.vae_scale_factor]; 207 | if (latents is null) 208 | { 209 | latents = torch.randn(shape, dtype: dtype, generator: generator).to(device); 210 | } 211 | else 212 | { 213 | latents = latents.to(dtype).to(device); 214 | } 215 | 216 | latents = latents * this.scheduler.InitNoiseSigma; 217 | return latents; 218 | } 219 | 220 | /// 221 | /// Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call 222 | /// 223 | /// The number of diffusion steps used when generating samples with a pre-trained model. 224 | /// If used, `timesteps` must be `None`. 225 | /// Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default 226 | /// timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` 227 | /// must be `None` 228 | /// A tuple where the first element is the timestep schedule from the scheduler and the 229 | /// second element is the number of inference steps. 230 | public (Tensor, int) RetireveTimesteps( 231 | int? num_inference_steps = null, 232 | int[]? timesteps = null) 233 | { 234 | if (num_inference_steps is not null && timesteps is not null) 235 | { 236 | throw new ArgumentException("Only one of `num_inference_steps` or `timesteps` should be passed."); 237 | } 238 | 239 | if (num_inference_steps is null && timesteps is null) 240 | { 241 | throw new ArgumentException("Either `num_inference_steps` or `timesteps` must be passed."); 242 | } 243 | 244 | if (num_inference_steps is not null) 245 | { 246 | this.scheduler.SetTimesteps(num_inference_steps.Value); 247 | 248 | return (this.scheduler.TimeSteps, this.scheduler.TimeSteps.IntShape()[0]); 249 | } 250 | else 251 | { 252 | this.scheduler.SetTimesteps(timesteps: timesteps); 253 | return (this.scheduler.TimeSteps, timesteps!.Length); 254 | } 255 | } 256 | 257 | public Tensor EncodePrompt( 258 | int batch_size, 259 | string? prompt = null, 260 | Tensor? prompt_embeds = null, 261 | int num_images_per_prompt = 1) 262 | { 263 | if (prompt is null && prompt_embeds is null) 264 | { 265 | throw new ArgumentException("Either `prompt` or `prompt_embeds` must be passed."); 266 | } 267 | 268 | if (prompt is not null && prompt_embeds is not null) 269 | { 270 | throw new ArgumentException("Only one of `prompt` or `prompt_embeds` should be passed."); 271 | } 272 | 273 | if (prompt is string) 274 | { 275 | // todo 276 | // enable attention_mask in tokenizer 277 | 278 | var text_inputs_id = this.tokenizer.Encode(prompt, true, true, padding: "max_length", maxLength: tokenizer.ModelMaxLength); 279 | var id_tensor = torch.tensor(text_inputs_id, dtype: ScalarType.Int64).reshape(1, -1).to(this.device); 280 | var output = this.text_encoder.forward(id_tensor, attention_mask: null); 281 | prompt_embeds = output.LastHiddenState; 282 | } 283 | 284 | var seql_len = prompt_embeds!.IntShape()[1]; 285 | // duplicate text embeddings for each generation per prompt, using mps friendly method 286 | prompt_embeds = prompt_embeds!.repeat([1, num_images_per_prompt, 1]); 287 | prompt_embeds = prompt_embeds.reshape(batch_size * num_images_per_prompt, seql_len, -1); 288 | 289 | return prompt_embeds; 290 | } 291 | 292 | public static StableDiffusionPipeline FromPretrained( 293 | string modelWeightFolder, 294 | string vaeFolder = "vae", 295 | string textModelFolder = "text_encoder", 296 | string schedulerFolder = "scheduler", 297 | string unetFolder = "unet", 298 | string tokenizerFolder = "tokenizer", 299 | ScalarType torchDtype = ScalarType.Float32) 300 | { 301 | var unetModelPath = Path.Join(modelWeightFolder, unetFolder); 302 | var tokenzierModelPath = Path.Join(modelWeightFolder, tokenizerFolder); 303 | var textModelPath = Path.Join(modelWeightFolder, textModelFolder); 304 | var schedulerModelPath = Path.Join(modelWeightFolder, schedulerFolder); 305 | var vaeModelPath = Path.Join(modelWeightFolder, vaeFolder); 306 | var tokenizer = BPETokenizer.FromPretrained(tokenzierModelPath); 307 | var clipTextModel = CLIPTextModel.FromPretrained(textModelPath, torchDtype: torchDtype); 308 | var unet = UNet2DConditionModel.FromPretrained(unetModelPath, torchDtype: torchDtype); 309 | var vae = AutoencoderKL.FromPretrained(vaeModelPath, torchDtype: torchDtype); 310 | var scheduler = DDIMScheduler.FromPretrained(schedulerModelPath); 311 | 312 | var pipeline = new StableDiffusionPipeline( 313 | vae: vae, 314 | text_encoder: clipTextModel, 315 | unet: unet, 316 | tokenizer: tokenizer, 317 | scheduler: scheduler, 318 | dtype: torchDtype); 319 | 320 | return pipeline; 321 | } 322 | } -------------------------------------------------------------------------------- /Program.cs: -------------------------------------------------------------------------------- 1 | using System.Runtime.InteropServices; 2 | using TorchSharp; 3 | using SD; 4 | var dtype = ScalarType.Float16; 5 | var device = DeviceType.CUDA; 6 | var outputFolder = "img"; 7 | torchvision.io.DefaultImager = new torchvision.io.SkiaImager(); 8 | 9 | if (!Directory.Exists(outputFolder)) 10 | { 11 | Directory.CreateDirectory(outputFolder); 12 | } 13 | 14 | // Comment out the following two line and install torchsharp-cuda package if your machine support Cuda 12 15 | // var libTorch = "/home/xiaoyuz/diffusers/venv/lib/python3.8/site-packages/torch/lib/libtorch.so"; 16 | // NativeLibrary.Load(libTorch); 17 | torch.InitializeDeviceType(device); 18 | if (!torch.cuda.is_available()) 19 | { 20 | device = DeviceType.CPU; 21 | } 22 | 23 | var input = "a photo of cat chasing after dog"; 24 | var modelFolder = @"C:\Users\xiaoyuz\source\repos\stable-diffusion-2\"; 25 | var pipeline = StableDiffusionPipeline.FromPretrained(modelFolder, torchDtype: dtype); 26 | pipeline.To(device); 27 | 28 | var output = pipeline.Run( 29 | prompt: input, 30 | width: 1020, 31 | height: 768, 32 | num_inference_steps: 50 33 | ); 34 | 35 | var decoded_images = torch.clamp((output.Images + 1.0) / 2.0, 0.0, 1.0); 36 | 37 | for(int i = 0; i!= decoded_images.shape[0]; ++i) 38 | { 39 | var savedPath = Path.Join(outputFolder, $"{i}.png"); 40 | var image = decoded_images[i]; 41 | image = (image * 255.0).to(torch.ScalarType.Byte).cpu(); 42 | torchvision.io.write_image(image, savedPath, torchvision.ImageFormat.Png); 43 | 44 | Console.WriteLine($"save image to {savedPath}, enjoy"); 45 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Torchsharp Stable Diffusion 2 2 | 3 | This repo contains a torchsharp implementation for [stable diffusion 2 model](https://github.com/Stability-AI/stablediffusion). 4 | 5 | ## Quick Start 6 | To run the stable diffusion 2 model on your local machine, the following prerequisites are required: 7 | - dotnet 6 8 | - git lfs, this is to download the model file from hugging face 9 | 10 | ### Step 1: Get the model weight from huggingface 11 | To get stable-diffusion-2 model weight, run the following command to download model weight from huggingface. Be sure to have git lfs installed. 12 | ```bash 13 | git clone https://huggingface.co/stabilityai/stable-diffusion-2 14 | ``` 15 | > [!Note] 16 | > To load fp32 model weight into GPU, it's recommended to have at least 16GB of GPU memory if you want to generate 768 * 768 size image. Loading fp16 model weight requires around 8GB of GPU memory. 17 | 18 | ### Step 2: Run the model 19 | Clone this repo and replace the `modelFolder` folder with where you download huggingface model weight in [Program.cs](./Program.cs#L25) 20 | 21 | Then run the following command to start the model: 22 | ```bash 23 | dotnet run 24 | ``` 25 | 26 | ### Example output 27 | ![a photo of an astronaut riding a horse on mars](./img/a%20photo%20of%20an%20astronaut%20riding%20a%20horse%20on%20mars.png) 28 | (a photo of an astronaut riding a horse on mars) 29 | 30 | ### Load fp16 model weight for faster and more GPU memory efficient inference 31 | You can load fp16 model weight by setting `dtype` to ` ScalarType.Float16` in [Program.cs](./Program.cs#L4). The inference on fp16 model weight is faster and more GPU memory efficient. 32 | 33 | > [!Note] 34 | > fp16 model only work with GPU because some operators doesn't work with fp16 and cpu. 35 | 36 | ### Update log 37 | #### Update on 2024/04/03 38 | - Add support for loading fp16 model weight 39 | ### See also 40 | - [Torchsharp-llama](https://github.com/LittleLittleCloud/Torchsharp-llama): A torchsharp implementation for llama 2 model 41 | - [Torchsharp-phi](https://github.com/LittleLittleCloud/Torchsharp-phi): A torchsharp implementation for phi model 42 | -------------------------------------------------------------------------------- /Scheduler/DDIMSchedulerConfig.cs: -------------------------------------------------------------------------------- 1 | using System.Text.Json.Serialization; 2 | 3 | namespace SD; 4 | 5 | public class DDIMSchedulerConfig 6 | { 7 | public DDIMSchedulerConfig( 8 | int numTrainTimesteps = 1000, 9 | float betaStart = 0.0001f, 10 | float betaEnd = 0.02f, 11 | string betaSchedule = "linear", 12 | float[]? trainedBetas = null, 13 | bool clipSample = true, 14 | bool setAlphaToOne = true, 15 | int stepsOffset = 0, 16 | string predictionType = "epsilon", 17 | bool thresholding = false, 18 | float dynamicThresholdingRatio = 0.995f, 19 | float clipSampleRange = 1.0f, 20 | float sampleMaxValue = 1.0f, 21 | string timestepSpacing = "leading", 22 | bool rescaleBetasZeroSnr = false) 23 | { 24 | NumTrainTimesteps = numTrainTimesteps; 25 | BetaStart = betaStart; 26 | BetaEnd = betaEnd; 27 | BetaSchedule = betaSchedule; 28 | TrainedBetas = trainedBetas; 29 | ClipSample = clipSample; 30 | SetAlphaToOne = setAlphaToOne; 31 | StepsOffset = stepsOffset; 32 | PredictionType = predictionType; 33 | Thresholding = thresholding; 34 | DynamicThresholdingRatio = dynamicThresholdingRatio; 35 | ClipSampleRange = clipSampleRange; 36 | SampleMaxValue = sampleMaxValue; 37 | TimestepSpacing = timestepSpacing; 38 | RescaleBetasZeroSnr = rescaleBetasZeroSnr; 39 | } 40 | 41 | [JsonPropertyName("num_train_timesteps")] 42 | public int NumTrainTimesteps { get; set; } = 1000; 43 | 44 | [JsonPropertyName("beta_start")] 45 | public float BetaStart { get; set; } = 0.0001f; 46 | 47 | [JsonPropertyName("beta_end")] 48 | public float BetaEnd { get; set; } = 0.02f; 49 | 50 | [JsonPropertyName("beta_schedule")] 51 | public string BetaSchedule { get; set; } = "linear"; 52 | 53 | [JsonPropertyName("trained_betas")] 54 | public float[]? TrainedBetas { get; set; } 55 | 56 | [JsonPropertyName("clip_sample")] 57 | public bool ClipSample { get; set; } = true; 58 | 59 | [JsonPropertyName("set_alpha_to_one")] 60 | public bool SetAlphaToOne { get; set; } = true; 61 | 62 | [JsonPropertyName("steps_offset")] 63 | public int StepsOffset { get; set; } = 0; 64 | 65 | [JsonPropertyName("prediction_type")] 66 | public string PredictionType { get; set; } = "epsilon"; 67 | 68 | [JsonPropertyName("thresholding")] 69 | public bool Thresholding { get; set; } = false; 70 | 71 | [JsonPropertyName("dynamic_thresholding_ratio")] 72 | public float DynamicThresholdingRatio { get; set; } = 0.995f; 73 | 74 | [JsonPropertyName("clip_sample_range")] 75 | public float ClipSampleRange { get; set; } = 1.0f; 76 | 77 | [JsonPropertyName("sample_max_value")] 78 | public float SampleMaxValue { get; set; } = 1.0f; 79 | 80 | [JsonPropertyName("timestep_spacing")] 81 | public string TimestepSpacing { get; set; } = "leading"; 82 | 83 | [JsonPropertyName("rescale_betas_zero_snr")] 84 | public bool RescaleBetasZeroSnr { get; set; } = false; 85 | } -------------------------------------------------------------------------------- /Tests/Approvals/AutoEncoderKLTest.DecoderForwardTest.approved.txt: -------------------------------------------------------------------------------- 1 | autokl_decoder_forward: sum: 969.2767 dtype: Float32 shape: [1,3,768,768] -------------------------------------------------------------------------------- /Tests/Approvals/AutoEncoderKLTest.DecoderShapeTest.approved.txt: -------------------------------------------------------------------------------- 1 |  -------------------------------------------------------------------------------- /Tests/Approvals/AutoEncoderKLTest.EncoderForwardTest.approved.txt: -------------------------------------------------------------------------------- 1 | autokl_encoder_forward: sum: 129.3368 dtype: Float32 shape: [1,8,64,64] -------------------------------------------------------------------------------- /Tests/Approvals/CLIPTextModelTest.Fp16TextModelForwardTest.approved.txt: -------------------------------------------------------------------------------- 1 | clip_text_model_forward: sum: 79.3837 dtype: Float16 shape: [2,7,1024] 2 | clip_text_model_forward: sum: 29.3969 dtype: Float16 shape: [2,1024] 3 | -------------------------------------------------------------------------------- /Tests/Approvals/CLIPTextModelTest.TextModelForwardTest.approved.txt: -------------------------------------------------------------------------------- 1 | clip_text_model_forward: sum: 79.3837 dtype: Float32 shape: [2,7,1024] 2 | clip_text_model_forward: sum: 29.3946 dtype: Float32 shape: [2,1024] 3 | -------------------------------------------------------------------------------- /Tests/Approvals/DDIMSchedulerTest.StepTest.approved.txt: -------------------------------------------------------------------------------- 1 | step_output.PrevSample: sum: 102.3984 dtype: Float32 shape: [1,4,64,64] 2 | step_output.PredOriginalSample: sum: 102.3984 dtype: Float32 shape: [1,4,64,64] 3 | -------------------------------------------------------------------------------- /Tests/Approvals/StableDiffusionPipelineTest.GenerateCatImageTest.approved.txt: -------------------------------------------------------------------------------- 1 | images: sum: 976.4648 dtype: Float32 shape: [1,3,768,768] -------------------------------------------------------------------------------- /Tests/Approvals/UNet2DConditionModelTest.ForwardTest.approved.txt: -------------------------------------------------------------------------------- 1 | prompt_embeds: sum: 55.8623 dtype: Float32 shape: [1,7,1024] 2 | output: sum: 144.5073 dtype: Float32 shape: [1,4,96,96] 3 | output: sum: 144.5071 dtype: Float32 shape: [1,4,96,96] 4 | output: sum: 144.5068 dtype: Float32 shape: [1,4,96,96] 5 | output: sum: 144.5067 dtype: Float32 shape: [1,4,96,96] 6 | output: sum: 144.5066 dtype: Float32 shape: [1,4,96,96] 7 | -------------------------------------------------------------------------------- /Tests/AutoEncoderKL.test.cs: -------------------------------------------------------------------------------- 1 | using static TorchSharp.torch.nn; 2 | using static TorchSharp.torch; 3 | using TorchSharp.Modules; 4 | using Xunit; 5 | using ApprovalTests; 6 | using ApprovalTests.Reporters; 7 | using ApprovalTests.Namers; 8 | using TorchSharp; 9 | 10 | namespace SD; 11 | 12 | public class AutoEncoderKLTest 13 | { 14 | [Fact] 15 | [UseReporter(typeof(DiffReporter))] 16 | [UseApprovalSubdirectory("Approvals")] 17 | public async Task ShapeTest() 18 | { 19 | var modelWeightFolder = "/home/xiaoyuz/stable-diffusion-2/vae"; 20 | var autoKL = AutoencoderKL.FromPretrained(modelWeightFolder, torchDtype: ScalarType.Float32); 21 | var state_dict_str = autoKL.Peek(); 22 | Approvals.Verify(state_dict_str); 23 | } 24 | 25 | [Fact] 26 | [UseReporter(typeof(DiffReporter))] 27 | [UseApprovalSubdirectory("Approvals")] 28 | public async Task Fp16ShapeTest() 29 | { 30 | var modelWeightFolder = "/home/xiaoyuz/stable-diffusion-2/vae"; 31 | var autoKL = AutoencoderKL.FromPretrained(modelWeightFolder, torchDtype: ScalarType.Float16); 32 | var state_dict_str = autoKL.Peek(); 33 | Approvals.Verify(state_dict_str); 34 | } 35 | 36 | [Fact] 37 | [UseReporter(typeof(DiffReporter))] 38 | [UseApprovalSubdirectory("Approvals")] 39 | public async Task EncoderForwardTest() 40 | { 41 | var modelWeightFolder = "/home/xiaoyuz/stable-diffusion-2/vae"; 42 | var autoKL = AutoencoderKL.FromPretrained(modelWeightFolder, torchDtype: ScalarType.Float32); 43 | var latent = torch.arange(0, 1 * 3 * 512 * 512, dtype: ScalarType.Float32); 44 | latent = latent.reshape(1, 3, 512, 512); 45 | 46 | var result = autoKL.Encoder.forward(latent); 47 | var str = result.Peek("autokl_encoder_forward"); 48 | Approvals.Verify(str); 49 | } 50 | 51 | [Fact] 52 | [UseReporter(typeof(DiffReporter))] 53 | [UseApprovalSubdirectory("Approvals")] 54 | public async Task DecoderForwardTest() 55 | { 56 | var modelWeightFolder = "/home/xiaoyuz/stable-diffusion-2/vae"; 57 | var autoKL = AutoencoderKL.FromPretrained(modelWeightFolder, torchDtype: ScalarType.Float32); 58 | var latent = torch.arange(0, 1 * 4 * 96 * 96, dtype: ScalarType.Float32); 59 | latent = latent.reshape(1, 4, 96, 96); 60 | 61 | var result = autoKL.Decoder.forward(latent); 62 | var str = result.Peek("autokl_decoder_forward"); 63 | Approvals.Verify(str); 64 | } 65 | } -------------------------------------------------------------------------------- /Tests/CLIPTextModel.test.cs: -------------------------------------------------------------------------------- 1 | using static TorchSharp.torch.nn; 2 | using static TorchSharp.torch; 3 | using TorchSharp.Modules; 4 | using Xunit; 5 | using ApprovalTests; 6 | using ApprovalTests.Reporters; 7 | using ApprovalTests.Namers; 8 | using TorchSharp; 9 | using System.Text; 10 | using System.Runtime.InteropServices; 11 | 12 | namespace SD; 13 | 14 | public class CLIPTextModelTest 15 | { 16 | [Fact] 17 | [UseReporter(typeof(DiffReporter))] 18 | [UseApprovalSubdirectory("Approvals")] 19 | public async Task ShapeTest() 20 | { 21 | var modelWeightFolder = "/home/xiaoyuz/stable-diffusion-2/text_encoder"; 22 | var clipTextModel = CLIPTextModel.FromPretrained(modelWeightFolder, torchDtype: ScalarType.Float32); 23 | var state_dict_str = clipTextModel.Peek(); 24 | Approvals.Verify(state_dict_str); 25 | } 26 | 27 | [Fact] 28 | [UseReporter(typeof(DiffReporter))] 29 | [UseApprovalSubdirectory("Approvals")] 30 | public async Task Fp16ShapeTest() 31 | { 32 | var modelWeightFolder = "/home/xiaoyuz/stable-diffusion-2/text_encoder"; 33 | var clipTextModel = CLIPTextModel.FromPretrained(modelWeightFolder, torchDtype: ScalarType.Float16); 34 | var state_dict_str = clipTextModel.Peek(); 35 | Approvals.Verify(state_dict_str); 36 | } 37 | 38 | [Fact] 39 | [UseReporter(typeof(DiffReporter))] 40 | [UseApprovalSubdirectory("Approvals")] 41 | public async Task TextModelForwardTest() 42 | { 43 | var modelWeightFolder = "/home/xiaoyuz/stable-diffusion-2/text_encoder"; 44 | var clipTextModel = CLIPTextModel.FromPretrained(modelWeightFolder, torchDtype: ScalarType.Float32); 45 | long[] input_ids = [49406, 320, 1125, 539, 320, 2368, 49407, 49406, 320, 1125, 539, 320, 1929, 49407]; // a photo of a cat a photo of a dog 46 | 47 | long[] attention_mask = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]; 48 | 49 | var input_ids_tensor = input_ids.ToTensor([2, 7]); 50 | var attention_mask_tensor = attention_mask.ToTensor([2, 7]); 51 | 52 | var result = clipTextModel.forward(input_ids_tensor, attention_mask_tensor); 53 | var last_hidden_state = result.LastHiddenState; 54 | var pooled_output = result.PoolerOutput; 55 | 56 | var last_hidden_state_str = last_hidden_state.Peek("clip_text_model_forward"); 57 | var pooled_output_str = pooled_output.Peek("clip_text_model_forward"); 58 | var sb = new StringBuilder(); 59 | sb.AppendLine(last_hidden_state_str); 60 | sb.AppendLine(pooled_output_str); 61 | 62 | Approvals.Verify(sb.ToString()); 63 | } 64 | 65 | [Fact(Skip = "need cuda")] 66 | [UseReporter(typeof(DiffReporter))] 67 | [UseApprovalSubdirectory("Approvals")] 68 | public async Task Fp16TextModelForwardTest() 69 | { 70 | // Comment out the following two line and install torchsharp-cuda package if your machine support Cuda 12 71 | var libTorch = "/home/xiaoyuz/diffusers/venv/lib/python3.8/site-packages/torch/lib/libtorch.so"; 72 | NativeLibrary.Load(libTorch); 73 | var dtype = ScalarType.Float16; 74 | var device = DeviceType.CUDA; 75 | torch.InitializeDeviceType(device); 76 | var modelWeightFolder = "/home/xiaoyuz/stable-diffusion-2/text_encoder"; 77 | var clipTextModel = CLIPTextModel.FromPretrained(modelWeightFolder, torchDtype: dtype); 78 | clipTextModel = clipTextModel.to(device); 79 | long[] input_ids = [49406, 320, 1125, 539, 320, 2368, 49407, 49406, 320, 1125, 539, 320, 1929, 49407]; // a photo of a cat a photo of a dog 80 | long[] attention_mask = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]; 81 | 82 | var input_ids_tensor = input_ids.ToTensor([2, 7]).to(device); 83 | var attention_mask_tensor = attention_mask.ToTensor([2, 7]).to(device); 84 | input_ids_tensor.Peek("input_ids_tensor"); 85 | attention_mask_tensor.Peek("attention_mask_tensor"); 86 | var result = clipTextModel.forward(input_ids_tensor, attention_mask_tensor); 87 | var last_hidden_state = result.LastHiddenState; 88 | var pooled_output = result.PoolerOutput; 89 | 90 | var last_hidden_state_str = last_hidden_state.Peek("clip_text_model_forward"); 91 | var pooled_output_str = pooled_output.Peek("clip_text_model_forward"); 92 | var sb = new StringBuilder(); 93 | sb.AppendLine(last_hidden_state_str); 94 | sb.AppendLine(pooled_output_str); 95 | 96 | Approvals.Verify(sb.ToString()); 97 | } 98 | } -------------------------------------------------------------------------------- /Tests/DDIMScheduler.test.cs: -------------------------------------------------------------------------------- 1 | using static TorchSharp.torch.nn; 2 | using static TorchSharp.torch; 3 | using TorchSharp.Modules; 4 | using Xunit; 5 | using ApprovalTests; 6 | using ApprovalTests.Reporters; 7 | using ApprovalTests.Namers; 8 | using TorchSharp; 9 | using System.Text; 10 | 11 | namespace SD; 12 | 13 | public class DDIMSchedulerTest 14 | { 15 | [Fact] 16 | [UseReporter(typeof(DiffReporter))] 17 | [UseApprovalSubdirectory("Approvals")] 18 | public async Task StepTest() 19 | { 20 | var dtype = ScalarType.Float32; 21 | var device = DeviceType.CPU; 22 | var path = "/home/xiaoyuz/stable-diffusion-2/scheduler"; 23 | var ddim = DDIMScheduler.FromPretrained(path); 24 | var timestep = 1; 25 | ddim.SetTimesteps(timestep); 26 | 27 | var latent = torch.arange(0, 1 * 4 * 64 * 64); 28 | latent = latent.reshape(1, 4, 64, 64).to(dtype).to(device); 29 | 30 | var step_output = ddim.Step(latent, timestep, latent); 31 | var sb = new StringBuilder(); 32 | sb.AppendLine(step_output.PrevSample.Peek("step_output.PrevSample")); 33 | sb.AppendLine(step_output.PredOriginalSample!.Peek("step_output.PredOriginalSample")); 34 | 35 | Approvals.Verify(sb.ToString()); 36 | } 37 | } -------------------------------------------------------------------------------- /Tests/StableDiffusionPipeline.test.cs: -------------------------------------------------------------------------------- 1 | using static TorchSharp.torch.nn; 2 | using static TorchSharp.torch; 3 | using TorchSharp.Modules; 4 | using Xunit; 5 | using ApprovalTests; 6 | using ApprovalTests.Reporters; 7 | using ApprovalTests.Namers; 8 | using TorchSharp; 9 | using System.Text; 10 | 11 | namespace SD; 12 | 13 | public class StableDiffusionPipelineTest 14 | { 15 | [Fact] 16 | [UseReporter(typeof(DiffReporter))] 17 | [UseApprovalSubdirectory("Approvals")] 18 | public void GenerateCatImageTest() 19 | { 20 | var tokenzierModelPath = "/home/xiaoyuz/stable-diffusion-2/tokenizer"; 21 | var unetModelPath = "/home/xiaoyuz/stable-diffusion-2/unet"; 22 | var textModelPath = "/home/xiaoyuz/stable-diffusion-2/text_encoder"; 23 | var schedulerModelPath = "/home/xiaoyuz/stable-diffusion-2/scheduler"; 24 | var vaeModelPath = "/home/xiaoyuz/stable-diffusion-2/vae"; 25 | 26 | var tokenizer = BPETokenizer.FromPretrained(tokenzierModelPath); 27 | var unet = UNet2DConditionModel.FromPretrained(unetModelPath, torchDtype: ScalarType.Float32); 28 | var clipTextModel = CLIPTextModel.FromPretrained(textModelPath, torchDtype: ScalarType.Float32); 29 | var ddim = DDIMScheduler.FromPretrained(schedulerModelPath); 30 | var vae = AutoencoderKL.FromPretrained(vaeModelPath); 31 | var dtype = ScalarType.Float32; 32 | var device = DeviceType.CPU; 33 | var generator = torch.manual_seed(0); 34 | var input = "a photo of a cat"; 35 | var latent = torch.arange(0, 1 * 4 * 96 * 96); 36 | latent = latent.reshape(1, 4, 96, 96).to(dtype).to(device); 37 | var pipeline = new StableDiffusionPipeline( 38 | vae: vae, 39 | text_encoder: clipTextModel, 40 | unet: unet, 41 | tokenizer: tokenizer, 42 | scheduler: ddim); 43 | 44 | var images = pipeline.Run( 45 | prompt: input, 46 | num_inference_steps: 5, 47 | generator: generator, 48 | latents: latent); 49 | 50 | var sb = new StringBuilder(); 51 | sb.Append(images.Images.Peek("images")); 52 | 53 | Approvals.Verify(sb.ToString()); 54 | } 55 | } -------------------------------------------------------------------------------- /Tests/Tokenizer.test.cs: -------------------------------------------------------------------------------- 1 | using static TorchSharp.torch.nn; 2 | using static TorchSharp.torch; 3 | using TorchSharp.Modules; 4 | using Xunit; 5 | using ApprovalTests; 6 | using ApprovalTests.Reporters; 7 | using ApprovalTests.Namers; 8 | using TorchSharp; 9 | using FluentAssertions; 10 | 11 | namespace SD; 12 | 13 | public class TokenizerTest 14 | { 15 | [Fact] 16 | public void TokenizerTest1() 17 | { 18 | var tokenizerFolder = "/home/xiaoyuz/stable-diffusion-2/tokenizer"; 19 | 20 | var tokenizer = BPETokenizer.FromPretrained(tokenizerFolder); 21 | var input = "a photo of a cat"; 22 | var output = tokenizer.Encode(input, true, true); 23 | output.Should().BeEquivalentTo([49406, 320, 1125, 539, 320, 2368, 49407]); 24 | 25 | // encode with max_length 26 | output = tokenizer.Encode(input, true, true, padding: "max_length", maxLength: 10); 27 | output.Should().BeEquivalentTo([49406, 320, 1125, 539, 320, 2368, 49407, 0, 0, 0]); 28 | } 29 | } -------------------------------------------------------------------------------- /Tests/UNet2DConditionModel.test.cs: -------------------------------------------------------------------------------- 1 | using static TorchSharp.torch.nn; 2 | using static TorchSharp.torch; 3 | using TorchSharp.Modules; 4 | using Xunit; 5 | using ApprovalTests; 6 | using ApprovalTests.Reporters; 7 | using ApprovalTests.Namers; 8 | using TorchSharp; 9 | using System.Text; 10 | 11 | namespace SD; 12 | 13 | public class UNet2DConditionModelTest 14 | { 15 | [Fact] 16 | [UseReporter(typeof(DiffReporter))] 17 | [UseApprovalSubdirectory("Approvals")] 18 | public async Task ShapeTest() 19 | { 20 | var modelWeightFolder = "/home/xiaoyuz/stable-diffusion-2/unet"; 21 | var unet = UNet2DConditionModel.FromPretrained(modelWeightFolder, torchDtype: ScalarType.Float32); 22 | var state_dict_str = unet.Peek(); 23 | Approvals.Verify(state_dict_str); 24 | } 25 | 26 | [Fact] 27 | [UseReporter(typeof(DiffReporter))] 28 | [UseApprovalSubdirectory("Approvals")] 29 | public async Task Fp16ShapeTest() 30 | { 31 | var modelWeightFolder = "/home/xiaoyuz/stable-diffusion-2/unet"; 32 | var unet = UNet2DConditionModel.FromPretrained(modelWeightFolder, torchDtype: ScalarType.Float16); 33 | var state_dict_str = unet.Peek(); 34 | Approvals.Verify(state_dict_str); 35 | } 36 | 37 | [Fact] 38 | [UseReporter(typeof(DiffReporter))] 39 | [UseApprovalSubdirectory("Approvals")] 40 | public async Task ForwardTest() 41 | { 42 | var dtype = ScalarType.Float32; 43 | var device = DeviceType.CPU; 44 | var textModelPath = "/home/xiaoyuz/stable-diffusion-2/text_encoder"; 45 | var clipTextModel = CLIPTextModel.FromPretrained(textModelPath, torchDtype: ScalarType.Float32); 46 | 47 | var unetModelPath = "/home/xiaoyuz/stable-diffusion-2/unet"; 48 | var unet = UNet2DConditionModel.FromPretrained(unetModelPath, torchDtype: ScalarType.Float32); 49 | 50 | var tokenizerPath = "/home/xiaoyuz/stable-diffusion-2/tokenizer"; 51 | var tokenizer = BPETokenizer.FromPretrained(tokenizerPath); 52 | 53 | var latent = torch.arange(0, 1 * 4 * 96 * 96); 54 | latent = latent.reshape(1, 4, 96, 96).to(dtype).to(device); 55 | var input = "a photo of a cat"; 56 | var text = tokenizer.Encode(input, true, true); 57 | var textTensor = torch.tensor(text, dtype: ScalarType.Int64).reshape(1, text.Length).to(device); 58 | var outputs = clipTextModel.forward(textTensor); 59 | var prompt_embeds = outputs.LastHiddenState; 60 | prompt_embeds.Peek("prompt_embeds"); 61 | 62 | long[] t_candidates = [0L, 10L, 100L, 1000L, 2000L]; 63 | var sb = new StringBuilder(); 64 | sb.AppendLine(prompt_embeds.Peek("prompt_embeds")); 65 | 66 | foreach (var t_candidate in t_candidates) 67 | { 68 | var t = torch.tensor(t_candidate, dtype: ScalarType.Int64).to(device); 69 | var unetInput = new UNet2DConditionModelInput(latent, t, prompt_embeds); 70 | var output = unet.forward(unetInput); 71 | sb.AppendLine(output.Peek("output")); 72 | } 73 | 74 | Approvals.Verify(sb.ToString()); 75 | } 76 | } -------------------------------------------------------------------------------- /Tokenizer.cs: -------------------------------------------------------------------------------- 1 | using System.Reflection.PortableExecutable; 2 | using System.Text.Json; 3 | using Microsoft.ML.Tokenizers; 4 | 5 | public class TokenizeDecoder : Microsoft.ML.Tokenizers.TokenizerDecoder 6 | { 7 | private const char spaceReplacement = 'Ġ'; 8 | 9 | private const char newlineReplacement = 'Ċ'; 10 | 11 | private const char carriageReturnReplacement = 'č'; 12 | private string bos = ""; 13 | private string eos = ""; 14 | 15 | public TokenizeDecoder(string bos = "", string eos = "") 16 | { 17 | this.bos = bos; 18 | this.eos = eos; 19 | } 20 | 21 | public override string Decode(IEnumerable tokens) 22 | { 23 | var str = string.Join("", tokens); 24 | str = str.Replace(spaceReplacement, ' '); 25 | str = str.Replace(newlineReplacement, '\n'); 26 | str = str.Replace(carriageReturnReplacement.ToString(), Environment.NewLine); 27 | 28 | if (str.StartsWith(bos)) 29 | { 30 | str = str.Substring(bos.Length); 31 | } 32 | 33 | if (str.EndsWith(eos)) 34 | { 35 | str = str.Substring(0, str.Length - eos.Length); 36 | } 37 | 38 | return str; 39 | } 40 | } 41 | 42 | public class BPETokenizer 43 | { 44 | private Tokenizer tokenizer; 45 | private bool addPrecedingSpace; 46 | 47 | public BPETokenizer( 48 | string vocabPath, 49 | string mergesPath, 50 | bool addPrecedingSpace, 51 | string uknToken, 52 | string bosToken, 53 | string eosToken) 54 | { 55 | this.addPrecedingSpace = addPrecedingSpace; 56 | var bpe = new Bpe(vocabPath, mergesPath, endOfWordSuffix: ""); 57 | this.tokenizer = new Tokenizer(bpe); 58 | this.BosId = this.tokenizer.Model.TokenToId(bosToken) ?? throw new Exception("Failed to get bos id"); 59 | this.EosId = this.tokenizer.Model.TokenToId(eosToken) ?? throw new Exception("Failed to get eos id"); 60 | var decoder = new TokenizeDecoder(this.tokenizer.Model.IdToToken(this.BosId)!, this.tokenizer.Model.IdToToken(this.EosId)!); 61 | this.tokenizer.Decoder = decoder; 62 | } 63 | 64 | public static BPETokenizer FromPretrained( 65 | string folder, 66 | string vocabFile = "vocab.json", 67 | string mergesFile = "merges.txt", 68 | string specialTokensFile = "special_tokens_map.json", 69 | bool addPrecedingSpace = false, 70 | string uknToken = "<|endoftext|>", 71 | string bosToken = "<|startoftext|>", 72 | string eosToken = "<|endoftext|>") 73 | { 74 | var vocabPath = Path.Combine(folder, vocabFile); 75 | var mergesPath = Path.Combine(folder, mergesFile); 76 | var specialTokenMapPath = Path.Combine(folder, specialTokensFile); 77 | 78 | Dictionary? specialTokenMap = null; 79 | // if (File.Exists(Path.Combine(folder, specialTokensFile))) 80 | // { 81 | // specialTokenMap = JsonSerializer.Deserialize>(File.ReadAllText(specialTokenMapPath)) ?? throw new Exception("Failed to load special token map"); 82 | // } 83 | 84 | bosToken = specialTokenMap?.GetValueOrDefault("bos_token") ?? bosToken; 85 | eosToken = specialTokenMap?.GetValueOrDefault("eos_token") ?? eosToken; 86 | uknToken = specialTokenMap?.GetValueOrDefault("unk_token") ?? uknToken; 87 | 88 | return new BPETokenizer(vocabPath, mergesPath, addPrecedingSpace, uknToken, bosToken, eosToken); 89 | } 90 | 91 | public int VocabSize => this.tokenizer.Model.GetVocabSize(); 92 | 93 | public int ModelMaxLength { get; } = 77; 94 | 95 | public int PadId { get; } 96 | 97 | public int BosId { get; } 98 | 99 | public int EosId { get; } 100 | 101 | public string Decode(int[] input) 102 | { 103 | var str = this.tokenizer.Decode(input) ?? throw new Exception("Failed to decode"); 104 | if (this.addPrecedingSpace) 105 | { 106 | str = str.TrimStart(); 107 | } 108 | 109 | return str; 110 | } 111 | 112 | public int TokenToId(string token) 113 | { 114 | return this.tokenizer.Model.TokenToId(token) ?? throw new Exception("Failed to get token id"); 115 | } 116 | 117 | public int[] Encode( 118 | string input, 119 | bool bos = false, 120 | bool eos = false, 121 | string? padding = null, 122 | int? maxLength = null) 123 | { 124 | if (this.addPrecedingSpace) 125 | { 126 | input = " " + input; 127 | } 128 | var tokens = this.tokenizer.Encode(input).Ids.ToArray(); 129 | if (bos) 130 | { 131 | tokens = new int[] { this.BosId }.Concat(tokens).ToArray(); 132 | } 133 | if (eos) 134 | { 135 | tokens = tokens.Concat(new int[] { this.EosId }).ToArray(); 136 | } 137 | 138 | if (padding == "max_length" && maxLength is int maxLen) 139 | { 140 | if (tokens.Length > maxLen) 141 | { 142 | tokens = tokens.Take(maxLen).ToArray(); 143 | } 144 | else if (tokens.Length < maxLen) 145 | { 146 | tokens = tokens.Concat(Enumerable.Repeat(this.PadId, maxLen - tokens.Length)).ToArray(); 147 | } 148 | } 149 | 150 | return tokens; 151 | } 152 | } 153 | -------------------------------------------------------------------------------- /Torchsharp-stable-diffusion-2.csproj: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | Exe 5 | net8.0 6 | SD 7 | enable 8 | enable 9 | preview 10 | CS0414 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | -------------------------------------------------------------------------------- /Torchsharp-stable-diffusion-2.sln: -------------------------------------------------------------------------------- 1 |  2 | Microsoft Visual Studio Solution File, Format Version 12.00 3 | # Visual Studio Version 17 4 | VisualStudioVersion = 17.5.002.0 5 | MinimumVisualStudioVersion = 10.0.40219.1 6 | Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Torchsharp-stable-diffusion-2", "Torchsharp-stable-diffusion-2.csproj", "{82F77C34-779B-4D12-9B9C-B32DC14D77B4}" 7 | EndProject 8 | Global 9 | GlobalSection(SolutionConfigurationPlatforms) = preSolution 10 | Debug|Any CPU = Debug|Any CPU 11 | Release|Any CPU = Release|Any CPU 12 | EndGlobalSection 13 | GlobalSection(ProjectConfigurationPlatforms) = postSolution 14 | {82F77C34-779B-4D12-9B9C-B32DC14D77B4}.Debug|Any CPU.ActiveCfg = Debug|Any CPU 15 | {82F77C34-779B-4D12-9B9C-B32DC14D77B4}.Debug|Any CPU.Build.0 = Debug|Any CPU 16 | {82F77C34-779B-4D12-9B9C-B32DC14D77B4}.Release|Any CPU.ActiveCfg = Release|Any CPU 17 | {82F77C34-779B-4D12-9B9C-B32DC14D77B4}.Release|Any CPU.Build.0 = Release|Any CPU 18 | EndGlobalSection 19 | GlobalSection(SolutionProperties) = preSolution 20 | HideSolutionNode = FALSE 21 | EndGlobalSection 22 | GlobalSection(ExtensibilityGlobals) = postSolution 23 | SolutionGuid = {1711A5EC-30AC-44A6-9E38-638218240697} 24 | EndGlobalSection 25 | EndGlobal 26 | -------------------------------------------------------------------------------- /UNet/AdaGroupNorm.cs: -------------------------------------------------------------------------------- 1 | using static TorchSharp.torch.nn; 2 | using static TorchSharp.torch; 3 | using TorchSharp.Modules; 4 | 5 | namespace SD; 6 | 7 | public class AdaGroupNorm : Module 8 | { 9 | private readonly int embedding_dim; 10 | private readonly int out_dim; 11 | private readonly int num_groups; 12 | private readonly string? act_fn; 13 | private readonly float eps = 1e-5f; 14 | private readonly Module? act; 15 | private readonly Linear linear; 16 | private ScalarType defaultDtype; 17 | public AdaGroupNorm( 18 | int embedding_dim, 19 | int out_dim, 20 | int num_groups, 21 | string? act_fn = null, 22 | float eps = 1e-5f, 23 | ScalarType dtype = ScalarType.Float32) 24 | : base(nameof(AdaGroupNorm)) 25 | { 26 | this.embedding_dim = embedding_dim; 27 | this.out_dim = out_dim; 28 | this.num_groups = num_groups; 29 | this.act_fn = act_fn; 30 | this.eps = eps; 31 | this.defaultDtype = dtype; 32 | 33 | this.act = act_fn != null ? Utils.GetActivation(act_fn) : null; 34 | this.linear = Linear(embedding_dim, out_dim * 2, dtype: dtype); 35 | } 36 | 37 | public override Tensor forward(Tensor x, Tensor emb) 38 | { 39 | if (this.act != null) 40 | { 41 | emb = this.act.forward(emb); 42 | } 43 | 44 | emb = this.linear.forward(emb); 45 | // emb = emb[:, :, None, None] 46 | emb = emb.unsqueeze(2).unsqueeze(3); 47 | // scale, shift = emb.chunk(2, dim=1) 48 | var chunks = emb.chunk(2, 1); 49 | var scale = chunks[0]; 50 | var shift = chunks[1]; 51 | 52 | x = nn.functional.group_norm(x, this.num_groups, eps: this.eps); 53 | x = x * (1+scale) + shift; 54 | 55 | return x; 56 | } 57 | } -------------------------------------------------------------------------------- /UNet/CrossAttnDownBlock2D.cs: -------------------------------------------------------------------------------- 1 | namespace SD; 2 | public class CrossAttnDownBlock2D : Module 3 | { 4 | private readonly bool has_cross_attention; 5 | private readonly int num_attention_heads; 6 | 7 | private readonly ModuleList resnets; 8 | private readonly ModuleList attentions; 9 | private readonly ModuleList? downsamplers = null; 10 | 11 | public CrossAttnDownBlock2D( 12 | int in_channels, 13 | int out_channels, 14 | int temb_channels, 15 | double dropout = 0.0, 16 | int num_layers = 1, 17 | int[]? transformer_layers_per_block = null, 18 | double resnet_eps = 1e-6, 19 | string resnet_time_scale_shift = "default", 20 | string resnet_act_fn = "swish", 21 | int? resnet_groups = 32, 22 | bool resnet_pre_norm = true, 23 | int? num_attention_heads = 1, 24 | int? cross_attention_dim = 1280, 25 | double output_scale_factor = 1.0, 26 | int? downsample_padding = 1, 27 | bool add_downsample = true, 28 | bool dual_cross_attention = false, 29 | bool use_linear_projection = false, 30 | bool only_cross_attention = false, 31 | bool upcast_attention = false, 32 | string attention_type = "default", 33 | ScalarType dtype = ScalarType.Float32 34 | ): base(nameof(CrossAttnDownBlock2D)) 35 | { 36 | ModuleList resnets = new ModuleList(); 37 | ModuleList attentions = new ModuleList(); 38 | 39 | this.has_cross_attention = true; 40 | this.num_attention_heads = num_attention_heads ?? 1; 41 | transformer_layers_per_block = transformer_layers_per_block ?? Enumerable.Repeat(1, num_layers).ToArray(); 42 | 43 | for(int i = 0; i != num_layers; ++i) 44 | { 45 | in_channels = i == 0 ? in_channels : out_channels; 46 | resnets.Add( 47 | new ResnetBlock2D( 48 | in_channels: in_channels, 49 | out_channels: out_channels, 50 | temb_channels: temb_channels, 51 | eps: (float)resnet_eps, 52 | groups: resnet_groups ?? 32, 53 | dropout: (float)dropout, 54 | time_embedding_norm: resnet_time_scale_shift, 55 | non_linearity: resnet_act_fn, 56 | output_scale_factor: (float)output_scale_factor, 57 | pre_norm: resnet_pre_norm, 58 | dtype: dtype)); 59 | 60 | if (!dual_cross_attention) 61 | { 62 | attentions.Add( 63 | new Transformer2DModel( 64 | num_attention_heads: num_attention_heads ?? 16, 65 | attention_head_dim: out_channels / num_attention_heads ?? throw new ArgumentNullException(nameof(num_attention_heads)), 66 | in_channels: out_channels, 67 | num_layers: transformer_layers_per_block[i], 68 | cross_attention_dim: cross_attention_dim, 69 | norm_num_groups: resnet_groups ?? 32, 70 | use_linear_projection: use_linear_projection, 71 | only_cross_attention: only_cross_attention, 72 | upcast_attention: upcast_attention, 73 | attention_type: attention_type, 74 | dtype: dtype)); 75 | } 76 | else 77 | { 78 | attentions.Add( 79 | new DualTransformer2DModel( 80 | num_attention_heads: num_attention_heads ?? 16, 81 | attention_head_dim: out_channels / num_attention_heads ?? 88, 82 | in_channels: out_channels, 83 | num_layers: 1, 84 | cross_attention_dim: cross_attention_dim, 85 | norm_num_groups: resnet_groups ?? 32, 86 | dtype: dtype)); 87 | } 88 | } 89 | 90 | this.resnets = resnets; 91 | this.attentions = attentions; 92 | 93 | if (add_downsample) 94 | { 95 | this.downsamplers = new ModuleList(); 96 | this.downsamplers.Add( 97 | new Downsample2D( 98 | channels: out_channels, 99 | use_conv: true, 100 | out_channels: out_channels, 101 | padding: downsample_padding, 102 | name: "op", 103 | dtype: dtype)); 104 | } 105 | } 106 | 107 | public override DownBlock2DOutput forward(DownBlock2DInput input) 108 | { 109 | var hidden_states = input.HiddenStates; 110 | var temb = input.Temb; 111 | var encoder_hidden_states = input.EncoderHiddenStates; 112 | var attention_mask = input.AttentionMask; 113 | var encoder_attention_mask = input.EncoderAttentionMask; 114 | var additional_residuals = input.AdditionalResiduals; 115 | 116 | List output_states = new List(); 117 | 118 | var blocks = this.resnets.Zip(this.attentions, (resnet, attention) => (resnet, attention)).ToArray(); 119 | 120 | for(int i = 0; i !=blocks.Count(); ++i) 121 | { 122 | var (resnet, attention) = blocks[i]; 123 | hidden_states = resnet.forward(hidden_states, temb); 124 | if (attention is Transformer2DModel transformer) 125 | { 126 | hidden_states = transformer.forward( 127 | hidden_states, 128 | encoder_hidden_states, 129 | attention_mask, 130 | encoder_attention_mask, 131 | additional_residuals).Sample; 132 | } 133 | else if (attention is DualTransformer2DModel dual_transformer) 134 | { 135 | hidden_states = dual_transformer.forward( 136 | hidden_states, 137 | encoder_hidden_states ?? throw new ArgumentNullException(nameof(encoder_hidden_states)), 138 | attention_mask: attention_mask).Sample; 139 | } 140 | else 141 | { 142 | throw new NotImplementedException(); 143 | } 144 | 145 | if (i == blocks.Count() - 1 && additional_residuals is not null) 146 | { 147 | hidden_states = hidden_states + additional_residuals; 148 | } 149 | 150 | output_states.Add(hidden_states); 151 | } 152 | 153 | if (this.downsamplers is not null) 154 | { 155 | foreach (var downsample in this.downsamplers) 156 | { 157 | hidden_states = downsample.forward(hidden_states); 158 | } 159 | 160 | output_states.Add(hidden_states); 161 | } 162 | 163 | return new DownBlock2DOutput(hidden_states, output_states.ToArray()); 164 | } 165 | } -------------------------------------------------------------------------------- /UNet/CrossAttnUpBlock2D.cs: -------------------------------------------------------------------------------- 1 | 2 | namespace SD; 3 | public class UpBlock2DInput 4 | { 5 | public UpBlock2DInput( 6 | Tensor hiddenStates, 7 | Tensor[] resHiddenStatesTuple, 8 | Tensor? temb = null, 9 | Tensor? encoderHiddenStates = null, 10 | Dictionary? crossAttentionKwargs = null, 11 | long[]? upsampleSize = null, 12 | Tensor? attentionMask = null, 13 | Tensor? encoderAttentionMask = null) 14 | { 15 | HiddenStates = hiddenStates; 16 | ResHiddenStatesTuple = resHiddenStatesTuple; 17 | Temb = temb; 18 | EncoderHiddenStates = encoderHiddenStates; 19 | CrossAttentionKwargs = crossAttentionKwargs; 20 | UpsampleSize = upsampleSize; 21 | AttentionMask = attentionMask; 22 | EncoderAttentionMask = encoderAttentionMask; 23 | } 24 | public Tensor HiddenStates { get; } 25 | public Tensor[] ResHiddenStatesTuple { get; } 26 | public Tensor? Temb { get; } 27 | public Tensor? EncoderHiddenStates { get; } 28 | public Dictionary? CrossAttentionKwargs { get; } 29 | 30 | public long[]? UpsampleSize { get; } 31 | 32 | public Tensor? AttentionMask { get; } 33 | 34 | public Tensor? EncoderAttentionMask { get; } 35 | 36 | } 37 | 38 | public class CrossAttnUpBlock2D : Module 39 | { 40 | private readonly bool has_cross_attention; 41 | private readonly int num_attention_heads; 42 | private readonly ModuleList resnets; 43 | private readonly ModuleList attentions; 44 | private readonly ModuleList? upsamplers = null; 45 | private readonly int? resolution_idx; 46 | public CrossAttnUpBlock2D( 47 | int in_channels, 48 | int out_channels, 49 | int prev_output_channel, 50 | int temb_channels, 51 | int? resolution_idx = null, 52 | float dropout = 0.0f, 53 | int num_layers = 1, 54 | int[]? transformer_layers_per_block = null, 55 | float resnet_eps = 1e-6f, 56 | string resnet_time_scale_shift = "default", 57 | string resnet_act_fn = "swish", 58 | int resnet_groups = 32, 59 | bool resnet_pre_norm = true, 60 | int num_attention_heads = 1, 61 | int cross_attention_dim = 1280, 62 | float output_scale_factor = 1.0f, 63 | bool add_upsample = true, 64 | bool dual_cross_attention = false, 65 | bool use_linear_projection = false, 66 | bool only_cross_attention = false, 67 | bool upcast_attention = false, 68 | string attention_type = "default", 69 | ScalarType dtype = ScalarType.Float32) 70 | : base(nameof(CrossAttnUpBlock2D)) 71 | { 72 | ModuleList resnets = new ModuleList(); 73 | ModuleList attentions = new ModuleList(); 74 | 75 | this.has_cross_attention = true; 76 | this.num_attention_heads = num_attention_heads; 77 | transformer_layers_per_block = transformer_layers_per_block ?? Enumerable.Repeat(1, num_layers).ToArray(); 78 | 79 | for (int i = 0; i != num_layers; ++i) 80 | { 81 | var res_skip_channels = i == num_layers - 1 ? in_channels : out_channels; 82 | var resnet_in_channels = i == 0 ? prev_output_channel : out_channels; 83 | 84 | resnets.Add( 85 | new ResnetBlock2D( 86 | in_channels: resnet_in_channels + res_skip_channels, 87 | out_channels: out_channels, 88 | temb_channels: temb_channels, 89 | eps: resnet_eps, 90 | groups: resnet_groups, 91 | dropout: dropout, 92 | time_embedding_norm: resnet_time_scale_shift, 93 | non_linearity: resnet_act_fn, 94 | output_scale_factor: output_scale_factor, 95 | pre_norm: resnet_pre_norm, 96 | dtype: dtype)); 97 | 98 | if (!dual_cross_attention) 99 | { 100 | attentions.Add( 101 | new Transformer2DModel( 102 | num_attention_heads: num_attention_heads, 103 | attention_head_dim: out_channels / num_attention_heads, 104 | in_channels: out_channels, 105 | num_layers: transformer_layers_per_block[i], 106 | cross_attention_dim: cross_attention_dim, 107 | norm_num_groups: resnet_groups, 108 | use_linear_projection: use_linear_projection, 109 | only_cross_attention: only_cross_attention, 110 | upcast_attention: upcast_attention, 111 | attention_type: attention_type, 112 | dtype: dtype)); 113 | } 114 | else 115 | { 116 | attentions.Add( 117 | new DualTransformer2DModel( 118 | num_attention_heads: num_attention_heads, 119 | attention_head_dim: out_channels / num_attention_heads, 120 | in_channels: out_channels, 121 | num_layers: 1, 122 | cross_attention_dim: cross_attention_dim, 123 | norm_num_groups: resnet_groups, 124 | dtype: dtype)); 125 | 126 | } 127 | } 128 | 129 | this.resnets = resnets; 130 | this.attentions = attentions; 131 | 132 | if (add_upsample) 133 | { 134 | this.upsamplers = new ModuleList(); 135 | this.upsamplers.Add( 136 | new Upsample2D( 137 | channels: out_channels, 138 | use_conv: true, 139 | out_channels: out_channels, 140 | dtype: dtype)); 141 | } 142 | this.resolution_idx = resolution_idx; 143 | } 144 | 145 | public ModuleList Resnets => resnets; 146 | public override Tensor forward(UpBlock2DInput input) 147 | { 148 | var hiddenStates = input.HiddenStates; 149 | var resHiddenStatesTuple = input.ResHiddenStatesTuple; 150 | var temb = input.Temb; 151 | var encoderHiddenStates = input.EncoderHiddenStates; 152 | var crossAttentionKwargs = input.CrossAttentionKwargs; 153 | var upsampleSize = input.UpsampleSize; 154 | var attentionMask = input.AttentionMask; 155 | var encoderAttentionMask = input.EncoderAttentionMask; 156 | 157 | foreach(var (resnet, attention) in resnets.Zip(attentions)) 158 | { 159 | // pop res hidden states 160 | var res_hidden_states = resHiddenStatesTuple[^1]; 161 | resHiddenStatesTuple = resHiddenStatesTuple[..^1]; 162 | 163 | hiddenStates = torch.cat([hiddenStates, res_hidden_states], 1); 164 | hiddenStates = resnet.forward(hiddenStates, temb); 165 | if (attention is Transformer2DModel transformer) 166 | { 167 | hiddenStates = transformer.forward( 168 | hiddenStates, 169 | encoder_hidden_states: encoderHiddenStates, 170 | attention_mask: attentionMask, 171 | encoder_attention_mask: encoderAttentionMask).Sample; 172 | } 173 | else if (attention is DualTransformer2DModel dualTransformer) 174 | { 175 | hiddenStates = dualTransformer.forward( 176 | hiddenStates, 177 | encoder_hidden_states: encoderHiddenStates ?? throw new ArgumentNullException(nameof(encoderHiddenStates)), 178 | attention_mask: attentionMask).Sample; 179 | } 180 | } 181 | 182 | if (upsamplers != null) 183 | { 184 | foreach(var upsample in upsamplers) 185 | { 186 | hiddenStates = upsample.forward(hiddenStates, upsampleSize); 187 | } 188 | } 189 | 190 | return hiddenStates; 191 | } 192 | } 193 | -------------------------------------------------------------------------------- /UNet/DownBlock2D.cs: -------------------------------------------------------------------------------- 1 | namespace SD; 2 | 3 | public class DownBlock2DInput 4 | { 5 | public Tensor HiddenStates {get;} 6 | public Tensor? Temb {get;} = null; 7 | 8 | public Tensor? EncoderHiddenStates {get;} = null; 9 | public Tensor? AttentionMask {get;} = null; 10 | public Tensor? EncoderAttentionMask {get;} = null; 11 | public Tensor? AdditionalResiduals {get;} = null; 12 | 13 | public DownBlock2DInput( 14 | Tensor hiddenStates, 15 | Tensor? temb = null, 16 | Tensor? encoderHiddenStates = null, 17 | Tensor? attentionMask = null, 18 | Tensor? encoderAttentionMask = null, 19 | Tensor? additionalResiduals = null) 20 | { 21 | HiddenStates = hiddenStates; 22 | Temb = temb; 23 | EncoderHiddenStates = encoderHiddenStates; 24 | AttentionMask = attentionMask; 25 | EncoderAttentionMask = encoderAttentionMask; 26 | AdditionalResiduals = additionalResiduals; 27 | } 28 | } 29 | 30 | public class DownBlock2DOutput 31 | { 32 | public Tensor HiddenStates {get;} 33 | public Tensor[]? OutputStates {get;} = null; 34 | 35 | public DownBlock2DOutput(Tensor hiddenStates, Tensor[]? outputStates = null) 36 | { 37 | HiddenStates = hiddenStates; 38 | OutputStates = outputStates; 39 | } 40 | } 41 | 42 | public class DownBlock2D: Module 43 | { 44 | private ModuleList resnets; 45 | private ModuleList>? attentions; 46 | private ModuleList>? downsamplers; 47 | 48 | public DownBlock2D( 49 | int in_channels, 50 | int out_channels, 51 | int temb_channels, 52 | float dropout = 0.0f, 53 | int num_layers = 1, 54 | float resnet_eps = 1e-6f, 55 | string resnet_time_scale_shift = "default", 56 | string resnet_act_fn = "swish", 57 | int? resnet_groups = 32, 58 | bool resnet_pre_norm = true, 59 | float output_scale_factor = 1.0f, 60 | bool? add_downsample = true, 61 | int? downsample_padding = 1, 62 | ScalarType dtype = ScalarType.Float32) 63 | : base(nameof(DownBlock2D)) 64 | { 65 | var resnets = new ModuleList(); 66 | for (int i = 0; i < num_layers; i++) 67 | { 68 | in_channels = i == 0 ? in_channels : out_channels; 69 | resnets.Add(new ResnetBlock2D( 70 | in_channels: in_channels, 71 | out_channels: out_channels, 72 | temb_channels: temb_channels, 73 | eps: resnet_eps, 74 | groups: resnet_groups ?? throw new ArgumentNullException(nameof(resnet_groups)), 75 | dropout: dropout, 76 | time_embedding_norm: resnet_time_scale_shift, 77 | non_linearity: resnet_act_fn, 78 | output_scale_factor: output_scale_factor, 79 | pre_norm: resnet_pre_norm, 80 | dtype: dtype)); 81 | } 82 | 83 | this.resnets = resnets; 84 | 85 | if (add_downsample is true) 86 | { 87 | this.downsamplers = new ModuleList>(); 88 | this.downsamplers.Add(new Downsample2D( 89 | channels: out_channels, 90 | use_conv: true, 91 | out_channels: out_channels, 92 | name: "op", 93 | padding: downsample_padding, 94 | dtype: dtype)); 95 | } 96 | } 97 | 98 | public override DownBlock2DOutput forward(DownBlock2DInput input) 99 | { 100 | var hiddenStates = input.HiddenStates; 101 | var temb = input.Temb; 102 | var encoderHiddenStates = input.EncoderHiddenStates; 103 | var attentionMask = input.AttentionMask; 104 | var encoderAttentionMask = input.EncoderAttentionMask; 105 | var additionalResiduals = input.AdditionalResiduals; 106 | 107 | var output_states = new List(); 108 | 109 | foreach (var resnet in resnets) 110 | { 111 | hiddenStates = resnet.forward(hiddenStates, temb); 112 | output_states.Add(hiddenStates); 113 | } 114 | 115 | if (downsamplers != null) 116 | { 117 | foreach (var downsample in downsamplers) 118 | { 119 | hiddenStates = downsample.forward(hiddenStates); 120 | } 121 | 122 | output_states.Add(hiddenStates); 123 | } 124 | 125 | return new DownBlock2DOutput(hiddenStates, output_states.ToArray()); 126 | } 127 | } -------------------------------------------------------------------------------- /UNet/DownEncoderBlock2D.cs: -------------------------------------------------------------------------------- 1 | using static TorchSharp.torch.nn; 2 | using static TorchSharp.torch; 3 | using TorchSharp.Modules; 4 | using TorchSharp; 5 | 6 | namespace SD; 7 | 8 | public class DownEncoderBlock2D : Module 9 | { 10 | private readonly ScalarType dtype; 11 | private readonly ModuleList> resnets; 12 | private readonly ModuleList>? downsamplers; 13 | public DownEncoderBlock2D( 14 | int in_channels, 15 | int out_channels, 16 | float dropout = 0.0f, 17 | int num_layers = 1, 18 | float resnet_eps = 1e-6f, 19 | string resnet_time_scale_shift = "default", 20 | string resnet_act_fun = "swish", 21 | int resnet_groups = 32, 22 | bool resnet_pre_norm = true, 23 | float output_scale_factor = 1.0f, 24 | bool add_downsample = true, 25 | int downsample_padding = 1, 26 | ScalarType dtype = ScalarType.Float32) 27 | : base(nameof(DownEncoderBlock2D)) 28 | { 29 | this.dtype = dtype; 30 | this.resnets = new ModuleList>(); 31 | for (int i = 0; i < num_layers; i++) 32 | { 33 | in_channels = i == 0 ? in_channels : out_channels; 34 | if (resnet_time_scale_shift == "spatial") 35 | { 36 | this.resnets.Add(new ResnetBlockCondNorm2D( 37 | in_channels: in_channels, 38 | out_channels: out_channels, 39 | dropout: dropout, 40 | temb_channels: out_channels, 41 | groups: resnet_groups, 42 | eps: resnet_eps, 43 | non_linearity: resnet_act_fun, 44 | time_embedding_norm: resnet_time_scale_shift, 45 | output_scale_factor: output_scale_factor, 46 | up: false, 47 | down: false, 48 | conv_2d_out_channels: out_channels, 49 | conv_shortcut: false, 50 | conv_shortcut_bias: true, 51 | dtype: dtype 52 | )); 53 | } 54 | else 55 | { 56 | this.resnets.Add(new ResnetBlock2D( 57 | in_channels: in_channels, 58 | out_channels: out_channels, 59 | dropout: dropout, 60 | temb_channels: null, 61 | groups: resnet_groups, 62 | pre_norm: resnet_pre_norm, 63 | eps: resnet_eps, 64 | non_linearity: resnet_act_fun, 65 | time_embedding_norm: resnet_time_scale_shift, 66 | output_scale_factor: output_scale_factor, 67 | up: false, 68 | down: false, 69 | conv_2d_out_channels: out_channels, 70 | conv_shortcut: false, 71 | conv_shortcut_bias: true, 72 | dtype: dtype 73 | )); 74 | } 75 | } 76 | 77 | if (add_downsample) 78 | { 79 | this.downsamplers = new ModuleList>(); 80 | this.downsamplers.Add(new Downsample2D( 81 | channels: out_channels, 82 | use_conv: true, 83 | out_channels: out_channels, 84 | padding: downsample_padding, 85 | name: "op", 86 | dtype: dtype 87 | )); 88 | } 89 | } 90 | 91 | public override Tensor forward(Tensor hidden_states) 92 | { 93 | for (int i = 0; i < resnets.Count; i++) 94 | { 95 | hidden_states = resnets[i].forward(hidden_states, null); 96 | } 97 | 98 | if (downsamplers is not null) 99 | { 100 | for (int i = 0; i < downsamplers.Count; i++) 101 | { 102 | hidden_states = downsamplers[i].forward(hidden_states); 103 | } 104 | } 105 | 106 | return hidden_states; 107 | } 108 | 109 | } 110 | -------------------------------------------------------------------------------- /UNet/Downsample2D.cs: -------------------------------------------------------------------------------- 1 | using static TorchSharp.torch.nn; 2 | using static TorchSharp.torch; 3 | using TorchSharp.Modules; 4 | 5 | namespace SD; 6 | 7 | public class Downsample2D : Module 8 | { 9 | private readonly int channels; 10 | private readonly int out_channels; 11 | private readonly bool use_conv; 12 | private readonly int? padding; 13 | private readonly string conv_name; 14 | 15 | private readonly ScalarType defaultDtype; 16 | 17 | private readonly Module? conv; 18 | private readonly Module? Conv2d_0; 19 | private readonly Module? norm; 20 | public Downsample2D( 21 | int channels, 22 | bool use_conv = false, 23 | int? out_channels = null, 24 | int? padding = 1, 25 | string name = "conv", 26 | int kernel_size = 3, 27 | string? norm_type = null, 28 | float eps = 1e-5f, 29 | bool elementwise_affine = false, 30 | bool bias = true, 31 | ScalarType dtype = ScalarType.Float32) 32 | : base(nameof(Downsample2D)) 33 | { 34 | this.channels = channels; 35 | this.out_channels = out_channels ?? channels; 36 | this.use_conv = use_conv; 37 | this.padding = padding; 38 | this.conv_name = name; 39 | this.defaultDtype = dtype; 40 | 41 | if (norm_type is "ln_norm") 42 | { 43 | this.norm = nn.LayerNorm(normalized_shape: this.channels, eps: eps, elementwise_affine: elementwise_affine, dtype: dtype); 44 | } 45 | else if (norm_type is null) 46 | { 47 | this.norm = null; 48 | } 49 | else 50 | { 51 | throw new ArgumentException("Invalid norm type: " + norm_type); 52 | } 53 | 54 | Module conv; 55 | if (use_conv) 56 | { 57 | conv = nn.Conv2d(inputChannel: this.channels, outputChannel: this.out_channels, kernelSize: kernel_size, stride: 2, padding: padding ?? 1, bias: bias, dtype: dtype); 58 | } 59 | else 60 | { 61 | var stride = 2; 62 | conv = nn.AvgPool2d(kernel_size: 2, stride: stride); 63 | } 64 | 65 | if (name == "conv"){ 66 | this.Conv2d_0 = conv; 67 | this.conv = conv; 68 | } 69 | else if (name == "Conv2d_0"){ 70 | this.Conv2d_0 = conv; 71 | } 72 | else 73 | { 74 | this.conv = conv; 75 | } 76 | 77 | } 78 | 79 | public override Tensor forward(Tensor hidden_states) 80 | { 81 | if (this.norm is not null) 82 | { 83 | hidden_states = this.norm.forward(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2); 84 | } 85 | 86 | if (this.use_conv && this.padding == 0) 87 | { 88 | hidden_states = nn.functional.pad(hidden_states, pad: [0, 1, 0, 1], mode: TorchSharp.PaddingModes.Constant, value: 0); 89 | } 90 | 91 | hidden_states = this.conv!.forward(hidden_states); 92 | 93 | return hidden_states; 94 | } 95 | } 96 | -------------------------------------------------------------------------------- /UNet/DualTransformer2DModel.cs: -------------------------------------------------------------------------------- 1 | public class DualTransformer2DModel : Module 2 | { 3 | private readonly float mix_ratio = 0.5f; 4 | private readonly int[] condition_lengths = [77, 257]; 5 | private readonly int[] transformer_index_for_condition = [1, 0]; 6 | private readonly ModuleList transformers; 7 | 8 | public DualTransformer2DModel( 9 | int num_attention_heads = 16, 10 | int attention_head_dim = 88, 11 | int? in_channels = null, 12 | int num_layers = 1, 13 | double dropout = 0.0, 14 | int norm_num_groups = 32, 15 | int? cross_attention_dim = null, 16 | bool attention_bias = false, 17 | int? sample_size = null, 18 | int? num_vector_embeds = null, 19 | string activation_fn = "geglu", 20 | int? num_embeds_ada_norm = null, 21 | ScalarType dtype = ScalarType.Float32 22 | ) : base(nameof(DualTransformer2DModel)) 23 | { 24 | this.transformers = new ModuleList(); 25 | for(int i = 0; i < 2; i++) 26 | { 27 | transformers.Add(new Transformer2DModel( 28 | num_attention_heads: num_attention_heads, 29 | attention_head_dim: attention_head_dim, 30 | in_channels: in_channels, 31 | num_layers: num_layers, 32 | dropout: dropout, 33 | norm_num_groups: norm_num_groups, 34 | cross_attention_dim: cross_attention_dim, 35 | attention_bias: attention_bias, 36 | sample_size: sample_size, 37 | num_vector_embeds: num_vector_embeds, 38 | activation_fn: activation_fn, 39 | num_embeds_ada_norm: num_embeds_ada_norm, 40 | dtype: dtype)); 41 | } 42 | } 43 | 44 | public override Transformer2DModelOutput forward( 45 | Tensor hidden_states, 46 | Tensor encoder_hidden_states, 47 | Tensor? timestep = null, 48 | Tensor? attention_mask = null) 49 | { 50 | var input_states = hidden_states; 51 | List encoded_states = []; 52 | var tokens_start = 0; 53 | 54 | for(int i = 0; i < 2; i++) 55 | { 56 | var condition_state = encoder_hidden_states[.., tokens_start..(tokens_start + condition_lengths[i])]; 57 | var transformer_index = transformer_index_for_condition[i]; 58 | var transformer = transformers[transformer_index]; 59 | var encoded_state = transformer.forward( 60 | input_states, condition_state, timestep); 61 | encoded_states.Add(encoded_state.Sample - input_states); 62 | tokens_start += condition_lengths[i]; 63 | } 64 | 65 | var output_states = encoded_states[0] * this.mix_ratio + encoded_states[1] * (1 - this.mix_ratio); 66 | output_states = input_states + output_states; 67 | 68 | return new Transformer2DModelOutput(output_states); 69 | } 70 | } -------------------------------------------------------------------------------- /UNet/ResnetBlock2D.cs: -------------------------------------------------------------------------------- 1 | using static TorchSharp.torch.nn; 2 | using static TorchSharp.torch; 3 | using TorchSharp.Modules; 4 | using System.Linq.Expressions; 5 | 6 | namespace SD; 7 | 8 | public class ResnetBlock2D : Module 9 | { 10 | private readonly bool pre_norm; 11 | private readonly int in_channels; 12 | private readonly int out_channels; 13 | private readonly bool use_conv_shortcut; 14 | private readonly bool up; 15 | private readonly bool down; 16 | private readonly float output_scale_factor; 17 | private readonly string time_embedding_norm; 18 | private bool skip_time_act; 19 | private readonly bool use_in_shortcut; 20 | private readonly ScalarType defaultDtype; 21 | 22 | private Module norm1; 23 | private Module conv1; 24 | private Module norm2; 25 | private Module conv2; 26 | private Module dropout; 27 | private Linear? time_emb_proj; 28 | private Module nonlinearity; 29 | private Module? upsample = null; 30 | private Module? downsample = null; 31 | private Module? conv_shortcut = null; 32 | public ResnetBlock2D( 33 | int in_channels, 34 | int? out_channels = null, 35 | bool conv_shortcut = false, 36 | float dropout = 0.0f, 37 | int? temb_channels = 512, 38 | int groups = 32, 39 | int? groups_out = null, 40 | bool pre_norm = true, 41 | float eps = 1e-6f, 42 | string non_linearity = "swish", 43 | bool skip_time_act = false, 44 | string time_embedding_norm = "default", // default, scale_shift, 45 | Tensor? kernel = null, 46 | float output_scale_factor = 1.0f, 47 | bool? use_in_shortcut = null, 48 | bool up = false, 49 | bool down = false, 50 | bool conv_shortcut_bias = true, 51 | int? conv_2d_out_channels = null, 52 | ScalarType dtype = ScalarType.Float32) 53 | : base(nameof(ResnetBlock2D)) 54 | { 55 | this.defaultDtype = dtype; 56 | if (time_embedding_norm == "ada_group" || time_embedding_norm == "spatial") 57 | { 58 | throw new ArgumentException("Invalid time_embedding_norm: " + time_embedding_norm); 59 | } 60 | 61 | this.pre_norm = pre_norm; 62 | this.in_channels = in_channels; 63 | this.out_channels = out_channels ?? in_channels; 64 | this.use_conv_shortcut = conv_shortcut; 65 | this.up = up; 66 | this.down = down; 67 | this.output_scale_factor = output_scale_factor; 68 | this.time_embedding_norm = time_embedding_norm; 69 | this.skip_time_act = skip_time_act; 70 | 71 | groups_out = groups_out ?? groups; 72 | 73 | this.norm1 = nn.GroupNorm(num_groups: groups, num_channels: in_channels, eps: eps, affine: true, dtype: dtype); 74 | this.conv1 = nn.Conv2d(in_channels, this.out_channels, kernelSize: 3, stride: 1, padding: 1, bias: true, dtype: dtype); 75 | 76 | if (temb_channels is not null) 77 | { 78 | if (this.time_embedding_norm == "default"){ 79 | this.time_emb_proj = nn.Linear(temb_channels.Value, this.out_channels, dtype: dtype); 80 | } 81 | else if (this.time_embedding_norm == "scale_shift") 82 | { 83 | this.time_emb_proj = nn.Linear(temb_channels.Value, this.out_channels * 2, dtype: dtype); 84 | } 85 | else{ 86 | throw new ArgumentException("Invalid time_embedding_norm: " + time_embedding_norm); 87 | } 88 | } 89 | else{ 90 | this.time_emb_proj = null; 91 | } 92 | 93 | this.norm2 = nn.GroupNorm(num_groups: groups_out.Value, num_channels: this.out_channels, eps: eps, affine: true, dtype: dtype); 94 | this.dropout = nn.Dropout(dropout); 95 | conv_2d_out_channels = conv_2d_out_channels ?? this.out_channels; 96 | this.conv2 = nn.Conv2d(this.out_channels, conv_2d_out_channels.Value, kernelSize: 3, stride: 1, padding: 1, bias: true, dtype: dtype); 97 | this.nonlinearity = Utils.GetActivation(non_linearity); 98 | if (this.up){ 99 | this.upsample = new Upsample2D(channels: in_channels, use_conv: false, dtype: dtype, padding: 1, name: "op"); 100 | } 101 | else if (this.down){ 102 | this.downsample = new Downsample2D(channels: in_channels, use_conv: false, padding: 1, name: "op", dtype: dtype); 103 | } 104 | 105 | this.use_in_shortcut = use_in_shortcut ?? this.in_channels != conv_2d_out_channels; 106 | 107 | if (this.use_in_shortcut) 108 | { 109 | this.conv_shortcut = nn.Conv2d(in_channels, this.out_channels, kernelSize: 1, stride: 1, padding: TorchSharp.Padding.Valid, bias: conv_shortcut_bias, dtype: dtype); 110 | } 111 | } 112 | 113 | public override Tensor forward(Tensor input_tensor, Tensor? temb) 114 | { 115 | using var _ = NewDisposeScope(); 116 | var hidden_states = input_tensor; 117 | hidden_states = this.norm1.forward(hidden_states); 118 | hidden_states = this.nonlinearity.forward(hidden_states); 119 | if (this.upsample is not null){ 120 | input_tensor = this.upsample.forward(input_tensor, null); 121 | hidden_states = this.upsample.forward(hidden_states, null); 122 | } 123 | else if (this.downsample is not null){ 124 | input_tensor = this.downsample.forward(input_tensor); 125 | hidden_states = this.downsample.forward(hidden_states); 126 | } 127 | hidden_states = this.conv1.forward(hidden_states); 128 | if (this.time_emb_proj is not null) 129 | { 130 | if (!this.skip_time_act){ 131 | temb = this.nonlinearity.forward(temb!); 132 | } 133 | temb = this.time_emb_proj.forward(temb!); 134 | // temb = self.time_emb_proj(temb)[:, :, None, None] 135 | temb = temb.unsqueeze(2).unsqueeze(3); 136 | } 137 | 138 | if (this.time_embedding_norm == "default"){ 139 | if (temb is not null){ 140 | hidden_states = hidden_states + temb; 141 | } 142 | hidden_states = this.norm2.forward(hidden_states); 143 | } 144 | else if (this.time_embedding_norm == "scale_shift") 145 | { 146 | if (temb is null){ 147 | throw new ArgumentException("Time embedding is None"); 148 | } 149 | 150 | var chunks = temb.chunk(2, 1); 151 | var time_scale = chunks[0]; 152 | var time_shift = chunks[1]; 153 | hidden_states = this.norm2.forward(hidden_states); 154 | hidden_states = hidden_states * (1 + time_scale) + time_shift; 155 | } 156 | else 157 | { 158 | hidden_states = this.norm2.forward(hidden_states); 159 | } 160 | 161 | hidden_states = this.nonlinearity.forward(hidden_states); 162 | hidden_states = this.dropout.forward(hidden_states); 163 | hidden_states = this.conv2.forward(hidden_states); 164 | if (this.conv_shortcut is not null) 165 | { 166 | input_tensor = this.conv_shortcut.forward(input_tensor); 167 | } 168 | 169 | 170 | var output = (input_tensor + hidden_states) / this.output_scale_factor; 171 | 172 | return output.MoveToOuterDisposeScope(); 173 | } 174 | } -------------------------------------------------------------------------------- /UNet/ResnetBlockCondNorm2D.cs: -------------------------------------------------------------------------------- 1 | using static TorchSharp.torch.nn; 2 | using static TorchSharp.torch; 3 | using TorchSharp.Modules; 4 | using TorchSharp; 5 | 6 | namespace SD; 7 | 8 | 9 | public class ResnetBlockCondNorm2D : Module 10 | { 11 | private readonly int in_channels; 12 | private readonly int out_channels; 13 | private readonly bool use_conv_shortcut; 14 | private readonly bool up; 15 | private readonly bool down; 16 | private readonly float output_scale_factor; 17 | private readonly string time_embedding_norm; 18 | private ScalarType defaultDtype; 19 | 20 | private readonly Module norm1; 21 | private readonly Module conv1; 22 | private readonly Module norm2; 23 | private readonly Module dropout; 24 | private readonly Module conv2; 25 | private readonly Module nonlinearity; 26 | private readonly Module? upsample; 27 | private readonly Module? downsample; 28 | private readonly Module? conv_shortcut; 29 | 30 | public ResnetBlockCondNorm2D( 31 | int in_channels, 32 | int? out_channels = null, 33 | bool conv_shortcut = false, 34 | float dropout = 0.0f, 35 | int temb_channels = 512, 36 | int groups = 32, 37 | int? groups_out = null, 38 | float eps = 1e-6f, 39 | string non_linearity = "swish", 40 | string time_embedding_norm = "ada_group", 41 | float output_scale_factor = 1.0f, 42 | bool? use_in_shortcut = null, 43 | bool up = false, 44 | bool down = false, 45 | bool conv_shortcut_bias = true, 46 | int? conv_2d_out_channels = null, 47 | ScalarType dtype = ScalarType.Float32) 48 | : base(nameof(ResnetBlockCondNorm2D)) 49 | { 50 | this.in_channels = in_channels; 51 | this.out_channels = out_channels ?? in_channels; 52 | this.use_conv_shortcut = conv_shortcut; 53 | this.up = up; 54 | this.down = down; 55 | this.output_scale_factor = output_scale_factor; 56 | this.time_embedding_norm = time_embedding_norm; 57 | 58 | groups_out = groups_out ?? groups; 59 | 60 | if (this.time_embedding_norm == "ada_group") 61 | { 62 | this.norm1 = new AdaGroupNorm( 63 | embedding_dim: temb_channels, 64 | out_dim: this.in_channels, 65 | num_groups: groups, 66 | eps: eps, 67 | dtype: dtype); 68 | this.norm2 = new AdaGroupNorm( 69 | embedding_dim: temb_channels, 70 | out_dim: this.out_channels, 71 | num_groups: groups_out.Value, 72 | eps: eps, 73 | dtype: dtype); 74 | } 75 | else if (this.time_embedding_norm == "spatial") 76 | { 77 | this.norm1 = new SpatialNorm( 78 | f_channels: this.in_channels, 79 | zq_channels: temb_channels, 80 | dtype: dtype); 81 | this.norm2 = new SpatialNorm( 82 | f_channels: this.out_channels, 83 | zq_channels: temb_channels, 84 | dtype: dtype); 85 | } 86 | else 87 | { 88 | throw new ArgumentException("Invalid time_embedding_norm"); 89 | } 90 | 91 | this.conv1 = nn.Conv2d( 92 | inputChannel: this.in_channels, 93 | outputChannel: this.in_channels, 94 | kernelSize: 3, 95 | stride: 1, 96 | padding: 1, 97 | dtype: dtype); 98 | 99 | this.dropout = nn.Dropout(dropout); 100 | conv_2d_out_channels = conv_2d_out_channels ?? this.out_channels; 101 | this.conv2 = nn.Conv2d( 102 | inputChannel: this.in_channels, 103 | outputChannel: conv_2d_out_channels.Value, 104 | kernelSize: 3, 105 | stride: 1, 106 | padding: 1, 107 | dtype: dtype); 108 | this.nonlinearity = Utils.GetActivation(non_linearity); 109 | 110 | this.upsample = null; 111 | this.downsample = null; 112 | if (this.up) 113 | { 114 | this.upsample = new Upsample2D( 115 | channels: this.in_channels, 116 | use_conv: false, 117 | dtype: dtype); 118 | } 119 | else if (this.down) 120 | { 121 | this.downsample = new Downsample2D( 122 | channels: this.in_channels, 123 | use_conv: false, 124 | padding: 1, 125 | name: "op", 126 | dtype: dtype); 127 | } 128 | 129 | this.use_conv_shortcut = use_in_shortcut ?? this.in_channels != conv_2d_out_channels; 130 | 131 | if (this.use_conv_shortcut) 132 | { 133 | this.conv_shortcut = nn.Conv2d( 134 | inputChannel: this.in_channels, 135 | outputChannel: conv_2d_out_channels.Value, 136 | kernelSize: 1, 137 | stride: 1, 138 | padding: Padding.Valid, 139 | bias: conv_shortcut_bias, 140 | dtype: dtype); 141 | } 142 | } 143 | 144 | public override Tensor forward(Tensor input_tensor, Tensor? temb) 145 | { 146 | var hidden_states = input_tensor; 147 | hidden_states = this.norm1.forward(hidden_states, temb); 148 | hidden_states = this.nonlinearity.forward(hidden_states); 149 | 150 | if (this.up) 151 | { 152 | input_tensor = this.upsample!.forward(input_tensor, null); 153 | hidden_states = this.upsample!.forward(hidden_states, null); 154 | } 155 | else if (this.down) 156 | { 157 | input_tensor = this.downsample!.forward(input_tensor); 158 | hidden_states = this.downsample!.forward(hidden_states); 159 | } 160 | 161 | hidden_states = this.conv1.forward(hidden_states); 162 | hidden_states = this.norm2.forward(hidden_states, temb); 163 | hidden_states = this.nonlinearity.forward(hidden_states); 164 | 165 | hidden_states = this.dropout.forward(hidden_states); 166 | hidden_states = this.conv2.forward(hidden_states); 167 | 168 | if (this.use_conv_shortcut) 169 | { 170 | input_tensor = this.conv_shortcut!.forward(input_tensor); 171 | } 172 | 173 | return (hidden_states + input_tensor) / this.output_scale_factor; 174 | } 175 | } -------------------------------------------------------------------------------- /UNet/SpatialNorm.cs: -------------------------------------------------------------------------------- 1 | using static TorchSharp.torch.nn; 2 | using static TorchSharp.torch; 3 | using TorchSharp.Modules; 4 | 5 | namespace SD; 6 | 7 | public class SpatialNorm : Module 8 | { 9 | private readonly Module norm_layer; 10 | private readonly Module conv_y; 11 | private readonly Module conv_b; 12 | 13 | private readonly ScalarType defaultDtype; 14 | 15 | public SpatialNorm( 16 | int f_channels, 17 | int zq_channels, 18 | ScalarType dtype = ScalarType.Float32) 19 | : base(nameof(SpatialNorm)) 20 | { 21 | this.defaultDtype = dtype; 22 | this.norm_layer = nn.GroupNorm(num_channels: f_channels, num_groups: 32, eps: 1e-6f, affine: true, dtype: dtype); 23 | this.conv_y = nn.Conv2d(inputChannel: zq_channels, outputChannel: f_channels, kernelSize: 1, stride: 1, padding: TorchSharp.Padding.Valid, dtype: dtype); 24 | this.conv_b = nn.Conv2d(inputChannel: zq_channels, outputChannel: f_channels, kernelSize: 1, stride: 1, padding: TorchSharp.Padding.Valid, dtype: dtype); 25 | 26 | RegisterComponents(); 27 | } 28 | 29 | public override Tensor forward(Tensor f, Tensor zq) 30 | { 31 | var f_size = f.shape[-2..]; 32 | zq = nn.functional.interpolate(zq, f_size, mode: InterpolationMode.Nearest); 33 | var norm_f = this.norm_layer.forward(f); 34 | var new_f = norm_f * this.conv_y.forward(zq) + this.conv_b.forward(zq); 35 | 36 | return new_f; 37 | } 38 | } -------------------------------------------------------------------------------- /UNet/Timesteps.cs: -------------------------------------------------------------------------------- 1 | namespace SD; 2 | 3 | public class Timesteps: Module 4 | { 5 | private readonly int num_channels; 6 | private readonly bool flip_sin_to_cos; 7 | private readonly float downscale_freq_shift; 8 | 9 | public Timesteps(int num_channels, bool flip_sin_to_cos, float downscale_freq_shift): base("Timesteps") 10 | { 11 | this.num_channels = num_channels; 12 | this.flip_sin_to_cos = flip_sin_to_cos; 13 | this.downscale_freq_shift = downscale_freq_shift; 14 | } 15 | 16 | public override Tensor forward(Tensor timesteps) 17 | { 18 | var t_emb = Utils.GetTimestepEmbedding( 19 | timesteps, 20 | num_channels, 21 | flip_sin_to_cos: flip_sin_to_cos, 22 | downscale_freq_shift: downscale_freq_shift 23 | ); 24 | return t_emb; 25 | } 26 | } -------------------------------------------------------------------------------- /UNet/Transformer2DModel.cs: -------------------------------------------------------------------------------- 1 | public class Transformer2DModel : Module 2 | { 3 | private readonly int num_attention_heads; 4 | private readonly int attention_head_dim; 5 | private readonly int? in_channels; 6 | private readonly int? out_channels; 7 | private readonly int num_layers; 8 | private readonly double dropout; 9 | private readonly int norm_num_groups; 10 | private readonly int? cross_attention_dim; 11 | private readonly bool attention_bias; 12 | private readonly int? sample_size; 13 | private readonly int? num_vector_embeds; 14 | private readonly int? patch_size; 15 | private readonly string activation_fn; 16 | private readonly int? num_embeds_ada_norm; 17 | private readonly bool use_linear_projection; 18 | private readonly bool only_cross_attention; 19 | private readonly bool double_self_attention; 20 | private readonly bool upcast_attention; 21 | private readonly string norm_type; 22 | private readonly bool norm_elementwise_affine; 23 | private readonly double norm_eps; 24 | private readonly string attention_type; 25 | private readonly int? caption_channels; 26 | private readonly double? interpolation_scale; 27 | private readonly bool use_additional_conditions = false; 28 | 29 | private readonly bool is_input_continuous; 30 | private readonly bool is_input_vectorized; 31 | private readonly bool is_input_patches; 32 | 33 | private readonly GroupNorm norm; 34 | private readonly Module proj_in; 35 | private readonly ModuleList transformer_blocks; 36 | private readonly Module proj_out; 37 | 38 | public Transformer2DModel( 39 | int num_attention_heads = 16, 40 | int attention_head_dim = 88, 41 | int? in_channels = null, 42 | int? out_channels = null, 43 | int num_layers = 1, 44 | double dropout = 0.0, 45 | int norm_num_groups = 32, 46 | int? cross_attention_dim = null, 47 | bool attention_bias = false, 48 | int? sample_size = null, 49 | int? num_vector_embeds = null, 50 | int? patch_size = null, 51 | string activation_fn = "geglu", 52 | int? num_embeds_ada_norm = null, 53 | bool use_linear_projection = false, 54 | bool only_cross_attention = false, 55 | bool double_self_attention = false, 56 | bool upcast_attention = false, 57 | string norm_type = "layer_norm", 58 | bool norm_elementwise_affine = true, 59 | double norm_eps = 1e-5, 60 | string attention_type = "default", 61 | int? caption_channels = null, 62 | double? interpolation_scale = null, 63 | ScalarType dtype = ScalarType.Float32) 64 | : base(nameof(Transformer2DModel)) 65 | { 66 | this.num_attention_heads = num_attention_heads; 67 | this.attention_head_dim = attention_head_dim; 68 | this.in_channels = in_channels; 69 | this.out_channels = out_channels ?? in_channels; 70 | this.num_layers = num_layers; 71 | this.dropout = dropout; 72 | this.norm_num_groups = norm_num_groups; 73 | this.cross_attention_dim = cross_attention_dim; 74 | this.attention_bias = attention_bias; 75 | this.sample_size = sample_size; 76 | this.num_vector_embeds = num_vector_embeds; 77 | this.patch_size = patch_size; 78 | this.activation_fn = activation_fn; 79 | this.num_embeds_ada_norm = num_embeds_ada_norm; 80 | this.use_linear_projection = use_linear_projection; 81 | this.only_cross_attention = only_cross_attention; 82 | this.double_self_attention = double_self_attention; 83 | this.upcast_attention = upcast_attention; 84 | this.norm_type = norm_type; 85 | this.norm_elementwise_affine = norm_elementwise_affine; 86 | this.norm_eps = norm_eps; 87 | this.attention_type = attention_type; 88 | this.caption_channels = caption_channels; 89 | this.interpolation_scale = interpolation_scale; 90 | 91 | if (norm_type != "layer_norm") 92 | { 93 | throw new NotImplementedException("Only layer_norm is supported for now"); 94 | } 95 | 96 | if (patch_size is not null) 97 | { 98 | throw new ArgumentNullException("patch_size"); 99 | } 100 | 101 | if (num_embeds_ada_norm is not null) 102 | { 103 | throw new ArgumentNullException("num_embeds_ada_norm"); 104 | } 105 | var inner_dim = attention_head_dim * num_attention_heads; 106 | this.is_input_continuous = (in_channels is not null) && (patch_size is null); 107 | this.is_input_vectorized = false; 108 | this.is_input_patches = false; 109 | 110 | if (this.is_input_continuous) 111 | { 112 | this.norm = GroupNorm(num_groups: norm_num_groups, num_channels: in_channels!.Value, eps: 1e-6, affine: true, dtype: dtype); 113 | 114 | if (this.use_linear_projection) 115 | { 116 | this.proj_in = Linear(in_channels!.Value, inner_dim, dtype: dtype); 117 | } 118 | else 119 | { 120 | this.proj_in = Conv2d(in_channels!.Value, inner_dim, kernelSize: 1, stride: 1, padding: Padding.Valid, dtype: dtype); 121 | } 122 | } 123 | 124 | this.transformer_blocks = new ModuleList(); 125 | for (int i = 0; i < num_layers; i++) 126 | { 127 | this.transformer_blocks.Add(new BasicTransformerBlock( 128 | dim: inner_dim, 129 | num_attention_heads: num_attention_heads, 130 | attention_head_dim: attention_head_dim, 131 | dropout: dropout, 132 | cross_attention_dim: cross_attention_dim, 133 | activation_fn: activation_fn, 134 | num_embeds_ada_norm: num_embeds_ada_norm, 135 | attention_bias: attention_bias, 136 | only_cross_attention: only_cross_attention, 137 | double_self_attention: double_self_attention, 138 | upcast_attention: upcast_attention, 139 | norm_type: norm_type, 140 | norm_elementwise_affine: norm_elementwise_affine, 141 | norm_eps: norm_eps, 142 | attention_type: attention_type, 143 | dtype: dtype 144 | )); 145 | } 146 | 147 | // 4. Define output layers 148 | if (this.is_input_continuous) 149 | { 150 | if (this.use_linear_projection) 151 | { 152 | this.proj_out = Linear(inner_dim, in_channels!.Value, dtype: dtype); 153 | } 154 | else 155 | { 156 | this.proj_out = Conv2d(inner_dim, in_channels!.Value, kernelSize: 1, stride: 1, padding: Padding.Valid, dtype: dtype); 157 | } 158 | } 159 | } 160 | 161 | public override Transformer2DModelOutput forward( 162 | Tensor hidden_states, 163 | Tensor? encoder_hidden_states = null, 164 | Tensor? timestep = null, 165 | Tensor? class_labels = null, 166 | Tensor? attention_mask = null, 167 | Tensor? encoder_attention_mask = null) 168 | { 169 | if (attention_mask is not null && attention_mask.ndim == 2) 170 | { 171 | // assume that mask is expressed as: 172 | // (1 = keep, 0 = discard) 173 | // convert mask into a bias that can be added to attention scores: 174 | // (keep = +0, discard = -10000.0) 175 | attention_mask = (1-attention_mask.to(hidden_states.dtype)) * -10000.0; 176 | attention_mask = attention_mask.unsqueeze(1); 177 | } 178 | 179 | // convert encoder_attention_mask to a bias the same way we do for attention_mask 180 | if (encoder_attention_mask is not null && encoder_attention_mask.ndim == 2) 181 | { 182 | encoder_attention_mask = (1-encoder_attention_mask.to(hidden_states.dtype)) * -10000.0; 183 | encoder_attention_mask = encoder_attention_mask.unsqueeze(1); 184 | } 185 | var residual = hidden_states; 186 | var batch = hidden_states.shape[0]; 187 | var inner_dim = hidden_states.shape[1]; 188 | var height = hidden_states.shape[2]; 189 | var width = hidden_states.shape[3]; 190 | 191 | if (this.is_input_continuous) 192 | { 193 | hidden_states = this.norm.forward(hidden_states); 194 | if (this.use_linear_projection) 195 | { 196 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim); 197 | hidden_states = this.proj_in.forward(hidden_states); 198 | } 199 | else 200 | { 201 | hidden_states = this.proj_in.forward(hidden_states); 202 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim); 203 | } 204 | } 205 | 206 | // 2. Blocks 207 | foreach (var block in this.transformer_blocks) 208 | { 209 | hidden_states = block.forward( 210 | hidden_states, 211 | attention_mask: attention_mask, 212 | encoder_hidden_states: encoder_hidden_states, 213 | encoder_attention_mask: encoder_attention_mask, 214 | timestep: timestep); 215 | } 216 | 217 | // 3. Output 218 | if (this.is_input_continuous) 219 | { 220 | if (this.use_linear_projection) 221 | { 222 | hidden_states = this.proj_out.forward(hidden_states); 223 | hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous(); 224 | } 225 | else 226 | { 227 | hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous(); 228 | hidden_states = this.proj_out.forward(hidden_states); 229 | } 230 | 231 | hidden_states = hidden_states + residual; 232 | } 233 | 234 | return new Transformer2DModelOutput(hidden_states); 235 | } 236 | } 237 | 238 | public class Transformer2DModelOutput 239 | { 240 | public Transformer2DModelOutput(Tensor sample) 241 | { 242 | Sample = sample; 243 | } 244 | public Tensor Sample { get; } 245 | } -------------------------------------------------------------------------------- /UNet/UNet2DConditionModelConfig.cs: -------------------------------------------------------------------------------- 1 | using System.Text.Json.Serialization; 2 | 3 | namespace SD; 4 | public class UNet2DConditionModelConfig 5 | { 6 | [JsonPropertyName("sample_size")] 7 | public int? SampleSize {get; set;} = null; 8 | 9 | [JsonPropertyName("in_channels")] 10 | public int InChannels {get; set;} = 4; 11 | 12 | [JsonPropertyName("out_channels")] 13 | public int OutChannels {get; set;} = 4; 14 | 15 | [JsonPropertyName("center_input_sample")] 16 | public bool CenterInputSample {get; set;} = false; 17 | 18 | [JsonPropertyName("flip_sin_to_cos")] 19 | public bool FlipSinToCos {get; set;} = true; 20 | 21 | [JsonPropertyName("freq_shift")] 22 | public int FreqShift {get; set;} = 0; 23 | 24 | [JsonPropertyName("down_block_types")] 25 | public string[] DownBlockTypes {get; set;} = new string[] { 26 | "CrossAttnDownBlock2D", 27 | "CrossAttnDownBlock2D", 28 | "CrossAttnDownBlock2D", 29 | "DownBlock2D", 30 | }; 31 | 32 | [JsonPropertyName("mid_block_type")] 33 | public string MidBlockType {get; set;} = "UNetMidBlock2DCrossAttn"; 34 | 35 | [JsonPropertyName("up_block_types")] 36 | public string[] UpBlockTypes {get; set;} = new string[] { 37 | "UpBlock2D", 38 | "CrossAttnUpBlock2D", 39 | "CrossAttnUpBlock2D", 40 | "CrossAttnUpBlock2D", 41 | }; 42 | 43 | [JsonPropertyName("only_cross_attention")] 44 | public bool OnlyCrossAttention {get; set;} = false; 45 | 46 | [JsonPropertyName("block_out_channels")] 47 | public int[] BlockOutChannels {get; set;} = new int[] {320, 640, 1280, 1280}; 48 | 49 | [JsonPropertyName("layers_per_block")] 50 | public int LayersPerBlock {get; set;} = 2; 51 | 52 | [JsonPropertyName("downsample_padding")] 53 | public int DownsamplePadding {get; set;} = 1; 54 | 55 | [JsonPropertyName("mid_block_scale_factor")] 56 | public float MidBlockScaleFactor {get; set;} = 1; 57 | 58 | [JsonPropertyName("dropout")] 59 | public float Dropout {get; set;} = 0.0f; 60 | 61 | [JsonPropertyName("act_fn")] 62 | public string ActFn {get; set;} = "silu"; 63 | 64 | [JsonPropertyName("norm_num_groups")] 65 | public int? NormNumGroups {get; set;} = 32; 66 | 67 | [JsonPropertyName("norm_eps")] 68 | public float NormEps {get; set;} = 1e-5f; 69 | 70 | [JsonPropertyName("cross_attention_dim")] 71 | public int CrossAttentionDim {get; set;} = 1280; 72 | 73 | [JsonPropertyName("transformer_layers_per_block")] 74 | public int TransformerLayersPerBlock {get; set;} = 1; 75 | 76 | [JsonPropertyName("reverse_transformer_layers_per_block")] 77 | public int[]? ReverseTransformerLayersPerBlock {get; set;} = null; 78 | 79 | [JsonPropertyName("encoder_hid_dim")] 80 | public int? EncoderHidDim {get; set;} = null; 81 | 82 | [JsonPropertyName("encoder_hid_dim_type")] 83 | public string? EncoderHidDimType {get; set;} = null; 84 | 85 | [JsonPropertyName("attention_head_dim")] 86 | public int[] AttentionHeadDim {get; set;} = [5, 10, 20, 20]; 87 | 88 | [JsonPropertyName("num_attention_heads")] 89 | public int? NumAttentionHeads {get; set;} = null; 90 | 91 | [JsonPropertyName("dual_cross_attention")] 92 | public bool DualCrossAttention {get; set;} = false; 93 | 94 | [JsonPropertyName("use_linear_projection")] 95 | public bool UseLinearProjection {get; set;} = false; 96 | 97 | [JsonPropertyName("class_embed_type")] 98 | public string? ClassEmbedType {get; set;} = null; 99 | 100 | [JsonPropertyName("addition_embed_type")] 101 | public string? AdditionEmbedType {get; set;} = null; 102 | 103 | [JsonPropertyName("addition_time_embed_dim")] 104 | public int? AdditionTimeEmbedDim {get; set;} = null; 105 | 106 | [JsonPropertyName("num_class_embeds")] 107 | public int? NumClassEmbeds {get; set;} = null; 108 | 109 | [JsonPropertyName("upcast_attention")] 110 | public bool UpcastAttention {get; set;} = false; 111 | 112 | [JsonPropertyName("resnet_time_scale_shift")] 113 | public string ResnetTimeScaleShift {get; set;} = "default"; 114 | 115 | [JsonPropertyName("resnet_skip_time_act")] 116 | public bool ResnetSkipTimeAct {get; set;} = false; 117 | 118 | [JsonPropertyName("resnet_out_scale_factor")] 119 | public float ResnetOutScaleFactor {get; set;} = 1.0f; 120 | 121 | [JsonPropertyName("time_embedding_type")] 122 | public string TimeEmbeddingType {get; set;} = "positional"; 123 | 124 | [JsonPropertyName("time_embedding_dim")] 125 | public int? TimeEmbeddingDim {get; set;} = null; 126 | 127 | [JsonPropertyName("time_embedding_act_fn")] 128 | public string? TimeEmbeddingActFn {get; set;} = null; 129 | 130 | [JsonPropertyName("timestep_post_act")] 131 | public string? TimestepPostAct {get; set;} = null; 132 | 133 | [JsonPropertyName("time_cond_proj_dim")] 134 | public int? TimeCondProjDim {get; set;} = null; 135 | 136 | [JsonPropertyName("conv_in_kernel")] 137 | public int ConvInKernel {get; set;} = 3; 138 | 139 | [JsonPropertyName("conv_out_kernel")] 140 | public int ConvOutKernel {get; set;} = 3; 141 | 142 | [JsonPropertyName("projection_class_embeddings_input_dim")] 143 | public int? ProjectionClassEmbeddingsInputDim {get; set;} = null; 144 | 145 | [JsonPropertyName("attention_type")] 146 | public string AttentionType {get; set;} = "default"; 147 | 148 | [JsonPropertyName("class_embeddings_concat")] 149 | public bool ClassEmbeddingsConcat {get; set;} = false; 150 | 151 | [JsonPropertyName("mid_block_only_cross_attention")] 152 | public bool MidBlockOnlyCrossAttention {get; set;} = false; 153 | 154 | [JsonPropertyName("cross_attention_norm")] 155 | public string? CrossAttentionNorm {get; set;} = null; 156 | 157 | [JsonPropertyName("addition_embed_type_num_heads")] 158 | public int AdditionEmbedTypeNumHeads {get; set;} = 64; 159 | 160 | public ScalarType DType {get; set;} = ScalarType.Float32; 161 | } 162 | -------------------------------------------------------------------------------- /UNet/UNetMidBlock2D.cs: -------------------------------------------------------------------------------- 1 | using static TorchSharp.torch.nn; 2 | using static TorchSharp.torch; 3 | using TorchSharp.Modules; 4 | 5 | namespace SD; 6 | 7 | public class UNetMidBlock2D : Module 8 | { 9 | private readonly ModuleList attentions; 10 | private readonly ModuleList> resnets; 11 | private readonly bool add_attention; 12 | private readonly ScalarType defaultDtype; 13 | 14 | public UNetMidBlock2D( 15 | int in_channels, 16 | int? temb_channels = null, 17 | float dropout = 0.0f, 18 | int num_layers = 1, 19 | float resnet_eps = 1e-6f, 20 | string resnet_time_scale_shift = "default", 21 | string resnet_act_fn = "swish", 22 | int? resnet_groups = 32, 23 | int? attn_groups = null, 24 | bool resnet_pre_norm = true, 25 | bool add_attention = true, 26 | bool from_deprecated_attn_block = true, 27 | int attention_head_dim = 1, 28 | float output_scale_factor = 1.0f, 29 | ScalarType dtype = ScalarType.Float32) 30 | : base(nameof(UNetMidBlock2D)) 31 | { 32 | resnet_groups = resnet_groups ?? Math.Min(in_channels / 4, 32); 33 | this.add_attention = add_attention; 34 | this.defaultDtype = dtype; 35 | 36 | if (attn_groups is null) 37 | { 38 | attn_groups = resnet_time_scale_shift == "default" ? resnet_groups : null; 39 | } 40 | 41 | this.resnets = new ModuleList>(); 42 | if (resnet_time_scale_shift == "spatial") 43 | { 44 | resnets.Add( 45 | new ResnetBlockCondNorm2D( 46 | in_channels: in_channels, 47 | out_channels: in_channels, 48 | temb_channels: temb_channels ?? 512, 49 | eps: resnet_eps, 50 | groups: resnet_groups.Value, 51 | dropout: dropout, 52 | time_embedding_norm: "spatial", 53 | non_linearity: resnet_act_fn, 54 | output_scale_factor: output_scale_factor, 55 | dtype: dtype) 56 | ); 57 | } 58 | else 59 | { 60 | resnets.Add( 61 | new ResnetBlock2D( 62 | in_channels: in_channels, 63 | out_channels: in_channels, 64 | temb_channels: temb_channels, 65 | eps: resnet_eps, 66 | groups: resnet_groups.Value, 67 | dropout: dropout, 68 | time_embedding_norm: resnet_time_scale_shift, 69 | non_linearity: resnet_act_fn, 70 | output_scale_factor: output_scale_factor, 71 | pre_norm: resnet_pre_norm, 72 | dtype: dtype) 73 | ); 74 | } 75 | 76 | var attentions = new ModuleList(); 77 | for(int i = 0; i!= num_layers; ++i) 78 | { 79 | if (add_attention) 80 | { 81 | attentions.Add( 82 | new Attention( 83 | query_dim: in_channels, 84 | heads: in_channels / attention_head_dim, 85 | dim_head: attention_head_dim, 86 | rescale_output_factor: output_scale_factor, 87 | eps: resnet_eps, 88 | norm_num_groups: attn_groups, 89 | spatial_norm_dim: resnet_time_scale_shift == "spatial" ? temb_channels : null, 90 | residual_connection: true, 91 | bias: true, 92 | upcast_softmax: true, 93 | _from_deprecated_attn_block: from_deprecated_attn_block, 94 | dtype: dtype) 95 | ); 96 | } 97 | else 98 | { 99 | attentions.Add(null); 100 | } 101 | 102 | if (resnet_time_scale_shift == "spatial") 103 | { 104 | resnets.Add( 105 | new ResnetBlockCondNorm2D( 106 | in_channels: in_channels, 107 | out_channels: in_channels, 108 | temb_channels: temb_channels ?? 512, 109 | eps: resnet_eps, 110 | groups: resnet_groups!.Value, 111 | dropout: dropout, 112 | time_embedding_norm: "spatial", 113 | non_linearity: resnet_act_fn, 114 | output_scale_factor: output_scale_factor, 115 | dtype: dtype) 116 | ); 117 | } 118 | else 119 | { 120 | resnets.Add( 121 | new ResnetBlock2D( 122 | in_channels: in_channels, 123 | out_channels: in_channels, 124 | temb_channels: temb_channels, 125 | eps: resnet_eps, 126 | groups: resnet_groups!.Value, 127 | dropout: dropout, 128 | time_embedding_norm: resnet_time_scale_shift, 129 | non_linearity: resnet_act_fn, 130 | output_scale_factor: output_scale_factor, 131 | pre_norm: resnet_pre_norm, 132 | dtype: dtype) 133 | ); 134 | } 135 | } 136 | 137 | this.attentions = attentions; 138 | RegisterComponents(); 139 | } 140 | 141 | public override Tensor forward(UNetMidBlock2DInput input) 142 | { 143 | var hidden_states = input.HiddenStates; 144 | var temb = input.Temb; 145 | hidden_states = resnets[0].forward(hidden_states, temb); 146 | foreach (var (attn, resnet) in Enumerable.Zip(attentions, resnets.Skip(1))) 147 | { 148 | if (attn is not null) 149 | { 150 | hidden_states = attn.forward(hidden_states, temb: temb); 151 | } 152 | hidden_states = resnet.forward(hidden_states, temb); 153 | } 154 | 155 | return hidden_states; 156 | } 157 | } -------------------------------------------------------------------------------- /UNet/UNetMidBlock2DCrossAttn.cs: -------------------------------------------------------------------------------- 1 | namespace SD; 2 | 3 | public class UNetMidBlock2DInput 4 | { 5 | public UNetMidBlock2DInput( 6 | Tensor hiddenStates, 7 | Tensor? temb = null, 8 | Tensor? encoderHiddenStates = null, 9 | Tensor? attentionMask = null, 10 | Dictionary? crossAttentionKwargs = null, 11 | Tensor? encoderAttentionMask = null) 12 | { 13 | HiddenStates = hiddenStates; 14 | Temb = temb; 15 | EncoderHiddenStates = encoderHiddenStates; 16 | AttentionMask = attentionMask; 17 | CrossAttentionKwargs = crossAttentionKwargs; 18 | EncoderAttentionMask = encoderAttentionMask; 19 | } 20 | 21 | public Tensor HiddenStates { get; } 22 | 23 | public Tensor? Temb { get; } 24 | 25 | public Tensor? EncoderHiddenStates { get; } 26 | 27 | public Tensor? AttentionMask { get; } 28 | 29 | public Dictionary? CrossAttentionKwargs { get; } 30 | 31 | public Tensor? EncoderAttentionMask { get; } 32 | } 33 | 34 | public class UNetMidBlock2DCrossAttn: Module 35 | { 36 | private readonly bool has_cross_attention; 37 | private readonly int num_attention_heads; 38 | private readonly ModuleList resnets; 39 | private readonly ModuleList attentions; 40 | 41 | public UNetMidBlock2DCrossAttn( 42 | int in_channels, 43 | int temb_channels, 44 | float dropout = 0.0f, 45 | int num_layers = 1, 46 | int[]? transformer_layers_per_block = null, 47 | float resnet_eps = 1e-6f, 48 | string resnet_time_scale_shift = "default", 49 | string resnet_act_fn = "swish", 50 | int resnet_groups = 32, 51 | bool resnet_pre_norm = true, 52 | int num_attention_heads = 1, 53 | float output_scale_factor = 1.0f, 54 | int? cross_attention_dim = 1280, 55 | bool dual_cross_attention = false, 56 | bool use_linear_projection = false, 57 | bool upcast_attention = false, 58 | string attention_type = "default", 59 | ScalarType dtype = ScalarType.Float32 60 | ): base(nameof(UNetMidBlock2DCrossAttn)) 61 | { 62 | ModuleList resnets = new ModuleList(); 63 | ModuleList attentions = new ModuleList(); 64 | 65 | this.has_cross_attention = true; 66 | this.num_attention_heads = num_attention_heads; 67 | transformer_layers_per_block = transformer_layers_per_block ?? Enumerable.Repeat(num_layers, 1).ToArray(); 68 | 69 | resnets.Add( 70 | new ResnetBlock2D( 71 | in_channels: in_channels, 72 | out_channels: in_channels, 73 | temb_channels: temb_channels, 74 | eps: (float)resnet_eps, 75 | groups: resnet_groups, 76 | dropout: (float)dropout, 77 | time_embedding_norm: resnet_time_scale_shift, 78 | non_linearity: resnet_act_fn, 79 | output_scale_factor: (float)output_scale_factor, 80 | pre_norm: resnet_pre_norm, 81 | dtype: dtype)); 82 | for(int i = 0; i != num_layers; ++i) 83 | { 84 | resnets.Add( 85 | new ResnetBlock2D( 86 | in_channels: in_channels, 87 | out_channels: in_channels, 88 | temb_channels: temb_channels, 89 | eps: (float)resnet_eps, 90 | groups: resnet_groups, 91 | dropout: (float)dropout, 92 | time_embedding_norm: resnet_time_scale_shift, 93 | non_linearity: resnet_act_fn, 94 | output_scale_factor: (float)output_scale_factor, 95 | pre_norm: resnet_pre_norm, 96 | dtype: dtype)); 97 | 98 | if (!dual_cross_attention) 99 | { 100 | attentions.Add( 101 | new Transformer2DModel( 102 | num_attention_heads: num_attention_heads, 103 | attention_head_dim: in_channels / num_attention_heads, 104 | in_channels: in_channels, 105 | num_layers: transformer_layers_per_block[i], 106 | cross_attention_dim: cross_attention_dim, 107 | norm_num_groups: resnet_groups, 108 | use_linear_projection: use_linear_projection, 109 | upcast_attention: upcast_attention, 110 | attention_type: attention_type, 111 | dtype: dtype)); 112 | } 113 | else 114 | { 115 | attentions.Add( 116 | new DualTransformer2DModel( 117 | num_attention_heads: num_attention_heads, 118 | attention_head_dim: in_channels / num_attention_heads, 119 | in_channels: in_channels, 120 | num_layers: 1, 121 | cross_attention_dim: cross_attention_dim, 122 | norm_num_groups: resnet_groups, 123 | dtype: dtype)); 124 | } 125 | } 126 | 127 | this.resnets = resnets; 128 | this.attentions = attentions; 129 | } 130 | 131 | public override Tensor forward(UNetMidBlock2DInput input) 132 | { 133 | var hiddenStates = input.HiddenStates; 134 | var temb = input.Temb; 135 | var encoderHiddenStates = input.EncoderHiddenStates; 136 | var attentionMask = input.AttentionMask; 137 | var crossAttentionKwargs = input.CrossAttentionKwargs; 138 | var encoderAttentionMask = input.EncoderAttentionMask; 139 | 140 | hiddenStates = this.resnets[0].forward(hiddenStates, temb); 141 | 142 | foreach (var (resnet, attention) in this.resnets.Skip(1).Zip(this.attentions)) 143 | { 144 | if (attention is Transformer2DModel transformer) 145 | { 146 | hiddenStates = transformer.forward( 147 | hiddenStates, 148 | encoder_hidden_states: encoderHiddenStates, 149 | attention_mask: attentionMask, 150 | encoder_attention_mask: encoderAttentionMask).Sample; 151 | } 152 | else if (attention is DualTransformer2DModel dualTransformer) 153 | { 154 | hiddenStates = dualTransformer.forward( 155 | hiddenStates, 156 | encoder_hidden_states: encoderHiddenStates ?? throw new ArgumentNullException(nameof(encoderHiddenStates)), 157 | attention_mask: attentionMask).Sample; 158 | } 159 | 160 | hiddenStates = resnet.forward(hiddenStates, temb); 161 | } 162 | 163 | return hiddenStates; 164 | } 165 | } -------------------------------------------------------------------------------- /UNet/UpBlock2D.cs: -------------------------------------------------------------------------------- 1 | namespace SD; 2 | 3 | public class UpBlock2D : Module 4 | { 5 | public UpBlock2D( 6 | int in_channels, 7 | int prev_output_channel, 8 | int out_channels, 9 | int temb_channels, 10 | int? resolution_idx = null, 11 | float dropout = 0.0f, 12 | int num_layers = 1, 13 | float resnet_eps = 1e-6f, 14 | string resnet_time_scale_shift = "default", 15 | string resnet_act_fn = "swish", 16 | int? resnet_groups = 32, 17 | bool resnet_pre_norm = true, 18 | float output_scale_factor = 1.0f, 19 | bool add_upsample = true, 20 | ScalarType dtype = ScalarType.Float32 21 | ): base(nameof(UpBlock2D)) 22 | { 23 | var resnets = new ModuleList(); 24 | for(int i = 0; i != num_layers; ++i) 25 | { 26 | var res_skip_channels = i == num_layers - 1 ? in_channels : out_channels; 27 | var resnet_in_channels = i == 0 ? prev_output_channel : out_channels; 28 | 29 | resnets.Add( 30 | new ResnetBlock2D( 31 | in_channels: resnet_in_channels + res_skip_channels, 32 | out_channels: out_channels, 33 | temb_channels: temb_channels, 34 | eps: resnet_eps, 35 | groups: resnet_groups ?? 32, 36 | dropout: dropout, 37 | time_embedding_norm: resnet_time_scale_shift, 38 | non_linearity: resnet_act_fn, 39 | output_scale_factor: output_scale_factor, 40 | pre_norm: resnet_pre_norm, 41 | dtype: dtype) 42 | ); 43 | } 44 | 45 | this.resnets = resnets; 46 | 47 | if (add_upsample) 48 | { 49 | this.upsamplers = new ModuleList(); 50 | this.upsamplers.Add(new Upsample2D( 51 | channels: out_channels, 52 | use_conv: true, 53 | out_channels: out_channels, 54 | dtype: dtype 55 | )); 56 | } 57 | 58 | this.resolution_idx = resolution_idx; 59 | } 60 | 61 | private readonly ModuleList resnets; 62 | private readonly ModuleList? upsamplers; 63 | private readonly int? resolution_idx; 64 | public ModuleList Resnets => resnets; 65 | 66 | public override Tensor forward(UpBlock2DInput x) 67 | { 68 | var hidden_states = x.HiddenStates; 69 | foreach (var resnet in resnets) 70 | { 71 | var res_hidden_states = x.ResHiddenStatesTuple[^1]; 72 | var res_hidden_states_tuple = x.ResHiddenStatesTuple[..^1]; 73 | 74 | hidden_states = torch.cat(new Tensor[] {hidden_states, res_hidden_states}, 1); 75 | hidden_states = resnet.forward(hidden_states, x.Temb); 76 | } 77 | 78 | if (upsamplers is not null) 79 | { 80 | foreach (var upsample in upsamplers) 81 | { 82 | hidden_states = upsample.forward(hidden_states, x.UpsampleSize); 83 | } 84 | } 85 | 86 | return hidden_states; 87 | } 88 | } -------------------------------------------------------------------------------- /UNet/UpDecoderBlock2D.cs: -------------------------------------------------------------------------------- 1 | using static TorchSharp.torch.nn; 2 | using static TorchSharp.torch; 3 | using TorchSharp.Modules; 4 | 5 | namespace SD; 6 | 7 | public class UpDecoderBlock2D : Module 8 | { 9 | private readonly int in_channels; 10 | private readonly int out_channels; 11 | private readonly int? resolution_idx; 12 | private readonly float dropout; 13 | private readonly int num_layers; 14 | private readonly float resnet_eps; 15 | private readonly string resnet_time_scale_shift; 16 | private readonly string resnet_act_fn; 17 | private readonly int resnet_groups; 18 | private readonly bool resnet_pre_norm; 19 | private readonly float output_scale_factor; 20 | private readonly bool add_upsample; 21 | private readonly int? temb_channels; 22 | private readonly ScalarType dtype; 23 | 24 | private readonly ModuleList> resnets; 25 | private readonly ModuleList>? upsamplers = null; 26 | public UpDecoderBlock2D( 27 | int in_channels, 28 | int out_channels, 29 | int? resolution_idx = null, 30 | float dropout = 0.0f, 31 | int num_layers = 1, 32 | float resnet_eps = 1e-6f, 33 | string resnet_time_scale_shift = "default", 34 | string resnet_act_fn = "swish", 35 | int resnet_groups = 32, 36 | bool resnet_pre_norm = true, 37 | float output_scale_factor = 1.0f, 38 | bool add_upsample = true, 39 | int? temb_channels = null, 40 | ScalarType dtype = ScalarType.Float32) 41 | : base(nameof(UpDecoderBlock2D)) 42 | { 43 | this.in_channels = in_channels; 44 | this.out_channels = out_channels; 45 | this.resolution_idx = resolution_idx; 46 | this.dropout = dropout; 47 | this.num_layers = num_layers; 48 | this.resnet_eps = resnet_eps; 49 | this.resnet_time_scale_shift = resnet_time_scale_shift; 50 | this.resnet_act_fn = resnet_act_fn; 51 | this.resnet_groups = resnet_groups; 52 | this.resnet_pre_norm = resnet_pre_norm; 53 | this.output_scale_factor = output_scale_factor; 54 | this.add_upsample = add_upsample; 55 | this.temb_channels = temb_channels; 56 | this.dtype = dtype; 57 | 58 | this.resnets = new ModuleList>(); 59 | for(int i = 0; i!= num_layers; ++i) 60 | { 61 | var input_channels = i == 0 ? in_channels : out_channels; 62 | if (resnet_time_scale_shift == "spatial") 63 | { 64 | resnets.Add( 65 | new ResnetBlockCondNorm2D( 66 | in_channels: input_channels, 67 | out_channels: out_channels, 68 | temb_channels: temb_channels ?? 512, 69 | eps: resnet_eps, 70 | groups: resnet_groups, 71 | time_embedding_norm: "spatial", 72 | non_linearity: resnet_act_fn, 73 | output_scale_factor: output_scale_factor, 74 | dtype: dtype) 75 | ); 76 | } 77 | else 78 | { 79 | resnets.Add( 80 | new ResnetBlock2D( 81 | in_channels: input_channels, 82 | out_channels: out_channels, 83 | temb_channels: temb_channels, 84 | groups: resnet_groups, 85 | pre_norm: resnet_pre_norm, 86 | eps: resnet_eps, 87 | non_linearity: resnet_act_fn, 88 | time_embedding_norm: resnet_time_scale_shift, 89 | output_scale_factor: output_scale_factor, 90 | dtype: dtype) 91 | ); 92 | } 93 | } 94 | 95 | if (add_upsample) 96 | { 97 | this.upsamplers = new ModuleList>(); 98 | this.upsamplers.Add(new Upsample2D( 99 | channels: out_channels, 100 | use_conv: true, 101 | out_channels: out_channels, 102 | dtype: dtype 103 | )); 104 | } 105 | 106 | this.resolution_idx = resolution_idx; 107 | } 108 | 109 | public override Tensor forward(Tensor hidden_states, Tensor? temb) 110 | { 111 | foreach (var resnet in resnets) 112 | { 113 | hidden_states = resnet.forward(hidden_states, temb); 114 | } 115 | 116 | if (upsamplers != null) 117 | { 118 | foreach (var upsample in upsamplers) 119 | { 120 | hidden_states = upsample.forward(hidden_states, null); 121 | } 122 | } 123 | 124 | return hidden_states; 125 | } 126 | } -------------------------------------------------------------------------------- /UNet/Upsample2D.cs: -------------------------------------------------------------------------------- 1 | using static TorchSharp.torch.nn; 2 | using static TorchSharp.torch; 3 | using TorchSharp.Modules; 4 | 5 | namespace SD; 6 | 7 | public class Upsample2D : Module 8 | { 9 | private readonly int channels; 10 | private readonly bool use_conv; 11 | private readonly int out_channels; 12 | private readonly bool use_conv_transpose; 13 | private readonly string conv_name; 14 | private readonly bool interpolate; 15 | 16 | private readonly Module? norm; 17 | private readonly Module? conv; 18 | 19 | private readonly Module? Conv2d_0; 20 | private readonly ScalarType defaultDtype; 21 | 22 | public Upsample2D( 23 | int channels, 24 | bool use_conv = false, 25 | bool use_conv_transpose = false, 26 | int? out_channels = null, 27 | string name = "conv", 28 | int? kernel_size = null, 29 | int padding = 1, 30 | string norm_type = null, 31 | float? eps = null, 32 | bool? elementwise_affine = null, 33 | bool bias = true, 34 | bool interpolate = true, 35 | ScalarType dtype = ScalarType.Float32) 36 | : base(nameof(Upsample2D)) 37 | { 38 | this.channels = channels; 39 | this.out_channels = out_channels ?? channels; 40 | this.use_conv = use_conv; 41 | this.use_conv_transpose = use_conv_transpose; 42 | this.conv_name = name; 43 | this.interpolate = interpolate; 44 | this.defaultDtype = dtype; 45 | 46 | if (norm_type is "ln_norm") 47 | { 48 | this.norm = nn.LayerNorm(normalized_shape: this.channels, eps: eps ?? 1e-5, elementwise_affine: elementwise_affine?? true, dtype: dtype); 49 | } 50 | else if (norm_type is null) 51 | { 52 | this.norm = null; 53 | } 54 | else 55 | { 56 | throw new ArgumentException("Invalid norm type: " + norm_type); 57 | } 58 | 59 | Module? conv; 60 | if (use_conv_transpose) 61 | { 62 | conv = nn.ConvTranspose2d(inputChannel: this.channels, outputChannel: this.out_channels, kernelSize: kernel_size ?? 4, stride: 2, padding: padding, bias: bias, dtype: dtype); 63 | } 64 | else if (use_conv) 65 | { 66 | conv = nn.Conv2d(inputChannel: this.channels, outputChannel: this.out_channels, kernelSize: kernel_size ?? 3, stride: 1, padding: padding, bias: bias, dtype: dtype); 67 | } 68 | else 69 | { 70 | conv = null; 71 | } 72 | 73 | if (this.conv_name is "conv") 74 | { 75 | this.conv = conv ?? throw new ArgumentException("Invalid conv type: " + this.conv_name); 76 | } 77 | else 78 | { 79 | this.Conv2d_0 = conv ?? throw new ArgumentException("Invalid conv type: " + this.conv_name); 80 | } 81 | 82 | } 83 | 84 | public override Tensor forward(Tensor hidden_states, long[]? output_size) 85 | { 86 | if (this.norm != null) 87 | { 88 | hidden_states = this.norm.forward(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2); 89 | } 90 | 91 | if (this.use_conv_transpose) 92 | { 93 | return this.conv!.forward(hidden_states); 94 | } 95 | 96 | var dtype = hidden_states.dtype; 97 | if (dtype == ScalarType.BFloat16 || dtype == ScalarType.Float16){ 98 | hidden_states = hidden_states.to_type(ScalarType.Float32); 99 | } 100 | 101 | if (hidden_states.shape[0] >= 64){ 102 | hidden_states = hidden_states.contiguous(); 103 | } 104 | 105 | if (this.interpolate){ 106 | if (output_size is null){ 107 | hidden_states = nn.functional.interpolate(hidden_states, scale_factor: [2, 2], mode: InterpolationMode.Nearest); 108 | } 109 | else{ 110 | hidden_states = nn.functional.interpolate(hidden_states, size: output_size, mode: InterpolationMode.Nearest); 111 | } 112 | } 113 | 114 | if (dtype == ScalarType.BFloat16 || dtype == ScalarType.Float16){ 115 | hidden_states = hidden_states.to_type(dtype); 116 | } 117 | 118 | if (this.use_conv) 119 | { 120 | if (this.conv_name is "conv") 121 | { 122 | return this.conv!.forward(hidden_states); 123 | } 124 | else 125 | { 126 | return this.Conv2d_0!.forward(hidden_states); 127 | } 128 | } 129 | 130 | return hidden_states; 131 | } 132 | } -------------------------------------------------------------------------------- /VAE/AutoencoderKL.cs: -------------------------------------------------------------------------------- 1 | using static TorchSharp.torch.nn; 2 | using static TorchSharp.torch; 3 | using TorchSharp.Modules; 4 | using TorchSharp; 5 | using System.Text.Json; 6 | using TorchSharp.PyBridge; 7 | 8 | namespace SD; 9 | 10 | public class AutoencoderKL : Module, IModelConfigLoader 11 | { 12 | private readonly int in_channels; 13 | private readonly int out_channels; 14 | private readonly string[] down_block_types; 15 | private readonly string[] up_block_types; 16 | private readonly int[] block_out_channels; 17 | private readonly int layers_per_block; 18 | private readonly string act_fn; 19 | 20 | private readonly int latent_channels; 21 | private readonly int norm_num_groups; 22 | private readonly int sample_size; 23 | private readonly float scaling_factor; 24 | private readonly float[]? latents_mean; 25 | private readonly float[]? latents_std; 26 | private readonly bool force_upcast; 27 | 28 | private readonly Encoder encoder; 29 | 30 | private readonly Decoder decoder; 31 | 32 | private readonly Conv2d quant_conv; 33 | private readonly Conv2d post_quant_conv; 34 | private readonly ScalarType dtype; 35 | 36 | /// 37 | /// Create an AutoencoderKL model. 38 | /// 39 | /// AutoencoderKL config 40 | /// the default dtype to use 41 | public AutoencoderKL( 42 | Config config, 43 | ScalarType dtype = ScalarType.Float32) 44 | : base(nameof(AutoencoderKL)) 45 | { 46 | this.in_channels = config!.InChannels; 47 | this.out_channels = config!.OutChannels; 48 | this.down_block_types = config!.DownBlockTypes; 49 | this.up_block_types = config!.UpBlockTypes; 50 | this.block_out_channels = config!.BlockOutChannels; 51 | this.layers_per_block = config!.LayersPerBlock; 52 | this.act_fn = config!.ActivationFunction; 53 | this.latent_channels = config!.LatentChannels; 54 | this.norm_num_groups = config!.NormNumGroups; 55 | this.sample_size = config!.SampleSize; 56 | this.scaling_factor = config!.ScalingFactor; 57 | this.latents_mean = config!.LatentsMean; 58 | this.latents_std = config!.LatentsStd; 59 | this.force_upcast = config!.ForceUpcast; 60 | this.dtype = dtype; 61 | 62 | 63 | 64 | this.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, kernelSize: 1, padding: Padding.Valid, dtype: this.dtype); 65 | this.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, kernelSize: 1, padding: Padding.Same, dtype: this.dtype); 66 | 67 | this.Config = config; 68 | 69 | this.encoder = new Encoder( 70 | inChannels: in_channels, 71 | outChannels: latent_channels, 72 | downBlockTypes: down_block_types, 73 | blockOutChannels: block_out_channels, 74 | layersPerBlock: layers_per_block, 75 | activationFunction: act_fn, 76 | mid_block_from_deprecated_attn_block: config.DecoderMidBlockFromDeprecatedAttnBlock, 77 | normNumGroups: norm_num_groups, 78 | doubleZ: true, 79 | dtype: this.dtype); 80 | 81 | this.decoder = new Decoder( 82 | in_channels: latent_channels, 83 | out_channels: out_channels, 84 | up_block_types: up_block_types, 85 | block_out_channels: block_out_channels, 86 | layers_per_block: layers_per_block, 87 | norm_num_groups: norm_num_groups, 88 | act_fn: act_fn, 89 | mid_block_from_deprecated_attn_block: config.DecoderMidBlockFromDeprecatedAttnBlock, 90 | mid_block_add_attention: true, 91 | dtype: this.dtype); 92 | 93 | RegisterComponents(); 94 | } 95 | 96 | public Decoder Decoder => this.decoder; 97 | 98 | public Encoder Encoder => this.encoder; 99 | 100 | public Config Config {get;} 101 | 102 | public DiagonalGaussianDistribution encode(Tensor x) 103 | { 104 | var h = this.encoder.forward(x); 105 | var moments = this.quant_conv.forward(h); 106 | var posterior = new DiagonalGaussianDistribution(moments); 107 | 108 | return posterior; 109 | } 110 | 111 | public Tensor _decode(Tensor z) 112 | { 113 | z = this.post_quant_conv.forward(z); 114 | var dec = this.decoder.forward(z); 115 | 116 | return dec; 117 | } 118 | 119 | public Tensor decode(Tensor z) 120 | { 121 | var dec = this._decode(z); 122 | return dec; 123 | } 124 | 125 | public override Tensor forward(Tensor sample, bool sample_posterior = false, Generator? generator = null) 126 | { 127 | var x = sample; 128 | var posterior = this.encode(x); 129 | Tensor z; 130 | if (sample_posterior) 131 | { 132 | z = posterior.Sample(generator); 133 | } 134 | else 135 | { 136 | z = posterior.Mode(); 137 | } 138 | 139 | var dec = this._decode(z); 140 | 141 | return dec; 142 | } 143 | 144 | public static AutoencoderKL FromPretrained( 145 | string pretrainedModelNameOrPath, 146 | string configName = "config.json", 147 | string modelWeightName = "diffusion_pytorch_model", 148 | bool useSafeTensor = true, 149 | ScalarType torchDtype = ScalarType.Float32 150 | ) 151 | { 152 | var configPath = Path.Combine(pretrainedModelNameOrPath, configName); 153 | var json = File.ReadAllText(configPath); 154 | var config = JsonSerializer.Deserialize(json) ?? throw new ArgumentNullException("config"); 155 | // if dtype is fp16, default FromDeprecatedAttnBlock to false 156 | if (torchDtype == ScalarType.Float16) 157 | { 158 | config.DecoderMidBlockFromDeprecatedAttnBlock = false; 159 | config.EncoderMidBlockAddAttention = false; 160 | } 161 | var autoEncoderKL = new AutoencoderKL(config, torchDtype); 162 | 163 | modelWeightName = (useSafeTensor, torchDtype) switch 164 | { 165 | (true, ScalarType.Float32) => $"{modelWeightName}.safetensors", 166 | (true, ScalarType.Float16) => $"{modelWeightName}.fp16.safetensors", 167 | (false, ScalarType.Float32) => $"{modelWeightName}.bin", 168 | (false, ScalarType.Float16) => $"{modelWeightName}.fp16.bin", 169 | _ => throw new ArgumentException("Invalid arguments for useSafeTensor and torchDtype") 170 | }; 171 | 172 | 173 | 174 | var location = Path.Combine(pretrainedModelNameOrPath, modelWeightName); 175 | 176 | var loadedParameters = new Dictionary(); 177 | autoEncoderKL.load_safetensors(location, strict: false, loadedParameters: loadedParameters); 178 | 179 | return autoEncoderKL; 180 | } 181 | 182 | public AutoencoderKL LoadFromModelConfig( 183 | string pretrainedModelNameOrPath, 184 | string configName = "config.json", 185 | string modelWeightName = "diffusion_pytorch_model", 186 | bool useSafeTensor = true, 187 | ScalarType torchDtype = ScalarType.Float32) 188 | { 189 | return AutoencoderKL.FromPretrained(pretrainedModelNameOrPath, configName, modelWeightName, useSafeTensor, torchDtype); 190 | } 191 | } -------------------------------------------------------------------------------- /VAE/Config.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Text.Json; 3 | using System.Text.Json.Serialization; 4 | 5 | namespace SD 6 | { 7 | public class Config 8 | { 9 | [JsonPropertyName("_class_name")] 10 | public string ClassName { get; set; } = "AutoencoderKL"; 11 | 12 | [JsonPropertyName("_diffusers_version")] 13 | public string DiffusersVersion { get; set; } = "0.8.0"; 14 | 15 | [JsonPropertyName("_name_or_path")] 16 | public string NameOrPath { get; set; } = "hf-models/stable-diffusion-v2-768x768/vae"; 17 | 18 | [JsonPropertyName("act_fn")] 19 | public string ActivationFunction { get; set; } = "silu"; 20 | 21 | [JsonPropertyName("block_out_channels")] 22 | public int[] BlockOutChannels { get; set; } = { 128, 256, 512, 512 }; 23 | 24 | [JsonPropertyName("down_block_types")] 25 | public string[] DownBlockTypes { get; set; } = { "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D" }; 26 | 27 | [JsonPropertyName("in_channels")] 28 | public int InChannels { get; set; } = 3; 29 | 30 | [JsonPropertyName("latent_channels")] 31 | public int LatentChannels { get; set; } = 4; 32 | 33 | [JsonPropertyName("layers_per_block")] 34 | public int LayersPerBlock { get; set; } = 2; 35 | 36 | [JsonPropertyName("norm_num_groups")] 37 | public int NormNumGroups { get; set; } = 32; 38 | 39 | [JsonPropertyName("out_channels")] 40 | public int OutChannels { get; set; } = 3; 41 | 42 | [JsonPropertyName("sample_size")] 43 | public int SampleSize { get; set; } = 768; 44 | 45 | [JsonPropertyName("up_block_types")] 46 | public string[] UpBlockTypes { get; set; } = { "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D" }; 47 | 48 | [JsonPropertyName("scaling_factor")] 49 | public float ScalingFactor { get; set; } = 0.18215f; 50 | 51 | [JsonPropertyName("latents_mean")] 52 | public float[]? LatentsMean { get; set; } 53 | 54 | [JsonPropertyName("latents_std")] 55 | public float[]? LatentsStd { get; set; } 56 | 57 | [JsonPropertyName("force_upcast")] 58 | public bool ForceUpcast { get; set; } = true; 59 | 60 | [JsonPropertyName("decoder_mid_block_from_deprecated_attn_block")] 61 | public bool DecoderMidBlockFromDeprecatedAttnBlock { get; set; } = true; 62 | 63 | [JsonPropertyName("encoder_mid_block_add_attention")] 64 | public bool EncoderMidBlockAddAttention { get; set; } = true; 65 | 66 | public override string ToString() 67 | { 68 | return JsonSerializer.Serialize(this, new JsonSerializerOptions { WriteIndented = true }); 69 | } 70 | } 71 | } -------------------------------------------------------------------------------- /VAE/Decoder.cs: -------------------------------------------------------------------------------- 1 | using static TorchSharp.torch.nn; 2 | using static TorchSharp.torch; 3 | using TorchSharp.Modules; 4 | using TorchSharp; 5 | 6 | namespace SD; 7 | 8 | public class Decoder : Module 9 | { 10 | private readonly int in_channels; 11 | private readonly int out_channels; 12 | private readonly string[] up_block_types; 13 | private readonly int[] block_out_channels; 14 | private readonly int layers_per_block; 15 | private readonly int norm_num_groups; 16 | private readonly string act_fn; 17 | private readonly string norm_type; 18 | private readonly bool mid_block_add_attention; 19 | private readonly ScalarType dtype; 20 | 21 | private readonly Conv2d conv_in; 22 | private readonly Module conv_norm_out; 23 | private readonly Module conv_act; 24 | private readonly Module conv_out; 25 | private readonly UNetMidBlock2D mid_block; 26 | private readonly ModuleList> up_blocks; 27 | public Decoder( 28 | int in_channels = 3, 29 | int out_channels = 3, 30 | string[]? up_block_types = null, 31 | int[]? block_out_channels = null, 32 | int layers_per_block = 2, 33 | int norm_num_groups = 32, 34 | string act_fn = "silu", 35 | string norm_type = "group", 36 | bool mid_block_add_attention = true, 37 | bool mid_block_from_deprecated_attn_block = true, 38 | ScalarType dtype = ScalarType.Float32) 39 | : base(nameof(Decoder)) 40 | { 41 | up_block_types = up_block_types ?? new string[] { nameof(UpDecoderBlock2D) }; 42 | block_out_channels = block_out_channels ?? new int[] { 64 }; 43 | this.dtype = dtype; 44 | this.in_channels = in_channels; 45 | this.out_channels = out_channels; 46 | this.up_block_types = up_block_types; 47 | this.block_out_channels = block_out_channels; 48 | this.layers_per_block = layers_per_block; 49 | this.norm_num_groups = norm_num_groups; 50 | this.act_fn = act_fn; 51 | this.norm_type = norm_type; 52 | this.mid_block_add_attention = mid_block_add_attention; 53 | 54 | this.conv_in = torch.nn.Conv2d(this.in_channels, this.block_out_channels[^1], kernelSize: 3, stride: 1, padding: 1, dtype: this.dtype); 55 | int? temb_channels = norm_type == "spatial" ? in_channels : null; 56 | 57 | // mid 58 | this.mid_block = new UNetMidBlock2D( 59 | in_channels: this.block_out_channels[^1], 60 | resnet_eps: 1e-6f, 61 | resnet_act_fn: act_fn, 62 | output_scale_factor: 1.0f, 63 | resnet_time_scale_shift: norm_type == "group" ? "default" : norm_type, 64 | attention_head_dim: this.block_out_channels[^1], 65 | resnet_groups: norm_num_groups, 66 | temb_channels: temb_channels, 67 | add_attention: mid_block_add_attention, 68 | from_deprecated_attn_block: mid_block_from_deprecated_attn_block, 69 | dtype: this.dtype); 70 | 71 | // up 72 | var reversed_block_out_channels = block_out_channels.Reverse().ToArray(); 73 | var output_channel = reversed_block_out_channels[0]; 74 | this.up_blocks = new ModuleList>(); 75 | for (int i = 0; i < up_block_types.Length; i++) 76 | { 77 | var prev_output_channel = output_channel; 78 | output_channel = reversed_block_out_channels[i]; 79 | 80 | var is_final_block = i == up_block_types.Length - 1; 81 | var up_block = new UpDecoderBlock2D( 82 | in_channels: prev_output_channel, 83 | out_channels: output_channel, 84 | add_upsample: !is_final_block, 85 | num_layers: layers_per_block + 1, 86 | resnet_eps: 1e-6f, 87 | resnet_act_fn: act_fn, 88 | resnet_groups: norm_num_groups, 89 | temb_channels: temb_channels, 90 | dtype: this.dtype); 91 | 92 | this.up_blocks.Add(up_block); 93 | prev_output_channel = output_channel; 94 | } 95 | 96 | // out 97 | if (norm_type == "spatial") 98 | { 99 | this.conv_norm_out = new SpatialNorm(block_out_channels[0], temb_channels ?? 512, dtype: this.dtype); 100 | } 101 | else 102 | { 103 | this.conv_norm_out = GroupNorm(num_channels: block_out_channels[0], num_groups: norm_num_groups, eps: 1e-6f, dtype: this.dtype); 104 | } 105 | 106 | this.conv_act = nn.SiLU(); 107 | this.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernelSize: 3, padding: Padding.Same, dtype: this.dtype); 108 | 109 | RegisterComponents(); 110 | } 111 | 112 | public override Tensor forward(Tensor sample, Tensor? latent_embeds = null) 113 | { 114 | sample = this.conv_in.forward(sample); 115 | var upscale_dtype = this.up_blocks[0].parameters().First().dtype; 116 | 117 | // middle 118 | var input = new UNetMidBlock2DInput(sample, latent_embeds); 119 | sample = this.mid_block.forward(input); 120 | sample = sample.to(upscale_dtype); 121 | 122 | // up 123 | foreach (var up_block in this.up_blocks) 124 | { 125 | sample = up_block.forward(sample, latent_embeds); 126 | } 127 | 128 | // post-process 129 | if (latent_embeds is null && this.conv_norm_out is Module norm) 130 | { 131 | sample = norm.forward(sample); 132 | } 133 | else if (this.conv_norm_out is Module norm1) 134 | { 135 | sample = norm1.forward(sample, latent_embeds); 136 | } 137 | else 138 | { 139 | throw new ArgumentException("Invalid norm type: " + this.conv_norm_out.GetType().Name); 140 | } 141 | 142 | sample = this.conv_act.forward(sample); 143 | sample = this.conv_out.forward(sample); 144 | 145 | return sample; 146 | } 147 | } -------------------------------------------------------------------------------- /VAE/DiagonalGaussianDistribution.cs: -------------------------------------------------------------------------------- 1 | using static TorchSharp.torch.nn; 2 | using static TorchSharp.torch; 3 | using TorchSharp.Modules; 4 | using TorchSharp; 5 | 6 | namespace SD; 7 | 8 | public class DiagonalGaussianDistribution 9 | { 10 | private readonly Tensor parameters; 11 | private readonly bool deterministic; 12 | private readonly Tensor mean; 13 | private readonly Tensor logvar; 14 | private readonly Tensor std; 15 | private readonly Tensor var; 16 | public DiagonalGaussianDistribution( 17 | Tensor parameters, 18 | bool deterministic = false) 19 | { 20 | this.parameters = parameters; 21 | this.deterministic = deterministic; 22 | 23 | var chunks = torch.chunk(parameters, 2, dim: -1); 24 | this.mean = chunks[0]; 25 | this.logvar = chunks[1]; 26 | this.std = torch.exp(0.5f * this.logvar); 27 | this.var = torch.exp(this.logvar); 28 | 29 | if (deterministic) 30 | { 31 | this.std = torch.zeros_like(this.std, device: this.parameters.device, dtype: this.parameters.dtype); 32 | this.var = torch.zeros_like(this.var, device: this.parameters.device, dtype: this.parameters.dtype); 33 | } 34 | } 35 | 36 | public Tensor Sample(Generator? generator = null) 37 | { 38 | if (deterministic) 39 | { 40 | return mean; 41 | } 42 | 43 | return mean + std * torch.randn_like(mean); 44 | } 45 | 46 | public Tensor KL(DiagonalGaussianDistribution? other) 47 | { 48 | if (this.deterministic) 49 | { 50 | return torch.zeros_like(this.mean); 51 | } 52 | 53 | if (other is null) 54 | { 55 | return 0.5 * torch.sum( 56 | this.var + this.mean * this.mean - 1.0 - this.logvar, 57 | dim: [1, 2, 3] 58 | ); 59 | } 60 | 61 | return 0.5 * torch.sum( 62 | this.var / other.var + (this.mean - other.mean).pow(2) / other.var - 1.0 - this.logvar + other.logvar, 63 | dim: [1, 2, 3] 64 | ); 65 | } 66 | 67 | public Tensor NLL(Tensor sample, long[]? dims = null) 68 | { 69 | dims = dims ?? new long[] { 1, 2, 3 }; 70 | 71 | if (deterministic) 72 | { 73 | return torch.zeros_like(this.mean); 74 | } 75 | 76 | var log2Pi = torch.tensor(2.0 * Math.PI, device: this.parameters.device, dtype: this.parameters.dtype); 77 | var nll = 0.5 * torch.sum( 78 | (sample - this.mean).pow(2) / this.var + this.logvar + log2Pi, 79 | dim: dims 80 | ); 81 | 82 | return nll; 83 | } 84 | 85 | public Tensor Mode() 86 | { 87 | return mean; 88 | } 89 | } -------------------------------------------------------------------------------- /VAE/Encoder.cs: -------------------------------------------------------------------------------- 1 | using static TorchSharp.torch.nn; 2 | using static TorchSharp.torch; 3 | using TorchSharp.Modules; 4 | using TorchSharp; 5 | 6 | namespace SD; 7 | 8 | /// 9 | /// The Encoder layer of a variational autoencoder that compresses the input data into a latent space. 10 | /// 11 | public class Encoder : Module 12 | { 13 | private readonly int _inChannels; 14 | private readonly int _outChannels; 15 | private readonly int[] _blockOutChannels; 16 | private readonly string[] _downBlockTypes; 17 | private readonly int _layersPerBlock; 18 | private readonly int _normNumGroups; 19 | private readonly string _activationFunction; 20 | private readonly bool gradient_checkpointing = false; 21 | private readonly Module conv_in; 22 | private readonly ModuleList> down_blocks; 23 | private readonly UNetMidBlock2D mid_block; 24 | private readonly Module conv_out; 25 | private readonly Module conv_norm_out; 26 | private readonly Module conv_act; 27 | 28 | private readonly ScalarType dtype; 29 | 30 | 31 | /// 32 | /// Initializes a new instance of the class. 33 | /// 34 | /// The number of input channels. 35 | /// The number of latent channels. 36 | /// The number of output channels for each block. 37 | /// The types of blocks to use for downscaling. 38 | /// The number of layers per block. 39 | /// The number of groups for normalization. 40 | /// The activation function to use. 41 | public Encoder( 42 | int? inChannels = null, 43 | int? outChannels = null, 44 | int[]? blockOutChannels = null, 45 | string[]? downBlockTypes = null, 46 | int layersPerBlock = 2, 47 | int normNumGroups = 32, 48 | string activationFunction = "silu", 49 | bool doubleZ = true, 50 | bool midBlockAddAttention = true, 51 | bool mid_block_from_deprecated_attn_block = true, 52 | ScalarType dtype = ScalarType.Float32) 53 | : base(nameof(Encoder)) 54 | { 55 | 56 | _inChannels = inChannels ?? 3; 57 | _outChannels = outChannels ?? 3; 58 | _blockOutChannels = blockOutChannels ?? [64]; 59 | _downBlockTypes = downBlockTypes ?? ["DownEncoderBlock2D"]; 60 | _layersPerBlock = layersPerBlock; 61 | _normNumGroups = normNumGroups; 62 | _activationFunction = activationFunction; 63 | this.dtype = dtype; 64 | 65 | this.conv_in = torch.nn.Conv2d(this._inChannels, this._blockOutChannels[0], kernelSize: 3, stride: 1, padding: 1, dtype: this.dtype); 66 | this.down_blocks = new ModuleList>(); 67 | 68 | // mid 69 | this.mid_block = new UNetMidBlock2D( 70 | in_channels: _blockOutChannels[^1], 71 | resnet_eps: 1e-6f, 72 | resnet_act_fn: activationFunction, 73 | output_scale_factor: 1, 74 | resnet_time_scale_shift: "default", 75 | attention_head_dim: _blockOutChannels[^1], 76 | resnet_groups: normNumGroups, 77 | add_attention: midBlockAddAttention, 78 | from_deprecated_attn_block: mid_block_from_deprecated_attn_block, 79 | dtype: this.dtype); 80 | 81 | var output_channel = _blockOutChannels[0]; 82 | for (int i = 0; i < _blockOutChannels.Length; i++) 83 | { 84 | var input_channel = output_channel; 85 | output_channel = _blockOutChannels[i]; 86 | var is_final_block = i == _blockOutChannels.Length - 1; 87 | var down_block = new DownEncoderBlock2D( 88 | in_channels: input_channel, 89 | out_channels: output_channel, 90 | add_downsample: !is_final_block, 91 | num_layers: _layersPerBlock, 92 | resnet_act_fun: _activationFunction, 93 | resnet_groups: _normNumGroups, 94 | downsample_padding: 0, 95 | dtype: this.dtype); 96 | 97 | this.down_blocks.Add(down_block); 98 | } 99 | // out 100 | this.conv_norm_out = nn.GroupNorm(num_groups: normNumGroups, num_channels: _blockOutChannels[^1], eps: 1e-6f, dtype: this.dtype); 101 | this.conv_act = nn.SiLU(); 102 | var conv_out_channels = doubleZ ? _outChannels * 2 : _outChannels; 103 | this.conv_out = nn.Conv2d(_blockOutChannels[^1], conv_out_channels, kernelSize: 3, padding: Padding.Same, dtype: this.dtype); 104 | 105 | } 106 | 107 | public override Tensor forward(Tensor sample) 108 | { 109 | sample = this.conv_in.forward(sample); 110 | // down 111 | foreach (var down_block in this.down_blocks) 112 | { 113 | sample = down_block.forward(sample); 114 | } 115 | // mid 116 | var input = new UNetMidBlock2DInput(sample); 117 | sample = this.mid_block.forward(input); 118 | // post-process 119 | sample = this.conv_norm_out.forward(sample); 120 | sample = this.conv_act.forward(sample); 121 | sample = this.conv_out.forward(sample); 122 | 123 | return sample; 124 | } 125 | } -------------------------------------------------------------------------------- /img/a photo of an astronaut riding a horse on mars.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LittleLittleCloud/Torchsharp-stable-diffusion-2/385eda34d436741be22061d5c531f85ec94d8a03/img/a photo of an astronaut riding a horse on mars.png --------------------------------------------------------------------------------