├── .gitignore
├── Dockerfile
├── README.md
├── docker-compose.yaml
├── model.py
├── samples
├── epoch:299_denoise_timesteps:1.png
├── epoch:299_denoise_timesteps:128.png
├── epoch:299_denoise_timesteps:16.png
├── epoch:299_denoise_timesteps:2.png
├── epoch:299_denoise_timesteps:4.png
└── epoch:299_denoise_timesteps:8.png
├── train.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | celeba-hq/
3 | log_images*
4 | log_images_tvanilla/
5 | lightning_logs/
6 | inception/
7 | tb_logs/
8 | wandb/
9 | dit_saved.pth
10 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | # Use a lightweight base image with Conda installed
2 | FROM nvidia/cuda:11.8.0-devel-ubuntu22.04
3 |
4 | # Set environment variables to avoid interactive prompts
5 | ARG DEBIAN_FRONTEND=noninteractive
6 | RUN apt-get update && \
7 | apt-get install -y --no-install-recommends \
8 | build-essential \
9 | cmake \
10 | curl \
11 | ffmpeg \
12 | git \
13 | wget \
14 | python3-pip \
15 | python3-dev
16 |
17 | # Create a working directory
18 | RUN pip install --upgrade pip
19 | RUN pip install accelerate==1.2.0 \
20 | && pip install diffusers==0.31.0 \
21 | && pip install torch==2.5.1 \
22 | && pip install timm==1.0.12 \
23 | && pip install torchmetrics[image] \
24 | && pip install matplotlib \
25 | && pip install pandas==2.2.3 \
26 | && pip install fastparquet==2024.11.0 \
27 | && pip install pytorch-lightning \
28 | && pip install tensorboard \
29 | && pip install wandb
30 |
31 | # 1. visualize logs(on server):
32 | # tensorboard --logdir=tb_logs/shortcut_model --port 6006
33 | # 2. on own pc:
34 | # ssh -N -f -L localhost:16006:localhost:6006 r.khafizov@10.16.88.93
35 | # if port is already in use, do:
36 | # 1. sudo netstat -tulpn | grep :16006
37 | # 2. kill observed process id `kill "id"`
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | ## One-Step Diffusion via Shortcut Models
3 |
4 | This is an unofficial PyTorch implementation. Original implementation in jax might be found here: https://github.com/kvfrans/shortcut-models
5 |
6 | At the moment there is a very basic implementation with 2 modes: naive(simple flow matching) and shortcut.
7 |
8 | Implementation allows to train in multi-gpu mode, thanks to pytorch-lightning
9 |
10 | ## Data
11 |
12 | I used celeba-hq dataset from HuggingFace for image generation task https://huggingface.co/datasets/mattymchen/celeba-hq
13 |
14 | ## Using the code
15 |
16 | There is a helpful Dockefile and docker-compose in this repository which install all necessary libraries.
17 |
18 | In order to run just write:
19 |
20 | ```
21 | python train.py
22 | ```
23 |
24 | ## Results(300 epochs of shortcut training):
25 |
26 | 1 denoising step:
27 |
28 |
29 |
30 |
31 |
32 | 2 denoising steps:
33 |
34 |
35 |
36 |
37 |
38 | 4 denoising steps:
39 |
40 |
41 |
42 |
43 |
44 | 8 denoising steps:
45 |
46 |
47 |
48 |
49 |
50 | 16 denoising steps:
51 |
52 |
53 |
54 |
55 |
56 | 128 denoising steps:
57 |
58 |
59 |
60 |
--------------------------------------------------------------------------------
/docker-compose.yaml:
--------------------------------------------------------------------------------
1 | services:
2 | shortcut_pytorch:
3 | image: rkhafizov.shortcut_pytorch_image
4 | build:
5 | context: .
6 | dockerfile: Dockerfile
7 | container_name: rkhafizov.shortcut_pytorch_container
8 | network_mode: host
9 | ipc: host
10 | volumes:
11 | - /home/r.khafizov/shortcut_pytorch:/workspace/shortcut_pytorch
12 | - /home/r.khafizov/shortcut_pytorch/inception:/root/.cache/torch/hub/checkpoints
13 | ports:
14 | - "6666:6666" # Adjust the port mapping as needed
15 | environment:
16 | - NVIDIA_VISIBLE_DEVICES=all # Adjust GPU visibility as needed
17 | command: "/bin/bash -c 'source /etc/bash.bashrc && tail -f /dev/null && /bin/bash'" # Keep container running
18 | deploy:
19 | resources:
20 | reservations:
21 | devices:
22 | - driver: nvidia
23 | # count: 1
24 | device_ids: ["0, 2, 3, 5"]
25 | capabilities: [gpu]
26 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # --------------------------------------------------------
7 | # References:
8 | # GLIDE: https://github.com/openai/glide-text2im
9 | # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
10 | # --------------------------------------------------------
11 |
12 | import torch
13 | import torch.nn as nn
14 | import numpy as np
15 | import math
16 | from timm.models.vision_transformer import PatchEmbed, Attention, Mlp
17 | from os import makedirs
18 |
19 | import pytorch_lightning as pl
20 | from pytorch_lightning.utilities.types import STEP_OUTPUT
21 |
22 |
23 | from torchmetrics.image.fid import FrechetInceptionDistance
24 | import matplotlib.pyplot as plt
25 |
26 | from diffusers.models import AutoencoderKL
27 |
28 | from utils import create_targets, create_targets_naive
29 |
30 | from utils import create_targets
31 | from typing import Dict, Any
32 | from typing import Iterable, Optional
33 |
34 | import weakref
35 | import copy
36 | import contextlib
37 | import wandb
38 |
39 | def modulate(x, shift, scale):
40 | return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
41 |
42 | LEARNING_RATE = 0.0001
43 | EVAL_SIZE = 8
44 | NUM_CLASSES = 1
45 | CFG_SCALE = 0.0
46 |
47 | def modulate(x, shift, scale):
48 | return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
49 |
50 | def to_float_maybe(x):
51 | return x.float() if x.dtype in [torch.float16, torch.bfloat16] else x
52 |
53 |
54 | class ExponentialMovingAverage:
55 | """
56 | Maintains (exponential) moving average of a set of parameters.
57 | Args:
58 | parameters: Iterable of `torch.nn.Parameter` (typically from
59 | `model.parameters()`).
60 | decay: The exponential decay.
61 | use_num_updates: Whether to use number of updates when computing
62 | averages.
63 | """
64 | def __init__(
65 | self,
66 | parameters: Iterable[torch.nn.Parameter],
67 | decay: float,
68 | use_num_updates: bool = True
69 | ):
70 | if decay < 0.0 or decay > 1.0:
71 | raise ValueError('Decay must be between 0 and 1')
72 | self.decay = decay
73 | self.num_updates = 0 if use_num_updates else None
74 | parameters = list(parameters)
75 | self.shadow_params = [to_float_maybe(p.clone().detach())
76 | for p in parameters if p.requires_grad]
77 | self.collected_params = None
78 | # By maintaining only a weakref to each parameter,
79 | # we maintain the old GC behaviour of ExponentialMovingAverage:
80 | # if the model goes out of scope but the ExponentialMovingAverage
81 | # is kept, no references to the model or its parameters will be
82 | # maintained, and the model will be cleaned up.
83 | self._params_refs = [weakref.ref(p) for p in parameters]
84 |
85 | def _get_parameters(
86 | self,
87 | parameters: Optional[Iterable[torch.nn.Parameter]]
88 | ) -> Iterable[torch.nn.Parameter]:
89 | if parameters is None:
90 | parameters = [p() for p in self._params_refs]
91 | if any(p is None for p in parameters):
92 | raise ValueError(
93 | "(One of) the parameters with which this "
94 | "ExponentialMovingAverage "
95 | "was initialized no longer exists (was garbage collected);"
96 | " please either provide `parameters` explicitly or keep "
97 | "the model to which they belong from being garbage "
98 | "collected."
99 | )
100 | return parameters
101 | else:
102 | parameters = list(parameters)
103 | if len(parameters) != len(self.shadow_params):
104 | raise ValueError(
105 | "Number of parameters passed as argument is different "
106 | "from number of shadow parameters maintained by this "
107 | "ExponentialMovingAverage"
108 | )
109 | return parameters
110 |
111 | def update(
112 | self,
113 | parameters: Optional[Iterable[torch.nn.Parameter]] = None
114 | ) -> None:
115 | """
116 | Update currently maintained parameters.
117 | Call this every time the parameters are updated, such as the result of
118 | the `optimizer.step()` call.
119 | Args:
120 | parameters: Iterable of `torch.nn.Parameter`; usually the same set of
121 | parameters used to initialize this object. If `None`, the
122 | parameters with which this `ExponentialMovingAverage` was
123 | initialized will be used.
124 | """
125 | parameters = self._get_parameters(parameters)
126 | decay = self.decay
127 | if self.num_updates is not None:
128 | self.num_updates += 1
129 | decay = min(
130 | decay,
131 | (1 + self.num_updates) / (10 + self.num_updates)
132 | )
133 | one_minus_decay = 1.0 - decay
134 | if parameters[0].device != self.shadow_params[0].device:
135 | self.to(device=parameters[0].device)
136 | with torch.no_grad():
137 | parameters = [p for p in parameters if p.requires_grad]
138 | for s_param, param in zip(self.shadow_params, parameters):
139 | torch.lerp(s_param, param.to(dtype=s_param.dtype), one_minus_decay, out=s_param)
140 |
141 | def copy_to(
142 | self,
143 | parameters: Optional[Iterable[torch.nn.Parameter]] = None
144 | ) -> None:
145 | """
146 | Copy current averaged parameters into given collection of parameters.
147 | Args:
148 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
149 | updated with the stored moving averages. If `None`, the
150 | parameters with which this `ExponentialMovingAverage` was
151 | initialized will be used.
152 | """
153 | parameters = self._get_parameters(parameters)
154 | for s_param, param in zip(self.shadow_params, parameters):
155 | if param.requires_grad:
156 | param.data.copy_(s_param.data)
157 |
158 | def store(
159 | self,
160 | parameters: Optional[Iterable[torch.nn.Parameter]] = None
161 | ) -> None:
162 | """
163 | Save the current parameters for restoring later.
164 | Args:
165 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
166 | temporarily stored. If `None`, the parameters of with which this
167 | `ExponentialMovingAverage` was initialized will be used.
168 | """
169 | parameters = self._get_parameters(parameters)
170 | self.collected_params = [
171 | param.clone()
172 | for param in parameters
173 | if param.requires_grad
174 | ]
175 |
176 | def restore(
177 | self,
178 | parameters: Optional[Iterable[torch.nn.Parameter]] = None
179 | ) -> None:
180 | """
181 | Restore the parameters stored with the `store` method.
182 | Useful to validate the model with EMA parameters without affecting the
183 | original optimization process. Store the parameters before the
184 | `copy_to` method. After validation (or model saving), use this to
185 | restore the former parameters.
186 | Args:
187 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
188 | updated with the stored parameters. If `None`, the
189 | parameters with which this `ExponentialMovingAverage` was
190 | initialized will be used.
191 | """
192 | if self.collected_params is None:
193 | raise RuntimeError(
194 | "This ExponentialMovingAverage has no `store()`ed weights "
195 | "to `restore()`"
196 | )
197 | parameters = self._get_parameters(parameters)
198 | for c_param, param in zip(self.collected_params, parameters):
199 | if param.requires_grad:
200 | param.data.copy_(c_param.data)
201 |
202 | @contextlib.contextmanager
203 | def average_parameters(
204 | self,
205 | parameters: Optional[Iterable[torch.nn.Parameter]] = None
206 | ):
207 | r"""
208 | Context manager for validation/inference with averaged parameters.
209 | Equivalent to:
210 | ema.store()
211 | ema.copy_to()
212 | try:
213 | ...
214 | finally:
215 | ema.restore()
216 | Args:
217 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
218 | updated with the stored parameters. If `None`, the
219 | parameters with which this `ExponentialMovingAverage` was
220 | initialized will be used.
221 | """
222 | parameters = self._get_parameters(parameters)
223 | self.store(parameters)
224 | self.copy_to(parameters)
225 | try:
226 | yield
227 | finally:
228 | self.restore(parameters)
229 |
230 | def to(self, device=None, dtype=None) -> None:
231 | r"""Move internal buffers of the ExponentialMovingAverage to `device`.
232 | Args:
233 | device: like `device` argument to `torch.Tensor.to`
234 | """
235 | # .to() on the tensors handles None correctly
236 | self.shadow_params = [
237 | p.to(device=device, dtype=dtype)
238 | if p.is_floating_point()
239 | else p.to(device=device)
240 | for p in self.shadow_params
241 | ]
242 | if self.collected_params is not None:
243 | self.collected_params = [
244 | p.to(device=device, dtype=dtype)
245 | if p.is_floating_point()
246 | else p.to(device=device)
247 | for p in self.collected_params
248 | ]
249 | return
250 |
251 | def state_dict(self) -> dict:
252 | r"""Returns the state of the ExponentialMovingAverage as a dict."""
253 | # Following PyTorch conventions, references to tensors are returned:
254 | # "returns a reference to the state and not its copy!" -
255 | # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict
256 | return {
257 | "decay": self.decay,
258 | "num_updates": self.num_updates,
259 | "shadow_params": self.shadow_params,
260 | "collected_params": self.collected_params
261 | }
262 |
263 | def load_state_dict(self, state_dict: dict) -> None:
264 | r"""Loads the ExponentialMovingAverage state.
265 | Args:
266 | state_dict (dict): EMA state. Should be an object returned
267 | from a call to :meth:`state_dict`.
268 | """
269 | # deepcopy, to be consistent with module API
270 | state_dict = copy.deepcopy(state_dict)
271 | self.decay = state_dict["decay"]
272 | if self.decay < 0.0 or self.decay > 1.0:
273 | raise ValueError('Decay must be between 0 and 1')
274 | self.num_updates = state_dict["num_updates"]
275 | assert self.num_updates is None or isinstance(self.num_updates, int), \
276 | "Invalid num_updates"
277 |
278 | self.shadow_params = state_dict["shadow_params"]
279 | assert isinstance(self.shadow_params, list), \
280 | "shadow_params must be a list"
281 | assert all(
282 | isinstance(p, torch.Tensor) for p in self.shadow_params
283 | ), "shadow_params must all be Tensors"
284 |
285 | self.collected_params = state_dict["collected_params"]
286 | if self.collected_params is not None:
287 | assert isinstance(self.collected_params, list), \
288 | "collected_params must be a list"
289 | assert all(
290 | isinstance(p, torch.Tensor) for p in self.collected_params
291 | ), "collected_params must all be Tensors"
292 | assert len(self.collected_params) == len(self.shadow_params), \
293 | "collected_params and shadow_params had different lengths"
294 |
295 | if len(self.shadow_params) == len(self._params_refs):
296 | # Consistent with torch.optim.Optimizer, cast things to consistent
297 | # device and dtype with the parameters
298 | params = [p() for p in self._params_refs]
299 | # If parameters have been garbage collected, just load the state
300 | # we were given without change.
301 | if not any(p is None for p in params):
302 | # ^ parameter references are still good
303 | for i, p in enumerate(params):
304 | self.shadow_params[i] = to_float_maybe(self.shadow_params[i].to(
305 | device=p.device, dtype=p.dtype
306 | ))
307 | if self.collected_params is not None:
308 | self.collected_params[i] = self.collected_params[i].to(
309 | device=p.device, dtype=p.dtype
310 | )
311 | else:
312 | raise ValueError(
313 | "Tried to `load_state_dict()` with the wrong number of "
314 | "parameters in the saved state."
315 | )
316 |
317 |
318 | class EMACallback(pl.Callback):
319 | """TD [2021-08-31]: saving and loading from checkpoint should work.
320 | """
321 | def __init__(self, decay: float, use_num_updates: bool = True):
322 | """
323 | decay: The exponential decay.
324 | use_num_updates: Whether to use number of updates when computing
325 | averages.
326 | """
327 | super().__init__()
328 | self.decay = decay
329 | self.use_num_updates = use_num_updates
330 | self.ema = None
331 |
332 | def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"):
333 | # It's possible that we already loaded EMA from the checkpoint
334 | if self.ema is None:
335 | self.ema = ExponentialMovingAverage([p for p in pl_module.parameters() if p.requires_grad],
336 | decay=self.decay, use_num_updates=self.use_num_updates)
337 |
338 | # Ideally we want on_after_optimizer_step but pytorch-lightning doesn't have it
339 | # We only want to update when parameters are changing.
340 | # Because of gradient accumulation, this doesn't happen every training step.
341 | # https://github.com/PyTorchLightning/pytorch-lightning/issues/11688
342 | def on_train_batch_end(
343 | self,
344 | trainer: "pl.Trainer",
345 | pl_module: "pl.LightningModule",
346 | outputs: STEP_OUTPUT,
347 | batch: Any,
348 | batch_idx: int,
349 | ) -> None:
350 | if (batch_idx + 1) % trainer.accumulate_grad_batches == 0:
351 | self.ema.update()
352 |
353 | def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
354 | # During the initial validation we don't have self.ema yet
355 | if self.ema is not None:
356 | self.ema.store()
357 | self.ema.copy_to()
358 |
359 | def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
360 | if self.ema is not None:
361 | self.ema.restore()
362 |
363 | def on_test_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
364 | if self.ema is not None:
365 | self.ema.store()
366 | self.ema.copy_to()
367 |
368 | def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
369 | if self.ema is not None:
370 | self.ema.restore()
371 |
372 | def on_save_checkpoint(
373 | self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any]
374 | ) -> Dict[str, Any]:
375 | return self.ema.state_dict()
376 |
377 | def on_load_checkpoint(
378 | self, trainer: "pl.Trainer", pl_module: "pl.LightningModule",
379 | checkpoint: Dict[str, Any]
380 | ) -> None:
381 | if self.ema is None:
382 | self.ema = ExponentialMovingAverage([p for p in pl_module.parameters() if p.requires_grad],
383 | decay=self.decay, use_num_updates=self.use_num_updates)
384 | self.ema.load_state_dict(checkpoint)
385 |
386 |
387 | #################################################################################
388 | # Embedding Layers for Timesteps and Class Labels #
389 | #################################################################################
390 |
391 | class TimestepEmbedder(pl.LightningModule):
392 | """
393 | Embeds scalar timesteps into vector representations.
394 | """
395 | def __init__(self, hidden_size, frequency_embedding_size=256):
396 | super().__init__()
397 | self.mlp = nn.Sequential(
398 | nn.Linear(frequency_embedding_size, hidden_size, bias=True),
399 | nn.SiLU(),
400 | nn.Linear(hidden_size, hidden_size, bias=True),
401 | )
402 | self.frequency_embedding_size = frequency_embedding_size
403 |
404 | @staticmethod
405 | def timestep_embedding(t, dim, max_period=10000):
406 | """
407 | Create sinusoidal timestep embeddings.
408 | :param t: a 1-D Tensor of N indices, one per batch element.
409 | These may be fractional.
410 | :param dim: the dimension of the output.
411 | :param max_period: controls the minimum frequency of the embeddings.
412 | :return: an (N, D) Tensor of positional embeddings.
413 | """
414 | # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
415 | half = dim // 2
416 | freqs = torch.exp(
417 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
418 | ).to(device=t.device)
419 |
420 | args = t[:, None].float() * freqs[None]
421 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
422 | if dim % 2:
423 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
424 | return embedding
425 |
426 | def forward(self, t):
427 | t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
428 | t_emb = self.mlp(t_freq)
429 | return t_emb
430 |
431 |
432 | class LabelEmbedder(pl.LightningModule):
433 | """
434 | Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
435 | """
436 | def __init__(self, num_classes, hidden_size, dropout_prob):
437 | super().__init__()
438 | use_cfg_embedding = dropout_prob > 0
439 | self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
440 | self.num_classes = num_classes
441 | self.dropout_prob = dropout_prob
442 |
443 | def token_drop(self, labels, force_drop_ids=None):
444 | """
445 | Drops labels to enable classifier-free guidance.
446 | """
447 | if force_drop_ids is None:
448 | drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
449 | else:
450 | drop_ids = force_drop_ids == 1
451 | labels = torch.where(drop_ids, self.num_classes, labels)
452 | return labels
453 |
454 | def forward(self, labels, train, force_drop_ids=None):
455 | use_dropout = self.dropout_prob > 0
456 | if (train and use_dropout) or (force_drop_ids is not None):
457 | labels = self.token_drop(labels, force_drop_ids)
458 | embeddings = self.embedding_table(labels)
459 | return embeddings
460 |
461 |
462 | #################################################################################
463 | # Core DiT Model #
464 | #################################################################################
465 |
466 | class DiTBlock(pl.LightningModule):
467 | """
468 | A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
469 | """
470 | def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
471 | super().__init__()
472 | self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
473 | self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
474 | self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
475 | mlp_hidden_dim = int(hidden_size * mlp_ratio)
476 | approx_gelu = lambda: nn.GELU(approximate="tanh")
477 | self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
478 | self.adaLN_modulation = nn.Sequential(
479 | nn.SiLU(),
480 | nn.Linear(hidden_size, 6 * hidden_size, bias=True)
481 | )
482 |
483 | def forward(self, x, c):
484 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
485 | x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
486 | x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
487 | return x
488 |
489 |
490 | class FinalLayer(pl.LightningModule):
491 | """
492 | The final layer of DiT.
493 | """
494 | def __init__(self, hidden_size, patch_size, out_channels):
495 | super().__init__()
496 | self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
497 | self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
498 | self.adaLN_modulation = nn.Sequential(
499 | nn.SiLU(),
500 | nn.Linear(hidden_size, 2 * hidden_size, bias=True)
501 | )
502 |
503 | def forward(self, x, c):
504 | shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
505 | x = modulate(self.norm_final(x), shift, scale)
506 | x = self.linear(x)
507 | return x
508 |
509 | class DiT(pl.LightningModule):
510 | """
511 | Diffusion model with a Transformer backbone.
512 | """
513 | def __init__(
514 | self,
515 | input_size=32,
516 | patch_size=2,
517 | in_channels=4,
518 | hidden_size=1152,
519 | depth=28,
520 | num_heads=16,
521 | mlp_ratio=4.0,
522 | class_dropout_prob=0.1,
523 | num_classes=1000,
524 | learn_sigma=True,
525 | lightning_mode=False,
526 | latent_shape=None,
527 | training_type="shortcut"
528 | ):
529 | super().__init__()
530 | self.learn_sigma = learn_sigma
531 | self.in_channels = in_channels
532 | self.out_channels = in_channels * 2 if learn_sigma else in_channels
533 | self.patch_size = patch_size
534 | self.num_heads = num_heads
535 |
536 | self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
537 | self.t_embedder = TimestepEmbedder(hidden_size)
538 | self.dt_embedder = TimestepEmbedder(hidden_size)
539 | self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
540 | num_patches = self.x_embedder.num_patches
541 | # Will use fixed sin-cos embedding:
542 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
543 |
544 | self.blocks = nn.ModuleList([
545 | DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)
546 | ])
547 | self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
548 | self.initialize_weights()
549 |
550 | if lightning_mode:
551 | self.vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(self.device)
552 | self.vae = self.vae.eval()
553 | self.vae.requires_grad_(False)
554 |
555 | self.eps = torch.randn(latent_shape).to(self.device)
556 |
557 | self.training_type = training_type
558 |
559 | # number of denoising steps to be applied
560 | self.denoise_timesteps = [1, 2, 4, 8, 16, 32, 128]
561 |
562 | # self.fid = FrechetInceptionDistance().to(self.device)
563 |
564 | self.fids = None
565 | self.validation_step_outputs = []
566 |
567 | makedirs("log_images3", exist_ok=True)
568 |
569 |
570 | def on_fit_start(self):
571 |
572 | self.fids = [FrechetInceptionDistance().to(self.device) for _ in range(len(self.denoise_timesteps))]
573 |
574 |
575 | def initialize_weights(self):
576 | # Initialize transformer layers:
577 | def _basic_init(module):
578 | if isinstance(module, nn.Linear):
579 | torch.nn.init.xavier_uniform_(module.weight)
580 | if module.bias is not None:
581 | nn.init.constant_(module.bias, 0)
582 | self.apply(_basic_init)
583 |
584 | # Initialize (and freeze) pos_embed by sin-cos embedding:
585 | pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5))
586 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
587 |
588 | # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
589 | w = self.x_embedder.proj.weight.data
590 | nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
591 | nn.init.constant_(self.x_embedder.proj.bias, 0)
592 |
593 | # Initialize label embedding table:
594 | nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
595 |
596 | # Initialize timestep embedding MLP:
597 | nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
598 | nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
599 |
600 | nn.init.normal_(self.dt_embedder.mlp[0].weight, std=0.02)
601 | nn.init.normal_(self.dt_embedder.mlp[2].weight, std=0.02)
602 |
603 | # Zero-out adaLN modulation layers in DiT blocks:
604 | for block in self.blocks:
605 | nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
606 | nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
607 |
608 | # Zero-out output layers:
609 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
610 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
611 | nn.init.constant_(self.final_layer.linear.weight, 0)
612 | nn.init.constant_(self.final_layer.linear.bias, 0)
613 |
614 | self.loss_fn = torch.nn.MSELoss()
615 |
616 | def unpatchify(self, x):
617 | """
618 | x: (N, T, patch_size**2 * C)
619 | imgs: (N, H, W, C)
620 | """
621 | c = self.out_channels
622 | p = self.x_embedder.patch_size[0]
623 | h = w = int(x.shape[1] ** 0.5)
624 | assert h * w == x.shape[1]
625 |
626 | x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
627 | x = torch.einsum('nhwpqc->nchpwq', x)
628 | imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
629 | return imgs
630 |
631 | def forward(self, x, t, dt, y):
632 | """
633 | Forward pass of DiT.
634 | x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
635 | t: (N,) tensor of diffusion timesteps
636 | y: (N,) tensor of class labels
637 | """
638 |
639 | if self.training_type=="naive":
640 | dt = torch.zeros_like(t)
641 |
642 | x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2
643 | t = self.t_embedder(t) # (N, D)
644 | dt = self.dt_embedder(dt)
645 | y = self.y_embedder(y, self.training) # (N, D)
646 | # print(f"t.shape: {t.shape}")
647 | # print(f"y.shape: {y.shape}")
648 | # print(f"dt.shape: {dt.shape}")
649 | # print(f"x.shape: {x.shape}")
650 | c = t + y + dt # (N, D)
651 | for block in self.blocks:
652 | x = block(x, c) # (N, T, D)
653 |
654 | x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)
655 | x = self.unpatchify(x) # (N, out_channels, H, W)
656 | return x
657 |
658 | def training_step(self, batch, batch_idx):
659 |
660 | images, labels = batch
661 |
662 | labels = torch.ones_like(labels)
663 |
664 | with torch.no_grad():
665 | latents = self.vae.encode(images).latent_dist.sample()
666 | latents = latents * self.vae.config.scaling_factor
667 |
668 | if self.training_type=="naive":
669 | x_t, v_t, t, dt_base, labels_dropped = create_targets_naive(latents, labels, self)
670 | elif self.training_type=="shortcut":
671 | x_t, v_t, t, dt_base, labels_dropped = create_targets(latents, labels, self)
672 |
673 | v_prime = self.forward(x_t, t, dt_base, labels)
674 |
675 | loss = self.loss_fn(v_prime, v_t)
676 | self.log("train_loss", loss, on_epoch=True, on_step=False, sync_dist=True)
677 |
678 | return loss
679 |
680 | def validation_step(self, batch, batch_idx):
681 |
682 | images, labels_real = batch
683 |
684 | labels_uncond = torch.ones_like(labels_real, dtype=torch.int32) * NUM_CLASSES
685 |
686 |
687 | with torch.no_grad():
688 | latents = self.vae.encode(images).latent_dist.sample()
689 | latents = latents * self.vae.config.scaling_factor
690 |
691 |
692 | # normalize to [0,255] range
693 | images = 255 * ((images - torch.min(images)) / (torch.max(images) - torch.min(images) + 1e-8))
694 |
695 | # sample noise
696 | eps_i = torch.randn_like(latents).to(self.device)
697 |
698 | # denoise_timesteps_list = [1, 2, 4, 8, 16, 32, 128]
699 |
700 | for i, denoise_timesteps in enumerate(self.denoise_timesteps):
701 |
702 | all_x = []
703 | delta_t = 1.0 / denoise_timesteps # i.e. step size
704 | # self.fid.reset()
705 |
706 | x = eps_i.to(self.device)
707 |
708 | for ti in range(denoise_timesteps):
709 | # t should in range [0,1]
710 | t = ti / denoise_timesteps
711 |
712 | t_vector = torch.full((eps_i.shape[0],), t).to(self.device)
713 | # t_vector = torch.full((eps.shape[0],), t)
714 | dt_base = torch.ones_like(t_vector).to(self.device) * math.log2(denoise_timesteps)
715 | # dt_base = torch.ones_like(t_vector) * math.log2(denoise_timesteps)
716 |
717 |
718 | # only on last step!
719 | # if i == len(denoise_timesteps_list)-1:
720 | # with torch.no_grad():
721 | # v_cond = self.forward(eps, t_vector, dt_base, labels_real)
722 | # v_uncond = self.forward(eps, t_vector, dt_base, labels_uncond)
723 |
724 | # v = v_uncond + CFG_SCALE * (v_cond - v_uncond)
725 | # else:
726 | # # t is same for all latents
727 | # with torch.no_grad():
728 | # v = self.forward(eps, t_vector, dt_base, labels_real)
729 |
730 | with torch.no_grad():
731 | v = self.forward(x, t_vector, dt_base, labels_real)
732 |
733 | x = x + v*delta_t
734 |
735 | # log 8 steps
736 | if denoise_timesteps <= 8 or ti % (denoise_timesteps//8) ==0 or ti == denoise_timesteps-1:
737 | with torch.no_grad():
738 | decoded = self.vae.decode(x/self.vae.config.scaling_factor)[0]
739 |
740 | decoded = decoded.to("cpu")
741 |
742 | all_x.append(decoded)
743 |
744 | if(len(all_x)==9):
745 | all_x = all_x[1:]
746 |
747 | # estimate FID metric
748 | decoded_denormalized = 255 * ((decoded - torch.min(decoded)) / (torch.max(decoded)-torch.min(decoded)+1e-8))
749 |
750 | # generated images
751 | self.fids[i].update(images.to(torch.uint8).to(self.device), real=True)
752 | self.fids[i].update(decoded_denormalized.to(torch.uint8).to(self.device), real=False)
753 |
754 |
755 | # fid_val = fid.compute()
756 |
757 | # log only a single batch of generated images and only on first device
758 | if self.trainer.is_global_zero and batch_idx == 0:
759 |
760 | all_x = torch.stack(all_x)
761 |
762 | def process_img(img):
763 | # normalize in range [0,1]
764 | img = img*0.5 + 0.5
765 | img = torch.clip(img, 0, 1)
766 | img = img.permute(1,2,0)
767 | return img
768 |
769 | fig, axs = plt.subplots(8, 8, figsize=(30,30))
770 | for t in range(min(8, all_x.shape[0])):
771 | for j in range(8):
772 | axs[t, j].imshow(process_img(all_x[t, j]), vmin=0, vmax=1)
773 |
774 |
775 | fig.savefig(f"log_images3/epoch:{self.trainer.current_epoch}_denoise_timesteps:{denoise_timesteps}.png")
776 |
777 | # self.logger.experiment.add_figure(f"epoch:{self.trainer.current_epoch}_denoise_timesteps:{denoise_timesteps}", fig, global_step=self.global_step)
778 | # self.logger.experiment.add_figure(f"denoise_timesteps:{denoise_timesteps}", fig, global_step=self.global_step)
779 |
780 | self.logger.experiment.log({f"denoise_timesteps:{denoise_timesteps}" : [wandb.Image(fig)]})
781 | # log_image(key=f"denoise_timesteps:{denoise_timesteps}", images=wandb.Image(fig))
782 |
783 | plt.close()
784 |
785 | return 0
786 |
787 |
788 | # def validation_epoch_start(self,):
789 |
790 | # if self.trainer.is_global_zero:
791 |
792 | # self.fid.reset()
793 |
794 | def on_validation_epoch_end(self):
795 | # if self.trainer.is_global_zero:
796 | for i in range(len(self.fids)):
797 | denoise_timesteps_i = self.denoise_timesteps[i]
798 |
799 | # Compute FID for the current timestep
800 | fid_val_i = self.fids[i].compute()
801 | self.fids[i].reset()
802 |
803 | # print(f"i: {i} | fid_val_i: {fid_val_i}")
804 |
805 | # Log the FID value
806 | self.log(f"[FID] denoise_steps: {denoise_timesteps_i}", fid_val_i, on_epoch=True, on_step=False, sync_dist=True)
807 |
808 |
809 | # def on_validation_epoch_end(self, outputs):
810 |
811 | # if self.trainer.is_global_zero:
812 |
813 | # for i in range(len(self.fids)):
814 |
815 | # denoise_timesteps_i = self.denoise_timesteps[i]
816 |
817 | # fid_val_i = self.fids[i].compute()
818 |
819 | # self.fids[i].reset()
820 |
821 | # self.log(f"steps: {denoise_timesteps_i}", fid_val_i, on_epoch=True, on_step=False)
822 |
823 | # fid_val = self.fid.compute()
824 | # self.fid.reset()
825 |
826 | def configure_optimizers(self):
827 |
828 | optimizer = torch.optim.AdamW(self.parameters(), lr=LEARNING_RATE, weight_decay=0.1)
829 |
830 | return optimizer
831 |
832 |
833 |
834 | def forward_with_cfg(self, x, t, y, cfg_scale):
835 | """
836 | Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
837 | """
838 | # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
839 | half = x[: len(x) // 2]
840 | combined = torch.cat([half, half], dim=0)
841 | model_out = self.forward(combined, t, y)
842 | # For exact reproducibility reasons, we apply classifier-free guidance on only
843 | # three channels by default. The standard approach to cfg applies it to all channels.
844 | # This can be done by uncommenting the following line and commenting-out the line following that.
845 | # eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
846 | eps, rest = model_out[:, :3], model_out[:, 3:]
847 | cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
848 | half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
849 | eps = torch.cat([half_eps, half_eps], dim=0)
850 | return torch.cat([eps, rest], dim=1)
851 |
852 |
853 | #################################################################################
854 | # Sine/Cosine Positional Embedding Functions #
855 | #################################################################################
856 | # https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
857 |
858 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
859 | """
860 | grid_size: int of the grid height and width
861 | return:
862 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
863 | """
864 | grid_h = np.arange(grid_size, dtype=np.float32)
865 | grid_w = np.arange(grid_size, dtype=np.float32)
866 | grid = np.meshgrid(grid_w, grid_h) # here w goes first
867 | grid = np.stack(grid, axis=0)
868 |
869 | grid = grid.reshape([2, 1, grid_size, grid_size])
870 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
871 | if cls_token and extra_tokens > 0:
872 | pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
873 | return pos_embed
874 |
875 |
876 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
877 | assert embed_dim % 2 == 0
878 |
879 | # use half of dimensions to encode grid_h
880 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
881 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
882 |
883 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
884 | return emb
885 |
886 |
887 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
888 | """
889 | embed_dim: output dimension for each position
890 | pos: a list of positions to be encoded: size (M,)
891 | out: (M, D)
892 | """
893 | assert embed_dim % 2 == 0
894 | omega = np.arange(embed_dim // 2, dtype=np.float64)
895 | omega /= embed_dim / 2.
896 | omega = 1. / 10000**omega # (D/2,)
897 |
898 | pos = pos.reshape(-1) # (M,)
899 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
900 |
901 | emb_sin = np.sin(out) # (M, D/2)
902 | emb_cos = np.cos(out) # (M, D/2)
903 |
904 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
905 | return emb
906 |
907 |
908 | #################################################################################
909 | # DiT Configs #
910 | #################################################################################
911 |
912 | def DiT_XL_2(**kwargs):
913 | return DiT(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs)
914 |
915 | def DiT_XL_4(**kwargs):
916 | return DiT(depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs)
917 |
918 | def DiT_XL_8(**kwargs):
919 | return DiT(depth=28, hidden_size=1152, patch_size=8, num_heads=16, **kwargs)
920 |
921 | def DiT_L_2(**kwargs):
922 | return DiT(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs)
923 |
924 | def DiT_L_4(**kwargs):
925 | return DiT(depth=24, hidden_size=1024, patch_size=4, num_heads=16, **kwargs)
926 |
927 | def DiT_L_8(**kwargs):
928 | return DiT(depth=24, hidden_size=1024, patch_size=8, num_heads=16, **kwargs)
929 |
930 | # default model used in shortcut
931 | def DiT_B_2(**kwargs):
932 | return DiT(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs)
933 |
934 | def DiT_B_4(**kwargs):
935 | return DiT(depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs)
936 |
937 | def DiT_B_8(**kwargs):
938 | return DiT(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs)
939 |
940 | def DiT_S_2(**kwargs):
941 | return DiT(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs)
942 |
943 | def DiT_S_4(**kwargs):
944 | return DiT(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs)
945 |
946 | def DiT_S_8(**kwargs):
947 | return DiT(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs)
948 |
949 |
950 | DiT_models = {
951 | 'DiT-XL/2': DiT_XL_2, 'DiT-XL/4': DiT_XL_4, 'DiT-XL/8': DiT_XL_8,
952 | 'DiT-L/2': DiT_L_2, 'DiT-L/4': DiT_L_4, 'DiT-L/8': DiT_L_8,
953 | 'DiT-B/2': DiT_B_2, 'DiT-B/4': DiT_B_4, 'DiT-B/8': DiT_B_8,
954 | 'DiT-S/2': DiT_S_2, 'DiT-S/4': DiT_S_4, 'DiT-S/8': DiT_S_8,
955 | }
--------------------------------------------------------------------------------
/samples/epoch:299_denoise_timesteps:1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/smileyenot983/shortcut_pytorch/7fee517af0b15dd8bd6ec89371f2e62e3947b4f0/samples/epoch:299_denoise_timesteps:1.png
--------------------------------------------------------------------------------
/samples/epoch:299_denoise_timesteps:128.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/smileyenot983/shortcut_pytorch/7fee517af0b15dd8bd6ec89371f2e62e3947b4f0/samples/epoch:299_denoise_timesteps:128.png
--------------------------------------------------------------------------------
/samples/epoch:299_denoise_timesteps:16.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/smileyenot983/shortcut_pytorch/7fee517af0b15dd8bd6ec89371f2e62e3947b4f0/samples/epoch:299_denoise_timesteps:16.png
--------------------------------------------------------------------------------
/samples/epoch:299_denoise_timesteps:2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/smileyenot983/shortcut_pytorch/7fee517af0b15dd8bd6ec89371f2e62e3947b4f0/samples/epoch:299_denoise_timesteps:2.png
--------------------------------------------------------------------------------
/samples/epoch:299_denoise_timesteps:4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/smileyenot983/shortcut_pytorch/7fee517af0b15dd8bd6ec89371f2e62e3947b4f0/samples/epoch:299_denoise_timesteps:4.png
--------------------------------------------------------------------------------
/samples/epoch:299_denoise_timesteps:8.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/smileyenot983/shortcut_pytorch/7fee517af0b15dd8bd6ec89371f2e62e3947b4f0/samples/epoch:299_denoise_timesteps:8.png
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import numpy as np
3 | import pandas as pd
4 | import io
5 | import os
6 | from copy import deepcopy
7 | from collections import OrderedDict
8 | import math
9 | from tqdm import tqdm
10 | import matplotlib.pyplot as plt
11 |
12 | import torch
13 | import torchvision
14 | from torch.utils.data import Dataset, DataLoader
15 | from torchvision import transforms
16 | from torchvision.utils import save_image
17 |
18 | from torchmetrics.image.fid import FrechetInceptionDistance
19 |
20 | from diffusers.models import AutoencoderKL
21 |
22 | import pytorch_lightning as pl
23 | from pytorch_lightning.callbacks import ModelCheckpoint
24 | from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
25 |
26 |
27 | from model import DiT_B_2, EMACallback
28 | from utils import create_targets, create_targets_naive
29 |
30 | def count_parameters(model):
31 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
32 |
33 | LEARNING_RATE = 0.0001
34 | BATCH_SIZE = 64
35 | EVAL_SIZE = 8
36 |
37 | NUM_CLASSES = 1
38 | CLASS_DROPOUT_PROB = 1.0
39 |
40 | N_EPOCHS = 100
41 | DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
42 |
43 | LOG_EVERY = 3
44 | CFG_SCALE = 0.0
45 |
46 |
47 | class CelebaHQDataset(Dataset):
48 | def __init__(self, parquet_path, transform=None, size=256):
49 |
50 |
51 | parquet_names = os.listdir(parquet_path)
52 | parquet_paths = [os.path.join(parquet_path, parquet_name) for parquet_name in parquet_names]
53 |
54 | if len(parquet_paths) < 1:
55 | return FileNotFoundError
56 |
57 | parquets = []
58 |
59 | for i in range(len(parquet_paths)):
60 | parquet_i = pd.read_parquet(parquet_paths[i])
61 | parquets.append(parquet_i)
62 |
63 | self.data = pd.concat(parquets, axis=0)
64 | self.size = size
65 | self.transform = transform
66 |
67 | # print(f"self.data: {self.data}")
68 |
69 | def __len__(self):
70 |
71 | return self.data.shape[0]
72 |
73 | def __getitem__(self, idx):
74 |
75 | data_i = self.data.iloc[[idx]]
76 |
77 | image_i = Image.open(io.BytesIO(data_i['image.bytes'].item())).resize((self.size, self.size))
78 |
79 | if self.transform is not None:
80 | image_i = self.transform(image_i)
81 |
82 | label_i = data_i['label'].item()
83 |
84 | return image_i, label_i
85 |
86 | @torch.no_grad()
87 | def update_ema(ema_model, model, decay=0.9999):
88 | """
89 | Step the EMA model towards the current model.
90 | """
91 | ema_params = OrderedDict(ema_model.named_parameters())
92 | model_params = OrderedDict(model.named_parameters())
93 |
94 | for name, param in model_params.items():
95 | # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed
96 | ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)
97 |
98 | def train_epoch(train_dataloader, dit, ema, vae, optimizer):
99 |
100 | loss_fn = torch.nn.MSELoss()
101 |
102 | dit.train()
103 |
104 | total_loss = 0.0
105 | for batch, (images, labels) in enumerate(tqdm(train_dataloader)):
106 |
107 | images, labels = images.to(DEVICE), labels.to(DEVICE)
108 |
109 | # print(f"torch.mean(images): {torch.mean(images)} | torch.std(images): {torch.std(images)}")
110 |
111 | with torch.no_grad():
112 |
113 | latents = vae.encode(images).latent_dist.sample() * vae.config.scaling_factor
114 |
115 | print(f"latents.shape: {latents.shape}")
116 | # print(f"labels.shape: {labels.shape}")
117 |
118 | x_t, v_t, t, dt_base, labels_dropped = create_targets(latents, labels, dit)
119 |
120 | # print(f"x_t.shape: {x_t.shape}")
121 | # print(f"v_t.shape: {v_t.shape}")
122 | # print(f"t.shape: {t.shape}")
123 | # print(f"dt_base.shape: {dt_base.shape}")
124 | # print(f"labels_dropped.shape: {labels_dropped.shape}")
125 |
126 | # exit()
127 |
128 | v_prime = dit(x_t, t, dt_base, labels)
129 |
130 | # print(f"v_prime.shape: {v_prime.shape}")
131 | # print(f"v_prime: {v_prime}")
132 |
133 | loss = loss_fn(v_prime, v_t)
134 |
135 | total_loss += loss.item()
136 | optimizer.zero_grad()
137 |
138 | loss.backward()
139 |
140 | optimizer.step()
141 |
142 | update_ema(ema, dit)
143 |
144 | return total_loss / len(train_dataloader)
145 |
146 | def evaluate(dit, vae, val_dataloader, epoch):
147 |
148 | dit.eval()
149 |
150 | images, labels_real = next(iter(val_dataloader))
151 | images, labels_real = images[:EVAL_SIZE].to(DEVICE), labels_real[:EVAL_SIZE].to(DEVICE)
152 | with torch.no_grad():
153 | latents = vae.encode(images).latent_dist.sample() * vae.config.scaling_factor
154 |
155 |
156 | labels_uncond = torch.ones_like(labels_real, dtype=torch.int32) * NUM_CLASSES
157 |
158 | fid = FrechetInceptionDistance().to(DEVICE)
159 |
160 |
161 | images = 255 * ((images - torch.min(images)) / (torch.max(images) - torch.min(images) + 1e-8))
162 |
163 | # all start from same noise
164 | eps = torch.randn_like(latents).to(DEVICE)
165 |
166 |
167 |
168 | denoise_timesteps_list = [1, 2, 4, 8, 16, 32, 128]
169 | # run at varying timesteps
170 | for i, denoise_timesteps in enumerate(denoise_timesteps_list):
171 | all_x = []
172 | delta_t = 1.0 / denoise_timesteps # i.e. step size
173 | fid.reset()
174 |
175 | x = eps
176 |
177 | for ti in range(denoise_timesteps):
178 | # t should in range [0,1]
179 | t = ti / denoise_timesteps
180 |
181 | t_vector = torch.full((eps.shape[0],), t).to(DEVICE)
182 | dt_base = torch.ones_like(t_vector).to(DEVICE) * math.log2(denoise_timesteps)
183 |
184 |
185 | # if i == len(denoise_timesteps_list)-1:
186 | # with torch.no_grad():
187 | # v_cond = dit.forward(x, t_vector, dt_base, labels_real)
188 | # v_uncond = dit.forward(x, t_vector, dt_base, labels_uncond)
189 |
190 | # v = v_uncond + CFG_SCALE * (v_cond - v_uncond)
191 | # else:
192 | # # t is same for all latents
193 | # with torch.no_grad():
194 | # v = dit.forward(x, t_vector, dt_base, labels_real)
195 |
196 | with torch.no_grad():
197 | v = dit.forward(x, t_vector, dt_base, labels_real)
198 |
199 | x = x + v * delta_t
200 |
201 | if denoise_timesteps <= 8 or ti % (denoise_timesteps//8) == 0 or ti == denoise_timesteps-1:
202 |
203 | with torch.no_grad():
204 | decoded = vae.decode(x/vae.config.scaling_factor)[0]
205 |
206 | decoded = decoded.to("cpu")
207 |
208 | all_x.append(decoded)
209 |
210 |
211 | if(len(all_x)==9):
212 | all_x = all_x[1:]
213 |
214 | # estimate FID metric
215 | # images_fake = torch.randint(low=0, high=255, size=images.shape).to(torch.uint8).to(DEVICE)
216 | decoded_denormalized = 255 * ((decoded - torch.min(decoded)) / (torch.max(decoded)-torch.min(decoded)+1e-8))
217 |
218 | # generated images
219 | fid.update(images.to(torch.uint8), real=True)
220 | fid.update(decoded_denormalized.to(torch.uint8).to(DEVICE), real=False)
221 | fid_val = fid.compute()
222 | print(f"denoise_timesteps: {denoise_timesteps} | fid_val: {fid_val}")
223 |
224 | all_x = torch.stack(all_x)
225 |
226 | def process_img(img):
227 | # normalize in range [0,1]
228 | img = img * 0.5 + 0.5
229 | img = torch.clip(img, 0, 1)
230 | img = img.permute(1,2,0)
231 | return img
232 |
233 | fig, axs = plt.subplots(8, 8, figsize=(30,30))
234 | for t in range(min(8, all_x.shape[0])):
235 | for j in range(8):
236 | axs[t, j].imshow(process_img(all_x[t, j]), vmin=0, vmax=1)
237 |
238 | fig.savefig(f"log_images_tvanilla/epoch:{epoch}_denoise_timesteps:{denoise_timesteps}.png")
239 | # if i == len(denoise_timesteps_list)-1:
240 | # fig.savefig(f"log_images_tvanilla/epoch:{epoch}_cfg.png")
241 | # else:
242 | # fig.savefig(f"log_images_tvanilla/epoch:{epoch}_denoise_timesteps:{denoise_timesteps}.png")
243 |
244 | plt.close()
245 |
246 | def requires_grad(model, flag=True):
247 | """
248 | Set requires_grad flag for all parameters in a model.
249 | """
250 | for p in model.parameters():
251 | p.requires_grad = flag
252 |
253 |
254 | # single gpu: [00:55<10:24, 0.64it/s, v_num=5]
255 | # 3 gpus: [03:35<00:00, 0.68it/s, v_num=7]
256 | def main_lightning():
257 |
258 | # should be same as in flax implementation(83'653'863 params)
259 |
260 |
261 |
262 |
263 |
264 |
265 | train_transform = transforms.Compose([
266 | transforms.ToTensor(),
267 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
268 | ])
269 |
270 | # 28000 images
271 | train_dataset = CelebaHQDataset('/workspace/shortcut_pytorch/celeba-hq/data/train', transform=train_transform)
272 | # 2000 images
273 | val_dataset = CelebaHQDataset('/workspace/shortcut_pytorch/celeba-hq/data/val', transform=train_transform)
274 |
275 |
276 |
277 | print(f"len(train_dataset): {len(train_dataset)}")
278 | print(f"len(val_dataset): {len(val_dataset)}")
279 |
280 |
281 |
282 | # good option is 2*num_gpus
283 | train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, drop_last=True)
284 | val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, drop_last=True)
285 |
286 |
287 | images_i, labels_i = next(iter(train_dataloader))
288 |
289 | # 4 - 2d mean, 2d std
290 | latent_shape = (BATCH_SIZE, 4, images_i.shape[2]//8, images_i.shape[2]//8)
291 |
292 | dit = DiT_B_2(learn_sigma=False,
293 | num_classes=NUM_CLASSES,
294 | class_dropout_prob=CLASS_DROPOUT_PROB,
295 | lightning_mode=True,
296 | latent_shape=latent_shape,
297 | training_type="shortcut")
298 |
299 | # if load from checkpoint:
300 | # checkpoint_path = "/workspace/shortcut_pytorch/tb_logs/shortcut_model/version_0/checkpoints/last.ckpt"
301 | # checkpoint = torch.load(checkpoint_path)
302 | # dit.load_state_dict(checkpoint['state_dict'])
303 |
304 | print(f"count_parameters(dit): {count_parameters(dit)}")
305 |
306 | callbacks = []
307 |
308 | # Define a checkpoint callback
309 | checkpoint_callback = ModelCheckpoint(
310 | filename="model-{epoch:02d}",
311 | save_last=True,
312 | )
313 | ema_callback = EMACallback(decay=0.999)
314 |
315 | callbacks.append(checkpoint_callback)
316 | callbacks.append(ema_callback)
317 |
318 | # logger = TensorBoardLogger("tb_logs", name="shortcut_model")
319 | logger = WandbLogger("shortcut_model")
320 |
321 | trainer = pl.Trainer(max_epochs=1000,
322 | accelerator="gpu",
323 | num_sanity_val_steps=1,
324 | check_val_every_n_epoch=50,
325 | limit_val_batches=1.0,
326 | devices=[0, 1, 2, 3],
327 | strategy="ddp_find_unused_parameters_false",
328 | callbacks=callbacks,
329 | logger=logger)
330 |
331 | trainer.fit(model=dit, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)
332 |
333 |
334 | # single epoch takes: [15:45<00:00, 2.16s/it]
335 | def main():
336 |
337 | train_transform = transforms.Compose([
338 | transforms.ToTensor(),
339 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
340 | ])
341 |
342 | train_dataset = CelebaHQDataset('/workspace/shortcut_pytorch/celeba-hq/data/train', transform=train_transform)
343 | val_dataset = CelebaHQDataset('/workspace/shortcut_pytorch/celeba-hq/data/val', transform=train_transform)
344 |
345 | print(f"len(train_dataset): {len(train_dataset)}")
346 | print(f"len(val_dataset): {len(val_dataset)}")
347 |
348 | # good option is 2*num_gpus
349 | train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, drop_last=True)
350 | val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, drop_last=True)
351 |
352 | images_i, labels_i = next(iter(train_dataloader))
353 |
354 | # 4 - 2d mean, 2d std
355 | latent_shape = (BATCH_SIZE, 4, images_i.shape[2]//8, images_i.shape[2]//8)
356 |
357 | dit = DiT_B_2(learn_sigma=False,
358 | num_classes=NUM_CLASSES,
359 | class_dropout_prob=CLASS_DROPOUT_PROB,
360 | latent_shape=latent_shape,
361 | training_type="shortcut").to(DEVICE)
362 |
363 | vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(DEVICE)
364 | vae = vae.eval()
365 | vae.requires_grad_(False)
366 |
367 | print(f"count_parameters(dit): {count_parameters(dit)}")
368 | ema = deepcopy(dit).to(DEVICE)
369 | ema.requires_grad_(False)
370 |
371 | checkpoint_path = "dit_saved.pth"
372 | checkpoint = torch.load(checkpoint_path)
373 | dit.load_state_dict(torch.load(checkpoint_path))
374 |
375 |
376 |
377 | optimizer = torch.optim.AdamW(dit.parameters(), lr=LEARNING_RATE, weight_decay=0.1)
378 |
379 | update_ema(ema, dit, decay=0)
380 | ema.eval()
381 |
382 | # evaluate(dit, vae, val_dataloader, 0)
383 |
384 | for i in range(N_EPOCHS):
385 |
386 | epoch_loss = train_epoch(train_dataloader, dit, ema, vae, optimizer)
387 | exit()
388 | print(f"epoch_loss: {epoch_loss}")
389 |
390 | if i%LOG_EVERY == 0 and i > 0:
391 | evaluate(dit, vae, val_dataloader, i)
392 |
393 | torch.save(dit.state_dict(), "dit_saved.pth")
394 |
395 |
396 | if __name__ == "__main__":
397 | # main()
398 | main_lightning()
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import os
4 | import pytorch_lightning as pl
5 | from pytorch_lightning import Callback
6 | from pytorch_lightning.callbacks import ModelCheckpoint
7 | from pytorch_lightning.utilities import rank_zero_warn, rank_zero_info
8 | from pytorch_lightning.utilities.exceptions import MisconfigurationException
9 | from pytorch_lightning.utilities.types import STEP_OUTPUT
10 |
11 | from typing import Any, Dict, List, Optional
12 |
13 |
14 | DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
15 | BOOTSTRAP_EVERY = 8
16 | DENOISE_TIMESTEPS = 128
17 | CLASS_DROPOUT_PROB = 1.0
18 | NUM_CLASSES = 1
19 | BATCH_SIZE = 64
20 |
21 | # create batch, consisting of different timesteps and different dts(depending on total step sizes)
22 | def create_targets(images, labels, model):
23 |
24 | model.eval()
25 |
26 | current_batch_size = images.shape[0]
27 |
28 | FORCE_T = -1
29 | FORCE_DT = -1
30 |
31 | # 1. create step sizes dt
32 | bootstrap_batch_size = current_batch_size // BOOTSTRAP_EVERY #=8
33 | log2_sections = int(math.log2(DENOISE_TIMESTEPS))
34 | # print(f"log2_sections: {log2_sections}")
35 | # print(f"bootstrap_batch_size: {bootstrap_batch_size}")
36 |
37 | dt_base = torch.repeat_interleave(log2_sections - 1 - torch.arange(log2_sections), bootstrap_batch_size // log2_sections)
38 | # print(f"dt_base: {dt_base}")
39 |
40 | dt_base = torch.cat([dt_base, torch.zeros(bootstrap_batch_size-dt_base.shape[0],)])
41 | # print(f"dt_base: {dt_base}")
42 |
43 |
44 |
45 | force_dt_vec = torch.ones(bootstrap_batch_size) * FORCE_DT
46 | dt_base = torch.where(force_dt_vec != -1, force_dt_vec, dt_base).to(model.device)
47 | dt = 1 / (2 ** (dt_base)) # [1, 1/2, 1/8, 1/16, 1/32]
48 | # print(f"dt: {dt}")
49 |
50 | dt_base_bootstrap = dt_base + 1
51 | dt_bootstrap = dt / 2 # [0.0078125 0.015625 0.03125 0.0625 0.125 0.25 0.5 0.5]
52 | # print(f"dt_bootstrap: {dt_bootstrap}")
53 |
54 | # 2. sample timesteps t
55 | dt_sections = 2**dt_base
56 |
57 | # print(f"dt_sections: {dt_sections}")
58 |
59 | t = torch.cat([
60 | torch.randint(low=0, high=int(val.item()), size=(1,)).float()
61 | for val in dt_sections
62 | ]).to(model.device)
63 |
64 | # print(f"t[randint]: {t}")
65 | t = t / dt_sections
66 | # print(f"t[normalized]: {t}")
67 |
68 | force_t_vec = torch.ones(bootstrap_batch_size, dtype=torch.float32).to(model.device) * FORCE_T
69 | t = torch.where(force_t_vec != -1, force_t_vec, t).to(model.device)
70 | t_full = t[:, None, None, None]
71 |
72 | # print(f"t_full: {t_full}")
73 |
74 | # 3. generate bootstrap targets:
75 | x_1 = images[:bootstrap_batch_size]
76 | x_0 = torch.randn_like(x_1)
77 |
78 | # get dx at timestep t
79 | x_t = (1 - (1-1e-5) * t_full)*x_0 + t_full*x_1
80 |
81 | bst_labels = labels[:bootstrap_batch_size]
82 |
83 |
84 | with torch.no_grad():
85 | v_b1 = model(x_t, t, dt_base_bootstrap, bst_labels)
86 |
87 | t2 = t + dt_bootstrap
88 | x_t2 = x_t + dt_bootstrap[:, None, None, None] * v_b1
89 | x_t2 = torch.clip(x_t2, -4, 4)
90 |
91 | with torch.no_grad():
92 | v_b2 = model(x_t2, t2, dt_base_bootstrap, bst_labels)
93 |
94 | v_target = (v_b1 + v_b2) / 2
95 |
96 | v_target = torch.clip(v_target, -4, 4)
97 |
98 | bst_v = v_target
99 | bst_dt = dt_base
100 | bst_t = t
101 | bst_xt = x_t
102 | bst_l = bst_labels
103 |
104 | # 4. generate flow-matching targets
105 |
106 | labels_dropout = torch.bernoulli(torch.full(labels.shape, CLASS_DROPOUT_PROB)).to(model.device)
107 | labels_dropped = torch.where(labels_dropout.bool(), NUM_CLASSES, labels)
108 |
109 | # sample t(normalized)
110 | t = torch.randint(low=0, high=DENOISE_TIMESTEPS, size=(images.shape[0],), dtype=torch.float32)
111 | # print(f"t: {t}")
112 | t /= DENOISE_TIMESTEPS
113 | # print(f"t: {t}")
114 | force_t_vec = torch.ones(images.shape[0]) * FORCE_T
115 | # force_t_vec = torch.full((images.shape[0],), FORCE_T, dtype=torch.float32)
116 | t = torch.where(force_t_vec != -1, force_t_vec, t).to(model.device)
117 | # t_full = t.view(-1, 1, 1, 1)
118 | t_full = t[:, None, None, None]
119 |
120 | # print(f"t_full: {t_full}")
121 |
122 | # sample flow pairs x_t, v_t
123 | x_0 = torch.randn_like(images).to(model.device)
124 | x_1 = images
125 | x_t = (1 - (1 - 1e-5) * t_full) * x_0 + t_full * x_1
126 | v_t = x_1 - (1 - 1e-5) * x_0
127 |
128 | dt_flow = int(math.log2(DENOISE_TIMESTEPS))
129 | dt_base = (torch.ones(images.shape[0], dtype=torch.int32) * dt_flow).to(model.device)
130 |
131 | # 5. merge flow and bootstrap
132 | bst_size = current_batch_size // BOOTSTRAP_EVERY
133 | bst_size_data = current_batch_size - bst_size
134 |
135 | # print(f"bst_size: {bst_size}")
136 | # print(f"bst_size_data: {bst_size_data}")
137 |
138 | x_t = torch.cat([bst_xt, x_t[:bst_size_data]], dim=0)
139 | t = torch.cat([bst_t, t[:bst_size_data]], dim=0)
140 |
141 | dt_base = torch.cat([bst_dt, dt_base[:bst_size_data]], dim=0)
142 | v_t = torch.cat([bst_v, v_t[:bst_size_data]], dim=0)
143 | labels_dropped = torch.cat([bst_l, labels_dropped[:bst_size_data]], dim=0)
144 |
145 | return x_t, v_t, t, dt_base, labels_dropped
146 |
147 | def create_targets_naive(images, labels, model):
148 |
149 | model.eval()
150 |
151 | current_batch_size = images.shape[0]
152 |
153 | FORCE_T = -1
154 | FORCE_DT = -1
155 |
156 | labels_dropout = torch.bernoulli(torch.full(labels.shape, CLASS_DROPOUT_PROB)).to(model.device)
157 | labels_dropped = torch.where(labels_dropout.bool(), NUM_CLASSES, labels)
158 |
159 | # sample t(normalized)
160 | t = torch.randint(low=0, high=DENOISE_TIMESTEPS, size=(images.shape[0],), dtype=torch.float32)
161 | # print(f"t: {t}")
162 | t /= DENOISE_TIMESTEPS
163 | # print(f"t: {t}")
164 | force_t_vec = torch.ones(images.shape[0]) * FORCE_T
165 | # force_t_vec = torch.full((images.shape[0],), FORCE_T, dtype=torch.float32)
166 | t = torch.where(force_t_vec != -1, force_t_vec, t).to(model.device)
167 | # t_full = t.view(-1, 1, 1, 1)
168 | t_full = t[:, None, None, None]
169 |
170 |
171 | x_0 = torch.randn_like(images).to(model.device)
172 | x_1 = images
173 | x_t = (1 - (1 - 1e-5) * t_full) * x_0 + t_full * x_1
174 | v_t = x_1 - (1 - 1e-5) * x_0
175 |
176 | dt_flow = int(math.log2(DENOISE_TIMESTEPS))
177 | dt_base = (torch.ones(images.shape[0], dtype=torch.int32) * dt_flow).to(model.device)
178 |
179 | return x_t, v_t, t, dt_base, labels_dropped
180 |
181 |
182 |
183 |
184 |
185 |
--------------------------------------------------------------------------------