├── .gitignore ├── README.md ├── amp_support.py ├── config.py ├── config └── ffhq.yaml ├── custom_layers.py ├── dataloader.py ├── images ├── 128x128.png ├── 16x16.png ├── 256x256.png ├── 32x32.png ├── 64x64.png ├── 8x8.png ├── fp32_loss_d.png ├── fp32_score.png ├── mixed_loss_d.png ├── mixed_score.png ├── precision_speed.png └── structure.png ├── inferencer.py ├── main.py ├── networks.py ├── tf_recorder.py └── trainer.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | repo/ 107 | checkpoints/ 108 | *.pth -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pytorch implementation of [A Style-Based Generator Architecture for Generative Adversarial Network](https://arxiv.org/abs/1812.04948) 2 | 3 | ## Requirements 4 | 5 | - Python3 6 | - Pytorch >= 1.0.0 7 | - TensorBoardX 8 | - fire 9 | - apex [optional] 10 | - pyyaml 11 | 12 | ## Usage 13 | 14 | train 15 | ``` 16 | python main.py 17 | --config_file=path_to_config_file 18 | --checkpoint=path_to_config_file[default=''] 19 | ``` 20 | 21 | inference 22 | ``` 23 | python main.py 24 | --config_file=path_to_config_file 25 | --run_type=inference 26 | ``` 27 | 28 | Default configuration file is located in config directory. 29 | 30 | ## Currently completed task 31 | 32 | * [x] Progressive method 33 | * [x] Tuning 34 | * [x] Add mapping and styles 35 | * [x] Remove traditional input 36 | * [x] Add noise inputs 37 | * [x] Mixing regularization 38 | 39 | ## Fake image and real image score graph 40 | 41 | ### fp32 precision 42 | ![fp32_score](images/fp32_score.png) 43 | 44 | ### mixed precision 45 | ![mixed_score](images/mixed_score.png) 46 | 47 | There seems to be no difference in the score. 48 | 49 | ## Discriminator loss 50 | 51 | ### fp32 precision 52 | ![fp32_dloss](images/fp32_loss_d.png) 53 | 54 | ### mixed precision 55 | ![mixed_dloss](images/mixed_loss_d.png) 56 | 57 | There is a problem with R1 regularization, so training does not work properly. This also affects image samples. It would be better not to use it now. 58 | 59 | ## Train speed 60 | 61 | ![precision_speed](images/precision_speed.png) 62 | 63 | There seems to be a clear speed difference depending on the precision, but it seems to be meaningless because the mixed precision training isn't done properly. 64 | 65 | ## Inference Images 66 | 67 | ### 8x8 images 68 | ![8x8](images/8x8.png) 69 | ### 16x16 images 70 | ![16x16](images/16x16.png) 71 | ### 32x32 images 72 | ![32x32](images/32x32.png) 73 | ### 64x64 images 74 | ![64x64](images/64x64.png) 75 | ### 128x128 images 76 | ![128x128](images/128x128.png) 77 | ### 256x256 images 78 | ![256x256](images/256x256.png) 79 | 80 | ## Pretrained checkpoint 81 | 82 | ~[256x256](https://drive.google.com/file/d/1YDNeDD5G-BI5Zx5RGnlggBMFinp2z8OH/view?usp=sharing)~ See [#1](https://github.com/caffeinism/StyleGAN-pytorch/issues/1) 83 | -------------------------------------------------------------------------------- /amp_support.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | try: 3 | from apex import amp 4 | except ImportError as e: 5 | amp = None 6 | 7 | def scale_loss(*args, **kwargs): 8 | if amp: 9 | return amp.scale_loss(*args, **kwargs) 10 | else: 11 | return dummy_scale_loss(*args, **kwargs) 12 | 13 | def initialize(models, optimizers, *args, **kwargs): 14 | if amp: 15 | return amp.initialize(models, optimizers, *args, **kwargs) 16 | else: 17 | return models, optimizers 18 | 19 | @contextlib.contextmanager 20 | def dummy_scale_loss(loss, *args, **kwargs): 21 | yield loss 22 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | from argparse import Namespace 3 | 4 | class Config(Namespace): 5 | def __init__(self, filename): 6 | config = yaml.load(open(filename, 'r')) 7 | super(Config, self).__init__(**config) -------------------------------------------------------------------------------- /config/ffhq.yaml: -------------------------------------------------------------------------------- 1 | dataset_dir: ../ffhq 2 | log_dir: ../log 3 | nz: 512 4 | style_depth: 8 5 | const_channel: 512 6 | lr: 0.001 7 | betas: [0.0, 0.99] 8 | eps: 1.0e-8 9 | image_size: 4 10 | batch_size: {'8':128, '16':64, '32':32, '64':16, '128':8, '256':4, '512':2, '1024':1} 11 | lrs: {'128':0.0015, '256':0.002, '512':0.003, '1024':0.003} 12 | generator_channels: [512, 512, 512, 512, 512, 256, 128, 64, 32, 16] 13 | discriminator_channels: [512, 512, 512, 512, 512, 256, 128, 64, 32, 16] 14 | log_iter: 100 15 | phase_iter: 600000 16 | n_cpu: {'8': 32, '16':32, '32':16, '64': 8, '128':4, '256':2, '512':0, '1024':0} 17 | opt_level: O0 18 | 19 | -------------------------------------------------------------------------------- /custom_layers.py: -------------------------------------------------------------------------------- 1 | # some codes copied from https://github.com/nashory/pggan-pytorch 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import numpy as np 7 | from torch.autograd import Variable 8 | import torch 9 | import torch.nn as nn 10 | import torchvision.datasets as dsets 11 | import torchvision.transforms as transforms 12 | from torch.autograd import Variable 13 | from PIL import Image 14 | import copy 15 | from torch.nn.init import kaiming_normal, calculate_gain 16 | from math import sqrt 17 | 18 | 19 | class PixelNorm(nn.Module): 20 | def __init__(self): 21 | super(PixelNorm, self).__init__() 22 | self.eps = 1e-8 23 | 24 | def forward(self, x): 25 | return x / (torch.mean(x**2, dim=1, keepdim=True) + self.eps) ** 0.5 26 | 27 | 28 | # for equaliaeed-learning rate. 29 | class EqualizedConv2d(nn.Module): 30 | def __init__(self, c_in, c_out, k_size, stride, pad): 31 | super(EqualizedConv2d, self).__init__() 32 | conv = nn.Conv2d(c_in, c_out, k_size, stride, pad) 33 | 34 | conv.weight.data.normal_() 35 | conv.bias.data.zero_() 36 | 37 | self.conv = equal_lr(conv) 38 | 39 | def forward(self, x): 40 | return self.conv(x) 41 | 42 | 43 | class EqualizedLinear(nn.Module): 44 | def __init__(self, c_in, c_out): 45 | super(EqualizedLinear, self).__init__() 46 | linear = nn.Linear(c_in, c_out) 47 | 48 | linear.weight.data.normal_() 49 | linear.bias.data.zero_() 50 | 51 | self.linear = equal_lr(linear) 52 | 53 | def forward(self, x): 54 | return self.linear(x) 55 | 56 | 57 | class AdaIn(nn.Module): 58 | def __init__(self, style_dim, channel): 59 | super(AdaIn, self).__init__() 60 | 61 | self.channel = channel 62 | 63 | self.instance_norm = nn.InstanceNorm2d(channel) 64 | self.linear = EqualizedLinear(style_dim, channel * 2) 65 | 66 | def forward(self, x, style): 67 | mu, sig = self.linear(style).chunk(2, dim=1) 68 | 69 | x = self.instance_norm(x) 70 | 71 | x = x * (sig.view(-1, self.channel, 1, 1) + 1) + mu.view(-1, self.channel, 1, 1) # affine transform 72 | 73 | return x 74 | 75 | class NoiseInjection_(nn.Module): 76 | def __init__(self, channel): 77 | super(NoiseInjection_, self).__init__() 78 | 79 | self.weight = nn.Parameter(torch.zeros(1, channel, 1, 1)) 80 | 81 | def forward(self, x, noise): 82 | return x + self.weight * noise 83 | 84 | class NoiseInjection(nn.Module): 85 | def __init__(self, channel): 86 | super(NoiseInjection, self).__init__() 87 | 88 | injection = NoiseInjection_(channel) 89 | self.injection = equal_lr(injection) 90 | 91 | def forward(self, x, noise): 92 | return self.injection(x, noise) 93 | 94 | class minibatch_stddev_layer(nn.Module): 95 | def __init__(self, group_size=4, num_new_features=1): 96 | super(minibatch_stddev_layer, self).__init__() 97 | self.group_size = group_size 98 | self.num_new_features = num_new_features 99 | 100 | def forward(self, x): 101 | group_size = min(self.group_size, x.size(0)) 102 | origin_shape = x.shape 103 | 104 | # split group 105 | y = x.view( 106 | group_size, 107 | -1, 108 | self.num_new_features, 109 | origin_shape[1] // self.num_new_features, 110 | origin_shape[2], 111 | origin_shape[3] 112 | ) 113 | 114 | # calculate stddev over group 115 | y = torch.sqrt(torch.mean((y - torch.mean(y, dim=0, keepdim=True)) ** 2, dim=0) + 1e-8) 116 | # [G, F. C, H, W] 117 | y = torch.mean(y, dim=[2,3,4], keepdim=True) 118 | # [G, F, 1, 1, 1] 119 | y = torch.squeeze(y, dim=2) 120 | # [G, F, 1, 1] 121 | y = y.repeat(group_size, 1, origin_shape[2], origin_shape[3]) 122 | # [B, F, H, W] 123 | 124 | return torch.cat([x, y], dim=1) 125 | 126 | class EqualLR: 127 | def __init__(self, name): 128 | self.name = name 129 | 130 | def compute_weight(self, module): 131 | weight = getattr(module, self.name + '_orig') 132 | fan_in = weight.data.size(1) * weight.data[0][0].numel() 133 | 134 | return weight * sqrt(2 / fan_in) 135 | 136 | @staticmethod 137 | def apply(module, name): 138 | fn = EqualLR(name) 139 | 140 | weight = getattr(module, name) 141 | del module._parameters[name] 142 | module.register_parameter(name + '_orig', nn.Parameter(weight.data)) 143 | module.register_forward_pre_hook(fn) 144 | 145 | return fn 146 | 147 | def __call__(self, module, input): 148 | weight = self.compute_weight(module) 149 | setattr(module, self.name, weight) 150 | 151 | 152 | def equal_lr(module, name='weight'): 153 | EqualLR.apply(module, name) 154 | 155 | return module 156 | -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision.datasets import ImageFolder 3 | import torchvision.transforms as transforms 4 | 5 | class Dataloader: 6 | def __init__(self, dataset_dir, batch_sizes, max_tick, n_cpu): 7 | self.dataset_dir = dataset_dir 8 | self.batch_sizes = batch_sizes 9 | self.img_size = 4 10 | self.max_tick = max_tick 11 | self.checkpoint = 0 12 | self.n_cpus = n_cpu 13 | 14 | def __iter__(self): 15 | return DataIter(self.dataset, self.batch_size, self.max_tick, self.checkpoint, self.n_cpu) 16 | 17 | def set_checkpoint(self, checkpoint_tick): 18 | self.checkpoint = checkpoint_tick 19 | 20 | def grow(self): 21 | self.checkpoint = 0 22 | self.img_size *= 2 23 | self.batch_size = self.batch_sizes[str(self.img_size)] 24 | self.n_cpu = self.n_cpus[str(self.img_size)] 25 | 26 | self.dataset = ImageFolder(root=self.dataset_dir, transform=transforms.Compose([ 27 | transforms.Resize(self.img_size), 28 | transforms.RandomHorizontalFlip(), 29 | transforms.ToTensor(), 30 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 31 | ])) 32 | 33 | def __len__(self): 34 | return (self.max_tick - self.checkpoint) // self.batch_size 35 | 36 | class DataIter: 37 | def __init__(self, dataset, batch_size, max_tick, checkpoint, n_cpu): 38 | self.dataloader = torch.utils.data.DataLoader( 39 | dataset, batch_size=batch_size, 40 | shuffle=True, drop_last=True, num_workers=n_cpu, 41 | ) 42 | self.iter = iter(self.dataloader) 43 | self.tick = self.checkpoint = checkpoint 44 | self.batch_size = batch_size 45 | self.max_tick = max_tick 46 | 47 | def __next__(self): 48 | if self.tick >= self.max_tick: 49 | raise StopIteration 50 | 51 | try: 52 | data = next(self.iter) 53 | except StopIteration as e: 54 | self.iter = iter(self.dataloader) 55 | data = next(self.iter) 56 | 57 | self.tick += self.batch_size 58 | 59 | return data, self.tick 60 | 61 | def __len__(self): 62 | return (self.max_tick - self.checkpoint) // self.batch_size 63 | 64 | -------------------------------------------------------------------------------- /images/128x128.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caffeinism/StyleGAN-pytorch/62686f015d6e831aa72118d252910bb874e672d2/images/128x128.png -------------------------------------------------------------------------------- /images/16x16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caffeinism/StyleGAN-pytorch/62686f015d6e831aa72118d252910bb874e672d2/images/16x16.png -------------------------------------------------------------------------------- /images/256x256.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caffeinism/StyleGAN-pytorch/62686f015d6e831aa72118d252910bb874e672d2/images/256x256.png -------------------------------------------------------------------------------- /images/32x32.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caffeinism/StyleGAN-pytorch/62686f015d6e831aa72118d252910bb874e672d2/images/32x32.png -------------------------------------------------------------------------------- /images/64x64.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caffeinism/StyleGAN-pytorch/62686f015d6e831aa72118d252910bb874e672d2/images/64x64.png -------------------------------------------------------------------------------- /images/8x8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caffeinism/StyleGAN-pytorch/62686f015d6e831aa72118d252910bb874e672d2/images/8x8.png -------------------------------------------------------------------------------- /images/fp32_loss_d.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caffeinism/StyleGAN-pytorch/62686f015d6e831aa72118d252910bb874e672d2/images/fp32_loss_d.png -------------------------------------------------------------------------------- /images/fp32_score.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caffeinism/StyleGAN-pytorch/62686f015d6e831aa72118d252910bb874e672d2/images/fp32_score.png -------------------------------------------------------------------------------- /images/mixed_loss_d.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caffeinism/StyleGAN-pytorch/62686f015d6e831aa72118d252910bb874e672d2/images/mixed_loss_d.png -------------------------------------------------------------------------------- /images/mixed_score.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caffeinism/StyleGAN-pytorch/62686f015d6e831aa72118d252910bb874e672d2/images/mixed_score.png -------------------------------------------------------------------------------- /images/precision_speed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caffeinism/StyleGAN-pytorch/62686f015d6e831aa72118d252910bb874e672d2/images/precision_speed.png -------------------------------------------------------------------------------- /images/structure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caffeinism/StyleGAN-pytorch/62686f015d6e831aa72118d252910bb874e672d2/images/structure.png -------------------------------------------------------------------------------- /inferencer.py: -------------------------------------------------------------------------------- 1 | from networks import Generator, Discriminator 2 | import torch 3 | import os.path 4 | import torchvision.utils as vutils 5 | import torch.nn.functional as F 6 | 7 | class Inferencer: 8 | def __init__(self, generator_channels, nz, style_depth): 9 | self.nz = nz 10 | self.generator = Generator(generator_channels, nz, style_depth).cuda() 11 | 12 | def inference(self, n): 13 | test_z = torch.randn(n, self.nz).cuda() 14 | with torch.no_grad(): 15 | self.grow() 16 | img_size = 8 17 | filename = 'checkpoints/{}x{}_last.pth'.format(img_size, img_size) 18 | while os.path.isfile(filename): 19 | self.load_checkpoint(img_size, filename) 20 | 21 | self.generator.eval() 22 | fake = self.generator(test_z, alpha=1) 23 | fake = (fake + 1) * 0.5 24 | fake = torch.clamp(fake, min=0.0, max=1.0) 25 | fake = F.interpolate(fake, size=(256, 256)) 26 | 27 | vutils.save_image(fake, 'images/{}x{}.png'.format(img_size, img_size)) 28 | 29 | self.grow() 30 | img_size *= 2 31 | filename = 'checkpoints/{}x{}_last.pth'.format(img_size, img_size) 32 | 33 | 34 | def grow(self): 35 | self.generator.grow() 36 | self.generator.cuda() 37 | 38 | def load_checkpoint(self, img_size, filename): 39 | checkpoint = torch.load(filename) 40 | 41 | print('load {}x{} checkpoint'.format(checkpoint['img_size'], checkpoint['img_size'])) 42 | while img_size < checkpoint['img_size']: 43 | self.grow() 44 | 45 | self.generator.load_state_dict(checkpoint['generator']) 46 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from config import Config 3 | import torchvision.datasets as dset 4 | import torchvision.transforms as transforms 5 | import torch 6 | import fire 7 | 8 | def main(config_file, run_type='train', checkpoint=''): 9 | # pylint: disable=no-member 10 | config = Config(config_file) 11 | 12 | print(config) 13 | 14 | if run_type == 'train': 15 | from trainer import Trainer 16 | trainer = Trainer( 17 | dataset_dir=config.dataset_dir, 18 | log_dir = config.log_dir, 19 | generator_channels=config.generator_channels, 20 | discriminator_channels=config.discriminator_channels, 21 | nz=config.nz, 22 | style_depth=config.style_depth, 23 | lrs=config.lrs, 24 | betas=config.betas, 25 | eps=config.eps, 26 | phase_iter=config.phase_iter, 27 | batch_size=config.batch_size, 28 | n_cpu=config.n_cpu, 29 | opt_level=config.opt_level 30 | ) 31 | trainer.run( 32 | log_iter=config.log_iter, 33 | checkpoint=checkpoint 34 | ) 35 | elif run_type == 'inference': 36 | from inferencer import Inferencer 37 | inferencer = Inferencer( 38 | generator_channels=config.generator_channels, 39 | nz=config.nz, 40 | style_depth=config.style_depth, 41 | ) 42 | inferencer.inference(n=8) 43 | else: 44 | raise NotImplementedError 45 | 46 | if __name__ == '__main__': 47 | fire.Fire(main) 48 | 49 | -------------------------------------------------------------------------------- /networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from custom_layers import EqualizedConv2d, EqualizedLinear, AdaIn, minibatch_stddev_layer, PixelNorm, NoiseInjection 4 | import random 5 | 6 | class Generator(nn.Module): 7 | def __init__(self, channels, style_dim, style_depth): 8 | super(Generator, self).__init__() 9 | 10 | self.style_dim = style_dim 11 | self.now_growth = 1 12 | self.channels = channels 13 | 14 | self.model = UpBlock(channels[0], channels[1], style_dim, prev=None) 15 | 16 | layers = [PixelNorm()] 17 | for _ in range(style_depth): 18 | layers.append(EqualizedLinear(style_dim, style_dim)) 19 | layers.append(nn.LeakyReLU(0.2)) 20 | 21 | self.style_mapper = nn.Sequential(*layers) 22 | 23 | def forward(self, z, alpha): 24 | if type(z) not in (tuple, list): 25 | w = self.style_mapper(z) 26 | w = [w for _ in range(self.now_growth)] 27 | else: 28 | assert len(z) == 2 # now, only support mix two styles 29 | w1, w2 = self.style_mapper(z[0]), self.style_mapper(z[1]) 30 | point = random.randint(1, self.now_growth-1) 31 | # layer_0 ~ layer_p: style with w1 32 | # layer_p ~ layer_n: style with w2 33 | w = [w1 for _ in range(point)] + [w2 for _ in range(point, self.now_growth)] 34 | 35 | x = self.model(x=None, style=w, alpha=alpha) 36 | return x 37 | 38 | def grow(self): 39 | in_c, out_c = self.channels[self.now_growth], self.channels[self.now_growth+1] 40 | self.model = UpBlock(in_c, out_c, self.style_dim, prev=self.model) 41 | self.now_growth += 1 42 | 43 | return self 44 | 45 | 46 | class Discriminator(nn.Module): 47 | def __init__(self, channels): 48 | super(Discriminator, self).__init__() 49 | 50 | self.now_growth = 1 51 | self.channels = channels 52 | 53 | self.model = DownBlock(channels[1], channels[0], next=None) 54 | 55 | def forward(self, x, alpha): 56 | return self.model(x=x, alpha=alpha) 57 | 58 | def grow(self): 59 | in_c, out_c = self.channels[self.now_growth+1], self.channels[self.now_growth] 60 | self.model = DownBlock(in_c, out_c, next=self.model) 61 | self.now_growth += 1 62 | 63 | return self 64 | 65 | 66 | class UpBlock(nn.Module): 67 | def __init__(self, in_channel, out_channel, style_dim, prev=None): 68 | super(UpBlock, self).__init__() 69 | 70 | self.prev = prev 71 | 72 | self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) 73 | 74 | if prev: 75 | self.conv1 = EqualizedConv2d(in_channel, out_channel, 3, 1, 1) 76 | else: 77 | self.input = nn.Parameter(torch.randn(1, out_channel, 4, 4)) 78 | 79 | self.noisein1 = NoiseInjection(out_channel) 80 | self.lrelu1 = nn.LeakyReLU(0.2) 81 | self.adain1 = AdaIn(style_dim, out_channel) 82 | 83 | self.conv2 = EqualizedConv2d(out_channel, out_channel, 3, 1, 1) 84 | self.noisein2 = NoiseInjection(out_channel) 85 | self.lrelu2 = nn.LeakyReLU(0.2) 86 | self.adain2 = AdaIn(style_dim, out_channel) 87 | 88 | self.to_rgb = EqualizedConv2d(out_channel, 3, 1, 1, 0) 89 | 90 | # if last layer (0 <= alpha <= 1) -> return RGB image (3 channels) 91 | # else return feature map of prev layer 92 | def forward(self, x, style, alpha=-1.0, noise=None): 93 | if self.prev: # if module has prev, then forward first. 94 | w, style = style[-1], style[:-1] # pop last style 95 | prev_x = x = self.prev(x, style) 96 | 97 | x = self.upsample(x) 98 | 99 | x = self.conv1(x) 100 | else: # else initial constant 101 | w = style[0] 102 | x = self.input.repeat(w.size(0), 1, 1, 1) 103 | 104 | # NOTE: paper's model image injection differnt noise to noise1 and noise2 layer 105 | # but, this model has just one noise per two layers 106 | noise = noise if noise else torch.randn(x.size(0), 1, x.size(2), x.size(3), device=x.device) 107 | 108 | x = self.noisein1(x, noise) 109 | x = self.lrelu1(x) 110 | x = self.adain1(x, w) 111 | 112 | x = self.conv2(x) 113 | x = self.noisein2(x, noise) 114 | x = self.lrelu2(x) 115 | x = self.adain2(x, w) 116 | 117 | if 0.0 <= alpha < 1.0: 118 | prev_rgb = self.prev.to_rgb(self.upsample(prev_x)) 119 | x = alpha * self.to_rgb(x) + (1 - alpha) * prev_rgb 120 | elif alpha == 1: 121 | x = self.to_rgb(x) 122 | 123 | return x 124 | 125 | 126 | class DownBlock(nn.Module): 127 | def __init__(self, in_channel, out_channel, next=None): 128 | super(DownBlock, self).__init__() 129 | 130 | self.next = next 131 | 132 | self.downsample = nn.Upsample(scale_factor=0.5, mode='bilinear', align_corners=False) 133 | # same as avgpool with kernel size 2 134 | 135 | if next: 136 | self.conv1 = EqualizedConv2d(in_channel, out_channel, 3, 1, 1) 137 | self.conv2 = EqualizedConv2d(out_channel, out_channel, 3, 1, 1) 138 | else: 139 | self.conv1 = nn.Sequential( 140 | minibatch_stddev_layer(), 141 | EqualizedConv2d(in_channel + 1, out_channel, 3, 1, 1), 142 | ) 143 | self.conv2 = EqualizedConv2d(out_channel, out_channel, 4, 1, 0) 144 | 145 | self.linear = EqualizedLinear(out_channel, 1) 146 | 147 | self.lrelu1 = nn.LeakyReLU(0.2) 148 | self.lrelu2 = nn.LeakyReLU(0.2) 149 | 150 | self.from_rgb = EqualizedConv2d(3, in_channel, 1, 1, 0) 151 | 152 | def forward(self, x, alpha=-1.0): 153 | input = x 154 | 155 | if 0 <= alpha: 156 | x = self.from_rgb(x) 157 | 158 | x = self.conv1(x) 159 | x = self.lrelu1(x) 160 | 161 | x = self.conv2(x) 162 | x = self.lrelu2(x) 163 | 164 | if self.next: 165 | x = self.downsample(x) 166 | 167 | if 0.0 <= alpha < 1.0: 168 | input = self.downsample(input) 169 | x = alpha * x + (1 - alpha) * self.next.from_rgb(input) 170 | 171 | x = self.next(x) 172 | else: 173 | x = x.view(x.size(0), -1) 174 | x = self.linear(x) 175 | 176 | return x 177 | -------------------------------------------------------------------------------- /tf_recorder.py: -------------------------------------------------------------------------------- 1 | from tensorboardX import SummaryWriter 2 | import os, sys 3 | import os.path 4 | 5 | # https://github.com/nashory/pggan-pytorch/blob/master/tf_recorder.py 6 | class tf_recorder: 7 | def __init__(self, network_name, log_dir): 8 | os.system('mkdir -p {}'.format(log_dir)) 9 | for i in range(1000): 10 | self.targ = os.path.join(log_dir, '{}_{}'.format(network_name, i)) 11 | if not os.path.exists(self.targ): 12 | self.writer = SummaryWriter(self.targ) 13 | break 14 | 15 | def renew(self, subname): 16 | self.writer = SummaryWriter('{}_{}'.format(self.targ, subname)) 17 | self.niter = 0 18 | 19 | def add_scalar(self, index, val): 20 | self.writer.add_scalar(index, val, self.niter) 21 | 22 | def add_scalars(self, index, group_dict): 23 | self.writer.add_scalar(index, group_dict, self.niter) 24 | 25 | def add_images(self, tag, images): 26 | self.writer.add_images(tag, images, self.niter) 27 | 28 | def iter(self, tick=1): 29 | self.niter += tick 30 | 31 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | from networks import Generator, Discriminator 2 | import torch 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | import tf_recorder as tensorboard 6 | from tqdm import tqdm 7 | from dataloader import Dataloader 8 | from torch.autograd import grad 9 | import amp_support as amp 10 | import random 11 | 12 | def requires_grad(model, flag=True): 13 | for p in model.parameters(): 14 | p.requires_grad = flag 15 | 16 | class DummyDataParallel(torch.nn.Module): 17 | def __init__(self, module): 18 | super(DummyDataParallel, self).__init__() 19 | self.module = module 20 | 21 | def forward(self, *args, **kwargs): 22 | return self.module(*args, **kwargs) 23 | 24 | def DataParallel(module): 25 | if torch.cuda.device_count() > 1: 26 | return torch.nn.DataParallel(module) 27 | else: 28 | return DummyDataParallel(module) # For a consistent model structure 29 | 30 | def cuda(module): 31 | if torch.cuda.device_count() > 0: 32 | return module.cuda() 33 | else: 34 | print('Warning: cuda cannot be activated.') 35 | return module 36 | 37 | class Trainer: 38 | def __init__(self, dataset_dir, log_dir, generator_channels, discriminator_channels, nz, style_depth, lrs, betas, eps, 39 | phase_iter, batch_size, n_cpu, opt_level): 40 | self.nz = nz 41 | self.dataloader = Dataloader(dataset_dir, batch_size, phase_iter * 2, n_cpu) 42 | 43 | self.generator = cuda(DataParallel(Generator(generator_channels, nz, style_depth))) 44 | self.discriminator = cuda(DataParallel(Discriminator(discriminator_channels))) 45 | 46 | self.tb = tensorboard.tf_recorder('StyleGAN', log_dir) 47 | 48 | self.phase_iter = phase_iter 49 | self.lrs = lrs 50 | self.betas = betas 51 | 52 | self.opt_level = opt_level 53 | 54 | def generator_trainloop(self, batch_size, alpha): 55 | requires_grad(self.generator, True) 56 | requires_grad(self.discriminator, False) 57 | 58 | # mixing regularization 59 | if random.random() < 0.9: 60 | z = [torch.randn(batch_size, self.nz).cuda(), 61 | torch.randn(batch_size, self.nz).cuda()] 62 | else: 63 | z = torch.randn(batch_size, self.nz).cuda() 64 | 65 | fake = self.generator(z, alpha=alpha) 66 | d_fake = self.discriminator(fake, alpha=alpha) 67 | loss = F.softplus(-d_fake).mean() 68 | 69 | self.optimizer_g.zero_grad() 70 | with amp.scale_loss(loss, self.optimizer_g, loss_id=0) as scaled_loss: 71 | scaled_loss.backward() 72 | self.optimizer_g.step() 73 | 74 | return loss.item() 75 | 76 | def discriminator_trainloop(self, real, alpha): 77 | requires_grad(self.generator, False) 78 | requires_grad(self.discriminator, True) 79 | 80 | real.requires_grad = True 81 | self.optimizer_d.zero_grad() 82 | 83 | d_real = self.discriminator(real, alpha=alpha) 84 | loss_real = F.softplus(-d_real).mean() 85 | with amp.scale_loss(loss_real, self.optimizer_d, loss_id=1) as scaled_loss_real: 86 | scaled_loss_real.backward(retain_graph=True) 87 | 88 | grad_real = grad( 89 | outputs=d_real.sum(), inputs=real, create_graph=True 90 | )[0] 91 | grad_penalty = ( 92 | grad_real.view(grad_real.size(0), -1).norm(2, dim=1) ** 2 93 | ).mean() 94 | grad_penalty = 10 / 2 * grad_penalty 95 | with amp.scale_loss(grad_penalty, self.optimizer_d, loss_id=1) as scaled_grad_penalty: 96 | scaled_grad_penalty.backward() 97 | 98 | if random.random() < 0.9: 99 | z = [torch.randn(real.size(0), self.nz).cuda(), 100 | torch.randn(real.size(0), self.nz).cuda()] 101 | else: 102 | z = torch.randn(real.size(0), self.nz).cuda() 103 | 104 | fake = self.generator(z, alpha=alpha) 105 | d_fake = self.discriminator(fake, alpha=alpha) 106 | loss_fake = F.softplus(d_fake).mean() 107 | with amp.scale_loss(loss_fake, self.optimizer_d) as scaled_loss_fake: 108 | scaled_loss_fake.backward() 109 | 110 | loss = scaled_loss_real + scaled_loss_fake + scaled_grad_penalty 111 | 112 | self.optimizer_d.step() 113 | 114 | return loss.item(), (d_real.mean().item(), d_fake.mean().item()) 115 | 116 | def run(self, log_iter, checkpoint): 117 | global_iter = 0 118 | 119 | test_z = torch.randn(4, self.nz).cuda() 120 | 121 | if checkpoint: 122 | self.load_checkpoint(checkpoint) 123 | else: 124 | self.grow() 125 | 126 | while True: 127 | print('train {}X{} images...'.format(self.dataloader.img_size, self.dataloader.img_size)) 128 | for iter, ((data, _), n_trained_samples) in enumerate(tqdm(self.dataloader), 1): 129 | real = data.cuda() 130 | alpha = min(1, n_trained_samples / self.phase_iter) if self.dataloader.img_size > 8 else 1 131 | 132 | loss_d, (real_score, fake_score) = self.discriminator_trainloop(real, alpha) 133 | loss_g = self.generator_trainloop(real.size(0), alpha) 134 | 135 | if global_iter % log_iter == 0: 136 | self.log(loss_d, loss_g, real_score, fake_score, test_z, alpha) 137 | 138 | # save 3 times during training 139 | if iter % (len(self.dataloader) // 4 + 1) == 0: 140 | self.save_checkpoint(n_trained_samples) 141 | 142 | global_iter += 1 143 | self.tb.iter(data.size(0)) 144 | 145 | self.save_checkpoint() 146 | self.grow() 147 | 148 | 149 | def log(self, loss_d, loss_g, real_score, fake_score, test_z, alpha): 150 | with torch.no_grad(): 151 | fake = self.generator.module(test_z, alpha=alpha) 152 | fake = (fake + 1) * 0.5 153 | fake = torch.clamp(fake, min=0.0, max=1.0) 154 | 155 | self.tb.add_scalar('loss_d', loss_d) 156 | self.tb.add_scalar('loss_g', loss_g) 157 | self.tb.add_scalar('real_score', real_score) 158 | self.tb.add_scalar('fake_score', fake_score) 159 | self.tb.add_images('fake', fake) 160 | 161 | def grow(self): 162 | self.discriminator = cuda(DataParallel(self.discriminator.module.grow())) 163 | self.generator = cuda(DataParallel(self.generator.module.grow())) 164 | self.dataloader.grow() 165 | self.tb.renew('{}x{}'.format(self.dataloader.img_size, self.dataloader.img_size)) 166 | 167 | self.lr = self.lrs.get(str(self.dataloader.img_size), 0.001) 168 | self.style_lr = self.lr * 0.01 169 | 170 | self.optimizer_d = optim.Adam(params=self.discriminator.parameters(), lr=self.lr, betas=self.betas) 171 | self.optimizer_g = optim.Adam([ 172 | {'params': self.generator.module.model.parameters(), 'lr':self.lr}, 173 | {'params': self.generator.module.style_mapper.parameters(), 'lr': self.style_lr}, 174 | ], 175 | betas=self.betas 176 | ) 177 | 178 | [self.generator, self.discriminator], [self.optimizer_g, self.optimizer_d] = amp.initialize( 179 | [self.generator, self.discriminator], 180 | [self.optimizer_g, self.optimizer_d], 181 | opt_level=self.opt_level, 182 | num_losses=2, 183 | ) 184 | 185 | def save_checkpoint(self, tick='last'): 186 | torch.save({ 187 | 'generator': self.generator.state_dict(), 188 | 'discriminator': self.discriminator.state_dict(), 189 | 'generator_optimizer': self.optimizer_g.state_dict(), 190 | 'discriminator_optimizer': self.optimizer_d.state_dict(), 191 | 'img_size': self.dataloader.img_size, 192 | 'tick': tick, 193 | }, 'checkpoints/{}x{}_{}.pth'.format(self.dataloader.img_size, self.dataloader.img_size, tick)) 194 | 195 | def load_checkpoint(self, filename): 196 | checkpoint = torch.load(filename) 197 | 198 | print('load {}x{} checkpoint'.format(checkpoint['img_size'], checkpoint['img_size'])) 199 | while self.dataloader.img_size < checkpoint['img_size']: 200 | self.grow() 201 | 202 | self.generator.load_state_dict(checkpoint['generator']) 203 | self.discriminator.load_state_dict(checkpoint['discriminator']) 204 | self.optimizer_g.load_state_dict(checkpoint['generator_optimizer']) 205 | self.optimizer_d.load_state_dict(checkpoint['discriminator_optimizer']) 206 | 207 | if checkpoint['tick'] == 'last': 208 | self.grow() 209 | else: 210 | self.dataloader.set_checkpoint(checkpoint['tick']) 211 | self.tb.iter(checkpoint['tick']) 212 | --------------------------------------------------------------------------------