├── .gitignore ├── LICENSE ├── data_loader.py ├── docker └── Dockerfile ├── image_converter.py ├── loss.py ├── network.py ├── readme.md ├── settings.json ├── train_gan.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | output/* 3 | __pycache__ 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2019 Takehiro Araki. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | 4 | import random 5 | from pathlib import Path 6 | import numpy as np 7 | from PIL import Image 8 | import concurrent.futures as futures 9 | 10 | 11 | def chunks(lst, chunk_size): 12 | num_chunks = len(lst) // chunk_size 13 | for i in range(num_chunks): 14 | start = i * chunk_size 15 | yield lst[start:start+chunk_size] 16 | 17 | 18 | class TrainDataLoader: 19 | def __init__(self, paths, settings): 20 | self.paths = paths 21 | 22 | self.flip = settings["flip"] 23 | self.color_shift = settings["color_shift"] 24 | 25 | def generate(self, batch_size, width, height): 26 | num_workers = 5 27 | 28 | with futures.ThreadPoolExecutor(num_workers) as executor: 29 | tasks = [] 30 | while True: 31 | np.random.shuffle(self.paths) 32 | for paths in chunks(self.paths, batch_size): 33 | newtask = executor.submit(self.load_images, paths, width, height) 34 | tasks.append(newtask) 35 | if len(tasks) < num_workers: 36 | continue 37 | task = tasks.pop(0) 38 | result = task.result(100) 39 | if result is not None: 40 | yield result 41 | 42 | def load_images(self, paths, width, height): 43 | 44 | array = np.empty([len(paths), height, width, 3]) 45 | for i, path in enumerate(paths): 46 | try: 47 | image = Image.open(str(path)) 48 | except: 49 | return None 50 | 51 | image = image.resize((width, height), Image.LANCZOS) 52 | image = np.array(image, dtype=float) 53 | image /= 255 54 | 55 | if self.flip: 56 | if bool(random.getrandbits(1)): 57 | image = np.flip(image, 1) 58 | 59 | if self.color_shift: 60 | v = np.random.choice([0.7, 1, 1.3]) 61 | image = image ** v 62 | 63 | array[i] = image 64 | 65 | return array 66 | 67 | 68 | class LabeledDataLoader: 69 | 70 | def __init__(self, settings): 71 | super().__init__() 72 | self.flip = settings["data_augmentation"]["flip"] 73 | self.color_shift = settings["data_augmentation"]["color_shift"] 74 | 75 | self.labels = settings["labels"] 76 | self.label_size = len(self.labels) 77 | 78 | img_root = Path(__file__).parent.joinpath("../images") 79 | self.data = [] 80 | for i, chara in enumerate(self.labels): 81 | chara_root = img_root.joinpath(chara) 82 | for path in chara_root.glob("**/*.png"): 83 | self.data.append((path, i)) 84 | 85 | def generate(self, batch_size, width, height): 86 | num_workers = 5 87 | 88 | with futures.ThreadPoolExecutor(num_workers) as executor: 89 | tasks = [] 90 | while True: 91 | np.random.shuffle(self.data) 92 | for data in chunks(self.data, batch_size): 93 | newtask = executor.submit(self.load_images, data, width, height) 94 | tasks.append(newtask) 95 | if len(tasks) < num_workers: 96 | continue 97 | task = tasks.pop(0) 98 | result = task.result(100) 99 | if result is not None: 100 | yield result 101 | 102 | def load_images(self, data, width, height): 103 | images = np.empty([len(data), height, width, 3]) 104 | labels = np.array(list(map(lambda x: x[1], data))) 105 | for i, row in enumerate(data): 106 | path, label = row 107 | try: 108 | image = Image.open(str(path)) 109 | except: 110 | return None 111 | 112 | image = image.resize((width, height), Image.LANCZOS) 113 | image = np.array(image, dtype=float) 114 | image /= 255 115 | 116 | if self.flip: 117 | if bool(random.getrandbits(1)): 118 | image = np.flip(image, 1) 119 | 120 | if self.color_shift: 121 | v = np.random.choice([0.7, 1, 1.3]) 122 | image = image ** v 123 | 124 | images[i] = image 125 | 126 | return images, labels 127 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/pytorch:19.01-py3 2 | 3 | RUN pip install pillow scikit-image tensorboard tensorboardx matplotlib 4 | 5 | CMD /bin/bash 6 | -------------------------------------------------------------------------------- /image_converter.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class RGBConverter: 5 | def to_train_data(self, images): 6 | # images is in [0, 1] 7 | return images * 2 - 1 8 | 9 | def from_generator_output(self, images): 10 | # images is in (about) [-1, 1] 11 | return np.clip((images + 1) / 2, 0, 1) 12 | 13 | 14 | class YUVConverter: 15 | # ITU-R BT.601 16 | # YCbCr 17 | def to_train_data(self, images): 18 | # images is in [0, 1] 19 | yuv = np.zeros_like(images, dtype=np.float) 20 | yuv[:, 0] = 0.299 * images[:, 0] + 0.587 * images[:, 1] + 0.114 * images[:, 2] 21 | yuv[:, 1] = -0.168736 * images[:, 0] - 0.331264 * images[:, 1] + 0.5 * images[:, 2] 22 | yuv[:, 2] = 0.5 * images[:, 0] - 0.418688 * images[:, 1] - 0.081312 * images[:, 2] 23 | yuv[:, 0] = yuv[:, 0]*2 - 1 24 | yuv[:, 1:] *= 2 25 | return yuv 26 | 27 | def from_generator_output(self, images): 28 | # images is in [-1, 1] 29 | images = images.copy() 30 | images = np.clip(images, -1, 1) 31 | images[:, 0] = (images[:, 0] + 1)/2 32 | images[:, 1:] /= 2 33 | 34 | rgb = np.zeros_like(images, dtype=np.float) 35 | rgb[:, 0] = images[:, 0] + 1.402 * images[:, 2] 36 | rgb[:, 1] = images[:, 0] - 0.344136 * images[:, 1] - 0.714136 * images[:, 2] 37 | rgb[:, 2] = images[:, 0] + 1.772 * images[:, 1] 38 | 39 | return np.clip(rgb, 0, 1) 40 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | 4 | 5 | def d_lsgan_loss(discriminator, trues, fakes, labels, alpha): 6 | d_trues = discriminator.forward(trues, labels, alpha) 7 | d_fakes = discriminator.forward(fakes, labels, alpha) 8 | 9 | loss = F.mse_loss(d_trues, torch.ones_like(d_trues)) + F.mse_loss(d_fakes, torch.zeros_like(d_fakes)) 10 | loss /= 2 11 | return loss 12 | 13 | 14 | def g_lsgan_loss(discriminator, fakes, labels, alpha): 15 | d_fakes = discriminator.forward(fakes, labels, alpha) 16 | loss = F.mse_loss(d_fakes, torch.ones_like(d_fakes)) / 2 17 | return loss 18 | 19 | 20 | def d_wgan_loss(discriminator, trues, fakes, labels, alpha): 21 | epsilon_drift = 1e-3 22 | lambda_gp = 10 23 | 24 | batch_size = fakes.size()[0] 25 | d_trues = discriminator.forward(trues, labels, alpha) 26 | d_fakes = discriminator.forward(fakes, labels, alpha) 27 | 28 | loss_wd = d_trues.mean() - d_fakes.mean() 29 | 30 | # gradient penalty 31 | epsilon = torch.rand(batch_size, 1, 1, 1, dtype=fakes.dtype, device=fakes.device) 32 | intpl = epsilon * fakes + (1 - epsilon) * trues 33 | intpl.requires_grad_() 34 | f = discriminator.forward(intpl, labels, alpha) 35 | grad = torch.autograd.grad(f.sum(), intpl, create_graph=True)[0] 36 | grad_norm = grad.view(batch_size, -1).norm(dim=1) 37 | loss_gp = lambda_gp * ((grad_norm - 1) ** 2).mean() 38 | 39 | # drift 40 | loss_drift = epsilon_drift * (d_trues ** 2).mean() 41 | 42 | loss = -loss_wd + loss_gp + loss_drift 43 | 44 | wd = loss_wd.item() 45 | 46 | return loss, wd 47 | 48 | 49 | def g_wgan_loss(discriminator, fakes, labels, alpha): 50 | d_fakes = discriminator.forward(fakes, labels, alpha) 51 | loss = -d_fakes.mean() 52 | return loss 53 | 54 | 55 | def d_logistic_loss(discriminator, trues, fakes, labels, alpha, r1gamma=10): 56 | d_fakes = discriminator.forward(fakes, labels, alpha) 57 | trues.requires_grad_() 58 | d_trues = discriminator.forward(trues, labels, alpha) 59 | loss = F.softplus(d_fakes).mean() + F.softplus(-d_trues).mean() 60 | 61 | if r1gamma > 0: 62 | grad = torch.autograd.grad(d_trues.sum(), trues, create_graph=True)[0] 63 | loss += r1gamma/2 * (grad**2).sum(dim=(1, 2, 3)).mean() 64 | 65 | return loss 66 | 67 | 68 | def g_logistic_loss(discriminator, fakes, labels, alpha): 69 | d_fakes = discriminator.forward(fakes, labels, alpha) 70 | return F.softplus(-d_fakes).mean() 71 | -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.init as init 5 | import torch.nn.functional as F 6 | 7 | 8 | class PixelNormalizationLayer(nn.Module): 9 | def __init__(self, settings): 10 | super().__init__() 11 | self.epsilon = settings["epsilon"] 12 | 13 | def forward(self, x): 14 | # x is [B, C, H, W] 15 | x2 = x ** 2 16 | 17 | length_inv = torch.rsqrt(x2.mean(1, keepdim=True) + self.epsilon) 18 | 19 | return x * length_inv 20 | 21 | 22 | class MinibatchStdConcatLayer(nn.Module): 23 | def __init__(self, settings): 24 | super().__init__() 25 | 26 | self.num_concat = settings["std_concat"]["num_concat"] 27 | self.group_size = settings["std_concat"]["group_size"] 28 | self.use_variance = settings["std_concat"]["use_variance"] # else use variance 29 | self.epsilon = settings["epsilon"] 30 | 31 | def forward(self, x): 32 | if self.num_concat == 0: 33 | return x 34 | 35 | group_size = self.group_size 36 | # x is [B, C, H, W] 37 | size = x.size() 38 | assert(size[0] % group_size == 0) 39 | M = size[0]//group_size 40 | 41 | x32 = x.to(torch.float32) 42 | 43 | y = x32.view(group_size, M, -1) # [group_size, M, -1] 44 | mean = y.mean(0, keepdim=True) # [1, M, -1] 45 | y = ((y - mean)**2).mean(0) # [M, -1] 46 | if not self.use_variance: 47 | y = (y + self.epsilon).sqrt() 48 | y = y.mean(1) # [M] 49 | y = y.repeat(group_size, 1) # [group_size, M] 50 | y = y.view(-1, 1, 1, 1) 51 | y1 = y.expand([size[0], 1, size[2], size[3]]) 52 | y1 = y1.to(x.dtype) 53 | 54 | if self.num_concat == 1: 55 | return torch.cat([x, y1], 1) 56 | 57 | # self.num_concat == 2 58 | y = x32.view(M, group_size, -1) # [M, group_size, -1] 59 | mean = y.mean(1, keepdim=True) # [M, 1, -1] 60 | y = ((y - mean) ** 2).mean(1) # [M, -1] 61 | if self.use_variance: 62 | y = (y + 1e-8).sqrt() 63 | y = y.mean(1, keepdim=True) # [M, 1] 64 | y = y.repeat(1, group_size) # [M, group_size] 65 | y = y.view(-1, 1, 1, 1) # [B, 1, 1, 1] 66 | y2 = y.expand([size[0], 1, size[2], size[3]]) 67 | y2 = y2.to(x.dtype) 68 | 69 | return torch.cat([x, y1, y2], 1) 70 | 71 | 72 | class Blur3x3(nn.Module): 73 | def __init__(self): 74 | super().__init__() 75 | 76 | f = np.array([1, 2, 1], dtype=np.float32) 77 | f = f[None, :] * f[:, None] 78 | f /= np.sum(f) 79 | f = f.reshape([1, 1, 3, 3]) 80 | self.register_buffer("filter", torch.from_numpy(f)) 81 | 82 | def forward(self, x): 83 | ch = x.size(1) 84 | return F.conv2d(x, self.filter.expand(ch, -1, -1, -1), padding=1, groups=ch) 85 | 86 | 87 | class WSConv2d(nn.Module): 88 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, gain=np.sqrt(2)): 89 | super().__init__() 90 | weight = torch.empty(out_channels, in_channels, kernel_size, kernel_size) 91 | init.normal_(weight) 92 | self.weight = nn.Parameter(weight) 93 | scale = gain / np.sqrt(in_channels * kernel_size * kernel_size) 94 | self.register_buffer("scale", torch.tensor(scale)) 95 | 96 | self.bias = nn.Parameter(torch.zeros(out_channels)) 97 | self.stride = stride 98 | self.padding = padding 99 | 100 | def forward(self, x): 101 | scaled_weight = self.weight * self.scale 102 | return F.conv2d(x, scaled_weight, self.bias, self.stride, self.padding) 103 | 104 | 105 | class WSConvTranspose2d(nn.Module): 106 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, gain=np.sqrt(2)): 107 | super().__init__() 108 | weight = torch.empty(in_channels, out_channels, kernel_size, kernel_size) 109 | init.normal_(weight) 110 | self.weight = nn.Parameter(weight) 111 | scale = gain / np.sqrt(in_channels * kernel_size * kernel_size) 112 | self.register_buffer("scale", torch.tensor(scale)) 113 | 114 | self.bias = nn.Parameter(torch.zeros(out_channels)) 115 | self.stride = stride 116 | self.padding = padding 117 | 118 | def forward(self, x): 119 | scaled_weight = self.weight * self.scale 120 | return F.conv_transpose2d(x, scaled_weight, self.bias, self.stride, self.padding) 121 | 122 | 123 | class AdaIN(nn.Module): 124 | def __init__(self, dim, w_dim): 125 | super().__init__() 126 | self.dim = dim 127 | self.epsilon = 1e-8 128 | self.scale_transform = WSConv2d(w_dim, dim, 1, 1, 0, gain=1) 129 | self.bias_transform = WSConv2d(w_dim, dim, 1, 1, 0, gain=1) 130 | 131 | def forward(self, x, w): 132 | x = F.instance_norm(x, eps=self.epsilon) 133 | 134 | # scale 135 | scale = self.scale_transform(w) 136 | bias = self.bias_transform(w) 137 | 138 | return scale * x + bias 139 | 140 | 141 | class NoiseLayer(nn.Module): 142 | def __init__(self, dim, size): 143 | super().__init__() 144 | 145 | self.fixed = False 146 | 147 | self.size = size 148 | self.register_buffer("fixed_noise", torch.randn([1, 1, size, size])) 149 | 150 | self.noise_scale = nn.Parameter(torch.zeros(1, dim, 1, 1)) 151 | 152 | def forward(self, x): 153 | batch_size = x.size()[0] 154 | if self.fixed: 155 | noise = self.fixed_noise.expand(batch_size, -1, -1, -1) 156 | else: 157 | noise = torch.randn([batch_size, 1, self.size, self.size], dtype=x.dtype, device=x.device) 158 | return x + noise * self.noise_scale 159 | 160 | 161 | class LatentTransformation(nn.Module): 162 | def __init__(self, settings, label_size): 163 | super().__init__() 164 | 165 | self.z_dim = settings["z_dim"] 166 | self.w_dim = settings["w_dim"] 167 | self.latent_normalization = PixelNormalizationLayer(settings) if settings["normalize_latents"] else None 168 | activation = nn.LeakyReLU(negative_slope=0.2) 169 | 170 | use_labels = settings["use_labels"] 171 | 172 | self.latent_transform = nn.Sequential( 173 | WSConv2d(self.z_dim * 2 if use_labels else self.z_dim, self.z_dim, 1, 1, 0), 174 | activation, 175 | WSConv2d(self.z_dim, self.z_dim, 1, 1, 0), 176 | activation, 177 | WSConv2d(self.z_dim, self.z_dim, 1, 1, 0), 178 | activation, 179 | WSConv2d(self.z_dim, self.z_dim, 1, 1, 0), 180 | activation, 181 | WSConv2d(self.z_dim, self.z_dim, 1, 1, 0), 182 | activation, 183 | WSConv2d(self.z_dim, self.z_dim, 1, 1, 0), 184 | activation, 185 | WSConv2d(self.z_dim, self.w_dim, 1, 1, 0), 186 | activation 187 | ) 188 | 189 | if use_labels: 190 | self.label_embed = nn.Embedding(label_size, self.z_dim) 191 | else: 192 | self.label_embed = None 193 | 194 | def forward(self, latent, labels): 195 | latent = latent.view([-1, self.z_dim, 1, 1]) 196 | if self.label_embed is not None: 197 | labels = self.label_embed(labels).view([-1, self.z_dim, 1, 1]) 198 | latent = torch.cat([latent, labels], dim=1) 199 | 200 | if self.latent_normalization is not None: 201 | latent = self.latent_normalization(latent) 202 | 203 | return self.latent_transform(latent) 204 | 205 | 206 | class SynthFirstBlock(nn.Module): 207 | def __init__(self, start_dim, output_dim, w_dim, base_image_init, use_noise): 208 | super().__init__() 209 | 210 | self.base_image = nn.Parameter(torch.empty(1, start_dim, 4, 4)) 211 | if base_image_init == "zeros": 212 | nn.init.zeros_(self.base_image) 213 | elif base_image_init == "ones": 214 | nn.init.ones_(self.base_image) 215 | elif base_image_init == "zero_normal": 216 | nn.init.normal_(self.base_image, 0, 1) 217 | elif base_image_init == "one_normal": 218 | nn.init.normal_(self.base_image, 1, 1) 219 | else: 220 | print(f"Invalid base_image_init: {base_image_init}") 221 | exit(1) 222 | 223 | self.conv = WSConv2d(start_dim, output_dim, 3, 1, 1) 224 | 225 | self.noise1 = NoiseLayer(start_dim, 4) 226 | self.noise2 = NoiseLayer(output_dim, 4) 227 | if not use_noise: 228 | self.noise1.noise_scale.zeros_() 229 | self.noise1.fixed = True 230 | self.noise2.noise_scale.zeros_() 231 | self.noise2.fixed = True 232 | 233 | self.adain1 = AdaIN(start_dim, w_dim) 234 | self.adain2 = AdaIN(output_dim, w_dim) 235 | 236 | self.activation = nn.LeakyReLU(negative_slope=0.2) 237 | 238 | def forward(self, w1, w2): 239 | batch_size = w1.size()[0] 240 | 241 | x = self.base_image.expand(batch_size, -1, -1, -1) 242 | x = self.noise1(x) 243 | x = self.activation(x) 244 | x = self.adain1(x, w1) 245 | 246 | x = self.conv(x) 247 | x = self.noise2(x) 248 | x = self.activation(x) 249 | x = self.adain2(x, w2) 250 | 251 | return x 252 | 253 | 254 | class SynthBlock(nn.Module): 255 | def __init__(self, input_dim, output_dim, output_size, w_dim, upsample_mode, use_blur, use_noise): 256 | super().__init__() 257 | 258 | self.conv1 = WSConv2d(input_dim, output_dim, 3, 1, 1) 259 | self.conv2 = WSConv2d(output_dim, output_dim, 3, 1, 1) 260 | if use_blur: 261 | self.blur = Blur3x3() 262 | else: 263 | self.blur = None 264 | 265 | self.noise1 = NoiseLayer(output_dim, output_size) 266 | self.noise2 = NoiseLayer(output_dim, output_size) 267 | if not use_noise: 268 | self.noise1.noise_scale.zeros_() 269 | self.noise1.fixed = True 270 | self.noise2.noise_scale.zeros_() 271 | self.noise2.fixed = True 272 | 273 | self.adain1 = AdaIN(output_dim, w_dim) 274 | self.adain2 = AdaIN(output_dim, w_dim) 275 | 276 | self.activation = nn.LeakyReLU(negative_slope=0.2) 277 | 278 | self.upsample_mode = upsample_mode 279 | 280 | def forward(self, x, w1, w2): 281 | x = F.interpolate(x, scale_factor=2, mode=self.upsample_mode) 282 | x = self.conv1(x) 283 | if self.blur is not None: 284 | x = self.blur(x) 285 | x = self.noise1(x) 286 | x = self.activation(x) 287 | x = self.adain1(x, w1) 288 | 289 | x = self.conv2(x) 290 | x = self.noise2(x) 291 | x = self.activation(x) 292 | x = self.adain2(x, w2) 293 | 294 | return x 295 | 296 | 297 | class SynthesisModule(nn.Module): 298 | def __init__(self, settings): 299 | super().__init__() 300 | 301 | self.w_dim = settings["w_dim"] 302 | self.upsample_mode = settings["upsample_mode"] 303 | use_blur = settings["use_blur"] 304 | use_noise = settings["use_noise"] 305 | base_image_init = settings["base_image_init"] 306 | 307 | self.blocks = nn.ModuleList([ 308 | SynthFirstBlock(256, 256, self.w_dim, base_image_init, use_noise), 309 | SynthBlock(256, 256, 8, self.w_dim, self.upsample_mode, use_blur, use_noise), 310 | SynthBlock(256, 256, 16, self.w_dim, self.upsample_mode, use_blur, use_noise), 311 | SynthBlock(256, 128, 32, self.w_dim, self.upsample_mode, use_blur, use_noise), 312 | SynthBlock(128, 64, 64, self.w_dim, self.upsample_mode, use_blur, use_noise), 313 | SynthBlock(64, 32, 128, self.w_dim, self.upsample_mode, use_blur, use_noise), 314 | SynthBlock(32, 16, 256, self.w_dim, self.upsample_mode, use_blur, use_noise) 315 | ]) 316 | 317 | self.to_rgbs = nn.ModuleList([ 318 | WSConv2d(256, 3, 1, 1, 0, gain=1), 319 | WSConv2d(256, 3, 1, 1, 0, gain=1), 320 | WSConv2d(256, 3, 1, 1, 0, gain=1), 321 | WSConv2d(128, 3, 1, 1, 0, gain=1), 322 | WSConv2d(64, 3, 1, 1, 0, gain=1), 323 | WSConv2d(32, 3, 1, 1, 0, gain=1), 324 | WSConv2d(16, 3, 1, 1, 0, gain=1) 325 | ]) 326 | 327 | self.register_buffer("level", torch.tensor(1, dtype=torch.int32)) 328 | 329 | def set_noise_fixed(self, fixed): 330 | for module in self.modules(): 331 | if isinstance(module, NoiseLayer): 332 | module.fixed = fixed 333 | 334 | def forward(self, w, alpha): 335 | # w is [batch_size. level*2, w_dim, 1, 1] 336 | level = self.level.item() 337 | 338 | x = self.blocks[0](w[:, 0], w[:, 1]) 339 | 340 | if level == 1: 341 | x = self.to_rgbs[0](x) 342 | return x 343 | 344 | for i in range(1, level-1): 345 | x = self.blocks[i](x, w[:, i*2], w[:, i*2+1]) 346 | 347 | x2 = x 348 | x2 = self.blocks[level-1](x2, w[:, level*2-2], w[:, level*2-1]) 349 | x2 = self.to_rgbs[level-1](x2) 350 | 351 | if alpha == 1: 352 | x = x2 353 | else: 354 | x1 = self.to_rgbs[level - 2](x) 355 | x1 = F.interpolate(x1, scale_factor=2, mode=self.upsample_mode) 356 | x = torch.lerp(x1, x2, alpha) 357 | 358 | return x 359 | 360 | def write_histogram(self, writer, step): 361 | for lv in range(self.level.item()): 362 | block = self.blocks[lv] 363 | for name, param in block.named_parameters(): 364 | writer.add_histogram(f"g_synth_block{lv}/{name}", param.cpu().data.numpy(), step) 365 | 366 | for name, param in self.to_rgbs.named_parameters(): 367 | writer.add_histogram(f"g_synth_block.torgb/{name}", param.cpu().data.numpy(), step) 368 | 369 | 370 | class Generator(nn.Module): 371 | def __init__(self, settings, label_size): 372 | super().__init__() 373 | 374 | self.latent_transform = LatentTransformation(settings, label_size) 375 | self.synthesis_module = SynthesisModule(settings) 376 | self.style_mixing_prob = settings["style_mixing_prob"] 377 | 378 | # Truncation trick 379 | self.register_buffer("w_average", torch.zeros(1, settings["z_dim"], 1, 1)) 380 | self.w_average_beta = 0.995 381 | self.trunc_w_layers = 8 382 | self.trunc_w_psi = 0.8 383 | 384 | def set_level(self, level): 385 | self.synthesis_module.level.fill_(level) 386 | 387 | def forward(self, z, labels, alpha): 388 | batch_size = z.size()[0] 389 | level = self.synthesis_module.level.item() 390 | 391 | w = self.latent_transform(z, labels) 392 | 393 | # update w_average 394 | if self.training: 395 | self.w_average = torch.lerp(w.mean(0, keepdim=True).detach(), self.w_average, self.w_average_beta) 396 | 397 | # w becomes [B, level*2, z_dim, 1, 1] 398 | # level*2 is because each synthesis block has two points of style inputs 399 | w = w.view(batch_size, 1, -1, 1, 1)\ 400 | .expand(-1, level*2, -1, -1, -1) 401 | 402 | # style mixing 403 | if self.training and level >= 2: 404 | z_mix = torch.randn_like(z) 405 | w_mix = self.latent_transform(z_mix, labels) 406 | for batch_index in range(batch_size): 407 | if np.random.uniform(0, 1) < self.style_mixing_prob: 408 | cross_point = np.random.randint(1, level*2) 409 | w[batch_index, cross_point:] = w_mix[batch_index] 410 | 411 | # Truncation trick 412 | if not self.training: 413 | w[:, self.trunc_w_layers:] = torch.lerp(self.w_average, 414 | w[:, self.trunc_w_layers:], 415 | self.trunc_w_psi) 416 | 417 | fakes = self.synthesis_module(w, alpha) 418 | 419 | return fakes 420 | 421 | def write_histogram(self, writer, step): 422 | for name, param in self.latent_transform.named_parameters(): 423 | writer.add_histogram(f"g_lt/{name}", param.cpu().data.numpy(), step) 424 | self.synthesis_module.write_histogram(writer, step) 425 | writer.add_histogram("w_average", self.w_average.cpu().data.numpy(), step) 426 | 427 | 428 | class DBlock(nn.Module): 429 | def __init__(self, inpit_dim, output_dim, use_blur): 430 | super().__init__() 431 | 432 | self.conv1 = WSConv2d(inpit_dim, output_dim, 3, 1, 1) 433 | self.conv2 = WSConv2d(output_dim, output_dim, 3, 1, 1) 434 | if use_blur: 435 | self.blur = Blur3x3() 436 | else: 437 | self.blur = None 438 | self.activation = nn.LeakyReLU(negative_slope=0.2) 439 | 440 | def forward(self, x): 441 | x = self.conv1(x) 442 | x = self.activation(x) 443 | if self.blur is not None: 444 | x = self.blur(x) 445 | x = self.conv2(x) 446 | x = self.activation(x) 447 | x = F.avg_pool2d(x, kernel_size=2) 448 | return x 449 | 450 | 451 | class DLastBlock(nn.Module): 452 | def __init__(self, input_dim, label_size): 453 | super().__init__() 454 | 455 | self.conv1 = WSConv2d(input_dim, input_dim, 3, 1, 1) 456 | self.conv2 = WSConv2d(input_dim, input_dim, 4, 1, 0) 457 | self.conv3 = WSConv2d(input_dim, label_size, 1, 1, 0, gain=1) 458 | self.activation = nn.LeakyReLU(negative_slope=0.2) 459 | 460 | def forward(self, x): 461 | x = self.conv1(x) 462 | x = self.activation(x) 463 | x = self.conv2(x) 464 | x = self.activation(x) 465 | x = self.conv3(x) 466 | return x 467 | 468 | 469 | class Discriminator(nn.Module): 470 | def __init__(self, settings, label_size): 471 | super().__init__() 472 | 473 | use_blur = settings["use_blur"] 474 | self.downsample_mode = settings["upsample_mode"] 475 | 476 | self.from_rgbs = nn.ModuleList([ 477 | WSConv2d(3, 16, 1, 1, 0), 478 | WSConv2d(3, 32, 1, 1, 0), 479 | WSConv2d(3, 64, 1, 1, 0), 480 | WSConv2d(3, 128, 1, 1, 0), 481 | WSConv2d(3, 256, 1, 1, 0), 482 | WSConv2d(3, 256, 1, 1, 0), 483 | WSConv2d(3, 256, 1, 1, 0) 484 | ]) 485 | 486 | self.use_labels = settings["use_labels"] 487 | if self.use_labels: 488 | self.label_size = label_size 489 | else: 490 | self.label_size = 1 491 | 492 | self.blocks = nn.ModuleList([ 493 | DBlock(16, 32, use_blur), 494 | DBlock(32, 64, use_blur), 495 | DBlock(64, 128, use_blur), 496 | DBlock(128, 256, use_blur), 497 | DBlock(256, 256, use_blur), 498 | DBlock(256, 256, use_blur), 499 | DLastBlock(256, self.label_size) 500 | ]) 501 | 502 | self.activation = nn.LeakyReLU(negative_slope=0.2) 503 | 504 | self.register_buffer("level", torch.tensor(1, dtype=torch.int32)) 505 | 506 | def set_level(self, level): 507 | self.level.fill_(level) 508 | 509 | def forward(self, x, labels, alpha): 510 | level = self.level.item() 511 | 512 | if level == 1: 513 | x = self.from_rgbs[-1](x) 514 | x = self.activation(x) 515 | x = self.blocks[-1](x) 516 | else: 517 | x2 = self.from_rgbs[-level](x) 518 | x2 = self.activation(x2) 519 | x2 = self.blocks[-level](x2) 520 | 521 | if alpha == 1: 522 | x = x2 523 | else: 524 | x1 = F.interpolate(x, scale_factor=0.5, mode=self.downsample_mode) 525 | x1 = self.from_rgbs[-level+1](x1) 526 | x1 = self.activation(x1) 527 | 528 | x = torch.lerp(x1, x2, alpha) 529 | 530 | for l in range(1, level): 531 | x = self.blocks[-level+l](x) 532 | 533 | if self.use_labels: 534 | x = x.view([-1, self.label_size]) 535 | return torch.gather(x, 1, labels.view(-1, 1)) 536 | else: 537 | return x.view([-1, 1]) 538 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Stylegan on PyTorch 2 | 3 | My implementation of StyleGAN on PyTorch, up to 256x256 generator. 4 | 5 | - [Original Paper](https://arxiv.org/abs/1812.04948) 6 | - [Official implementation(TensorFlow)](https://github.com/NVlabs/stylegan) 7 | 8 | ## Run 9 | 10 | ### Docker image 11 | There's a Dockerfile including all requirements. 12 | NGC Account is required for base image. 13 | 14 | https://www.nvidia.com/en-us/gpu-cloud/ 15 | 16 | ### Settings 17 | 18 | By default, the networks can generate up to 256x256 images. 19 | 20 | 1. Place images under `../images/{label}` 21 | 1. Edit settings.json 22 | 1. Run `python train_gay.py` 23 | 24 | Directory structure must be like below: 25 | 26 | ``` 27 | ├─ images 28 | | ├─ ffhq 29 | | | ├─ image1.png 30 | | | └─ ... 31 | | ├─ your custom label1 32 | | | ├─ image1.png 33 | | | └─ ... 34 | | ├─ your custom label2 35 | | | ├─ image1.png 36 | | | └─ ... 37 | | └─ ... 38 | └─ stylegan(this repository) 39 | 40 | ``` 41 | 42 | ## Result 43 | 44 | ![face](https://user-images.githubusercontent.com/12446914/56738130-f12cfc00-67a6-11e9-93ea-95abd08d5418.png) 45 | 46 | ## Training time 47 | 48 | It took about 3 days to train 256x256 image generator on single RTX2080 machine. 49 | -------------------------------------------------------------------------------- /settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "use_cuda": true, 3 | "use_apex": true, 4 | "detect_anomaly": false, 5 | "start_level": 2, 6 | "max_level": 7, 7 | "batch_sizes": [64, 64, 32, 32, 16, 8, 8], 8 | "num_images_in_stage": 1200000, 9 | "loss": "wgan", 10 | "use_yuv": false, 11 | "reset_optimizer": false, 12 | "gs_beta": 1e-4, 13 | "learning_rates": { 14 | "latent_transformation": 0.00001, 15 | "generator": 0.001, 16 | "discriminator": 0.001 17 | }, 18 | "network": { 19 | "z_dim": 256, 20 | "w_dim": 256, 21 | "epsilon": 1e-8, 22 | "normalize_latents": true, 23 | "style_mixing_prob": 0.9, 24 | "upsample_mode": "bilinear", 25 | "use_labels": false, 26 | "use_blur": false, 27 | "use_noise": true, 28 | "base_image_init": "one_normal" 29 | }, 30 | "data_augmentation":{ 31 | "flip": true, 32 | "color_shift": false 33 | }, 34 | "save_steps":{ 35 | "image": 1000, 36 | "model": 10000, 37 | "histogram": 30000 38 | }, 39 | "labels": [ 40 | "ffhq" 41 | ] 42 | } 43 | -------------------------------------------------------------------------------- /train_gan.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import json 4 | import shutil 5 | from pathlib import Path 6 | from datetime import datetime, timezone, timedelta 7 | import torch 8 | import torch.backends.cudnn 9 | import torch.autograd 10 | import torch.optim as optim 11 | import torch.utils.data 12 | from tensorboardX import SummaryWriter 13 | import torchvision 14 | 15 | from apex import amp 16 | 17 | import network 18 | import loss 19 | import data_loader 20 | import image_converter 21 | import utils 22 | 23 | # parameters 24 | SETTING_JSON_PATH = "./settings.json" 25 | 26 | 27 | def main(): 28 | # load settings 29 | with open(SETTING_JSON_PATH) as fp: 30 | settings = json.load(fp) 31 | output_root = Path("./output").joinpath(datetime.now(timezone(timedelta(hours=+9), 'JST')).strftime("%Y%m%d_%H%M%S")) 32 | output_root.mkdir() 33 | shutil.copy(SETTING_JSON_PATH, output_root.joinpath("settings.json")) 34 | 35 | if settings["detect_anomaly"]: 36 | with torch.autograd.detect_anomaly(): 37 | train(settings, output_root) 38 | else: 39 | train(settings, output_root) 40 | 41 | 42 | def train(settings, output_root): 43 | # directories 44 | weights_root = output_root.joinpath("weights") 45 | weights_root.mkdir() 46 | 47 | # settings 48 | amp_handle = amp.init(settings["use_apex"]) 49 | 50 | if settings["use_cuda"]: 51 | device = torch.device("cuda:0") 52 | else: 53 | device = torch.device("cpu") 54 | 55 | dtype = torch.float32 56 | test_device = torch.device("cuda:0") 57 | test_dtype = torch.float16 58 | 59 | loss_type = settings["loss"] 60 | 61 | z_dim = settings["network"]["z_dim"] 62 | 63 | # model 64 | label_size = len(settings["labels"]) 65 | generator = network.Generator(settings["network"], label_size).to(device, dtype) 66 | discriminator = network.Discriminator(settings["network"], label_size).to(device, dtype) 67 | 68 | # long-term average 69 | gs = network.Generator(settings["network"], label_size).to(device, dtype) 70 | gs.load_state_dict(generator.state_dict()) 71 | gs_beta = settings["gs_beta"] 72 | 73 | lt_learning_rate = settings["learning_rates"]["latent_transformation"] 74 | g_learning_rate = settings["learning_rates"]["generator"] 75 | d_learning_rate = settings["learning_rates"]["discriminator"] 76 | g_opt = optim.Adam([ 77 | {"params": generator.latent_transform.parameters(), "lr": lt_learning_rate}, 78 | {"params": generator.synthesis_module.parameters()} 79 | ], lr=g_learning_rate, betas=(0.0, 0.99), eps=1e-8) 80 | d_opt = optim.Adam(discriminator.parameters(), 81 | lr=d_learning_rate, betas=(0.0, 0.99), eps=1e-8) 82 | 83 | # train data 84 | loader = data_loader.LabeledDataLoader(settings) 85 | 86 | if settings["use_yuv"]: 87 | converter = image_converter.YUVConverter() 88 | else: 89 | converter = image_converter.RGBConverter() 90 | 91 | # parameters 92 | level = settings["start_level"] 93 | generator.set_level(level) 94 | discriminator.set_level(level) 95 | gs.set_level(level) 96 | fading = False 97 | alpha = 1 98 | step = 0 99 | 100 | # log 101 | writer = SummaryWriter(str(output_root)) 102 | test_rows = 12 103 | test_cols = 6 104 | test_zs = utils.create_test_z(test_rows, test_cols, z_dim) 105 | test_z0 = torch.from_numpy(test_zs[0]).to(test_device, test_dtype) 106 | test_z1 = torch.from_numpy(test_zs[1]).to(test_device, test_dtype) 107 | test_labels0 = torch.randint(0, loader.label_size, (1, test_cols)) 108 | test_labels0 = test_labels0.repeat(test_rows, 1).to(device) 109 | test_labels1 = torch.randint(0, loader.label_size, (test_rows, test_cols), device=test_device).view(-1) 110 | 111 | for loop in range(9999999): 112 | size = 2 ** (level+1) 113 | 114 | batch_size = settings["batch_sizes"][level-1] 115 | alpha_delta = batch_size / settings["num_images_in_stage"] 116 | 117 | image_count = 0 118 | 119 | for batch, labels in loader.generate(batch_size, size, size): 120 | # pre train 121 | step += 1 122 | image_count += batch_size 123 | if fading: 124 | alpha = min(1.0, alpha + alpha_delta) 125 | 126 | # data 127 | batch = batch.transpose([0, 3, 1, 2]) 128 | batch = converter.to_train_data(batch) 129 | trues = torch.from_numpy(batch).to(device, dtype) 130 | labels = torch.from_numpy(labels).to(device) 131 | 132 | # reset 133 | g_opt.zero_grad() 134 | d_opt.zero_grad() 135 | 136 | # === train discriminator === 137 | z = utils.create_z(batch_size, z_dim) 138 | z = torch.from_numpy(z).to(device, dtype) 139 | fakes = generator.forward(z, labels, alpha) 140 | fakes_nograd = fakes.detach() 141 | 142 | for param in discriminator.parameters(): 143 | param.requires_grad_(True) 144 | if loss_type == "wgan": 145 | d_loss, wd = loss.d_wgan_loss(discriminator, trues, fakes_nograd, labels, alpha) 146 | elif loss_type == "lsgan": 147 | d_loss = loss.d_lsgan_loss(discriminator, trues, fakes_nograd, labels, alpha) 148 | elif loss_type == "logistic": 149 | d_loss = loss.d_logistic_loss(discriminator, trues, fakes_nograd, labels, alpha) 150 | else: 151 | raise Exception(f"Invalid loss: {loss_type}") 152 | 153 | with amp_handle.scale_loss(d_loss, d_opt) as scaled_loss: 154 | scaled_loss.backward() 155 | d_opt.step() 156 | 157 | # === train generator === 158 | z = utils.create_z(batch_size, z_dim) 159 | z = torch.from_numpy(z).to(device, dtype) 160 | fakes = generator.forward(z, labels, alpha) 161 | 162 | for param in discriminator.parameters(): 163 | param.requires_grad_(False) 164 | if loss_type == "wgan": 165 | g_loss = loss.g_wgan_loss(discriminator, fakes, labels, alpha) 166 | elif loss_type == "lsgan": 167 | g_loss = loss.g_lsgan_loss(discriminator, fakes, labels, alpha) 168 | elif loss_type == "logistic": 169 | g_loss = loss.g_logistic_loss(discriminator, fakes, labels, alpha) 170 | else: 171 | raise Exception(f"Invalid loss: {loss_type}") 172 | 173 | with amp_handle.scale_loss(g_loss, g_opt) as scaled_loss: 174 | scaled_loss.backward() 175 | del scaled_loss 176 | g_opt.step() 177 | 178 | del trues, fakes, fakes_nograd 179 | 180 | # update gs 181 | for gparam, gsparam in zip(generator.parameters(), gs.parameters()): 182 | gsparam.data = (1-gs_beta) * gsparam.data + gs_beta * gparam.data 183 | gs.w_average.data = (1-gs_beta) * gs.w_average.data + gs_beta * generator.w_average.data 184 | 185 | # log 186 | if step % 1 == 0: 187 | print(f"lv{level}-{step}: " 188 | f"a: {alpha:.5f} " 189 | f"g: {g_loss.item():.7f} " 190 | f"d: {d_loss.item():.7f} ") 191 | 192 | writer.add_scalar(f"lv{level}/loss_gen", g_loss.item(), global_step=step) 193 | writer.add_scalar(f"lv{level}/loss_disc", d_loss.item(), global_step=step) 194 | if loss_type == "wgan": 195 | writer.add_scalar(f"lv{level}/wd", wd, global_step=step) 196 | 197 | del d_loss, g_loss 198 | 199 | # histogram 200 | if settings["save_steps"]["histogram"] > 0 and step % settings["save_steps"]["histogram"] == 0: 201 | gs.write_histogram(writer, step) 202 | for name, param in discriminator.named_parameters(): 203 | writer.add_histogram(f"disc/{name}", param.cpu().data.numpy(), step) 204 | 205 | # image 206 | if step % settings["save_steps"]["image"] == 0 or alpha == 0: 207 | fading_text = "fading" if fading else "stabilizing" 208 | with torch.no_grad(): 209 | eval_gen = network.Generator(settings["network"], label_size).to(test_device, test_dtype).eval() 210 | eval_gen.load_state_dict(gs.state_dict()) 211 | eval_gen.synthesis_module.set_noise_fixed(True) 212 | fakes = eval_gen.forward(test_z0, test_labels0, alpha) 213 | fakes = torchvision.utils.make_grid(fakes, nrow=test_cols, padding=0) 214 | fakes = fakes.to(torch.float32).cpu().numpy() 215 | fakes = converter.from_generator_output(fakes) 216 | writer.add_image(f"lv{level}_{fading_text}/intpl", torch.from_numpy(fakes), step) 217 | fakes = eval_gen.forward(test_z1, test_labels1, alpha) 218 | fakes = torchvision.utils.make_grid(fakes, nrow=test_cols, padding=0) 219 | fakes = fakes.to(torch.float32).cpu().numpy() 220 | fakes = converter.from_generator_output(fakes) 221 | writer.add_image(f"lv{level}_{fading_text}/random", torch.from_numpy(fakes), step) 222 | del eval_gen 223 | # memory usage 224 | writer.add_scalar("memory_allocated(MB)", torch.cuda.memory_allocated() / (1024*1024), global_step=step) 225 | 226 | # model save 227 | if step % settings["save_steps"]["model"] == 0 and level >= 5 and not fading: 228 | savedir = weights_root.joinpath(f"{step}_lv{level}") 229 | savedir.mkdir() 230 | torch.save(generator.state_dict(), savedir.joinpath("gen.pth")) 231 | torch.save(generator.state_dict(), savedir.joinpath("gs.pth")) 232 | torch.save(discriminator.state_dict(), savedir.joinpath("disc.pth")) 233 | 234 | # switch fading/stabilizing 235 | if image_count > settings["num_images_in_stage"]: 236 | if fading: 237 | print("start stabilizing") 238 | fading = False 239 | alpha = 1 240 | image_count = 0 241 | elif level < settings["max_level"]: 242 | print(f"end lv: {level}") 243 | break 244 | 245 | # level up 246 | if level < settings["max_level"]: 247 | level = level+1 248 | generator.set_level(level) 249 | discriminator.set_level(level) 250 | gs.set_level(level) 251 | fading = True 252 | alpha = 0 253 | print(f"lv up: {level}") 254 | 255 | if settings["reset_optimizer"]: 256 | g_opt = optim.Adam([ 257 | {"params": generator.latent_transform.parameters(), "lr": lt_learning_rate}, 258 | {"params": generator.synthesis_module.parameters()} 259 | ], lr=g_learning_rate, betas=(0.0, 0.99), eps=1e-8) 260 | d_opt = optim.Adam(discriminator.parameters(), 261 | lr=d_learning_rate, betas=(0.0, 0.99), eps=1e-8) 262 | 263 | 264 | if __name__ == '__main__': 265 | torch.backends.cudnn.benchmark = True 266 | main() 267 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def slerp(val, low, high): 5 | omega = np.arccos(np.dot(low/np.linalg.norm(low), high/np.linalg.norm(high))) 6 | so = np.sin(omega) 7 | return np.sin((1.0-val)*omega) / so * low + np.sin(val*omega)/so * high 8 | 9 | 10 | def create_z(size, dim): 11 | z = np.random.normal(0, 1, [size, dim]) 12 | return z 13 | 14 | 15 | def create_test_z(rows, cols, dim): 16 | # interpolation 17 | z1 = np.zeros([rows, cols, dim]) 18 | z1_start = create_z(cols, dim) 19 | z1_end = create_z(cols, dim) 20 | for i in range(rows): 21 | val = i / (rows-1) 22 | for j in range(cols): 23 | z1[i, j] = slerp(val, z1_start[j], z1_end[j]) 24 | z1 = z1.reshape([-1, dim]) 25 | 26 | # random 27 | z2 = create_z(rows * cols, dim) 28 | 29 | return z1, z2 30 | --------------------------------------------------------------------------------