├── LICENSE ├── README.md ├── SGAN.ipynb ├── dataset_preview.py ├── evaluate.py ├── model.py └── train.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Siskon 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # StyleGAN-PyTorch 2 | 3 | This is a simple but complete pytorch-version implementation of Nvidia's Style-based GAN[3]. We've train this model on [our new anime face dataset](https://github.com/SiskonEmilia/Anime-Wifu-Dataset) and a subset of FFHQ, you can download our pre-trained model to evaluate or continue training by yourself. 4 | 5 | ## Preview 6 | 7 | *Not available yet.* 8 | 9 | ### Overview 10 | 11 | ### With and without noise 12 | 13 | ### Style-mixing 14 | 15 | ## Versions 16 | 17 | We provide you with two versions of implementations: the `SGAN.ipynb` for jupyter notebook with GUI and `.py`s for CLI only. 18 | 19 | With `SGAN.ipynb`, one can view the image generated by model every `n_show_loss` iterations, while the `.py` version will only save it to the folder you specify. Except that, there's no difference between these two versions. 20 | 21 | If you want to understand how the model works, we recommend you to read the `.py` version, as we refine its code structure and comment content to make it more readable. 22 | 23 | ## Parameters 24 | 25 | As we did not provide you with any optional command parameters, you can only change them inside our code to match your requirement. 26 | 27 | |Parameter|Description| 28 | |:-:|:-:| 29 | |n_gpu|number of GPUs used to train the model| 30 | |device|default device to create and save tensors| 31 | |learning_rate|a dict to indicate learning rate at different stage of training| 32 | |batch_size*|a dict to indicate batch size at different stage of training| 33 | |mini_batch_size*|minimal batch size| 34 | |n_fc|number of layers in the full-connected mapping network| 35 | |dim_latent|dimension of latent space| 36 | |dim_input|size of the first layer of generator| 37 | |n_sample|how many samples will be used to train a single layer| 38 | |n_sample_total|how many samples will be used to train the whole model| 39 | |DGR|how many times will discriminator be trained before training generator| 40 | |n_show_loss|loss will be recorded every `n_show_loss` iterations| 41 | |step|which layer to start training| 42 | |max_step|maximum resolution of images is 2 ^ (max_step + 2)| 43 | |style_mixing|layers to use 2nd style to evaluate| 44 | |image_folder_path|path to the dataset folder that contains images| 45 | |save_folder_path|path to the folder that generated images will be saved to| 46 | |is_train|set to `True` if you want to train the model| 47 | |is_continue|set to `True` if you want to load pre-trained model| 48 | |CUDA_VISIBLE_DEVICES|specify indexes of available GPU| 49 | 50 | \*With suffix like '_2gpus', which means this parameter should be used (by removing the suffix) while using this number of GPUs. 51 | 52 | ## Checkpoint 53 | 54 | Our implementation support save trained model and load a existed checkpoint to continue training. 55 | 56 | ### Pre-trained model 57 | 58 | *Not available yet.* 59 | 60 | ### Save model & continue training 61 | 62 | When you train the model yourself, parameters of the model will be saved to `./checkpoint/trained.pth` every 1000 iterations. You can set `is_continue` to `True` to continue training from your pre-trained model. 63 | 64 | ## Performance 65 | 66 | ### Loss curve 67 | 68 | *Not available yet.* 69 | 70 | ### [Fréchet Inception Distance](https://arxiv.org/abs/1706.08500) 71 | 72 | We use Fréchet Inception Distance to estimate the performance of our implementation. We use an edited version (changes to it will not affect the score it gives) of [mseitzer's work](https://github.com/mseitzer/pytorch-fid) to estimate our model's performance. 73 | 74 | *Not available yet.* 75 | 76 | ## Changelog 77 | 78 | - PLAN: Nanami Iteration 79 | - TODO: Estimate performance of model with FID 80 | - TODO: Generate preview images 81 | - TODO: Upload pre-trained models and dataset 82 | - 5/28: Support truncation trick in W while evaluating (usable in evaluate.py) 83 | - 5/25: Allow users to edit maximum resolution (step). 84 | - 5/20-5/23: Umi Iteration 85 | - 5/23: Divide codes into files 86 | - 5/23: Support evaluate-only mode 87 | - 5/23: DEBUG: Leak of VRAM 88 | - 5/23: DEBUG: The change of alpha (used to decide the degree of crossover between different layers) is set to be linear[2]. 89 | - 5/22: DEBUG: Now this model is able to train on multiple GPUs. 90 | - 5/22: DEBUG: Fix the bug that the adaptive normalization module does not participate in calculation and back-propagation. 91 | - 5/16 - 5/19: Shiroha Iteration 92 | - 5/19: Construct [a new anime face dataset](https://github.com/SiskonEmilia/Anime-Wifu-Dataset) 93 | - 5/16: Able to continue training from a historic checkpoint 94 | - 5/16: Introduce style-mixing feature 95 | - 5/16: DEBUG: Fix the bug that the full connected mapping layer does not participate in calculation and back-propagation. 96 | - 5/13 - 5/15: Kamome Iteration 97 | - 5/15: DEBUG: VRAM leak and shared memory conflict 98 | - 5/14: DEBUG: Parallel conflict on Windows (Due to the speed limit, we migrate to Linux platform) 99 | - 5/13: Introduce complete Progressive GAN[2] and its training method. 100 | - 5/12: Introduce R1 regularization[1] and constant input vector. 101 | - 5/12: Early implementation of Style-based GAN. 102 | 103 | ## References 104 | 105 | [1] Mescheder, L., Geiger, A., & Nowozin, S. (2018). Which Training Methods for GANs do actually Converge? Retrieved from http://arxiv.org/abs/1801.04406 106 | 107 | [2] Karras, T., Aila, T., Laine, S., & Lehtinen, J. (2017). Progressive Growing of GANs for Improved Quality, Stability, and Variation. 1–26. Retrieved from http://arxiv.org/abs/1710.10196 108 | 109 | [3] Karras, T., Laine, S., & Aila, T. (2018). A Style-Based Generator Architecture for Generative Adversarial Networks. Retrieved from http://arxiv.org/abs/1812.04948 -------------------------------------------------------------------------------- /SGAN.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "# Import necessary modules\n", 10 | "import torch\n", 11 | "from tqdm import tqdm\n", 12 | "import numpy as np\n", 13 | "import torch.nn as nn\n", 14 | "import torch.optim as optim\n", 15 | "import torchvision\n", 16 | "import torchvision.transforms as transforms\n", 17 | "import matplotlib.pyplot as plt\n", 18 | "from PIL import Image\n", 19 | "import math\n", 20 | "\n", 21 | "from torch.utils.data import DataLoader\n", 22 | "from torchvision import datasets, transforms, utils\n", 23 | "\n", 24 | "%matplotlib inline" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": null, 30 | "metadata": {}, 31 | "outputs": [], 32 | "source": [ 33 | "# 5/15: No using shared memory\n", 34 | "import sys\n", 35 | "import torch\n", 36 | "from torch.utils.data import dataloader\n", 37 | "from torch.multiprocessing import reductions\n", 38 | "from multiprocessing.reduction import ForkingPickler\n", 39 | "\n", 40 | "default_collate_func = dataloader.default_collate\n", 41 | "\n", 42 | "\n", 43 | "def default_collate_override(batch):\n", 44 | " dataloader._use_shared_memory = False\n", 45 | " return default_collate_func(batch)\n", 46 | "\n", 47 | "setattr(dataloader, 'default_collate', default_collate_override)\n", 48 | "\n", 49 | "for t in torch._storage_classes:\n", 50 | " if sys.version_info[0] == 2:\n", 51 | " if t in ForkingPickler.dispatch:\n", 52 | " del ForkingPickler.dispatch[t]\n", 53 | " else:\n", 54 | " if t in ForkingPickler._extra_reducers:\n", 55 | " del ForkingPickler._extra_reducers[t]" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": null, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "# Constraints\n", 65 | "# Input: [batch_size, in_channels, height, width]\n", 66 | "\n", 67 | "# Scaled weight - He initialization\n", 68 | "# \"explicitly scale the weights at runtime\"\n", 69 | "class ScaleW:\n", 70 | " '''\n", 71 | " Constructor: name - name of attribute to be scaled\n", 72 | " '''\n", 73 | " def __init__(self, name):\n", 74 | " self.name = name\n", 75 | " \n", 76 | " def scale(self, module):\n", 77 | " weight = getattr(module, self.name + '_orig')\n", 78 | " fan_in = weight.data.size(1) * weight.data[0][0].numel()\n", 79 | " \n", 80 | " return weight * math.sqrt(2 / fan_in)\n", 81 | " \n", 82 | " @staticmethod\n", 83 | " def apply(module, name):\n", 84 | " '''\n", 85 | " Apply runtime scaling to specific module\n", 86 | " '''\n", 87 | " hook = ScaleW(name)\n", 88 | " weight = getattr(module, name)\n", 89 | " module.register_parameter(name + '_orig', nn.Parameter(weight.data))\n", 90 | " del module._parameters[name]\n", 91 | " module.register_forward_pre_hook(hook)\n", 92 | " \n", 93 | " def __call__(self, module, whatever):\n", 94 | " weight = self.scale(module)\n", 95 | " setattr(module, self.name, weight)\n", 96 | "\n", 97 | "# Quick apply for scaled weight\n", 98 | "def quick_scale(module, name='weight'):\n", 99 | " ScaleW.apply(module, name)\n", 100 | " return module\n", 101 | "\n", 102 | "# Uniformly set the hyperparameters of Linears\n", 103 | "# \"We initialize all weights of the convolutional, fully-connected, and affine transform layers using N(0, 1)\"\n", 104 | "# 5/13: Apply scaled weights\n", 105 | "class SLinear(nn.Module):\n", 106 | " def __init__(self, dim_in, dim_out):\n", 107 | " super().__init__()\n", 108 | "\n", 109 | " linear = nn.Linear(dim_in, dim_out)\n", 110 | " linear.weight.data.normal_()\n", 111 | " linear.bias.data.zero_()\n", 112 | " \n", 113 | " self.linear = quick_scale(linear)\n", 114 | "\n", 115 | " def forward(self, x):\n", 116 | " return self.linear(x)\n", 117 | "\n", 118 | "# Uniformly set the hyperparameters of Conv2d\n", 119 | "# \"We initialize all weights of the convolutional, fully-connected, and affine transform layers using N(0, 1)\"\n", 120 | "# 5/13: Apply scaled weights\n", 121 | "class SConv2d(nn.Module):\n", 122 | " def __init__(self, *args, **kwargs):\n", 123 | " super().__init__()\n", 124 | "\n", 125 | " conv = nn.Conv2d(*args, **kwargs)\n", 126 | " conv.weight.data.normal_()\n", 127 | " conv.bias.data.zero_()\n", 128 | " \n", 129 | " self.conv = quick_scale(conv)\n", 130 | "\n", 131 | " def forward(self, x):\n", 132 | " return self.conv(x)\n", 133 | "\n", 134 | "# Normalization on every element of input vector\n", 135 | "class PixelNorm(nn.Module):\n", 136 | " def __init__(self):\n", 137 | " super().__init__()\n", 138 | "\n", 139 | " def forward(self, x):\n", 140 | " return x / torch.sqrt(torch.mean(x ** 2, dim=1, keepdim=True) + 1e-8)\n", 141 | "\n", 142 | "# \"learned affine transform\" A\n", 143 | "class FC_A(nn.Module):\n", 144 | " '''\n", 145 | " Learned affine transform A, this module is used to transform\n", 146 | " midiate vector w into a style vector\n", 147 | " '''\n", 148 | " def __init__(self, dim_latent, n_channel):\n", 149 | " super().__init__()\n", 150 | " self.transform = SLinear(dim_latent, n_channel * 2)\n", 151 | " # \"the biases associated with ys that we initialize to one\"\n", 152 | " self.transform.linear.bias.data[:n_channel] = 1\n", 153 | " self.transform.linear.bias.data[n_channel:] = 0\n", 154 | "\n", 155 | " def forward(self, w):\n", 156 | " # Gain scale factor and bias with:\n", 157 | " style = self.transform(w).unsqueeze(2).unsqueeze(3)\n", 158 | " return style\n", 159 | " \n", 160 | "# AdaIn (AdaptiveInstanceNorm)\n", 161 | "class AdaIn(nn.Module):\n", 162 | " '''\n", 163 | " adaptive instance normalization\n", 164 | " '''\n", 165 | " def __init__(self, n_channel):\n", 166 | " super().__init__()\n", 167 | " self.norm = nn.InstanceNorm2d(n_channel)\n", 168 | " \n", 169 | " def forward(self, image, style):\n", 170 | " factor, bias = style.chunk(2, 1)\n", 171 | " result = self.norm(image)\n", 172 | " result = result * factor + bias \n", 173 | " return result\n", 174 | "\n", 175 | "# \"learned per-channel scaling factors\" B\n", 176 | "# 5/13: Debug - tensor -> nn.Parameter\n", 177 | "class Scale_B(nn.Module):\n", 178 | " '''\n", 179 | " Learned per-channel scale factor, used to scale the noise\n", 180 | " '''\n", 181 | " def __init__(self, n_channel):\n", 182 | " super().__init__()\n", 183 | " self.weight = nn.Parameter(torch.zeros((1, n_channel, 1, 1)))\n", 184 | " \n", 185 | " def forward(self, noise):\n", 186 | " result = noise * self.weight\n", 187 | " return result \n", 188 | "\n", 189 | "# Early convolutional block\n", 190 | "# 5/13: Debug - tensor -> nn.Parameter\n", 191 | "# 5/13: Remove noise generating module\n", 192 | "# TODO: Remove upsample\n", 193 | "class Early_StyleConv_Block(nn.Module):\n", 194 | " '''\n", 195 | " This is the very first block of generator that get the constant value as input\n", 196 | " '''\n", 197 | " def __init__ (self, n_channel, dim_latent, dim_input):\n", 198 | " super().__init__()\n", 199 | " # Constant input\n", 200 | " self.constant = nn.Parameter(torch.randn(1, n_channel, dim_input, dim_input))\n", 201 | " # Style generators\n", 202 | " self.style1 = FC_A(dim_latent, n_channel)\n", 203 | " self.style2 = FC_A(dim_latent, n_channel)\n", 204 | " # Noise processing modules\n", 205 | " self.noise1 = quick_scale(Scale_B(n_channel))\n", 206 | " self.noise2 = quick_scale(Scale_B(n_channel))\n", 207 | " # AdaIn\n", 208 | " self.adain = AdaIn(n_channel)\n", 209 | " self.lrelu = nn.LeakyReLU(0.2)\n", 210 | " # Convolutional layer\n", 211 | " self.conv = SConv2d(n_channel, n_channel, 3, padding=1)\n", 212 | " \n", 213 | " def forward(self, latent_w, noise):\n", 214 | " # Gaussian Noise: Proxyed by generator\n", 215 | " # noise1 = torch.normal(mean=0,std=torch.ones(self.constant.shape)).cuda()\n", 216 | " # noise2 = torch.normal(mean=0,std=torch.ones(self.constant.shape)).cuda()\n", 217 | " result = self.constant.repeat(noise.shape[0], 1, 1, 1)\n", 218 | " result = result + self.noise1(noise)\n", 219 | " result = self.adain(result, self.style1(latent_w))\n", 220 | " result = self.lrelu(result)\n", 221 | " result = self.conv(result)\n", 222 | " result = result + self.noise2(noise)\n", 223 | " result = self.adain(result, self.style2(latent_w))\n", 224 | " result = self.lrelu(result)\n", 225 | " \n", 226 | " return result\n", 227 | " \n", 228 | "# General convolutional blocks\n", 229 | "# 5/13: Remove upsampling\n", 230 | "# 5/13: Remove noise generating\n", 231 | "class StyleConv_Block(nn.Module):\n", 232 | " '''\n", 233 | " This is the general class of style-based convolutional blocks\n", 234 | " '''\n", 235 | " def __init__ (self, in_channel, out_channel, dim_latent):\n", 236 | " super().__init__()\n", 237 | " # Style generators\n", 238 | " self.style1 = FC_A(dim_latent, out_channel)\n", 239 | " self.style2 = FC_A(dim_latent, out_channel)\n", 240 | " # Noise processing modules\n", 241 | " self.noise1 = quick_scale(Scale_B(out_channel))\n", 242 | " self.noise2 = quick_scale(Scale_B(out_channel))\n", 243 | " # AdaIn\n", 244 | " self.adain = AdaIn(out_channel)\n", 245 | " self.lrelu = nn.LeakyReLU(0.2)\n", 246 | " # Convolutional layers\n", 247 | " self.conv1 = SConv2d(in_channel, out_channel, 3, padding=1)\n", 248 | " self.conv2 = SConv2d(out_channel, out_channel, 3, padding=1)\n", 249 | " \n", 250 | " def forward(self, previous_result, latent_w, noise):\n", 251 | " # Upsample: Proxyed by generator\n", 252 | " # result = nn.functional.interpolate(previous_result, scale_factor=2, mode='bilinear',\n", 253 | " # align_corners=False)\n", 254 | " # Conv 3*3\n", 255 | " result = self.conv1(previous_result)\n", 256 | " # Gaussian Noise: Proxyed by generator\n", 257 | " # noise1 = torch.normal(mean=0,std=torch.ones(result.shape)).cuda()\n", 258 | " # noise2 = torch.normal(mean=0,std=torch.ones(result.shape)).cuda()\n", 259 | " # Conv & Norm\n", 260 | " result = result + self.noise1(noise)\n", 261 | " result = self.adain(result, self.style1(latent_w))\n", 262 | " result = self.lrelu(result)\n", 263 | " result = self.conv2(result)\n", 264 | " result = result + self.noise2(noise)\n", 265 | " result = self.adain(result, self.style2(latent_w))\n", 266 | " result = self.lrelu(result)\n", 267 | " \n", 268 | " return result \n", 269 | "\n", 270 | "# Very First Convolutional Block\n", 271 | "# 5/13: No more downsample, this block is the same sa general ones\n", 272 | "# class Early_ConvBlock(nn.Module):\n", 273 | "# '''\n", 274 | "# Used to construct progressive discriminator\n", 275 | "# '''\n", 276 | "# def __init__(self, in_channel, out_channel, size_kernel, padding):\n", 277 | "# super().__init__()\n", 278 | "# self.conv = nn.Sequential(\n", 279 | "# SConv2d(in_channel, out_channel, size_kernel, padding=padding),\n", 280 | "# nn.LeakyReLU(0.2),\n", 281 | "# SConv2d(out_channel, out_channel, size_kernel, padding=padding),\n", 282 | "# nn.LeakyReLU(0.2)\n", 283 | "# )\n", 284 | " \n", 285 | "# def forward(self, image):\n", 286 | "# result = self.conv(image)\n", 287 | "# return result\n", 288 | " \n", 289 | "# General Convolutional Block\n", 290 | "# 5/13: Downsample is now removed from block module\n", 291 | "class ConvBlock(nn.Module):\n", 292 | " '''\n", 293 | " Used to construct progressive discriminator\n", 294 | " '''\n", 295 | " def __init__(self, in_channel, out_channel, size_kernel1, padding1, \n", 296 | " size_kernel2 = None, padding2 = None):\n", 297 | " super().__init__()\n", 298 | " \n", 299 | " if size_kernel2 == None:\n", 300 | " size_kernel2 = size_kernel1\n", 301 | " if padding2 == None:\n", 302 | " padding2 = padding1\n", 303 | " \n", 304 | " self.conv = nn.Sequential(\n", 305 | " SConv2d(in_channel, out_channel, size_kernel1, padding=padding1),\n", 306 | " nn.LeakyReLU(0.2),\n", 307 | " SConv2d(out_channel, out_channel, size_kernel2, padding=padding2),\n", 308 | " nn.LeakyReLU(0.2)\n", 309 | " )\n", 310 | " \n", 311 | " def forward(self, image):\n", 312 | " # Downsample now proxyed by discriminator\n", 313 | " # result = nn.functional.interpolate(image, scale_factor=0.5, mode=\"bilinear\", align_corners=False)\n", 314 | " # Conv\n", 315 | " result = self.conv(image)\n", 316 | " return result\n", 317 | " \n", 318 | " \n", 319 | "# Main components\n", 320 | "class Intermediate_Generator(nn.Module):\n", 321 | " '''\n", 322 | " A mapping consists of multiple fully connected layers.\n", 323 | " Used to map the input to an intermediate latent space W.\n", 324 | " '''\n", 325 | " def __init__(self, n_fc, dim_latent):\n", 326 | " super().__init__()\n", 327 | " layers = [PixelNorm()]\n", 328 | " for i in range(n_fc):\n", 329 | " layers.append(SLinear(dim_latent, dim_latent))\n", 330 | " layers.append(nn.LeakyReLU(0.2))\n", 331 | " \n", 332 | " self.mapping = nn.Sequential(*layers)\n", 333 | " \n", 334 | " def forward(self, latent_z):\n", 335 | " latent_w = self.mapping(latent_z)\n", 336 | " return latent_w \n", 337 | "\n", 338 | "# Generator\n", 339 | "# 5/13: Support progressive training\n", 340 | "# 5/13: Proxy noise generating\n", 341 | "# 5/13: Proxy upsampling\n", 342 | "# TODO: style mixing\n", 343 | "class StyleBased_Generator(nn.Module):\n", 344 | " '''\n", 345 | " Main Module\n", 346 | " '''\n", 347 | " def __init__(self, n_fc, dim_latent, dim_input):\n", 348 | " super().__init__()\n", 349 | " # Waiting to adjust the size\n", 350 | " self.fcs = Intermediate_Generator(n_fc, dim_latent)\n", 351 | " self.convs = nn.ModuleList([\n", 352 | " Early_StyleConv_Block(512, dim_latent, dim_input),\n", 353 | " StyleConv_Block(512, 512, dim_latent),\n", 354 | " StyleConv_Block(512, 512, dim_latent),\n", 355 | " StyleConv_Block(512, 512, dim_latent),\n", 356 | " StyleConv_Block(512, 256, dim_latent),\n", 357 | " StyleConv_Block(256, 128, dim_latent),\n", 358 | " StyleConv_Block(128, 64, dim_latent),\n", 359 | " StyleConv_Block(64, 32, dim_latent),\n", 360 | " StyleConv_Block(32, 16, dim_latent)\n", 361 | " ])\n", 362 | " self.to_rgbs = nn.ModuleList([\n", 363 | " SConv2d(512, 3, 1),\n", 364 | " SConv2d(512, 3, 1),\n", 365 | " SConv2d(512, 3, 1),\n", 366 | " SConv2d(512, 3, 1),\n", 367 | " SConv2d(256, 3, 1),\n", 368 | " SConv2d(128, 3, 1),\n", 369 | " SConv2d(64, 3, 1),\n", 370 | " SConv2d(32, 3, 1),\n", 371 | " SConv2d(16, 3, 1)\n", 372 | " ])\n", 373 | " def forward(self, latent_z, \n", 374 | " step = 0, # Step means how many layers (count from 4 x 4) are used to train\n", 375 | " alpha=-1, # Alpha is the parameter of smooth conversion of resolution):\n", 376 | " noise=None, # TODO: support input noise\n", 377 | " mix_steps=[]): # steps inside will use latent_z[1], else latent_z[0]\n", 378 | " if type(latent_z) != type([]):\n", 379 | " print('You should use list to package your latent_z')\n", 380 | " latent_z = [latent_z]\n", 381 | " if (len(latent_z) != 2 and len(mix_steps) > 0) or type(mix_steps) != type([]):\n", 382 | " print('Warning: Style mixing disabled, possible reasons:')\n", 383 | " print('- Invalid number of latent vectors')\n", 384 | " print('- Invalid parameter type: mix_steps')\n", 385 | " mix_steps = []\n", 386 | " \n", 387 | " latent_w = [self.fcs(latent) for latent in latent_z]\n", 388 | " batch_size = latent_w[0].size(0)\n", 389 | " \n", 390 | " # Generate needed Gaussian noise\n", 391 | " # 5/22: Noise is now generated by outer module\n", 392 | " # noise = []\n", 393 | " result = 0\n", 394 | " current_latent = 0\n", 395 | " # for i in range(step + 1):\n", 396 | " # size = 4 * 2 ** i # Due to the upsampling, size of noise will grow\n", 397 | " # noise.append(torch.randn((batch_size, 1, size, size), device=torch.device('cuda:0')))\n", 398 | " \n", 399 | " for i, conv in enumerate(self.convs):\n", 400 | " # Choose current latent_w\n", 401 | " if i in mix_steps:\n", 402 | " current_latent = latent_w[1]\n", 403 | " else:\n", 404 | " current_latent = latent_w[0]\n", 405 | " \n", 406 | " # Not the first layer, need to upsample\n", 407 | " if i > 0 and step > 0:\n", 408 | " result_upsample = nn.functional.interpolate(result, scale_factor=2, mode='bilinear',\n", 409 | " align_corners=False)\n", 410 | " result = conv(result_upsample, current_latent, noise[i])\n", 411 | " else:\n", 412 | " result = conv(current_latent, noise[i])\n", 413 | " \n", 414 | " # Final layer, output rgb image\n", 415 | " if i == step:\n", 416 | " result = self.to_rgbs[i](result)\n", 417 | " \n", 418 | " if i > 0 and 0 <= alpha < 1:\n", 419 | " result_prev = self.to_rgbs[i - 1](result_upsample)\n", 420 | " result = alpha * result + (1 - alpha) * result_prev\n", 421 | " \n", 422 | " # Finish and break\n", 423 | " break\n", 424 | " \n", 425 | " return result\n", 426 | "\n", 427 | "# Discriminator\n", 428 | "# 5/13: Support progressive training\n", 429 | "# 5/13: Add downsample module\n", 430 | "# Component of Progressive GAN\n", 431 | "# Reference: Karras, T., Aila, T., Laine, S., & Lehtinen, J. (2017).\n", 432 | "# Progressive Growing of GANs for Improved Quality, Stability, and Variation, 1–26.\n", 433 | "# Retrieved from http://arxiv.org/abs/1710.10196\n", 434 | "class Discriminator(nn.Module):\n", 435 | " '''\n", 436 | " Main Module\n", 437 | " '''\n", 438 | " def __init__(self):\n", 439 | " super().__init__()\n", 440 | " # Waiting to adjust the size\n", 441 | " self.from_rgbs = nn.ModuleList([\n", 442 | " SConv2d(3, 16, 1),\n", 443 | " SConv2d(3, 32, 1),\n", 444 | " SConv2d(3, 64, 1),\n", 445 | " SConv2d(3, 128, 1),\n", 446 | " SConv2d(3, 256, 1),\n", 447 | " SConv2d(3, 512, 1),\n", 448 | " SConv2d(3, 512, 1),\n", 449 | " SConv2d(3, 512, 1),\n", 450 | " SConv2d(3, 512, 1)\n", 451 | " ])\n", 452 | " self.convs = nn.ModuleList([\n", 453 | " ConvBlock(16, 32, 3, 1),\n", 454 | " ConvBlock(32, 64, 3, 1),\n", 455 | " ConvBlock(64, 128, 3, 1),\n", 456 | " ConvBlock(128, 256, 3, 1),\n", 457 | " ConvBlock(256, 512, 3, 1),\n", 458 | " ConvBlock(512, 512, 3, 1),\n", 459 | " ConvBlock(512, 512, 3, 1),\n", 460 | " ConvBlock(512, 512, 3, 1),\n", 461 | " ConvBlock(513, 512, 3, 1, 4, 0)\n", 462 | " ])\n", 463 | " self.fc = SLinear(512, 1)\n", 464 | " \n", 465 | " self.n_layer = 9 # 9 layers network\n", 466 | " \n", 467 | " def forward(self, image, \n", 468 | " step = 0, # Step means how many layers (count from 4 x 4) are used to train\n", 469 | " alpha=-1): # Alpha is the parameter of smooth conversion of resolution):\n", 470 | " for i in range(step, -1, -1):\n", 471 | " # Get the index of current layer\n", 472 | " # Count from the bottom layer (4 * 4)\n", 473 | " layer_index = self.n_layer - i - 1 \n", 474 | " \n", 475 | " # First layer, need to use from_rgb to convert to n_channel data\n", 476 | " if i == step: \n", 477 | " result = self.from_rgbs[layer_index](image)\n", 478 | " \n", 479 | " # Before final layer, do minibatch stddev\n", 480 | " if i == 0:\n", 481 | " # In dim: [batch, channel(512), 4, 4]\n", 482 | " res_var = result.var(0, unbiased=False) + 1e-8 # Avoid zero\n", 483 | " # Out dim: [channel(512), 4, 4]\n", 484 | " res_std = torch.sqrt(res_var)\n", 485 | " # Out dim: [channel(512), 4, 4]\n", 486 | " mean_std = res_std.mean().expand(result.size(0), 1, 4, 4)\n", 487 | " # Out dim: [1] -> [batch, 1, 4, 4]\n", 488 | " result = torch.cat([result, mean_std], 1)\n", 489 | " # Out dim: [batch, 512 + 1, 4, 4]\n", 490 | " \n", 491 | " # Conv\n", 492 | " result = self.convs[layer_index](result)\n", 493 | " \n", 494 | " # Not the final layer\n", 495 | " if i > 0:\n", 496 | " # Downsample for further usage\n", 497 | " result = nn.functional.interpolate(result, scale_factor=0.5, mode='bilinear',\n", 498 | " align_corners=False)\n", 499 | " # Alpha set, combine the result of different layers when input\n", 500 | " if i == step and 0 <= alpha < 1:\n", 501 | " result_next = self.from_rgbs[layer_index + 1](image)\n", 502 | " result_next = nn.functional.interpolate(result_next, scale_factor=0.5,\n", 503 | " mode = 'bilinear', align_corners=False)\n", 504 | " \n", 505 | " result = alpha * result + (1 - alpha) * result_next\n", 506 | " \n", 507 | " # Now, result is [batch, channel(512), 1, 1]\n", 508 | " # Convert it into [batch, channel(512)], so the fully-connetced layer \n", 509 | " # could process it.\n", 510 | " result = result.squeeze(2).squeeze(2)\n", 511 | " result = self.fc(result)\n", 512 | " return result" 513 | ] 514 | }, 515 | { 516 | "cell_type": "code", 517 | "execution_count": null, 518 | "metadata": { 519 | "scrolled": false 520 | }, 521 | "outputs": [], 522 | "source": [ 523 | "# use idel gpu\n", 524 | "# it's better to use enviroment variable\n", 525 | "# if you want to use multiple gpus, please\n", 526 | "# modify hyperparameters at the same time\n", 527 | "# And Make Sure Your Pytorch Version >= 1.0.1\n", 528 | "import os\n", 529 | "os.environ['CUDA_VISIBLE_DEVICES']='1, 2'\n", 530 | "n_gpu = 2\n", 531 | "device = torch.device('cuda:0')\n", 532 | "\n", 533 | "learning_rate = {128: 0.0015, 256: 0.002, 512: 0.003, 1024: 0.003}\n", 534 | "batch_size_1gpu = {4: 128, 8: 128, 16: 64, 32: 32, 64: 16, 128: 16}\n", 535 | "mini_batch_size_1 = 8\n", 536 | "batch_size = {4: 256, 8: 256, 16: 128, 32: 64, 64: 32, 128: 16}\n", 537 | "mini_batch_size = 8\n", 538 | "batch_size_4gpus = {4: 512, 8: 256, 16: 128, 32: 64, 64: 32}\n", 539 | "mini_batch_size_4 = 16\n", 540 | "batch_size_8gpus = {4: 512, 8: 256, 16: 128, 32: 64}\n", 541 | "mini_batch_size_8 = 32\n", 542 | "n_fc = 8\n", 543 | "dim_latent = 512\n", 544 | "dim_input = 4\n", 545 | "n_sample = 120000\n", 546 | "DGR = 1\n", 547 | "n_show_loss = 40\n", 548 | "step = 1 # Train from (8 * 8)\n", 549 | "max_step = 8 # Maximum step (8 for 1024^2)\n", 550 | "style_mixing = [] # Waiting to implement\n", 551 | "image_folder_path = './dataset/'\n", 552 | "save_folder_path = './results/'\n", 553 | "\n", 554 | "low_steps = [0, 1, 2]\n", 555 | "# style_mixing += low_steps\n", 556 | "mid_steps = [3, 4, 5]\n", 557 | "# style_mixing += mid_steps\n", 558 | "hig_steps = [6, 7, 8]\n", 559 | "# style_mixing += hig_steps\n", 560 | "\n", 561 | "# Used to continue training from last checkpoint\n", 562 | "startpoint = 0\n", 563 | "used_sample = 0\n", 564 | "alpha = 0\n", 565 | "\n", 566 | "# Mode: Evaluate? Train?\n", 567 | "is_train = True\n", 568 | "\n", 569 | "# How to start training?\n", 570 | "# True for start from saved model\n", 571 | "# False for retrain from the very beginning\n", 572 | "is_continue = True\n", 573 | "d_losses = [float('inf')]\n", 574 | "g_losses = [float('inf')]\n", 575 | "inputs, outputs = [], []\n", 576 | "\n", 577 | "def set_grad_flag(module, flag):\n", 578 | " for p in module.parameters():\n", 579 | " p.requires_grad = flag\n", 580 | "\n", 581 | "def reset_LR(optimizer, lr):\n", 582 | " for pam_group in optimizer.param_groups:\n", 583 | " mul = pam_group.get('mul', 1)\n", 584 | " pam_group['lr'] = lr * mul\n", 585 | " \n", 586 | "# Gain sample\n", 587 | "def gain_sample(dataset, batch_size, image_size=4):\n", 588 | " transform = transforms.Compose([\n", 589 | " transforms.Resize(image_size), # Resize to the same size\n", 590 | " transforms.CenterCrop(image_size), # Crop to get square area\n", 591 | " transforms.RandomHorizontalFlip(), # Increase number of samples\n", 592 | " transforms.ToTensor(), \n", 593 | " transforms.Normalize((0.5, 0.5, 0.5),\n", 594 | " (0.5, 0.5, 0.5))])\n", 595 | "\n", 596 | " dataset.transform = transform\n", 597 | " loader = DataLoader(dataset, shuffle=True, batch_size=batch_size, num_workers=8)\n", 598 | "\n", 599 | " return loader\n", 600 | "\n", 601 | "def imshow(tensor, i):\n", 602 | " grid = tensor[0]\n", 603 | " grid.clamp_(-1, 1).add_(1).div_(2)\n", 604 | " # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer\n", 605 | " ndarr = grid.mul_(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()\n", 606 | " img = Image.fromarray(ndarr)\n", 607 | " img.save(f'{save_folder_path}sample-iter{i}.png')\n", 608 | " plt.imshow(img)\n", 609 | " plt.show()\n", 610 | " \n", 611 | "# Train function\n", 612 | "def train(generator, discriminator, g_optim, d_optim, dataset, step, startpoint=0, used_sample=0,\n", 613 | " d_losses = [], g_losses = [], alpha=0):\n", 614 | " \n", 615 | " resolution = 4 * 2 ** step\n", 616 | " \n", 617 | " origin_loader = gain_sample(dataset, batch_size.get(resolution, mini_batch_size), resolution)\n", 618 | " data_loader = iter(origin_loader)\n", 619 | " \n", 620 | " reset_LR(g_optim, learning_rate.get(resolution, 0.001))\n", 621 | " reset_LR(d_optim, learning_rate.get(resolution, 0.001))\n", 622 | " \n", 623 | " progress_bar = tqdm(range(startpoint + 1, n_sample * 5))\n", 624 | " # Train\n", 625 | " for i in progress_bar:\n", 626 | " alpha = min(1, alpha + batch_size.get(resolution, mini_batch_size) / (n_sample * 2))\n", 627 | " \n", 628 | " if used_sample > n_sample * 2 and step < max_step: \n", 629 | " step += 1\n", 630 | " \n", 631 | " alpha = 0\n", 632 | " used_sample = 0\n", 633 | " \n", 634 | " resolution = 4 * 2 ** step\n", 635 | " \n", 636 | " # Avoid possble memory leak\n", 637 | " del origin_loader\n", 638 | " del data_loader\n", 639 | " \n", 640 | " # Change batch size\n", 641 | " origin_loader = gain_sample(dataset, batch_size.get(resolution, mini_batch_size), resolution)\n", 642 | " data_loader = iter(origin_loader)\n", 643 | " \n", 644 | "# torch.save({\n", 645 | "# 'generator' : generator.module.state_dict(),\n", 646 | "# 'discriminator': discriminator.module.state_dict(),\n", 647 | "# 'g_optim' : g_optim.state_dict(),\n", 648 | "# 'd_optim' : d_optim.state_dict()\n", 649 | "# }, f'checkpoint/train.pth')\n", 650 | " \n", 651 | " reset_LR(g_optim, learning_rate.get(resolution, 0.001))\n", 652 | " reset_LR(d_optim, learning_rate.get(resolution, 0.001))\n", 653 | " \n", 654 | " \n", 655 | " try:\n", 656 | " # Try to read next image\n", 657 | " real_image, label = next(data_loader)\n", 658 | "\n", 659 | " except (OSError, StopIteration):\n", 660 | " # Dataset exhausted, train from the first image\n", 661 | " data_loader = iter(origin_loader)\n", 662 | " real_image, label = next(data_loader)\n", 663 | " \n", 664 | " # Count used sample\n", 665 | " used_sample += real_image.shape[0]\n", 666 | " \n", 667 | " # Send image to GPU\n", 668 | " real_image = real_image.to(device)\n", 669 | " \n", 670 | " # D Module ---\n", 671 | " # Train discriminator first\n", 672 | " discriminator.zero_grad()\n", 673 | " set_grad_flag(discriminator, True)\n", 674 | " set_grad_flag(generator, False)\n", 675 | " \n", 676 | " # Real image predict & backward\n", 677 | " # We only implement non-saturating loss with R1 regularization loss\n", 678 | " real_image.requires_grad = True\n", 679 | " if n_gpu > 1:\n", 680 | " real_predict = nn.parallel.data_parallel(discriminator, (real_image, step, alpha), range(n_gpu))\n", 681 | " else:\n", 682 | " real_predict = discriminator(real_image, step, alpha)\n", 683 | " real_predict = nn.functional.softplus(-real_predict).mean()\n", 684 | " real_predict.backward(retain_graph=True)\n", 685 | "\n", 686 | " grad_real = torch.autograd.grad(outputs=real_predict.sum(), inputs=real_image, create_graph=True)[0]\n", 687 | " grad_penalty_real = (grad_real.view(grad_real.size(0), -1).norm(2, dim=1) ** 2).mean()\n", 688 | " grad_penalty_real = 10 / 2 * grad_penalty_real\n", 689 | " grad_penalty_real.backward()\n", 690 | " \n", 691 | " # Generate latent code\n", 692 | " latent_w1 = [torch.randn((batch_size.get(resolution, mini_batch_size), dim_latent), device=device)]\n", 693 | " latent_w2 = [torch.randn((batch_size.get(resolution, mini_batch_size), dim_latent), device=device)]\n", 694 | "\n", 695 | " noise_1 = []\n", 696 | " noise_2 = []\n", 697 | " for m in range(step + 1):\n", 698 | " size = 4 * 2 ** m # Due to the upsampling, size of noise will grow\n", 699 | " noise_1.append(torch.randn((batch_size.get(resolution, mini_batch_size), 1, size, size), device=device))\n", 700 | " noise_2.append(torch.randn((batch_size.get(resolution, mini_batch_size), 1, size, size), device=device))\n", 701 | " \n", 702 | " # Generate fake image & backward\n", 703 | " if n_gpu > 1:\n", 704 | " fake_image = nn.parallel.data_parallel(generator, (latent_w1, step, alpha, noise_1), range(n_gpu))\n", 705 | " fake_predict = nn.parallel.data_parallel(discriminator, (fake_image, step, alpha), range(n_gpu))\n", 706 | " else:\n", 707 | " fake_image = generator(latent_w1, step, alpha, noise_1)\n", 708 | " fake_predict = discriminator(fake_image, step, alpha)\n", 709 | "\n", 710 | " fake_predict = nn.functional.softplus(fake_predict).mean()\n", 711 | " fake_predict.backward()\n", 712 | " \n", 713 | " if i % n_show_loss == 0:\n", 714 | " d_losses.append((real_predict + fake_predict).item())\n", 715 | " \n", 716 | " # D optimizer step\n", 717 | " d_optim.step()\n", 718 | " \n", 719 | " # Avoid possible memory leak\n", 720 | " del grad_penalty_real, grad_real, fake_predict, real_predict, fake_image, real_image, latent_w1\n", 721 | " \n", 722 | " # G module ---\n", 723 | " if i % DGR != 0: continue\n", 724 | " # Due to DGR, train generator\n", 725 | " generator.zero_grad()\n", 726 | " set_grad_flag(discriminator, False)\n", 727 | " set_grad_flag(generator, True)\n", 728 | " \n", 729 | " if n_gpu > 1:\n", 730 | " fake_image = nn.parallel.data_parallel(generator, (latent_w2, step, alpha, noise_2), range(n_gpu))\n", 731 | " fake_predict = nn.parallel.data_parallel(discriminator, (fake_image, step, alpha), range(n_gpu))\n", 732 | " else: \n", 733 | " fake_image = generator(latent_w2, step, alpha, noise_2)\n", 734 | " fake_predict = discriminator(fake_image, step, alpha)\n", 735 | " fake_predict = nn.functional.softplus(-fake_predict).mean()\n", 736 | " fake_predict.backward()\n", 737 | " g_optim.step()\n", 738 | "\n", 739 | " if i % n_show_loss == 0:\n", 740 | " g_losses.append(fake_predict.item())\n", 741 | " imshow(fake_image.data.cpu(), i)\n", 742 | " \n", 743 | " # Avoid possible memory leak\n", 744 | " del fake_predict, fake_image, latent_w2\n", 745 | " \n", 746 | " if (i + 1) % 1000 == 0:\n", 747 | " # Save the model every 1000 iterations\n", 748 | " torch.save({\n", 749 | " 'generator' : generator.state_dict(),\n", 750 | " 'discriminator': discriminator.state_dict(),\n", 751 | " 'g_optim' : g_optim.state_dict(),\n", 752 | " 'd_optim' : d_optim.state_dict(),\n", 753 | " 'parameters' : (step, i, used_sample, alpha),\n", 754 | " 'd_losses' : d_losses,\n", 755 | " 'g_losses' : g_losses\n", 756 | " }, 'checkpoint/trained.pth')\n", 757 | " print(f'Iteration {i} successfully saved.')\n", 758 | " \n", 759 | " progress_bar.set_description((f'Resolution: {resolution}*{resolution} D_Loss: {d_losses[-1]:.4f} G_Loss: {g_losses[-1]:.4f} Alpha: {alpha:.4f}'))\n", 760 | " \n", 761 | " return d_losses, g_losses\n", 762 | "\n", 763 | "\n", 764 | "# generator = nn.DataParallel(StyleBased_Generator(n_fc, dim_latent, dim_input)).cuda()\n", 765 | "# discriminator = nn.DataParallel(Discriminator()).cuda() \n", 766 | "# g_optim = optim.Adam([{\n", 767 | "# 'params': generator.module.convs.parameters(),\n", 768 | "# 'lr' : 0.001\n", 769 | "# }, {\n", 770 | "# 'params': generator.module.to_rgbs.parameters(),\n", 771 | "# 'lr' : 0.001\n", 772 | "# }], lr=0.001, betas=(0.0, 0.99))\n", 773 | "# g_optim.add_param_group({\n", 774 | "# 'params': generator.module.fcs.parameters(),\n", 775 | "# 'lr' : 0.001 * 0.01,\n", 776 | "# 'mul' : 0.01\n", 777 | "# })\n", 778 | "\n", 779 | "generator = StyleBased_Generator(n_fc, dim_latent, dim_input).to(device)\n", 780 | "discriminator = Discriminator().to(device)\n", 781 | "g_optim = optim.Adam([{\n", 782 | " 'params': generator.convs.parameters(),\n", 783 | " 'lr' : 0.001\n", 784 | "}, {\n", 785 | " 'params': generator.to_rgbs.parameters(),\n", 786 | " 'lr' : 0.001\n", 787 | "}], lr=0.001, betas=(0.0, 0.99))\n", 788 | "g_optim.add_param_group({\n", 789 | " 'params': generator.fcs.parameters(),\n", 790 | " 'lr' : 0.001 * 0.01,\n", 791 | " 'mul' : 0.01\n", 792 | "})\n", 793 | "d_optim = optim.Adam(discriminator.parameters(), lr=0.001, betas=(0.0, 0.99))\n", 794 | "dataset = datasets.ImageFolder(image_folder_path)\n", 795 | "\n", 796 | "if is_continue:\n", 797 | " if os.path.exists('checkpoint/trained.pth'):\n", 798 | " # Load data from last checkpoint\n", 799 | " print('Loading pre-trained model...')\n", 800 | " checkpoint = torch.load('checkpoint/trained.pth')\n", 801 | " generator.load_state_dict(checkpoint['generator'])\n", 802 | " discriminator.load_state_dict(checkpoint['discriminator'])\n", 803 | " g_optim.load_state_dict(checkpoint['g_optim'])\n", 804 | " d_optim.load_state_dict(checkpoint['d_optim'])\n", 805 | " step, startpoint, used_sample, alpha = checkpoint['parameters']\n", 806 | " d_losses = checkpoint.get('d_losses', [float('inf')])\n", 807 | " g_losses = checkpoint.get('g_losses', [float('inf')])\n", 808 | " else:\n", 809 | " print('No pre-trained model detected, restart training...')\n", 810 | " \n", 811 | "if is_train:\n", 812 | " generator.train()\n", 813 | " discriminator.train() \n", 814 | " d_losses, g_losses = train(generator, discriminator, g_optim, d_optim, dataset, step, startpoint, used_sample, d_losses, g_losses, alpha)\n", 815 | "else:\n", 816 | " # Do some evaluation here\n", 817 | " pass" 818 | ] 819 | } 820 | ], 821 | "metadata": { 822 | "kernelspec": { 823 | "display_name": "Python 3", 824 | "language": "python", 825 | "name": "python3" 826 | }, 827 | "language_info": { 828 | "codemirror_mode": { 829 | "name": "ipython", 830 | "version": 3 831 | }, 832 | "file_extension": ".py", 833 | "mimetype": "text/x-python", 834 | "name": "python", 835 | "nbconvert_exporter": "python", 836 | "pygments_lexer": "ipython3", 837 | "version": "3.7.2" 838 | } 839 | }, 840 | "nbformat": 4, 841 | "nbformat_minor": 2 842 | } 843 | -------------------------------------------------------------------------------- /dataset_preview.py: -------------------------------------------------------------------------------- 1 | # Import necessary modules 2 | import torch 3 | from tqdm import tqdm 4 | import torchvision 5 | import torchvision.transforms as transforms 6 | from PIL import Image 7 | 8 | from torch.utils.data import DataLoader 9 | from torchvision import datasets, transforms, utils 10 | 11 | image_folder_path = './dataset/' 12 | 13 | def gain_sample(dataset, batch_size, image_size=4): 14 | transform = transforms.Compose([ 15 | transforms.Resize(image_size), # Resize to the same size 16 | transforms.CenterCrop(image_size), # Crop to get square area 17 | transforms.RandomHorizontalFlip(), # Increase number of samples 18 | transforms.ToTensor(), 19 | transforms.Normalize((0.5, 0.5, 0.5), 20 | (0.5, 0.5, 0.5))]) 21 | 22 | dataset.transform = transform 23 | loader = DataLoader(dataset, shuffle=True, batch_size=batch_size, num_workers=4) 24 | 25 | return loader 26 | 27 | dataset = datasets.ImageFolder(image_folder_path) 28 | origin_loader = gain_sample(dataset, 240, 64) 29 | data_loader = iter(origin_loader) 30 | 31 | for i in range(10): 32 | real_image, label = next(data_loader) 33 | torchvision.utils.save_image(real_image, f'./previews/preview{i}.png', nrow=24, padding=2, normalize=True, range=(-1,1)) -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Import SGAN models 3 | from model import * 4 | 5 | # use idel gpu 6 | # it's better to use environment variable 7 | # if you want to use multiple GPUs, please 8 | # modify hyperparameters at the same time 9 | import os 10 | os.environ['CUDA_VISIBLE_DEVICES']='1, 2' 11 | n_gpu = 2 12 | device = torch.device('cuda:0') 13 | 14 | # Hyper-parameters 15 | n_fc = 8 16 | dim_latent = 512 17 | dim_input = 4 18 | step = 7 19 | resolution = 2 ** (step + 2) 20 | save_folder_path = './results/' 21 | 22 | # Style mixing setting 23 | style_mixing = [] 24 | low_steps = [0, 1, 2] 25 | # style_mixing += low_steps 26 | mid_steps = [3, 4, 5] 27 | # style_mixing += mid_steps 28 | hig_steps = [6, 7, 8] 29 | # style_mixing += hig_steps 30 | 31 | generator = StyleBased_Generator(n_fc, dim_latent, dim_input).to(device) 32 | if os.path.exists('checkpoint/trained.pth'): 33 | checkpoint = torch.load('checkpoint/trained.pth') 34 | generator.load_state_dict(checkpoint['generator']) 35 | else: 36 | raise IOError('No checkpoint file found at ./checkpoint/trained.pth') 37 | generator.eval() 38 | # No computing gradients 39 | for param in generator.parameters(): 40 | param.requires_grad = False 41 | 42 | def compute_latent_cernter(batch_size=1024, multimes=10): 43 | appro_latent_center = None 44 | for i in range(multimes): 45 | if appro_latent_center is None: 46 | appro_latent_center = generator.center_w(torch.randn((batch_size, dim_latent)).to(device)) 47 | else: 48 | appro_latent_center += generator.center_w(torch.randn((batch_size, dim_latent)).to(device)) 49 | appro_latent_center /= multimes 50 | return appro_latent_center 51 | 52 | def evaluate(latent_code, noise, latent_w_center=None, psi=0, style_mixing=[]): 53 | if n_gpu > 1: 54 | return nn.parallel.data_parallel(generator, (latent_code, step, 1, noise, style_mixing, 55 | latent_w_center, psi), range(n_gpu)) 56 | else: 57 | return generator(latent_code, step, 1, noise, style_mixing, latent_w_center, psi) -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | # Import necessary modules 2 | import torch 3 | import torch.nn as nn 4 | import math 5 | 6 | # Constraints 7 | # Input: [batch_size, in_channels, height, width] 8 | 9 | # Scaled weight - He initialization 10 | # "explicitly scale the weights at runtime" 11 | class ScaleW: 12 | ''' 13 | Constructor: name - name of attribute to be scaled 14 | ''' 15 | def __init__(self, name): 16 | self.name = name 17 | 18 | def scale(self, module): 19 | weight = getattr(module, self.name + '_orig') 20 | fan_in = weight.data.size(1) * weight.data[0][0].numel() 21 | 22 | return weight * math.sqrt(2 / fan_in) 23 | 24 | @staticmethod 25 | def apply(module, name): 26 | ''' 27 | Apply runtime scaling to specific module 28 | ''' 29 | hook = ScaleW(name) 30 | weight = getattr(module, name) 31 | module.register_parameter(name + '_orig', nn.Parameter(weight.data)) 32 | del module._parameters[name] 33 | module.register_forward_pre_hook(hook) 34 | 35 | def __call__(self, module, whatever): 36 | weight = self.scale(module) 37 | setattr(module, self.name, weight) 38 | 39 | # Quick apply for scaled weight 40 | def quick_scale(module, name='weight'): 41 | ScaleW.apply(module, name) 42 | return module 43 | 44 | # Uniformly set the hyperparameters of Linears 45 | # "We initialize all weights of the convolutional, fully-connected, and affine transform layers using N(0, 1)" 46 | # 5/13: Apply scaled weights 47 | class SLinear(nn.Module): 48 | def __init__(self, dim_in, dim_out): 49 | super().__init__() 50 | 51 | linear = nn.Linear(dim_in, dim_out) 52 | linear.weight.data.normal_() 53 | linear.bias.data.zero_() 54 | 55 | self.linear = quick_scale(linear) 56 | 57 | def forward(self, x): 58 | return self.linear(x) 59 | 60 | # Uniformly set the hyperparameters of Conv2d 61 | # "We initialize all weights of the convolutional, fully-connected, and affine transform layers using N(0, 1)" 62 | # 5/13: Apply scaled weights 63 | class SConv2d(nn.Module): 64 | def __init__(self, *args, **kwargs): 65 | super().__init__() 66 | 67 | conv = nn.Conv2d(*args, **kwargs) 68 | conv.weight.data.normal_() 69 | conv.bias.data.zero_() 70 | 71 | self.conv = quick_scale(conv) 72 | 73 | def forward(self, x): 74 | return self.conv(x) 75 | 76 | # Normalization on every element of input vector 77 | class PixelNorm(nn.Module): 78 | def __init__(self): 79 | super().__init__() 80 | 81 | def forward(self, x): 82 | return x / torch.sqrt(torch.mean(x ** 2, dim=1, keepdim=True) + 1e-8) 83 | 84 | # "learned affine transform" A 85 | class FC_A(nn.Module): 86 | ''' 87 | Learned affine transform A, this module is used to transform 88 | midiate vector w into a style vector 89 | ''' 90 | def __init__(self, dim_latent, n_channel): 91 | super().__init__() 92 | self.transform = SLinear(dim_latent, n_channel * 2) 93 | # "the biases associated with ys that we initialize to one" 94 | self.transform.linear.bias.data[:n_channel] = 1 95 | self.transform.linear.bias.data[n_channel:] = 0 96 | 97 | def forward(self, w): 98 | # Gain scale factor and bias with: 99 | style = self.transform(w).unsqueeze(2).unsqueeze(3) 100 | return style 101 | 102 | # AdaIn (AdaptiveInstanceNorm) 103 | class AdaIn(nn.Module): 104 | ''' 105 | adaptive instance normalization 106 | ''' 107 | def __init__(self, n_channel): 108 | super().__init__() 109 | self.norm = nn.InstanceNorm2d(n_channel) 110 | 111 | def forward(self, image, style): 112 | factor, bias = style.chunk(2, 1) 113 | result = self.norm(image) 114 | result = result * factor + bias 115 | return result 116 | 117 | # "learned per-channel scaling factors" B 118 | # 5/13: Debug - tensor -> nn.Parameter 119 | class Scale_B(nn.Module): 120 | ''' 121 | Learned per-channel scale factor, used to scale the noise 122 | ''' 123 | def __init__(self, n_channel): 124 | super().__init__() 125 | self.weight = nn.Parameter(torch.zeros((1, n_channel, 1, 1))) 126 | 127 | def forward(self, noise): 128 | result = noise * self.weight 129 | return result 130 | 131 | # Early convolutional block 132 | # 5/13: Debug - tensor -> nn.Parameter 133 | # 5/13: Remove noise generating module 134 | class Early_StyleConv_Block(nn.Module): 135 | ''' 136 | This is the very first block of generator that get the constant value as input 137 | ''' 138 | def __init__ (self, n_channel, dim_latent, dim_input): 139 | super().__init__() 140 | # Constant input 141 | self.constant = nn.Parameter(torch.randn(1, n_channel, dim_input, dim_input)) 142 | # Style generators 143 | self.style1 = FC_A(dim_latent, n_channel) 144 | self.style2 = FC_A(dim_latent, n_channel) 145 | # Noise processing modules 146 | self.noise1 = quick_scale(Scale_B(n_channel)) 147 | self.noise2 = quick_scale(Scale_B(n_channel)) 148 | # AdaIn 149 | self.adain = AdaIn(n_channel) 150 | self.lrelu = nn.LeakyReLU(0.2) 151 | # Convolutional layer 152 | self.conv = SConv2d(n_channel, n_channel, 3, padding=1) 153 | 154 | def forward(self, latent_w, noise): 155 | # Gaussian Noise: Proxyed by generator 156 | # noise1 = torch.normal(mean=0,std=torch.ones(self.constant.shape)).cuda() 157 | # noise2 = torch.normal(mean=0,std=torch.ones(self.constant.shape)).cuda() 158 | result = self.constant.repeat(noise.shape[0], 1, 1, 1) 159 | result = result + self.noise1(noise) 160 | result = self.adain(result, self.style1(latent_w)) 161 | result = self.lrelu(result) 162 | result = self.conv(result) 163 | result = result + self.noise2(noise) 164 | result = self.adain(result, self.style2(latent_w)) 165 | result = self.lrelu(result) 166 | 167 | return result 168 | 169 | # General convolutional blocks 170 | # 5/13: Remove upsampling 171 | # 5/13: Remove noise generating 172 | class StyleConv_Block(nn.Module): 173 | ''' 174 | This is the general class of style-based convolutional blocks 175 | ''' 176 | def __init__ (self, in_channel, out_channel, dim_latent): 177 | super().__init__() 178 | # Style generators 179 | self.style1 = FC_A(dim_latent, out_channel) 180 | self.style2 = FC_A(dim_latent, out_channel) 181 | # Noise processing modules 182 | self.noise1 = quick_scale(Scale_B(out_channel)) 183 | self.noise2 = quick_scale(Scale_B(out_channel)) 184 | # AdaIn 185 | self.adain = AdaIn(out_channel) 186 | self.lrelu = nn.LeakyReLU(0.2) 187 | # Convolutional layers 188 | self.conv1 = SConv2d(in_channel, out_channel, 3, padding=1) 189 | self.conv2 = SConv2d(out_channel, out_channel, 3, padding=1) 190 | 191 | def forward(self, previous_result, latent_w, noise): 192 | # Upsample: Proxyed by generator 193 | # result = nn.functional.interpolate(previous_result, scale_factor=2, mode='bilinear', 194 | # align_corners=False) 195 | # Conv 3*3 196 | result = self.conv1(previous_result) 197 | # Gaussian Noise: Proxyed by generator 198 | # noise1 = torch.normal(mean=0,std=torch.ones(result.shape)).cuda() 199 | # noise2 = torch.normal(mean=0,std=torch.ones(result.shape)).cuda() 200 | # Conv & Norm 201 | result = result + self.noise1(noise) 202 | result = self.adain(result, self.style1(latent_w)) 203 | result = self.lrelu(result) 204 | result = self.conv2(result) 205 | result = result + self.noise2(noise) 206 | result = self.adain(result, self.style2(latent_w)) 207 | result = self.lrelu(result) 208 | 209 | return result 210 | 211 | # Very First Convolutional Block 212 | # 5/13: No more downsample, this block is the same sa general ones 213 | # class Early_ConvBlock(nn.Module): 214 | # ''' 215 | # Used to construct progressive discriminator 216 | # ''' 217 | # def __init__(self, in_channel, out_channel, size_kernel, padding): 218 | # super().__init__() 219 | # self.conv = nn.Sequential( 220 | # SConv2d(in_channel, out_channel, size_kernel, padding=padding), 221 | # nn.LeakyReLU(0.2), 222 | # SConv2d(out_channel, out_channel, size_kernel, padding=padding), 223 | # nn.LeakyReLU(0.2) 224 | # ) 225 | 226 | # def forward(self, image): 227 | # result = self.conv(image) 228 | # return result 229 | 230 | # General Convolutional Block 231 | # 5/13: Downsample is now removed from block module 232 | class ConvBlock(nn.Module): 233 | ''' 234 | Used to construct progressive discriminator 235 | ''' 236 | def __init__(self, in_channel, out_channel, size_kernel1, padding1, 237 | size_kernel2 = None, padding2 = None): 238 | super().__init__() 239 | 240 | if size_kernel2 == None: 241 | size_kernel2 = size_kernel1 242 | if padding2 == None: 243 | padding2 = padding1 244 | 245 | self.conv = nn.Sequential( 246 | SConv2d(in_channel, out_channel, size_kernel1, padding=padding1), 247 | nn.LeakyReLU(0.2), 248 | SConv2d(out_channel, out_channel, size_kernel2, padding=padding2), 249 | nn.LeakyReLU(0.2) 250 | ) 251 | 252 | def forward(self, image): 253 | # Downsample now proxyed by discriminator 254 | # result = nn.functional.interpolate(image, scale_factor=0.5, mode="bilinear", align_corners=False) 255 | # Conv 256 | result = self.conv(image) 257 | return result 258 | 259 | 260 | # Main components 261 | class Intermediate_Generator(nn.Module): 262 | ''' 263 | A mapping consists of multiple fully connected layers. 264 | Used to map the input to an intermediate latent space W. 265 | ''' 266 | def __init__(self, n_fc, dim_latent): 267 | super().__init__() 268 | layers = [PixelNorm()] 269 | for i in range(n_fc): 270 | layers.append(SLinear(dim_latent, dim_latent)) 271 | layers.append(nn.LeakyReLU(0.2)) 272 | 273 | self.mapping = nn.Sequential(*layers) 274 | 275 | def forward(self, latent_z): 276 | latent_w = self.mapping(latent_z) 277 | return latent_w 278 | 279 | # Generator 280 | # 5/13: Support progressive training 281 | # 5/13: Proxy noise generating 282 | # 5/13: Proxy upsampling 283 | class StyleBased_Generator(nn.Module): 284 | ''' 285 | Main Module 286 | ''' 287 | def __init__(self, n_fc, dim_latent, dim_input): 288 | super().__init__() 289 | # Waiting to adjust the size 290 | self.fcs = Intermediate_Generator(n_fc, dim_latent) 291 | self.convs = nn.ModuleList([ 292 | Early_StyleConv_Block(512, dim_latent, dim_input), 293 | StyleConv_Block(512, 512, dim_latent), 294 | StyleConv_Block(512, 512, dim_latent), 295 | StyleConv_Block(512, 512, dim_latent), 296 | StyleConv_Block(512, 256, dim_latent), 297 | StyleConv_Block(256, 128, dim_latent), 298 | StyleConv_Block(128, 64, dim_latent), 299 | StyleConv_Block(64, 32, dim_latent), 300 | StyleConv_Block(32, 16, dim_latent) 301 | ]) 302 | self.to_rgbs = nn.ModuleList([ 303 | SConv2d(512, 3, 1), 304 | SConv2d(512, 3, 1), 305 | SConv2d(512, 3, 1), 306 | SConv2d(512, 3, 1), 307 | SConv2d(256, 3, 1), 308 | SConv2d(128, 3, 1), 309 | SConv2d(64, 3, 1), 310 | SConv2d(32, 3, 1), 311 | SConv2d(16, 3, 1) 312 | ]) 313 | def forward(self, latent_z, 314 | step = 0, # Step means how many layers (count from 4 x 4) are used to train 315 | alpha=-1, # Alpha is the parameter of smooth conversion of resolution): 316 | noise=None, # TODO: support none noise 317 | mix_steps=[], # steps inside will use latent_z[1], else latent_z[0] 318 | latent_w_center=None, # Truncation trick in W 319 | psi=0): # parameter of truncation 320 | if type(latent_z) != type([]): 321 | print('You should use list to package your latent_z') 322 | latent_z = [latent_z] 323 | if (len(latent_z) != 2 and len(mix_steps) > 0) or type(mix_steps) != type([]): 324 | print('Warning: Style mixing disabled, possible reasons:') 325 | print('- Invalid number of latent vectors') 326 | print('- Invalid parameter type: mix_steps') 327 | mix_steps = [] 328 | 329 | latent_w = [self.fcs(latent) for latent in latent_z] 330 | batch_size = latent_w[0].size(0) 331 | 332 | # Truncation trick in W 333 | # Only usable in estimation 334 | if latent_w_center is not None: 335 | latent_w = [latent_w_center + psi * (unscaled_latent_w - latent_w_center) 336 | for unscaled_latent_w in latent_w] 337 | 338 | # Generate needed Gaussian noise 339 | # 5/22: Noise is now generated by outer module 340 | # noise = [] 341 | result = 0 342 | current_latent = 0 343 | # for i in range(step + 1): 344 | # size = 4 * 2 ** i # Due to the upsampling, size of noise will grow 345 | # noise.append(torch.randn((batch_size, 1, size, size), device=torch.device('cuda:0'))) 346 | 347 | for i, conv in enumerate(self.convs): 348 | # Choose current latent_w 349 | if i in mix_steps: 350 | current_latent = latent_w[1] 351 | else: 352 | current_latent = latent_w[0] 353 | 354 | # Not the first layer, need to upsample 355 | if i > 0 and step > 0: 356 | result_upsample = nn.functional.interpolate(result, scale_factor=2, mode='bilinear', 357 | align_corners=False) 358 | result = conv(result_upsample, current_latent, noise[i]) 359 | else: 360 | result = conv(current_latent, noise[i]) 361 | 362 | # Final layer, output rgb image 363 | if i == step: 364 | result = self.to_rgbs[i](result) 365 | 366 | if i > 0 and 0 <= alpha < 1: 367 | result_prev = self.to_rgbs[i - 1](result_upsample) 368 | result = alpha * result + (1 - alpha) * result_prev 369 | 370 | # Finish and break 371 | break 372 | 373 | return result 374 | 375 | def center_w(self, zs): 376 | ''' 377 | To begin, we compute the center of mass of W 378 | ''' 379 | latent_w_center = self.fcs(zs).mean(0, keepdim=True) 380 | return latent_w_center 381 | 382 | 383 | # Discriminator 384 | # 5/13: Support progressive training 385 | # 5/13: Add downsample module 386 | # Component of Progressive GAN 387 | # Reference: Karras, T., Aila, T., Laine, S., & Lehtinen, J. (2017). 388 | # Progressive Growing of GANs for Improved Quality, Stability, and Variation, 1–26. 389 | # Retrieved from http://arxiv.org/abs/1710.10196 390 | class Discriminator(nn.Module): 391 | ''' 392 | Main Module 393 | ''' 394 | def __init__(self): 395 | super().__init__() 396 | # Waiting to adjust the size 397 | self.from_rgbs = nn.ModuleList([ 398 | SConv2d(3, 16, 1), 399 | SConv2d(3, 32, 1), 400 | SConv2d(3, 64, 1), 401 | SConv2d(3, 128, 1), 402 | SConv2d(3, 256, 1), 403 | SConv2d(3, 512, 1), 404 | SConv2d(3, 512, 1), 405 | SConv2d(3, 512, 1), 406 | SConv2d(3, 512, 1) 407 | ]) 408 | self.convs = nn.ModuleList([ 409 | ConvBlock(16, 32, 3, 1), 410 | ConvBlock(32, 64, 3, 1), 411 | ConvBlock(64, 128, 3, 1), 412 | ConvBlock(128, 256, 3, 1), 413 | ConvBlock(256, 512, 3, 1), 414 | ConvBlock(512, 512, 3, 1), 415 | ConvBlock(512, 512, 3, 1), 416 | ConvBlock(512, 512, 3, 1), 417 | ConvBlock(513, 512, 3, 1, 4, 0) 418 | ]) 419 | self.fc = SLinear(512, 1) 420 | 421 | self.n_layer = 9 # 9 layers network 422 | 423 | def forward(self, image, 424 | step = 0, # Step means how many layers (count from 4 x 4) are used to train 425 | alpha=-1): # Alpha is the parameter of smooth conversion of resolution): 426 | for i in range(step, -1, -1): 427 | # Get the index of current layer 428 | # Count from the bottom layer (4 * 4) 429 | layer_index = self.n_layer - i - 1 430 | 431 | # First layer, need to use from_rgb to convert to n_channel data 432 | if i == step: 433 | result = self.from_rgbs[layer_index](image) 434 | 435 | # Before final layer, do minibatch stddev 436 | if i == 0: 437 | # In dim: [batch, channel(512), 4, 4] 438 | res_var = result.var(0, unbiased=False) + 1e-8 # Avoid zero 439 | # Out dim: [channel(512), 4, 4] 440 | res_std = torch.sqrt(res_var) 441 | # Out dim: [channel(512), 4, 4] 442 | mean_std = res_std.mean().expand(result.size(0), 1, 4, 4) 443 | # Out dim: [1] -> [batch, 1, 4, 4] 444 | result = torch.cat([result, mean_std], 1) 445 | # Out dim: [batch, 512 + 1, 4, 4] 446 | 447 | # Conv 448 | result = self.convs[layer_index](result) 449 | 450 | # Not the final layer 451 | if i > 0: 452 | # Downsample for further usage 453 | result = nn.functional.interpolate(result, scale_factor=0.5, mode='bilinear', 454 | align_corners=False) 455 | # Alpha set, combine the result of different layers when input 456 | if i == step and 0 <= alpha < 1: 457 | result_next = self.from_rgbs[layer_index + 1](image) 458 | result_next = nn.functional.interpolate(result_next, scale_factor=0.5, 459 | mode = 'bilinear', align_corners=False) 460 | 461 | result = alpha * result + (1 - alpha) * result_next 462 | 463 | # Now, result is [batch, channel(512), 1, 1] 464 | # Convert it into [batch, channel(512)], so the fully-connetced layer 465 | # could process it. 466 | result = result.squeeze(2).squeeze(2) 467 | result = self.fc(result) 468 | return result -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # use idel gpu 2 | # it's better to use enviroment variable 3 | # if you want to use multiple gpus, please 4 | # modify hyperparameters at the same time 5 | # And Make Sure Your Pytorch Version >= 1.0.1 6 | import os 7 | os.environ['CUDA_VISIBLE_DEVICES']='3, 2' 8 | n_gpu = 2 9 | device = torch.device('cuda:0') 10 | 11 | # Original Learning Rate 12 | learning_rate = {128: 0.0015, 256: 0.002, 512: 0.003, 1024: 0.003} 13 | # For anime only 14 | # learning_rate = {512: 0.0015, 1024: 0.002} 15 | batch_size_1gpu = {4: 128, 8: 128, 16: 64, 32: 32, 64: 16, 128: 16} 16 | mini_batch_size_1 = 8 17 | batch_size = {4: 256, 8: 256, 16: 128, 32: 64, 64: 32, 128: 16} 18 | mini_batch_size = 8 19 | batch_size_4gpus = {4: 512, 8: 256, 16: 128, 32: 64, 64: 32} 20 | mini_batch_size_4 = 16 21 | batch_size_8gpus = {4: 512, 8: 256, 16: 128, 32: 64} 22 | mini_batch_size_8 = 32 23 | # Commen line below if you don't meet the problem of 'shared memory conflict' 24 | num_workers = {128: 8, 256: 4, 512: 2} 25 | max_workers = 16 26 | n_fc = 8 27 | dim_latent = 512 28 | dim_input = 4 29 | # number of samples to show before dowbling resolution 30 | n_sample = 600_000 31 | # number of samples train model in total 32 | n_sample_total = 10_000_000 33 | DGR = 1 34 | n_show_loss = 360 35 | step = 1 # Train from (8 * 8) 36 | max_step = 7 37 | style_mixing = [] # Waiting to implement 38 | image_folder_path = './dataset/' 39 | save_folder_path = './results/' 40 | 41 | low_steps = [0, 1, 2] 42 | # style_mixing += low_steps 43 | mid_steps = [3, 4, 5] 44 | # style_mixing += mid_steps 45 | hig_steps = [6, 7, 8] 46 | # style_mixing += hig_steps 47 | 48 | # Used to continue training from last checkpoint 49 | iteration = 0 50 | startpoint = 0 51 | used_sample = 0 52 | alpha = 0 53 | 54 | # How to start training? 55 | # True for start from saved model 56 | # False for retrain from the very beginning 57 | is_continue = True 58 | d_losses = [float('inf')] 59 | g_losses = [float('inf')] 60 | 61 | def set_grad_flag(module, flag): 62 | for p in module.parameters(): 63 | p.requires_grad = flag 64 | 65 | def reset_LR(optimizer, lr): 66 | for pam_group in optimizer.param_groups: 67 | mul = pam_group.get('mul', 1) 68 | pam_group['lr'] = lr * mul 69 | 70 | # Gain sample 71 | def gain_sample(dataset, batch_size, image_size=4): 72 | transform = transforms.Compose([ 73 | transforms.Resize(image_size), # Resize to the same size 74 | transforms.CenterCrop(image_size), # Crop to get square area 75 | transforms.RandomHorizontalFlip(), # Increase number of samples 76 | transforms.ToTensor(), 77 | transforms.Normalize((0.5, 0.5, 0.5), 78 | (0.5, 0.5, 0.5))]) 79 | 80 | dataset.transform = transform 81 | loader = DataLoader(dataset, shuffle=True, batch_size=batch_size, num_workers=num_workers.get(image_size, max_workers)) 82 | 83 | return loader 84 | 85 | def imsave(tensor, i): 86 | grid = tensor[0] 87 | grid.clamp_(-1, 1).add_(1).div_(2) 88 | # Add 0.5 after normalizing to [0, 255] to round to nearest integer 89 | ndarr = grid.mul_(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy() 90 | img = Image.fromarray(ndarr) 91 | img.save(f'{save_folder_path}sample-iter{i}.png') 92 | 93 | # Train function 94 | def train(generator, discriminator, g_optim, d_optim, dataset, step, iteration=0, startpoint=0, used_sample=0, 95 | d_losses = [], g_losses = [], alpha=0): 96 | 97 | resolution = 4 * 2 ** step 98 | 99 | origin_loader = gain_sample(dataset, batch_size.get(resolution, mini_batch_size), resolution) 100 | data_loader = iter(origin_loader) 101 | 102 | reset_LR(g_optim, learning_rate.get(resolution, 0.001)) 103 | reset_LR(d_optim, learning_rate.get(resolution, 0.001)) 104 | progress_bar = tqdm(total=n_sample_total, initial=used_sample) 105 | # Train 106 | while used_sample < n_sample_total: 107 | iteration += 1 108 | alpha = min(1, alpha + batch_size.get(resolution, mini_batch_size) / (n_sample)) 109 | 110 | if (used_sample - startpoint) > n_sample and step < max_step: 111 | step += 1 112 | alpha = 0 113 | startpoint = used_sample 114 | 115 | resolution = 4 * 2 ** step 116 | 117 | # Avoid possible memory leak 118 | del origin_loader 119 | del data_loader 120 | 121 | # Change batch size 122 | origin_loader = gain_sample(dataset, batch_size.get(resolution, mini_batch_size), resolution) 123 | data_loader = iter(origin_loader) 124 | 125 | reset_LR(g_optim, learning_rate.get(resolution, 0.001)) 126 | reset_LR(d_optim, learning_rate.get(resolution, 0.001)) 127 | 128 | 129 | try: 130 | # Try to read next image 131 | real_image, label = next(data_loader) 132 | 133 | except (OSError, StopIteration): 134 | # Dataset exhausted, train from the first image 135 | data_loader = iter(origin_loader) 136 | real_image, label = next(data_loader) 137 | 138 | # Count used sample 139 | used_sample += real_image.shape[0] 140 | progress_bar.update(real_image.shape[0]) 141 | 142 | # Send image to GPU 143 | real_image = real_image.to(device) 144 | 145 | # D Module --- 146 | # Train discriminator first 147 | discriminator.zero_grad() 148 | set_grad_flag(discriminator, True) 149 | set_grad_flag(generator, False) 150 | 151 | # Real image predict & backward 152 | # We only implement non-saturating loss with R1 regularization loss 153 | real_image.requires_grad = True 154 | if n_gpu > 1: 155 | real_predict = nn.parallel.data_parallel(discriminator, (real_image, step, alpha), range(n_gpu)) 156 | else: 157 | real_predict = discriminator(real_image, step, alpha) 158 | real_predict = nn.functional.softplus(-real_predict).mean() 159 | real_predict.backward(retain_graph=True) 160 | 161 | grad_real = torch.autograd.grad(outputs=real_predict.sum(), inputs=real_image, create_graph=True)[0] 162 | grad_penalty_real = (grad_real.view(grad_real.size(0), -1).norm(2, dim=1) ** 2).mean() 163 | grad_penalty_real = 10 / 2 * grad_penalty_real 164 | grad_penalty_real.backward() 165 | 166 | # Generate latent code 167 | latent_w1 = [torch.randn((batch_size.get(resolution, mini_batch_size), dim_latent), device=device)] 168 | latent_w2 = [torch.randn((batch_size.get(resolution, mini_batch_size), dim_latent), device=device)] 169 | 170 | noise_1 = [] 171 | noise_2 = [] 172 | for m in range(step + 1): 173 | size = 4 * 2 ** m # Due to the upsampling, size of noise will grow 174 | noise_1.append(torch.randn((batch_size.get(resolution, mini_batch_size), 1, size, size), device=device)) 175 | noise_2.append(torch.randn((batch_size.get(resolution, mini_batch_size), 1, size, size), device=device)) 176 | 177 | # Generate fake image & backward 178 | if n_gpu > 1: 179 | fake_image = nn.parallel.data_parallel(generator, (latent_w1, step, alpha, noise_1), range(n_gpu)) 180 | fake_predict = nn.parallel.data_parallel(discriminator, (fake_image, step, alpha), range(n_gpu)) 181 | else: 182 | fake_image = generator(latent_w1, step, alpha, noise_1) 183 | fake_predict = discriminator(fake_image, step, alpha) 184 | 185 | fake_predict = nn.functional.softplus(fake_predict).mean() 186 | fake_predict.backward() 187 | 188 | if iteration % n_show_loss == 0: 189 | d_losses.append((real_predict + fake_predict).item()) 190 | 191 | # D optimizer step 192 | d_optim.step() 193 | 194 | # Avoid possible memory leak 195 | del grad_penalty_real, grad_real, fake_predict, real_predict, fake_image, real_image, latent_w1 196 | 197 | # G module --- 198 | if iteration % DGR != 0: continue 199 | # Due to DGR, train generator 200 | generator.zero_grad() 201 | set_grad_flag(discriminator, False) 202 | set_grad_flag(generator, True) 203 | 204 | if n_gpu > 1: 205 | fake_image = nn.parallel.data_parallel(generator, (latent_w2, step, alpha, noise_2), range(n_gpu)) 206 | fake_predict = nn.parallel.data_parallel(discriminator, (fake_image, step, alpha), range(n_gpu)) 207 | else: 208 | fake_image = generator(latent_w2, step, alpha, noise_2) 209 | fake_predict = discriminator(fake_image, step, alpha) 210 | fake_predict = nn.functional.softplus(-fake_predict).mean() 211 | fake_predict.backward() 212 | g_optim.step() 213 | 214 | if iteration % n_show_loss == 0: 215 | g_losses.append(fake_predict.item()) 216 | imsave(fake_image.data.cpu(), iteration) 217 | 218 | # Avoid possible memory leak 219 | del fake_predict, fake_image, latent_w2 220 | 221 | if iteration % 1000 == 0: 222 | # Save the model every 1000 iterations 223 | torch.save({ 224 | 'generator' : generator.state_dict(), 225 | 'discriminator': discriminator.state_dict(), 226 | 'g_optim' : g_optim.state_dict(), 227 | 'd_optim' : d_optim.state_dict(), 228 | 'parameters' : (step, iteration, startpoint, used_sample, alpha), 229 | 'd_losses' : d_losses, 230 | 'g_losses' : g_losses 231 | }, 'checkpoint/trained.pth') 232 | print(f'Model successfully saved.') 233 | 234 | progress_bar.set_description((f'Resolution: {resolution}*{resolution} D_Loss: {d_losses[-1]:.4f} G_Loss: {g_losses[-1]:.4f} Alpha: {alpha:.4f}')) 235 | torch.save({ 236 | 'generator' : generator.state_dict(), 237 | 'discriminator': discriminator.state_dict(), 238 | 'g_optim' : g_optim.state_dict(), 239 | 'd_optim' : d_optim.state_dict(), 240 | 'parameters' : (step, iteration, startpoint, used_sample, alpha), 241 | 'd_losses' : d_losses, 242 | 'g_losses' : g_losses 243 | }, 'checkpoint/trained.pth') 244 | print(f'Final model successfully saved.') 245 | return d_losses, g_losses 246 | 247 | 248 | # generator = nn.DataParallel(StyleBased_Generator(n_fc, dim_latent, dim_input)).cuda() 249 | # discriminator = nn.DataParallel(Discriminator()).cuda() 250 | # g_optim = optim.Adam([{ 251 | # 'params': generator.module.convs.parameters(), 252 | # 'lr' : 0.001 253 | # }, { 254 | # 'params': generator.module.to_rgbs.parameters(), 255 | # 'lr' : 0.001 256 | # }], lr=0.001, betas=(0.0, 0.99)) 257 | # g_optim.add_param_group({ 258 | # 'params': generator.module.fcs.parameters(), 259 | # 'lr' : 0.001 * 0.01, 260 | # 'mul' : 0.01 261 | # }) 262 | 263 | # Create models 264 | generator = StyleBased_Generator(n_fc, dim_latent, dim_input).to(device) 265 | discriminator = Discriminator().to(device) 266 | 267 | # Optimizers 268 | g_optim = optim.Adam([{ 269 | 'params': generator.convs.parameters(), 270 | 'lr' : 0.001 271 | }, { 272 | 'params': generator.to_rgbs.parameters(), 273 | 'lr' : 0.001 274 | }], lr=0.001, betas=(0.0, 0.99)) 275 | g_optim.add_param_group({ 276 | 'params': generator.fcs.parameters(), 277 | 'lr' : 0.001 * 0.01, 278 | 'mul' : 0.01 279 | }) 280 | d_optim = optim.Adam(discriminator.parameters(), lr=0.001, betas=(0.0, 0.99)) 281 | dataset = datasets.ImageFolder(image_folder_path) 282 | 283 | if is_continue: 284 | if os.path.exists('checkpoint/trained.pth'): 285 | # Load data from last checkpoint 286 | print('Loading pre-trained model...') 287 | checkpoint = torch.load('checkpoint/trained.pth') 288 | generator.load_state_dict(checkpoint['generator']) 289 | discriminator.load_state_dict(checkpoint['discriminator']) 290 | g_optim.load_state_dict(checkpoint['g_optim']) 291 | d_optim.load_state_dict(checkpoint['d_optim']) 292 | step, iteration, startpoint, used_sample, alpha = checkpoint['parameters'] 293 | d_losses = checkpoint.get('d_losses', [float('inf')]) 294 | g_losses = checkpoint.get('g_losses', [float('inf')]) 295 | print('Start training from loaded model...') 296 | else: 297 | print('No pre-trained model detected, restart training...') 298 | 299 | generator.train() 300 | discriminator.train() 301 | d_losses, g_losses = train(generator, discriminator, g_optim, d_optim, dataset, step, iteration, startpoint, used_sample, d_losses, g_losses, alpha) --------------------------------------------------------------------------------