├── .gitignore
├── README.md
├── checkpoints
└── readme.txt
├── dagan_trainer.py
├── dataset.py
├── datasets
└── readme.txt
├── discriminator.py
├── generator.py
├── notebook.ipynb
├── resources
├── dagan_tracking_images.png
└── dagan_training_progress.gif
├── train_dagan.py
├── train_omniglot_classifier.py
└── utils
├── classifier_utils.py
├── gif_maker.py
└── parser.py
/.gitignore:
--------------------------------------------------------------------------------
1 | datasets/*
2 | checkpoints/*
3 | !*readme.txt
4 | *.DS_Store
5 | .ipynb_checkpoints*
6 | __pycache__*
7 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Data Augmentation GAN in PyTorch
2 |
3 |
4 |
5 |
6 | Time-lapse of DAGAN generations on the omniglot dataset over the course of the training process.
7 |
8 |
9 | ## Table of Contents
10 | 1. [Intro](#intro)
11 | 2. [Background](#background)
12 | 3. [Results](#results)
13 | 4. [Training your own DAGAN](#train)
14 | 5. [Technical Details](#details)
15 | 1. [DAGAN Training Process](#dagan_train)
16 | 2. [Classifier Training Process](#classifier_train)
17 | 3. [Architectures](#architectures)
18 | 6. [Acknowledgements](#acknowledgements)
19 |
20 | ## 1. Intro
21 |
22 | This is a PyTorch implementation of Data Augmentation GAN (DAGAN), which was first proposed in [this paper](https://arxiv.org/abs/1711.04340) with a [corresponding TensorFlow implementation](https://github.com/AntreasAntoniou/DAGAN).
23 |
24 | This repo uses the same generator and discriminator architecture of the original TF implementation, while also including a classifier script for the omniglot dataset to test out the quality of a trained DAGAN.
25 |
26 | ## 2. Background
27 |
28 | The motivation for this work is to train a [Generative Adversarial Network (GAN)](https://en.wikipedia.org/wiki/Generative_adversarial_network) which takes in an image of a given class (e.g. a specific letter in an alphabet) and outputs another image of the same class that is sufficiently different looking than the input. This GAN is then used as a tool for data augmentation when training an image classifier.
29 |
30 | Standard data augmentation includes methods such as adding noise to, rotating, or cropping images, which increases variation in the training samples and improves the robustness of the trained classifier. Randomly passing some images through the DAGAN generator before using them in training serves a similar purpose.
31 |
32 | ## 3. Results
33 |
34 | To measure the quality of the DAGAN, classifiers were trained both with and without DAGAN augmentations to see if there was improvement in classifier accuracy with augmentations. The original paper showed improvement on the omniglot dataset using 5, 10, and 15 images per class to train the classifier. As expected, the fewer samples used, the more impactful the augmentations were.
35 |
36 | This PyTorch implementation showed statistically significant improvment on the omniglot dataset with 1-4 samples per class but had negligible gains with 5+ samples per class. The below table shows the classifier accuracy with and without DAGAN augmentations as well as the statistical significance level that the augmentations are in fact better. (More details on confidence interval methodology [can be found here](#classifier_train)).
37 |
38 |
39 | | Samples per Class | 1 | 2 | 3 | 4 | 5 |
40 | |----------------------------------------------|-------|-------|-------|-------|-------|
41 | | Acc. w/o DAGAN | 16.4% | 29.0% | 46.5% | 57.8% | 67.1% |
42 | | Acc. w/ DAGAN | 19.2% | 39.4% | 52.0% | 65.3% | 69.5% |
43 | | Confidence level that augmentations are better | 97.6% | 99.9% | 97.4% | 97.8% | 60.7% |
44 |
45 |
46 | ## 4. Training your own DAGAN
47 |
48 | The easiest way to train your own DAGAN or augmented omniglot classifier is through Google Colab. The Colab notebooks used to produce the results shown here can be found below:
49 | - [Train omniglot DAGAN](https://colab.research.google.com/drive/1U-twOEiguyIgiL6h9H6130tF-O_g-b-u)
50 | - [Train omniglot classifier with DAGAN augmentations](https://colab.research.google.com/drive/1oJggcS6-3x_chbEfahSJCsy19kWBxWeE)
51 |
52 | Running those notebooks as is should reproduce the results presented in this readme. One of the advantages of PyTorch relative to TensorFlow is the ease of modifying and testing out changes in the training process, particulary to the network architecture. To test out changes, you can fork this repo, make necessary changes, and re-run the colab script using the forked repo.
53 |
54 | ## 5. Technical Details
55 |
56 | ### 5.1 DAGAN Training Process
57 | Recall the procedure for training a traditional GAN:
58 | - Create 2 networks, a generator (G) and a discriminator (D)
59 | - To train D
60 | - Randomly sample images from G and train D to recognize as fake
61 | - Randomly sample real images and train D to recognize as real`
62 | - To train G
63 | - Sample images from G and pass them through D
64 | - Train/modify G to increase likelihood D classifies given samples as real
65 | - Train G and D alternately
66 | - This iteratively makes D better at distinguishing real and fake, while making G better at producing realistic images
67 |
68 | Training a DAGAN requires a slight twist:
69 | - To train D
70 | - Randomly sample pairs of real images (source, target), where both items in a pair belong to the same class
71 | - Pass both source, target to D to recognize as real
72 | - Pass source image to G to produce a realistic looking target
73 | - Pass source, generated target to D to recognize as fake
74 | - To train G
75 | - Sample real images (source) and pass them through G to produce generated targets
76 | - Train/modify G to increase likelihood D classifies (source, generated target) pairs as real
77 | - D learns to distinguish real and fake targets for a given source image
78 | - G learns to produce images that belong to the same class as source, while not being too similar to source (being too similar would provide a simple way for D to recognize fake targets)
79 | - Thus, G provides varied images that are somewhat similar to the source, which is our ultimate goal
80 |
81 | The omniglot DAGAN was trained on all the examples in the first 1200 classes of the dataset. The generator was validated on the next 200 classes by visual inspection. Training was done for 50 epochs, which took 3.3 hours on a Tesla T4 GPU.
82 |
83 | The network was trained using the Adam optimizer and the Improved Wasserstein loss function, which has some useful properties allowing signal to better pass from D to G during the training of G. More details can be found in the [Improved Wasserstein GAN paper](https://arxiv.org/abs/1704.00028).
84 |
85 | ### 5.2 Omniglot Classifier Training Process
86 |
87 | Omniglot classifiers were trained on classes #1420-1519 (100 classes) of the dataset for 200 epochs. Classifiers were trained with and without augmentations. When trained with augmentations, every other batch was passed through the DAGAN, so the total number of steps was the same in both configurations.
88 |
89 | To estimate more robustly the accuracy in each configuration, 10 classifiers were trained, each on a slightly different dataset. More specifically, out of the 20 samples available for each class, a different subset of k images was chosen for each of the 10 classifiers. A two-sample t-test was then used to determine confidence level that the 2 distributions of accuracies were sufficiently different (i.e. statistical significance of accuracy improvement from augmentation).
90 |
91 | This exercise was repeated using (1, 2, 3, 4, 5) samples per class. Training was done using Adam optimizer and standard cross-entropy loss.
92 |
93 | ### 5.3 Architectures
94 |
95 | The DAGAN architectures are described in detail in the paper and can also be seen in the PyTorch implementation of the [generator](https://github.com/amurthy1/dagan_torch/blob/master/generator.py) and [discriminator](https://github.com/amurthy1/dagan_torch/blob/master/discriminator.py).
96 |
97 | In a nutshell, the generator is a UNet of dense convolutional blocks. Each block has 4 conv layers, while the UNet itself is 4 blocks deep on each side.
98 |
99 | The discriminator is a DenseNet with 4 blocks, each containing 4 conv layers.
100 |
101 | The omniglot classifier uses the [standard PyTorch DenseNet implementation](https://pytorch.org/hub/pytorch_vision_densenet/) with 4 blocks, each having 3 conv layers. The last layer of the classifier was concatenated with a 0/1 flag representing whether a given input was real or generated. This was followed by 2 dense layers before outputting the final classification probabilities. This was useful to allow the image features output from the last layer of the DenseNet to interact with the real/generated flag in order to produce more accurate predictions.
102 |
103 |
104 | ## 6. Acknowledgements
105 |
106 | - As mentioned earlier, this work was adopted from [this paper](https://arxiv.org/abs/1711.04340) and [this repo](https://github.com/AntreasAntoniou/DAGAN) by A. Antoniou et al.
107 |
108 | - The omniglot dataset was originally sourced from [this github repo](https://github.com/brendenlake/omniglot/) by user [brendanlake](https://github.com/brendenlake).
109 |
110 | - The PyTorch Wasserstein GAN (WGAN) implementation in this repo was closely adopted from [this repo](https://github.com/EmilienDupont/wgan-gp) by user [EmilienDupont](https://github.com/EmilienDupont/).
111 |
112 |
113 |
--------------------------------------------------------------------------------
/checkpoints/readme.txt:
--------------------------------------------------------------------------------
1 | Store intermediate model checkpoints here
2 |
--------------------------------------------------------------------------------
/dagan_trainer.py:
--------------------------------------------------------------------------------
1 | import imageio
2 | import numpy as np
3 | import torch
4 | import time
5 | import torch.nn as nn
6 | import torchvision.transforms as transforms
7 | from torchvision.utils import make_grid
8 | from torch.autograd import Variable
9 | from torch.autograd import grad as torch_grad
10 | from PIL import Image
11 | import PIL
12 | import warnings
13 |
14 |
15 | class DaganTrainer:
16 | def __init__(
17 | self,
18 | generator,
19 | discriminator,
20 | gen_optimizer,
21 | dis_optimizer,
22 | batch_size,
23 | device="cpu",
24 | gp_weight=10,
25 | critic_iterations=5,
26 | print_every=50,
27 | num_tracking_images=0,
28 | save_checkpoint_path=None,
29 | load_checkpoint_path=None,
30 | display_transform=None,
31 | should_display_generations=True,
32 | ):
33 | self.device = device
34 | self.g = generator.to(device)
35 | self.g_opt = gen_optimizer
36 | self.d = discriminator.to(device)
37 | self.d_opt = dis_optimizer
38 | self.losses = {"G": [0.0], "D": [0.0], "GP": [0.0], "gradient_norm": [0.0]}
39 | self.num_steps = 0
40 | self.epoch = 0
41 | self.gp_weight = gp_weight
42 | self.critic_iterations = critic_iterations
43 | self.print_every = print_every
44 | self.num_tracking_images = num_tracking_images
45 | self.display_transform = display_transform or transforms.ToTensor()
46 | self.checkpoint_path = save_checkpoint_path
47 | self.should_display_generations = should_display_generations
48 |
49 | # Track progress of fixed images throughout the training
50 | self.tracking_images = None
51 | self.tracking_z = None
52 | self.tracking_images_gens = None
53 |
54 | if load_checkpoint_path:
55 | self.hydrate_checkpoint(load_checkpoint_path)
56 |
57 | def _critic_train_iteration(self, x1, x2):
58 | """ """
59 | # Get generated data
60 | generated_data = self.sample_generator(x1)
61 |
62 | d_real = self.d(x1, x2)
63 | d_generated = self.d(x1, generated_data)
64 |
65 | # Get gradient penalty
66 | gradient_penalty = self._gradient_penalty(x1, x2, generated_data)
67 | self.losses["GP"].append(gradient_penalty.item())
68 |
69 | # Create total loss and optimize
70 | self.d_opt.zero_grad()
71 | d_loss = d_generated.mean() - d_real.mean() + gradient_penalty
72 | d_loss.backward()
73 |
74 | self.d_opt.step()
75 |
76 | # Record loss
77 | self.losses["D"].append(d_loss.item())
78 |
79 | def _generator_train_iteration(self, x1):
80 | """ """
81 | self.g_opt.zero_grad()
82 |
83 | # Get generated data
84 | generated_data = self.sample_generator(x1)
85 |
86 | # Calculate loss and optimize
87 | d_generated = self.d(x1, generated_data)
88 | g_loss = -d_generated.mean()
89 | g_loss.backward()
90 | self.g_opt.step()
91 |
92 | # Record loss
93 | self.losses["G"].append(g_loss.item())
94 |
95 | def _gradient_penalty(self, x1, x2, generated_data):
96 | # Calculate interpolation
97 | alpha = torch.rand(x1.shape[0], 1, 1, 1)
98 | alpha = alpha.expand_as(x2).to(self.device)
99 | interpolated = alpha * x2.data + (1 - alpha) * generated_data.data
100 | interpolated = Variable(interpolated, requires_grad=True).to(self.device)
101 |
102 | # Calculate probability of interpolated examples
103 | prob_interpolated = self.d(x1, interpolated)
104 |
105 | # Calculate gradients of probabilities with respect to examples
106 | gradients = torch_grad(
107 | outputs=prob_interpolated,
108 | inputs=interpolated,
109 | grad_outputs=torch.ones(prob_interpolated.size()).to(self.device),
110 | create_graph=True,
111 | retain_graph=True,
112 | )[0]
113 |
114 | # Gradients have shape (batch_size, num_channels, img_width, img_height),
115 | # so flatten to easily take norm per example in batch
116 | gradients = gradients.view(x1.shape[0], -1)
117 | self.losses["gradient_norm"].append(gradients.norm(2, dim=1).mean().item())
118 |
119 | # Derivatives of the gradient close to 0 can cause problems because of
120 | # the square root, so manually calculate norm and add epsilon
121 | gradients_norm = torch.sqrt(torch.sum(gradients ** 2, dim=1) + 1e-12)
122 |
123 | # Return gradient penalty
124 | return self.gp_weight * ((gradients_norm - 1) ** 2).mean()
125 |
126 | def _train_epoch(self, data_loader, val_images):
127 | for i, data in enumerate(data_loader):
128 | if i % self.print_every == 0:
129 | print("Iteration {}".format(i))
130 | self.print_progress(data_loader, val_images)
131 | self.num_steps += 1
132 | x1, x2 = data[0].to(self.device), data[1].to(self.device)
133 | self._critic_train_iteration(x1, x2)
134 | # Only update generator every |critic_iterations| iterations
135 | if self.num_steps % self.critic_iterations == 0:
136 | self._generator_train_iteration(x1)
137 |
138 | def train(self, data_loader, epochs, val_images=None, save_training_gif=True):
139 | if self.tracking_images is None and self.num_tracking_images > 0:
140 | self.tracking_images = self.sample_val_images(
141 | self.num_tracking_images // 2, val_images
142 | )
143 | self.tracking_images.extend(
144 | self.sample_train_images(
145 | self.num_tracking_images - len(self.tracking_images), data_loader
146 | )
147 | )
148 | self.tracking_images = torch.stack(self.tracking_images).to(self.device)
149 | self.tracking_z = torch.randn((self.num_tracking_images, self.g.z_dim)).to(
150 | self.device
151 | )
152 | self.tracking_images_gens = []
153 |
154 | # Save checkpoint once before training to catch errors
155 | self._save_checkpoint()
156 |
157 | start_time = int(time.time())
158 |
159 | while self.epoch < epochs:
160 | print("\nEpoch {}".format(self.epoch))
161 | print(f"Elapsed time: {(time.time() - start_time) / 60:.2f} minutes\n")
162 |
163 | self._train_epoch(data_loader, val_images)
164 | self.epoch += 1
165 | self._save_checkpoint()
166 |
167 | def sample_generator(self, input_images, z=None):
168 | if z is None:
169 | z = torch.randn((input_images.shape[0], self.g.z_dim)).to(self.device)
170 | return self.g(input_images, z)
171 |
172 | def render_img(self, arr):
173 | arr = (arr * 0.5) + 0.5
174 | arr = np.uint8(arr * 255)
175 | display(Image.fromarray(arr, mode="L").transpose(PIL.Image.TRANSPOSE))
176 |
177 | def sample_train_images(self, n, data_loader):
178 | with warnings.catch_warnings():
179 | warnings.simplefilter("ignore", category=UserWarning)
180 | return [
181 | self.display_transform(data_loader.dataset.x1_examples[idx])
182 | for idx in torch.randint(0, len(data_loader.dataset), (n,))
183 | ]
184 |
185 | def sample_val_images(self, n, val_images):
186 | if val_images is None:
187 | return []
188 |
189 | with warnings.catch_warnings():
190 | warnings.simplefilter("ignore", category=UserWarning)
191 | return [
192 | self.display_transform(val_images[idx])
193 | for idx in torch.randint(0, len(val_images), (n,))
194 | ]
195 |
196 | def display_generations(self, data_loader, val_images):
197 | n = 5
198 | images = self.sample_train_images(n, data_loader) + self.sample_val_images(
199 | n, val_images
200 | )
201 | img_size = images[0].shape[-1]
202 | images.append(torch.tensor(np.ones((1, img_size, img_size))).float())
203 | images.append(torch.tensor(np.ones((1, img_size, img_size))).float() * -1)
204 | self.render_img(torch.cat(images, 1)[0])
205 | z = torch.randn((len(images), self.g.z_dim)).to(self.device)
206 | inp = torch.stack(images).to(self.device)
207 | train_gen = self.g(inp, z).cpu()
208 | self.render_img(train_gen.reshape(-1, train_gen.shape[-1]))
209 |
210 | def print_progress(self, data_loader, val_images):
211 | self.g.eval()
212 | with torch.no_grad():
213 | if self.should_display_generations:
214 | self.display_generations(data_loader, val_images)
215 | if self.num_tracking_images > 0:
216 | self.tracking_images_gens.append(
217 | self.g(self.tracking_images, self.tracking_z).cpu()
218 | )
219 | self.g.train()
220 | print("D: {}".format(self.losses["D"][-1]))
221 | print("Raw D: {}".format(self.losses["D"][-1] - self.losses["GP"][-1]))
222 | print("GP: {}".format(self.losses["GP"][-1]))
223 | print("Gradient norm: {}".format(self.losses["gradient_norm"][-1]))
224 | if self.num_steps > self.critic_iterations:
225 | print("G: {}".format(self.losses["G"][-1]))
226 |
227 | def _save_checkpoint(self):
228 | if self.checkpoint_path is None:
229 | return
230 | checkpoint = {
231 | "epoch": self.epoch,
232 | "num_steps": self.num_steps,
233 | "g": self.g.state_dict(),
234 | "g_opt": self.g_opt.state_dict(),
235 | "d": self.d.state_dict(),
236 | "d_opt": self.d_opt.state_dict(),
237 | "tracking_images": self.tracking_images,
238 | "tracking_z": self.tracking_z,
239 | "tracking_images_gens": self.tracking_images_gens,
240 | }
241 | torch.save(checkpoint, self.checkpoint_path)
242 |
243 | def hydrate_checkpoint(self, checkpoint_path):
244 | checkpoint = torch.load(checkpoint_path, map_location=self.device)
245 | self.epoch = checkpoint["epoch"]
246 | self.num_steps = checkpoint["num_steps"]
247 |
248 | self.g.load_state_dict(checkpoint["g"])
249 | self.g_opt.load_state_dict(checkpoint["g_opt"])
250 | self.d.load_state_dict(checkpoint["d"])
251 | self.d_opt.load_state_dict(checkpoint["d_opt"])
252 |
253 | self.tracking_images = checkpoint["tracking_images"]
254 | self.tracking_z = checkpoint["tracking_z"]
255 | self.tracking_images_gens = checkpoint["tracking_images_gens"]
256 |
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset, DataLoader
2 | import numpy as np
3 | import warnings
4 |
5 |
6 | class DaganDataset(Dataset):
7 | """Face Landmarks dataset."""
8 |
9 | def __init__(self, x1_examples, x2_examples, transform=None):
10 | assert len(x1_examples) == len(x2_examples)
11 | self.x1_examples = x1_examples
12 | self.x2_examples = x2_examples
13 | self.transform = transform
14 |
15 | def __len__(self):
16 | return len(self.x1_examples)
17 |
18 | def __getitem__(self, idx):
19 | with warnings.catch_warnings():
20 | warnings.simplefilter("ignore", category=UserWarning)
21 | return self.transform(self.x1_examples[idx]), self.transform(
22 | self.x2_examples[idx]
23 | )
24 |
25 |
26 | def create_dagan_dataloader(raw_data, num_classes, transform, batch_size):
27 | train_x1 = []
28 | train_x2 = []
29 |
30 | for i in range(num_classes):
31 | x2_data = list(raw_data[i])
32 | np.random.shuffle(x2_data)
33 | train_x1.extend(raw_data[i])
34 | train_x2.extend(x2_data)
35 |
36 | train_dataset = DaganDataset(train_x1, train_x2, transform)
37 | return DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=1)
38 |
--------------------------------------------------------------------------------
/datasets/readme.txt:
--------------------------------------------------------------------------------
1 | Store datasets here
2 |
--------------------------------------------------------------------------------
/discriminator.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import torch
4 | from torch import nn
5 | import torch.nn.functional as F
6 |
7 |
8 | class _LayerNorm(nn.Module):
9 | def __init__(self, num_features, img_size):
10 | """
11 | Normalizes over the entire image and scales + weights for each feature
12 | """
13 | super().__init__()
14 | self.layer_norm = nn.LayerNorm(
15 | (num_features, img_size, img_size), elementwise_affine=False, eps=1e-12
16 | )
17 | self.weight = torch.nn.Parameter(
18 | torch.ones(num_features).float().unsqueeze(-1).unsqueeze(-1),
19 | requires_grad=True,
20 | )
21 | self.bias = torch.nn.Parameter(
22 | torch.zeros(num_features).float().unsqueeze(-1).unsqueeze(-1),
23 | requires_grad=True,
24 | )
25 |
26 | def forward(self, x):
27 | out = self.layer_norm(x)
28 | out = out * self.weight + self.bias
29 | return out
30 |
31 |
32 | class _SamePad(nn.Module):
33 | """
34 | Pads equivalent to the behavior of tensorflow "SAME"
35 | """
36 |
37 | def __init__(self, stride):
38 | super().__init__()
39 | self.stride = stride
40 |
41 | def forward(self, x):
42 | if self.stride == 2 and x.shape[2] % 2 == 0:
43 | return F.pad(x, (0, 1, 0, 1))
44 | return F.pad(x, (1, 1, 1, 1))
45 |
46 |
47 | def _conv2d(
48 | in_channels,
49 | out_channels,
50 | kernel_size,
51 | stride,
52 | out_size=None,
53 | activate=True,
54 | dropout=0.0,
55 | ):
56 | layers = OrderedDict()
57 | layers["pad"] = _SamePad(stride)
58 | layers["conv"] = nn.Conv2d(in_channels, out_channels, kernel_size, stride)
59 | if activate:
60 | if out_size is None:
61 | raise ValueError("Must provide out_size if activate is True")
62 | layers["relu"] = nn.LeakyReLU(0.2)
63 | layers["norm"] = _LayerNorm(out_channels, out_size)
64 |
65 | if dropout > 0.0:
66 | layers["dropout"] = nn.Dropout(dropout)
67 | return nn.Sequential(layers)
68 |
69 |
70 | class _EncoderBlock(nn.Module):
71 | def __init__(
72 | self,
73 | pre_channels,
74 | in_channels,
75 | out_channels,
76 | num_layers,
77 | out_size,
78 | dropout_rate=0.0,
79 | ):
80 | super().__init__()
81 | self.num_layers = num_layers
82 | self.pre_conv = _conv2d(
83 | in_channels=pre_channels,
84 | out_channels=pre_channels,
85 | kernel_size=3,
86 | stride=2,
87 | activate=False,
88 | )
89 |
90 | self.conv0 = _conv2d(
91 | in_channels=in_channels + pre_channels,
92 | out_channels=out_channels,
93 | kernel_size=3,
94 | stride=1,
95 | out_size=out_size,
96 | )
97 | total_channels = in_channels + out_channels
98 | for i in range(1, num_layers):
99 | self.add_module(
100 | "conv%d" % i,
101 | _conv2d(
102 | in_channels=total_channels,
103 | out_channels=out_channels,
104 | kernel_size=3,
105 | stride=1,
106 | out_size=out_size,
107 | ),
108 | )
109 | total_channels += out_channels
110 | self.add_module(
111 | "conv%d" % num_layers,
112 | _conv2d(
113 | in_channels=total_channels,
114 | out_channels=out_channels,
115 | kernel_size=3,
116 | stride=2,
117 | out_size=(out_size + 1) // 2,
118 | dropout=dropout_rate,
119 | ),
120 | )
121 |
122 | def forward(self, inp):
123 | pre_input, x = inp
124 | pre_input = self.pre_conv(pre_input)
125 | out = self.conv0(torch.cat([x, pre_input], 1))
126 |
127 | all_outputs = [x, out]
128 | for i in range(1, self.num_layers + 1):
129 | input_features = torch.cat(
130 | [all_outputs[-1], all_outputs[-2]] + all_outputs[:-2], 1
131 | )
132 | module = self._modules["conv%d" % i]
133 | out = module(input_features)
134 | all_outputs.append(out)
135 | return all_outputs[-2], all_outputs[-1]
136 |
137 |
138 | class Discriminator(nn.Module):
139 | def __init__(self, dim, channels, dropout_rate=0.0, z_dim=100):
140 | super().__init__()
141 | self.dim = dim
142 | self.z_dim = z_dim
143 | self.channels = channels
144 | self.layer_sizes = [64, 64, 128, 128]
145 | self.num_inner_layers = 5
146 |
147 | # Number of times dimension is halved
148 | self.depth = len(self.layer_sizes)
149 |
150 | # dimension at each level of U-net
151 | self.dim_arr = [dim]
152 | for i in range(self.depth):
153 | self.dim_arr.append((self.dim_arr[-1] + 1) // 2)
154 |
155 | # Encoders
156 | self.encode0 = _conv2d(
157 | in_channels=self.channels,
158 | out_channels=self.layer_sizes[0],
159 | kernel_size=3,
160 | stride=2,
161 | out_size=self.dim_arr[1],
162 | )
163 | for i in range(1, self.depth):
164 | self.add_module(
165 | "encode%d" % i,
166 | _EncoderBlock(
167 | pre_channels=self.channels if i == 1 else self.layer_sizes[i - 1],
168 | in_channels=self.layer_sizes[i - 1],
169 | out_channels=self.layer_sizes[i],
170 | num_layers=self.num_inner_layers,
171 | out_size=self.dim_arr[i],
172 | dropout_rate=dropout_rate,
173 | ),
174 | )
175 | self.dense1 = nn.Linear(self.layer_sizes[-1], 1024)
176 | self.leaky_relu = nn.LeakyReLU(0.2)
177 | self.dense2 = nn.Linear(self.layer_sizes[-1] * self.dim_arr[-1] ** 2 + 1024, 1)
178 |
179 | def forward(self, x1, x2):
180 | x = torch.cat([x1, x2], 1)
181 | out = [x, self.encode0(x)]
182 | for i in range(1, len(self.layer_sizes)):
183 | out = self._modules["encode%d" % i](out)
184 | out = out[1]
185 |
186 | out_mean = out.mean([2, 3])
187 | out_flat = torch.flatten(out, 1)
188 |
189 | out = self.dense1(out_mean)
190 | out = self.leaky_relu(out)
191 | out = self.dense2(torch.cat([out, out_flat], 1))
192 |
193 | return out
194 |
--------------------------------------------------------------------------------
/generator.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import torch
4 | from torch import nn
5 | import torch.nn.functional as F
6 |
7 |
8 | class _SamePad(nn.Module):
9 | """
10 | Pads equivalent to the behavior of tensorflow "SAME"
11 | """
12 |
13 | def __init__(self, stride):
14 | super().__init__()
15 | self.stride = stride
16 |
17 | def forward(self, x):
18 | if self.stride == 2 and x.shape[2] % 2 == 0:
19 | return F.pad(x, (0, 1, 0, 1))
20 | return F.pad(x, (1, 1, 1, 1))
21 |
22 |
23 | def _conv2d(in_channels, out_channels, kernel_size, stride, activate=True, dropout=0.0):
24 | layers = OrderedDict()
25 | layers["pad"] = _SamePad(stride)
26 | layers["conv"] = nn.Conv2d(in_channels, out_channels, kernel_size, stride)
27 | if activate:
28 | layers["relu"] = nn.LeakyReLU(0.2)
29 | layers["batchnorm"] = nn.BatchNorm2d(out_channels, eps=1e-3, momentum=0.01)
30 |
31 | if dropout > 0.0:
32 | layers["dropout"] = nn.Dropout(dropout)
33 | return nn.Sequential(layers)
34 |
35 |
36 | def _conv2d_transpose(
37 | in_channels, out_channels, kernel_size, upscale_size, activate=True, dropout=0.0
38 | ):
39 | layers = OrderedDict()
40 | layers["upsample"] = nn.Upsample(upscale_size)
41 | layers["conv"] = nn.ConvTranspose2d(
42 | in_channels, out_channels, kernel_size, stride=1, padding=1
43 | )
44 | if activate:
45 | layers["relu"] = nn.LeakyReLU(0.2)
46 | layers["batchnorm"] = nn.BatchNorm2d(out_channels, eps=1e-3, momentum=0.01)
47 |
48 | if dropout > 0.0:
49 | layers["dropout"] = nn.Dropout(dropout)
50 | return nn.Sequential(layers)
51 |
52 |
53 | class _EncoderBlock(nn.Module):
54 | def __init__(
55 | self, pre_channels, in_channels, out_channels, num_layers, dropout_rate=0.0
56 | ):
57 | super().__init__()
58 | self.num_layers = num_layers
59 | self.pre_conv = _conv2d(
60 | in_channels=pre_channels,
61 | out_channels=pre_channels,
62 | kernel_size=3,
63 | stride=2,
64 | activate=False,
65 | )
66 |
67 | self.conv0 = _conv2d(
68 | in_channels=in_channels + pre_channels,
69 | out_channels=out_channels,
70 | kernel_size=3,
71 | stride=1,
72 | )
73 | total_channels = in_channels + out_channels
74 | for i in range(1, num_layers):
75 | self.add_module(
76 | "conv%d" % i,
77 | _conv2d(
78 | in_channels=total_channels,
79 | out_channels=out_channels,
80 | kernel_size=3,
81 | stride=1,
82 | ),
83 | )
84 | total_channels += out_channels
85 | self.add_module(
86 | "conv%d" % num_layers,
87 | _conv2d(
88 | in_channels=total_channels,
89 | out_channels=out_channels,
90 | kernel_size=3,
91 | stride=2,
92 | dropout=dropout_rate,
93 | ),
94 | )
95 |
96 | def forward(self, inp):
97 | pre_input, x = inp
98 | pre_input = self.pre_conv(pre_input)
99 | out = self.conv0(torch.cat([x, pre_input], 1))
100 |
101 | all_outputs = [x, out]
102 | for i in range(1, self.num_layers + 1):
103 | input_features = torch.cat(all_outputs, 1)
104 | module = self._modules["conv%d" % i]
105 | out = module(input_features)
106 | all_outputs.append(out)
107 | return all_outputs[-2], all_outputs[-1]
108 |
109 |
110 | class _DecoderBlock(nn.Module):
111 | def __init__(
112 | self,
113 | pre_channels,
114 | in_channels,
115 | out_channels,
116 | num_layers,
117 | curr_size,
118 | upscale_size=None,
119 | dropout_rate=0.0,
120 | ):
121 | super().__init__()
122 | self.num_layers = num_layers
123 | self.should_upscale = upscale_size is not None
124 | self.should_pre_conv = pre_channels > 0
125 |
126 | total_channels = pre_channels + in_channels
127 | for i in range(num_layers):
128 | if self.should_pre_conv:
129 | self.add_module(
130 | "pre_conv_t%d" % i,
131 | _conv2d_transpose(
132 | in_channels=pre_channels,
133 | out_channels=pre_channels,
134 | kernel_size=3,
135 | upscale_size=curr_size,
136 | activate=False,
137 | ),
138 | )
139 | self.add_module(
140 | "conv%d" % i,
141 | _conv2d(
142 | in_channels=total_channels,
143 | out_channels=out_channels,
144 | kernel_size=3,
145 | stride=1,
146 | ),
147 | )
148 | total_channels += out_channels
149 |
150 | if self.should_upscale:
151 | total_channels -= pre_channels
152 | self.add_module(
153 | "conv_t%d" % num_layers,
154 | _conv2d_transpose(
155 | in_channels=total_channels,
156 | out_channels=out_channels,
157 | kernel_size=3,
158 | upscale_size=upscale_size,
159 | dropout=dropout_rate,
160 | ),
161 | )
162 |
163 | def forward(self, inp):
164 | pre_input, x = inp
165 | all_outputs = [x]
166 | for i in range(self.num_layers):
167 | curr_input = all_outputs[-1]
168 | if self.should_pre_conv:
169 | pre_conv_output = self._modules["pre_conv_t%d" % i](pre_input)
170 | curr_input = torch.cat([curr_input, pre_conv_output], 1)
171 | input_features = torch.cat([curr_input] + all_outputs[:-1], 1)
172 | module = self._modules["conv%d" % i]
173 | out = module(input_features)
174 | all_outputs.append(out)
175 |
176 | if self.should_upscale:
177 | module = self._modules["conv_t%d" % self.num_layers]
178 | input_features = torch.cat(all_outputs, 1)
179 | out = module(input_features)
180 | all_outputs.append(out)
181 | return all_outputs[-2], all_outputs[-1]
182 |
183 |
184 | class Generator(nn.Module):
185 | def __init__(self, dim, channels, dropout_rate=0.0, z_dim=100):
186 | super().__init__()
187 | self.dim = dim
188 | self.z_dim = z_dim
189 | self.channels = channels
190 | self.layer_sizes = [64, 64, 128, 128]
191 | self.num_inner_layers = 3
192 |
193 | # Number of times dimension is halved
194 | self.U_depth = len(self.layer_sizes)
195 |
196 | # dimension at each level of U-net
197 | self.dim_arr = [dim]
198 | for i in range(self.U_depth):
199 | self.dim_arr.append((self.dim_arr[-1] + 1) // 2)
200 |
201 | # Encoders
202 | self.encode0 = _conv2d(
203 | in_channels=1, out_channels=self.layer_sizes[0], kernel_size=3, stride=2
204 | )
205 | for i in range(1, self.U_depth):
206 | self.add_module(
207 | "encode%d" % i,
208 | _EncoderBlock(
209 | pre_channels=self.channels if i == 1 else self.layer_sizes[i - 1],
210 | in_channels=self.layer_sizes[i - 1],
211 | out_channels=self.layer_sizes[i],
212 | num_layers=self.num_inner_layers,
213 | dropout_rate=dropout_rate,
214 | ),
215 | )
216 |
217 | # Noise encoders
218 | self.noise_encoders = 3
219 | num_noise_filters = 8
220 | self.z_channels = []
221 | for i in range(self.noise_encoders):
222 | curr_dim = self.dim_arr[-1 - i] # Iterate dim from back
223 | self.add_module(
224 | "z_reshape%d" % i,
225 | nn.Linear(self.z_dim, curr_dim * curr_dim * num_noise_filters),
226 | )
227 | self.z_channels.append(num_noise_filters)
228 | num_noise_filters //= 2
229 |
230 | # Decoders
231 | for i in range(self.U_depth + 1):
232 | # Input from previous decoder
233 | in_channels = 0 if i == 0 else self.layer_sizes[-i]
234 | # Input from encoder across the "U"
235 | in_channels += (
236 | self.channels if i == self.U_depth else self.layer_sizes[-i - 1]
237 | )
238 | # Input from injected noise
239 | if i < self.noise_encoders:
240 | in_channels += self.z_channels[i]
241 |
242 | self.add_module(
243 | "decode%d" % i,
244 | _DecoderBlock(
245 | pre_channels=0 if i == 0 else self.layer_sizes[-i],
246 | in_channels=in_channels,
247 | out_channels=self.layer_sizes[0]
248 | if i == self.U_depth
249 | else self.layer_sizes[-i - 1],
250 | num_layers=self.num_inner_layers,
251 | curr_size=self.dim_arr[-i - 1],
252 | upscale_size=None if i == self.U_depth else self.dim_arr[-i - 2],
253 | dropout_rate=dropout_rate,
254 | ),
255 | )
256 |
257 | # Final conv
258 | self.num_final_conv = 3
259 | for i in range(self.num_final_conv - 1):
260 | self.add_module(
261 | "final_conv%d" % i,
262 | _conv2d(
263 | in_channels=self.layer_sizes[0],
264 | out_channels=self.layer_sizes[0],
265 | kernel_size=3,
266 | stride=1,
267 | ),
268 | )
269 | self.add_module(
270 | "final_conv%d" % (self.num_final_conv - 1),
271 | _conv2d(
272 | in_channels=self.layer_sizes[0],
273 | out_channels=self.channels,
274 | kernel_size=3,
275 | stride=1,
276 | activate=False,
277 | ),
278 | )
279 | self.tanh = nn.Tanh()
280 |
281 | def forward(self, x, z):
282 | # Final output of every encoding block
283 | all_outputs = [x, self.encode0(x)]
284 |
285 | # Last 2 layer outputs
286 | out = [x, self.encode0(x)]
287 | for i in range(1, len(self.layer_sizes)):
288 | out = self._modules["encode%d" % i](out)
289 | all_outputs.append(out[1])
290 |
291 | pre_input, curr_input = None, out[1]
292 | for i in range(self.U_depth + 1):
293 | if i > 0:
294 | curr_input = torch.cat([curr_input, all_outputs[-i - 1]], 1)
295 | if i < self.noise_encoders:
296 | z_out = self._modules["z_reshape%d" % i](z)
297 |
298 | curr_dim = self.dim_arr[-i - 1]
299 | z_out = z_out.view(-1, self.z_channels[i], curr_dim, curr_dim)
300 | curr_input = torch.cat([z_out, curr_input], 1)
301 |
302 | pre_input, curr_input = self._modules["decode%d" % i](
303 | [pre_input, curr_input]
304 | )
305 |
306 | for i in range(self.num_final_conv):
307 | curr_input = self._modules["final_conv%d" % i](curr_input)
308 | return self.tanh(curr_input)
309 |
--------------------------------------------------------------------------------
/notebook.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "%load_ext autoreload\n",
10 | "%autoreload 2"
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": 2,
16 | "metadata": {
17 | "scrolled": true
18 | },
19 | "outputs": [],
20 | "source": [
21 | "import torch\n",
22 | "from torch import nn\n",
23 | "import torch.nn.functional as F\n",
24 | "\n",
25 | "import numpy as np\n",
26 | "import os\n",
27 | "from PIL import Image\n",
28 | "from collections import OrderedDict\n",
29 | "from generator import Generator\n",
30 | "import torchvision.transforms as transforms"
31 | ]
32 | },
33 | {
34 | "cell_type": "code",
35 | "execution_count": 24,
36 | "metadata": {},
37 | "outputs": [
38 | {
39 | "data": {
40 | "text/plain": [
41 | "Generator(\n",
42 | " (encode0): Sequential(\n",
43 | " (pad): _SamePad()\n",
44 | " (conv): Conv2d(1, 64, kernel_size=(3, 3), stride=(2, 2))\n",
45 | " (relu): LeakyReLU(negative_slope=0.2)\n",
46 | " (batchnorm): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
47 | " )\n",
48 | " (encode1): _EncoderBlock(\n",
49 | " (pre_conv): Sequential(\n",
50 | " (pad): _SamePad()\n",
51 | " (conv): Conv2d(1, 1, kernel_size=(3, 3), stride=(2, 2))\n",
52 | " )\n",
53 | " (conv0): Sequential(\n",
54 | " (pad): _SamePad()\n",
55 | " (conv): Conv2d(65, 64, kernel_size=(3, 3), stride=(1, 1))\n",
56 | " (relu): LeakyReLU(negative_slope=0.2)\n",
57 | " (batchnorm): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
58 | " )\n",
59 | " (conv1): Sequential(\n",
60 | " (pad): _SamePad()\n",
61 | " (conv): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1))\n",
62 | " (relu): LeakyReLU(negative_slope=0.2)\n",
63 | " (batchnorm): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
64 | " )\n",
65 | " (conv2): Sequential(\n",
66 | " (pad): _SamePad()\n",
67 | " (conv): Conv2d(192, 64, kernel_size=(3, 3), stride=(1, 1))\n",
68 | " (relu): LeakyReLU(negative_slope=0.2)\n",
69 | " (batchnorm): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
70 | " )\n",
71 | " (conv3): Sequential(\n",
72 | " (pad): _SamePad()\n",
73 | " (conv): Conv2d(256, 64, kernel_size=(3, 3), stride=(2, 2))\n",
74 | " (relu): LeakyReLU(negative_slope=0.2)\n",
75 | " (batchnorm): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
76 | " )\n",
77 | " )\n",
78 | " (encode2): _EncoderBlock(\n",
79 | " (pre_conv): Sequential(\n",
80 | " (pad): _SamePad()\n",
81 | " (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2))\n",
82 | " )\n",
83 | " (conv0): Sequential(\n",
84 | " (pad): _SamePad()\n",
85 | " (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))\n",
86 | " (relu): LeakyReLU(negative_slope=0.2)\n",
87 | " (batchnorm): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
88 | " )\n",
89 | " (conv1): Sequential(\n",
90 | " (pad): _SamePad()\n",
91 | " (conv): Conv2d(192, 128, kernel_size=(3, 3), stride=(1, 1))\n",
92 | " (relu): LeakyReLU(negative_slope=0.2)\n",
93 | " (batchnorm): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
94 | " )\n",
95 | " (conv2): Sequential(\n",
96 | " (pad): _SamePad()\n",
97 | " (conv): Conv2d(320, 128, kernel_size=(3, 3), stride=(1, 1))\n",
98 | " (relu): LeakyReLU(negative_slope=0.2)\n",
99 | " (batchnorm): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
100 | " )\n",
101 | " (conv3): Sequential(\n",
102 | " (pad): _SamePad()\n",
103 | " (conv): Conv2d(448, 128, kernel_size=(3, 3), stride=(2, 2))\n",
104 | " (relu): LeakyReLU(negative_slope=0.2)\n",
105 | " (batchnorm): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
106 | " )\n",
107 | " )\n",
108 | " (encode3): _EncoderBlock(\n",
109 | " (pre_conv): Sequential(\n",
110 | " (pad): _SamePad()\n",
111 | " (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2))\n",
112 | " )\n",
113 | " (conv0): Sequential(\n",
114 | " (pad): _SamePad()\n",
115 | " (conv): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1))\n",
116 | " (relu): LeakyReLU(negative_slope=0.2)\n",
117 | " (batchnorm): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
118 | " )\n",
119 | " (conv1): Sequential(\n",
120 | " (pad): _SamePad()\n",
121 | " (conv): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1))\n",
122 | " (relu): LeakyReLU(negative_slope=0.2)\n",
123 | " (batchnorm): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
124 | " )\n",
125 | " (conv2): Sequential(\n",
126 | " (pad): _SamePad()\n",
127 | " (conv): Conv2d(384, 128, kernel_size=(3, 3), stride=(1, 1))\n",
128 | " (relu): LeakyReLU(negative_slope=0.2)\n",
129 | " (batchnorm): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
130 | " )\n",
131 | " (conv3): Sequential(\n",
132 | " (pad): _SamePad()\n",
133 | " (conv): Conv2d(512, 128, kernel_size=(3, 3), stride=(2, 2))\n",
134 | " (relu): LeakyReLU(negative_slope=0.2)\n",
135 | " (batchnorm): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
136 | " )\n",
137 | " )\n",
138 | " (z_reshape0): Linear(in_features=100, out_features=128, bias=True)\n",
139 | " (z_reshape1): Linear(in_features=100, out_features=256, bias=True)\n",
140 | " (z_reshape2): Linear(in_features=100, out_features=512, bias=True)\n",
141 | " (decode0): _DecoderBlock(\n",
142 | " (conv0): Sequential(\n",
143 | " (pad): _SamePad()\n",
144 | " (conv): Conv2d(136, 128, kernel_size=(3, 3), stride=(1, 1))\n",
145 | " (relu): LeakyReLU(negative_slope=0.2)\n",
146 | " (batchnorm): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
147 | " )\n",
148 | " (conv1): Sequential(\n",
149 | " (pad): _SamePad()\n",
150 | " (conv): Conv2d(264, 128, kernel_size=(3, 3), stride=(1, 1))\n",
151 | " (relu): LeakyReLU(negative_slope=0.2)\n",
152 | " (batchnorm): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
153 | " )\n",
154 | " (conv2): Sequential(\n",
155 | " (pad): _SamePad()\n",
156 | " (conv): Conv2d(392, 128, kernel_size=(3, 3), stride=(1, 1))\n",
157 | " (relu): LeakyReLU(negative_slope=0.2)\n",
158 | " (batchnorm): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
159 | " )\n",
160 | " (conv_t3): Sequential(\n",
161 | " (upsample): Upsample(size=8, mode=nearest)\n",
162 | " (conv): ConvTranspose2d(520, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
163 | " (relu): LeakyReLU(negative_slope=0.2)\n",
164 | " (batchnorm): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
165 | " )\n",
166 | " )\n",
167 | " (decode1): _DecoderBlock(\n",
168 | " (pre_conv_t0): Sequential(\n",
169 | " (upsample): Upsample(size=8, mode=nearest)\n",
170 | " (conv): ConvTranspose2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
171 | " )\n",
172 | " (conv0): Sequential(\n",
173 | " (pad): _SamePad()\n",
174 | " (conv): Conv2d(388, 128, kernel_size=(3, 3), stride=(1, 1))\n",
175 | " (relu): LeakyReLU(negative_slope=0.2)\n",
176 | " (batchnorm): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
177 | " )\n",
178 | " (pre_conv_t1): Sequential(\n",
179 | " (upsample): Upsample(size=8, mode=nearest)\n",
180 | " (conv): ConvTranspose2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
181 | " )\n",
182 | " (conv1): Sequential(\n",
183 | " (pad): _SamePad()\n",
184 | " (conv): Conv2d(516, 128, kernel_size=(3, 3), stride=(1, 1))\n",
185 | " (relu): LeakyReLU(negative_slope=0.2)\n",
186 | " (batchnorm): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
187 | " )\n",
188 | " (pre_conv_t2): Sequential(\n",
189 | " (upsample): Upsample(size=8, mode=nearest)\n",
190 | " (conv): ConvTranspose2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
191 | " )\n",
192 | " (conv2): Sequential(\n",
193 | " (pad): _SamePad()\n",
194 | " (conv): Conv2d(644, 128, kernel_size=(3, 3), stride=(1, 1))\n",
195 | " (relu): LeakyReLU(negative_slope=0.2)\n",
196 | " (batchnorm): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
197 | " )\n",
198 | " (conv_t3): Sequential(\n",
199 | " (upsample): Upsample(size=16, mode=nearest)\n",
200 | " (conv): ConvTranspose2d(644, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
201 | " (relu): LeakyReLU(negative_slope=0.2)\n",
202 | " (batchnorm): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
203 | " )\n",
204 | " )\n",
205 | " (decode2): _DecoderBlock(\n",
206 | " (pre_conv_t0): Sequential(\n",
207 | " (upsample): Upsample(size=16, mode=nearest)\n",
208 | " (conv): ConvTranspose2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
209 | " )\n",
210 | " (conv0): Sequential(\n",
211 | " (pad): _SamePad()\n",
212 | " (conv): Conv2d(322, 64, kernel_size=(3, 3), stride=(1, 1))\n",
213 | " (relu): LeakyReLU(negative_slope=0.2)\n",
214 | " (batchnorm): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
215 | " )\n",
216 | " (pre_conv_t1): Sequential(\n",
217 | " (upsample): Upsample(size=16, mode=nearest)\n",
218 | " (conv): ConvTranspose2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
219 | " )\n",
220 | " (conv1): Sequential(\n",
221 | " (pad): _SamePad()\n",
222 | " (conv): Conv2d(386, 64, kernel_size=(3, 3), stride=(1, 1))\n",
223 | " (relu): LeakyReLU(negative_slope=0.2)\n",
224 | " (batchnorm): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
225 | " )\n",
226 | " (pre_conv_t2): Sequential(\n",
227 | " (upsample): Upsample(size=16, mode=nearest)\n",
228 | " (conv): ConvTranspose2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
229 | " )\n",
230 | " (conv2): Sequential(\n",
231 | " (pad): _SamePad()\n",
232 | " (conv): Conv2d(450, 64, kernel_size=(3, 3), stride=(1, 1))\n",
233 | " (relu): LeakyReLU(negative_slope=0.2)\n",
234 | " (batchnorm): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
235 | " )\n",
236 | " (conv_t3): Sequential(\n",
237 | " (upsample): Upsample(size=32, mode=nearest)\n",
238 | " (conv): ConvTranspose2d(386, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
239 | " (relu): LeakyReLU(negative_slope=0.2)\n",
240 | " (batchnorm): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
241 | " )\n",
242 | " )\n",
243 | " (decode3): _DecoderBlock(\n",
244 | " (pre_conv_t0): Sequential(\n",
245 | " (upsample): Upsample(size=32, mode=nearest)\n",
246 | " (conv): ConvTranspose2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
247 | " )\n",
248 | " (conv0): Sequential(\n",
249 | " (pad): _SamePad()\n",
250 | " (conv): Conv2d(192, 64, kernel_size=(3, 3), stride=(1, 1))\n",
251 | " (relu): LeakyReLU(negative_slope=0.2)\n",
252 | " (batchnorm): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
253 | " )\n",
254 | " (pre_conv_t1): Sequential(\n",
255 | " (upsample): Upsample(size=32, mode=nearest)\n",
256 | " (conv): ConvTranspose2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
257 | " )\n",
258 | " (conv1): Sequential(\n",
259 | " (pad): _SamePad()\n",
260 | " (conv): Conv2d(256, 64, kernel_size=(3, 3), stride=(1, 1))\n",
261 | " (relu): LeakyReLU(negative_slope=0.2)\n",
262 | " (batchnorm): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
263 | " )\n",
264 | " (pre_conv_t2): Sequential(\n",
265 | " (upsample): Upsample(size=32, mode=nearest)\n",
266 | " (conv): ConvTranspose2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
267 | " )\n",
268 | " (conv2): Sequential(\n",
269 | " (pad): _SamePad()\n",
270 | " (conv): Conv2d(320, 64, kernel_size=(3, 3), stride=(1, 1))\n",
271 | " (relu): LeakyReLU(negative_slope=0.2)\n",
272 | " (batchnorm): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
273 | " )\n",
274 | " (conv_t3): Sequential(\n",
275 | " (upsample): Upsample(size=64, mode=nearest)\n",
276 | " (conv): ConvTranspose2d(320, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
277 | " (relu): LeakyReLU(negative_slope=0.2)\n",
278 | " (batchnorm): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
279 | " )\n",
280 | " )\n",
281 | " (decode4): _DecoderBlock(\n",
282 | " (pre_conv_t0): Sequential(\n",
283 | " (upsample): Upsample(size=64, mode=nearest)\n",
284 | " (conv): ConvTranspose2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
285 | " )\n",
286 | " (conv0): Sequential(\n",
287 | " (pad): _SamePad()\n",
288 | " (conv): Conv2d(129, 64, kernel_size=(3, 3), stride=(1, 1))\n",
289 | " (relu): LeakyReLU(negative_slope=0.2)\n",
290 | " (batchnorm): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
291 | " )\n",
292 | " (pre_conv_t1): Sequential(\n",
293 | " (upsample): Upsample(size=64, mode=nearest)\n",
294 | " (conv): ConvTranspose2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
295 | " )\n",
296 | " (conv1): Sequential(\n",
297 | " (pad): _SamePad()\n",
298 | " (conv): Conv2d(193, 64, kernel_size=(3, 3), stride=(1, 1))\n",
299 | " (relu): LeakyReLU(negative_slope=0.2)\n",
300 | " (batchnorm): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
301 | " )\n",
302 | " (pre_conv_t2): Sequential(\n",
303 | " (upsample): Upsample(size=64, mode=nearest)\n",
304 | " (conv): ConvTranspose2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
305 | " )\n",
306 | " (conv2): Sequential(\n",
307 | " (pad): _SamePad()\n",
308 | " (conv): Conv2d(257, 64, kernel_size=(3, 3), stride=(1, 1))\n",
309 | " (relu): LeakyReLU(negative_slope=0.2)\n",
310 | " (batchnorm): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
311 | " )\n",
312 | " )\n",
313 | " (final_conv0): Sequential(\n",
314 | " (pad): _SamePad()\n",
315 | " (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))\n",
316 | " (relu): LeakyReLU(negative_slope=0.2)\n",
317 | " (batchnorm): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
318 | " )\n",
319 | " (final_conv1): Sequential(\n",
320 | " (pad): _SamePad()\n",
321 | " (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))\n",
322 | " (relu): LeakyReLU(negative_slope=0.2)\n",
323 | " (batchnorm): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
324 | " )\n",
325 | " (final_conv2): Sequential(\n",
326 | " (pad): _SamePad()\n",
327 | " (conv): Conv2d(64, 1, kernel_size=(3, 3), stride=(1, 1))\n",
328 | " )\n",
329 | " (tanh): Tanh()\n",
330 | ")"
331 | ]
332 | },
333 | "execution_count": 24,
334 | "metadata": {},
335 | "output_type": "execute_result"
336 | }
337 | ],
338 | "source": [
339 | "g = torch.load(\"checkpoints/g_87.pt\", map_location=torch.device('cpu'))\n",
340 | "g.eval()"
341 | ]
342 | },
343 | {
344 | "cell_type": "code",
345 | "execution_count": 8,
346 | "metadata": {},
347 | "outputs": [],
348 | "source": [
349 | "raw_data = np.load(\"datasets/omniglot_data.npy\")"
350 | ]
351 | },
352 | {
353 | "cell_type": "code",
354 | "execution_count": 13,
355 | "metadata": {},
356 | "outputs": [],
357 | "source": [
358 | "display_transform = transforms.Compose([\n",
359 | " transforms.ToPILImage(),\n",
360 | " transforms.Resize(g.dim),\n",
361 | " transforms.ToTensor(),\n",
362 | " transforms.Normalize((0.5), (0.5))\n",
363 | "])"
364 | ]
365 | },
366 | {
367 | "cell_type": "code",
368 | "execution_count": 22,
369 | "metadata": {},
370 | "outputs": [],
371 | "source": [
372 | "def render_img(arr):\n",
373 | " arr = (arr * 0.5) + 0.5\n",
374 | " arr = np.uint8(arr * 255)\n",
375 | " display(Image.fromarray(arr, mode='L').resize((224, 224)))\n",
376 | "\n",
377 | "def display_generations(self, data_loader):\n",
378 | " train_idx = torch.randint(0, len(data_loader.dataset), (1,))[0]\n",
379 | " train_img = display_transform(data_loader.dataset.x1_examples[train_idx])\n",
380 | " self.render_img(train_img[0])\n",
381 | "\n",
382 | " z = torch.randn((1, self.g.z_dim)).to(self.device)\n",
383 | " inp = train_img.unsqueeze(0).to(self.device)\n",
384 | " train_gen = self.g(inp, z).cpu()[0]\n",
385 | " self.render_img(train_gen[0])"
386 | ]
387 | },
388 | {
389 | "cell_type": "code",
390 | "execution_count": 25,
391 | "metadata": {},
392 | "outputs": [],
393 | "source": [
394 | "z = torch.randn((1, g.z_dim))"
395 | ]
396 | },
397 | {
398 | "cell_type": "code",
399 | "execution_count": 47,
400 | "metadata": {},
401 | "outputs": [],
402 | "source": [
403 | "i = 301\n",
404 | "j = 0"
405 | ]
406 | },
407 | {
408 | "cell_type": "code",
409 | "execution_count": 48,
410 | "metadata": {},
411 | "outputs": [
412 | {
413 | "data": {
414 | "image/png": "\n",
415 | "text/plain": [
416 | ""
417 | ]
418 | },
419 | "metadata": {},
420 | "output_type": "display_data"
421 | },
422 | {
423 | "data": {
424 | "image/png": "\n",
425 | "text/plain": [
426 | ""
427 | ]
428 | },
429 | "metadata": {},
430 | "output_type": "display_data"
431 | }
432 | ],
433 | "source": [
434 | "raw_inp = raw_data[i][j]\n",
435 | "inp = display_transform(raw_inp)\n",
436 | "with torch.no_grad():\n",
437 | " res = g(inp.unsqueeze(0), z)[0]\n",
438 | "render_img(inp[0])\n",
439 | "render_img(res[0])"
440 | ]
441 | },
442 | {
443 | "cell_type": "code",
444 | "execution_count": 6,
445 | "metadata": {},
446 | "outputs": [],
447 | "source": [
448 | "inp = np.array([raw_data[1200][0]], dtype=\"float64\")"
449 | ]
450 | },
451 | {
452 | "cell_type": "code",
453 | "execution_count": 7,
454 | "metadata": {
455 | "scrolled": true
456 | },
457 | "outputs": [
458 | {
459 | "data": {
460 | "text/plain": [
461 | "Generator(\n",
462 | " (encode0): Sequential(\n",
463 | " (pad): _SamePad()\n",
464 | " (conv): Conv2d(1, 64, kernel_size=(3, 3), stride=(2, 2))\n",
465 | " (relu): LeakyReLU(negative_slope=0.2)\n",
466 | " (batchnorm): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
467 | " )\n",
468 | " (encode1): _EncoderBlock(\n",
469 | " (pre_conv): Sequential(\n",
470 | " (pad): _SamePad()\n",
471 | " (conv): Conv2d(1, 1, kernel_size=(3, 3), stride=(2, 2))\n",
472 | " )\n",
473 | " (conv0): Sequential(\n",
474 | " (pad): _SamePad()\n",
475 | " (conv): Conv2d(65, 64, kernel_size=(3, 3), stride=(1, 1))\n",
476 | " (relu): LeakyReLU(negative_slope=0.2)\n",
477 | " (batchnorm): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
478 | " )\n",
479 | " (conv1): Sequential(\n",
480 | " (pad): _SamePad()\n",
481 | " (conv): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1))\n",
482 | " (relu): LeakyReLU(negative_slope=0.2)\n",
483 | " (batchnorm): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
484 | " )\n",
485 | " (conv2): Sequential(\n",
486 | " (pad): _SamePad()\n",
487 | " (conv): Conv2d(192, 64, kernel_size=(3, 3), stride=(1, 1))\n",
488 | " (relu): LeakyReLU(negative_slope=0.2)\n",
489 | " (batchnorm): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
490 | " )\n",
491 | " (conv3): Sequential(\n",
492 | " (pad): _SamePad()\n",
493 | " (conv): Conv2d(256, 64, kernel_size=(3, 3), stride=(2, 2))\n",
494 | " (relu): LeakyReLU(negative_slope=0.2)\n",
495 | " (batchnorm): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
496 | " )\n",
497 | " )\n",
498 | " (encode2): _EncoderBlock(\n",
499 | " (pre_conv): Sequential(\n",
500 | " (pad): _SamePad()\n",
501 | " (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2))\n",
502 | " )\n",
503 | " (conv0): Sequential(\n",
504 | " (pad): _SamePad()\n",
505 | " (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))\n",
506 | " (relu): LeakyReLU(negative_slope=0.2)\n",
507 | " (batchnorm): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
508 | " )\n",
509 | " (conv1): Sequential(\n",
510 | " (pad): _SamePad()\n",
511 | " (conv): Conv2d(192, 128, kernel_size=(3, 3), stride=(1, 1))\n",
512 | " (relu): LeakyReLU(negative_slope=0.2)\n",
513 | " (batchnorm): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
514 | " )\n",
515 | " (conv2): Sequential(\n",
516 | " (pad): _SamePad()\n",
517 | " (conv): Conv2d(320, 128, kernel_size=(3, 3), stride=(1, 1))\n",
518 | " (relu): LeakyReLU(negative_slope=0.2)\n",
519 | " (batchnorm): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
520 | " )\n",
521 | " (conv3): Sequential(\n",
522 | " (pad): _SamePad()\n",
523 | " (conv): Conv2d(448, 128, kernel_size=(3, 3), stride=(2, 2))\n",
524 | " (relu): LeakyReLU(negative_slope=0.2)\n",
525 | " (batchnorm): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
526 | " )\n",
527 | " )\n",
528 | " (encode3): _EncoderBlock(\n",
529 | " (pre_conv): Sequential(\n",
530 | " (pad): _SamePad()\n",
531 | " (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2))\n",
532 | " )\n",
533 | " (conv0): Sequential(\n",
534 | " (pad): _SamePad()\n",
535 | " (conv): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1))\n",
536 | " (relu): LeakyReLU(negative_slope=0.2)\n",
537 | " (batchnorm): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
538 | " )\n",
539 | " (conv1): Sequential(\n",
540 | " (pad): _SamePad()\n",
541 | " (conv): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1))\n",
542 | " (relu): LeakyReLU(negative_slope=0.2)\n",
543 | " (batchnorm): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
544 | " )\n",
545 | " (conv2): Sequential(\n",
546 | " (pad): _SamePad()\n",
547 | " (conv): Conv2d(384, 128, kernel_size=(3, 3), stride=(1, 1))\n",
548 | " (relu): LeakyReLU(negative_slope=0.2)\n",
549 | " (batchnorm): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
550 | " )\n",
551 | " (conv3): Sequential(\n",
552 | " (pad): _SamePad()\n",
553 | " (conv): Conv2d(512, 128, kernel_size=(3, 3), stride=(2, 2))\n",
554 | " (relu): LeakyReLU(negative_slope=0.2)\n",
555 | " (batchnorm): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
556 | " )\n",
557 | " )\n",
558 | " (z_reshape0): Linear(in_features=100, out_features=32, bias=True)\n",
559 | " (z_reshape1): Linear(in_features=100, out_features=64, bias=True)\n",
560 | " (z_reshape2): Linear(in_features=100, out_features=98, bias=True)\n",
561 | " (decode0): _DecoderBlock(\n",
562 | " (conv0): Sequential(\n",
563 | " (pad): _SamePad()\n",
564 | " (conv): Conv2d(136, 128, kernel_size=(3, 3), stride=(1, 1))\n",
565 | " (relu): LeakyReLU(negative_slope=0.2)\n",
566 | " (batchnorm): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
567 | " )\n",
568 | " (conv1): Sequential(\n",
569 | " (pad): _SamePad()\n",
570 | " (conv): Conv2d(264, 128, kernel_size=(3, 3), stride=(1, 1))\n",
571 | " (relu): LeakyReLU(negative_slope=0.2)\n",
572 | " (batchnorm): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
573 | " )\n",
574 | " (conv2): Sequential(\n",
575 | " (pad): _SamePad()\n",
576 | " (conv): Conv2d(392, 128, kernel_size=(3, 3), stride=(1, 1))\n",
577 | " (relu): LeakyReLU(negative_slope=0.2)\n",
578 | " (batchnorm): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
579 | " )\n",
580 | " (conv_t3): Sequential(\n",
581 | " (upsample): Upsample(size=4, mode=nearest)\n",
582 | " (conv): ConvTranspose2d(520, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
583 | " (relu): LeakyReLU(negative_slope=0.2)\n",
584 | " (batchnorm): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
585 | " )\n",
586 | " )\n",
587 | " (decode1): _DecoderBlock(\n",
588 | " (pre_conv_t0): Sequential(\n",
589 | " (upsample): Upsample(size=4, mode=nearest)\n",
590 | " (conv): ConvTranspose2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
591 | " )\n",
592 | " (conv0): Sequential(\n",
593 | " (pad): _SamePad()\n",
594 | " (conv): Conv2d(388, 128, kernel_size=(3, 3), stride=(1, 1))\n",
595 | " (relu): LeakyReLU(negative_slope=0.2)\n",
596 | " (batchnorm): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
597 | " )\n",
598 | " (pre_conv_t1): Sequential(\n",
599 | " (upsample): Upsample(size=4, mode=nearest)\n",
600 | " (conv): ConvTranspose2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
601 | " )\n",
602 | " (conv1): Sequential(\n",
603 | " (pad): _SamePad()\n",
604 | " (conv): Conv2d(516, 128, kernel_size=(3, 3), stride=(1, 1))\n",
605 | " (relu): LeakyReLU(negative_slope=0.2)\n",
606 | " (batchnorm): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
607 | " )\n",
608 | " (pre_conv_t2): Sequential(\n",
609 | " (upsample): Upsample(size=4, mode=nearest)\n",
610 | " (conv): ConvTranspose2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
611 | " )\n",
612 | " (conv2): Sequential(\n",
613 | " (pad): _SamePad()\n",
614 | " (conv): Conv2d(644, 128, kernel_size=(3, 3), stride=(1, 1))\n",
615 | " (relu): LeakyReLU(negative_slope=0.2)\n",
616 | " (batchnorm): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
617 | " )\n",
618 | " (conv_t3): Sequential(\n",
619 | " (upsample): Upsample(size=7, mode=nearest)\n",
620 | " (conv): ConvTranspose2d(644, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
621 | " (relu): LeakyReLU(negative_slope=0.2)\n",
622 | " (batchnorm): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
623 | " )\n",
624 | " )\n",
625 | " (decode2): _DecoderBlock(\n",
626 | " (pre_conv_t0): Sequential(\n",
627 | " (upsample): Upsample(size=7, mode=nearest)\n",
628 | " (conv): ConvTranspose2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
629 | " )\n",
630 | " (conv0): Sequential(\n",
631 | " (pad): _SamePad()\n",
632 | " (conv): Conv2d(322, 64, kernel_size=(3, 3), stride=(1, 1))\n",
633 | " (relu): LeakyReLU(negative_slope=0.2)\n",
634 | " (batchnorm): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
635 | " )\n",
636 | " (pre_conv_t1): Sequential(\n",
637 | " (upsample): Upsample(size=7, mode=nearest)\n",
638 | " (conv): ConvTranspose2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
639 | " )\n",
640 | " (conv1): Sequential(\n",
641 | " (pad): _SamePad()\n",
642 | " (conv): Conv2d(386, 64, kernel_size=(3, 3), stride=(1, 1))\n",
643 | " (relu): LeakyReLU(negative_slope=0.2)\n",
644 | " (batchnorm): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
645 | " )\n",
646 | " (pre_conv_t2): Sequential(\n",
647 | " (upsample): Upsample(size=7, mode=nearest)\n",
648 | " (conv): ConvTranspose2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
649 | " )\n",
650 | " (conv2): Sequential(\n",
651 | " (pad): _SamePad()\n",
652 | " (conv): Conv2d(450, 64, kernel_size=(3, 3), stride=(1, 1))\n",
653 | " (relu): LeakyReLU(negative_slope=0.2)\n",
654 | " (batchnorm): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
655 | " )\n",
656 | " (conv_t3): Sequential(\n",
657 | " (upsample): Upsample(size=14, mode=nearest)\n",
658 | " (conv): ConvTranspose2d(386, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
659 | " (relu): LeakyReLU(negative_slope=0.2)\n",
660 | " (batchnorm): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
661 | " )\n",
662 | " )\n",
663 | " (decode3): _DecoderBlock(\n",
664 | " (pre_conv_t0): Sequential(\n",
665 | " (upsample): Upsample(size=14, mode=nearest)\n",
666 | " (conv): ConvTranspose2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
667 | " )\n",
668 | " (conv0): Sequential(\n",
669 | " (pad): _SamePad()\n",
670 | " (conv): Conv2d(192, 64, kernel_size=(3, 3), stride=(1, 1))\n",
671 | " (relu): LeakyReLU(negative_slope=0.2)\n",
672 | " (batchnorm): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
673 | " )\n",
674 | " (pre_conv_t1): Sequential(\n",
675 | " (upsample): Upsample(size=14, mode=nearest)\n",
676 | " (conv): ConvTranspose2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
677 | " )\n",
678 | " (conv1): Sequential(\n",
679 | " (pad): _SamePad()\n",
680 | " (conv): Conv2d(256, 64, kernel_size=(3, 3), stride=(1, 1))\n",
681 | " (relu): LeakyReLU(negative_slope=0.2)\n",
682 | " (batchnorm): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
683 | " )\n",
684 | " (pre_conv_t2): Sequential(\n",
685 | " (upsample): Upsample(size=14, mode=nearest)\n",
686 | " (conv): ConvTranspose2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
687 | " )\n",
688 | " (conv2): Sequential(\n",
689 | " (pad): _SamePad()\n",
690 | " (conv): Conv2d(320, 64, kernel_size=(3, 3), stride=(1, 1))\n",
691 | " (relu): LeakyReLU(negative_slope=0.2)\n",
692 | " (batchnorm): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
693 | " )\n",
694 | " (conv_t3): Sequential(\n",
695 | " (upsample): Upsample(size=28, mode=nearest)\n",
696 | " (conv): ConvTranspose2d(320, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
697 | " (relu): LeakyReLU(negative_slope=0.2)\n",
698 | " (batchnorm): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
699 | " )\n",
700 | " )\n",
701 | " (decode4): _DecoderBlock(\n",
702 | " (pre_conv_t0): Sequential(\n",
703 | " (upsample): Upsample(size=28, mode=nearest)\n",
704 | " (conv): ConvTranspose2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
705 | " )\n",
706 | " (conv0): Sequential(\n",
707 | " (pad): _SamePad()\n",
708 | " (conv): Conv2d(129, 64, kernel_size=(3, 3), stride=(1, 1))\n",
709 | " (relu): LeakyReLU(negative_slope=0.2)\n",
710 | " (batchnorm): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
711 | " )\n",
712 | " (pre_conv_t1): Sequential(\n",
713 | " (upsample): Upsample(size=28, mode=nearest)\n",
714 | " (conv): ConvTranspose2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
715 | " )\n",
716 | " (conv1): Sequential(\n",
717 | " (pad): _SamePad()\n",
718 | " (conv): Conv2d(193, 64, kernel_size=(3, 3), stride=(1, 1))\n",
719 | " (relu): LeakyReLU(negative_slope=0.2)\n",
720 | " (batchnorm): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
721 | " )\n",
722 | " (pre_conv_t2): Sequential(\n",
723 | " (upsample): Upsample(size=28, mode=nearest)\n",
724 | " (conv): ConvTranspose2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
725 | " )\n",
726 | " (conv2): Sequential(\n",
727 | " (pad): _SamePad()\n",
728 | " (conv): Conv2d(257, 64, kernel_size=(3, 3), stride=(1, 1))\n",
729 | " (relu): LeakyReLU(negative_slope=0.2)\n",
730 | " (batchnorm): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
731 | " )\n",
732 | " )\n",
733 | " (final_conv0): Sequential(\n",
734 | " (pad): _SamePad()\n",
735 | " (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))\n",
736 | " (relu): LeakyReLU(negative_slope=0.2)\n",
737 | " (batchnorm): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
738 | " )\n",
739 | " (final_conv1): Sequential(\n",
740 | " (pad): _SamePad()\n",
741 | " (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))\n",
742 | " (relu): LeakyReLU(negative_slope=0.2)\n",
743 | " (batchnorm): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n",
744 | " )\n",
745 | " (final_conv2): Sequential(\n",
746 | " (pad): _SamePad()\n",
747 | " (conv): Conv2d(64, 1, kernel_size=(3, 3), stride=(1, 1))\n",
748 | " )\n",
749 | " (tanh): Tanh()\n",
750 | ")"
751 | ]
752 | },
753 | "execution_count": 7,
754 | "metadata": {},
755 | "output_type": "execute_result"
756 | }
757 | ],
758 | "source": [
759 | "torch_g = Generator(dim=28, channels=1)\n",
760 | "torch_g.eval()"
761 | ]
762 | },
763 | {
764 | "cell_type": "code",
765 | "execution_count": 10,
766 | "metadata": {},
767 | "outputs": [],
768 | "source": [
769 | "input_images = torch.ones((32, 1, 28, 28))\n",
770 | "z = torch.randn((32, 100))"
771 | ]
772 | },
773 | {
774 | "cell_type": "code",
775 | "execution_count": 70,
776 | "metadata": {},
777 | "outputs": [],
778 | "source": [
779 | "from discriminator import create_d"
780 | ]
781 | },
782 | {
783 | "cell_type": "code",
784 | "execution_count": 72,
785 | "metadata": {},
786 | "outputs": [
787 | {
788 | "data": {
789 | "text/plain": [
790 | "DenseNet(\n",
791 | " (features): Sequential(\n",
792 | " (conv0): Conv2d(56, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
793 | " (norm0): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
794 | " (relu0): LeakyReLU(negative_slope=0.2)\n",
795 | " (pool0): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
796 | " (denseblock1): _DenseBlock(\n",
797 | " (denselayer1): _DenseLayer(\n",
798 | " (norm1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
799 | " (relu1): LeakyReLU(negative_slope=0.2)\n",
800 | " (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
801 | " (norm2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
802 | " (relu2): LeakyReLU(negative_slope=0.2)\n",
803 | " (conv2): Conv2d(64, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
804 | " )\n",
805 | " (denselayer2): _DenseLayer(\n",
806 | " (norm1): InstanceNorm2d(80, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
807 | " (relu1): LeakyReLU(negative_slope=0.2)\n",
808 | " (conv1): Conv2d(80, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
809 | " (norm2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
810 | " (relu2): LeakyReLU(negative_slope=0.2)\n",
811 | " (conv2): Conv2d(64, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
812 | " )\n",
813 | " (denselayer3): _DenseLayer(\n",
814 | " (norm1): InstanceNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
815 | " (relu1): LeakyReLU(negative_slope=0.2)\n",
816 | " (conv1): Conv2d(96, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
817 | " (norm2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
818 | " (relu2): LeakyReLU(negative_slope=0.2)\n",
819 | " (conv2): Conv2d(64, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
820 | " )\n",
821 | " )\n",
822 | " (transition1): _Transition(\n",
823 | " (norm): InstanceNorm2d(112, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
824 | " (relu): LeakyReLU(negative_slope=0.2)\n",
825 | " (conv): Conv2d(112, 56, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
826 | " (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)\n",
827 | " )\n",
828 | " (denseblock2): _DenseBlock(\n",
829 | " (denselayer1): _DenseLayer(\n",
830 | " (norm1): InstanceNorm2d(56, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
831 | " (relu1): LeakyReLU(negative_slope=0.2)\n",
832 | " (conv1): Conv2d(56, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
833 | " (norm2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
834 | " (relu2): LeakyReLU(negative_slope=0.2)\n",
835 | " (conv2): Conv2d(64, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
836 | " )\n",
837 | " (denselayer2): _DenseLayer(\n",
838 | " (norm1): InstanceNorm2d(72, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
839 | " (relu1): LeakyReLU(negative_slope=0.2)\n",
840 | " (conv1): Conv2d(72, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
841 | " (norm2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
842 | " (relu2): LeakyReLU(negative_slope=0.2)\n",
843 | " (conv2): Conv2d(64, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
844 | " )\n",
845 | " (denselayer3): _DenseLayer(\n",
846 | " (norm1): InstanceNorm2d(88, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
847 | " (relu1): LeakyReLU(negative_slope=0.2)\n",
848 | " (conv1): Conv2d(88, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
849 | " (norm2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
850 | " (relu2): LeakyReLU(negative_slope=0.2)\n",
851 | " (conv2): Conv2d(64, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
852 | " )\n",
853 | " )\n",
854 | " (transition2): _Transition(\n",
855 | " (norm): InstanceNorm2d(104, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
856 | " (relu): LeakyReLU(negative_slope=0.2)\n",
857 | " (conv): Conv2d(104, 52, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
858 | " (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)\n",
859 | " )\n",
860 | " (denseblock3): _DenseBlock(\n",
861 | " (denselayer1): _DenseLayer(\n",
862 | " (norm1): InstanceNorm2d(52, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
863 | " (relu1): LeakyReLU(negative_slope=0.2)\n",
864 | " (conv1): Conv2d(52, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
865 | " (norm2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
866 | " (relu2): LeakyReLU(negative_slope=0.2)\n",
867 | " (conv2): Conv2d(64, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
868 | " )\n",
869 | " (denselayer2): _DenseLayer(\n",
870 | " (norm1): InstanceNorm2d(68, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
871 | " (relu1): LeakyReLU(negative_slope=0.2)\n",
872 | " (conv1): Conv2d(68, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
873 | " (norm2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
874 | " (relu2): LeakyReLU(negative_slope=0.2)\n",
875 | " (conv2): Conv2d(64, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
876 | " )\n",
877 | " (denselayer3): _DenseLayer(\n",
878 | " (norm1): InstanceNorm2d(84, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
879 | " (relu1): LeakyReLU(negative_slope=0.2)\n",
880 | " (conv1): Conv2d(84, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
881 | " (norm2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
882 | " (relu2): LeakyReLU(negative_slope=0.2)\n",
883 | " (conv2): Conv2d(64, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
884 | " )\n",
885 | " )\n",
886 | " (transition3): _Transition(\n",
887 | " (norm): InstanceNorm2d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
888 | " (relu): LeakyReLU(negative_slope=0.2)\n",
889 | " (conv): Conv2d(100, 50, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
890 | " (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)\n",
891 | " )\n",
892 | " (denseblock4): _DenseBlock(\n",
893 | " (denselayer1): _DenseLayer(\n",
894 | " (norm1): InstanceNorm2d(50, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
895 | " (relu1): LeakyReLU(negative_slope=0.2)\n",
896 | " (conv1): Conv2d(50, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
897 | " (norm2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
898 | " (relu2): LeakyReLU(negative_slope=0.2)\n",
899 | " (conv2): Conv2d(64, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
900 | " )\n",
901 | " (denselayer2): _DenseLayer(\n",
902 | " (norm1): InstanceNorm2d(66, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
903 | " (relu1): LeakyReLU(negative_slope=0.2)\n",
904 | " (conv1): Conv2d(66, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
905 | " (norm2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
906 | " (relu2): LeakyReLU(negative_slope=0.2)\n",
907 | " (conv2): Conv2d(64, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
908 | " )\n",
909 | " (denselayer3): _DenseLayer(\n",
910 | " (norm1): InstanceNorm2d(82, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
911 | " (relu1): LeakyReLU(negative_slope=0.2)\n",
912 | " (conv1): Conv2d(82, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
913 | " (norm2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
914 | " (relu2): LeakyReLU(negative_slope=0.2)\n",
915 | " (conv2): Conv2d(64, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
916 | " )\n",
917 | " )\n",
918 | " (norm5): InstanceNorm2d(98, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
919 | " )\n",
920 | " (classifier): Linear(in_features=98, out_features=1, bias=True)\n",
921 | ")"
922 | ]
923 | },
924 | "execution_count": 72,
925 | "metadata": {},
926 | "output_type": "execute_result"
927 | }
928 | ],
929 | "source": [
930 | "create_d(28)"
931 | ]
932 | },
933 | {
934 | "cell_type": "code",
935 | "execution_count": 31,
936 | "metadata": {},
937 | "outputs": [
938 | {
939 | "data": {
940 | "text/plain": [
941 | "torch.Size([32, 1, 28, 28])"
942 | ]
943 | },
944 | "execution_count": 31,
945 | "metadata": {},
946 | "output_type": "execute_result"
947 | }
948 | ],
949 | "source": [
950 | "torch_g.sample(input_images).shape"
951 | ]
952 | },
953 | {
954 | "cell_type": "code",
955 | "execution_count": null,
956 | "metadata": {},
957 | "outputs": [],
958 | "source": [
959 | "torch_inp = torch.tensor(inp).transpose(1, 3).transpose(2, 3).float()\n",
960 | "torch_out = torch_g(torch_inp, torch.tensor(z).float())\n",
961 | "torch_out[0][0][0]"
962 | ]
963 | },
964 | {
965 | "cell_type": "code",
966 | "execution_count": null,
967 | "metadata": {},
968 | "outputs": [],
969 | "source": [
970 | "def render_image_arr(arr):\n",
971 | " arr = np.uint8(arr * 256)\n",
972 | " arr = arr.reshape(arr.shape[:-1])\n",
973 | " display(Image.fromarray(arr, mode='L').resize((224, 224)))\n",
974 | " \n",
975 | "def render_torch(arr):\n",
976 | " arr = np.uint8(arr * 256)\n",
977 | " display(Image.fromarray(arr, mode='L').resize((224, 224)))"
978 | ]
979 | },
980 | {
981 | "cell_type": "code",
982 | "execution_count": null,
983 | "metadata": {},
984 | "outputs": [],
985 | "source": [
986 | "# Render torch\n",
987 | "refined_torch_out = torch_out.detach().numpy() * 0.5 + 0.5\n",
988 | "\n",
989 | "render_torch(refined_torch_out[0][0])"
990 | ]
991 | },
992 | {
993 | "cell_type": "code",
994 | "execution_count": null,
995 | "metadata": {},
996 | "outputs": [],
997 | "source": [
998 | "from torchvision.models.densenet import DenseNet\n",
999 | "net = DenseNet(growth_rate=16, block_config=(3,3,3,3), num_classes=10, drop_rate=0.0)\n",
1000 | "net"
1001 | ]
1002 | },
1003 | {
1004 | "cell_type": "code",
1005 | "execution_count": null,
1006 | "metadata": {},
1007 | "outputs": [],
1008 | "source": [
1009 | "def convert_bn(net):\n",
1010 | " for name in net._modules.keys():\n",
1011 | " if \"norm\" in name:\n",
1012 | " net._modules[name] = nn.Identity()\n",
1013 | " else:\n",
1014 | " convert_bn(net._modules[name])"
1015 | ]
1016 | },
1017 | {
1018 | "cell_type": "code",
1019 | "execution_count": null,
1020 | "metadata": {},
1021 | "outputs": [],
1022 | "source": [
1023 | "convert_bn(net)\n",
1024 | "net"
1025 | ]
1026 | },
1027 | {
1028 | "cell_type": "code",
1029 | "execution_count": 54,
1030 | "metadata": {},
1031 | "outputs": [],
1032 | "source": [
1033 | "mod = nn.BatchNorm2d(64, affine=True)"
1034 | ]
1035 | },
1036 | {
1037 | "cell_type": "code",
1038 | "execution_count": 64,
1039 | "metadata": {},
1040 | "outputs": [],
1041 | "source": [
1042 | "import tensorflow as tf"
1043 | ]
1044 | },
1045 | {
1046 | "cell_type": "code",
1047 | "execution_count": null,
1048 | "metadata": {},
1049 | "outputs": [],
1050 | "source": [
1051 | "a = tf.constant(0.)\n",
1052 | "b = 2 * a\n",
1053 | "g = tf.gradients(a + b, [a, b], stop_gradients=[a, b])\n"
1054 | ]
1055 | },
1056 | {
1057 | "cell_type": "code",
1058 | "execution_count": 6,
1059 | "metadata": {},
1060 | "outputs": [],
1061 | "source": [
1062 | "norm = nn.InstanceNorm1d(3, affine=True, track_running_stats=False)"
1063 | ]
1064 | },
1065 | {
1066 | "cell_type": "code",
1067 | "execution_count": 7,
1068 | "metadata": {},
1069 | "outputs": [],
1070 | "source": [
1071 | "tensor = torch.zeros((1, 3, 5, 5))\n",
1072 | "for i in range(3):\n",
1073 | " tensor[0][i] = torch.randn((5, 5)) + i"
1074 | ]
1075 | },
1076 | {
1077 | "cell_type": "code",
1078 | "execution_count": 8,
1079 | "metadata": {},
1080 | "outputs": [
1081 | {
1082 | "ename": "ValueError",
1083 | "evalue": "expected 3D input (got 4D input)",
1084 | "output_type": "error",
1085 | "traceback": [
1086 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
1087 | "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
1088 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mnorm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
1089 | "\u001b[0;32m~/Workspace/pytorch/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 720\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 721\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 722\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 723\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 724\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
1090 | "\u001b[0;32m~/Workspace/pytorch/torch/nn/modules/instancenorm.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 51\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 52\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 53\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_check_input_dim\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 54\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 55\u001b[0m return F.instance_norm(\n",
1091 | "\u001b[0;32m~/Workspace/pytorch/torch/nn/modules/instancenorm.py\u001b[0m in \u001b[0;36m_check_input_dim\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 136\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0;36m3\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 137\u001b[0m raise ValueError('expected 3D input (got {}D input)'\n\u001b[0;32m--> 138\u001b[0;31m .format(input.dim()))\n\u001b[0m\u001b[1;32m 139\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 140\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
1092 | "\u001b[0;31mValueError\u001b[0m: expected 3D input (got 4D input)"
1093 | ]
1094 | }
1095 | ],
1096 | "source": [
1097 | "norm(tensor)"
1098 | ]
1099 | },
1100 | {
1101 | "cell_type": "code",
1102 | "execution_count": 17,
1103 | "metadata": {},
1104 | "outputs": [
1105 | {
1106 | "data": {
1107 | "text/plain": [
1108 | "tensor([[[[-0.2375, 0.9305, 1.4677, 0.1904, 2.1872],\n",
1109 | " [ 1.1405, 0.1877, -0.2945, -1.8864, -0.8841],\n",
1110 | " [-1.9455, 1.5563, 1.7558, -0.5022, -0.2945],\n",
1111 | " [-0.0302, -0.1682, 1.5877, 1.6745, -0.8496],\n",
1112 | " [ 0.1775, 0.2997, 2.0681, 0.9333, -0.7345]],\n",
1113 | "\n",
1114 | " [[ 2.0193, 0.9958, 2.0291, 0.1691, 0.7073],\n",
1115 | " [ 1.1761, 1.0602, 0.8592, 0.7109, 0.4420],\n",
1116 | " [ 1.3823, 2.0038, 0.9473, 0.3570, 1.9287],\n",
1117 | " [ 0.4505, 2.2994, 0.7116, 2.1269, -0.8457],\n",
1118 | " [-0.5620, 2.2178, 0.5716, 0.3277, 1.9907]],\n",
1119 | "\n",
1120 | " [[ 3.8420, 0.8646, 2.5898, 2.0150, 1.1855],\n",
1121 | " [ 1.9689, 1.1003, 1.8354, 1.3174, 1.2613],\n",
1122 | " [ 4.2256, 2.5718, 1.5704, 1.7245, 0.9842],\n",
1123 | " [ 2.0691, -0.0515, 1.2107, 3.6537, 1.2279],\n",
1124 | " [ 0.5973, 2.1711, 1.0104, 2.9567, 1.1152]]]])"
1125 | ]
1126 | },
1127 | "execution_count": 17,
1128 | "metadata": {},
1129 | "output_type": "execute_result"
1130 | }
1131 | ],
1132 | "source": [
1133 | "tensor"
1134 | ]
1135 | },
1136 | {
1137 | "cell_type": "code",
1138 | "execution_count": 20,
1139 | "metadata": {},
1140 | "outputs": [
1141 | {
1142 | "data": {
1143 | "text/plain": [
1144 | "[tensor(0.3332), tensor(1.0431), tensor(1.8007)]"
1145 | ]
1146 | },
1147 | "execution_count": 20,
1148 | "metadata": {},
1149 | "output_type": "execute_result"
1150 | }
1151 | ],
1152 | "source": [
1153 | "[tensor[0][i].mean() for i in range(3)]"
1154 | ]
1155 | },
1156 | {
1157 | "cell_type": "code",
1158 | "execution_count": 21,
1159 | "metadata": {},
1160 | "outputs": [
1161 | {
1162 | "data": {
1163 | "text/plain": [
1164 | "[tensor(-2.8610e-08, grad_fn=),\n",
1165 | " tensor(-7.1526e-08, grad_fn=),\n",
1166 | " tensor(0., grad_fn=)]"
1167 | ]
1168 | },
1169 | "execution_count": 21,
1170 | "metadata": {},
1171 | "output_type": "execute_result"
1172 | }
1173 | ],
1174 | "source": [
1175 | "[norm(tensor)[0][i].mean() for i in range(3)]"
1176 | ]
1177 | },
1178 | {
1179 | "cell_type": "code",
1180 | "execution_count": 50,
1181 | "metadata": {},
1182 | "outputs": [],
1183 | "source": [
1184 | "class Layer(nn.Module):\n",
1185 | " def __init__(self):\n",
1186 | " super().__init__()\n",
1187 | " self.gamma = torch.nn.Parameter(torch.tensor(range(5)).float().unsqueeze(-1).unsqueeze(-1), requires_grad=True)\n",
1188 | " self.beta = torch.nn.Parameter(torch.tensor([0] * 5).float().unsqueeze(-1).unsqueeze(-1), requires_grad=True)\n",
1189 | " \n",
1190 | " def forward(self, x):\n",
1191 | " mid = x * self.gamma + self.beta\n",
1192 | " print(mid)\n",
1193 | " return mid.mean()"
1194 | ]
1195 | },
1196 | {
1197 | "cell_type": "code",
1198 | "execution_count": 51,
1199 | "metadata": {},
1200 | "outputs": [],
1201 | "source": [
1202 | "xx = torch.autograd.Variable(torch.ones((5,2,2)), requires_grad = True)\n",
1203 | "l = Layer()"
1204 | ]
1205 | },
1206 | {
1207 | "cell_type": "code",
1208 | "execution_count": 53,
1209 | "metadata": {},
1210 | "outputs": [
1211 | {
1212 | "name": "stdout",
1213 | "output_type": "stream",
1214 | "text": [
1215 | "tensor([[[0., 0.],\n",
1216 | " [0., 0.]],\n",
1217 | "\n",
1218 | " [[1., 1.],\n",
1219 | " [1., 1.]],\n",
1220 | "\n",
1221 | " [[2., 2.],\n",
1222 | " [2., 2.]],\n",
1223 | "\n",
1224 | " [[3., 3.],\n",
1225 | " [3., 3.]],\n",
1226 | "\n",
1227 | " [[4., 4.],\n",
1228 | " [4., 4.]]], grad_fn=)\n"
1229 | ]
1230 | }
1231 | ],
1232 | "source": [
1233 | "result = l(xx)"
1234 | ]
1235 | },
1236 | {
1237 | "cell_type": "code",
1238 | "execution_count": 54,
1239 | "metadata": {},
1240 | "outputs": [],
1241 | "source": [
1242 | "result.backward()"
1243 | ]
1244 | },
1245 | {
1246 | "cell_type": "code",
1247 | "execution_count": 55,
1248 | "metadata": {},
1249 | "outputs": [
1250 | {
1251 | "data": {
1252 | "text/plain": [
1253 | "tensor([[[0.2000]],\n",
1254 | "\n",
1255 | " [[0.2000]],\n",
1256 | "\n",
1257 | " [[0.2000]],\n",
1258 | "\n",
1259 | " [[0.2000]],\n",
1260 | "\n",
1261 | " [[0.2000]]])"
1262 | ]
1263 | },
1264 | "execution_count": 55,
1265 | "metadata": {},
1266 | "output_type": "execute_result"
1267 | }
1268 | ],
1269 | "source": [
1270 | "l.beta.grad"
1271 | ]
1272 | },
1273 | {
1274 | "cell_type": "code",
1275 | "execution_count": null,
1276 | "metadata": {},
1277 | "outputs": [],
1278 | "source": []
1279 | }
1280 | ],
1281 | "metadata": {
1282 | "kernelspec": {
1283 | "display_name": "Python 3",
1284 | "language": "python",
1285 | "name": "python3"
1286 | },
1287 | "language_info": {
1288 | "codemirror_mode": {
1289 | "name": "ipython",
1290 | "version": 3
1291 | },
1292 | "file_extension": ".py",
1293 | "mimetype": "text/x-python",
1294 | "name": "python",
1295 | "nbconvert_exporter": "python",
1296 | "pygments_lexer": "ipython3",
1297 | "version": "3.7.6"
1298 | }
1299 | },
1300 | "nbformat": 4,
1301 | "nbformat_minor": 4
1302 | }
1303 |
--------------------------------------------------------------------------------
/resources/dagan_tracking_images.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amurthy1/dagan_torch/66d29d3ba655575398199d6e0e8170ac3be143f5/resources/dagan_tracking_images.png
--------------------------------------------------------------------------------
/resources/dagan_training_progress.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amurthy1/dagan_torch/66d29d3ba655575398199d6e0e8170ac3be143f5/resources/dagan_training_progress.gif
--------------------------------------------------------------------------------
/train_dagan.py:
--------------------------------------------------------------------------------
1 | from dagan_trainer import DaganTrainer
2 | from discriminator import Discriminator
3 | from generator import Generator
4 | from dataset import create_dagan_dataloader
5 | from utils.parser import get_dagan_args
6 | import torchvision.transforms as transforms
7 | import torch
8 | import os
9 | import torch.optim as optim
10 | import numpy as np
11 |
12 |
13 | # To maintain reproducibility
14 | torch.manual_seed(0)
15 | np.random.seed(0)
16 |
17 | torch.backends.cudnn.deterministic = True
18 | torch.backends.cudnn.benchmark = False
19 |
20 | # Load input args
21 | args = get_dagan_args()
22 |
23 | dataset_path = args.dataset_path
24 | raw_data = np.load(dataset_path).copy()
25 |
26 | final_generator_path = args.final_model_path
27 | save_checkpoint_path = args.save_checkpoint_path
28 | load_checkpoint_path = args.load_checkpoint_path
29 | in_channels = raw_data.shape[-1]
30 | img_size = args.img_size or raw_data.shape[2]
31 | num_training_classes = args.num_training_classes
32 | num_val_classes = args.num_val_classes
33 | batch_size = args.batch_size
34 | epochs = args.epochs
35 | dropout_rate = args.dropout_rate
36 | max_pixel_value = args.max_pixel_value
37 | should_display_generations = not args.suppress_generations
38 |
39 | # Input sanity checks
40 | final_generator_dir = os.path.dirname(final_generator_path) or os.getcwd()
41 | if not os.access(final_generator_dir, os.W_OK):
42 | raise ValueError(final_generator_path + " is not a valid filepath.")
43 |
44 | if num_training_classes + num_val_classes > raw_data.shape[0]:
45 | raise ValueError(
46 | "Expected at least %d classes but only had %d."
47 | % (num_training_classes + num_val_classes, raw_data.shape[0])
48 | )
49 |
50 |
51 | g = Generator(dim=img_size, channels=in_channels, dropout_rate=dropout_rate)
52 | d = Discriminator(dim=img_size, channels=in_channels * 2, dropout_rate=dropout_rate)
53 |
54 | mid_pixel_value = max_pixel_value / 2
55 | train_transform = transforms.Compose(
56 | [
57 | transforms.ToPILImage(),
58 | transforms.Resize(img_size),
59 | transforms.ToTensor(),
60 | transforms.Normalize(
61 | (mid_pixel_value,) * in_channels, (mid_pixel_value,) * in_channels
62 | ),
63 | ]
64 | )
65 |
66 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
67 | train_dataloader = create_dagan_dataloader(
68 | raw_data, num_training_classes, train_transform, batch_size
69 | )
70 |
71 | g_opt = optim.Adam(g.parameters(), lr=0.0001, betas=(0.0, 0.9))
72 | d_opt = optim.Adam(d.parameters(), lr=0.0001, betas=(0.0, 0.9))
73 |
74 | val_data = raw_data[num_training_classes : num_training_classes + num_val_classes]
75 | flat_val_data = val_data.reshape(
76 | (val_data.shape[0] * val_data.shape[1], *val_data.shape[2:])
77 | )
78 |
79 | display_transform = train_transform
80 |
81 | trainer = DaganTrainer(
82 | generator=g,
83 | discriminator=d,
84 | gen_optimizer=g_opt,
85 | dis_optimizer=d_opt,
86 | batch_size=batch_size,
87 | device=device,
88 | critic_iterations=5,
89 | print_every=75,
90 | num_tracking_images=10,
91 | save_checkpoint_path=save_checkpoint_path,
92 | load_checkpoint_path=load_checkpoint_path,
93 | display_transform=display_transform,
94 | should_display_generations=should_display_generations,
95 | )
96 | trainer.train(data_loader=train_dataloader, epochs=epochs, val_images=flat_val_data)
97 |
98 | # Save final generator model
99 | torch.save(trainer.g, final_generator_path)
100 |
--------------------------------------------------------------------------------
/train_omniglot_classifier.py:
--------------------------------------------------------------------------------
1 | from utils.classifier_utils import (
2 | create_classifier,
3 | create_classifier_dataloaders,
4 | compute_val_accuracy,
5 | perform_train_step,
6 | create_generated_batch,
7 | create_real_batch,
8 | )
9 | from generator import Generator
10 | from utils.parser import get_omniglot_classifier_args
11 | import torch.nn as nn
12 | import torch
13 | import os
14 | import torch.optim as optim
15 | import numpy as np
16 | import scipy.stats
17 | import time
18 |
19 |
20 | # To maintain reproducibility
21 | torch.manual_seed(0)
22 | np.random.seed(0)
23 |
24 | torch.backends.cudnn.deterministic = True
25 | torch.backends.cudnn.benchmark = False
26 |
27 | # Load input args
28 | args = get_omniglot_classifier_args()
29 |
30 | generator_path = args.generator_path
31 | dataset_path = args.dataset_path
32 | data_start_index = args.data_start_index
33 | num_classes = args.num_training_classes
34 | generated_batches_per_real = args.generated_batches_per_real
35 | num_epochs = args.epochs
36 | num_val = args.val_samples_per_class
37 | num_train = args.train_samples_per_class
38 | num_bootstrap_samples = args.num_bootstrap_samples
39 | progress_frequency = args.progress_frequency
40 | batch_size = args.batch_size
41 |
42 |
43 | raw_data = np.load(dataset_path)
44 | raw_data = raw_data[data_start_index:]
45 |
46 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
47 |
48 | dagan_generator = torch.load(generator_path, map_location=device)
49 | dagan_generator.eval()
50 |
51 | loss_function = nn.CrossEntropyLoss()
52 |
53 | start_time = int(time.time())
54 | last_progress_time = start_time
55 |
56 |
57 | val_accuracies_list = []
58 |
59 | # Run routine both with and without augmentation
60 | for real_batch_rate in (1, generated_batches_per_real + 1):
61 | print(
62 | "Training %d classifiers with %d generated batches per real"
63 | % (num_bootstrap_samples, real_batch_rate - 1)
64 | )
65 | val_accuracies = []
66 |
67 | # Train {num_bootstrap_samples} classifiers each with different samples of data
68 | for bootstrap_sample in range(num_bootstrap_samples):
69 | print("\nBootstrap #%d" % (bootstrap_sample + 1))
70 | classifier = create_classifier(num_classes).to(device)
71 | train_dataloader, val_dataloader = create_classifier_dataloaders(
72 | raw_data, num_classes, num_train, num_val, batch_size
73 | )
74 | optimizer = optim.Adam(classifier.parameters(), lr=0.001, betas=(0.9, 0.99))
75 |
76 | # Train on the dataset {num_epochs} times
77 | for epoch in range(num_epochs):
78 | training_loss = perform_train_step(
79 | classifier,
80 | train_dataloader,
81 | optimizer,
82 | loss_function,
83 | device,
84 | real_batch_rate,
85 | dagan_generator,
86 | )
87 |
88 | if epoch % progress_frequency == 0:
89 | print("[%d] train loss: %.5f" % (epoch + 1, training_loss))
90 | compute_val_accuracy(val_dataloader, classifier, device, loss_function)
91 |
92 | last_progress_time = int(time.time())
93 | print(
94 | f"Elapsed time: {(last_progress_time - start_time) / 60:.2f} minutes\n"
95 | )
96 |
97 | print("[%d] train loss: %.5f" % (num_epochs, training_loss))
98 | val_accuracies.append(
99 | compute_val_accuracy(val_dataloader, classifier, device, loss_function)
100 | )
101 |
102 | # Remove current net from gpu memory
103 | del classifier
104 | val_accuracies_list.append(val_accuracies)
105 |
106 | # Summarize and print results
107 | pvalue = scipy.stats.ttest_ind(
108 | val_accuracies_list[0], val_accuracies_list[1], equal_var=False
109 | ).pvalue
110 | print(
111 | "Trained %d classifiers with and without augmentation using %d samples per class"
112 | % (num_bootstrap_samples, num_train)
113 | )
114 | print("Average accuracy without augmentation: %.5f" % (np.mean(val_accuracies_list[0])))
115 | print("Average accuracy with augmentation: %.5f" % (np.mean(val_accuracies_list[1])))
116 | print(
117 | "Confidence level that augmentation has higher accuracy, using 2 sample t-test: %.3f%%"
118 | % (100 * (1 - pvalue))
119 | )
120 |
--------------------------------------------------------------------------------
/utils/classifier_utils.py:
--------------------------------------------------------------------------------
1 | from torchvision.models.densenet import DenseNet
2 | import torch.nn as nn
3 | from collections import OrderedDict
4 | import torch.nn.functional as F
5 | import numpy as np
6 | import torchvision.transforms as transforms
7 | from torch.utils.data import Dataset, DataLoader
8 | import warnings
9 | import torch
10 |
11 |
12 | class ModifiedDenseNet(DenseNet):
13 | """
14 | Densenet architecture modified to acccept a flag for whether
15 | a given input image is a real or generated sample.
16 |
17 | Extra dense layers are also added at the end to allow the input flag
18 | to interact with the extracted image features before outputting a prediction.
19 | """
20 |
21 | def __init__(self, **kwargs):
22 | super().__init__(**kwargs)
23 | self.classifier = nn.Sequential(
24 | OrderedDict(
25 | [
26 | ("linear1", nn.Linear(self.classifier.in_features + 1, 20)),
27 | ("norm1", nn.BatchNorm1d(20)),
28 | ("linear2", nn.Linear(20, self.classifier.out_features)),
29 | ]
30 | )
31 | )
32 |
33 | def forward(self, x, flags):
34 | features = self.features(x)
35 | out = F.relu(features, inplace=True)
36 | out = F.adaptive_avg_pool2d(out, (1, 1))
37 | out = torch.flatten(out, 1)
38 | out = torch.cat((out, flags.float()), dim=1)
39 | out = self.classifier(out)
40 | return out
41 |
42 |
43 | class AddGaussianNoiseTransform:
44 | """
45 | Adds random gaussian noise to an image.
46 |
47 | Useful for data augmentation during the training phase.
48 | """
49 |
50 | def __init__(self, mean=0.0, std=1.0):
51 | self.std = std
52 | self.mean = mean
53 |
54 | def __call__(self, arr):
55 | return (
56 | np.array(arr)
57 | + np.random.randn(*arr.shape).astype("float32") * self.std
58 | + self.mean
59 | )
60 |
61 | def __repr__(self):
62 | return self.__class__.__name__ + "(mean={0}, std={1})".format(
63 | self.mean, self.std
64 | )
65 |
66 |
67 | class ClassifierDataset(Dataset):
68 | def __init__(self, examples, labels, transform=None):
69 | self.examples = examples
70 | self.labels = labels
71 | self.transform = transform
72 |
73 | def __len__(self):
74 | return len(self.examples)
75 |
76 | def __getitem__(self, idx):
77 | sample = self.examples[idx]
78 | with warnings.catch_warnings():
79 | warnings.simplefilter("ignore")
80 | if self.transform:
81 | sample = self.transform(sample)
82 |
83 | return sample, self.labels[idx]
84 |
85 |
86 | pre_train_transform = transforms.Compose(
87 | [
88 | transforms.ToTensor(),
89 | transforms.Normalize((0.5), (0.5)),
90 | ]
91 | )
92 |
93 | train_transform = transforms.Compose(
94 | [
95 | AddGaussianNoiseTransform(0, 0.1),
96 | transforms.ToPILImage(),
97 | transforms.RandomHorizontalFlip(),
98 | transforms.Resize(224),
99 | transforms.ToTensor(),
100 | transforms.Normalize((0.5), (0.5)),
101 | ]
102 | )
103 |
104 | val_transform = transforms.Compose(
105 | [
106 | transforms.ToPILImage(),
107 | transforms.Resize(224),
108 | transforms.ToTensor(),
109 | transforms.Normalize((0.5), (0.5)),
110 | ]
111 | )
112 |
113 |
114 | def gen_out_to_numpy(tensor):
115 | """
116 | Convert tensor of images to numpy format.
117 | """
118 | return ((tensor * 0.5) + 0.5).squeeze(1).unsqueeze(-1).numpy()
119 |
120 |
121 | def create_real_batch(samples):
122 | """
123 | Given a batch of images, apply the train transform.
124 | """
125 | np_samples = gen_out_to_numpy(samples)
126 | new_samples = []
127 | with warnings.catch_warnings():
128 | warnings.simplefilter("ignore")
129 | for i in range(np_samples.shape[0]):
130 | new_samples.append(train_transform(np_samples[i]))
131 | return torch.stack(new_samples)
132 |
133 |
134 | def create_generated_batch(samples, generator, device="cpu"):
135 | """
136 | Given a batch of images, run them through a dagan
137 | and apply the train transform.
138 | """
139 | z = torch.randn((samples.shape[0], generator.z_dim))
140 | with torch.no_grad():
141 | g_out = generator(samples.to(device), z.to(device)).cpu()
142 | np_out = gen_out_to_numpy(g_out)
143 |
144 | new_samples = []
145 | with warnings.catch_warnings():
146 | warnings.simplefilter("ignore")
147 | for i in range(np_out.shape[0]):
148 | new_samples.append(train_transform(np_out[i]))
149 | return torch.stack(new_samples)
150 |
151 |
152 | def perform_train_step(
153 | net, train_dataloader, optimizer, loss_function, device, real_batch_rate, g
154 | ):
155 | """
156 | Perform one epoch of training using the full dataset.
157 | """
158 | running_loss = 0.0
159 | net.train()
160 | for i, data in enumerate(train_dataloader):
161 | # get the inputs; data is a list of [inputs, labels]
162 | inputs, labels = data[0], data[1].to(device)
163 | if i % real_batch_rate == 0:
164 | inputs = create_real_batch(inputs).to(device)
165 | flags = torch.ones((inputs.shape[0], 1)).to(device)
166 | else:
167 | inputs = create_generated_batch(inputs, g, device).to(device)
168 | flags = torch.zeros((inputs.shape[0], 1)).to(device)
169 |
170 | # zero the parameter gradients
171 | optimizer.zero_grad()
172 | # forward + backward + optimize
173 | outputs = net(inputs, flags)
174 | loss = loss_function(outputs, labels)
175 | loss.backward()
176 | optimizer.step()
177 |
178 | # print statistics
179 | running_loss += loss.item()
180 | return running_loss / len(train_dataloader.dataset)
181 |
182 |
183 | def compute_val_accuracy(val_dataloader, net, device, loss_function):
184 | """
185 | Compute accuracy on the given validation set.
186 | """
187 | val_loss = 0.0
188 | success = 0
189 | success_topk = 0
190 | k = 5
191 | net.eval()
192 | val_dataset = val_dataloader.dataset
193 | for i, data in enumerate(val_dataloader):
194 | with torch.no_grad():
195 | inputs, labels = data[0].to(device), data[1].to(device)
196 | flags = torch.ones((inputs.shape[0], 1)).to(device)
197 | outputs = net(inputs, flags)
198 | _, predicted = torch.max(outputs, 1)
199 | _, predicted_topk = torch.topk(outputs, axis=1, k=k)
200 | success += sum(
201 | [int(predicted[i] == labels[i]) for i in range(len(predicted))]
202 | )
203 | success_topk += sum(
204 | [
205 | int(labels[i] in predicted_topk[i])
206 | for i in range(len(predicted_topk))
207 | ]
208 | )
209 | loss = loss_function(outputs, labels)
210 | val_loss += loss.item()
211 | print("val loss: %.5f" % (val_loss / len(val_dataset)))
212 | print("val acc: %.5f" % (success / len(val_dataset)))
213 | print("val acc top%d: %.5f" % (k, success_topk / len(val_dataset)))
214 | return success / len(val_dataset)
215 |
216 |
217 | def create_classifier(num_classes):
218 | """
219 | Create a modified densenet with the given number of classes.
220 | """
221 | classifier = ModifiedDenseNet(
222 | growth_rate=16,
223 | block_config=(3, 3, 3, 3),
224 | num_classes=num_classes,
225 | drop_rate=0.3,
226 | )
227 | classifier.features[0] = nn.Conv2d(
228 | 1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
229 | )
230 | return classifier
231 |
232 |
233 | def create_classifier_dataloaders(
234 | raw_data, num_classes, num_train, num_val, batch_size
235 | ):
236 | """
237 | Create train and validation dataloaders with the given raw data.
238 | """
239 | train_X = []
240 | train_y = []
241 | val_X = []
242 | val_y = []
243 |
244 | for i in range(num_classes):
245 | # Shuffle data so different examples are chosen each time
246 | class_data = list(raw_data[i])
247 | np.random.shuffle(class_data)
248 |
249 | train_X.extend(class_data[:num_train])
250 | train_y.extend([i] * num_train)
251 | val_X.extend(class_data[-num_val:])
252 | val_y.extend([i] * num_val)
253 |
254 | train_dataloader = DataLoader(
255 | ClassifierDataset(train_X, train_y, pre_train_transform),
256 | batch_size=batch_size,
257 | shuffle=True,
258 | num_workers=1,
259 | )
260 | val_dataloader = DataLoader(
261 | ClassifierDataset(val_X, val_y, val_transform),
262 | batch_size=batch_size,
263 | shuffle=False,
264 | num_workers=1,
265 | )
266 |
267 | return train_dataloader, val_dataloader
268 |
--------------------------------------------------------------------------------
/utils/gif_maker.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import imageio
4 | import numpy as np
5 | from PIL import Image
6 | import PIL
7 |
8 |
9 | def convert_to_pil(arr, scale=1.0):
10 | arr = arr.reshape(-1, arr.shape[-1])
11 | arr = (arr * 0.5) + 0.5
12 | arr = np.uint8(arr * 255)
13 | h, w = arr.shape
14 | return (
15 | Image.fromarray(arr, mode="L")
16 | .resize((int(w * scale), int(h * scale)))
17 | .transpose(PIL.Image.TRANSPOSE)
18 | )
19 |
20 |
21 | def create_gif(
22 | checkpoint_path,
23 | gif_path,
24 | fps=15,
25 | scale=1.0,
26 | sampling_rate=1,
27 | tracking_images_path=None,
28 | ):
29 | checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu"))
30 |
31 | # Sample generations and append tracking images on top
32 | gens = checkpoint["tracking_images_gens"][0::sampling_rate]
33 | pil_images = [convert_to_pil(tensor, scale=scale) for tensor in gens]
34 |
35 | # Arbitrarily freeze last frame for half the length of existing gif
36 | pil_images += [pil_images[-1]] * (len(pil_images) // 2)
37 |
38 | imageio.mimsave(gif_path, [np.array(img) for img in pil_images], fps=fps)
39 |
40 | if tracking_images_path is not None:
41 | tracking_images = checkpoint["tracking_images"]
42 | imageio.imsave(tracking_images_path, np.array(convert_to_pil(tracking_images)))
--------------------------------------------------------------------------------
/utils/parser.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 |
4 | def get_dagan_args():
5 | parser = argparse.ArgumentParser(
6 | description="Use this script to train a dagan.",
7 | formatter_class=argparse.ArgumentDefaultsHelpFormatter,
8 | )
9 | parser.add_argument(
10 | "dataset_path",
11 | type=str,
12 | help="Filepath for dataset on which to train dagan. File should be .npy format with shape "
13 | "(num_classes, samples_per_class, height, width, channels).",
14 | )
15 | parser.add_argument(
16 | "final_model_path", type=str, help="Filepath to save final dagan model."
17 | )
18 | parser.add_argument(
19 | "--batch_size",
20 | nargs="?",
21 | type=int,
22 | default=32,
23 | help="batch_size for experiment",
24 | )
25 | parser.add_argument(
26 | "--img_size",
27 | nargs="?",
28 | type=int,
29 | help="Dimension to scale images when training. "
30 | "Useful when model architecture expects specific input size. "
31 | "If not specified, uses img_size of data as passed.",
32 | )
33 | parser.add_argument(
34 | "--num_training_classes",
35 | nargs="?",
36 | type=int,
37 | default=1200,
38 | help="Number of classes to use for training.",
39 | )
40 | parser.add_argument(
41 | "--num_val_classes",
42 | nargs="?",
43 | type=int,
44 | default=200,
45 | help="Number of classes to use for validation.",
46 | )
47 | parser.add_argument(
48 | "--epochs",
49 | nargs="?",
50 | type=int,
51 | default=50,
52 | help="Number of epochs to run training.",
53 | )
54 | parser.add_argument(
55 | "--max_pixel_value",
56 | nargs="?",
57 | type=float,
58 | default=1.0,
59 | help="Range of values used to represent pixel values in input data. "
60 | "Assumes lower bound is 0 (i.e. range of values is [0, max_pixel_value]).",
61 | )
62 | parser.add_argument(
63 | "--save_checkpoint_path",
64 | nargs="?",
65 | type=str,
66 | help="Filepath to save intermediate training checkpoints.",
67 | )
68 | parser.add_argument(
69 | "--load_checkpoint_path",
70 | nargs="?",
71 | type=str,
72 | help="Filepath of intermediate checkpoint from which to resume training.",
73 | )
74 | parser.add_argument(
75 | "--dropout_rate",
76 | type=float,
77 | nargs="?",
78 | default=0.5,
79 | help="Dropout rate to use within network architecture.",
80 | )
81 | parser.add_argument(
82 | "--suppress_generations",
83 | action="store_true",
84 | help="If specified, does not show intermediate progress images.",
85 | )
86 | return parser.parse_args()
87 |
88 |
89 | def get_omniglot_classifier_args():
90 | parser = argparse.ArgumentParser(
91 | description="Use this script to train an omniglot classifier "
92 | "with and without augmentations to compare the results.",
93 | formatter_class=argparse.ArgumentDefaultsHelpFormatter,
94 | )
95 | parser.add_argument(
96 | "generator_path",
97 | type=str,
98 | help="Filepath for dagan generator to use for augmentations.",
99 | )
100 | parser.add_argument(
101 | "--dataset_path",
102 | type=str,
103 | default="datasets/omniglot_data.npy",
104 | help="Filepath for omniglot data.",
105 | )
106 | parser.add_argument(
107 | "--batch_size",
108 | nargs="?",
109 | type=int,
110 | default=32,
111 | help="batch_size for experiment",
112 | )
113 | parser.add_argument(
114 | "--data_start_index",
115 | nargs="?",
116 | type=int,
117 | default=1420,
118 | help="Only uses classes after the given index for training. "
119 | "Useful to isolate data that wasn't used during dagan training.",
120 | )
121 | parser.add_argument(
122 | "--num_training_classes",
123 | nargs="?",
124 | type=int,
125 | default=100,
126 | help="Number of classes to use for training.",
127 | )
128 | parser.add_argument(
129 | "--train_samples_per_class",
130 | nargs="?",
131 | type=int,
132 | default=5,
133 | help="Number of samples to use per class during training.",
134 | )
135 | parser.add_argument(
136 | "--val_samples_per_class",
137 | nargs="?",
138 | type=int,
139 | default=5,
140 | help="Number of samples to use per class during validation.",
141 | )
142 | parser.add_argument(
143 | "--epochs",
144 | nargs="?",
145 | type=int,
146 | default=200,
147 | help="Number of epochs to run training.",
148 | )
149 | parser.add_argument(
150 | "--progress_frequency",
151 | nargs="?",
152 | type=int,
153 | default=50,
154 | help="Number of epochs between printing intermediate train/val loss.",
155 | )
156 | parser.add_argument(
157 | "--generated_batches_per_real",
158 | nargs="?",
159 | type=int,
160 | default=1,
161 | help="Number of augmented batches per real batch during "
162 | "augmented training phase.",
163 | )
164 | parser.add_argument(
165 | "--num_bootstrap_samples",
166 | nargs="?",
167 | type=int,
168 | default=10,
169 | help="Number of classifiers to train with slightly different data "
170 | "in order to get a more accuracy measurement.",
171 | )
172 | return parser.parse_args()
173 |
--------------------------------------------------------------------------------