├── .gitignore ├── Dockerfile ├── README.md ├── docker-compose.yaml ├── model.py ├── samples ├── epoch:299_denoise_timesteps:1.png ├── epoch:299_denoise_timesteps:128.png ├── epoch:299_denoise_timesteps:16.png ├── epoch:299_denoise_timesteps:2.png ├── epoch:299_denoise_timesteps:4.png └── epoch:299_denoise_timesteps:8.png ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | celeba-hq/ 3 | log_images* 4 | log_images_tvanilla/ 5 | lightning_logs/ 6 | inception/ 7 | tb_logs/ 8 | wandb/ 9 | dit_saved.pth 10 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # Use a lightweight base image with Conda installed 2 | FROM nvidia/cuda:11.8.0-devel-ubuntu22.04 3 | 4 | # Set environment variables to avoid interactive prompts 5 | ARG DEBIAN_FRONTEND=noninteractive 6 | RUN apt-get update && \ 7 | apt-get install -y --no-install-recommends \ 8 | build-essential \ 9 | cmake \ 10 | curl \ 11 | ffmpeg \ 12 | git \ 13 | wget \ 14 | python3-pip \ 15 | python3-dev 16 | 17 | # Create a working directory 18 | RUN pip install --upgrade pip 19 | RUN pip install accelerate==1.2.0 \ 20 | && pip install diffusers==0.31.0 \ 21 | && pip install torch==2.5.1 \ 22 | && pip install timm==1.0.12 \ 23 | && pip install torchmetrics[image] \ 24 | && pip install matplotlib \ 25 | && pip install pandas==2.2.3 \ 26 | && pip install fastparquet==2024.11.0 \ 27 | && pip install pytorch-lightning \ 28 | && pip install tensorboard \ 29 | && pip install wandb 30 | 31 | # 1. visualize logs(on server): 32 | # tensorboard --logdir=tb_logs/shortcut_model --port 6006 33 | # 2. on own pc: 34 | # ssh -N -f -L localhost:16006:localhost:6006 r.khafizov@10.16.88.93 35 | # if port is already in use, do: 36 | # 1. sudo netstat -tulpn | grep :16006 37 | # 2. kill observed process id `kill "id"` -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | ## One-Step Diffusion via Shortcut Models 3 | 4 | This is an unofficial PyTorch implementation. Original implementation in jax might be found here: https://github.com/kvfrans/shortcut-models 5 | 6 | At the moment there is a very basic implementation with 2 modes: naive(simple flow matching) and shortcut. 7 | 8 | Implementation allows to train in multi-gpu mode, thanks to pytorch-lightning 9 | 10 | ## Data 11 | 12 | I used celeba-hq dataset from HuggingFace for image generation task https://huggingface.co/datasets/mattymchen/celeba-hq 13 | 14 | ## Using the code 15 | 16 | There is a helpful Dockefile and docker-compose in this repository which install all necessary libraries. 17 | 18 | In order to run just write: 19 | 20 | ``` 21 | python train.py 22 | ``` 23 | 24 | ## Results(300 epochs of shortcut training): 25 | 26 | 1 denoising step: 27 | 28 |

29 | Showcase Figure 30 |

31 | 32 | 2 denoising steps: 33 | 34 |

35 | Showcase Figure 36 |

37 | 38 | 4 denoising steps: 39 | 40 |

41 | Showcase Figure 42 |

43 | 44 | 8 denoising steps: 45 | 46 |

47 | Showcase Figure 48 |

49 | 50 | 16 denoising steps: 51 | 52 |

53 | Showcase Figure 54 |

55 | 56 | 128 denoising steps: 57 | 58 |

59 | Showcase Figure 60 |

