├── .gitignore ├── LICENSE ├── README.md ├── loss └── loss.py ├── network └── stylegan2.py ├── opts └── opts.py ├── torchvision_sunner ├── __init__.py ├── constant.py ├── data │ ├── __init__.py │ ├── base_dataset.py │ ├── image_dataset.py │ ├── loader.py │ └── video_dataset.py ├── read.py ├── transforms │ ├── __init__.py │ ├── base.py │ ├── categorical.py │ ├── complex.py │ ├── function.py │ └── simple.py └── utils.py ├── train.py └── utils ├── libs.py ├── stylegan-teaser.png └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 tomguluson92 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A PyTorch Implementation of StyleGAN2 (Unofficial) 2 | 3 | ![Github](https://img.shields.io/badge/PyTorch-v1.0.1-green.svg?style=for-the-badge&logo=data:image/png) 4 | ![Github](https://img.shields.io/badge/python-3.6-green.svg?style=for-the-badge&logo=python) 5 | ![Github](https://img.shields.io/badge/status-AlmostFinished-blue.svg?style=for-the-badge&logo=fire) 6 | ![Github](https://img.shields.io/badge/license-CC_BY--NC-green.svg?style=for-the-badge&logo=fire) 7 | 8 | This repository contains a PyTorch implementation of the following paper: 9 | > **Analyzing and Improving the Image Quality of StyleGAN** (StyleGAN2) 10 | > Authors: Tero Karras, Samuli Laine, Miika Aittala, Janne Hellsten, Jaakko Lehtinen, Timo Aila 11 | 12 | > Paper: http://arxiv.org/abs/1912.04958 13 | > Video: https://youtu.be/c-NJtV9Jvp0 14 | 15 | 16 | ## Motivation 17 | To the best of my knowledge, there is still not a similar pytorch 1.0 implementation of styleGAN2 as NvLabs released(Tensorflow), 18 | therefore, i wanna implement it on pytorch1.0.1 to extend its usage in pytorch community. 19 | 20 | ## Notice 21 | @date: 2019.12.16 22 | 23 | @info: settings are in opts/opts.py. you can change to your own dataset and choose resolution at 64, 128, 256 and so on. 24 | 25 | @date: 2019.12.27 26 | 27 | @info: **Need Help!** 28 | \ 29 | After about 2 weeks experiment, this version is still hard to converge. I am pretty confident about 30 | my G&D definition are strictly follow the same with the originial [stylegan2](https://github.com/NVlabs/stylegan2). 31 | So if you are willing to make this project converge, please feel free to change it! 32 | **Especially in training paradigm!** 33 | 34 | ## Author 35 | 36 | - [Samuel Ko](https://blog.csdn.net/g11d111) 37 | 38 | ## Training 39 | 40 | ``` python 41 | # ① pass your own dataset of training, batchsize and common settings in TrainOpts of `opts.py`. 42 | 43 | # ② run train_stylegan.py 44 | python3 train.py --path `your_own_dataset_path` 45 | 46 | # ③ you can get intermediate pics generated by stylegenerator in `opts.det/images/` 47 | ``` 48 | 49 | ## Project 50 | > we follow the release code of styleGAN2 carefully and if you found any bug or mistake in implementation, 51 | > please tell us and improve it, thank u very much! . 52 | 53 | 54 | ## Related 55 | [1. StyleGAN - Official TensorFlow Implementation](https://github.com/NVlabs/stylegan) 56 | 57 | [2. The re-implementation of style-based generator idea](https://github.com/SunnerLi/StyleGAN_demo) 58 | 59 | [3. ptrblck_styleGAN](https://github.com/lernapparat/lernapparat/blob/master/style_gan/pytorch_style_gan.ipynb) 60 | 61 | ## System Requirements 62 | - Ubuntu18.04 63 | - PyTorch 1.0.1 64 | - Numpy 1.13.3 65 | - torchvision 0.2.2 66 | - scikit-image 0.15.0 67 | - tqdm 68 | - GTX 1080Ti or above 69 | 70 | ## Q&A 71 | 72 | ## Acknowledgements 73 | 74 | My Email is **samuel.gao023@gmail.com**, if you have any question and wanna to PR, please let me know, thank you. 75 | -------------------------------------------------------------------------------- /loss/loss.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | """ 3 | @author: samuel ko 4 | """ 5 | 6 | from torch.autograd import Variable 7 | from torch.autograd import grad 8 | import torch.autograd as autograd 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torch 12 | 13 | import numpy as np 14 | 15 | 16 | def D_logistic_r1(real_img, D, gamma=10.0): 17 | # gradient penalty 18 | reals = Variable(real_img, requires_grad=True).to(real_img.device) 19 | real_logit = D(reals) 20 | 21 | real_grads = grad(torch.sum(real_logit), reals)[0] 22 | gradient_penalty = torch.sum(torch.mul(real_grads, real_grads), dim=[1, 2, 3]) 23 | return gradient_penalty * (gamma * 0.5) 24 | 25 | 26 | 27 | # ============================================================================== 28 | # R1 and R2 regularizers from the paper 29 | # "Which Training Methods for GANs do actually Converge?", Mescheder et al. 2018 30 | # ============================================================================== 31 | 32 | # def D_logistic_r1(fake_img, real_img, D, gamma=10.0): 33 | # real_img = Variable(real_img, requires_grad=True).to(real_img.device) 34 | # fake_img = Variable(fake_img, requires_grad=True).to(fake_img.device) 35 | # 36 | # real_score = D(real_img) 37 | # fake_score = D(fake_img) 38 | # 39 | # loss = F.softplus(fake_score) 40 | # loss = loss + F.softplus(-real_score) 41 | # 42 | # # GradientPenalty 43 | # # One of the differentiated Tensors does not require grad? 44 | # # https://discuss.pytorch.org/t/one-of-the-differentiated-tensors-does-not-require-grad/54694 45 | # real_grads = grad(torch.sum(real_score), real_img)[0] 46 | # gradient_penalty = torch.sum(torch.mul(real_grads, real_grads), dim=[1, 2, 3]) 47 | # reg = gradient_penalty * (gamma * 0.5) 48 | # 49 | # # fixme: only support non-lazy mode 50 | # return loss + reg 51 | 52 | 53 | def D_logistic_r2(fake_img, real_img, D, gamma=10.0): 54 | real_img = Variable(real_img, requires_grad=True).to(real_img.device) 55 | fake_img = Variable(fake_img, requires_grad=True).to(fake_img.device) 56 | 57 | real_score = D(real_img) 58 | fake_score = D(fake_img) 59 | 60 | loss = F.softplus(fake_score) 61 | loss = loss + F.softplus(-real_score) 62 | 63 | # GradientPenalty 64 | # One of the differentiated Tensors does not require grad? 65 | # https://discuss.pytorch.org/t/one-of-the-differentiated-tensors-does-not-require-grad/54694 66 | fake_grads = grad(torch.sum(fake_score), fake_img)[0] 67 | gradient_penalty = torch.sum(torch.square(fake_grads), dim=[1, 2, 3]) 68 | reg = gradient_penalty * (gamma * 0.5) 69 | 70 | # fixme: only support non-lazy mode 71 | return loss + reg 72 | 73 | 74 | # ============================================================================== 75 | # Non-saturating logistic loss with path length regularizer from the paper 76 | # "Analyzing and Improving the Image Quality of StyleGAN", Karras et al. 2019 77 | # ============================================================================== 78 | 79 | 80 | def G_logistic_ns_pathreg(x, D, opts, pl_decay=0.01, pl_weight=2.0): 81 | 82 | fake_images_out, fake_dlatents_out = x 83 | 84 | fake_images_out = Variable(fake_images_out, requires_grad=True).to(fake_images_out.device) 85 | fake_scores_out = D(fake_images_out) 86 | loss = F.softplus(-fake_scores_out) 87 | 88 | fake_dlatents_out = Variable(fake_dlatents_out, requires_grad=True).to(fake_dlatents_out.device) 89 | # Compute |J*y|. 90 | pl_noise = torch.randn(fake_images_out.shape) / np.sqrt(fake_images_out.shape[2] * fake_images_out.shape[3]) 91 | pl_noise = pl_noise.to(fake_images_out.device) 92 | pl_grads = grad(torch.sum(fake_images_out * pl_noise), fake_dlatents_out, retain_graph=True)[0] 93 | pl_lengths = torch.sqrt(torch.sum(torch.sum(torch.mul(pl_grads, pl_grads), dim=2), dim=1)) 94 | pl_mean = pl_decay * torch.sum(pl_lengths) 95 | 96 | # Calculate (|J*y|-a)^2. 97 | # Computes square of x element-wise 98 | # https://discuss.pytorch.org/t/computes-square-of-x-element-wise/9079 99 | pl_penalty = torch.mul(pl_lengths - pl_mean, pl_lengths - pl_mean) 100 | 101 | # Apply weight. 102 | # Note: The division in pl_noise decreases the weight by num_pixels, and the reduce_mean 103 | # in pl_lengths decreases it by num_affine_layers. The effective weight then becomes: 104 | # 105 | # gamma_pl = pl_weight / num_pixels / num_affine_layers 106 | # = 2 / (r^2) / (log2(r) * 2 - 2) 107 | # = 1 / (r^2 * (log2(r) - 1)) 108 | # = ln(2) / (r^2 * (ln(r) - ln(2)) 109 | # 110 | reg = pl_penalty * pl_weight 111 | 112 | # fixme: only support non-lazy mode 113 | return loss + reg 114 | -------------------------------------------------------------------------------- /network/stylegan2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | StyleGAN2 pytorch 5 | 6 | @author: samuel ko 7 | @date: 2019.12.13 8 | 9 | @notice: 1) fused_conv: unsupport 10 | 2) 4x4 upfirdn kernel: (transfer to 3x3 upfirdn kernel) 11 | 12 | @date: 2019.12.18 13 | 14 | @update: 1) fix upfirdn2d [1, 3, 1] to original [1, 3, 3, 1] in D_stylegan2 and Upsample2d. 15 | 2) use_wscale = True (default), gain = 1 (default). 16 | 3) update he_std calculation method to coordinate with the original repo. 17 | 18 | 19 | @date: 2019.12.20 20 | 21 | @update: 1) split ModulatedConv2d into 2 part. 22 | 23 | @date: 2020.01.02 24 | 25 | @update: 1) refact initialization part. 26 | @ref: https://stackoverflow.com/questions/51136581/how-to-create-a-normal-distribution-in-pytorch 27 | """ 28 | 29 | import os 30 | import torch.nn.functional as F 31 | import torch.nn as nn 32 | import numpy as np 33 | import torch 34 | from torch.nn import ModuleList 35 | 36 | from utils.libs import _setup_kernel, _approximate_size 37 | from opts.opts import TrainOptions, INFO 38 | 39 | import copy 40 | from tqdm import tqdm 41 | 42 | from torchvision.utils import save_image 43 | from matplotlib import pyplot as plt 44 | from utils.utils import plotLossCurve 45 | 46 | 47 | # from utils.libs import ShrinkFun 48 | 49 | # shrink_fun = ShrinkFun.apply 50 | 51 | 52 | # ========================================================================= 53 | # Define components for G_mapping & G_synthesis_stylegan2 & D_stylegan2 54 | # ========================================================================= 55 | 56 | class PixelNorm(nn.Module): 57 | def __init__(self, epsilon=1e-8): 58 | """ 59 | @notice: avoid in-place ops. 60 | https://discuss.pytorch.org/t/encounter-the-runtimeerror-one-of-the-variables-needed-for-gradient-computation-has-been-modified-by-an-inplace-operation/836/3 61 | """ 62 | super(PixelNorm, self).__init__() 63 | self.epsilon = epsilon 64 | 65 | def forward(self, x): 66 | tmp = torch.mul(x, x) # or x ** 2 67 | tmp1 = torch.rsqrt(torch.mean(tmp, dim=1, keepdim=True) + self.epsilon) 68 | 69 | return x * tmp1 70 | 71 | 72 | class BiasAdd(nn.Module): 73 | 74 | def __init__(self, 75 | channels, 76 | opts, 77 | act='linear', alpha=None, gain=None, lrmul=1): 78 | """ 79 | BiasAdd 80 | """ 81 | super(BiasAdd, self).__init__() 82 | 83 | self.opts = opts 84 | self.bias = torch.nn.Parameter((torch.zeros(channels, 1, 1) * lrmul)) 85 | 86 | self.act = act 87 | self.alpha = alpha if alpha is not None else 0.2 88 | self.gain = gain if gain is not None else 1.0 89 | 90 | def forward(self, x): 91 | # Pass Add bias. 92 | x += self.bias 93 | 94 | # Evaluate activation function. 95 | if self.act == "linear": 96 | pass 97 | elif self.act == 'lrelu': 98 | x = F.leaky_relu(x, self.alpha, inplace=True) 99 | x = x * np.sqrt(2) # original repo def_gain=np.sqrt(2). 100 | 101 | # Scale by gain. 102 | if self.gain != 1: 103 | x = x * self.gain 104 | 105 | return x 106 | 107 | 108 | class FC(nn.Module): 109 | def __init__(self, 110 | in_channels, 111 | out_channels, 112 | gain=1, 113 | use_wscale=True, 114 | lrmul=1.0, 115 | bias=True, 116 | act='lrelu', 117 | mode='normal'): 118 | """ 119 | The complete conversion of Dense/FC/Linear Layer of original Tensorflow version. 120 | """ 121 | super(FC, self).__init__() 122 | he_std = gain / np.sqrt((in_channels * out_channels)) # He init 123 | if use_wscale: 124 | init_std = 1.0 / lrmul 125 | self.w_lrmul = he_std * lrmul 126 | else: 127 | init_std = he_std / lrmul 128 | self.w_lrmul = lrmul 129 | 130 | self.weight = torch.nn.Parameter(torch.empty(out_channels, in_channels).normal_(0, init_std)) 131 | if bias: 132 | self.bias = torch.nn.Parameter(torch.zeros(out_channels)) 133 | self.b_lrmul = lrmul 134 | else: 135 | self.bias = None 136 | 137 | self.act = act 138 | self.mode = mode 139 | 140 | def forward(self, x): 141 | if self.bias is not None and self.mode != 'modulate': 142 | out = F.linear(x, self.weight * self.w_lrmul, self.bias * self.b_lrmul) 143 | elif self.bias is not None and self.mode == 'modulate': 144 | # original 145 | # out = F.linear(x, self.weight * self.w_lrmul, self.bias * self.b_lrmul) + 1 146 | # re-implement 147 | out = F.linear(x, self.weight * self.w_lrmul, self.bias * self.b_lrmul) 148 | else: 149 | out = F.linear(x, self.weight * self.w_lrmul) 150 | 151 | if self.act == 'lrelu': 152 | out = F.leaky_relu(out, 0.2, inplace=True) 153 | out = out * np.sqrt(2) # original repo def_gain=np.sqrt(2). 154 | return out 155 | elif self.act == 'linear': 156 | return out 157 | 158 | return out 159 | 160 | 161 | class Conv2d(nn.Module): 162 | def __init__(self, 163 | input_channels, 164 | output_channels, 165 | kernel_size, 166 | gain=1, 167 | use_wscale=True, 168 | lrmul=1, 169 | bias=True, 170 | act='linear'): 171 | super().__init__() 172 | 173 | assert kernel_size >= 1 and kernel_size % 2 == 1 174 | he_std = gain / np.sqrt((input_channels * output_channels * kernel_size * kernel_size)) # He init 175 | self.kernel_size = kernel_size 176 | self.act = act 177 | 178 | if use_wscale: 179 | init_std = 1.0 / lrmul 180 | self.w_lrmul = he_std * lrmul 181 | else: 182 | init_std = he_std / lrmul 183 | self.w_lrmul = lrmul 184 | 185 | self.weight = torch.nn.Parameter(torch.empty(output_channels, input_channels, kernel_size, kernel_size).normal_(0, init_std)) 186 | if bias: 187 | self.bias = torch.nn.Parameter(torch.zeros(output_channels)) 188 | self.b_lrmul = lrmul 189 | else: 190 | self.bias = None 191 | 192 | def forward(self, x): 193 | if self.bias is not None: 194 | out = F.conv2d(x, self.weight * self.w_lrmul, self.bias * self.b_lrmul, padding=self.kernel_size // 2) 195 | else: 196 | out = F.conv2d(x, self.weight * self.w_lrmul, padding=self.kernel_size // 2) 197 | 198 | if self.act == 'lrelu': 199 | out = F.leaky_relu(out, 0.2, inplace=True) 200 | out = out * np.sqrt(2) # original repo def_gain=np.sqrt(2). 201 | return out 202 | elif self.act == 'linear': 203 | return out 204 | 205 | 206 | class FromRGB(nn.Module): 207 | """ 208 | default non_linearity: LeakyReLU(0.2), def_gain= np.sqrt(2). 209 | """ 210 | 211 | def __init__(self, input_channels, output_channels, use_wscale=True, lrmul=1): 212 | super().__init__() 213 | self.conv = Conv2d(input_channels=input_channels, 214 | output_channels=output_channels, 215 | kernel_size=1, use_wscale=use_wscale, lrmul=lrmul) 216 | 217 | def forward(self, x): 218 | x, y = x 219 | y1 = self.conv(y) 220 | out = F.leaky_relu(y1, 0.2, inplace=True) 221 | out = out * np.sqrt(2) # original repo def_gain=np.sqrt(2). 222 | return out if x is None else out + x 223 | 224 | 225 | class ModulatedConv2d(nn.Module): 226 | """ 227 | Modulated convolution layer for G_synthesis_stylegan2. 228 | 229 | @date: 2019.12.19 230 | @update: 1) initialization update (He init). 231 | 2) refact ModulatedConv2d. 232 | """ 233 | 234 | def __init__(self, input_channels, output_channels, 235 | kernel_size, 236 | opts, 237 | k=[1, 3, 3, 1], 238 | dlatent_size=512, 239 | up=False, 240 | down=False, 241 | demodulate=True, 242 | gain=1, 243 | use_wscale=True, 244 | lrmul=1, 245 | fused_modconv=True): 246 | super().__init__() 247 | assert kernel_size >= 1 and kernel_size % 2 == 1 248 | 249 | self.demodulate = demodulate 250 | self.fused_modconv = fused_modconv 251 | self.up, self.down = up, down 252 | self.fmaps = output_channels 253 | self.opts = opts 254 | 255 | self.conv = Conv2d(input_channels=input_channels, 256 | output_channels=output_channels, 257 | kernel_size=1, use_wscale=use_wscale, lrmul=lrmul) 258 | 259 | he_std = gain / np.sqrt((input_channels * output_channels * kernel_size * kernel_size)) # He init 260 | self.kernel_size = kernel_size 261 | if use_wscale: 262 | init_std = 1.0 / lrmul 263 | self.w_lrmul = he_std * lrmul 264 | else: 265 | init_std = he_std / lrmul 266 | self.w_lrmul = lrmul 267 | 268 | self.w = torch.nn.Parameter(torch.empty(output_channels, input_channels, kernel_size, kernel_size).normal_(0, init_std)) 269 | self.convH, self.convW = self.w.shape[2:] 270 | 271 | self.dense = FC(dlatent_size, input_channels, gain, lrmul=lrmul, use_wscale=use_wscale, mode='modulate', 272 | act='linear') 273 | 274 | if self.up: 275 | factor = 2 276 | self.k = _setup_kernel(k) * (gain * (factor ** 2)) # 4 x 4 277 | self.k = torch.FloatTensor(self.k).unsqueeze(0).unsqueeze(0) 278 | self.k = torch.flip(self.k, [2, 3]) 279 | self.k = torch.nn.Parameter(self.k, requires_grad=False) 280 | 281 | self.p = self.k.shape[0] - factor - (kernel_size - 1) 282 | 283 | self.padx0, self.pady0 = (self.p + 1) // 2 + factor - 1, (self.p + 1) // 2 + factor - 1 284 | self.padx1, self.pady1 = self.p // 2 + 1, self.p // 2 + 1 285 | 286 | self.kernelH, self.kernelW = self.k.shape[2:] 287 | 288 | def forward(self, x): 289 | x, y = x 290 | if len(y.shape) > 2: 291 | # y is dlatent in ToRGB. 292 | y = y.squeeze(1) 293 | # x Input: N, C, H, W (NxCx4x4, NxCx8x8, NxCx16x16, ...) 294 | # y Input: Disentangled latents(W) [minibatch, 1, dlatent_size]. 295 | 296 | # Modulate. 297 | s = self.dense(y) # [BI] Transform incoming W to style. 298 | 299 | # OIkk ---> BkkOI ---> BkkIO 300 | self.ww = (self.w * self.w_lrmul).unsqueeze(0) 301 | self.ww = self.ww.repeat(s.shape[0], 1, 1, 1, 1) 302 | self.ww = self.ww.permute(0, 3, 4, 1, 2) 303 | self.ww = self.ww * s.unsqueeze(1).unsqueeze(1).unsqueeze(1) # [BkkOI] Scale input feature maps. 304 | self.ww = self.ww.permute(0, 1, 2, 4, 3) # [BkkIO] 305 | 306 | # Demodulate. 307 | if self.demodulate: 308 | d = torch.mul(self.ww, self.ww) 309 | d = torch.rsqrt(torch.sum(d, dim=[1, 2, 3]) + 1e-8) # [BO] Scaling factor. 310 | self.ww = self.ww * (d.unsqueeze(1).unsqueeze(1).unsqueeze(1)) # [BkkIO] Scale output feature maps. 311 | 312 | # Reshape/scale input. 313 | if self.fused_modconv: 314 | x = x.view(1, -1, x.shape[2], x.shape[3]) # Fused => reshape minibatch to convolution groups. 315 | self.w_new = torch.reshape(self.ww.permute(0, 4, 3, 1, 2), 316 | (-1, x.shape[1], self.ww.shape[1], self.ww.shape[2])) 317 | else: 318 | x = x * (s.unsqueeze(-1).unsqueeze(-1)) # [BIhw] Not fused => scale input activations. 319 | self.w_new = self.w * self.w_lrmul 320 | 321 | # Convolution with optional up/downsampling. 322 | if self.up: 323 | outC, inC, convH, convW = self.w_new.shape[0], self.w_new.shape[1], self.w_new.shape[2], self.w_new.shape[3] 324 | 325 | # Transpose Weight 326 | num_groups = x.shape[1] // inC if (x.shape[1] // inC) >= 1 else 1 327 | self.w_new = self.w_new.reshape(-1, num_groups, inC, convH, convW) 328 | self.w_new = self.w_new.flip([3, 4]) 329 | self.w_new = self.w_new.permute(2, 1, 0, 3, 4) 330 | self.w_new = self.w_new.reshape(inC, outC, convH, convW) 331 | 332 | x = F.conv_transpose2d(x, self.w_new, stride=2) 333 | 334 | # step 2: upfirdn2d 335 | y = x.clone() 336 | y = y.reshape([-1, x.shape[2], x.shape[3], 1]) # N C H W ---> N*C H W 1 337 | 338 | inC, inH, inW = x.shape[1:] 339 | # 1) Upsample 340 | y = y.reshape(-1, inH, inW, 1) 341 | 342 | # 2) Pad (crop if negative). 343 | y = F.pad(y, (0, 0, 344 | max(self.pady0, 0), max(self.pady1, 0), 345 | max(self.padx0, 0), max(self.padx1, 0), 346 | 0, 0 347 | )) 348 | y = y[:, 349 | max(-self.pady0, 0): y.shape[1] - max(-self.pady1, 0), 350 | max(-self.padx0, 0): y.shape[2] - max(-self.padx1, 0), 351 | :] 352 | 353 | # 3) Convolve with filter. 354 | y = y.permute(0, 3, 1, 2) # N*C H W 1 --> N*C 1 H W 355 | y = y.reshape(-1, 1, inH + self.pady0 + self.pady1, inW + self.padx0 + self.padx1) 356 | y = F.conv2d(y, self.k) 357 | y = y.view(-1, 1, inH + self.pady0 + self.pady1 - self.kernelH + 1, 358 | inW + self.padx0 + self.padx1 - self.kernelW + 1) 359 | 360 | # 4) Downsample (throw away pixels). 361 | if inH != y.shape[1] or inH % 2 != 0: 362 | inH = inW = _approximate_size(inH) 363 | y = F.interpolate(y, size=(inH, inW), mode='bilinear') 364 | y = y.permute(0, 2, 3, 1) 365 | x = y.reshape(-1, inC, inH, inW) 366 | 367 | elif self.down: 368 | pass 369 | else: 370 | x = F.conv2d(x, 371 | self.w_new, 372 | padding=self.w_new.shape[2] // 2) 373 | # Reshape/scale output. 374 | if self.fused_modconv: 375 | x = x.reshape(-1, self.fmaps, x.shape[2], x.shape[3]) # Fused => reshape convolution groups back to minibatch. 376 | elif self.demodulate: 377 | x = x * d.unsqueeze(-1).unsqueeze(-1) # [BOhw] Not fused => scale output activations. 378 | 379 | return x 380 | 381 | 382 | class ToRGB(nn.Module): 383 | """ 384 | default non_linearity: LeakyReLU(0.2), def_gain= np.sqrt(2). 385 | 386 | 2019.12.18 fix. 387 | """ 388 | 389 | def __init__(self, input_channels, output_channels, 390 | res, 391 | opts, 392 | use_wscale=True, 393 | lrmul=1, 394 | gain=1, 395 | fused_modconv=True): 396 | super().__init__() 397 | assert res >= 2 398 | 399 | self.modulated_conv2d = ModulatedConv2d(input_channels=input_channels, 400 | output_channels=output_channels, 401 | kernel_size=1, 402 | up=False, 403 | use_wscale=use_wscale, 404 | lrmul=lrmul, 405 | gain=gain, 406 | demodulate=False, 407 | fused_modconv=fused_modconv, 408 | opts=opts) 409 | self.biasAdd = BiasAdd(opts=opts, 410 | act='linear', 411 | channels=output_channels) 412 | 413 | self.res = res 414 | self.opts = opts 415 | 416 | def forward(self, x): 417 | x, y, dlatent = x 418 | dlatent = dlatent[:, self.res * 2 - 3] 419 | 420 | x = self.modulated_conv2d([x, dlatent]) 421 | t = self.biasAdd(x) 422 | 423 | return t if y is None else y + t 424 | 425 | 426 | class GLayer(nn.Module): 427 | """ 428 | GLayer. 429 | """ 430 | 431 | def __init__(self, input_channels, output_channels, 432 | layer_idx, 433 | opts, 434 | k=[1, 3, 3, 1], 435 | randomize_noise=True, 436 | up=False, 437 | use_wscale=True, 438 | lrmul=1, 439 | fused_modconv=True, 440 | act='lrelu'): 441 | super().__init__() 442 | 443 | self.randomize_noise = randomize_noise 444 | self.opts = opts 445 | self.up = up 446 | self.layer_idx = layer_idx 447 | 448 | self.modulated_conv2d = ModulatedConv2d(input_channels=input_channels, 449 | output_channels=output_channels, 450 | kernel_size=3, 451 | k=k, 452 | use_wscale=use_wscale, 453 | lrmul=lrmul, 454 | demodulate=True, 455 | fused_modconv=fused_modconv, 456 | up=up, 457 | opts=opts) 458 | 459 | # fixme: when you calling .to() on the parameter, it means that you are creating a non-leaf variable! 460 | self.noise_strength = torch.nn.Parameter(torch.zeros(1)) 461 | self.biasAdd = BiasAdd(act=act, 462 | channels=output_channels, 463 | opts=opts) 464 | 465 | def forward(self, x): 466 | x, dlatent = x 467 | if len(dlatent.shape) > 2: 468 | dlatent = dlatent[:, self.layer_idx] 469 | 470 | x = self.modulated_conv2d([x, dlatent]) 471 | 472 | noise = 0 473 | if self.randomize_noise: 474 | noise = torch.randn(x.shape[0], 1, x.shape[2], x.shape[3]).to(self.opts.device) 475 | 476 | x = x + noise * self.noise_strength 477 | x = self.biasAdd(x) 478 | 479 | return x 480 | 481 | 482 | class Upsample2d(nn.Module): 483 | def __init__(self, 484 | opts, 485 | k=[1, 3, 3, 1], 486 | factor=2, 487 | down=1, 488 | gain=1): 489 | """ 490 | Upsample2d method in G_synthesis_stylegan2. 491 | :param k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). 492 | The default is `[1] * factor`, which corresponds to average pooling. 493 | :param factor: Integer downsampling factor (default: 2). 494 | :param gain: Scaling factor for signal magnitude (default: 1.0). 495 | 496 | Returns: Tensor of the shape `[N, C, H // factor, W // factor]` 497 | """ 498 | super().__init__() 499 | assert isinstance(factor, int) and factor >= 1, "factor must be larger than 1! (default: 2)" 500 | 501 | self.gain = gain 502 | self.factor = factor 503 | self.opts = opts 504 | 505 | self.k = _setup_kernel(k) * (self.gain * (factor ** 2)) # 4 x 4 506 | self.k = torch.FloatTensor(self.k).unsqueeze(0).unsqueeze(0) 507 | self.k = torch.flip(self.k, [2, 3]) 508 | self.k = nn.Parameter(self.k, requires_grad=False) 509 | 510 | self.p = self.k.shape[0] - self.factor 511 | 512 | self.padx0, self.pady0 = (self.p + 1) // 2 + factor - 1, (self.p + 1) // 2 + factor - 1 513 | self.padx1, self.pady1 = self.p // 2, self.p // 2 514 | 515 | self.kernelH, self.kernelW = self.k.shape[2:] 516 | self.down = down 517 | 518 | def forward(self, x): 519 | y = x.clone() 520 | y = y.reshape([-1, x.shape[2], x.shape[3], 1]) # N C H W ---> N*C H W 1 521 | 522 | inC, inH, inW = x.shape[1:] 523 | # step 1: upfirdn2d 524 | 525 | # 1) Upsample 526 | y = torch.reshape(y, (-1, inH, 1, inW, 1, 1)) 527 | y = F.pad(y, (0, 0, self.factor - 1, 0, 0, 0, self.factor - 1, 0, 0, 0, 0, 0)) 528 | y = torch.reshape(y, (-1, 1, inH * self.factor, inW * self.factor)) 529 | 530 | # 2) Pad (crop if negative). 531 | y = F.pad(y, (0, 0, 532 | max(self.pady0, 0), max(self.pady1, 0), 533 | max(self.padx0, 0), max(self.padx1, 0), 534 | 0, 0 535 | )) 536 | y = y[:, 537 | max(-self.pady0, 0): y.shape[1] - max(-self.pady1, 0), 538 | max(-self.padx0, 0): y.shape[2] - max(-self.padx1, 0), 539 | :] 540 | 541 | # 3) Convolve with filter. 542 | y = y.permute(0, 3, 1, 2) # N*C H W 1 --> N*C 1 H W 543 | y = y.reshape(-1, 1, inH * self.factor + self.pady0 + self.pady1, inW * self.factor + self.padx0 + self.padx1) 544 | y = F.conv2d(y, self.k) 545 | y = y.view(-1, 1, 546 | inH * self.factor + self.pady0 + self.pady1 - self.kernelH + 1, 547 | inW * self.factor + self.padx0 + self.padx1 - self.kernelW + 1) 548 | 549 | # 4) Downsample (throw away pixels). 550 | if inH * self.factor != y.shape[1]: 551 | y = F.interpolate(y, size=(inH * self.factor, inW * self.factor), mode='bilinear') 552 | y = y.permute(0, 2, 3, 1) 553 | y = y.reshape(-1, inC, inH * self.factor, inW * self.factor) 554 | 555 | return y 556 | 557 | 558 | class ConvDownsample2d(nn.Module): 559 | def __init__(self, 560 | kernel_size, 561 | input_channels, 562 | output_channels, 563 | k=[1, 3, 3, 1], 564 | factor=2, 565 | gain=1, 566 | use_wscale=True, 567 | lrmul=1, 568 | bias=False, 569 | act='linear'): 570 | """ 571 | ConvDownsample2D method in D_stylegan2. 572 | :param k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). 573 | The default is `[1] * factor`, which corresponds to average pooling. 574 | :param factor: Integer downsampling factor (default: 2). 575 | :param gain: Scaling factor for signal magnitude (default: 1.0). 576 | 577 | Returns: Tensor of the shape `[N, C, H // factor, W // factor]` 578 | """ 579 | super().__init__() 580 | assert isinstance(factor, int) and factor >= 1, "factor must be larger than 1! (default: 2)" 581 | assert kernel_size >= 1 and kernel_size % 2 == 1 582 | 583 | he_std = gain / np.sqrt((input_channels * output_channels * kernel_size * kernel_size)) # He init 584 | self.kernel_size = kernel_size 585 | if use_wscale: 586 | init_std = 1.0 / lrmul 587 | self.w_lrmul = he_std * lrmul 588 | else: 589 | init_std = he_std / lrmul 590 | self.w_lrmul = lrmul 591 | 592 | # https://discuss.pytorch.org/t/gaussian-distribution/35031 593 | self.weight = torch.nn.Parameter(torch.empty(output_channels, input_channels, kernel_size, kernel_size).normal_(0, init_std)) 594 | self.convH, self.convW = self.weight.shape[2:] 595 | 596 | if bias: 597 | self.bias = torch.nn.Parameter(torch.zeros(output_channels)) 598 | self.b_lrmul = lrmul 599 | else: 600 | self.bias = None 601 | 602 | self.gain = gain 603 | self.factor = factor 604 | self.act = act 605 | 606 | self.k = _setup_kernel(k) * self.gain # 3 x 3. (original 4 x 4). 607 | self.k = torch.FloatTensor(self.k).unsqueeze(0).unsqueeze(0) 608 | self.k = torch.flip(self.k, [2, 3]) 609 | self.k = nn.Parameter(self.k, requires_grad=False) 610 | 611 | self.p = (self.k.shape[-1] - self.factor) + (self.convW - 1) 612 | 613 | self.padx0, self.pady0 = (self.p + 1) // 2, (self.p + 1) // 2 614 | self.padx1, self.pady1 = self.p // 2, self.p // 2 615 | 616 | self.kernelH, self.kernelW = self.k.shape[2:] 617 | 618 | def forward(self, x): 619 | 620 | y = x.clone() 621 | y = y.reshape([-1, x.shape[2], x.shape[3], 1]) # N C H W ---> N*C H W 1 622 | 623 | inC, inH, inW = x.shape[1:] 624 | # step 1: upfirdn2d 625 | # 1) Upsample 626 | y = torch.reshape(y, (-1, inH, inW, 1)) 627 | 628 | # 2) Pad (crop if negative). 629 | y = F.pad(y, (0, 0, 630 | max(self.pady0, 0), max(self.pady1, 0), 631 | max(self.padx0, 0), max(self.padx1, 0), 632 | 0, 0 633 | )) 634 | y = y[:, 635 | max(-self.pady0, 0): y.shape[1] - max(-self.pady1, 0), 636 | max(-self.padx0, 0): y.shape[2] - max(-self.padx1, 0), 637 | :] 638 | 639 | # 3) Convolve with filter. 640 | y = y.permute(0, 3, 1, 2) # N*C H W 1 --> N*C 1 H W 641 | y = y.reshape(-1, 1, inH + self.pady0 + self.pady1, inW + self.padx0 + self.padx1) 642 | y = F.conv2d(y, self.k) 643 | y = y.view(-1, 1, inH + self.pady0 + self.pady1 - self.kernelH + 1, 644 | inW + self.padx0 + self.padx1 - self.kernelW + 1) 645 | 646 | # 4) Downsample (throw away pixels). 647 | if inH != y.shape[1]: 648 | y = F.interpolate(y, size=(inH, inW), mode='bilinear') 649 | y = y.permute(0, 2, 3, 1) 650 | y = y.reshape(-1, inC, inH, inW) 651 | 652 | # step 2: downsample (in general, stride = self.factor = 2) 653 | if self.bias is not None: 654 | x1 = F.conv2d(y, 655 | self.weight * self.w_lrmul, 656 | self.bias * self.b_lrmul, 657 | stride=self.factor, 658 | padding=self.convW // 2) 659 | else: 660 | x1 = F.conv2d(y, 661 | self.weight * self.w_lrmul, 662 | stride=self.factor, 663 | padding=self.convW // 2) 664 | # step 3: non-linearity. 665 | if self.act == 'lrelu': 666 | out = F.leaky_relu(x1, 0.2, inplace=True) 667 | out = out * np.sqrt(2) # original repo def_gain=np.sqrt(2). 668 | else: 669 | out = x1 670 | 671 | return out 672 | 673 | 674 | class GBlock(nn.Module): 675 | """ 676 | G_stylegan2 Basic Block. 677 | """ 678 | 679 | def __init__(self, 680 | input_channels, 681 | output_channels, 682 | layer_idx, 683 | opts, 684 | k=[1, 3, 3, 1], 685 | use_wscale=True, 686 | lrmul=1, 687 | architecture='skip'): 688 | super().__init__() 689 | 690 | self.arch = architecture 691 | 692 | self.conv0up = GLayer(input_channels, output_channels, 693 | layer_idx, 694 | up=True, 695 | k=k, 696 | use_wscale=use_wscale, 697 | lrmul=lrmul, 698 | opts=opts) 699 | self.conv1 = GLayer(output_channels, output_channels, 700 | layer_idx + 1, 701 | up=False, 702 | k=k, 703 | use_wscale=use_wscale, 704 | lrmul=lrmul, 705 | opts=opts) 706 | 707 | def forward(self, x): 708 | x, dlatent = x 709 | x = self.conv0up([x, dlatent]) 710 | x = self.conv1([x, dlatent]) 711 | 712 | if self.arch == 'resnet': 713 | raise Exception("unsupported resnet architecture yet~") 714 | 715 | return x 716 | 717 | 718 | class DBlock(nn.Module): 719 | """ 720 | D_stylegan2 Basic Block. 721 | """ 722 | 723 | def __init__(self, in1, in2, out3, 724 | use_wscale=True, 725 | lrmul=1, 726 | resample_kernel=[1, 3, 3, 1], 727 | architecture='resnet'): 728 | super().__init__() 729 | 730 | self.arch = architecture 731 | 732 | self.conv0 = Conv2d(input_channels=in1, 733 | output_channels=in2, 734 | kernel_size=3, 735 | use_wscale=use_wscale, 736 | lrmul=lrmul, 737 | bias=True, 738 | act='lrelu') 739 | 740 | self.conv1_down = ConvDownsample2d(kernel_size=3, 741 | input_channels=in2, 742 | output_channels=out3, 743 | k=resample_kernel, 744 | bias=True, 745 | act='lrelu') 746 | 747 | self.res_conv2_down = ConvDownsample2d(kernel_size=1, 748 | input_channels=in1, 749 | output_channels=out3, 750 | k=resample_kernel, 751 | bias=False) 752 | 753 | def forward(self, x): 754 | t = x.clone() 755 | 756 | x = self.conv0(x) 757 | x = self.conv1_down(x) 758 | 759 | if self.arch == 'resnet': 760 | t = self.res_conv2_down(t) 761 | x = (x + t) * (1 / np.sqrt(2)) 762 | return x 763 | 764 | 765 | # ========================================================================= 766 | # Minibatch standard deviation layer. (D_stylegan2) 767 | # ========================================================================= 768 | 769 | class Minibatch_stddev_layer(nn.Module): 770 | """ 771 | Minibatch standard deviation layer. (D_stylegan2) 772 | """ 773 | 774 | def __init__(self, group_size=4, num_new_features=1): 775 | super().__init__() 776 | 777 | self.group_size = group_size 778 | self.num_new_features = num_new_features 779 | 780 | def forward(self, x): 781 | n, c, h, w = x.shape 782 | 783 | group_size = min(n, self.group_size) # Minibatch must be divisible by (or smaller than) group_size. 784 | y = x.view(group_size, -1, 785 | self.num_new_features, 786 | c // self.num_new_features, 787 | h, w) # [GMncHW] Split minibatch into M groups of size G. Split channels into n channel groups c. 788 | y = y - torch.mean(y, dim=0, keepdim=True) # [GMncHW] Subtract mean over group. 789 | y = torch.mean(y ** 2, dim=0) # [MncHW] Calc variance over group. 790 | y = torch.sqrt(y + 1e-8) # [MncHW] Calc stddev over group. 791 | y = torch.mean(y, dim=[2, 3, 4], keepdim=True) # [Mn111] Take average over fmaps and pixels. 792 | y = torch.mean(y, dim=2) # [Mn11] Split channels into c channel groups 793 | # How to tile a tensor? https://discuss.pytorch.org/t/how-to-tile-a-tensor/13853 794 | y = y.repeat(group_size, 1, h, w) # [NnHW] Replicate over group and pixels. 795 | 796 | return torch.cat([x, y], 1) # [NCHW] Append as new fmap. 797 | 798 | 799 | # ========================================================================= 800 | # Define G_mapping 801 | # 1) support variant mapping_fmaps in G_mapping. 802 | # ========================================================================= 803 | 804 | class G_mapping(nn.Module): 805 | def __init__(self, 806 | mapping_fmaps=512, 807 | dlatent_size=512, 808 | resolution=1024, 809 | label_size=0, # Label dimensionality, 0 if no labels. (non-support) 810 | mapping_layers=8, # Number of mapping layers. 811 | normalize_latents=True, # Normalize latent vectors (Z) before feeding them to the mapping layers? 812 | use_wscale=True, # Enable equalized learning rate? 813 | lrmul=0.01, # Learning rate multiplier for the mapping layers. 814 | gain=1 # original gain in tensorflow. 815 | ): 816 | super(G_mapping, self).__init__() 817 | self.mapping_fmaps = mapping_fmaps 818 | self.mapping_layers = mapping_layers 819 | 820 | self.fc1 = FC(self.mapping_fmaps, dlatent_size, gain=gain, lrmul=lrmul, use_wscale=use_wscale) 821 | self.fc_layers = ModuleList([]) 822 | for _ in range(2, mapping_layers + 1): 823 | self.fc_layers.append(FC(dlatent_size, dlatent_size, gain=gain, lrmul=lrmul, use_wscale=use_wscale)) 824 | 825 | self.normalize_latents = normalize_latents 826 | self.resolution_log2 = int(np.log2(resolution)) 827 | self.num_layers = self.resolution_log2 * 2 - 2 828 | self.pixel_norm = PixelNorm() 829 | 830 | def forward(self, x): 831 | if self.normalize_latents: 832 | x = self.pixel_norm(x) 833 | 834 | out = self.fc1(x) 835 | for fc in self.fc_layers: 836 | out = fc(out) 837 | 838 | out = out.unsqueeze(1) 839 | out = out.repeat(1, self.num_layers, 1) 840 | 841 | return out 842 | 843 | 844 | # ========================================================================= 845 | # Define G_synthesis_stylegan2 846 | # ========================================================================= 847 | 848 | class G_synthesis_stylegan2(nn.Module): 849 | def __init__(self, 850 | opts, 851 | fmap_base=8 << 10, # stylegan1 8192 (8 << 10), stylegan2 16384 (16 << 10) 852 | num_channels=3, # Number of output color channels. 853 | dlatent_size=512, # Disentangled latent (W) dimensionality. 854 | resolution=1024, # Output resolution. 855 | randomize_noise=True, 856 | # True = randomize noise inputs every time (non-deterministic), False = read noise inputs from variables. 857 | fmap_decay=1.0, # log2 feature map reduction when doubling the resolution. 858 | fmap_min=1, # Minimum number of feature maps in any layer. 859 | fmap_max=512, # Maximum number of feature maps in any layer. 860 | architecture='skip', # Architecture: 'orig', 'skip', 'resnet'. 861 | use_wscale=True, # Enable equalized learning rate? 862 | lrmul=1, # Learning rate multiplier for the mapping layers. 863 | gain=1, # original gain in tensorflow. 864 | act='lrelu', # Activation function: 'linear', 'lrelu'. 865 | resample_kernel=[1, 3, 3, 1], 866 | # Low-pass filter to apply when resampling activations. None = no filtering. 867 | fused_modconv=True, # Implement modulated_conv2d_layer() as a single fused op? 868 | ): 869 | super(G_synthesis_stylegan2, self).__init__() 870 | 871 | resolution_log2 = int(np.log2(resolution)) 872 | assert resolution == 2 ** resolution_log2 and resolution >= 4 873 | self.nf = lambda stage: np.clip(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_min, fmap_max) 874 | assert architecture in ['orig', 'skip', 'resnet'] 875 | num_layers = resolution_log2 * 2 - 2 876 | 877 | self.arch = architecture 878 | self.act = act 879 | self.resolution_log2 = resolution_log2 880 | self.opts = opts 881 | 882 | # Primary inputs. 883 | self.x = torch.nn.Parameter(torch.randn(1, self.nf(1), 4, 4)) 884 | # self.x = torch.randn(1, self.nf(1), 4, 4).to(self.opts.device) 885 | 886 | # layer0 887 | self.rgb0 = ToRGB(input_channels=self.nf(1), 888 | output_channels=num_channels, 889 | res=2, 890 | opts=opts) 891 | self.glayer0 = GLayer(input_channels=self.nf(1), 892 | output_channels=self.nf(1), 893 | layer_idx=0, 894 | k=resample_kernel, 895 | randomize_noise=randomize_noise, 896 | act=self.act, 897 | up=False, 898 | opts=opts) 899 | 900 | # rgb layers & block layers. 901 | self.rgb_layers = ModuleList([ToRGB(input_channels=self.nf(3), 902 | output_channels=num_channels, 903 | res=3, 904 | opts=opts, 905 | fused_modconv=fused_modconv)]) 906 | self.block_layers = ModuleList([GBlock(input_channels=self.nf(2), 907 | output_channels=self.nf(3), 908 | layer_idx=1, 909 | opts=opts)]) 910 | 911 | for res in range(4, self.resolution_log2 + 1): 912 | self.rgb_layers.append(ToRGB(input_channels=self.nf(res), 913 | output_channels=num_channels, 914 | res=res, 915 | opts=opts, 916 | fused_modconv=fused_modconv)) 917 | self.block_layers.append(GBlock(input_channels=self.nf(res - 1), 918 | output_channels=self.nf(res), 919 | layer_idx=(res - 2) * 2 - 1, 920 | opts=opts)) 921 | 922 | # upsample layer 923 | self.upsample2d = Upsample2d(opts=opts) 924 | 925 | self.tanh = torch.nn.Tanh() 926 | 927 | def forward(self, dlatent): 928 | # Early Layers 929 | y = None 930 | x = self.x.repeat(dlatent.shape[0], 1, 1, 1) 931 | x = self.glayer0([x, dlatent[:, 0]]) 932 | 933 | if self.arch == 'skip': 934 | y = self.rgb0([x, y, dlatent]) 935 | 936 | # Main layers. 937 | for res, (rgb, block) in enumerate(zip(self.rgb_layers, self.block_layers)): 938 | x = block([x, dlatent]) 939 | if self.arch == 'skip': 940 | y = self.upsample2d(y) 941 | if self.arch == 'skip' or (res + 3) == self.resolution_log2: 942 | y = rgb([x, y, dlatent]) 943 | 944 | # [-1, 1] 945 | # y = shrink_fun(y) 946 | # y = y / torch.max(torch.abs(y)) 947 | # y = self.tanh(y) 948 | return y 949 | 950 | 951 | # ========================================================================= 952 | # Combine G_mapping & G_synthesis_stylegan2 in G_stylegan2 953 | # ========================================================================= 954 | 955 | class G_stylegan2(nn.Module): 956 | def __init__(self, 957 | opts, 958 | return_dlatents=True, 959 | fmap_base=8 << 10, # stylegan1 8192 (8 << 10), stylegan2 16384 (16 << 10) 960 | num_channels=3, # Number of output color channels. 961 | mapping_fmaps=512, 962 | dlatent_size=512, # Disentangled latent (W) dimensionality. 963 | resolution=1024, # Output resolution. 964 | mapping_layers=8, # Number of mapping layers. 965 | randomize_noise=True, 966 | fmap_decay=1.0, # log2 feature map reduction when doubling the resolution. 967 | fmap_min=1, # Minimum number of feature maps in any layer. 968 | fmap_max=512, # Maximum number of feature maps in any layer. 969 | architecture='skip', # Architecture: 'orig', 'skip'. 970 | act='lrelu', # Activation function: 'linear', 'lrelu'. 971 | lrmul=0.01, # Learning rate multiplier for the mapping layers. 972 | gain=1, # original gain in tensorflow. 973 | truncation_psi=0.7, # Style strength multiplier for the truncation trick. None = disable. 974 | truncation_cutoff=8, # Number of layers for which to apply the truncation trick. None = disable. 975 | ): 976 | super().__init__() 977 | assert architecture in ['orig', 'skip'] 978 | 979 | self.return_dlatents = return_dlatents 980 | self.num_channels = num_channels 981 | 982 | self.g_mapping = G_mapping(mapping_fmaps=mapping_fmaps, 983 | dlatent_size=dlatent_size, 984 | resolution=resolution, 985 | mapping_layers=mapping_layers, 986 | lrmul=lrmul, 987 | gain=gain) 988 | 989 | self.g_synthesis = G_synthesis_stylegan2(resolution=resolution, 990 | architecture=architecture, 991 | randomize_noise=randomize_noise, 992 | fmap_base=fmap_base, 993 | fmap_min=fmap_min, 994 | fmap_max=fmap_max, 995 | fmap_decay=fmap_decay, 996 | act=act, 997 | opts=opts) 998 | 999 | self.truncation_cutoff = truncation_cutoff 1000 | self.truncation_psi = truncation_psi 1001 | 1002 | def forward(self, x): 1003 | dlatents1 = self.g_mapping(x) 1004 | num_layers = dlatents1.shape[1] 1005 | 1006 | # Apply truncation trick. 1007 | if self.truncation_psi and self.truncation_cutoff: 1008 | batch_avg = torch.mean(dlatents1, dim=1, keepdim=True) 1009 | coefs = np.ones([1, num_layers, 1], dtype=np.float32) 1010 | for i in range(num_layers): 1011 | coefs[:, i, :] *= self.truncation_psi 1012 | """Linear interpolation. 1013 | a + (b - a) * t 1014 | """ 1015 | dlatents1 = batch_avg + (dlatents1 - batch_avg) * torch.Tensor(coefs).to(dlatents1.device) 1016 | 1017 | out = self.g_synthesis(dlatents1) 1018 | 1019 | if self.return_dlatents: 1020 | return out, dlatents1 1021 | else: 1022 | return out 1023 | 1024 | 1025 | # ========================================================================= 1026 | # Define D_stylegan2 1027 | # 1) support structure == origin & resnet. (skip is unsupport here.) 1028 | # 2) multi-label unsupport. 1029 | # 3) Almost coord with the original! (2019.12.18) 1030 | # ========================================================================= 1031 | 1032 | class D_stylegan2(nn.Module): 1033 | def __init__(self, 1034 | resolution=1024, 1035 | fmap_base=8 << 10, # stylegan1 8192 (8 << 10), stylegan2 16384 (16 << 10) 1036 | num_channels=3, 1037 | label_size=0, # Dimensionality of the labels, 1 if no labels. Overridden based on dataset. 1038 | structure='resnet', # Architecture: 'orig', 'resnet' (skip unsupported). 1039 | fmap_max=512, 1040 | fmap_min=1, 1041 | fmap_decay=1.0, 1042 | mbstd_group_size=4, # Group size for the minibatch standard deviation layer, 0 = disable. 1043 | mbstd_num_features=1, # Number of features for the minibatch standard deviation layer. 1044 | resample_kernel=[1, 3, 3, 1] 1045 | # Low-pass filter to apply when resampling activations. None = no filtering. 1046 | ): 1047 | """ 1048 | Noitce: we only support input pic with height == width. 1049 | 1050 | if H or W >= 128, we use avgpooling2d to do feature map shrinkage. 1051 | else: we use ordinary conv2d. 1052 | """ 1053 | super().__init__() 1054 | self.resolution_log2 = int(np.log2(resolution)) 1055 | assert resolution == 2 ** self.resolution_log2 and resolution >= 4 and self.resolution_log2 >= 4 1056 | self.nf = lambda stage: np.clip(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_min, fmap_max) 1057 | 1058 | assert structure in ['orig', 'skip', 'resnet'] 1059 | 1060 | if structure == 'skip': 1061 | raise Exception("skip in Discriminator is unsupported yet~") 1062 | 1063 | self.structure = structure 1064 | self.label_size = label_size 1065 | self.mbstd_group_size = mbstd_group_size 1066 | 1067 | # sub network 1068 | self.fromrgb = FromRGB(input_channels=3, 1069 | output_channels=self.nf(self.resolution_log2 - 1), 1070 | use_wscale=True) 1071 | 1072 | # dblock layers 1073 | self.block_layers = ModuleList([]) 1074 | 1075 | for res in range(self.resolution_log2, 4, -1): 1076 | self.block_layers.append(DBlock(in1=self.nf(res - 1), in2=self.nf(res - 1), out3=self.nf(res - 2), 1077 | resample_kernel=resample_kernel)) 1078 | 1079 | for res in range(4, 2, -1): 1080 | self.block_layers.append( 1081 | DBlock(in1=self.nf(res), in2=self.nf(res - 1), out3=self.nf(res - 2), resample_kernel=resample_kernel)) 1082 | 1083 | # 4x4 1084 | self.minibatch_stddev = Minibatch_stddev_layer(mbstd_group_size, mbstd_num_features) 1085 | self.conv_last = Conv2d(input_channels=self.nf(2) + mbstd_num_features, 1086 | output_channels=self.nf(1), 1087 | kernel_size=3, 1088 | act='lrelu') 1089 | self.fc_last1 = FC(in_channels=fmap_base, 1090 | out_channels=self.nf(0), 1091 | act='lrelu') 1092 | self.fc_last2 = FC(in_channels=self.nf(0), 1093 | out_channels=1, 1094 | act='linear') 1095 | 1096 | def forward(self, input): 1097 | 1098 | x_origin = None 1099 | y = input 1100 | # 1) Main Layers. 1101 | x = self.fromrgb([x_origin, y]) 1102 | for dblock in self.block_layers: 1103 | x = dblock(x) 1104 | 1105 | # 2) Final layers (4 x 4). 1106 | if self.mbstd_group_size > 1: 1107 | x = self.minibatch_stddev(x) 1108 | x = self.conv_last(x) 1109 | 1110 | out = x 1111 | 1112 | _, c, h, w = out.shape 1113 | out = out.view(-1, h * w * c) 1114 | 1115 | out = self.fc_last1(out) 1116 | 1117 | # 3) Output 1118 | if self.label_size == 0: 1119 | out = self.fc_last2(out) 1120 | return out 1121 | 1122 | 1123 | # ========================================================================= 1124 | # Define StyleGAN2 1125 | # ========================================================================= 1126 | 1127 | class StyleGAN2: 1128 | """ Unconditional StyleGAN2 1129 | """ 1130 | 1131 | def __init__(self, 1132 | opts, 1133 | use_ema=True, 1134 | ema_decay=0.999): 1135 | """ constructor for the class """ 1136 | 1137 | self.start_epoch = 0 1138 | self.opts = opts 1139 | # Create the model 1140 | self.G = G_stylegan2(opts=opts, 1141 | fmap_base=opts.fmap_base, 1142 | resolution=opts.resolution, 1143 | mapping_layers=opts.mapping_layers, 1144 | return_dlatents=opts.return_latents, 1145 | architecture='skip') 1146 | 1147 | self.D = D_stylegan2(fmap_base=opts.fmap_base, 1148 | resolution=opts.resolution, 1149 | structure='resnet') 1150 | 1151 | # Load the pre-trained weight 1152 | if os.path.exists(opts.resume): 1153 | INFO("Load the pre-trained weight!") 1154 | state = torch.load(opts.resume) 1155 | self.G.load_state_dict(state['G']) 1156 | self.D.load_state_dict(state['D']) 1157 | self.start_epoch = state['start_epoch'] 1158 | else: 1159 | INFO("Pre-trained weight cannot load successfully, train from scratch!") 1160 | 1161 | # Multi-GPU support 1162 | if torch.cuda.device_count() > 1: 1163 | INFO("Multiple GPU:" + str(torch.cuda.device_count()) + "\t GPUs") 1164 | self.G = torch.nn.DataParallel(self.G) 1165 | self.D = torch.nn.DataParallel(self.D) 1166 | self.G.to(opts.device) 1167 | self.D.to(opts.device) 1168 | 1169 | # state of the object 1170 | self.use_ema = use_ema 1171 | self.ema_decay = ema_decay 1172 | 1173 | if self.use_ema: 1174 | from utils.libs import update_average 1175 | 1176 | # create a shadow copy of the generator 1177 | self.Gs = copy.deepcopy(self.G) 1178 | 1179 | # updater function: 1180 | self.ema_updater = update_average 1181 | 1182 | # initialize the gen_shadow weights equal to the 1183 | # weights of gen 1184 | Gs_beta = 0.99 1185 | self.ema_updater(self.Gs, self.G, beta=Gs_beta) 1186 | 1187 | # by default the generator and discriminator are in eval mode 1188 | self.G.eval() 1189 | self.D.eval() 1190 | if self.use_ema: 1191 | self.Gs.eval() 1192 | 1193 | def optimize_G(self, 1194 | gen_optim, 1195 | dlatent, 1196 | real_batch, 1197 | loss_fn): 1198 | """ 1199 | performs one step of weight update on generator using the batch of data 1200 | :param gen_optim: generator optimizer 1201 | :param dlatent: input noise of sample generation 1202 | :param real_batch: real samples batch 1203 | should contain a list of tensors at different scales 1204 | :param loss_fn: loss function to be used (object of GANLoss) 1205 | :return: current loss 1206 | """ 1207 | 1208 | # generate a batch of samples 1209 | fake_samples = self.G(dlatent) 1210 | loss = loss_fn.gen_loss(real_batch, fake_samples) 1211 | 1212 | # optimize discriminator 1213 | gen_optim.zero_grad() 1214 | loss.backward() 1215 | gen_optim.step() 1216 | 1217 | # if self.use_ema is true, apply the moving average here: 1218 | if self.use_ema: 1219 | self.ema_updater(self.Gs, self.G, self.ema_decay) 1220 | 1221 | return loss.mean().item() 1222 | 1223 | def optimize_D(self, 1224 | dis_optim, 1225 | dlatent, 1226 | real_batch, 1227 | loss_fn): 1228 | """ 1229 | performs one step of weight update on discriminator using the batch of data 1230 | :param dis_optim: discriminator optimizer 1231 | :param dlatent: input noise of sample generation 1232 | :param real_batch: real samples batch 1233 | should contain a list of tensors at different scales 1234 | :param loss_fn: loss function to be used (object of GANLoss) 1235 | :return: current loss 1236 | """ 1237 | # generate a batch of samples 1238 | fake_samples = self.G(dlatent) 1239 | fake_samples = fake_samples.detach() 1240 | 1241 | loss = loss_fn.dis_loss(real_batch, fake_samples) 1242 | 1243 | # optimize discriminator 1244 | dis_optim.zero_grad() 1245 | loss.backward() 1246 | dis_optim.step() 1247 | 1248 | return loss.mean().item() 1249 | 1250 | def train(self, 1251 | data_loader, 1252 | gen_optim, 1253 | dis_optim, 1254 | loss_fn, 1255 | scheduler_gen, 1256 | scheduler_dis 1257 | ): 1258 | """ 1259 | Method for training the network 1260 | 1) data_loader. Dataloader in PyTorch. 1261 | 2) gen_optim. torch.optim.Optimizer for Generator. 1262 | 3) dis_optim. torch.optim.Optimizer for Discriminator. 1263 | 4) loss_fn. loss/ganloss.py StyleLoss. 1264 | 5) scheduler_gen. scheduler_gen. 1265 | 6) scheduler_dis. scheduler_dis. 1266 | """ 1267 | 1268 | # turn the generator and discriminator into train mode 1269 | self.G.train() 1270 | self.D.train() 1271 | 1272 | # Train 1273 | fix_z = torch.randn([self.opts.batch_size, 512]).to(self.opts.device) 1274 | softplus = torch.nn.Softplus() 1275 | Loss_D_list = [0.0] 1276 | Loss_G_list = [0.0] 1277 | for ep in range(self.start_epoch, self.opts.epoch): 1278 | bar = tqdm(data_loader) 1279 | loss_D_list = [] 1280 | loss_G_list = [] 1281 | for i, (real_img,) in enumerate(bar): 1282 | real_img = real_img.to(self.opts.device) 1283 | latents = torch.randn([real_img.size(0), 512]).to(self.opts.device) 1284 | 1285 | # optimize the discriminator: 1286 | d_loss = self.optimize_D(dis_optim, latents, 1287 | real_img, loss_fn) 1288 | 1289 | # optimize the generator: 1290 | g_loss = self.optimize_G(gen_optim, latents, 1291 | real_img, loss_fn) 1292 | 1293 | loss_G_list.append(g_loss) 1294 | loss_D_list.append(d_loss) 1295 | 1296 | # Output training stats 1297 | bar.set_description( 1298 | "Epoch {} [{}, {}] [G]: {} [D]: {}". 1299 | format(ep, i + 1, len(data_loader), loss_G_list[-1], loss_D_list[-1])) 1300 | 1301 | # Save the result 1302 | Loss_G_list.append(np.mean(loss_G_list)) 1303 | Loss_D_list.append(np.mean(loss_D_list)) 1304 | 1305 | # Check how the generator is doing by saving G's output on fixed_noise 1306 | with torch.no_grad(): 1307 | if self.opts.return_latents: 1308 | fake_img = self.G(fix_z)[0].detach().cpu() 1309 | else: 1310 | fake_img = self.G(fix_z).detach().cpu() 1311 | save_image(fake_img, os.path.join(self.opts.det, 'images', str(ep) + '.png'), nrow=4, normalize=True) 1312 | 1313 | # Save model 1314 | state = { 1315 | 'G': self.G.state_dict(), 1316 | 'D': self.D.state_dict(), 1317 | 'Loss_G': Loss_G_list, 1318 | 'Loss_D': Loss_D_list, 1319 | 'start_epoch': ep, 1320 | } 1321 | torch.save(state, os.path.join(self.opts.det, 'models', 'latest.pth')) 1322 | 1323 | scheduler_gen.step() 1324 | scheduler_dis.step() 1325 | 1326 | # Plot the total loss curve 1327 | Loss_D_list = Loss_D_list[1:] 1328 | Loss_G_list = Loss_G_list[1:] 1329 | plotLossCurve(self.opts, Loss_D_list, Loss_G_list) 1330 | 1331 | 1332 | if __name__ == "__main__": 1333 | # 1) G_mapping ok. 1334 | # data = torch.randn([1, 512]) 1335 | # g = G_mapping() 1336 | # print(g(data).shape) 1337 | 1338 | # 2) D_stylegan2 ok (upfirdn). 1339 | from loss.loss import D_logistic_r1 1340 | 1341 | data = torch.randn(1, 3, 256, 256).cuda() 1342 | print(torch.max(data)) 1343 | print(torch.min(data)) 1344 | # fake = torch.randn(1, 3, 256, 256).cuda() 1345 | d = D_stylegan2(resolution=256, 1346 | structure='resnet', 1347 | resample_kernel=[1, 3, 3, 1]).cuda() 1348 | print(d(data)) 1349 | # https://discuss.pytorch.org/t/one-of-the-differentiated-tensors-does-not-require-grad/54694 1350 | # D_logistic_r1(fake, data, d) 1351 | 1352 | # 3) G_synthesis_stylegan2 Early Layers 1353 | # data = torch.randn([5, 18, 512]) 1354 | # g_syn = G_synthesis_stylegan2(resolution=128) 1355 | # print(g_syn(data).shape) 1356 | 1357 | # 3.1) ModulatedConv2d(up=True) UpConvsample2d 1358 | # y = torch.randn([1, 512]) 1359 | # x = torch.randn(1, 3, 128, 128) 1360 | # up1 = ModulatedConv2d(up=True, 1361 | # input_channels=3, 1362 | # output_channels=6, 1363 | # kernel_size=3) 1364 | # print(up1([x, y]).shape) 1365 | 1366 | # 3.2) Upsample2d 1367 | # data = torch.randn(1, 3, 128, 128) 1368 | # up = Upsample2d() 1369 | # print(up(data).shape) 1370 | 1371 | # 4) G_combination 1372 | # opts = TrainOptions().parse() 1373 | # data = torch.randn([5, 512]) 1374 | # g = G_stylegan2(resolution=256, 1375 | # mapping_layers=5, 1376 | # opts=opts) 1377 | # print(g(data).shape) 1378 | # print(g(data).shape) 1379 | # torch.Size([5, 3, 256, 256]) 1380 | -------------------------------------------------------------------------------- /opts/opts.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | """ 3 | @author: samuel ko 4 | """ 5 | import argparse 6 | import torch 7 | import os 8 | 9 | 10 | def INFO(inputs): 11 | print("[ Style GAN2 ] %s" % (inputs)) 12 | 13 | 14 | def presentParameters(args_dict): 15 | """ 16 | Print the parameters setting line by line 17 | 18 | Arg: args_dict - The dict object which is transferred from argparse Namespace object 19 | """ 20 | INFO("========== Parameters ==========") 21 | for key in sorted(args_dict.keys()): 22 | INFO("{:>15} : {}".format(key, args_dict[key])) 23 | INFO("===============================") 24 | 25 | 26 | class TrainOptions(): 27 | def __init__(self): 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument('--path', type=str, default='/home/samuel/gaodaiheng/生成模型/face_dataset/') 30 | parser.add_argument('--epoch', type=int, default=500) 31 | parser.add_argument('--fmap_base', type=int, default=8 << 10) 32 | parser.add_argument('--resolution', type=int, default=128) 33 | parser.add_argument('--mapping_layers', type=int, default=8) 34 | parser.add_argument('--batch_size', type=int, default=8) 35 | parser.add_argument('--type', type=str, default='style') 36 | parser.add_argument('--resume', type=str, default='train_result/models/latest.pth') 37 | parser.add_argument('--det', type=str, default='train_result') 38 | self.opts = parser.parse_args() 39 | 40 | def parse(self): 41 | self.opts.device = 'cuda' if torch.cuda.is_available() else 'cpu' 42 | 43 | # Check if the parameter is valid 44 | if self.opts.type not in ['style', 'origin']: 45 | raise Exception( 46 | "Unknown type: {} You should assign one of them ['style', 'origin']...".format(self.opts.type)) 47 | 48 | # Create the destination folder 49 | if not os.path.exists(self.opts.det): 50 | os.mkdir(self.opts.det) 51 | if not os.path.exists(os.path.join(self.opts.det, 'images')): 52 | os.mkdir(os.path.join(self.opts.det, 'images')) 53 | if not os.path.exists(os.path.join(self.opts.det, 'models')): 54 | os.mkdir(os.path.join(self.opts.det, 'models')) 55 | 56 | # Print the options 57 | presentParameters(vars(self.opts)) 58 | return self.opts 59 | 60 | 61 | class InferenceOptions(): 62 | def __init__(self): 63 | parser = argparse.ArgumentParser() 64 | parser.add_argument('--resume', type=str, default='train_result/model/latest.pth') 65 | parser.add_argument('--type', type=str, default='style') 66 | parser.add_argument('--num_face', type=int, default=32) 67 | parser.add_argument('--det', type=str, default='result.png') 68 | self.opts = parser.parse_args() 69 | 70 | def parse(self): 71 | self.opts.device = 'cuda' if torch.cuda.is_available() else 'cpu' 72 | 73 | # Print the options 74 | presentParameters(vars(self.opts)) 75 | return self.opts 76 | -------------------------------------------------------------------------------- /torchvision_sunner/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomguluson92/StyleGAN2_PyTorch/4ab7354c85cb986d2b77f5238c4a18c5efd1db1b/torchvision_sunner/__init__.py -------------------------------------------------------------------------------- /torchvision_sunner/constant.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script defines the constant which will be used in Torchvision_sunner package 3 | 4 | Author: SunnerLi 5 | """ 6 | 7 | # Constant 8 | UNDER_SAMPLING = 0 9 | OVER_SAMPLING = 1 10 | BCHW2BHWC = 0 11 | BHWC2BCHW = 1 12 | 13 | # Categorical constant 14 | ONEHOT2INDEX = 'one_hot_to_index' 15 | INDEX2ONEHOT = 'index_to_one_hot' 16 | ONEHOT2COLOR = 'one_hot_to_color' 17 | COLOR2ONEHOT = 'color_to_one_hot' 18 | INDEX2COLOR = 'index_to_color' 19 | COLOR2INDEX = 'color_to_index' 20 | 21 | # Verbose flag 22 | verbose = True -------------------------------------------------------------------------------- /torchvision_sunner/data/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script define the wrapper of the Torchvision_sunner.data 3 | 4 | Author: SunnerLi 5 | """ 6 | from torch.utils.data import DataLoader 7 | 8 | from torchvision_sunner.data.image_dataset import ImageDataset 9 | from torchvision_sunner.data.video_dataset import VideoDataset 10 | from torchvision_sunner.data.loader import * 11 | from torchvision_sunner.constant import * 12 | from torchvision_sunner.utils import * -------------------------------------------------------------------------------- /torchvision_sunner/data/base_dataset.py: -------------------------------------------------------------------------------- 1 | from torchvision_sunner.constant import * 2 | from torchvision_sunner.utils import INFO 3 | import torch.utils.data as Data 4 | 5 | import pickle 6 | import random 7 | import os 8 | 9 | """ 10 | This script define the parent class to deal with some common function for Dataset 11 | 12 | Author: SunnerLi 13 | """ 14 | 15 | class BaseDataset(Data.Dataset): 16 | def __init__(self): 17 | self.save_file = False 18 | self.files = None 19 | self.split_files = None 20 | 21 | def generateIndexList(self, a, size): 22 | """ 23 | Generate the list of index which will be picked 24 | This function will be used as train-test-split 25 | 26 | Arg: a - The list of images 27 | size - Int, the length of list you want to create 28 | Ret: The index list 29 | """ 30 | result = set() 31 | while len(result) != size: 32 | result.add(random.randint(0, len(a) - 1)) 33 | return list(result) 34 | 35 | def loadFromFile(self, file_name, check_type = 'image'): 36 | """ 37 | Load the root and files information from .pkl record file 38 | This function will return False if the record file format is invalid 39 | 40 | Arg: file_name - The name of record file 41 | check_type - Str. The type of the record file you want to check 42 | Ret: If the loading procedure are successful or not 43 | """ 44 | with open(file_name, 'rb') as f: 45 | obj = pickle.load(f) 46 | self.type = obj['type'] 47 | if self.type == check_type: 48 | INFO("Load from file: {}".format(file_name)) 49 | self.root = obj['root'] 50 | self.files = obj['files'] 51 | return True 52 | else: 53 | INFO("Record file type: {}\tFail to load...".format(self.type)) 54 | INFO("Form the contain from scratch...") 55 | return False 56 | 57 | def save(self, remain_file_name, split_ratio, split_file_name = ".split.pkl", save_type = 'image'): 58 | """ 59 | Save the information into record file 60 | 61 | Arg: remain_file_name - The path of record file which store the information of remain data 62 | split_ratio - Float. The proportion to split the data. Usually used to split the testing data 63 | split_file_name - The path of record file which store the information of split data 64 | save_type - Str. The type of the record file you want to save 65 | """ 66 | if self.save_file: 67 | if not os.path.exists(remain_file_name): 68 | with open(remain_file_name, 'wb') as f: 69 | pickle.dump({ 70 | 'type': save_type, 71 | 'root': self.root, 72 | 'files': self.files 73 | }, f) 74 | if split_ratio: 75 | INFO("Split the dataset, and save as {}".format(split_file_name)) 76 | with open(split_file_name, 'wb') as f: 77 | pickle.dump({ 78 | 'type': save_type, 79 | 'root': self.root, 80 | 'files': self.split_files 81 | }, f) -------------------------------------------------------------------------------- /torchvision_sunner/data/image_dataset.py: -------------------------------------------------------------------------------- 1 | from torchvision_sunner.data.base_dataset import BaseDataset 2 | from torchvision_sunner.read import readContain, readItem 3 | from torchvision_sunner.constant import * 4 | from torchvision_sunner.utils import INFO 5 | 6 | from skimage import io as io 7 | # from PIL import Image 8 | from glob import glob 9 | 10 | import torch.utils.data as Data 11 | 12 | import pickle 13 | import math 14 | import os 15 | 16 | """ 17 | This script define the structure of image dataset 18 | 19 | ======================================================================================= 20 | In the new version, we accept the form that the combination of image and folder: 21 | e.g. [[image1.jpg, image_folder]] 22 | On the other hand, the root can only be 'the list of list' 23 | You should use double list to represent different image domain. 24 | For example: 25 | [[image1.jpg], [image2.jpg]] => valid 26 | [[image1.jpg], [image_folder]] => valid 27 | [[image1.jpg, image2.jpg], [image_folder1, image_folder2]] => valid 28 | [image1.jpg, image2.jpg] => invalid! 29 | Also, the triple of nested list is not allow 30 | ======================================================================================= 31 | 32 | Author: SunnerLi 33 | """ 34 | 35 | class ImageDataset(BaseDataset): 36 | def __init__(self, root = None, file_name = '.remain.pkl', sample_method = UNDER_SAMPLING, transform = None, 37 | split_ratio = 0.0, save_file = False): 38 | """ 39 | The constructor of ImageDataset 40 | 41 | Arg: root - The list object. The image set 42 | file_name - The str. The name of record file. 43 | sample_method - sunnerData.UNDER_SAMPLING or sunnerData.OVER_SAMPLING. Use down sampling or over sampling to deal with data unbalance problem. 44 | (default is sunnerData.OVER_SAMPLING) 45 | transform - transform.Compose object. You can declare some pre-process toward the image 46 | split_ratio - Float. The proportion to split the data. Usually used to split the testing data 47 | save_file - Bool. If storing the record file or not. Default is False 48 | """ 49 | super().__init__() 50 | # Record the parameter 51 | self.root = root 52 | self.file_name = file_name 53 | self.sample_method = sample_method 54 | self.transform = transform 55 | self.split_ratio = split_ratio 56 | self.save_file = save_file 57 | self.img_num = -1 58 | INFO() 59 | 60 | # Substitude the contain of record file if the record file is exist 61 | if os.path.exists(file_name) and self.loadFromFile(file_name): 62 | self.getImgNum() 63 | elif not os.path.exists(file_name) and root is None: 64 | raise Exception("Record file {} not found. You should assign 'root' parameter!".format(file_name)) 65 | else: 66 | # Extend the images of folder into domain list 67 | self.getFiles() 68 | 69 | # Change root obj as the index format 70 | self.root = range(len(self.root)) 71 | 72 | # Adjust the image number 73 | self.getImgNum() 74 | 75 | # Split the files if split_ratio is more than 0.0 76 | self.split() 77 | 78 | # Save the split information 79 | self.save() 80 | 81 | # Print the domain information 82 | self.print() 83 | 84 | # =========================================================================================== 85 | # Define IO function 86 | # =========================================================================================== 87 | def loadFromFile(self, file_name): 88 | """ 89 | Load the root and files information from .pkl record file 90 | This function will return False if the record file format is invalid 91 | 92 | Arg: file_name - The name of record file 93 | Ret: If the loading procedure are successful or not 94 | """ 95 | return super().loadFromFile(file_name, 'image') 96 | 97 | def save(self, split_file_name = ".split.pkl"): 98 | """ 99 | Save the information into record file 100 | 101 | Arg: split_file_name - The path of record file which store the information of split data 102 | """ 103 | super().save(self.file_name, self.split_ratio, split_file_name, 'image') 104 | 105 | # =========================================================================================== 106 | # Define main function 107 | # =========================================================================================== 108 | def getFiles(self): 109 | """ 110 | Construct the files object for the assigned root 111 | We accept the user to mix folder with image 112 | This function can extract whole image in the folder 113 | The element in the files will all become image 114 | 115 | ******************************************************* 116 | * This function only work if the files object is None * 117 | ******************************************************* 118 | """ 119 | if not self.files: 120 | self.files = {} 121 | for domain_idx, domain in enumerate(self.root): 122 | images = [] 123 | for img in domain: 124 | if os.path.exists(img): 125 | if os.path.isdir(img): 126 | images += readContain(img) 127 | else: 128 | images.append(img) 129 | else: 130 | raise Exception("The path {} is not exist".format(img)) 131 | self.files[domain_idx] = sorted(images) 132 | 133 | def getImgNum(self): 134 | """ 135 | Obtain the image number in the loader for the specific sample method 136 | The function will check if the folder has been extracted 137 | """ 138 | if self.img_num == -1: 139 | # Check if the folder has been extracted 140 | for domain in self.root: 141 | for img in self.files[domain]: 142 | if os.path.isdir(img): 143 | raise Exception("You should extend the image in the folder {} first!" % img) 144 | 145 | # Statistic the image number 146 | for domain in self.root: 147 | if domain == 0: 148 | self.img_num = len(self.files[domain]) 149 | else: 150 | if self.sample_method == OVER_SAMPLING: 151 | self.img_num = max(self.img_num, len(self.files[domain])) 152 | elif self.sample_method == UNDER_SAMPLING: 153 | self.img_num = min(self.img_num, len(self.files[domain])) 154 | return self.img_num 155 | 156 | def split(self): 157 | """ 158 | Split the files object into split_files object 159 | The original files object will shrink 160 | 161 | We also consider the case of pair image 162 | Thus we will check if the number of image in each domain is the same 163 | If it does, then we only generate the list once 164 | """ 165 | # Check if the number of image in different domain is the same 166 | if not self.files: 167 | self.getFiles() 168 | pairImage = True 169 | for domain in range(len(self.root) - 1): 170 | if len(self.files[domain]) != len(self.files[domain + 1]): 171 | pairImage = False 172 | 173 | # Split the files 174 | self.split_files = {} 175 | if pairImage: 176 | split_img_num = math.floor(len(self.files[0]) * self.split_ratio) 177 | choice_index_list = self.generateIndexList(range(len(self.files[0])), size = split_img_num) 178 | for domain in range(len(self.root)): 179 | # determine the index list 180 | if not pairImage: 181 | split_img_num = math.floor(len(self.files[domain]) * self.split_ratio) 182 | choice_index_list = self.generateIndexList(range(len(self.files[domain])), size = split_img_num) 183 | # remove the corresponding term and add into new list 184 | split_img_list = [] 185 | remain_img_list = self.files[domain].copy() 186 | for j in choice_index_list: 187 | split_img_list.append(self.files[domain][j]) 188 | for j in choice_index_list: 189 | self.files[domain].remove(remain_img_list[j]) 190 | self.split_files[domain] = sorted(split_img_list) 191 | 192 | def print(self): 193 | """ 194 | Print the information for each image domain 195 | """ 196 | INFO() 197 | for domain in range(len(self.root)): 198 | INFO("domain index: %d \timage number: %d" % (domain, len(self.files[domain]))) 199 | INFO() 200 | 201 | def __len__(self): 202 | return self.img_num 203 | 204 | def __getitem__(self, index): 205 | return_list = [] 206 | for domain in self.root: 207 | img_path = self.files[domain][index] 208 | img = readItem(img_path) 209 | if self.transform: 210 | img = self.transform(img) 211 | return_list.append(img) 212 | return return_list -------------------------------------------------------------------------------- /torchvision_sunner/data/loader.py: -------------------------------------------------------------------------------- 1 | from torchvision_sunner.constant import * 2 | from collections import Iterator 3 | import torch.utils.data as data 4 | 5 | """ 6 | This script define the extra loader, and it can be used in flexibility. The loaders include: 7 | 1. ImageLoader (The old version exist) 8 | 2. MultiLoader 9 | 3. IterationLoader 10 | 11 | Author: SunnerLi 12 | """ 13 | 14 | class ImageLoader(data.DataLoader): 15 | def __init__(self, dataset, batch_size=1, shuffle=False, num_workers = 1): 16 | """ 17 | The DataLoader object which can deal with ImageDataset object. 18 | 19 | Arg: dataset - ImageDataset. You should use sunnerData.ImageDataset to generate the instance first 20 | batch_size - Int. 21 | shuffle - Bool. Shuffle the data or not 22 | num_workers - Int. How many thread you want to use to read the batch data 23 | """ 24 | super(ImageLoader, self).__init__(dataset, batch_size=batch_size, shuffle=shuffle, num_workers = num_workers) 25 | self.dataset = dataset 26 | self.iter_num = self.__len__() 27 | 28 | def __len__(self): 29 | return round(self.dataset.img_num / self.batch_size) 30 | 31 | def getImageNumber(self): 32 | return self.dataset.img_num 33 | 34 | class MultiLoader(Iterator): 35 | def __init__(self, datasets, batch_size=1, shuffle=False, num_workers = 1): 36 | """ 37 | This class can deal with multiple dataset object 38 | 39 | Arg: datasets - The list of ImageDataset. 40 | batch_size - Int. 41 | shuffle - Bool. Shuffle the data or not 42 | num_workers - Int. How many thread you want to use to read the batch data 43 | """ 44 | # Create loaders 45 | self.loaders = [] 46 | for dataset in datasets: 47 | self.loaders.append( 48 | data.DataLoader(dataset, batch_size = batch_size, shuffle = shuffle, num_workers = num_workers) 49 | ) 50 | 51 | # Check the sample method 52 | self.sample_method = None 53 | for dataset in datasets: 54 | if self.sample_method is None: 55 | self.sample_method = dataset.sample_method 56 | else: 57 | if self.sample_method != dataset.sample_method: 58 | raise Exception("Sample methods are not consistant, {} <=> {}".format( 59 | self.sample_method, dataset.sample_method 60 | )) 61 | 62 | # Check the iteration number 63 | self.iter_num = 0 64 | for i, dataset in enumerate(datasets): 65 | if i == 0: 66 | self.iter_num = len(dataset) 67 | else: 68 | if self.sample_method == UNDER_SAMPLING: 69 | self.iter_num = min(self.iter_num, len(dataset)) 70 | else: 71 | self.iter_num = max(self.iter_num, len(dataset)) 72 | self.iter_num = round(self.iter_num / batch_size) 73 | 74 | def __len__(self): 75 | return self.iter_num 76 | 77 | def __iter__(self): 78 | self.iter_loaders = [] 79 | for loader in self.loaders: 80 | self.iter_loaders.append(iter(loader)) 81 | return self 82 | 83 | def __next__(self): 84 | result = [] 85 | for loader in self.iter_loaders: 86 | for _ in loader.__next__(): 87 | result.append(_) 88 | return tuple(result) 89 | 90 | class IterationLoader(Iterator): 91 | def __init__(self, loader, max_iter = 1): 92 | """ 93 | Constructor of the loader with specific iteration (not epoch) 94 | The iteration object will create again while getting end 95 | 96 | Arg: loader - The torch.data.DataLoader object 97 | max_iter - The maximun iteration 98 | """ 99 | super().__init__() 100 | self.loader = loader 101 | self.loader_iter = iter(self.loader) 102 | self.iter = 0 103 | self.max_iter = max_iter 104 | 105 | def __next__(self): 106 | try: 107 | result_tuple = next(self.loader_iter) 108 | except: 109 | self.loader_iter = iter(self.loader) 110 | result_tuple = next(self.loader_iter) 111 | self.iter += 1 112 | if self.iter <= self.max_iter: 113 | return result_tuple 114 | else: 115 | print("", end='') 116 | raise StopIteration() 117 | 118 | def __len__(self): 119 | return self.max_iter -------------------------------------------------------------------------------- /torchvision_sunner/data/video_dataset.py: -------------------------------------------------------------------------------- 1 | from torchvision_sunner.data.base_dataset import BaseDataset 2 | from torchvision_sunner.read import readContain, readItem 3 | from torchvision_sunner.constant import * 4 | from torchvision_sunner.utils import INFO 5 | 6 | import torch.utils.data as Data 7 | 8 | from PIL import Image 9 | from glob import glob 10 | import numpy as np 11 | import subprocess 12 | import random 13 | import pickle 14 | import torch 15 | import math 16 | import os 17 | 18 | """ 19 | This script define the structure of video dataset 20 | 21 | ======================================================================================= 22 | In the new version, we accept the form that the combination of video and folder: 23 | e.g. [[video1.mp4, image_folder]] 24 | On the other hand, the root can only be 'the list of list' 25 | You should use double list to represent different image domain. 26 | For example: 27 | [[video1.mp4], [video2.mp4]] => valid 28 | [[video1.mp4], [video_folder]] => valid 29 | [[video1.mp4, video2.mp4], [video_folder1, video_folder2]] => valid 30 | [video1.mp4, video2.mp4] => invalid! 31 | Also, the triple of nested list is not allow 32 | ======================================================================================= 33 | 34 | Author: SunnerLi 35 | """ 36 | 37 | class VideoDataset(BaseDataset): 38 | def __init__(self, root = None, file_name = '.remain.pkl', T = 10, sample_method = UNDER_SAMPLING, transform = None, 39 | split_ratio = 0.0, decode_root = './.decode', save_file = False): 40 | """ 41 | The constructor of VideoDataset 42 | 43 | Arg: root - The list object. The image set 44 | file_name - Str. The name of record file. 45 | T - Int. The maximun length of small video sequence 46 | sample_method - sunnerData.UNDER_SAMPLING or sunnerData.OVER_SAMPLING. Use down sampling or over sampling to deal with data unbalance problem. 47 | (default is sunnerData.OVER_SAMPLING) 48 | transform - transform.Compose object. You can declare some pre-process toward the image 49 | split_ratio - Float. The proportion to split the data. Usually used to split the testing data 50 | decode_root - Str. The path to store the ffmpeg decode result. 51 | save_file - Bool. If storing the record file or not. Default is False 52 | """ 53 | super().__init__() 54 | 55 | # Record the parameter 56 | self.root = root 57 | self.file_name = file_name 58 | self.T = T 59 | self.sample_method = sample_method 60 | self.transform = transform 61 | self.split_ratio = split_ratio 62 | self.decode_root = decode_root 63 | self.video_num = -1 64 | self.split_root = None 65 | INFO() 66 | 67 | # Substitude the contain of record file if the record file is exist 68 | if not os.path.exists(file_name) and root is None: 69 | raise Exception("Record file {} not found. You should assign 'root' parameter!".format(file_name)) 70 | elif os.path.exists(file_name): 71 | INFO("Load from file: {}".format(file_name)) 72 | self.loadFromFile(file_name) 73 | 74 | # Extend the images of folder into domain list 75 | self.extendFolder() 76 | 77 | # Split the image 78 | self.split() 79 | 80 | # Form the files obj 81 | self.getFiles() 82 | 83 | # Adjust the image number 84 | self.getVideoNum() 85 | 86 | # Save the split information 87 | self.save() 88 | 89 | # Print the domain information 90 | self.print() 91 | 92 | # =========================================================================================== 93 | # Define IO function 94 | # =========================================================================================== 95 | def loadFromFile(self, file_name): 96 | """ 97 | Load the root and files information from .pkl record file 98 | This function will return False if the record file format is invalid 99 | 100 | Arg: file_name - The name of record file 101 | Ret: If the loading procedure are successful or not 102 | """ 103 | return super().loadFromFile(file_name, 'video') 104 | 105 | def save(self, split_file_name = ".split.pkl"): 106 | """ 107 | Save the information into record file 108 | 109 | Arg: split_file_name - The path of record file which store the information of split data 110 | """ 111 | super().save(self.file_name, self.split_ratio, split_file_name, 'video') 112 | 113 | # =========================================================================================== 114 | # Define main function 115 | # =========================================================================================== 116 | def to_folder(self, name): 117 | """ 118 | Transfer the name into the folder format 119 | e.g. 120 | '/home/Dataset/video1_folder' => 'home_Dataset_video1_folder' 121 | '/home/Dataset/video1.mp4' => 'home_Dataset_video1' 122 | 123 | Arg: name - Str. The path of file or original folder 124 | Ret: The new (encoded) folder name 125 | """ 126 | if not os.path.isdir(name): 127 | name = '_'.join(name.split('.')[:-1]) 128 | domain_list = name.split('/') 129 | while True: 130 | if '.' in domain_list: 131 | domain_list.remove('.') 132 | elif '..' in domain_list: 133 | domain_list.remove('..') 134 | else: 135 | break 136 | return '_'.join(domain_list) 137 | 138 | def extendFolder(self): 139 | """ 140 | Extend the video folder in root obj 141 | """ 142 | if not self.files: 143 | # Extend the folder of video and replace as new root obj 144 | extend_root = [] 145 | for domain in self.root: 146 | videos = [] 147 | for video in domain: 148 | if os.path.exists(video): 149 | if os.path.isdir(video): 150 | videos += readContain(video) 151 | else: 152 | videos.append(video) 153 | else: 154 | raise Exception("The path {} is not exist".format(videos)) 155 | extend_root.append(videos) 156 | self.root = extend_root 157 | 158 | def split(self): 159 | """ 160 | Split the root object into split_root object 161 | The original root object will shrink 162 | 163 | We also consider the case of pair image 164 | Thus we will check if the number of image in each domain is the same 165 | If it does, then we only generate the list once 166 | """ 167 | # Check if the number of video in different domain is the same 168 | pairImage = True 169 | for domain_idx in range(len(self.root) - 1): 170 | if len(self.root[domain_idx]) != len(self.root[domain_idx + 1]): 171 | pairImage = False 172 | 173 | # Split the files 174 | self.split_root = [] 175 | if pairImage: 176 | split_img_num = math.floor(len(self.root[0]) * self.split_ratio) 177 | choice_index_list = self.generateIndexList(range(len(self.root[0])), size = split_img_num) 178 | for domain_idx in range(len(self.root)): 179 | # determine the index list 180 | if not pairImage: 181 | split_img_num = math.floor(len(self.root[domain_idx]) * self.split_ratio) 182 | choice_index_list = self.generateIndexList(range(len(self.root[domain_idx])), size = split_img_num) 183 | # remove the corresponding term and add into new list 184 | split_img_list = [] 185 | remain_img_list = self.root[domain_idx].copy() 186 | for j in choice_index_list: 187 | split_img_list.append(self.root[domain_idx][j]) 188 | for j in choice_index_list: 189 | self.root[domain_idx].remove(remain_img_list[j]) 190 | self.split_root.append(sorted(split_img_list)) 191 | 192 | def getFiles(self): 193 | """ 194 | Construct the files object for the assigned root 195 | We accept the user to mix folder with image 196 | This function can extract whole image in the folder 197 | 198 | However, unlike the setting in ImageDataset, we store the video result in root obj. 199 | Also, the 'images' name will be store in files obj 200 | 201 | The following list the progress of this function: 202 | 1. check if we need to decode again 203 | 2. decode if needed 204 | 3. form the files obj 205 | """ 206 | if not self.files: 207 | # Check if the decode process should be conducted again 208 | should_decode = not os.path.exists(self.decode_root) 209 | if not should_decode: 210 | for domain_idx, domain in enumerate(self.root): 211 | for video in domain: 212 | if not os.path.exists(os.path.join(self.decode_root, str(domain_idx), self.to_folder(video))): 213 | should_decode = True 214 | break 215 | 216 | # Decode the video if needed 217 | if should_decode: 218 | INFO("Decode from scratch...") 219 | if os.path.exists(self.decode_root): 220 | subprocess.call(['rm', '-rf', self.decode_root]) 221 | os.mkdir(self.decode_root) 222 | self.decodeVideo() 223 | else: 224 | INFO("Skip the decode process!") 225 | 226 | # Form the files object 227 | self.files = {} 228 | for domain_idx, domain in enumerate(os.listdir(self.decode_root)): 229 | self.files[domain_idx] = [] 230 | for video in os.listdir(os.path.join(self.decode_root, domain)): 231 | self.files[domain_idx] += [ 232 | sorted(glob(os.path.join(self.decode_root, domain, video, "*"))) 233 | ] 234 | 235 | def decodeVideo(self): 236 | """ 237 | Decode the single video into a series of images, and store into particular folder 238 | """ 239 | for domain_idx, domain in enumerate(self.root): 240 | decode_domain_folder = os.path.join(self.decode_root, str(domain_idx)) 241 | os.mkdir(decode_domain_folder) 242 | for video in domain: 243 | os.mkdir(os.path.join(self.decode_root, str(domain_idx), self.to_folder(video))) 244 | source = os.path.join(domain, video) 245 | target = os.path.join(decode_domain_folder, self.to_folder(video), "%5d.png") 246 | subprocess.call(['ffmpeg', '-i', source, target]) 247 | 248 | def getVideoNum(self): 249 | """ 250 | Obtain the video number in the loader for the specific sample method 251 | The function will check if the folder has been extracted 252 | """ 253 | if self.video_num == -1: 254 | # Check if the folder has been extracted 255 | for domain in self.root: 256 | for video in domain: 257 | if os.path.isdir(video): 258 | raise Exception("You should extend the image in the folder {} first!" % video) 259 | 260 | # Statistic the image number 261 | for i, domain in enumerate(self.root): 262 | if i == 0: 263 | self.video_num = len(domain) 264 | else: 265 | if self.sample_method == OVER_SAMPLING: 266 | self.video_num = max(self.video_num, len(domain)) 267 | elif self.sample_method == UNDER_SAMPLING: 268 | self.video_num = min(self.video_num, len(domain)) 269 | return self.video_num 270 | 271 | def print(self): 272 | """ 273 | Print the information for each image domain 274 | """ 275 | INFO() 276 | for domain in range(len(self.root)): 277 | total_frame = 0 278 | for video in self.files[domain]: 279 | total_frame += len(video) 280 | INFO("domain index: %d \tvideo number: %d\tframe total: %d" % (domain, len(self.root[domain]), total_frame)) 281 | INFO() 282 | 283 | def __len__(self): 284 | return self.video_num 285 | 286 | def __getitem__(self, index): 287 | """ 288 | Return single batch of data, and the rank is BTCHW 289 | """ 290 | result = [] 291 | for domain_idx in range(len(self.root)): 292 | 293 | # Form the sequence in single domain 294 | film_sequence = [] 295 | max_init_frame_idx = len(self.files[domain_idx][index]) - self.T 296 | start_pos = random.randint(0, max_init_frame_idx) 297 | for i in range(self.T): 298 | img_path = self.files[domain_idx][index][start_pos + i] 299 | img = readItem(img_path) 300 | film_sequence.append(img) 301 | 302 | # Transform the film sequence 303 | film_sequence = np.asarray(film_sequence) 304 | if self.transform: 305 | film_sequence = self.transform(film_sequence) 306 | result.append(film_sequence) 307 | return result -------------------------------------------------------------------------------- /torchvision_sunner/read.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | from PIL import Image 3 | from glob import glob 4 | import numpy as np 5 | import os 6 | 7 | """ 8 | This script defines the function to read the containing of folder and read the file 9 | You should customize if your data is not considered in the torchvision_sunner previously 10 | 11 | Author: SunnerLi 12 | """ 13 | 14 | def readContain(folder_name): 15 | """ 16 | Read the containing in the particular folder 17 | 18 | ================================================================== 19 | You should customize this function if your data is not considered 20 | ================================================================== 21 | 22 | Arg: folder_name - The path of folder 23 | Ret: The list of containing 24 | """ 25 | # Check the common type in the folder 26 | common_type = Counter() 27 | for name in os.listdir(folder_name): 28 | common_type[name.split('.')[-1]] += 1 29 | common_type = common_type.most_common()[0][0] 30 | 31 | # Deal with the type 32 | if common_type == 'jpg': 33 | name_list = glob(os.path.join(folder_name, '*.jpg')) 34 | elif common_type == 'png': 35 | name_list = glob(os.path.join(folder_name, '*.png')) 36 | elif common_type == 'mp4': 37 | name_list = glob(os.path.join(folder_name, '*.mp4')) 38 | else: 39 | raise Exception("Unknown type {}, You should customize in read.py".format(common_type)) 40 | return name_list 41 | 42 | def readItem(item_name): 43 | """ 44 | Read the file for the given item name 45 | 46 | ================================================================== 47 | You should customize this function if your data is not considered 48 | ================================================================== 49 | 50 | Arg: item_name - The path of the file 51 | Ret: The item you read 52 | """ 53 | file_type = item_name.split('.')[-1] 54 | if file_type == "png" or file_type == 'jpg': 55 | file_obj = np.asarray(Image.open(item_name)) 56 | 57 | if len(file_obj.shape) == 3: 58 | # Ignore the 4th dim (RGB only) 59 | file_obj = file_obj[:, :, :3] 60 | elif len(file_obj.shape) == 2: 61 | # Make the rank of gray-scale image as 3 62 | file_obj = np.expand_dims(file_obj, axis = -1) 63 | return file_obj -------------------------------------------------------------------------------- /torchvision_sunner/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script define the wrapper of the Torchvision_sunner.transforms 3 | 4 | Author: SunnerLi 5 | """ 6 | from torchvision_sunner.constant import * 7 | from torchvision_sunner.utils import * 8 | 9 | from torchvision_sunner.transforms.base import * 10 | from torchvision_sunner.transforms.simple import * 11 | from torchvision_sunner.transforms.complex import * 12 | from torchvision_sunner.transforms.categorical import * 13 | from torchvision_sunner.transforms.function import * -------------------------------------------------------------------------------- /torchvision_sunner/transforms/base.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | """ 5 | This class define the parent class of operation 6 | 7 | Author: SunnerLi 8 | """ 9 | 10 | class OP(): 11 | """ 12 | The parent class of each operation 13 | The goal of this class is to adapting with different input format 14 | """ 15 | def work(self, tensor): 16 | """ 17 | The virtual function to define the process in child class 18 | 19 | Arg: tensor - The np.ndarray object. The tensor you want to deal with 20 | """ 21 | raise NotImplementedError("You should define your own function in the class!") 22 | 23 | def __call__(self, tensor): 24 | """ 25 | This function define the proceeding of the operation 26 | There are different choice toward the tensor parameter 27 | 1. torch.Tensor and rank is CHW 28 | 2. np.ndarray and rank is CHW 29 | 3. torch.Tensor and rank is TCHW 30 | 4. np.ndarray and rank is TCHW 31 | 32 | Arg: tensor - The tensor you want to operate 33 | Ret: The operated tensor 34 | """ 35 | isTensor = type(tensor) == torch.Tensor 36 | if isTensor: 37 | tensor_type = tensor.type() 38 | tensor = tensor.cpu().data.numpy() 39 | if len(tensor.shape) == 3: 40 | tensor = self.work(tensor) 41 | elif len(tensor.shape) == 4: 42 | tensor = np.asarray([self.work(_) for _ in tensor]) 43 | else: 44 | raise Exception("We dont support the rank format {}".format(tensor.shape), 45 | "If the rank of the tensor shape is only 2, you can call 'GrayStack()'") 46 | if isTensor: 47 | tensor = torch.from_numpy(tensor) 48 | tensor = tensor.type(tensor_type) 49 | return tensor -------------------------------------------------------------------------------- /torchvision_sunner/transforms/categorical.py: -------------------------------------------------------------------------------- 1 | from torchvision_sunner.utils import INFO 2 | from torchvision_sunner.constant import * 3 | from torchvision_sunner.transforms.simple import Transpose 4 | 5 | from collections import OrderedDict 6 | from tqdm import tqdm 7 | import numpy as np 8 | import pickle 9 | import torch 10 | import json 11 | import os 12 | 13 | """ 14 | This script define the categorical-related operations, including: 15 | 1. getCategoricalMapping 16 | 2. CategoricalTranspose 17 | 18 | Author: SunnerLi 19 | """ 20 | 21 | # ---------------------------------------------------------------------------------------- 22 | # Define the IO function toward pallete file 23 | # ---------------------------------------------------------------------------------------- 24 | 25 | def load_pallete(file_name): 26 | """ 27 | Load the pallete object from file 28 | 29 | Arg: file_name - The name of pallete .json file 30 | Ret: The list of pallete object 31 | """ 32 | # Load the list of dict from files (key is str) 33 | palletes_str_key = None 34 | with open(file_name, 'r') as f: 35 | palletes_str_key = json.load(f) 36 | 37 | # Change the key into color tuple 38 | palletes = [OrderedDict()] * len(palletes_str_key) 39 | for folder in range(len(palletes_str_key)): 40 | for key in palletes_str_key[folder].keys(): 41 | tuple_key = list() 42 | for v in key.split('_'): 43 | tuple_key.append(int(v)) 44 | palletes[folder][tuple(tuple_key)] = palletes_str_key[folder][key] 45 | return palletes 46 | 47 | def save_pallete(pallete, file_name): 48 | """ 49 | Load the pallete object from file 50 | 51 | Arg: pallete - The list of OrderDict objects 52 | file_name - The name of pallete .json file 53 | """ 54 | # Change the key into str 55 | pallete_str_key = [dict()] * len(pallete) 56 | for folder in range(len(pallete)): 57 | for key in pallete[folder].keys(): 58 | 59 | str_key = '_'.join([str(_) for _ in key]) 60 | pallete_str_key[folder][str_key] = pallete[folder][key] 61 | 62 | # Save into file 63 | with open(file_name, 'w') as f: 64 | json.dump(pallete_str_key, f) 65 | 66 | # ---------------------------------------------------------------------------------------- 67 | # Define the categorical-related operations 68 | # ---------------------------------------------------------------------------------------- 69 | 70 | def getCategoricalMapping(loader = None, path = 'torchvision_sunner_categories_pallete.json'): 71 | """ 72 | This function can statistic the different category with color 73 | And return the list of the mapping OrderedDict object 74 | 75 | Arg: loader - The ImageLoader object 76 | path - The path of pallete file 77 | Ret: The list of OrderDict object (palletes object) 78 | """ 79 | INFO("Applied << %15s >>" % getCategoricalMapping.__name__) 80 | INFO("* Notice: the rank format of input tensor should be 'BHWC'") 81 | INFO("* Notice: The range of tensor should be in [0, 255]") 82 | if os.path.exists(path): 83 | palletes = load_pallete(path) 84 | else: 85 | INFO(">> Load from scratch, please wait...") 86 | 87 | # Get the number of folder 88 | folder_num = 0 89 | for img_list in loader: 90 | folder_num = len(img_list) 91 | break 92 | 93 | # Initialize the pallete list 94 | palletes = [OrderedDict()] * folder_num 95 | color_sets = [set()] * folder_num 96 | 97 | # Work 98 | for img_list in tqdm(loader): 99 | for folder_idx in range(folder_num): 100 | img = img_list[folder_idx] 101 | if torch.max(img) > 255 or torch.min(img) < 0: 102 | raise Exception('tensor value out of range...\t range is [' + str(torch.min(img)) + ' ~ ' + str(torch.max(img))) 103 | img = img.cpu().data.numpy().astype(np.uint8) 104 | img = np.reshape(img, [-1, 3]) 105 | color_sets[folder_idx] |= set([tuple(_) for _ in img]) 106 | 107 | # Merge the color 108 | for i in range(folder_num): 109 | for color in color_sets[i]: 110 | if color not in palletes[i].keys(): 111 | palletes[i][color] = len(palletes[i]) 112 | save_pallete(palletes, path) 113 | 114 | return palletes 115 | 116 | class CategoricalTranspose(): 117 | def __init__(self, pallete = None, direction = COLOR2INDEX, index_default = 0): 118 | """ 119 | Transform the tensor into the particular format 120 | We support for 3 different kinds of format: 121 | 1. one hot image 122 | 2. index image 123 | 3. color 124 | 125 | Arg: pallete - The pallete object (default is None) 126 | direction - The direction you want to change 127 | index_default - The default index if the color cannot be found in the pallete 128 | """ 129 | self.pallete = pallete 130 | self.direction = direction 131 | self.index_default = index_default 132 | INFO("Applied << %15s >> , direction: %s" % (self.__class__.__name__, self.direction)) 133 | INFO("* Notice: The range of tensor should be in [-1, 1]") 134 | INFO("* Notice: the rank format of input tensor should be 'BCHW'") 135 | 136 | def fn_color_to_index(self, tensor): 137 | """ 138 | Transfer the tensor from the RGB colorful format into the index format 139 | 140 | Arg: tensor - The tensor obj. The tensor you want to deal with 141 | Ret: The tensor with index format 142 | """ 143 | if self.pallete is None: 144 | raise Exception("The direction << %s >> need the pallete object" % self.direction) 145 | tensor = tensor.transpose(-3, -2).transpose(-2, -1).cpu().data.numpy() 146 | size_tuple = list(np.shape(tensor)) 147 | tensor = (tensor * 127.5 + 127.5).astype(np.uint8) 148 | tensor = np.reshape(tensor, [-1, 3]) 149 | tensor = [tuple(_) for _ in tensor] 150 | tensor = [self.pallete.get(_, self.index_default) for _ in tensor] 151 | tensor = np.asarray(tensor) 152 | size_tuple[-1] = 1 153 | tensor = np.reshape(tensor, size_tuple) 154 | tensor = torch.from_numpy(tensor).transpose(-1, -2).transpose(-2, -3) 155 | return tensor 156 | 157 | def fn_index_to_one_hot(self, tensor): 158 | """ 159 | Transfer the tensor from the index format into the one-hot format 160 | 161 | Arg: tensor - The tensor obj. The tensor you want to deal with 162 | Ret: The tensor with one-hot format 163 | """ 164 | # Get the number of classes 165 | tensor = tensor.transpose(-3, -2).transpose(-2, -1) 166 | size_tuple = list(np.shape(tensor)) 167 | tensor = tensor.view(-1).cpu().data.numpy() 168 | channel = np.amax(tensor) + 1 169 | 170 | # Get the total number of pixel 171 | num_of_pixel = 1 172 | for i in range(len(size_tuple) - 1): 173 | num_of_pixel *= size_tuple[i] 174 | 175 | # Transfer as ont-hot format 176 | one_hot_tensor = np.zeros([num_of_pixel, channel]) 177 | for i in range(channel): 178 | one_hot_tensor[tensor == i, i] = 1 179 | 180 | # Recover to origin rank format and shape 181 | size_tuple[-1] = channel 182 | tensor = np.reshape(one_hot_tensor, size_tuple) 183 | tensor = torch.from_numpy(tensor).transpose(-1, -2).transpose(-2, -3) 184 | return tensor 185 | 186 | def fn_one_hot_to_index(self, tensor): 187 | """ 188 | Transfer the tensor from the one-hot format into the index format 189 | 190 | Arg: tensor - The tensor obj. The tensor you want to deal with 191 | Ret: The tensor with index format 192 | """ 193 | _, tensor = torch.max(tensor, dim = 1) 194 | tensor = tensor.unsqueeze(1) 195 | return tensor 196 | 197 | def fn_index_to_color(self, tensor): 198 | """ 199 | Transfer the tensor from the index format into the RGB colorful format 200 | 201 | Arg: tensor - The tensor obj. The tensor you want to deal with 202 | Ret: The tensor with RGB colorful format 203 | """ 204 | if self.pallete is None: 205 | raise Exception("The direction << %s >> need the pallete object" % self.direction) 206 | tensor = tensor.transpose(-3, -2).transpose(-2, -1).cpu().data.numpy() 207 | reverse_pallete = {self.pallete[x]: x for x in self.pallete} 208 | batch, height, width, channel = np.shape(tensor) 209 | tensor = np.reshape(tensor, [-1]) 210 | tensor = np.round(tensor, decimals=0) 211 | tensor = np.vectorize(reverse_pallete.get)(tensor) 212 | tensor = np.reshape(np.asarray(tensor).T, [batch, height, width, len(reverse_pallete[0])]) 213 | tensor = torch.from_numpy((tensor - 127.5) / 127.5).transpose(-1, -2).transpose(-2, -3) 214 | return tensor 215 | 216 | def __call__(self, tensor): 217 | if self.direction == COLOR2INDEX: 218 | return self.fn_color_to_index(tensor) 219 | elif self.direction == INDEX2COLOR: 220 | return self.fn_index_to_color(tensor) 221 | elif self.direction == ONEHOT2INDEX: 222 | return self.fn_one_hot_to_index(tensor) 223 | elif self.direction == INDEX2ONEHOT: 224 | return self.fn_index_to_one_hot(tensor) 225 | elif self.direction == ONEHOT2COLOR: 226 | return self.fn_index_to_color(self.fn_one_hot_to_index(tensor)) 227 | elif self.direction == COLOR2ONEHOT: 228 | return self.fn_index_to_one_hot(self.fn_color_to_index(tensor)) 229 | else: 230 | raise Exception("Unknown direction: {}".format(self.direction)) -------------------------------------------------------------------------------- /torchvision_sunner/transforms/complex.py: -------------------------------------------------------------------------------- 1 | from torchvision_sunner.transforms.base import OP 2 | from torchvision_sunner.utils import INFO 3 | from skimage import transform 4 | import numpy as np 5 | import torch 6 | 7 | """ 8 | This script define some complex operations 9 | These kind of operations should conduct work iteratively (with inherit OP class) 10 | 11 | Author: SunnerLi 12 | """ 13 | 14 | class Resize(OP): 15 | def __init__(self, output_size): 16 | """ 17 | Resize the tensor to the desired size 18 | This function only support for nearest-neighbor interpolation 19 | Since this mechanism can also deal with categorical data 20 | 21 | Arg: output_size - The tuple (H, W) 22 | """ 23 | self.output_size = output_size 24 | INFO("Applied << %15s >>" % self.__class__.__name__) 25 | INFO("* Notice: the rank format of input tensor should be 'BHWC'") 26 | 27 | def work(self, tensor): 28 | """ 29 | Resize the tensor 30 | If the tensor is not in the range of [-1, 1], we will do the normalization automatically 31 | 32 | Arg: tensor - The np.ndarray object. The tensor you want to deal with 33 | Ret: The resized tensor 34 | """ 35 | # Normalize the tensor if needed 36 | mean, std = -1, -1 37 | min_v = np.min(tensor) 38 | max_v = np.max(tensor) 39 | if not (max_v <= 1 and min_v >= -1): 40 | mean = 0.5 * max_v + 0.5 * min_v 41 | std = 0.5 * max_v - 0.5 * min_v 42 | # print(max_v, min_v, mean, std) 43 | tensor = (tensor - mean) / std 44 | 45 | # Work 46 | tensor = transform.resize(tensor, self.output_size, mode = 'constant', order = 0) 47 | 48 | # De-normalize the tensor 49 | if mean != -1 and std != -1: 50 | tensor = tensor * std + mean 51 | return tensor 52 | 53 | class Normalize(OP): 54 | def __init__(self, mean = [127.5, 127.5, 127.5], std = [127.5, 127.5, 127.5]): 55 | """ 56 | Normalize the tensor with given mean and standard deviation 57 | * Notice: If you didn't give mean and std, the result will locate in [-1, 1] 58 | 59 | Args: 60 | mean - The mean of the result tensor 61 | std - The standard deviation 62 | """ 63 | self.mean = mean 64 | self.std = std 65 | INFO("Applied << %15s >>" % self.__class__.__name__) 66 | INFO("* Notice: the rank format of input tensor should be 'BCHW'") 67 | INFO("*****************************************************************") 68 | INFO("* Notice: You should must call 'ToFloat' before normalization") 69 | INFO("*****************************************************************") 70 | if self.mean == [127.5, 127.5, 127.5] and self.std == [127.5, 127.5, 127.5]: 71 | INFO("* Notice: The result will locate in [-1, 1]") 72 | 73 | def work(self, tensor): 74 | """ 75 | Normalize the tensor 76 | 77 | Arg: tensor - The np.ndarray object. The tensor you want to deal with 78 | Ret: The normalized tensor 79 | """ 80 | if tensor.shape[0] != len(self.mean): 81 | raise Exception("The rank format should be BCHW, but the shape is {}".format(tensor.shape)) 82 | result = [] 83 | for t, m, s in zip(tensor, self.mean, self.std): 84 | result.append((t - m) / s) 85 | tensor = np.asarray(result) 86 | 87 | # Check if the normalization can really work 88 | if np.min(tensor) < -1 or np.max(tensor) > 1: 89 | raise Exception("Normalize can only work with float tensor", 90 | "Try to call 'ToFloat()' before normalization") 91 | return tensor 92 | 93 | class UnNormalize(OP): 94 | def __init__(self, mean = [127.5, 127.5, 127.5], std = [127.5, 127.5, 127.5]): 95 | """ 96 | Unnormalize the tensor with given mean and standard deviation 97 | * Notice: If you didn't give mean and std, the function will assume that the original distribution locates in [-1, 1] 98 | 99 | Args: 100 | mean - The mean of the result tensor 101 | std - The standard deviation 102 | """ 103 | self.mean = mean 104 | self.std = std 105 | INFO("Applied << %15s >>" % self.__class__.__name__) 106 | INFO("* Notice: the rank format of input tensor should be 'BCHW'") 107 | if self.mean == [127.5, 127.5, 127.5] and self.std == [127.5, 127.5, 127.5]: 108 | INFO("* Notice: The function assume that the input range is [-1, 1]") 109 | 110 | def work(self, tensor): 111 | """ 112 | Un-normalize the tensor 113 | 114 | Arg: tensor - The np.ndarray object. The tensor you want to deal with 115 | Ret: The un-normalized tensor 116 | """ 117 | if tensor.shape[0] != len(self.mean): 118 | raise Exception("The rank format should be BCHW, but the shape is {}".format(tensor.shape)) 119 | result = [] 120 | for t, m, s in zip(tensor, self.mean, self.std): 121 | result.append(t * s + m) 122 | tensor = np.asarray(result) 123 | return tensor 124 | 125 | class ToGray(OP): 126 | def __init__(self): 127 | """ 128 | Change the tensor as the gray scale 129 | The function will turn the BCHW tensor into B1HW gray-scaled tensor 130 | """ 131 | INFO("Applied << %15s >>" % self.__class__.__name__) 132 | INFO("* Notice: the rank format of input tensor should be 'BCHW'") 133 | 134 | def work(self, tensor): 135 | """ 136 | Make the tensor into gray-scale 137 | 138 | Arg: tensor - The np.ndarray object. The tensor you want to deal with 139 | Ret: The gray-scale tensor, and the rank of the tensor is B1HW 140 | """ 141 | if tensor.shape[0] == 3: 142 | result = 0.299 * tensor[0] + 0.587 * tensor[1] + 0.114 * tensor[2] 143 | result = np.expand_dims(result, axis = 0) 144 | elif tensor.shape[0] != 4: 145 | result = 0.299 * tensor[:, 0] + 0.587 * tensor[:, 1] + 0.114 * tensor[:, 2] 146 | result = np.expand_dims(result, axis = 1) 147 | else: 148 | raise Exception("The rank format should be BCHW, but the shape is {}".format(tensor.shape)) 149 | return result -------------------------------------------------------------------------------- /torchvision_sunner/transforms/function.py: -------------------------------------------------------------------------------- 1 | from torchvision_sunner.transforms.simple import Transpose 2 | from torchvision_sunner.constant import BCHW2BHWC 3 | 4 | from skimage import transform 5 | from skimage import io 6 | import numpy as np 7 | import torch 8 | 9 | """ 10 | This script define the transform function which can be called directly 11 | 12 | Author: SunnerLi 13 | """ 14 | 15 | channel_op = None # Define the channel op which will be used in 'asImg' function 16 | 17 | def asImg(tensor, size = None): 18 | """ 19 | This function provides fast approach to transfer the image into numpy.ndarray 20 | This function only accept the output from sigmoid layer or hyperbolic tangent output 21 | 22 | Arg: tensor - The torch.Variable object, the rank format is BCHW or BHW 23 | size - The tuple object, and the format is (height, width) 24 | Ret: The numpy image, the rank format is BHWC 25 | """ 26 | global channel_op 27 | result = tensor.detach() 28 | 29 | # 1. Judge the rank first 30 | if len(tensor.size()) == 3: 31 | result = torch.stack([result, result, result], 1) 32 | 33 | # 2. Judge the range of tensor (sigmoid output or hyperbolic tangent output) 34 | min_v = torch.min(result).cpu().data.numpy() 35 | max_v = torch.max(result).cpu().data.numpy() 36 | if max_v > 1.0 or min_v < -1.0: 37 | raise Exception('tensor value out of range...\t range is [' + str(min_v) + ' ~ ' + str(max_v)) 38 | if min_v < 0: 39 | result = (result + 1) / 2 40 | 41 | # 3. Define the BCHW -> BHWC operation 42 | if channel_op is None: 43 | channel_op = Transpose(BCHW2BHWC) 44 | 45 | # 3. Rest 46 | result = channel_op(result) 47 | result = result.cpu().data.numpy() 48 | if size is not None: 49 | result_list = [] 50 | for img in result: 51 | result_list.append(transform.resize(img, (size[0], size[1]), mode = 'constant', order = 0) * 255) 52 | result = np.stack(result_list, axis = 0) 53 | else: 54 | result *= 255. 55 | result = result.astype(np.uint8) 56 | return result -------------------------------------------------------------------------------- /torchvision_sunner/transforms/simple.py: -------------------------------------------------------------------------------- 1 | from torchvision_sunner.utils import INFO 2 | from torchvision_sunner.constant import * 3 | import numpy as np 4 | import torch 5 | 6 | """ 7 | This script define some operation which are rather simple 8 | The operation only need to call function once (without inherit OP class) 9 | 10 | Author: SunnerLi 11 | """ 12 | 13 | class ToTensor(): 14 | def __init__(self): 15 | """ 16 | Change the tensor into torch.Tensor type 17 | """ 18 | INFO("Applied << %15s >>" % self.__class__.__name__) 19 | 20 | def __call__(self, tensor): 21 | """ 22 | Arg: tensor - The torch.Tensor or other type. The tensor you want to deal with 23 | """ 24 | if type(tensor) == np.ndarray: 25 | tensor = torch.from_numpy(tensor) 26 | return tensor 27 | 28 | class ToFloat(): 29 | def __init__(self): 30 | """ 31 | Change the tensor into torch.FloatTensor 32 | """ 33 | INFO("Applied << %15s >>" % self.__class__.__name__) 34 | 35 | def __call__(self, tensor): 36 | """ 37 | Arg: tensor - The torch.Tensor object. The tensor you want to deal with 38 | """ 39 | return tensor.float() 40 | 41 | class Transpose(): 42 | def __init__(self, direction = BHWC2BCHW): 43 | """ 44 | Transfer the rank of tensor into target one 45 | 46 | Arg: direction - The direction you want to do the transpose 47 | """ 48 | self.direction = direction 49 | if self.direction == BHWC2BCHW: 50 | INFO("Applied << %15s >>, The rank format is BCHW" % self.__class__.__name__) 51 | elif self.direction == BCHW2BHWC: 52 | INFO("Applied << %15s >>, The rank format is BHWC" % self.__class__.__name__) 53 | else: 54 | raise Exception("Unknown direction symbol: {}".format(self.direction)) 55 | 56 | def __call__(self, tensor): 57 | """ 58 | Arg: tensor - The torch.Tensor object. The tensor you want to deal with 59 | """ 60 | if self.direction == BHWC2BCHW: 61 | tensor = tensor.transpose(-1, -2).transpose(-2, -3) 62 | else: 63 | tensor = tensor.transpose(-3, -2).transpose(-2, -1) 64 | return tensor -------------------------------------------------------------------------------- /torchvision_sunner/utils.py: -------------------------------------------------------------------------------- 1 | from torchvision_sunner.constant import * 2 | 3 | """ 4 | This script defines the function which are widely used in the whole package 5 | 6 | Author: SunnerLi 7 | """ 8 | 9 | def quiet(): 10 | """ 11 | Mute the information toward the whole log in the toolkit 12 | """ 13 | global verbose 14 | verbose = False 15 | 16 | def INFO(string = None): 17 | """ 18 | Print the information with prefix 19 | 20 | Arg: string - The string you want to print 21 | """ 22 | if verbose: 23 | if string: 24 | print("[ Torchvision_sunner ] %s" % (string)) 25 | else: 26 | print("[ Torchvision_sunner ] " + '=' * 50) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | """ 3 | @author: samuel ko 4 | @readme: StyleGAN2 PyTorch 5 | """ 6 | import torchvision_sunner.transforms as sunnertransforms 7 | import torchvision_sunner.data as sunnerData 8 | import torchvision.transforms as transforms 9 | 10 | from torch.autograd import grad 11 | 12 | from network.stylegan2 import G_stylegan2, D_stylegan2 13 | from utils.utils import plotLossCurve 14 | from loss.loss import D_logistic_r1, D_logistic_r2, G_logistic_ns_pathreg 15 | from opts.opts import TrainOptions, INFO 16 | 17 | from torchvision.utils import save_image 18 | from tqdm import tqdm 19 | from matplotlib import pyplot as plt 20 | import torch.optim as optim 21 | import numpy as np 22 | import random 23 | import torch 24 | import os 25 | 26 | 27 | # Set random seem for reproducibility 28 | # manualSeed = 999 29 | #manualSeed = random.randint(1, 10000) # use if you want new results 30 | # print("Random Seed: ", manualSeed) 31 | # random.seed(manualSeed) 32 | # torch.manual_seed(manualSeed) 33 | 34 | # Hyper-parameters 35 | CRITIC_ITER = 3 36 | PL_DECAY = 0.01 37 | PL_WEIGHT = 2.0 38 | 39 | 40 | def main(opts): 41 | # Create the data loader 42 | loader = sunnerData.DataLoader(sunnerData.ImageDataset( 43 | root=[[opts.path]], 44 | transform=transforms.Compose([ 45 | sunnertransforms.Resize((opts.resolution, opts.resolution)), 46 | sunnertransforms.ToTensor(), 47 | sunnertransforms.ToFloat(), 48 | sunnertransforms.Transpose(sunnertransforms.BHWC2BCHW), 49 | sunnertransforms.Normalize(), 50 | ])), 51 | batch_size=opts.batch_size, 52 | shuffle=True, 53 | drop_last=True 54 | ) 55 | 56 | # Create the model 57 | start_epoch = 0 58 | G = G_stylegan2(fmap_base=opts.fmap_base, 59 | resolution=opts.resolution, 60 | mapping_layers=opts.mapping_layers, 61 | opts=opts, 62 | return_dlatents=True) 63 | D = D_stylegan2(fmap_base=opts.fmap_base, 64 | resolution=opts.resolution, 65 | structure='resnet') 66 | 67 | # Load the pre-trained weight 68 | if os.path.exists(opts.resume): 69 | INFO("Load the pre-trained weight!") 70 | state = torch.load(opts.resume) 71 | G.load_state_dict(state['G']) 72 | D.load_state_dict(state['D']) 73 | start_epoch = state['start_epoch'] 74 | else: 75 | INFO("Pre-trained weight cannot load successfully, train from scratch!") 76 | 77 | # Multi-GPU support 78 | if torch.cuda.device_count() > 1: 79 | INFO("Multiple GPU:" + str(torch.cuda.device_count()) + "\t GPUs") 80 | G = torch.nn.DataParallel(G) 81 | D = torch.nn.DataParallel(D) 82 | G.to(opts.device) 83 | D.to(opts.device) 84 | 85 | # Create the criterion, optimizer and scheduler 86 | lr_D = 0.0015 87 | lr_G = 0.0015 88 | optim_D = torch.optim.Adam(D.parameters(), lr=lr_D, betas=(0.9, 0.999)) 89 | # g_mapping has 100x lower learning rate 90 | params_G = [{"params": G.g_synthesis.parameters()}, 91 | {"params": G.g_mapping.parameters(), "lr": lr_G * 0.01}] 92 | optim_G = torch.optim.Adam(params_G, lr=lr_G, betas=(0.9, 0.999)) 93 | scheduler_D = optim.lr_scheduler.ExponentialLR(optim_D, gamma=0.99) 94 | scheduler_G = optim.lr_scheduler.ExponentialLR(optim_G, gamma=0.99) 95 | 96 | # Train 97 | fix_z = torch.randn([opts.batch_size, 512]).to(opts.device) 98 | softplus = torch.nn.Softplus() 99 | Loss_D_list = [0.0] 100 | Loss_G_list = [0.0] 101 | for ep in range(start_epoch, opts.epoch): 102 | bar = tqdm(loader) 103 | loss_D_list = [] 104 | loss_G_list = [] 105 | for i, (real_img,) in enumerate(bar): 106 | 107 | real_img = real_img.to(opts.device) 108 | latents = torch.randn([real_img.size(0), 512]).to(opts.device) 109 | 110 | # ======================================================================================================= 111 | # (1) Update D network: D_logistic_r1(default) 112 | # ======================================================================================================= 113 | # Compute adversarial loss toward discriminator 114 | real_img = real_img.to(opts.device) 115 | real_logit = D(real_img) 116 | fake_img, fake_dlatent = G(latents) 117 | fake_logit = D(fake_img.detach()) 118 | 119 | d_loss = softplus(fake_logit) 120 | d_loss = d_loss + softplus(-real_logit) 121 | 122 | # original 123 | r1_penalty = D_logistic_r1(real_img.detach(), D) 124 | d_loss = (d_loss + r1_penalty).mean() 125 | # lite 126 | # d_loss = d_loss.mean() 127 | 128 | loss_D_list.append(d_loss.mean().item()) 129 | 130 | # Update discriminator 131 | optim_D.zero_grad() 132 | d_loss.backward() 133 | optim_D.step() 134 | 135 | # ======================================================================================================= 136 | # (2) Update G network: G_logistic_ns_pathreg(default) 137 | # ======================================================================================================= 138 | # if i % CRITIC_ITER == 0: 139 | G.zero_grad() 140 | fake_scores_out = D(fake_img) 141 | _g_loss = softplus(-fake_scores_out) 142 | 143 | g_loss = _g_loss.mean() 144 | loss_G_list.append(g_loss.mean().item()) 145 | 146 | # Update generator 147 | g_loss.backward() 148 | optim_G.step() 149 | 150 | # Output training stats 151 | bar.set_description( 152 | "Epoch {} [{}, {}] [G]: {} [D]: {}".format(ep, i + 1, len(loader), loss_G_list[-1], loss_D_list[-1])) 153 | 154 | # Save the result 155 | Loss_G_list.append(np.mean(loss_G_list)) 156 | Loss_D_list.append(np.mean(loss_D_list)) 157 | 158 | # Check how the generator is doing by saving G's output on fixed_noise 159 | with torch.no_grad(): 160 | fake_img = G(fix_z)[0].detach().cpu() 161 | save_image(fake_img, os.path.join(opts.det, 'images', str(ep) + '.png'), nrow=4, normalize=True) 162 | 163 | # Save model 164 | state = { 165 | 'G': G.state_dict(), 166 | 'D': D.state_dict(), 167 | 'Loss_G': Loss_G_list, 168 | 'Loss_D': Loss_D_list, 169 | 'start_epoch': ep, 170 | } 171 | torch.save(state, os.path.join(opts.det, 'models', 'latest.pth')) 172 | 173 | scheduler_D.step() 174 | scheduler_G.step() 175 | 176 | # Plot the total loss curve 177 | Loss_D_list = Loss_D_list[1:] 178 | Loss_G_list = Loss_G_list[1:] 179 | plotLossCurve(opts, Loss_D_list, Loss_G_list) 180 | 181 | 182 | if __name__ == '__main__': 183 | opts = TrainOptions().parse() 184 | main(opts) 185 | -------------------------------------------------------------------------------- /utils/libs.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | 4 | """ 5 | Miscellaneous utility classes and functions For StyleGAN2 Network. 6 | """ 7 | import torch 8 | import numpy as np 9 | 10 | # TWO = {1: 0, 2: 1, 4: 2, 8: 3, 16: 4, 32: 5, 11 | # 64: 6, 128: 7, 256: 8, 512: 9, 1024: 10} 12 | 13 | TWO = [pow(2, _) for _ in range(11)] 14 | 15 | 16 | def _setup_kernel(k): 17 | k = np.asarray(k, dtype=np.float32) 18 | if k.ndim == 1: 19 | k = np.outer(k, k) 20 | k /= np.sum(k) 21 | assert k.ndim == 2 22 | assert k.shape[0] == k.shape[1] 23 | return k 24 | 25 | 26 | def _approximate_size(feature_size): 27 | """ 28 | return most approximate 2**(x). 29 | :param feature_size (int): feature height (feature weight == feature height) 30 | :return: 31 | """ 32 | 33 | tmp = map(lambda x: abs(x - int(feature_size)), TWO) 34 | tmp = list(tmp) 35 | 36 | idxs = tmp.index(min(tmp)) 37 | return pow(2, idxs) 38 | 39 | 40 | # function to calculate the Exponential moving averages for the Generator weights 41 | # This function updates the exponential average weights based on the current training 42 | def update_average(model_tgt, model_src, beta): 43 | """ 44 | update the model_target using exponential moving averages 45 | :param model_tgt: target model 46 | :param model_src: source model 47 | :param beta: value of decay beta 48 | :return: None (updates the target model) 49 | """ 50 | 51 | # utility function for toggling the gradient requirements of the models 52 | def toggle_grad(model, requires_grad): 53 | for p in model.parameters(): 54 | p.requires_grad_(requires_grad) 55 | 56 | # turn off gradient calculation 57 | toggle_grad(model_tgt, False) 58 | toggle_grad(model_src, False) 59 | 60 | param_dict_src = dict(model_src.named_parameters()) 61 | 62 | for p_name, p_tgt in model_tgt.named_parameters(): 63 | p_src = param_dict_src[p_name] 64 | assert (p_src is not p_tgt) 65 | p_tgt.copy_(beta * p_tgt + (1. - beta) * p_src) 66 | 67 | # turn back on the gradient calculation 68 | toggle_grad(model_tgt, True) 69 | toggle_grad(model_src, True) 70 | 71 | 72 | class ShrinkFun(torch.autograd.Function): 73 | # Define grad for shrinked [-1, 1]. 74 | 75 | @staticmethod 76 | def forward(ctx, input): 77 | ctx.save_for_backward(input) 78 | input_x = input.clone() 79 | input_x = input_x / torch.max(torch.abs(input_x)) 80 | return input_x 81 | 82 | @staticmethod 83 | def backward(ctx, grad_output): 84 | # function 85 | grad_input = grad_output.clone() 86 | return grad_input 87 | 88 | 89 | def weights_init(m): 90 | classname = m.__class__.__name__ 91 | if classname == 'Conv2d': 92 | torch.nn.init.normal_(m.weight.data, 0.0, 0.02) 93 | 94 | 95 | if __name__ == "__main__": 96 | # k = _setup_kernel([1, 3, 3, 1]) 97 | # print(k) 98 | # print(k[::-1, ::-1]) 99 | 100 | _approximate_size(0) 101 | -------------------------------------------------------------------------------- /utils/stylegan-teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomguluson92/StyleGAN2_PyTorch/4ab7354c85cb986d2b77f5238c4a18c5efd1db1b/utils/stylegan-teaser.png -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | """ 3 | @author: samuel ko 4 | @date: 2019.12.13 5 | @readme: Miscellaneous utility classes and functions. 6 | """ 7 | 8 | import re 9 | import importlib 10 | from matplotlib import pyplot as plt 11 | import os 12 | import sys 13 | import types 14 | from typing import Any, List, Tuple, Union 15 | 16 | 17 | def plotLossCurve(opts, Loss_D_list, Loss_G_list): 18 | plt.figure() 19 | plt.plot(Loss_D_list, '-') 20 | plt.title("Loss curve (Discriminator)") 21 | plt.savefig(os.path.join(opts.det, 'images', 'loss_curve_discriminator.png')) 22 | 23 | plt.figure() 24 | plt.plot(Loss_G_list, '-o') 25 | plt.title("Loss curve (Generator)") 26 | plt.savefig(os.path.join(opts.det, 'images', 'loss_curve_generator.png')) 27 | 28 | 29 | def get_top_level_function_name(obj: Any) -> str: 30 | """Return the fully-qualified name of a top-level function.""" 31 | return obj.__module__ + "." + obj.__name__ 32 | 33 | 34 | def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]: 35 | """Searches for the underlying module behind the name to some python object. 36 | Returns the module and the object name (original name with module part removed).""" 37 | 38 | # allow convenience shorthands, substitute them by full names 39 | obj_name = re.sub("^np.", "numpy.", obj_name) 40 | obj_name = re.sub("^tf.", "tensorflow.", obj_name) 41 | 42 | # list alternatives for (module_name, local_obj_name) 43 | parts = obj_name.split(".") 44 | name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)] 45 | 46 | # try each alternative in turn 47 | for module_name, local_obj_name in name_pairs: 48 | try: 49 | module = importlib.import_module(module_name) # may raise ImportError 50 | get_obj_from_module(module, local_obj_name) # may raise AttributeError 51 | return module, local_obj_name 52 | except: 53 | pass 54 | 55 | # maybe some of the modules themselves contain errors? 56 | for module_name, _local_obj_name in name_pairs: 57 | try: 58 | importlib.import_module(module_name) # may raise ImportError 59 | except ImportError: 60 | if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"): 61 | raise 62 | 63 | # maybe the requested attribute is missing? 64 | for module_name, local_obj_name in name_pairs: 65 | try: 66 | module = importlib.import_module(module_name) # may raise ImportError 67 | get_obj_from_module(module, local_obj_name) # may raise AttributeError 68 | except ImportError: 69 | pass 70 | 71 | # we are out of luck, but we have no idea why 72 | raise ImportError(obj_name) 73 | 74 | 75 | def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any: 76 | """Traverses the object name and returns the last (rightmost) python object.""" 77 | if obj_name == '': 78 | return module 79 | obj = module 80 | for part in obj_name.split("."): 81 | obj = getattr(obj, part) 82 | return obj 83 | 84 | 85 | def get_obj_by_name(name: str) -> Any: 86 | """Finds the python object with the given name.""" 87 | module, obj_name = get_module_from_obj_name(name) 88 | return get_obj_from_module(module, obj_name) 89 | 90 | 91 | def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any: 92 | """Finds the python object with the given name and calls it as a function.""" 93 | assert func_name is not None 94 | func_obj = get_obj_by_name(func_name) 95 | assert callable(func_obj) 96 | return func_obj(*args, **kwargs) 97 | 98 | 99 | if __name__ == "__main__": 100 | def a(): 101 | print("gaga") 102 | 103 | b = globals()['a'] 104 | b = get_top_level_function_name(b) 105 | module, xxx = get_module_from_obj_name(b) 106 | _build_func = get_obj_from_module(module, xxx) 107 | _build_func() 108 | --------------------------------------------------------------------------------