├── .gitignore ├── README.hf.md ├── README.md ├── ae.py ├── challenge.ipynb ├── contents ├── appg.png ├── buildings.jpg ├── chinatown.jpg ├── cosplayers.jpg ├── flowers.jpg ├── lavender.jpg ├── logo.png ├── magazines.jpg ├── origin.png ├── randomwoman.jpeg ├── recon.png ├── ti2.png └── ti2_mask.png ├── inference.ipynb ├── launcher.sh ├── scripts └── launch_hdr.sh ├── sweep.sh ├── tae.py ├── tester.py ├── tester_upload.sh ├── unit_activation_reinitializer.py ├── utils.py └── vae_trainer.py /.gitignore: -------------------------------------------------------------------------------- 1 | wandb 2 | *.torrent 3 | celebvq 4 | __pycache__ 5 | vgg.pth 6 | *_tester.ipynb 7 | ckpt 8 | *.pt 9 | *.pth 10 | data 11 | mnist_data -------------------------------------------------------------------------------- /README.hf.md: -------------------------------------------------------------------------------- 1 | --- 2 | license: apache-2.0 3 | --- 4 | 5 | # Equivariant 16ch, f8 VAE 6 | 7 | 8 | 9 | AuraEquiVAE is a novel autoencoder that addresses multiple problems of existing conventional VAEs. First, unlike traditional VAEs that have significantly small log-variance, this model admits large noise to the latent space. 10 | Additionally, unlike traditional VAEs, the latent space is equivariant under `Z_2 X Z_2` group operations (Horizontal / Vertical flip). 11 | 12 | To understand the equivariance, we apply suitable group actions to both the latent space globally and locally. The latent is represented as `Z = (z_1, ..., z_n)`, and we perform a global permutation group action `g_global` on the tuples such that `g_global` is isomorphic to the `Z_2 x Z_2` group. 13 | We also apply a local action `g_local` to individual `z_i` elements such that `g_local` is also isomorphic to the `Z_2 x Z_2` group. 14 | 15 | In our specific case, `g_global` corresponds to flips, while `g_local` corresponds to sign flips on specific latent dimensions. Changing 2 channels for sign flips for both horizontal and vertical directions was chosen empirically. 16 | 17 | The model has been trained using the approach described in [Mastering VAE Training](https://github.com/cloneofsimo/vqgan-training), where detailed explanations for the training process can be found. 18 | 19 | ## How to use 20 | 21 | To use the weights, copy paste the [VAE](https://github.com/cloneofsimo/vqgan-training/blob/03e04401cf49fe55be612d1f568be0110aa0fad1/ae.py) implementation. 22 | 23 | ```python 24 | from ae import VAE 25 | import torch 26 | from PIL import Image 27 | 28 | vae = VAE( 29 | resolution=256, 30 | in_channels=3, 31 | ch=256, 32 | out_ch=3, 33 | ch_mult=[1, 2, 4, 4], 34 | num_res_blocks=2, 35 | z_ch 36 | ).cuda().bfloat16() 37 | 38 | from safetensors.torch import load_file 39 | state_dict = load_file("./vae_epoch_3_step_49501_bf16.pt") 40 | vae.load_state_dict(state_dict) 41 | 42 | imgpath = 'contents/lavender.jpg' 43 | 44 | img_orig = Image.open(imgpath).convert("RGB") 45 | offset = 128 46 | W = 768 47 | img_orig = img_orig.crop((offset, offset, W + offset, W + offset)) 48 | img = transforms.ToTensor()(img_orig).unsqueeze(0).cuda() 49 | img = (img - 0.5) / 0.5 50 | 51 | with torch.no_grad(): 52 | z = vae.encoder(img) 53 | z = z.clamp(-8.0, 8.0) # this is latent!! 54 | 55 | # flip horizontal 56 | z = torch.flip(z, [-1]) # this corresponds to g_global 57 | z[:, -4:-2] = -z[:, -4:-2] # this corresponds to g_local 58 | 59 | # flip vertical 60 | z = torch.flip(z, [-2]) 61 | z[:, -2:] = -z[:, -2:] 62 | 63 | 64 | with torch.no_grad(): 65 | decz = vae.decoder(z) # this is image! 66 | 67 | decimg = ((decz + 1) / 2).clamp(0, 1).squeeze(0).cpu().float().numpy().transpose(1, 2, 0) 68 | decimg = (decimg * 255).astype('uint8') 69 | decimg = Image.fromarray(decimg) # PIL image. 70 | 71 | ``` 72 | 73 | ## Citation 74 | 75 | If you find this model useful, please cite: 76 | 77 | ``` 78 | @misc{Training VQGAN and VAE, with detailed explanation, 79 | author = {Simo Ryu}, 80 | title = {Training VQGAN and VAE, with detailed explanation}, 81 | year = {2024}, 82 | publisher = {GitHub}, 83 | journal = {GitHub repository}, 84 | howpublished = {\url{https://github.com/cloneofsimo/vqgan-training}}, 85 | } 86 | ``` 87 | 88 | 89 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VAE Trainer 2 | 3 | The famous VAE of latent diffusion models, such as stable diffusion, FLUX, SORA, etc. How are they trained? This is my attempt to write distributed VAE trainer. It is largely based on [LDM's VAE](https://arxiv.org/abs/2112.10752) 4 | 5 |

6 | origin 7 | recon 8 |

9 | 10 | ## Details 11 | 12 | The VAE architecture is based on the one used in latent diffusion models, with some modifications for improved training stability and performance. 13 | 14 | 1. **Distributed Training**: Utilizes PyTorch's DistributedDataParallel (DDP) for efficient multi-GPU training. 15 | 16 | 2. **GAN Loss**: Incorporates GAN loss for enhanced image quality. A pretrained VGG16 backbone with a simple linear convolutional layer atop 16 feature maps serves as the discriminator. The hinge loss is employed for GAN loss, with a thresholding to stabilize training: 17 | 18 | $$L_{GAN} = \max(0, D(x) - D(\hat{x}) - 0.1)$$ 19 | 20 | where $x$ is the input and $\hat{x} = \text{Dec}(\text{Enc}(x))$ is the reconstructed image. This way, the generator (Decoder/Encoder) is only trained when the discriminator is good enough, leaving early stage training more stable. 21 | 22 | 23 | 3. **Perceptual Loss**: Utilizes LPIPS for reconstruction loss. 24 | 25 | 4. **Gradient Normalization**: Implements gradient normalization for stable training, offering a simpler alternative to rebalancing via autograd mechanism modification. This approach normalizes gradients to address potential imbalances between loss components: 26 | 27 | $$ \nabla_{\theta} L = \frac{1}{|\nabla_{X} L_{GAN}|} \nabla_{\theta} L_{GAN} + \frac{1}{|\nabla_{X} L_{percep}|} \nabla_{\theta} L_{percep} $$ 28 | 29 | where $\nabla_{X}$ denotes the gradient with respect to input. This method is implemented using a custom autograd function, modifying the backward pass while preserving the forward pass. You can use it as follows: 30 | 31 | ```python 32 | class GradNormFunction(torch.autograd.Function): 33 | @staticmethod 34 | def forward(ctx, x): 35 | 36 | return x.clone() 37 | 38 | @staticmethod 39 | def backward(ctx, grad_output): 40 | 41 | grad_output_norm = torch.linalg.norm(grad_output, dim=0, keepdim=True) 42 | 43 | grad_output_normalized = grad_output / (grad_output_norm + 1e-8) 44 | 45 | return grad_output_normalized 46 | 47 | def gradnorm(x): 48 | return GradNormFunction.apply(x) 49 | 50 | # example: 51 | 52 | x = gradnorm(x) # x.grad is now normalized to have unit norm. 53 | 54 | ``` 55 | 56 | 5. **Constant Variance**: We implemented a fixed variance of 0.1, deviating from the conventional learnable variance in VAE architectures. This modification addresses the tendency of VAEs to indiscriminately decrease log variance, regardless of KL divergence penalty coefficients. 57 | 58 | 6. **Pooled Mean Squared Error (MSE)**: A pooled MSE was introduced to modify the original VAE reconstruction loss, reducing the model's sensitivity to high-frequency input details. 59 | 60 | 7. **Low-Pass Reconstruction Loss**: Our loss function combines LPIPS and MSE components to balance competing behaviors. While MSE and L1 losses often result in blurred outputs due to the inherent risk in predicting high-intensity signals at precise locations, retaining some MSE loss is crucial for accurate color tone calibration. To address this, we found using MSE loss on a low-pass filtered version of the image is helpful. 61 | 62 | This method starts off by applying a Laplacian-like filter on the original grayscale image, followed by MSE loss application on the filtered result. The actual low-pass detection filter is bit more sophisticated, first detecting high frequency region, blurring, followed by truncation. 63 | 64 | The following images illustrate an example of the filtered result: 65 | 66 |

67 | filtered 68 | filtered 69 |

