├── images ├── ddpm_cars.gif ├── ddpm_celeba.gif ├── ddim_celeba_hq.gif ├── ddpm_diagram.png ├── ddpm_ema_cars.gif ├── ddpm_ema_celeba.gif ├── ddim_celeba_hq_ema_1.gif ├── ddim_celeba_hq_ema_2.gif └── ddim_celeba_hq_ema_3.gif ├── LICENSE ├── README.md ├── ddim.py └── ddpm.py /images/ddpm_cars.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quickgrid/pytorch-diffusion/HEAD/images/ddpm_cars.gif -------------------------------------------------------------------------------- /images/ddpm_celeba.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quickgrid/pytorch-diffusion/HEAD/images/ddpm_celeba.gif -------------------------------------------------------------------------------- /images/ddim_celeba_hq.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quickgrid/pytorch-diffusion/HEAD/images/ddim_celeba_hq.gif -------------------------------------------------------------------------------- /images/ddpm_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quickgrid/pytorch-diffusion/HEAD/images/ddpm_diagram.png -------------------------------------------------------------------------------- /images/ddpm_ema_cars.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quickgrid/pytorch-diffusion/HEAD/images/ddpm_ema_cars.gif -------------------------------------------------------------------------------- /images/ddpm_ema_celeba.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quickgrid/pytorch-diffusion/HEAD/images/ddpm_ema_celeba.gif -------------------------------------------------------------------------------- /images/ddim_celeba_hq_ema_1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quickgrid/pytorch-diffusion/HEAD/images/ddim_celeba_hq_ema_1.gif -------------------------------------------------------------------------------- /images/ddim_celeba_hq_ema_2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quickgrid/pytorch-diffusion/HEAD/images/ddim_celeba_hq_ema_2.gif -------------------------------------------------------------------------------- /images/ddim_celeba_hq_ema_3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quickgrid/pytorch-diffusion/HEAD/images/ddim_celeba_hq_ema_3.gif -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Asif Ahmed 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ### New version, https://github.com/quickgrid/text-to-image-diffusion. 2 | 3 | # Pytorch Diffusion 4 | 5 | Implementation of diffusion models in pytorch for custom training. This code is mainly based on [this repo](https://github.com/dome272/Diffusion-Models-pytorch). 6 | 7 | Models are implemented for `64 x 64` resolution output which are scaled 2x by nearest sampling to `128 x 128` resolution. In DDPM both training and reverse sampling requires around `T` steps. In DDIM reverse sampling can be done in small number of steps. 8 | 9 | 10 | ## Results 11 | 12 | Results were upsampled from `64 x 64` trained model output to `128 x 128` by nearest interpolation. 13 | 14 | ### DDPM 15 | 16 | Stanford Cars and CelebA HQ Dataset with 500 reverse diffusion steps. GIF generated by skipping every 20 frames in reverse process. 17 | 18 | ![ddpm_cars](images/ddpm_cars.gif "ddpm_cars") 19 | ![ddpm_ema_cars](images/ddpm_ema_cars.gif "ddpm_ema_cars") 20 | ![ddpm_celeba](images/ddpm_celeba.gif "ddpm_celeba") 21 | ![ddpm_ema_celeba](images/ddpm_ema_celeba.gif "ddpm_ema_celeba") 22 | 23 | ### DDIM 24 | 25 | CelebA HQ dataset with 30-50 reverse diffusion steps. No frames skipped during GIF generation. 26 | 27 | ![ddim_celeba_hq](images/ddim_celeba_hq.gif "ddim_celeba") 28 | ![ddim_celeba_hq_ema_1](images/ddim_celeba_hq_ema_1.gif "ddim_celeba_hq_ema_1") 29 | ![ddim_celeba_hq_ema_2](images/ddim_celeba_hq_ema_2.gif "ddim_celeba_hq_ema_2") 30 | ![ddim_celeba_hqa_ema_3](images/ddim_celeba_hq_ema_3.gif "ddim_celeba_hq_ema_3") 31 | 32 | 33 | ## Instructions 34 | 35 | Parent folder path should be provided in `dataset_path`. Inside it must be one or more folder with images. These folders are used as class information. 36 | 37 | For fast training it is best to first resize to expected size and remove corrupted, low res images with tools in this repo. 38 | 39 | **Large Minibatch Training** 40 | 41 | For gradient accumulation `batch_size * accumulation_iters` is the actual expected minibatch size. If code `batch_size = 2` and `accumulation_iters = 16` then minibatch size for gradient calculation is 32. 42 | 43 | If required minibatch size is 64 and `batch_size = 8` fits in memory then `accumulation_iters` should be 8. 44 | 45 | **Resume Training** 46 | 47 | To resume training `checkpoint_path` and `checkpoint_path_ema` should be provided. 48 | 49 | **Sample Images** 50 | 51 | This will generate 4 images each with regular and ema model. 52 | 53 | ``` 54 | trainer.sample(output_name='output', sample_count=4) 55 | ``` 56 | 57 | **Sample Gif** 58 | 59 | The following will generate `out.gif` in chosen directory. The pretrained checkpoint paths must be provided to sample. 60 | 61 | ``` 62 | trainer.sample_gif( 63 | output_name='out', 64 | sample_count=2, 65 | save_path=r'C:\computer_vision\ddpm' 66 | ) 67 | ``` 68 | 69 | ### Codes 70 | 71 | | Name | Description | 72 | | ----------- | ----------- | 73 | | `ddpm.py` | DDPM implementation for testing new features. | 74 | | `ddim.py` | DDIM implementation for testing new features. | 75 | 76 | ### Pretrained Checkpoints 77 | 78 | 79 | Models are available in, https://huggingface.co/quickgrid/pytorch-diffusion. 80 | 81 | #### DDPM 82 | 83 | Trained with linear noise schedule and `T = 500` noise steps. Only trained for 1 day without waiting for further improvement. 84 | 85 | | Dataset | Download Link | 86 | | ----------- | ----------- | 87 | | [Stanford Cars]() | https://huggingface.co/quickgrid/pytorch-diffusion/blob/main/cars_61_4000.pt | 88 | | | https://huggingface.co/quickgrid/pytorch-diffusion/blob/main/cars_ema_61_4000.pt | 89 | | | | 90 | | [CelebA HQ]() | https://huggingface.co/quickgrid/pytorch-diffusion/blob/main/celeba_147_0.pt | 91 | | | https://huggingface.co/quickgrid/pytorch-diffusion/blob/main/celeba_ema_147_0.pt | 92 | 93 | 94 | ## Todo 95 | 96 | - Match ddpm, ddim variable names, functions and merge code. 97 | - Class conditional generation. 98 | - Classifier Free Guidance (CFG). 99 | - Save EMA step number with checkpoint. 100 | - Add super resolution with unet like imagen for 4X upsampling, `64x64 => 256x256 => 1024x1024`. 101 | - Train and test with SWA EMA model. 102 | - Add loss to tensorboard. 103 | - Check if overfitting, add validation. 104 | - Convert to channel last mode. 105 | - Transformer encoder block missing layer norm after MHA. 106 | - Move test class to separate file. 107 | 108 | ## Issues 109 | 110 | - Logging does not print in kaggle. 111 | 112 | ## References 113 | 114 | - Annotated DDPM implementation, https://github.com/quickgrid/paper-implementations/tree/main/pytorch/ddpm. 115 | - DDIM implementation, https://github.com/quickgrid/paper-implementations/tree/main/pytorch/ddim. 116 | - DDPM Implementation, https://www.youtube.com/watch?v=TBCRlnwJtZU. 117 | - DDPM Implementation, https://github.com/dome272/Diffusion-Models-pytorch. 118 | - DDPM paper, https://arxiv.org/pdf/2006.11239.pdf. 119 | - DDIM paper, https://arxiv.org/pdf/2010.02502.pdf. 120 | - Improved DDPM, https://arxiv.org/pdf/2102.09672.pdf. 121 | - Annotated Diffusion, https://huggingface.co/blog/annotated-diffusion. 122 | - Keras DDIM, https://keras.io/examples/generative/ddim/. 123 | - Postional embedding, http://nlp.seas.harvard.edu/annotated-transformer/. 124 | - Attention paper, https://arxiv.org/pdf/1706.03762.pdf. 125 | - Transformers, https://pytorch.org/tutorials/beginner/transformer_tutorial.html. 126 | - Transformer encoder architecture, https://arxiv.org/pdf/2010.11929.pdf. 127 | - UNet architecture, https://arxiv.org/pdf/1505.04597.pdf. 128 | -------------------------------------------------------------------------------- /ddim.py: -------------------------------------------------------------------------------- 1 | """Implementation of DDIM. 2 | 3 | References: 4 | - Annotated DDPM implementation, 5 | https://github.com/quickgrid/paper-implementations/tree/main/pytorch/denoising-diffusion. 6 | - Keras DDIM, 7 | https://keras.io/examples/generative/ddim/. 8 | """ 9 | import copy 10 | import math 11 | import os 12 | import logging 13 | import pathlib 14 | from typing import Tuple, Union, List 15 | 16 | import torch 17 | import torch.nn as nn 18 | import torchvision.utils 19 | from PIL import Image 20 | from torch.cuda.amp import GradScaler 21 | from torch.utils.checkpoint import checkpoint 22 | from torch.utils.data import Dataset, DataLoader 23 | from torchvision.transforms import transforms 24 | from tqdm import tqdm 25 | from torch import optim 26 | from torch.functional import F 27 | from torch.utils.tensorboard import SummaryWriter 28 | 29 | logging.basicConfig(format="%(asctime)s - %(levelname)s: %(message)s", level=logging.INFO, datefmt="%I:%M:%S") 30 | 31 | 32 | class Diffusion: 33 | def __init__( 34 | self, 35 | device: str, 36 | img_size: int, 37 | noise_steps: int, 38 | min_signal_rate: int = 0.02, 39 | max_signal_rate: int = 0.95, 40 | ): 41 | self.max_signal_rate = max_signal_rate 42 | self.min_signal_rate = min_signal_rate 43 | self.device = device 44 | self.noise_steps = noise_steps 45 | self.img_size = img_size 46 | 47 | def diffusion_schedule( 48 | self, 49 | diffusion_times, 50 | ) -> Tuple[torch.Tensor, torch.Tensor]: 51 | max_signal_rate = torch.tensor(self.max_signal_rate, dtype=torch.float, device=self.device) 52 | min_signal_rate = torch.tensor(self.min_signal_rate, dtype=torch.float, device=self.device) 53 | 54 | start_angle = torch.acos(max_signal_rate) 55 | end_angle = torch.acos(min_signal_rate) 56 | 57 | diffusion_angles = start_angle + diffusion_times * (end_angle - start_angle) 58 | 59 | signal_rates = torch.cos(diffusion_angles) 60 | noise_rates = torch.sin(diffusion_angles) 61 | 62 | return noise_rates, signal_rates 63 | 64 | @staticmethod 65 | def denoise( 66 | eps_model: nn.Module, 67 | noisy_images: torch.Tensor, 68 | noise_rates: torch.Tensor, 69 | signal_rates: torch.Tensor, 70 | training: bool = True 71 | ) -> Tuple[torch.Tensor, torch.Tensor]: 72 | """Predict noise component and calculate the image component using it. 73 | """ 74 | if training: 75 | pred_noises = eps_model(noisy_images, noise_rates.to(dtype=torch.long) ** 2) 76 | pred_images = (noisy_images - noise_rates * pred_noises) / signal_rates 77 | return pred_noises, pred_images 78 | 79 | with torch.no_grad(): 80 | pred_noises = eps_model(noisy_images, noise_rates.to(dtype=torch.long) ** 2) 81 | pred_images = (noisy_images - noise_rates * pred_noises) / signal_rates 82 | 83 | return pred_noises, pred_images 84 | 85 | def reverse_diffusion( 86 | self, 87 | num_images: int, 88 | diffusion_steps: int, 89 | eps_model: nn.Module, 90 | scale_factor: int = 2, 91 | sample_gif: bool = False, 92 | ) -> Union[torch.Tensor, List[Image.Image]]: 93 | eps_model.eval() 94 | 95 | frames_list = [] 96 | pred_images = None 97 | initial_noise = torch.randn((num_images, 3, self.img_size, self.img_size), device=self.device) 98 | 99 | step_size = 1.0 / diffusion_steps 100 | 101 | next_noisy_images = initial_noise 102 | for step in range(diffusion_steps): 103 | noisy_images = next_noisy_images 104 | 105 | diffusion_times = torch.ones((num_images, 1, 1, 1), device=self.device) - step * step_size 106 | noise_rates, signal_rates = self.diffusion_schedule(diffusion_times) 107 | pred_noises, pred_images = self.denoise( 108 | eps_model, noisy_images, noise_rates, signal_rates, training=False 109 | ) 110 | 111 | if sample_gif: 112 | output = ((pred_images.clamp(-1, 1) + 1) * 127.5).type(torch.uint8) 113 | output = F.interpolate(input=output, scale_factor=scale_factor, mode='nearest-exact') 114 | grid = torchvision.utils.make_grid(output) 115 | img_arr = grid.permute(1, 2, 0).cpu().numpy() 116 | img = Image.fromarray(img_arr) 117 | frames_list.append(img) 118 | 119 | next_diffusion_times = diffusion_times - step_size 120 | next_noise_rates, next_signal_rates = self.diffusion_schedule(next_diffusion_times) 121 | next_noisy_images = (next_signal_rates * pred_images + next_noise_rates * pred_noises) 122 | 123 | eps_model.train() 124 | 125 | if sample_gif: 126 | return frames_list 127 | 128 | pred_images = ((pred_images.clamp(-1, 1) + 1) * 127.5).type(torch.uint8) 129 | pred_images = F.interpolate(input=pred_images, scale_factor=scale_factor, mode='nearest-exact') 130 | return pred_images 131 | 132 | 133 | class PositionalEncoding(nn.Module): 134 | def __init__( 135 | self, 136 | embedding_dim: int, 137 | dropout: float = 0.1, 138 | max_len: int = 1000, 139 | apply_dropout: bool = True, 140 | ): 141 | """Section 3.5 of attention is all you need paper. 142 | 143 | Extended slicing method is used to fill even and odd position of sin, cos with increment of 2. 144 | Ex, `[sin, cos, sin, cos, sin, cos]` for `embedding_dim = 6`. 145 | 146 | `max_len` is equivalent to number of noise steps or patches. `embedding_dim` must same as image 147 | embedding dimension of the model. 148 | 149 | Args: 150 | embedding_dim: `d_model` in given positional encoding formula. 151 | dropout: Dropout amount. 152 | max_len: Number of embeddings to generate. Here, equivalent to total noise steps. 153 | """ 154 | super(PositionalEncoding, self).__init__() 155 | self.dropout = nn.Dropout(p=dropout) 156 | self.apply_dropout = apply_dropout 157 | 158 | pos_encoding = torch.zeros(max_len, embedding_dim) 159 | position = torch.arange(start=0, end=max_len).unsqueeze(1) 160 | div_term = torch.exp(-math.log(10000.0) * torch.arange(0, embedding_dim, 2).float() / embedding_dim) 161 | 162 | pos_encoding[:, 0::2] = torch.sin(position * div_term) 163 | pos_encoding[:, 1::2] = torch.cos(position * div_term) 164 | self.register_buffer(name='pos_encoding', tensor=pos_encoding, persistent=False) 165 | 166 | def forward(self, t: torch.LongTensor) -> torch.Tensor: 167 | """Get precalculated positional embedding at timestep t. Outputs same as video implementation 168 | code but embeddings are in [sin, cos, sin, cos] format instead of [sin, sin, cos, cos] in that code. 169 | Also batch dimension is added to final output. 170 | """ 171 | positional_encoding = self.pos_encoding[t].squeeze(1) 172 | if self.apply_dropout: 173 | return self.dropout(positional_encoding) 174 | return positional_encoding 175 | 176 | 177 | class DoubleConv(nn.Module): 178 | def __init__( 179 | self, 180 | in_channels: int, 181 | out_channels: int, 182 | mid_channels: int = None, 183 | residual: bool = False 184 | ): 185 | """Double convolutions as applied in the unet paper architecture. 186 | """ 187 | super(DoubleConv, self).__init__() 188 | self.residual = residual 189 | if not mid_channels: 190 | mid_channels = out_channels 191 | 192 | self.double_conv = nn.Sequential( 193 | nn.Conv2d( 194 | in_channels=in_channels, out_channels=mid_channels, kernel_size=(3, 3), padding=(1, 1), bias=False 195 | ), 196 | nn.GroupNorm(num_groups=1, num_channels=mid_channels), 197 | nn.GELU(), 198 | nn.Conv2d( 199 | in_channels=mid_channels, out_channels=out_channels, kernel_size=(3, 3), padding=(1, 1), bias=False, 200 | ), 201 | nn.GroupNorm(num_groups=1, num_channels=out_channels), 202 | ) 203 | 204 | def forward(self, x: torch.Tensor) -> torch.Tensor: 205 | if self.residual: 206 | return F.gelu(x + self.double_conv(x)) 207 | 208 | return self.double_conv(x) 209 | 210 | 211 | class Down(nn.Module): 212 | def __init__(self, in_channels: int, out_channels: int, emb_dim: int = 256): 213 | super(Down, self).__init__() 214 | self.maxpool_conv = nn.Sequential( 215 | nn.MaxPool2d(kernel_size=(2, 2)), 216 | DoubleConv(in_channels=in_channels, out_channels=in_channels, residual=True), 217 | DoubleConv(in_channels=in_channels, out_channels=out_channels), 218 | ) 219 | 220 | self.out_channels = out_channels 221 | 222 | self.emb_layer = nn.Sequential( 223 | nn.SiLU(), 224 | nn.Linear(in_features=emb_dim, out_features=out_channels), 225 | ) 226 | 227 | def forward(self, x: torch.Tensor, t_embedding: torch.Tensor) -> torch.Tensor: 228 | x = self.maxpool_conv(x) 229 | emb = self.emb_layer(t_embedding) 230 | emb = emb.permute(0, 3, 1, 2).repeat(1, 1, x.shape[-2], x.shape[-1]) 231 | return x + emb 232 | 233 | 234 | class Up(nn.Module): 235 | def __init__(self, in_channels: int, out_channels: int, emb_dim: int = 256): 236 | super(Up, self).__init__() 237 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 238 | self.conv = nn.Sequential( 239 | DoubleConv(in_channels=in_channels, out_channels=in_channels, residual=True), 240 | DoubleConv(in_channels=in_channels, out_channels=out_channels, mid_channels=in_channels // 2), 241 | ) 242 | 243 | self.emb_layer = nn.Sequential( 244 | nn.SiLU(), 245 | nn.Linear(in_features=emb_dim, out_features=out_channels), 246 | ) 247 | 248 | def forward(self, x: torch.Tensor, x_skip: torch.Tensor, t_embedding: torch.Tensor) -> torch.Tensor: 249 | x = self.up(x) 250 | x = torch.cat([x_skip, x], dim=1) 251 | x = self.conv(x) 252 | emb = self.emb_layer(t_embedding) 253 | emb = emb.permute(0, 3, 1, 2).repeat(1, 1, x.shape[-2], x.shape[-1]) 254 | return x + emb 255 | 256 | 257 | class MLP(nn.Module): 258 | def __init__(self, dim: int, hidden_dim: int = None, dropout: float = 0.): 259 | super(MLP, self).__init__() 260 | hidden_dim = hidden_dim or dim 261 | self.net = nn.Sequential( 262 | nn.Linear(in_features=dim, out_features=hidden_dim), 263 | nn.GELU(), 264 | nn.Dropout(p=dropout), 265 | nn.Linear(in_features=hidden_dim, out_features=dim), 266 | nn.GELU(), 267 | ) 268 | 269 | def forward(self, x: torch.Tensor) -> torch.Tensor: 270 | return self.net(x) 271 | 272 | 273 | class TransformerEncoderSA(nn.Module): 274 | def __init__(self, num_channels: int, size: int, num_heads: int = 4, hidden_dim: int = 1024, dropout: int = 0.0): 275 | """A block of transformer encoder with mutli head self attention from vision transformers paper, 276 | https://arxiv.org/pdf/2010.11929.pdf. 277 | """ 278 | super(TransformerEncoderSA, self).__init__() 279 | self.num_channels = num_channels 280 | self.size = size 281 | self.mha = nn.MultiheadAttention(embed_dim=num_channels, num_heads=num_heads, batch_first=True) 282 | self.ln_1 = nn.LayerNorm([num_channels]) 283 | self.ln_2 = nn.LayerNorm([num_channels]) 284 | self.mlp = MLP(dim=num_channels, hidden_dim=hidden_dim, dropout=dropout) 285 | 286 | def forward(self, x: torch.Tensor) -> torch.Tensor: 287 | x = x.view(-1, self.num_channels, self.size * self.size).permute(0, 2, 1) 288 | x_ln = self.ln_1(x) 289 | attention_value, _ = self.mha(query=x_ln, key=x_ln, value=x_ln) 290 | x = attention_value + x 291 | x = self.mlp(self.ln_2(x)) + x 292 | return x.permute(0, 2, 1).view(-1, self.num_channels, self.size, self.size) 293 | 294 | 295 | class UNet(nn.Module): 296 | def __init__( 297 | self, 298 | noise_steps: int, 299 | in_channels: int = 3, 300 | out_channels: int = 3, 301 | time_dim: int = 256, 302 | ): 303 | super(UNet, self).__init__() 304 | self.time_dim = time_dim 305 | self.pos_encoding = PositionalEncoding(embedding_dim=time_dim, max_len=noise_steps) 306 | 307 | self.input_conv = DoubleConv(in_channels, 64) 308 | self.down1 = Down(64, 128) 309 | self.sa1 = TransformerEncoderSA(128, 32) 310 | self.down2 = Down(128, 256) 311 | self.sa2 = TransformerEncoderSA(256, 16) 312 | self.down3 = Down(256, 256) 313 | self.sa3 = TransformerEncoderSA(256, 8) 314 | 315 | self.bottleneck1 = DoubleConv(256, 512) 316 | self.bottleneck2 = DoubleConv(512, 512) 317 | self.bottleneck3 = DoubleConv(512, 256) 318 | 319 | self.up1 = Up(512, 128) 320 | self.sa4 = TransformerEncoderSA(128, 16) 321 | self.up2 = Up(256, 64) 322 | self.sa5 = TransformerEncoderSA(64, 32) 323 | self.up3 = Up(128, 64) 324 | self.sa6 = TransformerEncoderSA(64, 64) 325 | self.out_conv = nn.Conv2d(in_channels=64, out_channels=out_channels, kernel_size=(1, 1)) 326 | 327 | def forward(self, x: torch.Tensor, t: torch.LongTensor) -> torch.Tensor: 328 | """Forward pass with image tensor and timestep reduce noise. 329 | 330 | Args: 331 | x: Image tensor of shape, [batch_size, channels, height, width]. 332 | t: Time step defined as long integer. 333 | """ 334 | t = self.pos_encoding(t) 335 | 336 | x1 = self.input_conv(x) 337 | x2 = self.down1(x1, t) 338 | x2 = self.sa1(x2) 339 | x3 = self.down2(x2, t) 340 | x3 = self.sa2(x3) 341 | x4 = self.down3(x3, t) 342 | x4 = self.sa3(x4) 343 | 344 | x4 = self.bottleneck1(x4) 345 | x4 = self.bottleneck2(x4) 346 | x4 = self.bottleneck3(x4) 347 | 348 | x = self.up1(x4, x3, t) 349 | x = self.sa4(x) 350 | x = self.up2(x, x2, t) 351 | x = self.sa5(x) 352 | x = self.up3(x, x1, t) 353 | 354 | # x = checkpoint(self.sa6, x) 355 | x = self.sa6(x) 356 | 357 | return self.out_conv(x) 358 | 359 | 360 | class EMA: 361 | def __init__(self, beta): 362 | """Modifies exponential moving average model. 363 | """ 364 | self.beta = beta 365 | self.step = 0 366 | 367 | def update_model_average(self, ema_model: nn.Module, current_model: nn.Module) -> None: 368 | for current_params, ema_params in zip(current_model.parameters(), ema_model.parameters()): 369 | old_weights, new_weights = ema_params.data, current_params.data 370 | ema_params.data = self.update_average(old_weights=old_weights, new_weights=new_weights) 371 | 372 | def update_average(self, old_weights: torch.Tensor, new_weights: torch.Tensor) -> torch.Tensor: 373 | if old_weights is None: 374 | return new_weights 375 | return old_weights * self.beta + (1 - self.beta) * new_weights 376 | 377 | def ema_step(self, ema_model: nn.Module, model: nn.Module, step_start_ema: int = 2000) -> None: 378 | if self.step < step_start_ema: 379 | self.reset_parameters(ema_model=ema_model, model=model) 380 | self.step += 1 381 | return 382 | self.update_model_average(ema_model=ema_model, current_model=model) 383 | self.step += 1 384 | 385 | @staticmethod 386 | def reset_parameters(ema_model: nn.Module, model: nn.Module) -> None: 387 | ema_model.load_state_dict(model.state_dict()) 388 | 389 | 390 | class CustomImageClassDataset(Dataset): 391 | def __init__( 392 | self, 393 | root_dir: str, 394 | image_size: int, 395 | image_channels: int 396 | ): 397 | super(CustomImageClassDataset, self).__init__() 398 | self.root_dir = root_dir 399 | self.class_list = os.listdir(root_dir) 400 | 401 | self.transform = transforms.Compose([ 402 | transforms.Resize((image_size, image_size)), 403 | transforms.ToTensor(), 404 | transforms.Normalize( 405 | mean=[0.5 for _ in range(image_channels)], 406 | std=[0.5 for _ in range(image_channels)], 407 | ) 408 | ]) 409 | 410 | self.image_labels_files_list = list() 411 | for idx, class_name_folder in enumerate(self.class_list): 412 | class_path = os.path.join(root_dir, class_name_folder) 413 | files_list = os.listdir(class_path) 414 | for image_file in files_list: 415 | self.image_labels_files_list.append( 416 | ( 417 | os.path.join(class_path, image_file), 418 | idx, 419 | ) 420 | ) 421 | 422 | self.image_files_list_len = len(self.image_labels_files_list) 423 | 424 | def __len__(self) -> int: 425 | return self.image_files_list_len 426 | 427 | def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]: 428 | image_path, class_label = self.image_labels_files_list[idx] 429 | image = Image.open(image_path) 430 | image = image.convert('RGB') 431 | return self.transform(image), class_label 432 | 433 | 434 | class Utils: 435 | def __init__(self): 436 | super(Utils, self).__init__() 437 | 438 | @staticmethod 439 | def collate_fn(batch): 440 | """Discard none samples. 441 | """ 442 | batch = list(filter(lambda x: x is not None, batch)) 443 | return torch.utils.data.dataloader.default_collate(batch) 444 | 445 | @staticmethod 446 | def save_images(images: torch.Tensor, save_path: str, nrow: int = 8) -> None: 447 | grid = torchvision.utils.make_grid(images, nrow=nrow) 448 | img_arr = grid.permute(1, 2, 0).cpu().numpy() 449 | img = Image.fromarray(img_arr) 450 | img.save(save_path) 451 | 452 | @staticmethod 453 | def save_checkpoint( 454 | epoch: int, 455 | model: nn.Module, 456 | filename: str, 457 | optimizer: optim.Optimizer = None, 458 | scheduler: optim.lr_scheduler = None, 459 | grad_scaler: GradScaler = None, 460 | ) -> None: 461 | checkpoint_dict = { 462 | 'epoch': epoch, 463 | 'state_dict': model.state_dict(), 464 | } 465 | if optimizer: 466 | checkpoint_dict['optimizer'] = optimizer.state_dict() 467 | if scheduler: 468 | checkpoint_dict['scheduler'] = scheduler.state_dict() 469 | if scheduler: 470 | checkpoint_dict['grad_scaler'] = grad_scaler.state_dict() 471 | 472 | torch.save(checkpoint_dict, filename) 473 | logging.info("=> Saving checkpoint complete.") 474 | 475 | @staticmethod 476 | def load_checkpoint( 477 | model: nn.Module, 478 | filename: str, 479 | enable_train_mode: bool, 480 | optimizer: optim.Optimizer = None, 481 | scheduler: optim.lr_scheduler = None, 482 | grad_scaler: GradScaler = None, 483 | ) -> int: 484 | logging.info("=> Loading checkpoint") 485 | saved_model = torch.load(filename, map_location="cuda") 486 | model.load_state_dict(saved_model['state_dict'], strict=False) 487 | if 'optimizer' in saved_model and enable_train_mode: 488 | optimizer.load_state_dict(saved_model['optimizer']) 489 | if 'scheduler' in saved_model and enable_train_mode: 490 | scheduler.load_state_dict(saved_model['scheduler']) 491 | if 'grad_scaler' in saved_model and enable_train_mode: 492 | grad_scaler.load_state_dict(saved_model['grad_scaler']) 493 | return saved_model['epoch'] 494 | 495 | 496 | class DDIM: 497 | def __init__( 498 | self, 499 | dataset_path: str, 500 | save_path: str = None, 501 | checkpoint_path: str = None, 502 | checkpoint_path_ema: str = None, 503 | run_name: str = 'ddpm', 504 | image_size: int = 64, 505 | image_channels: int = 3, 506 | accumulation_batch_size: int = 2, 507 | accumulation_iters: int = 16, 508 | sample_count: int = 1, 509 | num_workers: int = 0, 510 | device: str = 'cuda', 511 | num_epochs: int = 10000, 512 | fp16: bool = False, 513 | save_every: int = 500, 514 | learning_rate: float = 2e-4, 515 | noise_steps: int = 500, 516 | enable_train_mode: bool = True, 517 | ): 518 | self.num_epochs = num_epochs 519 | self.device = device 520 | self.fp16 = fp16 521 | self.save_every = save_every 522 | self.accumulation_iters = accumulation_iters 523 | self.sample_count = sample_count 524 | self.accumulation_batch_size = accumulation_batch_size 525 | self.enable_train_mode = enable_train_mode 526 | 527 | base_path = save_path if save_path is not None else os.getcwd() 528 | self.save_path = os.path.join(base_path, run_name) 529 | pathlib.Path(self.save_path).mkdir(parents=True, exist_ok=True) 530 | self.logger = SummaryWriter(log_dir=os.path.join(self.save_path, 'logs')) 531 | 532 | diffusion_dataset = CustomImageClassDataset( 533 | root_dir=dataset_path, 534 | image_size=image_size, 535 | image_channels=image_channels 536 | ) 537 | self.train_loader = DataLoader( 538 | diffusion_dataset, 539 | batch_size=accumulation_batch_size, 540 | shuffle=True, 541 | pin_memory=True, 542 | num_workers=num_workers, 543 | drop_last=False, 544 | collate_fn=Utils.collate_fn, 545 | ) 546 | 547 | self.unet_model = UNet(noise_steps=noise_steps).to(device) 548 | self.diffusion = Diffusion(img_size=image_size, device=self.device, noise_steps=noise_steps) 549 | self.optimizer = optim.Adam( 550 | params=self.unet_model.parameters(), lr=learning_rate, # betas=(0.9, 0.999) 551 | ) 552 | self.scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer=self.optimizer, T_max=300) 553 | self.grad_scaler = GradScaler() 554 | 555 | self.ema = EMA(beta=0.95) 556 | self.ema_model = copy.deepcopy(self.unet_model).eval().requires_grad_(False) 557 | 558 | # ema_avg = lambda avg_model_param, model_param, num_averaged: 0.1 * avg_model_param + 0.9 * model_param 559 | # self.swa_model = optim.swa_utils.AveragedModel(model=self.unet_model, avg_fn=ema_avg).to(self.device) 560 | # self.swa_start = 10 561 | # self.swa_scheduler = optim.swa_utils.SWALR( 562 | # optimizer=self.optimizer, swa_lr=0.05, anneal_epochs=10, anneal_strategy='cos' 563 | # ) 564 | 565 | self.start_epoch = 0 566 | if checkpoint_path: 567 | logging.info(f'Loading model weights...') 568 | self.start_epoch = Utils.load_checkpoint( 569 | model=self.unet_model, 570 | optimizer=self.optimizer, 571 | scheduler=self.scheduler, 572 | grad_scaler=self.grad_scaler, 573 | filename=checkpoint_path, 574 | enable_train_mode=enable_train_mode, 575 | ) 576 | if checkpoint_path_ema: 577 | logging.info(f'Loading EMA model weights...') 578 | _ = Utils.load_checkpoint( 579 | model=self.ema_model, 580 | filename=checkpoint_path_ema, 581 | enable_train_mode=enable_train_mode, 582 | ) 583 | 584 | def sample( 585 | self, 586 | epoch: int = None, 587 | batch_idx: int = None, 588 | sample_count: int = 1, 589 | output_name: str = None, 590 | diffusion_steps: int = 40, 591 | ) -> None: 592 | """Generates images with reverse process based on sampling method with both training model and ema model. 593 | """ 594 | sampled_images = self.diffusion.reverse_diffusion( 595 | eps_model=self.unet_model, num_images=sample_count, diffusion_steps=diffusion_steps, 596 | ) 597 | ema_sampled_images = self.diffusion.reverse_diffusion( 598 | eps_model=self.ema_model, num_images=sample_count, diffusion_steps=diffusion_steps, 599 | ) 600 | 601 | model_name = f'model_{epoch}_{batch_idx}.jpg' 602 | ema_model_name = f'model_ema_{epoch}_{batch_idx}.jpg' 603 | 604 | if output_name: 605 | model_name = f'{output_name}.jpg' 606 | ema_model_name = f'{output_name}_ema.jpg' 607 | 608 | Utils.save_images( 609 | images=sampled_images, 610 | save_path=os.path.join(self.save_path, model_name) 611 | ) 612 | Utils.save_images( 613 | images=ema_sampled_images, 614 | save_path=os.path.join(self.save_path, ema_model_name) 615 | ) 616 | 617 | def sample_gif( 618 | self, 619 | output_name: str, 620 | save_path: str = '', 621 | sample_count: int = 1, 622 | diffusion_steps: int = 40, 623 | optimize: bool = False, 624 | ) -> None: 625 | """Generates images with reverse process based on sampling method with both training model and ema model. 626 | """ 627 | sampled_images = self.diffusion.reverse_diffusion( 628 | eps_model=self.unet_model, num_images=sample_count, sample_gif=True, diffusion_steps=diffusion_steps, 629 | ) 630 | ema_sampled_images = self.diffusion.reverse_diffusion( 631 | eps_model=self.ema_model, num_images=sample_count, sample_gif=True, diffusion_steps=diffusion_steps, 632 | ) 633 | 634 | model_name = f'{output_name}.gif' 635 | sampled_images[0].save( 636 | os.path.join(save_path, model_name), 637 | save_all=True, 638 | append_images=sampled_images[1:], 639 | optimize=optimize, 640 | duration=80, 641 | loop=0 642 | ) 643 | 644 | ema_model_name = f'{output_name}_ema.gif' 645 | ema_sampled_images[0].save( 646 | os.path.join(save_path, ema_model_name), 647 | save_all=True, 648 | append_images=ema_sampled_images[1:], 649 | optimize=optimize, 650 | duration=80, 651 | loop=0 652 | ) 653 | 654 | def train(self) -> None: 655 | assert self.enable_train_mode, 'Cannot train when enable_train_mode flag disabled.' 656 | 657 | logging.info(f'Training started....') 658 | for epoch in range(self.start_epoch, self.num_epochs): 659 | accumulated_minibatch_loss = 0.0 660 | accumulated_image_loss = 0.0 661 | # accumulated_image_ema_loss = 0.0 662 | 663 | with tqdm(self.train_loader) as pbar: 664 | for batch_idx, (real_images, _) in enumerate(pbar): 665 | real_images = real_images.to(self.device) 666 | current_batch_size = real_images.shape[0] 667 | 668 | noises = torch.randn(size=(current_batch_size, 3, 64, 64), device=self.device) 669 | 670 | # sample uniform random diffusion times 671 | diffusion_times = torch.rand(size=(current_batch_size, 1, 1, 1), device=self.device) 672 | 673 | noise_rates, signal_rates = self.diffusion.diffusion_schedule(diffusion_times) 674 | # mix the images with noises accordingly 675 | noisy_images = signal_rates * real_images + noise_rates * noises 676 | 677 | with torch.autocast(device_type=self.device, dtype=torch.float16, enabled=self.fp16): 678 | 679 | pred_noises, pred_images = self.diffusion.denoise( 680 | self.unet_model, noisy_images, noise_rates, signal_rates, training=True 681 | ) 682 | 683 | # pred_noises_ema, pred_images_ema = self.diffusion.denoise( 684 | # self.ema_model, noisy_images, noise_rates, signal_rates, training=True 685 | # ) 686 | 687 | loss = F.smooth_l1_loss(pred_noises, noises) 688 | loss /= self.accumulation_iters 689 | 690 | accumulated_minibatch_loss += float(loss) 691 | accumulated_image_loss += (F.smooth_l1_loss(pred_images, real_images) / self.accumulation_iters) 692 | # accumulated_image_ema_loss += (F.smooth_l1_loss(pred_images_ema, real_images) / self.accumulation_iters) 693 | 694 | self.grad_scaler.scale(loss).backward() 695 | 696 | # if ((batch_idx + 1) % self.accumulation_iters == 0) or ((batch_idx + 1) == len(self.train_loader)): 697 | if (batch_idx + 1) % self.accumulation_iters == 0: 698 | self.grad_scaler.step(self.optimizer) 699 | self.grad_scaler.update() 700 | self.optimizer.zero_grad(set_to_none=True) 701 | self.ema.ema_step(ema_model=self.ema_model, model=self.unet_model) 702 | 703 | # if epoch > self.swa_start: 704 | # self.swa_model.update_parameters(model=self.unet_model) 705 | # self.swa_scheduler.step() 706 | # else: 707 | # self.scheduler.step() 708 | 709 | pbar.set_description( 710 | f'Loss => ' 711 | f'Noise: {float(accumulated_minibatch_loss):.4f}, ' 712 | f'Image: {accumulated_image_loss:.4f} ' 713 | # f'Image EMA: {accumulated_image_ema_loss:.4f} ' 714 | ) 715 | accumulated_minibatch_loss = 0.0 716 | accumulated_image_loss = 0.0 717 | # accumulated_image_ema_loss = 0.0 718 | 719 | if not batch_idx % self.save_every: 720 | real_images_out = ((real_images.clamp(-1, 1) + 1) * 127.5).type(torch.uint8) 721 | noisy_images_out = ((noisy_images.clamp(-1, 1) + 1) * 127.5).type(torch.uint8) 722 | pred_images_out = ((pred_images.clamp(-1, 1) + 1) * 127.5).type(torch.uint8) 723 | images_out = torch.cat([real_images_out, noisy_images_out, pred_images_out], dim=0) 724 | images_out = F.interpolate(input=images_out, scale_factor=2, mode='nearest-exact') 725 | 726 | Utils.save_images( 727 | images=images_out, 728 | save_path=os.path.join(self.save_path, 'real_noised_denoised.jpg'), 729 | nrow=self.accumulation_batch_size, 730 | ) 731 | 732 | self.sample(epoch=epoch, batch_idx=batch_idx, sample_count=self.sample_count) 733 | 734 | Utils.save_checkpoint( 735 | epoch=epoch, 736 | model=self.unet_model, 737 | optimizer=self.optimizer, 738 | scheduler=self.scheduler, 739 | grad_scaler=self.grad_scaler, 740 | filename=os.path.join(self.save_path, f'model_{epoch}_{batch_idx}.pt') 741 | ) 742 | Utils.save_checkpoint( 743 | epoch=epoch, 744 | model=self.ema_model, 745 | filename=os.path.join(self.save_path, f'model_ema_{epoch}_{batch_idx}.pt') 746 | ) 747 | 748 | self.scheduler.step() 749 | 750 | 751 | if __name__ == '__main__': 752 | ddim = DDIM( 753 | dataset_path=r'C:\computer_vision\celeba', 754 | save_path=r'C:\computer_vision\ddim', 755 | checkpoint_path=r'C:\computer_vision\ddim\ddim_celeba_66_0.pt', 756 | checkpoint_path_ema=r'C:\computer_vision\ddim\ddim_celeba_ema_66_0.pt', 757 | # enable_train_mode=False, 758 | ) 759 | ddim.train() 760 | 761 | # ddim.sample(output_name='output9', sample_count=2, diffusion_steps=40) 762 | 763 | # ddim.sample_gif( 764 | # output_name='output8', 765 | # sample_count=1, 766 | # save_path=r'C:\computer_vision\ddim', 767 | # diffusion_steps=40, 768 | # ) 769 | -------------------------------------------------------------------------------- /ddpm.py: -------------------------------------------------------------------------------- 1 | """Implementation of DDPM. 2 | 3 | Best to use corrupted, low res image mover script first then use batch image resizer to resize image 4 | to expected format before using this. 5 | 6 | References 7 | - DDPM paper, https://arxiv.org/pdf/2006.11239.pdf. 8 | - DDIM paper, https://arxiv.org/pdf/2010.02502.pdf. 9 | - Annotated Diffusion, https://huggingface.co/blog/annotated-diffusion. 10 | - Keras DDIM, https://keras.io/examples/generative/ddim/. 11 | - Implementation, https://www.youtube.com/watch?v=TBCRlnwJtZU. 12 | - Implementation, https://github.com/dome272/Diffusion-Models-pytorch. 13 | - Postional embedding, http://nlp.seas.harvard.edu/annotated-transformer/. 14 | - Attention paper, https://arxiv.org/pdf/1706.03762.pdf. 15 | - Transformers, https://pytorch.org/tutorials/beginner/transformer_tutorial.html. 16 | - Transformer encoder architecture, https://arxiv.org/pdf/2010.11929.pdf. 17 | - UNet architecture, https://arxiv.org/pdf/1505.04597.pdf. 18 | """ 19 | import copy 20 | import math 21 | import os 22 | import logging 23 | import pathlib 24 | from typing import Tuple 25 | 26 | import torch 27 | import torch.nn as nn 28 | import torchvision.utils 29 | from PIL import Image 30 | from torch.cuda.amp import GradScaler 31 | from torch.utils.data import Dataset, DataLoader 32 | from torchvision.transforms import transforms 33 | from tqdm import tqdm 34 | from torch import optim 35 | from torch.functional import F 36 | from torch.utils.tensorboard import SummaryWriter 37 | 38 | logging.basicConfig(format="%(asctime)s - %(levelname)s: %(message)s", level=logging.INFO, datefmt="%I:%M:%S") 39 | 40 | 41 | class Diffusion: 42 | def __init__( 43 | self, 44 | device: str, 45 | img_size: int, 46 | noise_steps: int = 1000, 47 | beta_start: float = 1e-4, 48 | beta_end: float = 0.02, 49 | ): 50 | self.device = device 51 | self.noise_steps = noise_steps 52 | self.beta_start = beta_start 53 | self.beta_end = beta_end 54 | self.img_size = img_size 55 | 56 | # Section 2, equation 4 and near explation for alpha, alpha hat, beta. 57 | self.beta = self.linear_noise_schedule() 58 | # self.beta = self.cosine_beta_schedule() 59 | self.alpha = 1 - self.beta 60 | self.alpha_hat = torch.cumprod(self.alpha, dim=0) 61 | 62 | # Section 3.2, algorithm 1 formula implementation. Generate values early reuse later. 63 | self.sqrt_alpha_hat = torch.sqrt(self.alpha_hat) 64 | self.sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat) 65 | 66 | # Section 3.2, equation 2 precalculation values. 67 | self.sqrt_alpha = torch.sqrt(self.alpha) 68 | self.std_beta = torch.sqrt(self.beta) 69 | 70 | # Clean up unnecessary values. 71 | del self.alpha 72 | del self.alpha_hat 73 | 74 | def linear_noise_schedule(self) -> torch.Tensor: 75 | """Same amount of noise is applied each step. Weakness is near end steps image is so noisy it is hard make 76 | out information. So noise removal is also very small amount, so it takes more steps to generate clear image. 77 | """ 78 | return torch.linspace(start=self.beta_start, end=self.beta_end, steps=self.noise_steps, device=self.device) 79 | 80 | def cosine_beta_schedule(self, s=0.008): 81 | """Cosine schedule from annotated transformers. 82 | """ 83 | steps = self.noise_steps + 1 84 | x = torch.linspace(0, self.noise_steps, steps, device=self.device) 85 | alphas_cumprod = torch.cos(((x / self.noise_steps) + s) / (1 + s) * torch.pi * 0.5) ** 2 86 | alphas_cumprod = alphas_cumprod / alphas_cumprod[0] 87 | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) 88 | return torch.clip(betas, 0.0001, 0.9999) 89 | 90 | def q_sample(self, x: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 91 | """Section 3.2, algorithm 1 formula implementation. Forward process, defined by `q`. 92 | 93 | Found in section 2. `q` gradually adds gaussian noise according to variance schedule. Also, 94 | can be seen on figure 2. 95 | """ 96 | sqrt_alpha_hat = self.sqrt_alpha_hat[t].view(-1, 1, 1, 1) 97 | sqrt_one_minus_alpha_hat = self.sqrt_one_minus_alpha_hat[t].view(-1, 1, 1, 1) 98 | epsilon = torch.randn_like(x, device=self.device) 99 | return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * epsilon, epsilon 100 | 101 | def sample_timesteps(self, batch_size: int) -> torch.Tensor: 102 | """Random timestep for each sample in a batch. Timesteps selected from [1, noise_steps]. 103 | """ 104 | return torch.randint(low=1, high=self.noise_steps, size=(batch_size, ), device=self.device) 105 | 106 | def p_sample(self, eps_model: nn.Module, n: int, scale_factor: int = 2) -> torch.Tensor: 107 | """Implementation of algorithm 2 sampling. Reverse process, defined by `p` in section 2. Short 108 | formula is defined in equation 11 of section 3.2. 109 | 110 | From noise generates image step by step. From noise_steps, (noise_steps - 1), ...., 2, 1. 111 | Here, alpha = 1 - beta. So, beta = 1 - alpha. 112 | 113 | Sample noise from normal distribution of timestep t > 1, else noise is 0. Before returning values 114 | are clamped to [-1, 1] and converted to pixel values [0, 255]. 115 | 116 | Args: 117 | scale_factor: Scales the output image by the factor. 118 | eps_model: Noise prediction model. `eps_theta(x_t, t)` in paper. Theta is the model parameters. 119 | n: Number of samples to process. 120 | 121 | Returns: 122 | Generated denoised image. 123 | """ 124 | logging.info(f'Sampling {n} images....') 125 | 126 | eps_model.eval() 127 | with torch.no_grad(): 128 | x = torch.randn((n, 3, self.img_size, self.img_size), device=self.device) 129 | 130 | for i in tqdm(reversed(range(1, self.noise_steps)), position=0): 131 | t = torch.ones(n, dtype=torch.long, device=self.device) * i 132 | 133 | sqrt_alpha_t = self.sqrt_alpha[t].view(-1, 1, 1, 1) 134 | beta_t = self.beta[t].view(-1, 1, 1, 1) 135 | sqrt_one_minus_alpha_hat_t = self.sqrt_one_minus_alpha_hat[t].view(-1, 1, 1, 1) 136 | epsilon_t = self.std_beta[t].view(-1, 1, 1, 1) 137 | 138 | random_noise = torch.randn_like(x) if i > 1 else torch.zeros_like(x) 139 | 140 | x = ((1 / sqrt_alpha_t) * (x - ((beta_t / sqrt_one_minus_alpha_hat_t) * eps_model(x, t)))) +\ 141 | (epsilon_t * random_noise) 142 | 143 | eps_model.train() 144 | 145 | x = ((x.clamp(-1, 1) + 1) * 127.5).type(torch.uint8) 146 | x = F.interpolate(input=x, scale_factor=scale_factor, mode='nearest-exact') 147 | return x 148 | 149 | def generate_gif( 150 | self, 151 | eps_model: nn.Module, 152 | n: int = 1, 153 | save_path: str = '', 154 | output_name: str = None, 155 | skip_steps: int = 20, 156 | scale_factor: int = 2, 157 | ) -> None: 158 | logging.info(f'Generating gif....') 159 | frames_list = [] 160 | 161 | eps_model.eval() 162 | with torch.no_grad(): 163 | x = torch.randn((n, 3, self.img_size, self.img_size), device=self.device) 164 | 165 | for i in tqdm(reversed(range(1, self.noise_steps)), position=0): 166 | t = torch.ones(n, dtype=torch.long, device=self.device) * i 167 | 168 | sqrt_alpha_t = self.sqrt_alpha[t].view(-1, 1, 1, 1) 169 | beta_t = self.beta[t].view(-1, 1, 1, 1) 170 | sqrt_one_minus_alpha_hat_t = self.sqrt_one_minus_alpha_hat[t].view(-1, 1, 1, 1) 171 | epsilon_t = self.std_beta[t].view(-1, 1, 1, 1) 172 | 173 | random_noise = torch.randn_like(x) if i > 1 else torch.zeros_like(x) 174 | 175 | x = ((1 / sqrt_alpha_t) * (x - ((beta_t / sqrt_one_minus_alpha_hat_t) * eps_model(x, t)))) +\ 176 | (epsilon_t * random_noise) 177 | 178 | if i % skip_steps == 0: 179 | x_img = F.interpolate(input=x, scale_factor=scale_factor, mode='nearest-exact') 180 | x_img = ((x_img.clamp(-1, 1) + 1) * 127.5).type(torch.uint8) 181 | grid = torchvision.utils.make_grid(x_img) 182 | img_arr = grid.permute(1, 2, 0).cpu().numpy() 183 | img = Image.fromarray(img_arr) 184 | frames_list.append(img) 185 | 186 | eps_model.train() 187 | 188 | output_name = output_name if output_name else 'output' 189 | frames_list[0].save( 190 | os.path.join(save_path, f'{output_name}.gif'), 191 | save_all=True, 192 | append_images=frames_list[1:], 193 | optimize=False, 194 | duration=80, 195 | loop=0 196 | ) 197 | 198 | 199 | class PositionalEncoding(nn.Module): 200 | def __init__( 201 | self, 202 | embedding_dim: int, 203 | dropout: float = 0.1, 204 | max_len: int = 1000, 205 | apply_dropout: bool = True, 206 | ): 207 | """Section 3.5 of attention is all you need paper. 208 | 209 | Extended slicing method is used to fill even and odd position of sin, cos with increment of 2. 210 | Ex, `[sin, cos, sin, cos, sin, cos]` for `embedding_dim = 6`. 211 | 212 | `max_len` is equivalent to number of noise steps or patches. `embedding_dim` must same as image 213 | embedding dimension of the model. 214 | 215 | Args: 216 | embedding_dim: `d_model` in given positional encoding formula. 217 | dropout: Dropout amount. 218 | max_len: Number of embeddings to generate. Here, equivalent to total noise steps. 219 | """ 220 | super(PositionalEncoding, self).__init__() 221 | self.dropout = nn.Dropout(p=dropout) 222 | self.apply_dropout = apply_dropout 223 | 224 | pos_encoding = torch.zeros(max_len, embedding_dim) 225 | position = torch.arange(start=0, end=max_len).unsqueeze(1) 226 | div_term = torch.exp(-math.log(10000.0) * torch.arange(0, embedding_dim, 2).float() / embedding_dim) 227 | 228 | pos_encoding[:, 0::2] = torch.sin(position * div_term) 229 | pos_encoding[:, 1::2] = torch.cos(position * div_term) 230 | self.register_buffer(name='pos_encoding', tensor=pos_encoding, persistent=False) 231 | 232 | def forward(self, t: torch.LongTensor) -> torch.Tensor: 233 | """Get precalculated positional embedding at timestep t. Outputs same as video implementation 234 | code but embeddings are in [sin, cos, sin, cos] format instead of [sin, sin, cos, cos] in that code. 235 | Also batch dimension is added to final output. 236 | """ 237 | positional_encoding = self.pos_encoding[t].squeeze(1) 238 | if self.apply_dropout: 239 | return self.dropout(positional_encoding) 240 | return positional_encoding 241 | 242 | 243 | class DoubleConv(nn.Module): 244 | def __init__( 245 | self, 246 | in_channels: int, 247 | out_channels: int, 248 | mid_channels: int = None, 249 | residual: bool = False 250 | ): 251 | """Double convolutions as applied in the unet paper architecture. 252 | """ 253 | super(DoubleConv, self).__init__() 254 | self.residual = residual 255 | if not mid_channels: 256 | mid_channels = out_channels 257 | 258 | self.double_conv = nn.Sequential( 259 | nn.Conv2d( 260 | in_channels=in_channels, out_channels=mid_channels, kernel_size=(3, 3), padding=(1, 1), bias=False 261 | ), 262 | nn.GroupNorm(num_groups=1, num_channels=mid_channels), 263 | nn.GELU(), 264 | nn.Conv2d( 265 | in_channels=mid_channels, out_channels=out_channels, kernel_size=(3, 3), padding=(1, 1), bias=False, 266 | ), 267 | nn.GroupNorm(num_groups=1, num_channels=out_channels), 268 | ) 269 | 270 | def forward(self, x: torch.Tensor) -> torch.Tensor: 271 | if self.residual: 272 | return F.gelu(x + self.double_conv(x)) 273 | 274 | return self.double_conv(x) 275 | 276 | 277 | class Down(nn.Module): 278 | def __init__(self, in_channels: int, out_channels: int, emb_dim: int = 256): 279 | super(Down, self).__init__() 280 | self.maxpool_conv = nn.Sequential( 281 | nn.MaxPool2d(kernel_size=(2, 2)), 282 | DoubleConv(in_channels=in_channels, out_channels=in_channels, residual=True), 283 | DoubleConv(in_channels=in_channels, out_channels=out_channels), 284 | ) 285 | 286 | self.emb_layer = nn.Sequential( 287 | nn.SiLU(), 288 | nn.Linear(in_features=emb_dim, out_features=out_channels), 289 | ) 290 | 291 | def forward(self, x: torch.Tensor, t_embedding: torch.Tensor) -> torch.Tensor: 292 | """Downsamples input tensor, calculates embedding and adds embedding channel wise. 293 | 294 | If, `x.shape == [4, 64, 64, 64]` and `out_channels = 128`, then max_conv outputs [4, 128, 32, 32] by 295 | downsampling in h, w and outputting specified amount of feature maps/channels. 296 | 297 | `t_embedding` is embedding of timestep of shape [batch, time_dim]. It is passed through embedding layer 298 | to output channel dimentsion equivalent to channel dimension of x tensor, so they can be summbed elementwise. 299 | 300 | Since emb_layer output needs to be summed its output is also `emb.shape == [4, 128]`. It needs to be converted 301 | to 4D tensor, [4, 128, 1, 1]. Then the channel dimension is duplicated in all of `H x W` dimension to get 302 | shape of [4, 128, 32, 32]. 128D vector is sample for each pixel position is image. Now the emb_layer output 303 | is summed with max_conv output. 304 | """ 305 | x = self.maxpool_conv(x) 306 | emb = self.emb_layer(t_embedding) 307 | emb = emb.view(emb.shape[0], emb.shape[1], 1, 1).repeat(1, 1, x.shape[-2], x.shape[-1]) 308 | return x + emb 309 | 310 | 311 | class Up(nn.Module): 312 | def __init__(self, in_channels: int, out_channels: int, emb_dim: int = 256): 313 | super(Up, self).__init__() 314 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 315 | self.conv = nn.Sequential( 316 | DoubleConv(in_channels=in_channels, out_channels=in_channels, residual=True), 317 | DoubleConv(in_channels=in_channels, out_channels=out_channels, mid_channels=in_channels // 2), 318 | ) 319 | 320 | self.emb_layer = nn.Sequential( 321 | nn.SiLU(), 322 | nn.Linear(in_features=emb_dim, out_features=out_channels), 323 | ) 324 | 325 | def forward(self, x: torch.Tensor, x_skip: torch.Tensor, t_embedding: torch.Tensor) -> torch.Tensor: 326 | x = self.up(x) 327 | x = torch.cat([x_skip, x], dim=1) 328 | x = self.conv(x) 329 | emb = self.emb_layer(t_embedding) 330 | emb = emb.view(emb.shape[0], emb.shape[1], 1, 1).repeat(1, 1, x.shape[-2], x.shape[-1]) 331 | return x + emb 332 | 333 | 334 | class TransformerEncoderSA(nn.Module): 335 | def __init__(self, num_channels: int, size: int, num_heads: int = 4): 336 | """A block of transformer encoder with mutli head self attention from vision transformers paper, 337 | https://arxiv.org/pdf/2010.11929.pdf. 338 | """ 339 | super(TransformerEncoderSA, self).__init__() 340 | self.num_channels = num_channels 341 | self.size = size 342 | self.mha = nn.MultiheadAttention(embed_dim=num_channels, num_heads=num_heads, batch_first=True) 343 | self.ln = nn.LayerNorm([num_channels]) 344 | self.ff_self = nn.Sequential( 345 | nn.LayerNorm([num_channels]), 346 | nn.Linear(in_features=num_channels, out_features=num_channels), 347 | nn.LayerNorm([num_channels]), 348 | nn.Linear(in_features=num_channels, out_features=num_channels) 349 | ) 350 | 351 | def forward(self, x: torch.Tensor) -> torch.Tensor: 352 | """Self attention. 353 | 354 | Input feature map [4, 128, 32, 32], flattened to [4, 128, 32 x 32]. Which is reshaped to per pixel 355 | feature map order, [4, 1024, 128]. 356 | 357 | Attention output is same shape as input feature map to multihead attention module which are added element wise. 358 | Before returning attention output is converted back input feature map x shape. Opposite of feature map to 359 | mha input is done which gives output [4, 128, 32, 32]. 360 | """ 361 | x = x.view(-1, self.num_channels, self.size * self.size).permute(0, 2, 1) 362 | x_ln = self.ln(x) 363 | attention_value, _ = self.mha(query=x_ln, key=x_ln, value=x_ln) 364 | attention_value = attention_value + x 365 | attention_value = self.ff_self(attention_value) + attention_value 366 | return attention_value.permute(0, 2, 1).view(-1, self.num_channels, self.size, self.size) 367 | 368 | 369 | class UNet(nn.Module): 370 | def __init__( 371 | self, 372 | in_channels: int = 3, 373 | out_channels: int = 3, 374 | noise_steps: int = 1000, 375 | time_dim: int = 256, 376 | features: list = None, 377 | ): 378 | super(UNet, self).__init__() 379 | if features is None: 380 | features = [64, 128, 256, 512] 381 | self.time_dim = time_dim 382 | self.pos_encoding = PositionalEncoding(embedding_dim=time_dim, max_len=noise_steps) 383 | 384 | self.input_conv = DoubleConv(in_channels, 64) 385 | self.down1 = Down(64, 128) 386 | self.sa1 = TransformerEncoderSA(128, 32) 387 | self.down2 = Down(128, 256) 388 | self.sa2 = TransformerEncoderSA(256, 16) 389 | self.down3 = Down(256, 256) 390 | self.sa3 = TransformerEncoderSA(256, 8) 391 | 392 | self.bottleneck1 = DoubleConv(256, 512) 393 | self.bottleneck2 = DoubleConv(512, 512) 394 | self.bottleneck3 = DoubleConv(512, 256) 395 | 396 | self.up1 = Up(512, 128) 397 | self.sa4 = TransformerEncoderSA(128, 16) 398 | self.up2 = Up(256, 64) 399 | self.sa5 = TransformerEncoderSA(64, 32) 400 | self.up3 = Up(128, 64) 401 | self.sa6 = TransformerEncoderSA(64, 64) 402 | self.out_conv = nn.Conv2d(in_channels=64, out_channels=out_channels, kernel_size=(1, 1)) 403 | 404 | def forward(self, x: torch.Tensor, t: torch.LongTensor) -> torch.Tensor: 405 | """Forward pass with image tensor and timestep reduce noise. 406 | 407 | Args: 408 | x: Image tensor of shape, [batch_size, channels, height, width]. 409 | t: Time step defined as long integer. If batch size is 4, noise step 500, then random timesteps t = [10, 26, 460, 231]. 410 | """ 411 | t = self.pos_encoding(t) 412 | 413 | x1 = self.input_conv(x) 414 | x2 = self.down1(x1, t) 415 | x2 = self.sa1(x2) 416 | x3 = self.down2(x2, t) 417 | x3 = self.sa2(x3) 418 | x4 = self.down3(x3, t) 419 | x4 = self.sa3(x4) 420 | 421 | x4 = self.bottleneck1(x4) 422 | x4 = self.bottleneck2(x4) 423 | x4 = self.bottleneck3(x4) 424 | 425 | x = self.up1(x4, x3, t) 426 | x = self.sa4(x) 427 | x = self.up2(x, x2, t) 428 | x = self.sa5(x) 429 | x = self.up3(x, x1, t) 430 | x = self.sa6(x) 431 | 432 | return self.out_conv(x) 433 | 434 | 435 | class EMA: 436 | def __init__(self, beta): 437 | """Modifies exponential moving average model. 438 | """ 439 | self.beta = beta 440 | self.step = 0 441 | 442 | def update_model_average(self, ema_model: nn.Module, current_model: nn.Module) -> None: 443 | for current_params, ema_params in zip(current_model.parameters(), ema_model.parameters()): 444 | old_weights, new_weights = ema_params.data, current_params.data 445 | ema_params.data = self.update_average(old_weights=old_weights, new_weights=new_weights) 446 | 447 | def update_average(self, old_weights: torch.Tensor, new_weights: torch.Tensor) -> torch.Tensor: 448 | if old_weights is None: 449 | return new_weights 450 | return old_weights * self.beta + (1 - self.beta) * new_weights 451 | 452 | def ema_step(self, ema_model: nn.Module, model: nn.Module, step_start_ema: int = 2000) -> None: 453 | if self.step < step_start_ema: 454 | self.reset_parameters(ema_model=ema_model, model=model) 455 | self.step += 1 456 | return 457 | self.update_model_average(ema_model=ema_model, current_model=model) 458 | self.step += 1 459 | 460 | @staticmethod 461 | def reset_parameters(ema_model: nn.Module, model: nn.Module) -> None: 462 | ema_model.load_state_dict(model.state_dict()) 463 | 464 | 465 | class CustomImageClassDataset(Dataset): 466 | def __init__( 467 | self, 468 | root_dir: str, 469 | image_size: int, 470 | image_channels: int 471 | ): 472 | super(CustomImageClassDataset, self).__init__() 473 | self.root_dir = root_dir 474 | self.class_list = os.listdir(root_dir) 475 | 476 | self.transform = transforms.Compose([ 477 | transforms.Resize((image_size, image_size)), 478 | transforms.ToTensor(), 479 | transforms.Normalize( 480 | mean=[0.5 for _ in range(image_channels)], 481 | std=[0.5 for _ in range(image_channels)], 482 | ) 483 | ]) 484 | 485 | self.image_labels_files_list = list() 486 | for idx, class_name_folder in enumerate(self.class_list): 487 | class_path = os.path.join(root_dir, class_name_folder) 488 | files_list = os.listdir(class_path) 489 | for image_file in files_list: 490 | self.image_labels_files_list.append( 491 | ( 492 | os.path.join(class_path, image_file), 493 | idx, 494 | ) 495 | ) 496 | 497 | self.image_files_list_len = len(self.image_labels_files_list) 498 | 499 | def __len__(self) -> int: 500 | return self.image_files_list_len 501 | 502 | def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]: 503 | image_path, class_label = self.image_labels_files_list[idx] 504 | image = Image.open(image_path) 505 | image = image.convert('RGB') 506 | return self.transform(image), class_label 507 | 508 | 509 | class Tester: 510 | def __init__(self, device: str = 'cuda', batch_size: int = 4, image_size: int = 64,): 511 | self.device = device 512 | self.batch_size = batch_size 513 | self.image_size = image_size 514 | 515 | def test_unet(self) -> None: 516 | net = UNet().to(self.device) 517 | print(f'Param count: {sum([p.numel() for p in net.parameters()])}') 518 | x = torch.randn(self.batch_size, 3, self.image_size, self.image_size) 519 | current_timestep = 500 520 | t = x.new_tensor([current_timestep] * x.shape[0]).long() 521 | output = net(x, t) 522 | assert x.shape == output.shape, 'Input image tensor and output image tensor of network should be same.' 523 | print(f'UNet input shape: {x.shape}') 524 | print(f'UNet output shape: {output.shape}') 525 | 526 | def test_attention(self) -> None: 527 | x = torch.randn(size=(4, 128, 32, 32)) 528 | sa1 = TransformerEncoderSA(128, 32) 529 | output = sa1(x) 530 | assert x.shape == output.shape, 'Shape of output of feature map x and self attention output should be same.' 531 | print(f'Self attention input shape: {x.shape}') 532 | print(f'Self attention output shape: {output.shape}') 533 | 534 | def test_jit(self) -> None: 535 | net = torch.jit.script(UNet().to(self.device)) 536 | print(f'Param count: {sum([p.numel() for p in net.parameters()])}') 537 | x = torch.randn(self.batch_size, 3, self.image_size, self.image_size) 538 | current_timestep = 500 539 | t = x.new_tensor([current_timestep] * x.shape[0]).long() 540 | output = net(x, t) 541 | assert x.shape == output.shape, 'Input image tensor and output image tensor of network should be same.' 542 | print(f'UNet input shape: {x.shape}') 543 | print(f'UNet output shape: {output.shape}') 544 | 545 | 546 | class Utils: 547 | def __init__(self): 548 | super(Utils, self).__init__() 549 | 550 | @staticmethod 551 | def collate_fn(batch): 552 | """Discard none samples. 553 | """ 554 | batch = list(filter(lambda x: x is not None, batch)) 555 | return torch.utils.data.dataloader.default_collate(batch) 556 | 557 | @staticmethod 558 | def save_images(images: torch.Tensor, save_path: str) -> None: 559 | grid = torchvision.utils.make_grid(images) 560 | img_arr = grid.permute(1, 2, 0).cpu().numpy() 561 | img = Image.fromarray(img_arr) 562 | img.save(save_path) 563 | 564 | @staticmethod 565 | def save_checkpoint( 566 | epoch: int, 567 | model: nn.Module, 568 | filename: str, 569 | optimizer: optim.Optimizer = None, 570 | scheduler: optim.lr_scheduler = None, 571 | grad_scaler: GradScaler = None, 572 | ) -> None: 573 | checkpoint = { 574 | 'epoch': epoch, 575 | 'state_dict': model.state_dict(), 576 | } 577 | if optimizer: 578 | checkpoint['optimizer'] = optimizer.state_dict() 579 | if scheduler: 580 | checkpoint['scheduler'] = scheduler.state_dict() 581 | if scheduler: 582 | checkpoint['grad_scaler'] = grad_scaler.state_dict() 583 | 584 | torch.save(checkpoint, filename) 585 | logging.info("=> Saving checkpoint complete.") 586 | 587 | @staticmethod 588 | def load_checkpoint( 589 | model: nn.Module, 590 | filename: str, 591 | optimizer: optim.Optimizer = None, 592 | scheduler: optim.lr_scheduler = None, 593 | grad_scaler: GradScaler = None, 594 | ) -> int: 595 | logging.info("=> Loading checkpoint") 596 | checkpoint = torch.load(filename, map_location="cuda") 597 | model.load_state_dict(checkpoint['state_dict'], strict=False) 598 | if 'optimizer' in checkpoint: 599 | optimizer.load_state_dict(checkpoint['optimizer']) 600 | if 'scheduler' in checkpoint: 601 | scheduler.load_state_dict(checkpoint['scheduler']) 602 | if 'grad_scaler' in checkpoint: 603 | grad_scaler.load_state_dict(checkpoint['grad_scaler']) 604 | return checkpoint['epoch'] 605 | 606 | 607 | class Trainer: 608 | def __init__( 609 | self, 610 | dataset_path: str, 611 | save_path: str = None, 612 | checkpoint_path: str = None, 613 | checkpoint_path_ema: str = None, 614 | run_name: str = 'ddpm', 615 | image_size: int = 64, 616 | image_channels: int = 3, 617 | batch_size: int = 2, 618 | accumulation_iters: int = 16, 619 | sample_count: int = 1, 620 | num_workers: int = 0, 621 | device: str = 'cuda', 622 | num_epochs: int = 10000, 623 | fp16: bool = False, 624 | save_every: int = 2000, 625 | learning_rate: float = 2e-4, 626 | noise_steps: int = 500, 627 | enable_train_mode: bool = True, 628 | ): 629 | self.num_epochs = num_epochs 630 | self.device = device 631 | self.fp16 = fp16 632 | self.save_every = save_every 633 | self.accumulation_iters = accumulation_iters 634 | self.sample_count = sample_count 635 | 636 | base_path = save_path if save_path is not None else os.getcwd() 637 | self.save_path = os.path.join(base_path, run_name) 638 | pathlib.Path(self.save_path).mkdir(parents=True, exist_ok=True) 639 | self.logger = SummaryWriter(log_dir=os.path.join(self.save_path, 'logs')) 640 | 641 | if enable_train_mode: 642 | diffusion_dataset = CustomImageClassDataset( 643 | root_dir=dataset_path, 644 | image_size=image_size, 645 | image_channels=image_channels 646 | ) 647 | self.train_loader = DataLoader( 648 | diffusion_dataset, 649 | batch_size=batch_size, 650 | shuffle=True, 651 | pin_memory=True, 652 | num_workers=num_workers, 653 | drop_last=False, 654 | collate_fn=Utils.collate_fn, 655 | ) 656 | 657 | self.unet_model = UNet().to(device) 658 | self.diffusion = Diffusion(img_size=image_size, device=self.device, noise_steps=noise_steps) 659 | self.optimizer = optim.Adam( 660 | params=self.unet_model.parameters(), lr=learning_rate, # betas=(0.9, 0.999) 661 | ) 662 | self.scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer=self.optimizer, T_max=300) 663 | # self.loss_fn = nn.MSELoss().to(self.device) 664 | self.grad_scaler = GradScaler() 665 | 666 | self.ema = EMA(beta=0.95) 667 | self.ema_model = copy.deepcopy(self.unet_model).eval().requires_grad_(False) 668 | 669 | # ema_avg = lambda avg_model_param, model_param, num_averaged: 0.1 * avg_model_param + 0.9 * model_param 670 | # self.swa_model = optim.swa_utils.AveragedModel(model=self.unet_model, avg_fn=ema_avg).to(self.device) 671 | # self.swa_start = 10 672 | # self.swa_scheduler = optim.swa_utils.SWALR( 673 | # optimizer=self.optimizer, swa_lr=0.05, anneal_epochs=10, anneal_strategy='cos' 674 | # ) 675 | 676 | self.start_epoch = 0 677 | if checkpoint_path: 678 | logging.info(f'Loading model weights...') 679 | self.start_epoch = Utils.load_checkpoint( 680 | model=self.unet_model, 681 | optimizer=self.optimizer, 682 | scheduler=self.scheduler, 683 | grad_scaler=self.grad_scaler, 684 | filename=checkpoint_path, 685 | ) 686 | if checkpoint_path_ema: 687 | logging.info(f'Loading EMA model weights...') 688 | _ = Utils.load_checkpoint( 689 | model=self.ema_model, 690 | filename=checkpoint_path_ema, 691 | ) 692 | 693 | def sample( 694 | self, 695 | epoch: int = None, 696 | batch_idx: int = None, 697 | sample_count: int = 1, 698 | output_name: str = None 699 | ) -> None: 700 | """Generates images with reverse process based on sampling method with both training model and ema model. 701 | """ 702 | sampled_images = self.diffusion.p_sample(eps_model=self.unet_model, n=sample_count) 703 | ema_sampled_images = self.diffusion.p_sample(eps_model=self.ema_model, n=sample_count) 704 | 705 | model_name = f'model_{epoch}_{batch_idx}.jpg' 706 | ema_model_name = f'model_ema_{epoch}_{batch_idx}.jpg' 707 | 708 | if output_name: 709 | model_name = f'{output_name}.jpg' 710 | ema_model_name = f'{output_name}_ema.jpg' 711 | 712 | Utils.save_images( 713 | images=sampled_images, 714 | save_path=os.path.join(self.save_path, model_name) 715 | ) 716 | Utils.save_images( 717 | images=ema_sampled_images, 718 | save_path=os.path.join(self.save_path, ema_model_name) 719 | ) 720 | 721 | def sample_gif( 722 | self, 723 | save_path: str = '', 724 | sample_count: int = 1, 725 | output_name: str = None, 726 | ) -> None: 727 | """Generates images with reverse process based on sampling method with both training model and ema model. 728 | """ 729 | self.diffusion.generate_gif( 730 | eps_model=self.unet_model, 731 | n=sample_count, 732 | save_path=save_path, 733 | output_name=output_name, 734 | ) 735 | self.diffusion.generate_gif( 736 | eps_model=self.ema_model, 737 | n=sample_count, 738 | save_path=save_path, 739 | output_name=f'{output_name}_ema', 740 | ) 741 | 742 | def train(self) -> None: 743 | logging.info(f'Training started....') 744 | for epoch in range(self.start_epoch, self.num_epochs): 745 | # total_loss = 0.0 746 | accumulated_minibatch_loss = 0.0 747 | 748 | with tqdm(self.train_loader) as pbar: 749 | for batch_idx, (real_images, _) in enumerate(pbar): 750 | real_images = real_images.to(self.device) 751 | current_batch_size = real_images.shape[0] 752 | t = self.diffusion.sample_timesteps(batch_size=current_batch_size) 753 | x_t, noise = self.diffusion.q_sample(x=real_images, t=t) 754 | 755 | with torch.autocast(device_type=self.device, dtype=torch.float16, enabled=self.fp16): 756 | predicted_noise = self.unet_model(x=x_t, t=t) 757 | 758 | loss = F.smooth_l1_loss(predicted_noise, noise) 759 | loss /= self.accumulation_iters 760 | 761 | accumulated_minibatch_loss += float(loss) 762 | 763 | self.grad_scaler.scale(loss).backward() 764 | 765 | # if ((batch_idx + 1) % self.accumulation_iters == 0) or ((batch_idx + 1) == len(self.train_loader)): 766 | if (batch_idx + 1) % self.accumulation_iters == 0: 767 | self.grad_scaler.step(self.optimizer) 768 | self.grad_scaler.update() 769 | self.optimizer.zero_grad(set_to_none=True) 770 | self.ema.ema_step(ema_model=self.ema_model, model=self.unet_model) 771 | 772 | # if epoch > self.swa_start: 773 | # self.swa_model.update_parameters(model=self.unet_model) 774 | # self.swa_scheduler.step() 775 | # else: 776 | # self.scheduler.step() 777 | 778 | # total_loss += (float(accumulated_minibatch_loss) / len(self.train_loader) * self.accumulation_iters) 779 | pbar.set_description( 780 | # f'Loss minibatch: {float(accumulated_minibatch_loss):.4f}, total: {total_loss:.4f}' 781 | f'Loss minibatch: {float(accumulated_minibatch_loss):.4f}' 782 | ) 783 | accumulated_minibatch_loss = 0.0 784 | 785 | if not batch_idx % self.save_every: 786 | self.sample(epoch=epoch, batch_idx=batch_idx, sample_count=self.sample_count) 787 | 788 | Utils.save_checkpoint( 789 | epoch=epoch, 790 | model=self.unet_model, 791 | optimizer=self.optimizer, 792 | scheduler=self.scheduler, 793 | grad_scaler=self.grad_scaler, 794 | filename=os.path.join(self.save_path, f'model_{epoch}_{batch_idx}.pt') 795 | ) 796 | Utils.save_checkpoint( 797 | epoch=epoch, 798 | model=self.ema_model, 799 | filename=os.path.join(self.save_path, f'model_ema_{epoch}_{batch_idx}.pt') 800 | ) 801 | 802 | self.scheduler.step() 803 | 804 | 805 | if __name__ == '__main__': 806 | trainer = Trainer( 807 | dataset_path=r'C:\datasets\cars', 808 | save_path=r'C:\DeepLearningPytorch\ddpm', 809 | # checkpoint_path=r'C:\DeepLearningPytorch\ddpm\model_126_0.pt', 810 | # checkpoint_path_ema=r'C:\DeepLearningPytorch\ddpm\model_ema_126_0.pt', 811 | # enable_train_mode=False, 812 | ) 813 | trainer.train() 814 | 815 | # trainer.sample(output_name='output6', sample_count=4) 816 | 817 | # trainer.sample_gif( 818 | # output_name='output8', 819 | # sample_count=1, 820 | # save_path=r'C:\DeepLearningPytorch\ddpm' 821 | # ) 822 | 823 | # tester = Tester(device='cuda') 824 | # tester.test_unet() 825 | # tester.test_attention() 826 | # tester.test_jit() 827 | --------------------------------------------------------------------------------