├── .gitignore ├── LICENSE ├── README.md ├── checkpoint └── .gitignore ├── doc ├── celeba_570600.png └── lsun_600000.png ├── model.py ├── sample └── .gitignore └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Kim Seonghyeon 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Progressive GAN in PyTorch 2 | Implementation of Progressive Growing of GANs (https://arxiv.org/abs/1710.10196) in PyTorch 3 | 4 | Currently implemented and tested up to 128x128 images. 5 | 6 | Usage: 7 | 8 | > python train.py -d {celeba, lsun} PATH 9 | 10 | Currently CelebA and LSUN dataset is supported. (Warning: Using LSUN dataset requires vast amount of time for creating index cache.) 11 | 12 | ## Sample 13 | 14 | * Sample from the model trained on CelebA 15 | 16 | ![Sample of the model trained on CelebA](doc/celeba_570600.png) 17 | 18 | * Sample from the model trained on LSUN (dog) 19 | 20 | ![Sample of the model trained using LSUN (dog)](doc/lsun_600000.png) 21 | -------------------------------------------------------------------------------- /checkpoint/.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rosinality/progressive-gan-pytorch/672f27ef154944284ee6077c856fb98671ae0bea/checkpoint/.gitignore -------------------------------------------------------------------------------- /doc/celeba_570600.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rosinality/progressive-gan-pytorch/672f27ef154944284ee6077c856fb98671ae0bea/doc/celeba_570600.png -------------------------------------------------------------------------------- /doc/lsun_600000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rosinality/progressive-gan-pytorch/672f27ef154944284ee6077c856fb98671ae0bea/doc/lsun_600000.png -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch import nn 4 | from torch.nn import init 5 | from torch.nn import functional as F 6 | from torch.autograd import Variable 7 | 8 | from math import sqrt 9 | 10 | 11 | def init_linear(linear): 12 | init.xavier_normal(linear.weight) 13 | linear.bias.data.zero_() 14 | 15 | 16 | def init_conv(conv, glu=True): 17 | init.kaiming_normal(conv.weight) 18 | if conv.bias is not None: 19 | conv.bias.data.zero_() 20 | 21 | 22 | class SpectralNorm: 23 | def __init__(self, name): 24 | self.name = name 25 | 26 | def compute_weight(self, module): 27 | weight = getattr(module, self.name + '_orig') 28 | u = getattr(module, self.name + '_u') 29 | size = weight.size() 30 | weight_mat = weight.contiguous().view(size[0], -1) 31 | if weight_mat.is_cuda: 32 | u = u.cuda() 33 | v = weight_mat.t() @ u 34 | v = v / v.norm() 35 | u = weight_mat @ v 36 | u = u / u.norm() 37 | weight_sn = weight_mat / (u.t() @ weight_mat @ v) 38 | weight_sn = weight_sn.view(*size) 39 | 40 | return weight_sn, Variable(u.data) 41 | 42 | @staticmethod 43 | def apply(module, name): 44 | fn = SpectralNorm(name) 45 | 46 | weight = getattr(module, name) 47 | del module._parameters[name] 48 | module.register_parameter(name + '_orig', nn.Parameter(weight.data)) 49 | input_size = weight.size(0) 50 | u = Variable(torch.randn(input_size, 1) * 0.1, requires_grad=False) 51 | setattr(module, name + '_u', u) 52 | setattr(module, name, fn.compute_weight(module)[0]) 53 | 54 | module.register_forward_pre_hook(fn) 55 | 56 | return fn 57 | 58 | def __call__(self, module, input): 59 | weight_sn, u = self.compute_weight(module) 60 | setattr(module, self.name, weight_sn) 61 | setattr(module, self.name + '_u', u) 62 | 63 | 64 | def spectral_norm(module, name='weight'): 65 | SpectralNorm.apply(module, name) 66 | 67 | return module 68 | 69 | 70 | class EqualLR: 71 | def __init__(self, name): 72 | self.name = name 73 | 74 | def compute_weight(self, module): 75 | weight = getattr(module, self.name + '_orig') 76 | fan_in = weight.data.size(1) * weight.data[0][0].numel() 77 | 78 | return weight * sqrt(2 / fan_in) 79 | 80 | @staticmethod 81 | def apply(module, name): 82 | fn = EqualLR(name) 83 | 84 | weight = getattr(module, name) 85 | del module._parameters[name] 86 | module.register_parameter(name + '_orig', nn.Parameter(weight.data)) 87 | module.register_forward_pre_hook(fn) 88 | 89 | return fn 90 | 91 | def __call__(self, module, input): 92 | weight = self.compute_weight(module) 93 | setattr(module, self.name, weight) 94 | 95 | 96 | def equal_lr(module, name='weight'): 97 | EqualLR.apply(module, name) 98 | 99 | return module 100 | 101 | 102 | class PixelNorm(nn.Module): 103 | def __init__(self): 104 | super().__init__() 105 | 106 | def forward(self, input): 107 | return input / torch.sqrt(torch.mean(input ** 2, dim=1, keepdim=True) 108 | + 1e-8) 109 | 110 | 111 | class SpectralNormConv2d(nn.Module): 112 | def __init__(self, *args, **kwargs): 113 | super().__init__() 114 | 115 | conv = nn.Conv2d(*args, **kwargs) 116 | init.kaiming_normal(conv.weight) 117 | conv.bias.data.zero_() 118 | self.conv = spectral_norm(conv) 119 | 120 | def forward(self, input): 121 | return self.conv(input) 122 | 123 | 124 | class EqualConv2d(nn.Module): 125 | def __init__(self, *args, **kwargs): 126 | super().__init__() 127 | 128 | conv = nn.Conv2d(*args, **kwargs) 129 | conv.weight.data.normal_() 130 | conv.bias.data.zero_() 131 | self.conv = equal_lr(conv) 132 | 133 | def forward(self, input): 134 | return self.conv(input) 135 | 136 | 137 | class ConvBlock(nn.Module): 138 | def __init__(self, in_channel, out_channel, kernel_size, 139 | padding, 140 | kernel_size2=None, padding2=None, 141 | pixel_norm=True, spectral_norm=False): 142 | super().__init__() 143 | 144 | pad1 = padding 145 | pad2 = padding 146 | if padding2 is not None: 147 | pad2 = padding2 148 | 149 | kernel1 = kernel_size 150 | kernel2 = kernel_size 151 | if kernel_size2 is not None: 152 | kernel2 = kernel_size2 153 | 154 | if spectral_norm: 155 | self.conv = nn.Sequential(SpectralNormConv2d(in_channel, 156 | out_channel, kernel1, 157 | padding=pad1), 158 | nn.LeakyReLU(0.2), 159 | SpectralNormConv2d(out_channel, 160 | out_channel, kernel2, 161 | padding=pad2), 162 | nn.LeakyReLU(0.2)) 163 | 164 | else: 165 | if pixel_norm: 166 | self.conv = nn.Sequential(EqualConv2d(in_channel, out_channel, 167 | kernel1, padding=pad1), 168 | PixelNorm(), 169 | nn.LeakyReLU(0.2), 170 | EqualConv2d(out_channel, out_channel, 171 | kernel2, padding=pad2), 172 | PixelNorm(), 173 | nn.LeakyReLU(0.2)) 174 | 175 | else: 176 | self.conv = nn.Sequential(EqualConv2d(in_channel, out_channel, 177 | kernel1, padding=pad1), 178 | nn.LeakyReLU(0.2), 179 | EqualConv2d(out_channel, out_channel, 180 | kernel2, padding=pad2), 181 | nn.LeakyReLU(0.2)) 182 | 183 | def forward(self, input): 184 | out = self.conv(input) 185 | 186 | return out 187 | 188 | 189 | class Generator(nn.Module): 190 | def __init__(self, code_dim=512 - 10, n_label=10): 191 | super().__init__() 192 | 193 | self.label_embed = nn.Embedding(n_label, n_label) 194 | self.code_norm = PixelNorm() 195 | self.label_embed.weight.data.normal_() 196 | self.progression = nn.ModuleList([ConvBlock(512, 512, 4, 3, 3, 1), 197 | ConvBlock(512, 512, 3, 1), 198 | ConvBlock(512, 512, 3, 1), 199 | ConvBlock(512, 512, 3, 1), 200 | ConvBlock(512, 256, 3, 1), 201 | ConvBlock(256, 128, 3, 1)]) 202 | 203 | self.to_rgb = nn.ModuleList([nn.Conv2d(512, 3, 1), 204 | nn.Conv2d(512, 3, 1), 205 | nn.Conv2d(512, 3, 1), 206 | nn.Conv2d(512, 3, 1), 207 | nn.Conv2d(256, 3, 1), 208 | nn.Conv2d(128, 3, 1)]) 209 | 210 | def forward(self, input, label, step=0, alpha=-1): 211 | input = self.code_norm(input) 212 | label = self.label_embed(label) 213 | out = torch.cat([input, label], 1).unsqueeze(2).unsqueeze(3) 214 | 215 | for i, (conv, to_rgb) in enumerate(zip(self.progression, self.to_rgb)): 216 | if i > 0 and step > 0: 217 | upsample = F.upsample(out, scale_factor=2) 218 | out = conv(upsample) 219 | 220 | else: 221 | out = conv(out) 222 | 223 | if i == step: 224 | out = to_rgb(out) 225 | 226 | if i > 0 and 0 <= alpha < 1: 227 | skip_rgb = self.to_rgb[i - 1](upsample) 228 | out = (1 - alpha) * skip_rgb + alpha * out 229 | 230 | break 231 | 232 | return out 233 | 234 | 235 | class Discriminator(nn.Module): 236 | def __init__(self, n_label=10): 237 | super().__init__() 238 | 239 | self.progression = nn.ModuleList([ConvBlock(128, 256, 3, 1, 240 | pixel_norm=False, 241 | spectral_norm=False), 242 | ConvBlock(256, 512, 3, 1, 243 | pixel_norm=False, 244 | spectral_norm=False), 245 | ConvBlock(512, 512, 3, 1, 246 | pixel_norm=False, 247 | spectral_norm=False), 248 | ConvBlock(512, 512, 3, 1, 249 | pixel_norm=False, 250 | spectral_norm=False), 251 | ConvBlock(512, 512, 3, 1, 252 | pixel_norm=False, 253 | spectral_norm=False), 254 | ConvBlock(513, 512, 3, 1, 4, 0, 255 | pixel_norm=False, 256 | spectral_norm=False)]) 257 | 258 | self.from_rgb = nn.ModuleList([nn.Conv2d(3, 128, 1), 259 | nn.Conv2d(3, 256, 1), 260 | nn.Conv2d(3, 512, 1), 261 | nn.Conv2d(3, 512, 1), 262 | nn.Conv2d(3, 512, 1), 263 | nn.Conv2d(3, 512, 1)]) 264 | 265 | self.n_layer = len(self.progression) 266 | 267 | self.linear = nn.Linear(512, 1 + n_label) 268 | 269 | def forward(self, input, step=0, alpha=-1): 270 | for i in range(step, -1, -1): 271 | index = self.n_layer - i - 1 272 | 273 | if i == step: 274 | out = self.from_rgb[index](input) 275 | 276 | if i == 0: 277 | mean_std = input.std(0).mean() 278 | mean_std = mean_std.expand(input.size(0), 1, 4, 4) 279 | out = torch.cat([out, mean_std], 1) 280 | 281 | out = self.progression[index](out) 282 | 283 | if i > 0: 284 | out = F.avg_pool2d(out, 2) 285 | 286 | if i == step and 0 <= alpha < 1: 287 | skip_rgb = F.avg_pool2d(input, 2) 288 | skip_rgb = self.from_rgb[index + 1](skip_rgb) 289 | out = (1 - alpha) * skip_rgb + alpha * out 290 | 291 | out = out.squeeze(2).squeeze(2) 292 | # print(input.size(), out.size(), step) 293 | out = self.linear(out) 294 | 295 | return out[:, 0], out[:, 1:] 296 | -------------------------------------------------------------------------------- /sample/.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rosinality/progressive-gan-pytorch/672f27ef154944284ee6077c856fb98671ae0bea/sample/.gitignore -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import numpy as np 3 | from PIL import Image 4 | import argparse 5 | 6 | import torch 7 | from torch import nn, optim 8 | from torch.autograd import Variable, grad 9 | from torch.utils.data import DataLoader 10 | from torchvision import datasets, transforms, utils 11 | 12 | from model import Generator, Discriminator 13 | 14 | 15 | n_label = 1 16 | code_size = 512 - n_label 17 | batch_size = 16 18 | n_critic = 1 19 | 20 | parser = argparse.ArgumentParser(description='Progressive Growing of GANs') 21 | parser.add_argument('path', type=str, help='path of specified dataset') 22 | parser.add_argument('-d', '--data', default='celeba', type=str, 23 | choices=['celeba', 'lsun'], 24 | help=('Specify dataset. ' 25 | 'Currently CelebA and LSUN is supported')) 26 | 27 | generator = Generator(code_size, n_label).cuda() 28 | discriminator = Discriminator(n_label).cuda() 29 | g_running = Generator(code_size, n_label).cuda() 30 | g_running.train(False) 31 | 32 | class_loss = nn.CrossEntropyLoss() 33 | 34 | 35 | g_optimizer = optim.Adam(generator.parameters(), lr=0.001, betas=(0.0, 0.99)) 36 | d_optimizer = optim.Adam( 37 | discriminator.parameters(), lr=0.001, betas=(0.0, 0.99)) 38 | 39 | 40 | def requires_grad(model, flag=True): 41 | for p in model.parameters(): 42 | p.requires_grad = flag 43 | 44 | 45 | def accumulate(model1, model2, decay=0.999): 46 | par1 = dict(model1.named_parameters()) 47 | par2 = dict(model2.named_parameters()) 48 | 49 | for k in par1.keys(): 50 | par1[k].data.mul_(decay).add_(1 - decay, par2[k].data) 51 | 52 | 53 | def lsun_loader(path): 54 | def loader(transform): 55 | data = datasets.LSUNClass( 56 | path, transform=transform, 57 | target_transform=lambda x: 0) 58 | data_loader = DataLoader(data, shuffle=False, batch_size=batch_size, 59 | num_workers=4) 60 | 61 | return data_loader 62 | 63 | return loader 64 | 65 | 66 | def celeba_loader(path): 67 | def loader(transform): 68 | data = datasets.ImageFolder(path, transform=transform) 69 | data_loader = DataLoader(data, shuffle=True, batch_size=batch_size, 70 | num_workers=4) 71 | 72 | return data_loader 73 | 74 | return loader 75 | 76 | 77 | def sample_data(dataloader, image_size=4): 78 | transform = transforms.Compose([ 79 | transforms.Resize(image_size), 80 | transforms.CenterCrop(image_size), 81 | transforms.ToTensor(), 82 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 83 | ]) 84 | 85 | loader = dataloader(transform) 86 | 87 | for img, label in loader: 88 | yield img, label 89 | 90 | 91 | def train(generator, discriminator, loader): 92 | step = 0 93 | dataset = sample_data(loader, 4 * 2 ** step) 94 | pbar = tqdm(range(600000)) 95 | 96 | requires_grad(generator, False) 97 | requires_grad(discriminator, True) 98 | 99 | disc_loss_val = 0 100 | gen_loss_val = 0 101 | grad_loss_val = 0 102 | 103 | alpha = 0 104 | one = torch.FloatTensor([1]).cuda() 105 | mone = one * -1 106 | iteration = 0 107 | stabilize = False 108 | 109 | for i in pbar: 110 | discriminator.zero_grad() 111 | 112 | alpha = min(1, 0.00002 * iteration) 113 | 114 | if stabilize is False and iteration > 50000: 115 | dataset = sample_data(loader, 4 * 2 ** step) 116 | stabilize = True 117 | 118 | if iteration > 100000: 119 | alpha = 0 120 | iteration = 0 121 | step += 1 122 | stabilize = False 123 | if step > 5: 124 | alpha = 1 125 | step = 5 126 | dataset = sample_data(loader, 4 * 2 ** step) 127 | 128 | try: 129 | real_image, label = next(dataset) 130 | 131 | except (OSError, StopIteration): 132 | dataset = sample_data(loader, 4 * 2 ** step) 133 | real_image, label = next(dataset) 134 | 135 | iteration += 1 136 | 137 | b_size = real_image.size(0) 138 | real_image = Variable(real_image).cuda() 139 | label = Variable(label).cuda() 140 | real_predict, real_class_predict = discriminator( 141 | real_image, step, alpha) 142 | real_predict = real_predict.mean() \ 143 | - 0.001 * (real_predict ** 2).mean() 144 | real_predict.backward(mone) 145 | 146 | fake_image = generator( 147 | Variable(torch.randn(b_size, code_size)).cuda(), 148 | label, step, alpha) 149 | fake_predict, fake_class_predict = discriminator( 150 | fake_image, step, alpha) 151 | fake_predict = fake_predict.mean() 152 | fake_predict.backward(one) 153 | 154 | eps = torch.rand(b_size, 1, 1, 1).cuda() 155 | x_hat = eps * real_image.data + (1 - eps) * fake_image.data 156 | x_hat = Variable(x_hat, requires_grad=True) 157 | hat_predict, _ = discriminator(x_hat, step, alpha) 158 | grad_x_hat = grad( 159 | outputs=hat_predict.sum(), inputs=x_hat, create_graph=True)[0] 160 | grad_penalty = ((grad_x_hat.view(grad_x_hat.size(0), -1) 161 | .norm(2, dim=1) - 1)**2).mean() 162 | grad_penalty = 10 * grad_penalty 163 | grad_penalty.backward() 164 | grad_loss_val = grad_penalty.data[0] 165 | disc_loss_val = (real_predict - fake_predict).data[0] 166 | 167 | d_optimizer.step() 168 | 169 | if (i + 1) % n_critic == 0: 170 | generator.zero_grad() 171 | 172 | requires_grad(generator, True) 173 | requires_grad(discriminator, False) 174 | 175 | input_class = Variable( 176 | torch.multinomial( 177 | torch.ones(n_label), batch_size, replacement=True)).cuda() 178 | fake_image = generator( 179 | Variable(torch.randn(batch_size, code_size)).cuda(), 180 | input_class, step, alpha) 181 | 182 | predict, class_predict = discriminator(fake_image, step, alpha) 183 | 184 | loss = -predict.mean() 185 | gen_loss_val = loss.data[0] 186 | 187 | loss.backward() 188 | g_optimizer.step() 189 | accumulate(g_running, generator) 190 | 191 | requires_grad(generator, False) 192 | requires_grad(discriminator, True) 193 | 194 | if (i + 1) % 100 == 0: 195 | images = [] 196 | for _ in range(5): 197 | input_class = Variable(torch.zeros(10).long()).cuda() 198 | images.append(g_running( 199 | Variable(torch.randn(n_label * 10, code_size)).cuda(), 200 | input_class, step, alpha).data.cpu()) 201 | utils.save_image( 202 | torch.cat(images, 0), 203 | f'sample/{str(i + 1).zfill(6)}.png', 204 | nrow=n_label * 10, 205 | normalize=True, 206 | range=(-1, 1)) 207 | 208 | if (i + 1) % 10000 == 0: 209 | torch.save(g_running, f'checkpoint/{str(i + 1).zfill(6)}.model') 210 | 211 | pbar.set_description( 212 | (f'{i + 1}; G: {gen_loss_val:.5f}; D: {disc_loss_val:.5f};' 213 | f' Grad: {grad_loss_val:.5f}; Alpha: {alpha:.3f}')) 214 | 215 | 216 | if __name__ == '__main__': 217 | accumulate(g_running, generator, 0) 218 | args = parser.parse_args() 219 | 220 | if args.data == 'celeba': 221 | loader = celeba_loader(args.path) 222 | 223 | elif args.data == 'lsun': 224 | loader = lsun_loader(args.path) 225 | 226 | train(generator, discriminator, loader) 227 | --------------------------------------------------------------------------------