70 | 71 | Notice how the high frequency details are colored by black: now the white region is where MSE loss is applied, making the conflict between LPIPS and MSE loss less severe. 72 | 73 | 74 | ## Files 75 | 76 | - `ae.py`: Contains the VAE architecture implementation. Most of which is based on FLUX's VAE, with some modifications such as constant variance and multi-heads. 77 | 78 | - `vae_trainer.py`: Main training script with DDP setup. 79 | 80 | - `utils.py`: Utility functions and classes, including LPIPS and PatchDiscriminator. 81 | 82 | ## Usage 83 | 84 | You need to have a dataset of png images as [webdataset](https://github.com/tmbdev/webdataset) format. You can use [img2dataset](https://github.com/rom1504/img2dataset) to download your dataset. 85 | 86 | 87 | To start training, use the following command: 88 | 89 | ```bash 90 | torchrun --nproc_per_node=8 vae_trainer.py 91 | ``` 92 | 93 | This will initiate training on 8 GPUs. Adjust the number based on your available hardware. 94 | 95 | ## Configuration 96 | 97 | The trainer supports various configuration options through command-line arguments. Some key parameters include: 98 | 99 | - `--learning_rate_vae`: Learning rate for the VAE. 100 | - `--vae_ch`: Base channel size for the VAE. 101 | - `--vae_ch_mult`: Channel multipliers for the VAE. 102 | - `--do_ganloss`: Flag to enable GAN loss. 103 | 104 | For a full list of options, refer to the `train_ddp` function in `vae_trainer.py`. 105 | 106 | ## Citation 107 | 108 | If you find this dataset useful, please cite: 109 | 110 | ``` 111 | @misc{Training VQGAN and VAE, with detailed explanation, 112 | author = {Simo Ryu}, 113 | title = {Training VQGAN and VAE, with detailed explanation}, 114 | year = {2024}, 115 | publisher = {GitHub}, 116 | journal = {GitHub repository}, 117 | howpublished = {\url{https://github.com/cloneofsimo/vqgan-training}}, 118 | } 119 | ``` 120 | 121 | --- 122 | Above readme was written by claude. 123 | -------------------------------------------------------------------------------- /ae.py: -------------------------------------------------------------------------------- 1 | # Take from FLUX 2 | 3 | 4 | import math 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | from einops import rearrange 9 | from torch import Tensor, nn 10 | from utils import wavelet_transform_multi_channel 11 | 12 | 13 | def swish(x) -> Tensor: 14 | return x * torch.sigmoid(x) 15 | 16 | 17 | # class StandardizedC2d(nn.Conv2d): 18 | # def __init__(self, *args, **kwargs): 19 | # super().__init__(*args, **kwargs) 20 | # self.step = 0 21 | 22 | # def forward(self, input): 23 | # output = super().forward(input) 24 | # # normalize the weights 25 | # if self.step < 1000: 26 | # with torch.no_grad(): 27 | # std = output.std().item() 28 | # normalize_term = (std + 1e-6)**(100/(self.step + 100)) 29 | # self.step += 1 30 | # self.weight.data.div_(normalize_term) 31 | # self.bias.data.div_(normalize_term) 32 | # output.div_(normalize_term) 33 | # # sync the weights, braodcast 34 | # torch.distributed.broadcast(self.weight.data, 0) 35 | # torch.distributed.broadcast(self.bias.data, 0) 36 | 37 | # return output 38 | StandardizedC2d = nn.Conv2d 39 | 40 | 41 | class FP32GroupNorm(nn.GroupNorm): 42 | def __init__(self, *args, **kwargs): 43 | super().__init__(*args, **kwargs) 44 | 45 | def forward(self, input): 46 | output = F.group_norm( 47 | input.float(), 48 | self.num_groups, 49 | self.weight.float() if self.weight is not None else None, 50 | self.bias.float() if self.bias is not None else None, 51 | self.eps, 52 | ) 53 | return output.type_as(input) 54 | 55 | 56 | class AttnBlock(nn.Module): 57 | def __init__(self, in_channels: int): 58 | super().__init__() 59 | self.in_channels = in_channels 60 | 61 | self.head_dim = 64 62 | self.num_heads = in_channels // self.head_dim 63 | self.norm = FP32GroupNorm( 64 | num_groups=32, num_channels=in_channels, eps=1e-6, affine=True 65 | ) 66 | self.qkv = StandardizedC2d( 67 | in_channels, in_channels * 3, kernel_size=1, bias=False 68 | ) 69 | self.proj_out = StandardizedC2d( 70 | in_channels, in_channels, kernel_size=1, bias=False 71 | ) 72 | nn.init.normal_(self.proj_out.weight, std=0.2 / math.sqrt(in_channels)) 73 | 74 | def attention(self, h_) -> Tensor: 75 | h_ = self.norm(h_) 76 | qkv = self.qkv(h_) 77 | q, k, v = qkv.chunk(3, dim=1) 78 | b, c, h, w = q.shape 79 | q = rearrange( 80 | q, "b (h d) x y -> b h (x y) d", h=self.num_heads, d=self.head_dim 81 | ) 82 | k = rearrange( 83 | k, "b (h d) x y -> b h (x y) d", h=self.num_heads, d=self.head_dim 84 | ) 85 | v = rearrange( 86 | v, "b (h d) x y -> b h (x y) d", h=self.num_heads, d=self.head_dim 87 | ) 88 | h_ = F.scaled_dot_product_attention(q, k, v) 89 | h_ = rearrange(h_, "b h (x y) d -> b (h d) x y", x=h, y=w) 90 | return h_ 91 | 92 | def forward(self, x) -> Tensor: 93 | return x + self.proj_out(self.attention(x)) 94 | 95 | 96 | class ResnetBlock(nn.Module): 97 | def __init__(self, in_channels: int, out_channels: int): 98 | super().__init__() 99 | self.in_channels = in_channels 100 | out_channels = in_channels if out_channels is None else out_channels 101 | self.out_channels = out_channels 102 | self.norm1 = FP32GroupNorm( 103 | num_groups=32, num_channels=in_channels, eps=1e-6, affine=True 104 | ) 105 | self.conv1 = StandardizedC2d( 106 | in_channels, out_channels, kernel_size=3, stride=1, padding=1 107 | ) 108 | self.norm2 = FP32GroupNorm( 109 | num_groups=32, num_channels=out_channels, eps=1e-6, affine=True 110 | ) 111 | self.conv2 = StandardizedC2d( 112 | out_channels, out_channels, kernel_size=3, stride=1, padding=1 113 | ) 114 | if self.in_channels != self.out_channels: 115 | self.nin_shortcut = StandardizedC2d( 116 | in_channels, out_channels, kernel_size=1, stride=1, padding=0 117 | ) 118 | 119 | # init conv2 as very small number 120 | nn.init.normal_(self.conv2.weight, std=0.0001 / self.out_channels) 121 | nn.init.zeros_(self.conv2.bias) 122 | self.counter = 0 123 | 124 | def forward(self, x): 125 | 126 | # if self.counter < 5000: 127 | # self.counter += 1 128 | # h = 0 129 | # else: 130 | h = x 131 | h = self.norm1(h) 132 | h = swish(h) 133 | h = self.conv1(h) 134 | h = self.norm2(h) 135 | h = swish(h) 136 | h = self.conv2(h) 137 | 138 | if self.in_channels != self.out_channels: 139 | x = self.nin_shortcut(x) 140 | return x + h 141 | 142 | 143 | class Downsample(nn.Module): 144 | def __init__(self, in_channels: int): 145 | super().__init__() 146 | self.conv = StandardizedC2d( 147 | in_channels, in_channels, kernel_size=3, stride=2, padding=0 148 | ) 149 | 150 | def forward(self, x): 151 | pad = (0, 1, 0, 1) 152 | x = nn.functional.pad(x, pad, mode="constant", value=0) 153 | x = self.conv(x) 154 | return x 155 | 156 | 157 | class Upsample(nn.Module): 158 | def __init__(self, in_channels: int): 159 | super().__init__() 160 | self.conv = StandardizedC2d( 161 | in_channels, in_channels, kernel_size=3, stride=1, padding=1 162 | ) 163 | 164 | def forward(self, x): 165 | x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") 166 | x = self.conv(x) 167 | return x 168 | 169 | 170 | class Encoder(nn.Module): 171 | def __init__( 172 | self, 173 | resolution: int, 174 | in_channels: int, 175 | ch: int, 176 | ch_mult: list[int], 177 | num_res_blocks: int, 178 | z_channels: int, 179 | use_attn: bool = True, 180 | use_wavelet: bool = False, 181 | ): 182 | super().__init__() 183 | self.ch = ch 184 | self.num_resolutions = len(ch_mult) 185 | self.num_res_blocks = num_res_blocks 186 | self.resolution = resolution 187 | self.in_channels = in_channels 188 | self.use_wavelet = use_wavelet 189 | if self.use_wavelet: 190 | self.wavelet_transform = wavelet_transform_multi_channel 191 | self.conv_in = StandardizedC2d( 192 | 4 * in_channels, self.ch * 2, kernel_size=3, stride=1, padding=1 193 | ) 194 | ch_mult[0] *= 2 195 | else: 196 | self.wavelet_transform = nn.Identity() 197 | self.conv_in = StandardizedC2d( 198 | in_channels, self.ch, kernel_size=3, stride=1, padding=1 199 | ) 200 | 201 | curr_res = resolution 202 | in_ch_mult = (2 if self.use_wavelet else 1,) + tuple(ch_mult) 203 | self.in_ch_mult = in_ch_mult 204 | self.down = nn.ModuleList() 205 | block_in = self.ch 206 | for i_level in range(self.num_resolutions): 207 | block = nn.ModuleList() 208 | attn = nn.ModuleList() 209 | block_in = ch * in_ch_mult[i_level] 210 | block_out = ch * ch_mult[i_level] 211 | for _ in range(self.num_res_blocks): 212 | block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) 213 | block_in = block_out 214 | down = nn.Module() 215 | down.block = block 216 | down.attn = attn 217 | if i_level != self.num_resolutions - 1 and not ( 218 | self.use_wavelet and i_level == 0 219 | ): 220 | down.downsample = Downsample(block_in) 221 | curr_res = curr_res // 2 222 | self.down.append(down) 223 | self.mid = nn.Module() 224 | self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) 225 | self.mid.attn_1 = AttnBlock(block_in) if use_attn else nn.Identity() 226 | self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) 227 | self.norm_out = FP32GroupNorm( 228 | num_groups=32, num_channels=block_in, eps=1e-6, affine=True 229 | ) 230 | self.conv_out = StandardizedC2d( 231 | block_in, z_channels, kernel_size=3, stride=1, padding=1 232 | ) 233 | for module in self.modules(): 234 | if isinstance(module, StandardizedC2d): 235 | nn.init.zeros_(module.bias) 236 | if isinstance(module, nn.GroupNorm): 237 | nn.init.zeros_(module.bias) 238 | 239 | def forward(self, x) -> Tensor: 240 | h = self.wavelet_transform(x) 241 | h = self.conv_in(h) 242 | for i_level in range(self.num_resolutions): 243 | for i_block in range(self.num_res_blocks): 244 | h = self.down[i_level].block[i_block](h) 245 | if len(self.down[i_level].attn) > 0: 246 | h = self.down[i_level].attn[i_block](h) 247 | if i_level != self.num_resolutions - 1 and not ( 248 | self.use_wavelet and i_level == 0 249 | ): 250 | h = self.down[i_level].downsample(h) 251 | h = self.mid.block_1(h) 252 | h = self.mid.attn_1(h) 253 | h = self.mid.block_2(h) 254 | h = self.norm_out(h) 255 | h = swish(h) 256 | h = self.conv_out(h) 257 | return h 258 | 259 | 260 | class Decoder(nn.Module): 261 | def __init__( 262 | self, 263 | ch: int, 264 | out_ch: int, 265 | ch_mult: list[int], 266 | num_res_blocks: int, 267 | in_channels: int, 268 | resolution: int, 269 | z_channels: int, 270 | use_attn: bool = True, 271 | ): 272 | super().__init__() 273 | self.ch = ch 274 | self.num_resolutions = len(ch_mult) 275 | self.num_res_blocks = num_res_blocks 276 | self.resolution = resolution 277 | self.in_channels = in_channels 278 | self.ffactor = 2 ** (self.num_resolutions - 1) 279 | block_in = ch * ch_mult[self.num_resolutions - 1] 280 | curr_res = resolution // 2 ** (self.num_resolutions - 1) 281 | self.z_shape = (1, z_channels, curr_res, curr_res) 282 | self.conv_in = StandardizedC2d( 283 | z_channels, block_in, kernel_size=3, stride=1, padding=1 284 | ) 285 | self.mid = nn.Module() 286 | self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) 287 | self.mid.attn_1 = AttnBlock(block_in) if use_attn else nn.Identity() 288 | self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) 289 | self.up = nn.ModuleList() 290 | for i_level in reversed(range(self.num_resolutions)): 291 | block = nn.ModuleList() 292 | attn = nn.ModuleList() 293 | block_out = ch * ch_mult[i_level] 294 | for _ in range(self.num_res_blocks + 1): 295 | block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) 296 | block_in = block_out 297 | up = nn.Module() 298 | up.block = block 299 | up.attn = attn 300 | if i_level != 0: 301 | up.upsample = Upsample(block_in) 302 | curr_res = curr_res * 2 303 | self.up.insert(0, up) 304 | self.norm_out = FP32GroupNorm( 305 | num_groups=32, num_channels=block_in, eps=1e-6, affine=True 306 | ) 307 | self.conv_out = StandardizedC2d( 308 | block_in, out_ch, kernel_size=3, stride=1, padding=1 309 | ) 310 | 311 | # initialize all bias to zero 312 | for module in self.modules(): 313 | if isinstance(module, StandardizedC2d): 314 | nn.init.zeros_(module.bias) 315 | if isinstance(module, nn.GroupNorm): 316 | nn.init.zeros_(module.bias) 317 | 318 | def forward(self, z) -> Tensor: 319 | h = self.conv_in(z) 320 | h = self.mid.block_1(h) 321 | h = self.mid.attn_1(h) 322 | h = self.mid.block_2(h) 323 | for i_level in reversed(range(self.num_resolutions)): 324 | for i_block in range(self.num_res_blocks + 1): 325 | h = self.up[i_level].block[i_block](h) 326 | if len(self.up[i_level].attn) > 0: 327 | h = self.up[i_level].attn[i_block](h) 328 | if i_level != 0: 329 | h = self.up[i_level].upsample(h) 330 | h = self.norm_out(h) 331 | h = swish(h) 332 | h = self.conv_out(h) 333 | return h 334 | 335 | 336 | class DiagonalGaussian(nn.Module): 337 | def __init__(self, sample: bool = True, chunk_dim: int = 1): 338 | super().__init__() 339 | self.sample = sample 340 | self.chunk_dim = chunk_dim 341 | 342 | def forward(self, z) -> Tensor: 343 | mean = z 344 | if self.sample: 345 | std = 0.00 346 | return mean * (1 + std * torch.randn_like(mean)) 347 | else: 348 | return mean 349 | 350 | 351 | class VAE(nn.Module): 352 | def __init__( 353 | self, 354 | resolution, 355 | in_channels, 356 | ch, 357 | out_ch, 358 | ch_mult, 359 | num_res_blocks, 360 | z_channels, 361 | use_attn, 362 | decoder_also_perform_hr, 363 | use_wavelet, 364 | ): 365 | super().__init__() 366 | self.encoder = Encoder( 367 | resolution=resolution, 368 | in_channels=in_channels, 369 | ch=ch, 370 | ch_mult=ch_mult, 371 | num_res_blocks=num_res_blocks, 372 | z_channels=z_channels, 373 | use_attn=use_attn, 374 | use_wavelet=use_wavelet, 375 | ) 376 | self.decoder = Decoder( 377 | resolution=resolution, 378 | in_channels=in_channels, 379 | ch=ch, 380 | out_ch=out_ch, 381 | ch_mult=ch_mult + [4] if decoder_also_perform_hr else ch_mult, 382 | num_res_blocks=num_res_blocks, 383 | z_channels=z_channels, 384 | use_attn=use_attn, 385 | ) 386 | self.reg = DiagonalGaussian() 387 | 388 | def forward(self, x) -> Tensor: 389 | z = self.encoder(x) 390 | z_s = self.reg(z) 391 | decz = self.decoder(z_s) 392 | return decz, z 393 | 394 | 395 | if __name__ == "__main__": 396 | from utils import prepare_filter 397 | 398 | prepare_filter("cuda") 399 | vae = VAE( 400 | resolution=256, 401 | in_channels=3, 402 | ch=64, 403 | out_ch=3, 404 | ch_mult=[1, 2, 4, 4, 4], 405 | num_res_blocks=2, 406 | z_channels=16 * 4, 407 | use_attn=False, 408 | decoder_also_perform_hr=False, 409 | use_wavelet=False, 410 | ) 411 | vae.eval().to("cuda") 412 | x = torch.randn(1, 3, 256, 256).to("cuda") 413 | decz, z = vae(x) 414 | print(decz.shape, z.shape) 415 | 416 | # do de 417 | 418 | # from unit_activation_reinitializer import adjust_weight_init 419 | # from torchvision import transforms 420 | # import torchvision 421 | 422 | # train_dataset = torchvision.datasets.CIFAR10( 423 | # root="./data", 424 | # train=True, 425 | # download=True, 426 | # transform=transforms.Compose( 427 | # [ 428 | # transforms.Resize((256, 256)), 429 | # transforms.ToTensor(), 430 | # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 431 | # ] 432 | # ), 433 | # ) 434 | 435 | # initial_std, layer_weight_std = adjust_weight_init( 436 | # vae, 437 | # dataset=train_dataset, 438 | # device="cuda:0", 439 | # batch_size=64, 440 | # num_workers=0, 441 | # tol=0.1, 442 | # max_iters=10, 443 | # exclude_layers=[FP32GroupNorm, nn.LayerNorm], 444 | # ) 445 | 446 | # # save initial_std and layer_weight_std 447 | # torch.save(initial_std, "initial_std.pth") 448 | # torch.save(layer_weight_std, "layer_weight_std.pth") 449 | 450 | # print("\nAdjusted Weight Standard Deviations. Before -> After:") 451 | # for layer_name, std in layer_weight_std.items(): 452 | # print( 453 | # f"Layer {layer_name}, Changed STD from \n {initial_std[layer_name]:.4f} -> STD {std:.4f}\n" 454 | # ) 455 | 456 | # print(layer_weight_std) 457 | -------------------------------------------------------------------------------- /contents/appg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cloneofsimo/vqgan-training/379fe36eee4e90e01a4076b2815d49e6736db992/contents/appg.png -------------------------------------------------------------------------------- /contents/buildings.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cloneofsimo/vqgan-training/379fe36eee4e90e01a4076b2815d49e6736db992/contents/buildings.jpg -------------------------------------------------------------------------------- /contents/chinatown.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cloneofsimo/vqgan-training/379fe36eee4e90e01a4076b2815d49e6736db992/contents/chinatown.jpg -------------------------------------------------------------------------------- /contents/cosplayers.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cloneofsimo/vqgan-training/379fe36eee4e90e01a4076b2815d49e6736db992/contents/cosplayers.jpg -------------------------------------------------------------------------------- /contents/flowers.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cloneofsimo/vqgan-training/379fe36eee4e90e01a4076b2815d49e6736db992/contents/flowers.jpg -------------------------------------------------------------------------------- /contents/lavender.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cloneofsimo/vqgan-training/379fe36eee4e90e01a4076b2815d49e6736db992/contents/lavender.jpg -------------------------------------------------------------------------------- /contents/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cloneofsimo/vqgan-training/379fe36eee4e90e01a4076b2815d49e6736db992/contents/logo.png -------------------------------------------------------------------------------- /contents/magazines.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cloneofsimo/vqgan-training/379fe36eee4e90e01a4076b2815d49e6736db992/contents/magazines.jpg -------------------------------------------------------------------------------- /contents/origin.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cloneofsimo/vqgan-training/379fe36eee4e90e01a4076b2815d49e6736db992/contents/origin.png -------------------------------------------------------------------------------- /contents/randomwoman.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cloneofsimo/vqgan-training/379fe36eee4e90e01a4076b2815d49e6736db992/contents/randomwoman.jpeg -------------------------------------------------------------------------------- /contents/recon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cloneofsimo/vqgan-training/379fe36eee4e90e01a4076b2815d49e6736db992/contents/recon.png -------------------------------------------------------------------------------- /contents/ti2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cloneofsimo/vqgan-training/379fe36eee4e90e01a4076b2815d49e6736db992/contents/ti2.png -------------------------------------------------------------------------------- /contents/ti2_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cloneofsimo/vqgan-training/379fe36eee4e90e01a4076b2815d49e6736db992/contents/ti2_mask.png -------------------------------------------------------------------------------- /launcher.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | loglr=-7 4 | width=64 5 | lr=$(python -c "import math; print(2**${loglr})") 6 | run_name="stage_4_msepool-cont-512-1.0-1.0-batch-gradnorm_make_deterministic" 7 | echo "Running ${run_name}" 8 | 9 | torchrun --nproc_per_node=8 vae_trainer.py \ 10 | --learning_rate_vae ${lr} \ 11 | --vae_ch ${width} \ 12 | --run_name ${run_name} \ 13 | --num_epochs 20 \ 14 | --max_steps 100000 \ 15 | --evaluate_every_n_steps 500 \ 16 | --learning_rate_disc 1e-5 \ 17 | --batch_size 12 \ 18 | --do_clamp \ 19 | --do_ganloss \ 20 | --project_name "HrDecoderAE" \ 21 | --decoder_also_perform_hr True 22 | #--load_path "/home/ubuntu/auravasa/ckpt/stage_3_msepool-cont-512-1.0-1.0-batch-gradnorm/vae_epoch_1_step_23501.pt" 23 | #--load_path "/home/ubuntu/auravasa/ckpt/stage2_msepool-cont-512-1.0-1.0-batch-gradnorm/vae_epoch_0_step_28501.pt" 24 | # --load_path "/home/ubuntu/auravasa/ckpt/exp_vae_ch_256_lr_0.0078125_weighted_percep+f8areapool_l2_0.0/vae_epoch_1_step_27001.pt" -------------------------------------------------------------------------------- /scripts/launch_hdr.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | loglr=-7 4 | width=128 5 | lr=$(python -c "import math; print(2**${loglr})") 6 | run_name="stage_4_cont_with_lecam_hinge" 7 | echo "Running ${run_name}" 8 | 9 | torchrun --nproc_per_node=8 vae_trainer.py \ 10 | --learning_rate_vae ${lr} \ 11 | --vae_ch ${width} \ 12 | --run_name ${run_name} \ 13 | --num_epochs 20 \ 14 | --max_steps 100000 \ 15 | --evaluate_every_n_steps 1000 \ 16 | --learning_rate_disc 3e-5 \ 17 | --batch_size 4 \ 18 | --do_clamp \ 19 | --do_ganloss \ 20 | --project_name "HrDecoderAE" \ 21 | --decoder_also_perform_hr True \ 22 | --do_compile False \ 23 | --crop_invariance True \ 24 | --flip_invariance False \ 25 | --use_wavelet True \ 26 | --vae_z_channels 64 \ 27 | --vae_ch_mult 1,2,4,4,4 \ 28 | --use_lecam True \ 29 | --disc_type "hinge" \ 30 | --load_path "/home/ubuntu/auravasa/ckpt/stage_3_hdr_z64_f16_add_flip_lr_disc_1e-4/vae_epoch_1_step_98001.pt" -------------------------------------------------------------------------------- /sweep.sh: -------------------------------------------------------------------------------- 1 | ## Sweep 1. is attention useful? 2 | 3 | loglrs=(-8 -7 -6 -5 -4 -3 -2) 4 | MODEL_WIDTHS=(32 64 128) 5 | 6 | for loglr in "${loglrs[@]}"; do 7 | for attn in "True" "False"; do 8 | for width in "${MODEL_WIDTHS[@]}"; do 9 | lr=$(python -c "import math; print(2**${loglr})") 10 | run_name="exp_vae_ch_${width}_lr_${lr}_attn_${attn}" 11 | 12 | echo "Running ${run_name}" 13 | 14 | torchrun --nproc_per_node=8 vae_trainer.py \ 15 | --learning_rate_vae ${lr} \ 16 | --vae_ch ${width} \ 17 | --run_name ${run_name} \ 18 | --num_epochs 20 \ 19 | --max_steps 2000 \ 20 | --evaluate_every_n_steps 250 \ 21 | --batch_size 32 \ 22 | --do_clamp \ 23 | --do_attn ${attn} \ 24 | --project_name "vae_sweep_attn_lr_width" 25 | 26 | done 27 | done 28 | done 29 | 30 | ## Sweep 2. Can we initialize better? 31 | 32 | loglrs=(-8 -7 -6 -5 -4 -3 -2) 33 | MODEL_WIDTHS=(64) 34 | 35 | for loglr in "${loglrs[@]}"; do 36 | for attn in "True" "False"; do 37 | for width in "${MODEL_WIDTHS[@]}"; do 38 | lr=$(python -c "import math; print(2**${loglr})") 39 | run_name="exp_vae_ch_${width}_lr_${lr}_attn_${attn}" 40 | 41 | echo "Running ${run_name}" 42 | 43 | torchrun --nproc_per_node=8 vae_trainer.py \ 44 | --learning_rate_vae ${lr} \ 45 | --vae_ch ${width} \ 46 | --run_name ${run_name} \ 47 | --num_epochs 20 \ 48 | --max_steps 2000 \ 49 | --evaluate_every_n_steps 250 \ 50 | --batch_size 32 \ 51 | --do_clamp \ 52 | --do_attn ${attn} \ 53 | --project_name "vae_sweep_attn_lr_width" 54 | 55 | done 56 | done 57 | done 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | -------------------------------------------------------------------------------- /tae.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from einops import rearrange 6 | from torch import Tensor, nn 7 | 8 | 9 | def swish(x: Tensor) -> Tensor: 10 | return x * torch.sigmoid(x) 11 | 12 | 13 | class AttnBlock(nn.Module): 14 | def __init__(self, in_channels: int): 15 | super().__init__() 16 | self.in_channels = in_channels 17 | self.num_heads = 8 18 | self.head_dim = in_channels // self.num_heads 19 | self.norm = nn.GroupNorm( 20 | num_groups=32, num_channels=in_channels, eps=1e-6, affine=True 21 | ) 22 | self.qkv = nn.Conv3d(in_channels, in_channels * 3, kernel_size=1, bias=False) 23 | self.proj_out = nn.Conv3d(in_channels, in_channels, kernel_size=1, bias=False) 24 | nn.init.normal_(self.proj_out.weight, std=0.2 / math.sqrt(in_channels)) 25 | 26 | def attention(self, h_: Tensor) -> Tensor: 27 | h_ = self.norm(h_) 28 | qkv = self.qkv(h_) 29 | q, k, v = qkv.chunk(3, dim=1) 30 | b, c, t, h, w = q.shape 31 | q = rearrange( 32 | q, 33 | "b (head d) t h w -> b head (t h w) d", 34 | head=self.num_heads, 35 | d=self.head_dim, 36 | ) 37 | k = rearrange( 38 | k, 39 | "b (head d) t h w -> b head (t h w) d", 40 | head=self.num_heads, 41 | d=self.head_dim, 42 | ) 43 | v = rearrange( 44 | v, 45 | "b (head d) t h w -> b head (t h w) d", 46 | head=self.num_heads, 47 | d=self.head_dim, 48 | ) 49 | h_ = F.scaled_dot_product_attention(q, k, v) 50 | h_ = rearrange(h_, "b head (t h w) d -> b (head d) t h w", t=t, h=h, w=w) 51 | return h_ 52 | 53 | def forward(self, x: Tensor) -> Tensor: 54 | return x + self.proj_out(self.attention(x)) 55 | 56 | 57 | class ResnetBlock(nn.Module): 58 | def __init__(self, in_channels: int, out_channels: int = None): 59 | super().__init__() 60 | self.in_channels = in_channels 61 | out_channels = in_channels if out_channels is None else out_channels 62 | self.out_channels = out_channels 63 | self.norm1 = nn.GroupNorm( 64 | num_groups=32, num_channels=in_channels, eps=1e-6, affine=True 65 | ) 66 | self.conv1 = nn.Conv3d( 67 | in_channels, out_channels, kernel_size=3, stride=1, padding=1 68 | ) 69 | self.norm2 = nn.GroupNorm( 70 | num_groups=32, num_channels=out_channels, eps=1e-6, affine=True 71 | ) 72 | self.conv2 = nn.Conv3d( 73 | out_channels, out_channels, kernel_size=3, stride=1, padding=1 74 | ) 75 | if self.in_channels != self.out_channels: 76 | self.nin_shortcut = nn.Conv3d( 77 | in_channels, out_channels, kernel_size=1, stride=1, padding=0 78 | ) 79 | 80 | def forward(self, x): 81 | h = x 82 | h = self.norm1(h) 83 | h = swish(h) 84 | h = self.conv1(h) 85 | h = self.norm2(h) 86 | h = swish(h) 87 | h = self.conv2(h) 88 | if self.in_channels != self.out_channels: 89 | x = self.nin_shortcut(x) 90 | return x + h 91 | 92 | 93 | class Downsample(nn.Module): 94 | def __init__(self, in_channels: int): 95 | super().__init__() 96 | self.conv = nn.Conv3d( 97 | in_channels, in_channels, kernel_size=3, stride=2, padding=0 98 | ) 99 | 100 | def forward(self, x: Tensor): 101 | pad = (0, 1, 0, 1, 0, 1) # Pad depth (T), height (H), width (W) dimensions 102 | x = nn.functional.pad(x, pad, mode="constant", value=0) 103 | x = self.conv(x) 104 | return x 105 | 106 | 107 | class Upsample(nn.Module): 108 | def __init__(self, in_channels: int): 109 | super().__init__() 110 | self.conv = nn.Conv3d( 111 | in_channels, in_channels, kernel_size=3, stride=1, padding=1 112 | ) 113 | 114 | def forward(self, x: Tensor): 115 | x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") 116 | x = self.conv(x) 117 | return x 118 | 119 | 120 | class Encoder(nn.Module): 121 | def __init__( 122 | self, 123 | resolution: int, 124 | in_channels: int, 125 | ch: int, 126 | ch_mult: list[int], 127 | num_res_blocks: int, 128 | z_channels: int, 129 | ): 130 | super().__init__() 131 | self.ch = ch 132 | self.num_resolutions = len(ch_mult) 133 | self.num_res_blocks = num_res_blocks 134 | self.resolution = resolution 135 | self.in_channels = in_channels 136 | self.conv_in = nn.Conv3d( 137 | in_channels, self.ch, kernel_size=3, stride=1, padding=1 138 | ) 139 | curr_res = resolution 140 | in_ch_mult = (1,) + tuple(ch_mult) 141 | self.down = nn.ModuleList() 142 | block_in = self.ch 143 | for i_level in range(self.num_resolutions): 144 | block = nn.ModuleList() 145 | attn = nn.ModuleList() 146 | block_in = ch * in_ch_mult[i_level] 147 | block_out = ch * ch_mult[i_level] 148 | for _ in range(self.num_res_blocks): 149 | block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) 150 | block_in = block_out 151 | down = nn.Module() 152 | down.block = block 153 | down.attn = attn 154 | if i_level != self.num_resolutions - 1: 155 | down.downsample = Downsample(block_in) 156 | curr_res = curr_res // 2 157 | self.down.append(down) 158 | self.mid = nn.Module() 159 | self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) 160 | self.mid.attn_1 = AttnBlock(block_in) 161 | self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) 162 | self.norm_out = nn.GroupNorm( 163 | num_groups=32, num_channels=block_in, eps=1e-6, affine=True 164 | ) 165 | self.conv_out = nn.Conv3d( 166 | block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1 167 | ) 168 | 169 | def forward(self, x: Tensor) -> Tensor: 170 | h = self.conv_in(x) 171 | for i_level in range(self.num_resolutions): 172 | for i_block in range(self.num_res_blocks): 173 | h = self.down[i_level].block[i_block](h) 174 | if len(self.down[i_level].attn) > 0: 175 | h = self.down[i_level].attn[i_block](h) 176 | if i_level != self.num_resolutions - 1: 177 | h = self.down[i_level].downsample(h) 178 | h = self.mid.block_1(h) 179 | h = self.mid.attn_1(h) 180 | h = self.mid.block_2(h) 181 | h = self.norm_out(h) 182 | h = swish(h) 183 | h = self.conv_out(h) 184 | return h 185 | 186 | 187 | class Decoder(nn.Module): 188 | def __init__( 189 | self, 190 | ch: int, 191 | out_ch: int, 192 | ch_mult: list[int], 193 | num_res_blocks: int, 194 | in_channels: int, 195 | resolution: int, 196 | z_channels: int, 197 | ): 198 | super().__init__() 199 | self.ch = ch 200 | self.num_resolutions = len(ch_mult) 201 | self.num_res_blocks = num_res_blocks 202 | self.resolution = resolution 203 | self.in_channels = in_channels 204 | self.ffactor = 2 ** (self.num_resolutions - 1) 205 | block_in = ch * ch_mult[self.num_resolutions - 1] 206 | curr_res = resolution // (2 ** (self.num_resolutions - 1)) 207 | self.z_shape = (1, z_channels, curr_res, curr_res, curr_res) 208 | self.conv_in = nn.Conv3d( 209 | z_channels, block_in, kernel_size=3, stride=1, padding=1 210 | ) 211 | self.mid = nn.Module() 212 | self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) 213 | self.mid.attn_1 = AttnBlock(block_in) 214 | self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) 215 | self.up = nn.ModuleList() 216 | for i_level in reversed(range(self.num_resolutions)): 217 | block = nn.ModuleList() 218 | attn = nn.ModuleList() 219 | block_out = ch * ch_mult[i_level] 220 | for _ in range(self.num_res_blocks + 1): 221 | block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) 222 | block_in = block_out 223 | up = nn.Module() 224 | up.block = block 225 | up.attn = attn 226 | if i_level != 0: 227 | up.upsample = Upsample(block_in) 228 | curr_res = curr_res * 2 229 | self.up.insert(0, up) 230 | self.norm_out = nn.GroupNorm( 231 | num_groups=32, num_channels=block_in, eps=1e-6, affine=True 232 | ) 233 | self.conv_out = nn.Conv3d(block_in, out_ch, kernel_size=3, stride=1, padding=1) 234 | 235 | def forward(self, z: Tensor) -> Tensor: 236 | h = self.conv_in(z) 237 | h = self.mid.block_1(h) 238 | h = self.mid.attn_1(h) 239 | h = self.mid.block_2(h) 240 | for i_level in reversed(range(self.num_resolutions)): 241 | for i_block in range(self.num_res_blocks + 1): 242 | h = self.up[i_level].block[i_block](h) 243 | if len(self.up[i_level].attn) > 0: 244 | h = self.up[i_level].attn[i_block](h) 245 | if i_level != 0: 246 | h = self.up[i_level].upsample(h) 247 | h = self.norm_out(h) 248 | h = swish(h) 249 | h = self.conv_out(h) 250 | return h 251 | 252 | 253 | class DiagonalGaussian(nn.Module): 254 | def __init__(self, sample: bool = True, chunk_dim: int = 1): 255 | super().__init__() 256 | self.sample = sample 257 | self.chunk_dim = chunk_dim 258 | 259 | def forward(self, z: Tensor) -> Tensor: 260 | mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim) 261 | if self.sample: 262 | logvar = logvar.clamp(min=-3) 263 | std = torch.exp(0.5 * logvar) 264 | return mean + std * torch.randn_like(mean) 265 | else: 266 | return mean 267 | 268 | 269 | class TVAE(nn.Module): 270 | def __init__( 271 | self, resolution, in_channels, ch, out_ch, ch_mult, num_res_blocks, z_channels 272 | ): 273 | super().__init__() 274 | self.encoder = Encoder( 275 | resolution=resolution, 276 | in_channels=in_channels, 277 | ch=ch, 278 | ch_mult=ch_mult, 279 | num_res_blocks=num_res_blocks, 280 | z_channels=z_channels, 281 | ) 282 | self.decoder = Decoder( 283 | resolution=resolution, 284 | in_channels=in_channels, 285 | ch=ch, 286 | out_ch=out_ch, 287 | ch_mult=ch_mult, 288 | num_res_blocks=num_res_blocks, 289 | z_channels=z_channels, 290 | ) 291 | self.reg = DiagonalGaussian() 292 | 293 | def forward(self, x: Tensor) -> Tensor: 294 | z = self.encoder(x) 295 | z_s = self.reg(z) 296 | decz = self.decoder(z_s) 297 | return decz, z 298 | 299 | 300 | if __name__ == "__main__": 301 | vae = TVAE( 302 | resolution=256, 303 | in_channels=3, 304 | ch=64, 305 | out_ch=3, 306 | ch_mult=[1, 2, 4, 4], 307 | num_res_blocks=2, 308 | z_channels=16, 309 | ) 310 | with torch.no_grad(): 311 | vae.eval().to("cpu") 312 | x = torch.randn(1, 3, 48, 256, 256).to("cpu") # [B, C, T, H, W] 313 | decz, z = vae(x) 314 | print(decz.shape, z.shape) 315 | -------------------------------------------------------------------------------- /tester.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.optim as optim 5 | import torchvision.models as models 6 | import torchvision.transforms as transforms 7 | from PIL import Image 8 | 9 | from vae_trainer import ( 10 | VAE, 11 | PatchDiscriminator, 12 | create_dataloader, 13 | gan_loss, 14 | perceptual_loss, 15 | ) 16 | 17 | 18 | # Sample Image for Testing 19 | def create_sample_image(size=(128, 128), color=(255, 0, 0)): 20 | img = Image.new("RGB", size, color) 21 | transform = transforms.ToTensor() 22 | img_tensor = transform(img).unsqueeze(0) 23 | return img_tensor 24 | 25 | 26 | # Test the VAE architecture 27 | def test_vae(): 28 | print("Testing VAE Architecture...") 29 | # Instantiate the model 30 | vae = VAE(width_mult=1.0) 31 | 32 | # Create a sample image 33 | img = create_sample_image() 34 | 35 | # Forward pass through the model 36 | reconstructed, latent = vae(img) 37 | 38 | # Check output shapes 39 | print(f"Input shape: {img.shape}") 40 | print(f"Latent shape: {latent.shape}") 41 | print(f"Reconstructed shape: {reconstructed.shape}") 42 | 43 | 44 | # Test the Patch Discriminator 45 | def test_discriminator(): 46 | print("Testing Patch Discriminator...") 47 | # Instantiate the discriminator 48 | discriminator = PatchDiscriminator() 49 | 50 | # Create a sample image 51 | img = create_sample_image() 52 | 53 | # Forward pass through the discriminator 54 | output = discriminator(img) 55 | 56 | # Check output shape 57 | print(f"Input shape: {img.shape}") 58 | print(f"Discriminator output shape: {output.shape}") 59 | 60 | 61 | # Test GAN loss function 62 | def test_gan_loss(): 63 | print("Testing GAN Loss...") 64 | # Create sample real and fake predictions 65 | real_preds = torch.randn(1, 1, 8, 8).abs() 66 | fake_preds = torch.randn(1, 1, 8, 8).abs() 67 | 68 | # Calculate loss 69 | loss = gan_loss(real_preds, fake_preds) 70 | print(f"GAN Loss: {loss.item()}") 71 | 72 | 73 | # Test perceptual loss function 74 | def test_perceptual_loss(): 75 | print("Testing Perceptual Loss...") 76 | # Instantiate VGG model for perceptual loss 77 | vgg_model = models.vgg16(pretrained=True).features[:9] 78 | 79 | # Create two sample images (real and reconstructed) 80 | real_img = create_sample_image() 81 | reconstructed_img = create_sample_image(color=(0, 255, 0)) 82 | 83 | # Calculate perceptual loss 84 | loss = perceptual_loss(reconstructed_img, real_img, vgg_model) 85 | print(f"Perceptual Loss: {loss.item()}") 86 | 87 | 88 | # Test WebDataset DataLoader 89 | def test_webdataset_dataloader(): 90 | print("Testing WebDataset DataLoader...") 91 | 92 | # Dummy URL for WebDataset (replace with actual path or URL in practice) 93 | dataset_url = "path/to/your/webdataset/shards" 94 | 95 | # Create a dummy dataset with ToTensor transforms 96 | transform = transforms.Compose( 97 | [ 98 | transforms.Resize((128, 128)), 99 | transforms.ToTensor(), 100 | ] 101 | ) 102 | 103 | # Set up dummy dataloader 104 | dataloader = create_dataloader( 105 | dataset_url, batch_size=2, num_workers=4, world_size=1, rank=0 106 | ) 107 | 108 | # Check the output from the dataloader 109 | for imgs, labels in dataloader: 110 | print(f"Batch of images shape: {imgs.shape}") 111 | print(f"Batch of labels shape: {labels.shape}") 112 | break 113 | 114 | 115 | def test_train_loop(): 116 | print("Testing Train Loop...") 117 | 118 | vae = VAE(width_mult=1.0) 119 | discriminator = PatchDiscriminator() 120 | 121 | img = create_sample_image() 122 | 123 | optimizer_G = optim.Adam(vae.parameters(), lr=2e-4) 124 | optimizer_D = optim.Adam(discriminator.parameters(), lr=2e-4) 125 | 126 | vgg_model = models.vgg16(pretrained=True).features[:9] 127 | vgg_model.eval() 128 | 129 | reconstructed, latent = vae(img) 130 | 131 | print(f"Input image shape: {img.shape} (Expected: [1, 3, 128, 128])") 132 | print(f"Latent shape: {latent.shape} (Expected: [1, 256, 8, 8])") 133 | print( 134 | f"Reconstructed image shape: {reconstructed.shape} (Expected: [1, 3, 128, 128])" 135 | ) 136 | 137 | real_preds = discriminator(img) 138 | fake_preds = discriminator(reconstructed.detach()) 139 | 140 | print(f"Real predictions shape: {real_preds.shape} (Expected: [1, 1, 8, 8])") 141 | print(f"Fake predictions shape: {fake_preds.shape} (Expected: [1, 1, 8, 8])") 142 | 143 | gan_loss_value = gan_loss(real_preds, fake_preds) 144 | rec_loss_value = perceptual_loss(reconstructed, img, vgg_model) 145 | 146 | lambda_val = rec_loss_value / (gan_loss_value + 1e-8) 147 | g_loss = rec_loss_value + lambda_val * gan_loss_value 148 | g_loss.backward(retain_graph=True) 149 | 150 | d_loss = gan_loss(real_preds, fake_preds) 151 | d_loss.backward() 152 | 153 | optimizer_D.step() 154 | optimizer_G.step() 155 | 156 | optimizer_G.zero_grad() 157 | optimizer_D.zero_grad() 158 | 159 | print(f"Discriminator Loss: {d_loss.item()} (Expected: > 0)") 160 | print(f"Generator Loss: {g_loss.item()} (Expected: > 0)") 161 | print( 162 | f"Lambda value (weight): {lambda_val.item()} (Expected: ~1.0 depending on losses)" 163 | ) 164 | 165 | 166 | # Run the dummy train loop test 167 | if __name__ == "__main__": 168 | 169 | # download lpips from here 170 | # https://heibox.uni-heidelberg.de/seafhttp/files/9535cbee-6558-4c0c-8743-78f5e56ea75e/vgg.pth 171 | os.system( 172 | "wget https://heibox.uni-heidelberg.de/seafhttp/files/9535cbee-6558-4c0c-8743-78f5e56ea75e/vgg.pth" 173 | ) 174 | test_vae() 175 | test_discriminator() 176 | test_gan_loss() 177 | test_perceptual_loss() 178 | # test_webdataset_dataloader() 179 | 180 | # Dummy train loop for input-output 181 | test_train_loop() 182 | -------------------------------------------------------------------------------- /tester_upload.sh: -------------------------------------------------------------------------------- 1 | export HF_TRANSFER=True 2 | huggingface-cli upload fal/AuraEquiVAE ./vae_epoch_3_step_49501_bf16.pt -------------------------------------------------------------------------------- /unit_activation_reinitializer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchvision import datasets, transforms 5 | import numpy as np 6 | import math 7 | 8 | 9 | def compute_activation_std( 10 | model, dataset, device="cpu", batch_size=32, num_workers=0, layer_names=None 11 | ): 12 | activations = {} 13 | handles = [] 14 | 15 | def save_activation(name): 16 | def hook(module, input, output): 17 | if isinstance(output, tuple): 18 | output = output[0] 19 | activations[name].append(output.detach()) 20 | 21 | return hook 22 | 23 | for name, module in model.named_modules(): 24 | if name in layer_names: 25 | activations[name] = [] 26 | handle = module.register_forward_hook(save_activation(name)) 27 | handles.append(handle) 28 | 29 | loader = torch.utils.data.DataLoader( 30 | dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers 31 | ) 32 | model.to(device) 33 | model.eval() 34 | 35 | with torch.no_grad(): 36 | for batch in loader: 37 | if isinstance(batch, (list, tuple)): 38 | inputs = batch[0].to(device) 39 | else: 40 | inputs = batch.to(device) 41 | _ = model(inputs) 42 | break 43 | 44 | layer_activation_std = {} 45 | for name in layer_names: 46 | try: 47 | act = torch.cat(activations[name], dim=0) 48 | except: 49 | print(activations[name]) 50 | break 51 | act_std = act.std().item() 52 | layer_activation_std[name] = act_std 53 | 54 | for handle in handles: 55 | handle.remove() 56 | 57 | return layer_activation_std 58 | 59 | 60 | def adjust_weight_init( 61 | model, 62 | dataset, 63 | device="cpu", 64 | batch_size=32, 65 | num_workers=0, 66 | tol=0.2, 67 | max_iters=10, 68 | exclude_layers=None, 69 | ): 70 | if exclude_layers is None: 71 | exclude_layers = [] 72 | 73 | layers_to_adjust = [] 74 | for name, module in model.named_modules(): 75 | if isinstance(module, (nn.Linear, nn.Conv2d)) and not isinstance( 76 | module, tuple(exclude_layers) 77 | ): 78 | layers_to_adjust.append((name, module)) 79 | 80 | print(f"Layers to adjust: {layers_to_adjust}") 81 | initial_std = {} 82 | layer_weight_std = {} 83 | 84 | for name, module in layers_to_adjust: 85 | print(f"Adjusting layer: {name}") 86 | initial_std[name] = module.weight.std().item() 87 | fan_in = np.prod(module.weight.shape[1:]) 88 | weight_std = np.sqrt(1 / fan_in) # use muP for initialization. 89 | 90 | for i in range(max_iters): 91 | nn.init.normal_(module.weight, std=weight_std) 92 | 93 | activation_std = compute_activation_std( 94 | model, dataset, device, batch_size, num_workers, layer_names=[name] 95 | )[name] 96 | print(f"Iteration {i+1}: Activation std = {activation_std:.4f}") 97 | 98 | if abs(activation_std - 1.0) < tol: 99 | print( 100 | f"Layer {name} achieved near unit activation of {activation_std:.4f} with weight std = {weight_std:.4f}" 101 | ) 102 | layer_weight_std[name] = weight_std / activation_std 103 | break 104 | else: 105 | weight_std = weight_std / activation_std 106 | else: 107 | print(f"Layer {name} did not converge within {max_iters} iterations.") 108 | layer_weight_std[name] = weight_std 109 | 110 | return initial_std, layer_weight_std 111 | 112 | 113 | #### HOW TO USE 114 | # 1. define dataset 115 | # 2. define model 116 | # 3. launch. 117 | 118 | transform = transforms.Compose( 119 | [ 120 | transforms.ToTensor(), 121 | ] 122 | ) 123 | 124 | train_dataset = datasets.MNIST( 125 | root="mnist_data", train=True, transform=transform, download=True 126 | ) 127 | 128 | 129 | class CustomActivation(nn.Module): 130 | def forward(self, x): 131 | return x * torch.sigmoid(x) 132 | 133 | 134 | class ResBlock(nn.Module): 135 | def __init__(self, in_channels, out_channels, reduction_ratio=2): 136 | super(ResBlock, self).__init__() 137 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) 138 | self.bn1 = nn.BatchNorm2d(out_channels) 139 | self.relu = nn.ReLU(inplace=True) 140 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1) 141 | self.bn2 = nn.BatchNorm2d(out_channels) 142 | 143 | def forward(self, x): 144 | identity = x 145 | out = self.conv1(x) 146 | out = self.bn1(out) 147 | out = self.relu(out) 148 | out = self.conv2(out) 149 | out = self.bn2(out) 150 | out = out.mean(dim=[-1, -2]) 151 | return out 152 | 153 | 154 | class MLPModel(nn.Module): 155 | def __init__(self): 156 | super(MLPModel, self).__init__() 157 | self.block1 = ResBlock(1, 256) 158 | 159 | self.fc1 = nn.Linear(256, 256) 160 | self.act1 = CustomActivation() 161 | self.ln1 = nn.LayerNorm(256) 162 | 163 | self.fc2 = nn.Linear(256, 128) 164 | self.act2 = nn.ReLU() 165 | self.ln2 = nn.LayerNorm(128) 166 | 167 | self.fc3 = nn.Linear(128, 64) 168 | self.act3 = nn.Tanh() 169 | 170 | self.fc_residual = nn.Linear(256, 64) 171 | 172 | self.fc4 = nn.Linear(64, 10) 173 | 174 | def forward(self, x): 175 | out1 = self.block1(x) 176 | out1 = out1.view(out1.shape[0], -1) 177 | out1 = self.act1(out1) 178 | out1 = self.fc1(out1) 179 | out1 = self.ln1(out1) 180 | 181 | out2 = self.act2(self.fc2(out1)) 182 | out2 = self.ln2(out2) 183 | 184 | out3 = self.act3(self.fc3(out2)) 185 | 186 | res = self.fc_residual(out1) 187 | out3 += res 188 | 189 | logits = self.fc4(out3) 190 | return logits 191 | 192 | 193 | model = MLPModel() 194 | 195 | 196 | exclude_layers = [nn.LayerNorm] 197 | 198 | initial_std, layer_weight_std = adjust_weight_init( 199 | model, 200 | dataset=train_dataset, 201 | device="cuda:0", 202 | batch_size=64, 203 | num_workers=0, 204 | tol=0.1, 205 | max_iters=10, 206 | exclude_layers=exclude_layers, 207 | ) 208 | 209 | print("\nAdjusted Weight Standard Deviations. Before -> After:") 210 | for layer_name, std in layer_weight_std.items(): 211 | print( 212 | f"Layer {layer_name}, Changed STD from \n {initial_std[layer_name]:.4f} -> STD {std:.4f}\n" 213 | ) 214 | 215 | print(layer_weight_std) 216 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torchvision import models 6 | 7 | 8 | class LPIPS(nn.Module): 9 | # Learned perceptual metric 10 | def __init__(self, use_dropout=True): 11 | super().__init__() 12 | self.scaling_layer = ScalingLayer() 13 | self.chns = [64, 128, 256, 512, 512] # vg16 features 14 | self.net = vgg16(pretrained=True, requires_grad=False) 15 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) 16 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) 17 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) 18 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) 19 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) 20 | self.load_from_pretrained() 21 | for param in self.parameters(): 22 | param.requires_grad = False 23 | 24 | def load_from_pretrained(self, name="vgg_lpips"): 25 | try: 26 | data = torch.load("vgg.pth", map_location=torch.device("cpu")) 27 | except: 28 | print("Failed to load vgg.pth, downloading...") 29 | os.system( 30 | "wget https://heibox.uni-heidelberg.de/seafhttp/files/9535cbee-6558-4c0c-8743-78f5e56ea75e/vgg.pth" 31 | ) 32 | data = torch.load("vgg.pth", map_location=torch.device("cpu")) 33 | 34 | self.load_state_dict( 35 | data, 36 | strict=False, 37 | ) 38 | 39 | def forward(self, input, target): 40 | in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) 41 | outs0, outs1 = self.net(in0_input), self.net(in1_input) 42 | feats0, feats1, diffs = {}, {}, {} 43 | lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] 44 | for kk in range(len(self.chns)): 45 | feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor( 46 | outs1[kk] 47 | ) 48 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 49 | 50 | res = [ 51 | spatial_average(lins[kk].model(diffs[kk]), keepdim=True) 52 | for kk in range(len(self.chns)) 53 | ] 54 | val = res[0] 55 | for l in range(1, len(self.chns)): 56 | val += res[l] 57 | return val 58 | 59 | 60 | class ScalingLayer(nn.Module): 61 | def __init__(self): 62 | super(ScalingLayer, self).__init__() 63 | self.register_buffer( 64 | "shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None] 65 | ) 66 | self.register_buffer( 67 | "scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None] 68 | ) 69 | 70 | def forward(self, inp): 71 | return (inp - self.shift) / self.scale 72 | 73 | 74 | class NetLinLayer(nn.Module): 75 | """A single linear layer which does a 1x1 conv""" 76 | 77 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 78 | super(NetLinLayer, self).__init__() 79 | layers = ( 80 | [ 81 | nn.Dropout(), 82 | ] 83 | if (use_dropout) 84 | else [] 85 | ) 86 | layers += [ 87 | nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), 88 | ] 89 | self.model = nn.Sequential(*layers) 90 | 91 | 92 | class vgg16(torch.nn.Module): 93 | def __init__(self, requires_grad=False, pretrained=True): 94 | super(vgg16, self).__init__() 95 | vgg_pretrained_features = models.vgg16(pretrained=pretrained).features 96 | self.slice1 = torch.nn.Sequential() 97 | self.slice2 = torch.nn.Sequential() 98 | self.slice3 = torch.nn.Sequential() 99 | self.slice4 = torch.nn.Sequential() 100 | self.slice5 = torch.nn.Sequential() 101 | self.N_slices = 5 102 | for x in range(4): 103 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 104 | for x in range(4, 9): 105 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 106 | for x in range(9, 16): 107 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 108 | for x in range(16, 23): 109 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 110 | for x in range(23, 30): 111 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 112 | if not requires_grad: 113 | for param in self.parameters(): 114 | param.requires_grad = False 115 | 116 | def forward(self, X): 117 | h = self.slice1(X) 118 | h_relu1_2 = h 119 | h = self.slice2(h) 120 | h_relu2_2 = h 121 | h = self.slice3(h) 122 | h_relu3_3 = h 123 | h = self.slice4(h) 124 | h_relu4_3 = h 125 | h = self.slice5(h) 126 | h_relu5_3 = h 127 | vgg_outputs = namedtuple( 128 | "VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"] 129 | ) 130 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 131 | return out 132 | 133 | 134 | def normalize_tensor(x, eps=1e-10): 135 | norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True)) 136 | return x / (norm_factor + eps) 137 | 138 | 139 | def spatial_average(x, keepdim=True): 140 | return x.mean([2, 3], keepdim=keepdim) 141 | 142 | 143 | class PatchDiscriminator(nn.Module): 144 | def __init__(self): 145 | super(PatchDiscriminator, self).__init__() 146 | self.scaling_layer = ScalingLayer() 147 | 148 | _vgg = models.vgg16(pretrained=True) 149 | 150 | self.slice1 = nn.Sequential(_vgg.features[:4]) 151 | self.slice2 = nn.Sequential(_vgg.features[4:9]) 152 | self.slice3 = nn.Sequential(_vgg.features[9:16]) 153 | self.slice4 = nn.Sequential(_vgg.features[16:23]) 154 | self.slice5 = nn.Sequential(_vgg.features[23:30]) 155 | 156 | self.binary_classifier1 = nn.Sequential( 157 | nn.Conv2d(64, 32, kernel_size=4, stride=4, padding=0, bias=True), 158 | nn.ReLU(), 159 | nn.Conv2d(32, 1, kernel_size=4, stride=4, padding=0, bias=True), 160 | ) 161 | nn.init.zeros_(self.binary_classifier1[-1].weight) 162 | 163 | self.binary_classifier2 = nn.Sequential( 164 | nn.Conv2d(128, 64, kernel_size=4, stride=4, padding=0, bias=True), 165 | nn.ReLU(), 166 | nn.Conv2d(64, 1, kernel_size=2, stride=2, padding=0, bias=True), 167 | ) 168 | nn.init.zeros_(self.binary_classifier2[-1].weight) 169 | 170 | self.binary_classifier3 = nn.Sequential( 171 | nn.Conv2d(256, 128, kernel_size=2, stride=2, padding=0, bias=True), 172 | nn.ReLU(), 173 | nn.Conv2d(128, 1, kernel_size=2, stride=2, padding=0, bias=True), 174 | ) 175 | nn.init.zeros_(self.binary_classifier3[-1].weight) 176 | 177 | self.binary_classifier4 = nn.Sequential( 178 | nn.Conv2d(512, 1, kernel_size=2, stride=2, padding=0, bias=True), 179 | ) 180 | nn.init.zeros_(self.binary_classifier4[-1].weight) 181 | 182 | self.binary_classifier5 = nn.Sequential( 183 | nn.Conv2d(512, 1, kernel_size=1, stride=1, padding=0, bias=True), 184 | ) 185 | nn.init.zeros_(self.binary_classifier5[-1].weight) 186 | 187 | def forward(self, x): 188 | x = self.scaling_layer(x) 189 | features1 = self.slice1(x) 190 | features2 = self.slice2(features1) 191 | features3 = self.slice3(features2) 192 | features4 = self.slice4(features3) 193 | features5 = self.slice5(features4) 194 | 195 | # torch.Size([1, 64, 256, 256]) torch.Size([1, 128, 128, 128]) torch.Size([1, 256, 64, 64]) torch.Size([1, 512, 32, 32]) torch.Size([1, 512, 16, 16]) 196 | 197 | bc1 = self.binary_classifier1(features1).flatten(1) 198 | bc2 = self.binary_classifier2(features2).flatten(1) 199 | bc3 = self.binary_classifier3(features3).flatten(1) 200 | bc4 = self.binary_classifier4(features4).flatten(1) 201 | bc5 = self.binary_classifier5(features5).flatten(1) 202 | 203 | return bc1 + bc2 + bc3 + bc4 + bc5 204 | 205 | 206 | dec_lo, dec_hi = ( 207 | torch.Tensor([-0.1768, 0.3536, 1.0607, 0.3536, -0.1768, 0.0000]), 208 | torch.Tensor([0.0000, -0.0000, 0.3536, -0.7071, 0.3536, -0.0000]), 209 | ) 210 | 211 | filters = torch.stack( 212 | [ 213 | dec_lo.unsqueeze(0) * dec_lo.unsqueeze(1), 214 | dec_lo.unsqueeze(0) * dec_hi.unsqueeze(1), 215 | dec_hi.unsqueeze(0) * dec_lo.unsqueeze(1), 216 | dec_hi.unsqueeze(0) * dec_hi.unsqueeze(1), 217 | ], 218 | dim=0, 219 | ) 220 | 221 | filters_expanded = filters.unsqueeze(1) 222 | 223 | 224 | def prepare_filter(device): 225 | global filters_expanded 226 | filters_expanded = filters_expanded.to(device) 227 | 228 | 229 | def wavelet_transform_multi_channel(x, levels=4): 230 | B, C, H, W = x.shape 231 | padded = torch.nn.functional.pad(x, (2, 2, 2, 2)) 232 | 233 | # use predefined filters 234 | global filters_expanded 235 | 236 | ress = [] 237 | for ch in range(C): 238 | res = torch.nn.functional.conv2d( 239 | padded[:, ch : ch + 1], filters_expanded, stride=2 240 | ) 241 | ress.append(res) 242 | 243 | res = torch.cat(ress, dim=1) 244 | H_out, W_out = res.shape[2], res.shape[3] 245 | res = res.view(B, C, 4, H_out, W_out) 246 | res = res.view(B, 4 * C, H_out, W_out) 247 | return res 248 | 249 | 250 | def test_patch_discriminator(): 251 | vggDiscriminator = PatchDiscriminator().cuda() 252 | x = vggDiscriminator(torch.randn(1, 3, 256, 256).cuda()) 253 | print(x.shape) 254 | 255 | 256 | if __name__ == "__main__": 257 | vggDiscriminator = PatchDiscriminator().cuda() 258 | x = vggDiscriminator(torch.randn(1, 3, 256, 256).cuda()) 259 | print(x.shape) 260 | -------------------------------------------------------------------------------- /vae_trainer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import random 4 | 5 | import click 6 | import numpy as np 7 | import torch 8 | import torch.distributed as dist 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torch.optim as optim 12 | import torchvision.transforms as transforms 13 | import webdataset as wds 14 | from torch.nn.parallel import DistributedDataParallel as DDP 15 | from torchvision.transforms import GaussianBlur 16 | from transformers import get_cosine_schedule_with_warmup 17 | 18 | torch.backends.cuda.matmul.allow_tf32 = True 19 | torch.backends.cudnn.allow_tf32 = True 20 | 21 | import wandb 22 | from ae import VAE 23 | from utils import LPIPS, PatchDiscriminator, prepare_filter 24 | import time 25 | 26 | 27 | class GradNormFunction(torch.autograd.Function): 28 | @staticmethod 29 | def forward(ctx, x, weight): 30 | ctx.save_for_backward(weight) 31 | return x.clone() 32 | 33 | @staticmethod 34 | def backward(ctx, grad_output): 35 | weight = ctx.saved_tensors[0] 36 | 37 | # grad_output_norm = torch.linalg.vector_norm( 38 | # grad_output, dim=list(range(1, len(grad_output.shape))), keepdim=True 39 | # ).mean() 40 | grad_output_norm = torch.norm(grad_output).mean().item() 41 | # nccl over all nodes 42 | grad_output_norm = avg_scalar_over_nodes( 43 | grad_output_norm, device=grad_output.device 44 | ) 45 | 46 | grad_output_normalized = weight * grad_output / (grad_output_norm + 1e-8) 47 | 48 | return grad_output_normalized, None 49 | 50 | 51 | def gradnorm(x, weight=1.0): 52 | weight = torch.tensor(weight, device=x.device) 53 | return GradNormFunction.apply(x, weight) 54 | 55 | 56 | @torch.no_grad() 57 | def avg_scalar_over_nodes(value: float, device): 58 | value = torch.tensor(value, device=device) 59 | dist.all_reduce(value, op=dist.ReduceOp.AVG) 60 | return value.item() 61 | 62 | 63 | def gan_disc_loss(real_preds, fake_preds, disc_type="bce"): 64 | if disc_type == "bce": 65 | real_loss = nn.functional.binary_cross_entropy_with_logits( 66 | real_preds, torch.ones_like(real_preds) 67 | ) 68 | fake_loss = nn.functional.binary_cross_entropy_with_logits( 69 | fake_preds, torch.zeros_like(fake_preds) 70 | ) 71 | # eval its online performance 72 | avg_real_preds = real_preds.mean().item() 73 | avg_fake_preds = fake_preds.mean().item() 74 | 75 | with torch.no_grad(): 76 | acc = (real_preds > 0).sum().item() + (fake_preds < 0).sum().item() 77 | acc = acc / (real_preds.numel() + fake_preds.numel()) 78 | 79 | if disc_type == "hinge": 80 | real_loss = nn.functional.relu(1 - real_preds).mean() 81 | fake_loss = nn.functional.relu(1 + fake_preds).mean() 82 | 83 | with torch.no_grad(): 84 | acc = (real_preds > 0).sum().item() + (fake_preds < 0).sum().item() 85 | acc = acc / (real_preds.numel() + fake_preds.numel()) 86 | 87 | avg_real_preds = real_preds.mean().item() 88 | avg_fake_preds = fake_preds.mean().item() 89 | 90 | return (real_loss + fake_loss) * 0.5, avg_real_preds, avg_fake_preds, acc 91 | 92 | 93 | MAX_WIDTH = 512 94 | 95 | this_transform = transforms.Compose( 96 | [ 97 | transforms.ToTensor(), 98 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), 99 | transforms.CenterCrop(512), 100 | transforms.Resize(MAX_WIDTH), 101 | ] 102 | ) 103 | 104 | 105 | def this_transform_random_crop_resize(x, width=MAX_WIDTH): 106 | 107 | x = transforms.ToTensor()(x) 108 | x = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])(x) 109 | 110 | if random.random() < 0.5: 111 | x = transforms.RandomCrop(width)(x) 112 | else: 113 | x = transforms.Resize(width)(x) 114 | x = transforms.RandomCrop(width)(x) 115 | 116 | return x 117 | 118 | 119 | def create_dataloader(url, batch_size, num_workers, do_shuffle=True, just_resize=False): 120 | dataset = wds.WebDataset( 121 | url, nodesplitter=wds.split_by_node, workersplitter=wds.split_by_worker 122 | ) 123 | dataset = dataset.shuffle(1000) if do_shuffle else dataset 124 | 125 | dataset = ( 126 | dataset.decode("rgb") 127 | .to_tuple("jpg;png") 128 | .map_tuple( 129 | this_transform_random_crop_resize if not just_resize else this_transform 130 | ) 131 | ) 132 | 133 | loader = wds.WebLoader( 134 | dataset, 135 | batch_size=batch_size, 136 | shuffle=False, 137 | num_workers=num_workers, 138 | pin_memory=True, 139 | ) 140 | return loader 141 | 142 | 143 | def blurriness_heatmap(input_image): 144 | grayscale_image = input_image.mean(dim=1, keepdim=True) 145 | 146 | laplacian_kernel = torch.tensor( 147 | [ 148 | [0, 1, 1, 1, 0], 149 | [1, 1, 1, 1, 1], 150 | [1, 1, -20, 1, 1], 151 | [1, 1, 1, 1, 1], 152 | [0, 1, 1, 1, 0], 153 | ], 154 | dtype=torch.float32, 155 | ) 156 | laplacian_kernel = laplacian_kernel.view(1, 1, 5, 5) 157 | 158 | laplacian_kernel = laplacian_kernel.to(input_image.device) 159 | 160 | edge_response = F.conv2d(grayscale_image, laplacian_kernel, padding=2) 161 | 162 | edge_magnitude = GaussianBlur(kernel_size=(13, 13), sigma=(2.0, 2.0))( 163 | edge_response.abs() 164 | ) 165 | 166 | edge_magnitude = (edge_magnitude - edge_magnitude.min()) / ( 167 | edge_magnitude.max() - edge_magnitude.min() + 1e-8 168 | ) 169 | 170 | blurriness_map = 1 - edge_magnitude 171 | 172 | blurriness_map = torch.where( 173 | blurriness_map < 0.8, torch.zeros_like(blurriness_map), blurriness_map 174 | ) 175 | 176 | return blurriness_map.repeat(1, 3, 1, 1) 177 | 178 | 179 | def vae_loss_function(x, x_reconstructed, z, do_pool=True, do_recon=False): 180 | # downsample images by factor of 8 181 | if do_recon: 182 | if do_pool: 183 | x_reconstructed_down = F.interpolate( 184 | x_reconstructed, scale_factor=1 / 16, mode="area" 185 | ) 186 | x_down = F.interpolate(x, scale_factor=1 / 16, mode="area") 187 | recon_loss = ((x_reconstructed_down - x_down)).abs().mean() 188 | else: 189 | x_reconstructed_down = x_reconstructed 190 | x_down = x 191 | 192 | recon_loss = ( 193 | ((x_reconstructed_down - x_down) * blurriness_heatmap(x_down)) 194 | .abs() 195 | .mean() 196 | ) 197 | recon_loss_item = recon_loss.item() 198 | else: 199 | recon_loss = 0 200 | recon_loss_item = 0 201 | 202 | elewise_mean_loss = z.pow(2) 203 | zloss = elewise_mean_loss.mean() 204 | 205 | with torch.no_grad(): 206 | actual_mean_loss = elewise_mean_loss.mean() 207 | actual_ks_loss = actual_mean_loss.mean() 208 | 209 | vae_loss = recon_loss * 0.0 + zloss * 0.1 210 | return vae_loss, { 211 | "recon_loss": recon_loss_item, 212 | "kl_loss": actual_ks_loss.item(), 213 | "average_of_abs_z": z.abs().mean().item(), 214 | "std_of_abs_z": z.abs().std().item(), 215 | "average_of_logvar": 0.0, 216 | "std_of_logvar": 0.0, 217 | } 218 | 219 | 220 | def cleanup(): 221 | dist.destroy_process_group() 222 | 223 | 224 | @click.command() 225 | @click.option( 226 | "--dataset_url", type=str, default="", help="URL for the training dataset" 227 | ) 228 | @click.option( 229 | "--test_dataset_url", type=str, default="", help="URL for the test dataset" 230 | ) 231 | @click.option("--num_epochs", type=int, default=2, help="Number of training epochs") 232 | @click.option("--batch_size", type=int, default=8, help="Batch size for training") 233 | @click.option("--do_ganloss", is_flag=True, help="Whether to use GAN loss") 234 | @click.option( 235 | "--learning_rate_vae", type=float, default=1e-5, help="Learning rate for VAE" 236 | ) 237 | @click.option( 238 | "--learning_rate_disc", 239 | type=float, 240 | default=2e-4, 241 | help="Learning rate for discriminator", 242 | ) 243 | @click.option("--vae_resolution", type=int, default=256, help="Resolution for VAE") 244 | @click.option("--vae_in_channels", type=int, default=3, help="Input channels for VAE") 245 | @click.option("--vae_ch", type=int, default=256, help="Base channel size for VAE") 246 | @click.option( 247 | "--vae_ch_mult", type=str, default="1,2,4,4", help="Channel multipliers for VAE" 248 | ) 249 | @click.option( 250 | "--vae_num_res_blocks", 251 | type=int, 252 | default=2, 253 | help="Number of residual blocks for VAE", 254 | ) 255 | @click.option( 256 | "--vae_z_channels", type=int, default=16, help="Number of latent channels for VAE" 257 | ) 258 | @click.option("--run_name", type=str, default="run", help="Name of the run for wandb") 259 | @click.option( 260 | "--max_steps", type=int, default=1000, help="Maximum number of steps to train for" 261 | ) 262 | @click.option( 263 | "--evaluate_every_n_steps", type=int, default=250, help="Evaluate every n steps" 264 | ) 265 | @click.option("--load_path", type=str, default=None, help="Path to load the model from") 266 | @click.option("--do_clamp", is_flag=True, help="Whether to clamp the latent codes") 267 | @click.option( 268 | "--clamp_th", type=float, default=8.0, help="Clamp threshold for the latent codes" 269 | ) 270 | @click.option( 271 | "--max_spatial_dim", 272 | type=int, 273 | default=256, 274 | help="Maximum spatial dimension for overall training", 275 | ) 276 | @click.option( 277 | "--do_attn", type=bool, default=False, help="Whether to use attention in the VAE" 278 | ) 279 | @click.option( 280 | "--decoder_also_perform_hr", 281 | type=bool, 282 | default=False, 283 | help="Whether to perform HR decoding in the decoder", 284 | ) 285 | @click.option( 286 | "--project_name", 287 | type=str, 288 | default="vae_sweep_attn_lr_width", 289 | help="Project name for wandb", 290 | ) 291 | @click.option( 292 | "--crop_invariance", 293 | type=bool, 294 | default=False, 295 | help="Whether to perform crop invariance", 296 | ) 297 | @click.option( 298 | "--flip_invariance", 299 | type=bool, 300 | default=False, 301 | help="Whether to perform flip invariance", 302 | ) 303 | @click.option( 304 | "--do_compile", 305 | type=bool, 306 | default=False, 307 | help="Whether to compile the model", 308 | ) 309 | @click.option( 310 | "--use_wavelet", 311 | type=bool, 312 | default=False, 313 | help="Whether to use wavelet transform in the encoder", 314 | ) 315 | @click.option( 316 | "--augment_before_perceptual_loss", 317 | type=bool, 318 | default=False, 319 | help="Whether to augment the images before the perceptual loss", 320 | ) 321 | @click.option( 322 | "--downscale_factor", 323 | type=int, 324 | default=16, 325 | help="Downscale factor for the latent space", 326 | ) 327 | @click.option( 328 | "--use_lecam", 329 | type=bool, 330 | default=False, 331 | help="Whether to use Lecam", 332 | ) 333 | @click.option( 334 | "--disc_type", 335 | type=str, 336 | default="bce", 337 | help="Discriminator type", 338 | ) 339 | def train_ddp( 340 | dataset_url, 341 | test_dataset_url, 342 | num_epochs, 343 | batch_size, 344 | do_ganloss, 345 | learning_rate_vae, 346 | learning_rate_disc, 347 | vae_resolution, 348 | vae_in_channels, 349 | vae_ch, 350 | vae_ch_mult, 351 | vae_num_res_blocks, 352 | vae_z_channels, 353 | run_name, 354 | max_steps, 355 | evaluate_every_n_steps, 356 | load_path, 357 | do_clamp, 358 | clamp_th, 359 | max_spatial_dim, 360 | do_attn, 361 | decoder_also_perform_hr, 362 | project_name, 363 | crop_invariance, 364 | flip_invariance, 365 | do_compile, 366 | use_wavelet, 367 | augment_before_perceptual_loss, 368 | downscale_factor, 369 | use_lecam, 370 | disc_type, 371 | ): 372 | 373 | # fix random seed 374 | torch.manual_seed(42) 375 | torch.cuda.manual_seed(42) 376 | torch.cuda.manual_seed_all(42) 377 | np.random.seed(42) 378 | random.seed(42) 379 | 380 | start_train = 0 381 | end_train = 128 * 16 382 | 383 | start_test = end_train + 1 384 | end_test = start_test + 8 385 | 386 | dataset_url = f"/home/ubuntu/ultimate_pipe/flux_ipadapter_trainer/dataset/art_webdataset/{{{start_train:05d}..{end_train:05d}}}.tar" 387 | test_dataset_url = f"/home/ubuntu/ultimate_pipe/flux_ipadapter_trainer/dataset/art_webdataset/{{{start_test:05d}..{end_test:05d}}}.tar" 388 | 389 | assert torch.cuda.is_available(), "CUDA is required for DDP" 390 | 391 | dist.init_process_group(backend="nccl") 392 | ddp_rank = int(os.environ["RANK"]) 393 | ddp_local_rank = int(os.environ["LOCAL_RANK"]) 394 | world_size = int(os.environ["WORLD_SIZE"]) 395 | device = f"cuda:{ddp_local_rank}" 396 | torch.cuda.set_device(device) 397 | master_process = ddp_rank == 0 398 | print(f"using device: {device}") 399 | 400 | if master_process: 401 | wandb.init( 402 | project=project_name, 403 | entity="simo", 404 | name=run_name, 405 | config={ 406 | "learning_rate_vae": learning_rate_vae, 407 | "learning_rate_disc": learning_rate_disc, 408 | "vae_ch": vae_ch, 409 | "vae_resolution": vae_resolution, 410 | "vae_in_channels": vae_in_channels, 411 | "vae_ch_mult": vae_ch_mult, 412 | "vae_num_res_blocks": vae_num_res_blocks, 413 | "vae_z_channels": vae_z_channels, 414 | "batch_size": batch_size, 415 | "num_epochs": num_epochs, 416 | "do_ganloss": do_ganloss, 417 | "do_attn": do_attn, 418 | "use_wavelet": use_wavelet, 419 | }, 420 | ) 421 | 422 | vae = VAE( 423 | resolution=vae_resolution, 424 | in_channels=vae_in_channels, 425 | ch=vae_ch, 426 | out_ch=vae_in_channels, 427 | ch_mult=[int(x) for x in vae_ch_mult.split(",")], 428 | num_res_blocks=vae_num_res_blocks, 429 | z_channels=vae_z_channels, 430 | use_attn=do_attn, 431 | decoder_also_perform_hr=decoder_also_perform_hr, 432 | use_wavelet=use_wavelet, 433 | ).cuda() 434 | 435 | discriminator = PatchDiscriminator().cuda() 436 | discriminator.requires_grad_(True) 437 | 438 | vae = DDP(vae, device_ids=[ddp_rank]) 439 | 440 | prepare_filter(device) 441 | 442 | if do_compile: 443 | vae.module.encoder = torch.compile( 444 | vae.module.encoder, fullgraph=False, mode="max-autotune" 445 | ) 446 | vae.module.decoder = torch.compile( 447 | vae.module.decoder, fullgraph=False, mode="max-autotune" 448 | ) 449 | 450 | discriminator = DDP(discriminator, device_ids=[ddp_rank]) 451 | 452 | # context 453 | ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16) 454 | 455 | optimizer_G = optim.AdamW( 456 | [ 457 | { 458 | "params": [p for n, p in vae.named_parameters() if "conv_in" not in n], 459 | "lr": learning_rate_vae / vae_ch, 460 | }, 461 | { 462 | "params": [p for n, p in vae.named_parameters() if "conv_in" in n], 463 | "lr": 1e-4, 464 | }, 465 | ], 466 | weight_decay=1e-3, 467 | betas=(0.9, 0.95), 468 | ) 469 | 470 | optimizer_D = optim.AdamW( 471 | discriminator.parameters(), 472 | lr=learning_rate_disc, 473 | weight_decay=1e-3, 474 | betas=(0.9, 0.95), 475 | ) 476 | 477 | lpips = LPIPS().cuda() 478 | 479 | dataloader = create_dataloader( 480 | dataset_url, batch_size, num_workers=4, do_shuffle=True 481 | ) 482 | test_dataloader = create_dataloader( 483 | test_dataset_url, batch_size, num_workers=4, do_shuffle=False, just_resize=True 484 | ) 485 | 486 | num_training_steps = max_steps 487 | num_warmup_steps = 200 488 | lr_scheduler = get_cosine_schedule_with_warmup( 489 | optimizer_G, num_warmup_steps, num_training_steps 490 | ) 491 | 492 | # Setup logger 493 | logger = logging.getLogger(__name__) 494 | logger.setLevel(logging.INFO) 495 | if master_process: 496 | handler = logging.StreamHandler() 497 | formatter = logging.Formatter( 498 | "%(asctime)s - %(name)s - %(levelname)s - %(message)s" 499 | ) 500 | handler.setFormatter(formatter) 501 | logger.addHandler(handler) 502 | 503 | global_step = 0 504 | 505 | if load_path is not None: 506 | state_dict = torch.load(load_path, map_location="cpu") 507 | try: 508 | status = vae.load_state_dict(state_dict, strict=True) 509 | except Exception as e: 510 | print(e) 511 | state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()} 512 | status = vae.load_state_dict(state_dict, strict=True) 513 | print(status) 514 | 515 | t0 = time.time() 516 | 517 | # lecam variable 518 | 519 | lecam_loss_weight = 0.1 520 | lecam_anchor_real_logits = 0.0 521 | lecam_anchor_fake_logits = 0.0 522 | lecam_beta = 0.9 523 | 524 | for epoch in range(num_epochs): 525 | for i, real_images_hr in enumerate(dataloader): 526 | time_taken_till_load = time.time() - t0 527 | 528 | t0 = time.time() 529 | # resize real image to 256 530 | real_images_hr = real_images_hr[0].to(device) 531 | real_images_for_enc = F.interpolate( 532 | real_images_hr, size=(256, 256), mode="area" 533 | ) 534 | if random.random() < 0.5: 535 | real_images_for_enc = torch.flip(real_images_for_enc, [-1]) 536 | real_images_hr = torch.flip(real_images_hr, [-1]) 537 | 538 | z = vae.module.encoder(real_images_for_enc) 539 | 540 | # z distribution 541 | with ctx: 542 | z_dist_value: torch.Tensor = z.detach().cpu().reshape(-1) 543 | 544 | def kurtosis(x): 545 | return ((x - x.mean()) ** 4).mean() / (x.std() ** 4) 546 | 547 | def skew(x): 548 | return ((x - x.mean()) ** 3).mean() / (x.std() ** 3) 549 | 550 | z_quantiles = { 551 | "0.0": z_dist_value.quantile(0.0), 552 | "0.2": z_dist_value.quantile(0.2), 553 | "0.4": z_dist_value.quantile(0.4), 554 | "0.6": z_dist_value.quantile(0.6), 555 | "0.8": z_dist_value.quantile(0.8), 556 | "1.0": z_dist_value.quantile(1.0), 557 | "kurtosis": kurtosis(z_dist_value), 558 | "skewness": skew(z_dist_value), 559 | } 560 | 561 | if do_clamp: 562 | z = z.clamp(-clamp_th, clamp_th) 563 | z_s = vae.module.reg(z) 564 | 565 | #### do aug 566 | 567 | if random.random() < 0.5 and flip_invariance: 568 | z_s = torch.flip(z_s, [-1]) 569 | z_s[:, -4:-2] = -z_s[:, -4:-2] 570 | real_images_hr = torch.flip(real_images_hr, [-1]) 571 | 572 | if random.random() < 0.5 and flip_invariance: 573 | z_s = torch.flip(z_s, [-2]) 574 | z_s[:, -2:] = -z_s[:, -2:] 575 | real_images_hr = torch.flip(real_images_hr, [-2]) 576 | 577 | if random.random() < 0.5 and crop_invariance: 578 | # crop image and latent.' 579 | 580 | # new_z_h, new_z_w, offset_z_h, offset_z_w 581 | z_h, z_w = z.shape[-2:] 582 | new_z_h = random.randint(12, z_h - 1) 583 | new_z_w = random.randint(12, z_w - 1) 584 | offset_z_h = random.randint(0, z_h - new_z_h - 1) 585 | offset_z_w = random.randint(0, z_w - new_z_w - 1) 586 | 587 | new_h = ( 588 | new_z_h * downscale_factor * 2 589 | if decoder_also_perform_hr 590 | else new_z_h * downscale_factor 591 | ) 592 | new_w = ( 593 | new_z_w * downscale_factor * 2 594 | if decoder_also_perform_hr 595 | else new_z_w * downscale_factor 596 | ) 597 | offset_h = ( 598 | offset_z_h * downscale_factor * 2 599 | if decoder_also_perform_hr 600 | else offset_z_h * downscale_factor 601 | ) 602 | offset_w = ( 603 | offset_z_w * downscale_factor * 2 604 | if decoder_also_perform_hr 605 | else offset_z_w * downscale_factor 606 | ) 607 | 608 | real_images_hr = real_images_hr[ 609 | :, :, offset_h : offset_h + new_h, offset_w : offset_w + new_w 610 | ] 611 | z_s = z_s[ 612 | :, 613 | :, 614 | offset_z_h : offset_z_h + new_z_h, 615 | offset_z_w : offset_z_w + new_z_w, 616 | ] 617 | 618 | assert real_images_hr.shape[-2] == new_h 619 | assert real_images_hr.shape[-1] == new_w 620 | assert z_s.shape[-2] == new_z_h 621 | assert z_s.shape[-1] == new_z_w 622 | 623 | with ctx: 624 | reconstructed = vae.module.decoder(z_s) 625 | 626 | if global_step >= max_steps: 627 | break 628 | 629 | if do_ganloss: 630 | real_preds = discriminator(real_images_hr) 631 | fake_preds = discriminator(reconstructed.detach()) 632 | d_loss, avg_real_logits, avg_fake_logits, disc_acc = gan_disc_loss( 633 | real_preds, fake_preds, disc_type 634 | ) 635 | 636 | avg_real_logits = avg_scalar_over_nodes(avg_real_logits, device) 637 | avg_fake_logits = avg_scalar_over_nodes(avg_fake_logits, device) 638 | 639 | lecam_anchor_real_logits = ( 640 | lecam_beta * lecam_anchor_real_logits 641 | + (1 - lecam_beta) * avg_real_logits 642 | ) 643 | lecam_anchor_fake_logits = ( 644 | lecam_beta * lecam_anchor_fake_logits 645 | + (1 - lecam_beta) * avg_fake_logits 646 | ) 647 | total_d_loss = d_loss.mean() 648 | d_loss_item = total_d_loss.item() 649 | if use_lecam: 650 | # penalize the real logits to fake and fake logits to real. 651 | lecam_loss = (real_preds - lecam_anchor_fake_logits).pow( 652 | 2 653 | ).mean() + (fake_preds - lecam_anchor_real_logits).pow(2).mean() 654 | lecam_loss_item = lecam_loss.item() 655 | total_d_loss = total_d_loss + lecam_loss * lecam_loss_weight 656 | 657 | optimizer_D.zero_grad() 658 | total_d_loss.backward(retain_graph=True) 659 | optimizer_D.step() 660 | 661 | # unnormalize the images, and perceptual loss 662 | _recon_for_perceptual = gradnorm(reconstructed) 663 | 664 | if augment_before_perceptual_loss: 665 | real_images_hr_aug = real_images_hr.clone() 666 | if random.random() < 0.5: 667 | _recon_for_perceptual = torch.flip(_recon_for_perceptual, [-1]) 668 | real_images_hr_aug = torch.flip(real_images_hr_aug, [-1]) 669 | if random.random() < 0.5: 670 | _recon_for_perceptual = torch.flip(_recon_for_perceptual, [-2]) 671 | real_images_hr_aug = torch.flip(real_images_hr_aug, [-2]) 672 | 673 | else: 674 | real_images_hr_aug = real_images_hr 675 | 676 | percep_rec_loss = lpips(_recon_for_perceptual, real_images_hr_aug).mean() 677 | 678 | # mse, vae loss. 679 | recon_for_mse = gradnorm(reconstructed, weight=0.001) 680 | vae_loss, loss_data = vae_loss_function(real_images_hr, recon_for_mse, z) 681 | # gan loss 682 | if do_ganloss and global_step >= 0: 683 | recon_for_gan = gradnorm(reconstructed, weight=1.0) 684 | fake_preds = discriminator(recon_for_gan) 685 | real_preds_const = real_preds.clone().detach() 686 | # loss where (real > fake + 0.01) 687 | # g_gan_loss = (real_preds_const - fake_preds - 0.1).relu().mean() 688 | if disc_type == "bce": 689 | g_gan_loss = nn.functional.binary_cross_entropy_with_logits( 690 | fake_preds, torch.ones_like(fake_preds) 691 | ) 692 | elif disc_type == "hinge": 693 | g_gan_loss = -fake_preds.mean() 694 | 695 | overall_vae_loss = percep_rec_loss + g_gan_loss + vae_loss 696 | g_gan_loss = g_gan_loss.item() 697 | else: 698 | overall_vae_loss = percep_rec_loss + vae_loss 699 | g_gan_loss = 0.0 700 | 701 | overall_vae_loss.backward() 702 | optimizer_G.step() 703 | optimizer_G.zero_grad() 704 | lr_scheduler.step() 705 | 706 | if do_ganloss: 707 | 708 | optimizer_D.zero_grad() 709 | 710 | time_taken_till_step = time.time() - t0 711 | 712 | if master_process: 713 | if global_step % 5 == 0: 714 | wandb.log( 715 | { 716 | "epoch": epoch, 717 | "batch": i, 718 | "overall_vae_loss": overall_vae_loss.item(), 719 | "mse_loss": loss_data["recon_loss"], 720 | "kl_loss": loss_data["kl_loss"], 721 | "perceptual_loss": percep_rec_loss.item(), 722 | "gan/generator_gan_loss": ( 723 | g_gan_loss if do_ganloss else None 724 | ), 725 | "z_quantiles/abs_z": loss_data["average_of_abs_z"], 726 | "z_quantiles/std_z": loss_data["std_of_abs_z"], 727 | "z_quantiles/logvar": loss_data["average_of_logvar"], 728 | "gan/avg_real_logits": ( 729 | avg_real_logits if do_ganloss else None 730 | ), 731 | "gan/avg_fake_logits": ( 732 | avg_fake_logits if do_ganloss else None 733 | ), 734 | "gan/discriminator_loss": ( 735 | d_loss_item if do_ganloss else None 736 | ), 737 | "gan/discriminator_accuracy": ( 738 | disc_acc if do_ganloss else None 739 | ), 740 | "gan/lecam_loss": lecam_loss_item if do_ganloss else None, 741 | "gan/lecam_anchor_real_logits": ( 742 | lecam_anchor_real_logits if do_ganloss else None 743 | ), 744 | "gan/lecam_anchor_fake_logits": ( 745 | lecam_anchor_fake_logits if do_ganloss else None 746 | ), 747 | "z_quantiles/qs": z_quantiles, 748 | "time_taken_till_step": time_taken_till_step, 749 | "time_taken_till_load": time_taken_till_load, 750 | } 751 | ) 752 | 753 | if global_step % 200 == 0: 754 | 755 | wandb.log( 756 | { 757 | f"loss_stepwise/mse_loss_{global_step}": loss_data[ 758 | "recon_loss" 759 | ], 760 | f"loss_stepwise/kl_loss_{global_step}": loss_data[ 761 | "kl_loss" 762 | ], 763 | f"loss_stepwise/overall_vae_loss_{global_step}": overall_vae_loss.item(), 764 | } 765 | ) 766 | 767 | log_message = f"Epoch [{epoch}/{num_epochs}] - " 768 | log_items = [ 769 | ("perceptual_loss", percep_rec_loss.item()), 770 | ("mse_loss", loss_data["recon_loss"]), 771 | ("kl_loss", loss_data["kl_loss"]), 772 | ("overall_vae_loss", overall_vae_loss.item()), 773 | ("ABS mu (0.0): average_of_abs_z", loss_data["average_of_abs_z"]), 774 | ("STD mu : std_of_abs_z", loss_data["std_of_abs_z"]), 775 | ( 776 | "ABS logvar (0.0) : average_of_logvar", 777 | loss_data["average_of_logvar"], 778 | ), 779 | ("STD logvar : std_of_logvar", loss_data["std_of_logvar"]), 780 | *[(f"z_quantiles/{q}", v) for q, v in z_quantiles.items()], 781 | ("time_taken_till_step", time_taken_till_step), 782 | ("time_taken_till_load", time_taken_till_load), 783 | ] 784 | 785 | if do_ganloss: 786 | log_items = [ 787 | ("d_loss", d_loss_item), 788 | ("gan_loss", g_gan_loss), 789 | ("avg_real_logits", avg_real_logits), 790 | ("avg_fake_logits", avg_fake_logits), 791 | ("discriminator_accuracy", disc_acc), 792 | ("lecam_loss", lecam_loss_item), 793 | ("lecam_anchor_real_logits", lecam_anchor_real_logits), 794 | ("lecam_anchor_fake_logits", lecam_anchor_fake_logits), 795 | ] + log_items 796 | 797 | log_message += "\n\t".join( 798 | [f"{key}: {value:.4f}" for key, value in log_items] 799 | ) 800 | logger.info(log_message) 801 | 802 | global_step += 1 803 | t0 = time.time() 804 | 805 | if ( 806 | evaluate_every_n_steps > 0 807 | and global_step % evaluate_every_n_steps == 1 808 | and master_process 809 | ): 810 | 811 | with torch.no_grad(): 812 | all_test_images = [] 813 | all_reconstructed_test = [] 814 | 815 | for test_images in test_dataloader: 816 | test_images_ori = test_images[0].to(device) 817 | # resize to 256 818 | test_images = F.interpolate( 819 | test_images_ori, size=(256, 256), mode="area" 820 | ) 821 | with ctx: 822 | z = vae.module.encoder(test_images) 823 | 824 | if do_clamp: 825 | z = z.clamp(-clamp_th, clamp_th) 826 | 827 | z_s = vae.module.reg(z) 828 | 829 | # [1, 2] 830 | # [3, 4] 831 | # -> 832 | # [3, 4] 833 | # [1, 2] 834 | # -> 835 | # [4, 3] 836 | # [2, 1] 837 | if flip_invariance: 838 | z_s = torch.flip(z_s, [-1, -2]) 839 | z_s[:, -4:] = -z_s[:, -4:] 840 | 841 | with ctx: 842 | reconstructed_test = vae.module.decoder(z_s) 843 | 844 | # unnormalize the images 845 | test_images_ori = test_images_ori * 0.5 + 0.5 846 | reconstructed_test = reconstructed_test * 0.5 + 0.5 847 | # clamp 848 | test_images_ori = test_images_ori.clamp(0, 1) 849 | reconstructed_test = reconstructed_test.clamp(0, 1) 850 | 851 | # flip twice 852 | if flip_invariance: 853 | reconstructed_test = torch.flip( 854 | reconstructed_test, [-1, -2] 855 | ) 856 | 857 | all_test_images.append(test_images_ori) 858 | all_reconstructed_test.append(reconstructed_test) 859 | 860 | if len(all_test_images) >= 2: 861 | break 862 | 863 | test_images = torch.cat(all_test_images, dim=0) 864 | reconstructed_test = torch.cat(all_reconstructed_test, dim=0) 865 | 866 | logger.info(f"Epoch [{epoch}/{num_epochs}] - Logging test images") 867 | 868 | # crop test and recon to 64 x 64 869 | D = 512 if decoder_also_perform_hr else 256 870 | offset = 0 871 | test_images = test_images[ 872 | :, :, offset : offset + D, offset : offset + D 873 | ].cpu() 874 | reconstructed_test = reconstructed_test[ 875 | :, :, offset : offset + D, offset : offset + D 876 | ].cpu() 877 | 878 | # concat the images into one large image. 879 | # make size of (D * 4) x (D * 4) 880 | recon_all_image = torch.zeros((3, D * 4, D * 4)) 881 | test_all_image = torch.zeros((3, D * 4, D * 4)) 882 | 883 | for i in range(2): 884 | for j in range(4): 885 | recon_all_image[ 886 | :, i * D : (i + 1) * D, j * D : (j + 1) * D 887 | ] = reconstructed_test[i * 4 + j] 888 | test_all_image[ 889 | :, i * D : (i + 1) * D, j * D : (j + 1) * D 890 | ] = test_images[i * 4 + j] 891 | 892 | wandb.log( 893 | { 894 | "reconstructed_test_images": [ 895 | wandb.Image(recon_all_image), 896 | ], 897 | "test_images": [ 898 | wandb.Image(test_all_image), 899 | ], 900 | } 901 | ) 902 | 903 | os.makedirs(f"./ckpt/{run_name}", exist_ok=True) 904 | torch.save( 905 | vae.state_dict(), 906 | f"./ckpt/{run_name}/vae_epoch_{epoch}_step_{global_step}.pt", 907 | ) 908 | print( 909 | f"Saved checkpoint to ./ckpt/{run_name}/vae_epoch_{epoch}_step_{global_step}.pt" 910 | ) 911 | 912 | cleanup() 913 | 914 | 915 | if __name__ == "__main__": 916 | 917 | # Example: torchrun --nproc_per_node=8 vae_trainer.py 918 | train_ddp() 919 | --------------------------------------------------------------------------------