├── .gitignore ├── README.md ├── checkpoints └── readme.txt ├── dagan_trainer.py ├── dataset.py ├── datasets └── readme.txt ├── discriminator.py ├── generator.py ├── notebook.ipynb ├── resources ├── dagan_tracking_images.png └── dagan_training_progress.gif ├── train_dagan.py ├── train_omniglot_classifier.py └── utils ├── classifier_utils.py ├── gif_maker.py └── parser.py /.gitignore: -------------------------------------------------------------------------------- 1 | datasets/* 2 | checkpoints/* 3 | !*readme.txt 4 | *.DS_Store 5 | .ipynb_checkpoints* 6 | __pycache__* 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Data Augmentation GAN in PyTorch 2 | 3 | 4 | 5 | 6 | Time-lapse of DAGAN generations on the omniglot dataset over the course of the training process. 7 | 8 | 9 | ## Table of Contents 10 | 1. [Intro](#intro) 11 | 2. [Background](#background) 12 | 3. [Results](#results) 13 | 4. [Training your own DAGAN](#train) 14 | 5. [Technical Details](#details) 15 | 1. [DAGAN Training Process](#dagan_train) 16 | 2. [Classifier Training Process](#classifier_train) 17 | 3. [Architectures](#architectures) 18 | 6. [Acknowledgements](#acknowledgements) 19 | 20 | ## 1. Intro 21 | 22 | This is a PyTorch implementation of Data Augmentation GAN (DAGAN), which was first proposed in [this paper](https://arxiv.org/abs/1711.04340) with a [corresponding TensorFlow implementation](https://github.com/AntreasAntoniou/DAGAN). 23 | 24 | This repo uses the same generator and discriminator architecture of the original TF implementation, while also including a classifier script for the omniglot dataset to test out the quality of a trained DAGAN. 25 | 26 | ## 2. Background 27 | 28 | The motivation for this work is to train a [Generative Adversarial Network (GAN)](https://en.wikipedia.org/wiki/Generative_adversarial_network) which takes in an image of a given class (e.g. a specific letter in an alphabet) and outputs another image of the same class that is sufficiently different looking than the input. This GAN is then used as a tool for data augmentation when training an image classifier. 29 | 30 | Standard data augmentation includes methods such as adding noise to, rotating, or cropping images, which increases variation in the training samples and improves the robustness of the trained classifier. Randomly passing some images through the DAGAN generator before using them in training serves a similar purpose. 31 | 32 | ## 3. Results 33 | 34 | To measure the quality of the DAGAN, classifiers were trained both with and without DAGAN augmentations to see if there was improvement in classifier accuracy with augmentations. The original paper showed improvement on the omniglot dataset using 5, 10, and 15 images per class to train the classifier. As expected, the fewer samples used, the more impactful the augmentations were. 35 | 36 | This PyTorch implementation showed statistically significant improvment on the omniglot dataset with 1-4 samples per class but had negligible gains with 5+ samples per class. The below table shows the classifier accuracy with and without DAGAN augmentations as well as the statistical significance level that the augmentations are in fact better. (More details on confidence interval methodology [can be found here](#classifier_train)). 37 | 38 | 39 | | Samples per Class | 1 | 2 | 3 | 4 | 5 | 40 | |----------------------------------------------|-------|-------|-------|-------|-------| 41 | | Acc. w/o DAGAN | 16.4% | 29.0% | 46.5% | 57.8% | 67.1% | 42 | | Acc. w/ DAGAN | 19.2% | 39.4% | 52.0% | 65.3% | 69.5% | 43 | | Confidence level that augmentations are better | 97.6% | 99.9% | 97.4% | 97.8% | 60.7% | 44 | 45 | 46 | ## 4. Training your own DAGAN 47 | 48 | The easiest way to train your own DAGAN or augmented omniglot classifier is through Google Colab. The Colab notebooks used to produce the results shown here can be found below: 49 | - [Train omniglot DAGAN](https://colab.research.google.com/drive/1U-twOEiguyIgiL6h9H6130tF-O_g-b-u) 50 | - [Train omniglot classifier with DAGAN augmentations](https://colab.research.google.com/drive/1oJggcS6-3x_chbEfahSJCsy19kWBxWeE) 51 | 52 | Running those notebooks as is should reproduce the results presented in this readme. One of the advantages of PyTorch relative to TensorFlow is the ease of modifying and testing out changes in the training process, particulary to the network architecture. To test out changes, you can fork this repo, make necessary changes, and re-run the colab script using the forked repo. 53 | 54 | ## 5. Technical Details 55 | 56 | ### 5.1 DAGAN Training Process 57 | Recall the procedure for training a traditional GAN: 58 | - Create 2 networks, a generator (G) and a discriminator (D) 59 | - To train D 60 | - Randomly sample images from G and train D to recognize as fake 61 | - Randomly sample real images and train D to recognize as real` 62 | - To train G 63 | - Sample images from G and pass them through D 64 | - Train/modify G to increase likelihood D classifies given samples as real 65 | - Train G and D alternately 66 | - This iteratively makes D better at distinguishing real and fake, while making G better at producing realistic images 67 | 68 | Training a DAGAN requires a slight twist: 69 | - To train D 70 | - Randomly sample pairs of real images (source, target), where both items in a pair belong to the same class 71 | - Pass both source, target to D to recognize as real 72 | - Pass source image to G to produce a realistic looking target 73 | - Pass source, generated target to D to recognize as fake 74 | - To train G 75 | - Sample real images (source) and pass them through G to produce generated targets 76 | - Train/modify G to increase likelihood D classifies (source, generated target) pairs as real 77 | - D learns to distinguish real and fake targets for a given source image 78 | - G learns to produce images that belong to the same class as source, while not being too similar to source (being too similar would provide a simple way for D to recognize fake targets) 79 | - Thus, G provides varied images that are somewhat similar to the source, which is our ultimate goal 80 | 81 | The omniglot DAGAN was trained on all the examples in the first 1200 classes of the dataset. The generator was validated on the next 200 classes by visual inspection. Training was done for 50 epochs, which took 3.3 hours on a Tesla T4 GPU. 82 | 83 | The network was trained using the Adam optimizer and the Improved Wasserstein loss function, which has some useful properties allowing signal to better pass from D to G during the training of G. More details can be found in the [Improved Wasserstein GAN paper](https://arxiv.org/abs/1704.00028). 84 | 85 | ### 5.2 Omniglot Classifier Training Process 86 | 87 | Omniglot classifiers were trained on classes #1420-1519 (100 classes) of the dataset for 200 epochs. Classifiers were trained with and without augmentations. When trained with augmentations, every other batch was passed through the DAGAN, so the total number of steps was the same in both configurations. 88 | 89 | To estimate more robustly the accuracy in each configuration, 10 classifiers were trained, each on a slightly different dataset. More specifically, out of the 20 samples available for each class, a different subset of k images was chosen for each of the 10 classifiers. A two-sample t-test was then used to determine confidence level that the 2 distributions of accuracies were sufficiently different (i.e. statistical significance of accuracy improvement from augmentation). 90 | 91 | This exercise was repeated using (1, 2, 3, 4, 5) samples per class. Training was done using Adam optimizer and standard cross-entropy loss. 92 | 93 | ### 5.3 Architectures 94 | 95 | The DAGAN architectures are described in detail in the paper and can also be seen in the PyTorch implementation of the [generator](https://github.com/amurthy1/dagan_torch/blob/master/generator.py) and [discriminator](https://github.com/amurthy1/dagan_torch/blob/master/discriminator.py). 96 | 97 | In a nutshell, the generator is a UNet of dense convolutional blocks. Each block has 4 conv layers, while the UNet itself is 4 blocks deep on each side. 98 | 99 | The discriminator is a DenseNet with 4 blocks, each containing 4 conv layers. 100 | 101 | The omniglot classifier uses the [standard PyTorch DenseNet implementation](https://pytorch.org/hub/pytorch_vision_densenet/) with 4 blocks, each having 3 conv layers. The last layer of the classifier was concatenated with a 0/1 flag representing whether a given input was real or generated. This was followed by 2 dense layers before outputting the final classification probabilities. This was useful to allow the image features output from the last layer of the DenseNet to interact with the real/generated flag in order to produce more accurate predictions. 102 | 103 | 104 | ## 6. Acknowledgements 105 | 106 | - As mentioned earlier, this work was adopted from [this paper](https://arxiv.org/abs/1711.04340) and [this repo](https://github.com/AntreasAntoniou/DAGAN) by A. Antoniou et al. 107 | 108 | - The omniglot dataset was originally sourced from [this github repo](https://github.com/brendenlake/omniglot/) by user [brendanlake](https://github.com/brendenlake). 109 | 110 | - The PyTorch Wasserstein GAN (WGAN) implementation in this repo was closely adopted from [this repo](https://github.com/EmilienDupont/wgan-gp) by user [EmilienDupont](https://github.com/EmilienDupont/). 111 | 112 | 113 | -------------------------------------------------------------------------------- /checkpoints/readme.txt: -------------------------------------------------------------------------------- 1 | Store intermediate model checkpoints here 2 | -------------------------------------------------------------------------------- /dagan_trainer.py: -------------------------------------------------------------------------------- 1 | import imageio 2 | import numpy as np 3 | import torch 4 | import time 5 | import torch.nn as nn 6 | import torchvision.transforms as transforms 7 | from torchvision.utils import make_grid 8 | from torch.autograd import Variable 9 | from torch.autograd import grad as torch_grad 10 | from PIL import Image 11 | import PIL 12 | import warnings 13 | 14 | 15 | class DaganTrainer: 16 | def __init__( 17 | self, 18 | generator, 19 | discriminator, 20 | gen_optimizer, 21 | dis_optimizer, 22 | batch_size, 23 | device="cpu", 24 | gp_weight=10, 25 | critic_iterations=5, 26 | print_every=50, 27 | num_tracking_images=0, 28 | save_checkpoint_path=None, 29 | load_checkpoint_path=None, 30 | display_transform=None, 31 | should_display_generations=True, 32 | ): 33 | self.device = device 34 | self.g = generator.to(device) 35 | self.g_opt = gen_optimizer 36 | self.d = discriminator.to(device) 37 | self.d_opt = dis_optimizer 38 | self.losses = {"G": [0.0], "D": [0.0], "GP": [0.0], "gradient_norm": [0.0]} 39 | self.num_steps = 0 40 | self.epoch = 0 41 | self.gp_weight = gp_weight 42 | self.critic_iterations = critic_iterations 43 | self.print_every = print_every 44 | self.num_tracking_images = num_tracking_images 45 | self.display_transform = display_transform or transforms.ToTensor() 46 | self.checkpoint_path = save_checkpoint_path 47 | self.should_display_generations = should_display_generations 48 | 49 | # Track progress of fixed images throughout the training 50 | self.tracking_images = None 51 | self.tracking_z = None 52 | self.tracking_images_gens = None 53 | 54 | if load_checkpoint_path: 55 | self.hydrate_checkpoint(load_checkpoint_path) 56 | 57 | def _critic_train_iteration(self, x1, x2): 58 | """ """ 59 | # Get generated data 60 | generated_data = self.sample_generator(x1) 61 | 62 | d_real = self.d(x1, x2) 63 | d_generated = self.d(x1, generated_data) 64 | 65 | # Get gradient penalty 66 | gradient_penalty = self._gradient_penalty(x1, x2, generated_data) 67 | self.losses["GP"].append(gradient_penalty.item()) 68 | 69 | # Create total loss and optimize 70 | self.d_opt.zero_grad() 71 | d_loss = d_generated.mean() - d_real.mean() + gradient_penalty 72 | d_loss.backward() 73 | 74 | self.d_opt.step() 75 | 76 | # Record loss 77 | self.losses["D"].append(d_loss.item()) 78 | 79 | def _generator_train_iteration(self, x1): 80 | """ """ 81 | self.g_opt.zero_grad() 82 | 83 | # Get generated data 84 | generated_data = self.sample_generator(x1) 85 | 86 | # Calculate loss and optimize 87 | d_generated = self.d(x1, generated_data) 88 | g_loss = -d_generated.mean() 89 | g_loss.backward() 90 | self.g_opt.step() 91 | 92 | # Record loss 93 | self.losses["G"].append(g_loss.item()) 94 | 95 | def _gradient_penalty(self, x1, x2, generated_data): 96 | # Calculate interpolation 97 | alpha = torch.rand(x1.shape[0], 1, 1, 1) 98 | alpha = alpha.expand_as(x2).to(self.device) 99 | interpolated = alpha * x2.data + (1 - alpha) * generated_data.data 100 | interpolated = Variable(interpolated, requires_grad=True).to(self.device) 101 | 102 | # Calculate probability of interpolated examples 103 | prob_interpolated = self.d(x1, interpolated) 104 | 105 | # Calculate gradients of probabilities with respect to examples 106 | gradients = torch_grad( 107 | outputs=prob_interpolated, 108 | inputs=interpolated, 109 | grad_outputs=torch.ones(prob_interpolated.size()).to(self.device), 110 | create_graph=True, 111 | retain_graph=True, 112 | )[0] 113 | 114 | # Gradients have shape (batch_size, num_channels, img_width, img_height), 115 | # so flatten to easily take norm per example in batch 116 | gradients = gradients.view(x1.shape[0], -1) 117 | self.losses["gradient_norm"].append(gradients.norm(2, dim=1).mean().item()) 118 | 119 | # Derivatives of the gradient close to 0 can cause problems because of 120 | # the square root, so manually calculate norm and add epsilon 121 | gradients_norm = torch.sqrt(torch.sum(gradients ** 2, dim=1) + 1e-12) 122 | 123 | # Return gradient penalty 124 | return self.gp_weight * ((gradients_norm - 1) ** 2).mean() 125 | 126 | def _train_epoch(self, data_loader, val_images): 127 | for i, data in enumerate(data_loader): 128 | if i % self.print_every == 0: 129 | print("Iteration {}".format(i)) 130 | self.print_progress(data_loader, val_images) 131 | self.num_steps += 1 132 | x1, x2 = data[0].to(self.device), data[1].to(self.device) 133 | self._critic_train_iteration(x1, x2) 134 | # Only update generator every |critic_iterations| iterations 135 | if self.num_steps % self.critic_iterations == 0: 136 | self._generator_train_iteration(x1) 137 | 138 | def train(self, data_loader, epochs, val_images=None, save_training_gif=True): 139 | if self.tracking_images is None and self.num_tracking_images > 0: 140 | self.tracking_images = self.sample_val_images( 141 | self.num_tracking_images // 2, val_images 142 | ) 143 | self.tracking_images.extend( 144 | self.sample_train_images( 145 | self.num_tracking_images - len(self.tracking_images), data_loader 146 | ) 147 | ) 148 | self.tracking_images = torch.stack(self.tracking_images).to(self.device) 149 | self.tracking_z = torch.randn((self.num_tracking_images, self.g.z_dim)).to( 150 | self.device 151 | ) 152 | self.tracking_images_gens = [] 153 | 154 | # Save checkpoint once before training to catch errors 155 | self._save_checkpoint() 156 | 157 | start_time = int(time.time()) 158 | 159 | while self.epoch < epochs: 160 | print("\nEpoch {}".format(self.epoch)) 161 | print(f"Elapsed time: {(time.time() - start_time) / 60:.2f} minutes\n") 162 | 163 | self._train_epoch(data_loader, val_images) 164 | self.epoch += 1 165 | self._save_checkpoint() 166 | 167 | def sample_generator(self, input_images, z=None): 168 | if z is None: 169 | z = torch.randn((input_images.shape[0], self.g.z_dim)).to(self.device) 170 | return self.g(input_images, z) 171 | 172 | def render_img(self, arr): 173 | arr = (arr * 0.5) + 0.5 174 | arr = np.uint8(arr * 255) 175 | display(Image.fromarray(arr, mode="L").transpose(PIL.Image.TRANSPOSE)) 176 | 177 | def sample_train_images(self, n, data_loader): 178 | with warnings.catch_warnings(): 179 | warnings.simplefilter("ignore", category=UserWarning) 180 | return [ 181 | self.display_transform(data_loader.dataset.x1_examples[idx]) 182 | for idx in torch.randint(0, len(data_loader.dataset), (n,)) 183 | ] 184 | 185 | def sample_val_images(self, n, val_images): 186 | if val_images is None: 187 | return [] 188 | 189 | with warnings.catch_warnings(): 190 | warnings.simplefilter("ignore", category=UserWarning) 191 | return [ 192 | self.display_transform(val_images[idx]) 193 | for idx in torch.randint(0, len(val_images), (n,)) 194 | ] 195 | 196 | def display_generations(self, data_loader, val_images): 197 | n = 5 198 | images = self.sample_train_images(n, data_loader) + self.sample_val_images( 199 | n, val_images 200 | ) 201 | img_size = images[0].shape[-1] 202 | images.append(torch.tensor(np.ones((1, img_size, img_size))).float()) 203 | images.append(torch.tensor(np.ones((1, img_size, img_size))).float() * -1) 204 | self.render_img(torch.cat(images, 1)[0]) 205 | z = torch.randn((len(images), self.g.z_dim)).to(self.device) 206 | inp = torch.stack(images).to(self.device) 207 | train_gen = self.g(inp, z).cpu() 208 | self.render_img(train_gen.reshape(-1, train_gen.shape[-1])) 209 | 210 | def print_progress(self, data_loader, val_images): 211 | self.g.eval() 212 | with torch.no_grad(): 213 | if self.should_display_generations: 214 | self.display_generations(data_loader, val_images) 215 | if self.num_tracking_images > 0: 216 | self.tracking_images_gens.append( 217 | self.g(self.tracking_images, self.tracking_z).cpu() 218 | ) 219 | self.g.train() 220 | print("D: {}".format(self.losses["D"][-1])) 221 | print("Raw D: {}".format(self.losses["D"][-1] - self.losses["GP"][-1])) 222 | print("GP: {}".format(self.losses["GP"][-1])) 223 | print("Gradient norm: {}".format(self.losses["gradient_norm"][-1])) 224 | if self.num_steps > self.critic_iterations: 225 | print("G: {}".format(self.losses["G"][-1])) 226 | 227 | def _save_checkpoint(self): 228 | if self.checkpoint_path is None: 229 | return 230 | checkpoint = { 231 | "epoch": self.epoch, 232 | "num_steps": self.num_steps, 233 | "g": self.g.state_dict(), 234 | "g_opt": self.g_opt.state_dict(), 235 | "d": self.d.state_dict(), 236 | "d_opt": self.d_opt.state_dict(), 237 | "tracking_images": self.tracking_images, 238 | "tracking_z": self.tracking_z, 239 | "tracking_images_gens": self.tracking_images_gens, 240 | } 241 | torch.save(checkpoint, self.checkpoint_path) 242 | 243 | def hydrate_checkpoint(self, checkpoint_path): 244 | checkpoint = torch.load(checkpoint_path, map_location=self.device) 245 | self.epoch = checkpoint["epoch"] 246 | self.num_steps = checkpoint["num_steps"] 247 | 248 | self.g.load_state_dict(checkpoint["g"]) 249 | self.g_opt.load_state_dict(checkpoint["g_opt"]) 250 | self.d.load_state_dict(checkpoint["d"]) 251 | self.d_opt.load_state_dict(checkpoint["d_opt"]) 252 | 253 | self.tracking_images = checkpoint["tracking_images"] 254 | self.tracking_z = checkpoint["tracking_z"] 255 | self.tracking_images_gens = checkpoint["tracking_images_gens"] 256 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset, DataLoader 2 | import numpy as np 3 | import warnings 4 | 5 | 6 | class DaganDataset(Dataset): 7 | """Face Landmarks dataset.""" 8 | 9 | def __init__(self, x1_examples, x2_examples, transform=None): 10 | assert len(x1_examples) == len(x2_examples) 11 | self.x1_examples = x1_examples 12 | self.x2_examples = x2_examples 13 | self.transform = transform 14 | 15 | def __len__(self): 16 | return len(self.x1_examples) 17 | 18 | def __getitem__(self, idx): 19 | with warnings.catch_warnings(): 20 | warnings.simplefilter("ignore", category=UserWarning) 21 | return self.transform(self.x1_examples[idx]), self.transform( 22 | self.x2_examples[idx] 23 | ) 24 | 25 | 26 | def create_dagan_dataloader(raw_data, num_classes, transform, batch_size): 27 | train_x1 = [] 28 | train_x2 = [] 29 | 30 | for i in range(num_classes): 31 | x2_data = list(raw_data[i]) 32 | np.random.shuffle(x2_data) 33 | train_x1.extend(raw_data[i]) 34 | train_x2.extend(x2_data) 35 | 36 | train_dataset = DaganDataset(train_x1, train_x2, transform) 37 | return DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=1) 38 | -------------------------------------------------------------------------------- /datasets/readme.txt: -------------------------------------------------------------------------------- 1 | Store datasets here 2 | -------------------------------------------------------------------------------- /discriminator.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class _LayerNorm(nn.Module): 9 | def __init__(self, num_features, img_size): 10 | """ 11 | Normalizes over the entire image and scales + weights for each feature 12 | """ 13 | super().__init__() 14 | self.layer_norm = nn.LayerNorm( 15 | (num_features, img_size, img_size), elementwise_affine=False, eps=1e-12 16 | ) 17 | self.weight = torch.nn.Parameter( 18 | torch.ones(num_features).float().unsqueeze(-1).unsqueeze(-1), 19 | requires_grad=True, 20 | ) 21 | self.bias = torch.nn.Parameter( 22 | torch.zeros(num_features).float().unsqueeze(-1).unsqueeze(-1), 23 | requires_grad=True, 24 | ) 25 | 26 | def forward(self, x): 27 | out = self.layer_norm(x) 28 | out = out * self.weight + self.bias 29 | return out 30 | 31 | 32 | class _SamePad(nn.Module): 33 | """ 34 | Pads equivalent to the behavior of tensorflow "SAME" 35 | """ 36 | 37 | def __init__(self, stride): 38 | super().__init__() 39 | self.stride = stride 40 | 41 | def forward(self, x): 42 | if self.stride == 2 and x.shape[2] % 2 == 0: 43 | return F.pad(x, (0, 1, 0, 1)) 44 | return F.pad(x, (1, 1, 1, 1)) 45 | 46 | 47 | def _conv2d( 48 | in_channels, 49 | out_channels, 50 | kernel_size, 51 | stride, 52 | out_size=None, 53 | activate=True, 54 | dropout=0.0, 55 | ): 56 | layers = OrderedDict() 57 | layers["pad"] = _SamePad(stride) 58 | layers["conv"] = nn.Conv2d(in_channels, out_channels, kernel_size, stride) 59 | if activate: 60 | if out_size is None: 61 | raise ValueError("Must provide out_size if activate is True") 62 | layers["relu"] = nn.LeakyReLU(0.2) 63 | layers["norm"] = _LayerNorm(out_channels, out_size) 64 | 65 | if dropout > 0.0: 66 | layers["dropout"] = nn.Dropout(dropout) 67 | return nn.Sequential(layers) 68 | 69 | 70 | class _EncoderBlock(nn.Module): 71 | def __init__( 72 | self, 73 | pre_channels, 74 | in_channels, 75 | out_channels, 76 | num_layers, 77 | out_size, 78 | dropout_rate=0.0, 79 | ): 80 | super().__init__() 81 | self.num_layers = num_layers 82 | self.pre_conv = _conv2d( 83 | in_channels=pre_channels, 84 | out_channels=pre_channels, 85 | kernel_size=3, 86 | stride=2, 87 | activate=False, 88 | ) 89 | 90 | self.conv0 = _conv2d( 91 | in_channels=in_channels + pre_channels, 92 | out_channels=out_channels, 93 | kernel_size=3, 94 | stride=1, 95 | out_size=out_size, 96 | ) 97 | total_channels = in_channels + out_channels 98 | for i in range(1, num_layers): 99 | self.add_module( 100 | "conv%d" % i, 101 | _conv2d( 102 | in_channels=total_channels, 103 | out_channels=out_channels, 104 | kernel_size=3, 105 | stride=1, 106 | out_size=out_size, 107 | ), 108 | ) 109 | total_channels += out_channels 110 | self.add_module( 111 | "conv%d" % num_layers, 112 | _conv2d( 113 | in_channels=total_channels, 114 | out_channels=out_channels, 115 | kernel_size=3, 116 | stride=2, 117 | out_size=(out_size + 1) // 2, 118 | dropout=dropout_rate, 119 | ), 120 | ) 121 | 122 | def forward(self, inp): 123 | pre_input, x = inp 124 | pre_input = self.pre_conv(pre_input) 125 | out = self.conv0(torch.cat([x, pre_input], 1)) 126 | 127 | all_outputs = [x, out] 128 | for i in range(1, self.num_layers + 1): 129 | input_features = torch.cat( 130 | [all_outputs[-1], all_outputs[-2]] + all_outputs[:-2], 1 131 | ) 132 | module = self._modules["conv%d" % i] 133 | out = module(input_features) 134 | all_outputs.append(out) 135 | return all_outputs[-2], all_outputs[-1] 136 | 137 | 138 | class Discriminator(nn.Module): 139 | def __init__(self, dim, channels, dropout_rate=0.0, z_dim=100): 140 | super().__init__() 141 | self.dim = dim 142 | self.z_dim = z_dim 143 | self.channels = channels 144 | self.layer_sizes = [64, 64, 128, 128] 145 | self.num_inner_layers = 5 146 | 147 | # Number of times dimension is halved 148 | self.depth = len(self.layer_sizes) 149 | 150 | # dimension at each level of U-net 151 | self.dim_arr = [dim] 152 | for i in range(self.depth): 153 | self.dim_arr.append((self.dim_arr[-1] + 1) // 2) 154 | 155 | # Encoders 156 | self.encode0 = _conv2d( 157 | in_channels=self.channels, 158 | out_channels=self.layer_sizes[0], 159 | kernel_size=3, 160 | stride=2, 161 | out_size=self.dim_arr[1], 162 | ) 163 | for i in range(1, self.depth): 164 | self.add_module( 165 | "encode%d" % i, 166 | _EncoderBlock( 167 | pre_channels=self.channels if i == 1 else self.layer_sizes[i - 1], 168 | in_channels=self.layer_sizes[i - 1], 169 | out_channels=self.layer_sizes[i], 170 | num_layers=self.num_inner_layers, 171 | out_size=self.dim_arr[i], 172 | dropout_rate=dropout_rate, 173 | ), 174 | ) 175 | self.dense1 = nn.Linear(self.layer_sizes[-1], 1024) 176 | self.leaky_relu = nn.LeakyReLU(0.2) 177 | self.dense2 = nn.Linear(self.layer_sizes[-1] * self.dim_arr[-1] ** 2 + 1024, 1) 178 | 179 | def forward(self, x1, x2): 180 | x = torch.cat([x1, x2], 1) 181 | out = [x, self.encode0(x)] 182 | for i in range(1, len(self.layer_sizes)): 183 | out = self._modules["encode%d" % i](out) 184 | out = out[1] 185 | 186 | out_mean = out.mean([2, 3]) 187 | out_flat = torch.flatten(out, 1) 188 | 189 | out = self.dense1(out_mean) 190 | out = self.leaky_relu(out) 191 | out = self.dense2(torch.cat([out, out_flat], 1)) 192 | 193 | return out 194 | -------------------------------------------------------------------------------- /generator.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class _SamePad(nn.Module): 9 | """ 10 | Pads equivalent to the behavior of tensorflow "SAME" 11 | """ 12 | 13 | def __init__(self, stride): 14 | super().__init__() 15 | self.stride = stride 16 | 17 | def forward(self, x): 18 | if self.stride == 2 and x.shape[2] % 2 == 0: 19 | return F.pad(x, (0, 1, 0, 1)) 20 | return F.pad(x, (1, 1, 1, 1)) 21 | 22 | 23 | def _conv2d(in_channels, out_channels, kernel_size, stride, activate=True, dropout=0.0): 24 | layers = OrderedDict() 25 | layers["pad"] = _SamePad(stride) 26 | layers["conv"] = nn.Conv2d(in_channels, out_channels, kernel_size, stride) 27 | if activate: 28 | layers["relu"] = nn.LeakyReLU(0.2) 29 | layers["batchnorm"] = nn.BatchNorm2d(out_channels, eps=1e-3, momentum=0.01) 30 | 31 | if dropout > 0.0: 32 | layers["dropout"] = nn.Dropout(dropout) 33 | return nn.Sequential(layers) 34 | 35 | 36 | def _conv2d_transpose( 37 | in_channels, out_channels, kernel_size, upscale_size, activate=True, dropout=0.0 38 | ): 39 | layers = OrderedDict() 40 | layers["upsample"] = nn.Upsample(upscale_size) 41 | layers["conv"] = nn.ConvTranspose2d( 42 | in_channels, out_channels, kernel_size, stride=1, padding=1 43 | ) 44 | if activate: 45 | layers["relu"] = nn.LeakyReLU(0.2) 46 | layers["batchnorm"] = nn.BatchNorm2d(out_channels, eps=1e-3, momentum=0.01) 47 | 48 | if dropout > 0.0: 49 | layers["dropout"] = nn.Dropout(dropout) 50 | return nn.Sequential(layers) 51 | 52 | 53 | class _EncoderBlock(nn.Module): 54 | def __init__( 55 | self, pre_channels, in_channels, out_channels, num_layers, dropout_rate=0.0 56 | ): 57 | super().__init__() 58 | self.num_layers = num_layers 59 | self.pre_conv = _conv2d( 60 | in_channels=pre_channels, 61 | out_channels=pre_channels, 62 | kernel_size=3, 63 | stride=2, 64 | activate=False, 65 | ) 66 | 67 | self.conv0 = _conv2d( 68 | in_channels=in_channels + pre_channels, 69 | out_channels=out_channels, 70 | kernel_size=3, 71 | stride=1, 72 | ) 73 | total_channels = in_channels + out_channels 74 | for i in range(1, num_layers): 75 | self.add_module( 76 | "conv%d" % i, 77 | _conv2d( 78 | in_channels=total_channels, 79 | out_channels=out_channels, 80 | kernel_size=3, 81 | stride=1, 82 | ), 83 | ) 84 | total_channels += out_channels 85 | self.add_module( 86 | "conv%d" % num_layers, 87 | _conv2d( 88 | in_channels=total_channels, 89 | out_channels=out_channels, 90 | kernel_size=3, 91 | stride=2, 92 | dropout=dropout_rate, 93 | ), 94 | ) 95 | 96 | def forward(self, inp): 97 | pre_input, x = inp 98 | pre_input = self.pre_conv(pre_input) 99 | out = self.conv0(torch.cat([x, pre_input], 1)) 100 | 101 | all_outputs = [x, out] 102 | for i in range(1, self.num_layers + 1): 103 | input_features = torch.cat(all_outputs, 1) 104 | module = self._modules["conv%d" % i] 105 | out = module(input_features) 106 | all_outputs.append(out) 107 | return all_outputs[-2], all_outputs[-1] 108 | 109 | 110 | class _DecoderBlock(nn.Module): 111 | def __init__( 112 | self, 113 | pre_channels, 114 | in_channels, 115 | out_channels, 116 | num_layers, 117 | curr_size, 118 | upscale_size=None, 119 | dropout_rate=0.0, 120 | ): 121 | super().__init__() 122 | self.num_layers = num_layers 123 | self.should_upscale = upscale_size is not None 124 | self.should_pre_conv = pre_channels > 0 125 | 126 | total_channels = pre_channels + in_channels 127 | for i in range(num_layers): 128 | if self.should_pre_conv: 129 | self.add_module( 130 | "pre_conv_t%d" % i, 131 | _conv2d_transpose( 132 | in_channels=pre_channels, 133 | out_channels=pre_channels, 134 | kernel_size=3, 135 | upscale_size=curr_size, 136 | activate=False, 137 | ), 138 | ) 139 | self.add_module( 140 | "conv%d" % i, 141 | _conv2d( 142 | in_channels=total_channels, 143 | out_channels=out_channels, 144 | kernel_size=3, 145 | stride=1, 146 | ), 147 | ) 148 | total_channels += out_channels 149 | 150 | if self.should_upscale: 151 | total_channels -= pre_channels 152 | self.add_module( 153 | "conv_t%d" % num_layers, 154 | _conv2d_transpose( 155 | in_channels=total_channels, 156 | out_channels=out_channels, 157 | kernel_size=3, 158 | upscale_size=upscale_size, 159 | dropout=dropout_rate, 160 | ), 161 | ) 162 | 163 | def forward(self, inp): 164 | pre_input, x = inp 165 | all_outputs = [x] 166 | for i in range(self.num_layers): 167 | curr_input = all_outputs[-1] 168 | if self.should_pre_conv: 169 | pre_conv_output = self._modules["pre_conv_t%d" % i](pre_input) 170 | curr_input = torch.cat([curr_input, pre_conv_output], 1) 171 | input_features = torch.cat([curr_input] + all_outputs[:-1], 1) 172 | module = self._modules["conv%d" % i] 173 | out = module(input_features) 174 | all_outputs.append(out) 175 | 176 | if self.should_upscale: 177 | module = self._modules["conv_t%d" % self.num_layers] 178 | input_features = torch.cat(all_outputs, 1) 179 | out = module(input_features) 180 | all_outputs.append(out) 181 | return all_outputs[-2], all_outputs[-1] 182 | 183 | 184 | class Generator(nn.Module): 185 | def __init__(self, dim, channels, dropout_rate=0.0, z_dim=100): 186 | super().__init__() 187 | self.dim = dim 188 | self.z_dim = z_dim 189 | self.channels = channels 190 | self.layer_sizes = [64, 64, 128, 128] 191 | self.num_inner_layers = 3 192 | 193 | # Number of times dimension is halved 194 | self.U_depth = len(self.layer_sizes) 195 | 196 | # dimension at each level of U-net 197 | self.dim_arr = [dim] 198 | for i in range(self.U_depth): 199 | self.dim_arr.append((self.dim_arr[-1] + 1) // 2) 200 | 201 | # Encoders 202 | self.encode0 = _conv2d( 203 | in_channels=1, out_channels=self.layer_sizes[0], kernel_size=3, stride=2 204 | ) 205 | for i in range(1, self.U_depth): 206 | self.add_module( 207 | "encode%d" % i, 208 | _EncoderBlock( 209 | pre_channels=self.channels if i == 1 else self.layer_sizes[i - 1], 210 | in_channels=self.layer_sizes[i - 1], 211 | out_channels=self.layer_sizes[i], 212 | num_layers=self.num_inner_layers, 213 | dropout_rate=dropout_rate, 214 | ), 215 | ) 216 | 217 | # Noise encoders 218 | self.noise_encoders = 3 219 | num_noise_filters = 8 220 | self.z_channels = [] 221 | for i in range(self.noise_encoders): 222 | curr_dim = self.dim_arr[-1 - i] # Iterate dim from back 223 | self.add_module( 224 | "z_reshape%d" % i, 225 | nn.Linear(self.z_dim, curr_dim * curr_dim * num_noise_filters), 226 | ) 227 | self.z_channels.append(num_noise_filters) 228 | num_noise_filters //= 2 229 | 230 | # Decoders 231 | for i in range(self.U_depth + 1): 232 | # Input from previous decoder 233 | in_channels = 0 if i == 0 else self.layer_sizes[-i] 234 | # Input from encoder across the "U" 235 | in_channels += ( 236 | self.channels if i == self.U_depth else self.layer_sizes[-i - 1] 237 | ) 238 | # Input from injected noise 239 | if i < self.noise_encoders: 240 | in_channels += self.z_channels[i] 241 | 242 | self.add_module( 243 | "decode%d" % i, 244 | _DecoderBlock( 245 | pre_channels=0 if i == 0 else self.layer_sizes[-i], 246 | in_channels=in_channels, 247 | out_channels=self.layer_sizes[0] 248 | if i == self.U_depth 249 | else self.layer_sizes[-i - 1], 250 | num_layers=self.num_inner_layers, 251 | curr_size=self.dim_arr[-i - 1], 252 | upscale_size=None if i == self.U_depth else self.dim_arr[-i - 2], 253 | dropout_rate=dropout_rate, 254 | ), 255 | ) 256 | 257 | # Final conv 258 | self.num_final_conv = 3 259 | for i in range(self.num_final_conv - 1): 260 | self.add_module( 261 | "final_conv%d" % i, 262 | _conv2d( 263 | in_channels=self.layer_sizes[0], 264 | out_channels=self.layer_sizes[0], 265 | kernel_size=3, 266 | stride=1, 267 | ), 268 | ) 269 | self.add_module( 270 | "final_conv%d" % (self.num_final_conv - 1), 271 | _conv2d( 272 | in_channels=self.layer_sizes[0], 273 | out_channels=self.channels, 274 | kernel_size=3, 275 | stride=1, 276 | activate=False, 277 | ), 278 | ) 279 | self.tanh = nn.Tanh() 280 | 281 | def forward(self, x, z): 282 | # Final output of every encoding block 283 | all_outputs = [x, self.encode0(x)] 284 | 285 | # Last 2 layer outputs 286 | out = [x, self.encode0(x)] 287 | for i in range(1, len(self.layer_sizes)): 288 | out = self._modules["encode%d" % i](out) 289 | all_outputs.append(out[1]) 290 | 291 | pre_input, curr_input = None, out[1] 292 | for i in range(self.U_depth + 1): 293 | if i > 0: 294 | curr_input = torch.cat([curr_input, all_outputs[-i - 1]], 1) 295 | if i < self.noise_encoders: 296 | z_out = self._modules["z_reshape%d" % i](z) 297 | 298 | curr_dim = self.dim_arr[-i - 1] 299 | z_out = z_out.view(-1, self.z_channels[i], curr_dim, curr_dim) 300 | curr_input = torch.cat([z_out, curr_input], 1) 301 | 302 | pre_input, curr_input = self._modules["decode%d" % i]( 303 | [pre_input, curr_input] 304 | ) 305 | 306 | for i in range(self.num_final_conv): 307 | curr_input = self._modules["final_conv%d" % i](curr_input) 308 | return self.tanh(curr_input) 309 | -------------------------------------------------------------------------------- /notebook.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 2, 16 | "metadata": { 17 | "scrolled": true 18 | }, 19 | "outputs": [], 20 | "source": [ 21 | "import torch\n", 22 | "from torch import nn\n", 23 | "import torch.nn.functional as F\n", 24 | "\n", 25 | "import numpy as np\n", 26 | "import os\n", 27 | "from PIL import Image\n", 28 | "from collections import OrderedDict\n", 29 | "from generator import Generator\n", 30 | "import torchvision.transforms as transforms" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 24, 36 | "metadata": {}, 37 | "outputs": [ 38 | { 39 | "data": { 40 | "text/plain": [ 41 | "Generator(\n", 42 | " (encode0): Sequential(\n", 43 | " (pad): _SamePad()\n", 44 | " (conv): Conv2d(1, 64, kernel_size=(3, 3), stride=(2, 2))\n", 45 | " (relu): LeakyReLU(negative_slope=0.2)\n", 46 | " (batchnorm): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n", 47 | " )\n", 48 | " (encode1): _EncoderBlock(\n", 49 | " (pre_conv): Sequential(\n", 50 | " (pad): _SamePad()\n", 51 | " (conv): Conv2d(1, 1, kernel_size=(3, 3), stride=(2, 2))\n", 52 | " )\n", 53 | " (conv0): Sequential(\n", 54 | " (pad): _SamePad()\n", 55 | " (conv): Conv2d(65, 64, kernel_size=(3, 3), stride=(1, 1))\n", 56 | " (relu): LeakyReLU(negative_slope=0.2)\n", 57 | " (batchnorm): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n", 58 | " )\n", 59 | " (conv1): Sequential(\n", 60 | " (pad): _SamePad()\n", 61 | " (conv): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1))\n", 62 | " (relu): LeakyReLU(negative_slope=0.2)\n", 63 | " (batchnorm): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n", 64 | " )\n", 65 | " (conv2): Sequential(\n", 66 | " (pad): _SamePad()\n", 67 | " (conv): Conv2d(192, 64, kernel_size=(3, 3), stride=(1, 1))\n", 68 | " (relu): LeakyReLU(negative_slope=0.2)\n", 69 | " (batchnorm): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n", 70 | " )\n", 71 | " (conv3): Sequential(\n", 72 | " (pad): _SamePad()\n", 73 | " (conv): Conv2d(256, 64, kernel_size=(3, 3), stride=(2, 2))\n", 74 | " (relu): LeakyReLU(negative_slope=0.2)\n", 75 | " (batchnorm): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n", 76 | " )\n", 77 | " )\n", 78 | " (encode2): _EncoderBlock(\n", 79 | " (pre_conv): Sequential(\n", 80 | " (pad): _SamePad()\n", 81 | " (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2))\n", 82 | " )\n", 83 | " (conv0): Sequential(\n", 84 | " (pad): _SamePad()\n", 85 | " (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))\n", 86 | " (relu): LeakyReLU(negative_slope=0.2)\n", 87 | " (batchnorm): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n", 88 | " )\n", 89 | " (conv1): Sequential(\n", 90 | " (pad): _SamePad()\n", 91 | " (conv): Conv2d(192, 128, kernel_size=(3, 3), stride=(1, 1))\n", 92 | " (relu): LeakyReLU(negative_slope=0.2)\n", 93 | " (batchnorm): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n", 94 | " )\n", 95 | " (conv2): Sequential(\n", 96 | " (pad): _SamePad()\n", 97 | " (conv): Conv2d(320, 128, kernel_size=(3, 3), stride=(1, 1))\n", 98 | " (relu): LeakyReLU(negative_slope=0.2)\n", 99 | " (batchnorm): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n", 100 | " )\n", 101 | " (conv3): Sequential(\n", 102 | " (pad): _SamePad()\n", 103 | " (conv): Conv2d(448, 128, kernel_size=(3, 3), stride=(2, 2))\n", 104 | " (relu): LeakyReLU(negative_slope=0.2)\n", 105 | " (batchnorm): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n", 106 | " )\n", 107 | " )\n", 108 | " (encode3): _EncoderBlock(\n", 109 | " (pre_conv): Sequential(\n", 110 | " (pad): _SamePad()\n", 111 | " (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2))\n", 112 | " )\n", 113 | " (conv0): Sequential(\n", 114 | " (pad): _SamePad()\n", 115 | " (conv): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1))\n", 116 | " (relu): LeakyReLU(negative_slope=0.2)\n", 117 | " (batchnorm): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n", 118 | " )\n", 119 | " (conv1): Sequential(\n", 120 | " (pad): _SamePad()\n", 121 | " (conv): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1))\n", 122 | " (relu): LeakyReLU(negative_slope=0.2)\n", 123 | " (batchnorm): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n", 124 | " )\n", 125 | " (conv2): Sequential(\n", 126 | " (pad): _SamePad()\n", 127 | " (conv): Conv2d(384, 128, kernel_size=(3, 3), stride=(1, 1))\n", 128 | " (relu): LeakyReLU(negative_slope=0.2)\n", 129 | " (batchnorm): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n", 130 | " )\n", 131 | " (conv3): Sequential(\n", 132 | " (pad): _SamePad()\n", 133 | " (conv): Conv2d(512, 128, kernel_size=(3, 3), stride=(2, 2))\n", 134 | " (relu): LeakyReLU(negative_slope=0.2)\n", 135 | " (batchnorm): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n", 136 | " )\n", 137 | " )\n", 138 | " (z_reshape0): Linear(in_features=100, out_features=128, bias=True)\n", 139 | " (z_reshape1): Linear(in_features=100, out_features=256, bias=True)\n", 140 | " (z_reshape2): Linear(in_features=100, out_features=512, bias=True)\n", 141 | " (decode0): _DecoderBlock(\n", 142 | " (conv0): Sequential(\n", 143 | " (pad): _SamePad()\n", 144 | " (conv): Conv2d(136, 128, kernel_size=(3, 3), stride=(1, 1))\n", 145 | " (relu): LeakyReLU(negative_slope=0.2)\n", 146 | " (batchnorm): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n", 147 | " )\n", 148 | " (conv1): Sequential(\n", 149 | " (pad): _SamePad()\n", 150 | " (conv): Conv2d(264, 128, kernel_size=(3, 3), stride=(1, 1))\n", 151 | " (relu): LeakyReLU(negative_slope=0.2)\n", 152 | " (batchnorm): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n", 153 | " )\n", 154 | " (conv2): Sequential(\n", 155 | " (pad): _SamePad()\n", 156 | " (conv): Conv2d(392, 128, kernel_size=(3, 3), stride=(1, 1))\n", 157 | " (relu): LeakyReLU(negative_slope=0.2)\n", 158 | " (batchnorm): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n", 159 | " )\n", 160 | " (conv_t3): Sequential(\n", 161 | " (upsample): Upsample(size=8, mode=nearest)\n", 162 | " (conv): ConvTranspose2d(520, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 163 | " (relu): LeakyReLU(negative_slope=0.2)\n", 164 | " (batchnorm): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n", 165 | " )\n", 166 | " )\n", 167 | " (decode1): _DecoderBlock(\n", 168 | " (pre_conv_t0): Sequential(\n", 169 | " (upsample): Upsample(size=8, mode=nearest)\n", 170 | " (conv): ConvTranspose2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 171 | " )\n", 172 | " (conv0): Sequential(\n", 173 | " (pad): _SamePad()\n", 174 | " (conv): Conv2d(388, 128, kernel_size=(3, 3), stride=(1, 1))\n", 175 | " (relu): LeakyReLU(negative_slope=0.2)\n", 176 | " (batchnorm): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n", 177 | " )\n", 178 | " (pre_conv_t1): Sequential(\n", 179 | " (upsample): Upsample(size=8, mode=nearest)\n", 180 | " (conv): ConvTranspose2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 181 | " )\n", 182 | " (conv1): Sequential(\n", 183 | " (pad): _SamePad()\n", 184 | " (conv): Conv2d(516, 128, kernel_size=(3, 3), stride=(1, 1))\n", 185 | " (relu): LeakyReLU(negative_slope=0.2)\n", 186 | " (batchnorm): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n", 187 | " )\n", 188 | " (pre_conv_t2): Sequential(\n", 189 | " (upsample): Upsample(size=8, mode=nearest)\n", 190 | " (conv): ConvTranspose2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 191 | " )\n", 192 | " (conv2): Sequential(\n", 193 | " (pad): _SamePad()\n", 194 | " (conv): Conv2d(644, 128, kernel_size=(3, 3), stride=(1, 1))\n", 195 | " (relu): LeakyReLU(negative_slope=0.2)\n", 196 | " (batchnorm): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n", 197 | " )\n", 198 | " (conv_t3): Sequential(\n", 199 | " (upsample): Upsample(size=16, mode=nearest)\n", 200 | " (conv): ConvTranspose2d(644, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 201 | " (relu): LeakyReLU(negative_slope=0.2)\n", 202 | " (batchnorm): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n", 203 | " )\n", 204 | " )\n", 205 | " (decode2): _DecoderBlock(\n", 206 | " (pre_conv_t0): Sequential(\n", 207 | " (upsample): Upsample(size=16, mode=nearest)\n", 208 | " (conv): ConvTranspose2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 209 | " )\n", 210 | " (conv0): Sequential(\n", 211 | " (pad): _SamePad()\n", 212 | " (conv): Conv2d(322, 64, kernel_size=(3, 3), stride=(1, 1))\n", 213 | " (relu): LeakyReLU(negative_slope=0.2)\n", 214 | " (batchnorm): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n", 215 | " )\n", 216 | " (pre_conv_t1): Sequential(\n", 217 | " (upsample): Upsample(size=16, mode=nearest)\n", 218 | " (conv): ConvTranspose2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 219 | " )\n", 220 | " (conv1): Sequential(\n", 221 | " (pad): _SamePad()\n", 222 | " (conv): Conv2d(386, 64, kernel_size=(3, 3), stride=(1, 1))\n", 223 | " (relu): LeakyReLU(negative_slope=0.2)\n", 224 | " (batchnorm): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n", 225 | " )\n", 226 | " (pre_conv_t2): Sequential(\n", 227 | " (upsample): Upsample(size=16, mode=nearest)\n", 228 | " (conv): ConvTranspose2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 229 | " )\n", 230 | " (conv2): Sequential(\n", 231 | " (pad): _SamePad()\n", 232 | " (conv): Conv2d(450, 64, kernel_size=(3, 3), stride=(1, 1))\n", 233 | " (relu): LeakyReLU(negative_slope=0.2)\n", 234 | " (batchnorm): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n", 235 | " )\n", 236 | " (conv_t3): Sequential(\n", 237 | " (upsample): Upsample(size=32, mode=nearest)\n", 238 | " (conv): ConvTranspose2d(386, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 239 | " (relu): LeakyReLU(negative_slope=0.2)\n", 240 | " (batchnorm): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n", 241 | " )\n", 242 | " )\n", 243 | " (decode3): _DecoderBlock(\n", 244 | " (pre_conv_t0): Sequential(\n", 245 | " (upsample): Upsample(size=32, mode=nearest)\n", 246 | " (conv): ConvTranspose2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 247 | " )\n", 248 | " (conv0): Sequential(\n", 249 | " (pad): _SamePad()\n", 250 | " (conv): Conv2d(192, 64, kernel_size=(3, 3), stride=(1, 1))\n", 251 | " (relu): LeakyReLU(negative_slope=0.2)\n", 252 | " (batchnorm): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n", 253 | " )\n", 254 | " (pre_conv_t1): Sequential(\n", 255 | " (upsample): Upsample(size=32, mode=nearest)\n", 256 | " (conv): ConvTranspose2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 257 | " )\n", 258 | " (conv1): Sequential(\n", 259 | " (pad): _SamePad()\n", 260 | " (conv): Conv2d(256, 64, kernel_size=(3, 3), stride=(1, 1))\n", 261 | " (relu): LeakyReLU(negative_slope=0.2)\n", 262 | " (batchnorm): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n", 263 | " )\n", 264 | " (pre_conv_t2): Sequential(\n", 265 | " (upsample): Upsample(size=32, mode=nearest)\n", 266 | " (conv): ConvTranspose2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 267 | " )\n", 268 | " (conv2): Sequential(\n", 269 | " (pad): _SamePad()\n", 270 | " (conv): Conv2d(320, 64, kernel_size=(3, 3), stride=(1, 1))\n", 271 | " (relu): LeakyReLU(negative_slope=0.2)\n", 272 | " (batchnorm): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n", 273 | " )\n", 274 | " (conv_t3): Sequential(\n", 275 | " (upsample): Upsample(size=64, mode=nearest)\n", 276 | " (conv): ConvTranspose2d(320, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 277 | " (relu): LeakyReLU(negative_slope=0.2)\n", 278 | " (batchnorm): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n", 279 | " )\n", 280 | " )\n", 281 | " (decode4): _DecoderBlock(\n", 282 | " (pre_conv_t0): Sequential(\n", 283 | " (upsample): Upsample(size=64, mode=nearest)\n", 284 | " (conv): ConvTranspose2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 285 | " )\n", 286 | " (conv0): Sequential(\n", 287 | " (pad): _SamePad()\n", 288 | " (conv): Conv2d(129, 64, kernel_size=(3, 3), stride=(1, 1))\n", 289 | " (relu): LeakyReLU(negative_slope=0.2)\n", 290 | " (batchnorm): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n", 291 | " )\n", 292 | " (pre_conv_t1): Sequential(\n", 293 | " (upsample): Upsample(size=64, mode=nearest)\n", 294 | " (conv): ConvTranspose2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 295 | " )\n", 296 | " (conv1): Sequential(\n", 297 | " (pad): _SamePad()\n", 298 | " (conv): Conv2d(193, 64, kernel_size=(3, 3), stride=(1, 1))\n", 299 | " (relu): LeakyReLU(negative_slope=0.2)\n", 300 | " (batchnorm): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n", 301 | " )\n", 302 | " (pre_conv_t2): Sequential(\n", 303 | " (upsample): Upsample(size=64, mode=nearest)\n", 304 | " (conv): ConvTranspose2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 305 | " )\n", 306 | " (conv2): Sequential(\n", 307 | " (pad): _SamePad()\n", 308 | " (conv): Conv2d(257, 64, kernel_size=(3, 3), stride=(1, 1))\n", 309 | " (relu): LeakyReLU(negative_slope=0.2)\n", 310 | " (batchnorm): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n", 311 | " )\n", 312 | " )\n", 313 | " (final_conv0): Sequential(\n", 314 | " (pad): _SamePad()\n", 315 | " (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))\n", 316 | " (relu): LeakyReLU(negative_slope=0.2)\n", 317 | " (batchnorm): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n", 318 | " )\n", 319 | " (final_conv1): Sequential(\n", 320 | " (pad): _SamePad()\n", 321 | " (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))\n", 322 | " (relu): LeakyReLU(negative_slope=0.2)\n", 323 | " (batchnorm): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n", 324 | " )\n", 325 | " (final_conv2): Sequential(\n", 326 | " (pad): _SamePad()\n", 327 | " (conv): Conv2d(64, 1, kernel_size=(3, 3), stride=(1, 1))\n", 328 | " )\n", 329 | " (tanh): Tanh()\n", 330 | ")" 331 | ] 332 | }, 333 | "execution_count": 24, 334 | "metadata": {}, 335 | "output_type": "execute_result" 336 | } 337 | ], 338 | "source": [ 339 | "g = torch.load(\"checkpoints/g_87.pt\", map_location=torch.device('cpu'))\n", 340 | "g.eval()" 341 | ] 342 | }, 343 | { 344 | "cell_type": "code", 345 | "execution_count": 8, 346 | "metadata": {}, 347 | "outputs": [], 348 | "source": [ 349 | "raw_data = np.load(\"datasets/omniglot_data.npy\")" 350 | ] 351 | }, 352 | { 353 | "cell_type": "code", 354 | "execution_count": 13, 355 | "metadata": {}, 356 | "outputs": [], 357 | "source": [ 358 | "display_transform = transforms.Compose([\n", 359 | " transforms.ToPILImage(),\n", 360 | " transforms.Resize(g.dim),\n", 361 | " transforms.ToTensor(),\n", 362 | " transforms.Normalize((0.5), (0.5))\n", 363 | "])" 364 | ] 365 | }, 366 | { 367 | "cell_type": "code", 368 | "execution_count": 22, 369 | "metadata": {}, 370 | "outputs": [], 371 | "source": [ 372 | "def render_img(arr):\n", 373 | " arr = (arr * 0.5) + 0.5\n", 374 | " arr = np.uint8(arr * 255)\n", 375 | " display(Image.fromarray(arr, mode='L').resize((224, 224)))\n", 376 | "\n", 377 | "def display_generations(self, data_loader):\n", 378 | " train_idx = torch.randint(0, len(data_loader.dataset), (1,))[0]\n", 379 | " train_img = display_transform(data_loader.dataset.x1_examples[train_idx])\n", 380 | " self.render_img(train_img[0])\n", 381 | "\n", 382 | " z = torch.randn((1, self.g.z_dim)).to(self.device)\n", 383 | " inp = train_img.unsqueeze(0).to(self.device)\n", 384 | " train_gen = self.g(inp, z).cpu()[0]\n", 385 | " self.render_img(train_gen[0])" 386 | ] 387 | }, 388 | { 389 | "cell_type": "code", 390 | "execution_count": 25, 391 | "metadata": {}, 392 | "outputs": [], 393 | "source": [ 394 | "z = torch.randn((1, g.z_dim))" 395 | ] 396 | }, 397 | { 398 | "cell_type": "code", 399 | "execution_count": 47, 400 | "metadata": {}, 401 | "outputs": [], 402 | "source": [ 403 | "i = 301\n", 404 | "j = 0" 405 | ] 406 | }, 407 | { 408 | "cell_type": "code", 409 | "execution_count": 48, 410 | "metadata": {}, 411 | "outputs": [ 412 | { 413 | "data": { 414 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAOAAAADgCAAAAAA/RjU9AAAXYUlEQVR4nO1965raONNtHSTZQCff/V/mTABbUh32D8lA9yQdm4Ts8D6syXTSndh4USWpzqDD/zbo//cDPBovgs+OF8Fnx4vgs+NF8NnxIvjseBF8drwIPjteBJ8dL4LPjhfBZ8eL4LPjRfDZ8SL47HgRfHa8CD47XgSfHS+Cz44XwWfHi+Cz40Xw2fEi+OwId17n/atD/wUAANh/Yf/mfjj0G/dbI8J9d72PoPf/DMzdrf8EEBEJCJAaSfQ7WTq4O7i7u5sDAiEh0j0U7yLYXhzMzdXVzB3A3YEakAgJkNqjbn8mcHBzd3MzM3VAYGJiR9zO8E4VdXd3dTU1Ue3vNhATMzMRIyD4/Tra7m5mqqIGiMyRnRlg85t17xpsYjQVERFzcDcH5hCCOQOhA/b3+65bu7mbqYqIKCBxBAdsSvEHVHR5DDOpUqtoUygIHKNFADR0vHuXafrexFelVnVkdgAkRECAbQv7HoIO3laHSCm1lioLwZBENZq7EwM433FzAFhur1JrzUUdOXSCBA+WoAMAmKurqkqVkktZCBqGkFJKMcUQmAMT08dNBlceI+5qKrXmMufqGIMBEt2j9NsItu3STKyKVCm15JJL6RLEEIaYUkwxRg6BuRFcHgoBERG7ID5VYXczUal5ytNcjGN0RCIiRMJtm9cmgt53Fq1Saiml5JJLzlUU3M2JY2oEUwwxcGACuG6m2J6RmIgBET950rbB1JLn6XzKFuLgSExMiIDbyui3StDNTSTXOc8557nkknMRA3cD4hg7wRRjCI3gVYJExIGZAzsi0WeScFWRmufpdPo2W0g74BCCELJvbBPYSNDczGrN83Sepmmec65zLmIAXYIhpZjiEFOMMTKhLysOAIk5hhijBSdq1tcPXsZNRWop8/n07Z+zxaECxxhDMAR6rATVTGuez8fj8XyaOsHqAGCAHGNIMQ6xLcUUFgm2TYVDSGlIZu7MDj/cMvoBW0uezsd//zla3FWMQ0pRCdF40yLctgbdTVVqns7Hf74dj+cpl5pzEXcEB+QQQwwxNQwpLBJEAHAMMaVR1d3doZtyP3gda3voPJ2+/fPNYtE47vOoZoiPlaCqaM3zdDp+++fb8TSVIrkUaTJCCoFDiPEdwW4ig2NIwyBmAADwuaq562URHr9ZNBwPuRsUj1yD7iYibW87fvv322nKoqUW82acMTMThxhjGoY4pNjOZWzXUhiGnUI/LNA+eZ22jdZa8nw+HS0B73MRMTNy2EZxDcHlsHY3lVrnPE3n0/H47dtxrmpVxAANAdHdFIlLjKXUWGOgfqw7gAHGtBMjRCRCos/2w7YWVCTneZpUKOZSxcwfLUEVKWWezqfT6XQ+n6dZHQCZAAmhfQUDdzN30RgIHBAB3MEcwyjqRMTMRGSfPWnzBFWk1lIUSq2idoeCrifoAOimUvLc2E1TLmLAhO0IRyRsRioAgasimCAAIDqYuxrEIgqBOYRAxPYpw+Z/NaO7M7uD3XqCDgDgpjXPp9PxeDxNc67qyDEGDkTNjjITFTEDIHRTVwRvmmsq5mFUoxBCCIE5fM7v4toiklNzve5yvzZIENy0lul8PB5P5/NcqjnFYTekITASMYPWWksWcQMCA2sPhN4cO+eq3o2dEPQnElx43uAOepvXYM3z+XQ8Hk/naS7VgNK43+0GJqLAIDnnmTOpuBsoQFdtVZFSnasCx5RSijGK/pnWzDUEL+atq3QJHs/TnMUY43h4e9sFIo7sdT5PjABg7cTyFhRre1MxLgqchsYw6V1LajPWSbAHB22R4Ok0zaUqAqfd4evXfSQKiT1PKbYzwd1bNMoB0KWWkrMGBU7jbhxzGUR/rKHtjcGbb+/bXzYQBDdHUCk5T6fT6TzNRQ2J07DbH758PSRsBCMTABBzVVU1bQ9mYAiu4lRKqbVWEWmhuD+AVQe9u6mDS87zdD6fp7moUwwQhy9vb29f3t4iEifyQEREccxZVBaCDl7mGECBeNkpLvHcx2MFQQdwVXWr8zxP0/l8ztUwRuS4+7+vX74cDvsIxIGcAInCbi5V1NTU3BwcLJ+PJ8YJUgrMTD149GewToJmImZlnqbpfD5PRZyIYxh3X//vy9vbYR8QickZOYQhFxE1NTNzc3ew6TgkdLQYWyDjF3b9zVi5yaiKas7zPE2n01QdAg/DsDt8/fLl7e1tz4DICBJSGnKtKmbmjaA56HkglCLKTER/THYNqwi6u4pInudpOp9OZ4PAPOz2h7evX798/fJlDA6ECCYitaqYmoEtZhbIMYLncylEl9DmH9piVhM0lVpLbotwNkYMw+7t7cvXL1/eDoeRm1/rqqZq6t2pNTNVgDqSyvmYAQO32N8flOJ6Fa15nud5nnMpgE5h2B3e3t7e9rvdbiB3BAQ3d1MwN3e4Eiwk8/x2LOQxpdj2mT+mqeskaCY152nOpYqoEXIYxt1uv9+P45BCIHd0BLCeF2rx0+a2AgSdDufDWw0eh/04Dp3kn6H4KUEHAHR3cJWap2macxVzJI5pHHf73X4cU2Qm7JlAcnAEW/KH7kYMDsO427+dJXlMh8NC8Q8JcbUES5mbC2hAFOIwjLvdfjemFJiuWz6CG9Bi2rkjGjikNO4PX7R4GHZv+/04pBQD/Zns+cpNRqXmeZrmKtqiZ2kYd7vdbhxiYMTFIEdoWf8LQTIDhziMu7fsBTiNh0O7KvBfoKIN7q5aS57O01yqOhLHmMbdbr/fjakFXq6m101+C90dEBxiGvelUgGOw36/241jiuEPLcKfEURYTok8Ty2IjRRiahIchxj4P+mQFkVrGXpycOc01qpcgWLa7cYxDSneKvYjsV5FyzzPpWtoiGkYxnEc0w9V7Zo9IghxqNW4AoU0DsOQhvA3qWgXYWlpFgeilgdMKaX03yPtw3MjAHFMozoLEMdhSCnEsGIXRYR3TqG53eEarouLuplIKUXUHJl5CV7HQET0Ptfn8IEkAoeYzIMABk5DDCGG8Am/bonjjXPl3g3bzSbeSkvGTKWWKupAxCHEjoD4cS3h5be+NJGYo7kHdWKOMUQO3Wn6Md5FmryFD93cNlNcSbAHxtQBkcMFzD+zK5soOJhDUECmGDhQoJ/tMUs6uL98D3XrIyLbvejBVVXNAbkzZCbiT9N80I5HdGQODmwIxC1/8RP5wSJCQGxlR2ZmprYq1vgOPyPoS4jZzNQckIhCCK3eh+inuaxGg9gB2QGRmIj484KCXhSGF89jKSpRUd6qo2tV1F1VWyYCugCZ+d1G92MgMSCrIyDhUun1E9W+XYSLgqqqhockX7qKmAEQ4UKQiBD95+c1kjOy2e1z/zi/C/2ftbK3TrDVzahuV9J1MRlwM1c1I0Qi7irKSJ8WvCw/dwSm7sPjTanMD6/r+RxmopZ8a3VdjWBT0dVFfmsDv+1NdAds+wTzZaf/uQRbhvDm+59e0QVIiISI4OYqUmuNavZJ5vQ72JCbWBb3XdmQjXYZIjFf1wG4Sq2llO05jQ0E/1SgrwuQL3sZIbhJLTmnEMInMf/v4Y5iPL9g+7UrgUjs3hOJTAiuteQ8xxDT4wn+CSCxw2ItEbSYyZxiHB6oon8QSESMcVmEBqa15JzTUGXjOfEXEkRAJGI0jSG0QL+bSClzSqXq/4gE2RAvJr0jmNacci8m2XKrv5QgIjtyCCHGmCogLkrayp027Od/JcFW0g7NsR4HIUdylVJyqSJmju6fV9Re8VcSBEQkAGqxn52SAoFrLaWUqmo9979Kjn8nQQBEYOMQ4zCOgtUBVGpOLXXgsF5L/1KCSADALcA8CqCamyDnUkTNjNzXWn9/I0HsEQvnpqLiUEXVgGKp9dah+G3exB8HAjV7O6ZhKM0bRMBS5OoyPUyCNyU+DwN2h4JDSClJFQE3EBG9VAyvxIYUDyKC35jaD7S2cfEomEOMMYRA4Dcu/YbSw7US7LHYHsGzhSGu15X18IUhEhOHwIGpVduqXkor177qRgnCrbf02DqCFhnl7vfiOzdtwyuvI3j14f0Sf3owxaXOdum5XKKHjyC4BIHoEsOzpbMVHrXXLAkARGph4hYB7u1TG6pQ1kuQCAnxGkG8J9WzEZfY05LCatUNm+6xcg0iImLLJ7i7259Zhtd2tUuSYvMt1u+ixETk6NDTBJf1/ttjUdeShmv4sEXw7ziEfy5BvGQKmAkRuoou+cg7i+FXAa9L46qhG7FSgi1MGcyttd2o6rKNPjaY+EsF6QAra7Zb519M1Vxb+ZL2ZNafi5XeixVVFgiAxCGmVM0Q7JLoMTPyzxo5/waslGCz66sKoV/SIGZm+Fij+9exbg0icogxxsoIrqAqS6bnwY/36/iUYEuONas3hJRSZUI3EFHRVhGADz8LfxGrJIhAHGJs9Vc9Xd+qst2u1RS/DZfc3+9481araHPMGAF8mTEhqkT0WAn+ssW0ToJISBxaeZIbmLZ0ZBVkf2BZpF8yu+Z+26i1oZl+BcGWUG8JSSIABzORWkopBZx5c2XHevRGUDUzQECipdFu/apYKcEmwpY0763YJeecEPCx/LoE3bzZG4+QIEBfhKHVeLovGlor4+etqr8IX+xCdYcbt3sD1m4yRMxdQ91af0QtpTLptpqATehl7f3I7ToKAFvmWayUYOPXy7Otdw/XWn7aqnovHJbwiF561ZCum8zqioENx0Ro+UhqBbJSa84pBHlko6NfXZclULpxCa4+JohDqr2tw8BdpeSSU+Cgam3Qw900fgS/bDJmDgRtUksvz/mdyRfsJZ/DkFKMgZUQTGqZhxRjFTNfytJ/G0l3aI0mZiZqfQ8NHHoEClYPxVonQSIKDtqKfIMiukkpeY4xxp7OerQEkZHbQdUiGL81fYZIHAB1GGKMIUjP1s0ppjRsrQrYgj60Q9vwuF4Y1HeZlVhjyQASBUOyYUgxxsDeJTjF9K5u5TcJEVv8E5ul1vNJPWgSegPpahmulCBzYPUhDV1HwbTmEFLaVVFrZv9vDF/092wZHqdqAL22a6mxxJVuzOpjwp08pRhTihURXGpmjmnOpVRh98tsxd+Jfkw0U7tlKRZrdO0tVhIkYnBPKQ0pDSKGrrUQhTSOY86BAYi8V1e+e1+/uz7Xzj30ZYCjebOG+RLlXn0Srg4bAoDHmFIax4rqYJWAOA3DMEaMiBwc+uL/8Nrff5R3P/0u3cs0kl5Ke5l8ua5KtWO1LQqAFmNKwzhWFEVFdMCYUkzsiSlaAML/zDTcnEu4hd/4uzdFzg8wtgGREC2EOIy7nWCp5mJqziFwQB1C1OTAy8y7D8/ZwtIAi3zvXKt3RYDXWTJLhVyIadwdFBCKGtQqSASgeZeG0QCdCAguk9Quovv4B1/u+f5v373g9+g4bN+rV0uwGaQxjfssDqai6lSB0K3m/W5UQHRiALrdQfyay3tnyt0Ox/20seT6R//w/Uqszi4BInKIw7jLYi4VTAQEmNyklipOhM7u7LRkdByW0bHLA+LNzW5F8VFS1xm/mwl9xFoJwlVFi6iWjCpFizOa1lrVMASC2Nqa7aKirWZh+W7pDL02Jn1fIp3cb8lCrmuvax4mcYhp2FWRmgOCVkFgNGlNaTEQeAQAv3TO+dL2d3s3BEQg6G07CN+Z5dz44XVeVdd6vH79zQd9uy0ShzTsRGotc2QCg5oZ2pxhCgSaJPbQGywNHf6xVwV7DOs2BPj+cdHdzRGk1Da25Y62we0EAQCQQxxEmjsvbmgMVtwdiZBccoop3PQ+LnbIu7BNa2DqpROEfd+9HCEAPZqGLqfzNJdSRdw3ThW9lyCFOGjLXxuFOFcK5AKAiG5lGoaYYgjE3ZYyN1VTuykdQFiyxcy4WF437o9fQk3g9fTvt+N5zqUCALk37X7MMXEl2AcoIoY0TsUA3ErTp2k3tIFb1CbDtmGyYmK2HBatJwuJKTbvNVDgNi+2S7DtvKqibvX4z7/H85SLAAEvgcNt/DYW41GIvc6KkGMapioqokVdy3we+8AtZgYA8xZ+02p9dsflfGDm2MrNA8fAjEhXguZtuItaOf377Xia5mxIcG/wbqsEExIHImbiECLPOVsplks+D8MwpCG1URwtM2QqtUq9VGT0JABRCKm1AMeQYgh0Q9Dd3aTWKlpOx3+Pp/NcjBjCnenI9QQdANiBQ4wtvBUCE1qxchbOOcYUh2EYhiEFZnR3B9Vaai1iBm6dICFRaENW49Cm5QamxZlsx59KKbXWfD59O56mSZwNe6n25v10owQZOQgjLfugyowmmVQCx5CGYRiHFDvBZuOUIuqdYKuvu/ThD+3/yHQtFDG3pfw8T+fjec5VsG0xeIczsXENtiGU6Evdu4vkHIOAm5iJmampLBJ0l1prLaoOYJddkEmCqtRaakwpd4LtqLgQrKXUMl+GZyyd+zGEjSMw1noTzQVwBAQijqZtaxStUhS9jUxzUxUEZe49larqjsiOToux3N4hATeRkEKKMTDeEnRXqbWKlDK3QfcchnG/PxwOu3GIGxmulyB2c9KRiENsBFVFxdmQ3BEIwZXElAGtndeAjOyAvjhJjgiE6FZVKucQIodlTD8sUQpVERUtVQw5YkjD/vB2eHs77MYUHqmijShxcEcgQFARESPpZMCU0ImoMWoBYXrnBPc9wtwRkQJx4Os5uKSsTVVdVUUVMFAcx93+8PZ22O+GuHHKx8Y1CIDuwC0MhQQmUqtCNVPplc7ausjf9cPenORtlFwr0QBotsztXI7O0JdiOAUKKY27/f5wOBx2Q3qYikK3+9EhtJEGxGhVqihkFfE2o9FuTK/WJtcYLNW0ZuYG1uLV/t8gbjNl2pkPiATIMQ3j7nA4HPb7/fhQgv1dJgBgIyIGKbWKYqi1goO3hukrwdb8wNyO9yUnbQquZlLVcXHGbglCn6LeKjgxDuO42+8P+/1hP6a4cZLQHX0TCIQAhoSoNVcRo1hqZq7enVnveyIRxxCZw3uCqlIQwA1smW31niBeYkAhMqWhqedhvxuHtHVU0kaCCM35JCAEB6n7KuqcShlyrr26sjs/2GbDhzaXjFrhvJmJiZSSYwli4O4GflPx5m1SEhEihRBiCHHc7Q9f3g673dhGLG164rsk2Oto3VMaqzrEoZZcsti7UNIHgkQAy8gGaR9VIe7u7/1hBwBfQqAhxBhDHMf97u3QRmRtHnZ1Z+8SAhC4x6GKAaWxllKL3D5qI8jhShBvVbTkdsGynd4S7AWqRI1gG8L3dhjHYWgjbB5OsMeOiEMcDSgMVWotVd3e/RtEChw4MBNcCLqpmpSalw/D0f8SXLoluoqmoY8AG5pv9XiCHUQczTGk0orX/ht86XMvCFuMCbwlpU2qlFqr/liCTNQclhA4xSGOw5BSDPzTUUK/kSACcQTgOIiaqcpHVwYRiQIR9zghXMJQKrWKyA8IXmrt29vTWpVjCiEy4daRiL9I0AFCFHVzNf3Y1tAKTZEY3hN0U9GqIvYTgsTts8Y4cFh6tLZmIe8j2F+DAJDMtI+D/ZhmQEQgwj6XbPFJlg/9UtGfEiS+lI603OCG3PUvEVxIEhAbXBr6Pjrb1GOg7WTEm1i3LsUTH2elXQguHSHcPw0PL3GNbRK8O94Il9xDsxzxO7GEi6F5m0tqwQyz1uvYpsn+58JGcPltKcO75yF/ZRdtyWpshsiPCH7UqNaiB4hIl+6Zj1e20Df1T7GgLaVp/8EvNimjI/ToxXf+tjnI8DHn29UQyaEPoP141yU2cuMn3ps2fWwXtv+3rcKvFAEIflC1f1u+fL3yHoq/TPDzF/2uZiEAADlel/B3/sm1anJ7uP4Wv0QQb77+gCp+/6/QF0n+aJO7zCdb/JM7dfRXdtEFn93iBxnO2+u+L8F3F/+CBH8Hwb8a//OfpPwi+Ox4EXx2vAg+O14Enx0vgs+OF8Fnx4vgs+NF8NnxIvjseBF8drwIPjteBJ8dL4LPjhfBZ8eL4LPjRfDZ8SL47HgRfHa8CD47XgSfHS+Cz47/BwfSDjD5c+2lAAAAAElFTkSuQmCC\n", 415 | "text/plain": [ 416 | "" 417 | ] 418 | }, 419 | "metadata": {}, 420 | "output_type": "display_data" 421 | }, 422 | { 423 | "data": { 424 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAOAAAADgCAAAAAA/RjU9AAAyJ0lEQVR4nM1963bjOq4mbpScVNXu7jVrXmPe/7FmpiuxJRLA+QGQoh0ncap29TnclyS+SPyEOwiC+H/k+V//6x8/vq1YNy+ntQgTgrubq6mamZm5uzm4u4ODgcfvbm5gbu7m8ZK59zfM3Ezjh7u7AwAiEhETMzExERKio4O7upub5sXicgAADoDgCECISMBIzMzCLFKklGVZSllEiogIF2IkADdXa9tey9NTIekAAMb8DDEnCgAAgAgAZADoDg44AMbs+p8AgJDfACQwRETEvMg08HoYAJA75D/xvwTYr4mEhISMxCwHwLIspRQREWFhToDogM7szMwku0mtzRzcTVtjBCekiRp5D8R4oA7QEZl5/NOJZJRfQkdzROyAj4d1/NafHTK4k7u54TUFsX8CEYkYiZiZubAIixSRUkqRIsLCQgEQAdxdjQFZ1rWgvBY8X/ZqDlrVEUylf2zQNtAhIAB6AE1Y5snDNlg1GBQBALxfyM0dwJEQ0N0Q0dA9J58sYW6c8PKugICIBEQJjUVEpARH9sHMTMzIxEgDYCu1yelUUF4Fni5bMwBr1RDdjIkAukgBAoIjAiFhzggcvIMyNdP4MQCaqZFRAwf3zqYO4I4AiESYRKGQQwAAi0dzDRAJiDGgFeEyhpQgGzMRERESEBISDIC1Kp9WIdms7FUNCRCSHARBKwdH9/gDMSeDSIAHQAuApqEhbABUVVZW0njf3QGQmIRZ4pkzM8VTx8671vWTe4gqITGxsMhSipSylCUQclAOiRAJOyMH34Cbaq3Kyyoo5g4IxIIiTsEBwcvglrKOgEhEhBxXoQDo1rGYmZl2sVXzZqraWtOmqgEQAGlI0AAY1OzSeU3BVLrBmYFrkSUZlImIExwGEeAAiAhOJEwoiIjEpSy47kTrWkopxDweqne1FxQMNk2VO1EwBDGY1MxVO8Cmahq8TsIiofSChERECAR9Xt654qAghdqUElozv5xGJh9OKIjUSQDuQADOQAgOgkjEUpaVtAktaylFmLuecfCu1pNFqQPs5m7gs2FmwoRq06ZVEyAg0TzHsIRxtQcABhFDpfB41IjdMuFhpNDdHZEcwR0EAImlrCubFlyWErdPVRIA6RYggsENwK7e03yYmjbVgKfmAMShJDoNiBiIEP0hgJzGjpgYiYDw4M07gzxFGwQQiaUsJ0GoVEqnf1AJwCFhDYBdVTloOiwdYB+pXdU0vRkDQGZZSlCBmCcapIGESa4H43QDIaFUutQSBOXvowMA8tRcIIjEZVlOp4WxoQgRIqdiS28i2D2nRMPYuFM87e4MubtB6MJw9dzAXd3MATgURVCvi0/H59BNTyqZDpCRuw0kQUZGIgSkztdhlt8OT2l0ECRe1qfn09Mi1FAYgyXDaQlrRIxpIm4AJjsZDIDD1YLuBrm5gQOyhMvIzHGhIXxJvuB5ODzRMBNM2J1XyGecKrP7HvdG5w0HISrr9x8/vj8vjRWZELpQgLujA+BEwZTHA+CVZndwTB+52/bunHUfsisWxHj4g4IBEMaFwkkjJEJGJgyDQofgTWrlLkIiR3AQluX5r3/+9f3b0ooiI3bG7N4vIPJBumB/6GZrmHZ3sHHXLiAYQOMaQ20GBSbtB2B5P4PBBEGG7mWHGpjI/i7tBj4kBwAQ4uXpxz//+v5UrBgSDt05fHpE6iyac+8Ah+Y0D4ApV5RTw+N+QYH+0u30/Hig/eb5KeoxRxqrD8k2bgeOiGQA4MIsp2/fv51WcTYkzHDvBiAO+3AHIEzOI0LqXHwDcLJa4O8BPP4NgN0F67/DCKMegAkAIMhy+vb9eS3idAsQOlmC9w+ADkMKRyScQDqxoBuiDLYmq/x2ikNgwf14v6sRRIQwlyNuewAeAroDCJGsT9+eFmZAR8RxiyPsTM01MegASG4UH8xHjnjXy0AAgunvG5DHK+4IhjCi5yGwH1m9j4Ywy/r0dCpyByAcrui40zVAd/IOEK8AwsSQGZ08ID2ADsDjV4yw99egJUBiKafTUogA/SpsH4wzRMAn3efo3dr15wChWTAdqRt1EFr1A2jx42DREWX74PWH5W8CyCxlWRYm8mADDxnzg22w32vMw8OH8MOqAwACZdwxPJRh566MQvoZN+OYPN78lne/+607w2GoJgAQIpZlEUbqsS3Mj3FQ8OrqCPkcfOik46qHI3Vnyu+Pe5/Bzz7w+RAiZhEh7+HU21tMrsPbQeDpviRAnyOg//4h4a/TPRB4ZAKvWWxISTLqofBm5QndoLwZ79HCb9/7DeUCGKkgiWRE1+F3OQvfADymhD69iZPw/U8YCADSPaF7b755om8/Nbkkc3jwP2cEQJo08M0cDxZ9a5tvP9kBfmgO/oMDAUAABrs+8PHjV3/zxmT4ZvQ3l/5I2f899MfOeiGD3c366g1nM3lvHGbmv4+cCJ8A/N3hAA6W16d+y//skJ5aeGjcmAm4J5ww2NAzIe8AmQbCjxj0bxxprwAApCeM/uZbxHpc5gEBet7j1h/9Y2hjTQcGBf9efD3Y6Hl9AEAWEsqc439yiKqqmd3Nvx128J3s3JsvhIeW1FPV2mpTdUSKlC8DwX9WEsXMTNVoaL03Ft4B3heecNVuhTOo12rd9n1vzYFkOa3L6i6xYvCIXfX33vjSEDNTVTvSvP2+Vwr+fQqOsPEQLgeLBaztsl0uW22GVE7PTydzB/kPu3MSq0BNsFPvdjzInflhAABzNWv7frmcz+fLZW+OvJz2Ws0dgCNg/0+BFDNte23EPhNs/vFRqsHBb1jNwdxaa/t2fn19eT2ft10dZQ2A4OBAR8oU/hzSzMAFwL0yUkrexKLxA/2NvwUjEXb1KCKRYdZa3S/nl58/X36+XraqgLI+1dp6qBhrW9OF3our3r745cchqq1u2y70lhQPDD9kcKTi1FrdL5fz68u/f/58Pe9VgWRv5kDMiIhO6Ei/POWvDTFr+3bZixgY4HHfrjpiEm8E8Uid5s9cjXA31bbv+3a5nC+XyzkBqgOSFEEidCYgu8neX137zq+PJeXuAVTdz+dtWRgMkGzONThALMX4xKIIIyPlfQnRPUXK3FW11n3b9r3utdZaWwM0R2QpixATuDM4QWS5hhC8xYhHXm/M6GuCiwBi2vbz6/nUGAwIwh4eAN+hoB/1TT13GCmotH97rbW1qD1wd1RsdSvrWZjRrQjTWKp6O60u2lMG8Rr5V4a4t+38cn4+MRgQO04S/2E4NKrTwNzd4hVzU21aW1NzJGIWM3ckdNe2b6UwWCsixMiUy1bvAbzF895q58cAwXV7+fdf32tQ0GnWmI6HAu2XT10Co6wlF9XdwEHd1FRbMweUpZkTczNAFiawup/B2iKFo1zG2DgLqI5HOUzPTMHwnadPfTryquLetpefL5edQZGBraeRB64ANwvhtBprrmaaZUGQlU+mTR25qAFyKU3BiYXA2oamW+EyKrLYcs0QDpfCoeda5zUPRHBEH5npT3FmNAGu28vPl8vGoCjgb5fw8HDjjic818loVgJFeKQeRRaAVAyQmJu6AwmBVrS6Ced6/bKURQqzIMXdugLoC+EzwFEw9EVlKuC6v/58OW8MSgbORx4bAcjpjlz3RV0zNWuW1Uya/OoObmaOxAsgSWlqCkCEVq3G4gxzWZb1tJ5OaymFOIjYPadex5BxHCJQlB8kxq9AFDfdzy+vrxcGxQXBqa9WASJYrmjP8u3gFhUkWQ3TekVMliTEqoU7MgBxaa2pmiOA1RpLpkgkZX16ev7+/HxaV2FGol6xmaWH5maR78As3GOiVEj4RrveUCAAODoIuO3nl5fXM4GSI3jWFEJG/bcOThhziOKYo2CrJRnBDWDUDyIJEhFpawaWXlwzdSAu69O37z/+2p6f2yLCXRKzFNW1V79FOQkJszE7vdG5n1EQXOt2Pp8TYJRNAiJcXWlWMv3+allx11qtrWmLikOAZIEUVQMAM1VTV21131tr7kiynr59P1/OP/bn01KiQAfGl6yZmrp2gCzCRYzZvxgxi7u17XJ+fUVQdiJgckRENCQAOFi0Z+m910uGQdBWa9trrdqCScG7pDhCzLXVWltrqq3tdathRHhZnr+/vLycLz+en9ZFoj4OoqBLNZ6XujsCEnORoqWYM9DHAc5bgKp1u5zP6CrOjE4cC+2Q5dk3FAwlqaZRUait7nXfa23a1EwNHDAX89AB3FRbrXWvrbXW9rrXlmq1nF5ezudtr3t9blEMHwCDM1RV1bMOU0pZ2moFAADJ0fFBayHgHs4/ugqWhYABGbqifhvIRLKlj9Zarfu277W1lrWvgEPMAdw8AdYJoLkjy3nb6t7MmlpbiiSPhh1NsXbrAJeqZgYAwP4VCoKbtn27XNBVeN2jwhTpsEQzQM8JpN7U1lrb933ft73VpmrxzD2KNhEA3TwfQ20tmLk1NQMgbqZm4NbM6qkURgIAN7XWtLUamtkdkLiUWlXNsuYCHpZDcXBrbd8v4FrKXoWRsnBixAwAfoRRbpYM1FS11Vbrvm/bXmtq0l4YRNBLv0NWW201pbGZA6C6mwO4N7dWl0WIMBmkaqtNtZmZeVJwbaoeFRxRnvDxClikDVzAPWJedLVlr0VIARF9lJEYIoKNxZUwDknC4NAgYSA0dQsD1h2wrpRaa6Fqki5A7iHmjq71VIU5AbYYWQ0NiCylNjXDKMREephJBcBdW6s7utqpNVXiXlbgb40gmFnrGqAjrPu2XfY9AGYBqXdzmH6PqfVZq8b7UegcZWyuWoswosEdgMDSmpoDSxSy+8OBhUQQ3tpOZlCbqnkkSb17tdNwAOvUG55Mq/u+XbZt31szHZspklEOyxZOT5S7hg+mUacuXMhN10WY0CBksGprGnuDHJCkNXMHFmFisr4b4J2B4+kGBc20tWqGrVlUkGOnYALrMTe6Ngt9mc6oam1137ftstWqWXpvrtb3dqRzedTn50KBA4BZqyxFhMBNW2FCh/T/Or8ruAM1UXdHjn0u04Lm52YCQgobm1ELy+Nzfc81BT0JOD7lNgQxAPoAeEStjln4rHb9umKr+0WKEKJrY0aHoLVqF2h3B2Q1d6DYPiH8+GKKHK5DpLiP0mtIN2nSyw6hYUaQmxGStlpr7QChAxyJAYfcCKRuE8R4PLVupTCRm7IgQPi4LWaU26JQ2R2AZFmWpZkZPSqE4t4Rhn8yBSqG1qs1e/lTGPngz3BZ+ouhFcwsyxDzE0dkl1sqbFQse7c5te5bJGtEESE9NQtlG1/A1Ltl25blpMeOlXyC9zgVOwW7Dui7H9zdrRegziQMZu77ICCLqhHDH+t7DQOgmve0zZGg6sAPBgv92uouIgSgPCxhki/LwQEj7b9t21qbuj2sRqXzyvUAv8Pl7gEwHcSs4804ZzwlmPdJwnTBsBc3V/ZQALXuTABiNB6Xm7pF6GxZd4j7vm1bbaqPC6GM/NiQCzM3ik1wOShTTW7eBvkAEcFNcisKYgZynVIdl7nd0HGkicEBR7whQuROIYQ2xCAviQCAFKqsqZmDf55kQwgZHH8jpkIHiLLnXkQTFsJdTc00nHCMyN2KRJqzuy2D/P3XbviOTAvAyGOFHg5NTARAGJA1DY3n/xwcw+R2Z90z//Q5Ba/HwaXjlczfJUAdBo4cACw30PKRwr1h9UHYcbE3NwwlxYxAiGAW1q9jy++aUvrsKSQPDem0Q0TDvoVwZOq6gsHwrMKQd/4MEuhSl9gRNBHxCqDlf+hzYJmesve9TqqqhI4Y6VY3gxuAGNaoNbXBtp8DdM+sDjlhqrW7ADMH6uajFt4R0Ret61KKDIS35JkkE49s51HK2IVWtQVAj9TVdeLQIR2umm4UwENrswLQi/7JqDscNwAdw8HIrBn0zRERu2utS0dIbnjlw16z6MxXPdZJeDkgAXZtNKQFEVx5mNsrIXoAIBLzSPcFhnkhrXuTuZOz8zQBEriubV3XvYg0c/dR/N+ZcQjiAewemVVVFfu+Zu9P5MqmdH/iqzKISMTsPnbxxjQodqUDIoCBqUWKDCcTjw6+aD1tp33fa+yXNXDLPQjdXzkAvlkq9q5qzcwUj1C0k9wzQYHhCWfyZzDve6MrCcltesxshmlXE2Cm0g0RoPcpgFzWi60tYci07etp37eq7o4IGAKEZH2l5njg9xb4uhCamcUaRP9G/yVWP9HdWovUiPeMw/1xgE8WZWb27iVRPEOa+MC1AyQfHIoIBOheynra9m2tzQz0WCPReTPHoOC9WQVCVUJgp45rElrsgWUXwi+xaG6xplSh2gu4ATByMREMaN95mkRHQgAgElvqetq3bVd3QI1FNzAAtNhYlhD6PQ+IQ0yHluk7NwwmgMHv3iPX1DIPA+xOpWbCTLP+F92te1TRD8D7avNkQYxByrKe6rZVM3OEpGBfEuuJi4le0KedW226OQyAjlOoDV2HIqAhgGlLQxhvfeqvCQDE0kb3kTQomFtXHHtSJTN46AeHIqITOEtZ931d96bmCGiIjqjBVmQj+fQRBf2gYCciDCUa/iIAgA9T+DUzkc4yuKpqC/WHM1cFxinuGjQEZBcpy7que2tqABiKChUwU3Kfj7QUhFk4e6NAEBGjzsW/DnAEPZC+AhMCoiF1PTW0nOfdj9UnBCAjYlnKsu61qgKimRkSIqo73dQZzPsQYubdUhwUJBhlFd3GI1F4i6Zaa2std2o/ArDrGOyLAoxIikZXtnm4Fh3Y5MYhs0gAbIoAikjpkzgRTY7rO2OyhX3bzA2zEBE5Aqhrq/teH/dGg0U5n2TkOonoQBe8G9OYvpcUzE8w92YTCg4YFFTwjCo/n0b36Oy43TwQkckxHP5Wa20P+zJBQewcaqqq3DMn3eX01N4hm/MAAEAiiuWfpRT1TFU5ABgRsfY98++zlPfkcBr6abNb4iMidgQnMK3btrcMhx+LB3NPavfTtOeLuvZBI7uWG5rpkq5Q0FDDYUZ0AGePGgrqHuqVO5qzO8KmpCAZOeJRuoWIhMwMBq7o1vZtD1PvD9R3SeIjxMxxRUg08AV0Cm2Pt+N4wJSdABrHPgJ3JyI2S4Y3GrnuO3Pq8OIut6tHQUFAd0LTtnchfJCCiFEFCEMUpusyh4mdVMotPhjOnoiIqPUmQG40f3Yi4Z3hx7hNA2JcHwigIXirWxiKT8EBAEjoYDoQAhzWlbKzDIQvN9ySkJNRLogQpSEswizmAE5BwRzuDtY/Ddd6BNOZcfdBQeoYY38rAhIxGjgjmNZ9r7M7+iEdJQSIexVH3rGTj5kRAIGIkGxy0dLIA6RGx95ZSpQNYE7KJcBJDQct8zaQZgi60gXySbjyUYapdibwANjU5m247wMMO8bMQ+1D8mF2AQnP2YiV9a4I9i/k55nN74nrxNB3Z+JX44pFERGZCdCZ0TMzo4/ZiVAyzMzeG1x13sy2SwgATkTEANmrozeWmLe2HvxIOLKq+aQGBS0nfEC9zcUZRuA3I0wlRoguRGYtF3oe0jLi/dHDPEVkImGW6AMYAMkPryJT9hkwwHgjHtKVeQ+MDgBueK0bE2Tgm6kI47qQNd5ExEggQhgVxRHVo9/XyjcUZBERZ9ZOwt6IJ5o7zQCTxDPPdXYk6ggNrzks68H9jcDMFDxIOGtygPQkiJnchQld677ve7ijn1MQopeNiIsoTyObziGCgbO5GeCUzkd8IyZ33xlvEgzc77s0VyScv87MzO4ijBAlxQ8KoSAiSSml2OibNLooRa7TKWXQ42kOBH26B7C+FNNtKowkHHUenGeegSwOcF0r90g++LQ7EgxQRAis1X3f2wPxBMYubJayFHEpNkPs/R7BO4uS07AP13S6clJh4EtiICKkFEJf1jumdlxq2PlJkUZ2OY0Qugh3d7TqQ2toAoDRDsyKKM9dEWWk4xGRs3HMXR6854K/eZSTGfr4wd9kBPOazCzoIhKGYtseZVFA4rKsS7FSjEuZ8UUc7A5OHg4JDAtxB+JQwVcEyUQVABJ4txQjm3R1pXm1ZXIYMZxiZnIRJvQwFI8lnoKC67IUK4tRdH4czeswyjMRD4DZaCT9HRgZmoOIUY15+HwA3e1CBKAjCDg8p4HSMwc0P7cu3czkUoTRPQIKN/icSQWQeFmWZbFlMZo6rmaPNoAoSQqAU0LtZhym/g0FR5A024Q32AbE6Xv5JPPCTC4phBFQPEpBKetaipWiVAaPhgAO5zpcSrs7p/lJX799O4FHOsLc+mtJQWZmAhFhBNe67Q8G9b0z3lJsSRYtg4Bd4wMlBcGhd1GdcCSHHp5evtnXeofOxwNi59mDt/urAyX4KGPplh44eFQjL/OokpFl7QCXbJo748smXgkQu/War3JDvjc8enCp364YXScir76VWeMBjxgktIy2fa8t0iqfSOEVizrJJIKpO8ajjgnc59B3WPShce9LPhxpREw/hpmBZfgyTR+yExLdjNel2LIAyZIUPNgNADB5NCj4nggSXX2pT/SN63UP3UT9A2S+nw4uM5Mzh/uvWWvxOUIBIFnWdSlWFqBS5EB4BAQ9JOhuCdzGa9NUKdduroAdKfrJAsDV96Z4++riPfQOGYwAznIRzT7FFzK4rksRLwvipGFovi1mVOd+c/8DY6fhlyiI743jXeqhDRF5NM4Ft4eXmFIGwxcFLEV49Mi8ikuHDL4Tfv2aEN4Hdv3uiNySgozouYj2mAwSl3URdimIUtLJDkbLWw53hXJt+p4ShRFMzDM9vLV30N2F2a8RAhjJLCYi59gjAxbdkR8DiFKWImIJkFl6n8x+/ylcAOj/HXIYAcJs0NKLm3nTJ990qg6ZfbG3vIkj8BY+/iQEi6Jj956fcrj7DI9oQhilIPbO6u9wywfM9zex6DRGWqE3fOduLxA8KoIf92QKM4ogsvQW1lfZ+ZFYuc9rg4OvAb7RLsfa+/jeHIUQEgJ1u545sEwnR8UfpMFHy4rLx5QMSxEmZEFkodFuNrO7OTWkaF49tgtczRIgUqcxU0DswXsKoKftTphZ7AThox/8idh79wdGIUn6hW8MqWzQTR+otegUpAgjWRC7hb8N+pKC717sVjvAAW4abyk4BwvZ5ho6BQ/+HL4xJG1xrKJ8CA4glAwnQEZkIiK4BncF4f0r/YIETt8bXYqRovF0HJsxUgw8U5ARsjL388sLAhELEQCLIuddkj0PezBHEDOIaA4zaYZriFeZwhFcdCJe6ZPOOMRECS67+MduZuoyyEwY1QiP2AnBtKTuRFPa81aX4OGF3urXXEoMZutfPRJIM7p8JZXV+F6Gs0nBTr9AWLL7OHUWJSZEiC1+D9pB4liBYDwU4ZsP/lEWDboEwJ7Sy39HdBNHQQwW9YdZNPgDDYkARpdvuFGWt0HufXSze3Drg/pMwUP34fA2M6cemoVFWApJj98mChJRrxp8CGAo9+vyJejmbiwzR+xC7tZ3SxzQoRfLBQfEa95RxYfSQAy0xwJgjxaYO4OOowmEs6VAj8A9Q1/E2OnxiLOdD9A96lMCYDpeM4lgRv9mDF1xLN8/TME05wGQ6bAMki0Fjvh0rKl+hYKISNwRBYb73+sA37vqr8tgMCUfrnW2/Zdk1lwkOYQQo+blwaQTEiFZ9KDOgKgXAvQ5xP8SPb7Vo2OJEPtFhtqcPzhTcErGdL2ZAKWf6sLClPiSbkdmq+8S+hxg3AEHqltwBwyHY/bHnI/3D4zQM2oTRO92cAD0UcoRAZEQITH1Ph4RQXQvrbuA6fEApAx2Q/su28hgLB+x0Dsf7cb/b2fRLnZEKDPA7tMMDh0mBW9TIu8POR57AvwAQXjcg2rXn8Mp7dTf9l7mCzBl5Od5xbdYSpFCRCwk/aw2Pk4DG6rluIW7P5bZnh77kWOCW1rm/rrAH9Hu3L+vk388q+TR+ACOaojDt/ERaMT6XSmlhI82ZdYp1eqRpe3Lz9QFGj6rdQotepiFj1a8UwHhuxQ8ktvXb39IwSTgUgpf2cDowZJJw3HQCOFgksdZtIsN3s75LcSPnhZO4/MbT18Lj3MpMgMkYhoyNxzkfKGnwB+4vODwoPH4cWcaANCzTyPLOV6Pgw5GjnaCiJ2f4xveC7KPcASZpUhZllKOM5bScwMEvLpkZxKYV7o/BgiH1/nxcz+MoedBLrPnlSrqUDNz4DEOj4I7zz0KFZe1lMI8gsBxkWnd//rvL1LwK+MjTwdpXna7peC9b/VCzHUCSBLuMcBcuAKTGDw8Wfna57sumnPy15m+oWf6fMY3b0fo0CTgsiwHwOxjBZELGhe8ylA9OGkcGyTf+Ci3CByGN+PRR6fLYLwzs+gRNiNZnLzRXdlj31j02yBOgKUsPVHR1Qi4dwpS+CIZUj1AxpyczzL4yAPJNMbtqkq8OWUApyxZ+BzvUDCMxLIsZSmlx/FEUSHXAdKsRa8e4ecTfrPF9Wvj2g9E7CH3Ue+VO7zemQumkVjWdS3HsZdIocfMJ6kb+aAvWaMbOzgmfW8yfUoAePjMh/Hs6RUe+QdECt8li83fuIFR6RoiuAyAFIVR7k4OU1K/J8IOtfMAwscpiIf67IbCj30e3e3COEc4k2QE/cyUMZkDYspo1Osv61IyeRZNVnuOCnsIjT3RgEMMqC+b/C0A748bcqeQjPAUybMK8+63sW+5WNd1KaWvsWDuGe5XnIk1LsxZy/rJmAA+eEoE9sqEa4BdNaYh7v4/5j7/+wgRKWo8hpmIFnkADkbQKTh2XiS2GL0/wYMAY99//Da7QbcXGC2NR9/YG98bh5QQERhCL+O9D7BbwaWUhYkZst+o9zzxtZKBOLKP+iGpH4cH8IsUhMlSTCT0CeLBol113bk6IhElviXipVEbfwCkLobXlbdRUfI4BX99HI403K6swHjrbvyNOBy1yF9THG+YScXJRZpxYGapROQBIZQrH+Ymjns7Mi5Oxzc5eaRA3X3sdDyCv+zxlBATZ+4EJmY+Dq/N5FBaeIj8Xl8puaJgPBa505rgDcCrkODO+JhvJ7LM8HycSJgv9WZOB5W9W0EexUd0pTRSl41UbYox4hDcIg+0qvxbWXTQ8Oj5FB1xexcav35gGI7MyFHMzvlHLNq9O+GHWPT30I3FsKu2Qn3oANhJmyugecRpagsZSy9TGqib9skQ9piXiy3Lsjwogw9jgcl2IOT26YQYe/AnZMcw09ayi5/3bHvfbEGUSyuIfTNGT2RNFBysm96PLHBa10Vix9WjZuJXxlhwwEHBjmr8Ea0ee8uxQUGApGAWPsyr5mFbaADsb3QKloVOp/XvpuAHGCHOTB49+47WgDo6kXZ08xh1hIfh9FuK3MogInI56dNpfVwGHwyO3yrUVCSZp9d2dCVsR2/AoxVbT2V6F8FR2XMIGkLE06lCoccQMYVwtqWc7Om0FP7UzuOdg6Vu1kv6B/OvLncHs2XDbXDTptG9vzdT1eOs+rGtdPIFIpigHt9Osxo/BsxMHUREJovDt6f1P2AmAqbmidHaskFe9mw+LP/xQK6eXnLo1T7R64EzxnyJWFaip1MRfoD1fgOgz4e1m5lpq9u2bfte29Ayh2MzW/mElyHBFMB2TuyfG9C6EgUkkgKFvp0WeSReugtwDiqOjNMtvN7If0hba/u+bdu219ayTeVbN7SnAcJGdAKObYk9teW30jEGclF2eX5a5JHc2u9SMP0wNTVtNfD1PqN+rLgcT+UqqOrLRjc8ijd55QNcWnlHeV4L/2EWBXD3sHLWrGkQ8BK9+T7LOiecaxPxLj18+hZLAaDlaZXPlSh8FWA/lmGEbNn2NazDvu/7ZbuMrX3zxO+FSiP4x86g9yY8rYDEYmlRQlnWhccKAbxl5AOgz3I/kkofYOy3gmBS1da0Hr2Mty32uE/OFV4tuk3w5pxnxkLv3zgBEssqWJawgn/ck+kQ6173fdv3bK74hoBvJoK3797/3NucFouBUFmkPODGwK+cIDkHMRD1Klr3bd+2ba97bdk8MhTC2I59D9/Vcum7MjizKJADF3SWReSBfAVkA9W77/iwQlfuYW8blvNJOdz37XLZ6h5emvlRgkhGY0vX4cJAz/PjCAdwlHfcJoenxwIAguSZkLl5+z7AX6YgDdq4mdbe+rNZ5tvJyd+jYNIvNcvx2nuzPSgIjpj7Qx7IV8CvyiDeyBCOnnp+K14fiWA385/f7lBu7GSQFdyPDHnDoB8SFGHaM0XUKwRGyENOWWNB5HS05LpBmdmY+P78tD5EhxAJqdS+H2dEs2w1ujR3aL3u45MREQ2SE7GyiYuW0kTUHdEt+gS4g1nn4xA6n+V37ImgXqF4B1533DCb0UH0KQEah3d/On7RTOQyD5uYu5fogu5IqtF2mcAB7KpR8JtsE/dmBJ8s9UUQeLjdCPAJ9a4AflnJjLuSe5RhZn7GqTbV7JyNiGBu6WsevAqHiehbPj5azryKBydmxXdrIt8AhO7d3x9pha6MUXhsiOQMCL0ZDVJJgK21nRDiYC0jIjKkw8Ub+LKqME56G1kzcJzz4724Yr79yJk+AvCXKQgUm2yViQiZSZYAqK1WQQxf1aOpU9qL+Go4abkuGK1l6f5CEeJ7AB8e8kvwkl0YyEKVIjOVvTVVt9bqLiOvmDHhZA5DACXwLUsW1b/HoncBfmHIzYKJX7kb78NzQERDIyNWJioi2SBate67dAOeXTOJncYSYiwOxcJ8rM0LEQ035rrhZC8i6nbwqxC/TsEwL3lPcnPjRtxk6We+aN03zvbjmC373YjA89SpXjqynk5Pp3VdltL9kjHxLvOH7vw6tAT41S/MtyF0JzdDImlamqmqadtLFFVDaNhMbvRl3lhwCXyn0+m0riUKtq89gf6ze6hfRtYBIn7iENxwx/yGR/clRCIjk8yBNtmpnyyVkzZTyqNkfOiX09PT8/PTaV2WkkvzfnsXhAPgLyoL6QKcydVZP98Hfv2c3aM9nJGbqauqViICN0i3NPv1uzlF2pBJSvDn8/PT09NpLb1Z/jEBBLil2y+S8JaCb/x+uKXgmxvF7nNyd3NVbUSEni3xw1jEcR0JEDkUzOnp+fn5+em0rks0Ib+9PL5/y68AvAPozvjgTgghjA5uwKrEROjaRqI+jlFxdwd0B2IuZV1Pp6enp6fTKQh4mx/7Ham7HjLbpztz/5CCjj5OsEKPnuNR869Li+1hlmfB5SEcDkAiy5L4np5O67IE/d57eL8L8M3lJkl/wyRX4OBYaMDsXQyIBAhuIks/QCqWBwOgOQCLLOt6SninUDE4PDX/WxjzPsCvXdez1GOsNyAAADsgmEnpx8/17vR5VhYgy7KmhTidTutS5LEE5y+O+wAfVcmepWiY4VoW6YmVBBbedGmtmAOiORCXsq7DBg4r/4tW4NMhb1clPwPVH7ZDFlOM9BSGD84mmrXzzNS7jpoDmiOJhIl/Oq1r2sCRboK/izFxLLAEwHd2V18tmXTKOgA4jqUxdwCPJfbUNkisLKO2IBwXNQMkA6T00Z5Op9O6LEs5zmb9I0PuLCzfhTjAGTjkKa5xMhkCylE4BnMl0tGHTswcyABFShDwKSWQ/yB/AoD080X8bTmZT4tYATG50iKd3aqaASBzWRYRAQp70b3N2qtDkESZHcwhOXTYCGF+OL3yiwDj6FWInFPyYaAaG+EyynawPKBWtba97nuzYLr16bSWJYKekMlMSHDvjsAsDu5IUsJGnMJLk3snxfz+OHS7xHbmY/l90DItQFYYQFactX7cZd33fWtmiLwsp+fn03oqRTiK0GHUF+QKLjObox8cekoOfWiN73eG5AomGMQJXHCYt1FGl9VKrbVWW62t7bEYv2vS5PTt+fnb89MaZyRmsTz3fg1Zf8x46NCnpzCCDy4w/A5AcNNaFRxMxwbrUT0QS5tVW6t5Wuu273vdWx6dDJAAv33/8de3b89P68KMGf2NMq3QpY5AJMtyWp+eTqen09MaR9P+YYDuprU2MLA2AMZZgXFUZWt1r7VusXp7uWzbVvfa4owwQGSRdX3+9v0f//jx1/dvz+sihOZWo93LtJOCAbqffXp6Op3WNc5t/bNDALzV7QQO2iwAmmucZFprrfte961uddu3bdvirM9atZllI0BmWZbT07d//78fP358f35eV2ZAN63R2y3l0QARmJf004aG+cP4QNytbpeVHXTPY53yhKok2bZte/Bl/tNqdjrxI8GyLKf/+/T8/dv3b9++ndbCEo2latXkVXMkROZlXU+pRtcij1SV/y5ABGvbeRX0thkAguXBrJfL+fx6Pp8vl0s/Crr181tz/2wUqoY6KUtZT8/fvn3//vy8LosQgVp0sHNAIkFHZFkS4el0Ktx3mP2J0Q2FAHjbL5eCXjcFxDjaaNsur+fXny/n1/P5ctn3uuehsW5mw6RklTEhATGxrOvz0/cfz0+n9VSEycFNtbaqBojUC1nT014eqkj+7SGIrvt2MbR9a0BoVmvdLufX15eXl5fz6+WybXWPgMfcoB+D4MP1hGyEQVTKy/r/n9enNRLWkUiyONIPCYlFlmVZ1nBDg0H/mJeWYYAQgNXLamjbpQGht7pvl9eXl5eXn68vl0sULsUhDWCj580oWfJe+gmEbbtIWaQUKdHmpvTCeATqp8pHvnApElsI/ggN/fBbhAi07puBXS4NGK3t+/n15eXnz58vr+fzvtfaei1r5uGvrnSkOgwUkJCRmbiILEuJfATGdiaCPBxmWU/LIo+uQf8qxg6QCb1tm4JeLg0Zre6X88vPnz9//nw5n4M7zQx6peedaPjIW+RmACJi4aUs67IsETkVYUAHYpHYqPTYzqPfQwjYAe5bc71cFAStDgY9ny97HlNlia0vbR1z8/EjXHNHAEKkyjW3fIiUspRlAWRz5FKWKLf+gyq01yEDgDCjt7qh62VTVNS6nV9fX19fz+dti0Pu+2fjP59gwe3vCd0dzLRGRFFKWdblpICsDhQ7Hv4g/XKaQUEXSYCm265ooPt2uZzP58tl22ptqmrH43hwGICDYfdGS1nqyRyRiwLGRpA/mWcCiDrBOH+QmVyrgFlrig4aNZ+hPLP7nB2bI+6cpnYXuQMCkGEcWh+eD3JRRxrbJH7fQuSGudvzaj1WCwABQYQQrFUwa2rgUPc9a3rimL/s2eI9L/MVUrqTAWTFrAPy0gx6Z7TfthCTxNyusYeLAYSDgoRmrRmg11F01g+dn3I2X3jmPYtsGIElAHLZq/rRbeO3wY0tj9dVBpF60AaCiNH0uJKbtd0AvW6Xw7p/Oad4ZyKWxhK5LFttWbj2OxfNK3sKe6zE5ss9366tVgAgEmby1tDV9qqAVrfL5XLZ9lCgcY7hlzTMPAkEAEc3VKoUvKEGnx5jcsUq9z7ZjbI5GABQnF4UC+axpKdadwg7yOhtN9e2b5oUPF+2vR8l+isU7IcKjSrXvOO+XfJQr/efmCNMb79Ns+eCXM+oeG7zRYqiWwRwA7VWtx2QwIUQdL+wadu2Buht27bE9+v8idcAEcBd4zih+tkRpQ4f4Y83LeHFuatT4TAgoJmrtu1yIWR0IbB6Jjatl60CdjNR43y4X2TNVLmZBh40bHXfLpdL1XEA5FuTE/6g+4hV7gOMXUOqrg6Ue/cpnCM3a207/7wQCLIA2P7qZLpfzg0ArbZwsHuH4K8tx8wjnDZEAkIECISX82WrakbeRWbCmcxnfWEBwdGjYVSH3/PqqqpVm6nDlIONImRt+3b++cq8soubbqLoup/PFRBdtR9B+ZCyO5BPy3pZf9c7zUZmgxDc6n657K2ZGdLxnWHMumRlOiS5bjZ27n0zQz81GoiFi5kwxUHT1uq+nV9eyqkuIAatXhRB98ulAiIE5eFKm191nL4mUl9XGkWf1PsQE4ZD7WDggMhLYYLcbHdbn+MAALEgrMfpgH1/+URBd3dTyyRtU3Vgc89dAAkwMkquZizquKMSWq177bUEAa/vgfePajhGOWtUxhIzSW8LE4cX992DJMtaBDFKEmBe0urKPQTL8vlmvrGTcaKw9s0M2UQ1yt8UgFKd5UmvgChqgGhMbmoavVFS55phaqtjkeJYt5jUZPQ/OzryLlLKUkSIEd217nWvqgZUlkUI3FRbMZoaIXkqDtWmTaNCIyr8xtbQDjAytqotXFz1OHNUh2/kZq3F7ikARHGY34XU6VEJaQbvAcxCQDw6SwiXaNpXlrKULC1At7Zf9m3fmxpS4TjicitMOHW4cHewyMW2pqru0ZBZ2IjfBZgJvlxNMbOj++9YNAdxcDMjGN2ECPP0sdj2mL7owaNRHEeIyNlWIvhSuLCUDAClLMxEAK5t37bztu21ORJY286vPxnd9SicjxKbPCS5qSo4ARNJy77mB0CzONVZVXPvpTsi1d5rBwHMWt1ri4dA2TCHmEGK54awtAvWszDXrlOyZW+XEQIX7SWjn4+IkAgRIsRZ7pdyFiZWR9f98vpvQTDXXoKXK1eqrdaQHQMgJGYVZZI0NSmCuafPTJPJLPiZmSRK3N0sz8tWVRIBKcu6rKwi9egm3HMQE192M4CIGJ1cmaMfaC4iCRKntsH+mFx5VFaSObru23kVYXLLkwccXAf9amtm1qNOQgR1REQ7AAbGI8BxcwdiohbUBnM1BwyscqJyen56eiq+bZUOgIDg2VBlAOzmN/JKJH1Vhfoe+F4jamCUjr1ZlFaYo5qxMIG1Wmtr8Rj7VRHj+Lg8/Cium+cmpEXOgo6w/WCOUS4dimaoZAQCAFufaF0YUX7I6fs/f3z/fqJWjfqe4TCv3cT61U/v5iGzRo6xTmrQt1oPJRvypdoMxUndSNa1CGG26M2KS6CoQVS1NILDml6d5JjreubmvUogRNHTleBsrupNgPfybRGS7+X5X//7X3/99VzcPLZqDMPTzaAf/8sYpW/HNfOsobRm2lRbOMDpaY6OCQoklLZwET6qnSH9hGMzrAdvAiLmIU/DNI0PDUG0DhCiUycRUkg+SZXnwiRref7+j3/985/fFgQ4SvzvAYwb2JCCQNeatlaxxUbz2gu4urMQBHBAJncn6avysyOPmVY5nHTE3kwYJ9bpnswEkMgoNxTRBJAIoNAihMLCUbiy0kMAE5m65po9UQO3PKbbva9FBKNFf/K8FvooqIHrSDMZGlO5ECA6wuRlX3lS2ZbPKakK4PMpj3EnFmBCxP8CtX/E2qAwe7MAAAAASUVORK5CYII=\n", 425 | "text/plain": [ 426 | "" 427 | ] 428 | }, 429 | "metadata": {}, 430 | "output_type": "display_data" 431 | } 432 | ], 433 | "source": [ 434 | "raw_inp = raw_data[i][j]\n", 435 | "inp = display_transform(raw_inp)\n", 436 | "with torch.no_grad():\n", 437 | " res = g(inp.unsqueeze(0), z)[0]\n", 438 | "render_img(inp[0])\n", 439 | "render_img(res[0])" 440 | ] 441 | }, 442 | { 443 | "cell_type": "code", 444 | "execution_count": 6, 445 | "metadata": {}, 446 | "outputs": [], 447 | "source": [ 448 | "inp = np.array([raw_data[1200][0]], dtype=\"float64\")" 449 | ] 450 | }, 451 | { 452 | "cell_type": "code", 453 | "execution_count": 7, 454 | "metadata": { 455 | "scrolled": true 456 | }, 457 | "outputs": [ 458 | { 459 | "data": { 460 | "text/plain": [ 461 | "Generator(\n", 462 | " (encode0): Sequential(\n", 463 | " (pad): _SamePad()\n", 464 | " (conv): Conv2d(1, 64, kernel_size=(3, 3), stride=(2, 2))\n", 465 | " (relu): LeakyReLU(negative_slope=0.2)\n", 466 | " (batchnorm): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n", 467 | " )\n", 468 | " (encode1): _EncoderBlock(\n", 469 | " (pre_conv): Sequential(\n", 470 | " (pad): _SamePad()\n", 471 | " (conv): Conv2d(1, 1, kernel_size=(3, 3), stride=(2, 2))\n", 472 | " )\n", 473 | " (conv0): Sequential(\n", 474 | " (pad): _SamePad()\n", 475 | " (conv): Conv2d(65, 64, kernel_size=(3, 3), stride=(1, 1))\n", 476 | " (relu): LeakyReLU(negative_slope=0.2)\n", 477 | " (batchnorm): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n", 478 | " )\n", 479 | " (conv1): Sequential(\n", 480 | " (pad): _SamePad()\n", 481 | " (conv): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1))\n", 482 | " (relu): LeakyReLU(negative_slope=0.2)\n", 483 | " (batchnorm): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n", 484 | " )\n", 485 | " (conv2): Sequential(\n", 486 | " (pad): _SamePad()\n", 487 | " (conv): Conv2d(192, 64, kernel_size=(3, 3), stride=(1, 1))\n", 488 | " (relu): LeakyReLU(negative_slope=0.2)\n", 489 | " (batchnorm): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n", 490 | " )\n", 491 | " (conv3): Sequential(\n", 492 | " (pad): _SamePad()\n", 493 | " (conv): Conv2d(256, 64, kernel_size=(3, 3), stride=(2, 2))\n", 494 | " (relu): LeakyReLU(negative_slope=0.2)\n", 495 | " (batchnorm): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n", 496 | " )\n", 497 | " )\n", 498 | " (encode2): _EncoderBlock(\n", 499 | " (pre_conv): Sequential(\n", 500 | " (pad): _SamePad()\n", 501 | " (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2))\n", 502 | " )\n", 503 | " (conv0): Sequential(\n", 504 | " (pad): _SamePad()\n", 505 | " (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))\n", 506 | " (relu): LeakyReLU(negative_slope=0.2)\n", 507 | " (batchnorm): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n", 508 | " )\n", 509 | " (conv1): Sequential(\n", 510 | " (pad): _SamePad()\n", 511 | " (conv): Conv2d(192, 128, kernel_size=(3, 3), stride=(1, 1))\n", 512 | " (relu): LeakyReLU(negative_slope=0.2)\n", 513 | " (batchnorm): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n", 514 | " )\n", 515 | " (conv2): Sequential(\n", 516 | " (pad): _SamePad()\n", 517 | " (conv): Conv2d(320, 128, kernel_size=(3, 3), stride=(1, 1))\n", 518 | " (relu): LeakyReLU(negative_slope=0.2)\n", 519 | " (batchnorm): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n", 520 | " )\n", 521 | " (conv3): Sequential(\n", 522 | " (pad): _SamePad()\n", 523 | " (conv): Conv2d(448, 128, kernel_size=(3, 3), stride=(2, 2))\n", 524 | " (relu): LeakyReLU(negative_slope=0.2)\n", 525 | " (batchnorm): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n", 526 | " )\n", 527 | " )\n", 528 | " (encode3): _EncoderBlock(\n", 529 | " (pre_conv): Sequential(\n", 530 | " (pad): _SamePad()\n", 531 | " (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2))\n", 532 | " )\n", 533 | " (conv0): Sequential(\n", 534 | " (pad): _SamePad()\n", 535 | " (conv): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1))\n", 536 | " (relu): LeakyReLU(negative_slope=0.2)\n", 537 | " (batchnorm): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n", 538 | " )\n", 539 | " (conv1): Sequential(\n", 540 | " (pad): _SamePad()\n", 541 | " (conv): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1))\n", 542 | " (relu): LeakyReLU(negative_slope=0.2)\n", 543 | " (batchnorm): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n", 544 | " )\n", 545 | " (conv2): Sequential(\n", 546 | " (pad): _SamePad()\n", 547 | " (conv): Conv2d(384, 128, kernel_size=(3, 3), stride=(1, 1))\n", 548 | " (relu): LeakyReLU(negative_slope=0.2)\n", 549 | " (batchnorm): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n", 550 | " )\n", 551 | " (conv3): Sequential(\n", 552 | " (pad): _SamePad()\n", 553 | " (conv): Conv2d(512, 128, kernel_size=(3, 3), stride=(2, 2))\n", 554 | " (relu): LeakyReLU(negative_slope=0.2)\n", 555 | " (batchnorm): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n", 556 | " )\n", 557 | " )\n", 558 | " (z_reshape0): Linear(in_features=100, out_features=32, bias=True)\n", 559 | " (z_reshape1): Linear(in_features=100, out_features=64, bias=True)\n", 560 | " (z_reshape2): Linear(in_features=100, out_features=98, bias=True)\n", 561 | " (decode0): _DecoderBlock(\n", 562 | " (conv0): Sequential(\n", 563 | " (pad): _SamePad()\n", 564 | " (conv): Conv2d(136, 128, kernel_size=(3, 3), stride=(1, 1))\n", 565 | " (relu): LeakyReLU(negative_slope=0.2)\n", 566 | " (batchnorm): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n", 567 | " )\n", 568 | " (conv1): Sequential(\n", 569 | " (pad): _SamePad()\n", 570 | " (conv): Conv2d(264, 128, kernel_size=(3, 3), stride=(1, 1))\n", 571 | " (relu): LeakyReLU(negative_slope=0.2)\n", 572 | " (batchnorm): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n", 573 | " )\n", 574 | " (conv2): Sequential(\n", 575 | " (pad): _SamePad()\n", 576 | " (conv): Conv2d(392, 128, kernel_size=(3, 3), stride=(1, 1))\n", 577 | " (relu): LeakyReLU(negative_slope=0.2)\n", 578 | " (batchnorm): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n", 579 | " )\n", 580 | " (conv_t3): Sequential(\n", 581 | " (upsample): Upsample(size=4, mode=nearest)\n", 582 | " (conv): ConvTranspose2d(520, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 583 | " (relu): LeakyReLU(negative_slope=0.2)\n", 584 | " (batchnorm): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n", 585 | " )\n", 586 | " )\n", 587 | " (decode1): _DecoderBlock(\n", 588 | " (pre_conv_t0): Sequential(\n", 589 | " (upsample): Upsample(size=4, mode=nearest)\n", 590 | " (conv): ConvTranspose2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 591 | " )\n", 592 | " (conv0): Sequential(\n", 593 | " (pad): _SamePad()\n", 594 | " (conv): Conv2d(388, 128, kernel_size=(3, 3), stride=(1, 1))\n", 595 | " (relu): LeakyReLU(negative_slope=0.2)\n", 596 | " (batchnorm): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n", 597 | " )\n", 598 | " (pre_conv_t1): Sequential(\n", 599 | " (upsample): Upsample(size=4, mode=nearest)\n", 600 | " (conv): ConvTranspose2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 601 | " )\n", 602 | " (conv1): Sequential(\n", 603 | " (pad): _SamePad()\n", 604 | " (conv): Conv2d(516, 128, kernel_size=(3, 3), stride=(1, 1))\n", 605 | " (relu): LeakyReLU(negative_slope=0.2)\n", 606 | " (batchnorm): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n", 607 | " )\n", 608 | " (pre_conv_t2): Sequential(\n", 609 | " (upsample): Upsample(size=4, mode=nearest)\n", 610 | " (conv): ConvTranspose2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 611 | " )\n", 612 | " (conv2): Sequential(\n", 613 | " (pad): _SamePad()\n", 614 | " (conv): Conv2d(644, 128, kernel_size=(3, 3), stride=(1, 1))\n", 615 | " (relu): LeakyReLU(negative_slope=0.2)\n", 616 | " (batchnorm): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n", 617 | " )\n", 618 | " (conv_t3): Sequential(\n", 619 | " (upsample): Upsample(size=7, mode=nearest)\n", 620 | " (conv): ConvTranspose2d(644, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 621 | " (relu): LeakyReLU(negative_slope=0.2)\n", 622 | " (batchnorm): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n", 623 | " )\n", 624 | " )\n", 625 | " (decode2): _DecoderBlock(\n", 626 | " (pre_conv_t0): Sequential(\n", 627 | " (upsample): Upsample(size=7, mode=nearest)\n", 628 | " (conv): ConvTranspose2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 629 | " )\n", 630 | " (conv0): Sequential(\n", 631 | " (pad): _SamePad()\n", 632 | " (conv): Conv2d(322, 64, kernel_size=(3, 3), stride=(1, 1))\n", 633 | " (relu): LeakyReLU(negative_slope=0.2)\n", 634 | " (batchnorm): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n", 635 | " )\n", 636 | " (pre_conv_t1): Sequential(\n", 637 | " (upsample): Upsample(size=7, mode=nearest)\n", 638 | " (conv): ConvTranspose2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 639 | " )\n", 640 | " (conv1): Sequential(\n", 641 | " (pad): _SamePad()\n", 642 | " (conv): Conv2d(386, 64, kernel_size=(3, 3), stride=(1, 1))\n", 643 | " (relu): LeakyReLU(negative_slope=0.2)\n", 644 | " (batchnorm): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n", 645 | " )\n", 646 | " (pre_conv_t2): Sequential(\n", 647 | " (upsample): Upsample(size=7, mode=nearest)\n", 648 | " (conv): ConvTranspose2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 649 | " )\n", 650 | " (conv2): Sequential(\n", 651 | " (pad): _SamePad()\n", 652 | " (conv): Conv2d(450, 64, kernel_size=(3, 3), stride=(1, 1))\n", 653 | " (relu): LeakyReLU(negative_slope=0.2)\n", 654 | " (batchnorm): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n", 655 | " )\n", 656 | " (conv_t3): Sequential(\n", 657 | " (upsample): Upsample(size=14, mode=nearest)\n", 658 | " (conv): ConvTranspose2d(386, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 659 | " (relu): LeakyReLU(negative_slope=0.2)\n", 660 | " (batchnorm): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n", 661 | " )\n", 662 | " )\n", 663 | " (decode3): _DecoderBlock(\n", 664 | " (pre_conv_t0): Sequential(\n", 665 | " (upsample): Upsample(size=14, mode=nearest)\n", 666 | " (conv): ConvTranspose2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 667 | " )\n", 668 | " (conv0): Sequential(\n", 669 | " (pad): _SamePad()\n", 670 | " (conv): Conv2d(192, 64, kernel_size=(3, 3), stride=(1, 1))\n", 671 | " (relu): LeakyReLU(negative_slope=0.2)\n", 672 | " (batchnorm): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n", 673 | " )\n", 674 | " (pre_conv_t1): Sequential(\n", 675 | " (upsample): Upsample(size=14, mode=nearest)\n", 676 | " (conv): ConvTranspose2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 677 | " )\n", 678 | " (conv1): Sequential(\n", 679 | " (pad): _SamePad()\n", 680 | " (conv): Conv2d(256, 64, kernel_size=(3, 3), stride=(1, 1))\n", 681 | " (relu): LeakyReLU(negative_slope=0.2)\n", 682 | " (batchnorm): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n", 683 | " )\n", 684 | " (pre_conv_t2): Sequential(\n", 685 | " (upsample): Upsample(size=14, mode=nearest)\n", 686 | " (conv): ConvTranspose2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 687 | " )\n", 688 | " (conv2): Sequential(\n", 689 | " (pad): _SamePad()\n", 690 | " (conv): Conv2d(320, 64, kernel_size=(3, 3), stride=(1, 1))\n", 691 | " (relu): LeakyReLU(negative_slope=0.2)\n", 692 | " (batchnorm): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n", 693 | " )\n", 694 | " (conv_t3): Sequential(\n", 695 | " (upsample): Upsample(size=28, mode=nearest)\n", 696 | " (conv): ConvTranspose2d(320, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 697 | " (relu): LeakyReLU(negative_slope=0.2)\n", 698 | " (batchnorm): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n", 699 | " )\n", 700 | " )\n", 701 | " (decode4): _DecoderBlock(\n", 702 | " (pre_conv_t0): Sequential(\n", 703 | " (upsample): Upsample(size=28, mode=nearest)\n", 704 | " (conv): ConvTranspose2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 705 | " )\n", 706 | " (conv0): Sequential(\n", 707 | " (pad): _SamePad()\n", 708 | " (conv): Conv2d(129, 64, kernel_size=(3, 3), stride=(1, 1))\n", 709 | " (relu): LeakyReLU(negative_slope=0.2)\n", 710 | " (batchnorm): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n", 711 | " )\n", 712 | " (pre_conv_t1): Sequential(\n", 713 | " (upsample): Upsample(size=28, mode=nearest)\n", 714 | " (conv): ConvTranspose2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 715 | " )\n", 716 | " (conv1): Sequential(\n", 717 | " (pad): _SamePad()\n", 718 | " (conv): Conv2d(193, 64, kernel_size=(3, 3), stride=(1, 1))\n", 719 | " (relu): LeakyReLU(negative_slope=0.2)\n", 720 | " (batchnorm): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n", 721 | " )\n", 722 | " (pre_conv_t2): Sequential(\n", 723 | " (upsample): Upsample(size=28, mode=nearest)\n", 724 | " (conv): ConvTranspose2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 725 | " )\n", 726 | " (conv2): Sequential(\n", 727 | " (pad): _SamePad()\n", 728 | " (conv): Conv2d(257, 64, kernel_size=(3, 3), stride=(1, 1))\n", 729 | " (relu): LeakyReLU(negative_slope=0.2)\n", 730 | " (batchnorm): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n", 731 | " )\n", 732 | " )\n", 733 | " (final_conv0): Sequential(\n", 734 | " (pad): _SamePad()\n", 735 | " (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))\n", 736 | " (relu): LeakyReLU(negative_slope=0.2)\n", 737 | " (batchnorm): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n", 738 | " )\n", 739 | " (final_conv1): Sequential(\n", 740 | " (pad): _SamePad()\n", 741 | " (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))\n", 742 | " (relu): LeakyReLU(negative_slope=0.2)\n", 743 | " (batchnorm): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)\n", 744 | " )\n", 745 | " (final_conv2): Sequential(\n", 746 | " (pad): _SamePad()\n", 747 | " (conv): Conv2d(64, 1, kernel_size=(3, 3), stride=(1, 1))\n", 748 | " )\n", 749 | " (tanh): Tanh()\n", 750 | ")" 751 | ] 752 | }, 753 | "execution_count": 7, 754 | "metadata": {}, 755 | "output_type": "execute_result" 756 | } 757 | ], 758 | "source": [ 759 | "torch_g = Generator(dim=28, channels=1)\n", 760 | "torch_g.eval()" 761 | ] 762 | }, 763 | { 764 | "cell_type": "code", 765 | "execution_count": 10, 766 | "metadata": {}, 767 | "outputs": [], 768 | "source": [ 769 | "input_images = torch.ones((32, 1, 28, 28))\n", 770 | "z = torch.randn((32, 100))" 771 | ] 772 | }, 773 | { 774 | "cell_type": "code", 775 | "execution_count": 70, 776 | "metadata": {}, 777 | "outputs": [], 778 | "source": [ 779 | "from discriminator import create_d" 780 | ] 781 | }, 782 | { 783 | "cell_type": "code", 784 | "execution_count": 72, 785 | "metadata": {}, 786 | "outputs": [ 787 | { 788 | "data": { 789 | "text/plain": [ 790 | "DenseNet(\n", 791 | " (features): Sequential(\n", 792 | " (conv0): Conv2d(56, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n", 793 | " (norm0): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", 794 | " (relu0): LeakyReLU(negative_slope=0.2)\n", 795 | " (pool0): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n", 796 | " (denseblock1): _DenseBlock(\n", 797 | " (denselayer1): _DenseLayer(\n", 798 | " (norm1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", 799 | " (relu1): LeakyReLU(negative_slope=0.2)\n", 800 | " (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 801 | " (norm2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", 802 | " (relu2): LeakyReLU(negative_slope=0.2)\n", 803 | " (conv2): Conv2d(64, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 804 | " )\n", 805 | " (denselayer2): _DenseLayer(\n", 806 | " (norm1): InstanceNorm2d(80, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", 807 | " (relu1): LeakyReLU(negative_slope=0.2)\n", 808 | " (conv1): Conv2d(80, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 809 | " (norm2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", 810 | " (relu2): LeakyReLU(negative_slope=0.2)\n", 811 | " (conv2): Conv2d(64, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 812 | " )\n", 813 | " (denselayer3): _DenseLayer(\n", 814 | " (norm1): InstanceNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", 815 | " (relu1): LeakyReLU(negative_slope=0.2)\n", 816 | " (conv1): Conv2d(96, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 817 | " (norm2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", 818 | " (relu2): LeakyReLU(negative_slope=0.2)\n", 819 | " (conv2): Conv2d(64, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 820 | " )\n", 821 | " )\n", 822 | " (transition1): _Transition(\n", 823 | " (norm): InstanceNorm2d(112, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", 824 | " (relu): LeakyReLU(negative_slope=0.2)\n", 825 | " (conv): Conv2d(112, 56, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 826 | " (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)\n", 827 | " )\n", 828 | " (denseblock2): _DenseBlock(\n", 829 | " (denselayer1): _DenseLayer(\n", 830 | " (norm1): InstanceNorm2d(56, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", 831 | " (relu1): LeakyReLU(negative_slope=0.2)\n", 832 | " (conv1): Conv2d(56, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 833 | " (norm2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", 834 | " (relu2): LeakyReLU(negative_slope=0.2)\n", 835 | " (conv2): Conv2d(64, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 836 | " )\n", 837 | " (denselayer2): _DenseLayer(\n", 838 | " (norm1): InstanceNorm2d(72, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", 839 | " (relu1): LeakyReLU(negative_slope=0.2)\n", 840 | " (conv1): Conv2d(72, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 841 | " (norm2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", 842 | " (relu2): LeakyReLU(negative_slope=0.2)\n", 843 | " (conv2): Conv2d(64, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 844 | " )\n", 845 | " (denselayer3): _DenseLayer(\n", 846 | " (norm1): InstanceNorm2d(88, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", 847 | " (relu1): LeakyReLU(negative_slope=0.2)\n", 848 | " (conv1): Conv2d(88, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 849 | " (norm2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", 850 | " (relu2): LeakyReLU(negative_slope=0.2)\n", 851 | " (conv2): Conv2d(64, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 852 | " )\n", 853 | " )\n", 854 | " (transition2): _Transition(\n", 855 | " (norm): InstanceNorm2d(104, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", 856 | " (relu): LeakyReLU(negative_slope=0.2)\n", 857 | " (conv): Conv2d(104, 52, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 858 | " (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)\n", 859 | " )\n", 860 | " (denseblock3): _DenseBlock(\n", 861 | " (denselayer1): _DenseLayer(\n", 862 | " (norm1): InstanceNorm2d(52, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", 863 | " (relu1): LeakyReLU(negative_slope=0.2)\n", 864 | " (conv1): Conv2d(52, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 865 | " (norm2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", 866 | " (relu2): LeakyReLU(negative_slope=0.2)\n", 867 | " (conv2): Conv2d(64, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 868 | " )\n", 869 | " (denselayer2): _DenseLayer(\n", 870 | " (norm1): InstanceNorm2d(68, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", 871 | " (relu1): LeakyReLU(negative_slope=0.2)\n", 872 | " (conv1): Conv2d(68, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 873 | " (norm2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", 874 | " (relu2): LeakyReLU(negative_slope=0.2)\n", 875 | " (conv2): Conv2d(64, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 876 | " )\n", 877 | " (denselayer3): _DenseLayer(\n", 878 | " (norm1): InstanceNorm2d(84, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", 879 | " (relu1): LeakyReLU(negative_slope=0.2)\n", 880 | " (conv1): Conv2d(84, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 881 | " (norm2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", 882 | " (relu2): LeakyReLU(negative_slope=0.2)\n", 883 | " (conv2): Conv2d(64, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 884 | " )\n", 885 | " )\n", 886 | " (transition3): _Transition(\n", 887 | " (norm): InstanceNorm2d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", 888 | " (relu): LeakyReLU(negative_slope=0.2)\n", 889 | " (conv): Conv2d(100, 50, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 890 | " (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)\n", 891 | " )\n", 892 | " (denseblock4): _DenseBlock(\n", 893 | " (denselayer1): _DenseLayer(\n", 894 | " (norm1): InstanceNorm2d(50, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", 895 | " (relu1): LeakyReLU(negative_slope=0.2)\n", 896 | " (conv1): Conv2d(50, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 897 | " (norm2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", 898 | " (relu2): LeakyReLU(negative_slope=0.2)\n", 899 | " (conv2): Conv2d(64, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 900 | " )\n", 901 | " (denselayer2): _DenseLayer(\n", 902 | " (norm1): InstanceNorm2d(66, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", 903 | " (relu1): LeakyReLU(negative_slope=0.2)\n", 904 | " (conv1): Conv2d(66, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 905 | " (norm2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", 906 | " (relu2): LeakyReLU(negative_slope=0.2)\n", 907 | " (conv2): Conv2d(64, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 908 | " )\n", 909 | " (denselayer3): _DenseLayer(\n", 910 | " (norm1): InstanceNorm2d(82, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", 911 | " (relu1): LeakyReLU(negative_slope=0.2)\n", 912 | " (conv1): Conv2d(82, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 913 | " (norm2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", 914 | " (relu2): LeakyReLU(negative_slope=0.2)\n", 915 | " (conv2): Conv2d(64, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 916 | " )\n", 917 | " )\n", 918 | " (norm5): InstanceNorm2d(98, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", 919 | " )\n", 920 | " (classifier): Linear(in_features=98, out_features=1, bias=True)\n", 921 | ")" 922 | ] 923 | }, 924 | "execution_count": 72, 925 | "metadata": {}, 926 | "output_type": "execute_result" 927 | } 928 | ], 929 | "source": [ 930 | "create_d(28)" 931 | ] 932 | }, 933 | { 934 | "cell_type": "code", 935 | "execution_count": 31, 936 | "metadata": {}, 937 | "outputs": [ 938 | { 939 | "data": { 940 | "text/plain": [ 941 | "torch.Size([32, 1, 28, 28])" 942 | ] 943 | }, 944 | "execution_count": 31, 945 | "metadata": {}, 946 | "output_type": "execute_result" 947 | } 948 | ], 949 | "source": [ 950 | "torch_g.sample(input_images).shape" 951 | ] 952 | }, 953 | { 954 | "cell_type": "code", 955 | "execution_count": null, 956 | "metadata": {}, 957 | "outputs": [], 958 | "source": [ 959 | "torch_inp = torch.tensor(inp).transpose(1, 3).transpose(2, 3).float()\n", 960 | "torch_out = torch_g(torch_inp, torch.tensor(z).float())\n", 961 | "torch_out[0][0][0]" 962 | ] 963 | }, 964 | { 965 | "cell_type": "code", 966 | "execution_count": null, 967 | "metadata": {}, 968 | "outputs": [], 969 | "source": [ 970 | "def render_image_arr(arr):\n", 971 | " arr = np.uint8(arr * 256)\n", 972 | " arr = arr.reshape(arr.shape[:-1])\n", 973 | " display(Image.fromarray(arr, mode='L').resize((224, 224)))\n", 974 | " \n", 975 | "def render_torch(arr):\n", 976 | " arr = np.uint8(arr * 256)\n", 977 | " display(Image.fromarray(arr, mode='L').resize((224, 224)))" 978 | ] 979 | }, 980 | { 981 | "cell_type": "code", 982 | "execution_count": null, 983 | "metadata": {}, 984 | "outputs": [], 985 | "source": [ 986 | "# Render torch\n", 987 | "refined_torch_out = torch_out.detach().numpy() * 0.5 + 0.5\n", 988 | "\n", 989 | "render_torch(refined_torch_out[0][0])" 990 | ] 991 | }, 992 | { 993 | "cell_type": "code", 994 | "execution_count": null, 995 | "metadata": {}, 996 | "outputs": [], 997 | "source": [ 998 | "from torchvision.models.densenet import DenseNet\n", 999 | "net = DenseNet(growth_rate=16, block_config=(3,3,3,3), num_classes=10, drop_rate=0.0)\n", 1000 | "net" 1001 | ] 1002 | }, 1003 | { 1004 | "cell_type": "code", 1005 | "execution_count": null, 1006 | "metadata": {}, 1007 | "outputs": [], 1008 | "source": [ 1009 | "def convert_bn(net):\n", 1010 | " for name in net._modules.keys():\n", 1011 | " if \"norm\" in name:\n", 1012 | " net._modules[name] = nn.Identity()\n", 1013 | " else:\n", 1014 | " convert_bn(net._modules[name])" 1015 | ] 1016 | }, 1017 | { 1018 | "cell_type": "code", 1019 | "execution_count": null, 1020 | "metadata": {}, 1021 | "outputs": [], 1022 | "source": [ 1023 | "convert_bn(net)\n", 1024 | "net" 1025 | ] 1026 | }, 1027 | { 1028 | "cell_type": "code", 1029 | "execution_count": 54, 1030 | "metadata": {}, 1031 | "outputs": [], 1032 | "source": [ 1033 | "mod = nn.BatchNorm2d(64, affine=True)" 1034 | ] 1035 | }, 1036 | { 1037 | "cell_type": "code", 1038 | "execution_count": 64, 1039 | "metadata": {}, 1040 | "outputs": [], 1041 | "source": [ 1042 | "import tensorflow as tf" 1043 | ] 1044 | }, 1045 | { 1046 | "cell_type": "code", 1047 | "execution_count": null, 1048 | "metadata": {}, 1049 | "outputs": [], 1050 | "source": [ 1051 | "a = tf.constant(0.)\n", 1052 | "b = 2 * a\n", 1053 | "g = tf.gradients(a + b, [a, b], stop_gradients=[a, b])\n" 1054 | ] 1055 | }, 1056 | { 1057 | "cell_type": "code", 1058 | "execution_count": 6, 1059 | "metadata": {}, 1060 | "outputs": [], 1061 | "source": [ 1062 | "norm = nn.InstanceNorm1d(3, affine=True, track_running_stats=False)" 1063 | ] 1064 | }, 1065 | { 1066 | "cell_type": "code", 1067 | "execution_count": 7, 1068 | "metadata": {}, 1069 | "outputs": [], 1070 | "source": [ 1071 | "tensor = torch.zeros((1, 3, 5, 5))\n", 1072 | "for i in range(3):\n", 1073 | " tensor[0][i] = torch.randn((5, 5)) + i" 1074 | ] 1075 | }, 1076 | { 1077 | "cell_type": "code", 1078 | "execution_count": 8, 1079 | "metadata": {}, 1080 | "outputs": [ 1081 | { 1082 | "ename": "ValueError", 1083 | "evalue": "expected 3D input (got 4D input)", 1084 | "output_type": "error", 1085 | "traceback": [ 1086 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 1087 | "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", 1088 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mnorm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", 1089 | "\u001b[0;32m~/Workspace/pytorch/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 720\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 721\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 722\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 723\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 724\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 1090 | "\u001b[0;32m~/Workspace/pytorch/torch/nn/modules/instancenorm.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 51\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 52\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 53\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_check_input_dim\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 54\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 55\u001b[0m return F.instance_norm(\n", 1091 | "\u001b[0;32m~/Workspace/pytorch/torch/nn/modules/instancenorm.py\u001b[0m in \u001b[0;36m_check_input_dim\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 136\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0;36m3\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 137\u001b[0m raise ValueError('expected 3D input (got {}D input)'\n\u001b[0;32m--> 138\u001b[0;31m .format(input.dim()))\n\u001b[0m\u001b[1;32m 139\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 140\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 1092 | "\u001b[0;31mValueError\u001b[0m: expected 3D input (got 4D input)" 1093 | ] 1094 | } 1095 | ], 1096 | "source": [ 1097 | "norm(tensor)" 1098 | ] 1099 | }, 1100 | { 1101 | "cell_type": "code", 1102 | "execution_count": 17, 1103 | "metadata": {}, 1104 | "outputs": [ 1105 | { 1106 | "data": { 1107 | "text/plain": [ 1108 | "tensor([[[[-0.2375, 0.9305, 1.4677, 0.1904, 2.1872],\n", 1109 | " [ 1.1405, 0.1877, -0.2945, -1.8864, -0.8841],\n", 1110 | " [-1.9455, 1.5563, 1.7558, -0.5022, -0.2945],\n", 1111 | " [-0.0302, -0.1682, 1.5877, 1.6745, -0.8496],\n", 1112 | " [ 0.1775, 0.2997, 2.0681, 0.9333, -0.7345]],\n", 1113 | "\n", 1114 | " [[ 2.0193, 0.9958, 2.0291, 0.1691, 0.7073],\n", 1115 | " [ 1.1761, 1.0602, 0.8592, 0.7109, 0.4420],\n", 1116 | " [ 1.3823, 2.0038, 0.9473, 0.3570, 1.9287],\n", 1117 | " [ 0.4505, 2.2994, 0.7116, 2.1269, -0.8457],\n", 1118 | " [-0.5620, 2.2178, 0.5716, 0.3277, 1.9907]],\n", 1119 | "\n", 1120 | " [[ 3.8420, 0.8646, 2.5898, 2.0150, 1.1855],\n", 1121 | " [ 1.9689, 1.1003, 1.8354, 1.3174, 1.2613],\n", 1122 | " [ 4.2256, 2.5718, 1.5704, 1.7245, 0.9842],\n", 1123 | " [ 2.0691, -0.0515, 1.2107, 3.6537, 1.2279],\n", 1124 | " [ 0.5973, 2.1711, 1.0104, 2.9567, 1.1152]]]])" 1125 | ] 1126 | }, 1127 | "execution_count": 17, 1128 | "metadata": {}, 1129 | "output_type": "execute_result" 1130 | } 1131 | ], 1132 | "source": [ 1133 | "tensor" 1134 | ] 1135 | }, 1136 | { 1137 | "cell_type": "code", 1138 | "execution_count": 20, 1139 | "metadata": {}, 1140 | "outputs": [ 1141 | { 1142 | "data": { 1143 | "text/plain": [ 1144 | "[tensor(0.3332), tensor(1.0431), tensor(1.8007)]" 1145 | ] 1146 | }, 1147 | "execution_count": 20, 1148 | "metadata": {}, 1149 | "output_type": "execute_result" 1150 | } 1151 | ], 1152 | "source": [ 1153 | "[tensor[0][i].mean() for i in range(3)]" 1154 | ] 1155 | }, 1156 | { 1157 | "cell_type": "code", 1158 | "execution_count": 21, 1159 | "metadata": {}, 1160 | "outputs": [ 1161 | { 1162 | "data": { 1163 | "text/plain": [ 1164 | "[tensor(-2.8610e-08, grad_fn=),\n", 1165 | " tensor(-7.1526e-08, grad_fn=),\n", 1166 | " tensor(0., grad_fn=)]" 1167 | ] 1168 | }, 1169 | "execution_count": 21, 1170 | "metadata": {}, 1171 | "output_type": "execute_result" 1172 | } 1173 | ], 1174 | "source": [ 1175 | "[norm(tensor)[0][i].mean() for i in range(3)]" 1176 | ] 1177 | }, 1178 | { 1179 | "cell_type": "code", 1180 | "execution_count": 50, 1181 | "metadata": {}, 1182 | "outputs": [], 1183 | "source": [ 1184 | "class Layer(nn.Module):\n", 1185 | " def __init__(self):\n", 1186 | " super().__init__()\n", 1187 | " self.gamma = torch.nn.Parameter(torch.tensor(range(5)).float().unsqueeze(-1).unsqueeze(-1), requires_grad=True)\n", 1188 | " self.beta = torch.nn.Parameter(torch.tensor([0] * 5).float().unsqueeze(-1).unsqueeze(-1), requires_grad=True)\n", 1189 | " \n", 1190 | " def forward(self, x):\n", 1191 | " mid = x * self.gamma + self.beta\n", 1192 | " print(mid)\n", 1193 | " return mid.mean()" 1194 | ] 1195 | }, 1196 | { 1197 | "cell_type": "code", 1198 | "execution_count": 51, 1199 | "metadata": {}, 1200 | "outputs": [], 1201 | "source": [ 1202 | "xx = torch.autograd.Variable(torch.ones((5,2,2)), requires_grad = True)\n", 1203 | "l = Layer()" 1204 | ] 1205 | }, 1206 | { 1207 | "cell_type": "code", 1208 | "execution_count": 53, 1209 | "metadata": {}, 1210 | "outputs": [ 1211 | { 1212 | "name": "stdout", 1213 | "output_type": "stream", 1214 | "text": [ 1215 | "tensor([[[0., 0.],\n", 1216 | " [0., 0.]],\n", 1217 | "\n", 1218 | " [[1., 1.],\n", 1219 | " [1., 1.]],\n", 1220 | "\n", 1221 | " [[2., 2.],\n", 1222 | " [2., 2.]],\n", 1223 | "\n", 1224 | " [[3., 3.],\n", 1225 | " [3., 3.]],\n", 1226 | "\n", 1227 | " [[4., 4.],\n", 1228 | " [4., 4.]]], grad_fn=)\n" 1229 | ] 1230 | } 1231 | ], 1232 | "source": [ 1233 | "result = l(xx)" 1234 | ] 1235 | }, 1236 | { 1237 | "cell_type": "code", 1238 | "execution_count": 54, 1239 | "metadata": {}, 1240 | "outputs": [], 1241 | "source": [ 1242 | "result.backward()" 1243 | ] 1244 | }, 1245 | { 1246 | "cell_type": "code", 1247 | "execution_count": 55, 1248 | "metadata": {}, 1249 | "outputs": [ 1250 | { 1251 | "data": { 1252 | "text/plain": [ 1253 | "tensor([[[0.2000]],\n", 1254 | "\n", 1255 | " [[0.2000]],\n", 1256 | "\n", 1257 | " [[0.2000]],\n", 1258 | "\n", 1259 | " [[0.2000]],\n", 1260 | "\n", 1261 | " [[0.2000]]])" 1262 | ] 1263 | }, 1264 | "execution_count": 55, 1265 | "metadata": {}, 1266 | "output_type": "execute_result" 1267 | } 1268 | ], 1269 | "source": [ 1270 | "l.beta.grad" 1271 | ] 1272 | }, 1273 | { 1274 | "cell_type": "code", 1275 | "execution_count": null, 1276 | "metadata": {}, 1277 | "outputs": [], 1278 | "source": [] 1279 | } 1280 | ], 1281 | "metadata": { 1282 | "kernelspec": { 1283 | "display_name": "Python 3", 1284 | "language": "python", 1285 | "name": "python3" 1286 | }, 1287 | "language_info": { 1288 | "codemirror_mode": { 1289 | "name": "ipython", 1290 | "version": 3 1291 | }, 1292 | "file_extension": ".py", 1293 | "mimetype": "text/x-python", 1294 | "name": "python", 1295 | "nbconvert_exporter": "python", 1296 | "pygments_lexer": "ipython3", 1297 | "version": "3.7.6" 1298 | } 1299 | }, 1300 | "nbformat": 4, 1301 | "nbformat_minor": 4 1302 | } 1303 | -------------------------------------------------------------------------------- /resources/dagan_tracking_images.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amurthy1/dagan_torch/66d29d3ba655575398199d6e0e8170ac3be143f5/resources/dagan_tracking_images.png -------------------------------------------------------------------------------- /resources/dagan_training_progress.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amurthy1/dagan_torch/66d29d3ba655575398199d6e0e8170ac3be143f5/resources/dagan_training_progress.gif -------------------------------------------------------------------------------- /train_dagan.py: -------------------------------------------------------------------------------- 1 | from dagan_trainer import DaganTrainer 2 | from discriminator import Discriminator 3 | from generator import Generator 4 | from dataset import create_dagan_dataloader 5 | from utils.parser import get_dagan_args 6 | import torchvision.transforms as transforms 7 | import torch 8 | import os 9 | import torch.optim as optim 10 | import numpy as np 11 | 12 | 13 | # To maintain reproducibility 14 | torch.manual_seed(0) 15 | np.random.seed(0) 16 | 17 | torch.backends.cudnn.deterministic = True 18 | torch.backends.cudnn.benchmark = False 19 | 20 | # Load input args 21 | args = get_dagan_args() 22 | 23 | dataset_path = args.dataset_path 24 | raw_data = np.load(dataset_path).copy() 25 | 26 | final_generator_path = args.final_model_path 27 | save_checkpoint_path = args.save_checkpoint_path 28 | load_checkpoint_path = args.load_checkpoint_path 29 | in_channels = raw_data.shape[-1] 30 | img_size = args.img_size or raw_data.shape[2] 31 | num_training_classes = args.num_training_classes 32 | num_val_classes = args.num_val_classes 33 | batch_size = args.batch_size 34 | epochs = args.epochs 35 | dropout_rate = args.dropout_rate 36 | max_pixel_value = args.max_pixel_value 37 | should_display_generations = not args.suppress_generations 38 | 39 | # Input sanity checks 40 | final_generator_dir = os.path.dirname(final_generator_path) or os.getcwd() 41 | if not os.access(final_generator_dir, os.W_OK): 42 | raise ValueError(final_generator_path + " is not a valid filepath.") 43 | 44 | if num_training_classes + num_val_classes > raw_data.shape[0]: 45 | raise ValueError( 46 | "Expected at least %d classes but only had %d." 47 | % (num_training_classes + num_val_classes, raw_data.shape[0]) 48 | ) 49 | 50 | 51 | g = Generator(dim=img_size, channels=in_channels, dropout_rate=dropout_rate) 52 | d = Discriminator(dim=img_size, channels=in_channels * 2, dropout_rate=dropout_rate) 53 | 54 | mid_pixel_value = max_pixel_value / 2 55 | train_transform = transforms.Compose( 56 | [ 57 | transforms.ToPILImage(), 58 | transforms.Resize(img_size), 59 | transforms.ToTensor(), 60 | transforms.Normalize( 61 | (mid_pixel_value,) * in_channels, (mid_pixel_value,) * in_channels 62 | ), 63 | ] 64 | ) 65 | 66 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 67 | train_dataloader = create_dagan_dataloader( 68 | raw_data, num_training_classes, train_transform, batch_size 69 | ) 70 | 71 | g_opt = optim.Adam(g.parameters(), lr=0.0001, betas=(0.0, 0.9)) 72 | d_opt = optim.Adam(d.parameters(), lr=0.0001, betas=(0.0, 0.9)) 73 | 74 | val_data = raw_data[num_training_classes : num_training_classes + num_val_classes] 75 | flat_val_data = val_data.reshape( 76 | (val_data.shape[0] * val_data.shape[1], *val_data.shape[2:]) 77 | ) 78 | 79 | display_transform = train_transform 80 | 81 | trainer = DaganTrainer( 82 | generator=g, 83 | discriminator=d, 84 | gen_optimizer=g_opt, 85 | dis_optimizer=d_opt, 86 | batch_size=batch_size, 87 | device=device, 88 | critic_iterations=5, 89 | print_every=75, 90 | num_tracking_images=10, 91 | save_checkpoint_path=save_checkpoint_path, 92 | load_checkpoint_path=load_checkpoint_path, 93 | display_transform=display_transform, 94 | should_display_generations=should_display_generations, 95 | ) 96 | trainer.train(data_loader=train_dataloader, epochs=epochs, val_images=flat_val_data) 97 | 98 | # Save final generator model 99 | torch.save(trainer.g, final_generator_path) 100 | -------------------------------------------------------------------------------- /train_omniglot_classifier.py: -------------------------------------------------------------------------------- 1 | from utils.classifier_utils import ( 2 | create_classifier, 3 | create_classifier_dataloaders, 4 | compute_val_accuracy, 5 | perform_train_step, 6 | create_generated_batch, 7 | create_real_batch, 8 | ) 9 | from generator import Generator 10 | from utils.parser import get_omniglot_classifier_args 11 | import torch.nn as nn 12 | import torch 13 | import os 14 | import torch.optim as optim 15 | import numpy as np 16 | import scipy.stats 17 | import time 18 | 19 | 20 | # To maintain reproducibility 21 | torch.manual_seed(0) 22 | np.random.seed(0) 23 | 24 | torch.backends.cudnn.deterministic = True 25 | torch.backends.cudnn.benchmark = False 26 | 27 | # Load input args 28 | args = get_omniglot_classifier_args() 29 | 30 | generator_path = args.generator_path 31 | dataset_path = args.dataset_path 32 | data_start_index = args.data_start_index 33 | num_classes = args.num_training_classes 34 | generated_batches_per_real = args.generated_batches_per_real 35 | num_epochs = args.epochs 36 | num_val = args.val_samples_per_class 37 | num_train = args.train_samples_per_class 38 | num_bootstrap_samples = args.num_bootstrap_samples 39 | progress_frequency = args.progress_frequency 40 | batch_size = args.batch_size 41 | 42 | 43 | raw_data = np.load(dataset_path) 44 | raw_data = raw_data[data_start_index:] 45 | 46 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 47 | 48 | dagan_generator = torch.load(generator_path, map_location=device) 49 | dagan_generator.eval() 50 | 51 | loss_function = nn.CrossEntropyLoss() 52 | 53 | start_time = int(time.time()) 54 | last_progress_time = start_time 55 | 56 | 57 | val_accuracies_list = [] 58 | 59 | # Run routine both with and without augmentation 60 | for real_batch_rate in (1, generated_batches_per_real + 1): 61 | print( 62 | "Training %d classifiers with %d generated batches per real" 63 | % (num_bootstrap_samples, real_batch_rate - 1) 64 | ) 65 | val_accuracies = [] 66 | 67 | # Train {num_bootstrap_samples} classifiers each with different samples of data 68 | for bootstrap_sample in range(num_bootstrap_samples): 69 | print("\nBootstrap #%d" % (bootstrap_sample + 1)) 70 | classifier = create_classifier(num_classes).to(device) 71 | train_dataloader, val_dataloader = create_classifier_dataloaders( 72 | raw_data, num_classes, num_train, num_val, batch_size 73 | ) 74 | optimizer = optim.Adam(classifier.parameters(), lr=0.001, betas=(0.9, 0.99)) 75 | 76 | # Train on the dataset {num_epochs} times 77 | for epoch in range(num_epochs): 78 | training_loss = perform_train_step( 79 | classifier, 80 | train_dataloader, 81 | optimizer, 82 | loss_function, 83 | device, 84 | real_batch_rate, 85 | dagan_generator, 86 | ) 87 | 88 | if epoch % progress_frequency == 0: 89 | print("[%d] train loss: %.5f" % (epoch + 1, training_loss)) 90 | compute_val_accuracy(val_dataloader, classifier, device, loss_function) 91 | 92 | last_progress_time = int(time.time()) 93 | print( 94 | f"Elapsed time: {(last_progress_time - start_time) / 60:.2f} minutes\n" 95 | ) 96 | 97 | print("[%d] train loss: %.5f" % (num_epochs, training_loss)) 98 | val_accuracies.append( 99 | compute_val_accuracy(val_dataloader, classifier, device, loss_function) 100 | ) 101 | 102 | # Remove current net from gpu memory 103 | del classifier 104 | val_accuracies_list.append(val_accuracies) 105 | 106 | # Summarize and print results 107 | pvalue = scipy.stats.ttest_ind( 108 | val_accuracies_list[0], val_accuracies_list[1], equal_var=False 109 | ).pvalue 110 | print( 111 | "Trained %d classifiers with and without augmentation using %d samples per class" 112 | % (num_bootstrap_samples, num_train) 113 | ) 114 | print("Average accuracy without augmentation: %.5f" % (np.mean(val_accuracies_list[0]))) 115 | print("Average accuracy with augmentation: %.5f" % (np.mean(val_accuracies_list[1]))) 116 | print( 117 | "Confidence level that augmentation has higher accuracy, using 2 sample t-test: %.3f%%" 118 | % (100 * (1 - pvalue)) 119 | ) 120 | -------------------------------------------------------------------------------- /utils/classifier_utils.py: -------------------------------------------------------------------------------- 1 | from torchvision.models.densenet import DenseNet 2 | import torch.nn as nn 3 | from collections import OrderedDict 4 | import torch.nn.functional as F 5 | import numpy as np 6 | import torchvision.transforms as transforms 7 | from torch.utils.data import Dataset, DataLoader 8 | import warnings 9 | import torch 10 | 11 | 12 | class ModifiedDenseNet(DenseNet): 13 | """ 14 | Densenet architecture modified to acccept a flag for whether 15 | a given input image is a real or generated sample. 16 | 17 | Extra dense layers are also added at the end to allow the input flag 18 | to interact with the extracted image features before outputting a prediction. 19 | """ 20 | 21 | def __init__(self, **kwargs): 22 | super().__init__(**kwargs) 23 | self.classifier = nn.Sequential( 24 | OrderedDict( 25 | [ 26 | ("linear1", nn.Linear(self.classifier.in_features + 1, 20)), 27 | ("norm1", nn.BatchNorm1d(20)), 28 | ("linear2", nn.Linear(20, self.classifier.out_features)), 29 | ] 30 | ) 31 | ) 32 | 33 | def forward(self, x, flags): 34 | features = self.features(x) 35 | out = F.relu(features, inplace=True) 36 | out = F.adaptive_avg_pool2d(out, (1, 1)) 37 | out = torch.flatten(out, 1) 38 | out = torch.cat((out, flags.float()), dim=1) 39 | out = self.classifier(out) 40 | return out 41 | 42 | 43 | class AddGaussianNoiseTransform: 44 | """ 45 | Adds random gaussian noise to an image. 46 | 47 | Useful for data augmentation during the training phase. 48 | """ 49 | 50 | def __init__(self, mean=0.0, std=1.0): 51 | self.std = std 52 | self.mean = mean 53 | 54 | def __call__(self, arr): 55 | return ( 56 | np.array(arr) 57 | + np.random.randn(*arr.shape).astype("float32") * self.std 58 | + self.mean 59 | ) 60 | 61 | def __repr__(self): 62 | return self.__class__.__name__ + "(mean={0}, std={1})".format( 63 | self.mean, self.std 64 | ) 65 | 66 | 67 | class ClassifierDataset(Dataset): 68 | def __init__(self, examples, labels, transform=None): 69 | self.examples = examples 70 | self.labels = labels 71 | self.transform = transform 72 | 73 | def __len__(self): 74 | return len(self.examples) 75 | 76 | def __getitem__(self, idx): 77 | sample = self.examples[idx] 78 | with warnings.catch_warnings(): 79 | warnings.simplefilter("ignore") 80 | if self.transform: 81 | sample = self.transform(sample) 82 | 83 | return sample, self.labels[idx] 84 | 85 | 86 | pre_train_transform = transforms.Compose( 87 | [ 88 | transforms.ToTensor(), 89 | transforms.Normalize((0.5), (0.5)), 90 | ] 91 | ) 92 | 93 | train_transform = transforms.Compose( 94 | [ 95 | AddGaussianNoiseTransform(0, 0.1), 96 | transforms.ToPILImage(), 97 | transforms.RandomHorizontalFlip(), 98 | transforms.Resize(224), 99 | transforms.ToTensor(), 100 | transforms.Normalize((0.5), (0.5)), 101 | ] 102 | ) 103 | 104 | val_transform = transforms.Compose( 105 | [ 106 | transforms.ToPILImage(), 107 | transforms.Resize(224), 108 | transforms.ToTensor(), 109 | transforms.Normalize((0.5), (0.5)), 110 | ] 111 | ) 112 | 113 | 114 | def gen_out_to_numpy(tensor): 115 | """ 116 | Convert tensor of images to numpy format. 117 | """ 118 | return ((tensor * 0.5) + 0.5).squeeze(1).unsqueeze(-1).numpy() 119 | 120 | 121 | def create_real_batch(samples): 122 | """ 123 | Given a batch of images, apply the train transform. 124 | """ 125 | np_samples = gen_out_to_numpy(samples) 126 | new_samples = [] 127 | with warnings.catch_warnings(): 128 | warnings.simplefilter("ignore") 129 | for i in range(np_samples.shape[0]): 130 | new_samples.append(train_transform(np_samples[i])) 131 | return torch.stack(new_samples) 132 | 133 | 134 | def create_generated_batch(samples, generator, device="cpu"): 135 | """ 136 | Given a batch of images, run them through a dagan 137 | and apply the train transform. 138 | """ 139 | z = torch.randn((samples.shape[0], generator.z_dim)) 140 | with torch.no_grad(): 141 | g_out = generator(samples.to(device), z.to(device)).cpu() 142 | np_out = gen_out_to_numpy(g_out) 143 | 144 | new_samples = [] 145 | with warnings.catch_warnings(): 146 | warnings.simplefilter("ignore") 147 | for i in range(np_out.shape[0]): 148 | new_samples.append(train_transform(np_out[i])) 149 | return torch.stack(new_samples) 150 | 151 | 152 | def perform_train_step( 153 | net, train_dataloader, optimizer, loss_function, device, real_batch_rate, g 154 | ): 155 | """ 156 | Perform one epoch of training using the full dataset. 157 | """ 158 | running_loss = 0.0 159 | net.train() 160 | for i, data in enumerate(train_dataloader): 161 | # get the inputs; data is a list of [inputs, labels] 162 | inputs, labels = data[0], data[1].to(device) 163 | if i % real_batch_rate == 0: 164 | inputs = create_real_batch(inputs).to(device) 165 | flags = torch.ones((inputs.shape[0], 1)).to(device) 166 | else: 167 | inputs = create_generated_batch(inputs, g, device).to(device) 168 | flags = torch.zeros((inputs.shape[0], 1)).to(device) 169 | 170 | # zero the parameter gradients 171 | optimizer.zero_grad() 172 | # forward + backward + optimize 173 | outputs = net(inputs, flags) 174 | loss = loss_function(outputs, labels) 175 | loss.backward() 176 | optimizer.step() 177 | 178 | # print statistics 179 | running_loss += loss.item() 180 | return running_loss / len(train_dataloader.dataset) 181 | 182 | 183 | def compute_val_accuracy(val_dataloader, net, device, loss_function): 184 | """ 185 | Compute accuracy on the given validation set. 186 | """ 187 | val_loss = 0.0 188 | success = 0 189 | success_topk = 0 190 | k = 5 191 | net.eval() 192 | val_dataset = val_dataloader.dataset 193 | for i, data in enumerate(val_dataloader): 194 | with torch.no_grad(): 195 | inputs, labels = data[0].to(device), data[1].to(device) 196 | flags = torch.ones((inputs.shape[0], 1)).to(device) 197 | outputs = net(inputs, flags) 198 | _, predicted = torch.max(outputs, 1) 199 | _, predicted_topk = torch.topk(outputs, axis=1, k=k) 200 | success += sum( 201 | [int(predicted[i] == labels[i]) for i in range(len(predicted))] 202 | ) 203 | success_topk += sum( 204 | [ 205 | int(labels[i] in predicted_topk[i]) 206 | for i in range(len(predicted_topk)) 207 | ] 208 | ) 209 | loss = loss_function(outputs, labels) 210 | val_loss += loss.item() 211 | print("val loss: %.5f" % (val_loss / len(val_dataset))) 212 | print("val acc: %.5f" % (success / len(val_dataset))) 213 | print("val acc top%d: %.5f" % (k, success_topk / len(val_dataset))) 214 | return success / len(val_dataset) 215 | 216 | 217 | def create_classifier(num_classes): 218 | """ 219 | Create a modified densenet with the given number of classes. 220 | """ 221 | classifier = ModifiedDenseNet( 222 | growth_rate=16, 223 | block_config=(3, 3, 3, 3), 224 | num_classes=num_classes, 225 | drop_rate=0.3, 226 | ) 227 | classifier.features[0] = nn.Conv2d( 228 | 1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False 229 | ) 230 | return classifier 231 | 232 | 233 | def create_classifier_dataloaders( 234 | raw_data, num_classes, num_train, num_val, batch_size 235 | ): 236 | """ 237 | Create train and validation dataloaders with the given raw data. 238 | """ 239 | train_X = [] 240 | train_y = [] 241 | val_X = [] 242 | val_y = [] 243 | 244 | for i in range(num_classes): 245 | # Shuffle data so different examples are chosen each time 246 | class_data = list(raw_data[i]) 247 | np.random.shuffle(class_data) 248 | 249 | train_X.extend(class_data[:num_train]) 250 | train_y.extend([i] * num_train) 251 | val_X.extend(class_data[-num_val:]) 252 | val_y.extend([i] * num_val) 253 | 254 | train_dataloader = DataLoader( 255 | ClassifierDataset(train_X, train_y, pre_train_transform), 256 | batch_size=batch_size, 257 | shuffle=True, 258 | num_workers=1, 259 | ) 260 | val_dataloader = DataLoader( 261 | ClassifierDataset(val_X, val_y, val_transform), 262 | batch_size=batch_size, 263 | shuffle=False, 264 | num_workers=1, 265 | ) 266 | 267 | return train_dataloader, val_dataloader 268 | -------------------------------------------------------------------------------- /utils/gif_maker.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import imageio 4 | import numpy as np 5 | from PIL import Image 6 | import PIL 7 | 8 | 9 | def convert_to_pil(arr, scale=1.0): 10 | arr = arr.reshape(-1, arr.shape[-1]) 11 | arr = (arr * 0.5) + 0.5 12 | arr = np.uint8(arr * 255) 13 | h, w = arr.shape 14 | return ( 15 | Image.fromarray(arr, mode="L") 16 | .resize((int(w * scale), int(h * scale))) 17 | .transpose(PIL.Image.TRANSPOSE) 18 | ) 19 | 20 | 21 | def create_gif( 22 | checkpoint_path, 23 | gif_path, 24 | fps=15, 25 | scale=1.0, 26 | sampling_rate=1, 27 | tracking_images_path=None, 28 | ): 29 | checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu")) 30 | 31 | # Sample generations and append tracking images on top 32 | gens = checkpoint["tracking_images_gens"][0::sampling_rate] 33 | pil_images = [convert_to_pil(tensor, scale=scale) for tensor in gens] 34 | 35 | # Arbitrarily freeze last frame for half the length of existing gif 36 | pil_images += [pil_images[-1]] * (len(pil_images) // 2) 37 | 38 | imageio.mimsave(gif_path, [np.array(img) for img in pil_images], fps=fps) 39 | 40 | if tracking_images_path is not None: 41 | tracking_images = checkpoint["tracking_images"] 42 | imageio.imsave(tracking_images_path, np.array(convert_to_pil(tracking_images))) -------------------------------------------------------------------------------- /utils/parser.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def get_dagan_args(): 5 | parser = argparse.ArgumentParser( 6 | description="Use this script to train a dagan.", 7 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 8 | ) 9 | parser.add_argument( 10 | "dataset_path", 11 | type=str, 12 | help="Filepath for dataset on which to train dagan. File should be .npy format with shape " 13 | "(num_classes, samples_per_class, height, width, channels).", 14 | ) 15 | parser.add_argument( 16 | "final_model_path", type=str, help="Filepath to save final dagan model." 17 | ) 18 | parser.add_argument( 19 | "--batch_size", 20 | nargs="?", 21 | type=int, 22 | default=32, 23 | help="batch_size for experiment", 24 | ) 25 | parser.add_argument( 26 | "--img_size", 27 | nargs="?", 28 | type=int, 29 | help="Dimension to scale images when training. " 30 | "Useful when model architecture expects specific input size. " 31 | "If not specified, uses img_size of data as passed.", 32 | ) 33 | parser.add_argument( 34 | "--num_training_classes", 35 | nargs="?", 36 | type=int, 37 | default=1200, 38 | help="Number of classes to use for training.", 39 | ) 40 | parser.add_argument( 41 | "--num_val_classes", 42 | nargs="?", 43 | type=int, 44 | default=200, 45 | help="Number of classes to use for validation.", 46 | ) 47 | parser.add_argument( 48 | "--epochs", 49 | nargs="?", 50 | type=int, 51 | default=50, 52 | help="Number of epochs to run training.", 53 | ) 54 | parser.add_argument( 55 | "--max_pixel_value", 56 | nargs="?", 57 | type=float, 58 | default=1.0, 59 | help="Range of values used to represent pixel values in input data. " 60 | "Assumes lower bound is 0 (i.e. range of values is [0, max_pixel_value]).", 61 | ) 62 | parser.add_argument( 63 | "--save_checkpoint_path", 64 | nargs="?", 65 | type=str, 66 | help="Filepath to save intermediate training checkpoints.", 67 | ) 68 | parser.add_argument( 69 | "--load_checkpoint_path", 70 | nargs="?", 71 | type=str, 72 | help="Filepath of intermediate checkpoint from which to resume training.", 73 | ) 74 | parser.add_argument( 75 | "--dropout_rate", 76 | type=float, 77 | nargs="?", 78 | default=0.5, 79 | help="Dropout rate to use within network architecture.", 80 | ) 81 | parser.add_argument( 82 | "--suppress_generations", 83 | action="store_true", 84 | help="If specified, does not show intermediate progress images.", 85 | ) 86 | return parser.parse_args() 87 | 88 | 89 | def get_omniglot_classifier_args(): 90 | parser = argparse.ArgumentParser( 91 | description="Use this script to train an omniglot classifier " 92 | "with and without augmentations to compare the results.", 93 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 94 | ) 95 | parser.add_argument( 96 | "generator_path", 97 | type=str, 98 | help="Filepath for dagan generator to use for augmentations.", 99 | ) 100 | parser.add_argument( 101 | "--dataset_path", 102 | type=str, 103 | default="datasets/omniglot_data.npy", 104 | help="Filepath for omniglot data.", 105 | ) 106 | parser.add_argument( 107 | "--batch_size", 108 | nargs="?", 109 | type=int, 110 | default=32, 111 | help="batch_size for experiment", 112 | ) 113 | parser.add_argument( 114 | "--data_start_index", 115 | nargs="?", 116 | type=int, 117 | default=1420, 118 | help="Only uses classes after the given index for training. " 119 | "Useful to isolate data that wasn't used during dagan training.", 120 | ) 121 | parser.add_argument( 122 | "--num_training_classes", 123 | nargs="?", 124 | type=int, 125 | default=100, 126 | help="Number of classes to use for training.", 127 | ) 128 | parser.add_argument( 129 | "--train_samples_per_class", 130 | nargs="?", 131 | type=int, 132 | default=5, 133 | help="Number of samples to use per class during training.", 134 | ) 135 | parser.add_argument( 136 | "--val_samples_per_class", 137 | nargs="?", 138 | type=int, 139 | default=5, 140 | help="Number of samples to use per class during validation.", 141 | ) 142 | parser.add_argument( 143 | "--epochs", 144 | nargs="?", 145 | type=int, 146 | default=200, 147 | help="Number of epochs to run training.", 148 | ) 149 | parser.add_argument( 150 | "--progress_frequency", 151 | nargs="?", 152 | type=int, 153 | default=50, 154 | help="Number of epochs between printing intermediate train/val loss.", 155 | ) 156 | parser.add_argument( 157 | "--generated_batches_per_real", 158 | nargs="?", 159 | type=int, 160 | default=1, 161 | help="Number of augmented batches per real batch during " 162 | "augmented training phase.", 163 | ) 164 | parser.add_argument( 165 | "--num_bootstrap_samples", 166 | nargs="?", 167 | type=int, 168 | default=10, 169 | help="Number of classifiers to train with slightly different data " 170 | "in order to get a more accuracy measurement.", 171 | ) 172 | return parser.parse_args() 173 | --------------------------------------------------------------------------------