-------------------------------------------------------------------------------- /docker-compose.yaml: -------------------------------------------------------------------------------- 1 | services: 2 | shortcut_pytorch: 3 | image: rkhafizov.shortcut_pytorch_image 4 | build: 5 | context: . 6 | dockerfile: Dockerfile 7 | container_name: rkhafizov.shortcut_pytorch_container 8 | network_mode: host 9 | ipc: host 10 | volumes: 11 | - /home/r.khafizov/shortcut_pytorch:/workspace/shortcut_pytorch 12 | - /home/r.khafizov/shortcut_pytorch/inception:/root/.cache/torch/hub/checkpoints 13 | ports: 14 | - "6666:6666" # Adjust the port mapping as needed 15 | environment: 16 | - NVIDIA_VISIBLE_DEVICES=all # Adjust GPU visibility as needed 17 | command: "/bin/bash -c 'source /etc/bash.bashrc && tail -f /dev/null && /bin/bash'" # Keep container running 18 | deploy: 19 | resources: 20 | reservations: 21 | devices: 22 | - driver: nvidia 23 | # count: 1 24 | device_ids: ["0, 2, 3, 5"] 25 | capabilities: [gpu] 26 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # GLIDE: https://github.com/openai/glide-text2im 9 | # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py 10 | # -------------------------------------------------------- 11 | 12 | import torch 13 | import torch.nn as nn 14 | import numpy as np 15 | import math 16 | from timm.models.vision_transformer import PatchEmbed, Attention, Mlp 17 | from os import makedirs 18 | 19 | import pytorch_lightning as pl 20 | from pytorch_lightning.utilities.types import STEP_OUTPUT 21 | 22 | 23 | from torchmetrics.image.fid import FrechetInceptionDistance 24 | import matplotlib.pyplot as plt 25 | 26 | from diffusers.models import AutoencoderKL 27 | 28 | from utils import create_targets, create_targets_naive 29 | 30 | from utils import create_targets 31 | from typing import Dict, Any 32 | from typing import Iterable, Optional 33 | 34 | import weakref 35 | import copy 36 | import contextlib 37 | import wandb 38 | 39 | def modulate(x, shift, scale): 40 | return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) 41 | 42 | LEARNING_RATE = 0.0001 43 | EVAL_SIZE = 8 44 | NUM_CLASSES = 1 45 | CFG_SCALE = 0.0 46 | 47 | def modulate(x, shift, scale): 48 | return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) 49 | 50 | def to_float_maybe(x): 51 | return x.float() if x.dtype in [torch.float16, torch.bfloat16] else x 52 | 53 | 54 | class ExponentialMovingAverage: 55 | """ 56 | Maintains (exponential) moving average of a set of parameters. 57 | Args: 58 | parameters: Iterable of `torch.nn.Parameter` (typically from 59 | `model.parameters()`). 60 | decay: The exponential decay. 61 | use_num_updates: Whether to use number of updates when computing 62 | averages. 63 | """ 64 | def __init__( 65 | self, 66 | parameters: Iterable[torch.nn.Parameter], 67 | decay: float, 68 | use_num_updates: bool = True 69 | ): 70 | if decay < 0.0 or decay > 1.0: 71 | raise ValueError('Decay must be between 0 and 1') 72 | self.decay = decay 73 | self.num_updates = 0 if use_num_updates else None 74 | parameters = list(parameters) 75 | self.shadow_params = [to_float_maybe(p.clone().detach()) 76 | for p in parameters if p.requires_grad] 77 | self.collected_params = None 78 | # By maintaining only a weakref to each parameter, 79 | # we maintain the old GC behaviour of ExponentialMovingAverage: 80 | # if the model goes out of scope but the ExponentialMovingAverage 81 | # is kept, no references to the model or its parameters will be 82 | # maintained, and the model will be cleaned up. 83 | self._params_refs = [weakref.ref(p) for p in parameters] 84 | 85 | def _get_parameters( 86 | self, 87 | parameters: Optional[Iterable[torch.nn.Parameter]] 88 | ) -> Iterable[torch.nn.Parameter]: 89 | if parameters is None: 90 | parameters = [p() for p in self._params_refs] 91 | if any(p is None for p in parameters): 92 | raise ValueError( 93 | "(One of) the parameters with which this " 94 | "ExponentialMovingAverage " 95 | "was initialized no longer exists (was garbage collected);" 96 | " please either provide `parameters` explicitly or keep " 97 | "the model to which they belong from being garbage " 98 | "collected." 99 | ) 100 | return parameters 101 | else: 102 | parameters = list(parameters) 103 | if len(parameters) != len(self.shadow_params): 104 | raise ValueError( 105 | "Number of parameters passed as argument is different " 106 | "from number of shadow parameters maintained by this " 107 | "ExponentialMovingAverage" 108 | ) 109 | return parameters 110 | 111 | def update( 112 | self, 113 | parameters: Optional[Iterable[torch.nn.Parameter]] = None 114 | ) -> None: 115 | """ 116 | Update currently maintained parameters. 117 | Call this every time the parameters are updated, such as the result of 118 | the `optimizer.step()` call. 119 | Args: 120 | parameters: Iterable of `torch.nn.Parameter`; usually the same set of 121 | parameters used to initialize this object. If `None`, the 122 | parameters with which this `ExponentialMovingAverage` was 123 | initialized will be used. 124 | """ 125 | parameters = self._get_parameters(parameters) 126 | decay = self.decay 127 | if self.num_updates is not None: 128 | self.num_updates += 1 129 | decay = min( 130 | decay, 131 | (1 + self.num_updates) / (10 + self.num_updates) 132 | ) 133 | one_minus_decay = 1.0 - decay 134 | if parameters[0].device != self.shadow_params[0].device: 135 | self.to(device=parameters[0].device) 136 | with torch.no_grad(): 137 | parameters = [p for p in parameters if p.requires_grad] 138 | for s_param, param in zip(self.shadow_params, parameters): 139 | torch.lerp(s_param, param.to(dtype=s_param.dtype), one_minus_decay, out=s_param) 140 | 141 | def copy_to( 142 | self, 143 | parameters: Optional[Iterable[torch.nn.Parameter]] = None 144 | ) -> None: 145 | """ 146 | Copy current averaged parameters into given collection of parameters. 147 | Args: 148 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 149 | updated with the stored moving averages. If `None`, the 150 | parameters with which this `ExponentialMovingAverage` was 151 | initialized will be used. 152 | """ 153 | parameters = self._get_parameters(parameters) 154 | for s_param, param in zip(self.shadow_params, parameters): 155 | if param.requires_grad: 156 | param.data.copy_(s_param.data) 157 | 158 | def store( 159 | self, 160 | parameters: Optional[Iterable[torch.nn.Parameter]] = None 161 | ) -> None: 162 | """ 163 | Save the current parameters for restoring later. 164 | Args: 165 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 166 | temporarily stored. If `None`, the parameters of with which this 167 | `ExponentialMovingAverage` was initialized will be used. 168 | """ 169 | parameters = self._get_parameters(parameters) 170 | self.collected_params = [ 171 | param.clone() 172 | for param in parameters 173 | if param.requires_grad 174 | ] 175 | 176 | def restore( 177 | self, 178 | parameters: Optional[Iterable[torch.nn.Parameter]] = None 179 | ) -> None: 180 | """ 181 | Restore the parameters stored with the `store` method. 182 | Useful to validate the model with EMA parameters without affecting the 183 | original optimization process. Store the parameters before the 184 | `copy_to` method. After validation (or model saving), use this to 185 | restore the former parameters. 186 | Args: 187 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 188 | updated with the stored parameters. If `None`, the 189 | parameters with which this `ExponentialMovingAverage` was 190 | initialized will be used. 191 | """ 192 | if self.collected_params is None: 193 | raise RuntimeError( 194 | "This ExponentialMovingAverage has no `store()`ed weights " 195 | "to `restore()`" 196 | ) 197 | parameters = self._get_parameters(parameters) 198 | for c_param, param in zip(self.collected_params, parameters): 199 | if param.requires_grad: 200 | param.data.copy_(c_param.data) 201 | 202 | @contextlib.contextmanager 203 | def average_parameters( 204 | self, 205 | parameters: Optional[Iterable[torch.nn.Parameter]] = None 206 | ): 207 | r""" 208 | Context manager for validation/inference with averaged parameters. 209 | Equivalent to: 210 | ema.store() 211 | ema.copy_to() 212 | try: 213 | ... 214 | finally: 215 | ema.restore() 216 | Args: 217 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 218 | updated with the stored parameters. If `None`, the 219 | parameters with which this `ExponentialMovingAverage` was 220 | initialized will be used. 221 | """ 222 | parameters = self._get_parameters(parameters) 223 | self.store(parameters) 224 | self.copy_to(parameters) 225 | try: 226 | yield 227 | finally: 228 | self.restore(parameters) 229 | 230 | def to(self, device=None, dtype=None) -> None: 231 | r"""Move internal buffers of the ExponentialMovingAverage to `device`. 232 | Args: 233 | device: like `device` argument to `torch.Tensor.to` 234 | """ 235 | # .to() on the tensors handles None correctly 236 | self.shadow_params = [ 237 | p.to(device=device, dtype=dtype) 238 | if p.is_floating_point() 239 | else p.to(device=device) 240 | for p in self.shadow_params 241 | ] 242 | if self.collected_params is not None: 243 | self.collected_params = [ 244 | p.to(device=device, dtype=dtype) 245 | if p.is_floating_point() 246 | else p.to(device=device) 247 | for p in self.collected_params 248 | ] 249 | return 250 | 251 | def state_dict(self) -> dict: 252 | r"""Returns the state of the ExponentialMovingAverage as a dict.""" 253 | # Following PyTorch conventions, references to tensors are returned: 254 | # "returns a reference to the state and not its copy!" - 255 | # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict 256 | return { 257 | "decay": self.decay, 258 | "num_updates": self.num_updates, 259 | "shadow_params": self.shadow_params, 260 | "collected_params": self.collected_params 261 | } 262 | 263 | def load_state_dict(self, state_dict: dict) -> None: 264 | r"""Loads the ExponentialMovingAverage state. 265 | Args: 266 | state_dict (dict): EMA state. Should be an object returned 267 | from a call to :meth:`state_dict`. 268 | """ 269 | # deepcopy, to be consistent with module API 270 | state_dict = copy.deepcopy(state_dict) 271 | self.decay = state_dict["decay"] 272 | if self.decay < 0.0 or self.decay > 1.0: 273 | raise ValueError('Decay must be between 0 and 1') 274 | self.num_updates = state_dict["num_updates"] 275 | assert self.num_updates is None or isinstance(self.num_updates, int), \ 276 | "Invalid num_updates" 277 | 278 | self.shadow_params = state_dict["shadow_params"] 279 | assert isinstance(self.shadow_params, list), \ 280 | "shadow_params must be a list" 281 | assert all( 282 | isinstance(p, torch.Tensor) for p in self.shadow_params 283 | ), "shadow_params must all be Tensors" 284 | 285 | self.collected_params = state_dict["collected_params"] 286 | if self.collected_params is not None: 287 | assert isinstance(self.collected_params, list), \ 288 | "collected_params must be a list" 289 | assert all( 290 | isinstance(p, torch.Tensor) for p in self.collected_params 291 | ), "collected_params must all be Tensors" 292 | assert len(self.collected_params) == len(self.shadow_params), \ 293 | "collected_params and shadow_params had different lengths" 294 | 295 | if len(self.shadow_params) == len(self._params_refs): 296 | # Consistent with torch.optim.Optimizer, cast things to consistent 297 | # device and dtype with the parameters 298 | params = [p() for p in self._params_refs] 299 | # If parameters have been garbage collected, just load the state 300 | # we were given without change. 301 | if not any(p is None for p in params): 302 | # ^ parameter references are still good 303 | for i, p in enumerate(params): 304 | self.shadow_params[i] = to_float_maybe(self.shadow_params[i].to( 305 | device=p.device, dtype=p.dtype 306 | )) 307 | if self.collected_params is not None: 308 | self.collected_params[i] = self.collected_params[i].to( 309 | device=p.device, dtype=p.dtype 310 | ) 311 | else: 312 | raise ValueError( 313 | "Tried to `load_state_dict()` with the wrong number of " 314 | "parameters in the saved state." 315 | ) 316 | 317 | 318 | class EMACallback(pl.Callback): 319 | """TD [2021-08-31]: saving and loading from checkpoint should work. 320 | """ 321 | def __init__(self, decay: float, use_num_updates: bool = True): 322 | """ 323 | decay: The exponential decay. 324 | use_num_updates: Whether to use number of updates when computing 325 | averages. 326 | """ 327 | super().__init__() 328 | self.decay = decay 329 | self.use_num_updates = use_num_updates 330 | self.ema = None 331 | 332 | def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"): 333 | # It's possible that we already loaded EMA from the checkpoint 334 | if self.ema is None: 335 | self.ema = ExponentialMovingAverage([p for p in pl_module.parameters() if p.requires_grad], 336 | decay=self.decay, use_num_updates=self.use_num_updates) 337 | 338 | # Ideally we want on_after_optimizer_step but pytorch-lightning doesn't have it 339 | # We only want to update when parameters are changing. 340 | # Because of gradient accumulation, this doesn't happen every training step. 341 | # https://github.com/PyTorchLightning/pytorch-lightning/issues/11688 342 | def on_train_batch_end( 343 | self, 344 | trainer: "pl.Trainer", 345 | pl_module: "pl.LightningModule", 346 | outputs: STEP_OUTPUT, 347 | batch: Any, 348 | batch_idx: int, 349 | ) -> None: 350 | if (batch_idx + 1) % trainer.accumulate_grad_batches == 0: 351 | self.ema.update() 352 | 353 | def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: 354 | # During the initial validation we don't have self.ema yet 355 | if self.ema is not None: 356 | self.ema.store() 357 | self.ema.copy_to() 358 | 359 | def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: 360 | if self.ema is not None: 361 | self.ema.restore() 362 | 363 | def on_test_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: 364 | if self.ema is not None: 365 | self.ema.store() 366 | self.ema.copy_to() 367 | 368 | def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: 369 | if self.ema is not None: 370 | self.ema.restore() 371 | 372 | def on_save_checkpoint( 373 | self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any] 374 | ) -> Dict[str, Any]: 375 | return self.ema.state_dict() 376 | 377 | def on_load_checkpoint( 378 | self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", 379 | checkpoint: Dict[str, Any] 380 | ) -> None: 381 | if self.ema is None: 382 | self.ema = ExponentialMovingAverage([p for p in pl_module.parameters() if p.requires_grad], 383 | decay=self.decay, use_num_updates=self.use_num_updates) 384 | self.ema.load_state_dict(checkpoint) 385 | 386 | 387 | ################################################################################# 388 | # Embedding Layers for Timesteps and Class Labels # 389 | ################################################################################# 390 | 391 | class TimestepEmbedder(pl.LightningModule): 392 | """ 393 | Embeds scalar timesteps into vector representations. 394 | """ 395 | def __init__(self, hidden_size, frequency_embedding_size=256): 396 | super().__init__() 397 | self.mlp = nn.Sequential( 398 | nn.Linear(frequency_embedding_size, hidden_size, bias=True), 399 | nn.SiLU(), 400 | nn.Linear(hidden_size, hidden_size, bias=True), 401 | ) 402 | self.frequency_embedding_size = frequency_embedding_size 403 | 404 | @staticmethod 405 | def timestep_embedding(t, dim, max_period=10000): 406 | """ 407 | Create sinusoidal timestep embeddings. 408 | :param t: a 1-D Tensor of N indices, one per batch element. 409 | These may be fractional. 410 | :param dim: the dimension of the output. 411 | :param max_period: controls the minimum frequency of the embeddings. 412 | :return: an (N, D) Tensor of positional embeddings. 413 | """ 414 | # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py 415 | half = dim // 2 416 | freqs = torch.exp( 417 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 418 | ).to(device=t.device) 419 | 420 | args = t[:, None].float() * freqs[None] 421 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 422 | if dim % 2: 423 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 424 | return embedding 425 | 426 | def forward(self, t): 427 | t_freq = self.timestep_embedding(t, self.frequency_embedding_size) 428 | t_emb = self.mlp(t_freq) 429 | return t_emb 430 | 431 | 432 | class LabelEmbedder(pl.LightningModule): 433 | """ 434 | Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. 435 | """ 436 | def __init__(self, num_classes, hidden_size, dropout_prob): 437 | super().__init__() 438 | use_cfg_embedding = dropout_prob > 0 439 | self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size) 440 | self.num_classes = num_classes 441 | self.dropout_prob = dropout_prob 442 | 443 | def token_drop(self, labels, force_drop_ids=None): 444 | """ 445 | Drops labels to enable classifier-free guidance. 446 | """ 447 | if force_drop_ids is None: 448 | drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob 449 | else: 450 | drop_ids = force_drop_ids == 1 451 | labels = torch.where(drop_ids, self.num_classes, labels) 452 | return labels 453 | 454 | def forward(self, labels, train, force_drop_ids=None): 455 | use_dropout = self.dropout_prob > 0 456 | if (train and use_dropout) or (force_drop_ids is not None): 457 | labels = self.token_drop(labels, force_drop_ids) 458 | embeddings = self.embedding_table(labels) 459 | return embeddings 460 | 461 | 462 | ################################################################################# 463 | # Core DiT Model # 464 | ################################################################################# 465 | 466 | class DiTBlock(pl.LightningModule): 467 | """ 468 | A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning. 469 | """ 470 | def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs): 471 | super().__init__() 472 | self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 473 | self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs) 474 | self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 475 | mlp_hidden_dim = int(hidden_size * mlp_ratio) 476 | approx_gelu = lambda: nn.GELU(approximate="tanh") 477 | self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) 478 | self.adaLN_modulation = nn.Sequential( 479 | nn.SiLU(), 480 | nn.Linear(hidden_size, 6 * hidden_size, bias=True) 481 | ) 482 | 483 | def forward(self, x, c): 484 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1) 485 | x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa)) 486 | x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) 487 | return x 488 | 489 | 490 | class FinalLayer(pl.LightningModule): 491 | """ 492 | The final layer of DiT. 493 | """ 494 | def __init__(self, hidden_size, patch_size, out_channels): 495 | super().__init__() 496 | self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 497 | self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) 498 | self.adaLN_modulation = nn.Sequential( 499 | nn.SiLU(), 500 | nn.Linear(hidden_size, 2 * hidden_size, bias=True) 501 | ) 502 | 503 | def forward(self, x, c): 504 | shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) 505 | x = modulate(self.norm_final(x), shift, scale) 506 | x = self.linear(x) 507 | return x 508 | 509 | class DiT(pl.LightningModule): 510 | """ 511 | Diffusion model with a Transformer backbone. 512 | """ 513 | def __init__( 514 | self, 515 | input_size=32, 516 | patch_size=2, 517 | in_channels=4, 518 | hidden_size=1152, 519 | depth=28, 520 | num_heads=16, 521 | mlp_ratio=4.0, 522 | class_dropout_prob=0.1, 523 | num_classes=1000, 524 | learn_sigma=True, 525 | lightning_mode=False, 526 | latent_shape=None, 527 | training_type="shortcut" 528 | ): 529 | super().__init__() 530 | self.learn_sigma = learn_sigma 531 | self.in_channels = in_channels 532 | self.out_channels = in_channels * 2 if learn_sigma else in_channels 533 | self.patch_size = patch_size 534 | self.num_heads = num_heads 535 | 536 | self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True) 537 | self.t_embedder = TimestepEmbedder(hidden_size) 538 | self.dt_embedder = TimestepEmbedder(hidden_size) 539 | self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob) 540 | num_patches = self.x_embedder.num_patches 541 | # Will use fixed sin-cos embedding: 542 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False) 543 | 544 | self.blocks = nn.ModuleList([ 545 | DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth) 546 | ]) 547 | self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels) 548 | self.initialize_weights() 549 | 550 | if lightning_mode: 551 | self.vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(self.device) 552 | self.vae = self.vae.eval() 553 | self.vae.requires_grad_(False) 554 | 555 | self.eps = torch.randn(latent_shape).to(self.device) 556 | 557 | self.training_type = training_type 558 | 559 | # number of denoising steps to be applied 560 | self.denoise_timesteps = [1, 2, 4, 8, 16, 32, 128] 561 | 562 | # self.fid = FrechetInceptionDistance().to(self.device) 563 | 564 | self.fids = None 565 | self.validation_step_outputs = [] 566 | 567 | makedirs("log_images3", exist_ok=True) 568 | 569 | 570 | def on_fit_start(self): 571 | 572 | self.fids = [FrechetInceptionDistance().to(self.device) for _ in range(len(self.denoise_timesteps))] 573 | 574 | 575 | def initialize_weights(self): 576 | # Initialize transformer layers: 577 | def _basic_init(module): 578 | if isinstance(module, nn.Linear): 579 | torch.nn.init.xavier_uniform_(module.weight) 580 | if module.bias is not None: 581 | nn.init.constant_(module.bias, 0) 582 | self.apply(_basic_init) 583 | 584 | # Initialize (and freeze) pos_embed by sin-cos embedding: 585 | pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5)) 586 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) 587 | 588 | # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): 589 | w = self.x_embedder.proj.weight.data 590 | nn.init.xavier_uniform_(w.view([w.shape[0], -1])) 591 | nn.init.constant_(self.x_embedder.proj.bias, 0) 592 | 593 | # Initialize label embedding table: 594 | nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) 595 | 596 | # Initialize timestep embedding MLP: 597 | nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) 598 | nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) 599 | 600 | nn.init.normal_(self.dt_embedder.mlp[0].weight, std=0.02) 601 | nn.init.normal_(self.dt_embedder.mlp[2].weight, std=0.02) 602 | 603 | # Zero-out adaLN modulation layers in DiT blocks: 604 | for block in self.blocks: 605 | nn.init.constant_(block.adaLN_modulation[-1].weight, 0) 606 | nn.init.constant_(block.adaLN_modulation[-1].bias, 0) 607 | 608 | # Zero-out output layers: 609 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) 610 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) 611 | nn.init.constant_(self.final_layer.linear.weight, 0) 612 | nn.init.constant_(self.final_layer.linear.bias, 0) 613 | 614 | self.loss_fn = torch.nn.MSELoss() 615 | 616 | def unpatchify(self, x): 617 | """ 618 | x: (N, T, patch_size**2 * C) 619 | imgs: (N, H, W, C) 620 | """ 621 | c = self.out_channels 622 | p = self.x_embedder.patch_size[0] 623 | h = w = int(x.shape[1] ** 0.5) 624 | assert h * w == x.shape[1] 625 | 626 | x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) 627 | x = torch.einsum('nhwpqc->nchpwq', x) 628 | imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p)) 629 | return imgs 630 | 631 | def forward(self, x, t, dt, y): 632 | """ 633 | Forward pass of DiT. 634 | x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) 635 | t: (N,) tensor of diffusion timesteps 636 | y: (N,) tensor of class labels 637 | """ 638 | 639 | if self.training_type=="naive": 640 | dt = torch.zeros_like(t) 641 | 642 | x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2 643 | t = self.t_embedder(t) # (N, D) 644 | dt = self.dt_embedder(dt) 645 | y = self.y_embedder(y, self.training) # (N, D) 646 | # print(f"t.shape: {t.shape}") 647 | # print(f"y.shape: {y.shape}") 648 | # print(f"dt.shape: {dt.shape}") 649 | # print(f"x.shape: {x.shape}") 650 | c = t + y + dt # (N, D) 651 | for block in self.blocks: 652 | x = block(x, c) # (N, T, D) 653 | 654 | x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels) 655 | x = self.unpatchify(x) # (N, out_channels, H, W) 656 | return x 657 | 658 | def training_step(self, batch, batch_idx): 659 | 660 | images, labels = batch 661 | 662 | labels = torch.ones_like(labels) 663 | 664 | with torch.no_grad(): 665 | latents = self.vae.encode(images).latent_dist.sample() 666 | latents = latents * self.vae.config.scaling_factor 667 | 668 | if self.training_type=="naive": 669 | x_t, v_t, t, dt_base, labels_dropped = create_targets_naive(latents, labels, self) 670 | elif self.training_type=="shortcut": 671 | x_t, v_t, t, dt_base, labels_dropped = create_targets(latents, labels, self) 672 | 673 | v_prime = self.forward(x_t, t, dt_base, labels) 674 | 675 | loss = self.loss_fn(v_prime, v_t) 676 | self.log("train_loss", loss, on_epoch=True, on_step=False, sync_dist=True) 677 | 678 | return loss 679 | 680 | def validation_step(self, batch, batch_idx): 681 | 682 | images, labels_real = batch 683 | 684 | labels_uncond = torch.ones_like(labels_real, dtype=torch.int32) * NUM_CLASSES 685 | 686 | 687 | with torch.no_grad(): 688 | latents = self.vae.encode(images).latent_dist.sample() 689 | latents = latents * self.vae.config.scaling_factor 690 | 691 | 692 | # normalize to [0,255] range 693 | images = 255 * ((images - torch.min(images)) / (torch.max(images) - torch.min(images) + 1e-8)) 694 | 695 | # sample noise 696 | eps_i = torch.randn_like(latents).to(self.device) 697 | 698 | # denoise_timesteps_list = [1, 2, 4, 8, 16, 32, 128] 699 | 700 | for i, denoise_timesteps in enumerate(self.denoise_timesteps): 701 | 702 | all_x = [] 703 | delta_t = 1.0 / denoise_timesteps # i.e. step size 704 | # self.fid.reset() 705 | 706 | x = eps_i.to(self.device) 707 | 708 | for ti in range(denoise_timesteps): 709 | # t should in range [0,1] 710 | t = ti / denoise_timesteps 711 | 712 | t_vector = torch.full((eps_i.shape[0],), t).to(self.device) 713 | # t_vector = torch.full((eps.shape[0],), t) 714 | dt_base = torch.ones_like(t_vector).to(self.device) * math.log2(denoise_timesteps) 715 | # dt_base = torch.ones_like(t_vector) * math.log2(denoise_timesteps) 716 | 717 | 718 | # only on last step! 719 | # if i == len(denoise_timesteps_list)-1: 720 | # with torch.no_grad(): 721 | # v_cond = self.forward(eps, t_vector, dt_base, labels_real) 722 | # v_uncond = self.forward(eps, t_vector, dt_base, labels_uncond) 723 | 724 | # v = v_uncond + CFG_SCALE * (v_cond - v_uncond) 725 | # else: 726 | # # t is same for all latents 727 | # with torch.no_grad(): 728 | # v = self.forward(eps, t_vector, dt_base, labels_real) 729 | 730 | with torch.no_grad(): 731 | v = self.forward(x, t_vector, dt_base, labels_real) 732 | 733 | x = x + v*delta_t 734 | 735 | # log 8 steps 736 | if denoise_timesteps <= 8 or ti % (denoise_timesteps//8) ==0 or ti == denoise_timesteps-1: 737 | with torch.no_grad(): 738 | decoded = self.vae.decode(x/self.vae.config.scaling_factor)[0] 739 | 740 | decoded = decoded.to("cpu") 741 | 742 | all_x.append(decoded) 743 | 744 | if(len(all_x)==9): 745 | all_x = all_x[1:] 746 | 747 | # estimate FID metric 748 | decoded_denormalized = 255 * ((decoded - torch.min(decoded)) / (torch.max(decoded)-torch.min(decoded)+1e-8)) 749 | 750 | # generated images 751 | self.fids[i].update(images.to(torch.uint8).to(self.device), real=True) 752 | self.fids[i].update(decoded_denormalized.to(torch.uint8).to(self.device), real=False) 753 | 754 | 755 | # fid_val = fid.compute() 756 | 757 | # log only a single batch of generated images and only on first device 758 | if self.trainer.is_global_zero and batch_idx == 0: 759 | 760 | all_x = torch.stack(all_x) 761 | 762 | def process_img(img): 763 | # normalize in range [0,1] 764 | img = img*0.5 + 0.5 765 | img = torch.clip(img, 0, 1) 766 | img = img.permute(1,2,0) 767 | return img 768 | 769 | fig, axs = plt.subplots(8, 8, figsize=(30,30)) 770 | for t in range(min(8, all_x.shape[0])): 771 | for j in range(8): 772 | axs[t, j].imshow(process_img(all_x[t, j]), vmin=0, vmax=1) 773 | 774 | 775 | fig.savefig(f"log_images3/epoch:{self.trainer.current_epoch}_denoise_timesteps:{denoise_timesteps}.png") 776 | 777 | # self.logger.experiment.add_figure(f"epoch:{self.trainer.current_epoch}_denoise_timesteps:{denoise_timesteps}", fig, global_step=self.global_step) 778 | # self.logger.experiment.add_figure(f"denoise_timesteps:{denoise_timesteps}", fig, global_step=self.global_step) 779 | 780 | self.logger.experiment.log({f"denoise_timesteps:{denoise_timesteps}" : [wandb.Image(fig)]}) 781 | # log_image(key=f"denoise_timesteps:{denoise_timesteps}", images=wandb.Image(fig)) 782 | 783 | plt.close() 784 | 785 | return 0 786 | 787 | 788 | # def validation_epoch_start(self,): 789 | 790 | # if self.trainer.is_global_zero: 791 | 792 | # self.fid.reset() 793 | 794 | def on_validation_epoch_end(self): 795 | # if self.trainer.is_global_zero: 796 | for i in range(len(self.fids)): 797 | denoise_timesteps_i = self.denoise_timesteps[i] 798 | 799 | # Compute FID for the current timestep 800 | fid_val_i = self.fids[i].compute() 801 | self.fids[i].reset() 802 | 803 | # print(f"i: {i} | fid_val_i: {fid_val_i}") 804 | 805 | # Log the FID value 806 | self.log(f"[FID] denoise_steps: {denoise_timesteps_i}", fid_val_i, on_epoch=True, on_step=False, sync_dist=True) 807 | 808 | 809 | # def on_validation_epoch_end(self, outputs): 810 | 811 | # if self.trainer.is_global_zero: 812 | 813 | # for i in range(len(self.fids)): 814 | 815 | # denoise_timesteps_i = self.denoise_timesteps[i] 816 | 817 | # fid_val_i = self.fids[i].compute() 818 | 819 | # self.fids[i].reset() 820 | 821 | # self.log(f"steps: {denoise_timesteps_i}", fid_val_i, on_epoch=True, on_step=False) 822 | 823 | # fid_val = self.fid.compute() 824 | # self.fid.reset() 825 | 826 | def configure_optimizers(self): 827 | 828 | optimizer = torch.optim.AdamW(self.parameters(), lr=LEARNING_RATE, weight_decay=0.1) 829 | 830 | return optimizer 831 | 832 | 833 | 834 | def forward_with_cfg(self, x, t, y, cfg_scale): 835 | """ 836 | Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance. 837 | """ 838 | # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb 839 | half = x[: len(x) // 2] 840 | combined = torch.cat([half, half], dim=0) 841 | model_out = self.forward(combined, t, y) 842 | # For exact reproducibility reasons, we apply classifier-free guidance on only 843 | # three channels by default. The standard approach to cfg applies it to all channels. 844 | # This can be done by uncommenting the following line and commenting-out the line following that. 845 | # eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:] 846 | eps, rest = model_out[:, :3], model_out[:, 3:] 847 | cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) 848 | half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) 849 | eps = torch.cat([half_eps, half_eps], dim=0) 850 | return torch.cat([eps, rest], dim=1) 851 | 852 | 853 | ################################################################################# 854 | # Sine/Cosine Positional Embedding Functions # 855 | ################################################################################# 856 | # https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py 857 | 858 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): 859 | """ 860 | grid_size: int of the grid height and width 861 | return: 862 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 863 | """ 864 | grid_h = np.arange(grid_size, dtype=np.float32) 865 | grid_w = np.arange(grid_size, dtype=np.float32) 866 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 867 | grid = np.stack(grid, axis=0) 868 | 869 | grid = grid.reshape([2, 1, grid_size, grid_size]) 870 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 871 | if cls_token and extra_tokens > 0: 872 | pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) 873 | return pos_embed 874 | 875 | 876 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 877 | assert embed_dim % 2 == 0 878 | 879 | # use half of dimensions to encode grid_h 880 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 881 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 882 | 883 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 884 | return emb 885 | 886 | 887 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 888 | """ 889 | embed_dim: output dimension for each position 890 | pos: a list of positions to be encoded: size (M,) 891 | out: (M, D) 892 | """ 893 | assert embed_dim % 2 == 0 894 | omega = np.arange(embed_dim // 2, dtype=np.float64) 895 | omega /= embed_dim / 2. 896 | omega = 1. / 10000**omega # (D/2,) 897 | 898 | pos = pos.reshape(-1) # (M,) 899 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 900 | 901 | emb_sin = np.sin(out) # (M, D/2) 902 | emb_cos = np.cos(out) # (M, D/2) 903 | 904 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 905 | return emb 906 | 907 | 908 | ################################################################################# 909 | # DiT Configs # 910 | ################################################################################# 911 | 912 | def DiT_XL_2(**kwargs): 913 | return DiT(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs) 914 | 915 | def DiT_XL_4(**kwargs): 916 | return DiT(depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs) 917 | 918 | def DiT_XL_8(**kwargs): 919 | return DiT(depth=28, hidden_size=1152, patch_size=8, num_heads=16, **kwargs) 920 | 921 | def DiT_L_2(**kwargs): 922 | return DiT(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs) 923 | 924 | def DiT_L_4(**kwargs): 925 | return DiT(depth=24, hidden_size=1024, patch_size=4, num_heads=16, **kwargs) 926 | 927 | def DiT_L_8(**kwargs): 928 | return DiT(depth=24, hidden_size=1024, patch_size=8, num_heads=16, **kwargs) 929 | 930 | # default model used in shortcut 931 | def DiT_B_2(**kwargs): 932 | return DiT(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs) 933 | 934 | def DiT_B_4(**kwargs): 935 | return DiT(depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs) 936 | 937 | def DiT_B_8(**kwargs): 938 | return DiT(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs) 939 | 940 | def DiT_S_2(**kwargs): 941 | return DiT(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs) 942 | 943 | def DiT_S_4(**kwargs): 944 | return DiT(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs) 945 | 946 | def DiT_S_8(**kwargs): 947 | return DiT(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs) 948 | 949 | 950 | DiT_models = { 951 | 'DiT-XL/2': DiT_XL_2, 'DiT-XL/4': DiT_XL_4, 'DiT-XL/8': DiT_XL_8, 952 | 'DiT-L/2': DiT_L_2, 'DiT-L/4': DiT_L_4, 'DiT-L/8': DiT_L_8, 953 | 'DiT-B/2': DiT_B_2, 'DiT-B/4': DiT_B_4, 'DiT-B/8': DiT_B_8, 954 | 'DiT-S/2': DiT_S_2, 'DiT-S/4': DiT_S_4, 'DiT-S/8': DiT_S_8, 955 | } -------------------------------------------------------------------------------- /samples/epoch:299_denoise_timesteps:1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smileyenot983/shortcut_pytorch/7fee517af0b15dd8bd6ec89371f2e62e3947b4f0/samples/epoch:299_denoise_timesteps:1.png -------------------------------------------------------------------------------- /samples/epoch:299_denoise_timesteps:128.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smileyenot983/shortcut_pytorch/7fee517af0b15dd8bd6ec89371f2e62e3947b4f0/samples/epoch:299_denoise_timesteps:128.png -------------------------------------------------------------------------------- /samples/epoch:299_denoise_timesteps:16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smileyenot983/shortcut_pytorch/7fee517af0b15dd8bd6ec89371f2e62e3947b4f0/samples/epoch:299_denoise_timesteps:16.png -------------------------------------------------------------------------------- /samples/epoch:299_denoise_timesteps:2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smileyenot983/shortcut_pytorch/7fee517af0b15dd8bd6ec89371f2e62e3947b4f0/samples/epoch:299_denoise_timesteps:2.png -------------------------------------------------------------------------------- /samples/epoch:299_denoise_timesteps:4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smileyenot983/shortcut_pytorch/7fee517af0b15dd8bd6ec89371f2e62e3947b4f0/samples/epoch:299_denoise_timesteps:4.png -------------------------------------------------------------------------------- /samples/epoch:299_denoise_timesteps:8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smileyenot983/shortcut_pytorch/7fee517af0b15dd8bd6ec89371f2e62e3947b4f0/samples/epoch:299_denoise_timesteps:8.png -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | import pandas as pd 4 | import io 5 | import os 6 | from copy import deepcopy 7 | from collections import OrderedDict 8 | import math 9 | from tqdm import tqdm 10 | import matplotlib.pyplot as plt 11 | 12 | import torch 13 | import torchvision 14 | from torch.utils.data import Dataset, DataLoader 15 | from torchvision import transforms 16 | from torchvision.utils import save_image 17 | 18 | from torchmetrics.image.fid import FrechetInceptionDistance 19 | 20 | from diffusers.models import AutoencoderKL 21 | 22 | import pytorch_lightning as pl 23 | from pytorch_lightning.callbacks import ModelCheckpoint 24 | from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger 25 | 26 | 27 | from model import DiT_B_2, EMACallback 28 | from utils import create_targets, create_targets_naive 29 | 30 | def count_parameters(model): 31 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 32 | 33 | LEARNING_RATE = 0.0001 34 | BATCH_SIZE = 64 35 | EVAL_SIZE = 8 36 | 37 | NUM_CLASSES = 1 38 | CLASS_DROPOUT_PROB = 1.0 39 | 40 | N_EPOCHS = 100 41 | DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' 42 | 43 | LOG_EVERY = 3 44 | CFG_SCALE = 0.0 45 | 46 | 47 | class CelebaHQDataset(Dataset): 48 | def __init__(self, parquet_path, transform=None, size=256): 49 | 50 | 51 | parquet_names = os.listdir(parquet_path) 52 | parquet_paths = [os.path.join(parquet_path, parquet_name) for parquet_name in parquet_names] 53 | 54 | if len(parquet_paths) < 1: 55 | return FileNotFoundError 56 | 57 | parquets = [] 58 | 59 | for i in range(len(parquet_paths)): 60 | parquet_i = pd.read_parquet(parquet_paths[i]) 61 | parquets.append(parquet_i) 62 | 63 | self.data = pd.concat(parquets, axis=0) 64 | self.size = size 65 | self.transform = transform 66 | 67 | # print(f"self.data: {self.data}") 68 | 69 | def __len__(self): 70 | 71 | return self.data.shape[0] 72 | 73 | def __getitem__(self, idx): 74 | 75 | data_i = self.data.iloc[[idx]] 76 | 77 | image_i = Image.open(io.BytesIO(data_i['image.bytes'].item())).resize((self.size, self.size)) 78 | 79 | if self.transform is not None: 80 | image_i = self.transform(image_i) 81 | 82 | label_i = data_i['label'].item() 83 | 84 | return image_i, label_i 85 | 86 | @torch.no_grad() 87 | def update_ema(ema_model, model, decay=0.9999): 88 | """ 89 | Step the EMA model towards the current model. 90 | """ 91 | ema_params = OrderedDict(ema_model.named_parameters()) 92 | model_params = OrderedDict(model.named_parameters()) 93 | 94 | for name, param in model_params.items(): 95 | # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed 96 | ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay) 97 | 98 | def train_epoch(train_dataloader, dit, ema, vae, optimizer): 99 | 100 | loss_fn = torch.nn.MSELoss() 101 | 102 | dit.train() 103 | 104 | total_loss = 0.0 105 | for batch, (images, labels) in enumerate(tqdm(train_dataloader)): 106 | 107 | images, labels = images.to(DEVICE), labels.to(DEVICE) 108 | 109 | # print(f"torch.mean(images): {torch.mean(images)} | torch.std(images): {torch.std(images)}") 110 | 111 | with torch.no_grad(): 112 | 113 | latents = vae.encode(images).latent_dist.sample() * vae.config.scaling_factor 114 | 115 | print(f"latents.shape: {latents.shape}") 116 | # print(f"labels.shape: {labels.shape}") 117 | 118 | x_t, v_t, t, dt_base, labels_dropped = create_targets(latents, labels, dit) 119 | 120 | # print(f"x_t.shape: {x_t.shape}") 121 | # print(f"v_t.shape: {v_t.shape}") 122 | # print(f"t.shape: {t.shape}") 123 | # print(f"dt_base.shape: {dt_base.shape}") 124 | # print(f"labels_dropped.shape: {labels_dropped.shape}") 125 | 126 | # exit() 127 | 128 | v_prime = dit(x_t, t, dt_base, labels) 129 | 130 | # print(f"v_prime.shape: {v_prime.shape}") 131 | # print(f"v_prime: {v_prime}") 132 | 133 | loss = loss_fn(v_prime, v_t) 134 | 135 | total_loss += loss.item() 136 | optimizer.zero_grad() 137 | 138 | loss.backward() 139 | 140 | optimizer.step() 141 | 142 | update_ema(ema, dit) 143 | 144 | return total_loss / len(train_dataloader) 145 | 146 | def evaluate(dit, vae, val_dataloader, epoch): 147 | 148 | dit.eval() 149 | 150 | images, labels_real = next(iter(val_dataloader)) 151 | images, labels_real = images[:EVAL_SIZE].to(DEVICE), labels_real[:EVAL_SIZE].to(DEVICE) 152 | with torch.no_grad(): 153 | latents = vae.encode(images).latent_dist.sample() * vae.config.scaling_factor 154 | 155 | 156 | labels_uncond = torch.ones_like(labels_real, dtype=torch.int32) * NUM_CLASSES 157 | 158 | fid = FrechetInceptionDistance().to(DEVICE) 159 | 160 | 161 | images = 255 * ((images - torch.min(images)) / (torch.max(images) - torch.min(images) + 1e-8)) 162 | 163 | # all start from same noise 164 | eps = torch.randn_like(latents).to(DEVICE) 165 | 166 | 167 | 168 | denoise_timesteps_list = [1, 2, 4, 8, 16, 32, 128] 169 | # run at varying timesteps 170 | for i, denoise_timesteps in enumerate(denoise_timesteps_list): 171 | all_x = [] 172 | delta_t = 1.0 / denoise_timesteps # i.e. step size 173 | fid.reset() 174 | 175 | x = eps 176 | 177 | for ti in range(denoise_timesteps): 178 | # t should in range [0,1] 179 | t = ti / denoise_timesteps 180 | 181 | t_vector = torch.full((eps.shape[0],), t).to(DEVICE) 182 | dt_base = torch.ones_like(t_vector).to(DEVICE) * math.log2(denoise_timesteps) 183 | 184 | 185 | # if i == len(denoise_timesteps_list)-1: 186 | # with torch.no_grad(): 187 | # v_cond = dit.forward(x, t_vector, dt_base, labels_real) 188 | # v_uncond = dit.forward(x, t_vector, dt_base, labels_uncond) 189 | 190 | # v = v_uncond + CFG_SCALE * (v_cond - v_uncond) 191 | # else: 192 | # # t is same for all latents 193 | # with torch.no_grad(): 194 | # v = dit.forward(x, t_vector, dt_base, labels_real) 195 | 196 | with torch.no_grad(): 197 | v = dit.forward(x, t_vector, dt_base, labels_real) 198 | 199 | x = x + v * delta_t 200 | 201 | if denoise_timesteps <= 8 or ti % (denoise_timesteps//8) == 0 or ti == denoise_timesteps-1: 202 | 203 | with torch.no_grad(): 204 | decoded = vae.decode(x/vae.config.scaling_factor)[0] 205 | 206 | decoded = decoded.to("cpu") 207 | 208 | all_x.append(decoded) 209 | 210 | 211 | if(len(all_x)==9): 212 | all_x = all_x[1:] 213 | 214 | # estimate FID metric 215 | # images_fake = torch.randint(low=0, high=255, size=images.shape).to(torch.uint8).to(DEVICE) 216 | decoded_denormalized = 255 * ((decoded - torch.min(decoded)) / (torch.max(decoded)-torch.min(decoded)+1e-8)) 217 | 218 | # generated images 219 | fid.update(images.to(torch.uint8), real=True) 220 | fid.update(decoded_denormalized.to(torch.uint8).to(DEVICE), real=False) 221 | fid_val = fid.compute() 222 | print(f"denoise_timesteps: {denoise_timesteps} | fid_val: {fid_val}") 223 | 224 | all_x = torch.stack(all_x) 225 | 226 | def process_img(img): 227 | # normalize in range [0,1] 228 | img = img * 0.5 + 0.5 229 | img = torch.clip(img, 0, 1) 230 | img = img.permute(1,2,0) 231 | return img 232 | 233 | fig, axs = plt.subplots(8, 8, figsize=(30,30)) 234 | for t in range(min(8, all_x.shape[0])): 235 | for j in range(8): 236 | axs[t, j].imshow(process_img(all_x[t, j]), vmin=0, vmax=1) 237 | 238 | fig.savefig(f"log_images_tvanilla/epoch:{epoch}_denoise_timesteps:{denoise_timesteps}.png") 239 | # if i == len(denoise_timesteps_list)-1: 240 | # fig.savefig(f"log_images_tvanilla/epoch:{epoch}_cfg.png") 241 | # else: 242 | # fig.savefig(f"log_images_tvanilla/epoch:{epoch}_denoise_timesteps:{denoise_timesteps}.png") 243 | 244 | plt.close() 245 | 246 | def requires_grad(model, flag=True): 247 | """ 248 | Set requires_grad flag for all parameters in a model. 249 | """ 250 | for p in model.parameters(): 251 | p.requires_grad = flag 252 | 253 | 254 | # single gpu: [00:55<10:24, 0.64it/s, v_num=5] 255 | # 3 gpus: [03:35<00:00, 0.68it/s, v_num=7] 256 | def main_lightning(): 257 | 258 | # should be same as in flax implementation(83'653'863 params) 259 | 260 | 261 | 262 | 263 | 264 | 265 | train_transform = transforms.Compose([ 266 | transforms.ToTensor(), 267 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 268 | ]) 269 | 270 | # 28000 images 271 | train_dataset = CelebaHQDataset('/workspace/shortcut_pytorch/celeba-hq/data/train', transform=train_transform) 272 | # 2000 images 273 | val_dataset = CelebaHQDataset('/workspace/shortcut_pytorch/celeba-hq/data/val', transform=train_transform) 274 | 275 | 276 | 277 | print(f"len(train_dataset): {len(train_dataset)}") 278 | print(f"len(val_dataset): {len(val_dataset)}") 279 | 280 | 281 | 282 | # good option is 2*num_gpus 283 | train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, drop_last=True) 284 | val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, drop_last=True) 285 | 286 | 287 | images_i, labels_i = next(iter(train_dataloader)) 288 | 289 | # 4 - 2d mean, 2d std 290 | latent_shape = (BATCH_SIZE, 4, images_i.shape[2]//8, images_i.shape[2]//8) 291 | 292 | dit = DiT_B_2(learn_sigma=False, 293 | num_classes=NUM_CLASSES, 294 | class_dropout_prob=CLASS_DROPOUT_PROB, 295 | lightning_mode=True, 296 | latent_shape=latent_shape, 297 | training_type="shortcut") 298 | 299 | # if load from checkpoint: 300 | # checkpoint_path = "/workspace/shortcut_pytorch/tb_logs/shortcut_model/version_0/checkpoints/last.ckpt" 301 | # checkpoint = torch.load(checkpoint_path) 302 | # dit.load_state_dict(checkpoint['state_dict']) 303 | 304 | print(f"count_parameters(dit): {count_parameters(dit)}") 305 | 306 | callbacks = [] 307 | 308 | # Define a checkpoint callback 309 | checkpoint_callback = ModelCheckpoint( 310 | filename="model-{epoch:02d}", 311 | save_last=True, 312 | ) 313 | ema_callback = EMACallback(decay=0.999) 314 | 315 | callbacks.append(checkpoint_callback) 316 | callbacks.append(ema_callback) 317 | 318 | # logger = TensorBoardLogger("tb_logs", name="shortcut_model") 319 | logger = WandbLogger("shortcut_model") 320 | 321 | trainer = pl.Trainer(max_epochs=1000, 322 | accelerator="gpu", 323 | num_sanity_val_steps=1, 324 | check_val_every_n_epoch=50, 325 | limit_val_batches=1.0, 326 | devices=[0, 1, 2, 3], 327 | strategy="ddp_find_unused_parameters_false", 328 | callbacks=callbacks, 329 | logger=logger) 330 | 331 | trainer.fit(model=dit, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader) 332 | 333 | 334 | # single epoch takes: [15:45<00:00, 2.16s/it] 335 | def main(): 336 | 337 | train_transform = transforms.Compose([ 338 | transforms.ToTensor(), 339 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 340 | ]) 341 | 342 | train_dataset = CelebaHQDataset('/workspace/shortcut_pytorch/celeba-hq/data/train', transform=train_transform) 343 | val_dataset = CelebaHQDataset('/workspace/shortcut_pytorch/celeba-hq/data/val', transform=train_transform) 344 | 345 | print(f"len(train_dataset): {len(train_dataset)}") 346 | print(f"len(val_dataset): {len(val_dataset)}") 347 | 348 | # good option is 2*num_gpus 349 | train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, drop_last=True) 350 | val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, drop_last=True) 351 | 352 | images_i, labels_i = next(iter(train_dataloader)) 353 | 354 | # 4 - 2d mean, 2d std 355 | latent_shape = (BATCH_SIZE, 4, images_i.shape[2]//8, images_i.shape[2]//8) 356 | 357 | dit = DiT_B_2(learn_sigma=False, 358 | num_classes=NUM_CLASSES, 359 | class_dropout_prob=CLASS_DROPOUT_PROB, 360 | latent_shape=latent_shape, 361 | training_type="shortcut").to(DEVICE) 362 | 363 | vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(DEVICE) 364 | vae = vae.eval() 365 | vae.requires_grad_(False) 366 | 367 | print(f"count_parameters(dit): {count_parameters(dit)}") 368 | ema = deepcopy(dit).to(DEVICE) 369 | ema.requires_grad_(False) 370 | 371 | checkpoint_path = "dit_saved.pth" 372 | checkpoint = torch.load(checkpoint_path) 373 | dit.load_state_dict(torch.load(checkpoint_path)) 374 | 375 | 376 | 377 | optimizer = torch.optim.AdamW(dit.parameters(), lr=LEARNING_RATE, weight_decay=0.1) 378 | 379 | update_ema(ema, dit, decay=0) 380 | ema.eval() 381 | 382 | # evaluate(dit, vae, val_dataloader, 0) 383 | 384 | for i in range(N_EPOCHS): 385 | 386 | epoch_loss = train_epoch(train_dataloader, dit, ema, vae, optimizer) 387 | exit() 388 | print(f"epoch_loss: {epoch_loss}") 389 | 390 | if i%LOG_EVERY == 0 and i > 0: 391 | evaluate(dit, vae, val_dataloader, i) 392 | 393 | torch.save(dit.state_dict(), "dit_saved.pth") 394 | 395 | 396 | if __name__ == "__main__": 397 | # main() 398 | main_lightning() -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import os 4 | import pytorch_lightning as pl 5 | from pytorch_lightning import Callback 6 | from pytorch_lightning.callbacks import ModelCheckpoint 7 | from pytorch_lightning.utilities import rank_zero_warn, rank_zero_info 8 | from pytorch_lightning.utilities.exceptions import MisconfigurationException 9 | from pytorch_lightning.utilities.types import STEP_OUTPUT 10 | 11 | from typing import Any, Dict, List, Optional 12 | 13 | 14 | DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' 15 | BOOTSTRAP_EVERY = 8 16 | DENOISE_TIMESTEPS = 128 17 | CLASS_DROPOUT_PROB = 1.0 18 | NUM_CLASSES = 1 19 | BATCH_SIZE = 64 20 | 21 | # create batch, consisting of different timesteps and different dts(depending on total step sizes) 22 | def create_targets(images, labels, model): 23 | 24 | model.eval() 25 | 26 | current_batch_size = images.shape[0] 27 | 28 | FORCE_T = -1 29 | FORCE_DT = -1 30 | 31 | # 1. create step sizes dt 32 | bootstrap_batch_size = current_batch_size // BOOTSTRAP_EVERY #=8 33 | log2_sections = int(math.log2(DENOISE_TIMESTEPS)) 34 | # print(f"log2_sections: {log2_sections}") 35 | # print(f"bootstrap_batch_size: {bootstrap_batch_size}") 36 | 37 | dt_base = torch.repeat_interleave(log2_sections - 1 - torch.arange(log2_sections), bootstrap_batch_size // log2_sections) 38 | # print(f"dt_base: {dt_base}") 39 | 40 | dt_base = torch.cat([dt_base, torch.zeros(bootstrap_batch_size-dt_base.shape[0],)]) 41 | # print(f"dt_base: {dt_base}") 42 | 43 | 44 | 45 | force_dt_vec = torch.ones(bootstrap_batch_size) * FORCE_DT 46 | dt_base = torch.where(force_dt_vec != -1, force_dt_vec, dt_base).to(model.device) 47 | dt = 1 / (2 ** (dt_base)) # [1, 1/2, 1/8, 1/16, 1/32] 48 | # print(f"dt: {dt}") 49 | 50 | dt_base_bootstrap = dt_base + 1 51 | dt_bootstrap = dt / 2 # [0.0078125 0.015625 0.03125 0.0625 0.125 0.25 0.5 0.5] 52 | # print(f"dt_bootstrap: {dt_bootstrap}") 53 | 54 | # 2. sample timesteps t 55 | dt_sections = 2**dt_base 56 | 57 | # print(f"dt_sections: {dt_sections}") 58 | 59 | t = torch.cat([ 60 | torch.randint(low=0, high=int(val.item()), size=(1,)).float() 61 | for val in dt_sections 62 | ]).to(model.device) 63 | 64 | # print(f"t[randint]: {t}") 65 | t = t / dt_sections 66 | # print(f"t[normalized]: {t}") 67 | 68 | force_t_vec = torch.ones(bootstrap_batch_size, dtype=torch.float32).to(model.device) * FORCE_T 69 | t = torch.where(force_t_vec != -1, force_t_vec, t).to(model.device) 70 | t_full = t[:, None, None, None] 71 | 72 | # print(f"t_full: {t_full}") 73 | 74 | # 3. generate bootstrap targets: 75 | x_1 = images[:bootstrap_batch_size] 76 | x_0 = torch.randn_like(x_1) 77 | 78 | # get dx at timestep t 79 | x_t = (1 - (1-1e-5) * t_full)*x_0 + t_full*x_1 80 | 81 | bst_labels = labels[:bootstrap_batch_size] 82 | 83 | 84 | with torch.no_grad(): 85 | v_b1 = model(x_t, t, dt_base_bootstrap, bst_labels) 86 | 87 | t2 = t + dt_bootstrap 88 | x_t2 = x_t + dt_bootstrap[:, None, None, None] * v_b1 89 | x_t2 = torch.clip(x_t2, -4, 4) 90 | 91 | with torch.no_grad(): 92 | v_b2 = model(x_t2, t2, dt_base_bootstrap, bst_labels) 93 | 94 | v_target = (v_b1 + v_b2) / 2 95 | 96 | v_target = torch.clip(v_target, -4, 4) 97 | 98 | bst_v = v_target 99 | bst_dt = dt_base 100 | bst_t = t 101 | bst_xt = x_t 102 | bst_l = bst_labels 103 | 104 | # 4. generate flow-matching targets 105 | 106 | labels_dropout = torch.bernoulli(torch.full(labels.shape, CLASS_DROPOUT_PROB)).to(model.device) 107 | labels_dropped = torch.where(labels_dropout.bool(), NUM_CLASSES, labels) 108 | 109 | # sample t(normalized) 110 | t = torch.randint(low=0, high=DENOISE_TIMESTEPS, size=(images.shape[0],), dtype=torch.float32) 111 | # print(f"t: {t}") 112 | t /= DENOISE_TIMESTEPS 113 | # print(f"t: {t}") 114 | force_t_vec = torch.ones(images.shape[0]) * FORCE_T 115 | # force_t_vec = torch.full((images.shape[0],), FORCE_T, dtype=torch.float32) 116 | t = torch.where(force_t_vec != -1, force_t_vec, t).to(model.device) 117 | # t_full = t.view(-1, 1, 1, 1) 118 | t_full = t[:, None, None, None] 119 | 120 | # print(f"t_full: {t_full}") 121 | 122 | # sample flow pairs x_t, v_t 123 | x_0 = torch.randn_like(images).to(model.device) 124 | x_1 = images 125 | x_t = (1 - (1 - 1e-5) * t_full) * x_0 + t_full * x_1 126 | v_t = x_1 - (1 - 1e-5) * x_0 127 | 128 | dt_flow = int(math.log2(DENOISE_TIMESTEPS)) 129 | dt_base = (torch.ones(images.shape[0], dtype=torch.int32) * dt_flow).to(model.device) 130 | 131 | # 5. merge flow and bootstrap 132 | bst_size = current_batch_size // BOOTSTRAP_EVERY 133 | bst_size_data = current_batch_size - bst_size 134 | 135 | # print(f"bst_size: {bst_size}") 136 | # print(f"bst_size_data: {bst_size_data}") 137 | 138 | x_t = torch.cat([bst_xt, x_t[:bst_size_data]], dim=0) 139 | t = torch.cat([bst_t, t[:bst_size_data]], dim=0) 140 | 141 | dt_base = torch.cat([bst_dt, dt_base[:bst_size_data]], dim=0) 142 | v_t = torch.cat([bst_v, v_t[:bst_size_data]], dim=0) 143 | labels_dropped = torch.cat([bst_l, labels_dropped[:bst_size_data]], dim=0) 144 | 145 | return x_t, v_t, t, dt_base, labels_dropped 146 | 147 | def create_targets_naive(images, labels, model): 148 | 149 | model.eval() 150 | 151 | current_batch_size = images.shape[0] 152 | 153 | FORCE_T = -1 154 | FORCE_DT = -1 155 | 156 | labels_dropout = torch.bernoulli(torch.full(labels.shape, CLASS_DROPOUT_PROB)).to(model.device) 157 | labels_dropped = torch.where(labels_dropout.bool(), NUM_CLASSES, labels) 158 | 159 | # sample t(normalized) 160 | t = torch.randint(low=0, high=DENOISE_TIMESTEPS, size=(images.shape[0],), dtype=torch.float32) 161 | # print(f"t: {t}") 162 | t /= DENOISE_TIMESTEPS 163 | # print(f"t: {t}") 164 | force_t_vec = torch.ones(images.shape[0]) * FORCE_T 165 | # force_t_vec = torch.full((images.shape[0],), FORCE_T, dtype=torch.float32) 166 | t = torch.where(force_t_vec != -1, force_t_vec, t).to(model.device) 167 | # t_full = t.view(-1, 1, 1, 1) 168 | t_full = t[:, None, None, None] 169 | 170 | 171 | x_0 = torch.randn_like(images).to(model.device) 172 | x_1 = images 173 | x_t = (1 - (1 - 1e-5) * t_full) * x_0 + t_full * x_1 174 | v_t = x_1 - (1 - 1e-5) * x_0 175 | 176 | dt_flow = int(math.log2(DENOISE_TIMESTEPS)) 177 | dt_base = (torch.ones(images.shape[0], dtype=torch.int32) * dt_flow).to(model.device) 178 | 179 | return x_t, v_t, t, dt_base, labels_dropped 180 | 181 | 182 | 183 | 184 | 185 | --------------------------------------------------------------------------------