├── models
├── __init__.py
├── stylegan2_discriminator.py
├── cnn_discriminator.py
├── cnn_generator.py
├── multi_head.py
├── utils.py
├── diffaugment.py
├── vitgan_discriminator.py
└── vitgan_generator.py
├── LICENSE
├── README.md
├── main.py
└── ViTGAN_pytorch.ipynb
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .cnn_discriminator import CNN
2 | from .cnn_generator import CNNGenerator
3 | from .stylegan2_discriminator import StyleGanDiscriminator
4 | from .vitgan_discriminator import ViT
5 | from .vitgan_generator import GeneratorViT
6 | from .diffaugment import DiffAugment
7 | from .utils import spectral_norm
8 |
9 | __all__ = [CNN, CNNGenerator, StyleGanDiscriminator, ViT, GeneratorViT, DiffAugment, spectral_norm]
10 |
--------------------------------------------------------------------------------
/models/stylegan2_discriminator.py:
--------------------------------------------------------------------------------
1 | from stylegan2_pytorch import stylegan2_pytorch
2 | from .diffaugment import DiffAugment
3 |
4 | # StyleGAN2 Discriminator
5 |
6 | class StyleGanDiscriminator(stylegan2_pytorch.Discriminator):
7 | def __init__(self,
8 | diffaugment='color,translation,cutout',
9 | **kwargs):
10 | self.diffaugment = diffaugment
11 | super().__init__(**kwargs)
12 | def forward(self, img, do_augment=True):
13 | if do_augment:
14 | img = DiffAugment(img, policy=self.diffaugment)
15 | out, _ = super().forward(img)
16 | return out
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Teodor Toshkov
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/models/cnn_discriminator.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | from .diffaugment import DiffAugment
3 |
4 | # CNN Discriminator
5 |
6 | class CNN(nn.Sequential):
7 | def __init__(self,
8 | diffaugment='color,translation,cutout',
9 | **kwargs):
10 | self.diffaugment = diffaugment
11 | super().__init__(
12 | nn.Conv2d(3,32,kernel_size=3,padding=1),
13 | nn.ReLU(),
14 | nn.Conv2d(32,64,kernel_size=3,stride=1,padding=1),
15 | nn.ReLU(),
16 | nn.MaxPool2d(2,2),
17 |
18 | nn.Conv2d(64,128,kernel_size=3,stride=1,padding=1),
19 | nn.ReLU(),
20 | nn.Conv2d(128,128,kernel_size=3,stride=1,padding=1),
21 | nn.ReLU(),
22 | nn.MaxPool2d(2,2),
23 |
24 | nn.Conv2d(128,256,kernel_size=3,stride=1,padding=1),
25 | nn.ReLU(),
26 | nn.Conv2d(256,256,kernel_size=3,stride=1,padding=1),
27 | nn.ReLU(),
28 | nn.MaxPool2d(2,2),
29 |
30 | nn.Flatten(),
31 | nn.Linear(256*4*4,1024),
32 | nn.ReLU(),
33 | nn.Linear(1024,512),
34 | nn.ReLU(),
35 | nn.Linear(512,1)
36 | )
37 |
38 | def forward(self, img, do_augment=True):
39 | if do_augment:
40 | img = DiffAugment(img, policy=self.diffaugment)
41 | return super().forward(img)
--------------------------------------------------------------------------------
/models/cnn_generator.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 |
3 | # CNN Generator
4 |
5 | class CNNGenerator(nn.Module):
6 | def __init__(self, hidden_size, latent_dim):
7 | super(CNNGenerator, self).__init__()
8 | self.hidden_size = hidden_size
9 | self.w = nn.Linear(latent_dim, hidden_size * 2 * 4 * 4, bias=False)
10 | self.main = nn.Sequential(
11 | # input is Z, going into a convolution
12 | nn.BatchNorm2d(hidden_size * 2),
13 | nn.ReLU(True),
14 | # state size. (ngf*8) x 4 x 4
15 | nn.ConvTranspose2d(hidden_size * 2, hidden_size, 4, 2, 1, bias=False),
16 | nn.BatchNorm2d(hidden_size),
17 | nn.ReLU(True),
18 | # state size. (ngf*4) x 8 x 8
19 | nn.ConvTranspose2d( hidden_size, hidden_size // 2, 4, 2, 1, bias=False),
20 | nn.BatchNorm2d(hidden_size // 2),
21 | nn.ReLU(True),
22 | # state size. (ngf*2) x 16 x 16
23 | nn.ConvTranspose2d( hidden_size // 2, hidden_size // 4, 4, 2, 1, bias=False),
24 | nn.BatchNorm2d(hidden_size // 4),
25 | nn.ReLU(True),
26 | # state size. (ngf*2) x 32 x 32
27 | nn.ConvTranspose2d( hidden_size // 4, 3, 3, 1, 1, bias=False),
28 | nn.Tanh(),
29 | # state size. (nc) x 64 x 64
30 | )
31 |
32 | def forward(self, input):
33 | input = self.w(input).view((-1, self.hidden_size * 2, 4, 4))
34 | return self.main(input)
35 |
--------------------------------------------------------------------------------
/models/multi_head.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from einops import rearrange
4 | import torch.nn.functional as F
5 |
6 | from .utils import spectral_norm
7 |
8 | class MultiHeadAttention(nn.Module):
9 | def __init__(self, emb_size=384, num_heads=4, dropout=0, discriminator=False, **kwargs):
10 | super().__init__()
11 | self.emb_size = emb_size
12 | self.num_heads = num_heads
13 | self.discriminator = discriminator
14 | # fuse the queries, keys and values in one matrix
15 | self.qkv = nn.Linear(emb_size, emb_size * 3)
16 | self.att_drop = nn.Dropout(dropout)
17 | self.projection = nn.Linear(emb_size, emb_size)
18 | if self.discriminator:
19 | self.qkv = spectral_norm(self.qkv)
20 | self.projection = spectral_norm(self.projection)
21 |
22 | def forward(self, x, mask=None):
23 | # split keys, queries and values in num_heads
24 | qkv = rearrange(self.qkv(x), "b n (h d qkv) -> (qkv) b h n d", h=self.num_heads, qkv=3)
25 | queries, keys, values = qkv[0], qkv[1], qkv[2]
26 | if self.discriminator:
27 | # calculate L2-distances
28 | energy = torch.cdist(queries.contiguous(), keys.contiguous(), p=2)
29 | else:
30 | # sum up over the last axis
31 | energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys) # batch, num_heads, query_len, key_len
32 |
33 | if mask is not None:
34 | fill_value = torch.finfo(torch.float32).min
35 | energy.mask_fill(~mask, fill_value)
36 |
37 | scaling = self.emb_size ** (1/2)
38 | att = F.softmax(energy, dim=-1) / scaling
39 | att = self.att_drop(att)
40 | # sum up over the third axis
41 | out = torch.einsum('bhal, bhlv -> bhav ', att, values)
42 | out = rearrange(out, "b h n d -> b n (h d)")
43 | out = self.projection(out)
44 | return out
--------------------------------------------------------------------------------
/models/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import Parameter
4 | import numpy as np
5 |
6 | # Improved Spectral Normalization (ISN)
7 | # Reference code: https://github.com/koshian2/SNGAN
8 | # When updating the weights, normalize the weights' norm to its norm at initialization.
9 |
10 | def l2normalize(v, eps=1e-4):
11 | return v / (v.norm() + eps)
12 |
13 | class spectral_norm(nn.Module):
14 | def __init__(self, module, name='weight', power_iterations=1):
15 | super().__init__()
16 | self.module = module
17 | self.name = name
18 | self.power_iterations = power_iterations
19 | if not self._made_params():
20 | self._make_params()
21 | self.w_init_sigma = None
22 | self.w_initalized = False
23 |
24 | def _update_u_v(self):
25 | u = getattr(self.module, self.name + "_u")
26 | v = getattr(self.module, self.name + "_v")
27 | w = getattr(self.module, self.name + "_bar")
28 |
29 | height = w.data.shape[0]
30 | _w = w.view(height, -1)
31 | for _ in range(self.power_iterations):
32 | v = l2normalize(torch.matmul(_w.t(), u))
33 | u = l2normalize(torch.matmul(_w, v))
34 |
35 | sigma = u.dot((_w).mv(v))
36 | if not self.w_initalized:
37 | self.w_init_sigma = np.array(sigma.expand_as(w).detach().cpu())
38 | self.w_initalized = True
39 | setattr(self.module, self.name, torch.tensor(self.w_init_sigma).to(sigma.device) * w / sigma.expand_as(w))
40 |
41 | def _made_params(self):
42 | try:
43 | getattr(self.module, self.name + "_u")
44 | getattr(self.module, self.name + "_v")
45 | getattr(self.module, self.name + "_bar")
46 | return True
47 | except AttributeError:
48 | return False
49 |
50 | def _make_params(self):
51 | w = getattr(self.module, self.name)
52 |
53 | height = w.data.shape[0]
54 | width = w.view(height, -1).data.shape[1]
55 |
56 | u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
57 | v = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
58 | u.data = l2normalize(u.data)
59 | v.data = l2normalize(v.data)
60 | w_bar = Parameter(w.data)
61 |
62 | del self.module._parameters[self.name]
63 | self.module.register_parameter(self.name + "_u", u)
64 | self.module.register_parameter(self.name + "_v", v)
65 | self.module.register_parameter(self.name + "_bar", w_bar)
66 |
67 | def forward(self, *args):
68 | self._update_u_v()
69 | return self.module.forward(*args)
--------------------------------------------------------------------------------
/models/diffaugment.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | # https://github.com/mit-han-lab/data-efficient-gans/blob/master/DiffAugment-stylegan2-pytorch/DiffAugment_pytorch.py
5 |
6 | def DiffAugment(x, policy='', channels_first=True):
7 | if policy:
8 | if not channels_first:
9 | x = x.permute(0, 3, 1, 2)
10 | for p in policy.split(','):
11 | for f in AUGMENT_FNS[p]:
12 | x = f(x)
13 | if not channels_first:
14 | x = x.permute(0, 2, 3, 1)
15 | x = x.contiguous()
16 | return x
17 |
18 |
19 | def rand_brightness(x):
20 | x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)
21 | return x
22 |
23 |
24 | def rand_saturation(x):
25 | x_mean = x.mean(dim=1, keepdim=True)
26 | x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean
27 | return x
28 |
29 |
30 | def rand_contrast(x):
31 | x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
32 | x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean
33 | return x
34 |
35 |
36 | def rand_translation(x, ratio=0.1):
37 | shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
38 | translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
39 | translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
40 | grid_batch, grid_x, grid_y = torch.meshgrid(
41 | torch.arange(x.size(0), dtype=torch.long, device=x.device),
42 | torch.arange(x.size(2), dtype=torch.long, device=x.device),
43 | torch.arange(x.size(3), dtype=torch.long, device=x.device),
44 | )
45 | grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
46 | grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
47 | x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
48 | x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2).contiguous()
49 | return x
50 |
51 |
52 | def rand_cutout(x, ratio=0.3):
53 | cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
54 | offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device)
55 | offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device)
56 | grid_batch, grid_x, grid_y = torch.meshgrid(
57 | torch.arange(x.size(0), dtype=torch.long, device=x.device),
58 | torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
59 | torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
60 | )
61 | grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
62 | grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
63 | mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
64 | mask[grid_batch, grid_x, grid_y] = 0
65 | x = x * mask.unsqueeze(1)
66 | return x
67 |
68 |
69 | AUGMENT_FNS = {
70 | 'color': [rand_brightness, rand_saturation, rand_contrast],
71 | 'translation': [rand_translation],
72 | 'cutout': [rand_cutout],
73 | }
--------------------------------------------------------------------------------
/models/vitgan_discriminator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from einops import repeat
4 | from einops.layers.torch import Rearrange
5 |
6 | from .diffaugment import DiffAugment
7 | from .multi_head import MultiHeadAttention
8 | from .utils import spectral_norm
9 |
10 | # Discriminator
11 |
12 | class PatchEmbedding(nn.Module):
13 | def __init__(self, in_channels=3, patch_size=4, stride_size=4, emb_size=384, image_size=32, batch_size=64):
14 | super().__init__()
15 | self.patch_size = patch_size
16 | self.projection = nn.Sequential(
17 | # using a conv layer instead of a linear one -> performance gains
18 | spectral_norm(nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=stride_size)),
19 | Rearrange('b e (h) (w) -> b (h w) e'),
20 | )
21 | self.cls_token = nn.Parameter(torch.randn(1, 1, emb_size))
22 | num_patches = ((image_size-patch_size+stride_size) // stride_size) **2 + 1
23 | self.positions = nn.Parameter(torch.randn(num_patches, emb_size))
24 | self.batch_size = batch_size
25 |
26 | def forward(self, x):
27 | b, _, _, _ = x.shape
28 | x = self.projection(x)
29 | cls_tokens = repeat(self.cls_token, '() n e -> b n e', b=b)
30 | # prepend the cls token to the input
31 | x = torch.cat([cls_tokens, x], dim=1)
32 | # add position embedding
33 | x += torch.sin(self.positions)
34 | return x
35 |
36 | class ResidualAdd(nn.Module):
37 | def __init__(self, fn):
38 | super().__init__()
39 | self.fn = fn
40 |
41 | def forward(self, x, **kwargs):
42 | res = x
43 | x = self.fn(x, **kwargs)
44 | x += res
45 | return x
46 |
47 | class DiscriminatorTransformerEncoderBlock(nn.Sequential):
48 | def __init__(self,
49 | emb_size=384,
50 | drop_p=0.,
51 | forward_expansion=4,
52 | forward_drop_p=0.,
53 | **kwargs):
54 | super().__init__(
55 | ResidualAdd(nn.Sequential(
56 | nn.LayerNorm(emb_size),
57 | MultiHeadAttention(emb_size, **kwargs),
58 | nn.Dropout(drop_p)
59 | )),
60 | ResidualAdd(nn.Sequential(
61 | nn.LayerNorm(emb_size),
62 | nn.Sequential(
63 | spectral_norm(nn.Linear(emb_size, forward_expansion * emb_size)),
64 | nn.GELU(),
65 | nn.Dropout(forward_drop_p),
66 | spectral_norm(nn.Linear(forward_expansion * emb_size, emb_size)),
67 | ),
68 | nn.Dropout(drop_p)
69 | )
70 | ))
71 |
72 | class DiscriminatorTransformerEncoder(nn.Sequential):
73 | def __init__(self, depth=4, **kwargs):
74 | super().__init__(*[DiscriminatorTransformerEncoderBlock(**kwargs) for _ in range(depth)])
75 |
76 | class ClassificationHead(nn.Sequential):
77 | def __init__(self, emb_size=384, class_size_1=4098, class_size_2=1024, class_size_3=512, n_classes=10):
78 | super().__init__(
79 | nn.LayerNorm(emb_size),
80 | spectral_norm(nn.Linear(emb_size, class_size_1)),
81 | nn.GELU(),
82 | spectral_norm(nn.Linear(class_size_1, class_size_2)),
83 | nn.GELU(),
84 | spectral_norm(nn.Linear(class_size_2, class_size_3)),
85 | nn.GELU(),
86 | spectral_norm(nn.Linear(class_size_3, n_classes)),
87 | nn.GELU(),
88 | )
89 |
90 | def forward(self, x):
91 | # Take only the cls token outputs
92 | x = x[:, 0, :]
93 | return super().forward(x)
94 |
95 | class ViT(nn.Sequential):
96 | def __init__(self,
97 | in_channels=3,
98 | patch_size=4,
99 | stride_size=4,
100 | emb_size=384,
101 | image_size=32,
102 | depth=4,
103 | n_classes=1,
104 | diffaugment='color,translation,cutout',
105 | **kwargs):
106 | self.diffaugment = diffaugment
107 | super().__init__(
108 | PatchEmbedding(in_channels, patch_size, stride_size, emb_size, image_size),
109 | DiscriminatorTransformerEncoder(depth, emb_size=emb_size, **kwargs),
110 | ClassificationHead(emb_size, n_classes=n_classes)
111 | )
112 |
113 | def forward(self, img, do_augment=True):
114 | if do_augment:
115 | img = DiffAugment(img, policy=self.diffaugment)
116 | return super().forward(img)
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ViTGAN-pytorch
2 | A PyTorch implementation of [VITGAN: Training GANs with Vision Transformers](https://arxiv.org/pdf/2107.04589v1.pdf)
3 |
4 | [](https://colab.research.google.com/drive/1kJJw6BYW01HgooCZ2zUDt54e1mXqITXH?usp=sharing)
5 |
6 | ## TODO:
7 | 1. [x] Use vectorized L2 distance in attention for **Discriminator**
8 | 2. [x] Overlapping Image Patches
9 | 2. [x] DiffAugment
10 | 3. [x] Self-modulated LayerNorm (SLN)
11 | 4. [x] Implicit Neural Representation for Patch Generation
12 | 5. [x] ExponentialMovingAverage (EMA)
13 | 6. [x] Balanced Consistency Regularization (bCR)
14 | 7. [x] Improved Spectral Normalization (ISN)
15 | 8. [x] Equalized Learning Rate
16 | 9. [x] Weight Modulation
17 |
18 | ## Dependencies
19 |
20 | - Python3
21 | - einops
22 | - pytorch_ema
23 | - stylegan2-pytorch
24 | - tensorboard
25 | - wandb
26 |
27 | ``` bash
28 | pip install einops git+https://github.com/fadel/pytorch_ema stylegan2-pytorch tensorboard wandb
29 | ```
30 |
31 | ## **TLDR:**
32 |
33 | Train the model with the proposed parameters:
34 |
35 | ``` bash
36 | python main.py
37 | ```
38 |
39 | Tensorboard
40 |
41 | ``` bash
42 | tensorboard --logdir runs/
43 | ```
44 |
45 | ***
46 |
47 | The following parameters are the parameters, proposed in the paper for the CIFAR-10 dataset:
48 |
49 | ``` bash
50 | python main.py
51 | ```
52 |
53 | ## Implementation Details
54 |
55 | ### Generator
56 |
57 | The Generator follows the following architecture:
58 |
59 | 
60 |
61 | For debugging purposes, the Generator is separated into a Vision Transformer (ViT) model and a SIREN model.
62 |
63 | Given a seed, the dimensionality of which is controlled by ```latent_dim```, the ViT model creates an embedding for each of the patches of the final image. Those embeddings are fed to a SIREN network, combined with a Fourier Position Encoding \([Jupyter Notebook](https://github.com/tancik/fourier-feature-networks/blob/master/Demo.ipynb)\). It outputs the patches of the image, which are stitched together.
64 |
65 | The ViT part of the Generator differs from a standard Vision Transformer in the following ways:
66 | - The input to the Transformer consists only of the position embeddings
67 | - Self-Modulated Layer Norm (SLN) is used in place of LayerNorm
68 | - There is no classification head
69 |
70 | SLN is the only place, where the seed is inputted to the network.
71 | SLN consists of a regular LayerNorm, the result of which is multiplied by ```gamma``` and added to ```beta```.
72 | Both ```gamma``` and ```beta``` are calculated using a fully connected layer, different for each place, SLN is applied.
73 | The input dimension to each of those fully connected is equal to ```hidden_dimension``` and the output dimension is equal to ```hidden_dimension```.
74 |
75 | #### SIREN
76 |
77 | A description of SIREN:
78 | \[[Blog Post](https://tech.fusic.co.jp/posts/2021-08-03-what-are-sirens/)\] \[[Paper](https://arxiv.org/pdf/2006.09661.pdf)\] \[[Colab Notebook](https://colab.research.google.com/github/vsitzmann/siren/blob/master/explore_siren.ipynb)\]
79 |
80 | In contrast to regular SIREN, the desired output is not a single image. For this purpose, the patch embedding is combined to a position embedding.
81 |
82 | The positional encoding, used in ViTGAN is the Fourier Position Encoding, the code for which was taken from here: \([Jupyter Notebook](https://github.com/tancik/fourier-feature-networks/blob/master/Demo.ipynb)\)
83 |
84 | In my implementation, the input to the SIREN is the sum of a patch embedding and a position embedding.
85 |
86 | #### Weight Modulation
87 |
88 | Weight Modulation usually consists of a modulation and a demodulation module. After testing the network, I concluded that **demodulation is not used in ViTGAN**.
89 |
90 | My implementation of the weight modulation is heavily based on [CIPS](https://github.com/saic-mdal/CIPS/blob/main/model/blocks.py#L173). I have adjusted it to work for a fully-connected network, using a 1D convolution. The reason for using 1D convolution, instead of a linear layer is the groups term, which optimizes the performance by a factor of batch_size.
91 |
92 | Each SIREN layer consists of a sinsin activation, applied to a weight modulation layer. The size of the input, the hidden and the output layers in a SIREN network could vary. Thus, in case the input size differs from the size of the patch embedding, I define an additional fully-connected layer, which converts the patch embedding to the appropriate size.
93 |
94 | ### Discriminator
95 |
96 | The Discriminator follows the following architecture:
97 |
98 | 
99 |
100 | The ViTGAN Discriminator is mostly a standard Vision Transformer network, with the following modifications:
101 | - DiffAugment
102 | - Overlapping Image Patches
103 | - Use vectorized L2 distance in attention for **Discriminator**
104 | - Improved Spectral Normalization (ISN)
105 | - Balanced Consistency Regularization (bCR)
106 |
107 | #### DiffAugment
108 |
109 | For implementating DiffAugment, I used the code below:
110 | \[[GitHub](https://github.com/mit-han-lab/data-efficient-gans/blob/master/DiffAugment-stylegan2-pytorch/DiffAugment_pytorch.py)\] \[[Paper](https://arxiv.org/pdf/2006.10738.pdf)\]
111 |
112 | #### Overlapping Image Patches
113 |
114 | Creation of the overlapping image patches is implemented with the use of a convolution layer.
115 |
116 | #### Use vectorized L2 distance in attention for **Discriminator**
117 |
118 | \[[Paper](https://arxiv.org/pdf/2006.04710.pdf)\]
119 |
120 | #### Improved Spectral Normalization (ISN)
121 |
122 | The ISN implementation is based on the following implementation of Spectral Normalization:
123 | \[[GitHub](https://github.com/koshian2/SNGAN/blob/117fbb19ac79bbc561c3ccfe285d6890ea0971f9/models/core_layers.py#L9)\]
124 | \[[Paper](https://arxiv.org/abs/1802.05957)\]
125 |
126 | #### Balanced Consistency Regularization (bCR)
127 |
128 | Zhengli Zhao, Sameer Singh, Honglak Lee, Zizhao Zhang, Augustus Odena, Han Zhang; Improved Consistency Regularization for GANs; AAAI 2021
129 | \[[Paper](https://arxiv.org/pdf/2002.04724.pdf)\]
130 |
131 | ## References
132 | SIREN: [Implicit Neural Representations with Periodic Activation Functions](https://arxiv.org/pdf/2006.09661.pdf)
133 | Vision Transformer: \[[Blog Post](https://towardsdatascience.com/implementing-visualttransformer-in-pytorch-184f9f16f632)\]
134 | L2 distance attention: [The Lipschitz Constant of Self-Attention](https://arxiv.org/pdf/2006.04710.pdf)
135 | Spectral Normalization reference code: \[[GitHub](https://github.com/koshian2/SNGAN/blob/117fbb19ac79bbc561c3ccfe285d6890ea0971f9/models/core_layers.py#L9)\] \[[Paper](https://arxiv.org/abs/1802.05957)\]
136 | Diff Augment: \[[GitHub](https://github.com/mit-han-lab/data-efficient-gans/blob/master/DiffAugment-stylegan2-pytorch/DiffAugment_pytorch.py)\] \[[Paper](https://arxiv.org/pdf/2006.10738.pdf)\]
137 | Fourier Position Embedding: \[[Jupyter Notebook](https://github.com/tancik/fourier-feature-networks/blob/master/Demo.ipynb)\]
138 | Exponential Moving Average: \[[GitHub](https://github.com/fadel/pytorch_ema)\]
139 | Balanced Concictancy Regularization (bCR): \[[Paper](https://arxiv.org/pdf/2002.04724.pdf)\]
140 | SyleGAN2 Discriminator: \[[GitHub](https://github.com/lucidrains/stylegan2-pytorch/blob/1a789d186b9697571bd6bbfa8bb1b9735bb42a0c/stylegan2_pytorch/stylegan2_pytorch.py#L627)\]
141 |
--------------------------------------------------------------------------------
/models/vitgan_generator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | from einops import repeat
6 |
7 | from .multi_head import MultiHeadAttention
8 |
9 | # ViTGAN Generator
10 |
11 | class FullyConnectedLayer(nn.Module):
12 | def __init__(self,
13 | in_features, # Number of input features.
14 | out_features, # Number of output features.
15 | bias = True, # Apply additive bias before the activation function?
16 | activation = 'linear', # Activation function: 'relu', 'lrelu', etc.
17 | lr_multiplier = 1, # Learning rate multiplier.
18 | bias_init = 0, # Initial value for the additive bias.
19 | **kwargs
20 | ):
21 | super().__init__()
22 | self.activation = activation
23 | if activation == 'lrelu':
24 | self.activation = nn.LeakyReLU(0.2)
25 | elif activation == 'relu':
26 | self.activation = nn.ReLU()
27 | elif activation == 'gelu':
28 | self.activation = nn.GELU()
29 | self.weight = nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier)
30 | self.bias = nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None
31 | self.weight_gain = lr_multiplier / np.sqrt(in_features)
32 | self.bias_gain = lr_multiplier
33 |
34 | def forward(self, x):
35 | w = self.weight.to(x.dtype) * self.weight_gain
36 | b = self.bias
37 | if b is not None:
38 | b = b.to(x.dtype)
39 | if self.bias_gain != 1:
40 | b = b * self.bias_gain
41 |
42 | if self.activation == 'linear' and b is not None:
43 | x = torch.addmm(b.unsqueeze(0), x, w.t())
44 | else:
45 | x = x.matmul(w.t())
46 | if b is not None:
47 | x = x + b
48 | if self.activation != 'linear':
49 | x = self.activation(x)
50 | return x
51 |
52 | def normalize_2nd_moment(x, dim=1, eps=1e-8):
53 | return x * (x.square().mean(dim=dim, keepdim=True) + eps).rsqrt()
54 |
55 | class MappingNetwork(nn.Module):
56 | def __init__(self,
57 | z_dim, # Input latent (Z) dimensionality, 0 = no latent.
58 | c_dim, # Conditioning label (C) dimensionality, 0 = no label.
59 | w_dim, # Intermediate latent (W) dimensionality.
60 | num_ws = None, # Number of intermediate latents to output, None = do not broadcast.
61 | num_layers = 8, # Number of mapping layers.
62 | embed_features = None, # Label embedding dimensionality, None = same as w_dim.
63 | layer_features = None, # Number of intermediate features in the mapping layers, None = same as w_dim.
64 | activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
65 | lr_multiplier = 0.01, # Learning rate multiplier for the mapping layers.
66 | w_avg_beta = 0.995, # Decay for tracking the moving average of W during training, None = do not track.
67 | **kwargs
68 | ):
69 | super().__init__()
70 | self.z_dim = z_dim
71 | self.c_dim = c_dim
72 | self.w_dim = w_dim
73 | self.num_ws = num_ws
74 | self.num_layers = num_layers
75 | self.w_avg_beta = w_avg_beta
76 |
77 | if embed_features is None:
78 | embed_features = w_dim
79 | if c_dim == 0:
80 | embed_features = 0
81 | if layer_features is None:
82 | layer_features = w_dim
83 | features_list = [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim]
84 |
85 | if c_dim > 0:
86 | self.embed = FullyConnectedLayer(c_dim, embed_features)
87 | for idx in range(num_layers):
88 | in_features = features_list[idx]
89 | out_features = features_list[idx + 1]
90 | layer = FullyConnectedLayer(in_features, out_features, activation=activation, lr_multiplier=lr_multiplier)
91 | setattr(self, f'fc{idx}', layer)
92 |
93 | if num_ws is not None and w_avg_beta is not None:
94 | self.register_buffer('w_avg', torch.zeros([w_dim]))
95 |
96 | def forward(self, z, c=None, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False):
97 | # Embed, normalize, and concat inputs.
98 | x = None
99 | with torch.autograd.profiler.record_function('input'):
100 | if self.z_dim > 0:
101 | assert z.shape[1] == self.z_dim
102 | x = normalize_2nd_moment(z.to(torch.float32))
103 | if self.c_dim > 0:
104 | assert c.shape[1] == self.c_dim
105 | y = normalize_2nd_moment(self.embed(c.to(torch.float32)))
106 | x = torch.cat([x, y], dim=1) if x is not None else y
107 |
108 | # Main layers.
109 | for idx in range(self.num_layers):
110 | layer = getattr(self, f'fc{idx}')
111 | x = layer(x)
112 |
113 | # Update moving average of W.
114 | if self.w_avg_beta is not None and self.training and not skip_w_avg_update:
115 | with torch.autograd.profiler.record_function('update_w_avg'):
116 | self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta))
117 |
118 | # Broadcast.
119 | if self.num_ws is not None:
120 | with torch.autograd.profiler.record_function('broadcast'):
121 | x = x.unsqueeze(1).repeat([1, self.num_ws, 1])
122 |
123 | # Apply truncation.
124 | if truncation_psi != 1:
125 | with torch.autograd.profiler.record_function('truncate'):
126 | assert self.w_avg_beta is not None
127 | if self.num_ws is None or truncation_cutoff is None:
128 | x = self.w_avg.lerp(x, truncation_psi)
129 | else:
130 | x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi)
131 | return x
132 |
133 | class FeedForwardBlock(nn.Sequential):
134 | def __init__(self, emb_size, expansion=4, drop_p=0., bias=False):
135 | super().__init__(
136 | FullyConnectedLayer(emb_size, expansion * emb_size, activation='gelu', bias=False),
137 | nn.Dropout(drop_p),
138 | FullyConnectedLayer(expansion * emb_size, emb_size, bias=False),
139 | )
140 |
141 | # Self-Modulated LayerNorm
142 |
143 | class SLN(nn.Module):
144 | def __init__(self, input_size, parameter_size=None, **kwargs):
145 | super().__init__()
146 | if parameter_size == None:
147 | parameter_size = input_size
148 | assert(input_size == parameter_size or parameter_size == 1)
149 | self.input_size = input_size
150 | self.parameter_size = parameter_size
151 | self.ln = nn.LayerNorm(input_size)
152 | self.gamma = FullyConnectedLayer(input_size, parameter_size, bias=False)
153 | self.beta = FullyConnectedLayer(input_size, parameter_size, bias=False)
154 | # self.gamma = nn.Linear(input_size, parameter_size, bias=False)
155 | # self.beta = nn.Linear(input_size, parameter_size, bias=False)
156 |
157 | def forward(self, hidden, w):
158 | assert(hidden.size(-1) == self.parameter_size and w.size(-1) == self.parameter_size)
159 | gamma = self.gamma(w).unsqueeze(1)
160 | beta = self.beta(w).unsqueeze(1)
161 | ln = self.ln(hidden)
162 | return gamma * ln + beta
163 |
164 | class GeneratorTransformerEncoderBlock(nn.Module):
165 | def __init__(self,
166 | hidden_size=384,
167 | sln_paremeter_size=384,
168 | drop_p=0.,
169 | forward_expansion=4,
170 | forward_drop_p=0.,
171 | **kwargs):
172 | super().__init__()
173 | self.sln = SLN(hidden_size, parameter_size=sln_paremeter_size)
174 | self.msa = MultiHeadAttention(hidden_size, **kwargs)
175 | self.dropout = nn.Dropout(drop_p)
176 | self.feed_forward = FeedForwardBlock(hidden_size, expansion=forward_expansion, drop_p=forward_drop_p)
177 |
178 | def forward(self, hidden, w):
179 | res = hidden
180 | hidden = self.sln(hidden, w)
181 | hidden = self.msa(hidden)
182 | hidden = self.dropout(hidden)
183 | hidden += res
184 |
185 | res = hidden
186 | hidden = self.sln(hidden, w)
187 | self.feed_forward(hidden)
188 | hidden = self.dropout(hidden)
189 | hidden += res
190 | return hidden
191 |
192 | class GeneratorTransformerEncoder(nn.Module):
193 | def __init__(self, depth=4, **kwargs):
194 | super().__init__()
195 | self.depth = depth
196 | self.blocks = nn.ModuleList([GeneratorTransformerEncoderBlock(**kwargs) for _ in range(depth)])
197 |
198 | def forward(self, hidden, w):
199 | for i in range(self.depth):
200 | hidden = self.blocks[i](hidden, w)
201 | return hidden
202 |
203 | # SIREN
204 |
205 | #Code for SIREN is taken from https://colab.research.google.com/github/vsitzmann/siren/blob/master/explore_siren.ipynb
206 |
207 | class ModulatedLinear(nn.Module):
208 | def __init__(self, in_channels, out_channels, style_size, bias=False, demodulation=True, **kwargs):
209 | super().__init__()
210 | self.in_channels = in_channels
211 | self.out_channels = out_channels
212 | self.style_size = style_size
213 | self.scale = 1 / np.sqrt(in_channels)
214 | self.weight = nn.Parameter(
215 | torch.randn(1, out_channels, in_channels, 1)
216 | )
217 | self.modulation = None
218 | if self.style_size != self.in_channels:
219 | self.modulation = FullyConnectedLayer(style_size, in_channels, bias=False)
220 | self.demodulation = demodulation
221 |
222 | def forward(self, input, style):
223 | batch_size = input.shape[0]
224 |
225 | if self.style_size != self.in_channels:
226 | style = self.modulation(style)
227 | style = style.view(batch_size, 1, self.in_channels, 1)
228 | weight = self.scale * self.weight * style
229 |
230 | if self.demodulation:
231 | demod = torch.rsqrt(weight.pow(2).sum([2]) + 1e-8)
232 | weight = weight * demod.view(batch_size, self.out_channels, 1, 1)
233 |
234 | weight = weight.view(
235 | batch_size * self.out_channels, self.in_channels, 1
236 | )
237 |
238 | img_size = input.size(1)
239 | input = input.reshape(1, batch_size * self.in_channels, img_size)
240 | out = F.conv1d(input, weight, groups=batch_size)
241 | out = out.view(batch_size, img_size, self.out_channels)
242 |
243 | return out
244 |
245 | class ResLinear(nn.Module):
246 | def __init__(self, in_channels, out_channels, style_size, bias=False, **kwargs):
247 | super().__init__()
248 | self.linear = FullyConnectedLayer(in_channels, out_channels, bias=False)
249 | self.style = FullyConnectedLayer(style_size, in_channels, bias=False)
250 | self.in_channels = in_channels
251 | self.out_channels = out_channels
252 | self.style_size = style_size
253 |
254 | def forward(self, input, style):
255 | x = input + self.style(style).unsqueeze(1)
256 | x = self.linear(x)
257 | return x
258 |
259 | class ConLinear(nn.Module):
260 | def __init__(self, ch_in, ch_out, is_first=False, bias=True, **kwargs):
261 | super().__init__()
262 | self.conv = nn.Linear(ch_in, ch_out, bias=bias)
263 | if is_first:
264 | nn.init.uniform_(self.conv.weight, -np.sqrt(9 / ch_in), np.sqrt(9 / ch_in))
265 | else:
266 | nn.init.uniform_(self.conv.weight, -np.sqrt(3 / ch_in), np.sqrt(3 / ch_in))
267 |
268 | def forward(self, x):
269 | return self.conv(x)
270 |
271 | class SinActivation(nn.Module):
272 | def __init__(self):
273 | super().__init__()
274 |
275 | def forward(self, x):
276 | return torch.sin(x)
277 |
278 | class LFF(nn.Module):
279 | def __init__(self, hidden_size, **kwargs):
280 | super().__init__()
281 | self.ffm = ConLinear(2, hidden_size, is_first=True)
282 | self.activation = SinActivation()
283 |
284 | def forward(self, x):
285 | x = x
286 | x = self.ffm(x)
287 | x = self.activation(x)
288 | return x
289 |
290 | class SineLayer(nn.Module):
291 | def __init__(self, in_features, out_features, style_size, bias=False,
292 | is_first=False, omega_0=30, weight_modulation=True, **kwargs):
293 | super().__init__()
294 | self.omega_0 = omega_0
295 | self.is_first = is_first
296 |
297 | self.in_features = in_features
298 | self.weight_modulation = weight_modulation
299 | if weight_modulation:
300 | self.linear = ModulatedLinear(in_features, out_features, style_size=style_size, bias=bias, **kwargs)
301 | else:
302 | self.linear = ResLinear(in_features, out_features, style_size=style_size, bias=bias, **kwargs)
303 | self.init_weights()
304 |
305 | def init_weights(self):
306 | with torch.no_grad():
307 | if self.is_first:
308 | if self.weight_modulation:
309 | self.linear.weight.uniform_(-1 / self.in_features,
310 | 1 / self.in_features)
311 | else:
312 | self.linear.linear.weight.uniform_(-1 / self.in_features,
313 | 1 / self.in_features)
314 | else:
315 | if self.weight_modulation:
316 | self.linear.weight.uniform_(-np.sqrt(6 / self.in_features) / self.omega_0,
317 | np.sqrt(6 / self.in_features) / self.omega_0)
318 | else:
319 | self.linear.linear.weight.uniform_(-np.sqrt(6 / self.in_features) / self.omega_0,
320 | np.sqrt(6 / self.in_features) / self.omega_0)
321 |
322 | def forward(self, input, style):
323 | return torch.sin(self.omega_0 * self.linear(input, style))
324 |
325 | class Siren(nn.Module):
326 | def __init__(self, in_features, hidden_size, hidden_layers, out_features, style_size, outermost_linear=False,
327 | first_omega_0=30, hidden_omega_0=30., weight_modulation=True, bias=False, **kwargs):
328 | super().__init__()
329 |
330 | self.net = []
331 | self.net.append(SineLayer(in_features, hidden_size, style_size,
332 | is_first=True, omega_0=first_omega_0,
333 | weight_modulation=weight_modulation, **kwargs))
334 |
335 | for i in range(hidden_layers):
336 | self.net.append(SineLayer(hidden_size, hidden_size, style_size,
337 | is_first=False, omega_0=hidden_omega_0,
338 | weight_modulation=weight_modulation, **kwargs))
339 |
340 | if outermost_linear:
341 | if weight_modulation:
342 | final_linear = ModulatedLinear(hidden_size, out_features,
343 | style_size=style_size, bias=bias, **kwargs)
344 | else:
345 | final_linear = ResLinear(hidden_size, out_features, style_size=style_size, bias=bias, **kwargs)
346 |
347 | with torch.no_grad():
348 | if weight_modulation:
349 | final_linear.weight.uniform_(-np.sqrt(6 / hidden_size) / hidden_omega_0,
350 | np.sqrt(6 / hidden_size) / hidden_omega_0)
351 | else:
352 | final_linear.linear.weight.uniform_(-np.sqrt(6 / hidden_size) / hidden_omega_0,
353 | np.sqrt(6 / hidden_size) / hidden_omega_0)
354 |
355 | self.net.append(final_linear)
356 | else:
357 | self.net.append(SineLayer(hidden_size, out_features,
358 | is_first=False, omega_0=hidden_omega_0,
359 | weight_modulation=weight_modulation, **kwargs))
360 |
361 | self.net = nn.Sequential(*self.net)
362 |
363 | def forward(self, coords, style):
364 | coords = coords.clone().detach().requires_grad_(True) # allows to take derivative w.r.t. input
365 | # output = self.net(coords, style)
366 | output = coords
367 | for layer in self.net:
368 | output = layer(output, style)
369 | return output
370 |
371 | class GeneratorViT(nn.Module):
372 | def __init__(self,
373 | style_mlp_layers=8,
374 | patch_size=4,
375 | latent_dim=32,
376 | hidden_size=384,
377 | sln_paremeter_size=1,
378 | image_size=32,
379 | depth=4,
380 | combine_patch_embeddings=False,
381 | combined_embedding_size=1024,
382 | forward_drop_p=0.,
383 | bias=False,
384 | out_features=3,
385 | out_patch_size=4,
386 | weight_modulation=True,
387 | siren_hidden_layers=1,
388 | **kwargs):
389 | super().__init__()
390 | self.hidden_size = hidden_size
391 |
392 | self.mlp = MappingNetwork(z_dim=latent_dim, c_dim=0, w_dim=hidden_size, num_layers=style_mlp_layers, w_avg_beta=None)
393 |
394 | num_patches = int(image_size//patch_size)**2
395 | self.patch_size = patch_size
396 | self.num_patches = num_patches
397 | self.image_size = image_size
398 | self.combine_patch_embeddings = combine_patch_embeddings
399 | self.combined_embedding_size = combined_embedding_size
400 | self.out_patch_size = out_patch_size
401 | self.out_features = out_features
402 |
403 | self.pos_emb = nn.Parameter(torch.randn(num_patches, hidden_size))
404 | self.transformer_encoder = GeneratorTransformerEncoder(depth,
405 | hidden_size=hidden_size,
406 | sln_paremeter_size=sln_paremeter_size,
407 | drop_p=forward_drop_p,
408 | forward_drop_p=forward_drop_p,
409 | **kwargs)
410 | self.sln = SLN(hidden_size, parameter_size=sln_paremeter_size)
411 | if combine_patch_embeddings:
412 | self.to_single_emb = nn.Sequential(
413 | FullyConnectedLayer(num_patches*hidden_size, combined_embedding_size, bias=bias, activation='gelu'),
414 | nn.Dropout(forward_drop_p),
415 | )
416 |
417 | self.lff = LFF(self.hidden_size)
418 |
419 | self.siren_in_features = combined_embedding_size if combine_patch_embeddings else self.hidden_size
420 | self.siren = Siren(in_features=self.siren_in_features, out_features=out_features,
421 | style_size=self.siren_in_features, hidden_size=self.hidden_size, bias=bias,
422 | hidden_layers=siren_hidden_layers, outermost_linear=True, weight_modulation=weight_modulation, **kwargs)
423 |
424 | self.num_patches_x = int(image_size//self.out_patch_size)
425 |
426 | def fourier_input_mapping(self, x):
427 | return self.lff(x)
428 |
429 | def fourier_pos_embedding(self, device):
430 | # Create input pixel coordinates in the unit square
431 | coords = np.linspace(-1, 1, self.out_patch_size, endpoint=True)
432 | pos = np.stack(np.meshgrid(coords, coords), -1)
433 | pos = torch.tensor(pos, dtype=torch.float, device=device)
434 | result = self.fourier_input_mapping(pos).reshape([self.out_patch_size**2, self.hidden_size])
435 | return result.to(device)
436 |
437 | def repeat_pos(self, hidden):
438 | pos = self.fourier_pos_embedding(hidden.device)
439 | result = repeat(pos, 'p h -> n p h', n = hidden.shape[0])
440 |
441 | return result
442 |
443 | def forward(self, z):
444 | w = self.mlp(z)
445 | pos = repeat(torch.sin(self.pos_emb), 'n e -> b n e', b=z.shape[0])
446 | hidden = self.transformer_encoder(pos, w)
447 |
448 | if self.combine_patch_embeddings:
449 | # Output [batch_size, combined_embedding_size]
450 | hidden = self.sln(hidden, w).view((z.shape[0], -1))
451 | hidden = self.to_single_emb(hidden)
452 | else:
453 | # Output [batch_size*num_patches, hidden_size]
454 | hidden = self.sln(hidden, w).view((-1, self.hidden_size))
455 |
456 | pos = self.repeat_pos(hidden)
457 |
458 | result = self.siren(pos, hidden)
459 |
460 | model_output_1 = result.view([-1, self.num_patches_x, self.num_patches_x, self.out_patch_size, self.out_patch_size, self.out_features])
461 | model_output_2 = model_output_1.permute([0, 1, 3, 2, 4, 5])
462 | model_output = model_output_2.reshape([-1, self.image_size**2, self.out_features])
463 |
464 | return model_output
465 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | import os
5 |
6 | import numpy as np
7 | import matplotlib.pyplot as plt
8 |
9 | import torchvision
10 | import torchvision.transforms as transforms
11 | import torch.nn.functional as F
12 | import torchvision.transforms as T
13 | from torchvision.utils import make_grid
14 |
15 | from torch_ema import ExponentialMovingAverage
16 |
17 | from torch.utils.tensorboard import SummaryWriter
18 | import wandb
19 |
20 | from models import CNN, ViT, StyleGanDiscriminator, GeneratorViT, CNNGenerator
21 |
22 | import argparse
23 |
24 | def get_parser():
25 | # parse parameters
26 | parser = argparse.ArgumentParser(description='ViTGAN')
27 | parser.add_argument("--image_size", type=int, default=32,
28 | help="Image Size")
29 | parser.add_argument("--style_mlp_layers", type=int, default=8,
30 | help="Style Mapping network depth")
31 | parser.add_argument("--patch_size", type=int, default=4,
32 | help="Patch Size")
33 | parser.add_argument("--latent_dim", type=int, default=32,
34 | help="Dimensions of the seed")
35 | parser.add_argument("--dropout_p", type=float, default=0.,
36 | help="Dropout rate")
37 | parser.add_argument("--bias", type=bool, default=True,
38 | help="Whether to use bias or not")
39 | parser.add_argument("--weight_modulation", type=bool, default=True,
40 | help="Whether to use weight modulation or not")
41 | parser.add_argument("--demodulation", type=bool, default=False,
42 | help="Whether to use weight demodulation or not")
43 | parser.add_argument("--siren_hidden_layers", type=int, default=1,
44 | help="Number of hidden layers for the SIREN network")
45 | parser.add_argument("--hidden_features", type=int, default=384,
46 | help="Image Size")
47 | parser.add_argument("--sln_paremeter_size", type=int, default=384,
48 | help="Either equal to --hidden_features of 1")
49 | parser.add_argument("--depth", type=int, default=4,
50 | help="Number of Transformer Block layers for both the Generator and the Discriminator")
51 | parser.add_argument("--num_heads", type=int, default=4,
52 | help="Number of Attention heads for every Transformer Block layer")
53 | parser.add_argument("--combine_patch_embeddings", type=bool, default=False,
54 | help="Generate an image from a single SIREN, instead of patch-by-patch")
55 | parser.add_argument("--combine_patch_embeddings_size", type=int, default=384,
56 | help="Size of the combined image embedding")
57 | parser.add_argument("--batch_size", type=int, default=128,
58 | help="Batch size")
59 | parser.add_argument("--generator_type", type=str, default="vitgan",
60 | help="\"vitgan\", \"cnn\"")
61 | parser.add_argument("--discriminator_type", type=str, default="vitgan",
62 | help="\"vitgan\", \"cnn\", \"stylegan2\"")
63 | parser.add_argument("--batch_size_history_discriminator", type=bool, default=True,
64 | help="Whether to use a loss, which tracks one sample from last batch_size number of batches")
65 | parser.add_argument("--lr", type=float, default=0.001,
66 | help="Learning Rate for the Generator")
67 | parser.add_argument("--lr_dis", type=float, default=0.001,
68 | help="Learning Rate for the Discriminator")
69 | parser.add_argument("--beta1", type=float, default=0,
70 | help="Adam beta1 parameter")
71 | parser.add_argument("--beta2", type=float, default=0.99,
72 | help="Adam beta2 parameter")
73 | parser.add_argument("--epochs", type=int, default=400,
74 | help="Number of epocks")
75 | parser.add_argument("--lambda_bCR_real", type=int, default=10,
76 | help="lambda_bCR_real")
77 | parser.add_argument("--lambda_bCR_fake", type=int, default=10,
78 | help="lambda_bCR_fake")
79 | parser.add_argument("--lambda_lossD_noise", type=int, default=0,
80 | help="lambda_lossD_noise")
81 | parser.add_argument("--lambda_lossD_history", type=int, default=0,
82 | help="lambda_lossD_history")
83 | parser.add_argument("--lambda_diversity_penalty", type=int, default=0,
84 | help="lambda_diversity_penalty")
85 | parser.add_argument("--device", type=str, default='cuda',
86 | help="device")
87 | return parser
88 |
89 | parser = get_parser()
90 | params = parser.parse_args()
91 |
92 | image_size = params.image_size
93 | style_mlp_layers = params.style_mlp_layers
94 | patch_size = params.patch_size
95 | latent_dim = params.latent_dim # Size of z
96 | hidden_size = params.hidden_features
97 | depth = params.depth
98 | num_heads = params.num_heads
99 |
100 | dropout_p = params.dropout_p
101 | bias = params.bias
102 | weight_modulation = params.weight_modulation
103 | demodulation = params.demodulation
104 | siren_hidden_layers = params.siren_hidden_layers
105 |
106 | combine_patch_embeddings = params.combine_patch_embeddings # Generate an image from a single SIREN, instead of patch-by-patch
107 | combine_patch_embeddings_size = params.combine_patch_embeddings_size
108 |
109 | sln_paremeter_size = params.sln_paremeter_size # either hidden_size or 1
110 |
111 | batch_size = params.batch_size
112 | device = params.device
113 | out_features = 3 # The number of color channels
114 |
115 | generator_type = params.generator_type # "vitgan", "cnn"
116 | discriminator_type = params.discriminator_type # "vitgan", "cnn", "stylegan2"
117 |
118 | lr = params.lr # Learning rate
119 | lr_dis = params.lr_dis # Learning rate
120 | beta = (params.beta1, params.beta2) # Adam oprimizer parameters for both the generator and the discriminator
121 | batch_size_history_discriminator = params.batch_size_history_discriminator # Whether to use a loss, which tracks one sample from last batch_size number of batches
122 | epochs = params.epochs # Number of epochs
123 | lambda_bCR_real = params.lambda_bCR_real
124 | lambda_bCR_fake = params.lambda_bCR_fake
125 | lambda_lossD_noise = params.lambda_lossD_noise
126 | lambda_lossD_history = params.lambda_lossD_history
127 | lambda_diversity_penalty = params.lambda_diversity_penalty
128 |
129 | experiment_folder_name = f'runs/lr-{lr}_\
130 | lr_dis-{lr_dis}_\
131 | bias-{bias}_\
132 | demod-{demodulation}_\
133 | sir_n_layer-{siren_hidden_layers}_\
134 | w_mod-{weight_modulation}_\
135 | patch_s-{patch_size}_\
136 | st_mlp_l-{style_mlp_layers}_\
137 | hid_size-{hidden_size}_\
138 | comb_patch_emb-{combine_patch_embeddings}_\
139 | sln_par_s-{sln_paremeter_size}_\
140 | dis_type-{discriminator_type}_\
141 | gen_type-{generator_type}_\
142 | n_head-{num_heads}_\
143 | depth-{depth}_\
144 | drop_p-{dropout_p}_\
145 | l_bCR_r-{lambda_bCR_real}_\
146 | l_bCR_f-{lambda_bCR_fake}_\
147 | l_D_noise-{lambda_lossD_noise}_\
148 | l_D_his-{lambda_lossD_history}\
149 | '
150 | writer = SummaryWriter(log_dir=experiment_folder_name)
151 |
152 | wandb.init(project='ViTGAN-pytorch')
153 | config = wandb.config
154 | config.image_size = image_size
155 | config.bias = bias
156 | config.demodulation = demodulation
157 | config.siren_hidden_layers = siren_hidden_layers
158 | config.weight_modulation = weight_modulation
159 | config.style_mlp_layers = style_mlp_layers
160 | config.patch_size = patch_size
161 | config.latent_dim = latent_dim
162 | config.hidden_size = hidden_size
163 | config.depth = depth
164 | config.num_heads = num_heads
165 |
166 | config.dropout_p = dropout_p
167 |
168 | config.combine_patch_embeddings = combine_patch_embeddings
169 | config.combine_patch_embeddings_size = combine_patch_embeddings_size
170 |
171 | config.sln_paremeter_size = sln_paremeter_size
172 |
173 | config.batch_size = batch_size
174 | config.device = device
175 | config.out_features = out_features
176 |
177 | config.generator_type = generator_type
178 | config.discriminator_type = discriminator_type
179 |
180 | config.lr = lr
181 | config.lr_dis = lr_dis
182 | config.beta1 = beta[0]
183 | config.beta2 = beta[1]
184 | config.batch_size_history_discriminator = batch_size_history_discriminator
185 | config.epochs = epochs
186 | config.lambda_bCR_real = lambda_bCR_real
187 | config.lambda_bCR_fake = lambda_bCR_fake
188 | config.lambda_lossD_noise = lambda_lossD_noise
189 | config.lambda_lossD_history = lambda_lossD_history
190 | config.lambda_diversity_penalty = lambda_diversity_penalty
191 |
192 | if combine_patch_embeddings:
193 | out_patch_size = image_size
194 | combined_embedding_size = combine_patch_embeddings_size
195 | else:
196 | out_patch_size = patch_size
197 | combined_embedding_size = hidden_size
198 |
199 | siren_in_features = combined_embedding_size
200 |
201 | transform = transforms.Compose(
202 | [transforms.ToTensor(),
203 | transforms.Normalize((0., 0., 0.), (1., 1., 1.))
204 | ])
205 |
206 | trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
207 | download=True, transform=transform)
208 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
209 | shuffle=True, num_workers=2)
210 |
211 | # Diversity Loss
212 |
213 | def diversity_loss(images):
214 | num_images_to_calculate_on = 10
215 | num_pairs = num_images_to_calculate_on * (num_images_to_calculate_on - 1) // 2
216 |
217 | scale_factor = 5
218 |
219 | loss = torch.zeros(1, dtype=torch.float, device=device, requires_grad=True)
220 | i = 0
221 | for a_id in range(num_images_to_calculate_on):
222 | for b_id in range(a_id+1, num_images_to_calculate_on):
223 | img_a = images[a_id]
224 | img_b = images[b_id]
225 | img_a_l2 = torch.norm(img_a)
226 | img_b_l2 = torch.norm(img_b)
227 | img_a, img_b = img_a.flatten(), img_b.flatten()
228 |
229 | a_b_loss = scale_factor * (img_a.t() @ img_b) / (img_a_l2 * img_b_l2)
230 | loss = loss + torch.sigmoid(a_b_loss)
231 | i += 1
232 | loss = loss.sum() / num_pairs
233 | return loss
234 |
235 | # Normal distribution init weight
236 |
237 | def init_normal(m):
238 | if type(m) == nn.Linear:
239 | # y = m.in_features
240 | # m.weight.data.normal_(0.0,1/np.sqrt(y))
241 | if 'weight' in m.__dict__.keys():
242 | m.weight.data.normal_(0.0,1)
243 | # m.bias.data.fill_(0)
244 |
245 | # Experiments
246 |
247 | if generator_type == "vitgan":
248 | # Create the Generator
249 | Generator = GeneratorViT( patch_size=patch_size,
250 | image_size=image_size,
251 | style_mlp_layers=style_mlp_layers,
252 | latent_dim=latent_dim,
253 | hidden_size=hidden_size,
254 | combine_patch_embeddings=combine_patch_embeddings,
255 | combined_embedding_size=combined_embedding_size,
256 | sln_paremeter_size=sln_paremeter_size,
257 | num_heads=num_heads,
258 | depth=depth,
259 | forward_drop_p=dropout_p,
260 | bias=bias,
261 | weight_modulation=weight_modulation,
262 | siren_hidden_layers=siren_hidden_layers,
263 | demodulation=demodulation,
264 | out_patch_size=out_patch_size,
265 | ).to(device)
266 |
267 | print(Generator)
268 |
269 | # use the modules apply function to recursively apply the initialization
270 | Generator.apply(init_normal)
271 |
272 | num_patches_x = int(image_size//out_patch_size)
273 |
274 | if os.path.exists(f'{experiment_folder_name}/weights/Generator.pth'):
275 | Generator = torch.load(f'{experiment_folder_name}/weights/Generator.pth')
276 |
277 | wandb.watch(Generator)
278 |
279 | elif generator_type == "cnn":
280 | cnn_generator = CNNGenerator(hidden_size=hidden_size, latent_dim=latent_dim).to(device)
281 |
282 | print(cnn_generator)
283 |
284 | cnn_generator.apply(init_normal)
285 |
286 | if os.path.exists(f'{experiment_folder_name}/weights/cnn_generator.pth'):
287 | cnn_generator = torch.load(f'{experiment_folder_name}/weights/cnn_generator.pth')
288 |
289 | wandb.watch(cnn_generator)
290 |
291 | # Create the three types of discriminators
292 | if discriminator_type == "vitgan":
293 | Discriminator = ViT(discriminator=True,
294 | patch_size=patch_size*2,
295 | stride_size=patch_size,
296 | n_classes=1,
297 | num_heads=num_heads,
298 | depth=depth,
299 | forward_drop_p=dropout_p,
300 | ).to(device)
301 |
302 | print(Discriminator)
303 |
304 | Discriminator.apply(init_normal)
305 |
306 | if os.path.exists(f'{experiment_folder_name}/weights/discriminator.pth'):
307 | Discriminator = torch.load(f'{experiment_folder_name}/weights/discriminator.pth')
308 |
309 | wandb.watch(Discriminator)
310 |
311 | elif discriminator_type == "cnn":
312 | cnn_discriminator = CNN().to(device)
313 |
314 | print(cnn_discriminator)
315 |
316 | cnn_discriminator.apply(init_normal)
317 |
318 | if os.path.exists(f'{experiment_folder_name}/weights/discriminator.pth'):
319 | cnn_discriminator = torch.load(f'{experiment_folder_name}/weights/discriminator.pth')
320 |
321 | wandb.watch(cnn_discriminator)
322 |
323 | elif discriminator_type == "stylegan2":
324 | stylegan2_discriminator = StyleGanDiscriminator(image_size=32).to(device)
325 |
326 | print(stylegan2_discriminator)
327 |
328 | # stylegan2_discriminator.apply(init_normal)
329 |
330 | if os.path.exists(f'{experiment_folder_name}/weights/discriminator.pth'):
331 | stylegan2_discriminator = torch.load(f'{experiment_folder_name}/weights/discriminator.pth')
332 |
333 | wandb.watch(stylegan2_discriminator)
334 |
335 | # Training
336 |
337 | os.makedirs(f"{experiment_folder_name}/weights", exist_ok = True)
338 | os.makedirs(f"{experiment_folder_name}/samples", exist_ok = True)
339 |
340 | # Loss function
341 | criterion = nn.BCEWithLogitsLoss()
342 |
343 | if discriminator_type == "cnn": discriminator = cnn_discriminator
344 | elif discriminator_type == "stylegan2": discriminator = stylegan2_discriminator
345 | elif discriminator_type == "vitgan": discriminator = Discriminator
346 |
347 | if generator_type == "cnn":
348 | params = cnn_generator.parameters()
349 | else:
350 | params = Generator.parameters()
351 | optim_g = torch.optim.Adam(lr=lr, params=params, betas=beta)
352 | optim_d = torch.optim.Adam(lr=lr_dis, params=discriminator.parameters(), betas=beta)
353 | ema = ExponentialMovingAverage(params, decay=0.995)
354 |
355 | fixed_noise = torch.FloatTensor(np.random.normal(0, 1, (16, latent_dim))).to(device)
356 |
357 | discriminator_f_img = torch.zeros([batch_size, 3, image_size, image_size]).to(device)
358 |
359 | trainset_len = len(trainloader.dataset)
360 |
361 | step = 0
362 | for epoch in range(epochs):
363 | for batch_id, batch in enumerate(trainloader):
364 | step += 1
365 |
366 | # Train discriminator
367 |
368 | # Forward + Backward with real images
369 | r_img = batch[0].to(device)
370 | r_logit = discriminator(r_img).flatten()
371 | r_label = torch.ones(r_logit.shape[0]).to(device)
372 |
373 | lossD_real = criterion(r_logit, r_label)
374 |
375 | lossD_bCR_real = F.mse_loss(r_logit, discriminator(r_img, do_augment=False))
376 |
377 | # Forward + Backward with fake images
378 | latent_vector = torch.FloatTensor(np.random.normal(0, 1, (batch_size, latent_dim))).to(device)
379 |
380 | if generator_type == "vitgan":
381 | f_img = Generator(latent_vector)
382 | f_img = f_img.reshape([-1, image_size, image_size, out_features])
383 | f_img = f_img.permute(0, 3, 1, 2)
384 | else:
385 | model_output = cnn_generator(latent_vector)
386 | f_img = model_output
387 |
388 | assert f_img.size(0) == batch_size, f_img.shape
389 | assert f_img.size(1) == out_features, f_img.shape
390 | assert f_img.size(2) == image_size, f_img.shape
391 | assert f_img.size(3) == image_size, f_img.shape
392 |
393 | f_label = torch.zeros(batch_size).to(device)
394 | # Save the a single generated image to the discriminator training data
395 | if batch_size_history_discriminator:
396 | discriminator_f_img[step % batch_size] = f_img[0].detach()
397 | f_logit_history = discriminator(discriminator_f_img).flatten()
398 | lossD_fake_history = criterion(f_logit_history, f_label)
399 | else: lossD_fake_history = 0
400 | # Train the discriminator on the images, generated only from this batch
401 | f_logit = discriminator(f_img.detach()).flatten()
402 | lossD_fake = criterion(f_logit, f_label)
403 |
404 | lossD_bCR_fake = F.mse_loss(f_logit, discriminator(f_img, do_augment=False))
405 |
406 | f_noise_input = torch.FloatTensor(np.random.rand(*f_img.shape)*2 - 1).to(device)
407 | f_noise_logit = discriminator(f_noise_input).flatten()
408 | lossD_noise = criterion(f_noise_logit, f_label)
409 |
410 | lossD = lossD_real * 0.5 +\
411 | lossD_fake * 0.5 +\
412 | lossD_fake_history * lambda_lossD_history +\
413 | lossD_noise * lambda_lossD_noise +\
414 | lossD_bCR_real * lambda_bCR_real +\
415 | lossD_bCR_fake * lambda_bCR_fake
416 |
417 | optim_d.zero_grad()
418 | lossD.backward()
419 | optim_d.step()
420 |
421 | # Train Generator
422 |
423 | if generator_type == "vitgan":
424 | f_img = Generator(latent_vector)
425 | f_img = f_img.reshape([-1, image_size, image_size, out_features])
426 | f_img = f_img.permute(0, 3, 1, 2)
427 | else:
428 | model_output = cnn_generator(latent_vector)
429 | f_img = model_output
430 |
431 | assert f_img.size(0) == batch_size
432 | assert f_img.size(1) == out_features
433 | assert f_img.size(2) == image_size
434 | assert f_img.size(3) == image_size
435 |
436 | f_logit = discriminator(f_img).flatten()
437 | r_label = torch.ones(batch_size).to(device)
438 | lossG_main = criterion(f_logit, r_label)
439 |
440 | lossG_diversity = diversity_loss(f_img) * lambda_diversity_penalty
441 | lossG = lossG_main + lossG_diversity
442 |
443 | optim_g.zero_grad()
444 | lossG.backward()
445 | optim_g.step()
446 | ema.update()
447 |
448 | writer.add_scalar("Loss/Generator", lossG_main, step)
449 | writer.add_scalar("Loss/Gen(diversity)", lossG_diversity, step)
450 | writer.add_scalar("Loss/Dis(real)", lossD_real, step)
451 | writer.add_scalar("Loss/Dis(fake)", lossD_fake, step)
452 | writer.add_scalar("Loss/Dis(fake_history)", lossD_fake_history, step)
453 | writer.add_scalar("Loss/Dis(noise)", lossD_noise, step)
454 | writer.add_scalar("Loss/Dis(bCR_fake)", lossD_bCR_fake * lambda_bCR_fake, step)
455 | writer.add_scalar("Loss/Dis(bCR_real)", lossD_bCR_real * lambda_bCR_real, step)
456 | writer.flush()
457 |
458 | wandb.log({
459 | 'Generator': lossG_main,
460 | 'Gen(diversity)': lossG_diversity,
461 | 'Dis(real)': lossD_real,
462 | 'Dis(fake)': lossD_fake,
463 | 'Dis(fake_history)': lossD_fake_history,
464 | 'Dis(noise)': lossD_noise,
465 | 'Dis(bCR_fake)': lossD_bCR_fake * lambda_bCR_fake,
466 | 'Dis(bCR_real)': lossD_bCR_real * lambda_bCR_real
467 | })
468 |
469 | if batch_id%20 == 0:
470 | print(f'epoch {epoch}/{epochs}; batch {batch_id}/{int(trainset_len/batch_size)}')
471 | print(f'Generator: {"{:.3f}".format(float(lossG_main))}, '+\
472 | f'Gen(diversity): {"{:.3f}".format(float(lossG_diversity))}, '+\
473 | f'Dis(real): {"{:.3f}".format(float(lossD_real))}, '+\
474 | f'Dis(fake): {"{:.3f}".format(float(lossD_fake))}, '+\
475 | f'Dis(fake_history): {"{:.3f}".format(float(lossD_fake_history))}, '+\
476 | f'Dis(noise) {"{:.3f}".format(float(lossD_noise))}, '+\
477 | f'Dis(bCR_fake): {"{:.3f}".format(float(lossD_bCR_fake * lambda_bCR_fake))}, '+\
478 | f'Dis(bCR_real): {"{:.3f}".format(float(lossD_bCR_real * lambda_bCR_real))}')
479 |
480 | # Plot 8 randomly selected samples
481 | fig, axes = plt.subplots(1,8, figsize=(24,3))
482 | output = f_img.permute(0, 2, 3, 1)
483 | for i in range(8):
484 | j = np.random.randint(0, batch_size-1)
485 | img = output[j].cpu().view(32,32,3).detach().numpy()
486 | img -= img.min()
487 | img /= img.max()
488 | axes[i].imshow(img)
489 | plt.show()
490 |
491 | # if step % sample_interval == 0:
492 | if generator_type == "vitgan":
493 | Generator.eval()
494 | # img_siren.eval()
495 | vis = Generator(fixed_noise)
496 | vis = vis.reshape([-1, image_size, image_size, out_features])
497 | vis = vis.permute(0, 3, 1, 2)
498 | else:
499 | model_output = cnn_generator(fixed_noise)
500 | vis = model_output
501 |
502 | assert vis.shape[0] == fixed_noise.shape[0], f'vis.shape[0] is {vis.shape[0]}, but should be {fixed_noise.shape[0]}'
503 | assert vis.shape[1] == out_features, f'vis.shape[1] is {vis.shape[1]}, but should be {out_features}'
504 | assert vis.shape[2] == image_size, f'vis.shape[2] is {vis.shape[2]}, but should be {image_size}'
505 | assert vis.shape[3] == image_size, f'vis.shape[3] is {vis.shape[3]}, but should be {image_size}'
506 |
507 | vis.detach().cpu()
508 | vis = make_grid(vis, nrow = 4, padding = 5, normalize = True)
509 | writer.add_image(f'Generated/epoch_{epoch}', vis)
510 | wandb.log({'examples': wandb.Image(vis)})
511 |
512 | vis = T.ToPILImage()(vis)
513 | vis.save(f'{experiment_folder_name}/samples/vis{epoch}.jpg')
514 | if generator_type == "vitgan":
515 | Generator.train()
516 | # img_siren.train()
517 | else:
518 | cnn_generator.train()
519 | print(f"Save sample to {experiment_folder_name}/samples/vis{epoch}.jpg")
520 |
521 | # Save the checkpoints.
522 | if generator_type == "vitgan":
523 | torch.save(Generator, f'{experiment_folder_name}/weights/Generator.pth')
524 | # torch.save(img_siren, f'{experiment_folder_name}/weights/img_siren.pth')
525 | elif generator_type == "cnn":
526 | torch.save(cnn_generator, f'{experiment_folder_name}/weights/cnn_generator.pth')
527 | torch.save(discriminator, f'{experiment_folder_name}/weights/discriminator.pth')
528 | print("Save model state.")
529 |
530 | writer.close()
531 |
--------------------------------------------------------------------------------
/ViTGAN_pytorch.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "id": "ZiVyqRSgDTpU"
7 | },
8 | "source": [
9 | "# ViTGAN pytorch implementation\n",
10 | "\n",
11 | "This notebook is a pytorch implementation of [VITGAN: Training GANs with Vision Transformers](https://arxiv.org/pdf/2107.04589v1.pdf)\n",
12 | "\n",
13 | "[](https://colab.research.google.com/drive/1kJJw6BYW01HgooCZ2zUDt54e1mXqITXH?usp=sharing)"
14 | ]
15 | },
16 | {
17 | "cell_type": "markdown",
18 | "metadata": {
19 | "id": "cjz0JmQyD4JJ"
20 | },
21 | "source": [
22 | "The model consists of a Vision Transformer Generator and a Vision Transformer Discriminator.\n",
23 | "\n",
24 | "It is adversarially trained to map latent vectors to images, which closely resemble the images from a given dataset. In this implementation, the dataset used is CIFAR-10.\n",
25 | "\n",
26 | "The Generator takes latent values $z$ as input, which is integrated in a Vision Transformer Encoder. The output for each patch of the image is fed to a SIREN network, in combination with a Fourier Embedding ($E_{fou}$)\n",
27 | "\n",
28 | "\n",
29 | "\n",
30 | "This implementation separates the Generator in Vision Transformer and SIREN networks for debugging purposes."
31 | ]
32 | },
33 | {
34 | "cell_type": "markdown",
35 | "metadata": {
36 | "id": "6ZeCyoxmxwud"
37 | },
38 | "source": [
39 | "1. [x] Use vectorized L2 distance in attention for **Discriminator**\n",
40 | "2. [x] Overlapping Image Patches\n",
41 | "2. [x] DiffAugment\n",
42 | "3. [x] Self-modulated LayerNorm (SLN)\n",
43 | "4. [x] Implicit Neural Representation for Patch Generation\n",
44 | "5. [x] ExponentialMovingAverage (EMA)\n",
45 | "6. [x] Balanced Consistency Regularization (bCR)\n",
46 | "7. [x] Improved Spectral Normalization"
47 | ]
48 | },
49 | {
50 | "cell_type": "code",
51 | "execution_count": null,
52 | "metadata": {
53 | "id": "Z_gqyo3DZA3J"
54 | },
55 | "outputs": [],
56 | "source": [
57 | "! pip install einops\n",
58 | "! pip install git+https://github.com/fadel/pytorch_ema\n",
59 | "! pip install stylegan2-pytorch\n",
60 | "! pip install tensorboard\n",
61 | "! pip install wandb"
62 | ]
63 | },
64 | {
65 | "cell_type": "code",
66 | "execution_count": null,
67 | "metadata": {
68 | "id": "TufykVSpr-QU"
69 | },
70 | "outputs": [],
71 | "source": [
72 | "! wandb login"
73 | ]
74 | },
75 | {
76 | "cell_type": "code",
77 | "execution_count": null,
78 | "metadata": {
79 | "id": "cXCe_q68481Q"
80 | },
81 | "outputs": [],
82 | "source": [
83 | "import torch\n",
84 | "from torch import nn\n",
85 | "import torch.nn.functional as F\n",
86 | "from torch.nn import Parameter\n",
87 | "import os\n",
88 | "\n",
89 | "import numpy as np\n",
90 | "import matplotlib.pyplot as plt\n",
91 | "\n",
92 | "import time\n",
93 | "\n",
94 | "import torchvision\n",
95 | "import torchvision.transforms as transforms\n",
96 | "from einops import rearrange, repeat\n",
97 | "from einops.layers.torch import Rearrange\n",
98 | "import torch.nn.functional as F\n",
99 | "import torchvision.transforms as T\n",
100 | "from torchvision.utils import make_grid\n",
101 | "\n",
102 | "from torch_ema import ExponentialMovingAverage\n",
103 | "\n",
104 | "from stylegan2_pytorch import stylegan2_pytorch\n",
105 | "\n",
106 | "from torch.utils.tensorboard import SummaryWriter\n",
107 | "import wandb"
108 | ]
109 | },
110 | {
111 | "cell_type": "markdown",
112 | "metadata": {
113 | "id": "g-JOToTeB9uE"
114 | },
115 | "source": [
116 | "Hyperparameters"
117 | ]
118 | },
119 | {
120 | "cell_type": "code",
121 | "execution_count": null,
122 | "metadata": {
123 | "id": "aeQq8IuCAPmZ"
124 | },
125 | "outputs": [],
126 | "source": [
127 | "image_size = 32\n",
128 | "style_mlp_layers = 8\n",
129 | "patch_size = 4\n",
130 | "latent_dim = 512 # Size of z\n",
131 | "hidden_size = 384\n",
132 | "depth = 4\n",
133 | "num_heads = 4\n",
134 | "\n",
135 | "dropout_p = 0.\n",
136 | "bias = True\n",
137 | "weight_modulation = True\n",
138 | "demodulation = False\n",
139 | "siren_hidden_layers = 1\n",
140 | "\n",
141 | "combine_patch_embeddings = False # Generate an image from a single SIREN, instead of patch-by-patch\n",
142 | "combine_patch_embeddings_size = hidden_size * 4\n",
143 | "\n",
144 | "sln_paremeter_size = hidden_size # either hidden_size or 1\n",
145 | "\n",
146 | "batch_size = 50\n",
147 | "device = \"cuda\"\n",
148 | "out_features = 3 # The number of color channels\n",
149 | "\n",
150 | "generator_type = \"vitgan\" # \"vitgan\", \"cnn\"\n",
151 | "discriminator_type = \"vitgan\" # \"vitgan\", \"cnn\", \"stylegan2\"\n",
152 | "\n",
153 | "lr = 7e-4 # Learning rate\n",
154 | "lr_dis = 7e-4 # Learning rate\n",
155 | "beta = (0., 0.99) # Adam oprimizer parameters for both the generator and the discriminator\n",
156 | "batch_size_history_discriminator = False # Whether to use a loss, which tracks one sample from last batch_size number of batches\n",
157 | "epochs = 400 # Number of epochs\n",
158 | "lambda_bCR_real = 10\n",
159 | "lambda_bCR_fake = 10\n",
160 | "lambda_lossD_noise = 0.0\n",
161 | "lambda_lossD_history = 0.0\n",
162 | "lambda_diversity_penalty = 0.0\n",
163 | "\n",
164 | "experiment_folder_name = f'lr-{lr}_\\\n",
165 | "lr_dis-{lr_dis}_\\\n",
166 | "bias-{bias}_\\\n",
167 | "demod-{demodulation}_\\\n",
168 | "sir_n_layer-{siren_hidden_layers}_\\\n",
169 | "w_mod-{weight_modulation}_\\\n",
170 | "patch_s-{patch_size}_\\\n",
171 | "st_mlp_l-{style_mlp_layers}_\\\n",
172 | "hid_size-{hidden_size}_\\\n",
173 | "comb_patch_emb-{combine_patch_embeddings}_\\\n",
174 | "sln_par_s-{sln_paremeter_size}_\\\n",
175 | "dis_type-{discriminator_type}_\\\n",
176 | "gen_type-{generator_type}_\\\n",
177 | "n_head-{num_heads}_\\\n",
178 | "depth-{depth}_\\\n",
179 | "drop_p-{dropout_p}_\\\n",
180 | "l_bCR_r-{lambda_bCR_real}_\\\n",
181 | "l_bCR_f-{lambda_bCR_fake}_\\\n",
182 | "l_D_noise-{lambda_lossD_noise}_\\\n",
183 | "l_D_his-{lambda_lossD_history}\\\n",
184 | "'\n",
185 | "writer = SummaryWriter(log_dir=experiment_folder_name)\n",
186 | "\n",
187 | "wandb.init(project='ViTGAN-pytorch')\n",
188 | "config = wandb.config\n",
189 | "config.image_size = image_size\n",
190 | "config.bias = bias\n",
191 | "config.demodulation = demodulation\n",
192 | "config.siren_hidden_layers = siren_hidden_layers\n",
193 | "config.weight_modulation = weight_modulation\n",
194 | "config.style_mlp_layers = style_mlp_layers\n",
195 | "config.patch_size = patch_size\n",
196 | "config.latent_dim = latent_dim\n",
197 | "config.hidden_size = hidden_size\n",
198 | "config.depth = depth\n",
199 | "config.num_heads = num_heads\n",
200 | "\n",
201 | "config.dropout_p = dropout_p\n",
202 | "\n",
203 | "config.combine_patch_embeddings = combine_patch_embeddings\n",
204 | "config.combine_patch_embeddings_size = combine_patch_embeddings_size\n",
205 | "\n",
206 | "config.sln_paremeter_size = sln_paremeter_size\n",
207 | "\n",
208 | "config.batch_size = batch_size\n",
209 | "config.device = device\n",
210 | "config.out_features = out_features\n",
211 | "\n",
212 | "config.generator_type = generator_type\n",
213 | "config.discriminator_type = discriminator_type\n",
214 | "\n",
215 | "config.lr = lr\n",
216 | "config.lr_dis = lr_dis\n",
217 | "config.beta1 = beta[0]\n",
218 | "config.beta2 = beta[1]\n",
219 | "config.batch_size_history_discriminator = batch_size_history_discriminator\n",
220 | "config.epochs = epochs\n",
221 | "config.lambda_bCR_real = lambda_bCR_real\n",
222 | "config.lambda_bCR_fake = lambda_bCR_fake\n",
223 | "config.lambda_lossD_noise = lambda_lossD_noise\n",
224 | "config.lambda_lossD_history = lambda_lossD_history\n",
225 | "config.lambda_diversity_penalty = lambda_diversity_penalty"
226 | ]
227 | },
228 | {
229 | "cell_type": "code",
230 | "execution_count": null,
231 | "metadata": {
232 | "id": "V3dGermsa1bG"
233 | },
234 | "outputs": [],
235 | "source": [
236 | "if combine_patch_embeddings:\n",
237 | " out_patch_size = image_size\n",
238 | " combined_embedding_size = combine_patch_embeddings_size\n",
239 | "else:\n",
240 | " out_patch_size = patch_size\n",
241 | " combined_embedding_size = hidden_size\n",
242 | "\n",
243 | "siren_in_features = combined_embedding_size"
244 | ]
245 | },
246 | {
247 | "cell_type": "markdown",
248 | "metadata": {
249 | "id": "2M4De21AB8lE"
250 | },
251 | "source": [
252 | "\n",
253 | "\n",
254 | "https://github.com/mit-han-lab/data-efficient-gans/blob/master/DiffAugment-stylegan2-pytorch/DiffAugment_pytorch.py"
255 | ]
256 | },
257 | {
258 | "cell_type": "code",
259 | "execution_count": null,
260 | "metadata": {
261 | "id": "XsaytJSj0inJ"
262 | },
263 | "outputs": [],
264 | "source": [
265 | "def DiffAugment(x, policy='', channels_first=True):\n",
266 | " if policy:\n",
267 | " if not channels_first:\n",
268 | " x = x.permute(0, 3, 1, 2)\n",
269 | " for p in policy.split(','):\n",
270 | " for f in AUGMENT_FNS[p]:\n",
271 | " x = f(x)\n",
272 | " if not channels_first:\n",
273 | " x = x.permute(0, 2, 3, 1)\n",
274 | " x = x.contiguous()\n",
275 | " return x\n",
276 | "\n",
277 | "\n",
278 | "def rand_brightness(x):\n",
279 | " x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)\n",
280 | " return x\n",
281 | "\n",
282 | "\n",
283 | "def rand_saturation(x):\n",
284 | " x_mean = x.mean(dim=1, keepdim=True)\n",
285 | " x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean\n",
286 | " return x\n",
287 | "\n",
288 | "\n",
289 | "def rand_contrast(x):\n",
290 | " x_mean = x.mean(dim=[1, 2, 3], keepdim=True)\n",
291 | " x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean\n",
292 | " return x\n",
293 | "\n",
294 | "\n",
295 | "def rand_translation(x, ratio=0.1):\n",
296 | " shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)\n",
297 | " translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)\n",
298 | " translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)\n",
299 | " grid_batch, grid_x, grid_y = torch.meshgrid(\n",
300 | " torch.arange(x.size(0), dtype=torch.long, device=x.device),\n",
301 | " torch.arange(x.size(2), dtype=torch.long, device=x.device),\n",
302 | " torch.arange(x.size(3), dtype=torch.long, device=x.device),\n",
303 | " )\n",
304 | " grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)\n",
305 | " grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)\n",
306 | " x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])\n",
307 | " x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2).contiguous()\n",
308 | " return x\n",
309 | "\n",
310 | "\n",
311 | "def rand_cutout(x, ratio=0.3):\n",
312 | " cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)\n",
313 | " offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device)\n",
314 | " offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device)\n",
315 | " grid_batch, grid_x, grid_y = torch.meshgrid(\n",
316 | " torch.arange(x.size(0), dtype=torch.long, device=x.device),\n",
317 | " torch.arange(cutout_size[0], dtype=torch.long, device=x.device),\n",
318 | " torch.arange(cutout_size[1], dtype=torch.long, device=x.device),\n",
319 | " )\n",
320 | " grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)\n",
321 | " grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)\n",
322 | " mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)\n",
323 | " mask[grid_batch, grid_x, grid_y] = 0\n",
324 | " x = x * mask.unsqueeze(1)\n",
325 | " return x\n",
326 | "\n",
327 | "\n",
328 | "AUGMENT_FNS = {\n",
329 | " 'color': [rand_brightness, rand_saturation, rand_contrast],\n",
330 | " 'translation': [rand_translation],\n",
331 | " 'cutout': [rand_cutout],\n",
332 | "}"
333 | ]
334 | },
335 | {
336 | "cell_type": "code",
337 | "execution_count": null,
338 | "metadata": {
339 | "id": "VA-uwpp7b1QK"
340 | },
341 | "outputs": [],
342 | "source": [
343 | "transform = transforms.Compose(\n",
344 | " [transforms.ToTensor(),\n",
345 | " transforms.Normalize((0., 0., 0.), (1., 1., 1.))\n",
346 | " ])\n",
347 | "\n",
348 | "trainset = torchvision.datasets.CIFAR10(root='./data', train=True,\n",
349 | " download=True, transform=transform)\n",
350 | "trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,\n",
351 | " shuffle=True, num_workers=2)"
352 | ]
353 | },
354 | {
355 | "cell_type": "markdown",
356 | "metadata": {
357 | "id": "63v16amHIG2L"
358 | },
359 | "source": [
360 | "Visualize the effects of the DiffAugment"
361 | ]
362 | },
363 | {
364 | "cell_type": "code",
365 | "execution_count": null,
366 | "metadata": {
367 | "id": "fnQW4nBraOJb"
368 | },
369 | "outputs": [],
370 | "source": [
371 | "img = next(iter(trainloader))[0]\n",
372 | "img = DiffAugment(img, policy='color,translation,cutout', channels_first=True)\n",
373 | "img = img.permute(0,2,3,1)[0]\n",
374 | "img -= img.min()\n",
375 | "img /= img.max()\n",
376 | "plt.imshow(img)"
377 | ]
378 | },
379 | {
380 | "cell_type": "markdown",
381 | "metadata": {
382 | "id": "8QKaT8pk9vD_"
383 | },
384 | "source": [
385 | "# Improved Spectral Normalization (ISN)\n",
386 | "\n",
387 | "$$\n",
388 | "\\bar{W}_{ISN}(W):=\\sigma(W_{init})\\cdot W/\\sigma(W)\n",
389 | "$$\n",
390 | "\n",
391 | "Reference code: https://github.com/koshian2/SNGAN\n",
392 | "\n",
393 | "When updating the weights, normalize the weights' norm to its norm at initialization."
394 | ]
395 | },
396 | {
397 | "cell_type": "code",
398 | "execution_count": null,
399 | "metadata": {
400 | "id": "BL2ZLEnM-xct"
401 | },
402 | "outputs": [],
403 | "source": [
404 | "def l2normalize(v, eps=1e-4):\n",
405 | "\treturn v / (v.norm() + eps)\n",
406 | "\n",
407 | "class spectral_norm(nn.Module):\n",
408 | "\tdef __init__(self, module, name='weight', power_iterations=1):\n",
409 | "\t\tsuper().__init__()\n",
410 | "\t\tself.module = module\n",
411 | "\t\tself.name = name\n",
412 | "\t\tself.power_iterations = power_iterations\n",
413 | "\t\tif not self._made_params():\n",
414 | "\t\t\tself._make_params()\n",
415 | "\t\tself.w_init_sigma = None\n",
416 | "\t\tself.w_initalized = False\n",
417 | "\n",
418 | "\tdef _update_u_v(self):\n",
419 | "\t\tu = getattr(self.module, self.name + \"_u\")\n",
420 | "\t\tv = getattr(self.module, self.name + \"_v\")\n",
421 | "\t\tw = getattr(self.module, self.name + \"_bar\")\n",
422 | "\n",
423 | "\t\theight = w.data.shape[0]\n",
424 | "\t\t_w = w.view(height, -1)\n",
425 | "\t\tfor _ in range(self.power_iterations):\n",
426 | "\t\t\tv = l2normalize(torch.matmul(_w.t(), u))\n",
427 | "\t\t\tu = l2normalize(torch.matmul(_w, v))\n",
428 | "\n",
429 | "\t\tsigma = u.dot((_w).mv(v))\n",
430 | "\t\tif not self.w_initalized:\n",
431 | "\t\t\tself.w_init_sigma = np.array(sigma.expand_as(w).detach().cpu())\n",
432 | "\t\t\tself.w_initalized = True\n",
433 | "\t\tsetattr(self.module, self.name, torch.tensor(self.w_init_sigma).to(device) * w / sigma.expand_as(w))\n",
434 | "\n",
435 | "\tdef _made_params(self):\n",
436 | "\t\ttry:\n",
437 | "\t\t\tgetattr(self.module, self.name + \"_u\")\n",
438 | "\t\t\tgetattr(self.module, self.name + \"_v\")\n",
439 | "\t\t\tgetattr(self.module, self.name + \"_bar\")\n",
440 | "\t\t\treturn True\n",
441 | "\t\texcept AttributeError:\n",
442 | "\t\t\treturn False\n",
443 | "\n",
444 | "\tdef _make_params(self):\n",
445 | "\t\tw = getattr(self.module, self.name)\n",
446 | "\n",
447 | "\t\theight = w.data.shape[0]\n",
448 | "\t\twidth = w.view(height, -1).data.shape[1]\n",
449 | "\n",
450 | "\t\tu = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)\n",
451 | "\t\tv = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)\n",
452 | "\t\tu.data = l2normalize(u.data)\n",
453 | "\t\tv.data = l2normalize(v.data)\n",
454 | "\t\tw_bar = Parameter(w.data)\n",
455 | "\n",
456 | "\t\tdel self.module._parameters[self.name]\n",
457 | "\t\tself.module.register_parameter(self.name + \"_u\", u)\n",
458 | "\t\tself.module.register_parameter(self.name + \"_v\", v)\n",
459 | "\t\tself.module.register_parameter(self.name + \"_bar\", w_bar)\n",
460 | "\n",
461 | "\tdef forward(self, *args):\n",
462 | "\t\tself._update_u_v()\n",
463 | "\t\treturn self.module.forward(*args)"
464 | ]
465 | },
466 | {
467 | "cell_type": "markdown",
468 | "metadata": {
469 | "id": "Ttk-7hLuIZ4o"
470 | },
471 | "source": [
472 | "Vision Transformer reference code: \\[[Blog Post](https://towardsdatascience.com/implementing-visualttransformer-in-pytorch-184f9f16f632)\\]\n",
473 | "\n",
474 | "Normal Attention Mechanism\n",
475 | "\n",
476 | "$$\n",
477 | "Attention_h(X) = softmax \\bigg ( \\frac{QK^T}{\\sqrt{d_h}} V \\bigg )\n",
478 | "$$\n",
479 | "\n",
480 | "Lipschitz Attention Mechanism\n",
481 | "\n",
482 | "$$\n",
483 | "Attention_h(X) = softmax \\bigg ( \\frac{d(Q,K)}{\\sqrt{d_h}} V \\bigg )\n",
484 | "$$\n",
485 | "\n",
486 | "where $d(Q,K)$ is L2-distance.\n",
487 | "\n",
488 | "https://arxiv.org/pdf/2006.04710.pdf"
489 | ]
490 | },
491 | {
492 | "cell_type": "code",
493 | "execution_count": null,
494 | "metadata": {
495 | "id": "F1VgcIu6MFow"
496 | },
497 | "outputs": [],
498 | "source": [
499 | "class MultiHeadAttention(nn.Module):\n",
500 | " def __init__(self, emb_size=384, num_heads=4, dropout=0, discriminator=False, **kwargs):\n",
501 | " super().__init__()\n",
502 | " self.emb_size = emb_size\n",
503 | " self.num_heads = num_heads\n",
504 | " self.discriminator = discriminator\n",
505 | " # fuse the queries, keys and values in one matrix\n",
506 | " self.qkv = nn.Linear(emb_size, emb_size * 3)\n",
507 | " self.att_drop = nn.Dropout(dropout)\n",
508 | " self.projection = nn.Linear(emb_size, emb_size)\n",
509 | " if self.discriminator:\n",
510 | " self.qkv = spectral_norm(self.qkv)\n",
511 | " self.projection = spectral_norm(self.projection)\n",
512 | " \n",
513 | " def forward(self, x, mask=None):\n",
514 | " # split keys, queries and values in num_heads\n",
515 | " qkv = rearrange(self.qkv(x), \"b n (h d qkv) -> (qkv) b h n d\", h=self.num_heads, qkv=3)\n",
516 | " queries, keys, values = qkv[0], qkv[1], qkv[2]\n",
517 | " if self.discriminator:\n",
518 | " # calculate L2-distances\n",
519 | " energy = torch.cdist(queries.contiguous(), keys.contiguous(), p=2)\n",
520 | " else:\n",
521 | " # sum up over the last axis\n",
522 | " energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys) # batch, num_heads, query_len, key_len\n",
523 | "\n",
524 | " if mask is not None:\n",
525 | " fill_value = torch.finfo(torch.float32).min\n",
526 | " energy.mask_fill(~mask, fill_value)\n",
527 | " \n",
528 | " scaling = self.emb_size ** (1/2)\n",
529 | " att = F.softmax(energy, dim=-1) / scaling\n",
530 | " att = self.att_drop(att)\n",
531 | " # sum up over the third axis\n",
532 | " out = torch.einsum('bhal, bhlv -> bhav ', att, values)\n",
533 | " out = rearrange(out, \"b h n d -> b n (h d)\")\n",
534 | " out = self.projection(out)\n",
535 | " return out"
536 | ]
537 | },
538 | {
539 | "cell_type": "markdown",
540 | "metadata": {
541 | "id": "hpUuzMsS9G-t"
542 | },
543 | "source": [
544 | "# Generator"
545 | ]
546 | },
547 | {
548 | "cell_type": "code",
549 | "execution_count": null,
550 | "metadata": {
551 | "id": "9xwQ0P_83N6L"
552 | },
553 | "outputs": [],
554 | "source": [
555 | "class FullyConnectedLayer(nn.Module):\n",
556 | " def __init__(self,\n",
557 | " in_features, # Number of input features.\n",
558 | " out_features, # Number of output features.\n",
559 | " bias = True, # Apply additive bias before the activation function?\n",
560 | " activation = 'linear', # Activation function: 'relu', 'lrelu', etc.\n",
561 | " lr_multiplier = 1, # Learning rate multiplier.\n",
562 | " bias_init = 0, # Initial value for the additive bias.\n",
563 | " **kwargs\n",
564 | " ):\n",
565 | " super().__init__()\n",
566 | " self.activation = activation\n",
567 | " if activation == 'lrelu':\n",
568 | " self.activation = nn.LeakyReLU(0.2)\n",
569 | " elif activation == 'relu':\n",
570 | " self.activation = nn.ReLU()\n",
571 | " elif activation == 'gelu':\n",
572 | " self.activation = nn.GELU()\n",
573 | " self.weight = nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier)\n",
574 | " self.bias = nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None\n",
575 | " self.weight_gain = lr_multiplier / np.sqrt(in_features)\n",
576 | " self.bias_gain = lr_multiplier\n",
577 | "\n",
578 | " def forward(self, x):\n",
579 | " w = self.weight.to(x.dtype) * self.weight_gain\n",
580 | " b = self.bias\n",
581 | " if b is not None:\n",
582 | " b = b.to(x.dtype)\n",
583 | " if self.bias_gain != 1:\n",
584 | " b = b * self.bias_gain\n",
585 | "\n",
586 | " if self.activation == 'linear' and b is not None:\n",
587 | " # print(b.shape, x.shape, w.t().shape)\n",
588 | " x = torch.addmm(b.unsqueeze(0), x, w.t())\n",
589 | " else:\n",
590 | " x = x.matmul(w.t())\n",
591 | " if b is not None:\n",
592 | " x = x + b\n",
593 | " if self.activation != 'linear':\n",
594 | " x = self.activation(x)\n",
595 | " return x"
596 | ]
597 | },
598 | {
599 | "cell_type": "code",
600 | "execution_count": null,
601 | "metadata": {
602 | "id": "jx-xGRh_61Qz"
603 | },
604 | "outputs": [],
605 | "source": [
606 | "def normalize_2nd_moment(x, dim=1, eps=1e-8):\n",
607 | " return x * (x.square().mean(dim=dim, keepdim=True) + eps).rsqrt()"
608 | ]
609 | },
610 | {
611 | "cell_type": "code",
612 | "execution_count": null,
613 | "metadata": {
614 | "id": "YCHXNmaF5C6m"
615 | },
616 | "outputs": [],
617 | "source": [
618 | "class MappingNetwork(nn.Module):\n",
619 | " def __init__(self,\n",
620 | " z_dim, # Input latent (Z) dimensionality, 0 = no latent.\n",
621 | " c_dim, # Conditioning label (C) dimensionality, 0 = no label.\n",
622 | " w_dim, # Intermediate latent (W) dimensionality.\n",
623 | " num_ws = None, # Number of intermediate latents to output, None = do not broadcast.\n",
624 | " num_layers = 8, # Number of mapping layers.\n",
625 | " embed_features = None, # Label embedding dimensionality, None = same as w_dim.\n",
626 | " layer_features = None, # Number of intermediate features in the mapping layers, None = same as w_dim.\n",
627 | " activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc.\n",
628 | " lr_multiplier = 0.01, # Learning rate multiplier for the mapping layers.\n",
629 | " w_avg_beta = 0.995, # Decay for tracking the moving average of W during training, None = do not track.\n",
630 | " **kwargs\n",
631 | " ):\n",
632 | " super().__init__()\n",
633 | " self.z_dim = z_dim\n",
634 | " self.c_dim = c_dim\n",
635 | " self.w_dim = w_dim\n",
636 | " self.num_ws = num_ws\n",
637 | " self.num_layers = num_layers\n",
638 | " self.w_avg_beta = w_avg_beta\n",
639 | "\n",
640 | " if embed_features is None:\n",
641 | " embed_features = w_dim\n",
642 | " if c_dim == 0:\n",
643 | " embed_features = 0\n",
644 | " if layer_features is None:\n",
645 | " layer_features = w_dim\n",
646 | " features_list = [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim]\n",
647 | "\n",
648 | " if c_dim > 0:\n",
649 | " self.embed = FullyConnectedLayer(c_dim, embed_features)\n",
650 | " for idx in range(num_layers):\n",
651 | " in_features = features_list[idx]\n",
652 | " out_features = features_list[idx + 1]\n",
653 | " layer = FullyConnectedLayer(in_features, out_features, activation=activation, lr_multiplier=lr_multiplier)\n",
654 | " setattr(self, f'fc{idx}', layer)\n",
655 | "\n",
656 | " if num_ws is not None and w_avg_beta is not None:\n",
657 | " self.register_buffer('w_avg', torch.zeros([w_dim]))\n",
658 | "\n",
659 | " def forward(self, z, c=None, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False):\n",
660 | " # Embed, normalize, and concat inputs.\n",
661 | " x = None\n",
662 | " with torch.autograd.profiler.record_function('input'):\n",
663 | " if self.z_dim > 0:\n",
664 | " assert z.shape[1] == self.z_dim\n",
665 | " x = normalize_2nd_moment(z.to(torch.float32))\n",
666 | " if self.c_dim > 0:\n",
667 | " assert c.shape[1] == self.c_dim\n",
668 | " y = normalize_2nd_moment(self.embed(c.to(torch.float32)))\n",
669 | " x = torch.cat([x, y], dim=1) if x is not None else y\n",
670 | "\n",
671 | " # Main layers.\n",
672 | " for idx in range(self.num_layers):\n",
673 | " layer = getattr(self, f'fc{idx}')\n",
674 | " x = layer(x)\n",
675 | "\n",
676 | " # Update moving average of W.\n",
677 | " if self.w_avg_beta is not None and self.training and not skip_w_avg_update:\n",
678 | " with torch.autograd.profiler.record_function('update_w_avg'):\n",
679 | " self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta))\n",
680 | "\n",
681 | " # Broadcast.\n",
682 | " if self.num_ws is not None:\n",
683 | " with torch.autograd.profiler.record_function('broadcast'):\n",
684 | " x = x.unsqueeze(1).repeat([1, self.num_ws, 1])\n",
685 | "\n",
686 | " # Apply truncation.\n",
687 | " if truncation_psi != 1:\n",
688 | " with torch.autograd.profiler.record_function('truncate'):\n",
689 | " assert self.w_avg_beta is not None\n",
690 | " if self.num_ws is None or truncation_cutoff is None:\n",
691 | " x = self.w_avg.lerp(x, truncation_psi)\n",
692 | " else:\n",
693 | " x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi)\n",
694 | " return x"
695 | ]
696 | },
697 | {
698 | "cell_type": "code",
699 | "execution_count": null,
700 | "metadata": {
701 | "id": "2GJiP09nL_Lh"
702 | },
703 | "outputs": [],
704 | "source": [
705 | "class FeedForwardBlock(nn.Sequential):\n",
706 | " def __init__(self, emb_size, expansion=4, drop_p=0., bias=False):\n",
707 | " super().__init__(\n",
708 | " FullyConnectedLayer(expansion, emb_size * emb_size, activation='gelu', bias=False),\n",
709 | " nn.Dropout(drop_p),\n",
710 | " FullyConnectedLayer(expansion * emb_size, emb_size, bias=False),\n",
711 | " )"
712 | ]
713 | },
714 | {
715 | "cell_type": "markdown",
716 | "metadata": {
717 | "id": "4713MviFIl8R"
718 | },
719 | "source": [
720 | "Self-Modulated LayerNorm\n",
721 | "$$\n",
722 | "SLN(h_{\\ell},w)=\\gamma_{\\ell}(w)\\odot\\frac{h_{\\ell}-\\mu}{\\sigma}+\\beta_{\\ell}(w)\n",
723 | "$$\n",
724 | "\n",
725 | "where $\\gamma_{\\ell}, \\beta_{\\ell}\\in \\mathbb{R}^D$ or $\\gamma_{\\ell}, \\beta_{\\ell}\\in \\mathbb{R}^1$"
726 | ]
727 | },
728 | {
729 | "cell_type": "code",
730 | "execution_count": null,
731 | "metadata": {
732 | "id": "Wdd9c_sk9GTC"
733 | },
734 | "outputs": [],
735 | "source": [
736 | "class SLN(nn.Module):\n",
737 | " def __init__(self, input_size, parameter_size=None, **kwargs):\n",
738 | " super().__init__()\n",
739 | " if parameter_size == None:\n",
740 | " parameter_size = input_size\n",
741 | " assert(input_size == parameter_size or parameter_size == 1)\n",
742 | " self.input_size = input_size\n",
743 | " self.parameter_size = parameter_size\n",
744 | " self.ln = nn.LayerNorm(input_size)\n",
745 | " self.gamma = FullyConnectedLayer(input_size, parameter_size, bias=False)\n",
746 | " self.beta = FullyConnectedLayer(input_size, parameter_size, bias=False)\n",
747 | " # self.gamma = nn.Linear(input_size, parameter_size, bias=False)\n",
748 | " # self.beta = nn.Linear(input_size, parameter_size, bias=False)\n",
749 | "\n",
750 | " def forward(self, hidden, w):\n",
751 | " assert(hidden.size(-1) == self.parameter_size and w.size(-1) == self.parameter_size)\n",
752 | " gamma = self.gamma(w).unsqueeze(1)\n",
753 | " beta = self.beta(w).unsqueeze(1)\n",
754 | " ln = self.ln(hidden)\n",
755 | " return gamma * ln + beta"
756 | ]
757 | },
758 | {
759 | "cell_type": "code",
760 | "execution_count": null,
761 | "metadata": {
762 | "id": "YvIYYkHe9J3f"
763 | },
764 | "outputs": [],
765 | "source": [
766 | "class GeneratorTransformerEncoderBlock(nn.Module):\n",
767 | " def __init__(self,\n",
768 | " hidden_size=384,\n",
769 | " sln_paremeter_size=384,\n",
770 | " drop_p=0.,\n",
771 | " forward_expansion=4,\n",
772 | " forward_drop_p=0.,\n",
773 | " **kwargs):\n",
774 | " super().__init__()\n",
775 | " self.sln = SLN(hidden_size, parameter_size=sln_paremeter_size)\n",
776 | " self.msa = MultiHeadAttention(hidden_size, **kwargs)\n",
777 | " self.dropout = nn.Dropout(drop_p)\n",
778 | " self.feed_forward = FeedForwardBlock(hidden_size, expansion=forward_expansion, drop_p=forward_drop_p)\n",
779 | "\n",
780 | " def forward(self, hidden, w):\n",
781 | " res = hidden\n",
782 | " hidden = self.sln(hidden, w)\n",
783 | " hidden = self.msa(hidden)\n",
784 | " hidden = self.dropout(hidden)\n",
785 | " hidden += res\n",
786 | "\n",
787 | " res = hidden\n",
788 | " hidden = self.sln(hidden, w)\n",
789 | " self.feed_forward(hidden)\n",
790 | " hidden = self.dropout(hidden)\n",
791 | " hidden += res\n",
792 | " return hidden"
793 | ]
794 | },
795 | {
796 | "cell_type": "code",
797 | "execution_count": null,
798 | "metadata": {
799 | "id": "xXueAJqx9LsE"
800 | },
801 | "outputs": [],
802 | "source": [
803 | "class GeneratorTransformerEncoder(nn.Module):\n",
804 | " def __init__(self, depth=4, **kwargs):\n",
805 | " super().__init__()\n",
806 | " self.depth = depth\n",
807 | " self.blocks = nn.ModuleList([GeneratorTransformerEncoderBlock(**kwargs) for _ in range(depth)])\n",
808 | " \n",
809 | " def forward(self, hidden, w):\n",
810 | " for i in range(self.depth):\n",
811 | " hidden = self.blocks[i](hidden, w)\n",
812 | " return hidden"
813 | ]
814 | },
815 | {
816 | "cell_type": "markdown",
817 | "metadata": {
818 | "id": "BiQtj0Qa9b7L"
819 | },
820 | "source": [
821 | "# SIREN"
822 | ]
823 | },
824 | {
825 | "cell_type": "markdown",
826 | "metadata": {
827 | "id": "qH09oAfX481V"
828 | },
829 | "source": [
830 | "Code for SIREN is taken from [SIREN reference colab notebook](https://colab.research.google.com/github/vsitzmann/siren/blob/master/explore_siren.ipynb)"
831 | ]
832 | },
833 | {
834 | "cell_type": "markdown",
835 | "metadata": {
836 | "id": "EOD98u3DzTow"
837 | },
838 | "source": [
839 | "$$\n",
840 | "w^{'}_{ijk}=s_i\\cdot w_{ijk}\n",
841 | "$$\n",
842 | "\n",
843 | "$$\n",
844 | "w^{''}_{ijk}=\\frac{w^{'}_{ijk}}{\\sqrt{\\sum_{i,k}{w^{'}_{ijk}}^2+\\epsilon}}\n",
845 | "$$"
846 | ]
847 | },
848 | {
849 | "cell_type": "code",
850 | "execution_count": null,
851 | "metadata": {
852 | "id": "bm7OIWD5Y7Up"
853 | },
854 | "outputs": [],
855 | "source": [
856 | "class ModulatedLinear(nn.Module):\n",
857 | " def __init__(self, in_channels, out_channels, style_size, bias=False, demodulation=True, **kwargs):\n",
858 | " super().__init__()\n",
859 | " self.in_channels = in_channels\n",
860 | " self.out_channels = out_channels\n",
861 | " self.style_size = style_size\n",
862 | " self.scale = 1 / np.sqrt(in_channels)\n",
863 | " self.weight = nn.Parameter(\n",
864 | " torch.randn(1, out_channels, in_channels, 1)\n",
865 | " )\n",
866 | " self.modulation = None\n",
867 | " if self.style_size != self.in_channels:\n",
868 | " self.modulation = FullyConnectedLayer(style_size, in_channels, bias=False)\n",
869 | " self.demodulation = demodulation\n",
870 | "\n",
871 | " def forward(self, input, style):\n",
872 | " batch_size = input.shape[0]\n",
873 | "\n",
874 | " if self.style_size != self.in_channels:\n",
875 | " style = self.modulation(style)\n",
876 | " style = style.view(batch_size, 1, self.in_channels, 1)\n",
877 | " # print('self.scale, self.weight.shape, style.shape', self.scale, self.weight.shape, style.shape)\n",
878 | " weight = self.scale * self.weight * style\n",
879 | "\n",
880 | " if self.demodulation:\n",
881 | " demod = torch.rsqrt(weight.pow(2).sum([2]) + 1e-8)\n",
882 | " weight = weight * demod.view(batch_size, self.out_channels, 1, 1)\n",
883 | "\n",
884 | " weight = weight.view(\n",
885 | " batch_size * self.out_channels, self.in_channels, 1\n",
886 | " )\n",
887 | " \n",
888 | " img_size = input.size(1)\n",
889 | " input = input.reshape(1, batch_size * self.in_channels, img_size)\n",
890 | " out = F.conv1d(input, weight, groups=batch_size)\n",
891 | " out = out.view(batch_size, img_size, self.out_channels)\n",
892 | "\n",
893 | " return out"
894 | ]
895 | },
896 | {
897 | "cell_type": "code",
898 | "execution_count": null,
899 | "metadata": {
900 | "id": "gJv0hR38RBFC"
901 | },
902 | "outputs": [],
903 | "source": [
904 | "class ResLinear(nn.Module):\n",
905 | " def __init__(self, in_channels, out_channels, style_size, bias=False, **kwargs):\n",
906 | " super().__init__()\n",
907 | " self.linear = FullyConnectedLayer(in_channels, out_channels, bias=False)\n",
908 | " self.style = FullyConnectedLayer(style_size, in_channels, bias=False)\n",
909 | " self.in_channels = in_channels\n",
910 | " self.out_channels = out_channels\n",
911 | " self.style_size = style_size\n",
912 | " # print('style_size, in_channels, out_channels', style_size, in_channels, out_channels)\n",
913 | "\n",
914 | " def forward(self, input, style):\n",
915 | " x = input + self.style(style).unsqueeze(1)\n",
916 | " x = self.linear(x)\n",
917 | " return x"
918 | ]
919 | },
920 | {
921 | "cell_type": "code",
922 | "execution_count": null,
923 | "metadata": {
924 | "id": "930yhZ6zgPI0"
925 | },
926 | "outputs": [],
927 | "source": [
928 | "class ConLinear(nn.Module):\n",
929 | " def __init__(self, ch_in, ch_out, is_first=False, bias=True, **kwargs):\n",
930 | " super(ConLinear, self).__init__()\n",
931 | " self.conv = nn.Linear(ch_in, ch_out, bias=bias)\n",
932 | " if is_first:\n",
933 | " nn.init.uniform_(self.conv.weight, -np.sqrt(9 / ch_in), np.sqrt(9 / ch_in))\n",
934 | " else:\n",
935 | " nn.init.uniform_(self.conv.weight, -np.sqrt(3 / ch_in), np.sqrt(3 / ch_in))\n",
936 | "\n",
937 | " def forward(self, x):\n",
938 | " return self.conv(x)\n",
939 | "\n",
940 | "class SinActivation(nn.Module):\n",
941 | " def __init__(self):\n",
942 | " super(SinActivation, self).__init__()\n",
943 | "\n",
944 | " def forward(self, x):\n",
945 | " return torch.sin(x)\n",
946 | "\n",
947 | "class LFF(nn.Module):\n",
948 | " def __init__(self, hidden_size, **kwargs):\n",
949 | " super(LFF, self).__init__()\n",
950 | " self.ffm = ConLinear(2, hidden_size, is_first=True)\n",
951 | " self.activation = SinActivation()\n",
952 | "\n",
953 | " def forward(self, x):\n",
954 | " x = x\n",
955 | " x = self.ffm(x)\n",
956 | " x = self.activation(x)\n",
957 | " return x"
958 | ]
959 | },
960 | {
961 | "cell_type": "code",
962 | "execution_count": null,
963 | "metadata": {
964 | "id": "wgjY_FpH481X"
965 | },
966 | "outputs": [],
967 | "source": [
968 | "class SineLayer(nn.Module):\n",
969 | " def __init__(self, in_features, out_features, style_size, bias=False,\n",
970 | " is_first=False, omega_0=30, weight_modulation=True, **kwargs):\n",
971 | " super().__init__()\n",
972 | " self.omega_0 = omega_0\n",
973 | " self.is_first = is_first\n",
974 | " \n",
975 | " self.in_features = in_features\n",
976 | " self.weight_modulation = weight_modulation\n",
977 | " if weight_modulation:\n",
978 | " self.linear = ModulatedLinear(in_features, out_features, style_size=style_size, bias=bias, **kwargs)\n",
979 | " else:\n",
980 | " self.linear = ResLinear(in_features, out_features, style_size=style_size, bias=bias, **kwargs)\n",
981 | " # print('in_features, out_features, style_size', in_features, out_features, style_size)\n",
982 | " self.init_weights()\n",
983 | " \n",
984 | " def init_weights(self):\n",
985 | " with torch.no_grad():\n",
986 | " if self.is_first:\n",
987 | " if self.weight_modulation:\n",
988 | " self.linear.weight.uniform_(-1 / self.in_features, \n",
989 | " 1 / self.in_features)\n",
990 | " else:\n",
991 | " self.linear.linear.weight.uniform_(-1 / self.in_features, \n",
992 | " 1 / self.in_features) \n",
993 | " else:\n",
994 | " if self.weight_modulation:\n",
995 | " self.linear.weight.uniform_(-np.sqrt(6 / self.in_features) / self.omega_0, \n",
996 | " np.sqrt(6 / self.in_features) / self.omega_0)\n",
997 | " else:\n",
998 | " self.linear.linear.weight.uniform_(-np.sqrt(6 / self.in_features) / self.omega_0, \n",
999 | " np.sqrt(6 / self.in_features) / self.omega_0) \n",
1000 | " \n",
1001 | " def forward(self, input, style):\n",
1002 | " return torch.sin(self.omega_0 * self.linear(input, style))\n",
1003 | " \n",
1004 | "class Siren(nn.Module):\n",
1005 | " def __init__(self, in_features, hidden_size, hidden_layers, out_features, style_size, outermost_linear=False, \n",
1006 | " first_omega_0=30, hidden_omega_0=30., weight_modulation=True, bias=False, **kwargs):\n",
1007 | " super().__init__()\n",
1008 | " \n",
1009 | " self.net = []\n",
1010 | " self.net.append(SineLayer(in_features, hidden_size, style_size,\n",
1011 | " is_first=True, omega_0=first_omega_0,\n",
1012 | " weight_modulation=weight_modulation, **kwargs))\n",
1013 | "\n",
1014 | " for i in range(hidden_layers):\n",
1015 | " self.net.append(SineLayer(hidden_size, hidden_size, style_size,\n",
1016 | " is_first=False, omega_0=hidden_omega_0,\n",
1017 | " weight_modulation=weight_modulation, **kwargs))\n",
1018 | "\n",
1019 | " if outermost_linear:\n",
1020 | " if weight_modulation:\n",
1021 | " final_linear = ModulatedLinear(hidden_size, out_features,\n",
1022 | " style_size=style_size, bias=bias, **kwargs)\n",
1023 | " else:\n",
1024 | " final_linear = ResLinear(hidden_size, out_features, style_size=style_size, bias=bias, **kwargs)\n",
1025 | " # FullyConnectedLayer(hidden_size, out_features, bias=False)\n",
1026 | " # final_linear = nn.Linear(hidden_size, out_features)\n",
1027 | " \n",
1028 | " with torch.no_grad():\n",
1029 | " if weight_modulation:\n",
1030 | " final_linear.weight.uniform_(-np.sqrt(6 / hidden_size) / hidden_omega_0, \n",
1031 | " np.sqrt(6 / hidden_size) / hidden_omega_0)\n",
1032 | " else:\n",
1033 | " final_linear.linear.weight.uniform_(-np.sqrt(6 / hidden_size) / hidden_omega_0, \n",
1034 | " np.sqrt(6 / hidden_size) / hidden_omega_0)\n",
1035 | " \n",
1036 | " self.net.append(final_linear)\n",
1037 | " else:\n",
1038 | " self.net.append(SineLayer(hidden_size, out_features, \n",
1039 | " is_first=False, omega_0=hidden_omega_0,\n",
1040 | " weight_modulation=weight_modulation, **kwargs))\n",
1041 | " \n",
1042 | " self.net = nn.Sequential(*self.net)\n",
1043 | " \n",
1044 | " def forward(self, coords, style):\n",
1045 | " coords = coords.clone().detach().requires_grad_(True) # allows to take derivative w.r.t. input\n",
1046 | " # output = self.net(coords, style)\n",
1047 | " output = coords\n",
1048 | " for layer in self.net:\n",
1049 | " output = layer(output, style)\n",
1050 | " return output"
1051 | ]
1052 | },
1053 | {
1054 | "cell_type": "markdown",
1055 | "metadata": {
1056 | "id": "V8DB73-TBGEs"
1057 | },
1058 | "source": [
1059 | "$$\n",
1060 | "Fou(\\mathbf{v})= \\left[ \\cos(2 \\pi \\mathbf B \\mathbf{v}), \\sin(2 \\pi \\mathbf B \\mathbf{v}) \\right]^\\mathrm{T}\n",
1061 | "$$"
1062 | ]
1063 | },
1064 | {
1065 | "cell_type": "code",
1066 | "execution_count": null,
1067 | "metadata": {
1068 | "id": "5ru0trNs9N3F"
1069 | },
1070 | "outputs": [],
1071 | "source": [
1072 | "class GeneratorViT(nn.Module):\n",
1073 | " def __init__(self,\n",
1074 | " style_mlp_layers=8,\n",
1075 | " patch_size=4,\n",
1076 | " latent_dim=32,\n",
1077 | " hidden_size=384,\n",
1078 | " sln_paremeter_size=1,\n",
1079 | " image_size=32,\n",
1080 | " depth=4,\n",
1081 | " combine_patch_embeddings=False,\n",
1082 | " combined_embedding_size=1024,\n",
1083 | " forward_drop_p=0.,\n",
1084 | " bias=False,\n",
1085 | " out_features=3,\n",
1086 | " weight_modulation=True,\n",
1087 | " siren_hidden_layers=1,\n",
1088 | " **kwargs):\n",
1089 | " super().__init__()\n",
1090 | " self.hidden_size = hidden_size\n",
1091 | "\n",
1092 | " self.mlp = MappingNetwork(z_dim=latent_dim, c_dim=0, w_dim=hidden_size, num_layers=style_mlp_layers, w_avg_beta=None)\n",
1093 | "\n",
1094 | " num_patches = int(image_size//patch_size)**2\n",
1095 | " self.patch_size = patch_size\n",
1096 | " self.num_patches = num_patches\n",
1097 | " self.image_size = image_size\n",
1098 | " self.combine_patch_embeddings = combine_patch_embeddings\n",
1099 | " self.combined_embedding_size = combined_embedding_size\n",
1100 | "\n",
1101 | " self.pos_emb = nn.Parameter(torch.randn(num_patches, hidden_size))\n",
1102 | " self.transformer_encoder = GeneratorTransformerEncoder(depth,\n",
1103 | " hidden_size=hidden_size,\n",
1104 | " sln_paremeter_size=sln_paremeter_size,\n",
1105 | " drop_p=forward_drop_p,\n",
1106 | " forward_drop_p=forward_drop_p,\n",
1107 | " **kwargs)\n",
1108 | " self.sln = SLN(hidden_size, parameter_size=sln_paremeter_size)\n",
1109 | " if combine_patch_embeddings:\n",
1110 | " self.to_single_emb = nn.Sequential(\n",
1111 | " FullyConnectedLayer(num_patches*hidden_size, combined_embedding_size, bias=bias, activation='gelu'),\n",
1112 | " nn.Dropout(forward_drop_p),\n",
1113 | " )\n",
1114 | "\n",
1115 | " self.lff = LFF(self.hidden_size)\n",
1116 | "\n",
1117 | " self.siren_in_features = combined_embedding_size if combine_patch_embeddings else self.hidden_size\n",
1118 | " self.siren = Siren(in_features=self.siren_in_features, out_features=out_features,\n",
1119 | " style_size=self.siren_in_features, hidden_size=self.hidden_size, bias=bias,\n",
1120 | " hidden_layers=siren_hidden_layers, outermost_linear=True, weight_modulation=weight_modulation, **kwargs)\n",
1121 | "\n",
1122 | " self.num_patches_x = int(image_size//out_patch_size)\n",
1123 | "\n",
1124 | "\n",
1125 | " def fourier_input_mapping(self, x):\n",
1126 | " return self.lff(x)\n",
1127 | "\n",
1128 | " def fourier_pos_embedding(self):\n",
1129 | " # Create input pixel coordinates in the unit square\n",
1130 | " coords = np.linspace(-1, 1, out_patch_size, endpoint=True)\n",
1131 | " pos = np.stack(np.meshgrid(coords, coords), -1)\n",
1132 | " pos = torch.tensor(pos, dtype=torch.float).to(device)\n",
1133 | " result = self.fourier_input_mapping(pos).reshape([out_patch_size**2, self.hidden_size])\n",
1134 | " return result.to(device)\n",
1135 | "\n",
1136 | " def mix_hidden_and_pos(self, hidden):\n",
1137 | " pos = self.fourier_pos_embedding()\n",
1138 | "\n",
1139 | " pos = repeat(pos, 'p h -> n p h', n = hidden.shape[0])\n",
1140 | "\n",
1141 | " return result\n",
1142 | "\n",
1143 | " def forward(self, z):\n",
1144 | " w = self.mlp(z)\n",
1145 | " pos = repeat(torch.sin(self.pos_emb), 'n e -> b n e', b=z.shape[0])\n",
1146 | " hidden = self.transformer_encoder(pos, w)\n",
1147 | "\n",
1148 | " if self.combine_patch_embeddings:\n",
1149 | " # Output [batch_size, combined_embedding_size]\n",
1150 | " hidden = self.sln(hidden, w).view((z.shape[0], -1))\n",
1151 | " hidden = self.to_single_emb(hidden)\n",
1152 | " else:\n",
1153 | " # Output [batch_size*num_patches, hidden_size]\n",
1154 | " hidden = self.sln(hidden, w).view((-1, self.hidden_size))\n",
1155 | " \n",
1156 | " pos = self.mix_hidden_and_pos(hidden)\n",
1157 | "\n",
1158 | " # hidden = repeat(hidden, 'n h -> n p h', p = out_patch_size**2)\n",
1159 | "\n",
1160 | " result = self.siren(pos, hidden)\n",
1161 | "\n",
1162 | " model_output_1 = result.view([-1, self.num_patches_x, self.num_patches_x, out_patch_size, out_patch_size, out_features])\n",
1163 | " model_output_2 = model_output_1.permute([0, 1, 3, 2, 4, 5])\n",
1164 | " model_output = model_output_2.reshape([-1, image_size**2, out_features])\n",
1165 | " \n",
1166 | " return model_output\n",
1167 | "\n",
1168 | "\n",
1169 | "Generator = GeneratorViT( patch_size=patch_size,\n",
1170 | " image_size=image_size,\n",
1171 | " style_mlp_layers=style_mlp_layers,\n",
1172 | " latent_dim=latent_dim,\n",
1173 | " hidden_size=hidden_size,\n",
1174 | " combine_patch_embeddings=combine_patch_embeddings,\n",
1175 | " combined_embedding_size=combined_embedding_size,\n",
1176 | " sln_paremeter_size=sln_paremeter_size,\n",
1177 | " num_heads=num_heads,\n",
1178 | " depth=depth,\n",
1179 | " forward_drop_p=dropout_p,\n",
1180 | " bias=bias,\n",
1181 | " weight_modulation=weight_modulation,\n",
1182 | " siren_hidden_layers=siren_hidden_layers,\n",
1183 | " demodulation=demodulation,\n",
1184 | " ).to(device)\n",
1185 | "print(Generator(torch.randn([batch_size, latent_dim]).to(device)).shape)\n",
1186 | "# print(Generator)\n",
1187 | "del Generator"
1188 | ]
1189 | },
1190 | {
1191 | "cell_type": "markdown",
1192 | "metadata": {
1193 | "id": "NqMxvqCjb2fp"
1194 | },
1195 | "source": [
1196 | "# CNN Generator"
1197 | ]
1198 | },
1199 | {
1200 | "cell_type": "code",
1201 | "execution_count": null,
1202 | "metadata": {
1203 | "id": "D7e7XAjvb5pX"
1204 | },
1205 | "outputs": [],
1206 | "source": [
1207 | "class CNNGenerator(nn.Module):\n",
1208 | " def __init__(self):\n",
1209 | " super(CNNGenerator, self).__init__()\n",
1210 | " self.w = nn.Linear(latent_dim, hidden_size * 2 * 4 * 4, bias=False)\n",
1211 | " self.main = nn.Sequential(\n",
1212 | " # input is Z, going into a convolution\n",
1213 | " nn.BatchNorm2d(hidden_size * 2),\n",
1214 | " nn.ReLU(True),\n",
1215 | " # state size. (ngf*8) x 4 x 4\n",
1216 | " nn.ConvTranspose2d(hidden_size * 2, hidden_size, 4, 2, 1, bias=False),\n",
1217 | " nn.BatchNorm2d(hidden_size),\n",
1218 | " nn.ReLU(True),\n",
1219 | " # state size. (ngf*4) x 8 x 8\n",
1220 | " nn.ConvTranspose2d( hidden_size, hidden_size // 2, 4, 2, 1, bias=False),\n",
1221 | " nn.BatchNorm2d(hidden_size // 2),\n",
1222 | " nn.ReLU(True),\n",
1223 | " # state size. (ngf*2) x 16 x 16\n",
1224 | " nn.ConvTranspose2d( hidden_size // 2, hidden_size // 4, 4, 2, 1, bias=False),\n",
1225 | " nn.BatchNorm2d(hidden_size // 4),\n",
1226 | " nn.ReLU(True),\n",
1227 | " # state size. (ngf*2) x 32 x 32\n",
1228 | " nn.ConvTranspose2d( hidden_size // 4, 3, 3, 1, 1, bias=False),\n",
1229 | " nn.Tanh(),\n",
1230 | " # state size. (nc) x 64 x 64\n",
1231 | " )\n",
1232 | "\n",
1233 | " def forward(self, input):\n",
1234 | " input = self.w(input).view((-1, hidden_size * 2, 4, 4))\n",
1235 | " return self.main(input)"
1236 | ]
1237 | },
1238 | {
1239 | "cell_type": "markdown",
1240 | "metadata": {
1241 | "id": "t1Bz035Ugroo"
1242 | },
1243 | "source": [
1244 | "# Discriminator"
1245 | ]
1246 | },
1247 | {
1248 | "cell_type": "code",
1249 | "execution_count": null,
1250 | "metadata": {
1251 | "id": "wY_LoVXOgtao"
1252 | },
1253 | "outputs": [],
1254 | "source": [
1255 | "class PatchEmbedding(nn.Module):\n",
1256 | " def __init__(self, in_channels=3, patch_size=4, stride_size=4, emb_size=384, image_size=32, batch_size=64):\n",
1257 | " super().__init__()\n",
1258 | " self.patch_size = patch_size\n",
1259 | " self.projection = nn.Sequential(\n",
1260 | " # using a conv layer instead of a linear one -> performance gains\n",
1261 | " spectral_norm(nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=stride_size)).to(device),\n",
1262 | " Rearrange('b e (h) (w) -> b (h w) e'),\n",
1263 | " )\n",
1264 | " self.cls_token = nn.Parameter(torch.randn(1, 1, emb_size))\n",
1265 | " num_patches = ((image_size-patch_size+stride_size) // stride_size) **2 + 1\n",
1266 | " self.positions = nn.Parameter(torch.randn(num_patches, emb_size))\n",
1267 | " self.batch_size = batch_size\n",
1268 | "\n",
1269 | " def forward(self, x):\n",
1270 | " b, _, _, _ = x.shape\n",
1271 | " x = self.projection(x)\n",
1272 | " cls_tokens = repeat(self.cls_token, '() n e -> b n e', b=b)\n",
1273 | " # prepend the cls token to the input\n",
1274 | " x = torch.cat([cls_tokens, x], dim=1)\n",
1275 | " # add position embedding\n",
1276 | " x += torch.sin(self.positions)\n",
1277 | " return x"
1278 | ]
1279 | },
1280 | {
1281 | "cell_type": "code",
1282 | "execution_count": null,
1283 | "metadata": {
1284 | "id": "D8DZy_pOg0ID"
1285 | },
1286 | "outputs": [],
1287 | "source": [
1288 | "class ResidualAdd(nn.Module):\n",
1289 | " def __init__(self, fn):\n",
1290 | " super().__init__()\n",
1291 | " self.fn = fn\n",
1292 | " \n",
1293 | " def forward(self, x, **kwargs):\n",
1294 | " res = x\n",
1295 | " x = self.fn(x, **kwargs)\n",
1296 | " x += res\n",
1297 | " return x"
1298 | ]
1299 | },
1300 | {
1301 | "cell_type": "code",
1302 | "execution_count": null,
1303 | "metadata": {
1304 | "id": "BkfQ73pSgwQo"
1305 | },
1306 | "outputs": [],
1307 | "source": [
1308 | "class DiscriminatorTransformerEncoderBlock(nn.Sequential):\n",
1309 | " def __init__(self,\n",
1310 | " emb_size=384,\n",
1311 | " drop_p=0.,\n",
1312 | " forward_expansion=4,\n",
1313 | " forward_drop_p=0.,\n",
1314 | " **kwargs):\n",
1315 | " super().__init__(\n",
1316 | " ResidualAdd(nn.Sequential(\n",
1317 | " nn.LayerNorm(emb_size),\n",
1318 | " MultiHeadAttention(emb_size, **kwargs),\n",
1319 | " nn.Dropout(drop_p)\n",
1320 | " )),\n",
1321 | " ResidualAdd(nn.Sequential(\n",
1322 | " nn.LayerNorm(emb_size),\n",
1323 | " nn.Sequential(\n",
1324 | " spectral_norm(nn.Linear(emb_size, forward_expansion * emb_size)),\n",
1325 | " nn.GELU(),\n",
1326 | " nn.Dropout(forward_drop_p),\n",
1327 | " spectral_norm(nn.Linear(forward_expansion * emb_size, emb_size)),\n",
1328 | " ),\n",
1329 | " nn.Dropout(drop_p)\n",
1330 | " )\n",
1331 | " ))"
1332 | ]
1333 | },
1334 | {
1335 | "cell_type": "code",
1336 | "execution_count": null,
1337 | "metadata": {
1338 | "id": "F6IYQuoNg9EA"
1339 | },
1340 | "outputs": [],
1341 | "source": [
1342 | "class DiscriminatorTransformerEncoder(nn.Sequential):\n",
1343 | " def __init__(self, depth=4, **kwargs):\n",
1344 | " super().__init__(*[DiscriminatorTransformerEncoderBlock(**kwargs) for _ in range(depth)])\n",
1345 | "\n",
1346 | "class ClassificationHead(nn.Sequential):\n",
1347 | " def __init__(self, emb_size=384, class_size_1=4098, class_size_2=1024, class_size_3=512, n_classes=10):\n",
1348 | " super().__init__(\n",
1349 | " nn.LayerNorm(emb_size),\n",
1350 | " spectral_norm(nn.Linear(emb_size, class_size_1)),\n",
1351 | " nn.GELU(),\n",
1352 | " spectral_norm(nn.Linear(class_size_1, class_size_2)),\n",
1353 | " nn.GELU(),\n",
1354 | " spectral_norm(nn.Linear(class_size_2, class_size_3)),\n",
1355 | " nn.GELU(),\n",
1356 | " spectral_norm(nn.Linear(class_size_3, n_classes)),\n",
1357 | " nn.GELU(),\n",
1358 | " )\n",
1359 | "\n",
1360 | " def forward(self, x):\n",
1361 | " # Take only the cls token outputs\n",
1362 | " x = x[:, 0, :]\n",
1363 | " return super().forward(x)"
1364 | ]
1365 | },
1366 | {
1367 | "cell_type": "code",
1368 | "execution_count": null,
1369 | "metadata": {
1370 | "id": "5oALpYkMhCGV"
1371 | },
1372 | "outputs": [],
1373 | "source": [
1374 | "class ViT(nn.Sequential):\n",
1375 | " def __init__(self, \n",
1376 | " in_channels=3,\n",
1377 | " patch_size=4,\n",
1378 | " stride_size=4,\n",
1379 | " emb_size=384,\n",
1380 | " image_size=32,\n",
1381 | " depth=4,\n",
1382 | " n_classes=1,\n",
1383 | " diffaugment='color,translation,cutout',\n",
1384 | " **kwargs):\n",
1385 | " self.diffaugment = diffaugment\n",
1386 | " super().__init__(\n",
1387 | " PatchEmbedding(in_channels, patch_size, stride_size, emb_size, image_size),\n",
1388 | " DiscriminatorTransformerEncoder(depth, emb_size=emb_size, **kwargs),\n",
1389 | " ClassificationHead(emb_size, n_classes=n_classes)\n",
1390 | " )\n",
1391 | " \n",
1392 | " def forward(self, img, do_augment=True):\n",
1393 | " if do_augment:\n",
1394 | " img = DiffAugment(img, policy=self.diffaugment)\n",
1395 | " return super().forward(img)"
1396 | ]
1397 | },
1398 | {
1399 | "cell_type": "markdown",
1400 | "metadata": {
1401 | "id": "Wu4HNFgpaLZ2"
1402 | },
1403 | "source": [
1404 | "# CNN Discriminator"
1405 | ]
1406 | },
1407 | {
1408 | "cell_type": "code",
1409 | "execution_count": null,
1410 | "metadata": {
1411 | "id": "uyj4W0kQLJ2j"
1412 | },
1413 | "outputs": [],
1414 | "source": [
1415 | "class CNN(nn.Sequential):\n",
1416 | " def __init__(self,\n",
1417 | " diffaugment='color,translation,cutout',\n",
1418 | " **kwargs):\n",
1419 | " self.diffaugment = diffaugment\n",
1420 | " super().__init__(\n",
1421 | " nn.Conv2d(3,32,kernel_size=3,padding=1),\n",
1422 | " nn.ReLU(),\n",
1423 | " nn.Conv2d(32,64,kernel_size=3,stride=1,padding=1),\n",
1424 | " nn.ReLU(),\n",
1425 | " nn.MaxPool2d(2,2),\n",
1426 | "\n",
1427 | " nn.Conv2d(64,128,kernel_size=3,stride=1,padding=1),\n",
1428 | " nn.ReLU(),\n",
1429 | " nn.Conv2d(128,128,kernel_size=3,stride=1,padding=1),\n",
1430 | " nn.ReLU(),\n",
1431 | " nn.MaxPool2d(2,2),\n",
1432 | "\n",
1433 | " nn.Conv2d(128,256,kernel_size=3,stride=1,padding=1),\n",
1434 | " nn.ReLU(),\n",
1435 | " nn.Conv2d(256,256,kernel_size=3,stride=1,padding=1),\n",
1436 | " nn.ReLU(),\n",
1437 | " nn.MaxPool2d(2,2),\n",
1438 | "\n",
1439 | " nn.Flatten(),\n",
1440 | " nn.Linear(256*4*4,1024),\n",
1441 | " nn.ReLU(),\n",
1442 | " nn.Linear(1024,512),\n",
1443 | " nn.ReLU(),\n",
1444 | " nn.Linear(512,1)\n",
1445 | " )\n",
1446 | " \n",
1447 | " def forward(self, img, do_augment=True):\n",
1448 | " if do_augment:\n",
1449 | " img = DiffAugment(img, policy=self.diffaugment)\n",
1450 | " return super().forward(img)"
1451 | ]
1452 | },
1453 | {
1454 | "cell_type": "markdown",
1455 | "metadata": {
1456 | "id": "wbniFkThhR5T"
1457 | },
1458 | "source": [
1459 | "# StyleGAN2 Discriminator"
1460 | ]
1461 | },
1462 | {
1463 | "cell_type": "code",
1464 | "execution_count": null,
1465 | "metadata": {
1466 | "id": "ZW_m-p4pg34m"
1467 | },
1468 | "outputs": [],
1469 | "source": [
1470 | "class StyleGanDiscriminator(stylegan2_pytorch.Discriminator):\n",
1471 | " def __init__(self,\n",
1472 | " diffaugment='color,translation,cutout',\n",
1473 | " **kwargs):\n",
1474 | " self.diffaugment = diffaugment\n",
1475 | " super().__init__(**kwargs)\n",
1476 | " def forward(self, img, do_augment=True):\n",
1477 | " if do_augment:\n",
1478 | " img = DiffAugment(img, policy=self.diffaugment)\n",
1479 | " out, _ = super().forward(img)\n",
1480 | " return out"
1481 | ]
1482 | },
1483 | {
1484 | "cell_type": "markdown",
1485 | "metadata": {
1486 | "id": "WsKPASQ3XtSU"
1487 | },
1488 | "source": [
1489 | "# Diversity Loss"
1490 | ]
1491 | },
1492 | {
1493 | "cell_type": "code",
1494 | "execution_count": null,
1495 | "metadata": {
1496 | "id": "SB52QuhSWhRk"
1497 | },
1498 | "outputs": [],
1499 | "source": [
1500 | "def diversity_loss(images):\n",
1501 | " num_images_to_calculate_on = 10\n",
1502 | " num_pairs = num_images_to_calculate_on * (num_images_to_calculate_on - 1) // 2\n",
1503 | "\n",
1504 | " scale_factor = 5\n",
1505 | "\n",
1506 | " loss = torch.zeros(1, dtype=torch.float, device=device, requires_grad=True)\n",
1507 | " i = 0\n",
1508 | " for a_id in range(num_images_to_calculate_on):\n",
1509 | " for b_id in range(a_id+1, num_images_to_calculate_on):\n",
1510 | " img_a = images[a_id]\n",
1511 | " img_b = images[b_id]\n",
1512 | " img_a_l2 = torch.norm(img_a)\n",
1513 | " img_b_l2 = torch.norm(img_b)\n",
1514 | " img_a, img_b = img_a.flatten(), img_b.flatten()\n",
1515 | "\n",
1516 | " # print(img_a_l2, img_b_l2, img_a.shape, img_b.shape)\n",
1517 | "\n",
1518 | " a_b_loss = scale_factor * (img_a.t() @ img_b) / (img_a_l2 * img_b_l2)\n",
1519 | " # print(a_b_loss)\n",
1520 | " loss = loss + torch.sigmoid(a_b_loss)\n",
1521 | " i += 1\n",
1522 | " loss = loss.sum() / num_pairs\n",
1523 | " return loss"
1524 | ]
1525 | },
1526 | {
1527 | "cell_type": "markdown",
1528 | "metadata": {
1529 | "id": "RFkNgPe5J4WZ"
1530 | },
1531 | "source": [
1532 | "# Normal distribution init weight"
1533 | ]
1534 | },
1535 | {
1536 | "cell_type": "code",
1537 | "execution_count": null,
1538 | "metadata": {
1539 | "id": "kg23vO0CGoPr"
1540 | },
1541 | "outputs": [],
1542 | "source": [
1543 | "def init_normal(m):\n",
1544 | " if type(m) == nn.Linear:\n",
1545 | " if 'weight' in m.__dict__.keys():\n",
1546 | " m.weight.data.normal_(0.0,1)"
1547 | ]
1548 | },
1549 | {
1550 | "cell_type": "markdown",
1551 | "metadata": {
1552 | "id": "9Akp8Ybo483q"
1553 | },
1554 | "source": [
1555 | "# Experiments"
1556 | ]
1557 | },
1558 | {
1559 | "cell_type": "code",
1560 | "execution_count": null,
1561 | "metadata": {
1562 | "id": "CS5sfmIuLj8P"
1563 | },
1564 | "outputs": [],
1565 | "source": [
1566 | "if generator_type == \"vitgan\":\n",
1567 | " # Create the Generator\n",
1568 | " Generator = GeneratorViT( patch_size=patch_size,\n",
1569 | " image_size=image_size,\n",
1570 | " style_mlp_layers=style_mlp_layers,\n",
1571 | " latent_dim=latent_dim,\n",
1572 | " hidden_size=hidden_size,\n",
1573 | " combine_patch_embeddings=combine_patch_embeddings,\n",
1574 | " combined_embedding_size=combined_embedding_size,\n",
1575 | " sln_paremeter_size=sln_paremeter_size,\n",
1576 | " num_heads=num_heads,\n",
1577 | " depth=depth,\n",
1578 | " forward_drop_p=dropout_p,\n",
1579 | " bias=bias,\n",
1580 | " weight_modulation=weight_modulation,\n",
1581 | " siren_hidden_layers=siren_hidden_layers,\n",
1582 | " demodulation=demodulation,\n",
1583 | " ).to(device)\n",
1584 | " \n",
1585 | " # use the modules apply function to recursively apply the initialization\n",
1586 | " Generator.apply(init_normal)\n",
1587 | "\n",
1588 | " num_patches_x = int(image_size//out_patch_size)\n",
1589 | "\n",
1590 | " if os.path.exists(f'{experiment_folder_name}/weights/Generator.pth'):\n",
1591 | " Generator = torch.load(f'{experiment_folder_name}/weights/Generator.pth')\n",
1592 | "\n",
1593 | " wandb.watch(Generator)\n",
1594 | "\n",
1595 | "elif generator_type == \"cnn\":\n",
1596 | " cnn_generator = CNNGenerator().to(device)\n",
1597 | "\n",
1598 | " cnn_generator.apply(init_normal)\n",
1599 | "\n",
1600 | " if os.path.exists(f'{experiment_folder_name}/weights/cnn_generator.pth'):\n",
1601 | " cnn_generator = torch.load(f'{experiment_folder_name}/weights/cnn_generator.pth')\n",
1602 | "\n",
1603 | " wandb.watch(cnn_generator)\n",
1604 | "\n",
1605 | "# Create the three types of discriminators\n",
1606 | "if discriminator_type == \"vitgan\":\n",
1607 | " Discriminator = ViT(discriminator=True,\n",
1608 | " patch_size=patch_size*2,\n",
1609 | " stride_size=patch_size,\n",
1610 | " n_classes=1, \n",
1611 | " num_heads=num_heads,\n",
1612 | " depth=depth,\n",
1613 | " forward_drop_p=dropout_p,\n",
1614 | " ).to(device)\n",
1615 | " \n",
1616 | " Discriminator.apply(init_normal)\n",
1617 | " \n",
1618 | " if os.path.exists(f'{experiment_folder_name}/weights/discriminator.pth'):\n",
1619 | " Discriminator = torch.load(f'{experiment_folder_name}/weights/discriminator.pth')\n",
1620 | "\n",
1621 | " wandb.watch(Discriminator)\n",
1622 | "\n",
1623 | "elif discriminator_type == \"cnn\":\n",
1624 | " cnn_discriminator = CNN().to(device)\n",
1625 | "\n",
1626 | " cnn_discriminator.apply(init_normal)\n",
1627 | "\n",
1628 | " if os.path.exists(f'{experiment_folder_name}/weights/discriminator.pth'):\n",
1629 | " cnn_discriminator = torch.load(f'{experiment_folder_name}/weights/discriminator.pth')\n",
1630 | "\n",
1631 | " wandb.watch(cnn_discriminator)\n",
1632 | "\n",
1633 | "elif discriminator_type == \"stylegan2\":\n",
1634 | " stylegan2_discriminator = StyleGanDiscriminator(image_size=32).to(device)\n",
1635 | "\n",
1636 | " # stylegan2_discriminator.apply(init_normal)\n",
1637 | "\n",
1638 | " if os.path.exists(f'{experiment_folder_name}/weights/discriminator.pth'):\n",
1639 | " stylegan2_discriminator = torch.load(f'{experiment_folder_name}/weights/discriminator.pth')\n",
1640 | "\n",
1641 | " wandb.watch(stylegan2_discriminator)"
1642 | ]
1643 | },
1644 | {
1645 | "cell_type": "markdown",
1646 | "metadata": {
1647 | "id": "j9EdXO0K8qFW"
1648 | },
1649 | "source": [
1650 | "# Testing the generator\n",
1651 | "\n",
1652 | "Train to match fixed latent values to fixed images"
1653 | ]
1654 | },
1655 | {
1656 | "cell_type": "code",
1657 | "execution_count": null,
1658 | "metadata": {
1659 | "id": "1BsgyaZMMi2f"
1660 | },
1661 | "outputs": [],
1662 | "source": [
1663 | "total_steps = 0 # Since the whole image is our dataset, this just means 500 gradient descent steps.\n",
1664 | "steps_til_summary = 50\n",
1665 | "\n",
1666 | "if generator_type == \"vitgan\":\n",
1667 | " params = Generator.parameters()\n",
1668 | "else:\n",
1669 | " # z = torch.randn(batch_size, latent_dim, 1, 1, device=device)\n",
1670 | " params = list(cnn_generator.parameters())\n",
1671 | "optim = torch.optim.Adam(lr=lr, params=params)\n",
1672 | "ema = ExponentialMovingAverage(params, decay=0.995)\n",
1673 | "\n",
1674 | "ground_truth, _ = next(iter(trainloader))\n",
1675 | "ground_truth = ground_truth.permute(0, 2, 3, 1).view((-1, image_size**2, out_features))\n",
1676 | "ground_truth = ground_truth.to(device)\n",
1677 | "\n",
1678 | "z = torch.randn([batch_size, latent_dim]).to(device)"
1679 | ]
1680 | },
1681 | {
1682 | "cell_type": "code",
1683 | "execution_count": null,
1684 | "metadata": {
1685 | "id": "Swv4ZQt94833"
1686 | },
1687 | "outputs": [],
1688 | "source": [
1689 | "for step in range(total_steps):\n",
1690 | " if generator_type == \"vitgan\":\n",
1691 | " model_output = Generator(z)\n",
1692 | " elif generator_type == \"cnn\":\n",
1693 | " model_output = cnn_generator(z)\n",
1694 | " model_output = model_output.permute([0, 2, 3, 1]).view([-1, image_size**2, out_features])\n",
1695 | " loss = ((model_output - ground_truth)**2).mean()\n",
1696 | " \n",
1697 | " if not step % steps_til_summary:\n",
1698 | " print(\"Step %d, Total loss %0.6f\" % (step, loss))\n",
1699 | "\n",
1700 | " fig, axes = plt.subplots(2,8, figsize=(24,6))\n",
1701 | " for i in range(8):\n",
1702 | " j = np.random.randint(0, batch_size-1)\n",
1703 | " img = model_output[j].cpu().view(32,32,3).detach().numpy()\n",
1704 | " img -= img.min()\n",
1705 | " img /= img.max()\n",
1706 | " axes[0,i].imshow(img)\n",
1707 | " g_img = ground_truth[j].cpu().view(32,32,3).detach().numpy()\n",
1708 | " g_img -= g_img.min()\n",
1709 | " g_img /= g_img.max()\n",
1710 | " axes[1,i].imshow(g_img)\n",
1711 | "\n",
1712 | " plt.show()\n",
1713 | "\n",
1714 | " optim.zero_grad()\n",
1715 | " loss.backward()\n",
1716 | " optim.step()\n",
1717 | " ema.update()"
1718 | ]
1719 | },
1720 | {
1721 | "cell_type": "markdown",
1722 | "metadata": {
1723 | "id": "pcWiPCochjJ7"
1724 | },
1725 | "source": [
1726 | "# Training"
1727 | ]
1728 | },
1729 | {
1730 | "cell_type": "code",
1731 | "execution_count": null,
1732 | "metadata": {
1733 | "id": "Ad_UXnQJCOeN"
1734 | },
1735 | "outputs": [],
1736 | "source": [
1737 | "os.makedirs(f\"{experiment_folder_name}/weights\", exist_ok = True)\n",
1738 | "os.makedirs(f\"{experiment_folder_name}/samples\", exist_ok = True)\n",
1739 | "\n",
1740 | "# Loss function\n",
1741 | "criterion = nn.BCEWithLogitsLoss()\n",
1742 | "\n",
1743 | "if discriminator_type == \"cnn\": discriminator = cnn_discriminator\n",
1744 | "elif discriminator_type == \"stylegan2\": discriminator = stylegan2_discriminator\n",
1745 | "elif discriminator_type == \"vitgan\": discriminator = Discriminator\n",
1746 | "\n",
1747 | "if generator_type == \"cnn\":\n",
1748 | " params = cnn_generator.parameters()\n",
1749 | "else:\n",
1750 | " params = Generator.parameters()\n",
1751 | "optim_g = torch.optim.Adam(lr=lr, params=params, betas=beta)\n",
1752 | "optim_d = torch.optim.Adam(lr=lr_dis, params=discriminator.parameters(), betas=beta)\n",
1753 | "ema = ExponentialMovingAverage(params, decay=0.995)\n",
1754 | "\n",
1755 | "fixed_noise = torch.FloatTensor(np.random.normal(0, 1, (16, latent_dim))).to(device)\n",
1756 | "\n",
1757 | "discriminator_f_img = torch.zeros([batch_size, 3, image_size, image_size]).to(device)\n",
1758 | "\n",
1759 | "trainset_len = len(trainloader.dataset)\n",
1760 | "\n",
1761 | "step = 0\n",
1762 | "for epoch in range(epochs):\n",
1763 | " for batch_id, batch in enumerate(trainloader):\n",
1764 | " step += 1\n",
1765 | "\n",
1766 | " # Train discriminator\n",
1767 | "\n",
1768 | " # Forward + Backward with real images\n",
1769 | " r_img = batch[0].to(device)\n",
1770 | " r_logit = discriminator(r_img).flatten()\n",
1771 | " r_label = torch.ones(r_logit.shape[0]).to(device)\n",
1772 | "\n",
1773 | " lossD_real = criterion(r_logit, r_label)\n",
1774 | " \n",
1775 | " lossD_bCR_real = F.mse_loss(r_logit, discriminator(r_img, do_augment=False))\n",
1776 | "\n",
1777 | " # Forward + Backward with fake images\n",
1778 | " latent_vector = torch.FloatTensor(np.random.normal(0, 1, (batch_size, latent_dim))).to(device)\n",
1779 | "\n",
1780 | " if generator_type == \"vitgan\":\n",
1781 | " f_img = Generator(latent_vector)\n",
1782 | " f_img = f_img.reshape([-1, image_size, image_size, out_features])\n",
1783 | " f_img = f_img.permute(0, 3, 1, 2)\n",
1784 | " else:\n",
1785 | " model_output = cnn_generator(latent_vector)\n",
1786 | " f_img = model_output\n",
1787 | " \n",
1788 | " assert f_img.size(0) == batch_size, f_img.shape\n",
1789 | " assert f_img.size(1) == out_features, f_img.shape\n",
1790 | " assert f_img.size(2) == image_size, f_img.shape\n",
1791 | " assert f_img.size(3) == image_size, f_img.shape\n",
1792 | "\n",
1793 | " f_label = torch.zeros(batch_size).to(device)\n",
1794 | " # Save the a single generated image to the discriminator training data\n",
1795 | " if batch_size_history_discriminator:\n",
1796 | " discriminator_f_img[step % batch_size] = f_img[0].detach()\n",
1797 | " f_logit_history = discriminator(discriminator_f_img).flatten()\n",
1798 | " lossD_fake_history = criterion(f_logit_history, f_label)\n",
1799 | " else: lossD_fake_history = 0\n",
1800 | " # Train the discriminator on the images, generated only from this batch\n",
1801 | " f_logit = discriminator(f_img.detach()).flatten()\n",
1802 | " lossD_fake = criterion(f_logit, f_label)\n",
1803 | " \n",
1804 | " lossD_bCR_fake = F.mse_loss(f_logit, discriminator(f_img, do_augment=False))\n",
1805 | " \n",
1806 | " f_noise_input = torch.FloatTensor(np.random.rand(*f_img.shape)*2 - 1).to(device)\n",
1807 | " f_noise_logit = discriminator(f_noise_input).flatten()\n",
1808 | " lossD_noise = criterion(f_noise_logit, f_label)\n",
1809 | "\n",
1810 | " lossD = lossD_real * 0.5 +\\\n",
1811 | " lossD_fake * 0.5 +\\\n",
1812 | " lossD_fake_history * lambda_lossD_history +\\\n",
1813 | " lossD_noise * lambda_lossD_noise +\\\n",
1814 | " lossD_bCR_real * lambda_bCR_real +\\\n",
1815 | " lossD_bCR_fake * lambda_bCR_fake\n",
1816 | "\n",
1817 | " optim_d.zero_grad()\n",
1818 | " lossD.backward()\n",
1819 | " optim_d.step()\n",
1820 | " \n",
1821 | " # Train Generator\n",
1822 | "\n",
1823 | " if generator_type == \"vitgan\":\n",
1824 | " f_img = Generator(latent_vector)\n",
1825 | " f_img = f_img.reshape([-1, image_size, image_size, out_features])\n",
1826 | " f_img = f_img.permute(0, 3, 1, 2)\n",
1827 | " else:\n",
1828 | " model_output = cnn_generator(latent_vector)\n",
1829 | " f_img = model_output\n",
1830 | " \n",
1831 | " assert f_img.size(0) == batch_size\n",
1832 | " assert f_img.size(1) == out_features\n",
1833 | " assert f_img.size(2) == image_size\n",
1834 | " assert f_img.size(3) == image_size\n",
1835 | "\n",
1836 | " f_logit = discriminator(f_img).flatten()\n",
1837 | " r_label = torch.ones(batch_size).to(device)\n",
1838 | " lossG_main = criterion(f_logit, r_label)\n",
1839 | " \n",
1840 | " lossG_diversity = diversity_loss(f_img) * lambda_diversity_penalty\n",
1841 | " lossG = lossG_main + lossG_diversity\n",
1842 | " \n",
1843 | " optim_g.zero_grad()\n",
1844 | " lossG.backward()\n",
1845 | " optim_g.step()\n",
1846 | " ema.update()\n",
1847 | "\n",
1848 | " writer.add_scalar(\"Loss/Generator\", lossG_main, step)\n",
1849 | " writer.add_scalar(\"Loss/Gen(diversity)\", lossG_diversity, step)\n",
1850 | " writer.add_scalar(\"Loss/Dis(real)\", lossD_real, step)\n",
1851 | " writer.add_scalar(\"Loss/Dis(fake)\", lossD_fake, step)\n",
1852 | " writer.add_scalar(\"Loss/Dis(fake_history)\", lossD_fake_history, step)\n",
1853 | " writer.add_scalar(\"Loss/Dis(noise)\", lossD_noise, step)\n",
1854 | " writer.add_scalar(\"Loss/Dis(bCR_fake)\", lossD_bCR_fake * lambda_bCR_fake, step)\n",
1855 | " writer.add_scalar(\"Loss/Dis(bCR_real)\", lossD_bCR_real * lambda_bCR_real, step)\n",
1856 | " writer.flush()\n",
1857 | "\n",
1858 | " wandb.log({\n",
1859 | " 'Generator': lossG_main,\n",
1860 | " 'Gen(diversity)': lossG_diversity,\n",
1861 | " 'Dis(real)': lossD_real,\n",
1862 | " 'Dis(fake)': lossD_fake,\n",
1863 | " 'Dis(fake_history)': lossD_fake_history,\n",
1864 | " 'Dis(noise)': lossD_noise,\n",
1865 | " 'Dis(bCR_fake)': lossD_bCR_fake * lambda_bCR_fake,\n",
1866 | " 'Dis(bCR_real)': lossD_bCR_real * lambda_bCR_real\n",
1867 | " })\n",
1868 | "\n",
1869 | " if batch_id%20 == 0:\n",
1870 | " print(f'epoch {epoch}/{epochs}; batch {batch_id}/{int(trainset_len/batch_size)}')\n",
1871 | " print(f'Generator: {\"{:.3f}\".format(float(lossG_main))}, '+\\\n",
1872 | " f'Gen(diversity): {\"{:.3f}\".format(float(lossG_diversity))}, '+\\\n",
1873 | " f'Dis(real): {\"{:.3f}\".format(float(lossD_real))}, '+\\\n",
1874 | " f'Dis(fake): {\"{:.3f}\".format(float(lossD_fake))}, '+\\\n",
1875 | " f'Dis(fake_history): {\"{:.3f}\".format(float(lossD_fake_history))}, '+\\\n",
1876 | " f'Dis(noise) {\"{:.3f}\".format(float(lossD_noise))}, '+\\\n",
1877 | " f'Dis(bCR_fake): {\"{:.3f}\".format(float(lossD_bCR_fake * lambda_bCR_fake))}, '+\\\n",
1878 | " f'Dis(bCR_real): {\"{:.3f}\".format(float(lossD_bCR_real * lambda_bCR_real))}')\n",
1879 | "\n",
1880 | " # Plot 8 randomly selected samples\n",
1881 | " fig, axes = plt.subplots(1,8, figsize=(24,3))\n",
1882 | " output = f_img.permute(0, 2, 3, 1)\n",
1883 | " for i in range(8):\n",
1884 | " j = np.random.randint(0, batch_size-1)\n",
1885 | " img = output[j].cpu().view(32,32,3).detach().numpy()\n",
1886 | " img -= img.min()\n",
1887 | " img /= img.max()\n",
1888 | " axes[i].imshow(img)\n",
1889 | " plt.show()\n",
1890 | "\n",
1891 | " # if step % sample_interval == 0:\n",
1892 | " if generator_type == \"vitgan\":\n",
1893 | " Generator.eval()\n",
1894 | " vis = Generator(fixed_noise)\n",
1895 | " vis = vis.reshape([-1, image_size, image_size, out_features])\n",
1896 | " vis = vis.permute(0, 3, 1, 2)\n",
1897 | " else:\n",
1898 | " model_output = cnn_generator(fixed_noise)\n",
1899 | " vis = model_output\n",
1900 | "\n",
1901 | " assert vis.shape[0] == fixed_noise.shape[0], f'vis.shape[0] is {vis.shape[0]}, but should be {fixed_noise.shape[0]}'\n",
1902 | " assert vis.shape[1] == out_features, f'vis.shape[1] is {vis.shape[1]}, but should be {out_features}'\n",
1903 | " assert vis.shape[2] == image_size, f'vis.shape[2] is {vis.shape[2]}, but should be {image_size}'\n",
1904 | " assert vis.shape[3] == image_size, f'vis.shape[3] is {vis.shape[3]}, but should be {image_size}'\n",
1905 | " \n",
1906 | " vis.detach().cpu()\n",
1907 | " vis = make_grid(vis, nrow = 4, padding = 5, normalize = True)\n",
1908 | " writer.add_image(f'Generated/epoch_{epoch}', vis)\n",
1909 | " wandb.log({'examples': wandb.Image(vis)})\n",
1910 | "\n",
1911 | " vis = T.ToPILImage()(vis)\n",
1912 | " vis.save(f'{experiment_folder_name}/samples/vis{epoch}.jpg')\n",
1913 | " if generator_type == \"vitgan\":\n",
1914 | " Generator.train()\n",
1915 | " else:\n",
1916 | " cnn_generator.train()\n",
1917 | " print(f\"Save sample to {experiment_folder_name}/samples/vis{epoch}.jpg\")\n",
1918 | "\n",
1919 | " # Save the checkpoints.\n",
1920 | " if generator_type == \"vitgan\":\n",
1921 | " torch.save(Generator, f'{experiment_folder_name}/weights/Generator.pth')\n",
1922 | " elif generator_type == \"cnn\":\n",
1923 | " torch.save(cnn_generator, f'{experiment_folder_name}/weights/cnn_generator.pth')\n",
1924 | " torch.save(discriminator, f'{experiment_folder_name}/weights/discriminator.pth')\n",
1925 | " print(\"Save model state.\")\n",
1926 | "\n",
1927 | "writer.close()"
1928 | ]
1929 | },
1930 | {
1931 | "cell_type": "code",
1932 | "execution_count": null,
1933 | "metadata": {
1934 | "id": "LWDnspPtsot5"
1935 | },
1936 | "outputs": [],
1937 | "source": []
1938 | }
1939 | ],
1940 | "metadata": {
1941 | "accelerator": "GPU",
1942 | "colab": {
1943 | "collapsed_sections": [],
1944 | "name": "ViTGAN-pytorch.ipynb",
1945 | "private_outputs": true,
1946 | "provenance": []
1947 | },
1948 | "kernelspec": {
1949 | "display_name": "Python 3",
1950 | "language": "python",
1951 | "name": "python3"
1952 | },
1953 | "language_info": {
1954 | "codemirror_mode": {
1955 | "name": "ipython",
1956 | "version": 3
1957 | },
1958 | "file_extension": ".py",
1959 | "mimetype": "text/x-python",
1960 | "name": "python",
1961 | "nbconvert_exporter": "python",
1962 | "pygments_lexer": "ipython3",
1963 | "version": "3.7.4"
1964 | },
1965 | "pycharm": {
1966 | "stem_cell": {
1967 | "cell_type": "raw",
1968 | "metadata": {
1969 | "collapsed": false
1970 | },
1971 | "source": []
1972 | }
1973 | }
1974 | },
1975 | "nbformat": 4,
1976 | "nbformat_minor": 0
1977 | }
1978 |
--------------------------------------------------------------------------------