├── models ├── __init__.py ├── stylegan2_discriminator.py ├── cnn_discriminator.py ├── cnn_generator.py ├── multi_head.py ├── utils.py ├── diffaugment.py ├── vitgan_discriminator.py └── vitgan_generator.py ├── LICENSE ├── README.md ├── main.py └── ViTGAN_pytorch.ipynb /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .cnn_discriminator import CNN 2 | from .cnn_generator import CNNGenerator 3 | from .stylegan2_discriminator import StyleGanDiscriminator 4 | from .vitgan_discriminator import ViT 5 | from .vitgan_generator import GeneratorViT 6 | from .diffaugment import DiffAugment 7 | from .utils import spectral_norm 8 | 9 | __all__ = [CNN, CNNGenerator, StyleGanDiscriminator, ViT, GeneratorViT, DiffAugment, spectral_norm] 10 | -------------------------------------------------------------------------------- /models/stylegan2_discriminator.py: -------------------------------------------------------------------------------- 1 | from stylegan2_pytorch import stylegan2_pytorch 2 | from .diffaugment import DiffAugment 3 | 4 | # StyleGAN2 Discriminator 5 | 6 | class StyleGanDiscriminator(stylegan2_pytorch.Discriminator): 7 | def __init__(self, 8 | diffaugment='color,translation,cutout', 9 | **kwargs): 10 | self.diffaugment = diffaugment 11 | super().__init__(**kwargs) 12 | def forward(self, img, do_augment=True): 13 | if do_augment: 14 | img = DiffAugment(img, policy=self.diffaugment) 15 | out, _ = super().forward(img) 16 | return out -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Teodor Toshkov 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /models/cnn_discriminator.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from .diffaugment import DiffAugment 3 | 4 | # CNN Discriminator 5 | 6 | class CNN(nn.Sequential): 7 | def __init__(self, 8 | diffaugment='color,translation,cutout', 9 | **kwargs): 10 | self.diffaugment = diffaugment 11 | super().__init__( 12 | nn.Conv2d(3,32,kernel_size=3,padding=1), 13 | nn.ReLU(), 14 | nn.Conv2d(32,64,kernel_size=3,stride=1,padding=1), 15 | nn.ReLU(), 16 | nn.MaxPool2d(2,2), 17 | 18 | nn.Conv2d(64,128,kernel_size=3,stride=1,padding=1), 19 | nn.ReLU(), 20 | nn.Conv2d(128,128,kernel_size=3,stride=1,padding=1), 21 | nn.ReLU(), 22 | nn.MaxPool2d(2,2), 23 | 24 | nn.Conv2d(128,256,kernel_size=3,stride=1,padding=1), 25 | nn.ReLU(), 26 | nn.Conv2d(256,256,kernel_size=3,stride=1,padding=1), 27 | nn.ReLU(), 28 | nn.MaxPool2d(2,2), 29 | 30 | nn.Flatten(), 31 | nn.Linear(256*4*4,1024), 32 | nn.ReLU(), 33 | nn.Linear(1024,512), 34 | nn.ReLU(), 35 | nn.Linear(512,1) 36 | ) 37 | 38 | def forward(self, img, do_augment=True): 39 | if do_augment: 40 | img = DiffAugment(img, policy=self.diffaugment) 41 | return super().forward(img) -------------------------------------------------------------------------------- /models/cnn_generator.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | # CNN Generator 4 | 5 | class CNNGenerator(nn.Module): 6 | def __init__(self, hidden_size, latent_dim): 7 | super(CNNGenerator, self).__init__() 8 | self.hidden_size = hidden_size 9 | self.w = nn.Linear(latent_dim, hidden_size * 2 * 4 * 4, bias=False) 10 | self.main = nn.Sequential( 11 | # input is Z, going into a convolution 12 | nn.BatchNorm2d(hidden_size * 2), 13 | nn.ReLU(True), 14 | # state size. (ngf*8) x 4 x 4 15 | nn.ConvTranspose2d(hidden_size * 2, hidden_size, 4, 2, 1, bias=False), 16 | nn.BatchNorm2d(hidden_size), 17 | nn.ReLU(True), 18 | # state size. (ngf*4) x 8 x 8 19 | nn.ConvTranspose2d( hidden_size, hidden_size // 2, 4, 2, 1, bias=False), 20 | nn.BatchNorm2d(hidden_size // 2), 21 | nn.ReLU(True), 22 | # state size. (ngf*2) x 16 x 16 23 | nn.ConvTranspose2d( hidden_size // 2, hidden_size // 4, 4, 2, 1, bias=False), 24 | nn.BatchNorm2d(hidden_size // 4), 25 | nn.ReLU(True), 26 | # state size. (ngf*2) x 32 x 32 27 | nn.ConvTranspose2d( hidden_size // 4, 3, 3, 1, 1, bias=False), 28 | nn.Tanh(), 29 | # state size. (nc) x 64 x 64 30 | ) 31 | 32 | def forward(self, input): 33 | input = self.w(input).view((-1, self.hidden_size * 2, 4, 4)) 34 | return self.main(input) 35 | -------------------------------------------------------------------------------- /models/multi_head.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from einops import rearrange 4 | import torch.nn.functional as F 5 | 6 | from .utils import spectral_norm 7 | 8 | class MultiHeadAttention(nn.Module): 9 | def __init__(self, emb_size=384, num_heads=4, dropout=0, discriminator=False, **kwargs): 10 | super().__init__() 11 | self.emb_size = emb_size 12 | self.num_heads = num_heads 13 | self.discriminator = discriminator 14 | # fuse the queries, keys and values in one matrix 15 | self.qkv = nn.Linear(emb_size, emb_size * 3) 16 | self.att_drop = nn.Dropout(dropout) 17 | self.projection = nn.Linear(emb_size, emb_size) 18 | if self.discriminator: 19 | self.qkv = spectral_norm(self.qkv) 20 | self.projection = spectral_norm(self.projection) 21 | 22 | def forward(self, x, mask=None): 23 | # split keys, queries and values in num_heads 24 | qkv = rearrange(self.qkv(x), "b n (h d qkv) -> (qkv) b h n d", h=self.num_heads, qkv=3) 25 | queries, keys, values = qkv[0], qkv[1], qkv[2] 26 | if self.discriminator: 27 | # calculate L2-distances 28 | energy = torch.cdist(queries.contiguous(), keys.contiguous(), p=2) 29 | else: 30 | # sum up over the last axis 31 | energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys) # batch, num_heads, query_len, key_len 32 | 33 | if mask is not None: 34 | fill_value = torch.finfo(torch.float32).min 35 | energy.mask_fill(~mask, fill_value) 36 | 37 | scaling = self.emb_size ** (1/2) 38 | att = F.softmax(energy, dim=-1) / scaling 39 | att = self.att_drop(att) 40 | # sum up over the third axis 41 | out = torch.einsum('bhal, bhlv -> bhav ', att, values) 42 | out = rearrange(out, "b h n d -> b n (h d)") 43 | out = self.projection(out) 44 | return out -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import Parameter 4 | import numpy as np 5 | 6 | # Improved Spectral Normalization (ISN) 7 | # Reference code: https://github.com/koshian2/SNGAN 8 | # When updating the weights, normalize the weights' norm to its norm at initialization. 9 | 10 | def l2normalize(v, eps=1e-4): 11 | return v / (v.norm() + eps) 12 | 13 | class spectral_norm(nn.Module): 14 | def __init__(self, module, name='weight', power_iterations=1): 15 | super().__init__() 16 | self.module = module 17 | self.name = name 18 | self.power_iterations = power_iterations 19 | if not self._made_params(): 20 | self._make_params() 21 | self.w_init_sigma = None 22 | self.w_initalized = False 23 | 24 | def _update_u_v(self): 25 | u = getattr(self.module, self.name + "_u") 26 | v = getattr(self.module, self.name + "_v") 27 | w = getattr(self.module, self.name + "_bar") 28 | 29 | height = w.data.shape[0] 30 | _w = w.view(height, -1) 31 | for _ in range(self.power_iterations): 32 | v = l2normalize(torch.matmul(_w.t(), u)) 33 | u = l2normalize(torch.matmul(_w, v)) 34 | 35 | sigma = u.dot((_w).mv(v)) 36 | if not self.w_initalized: 37 | self.w_init_sigma = np.array(sigma.expand_as(w).detach().cpu()) 38 | self.w_initalized = True 39 | setattr(self.module, self.name, torch.tensor(self.w_init_sigma).to(sigma.device) * w / sigma.expand_as(w)) 40 | 41 | def _made_params(self): 42 | try: 43 | getattr(self.module, self.name + "_u") 44 | getattr(self.module, self.name + "_v") 45 | getattr(self.module, self.name + "_bar") 46 | return True 47 | except AttributeError: 48 | return False 49 | 50 | def _make_params(self): 51 | w = getattr(self.module, self.name) 52 | 53 | height = w.data.shape[0] 54 | width = w.view(height, -1).data.shape[1] 55 | 56 | u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False) 57 | v = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False) 58 | u.data = l2normalize(u.data) 59 | v.data = l2normalize(v.data) 60 | w_bar = Parameter(w.data) 61 | 62 | del self.module._parameters[self.name] 63 | self.module.register_parameter(self.name + "_u", u) 64 | self.module.register_parameter(self.name + "_v", v) 65 | self.module.register_parameter(self.name + "_bar", w_bar) 66 | 67 | def forward(self, *args): 68 | self._update_u_v() 69 | return self.module.forward(*args) -------------------------------------------------------------------------------- /models/diffaugment.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | # https://github.com/mit-han-lab/data-efficient-gans/blob/master/DiffAugment-stylegan2-pytorch/DiffAugment_pytorch.py 5 | 6 | def DiffAugment(x, policy='', channels_first=True): 7 | if policy: 8 | if not channels_first: 9 | x = x.permute(0, 3, 1, 2) 10 | for p in policy.split(','): 11 | for f in AUGMENT_FNS[p]: 12 | x = f(x) 13 | if not channels_first: 14 | x = x.permute(0, 2, 3, 1) 15 | x = x.contiguous() 16 | return x 17 | 18 | 19 | def rand_brightness(x): 20 | x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5) 21 | return x 22 | 23 | 24 | def rand_saturation(x): 25 | x_mean = x.mean(dim=1, keepdim=True) 26 | x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean 27 | return x 28 | 29 | 30 | def rand_contrast(x): 31 | x_mean = x.mean(dim=[1, 2, 3], keepdim=True) 32 | x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean 33 | return x 34 | 35 | 36 | def rand_translation(x, ratio=0.1): 37 | shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) 38 | translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device) 39 | translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device) 40 | grid_batch, grid_x, grid_y = torch.meshgrid( 41 | torch.arange(x.size(0), dtype=torch.long, device=x.device), 42 | torch.arange(x.size(2), dtype=torch.long, device=x.device), 43 | torch.arange(x.size(3), dtype=torch.long, device=x.device), 44 | ) 45 | grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1) 46 | grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1) 47 | x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0]) 48 | x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2).contiguous() 49 | return x 50 | 51 | 52 | def rand_cutout(x, ratio=0.3): 53 | cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) 54 | offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device) 55 | offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device) 56 | grid_batch, grid_x, grid_y = torch.meshgrid( 57 | torch.arange(x.size(0), dtype=torch.long, device=x.device), 58 | torch.arange(cutout_size[0], dtype=torch.long, device=x.device), 59 | torch.arange(cutout_size[1], dtype=torch.long, device=x.device), 60 | ) 61 | grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1) 62 | grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1) 63 | mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device) 64 | mask[grid_batch, grid_x, grid_y] = 0 65 | x = x * mask.unsqueeze(1) 66 | return x 67 | 68 | 69 | AUGMENT_FNS = { 70 | 'color': [rand_brightness, rand_saturation, rand_contrast], 71 | 'translation': [rand_translation], 72 | 'cutout': [rand_cutout], 73 | } -------------------------------------------------------------------------------- /models/vitgan_discriminator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from einops import repeat 4 | from einops.layers.torch import Rearrange 5 | 6 | from .diffaugment import DiffAugment 7 | from .multi_head import MultiHeadAttention 8 | from .utils import spectral_norm 9 | 10 | # Discriminator 11 | 12 | class PatchEmbedding(nn.Module): 13 | def __init__(self, in_channels=3, patch_size=4, stride_size=4, emb_size=384, image_size=32, batch_size=64): 14 | super().__init__() 15 | self.patch_size = patch_size 16 | self.projection = nn.Sequential( 17 | # using a conv layer instead of a linear one -> performance gains 18 | spectral_norm(nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=stride_size)), 19 | Rearrange('b e (h) (w) -> b (h w) e'), 20 | ) 21 | self.cls_token = nn.Parameter(torch.randn(1, 1, emb_size)) 22 | num_patches = ((image_size-patch_size+stride_size) // stride_size) **2 + 1 23 | self.positions = nn.Parameter(torch.randn(num_patches, emb_size)) 24 | self.batch_size = batch_size 25 | 26 | def forward(self, x): 27 | b, _, _, _ = x.shape 28 | x = self.projection(x) 29 | cls_tokens = repeat(self.cls_token, '() n e -> b n e', b=b) 30 | # prepend the cls token to the input 31 | x = torch.cat([cls_tokens, x], dim=1) 32 | # add position embedding 33 | x += torch.sin(self.positions) 34 | return x 35 | 36 | class ResidualAdd(nn.Module): 37 | def __init__(self, fn): 38 | super().__init__() 39 | self.fn = fn 40 | 41 | def forward(self, x, **kwargs): 42 | res = x 43 | x = self.fn(x, **kwargs) 44 | x += res 45 | return x 46 | 47 | class DiscriminatorTransformerEncoderBlock(nn.Sequential): 48 | def __init__(self, 49 | emb_size=384, 50 | drop_p=0., 51 | forward_expansion=4, 52 | forward_drop_p=0., 53 | **kwargs): 54 | super().__init__( 55 | ResidualAdd(nn.Sequential( 56 | nn.LayerNorm(emb_size), 57 | MultiHeadAttention(emb_size, **kwargs), 58 | nn.Dropout(drop_p) 59 | )), 60 | ResidualAdd(nn.Sequential( 61 | nn.LayerNorm(emb_size), 62 | nn.Sequential( 63 | spectral_norm(nn.Linear(emb_size, forward_expansion * emb_size)), 64 | nn.GELU(), 65 | nn.Dropout(forward_drop_p), 66 | spectral_norm(nn.Linear(forward_expansion * emb_size, emb_size)), 67 | ), 68 | nn.Dropout(drop_p) 69 | ) 70 | )) 71 | 72 | class DiscriminatorTransformerEncoder(nn.Sequential): 73 | def __init__(self, depth=4, **kwargs): 74 | super().__init__(*[DiscriminatorTransformerEncoderBlock(**kwargs) for _ in range(depth)]) 75 | 76 | class ClassificationHead(nn.Sequential): 77 | def __init__(self, emb_size=384, class_size_1=4098, class_size_2=1024, class_size_3=512, n_classes=10): 78 | super().__init__( 79 | nn.LayerNorm(emb_size), 80 | spectral_norm(nn.Linear(emb_size, class_size_1)), 81 | nn.GELU(), 82 | spectral_norm(nn.Linear(class_size_1, class_size_2)), 83 | nn.GELU(), 84 | spectral_norm(nn.Linear(class_size_2, class_size_3)), 85 | nn.GELU(), 86 | spectral_norm(nn.Linear(class_size_3, n_classes)), 87 | nn.GELU(), 88 | ) 89 | 90 | def forward(self, x): 91 | # Take only the cls token outputs 92 | x = x[:, 0, :] 93 | return super().forward(x) 94 | 95 | class ViT(nn.Sequential): 96 | def __init__(self, 97 | in_channels=3, 98 | patch_size=4, 99 | stride_size=4, 100 | emb_size=384, 101 | image_size=32, 102 | depth=4, 103 | n_classes=1, 104 | diffaugment='color,translation,cutout', 105 | **kwargs): 106 | self.diffaugment = diffaugment 107 | super().__init__( 108 | PatchEmbedding(in_channels, patch_size, stride_size, emb_size, image_size), 109 | DiscriminatorTransformerEncoder(depth, emb_size=emb_size, **kwargs), 110 | ClassificationHead(emb_size, n_classes=n_classes) 111 | ) 112 | 113 | def forward(self, img, do_augment=True): 114 | if do_augment: 115 | img = DiffAugment(img, policy=self.diffaugment) 116 | return super().forward(img) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ViTGAN-pytorch 2 | A PyTorch implementation of [VITGAN: Training GANs with Vision Transformers](https://arxiv.org/pdf/2107.04589v1.pdf) 3 | 4 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1kJJw6BYW01HgooCZ2zUDt54e1mXqITXH?usp=sharing) 5 | 6 | ## TODO: 7 | 1. [x] Use vectorized L2 distance in attention for **Discriminator** 8 | 2. [x] Overlapping Image Patches 9 | 2. [x] DiffAugment 10 | 3. [x] Self-modulated LayerNorm (SLN) 11 | 4. [x] Implicit Neural Representation for Patch Generation 12 | 5. [x] ExponentialMovingAverage (EMA) 13 | 6. [x] Balanced Consistency Regularization (bCR) 14 | 7. [x] Improved Spectral Normalization (ISN) 15 | 8. [x] Equalized Learning Rate 16 | 9. [x] Weight Modulation 17 | 18 | ## Dependencies 19 | 20 | - Python3 21 | - einops 22 | - pytorch_ema 23 | - stylegan2-pytorch 24 | - tensorboard 25 | - wandb 26 | 27 | ``` bash 28 | pip install einops git+https://github.com/fadel/pytorch_ema stylegan2-pytorch tensorboard wandb 29 | ``` 30 | 31 | ## **TLDR:** 32 | 33 | Train the model with the proposed parameters: 34 | 35 | ``` bash 36 | python main.py 37 | ``` 38 | 39 | Tensorboard 40 | 41 | ``` bash 42 | tensorboard --logdir runs/ 43 | ``` 44 | 45 | *** 46 | 47 | The following parameters are the parameters, proposed in the paper for the CIFAR-10 dataset: 48 | 49 | ``` bash 50 | python main.py 51 | ``` 52 | 53 | ## Implementation Details 54 | 55 | ### Generator 56 | 57 | The Generator follows the following architecture: 58 | 59 | ![ViTGAN Generator architecture](https://drive.google.com/uc?export=view&id=1XaCVOLq8Bvg-I3qM-bugNZcjIW5L7XTO) 60 | 61 | For debugging purposes, the Generator is separated into a Vision Transformer (ViT) model and a SIREN model. 62 | 63 | Given a seed, the dimensionality of which is controlled by ```latent_dim```, the ViT model creates an embedding for each of the patches of the final image. Those embeddings are fed to a SIREN network, combined with a Fourier Position Encoding \([Jupyter Notebook](https://github.com/tancik/fourier-feature-networks/blob/master/Demo.ipynb)\). It outputs the patches of the image, which are stitched together. 64 | 65 | The ViT part of the Generator differs from a standard Vision Transformer in the following ways: 66 | - The input to the Transformer consists only of the position embeddings 67 | - Self-Modulated Layer Norm (SLN) is used in place of LayerNorm 68 | - There is no classification head 69 | 70 | SLN is the only place, where the seed is inputted to the network.
71 | SLN consists of a regular LayerNorm, the result of which is multiplied by ```gamma``` and added to ```beta```.
72 | Both ```gamma``` and ```beta``` are calculated using a fully connected layer, different for each place, SLN is applied.
73 | The input dimension to each of those fully connected is equal to ```hidden_dimension``` and the output dimension is equal to ```hidden_dimension```. 74 | 75 | #### SIREN 76 | 77 | A description of SIREN: 78 | \[[Blog Post](https://tech.fusic.co.jp/posts/2021-08-03-what-are-sirens/)\] \[[Paper](https://arxiv.org/pdf/2006.09661.pdf)\] \[[Colab Notebook](https://colab.research.google.com/github/vsitzmann/siren/blob/master/explore_siren.ipynb)\] 79 | 80 | In contrast to regular SIREN, the desired output is not a single image. For this purpose, the patch embedding is combined to a position embedding. 81 | 82 | The positional encoding, used in ViTGAN is the Fourier Position Encoding, the code for which was taken from here: \([Jupyter Notebook](https://github.com/tancik/fourier-feature-networks/blob/master/Demo.ipynb)\) 83 | 84 | In my implementation, the input to the SIREN is the sum of a patch embedding and a position embedding. 85 | 86 | #### Weight Modulation 87 | 88 | Weight Modulation usually consists of a modulation and a demodulation module. After testing the network, I concluded that **demodulation is not used in ViTGAN**. 89 | 90 | My implementation of the weight modulation is heavily based on [CIPS](https://github.com/saic-mdal/CIPS/blob/main/model/blocks.py#L173). I have adjusted it to work for a fully-connected network, using a 1D convolution. The reason for using 1D convolution, instead of a linear layer is the groups term, which optimizes the performance by a factor of batch_size. 91 | 92 | Each SIREN layer consists of a sinsin activation, applied to a weight modulation layer. The size of the input, the hidden and the output layers in a SIREN network could vary. Thus, in case the input size differs from the size of the patch embedding, I define an additional fully-connected layer, which converts the patch embedding to the appropriate size. 93 | 94 | ### Discriminator 95 | 96 | The Discriminator follows the following architecture: 97 | 98 | ![ViTGAN Discriminator architecture](https://drive.google.com/uc?export=view&id=1LK-WLwNGXqAhJ44MAexSHOyPkyiGapys) 99 | 100 | The ViTGAN Discriminator is mostly a standard Vision Transformer network, with the following modifications: 101 | - DiffAugment 102 | - Overlapping Image Patches 103 | - Use vectorized L2 distance in attention for **Discriminator** 104 | - Improved Spectral Normalization (ISN) 105 | - Balanced Consistency Regularization (bCR) 106 | 107 | #### DiffAugment 108 | 109 | For implementating DiffAugment, I used the code below:
110 | \[[GitHub](https://github.com/mit-han-lab/data-efficient-gans/blob/master/DiffAugment-stylegan2-pytorch/DiffAugment_pytorch.py)\] \[[Paper](https://arxiv.org/pdf/2006.10738.pdf)\] 111 | 112 | #### Overlapping Image Patches 113 | 114 | Creation of the overlapping image patches is implemented with the use of a convolution layer. 115 | 116 | #### Use vectorized L2 distance in attention for **Discriminator** 117 | 118 | \[[Paper](https://arxiv.org/pdf/2006.04710.pdf)\] 119 | 120 | #### Improved Spectral Normalization (ISN) 121 | 122 | The ISN implementation is based on the following implementation of Spectral Normalization:
123 | \[[GitHub](https://github.com/koshian2/SNGAN/blob/117fbb19ac79bbc561c3ccfe285d6890ea0971f9/models/core_layers.py#L9)\] 124 | \[[Paper](https://arxiv.org/abs/1802.05957)\] 125 | 126 | #### Balanced Consistency Regularization (bCR) 127 | 128 | Zhengli Zhao, Sameer Singh, Honglak Lee, Zizhao Zhang, Augustus Odena, Han Zhang; Improved Consistency Regularization for GANs; AAAI 2021 129 | \[[Paper](https://arxiv.org/pdf/2002.04724.pdf)\] 130 | 131 | ## References 132 | SIREN: [Implicit Neural Representations with Periodic Activation Functions](https://arxiv.org/pdf/2006.09661.pdf)
133 | Vision Transformer: \[[Blog Post](https://towardsdatascience.com/implementing-visualttransformer-in-pytorch-184f9f16f632)\]
134 | L2 distance attention: [The Lipschitz Constant of Self-Attention](https://arxiv.org/pdf/2006.04710.pdf)
135 | Spectral Normalization reference code: \[[GitHub](https://github.com/koshian2/SNGAN/blob/117fbb19ac79bbc561c3ccfe285d6890ea0971f9/models/core_layers.py#L9)\] \[[Paper](https://arxiv.org/abs/1802.05957)\]
136 | Diff Augment: \[[GitHub](https://github.com/mit-han-lab/data-efficient-gans/blob/master/DiffAugment-stylegan2-pytorch/DiffAugment_pytorch.py)\] \[[Paper](https://arxiv.org/pdf/2006.10738.pdf)\]
137 | Fourier Position Embedding: \[[Jupyter Notebook](https://github.com/tancik/fourier-feature-networks/blob/master/Demo.ipynb)\]
138 | Exponential Moving Average: \[[GitHub](https://github.com/fadel/pytorch_ema)\]
139 | Balanced Concictancy Regularization (bCR): \[[Paper](https://arxiv.org/pdf/2002.04724.pdf)\]
140 | SyleGAN2 Discriminator: \[[GitHub](https://github.com/lucidrains/stylegan2-pytorch/blob/1a789d186b9697571bd6bbfa8bb1b9735bb42a0c/stylegan2_pytorch/stylegan2_pytorch.py#L627)\]
141 | -------------------------------------------------------------------------------- /models/vitgan_generator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from einops import repeat 6 | 7 | from .multi_head import MultiHeadAttention 8 | 9 | # ViTGAN Generator 10 | 11 | class FullyConnectedLayer(nn.Module): 12 | def __init__(self, 13 | in_features, # Number of input features. 14 | out_features, # Number of output features. 15 | bias = True, # Apply additive bias before the activation function? 16 | activation = 'linear', # Activation function: 'relu', 'lrelu', etc. 17 | lr_multiplier = 1, # Learning rate multiplier. 18 | bias_init = 0, # Initial value for the additive bias. 19 | **kwargs 20 | ): 21 | super().__init__() 22 | self.activation = activation 23 | if activation == 'lrelu': 24 | self.activation = nn.LeakyReLU(0.2) 25 | elif activation == 'relu': 26 | self.activation = nn.ReLU() 27 | elif activation == 'gelu': 28 | self.activation = nn.GELU() 29 | self.weight = nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier) 30 | self.bias = nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None 31 | self.weight_gain = lr_multiplier / np.sqrt(in_features) 32 | self.bias_gain = lr_multiplier 33 | 34 | def forward(self, x): 35 | w = self.weight.to(x.dtype) * self.weight_gain 36 | b = self.bias 37 | if b is not None: 38 | b = b.to(x.dtype) 39 | if self.bias_gain != 1: 40 | b = b * self.bias_gain 41 | 42 | if self.activation == 'linear' and b is not None: 43 | x = torch.addmm(b.unsqueeze(0), x, w.t()) 44 | else: 45 | x = x.matmul(w.t()) 46 | if b is not None: 47 | x = x + b 48 | if self.activation != 'linear': 49 | x = self.activation(x) 50 | return x 51 | 52 | def normalize_2nd_moment(x, dim=1, eps=1e-8): 53 | return x * (x.square().mean(dim=dim, keepdim=True) + eps).rsqrt() 54 | 55 | class MappingNetwork(nn.Module): 56 | def __init__(self, 57 | z_dim, # Input latent (Z) dimensionality, 0 = no latent. 58 | c_dim, # Conditioning label (C) dimensionality, 0 = no label. 59 | w_dim, # Intermediate latent (W) dimensionality. 60 | num_ws = None, # Number of intermediate latents to output, None = do not broadcast. 61 | num_layers = 8, # Number of mapping layers. 62 | embed_features = None, # Label embedding dimensionality, None = same as w_dim. 63 | layer_features = None, # Number of intermediate features in the mapping layers, None = same as w_dim. 64 | activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc. 65 | lr_multiplier = 0.01, # Learning rate multiplier for the mapping layers. 66 | w_avg_beta = 0.995, # Decay for tracking the moving average of W during training, None = do not track. 67 | **kwargs 68 | ): 69 | super().__init__() 70 | self.z_dim = z_dim 71 | self.c_dim = c_dim 72 | self.w_dim = w_dim 73 | self.num_ws = num_ws 74 | self.num_layers = num_layers 75 | self.w_avg_beta = w_avg_beta 76 | 77 | if embed_features is None: 78 | embed_features = w_dim 79 | if c_dim == 0: 80 | embed_features = 0 81 | if layer_features is None: 82 | layer_features = w_dim 83 | features_list = [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim] 84 | 85 | if c_dim > 0: 86 | self.embed = FullyConnectedLayer(c_dim, embed_features) 87 | for idx in range(num_layers): 88 | in_features = features_list[idx] 89 | out_features = features_list[idx + 1] 90 | layer = FullyConnectedLayer(in_features, out_features, activation=activation, lr_multiplier=lr_multiplier) 91 | setattr(self, f'fc{idx}', layer) 92 | 93 | if num_ws is not None and w_avg_beta is not None: 94 | self.register_buffer('w_avg', torch.zeros([w_dim])) 95 | 96 | def forward(self, z, c=None, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False): 97 | # Embed, normalize, and concat inputs. 98 | x = None 99 | with torch.autograd.profiler.record_function('input'): 100 | if self.z_dim > 0: 101 | assert z.shape[1] == self.z_dim 102 | x = normalize_2nd_moment(z.to(torch.float32)) 103 | if self.c_dim > 0: 104 | assert c.shape[1] == self.c_dim 105 | y = normalize_2nd_moment(self.embed(c.to(torch.float32))) 106 | x = torch.cat([x, y], dim=1) if x is not None else y 107 | 108 | # Main layers. 109 | for idx in range(self.num_layers): 110 | layer = getattr(self, f'fc{idx}') 111 | x = layer(x) 112 | 113 | # Update moving average of W. 114 | if self.w_avg_beta is not None and self.training and not skip_w_avg_update: 115 | with torch.autograd.profiler.record_function('update_w_avg'): 116 | self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta)) 117 | 118 | # Broadcast. 119 | if self.num_ws is not None: 120 | with torch.autograd.profiler.record_function('broadcast'): 121 | x = x.unsqueeze(1).repeat([1, self.num_ws, 1]) 122 | 123 | # Apply truncation. 124 | if truncation_psi != 1: 125 | with torch.autograd.profiler.record_function('truncate'): 126 | assert self.w_avg_beta is not None 127 | if self.num_ws is None or truncation_cutoff is None: 128 | x = self.w_avg.lerp(x, truncation_psi) 129 | else: 130 | x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi) 131 | return x 132 | 133 | class FeedForwardBlock(nn.Sequential): 134 | def __init__(self, emb_size, expansion=4, drop_p=0., bias=False): 135 | super().__init__( 136 | FullyConnectedLayer(emb_size, expansion * emb_size, activation='gelu', bias=False), 137 | nn.Dropout(drop_p), 138 | FullyConnectedLayer(expansion * emb_size, emb_size, bias=False), 139 | ) 140 | 141 | # Self-Modulated LayerNorm 142 | 143 | class SLN(nn.Module): 144 | def __init__(self, input_size, parameter_size=None, **kwargs): 145 | super().__init__() 146 | if parameter_size == None: 147 | parameter_size = input_size 148 | assert(input_size == parameter_size or parameter_size == 1) 149 | self.input_size = input_size 150 | self.parameter_size = parameter_size 151 | self.ln = nn.LayerNorm(input_size) 152 | self.gamma = FullyConnectedLayer(input_size, parameter_size, bias=False) 153 | self.beta = FullyConnectedLayer(input_size, parameter_size, bias=False) 154 | # self.gamma = nn.Linear(input_size, parameter_size, bias=False) 155 | # self.beta = nn.Linear(input_size, parameter_size, bias=False) 156 | 157 | def forward(self, hidden, w): 158 | assert(hidden.size(-1) == self.parameter_size and w.size(-1) == self.parameter_size) 159 | gamma = self.gamma(w).unsqueeze(1) 160 | beta = self.beta(w).unsqueeze(1) 161 | ln = self.ln(hidden) 162 | return gamma * ln + beta 163 | 164 | class GeneratorTransformerEncoderBlock(nn.Module): 165 | def __init__(self, 166 | hidden_size=384, 167 | sln_paremeter_size=384, 168 | drop_p=0., 169 | forward_expansion=4, 170 | forward_drop_p=0., 171 | **kwargs): 172 | super().__init__() 173 | self.sln = SLN(hidden_size, parameter_size=sln_paremeter_size) 174 | self.msa = MultiHeadAttention(hidden_size, **kwargs) 175 | self.dropout = nn.Dropout(drop_p) 176 | self.feed_forward = FeedForwardBlock(hidden_size, expansion=forward_expansion, drop_p=forward_drop_p) 177 | 178 | def forward(self, hidden, w): 179 | res = hidden 180 | hidden = self.sln(hidden, w) 181 | hidden = self.msa(hidden) 182 | hidden = self.dropout(hidden) 183 | hidden += res 184 | 185 | res = hidden 186 | hidden = self.sln(hidden, w) 187 | self.feed_forward(hidden) 188 | hidden = self.dropout(hidden) 189 | hidden += res 190 | return hidden 191 | 192 | class GeneratorTransformerEncoder(nn.Module): 193 | def __init__(self, depth=4, **kwargs): 194 | super().__init__() 195 | self.depth = depth 196 | self.blocks = nn.ModuleList([GeneratorTransformerEncoderBlock(**kwargs) for _ in range(depth)]) 197 | 198 | def forward(self, hidden, w): 199 | for i in range(self.depth): 200 | hidden = self.blocks[i](hidden, w) 201 | return hidden 202 | 203 | # SIREN 204 | 205 | #Code for SIREN is taken from https://colab.research.google.com/github/vsitzmann/siren/blob/master/explore_siren.ipynb 206 | 207 | class ModulatedLinear(nn.Module): 208 | def __init__(self, in_channels, out_channels, style_size, bias=False, demodulation=True, **kwargs): 209 | super().__init__() 210 | self.in_channels = in_channels 211 | self.out_channels = out_channels 212 | self.style_size = style_size 213 | self.scale = 1 / np.sqrt(in_channels) 214 | self.weight = nn.Parameter( 215 | torch.randn(1, out_channels, in_channels, 1) 216 | ) 217 | self.modulation = None 218 | if self.style_size != self.in_channels: 219 | self.modulation = FullyConnectedLayer(style_size, in_channels, bias=False) 220 | self.demodulation = demodulation 221 | 222 | def forward(self, input, style): 223 | batch_size = input.shape[0] 224 | 225 | if self.style_size != self.in_channels: 226 | style = self.modulation(style) 227 | style = style.view(batch_size, 1, self.in_channels, 1) 228 | weight = self.scale * self.weight * style 229 | 230 | if self.demodulation: 231 | demod = torch.rsqrt(weight.pow(2).sum([2]) + 1e-8) 232 | weight = weight * demod.view(batch_size, self.out_channels, 1, 1) 233 | 234 | weight = weight.view( 235 | batch_size * self.out_channels, self.in_channels, 1 236 | ) 237 | 238 | img_size = input.size(1) 239 | input = input.reshape(1, batch_size * self.in_channels, img_size) 240 | out = F.conv1d(input, weight, groups=batch_size) 241 | out = out.view(batch_size, img_size, self.out_channels) 242 | 243 | return out 244 | 245 | class ResLinear(nn.Module): 246 | def __init__(self, in_channels, out_channels, style_size, bias=False, **kwargs): 247 | super().__init__() 248 | self.linear = FullyConnectedLayer(in_channels, out_channels, bias=False) 249 | self.style = FullyConnectedLayer(style_size, in_channels, bias=False) 250 | self.in_channels = in_channels 251 | self.out_channels = out_channels 252 | self.style_size = style_size 253 | 254 | def forward(self, input, style): 255 | x = input + self.style(style).unsqueeze(1) 256 | x = self.linear(x) 257 | return x 258 | 259 | class ConLinear(nn.Module): 260 | def __init__(self, ch_in, ch_out, is_first=False, bias=True, **kwargs): 261 | super().__init__() 262 | self.conv = nn.Linear(ch_in, ch_out, bias=bias) 263 | if is_first: 264 | nn.init.uniform_(self.conv.weight, -np.sqrt(9 / ch_in), np.sqrt(9 / ch_in)) 265 | else: 266 | nn.init.uniform_(self.conv.weight, -np.sqrt(3 / ch_in), np.sqrt(3 / ch_in)) 267 | 268 | def forward(self, x): 269 | return self.conv(x) 270 | 271 | class SinActivation(nn.Module): 272 | def __init__(self): 273 | super().__init__() 274 | 275 | def forward(self, x): 276 | return torch.sin(x) 277 | 278 | class LFF(nn.Module): 279 | def __init__(self, hidden_size, **kwargs): 280 | super().__init__() 281 | self.ffm = ConLinear(2, hidden_size, is_first=True) 282 | self.activation = SinActivation() 283 | 284 | def forward(self, x): 285 | x = x 286 | x = self.ffm(x) 287 | x = self.activation(x) 288 | return x 289 | 290 | class SineLayer(nn.Module): 291 | def __init__(self, in_features, out_features, style_size, bias=False, 292 | is_first=False, omega_0=30, weight_modulation=True, **kwargs): 293 | super().__init__() 294 | self.omega_0 = omega_0 295 | self.is_first = is_first 296 | 297 | self.in_features = in_features 298 | self.weight_modulation = weight_modulation 299 | if weight_modulation: 300 | self.linear = ModulatedLinear(in_features, out_features, style_size=style_size, bias=bias, **kwargs) 301 | else: 302 | self.linear = ResLinear(in_features, out_features, style_size=style_size, bias=bias, **kwargs) 303 | self.init_weights() 304 | 305 | def init_weights(self): 306 | with torch.no_grad(): 307 | if self.is_first: 308 | if self.weight_modulation: 309 | self.linear.weight.uniform_(-1 / self.in_features, 310 | 1 / self.in_features) 311 | else: 312 | self.linear.linear.weight.uniform_(-1 / self.in_features, 313 | 1 / self.in_features) 314 | else: 315 | if self.weight_modulation: 316 | self.linear.weight.uniform_(-np.sqrt(6 / self.in_features) / self.omega_0, 317 | np.sqrt(6 / self.in_features) / self.omega_0) 318 | else: 319 | self.linear.linear.weight.uniform_(-np.sqrt(6 / self.in_features) / self.omega_0, 320 | np.sqrt(6 / self.in_features) / self.omega_0) 321 | 322 | def forward(self, input, style): 323 | return torch.sin(self.omega_0 * self.linear(input, style)) 324 | 325 | class Siren(nn.Module): 326 | def __init__(self, in_features, hidden_size, hidden_layers, out_features, style_size, outermost_linear=False, 327 | first_omega_0=30, hidden_omega_0=30., weight_modulation=True, bias=False, **kwargs): 328 | super().__init__() 329 | 330 | self.net = [] 331 | self.net.append(SineLayer(in_features, hidden_size, style_size, 332 | is_first=True, omega_0=first_omega_0, 333 | weight_modulation=weight_modulation, **kwargs)) 334 | 335 | for i in range(hidden_layers): 336 | self.net.append(SineLayer(hidden_size, hidden_size, style_size, 337 | is_first=False, omega_0=hidden_omega_0, 338 | weight_modulation=weight_modulation, **kwargs)) 339 | 340 | if outermost_linear: 341 | if weight_modulation: 342 | final_linear = ModulatedLinear(hidden_size, out_features, 343 | style_size=style_size, bias=bias, **kwargs) 344 | else: 345 | final_linear = ResLinear(hidden_size, out_features, style_size=style_size, bias=bias, **kwargs) 346 | 347 | with torch.no_grad(): 348 | if weight_modulation: 349 | final_linear.weight.uniform_(-np.sqrt(6 / hidden_size) / hidden_omega_0, 350 | np.sqrt(6 / hidden_size) / hidden_omega_0) 351 | else: 352 | final_linear.linear.weight.uniform_(-np.sqrt(6 / hidden_size) / hidden_omega_0, 353 | np.sqrt(6 / hidden_size) / hidden_omega_0) 354 | 355 | self.net.append(final_linear) 356 | else: 357 | self.net.append(SineLayer(hidden_size, out_features, 358 | is_first=False, omega_0=hidden_omega_0, 359 | weight_modulation=weight_modulation, **kwargs)) 360 | 361 | self.net = nn.Sequential(*self.net) 362 | 363 | def forward(self, coords, style): 364 | coords = coords.clone().detach().requires_grad_(True) # allows to take derivative w.r.t. input 365 | # output = self.net(coords, style) 366 | output = coords 367 | for layer in self.net: 368 | output = layer(output, style) 369 | return output 370 | 371 | class GeneratorViT(nn.Module): 372 | def __init__(self, 373 | style_mlp_layers=8, 374 | patch_size=4, 375 | latent_dim=32, 376 | hidden_size=384, 377 | sln_paremeter_size=1, 378 | image_size=32, 379 | depth=4, 380 | combine_patch_embeddings=False, 381 | combined_embedding_size=1024, 382 | forward_drop_p=0., 383 | bias=False, 384 | out_features=3, 385 | out_patch_size=4, 386 | weight_modulation=True, 387 | siren_hidden_layers=1, 388 | **kwargs): 389 | super().__init__() 390 | self.hidden_size = hidden_size 391 | 392 | self.mlp = MappingNetwork(z_dim=latent_dim, c_dim=0, w_dim=hidden_size, num_layers=style_mlp_layers, w_avg_beta=None) 393 | 394 | num_patches = int(image_size//patch_size)**2 395 | self.patch_size = patch_size 396 | self.num_patches = num_patches 397 | self.image_size = image_size 398 | self.combine_patch_embeddings = combine_patch_embeddings 399 | self.combined_embedding_size = combined_embedding_size 400 | self.out_patch_size = out_patch_size 401 | self.out_features = out_features 402 | 403 | self.pos_emb = nn.Parameter(torch.randn(num_patches, hidden_size)) 404 | self.transformer_encoder = GeneratorTransformerEncoder(depth, 405 | hidden_size=hidden_size, 406 | sln_paremeter_size=sln_paremeter_size, 407 | drop_p=forward_drop_p, 408 | forward_drop_p=forward_drop_p, 409 | **kwargs) 410 | self.sln = SLN(hidden_size, parameter_size=sln_paremeter_size) 411 | if combine_patch_embeddings: 412 | self.to_single_emb = nn.Sequential( 413 | FullyConnectedLayer(num_patches*hidden_size, combined_embedding_size, bias=bias, activation='gelu'), 414 | nn.Dropout(forward_drop_p), 415 | ) 416 | 417 | self.lff = LFF(self.hidden_size) 418 | 419 | self.siren_in_features = combined_embedding_size if combine_patch_embeddings else self.hidden_size 420 | self.siren = Siren(in_features=self.siren_in_features, out_features=out_features, 421 | style_size=self.siren_in_features, hidden_size=self.hidden_size, bias=bias, 422 | hidden_layers=siren_hidden_layers, outermost_linear=True, weight_modulation=weight_modulation, **kwargs) 423 | 424 | self.num_patches_x = int(image_size//self.out_patch_size) 425 | 426 | def fourier_input_mapping(self, x): 427 | return self.lff(x) 428 | 429 | def fourier_pos_embedding(self, device): 430 | # Create input pixel coordinates in the unit square 431 | coords = np.linspace(-1, 1, self.out_patch_size, endpoint=True) 432 | pos = np.stack(np.meshgrid(coords, coords), -1) 433 | pos = torch.tensor(pos, dtype=torch.float, device=device) 434 | result = self.fourier_input_mapping(pos).reshape([self.out_patch_size**2, self.hidden_size]) 435 | return result.to(device) 436 | 437 | def repeat_pos(self, hidden): 438 | pos = self.fourier_pos_embedding(hidden.device) 439 | result = repeat(pos, 'p h -> n p h', n = hidden.shape[0]) 440 | 441 | return result 442 | 443 | def forward(self, z): 444 | w = self.mlp(z) 445 | pos = repeat(torch.sin(self.pos_emb), 'n e -> b n e', b=z.shape[0]) 446 | hidden = self.transformer_encoder(pos, w) 447 | 448 | if self.combine_patch_embeddings: 449 | # Output [batch_size, combined_embedding_size] 450 | hidden = self.sln(hidden, w).view((z.shape[0], -1)) 451 | hidden = self.to_single_emb(hidden) 452 | else: 453 | # Output [batch_size*num_patches, hidden_size] 454 | hidden = self.sln(hidden, w).view((-1, self.hidden_size)) 455 | 456 | pos = self.repeat_pos(hidden) 457 | 458 | result = self.siren(pos, hidden) 459 | 460 | model_output_1 = result.view([-1, self.num_patches_x, self.num_patches_x, self.out_patch_size, self.out_patch_size, self.out_features]) 461 | model_output_2 = model_output_1.permute([0, 1, 3, 2, 4, 5]) 462 | model_output = model_output_2.reshape([-1, self.image_size**2, self.out_features]) 463 | 464 | return model_output 465 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | import os 5 | 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | 9 | import torchvision 10 | import torchvision.transforms as transforms 11 | import torch.nn.functional as F 12 | import torchvision.transforms as T 13 | from torchvision.utils import make_grid 14 | 15 | from torch_ema import ExponentialMovingAverage 16 | 17 | from torch.utils.tensorboard import SummaryWriter 18 | import wandb 19 | 20 | from models import CNN, ViT, StyleGanDiscriminator, GeneratorViT, CNNGenerator 21 | 22 | import argparse 23 | 24 | def get_parser(): 25 | # parse parameters 26 | parser = argparse.ArgumentParser(description='ViTGAN') 27 | parser.add_argument("--image_size", type=int, default=32, 28 | help="Image Size") 29 | parser.add_argument("--style_mlp_layers", type=int, default=8, 30 | help="Style Mapping network depth") 31 | parser.add_argument("--patch_size", type=int, default=4, 32 | help="Patch Size") 33 | parser.add_argument("--latent_dim", type=int, default=32, 34 | help="Dimensions of the seed") 35 | parser.add_argument("--dropout_p", type=float, default=0., 36 | help="Dropout rate") 37 | parser.add_argument("--bias", type=bool, default=True, 38 | help="Whether to use bias or not") 39 | parser.add_argument("--weight_modulation", type=bool, default=True, 40 | help="Whether to use weight modulation or not") 41 | parser.add_argument("--demodulation", type=bool, default=False, 42 | help="Whether to use weight demodulation or not") 43 | parser.add_argument("--siren_hidden_layers", type=int, default=1, 44 | help="Number of hidden layers for the SIREN network") 45 | parser.add_argument("--hidden_features", type=int, default=384, 46 | help="Image Size") 47 | parser.add_argument("--sln_paremeter_size", type=int, default=384, 48 | help="Either equal to --hidden_features of 1") 49 | parser.add_argument("--depth", type=int, default=4, 50 | help="Number of Transformer Block layers for both the Generator and the Discriminator") 51 | parser.add_argument("--num_heads", type=int, default=4, 52 | help="Number of Attention heads for every Transformer Block layer") 53 | parser.add_argument("--combine_patch_embeddings", type=bool, default=False, 54 | help="Generate an image from a single SIREN, instead of patch-by-patch") 55 | parser.add_argument("--combine_patch_embeddings_size", type=int, default=384, 56 | help="Size of the combined image embedding") 57 | parser.add_argument("--batch_size", type=int, default=128, 58 | help="Batch size") 59 | parser.add_argument("--generator_type", type=str, default="vitgan", 60 | help="\"vitgan\", \"cnn\"") 61 | parser.add_argument("--discriminator_type", type=str, default="vitgan", 62 | help="\"vitgan\", \"cnn\", \"stylegan2\"") 63 | parser.add_argument("--batch_size_history_discriminator", type=bool, default=True, 64 | help="Whether to use a loss, which tracks one sample from last batch_size number of batches") 65 | parser.add_argument("--lr", type=float, default=0.001, 66 | help="Learning Rate for the Generator") 67 | parser.add_argument("--lr_dis", type=float, default=0.001, 68 | help="Learning Rate for the Discriminator") 69 | parser.add_argument("--beta1", type=float, default=0, 70 | help="Adam beta1 parameter") 71 | parser.add_argument("--beta2", type=float, default=0.99, 72 | help="Adam beta2 parameter") 73 | parser.add_argument("--epochs", type=int, default=400, 74 | help="Number of epocks") 75 | parser.add_argument("--lambda_bCR_real", type=int, default=10, 76 | help="lambda_bCR_real") 77 | parser.add_argument("--lambda_bCR_fake", type=int, default=10, 78 | help="lambda_bCR_fake") 79 | parser.add_argument("--lambda_lossD_noise", type=int, default=0, 80 | help="lambda_lossD_noise") 81 | parser.add_argument("--lambda_lossD_history", type=int, default=0, 82 | help="lambda_lossD_history") 83 | parser.add_argument("--lambda_diversity_penalty", type=int, default=0, 84 | help="lambda_diversity_penalty") 85 | parser.add_argument("--device", type=str, default='cuda', 86 | help="device") 87 | return parser 88 | 89 | parser = get_parser() 90 | params = parser.parse_args() 91 | 92 | image_size = params.image_size 93 | style_mlp_layers = params.style_mlp_layers 94 | patch_size = params.patch_size 95 | latent_dim = params.latent_dim # Size of z 96 | hidden_size = params.hidden_features 97 | depth = params.depth 98 | num_heads = params.num_heads 99 | 100 | dropout_p = params.dropout_p 101 | bias = params.bias 102 | weight_modulation = params.weight_modulation 103 | demodulation = params.demodulation 104 | siren_hidden_layers = params.siren_hidden_layers 105 | 106 | combine_patch_embeddings = params.combine_patch_embeddings # Generate an image from a single SIREN, instead of patch-by-patch 107 | combine_patch_embeddings_size = params.combine_patch_embeddings_size 108 | 109 | sln_paremeter_size = params.sln_paremeter_size # either hidden_size or 1 110 | 111 | batch_size = params.batch_size 112 | device = params.device 113 | out_features = 3 # The number of color channels 114 | 115 | generator_type = params.generator_type # "vitgan", "cnn" 116 | discriminator_type = params.discriminator_type # "vitgan", "cnn", "stylegan2" 117 | 118 | lr = params.lr # Learning rate 119 | lr_dis = params.lr_dis # Learning rate 120 | beta = (params.beta1, params.beta2) # Adam oprimizer parameters for both the generator and the discriminator 121 | batch_size_history_discriminator = params.batch_size_history_discriminator # Whether to use a loss, which tracks one sample from last batch_size number of batches 122 | epochs = params.epochs # Number of epochs 123 | lambda_bCR_real = params.lambda_bCR_real 124 | lambda_bCR_fake = params.lambda_bCR_fake 125 | lambda_lossD_noise = params.lambda_lossD_noise 126 | lambda_lossD_history = params.lambda_lossD_history 127 | lambda_diversity_penalty = params.lambda_diversity_penalty 128 | 129 | experiment_folder_name = f'runs/lr-{lr}_\ 130 | lr_dis-{lr_dis}_\ 131 | bias-{bias}_\ 132 | demod-{demodulation}_\ 133 | sir_n_layer-{siren_hidden_layers}_\ 134 | w_mod-{weight_modulation}_\ 135 | patch_s-{patch_size}_\ 136 | st_mlp_l-{style_mlp_layers}_\ 137 | hid_size-{hidden_size}_\ 138 | comb_patch_emb-{combine_patch_embeddings}_\ 139 | sln_par_s-{sln_paremeter_size}_\ 140 | dis_type-{discriminator_type}_\ 141 | gen_type-{generator_type}_\ 142 | n_head-{num_heads}_\ 143 | depth-{depth}_\ 144 | drop_p-{dropout_p}_\ 145 | l_bCR_r-{lambda_bCR_real}_\ 146 | l_bCR_f-{lambda_bCR_fake}_\ 147 | l_D_noise-{lambda_lossD_noise}_\ 148 | l_D_his-{lambda_lossD_history}\ 149 | ' 150 | writer = SummaryWriter(log_dir=experiment_folder_name) 151 | 152 | wandb.init(project='ViTGAN-pytorch') 153 | config = wandb.config 154 | config.image_size = image_size 155 | config.bias = bias 156 | config.demodulation = demodulation 157 | config.siren_hidden_layers = siren_hidden_layers 158 | config.weight_modulation = weight_modulation 159 | config.style_mlp_layers = style_mlp_layers 160 | config.patch_size = patch_size 161 | config.latent_dim = latent_dim 162 | config.hidden_size = hidden_size 163 | config.depth = depth 164 | config.num_heads = num_heads 165 | 166 | config.dropout_p = dropout_p 167 | 168 | config.combine_patch_embeddings = combine_patch_embeddings 169 | config.combine_patch_embeddings_size = combine_patch_embeddings_size 170 | 171 | config.sln_paremeter_size = sln_paremeter_size 172 | 173 | config.batch_size = batch_size 174 | config.device = device 175 | config.out_features = out_features 176 | 177 | config.generator_type = generator_type 178 | config.discriminator_type = discriminator_type 179 | 180 | config.lr = lr 181 | config.lr_dis = lr_dis 182 | config.beta1 = beta[0] 183 | config.beta2 = beta[1] 184 | config.batch_size_history_discriminator = batch_size_history_discriminator 185 | config.epochs = epochs 186 | config.lambda_bCR_real = lambda_bCR_real 187 | config.lambda_bCR_fake = lambda_bCR_fake 188 | config.lambda_lossD_noise = lambda_lossD_noise 189 | config.lambda_lossD_history = lambda_lossD_history 190 | config.lambda_diversity_penalty = lambda_diversity_penalty 191 | 192 | if combine_patch_embeddings: 193 | out_patch_size = image_size 194 | combined_embedding_size = combine_patch_embeddings_size 195 | else: 196 | out_patch_size = patch_size 197 | combined_embedding_size = hidden_size 198 | 199 | siren_in_features = combined_embedding_size 200 | 201 | transform = transforms.Compose( 202 | [transforms.ToTensor(), 203 | transforms.Normalize((0., 0., 0.), (1., 1., 1.)) 204 | ]) 205 | 206 | trainset = torchvision.datasets.CIFAR10(root='./data', train=True, 207 | download=True, transform=transform) 208 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, 209 | shuffle=True, num_workers=2) 210 | 211 | # Diversity Loss 212 | 213 | def diversity_loss(images): 214 | num_images_to_calculate_on = 10 215 | num_pairs = num_images_to_calculate_on * (num_images_to_calculate_on - 1) // 2 216 | 217 | scale_factor = 5 218 | 219 | loss = torch.zeros(1, dtype=torch.float, device=device, requires_grad=True) 220 | i = 0 221 | for a_id in range(num_images_to_calculate_on): 222 | for b_id in range(a_id+1, num_images_to_calculate_on): 223 | img_a = images[a_id] 224 | img_b = images[b_id] 225 | img_a_l2 = torch.norm(img_a) 226 | img_b_l2 = torch.norm(img_b) 227 | img_a, img_b = img_a.flatten(), img_b.flatten() 228 | 229 | a_b_loss = scale_factor * (img_a.t() @ img_b) / (img_a_l2 * img_b_l2) 230 | loss = loss + torch.sigmoid(a_b_loss) 231 | i += 1 232 | loss = loss.sum() / num_pairs 233 | return loss 234 | 235 | # Normal distribution init weight 236 | 237 | def init_normal(m): 238 | if type(m) == nn.Linear: 239 | # y = m.in_features 240 | # m.weight.data.normal_(0.0,1/np.sqrt(y)) 241 | if 'weight' in m.__dict__.keys(): 242 | m.weight.data.normal_(0.0,1) 243 | # m.bias.data.fill_(0) 244 | 245 | # Experiments 246 | 247 | if generator_type == "vitgan": 248 | # Create the Generator 249 | Generator = GeneratorViT( patch_size=patch_size, 250 | image_size=image_size, 251 | style_mlp_layers=style_mlp_layers, 252 | latent_dim=latent_dim, 253 | hidden_size=hidden_size, 254 | combine_patch_embeddings=combine_patch_embeddings, 255 | combined_embedding_size=combined_embedding_size, 256 | sln_paremeter_size=sln_paremeter_size, 257 | num_heads=num_heads, 258 | depth=depth, 259 | forward_drop_p=dropout_p, 260 | bias=bias, 261 | weight_modulation=weight_modulation, 262 | siren_hidden_layers=siren_hidden_layers, 263 | demodulation=demodulation, 264 | out_patch_size=out_patch_size, 265 | ).to(device) 266 | 267 | print(Generator) 268 | 269 | # use the modules apply function to recursively apply the initialization 270 | Generator.apply(init_normal) 271 | 272 | num_patches_x = int(image_size//out_patch_size) 273 | 274 | if os.path.exists(f'{experiment_folder_name}/weights/Generator.pth'): 275 | Generator = torch.load(f'{experiment_folder_name}/weights/Generator.pth') 276 | 277 | wandb.watch(Generator) 278 | 279 | elif generator_type == "cnn": 280 | cnn_generator = CNNGenerator(hidden_size=hidden_size, latent_dim=latent_dim).to(device) 281 | 282 | print(cnn_generator) 283 | 284 | cnn_generator.apply(init_normal) 285 | 286 | if os.path.exists(f'{experiment_folder_name}/weights/cnn_generator.pth'): 287 | cnn_generator = torch.load(f'{experiment_folder_name}/weights/cnn_generator.pth') 288 | 289 | wandb.watch(cnn_generator) 290 | 291 | # Create the three types of discriminators 292 | if discriminator_type == "vitgan": 293 | Discriminator = ViT(discriminator=True, 294 | patch_size=patch_size*2, 295 | stride_size=patch_size, 296 | n_classes=1, 297 | num_heads=num_heads, 298 | depth=depth, 299 | forward_drop_p=dropout_p, 300 | ).to(device) 301 | 302 | print(Discriminator) 303 | 304 | Discriminator.apply(init_normal) 305 | 306 | if os.path.exists(f'{experiment_folder_name}/weights/discriminator.pth'): 307 | Discriminator = torch.load(f'{experiment_folder_name}/weights/discriminator.pth') 308 | 309 | wandb.watch(Discriminator) 310 | 311 | elif discriminator_type == "cnn": 312 | cnn_discriminator = CNN().to(device) 313 | 314 | print(cnn_discriminator) 315 | 316 | cnn_discriminator.apply(init_normal) 317 | 318 | if os.path.exists(f'{experiment_folder_name}/weights/discriminator.pth'): 319 | cnn_discriminator = torch.load(f'{experiment_folder_name}/weights/discriminator.pth') 320 | 321 | wandb.watch(cnn_discriminator) 322 | 323 | elif discriminator_type == "stylegan2": 324 | stylegan2_discriminator = StyleGanDiscriminator(image_size=32).to(device) 325 | 326 | print(stylegan2_discriminator) 327 | 328 | # stylegan2_discriminator.apply(init_normal) 329 | 330 | if os.path.exists(f'{experiment_folder_name}/weights/discriminator.pth'): 331 | stylegan2_discriminator = torch.load(f'{experiment_folder_name}/weights/discriminator.pth') 332 | 333 | wandb.watch(stylegan2_discriminator) 334 | 335 | # Training 336 | 337 | os.makedirs(f"{experiment_folder_name}/weights", exist_ok = True) 338 | os.makedirs(f"{experiment_folder_name}/samples", exist_ok = True) 339 | 340 | # Loss function 341 | criterion = nn.BCEWithLogitsLoss() 342 | 343 | if discriminator_type == "cnn": discriminator = cnn_discriminator 344 | elif discriminator_type == "stylegan2": discriminator = stylegan2_discriminator 345 | elif discriminator_type == "vitgan": discriminator = Discriminator 346 | 347 | if generator_type == "cnn": 348 | params = cnn_generator.parameters() 349 | else: 350 | params = Generator.parameters() 351 | optim_g = torch.optim.Adam(lr=lr, params=params, betas=beta) 352 | optim_d = torch.optim.Adam(lr=lr_dis, params=discriminator.parameters(), betas=beta) 353 | ema = ExponentialMovingAverage(params, decay=0.995) 354 | 355 | fixed_noise = torch.FloatTensor(np.random.normal(0, 1, (16, latent_dim))).to(device) 356 | 357 | discriminator_f_img = torch.zeros([batch_size, 3, image_size, image_size]).to(device) 358 | 359 | trainset_len = len(trainloader.dataset) 360 | 361 | step = 0 362 | for epoch in range(epochs): 363 | for batch_id, batch in enumerate(trainloader): 364 | step += 1 365 | 366 | # Train discriminator 367 | 368 | # Forward + Backward with real images 369 | r_img = batch[0].to(device) 370 | r_logit = discriminator(r_img).flatten() 371 | r_label = torch.ones(r_logit.shape[0]).to(device) 372 | 373 | lossD_real = criterion(r_logit, r_label) 374 | 375 | lossD_bCR_real = F.mse_loss(r_logit, discriminator(r_img, do_augment=False)) 376 | 377 | # Forward + Backward with fake images 378 | latent_vector = torch.FloatTensor(np.random.normal(0, 1, (batch_size, latent_dim))).to(device) 379 | 380 | if generator_type == "vitgan": 381 | f_img = Generator(latent_vector) 382 | f_img = f_img.reshape([-1, image_size, image_size, out_features]) 383 | f_img = f_img.permute(0, 3, 1, 2) 384 | else: 385 | model_output = cnn_generator(latent_vector) 386 | f_img = model_output 387 | 388 | assert f_img.size(0) == batch_size, f_img.shape 389 | assert f_img.size(1) == out_features, f_img.shape 390 | assert f_img.size(2) == image_size, f_img.shape 391 | assert f_img.size(3) == image_size, f_img.shape 392 | 393 | f_label = torch.zeros(batch_size).to(device) 394 | # Save the a single generated image to the discriminator training data 395 | if batch_size_history_discriminator: 396 | discriminator_f_img[step % batch_size] = f_img[0].detach() 397 | f_logit_history = discriminator(discriminator_f_img).flatten() 398 | lossD_fake_history = criterion(f_logit_history, f_label) 399 | else: lossD_fake_history = 0 400 | # Train the discriminator on the images, generated only from this batch 401 | f_logit = discriminator(f_img.detach()).flatten() 402 | lossD_fake = criterion(f_logit, f_label) 403 | 404 | lossD_bCR_fake = F.mse_loss(f_logit, discriminator(f_img, do_augment=False)) 405 | 406 | f_noise_input = torch.FloatTensor(np.random.rand(*f_img.shape)*2 - 1).to(device) 407 | f_noise_logit = discriminator(f_noise_input).flatten() 408 | lossD_noise = criterion(f_noise_logit, f_label) 409 | 410 | lossD = lossD_real * 0.5 +\ 411 | lossD_fake * 0.5 +\ 412 | lossD_fake_history * lambda_lossD_history +\ 413 | lossD_noise * lambda_lossD_noise +\ 414 | lossD_bCR_real * lambda_bCR_real +\ 415 | lossD_bCR_fake * lambda_bCR_fake 416 | 417 | optim_d.zero_grad() 418 | lossD.backward() 419 | optim_d.step() 420 | 421 | # Train Generator 422 | 423 | if generator_type == "vitgan": 424 | f_img = Generator(latent_vector) 425 | f_img = f_img.reshape([-1, image_size, image_size, out_features]) 426 | f_img = f_img.permute(0, 3, 1, 2) 427 | else: 428 | model_output = cnn_generator(latent_vector) 429 | f_img = model_output 430 | 431 | assert f_img.size(0) == batch_size 432 | assert f_img.size(1) == out_features 433 | assert f_img.size(2) == image_size 434 | assert f_img.size(3) == image_size 435 | 436 | f_logit = discriminator(f_img).flatten() 437 | r_label = torch.ones(batch_size).to(device) 438 | lossG_main = criterion(f_logit, r_label) 439 | 440 | lossG_diversity = diversity_loss(f_img) * lambda_diversity_penalty 441 | lossG = lossG_main + lossG_diversity 442 | 443 | optim_g.zero_grad() 444 | lossG.backward() 445 | optim_g.step() 446 | ema.update() 447 | 448 | writer.add_scalar("Loss/Generator", lossG_main, step) 449 | writer.add_scalar("Loss/Gen(diversity)", lossG_diversity, step) 450 | writer.add_scalar("Loss/Dis(real)", lossD_real, step) 451 | writer.add_scalar("Loss/Dis(fake)", lossD_fake, step) 452 | writer.add_scalar("Loss/Dis(fake_history)", lossD_fake_history, step) 453 | writer.add_scalar("Loss/Dis(noise)", lossD_noise, step) 454 | writer.add_scalar("Loss/Dis(bCR_fake)", lossD_bCR_fake * lambda_bCR_fake, step) 455 | writer.add_scalar("Loss/Dis(bCR_real)", lossD_bCR_real * lambda_bCR_real, step) 456 | writer.flush() 457 | 458 | wandb.log({ 459 | 'Generator': lossG_main, 460 | 'Gen(diversity)': lossG_diversity, 461 | 'Dis(real)': lossD_real, 462 | 'Dis(fake)': lossD_fake, 463 | 'Dis(fake_history)': lossD_fake_history, 464 | 'Dis(noise)': lossD_noise, 465 | 'Dis(bCR_fake)': lossD_bCR_fake * lambda_bCR_fake, 466 | 'Dis(bCR_real)': lossD_bCR_real * lambda_bCR_real 467 | }) 468 | 469 | if batch_id%20 == 0: 470 | print(f'epoch {epoch}/{epochs}; batch {batch_id}/{int(trainset_len/batch_size)}') 471 | print(f'Generator: {"{:.3f}".format(float(lossG_main))}, '+\ 472 | f'Gen(diversity): {"{:.3f}".format(float(lossG_diversity))}, '+\ 473 | f'Dis(real): {"{:.3f}".format(float(lossD_real))}, '+\ 474 | f'Dis(fake): {"{:.3f}".format(float(lossD_fake))}, '+\ 475 | f'Dis(fake_history): {"{:.3f}".format(float(lossD_fake_history))}, '+\ 476 | f'Dis(noise) {"{:.3f}".format(float(lossD_noise))}, '+\ 477 | f'Dis(bCR_fake): {"{:.3f}".format(float(lossD_bCR_fake * lambda_bCR_fake))}, '+\ 478 | f'Dis(bCR_real): {"{:.3f}".format(float(lossD_bCR_real * lambda_bCR_real))}') 479 | 480 | # Plot 8 randomly selected samples 481 | fig, axes = plt.subplots(1,8, figsize=(24,3)) 482 | output = f_img.permute(0, 2, 3, 1) 483 | for i in range(8): 484 | j = np.random.randint(0, batch_size-1) 485 | img = output[j].cpu().view(32,32,3).detach().numpy() 486 | img -= img.min() 487 | img /= img.max() 488 | axes[i].imshow(img) 489 | plt.show() 490 | 491 | # if step % sample_interval == 0: 492 | if generator_type == "vitgan": 493 | Generator.eval() 494 | # img_siren.eval() 495 | vis = Generator(fixed_noise) 496 | vis = vis.reshape([-1, image_size, image_size, out_features]) 497 | vis = vis.permute(0, 3, 1, 2) 498 | else: 499 | model_output = cnn_generator(fixed_noise) 500 | vis = model_output 501 | 502 | assert vis.shape[0] == fixed_noise.shape[0], f'vis.shape[0] is {vis.shape[0]}, but should be {fixed_noise.shape[0]}' 503 | assert vis.shape[1] == out_features, f'vis.shape[1] is {vis.shape[1]}, but should be {out_features}' 504 | assert vis.shape[2] == image_size, f'vis.shape[2] is {vis.shape[2]}, but should be {image_size}' 505 | assert vis.shape[3] == image_size, f'vis.shape[3] is {vis.shape[3]}, but should be {image_size}' 506 | 507 | vis.detach().cpu() 508 | vis = make_grid(vis, nrow = 4, padding = 5, normalize = True) 509 | writer.add_image(f'Generated/epoch_{epoch}', vis) 510 | wandb.log({'examples': wandb.Image(vis)}) 511 | 512 | vis = T.ToPILImage()(vis) 513 | vis.save(f'{experiment_folder_name}/samples/vis{epoch}.jpg') 514 | if generator_type == "vitgan": 515 | Generator.train() 516 | # img_siren.train() 517 | else: 518 | cnn_generator.train() 519 | print(f"Save sample to {experiment_folder_name}/samples/vis{epoch}.jpg") 520 | 521 | # Save the checkpoints. 522 | if generator_type == "vitgan": 523 | torch.save(Generator, f'{experiment_folder_name}/weights/Generator.pth') 524 | # torch.save(img_siren, f'{experiment_folder_name}/weights/img_siren.pth') 525 | elif generator_type == "cnn": 526 | torch.save(cnn_generator, f'{experiment_folder_name}/weights/cnn_generator.pth') 527 | torch.save(discriminator, f'{experiment_folder_name}/weights/discriminator.pth') 528 | print("Save model state.") 529 | 530 | writer.close() 531 | -------------------------------------------------------------------------------- /ViTGAN_pytorch.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "ZiVyqRSgDTpU" 7 | }, 8 | "source": [ 9 | "# ViTGAN pytorch implementation\n", 10 | "\n", 11 | "This notebook is a pytorch implementation of [VITGAN: Training GANs with Vision Transformers](https://arxiv.org/pdf/2107.04589v1.pdf)\n", 12 | "\n", 13 | "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1kJJw6BYW01HgooCZ2zUDt54e1mXqITXH?usp=sharing)" 14 | ] 15 | }, 16 | { 17 | "cell_type": "markdown", 18 | "metadata": { 19 | "id": "cjz0JmQyD4JJ" 20 | }, 21 | "source": [ 22 | "The model consists of a Vision Transformer Generator and a Vision Transformer Discriminator.\n", 23 | "\n", 24 | "It is adversarially trained to map latent vectors to images, which closely resemble the images from a given dataset. In this implementation, the dataset used is CIFAR-10.\n", 25 | "\n", 26 | "The Generator takes latent values $z$ as input, which is integrated in a Vision Transformer Encoder. The output for each patch of the image is fed to a SIREN network, in combination with a Fourier Embedding ($E_{fou}$)\n", 27 | "\n", 28 | "![ViTGAN Generator architecture](https://drive.google.com/uc?export=view&id=1XaCVOLq8Bvg-I3qM-bugNZcjIW5L7XTO)\n", 29 | "\n", 30 | "This implementation separates the Generator in Vision Transformer and SIREN networks for debugging purposes." 31 | ] 32 | }, 33 | { 34 | "cell_type": "markdown", 35 | "metadata": { 36 | "id": "6ZeCyoxmxwud" 37 | }, 38 | "source": [ 39 | "1. [x] Use vectorized L2 distance in attention for **Discriminator**\n", 40 | "2. [x] Overlapping Image Patches\n", 41 | "2. [x] DiffAugment\n", 42 | "3. [x] Self-modulated LayerNorm (SLN)\n", 43 | "4. [x] Implicit Neural Representation for Patch Generation\n", 44 | "5. [x] ExponentialMovingAverage (EMA)\n", 45 | "6. [x] Balanced Consistency Regularization (bCR)\n", 46 | "7. [x] Improved Spectral Normalization" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": null, 52 | "metadata": { 53 | "id": "Z_gqyo3DZA3J" 54 | }, 55 | "outputs": [], 56 | "source": [ 57 | "! pip install einops\n", 58 | "! pip install git+https://github.com/fadel/pytorch_ema\n", 59 | "! pip install stylegan2-pytorch\n", 60 | "! pip install tensorboard\n", 61 | "! pip install wandb" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": null, 67 | "metadata": { 68 | "id": "TufykVSpr-QU" 69 | }, 70 | "outputs": [], 71 | "source": [ 72 | "! wandb login" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": null, 78 | "metadata": { 79 | "id": "cXCe_q68481Q" 80 | }, 81 | "outputs": [], 82 | "source": [ 83 | "import torch\n", 84 | "from torch import nn\n", 85 | "import torch.nn.functional as F\n", 86 | "from torch.nn import Parameter\n", 87 | "import os\n", 88 | "\n", 89 | "import numpy as np\n", 90 | "import matplotlib.pyplot as plt\n", 91 | "\n", 92 | "import time\n", 93 | "\n", 94 | "import torchvision\n", 95 | "import torchvision.transforms as transforms\n", 96 | "from einops import rearrange, repeat\n", 97 | "from einops.layers.torch import Rearrange\n", 98 | "import torch.nn.functional as F\n", 99 | "import torchvision.transforms as T\n", 100 | "from torchvision.utils import make_grid\n", 101 | "\n", 102 | "from torch_ema import ExponentialMovingAverage\n", 103 | "\n", 104 | "from stylegan2_pytorch import stylegan2_pytorch\n", 105 | "\n", 106 | "from torch.utils.tensorboard import SummaryWriter\n", 107 | "import wandb" 108 | ] 109 | }, 110 | { 111 | "cell_type": "markdown", 112 | "metadata": { 113 | "id": "g-JOToTeB9uE" 114 | }, 115 | "source": [ 116 | "Hyperparameters" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": null, 122 | "metadata": { 123 | "id": "aeQq8IuCAPmZ" 124 | }, 125 | "outputs": [], 126 | "source": [ 127 | "image_size = 32\n", 128 | "style_mlp_layers = 8\n", 129 | "patch_size = 4\n", 130 | "latent_dim = 512 # Size of z\n", 131 | "hidden_size = 384\n", 132 | "depth = 4\n", 133 | "num_heads = 4\n", 134 | "\n", 135 | "dropout_p = 0.\n", 136 | "bias = True\n", 137 | "weight_modulation = True\n", 138 | "demodulation = False\n", 139 | "siren_hidden_layers = 1\n", 140 | "\n", 141 | "combine_patch_embeddings = False # Generate an image from a single SIREN, instead of patch-by-patch\n", 142 | "combine_patch_embeddings_size = hidden_size * 4\n", 143 | "\n", 144 | "sln_paremeter_size = hidden_size # either hidden_size or 1\n", 145 | "\n", 146 | "batch_size = 50\n", 147 | "device = \"cuda\"\n", 148 | "out_features = 3 # The number of color channels\n", 149 | "\n", 150 | "generator_type = \"vitgan\" # \"vitgan\", \"cnn\"\n", 151 | "discriminator_type = \"vitgan\" # \"vitgan\", \"cnn\", \"stylegan2\"\n", 152 | "\n", 153 | "lr = 7e-4 # Learning rate\n", 154 | "lr_dis = 7e-4 # Learning rate\n", 155 | "beta = (0., 0.99) # Adam oprimizer parameters for both the generator and the discriminator\n", 156 | "batch_size_history_discriminator = False # Whether to use a loss, which tracks one sample from last batch_size number of batches\n", 157 | "epochs = 400 # Number of epochs\n", 158 | "lambda_bCR_real = 10\n", 159 | "lambda_bCR_fake = 10\n", 160 | "lambda_lossD_noise = 0.0\n", 161 | "lambda_lossD_history = 0.0\n", 162 | "lambda_diversity_penalty = 0.0\n", 163 | "\n", 164 | "experiment_folder_name = f'lr-{lr}_\\\n", 165 | "lr_dis-{lr_dis}_\\\n", 166 | "bias-{bias}_\\\n", 167 | "demod-{demodulation}_\\\n", 168 | "sir_n_layer-{siren_hidden_layers}_\\\n", 169 | "w_mod-{weight_modulation}_\\\n", 170 | "patch_s-{patch_size}_\\\n", 171 | "st_mlp_l-{style_mlp_layers}_\\\n", 172 | "hid_size-{hidden_size}_\\\n", 173 | "comb_patch_emb-{combine_patch_embeddings}_\\\n", 174 | "sln_par_s-{sln_paremeter_size}_\\\n", 175 | "dis_type-{discriminator_type}_\\\n", 176 | "gen_type-{generator_type}_\\\n", 177 | "n_head-{num_heads}_\\\n", 178 | "depth-{depth}_\\\n", 179 | "drop_p-{dropout_p}_\\\n", 180 | "l_bCR_r-{lambda_bCR_real}_\\\n", 181 | "l_bCR_f-{lambda_bCR_fake}_\\\n", 182 | "l_D_noise-{lambda_lossD_noise}_\\\n", 183 | "l_D_his-{lambda_lossD_history}\\\n", 184 | "'\n", 185 | "writer = SummaryWriter(log_dir=experiment_folder_name)\n", 186 | "\n", 187 | "wandb.init(project='ViTGAN-pytorch')\n", 188 | "config = wandb.config\n", 189 | "config.image_size = image_size\n", 190 | "config.bias = bias\n", 191 | "config.demodulation = demodulation\n", 192 | "config.siren_hidden_layers = siren_hidden_layers\n", 193 | "config.weight_modulation = weight_modulation\n", 194 | "config.style_mlp_layers = style_mlp_layers\n", 195 | "config.patch_size = patch_size\n", 196 | "config.latent_dim = latent_dim\n", 197 | "config.hidden_size = hidden_size\n", 198 | "config.depth = depth\n", 199 | "config.num_heads = num_heads\n", 200 | "\n", 201 | "config.dropout_p = dropout_p\n", 202 | "\n", 203 | "config.combine_patch_embeddings = combine_patch_embeddings\n", 204 | "config.combine_patch_embeddings_size = combine_patch_embeddings_size\n", 205 | "\n", 206 | "config.sln_paremeter_size = sln_paremeter_size\n", 207 | "\n", 208 | "config.batch_size = batch_size\n", 209 | "config.device = device\n", 210 | "config.out_features = out_features\n", 211 | "\n", 212 | "config.generator_type = generator_type\n", 213 | "config.discriminator_type = discriminator_type\n", 214 | "\n", 215 | "config.lr = lr\n", 216 | "config.lr_dis = lr_dis\n", 217 | "config.beta1 = beta[0]\n", 218 | "config.beta2 = beta[1]\n", 219 | "config.batch_size_history_discriminator = batch_size_history_discriminator\n", 220 | "config.epochs = epochs\n", 221 | "config.lambda_bCR_real = lambda_bCR_real\n", 222 | "config.lambda_bCR_fake = lambda_bCR_fake\n", 223 | "config.lambda_lossD_noise = lambda_lossD_noise\n", 224 | "config.lambda_lossD_history = lambda_lossD_history\n", 225 | "config.lambda_diversity_penalty = lambda_diversity_penalty" 226 | ] 227 | }, 228 | { 229 | "cell_type": "code", 230 | "execution_count": null, 231 | "metadata": { 232 | "id": "V3dGermsa1bG" 233 | }, 234 | "outputs": [], 235 | "source": [ 236 | "if combine_patch_embeddings:\n", 237 | " out_patch_size = image_size\n", 238 | " combined_embedding_size = combine_patch_embeddings_size\n", 239 | "else:\n", 240 | " out_patch_size = patch_size\n", 241 | " combined_embedding_size = hidden_size\n", 242 | "\n", 243 | "siren_in_features = combined_embedding_size" 244 | ] 245 | }, 246 | { 247 | "cell_type": "markdown", 248 | "metadata": { 249 | "id": "2M4De21AB8lE" 250 | }, 251 | "source": [ 252 | "\n", 253 | "\n", 254 | "https://github.com/mit-han-lab/data-efficient-gans/blob/master/DiffAugment-stylegan2-pytorch/DiffAugment_pytorch.py" 255 | ] 256 | }, 257 | { 258 | "cell_type": "code", 259 | "execution_count": null, 260 | "metadata": { 261 | "id": "XsaytJSj0inJ" 262 | }, 263 | "outputs": [], 264 | "source": [ 265 | "def DiffAugment(x, policy='', channels_first=True):\n", 266 | " if policy:\n", 267 | " if not channels_first:\n", 268 | " x = x.permute(0, 3, 1, 2)\n", 269 | " for p in policy.split(','):\n", 270 | " for f in AUGMENT_FNS[p]:\n", 271 | " x = f(x)\n", 272 | " if not channels_first:\n", 273 | " x = x.permute(0, 2, 3, 1)\n", 274 | " x = x.contiguous()\n", 275 | " return x\n", 276 | "\n", 277 | "\n", 278 | "def rand_brightness(x):\n", 279 | " x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)\n", 280 | " return x\n", 281 | "\n", 282 | "\n", 283 | "def rand_saturation(x):\n", 284 | " x_mean = x.mean(dim=1, keepdim=True)\n", 285 | " x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean\n", 286 | " return x\n", 287 | "\n", 288 | "\n", 289 | "def rand_contrast(x):\n", 290 | " x_mean = x.mean(dim=[1, 2, 3], keepdim=True)\n", 291 | " x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean\n", 292 | " return x\n", 293 | "\n", 294 | "\n", 295 | "def rand_translation(x, ratio=0.1):\n", 296 | " shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)\n", 297 | " translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)\n", 298 | " translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)\n", 299 | " grid_batch, grid_x, grid_y = torch.meshgrid(\n", 300 | " torch.arange(x.size(0), dtype=torch.long, device=x.device),\n", 301 | " torch.arange(x.size(2), dtype=torch.long, device=x.device),\n", 302 | " torch.arange(x.size(3), dtype=torch.long, device=x.device),\n", 303 | " )\n", 304 | " grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)\n", 305 | " grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)\n", 306 | " x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])\n", 307 | " x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2).contiguous()\n", 308 | " return x\n", 309 | "\n", 310 | "\n", 311 | "def rand_cutout(x, ratio=0.3):\n", 312 | " cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)\n", 313 | " offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device)\n", 314 | " offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device)\n", 315 | " grid_batch, grid_x, grid_y = torch.meshgrid(\n", 316 | " torch.arange(x.size(0), dtype=torch.long, device=x.device),\n", 317 | " torch.arange(cutout_size[0], dtype=torch.long, device=x.device),\n", 318 | " torch.arange(cutout_size[1], dtype=torch.long, device=x.device),\n", 319 | " )\n", 320 | " grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)\n", 321 | " grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)\n", 322 | " mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)\n", 323 | " mask[grid_batch, grid_x, grid_y] = 0\n", 324 | " x = x * mask.unsqueeze(1)\n", 325 | " return x\n", 326 | "\n", 327 | "\n", 328 | "AUGMENT_FNS = {\n", 329 | " 'color': [rand_brightness, rand_saturation, rand_contrast],\n", 330 | " 'translation': [rand_translation],\n", 331 | " 'cutout': [rand_cutout],\n", 332 | "}" 333 | ] 334 | }, 335 | { 336 | "cell_type": "code", 337 | "execution_count": null, 338 | "metadata": { 339 | "id": "VA-uwpp7b1QK" 340 | }, 341 | "outputs": [], 342 | "source": [ 343 | "transform = transforms.Compose(\n", 344 | " [transforms.ToTensor(),\n", 345 | " transforms.Normalize((0., 0., 0.), (1., 1., 1.))\n", 346 | " ])\n", 347 | "\n", 348 | "trainset = torchvision.datasets.CIFAR10(root='./data', train=True,\n", 349 | " download=True, transform=transform)\n", 350 | "trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,\n", 351 | " shuffle=True, num_workers=2)" 352 | ] 353 | }, 354 | { 355 | "cell_type": "markdown", 356 | "metadata": { 357 | "id": "63v16amHIG2L" 358 | }, 359 | "source": [ 360 | "Visualize the effects of the DiffAugment" 361 | ] 362 | }, 363 | { 364 | "cell_type": "code", 365 | "execution_count": null, 366 | "metadata": { 367 | "id": "fnQW4nBraOJb" 368 | }, 369 | "outputs": [], 370 | "source": [ 371 | "img = next(iter(trainloader))[0]\n", 372 | "img = DiffAugment(img, policy='color,translation,cutout', channels_first=True)\n", 373 | "img = img.permute(0,2,3,1)[0]\n", 374 | "img -= img.min()\n", 375 | "img /= img.max()\n", 376 | "plt.imshow(img)" 377 | ] 378 | }, 379 | { 380 | "cell_type": "markdown", 381 | "metadata": { 382 | "id": "8QKaT8pk9vD_" 383 | }, 384 | "source": [ 385 | "# Improved Spectral Normalization (ISN)\n", 386 | "\n", 387 | "$$\n", 388 | "\\bar{W}_{ISN}(W):=\\sigma(W_{init})\\cdot W/\\sigma(W)\n", 389 | "$$\n", 390 | "\n", 391 | "Reference code: https://github.com/koshian2/SNGAN\n", 392 | "\n", 393 | "When updating the weights, normalize the weights' norm to its norm at initialization." 394 | ] 395 | }, 396 | { 397 | "cell_type": "code", 398 | "execution_count": null, 399 | "metadata": { 400 | "id": "BL2ZLEnM-xct" 401 | }, 402 | "outputs": [], 403 | "source": [ 404 | "def l2normalize(v, eps=1e-4):\n", 405 | "\treturn v / (v.norm() + eps)\n", 406 | "\n", 407 | "class spectral_norm(nn.Module):\n", 408 | "\tdef __init__(self, module, name='weight', power_iterations=1):\n", 409 | "\t\tsuper().__init__()\n", 410 | "\t\tself.module = module\n", 411 | "\t\tself.name = name\n", 412 | "\t\tself.power_iterations = power_iterations\n", 413 | "\t\tif not self._made_params():\n", 414 | "\t\t\tself._make_params()\n", 415 | "\t\tself.w_init_sigma = None\n", 416 | "\t\tself.w_initalized = False\n", 417 | "\n", 418 | "\tdef _update_u_v(self):\n", 419 | "\t\tu = getattr(self.module, self.name + \"_u\")\n", 420 | "\t\tv = getattr(self.module, self.name + \"_v\")\n", 421 | "\t\tw = getattr(self.module, self.name + \"_bar\")\n", 422 | "\n", 423 | "\t\theight = w.data.shape[0]\n", 424 | "\t\t_w = w.view(height, -1)\n", 425 | "\t\tfor _ in range(self.power_iterations):\n", 426 | "\t\t\tv = l2normalize(torch.matmul(_w.t(), u))\n", 427 | "\t\t\tu = l2normalize(torch.matmul(_w, v))\n", 428 | "\n", 429 | "\t\tsigma = u.dot((_w).mv(v))\n", 430 | "\t\tif not self.w_initalized:\n", 431 | "\t\t\tself.w_init_sigma = np.array(sigma.expand_as(w).detach().cpu())\n", 432 | "\t\t\tself.w_initalized = True\n", 433 | "\t\tsetattr(self.module, self.name, torch.tensor(self.w_init_sigma).to(device) * w / sigma.expand_as(w))\n", 434 | "\n", 435 | "\tdef _made_params(self):\n", 436 | "\t\ttry:\n", 437 | "\t\t\tgetattr(self.module, self.name + \"_u\")\n", 438 | "\t\t\tgetattr(self.module, self.name + \"_v\")\n", 439 | "\t\t\tgetattr(self.module, self.name + \"_bar\")\n", 440 | "\t\t\treturn True\n", 441 | "\t\texcept AttributeError:\n", 442 | "\t\t\treturn False\n", 443 | "\n", 444 | "\tdef _make_params(self):\n", 445 | "\t\tw = getattr(self.module, self.name)\n", 446 | "\n", 447 | "\t\theight = w.data.shape[0]\n", 448 | "\t\twidth = w.view(height, -1).data.shape[1]\n", 449 | "\n", 450 | "\t\tu = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)\n", 451 | "\t\tv = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)\n", 452 | "\t\tu.data = l2normalize(u.data)\n", 453 | "\t\tv.data = l2normalize(v.data)\n", 454 | "\t\tw_bar = Parameter(w.data)\n", 455 | "\n", 456 | "\t\tdel self.module._parameters[self.name]\n", 457 | "\t\tself.module.register_parameter(self.name + \"_u\", u)\n", 458 | "\t\tself.module.register_parameter(self.name + \"_v\", v)\n", 459 | "\t\tself.module.register_parameter(self.name + \"_bar\", w_bar)\n", 460 | "\n", 461 | "\tdef forward(self, *args):\n", 462 | "\t\tself._update_u_v()\n", 463 | "\t\treturn self.module.forward(*args)" 464 | ] 465 | }, 466 | { 467 | "cell_type": "markdown", 468 | "metadata": { 469 | "id": "Ttk-7hLuIZ4o" 470 | }, 471 | "source": [ 472 | "Vision Transformer reference code: \\[[Blog Post](https://towardsdatascience.com/implementing-visualttransformer-in-pytorch-184f9f16f632)\\]\n", 473 | "\n", 474 | "Normal Attention Mechanism\n", 475 | "\n", 476 | "$$\n", 477 | "Attention_h(X) = softmax \\bigg ( \\frac{QK^T}{\\sqrt{d_h}} V \\bigg )\n", 478 | "$$\n", 479 | "\n", 480 | "Lipschitz Attention Mechanism\n", 481 | "\n", 482 | "$$\n", 483 | "Attention_h(X) = softmax \\bigg ( \\frac{d(Q,K)}{\\sqrt{d_h}} V \\bigg )\n", 484 | "$$\n", 485 | "\n", 486 | "where $d(Q,K)$ is L2-distance.\n", 487 | "\n", 488 | "https://arxiv.org/pdf/2006.04710.pdf" 489 | ] 490 | }, 491 | { 492 | "cell_type": "code", 493 | "execution_count": null, 494 | "metadata": { 495 | "id": "F1VgcIu6MFow" 496 | }, 497 | "outputs": [], 498 | "source": [ 499 | "class MultiHeadAttention(nn.Module):\n", 500 | " def __init__(self, emb_size=384, num_heads=4, dropout=0, discriminator=False, **kwargs):\n", 501 | " super().__init__()\n", 502 | " self.emb_size = emb_size\n", 503 | " self.num_heads = num_heads\n", 504 | " self.discriminator = discriminator\n", 505 | " # fuse the queries, keys and values in one matrix\n", 506 | " self.qkv = nn.Linear(emb_size, emb_size * 3)\n", 507 | " self.att_drop = nn.Dropout(dropout)\n", 508 | " self.projection = nn.Linear(emb_size, emb_size)\n", 509 | " if self.discriminator:\n", 510 | " self.qkv = spectral_norm(self.qkv)\n", 511 | " self.projection = spectral_norm(self.projection)\n", 512 | " \n", 513 | " def forward(self, x, mask=None):\n", 514 | " # split keys, queries and values in num_heads\n", 515 | " qkv = rearrange(self.qkv(x), \"b n (h d qkv) -> (qkv) b h n d\", h=self.num_heads, qkv=3)\n", 516 | " queries, keys, values = qkv[0], qkv[1], qkv[2]\n", 517 | " if self.discriminator:\n", 518 | " # calculate L2-distances\n", 519 | " energy = torch.cdist(queries.contiguous(), keys.contiguous(), p=2)\n", 520 | " else:\n", 521 | " # sum up over the last axis\n", 522 | " energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys) # batch, num_heads, query_len, key_len\n", 523 | "\n", 524 | " if mask is not None:\n", 525 | " fill_value = torch.finfo(torch.float32).min\n", 526 | " energy.mask_fill(~mask, fill_value)\n", 527 | " \n", 528 | " scaling = self.emb_size ** (1/2)\n", 529 | " att = F.softmax(energy, dim=-1) / scaling\n", 530 | " att = self.att_drop(att)\n", 531 | " # sum up over the third axis\n", 532 | " out = torch.einsum('bhal, bhlv -> bhav ', att, values)\n", 533 | " out = rearrange(out, \"b h n d -> b n (h d)\")\n", 534 | " out = self.projection(out)\n", 535 | " return out" 536 | ] 537 | }, 538 | { 539 | "cell_type": "markdown", 540 | "metadata": { 541 | "id": "hpUuzMsS9G-t" 542 | }, 543 | "source": [ 544 | "# Generator" 545 | ] 546 | }, 547 | { 548 | "cell_type": "code", 549 | "execution_count": null, 550 | "metadata": { 551 | "id": "9xwQ0P_83N6L" 552 | }, 553 | "outputs": [], 554 | "source": [ 555 | "class FullyConnectedLayer(nn.Module):\n", 556 | " def __init__(self,\n", 557 | " in_features, # Number of input features.\n", 558 | " out_features, # Number of output features.\n", 559 | " bias = True, # Apply additive bias before the activation function?\n", 560 | " activation = 'linear', # Activation function: 'relu', 'lrelu', etc.\n", 561 | " lr_multiplier = 1, # Learning rate multiplier.\n", 562 | " bias_init = 0, # Initial value for the additive bias.\n", 563 | " **kwargs\n", 564 | " ):\n", 565 | " super().__init__()\n", 566 | " self.activation = activation\n", 567 | " if activation == 'lrelu':\n", 568 | " self.activation = nn.LeakyReLU(0.2)\n", 569 | " elif activation == 'relu':\n", 570 | " self.activation = nn.ReLU()\n", 571 | " elif activation == 'gelu':\n", 572 | " self.activation = nn.GELU()\n", 573 | " self.weight = nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier)\n", 574 | " self.bias = nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None\n", 575 | " self.weight_gain = lr_multiplier / np.sqrt(in_features)\n", 576 | " self.bias_gain = lr_multiplier\n", 577 | "\n", 578 | " def forward(self, x):\n", 579 | " w = self.weight.to(x.dtype) * self.weight_gain\n", 580 | " b = self.bias\n", 581 | " if b is not None:\n", 582 | " b = b.to(x.dtype)\n", 583 | " if self.bias_gain != 1:\n", 584 | " b = b * self.bias_gain\n", 585 | "\n", 586 | " if self.activation == 'linear' and b is not None:\n", 587 | " # print(b.shape, x.shape, w.t().shape)\n", 588 | " x = torch.addmm(b.unsqueeze(0), x, w.t())\n", 589 | " else:\n", 590 | " x = x.matmul(w.t())\n", 591 | " if b is not None:\n", 592 | " x = x + b\n", 593 | " if self.activation != 'linear':\n", 594 | " x = self.activation(x)\n", 595 | " return x" 596 | ] 597 | }, 598 | { 599 | "cell_type": "code", 600 | "execution_count": null, 601 | "metadata": { 602 | "id": "jx-xGRh_61Qz" 603 | }, 604 | "outputs": [], 605 | "source": [ 606 | "def normalize_2nd_moment(x, dim=1, eps=1e-8):\n", 607 | " return x * (x.square().mean(dim=dim, keepdim=True) + eps).rsqrt()" 608 | ] 609 | }, 610 | { 611 | "cell_type": "code", 612 | "execution_count": null, 613 | "metadata": { 614 | "id": "YCHXNmaF5C6m" 615 | }, 616 | "outputs": [], 617 | "source": [ 618 | "class MappingNetwork(nn.Module):\n", 619 | " def __init__(self,\n", 620 | " z_dim, # Input latent (Z) dimensionality, 0 = no latent.\n", 621 | " c_dim, # Conditioning label (C) dimensionality, 0 = no label.\n", 622 | " w_dim, # Intermediate latent (W) dimensionality.\n", 623 | " num_ws = None, # Number of intermediate latents to output, None = do not broadcast.\n", 624 | " num_layers = 8, # Number of mapping layers.\n", 625 | " embed_features = None, # Label embedding dimensionality, None = same as w_dim.\n", 626 | " layer_features = None, # Number of intermediate features in the mapping layers, None = same as w_dim.\n", 627 | " activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc.\n", 628 | " lr_multiplier = 0.01, # Learning rate multiplier for the mapping layers.\n", 629 | " w_avg_beta = 0.995, # Decay for tracking the moving average of W during training, None = do not track.\n", 630 | " **kwargs\n", 631 | " ):\n", 632 | " super().__init__()\n", 633 | " self.z_dim = z_dim\n", 634 | " self.c_dim = c_dim\n", 635 | " self.w_dim = w_dim\n", 636 | " self.num_ws = num_ws\n", 637 | " self.num_layers = num_layers\n", 638 | " self.w_avg_beta = w_avg_beta\n", 639 | "\n", 640 | " if embed_features is None:\n", 641 | " embed_features = w_dim\n", 642 | " if c_dim == 0:\n", 643 | " embed_features = 0\n", 644 | " if layer_features is None:\n", 645 | " layer_features = w_dim\n", 646 | " features_list = [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim]\n", 647 | "\n", 648 | " if c_dim > 0:\n", 649 | " self.embed = FullyConnectedLayer(c_dim, embed_features)\n", 650 | " for idx in range(num_layers):\n", 651 | " in_features = features_list[idx]\n", 652 | " out_features = features_list[idx + 1]\n", 653 | " layer = FullyConnectedLayer(in_features, out_features, activation=activation, lr_multiplier=lr_multiplier)\n", 654 | " setattr(self, f'fc{idx}', layer)\n", 655 | "\n", 656 | " if num_ws is not None and w_avg_beta is not None:\n", 657 | " self.register_buffer('w_avg', torch.zeros([w_dim]))\n", 658 | "\n", 659 | " def forward(self, z, c=None, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False):\n", 660 | " # Embed, normalize, and concat inputs.\n", 661 | " x = None\n", 662 | " with torch.autograd.profiler.record_function('input'):\n", 663 | " if self.z_dim > 0:\n", 664 | " assert z.shape[1] == self.z_dim\n", 665 | " x = normalize_2nd_moment(z.to(torch.float32))\n", 666 | " if self.c_dim > 0:\n", 667 | " assert c.shape[1] == self.c_dim\n", 668 | " y = normalize_2nd_moment(self.embed(c.to(torch.float32)))\n", 669 | " x = torch.cat([x, y], dim=1) if x is not None else y\n", 670 | "\n", 671 | " # Main layers.\n", 672 | " for idx in range(self.num_layers):\n", 673 | " layer = getattr(self, f'fc{idx}')\n", 674 | " x = layer(x)\n", 675 | "\n", 676 | " # Update moving average of W.\n", 677 | " if self.w_avg_beta is not None and self.training and not skip_w_avg_update:\n", 678 | " with torch.autograd.profiler.record_function('update_w_avg'):\n", 679 | " self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta))\n", 680 | "\n", 681 | " # Broadcast.\n", 682 | " if self.num_ws is not None:\n", 683 | " with torch.autograd.profiler.record_function('broadcast'):\n", 684 | " x = x.unsqueeze(1).repeat([1, self.num_ws, 1])\n", 685 | "\n", 686 | " # Apply truncation.\n", 687 | " if truncation_psi != 1:\n", 688 | " with torch.autograd.profiler.record_function('truncate'):\n", 689 | " assert self.w_avg_beta is not None\n", 690 | " if self.num_ws is None or truncation_cutoff is None:\n", 691 | " x = self.w_avg.lerp(x, truncation_psi)\n", 692 | " else:\n", 693 | " x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi)\n", 694 | " return x" 695 | ] 696 | }, 697 | { 698 | "cell_type": "code", 699 | "execution_count": null, 700 | "metadata": { 701 | "id": "2GJiP09nL_Lh" 702 | }, 703 | "outputs": [], 704 | "source": [ 705 | "class FeedForwardBlock(nn.Sequential):\n", 706 | " def __init__(self, emb_size, expansion=4, drop_p=0., bias=False):\n", 707 | " super().__init__(\n", 708 | " FullyConnectedLayer(expansion, emb_size * emb_size, activation='gelu', bias=False),\n", 709 | " nn.Dropout(drop_p),\n", 710 | " FullyConnectedLayer(expansion * emb_size, emb_size, bias=False),\n", 711 | " )" 712 | ] 713 | }, 714 | { 715 | "cell_type": "markdown", 716 | "metadata": { 717 | "id": "4713MviFIl8R" 718 | }, 719 | "source": [ 720 | "Self-Modulated LayerNorm\n", 721 | "$$\n", 722 | "SLN(h_{\\ell},w)=\\gamma_{\\ell}(w)\\odot\\frac{h_{\\ell}-\\mu}{\\sigma}+\\beta_{\\ell}(w)\n", 723 | "$$\n", 724 | "\n", 725 | "where $\\gamma_{\\ell}, \\beta_{\\ell}\\in \\mathbb{R}^D$ or $\\gamma_{\\ell}, \\beta_{\\ell}\\in \\mathbb{R}^1$" 726 | ] 727 | }, 728 | { 729 | "cell_type": "code", 730 | "execution_count": null, 731 | "metadata": { 732 | "id": "Wdd9c_sk9GTC" 733 | }, 734 | "outputs": [], 735 | "source": [ 736 | "class SLN(nn.Module):\n", 737 | " def __init__(self, input_size, parameter_size=None, **kwargs):\n", 738 | " super().__init__()\n", 739 | " if parameter_size == None:\n", 740 | " parameter_size = input_size\n", 741 | " assert(input_size == parameter_size or parameter_size == 1)\n", 742 | " self.input_size = input_size\n", 743 | " self.parameter_size = parameter_size\n", 744 | " self.ln = nn.LayerNorm(input_size)\n", 745 | " self.gamma = FullyConnectedLayer(input_size, parameter_size, bias=False)\n", 746 | " self.beta = FullyConnectedLayer(input_size, parameter_size, bias=False)\n", 747 | " # self.gamma = nn.Linear(input_size, parameter_size, bias=False)\n", 748 | " # self.beta = nn.Linear(input_size, parameter_size, bias=False)\n", 749 | "\n", 750 | " def forward(self, hidden, w):\n", 751 | " assert(hidden.size(-1) == self.parameter_size and w.size(-1) == self.parameter_size)\n", 752 | " gamma = self.gamma(w).unsqueeze(1)\n", 753 | " beta = self.beta(w).unsqueeze(1)\n", 754 | " ln = self.ln(hidden)\n", 755 | " return gamma * ln + beta" 756 | ] 757 | }, 758 | { 759 | "cell_type": "code", 760 | "execution_count": null, 761 | "metadata": { 762 | "id": "YvIYYkHe9J3f" 763 | }, 764 | "outputs": [], 765 | "source": [ 766 | "class GeneratorTransformerEncoderBlock(nn.Module):\n", 767 | " def __init__(self,\n", 768 | " hidden_size=384,\n", 769 | " sln_paremeter_size=384,\n", 770 | " drop_p=0.,\n", 771 | " forward_expansion=4,\n", 772 | " forward_drop_p=0.,\n", 773 | " **kwargs):\n", 774 | " super().__init__()\n", 775 | " self.sln = SLN(hidden_size, parameter_size=sln_paremeter_size)\n", 776 | " self.msa = MultiHeadAttention(hidden_size, **kwargs)\n", 777 | " self.dropout = nn.Dropout(drop_p)\n", 778 | " self.feed_forward = FeedForwardBlock(hidden_size, expansion=forward_expansion, drop_p=forward_drop_p)\n", 779 | "\n", 780 | " def forward(self, hidden, w):\n", 781 | " res = hidden\n", 782 | " hidden = self.sln(hidden, w)\n", 783 | " hidden = self.msa(hidden)\n", 784 | " hidden = self.dropout(hidden)\n", 785 | " hidden += res\n", 786 | "\n", 787 | " res = hidden\n", 788 | " hidden = self.sln(hidden, w)\n", 789 | " self.feed_forward(hidden)\n", 790 | " hidden = self.dropout(hidden)\n", 791 | " hidden += res\n", 792 | " return hidden" 793 | ] 794 | }, 795 | { 796 | "cell_type": "code", 797 | "execution_count": null, 798 | "metadata": { 799 | "id": "xXueAJqx9LsE" 800 | }, 801 | "outputs": [], 802 | "source": [ 803 | "class GeneratorTransformerEncoder(nn.Module):\n", 804 | " def __init__(self, depth=4, **kwargs):\n", 805 | " super().__init__()\n", 806 | " self.depth = depth\n", 807 | " self.blocks = nn.ModuleList([GeneratorTransformerEncoderBlock(**kwargs) for _ in range(depth)])\n", 808 | " \n", 809 | " def forward(self, hidden, w):\n", 810 | " for i in range(self.depth):\n", 811 | " hidden = self.blocks[i](hidden, w)\n", 812 | " return hidden" 813 | ] 814 | }, 815 | { 816 | "cell_type": "markdown", 817 | "metadata": { 818 | "id": "BiQtj0Qa9b7L" 819 | }, 820 | "source": [ 821 | "# SIREN" 822 | ] 823 | }, 824 | { 825 | "cell_type": "markdown", 826 | "metadata": { 827 | "id": "qH09oAfX481V" 828 | }, 829 | "source": [ 830 | "Code for SIREN is taken from [SIREN reference colab notebook](https://colab.research.google.com/github/vsitzmann/siren/blob/master/explore_siren.ipynb)" 831 | ] 832 | }, 833 | { 834 | "cell_type": "markdown", 835 | "metadata": { 836 | "id": "EOD98u3DzTow" 837 | }, 838 | "source": [ 839 | "$$\n", 840 | "w^{'}_{ijk}=s_i\\cdot w_{ijk}\n", 841 | "$$\n", 842 | "\n", 843 | "$$\n", 844 | "w^{''}_{ijk}=\\frac{w^{'}_{ijk}}{\\sqrt{\\sum_{i,k}{w^{'}_{ijk}}^2+\\epsilon}}\n", 845 | "$$" 846 | ] 847 | }, 848 | { 849 | "cell_type": "code", 850 | "execution_count": null, 851 | "metadata": { 852 | "id": "bm7OIWD5Y7Up" 853 | }, 854 | "outputs": [], 855 | "source": [ 856 | "class ModulatedLinear(nn.Module):\n", 857 | " def __init__(self, in_channels, out_channels, style_size, bias=False, demodulation=True, **kwargs):\n", 858 | " super().__init__()\n", 859 | " self.in_channels = in_channels\n", 860 | " self.out_channels = out_channels\n", 861 | " self.style_size = style_size\n", 862 | " self.scale = 1 / np.sqrt(in_channels)\n", 863 | " self.weight = nn.Parameter(\n", 864 | " torch.randn(1, out_channels, in_channels, 1)\n", 865 | " )\n", 866 | " self.modulation = None\n", 867 | " if self.style_size != self.in_channels:\n", 868 | " self.modulation = FullyConnectedLayer(style_size, in_channels, bias=False)\n", 869 | " self.demodulation = demodulation\n", 870 | "\n", 871 | " def forward(self, input, style):\n", 872 | " batch_size = input.shape[0]\n", 873 | "\n", 874 | " if self.style_size != self.in_channels:\n", 875 | " style = self.modulation(style)\n", 876 | " style = style.view(batch_size, 1, self.in_channels, 1)\n", 877 | " # print('self.scale, self.weight.shape, style.shape', self.scale, self.weight.shape, style.shape)\n", 878 | " weight = self.scale * self.weight * style\n", 879 | "\n", 880 | " if self.demodulation:\n", 881 | " demod = torch.rsqrt(weight.pow(2).sum([2]) + 1e-8)\n", 882 | " weight = weight * demod.view(batch_size, self.out_channels, 1, 1)\n", 883 | "\n", 884 | " weight = weight.view(\n", 885 | " batch_size * self.out_channels, self.in_channels, 1\n", 886 | " )\n", 887 | " \n", 888 | " img_size = input.size(1)\n", 889 | " input = input.reshape(1, batch_size * self.in_channels, img_size)\n", 890 | " out = F.conv1d(input, weight, groups=batch_size)\n", 891 | " out = out.view(batch_size, img_size, self.out_channels)\n", 892 | "\n", 893 | " return out" 894 | ] 895 | }, 896 | { 897 | "cell_type": "code", 898 | "execution_count": null, 899 | "metadata": { 900 | "id": "gJv0hR38RBFC" 901 | }, 902 | "outputs": [], 903 | "source": [ 904 | "class ResLinear(nn.Module):\n", 905 | " def __init__(self, in_channels, out_channels, style_size, bias=False, **kwargs):\n", 906 | " super().__init__()\n", 907 | " self.linear = FullyConnectedLayer(in_channels, out_channels, bias=False)\n", 908 | " self.style = FullyConnectedLayer(style_size, in_channels, bias=False)\n", 909 | " self.in_channels = in_channels\n", 910 | " self.out_channels = out_channels\n", 911 | " self.style_size = style_size\n", 912 | " # print('style_size, in_channels, out_channels', style_size, in_channels, out_channels)\n", 913 | "\n", 914 | " def forward(self, input, style):\n", 915 | " x = input + self.style(style).unsqueeze(1)\n", 916 | " x = self.linear(x)\n", 917 | " return x" 918 | ] 919 | }, 920 | { 921 | "cell_type": "code", 922 | "execution_count": null, 923 | "metadata": { 924 | "id": "930yhZ6zgPI0" 925 | }, 926 | "outputs": [], 927 | "source": [ 928 | "class ConLinear(nn.Module):\n", 929 | " def __init__(self, ch_in, ch_out, is_first=False, bias=True, **kwargs):\n", 930 | " super(ConLinear, self).__init__()\n", 931 | " self.conv = nn.Linear(ch_in, ch_out, bias=bias)\n", 932 | " if is_first:\n", 933 | " nn.init.uniform_(self.conv.weight, -np.sqrt(9 / ch_in), np.sqrt(9 / ch_in))\n", 934 | " else:\n", 935 | " nn.init.uniform_(self.conv.weight, -np.sqrt(3 / ch_in), np.sqrt(3 / ch_in))\n", 936 | "\n", 937 | " def forward(self, x):\n", 938 | " return self.conv(x)\n", 939 | "\n", 940 | "class SinActivation(nn.Module):\n", 941 | " def __init__(self):\n", 942 | " super(SinActivation, self).__init__()\n", 943 | "\n", 944 | " def forward(self, x):\n", 945 | " return torch.sin(x)\n", 946 | "\n", 947 | "class LFF(nn.Module):\n", 948 | " def __init__(self, hidden_size, **kwargs):\n", 949 | " super(LFF, self).__init__()\n", 950 | " self.ffm = ConLinear(2, hidden_size, is_first=True)\n", 951 | " self.activation = SinActivation()\n", 952 | "\n", 953 | " def forward(self, x):\n", 954 | " x = x\n", 955 | " x = self.ffm(x)\n", 956 | " x = self.activation(x)\n", 957 | " return x" 958 | ] 959 | }, 960 | { 961 | "cell_type": "code", 962 | "execution_count": null, 963 | "metadata": { 964 | "id": "wgjY_FpH481X" 965 | }, 966 | "outputs": [], 967 | "source": [ 968 | "class SineLayer(nn.Module):\n", 969 | " def __init__(self, in_features, out_features, style_size, bias=False,\n", 970 | " is_first=False, omega_0=30, weight_modulation=True, **kwargs):\n", 971 | " super().__init__()\n", 972 | " self.omega_0 = omega_0\n", 973 | " self.is_first = is_first\n", 974 | " \n", 975 | " self.in_features = in_features\n", 976 | " self.weight_modulation = weight_modulation\n", 977 | " if weight_modulation:\n", 978 | " self.linear = ModulatedLinear(in_features, out_features, style_size=style_size, bias=bias, **kwargs)\n", 979 | " else:\n", 980 | " self.linear = ResLinear(in_features, out_features, style_size=style_size, bias=bias, **kwargs)\n", 981 | " # print('in_features, out_features, style_size', in_features, out_features, style_size)\n", 982 | " self.init_weights()\n", 983 | " \n", 984 | " def init_weights(self):\n", 985 | " with torch.no_grad():\n", 986 | " if self.is_first:\n", 987 | " if self.weight_modulation:\n", 988 | " self.linear.weight.uniform_(-1 / self.in_features, \n", 989 | " 1 / self.in_features)\n", 990 | " else:\n", 991 | " self.linear.linear.weight.uniform_(-1 / self.in_features, \n", 992 | " 1 / self.in_features) \n", 993 | " else:\n", 994 | " if self.weight_modulation:\n", 995 | " self.linear.weight.uniform_(-np.sqrt(6 / self.in_features) / self.omega_0, \n", 996 | " np.sqrt(6 / self.in_features) / self.omega_0)\n", 997 | " else:\n", 998 | " self.linear.linear.weight.uniform_(-np.sqrt(6 / self.in_features) / self.omega_0, \n", 999 | " np.sqrt(6 / self.in_features) / self.omega_0) \n", 1000 | " \n", 1001 | " def forward(self, input, style):\n", 1002 | " return torch.sin(self.omega_0 * self.linear(input, style))\n", 1003 | " \n", 1004 | "class Siren(nn.Module):\n", 1005 | " def __init__(self, in_features, hidden_size, hidden_layers, out_features, style_size, outermost_linear=False, \n", 1006 | " first_omega_0=30, hidden_omega_0=30., weight_modulation=True, bias=False, **kwargs):\n", 1007 | " super().__init__()\n", 1008 | " \n", 1009 | " self.net = []\n", 1010 | " self.net.append(SineLayer(in_features, hidden_size, style_size,\n", 1011 | " is_first=True, omega_0=first_omega_0,\n", 1012 | " weight_modulation=weight_modulation, **kwargs))\n", 1013 | "\n", 1014 | " for i in range(hidden_layers):\n", 1015 | " self.net.append(SineLayer(hidden_size, hidden_size, style_size,\n", 1016 | " is_first=False, omega_0=hidden_omega_0,\n", 1017 | " weight_modulation=weight_modulation, **kwargs))\n", 1018 | "\n", 1019 | " if outermost_linear:\n", 1020 | " if weight_modulation:\n", 1021 | " final_linear = ModulatedLinear(hidden_size, out_features,\n", 1022 | " style_size=style_size, bias=bias, **kwargs)\n", 1023 | " else:\n", 1024 | " final_linear = ResLinear(hidden_size, out_features, style_size=style_size, bias=bias, **kwargs)\n", 1025 | " # FullyConnectedLayer(hidden_size, out_features, bias=False)\n", 1026 | " # final_linear = nn.Linear(hidden_size, out_features)\n", 1027 | " \n", 1028 | " with torch.no_grad():\n", 1029 | " if weight_modulation:\n", 1030 | " final_linear.weight.uniform_(-np.sqrt(6 / hidden_size) / hidden_omega_0, \n", 1031 | " np.sqrt(6 / hidden_size) / hidden_omega_0)\n", 1032 | " else:\n", 1033 | " final_linear.linear.weight.uniform_(-np.sqrt(6 / hidden_size) / hidden_omega_0, \n", 1034 | " np.sqrt(6 / hidden_size) / hidden_omega_0)\n", 1035 | " \n", 1036 | " self.net.append(final_linear)\n", 1037 | " else:\n", 1038 | " self.net.append(SineLayer(hidden_size, out_features, \n", 1039 | " is_first=False, omega_0=hidden_omega_0,\n", 1040 | " weight_modulation=weight_modulation, **kwargs))\n", 1041 | " \n", 1042 | " self.net = nn.Sequential(*self.net)\n", 1043 | " \n", 1044 | " def forward(self, coords, style):\n", 1045 | " coords = coords.clone().detach().requires_grad_(True) # allows to take derivative w.r.t. input\n", 1046 | " # output = self.net(coords, style)\n", 1047 | " output = coords\n", 1048 | " for layer in self.net:\n", 1049 | " output = layer(output, style)\n", 1050 | " return output" 1051 | ] 1052 | }, 1053 | { 1054 | "cell_type": "markdown", 1055 | "metadata": { 1056 | "id": "V8DB73-TBGEs" 1057 | }, 1058 | "source": [ 1059 | "$$\n", 1060 | "Fou(\\mathbf{v})= \\left[ \\cos(2 \\pi \\mathbf B \\mathbf{v}), \\sin(2 \\pi \\mathbf B \\mathbf{v}) \\right]^\\mathrm{T}\n", 1061 | "$$" 1062 | ] 1063 | }, 1064 | { 1065 | "cell_type": "code", 1066 | "execution_count": null, 1067 | "metadata": { 1068 | "id": "5ru0trNs9N3F" 1069 | }, 1070 | "outputs": [], 1071 | "source": [ 1072 | "class GeneratorViT(nn.Module):\n", 1073 | " def __init__(self,\n", 1074 | " style_mlp_layers=8,\n", 1075 | " patch_size=4,\n", 1076 | " latent_dim=32,\n", 1077 | " hidden_size=384,\n", 1078 | " sln_paremeter_size=1,\n", 1079 | " image_size=32,\n", 1080 | " depth=4,\n", 1081 | " combine_patch_embeddings=False,\n", 1082 | " combined_embedding_size=1024,\n", 1083 | " forward_drop_p=0.,\n", 1084 | " bias=False,\n", 1085 | " out_features=3,\n", 1086 | " weight_modulation=True,\n", 1087 | " siren_hidden_layers=1,\n", 1088 | " **kwargs):\n", 1089 | " super().__init__()\n", 1090 | " self.hidden_size = hidden_size\n", 1091 | "\n", 1092 | " self.mlp = MappingNetwork(z_dim=latent_dim, c_dim=0, w_dim=hidden_size, num_layers=style_mlp_layers, w_avg_beta=None)\n", 1093 | "\n", 1094 | " num_patches = int(image_size//patch_size)**2\n", 1095 | " self.patch_size = patch_size\n", 1096 | " self.num_patches = num_patches\n", 1097 | " self.image_size = image_size\n", 1098 | " self.combine_patch_embeddings = combine_patch_embeddings\n", 1099 | " self.combined_embedding_size = combined_embedding_size\n", 1100 | "\n", 1101 | " self.pos_emb = nn.Parameter(torch.randn(num_patches, hidden_size))\n", 1102 | " self.transformer_encoder = GeneratorTransformerEncoder(depth,\n", 1103 | " hidden_size=hidden_size,\n", 1104 | " sln_paremeter_size=sln_paremeter_size,\n", 1105 | " drop_p=forward_drop_p,\n", 1106 | " forward_drop_p=forward_drop_p,\n", 1107 | " **kwargs)\n", 1108 | " self.sln = SLN(hidden_size, parameter_size=sln_paremeter_size)\n", 1109 | " if combine_patch_embeddings:\n", 1110 | " self.to_single_emb = nn.Sequential(\n", 1111 | " FullyConnectedLayer(num_patches*hidden_size, combined_embedding_size, bias=bias, activation='gelu'),\n", 1112 | " nn.Dropout(forward_drop_p),\n", 1113 | " )\n", 1114 | "\n", 1115 | " self.lff = LFF(self.hidden_size)\n", 1116 | "\n", 1117 | " self.siren_in_features = combined_embedding_size if combine_patch_embeddings else self.hidden_size\n", 1118 | " self.siren = Siren(in_features=self.siren_in_features, out_features=out_features,\n", 1119 | " style_size=self.siren_in_features, hidden_size=self.hidden_size, bias=bias,\n", 1120 | " hidden_layers=siren_hidden_layers, outermost_linear=True, weight_modulation=weight_modulation, **kwargs)\n", 1121 | "\n", 1122 | " self.num_patches_x = int(image_size//out_patch_size)\n", 1123 | "\n", 1124 | "\n", 1125 | " def fourier_input_mapping(self, x):\n", 1126 | " return self.lff(x)\n", 1127 | "\n", 1128 | " def fourier_pos_embedding(self):\n", 1129 | " # Create input pixel coordinates in the unit square\n", 1130 | " coords = np.linspace(-1, 1, out_patch_size, endpoint=True)\n", 1131 | " pos = np.stack(np.meshgrid(coords, coords), -1)\n", 1132 | " pos = torch.tensor(pos, dtype=torch.float).to(device)\n", 1133 | " result = self.fourier_input_mapping(pos).reshape([out_patch_size**2, self.hidden_size])\n", 1134 | " return result.to(device)\n", 1135 | "\n", 1136 | " def mix_hidden_and_pos(self, hidden):\n", 1137 | " pos = self.fourier_pos_embedding()\n", 1138 | "\n", 1139 | " pos = repeat(pos, 'p h -> n p h', n = hidden.shape[0])\n", 1140 | "\n", 1141 | " return result\n", 1142 | "\n", 1143 | " def forward(self, z):\n", 1144 | " w = self.mlp(z)\n", 1145 | " pos = repeat(torch.sin(self.pos_emb), 'n e -> b n e', b=z.shape[0])\n", 1146 | " hidden = self.transformer_encoder(pos, w)\n", 1147 | "\n", 1148 | " if self.combine_patch_embeddings:\n", 1149 | " # Output [batch_size, combined_embedding_size]\n", 1150 | " hidden = self.sln(hidden, w).view((z.shape[0], -1))\n", 1151 | " hidden = self.to_single_emb(hidden)\n", 1152 | " else:\n", 1153 | " # Output [batch_size*num_patches, hidden_size]\n", 1154 | " hidden = self.sln(hidden, w).view((-1, self.hidden_size))\n", 1155 | " \n", 1156 | " pos = self.mix_hidden_and_pos(hidden)\n", 1157 | "\n", 1158 | " # hidden = repeat(hidden, 'n h -> n p h', p = out_patch_size**2)\n", 1159 | "\n", 1160 | " result = self.siren(pos, hidden)\n", 1161 | "\n", 1162 | " model_output_1 = result.view([-1, self.num_patches_x, self.num_patches_x, out_patch_size, out_patch_size, out_features])\n", 1163 | " model_output_2 = model_output_1.permute([0, 1, 3, 2, 4, 5])\n", 1164 | " model_output = model_output_2.reshape([-1, image_size**2, out_features])\n", 1165 | " \n", 1166 | " return model_output\n", 1167 | "\n", 1168 | "\n", 1169 | "Generator = GeneratorViT( patch_size=patch_size,\n", 1170 | " image_size=image_size,\n", 1171 | " style_mlp_layers=style_mlp_layers,\n", 1172 | " latent_dim=latent_dim,\n", 1173 | " hidden_size=hidden_size,\n", 1174 | " combine_patch_embeddings=combine_patch_embeddings,\n", 1175 | " combined_embedding_size=combined_embedding_size,\n", 1176 | " sln_paremeter_size=sln_paremeter_size,\n", 1177 | " num_heads=num_heads,\n", 1178 | " depth=depth,\n", 1179 | " forward_drop_p=dropout_p,\n", 1180 | " bias=bias,\n", 1181 | " weight_modulation=weight_modulation,\n", 1182 | " siren_hidden_layers=siren_hidden_layers,\n", 1183 | " demodulation=demodulation,\n", 1184 | " ).to(device)\n", 1185 | "print(Generator(torch.randn([batch_size, latent_dim]).to(device)).shape)\n", 1186 | "# print(Generator)\n", 1187 | "del Generator" 1188 | ] 1189 | }, 1190 | { 1191 | "cell_type": "markdown", 1192 | "metadata": { 1193 | "id": "NqMxvqCjb2fp" 1194 | }, 1195 | "source": [ 1196 | "# CNN Generator" 1197 | ] 1198 | }, 1199 | { 1200 | "cell_type": "code", 1201 | "execution_count": null, 1202 | "metadata": { 1203 | "id": "D7e7XAjvb5pX" 1204 | }, 1205 | "outputs": [], 1206 | "source": [ 1207 | "class CNNGenerator(nn.Module):\n", 1208 | " def __init__(self):\n", 1209 | " super(CNNGenerator, self).__init__()\n", 1210 | " self.w = nn.Linear(latent_dim, hidden_size * 2 * 4 * 4, bias=False)\n", 1211 | " self.main = nn.Sequential(\n", 1212 | " # input is Z, going into a convolution\n", 1213 | " nn.BatchNorm2d(hidden_size * 2),\n", 1214 | " nn.ReLU(True),\n", 1215 | " # state size. (ngf*8) x 4 x 4\n", 1216 | " nn.ConvTranspose2d(hidden_size * 2, hidden_size, 4, 2, 1, bias=False),\n", 1217 | " nn.BatchNorm2d(hidden_size),\n", 1218 | " nn.ReLU(True),\n", 1219 | " # state size. (ngf*4) x 8 x 8\n", 1220 | " nn.ConvTranspose2d( hidden_size, hidden_size // 2, 4, 2, 1, bias=False),\n", 1221 | " nn.BatchNorm2d(hidden_size // 2),\n", 1222 | " nn.ReLU(True),\n", 1223 | " # state size. (ngf*2) x 16 x 16\n", 1224 | " nn.ConvTranspose2d( hidden_size // 2, hidden_size // 4, 4, 2, 1, bias=False),\n", 1225 | " nn.BatchNorm2d(hidden_size // 4),\n", 1226 | " nn.ReLU(True),\n", 1227 | " # state size. (ngf*2) x 32 x 32\n", 1228 | " nn.ConvTranspose2d( hidden_size // 4, 3, 3, 1, 1, bias=False),\n", 1229 | " nn.Tanh(),\n", 1230 | " # state size. (nc) x 64 x 64\n", 1231 | " )\n", 1232 | "\n", 1233 | " def forward(self, input):\n", 1234 | " input = self.w(input).view((-1, hidden_size * 2, 4, 4))\n", 1235 | " return self.main(input)" 1236 | ] 1237 | }, 1238 | { 1239 | "cell_type": "markdown", 1240 | "metadata": { 1241 | "id": "t1Bz035Ugroo" 1242 | }, 1243 | "source": [ 1244 | "# Discriminator" 1245 | ] 1246 | }, 1247 | { 1248 | "cell_type": "code", 1249 | "execution_count": null, 1250 | "metadata": { 1251 | "id": "wY_LoVXOgtao" 1252 | }, 1253 | "outputs": [], 1254 | "source": [ 1255 | "class PatchEmbedding(nn.Module):\n", 1256 | " def __init__(self, in_channels=3, patch_size=4, stride_size=4, emb_size=384, image_size=32, batch_size=64):\n", 1257 | " super().__init__()\n", 1258 | " self.patch_size = patch_size\n", 1259 | " self.projection = nn.Sequential(\n", 1260 | " # using a conv layer instead of a linear one -> performance gains\n", 1261 | " spectral_norm(nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=stride_size)).to(device),\n", 1262 | " Rearrange('b e (h) (w) -> b (h w) e'),\n", 1263 | " )\n", 1264 | " self.cls_token = nn.Parameter(torch.randn(1, 1, emb_size))\n", 1265 | " num_patches = ((image_size-patch_size+stride_size) // stride_size) **2 + 1\n", 1266 | " self.positions = nn.Parameter(torch.randn(num_patches, emb_size))\n", 1267 | " self.batch_size = batch_size\n", 1268 | "\n", 1269 | " def forward(self, x):\n", 1270 | " b, _, _, _ = x.shape\n", 1271 | " x = self.projection(x)\n", 1272 | " cls_tokens = repeat(self.cls_token, '() n e -> b n e', b=b)\n", 1273 | " # prepend the cls token to the input\n", 1274 | " x = torch.cat([cls_tokens, x], dim=1)\n", 1275 | " # add position embedding\n", 1276 | " x += torch.sin(self.positions)\n", 1277 | " return x" 1278 | ] 1279 | }, 1280 | { 1281 | "cell_type": "code", 1282 | "execution_count": null, 1283 | "metadata": { 1284 | "id": "D8DZy_pOg0ID" 1285 | }, 1286 | "outputs": [], 1287 | "source": [ 1288 | "class ResidualAdd(nn.Module):\n", 1289 | " def __init__(self, fn):\n", 1290 | " super().__init__()\n", 1291 | " self.fn = fn\n", 1292 | " \n", 1293 | " def forward(self, x, **kwargs):\n", 1294 | " res = x\n", 1295 | " x = self.fn(x, **kwargs)\n", 1296 | " x += res\n", 1297 | " return x" 1298 | ] 1299 | }, 1300 | { 1301 | "cell_type": "code", 1302 | "execution_count": null, 1303 | "metadata": { 1304 | "id": "BkfQ73pSgwQo" 1305 | }, 1306 | "outputs": [], 1307 | "source": [ 1308 | "class DiscriminatorTransformerEncoderBlock(nn.Sequential):\n", 1309 | " def __init__(self,\n", 1310 | " emb_size=384,\n", 1311 | " drop_p=0.,\n", 1312 | " forward_expansion=4,\n", 1313 | " forward_drop_p=0.,\n", 1314 | " **kwargs):\n", 1315 | " super().__init__(\n", 1316 | " ResidualAdd(nn.Sequential(\n", 1317 | " nn.LayerNorm(emb_size),\n", 1318 | " MultiHeadAttention(emb_size, **kwargs),\n", 1319 | " nn.Dropout(drop_p)\n", 1320 | " )),\n", 1321 | " ResidualAdd(nn.Sequential(\n", 1322 | " nn.LayerNorm(emb_size),\n", 1323 | " nn.Sequential(\n", 1324 | " spectral_norm(nn.Linear(emb_size, forward_expansion * emb_size)),\n", 1325 | " nn.GELU(),\n", 1326 | " nn.Dropout(forward_drop_p),\n", 1327 | " spectral_norm(nn.Linear(forward_expansion * emb_size, emb_size)),\n", 1328 | " ),\n", 1329 | " nn.Dropout(drop_p)\n", 1330 | " )\n", 1331 | " ))" 1332 | ] 1333 | }, 1334 | { 1335 | "cell_type": "code", 1336 | "execution_count": null, 1337 | "metadata": { 1338 | "id": "F6IYQuoNg9EA" 1339 | }, 1340 | "outputs": [], 1341 | "source": [ 1342 | "class DiscriminatorTransformerEncoder(nn.Sequential):\n", 1343 | " def __init__(self, depth=4, **kwargs):\n", 1344 | " super().__init__(*[DiscriminatorTransformerEncoderBlock(**kwargs) for _ in range(depth)])\n", 1345 | "\n", 1346 | "class ClassificationHead(nn.Sequential):\n", 1347 | " def __init__(self, emb_size=384, class_size_1=4098, class_size_2=1024, class_size_3=512, n_classes=10):\n", 1348 | " super().__init__(\n", 1349 | " nn.LayerNorm(emb_size),\n", 1350 | " spectral_norm(nn.Linear(emb_size, class_size_1)),\n", 1351 | " nn.GELU(),\n", 1352 | " spectral_norm(nn.Linear(class_size_1, class_size_2)),\n", 1353 | " nn.GELU(),\n", 1354 | " spectral_norm(nn.Linear(class_size_2, class_size_3)),\n", 1355 | " nn.GELU(),\n", 1356 | " spectral_norm(nn.Linear(class_size_3, n_classes)),\n", 1357 | " nn.GELU(),\n", 1358 | " )\n", 1359 | "\n", 1360 | " def forward(self, x):\n", 1361 | " # Take only the cls token outputs\n", 1362 | " x = x[:, 0, :]\n", 1363 | " return super().forward(x)" 1364 | ] 1365 | }, 1366 | { 1367 | "cell_type": "code", 1368 | "execution_count": null, 1369 | "metadata": { 1370 | "id": "5oALpYkMhCGV" 1371 | }, 1372 | "outputs": [], 1373 | "source": [ 1374 | "class ViT(nn.Sequential):\n", 1375 | " def __init__(self, \n", 1376 | " in_channels=3,\n", 1377 | " patch_size=4,\n", 1378 | " stride_size=4,\n", 1379 | " emb_size=384,\n", 1380 | " image_size=32,\n", 1381 | " depth=4,\n", 1382 | " n_classes=1,\n", 1383 | " diffaugment='color,translation,cutout',\n", 1384 | " **kwargs):\n", 1385 | " self.diffaugment = diffaugment\n", 1386 | " super().__init__(\n", 1387 | " PatchEmbedding(in_channels, patch_size, stride_size, emb_size, image_size),\n", 1388 | " DiscriminatorTransformerEncoder(depth, emb_size=emb_size, **kwargs),\n", 1389 | " ClassificationHead(emb_size, n_classes=n_classes)\n", 1390 | " )\n", 1391 | " \n", 1392 | " def forward(self, img, do_augment=True):\n", 1393 | " if do_augment:\n", 1394 | " img = DiffAugment(img, policy=self.diffaugment)\n", 1395 | " return super().forward(img)" 1396 | ] 1397 | }, 1398 | { 1399 | "cell_type": "markdown", 1400 | "metadata": { 1401 | "id": "Wu4HNFgpaLZ2" 1402 | }, 1403 | "source": [ 1404 | "# CNN Discriminator" 1405 | ] 1406 | }, 1407 | { 1408 | "cell_type": "code", 1409 | "execution_count": null, 1410 | "metadata": { 1411 | "id": "uyj4W0kQLJ2j" 1412 | }, 1413 | "outputs": [], 1414 | "source": [ 1415 | "class CNN(nn.Sequential):\n", 1416 | " def __init__(self,\n", 1417 | " diffaugment='color,translation,cutout',\n", 1418 | " **kwargs):\n", 1419 | " self.diffaugment = diffaugment\n", 1420 | " super().__init__(\n", 1421 | " nn.Conv2d(3,32,kernel_size=3,padding=1),\n", 1422 | " nn.ReLU(),\n", 1423 | " nn.Conv2d(32,64,kernel_size=3,stride=1,padding=1),\n", 1424 | " nn.ReLU(),\n", 1425 | " nn.MaxPool2d(2,2),\n", 1426 | "\n", 1427 | " nn.Conv2d(64,128,kernel_size=3,stride=1,padding=1),\n", 1428 | " nn.ReLU(),\n", 1429 | " nn.Conv2d(128,128,kernel_size=3,stride=1,padding=1),\n", 1430 | " nn.ReLU(),\n", 1431 | " nn.MaxPool2d(2,2),\n", 1432 | "\n", 1433 | " nn.Conv2d(128,256,kernel_size=3,stride=1,padding=1),\n", 1434 | " nn.ReLU(),\n", 1435 | " nn.Conv2d(256,256,kernel_size=3,stride=1,padding=1),\n", 1436 | " nn.ReLU(),\n", 1437 | " nn.MaxPool2d(2,2),\n", 1438 | "\n", 1439 | " nn.Flatten(),\n", 1440 | " nn.Linear(256*4*4,1024),\n", 1441 | " nn.ReLU(),\n", 1442 | " nn.Linear(1024,512),\n", 1443 | " nn.ReLU(),\n", 1444 | " nn.Linear(512,1)\n", 1445 | " )\n", 1446 | " \n", 1447 | " def forward(self, img, do_augment=True):\n", 1448 | " if do_augment:\n", 1449 | " img = DiffAugment(img, policy=self.diffaugment)\n", 1450 | " return super().forward(img)" 1451 | ] 1452 | }, 1453 | { 1454 | "cell_type": "markdown", 1455 | "metadata": { 1456 | "id": "wbniFkThhR5T" 1457 | }, 1458 | "source": [ 1459 | "# StyleGAN2 Discriminator" 1460 | ] 1461 | }, 1462 | { 1463 | "cell_type": "code", 1464 | "execution_count": null, 1465 | "metadata": { 1466 | "id": "ZW_m-p4pg34m" 1467 | }, 1468 | "outputs": [], 1469 | "source": [ 1470 | "class StyleGanDiscriminator(stylegan2_pytorch.Discriminator):\n", 1471 | " def __init__(self,\n", 1472 | " diffaugment='color,translation,cutout',\n", 1473 | " **kwargs):\n", 1474 | " self.diffaugment = diffaugment\n", 1475 | " super().__init__(**kwargs)\n", 1476 | " def forward(self, img, do_augment=True):\n", 1477 | " if do_augment:\n", 1478 | " img = DiffAugment(img, policy=self.diffaugment)\n", 1479 | " out, _ = super().forward(img)\n", 1480 | " return out" 1481 | ] 1482 | }, 1483 | { 1484 | "cell_type": "markdown", 1485 | "metadata": { 1486 | "id": "WsKPASQ3XtSU" 1487 | }, 1488 | "source": [ 1489 | "# Diversity Loss" 1490 | ] 1491 | }, 1492 | { 1493 | "cell_type": "code", 1494 | "execution_count": null, 1495 | "metadata": { 1496 | "id": "SB52QuhSWhRk" 1497 | }, 1498 | "outputs": [], 1499 | "source": [ 1500 | "def diversity_loss(images):\n", 1501 | " num_images_to_calculate_on = 10\n", 1502 | " num_pairs = num_images_to_calculate_on * (num_images_to_calculate_on - 1) // 2\n", 1503 | "\n", 1504 | " scale_factor = 5\n", 1505 | "\n", 1506 | " loss = torch.zeros(1, dtype=torch.float, device=device, requires_grad=True)\n", 1507 | " i = 0\n", 1508 | " for a_id in range(num_images_to_calculate_on):\n", 1509 | " for b_id in range(a_id+1, num_images_to_calculate_on):\n", 1510 | " img_a = images[a_id]\n", 1511 | " img_b = images[b_id]\n", 1512 | " img_a_l2 = torch.norm(img_a)\n", 1513 | " img_b_l2 = torch.norm(img_b)\n", 1514 | " img_a, img_b = img_a.flatten(), img_b.flatten()\n", 1515 | "\n", 1516 | " # print(img_a_l2, img_b_l2, img_a.shape, img_b.shape)\n", 1517 | "\n", 1518 | " a_b_loss = scale_factor * (img_a.t() @ img_b) / (img_a_l2 * img_b_l2)\n", 1519 | " # print(a_b_loss)\n", 1520 | " loss = loss + torch.sigmoid(a_b_loss)\n", 1521 | " i += 1\n", 1522 | " loss = loss.sum() / num_pairs\n", 1523 | " return loss" 1524 | ] 1525 | }, 1526 | { 1527 | "cell_type": "markdown", 1528 | "metadata": { 1529 | "id": "RFkNgPe5J4WZ" 1530 | }, 1531 | "source": [ 1532 | "# Normal distribution init weight" 1533 | ] 1534 | }, 1535 | { 1536 | "cell_type": "code", 1537 | "execution_count": null, 1538 | "metadata": { 1539 | "id": "kg23vO0CGoPr" 1540 | }, 1541 | "outputs": [], 1542 | "source": [ 1543 | "def init_normal(m):\n", 1544 | " if type(m) == nn.Linear:\n", 1545 | " if 'weight' in m.__dict__.keys():\n", 1546 | " m.weight.data.normal_(0.0,1)" 1547 | ] 1548 | }, 1549 | { 1550 | "cell_type": "markdown", 1551 | "metadata": { 1552 | "id": "9Akp8Ybo483q" 1553 | }, 1554 | "source": [ 1555 | "# Experiments" 1556 | ] 1557 | }, 1558 | { 1559 | "cell_type": "code", 1560 | "execution_count": null, 1561 | "metadata": { 1562 | "id": "CS5sfmIuLj8P" 1563 | }, 1564 | "outputs": [], 1565 | "source": [ 1566 | "if generator_type == \"vitgan\":\n", 1567 | " # Create the Generator\n", 1568 | " Generator = GeneratorViT( patch_size=patch_size,\n", 1569 | " image_size=image_size,\n", 1570 | " style_mlp_layers=style_mlp_layers,\n", 1571 | " latent_dim=latent_dim,\n", 1572 | " hidden_size=hidden_size,\n", 1573 | " combine_patch_embeddings=combine_patch_embeddings,\n", 1574 | " combined_embedding_size=combined_embedding_size,\n", 1575 | " sln_paremeter_size=sln_paremeter_size,\n", 1576 | " num_heads=num_heads,\n", 1577 | " depth=depth,\n", 1578 | " forward_drop_p=dropout_p,\n", 1579 | " bias=bias,\n", 1580 | " weight_modulation=weight_modulation,\n", 1581 | " siren_hidden_layers=siren_hidden_layers,\n", 1582 | " demodulation=demodulation,\n", 1583 | " ).to(device)\n", 1584 | " \n", 1585 | " # use the modules apply function to recursively apply the initialization\n", 1586 | " Generator.apply(init_normal)\n", 1587 | "\n", 1588 | " num_patches_x = int(image_size//out_patch_size)\n", 1589 | "\n", 1590 | " if os.path.exists(f'{experiment_folder_name}/weights/Generator.pth'):\n", 1591 | " Generator = torch.load(f'{experiment_folder_name}/weights/Generator.pth')\n", 1592 | "\n", 1593 | " wandb.watch(Generator)\n", 1594 | "\n", 1595 | "elif generator_type == \"cnn\":\n", 1596 | " cnn_generator = CNNGenerator().to(device)\n", 1597 | "\n", 1598 | " cnn_generator.apply(init_normal)\n", 1599 | "\n", 1600 | " if os.path.exists(f'{experiment_folder_name}/weights/cnn_generator.pth'):\n", 1601 | " cnn_generator = torch.load(f'{experiment_folder_name}/weights/cnn_generator.pth')\n", 1602 | "\n", 1603 | " wandb.watch(cnn_generator)\n", 1604 | "\n", 1605 | "# Create the three types of discriminators\n", 1606 | "if discriminator_type == \"vitgan\":\n", 1607 | " Discriminator = ViT(discriminator=True,\n", 1608 | " patch_size=patch_size*2,\n", 1609 | " stride_size=patch_size,\n", 1610 | " n_classes=1, \n", 1611 | " num_heads=num_heads,\n", 1612 | " depth=depth,\n", 1613 | " forward_drop_p=dropout_p,\n", 1614 | " ).to(device)\n", 1615 | " \n", 1616 | " Discriminator.apply(init_normal)\n", 1617 | " \n", 1618 | " if os.path.exists(f'{experiment_folder_name}/weights/discriminator.pth'):\n", 1619 | " Discriminator = torch.load(f'{experiment_folder_name}/weights/discriminator.pth')\n", 1620 | "\n", 1621 | " wandb.watch(Discriminator)\n", 1622 | "\n", 1623 | "elif discriminator_type == \"cnn\":\n", 1624 | " cnn_discriminator = CNN().to(device)\n", 1625 | "\n", 1626 | " cnn_discriminator.apply(init_normal)\n", 1627 | "\n", 1628 | " if os.path.exists(f'{experiment_folder_name}/weights/discriminator.pth'):\n", 1629 | " cnn_discriminator = torch.load(f'{experiment_folder_name}/weights/discriminator.pth')\n", 1630 | "\n", 1631 | " wandb.watch(cnn_discriminator)\n", 1632 | "\n", 1633 | "elif discriminator_type == \"stylegan2\":\n", 1634 | " stylegan2_discriminator = StyleGanDiscriminator(image_size=32).to(device)\n", 1635 | "\n", 1636 | " # stylegan2_discriminator.apply(init_normal)\n", 1637 | "\n", 1638 | " if os.path.exists(f'{experiment_folder_name}/weights/discriminator.pth'):\n", 1639 | " stylegan2_discriminator = torch.load(f'{experiment_folder_name}/weights/discriminator.pth')\n", 1640 | "\n", 1641 | " wandb.watch(stylegan2_discriminator)" 1642 | ] 1643 | }, 1644 | { 1645 | "cell_type": "markdown", 1646 | "metadata": { 1647 | "id": "j9EdXO0K8qFW" 1648 | }, 1649 | "source": [ 1650 | "# Testing the generator\n", 1651 | "\n", 1652 | "Train to match fixed latent values to fixed images" 1653 | ] 1654 | }, 1655 | { 1656 | "cell_type": "code", 1657 | "execution_count": null, 1658 | "metadata": { 1659 | "id": "1BsgyaZMMi2f" 1660 | }, 1661 | "outputs": [], 1662 | "source": [ 1663 | "total_steps = 0 # Since the whole image is our dataset, this just means 500 gradient descent steps.\n", 1664 | "steps_til_summary = 50\n", 1665 | "\n", 1666 | "if generator_type == \"vitgan\":\n", 1667 | " params = Generator.parameters()\n", 1668 | "else:\n", 1669 | " # z = torch.randn(batch_size, latent_dim, 1, 1, device=device)\n", 1670 | " params = list(cnn_generator.parameters())\n", 1671 | "optim = torch.optim.Adam(lr=lr, params=params)\n", 1672 | "ema = ExponentialMovingAverage(params, decay=0.995)\n", 1673 | "\n", 1674 | "ground_truth, _ = next(iter(trainloader))\n", 1675 | "ground_truth = ground_truth.permute(0, 2, 3, 1).view((-1, image_size**2, out_features))\n", 1676 | "ground_truth = ground_truth.to(device)\n", 1677 | "\n", 1678 | "z = torch.randn([batch_size, latent_dim]).to(device)" 1679 | ] 1680 | }, 1681 | { 1682 | "cell_type": "code", 1683 | "execution_count": null, 1684 | "metadata": { 1685 | "id": "Swv4ZQt94833" 1686 | }, 1687 | "outputs": [], 1688 | "source": [ 1689 | "for step in range(total_steps):\n", 1690 | " if generator_type == \"vitgan\":\n", 1691 | " model_output = Generator(z)\n", 1692 | " elif generator_type == \"cnn\":\n", 1693 | " model_output = cnn_generator(z)\n", 1694 | " model_output = model_output.permute([0, 2, 3, 1]).view([-1, image_size**2, out_features])\n", 1695 | " loss = ((model_output - ground_truth)**2).mean()\n", 1696 | " \n", 1697 | " if not step % steps_til_summary:\n", 1698 | " print(\"Step %d, Total loss %0.6f\" % (step, loss))\n", 1699 | "\n", 1700 | " fig, axes = plt.subplots(2,8, figsize=(24,6))\n", 1701 | " for i in range(8):\n", 1702 | " j = np.random.randint(0, batch_size-1)\n", 1703 | " img = model_output[j].cpu().view(32,32,3).detach().numpy()\n", 1704 | " img -= img.min()\n", 1705 | " img /= img.max()\n", 1706 | " axes[0,i].imshow(img)\n", 1707 | " g_img = ground_truth[j].cpu().view(32,32,3).detach().numpy()\n", 1708 | " g_img -= g_img.min()\n", 1709 | " g_img /= g_img.max()\n", 1710 | " axes[1,i].imshow(g_img)\n", 1711 | "\n", 1712 | " plt.show()\n", 1713 | "\n", 1714 | " optim.zero_grad()\n", 1715 | " loss.backward()\n", 1716 | " optim.step()\n", 1717 | " ema.update()" 1718 | ] 1719 | }, 1720 | { 1721 | "cell_type": "markdown", 1722 | "metadata": { 1723 | "id": "pcWiPCochjJ7" 1724 | }, 1725 | "source": [ 1726 | "# Training" 1727 | ] 1728 | }, 1729 | { 1730 | "cell_type": "code", 1731 | "execution_count": null, 1732 | "metadata": { 1733 | "id": "Ad_UXnQJCOeN" 1734 | }, 1735 | "outputs": [], 1736 | "source": [ 1737 | "os.makedirs(f\"{experiment_folder_name}/weights\", exist_ok = True)\n", 1738 | "os.makedirs(f\"{experiment_folder_name}/samples\", exist_ok = True)\n", 1739 | "\n", 1740 | "# Loss function\n", 1741 | "criterion = nn.BCEWithLogitsLoss()\n", 1742 | "\n", 1743 | "if discriminator_type == \"cnn\": discriminator = cnn_discriminator\n", 1744 | "elif discriminator_type == \"stylegan2\": discriminator = stylegan2_discriminator\n", 1745 | "elif discriminator_type == \"vitgan\": discriminator = Discriminator\n", 1746 | "\n", 1747 | "if generator_type == \"cnn\":\n", 1748 | " params = cnn_generator.parameters()\n", 1749 | "else:\n", 1750 | " params = Generator.parameters()\n", 1751 | "optim_g = torch.optim.Adam(lr=lr, params=params, betas=beta)\n", 1752 | "optim_d = torch.optim.Adam(lr=lr_dis, params=discriminator.parameters(), betas=beta)\n", 1753 | "ema = ExponentialMovingAverage(params, decay=0.995)\n", 1754 | "\n", 1755 | "fixed_noise = torch.FloatTensor(np.random.normal(0, 1, (16, latent_dim))).to(device)\n", 1756 | "\n", 1757 | "discriminator_f_img = torch.zeros([batch_size, 3, image_size, image_size]).to(device)\n", 1758 | "\n", 1759 | "trainset_len = len(trainloader.dataset)\n", 1760 | "\n", 1761 | "step = 0\n", 1762 | "for epoch in range(epochs):\n", 1763 | " for batch_id, batch in enumerate(trainloader):\n", 1764 | " step += 1\n", 1765 | "\n", 1766 | " # Train discriminator\n", 1767 | "\n", 1768 | " # Forward + Backward with real images\n", 1769 | " r_img = batch[0].to(device)\n", 1770 | " r_logit = discriminator(r_img).flatten()\n", 1771 | " r_label = torch.ones(r_logit.shape[0]).to(device)\n", 1772 | "\n", 1773 | " lossD_real = criterion(r_logit, r_label)\n", 1774 | " \n", 1775 | " lossD_bCR_real = F.mse_loss(r_logit, discriminator(r_img, do_augment=False))\n", 1776 | "\n", 1777 | " # Forward + Backward with fake images\n", 1778 | " latent_vector = torch.FloatTensor(np.random.normal(0, 1, (batch_size, latent_dim))).to(device)\n", 1779 | "\n", 1780 | " if generator_type == \"vitgan\":\n", 1781 | " f_img = Generator(latent_vector)\n", 1782 | " f_img = f_img.reshape([-1, image_size, image_size, out_features])\n", 1783 | " f_img = f_img.permute(0, 3, 1, 2)\n", 1784 | " else:\n", 1785 | " model_output = cnn_generator(latent_vector)\n", 1786 | " f_img = model_output\n", 1787 | " \n", 1788 | " assert f_img.size(0) == batch_size, f_img.shape\n", 1789 | " assert f_img.size(1) == out_features, f_img.shape\n", 1790 | " assert f_img.size(2) == image_size, f_img.shape\n", 1791 | " assert f_img.size(3) == image_size, f_img.shape\n", 1792 | "\n", 1793 | " f_label = torch.zeros(batch_size).to(device)\n", 1794 | " # Save the a single generated image to the discriminator training data\n", 1795 | " if batch_size_history_discriminator:\n", 1796 | " discriminator_f_img[step % batch_size] = f_img[0].detach()\n", 1797 | " f_logit_history = discriminator(discriminator_f_img).flatten()\n", 1798 | " lossD_fake_history = criterion(f_logit_history, f_label)\n", 1799 | " else: lossD_fake_history = 0\n", 1800 | " # Train the discriminator on the images, generated only from this batch\n", 1801 | " f_logit = discriminator(f_img.detach()).flatten()\n", 1802 | " lossD_fake = criterion(f_logit, f_label)\n", 1803 | " \n", 1804 | " lossD_bCR_fake = F.mse_loss(f_logit, discriminator(f_img, do_augment=False))\n", 1805 | " \n", 1806 | " f_noise_input = torch.FloatTensor(np.random.rand(*f_img.shape)*2 - 1).to(device)\n", 1807 | " f_noise_logit = discriminator(f_noise_input).flatten()\n", 1808 | " lossD_noise = criterion(f_noise_logit, f_label)\n", 1809 | "\n", 1810 | " lossD = lossD_real * 0.5 +\\\n", 1811 | " lossD_fake * 0.5 +\\\n", 1812 | " lossD_fake_history * lambda_lossD_history +\\\n", 1813 | " lossD_noise * lambda_lossD_noise +\\\n", 1814 | " lossD_bCR_real * lambda_bCR_real +\\\n", 1815 | " lossD_bCR_fake * lambda_bCR_fake\n", 1816 | "\n", 1817 | " optim_d.zero_grad()\n", 1818 | " lossD.backward()\n", 1819 | " optim_d.step()\n", 1820 | " \n", 1821 | " # Train Generator\n", 1822 | "\n", 1823 | " if generator_type == \"vitgan\":\n", 1824 | " f_img = Generator(latent_vector)\n", 1825 | " f_img = f_img.reshape([-1, image_size, image_size, out_features])\n", 1826 | " f_img = f_img.permute(0, 3, 1, 2)\n", 1827 | " else:\n", 1828 | " model_output = cnn_generator(latent_vector)\n", 1829 | " f_img = model_output\n", 1830 | " \n", 1831 | " assert f_img.size(0) == batch_size\n", 1832 | " assert f_img.size(1) == out_features\n", 1833 | " assert f_img.size(2) == image_size\n", 1834 | " assert f_img.size(3) == image_size\n", 1835 | "\n", 1836 | " f_logit = discriminator(f_img).flatten()\n", 1837 | " r_label = torch.ones(batch_size).to(device)\n", 1838 | " lossG_main = criterion(f_logit, r_label)\n", 1839 | " \n", 1840 | " lossG_diversity = diversity_loss(f_img) * lambda_diversity_penalty\n", 1841 | " lossG = lossG_main + lossG_diversity\n", 1842 | " \n", 1843 | " optim_g.zero_grad()\n", 1844 | " lossG.backward()\n", 1845 | " optim_g.step()\n", 1846 | " ema.update()\n", 1847 | "\n", 1848 | " writer.add_scalar(\"Loss/Generator\", lossG_main, step)\n", 1849 | " writer.add_scalar(\"Loss/Gen(diversity)\", lossG_diversity, step)\n", 1850 | " writer.add_scalar(\"Loss/Dis(real)\", lossD_real, step)\n", 1851 | " writer.add_scalar(\"Loss/Dis(fake)\", lossD_fake, step)\n", 1852 | " writer.add_scalar(\"Loss/Dis(fake_history)\", lossD_fake_history, step)\n", 1853 | " writer.add_scalar(\"Loss/Dis(noise)\", lossD_noise, step)\n", 1854 | " writer.add_scalar(\"Loss/Dis(bCR_fake)\", lossD_bCR_fake * lambda_bCR_fake, step)\n", 1855 | " writer.add_scalar(\"Loss/Dis(bCR_real)\", lossD_bCR_real * lambda_bCR_real, step)\n", 1856 | " writer.flush()\n", 1857 | "\n", 1858 | " wandb.log({\n", 1859 | " 'Generator': lossG_main,\n", 1860 | " 'Gen(diversity)': lossG_diversity,\n", 1861 | " 'Dis(real)': lossD_real,\n", 1862 | " 'Dis(fake)': lossD_fake,\n", 1863 | " 'Dis(fake_history)': lossD_fake_history,\n", 1864 | " 'Dis(noise)': lossD_noise,\n", 1865 | " 'Dis(bCR_fake)': lossD_bCR_fake * lambda_bCR_fake,\n", 1866 | " 'Dis(bCR_real)': lossD_bCR_real * lambda_bCR_real\n", 1867 | " })\n", 1868 | "\n", 1869 | " if batch_id%20 == 0:\n", 1870 | " print(f'epoch {epoch}/{epochs}; batch {batch_id}/{int(trainset_len/batch_size)}')\n", 1871 | " print(f'Generator: {\"{:.3f}\".format(float(lossG_main))}, '+\\\n", 1872 | " f'Gen(diversity): {\"{:.3f}\".format(float(lossG_diversity))}, '+\\\n", 1873 | " f'Dis(real): {\"{:.3f}\".format(float(lossD_real))}, '+\\\n", 1874 | " f'Dis(fake): {\"{:.3f}\".format(float(lossD_fake))}, '+\\\n", 1875 | " f'Dis(fake_history): {\"{:.3f}\".format(float(lossD_fake_history))}, '+\\\n", 1876 | " f'Dis(noise) {\"{:.3f}\".format(float(lossD_noise))}, '+\\\n", 1877 | " f'Dis(bCR_fake): {\"{:.3f}\".format(float(lossD_bCR_fake * lambda_bCR_fake))}, '+\\\n", 1878 | " f'Dis(bCR_real): {\"{:.3f}\".format(float(lossD_bCR_real * lambda_bCR_real))}')\n", 1879 | "\n", 1880 | " # Plot 8 randomly selected samples\n", 1881 | " fig, axes = plt.subplots(1,8, figsize=(24,3))\n", 1882 | " output = f_img.permute(0, 2, 3, 1)\n", 1883 | " for i in range(8):\n", 1884 | " j = np.random.randint(0, batch_size-1)\n", 1885 | " img = output[j].cpu().view(32,32,3).detach().numpy()\n", 1886 | " img -= img.min()\n", 1887 | " img /= img.max()\n", 1888 | " axes[i].imshow(img)\n", 1889 | " plt.show()\n", 1890 | "\n", 1891 | " # if step % sample_interval == 0:\n", 1892 | " if generator_type == \"vitgan\":\n", 1893 | " Generator.eval()\n", 1894 | " vis = Generator(fixed_noise)\n", 1895 | " vis = vis.reshape([-1, image_size, image_size, out_features])\n", 1896 | " vis = vis.permute(0, 3, 1, 2)\n", 1897 | " else:\n", 1898 | " model_output = cnn_generator(fixed_noise)\n", 1899 | " vis = model_output\n", 1900 | "\n", 1901 | " assert vis.shape[0] == fixed_noise.shape[0], f'vis.shape[0] is {vis.shape[0]}, but should be {fixed_noise.shape[0]}'\n", 1902 | " assert vis.shape[1] == out_features, f'vis.shape[1] is {vis.shape[1]}, but should be {out_features}'\n", 1903 | " assert vis.shape[2] == image_size, f'vis.shape[2] is {vis.shape[2]}, but should be {image_size}'\n", 1904 | " assert vis.shape[3] == image_size, f'vis.shape[3] is {vis.shape[3]}, but should be {image_size}'\n", 1905 | " \n", 1906 | " vis.detach().cpu()\n", 1907 | " vis = make_grid(vis, nrow = 4, padding = 5, normalize = True)\n", 1908 | " writer.add_image(f'Generated/epoch_{epoch}', vis)\n", 1909 | " wandb.log({'examples': wandb.Image(vis)})\n", 1910 | "\n", 1911 | " vis = T.ToPILImage()(vis)\n", 1912 | " vis.save(f'{experiment_folder_name}/samples/vis{epoch}.jpg')\n", 1913 | " if generator_type == \"vitgan\":\n", 1914 | " Generator.train()\n", 1915 | " else:\n", 1916 | " cnn_generator.train()\n", 1917 | " print(f\"Save sample to {experiment_folder_name}/samples/vis{epoch}.jpg\")\n", 1918 | "\n", 1919 | " # Save the checkpoints.\n", 1920 | " if generator_type == \"vitgan\":\n", 1921 | " torch.save(Generator, f'{experiment_folder_name}/weights/Generator.pth')\n", 1922 | " elif generator_type == \"cnn\":\n", 1923 | " torch.save(cnn_generator, f'{experiment_folder_name}/weights/cnn_generator.pth')\n", 1924 | " torch.save(discriminator, f'{experiment_folder_name}/weights/discriminator.pth')\n", 1925 | " print(\"Save model state.\")\n", 1926 | "\n", 1927 | "writer.close()" 1928 | ] 1929 | }, 1930 | { 1931 | "cell_type": "code", 1932 | "execution_count": null, 1933 | "metadata": { 1934 | "id": "LWDnspPtsot5" 1935 | }, 1936 | "outputs": [], 1937 | "source": [] 1938 | } 1939 | ], 1940 | "metadata": { 1941 | "accelerator": "GPU", 1942 | "colab": { 1943 | "collapsed_sections": [], 1944 | "name": "ViTGAN-pytorch.ipynb", 1945 | "private_outputs": true, 1946 | "provenance": [] 1947 | }, 1948 | "kernelspec": { 1949 | "display_name": "Python 3", 1950 | "language": "python", 1951 | "name": "python3" 1952 | }, 1953 | "language_info": { 1954 | "codemirror_mode": { 1955 | "name": "ipython", 1956 | "version": 3 1957 | }, 1958 | "file_extension": ".py", 1959 | "mimetype": "text/x-python", 1960 | "name": "python", 1961 | "nbconvert_exporter": "python", 1962 | "pygments_lexer": "ipython3", 1963 | "version": "3.7.4" 1964 | }, 1965 | "pycharm": { 1966 | "stem_cell": { 1967 | "cell_type": "raw", 1968 | "metadata": { 1969 | "collapsed": false 1970 | }, 1971 | "source": [] 1972 | } 1973 | } 1974 | }, 1975 | "nbformat": 4, 1976 | "nbformat_minor": 0 1977 | } 1978 | --------------------------------------------------------------------------------