├── README.md ├── cyclegan ├── loaders.py ├── main.py ├── models.py ├── script.sh └── solver.py ├── images ├── cyclegan.jpg ├── pixel2pixel.jpg └── stargan.jpg ├── pixel2pixel ├── loaders.py ├── main.py ├── models.py ├── script.sh └── solver.py └── stargan ├── loaders.py ├── main.py ├── models.py ├── script.sh ├── solver.py └── tools.py /README.md: -------------------------------------------------------------------------------- 1 | # Pytorch-Image-Translation-GANs 2 | Pytorch implementations of most popular image-translation GANs, including [Pixel2Pixel](https://arxiv.org/pdf/1611.07004.pdf), [CycleGAN](https://arxiv.org/pdf/1703.10593.pdf) and [StarGAN](https://arxiv.org/pdf/1711.09020.pdf). 3 | 4 | ## Dependencies 5 | * [Anaconda (Python 2.7)](https://www.anaconda.com/download/) 6 | * [PyTorch 0.4.0](http://pytorch.org/) 7 | 8 | ## Prepare Datasets 9 | ### Datasets 10 | * [Pixel2Pixel (Facade)](http://cmp.felk.cvut.cz/~tylecr1/facade/CMP_facade_DB_base.zip). 11 | * [CycleGAN (Apple2Orange)](https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/apple2orange.zip). 12 | * [StarGAN (CelebA)](https://pan.baidu.com/s/1eSNpdRG#list/path=%2F): download 'CelebA/Img/img_align_celeba.zip' and 'CelebA/Anno/list_attr_celeba.txt'. 13 | 14 | 15 | ## Experimental Results 16 | ### Pixel2Pixel 17 | ![](https://github.com/wangguanan/Pytorch-Image-Translation-GANs/blob/master/images/pixel2pixel.jpg?raw=true) 18 | ### CycleGAN 19 | ![](https://github.com/wangguanan/Pytorch-Image-Translation-GANs/blob/master/images/cyclegan.jpg?raw=true) 20 | ### StarGAN 21 | ![](https://github.com/wangguanan/Pytorch-Image-Translation-GANs/blob/master/images/stargan.jpg?raw=true) 22 | 23 | 24 | ## Acknowledgement 25 | This project is going with the [GAN Theory and Practice](https://study.163.com/course/courseLearn.htm?courseId=1006498024&share=2&shareId=400000000681046#/learn/live?lessonId=1054152368&courseId=1006498024) part of the [Deep Learning Course: from Algorithm to Practice](https://study.163.com/course/courseMain.htm?share=2&shareId=400000000681046&courseId=1006498024&_trace_c_p_k2_=d197343763ee421eae96c4cdb1b129cb). 26 | 27 | ## Contacts 28 | If you have any question about the project, please feel free to contact with me. 29 | 30 | E-mail: guan.wang0706@gmail.com 31 | -------------------------------------------------------------------------------- /cyclegan/loaders.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | 4 | import torch.utils.data as data 5 | from torch.utils.data import Dataset 6 | from PIL import Image 7 | import torchvision.transforms as transforms 8 | 9 | 10 | class ImageFolder(Dataset): 11 | def __init__(self, path, transform): 12 | self.transform = transform 13 | self.samples = sorted(glob.glob(os.path.join(path + '/*.*'))) 14 | 15 | def __getitem__(self, index): 16 | sample = self.transform(Image.open(self.samples[index])) 17 | return sample 18 | 19 | def __len__(self): 20 | return len(self.samples) 21 | 22 | 23 | class IterLoader: 24 | 25 | def __init__(self, loader): 26 | self.loader = loader 27 | self.iter = iter(self.loader) 28 | 29 | def next_one(self): 30 | try: 31 | return next(self.iter) 32 | except: 33 | self.iter = iter(self.loader) 34 | return next(self.iter) 35 | 36 | 37 | class Loaders: 38 | 39 | def __init__(self, config): 40 | 41 | self.dataset_path = config.dataset_path 42 | self.image_size = config.image_size 43 | self.batch_size = config.batch_size 44 | 45 | self.transforms_train = transforms.Compose([transforms.RandomResizedCrop(self.image_size), 46 | transforms.RandomHorizontalFlip(), 47 | transforms.ToTensor(), 48 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])]) 49 | self.transforms_test = transforms.Compose([transforms.Resize(self.image_size), 50 | transforms.ToTensor(), 51 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])]) 52 | 53 | train_a_set = ImageFolder(os.path.join(self.dataset_path, 'trainA/'), self.transforms_train) 54 | train_b_set = ImageFolder(os.path.join(self.dataset_path, 'trainB/'), self.transforms_train) 55 | test_a_set = ImageFolder(os.path.join(self.dataset_path, 'testA/'), self.transforms_test) 56 | test_b_set = ImageFolder(os.path.join(self.dataset_path, 'testB/'), self.transforms_test) 57 | 58 | self.train_a_loader = data.DataLoader(dataset=train_a_set, batch_size=self.batch_size, shuffle=True, num_workers=4, drop_last=True) 59 | self.train_b_loader = data.DataLoader(dataset=train_b_set, batch_size=self.batch_size, shuffle=True, num_workers=4, drop_last=True) 60 | self.test_a_loader = data.DataLoader(dataset=test_a_set, batch_size=self.batch_size, shuffle=False, num_workers=4, drop_last=False) 61 | self.test_b_loader = data.DataLoader(dataset=test_b_set, batch_size=self.batch_size, shuffle=False, num_workers=4, drop_last=False) 62 | 63 | self.train_a_iter = IterLoader(self.train_a_loader) 64 | self.train_b_iter = IterLoader(self.train_b_loader) 65 | -------------------------------------------------------------------------------- /cyclegan/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | from solver import Solver 5 | from loaders import Loaders 6 | 7 | 8 | def main(config): 9 | 10 | # environments 11 | os.environ['CUDA_VISIBLE_DEVICES'] = config.cuda 12 | if not os.path.exists('out/'): 13 | os.makedirs(os.path.join(config.output_path, 'images/')) 14 | os.makedirs(os.path.join(config.output_path, 'models/')) 15 | 16 | # init dataset 17 | loaders = Loaders(config) 18 | 19 | # main 20 | solver = Solver(config, loaders) 21 | solver.train() 22 | 23 | 24 | if __name__ == '__main__': 25 | 26 | parser = argparse.ArgumentParser() 27 | 28 | # Environment Configuration 29 | parser.add_argument('--cuda', type=str, default='-1', help='If -1, use cpu; 2 for single GPU(2), 2,3,4 for multi GPUS(2,3,4)') 30 | parser.add_argument('--output_path', type=str, default='out/apple2orange/') 31 | 32 | # Dataset Configuration 33 | parser.add_argument('--dataset_path', type=str, default='data/apple2orange/') 34 | parser.add_argument('--image_size', type=int, default=256) 35 | parser.add_argument('--batch_size', type=int, default=1) 36 | 37 | # Training Configuration 38 | parser.add_argument('--resume_iteration', type=int, default=-1, help='if -1, train from scratch; if >=0, resume from the model and start to train') 39 | 40 | # main function 41 | config = parser.parse_args() 42 | main(config) 43 | 44 | 45 | 46 | -------------------------------------------------------------------------------- /cyclegan/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def weights_init_normal(m): 6 | classname = m.__class__.__name__ 7 | if classname.find('Conv') != -1: 8 | torch.nn.init.normal_(m.weight.data, 0.0, 0.02) 9 | elif classname.find('BatchNorm2d') != -1: 10 | torch.nn.init.normal_(m.weight.data, 1.0, 0.02) 11 | torch.nn.init.constant_(m.bias.data, 0.0) 12 | 13 | 14 | class ResidualBlock(nn.Module): 15 | '''Residual Block with Instance Normalization''' 16 | 17 | def __init__(self, in_channels, out_channels): 18 | super(ResidualBlock, self).__init__() 19 | 20 | self.model = nn.Sequential( 21 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False), 22 | nn.InstanceNorm2d(out_channels, affine=True, track_running_stats=True), 23 | nn.ReLU(inplace=True), 24 | nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False), 25 | nn.InstanceNorm2d(out_channels, affine=True, track_running_stats=True), 26 | ) 27 | 28 | def forward(self, x): 29 | return self.model(x) + x 30 | 31 | 32 | class Generator(nn.Module): 33 | '''Generator with Down sampling, Several ResBlocks and Up sampling. 34 | Down/Up Samplings are used for less computation. 35 | ''' 36 | 37 | def __init__(self, conv_dim, layer_num): 38 | super(Generator, self).__init__() 39 | 40 | layers = [] 41 | 42 | # input layer 43 | layers.append(nn.Conv2d(in_channels=3, out_channels=conv_dim, kernel_size=7, stride=1, padding=3, bias=False)) 44 | layers.append(nn.InstanceNorm2d(conv_dim, affine=True, track_running_stats=True)) 45 | layers.append(nn.ReLU(inplace=True)) 46 | 47 | # down sampling layers 48 | current_dims = conv_dim 49 | for i in xrange(2): 50 | layers.append(nn.Conv2d(current_dims, current_dims*2, kernel_size=4, stride=2, padding=1, bias=False)) 51 | layers.append(nn.InstanceNorm2d(current_dims*2, affine=True, track_running_stats=True)) 52 | layers.append(nn.ReLU(inplace=True)) 53 | current_dims *= 2 54 | 55 | # Residual Layers 56 | for i in xrange(layer_num): 57 | layers.append(ResidualBlock(current_dims, current_dims)) 58 | 59 | # up sampling layers 60 | for i in xrange(2): 61 | layers.append(nn.ConvTranspose2d(current_dims, current_dims//2, kernel_size=4, stride=2, padding=1, bias=False)) 62 | layers.append(nn.InstanceNorm2d(current_dims//2, affine=True, track_running_stats=True)) 63 | layers.append(nn.ReLU(inplace=True)) 64 | current_dims = current_dims//2 65 | 66 | # output layer 67 | layers.append(nn.Conv2d(current_dims, 3, kernel_size=7, stride=1, padding=3, bias=False)) 68 | layers.append(nn.Tanh()) 69 | 70 | self.model = nn.Sequential(*layers) 71 | 72 | def forward(self, x,): 73 | return self.model(x) 74 | 75 | 76 | class Discriminator(nn.Module): 77 | '''Discriminator with PatchGAN''' 78 | 79 | def __init__(self, image_size, conv_dim, layer_num): 80 | super(Discriminator, self).__init__() 81 | 82 | layers = [] 83 | 84 | # input layer 85 | layers.append(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1)) 86 | layers.append(nn.LeakyReLU(0.2, inplace=True)) 87 | current_dim = conv_dim 88 | 89 | # hidden layers 90 | for i in xrange(layer_num): 91 | layers.append(nn.Conv2d(current_dim, current_dim*2, kernel_size=4, stride=2, padding=1)) 92 | layers.append(nn.InstanceNorm2d(current_dim*2)) 93 | layers.append(nn.LeakyReLU(0.2, inplace=True)) 94 | current_dim *= 2 95 | 96 | self.model = nn.Sequential(*layers) 97 | 98 | # output layer 99 | self.conv_src = nn.Conv2d(current_dim, 1, kernel_size=3, stride=1, padding=1, bias=False) 100 | 101 | def forward(self, x): 102 | x = self.model(x) 103 | out_src = self.conv_src(x) 104 | return out_src 105 | -------------------------------------------------------------------------------- /cyclegan/script.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | python main.py --cuda 0 -------------------------------------------------------------------------------- /cyclegan/solver.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | from torchvision.utils import save_image 4 | import os 5 | import itertools 6 | 7 | from models import Generator, Discriminator, weights_init_normal 8 | 9 | 10 | class Solver: 11 | 12 | def __init__(self, config, loaders): 13 | 14 | self.config = config 15 | self.loaders = loaders 16 | 17 | self.save_images_path = os.path.join(self.config.output_path, 'images/') 18 | self.save_models_path = os.path.join(self.config.output_path, 'models/') 19 | 20 | if self.config.cuda != '-1': 21 | os.environ['CUDA_VISIBLE_DEVICES'] = config.cuda 22 | self.device = torch.device('cuda') 23 | else: 24 | self.device = torch.device('cpu') 25 | 26 | self._init_models() 27 | self._init_losses() 28 | self._init_optimizers() 29 | 30 | if self.config.resume_iteration == -1: 31 | self.start_iteration = 0 32 | elif self.config.resume_iteration >=0: 33 | self.start_iteration = self.config.resume_iteration 34 | self._restore_model(self.config.resume_iteration) 35 | 36 | 37 | def _init_models(self): 38 | 39 | self.G_A2B = Generator(64, 9) 40 | self.D_B = Discriminator(self.config.image_size, 64, 4) 41 | self.G_A2B.apply(weights_init_normal) 42 | self.D_B.apply(weights_init_normal) 43 | self.G_A2B = torch.nn.DataParallel(self.G_A2B).to(self.device) 44 | self.D_B = torch.nn.DataParallel(self.D_B).to(self.device) 45 | 46 | self.G_B2A = Generator(64, 9) 47 | self.D_A = Discriminator(self.config.image_size, 64, 4) 48 | self.G_B2A.apply(weights_init_normal) 49 | self.D_A.apply(weights_init_normal) 50 | self.G_B2A = torch.nn.DataParallel(self.G_B2A).to(self.device) 51 | self.D_A = torch.nn.DataParallel(self.D_A).to(self.device) 52 | 53 | 54 | def _init_losses(self): 55 | self.criterion_GAN = torch.nn.MSELoss() 56 | self.criterion_cycle = torch.nn.L1Loss() 57 | self.criterion_identity = torch.nn.L1Loss() 58 | 59 | 60 | def _init_optimizers(self): 61 | self.G_optimizer = optim.Adam(itertools.chain(self.G_B2A.parameters(),self.G_A2B.parameters()), lr=0.0002, betas=[0.5, 0.999]) 62 | self.D_A_optimizer = optim.Adam(self.D_A.parameters(), lr=0.0002, betas=[0.5, 0.999]) 63 | self.D_B_optimizer = optim.Adam(self.D_B.parameters(), lr=0.0002, betas=[0.5, 0.999]) 64 | 65 | self.G_lr_decay = optim.lr_scheduler.StepLR(self.G_optimizer, step_size=50000, gamma=0.1) 66 | self.D_A_lr_decay = optim.lr_scheduler.StepLR(self.D_A_optimizer, step_size=50000, gamma=0.1) 67 | self.D_B_lr_decay = optim.lr_scheduler.StepLR(self.D_B_optimizer, step_size=50000, gamma=0.1) 68 | 69 | def _lr_decay_step(self, current_iteration): 70 | self.G_lr_decay.step(current_iteration) 71 | self.D_A_lr_decay.step(current_iteration) 72 | self.D_B_lr_decay.step(current_iteration) 73 | 74 | 75 | def _save_model(self, current_iteration): 76 | torch.save(self.G_A2B.state_dict(), os.path.join(self.save_models_path, 'G_A2B_{}.pkl'.format(current_iteration))) 77 | torch.save(self.G_B2A.state_dict(), os.path.join(self.save_models_path, 'G_B2A_{}.pkl'.format(current_iteration))) 78 | torch.save(self.D_A.state_dict(), os.path.join(self.save_models_path, 'D_A_{}.pkl'.format(current_iteration))) 79 | torch.save(self.D_B.state_dict(), os.path.join(self.save_models_path, 'D_B_{}.pkl'.format(current_iteration))) 80 | print 'Note: Successfully save model as {}'.format(current_iteration) 81 | 82 | 83 | def _restore_model(self, resume_iteration): 84 | self.G_A2B.load_state_dict(torch.load(os.path.join(self.save_models_path, 'G_A2B_{}.pkl'.format(resume_iteration)))) 85 | self.G_B2A.load_state_dict(torch.load(os.path.join(self.save_models_path, 'G_B2A_{}.pkl'.format(resume_iteration)))) 86 | self.D_A.load_state_dict(torch.load(os.path.join(self.save_models_path, 'D_A_{}.pkl'.format(resume_iteration)))) 87 | self.D_B.load_state_dict(torch.load(os.path.join(self.save_models_path, 'D_B_{}.pkl'.format(resume_iteration)))) 88 | print 'Note: Successfully resume model from {}'.format(resume_iteration) 89 | 90 | 91 | def train(self): 92 | 93 | fixed_real_a = self.loaders.train_a_iter.next_one() 94 | fixed_real_b = self.loaders.train_b_iter.next_one() 95 | ones = torch.ones_like(self.D_A(fixed_real_a)) 96 | zeros = torch.zeros_like(self.D_A(fixed_real_b)) 97 | for ii in xrange(16/self.config.batch_size-1): 98 | fixed_real_a_ = self.loaders.train_a_iter.next_one() 99 | fixed_real_b_ = self.loaders.train_b_iter.next_one() 100 | fixed_real_a = torch.cat([fixed_real_a, fixed_real_a_], dim=0) 101 | fixed_real_b = torch.cat([fixed_real_b, fixed_real_b_], dim=0) 102 | fixed_real_a, fixed_real_b = fixed_real_a.to(self.device), fixed_real_b.to(self.device) 103 | 104 | 105 | for iteration in range(self.start_iteration, 100000): 106 | 107 | # save images 108 | if iteration % 100 == 0: 109 | with torch.no_grad(): 110 | self.G_A2B = self.G_A2B.eval() 111 | self.G_B2A = self.G_B2A.eval() 112 | fake_a = self.G_B2A(fixed_real_b) 113 | fake_b = self.G_A2B(fixed_real_a) 114 | all = torch.cat([fixed_real_a, fake_b, fixed_real_b, fake_a], dim=0) 115 | save_image((all.cpu()+1.0)/2.0, 116 | os.path.join(self.save_images_path, 'images_{}.jpg'.format(iteration)), 16) 117 | # save model 118 | if iteration % 1000 == 0: 119 | self._save_model(iteration) 120 | 121 | # train 122 | self.G_A2B = self.G_A2B.train() 123 | self.G_B2A = self.G_B2A.train() 124 | self._lr_decay_step(iteration) 125 | 126 | ######################################################################################################### 127 | # Data # 128 | ######################################################################################################### 129 | real_a = self.loaders.train_a_iter.next_one() 130 | real_b = self.loaders.train_b_iter.next_one() 131 | real_a, real_b = real_a.to(self.device), real_b.to(self.device) 132 | 133 | ######################################################################################################### 134 | # Generator # 135 | ######################################################################################################### 136 | fake_a = self.G_B2A(real_b) 137 | fake_b = self.G_A2B(real_a) 138 | 139 | # gan loss 140 | gan_loss_a = self.criterion_GAN(self.D_A(fake_a), ones) 141 | gan_loss_b = self.criterion_GAN(self.D_B(fake_b), ones) 142 | gan_loss = (gan_loss_a + gan_loss_b) / 2.0 143 | 144 | # cycle loss 145 | cycle_loss_a = self.criterion_cycle(self.G_B2A(fake_b), real_a) 146 | cycle_loss_b = self.criterion_cycle(self.G_A2B(fake_a), real_b) 147 | cycle_loss = (cycle_loss_a + cycle_loss_b) / 2.0 148 | 149 | # idnetity loss 150 | identity_loss_a = self.criterion_identity(self.G_B2A(real_a), real_a) 151 | identity_loss_b = self.criterion_identity(self.G_A2B(real_b), real_b) 152 | identity_loss = (identity_loss_a + identity_loss_b) / 2.0 153 | 154 | # overall loss and optimize 155 | g_loss = gan_loss + 10 * cycle_loss + 5 * identity_loss 156 | 157 | self.G_optimizer.zero_grad() 158 | g_loss.backward(retain_graph=True) 159 | self.G_optimizer.step() 160 | 161 | ######################################################################################################### 162 | # Discriminator # 163 | ######################################################################################################### 164 | # discriminator a 165 | gan_loss_a_real = self.criterion_GAN(self.D_A(real_a), ones) 166 | gan_loss_a_fake = self.criterion_GAN(self.D_A(fake_a.detach()), zeros) 167 | gan_loss_a = (gan_loss_a_real + gan_loss_a_fake) / 2.0 168 | 169 | self.D_A_optimizer.zero_grad() 170 | gan_loss_a.backward() 171 | self.D_A_optimizer.step() 172 | 173 | # discriminator b 174 | gan_loss_b_real = self.criterion_GAN(self.D_B(real_b), ones) 175 | gan_loss_b_fake = self.criterion_GAN(self.D_B(fake_b.detach()), zeros) 176 | gan_loss_b = (gan_loss_b_real + gan_loss_b_fake) / 2.0 177 | 178 | self.D_B_optimizer.zero_grad() 179 | gan_loss_b.backward() 180 | self.D_B_optimizer.step() 181 | 182 | 183 | print('[ITER:{}] [G_GAN:{}] [CYC:{}] [IDENT:{}] [D_GAN:{} {}]'. 184 | format(iteration, gan_loss, cycle_loss, identity_loss, gan_loss_a, gan_loss_b)) 185 | 186 | -------------------------------------------------------------------------------- /images/cyclegan.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangguanan/Pytorch-Image-Translation-GANs/19cd3e4f473b76f48f457b00d4a8efbedbbe59eb/images/cyclegan.jpg -------------------------------------------------------------------------------- /images/pixel2pixel.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangguanan/Pytorch-Image-Translation-GANs/19cd3e4f473b76f48f457b00d4a8efbedbbe59eb/images/pixel2pixel.jpg -------------------------------------------------------------------------------- /images/stargan.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangguanan/Pytorch-Image-Translation-GANs/19cd3e4f473b76f48f457b00d4a8efbedbbe59eb/images/stargan.jpg -------------------------------------------------------------------------------- /pixel2pixel/loaders.py: -------------------------------------------------------------------------------- 1 | ''' 2 | load train and test dataset 3 | ''' 4 | 5 | import glob 6 | import os 7 | 8 | import torch.utils.data as data 9 | from torch.utils.data import Dataset 10 | from PIL import Image 11 | import torchvision.transforms as transforms 12 | 13 | 14 | class Loaders: 15 | ''' 16 | Initialize dataloaders 17 | ''' 18 | 19 | def __init__(self, config): 20 | 21 | self.dataset_path = config.dataset_path 22 | self.image_size = config.image_size 23 | self.batch_size = config.batch_size 24 | 25 | self.transforms = transforms.Compose([transforms.Resize((self.image_size, self.image_size), Image.BICUBIC), 26 | transforms.ToTensor(), 27 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])]) 28 | 29 | train_set = ImageFolder(os.path.join(self.dataset_path, 'train/'), self.transforms) 30 | test_set = ImageFolder(os.path.join(self.dataset_path, 'test/'), self.transforms) 31 | 32 | self.train_loader = data.DataLoader(dataset=train_set, batch_size=self.batch_size, shuffle=True, num_workers=4, drop_last=True) 33 | self.test_loader = data.DataLoader(dataset=test_set, batch_size=self.batch_size, shuffle=False, num_workers=4, drop_last=False) 34 | 35 | 36 | class ImageFolder(Dataset): 37 | ''' 38 | Load images given the path 39 | ''' 40 | 41 | def __init__(self, path, transform): 42 | self.transform = transform 43 | self.samples = sorted(glob.glob(os.path.join(path + '/*.*'))) 44 | 45 | def __getitem__(self, index): 46 | 47 | sample = Image.open(self.samples[index]) 48 | 49 | w, h = sample.size 50 | sample_target = sample.crop((0, 0, w/2, h)) 51 | sample_source = sample.crop((w/2, 0, w, h)) 52 | 53 | sample_source = self.transform(sample_source) 54 | sample_target = self.transform(sample_target) 55 | 56 | return sample_source, sample_target 57 | 58 | def __len__(self): 59 | return len(self.samples) 60 | -------------------------------------------------------------------------------- /pixel2pixel/main.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Define Hyper-parameters 3 | Init Dataset and Model 4 | Run 5 | ''' 6 | 7 | import argparse 8 | import os 9 | 10 | from solver import Solver 11 | from loaders import Loaders 12 | 13 | 14 | def main(config): 15 | 16 | # Environments Parameters 17 | os.environ['CUDA_VISIBLE_DEVICES'] = config.cuda 18 | if not os.path.exists(config.output_path,): 19 | os.makedirs(os.path.join(config.output_path, 'images/')) 20 | os.makedirs(os.path.join(config.output_path, 'models/')) 21 | 22 | # Initialize Dataset 23 | loaders = Loaders(config) 24 | 25 | # Initialize Pixel2Pixel and train 26 | solver = Solver(config, loaders) 27 | solver.train() 28 | 29 | 30 | if __name__ == '__main__': 31 | 32 | parser = argparse.ArgumentParser() 33 | 34 | # Environment Configuration 35 | parser.add_argument('--cuda', type=str, default='-1', help='If -1, use cpu; if >0 use single GPU; if 2,3,4 for multi GPUS(2,3,4)') 36 | parser.add_argument('--output_path', type=str, default='out/facades/') 37 | 38 | # Dataset Configuration 39 | parser.add_argument('--dataset_path', type=str, default='data/facades/') 40 | parser.add_argument('--image_size', type=int, default=256) 41 | parser.add_argument('--batch_size', type=int, default=1) 42 | 43 | # Model Configuration 44 | parser.add_argument('--conv_dim', type=int, default=64) 45 | parser.add_argument('--layer_num', type=int, default=6) 46 | 47 | # Train Configuration 48 | parser.add_argument('--resume_epoch', type=int, default=-1, help='if -1, train from scratch; if >=0, resume and start to train') 49 | 50 | # Test Configuration 51 | parser.add_argument('--test_epoch', type=int, default=100) 52 | parser.add_argument('--test_image', type=str, default='', help='if is an image, only translate it; if a folder, translate all images in it') 53 | 54 | # main function 55 | config = parser.parse_args() 56 | main(config) 57 | -------------------------------------------------------------------------------- /pixel2pixel/models.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Define Models: Discriminator and Generator 3 | ''' 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | 9 | def weights_init_normal(m): 10 | classname = m.__class__.__name__ 11 | if classname.find('Conv') != -1: 12 | torch.nn.init.normal_(m.weight.data, 0.0, 0.02) 13 | elif classname.find('BatchNorm2d') != -1: 14 | torch.nn.init.normal_(m.weight.data, 1.0, 0.02) 15 | torch.nn.init.constant_(m.bias.data, 0.0) 16 | 17 | 18 | class Discriminator(nn.Module): 19 | '''Discriminator with PatchGAN''' 20 | 21 | def __init__(self, conv_dim, layer_num): 22 | super(Discriminator, self).__init__() 23 | 24 | layers = [] 25 | 26 | # input layer 27 | layers.append(nn.Conv2d(3+3, conv_dim, kernel_size=4, stride=2, padding=1)) 28 | layers.append(nn.LeakyReLU(0.2, inplace=True)) 29 | current_dim = conv_dim 30 | 31 | # hidden layers 32 | for i in xrange(1, layer_num): 33 | layers.append(nn.Conv2d(current_dim, current_dim*2, kernel_size=4, stride=2, padding=1)) 34 | layers.append(nn.InstanceNorm2d(current_dim*2)) 35 | layers.append(nn.LeakyReLU(0.2, inplace=True)) 36 | current_dim *= 2 37 | 38 | self.model = nn.Sequential(*layers) 39 | 40 | # output layer 41 | self.conv_src = nn.Conv2d(current_dim, 1, kernel_size=3, stride=1, padding=1, bias=False) 42 | 43 | def forward(self, x, c): 44 | x = self.model(torch.cat([x, c], dim=1)) 45 | out_src = self.conv_src(x) 46 | return out_src 47 | 48 | 49 | class UNetDown(nn.Module): 50 | def __init__(self, in_size, out_size, normalize=True, dropout=0.0): 51 | super(UNetDown, self).__init__() 52 | 53 | layers = [] 54 | 55 | layers.append(nn.Conv2d(in_size, out_size, kernel_size=4, stride=2, padding=1, bias=False)) 56 | if normalize: 57 | layers.append(nn.InstanceNorm2d(out_size)) 58 | layers.append(nn.LeakyReLU(0.2)) 59 | if dropout: 60 | layers.append(nn.Dropout(dropout)) 61 | 62 | self.model = nn.Sequential(*layers) 63 | 64 | def forward(self, x): 65 | return self.model(x) 66 | 67 | 68 | class UNetUp(nn.Module): 69 | def __init__(self, in_size, out_size, dropout=0.0): 70 | super(UNetUp, self).__init__() 71 | 72 | layers = [ nn.ConvTranspose2d(in_size, out_size, 4, 2, 1, bias=False), 73 | nn.InstanceNorm2d(out_size), 74 | nn.ReLU(inplace=True)] 75 | if dropout: 76 | layers.append(nn.Dropout(dropout)) 77 | 78 | self.model = nn.Sequential(*layers) 79 | 80 | def forward(self, x, skip_input): 81 | x = self.model(x) 82 | x = torch.cat((x, skip_input), 1) 83 | 84 | return x 85 | 86 | 87 | class Generator(nn.Module): 88 | def __init__(self, in_channels=3, out_channels=3): 89 | super(Generator, self).__init__() 90 | 91 | self.down1 = UNetDown(in_channels, 64, normalize=False) 92 | self.down2 = UNetDown(64, 128) 93 | self.down3 = UNetDown(128, 256) 94 | self.down4 = UNetDown(256, 512, dropout=0.5) 95 | self.down5 = UNetDown(512, 512, dropout=0.5) 96 | self.down6 = UNetDown(512, 512, dropout=0.5) 97 | self.down7 = UNetDown(512, 512, dropout=0.5) 98 | self.down8 = UNetDown(512, 512, normalize=False, dropout=0.5) 99 | 100 | self.up1 = UNetUp(512, 512, dropout=0.5) 101 | self.up2 = UNetUp(1024, 512, dropout=0.5) 102 | self.up3 = UNetUp(1024, 512, dropout=0.5) 103 | self.up4 = UNetUp(1024, 512, dropout=0.5) 104 | self.up5 = UNetUp(1024, 256) 105 | self.up6 = UNetUp(512, 128) 106 | self.up7 = UNetUp(256, 64) 107 | 108 | 109 | self.final = nn.Sequential( 110 | nn.Upsample(scale_factor=2), 111 | nn.ZeroPad2d((1, 0, 1, 0)), 112 | nn.Conv2d(128, out_channels, 4, padding=1), 113 | nn.Tanh() 114 | ) 115 | 116 | 117 | def forward(self, x): 118 | # U-Net generator with skip connections from encoder to decoder 119 | 120 | d1 = self.down1(x) 121 | d2 = self.down2(d1) 122 | d3 = self.down3(d2) 123 | d4 = self.down4(d3) 124 | d5 = self.down5(d4) 125 | d6 = self.down6(d5) 126 | d7 = self.down7(d6) 127 | d8 = self.down8(d7) 128 | u1 = self.up1(d8, d7) 129 | u2 = self.up2(u1, d6) 130 | u3 = self.up3(u2, d5) 131 | u4 = self.up4(u3, d4) 132 | u5 = self.up5(u4, d3) 133 | u6 = self.up6(u5, d2) 134 | u7 = self.up7(u6, d1) 135 | 136 | return self.final(u7) 137 | -------------------------------------------------------------------------------- /pixel2pixel/script.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | python main.py --cuda 0 -------------------------------------------------------------------------------- /pixel2pixel/solver.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | from torchvision.utils import save_image 4 | import os 5 | 6 | from models import Generator, Discriminator, weights_init_normal 7 | 8 | 9 | class Solver: 10 | 11 | def __init__(self, config, loaders): 12 | 13 | # Parameters 14 | self.config = config 15 | self.loaders = loaders 16 | self.save_images_path = os.path.join(self.config.output_path, 'images/') 17 | self.save_models_path = os.path.join(self.config.output_path, 'models/') 18 | 19 | # Set Devices 20 | if self.config.cuda is not '-1': 21 | os.environ['CUDA_VISIBLE_DEVICES'] = config.cuda 22 | self.device = torch.device('cuda') 23 | else: 24 | self.device = torch.device('cpu') 25 | 26 | # Initialize 27 | self._init_models() 28 | self._init_losses() 29 | self._init_optimizers() 30 | 31 | # Resume Model 32 | if self.config.resume_epoch == -1: 33 | self.start_epoch = 0 34 | elif self.config.resume_epoch >=0: 35 | self.start_epoch = self.config.resume_epoch 36 | self._restore_model(self.config.resume_epoch) 37 | 38 | 39 | def _init_models(self): 40 | 41 | # Init Model 42 | self.generator = Generator() 43 | self.discriminator = Discriminator(self.config.conv_dim, self.config.layer_num) 44 | # Init Weights 45 | self.generator.apply(weights_init_normal) 46 | self.discriminator.apply(weights_init_normal) 47 | # Move model to device (GPU or CPU) 48 | self.generator = torch.nn.DataParallel(self.generator).to(self.device) 49 | self.discriminator = torch.nn.DataParallel(self.discriminator).to(self.device) 50 | 51 | 52 | def _init_losses(self): 53 | # Init GAN loss and Reconstruction Loss 54 | self.criterion_GAN = torch.nn.MSELoss() 55 | self.criterion_recon = torch.nn.L1Loss() 56 | 57 | 58 | def _init_optimizers(self): 59 | # Init Optimizer. Use Hyper-Parameters as DCGAN 60 | self.g_optimizer = optim.Adam(self.generator.parameters(), lr=0.0002, betas=[0.5, 0.999]) 61 | self.d_optimizer = optim.Adam(self.discriminator.parameters(), lr=0.0002, betas=[0.5, 0.999]) 62 | # Set learning-rate decay 63 | self.g_lr_decay = optim.lr_scheduler.StepLR(self.g_optimizer, step_size=100, gamma=0.1) 64 | self.d_lr_decay = optim.lr_scheduler.StepLR(self.d_optimizer, step_size=100, gamma=0.1) 65 | 66 | 67 | def _lr_decay_step(self, current_epoch): 68 | self.g_lr_decay.step(current_epoch) 69 | self.d_lr_decay.step(current_epoch) 70 | 71 | 72 | def _save_model(self, current_epoch): 73 | # Save generator and discriminator 74 | torch.save(self.generator.state_dict(), os.path.join(self.save_models_path, 'G_{}.pkl'.format(current_epoch))) 75 | torch.save(self.discriminator.state_dict(), os.path.join(self.save_models_path, 'D_{}.pkl'.format(current_epoch))) 76 | print 'Note: Successfully save model as {}'.format(current_epoch) 77 | 78 | 79 | def _restore_model(self, resume_epoch): 80 | # Resume generator and discriminator 81 | self.discriminator.load_state_dict(torch.load(os.path.join(self.save_models_path, 'D_{}.pkl'.format(resume_epoch)))) 82 | self.generator.load_state_dict(torch.load(os.path.join(self.save_models_path, 'G_{}.pkl'.format(resume_epoch)))) 83 | print 'Note: Successfully resume model from {}'.format(resume_epoch) 84 | 85 | 86 | def train(self): 87 | 88 | # Load 16 images as fixed image for displaying and debugging 89 | fixed_source_images, fixed_target_images = next(iter(self.loaders.train_loader)) 90 | ones = torch.ones_like(self.discriminator(fixed_source_images, fixed_source_images)) 91 | zeros = torch.zeros_like(self.discriminator(fixed_source_images, fixed_source_images)) 92 | for ii in xrange(16/self.config.batch_size-1): 93 | fixed_source_images_, fixed_target_images_ = next(iter(self.loaders.train_loader)) 94 | fixed_source_images = torch.cat([fixed_source_images, fixed_source_images_], dim=0) 95 | fixed_target_images = torch.cat([fixed_target_images, fixed_target_images_], dim=0) 96 | fixed_source_images, fixed_target_images = fixed_source_images.to(self.device), fixed_target_images.to(self.device) 97 | 98 | # Train 200 epoches 99 | for epoch in xrange(self.start_epoch, 200): 100 | # Save Images for debugging 101 | with torch.no_grad(): 102 | self.generator = self.generator.eval() 103 | fake_images = self.generator(fixed_source_images) 104 | all = torch.cat([fixed_source_images, fake_images, fixed_target_images], dim=0) 105 | save_image((all.cpu()+1.0)/2.0, 106 | os.path.join(self.save_images_path, 'images_{}.jpg'.format(epoch)), 16) 107 | 108 | # Train 109 | self.generator = self.generator.train() 110 | self._lr_decay_step(epoch) 111 | for iteration, data in enumerate(self.loaders.train_loader): 112 | ######################################################################################################### 113 | # load a batch data # 114 | ######################################################################################################### 115 | source_images, target_images = data 116 | source_images, target_images = source_images.to(self.device), target_images.to(self.device) 117 | 118 | ######################################################################################################### 119 | # Generator # 120 | ######################################################################################################### 121 | fake_images = self.generator(source_images) 122 | 123 | gan_loss = self.criterion_GAN(self.discriminator(fake_images, source_images), ones) 124 | recon_loss = self.criterion_recon(fake_images, target_images) 125 | g_loss = gan_loss + 100 * recon_loss 126 | 127 | self.g_optimizer.zero_grad() 128 | g_loss.backward(retain_graph=True) 129 | self.g_optimizer.step() 130 | 131 | ######################################################################################################### 132 | # Discriminator # 133 | ######################################################################################################### 134 | loss_real = self.criterion_GAN(self.discriminator(target_images, source_images), ones) 135 | loss_fake = self.criterion_GAN(self.discriminator(fake_images.detach(), source_images), zeros) 136 | d_loss = (loss_real + loss_fake) / 2.0 137 | 138 | self.d_optimizer.zero_grad() 139 | d_loss.backward() 140 | self.d_optimizer.step() 141 | 142 | print('[EPOCH:{}/{}] [ITER:{}/{}] [D_GAN:{}] [G_GAN:{}] [RECON:{}]'. 143 | format(epoch, 200, iteration, len(self.loaders.train_loader), d_loss, gan_loss, recon_loss)) 144 | 145 | # Save model 146 | self._save_model(epoch) 147 | -------------------------------------------------------------------------------- /stargan/loaders.py: -------------------------------------------------------------------------------- 1 | from torch.utils import data 2 | import torchvision.transforms as transforms 3 | from torchvision.datasets import ImageFolder 4 | from PIL import Image 5 | import torch 6 | import os 7 | import random 8 | 9 | 10 | class CelebA(data.Dataset): 11 | """Dataset class for the CelebA dataset.""" 12 | 13 | def __init__(self, image_dir, attr_path, selected_attrs, transform, mode): 14 | """Initialize and preprocess the CelebA dataset.""" 15 | self.image_dir = image_dir 16 | self.attr_path = attr_path 17 | self.selected_attrs = selected_attrs 18 | self.transform = transform 19 | self.mode = mode 20 | self.train_dataset = [] 21 | self.test_dataset = [] 22 | self.attr2idx = {} 23 | self.idx2attr = {} 24 | self.preprocess() 25 | 26 | if mode == 'train': 27 | self.num_images = len(self.train_dataset) 28 | else: 29 | self.num_images = len(self.test_dataset) 30 | 31 | def preprocess(self): 32 | """Preprocess the CelebA attribute file.""" 33 | lines = [line.rstrip() for line in open(self.attr_path, 'r')] 34 | all_attr_names = lines[1].split() 35 | for i, attr_name in enumerate(all_attr_names): 36 | self.attr2idx[attr_name] = i 37 | self.idx2attr[i] = attr_name 38 | 39 | lines = lines[2:] 40 | random.seed(1234) 41 | random.shuffle(lines) 42 | for i, line in enumerate(lines): 43 | split = line.split() 44 | filename = split[0] 45 | values = split[1:] 46 | 47 | label = [] 48 | for attr_name in self.selected_attrs: 49 | idx = self.attr2idx[attr_name] 50 | label.append(values[idx] == '1') 51 | 52 | if (i+1) < 2000: 53 | self.test_dataset.append([filename, label]) 54 | else: 55 | self.train_dataset.append([filename, label]) 56 | 57 | print('Finished preprocessing the CelebA dataset...') 58 | 59 | def __getitem__(self, index): 60 | """Return one image and its corresponding attribute label.""" 61 | dataset = self.train_dataset if self.mode == 'train' else self.test_dataset 62 | filename, label = dataset[index] 63 | image = Image.open(os.path.join(self.image_dir, filename)) 64 | return self.transform(image), torch.FloatTensor(label) 65 | 66 | def __len__(self): 67 | """Return the number of images.""" 68 | return self.num_images 69 | 70 | 71 | class IterLoader: 72 | 73 | def __init__(self, loader): 74 | self.loader = loader 75 | self.iter = iter(self.loader) 76 | 77 | def next_one(self): 78 | try: 79 | return next(self.iter) 80 | except: 81 | self.iter = iter(self.loader) 82 | return next(self.iter) 83 | 84 | 85 | class Loaders: 86 | 87 | def __init__(self, config): 88 | transform_train = transforms.Compose([transforms.RandomHorizontalFlip(), 89 | transforms.CenterCrop(config.crop_size), 90 | transforms.Resize([config.image_size, config.image_size]), 91 | transforms.ToTensor(), 92 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])]) 93 | 94 | transform_test = transforms.Compose([transforms.CenterCrop(config.crop_size), 95 | transforms.Resize([config.image_size, config.image_size]), 96 | transforms.ToTensor(), 97 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])]) 98 | 99 | train_set = CelebA(config.dataset_path, config.attr_path, config.selected_attrs, transform_train, 'train') 100 | test_set = CelebA(config.dataset_path, config.attr_path, config.selected_attrs, transform_test, 'test') 101 | 102 | self.train_loader = data.DataLoader(dataset=train_set, batch_size=config.batch_size, shuffle=True, num_workers=1) 103 | self.test_loader = data.DataLoader(dataset=test_set, batch_size=config.batch_size, shuffle=False, num_workers=1) 104 | 105 | self.train_iter = IterLoader(self.train_loader) 106 | -------------------------------------------------------------------------------- /stargan/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | from solver import Solver 5 | from loaders import Loaders 6 | 7 | 8 | 9 | def main(config): 10 | 11 | # environments 12 | os.environ['CUDA_VISIBLE_DEVICES'] = config.cuda 13 | if not os.path.exists('out/'): 14 | os.makedirs(os.path.join(config.output_path, 'images/')) 15 | os.makedirs(os.path.join(config.output_path, 'models/')) 16 | 17 | # dataset 18 | loaders = Loaders(config) 19 | 20 | # main 21 | solver = Solver(config, loaders) 22 | solver.train() 23 | 24 | 25 | if __name__ == '__main__': 26 | 27 | parser = argparse.ArgumentParser() 28 | 29 | # Environment Configuration 30 | parser.add_argument('--cuda', type=str, default='4,5,6,7') 31 | parser.add_argument('--output_path', type=str, default='out/celeba') 32 | 33 | # Dataset Configuration 34 | parser.add_argument('--dataset_path', type=str, default='data/img_align_celeba/') 35 | parser.add_argument('--attr_path', type=str, default='data/list_attr_celeba.txt') 36 | parser.add_argument('--selected_attrs', nargs='+', type=str, 37 | default=['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Male', 'Young']) 38 | parser.add_argument('--crop_size', type=int, default=178) 39 | parser.add_argument('--image_size', type=int, default=128) 40 | parser.add_argument('--batch_size', type=int, default=16) 41 | 42 | # Model Configuration 43 | parser.add_argument('--class_num', type=int, default=5) 44 | parser.add_argument('--conv_dim', type=int, default=64) 45 | parser.add_argument('--layer_num', type=int, default=6) 46 | 47 | # Weights Configuration 48 | parser.add_argument('--lambda_cls', type=float, default=1) 49 | parser.add_argument('--lambda_gp', type=float, default=10) 50 | parser.add_argument('--lambda_rec', type=float, default=10) 51 | parser.add_argument('--n_critics', type=int, default=5) 52 | 53 | # main function 54 | config = parser.parse_args() 55 | main(config) 56 | -------------------------------------------------------------------------------- /stargan/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def weights_init_normal(m): 6 | classname = m.__class__.__name__ 7 | if classname.find('Conv') != -1: 8 | torch.nn.init.normal_(m.weight.data, 0.0, 0.02) 9 | 10 | 11 | class ResidualBlock(nn.Module): 12 | '''Residual Block with Instance Normalization''' 13 | 14 | def __init__(self, in_channels, out_channels): 15 | super(ResidualBlock, self).__init__() 16 | 17 | self.model = nn.Sequential( 18 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False), 19 | nn.InstanceNorm2d(out_channels, affine=True, track_running_stats=True), 20 | nn.ReLU(inplace=True), 21 | nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False), 22 | nn.InstanceNorm2d(out_channels, affine=True, track_running_stats=True), 23 | ) 24 | 25 | def forward(self, x): 26 | return self.model(x) + x 27 | 28 | 29 | class Generator(nn.Module): 30 | '''Generator with Down sampling, Several ResBlocks and Up sampling. 31 | Down/Up Samplings are used for less computation. 32 | ''' 33 | 34 | def __init__(self, class_num, conv_dim, layer_num): 35 | super(Generator, self).__init__() 36 | 37 | self.class_num = class_num 38 | 39 | layers = [] 40 | 41 | # input layer 42 | layers.append(nn.Conv2d(in_channels=3+class_num, out_channels=conv_dim, kernel_size=7, stride=1, padding=3, bias=False)) 43 | layers.append(nn.InstanceNorm2d(conv_dim, affine=True, track_running_stats=True)) 44 | layers.append(nn.ReLU(inplace=True)) 45 | 46 | # down sampling layers 47 | current_dims = conv_dim 48 | for i in xrange(2): 49 | layers.append(nn.Conv2d(current_dims, current_dims*2, kernel_size=4, stride=2, padding=1, bias=False)) 50 | layers.append(nn.InstanceNorm2d(current_dims*2, affine=True, track_running_stats=True)) 51 | layers.append(nn.ReLU(inplace=True)) 52 | current_dims *= 2 53 | 54 | # Residual Layers 55 | for i in xrange(layer_num): 56 | layers.append(ResidualBlock(current_dims, current_dims)) 57 | 58 | # up sampling layers 59 | for i in xrange(2): 60 | layers.append(nn.ConvTranspose2d(current_dims, current_dims//2, kernel_size=4, stride=2, padding=1, bias=False)) 61 | layers.append(nn.InstanceNorm2d(current_dims//2, affine=True, track_running_stats=True)) 62 | layers.append(nn.ReLU(inplace=True)) 63 | current_dims = current_dims//2 64 | 65 | # output layer 66 | layers.append(nn.Conv2d(current_dims, 3, kernel_size=7, stride=1, padding=3, bias=False)) 67 | layers.append(nn.Tanh()) 68 | 69 | self.model = nn.Sequential(*layers) 70 | 71 | def forward(self, x, c): 72 | c = c.view(c.size(0), c.size(1), 1, 1) 73 | c = c.repeat([1, 1, x.size(2), x.size(3)]) 74 | x = torch.cat([x, c], dim=1) 75 | return self.model(x) 76 | 77 | 78 | class Discriminator(nn.Module): 79 | '''Discriminator with PatchGAN''' 80 | 81 | def __init__(self, image_size, conv_dim, layer_num, class_num): 82 | super(Discriminator, self).__init__() 83 | 84 | layers = [] 85 | 86 | # input layer 87 | layers.append(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1)) 88 | layers.append(nn.LeakyReLU(0.01)) 89 | current_dim = conv_dim 90 | 91 | # hidden layers 92 | for i in xrange(1, layer_num): 93 | layers.append(nn.Conv2d(current_dim, current_dim*2, kernel_size=4, stride=2, padding=1)) 94 | layers.append(nn.LeakyReLU(0.01)) 95 | current_dim *= 2 96 | 97 | self.model = nn.Sequential(*layers) 98 | 99 | # output layer 100 | # compute image source (real/fake) and image class 101 | kernel_size = int(image_size / 2**layer_num) 102 | self.conv_src = nn.Conv2d(current_dim, 1, kernel_size=3, stride=1, padding=1, bias=False) 103 | self.conv_cls = nn.Conv2d(current_dim, class_num, kernel_size=kernel_size, bias=False) 104 | 105 | def forward(self, x): 106 | x = self.model(x) 107 | out_src = self.conv_src(x) 108 | out_cls = self.conv_cls(x) 109 | out_cls = out_cls.view(out_cls.size(0), out_cls.size(1)) 110 | return out_src, out_cls 111 | -------------------------------------------------------------------------------- /stargan/script.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | python main.py --cuda 0 -------------------------------------------------------------------------------- /stargan/solver.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision.utils import save_image 4 | import os 5 | import numpy as np 6 | 7 | from models import Generator, Discriminator 8 | 9 | 10 | class Solver: 11 | 12 | def __init__(self, config, loaders): 13 | 14 | self.config = config 15 | self.loaders = loaders 16 | 17 | self.save_images_path = os.path.join(self.config.output_path, 'images/') 18 | self.save_models_path = os.path.join(self.config.output_path, 'models/') 19 | 20 | if self.config.cuda != '-1': 21 | os.environ['CUDA_VISIBLE_DEVICES'] = config.cuda 22 | self.device = torch.device('cuda') 23 | else: 24 | os.environ['CUDA_VISIBLE_DEVICES'] = config.cuda 25 | self.device = torch.device('cpu') 26 | 27 | self._init_models() 28 | self._init_optimizers() 29 | 30 | 31 | def _init_models(self): 32 | 33 | self.generator = Generator(self.config.class_num, self.config.conv_dim, self.config.layer_num) 34 | self.discriminator = Discriminator(self.config.image_size, self.config.conv_dim, self.config.layer_num, self.config.class_num) 35 | 36 | self.generator = torch.nn.DataParallel(self.generator).to(self.device) 37 | self.discriminator = torch.nn.DataParallel(self.discriminator).to(self.device) 38 | 39 | 40 | def criterion_cls(self, logit, target): 41 | return nn.functional.binary_cross_entropy_with_logits(logit, target, size_average=False) / logit.size(0) 42 | 43 | 44 | def gradient_penalty(self, y, x): 45 | """Compute gradient penalty: (L2_norm(dy/dx) - 1)**2.""" 46 | weight = torch.ones(y.size()).to(self.device) 47 | dydx = torch.autograd.grad(outputs=y, 48 | inputs=x, 49 | grad_outputs=weight, 50 | retain_graph=True, 51 | create_graph=True, 52 | only_inputs=True)[0] 53 | dydx = dydx.view(dydx.size(0), -1) 54 | dydx_l2norm = torch.sqrt(torch.sum(dydx ** 2, dim=1)) 55 | return torch.mean((dydx_l2norm - 1) ** 2) 56 | 57 | 58 | def _init_optimizers(self): 59 | self.d_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=0.0001, betas=[0.5, 0.999]) 60 | self.g_optimizer = torch.optim.Adam(self.generator.parameters(), lr=0.0001, betas=[0.5, 0.999]) 61 | 62 | 63 | def _save_model(self, current_epoch): 64 | torch.save(self.generator.state_dict(), os.path.join(self.save_models_path, 'G_{}.pkl'.format(current_epoch))) 65 | torch.save(self.discriminator.state_dict(), os.path.join(self.save_models_path, 'D_{}.pkl'.format(current_epoch))) 66 | print 'Note: Successfully save model as {}'.format(current_epoch) 67 | 68 | 69 | def _restore_model(self, resume_epoch): 70 | self.discriminator.load_state_dict(torch.load(os.path.join(self.save_models_path, 'D_{}.pkl'.format(resume_epoch)))) 71 | self.generator.load_state_dict(torch.load(os.path.join(self.save_models_path, 'G_{}.pkl'.format(resume_epoch)))) 72 | print 'Note: Successfully resume model from {}'.format(resume_epoch) 73 | 74 | 75 | def denorm(self, x): 76 | """Convert the range from [-1, 1] to [0, 1].""" 77 | out = (x + 1) / 2 78 | return out.clamp_(0, 1) 79 | 80 | 81 | def label2onehot(self, labels, dim): 82 | """Convert label indices to one-hot vectors.""" 83 | batch_size = labels.size(0) 84 | out = torch.zeros(batch_size, dim) 85 | out[np.arange(batch_size), labels.long()] = 1 86 | return out 87 | 88 | 89 | def create_labels(self, c_org, c_dim=5, dataset='CelebA', selected_attrs=None): 90 | """Generate target domain labels for debugging and testing.""" 91 | # Get hair color indices. 92 | if dataset == 'CelebA': 93 | hair_color_indices = [] 94 | for i, attr_name in enumerate(selected_attrs): 95 | if attr_name in ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Gray_Hair']: 96 | hair_color_indices.append(i) 97 | 98 | c_trg_list = [] 99 | for i in range(c_dim): 100 | if dataset == 'CelebA': 101 | c_trg = c_org.clone() 102 | if i in hair_color_indices: # Set one hair color to 1 and the rest to 0. 103 | c_trg[:, i] = 1 104 | for j in hair_color_indices: 105 | if j != i: 106 | c_trg[:, j] = 0 107 | else: 108 | c_trg[:, i] = (c_trg[:, i] == 0) # Reverse attribute value. 109 | elif dataset == 'RaFD': 110 | c_trg = self.label2onehot(torch.ones(c_org.size(0))*i, c_dim) 111 | 112 | c_trg_list.append(c_trg.to(self.device)) 113 | return c_trg_list 114 | 115 | 116 | def train(self): 117 | 118 | # fixed_real_images, fixed_real_labels = self.loaders.train_iter.next_one() 119 | # fixed_target_labels = fixed_real_labels[torch.randperm(fixed_real_labels.size(0))] 120 | # fixed_real_images, fixed_real_labels, fixed_target_labels = \ 121 | # fixed_real_images.to(self.device), fixed_real_labels.to(self.device), fixed_target_labels.to(self.device) 122 | 123 | x_fixed, c_org = self.loaders.train_iter.next_one() 124 | x_fixed = x_fixed.to(self.device) 125 | c_fixed_list = self.create_labels(c_org, 5, 'CelebA', ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Male', 'Young']) 126 | 127 | 128 | for iteration in xrange(200000): 129 | 130 | ######################################################################################################### 131 | # Save Images # 132 | ######################################################################################################### 133 | if iteration % 100 == 0: 134 | # with torch.no_grad(): 135 | # self.generator = self.generator.eval() 136 | # all_images = copy.deepcopy(fixed_real_images) 137 | # for i in range(self.config.class_num): 138 | # target_labels = copy.deepcopy(fixed_target_labels) 139 | # target_labels[:, i] = 1 140 | # fake_images = self.generator(fixed_real_images, target_labels) 141 | # all_images = torch.cat([all_images, fake_images], dim=0) 142 | # save_image((all_images.cpu()+1.0)/2.0, 143 | # os.path.join(self.save_images_path, 'images_{}.jpg'.format(iteration)), self.config.batch_size) 144 | 145 | with torch.no_grad(): 146 | x_fake_list = [x_fixed] 147 | for c_fixed in c_fixed_list: 148 | x_fake_list.append(self.generator(x_fixed, c_fixed)) 149 | x_concat = torch.cat(x_fake_list, dim=3) 150 | sample_path = os.path.join(self.save_images_path, '{}-images.jpg'.format(iteration+1)) 151 | save_image(self.denorm(x_concat.data.cpu()), sample_path, nrow=1, padding=0) 152 | print('Saved real and fake images into {}...'.format(sample_path)) 153 | 154 | 155 | ######################################################################################################### 156 | # Load Batch Data # 157 | ######################################################################################################### 158 | self.generator = self.generator.train() 159 | real_images, real_labels = self.loaders.train_iter.next_one() 160 | target_labels = real_labels[torch.randperm(real_labels.size(0))] 161 | real_images, real_labels, target_labels = real_images.to(self.device), real_labels.to(self.device), target_labels.to(self.device) 162 | 163 | 164 | ######################################################################################################### 165 | # Discriminator # 166 | ######################################################################################################### 167 | out_src, out_cls = self.discriminator(real_images) 168 | d_loss_real = -torch.mean(out_src) 169 | d_loss_cls = self.criterion_cls(out_cls, real_labels) 170 | 171 | fake_images = self.generator(real_images, target_labels) 172 | out_src, _ = self.discriminator(fake_images.detach()) 173 | d_loss_fake = torch.mean(out_src) 174 | 175 | alpha = torch.rand(real_images.size(0), 1, 1, 1).to(self.device) 176 | images_hat = (alpha * real_images.data + (1 - alpha) * fake_images.data).requires_grad_(True) 177 | out_src, _ = self.discriminator(images_hat) 178 | d_loss_gp = self.gradient_penalty(out_src, images_hat) 179 | 180 | d_loss = d_loss_real + d_loss_fake + \ 181 | self.config.lambda_cls * d_loss_cls + \ 182 | self.config.lambda_gp * d_loss_gp 183 | 184 | self.d_optimizer.zero_grad() 185 | d_loss.backward() 186 | self.d_optimizer.step() 187 | 188 | ######################################################################################################### 189 | # Generator # 190 | ######################################################################################################### 191 | 192 | if iteration % self.config.n_critics == 0: 193 | fake_images = self.generator(real_images, target_labels) 194 | out_src, out_cls = self.discriminator(fake_images) 195 | g_loss_fake = -torch.mean(out_src) 196 | g_loss_cls = self.criterion_cls(out_cls, target_labels) 197 | 198 | rec_images = self.generator(fake_images, real_labels) 199 | g_loss_rec = torch.mean(torch.abs(real_images-rec_images)) 200 | 201 | g_loss = g_loss_fake + self.config.lambda_rec * g_loss_rec + self.config.lambda_cls * g_loss_cls 202 | 203 | self.g_optimizer.zero_grad() 204 | g_loss.backward() 205 | self.g_optimizer.step() 206 | 207 | 208 | print 'Iter: {}, [D] gan: {}, cls: {}, gp: {}; [G]: gan: {}, rec: {}, cls: {}'.\ 209 | format(iteration, (d_loss_real+d_loss_fake).data, d_loss_cls.data, d_loss_gp.data, 210 | g_loss_fake.data, g_loss_rec.data, g_loss_cls.data) 211 | 212 | if iteration % 1000 == 0: 213 | self._save_model(iteration) 214 | 215 | -------------------------------------------------------------------------------- /stargan/tools.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def label2onehot(labels, dim, cuda): 6 | """Convert label indices to one-hot vectors.""" 7 | batch_size = labels.size(0) 8 | out = torch.zeros(batch_size, dim) 9 | out[np.arange(batch_size), labels.long()] = 1 10 | if cuda: 11 | out = out.to(torch.device('cuda')) 12 | return out --------------------------------------------------------------------------------