├── .gitignore
├── README.hf.md
├── README.md
├── ae.py
├── challenge.ipynb
├── contents
├── appg.png
├── buildings.jpg
├── chinatown.jpg
├── cosplayers.jpg
├── flowers.jpg
├── lavender.jpg
├── logo.png
├── magazines.jpg
├── origin.png
├── randomwoman.jpeg
├── recon.png
├── ti2.png
└── ti2_mask.png
├── inference.ipynb
├── launcher.sh
├── scripts
└── launch_hdr.sh
├── sweep.sh
├── tae.py
├── tester.py
├── tester_upload.sh
├── unit_activation_reinitializer.py
├── utils.py
└── vae_trainer.py
/.gitignore:
--------------------------------------------------------------------------------
1 | wandb
2 | *.torrent
3 | celebvq
4 | __pycache__
5 | vgg.pth
6 | *_tester.ipynb
7 | ckpt
8 | *.pt
9 | *.pth
10 | data
11 | mnist_data
--------------------------------------------------------------------------------
/README.hf.md:
--------------------------------------------------------------------------------
1 | ---
2 | license: apache-2.0
3 | ---
4 |
5 | # Equivariant 16ch, f8 VAE
6 |
7 |
8 |
9 | AuraEquiVAE is a novel autoencoder that addresses multiple problems of existing conventional VAEs. First, unlike traditional VAEs that have significantly small log-variance, this model admits large noise to the latent space.
10 | Additionally, unlike traditional VAEs, the latent space is equivariant under `Z_2 X Z_2` group operations (Horizontal / Vertical flip).
11 |
12 | To understand the equivariance, we apply suitable group actions to both the latent space globally and locally. The latent is represented as `Z = (z_1, ..., z_n)`, and we perform a global permutation group action `g_global` on the tuples such that `g_global` is isomorphic to the `Z_2 x Z_2` group.
13 | We also apply a local action `g_local` to individual `z_i` elements such that `g_local` is also isomorphic to the `Z_2 x Z_2` group.
14 |
15 | In our specific case, `g_global` corresponds to flips, while `g_local` corresponds to sign flips on specific latent dimensions. Changing 2 channels for sign flips for both horizontal and vertical directions was chosen empirically.
16 |
17 | The model has been trained using the approach described in [Mastering VAE Training](https://github.com/cloneofsimo/vqgan-training), where detailed explanations for the training process can be found.
18 |
19 | ## How to use
20 |
21 | To use the weights, copy paste the [VAE](https://github.com/cloneofsimo/vqgan-training/blob/03e04401cf49fe55be612d1f568be0110aa0fad1/ae.py) implementation.
22 |
23 | ```python
24 | from ae import VAE
25 | import torch
26 | from PIL import Image
27 |
28 | vae = VAE(
29 | resolution=256,
30 | in_channels=3,
31 | ch=256,
32 | out_ch=3,
33 | ch_mult=[1, 2, 4, 4],
34 | num_res_blocks=2,
35 | z_ch
36 | ).cuda().bfloat16()
37 |
38 | from safetensors.torch import load_file
39 | state_dict = load_file("./vae_epoch_3_step_49501_bf16.pt")
40 | vae.load_state_dict(state_dict)
41 |
42 | imgpath = 'contents/lavender.jpg'
43 |
44 | img_orig = Image.open(imgpath).convert("RGB")
45 | offset = 128
46 | W = 768
47 | img_orig = img_orig.crop((offset, offset, W + offset, W + offset))
48 | img = transforms.ToTensor()(img_orig).unsqueeze(0).cuda()
49 | img = (img - 0.5) / 0.5
50 |
51 | with torch.no_grad():
52 | z = vae.encoder(img)
53 | z = z.clamp(-8.0, 8.0) # this is latent!!
54 |
55 | # flip horizontal
56 | z = torch.flip(z, [-1]) # this corresponds to g_global
57 | z[:, -4:-2] = -z[:, -4:-2] # this corresponds to g_local
58 |
59 | # flip vertical
60 | z = torch.flip(z, [-2])
61 | z[:, -2:] = -z[:, -2:]
62 |
63 |
64 | with torch.no_grad():
65 | decz = vae.decoder(z) # this is image!
66 |
67 | decimg = ((decz + 1) / 2).clamp(0, 1).squeeze(0).cpu().float().numpy().transpose(1, 2, 0)
68 | decimg = (decimg * 255).astype('uint8')
69 | decimg = Image.fromarray(decimg) # PIL image.
70 |
71 | ```
72 |
73 | ## Citation
74 |
75 | If you find this model useful, please cite:
76 |
77 | ```
78 | @misc{Training VQGAN and VAE, with detailed explanation,
79 | author = {Simo Ryu},
80 | title = {Training VQGAN and VAE, with detailed explanation},
81 | year = {2024},
82 | publisher = {GitHub},
83 | journal = {GitHub repository},
84 | howpublished = {\url{https://github.com/cloneofsimo/vqgan-training}},
85 | }
86 | ```
87 |
88 |
89 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # VAE Trainer
2 |
3 | The famous VAE of latent diffusion models, such as stable diffusion, FLUX, SORA, etc. How are they trained? This is my attempt to write distributed VAE trainer. It is largely based on [LDM's VAE](https://arxiv.org/abs/2112.10752)
4 |
5 |
6 |
7 |
8 |
9 |
10 | ## Details
11 |
12 | The VAE architecture is based on the one used in latent diffusion models, with some modifications for improved training stability and performance.
13 |
14 | 1. **Distributed Training**: Utilizes PyTorch's DistributedDataParallel (DDP) for efficient multi-GPU training.
15 |
16 | 2. **GAN Loss**: Incorporates GAN loss for enhanced image quality. A pretrained VGG16 backbone with a simple linear convolutional layer atop 16 feature maps serves as the discriminator. The hinge loss is employed for GAN loss, with a thresholding to stabilize training:
17 |
18 | $$L_{GAN} = \max(0, D(x) - D(\hat{x}) - 0.1)$$
19 |
20 | where $x$ is the input and $\hat{x} = \text{Dec}(\text{Enc}(x))$ is the reconstructed image. This way, the generator (Decoder/Encoder) is only trained when the discriminator is good enough, leaving early stage training more stable.
21 |
22 |
23 | 3. **Perceptual Loss**: Utilizes LPIPS for reconstruction loss.
24 |
25 | 4. **Gradient Normalization**: Implements gradient normalization for stable training, offering a simpler alternative to rebalancing via autograd mechanism modification. This approach normalizes gradients to address potential imbalances between loss components:
26 |
27 | $$ \nabla_{\theta} L = \frac{1}{|\nabla_{X} L_{GAN}|} \nabla_{\theta} L_{GAN} + \frac{1}{|\nabla_{X} L_{percep}|} \nabla_{\theta} L_{percep} $$
28 |
29 | where $\nabla_{X}$ denotes the gradient with respect to input. This method is implemented using a custom autograd function, modifying the backward pass while preserving the forward pass. You can use it as follows:
30 |
31 | ```python
32 | class GradNormFunction(torch.autograd.Function):
33 | @staticmethod
34 | def forward(ctx, x):
35 |
36 | return x.clone()
37 |
38 | @staticmethod
39 | def backward(ctx, grad_output):
40 |
41 | grad_output_norm = torch.linalg.norm(grad_output, dim=0, keepdim=True)
42 |
43 | grad_output_normalized = grad_output / (grad_output_norm + 1e-8)
44 |
45 | return grad_output_normalized
46 |
47 | def gradnorm(x):
48 | return GradNormFunction.apply(x)
49 |
50 | # example:
51 |
52 | x = gradnorm(x) # x.grad is now normalized to have unit norm.
53 |
54 | ```
55 |
56 | 5. **Constant Variance**: We implemented a fixed variance of 0.1, deviating from the conventional learnable variance in VAE architectures. This modification addresses the tendency of VAEs to indiscriminately decrease log variance, regardless of KL divergence penalty coefficients.
57 |
58 | 6. **Pooled Mean Squared Error (MSE)**: A pooled MSE was introduced to modify the original VAE reconstruction loss, reducing the model's sensitivity to high-frequency input details.
59 |
60 | 7. **Low-Pass Reconstruction Loss**: Our loss function combines LPIPS and MSE components to balance competing behaviors. While MSE and L1 losses often result in blurred outputs due to the inherent risk in predicting high-intensity signals at precise locations, retaining some MSE loss is crucial for accurate color tone calibration. To address this, we found using MSE loss on a low-pass filtered version of the image is helpful.
61 |
62 | This method starts off by applying a Laplacian-like filter on the original grayscale image, followed by MSE loss application on the filtered result. The actual low-pass detection filter is bit more sophisticated, first detecting high frequency region, blurring, followed by truncation.
63 |
64 | The following images illustrate an example of the filtered result:
65 |
66 |
67 |
68 |
69 |
70 |
71 | Notice how the high frequency details are colored by black: now the white region is where MSE loss is applied, making the conflict between LPIPS and MSE loss less severe.
72 |
73 |
74 | ## Files
75 |
76 | - `ae.py`: Contains the VAE architecture implementation. Most of which is based on FLUX's VAE, with some modifications such as constant variance and multi-heads.
77 |
78 | - `vae_trainer.py`: Main training script with DDP setup.
79 |
80 | - `utils.py`: Utility functions and classes, including LPIPS and PatchDiscriminator.
81 |
82 | ## Usage
83 |
84 | You need to have a dataset of png images as [webdataset](https://github.com/tmbdev/webdataset) format. You can use [img2dataset](https://github.com/rom1504/img2dataset) to download your dataset.
85 |
86 |
87 | To start training, use the following command:
88 |
89 | ```bash
90 | torchrun --nproc_per_node=8 vae_trainer.py
91 | ```
92 |
93 | This will initiate training on 8 GPUs. Adjust the number based on your available hardware.
94 |
95 | ## Configuration
96 |
97 | The trainer supports various configuration options through command-line arguments. Some key parameters include:
98 |
99 | - `--learning_rate_vae`: Learning rate for the VAE.
100 | - `--vae_ch`: Base channel size for the VAE.
101 | - `--vae_ch_mult`: Channel multipliers for the VAE.
102 | - `--do_ganloss`: Flag to enable GAN loss.
103 |
104 | For a full list of options, refer to the `train_ddp` function in `vae_trainer.py`.
105 |
106 | ## Citation
107 |
108 | If you find this dataset useful, please cite:
109 |
110 | ```
111 | @misc{Training VQGAN and VAE, with detailed explanation,
112 | author = {Simo Ryu},
113 | title = {Training VQGAN and VAE, with detailed explanation},
114 | year = {2024},
115 | publisher = {GitHub},
116 | journal = {GitHub repository},
117 | howpublished = {\url{https://github.com/cloneofsimo/vqgan-training}},
118 | }
119 | ```
120 |
121 | ---
122 | Above readme was written by claude.
123 |
--------------------------------------------------------------------------------
/ae.py:
--------------------------------------------------------------------------------
1 | # Take from FLUX
2 |
3 |
4 | import math
5 |
6 | import torch
7 | import torch.nn.functional as F
8 | from einops import rearrange
9 | from torch import Tensor, nn
10 | from utils import wavelet_transform_multi_channel
11 |
12 |
13 | def swish(x) -> Tensor:
14 | return x * torch.sigmoid(x)
15 |
16 |
17 | # class StandardizedC2d(nn.Conv2d):
18 | # def __init__(self, *args, **kwargs):
19 | # super().__init__(*args, **kwargs)
20 | # self.step = 0
21 |
22 | # def forward(self, input):
23 | # output = super().forward(input)
24 | # # normalize the weights
25 | # if self.step < 1000:
26 | # with torch.no_grad():
27 | # std = output.std().item()
28 | # normalize_term = (std + 1e-6)**(100/(self.step + 100))
29 | # self.step += 1
30 | # self.weight.data.div_(normalize_term)
31 | # self.bias.data.div_(normalize_term)
32 | # output.div_(normalize_term)
33 | # # sync the weights, braodcast
34 | # torch.distributed.broadcast(self.weight.data, 0)
35 | # torch.distributed.broadcast(self.bias.data, 0)
36 |
37 | # return output
38 | StandardizedC2d = nn.Conv2d
39 |
40 |
41 | class FP32GroupNorm(nn.GroupNorm):
42 | def __init__(self, *args, **kwargs):
43 | super().__init__(*args, **kwargs)
44 |
45 | def forward(self, input):
46 | output = F.group_norm(
47 | input.float(),
48 | self.num_groups,
49 | self.weight.float() if self.weight is not None else None,
50 | self.bias.float() if self.bias is not None else None,
51 | self.eps,
52 | )
53 | return output.type_as(input)
54 |
55 |
56 | class AttnBlock(nn.Module):
57 | def __init__(self, in_channels: int):
58 | super().__init__()
59 | self.in_channels = in_channels
60 |
61 | self.head_dim = 64
62 | self.num_heads = in_channels // self.head_dim
63 | self.norm = FP32GroupNorm(
64 | num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
65 | )
66 | self.qkv = StandardizedC2d(
67 | in_channels, in_channels * 3, kernel_size=1, bias=False
68 | )
69 | self.proj_out = StandardizedC2d(
70 | in_channels, in_channels, kernel_size=1, bias=False
71 | )
72 | nn.init.normal_(self.proj_out.weight, std=0.2 / math.sqrt(in_channels))
73 |
74 | def attention(self, h_) -> Tensor:
75 | h_ = self.norm(h_)
76 | qkv = self.qkv(h_)
77 | q, k, v = qkv.chunk(3, dim=1)
78 | b, c, h, w = q.shape
79 | q = rearrange(
80 | q, "b (h d) x y -> b h (x y) d", h=self.num_heads, d=self.head_dim
81 | )
82 | k = rearrange(
83 | k, "b (h d) x y -> b h (x y) d", h=self.num_heads, d=self.head_dim
84 | )
85 | v = rearrange(
86 | v, "b (h d) x y -> b h (x y) d", h=self.num_heads, d=self.head_dim
87 | )
88 | h_ = F.scaled_dot_product_attention(q, k, v)
89 | h_ = rearrange(h_, "b h (x y) d -> b (h d) x y", x=h, y=w)
90 | return h_
91 |
92 | def forward(self, x) -> Tensor:
93 | return x + self.proj_out(self.attention(x))
94 |
95 |
96 | class ResnetBlock(nn.Module):
97 | def __init__(self, in_channels: int, out_channels: int):
98 | super().__init__()
99 | self.in_channels = in_channels
100 | out_channels = in_channels if out_channels is None else out_channels
101 | self.out_channels = out_channels
102 | self.norm1 = FP32GroupNorm(
103 | num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
104 | )
105 | self.conv1 = StandardizedC2d(
106 | in_channels, out_channels, kernel_size=3, stride=1, padding=1
107 | )
108 | self.norm2 = FP32GroupNorm(
109 | num_groups=32, num_channels=out_channels, eps=1e-6, affine=True
110 | )
111 | self.conv2 = StandardizedC2d(
112 | out_channels, out_channels, kernel_size=3, stride=1, padding=1
113 | )
114 | if self.in_channels != self.out_channels:
115 | self.nin_shortcut = StandardizedC2d(
116 | in_channels, out_channels, kernel_size=1, stride=1, padding=0
117 | )
118 |
119 | # init conv2 as very small number
120 | nn.init.normal_(self.conv2.weight, std=0.0001 / self.out_channels)
121 | nn.init.zeros_(self.conv2.bias)
122 | self.counter = 0
123 |
124 | def forward(self, x):
125 |
126 | # if self.counter < 5000:
127 | # self.counter += 1
128 | # h = 0
129 | # else:
130 | h = x
131 | h = self.norm1(h)
132 | h = swish(h)
133 | h = self.conv1(h)
134 | h = self.norm2(h)
135 | h = swish(h)
136 | h = self.conv2(h)
137 |
138 | if self.in_channels != self.out_channels:
139 | x = self.nin_shortcut(x)
140 | return x + h
141 |
142 |
143 | class Downsample(nn.Module):
144 | def __init__(self, in_channels: int):
145 | super().__init__()
146 | self.conv = StandardizedC2d(
147 | in_channels, in_channels, kernel_size=3, stride=2, padding=0
148 | )
149 |
150 | def forward(self, x):
151 | pad = (0, 1, 0, 1)
152 | x = nn.functional.pad(x, pad, mode="constant", value=0)
153 | x = self.conv(x)
154 | return x
155 |
156 |
157 | class Upsample(nn.Module):
158 | def __init__(self, in_channels: int):
159 | super().__init__()
160 | self.conv = StandardizedC2d(
161 | in_channels, in_channels, kernel_size=3, stride=1, padding=1
162 | )
163 |
164 | def forward(self, x):
165 | x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
166 | x = self.conv(x)
167 | return x
168 |
169 |
170 | class Encoder(nn.Module):
171 | def __init__(
172 | self,
173 | resolution: int,
174 | in_channels: int,
175 | ch: int,
176 | ch_mult: list[int],
177 | num_res_blocks: int,
178 | z_channels: int,
179 | use_attn: bool = True,
180 | use_wavelet: bool = False,
181 | ):
182 | super().__init__()
183 | self.ch = ch
184 | self.num_resolutions = len(ch_mult)
185 | self.num_res_blocks = num_res_blocks
186 | self.resolution = resolution
187 | self.in_channels = in_channels
188 | self.use_wavelet = use_wavelet
189 | if self.use_wavelet:
190 | self.wavelet_transform = wavelet_transform_multi_channel
191 | self.conv_in = StandardizedC2d(
192 | 4 * in_channels, self.ch * 2, kernel_size=3, stride=1, padding=1
193 | )
194 | ch_mult[0] *= 2
195 | else:
196 | self.wavelet_transform = nn.Identity()
197 | self.conv_in = StandardizedC2d(
198 | in_channels, self.ch, kernel_size=3, stride=1, padding=1
199 | )
200 |
201 | curr_res = resolution
202 | in_ch_mult = (2 if self.use_wavelet else 1,) + tuple(ch_mult)
203 | self.in_ch_mult = in_ch_mult
204 | self.down = nn.ModuleList()
205 | block_in = self.ch
206 | for i_level in range(self.num_resolutions):
207 | block = nn.ModuleList()
208 | attn = nn.ModuleList()
209 | block_in = ch * in_ch_mult[i_level]
210 | block_out = ch * ch_mult[i_level]
211 | for _ in range(self.num_res_blocks):
212 | block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
213 | block_in = block_out
214 | down = nn.Module()
215 | down.block = block
216 | down.attn = attn
217 | if i_level != self.num_resolutions - 1 and not (
218 | self.use_wavelet and i_level == 0
219 | ):
220 | down.downsample = Downsample(block_in)
221 | curr_res = curr_res // 2
222 | self.down.append(down)
223 | self.mid = nn.Module()
224 | self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
225 | self.mid.attn_1 = AttnBlock(block_in) if use_attn else nn.Identity()
226 | self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
227 | self.norm_out = FP32GroupNorm(
228 | num_groups=32, num_channels=block_in, eps=1e-6, affine=True
229 | )
230 | self.conv_out = StandardizedC2d(
231 | block_in, z_channels, kernel_size=3, stride=1, padding=1
232 | )
233 | for module in self.modules():
234 | if isinstance(module, StandardizedC2d):
235 | nn.init.zeros_(module.bias)
236 | if isinstance(module, nn.GroupNorm):
237 | nn.init.zeros_(module.bias)
238 |
239 | def forward(self, x) -> Tensor:
240 | h = self.wavelet_transform(x)
241 | h = self.conv_in(h)
242 | for i_level in range(self.num_resolutions):
243 | for i_block in range(self.num_res_blocks):
244 | h = self.down[i_level].block[i_block](h)
245 | if len(self.down[i_level].attn) > 0:
246 | h = self.down[i_level].attn[i_block](h)
247 | if i_level != self.num_resolutions - 1 and not (
248 | self.use_wavelet and i_level == 0
249 | ):
250 | h = self.down[i_level].downsample(h)
251 | h = self.mid.block_1(h)
252 | h = self.mid.attn_1(h)
253 | h = self.mid.block_2(h)
254 | h = self.norm_out(h)
255 | h = swish(h)
256 | h = self.conv_out(h)
257 | return h
258 |
259 |
260 | class Decoder(nn.Module):
261 | def __init__(
262 | self,
263 | ch: int,
264 | out_ch: int,
265 | ch_mult: list[int],
266 | num_res_blocks: int,
267 | in_channels: int,
268 | resolution: int,
269 | z_channels: int,
270 | use_attn: bool = True,
271 | ):
272 | super().__init__()
273 | self.ch = ch
274 | self.num_resolutions = len(ch_mult)
275 | self.num_res_blocks = num_res_blocks
276 | self.resolution = resolution
277 | self.in_channels = in_channels
278 | self.ffactor = 2 ** (self.num_resolutions - 1)
279 | block_in = ch * ch_mult[self.num_resolutions - 1]
280 | curr_res = resolution // 2 ** (self.num_resolutions - 1)
281 | self.z_shape = (1, z_channels, curr_res, curr_res)
282 | self.conv_in = StandardizedC2d(
283 | z_channels, block_in, kernel_size=3, stride=1, padding=1
284 | )
285 | self.mid = nn.Module()
286 | self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
287 | self.mid.attn_1 = AttnBlock(block_in) if use_attn else nn.Identity()
288 | self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
289 | self.up = nn.ModuleList()
290 | for i_level in reversed(range(self.num_resolutions)):
291 | block = nn.ModuleList()
292 | attn = nn.ModuleList()
293 | block_out = ch * ch_mult[i_level]
294 | for _ in range(self.num_res_blocks + 1):
295 | block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
296 | block_in = block_out
297 | up = nn.Module()
298 | up.block = block
299 | up.attn = attn
300 | if i_level != 0:
301 | up.upsample = Upsample(block_in)
302 | curr_res = curr_res * 2
303 | self.up.insert(0, up)
304 | self.norm_out = FP32GroupNorm(
305 | num_groups=32, num_channels=block_in, eps=1e-6, affine=True
306 | )
307 | self.conv_out = StandardizedC2d(
308 | block_in, out_ch, kernel_size=3, stride=1, padding=1
309 | )
310 |
311 | # initialize all bias to zero
312 | for module in self.modules():
313 | if isinstance(module, StandardizedC2d):
314 | nn.init.zeros_(module.bias)
315 | if isinstance(module, nn.GroupNorm):
316 | nn.init.zeros_(module.bias)
317 |
318 | def forward(self, z) -> Tensor:
319 | h = self.conv_in(z)
320 | h = self.mid.block_1(h)
321 | h = self.mid.attn_1(h)
322 | h = self.mid.block_2(h)
323 | for i_level in reversed(range(self.num_resolutions)):
324 | for i_block in range(self.num_res_blocks + 1):
325 | h = self.up[i_level].block[i_block](h)
326 | if len(self.up[i_level].attn) > 0:
327 | h = self.up[i_level].attn[i_block](h)
328 | if i_level != 0:
329 | h = self.up[i_level].upsample(h)
330 | h = self.norm_out(h)
331 | h = swish(h)
332 | h = self.conv_out(h)
333 | return h
334 |
335 |
336 | class DiagonalGaussian(nn.Module):
337 | def __init__(self, sample: bool = True, chunk_dim: int = 1):
338 | super().__init__()
339 | self.sample = sample
340 | self.chunk_dim = chunk_dim
341 |
342 | def forward(self, z) -> Tensor:
343 | mean = z
344 | if self.sample:
345 | std = 0.00
346 | return mean * (1 + std * torch.randn_like(mean))
347 | else:
348 | return mean
349 |
350 |
351 | class VAE(nn.Module):
352 | def __init__(
353 | self,
354 | resolution,
355 | in_channels,
356 | ch,
357 | out_ch,
358 | ch_mult,
359 | num_res_blocks,
360 | z_channels,
361 | use_attn,
362 | decoder_also_perform_hr,
363 | use_wavelet,
364 | ):
365 | super().__init__()
366 | self.encoder = Encoder(
367 | resolution=resolution,
368 | in_channels=in_channels,
369 | ch=ch,
370 | ch_mult=ch_mult,
371 | num_res_blocks=num_res_blocks,
372 | z_channels=z_channels,
373 | use_attn=use_attn,
374 | use_wavelet=use_wavelet,
375 | )
376 | self.decoder = Decoder(
377 | resolution=resolution,
378 | in_channels=in_channels,
379 | ch=ch,
380 | out_ch=out_ch,
381 | ch_mult=ch_mult + [4] if decoder_also_perform_hr else ch_mult,
382 | num_res_blocks=num_res_blocks,
383 | z_channels=z_channels,
384 | use_attn=use_attn,
385 | )
386 | self.reg = DiagonalGaussian()
387 |
388 | def forward(self, x) -> Tensor:
389 | z = self.encoder(x)
390 | z_s = self.reg(z)
391 | decz = self.decoder(z_s)
392 | return decz, z
393 |
394 |
395 | if __name__ == "__main__":
396 | from utils import prepare_filter
397 |
398 | prepare_filter("cuda")
399 | vae = VAE(
400 | resolution=256,
401 | in_channels=3,
402 | ch=64,
403 | out_ch=3,
404 | ch_mult=[1, 2, 4, 4, 4],
405 | num_res_blocks=2,
406 | z_channels=16 * 4,
407 | use_attn=False,
408 | decoder_also_perform_hr=False,
409 | use_wavelet=False,
410 | )
411 | vae.eval().to("cuda")
412 | x = torch.randn(1, 3, 256, 256).to("cuda")
413 | decz, z = vae(x)
414 | print(decz.shape, z.shape)
415 |
416 | # do de
417 |
418 | # from unit_activation_reinitializer import adjust_weight_init
419 | # from torchvision import transforms
420 | # import torchvision
421 |
422 | # train_dataset = torchvision.datasets.CIFAR10(
423 | # root="./data",
424 | # train=True,
425 | # download=True,
426 | # transform=transforms.Compose(
427 | # [
428 | # transforms.Resize((256, 256)),
429 | # transforms.ToTensor(),
430 | # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
431 | # ]
432 | # ),
433 | # )
434 |
435 | # initial_std, layer_weight_std = adjust_weight_init(
436 | # vae,
437 | # dataset=train_dataset,
438 | # device="cuda:0",
439 | # batch_size=64,
440 | # num_workers=0,
441 | # tol=0.1,
442 | # max_iters=10,
443 | # exclude_layers=[FP32GroupNorm, nn.LayerNorm],
444 | # )
445 |
446 | # # save initial_std and layer_weight_std
447 | # torch.save(initial_std, "initial_std.pth")
448 | # torch.save(layer_weight_std, "layer_weight_std.pth")
449 |
450 | # print("\nAdjusted Weight Standard Deviations. Before -> After:")
451 | # for layer_name, std in layer_weight_std.items():
452 | # print(
453 | # f"Layer {layer_name}, Changed STD from \n {initial_std[layer_name]:.4f} -> STD {std:.4f}\n"
454 | # )
455 |
456 | # print(layer_weight_std)
457 |
--------------------------------------------------------------------------------
/contents/appg.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cloneofsimo/vqgan-training/379fe36eee4e90e01a4076b2815d49e6736db992/contents/appg.png
--------------------------------------------------------------------------------
/contents/buildings.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cloneofsimo/vqgan-training/379fe36eee4e90e01a4076b2815d49e6736db992/contents/buildings.jpg
--------------------------------------------------------------------------------
/contents/chinatown.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cloneofsimo/vqgan-training/379fe36eee4e90e01a4076b2815d49e6736db992/contents/chinatown.jpg
--------------------------------------------------------------------------------
/contents/cosplayers.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cloneofsimo/vqgan-training/379fe36eee4e90e01a4076b2815d49e6736db992/contents/cosplayers.jpg
--------------------------------------------------------------------------------
/contents/flowers.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cloneofsimo/vqgan-training/379fe36eee4e90e01a4076b2815d49e6736db992/contents/flowers.jpg
--------------------------------------------------------------------------------
/contents/lavender.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cloneofsimo/vqgan-training/379fe36eee4e90e01a4076b2815d49e6736db992/contents/lavender.jpg
--------------------------------------------------------------------------------
/contents/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cloneofsimo/vqgan-training/379fe36eee4e90e01a4076b2815d49e6736db992/contents/logo.png
--------------------------------------------------------------------------------
/contents/magazines.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cloneofsimo/vqgan-training/379fe36eee4e90e01a4076b2815d49e6736db992/contents/magazines.jpg
--------------------------------------------------------------------------------
/contents/origin.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cloneofsimo/vqgan-training/379fe36eee4e90e01a4076b2815d49e6736db992/contents/origin.png
--------------------------------------------------------------------------------
/contents/randomwoman.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cloneofsimo/vqgan-training/379fe36eee4e90e01a4076b2815d49e6736db992/contents/randomwoman.jpeg
--------------------------------------------------------------------------------
/contents/recon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cloneofsimo/vqgan-training/379fe36eee4e90e01a4076b2815d49e6736db992/contents/recon.png
--------------------------------------------------------------------------------
/contents/ti2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cloneofsimo/vqgan-training/379fe36eee4e90e01a4076b2815d49e6736db992/contents/ti2.png
--------------------------------------------------------------------------------
/contents/ti2_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cloneofsimo/vqgan-training/379fe36eee4e90e01a4076b2815d49e6736db992/contents/ti2_mask.png
--------------------------------------------------------------------------------
/launcher.sh:
--------------------------------------------------------------------------------
1 |
2 |
3 | loglr=-7
4 | width=64
5 | lr=$(python -c "import math; print(2**${loglr})")
6 | run_name="stage_4_msepool-cont-512-1.0-1.0-batch-gradnorm_make_deterministic"
7 | echo "Running ${run_name}"
8 |
9 | torchrun --nproc_per_node=8 vae_trainer.py \
10 | --learning_rate_vae ${lr} \
11 | --vae_ch ${width} \
12 | --run_name ${run_name} \
13 | --num_epochs 20 \
14 | --max_steps 100000 \
15 | --evaluate_every_n_steps 500 \
16 | --learning_rate_disc 1e-5 \
17 | --batch_size 12 \
18 | --do_clamp \
19 | --do_ganloss \
20 | --project_name "HrDecoderAE" \
21 | --decoder_also_perform_hr True
22 | #--load_path "/home/ubuntu/auravasa/ckpt/stage_3_msepool-cont-512-1.0-1.0-batch-gradnorm/vae_epoch_1_step_23501.pt"
23 | #--load_path "/home/ubuntu/auravasa/ckpt/stage2_msepool-cont-512-1.0-1.0-batch-gradnorm/vae_epoch_0_step_28501.pt"
24 | # --load_path "/home/ubuntu/auravasa/ckpt/exp_vae_ch_256_lr_0.0078125_weighted_percep+f8areapool_l2_0.0/vae_epoch_1_step_27001.pt"
--------------------------------------------------------------------------------
/scripts/launch_hdr.sh:
--------------------------------------------------------------------------------
1 |
2 |
3 | loglr=-7
4 | width=128
5 | lr=$(python -c "import math; print(2**${loglr})")
6 | run_name="stage_4_cont_with_lecam_hinge"
7 | echo "Running ${run_name}"
8 |
9 | torchrun --nproc_per_node=8 vae_trainer.py \
10 | --learning_rate_vae ${lr} \
11 | --vae_ch ${width} \
12 | --run_name ${run_name} \
13 | --num_epochs 20 \
14 | --max_steps 100000 \
15 | --evaluate_every_n_steps 1000 \
16 | --learning_rate_disc 3e-5 \
17 | --batch_size 4 \
18 | --do_clamp \
19 | --do_ganloss \
20 | --project_name "HrDecoderAE" \
21 | --decoder_also_perform_hr True \
22 | --do_compile False \
23 | --crop_invariance True \
24 | --flip_invariance False \
25 | --use_wavelet True \
26 | --vae_z_channels 64 \
27 | --vae_ch_mult 1,2,4,4,4 \
28 | --use_lecam True \
29 | --disc_type "hinge" \
30 | --load_path "/home/ubuntu/auravasa/ckpt/stage_3_hdr_z64_f16_add_flip_lr_disc_1e-4/vae_epoch_1_step_98001.pt"
--------------------------------------------------------------------------------
/sweep.sh:
--------------------------------------------------------------------------------
1 | ## Sweep 1. is attention useful?
2 |
3 | loglrs=(-8 -7 -6 -5 -4 -3 -2)
4 | MODEL_WIDTHS=(32 64 128)
5 |
6 | for loglr in "${loglrs[@]}"; do
7 | for attn in "True" "False"; do
8 | for width in "${MODEL_WIDTHS[@]}"; do
9 | lr=$(python -c "import math; print(2**${loglr})")
10 | run_name="exp_vae_ch_${width}_lr_${lr}_attn_${attn}"
11 |
12 | echo "Running ${run_name}"
13 |
14 | torchrun --nproc_per_node=8 vae_trainer.py \
15 | --learning_rate_vae ${lr} \
16 | --vae_ch ${width} \
17 | --run_name ${run_name} \
18 | --num_epochs 20 \
19 | --max_steps 2000 \
20 | --evaluate_every_n_steps 250 \
21 | --batch_size 32 \
22 | --do_clamp \
23 | --do_attn ${attn} \
24 | --project_name "vae_sweep_attn_lr_width"
25 |
26 | done
27 | done
28 | done
29 |
30 | ## Sweep 2. Can we initialize better?
31 |
32 | loglrs=(-8 -7 -6 -5 -4 -3 -2)
33 | MODEL_WIDTHS=(64)
34 |
35 | for loglr in "${loglrs[@]}"; do
36 | for attn in "True" "False"; do
37 | for width in "${MODEL_WIDTHS[@]}"; do
38 | lr=$(python -c "import math; print(2**${loglr})")
39 | run_name="exp_vae_ch_${width}_lr_${lr}_attn_${attn}"
40 |
41 | echo "Running ${run_name}"
42 |
43 | torchrun --nproc_per_node=8 vae_trainer.py \
44 | --learning_rate_vae ${lr} \
45 | --vae_ch ${width} \
46 | --run_name ${run_name} \
47 | --num_epochs 20 \
48 | --max_steps 2000 \
49 | --evaluate_every_n_steps 250 \
50 | --batch_size 32 \
51 | --do_clamp \
52 | --do_attn ${attn} \
53 | --project_name "vae_sweep_attn_lr_width"
54 |
55 | done
56 | done
57 | done
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
--------------------------------------------------------------------------------
/tae.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn.functional as F
5 | from einops import rearrange
6 | from torch import Tensor, nn
7 |
8 |
9 | def swish(x: Tensor) -> Tensor:
10 | return x * torch.sigmoid(x)
11 |
12 |
13 | class AttnBlock(nn.Module):
14 | def __init__(self, in_channels: int):
15 | super().__init__()
16 | self.in_channels = in_channels
17 | self.num_heads = 8
18 | self.head_dim = in_channels // self.num_heads
19 | self.norm = nn.GroupNorm(
20 | num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
21 | )
22 | self.qkv = nn.Conv3d(in_channels, in_channels * 3, kernel_size=1, bias=False)
23 | self.proj_out = nn.Conv3d(in_channels, in_channels, kernel_size=1, bias=False)
24 | nn.init.normal_(self.proj_out.weight, std=0.2 / math.sqrt(in_channels))
25 |
26 | def attention(self, h_: Tensor) -> Tensor:
27 | h_ = self.norm(h_)
28 | qkv = self.qkv(h_)
29 | q, k, v = qkv.chunk(3, dim=1)
30 | b, c, t, h, w = q.shape
31 | q = rearrange(
32 | q,
33 | "b (head d) t h w -> b head (t h w) d",
34 | head=self.num_heads,
35 | d=self.head_dim,
36 | )
37 | k = rearrange(
38 | k,
39 | "b (head d) t h w -> b head (t h w) d",
40 | head=self.num_heads,
41 | d=self.head_dim,
42 | )
43 | v = rearrange(
44 | v,
45 | "b (head d) t h w -> b head (t h w) d",
46 | head=self.num_heads,
47 | d=self.head_dim,
48 | )
49 | h_ = F.scaled_dot_product_attention(q, k, v)
50 | h_ = rearrange(h_, "b head (t h w) d -> b (head d) t h w", t=t, h=h, w=w)
51 | return h_
52 |
53 | def forward(self, x: Tensor) -> Tensor:
54 | return x + self.proj_out(self.attention(x))
55 |
56 |
57 | class ResnetBlock(nn.Module):
58 | def __init__(self, in_channels: int, out_channels: int = None):
59 | super().__init__()
60 | self.in_channels = in_channels
61 | out_channels = in_channels if out_channels is None else out_channels
62 | self.out_channels = out_channels
63 | self.norm1 = nn.GroupNorm(
64 | num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
65 | )
66 | self.conv1 = nn.Conv3d(
67 | in_channels, out_channels, kernel_size=3, stride=1, padding=1
68 | )
69 | self.norm2 = nn.GroupNorm(
70 | num_groups=32, num_channels=out_channels, eps=1e-6, affine=True
71 | )
72 | self.conv2 = nn.Conv3d(
73 | out_channels, out_channels, kernel_size=3, stride=1, padding=1
74 | )
75 | if self.in_channels != self.out_channels:
76 | self.nin_shortcut = nn.Conv3d(
77 | in_channels, out_channels, kernel_size=1, stride=1, padding=0
78 | )
79 |
80 | def forward(self, x):
81 | h = x
82 | h = self.norm1(h)
83 | h = swish(h)
84 | h = self.conv1(h)
85 | h = self.norm2(h)
86 | h = swish(h)
87 | h = self.conv2(h)
88 | if self.in_channels != self.out_channels:
89 | x = self.nin_shortcut(x)
90 | return x + h
91 |
92 |
93 | class Downsample(nn.Module):
94 | def __init__(self, in_channels: int):
95 | super().__init__()
96 | self.conv = nn.Conv3d(
97 | in_channels, in_channels, kernel_size=3, stride=2, padding=0
98 | )
99 |
100 | def forward(self, x: Tensor):
101 | pad = (0, 1, 0, 1, 0, 1) # Pad depth (T), height (H), width (W) dimensions
102 | x = nn.functional.pad(x, pad, mode="constant", value=0)
103 | x = self.conv(x)
104 | return x
105 |
106 |
107 | class Upsample(nn.Module):
108 | def __init__(self, in_channels: int):
109 | super().__init__()
110 | self.conv = nn.Conv3d(
111 | in_channels, in_channels, kernel_size=3, stride=1, padding=1
112 | )
113 |
114 | def forward(self, x: Tensor):
115 | x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
116 | x = self.conv(x)
117 | return x
118 |
119 |
120 | class Encoder(nn.Module):
121 | def __init__(
122 | self,
123 | resolution: int,
124 | in_channels: int,
125 | ch: int,
126 | ch_mult: list[int],
127 | num_res_blocks: int,
128 | z_channels: int,
129 | ):
130 | super().__init__()
131 | self.ch = ch
132 | self.num_resolutions = len(ch_mult)
133 | self.num_res_blocks = num_res_blocks
134 | self.resolution = resolution
135 | self.in_channels = in_channels
136 | self.conv_in = nn.Conv3d(
137 | in_channels, self.ch, kernel_size=3, stride=1, padding=1
138 | )
139 | curr_res = resolution
140 | in_ch_mult = (1,) + tuple(ch_mult)
141 | self.down = nn.ModuleList()
142 | block_in = self.ch
143 | for i_level in range(self.num_resolutions):
144 | block = nn.ModuleList()
145 | attn = nn.ModuleList()
146 | block_in = ch * in_ch_mult[i_level]
147 | block_out = ch * ch_mult[i_level]
148 | for _ in range(self.num_res_blocks):
149 | block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
150 | block_in = block_out
151 | down = nn.Module()
152 | down.block = block
153 | down.attn = attn
154 | if i_level != self.num_resolutions - 1:
155 | down.downsample = Downsample(block_in)
156 | curr_res = curr_res // 2
157 | self.down.append(down)
158 | self.mid = nn.Module()
159 | self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
160 | self.mid.attn_1 = AttnBlock(block_in)
161 | self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
162 | self.norm_out = nn.GroupNorm(
163 | num_groups=32, num_channels=block_in, eps=1e-6, affine=True
164 | )
165 | self.conv_out = nn.Conv3d(
166 | block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1
167 | )
168 |
169 | def forward(self, x: Tensor) -> Tensor:
170 | h = self.conv_in(x)
171 | for i_level in range(self.num_resolutions):
172 | for i_block in range(self.num_res_blocks):
173 | h = self.down[i_level].block[i_block](h)
174 | if len(self.down[i_level].attn) > 0:
175 | h = self.down[i_level].attn[i_block](h)
176 | if i_level != self.num_resolutions - 1:
177 | h = self.down[i_level].downsample(h)
178 | h = self.mid.block_1(h)
179 | h = self.mid.attn_1(h)
180 | h = self.mid.block_2(h)
181 | h = self.norm_out(h)
182 | h = swish(h)
183 | h = self.conv_out(h)
184 | return h
185 |
186 |
187 | class Decoder(nn.Module):
188 | def __init__(
189 | self,
190 | ch: int,
191 | out_ch: int,
192 | ch_mult: list[int],
193 | num_res_blocks: int,
194 | in_channels: int,
195 | resolution: int,
196 | z_channels: int,
197 | ):
198 | super().__init__()
199 | self.ch = ch
200 | self.num_resolutions = len(ch_mult)
201 | self.num_res_blocks = num_res_blocks
202 | self.resolution = resolution
203 | self.in_channels = in_channels
204 | self.ffactor = 2 ** (self.num_resolutions - 1)
205 | block_in = ch * ch_mult[self.num_resolutions - 1]
206 | curr_res = resolution // (2 ** (self.num_resolutions - 1))
207 | self.z_shape = (1, z_channels, curr_res, curr_res, curr_res)
208 | self.conv_in = nn.Conv3d(
209 | z_channels, block_in, kernel_size=3, stride=1, padding=1
210 | )
211 | self.mid = nn.Module()
212 | self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
213 | self.mid.attn_1 = AttnBlock(block_in)
214 | self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
215 | self.up = nn.ModuleList()
216 | for i_level in reversed(range(self.num_resolutions)):
217 | block = nn.ModuleList()
218 | attn = nn.ModuleList()
219 | block_out = ch * ch_mult[i_level]
220 | for _ in range(self.num_res_blocks + 1):
221 | block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
222 | block_in = block_out
223 | up = nn.Module()
224 | up.block = block
225 | up.attn = attn
226 | if i_level != 0:
227 | up.upsample = Upsample(block_in)
228 | curr_res = curr_res * 2
229 | self.up.insert(0, up)
230 | self.norm_out = nn.GroupNorm(
231 | num_groups=32, num_channels=block_in, eps=1e-6, affine=True
232 | )
233 | self.conv_out = nn.Conv3d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
234 |
235 | def forward(self, z: Tensor) -> Tensor:
236 | h = self.conv_in(z)
237 | h = self.mid.block_1(h)
238 | h = self.mid.attn_1(h)
239 | h = self.mid.block_2(h)
240 | for i_level in reversed(range(self.num_resolutions)):
241 | for i_block in range(self.num_res_blocks + 1):
242 | h = self.up[i_level].block[i_block](h)
243 | if len(self.up[i_level].attn) > 0:
244 | h = self.up[i_level].attn[i_block](h)
245 | if i_level != 0:
246 | h = self.up[i_level].upsample(h)
247 | h = self.norm_out(h)
248 | h = swish(h)
249 | h = self.conv_out(h)
250 | return h
251 |
252 |
253 | class DiagonalGaussian(nn.Module):
254 | def __init__(self, sample: bool = True, chunk_dim: int = 1):
255 | super().__init__()
256 | self.sample = sample
257 | self.chunk_dim = chunk_dim
258 |
259 | def forward(self, z: Tensor) -> Tensor:
260 | mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
261 | if self.sample:
262 | logvar = logvar.clamp(min=-3)
263 | std = torch.exp(0.5 * logvar)
264 | return mean + std * torch.randn_like(mean)
265 | else:
266 | return mean
267 |
268 |
269 | class TVAE(nn.Module):
270 | def __init__(
271 | self, resolution, in_channels, ch, out_ch, ch_mult, num_res_blocks, z_channels
272 | ):
273 | super().__init__()
274 | self.encoder = Encoder(
275 | resolution=resolution,
276 | in_channels=in_channels,
277 | ch=ch,
278 | ch_mult=ch_mult,
279 | num_res_blocks=num_res_blocks,
280 | z_channels=z_channels,
281 | )
282 | self.decoder = Decoder(
283 | resolution=resolution,
284 | in_channels=in_channels,
285 | ch=ch,
286 | out_ch=out_ch,
287 | ch_mult=ch_mult,
288 | num_res_blocks=num_res_blocks,
289 | z_channels=z_channels,
290 | )
291 | self.reg = DiagonalGaussian()
292 |
293 | def forward(self, x: Tensor) -> Tensor:
294 | z = self.encoder(x)
295 | z_s = self.reg(z)
296 | decz = self.decoder(z_s)
297 | return decz, z
298 |
299 |
300 | if __name__ == "__main__":
301 | vae = TVAE(
302 | resolution=256,
303 | in_channels=3,
304 | ch=64,
305 | out_ch=3,
306 | ch_mult=[1, 2, 4, 4],
307 | num_res_blocks=2,
308 | z_channels=16,
309 | )
310 | with torch.no_grad():
311 | vae.eval().to("cpu")
312 | x = torch.randn(1, 3, 48, 256, 256).to("cpu") # [B, C, T, H, W]
313 | decz, z = vae(x)
314 | print(decz.shape, z.shape)
315 |
--------------------------------------------------------------------------------
/tester.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | import torch.optim as optim
5 | import torchvision.models as models
6 | import torchvision.transforms as transforms
7 | from PIL import Image
8 |
9 | from vae_trainer import (
10 | VAE,
11 | PatchDiscriminator,
12 | create_dataloader,
13 | gan_loss,
14 | perceptual_loss,
15 | )
16 |
17 |
18 | # Sample Image for Testing
19 | def create_sample_image(size=(128, 128), color=(255, 0, 0)):
20 | img = Image.new("RGB", size, color)
21 | transform = transforms.ToTensor()
22 | img_tensor = transform(img).unsqueeze(0)
23 | return img_tensor
24 |
25 |
26 | # Test the VAE architecture
27 | def test_vae():
28 | print("Testing VAE Architecture...")
29 | # Instantiate the model
30 | vae = VAE(width_mult=1.0)
31 |
32 | # Create a sample image
33 | img = create_sample_image()
34 |
35 | # Forward pass through the model
36 | reconstructed, latent = vae(img)
37 |
38 | # Check output shapes
39 | print(f"Input shape: {img.shape}")
40 | print(f"Latent shape: {latent.shape}")
41 | print(f"Reconstructed shape: {reconstructed.shape}")
42 |
43 |
44 | # Test the Patch Discriminator
45 | def test_discriminator():
46 | print("Testing Patch Discriminator...")
47 | # Instantiate the discriminator
48 | discriminator = PatchDiscriminator()
49 |
50 | # Create a sample image
51 | img = create_sample_image()
52 |
53 | # Forward pass through the discriminator
54 | output = discriminator(img)
55 |
56 | # Check output shape
57 | print(f"Input shape: {img.shape}")
58 | print(f"Discriminator output shape: {output.shape}")
59 |
60 |
61 | # Test GAN loss function
62 | def test_gan_loss():
63 | print("Testing GAN Loss...")
64 | # Create sample real and fake predictions
65 | real_preds = torch.randn(1, 1, 8, 8).abs()
66 | fake_preds = torch.randn(1, 1, 8, 8).abs()
67 |
68 | # Calculate loss
69 | loss = gan_loss(real_preds, fake_preds)
70 | print(f"GAN Loss: {loss.item()}")
71 |
72 |
73 | # Test perceptual loss function
74 | def test_perceptual_loss():
75 | print("Testing Perceptual Loss...")
76 | # Instantiate VGG model for perceptual loss
77 | vgg_model = models.vgg16(pretrained=True).features[:9]
78 |
79 | # Create two sample images (real and reconstructed)
80 | real_img = create_sample_image()
81 | reconstructed_img = create_sample_image(color=(0, 255, 0))
82 |
83 | # Calculate perceptual loss
84 | loss = perceptual_loss(reconstructed_img, real_img, vgg_model)
85 | print(f"Perceptual Loss: {loss.item()}")
86 |
87 |
88 | # Test WebDataset DataLoader
89 | def test_webdataset_dataloader():
90 | print("Testing WebDataset DataLoader...")
91 |
92 | # Dummy URL for WebDataset (replace with actual path or URL in practice)
93 | dataset_url = "path/to/your/webdataset/shards"
94 |
95 | # Create a dummy dataset with ToTensor transforms
96 | transform = transforms.Compose(
97 | [
98 | transforms.Resize((128, 128)),
99 | transforms.ToTensor(),
100 | ]
101 | )
102 |
103 | # Set up dummy dataloader
104 | dataloader = create_dataloader(
105 | dataset_url, batch_size=2, num_workers=4, world_size=1, rank=0
106 | )
107 |
108 | # Check the output from the dataloader
109 | for imgs, labels in dataloader:
110 | print(f"Batch of images shape: {imgs.shape}")
111 | print(f"Batch of labels shape: {labels.shape}")
112 | break
113 |
114 |
115 | def test_train_loop():
116 | print("Testing Train Loop...")
117 |
118 | vae = VAE(width_mult=1.0)
119 | discriminator = PatchDiscriminator()
120 |
121 | img = create_sample_image()
122 |
123 | optimizer_G = optim.Adam(vae.parameters(), lr=2e-4)
124 | optimizer_D = optim.Adam(discriminator.parameters(), lr=2e-4)
125 |
126 | vgg_model = models.vgg16(pretrained=True).features[:9]
127 | vgg_model.eval()
128 |
129 | reconstructed, latent = vae(img)
130 |
131 | print(f"Input image shape: {img.shape} (Expected: [1, 3, 128, 128])")
132 | print(f"Latent shape: {latent.shape} (Expected: [1, 256, 8, 8])")
133 | print(
134 | f"Reconstructed image shape: {reconstructed.shape} (Expected: [1, 3, 128, 128])"
135 | )
136 |
137 | real_preds = discriminator(img)
138 | fake_preds = discriminator(reconstructed.detach())
139 |
140 | print(f"Real predictions shape: {real_preds.shape} (Expected: [1, 1, 8, 8])")
141 | print(f"Fake predictions shape: {fake_preds.shape} (Expected: [1, 1, 8, 8])")
142 |
143 | gan_loss_value = gan_loss(real_preds, fake_preds)
144 | rec_loss_value = perceptual_loss(reconstructed, img, vgg_model)
145 |
146 | lambda_val = rec_loss_value / (gan_loss_value + 1e-8)
147 | g_loss = rec_loss_value + lambda_val * gan_loss_value
148 | g_loss.backward(retain_graph=True)
149 |
150 | d_loss = gan_loss(real_preds, fake_preds)
151 | d_loss.backward()
152 |
153 | optimizer_D.step()
154 | optimizer_G.step()
155 |
156 | optimizer_G.zero_grad()
157 | optimizer_D.zero_grad()
158 |
159 | print(f"Discriminator Loss: {d_loss.item()} (Expected: > 0)")
160 | print(f"Generator Loss: {g_loss.item()} (Expected: > 0)")
161 | print(
162 | f"Lambda value (weight): {lambda_val.item()} (Expected: ~1.0 depending on losses)"
163 | )
164 |
165 |
166 | # Run the dummy train loop test
167 | if __name__ == "__main__":
168 |
169 | # download lpips from here
170 | # https://heibox.uni-heidelberg.de/seafhttp/files/9535cbee-6558-4c0c-8743-78f5e56ea75e/vgg.pth
171 | os.system(
172 | "wget https://heibox.uni-heidelberg.de/seafhttp/files/9535cbee-6558-4c0c-8743-78f5e56ea75e/vgg.pth"
173 | )
174 | test_vae()
175 | test_discriminator()
176 | test_gan_loss()
177 | test_perceptual_loss()
178 | # test_webdataset_dataloader()
179 |
180 | # Dummy train loop for input-output
181 | test_train_loop()
182 |
--------------------------------------------------------------------------------
/tester_upload.sh:
--------------------------------------------------------------------------------
1 | export HF_TRANSFER=True
2 | huggingface-cli upload fal/AuraEquiVAE ./vae_epoch_3_step_49501_bf16.pt
--------------------------------------------------------------------------------
/unit_activation_reinitializer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torchvision import datasets, transforms
5 | import numpy as np
6 | import math
7 |
8 |
9 | def compute_activation_std(
10 | model, dataset, device="cpu", batch_size=32, num_workers=0, layer_names=None
11 | ):
12 | activations = {}
13 | handles = []
14 |
15 | def save_activation(name):
16 | def hook(module, input, output):
17 | if isinstance(output, tuple):
18 | output = output[0]
19 | activations[name].append(output.detach())
20 |
21 | return hook
22 |
23 | for name, module in model.named_modules():
24 | if name in layer_names:
25 | activations[name] = []
26 | handle = module.register_forward_hook(save_activation(name))
27 | handles.append(handle)
28 |
29 | loader = torch.utils.data.DataLoader(
30 | dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers
31 | )
32 | model.to(device)
33 | model.eval()
34 |
35 | with torch.no_grad():
36 | for batch in loader:
37 | if isinstance(batch, (list, tuple)):
38 | inputs = batch[0].to(device)
39 | else:
40 | inputs = batch.to(device)
41 | _ = model(inputs)
42 | break
43 |
44 | layer_activation_std = {}
45 | for name in layer_names:
46 | try:
47 | act = torch.cat(activations[name], dim=0)
48 | except:
49 | print(activations[name])
50 | break
51 | act_std = act.std().item()
52 | layer_activation_std[name] = act_std
53 |
54 | for handle in handles:
55 | handle.remove()
56 |
57 | return layer_activation_std
58 |
59 |
60 | def adjust_weight_init(
61 | model,
62 | dataset,
63 | device="cpu",
64 | batch_size=32,
65 | num_workers=0,
66 | tol=0.2,
67 | max_iters=10,
68 | exclude_layers=None,
69 | ):
70 | if exclude_layers is None:
71 | exclude_layers = []
72 |
73 | layers_to_adjust = []
74 | for name, module in model.named_modules():
75 | if isinstance(module, (nn.Linear, nn.Conv2d)) and not isinstance(
76 | module, tuple(exclude_layers)
77 | ):
78 | layers_to_adjust.append((name, module))
79 |
80 | print(f"Layers to adjust: {layers_to_adjust}")
81 | initial_std = {}
82 | layer_weight_std = {}
83 |
84 | for name, module in layers_to_adjust:
85 | print(f"Adjusting layer: {name}")
86 | initial_std[name] = module.weight.std().item()
87 | fan_in = np.prod(module.weight.shape[1:])
88 | weight_std = np.sqrt(1 / fan_in) # use muP for initialization.
89 |
90 | for i in range(max_iters):
91 | nn.init.normal_(module.weight, std=weight_std)
92 |
93 | activation_std = compute_activation_std(
94 | model, dataset, device, batch_size, num_workers, layer_names=[name]
95 | )[name]
96 | print(f"Iteration {i+1}: Activation std = {activation_std:.4f}")
97 |
98 | if abs(activation_std - 1.0) < tol:
99 | print(
100 | f"Layer {name} achieved near unit activation of {activation_std:.4f} with weight std = {weight_std:.4f}"
101 | )
102 | layer_weight_std[name] = weight_std / activation_std
103 | break
104 | else:
105 | weight_std = weight_std / activation_std
106 | else:
107 | print(f"Layer {name} did not converge within {max_iters} iterations.")
108 | layer_weight_std[name] = weight_std
109 |
110 | return initial_std, layer_weight_std
111 |
112 |
113 | #### HOW TO USE
114 | # 1. define dataset
115 | # 2. define model
116 | # 3. launch.
117 |
118 | transform = transforms.Compose(
119 | [
120 | transforms.ToTensor(),
121 | ]
122 | )
123 |
124 | train_dataset = datasets.MNIST(
125 | root="mnist_data", train=True, transform=transform, download=True
126 | )
127 |
128 |
129 | class CustomActivation(nn.Module):
130 | def forward(self, x):
131 | return x * torch.sigmoid(x)
132 |
133 |
134 | class ResBlock(nn.Module):
135 | def __init__(self, in_channels, out_channels, reduction_ratio=2):
136 | super(ResBlock, self).__init__()
137 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
138 | self.bn1 = nn.BatchNorm2d(out_channels)
139 | self.relu = nn.ReLU(inplace=True)
140 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
141 | self.bn2 = nn.BatchNorm2d(out_channels)
142 |
143 | def forward(self, x):
144 | identity = x
145 | out = self.conv1(x)
146 | out = self.bn1(out)
147 | out = self.relu(out)
148 | out = self.conv2(out)
149 | out = self.bn2(out)
150 | out = out.mean(dim=[-1, -2])
151 | return out
152 |
153 |
154 | class MLPModel(nn.Module):
155 | def __init__(self):
156 | super(MLPModel, self).__init__()
157 | self.block1 = ResBlock(1, 256)
158 |
159 | self.fc1 = nn.Linear(256, 256)
160 | self.act1 = CustomActivation()
161 | self.ln1 = nn.LayerNorm(256)
162 |
163 | self.fc2 = nn.Linear(256, 128)
164 | self.act2 = nn.ReLU()
165 | self.ln2 = nn.LayerNorm(128)
166 |
167 | self.fc3 = nn.Linear(128, 64)
168 | self.act3 = nn.Tanh()
169 |
170 | self.fc_residual = nn.Linear(256, 64)
171 |
172 | self.fc4 = nn.Linear(64, 10)
173 |
174 | def forward(self, x):
175 | out1 = self.block1(x)
176 | out1 = out1.view(out1.shape[0], -1)
177 | out1 = self.act1(out1)
178 | out1 = self.fc1(out1)
179 | out1 = self.ln1(out1)
180 |
181 | out2 = self.act2(self.fc2(out1))
182 | out2 = self.ln2(out2)
183 |
184 | out3 = self.act3(self.fc3(out2))
185 |
186 | res = self.fc_residual(out1)
187 | out3 += res
188 |
189 | logits = self.fc4(out3)
190 | return logits
191 |
192 |
193 | model = MLPModel()
194 |
195 |
196 | exclude_layers = [nn.LayerNorm]
197 |
198 | initial_std, layer_weight_std = adjust_weight_init(
199 | model,
200 | dataset=train_dataset,
201 | device="cuda:0",
202 | batch_size=64,
203 | num_workers=0,
204 | tol=0.1,
205 | max_iters=10,
206 | exclude_layers=exclude_layers,
207 | )
208 |
209 | print("\nAdjusted Weight Standard Deviations. Before -> After:")
210 | for layer_name, std in layer_weight_std.items():
211 | print(
212 | f"Layer {layer_name}, Changed STD from \n {initial_std[layer_name]:.4f} -> STD {std:.4f}\n"
213 | )
214 |
215 | print(layer_weight_std)
216 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torchvision import models
6 |
7 |
8 | class LPIPS(nn.Module):
9 | # Learned perceptual metric
10 | def __init__(self, use_dropout=True):
11 | super().__init__()
12 | self.scaling_layer = ScalingLayer()
13 | self.chns = [64, 128, 256, 512, 512] # vg16 features
14 | self.net = vgg16(pretrained=True, requires_grad=False)
15 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
16 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
17 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
18 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
19 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
20 | self.load_from_pretrained()
21 | for param in self.parameters():
22 | param.requires_grad = False
23 |
24 | def load_from_pretrained(self, name="vgg_lpips"):
25 | try:
26 | data = torch.load("vgg.pth", map_location=torch.device("cpu"))
27 | except:
28 | print("Failed to load vgg.pth, downloading...")
29 | os.system(
30 | "wget https://heibox.uni-heidelberg.de/seafhttp/files/9535cbee-6558-4c0c-8743-78f5e56ea75e/vgg.pth"
31 | )
32 | data = torch.load("vgg.pth", map_location=torch.device("cpu"))
33 |
34 | self.load_state_dict(
35 | data,
36 | strict=False,
37 | )
38 |
39 | def forward(self, input, target):
40 | in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
41 | outs0, outs1 = self.net(in0_input), self.net(in1_input)
42 | feats0, feats1, diffs = {}, {}, {}
43 | lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
44 | for kk in range(len(self.chns)):
45 | feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(
46 | outs1[kk]
47 | )
48 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
49 |
50 | res = [
51 | spatial_average(lins[kk].model(diffs[kk]), keepdim=True)
52 | for kk in range(len(self.chns))
53 | ]
54 | val = res[0]
55 | for l in range(1, len(self.chns)):
56 | val += res[l]
57 | return val
58 |
59 |
60 | class ScalingLayer(nn.Module):
61 | def __init__(self):
62 | super(ScalingLayer, self).__init__()
63 | self.register_buffer(
64 | "shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None]
65 | )
66 | self.register_buffer(
67 | "scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None]
68 | )
69 |
70 | def forward(self, inp):
71 | return (inp - self.shift) / self.scale
72 |
73 |
74 | class NetLinLayer(nn.Module):
75 | """A single linear layer which does a 1x1 conv"""
76 |
77 | def __init__(self, chn_in, chn_out=1, use_dropout=False):
78 | super(NetLinLayer, self).__init__()
79 | layers = (
80 | [
81 | nn.Dropout(),
82 | ]
83 | if (use_dropout)
84 | else []
85 | )
86 | layers += [
87 | nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),
88 | ]
89 | self.model = nn.Sequential(*layers)
90 |
91 |
92 | class vgg16(torch.nn.Module):
93 | def __init__(self, requires_grad=False, pretrained=True):
94 | super(vgg16, self).__init__()
95 | vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
96 | self.slice1 = torch.nn.Sequential()
97 | self.slice2 = torch.nn.Sequential()
98 | self.slice3 = torch.nn.Sequential()
99 | self.slice4 = torch.nn.Sequential()
100 | self.slice5 = torch.nn.Sequential()
101 | self.N_slices = 5
102 | for x in range(4):
103 | self.slice1.add_module(str(x), vgg_pretrained_features[x])
104 | for x in range(4, 9):
105 | self.slice2.add_module(str(x), vgg_pretrained_features[x])
106 | for x in range(9, 16):
107 | self.slice3.add_module(str(x), vgg_pretrained_features[x])
108 | for x in range(16, 23):
109 | self.slice4.add_module(str(x), vgg_pretrained_features[x])
110 | for x in range(23, 30):
111 | self.slice5.add_module(str(x), vgg_pretrained_features[x])
112 | if not requires_grad:
113 | for param in self.parameters():
114 | param.requires_grad = False
115 |
116 | def forward(self, X):
117 | h = self.slice1(X)
118 | h_relu1_2 = h
119 | h = self.slice2(h)
120 | h_relu2_2 = h
121 | h = self.slice3(h)
122 | h_relu3_3 = h
123 | h = self.slice4(h)
124 | h_relu4_3 = h
125 | h = self.slice5(h)
126 | h_relu5_3 = h
127 | vgg_outputs = namedtuple(
128 | "VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"]
129 | )
130 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
131 | return out
132 |
133 |
134 | def normalize_tensor(x, eps=1e-10):
135 | norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True))
136 | return x / (norm_factor + eps)
137 |
138 |
139 | def spatial_average(x, keepdim=True):
140 | return x.mean([2, 3], keepdim=keepdim)
141 |
142 |
143 | class PatchDiscriminator(nn.Module):
144 | def __init__(self):
145 | super(PatchDiscriminator, self).__init__()
146 | self.scaling_layer = ScalingLayer()
147 |
148 | _vgg = models.vgg16(pretrained=True)
149 |
150 | self.slice1 = nn.Sequential(_vgg.features[:4])
151 | self.slice2 = nn.Sequential(_vgg.features[4:9])
152 | self.slice3 = nn.Sequential(_vgg.features[9:16])
153 | self.slice4 = nn.Sequential(_vgg.features[16:23])
154 | self.slice5 = nn.Sequential(_vgg.features[23:30])
155 |
156 | self.binary_classifier1 = nn.Sequential(
157 | nn.Conv2d(64, 32, kernel_size=4, stride=4, padding=0, bias=True),
158 | nn.ReLU(),
159 | nn.Conv2d(32, 1, kernel_size=4, stride=4, padding=0, bias=True),
160 | )
161 | nn.init.zeros_(self.binary_classifier1[-1].weight)
162 |
163 | self.binary_classifier2 = nn.Sequential(
164 | nn.Conv2d(128, 64, kernel_size=4, stride=4, padding=0, bias=True),
165 | nn.ReLU(),
166 | nn.Conv2d(64, 1, kernel_size=2, stride=2, padding=0, bias=True),
167 | )
168 | nn.init.zeros_(self.binary_classifier2[-1].weight)
169 |
170 | self.binary_classifier3 = nn.Sequential(
171 | nn.Conv2d(256, 128, kernel_size=2, stride=2, padding=0, bias=True),
172 | nn.ReLU(),
173 | nn.Conv2d(128, 1, kernel_size=2, stride=2, padding=0, bias=True),
174 | )
175 | nn.init.zeros_(self.binary_classifier3[-1].weight)
176 |
177 | self.binary_classifier4 = nn.Sequential(
178 | nn.Conv2d(512, 1, kernel_size=2, stride=2, padding=0, bias=True),
179 | )
180 | nn.init.zeros_(self.binary_classifier4[-1].weight)
181 |
182 | self.binary_classifier5 = nn.Sequential(
183 | nn.Conv2d(512, 1, kernel_size=1, stride=1, padding=0, bias=True),
184 | )
185 | nn.init.zeros_(self.binary_classifier5[-1].weight)
186 |
187 | def forward(self, x):
188 | x = self.scaling_layer(x)
189 | features1 = self.slice1(x)
190 | features2 = self.slice2(features1)
191 | features3 = self.slice3(features2)
192 | features4 = self.slice4(features3)
193 | features5 = self.slice5(features4)
194 |
195 | # torch.Size([1, 64, 256, 256]) torch.Size([1, 128, 128, 128]) torch.Size([1, 256, 64, 64]) torch.Size([1, 512, 32, 32]) torch.Size([1, 512, 16, 16])
196 |
197 | bc1 = self.binary_classifier1(features1).flatten(1)
198 | bc2 = self.binary_classifier2(features2).flatten(1)
199 | bc3 = self.binary_classifier3(features3).flatten(1)
200 | bc4 = self.binary_classifier4(features4).flatten(1)
201 | bc5 = self.binary_classifier5(features5).flatten(1)
202 |
203 | return bc1 + bc2 + bc3 + bc4 + bc5
204 |
205 |
206 | dec_lo, dec_hi = (
207 | torch.Tensor([-0.1768, 0.3536, 1.0607, 0.3536, -0.1768, 0.0000]),
208 | torch.Tensor([0.0000, -0.0000, 0.3536, -0.7071, 0.3536, -0.0000]),
209 | )
210 |
211 | filters = torch.stack(
212 | [
213 | dec_lo.unsqueeze(0) * dec_lo.unsqueeze(1),
214 | dec_lo.unsqueeze(0) * dec_hi.unsqueeze(1),
215 | dec_hi.unsqueeze(0) * dec_lo.unsqueeze(1),
216 | dec_hi.unsqueeze(0) * dec_hi.unsqueeze(1),
217 | ],
218 | dim=0,
219 | )
220 |
221 | filters_expanded = filters.unsqueeze(1)
222 |
223 |
224 | def prepare_filter(device):
225 | global filters_expanded
226 | filters_expanded = filters_expanded.to(device)
227 |
228 |
229 | def wavelet_transform_multi_channel(x, levels=4):
230 | B, C, H, W = x.shape
231 | padded = torch.nn.functional.pad(x, (2, 2, 2, 2))
232 |
233 | # use predefined filters
234 | global filters_expanded
235 |
236 | ress = []
237 | for ch in range(C):
238 | res = torch.nn.functional.conv2d(
239 | padded[:, ch : ch + 1], filters_expanded, stride=2
240 | )
241 | ress.append(res)
242 |
243 | res = torch.cat(ress, dim=1)
244 | H_out, W_out = res.shape[2], res.shape[3]
245 | res = res.view(B, C, 4, H_out, W_out)
246 | res = res.view(B, 4 * C, H_out, W_out)
247 | return res
248 |
249 |
250 | def test_patch_discriminator():
251 | vggDiscriminator = PatchDiscriminator().cuda()
252 | x = vggDiscriminator(torch.randn(1, 3, 256, 256).cuda())
253 | print(x.shape)
254 |
255 |
256 | if __name__ == "__main__":
257 | vggDiscriminator = PatchDiscriminator().cuda()
258 | x = vggDiscriminator(torch.randn(1, 3, 256, 256).cuda())
259 | print(x.shape)
260 |
--------------------------------------------------------------------------------
/vae_trainer.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import random
4 |
5 | import click
6 | import numpy as np
7 | import torch
8 | import torch.distributed as dist
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | import torch.optim as optim
12 | import torchvision.transforms as transforms
13 | import webdataset as wds
14 | from torch.nn.parallel import DistributedDataParallel as DDP
15 | from torchvision.transforms import GaussianBlur
16 | from transformers import get_cosine_schedule_with_warmup
17 |
18 | torch.backends.cuda.matmul.allow_tf32 = True
19 | torch.backends.cudnn.allow_tf32 = True
20 |
21 | import wandb
22 | from ae import VAE
23 | from utils import LPIPS, PatchDiscriminator, prepare_filter
24 | import time
25 |
26 |
27 | class GradNormFunction(torch.autograd.Function):
28 | @staticmethod
29 | def forward(ctx, x, weight):
30 | ctx.save_for_backward(weight)
31 | return x.clone()
32 |
33 | @staticmethod
34 | def backward(ctx, grad_output):
35 | weight = ctx.saved_tensors[0]
36 |
37 | # grad_output_norm = torch.linalg.vector_norm(
38 | # grad_output, dim=list(range(1, len(grad_output.shape))), keepdim=True
39 | # ).mean()
40 | grad_output_norm = torch.norm(grad_output).mean().item()
41 | # nccl over all nodes
42 | grad_output_norm = avg_scalar_over_nodes(
43 | grad_output_norm, device=grad_output.device
44 | )
45 |
46 | grad_output_normalized = weight * grad_output / (grad_output_norm + 1e-8)
47 |
48 | return grad_output_normalized, None
49 |
50 |
51 | def gradnorm(x, weight=1.0):
52 | weight = torch.tensor(weight, device=x.device)
53 | return GradNormFunction.apply(x, weight)
54 |
55 |
56 | @torch.no_grad()
57 | def avg_scalar_over_nodes(value: float, device):
58 | value = torch.tensor(value, device=device)
59 | dist.all_reduce(value, op=dist.ReduceOp.AVG)
60 | return value.item()
61 |
62 |
63 | def gan_disc_loss(real_preds, fake_preds, disc_type="bce"):
64 | if disc_type == "bce":
65 | real_loss = nn.functional.binary_cross_entropy_with_logits(
66 | real_preds, torch.ones_like(real_preds)
67 | )
68 | fake_loss = nn.functional.binary_cross_entropy_with_logits(
69 | fake_preds, torch.zeros_like(fake_preds)
70 | )
71 | # eval its online performance
72 | avg_real_preds = real_preds.mean().item()
73 | avg_fake_preds = fake_preds.mean().item()
74 |
75 | with torch.no_grad():
76 | acc = (real_preds > 0).sum().item() + (fake_preds < 0).sum().item()
77 | acc = acc / (real_preds.numel() + fake_preds.numel())
78 |
79 | if disc_type == "hinge":
80 | real_loss = nn.functional.relu(1 - real_preds).mean()
81 | fake_loss = nn.functional.relu(1 + fake_preds).mean()
82 |
83 | with torch.no_grad():
84 | acc = (real_preds > 0).sum().item() + (fake_preds < 0).sum().item()
85 | acc = acc / (real_preds.numel() + fake_preds.numel())
86 |
87 | avg_real_preds = real_preds.mean().item()
88 | avg_fake_preds = fake_preds.mean().item()
89 |
90 | return (real_loss + fake_loss) * 0.5, avg_real_preds, avg_fake_preds, acc
91 |
92 |
93 | MAX_WIDTH = 512
94 |
95 | this_transform = transforms.Compose(
96 | [
97 | transforms.ToTensor(),
98 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
99 | transforms.CenterCrop(512),
100 | transforms.Resize(MAX_WIDTH),
101 | ]
102 | )
103 |
104 |
105 | def this_transform_random_crop_resize(x, width=MAX_WIDTH):
106 |
107 | x = transforms.ToTensor()(x)
108 | x = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])(x)
109 |
110 | if random.random() < 0.5:
111 | x = transforms.RandomCrop(width)(x)
112 | else:
113 | x = transforms.Resize(width)(x)
114 | x = transforms.RandomCrop(width)(x)
115 |
116 | return x
117 |
118 |
119 | def create_dataloader(url, batch_size, num_workers, do_shuffle=True, just_resize=False):
120 | dataset = wds.WebDataset(
121 | url, nodesplitter=wds.split_by_node, workersplitter=wds.split_by_worker
122 | )
123 | dataset = dataset.shuffle(1000) if do_shuffle else dataset
124 |
125 | dataset = (
126 | dataset.decode("rgb")
127 | .to_tuple("jpg;png")
128 | .map_tuple(
129 | this_transform_random_crop_resize if not just_resize else this_transform
130 | )
131 | )
132 |
133 | loader = wds.WebLoader(
134 | dataset,
135 | batch_size=batch_size,
136 | shuffle=False,
137 | num_workers=num_workers,
138 | pin_memory=True,
139 | )
140 | return loader
141 |
142 |
143 | def blurriness_heatmap(input_image):
144 | grayscale_image = input_image.mean(dim=1, keepdim=True)
145 |
146 | laplacian_kernel = torch.tensor(
147 | [
148 | [0, 1, 1, 1, 0],
149 | [1, 1, 1, 1, 1],
150 | [1, 1, -20, 1, 1],
151 | [1, 1, 1, 1, 1],
152 | [0, 1, 1, 1, 0],
153 | ],
154 | dtype=torch.float32,
155 | )
156 | laplacian_kernel = laplacian_kernel.view(1, 1, 5, 5)
157 |
158 | laplacian_kernel = laplacian_kernel.to(input_image.device)
159 |
160 | edge_response = F.conv2d(grayscale_image, laplacian_kernel, padding=2)
161 |
162 | edge_magnitude = GaussianBlur(kernel_size=(13, 13), sigma=(2.0, 2.0))(
163 | edge_response.abs()
164 | )
165 |
166 | edge_magnitude = (edge_magnitude - edge_magnitude.min()) / (
167 | edge_magnitude.max() - edge_magnitude.min() + 1e-8
168 | )
169 |
170 | blurriness_map = 1 - edge_magnitude
171 |
172 | blurriness_map = torch.where(
173 | blurriness_map < 0.8, torch.zeros_like(blurriness_map), blurriness_map
174 | )
175 |
176 | return blurriness_map.repeat(1, 3, 1, 1)
177 |
178 |
179 | def vae_loss_function(x, x_reconstructed, z, do_pool=True, do_recon=False):
180 | # downsample images by factor of 8
181 | if do_recon:
182 | if do_pool:
183 | x_reconstructed_down = F.interpolate(
184 | x_reconstructed, scale_factor=1 / 16, mode="area"
185 | )
186 | x_down = F.interpolate(x, scale_factor=1 / 16, mode="area")
187 | recon_loss = ((x_reconstructed_down - x_down)).abs().mean()
188 | else:
189 | x_reconstructed_down = x_reconstructed
190 | x_down = x
191 |
192 | recon_loss = (
193 | ((x_reconstructed_down - x_down) * blurriness_heatmap(x_down))
194 | .abs()
195 | .mean()
196 | )
197 | recon_loss_item = recon_loss.item()
198 | else:
199 | recon_loss = 0
200 | recon_loss_item = 0
201 |
202 | elewise_mean_loss = z.pow(2)
203 | zloss = elewise_mean_loss.mean()
204 |
205 | with torch.no_grad():
206 | actual_mean_loss = elewise_mean_loss.mean()
207 | actual_ks_loss = actual_mean_loss.mean()
208 |
209 | vae_loss = recon_loss * 0.0 + zloss * 0.1
210 | return vae_loss, {
211 | "recon_loss": recon_loss_item,
212 | "kl_loss": actual_ks_loss.item(),
213 | "average_of_abs_z": z.abs().mean().item(),
214 | "std_of_abs_z": z.abs().std().item(),
215 | "average_of_logvar": 0.0,
216 | "std_of_logvar": 0.0,
217 | }
218 |
219 |
220 | def cleanup():
221 | dist.destroy_process_group()
222 |
223 |
224 | @click.command()
225 | @click.option(
226 | "--dataset_url", type=str, default="", help="URL for the training dataset"
227 | )
228 | @click.option(
229 | "--test_dataset_url", type=str, default="", help="URL for the test dataset"
230 | )
231 | @click.option("--num_epochs", type=int, default=2, help="Number of training epochs")
232 | @click.option("--batch_size", type=int, default=8, help="Batch size for training")
233 | @click.option("--do_ganloss", is_flag=True, help="Whether to use GAN loss")
234 | @click.option(
235 | "--learning_rate_vae", type=float, default=1e-5, help="Learning rate for VAE"
236 | )
237 | @click.option(
238 | "--learning_rate_disc",
239 | type=float,
240 | default=2e-4,
241 | help="Learning rate for discriminator",
242 | )
243 | @click.option("--vae_resolution", type=int, default=256, help="Resolution for VAE")
244 | @click.option("--vae_in_channels", type=int, default=3, help="Input channels for VAE")
245 | @click.option("--vae_ch", type=int, default=256, help="Base channel size for VAE")
246 | @click.option(
247 | "--vae_ch_mult", type=str, default="1,2,4,4", help="Channel multipliers for VAE"
248 | )
249 | @click.option(
250 | "--vae_num_res_blocks",
251 | type=int,
252 | default=2,
253 | help="Number of residual blocks for VAE",
254 | )
255 | @click.option(
256 | "--vae_z_channels", type=int, default=16, help="Number of latent channels for VAE"
257 | )
258 | @click.option("--run_name", type=str, default="run", help="Name of the run for wandb")
259 | @click.option(
260 | "--max_steps", type=int, default=1000, help="Maximum number of steps to train for"
261 | )
262 | @click.option(
263 | "--evaluate_every_n_steps", type=int, default=250, help="Evaluate every n steps"
264 | )
265 | @click.option("--load_path", type=str, default=None, help="Path to load the model from")
266 | @click.option("--do_clamp", is_flag=True, help="Whether to clamp the latent codes")
267 | @click.option(
268 | "--clamp_th", type=float, default=8.0, help="Clamp threshold for the latent codes"
269 | )
270 | @click.option(
271 | "--max_spatial_dim",
272 | type=int,
273 | default=256,
274 | help="Maximum spatial dimension for overall training",
275 | )
276 | @click.option(
277 | "--do_attn", type=bool, default=False, help="Whether to use attention in the VAE"
278 | )
279 | @click.option(
280 | "--decoder_also_perform_hr",
281 | type=bool,
282 | default=False,
283 | help="Whether to perform HR decoding in the decoder",
284 | )
285 | @click.option(
286 | "--project_name",
287 | type=str,
288 | default="vae_sweep_attn_lr_width",
289 | help="Project name for wandb",
290 | )
291 | @click.option(
292 | "--crop_invariance",
293 | type=bool,
294 | default=False,
295 | help="Whether to perform crop invariance",
296 | )
297 | @click.option(
298 | "--flip_invariance",
299 | type=bool,
300 | default=False,
301 | help="Whether to perform flip invariance",
302 | )
303 | @click.option(
304 | "--do_compile",
305 | type=bool,
306 | default=False,
307 | help="Whether to compile the model",
308 | )
309 | @click.option(
310 | "--use_wavelet",
311 | type=bool,
312 | default=False,
313 | help="Whether to use wavelet transform in the encoder",
314 | )
315 | @click.option(
316 | "--augment_before_perceptual_loss",
317 | type=bool,
318 | default=False,
319 | help="Whether to augment the images before the perceptual loss",
320 | )
321 | @click.option(
322 | "--downscale_factor",
323 | type=int,
324 | default=16,
325 | help="Downscale factor for the latent space",
326 | )
327 | @click.option(
328 | "--use_lecam",
329 | type=bool,
330 | default=False,
331 | help="Whether to use Lecam",
332 | )
333 | @click.option(
334 | "--disc_type",
335 | type=str,
336 | default="bce",
337 | help="Discriminator type",
338 | )
339 | def train_ddp(
340 | dataset_url,
341 | test_dataset_url,
342 | num_epochs,
343 | batch_size,
344 | do_ganloss,
345 | learning_rate_vae,
346 | learning_rate_disc,
347 | vae_resolution,
348 | vae_in_channels,
349 | vae_ch,
350 | vae_ch_mult,
351 | vae_num_res_blocks,
352 | vae_z_channels,
353 | run_name,
354 | max_steps,
355 | evaluate_every_n_steps,
356 | load_path,
357 | do_clamp,
358 | clamp_th,
359 | max_spatial_dim,
360 | do_attn,
361 | decoder_also_perform_hr,
362 | project_name,
363 | crop_invariance,
364 | flip_invariance,
365 | do_compile,
366 | use_wavelet,
367 | augment_before_perceptual_loss,
368 | downscale_factor,
369 | use_lecam,
370 | disc_type,
371 | ):
372 |
373 | # fix random seed
374 | torch.manual_seed(42)
375 | torch.cuda.manual_seed(42)
376 | torch.cuda.manual_seed_all(42)
377 | np.random.seed(42)
378 | random.seed(42)
379 |
380 | start_train = 0
381 | end_train = 128 * 16
382 |
383 | start_test = end_train + 1
384 | end_test = start_test + 8
385 |
386 | dataset_url = f"/home/ubuntu/ultimate_pipe/flux_ipadapter_trainer/dataset/art_webdataset/{{{start_train:05d}..{end_train:05d}}}.tar"
387 | test_dataset_url = f"/home/ubuntu/ultimate_pipe/flux_ipadapter_trainer/dataset/art_webdataset/{{{start_test:05d}..{end_test:05d}}}.tar"
388 |
389 | assert torch.cuda.is_available(), "CUDA is required for DDP"
390 |
391 | dist.init_process_group(backend="nccl")
392 | ddp_rank = int(os.environ["RANK"])
393 | ddp_local_rank = int(os.environ["LOCAL_RANK"])
394 | world_size = int(os.environ["WORLD_SIZE"])
395 | device = f"cuda:{ddp_local_rank}"
396 | torch.cuda.set_device(device)
397 | master_process = ddp_rank == 0
398 | print(f"using device: {device}")
399 |
400 | if master_process:
401 | wandb.init(
402 | project=project_name,
403 | entity="simo",
404 | name=run_name,
405 | config={
406 | "learning_rate_vae": learning_rate_vae,
407 | "learning_rate_disc": learning_rate_disc,
408 | "vae_ch": vae_ch,
409 | "vae_resolution": vae_resolution,
410 | "vae_in_channels": vae_in_channels,
411 | "vae_ch_mult": vae_ch_mult,
412 | "vae_num_res_blocks": vae_num_res_blocks,
413 | "vae_z_channels": vae_z_channels,
414 | "batch_size": batch_size,
415 | "num_epochs": num_epochs,
416 | "do_ganloss": do_ganloss,
417 | "do_attn": do_attn,
418 | "use_wavelet": use_wavelet,
419 | },
420 | )
421 |
422 | vae = VAE(
423 | resolution=vae_resolution,
424 | in_channels=vae_in_channels,
425 | ch=vae_ch,
426 | out_ch=vae_in_channels,
427 | ch_mult=[int(x) for x in vae_ch_mult.split(",")],
428 | num_res_blocks=vae_num_res_blocks,
429 | z_channels=vae_z_channels,
430 | use_attn=do_attn,
431 | decoder_also_perform_hr=decoder_also_perform_hr,
432 | use_wavelet=use_wavelet,
433 | ).cuda()
434 |
435 | discriminator = PatchDiscriminator().cuda()
436 | discriminator.requires_grad_(True)
437 |
438 | vae = DDP(vae, device_ids=[ddp_rank])
439 |
440 | prepare_filter(device)
441 |
442 | if do_compile:
443 | vae.module.encoder = torch.compile(
444 | vae.module.encoder, fullgraph=False, mode="max-autotune"
445 | )
446 | vae.module.decoder = torch.compile(
447 | vae.module.decoder, fullgraph=False, mode="max-autotune"
448 | )
449 |
450 | discriminator = DDP(discriminator, device_ids=[ddp_rank])
451 |
452 | # context
453 | ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
454 |
455 | optimizer_G = optim.AdamW(
456 | [
457 | {
458 | "params": [p for n, p in vae.named_parameters() if "conv_in" not in n],
459 | "lr": learning_rate_vae / vae_ch,
460 | },
461 | {
462 | "params": [p for n, p in vae.named_parameters() if "conv_in" in n],
463 | "lr": 1e-4,
464 | },
465 | ],
466 | weight_decay=1e-3,
467 | betas=(0.9, 0.95),
468 | )
469 |
470 | optimizer_D = optim.AdamW(
471 | discriminator.parameters(),
472 | lr=learning_rate_disc,
473 | weight_decay=1e-3,
474 | betas=(0.9, 0.95),
475 | )
476 |
477 | lpips = LPIPS().cuda()
478 |
479 | dataloader = create_dataloader(
480 | dataset_url, batch_size, num_workers=4, do_shuffle=True
481 | )
482 | test_dataloader = create_dataloader(
483 | test_dataset_url, batch_size, num_workers=4, do_shuffle=False, just_resize=True
484 | )
485 |
486 | num_training_steps = max_steps
487 | num_warmup_steps = 200
488 | lr_scheduler = get_cosine_schedule_with_warmup(
489 | optimizer_G, num_warmup_steps, num_training_steps
490 | )
491 |
492 | # Setup logger
493 | logger = logging.getLogger(__name__)
494 | logger.setLevel(logging.INFO)
495 | if master_process:
496 | handler = logging.StreamHandler()
497 | formatter = logging.Formatter(
498 | "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
499 | )
500 | handler.setFormatter(formatter)
501 | logger.addHandler(handler)
502 |
503 | global_step = 0
504 |
505 | if load_path is not None:
506 | state_dict = torch.load(load_path, map_location="cpu")
507 | try:
508 | status = vae.load_state_dict(state_dict, strict=True)
509 | except Exception as e:
510 | print(e)
511 | state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
512 | status = vae.load_state_dict(state_dict, strict=True)
513 | print(status)
514 |
515 | t0 = time.time()
516 |
517 | # lecam variable
518 |
519 | lecam_loss_weight = 0.1
520 | lecam_anchor_real_logits = 0.0
521 | lecam_anchor_fake_logits = 0.0
522 | lecam_beta = 0.9
523 |
524 | for epoch in range(num_epochs):
525 | for i, real_images_hr in enumerate(dataloader):
526 | time_taken_till_load = time.time() - t0
527 |
528 | t0 = time.time()
529 | # resize real image to 256
530 | real_images_hr = real_images_hr[0].to(device)
531 | real_images_for_enc = F.interpolate(
532 | real_images_hr, size=(256, 256), mode="area"
533 | )
534 | if random.random() < 0.5:
535 | real_images_for_enc = torch.flip(real_images_for_enc, [-1])
536 | real_images_hr = torch.flip(real_images_hr, [-1])
537 |
538 | z = vae.module.encoder(real_images_for_enc)
539 |
540 | # z distribution
541 | with ctx:
542 | z_dist_value: torch.Tensor = z.detach().cpu().reshape(-1)
543 |
544 | def kurtosis(x):
545 | return ((x - x.mean()) ** 4).mean() / (x.std() ** 4)
546 |
547 | def skew(x):
548 | return ((x - x.mean()) ** 3).mean() / (x.std() ** 3)
549 |
550 | z_quantiles = {
551 | "0.0": z_dist_value.quantile(0.0),
552 | "0.2": z_dist_value.quantile(0.2),
553 | "0.4": z_dist_value.quantile(0.4),
554 | "0.6": z_dist_value.quantile(0.6),
555 | "0.8": z_dist_value.quantile(0.8),
556 | "1.0": z_dist_value.quantile(1.0),
557 | "kurtosis": kurtosis(z_dist_value),
558 | "skewness": skew(z_dist_value),
559 | }
560 |
561 | if do_clamp:
562 | z = z.clamp(-clamp_th, clamp_th)
563 | z_s = vae.module.reg(z)
564 |
565 | #### do aug
566 |
567 | if random.random() < 0.5 and flip_invariance:
568 | z_s = torch.flip(z_s, [-1])
569 | z_s[:, -4:-2] = -z_s[:, -4:-2]
570 | real_images_hr = torch.flip(real_images_hr, [-1])
571 |
572 | if random.random() < 0.5 and flip_invariance:
573 | z_s = torch.flip(z_s, [-2])
574 | z_s[:, -2:] = -z_s[:, -2:]
575 | real_images_hr = torch.flip(real_images_hr, [-2])
576 |
577 | if random.random() < 0.5 and crop_invariance:
578 | # crop image and latent.'
579 |
580 | # new_z_h, new_z_w, offset_z_h, offset_z_w
581 | z_h, z_w = z.shape[-2:]
582 | new_z_h = random.randint(12, z_h - 1)
583 | new_z_w = random.randint(12, z_w - 1)
584 | offset_z_h = random.randint(0, z_h - new_z_h - 1)
585 | offset_z_w = random.randint(0, z_w - new_z_w - 1)
586 |
587 | new_h = (
588 | new_z_h * downscale_factor * 2
589 | if decoder_also_perform_hr
590 | else new_z_h * downscale_factor
591 | )
592 | new_w = (
593 | new_z_w * downscale_factor * 2
594 | if decoder_also_perform_hr
595 | else new_z_w * downscale_factor
596 | )
597 | offset_h = (
598 | offset_z_h * downscale_factor * 2
599 | if decoder_also_perform_hr
600 | else offset_z_h * downscale_factor
601 | )
602 | offset_w = (
603 | offset_z_w * downscale_factor * 2
604 | if decoder_also_perform_hr
605 | else offset_z_w * downscale_factor
606 | )
607 |
608 | real_images_hr = real_images_hr[
609 | :, :, offset_h : offset_h + new_h, offset_w : offset_w + new_w
610 | ]
611 | z_s = z_s[
612 | :,
613 | :,
614 | offset_z_h : offset_z_h + new_z_h,
615 | offset_z_w : offset_z_w + new_z_w,
616 | ]
617 |
618 | assert real_images_hr.shape[-2] == new_h
619 | assert real_images_hr.shape[-1] == new_w
620 | assert z_s.shape[-2] == new_z_h
621 | assert z_s.shape[-1] == new_z_w
622 |
623 | with ctx:
624 | reconstructed = vae.module.decoder(z_s)
625 |
626 | if global_step >= max_steps:
627 | break
628 |
629 | if do_ganloss:
630 | real_preds = discriminator(real_images_hr)
631 | fake_preds = discriminator(reconstructed.detach())
632 | d_loss, avg_real_logits, avg_fake_logits, disc_acc = gan_disc_loss(
633 | real_preds, fake_preds, disc_type
634 | )
635 |
636 | avg_real_logits = avg_scalar_over_nodes(avg_real_logits, device)
637 | avg_fake_logits = avg_scalar_over_nodes(avg_fake_logits, device)
638 |
639 | lecam_anchor_real_logits = (
640 | lecam_beta * lecam_anchor_real_logits
641 | + (1 - lecam_beta) * avg_real_logits
642 | )
643 | lecam_anchor_fake_logits = (
644 | lecam_beta * lecam_anchor_fake_logits
645 | + (1 - lecam_beta) * avg_fake_logits
646 | )
647 | total_d_loss = d_loss.mean()
648 | d_loss_item = total_d_loss.item()
649 | if use_lecam:
650 | # penalize the real logits to fake and fake logits to real.
651 | lecam_loss = (real_preds - lecam_anchor_fake_logits).pow(
652 | 2
653 | ).mean() + (fake_preds - lecam_anchor_real_logits).pow(2).mean()
654 | lecam_loss_item = lecam_loss.item()
655 | total_d_loss = total_d_loss + lecam_loss * lecam_loss_weight
656 |
657 | optimizer_D.zero_grad()
658 | total_d_loss.backward(retain_graph=True)
659 | optimizer_D.step()
660 |
661 | # unnormalize the images, and perceptual loss
662 | _recon_for_perceptual = gradnorm(reconstructed)
663 |
664 | if augment_before_perceptual_loss:
665 | real_images_hr_aug = real_images_hr.clone()
666 | if random.random() < 0.5:
667 | _recon_for_perceptual = torch.flip(_recon_for_perceptual, [-1])
668 | real_images_hr_aug = torch.flip(real_images_hr_aug, [-1])
669 | if random.random() < 0.5:
670 | _recon_for_perceptual = torch.flip(_recon_for_perceptual, [-2])
671 | real_images_hr_aug = torch.flip(real_images_hr_aug, [-2])
672 |
673 | else:
674 | real_images_hr_aug = real_images_hr
675 |
676 | percep_rec_loss = lpips(_recon_for_perceptual, real_images_hr_aug).mean()
677 |
678 | # mse, vae loss.
679 | recon_for_mse = gradnorm(reconstructed, weight=0.001)
680 | vae_loss, loss_data = vae_loss_function(real_images_hr, recon_for_mse, z)
681 | # gan loss
682 | if do_ganloss and global_step >= 0:
683 | recon_for_gan = gradnorm(reconstructed, weight=1.0)
684 | fake_preds = discriminator(recon_for_gan)
685 | real_preds_const = real_preds.clone().detach()
686 | # loss where (real > fake + 0.01)
687 | # g_gan_loss = (real_preds_const - fake_preds - 0.1).relu().mean()
688 | if disc_type == "bce":
689 | g_gan_loss = nn.functional.binary_cross_entropy_with_logits(
690 | fake_preds, torch.ones_like(fake_preds)
691 | )
692 | elif disc_type == "hinge":
693 | g_gan_loss = -fake_preds.mean()
694 |
695 | overall_vae_loss = percep_rec_loss + g_gan_loss + vae_loss
696 | g_gan_loss = g_gan_loss.item()
697 | else:
698 | overall_vae_loss = percep_rec_loss + vae_loss
699 | g_gan_loss = 0.0
700 |
701 | overall_vae_loss.backward()
702 | optimizer_G.step()
703 | optimizer_G.zero_grad()
704 | lr_scheduler.step()
705 |
706 | if do_ganloss:
707 |
708 | optimizer_D.zero_grad()
709 |
710 | time_taken_till_step = time.time() - t0
711 |
712 | if master_process:
713 | if global_step % 5 == 0:
714 | wandb.log(
715 | {
716 | "epoch": epoch,
717 | "batch": i,
718 | "overall_vae_loss": overall_vae_loss.item(),
719 | "mse_loss": loss_data["recon_loss"],
720 | "kl_loss": loss_data["kl_loss"],
721 | "perceptual_loss": percep_rec_loss.item(),
722 | "gan/generator_gan_loss": (
723 | g_gan_loss if do_ganloss else None
724 | ),
725 | "z_quantiles/abs_z": loss_data["average_of_abs_z"],
726 | "z_quantiles/std_z": loss_data["std_of_abs_z"],
727 | "z_quantiles/logvar": loss_data["average_of_logvar"],
728 | "gan/avg_real_logits": (
729 | avg_real_logits if do_ganloss else None
730 | ),
731 | "gan/avg_fake_logits": (
732 | avg_fake_logits if do_ganloss else None
733 | ),
734 | "gan/discriminator_loss": (
735 | d_loss_item if do_ganloss else None
736 | ),
737 | "gan/discriminator_accuracy": (
738 | disc_acc if do_ganloss else None
739 | ),
740 | "gan/lecam_loss": lecam_loss_item if do_ganloss else None,
741 | "gan/lecam_anchor_real_logits": (
742 | lecam_anchor_real_logits if do_ganloss else None
743 | ),
744 | "gan/lecam_anchor_fake_logits": (
745 | lecam_anchor_fake_logits if do_ganloss else None
746 | ),
747 | "z_quantiles/qs": z_quantiles,
748 | "time_taken_till_step": time_taken_till_step,
749 | "time_taken_till_load": time_taken_till_load,
750 | }
751 | )
752 |
753 | if global_step % 200 == 0:
754 |
755 | wandb.log(
756 | {
757 | f"loss_stepwise/mse_loss_{global_step}": loss_data[
758 | "recon_loss"
759 | ],
760 | f"loss_stepwise/kl_loss_{global_step}": loss_data[
761 | "kl_loss"
762 | ],
763 | f"loss_stepwise/overall_vae_loss_{global_step}": overall_vae_loss.item(),
764 | }
765 | )
766 |
767 | log_message = f"Epoch [{epoch}/{num_epochs}] - "
768 | log_items = [
769 | ("perceptual_loss", percep_rec_loss.item()),
770 | ("mse_loss", loss_data["recon_loss"]),
771 | ("kl_loss", loss_data["kl_loss"]),
772 | ("overall_vae_loss", overall_vae_loss.item()),
773 | ("ABS mu (0.0): average_of_abs_z", loss_data["average_of_abs_z"]),
774 | ("STD mu : std_of_abs_z", loss_data["std_of_abs_z"]),
775 | (
776 | "ABS logvar (0.0) : average_of_logvar",
777 | loss_data["average_of_logvar"],
778 | ),
779 | ("STD logvar : std_of_logvar", loss_data["std_of_logvar"]),
780 | *[(f"z_quantiles/{q}", v) for q, v in z_quantiles.items()],
781 | ("time_taken_till_step", time_taken_till_step),
782 | ("time_taken_till_load", time_taken_till_load),
783 | ]
784 |
785 | if do_ganloss:
786 | log_items = [
787 | ("d_loss", d_loss_item),
788 | ("gan_loss", g_gan_loss),
789 | ("avg_real_logits", avg_real_logits),
790 | ("avg_fake_logits", avg_fake_logits),
791 | ("discriminator_accuracy", disc_acc),
792 | ("lecam_loss", lecam_loss_item),
793 | ("lecam_anchor_real_logits", lecam_anchor_real_logits),
794 | ("lecam_anchor_fake_logits", lecam_anchor_fake_logits),
795 | ] + log_items
796 |
797 | log_message += "\n\t".join(
798 | [f"{key}: {value:.4f}" for key, value in log_items]
799 | )
800 | logger.info(log_message)
801 |
802 | global_step += 1
803 | t0 = time.time()
804 |
805 | if (
806 | evaluate_every_n_steps > 0
807 | and global_step % evaluate_every_n_steps == 1
808 | and master_process
809 | ):
810 |
811 | with torch.no_grad():
812 | all_test_images = []
813 | all_reconstructed_test = []
814 |
815 | for test_images in test_dataloader:
816 | test_images_ori = test_images[0].to(device)
817 | # resize to 256
818 | test_images = F.interpolate(
819 | test_images_ori, size=(256, 256), mode="area"
820 | )
821 | with ctx:
822 | z = vae.module.encoder(test_images)
823 |
824 | if do_clamp:
825 | z = z.clamp(-clamp_th, clamp_th)
826 |
827 | z_s = vae.module.reg(z)
828 |
829 | # [1, 2]
830 | # [3, 4]
831 | # ->
832 | # [3, 4]
833 | # [1, 2]
834 | # ->
835 | # [4, 3]
836 | # [2, 1]
837 | if flip_invariance:
838 | z_s = torch.flip(z_s, [-1, -2])
839 | z_s[:, -4:] = -z_s[:, -4:]
840 |
841 | with ctx:
842 | reconstructed_test = vae.module.decoder(z_s)
843 |
844 | # unnormalize the images
845 | test_images_ori = test_images_ori * 0.5 + 0.5
846 | reconstructed_test = reconstructed_test * 0.5 + 0.5
847 | # clamp
848 | test_images_ori = test_images_ori.clamp(0, 1)
849 | reconstructed_test = reconstructed_test.clamp(0, 1)
850 |
851 | # flip twice
852 | if flip_invariance:
853 | reconstructed_test = torch.flip(
854 | reconstructed_test, [-1, -2]
855 | )
856 |
857 | all_test_images.append(test_images_ori)
858 | all_reconstructed_test.append(reconstructed_test)
859 |
860 | if len(all_test_images) >= 2:
861 | break
862 |
863 | test_images = torch.cat(all_test_images, dim=0)
864 | reconstructed_test = torch.cat(all_reconstructed_test, dim=0)
865 |
866 | logger.info(f"Epoch [{epoch}/{num_epochs}] - Logging test images")
867 |
868 | # crop test and recon to 64 x 64
869 | D = 512 if decoder_also_perform_hr else 256
870 | offset = 0
871 | test_images = test_images[
872 | :, :, offset : offset + D, offset : offset + D
873 | ].cpu()
874 | reconstructed_test = reconstructed_test[
875 | :, :, offset : offset + D, offset : offset + D
876 | ].cpu()
877 |
878 | # concat the images into one large image.
879 | # make size of (D * 4) x (D * 4)
880 | recon_all_image = torch.zeros((3, D * 4, D * 4))
881 | test_all_image = torch.zeros((3, D * 4, D * 4))
882 |
883 | for i in range(2):
884 | for j in range(4):
885 | recon_all_image[
886 | :, i * D : (i + 1) * D, j * D : (j + 1) * D
887 | ] = reconstructed_test[i * 4 + j]
888 | test_all_image[
889 | :, i * D : (i + 1) * D, j * D : (j + 1) * D
890 | ] = test_images[i * 4 + j]
891 |
892 | wandb.log(
893 | {
894 | "reconstructed_test_images": [
895 | wandb.Image(recon_all_image),
896 | ],
897 | "test_images": [
898 | wandb.Image(test_all_image),
899 | ],
900 | }
901 | )
902 |
903 | os.makedirs(f"./ckpt/{run_name}", exist_ok=True)
904 | torch.save(
905 | vae.state_dict(),
906 | f"./ckpt/{run_name}/vae_epoch_{epoch}_step_{global_step}.pt",
907 | )
908 | print(
909 | f"Saved checkpoint to ./ckpt/{run_name}/vae_epoch_{epoch}_step_{global_step}.pt"
910 | )
911 |
912 | cleanup()
913 |
914 |
915 | if __name__ == "__main__":
916 |
917 | # Example: torchrun --nproc_per_node=8 vae_trainer.py
918 | train_ddp()
919 |
--------------------------------------------------------------------------------