├── LICENSE ├── README.md ├── diffusion_from_scratch ├── README.md ├── diffusion_model.py ├── diffusion_results │ ├── img.png │ ├── img_0.png │ ├── img_1.png │ ├── img_2.png │ ├── img_3.png │ ├── img_4.png │ ├── img_5.png │ ├── img_6.png │ ├── img_7.png │ └── img_8.png ├── sample_diffusion.py └── train_diffusion.py ├── stable_diffusion_from_scratch ├── README.md ├── diffusion_model.py ├── sample_stable_diffusion.ipynb ├── stable_diffusion_model.py ├── stable_diffusion_results │ ├── img.png │ ├── img_0.png │ ├── img_1.png │ ├── img_2.png │ └── img_3.png ├── train_stable_diffusion.py ├── vae_model.pth └── vae_model.py └── vae_from_scratch ├── README.md ├── pokemon_sample_test.png ├── sample_vae.py ├── train_vae.py ├── vae_model.pth ├── vae_model.py └── vae_results ├── kl_loss.png ├── latent_space.png ├── mse_loss.png ├── reconstruction_0.png ├── reconstruction_100.png ├── reconstruction_20.png ├── reconstruction_40.png ├── reconstruction_60.png ├── reconstruction_80.png ├── sampled.png └── train_loss.png /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 ZHUO ZHANG 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ### 一系列关于Diffusion、VAE、DiT等模型实现的教学相关代码集 2 | 3 | __阅读这个代码库的最佳方式是对应我录的视频来看,可以更好的理解。__ 4 | 5 | 可以在各个平台找到我的视频录播地址: 6 | 7 | - [BiliBili - LLM张老师](https://space.bilibili.com/3546611527453161/channel/collectiondetail?sid=3848231&ctype=0) 8 | - [抖音 - LLM张老师](https://www.douyin.com/user/self?from_tab_name=main&modal_id=7416739026914250024) 9 | - [Youtube - LLMZhang](https://www.youtube.com/playlist?list=PL95p-eWIbW1Eu8SQQ2a9zGd5t2fgX9H3I&playnext=1&index=1) 10 | 11 | ### 代码结构 12 | 13 | - `/vae_from_scratch`: VAE的原理实现 14 | - `/diffusion_from_scratch`:DDPM的原理实现 15 | - `/stable_diffusion_from_scratch`:(更新中)Stable Diffusion的实现... -------------------------------------------------------------------------------- /diffusion_from_scratch/README.md: -------------------------------------------------------------------------------- 1 | 2 | # 从零开始实现 DDPM 3 | 4 | 这里是我在视频中讲解的代码,主要是关于Diffusion模型的实现。 5 | 6 | 使用了一个800张照片的pokemon的小规模图片数据集,演示从零开始的训练过程。 7 | 8 | 由于是从零训练并没有依赖任何预训练的模型,展示了 Diffusion + Spatial Transformer + Attention + Unet (这些模块合在一起就是比较现代的DDPM模型了),模型只作为教学示例。 9 | 10 | 下一个视频我再来用 VAE + DDPM 来实现一个接近 Stable Diffusion 2 的模型架构。 11 | 12 | 在本地 Mac M3 上训练大概需要16G内存,大概需要2~4小时。如果要达到更好的效果,则需使用更大的数据集和算力。 13 | 14 | #### 需要安装的库: 15 | ``` 16 | numpy 17 | torch 18 | torchvision 19 | Pillow 20 | datasets 21 | transformers 22 | PIL 23 | tdqm 24 | datasets 25 | ``` 26 | 27 | #### 训练图片数据集: 28 | 29 | 运行`train_diffusion.py`会从huggingface上下载一个[pokemon](https://huggingface.co/datasets/svjack/pokemon-blip-captions-en-zh)的小规模图片数据集。 30 | 31 | 当然,你也可以在代码中替换成本地的其他图片数据集。 32 | 33 | #### 训练Epoch过程中的样本图片生成: 34 | 35 | 随着训练epoch过程,生成(带文字条件)的图片会越来越清晰。 36 | 37 | > `文本条件 = "a water type pokemon"` 38 | 39 | ![img](diffusion_results/img.png) 40 | 41 | > `文本条件 = "a dragon character with tail"` 42 | 43 | ![img](diffusion_results/img_1.png) 44 | 45 | #### 训练完成后模型生成: 46 | 47 | 模型训练完成后,运行 `sample_diffusion.py` 可以生成一些图片。 48 | 49 | > 1. `文本条件 = "a cartoon pikachu with big eyes and big ears"` 50 | - 普通DDPM采样生成: 51 | 52 | ![img](diffusion_results/img_3.png) 53 | 54 | - Classifier-Free Guidance(CFG) 采样: 55 | 56 | ![img](diffusion_results/img_4.png) 57 | 58 | > 2. `文本条件 = "a red pokemon with a red fire tail"` 59 | 60 | - 普通DDPM采样生成: 61 | 62 | ![img](diffusion_results/img_5.png) 63 | 64 | - Classifier-Free Guidance(CFG) 采样生成: 65 | 66 | ![img](diffusion_results/img_6.png) 67 | 68 | > 3. `文本条件 = "a green bird with a red tail and a black nose"` 69 | 70 | - 普通DDPM采样生成: 71 | 72 | ![img](diffusion_results/img_7.png) 73 | 74 | - Classifier-Free Guidance(CFG) 采样生成: 75 | 76 | 所有的CFG采样生成的图片都比较暗是由于对数据集进行了数据增强处理,导致了图片的亮度变化。而在训练过程中,模型学习到了这种亮度变化,所以生成的图片也会有这种特点。 77 | 78 | ![img](diffusion_results/img_8.png) 79 | 80 | 81 | #### 关于损失值: 82 | 83 | 使用了均方误差损失函数。 84 | 85 | 训练图片被缩放成了64x64的尺寸,所以损失值是按照像素计算的。 86 | 87 | 由于这个例子中的pokemon数据集相对较小,300个epoch和2000个epoch的结果差不多。要想达到更完美的生成效果,个人能力认为智能增加更多pokemon训练集图片数量。因为加入文本嵌入条件之后做到生成的泛化,800张图片是远远不够的(参考SD2的训练集数量是1亿张以上)。 88 | 89 | - Learning rate 及 训练损失: 90 | ![train loss](diffusion_results/img_0.png) 91 | 92 | 93 | #### 效果说明: 94 | 95 | Diffusion这里想说的东西太多了,基本都在视频里讲了一遍。 96 | 97 | 由于这个实现是增加了“文本条件”的,所以简单的模型架构和少量的训练集无法达到很好的泛化效果。 98 | 99 | - 如果只是使用DDPM模型来实现训练集上见过的图片浮现是很容易的; 100 | 101 | - 又或者做照片分类任务DDPM也较容易。 102 | 103 | 但以上这两方面网上开源的实现已经很多了。 104 | 105 | 加入文本条件的生成模型则教学很少。在代码的实现中增加了Attention机制,跳跃链接等,这样才能更好的利用文本信息以结合图片像素关系。 106 | 107 | 这个实现只是一个简单的示例,希望能给大家一些启发。 -------------------------------------------------------------------------------- /diffusion_from_scratch/diffusion_model.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from functorch.einops import rearrange 6 | 7 | 8 | class Attention(nn.Module): 9 | def __init__(self, embed_dim, hidden_dim, context_dim=None, num_heads=4): 10 | super(Attention, self).__init__() 11 | self.self_attn = context_dim is None 12 | self.hidden_dim = hidden_dim 13 | self.context_dim = context_dim if context_dim is not None else hidden_dim 14 | self.embed_dim = embed_dim 15 | self.num_heads = num_heads 16 | self.head_dim = embed_dim // num_heads 17 | 18 | self.query = nn.Linear(hidden_dim, embed_dim, bias=False) 19 | self.key = nn.Linear(self.context_dim, embed_dim, bias=False) 20 | self.value = nn.Linear(self.context_dim, embed_dim, bias=False) 21 | self.out_proj = nn.Linear(embed_dim, hidden_dim) 22 | 23 | def forward(self, tokens, t=None, context=None): 24 | B, T, C = tokens.shape 25 | H = self.num_heads 26 | 27 | Q = self.query(tokens).view(B, T, H, self.head_dim).transpose(1, 2) 28 | 29 | if self.self_attn: 30 | K = self.key(tokens).view(B, T, H, self.head_dim).transpose(1, 2) 31 | V = self.value(tokens).view(B, T, H, self.head_dim).transpose(1, 2) 32 | else: 33 | _, context_len, context_C = context.shape 34 | if context_C != self.context_dim: 35 | context = nn.Linear(context_C, self.context_dim).to(context.device)(context) 36 | context_C = self.context_dim 37 | 38 | K = self.key(context).view(B, context_len, H, self.head_dim).transpose(1, 2) 39 | V = self.value(context).view(B, context_len, H, self.head_dim).transpose(1, 2) 40 | 41 | # 计算注意力分数 42 | attn_scores = torch.matmul(Q, K.transpose(-1, -2)) / math.sqrt(self.head_dim) 43 | attn_probs = F.softmax(attn_scores, dim=-1) 44 | 45 | out = torch.matmul(attn_probs, V) 46 | out = out.transpose(1, 2).contiguous().view(B, T, self.embed_dim) 47 | out = self.out_proj(out) 48 | 49 | return out 50 | 51 | 52 | class TransformerBlock(nn.Module): 53 | def __init__(self, hidden_dim, context_dim, num_heads, self_attn=False, cross_attn=False): 54 | """ 55 | Build a transformer block 56 | :param hidden_dim: 图像的隐藏维度(通道数) 57 | :param context_dim: 文本的隐藏维度 58 | :param num_heads: Attention中多头的数量 59 | :param self_attn: 是否使用自注意力 60 | :param cross_attn: 是否使用交叉注意力 61 | """ 62 | super(TransformerBlock, self).__init__() 63 | self.self_attn = self_attn 64 | self.cross_attn = cross_attn 65 | 66 | # Self-attention 自注意力 67 | self.attn_self = Attention(hidden_dim, hidden_dim, num_heads=num_heads) if self_attn else None 68 | 69 | # Cross-attention 交叉注意力 70 | self.attn_cross = Attention(hidden_dim, hidden_dim, context_dim=context_dim, num_heads=num_heads) if cross_attn else None 71 | 72 | self.norm1 = nn.LayerNorm(hidden_dim) 73 | self.norm2 = nn.LayerNorm(hidden_dim) 74 | self.norm3 = nn.LayerNorm(hidden_dim) 75 | self.norm4 = nn.LayerNorm(hidden_dim) 76 | 77 | self.ffn1 = nn.Sequential( 78 | nn.Linear(hidden_dim, 4 * hidden_dim), 79 | nn.GELU(), 80 | nn.Linear(4 * hidden_dim, hidden_dim), 81 | nn.Dropout(0.1) 82 | ) 83 | self.ffn2 = nn.Sequential( 84 | nn.Linear(hidden_dim, 4 * hidden_dim), 85 | nn.GELU(), 86 | nn.Linear(4 * hidden_dim, hidden_dim), 87 | nn.Dropout(0.1) 88 | ) 89 | 90 | def forward(self, x, t=None, context=None): 91 | if self.self_attn: 92 | x = self.attn_self(self.norm1(x)) + x 93 | x = self.ffn1(self.norm2(x)) + x 94 | 95 | if self.cross_attn: 96 | x = self.attn_cross(self.norm3(x), context=context) + x 97 | x = self.ffn2(self.norm4(x)) + x 98 | 99 | return x 100 | 101 | 102 | class SpatialTransformer(nn.Module): 103 | def __init__(self, hidden_dim, context_dim=512, num_heads=4, self_attn=False, cross_attn=False): 104 | super(SpatialTransformer, self).__init__() 105 | self.transformer = TransformerBlock(hidden_dim, context_dim, num_heads, self_attn, cross_attn) 106 | self.context_proj = nn.Linear(context_dim, hidden_dim) if context_dim != hidden_dim else nn.Identity() 107 | self.self_attn = self_attn 108 | self.cross_attn = cross_attn 109 | 110 | def forward(self, x, t=None, context=None): 111 | b, c, h, w = x.shape 112 | x_res = x # 用作残差连接 113 | x = rearrange(x, "b c h w -> b (h w) c") 114 | 115 | if context is not None: 116 | context = self.context_proj(context) 117 | 118 | x = self.transformer(x, t, context) 119 | x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) 120 | return x + x_res 121 | 122 | 123 | class ResnetBlock(nn.Module): 124 | """ 125 | 抄自 Stable Diffusion 1.x. 126 | 源代码中有两个版本的实现: 127 | 1) https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/model.py#L82 128 | 2) https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/openaimodel.py#L163 129 | """ 130 | def __init__(self, in_channels, out_channels, time_dim): 131 | super().__init__() 132 | self.norm1 = nn.GroupNorm(4, in_channels, eps=1e-6) # SD1.x uses eps=1e-6 133 | self.norm2 = nn.GroupNorm(4, out_channels, eps=1e-6) 134 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) 135 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) 136 | self.activation = nn.SiLU() 137 | self.residual_conv = nn.Conv2d(in_channels, out_channels, 1, stride=1, padding=0) if in_channels != out_channels else nn.Identity() 138 | self.dropout = nn.Dropout(0.1) 139 | self.time_proj = torch.nn.Linear(time_dim, out_channels) 140 | 141 | def forward(self, x, t): 142 | residual = self.residual_conv(x) 143 | 144 | x = self.conv1(self.activation(self.norm1(x))) 145 | x = x + self.time_proj(self.activation(t))[:, :, None, None] # 添加时间嵌入 146 | x = self.dropout(x) 147 | x = self.conv2(self.activation(self.norm2(x))) 148 | 149 | return x + residual 150 | 151 | 152 | class DownBlock(nn.Module): 153 | def __init__(self, in_channels, out_channels, time_dim, self_attn=False, cross_attn=False, num_heads=1, context_dim=512): 154 | super().__init__() 155 | self.resnet1 = ResnetBlock(in_channels, out_channels, time_dim) 156 | self.transformer1 = SpatialTransformer(out_channels, context_dim, num_heads=num_heads, self_attn=True) if self_attn else None 157 | self.resnet2 = ResnetBlock(out_channels, out_channels, time_dim) 158 | self.transformer2 = SpatialTransformer(out_channels, context_dim, num_heads=num_heads, cross_attn=True) if cross_attn else None 159 | self.downsample = nn.Conv2d(out_channels, out_channels, 3, stride=2, padding=1) # 下采样 160 | 161 | def forward(self, x, t, y): 162 | x = self.resnet1(x, t) 163 | if self.transformer1: 164 | x = self.transformer1(x, t, y) 165 | x = self.resnet2(x, t) 166 | if self.transformer2: 167 | x = self.transformer2(x, t, y) 168 | x = self.downsample(x) 169 | return x 170 | 171 | 172 | class MiddleBlock(nn.Module): 173 | def __init__(self, channels, time_dim, context_dim): 174 | super().__init__() 175 | self.resnet1 = ResnetBlock(channels, channels, time_dim) 176 | self.attn1 = SpatialTransformer(channels, context_dim, num_heads=channels//64, self_attn=True, cross_attn=True) # 256/64=4 177 | self.resnet2 = ResnetBlock(channels, channels, time_dim) 178 | # 可选:添加第二个注意力层和resnet块 179 | # self.attn2 = SpatialTransformer(channels, context_dim, num_heads=channels//64, self_attn=True, cross_attn=True) # 256/64=4 180 | # self.resnet3 = ResnetBlock(channels, channels, time_dim) 181 | 182 | def forward(self, x, t, context): 183 | x = self.resnet1(x, t) 184 | x = self.attn1(x, t, context) 185 | x = self.resnet2(x, t) 186 | # x = self.attn2(x, context) 187 | # x = self.resnet3(x, t) 188 | return x 189 | 190 | 191 | class UpBlock(nn.Module): 192 | def __init__(self, in_channels, out_channels, time_dim, self_attn=False, cross_attn=False, num_heads=1, context_dim=512): 193 | super().__init__() 194 | # nn.Upsample(scale_factor=2, mode='nearest'), 195 | # nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=1, padding=1), # or, kernel_size=5, padding=2 196 | self.upsample = nn.ConvTranspose2d(in_channels, out_channels, 4, stride=2, padding=1) # 上采样 197 | self.resnet1 = ResnetBlock(out_channels, out_channels, time_dim) 198 | self.transformer1 = SpatialTransformer(out_channels, context_dim, num_heads=num_heads, self_attn=True) if self_attn else None 199 | self.resnet2 = ResnetBlock(out_channels, out_channels, time_dim) 200 | self.transformer2 = SpatialTransformer(out_channels, context_dim, num_heads=num_heads, cross_attn=True) if cross_attn else None 201 | self.resnet3 = ResnetBlock(out_channels, out_channels, time_dim) 202 | 203 | def forward(self, x, t, y): 204 | x = self.upsample(x) 205 | x = self.resnet1(x, t) 206 | if self.transformer1: 207 | x = self.transformer1(x, t, y) 208 | x = self.resnet2(x, t) 209 | if self.transformer2: 210 | x = self.transformer2(x, t, y) 211 | x = self.resnet3(x, t) 212 | return x 213 | 214 | 215 | class UNet_Transformer(nn.Module): 216 | def __init__(self, in_channels=3, time_dim=256, context_dim=512): 217 | super().__init__() 218 | 219 | self.time_dim = time_dim 220 | self.time_mlp = nn.Sequential( 221 | nn.Linear(time_dim, time_dim * 4), 222 | nn.SiLU(), 223 | nn.Linear(time_dim * 4, time_dim) 224 | ) 225 | self.context_dim = context_dim 226 | 227 | # 初始卷积 228 | self.init_conv = nn.Sequential( 229 | nn.Conv2d(in_channels, 64, 3, padding=1), 230 | nn.SiLU(), 231 | nn.Conv2d(64, 64, 3, padding=1) 232 | ) # 64 x H x W 233 | 234 | # 下采样 235 | self.down1 = self._down_block(64, 128, time_dim) # 128 x H/2 x W/2 236 | self.down2 = self._down_block(128, 256, time_dim, self_attn=True, cross_attn=False, num_heads=4, context_dim=context_dim) # 256 x H/4 x W/4 237 | self.down3 = self._down_block(256, 512, time_dim, self_attn=True, cross_attn=False, num_heads=8, context_dim=context_dim) # 512 x H/8 x W/8 238 | 239 | # 中间块 240 | self.middle_block = MiddleBlock(512, time_dim, context_dim) # 512 x H/8 x W/8 241 | 242 | # 上采样 243 | self.up1 = self._up_conv(512, 256, time_dim, self_attn=True, cross_attn=True, num_heads=8, context_dim=context_dim) # 256 x H/4 x W/4 244 | self.up2 = self._up_conv(256+256, 128, time_dim, self_attn=True, cross_attn=True, num_heads=4, context_dim=context_dim) # 128 x H/2 x W/2 245 | self.up3 = self._up_conv(128+128, 64, time_dim) # 64 x H x W 246 | 247 | # 最终卷积 248 | self.final_conv = nn.Sequential( 249 | ResnetBlock(64 * 2, 64, time_dim), 250 | nn.Conv2d(64, 64, 3, stride=1, padding=1), 251 | nn.SiLU(), 252 | nn.Conv2d(64, in_channels, 3, stride=1, padding=1), 253 | ) 254 | 255 | def get_sinusoidal_position_embedding(self, timesteps, embedding_dim): 256 | half_dim = embedding_dim // 2 257 | emb = math.log(10000) / (half_dim - 1) 258 | emb = torch.exp(torch.arange(half_dim, device=timesteps.device) * -emb) 259 | emb = timesteps[:, None] * emb[None, :] 260 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) 261 | if embedding_dim % 2 == 1: # zero pad if embedding_dim is odd 262 | emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) 263 | return emb 264 | 265 | def _down_block(self, in_channels, out_channels, time_dim, self_attn=False, cross_attn=False, num_heads=1, context_dim=None): 266 | return DownBlock(in_channels, out_channels, time_dim, self_attn, cross_attn, num_heads, context_dim or self.context_dim) 267 | 268 | def _up_conv(self, in_channels, out_channels, time_dim, self_attn=False, cross_attn=False, num_heads=1, context_dim=None): 269 | return UpBlock(in_channels, out_channels, time_dim, self_attn, cross_attn, num_heads, context_dim or self.context_dim) 270 | 271 | def forward(self, x, t, y): 272 | # x: [batch, 3, H, W] 273 | # t: [batch, ] time embedding 274 | # y: [batch, 512] text embedding 275 | initial_x = x 276 | # Ensure y has the correct shape 277 | if y.dim() == 2: 278 | y = y.unsqueeze(1) # [batch, 1, context_dim] 279 | 280 | t = self.get_sinusoidal_position_embedding(t, self.time_dim) # [batch, 256] 281 | t = self.time_mlp(t) 282 | 283 | x1 = self.init_conv(x) 284 | 285 | x2 = self.down1(x1, t, y) 286 | x3 = self.down2(x2, t, y) 287 | x4 = self.down3(x3, t, y) 288 | 289 | x4 = self.middle_block(x4, t, y) 290 | 291 | x = self.up1(x4, t, y) 292 | x = torch.cat([x, x3], dim=1) # skip connection 跳跃连接 293 | x = self.up2(x, t, y) 294 | x = torch.cat([x, x2], dim=1) # skip connection 跳跃连接 295 | x = self.up3(x, t, y) 296 | x = torch.cat([x, x1], dim=1) # skip connection 跳跃连接 297 | 298 | x = self.final_conv[0](x, t) 299 | for layer in self.final_conv[1:]: 300 | x = layer(x) 301 | 302 | return x + initial_x # 全局残差连接 303 | 304 | 305 | """添加噪声过程:从原始图像开始,逐渐增加噪声,直到最终的噪声图像""" 306 | class NoiseScheduler: 307 | def __init__(self, num_timesteps, device): 308 | self.device = device 309 | self.num_timesteps = num_timesteps 310 | self.betas = self.cosine_beta_schedule(num_timesteps).to(device) # 这里我们使用余弦噪声调度。DDPM原始论文中使用的是线性调度。 311 | self.alphas = (1. - self.betas).to(device) 312 | self.alphas_cumprod = torch.cumprod(self.alphas, dim=0).to(device) 313 | self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod).to(device) 314 | self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod).to(device) 315 | 316 | def cosine_beta_schedule(self, timesteps, s=0.008): 317 | steps = timesteps + 1 318 | x = torch.linspace(0, timesteps, steps) 319 | alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2 320 | alphas_cumprod = alphas_cumprod / alphas_cumprod[0] 321 | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) 322 | return torch.clip(betas, 0.0001, 0.9999) 323 | 324 | def add_noise(self, x_start, t): 325 | """ 326 | 添加噪声到输入图像或潜在表示。 327 | :param x_start: 初始清晰图像或潜在表示 328 | :param t: 当前时间步 329 | :return: 添加噪声后的表示 330 | """ 331 | t = t.clone().detach().long().to(self.sqrt_alphas_cumprod.device) 332 | # 生成标准正态分布的噪声 333 | noise = torch.randn_like(x_start) 334 | # 获取所需的预计算值 335 | sqrt_alphas_cumprod_t = self.sqrt_alphas_cumprod[t].view(-1, 1, 1, 1) 336 | sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1, 1) 337 | # 计算第t步、带噪声的图像 338 | x_t = sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise 339 | 340 | return x_t, noise 341 | 342 | 343 | 344 | """CFG采样器(去噪)Classifier Guided Diffusion""" 345 | @torch.no_grad() 346 | def sample_cfg(model, noise_scheduler, n_samples, in_channels, text_embeddings, image_size=64, guidance_scale=3.0): 347 | """ 348 | 从噪声开始,逐渐减小噪声,直到最终的图像。 349 | :param model: UNet模型 350 | :param noise_scheduler: 噪声调度器 351 | :param n_samples: 生成的样本数量 352 | :param in_channels: 输入图像的通道数 353 | :param text_embeddings: 文本嵌入 354 | :param image_size: 图像的大小 355 | :param guidance_scale: 用于加权噪声预测的比例 356 | :return: 生成的图像 357 | """ 358 | model.eval() 359 | device = next(model.parameters()).device 360 | 361 | x = torch.randn(n_samples, in_channels, image_size, image_size).to(device) # 随机初始化噪声图像 362 | null_embeddings = torch.zeros_like(text_embeddings) # 用于无条件生成 363 | 364 | # 逐步去噪 365 | for t in reversed(range(noise_scheduler.num_timesteps)): 366 | t_batch = torch.full((n_samples,), t, device=device, dtype=torch.long) 367 | 368 | noise_pred_uncond = model(x, t_batch, y=null_embeddings) # 生成一个无条件的噪声预测 369 | noise_pred_cond = model(x, t_batch, y=text_embeddings) # 生成一个有条件的噪声预测 370 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) # CFG:结果加权后的噪声预测 371 | 372 | # 采样器的去噪过程 373 | alpha_t = noise_scheduler.alphas[t] 374 | alpha_t_bar = noise_scheduler.alphas_cumprod[t] 375 | beta_t = noise_scheduler.betas[t] 376 | 377 | if t > 0: 378 | noise = torch.randn_like(x) 379 | else: 380 | noise = torch.zeros_like(x) 381 | 382 | # 去噪公式 383 | x = (1 / torch.sqrt(alpha_t)) * (x - ((1 - alpha_t) / (torch.sqrt(1 - alpha_t_bar))) * noise_pred) + torch.sqrt(beta_t) * noise 384 | 385 | model.train() 386 | return x 387 | 388 | 389 | """普通采样器(去噪)""" 390 | @torch.no_grad() 391 | def sample(model, x_t, noise_scheduler, t, text_embeddings): 392 | """ 393 | 从噪声开始,逐渐减小噪声,直到最终的图像。 394 | 395 | 参数: 396 | - model: UNet模型用于预测噪声。 397 | - x_t: 当前时间步的噪声化表示(torch.Tensor)。 398 | - noise_scheduler: 噪声调度器,包含betas和其他预计算值。 399 | - t: 当前时间步(torch.Tensor)。 400 | - text_embeddings: 文本嵌入,用于条件生成(torch.Tensor)。 401 | 402 | 返回: 403 | - x: 去噪后的表示。 404 | """ 405 | t = t.to(x_t.device) 406 | 407 | # 获取当前时间步的beta和alpha值 408 | beta_t = noise_scheduler.betas[t] 409 | alpha_t = noise_scheduler.alphas[t] 410 | alpha_t_bar = noise_scheduler.alphas_cumprod[t] 411 | 412 | # 预测当前时间步的噪声 413 | predicted_noise = model(x_t, t, text_embeddings) 414 | 415 | # 计算去噪后的表示 416 | if t > 0: 417 | noise = torch.randn_like(x_t).to(x_t.device) 418 | else: 419 | noise = torch.zeros_like(x_t).to(x_t.device) 420 | 421 | # 去噪公式 422 | x = (1 / torch.sqrt(alpha_t)) * (x_t - ((1 - alpha_t) / (torch.sqrt(1 - alpha_t_bar))) * predicted_noise) + torch.sqrt(beta_t) * noise 423 | 424 | return x 425 | 426 | 427 | 428 | -------------------------------------------------------------------------------- /diffusion_from_scratch/diffusion_results/img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/waylandzhang/DiT_from_scratch/1f6b6c0549db56a100ecca6ccaebeba8371e3225/diffusion_from_scratch/diffusion_results/img.png -------------------------------------------------------------------------------- /diffusion_from_scratch/diffusion_results/img_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/waylandzhang/DiT_from_scratch/1f6b6c0549db56a100ecca6ccaebeba8371e3225/diffusion_from_scratch/diffusion_results/img_0.png -------------------------------------------------------------------------------- /diffusion_from_scratch/diffusion_results/img_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/waylandzhang/DiT_from_scratch/1f6b6c0549db56a100ecca6ccaebeba8371e3225/diffusion_from_scratch/diffusion_results/img_1.png -------------------------------------------------------------------------------- /diffusion_from_scratch/diffusion_results/img_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/waylandzhang/DiT_from_scratch/1f6b6c0549db56a100ecca6ccaebeba8371e3225/diffusion_from_scratch/diffusion_results/img_2.png -------------------------------------------------------------------------------- /diffusion_from_scratch/diffusion_results/img_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/waylandzhang/DiT_from_scratch/1f6b6c0549db56a100ecca6ccaebeba8371e3225/diffusion_from_scratch/diffusion_results/img_3.png -------------------------------------------------------------------------------- /diffusion_from_scratch/diffusion_results/img_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/waylandzhang/DiT_from_scratch/1f6b6c0549db56a100ecca6ccaebeba8371e3225/diffusion_from_scratch/diffusion_results/img_4.png -------------------------------------------------------------------------------- /diffusion_from_scratch/diffusion_results/img_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/waylandzhang/DiT_from_scratch/1f6b6c0549db56a100ecca6ccaebeba8371e3225/diffusion_from_scratch/diffusion_results/img_5.png -------------------------------------------------------------------------------- /diffusion_from_scratch/diffusion_results/img_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/waylandzhang/DiT_from_scratch/1f6b6c0549db56a100ecca6ccaebeba8371e3225/diffusion_from_scratch/diffusion_results/img_6.png -------------------------------------------------------------------------------- /diffusion_from_scratch/diffusion_results/img_7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/waylandzhang/DiT_from_scratch/1f6b6c0549db56a100ecca6ccaebeba8371e3225/diffusion_from_scratch/diffusion_results/img_7.png -------------------------------------------------------------------------------- /diffusion_from_scratch/diffusion_results/img_8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/waylandzhang/DiT_from_scratch/1f6b6c0549db56a100ecca6ccaebeba8371e3225/diffusion_from_scratch/diffusion_results/img_8.png -------------------------------------------------------------------------------- /diffusion_from_scratch/sample_diffusion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from diffusion_model import UNet_Transformer, NoiseScheduler, sample_cfg, sample 3 | from transformers import CLIPTokenizer, CLIPTextModel 4 | from PIL import Image 5 | import numpy as np 6 | 7 | # 超参数 8 | device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu") 9 | image_size = 64 10 | in_channels = 3 11 | num_timesteps = 1000 12 | 13 | """从保存点加载""" 14 | checkpoint = torch.load('diffusion_model_checkpoint_epoch_500.pth', map_location=device, weights_only=True) 15 | diffusion_model = UNet_Transformer(in_channels=in_channels).to(device) 16 | diffusion_model.load_state_dict(checkpoint['model_state_dict']) 17 | diffusion_model.eval() 18 | 19 | """从最终模型加载""" 20 | # diffusion_model = UNet_Transformer(in_channels=in_channels).to(device) 21 | # diffusion_model.load_state_dict(torch.load('diffusion_model_final.pth')) 22 | # diffusion_model.eval() 23 | 24 | tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32") 25 | text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32").to(device) 26 | noise_scheduler = NoiseScheduler(num_timesteps=num_timesteps, device=device) 27 | 28 | # 文本条件 29 | # condition = "a red pokemon with a red fire tail" 30 | # condition = "a blue rabbit with a yellow belly" 31 | # condition = "a cartoon pikachu with big eyes and big ears" 32 | condition = "a green bird with a red tail and a black nose" 33 | 34 | 35 | # 文本编码 36 | text_input = tokenizer([condition], padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt") 37 | text_embeddings = text_encoder(text_input.input_ids.to(device)).last_hidden_state 38 | 39 | # DDPM采样(CFG) 40 | sampled_images = sample_cfg(diffusion_model, noise_scheduler, n_samples=1, in_channels=in_channels, text_embeddings=text_embeddings, image_size=image_size, guidance_scale=1.0) 41 | 42 | # 保存生成的图片 43 | img = sampled_images[0] * 0.5 + 0.5 # 缩放到 [0, 1] 44 | img = img.detach().cpu().permute(1, 2, 0).numpy() # [C, H, W] -> [H, W, C] 调整顺序以适应 PIL 画图 45 | img = (img * 255).astype(np.uint8) 46 | img_pil = Image.fromarray(img) 47 | img_pil.save('generated_image_pokemon_cfg.png') 48 | 49 | 50 | # DDPM采样(普通) 51 | x_t = torch.randn(1, in_channels, image_size, image_size).to(device) 52 | 53 | for t in reversed(range(num_timesteps)): 54 | t_tensor = torch.full((1,), t, device=device, dtype=torch.long) 55 | x_t = sample(diffusion_model, x_t, noise_scheduler, t_tensor, text_embeddings) 56 | 57 | img = x_t[0] * 0.5 + 0.5 # Rescale to [0, 1] 58 | img = img.detach().cpu().permute(1, 2, 0).numpy() # [C, H, W] -> [H, W, C] 59 | img = (img * 255).astype(np.uint8) 60 | img_pil = Image.fromarray(img) 61 | img_pil.save('generated_image_pokemon.png') -------------------------------------------------------------------------------- /diffusion_from_scratch/train_diffusion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim import AdamW 3 | from torch.utils.data import DataLoader 4 | from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingLR, OneCycleLR 5 | from torchvision import transforms 6 | from datasets import load_dataset 7 | from diffusion_model import UNet_Transformer, NoiseScheduler, sample_cfg 8 | from transformers import CLIPTokenizer, CLIPTextModel 9 | from PIL import Image 10 | import numpy as np 11 | from tqdm.auto import tqdm 12 | import os 13 | import wandb 14 | 15 | # 超参数 16 | device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu") 17 | image_size = 64 18 | in_channels = 3 19 | n_epochs = 1000 20 | batch_size = 32 21 | lr = 1e-4 22 | num_timesteps = 1000 23 | save_checkpoint_interval = 100 24 | 25 | # WandB 初始化 26 | run = wandb.init( 27 | project="diffusion_from_scratch", 28 | config={ 29 | "batch_size": batch_size, 30 | "learning_rate": lr, 31 | "epochs": n_epochs, 32 | "image_size": image_size, 33 | "in_channels": in_channels, 34 | "num_timesteps": num_timesteps, 35 | }, 36 | ) 37 | 38 | # 初始化模型和噪声调度器 39 | diffusion_model = UNet_Transformer(in_channels=in_channels).to(device) 40 | noise_scheduler = NoiseScheduler(num_timesteps=num_timesteps, device=device) 41 | 42 | # 加载 CLIP 模型 43 | tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32") 44 | text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32").to(device) 45 | 46 | # WandB 监控 47 | wandb.watch(diffusion_model, log_freq=100) 48 | 49 | # 加载数据集 50 | dataset = load_dataset("svjack/pokemon-blip-captions-en-zh", split="train") 51 | 52 | # 数据预处理 53 | preprocess = transforms.Compose( 54 | [ 55 | transforms.Resize((image_size, image_size)), 56 | # transforms.RandomHorizontalFlip(), 57 | # transforms.RandomRotation(10), 58 | # transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), 59 | transforms.ToTensor(), 60 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 61 | ] 62 | ) 63 | 64 | def transform(examples): 65 | images = [preprocess(image.convert("RGB")) for image in examples["image"]] 66 | return {"images": images, "text": examples["en_text"]} 67 | 68 | dataset.set_transform(transform) 69 | 70 | train_dataset = dataset.select(range(0, 600)) # 选择前 600 个样本作为训练集 71 | val_dataset = dataset.select(range(600, 800)) # 选择接下来的 200 个样本作为验证集 72 | 73 | train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True) 74 | val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=True) 75 | 76 | # 优化器和学习率调度器 77 | optimizer = AdamW(diffusion_model.parameters(), lr=lr, weight_decay=1e-4) # 可以考虑加入L2正则化:weight_decay=1e-4 78 | # scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10, min_lr=5e-5) 79 | scheduler = CosineAnnealingLR(optimizer, T_max=n_epochs) # 余弦退火学习率调度器 80 | # scheduler = OneCycleLR(optimizer, max_lr=1e-3, epochs=n_epochs, steps_per_epoch=len(train_dataloader)) # OneCycleLR 学习率调度器 81 | 82 | # 创建保存生成测试图像的目录 83 | os.makedirs('diffusion_results', exist_ok=True) 84 | 85 | # 训练循环 86 | for epoch in range(n_epochs): 87 | diffusion_model.train() 88 | progress_bar = tqdm(total=len(train_dataloader), desc=f"Epoch {epoch+1}/{n_epochs}") 89 | epoch_loss = 0.0 90 | 91 | # 训练模型 92 | for batch in train_dataloader: 93 | images = batch["images"].to(device) 94 | text = batch["text"] 95 | 96 | # 使用 CLIP 模型编码文本 97 | text_inputs = tokenizer(text, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt") 98 | text_embeddings = text_encoder(text_inputs.input_ids.to(device)).last_hidden_state 99 | 100 | timesteps = torch.randint(0, num_timesteps, (images.shape[0],), device=device).long() # 随机选择 timesteps 101 | noisy_images, noise = noise_scheduler.add_noise(images, timesteps) # 添加噪声 102 | noise_pred = diffusion_model(noisy_images, timesteps, text_embeddings) # 预测噪声 103 | loss = torch.nn.functional.mse_loss(noise_pred, noise) # 预测的噪声与真实噪声的均方误差 104 | 105 | optimizer.zero_grad() 106 | loss.backward() 107 | # torch.nn.utils.clip_grad_norm_(diffusion_model.parameters(), 1.0) # 梯度裁剪 108 | optimizer.step() 109 | # scheduler.step() # OneCycleLR 在每个批次后调用 110 | 111 | epoch_loss += loss.item() 112 | progress_bar.update(1) 113 | progress_bar.set_postfix({"loss": loss.item()}) 114 | 115 | # 验证集上评估模型 116 | diffusion_model.eval() 117 | val_loss = 0 118 | with torch.no_grad(): 119 | for batch in val_dataloader: 120 | images = batch["images"].to(device) 121 | text = batch["text"] 122 | text_inputs = tokenizer(text, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt") 123 | text_embeddings = text_encoder(text_inputs.input_ids.to(device)).last_hidden_state 124 | 125 | timesteps = torch.randint(0, num_timesteps, (images.shape[0],), device=device).long() 126 | noisy_images, noise = noise_scheduler.add_noise(images, timesteps) 127 | noise_pred = diffusion_model(noisy_images, timesteps, text_embeddings) 128 | loss = torch.nn.functional.mse_loss(noise_pred, noise) 129 | val_loss += loss.item() 130 | 131 | scheduler.step() # 除了 OneCycleLR 之外,其他调度器都需要在每个 epoch 结束时调用 132 | 133 | wandb.log({ 134 | "epoch": epoch, 135 | "train_loss": epoch_loss / len(train_dataloader), 136 | "val_loss": val_loss / len(val_dataloader), 137 | "learning_rate": scheduler.get_last_lr()[0] 138 | }) 139 | 140 | # 保存模型检查点 141 | if (epoch + 1) % save_checkpoint_interval == 0: 142 | torch.save({ 143 | 'epoch': epoch, 144 | 'model_state_dict': diffusion_model.state_dict(), 145 | 'optimizer_state_dict': optimizer.state_dict(), 146 | 'scheduler_state_dict': scheduler.state_dict(), 147 | 'train_loss': epoch_loss, 148 | 'val_loss': val_loss, 149 | }, f'diffusion_results/diffusion_model_checkpoint_epoch_{epoch+1}.pth') 150 | 151 | # 生成测试图像 152 | if (epoch + 1) % save_checkpoint_interval == 0: 153 | diffusion_model.eval() 154 | with torch.no_grad(): 155 | sample_text = ["a water type pokemon", "a red pokemon with a red fire tail"] 156 | text_input = tokenizer(sample_text, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt") 157 | text_embeddings = text_encoder(text_input.input_ids.to(device)).last_hidden_state 158 | sampled_images = sample_cfg(diffusion_model, noise_scheduler, len(sample_text), in_channels, text_embeddings, image_size=image_size, guidance_scale=3.0) 159 | # 保存生成的图像 160 | for i, img in enumerate(sampled_images): 161 | img = img * 0.5 + 0.5 # Rescale to [0, 1] 162 | img = img.detach().cpu().permute(1, 2, 0).numpy() 163 | img = (img * 255).astype(np.uint8) 164 | img_pil = Image.fromarray(img) 165 | img_pil.save(f'diffusion_results/generated_image_epoch_{epoch+1}_sample_{i}.png') 166 | 167 | wandb.log({f"generated_image_{i}": wandb.Image(sampled_images[i]) for i in range(len(sample_text))}) 168 | 169 | torch.save(diffusion_model.state_dict(), 'diffusion_model_final.pth') 170 | wandb.finish() -------------------------------------------------------------------------------- /stable_diffusion_from_scratch/README.md: -------------------------------------------------------------------------------- 1 | 2 | # 从零开始实现 Stable Diffusion 模型 3 | 4 | > 与 `vae_from_scratch`和`diffusion_from_scratch` 两文件夹下的训练不同,这次的`stable_diffusion_from_scratch`是在 A100 GPU上训练,大概需要30G内存。 5 | 6 | 这个目录是我在视频中讲解的从零复现 Stable Diffusion 模型结构的代码。使用了一个800张照片的 [pokemon](https://huggingface.co/datasets/svjack/pokemon-blip-captions-en-zh) 的小规模图片数据集。其中600张用作训练(通过数据增强)、200张用作验证。 7 | 8 | 由于是从零训练并没有依赖任何预训练的模型,数据集也较小,效果不用在乎,只需关注实现原理并作为教学示例。 9 | 10 | 教学视频是循序渐进的,这个文件夹下的代码内容已经整合包含了 `vae_from_scratch/` 和 `ddpm_from_scratch/` 两部分文件夹下的内容,实现了一个完整的SD2.x模型架构: 11 | 12 | - 通过VAE将原始图片(3通道彩色 x 512高 x 512宽)压缩成潜在空间表示 (Latent dimension:4 x 64 x 64); 13 | - 将潜在空间表示( 4 x 64 x 64 )传递给 DDPM 架构进行噪声预测的训练; 14 | - 训练完成的 DDPM 模型可生成带有文字条件的潜在空间表示; 15 | - 通过VAE将潜在空间表示解码恢复成像素空间图片。 16 | 17 | 训练大概需要2~4小时。如果要达到更好的效果,则需使用更大的数据集和算力。 18 | 19 | #### 需要安装的库: 20 | ``` 21 | numpy 22 | torch 23 | torchvision 24 | Pillow 25 | datasets 26 | transformers 27 | PIL 28 | tdqm 29 | datasets 30 | ``` 31 | 32 | #### 训练图片数据集: 33 | 34 | 运行`train_stable_diffusion.py`会从huggingface上下载一个[pokemon](https://huggingface.co/datasets/svjack/pokemon-blip-captions-en-zh)的小规模图片数据集。 35 | 36 | 当然,你也可以在代码中替换成本地的其他图片数据集。 37 | 38 | #### 训练Epoch过程中的样本图片生成: 39 | 40 | 随着训练epoch过程,通过潜在空间生成(带文字条件)的512x512图片会越来越清晰。 41 | 42 | > `文本条件 = "a dragon character with tail"` 43 | > 44 | > `epoch 1, 30, 60, 90,` 生成的图片如下(可以看到学习的过程): 45 | 46 | ![img](stable_diffusion_results/img_1.png) 47 | 48 | #### 训练完成后模型生成: 49 | 50 | 模型训练完成后,运行 `sample_stable_diffusion.ipynb` 可以生成一些图片。 51 | 52 | > 1. `文本条件` 53 | > - "a water type pokemon with a big ear", 54 | > - "a yellow bird", 55 | > - "a green sheep pokemon", 56 | > - "" (无文本条件) 57 | 58 | ![img](stable_diffusion_results/img_2.png) 59 | 60 | > 2. `文本条件` 61 | > - "一只大耳朵的小鸟", 62 | > - "a blue cartoon ball with a smile on it's face", 63 | > - "a yellow bird", 64 | > - "a fish with a horn on it's head" 65 | 66 | ![img](stable_diffusion_results/img_3.png) 67 | 68 | 69 | #### 关于损失值: 70 | 71 | 计算潜在空间的均方误差损失、多样性损失、一致性损失来协调总损失。 72 | 73 | 训练图片首先被我们之前视频里的预训练VAE模型(`var_from_scratch`文件夹下)压缩放成了4x64x64的潜在空间表示。再在潜在空间训练噪声预测。并通过VAE重构成3x512x512像素空间的图片。 74 | 75 | 由于这个例子中的pokemon数据集相对较小,模型的泛化能力较差。要想达到更完美的生成效果,个人能力认为智能增加更多pokemon训练集图片数量。因为加入文本嵌入条件之后做到生成的泛化,800张图片是远远不够的(参考SD2的训练集数量是1亿张以上)。 76 | 77 | - Learning rate 及 训练损失: 78 | 79 | - ![train loss](stable_diffusion_results/img_0.png) 80 | 81 | 82 | #### 效果说明: 83 | 84 | 由于这个实现是增加了“文本条件”的,所以简单的模型架构和少量的训练集无法达到很好的泛化效果。 85 | 86 | 如果要复现Stable Diffusion 2.x的效果,必然需要更大的数据集和更多的训练时间。这个实现只是一个简单的示例,希望能给大家一些启发。 87 | 88 | 89 | #### 待改进: 90 | 91 | 其实还可以提高效果的地方有很多,比如: 92 | 93 | - 学习率调整策略(我的学习率随着整体epoch数先上升后下降的,但由于我前期看到已经产生了效果就提前停止了); 94 | - 损失函数的权重调整; 95 | - 更多的数据增强策略; -------------------------------------------------------------------------------- /stable_diffusion_from_scratch/diffusion_model.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from functorch.einops import rearrange 6 | 7 | 8 | class Attention(nn.Module): 9 | def __init__(self, embed_dim, hidden_dim, context_dim=None, num_heads=4): 10 | super(Attention, self).__init__() 11 | self.self_attn = context_dim is None 12 | self.hidden_dim = hidden_dim 13 | self.context_dim = context_dim if context_dim is not None else hidden_dim 14 | self.embed_dim = embed_dim 15 | self.num_heads = num_heads 16 | self.head_dim = embed_dim // num_heads 17 | 18 | self.query = nn.Linear(hidden_dim, embed_dim, bias=False) 19 | self.key = nn.Linear(self.context_dim, embed_dim, bias=False) 20 | self.value = nn.Linear(self.context_dim, embed_dim, bias=False) 21 | self.out_proj = nn.Linear(embed_dim, hidden_dim) 22 | 23 | def forward(self, tokens, t=None, context=None): 24 | B, T, C = tokens.shape 25 | H = self.num_heads 26 | 27 | Q = self.query(tokens).view(B, T, H, self.head_dim).transpose(1, 2) 28 | 29 | if self.self_attn: 30 | K = self.key(tokens).view(B, T, H, self.head_dim).transpose(1, 2) 31 | V = self.value(tokens).view(B, T, H, self.head_dim).transpose(1, 2) 32 | else: 33 | _, context_len, context_C = context.shape 34 | if context_C != self.context_dim: 35 | context = nn.Linear(context_C, self.context_dim).to(context.device)(context) 36 | context_C = self.context_dim 37 | 38 | K = self.key(context).view(B, context_len, H, self.head_dim).transpose(1, 2) 39 | V = self.value(context).view(B, context_len, H, self.head_dim).transpose(1, 2) 40 | 41 | # 计算注意力分数 42 | attn_scores = torch.matmul(Q, K.transpose(-1, -2)) / math.sqrt(self.head_dim) 43 | attn_probs = F.softmax(attn_scores, dim=-1) 44 | 45 | out = torch.matmul(attn_probs, V) 46 | out = out.transpose(1, 2).contiguous().view(B, T, self.embed_dim) 47 | out = self.out_proj(out) 48 | 49 | return out 50 | 51 | 52 | class TransformerBlock(nn.Module): 53 | def __init__(self, hidden_dim, context_dim, num_heads, self_attn=False, cross_attn=False): 54 | """ 55 | Build a transformer block 56 | :param hidden_dim: 图像的隐藏维度(通道数) 57 | :param context_dim: 文本的隐藏维度 58 | :param num_heads: Attention中多头的数量 59 | :param self_attn: 是否使用自注意力 60 | :param cross_attn: 是否使用交叉注意力 61 | """ 62 | super(TransformerBlock, self).__init__() 63 | self.self_attn = self_attn 64 | self.cross_attn = cross_attn 65 | 66 | # Self-attention 自注意力 67 | self.attn_self = Attention(hidden_dim, hidden_dim, num_heads=num_heads) if self_attn else None 68 | 69 | # Cross-attention 交叉注意力 70 | self.attn_cross = Attention(hidden_dim, hidden_dim, context_dim=context_dim, num_heads=num_heads) if cross_attn else None 71 | 72 | self.norm1 = nn.LayerNorm(hidden_dim) 73 | self.norm2 = nn.LayerNorm(hidden_dim) 74 | self.norm3 = nn.LayerNorm(hidden_dim) 75 | self.norm4 = nn.LayerNorm(hidden_dim) 76 | 77 | self.ffn1 = nn.Sequential( 78 | nn.Linear(hidden_dim, 4 * hidden_dim), 79 | nn.GELU(), 80 | nn.Linear(4 * hidden_dim, hidden_dim), 81 | nn.Dropout(0.1) 82 | ) 83 | self.ffn2 = nn.Sequential( 84 | nn.Linear(hidden_dim, 4 * hidden_dim), 85 | nn.GELU(), 86 | nn.Linear(4 * hidden_dim, hidden_dim), 87 | nn.Dropout(0.1) 88 | ) 89 | 90 | def forward(self, x, t=None, context=None): 91 | if self.self_attn: 92 | x = self.attn_self(self.norm1(x)) + x 93 | x = self.ffn1(self.norm2(x)) + x 94 | 95 | if self.cross_attn: 96 | x = self.attn_cross(self.norm3(x), context=context) + x 97 | x = self.ffn2(self.norm4(x)) + x 98 | 99 | return x 100 | 101 | 102 | class SpatialTransformer(nn.Module): 103 | def __init__(self, hidden_dim, context_dim=512, num_heads=4, self_attn=False, cross_attn=False): 104 | super(SpatialTransformer, self).__init__() 105 | self.transformer = TransformerBlock(hidden_dim, context_dim, num_heads, self_attn, cross_attn) 106 | self.context_proj = nn.Linear(context_dim, hidden_dim) if context_dim != hidden_dim else nn.Identity() 107 | self.self_attn = self_attn 108 | self.cross_attn = cross_attn 109 | 110 | def forward(self, x, t=None, context=None): 111 | b, c, h, w = x.shape 112 | x_res = x # 用作残差连接 113 | x = rearrange(x, "b c h w -> b (h w) c") 114 | 115 | if context is not None: 116 | context = self.context_proj(context) 117 | 118 | x = self.transformer(x, t, context) 119 | x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) 120 | return x + x_res 121 | 122 | 123 | class ResnetBlock(nn.Module): 124 | """ 125 | 抄自 Stable Diffusion 1.x. 126 | 源代码中有两个版本的实现: 127 | 1) https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/model.py#L82 128 | 2) https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/openaimodel.py#L163 129 | """ 130 | def __init__(self, in_channels, out_channels, time_dim): 131 | super().__init__() 132 | self.norm1 = nn.GroupNorm(4, in_channels, eps=1e-6) # SD1.x uses eps=1e-6 133 | self.norm2 = nn.GroupNorm(4, out_channels, eps=1e-6) 134 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) 135 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) 136 | self.activation = nn.SiLU() 137 | self.residual_conv = nn.Conv2d(in_channels, out_channels, 1, stride=1, padding=0) if in_channels != out_channels else nn.Identity() 138 | self.dropout = nn.Dropout(0.1) 139 | self.time_proj = torch.nn.Linear(time_dim, out_channels) 140 | 141 | def forward(self, x, t): 142 | residual = self.residual_conv(x) 143 | 144 | x = self.conv1(self.activation(self.norm1(x))) 145 | x = x + self.time_proj(self.activation(t))[:, :, None, None] # 添加时间嵌入 146 | x = self.dropout(x) 147 | x = self.conv2(self.activation(self.norm2(x))) 148 | 149 | return x + residual 150 | 151 | 152 | class DownBlock(nn.Module): 153 | def __init__(self, in_channels, out_channels, time_dim, self_attn=False, cross_attn=False, num_heads=1, context_dim=512): 154 | super().__init__() 155 | self.resnet1 = ResnetBlock(in_channels, out_channels, time_dim) 156 | self.transformer1 = SpatialTransformer(out_channels, context_dim, num_heads=num_heads, self_attn=True) if self_attn else None 157 | self.resnet2 = ResnetBlock(out_channels, out_channels, time_dim) 158 | self.transformer2 = SpatialTransformer(out_channels, context_dim, num_heads=num_heads, cross_attn=True) if cross_attn else None 159 | self.downsample = nn.Conv2d(out_channels, out_channels, 3, stride=2, padding=1) # 下采样 160 | 161 | def forward(self, x, t, y): 162 | x = self.resnet1(x, t) 163 | if self.transformer1: 164 | x = self.transformer1(x, t, y) 165 | x = self.resnet2(x, t) 166 | if self.transformer2: 167 | x = self.transformer2(x, t, y) 168 | x = self.downsample(x) 169 | return x 170 | 171 | 172 | class MiddleBlock(nn.Module): 173 | def __init__(self, channels, time_dim, context_dim): 174 | super().__init__() 175 | self.resnet1 = ResnetBlock(channels, channels, time_dim) 176 | self.attn1 = SpatialTransformer(channels, context_dim, num_heads=channels//64, self_attn=True, cross_attn=True) # 256/64=4 177 | self.resnet2 = ResnetBlock(channels, channels, time_dim) 178 | # 可选:添加第二个注意力层和resnet块 179 | self.attn2 = SpatialTransformer(channels, context_dim, num_heads=channels//64, self_attn=True, cross_attn=True) # 256/64=4 180 | self.resnet3 = ResnetBlock(channels, channels, time_dim) 181 | 182 | def forward(self, x, t, context): 183 | x = self.resnet1(x, t) 184 | x = self.attn1(x, t, context) 185 | x = self.resnet2(x, t) 186 | x = self.attn2(x, t, context) 187 | x = self.resnet3(x, t) 188 | return x 189 | 190 | 191 | class UpBlock(nn.Module): 192 | def __init__(self, in_channels, out_channels, time_dim, self_attn=False, cross_attn=False, num_heads=1, context_dim=512): 193 | super().__init__() 194 | # nn.Upsample(scale_factor=2, mode='nearest'), 195 | # nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=1, padding=1), # or, kernel_size=5, padding=2 196 | self.upsample = nn.ConvTranspose2d(in_channels, out_channels, 4, stride=2, padding=1) # 上采样 197 | self.resnet1 = ResnetBlock(out_channels, out_channels, time_dim) 198 | self.transformer1 = SpatialTransformer(out_channels, context_dim, num_heads=num_heads, self_attn=True) if self_attn else None 199 | self.resnet2 = ResnetBlock(out_channels, out_channels, time_dim) 200 | self.transformer2 = SpatialTransformer(out_channels, context_dim, num_heads=num_heads, cross_attn=True) if cross_attn else None 201 | self.resnet3 = ResnetBlock(out_channels, out_channels, time_dim) 202 | 203 | def forward(self, x, t, y): 204 | x = self.upsample(x) 205 | x = self.resnet1(x, t) 206 | if self.transformer1: 207 | x = self.transformer1(x, t, y) 208 | x = self.resnet2(x, t) 209 | if self.transformer2: 210 | x = self.transformer2(x, t, y) 211 | x = self.resnet3(x, t) 212 | return x 213 | 214 | 215 | class UNet_Transformer(nn.Module): 216 | def __init__(self, in_channels=3, time_dim=256, context_dim=512): 217 | super().__init__() 218 | 219 | self.time_dim = time_dim 220 | self.time_mlp = nn.Sequential( 221 | nn.Linear(time_dim, time_dim * 4), 222 | nn.SiLU(), 223 | nn.Linear(time_dim * 4, time_dim) 224 | ) 225 | self.context_dim = context_dim 226 | 227 | # 初始卷积 228 | self.init_conv = nn.Sequential( 229 | nn.Conv2d(in_channels, 64, 3, padding=1), 230 | nn.SiLU(), 231 | nn.Conv2d(64, 64, 3, padding=1) 232 | ) # 64 x H x W 233 | 234 | # 下采样 235 | self.down1 = self._down_block(64, 128, time_dim, self_attn=True, cross_attn=False, num_heads=4, context_dim=context_dim) # 128 x H/2 x W/2 236 | self.down2 = self._down_block(128, 256, time_dim, self_attn=True, cross_attn=False, num_heads=4, context_dim=context_dim) # 256 x H/4 x W/4 237 | self.down3 = self._down_block(256, 512, time_dim, self_attn=True, cross_attn=False, num_heads=8, context_dim=context_dim) # 512 x H/8 x W/8 238 | 239 | # 中间块 240 | self.middle_block = MiddleBlock(512, time_dim, context_dim) # 512 x H/8 x W/8 241 | 242 | # 上采样 243 | self.up1 = self._up_conv(512, 256, time_dim, self_attn=True, cross_attn=True, num_heads=8, context_dim=context_dim) # 256 x H/4 x W/4 244 | self.up2 = self._up_conv(256+256, 128, time_dim, self_attn=True, cross_attn=True, num_heads=4, context_dim=context_dim) # 128 x H/2 x W/2 245 | self.up3 = self._up_conv(128+128, 64, time_dim, self_attn=True, cross_attn=True, num_heads=4, context_dim=context_dim) # 64 x H x W 246 | 247 | # 最终卷积 248 | self.final_conv = nn.Sequential( 249 | ResnetBlock(64 * 2, 64, time_dim), 250 | nn.Conv2d(64, 64, 3, stride=1, padding=1), 251 | nn.SiLU(), 252 | nn.Conv2d(64, in_channels, 3, stride=1, padding=1), 253 | ) 254 | 255 | def get_sinusoidal_position_embedding(self, timesteps, embedding_dim): 256 | half_dim = embedding_dim // 2 257 | emb = math.log(10000) / (half_dim - 1) 258 | emb = torch.exp(torch.arange(half_dim, device=timesteps.device) * -emb) 259 | emb = timesteps[:, None] * emb[None, :] 260 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) 261 | if embedding_dim % 2 == 1: # zero pad if embedding_dim is odd 262 | emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) 263 | return emb 264 | 265 | def _down_block(self, in_channels, out_channels, time_dim, self_attn=False, cross_attn=False, num_heads=1, context_dim=None): 266 | return DownBlock(in_channels, out_channels, time_dim, self_attn, cross_attn, num_heads, context_dim or self.context_dim) 267 | 268 | def _up_conv(self, in_channels, out_channels, time_dim, self_attn=False, cross_attn=False, num_heads=1, context_dim=None): 269 | return UpBlock(in_channels, out_channels, time_dim, self_attn, cross_attn, num_heads, context_dim or self.context_dim) 270 | 271 | def forward(self, x, t, y): 272 | # x: [batch, 3, H, W] 273 | # t: [batch, ] time embedding 274 | # y: [batch, 512] text embedding 275 | initial_x = x 276 | # Ensure y has the correct shape 277 | if y.dim() == 2: 278 | y = y.unsqueeze(1) # [batch, 1, context_dim] 279 | 280 | t = self.get_sinusoidal_position_embedding(t, self.time_dim) # [batch, 256] 281 | t = self.time_mlp(t) 282 | 283 | x1 = self.init_conv(x) 284 | 285 | x2 = self.down1(x1, t, y) 286 | x3 = self.down2(x2, t, y) 287 | x4 = self.down3(x3, t, y) 288 | 289 | x4 = self.middle_block(x4, t, y) 290 | 291 | x = self.up1(x4, t, y) 292 | x = torch.cat([x, x3], dim=1) # skip connection 跳跃连接 293 | x = self.up2(x, t, y) 294 | x = torch.cat([x, x2], dim=1) # skip connection 跳跃连接 295 | x = self.up3(x, t, y) 296 | x = torch.cat([x, x1], dim=1) # skip connection 跳跃连接 297 | 298 | x = self.final_conv[0](x, t) 299 | for layer in self.final_conv[1:]: 300 | x = layer(x) 301 | 302 | return x + initial_x # 全局残差连接 303 | 304 | 305 | """添加噪声过程:从原始图像开始,逐渐增加噪声,直到最终的噪声图像""" 306 | class NoiseScheduler: 307 | def __init__(self, num_timesteps, device): 308 | self.device = device 309 | self.num_timesteps = num_timesteps 310 | self.betas = self.cosine_beta_schedule(num_timesteps).to(device) # 这里我们使用余弦噪声调度。DDPM原始论文中使用的是线性调度。 311 | self.alphas = (1. - self.betas).to(device) 312 | self.alphas_cumprod = torch.cumprod(self.alphas, dim=0).to(device) 313 | self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod).to(device) 314 | self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod).to(device) 315 | 316 | def cosine_beta_schedule(self, timesteps, s=0.008): 317 | steps = timesteps + 1 318 | x = torch.linspace(0, timesteps, steps) 319 | alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2 320 | alphas_cumprod = alphas_cumprod / alphas_cumprod[0] 321 | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) 322 | return torch.clip(betas, 0.0001, 0.9999) 323 | 324 | def add_noise(self, x_start, t): 325 | """ 326 | 添加噪声到输入图像或潜在表示。 327 | :param x_start: 初始清晰图像或潜在表示 328 | :param t: 当前时间步 329 | :return: 添加噪声后的表示 330 | """ 331 | t = t.clone().detach().long().to(self.sqrt_alphas_cumprod.device) 332 | # 生成标准正态分布的噪声 333 | noise = torch.randn_like(x_start) 334 | # 获取所需的预计算值 335 | sqrt_alphas_cumprod_t = self.sqrt_alphas_cumprod[t].view(-1, 1, 1, 1) 336 | sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1, 1) 337 | # 计算第t步、带噪声的图像 338 | x_t = sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise 339 | 340 | return x_t, noise 341 | 342 | 343 | 344 | """CFG采样器(去噪)Classifier Guided Diffusion""" 345 | @torch.no_grad() 346 | def sample_cfg(model, noise_scheduler, n_samples, in_channels, text_embeddings, image_size=64, guidance_scale=3.0): 347 | """ 348 | 从噪声开始,逐渐减小噪声,直到最终的图像。 349 | :param model: UNet模型 350 | :param noise_scheduler: 噪声调度器 351 | :param n_samples: 生成的样本数量 352 | :param in_channels: 输入图像的通道数 353 | :param text_embeddings: 文本嵌入 354 | :param image_size: 图像的大小 355 | :param guidance_scale: 用于加权噪声预测的比例 356 | :return: 生成的图像 357 | """ 358 | model.eval() 359 | device = next(model.parameters()).device 360 | 361 | x = torch.randn(n_samples, in_channels, image_size, image_size).to(device) # 随机初始化噪声图像 362 | null_embeddings = torch.zeros_like(text_embeddings) # 用于无条件生成 363 | 364 | # 逐步去噪 365 | for t in reversed(range(noise_scheduler.num_timesteps)): 366 | t_batch = torch.full((n_samples,), t, device=device, dtype=torch.long) 367 | 368 | noise_pred_uncond = model(x, t_batch, y=null_embeddings) # 生成一个无条件的噪声预测 369 | noise_pred_cond = model(x, t_batch, y=text_embeddings) # 生成一个有条件的噪声预测 370 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) # CFG:结果加权后的噪声预测 371 | 372 | # 采样器的去噪过程 373 | alpha_t = noise_scheduler.alphas[t] 374 | alpha_t_bar = noise_scheduler.alphas_cumprod[t] 375 | beta_t = noise_scheduler.betas[t] 376 | 377 | if t > 0: 378 | noise = torch.randn_like(x) 379 | else: 380 | noise = torch.zeros_like(x) 381 | 382 | # 去噪公式 383 | x = (1 / torch.sqrt(alpha_t)) * (x - ((1 - alpha_t) / (torch.sqrt(1 - alpha_t_bar))) * noise_pred) + torch.sqrt(beta_t) * noise 384 | 385 | model.train() 386 | return x 387 | 388 | 389 | """普通采样器(去噪)""" 390 | @torch.no_grad() 391 | def sample(model, x_t, noise_scheduler, t, text_embeddings): 392 | """ 393 | 从噪声开始,逐渐减小噪声,直到最终的图像。 394 | 395 | 参数: 396 | - model: UNet模型用于预测噪声。 397 | - x_t: 当前时间步的噪声化表示(torch.Tensor)。 398 | - noise_scheduler: 噪声调度器,包含betas和其他预计算值。 399 | - t: 当前时间步(torch.Tensor)。 400 | - text_embeddings: 文本嵌入,用于条件生成(torch.Tensor)。 401 | 402 | 返回: 403 | - x: 去噪后的表示。 404 | """ 405 | t = t.to(x_t.device) 406 | 407 | # 获取当前时间步的beta和alpha值 408 | beta_t = noise_scheduler.betas[t] 409 | alpha_t = noise_scheduler.alphas[t] 410 | alpha_t_bar = noise_scheduler.alphas_cumprod[t] 411 | 412 | # 预测当前时间步的噪声 413 | predicted_noise = model(x_t, t, text_embeddings) 414 | 415 | # 计算去噪后的表示 416 | if t > 0: 417 | noise = torch.randn_like(x_t).to(x_t.device) 418 | else: 419 | noise = torch.zeros_like(x_t).to(x_t.device) 420 | 421 | # 去噪公式 422 | x = (1 / torch.sqrt(alpha_t)) * (x_t - ((1 - alpha_t) / (torch.sqrt(1 - alpha_t_bar))) * predicted_noise) + torch.sqrt(beta_t) * noise 423 | 424 | return x 425 | 426 | 427 | class DDIMSampler: 428 | def __init__(self, model, n_steps=50, device="cuda"): 429 | self.model = model 430 | self.n_steps = n_steps 431 | self.device = device 432 | 433 | @torch.no_grad() 434 | def sample(self, noise, context, guidance_scale=3.0): 435 | # Assuming your noise scheduler is accessible via model.noise_scheduler 436 | scheduler = self.model.noise_scheduler 437 | 438 | # Initialize x_t with pure noise 439 | x = noise 440 | 441 | for i in reversed(range(0, scheduler.num_timesteps, scheduler.num_timesteps // self.n_steps)): 442 | t = torch.full((noise.shape[0],), i, device=self.device, dtype=torch.long) 443 | 444 | # For classifier-free guidance 445 | noise_pred_uncond = self.model.unet(x, t, y=torch.zeros_like(context)) 446 | noise_pred_cond = self.model.unet(x, t, y=context) 447 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) 448 | 449 | # DDIM update step 450 | alpha_prod_t = scheduler.alphas_cumprod[t] 451 | alpha_prod_t_prev = scheduler.alphas_cumprod[t-1] if i > 0 else torch.ones_like(alpha_prod_t) 452 | 453 | pred_x0 = (x - torch.sqrt(1 - alpha_prod_t) * noise_pred) / torch.sqrt(alpha_prod_t) 454 | dir_xt = torch.sqrt(1 - alpha_prod_t_prev - scheduler.betas[t]) * noise_pred 455 | x = torch.sqrt(alpha_prod_t_prev) * pred_x0 + dir_xt 456 | 457 | return x -------------------------------------------------------------------------------- /stable_diffusion_from_scratch/stable_diffusion_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from diffusion_model import UNet_Transformer, NoiseScheduler 5 | from vae_model import VAE 6 | 7 | class StableDiffusion(nn.Module): 8 | def __init__(self, in_channels=3, latent_dim=4, image_size=512, diffusion_timesteps=1000, device="cuda"): 9 | super(StableDiffusion, self).__init__() 10 | 11 | # VAE 12 | self.vae = VAE(in_channels=in_channels, latent_dim=latent_dim, image_size=image_size) 13 | 14 | # Diffusion model (UNet) 15 | self.unet = UNet_Transformer(in_channels=latent_dim) 16 | 17 | # Noise scheduler 18 | self.noise_scheduler = NoiseScheduler(num_timesteps=diffusion_timesteps, device=device) 19 | 20 | def encode(self, x): 21 | return self.vae.encode(x)[0] 22 | 23 | def decode(self, z): 24 | return self.vae.decode(z) 25 | 26 | def diffuse(self, latents, t, context): 27 | return self.unet(latents, t, context) 28 | 29 | def forward(self, latents, t, context): 30 | 31 | noise_pred = self.diffuse(latents, t, context) 32 | 33 | return noise_pred 34 | 35 | def sample(self, context, latent_size=64, batch_size=1, guidance_scale=3.0, device="cuda"): 36 | # Generate initial random noise in the latent space 37 | latents = torch.randn(batch_size, self.vae.latent_dim, latent_size, latent_size).to(device) 38 | 39 | # Create unconditioned embedding for classifier-free guidance 40 | uncond_embeddings = torch.zeros_like(context) 41 | 42 | # Gradually denoise the latents 43 | for t in reversed(range(self.noise_scheduler.num_timesteps)): 44 | t_batch = torch.full((batch_size,), t, device=device, dtype=torch.long) 45 | 46 | # Predict noise for both conditioned and unconditioned 47 | noise_pred_uncond = self.diffuse(latents, t_batch, uncond_embeddings) 48 | noise_pred_cond = self.diffuse(latents, t_batch, context) 49 | 50 | # Perform guidance 51 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) 52 | 53 | # Get alpha and beta valuescpu 54 | alpha_t = self.noise_scheduler.alphas[t] 55 | alpha_t_bar = self.noise_scheduler.alphas_cumprod[t] 56 | beta_t = self.noise_scheduler.betas[t] 57 | 58 | # Compute "previous" noisy sample x_t -> x_t-1 59 | if t > 0: 60 | noise = torch.randn_like(latents) 61 | else: 62 | noise = torch.zeros_like(latents) 63 | 64 | latents = (1 / torch.sqrt(alpha_t)) * ( 65 | latents - ((1 - alpha_t) / (torch.sqrt(1 - alpha_t_bar))) * noise_pred 66 | ) + torch.sqrt(beta_t) * noise 67 | 68 | # Return the final latents instead of decoding them 69 | return latents 70 | 71 | def load_vae(self, vae_path): 72 | self.vae.load_state_dict(torch.load(vae_path, map_location=torch.device('cpu'))) 73 | 74 | def load_diffusion(self, diffusion_path): 75 | self.unet.load_state_dict(torch.load(diffusion_path)) 76 | 77 | class DDIMSampler: 78 | def __init__(self, model, n_steps=50, device="cuda"): 79 | self.model = model 80 | self.n_steps = n_steps 81 | self.device = device 82 | 83 | @torch.no_grad() 84 | def sample(self, noise, context, guidance_scale=3.0): 85 | 86 | scheduler = self.model.noise_scheduler 87 | 88 | # Initialize x_t with pure noise 89 | x = noise 90 | 91 | for i in reversed(range(0, scheduler.num_timesteps, scheduler.num_timesteps // self.n_steps)): 92 | t = torch.full((noise.shape[0],), i, device=self.device, dtype=torch.long) 93 | 94 | # For classifier-free guidance 95 | noise_pred_uncond = self.model.unet(x, t, y=torch.zeros_like(context)) 96 | noise_pred_cond = self.model.unet(x, t, y=context) 97 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) 98 | 99 | # DDIM update step 100 | alpha_prod_t = scheduler.alphas_cumprod[t] 101 | alpha_prod_t_prev = scheduler.alphas_cumprod[t-1] if i > 0 else torch.ones_like(alpha_prod_t) 102 | 103 | pred_x0 = (x - torch.sqrt(1 - alpha_prod_t) * noise_pred) / torch.sqrt(alpha_prod_t) 104 | dir_xt = torch.sqrt(1 - alpha_prod_t_prev - scheduler.betas[t]) * noise_pred 105 | x = torch.sqrt(alpha_prod_t_prev) * pred_x0 + dir_xt 106 | 107 | return x 108 | 109 | def load_vae_diffusion_model(vae_path, in_channels=3, latent_dim=4, image_size=512, diffusion_timesteps=1000, device="cuda"): 110 | model = StableDiffusion(in_channels, latent_dim, image_size, diffusion_timesteps, device=device) 111 | model.load_vae(vae_path) 112 | return model 113 | 114 | def load_model_from_checkpoint(checkpoint_path, in_channels=3, latent_dim=4, image_size=512, diffusion_timesteps=1000, device="cuda"): 115 | model = StableDiffusion(in_channels=in_channels, latent_dim=latent_dim, image_size=image_size, diffusion_timesteps=diffusion_timesteps, device=device) 116 | checkpoint = torch.load(checkpoint_path, map_location=device) 117 | model.load_state_dict(checkpoint['model_state_dict']) 118 | print(f"Model loaded from checkpoint at epoch {checkpoint['epoch']}") 119 | return model -------------------------------------------------------------------------------- /stable_diffusion_from_scratch/stable_diffusion_results/img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/waylandzhang/DiT_from_scratch/1f6b6c0549db56a100ecca6ccaebeba8371e3225/stable_diffusion_from_scratch/stable_diffusion_results/img.png -------------------------------------------------------------------------------- /stable_diffusion_from_scratch/stable_diffusion_results/img_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/waylandzhang/DiT_from_scratch/1f6b6c0549db56a100ecca6ccaebeba8371e3225/stable_diffusion_from_scratch/stable_diffusion_results/img_0.png -------------------------------------------------------------------------------- /stable_diffusion_from_scratch/stable_diffusion_results/img_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/waylandzhang/DiT_from_scratch/1f6b6c0549db56a100ecca6ccaebeba8371e3225/stable_diffusion_from_scratch/stable_diffusion_results/img_1.png -------------------------------------------------------------------------------- /stable_diffusion_from_scratch/stable_diffusion_results/img_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/waylandzhang/DiT_from_scratch/1f6b6c0549db56a100ecca6ccaebeba8371e3225/stable_diffusion_from_scratch/stable_diffusion_results/img_2.png -------------------------------------------------------------------------------- /stable_diffusion_from_scratch/stable_diffusion_results/img_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/waylandzhang/DiT_from_scratch/1f6b6c0549db56a100ecca6ccaebeba8371e3225/stable_diffusion_from_scratch/stable_diffusion_results/img_3.png -------------------------------------------------------------------------------- /stable_diffusion_from_scratch/train_stable_diffusion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.optim import AdamW 4 | from torch.utils.data import DataLoader, Dataset 5 | from torch.optim.lr_scheduler import CosineAnnealingLR, OneCycleLR 6 | from torchvision import transforms 7 | from datasets import load_dataset 8 | from stable_diffusion_model import load_vae_diffusion_model, StableDiffusion 9 | from transformers import CLIPTokenizer, CLIPTextModel 10 | from PIL import Image 11 | import numpy as np 12 | from tqdm.auto import tqdm 13 | import os 14 | import wandb 15 | 16 | # 超参数 17 | device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu") 18 | image_size = 512 19 | latent_size = 64 # 潜在表示的宽和高,用于生成图像 20 | n_epochs = 1000 21 | batch_size = 16 22 | lr = 1e-4 23 | num_timesteps = 1000 24 | save_checkpoint_interval = 50 25 | lambda_cons = 0.1 # 一致性损失的权重 26 | max_lambda_cons = 1.0 # 最大一致性损失权重 27 | epochs_to_max_lambda = n_epochs # 达到最大权重所需的epoch数 28 | 29 | # WandB 初始化 30 | run = wandb.init( 31 | project="stable_diffusion_from_scratch", 32 | config={ 33 | "batch_size": batch_size, 34 | "learning_rate": lr, 35 | "epochs": n_epochs, 36 | "num_timesteps": num_timesteps, 37 | }, 38 | ) 39 | 40 | class AugmentedLatentDataset(Dataset): 41 | def __init__(self, original_dataset, model, device, num_augmentations=5): 42 | self.original_dataset = original_dataset 43 | self.model = model 44 | self.device = device 45 | self.num_augmentations = num_augmentations 46 | 47 | self.augment = transforms.Compose([ 48 | transforms.RandomHorizontalFlip(), 49 | transforms.RandomRotation(10), 50 | transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), 51 | ]) 52 | 53 | def __len__(self): 54 | return len(self.original_dataset) * self.num_augmentations 55 | 56 | def __getitem__(self, idx): 57 | original_idx = idx // self.num_augmentations 58 | original_item = self.original_dataset[original_idx] 59 | 60 | image = original_item["images"] 61 | text = original_item["text"] 62 | 63 | # Apply augmentation 64 | augmented_image = self.augment(image) 65 | 66 | # Encode to latent space 67 | with torch.no_grad(): 68 | latent = self.model.encode(augmented_image.unsqueeze(0).to(self.device)) 69 | 70 | return {"latents": latent.squeeze(0).cpu(), "text": text} 71 | 72 | # 预加载数据集 73 | dataset = load_dataset("svjack/pokemon-blip-captions-en-zh", split="train") 74 | 75 | def transform(examples): 76 | images = [preprocess(image.convert("RGB")) for image in examples["image"]] 77 | return {"images": images, "text": examples["en_text"]} 78 | 79 | preprocess = transforms.Compose([ 80 | transforms.Resize((image_size, image_size)), 81 | transforms.ToTensor(), 82 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 83 | ]) 84 | 85 | def transform(examples): 86 | images = [preprocess(image.convert("RGB")) for image in examples["image"]] 87 | return {"images": images, "text": examples["en_text"]} 88 | 89 | dataset.set_transform(transform) 90 | 91 | # 初始化合并的VAE+Diffusion模型 92 | model = StableDiffusion(in_channels=3, latent_dim=4, image_size=512, diffusion_timesteps=1000, device=device) 93 | checkpoint = torch.load('stable_diffusion_results/stable_diffusion_model_checkpoint_epoch_100.pth', map_location=device) 94 | model.load_state_dict(checkpoint['model_state_dict']) 95 | model.load_vae('vae_model.pth') 96 | # model = load_vae_diffusion_model('vae_model.pth', 97 | # in_channels=3, 98 | # latent_dim=4, 99 | # image_size=512, 100 | # diffusion_timesteps=1000, 101 | # device=device) 102 | model.to(device) 103 | 104 | # Create augmented datasets 105 | train_dataset = AugmentedLatentDataset(dataset.select(range(0, 600)), model, device, num_augmentations=5) 106 | val_dataset = dataset.select(range(600, 800)) 107 | 108 | # Create data loaders 109 | train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True) 110 | val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=True) 111 | 112 | # 加载 CLIP 模型 113 | tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32") 114 | text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32").to(device) 115 | 116 | # WandB 监控 117 | wandb.watch(model, log_freq=100) 118 | 119 | # 冻结VAE参数 120 | for param in model.vae.parameters(): 121 | param.requires_grad = False 122 | # 确保UNet (diffusion model) 参数可训练 123 | for param in model.unet.parameters(): 124 | param.requires_grad = True 125 | 126 | # 优化器和学习率调度器 127 | optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=1e-4) 128 | # scheduler = CosineAnnealingLR(optimizer, T_max=n_epochs) 129 | scheduler = OneCycleLR(optimizer, max_lr=1e-4, epochs=n_epochs, steps_per_epoch=len(train_dataloader)) # OneCycleLR 学习率调度器 130 | 131 | # 创建保存生成测试图像的目录 132 | os.makedirs('stable_diffusion_results', exist_ok=True) 133 | 134 | # 辅助损失函数:多样性损失 135 | def diversity_loss(latents, use_cosine=False): 136 | """ 137 | 计算多样性损失,可选使用余弦相似度 138 | """ 139 | batch_size = latents.size(0) 140 | latents_flat = latents.view(batch_size, -1) 141 | 142 | if use_cosine: 143 | # 使用余弦相似度 144 | latents_norm = F.normalize(latents_flat, p=2, dim=1) 145 | similarity = torch.mm(latents_norm, latents_norm.t()) 146 | else: 147 | # 使用原始的点积相似度 148 | similarity = torch.mm(latents_flat, latents_flat.t()) 149 | 150 | # 移除对角线上的自相似度 151 | similarity = similarity - torch.eye(batch_size, device=latents.device) 152 | 153 | return similarity.sum() / (batch_size * (batch_size - 1)) 154 | 155 | diversity_weight = 0.01 # 多样性损失起始权重 156 | 157 | # 训练循环 158 | for epoch in range(n_epochs): 159 | model.train() 160 | progress_bar = tqdm(total=len(train_dataloader), desc=f"Epoch {epoch+1}/{n_epochs}") 161 | epoch_loss = 0.0 162 | num_batches = 0 163 | 164 | # 更新一致性损失权重 165 | current_lambda_cons = min(lambda_cons * (epoch + 1) / epochs_to_max_lambda, max_lambda_cons) 166 | 167 | # 训练模型 168 | for batch in train_dataloader: 169 | latents = batch["latents"].to(device) 170 | text = batch["text"] 171 | 172 | # 添加噪声 173 | timesteps = torch.randint(0, num_timesteps, (latents.shape[0],), device=device).long() 174 | noisy_latents, noise = model.noise_scheduler.add_noise(latents, timesteps) 175 | 176 | # 使用 CLIP 模型编码文本 177 | text_inputs = tokenizer(text, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt") 178 | text_embeddings = text_encoder(text_inputs.input_ids.to(device)).last_hidden_state 179 | 180 | # 预测噪声 181 | noise_pred = model(noisy_latents, timesteps, text_embeddings) 182 | mse_loss = F.mse_loss(noise_pred, noise) 183 | div_loss = diversity_loss(noisy_latents, use_cosine=True) 184 | 185 | # 计算去噪后的潜在表示 186 | alpha_t = model.noise_scheduler.alphas[timesteps][:, None, None, None] 187 | sqrt_alpha_t = torch.sqrt(alpha_t) 188 | sqrt_one_minus_alpha_t = torch.sqrt(1 - alpha_t) 189 | predicted_latents = (noisy_latents - sqrt_one_minus_alpha_t * noise_pred) / sqrt_alpha_t 190 | cons_loss = F.mse_loss(predicted_latents, latents) 191 | 192 | 193 | # 组合损失 194 | total_loss = mse_loss + diversity_weight * div_loss + cons_loss * current_lambda_cons 195 | epoch_loss += total_loss.item() 196 | num_batches += 1 197 | 198 | optimizer.zero_grad() 199 | total_loss.backward() 200 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) 201 | optimizer.step() 202 | scheduler.step() # OneCycleLR 学习率调度器 203 | 204 | # 动态调整多样性损失的权重 205 | if epoch % 10 == 0: 206 | diversity_weight = min(diversity_weight * 1.05, 0.1) # 逐渐增加权重,但设置上限 207 | 208 | progress_bar.update(1) 209 | progress_bar.set_postfix({"loss": epoch_loss / num_batches}) 210 | 211 | average_train_loss = epoch_loss / num_batches 212 | 213 | # 验证集上评估模型 214 | model.eval() 215 | val_loss = 0.0 216 | val_batches = 0 217 | with torch.no_grad(): 218 | for batch in val_dataloader: 219 | data = batch["images"].to(device) 220 | latents = model.encode(data) 221 | text = batch["text"] 222 | 223 | # 添加噪声 224 | timesteps = torch.randint(0, num_timesteps, (latents.shape[0],), device=device).long() 225 | noisy_latents, noise = model.noise_scheduler.add_noise(latents, timesteps) 226 | 227 | # 使用 CLIP 模型编码文本 228 | text_inputs = tokenizer(text, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt") 229 | text_embeddings = text_encoder(text_inputs.input_ids.to(device)).last_hidden_state 230 | 231 | # 预测噪声 232 | noise_pred = model(noisy_latents, timesteps, text_embeddings) 233 | mse_loss = F.mse_loss(noise_pred, noise) 234 | 235 | val_loss += mse_loss.item() 236 | val_batches += 1 237 | 238 | average_val_loss = val_loss / val_batches 239 | 240 | # scheduler.step() 241 | 242 | wandb.log({ 243 | "epoch": epoch, 244 | "learning_rate": scheduler.get_last_lr()[0], 245 | "train_loss": average_train_loss, 246 | "val_loss": average_val_loss, 247 | }) 248 | 249 | # 保存模型检查点 250 | if (epoch + 1) % save_checkpoint_interval == 0: 251 | torch.save({ 252 | 'epoch': epoch, 253 | 'model_state_dict': model.state_dict(), 254 | 'optimizer_state_dict': optimizer.state_dict(), 255 | 'scheduler_state_dict': scheduler.state_dict(), 256 | 'train_loss': epoch_loss, 257 | 'val_loss': val_loss, 258 | }, f'stable_diffusion_results/stable_diffusion_model_checkpoint_epoch_{epoch+1}.pth') 259 | 260 | # 生成测试图像 261 | if (epoch + 1) % save_checkpoint_interval == 0: 262 | model.eval() 263 | with torch.no_grad(): 264 | sample_text = ["a water type pokemon", "a red pokemon with a red fire tail"] 265 | text_input = tokenizer(sample_text, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt") 266 | text_embeddings = text_encoder(text_input.input_ids.to(device)).last_hidden_state 267 | 268 | # 使用模型的sample方法生成图像 269 | sampled_latents = model.sample(text_embeddings, latent_size=latent_size, batch_size=len(sample_text), guidance_scale=7.5, device=device) 270 | 271 | # 使用VAE解码器将潜在表示解码回像素空间 272 | sampled_images = model.decode(sampled_latents) 273 | 274 | # 保存生成的图像 275 | for i, img in enumerate(sampled_images): 276 | img = img * 0.5 + 0.5 # Rescale to [0, 1] 277 | img = img.detach().cpu().permute(1, 2, 0).numpy() 278 | img = (img * 255).astype(np.uint8) 279 | img_pil = Image.fromarray(img) 280 | img_pil.save(f'stable_diffusion_results/generated_image_epoch_{epoch+1}_sample_{i}.png') 281 | 282 | wandb.log({f"generated_image_{i}": wandb.Image(sampled_images[i]) for i in range(len(sample_text))}) 283 | 284 | torch.save(model.state_dict(), 'stable_diffusion_model_final.pth') 285 | wandb.finish() -------------------------------------------------------------------------------- /stable_diffusion_from_scratch/vae_model.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/waylandzhang/DiT_from_scratch/1f6b6c0549db56a100ecca6ccaebeba8371e3225/stable_diffusion_from_scratch/vae_model.pth -------------------------------------------------------------------------------- /stable_diffusion_from_scratch/vae_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | 一个非常简单的变分自编码器(VAE)模型教学,用于训练压缩和解压缩图像于潜在空间(Latent Space)。 3 | Encoder和Decoder都是简单的卷积神经网络。 4 | Encoder用于将图像压缩为潜在空间表示,Decoder用于将潜在空间表示解压缩还原到原始图像。 5 | 6 | 在这个例子中,我们将3x512x512的图像压缩到4x64x64的特征值,并进一步输出潜在空间表示向量 z。 7 | """ 8 | import torch 9 | import torch.nn as nn 10 | 11 | # VAE model 12 | class VAE(nn.Module): 13 | def __init__(self, in_channels=3, latent_dim=4, image_size=512): 14 | super(VAE, self).__init__() 15 | self.in_channels = in_channels 16 | self.latent_dim = latent_dim 17 | self.image_size = image_size 18 | 19 | # Encoder 20 | # 3 x 512 x 512 -> 4 x 64 x 64 21 | self.encoder = nn.Sequential( 22 | self._conv_block(in_channels, 64), # 64 x 256 x 256 23 | self._conv_block(64, 128), # 128 x 128 x 128 24 | self._conv_block(128, 256), # 256 x 64 x 64 25 | ) 26 | 27 | # Encoder 的潜在空间输出 28 | self.fc_mu = nn.Conv2d(256, latent_dim, 1) # 4 x 64 x 64 <- Latent Space 29 | self.fc_var = nn.Conv2d(256, latent_dim, 1) # 4 x 64 x 64 <- Latent Space 30 | 31 | # Decoder 32 | # 4 x 64 x 64 -> 3 x 512 x 512 33 | self.decoder_input = nn.ConvTranspose2d(latent_dim, 256, 1) # 256 x 64 x 64 34 | self.decoder = nn.Sequential( 35 | self._conv_transpose_block(256, 128), # 128 x 128 x 128 36 | self._conv_transpose_block(128, 64), # 64 x 256 x 256 37 | self._conv_transpose_block(64, in_channels), # 3 x 512 x 512 38 | ) 39 | 40 | self.sigmoid = nn.Sigmoid() # [0, 1] 41 | self.tanh = nn.Tanh() # [-1, 1] 42 | 43 | def _conv_block(self, in_channels, out_channels): 44 | return nn.Sequential( 45 | nn.Conv2d(in_channels, out_channels, 3, stride=2, padding=1), 46 | # nn.GroupNorm(num_groups=1, num_channels=out_channels), 47 | nn.BatchNorm2d(out_channels), 48 | # nn.LeakyReLU(), 49 | nn.LeakyReLU(0.2) 50 | ) 51 | 52 | def _conv_transpose_block(self, in_channels, out_channels): 53 | return nn.Sequential( 54 | nn.ConvTranspose2d(in_channels, out_channels, 3, stride=2, padding=1, output_padding=1), 55 | # nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), 56 | # nn.Conv2d(in_channels, out_channels, 3, stride=1, padding=1), 57 | # nn.GroupNorm(num_groups=1, num_channels=out_channels), 58 | nn.BatchNorm2d(out_channels), 59 | # nn.LeakyReLU(), 60 | nn.LeakyReLU(0.2) 61 | ) 62 | 63 | def encode(self, input): 64 | result = self.encoder(input) 65 | mu = self.fc_mu(result) 66 | log_var = self.fc_var(result) 67 | return mu, log_var 68 | 69 | def decode(self, z): 70 | result = self.decoder_input(z) 71 | result = self.decoder(result) 72 | # result = self.sigmoid(result) # 如果原始图像被归一化为[0, 1],则使用sigmoid 73 | result = self.tanh(result) # 如果原始图像被归一化为[-1, 1],则使用tanh 74 | # return result.view(-1, self.in_channels, self.image_size, self.image_size) 75 | return result 76 | 77 | def reparameterize(self, mu, logvar): 78 | std = torch.exp(0.5 * logvar) 79 | eps = torch.randn_like(std) 80 | return eps * std + mu 81 | 82 | def forward(self, input): 83 | """ 84 | 返回4个值: 85 | reconstruction, input, mu, log_var 86 | """ 87 | mu, log_var = self.encode(input) 88 | z = self.reparameterize(mu, log_var) # 潜在空间的向量表达 Latent Vector z 89 | return self.decode(z), input, mu, log_var -------------------------------------------------------------------------------- /vae_from_scratch/README.md: -------------------------------------------------------------------------------- 1 | 2 | # 从零开始实现VAE 3 | 4 | 这里是我在视频中讲解的代码,主要是关于VAE的实现。 5 | 6 | 使用了一个pokemon的小规模图片数据集,演示训练过程。 7 | 8 | #### 需要安装的库: 9 | ``` 10 | numpy 11 | torch 12 | torchvision 13 | Pillow 14 | datasets 15 | matplotlib 16 | ``` 17 | 18 | #### 训练图片数据集: 19 | 20 | 运行`train_vae.py`会从huggingface上下载一个[pokemon](https://huggingface.co/datasets/svjack/pokemon-blip-captions-en-zh)的小规模图片数据集,然后训练VAE模型。 21 | 22 | 当然,你也可以在代码中替换成本地的其他图片数据集。 23 | 24 | #### 训练过程中的重构图片: 25 | 26 | 训练过程中将原始图片压缩到潜在空间,然后再从潜在空间解码还原成像素空间图片。 27 | 28 | Epoch 0: 29 | ![img](vae_results/reconstruction_0.png) 30 | 31 | Epoch 20: 32 | ![img](vae_results/reconstruction_20.png) 33 | 34 | Epoch 40: 35 | ![img](vae_results/reconstruction_40.png) 36 | 37 | Epoch 100: 38 | ![img](vae_results/reconstruction_100.png) 39 | 40 | #### 关于损失值: 41 | 42 | VAE的损失值是由两部分组成的,分别是重构损失和KL散度损失。两个损失的比重可以自定义。 43 | 44 | 由于这个例子中的pokemon数据集相对较小,100到200个epoch的结果差不多。要想达到更好的效果,可以增加训练集图片数量,或者减小图片尺寸。 45 | 46 | - 训练总损失: 47 | ![train loss](vae_results/train_loss.png) 48 | 49 | - MSE损失(像素): 50 | ![mse loss](vae_results/mse_loss.png) 51 | 52 | - KL散度损失(像素): 53 | ![kl loss](vae_results/kl_loss.png) 54 | 55 | 56 | #### 调用sample_vae.py尝试用训练好的VAE模型来压缩并恢复一张图片: 57 | 58 | ![sample](vae_results/sampled.png) 59 | 60 | #### 潜在空间的可视化: 61 | 62 | 将潜在空间的向量可视化,可以看到潜在空间的分布情况: 63 | 64 | ![latent space](vae_results/latent_space.png) -------------------------------------------------------------------------------- /vae_from_scratch/pokemon_sample_test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/waylandzhang/DiT_from_scratch/1f6b6c0549db56a100ecca6ccaebeba8371e3225/vae_from_scratch/pokemon_sample_test.png -------------------------------------------------------------------------------- /vae_from_scratch/sample_vae.py: -------------------------------------------------------------------------------- 1 | """ 2 | 这段代码用于展示如何使用训练好的VAE模型对图像进行编码和解码。 3 | 用自己训练好的vae模型来压缩一张图片(pokemon_sample_test.png)到潜在空间,然后再还原到像素空间并可视化的过程。 4 | 需要通过train_vae.py训练好VAE模型并保存后,才能运行这段代码。 5 | """ 6 | import numpy as np 7 | import torch 8 | from PIL import Image 9 | from matplotlib import pyplot as plt 10 | from torchvision import transforms 11 | from vae_model import VAE 12 | 13 | 14 | # 超参数 15 | device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu") 16 | image_size = 512 17 | latent_dim = 4 18 | 19 | # 加载一个随机的原始图像 20 | image_path = "pokemon_sample_test.png" 21 | original_image = Image.open(image_path) 22 | 23 | preprocess = transforms.Compose( 24 | [ 25 | transforms.Resize((image_size, image_size)), # 图片大小调整为 512 x 512 26 | transforms.ToTensor(), # 转换为张量 27 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 将像素值从 [0, 1] 转换到 [-1, 1] 28 | ] 29 | ) 30 | 31 | def transform(examples): 32 | images = [preprocess(image.convert("RGB")) for image in examples["image"]] 33 | return {"images": images} 34 | 35 | # 处理图片到3通道的RGB格式(防止有时图片是RGBA的4通道) 36 | image_tensor = preprocess(original_image.convert("RGB")).unsqueeze(0).to(device) 37 | 38 | mean_value = image_tensor.mean().item() 39 | print(f"Mean value of image_tensor: {mean_value}") 40 | 41 | # 加载我们刚刚预训练好的VAE模型 42 | vae = VAE(in_channels=3, latent_dim=latent_dim, image_size=image_size).to(device) 43 | vae.load_state_dict(torch.load('vae_model.pth', map_location=torch.device('cpu'))) 44 | 45 | # 使用VAE的encoder压缩图像到潜在空间 46 | with torch.no_grad(): 47 | mu, log_var = vae.encode(image_tensor) 48 | latent = vae.reparameterize(mu, log_var) 49 | 50 | # 使用encoder的输出通过decoder重构图像 51 | with torch.no_grad(): 52 | reconstructed_image = vae.decode(latent) 53 | 54 | # 显示原始图像 55 | plt.figure(figsize=(10, 5)) 56 | plt.subplot(1, 2, 1) 57 | plt.imshow(original_image) 58 | plt.title("Original Image") 59 | plt.axis('off') 60 | 61 | # 显示重构图像 62 | reconstructed_image = reconstructed_image.squeeze().cpu().numpy().transpose(1, 2, 0) 63 | reconstructed_image = (reconstructed_image + 1) / 2 # 从[-1, 1]转换到[0, 1] 64 | plt.subplot(1, 2, 2) 65 | plt.imshow(reconstructed_image) 66 | plt.title("Reconstructed Image") 67 | plt.axis('off') 68 | 69 | plt.show() 70 | 71 | # 将潜在向量转换为可视化的图像格式 72 | latent_image = latent.squeeze().cpu().numpy() 73 | 74 | # 检查潜在向量的形状 75 | if latent_image.ndim == 1: 76 | # 如果是1D的,将其reshape成2D图像 77 | side_length = int(np.ceil(np.sqrt(latent_image.size))) 78 | latent_image = np.pad(latent_image, (0, side_length**2 - latent_image.size), mode='constant') 79 | latent_image = latent_image.reshape((side_length, side_length)) 80 | elif latent_image.ndim == 3: 81 | # 如果是3D的,选择一个切片或进行平均 82 | latent_image = np.mean(latent_image, axis=0) 83 | 84 | # 显示潜在向量图像 85 | plt.imshow(latent_image, cmap='gray') 86 | plt.title("Latent Space Image") 87 | plt.axis('off') 88 | plt.colorbar() 89 | plt.show() 90 | -------------------------------------------------------------------------------- /vae_from_scratch/train_vae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.optim as optim 4 | from torch.utils.data import DataLoader 5 | from torchvision import transforms 6 | from torchvision.utils import save_image 7 | from vae_model import VAE 8 | from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingLR, OneCycleLR 9 | import os 10 | from datasets import load_dataset 11 | 12 | device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu") 13 | 14 | # 超参数 15 | batch_size = 8 16 | learning_rate = 1e-3 17 | num_epochs = 200 18 | image_size = 512 19 | latent_dim = 4 20 | 21 | # 需要安装 wandb 库,如果要记录训练过程可以打开下面的注释 22 | # import wandb 23 | # wandb.init(project="vae_from_scratch") 24 | # wandb.config = { 25 | # "learning_rate": learning_rate, 26 | # "epochs": num_epochs, 27 | # "batch_size": batch_size, 28 | # "image_size": image_size, 29 | # "latent_dim": latent_dim 30 | # } 31 | 32 | # 加载数据集 33 | dataset = load_dataset("svjack/pokemon-blip-captions-en-zh", split="train") 34 | # dataset = load_dataset("imagefolder", split="train", data_dir="train_images/") # 也可以这样加载本地文件夹的图片数据集 35 | 36 | preprocess = transforms.Compose( 37 | [ 38 | transforms.Resize((image_size, image_size)), # 图片大小调整为 512 x 512 39 | transforms.RandomHorizontalFlip(), # 随机水平翻转 40 | transforms.RandomRotation(10), # 随机旋转 41 | transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), # 随机颜色调整 42 | transforms.ToTensor(), # 转换为张量 43 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 将像素值从 [0, 1] 转换到 [-1, 1] 44 | ] 45 | ) 46 | 47 | def transform(examples): 48 | images = [preprocess(image.convert("RGB")) for image in examples["image"]] 49 | return {"images": images} 50 | 51 | 52 | dataset.set_transform(transform) 53 | 54 | train_dataset = dataset.select(range(0, 600)) 55 | val_dataset = dataset.select(range(600, 800)) 56 | 57 | train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True) 58 | val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=True) 59 | 60 | # 初始化模型 61 | vae = VAE(in_channels=3, latent_dim=latent_dim, image_size=image_size).to(device) 62 | 63 | # 优化器和学习率调度器 64 | optimizer = optim.AdamW(vae.parameters(), lr=learning_rate, weight_decay=1e-4) # 可以考虑加入L2正则化:weight_decay=1e-4 65 | # scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10, min_lr=5e-5) 66 | # scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs) # 余弦退火学习率调度器 67 | scheduler = OneCycleLR(optimizer, max_lr=1e-3, epochs=num_epochs, steps_per_epoch=len(train_dataloader)) 68 | 69 | 70 | # 自定义损失函数 71 | """ 72 | 这个损失函数是用于变分自编码器(VAE)的训练。它由两部分组成:重构误差(MSE)和KL散度(KLD)。 73 | 重构误差(MSE):衡量重构图像 recon_x 和原始图像 x 之间的差异。使用均方误差(MSE)作为度量标准,计算两个图像之间的像素差异的平方和。 74 | KL散度(KLD):衡量编码器输出的潜在分布 mu 和 logvar 与标准正态分布之间的差异。KL散度用于正则化潜在空间,使其接近标准正态分布。 75 | 76 | :param recon_x: 重构图像 77 | :param x: 原始图像 78 | :param mu: 编码器输出的均值 79 | :param logvar: 编码器输出的对数方差 80 | :return: 总损失值 =(重构误差 + KL散度) <- 也可以调整加法的比重 81 | """ 82 | 83 | def vae_loss_function(recon_x, x, mu, logvar, kld_weight=0.1): 84 | batch_size = x.size(0) 85 | mse = F.mse_loss(recon_x, x, reduction='sum') 86 | kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) 87 | # 总损失 - 用于优化 88 | total_loss = mse + kld_weight * kld 89 | # 每像素指标 - 用于监控 90 | mse_per_pixel = mse / (batch_size * x.size(1) * x.size(2) * x.size(3)) 91 | kld_per_pixel = kld / (batch_size * x.size(1) * x.size(2) * x.size(3)) 92 | 93 | return total_loss, mse, kld_weight * kld, mse_per_pixel, kld_per_pixel 94 | 95 | # 创建保存生成测试图像的目录 96 | os.makedirs('vae_results', exist_ok=True) 97 | 98 | # 训练循环 99 | for epoch in range(num_epochs): 100 | vae.train() 101 | train_loss = 0 102 | mse_loss_total = 0 103 | kl_loss_total = 0 104 | mse_vs_kld = 0 105 | for batch_idx, batch in enumerate(train_dataloader): 106 | 107 | data = batch["images"].to(device) # [batch, 3, 512, 512] 的原始图像张量 108 | optimizer.zero_grad() 109 | 110 | recon_batch, _, mu, logvar = vae(data) # 传递给VAE模型,获取重构图像、均值和对数方差 111 | loss, mse, kld, mse_per_pixel, kld_per_pixel = vae_loss_function(recon_batch, data, mu, logvar) # 计算损失 112 | 113 | loss.backward() 114 | train_loss += loss.item() 115 | mse_vs_kld += mse_per_pixel / kld_per_pixel 116 | mse_loss_total += mse_per_pixel.item() 117 | kl_loss_total += kld_per_pixel.item() 118 | optimizer.step() 119 | scheduler.step() # OneCycleLR 在每个批次后调用 120 | 121 | # scheduler.step() # 除了 OneCycleLR 之外,其他调度器都需要在每个 epoch 结束时调用 122 | 123 | avg_train_loss = train_loss / len(train_dataloader.dataset) 124 | avg_mse_loss = mse_loss_total / len(train_dataloader.dataset) 125 | avg_kl_loss = kl_loss_total / len(train_dataloader.dataset) 126 | avg_mse_vs_kld = mse_vs_kld / len(train_dataloader) 127 | 128 | print(f'====> Epoch: {epoch} | Learning rate: {scheduler.get_last_lr()[0]:.6f}') 129 | print(f'Total loss: {avg_train_loss:.4f}') 130 | print(f'MSE loss (pixel): {avg_mse_loss:.6f} | KL loss (pixel): {avg_kl_loss:.6f}') 131 | 132 | # 验证集上的损失 133 | vae.eval() 134 | val_loss = 0 135 | with torch.no_grad(): 136 | for batch_idx, batch in enumerate(val_dataloader): 137 | data = batch["images"].to(device) 138 | recon_batch, _, mu, logvar = vae(data) 139 | loss,_,_,_,_ = vae_loss_function(recon_batch, data, mu, logvar) 140 | val_loss += loss.item() 141 | 142 | val_loss /= len(val_dataloader.dataset) 143 | print(f'Validation set loss: {val_loss:.4f}') 144 | 145 | # 需要安装 wandb 库,如果要记录训练过程可以打开下面的注释 146 | # wandb.log({ 147 | # "epoch": epoch, 148 | # "learning_rate": scheduler.get_last_lr()[0], 149 | # "train_loss": avg_train_loss, 150 | # "mse_per_pixel": avg_mse_loss, 151 | # "kl_per_pixel": avg_kl_loss, 152 | # "mse_vs_kld": avg_mse_vs_kld, 153 | # "val_loss": val_loss, 154 | # }) 155 | 156 | # 生成一些重构图像和可视化 157 | if epoch % 20 == 0: 158 | with torch.no_grad(): 159 | # 获取实际的批次大小 160 | actual_batch_size = data.size(0) 161 | # 重构图像 162 | n = min(actual_batch_size, 8) 163 | comparison = torch.cat([data[:n], recon_batch.view(actual_batch_size, 3, image_size, image_size)[:n]]) 164 | comparison = (comparison * 0.5) + 0.5 # 将 [-1, 1] 转换回 [0, 1] 165 | save_image(comparison.cpu(), f'vae_results/reconstruction_{epoch}.png', nrow=n) 166 | 167 | # 需要安装 wandb 库,如果要记录训练过程可以打开下面的注释 168 | # wandb.log({"reconstruction": wandb.Image(f'vae_results/reconstruction_{epoch}.png')}) 169 | 170 | torch.save(vae.state_dict(), 'vae_model.pth') 171 | print("Training completed.") 172 | # 需要安装 wandb 库,如果要记录训练过程可以打开下面的注释 173 | # wandb.finish() -------------------------------------------------------------------------------- /vae_from_scratch/vae_model.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/waylandzhang/DiT_from_scratch/1f6b6c0549db56a100ecca6ccaebeba8371e3225/vae_from_scratch/vae_model.pth -------------------------------------------------------------------------------- /vae_from_scratch/vae_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | 一个非常简单的变分自编码器(VAE)模型教学,用于训练压缩和解压缩图像于潜在空间(Latent Space)。 3 | Encoder和Decoder都是简单的卷积神经网络。 4 | Encoder用于将图像压缩为潜在空间表示,Decoder用于将潜在空间表示解压缩还原到原始图像。 5 | 6 | 在这个例子中,我们将3x512x512的图像压缩到4x64x64的特征值,并进一步输出潜在空间表示向量 z。 7 | """ 8 | import torch 9 | import torch.nn as nn 10 | 11 | # VAE model 12 | class VAE(nn.Module): 13 | def __init__(self, in_channels=3, latent_dim=4, image_size=512): 14 | super(VAE, self).__init__() 15 | self.in_channels = in_channels 16 | self.latent_dim = latent_dim 17 | self.image_size = image_size 18 | 19 | # Encoder 20 | # 3 x 512 x 512 -> 4 x 64 x 64 21 | self.encoder = nn.Sequential( 22 | self._conv_block(in_channels, 64), # 64 x 256 x 256 23 | self._conv_block(64, 128), # 128 x 128 x 128 24 | self._conv_block(128, 256), # 256 x 64 x 64 25 | ) 26 | 27 | # Encoder 的潜在空间输出 28 | self.fc_mu = nn.Conv2d(256, latent_dim, 1) # 4 x 64 x 64 <- Latent Space 29 | self.fc_var = nn.Conv2d(256, latent_dim, 1) # 4 x 64 x 64 <- Latent Space 30 | 31 | # Decoder 32 | # 4 x 64 x 64 -> 3 x 512 x 512 33 | self.decoder_input = nn.ConvTranspose2d(latent_dim, 256, 1) # 256 x 64 x 64 34 | self.decoder = nn.Sequential( 35 | self._conv_transpose_block(256, 128), # 128 x 128 x 128 36 | self._conv_transpose_block(128, 64), # 64 x 256 x 256 37 | self._conv_transpose_block(64, in_channels), # 3 x 512 x 512 38 | ) 39 | 40 | self.sigmoid = nn.Sigmoid() # [0, 1] 41 | self.tanh = nn.Tanh() # [-1, 1] 42 | 43 | def _conv_block(self, in_channels, out_channels): 44 | return nn.Sequential( 45 | nn.Conv2d(in_channels, out_channels, 3, stride=2, padding=1), 46 | # nn.GroupNorm(num_groups=1, num_channels=out_channels), 47 | nn.BatchNorm2d(out_channels), 48 | # nn.LeakyReLU(), 49 | nn.LeakyReLU(0.2) 50 | ) 51 | 52 | def _conv_transpose_block(self, in_channels, out_channels): 53 | return nn.Sequential( 54 | nn.ConvTranspose2d(in_channels, out_channels, 3, stride=2, padding=1, output_padding=1), 55 | # nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), 56 | # nn.Conv2d(in_channels, out_channels, 3, stride=1, padding=1), 57 | # nn.GroupNorm(num_groups=1, num_channels=out_channels), 58 | nn.BatchNorm2d(out_channels), 59 | # nn.LeakyReLU(), 60 | nn.LeakyReLU(0.2) 61 | ) 62 | 63 | def encode(self, input): 64 | result = self.encoder(input) 65 | mu = self.fc_mu(result) 66 | log_var = self.fc_var(result) 67 | return mu, log_var 68 | 69 | def decode(self, z): 70 | result = self.decoder_input(z) 71 | result = self.decoder(result) 72 | # result = self.sigmoid(result) # 如果原始图像被归一化为[0, 1],则使用sigmoid 73 | result = self.tanh(result) # 如果原始图像被归一化为[-1, 1],则使用tanh 74 | # return result.view(-1, self.in_channels, self.image_size, self.image_size) 75 | return result 76 | 77 | def reparameterize(self, mu, logvar): 78 | std = torch.exp(0.5 * logvar) 79 | eps = torch.randn_like(std) 80 | return eps * std + mu 81 | 82 | def forward(self, input): 83 | """ 84 | 返回4个值: 85 | reconstruction, input, mu, log_var 86 | """ 87 | mu, log_var = self.encode(input) 88 | z = self.reparameterize(mu, log_var) # 潜在空间的向量表达 Latent Vector z 89 | return self.decode(z), input, mu, log_var -------------------------------------------------------------------------------- /vae_from_scratch/vae_results/kl_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/waylandzhang/DiT_from_scratch/1f6b6c0549db56a100ecca6ccaebeba8371e3225/vae_from_scratch/vae_results/kl_loss.png -------------------------------------------------------------------------------- /vae_from_scratch/vae_results/latent_space.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/waylandzhang/DiT_from_scratch/1f6b6c0549db56a100ecca6ccaebeba8371e3225/vae_from_scratch/vae_results/latent_space.png -------------------------------------------------------------------------------- /vae_from_scratch/vae_results/mse_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/waylandzhang/DiT_from_scratch/1f6b6c0549db56a100ecca6ccaebeba8371e3225/vae_from_scratch/vae_results/mse_loss.png -------------------------------------------------------------------------------- /vae_from_scratch/vae_results/reconstruction_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/waylandzhang/DiT_from_scratch/1f6b6c0549db56a100ecca6ccaebeba8371e3225/vae_from_scratch/vae_results/reconstruction_0.png -------------------------------------------------------------------------------- /vae_from_scratch/vae_results/reconstruction_100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/waylandzhang/DiT_from_scratch/1f6b6c0549db56a100ecca6ccaebeba8371e3225/vae_from_scratch/vae_results/reconstruction_100.png -------------------------------------------------------------------------------- /vae_from_scratch/vae_results/reconstruction_20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/waylandzhang/DiT_from_scratch/1f6b6c0549db56a100ecca6ccaebeba8371e3225/vae_from_scratch/vae_results/reconstruction_20.png -------------------------------------------------------------------------------- /vae_from_scratch/vae_results/reconstruction_40.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/waylandzhang/DiT_from_scratch/1f6b6c0549db56a100ecca6ccaebeba8371e3225/vae_from_scratch/vae_results/reconstruction_40.png -------------------------------------------------------------------------------- /vae_from_scratch/vae_results/reconstruction_60.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/waylandzhang/DiT_from_scratch/1f6b6c0549db56a100ecca6ccaebeba8371e3225/vae_from_scratch/vae_results/reconstruction_60.png -------------------------------------------------------------------------------- /vae_from_scratch/vae_results/reconstruction_80.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/waylandzhang/DiT_from_scratch/1f6b6c0549db56a100ecca6ccaebeba8371e3225/vae_from_scratch/vae_results/reconstruction_80.png -------------------------------------------------------------------------------- /vae_from_scratch/vae_results/sampled.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/waylandzhang/DiT_from_scratch/1f6b6c0549db56a100ecca6ccaebeba8371e3225/vae_from_scratch/vae_results/sampled.png -------------------------------------------------------------------------------- /vae_from_scratch/vae_results/train_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/waylandzhang/DiT_from_scratch/1f6b6c0549db56a100ecca6ccaebeba8371e3225/vae_from_scratch/vae_results/train_loss.png --------------------------------------------------------------------------------