├── .github └── FUNDING.yml ├── .idea └── vcs.xml ├── README.md ├── loss └── loss.py ├── networks_gan.py ├── networks_stylegan.py ├── opts └── opts.py ├── star ├── 00793.png ├── 00794.png ├── 00795.png ├── 00796.png ├── 00797.png ├── 00798.png ├── 00799.png ├── 00800.png ├── 00801.png ├── 00802.png ├── 00803.png ├── 00804.png ├── 00805.png ├── 00806.png ├── 00807.png ├── 00808.png ├── 00809.png ├── 00810.png ├── 00811.png ├── 00812.png ├── 00813.png ├── 00814.png ├── 00815.png ├── 00816.png ├── 00817.png ├── 00818.png ├── 00819.png ├── 00820.png ├── 00821.png ├── 00822.png ├── 00823.png ├── 00824.png ├── 00825.png ├── 00826.png ├── 00827.png ├── 00828.png ├── 00829.png ├── 00830.png ├── 00831.png └── 00832.png ├── torchvision_sunner ├── __init__.py ├── constant.py ├── data │ ├── __init__.py │ ├── base_dataset.py │ ├── image_dataset.py │ ├── loader.py │ └── video_dataset.py ├── read.py ├── transforms │ ├── __init__.py │ ├── base.py │ ├── categorical.py │ ├── complex.py │ ├── function.py │ └── simple.py └── utils.py ├── train.py ├── train_stylegan.py └── utils ├── stylegan-teaser.png └── utils.py /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] 4 | patreon: # Replace with a single Patreon username 5 | open_collective: # Replace with a single Open Collective username 6 | ko_fi: # Replace with a single Ko-fi username 7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel 8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry 9 | liberapay: # Replace with a single Liberapay username 10 | issuehunt: # Replace with a single IssueHunt username 11 | otechie: # Replace with a single Otechie username 12 | custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] 13 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | “GitAds”/ 2 | 3 |
This advert was placed by GitAds 4 | 5 | # A PyTorch Implementation of StyleGAN (Unofficial) 6 | 7 | ![Github](https://img.shields.io/badge/PyTorch-v1.0.1-green.svg?style=for-the-badge&logo=data:image/png) 8 | ![Github](https://img.shields.io/badge/python-3.6-green.svg?style=for-the-badge&logo=python) 9 | ![Github](https://img.shields.io/badge/status-AlmostFinished-blue.svg?style=for-the-badge&logo=fire) 10 | ![Github](https://img.shields.io/badge/license-CC_BY--NC-green.svg?style=for-the-badge&logo=fire) 11 | 12 | This repository contains a PyTorch implementation of the following paper: 13 | > **A Style-Based Generator Architecture for Generative Adversarial Networks**
14 | > Tero Karras (NVIDIA), Samuli Laine (NVIDIA), Timo Aila (NVIDIA)
15 | > http://stylegan.xyz/paper 16 | > 17 | > **Abstract:** *We propose an alternative generator architecture for generative adversarial networks, borrowing from style transfer literature. The new architecture leads to an automatically learned, unsupervised separation of high-level attributes (e.g., pose and identity when trained on human faces) and stochastic variation in the generated images (e.g., freckles, hair), and it enables intuitive, scale-specific control of the synthesis. The new generator improves the state-of-the-art in terms of traditional distribution quality metrics, leads to demonstrably better interpolation properties, and also better disentangles the latent factors of variation. To quantify interpolation quality and disentanglement, we propose two new, automated methods that are applicable to any generator architecture. Finally, we introduce a new, highly varied and high-quality dataset of human faces.* 18 | 19 | 20 | ![Teaser image](utils/stylegan-teaser.png) 21 | Picture: These people are not real – they were produced by our generator that allows control over different aspects of the image. 22 | 23 | ## Motivation 24 | To the best of my knowledge, there is still not a similar pytorch 1.0 implementation of styleGAN as NvLabs released(Tensorflow), 25 | therefore, i wanna implement it on pytorch1.0.1 to extend its usage in pytorch community. 26 | 27 | ## Notice 28 | @date: 2019.10.21 29 | 30 | @info: The noteworthy thing I just ignore to highlight is **you need to change default `Star` dataset to your own dataset** (such as FFHQ or others) in `opts.py`. Sorry for my carelessness for this. 31 | 32 | 33 | ## Author 34 | 35 | - [Samuel Ko](https://blog.csdn.net/g11d111) 36 | - [Sunner Li](https://github.com/SunnerLi) 37 | 38 | ## Training 39 | 40 | ``` python 41 | # ① pass your own dataset of training, batchsize and common settings in TrainOpts of `opts.py`. 42 | 43 | # ② run train_stylegan.py 44 | python3 train_stylegan.py 45 | 46 | # ③ you can get intermediate pics generated by stylegenerator in `opts.det/images/` 47 | ``` 48 | 49 | ## Project 50 | > we follow the release code of styleGAN carefully and if you found any bug or mistake in implementation, 51 | > please tell us and improve it, thank u very much! 52 | #### Finished 53 | * `blur2d` mechanism. (a step which takes much gpu memory and if you don't have enough resouces, please set it to `None`.) 54 | * `truncation` tricks. 55 | * Two kind of `upsample` method in `G_synthesis`. 56 | * Two kind of `downsample` method in `StyleDiscriminator`. 57 | * `PixelNorm` and `InstanceNorm`. 58 | * `Noise` mechanism. 59 | * `styleMixed` mechanism. 60 | * add `Multi-GPU` support. 61 | 62 | #### Unfinished 63 | * Inference code. 64 | 65 | ## Related 66 | [1. StyleGAN - Official TensorFlow Implementation](https://github.com/NVlabs/stylegan) 67 | 68 | [2. The re-implementation of style-based generator idea](https://github.com/SunnerLi/StyleGAN_demo) 69 | 70 | [3. ptrblck_styleGAN](https://github.com/lernapparat/lernapparat/blob/master/style_gan/pytorch_style_gan.ipynb) 71 | 72 | ## System Requirements 73 | - Ubuntu18.04 74 | - PyTorch 1.0.1 75 | - Numpy 1.13.3 76 | - torchvision 0.2.1 77 | - scikit-image 0.15.0 78 | - tqdm 79 | - GTX 1080Ti or above 80 | 81 | ## Q&A 82 | 83 | ## Acknowledgements 84 | Our code can run `1024 x 1024` resolution image generation task on 1080Ti, if you have stronger graphic card or GPU, then 85 | you may train your model with large batchsize and self-define your multi-gpu version of this code. 86 | 87 | My Email is **samuel.gao023@gmail.com**, if you have any question and wanna to PR, please let me know, thank you. 88 | -------------------------------------------------------------------------------- /loss/loss.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | """ 3 | @author: samuel ko 4 | """ 5 | 6 | from torch.autograd import Variable 7 | from torch.autograd import grad 8 | import torch.autograd as autograd 9 | import torch.nn as nn 10 | import torch 11 | 12 | import numpy as np 13 | 14 | def gradient_penalty(x, y, f): 15 | # interpolation 16 | shape = [x.size(0)] + [1] * (x.dim() - 1) 17 | alpha = torch.rand(shape).to(x.device) 18 | z = x + alpha * (y - x) 19 | 20 | # gradient penalty 21 | z = Variable(z, requires_grad=True).to(x.device) 22 | o = f(z) 23 | g = grad(o, z, grad_outputs=torch.ones(o.size()).to(z.device), create_graph=True)[0].view(z.size(0), -1) 24 | gp = ((g.norm(p=2, dim=1) - 1)**2).mean() 25 | return gp 26 | 27 | 28 | def R1Penalty(real_img, f): 29 | # gradient penalty 30 | reals = Variable(real_img, requires_grad=True).to(real_img.device) 31 | real_logit = f(reals) 32 | apply_loss_scaling = lambda x: x * torch.exp(x * torch.Tensor([np.float32(np.log(2.0))]).to(real_img.device)) 33 | undo_loss_scaling = lambda x: x * torch.exp(-x * torch.Tensor([np.float32(np.log(2.0))]).to(real_img.device)) 34 | 35 | real_logit = apply_loss_scaling(torch.sum(real_logit)) 36 | real_grads = grad(real_logit, reals, grad_outputs=torch.ones(real_logit.size()).to(reals.device), create_graph=True)[0].view(reals.size(0), -1) 37 | real_grads = undo_loss_scaling(real_grads) 38 | r1_penalty = torch.sum(torch.mul(real_grads, real_grads)) 39 | return r1_penalty 40 | 41 | 42 | def R2Penalty(fake_img, f): 43 | # gradient penalty 44 | fakes = Variable(fake_img, requires_grad=True).to(fake_img.device) 45 | fake_logit = f(fakes) 46 | apply_loss_scaling = lambda x: x * torch.exp(x * torch.Tensor([np.float32(np.log(2.0))]).to(fake_img.device)) 47 | undo_loss_scaling = lambda x: x * torch.exp(-x * torch.Tensor([np.float32(np.log(2.0))]).to(fake_img.device)) 48 | 49 | fake_logit = apply_loss_scaling(torch.sum(fake_logit)) 50 | fake_grads = grad(fake_logit, fakes, grad_outputs=torch.ones(fake_logit.size()).to(fakes.device), create_graph=True)[0].view(fakes.size(0), -1) 51 | fake_grads = undo_loss_scaling(fake_grads) 52 | r2_penalty = torch.sum(torch.mul(fake_grads, fake_grads)) 53 | return r2_penalty 54 | -------------------------------------------------------------------------------- /networks_gan.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | """ 3 | @author: samuel ko 4 | """ 5 | import torch.nn.functional as F 6 | import torch.nn as nn 7 | import torch 8 | 9 | 10 | class Generator(nn.Module): 11 | def __init__(self, z_dims=512, d=64): 12 | super().__init__() 13 | self.deconv1 = nn.utils.spectral_norm(nn.ConvTranspose2d(z_dims, d * 8, 4, 1, 0)) 14 | self.deconv2 = nn.utils.spectral_norm(nn.ConvTranspose2d(d * 8, d * 8, 4, 2, 1)) 15 | self.deconv3 = nn.utils.spectral_norm(nn.ConvTranspose2d(d * 8, d * 4, 4, 2, 1)) 16 | self.deconv4 = nn.utils.spectral_norm(nn.ConvTranspose2d(d * 4, d * 2, 4, 2, 1)) 17 | self.deconv5 = nn.utils.spectral_norm(nn.ConvTranspose2d(d * 2, d, 4, 2, 1)) 18 | self.deconv6 = nn.ConvTranspose2d(d, 3, 4, 2, 1) 19 | 20 | def forward(self, input): 21 | input = input.view(input.size(0), input.size(1), 1, 1) # 1 x 1 22 | x = F.relu(self.deconv1(input)) # 4 x 4 23 | x = F.relu(self.deconv2(x)) # 8 x 8 24 | x = F.relu(self.deconv3(x)) # 16 x 16 25 | x = F.relu(self.deconv4(x)) # 32 x 32 26 | x = F.relu(self.deconv5(x)) # 64 x 64 27 | x = F.tanh(self.deconv6(x)) # 128 x 128 28 | return x 29 | 30 | 31 | class Discriminator(nn.Module): 32 | def __init__(self, nc=3, ndf=64): 33 | super().__init__() 34 | self.layer1 = nn.Conv2d(nc, ndf, 4, 2, 1, bias=False) 35 | self.layer2 = nn.utils.spectral_norm(nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False)) 36 | self.layer3 = nn.utils.spectral_norm(nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False)) 37 | self.layer4 = nn.utils.spectral_norm(nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False)) 38 | self.layer5 = nn.utils.spectral_norm(nn.Conv2d(ndf * 8, ndf * 8, 4, 2, 1, bias=False)) 39 | self.layer6 = nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False) 40 | 41 | def forward(self, input): 42 | out = F.leaky_relu(self.layer1(input), 0.2, inplace=True) # 64 x 64 43 | out = F.leaky_relu(self.layer2(out), 0.2, inplace=True) # 32 x 32 44 | out = F.leaky_relu(self.layer3(out), 0.2, inplace=True) # 16 x 16 45 | out = F.leaky_relu(self.layer4(out), 0.2, inplace=True) # 8 x 8 46 | out = F.leaky_relu(self.layer5(out), 0.2, inplace=True) # 4 x 4 47 | out = F.leaky_relu(self.layer6(out), 0.2, inplace=True) # 1 x 1 48 | return out.view(-1, 1) -------------------------------------------------------------------------------- /networks_stylegan.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | """ 3 | @author: samuel ko 4 | @date: 2019.04.11 5 | @notice: 6 | 1) refactor the module of Gsynthesis with 7 | - LayerEpilogue. 8 | - Upsample2d. 9 | - GBlock. 10 | and etc. 11 | 2) the initialization of every patch we use are all abided by the original NvLabs released code. 12 | 3) Discriminator is a simplicity version of PyTorch. 13 | 4) fix bug: default settings of batchsize. 14 | 15 | """ 16 | import torch.nn.functional as F 17 | import torch.nn as nn 18 | import numpy as np 19 | import torch 20 | import os 21 | from collections import OrderedDict 22 | from torch.nn.init import kaiming_normal_ 23 | 24 | 25 | class ApplyNoise(nn.Module): 26 | def __init__(self, channels): 27 | super().__init__() 28 | self.weight = nn.Parameter(torch.zeros(channels)) 29 | 30 | def forward(self, x, noise): 31 | if noise is None: 32 | noise = torch.randn(x.size(0), 1, x.size(2), x.size(3), device=x.device, dtype=x.dtype) 33 | return x + self.weight.view(1, -1, 1, 1) * noise.to(x.device) 34 | 35 | 36 | class ApplyStyle(nn.Module): 37 | """ 38 | @ref: https://github.com/lernapparat/lernapparat/blob/master/style_gan/pytorch_style_gan.ipynb 39 | """ 40 | def __init__(self, latent_size, channels, use_wscale): 41 | super(ApplyStyle, self).__init__() 42 | self.linear = FC(latent_size, 43 | channels * 2, 44 | gain=1.0, 45 | use_wscale=use_wscale) 46 | 47 | def forward(self, x, latent): 48 | style = self.linear(latent) # style => [batch_size, n_channels*2] 49 | shape = [-1, 2, x.size(1), 1, 1] 50 | style = style.view(shape) # [batch_size, 2, n_channels, ...] 51 | x = x * (style[:, 0] + 1.) + style[:, 1] 52 | return x 53 | 54 | 55 | class FC(nn.Module): 56 | def __init__(self, 57 | in_channels, 58 | out_channels, 59 | gain=2**(0.5), 60 | use_wscale=False, 61 | lrmul=1.0, 62 | bias=True): 63 | """ 64 | The complete conversion of Dense/FC/Linear Layer of original Tensorflow version. 65 | """ 66 | super(FC, self).__init__() 67 | he_std = gain * in_channels ** (-0.5) # He init 68 | if use_wscale: 69 | init_std = 1.0 / lrmul 70 | self.w_lrmul = he_std * lrmul 71 | else: 72 | init_std = he_std / lrmul 73 | self.w_lrmul = lrmul 74 | 75 | self.weight = torch.nn.Parameter(torch.randn(out_channels, in_channels) * init_std) 76 | if bias: 77 | self.bias = torch.nn.Parameter(torch.zeros(out_channels)) 78 | self.b_lrmul = lrmul 79 | else: 80 | self.bias = None 81 | 82 | def forward(self, x): 83 | if self.bias is not None: 84 | out = F.linear(x, self.weight * self.w_lrmul, self.bias * self.b_lrmul) 85 | else: 86 | out = F.linear(x, self.weight * self.w_lrmul) 87 | out = F.leaky_relu(out, 0.2, inplace=True) 88 | return out 89 | 90 | 91 | class Blur2d(nn.Module): 92 | def __init__(self, f=[1,2,1], normalize=True, flip=False, stride=1): 93 | """ 94 | depthwise_conv2d: 95 | https://blog.csdn.net/mao_xiao_feng/article/details/78003476 96 | """ 97 | super(Blur2d, self).__init__() 98 | assert isinstance(f, list) or f is None, "kernel f must be an instance of python built_in type list!" 99 | 100 | if f is not None: 101 | f = torch.tensor(f, dtype=torch.float32) 102 | f = f[:, None] * f[None, :] 103 | f = f[None, None] 104 | if normalize: 105 | f = f / f.sum() 106 | if flip: 107 | # f = f[:, :, ::-1, ::-1] 108 | f = torch.flip(f, [2, 3]) 109 | self.f = f 110 | else: 111 | self.f = None 112 | self.stride = stride 113 | 114 | def forward(self, x): 115 | if self.f is not None: 116 | # expand kernel channels 117 | kernel = self.f.expand(x.size(1), -1, -1, -1).to(x.device) 118 | x = F.conv2d( 119 | x, 120 | kernel, 121 | stride=self.stride, 122 | padding=int((self.f.size(2)-1)/2), 123 | groups=x.size(1) 124 | ) 125 | return x 126 | else: 127 | return x 128 | 129 | 130 | class Conv2d(nn.Module): 131 | def __init__(self, 132 | input_channels, 133 | output_channels, 134 | kernel_size, 135 | gain=2 ** (0.5), 136 | use_wscale=False, 137 | lrmul=1, 138 | bias=True): 139 | super().__init__() 140 | he_std = gain * (input_channels * kernel_size ** 2) ** (-0.5) # He init 141 | self.kernel_size = kernel_size 142 | if use_wscale: 143 | init_std = 1.0 / lrmul 144 | self.w_lrmul = he_std * lrmul 145 | else: 146 | init_std = he_std / lrmul 147 | self.w_lrmul = lrmul 148 | 149 | self.weight = torch.nn.Parameter( 150 | torch.randn(output_channels, input_channels, kernel_size, kernel_size) * init_std) 151 | if bias: 152 | self.bias = torch.nn.Parameter(torch.zeros(output_channels)) 153 | self.b_lrmul = lrmul 154 | else: 155 | self.bias = None 156 | 157 | def forward(self, x): 158 | if self.bias is not None: 159 | return F.conv2d(x, self.weight * self.w_lrmul, self.bias * self.b_lrmul, padding=self.kernel_size // 2) 160 | else: 161 | return F.conv2d(x, self.weight * self.w_lrmul, padding=self.kernel_size // 2) 162 | 163 | 164 | class Upscale2d(nn.Module): 165 | def __init__(self, factor=2, gain=1): 166 | """ 167 | the first upsample method in G_synthesis. 168 | :param factor: 169 | :param gain: 170 | """ 171 | super().__init__() 172 | self.gain = gain 173 | self.factor = factor 174 | 175 | def forward(self, x): 176 | if self.gain != 1: 177 | x = x * self.gain 178 | if self.factor > 1: 179 | shape = x.shape 180 | x = x.view(shape[0], shape[1], shape[2], 1, shape[3], 1).expand(-1, -1, -1, self.factor, -1, self.factor) 181 | x = x.contiguous().view(shape[0], shape[1], self.factor * shape[2], self.factor * shape[3]) 182 | return x 183 | 184 | 185 | class PixelNorm(nn.Module): 186 | def __init__(self, epsilon=1e-8): 187 | """ 188 | @notice: avoid in-place ops. 189 | https://discuss.pytorch.org/t/encounter-the-runtimeerror-one-of-the-variables-needed-for-gradient-computation-has-been-modified-by-an-inplace-operation/836/3 190 | """ 191 | super(PixelNorm, self).__init__() 192 | self.epsilon = epsilon 193 | 194 | def forward(self, x): 195 | tmp = torch.mul(x, x) # or x ** 2 196 | tmp1 = torch.rsqrt(torch.mean(tmp, dim=1, keepdim=True) + self.epsilon) 197 | 198 | return x * tmp1 199 | 200 | 201 | class InstanceNorm(nn.Module): 202 | def __init__(self, epsilon=1e-8): 203 | """ 204 | @notice: avoid in-place ops. 205 | https://discuss.pytorch.org/t/encounter-the-runtimeerror-one-of-the-variables-needed-for-gradient-computation-has-been-modified-by-an-inplace-operation/836/3 206 | """ 207 | super(InstanceNorm, self).__init__() 208 | self.epsilon = epsilon 209 | 210 | def forward(self, x): 211 | x = x - torch.mean(x, (2, 3), True) 212 | tmp = torch.mul(x, x) # or x ** 2 213 | tmp = torch.rsqrt(torch.mean(tmp, (2, 3), True) + self.epsilon) 214 | return x * tmp 215 | 216 | 217 | class LayerEpilogue(nn.Module): 218 | def __init__(self, 219 | channels, 220 | dlatent_size, 221 | use_wscale, 222 | use_noise, 223 | use_pixel_norm, 224 | use_instance_norm, 225 | use_styles): 226 | super(LayerEpilogue, self).__init__() 227 | 228 | if use_noise: 229 | self.noise = ApplyNoise(channels) 230 | self.act = nn.LeakyReLU(negative_slope=0.2) 231 | 232 | if use_pixel_norm: 233 | self.pixel_norm = PixelNorm() 234 | else: 235 | self.pixel_norm = None 236 | 237 | if use_instance_norm: 238 | self.instance_norm = InstanceNorm() 239 | else: 240 | self.instance_norm = None 241 | 242 | if use_styles: 243 | self.style_mod = ApplyStyle(dlatent_size, channels, use_wscale=use_wscale) 244 | else: 245 | self.style_mod = None 246 | 247 | def forward(self, x, noise, dlatents_in_slice=None): 248 | x = self.noise(x, noise) 249 | x = self.act(x) 250 | if self.pixel_norm is not None: 251 | x = self.pixel_norm(x) 252 | if self.instance_norm is not None: 253 | x = self.instance_norm(x) 254 | if self.style_mod is not None: 255 | x = self.style_mod(x, dlatents_in_slice) 256 | 257 | return x 258 | 259 | 260 | class GBlock(nn.Module): 261 | def __init__(self, 262 | res, 263 | use_wscale, 264 | use_noise, 265 | use_pixel_norm, 266 | use_instance_norm, 267 | noise_input, # noise 268 | dlatent_size=512, # Disentangled latent (W) dimensionality. 269 | use_style=True, # Enable style inputs? 270 | f=None, # (Huge overload, if you dont have enough resouces, please pass it as `f = None`)Low-pass filter to apply when resampling activations. None = no filtering. 271 | factor=2, # upsample factor. 272 | fmap_base=8192, # Overall multiplier for the number of feature maps. 273 | fmap_decay=1.0, # log2 feature map reduction when doubling the resolution. 274 | fmap_max=512, # Maximum number of feature maps in any layer. 275 | ): 276 | super(GBlock, self).__init__() 277 | self.nf = lambda stage: min(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_max) 278 | 279 | # res 280 | self.res = res 281 | 282 | # blur2d 283 | self.blur = Blur2d(f) 284 | 285 | # noise 286 | self.noise_input = noise_input 287 | 288 | if res < 7: 289 | # upsample method 1 290 | self.up_sample = Upscale2d(factor) 291 | else: 292 | # upsample method 2 293 | self.up_sample = nn.ConvTranspose2d(self.nf(res-3), self.nf(res-2), 4, stride=2, padding=1) 294 | 295 | # A Composition of LayerEpilogue and Conv2d. 296 | self.adaIn1 = LayerEpilogue(self.nf(res-2), dlatent_size, use_wscale, use_noise, 297 | use_pixel_norm, use_instance_norm, use_style) 298 | self.conv1 = Conv2d(input_channels=self.nf(res-2), output_channels=self.nf(res-2), 299 | kernel_size=3, use_wscale=use_wscale) 300 | self.adaIn2 = LayerEpilogue(self.nf(res-2), dlatent_size, use_wscale, use_noise, 301 | use_pixel_norm, use_instance_norm, use_style) 302 | 303 | def forward(self, x, dlatent): 304 | x = self.up_sample(x) 305 | x = self.adaIn1(x, self.noise_input[self.res*2-4], dlatent[:, self.res*2-4]) 306 | x = self.conv1(x) 307 | x = self.adaIn2(x, self.noise_input[self.res*2-3], dlatent[:, self.res*2-3]) 308 | return x 309 | 310 | #model.apply(weights_init) 311 | 312 | 313 | # ========================================================================= 314 | # Define sub-network 315 | # 2019.3.31 316 | # FC 317 | # ========================================================================= 318 | class G_mapping(nn.Module): 319 | def __init__(self, 320 | mapping_fmaps=512, 321 | dlatent_size=512, 322 | resolution=1024, 323 | normalize_latents=True, # Normalize latent vectors (Z) before feeding them to the mapping layers? 324 | use_wscale=True, # Enable equalized learning rate? 325 | lrmul=0.01, # Learning rate multiplier for the mapping layers. 326 | gain=2**(0.5) # original gain in tensorflow. 327 | ): 328 | super(G_mapping, self).__init__() 329 | self.mapping_fmaps = mapping_fmaps 330 | self.func = nn.Sequential( 331 | FC(self.mapping_fmaps, dlatent_size, gain, lrmul=lrmul, use_wscale=use_wscale), 332 | FC(dlatent_size, dlatent_size, gain, lrmul=lrmul, use_wscale=use_wscale), 333 | FC(dlatent_size, dlatent_size, gain, lrmul=lrmul, use_wscale=use_wscale), 334 | FC(dlatent_size, dlatent_size, gain, lrmul=lrmul, use_wscale=use_wscale), 335 | FC(dlatent_size, dlatent_size, gain, lrmul=lrmul, use_wscale=use_wscale), 336 | FC(dlatent_size, dlatent_size, gain, lrmul=lrmul, use_wscale=use_wscale), 337 | FC(dlatent_size, dlatent_size, gain, lrmul=lrmul, use_wscale=use_wscale), 338 | FC(dlatent_size, dlatent_size, gain, lrmul=lrmul, use_wscale=use_wscale) 339 | ) 340 | 341 | self.normalize_latents = normalize_latents 342 | self.resolution_log2 = int(np.log2(resolution)) 343 | self.num_layers = self.resolution_log2 * 2 - 2 344 | self.pixel_norm = PixelNorm() 345 | # - 2 means we start from feature map with height and width equals 4. 346 | # as this example, we get num_layers = 18. 347 | 348 | def forward(self, x): 349 | if self.normalize_latents: 350 | x = self.pixel_norm(x) 351 | out = self.func(x) 352 | return out, self.num_layers 353 | 354 | 355 | class G_synthesis(nn.Module): 356 | def __init__(self, 357 | dlatent_size, # Disentangled latent (W) dimensionality. 358 | resolution=1024, # Output resolution (1024 x 1024 by default). 359 | fmap_base=8192, # Overall multiplier for the number of feature maps. 360 | num_channels=3, # Number of output color channels. 361 | structure='fixed', # 'fixed' = no progressive growing, 'linear' = human-readable, 'recursive' = efficient, 'auto' = select automatically. 362 | fmap_max=512, # Maximum number of feature maps in any layer. 363 | fmap_decay=1.0, # log2 feature map reduction when doubling the resolution. 364 | f=None, # (Huge overload, if you dont have enough resouces, please pass it as `f = None`)Low-pass filter to apply when resampling activations. None = no filtering. 365 | use_pixel_norm = False, # Enable pixelwise feature vector normalization? 366 | use_instance_norm = True, # Enable instance normalization? 367 | use_wscale = True, # Enable equalized learning rate? 368 | use_noise = True, # Enable noise inputs? 369 | use_style = True # Enable style inputs? 370 | ): # batch size. 371 | """ 372 | 2019.3.31 373 | :param dlatent_size: 512 Disentangled latent(W) dimensionality. 374 | :param resolution: 1024 x 1024. 375 | :param fmap_base: 376 | :param num_channels: 377 | :param structure: only support 'fixed' mode. 378 | :param fmap_max: 379 | """ 380 | super(G_synthesis, self).__init__() 381 | 382 | self.nf = lambda stage: min(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_max) 383 | self.structure = structure 384 | self.resolution_log2 = int(np.log2(resolution)) 385 | # - 2 means we start from feature map with height and width equals 4. 386 | # as this example, we get num_layers = 18. 387 | num_layers = self.resolution_log2 * 2 - 2 388 | self.num_layers = num_layers 389 | 390 | # Noise inputs. 391 | self.noise_inputs = [] 392 | for layer_idx in range(num_layers): 393 | res = layer_idx // 2 + 2 394 | shape = [1, 1, 2 ** res, 2 ** res] 395 | self.noise_inputs.append(torch.randn(*shape).to("cuda")) 396 | 397 | # Blur2d 398 | self.blur = Blur2d(f) 399 | 400 | # torgb: fixed mode 401 | self.channel_shrinkage = Conv2d(input_channels=self.nf(self.resolution_log2-2), 402 | output_channels=self.nf(self.resolution_log2), 403 | kernel_size=3, 404 | use_wscale=use_wscale) 405 | self.torgb = Conv2d(self.nf(self.resolution_log2), num_channels, kernel_size=1, gain=1, use_wscale=use_wscale) 406 | 407 | # Initial Input Block 408 | self.const_input = nn.Parameter(torch.ones(1, self.nf(1), 4, 4)) 409 | self.bias = nn.Parameter(torch.ones(self.nf(1))) 410 | self.adaIn1 = LayerEpilogue(self.nf(1), dlatent_size, use_wscale, use_noise, 411 | use_pixel_norm, use_instance_norm, use_style) 412 | self.conv1 = Conv2d(input_channels=self.nf(1), output_channels=self.nf(1), kernel_size=3, use_wscale=use_wscale) 413 | self.adaIn2 = LayerEpilogue(self.nf(1), dlatent_size, use_wscale, use_noise, use_pixel_norm, 414 | use_instance_norm, use_style) 415 | 416 | # Common Block 417 | # 4 x 4 -> 8 x 8 418 | res = 3 419 | self.GBlock1 = GBlock(res, use_wscale, use_noise, use_pixel_norm, use_instance_norm, 420 | self.noise_inputs) 421 | 422 | # 8 x 8 -> 16 x 16 423 | res = 4 424 | self.GBlock2 = GBlock(res, use_wscale, use_noise, use_pixel_norm, use_instance_norm, 425 | self.noise_inputs) 426 | 427 | # 16 x 16 -> 32 x 32 428 | res = 5 429 | self.GBlock3 = GBlock(res, use_wscale, use_noise, use_pixel_norm, use_instance_norm, 430 | self.noise_inputs) 431 | 432 | # 32 x 32 -> 64 x 64 433 | res = 6 434 | self.GBlock4 = GBlock(res, use_wscale, use_noise, use_pixel_norm, use_instance_norm, 435 | self.noise_inputs) 436 | 437 | # 64 x 64 -> 128 x 128 438 | res = 7 439 | self.GBlock5 = GBlock(res, use_wscale, use_noise, use_pixel_norm, use_instance_norm, 440 | self.noise_inputs) 441 | 442 | # 128 x 128 -> 256 x 256 443 | res = 8 444 | self.GBlock6 = GBlock(res, use_wscale, use_noise, use_pixel_norm, use_instance_norm, 445 | self.noise_inputs) 446 | 447 | # 256 x 256 -> 512 x 512 448 | res = 9 449 | self.GBlock7 = GBlock(res, use_wscale, use_noise, use_pixel_norm, use_instance_norm, 450 | self.noise_inputs) 451 | 452 | # 512 x 512 -> 1024 x 1024 453 | res = 10 454 | self.GBlock8 = GBlock(res, use_wscale, use_noise, use_pixel_norm, use_instance_norm, 455 | self.noise_inputs) 456 | 457 | def forward(self, dlatent): 458 | """ 459 | dlatent: Disentangled latents (W), shape为[minibatch, num_layers, dlatent_size]. 460 | :param dlatent: 461 | :return: 462 | """ 463 | images_out = None 464 | # Fixed structure: simple and efficient, but does not support progressive growing. 465 | if self.structure == 'fixed': 466 | # initial block 0: 467 | x = self.const_input.expand(dlatent.size(0), -1, -1, -1) 468 | x = x + self.bias.view(1, -1, 1, 1) 469 | x = self.adaIn1(x, self.noise_inputs[0], dlatent[:, 0]) 470 | x = self.conv1(x) 471 | x = self.adaIn2(x, self.noise_inputs[1], dlatent[:, 1]) 472 | 473 | # block 1: 474 | # 4 x 4 -> 8 x 8 475 | x = self.GBlock1(x, dlatent) 476 | 477 | # block 2: 478 | # 8 x 8 -> 16 x 16 479 | x = self.GBlock2(x, dlatent) 480 | 481 | # block 3: 482 | # 16 x 16 -> 32 x 32 483 | x = self.GBlock3(x, dlatent) 484 | 485 | # block 4: 486 | # 32 x 32 -> 64 x 64 487 | x = self.GBlock4(x, dlatent) 488 | 489 | # block 5: 490 | # 64 x 64 -> 128 x 128 491 | x = self.GBlock5(x, dlatent) 492 | 493 | # block 6: 494 | # 128 x 128 -> 256 x 256 495 | x = self.GBlock6(x, dlatent) 496 | 497 | # block 7: 498 | # 256 x 256 -> 512 x 512 499 | x = self.GBlock7(x, dlatent) 500 | 501 | # block 8: 502 | # 512 x 512 -> 1024 x 1024 503 | x = self.GBlock8(x, dlatent) 504 | 505 | x = self.channel_shrinkage(x) 506 | images_out = self.torgb(x) 507 | return images_out 508 | 509 | 510 | class StyleGenerator(nn.Module): 511 | def __init__(self, 512 | mapping_fmaps=512, 513 | style_mixing_prob=0.9, # Probability of mixing styles during training. None = disable. 514 | truncation_psi=0.7, # Style strength multiplier for the truncation trick. None = disable. 515 | truncation_cutoff=8, # Number of layers for which to apply the truncation trick. None = disable. 516 | **kwargs 517 | ): 518 | super(StyleGenerator, self).__init__() 519 | self.mapping_fmaps = mapping_fmaps 520 | self.style_mixing_prob = style_mixing_prob 521 | self.truncation_psi = truncation_psi 522 | self.truncation_cutoff = truncation_cutoff 523 | 524 | self.mapping = G_mapping(self.mapping_fmaps, **kwargs) 525 | self.synthesis = G_synthesis(self.mapping_fmaps, **kwargs) 526 | 527 | def forward(self, latents1): 528 | dlatents1, num_layers = self.mapping(latents1) 529 | # let [N, O] -> [N, num_layers, O] 530 | # 这里的unsqueeze不能使用inplace操作, 如果这样的话, 反向传播的链条会断掉的. 531 | dlatents1 = dlatents1.unsqueeze(1) 532 | dlatents1 = dlatents1.expand(-1, int(num_layers), -1) 533 | 534 | # Add mixing style mechanism. 535 | # with torch.no_grad(): 536 | # latents2 = torch.randn(latents1.shape).to(latents1.device) 537 | # dlatents2, num_layers = self.mapping(latents2) 538 | # dlatents2 = dlatents2.unsqueeze(1) 539 | # dlatents2 = dlatents2.expand(-1, int(num_layers), -1) 540 | # 541 | # # TODO: original NvLABs produce a placeholder "lod", this mechanism was not added here. 542 | # cur_layers = num_layers 543 | # mix_layers = num_layers 544 | # if np.random.random() < self.style_mixing_prob: 545 | # mix_layers = np.random.randint(1, cur_layers) 546 | # 547 | # # NvLABs: dlatents = tf.where(tf.broadcast_to(layer_idx < mixing_cutoff, tf.shape(dlatents)), dlatents, dlatents2) 548 | # for i in range(num_layers): 549 | # if i >= mix_layers: 550 | # dlatents1[:, i, :] = dlatents2[:, i, :] 551 | 552 | # Apply truncation trick. 553 | if self.truncation_psi and self.truncation_cutoff: 554 | coefs = np.ones([1, num_layers, 1], dtype=np.float32) 555 | for i in range(num_layers): 556 | if i < self.truncation_cutoff: 557 | coefs[:, i, :] *= self.truncation_psi 558 | """Linear interpolation. 559 | a + (b - a) * t (a = 0) 560 | reduce to 561 | b * t 562 | """ 563 | 564 | dlatents1 = dlatents1 * torch.Tensor(coefs).to(dlatents1.device) 565 | 566 | img = self.synthesis(dlatents1) 567 | return img 568 | 569 | 570 | class StyleDiscriminator(nn.Module): 571 | def __init__(self, 572 | resolution=1024, 573 | fmap_base=8192, 574 | num_channels=3, 575 | structure='fixed', # 'fixed' = no progressive growing, 'linear' = human-readable, 'recursive' = efficient, only support 'fixed' mode now. 576 | fmap_max=512, 577 | fmap_decay=1.0, 578 | # f=[1, 2, 1] # (Huge overload, if you dont have enough resouces, please pass it as `f = None`)Low-pass filter to apply when resampling activations. None = no filtering. 579 | f=None # (Huge overload, if you dont have enough resouces, please pass it as `f = None`)Low-pass filter to apply when resampling activations. None = no filtering. 580 | ): 581 | """ 582 | Noitce: we only support input pic with height == width. 583 | 584 | if H or W >= 128, we use avgpooling2d to do feature map shrinkage. 585 | else: we use ordinary conv2d. 586 | """ 587 | super().__init__() 588 | self.resolution_log2 = int(np.log2(resolution)) 589 | assert resolution == 2 ** self.resolution_log2 and resolution >= 4 590 | self.nf = lambda stage: min(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_max) 591 | # fromrgb: fixed mode 592 | self.fromrgb = nn.Conv2d(num_channels, self.nf(self.resolution_log2-1), kernel_size=1) 593 | self.structure = structure 594 | 595 | # blur2d 596 | self.blur2d = Blur2d(f) 597 | 598 | # down_sample 599 | self.down1 = nn.AvgPool2d(2) 600 | self.down21 = nn.Conv2d(self.nf(self.resolution_log2-5), self.nf(self.resolution_log2-5), kernel_size=2, stride=2) 601 | self.down22 = nn.Conv2d(self.nf(self.resolution_log2-6), self.nf(self.resolution_log2-6), kernel_size=2, stride=2) 602 | self.down23 = nn.Conv2d(self.nf(self.resolution_log2-7), self.nf(self.resolution_log2-7), kernel_size=2, stride=2) 603 | self.down24 = nn.Conv2d(self.nf(self.resolution_log2-8), self.nf(self.resolution_log2-8), kernel_size=2, stride=2) 604 | 605 | # conv1: padding=same 606 | self.conv1 = nn.Conv2d(self.nf(self.resolution_log2-1), self.nf(self.resolution_log2-1), kernel_size=3, padding=(1, 1)) 607 | self.conv2 = nn.Conv2d(self.nf(self.resolution_log2-1), self.nf(self.resolution_log2-2), kernel_size=3, padding=(1, 1)) 608 | self.conv3 = nn.Conv2d(self.nf(self.resolution_log2-2), self.nf(self.resolution_log2-3), kernel_size=3, padding=(1, 1)) 609 | self.conv4 = nn.Conv2d(self.nf(self.resolution_log2-3), self.nf(self.resolution_log2-4), kernel_size=3, padding=(1, 1)) 610 | self.conv5 = nn.Conv2d(self.nf(self.resolution_log2-4), self.nf(self.resolution_log2-5), kernel_size=3, padding=(1, 1)) 611 | self.conv6 = nn.Conv2d(self.nf(self.resolution_log2-5), self.nf(self.resolution_log2-6), kernel_size=3, padding=(1, 1)) 612 | self.conv7 = nn.Conv2d(self.nf(self.resolution_log2-6), self.nf(self.resolution_log2-7), kernel_size=3, padding=(1, 1)) 613 | self.conv8 = nn.Conv2d(self.nf(self.resolution_log2-7), self.nf(self.resolution_log2-8), kernel_size=3, padding=(1, 1)) 614 | 615 | # calculate point: 616 | self.conv_last = nn.Conv2d(self.nf(self.resolution_log2-8), self.nf(1), kernel_size=3, padding=(1, 1)) 617 | self.dense0 = nn.Linear(fmap_base, self.nf(0)) 618 | self.dense1 = nn.Linear(self.nf(0), 1) 619 | self.sigmoid = nn.Sigmoid() 620 | 621 | def forward(self, input): 622 | if self.structure == 'fixed': 623 | x = F.leaky_relu(self.fromrgb(input), 0.2, inplace=True) 624 | # 1. 1024 x 1024 x nf(9)(16) -> 512 x 512 625 | res = self.resolution_log2 626 | x = F.leaky_relu(self.conv1(x), 0.2, inplace=True) 627 | x = F.leaky_relu(self.down1(self.blur2d(x)), 0.2, inplace=True) 628 | 629 | # 2. 512 x 512 -> 256 x 256 630 | res -= 1 631 | x = F.leaky_relu(self.conv2(x), 0.2, inplace=True) 632 | x = F.leaky_relu(self.down1(self.blur2d(x)), 0.2, inplace=True) 633 | 634 | # 3. 256 x 256 -> 128 x 128 635 | res -= 1 636 | x = F.leaky_relu(self.conv3(x), 0.2, inplace=True) 637 | x = F.leaky_relu(self.down1(self.blur2d(x)), 0.2, inplace=True) 638 | 639 | # 4. 128 x 128 -> 64 x 64 640 | res -= 1 641 | x = F.leaky_relu(self.conv4(x), 0.2, inplace=True) 642 | x = F.leaky_relu(self.down1(self.blur2d(x)), 0.2, inplace=True) 643 | 644 | # 5. 64 x 64 -> 32 x 32 645 | res -= 1 646 | x = F.leaky_relu(self.conv5(x), 0.2, inplace=True) 647 | x = F.leaky_relu(self.down21(self.blur2d(x)), 0.2, inplace=True) 648 | 649 | # 6. 32 x 32 -> 16 x 16 650 | res -= 1 651 | x = F.leaky_relu(self.conv6(x), 0.2, inplace=True) 652 | x = F.leaky_relu(self.down22(self.blur2d(x)), 0.2, inplace=True) 653 | 654 | # 7. 16 x 16 -> 8 x 8 655 | res -= 1 656 | x = F.leaky_relu(self.conv7(x), 0.2, inplace=True) 657 | x = F.leaky_relu(self.down23(self.blur2d(x)), 0.2, inplace=True) 658 | 659 | # 8. 8 x 8 -> 4 x 4 660 | res -= 1 661 | x = F.leaky_relu(self.conv8(x), 0.2, inplace=True) 662 | x = F.leaky_relu(self.down24(self.blur2d(x)), 0.2, inplace=True) 663 | 664 | # 9. 4 x 4 -> point 665 | x = F.leaky_relu(self.conv_last(x), 0.2, inplace=True) 666 | # N x 8192(4 x 4 x nf(1)). 667 | x = x.view(x.size(0), -1) 668 | x = F.leaky_relu(self.dense0(x), 0.2, inplace=True) 669 | # N x 1 670 | x = F.leaky_relu(self.dense1(x), 0.2, inplace=True) 671 | return x 672 | 673 | -------------------------------------------------------------------------------- /opts/opts.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | """ 3 | @author: samuel ko 4 | """ 5 | import argparse 6 | import torch 7 | import os 8 | 9 | 10 | def INFO(inputs): 11 | print("[ Style GAN ] %s" % (inputs)) 12 | 13 | 14 | def presentParameters(args_dict): 15 | """ 16 | Print the parameters setting line by line 17 | 18 | Arg: args_dict - The dict object which is transferred from argparse Namespace object 19 | """ 20 | INFO("========== Parameters ==========") 21 | for key in sorted(args_dict.keys()): 22 | INFO("{:>15} : {}".format(key, args_dict[key])) 23 | INFO("===============================") 24 | 25 | 26 | class TrainOptions(): 27 | def __init__(self): 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument('--path', type=str, default='./star/') 30 | parser.add_argument('--epoch', type=int, default=500) 31 | parser.add_argument('--batch_size', type=int, default=2) 32 | parser.add_argument('--type', type=str, default='style') 33 | parser.add_argument('--resume', type=str, default='train_result/models/latest.pth') 34 | parser.add_argument('--det', type=str, default='train_result') 35 | parser.add_argument('--r1_gamma', type=float, default=10.0) 36 | parser.add_argument('--r2_gamma', type=float, default=0.0) 37 | self.opts = parser.parse_args() 38 | 39 | def parse(self): 40 | self.opts.device = 'cuda' if torch.cuda.is_available() else 'cpu' 41 | 42 | # Check if the parameter is valid 43 | if self.opts.type not in ['style', 'origin']: 44 | raise Exception( 45 | "Unknown type: {} You should assign one of them ['style', 'origin']...".format(self.opts.type)) 46 | 47 | # Create the destination folder 48 | if not os.path.exists(self.opts.det): 49 | os.mkdir(self.opts.det) 50 | if not os.path.exists(os.path.join(self.opts.det, 'images')): 51 | os.mkdir(os.path.join(self.opts.det, 'images')) 52 | if not os.path.exists(os.path.join(self.opts.det, 'models')): 53 | os.mkdir(os.path.join(self.opts.det, 'models')) 54 | 55 | # Print the options 56 | presentParameters(vars(self.opts)) 57 | return self.opts 58 | 59 | 60 | class InferenceOptions(): 61 | def __init__(self): 62 | parser = argparse.ArgumentParser() 63 | parser.add_argument('--resume', type=str, default='train_result/model/latest.pth') 64 | parser.add_argument('--type', type=str, default='style') 65 | parser.add_argument('--num_face', type=int, default=32) 66 | parser.add_argument('--det', type=str, default='result.png') 67 | self.opts = parser.parse_args() 68 | 69 | def parse(self): 70 | self.opts.device = 'cuda' if torch.cuda.is_available() else 'cpu' 71 | 72 | # Print the options 73 | presentParameters(vars(self.opts)) 74 | return self.opts 75 | -------------------------------------------------------------------------------- /star/00793.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00793.png -------------------------------------------------------------------------------- /star/00794.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00794.png -------------------------------------------------------------------------------- /star/00795.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00795.png -------------------------------------------------------------------------------- /star/00796.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00796.png -------------------------------------------------------------------------------- /star/00797.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00797.png -------------------------------------------------------------------------------- /star/00798.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00798.png -------------------------------------------------------------------------------- /star/00799.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00799.png -------------------------------------------------------------------------------- /star/00800.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00800.png -------------------------------------------------------------------------------- /star/00801.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00801.png -------------------------------------------------------------------------------- /star/00802.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00802.png -------------------------------------------------------------------------------- /star/00803.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00803.png -------------------------------------------------------------------------------- /star/00804.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00804.png -------------------------------------------------------------------------------- /star/00805.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00805.png -------------------------------------------------------------------------------- /star/00806.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00806.png -------------------------------------------------------------------------------- /star/00807.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00807.png -------------------------------------------------------------------------------- /star/00808.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00808.png -------------------------------------------------------------------------------- /star/00809.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00809.png -------------------------------------------------------------------------------- /star/00810.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00810.png -------------------------------------------------------------------------------- /star/00811.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00811.png -------------------------------------------------------------------------------- /star/00812.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00812.png -------------------------------------------------------------------------------- /star/00813.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00813.png -------------------------------------------------------------------------------- /star/00814.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00814.png -------------------------------------------------------------------------------- /star/00815.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00815.png -------------------------------------------------------------------------------- /star/00816.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00816.png -------------------------------------------------------------------------------- /star/00817.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00817.png -------------------------------------------------------------------------------- /star/00818.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00818.png -------------------------------------------------------------------------------- /star/00819.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00819.png -------------------------------------------------------------------------------- /star/00820.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00820.png -------------------------------------------------------------------------------- /star/00821.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00821.png -------------------------------------------------------------------------------- /star/00822.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00822.png -------------------------------------------------------------------------------- /star/00823.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00823.png -------------------------------------------------------------------------------- /star/00824.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00824.png -------------------------------------------------------------------------------- /star/00825.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00825.png -------------------------------------------------------------------------------- /star/00826.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00826.png -------------------------------------------------------------------------------- /star/00827.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00827.png -------------------------------------------------------------------------------- /star/00828.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00828.png -------------------------------------------------------------------------------- /star/00829.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00829.png -------------------------------------------------------------------------------- /star/00830.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00830.png -------------------------------------------------------------------------------- /star/00831.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00831.png -------------------------------------------------------------------------------- /star/00832.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00832.png -------------------------------------------------------------------------------- /torchvision_sunner/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/torchvision_sunner/__init__.py -------------------------------------------------------------------------------- /torchvision_sunner/constant.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script defines the constant which will be used in Torchvision_sunner package 3 | 4 | Author: SunnerLi 5 | """ 6 | 7 | # Constant 8 | UNDER_SAMPLING = 0 9 | OVER_SAMPLING = 1 10 | BCHW2BHWC = 0 11 | BHWC2BCHW = 1 12 | 13 | # Categorical constant 14 | ONEHOT2INDEX = 'one_hot_to_index' 15 | INDEX2ONEHOT = 'index_to_one_hot' 16 | ONEHOT2COLOR = 'one_hot_to_color' 17 | COLOR2ONEHOT = 'color_to_one_hot' 18 | INDEX2COLOR = 'index_to_color' 19 | COLOR2INDEX = 'color_to_index' 20 | 21 | # Verbose flag 22 | verbose = True -------------------------------------------------------------------------------- /torchvision_sunner/data/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script define the wrapper of the Torchvision_sunner.data 3 | 4 | Author: SunnerLi 5 | """ 6 | from torch.utils.data import DataLoader 7 | 8 | from torchvision_sunner.data.image_dataset import ImageDataset 9 | from torchvision_sunner.data.video_dataset import VideoDataset 10 | from torchvision_sunner.data.loader import * 11 | from torchvision_sunner.constant import * 12 | from torchvision_sunner.utils import * -------------------------------------------------------------------------------- /torchvision_sunner/data/base_dataset.py: -------------------------------------------------------------------------------- 1 | from torchvision_sunner.constant import * 2 | from torchvision_sunner.utils import INFO 3 | import torch.utils.data as Data 4 | 5 | import pickle 6 | import random 7 | import os 8 | 9 | """ 10 | This script define the parent class to deal with some common function for Dataset 11 | 12 | Author: SunnerLi 13 | """ 14 | 15 | class BaseDataset(Data.Dataset): 16 | def __init__(self): 17 | self.save_file = False 18 | self.files = None 19 | self.split_files = None 20 | 21 | def generateIndexList(self, a, size): 22 | """ 23 | Generate the list of index which will be picked 24 | This function will be used as train-test-split 25 | 26 | Arg: a - The list of images 27 | size - Int, the length of list you want to create 28 | Ret: The index list 29 | """ 30 | result = set() 31 | while len(result) != size: 32 | result.add(random.randint(0, len(a) - 1)) 33 | return list(result) 34 | 35 | def loadFromFile(self, file_name, check_type = 'image'): 36 | """ 37 | Load the root and files information from .pkl record file 38 | This function will return False if the record file format is invalid 39 | 40 | Arg: file_name - The name of record file 41 | check_type - Str. The type of the record file you want to check 42 | Ret: If the loading procedure are successful or not 43 | """ 44 | with open(file_name, 'rb') as f: 45 | obj = pickle.load(f) 46 | self.type = obj['type'] 47 | if self.type == check_type: 48 | INFO("Load from file: {}".format(file_name)) 49 | self.root = obj['root'] 50 | self.files = obj['files'] 51 | return True 52 | else: 53 | INFO("Record file type: {}\tFail to load...".format(self.type)) 54 | INFO("Form the contain from scratch...") 55 | return False 56 | 57 | def save(self, remain_file_name, split_ratio, split_file_name = ".split.pkl", save_type = 'image'): 58 | """ 59 | Save the information into record file 60 | 61 | Arg: remain_file_name - The path of record file which store the information of remain data 62 | split_ratio - Float. The proportion to split the data. Usually used to split the testing data 63 | split_file_name - The path of record file which store the information of split data 64 | save_type - Str. The type of the record file you want to save 65 | """ 66 | if self.save_file: 67 | if not os.path.exists(remain_file_name): 68 | with open(remain_file_name, 'wb') as f: 69 | pickle.dump({ 70 | 'type': save_type, 71 | 'root': self.root, 72 | 'files': self.files 73 | }, f) 74 | if split_ratio: 75 | INFO("Split the dataset, and save as {}".format(split_file_name)) 76 | with open(split_file_name, 'wb') as f: 77 | pickle.dump({ 78 | 'type': save_type, 79 | 'root': self.root, 80 | 'files': self.split_files 81 | }, f) -------------------------------------------------------------------------------- /torchvision_sunner/data/image_dataset.py: -------------------------------------------------------------------------------- 1 | from torchvision_sunner.data.base_dataset import BaseDataset 2 | from torchvision_sunner.read import readContain, readItem 3 | from torchvision_sunner.constant import * 4 | from torchvision_sunner.utils import INFO 5 | 6 | from skimage import io as io 7 | # from PIL import Image 8 | from glob import glob 9 | 10 | import torch.utils.data as Data 11 | 12 | import pickle 13 | import math 14 | import os 15 | 16 | """ 17 | This script define the structure of image dataset 18 | 19 | ======================================================================================= 20 | In the new version, we accept the form that the combination of image and folder: 21 | e.g. [[image1.jpg, image_folder]] 22 | On the other hand, the root can only be 'the list of list' 23 | You should use double list to represent different image domain. 24 | For example: 25 | [[image1.jpg], [image2.jpg]] => valid 26 | [[image1.jpg], [image_folder]] => valid 27 | [[image1.jpg, image2.jpg], [image_folder1, image_folder2]] => valid 28 | [image1.jpg, image2.jpg] => invalid! 29 | Also, the triple of nested list is not allow 30 | ======================================================================================= 31 | 32 | Author: SunnerLi 33 | """ 34 | 35 | class ImageDataset(BaseDataset): 36 | def __init__(self, root = None, file_name = '.remain.pkl', sample_method = UNDER_SAMPLING, transform = None, 37 | split_ratio = 0.0, save_file = False): 38 | """ 39 | The constructor of ImageDataset 40 | 41 | Arg: root - The list object. The image set 42 | file_name - The str. The name of record file. 43 | sample_method - sunnerData.UNDER_SAMPLING or sunnerData.OVER_SAMPLING. Use down sampling or over sampling to deal with data unbalance problem. 44 | (default is sunnerData.OVER_SAMPLING) 45 | transform - transform.Compose object. You can declare some pre-process toward the image 46 | split_ratio - Float. The proportion to split the data. Usually used to split the testing data 47 | save_file - Bool. If storing the record file or not. Default is False 48 | """ 49 | super().__init__() 50 | # Record the parameter 51 | self.root = root 52 | self.file_name = file_name 53 | self.sample_method = sample_method 54 | self.transform = transform 55 | self.split_ratio = split_ratio 56 | self.save_file = save_file 57 | self.img_num = -1 58 | INFO() 59 | 60 | # Substitude the contain of record file if the record file is exist 61 | if os.path.exists(file_name) and self.loadFromFile(file_name): 62 | self.getImgNum() 63 | elif not os.path.exists(file_name) and root is None: 64 | raise Exception("Record file {} not found. You should assign 'root' parameter!".format(file_name)) 65 | else: 66 | # Extend the images of folder into domain list 67 | self.getFiles() 68 | 69 | # Change root obj as the index format 70 | self.root = range(len(self.root)) 71 | 72 | # Adjust the image number 73 | self.getImgNum() 74 | 75 | # Split the files if split_ratio is more than 0.0 76 | self.split() 77 | 78 | # Save the split information 79 | self.save() 80 | 81 | # Print the domain information 82 | self.print() 83 | 84 | # =========================================================================================== 85 | # Define IO function 86 | # =========================================================================================== 87 | def loadFromFile(self, file_name): 88 | """ 89 | Load the root and files information from .pkl record file 90 | This function will return False if the record file format is invalid 91 | 92 | Arg: file_name - The name of record file 93 | Ret: If the loading procedure are successful or not 94 | """ 95 | return super().loadFromFile(file_name, 'image') 96 | 97 | def save(self, split_file_name = ".split.pkl"): 98 | """ 99 | Save the information into record file 100 | 101 | Arg: split_file_name - The path of record file which store the information of split data 102 | """ 103 | super().save(self.file_name, self.split_ratio, split_file_name, 'image') 104 | 105 | # =========================================================================================== 106 | # Define main function 107 | # =========================================================================================== 108 | def getFiles(self): 109 | """ 110 | Construct the files object for the assigned root 111 | We accept the user to mix folder with image 112 | This function can extract whole image in the folder 113 | The element in the files will all become image 114 | 115 | ******************************************************* 116 | * This function only work if the files object is None * 117 | ******************************************************* 118 | """ 119 | if not self.files: 120 | self.files = {} 121 | for domain_idx, domain in enumerate(self.root): 122 | images = [] 123 | for img in domain: 124 | if os.path.exists(img): 125 | if os.path.isdir(img): 126 | images += readContain(img) 127 | else: 128 | images.append(img) 129 | else: 130 | raise Exception("The path {} is not exist".format(img)) 131 | self.files[domain_idx] = sorted(images) 132 | 133 | def getImgNum(self): 134 | """ 135 | Obtain the image number in the loader for the specific sample method 136 | The function will check if the folder has been extracted 137 | """ 138 | if self.img_num == -1: 139 | # Check if the folder has been extracted 140 | for domain in self.root: 141 | for img in self.files[domain]: 142 | if os.path.isdir(img): 143 | raise Exception("You should extend the image in the folder {} first!" % img) 144 | 145 | # Statistic the image number 146 | for domain in self.root: 147 | if domain == 0: 148 | self.img_num = len(self.files[domain]) 149 | else: 150 | if self.sample_method == OVER_SAMPLING: 151 | self.img_num = max(self.img_num, len(self.files[domain])) 152 | elif self.sample_method == UNDER_SAMPLING: 153 | self.img_num = min(self.img_num, len(self.files[domain])) 154 | return self.img_num 155 | 156 | def split(self): 157 | """ 158 | Split the files object into split_files object 159 | The original files object will shrink 160 | 161 | We also consider the case of pair image 162 | Thus we will check if the number of image in each domain is the same 163 | If it does, then we only generate the list once 164 | """ 165 | # Check if the number of image in different domain is the same 166 | if not self.files: 167 | self.getFiles() 168 | pairImage = True 169 | for domain in range(len(self.root) - 1): 170 | if len(self.files[domain]) != len(self.files[domain + 1]): 171 | pairImage = False 172 | 173 | # Split the files 174 | self.split_files = {} 175 | if pairImage: 176 | split_img_num = math.floor(len(self.files[0]) * self.split_ratio) 177 | choice_index_list = self.generateIndexList(range(len(self.files[0])), size = split_img_num) 178 | for domain in range(len(self.root)): 179 | # determine the index list 180 | if not pairImage: 181 | split_img_num = math.floor(len(self.files[domain]) * self.split_ratio) 182 | choice_index_list = self.generateIndexList(range(len(self.files[domain])), size = split_img_num) 183 | # remove the corresponding term and add into new list 184 | split_img_list = [] 185 | remain_img_list = self.files[domain].copy() 186 | for j in choice_index_list: 187 | split_img_list.append(self.files[domain][j]) 188 | for j in choice_index_list: 189 | self.files[domain].remove(remain_img_list[j]) 190 | self.split_files[domain] = sorted(split_img_list) 191 | 192 | def print(self): 193 | """ 194 | Print the information for each image domain 195 | """ 196 | INFO() 197 | for domain in range(len(self.root)): 198 | INFO("domain index: %d \timage number: %d" % (domain, len(self.files[domain]))) 199 | INFO() 200 | 201 | def __len__(self): 202 | return self.img_num 203 | 204 | def __getitem__(self, index): 205 | return_list = [] 206 | for domain in self.root: 207 | img_path = self.files[domain][index] 208 | img = readItem(img_path) 209 | if self.transform: 210 | img = self.transform(img) 211 | return_list.append(img) 212 | return return_list -------------------------------------------------------------------------------- /torchvision_sunner/data/loader.py: -------------------------------------------------------------------------------- 1 | from torchvision_sunner.constant import * 2 | from collections import Iterator 3 | import torch.utils.data as data 4 | 5 | """ 6 | This script define the extra loader, and it can be used in flexibility. The loaders include: 7 | 1. ImageLoader (The old version exist) 8 | 2. MultiLoader 9 | 3. IterationLoader 10 | 11 | Author: SunnerLi 12 | """ 13 | 14 | class ImageLoader(data.DataLoader): 15 | def __init__(self, dataset, batch_size=1, shuffle=False, num_workers = 1): 16 | """ 17 | The DataLoader object which can deal with ImageDataset object. 18 | 19 | Arg: dataset - ImageDataset. You should use sunnerData.ImageDataset to generate the instance first 20 | batch_size - Int. 21 | shuffle - Bool. Shuffle the data or not 22 | num_workers - Int. How many thread you want to use to read the batch data 23 | """ 24 | super(ImageLoader, self).__init__(dataset, batch_size=batch_size, shuffle=shuffle, num_workers = num_workers) 25 | self.dataset = dataset 26 | self.iter_num = self.__len__() 27 | 28 | def __len__(self): 29 | return round(self.dataset.img_num / self.batch_size) 30 | 31 | def getImageNumber(self): 32 | return self.dataset.img_num 33 | 34 | class MultiLoader(Iterator): 35 | def __init__(self, datasets, batch_size=1, shuffle=False, num_workers = 1): 36 | """ 37 | This class can deal with multiple dataset object 38 | 39 | Arg: datasets - The list of ImageDataset. 40 | batch_size - Int. 41 | shuffle - Bool. Shuffle the data or not 42 | num_workers - Int. How many thread you want to use to read the batch data 43 | """ 44 | # Create loaders 45 | self.loaders = [] 46 | for dataset in datasets: 47 | self.loaders.append( 48 | data.DataLoader(dataset, batch_size = batch_size, shuffle = shuffle, num_workers = num_workers) 49 | ) 50 | 51 | # Check the sample method 52 | self.sample_method = None 53 | for dataset in datasets: 54 | if self.sample_method is None: 55 | self.sample_method = dataset.sample_method 56 | else: 57 | if self.sample_method != dataset.sample_method: 58 | raise Exception("Sample methods are not consistant, {} <=> {}".format( 59 | self.sample_method, dataset.sample_method 60 | )) 61 | 62 | # Check the iteration number 63 | self.iter_num = 0 64 | for i, dataset in enumerate(datasets): 65 | if i == 0: 66 | self.iter_num = len(dataset) 67 | else: 68 | if self.sample_method == UNDER_SAMPLING: 69 | self.iter_num = min(self.iter_num, len(dataset)) 70 | else: 71 | self.iter_num = max(self.iter_num, len(dataset)) 72 | self.iter_num = round(self.iter_num / batch_size) 73 | 74 | def __len__(self): 75 | return self.iter_num 76 | 77 | def __iter__(self): 78 | self.iter_loaders = [] 79 | for loader in self.loaders: 80 | self.iter_loaders.append(iter(loader)) 81 | return self 82 | 83 | def __next__(self): 84 | result = [] 85 | for loader in self.iter_loaders: 86 | for _ in loader.__next__(): 87 | result.append(_) 88 | return tuple(result) 89 | 90 | class IterationLoader(Iterator): 91 | def __init__(self, loader, max_iter = 1): 92 | """ 93 | Constructor of the loader with specific iteration (not epoch) 94 | The iteration object will create again while getting end 95 | 96 | Arg: loader - The torch.data.DataLoader object 97 | max_iter - The maximun iteration 98 | """ 99 | super().__init__() 100 | self.loader = loader 101 | self.loader_iter = iter(self.loader) 102 | self.iter = 0 103 | self.max_iter = max_iter 104 | 105 | def __next__(self): 106 | try: 107 | result_tuple = next(self.loader_iter) 108 | except: 109 | self.loader_iter = iter(self.loader) 110 | result_tuple = next(self.loader_iter) 111 | self.iter += 1 112 | if self.iter <= self.max_iter: 113 | return result_tuple 114 | else: 115 | print("", end='') 116 | raise StopIteration() 117 | 118 | def __len__(self): 119 | return self.max_iter -------------------------------------------------------------------------------- /torchvision_sunner/data/video_dataset.py: -------------------------------------------------------------------------------- 1 | from torchvision_sunner.data.base_dataset import BaseDataset 2 | from torchvision_sunner.read import readContain, readItem 3 | from torchvision_sunner.constant import * 4 | from torchvision_sunner.utils import INFO 5 | 6 | import torch.utils.data as Data 7 | 8 | from PIL import Image 9 | from glob import glob 10 | import numpy as np 11 | import subprocess 12 | import random 13 | import pickle 14 | import torch 15 | import math 16 | import os 17 | 18 | """ 19 | This script define the structure of video dataset 20 | 21 | ======================================================================================= 22 | In the new version, we accept the form that the combination of video and folder: 23 | e.g. [[video1.mp4, image_folder]] 24 | On the other hand, the root can only be 'the list of list' 25 | You should use double list to represent different image domain. 26 | For example: 27 | [[video1.mp4], [video2.mp4]] => valid 28 | [[video1.mp4], [video_folder]] => valid 29 | [[video1.mp4, video2.mp4], [video_folder1, video_folder2]] => valid 30 | [video1.mp4, video2.mp4] => invalid! 31 | Also, the triple of nested list is not allow 32 | ======================================================================================= 33 | 34 | Author: SunnerLi 35 | """ 36 | 37 | class VideoDataset(BaseDataset): 38 | def __init__(self, root = None, file_name = '.remain.pkl', T = 10, sample_method = UNDER_SAMPLING, transform = None, 39 | split_ratio = 0.0, decode_root = './.decode', save_file = False): 40 | """ 41 | The constructor of VideoDataset 42 | 43 | Arg: root - The list object. The image set 44 | file_name - Str. The name of record file. 45 | T - Int. The maximun length of small video sequence 46 | sample_method - sunnerData.UNDER_SAMPLING or sunnerData.OVER_SAMPLING. Use down sampling or over sampling to deal with data unbalance problem. 47 | (default is sunnerData.OVER_SAMPLING) 48 | transform - transform.Compose object. You can declare some pre-process toward the image 49 | split_ratio - Float. The proportion to split the data. Usually used to split the testing data 50 | decode_root - Str. The path to store the ffmpeg decode result. 51 | save_file - Bool. If storing the record file or not. Default is False 52 | """ 53 | super().__init__() 54 | 55 | # Record the parameter 56 | self.root = root 57 | self.file_name = file_name 58 | self.T = T 59 | self.sample_method = sample_method 60 | self.transform = transform 61 | self.split_ratio = split_ratio 62 | self.decode_root = decode_root 63 | self.video_num = -1 64 | self.split_root = None 65 | INFO() 66 | 67 | # Substitude the contain of record file if the record file is exist 68 | if not os.path.exists(file_name) and root is None: 69 | raise Exception("Record file {} not found. You should assign 'root' parameter!".format(file_name)) 70 | elif os.path.exists(file_name): 71 | INFO("Load from file: {}".format(file_name)) 72 | self.loadFromFile(file_name) 73 | 74 | # Extend the images of folder into domain list 75 | self.extendFolder() 76 | 77 | # Split the image 78 | self.split() 79 | 80 | # Form the files obj 81 | self.getFiles() 82 | 83 | # Adjust the image number 84 | self.getVideoNum() 85 | 86 | # Save the split information 87 | self.save() 88 | 89 | # Print the domain information 90 | self.print() 91 | 92 | # =========================================================================================== 93 | # Define IO function 94 | # =========================================================================================== 95 | def loadFromFile(self, file_name): 96 | """ 97 | Load the root and files information from .pkl record file 98 | This function will return False if the record file format is invalid 99 | 100 | Arg: file_name - The name of record file 101 | Ret: If the loading procedure are successful or not 102 | """ 103 | return super().loadFromFile(file_name, 'video') 104 | 105 | def save(self, split_file_name = ".split.pkl"): 106 | """ 107 | Save the information into record file 108 | 109 | Arg: split_file_name - The path of record file which store the information of split data 110 | """ 111 | super().save(self.file_name, self.split_ratio, split_file_name, 'video') 112 | 113 | # =========================================================================================== 114 | # Define main function 115 | # =========================================================================================== 116 | def to_folder(self, name): 117 | """ 118 | Transfer the name into the folder format 119 | e.g. 120 | '/home/Dataset/video1_folder' => 'home_Dataset_video1_folder' 121 | '/home/Dataset/video1.mp4' => 'home_Dataset_video1' 122 | 123 | Arg: name - Str. The path of file or original folder 124 | Ret: The new (encoded) folder name 125 | """ 126 | if not os.path.isdir(name): 127 | name = '_'.join(name.split('.')[:-1]) 128 | domain_list = name.split('/') 129 | while True: 130 | if '.' in domain_list: 131 | domain_list.remove('.') 132 | elif '..' in domain_list: 133 | domain_list.remove('..') 134 | else: 135 | break 136 | return '_'.join(domain_list) 137 | 138 | def extendFolder(self): 139 | """ 140 | Extend the video folder in root obj 141 | """ 142 | if not self.files: 143 | # Extend the folder of video and replace as new root obj 144 | extend_root = [] 145 | for domain in self.root: 146 | videos = [] 147 | for video in domain: 148 | if os.path.exists(video): 149 | if os.path.isdir(video): 150 | videos += readContain(video) 151 | else: 152 | videos.append(video) 153 | else: 154 | raise Exception("The path {} is not exist".format(videos)) 155 | extend_root.append(videos) 156 | self.root = extend_root 157 | 158 | def split(self): 159 | """ 160 | Split the root object into split_root object 161 | The original root object will shrink 162 | 163 | We also consider the case of pair image 164 | Thus we will check if the number of image in each domain is the same 165 | If it does, then we only generate the list once 166 | """ 167 | # Check if the number of video in different domain is the same 168 | pairImage = True 169 | for domain_idx in range(len(self.root) - 1): 170 | if len(self.root[domain_idx]) != len(self.root[domain_idx + 1]): 171 | pairImage = False 172 | 173 | # Split the files 174 | self.split_root = [] 175 | if pairImage: 176 | split_img_num = math.floor(len(self.root[0]) * self.split_ratio) 177 | choice_index_list = self.generateIndexList(range(len(self.root[0])), size = split_img_num) 178 | for domain_idx in range(len(self.root)): 179 | # determine the index list 180 | if not pairImage: 181 | split_img_num = math.floor(len(self.root[domain_idx]) * self.split_ratio) 182 | choice_index_list = self.generateIndexList(range(len(self.root[domain_idx])), size = split_img_num) 183 | # remove the corresponding term and add into new list 184 | split_img_list = [] 185 | remain_img_list = self.root[domain_idx].copy() 186 | for j in choice_index_list: 187 | split_img_list.append(self.root[domain_idx][j]) 188 | for j in choice_index_list: 189 | self.root[domain_idx].remove(remain_img_list[j]) 190 | self.split_root.append(sorted(split_img_list)) 191 | 192 | def getFiles(self): 193 | """ 194 | Construct the files object for the assigned root 195 | We accept the user to mix folder with image 196 | This function can extract whole image in the folder 197 | 198 | However, unlike the setting in ImageDataset, we store the video result in root obj. 199 | Also, the 'images' name will be store in files obj 200 | 201 | The following list the progress of this function: 202 | 1. check if we need to decode again 203 | 2. decode if needed 204 | 3. form the files obj 205 | """ 206 | if not self.files: 207 | # Check if the decode process should be conducted again 208 | should_decode = not os.path.exists(self.decode_root) 209 | if not should_decode: 210 | for domain_idx, domain in enumerate(self.root): 211 | for video in domain: 212 | if not os.path.exists(os.path.join(self.decode_root, str(domain_idx), self.to_folder(video))): 213 | should_decode = True 214 | break 215 | 216 | # Decode the video if needed 217 | if should_decode: 218 | INFO("Decode from scratch...") 219 | if os.path.exists(self.decode_root): 220 | subprocess.call(['rm', '-rf', self.decode_root]) 221 | os.mkdir(self.decode_root) 222 | self.decodeVideo() 223 | else: 224 | INFO("Skip the decode process!") 225 | 226 | # Form the files object 227 | self.files = {} 228 | for domain_idx, domain in enumerate(os.listdir(self.decode_root)): 229 | self.files[domain_idx] = [] 230 | for video in os.listdir(os.path.join(self.decode_root, domain)): 231 | self.files[domain_idx] += [ 232 | sorted(glob(os.path.join(self.decode_root, domain, video, "*"))) 233 | ] 234 | 235 | def decodeVideo(self): 236 | """ 237 | Decode the single video into a series of images, and store into particular folder 238 | """ 239 | for domain_idx, domain in enumerate(self.root): 240 | decode_domain_folder = os.path.join(self.decode_root, str(domain_idx)) 241 | os.mkdir(decode_domain_folder) 242 | for video in domain: 243 | os.mkdir(os.path.join(self.decode_root, str(domain_idx), self.to_folder(video))) 244 | source = os.path.join(domain, video) 245 | target = os.path.join(decode_domain_folder, self.to_folder(video), "%5d.png") 246 | subprocess.call(['ffmpeg', '-i', source, target]) 247 | 248 | def getVideoNum(self): 249 | """ 250 | Obtain the video number in the loader for the specific sample method 251 | The function will check if the folder has been extracted 252 | """ 253 | if self.video_num == -1: 254 | # Check if the folder has been extracted 255 | for domain in self.root: 256 | for video in domain: 257 | if os.path.isdir(video): 258 | raise Exception("You should extend the image in the folder {} first!" % video) 259 | 260 | # Statistic the image number 261 | for i, domain in enumerate(self.root): 262 | if i == 0: 263 | self.video_num = len(domain) 264 | else: 265 | if self.sample_method == OVER_SAMPLING: 266 | self.video_num = max(self.video_num, len(domain)) 267 | elif self.sample_method == UNDER_SAMPLING: 268 | self.video_num = min(self.video_num, len(domain)) 269 | return self.video_num 270 | 271 | def print(self): 272 | """ 273 | Print the information for each image domain 274 | """ 275 | INFO() 276 | for domain in range(len(self.root)): 277 | total_frame = 0 278 | for video in self.files[domain]: 279 | total_frame += len(video) 280 | INFO("domain index: %d \tvideo number: %d\tframe total: %d" % (domain, len(self.root[domain]), total_frame)) 281 | INFO() 282 | 283 | def __len__(self): 284 | return self.video_num 285 | 286 | def __getitem__(self, index): 287 | """ 288 | Return single batch of data, and the rank is BTCHW 289 | """ 290 | result = [] 291 | for domain_idx in range(len(self.root)): 292 | 293 | # Form the sequence in single domain 294 | film_sequence = [] 295 | max_init_frame_idx = len(self.files[domain_idx][index]) - self.T 296 | start_pos = random.randint(0, max_init_frame_idx) 297 | for i in range(self.T): 298 | img_path = self.files[domain_idx][index][start_pos + i] 299 | img = readItem(img_path) 300 | film_sequence.append(img) 301 | 302 | # Transform the film sequence 303 | film_sequence = np.asarray(film_sequence) 304 | if self.transform: 305 | film_sequence = self.transform(film_sequence) 306 | result.append(film_sequence) 307 | return result -------------------------------------------------------------------------------- /torchvision_sunner/read.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | from PIL import Image 3 | from glob import glob 4 | import numpy as np 5 | import os 6 | 7 | """ 8 | This script defines the function to read the containing of folder and read the file 9 | You should customize if your data is not considered in the torchvision_sunner previously 10 | 11 | Author: SunnerLi 12 | """ 13 | 14 | def readContain(folder_name): 15 | """ 16 | Read the containing in the particular folder 17 | 18 | ================================================================== 19 | You should customize this function if your data is not considered 20 | ================================================================== 21 | 22 | Arg: folder_name - The path of folder 23 | Ret: The list of containing 24 | """ 25 | # Check the common type in the folder 26 | common_type = Counter() 27 | for name in os.listdir(folder_name): 28 | common_type[name.split('.')[-1]] += 1 29 | common_type = common_type.most_common()[0][0] 30 | 31 | # Deal with the type 32 | if common_type == 'jpg': 33 | name_list = glob(os.path.join(folder_name, '*.jpg')) 34 | elif common_type == 'png': 35 | name_list = glob(os.path.join(folder_name, '*.png')) 36 | elif common_type == 'mp4': 37 | name_list = glob(os.path.join(folder_name, '*.mp4')) 38 | else: 39 | raise Exception("Unknown type {}, You should customize in read.py".format(common_type)) 40 | return name_list 41 | 42 | def readItem(item_name): 43 | """ 44 | Read the file for the given item name 45 | 46 | ================================================================== 47 | You should customize this function if your data is not considered 48 | ================================================================== 49 | 50 | Arg: item_name - The path of the file 51 | Ret: The item you read 52 | """ 53 | file_type = item_name.split('.')[-1] 54 | if file_type == "png" or file_type == 'jpg': 55 | file_obj = np.asarray(Image.open(item_name)) 56 | 57 | if len(file_obj.shape) == 3: 58 | # Ignore the 4th dim (RGB only) 59 | file_obj = file_obj[:, :, :3] 60 | elif len(file_obj.shape) == 2: 61 | # Make the rank of gray-scale image as 3 62 | file_obj = np.expand_dims(file_obj, axis = -1) 63 | return file_obj -------------------------------------------------------------------------------- /torchvision_sunner/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script define the wrapper of the Torchvision_sunner.transforms 3 | 4 | Author: SunnerLi 5 | """ 6 | from torchvision_sunner.constant import * 7 | from torchvision_sunner.utils import * 8 | 9 | from torchvision_sunner.transforms.base import * 10 | from torchvision_sunner.transforms.simple import * 11 | from torchvision_sunner.transforms.complex import * 12 | from torchvision_sunner.transforms.categorical import * 13 | from torchvision_sunner.transforms.function import * -------------------------------------------------------------------------------- /torchvision_sunner/transforms/base.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | """ 5 | This class define the parent class of operation 6 | 7 | Author: SunnerLi 8 | """ 9 | 10 | class OP(): 11 | """ 12 | The parent class of each operation 13 | The goal of this class is to adapting with different input format 14 | """ 15 | def work(self, tensor): 16 | """ 17 | The virtual function to define the process in child class 18 | 19 | Arg: tensor - The np.ndarray object. The tensor you want to deal with 20 | """ 21 | raise NotImplementedError("You should define your own function in the class!") 22 | 23 | def __call__(self, tensor): 24 | """ 25 | This function define the proceeding of the operation 26 | There are different choice toward the tensor parameter 27 | 1. torch.Tensor and rank is CHW 28 | 2. np.ndarray and rank is CHW 29 | 3. torch.Tensor and rank is TCHW 30 | 4. np.ndarray and rank is TCHW 31 | 32 | Arg: tensor - The tensor you want to operate 33 | Ret: The operated tensor 34 | """ 35 | isTensor = type(tensor) == torch.Tensor 36 | if isTensor: 37 | tensor_type = tensor.type() 38 | tensor = tensor.cpu().data.numpy() 39 | if len(tensor.shape) == 3: 40 | tensor = self.work(tensor) 41 | elif len(tensor.shape) == 4: 42 | tensor = np.asarray([self.work(_) for _ in tensor]) 43 | else: 44 | raise Exception("We dont support the rank format {}".format(tensor.shape), 45 | "If the rank of the tensor shape is only 2, you can call 'GrayStack()'") 46 | if isTensor: 47 | tensor = torch.from_numpy(tensor) 48 | tensor = tensor.type(tensor_type) 49 | return tensor -------------------------------------------------------------------------------- /torchvision_sunner/transforms/categorical.py: -------------------------------------------------------------------------------- 1 | from torchvision_sunner.utils import INFO 2 | from torchvision_sunner.constant import * 3 | from torchvision_sunner.transforms.simple import Transpose 4 | 5 | from collections import OrderedDict 6 | from tqdm import tqdm 7 | import numpy as np 8 | import pickle 9 | import torch 10 | import json 11 | import os 12 | 13 | """ 14 | This script define the categorical-related operations, including: 15 | 1. getCategoricalMapping 16 | 2. CategoricalTranspose 17 | 18 | Author: SunnerLi 19 | """ 20 | 21 | # ---------------------------------------------------------------------------------------- 22 | # Define the IO function toward pallete file 23 | # ---------------------------------------------------------------------------------------- 24 | 25 | def load_pallete(file_name): 26 | """ 27 | Load the pallete object from file 28 | 29 | Arg: file_name - The name of pallete .json file 30 | Ret: The list of pallete object 31 | """ 32 | # Load the list of dict from files (key is str) 33 | palletes_str_key = None 34 | with open(file_name, 'r') as f: 35 | palletes_str_key = json.load(f) 36 | 37 | # Change the key into color tuple 38 | palletes = [OrderedDict()] * len(palletes_str_key) 39 | for folder in range(len(palletes_str_key)): 40 | for key in palletes_str_key[folder].keys(): 41 | tuple_key = list() 42 | for v in key.split('_'): 43 | tuple_key.append(int(v)) 44 | palletes[folder][tuple(tuple_key)] = palletes_str_key[folder][key] 45 | return palletes 46 | 47 | def save_pallete(pallete, file_name): 48 | """ 49 | Load the pallete object from file 50 | 51 | Arg: pallete - The list of OrderDict objects 52 | file_name - The name of pallete .json file 53 | """ 54 | # Change the key into str 55 | pallete_str_key = [dict()] * len(pallete) 56 | for folder in range(len(pallete)): 57 | for key in pallete[folder].keys(): 58 | 59 | str_key = '_'.join([str(_) for _ in key]) 60 | pallete_str_key[folder][str_key] = pallete[folder][key] 61 | 62 | # Save into file 63 | with open(file_name, 'w') as f: 64 | json.dump(pallete_str_key, f) 65 | 66 | # ---------------------------------------------------------------------------------------- 67 | # Define the categorical-related operations 68 | # ---------------------------------------------------------------------------------------- 69 | 70 | def getCategoricalMapping(loader = None, path = 'torchvision_sunner_categories_pallete.json'): 71 | """ 72 | This function can statistic the different category with color 73 | And return the list of the mapping OrderedDict object 74 | 75 | Arg: loader - The ImageLoader object 76 | path - The path of pallete file 77 | Ret: The list of OrderDict object (palletes object) 78 | """ 79 | INFO("Applied << %15s >>" % getCategoricalMapping.__name__) 80 | INFO("* Notice: the rank format of input tensor should be 'BHWC'") 81 | INFO("* Notice: The range of tensor should be in [0, 255]") 82 | if os.path.exists(path): 83 | palletes = load_pallete(path) 84 | else: 85 | INFO(">> Load from scratch, please wait...") 86 | 87 | # Get the number of folder 88 | folder_num = 0 89 | for img_list in loader: 90 | folder_num = len(img_list) 91 | break 92 | 93 | # Initialize the pallete list 94 | palletes = [OrderedDict()] * folder_num 95 | color_sets = [set()] * folder_num 96 | 97 | # Work 98 | for img_list in tqdm(loader): 99 | for folder_idx in range(folder_num): 100 | img = img_list[folder_idx] 101 | if torch.max(img) > 255 or torch.min(img) < 0: 102 | raise Exception('tensor value out of range...\t range is [' + str(torch.min(img)) + ' ~ ' + str(torch.max(img))) 103 | img = img.cpu().data.numpy().astype(np.uint8) 104 | img = np.reshape(img, [-1, 3]) 105 | color_sets[folder_idx] |= set([tuple(_) for _ in img]) 106 | 107 | # Merge the color 108 | for i in range(folder_num): 109 | for color in color_sets[i]: 110 | if color not in palletes[i].keys(): 111 | palletes[i][color] = len(palletes[i]) 112 | save_pallete(palletes, path) 113 | 114 | return palletes 115 | 116 | class CategoricalTranspose(): 117 | def __init__(self, pallete = None, direction = COLOR2INDEX, index_default = 0): 118 | """ 119 | Transform the tensor into the particular format 120 | We support for 3 different kinds of format: 121 | 1. one hot image 122 | 2. index image 123 | 3. color 124 | 125 | Arg: pallete - The pallete object (default is None) 126 | direction - The direction you want to change 127 | index_default - The default index if the color cannot be found in the pallete 128 | """ 129 | self.pallete = pallete 130 | self.direction = direction 131 | self.index_default = index_default 132 | INFO("Applied << %15s >> , direction: %s" % (self.__class__.__name__, self.direction)) 133 | INFO("* Notice: The range of tensor should be in [-1, 1]") 134 | INFO("* Notice: the rank format of input tensor should be 'BCHW'") 135 | 136 | def fn_color_to_index(self, tensor): 137 | """ 138 | Transfer the tensor from the RGB colorful format into the index format 139 | 140 | Arg: tensor - The tensor obj. The tensor you want to deal with 141 | Ret: The tensor with index format 142 | """ 143 | if self.pallete is None: 144 | raise Exception("The direction << %s >> need the pallete object" % self.direction) 145 | tensor = tensor.transpose(-3, -2).transpose(-2, -1).cpu().data.numpy() 146 | size_tuple = list(np.shape(tensor)) 147 | tensor = (tensor * 127.5 + 127.5).astype(np.uint8) 148 | tensor = np.reshape(tensor, [-1, 3]) 149 | tensor = [tuple(_) for _ in tensor] 150 | tensor = [self.pallete.get(_, self.index_default) for _ in tensor] 151 | tensor = np.asarray(tensor) 152 | size_tuple[-1] = 1 153 | tensor = np.reshape(tensor, size_tuple) 154 | tensor = torch.from_numpy(tensor).transpose(-1, -2).transpose(-2, -3) 155 | return tensor 156 | 157 | def fn_index_to_one_hot(self, tensor): 158 | """ 159 | Transfer the tensor from the index format into the one-hot format 160 | 161 | Arg: tensor - The tensor obj. The tensor you want to deal with 162 | Ret: The tensor with one-hot format 163 | """ 164 | # Get the number of classes 165 | tensor = tensor.transpose(-3, -2).transpose(-2, -1) 166 | size_tuple = list(np.shape(tensor)) 167 | tensor = tensor.view(-1).cpu().data.numpy() 168 | channel = np.amax(tensor) + 1 169 | 170 | # Get the total number of pixel 171 | num_of_pixel = 1 172 | for i in range(len(size_tuple) - 1): 173 | num_of_pixel *= size_tuple[i] 174 | 175 | # Transfer as ont-hot format 176 | one_hot_tensor = np.zeros([num_of_pixel, channel]) 177 | for i in range(channel): 178 | one_hot_tensor[tensor == i, i] = 1 179 | 180 | # Recover to origin rank format and shape 181 | size_tuple[-1] = channel 182 | tensor = np.reshape(one_hot_tensor, size_tuple) 183 | tensor = torch.from_numpy(tensor).transpose(-1, -2).transpose(-2, -3) 184 | return tensor 185 | 186 | def fn_one_hot_to_index(self, tensor): 187 | """ 188 | Transfer the tensor from the one-hot format into the index format 189 | 190 | Arg: tensor - The tensor obj. The tensor you want to deal with 191 | Ret: The tensor with index format 192 | """ 193 | _, tensor = torch.max(tensor, dim = 1) 194 | tensor = tensor.unsqueeze(1) 195 | return tensor 196 | 197 | def fn_index_to_color(self, tensor): 198 | """ 199 | Transfer the tensor from the index format into the RGB colorful format 200 | 201 | Arg: tensor - The tensor obj. The tensor you want to deal with 202 | Ret: The tensor with RGB colorful format 203 | """ 204 | if self.pallete is None: 205 | raise Exception("The direction << %s >> need the pallete object" % self.direction) 206 | tensor = tensor.transpose(-3, -2).transpose(-2, -1).cpu().data.numpy() 207 | reverse_pallete = {self.pallete[x]: x for x in self.pallete} 208 | batch, height, width, channel = np.shape(tensor) 209 | tensor = np.reshape(tensor, [-1]) 210 | tensor = np.round(tensor, decimals=0) 211 | tensor = np.vectorize(reverse_pallete.get)(tensor) 212 | tensor = np.reshape(np.asarray(tensor).T, [batch, height, width, len(reverse_pallete[0])]) 213 | tensor = torch.from_numpy((tensor - 127.5) / 127.5).transpose(-1, -2).transpose(-2, -3) 214 | return tensor 215 | 216 | def __call__(self, tensor): 217 | if self.direction == COLOR2INDEX: 218 | return self.fn_color_to_index(tensor) 219 | elif self.direction == INDEX2COLOR: 220 | return self.fn_index_to_color(tensor) 221 | elif self.direction == ONEHOT2INDEX: 222 | return self.fn_one_hot_to_index(tensor) 223 | elif self.direction == INDEX2ONEHOT: 224 | return self.fn_index_to_one_hot(tensor) 225 | elif self.direction == ONEHOT2COLOR: 226 | return self.fn_index_to_color(self.fn_one_hot_to_index(tensor)) 227 | elif self.direction == COLOR2ONEHOT: 228 | return self.fn_index_to_one_hot(self.fn_color_to_index(tensor)) 229 | else: 230 | raise Exception("Unknown direction: {}".format(self.direction)) -------------------------------------------------------------------------------- /torchvision_sunner/transforms/complex.py: -------------------------------------------------------------------------------- 1 | from torchvision_sunner.transforms.base import OP 2 | from torchvision_sunner.utils import INFO 3 | from skimage import transform 4 | import numpy as np 5 | import torch 6 | 7 | """ 8 | This script define some complex operations 9 | These kind of operations should conduct work iteratively (with inherit OP class) 10 | 11 | Author: SunnerLi 12 | """ 13 | 14 | class Resize(OP): 15 | def __init__(self, output_size): 16 | """ 17 | Resize the tensor to the desired size 18 | This function only support for nearest-neighbor interpolation 19 | Since this mechanism can also deal with categorical data 20 | 21 | Arg: output_size - The tuple (H, W) 22 | """ 23 | self.output_size = output_size 24 | INFO("Applied << %15s >>" % self.__class__.__name__) 25 | INFO("* Notice: the rank format of input tensor should be 'BHWC'") 26 | 27 | def work(self, tensor): 28 | """ 29 | Resize the tensor 30 | If the tensor is not in the range of [-1, 1], we will do the normalization automatically 31 | 32 | Arg: tensor - The np.ndarray object. The tensor you want to deal with 33 | Ret: The resized tensor 34 | """ 35 | # Normalize the tensor if needed 36 | mean, std = -1, -1 37 | min_v = np.min(tensor) 38 | max_v = np.max(tensor) 39 | if not (max_v <= 1 and min_v >= -1): 40 | mean = 0.5 * max_v + 0.5 * min_v 41 | std = 0.5 * max_v - 0.5 * min_v 42 | # print(max_v, min_v, mean, std) 43 | tensor = (tensor - mean) / std 44 | 45 | # Work 46 | tensor = transform.resize(tensor, self.output_size, mode = 'constant', order = 0) 47 | 48 | # De-normalize the tensor 49 | if mean != -1 and std != -1: 50 | tensor = tensor * std + mean 51 | return tensor 52 | 53 | class Normalize(OP): 54 | def __init__(self, mean = [127.5, 127.5, 127.5], std = [127.5, 127.5, 127.5]): 55 | """ 56 | Normalize the tensor with given mean and standard deviation 57 | * Notice: If you didn't give mean and std, the result will locate in [-1, 1] 58 | 59 | Args: 60 | mean - The mean of the result tensor 61 | std - The standard deviation 62 | """ 63 | self.mean = mean 64 | self.std = std 65 | INFO("Applied << %15s >>" % self.__class__.__name__) 66 | INFO("* Notice: the rank format of input tensor should be 'BCHW'") 67 | INFO("*****************************************************************") 68 | INFO("* Notice: You should must call 'ToFloat' before normalization") 69 | INFO("*****************************************************************") 70 | if self.mean == [127.5, 127.5, 127.5] and self.std == [127.5, 127.5, 127.5]: 71 | INFO("* Notice: The result will locate in [-1, 1]") 72 | 73 | def work(self, tensor): 74 | """ 75 | Normalize the tensor 76 | 77 | Arg: tensor - The np.ndarray object. The tensor you want to deal with 78 | Ret: The normalized tensor 79 | """ 80 | if tensor.shape[0] != len(self.mean): 81 | raise Exception("The rank format should be BCHW, but the shape is {}".format(tensor.shape)) 82 | result = [] 83 | for t, m, s in zip(tensor, self.mean, self.std): 84 | result.append((t - m) / s) 85 | tensor = np.asarray(result) 86 | 87 | # Check if the normalization can really work 88 | if np.min(tensor) < -1 or np.max(tensor) > 1: 89 | raise Exception("Normalize can only work with float tensor", 90 | "Try to call 'ToFloat()' before normalization") 91 | return tensor 92 | 93 | class UnNormalize(OP): 94 | def __init__(self, mean = [127.5, 127.5, 127.5], std = [127.5, 127.5, 127.5]): 95 | """ 96 | Unnormalize the tensor with given mean and standard deviation 97 | * Notice: If you didn't give mean and std, the function will assume that the original distribution locates in [-1, 1] 98 | 99 | Args: 100 | mean - The mean of the result tensor 101 | std - The standard deviation 102 | """ 103 | self.mean = mean 104 | self.std = std 105 | INFO("Applied << %15s >>" % self.__class__.__name__) 106 | INFO("* Notice: the rank format of input tensor should be 'BCHW'") 107 | if self.mean == [127.5, 127.5, 127.5] and self.std == [127.5, 127.5, 127.5]: 108 | INFO("* Notice: The function assume that the input range is [-1, 1]") 109 | 110 | def work(self, tensor): 111 | """ 112 | Un-normalize the tensor 113 | 114 | Arg: tensor - The np.ndarray object. The tensor you want to deal with 115 | Ret: The un-normalized tensor 116 | """ 117 | if tensor.shape[0] != len(self.mean): 118 | raise Exception("The rank format should be BCHW, but the shape is {}".format(tensor.shape)) 119 | result = [] 120 | for t, m, s in zip(tensor, self.mean, self.std): 121 | result.append(t * s + m) 122 | tensor = np.asarray(result) 123 | return tensor 124 | 125 | class ToGray(OP): 126 | def __init__(self): 127 | """ 128 | Change the tensor as the gray scale 129 | The function will turn the BCHW tensor into B1HW gray-scaled tensor 130 | """ 131 | INFO("Applied << %15s >>" % self.__class__.__name__) 132 | INFO("* Notice: the rank format of input tensor should be 'BCHW'") 133 | 134 | def work(self, tensor): 135 | """ 136 | Make the tensor into gray-scale 137 | 138 | Arg: tensor - The np.ndarray object. The tensor you want to deal with 139 | Ret: The gray-scale tensor, and the rank of the tensor is B1HW 140 | """ 141 | if tensor.shape[0] == 3: 142 | result = 0.299 * tensor[0] + 0.587 * tensor[1] + 0.114 * tensor[2] 143 | result = np.expand_dims(result, axis = 0) 144 | elif tensor.shape[0] != 4: 145 | result = 0.299 * tensor[:, 0] + 0.587 * tensor[:, 1] + 0.114 * tensor[:, 2] 146 | result = np.expand_dims(result, axis = 1) 147 | else: 148 | raise Exception("The rank format should be BCHW, but the shape is {}".format(tensor.shape)) 149 | return result -------------------------------------------------------------------------------- /torchvision_sunner/transforms/function.py: -------------------------------------------------------------------------------- 1 | from torchvision_sunner.transforms.simple import Transpose 2 | from torchvision_sunner.constant import BCHW2BHWC 3 | 4 | from skimage import transform 5 | from skimage import io 6 | import numpy as np 7 | import torch 8 | 9 | """ 10 | This script define the transform function which can be called directly 11 | 12 | Author: SunnerLi 13 | """ 14 | 15 | channel_op = None # Define the channel op which will be used in 'asImg' function 16 | 17 | def asImg(tensor, size = None): 18 | """ 19 | This function provides fast approach to transfer the image into numpy.ndarray 20 | This function only accept the output from sigmoid layer or hyperbolic tangent output 21 | 22 | Arg: tensor - The torch.Variable object, the rank format is BCHW or BHW 23 | size - The tuple object, and the format is (height, width) 24 | Ret: The numpy image, the rank format is BHWC 25 | """ 26 | global channel_op 27 | result = tensor.detach() 28 | 29 | # 1. Judge the rank first 30 | if len(tensor.size()) == 3: 31 | result = torch.stack([result, result, result], 1) 32 | 33 | # 2. Judge the range of tensor (sigmoid output or hyperbolic tangent output) 34 | min_v = torch.min(result).cpu().data.numpy() 35 | max_v = torch.max(result).cpu().data.numpy() 36 | if max_v > 1.0 or min_v < -1.0: 37 | raise Exception('tensor value out of range...\t range is [' + str(min_v) + ' ~ ' + str(max_v)) 38 | if min_v < 0: 39 | result = (result + 1) / 2 40 | 41 | # 3. Define the BCHW -> BHWC operation 42 | if channel_op is None: 43 | channel_op = Transpose(BCHW2BHWC) 44 | 45 | # 3. Rest 46 | result = channel_op(result) 47 | result = result.cpu().data.numpy() 48 | if size is not None: 49 | result_list = [] 50 | for img in result: 51 | result_list.append(transform.resize(img, (size[0], size[1]), mode = 'constant', order = 0) * 255) 52 | result = np.stack(result_list, axis = 0) 53 | else: 54 | result *= 255. 55 | result = result.astype(np.uint8) 56 | return result -------------------------------------------------------------------------------- /torchvision_sunner/transforms/simple.py: -------------------------------------------------------------------------------- 1 | from torchvision_sunner.utils import INFO 2 | from torchvision_sunner.constant import * 3 | import numpy as np 4 | import torch 5 | 6 | """ 7 | This script define some operation which are rather simple 8 | The operation only need to call function once (without inherit OP class) 9 | 10 | Author: SunnerLi 11 | """ 12 | 13 | class ToTensor(): 14 | def __init__(self): 15 | """ 16 | Change the tensor into torch.Tensor type 17 | """ 18 | INFO("Applied << %15s >>" % self.__class__.__name__) 19 | 20 | def __call__(self, tensor): 21 | """ 22 | Arg: tensor - The torch.Tensor or other type. The tensor you want to deal with 23 | """ 24 | if type(tensor) == np.ndarray: 25 | tensor = torch.from_numpy(tensor) 26 | return tensor 27 | 28 | class ToFloat(): 29 | def __init__(self): 30 | """ 31 | Change the tensor into torch.FloatTensor 32 | """ 33 | INFO("Applied << %15s >>" % self.__class__.__name__) 34 | 35 | def __call__(self, tensor): 36 | """ 37 | Arg: tensor - The torch.Tensor object. The tensor you want to deal with 38 | """ 39 | return tensor.float() 40 | 41 | class Transpose(): 42 | def __init__(self, direction = BHWC2BCHW): 43 | """ 44 | Transfer the rank of tensor into target one 45 | 46 | Arg: direction - The direction you want to do the transpose 47 | """ 48 | self.direction = direction 49 | if self.direction == BHWC2BCHW: 50 | INFO("Applied << %15s >>, The rank format is BCHW" % self.__class__.__name__) 51 | elif self.direction == BCHW2BHWC: 52 | INFO("Applied << %15s >>, The rank format is BHWC" % self.__class__.__name__) 53 | else: 54 | raise Exception("Unknown direction symbol: {}".format(self.direction)) 55 | 56 | def __call__(self, tensor): 57 | """ 58 | Arg: tensor - The torch.Tensor object. The tensor you want to deal with 59 | """ 60 | if self.direction == BHWC2BCHW: 61 | tensor = tensor.transpose(-1, -2).transpose(-2, -3) 62 | else: 63 | tensor = tensor.transpose(-3, -2).transpose(-2, -1) 64 | return tensor -------------------------------------------------------------------------------- /torchvision_sunner/utils.py: -------------------------------------------------------------------------------- 1 | from torchvision_sunner.constant import * 2 | 3 | """ 4 | This script defines the function which are widely used in the whole package 5 | 6 | Author: SunnerLi 7 | """ 8 | 9 | def quiet(): 10 | """ 11 | Mute the information toward the whole log in the toolkit 12 | """ 13 | global verbose 14 | verbose = False 15 | 16 | def INFO(string = None): 17 | """ 18 | Print the information with prefix 19 | 20 | Arg: string - The string you want to print 21 | """ 22 | if verbose: 23 | if string: 24 | print("[ Torchvision_sunner ] %s" % (string)) 25 | else: 26 | print("[ Torchvision_sunner ] " + '=' * 50) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | """ 3 | @author: samuel ko 4 | """ 5 | import torchvision_sunner.transforms as sunnertransforms 6 | import torchvision_sunner.data as sunnerData 7 | import torchvision.transforms as transforms 8 | 9 | 10 | from networks_stylegan import StyleGenerator, StyleDiscriminator 11 | from networks_gan import Generator, Discriminator 12 | from utils import plotLossCurve 13 | from loss import gradient_penalty 14 | from opts import TrainOptions 15 | 16 | from torchvision.utils import save_image 17 | from tqdm import tqdm 18 | from matplotlib import pyplot as plt 19 | import torch.optim as optim 20 | import numpy as np 21 | import torch 22 | import os 23 | 24 | # Hyper-parameters 25 | CRITIC_ITER = 5 26 | 27 | 28 | def main(opts): 29 | # Create the data loader 30 | loader = sunnerData.DataLoader(sunnerData.ImageDataset( 31 | root=[[opts.path]], 32 | transform=transforms.Compose([ 33 | sunnertransforms.Resize((1024, 1024)), 34 | sunnertransforms.ToTensor(), 35 | sunnertransforms.ToFloat(), 36 | sunnertransforms.Transpose(sunnertransforms.BHWC2BCHW), 37 | sunnertransforms.Normalize(), 38 | ])), 39 | batch_size=opts.batch_size, 40 | shuffle=True, 41 | ) 42 | 43 | # Create the model 44 | G = StyleGenerator(bs=opts.batch_size).to(opts.device) 45 | D = StyleDiscriminator().to(opts.device) 46 | 47 | # G = Generator().to(opts.device) 48 | # D = Discriminator().to(opts.device) 49 | 50 | # Create the criterion, optimizer and scheduler 51 | optim_D = optim.Adam(D.parameters(), lr=0.0001, betas=(0.5, 0.999)) 52 | optim_G = optim.Adam(G.parameters(), lr=0.0001, betas=(0.5, 0.999)) 53 | scheduler_D = optim.lr_scheduler.ExponentialLR(optim_D, gamma=0.99) 54 | scheduler_G = optim.lr_scheduler.ExponentialLR(optim_G, gamma=0.99) 55 | 56 | # Train 57 | fix_z = torch.randn([opts.batch_size, 512]).to(opts.device) 58 | Loss_D_list = [0.0] 59 | Loss_G_list = [0.0] 60 | for ep in range(opts.epoch): 61 | bar = tqdm(loader) 62 | loss_D_list = [] 63 | loss_G_list = [] 64 | for i, (real_img,) in enumerate(bar): 65 | # ======================================================================================================= 66 | # Update discriminator 67 | # ======================================================================================================= 68 | # Compute adversarial loss toward discriminator 69 | real_img = real_img.to(opts.device) 70 | real_logit = D(real_img) 71 | fake_img = G(torch.randn([real_img.size(0), 512]).to(opts.device)) 72 | fake_logit = D(fake_img.detach()) 73 | d_loss = -(real_logit.mean() - fake_logit.mean()) + gradient_penalty(real_img.data, fake_img.data, D) * 10.0 74 | loss_D_list.append(d_loss.item()) 75 | 76 | # Update discriminator 77 | optim_D.zero_grad() 78 | d_loss.backward() 79 | optim_D.step() 80 | 81 | # ======================================================================================================= 82 | # Update generator 83 | # ======================================================================================================= 84 | if i % CRITIC_ITER == 0: 85 | # Compute adversarial loss toward generator 86 | fake_img = G(torch.randn([opts.batch_size, 512]).to(opts.device)) 87 | fake_logit = D(fake_img) 88 | g_loss = -fake_logit.mean() 89 | loss_G_list.append(g_loss.item()) 90 | 91 | # Update generator 92 | D.zero_grad() 93 | optim_G.zero_grad() 94 | g_loss.backward() 95 | optim_G.step() 96 | bar.set_description(" {} [G]: {} [D]: {}".format(ep, loss_G_list[-1], loss_D_list[-1])) 97 | 98 | # Save the result 99 | Loss_G_list.append(np.mean(loss_G_list)) 100 | Loss_D_list.append(np.mean(loss_D_list)) 101 | fake_img = G(fix_z) 102 | save_image(fake_img, os.path.join(opts.det, 'images', str(ep) + '.png'), nrow=4, normalize=True) 103 | state = { 104 | 'G': G.state_dict(), 105 | 'D': D.state_dict(), 106 | 'Loss_G': Loss_G_list, 107 | 'Loss_D': Loss_D_list, 108 | } 109 | torch.save(state, os.path.join(opts.det, 'models', 'latest.pth')) 110 | 111 | scheduler_D.step() 112 | scheduler_G.step() 113 | 114 | # Plot the total loss curve 115 | Loss_D_list = Loss_D_list[1:] 116 | Loss_G_list = Loss_G_list[1:] 117 | plotLossCurve(opts, Loss_D_list, Loss_G_list) 118 | 119 | 120 | if __name__ == '__main__': 121 | opts = TrainOptions().parse() 122 | main(opts) -------------------------------------------------------------------------------- /train_stylegan.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | """ 3 | @author: samuel ko 4 | """ 5 | import torchvision_sunner.transforms as sunnertransforms 6 | import torchvision_sunner.data as sunnerData 7 | import torchvision.transforms as transforms 8 | 9 | 10 | from networks_stylegan import StyleGenerator, StyleDiscriminator 11 | from networks_gan import Generator, Discriminator 12 | from utils.utils import plotLossCurve 13 | from loss.loss import gradient_penalty, R1Penalty, R2Penalty 14 | from opts.opts import TrainOptions, INFO 15 | 16 | from torchvision.utils import save_image 17 | from tqdm import tqdm 18 | from matplotlib import pyplot as plt 19 | from torch import nn 20 | import torch.optim as optim 21 | from torch.autograd import Variable 22 | import numpy as np 23 | import random 24 | import torch 25 | import os 26 | 27 | # Set random seem for reproducibility 28 | manualSeed = 999 29 | #manualSeed = random.randint(1, 10000) # use if you want new results 30 | print("Random Seed: ", manualSeed) 31 | random.seed(manualSeed) 32 | torch.manual_seed(manualSeed) 33 | 34 | # Hyper-parameters 35 | CRITIC_ITER = 5 36 | 37 | 38 | def main(opts): 39 | # Create the data loader 40 | loader = sunnerData.DataLoader(sunnerData.ImageDataset( 41 | root=[[opts.path]], 42 | transform=transforms.Compose([ 43 | sunnertransforms.Resize((1024, 1024)), 44 | sunnertransforms.ToTensor(), 45 | sunnertransforms.ToFloat(), 46 | sunnertransforms.Transpose(sunnertransforms.BHWC2BCHW), 47 | sunnertransforms.Normalize(), 48 | ])), 49 | batch_size=opts.batch_size, 50 | shuffle=True, 51 | ) 52 | 53 | # Create the model 54 | start_epoch = 0 55 | G = StyleGenerator() 56 | D = StyleDiscriminator() 57 | 58 | # Load the pre-trained weight 59 | if os.path.exists(opts.resume): 60 | INFO("Load the pre-trained weight!") 61 | state = torch.load(opts.resume) 62 | G.load_state_dict(state['G']) 63 | D.load_state_dict(state['D']) 64 | start_epoch = state['start_epoch'] 65 | else: 66 | INFO("Pre-trained weight cannot load successfully, train from scratch!") 67 | 68 | # Multi-GPU support 69 | if torch.cuda.device_count() > 1: 70 | INFO("Multiple GPU:" + str(torch.cuda.device_count()) + "\t GPUs") 71 | G = nn.DataParallel(G) 72 | D = nn.DataParallel(D) 73 | G.to(opts.device) 74 | D.to(opts.device) 75 | 76 | # Create the criterion, optimizer and scheduler 77 | optim_D = optim.Adam(D.parameters(), lr=0.00001, betas=(0.5, 0.999)) 78 | optim_G = optim.Adam(G.parameters(), lr=0.00001, betas=(0.5, 0.999)) 79 | scheduler_D = optim.lr_scheduler.ExponentialLR(optim_D, gamma=0.99) 80 | scheduler_G = optim.lr_scheduler.ExponentialLR(optim_G, gamma=0.99) 81 | 82 | # Train 83 | fix_z = torch.randn([opts.batch_size, 512]).to(opts.device) 84 | softplus = nn.Softplus() 85 | Loss_D_list = [0.0] 86 | Loss_G_list = [0.0] 87 | for ep in range(start_epoch, opts.epoch): 88 | bar = tqdm(loader) 89 | loss_D_list = [] 90 | loss_G_list = [] 91 | for i, (real_img,) in enumerate(bar): 92 | # ======================================================================================================= 93 | # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z))) 94 | # ======================================================================================================= 95 | # Compute adversarial loss toward discriminator 96 | D.zero_grad() 97 | real_img = real_img.to(opts.device) 98 | real_logit = D(real_img) 99 | fake_img = G(torch.randn([real_img.size(0), 512]).to(opts.device)) 100 | fake_logit = D(fake_img.detach()) 101 | d_loss = softplus(fake_logit).mean() 102 | d_loss = d_loss + softplus(-real_logit).mean() 103 | 104 | if opts.r1_gamma != 0.0: 105 | r1_penalty = R1Penalty(real_img.detach(), D) 106 | d_loss = d_loss + r1_penalty * (opts.r1_gamma * 0.5) 107 | 108 | if opts.r2_gamma != 0.0: 109 | r2_penalty = R2Penalty(fake_img.detach(), D) 110 | d_loss = d_loss + r2_penalty * (opts.r2_gamma * 0.5) 111 | 112 | loss_D_list.append(d_loss.item()) 113 | 114 | # Update discriminator 115 | d_loss.backward() 116 | optim_D.step() 117 | 118 | # ======================================================================================================= 119 | # (2) Update G network: maximize log(D(G(z))) 120 | # ======================================================================================================= 121 | if i % CRITIC_ITER == 0: 122 | G.zero_grad() 123 | fake_logit = D(fake_img) 124 | g_loss = softplus(-fake_logit).mean() 125 | loss_G_list.append(g_loss.item()) 126 | 127 | # Update generator 128 | g_loss.backward() 129 | optim_G.step() 130 | 131 | # Output training stats 132 | bar.set_description("Epoch {} [{}, {}] [G]: {} [D]: {}".format(ep, i+1, len(loader), loss_G_list[-1], loss_D_list[-1])) 133 | 134 | # Save the result 135 | Loss_G_list.append(np.mean(loss_G_list)) 136 | Loss_D_list.append(np.mean(loss_D_list)) 137 | 138 | # Check how the generator is doing by saving G's output on fixed_noise 139 | with torch.no_grad(): 140 | fake_img = G(fix_z).detach().cpu() 141 | save_image(fake_img, os.path.join(opts.det, 'images', str(ep) + '.png'), nrow=4, normalize=True) 142 | 143 | # Save model 144 | state = { 145 | 'G': G.state_dict(), 146 | 'D': D.state_dict(), 147 | 'Loss_G': Loss_G_list, 148 | 'Loss_D': Loss_D_list, 149 | 'start_epoch': ep, 150 | } 151 | torch.save(state, os.path.join(opts.det, 'models', 'latest.pth')) 152 | 153 | scheduler_D.step() 154 | scheduler_G.step() 155 | 156 | # Plot the total loss curve 157 | Loss_D_list = Loss_D_list[1:] 158 | Loss_G_list = Loss_G_list[1:] 159 | plotLossCurve(opts, Loss_D_list, Loss_G_list) 160 | 161 | 162 | if __name__ == '__main__': 163 | opts = TrainOptions().parse() 164 | main(opts) 165 | -------------------------------------------------------------------------------- /utils/stylegan-teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/utils/stylegan-teaser.png -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | """ 3 | @author: samuel ko 4 | """ 5 | 6 | from matplotlib import pyplot as plt 7 | import os 8 | 9 | def plotLossCurve(opts, Loss_D_list, Loss_G_list): 10 | plt.figure() 11 | plt.plot(Loss_D_list, '-') 12 | plt.title("Loss curve (Discriminator)") 13 | plt.savefig(os.path.join(opts.det, 'images', 'loss_curve_discriminator.png')) 14 | 15 | plt.figure() 16 | plt.plot(Loss_G_list, '-o') 17 | plt.title("Loss curve (Generator)") 18 | plt.savefig(os.path.join(opts.det, 'images', 'loss_curve_generator.png')) --------------------------------------------------------------------------------