├── models ├── content_predictor.py ├── spatial_prediction.py ├── discriminator.py └── generator.py ├── configs ├── CelebA_64x64_N2M2S32.yaml └── LSUN_64x64_N2M2S32.yaml ├── LICENSE ├── utils.py ├── README.md ├── img_utils.py ├── ops.py ├── patch_handler.py ├── logger.py ├── main.py ├── coord_handler.py └── trainer.py /models/content_predictor.py: -------------------------------------------------------------------------------- 1 | 2 | from ops import lrelu 3 | 4 | _EPS = 1e-5 5 | import torch.nn as nn 6 | import torch 7 | 8 | class ContentPredictorBuilder(nn.Module): 9 | def __init__(self, config): 10 | super(ContentPredictorBuilder, self).__init__() 11 | 12 | self.config=config 13 | self.z_dim = config["model_params"]["z_dim"] 14 | self.aux_dim = config["model_params"]["aux_dim"] 15 | 16 | self.fc1 = nn.Linear(512, self.aux_dim) 17 | self.bn1 = nn.BatchNorm1d(1, eps=1e-05, momentum=0.9) 18 | self.fc2 = nn.Linear(self.aux_dim, self.z_dim) 19 | 20 | def forward(self, h, is_training): 21 | h = self.fc1(h) 22 | h = self.bn1(h) 23 | h = lrelu(h) 24 | h = self.fc2(h) 25 | return torch.tanh(h) 26 | -------------------------------------------------------------------------------- /models/spatial_prediction.py: -------------------------------------------------------------------------------- 1 | 2 | from ops import lrelu 3 | 4 | _EPS = 1e-5 5 | import torch.nn as nn 6 | import torch 7 | 8 | class SpatialPredictorBuilder(nn.Module): 9 | def __init__(self, config): 10 | super(SpatialPredictorBuilder, self).__init__() 11 | self.config=config 12 | self.aux_dim = config["model_params"]["aux_dim"] 13 | self.spatial_dim = config["model_params"]["spatial_dim"] 14 | 15 | self.fc1 = nn.Linear(512, self.aux_dim) 16 | self.bn1 = nn.BatchNorm1d(1, eps=1e-05, momentum=0.9) 17 | self.fc2 = nn.Linear(self.aux_dim, self.spatial_dim) 18 | 19 | def forward(self, h, is_training): 20 | h = self.fc1(h) 21 | h = self.bn1(h) 22 | h = lrelu(h) 23 | h = self.fc2(h) 24 | return torch.tanh(h) 25 | -------------------------------------------------------------------------------- /configs/CelebA_64x64_N2M2S32.yaml: -------------------------------------------------------------------------------- 1 | data_params: 2 | # dataset options: "celeba", "lsun", "mnist" 3 | dataset: "celeba" 4 | c_dim: 3 5 | full_image_size: [64, 64] 6 | macro_patch_size: [32, 32] 7 | micro_patch_size: [16, 16] 8 | num_train_samples: 1000 9 | num_test_samples: 0 10 | coordinate_system: "euclidean" 11 | 12 | train_params: 13 | epochs: inf # No need to specify, usually longer better, and eventually saturates 14 | batch_size: 100 15 | G_update_period: 1 16 | D_update_period: 1 17 | Q_update_period: 0 18 | beta1: 0.0 19 | beta2: 0.999 20 | glr: 0.0001 21 | dlr: 0.0004 22 | qlr: 0.0001 23 | 24 | loss_params: 25 | gp_lambda: 10 26 | coord_loss_w: 100 27 | code_loss_w: 0 28 | 29 | model_params: 30 | z_dim: 128 31 | spatial_dim: 2 32 | g_extra_layers: 0 33 | d_extra_layers: 0 34 | ngf_base: 64 35 | ndf_base: 64 36 | aux_dim: 128 37 | 38 | log_params: 39 | exp_name: "CelebA_64x64_N2M2S32" 40 | log_dir: "./logs/" 41 | 42 | # Use inf to disable 43 | img_step: 5 # Consumes quite much disk space 44 | dump_img_step: 5 # Consumes LOTS of disk space 45 | 46 | -------------------------------------------------------------------------------- /configs/LSUN_64x64_N2M2S32.yaml: -------------------------------------------------------------------------------- 1 | data_params: 2 | # dataset options: "celeba", "lsun", "mnist", 3 | dataset: "lsun" 4 | c_dim: 3 5 | full_image_size: [64, 64] 6 | macro_patch_size: [32, 32] 7 | micro_patch_size: [16, 16] 8 | num_train_samples: 2983042 9 | num_test_samples: 50000 10 | coordinate_system: "euclidean" 11 | 12 | train_params: 13 | epochs: inf # No need to specify, usually longer better, and eventually saturates 14 | batch_size: 128 15 | G_update_period: 1 16 | D_update_period: 1 17 | Q_update_period: 0 18 | beta1: 0.0 19 | beta2: 0.999 20 | glr: 0.0001 21 | dlr: 0.0004 22 | qlr: 0.0001 23 | 24 | loss_params: 25 | gp_lambda: 10 26 | coord_loss_w: 100 27 | code_loss_w: 0 28 | 29 | model_params: 30 | z_dim: 128 31 | spatial_dim: 2 32 | g_extra_layers: 0 33 | d_extra_layers: 0 34 | ngf_base: 64 35 | ndf_base: 64 36 | aux_dim: 128 37 | 38 | log_params: 39 | exp_name: "LSUN_64x64_N2M2S32" 40 | log_dir: "./logs/" 41 | 42 | # Use inf to disable 43 | img_step: 1000 # Consumes quite much disk space 44 | dump_img_step: 2000 # Consumes LOTS of disk space 45 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Mohsen Fayyaz 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | from torchvision.utils import save_image 4 | import torch 5 | 6 | def save_manifold_images(images, size, image_path): 7 | images = (images+1) / 2 8 | manifold_image = np.squeeze(compose_manifold_images(images, size)) 9 | return save_image(torch.tensor(manifold_image, dtype=torch.float64), image_path) 10 | 11 | 12 | def compose_manifold_images(images, size): 13 | h, w = images.shape[2], images.shape[3] 14 | if (images.shape[1] in (3,4)): 15 | c = images.shape[1] 16 | img = np.zeros((c, h * size[0], w * size[1])) 17 | for idx, image in enumerate(images): 18 | i = idx % size[1] 19 | j = idx // size[1] 20 | img[:,j * h:j * h + h, i * w:i * w + w] = image 21 | return img 22 | elif images.shape[1]==1: 23 | img = np.zeros((h * size[0], w * size[1])) 24 | for idx, image in enumerate(images): 25 | i = idx % size[1] 26 | j = idx // size[1] 27 | img[j * h:j * h + h, i * w:i * w + w] = image[0,:,:] 28 | return img 29 | else: 30 | raise ValueError('in merge(images,size) images parameter ' + 31 | 'must have dimensions: HxW or HxWx3 or HxWx4, got {}'.format(images.shape)) 32 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # COCO_GAN 2 | 3 | Implementation of Paper 'COCO-GAN: Generation by Parts via Conditional Coordinating' 4 | This code is PyTorch converted version of author's TensorFlow code present here : https://github.com/hubert0527/COCO-GAN 5 | 6 | (The structure of the code is kept very similar to the author's repository, refer https://hubert0527.github.io/COCO-GAN/ for detailed descriptions) 7 | 8 | # Prerequisite 9 | 10 | ```python 11 | Torch 12 | TorchVision 13 | Numpy 14 | Matplotlib 15 | ``` 16 | 17 | # Train the Network 18 | 19 | To train the network, e.g. with coco-gan, you can execute the following command: 20 | 21 | ```python 22 | python main.py 23 | ``` 24 | Config file can be edited according to the need ! 25 | 26 | # Dataset 27 | "celeb_data" can be downloaded from this link : https://drive.google.com/open?id=1PceubRgNbDhTExEFSfHuB4fKsOibDKQy 28 | This contains only 1000 randomly sampled images 29 | 30 | For "MNIST" data, code is already present in the 'main' file under 'load_dataset' 31 | 32 | # Reference 33 | @inproceedings{lin2019cocogan, 34 | author = {Chieh Hubert Lin and 35 | Chia{-}Che Chang and 36 | Yu{-}Sheng Chen and 37 | Da{-}Cheng Juan and 38 | Wei Wei and 39 | Hwann{-}Tzong Chen}, 40 | title = {{COCO-GAN:} Generation by Parts via Conditional Coordinating}, 41 | booktitle = {IEEE International Conference on Computer Vision (ICCV)}, 42 | year = {2019}, 43 | } 44 | -------------------------------------------------------------------------------- /img_utils.py: -------------------------------------------------------------------------------- 1 | import scipy.misc 2 | import numpy as np 3 | 4 | def get_image(image_path, input_height, input_width, resize_height=64, resize_width=64, byte_image=True): 5 | image = imread(image_path) 6 | return transform(image, input_height, input_width, resize_height, resize_width, byte_image=byte_image) 7 | 8 | def save_images(images, size, image_path): 9 | return imsave(inverse_transform(images), size, image_path) 10 | 11 | def imread(path, grayscale = False): 12 | if (grayscale): 13 | return scipy.misc.imread(path, flatten = True).astype(np.float) 14 | else: 15 | return scipy.misc.imread(path).astype(np.float) 16 | 17 | def merge_images(images, size): 18 | return inverse_transform(images) 19 | 20 | def merge(images, size): 21 | h, w = images.shape[1], images.shape[2] 22 | if (images.shape[3] in (3,4)): 23 | c = images.shape[3] 24 | img = np.zeros((h * size[0], w * size[1], c)) 25 | for idx, image in enumerate(images): 26 | i = idx % size[1] 27 | j = idx // size[1] 28 | img[j * h:j * h + h, i * w:i * w + w, :] = image 29 | return img 30 | elif images.shape[3]==1: 31 | img = np.zeros((h * size[0], w * size[1])) 32 | for idx, image in enumerate(images): 33 | i = idx % size[1] 34 | j = idx // size[1] 35 | img[j * h:j * h + h, i * w:i * w + w] = image[:,:,0] 36 | return img 37 | else: 38 | raise ValueError('in merge(images,size) images parameter ''must have dimensions: HxW or HxWx3 or HxWx4') 39 | 40 | def imsave(images, size, path): 41 | image = np.squeeze(merge(images, size)) 42 | return scipy.misc.imsave(path, image) 43 | 44 | def center_crop(x, crop_h, crop_w, resize_h=64, resize_w=64): 45 | if crop_w is None: 46 | crop_w = crop_h 47 | h, w = x.shape[:2] 48 | j = int(round((h - crop_h)/2.)) 49 | i = int(round((w - crop_w)/2.)) 50 | return scipy.misc.imresize(x[j:j+crop_h, i:i+crop_w], [resize_h, resize_w]) 51 | 52 | def transform(image, input_height, input_width, resize_height=64, resize_width=64, byte_image=True): 53 | cropped_image = scipy.misc.imresize(image, [resize_height, resize_width]) 54 | if byte_image: 55 | return np.array(cropped_image).astype(np.uint8) 56 | else: 57 | return np.array(cropped_image)/127.5 - 1. 58 | 59 | def inverse_transform(images): 60 | return (images+1.)/2. -------------------------------------------------------------------------------- /ops.py: -------------------------------------------------------------------------------- 1 | """ 2 | Most of the codes are from: 3 | 1. https://github.com/carpedm20/DCGAN-tensorflow 4 | 2. https://github.com/minhnhat93/tf-SNDCGAN 5 | """ 6 | import numpy as np 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.functional as F 11 | 12 | def _l2normalize(v, eps=1e-12): 13 | return v / (torch.sum(v ** 2) ** 0.5 + eps) 14 | 15 | 16 | def lrelu(x, leak=0.2, name="lrelu"): 17 | return torch.max(x, leak*x) 18 | 19 | def upscale(x, scale): 20 | return nn.Upsample(scale_factor=scale, mode='nearest')(x) 21 | 22 | 23 | def pad(x, p): 24 | c = torch.tensor([[0, 0], [p, p,], [p, p], [0, 0]]) 25 | return F.pad(x, c, mode='reflect') 26 | 27 | def tile(a, dim, n_tile): 28 | init_dim = a.size(dim) 29 | repeat_idx = [1] * a.dim() 30 | repeat_idx[dim] = n_tile 31 | a = a.repeat(*(repeat_idx)) 32 | order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])) 33 | return torch.index_select(a, dim, order_index) 34 | 35 | 36 | def add_coords(input_tensor, x_dim=64, y_dim=64, with_r=False): 37 | """ 38 | For CoordConv. 39 | 40 | Add coords to a tensor 41 | input_tensor: (batch, x_dim, y_dim, c) 42 | """ 43 | batch_size_tensor = input_tensor.size(0) 44 | 45 | xx_ones = torch.ones((batch_size_tensor, x_dim), dtype=torch.int32) 46 | xx_ones = xx_ones.unsqueeze(-1) 47 | xx_range = tile(torch.range(0,x_dim-1).unsqueeze(0), 1, batch_size_tensor) 48 | xx_range = xx_range.unsqueeze(1) 49 | xx_channel = torch.matmul(xx_ones, xx_range) 50 | xx_channel = xx_channel.unsqueeze(-1) 51 | 52 | yy_ones = torch.ones((batch_size_tensor, y_dim), 53 | dtype=torch.int32) 54 | yy_ones = yy_ones.unsqueeze(1) 55 | yy_range = tile(torch.range(0,y_dim-1).unsqueeze(0), 1, batch_size_tensor) 56 | yy_range = yy_range.unsqueeze(-1) 57 | yy_channel = torch.matmul(yy_range, yy_ones) 58 | yy_channel = yy_channel.unsqueeze(-1) 59 | 60 | xx_channel = xx_channel.float() / (x_dim - 1) 61 | yy_channel = yy_channel.float() / (y_dim - 1) 62 | xx_channel = xx_channel*2 - 1 63 | yy_channel = yy_channel*2 - 1 64 | 65 | ret = torch.cat((input_tensor, 66 | xx_channel, 67 | yy_channel), -1) 68 | 69 | if with_r: 70 | rr = torch.sqrt( torch.pow(xx_channel-0.5, 2) 71 | + torch.pow(yy_channel-0.5, 2)) 72 | ret = torch.cat((ret, rr), -1) 73 | return ret 74 | 75 | -------------------------------------------------------------------------------- /patch_handler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | class PatchHandler(): 5 | 6 | def __init__(self, config): 7 | self.config = config 8 | 9 | self.batch_size = self.config["train_params"]["batch_size"] 10 | self.micro_patch_size = self.config["data_params"]["micro_patch_size"] 11 | self.macro_patch_size = self.config["data_params"]["macro_patch_size"] 12 | self.full_image_size = self.config["data_params"]["full_image_size"] 13 | self.coordinate_system = self.config["data_params"]["coordinate_system"] 14 | self.c_dim = self.config["data_params"]["c_dim"] 15 | 16 | self.num_micro_compose_macro = config["data_params"]["num_micro_compose_macro"] 17 | 18 | 19 | def reord_patches_cpu(self, x, batch_size, patch_count): 20 | # Reorganize image order from [a0, b0, c0, a1, b1, c1, ...] to [a0, a1, ..., b0, b1, ..., c0, c1, ...] 21 | select = np.hstack([[i*batch_size+j for i in range(patch_count)] for j in range(batch_size)]) 22 | x_reord = np.take(x, select, axis=0) 23 | return x_reord 24 | 25 | def concat_micro_patches_cpu(self, generated_patches, ratio_over_micro): 26 | 27 | patch_count = ratio_over_micro[0] * ratio_over_micro[1] 28 | macro_patches = torch.zeros(int(generated_patches.shape[0]/patch_count), self.c_dim, self.macro_patch_size[0], self.macro_patch_size[1]) 29 | 30 | idx = list(range(0,generated_patches.shape[0],4)) 31 | micro_patch = generated_patches[idx] 32 | macro_patches[:,:,0:self.micro_patch_size[0],0:self.micro_patch_size[1]] = micro_patch 33 | 34 | idx = list(range(2,generated_patches.shape[0],4)) 35 | micro_patch = generated_patches[idx] 36 | macro_patches[:,:,0:self.micro_patch_size[0],self.micro_patch_size[1]:2*self.micro_patch_size[1]] = micro_patch 37 | 38 | idx = list(range(1,generated_patches.shape[0],4)) 39 | micro_patch = generated_patches[idx] 40 | macro_patches[:,:,self.micro_patch_size[0]:2*self.micro_patch_size[0],0:self.micro_patch_size[1]] = micro_patch 41 | 42 | idx = list(range(3,generated_patches.shape[0],4)) 43 | micro_patch = generated_patches[idx] 44 | macro_patches[:,:,self.micro_patch_size[0]:2*self.micro_patch_size[0],self.micro_patch_size[1]:2*self.micro_patch_size[1]] = micro_patch 45 | 46 | return macro_patches 47 | 48 | 49 | def crop_micro_from_full_cpu(self, imgs, crop_pos_x, crop_pos_y): 50 | 51 | ps_x, ps_y = self.micro_patch_size # i.e. Patch-Size 52 | 53 | valid_area_x = self.full_image_size[0] - self.micro_patch_size[0] 54 | valid_area_y = self.full_image_size[1] - self.micro_patch_size[1] 55 | 56 | crop_result = [] 57 | batch_size = imgs.shape[0] 58 | for i in range(batch_size*self.num_micro_compose_macro): 59 | i_idx = i // self.num_micro_compose_macro 60 | x_idx = np.round((crop_pos_x[i, 0]+1)/2*valid_area_x).astype(int) 61 | y_idx = np.round((crop_pos_y[i, 0]+1)/2*valid_area_y).astype(int) 62 | t = imgs[i_idx, :, x_idx:x_idx+ps_x, y_idx:y_idx+ps_y] 63 | crop_result.append(t) 64 | return torch.stack(crop_result) -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | from utils import save_manifold_images 5 | 6 | 7 | class Logger(): 8 | def __init__(self, config, patch_handler): 9 | self.config = config 10 | 11 | self.batch_size = self.config["train_params"]["batch_size"] 12 | self.num_micro_compose_full = self.config["data_params"]["num_micro_compose_full"] 13 | 14 | 15 | self.full_shape = [ 16 | None, 17 | self.config["data_params"]["full_image_size"][0], 18 | self.config["data_params"]["full_image_size"][1], 19 | self.config["data_params"]["c_dim"], 20 | ] 21 | 22 | 23 | self.exp_name = config["log_params"]["exp_name"] 24 | self.log_dir = self._check_folder(os.path.join(config["log_params"]["log_dir"], self.exp_name)) 25 | self.img_dir = self._check_folder(os.path.join(self.log_dir, "images")) 26 | 27 | # Use float to parse "inf" 28 | self.img_step = float(config["log_params"]["img_step"]) 29 | self.dump_img_step = float(config["log_params"]["dump_img_step"]) 30 | 31 | 32 | def _check_folder(self, folder): 33 | if not os.path.exists(folder): 34 | os.makedirs(folder) 35 | return folder 36 | 37 | 38 | 39 | 40 | def _check_step(self, step, step_config): 41 | if step_config is None: 42 | return False 43 | elif step==0: 44 | return False 45 | return (step % step_config) == 0 46 | 47 | 48 | def log_iter(self, trainer, epoch, iter_, global_step, 49 | z_iter, z_fixed): 50 | 51 | 52 | # We use a set of fixed z here to better monitor the changes through time. 53 | if self._check_step(global_step, self.dump_img_step): 54 | 55 | fixed_patch, fixed_full = trainer.generate_full_image_cpu(z_fixed) 56 | _, sampled_full = trainer.generate_full_image_cpu(z_iter) 57 | 58 | 59 | num_full = self.batch_size 60 | num_patches = self.batch_size * self.num_micro_compose_full 61 | manifold_h_f, manifold_w_f = int(np.sqrt(num_full)), int(np.sqrt(num_full)) 62 | manifold_h_p, manifold_w_p = int(np.sqrt(num_patches)), int(np.sqrt(num_patches)) 63 | 64 | # Save fixed micro patches 65 | save_name = 'fixed_patch_{:02d}_{:04d}.png'.format(epoch, iter_) 66 | save_manifold_images(fixed_patch[:manifold_h_p * manifold_w_p, :, :, :], 67 | [manifold_h_p, manifold_w_p], 68 | os.path.join(self.img_dir, save_name)) 69 | 70 | # Save fixed full images 71 | save_name = 'fixed_full_{:02d}_{:04d}.png'.format(epoch, iter_) 72 | save_manifold_images(fixed_full[:manifold_h_f * manifold_w_f, :, :, :], 73 | [manifold_h_f, manifold_w_f], 74 | os.path.join(self.img_dir, save_name)) 75 | 76 | # Save sampled full images 77 | save_name = 'sampled_full_{:02d}_{:04d}.png'.format(epoch, iter_) 78 | save_manifold_images(sampled_full[:manifold_h_f * manifold_w_f, :, :, :], 79 | [manifold_h_f, manifold_w_f], 80 | os.path.join(self.img_dir, save_name)) 81 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | 3 | 4 | from models.generator import GeneratorBuilder 5 | from models.discriminator import DiscriminatorBuilder 6 | from models.spatial_prediction import SpatialPredictorBuilder 7 | from models.content_predictor import ContentPredictorBuilder 8 | 9 | from coord_handler import CoordHandler 10 | from patch_handler import PatchHandler 11 | from logger import Logger 12 | 13 | from trainer import Trainer 14 | 15 | import torchvision 16 | import torchvision.datasets as dataset 17 | import torch 18 | 19 | 20 | def precompute_parameters(config): 21 | full_image_size = config["data_params"]["full_image_size"] 22 | micro_patch_size = config["data_params"]["micro_patch_size"] 23 | macro_patch_size = config["data_params"]["macro_patch_size"] 24 | 25 | # Let NxM micro matches to compose a macro patch, 26 | # `ratio_macro_to_micro` is N or M 27 | ratio_macro_to_micro = [ 28 | macro_patch_size[0] // micro_patch_size[0], 29 | macro_patch_size[1] // micro_patch_size[1], 30 | ] 31 | num_micro_compose_macro = ratio_macro_to_micro[0] * ratio_macro_to_micro[1] 32 | 33 | # Let NxM micro matches to compose a full image, 34 | # `ratio_full_to_micro` is N or M 35 | ratio_full_to_micro = [ 36 | full_image_size[0] // micro_patch_size[0], 37 | full_image_size[1] // micro_patch_size[1], 38 | ] 39 | num_micro_compose_full = ratio_full_to_micro[0] * ratio_full_to_micro[1] 40 | 41 | config["data_params"]["ratio_macro_to_micro"] = ratio_macro_to_micro 42 | config["data_params"]["ratio_full_to_micro"] = ratio_full_to_micro 43 | config["data_params"]["num_micro_compose_macro"] = num_micro_compose_macro 44 | config["data_params"]["num_micro_compose_full"] = num_micro_compose_full 45 | 46 | 47 | 48 | def load_dataset(config): 49 | # data_path = 'mnist/' 50 | # 51 | # train_dataset = dataset.MNIST(root=data_path, train=True, download=True, transform=torchvision.transforms.Compose([torchvision.transforms.Resize((config["data_params"]["full_image_size"][0],config["data_params"]["full_image_size"][1])), torchvision.transforms.ToTensor()])) 52 | 53 | data_path = 'celeb_data/' 54 | train_dataset = torchvision.datasets.ImageFolder( 55 | root=data_path, 56 | transform=torchvision.transforms.Compose([torchvision.transforms.Resize((config["data_params"]["full_image_size"][0],config["data_params"]["full_image_size"][1])), torchvision.transforms.ToTensor()]) 57 | ) 58 | 59 | 60 | train_loader = torch.utils.data.DataLoader( 61 | train_dataset, 62 | batch_size=config['train_params']['batch_size'], 63 | num_workers=0, 64 | shuffle=True 65 | ) 66 | return train_loader 67 | 68 | 69 | with open('./configs/CelebA_64x64_N2M2S32.yaml') as f: 70 | config = yaml.load(f) 71 | micro_size = config["data_params"]['micro_patch_size'] 72 | macro_size = config["data_params"]['macro_patch_size'] 73 | full_size = config["data_params"]['full_image_size'] 74 | assert macro_size[0] % micro_size[0] == 0 75 | assert macro_size[1] % micro_size[1] == 0 76 | assert full_size[0] % micro_size[0] == 0 77 | assert full_size[1] % micro_size[1] == 0 78 | 79 | # Pre-compute some frequently used parameters 80 | precompute_parameters(config) 81 | 82 | # Create model builders 83 | coord_handler = CoordHandler(config) 84 | patch_handler = PatchHandler(config) 85 | 86 | d_builder = DiscriminatorBuilder(config) 87 | g_builder = GeneratorBuilder(config) 88 | cp_builder = SpatialPredictorBuilder(config) 89 | zp_builder = ContentPredictorBuilder(config) 90 | 91 | 92 | real_images = load_dataset(config) 93 | 94 | ## Create controllers 95 | logger = Logger(config, patch_handler) 96 | trainer = Trainer(config, g_builder, d_builder, cp_builder, zp_builder, coord_handler, patch_handler) 97 | trainer.train(logger, real_images) 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | -------------------------------------------------------------------------------- /models/discriminator.py: -------------------------------------------------------------------------------- 1 | import math 2 | _EPS = 1e-5 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import numpy as np 8 | 9 | # Learnable Function for Discriminator 10 | class DiscriminatorBuilder(nn.Module): 11 | def __init__(self, config): 12 | super(DiscriminatorBuilder, self).__init__() 13 | 14 | self.config=config 15 | self.ndf_base = self.config["model_params"]["ndf_base"] 16 | self.num_extra_layers = self.config["model_params"]["d_extra_layers"] 17 | self.macro_patch_size = self.config["data_params"]["macro_patch_size"] 18 | 19 | self.residual_block_main = nn.ModuleList() 20 | self.residual_block_residue = nn.ModuleList() 21 | 22 | out_ch = self.config["data_params"]["c_dim"] 23 | num_resize_layers = int(math.log(min(self.macro_patch_size), 2) - 1) 24 | num_total_layers = num_resize_layers + self.num_extra_layers 25 | basic_layers = [2, 4, 8, 8] 26 | if num_total_layers>=len(basic_layers): 27 | num_replicate_layers = num_total_layers - len(basic_layers) 28 | ndf_mult_list = [1, ] * num_replicate_layers + basic_layers 29 | else: 30 | ndf_mult_list = basic_layers[-num_total_layers:] 31 | 32 | # Residual Block 33 | for idx, ndf_mult in enumerate(ndf_mult_list): 34 | n_ch = self.ndf_base * ndf_mult 35 | # Head is fixed and goes first 36 | if idx==0: 37 | resize= True 38 | # Extra layers before standard layers 39 | elif idx<=self.num_extra_layers: 40 | resize= False 41 | # Last standard layer has no resize 42 | elif idx==len(ndf_mult_list)-1: 43 | resize= False 44 | # Standard layers 45 | else: 46 | resize= True 47 | 48 | self.residual_block_main.append(nn.Conv2d(out_ch, n_ch, 3, 1, padding=1)) 49 | self.residual_block_main.append(nn.Conv2d(n_ch, n_ch, 3, 1, padding=1)) 50 | if resize: 51 | self.residual_block_main.append(nn.MaxPool2d(2, 2)) 52 | self.residual_block_residue.append(nn.MaxPool2d(2, 2)) 53 | self.residual_block_residue.append(nn.Conv2d(out_ch, n_ch, 1, 1, padding=0)) 54 | out_ch = n_ch 55 | 56 | proj_out_ch = self.ndf_base*ndf_mult_list[-1] 57 | proj_in_ch = self.config["model_params"]["spatial_dim"] 58 | self.projection = nn.Linear(proj_in_ch, proj_out_ch, bias=torch.zeros).double() 59 | stddev = np.sqrt(2. / (proj_in_ch)) 60 | self.projection.weight.data.uniform_(-stddev,stddev) 61 | 62 | 63 | self.global_linear = nn.Linear(proj_out_ch, 1, bias=torch.zeros) 64 | stddev = np.sqrt(2. / (proj_out_ch)) 65 | self.global_linear.weight.data.uniform_(-stddev,stddev) 66 | 67 | 68 | # forward method 69 | def forward(self, x, y=None, is_training=True): 70 | 71 | num_resize_layers = int(math.log(min(self.macro_patch_size), 2) - 1) 72 | num_total_layers = num_resize_layers + self.num_extra_layers 73 | basic_layers = [2, 4, 8, 8] 74 | if num_total_layers>=len(basic_layers): 75 | num_replicate_layers = num_total_layers - len(basic_layers) 76 | ndf_mult_list = [1, ] * num_replicate_layers + basic_layers 77 | else: 78 | ndf_mult_list = basic_layers[-num_total_layers:] 79 | 80 | # Stack extra layers without resize first 81 | X = x 82 | residual_main_idx = 0 83 | residual_residue_idx = 0 84 | for idx, ndf_mult in enumerate(ndf_mult_list): 85 | # Head is fixed and goes first 86 | if idx==0: 87 | is_head, resize = True, True 88 | # Extra layers before standard layers 89 | elif idx<=self.num_extra_layers: 90 | is_head, resize = False, False 91 | # Last standard layer has no resize 92 | elif idx==len(ndf_mult_list)-1: 93 | is_head, resize = False, False 94 | # Standard layers 95 | else: 96 | is_head, resize = False, True 97 | 98 | h = X 99 | if not is_head: 100 | h = F.relu(h) 101 | h = self.residual_block_main[residual_main_idx](h) 102 | residual_main_idx+=1 103 | h = F.relu(h) 104 | h = self.residual_block_main[residual_main_idx](h) 105 | residual_main_idx+=1 106 | if resize: 107 | h = self.residual_block_main[residual_main_idx](h) 108 | residual_main_idx+=1 109 | 110 | # Short cut 111 | s = X 112 | if resize: 113 | s = self.residual_block_residue[residual_residue_idx](s) 114 | residual_residue_idx+=1 115 | s = self.residual_block_residue[residual_residue_idx](s) 116 | residual_residue_idx+=1 117 | 118 | X = h + s 119 | 120 | X = F.relu(X) 121 | X = torch.sum(X, (2,3)) # Global pooling 122 | last_feature_map = X 123 | adv_out = self.global_linear(X).type(torch.float64) 124 | 125 | # Projection Discriminator 126 | if y is not None: 127 | Y = torch.tensor(np.expand_dims(y,1), dtype=torch.float64) 128 | y_emb = self.projection(Y) 129 | proj_out = torch.sum(y_emb*X.type(torch.float64),(1,2), keepdim=True).view(-1,1) 130 | out = adv_out + proj_out 131 | else: 132 | out = adv_out 133 | 134 | return out, last_feature_map 135 | 136 | -------------------------------------------------------------------------------- /models/generator.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | from ops import upscale 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | _EPS = 1e-5 9 | 10 | class GeneratorBuilder(nn.Module): 11 | def __init__(self, config): 12 | super(GeneratorBuilder, self).__init__() 13 | 14 | self.config=config 15 | self.ngf_base = self.config["model_params"]["ngf_base"] 16 | self.num_extra_layers = self.config["model_params"]["g_extra_layers"] 17 | self.micro_patch_size = self.config["data_params"]["micro_patch_size"] 18 | self.c_dim = self.config["data_params"]["c_dim"] 19 | self.spatial_dim = self.config["model_params"]["spatial_dim"] 20 | 21 | init_sp = 2 22 | init_ngf_mult = 16 23 | in_ch = self.config["model_params"]["z_dim"] 24 | out_ch = self.ngf_base*init_ngf_mult 25 | self.initial_layer = nn.Linear(in_ch+self.spatial_dim, out_ch*init_sp*init_sp) 26 | self.residual_block_main = nn.ModuleList() 27 | self.residual_block_residue = nn.ModuleList() 28 | self.residual_block_cbn = nn.ModuleList() 29 | # Stacking residual blocks 30 | num_resize_layers = int(math.log(min(self.micro_patch_size), 2) - 1) 31 | num_total_layers = num_resize_layers + self.num_extra_layers 32 | basic_layers = [8, 4, 2] 33 | if num_total_layers>=len(basic_layers): 34 | num_replicate_layers = num_total_layers - len(basic_layers) 35 | ngf_mult_list = basic_layers + [1, ] * num_replicate_layers 36 | else: 37 | ngf_mult_list = basic_layers[:num_total_layers] 38 | 39 | for idx, ngf_mult in enumerate(ngf_mult_list): 40 | n_ch = self.ngf_base * ngf_mult 41 | self.residual_block_cbn.append(nn.Linear(in_ch+self.spatial_dim, out_ch)) 42 | self.residual_block_cbn.append(nn.Linear(in_ch+self.spatial_dim, out_ch)) 43 | self.residual_block_main.append(nn.Conv2d(out_ch, n_ch, 3, 1, padding=1)) 44 | self.residual_block_cbn.append(nn.Linear(in_ch+self.spatial_dim, n_ch)) 45 | self.residual_block_cbn.append(nn.Linear(in_ch+self.spatial_dim, n_ch)) 46 | self.residual_block_main.append(nn.Conv2d(n_ch, n_ch, 3, 1, padding=1)) 47 | self.residual_block_residue.append(nn.Conv2d(out_ch, n_ch, 1, 1, padding=0)) 48 | out_ch = n_ch 49 | 50 | self.bn = nn.BatchNorm2d(out_ch, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True) 51 | self.final_conv = nn.Conv2d(out_ch, self.c_dim, 3, 1, padding=1) 52 | 53 | 54 | def _cbn(self, x, y, residual_main_cbn_idx, is_training): # Spectral Batch Normalization 55 | ch = list(x.size())[1] 56 | gamma = self.residual_block_cbn[residual_main_cbn_idx](y) 57 | beta = self.residual_block_cbn[residual_main_cbn_idx+1](y) 58 | 59 | mean_rec = torch.zeros(ch, requires_grad=False) 60 | var_rec = torch.ones(ch, requires_grad=False) 61 | running_mean = torch.mean(x, [0, 2, 3]) 62 | running_var = torch.var(x, [0, 2, 3]) 63 | 64 | if is_training: 65 | new_mean_rec = 0.99 * mean_rec + 0.01 * running_mean 66 | new_var_rec = 0.99 * var_rec + 0.01 * running_var 67 | mean = new_mean_rec 68 | var = new_var_rec 69 | else: 70 | mean = mean_rec 71 | var = var_rec 72 | 73 | mean = mean.view(1, ch, 1, 1) 74 | var = var.view(1, ch, 1, 1) 75 | gamma = gamma.view(-1, ch, 1, 1) 76 | beta = beta.view(-1, ch, 1, 1) 77 | 78 | out = (x-mean) / (var+_EPS) * gamma + beta 79 | return out 80 | 81 | 82 | def forward(self, z, coord, is_training): 83 | 84 | init_sp = 2 85 | init_ngf_mult = 16 86 | cond = torch.cat((z.float(), torch.tensor(coord).float()), 1) 87 | h = self.initial_layer(cond) 88 | h = h.view(-1, self.ngf_base*init_ngf_mult, init_sp, init_sp) 89 | 90 | # Stacking residual blocks 91 | num_resize_layers = int(math.log(min(self.micro_patch_size), 2) - 1) 92 | num_total_layers = num_resize_layers + self.num_extra_layers 93 | basic_layers = [8, 4, 2] 94 | if num_total_layers>=len(basic_layers): 95 | num_replicate_layers = num_total_layers - len(basic_layers) 96 | ngf_mult_list = basic_layers + [1, ] * num_replicate_layers 97 | else: 98 | ngf_mult_list = basic_layers[:num_total_layers] 99 | 100 | X=h 101 | residual_main_idx = 0 102 | residual_residue_idx = 0 103 | residual_main_cbn_idx = 0 104 | for idx, ngf_mult in enumerate(ngf_mult_list): 105 | # Standard layers first 106 | if idx < num_resize_layers: 107 | resize = True 108 | # Extra layers do not resize spatial size 109 | else: 110 | resize = False 111 | 112 | h = X 113 | h = self._cbn(h, cond, residual_main_cbn_idx, is_training) 114 | residual_main_cbn_idx+=2 115 | 116 | h = nn.functional.relu(h) 117 | if resize: 118 | h = upscale(h, 2) 119 | 120 | h = self.residual_block_main[residual_main_idx](h) 121 | residual_main_idx+=1 122 | 123 | h = self._cbn(h, cond, residual_main_cbn_idx, is_training) 124 | residual_main_cbn_idx+=2 125 | 126 | h = nn.functional.relu(h) 127 | h = self.residual_block_main[residual_main_idx](h) 128 | residual_main_idx+=1 129 | 130 | if resize: 131 | sc = upscale(X, 2) 132 | else: 133 | sc = X 134 | sc = self.residual_block_residue[residual_residue_idx](sc) 135 | residual_residue_idx+=1 136 | 137 | X = h + sc 138 | 139 | X = self.bn(X) 140 | X = nn.functional.relu(X) 141 | X = self.final_conv(h) 142 | return torch.tanh(X) -------------------------------------------------------------------------------- /coord_handler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class CoordHandler(): 4 | 5 | def __init__(self, config): 6 | self.config = config 7 | 8 | self.batch_size = self.config["train_params"]["batch_size"] 9 | self.micro_patch_size = self.config["data_params"]["micro_patch_size"] 10 | self.macro_patch_size = self.config["data_params"]["macro_patch_size"] 11 | self.full_image_size = self.config["data_params"]["full_image_size"] 12 | self.coordinate_system = self.config["data_params"]["coordinate_system"] 13 | self.c_dim = self.config["data_params"]["c_dim"] 14 | 15 | self.ratio_macro_to_micro = self.config["data_params"]["ratio_macro_to_micro"] 16 | self.ratio_full_to_micro = self.config["data_params"]["ratio_full_to_micro"] 17 | self.num_micro_compose_macro = self.config["data_params"]["num_micro_compose_macro"] 18 | 19 | self.cache = { 20 | "const_centroid": {}, 21 | } 22 | 23 | def sample_coord(self, num_extrap_steps=0): 24 | if self.coordinate_system == "euclidean": 25 | return self._euclidean_sample_coord(num_extrap_steps=num_extrap_steps) 26 | elif self.coordinate_system == "cylindrical": 27 | assert num_extrap_steps==0 28 | return self._cylindrical_sample_coord() 29 | else: 30 | raise NotImplementedError() 31 | 32 | 33 | def euclidean_coord_int_full_to_float_micro(self, i, ratio_full_to_micro, extrap_steps=0): 34 | if extrap_steps>0: # Extrapolation training 35 | ratio_original = ratio_full_to_micro - extrap_steps*2 36 | return -1 + (i-extrap_steps) * 2 / (ratio_original - 1) 37 | else: 38 | return -1 + i * 2 / (ratio_full_to_micro - 1) 39 | 40 | 41 | # def hyperbolic_coord_int_full_to_float_micro(self, i, ratio_full_to_micro): 42 | # return -1 + i * 2 / ratio_full_to_micro 43 | # 44 | # 45 | # def hyperbolic_theta_to_euclidean(self, angle_ratio, proj_func): 46 | # p = proj_func(np.pi * angle_ratio) 47 | # if (type(p) is float) or (len(p.shape)==0): 48 | # p = p if abs(p) > 1e-6 else 0 49 | # else: 50 | # p[np.abs(p)<1e-6] = 0 51 | # return p 52 | 53 | 54 | def _gen_const_centroids(self, target, dim, num_extrap_steps=0, is_hyperbolic=False): 55 | assert target in {"full", "macro"} 56 | 57 | # Check if in cache 58 | cache_key = (target, dim, num_extrap_steps, is_hyperbolic) 59 | if cache_key in self.cache["const_centroid"]: 60 | return self.cache["const_centroid"][cache_key] 61 | 62 | const_centroid = [] 63 | if target=="full" and is_hyperbolic: 64 | assert dim==1 65 | num_patches = self.ratio_full_to_micro[dim] # Warn: -1 is the same location as 1 in 3D, so ignore 1 here 66 | ratio_over_micro = self.ratio_full_to_micro[dim] 67 | for i in range(num_patches): 68 | coord = -1 + i * 2 / ratio_over_micro 69 | const_centroid.append(coord) 70 | elif target=="full": 71 | num_pad_patch = (self.ratio_macro_to_micro[dim] - 1) 72 | num_patches = self.ratio_full_to_micro[dim] - num_pad_patch + num_extrap_steps*2 73 | ratio_over_micro = self.ratio_full_to_micro[dim] 74 | for i in range(num_patches): 75 | coord = -1 + (i-num_extrap_steps) * 2 / (ratio_over_micro - 1 - num_pad_patch) 76 | const_centroid.append(coord) 77 | else: 78 | num_patches = self.ratio_macro_to_micro[dim] 79 | ratio_over_micro = self.ratio_macro_to_micro[dim] 80 | for i in range(num_patches): 81 | coord = i / (ratio_over_micro - 1) 82 | const_centroid.append(coord) 83 | const_centroid = np.array(const_centroid) 84 | self.cache["const_centroid"][cache_key] = const_centroid 85 | return const_centroid 86 | 87 | 88 | def _euclidean_sample_coord(self, num_extrap_steps=0): 89 | 90 | const_centroid_x = self._gen_const_centroids(target="full", dim=0, num_extrap_steps=num_extrap_steps) 91 | const_centroid_y = self._gen_const_centroids(target="full", dim=1, num_extrap_steps=num_extrap_steps) 92 | const_micro_in_macro_x = self._gen_const_centroids(target="macro", dim=0) 93 | const_micro_in_macro_y = self._gen_const_centroids(target="macro", dim=1) 94 | 95 | # Random cropping position 96 | ps = self.micro_patch_size 97 | gs = self.macro_patch_size 98 | m_ratio = self.ratio_macro_to_micro 99 | valid_crop_size_x = self.full_image_size[0] - ps[0] 100 | valid_crop_size_y = self.full_image_size[1] - ps[1] 101 | macro_patch_center_range_x = self.full_image_size[0]-self.macro_patch_size[0] 102 | macro_patch_center_range_y = self.full_image_size[1]-self.macro_patch_size[1] 103 | 104 | num_pad_patch_x = (self.ratio_macro_to_micro[0] - 1) 105 | num_pad_patch_y = (self.ratio_macro_to_micro[1] - 1) 106 | d_macro_center_idx_x = np.random.randint(0, self.ratio_full_to_micro[0]-num_pad_patch_x+num_extrap_steps*2, self.batch_size) 107 | d_macro_center_idx_y = np.random.randint(0, self.ratio_full_to_micro[1]-num_pad_patch_y+num_extrap_steps*2, self.batch_size) 108 | 109 | d_macro_pos_x = np.array([const_centroid_x[i] for i in d_macro_center_idx_x]).reshape(-1, 1) 110 | d_macro_pos_y = np.array([const_centroid_y[i] for i in d_macro_center_idx_y]).reshape(-1, 1) 111 | 112 | # Wrap value to avoid numerical issue (e.g., 1.000001) 113 | if num_extrap_steps==0: 114 | d_macro_pos_x = np.clip(d_macro_pos_x, -1, 1) 115 | d_macro_pos_y = np.clip(d_macro_pos_y, -1, 1) 116 | 117 | d_macro_coord = np.concatenate([d_macro_pos_x, d_macro_pos_y], axis=1) 118 | 119 | # Transform d global position ([-1, 1]) to patch position ([-1, 1]) 120 | g_micro_pos_x_proto = np.tile(np.expand_dims(d_macro_pos_x, 1), [1, self.num_micro_compose_macro, 1]) 121 | g_micro_pos_y_proto = np.tile(np.expand_dims(d_macro_pos_y, 1), [1, self.num_micro_compose_macro, 1]) 122 | g_micro_pos_x_l, g_micro_pos_y_l = [], [] 123 | gpc_x = macro_patch_center_range_x 124 | gpc_y = macro_patch_center_range_y 125 | for yy in range(self.ratio_macro_to_micro[1]): 126 | for xx in range(self.ratio_macro_to_micro[0]): 127 | idx = yy*m_ratio[0] + xx 128 | T_x = const_micro_in_macro_x[xx] 129 | T_y = const_micro_in_macro_y[yy] 130 | g_micro_pos_x = ((g_micro_pos_x_proto[:,idx] + 1)/2 * gpc_x + (gs[0]/2) + (T_x*(gs[0]-ps[0])) - (m_ratio[0]/2)*ps[0]) / valid_crop_size_x * 2 - 1 131 | g_micro_pos_y = ((g_micro_pos_y_proto[:,idx] + 1)/2 * gpc_y + (gs[1]/2) + (T_y*(gs[1]-ps[1])) - (m_ratio[1]/2)*ps[1]) / valid_crop_size_y * 2 - 1 132 | g_micro_pos_x_l.append(g_micro_pos_x) 133 | g_micro_pos_y_l.append(g_micro_pos_y) 134 | g_micro_pos_x = np.concatenate(g_micro_pos_x_l, axis=1).reshape(-1, 1) 135 | g_micro_pos_y = np.concatenate(g_micro_pos_y_l, axis=1).reshape(-1, 1) 136 | g_micro_coord = np.concatenate([g_micro_pos_x, g_micro_pos_y], axis=1) 137 | 138 | # Unused, put some trash values 139 | c_angle_ratio = np.zeros_like(g_micro_pos_x) 140 | 141 | return d_macro_coord, g_micro_coord, c_angle_ratio 142 | 143 | 144 | 145 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | 2 | import time 3 | import numpy as np 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | 9 | class Trainer(): 10 | def __init__(self, config, 11 | g_builder, d_builder, cp_builder, zp_builder, 12 | coord_handler, patch_handler): 13 | self.config = config 14 | self.g_builder = g_builder 15 | self.d_builder = d_builder 16 | self.cp_builder = cp_builder 17 | self.zp_builder = zp_builder 18 | self.coord_handler = coord_handler 19 | self.patch_handler = patch_handler 20 | 21 | # Vars for graph building 22 | self.batch_size = self.config["train_params"]["batch_size"] 23 | self.z_dim = self.config["model_params"]["z_dim"] 24 | self.spatial_dim = self.config["model_params"]["spatial_dim"] 25 | self.micro_patch_size = self.config["data_params"]["micro_patch_size"] 26 | self.macro_patch_size = self.config["data_params"]["macro_patch_size"] 27 | 28 | self.ratio_macro_to_micro = self.config["data_params"]["ratio_macro_to_micro"] 29 | self.ratio_full_to_micro = self.config["data_params"]["ratio_full_to_micro"] 30 | self.num_micro_compose_macro = self.config["data_params"]["num_micro_compose_macro"] 31 | 32 | # Vars for training loop 33 | self.exp_name = config["log_params"]["exp_name"] 34 | self.epochs = float(self.config["train_params"]["epochs"]) 35 | self.num_batches = self.config["data_params"]["num_train_samples"] // self.batch_size 36 | self.coordinate_system = self.config["data_params"]["coordinate_system"] 37 | self.G_update_period = self.config["train_params"]["G_update_period"] 38 | self.D_update_period = self.config["train_params"]["D_update_period"] 39 | self.Q_update_period = self.config["train_params"]["Q_update_period"] 40 | 41 | # Loss weights 42 | self.code_loss_w = self.config["loss_params"]["code_loss_w"] 43 | self.coord_loss_w = self.config["loss_params"]["coord_loss_w"] 44 | self.gp_lambda = self.config["loss_params"]["gp_lambda"] 45 | 46 | # Extrapolation parameters handling 47 | self.num_extrap_steps = 0 48 | self.weight_cliping_limit = 0.01 49 | 50 | 51 | 52 | def _train_content_prediction_model(self): 53 | return (self.Q_update_period>0) and (self.config["train_params"]["qlr"]>0) 54 | 55 | 56 | def sample_prior(self): 57 | return np.random.uniform(-1., 1., [self.batch_size, self.z_dim]).astype(np.float32) 58 | 59 | def tile(self, a, dim, n_tile): 60 | init_dim = a.size(dim) 61 | repeat_idx = [1] * a.dim() 62 | repeat_idx[dim] = n_tile 63 | a = a.repeat(*(repeat_idx)) 64 | order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])) 65 | return torch.index_select(a, dim, order_index) 66 | 67 | def _dup_z_for_macro(self, z): 68 | # Duplicate with nearest neighbor, different to `tf.tile`. 69 | ch = z.shape[-1] 70 | repeat = self.num_micro_compose_macro 71 | extend = torch.unsqueeze(z,1) 72 | 73 | extend_dup = self.tile(extend, 1, repeat) 74 | return extend_dup.view(-1, ch) 75 | 76 | def calc_gradient_penalty(self): 77 | """ Gradient Penalty for patches D """ 78 | # This is borrowed from https://github.com/kodalinaveen3/DRAGAN/blob/master/DRAGAN.ipynb 79 | alpha = torch.FloatTensor(self.real_macro.size(0),1,1,1).uniform_(0., 1.) 80 | alpha = alpha.expand(self.real_macro.size(0), self.real_macro.size(1), self.real_macro.size(2), self.real_macro.size(3)) 81 | interpolates = alpha*self.real_macro + ((1-alpha) * self.gen_macro) 82 | 83 | disc_inter, _ = self.d_builder(interpolates, None, is_training=True) 84 | disc_inter.unsqueeze_(1) 85 | disc_inter.unsqueeze_(2) 86 | disc_inter = disc_inter.type(torch.DoubleTensor) 87 | 88 | interpolates.detach().type(torch.DoubleTensor).requires_grad_(True) 89 | 90 | gradients = torch.autograd.grad(outputs=disc_inter, inputs=interpolates, grad_outputs= torch.ones(interpolates.size(), dtype=torch.float64), create_graph=True, allow_unused=True, retain_graph=True)[0] 91 | slopes = torch.sqrt(torch.sum(torch.pow(gradients,2), 1)) 92 | gradient_penalty = torch.mean((slopes - 1.) ** 2).type(torch.DoubleTensor) * self.config["loss_params"]["gp_lambda"] 93 | return gradient_penalty 94 | 95 | 96 | def generate_full_image_cpu(self, z): 97 | all_micro_patches = [] 98 | all_micro_coord = [] 99 | num_patches_x = self.ratio_full_to_micro[0] + self.num_extrap_steps * 2 100 | num_patches_y = self.ratio_full_to_micro[1] + self.num_extrap_steps * 2 101 | 102 | full_image = np.empty((z.shape[0],self.config["data_params"]["c_dim"],self.micro_patch_size[0]*num_patches_x,0)) 103 | for yy in range(num_patches_y): 104 | rows_data = np.empty((z.shape[0],self.config["data_params"]["c_dim"],0,self.micro_patch_size[1])) 105 | for xx in range(num_patches_x): 106 | micro_coord_single = np.array([ 107 | self.coord_handler.euclidean_coord_int_full_to_float_micro(xx, num_patches_x, extrap_steps=self.num_extrap_steps), 108 | self.coord_handler.euclidean_coord_int_full_to_float_micro(yy, num_patches_y, extrap_steps=self.num_extrap_steps), 109 | ]) 110 | micro_coord = np.tile(np.expand_dims(micro_coord_single, 0), [z.shape[0], 1]) 111 | generated_patch = self.g_builder(torch.tensor(z, dtype=torch.float64), torch.tensor(micro_coord, dtype=torch.float64), is_training=False) 112 | rows_data = np.concatenate((rows_data, generated_patch.detach().numpy()), axis=2) 113 | all_micro_patches.append(torch.Tensor.cpu(generated_patch).detach().numpy()) 114 | all_micro_coord.append(micro_coord) 115 | full_image = np.concatenate((full_image, rows_data), axis=3) 116 | 117 | all_micro_patches = np.concatenate(all_micro_patches, 0) 118 | 119 | return all_micro_patches, full_image 120 | 121 | 122 | def train(self, logger, real_images): 123 | 124 | # Optimizers 125 | self.g_optim = optim.Adam(self.g_builder.parameters(), lr=self.config["train_params"]["glr"], betas=[self.config["train_params"]["beta1"], self.config["train_params"]["beta2"]]) 126 | self.d_optim = optim.Adam(self.d_builder.parameters(), lr=self.config["train_params"]["dlr"], betas=[self.config["train_params"]["beta1"], self.config["train_params"]["beta2"]]) 127 | self.q_optim = optim.Adam(self.zp_builder.parameters(), lr=self.config["train_params"]["qlr"], betas=[self.config["train_params"]["beta1"], self.config["train_params"]["beta2"]]) 128 | 129 | 130 | z_fixed = self.sample_prior() 131 | global_step = 0 132 | start_time = time.time() 133 | self.g_loss, self.d_loss, self.q_loss = 0, 0, 0 134 | 135 | cur_epoch = int(global_step / self.num_batches) 136 | cur_iter = global_step - cur_epoch * self.num_batches 137 | 138 | self.g_optim.zero_grad() 139 | self.d_optim.zero_grad() 140 | 141 | while cur_epoch < 1000: 142 | for cur_iter, data in enumerate(real_images): 143 | print('iter : '+str(cur_iter)) 144 | # for p in self.d_builder.parameters(): 145 | # p.data.clamp_(-self.weight_cliping_limit, self.weight_cliping_limit) 146 | 147 | image = data[0] 148 | 149 | # Create data 150 | z = torch.tensor(self.sample_prior()) 151 | macro_coord, micro_coord, y_angle_ratio = self.coord_handler.sample_coord() 152 | 153 | micro_coord_fake = micro_coord 154 | macro_coord_fake = macro_coord 155 | micro_coord_real = micro_coord 156 | macro_coord_real = macro_coord 157 | 158 | 159 | # Crop real micro for visualization 160 | self.real_micro = self.patch_handler.crop_micro_from_full_cpu(image, micro_coord_real[:, 0:1], micro_coord_real[:, 1:2]) 161 | self.real_macro = self.patch_handler.concat_micro_patches_cpu(self.real_micro, ratio_over_micro=self.ratio_macro_to_micro) 162 | 163 | (self.disc_real, disc_real_h) = self.d_builder(self.real_macro, macro_coord_real, is_training=True) 164 | self.c_real_pred = self.cp_builder(torch.unsqueeze(disc_real_h,1), is_training=True) 165 | self.z_real_pred = self.zp_builder(torch.unsqueeze(disc_real_h,1), is_training=True) 166 | 167 | # Fake part 168 | z_dup_macro = self._dup_z_for_macro(z).clone().detach() 169 | self.gen_micro = self.g_builder(z_dup_macro, micro_coord_fake, is_training=True) 170 | self.gen_macro = self.patch_handler.concat_micro_patches_cpu(self.gen_micro, ratio_over_micro=self.ratio_macro_to_micro) 171 | (self.disc_fake, disc_fake_h) = self.d_builder(self.gen_macro, macro_coord_fake, is_training=True) 172 | self.c_fake_pred = self.cp_builder(torch.unsqueeze(disc_fake_h,1), is_training=True) 173 | self.z_fake_pred = self.zp_builder(torch.unsqueeze(disc_fake_h,1), is_training=True) 174 | 175 | 176 | self.macro_error = nn.MSELoss()(self.real_macro.type(torch.float64), self.gen_macro.type(torch.float64)) 177 | 178 | # Spatial consistency loss (reduce later) 179 | self.coord_mse_real = self.coord_loss_w * nn.MSELoss()(torch.squeeze(torch.tensor(macro_coord_real).double()), torch.squeeze(self.c_real_pred).double()) 180 | self.coord_mse_fake = self.coord_loss_w * nn.MSELoss()(torch.squeeze(torch.tensor(macro_coord_fake).double()), torch.squeeze(self.c_fake_pred).double()) 181 | 182 | 183 | self.coord_mse_real = torch.mean(self.coord_mse_real) 184 | self.coord_mse_fake = torch.mean(self.coord_mse_fake) 185 | self.coord_loss = self.coord_mse_real + self.coord_mse_fake 186 | 187 | # Content consistency loss 188 | z_real = z.clone().detach().requires_grad_(True).type(torch.DoubleTensor).unsqueeze(1) 189 | z_fake = self.z_fake_pred.clone().detach().requires_grad_(True).type(torch.DoubleTensor) 190 | self.code_loss = self.code_loss_w * torch.mean(nn.L1Loss()(z_real, z_fake)) 191 | 192 | # Gradient penalty loss of WGAN-GP 193 | gradient_penalty = self.calc_gradient_penalty() 194 | self.gp_loss = gradient_penalty 195 | 196 | # WGAN loss 197 | self.adv_real = torch.mean(self.disc_real) 198 | self.adv_fake = torch.mean(self.disc_fake) 199 | 200 | self.d_adv_loss = -self.adv_real + self.adv_fake 201 | self.g_adv_loss = -self.adv_fake #+ self.macro_error.type(torch.float64) 202 | 203 | 204 | # Total loss 205 | self.d_loss = self.d_adv_loss + self.gp_loss + self.coord_loss + self.code_loss 206 | self.g_loss = self.g_adv_loss + self.coord_loss + self.code_loss 207 | self.q_loss = self.g_adv_loss + self.code_loss 208 | 209 | self.w_dist = torch.abs(self.adv_real - self.adv_fake) 210 | print('Wasserstein Distance '+ str(self.w_dist)) 211 | 212 | self.d_loss.float().backward(retain_graph=True) 213 | self.g_loss.float().backward(retain_graph=True) 214 | 215 | self.d_optim.step() 216 | self.d_optim.zero_grad() 217 | 218 | self.g_optim.step() 219 | self.g_optim.zero_grad() 220 | 221 | # Log 222 | time_elapsed = time.time() - start_time 223 | print("[{}] [Epoch: {}; {:4d}/{:4d}; global_step:{}] elapsed: {:.4f}".format( 224 | self.exp_name, cur_epoch, cur_iter, self.num_batches, global_step, time_elapsed)) 225 | 226 | logger.log_iter(self, cur_epoch, cur_iter, global_step, z, z_fixed) 227 | 228 | cur_iter += 1 229 | global_step += 1 230 | 231 | cur_epoch += 1 232 | cur_iter = 0 233 | --------------------------------------------------------------------------------