├── LICENSE ├── README.md ├── models.py ├── requirements.txt └── train.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 OvJat 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Diffusion Models Tutorials 2 | 3 | [![License: MIT](https://img.shields.io/badge/license-MIT-blue.svg)](https://opensource.org/licenses/MIT) 4 | 5 | ## Description 6 | This is a PyTorch-based tutorial for Diffusion Models. 7 | 8 | ## setup environment 9 | 10 | ### setup environment (step by step) 11 | 12 | ```bash 13 | # step1. create anaconda environment 14 | conda create -n DiffusionModels python=3.8 15 | 16 | # step2. then activate this environment 17 | conda activate DiffusionModels 18 | 19 | # step3. install pytorch 20 | # if on MacOSX 21 | pip install torch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 22 | # if on Linux/Windows, CUDA 11.6 23 | pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116 24 | # if on Linux/Windows, CUDA 11.7 25 | pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu117 26 | # if on Linux/Windows, CPU Only 27 | pip install torch==1.13.1+cpu torchvision==0.14.1+cpu torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cpu 28 | 29 | # step4. install other packages 30 | pip install diffusers 31 | 32 | ``` 33 | 34 | ### setup environment (on Linux/Windows, CUDA 11.7) 35 | ```shell 36 | # step1. create anaconda environment 37 | conda create -n DiffusionModels python=3.8 38 | 39 | # step2. then activate this environment 40 | conda activate DiffusionModels 41 | 42 | # step3. using requirements.txt 43 | pip install -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu117 44 | 45 | ``` 46 | 47 | ## Files 48 | 49 | * `models.py` is Neural Networks. 50 | * `train.py` 51 | * function `train_vae` shows how to train AutoEncoderKL or AutoEncoderVQ. 52 | * function `make_conditions` shows how to make timesteps and condition for Diffusion. 53 | * function `train_diffusion` shows how to train an Unet for Diffusion. 54 | * function `sampling_diffusion` shows how to sample using a pretrained U-Net. 55 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | 4 | """ 5 | Support Python 3.8 6 | @author: Lou Xiao(louxiao@i32n.com) 7 | @maintainer: Lou Xiao(louxiao@i32n.com) 8 | @copyright: Copyright 2018~2023 9 | @created time: 2023-04-05 15:05:52 CST 10 | @updated time: 2023-04-05 15:05:52 CST 11 | """ 12 | 13 | import torch 14 | import torch.nn as nn 15 | 16 | from diffusers.models.vae import DiagonalGaussianDistribution 17 | from diffusers.models.vae import VectorQuantizer 18 | 19 | 20 | class ConvBlock(nn.Module): 21 | 22 | def __init__(self, num_channels: int): 23 | super().__init__() 24 | self.layers = nn.Sequential( 25 | nn.Conv2d(num_channels, num_channels, kernel_size=3, stride=1, padding=1), 26 | nn.BatchNorm2d(num_channels), 27 | nn.ReLU(), 28 | nn.Conv2d(num_channels, num_channels, kernel_size=3, stride=1, padding=1), 29 | nn.BatchNorm2d(num_channels), 30 | nn.ReLU(), 31 | ) 32 | 33 | def forward(self, inputs: torch.Tensor) -> torch.Tensor: 34 | h = self.layers(inputs) 35 | return h 36 | 37 | 38 | class ResBlock(nn.Module): 39 | 40 | def __init__(self, num_channels: int): 41 | super().__init__() 42 | self.residual = nn.Sequential( 43 | nn.BatchNorm2d(num_channels), 44 | nn.ReLU(), 45 | nn.Conv2d(num_channels, num_channels, kernel_size=3, stride=1, padding=1), 46 | nn.BatchNorm2d(num_channels), 47 | nn.ReLU(), 48 | nn.Conv2d(num_channels, num_channels, kernel_size=3, stride=1, padding=1), 49 | ) 50 | 51 | def forward(self, inputs: torch.Tensor) -> torch.Tensor: 52 | h = inputs + self.residual(inputs) 53 | return h 54 | 55 | 56 | class AutoEncoder(nn.Module): 57 | 58 | def __init__(self, num_channels: int, base_channels: int = 64): 59 | super().__init__() 60 | self.num_channels = num_channels 61 | self.base_channels = base_channels 62 | 63 | self.conv_in = nn.Conv2d(num_channels, base_channels, kernel_size=3, stride=1, padding=1) 64 | self.encoder = nn.Sequential( 65 | # stage 1 66 | nn.Sequential( 67 | ConvBlock(base_channels), 68 | ConvBlock(base_channels), 69 | ), 70 | # stage 2 71 | nn.Conv2d(base_channels, 2 * base_channels, kernel_size=2, stride=2, padding=0), 72 | nn.Sequential( 73 | ConvBlock(2 * base_channels), 74 | ConvBlock(2 * base_channels), 75 | ), 76 | # stage 3 77 | nn.Conv2d(2 * base_channels, 4 * base_channels, kernel_size=2, stride=2, padding=0), 78 | nn.Sequential( 79 | ConvBlock(4 * base_channels), 80 | ConvBlock(4 * base_channels), 81 | ), 82 | # stage 4 83 | nn.Conv2d(4 * base_channels, 8 * base_channels, kernel_size=2, stride=2, padding=0), 84 | nn.Sequential( 85 | ConvBlock(8 * base_channels), 86 | ConvBlock(8 * base_channels), 87 | ), 88 | ) 89 | self.decoder = nn.Sequential( 90 | # stage 4 91 | nn.Sequential( 92 | ConvBlock(8 * base_channels), 93 | ConvBlock(8 * base_channels), 94 | ), 95 | # stage 3 96 | nn.ConvTranspose2d(8 * base_channels, 4 * base_channels, kernel_size=2, stride=2, padding=0), 97 | nn.Sequential( 98 | ConvBlock(4 * base_channels), 99 | ConvBlock(4 * base_channels), 100 | ), 101 | # stage 2 102 | nn.ConvTranspose2d(4 * base_channels, 2 * base_channels, kernel_size=2, stride=2, padding=0), 103 | nn.Sequential( 104 | ConvBlock(2 * base_channels), 105 | ConvBlock(2 * base_channels), 106 | ), 107 | # stage 1 108 | nn.ConvTranspose2d(2 * base_channels, 1 * base_channels, kernel_size=2, stride=2, padding=0), 109 | nn.Sequential( 110 | ConvBlock(1 * base_channels), 111 | ConvBlock(1 * base_channels), 112 | ), 113 | ) 114 | self.conv_out = nn.Conv2d(base_channels, num_channels, kernel_size=3, stride=1, padding=1) 115 | 116 | def forward(self, inputs: torch.Tensor) -> torch.Tensor: 117 | h = self.conv_in(inputs) 118 | h = self.encoder(h) 119 | h = self.decoder(h) 120 | h = self.conv_out(h) 121 | return h 122 | 123 | 124 | class AutoEncoderKL(nn.Module): 125 | 126 | def __init__( 127 | self, 128 | num_channels: int, 129 | latent_dim: int, 130 | base_channels: int = 64 131 | ): 132 | super().__init__() 133 | self.num_channels = num_channels 134 | self.latent_dim = latent_dim 135 | self.base_channels = base_channels 136 | 137 | self.conv_in = nn.Conv2d(num_channels, base_channels, kernel_size=3, stride=1, padding=1) 138 | self.encoder = nn.Sequential( 139 | # stage 1 140 | nn.Sequential( 141 | ConvBlock(base_channels), 142 | ConvBlock(base_channels), 143 | ), 144 | # stage 2 145 | nn.Conv2d(base_channels, 2 * base_channels, kernel_size=2, stride=2, padding=0), 146 | nn.Sequential( 147 | ConvBlock(2 * base_channels), 148 | ConvBlock(2 * base_channels), 149 | ), 150 | # stage 3 151 | nn.Conv2d(2 * base_channels, 4 * base_channels, kernel_size=2, stride=2, padding=0), 152 | nn.Sequential( 153 | ConvBlock(4 * base_channels), 154 | ConvBlock(4 * base_channels), 155 | ), 156 | # stage 4 157 | nn.Conv2d(4 * base_channels, 8 * base_channels, kernel_size=2, stride=2, padding=0), 158 | nn.Sequential( 159 | ConvBlock(8 * base_channels), 160 | ConvBlock(8 * base_channels), 161 | ), 162 | ) 163 | 164 | self.encode_latent = nn.Conv2d(8 * base_channels, 2 * latent_dim, kernel_size=1, stride=1, padding=0) 165 | # KL sampling 166 | self.decode_latent = nn.Conv2d(latent_dim, 8 * base_channels, kernel_size=1, stride=1, padding=0) 167 | 168 | self.decoder = nn.Sequential( 169 | # stage 4 170 | nn.Sequential( 171 | ConvBlock(8 * base_channels), 172 | ConvBlock(8 * base_channels), 173 | ), 174 | # stage 3 175 | nn.ConvTranspose2d(8 * base_channels, 4 * base_channels, kernel_size=2, stride=2, padding=0), 176 | nn.Sequential( 177 | ConvBlock(4 * base_channels), 178 | ConvBlock(4 * base_channels), 179 | ), 180 | # stage 2 181 | nn.ConvTranspose2d(4 * base_channels, 2 * base_channels, kernel_size=2, stride=2, padding=0), 182 | nn.Sequential( 183 | ConvBlock(2 * base_channels), 184 | ConvBlock(2 * base_channels), 185 | ), 186 | # stage 1 187 | nn.ConvTranspose2d(2 * base_channels, 1 * base_channels, kernel_size=2, stride=2, padding=0), 188 | nn.Sequential( 189 | ConvBlock(1 * base_channels), 190 | ConvBlock(1 * base_channels), 191 | ), 192 | ) 193 | self.conv_out = nn.Conv2d(base_channels, num_channels, kernel_size=3, stride=1, padding=1) 194 | 195 | def encode(self, inputs: torch.Tensor, sampling: bool = False, return_loss: bool = False): 196 | h = self.conv_in(inputs) 197 | h = self.encoder(h) 198 | h = self.encode_latent(h) # avg, std 199 | dist = DiagonalGaussianDistribution(h) 200 | if sampling: 201 | return dist.sample() 202 | elif return_loss: 203 | kl_loss = dist.kl().mean() 204 | return dist.sample(), kl_loss 205 | else: 206 | return dist.mode() 207 | 208 | def decode(self, latent: torch.Tensor) -> torch.Tensor: 209 | h = self.decode_latent(latent) 210 | h = self.decoder(h) 211 | h = self.conv_out(h) 212 | return h 213 | 214 | 215 | class AutoEncoderVQ(nn.Module): 216 | 217 | def __init__( 218 | self, 219 | num_channels: int, 220 | latent_dim: int, 221 | base_channels: int = 64, 222 | num_vq_embeddings: int = 8192, 223 | ): 224 | super().__init__() 225 | self.num_channels = num_channels 226 | self.latent_dim = latent_dim 227 | self.base_channels = base_channels 228 | self.num_vq_embeddings = num_vq_embeddings 229 | 230 | self.conv_in = nn.Conv2d(num_channels, base_channels, kernel_size=3, stride=1, padding=1) 231 | self.encoder = nn.Sequential( 232 | # stage 1 233 | nn.Sequential( 234 | ConvBlock(base_channels), 235 | ConvBlock(base_channels), 236 | ), 237 | # stage 2 238 | nn.Conv2d(base_channels, 2 * base_channels, kernel_size=2, stride=2, padding=0), 239 | nn.Sequential( 240 | ConvBlock(2 * base_channels), 241 | ConvBlock(2 * base_channels), 242 | ), 243 | # stage 3 244 | nn.Conv2d(2 * base_channels, 4 * base_channels, kernel_size=2, stride=2, padding=0), 245 | nn.Sequential( 246 | ConvBlock(4 * base_channels), 247 | ConvBlock(4 * base_channels), 248 | ), 249 | # stage 4 250 | nn.Conv2d(4 * base_channels, 8 * base_channels, kernel_size=2, stride=2, padding=0), 251 | nn.Sequential( 252 | ConvBlock(8 * base_channels), 253 | ConvBlock(8 * base_channels), 254 | ), 255 | ) 256 | 257 | self.encode_latent = nn.Conv2d(8 * base_channels, latent_dim, kernel_size=1, stride=1, padding=0) 258 | # VQ 259 | self.vq = VectorQuantizer(num_vq_embeddings, latent_dim, beta=0.25, sane_index_shape=True, legacy=False) 260 | self.decode_latent = nn.Conv2d(latent_dim, 8 * base_channels, kernel_size=1, stride=1, padding=0) 261 | 262 | self.decoder = nn.Sequential( 263 | # stage 4 264 | nn.Sequential( 265 | ConvBlock(8 * base_channels), 266 | ConvBlock(8 * base_channels), 267 | ), 268 | # stage 3 269 | nn.ConvTranspose2d(8 * base_channels, 4 * base_channels, kernel_size=2, stride=2, padding=0), 270 | nn.Sequential( 271 | ConvBlock(4 * base_channels), 272 | ConvBlock(4 * base_channels), 273 | ), 274 | # stage 2 275 | nn.ConvTranspose2d(4 * base_channels, 2 * base_channels, kernel_size=2, stride=2, padding=0), 276 | nn.Sequential( 277 | ConvBlock(2 * base_channels), 278 | ConvBlock(2 * base_channels), 279 | ), 280 | # stage 1 281 | nn.ConvTranspose2d(2 * base_channels, 1 * base_channels, kernel_size=2, stride=2, padding=0), 282 | nn.Sequential( 283 | ConvBlock(1 * base_channels), 284 | ConvBlock(1 * base_channels), 285 | ), 286 | ) 287 | self.conv_out = nn.Conv2d(base_channels, num_channels, kernel_size=3, stride=1, padding=1) 288 | 289 | def encode(self, inputs: torch.Tensor, sampling: bool = False, return_loss: bool = False): 290 | h = self.conv_in(inputs) 291 | h = self.encoder(h) 292 | h = self.encode_latent(h) # avg, std 293 | z_q, loss, _ = self.vq(h) 294 | if sampling: 295 | return z_q 296 | elif return_loss: # train 297 | return z_q, loss 298 | else: 299 | return h 300 | 301 | def decode(self, latent: torch.Tensor) -> torch.Tensor: 302 | h = self.decode_latent(latent) 303 | h = self.decoder(h) 304 | h = self.conv_out(h) 305 | return h 306 | 307 | 308 | class PatchGANDiscriminator(nn.Module): 309 | 310 | def __init__(self, in_channels: int, num_channels: int = 64): 311 | super().__init__() 312 | self.layers = nn.Sequential( 313 | nn.Sequential( 314 | nn.Conv2d(2 * in_channels, num_channels, kernel_size=4, stride=2, padding=1), 315 | # nn.BatchNorm2d(num_channels), 316 | nn.LeakyReLU(0.2), 317 | ), 318 | nn.Sequential( 319 | nn.Conv2d(1 * num_channels, 2 * num_channels, kernel_size=4, stride=2, padding=1), 320 | nn.BatchNorm2d(2 * num_channels), 321 | nn.LeakyReLU(0.2), 322 | ), 323 | nn.Sequential( 324 | nn.Conv2d(2 * num_channels, 4 * num_channels, kernel_size=4, stride=2, padding=1), 325 | nn.BatchNorm2d(4 * num_channels), 326 | nn.LeakyReLU(0.2), 327 | ), 328 | nn.Sequential( 329 | nn.Conv2d(4 * num_channels, 8 * num_channels, kernel_size=4, stride=1, padding=1), 330 | nn.BatchNorm2d(8 * num_channels), 331 | nn.LeakyReLU(0.2), 332 | ), 333 | nn.Sequential( 334 | nn.Conv2d(8 * num_channels, 1, kernel_size=4, stride=1, padding=1), 335 | nn.Sigmoid(), 336 | ), 337 | ) 338 | 339 | # forward method 340 | def forward(self, outputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: 341 | h = torch.cat([outputs, targets], 1) 342 | h = self.layers(h) 343 | return h 344 | 345 | 346 | class UNet(nn.Module): 347 | 348 | def __init__( 349 | self, 350 | in_channels: int, 351 | out_channels: int, 352 | base_channels: int = 64 353 | ): 354 | super().__init__() 355 | self.in_channels = in_channels 356 | self.out_channels = out_channels 357 | 358 | self.conv_in = nn.Conv2d(in_channels, base_channels, kernel_size=3, stride=1, padding=1) 359 | self.encoder_list = nn.ModuleList([ 360 | # stage 1 361 | nn.Sequential( 362 | ResBlock(1 * base_channels), 363 | ResBlock(1 * base_channels), 364 | ), 365 | # stage 2 366 | nn.Sequential( 367 | nn.Conv2d(base_channels, 2 * base_channels, kernel_size=2, stride=2, padding=0), 368 | ResBlock(2 * base_channels), 369 | ResBlock(2 * base_channels), 370 | ), 371 | # stage 3 372 | nn.Sequential( 373 | nn.Conv2d(2 * base_channels, 4 * base_channels, kernel_size=2, stride=2, padding=0), 374 | ResBlock(4 * base_channels), 375 | ResBlock(4 * base_channels), 376 | ), 377 | # stage 4 378 | nn.Sequential( 379 | nn.Conv2d(4 * base_channels, 8 * base_channels, kernel_size=2, stride=2, padding=0), 380 | ResBlock(8 * base_channels), 381 | ResBlock(8 * base_channels), 382 | ), 383 | ]) 384 | self.middle = nn.Sequential( 385 | nn.Conv2d(8 * base_channels, 32 * base_channels, kernel_size=2, stride=2, padding=0), 386 | ResBlock(32 * base_channels), 387 | ResBlock(32 * base_channels), 388 | nn.ConvTranspose2d(32 * base_channels, 8 * base_channels, kernel_size=2, stride=2, padding=0), 389 | ) 390 | self.decoder_list = nn.Sequential( 391 | # stage 4 392 | nn.Sequential( 393 | nn.Conv2d(2 * 8 * base_channels, 8 * base_channels, kernel_size=1, stride=1, padding=0), 394 | ResBlock(8 * base_channels), 395 | ResBlock(8 * base_channels), 396 | nn.ConvTranspose2d(8 * base_channels, 4 * base_channels, kernel_size=2, stride=2, padding=0), 397 | ), 398 | # stage 3 399 | nn.Sequential( 400 | nn.Conv2d(2 * 4 * base_channels, 4 * base_channels, kernel_size=1, stride=1, padding=0), 401 | ResBlock(4 * base_channels), 402 | ResBlock(4 * base_channels), 403 | nn.ConvTranspose2d(4 * base_channels, 2 * base_channels, kernel_size=2, stride=2, padding=0), 404 | ), 405 | # stage 2 406 | nn.Sequential( 407 | nn.Conv2d(2 * 2 * base_channels, 2 * base_channels, kernel_size=1, stride=1, padding=0), 408 | ResBlock(2 * base_channels), 409 | ResBlock(2 * base_channels), 410 | nn.ConvTranspose2d(2 * base_channels, 1 * base_channels, kernel_size=2, stride=2, padding=0), 411 | ), 412 | # stage 1 413 | nn.Sequential( 414 | nn.Conv2d(2 * 1 * base_channels, 1 * base_channels, kernel_size=1, stride=1, padding=0), 415 | ResBlock(1 * base_channels), 416 | ResBlock(1 * base_channels), 417 | ), 418 | ) 419 | self.conv_out = nn.Conv2d(base_channels, out_channels, kernel_size=3, stride=1, padding=1) 420 | 421 | def forward(self, inputs: torch.Tensor) -> torch.Tensor: 422 | h = self.conv_in(inputs) 423 | skip_list = [] 424 | for m in self.encoder_list: 425 | h = m(h) 426 | skip_list.insert(0, h) 427 | h = self.middle(h) 428 | for m, skip in zip(self.decoder_list, skip_list): 429 | h = torch.concat([skip, h], dim=1) 430 | h = m(h) 431 | h = self.conv_out(h) 432 | return h 433 | 434 | 435 | class TwoWaysModule(object): 436 | pass 437 | 438 | 439 | class TwoWaysSequential(nn.Module): 440 | 441 | def __init__(self, *modules): 442 | super().__init__() 443 | self.module_list = nn.ModuleList(modules) 444 | 445 | def forward(self, inputs: torch.Tensor, conditions: torch.Tensor) -> torch.Tensor: 446 | h = inputs 447 | for m in self.module_list: 448 | if isinstance(m, TwoWaysModule): 449 | h = m(h, conditions) 450 | else: 451 | h = m(h) 452 | return h 453 | 454 | 455 | class CrossAttentionBlock(nn.Module, TwoWaysModule): 456 | 457 | def __init__( 458 | self, 459 | num_channels: int, 460 | condition_dim: int, 461 | num_heads: int = 8, 462 | layer_scale_init: float = 1e-6, 463 | ): 464 | super().__init__() 465 | self.layer_norm = nn.GroupNorm(1, num_channels) 466 | self.attention = nn.MultiheadAttention( 467 | embed_dim=num_channels, 468 | kdim=condition_dim, 469 | vdim=condition_dim, 470 | num_heads=num_heads, 471 | batch_first=True, 472 | ) 473 | self.layer_scale = nn.Parameter(torch.full([num_channels, 1, 1], layer_scale_init, dtype=torch.float)) 474 | 475 | def forward(self, inputs: torch.Tensor, conditions: torch.Tensor) -> torch.Tensor: 476 | # inputs: shape[B, C, H, W] 477 | # conditions: shape[B, L, C'] 478 | h = self.layer_norm(inputs) 479 | # cross attention 480 | bb, cc, hh, ww = h.shape 481 | h = h.reshape([bb, cc, hh * ww]) # [B, C, L] 482 | h = torch.swapdims(h, 1, 2) # [B, L, C] 483 | h, _ = self.attention(h, conditions, conditions) # Q, K, V 484 | h = torch.swapdims(h, 1, 2) 485 | h = h.reshape([bb, cc, hh, ww]) 486 | h = inputs + self.layer_scale * h # residual 487 | return h 488 | 489 | 490 | class ConditionalUNet(nn.Module): 491 | 492 | def __init__( 493 | self, 494 | num_channels: int, 495 | condition_dim: int = 512, 496 | base_channels: int = 64, 497 | ): 498 | super().__init__() 499 | self.num_channels = num_channels 500 | self.base_channels = base_channels 501 | self.condition_dim = condition_dim 502 | 503 | self.conv_in = nn.Conv2d(num_channels, base_channels, kernel_size=3, stride=1, padding=1) 504 | self.encoder_list = nn.ModuleList([ 505 | # stage 1 506 | TwoWaysSequential( 507 | ResBlock(1 * base_channels), 508 | CrossAttentionBlock(1 * base_channels, condition_dim), 509 | ResBlock(1 * base_channels), 510 | CrossAttentionBlock(1 * base_channels, condition_dim), 511 | ), 512 | # stage 2 513 | TwoWaysSequential( 514 | nn.Conv2d(base_channels, 2 * base_channels, kernel_size=2, stride=2, padding=0), 515 | ResBlock(2 * base_channels), 516 | CrossAttentionBlock(2 * base_channels, condition_dim), 517 | ResBlock(2 * base_channels), 518 | CrossAttentionBlock(2 * base_channels, condition_dim), 519 | ), 520 | # stage 3 521 | TwoWaysSequential( 522 | nn.Conv2d(2 * base_channels, 4 * base_channels, kernel_size=2, stride=2, padding=0), 523 | ResBlock(4 * base_channels), 524 | CrossAttentionBlock(4 * base_channels, condition_dim), 525 | ResBlock(4 * base_channels), 526 | CrossAttentionBlock(4 * base_channels, condition_dim), 527 | ), 528 | # stage 4 529 | TwoWaysSequential( 530 | nn.Conv2d(4 * base_channels, 8 * base_channels, kernel_size=2, stride=2, padding=0), 531 | ResBlock(8 * base_channels), 532 | ResBlock(8 * base_channels), 533 | ), 534 | ]) 535 | self.middle = TwoWaysSequential( 536 | nn.Conv2d(8 * base_channels, 32 * base_channels, kernel_size=2, stride=2, padding=0), 537 | ResBlock(32 * base_channels), 538 | CrossAttentionBlock(32 * base_channels, condition_dim), 539 | ResBlock(32 * base_channels), 540 | CrossAttentionBlock(32 * base_channels, condition_dim), 541 | nn.ConvTranspose2d(32 * base_channels, 8 * base_channels, kernel_size=2, stride=2, padding=0), 542 | ) 543 | self.decoder_list = nn.ModuleList([ 544 | # stage 4 545 | TwoWaysSequential( 546 | nn.Conv2d(2 * 8 * base_channels, 8 * base_channels, kernel_size=1, stride=1, padding=0), 547 | ResBlock(8 * base_channels), 548 | CrossAttentionBlock(8 * base_channels, condition_dim), 549 | ResBlock(8 * base_channels), 550 | CrossAttentionBlock(8 * base_channels, condition_dim), 551 | nn.ConvTranspose2d(8 * base_channels, 4 * base_channels, kernel_size=2, stride=2, padding=0), 552 | ), 553 | # stage 3 554 | TwoWaysSequential( 555 | nn.Conv2d(2 * 4 * base_channels, 4 * base_channels, kernel_size=1, stride=1, padding=0), 556 | ResBlock(4 * base_channels), 557 | CrossAttentionBlock(4 * base_channels, condition_dim), 558 | ResBlock(4 * base_channels), 559 | CrossAttentionBlock(4 * base_channels, condition_dim), 560 | nn.ConvTranspose2d(4 * base_channels, 2 * base_channels, kernel_size=2, stride=2, padding=0), 561 | ), 562 | # stage 2 563 | TwoWaysSequential( 564 | nn.Conv2d(2 * 2 * base_channels, 2 * base_channels, kernel_size=1, stride=1, padding=0), 565 | ResBlock(2 * base_channels), 566 | CrossAttentionBlock(2 * base_channels, condition_dim), 567 | ResBlock(2 * base_channels), 568 | CrossAttentionBlock(2 * base_channels, condition_dim), 569 | nn.ConvTranspose2d(2 * base_channels, 1 * base_channels, kernel_size=2, stride=2, padding=0), 570 | ), 571 | # stage 1 572 | TwoWaysSequential( 573 | nn.Conv2d(2 * 1 * base_channels, 1 * base_channels, kernel_size=1, stride=1, padding=0), 574 | ResBlock(1 * base_channels), 575 | CrossAttentionBlock(1 * base_channels, condition_dim), 576 | ResBlock(1 * base_channels), 577 | CrossAttentionBlock(1 * base_channels, condition_dim), 578 | ), 579 | ]) 580 | self.conv_out = nn.Conv2d(base_channels, num_channels, kernel_size=3, stride=1, padding=1) 581 | 582 | def forward(self, inputs: torch.Tensor, conditions: torch.Tensor) -> torch.Tensor: 583 | h = self.conv_in(inputs) 584 | skip_list = [] 585 | for m in self.encoder_list: 586 | h = m(h, conditions) 587 | skip_list.insert(0, h) 588 | h = self.middle(h, conditions) 589 | for m, skip in zip(self.decoder_list, skip_list): 590 | h = torch.concat([skip, h], dim=1) 591 | h = m(h, conditions) 592 | h = self.conv_out(h) 593 | return h 594 | 595 | 596 | def debug(): 597 | net = ConditionalUNet(3, condition_dim=512) 598 | xx = torch.rand([4, 3, 128, 128]) 599 | cc = torch.rand([4, 128, 512]) 600 | yy = net(xx, cc) 601 | print(yy.shape) 602 | 603 | 604 | if __name__ == '__main__': 605 | debug() 606 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.17.1 2 | diffusers==0.14.0 3 | filelock==3.10.0 4 | huggingface-hub==0.13.2 5 | packaging==23.0 6 | Pillow==9.4.0 7 | psutil==5.9.4 8 | PyYAML==6.0 9 | regex==2022.10.31 10 | requests==2.28.2 11 | sympy==1.11.1 12 | torch==1.13.1+cu117 13 | torchaudio==0.13.1+cu117 14 | torchvision==0.14.1+cu117 15 | tqdm==4.65.0 16 | typing_extensions==4.5.0 17 | urllib3==1.26.15 18 | zipp==3.15.0 19 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | 4 | """ 5 | Support Python 3.8 6 | @author: Lou Xiao(louxiao@i32n.com) 7 | @maintainer: Lou Xiao(louxiao@i32n.com) 8 | @copyright: Copyright 2018~2023 9 | @created time: 2023-04-05 18:19:12 CST 10 | @updated time: 2023-04-05 18:19:12 CST 11 | """ 12 | 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as tnf 16 | from torch.optim import AdamW 17 | import torch.utils.data as tud 18 | 19 | from diffusers.models.embeddings import get_timestep_embedding 20 | from diffusers.schedulers import DDPMScheduler 21 | 22 | from models import AutoEncoderKL 23 | from models import AutoEncoderVQ 24 | from models import PatchGANDiscriminator 25 | from models import ConditionalUNet 26 | 27 | 28 | # Fake Dataset, just for demo. 29 | class FakeDataset(tud.Dataset): 30 | 31 | def __init__(self, src_shape=(17, 128, 128), dst_shape=(1, 128, 128)): 32 | self.src_shape = src_shape 33 | self.dst_shape = dst_shape 34 | self.sample_count = 10000 35 | 36 | def __len__(self): 37 | return self.sample_count 38 | 39 | def __getitem__(self, index: int): 40 | xx = torch.rand(self.src_shape, dtype=torch.float32) 41 | yy = torch.rand(self.dst_shape, dtype=torch.float32) 42 | return xx, yy 43 | 44 | 45 | def train_vae(): 46 | default_device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 47 | default_type = torch.float32 48 | 49 | # init model VAE or VQ-VAE or VQ-GAN 50 | vae = AutoEncoderKL(17, latent_dim=128) 51 | # vae = AutoEncoderVQ(17, latent_dim=128) 52 | # init weight (optional) 53 | vae.to(device=default_device, dtype=default_type) 54 | optimizer = AdamW(vae.parameters(), lr=1e-4, weight_decay=0.05) 55 | 56 | discriminator = PatchGANDiscriminator(17) 57 | discriminator.to(device=default_device, dtype=default_type) 58 | discriminator_optimizer = AdamW(vae.parameters(), lr=2e-4, weight_decay=0.05) 59 | 60 | # init dataset 61 | ds = FakeDataset() 62 | dl = tud.DataLoader(ds, batch_size=32, shuffle=True, drop_last=True, num_workers=0) 63 | 64 | vae.train() 65 | discriminator.train() 66 | for batch, (xx, yy) in enumerate(dl): 67 | xx = xx.to(device=default_device, dtype=default_type) 68 | # yy = yy.to(device=default_device, dtype=default_type) 69 | 70 | # train discriminator 71 | discriminator_optimizer.zero_grad() 72 | with torch.no_grad(): 73 | z = vae.encode(xx) 74 | fake = vae.decode(z) 75 | d_fake = discriminator(fake, xx).reshape(-1, 1) 76 | d_fake_loss = tnf.binary_cross_entropy(d_fake, torch.zeros_like(d_fake)) 77 | d_real = discriminator(xx, xx).reshape(-1, 1) 78 | d_real_loss = tnf.binary_cross_entropy(d_real, torch.ones_like(d_real)) 79 | d_loss = d_real_loss + d_fake_loss 80 | d_loss.backward() 81 | discriminator_optimizer.step() 82 | 83 | # train step 84 | optimizer.zero_grad() 85 | z, kl_loss = vae.encode(xx, return_loss=True) 86 | # print("kl_loss:",kl_loss) 87 | x_hat = vae.decode(z) 88 | d_real = discriminator(x_hat, xx) 89 | gan_loss = tnf.binary_cross_entropy(d_real, torch.ones_like(d_real)) 90 | mse_loss = tnf.mse_loss(x_hat, xx) 91 | mse_loss = mse_loss / (mse_loss.detach() + 1e-6) 92 | loss = mse_loss + gan_loss + kl_loss 93 | loss.backward() 94 | optimizer.step() 95 | loss = loss.detach().cpu().numpy() 96 | print("Batch {:10d} | Loss: {:8.4f}".format(batch, loss)) 97 | 98 | 99 | def make_conditions(timesteps: torch.Tensor, images: torch.Tensor = None, embedding_dim: int = 128) -> torch.Tensor: 100 | assert timesteps.ndim == 1 101 | 102 | timestep_embedding = get_timestep_embedding(timesteps, embedding_dim, max_period=10000) # [B, C] 103 | timestep_embedding = timestep_embedding[:, None, :] # [B, 1, C] 104 | 105 | if images is not None: 106 | assert images.shape[1] == embedding_dim 107 | img_embed = torch.flatten(images, 2) # [B, C, H*W] 108 | img_embed = torch.swapdims(img_embed, 1, 2) # [B, H*W, C] 109 | condition_embedding = torch.cat([timestep_embedding, img_embed], dim=1) # [B, 1+L, C] 110 | else: 111 | condition_embedding = timestep_embedding 112 | 113 | length = condition_embedding.shape[1] 114 | positions = torch.arange(0, length, device=condition_embedding.device, dtype=condition_embedding.dtype) 115 | position_embedding = get_timestep_embedding(positions, embedding_dim, max_period=10000) # [1+L, C] 116 | position_embedding = position_embedding[None, ...] # [1, L, C] 117 | condition_embedding += position_embedding 118 | 119 | return condition_embedding 120 | 121 | 122 | def train_diffusion(): 123 | default_device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 124 | default_type = torch.float32 125 | 126 | # init model 127 | unet = ConditionalUNet(128, condition_dim=512) 128 | unet.to(device=default_device, dtype=default_type) 129 | optimizer = AdamW(unet.parameters(), lr=1e-4, weight_decay=0.05) 130 | 131 | # load encoder 132 | # Encoder Source domain 133 | src_vae = AutoEncoderVQ(17, latent_dim=512) # as condition 134 | src_vae.to(device=default_device, dtype=default_type) 135 | # load from checkpoint 136 | src_vae.requires_grad_(False) 137 | src_vae.eval() 138 | 139 | # Encoder Target domain 140 | tgt_vae = AutoEncoderVQ(1, latent_dim=128) 141 | tgt_vae.to(device=default_device, dtype=default_type) 142 | # load from checkpoint 143 | tgt_vae.requires_grad_(False) 144 | tgt_vae.eval() 145 | 146 | # init dataset 147 | ds = FakeDataset() 148 | dl = tud.DataLoader(ds, batch_size=32, shuffle=True, drop_last=True, num_workers=0) 149 | 150 | # init noise scheduler 151 | noise_scheduler = DDPMScheduler( 152 | num_train_timesteps=1000, 153 | beta_start=0.0001, 154 | beta_end=0.02, 155 | prediction_type="epsilon", 156 | clip_sample=False, 157 | ) 158 | 159 | unet.train() 160 | for batch, (xx, yy) in enumerate(dl): 161 | xx = xx.to(device=default_device, dtype=default_type) 162 | yy = yy.to(device=default_device, dtype=default_type) 163 | 164 | # train step 165 | optimizer.zero_grad() 166 | with torch.no_grad(): 167 | src_latent = src_vae.encode(xx) 168 | tgt_latent = tgt_vae.encode(yy) # x_0 169 | # make condition 170 | timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (tgt_latent.shape[0],), device=default_device, dtype=torch.long) 171 | if batch % 2 == 0: 172 | conditions = make_conditions(timesteps, src_latent, embedding_dim=512) 173 | else: 174 | conditions = make_conditions(timesteps, None, embedding_dim=512) 175 | # add noise 176 | noise = torch.randn_like(tgt_latent) 177 | tgt_latent = noise_scheduler.add_noise(tgt_latent, noise, timesteps) 178 | # learning noise 179 | outputs = unet(tgt_latent, conditions) 180 | loss = tnf.mse_loss(outputs, noise) 181 | loss.backward() 182 | optimizer.step() 183 | loss = loss.detach().cpu().numpy() 184 | print("Batch {:10d} | Loss: {:8.4f}".format(batch, loss)) 185 | 186 | 187 | def sampling_diffusion(): 188 | default_device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 189 | default_type = torch.float32 190 | 191 | # loading model 192 | unet = ConditionalUNet(128, condition_dim=512) 193 | unet.to(device=default_device, dtype=default_type) 194 | # load from checkpoint 195 | unet.requires_grad_(False) 196 | unet.eval() 197 | 198 | # Encoder Source domain 199 | src_vae = AutoEncoderVQ(17, latent_dim=512) # as condition 200 | src_vae.to(device=default_device, dtype=default_type) 201 | # load from checkpoint 202 | src_vae.requires_grad_(False) 203 | src_vae.eval() 204 | 205 | # Encoder Target domain 206 | tgt_vae = AutoEncoderVQ(1, latent_dim=128) 207 | tgt_vae.to(device=default_device, dtype=default_type) 208 | # load from checkpoint 209 | tgt_vae.requires_grad_(False) 210 | tgt_vae.eval() 211 | 212 | # init dataset 213 | ds = FakeDataset() 214 | dl = tud.DataLoader(ds, batch_size=32, shuffle=True, drop_last=True, num_workers=0) 215 | 216 | # init noise scheduler 217 | noise_scheduler = DDPMScheduler( 218 | num_train_timesteps=1000, 219 | beta_start=0.0001, 220 | beta_end=0.02, 221 | prediction_type="epsilon", 222 | clip_sample=False, 223 | ) 224 | 225 | guidance_scale = 7.2 226 | noise_scheduler.set_timesteps(100) 227 | timesteps = noise_scheduler.timesteps 228 | print("sampling timesteps:", timesteps) 229 | with torch.no_grad(): 230 | for batch, (xx, _) in enumerate(dl): 231 | xx = xx.to(device=default_device, dtype=default_type) 232 | src_latent = src_vae.encode(xx) 233 | tgt_latent = torch.randn([xx.shape[0], 128, 16, 16], device=default_device, dtype=default_type) 234 | # sampling steps 235 | for t in timesteps.tolist(): 236 | ts = torch.full([xx.shape[0]], t, device=default_device, dtype=default_type) 237 | # conditional sampling 238 | conditions = make_conditions(ts, src_latent, embedding_dim=512) 239 | conditional_noise = unet(tgt_latent, conditions) 240 | # unconditional sampling 241 | conditions = make_conditions(ts, None, embedding_dim=512) 242 | unconditional_noise = unet(tgt_latent, conditions) 243 | # Classifier-Free Guidance 244 | noise = guidance_scale * conditional_noise + (1 - guidance_scale) * unconditional_noise 245 | tgt_latent = noise_scheduler.step(noise, t, tgt_latent).prev_sample 246 | # decode, get target sample. 247 | tgt = tgt_vae.decode(tgt_latent) 248 | print("target sample:", tgt.shape) 249 | 250 | 251 | def main(): 252 | # train_vae() 253 | # train_diffusion() 254 | sampling_diffusion() 255 | 256 | 257 | if __name__ == '__main__': 258 | main() 259 | --------------------------------------------------------------------------------