├── .github
└── FUNDING.yml
├── .idea
└── vcs.xml
├── README.md
├── loss
└── loss.py
├── networks_gan.py
├── networks_stylegan.py
├── opts
└── opts.py
├── star
├── 00793.png
├── 00794.png
├── 00795.png
├── 00796.png
├── 00797.png
├── 00798.png
├── 00799.png
├── 00800.png
├── 00801.png
├── 00802.png
├── 00803.png
├── 00804.png
├── 00805.png
├── 00806.png
├── 00807.png
├── 00808.png
├── 00809.png
├── 00810.png
├── 00811.png
├── 00812.png
├── 00813.png
├── 00814.png
├── 00815.png
├── 00816.png
├── 00817.png
├── 00818.png
├── 00819.png
├── 00820.png
├── 00821.png
├── 00822.png
├── 00823.png
├── 00824.png
├── 00825.png
├── 00826.png
├── 00827.png
├── 00828.png
├── 00829.png
├── 00830.png
├── 00831.png
└── 00832.png
├── torchvision_sunner
├── __init__.py
├── constant.py
├── data
│ ├── __init__.py
│ ├── base_dataset.py
│ ├── image_dataset.py
│ ├── loader.py
│ └── video_dataset.py
├── read.py
├── transforms
│ ├── __init__.py
│ ├── base.py
│ ├── categorical.py
│ ├── complex.py
│ ├── function.py
│ └── simple.py
└── utils.py
├── train.py
├── train_stylegan.py
└── utils
├── stylegan-teaser.png
└── utils.py
/.github/FUNDING.yml:
--------------------------------------------------------------------------------
1 | # These are supported funding model platforms
2 |
3 | github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
4 | patreon: # Replace with a single Patreon username
5 | open_collective: # Replace with a single Open Collective username
6 | ko_fi: # Replace with a single Ko-fi username
7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
9 | liberapay: # Replace with a single Liberapay username
10 | issuehunt: # Replace with a single IssueHunt username
11 | otechie: # Replace with a single Otechie username
12 | custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']
13 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
This advert was placed by GitAds
4 |
5 | # A PyTorch Implementation of StyleGAN (Unofficial)
6 |
7 | 
8 | 
9 | 
10 | 
11 |
12 | This repository contains a PyTorch implementation of the following paper:
13 | > **A Style-Based Generator Architecture for Generative Adversarial Networks**
14 | > Tero Karras (NVIDIA), Samuli Laine (NVIDIA), Timo Aila (NVIDIA)
15 | > http://stylegan.xyz/paper
16 | >
17 | > **Abstract:** *We propose an alternative generator architecture for generative adversarial networks, borrowing from style transfer literature. The new architecture leads to an automatically learned, unsupervised separation of high-level attributes (e.g., pose and identity when trained on human faces) and stochastic variation in the generated images (e.g., freckles, hair), and it enables intuitive, scale-specific control of the synthesis. The new generator improves the state-of-the-art in terms of traditional distribution quality metrics, leads to demonstrably better interpolation properties, and also better disentangles the latent factors of variation. To quantify interpolation quality and disentanglement, we propose two new, automated methods that are applicable to any generator architecture. Finally, we introduce a new, highly varied and high-quality dataset of human faces.*
18 |
19 |
20 | 
21 | Picture: These people are not real – they were produced by our generator that allows control over different aspects of the image.
22 |
23 | ## Motivation
24 | To the best of my knowledge, there is still not a similar pytorch 1.0 implementation of styleGAN as NvLabs released(Tensorflow),
25 | therefore, i wanna implement it on pytorch1.0.1 to extend its usage in pytorch community.
26 |
27 | ## Notice
28 | @date: 2019.10.21
29 |
30 | @info: The noteworthy thing I just ignore to highlight is **you need to change default `Star` dataset to your own dataset** (such as FFHQ or others) in `opts.py`. Sorry for my carelessness for this.
31 |
32 |
33 | ## Author
34 |
35 | - [Samuel Ko](https://blog.csdn.net/g11d111)
36 | - [Sunner Li](https://github.com/SunnerLi)
37 |
38 | ## Training
39 |
40 | ``` python
41 | # ① pass your own dataset of training, batchsize and common settings in TrainOpts of `opts.py`.
42 |
43 | # ② run train_stylegan.py
44 | python3 train_stylegan.py
45 |
46 | # ③ you can get intermediate pics generated by stylegenerator in `opts.det/images/`
47 | ```
48 |
49 | ## Project
50 | > we follow the release code of styleGAN carefully and if you found any bug or mistake in implementation,
51 | > please tell us and improve it, thank u very much!
52 | #### Finished
53 | * `blur2d` mechanism. (a step which takes much gpu memory and if you don't have enough resouces, please set it to `None`.)
54 | * `truncation` tricks.
55 | * Two kind of `upsample` method in `G_synthesis`.
56 | * Two kind of `downsample` method in `StyleDiscriminator`.
57 | * `PixelNorm` and `InstanceNorm`.
58 | * `Noise` mechanism.
59 | * `styleMixed` mechanism.
60 | * add `Multi-GPU` support.
61 |
62 | #### Unfinished
63 | * Inference code.
64 |
65 | ## Related
66 | [1. StyleGAN - Official TensorFlow Implementation](https://github.com/NVlabs/stylegan)
67 |
68 | [2. The re-implementation of style-based generator idea](https://github.com/SunnerLi/StyleGAN_demo)
69 |
70 | [3. ptrblck_styleGAN](https://github.com/lernapparat/lernapparat/blob/master/style_gan/pytorch_style_gan.ipynb)
71 |
72 | ## System Requirements
73 | - Ubuntu18.04
74 | - PyTorch 1.0.1
75 | - Numpy 1.13.3
76 | - torchvision 0.2.1
77 | - scikit-image 0.15.0
78 | - tqdm
79 | - GTX 1080Ti or above
80 |
81 | ## Q&A
82 |
83 | ## Acknowledgements
84 | Our code can run `1024 x 1024` resolution image generation task on 1080Ti, if you have stronger graphic card or GPU, then
85 | you may train your model with large batchsize and self-define your multi-gpu version of this code.
86 |
87 | My Email is **samuel.gao023@gmail.com**, if you have any question and wanna to PR, please let me know, thank you.
88 |
--------------------------------------------------------------------------------
/loss/loss.py:
--------------------------------------------------------------------------------
1 | # coding: UTF-8
2 | """
3 | @author: samuel ko
4 | """
5 |
6 | from torch.autograd import Variable
7 | from torch.autograd import grad
8 | import torch.autograd as autograd
9 | import torch.nn as nn
10 | import torch
11 |
12 | import numpy as np
13 |
14 | def gradient_penalty(x, y, f):
15 | # interpolation
16 | shape = [x.size(0)] + [1] * (x.dim() - 1)
17 | alpha = torch.rand(shape).to(x.device)
18 | z = x + alpha * (y - x)
19 |
20 | # gradient penalty
21 | z = Variable(z, requires_grad=True).to(x.device)
22 | o = f(z)
23 | g = grad(o, z, grad_outputs=torch.ones(o.size()).to(z.device), create_graph=True)[0].view(z.size(0), -1)
24 | gp = ((g.norm(p=2, dim=1) - 1)**2).mean()
25 | return gp
26 |
27 |
28 | def R1Penalty(real_img, f):
29 | # gradient penalty
30 | reals = Variable(real_img, requires_grad=True).to(real_img.device)
31 | real_logit = f(reals)
32 | apply_loss_scaling = lambda x: x * torch.exp(x * torch.Tensor([np.float32(np.log(2.0))]).to(real_img.device))
33 | undo_loss_scaling = lambda x: x * torch.exp(-x * torch.Tensor([np.float32(np.log(2.0))]).to(real_img.device))
34 |
35 | real_logit = apply_loss_scaling(torch.sum(real_logit))
36 | real_grads = grad(real_logit, reals, grad_outputs=torch.ones(real_logit.size()).to(reals.device), create_graph=True)[0].view(reals.size(0), -1)
37 | real_grads = undo_loss_scaling(real_grads)
38 | r1_penalty = torch.sum(torch.mul(real_grads, real_grads))
39 | return r1_penalty
40 |
41 |
42 | def R2Penalty(fake_img, f):
43 | # gradient penalty
44 | fakes = Variable(fake_img, requires_grad=True).to(fake_img.device)
45 | fake_logit = f(fakes)
46 | apply_loss_scaling = lambda x: x * torch.exp(x * torch.Tensor([np.float32(np.log(2.0))]).to(fake_img.device))
47 | undo_loss_scaling = lambda x: x * torch.exp(-x * torch.Tensor([np.float32(np.log(2.0))]).to(fake_img.device))
48 |
49 | fake_logit = apply_loss_scaling(torch.sum(fake_logit))
50 | fake_grads = grad(fake_logit, fakes, grad_outputs=torch.ones(fake_logit.size()).to(fakes.device), create_graph=True)[0].view(fakes.size(0), -1)
51 | fake_grads = undo_loss_scaling(fake_grads)
52 | r2_penalty = torch.sum(torch.mul(fake_grads, fake_grads))
53 | return r2_penalty
54 |
--------------------------------------------------------------------------------
/networks_gan.py:
--------------------------------------------------------------------------------
1 | # coding: UTF-8
2 | """
3 | @author: samuel ko
4 | """
5 | import torch.nn.functional as F
6 | import torch.nn as nn
7 | import torch
8 |
9 |
10 | class Generator(nn.Module):
11 | def __init__(self, z_dims=512, d=64):
12 | super().__init__()
13 | self.deconv1 = nn.utils.spectral_norm(nn.ConvTranspose2d(z_dims, d * 8, 4, 1, 0))
14 | self.deconv2 = nn.utils.spectral_norm(nn.ConvTranspose2d(d * 8, d * 8, 4, 2, 1))
15 | self.deconv3 = nn.utils.spectral_norm(nn.ConvTranspose2d(d * 8, d * 4, 4, 2, 1))
16 | self.deconv4 = nn.utils.spectral_norm(nn.ConvTranspose2d(d * 4, d * 2, 4, 2, 1))
17 | self.deconv5 = nn.utils.spectral_norm(nn.ConvTranspose2d(d * 2, d, 4, 2, 1))
18 | self.deconv6 = nn.ConvTranspose2d(d, 3, 4, 2, 1)
19 |
20 | def forward(self, input):
21 | input = input.view(input.size(0), input.size(1), 1, 1) # 1 x 1
22 | x = F.relu(self.deconv1(input)) # 4 x 4
23 | x = F.relu(self.deconv2(x)) # 8 x 8
24 | x = F.relu(self.deconv3(x)) # 16 x 16
25 | x = F.relu(self.deconv4(x)) # 32 x 32
26 | x = F.relu(self.deconv5(x)) # 64 x 64
27 | x = F.tanh(self.deconv6(x)) # 128 x 128
28 | return x
29 |
30 |
31 | class Discriminator(nn.Module):
32 | def __init__(self, nc=3, ndf=64):
33 | super().__init__()
34 | self.layer1 = nn.Conv2d(nc, ndf, 4, 2, 1, bias=False)
35 | self.layer2 = nn.utils.spectral_norm(nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False))
36 | self.layer3 = nn.utils.spectral_norm(nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False))
37 | self.layer4 = nn.utils.spectral_norm(nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False))
38 | self.layer5 = nn.utils.spectral_norm(nn.Conv2d(ndf * 8, ndf * 8, 4, 2, 1, bias=False))
39 | self.layer6 = nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False)
40 |
41 | def forward(self, input):
42 | out = F.leaky_relu(self.layer1(input), 0.2, inplace=True) # 64 x 64
43 | out = F.leaky_relu(self.layer2(out), 0.2, inplace=True) # 32 x 32
44 | out = F.leaky_relu(self.layer3(out), 0.2, inplace=True) # 16 x 16
45 | out = F.leaky_relu(self.layer4(out), 0.2, inplace=True) # 8 x 8
46 | out = F.leaky_relu(self.layer5(out), 0.2, inplace=True) # 4 x 4
47 | out = F.leaky_relu(self.layer6(out), 0.2, inplace=True) # 1 x 1
48 | return out.view(-1, 1)
--------------------------------------------------------------------------------
/networks_stylegan.py:
--------------------------------------------------------------------------------
1 | # coding: UTF-8
2 | """
3 | @author: samuel ko
4 | @date: 2019.04.11
5 | @notice:
6 | 1) refactor the module of Gsynthesis with
7 | - LayerEpilogue.
8 | - Upsample2d.
9 | - GBlock.
10 | and etc.
11 | 2) the initialization of every patch we use are all abided by the original NvLabs released code.
12 | 3) Discriminator is a simplicity version of PyTorch.
13 | 4) fix bug: default settings of batchsize.
14 |
15 | """
16 | import torch.nn.functional as F
17 | import torch.nn as nn
18 | import numpy as np
19 | import torch
20 | import os
21 | from collections import OrderedDict
22 | from torch.nn.init import kaiming_normal_
23 |
24 |
25 | class ApplyNoise(nn.Module):
26 | def __init__(self, channels):
27 | super().__init__()
28 | self.weight = nn.Parameter(torch.zeros(channels))
29 |
30 | def forward(self, x, noise):
31 | if noise is None:
32 | noise = torch.randn(x.size(0), 1, x.size(2), x.size(3), device=x.device, dtype=x.dtype)
33 | return x + self.weight.view(1, -1, 1, 1) * noise.to(x.device)
34 |
35 |
36 | class ApplyStyle(nn.Module):
37 | """
38 | @ref: https://github.com/lernapparat/lernapparat/blob/master/style_gan/pytorch_style_gan.ipynb
39 | """
40 | def __init__(self, latent_size, channels, use_wscale):
41 | super(ApplyStyle, self).__init__()
42 | self.linear = FC(latent_size,
43 | channels * 2,
44 | gain=1.0,
45 | use_wscale=use_wscale)
46 |
47 | def forward(self, x, latent):
48 | style = self.linear(latent) # style => [batch_size, n_channels*2]
49 | shape = [-1, 2, x.size(1), 1, 1]
50 | style = style.view(shape) # [batch_size, 2, n_channels, ...]
51 | x = x * (style[:, 0] + 1.) + style[:, 1]
52 | return x
53 |
54 |
55 | class FC(nn.Module):
56 | def __init__(self,
57 | in_channels,
58 | out_channels,
59 | gain=2**(0.5),
60 | use_wscale=False,
61 | lrmul=1.0,
62 | bias=True):
63 | """
64 | The complete conversion of Dense/FC/Linear Layer of original Tensorflow version.
65 | """
66 | super(FC, self).__init__()
67 | he_std = gain * in_channels ** (-0.5) # He init
68 | if use_wscale:
69 | init_std = 1.0 / lrmul
70 | self.w_lrmul = he_std * lrmul
71 | else:
72 | init_std = he_std / lrmul
73 | self.w_lrmul = lrmul
74 |
75 | self.weight = torch.nn.Parameter(torch.randn(out_channels, in_channels) * init_std)
76 | if bias:
77 | self.bias = torch.nn.Parameter(torch.zeros(out_channels))
78 | self.b_lrmul = lrmul
79 | else:
80 | self.bias = None
81 |
82 | def forward(self, x):
83 | if self.bias is not None:
84 | out = F.linear(x, self.weight * self.w_lrmul, self.bias * self.b_lrmul)
85 | else:
86 | out = F.linear(x, self.weight * self.w_lrmul)
87 | out = F.leaky_relu(out, 0.2, inplace=True)
88 | return out
89 |
90 |
91 | class Blur2d(nn.Module):
92 | def __init__(self, f=[1,2,1], normalize=True, flip=False, stride=1):
93 | """
94 | depthwise_conv2d:
95 | https://blog.csdn.net/mao_xiao_feng/article/details/78003476
96 | """
97 | super(Blur2d, self).__init__()
98 | assert isinstance(f, list) or f is None, "kernel f must be an instance of python built_in type list!"
99 |
100 | if f is not None:
101 | f = torch.tensor(f, dtype=torch.float32)
102 | f = f[:, None] * f[None, :]
103 | f = f[None, None]
104 | if normalize:
105 | f = f / f.sum()
106 | if flip:
107 | # f = f[:, :, ::-1, ::-1]
108 | f = torch.flip(f, [2, 3])
109 | self.f = f
110 | else:
111 | self.f = None
112 | self.stride = stride
113 |
114 | def forward(self, x):
115 | if self.f is not None:
116 | # expand kernel channels
117 | kernel = self.f.expand(x.size(1), -1, -1, -1).to(x.device)
118 | x = F.conv2d(
119 | x,
120 | kernel,
121 | stride=self.stride,
122 | padding=int((self.f.size(2)-1)/2),
123 | groups=x.size(1)
124 | )
125 | return x
126 | else:
127 | return x
128 |
129 |
130 | class Conv2d(nn.Module):
131 | def __init__(self,
132 | input_channels,
133 | output_channels,
134 | kernel_size,
135 | gain=2 ** (0.5),
136 | use_wscale=False,
137 | lrmul=1,
138 | bias=True):
139 | super().__init__()
140 | he_std = gain * (input_channels * kernel_size ** 2) ** (-0.5) # He init
141 | self.kernel_size = kernel_size
142 | if use_wscale:
143 | init_std = 1.0 / lrmul
144 | self.w_lrmul = he_std * lrmul
145 | else:
146 | init_std = he_std / lrmul
147 | self.w_lrmul = lrmul
148 |
149 | self.weight = torch.nn.Parameter(
150 | torch.randn(output_channels, input_channels, kernel_size, kernel_size) * init_std)
151 | if bias:
152 | self.bias = torch.nn.Parameter(torch.zeros(output_channels))
153 | self.b_lrmul = lrmul
154 | else:
155 | self.bias = None
156 |
157 | def forward(self, x):
158 | if self.bias is not None:
159 | return F.conv2d(x, self.weight * self.w_lrmul, self.bias * self.b_lrmul, padding=self.kernel_size // 2)
160 | else:
161 | return F.conv2d(x, self.weight * self.w_lrmul, padding=self.kernel_size // 2)
162 |
163 |
164 | class Upscale2d(nn.Module):
165 | def __init__(self, factor=2, gain=1):
166 | """
167 | the first upsample method in G_synthesis.
168 | :param factor:
169 | :param gain:
170 | """
171 | super().__init__()
172 | self.gain = gain
173 | self.factor = factor
174 |
175 | def forward(self, x):
176 | if self.gain != 1:
177 | x = x * self.gain
178 | if self.factor > 1:
179 | shape = x.shape
180 | x = x.view(shape[0], shape[1], shape[2], 1, shape[3], 1).expand(-1, -1, -1, self.factor, -1, self.factor)
181 | x = x.contiguous().view(shape[0], shape[1], self.factor * shape[2], self.factor * shape[3])
182 | return x
183 |
184 |
185 | class PixelNorm(nn.Module):
186 | def __init__(self, epsilon=1e-8):
187 | """
188 | @notice: avoid in-place ops.
189 | https://discuss.pytorch.org/t/encounter-the-runtimeerror-one-of-the-variables-needed-for-gradient-computation-has-been-modified-by-an-inplace-operation/836/3
190 | """
191 | super(PixelNorm, self).__init__()
192 | self.epsilon = epsilon
193 |
194 | def forward(self, x):
195 | tmp = torch.mul(x, x) # or x ** 2
196 | tmp1 = torch.rsqrt(torch.mean(tmp, dim=1, keepdim=True) + self.epsilon)
197 |
198 | return x * tmp1
199 |
200 |
201 | class InstanceNorm(nn.Module):
202 | def __init__(self, epsilon=1e-8):
203 | """
204 | @notice: avoid in-place ops.
205 | https://discuss.pytorch.org/t/encounter-the-runtimeerror-one-of-the-variables-needed-for-gradient-computation-has-been-modified-by-an-inplace-operation/836/3
206 | """
207 | super(InstanceNorm, self).__init__()
208 | self.epsilon = epsilon
209 |
210 | def forward(self, x):
211 | x = x - torch.mean(x, (2, 3), True)
212 | tmp = torch.mul(x, x) # or x ** 2
213 | tmp = torch.rsqrt(torch.mean(tmp, (2, 3), True) + self.epsilon)
214 | return x * tmp
215 |
216 |
217 | class LayerEpilogue(nn.Module):
218 | def __init__(self,
219 | channels,
220 | dlatent_size,
221 | use_wscale,
222 | use_noise,
223 | use_pixel_norm,
224 | use_instance_norm,
225 | use_styles):
226 | super(LayerEpilogue, self).__init__()
227 |
228 | if use_noise:
229 | self.noise = ApplyNoise(channels)
230 | self.act = nn.LeakyReLU(negative_slope=0.2)
231 |
232 | if use_pixel_norm:
233 | self.pixel_norm = PixelNorm()
234 | else:
235 | self.pixel_norm = None
236 |
237 | if use_instance_norm:
238 | self.instance_norm = InstanceNorm()
239 | else:
240 | self.instance_norm = None
241 |
242 | if use_styles:
243 | self.style_mod = ApplyStyle(dlatent_size, channels, use_wscale=use_wscale)
244 | else:
245 | self.style_mod = None
246 |
247 | def forward(self, x, noise, dlatents_in_slice=None):
248 | x = self.noise(x, noise)
249 | x = self.act(x)
250 | if self.pixel_norm is not None:
251 | x = self.pixel_norm(x)
252 | if self.instance_norm is not None:
253 | x = self.instance_norm(x)
254 | if self.style_mod is not None:
255 | x = self.style_mod(x, dlatents_in_slice)
256 |
257 | return x
258 |
259 |
260 | class GBlock(nn.Module):
261 | def __init__(self,
262 | res,
263 | use_wscale,
264 | use_noise,
265 | use_pixel_norm,
266 | use_instance_norm,
267 | noise_input, # noise
268 | dlatent_size=512, # Disentangled latent (W) dimensionality.
269 | use_style=True, # Enable style inputs?
270 | f=None, # (Huge overload, if you dont have enough resouces, please pass it as `f = None`)Low-pass filter to apply when resampling activations. None = no filtering.
271 | factor=2, # upsample factor.
272 | fmap_base=8192, # Overall multiplier for the number of feature maps.
273 | fmap_decay=1.0, # log2 feature map reduction when doubling the resolution.
274 | fmap_max=512, # Maximum number of feature maps in any layer.
275 | ):
276 | super(GBlock, self).__init__()
277 | self.nf = lambda stage: min(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_max)
278 |
279 | # res
280 | self.res = res
281 |
282 | # blur2d
283 | self.blur = Blur2d(f)
284 |
285 | # noise
286 | self.noise_input = noise_input
287 |
288 | if res < 7:
289 | # upsample method 1
290 | self.up_sample = Upscale2d(factor)
291 | else:
292 | # upsample method 2
293 | self.up_sample = nn.ConvTranspose2d(self.nf(res-3), self.nf(res-2), 4, stride=2, padding=1)
294 |
295 | # A Composition of LayerEpilogue and Conv2d.
296 | self.adaIn1 = LayerEpilogue(self.nf(res-2), dlatent_size, use_wscale, use_noise,
297 | use_pixel_norm, use_instance_norm, use_style)
298 | self.conv1 = Conv2d(input_channels=self.nf(res-2), output_channels=self.nf(res-2),
299 | kernel_size=3, use_wscale=use_wscale)
300 | self.adaIn2 = LayerEpilogue(self.nf(res-2), dlatent_size, use_wscale, use_noise,
301 | use_pixel_norm, use_instance_norm, use_style)
302 |
303 | def forward(self, x, dlatent):
304 | x = self.up_sample(x)
305 | x = self.adaIn1(x, self.noise_input[self.res*2-4], dlatent[:, self.res*2-4])
306 | x = self.conv1(x)
307 | x = self.adaIn2(x, self.noise_input[self.res*2-3], dlatent[:, self.res*2-3])
308 | return x
309 |
310 | #model.apply(weights_init)
311 |
312 |
313 | # =========================================================================
314 | # Define sub-network
315 | # 2019.3.31
316 | # FC
317 | # =========================================================================
318 | class G_mapping(nn.Module):
319 | def __init__(self,
320 | mapping_fmaps=512,
321 | dlatent_size=512,
322 | resolution=1024,
323 | normalize_latents=True, # Normalize latent vectors (Z) before feeding them to the mapping layers?
324 | use_wscale=True, # Enable equalized learning rate?
325 | lrmul=0.01, # Learning rate multiplier for the mapping layers.
326 | gain=2**(0.5) # original gain in tensorflow.
327 | ):
328 | super(G_mapping, self).__init__()
329 | self.mapping_fmaps = mapping_fmaps
330 | self.func = nn.Sequential(
331 | FC(self.mapping_fmaps, dlatent_size, gain, lrmul=lrmul, use_wscale=use_wscale),
332 | FC(dlatent_size, dlatent_size, gain, lrmul=lrmul, use_wscale=use_wscale),
333 | FC(dlatent_size, dlatent_size, gain, lrmul=lrmul, use_wscale=use_wscale),
334 | FC(dlatent_size, dlatent_size, gain, lrmul=lrmul, use_wscale=use_wscale),
335 | FC(dlatent_size, dlatent_size, gain, lrmul=lrmul, use_wscale=use_wscale),
336 | FC(dlatent_size, dlatent_size, gain, lrmul=lrmul, use_wscale=use_wscale),
337 | FC(dlatent_size, dlatent_size, gain, lrmul=lrmul, use_wscale=use_wscale),
338 | FC(dlatent_size, dlatent_size, gain, lrmul=lrmul, use_wscale=use_wscale)
339 | )
340 |
341 | self.normalize_latents = normalize_latents
342 | self.resolution_log2 = int(np.log2(resolution))
343 | self.num_layers = self.resolution_log2 * 2 - 2
344 | self.pixel_norm = PixelNorm()
345 | # - 2 means we start from feature map with height and width equals 4.
346 | # as this example, we get num_layers = 18.
347 |
348 | def forward(self, x):
349 | if self.normalize_latents:
350 | x = self.pixel_norm(x)
351 | out = self.func(x)
352 | return out, self.num_layers
353 |
354 |
355 | class G_synthesis(nn.Module):
356 | def __init__(self,
357 | dlatent_size, # Disentangled latent (W) dimensionality.
358 | resolution=1024, # Output resolution (1024 x 1024 by default).
359 | fmap_base=8192, # Overall multiplier for the number of feature maps.
360 | num_channels=3, # Number of output color channels.
361 | structure='fixed', # 'fixed' = no progressive growing, 'linear' = human-readable, 'recursive' = efficient, 'auto' = select automatically.
362 | fmap_max=512, # Maximum number of feature maps in any layer.
363 | fmap_decay=1.0, # log2 feature map reduction when doubling the resolution.
364 | f=None, # (Huge overload, if you dont have enough resouces, please pass it as `f = None`)Low-pass filter to apply when resampling activations. None = no filtering.
365 | use_pixel_norm = False, # Enable pixelwise feature vector normalization?
366 | use_instance_norm = True, # Enable instance normalization?
367 | use_wscale = True, # Enable equalized learning rate?
368 | use_noise = True, # Enable noise inputs?
369 | use_style = True # Enable style inputs?
370 | ): # batch size.
371 | """
372 | 2019.3.31
373 | :param dlatent_size: 512 Disentangled latent(W) dimensionality.
374 | :param resolution: 1024 x 1024.
375 | :param fmap_base:
376 | :param num_channels:
377 | :param structure: only support 'fixed' mode.
378 | :param fmap_max:
379 | """
380 | super(G_synthesis, self).__init__()
381 |
382 | self.nf = lambda stage: min(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_max)
383 | self.structure = structure
384 | self.resolution_log2 = int(np.log2(resolution))
385 | # - 2 means we start from feature map with height and width equals 4.
386 | # as this example, we get num_layers = 18.
387 | num_layers = self.resolution_log2 * 2 - 2
388 | self.num_layers = num_layers
389 |
390 | # Noise inputs.
391 | self.noise_inputs = []
392 | for layer_idx in range(num_layers):
393 | res = layer_idx // 2 + 2
394 | shape = [1, 1, 2 ** res, 2 ** res]
395 | self.noise_inputs.append(torch.randn(*shape).to("cuda"))
396 |
397 | # Blur2d
398 | self.blur = Blur2d(f)
399 |
400 | # torgb: fixed mode
401 | self.channel_shrinkage = Conv2d(input_channels=self.nf(self.resolution_log2-2),
402 | output_channels=self.nf(self.resolution_log2),
403 | kernel_size=3,
404 | use_wscale=use_wscale)
405 | self.torgb = Conv2d(self.nf(self.resolution_log2), num_channels, kernel_size=1, gain=1, use_wscale=use_wscale)
406 |
407 | # Initial Input Block
408 | self.const_input = nn.Parameter(torch.ones(1, self.nf(1), 4, 4))
409 | self.bias = nn.Parameter(torch.ones(self.nf(1)))
410 | self.adaIn1 = LayerEpilogue(self.nf(1), dlatent_size, use_wscale, use_noise,
411 | use_pixel_norm, use_instance_norm, use_style)
412 | self.conv1 = Conv2d(input_channels=self.nf(1), output_channels=self.nf(1), kernel_size=3, use_wscale=use_wscale)
413 | self.adaIn2 = LayerEpilogue(self.nf(1), dlatent_size, use_wscale, use_noise, use_pixel_norm,
414 | use_instance_norm, use_style)
415 |
416 | # Common Block
417 | # 4 x 4 -> 8 x 8
418 | res = 3
419 | self.GBlock1 = GBlock(res, use_wscale, use_noise, use_pixel_norm, use_instance_norm,
420 | self.noise_inputs)
421 |
422 | # 8 x 8 -> 16 x 16
423 | res = 4
424 | self.GBlock2 = GBlock(res, use_wscale, use_noise, use_pixel_norm, use_instance_norm,
425 | self.noise_inputs)
426 |
427 | # 16 x 16 -> 32 x 32
428 | res = 5
429 | self.GBlock3 = GBlock(res, use_wscale, use_noise, use_pixel_norm, use_instance_norm,
430 | self.noise_inputs)
431 |
432 | # 32 x 32 -> 64 x 64
433 | res = 6
434 | self.GBlock4 = GBlock(res, use_wscale, use_noise, use_pixel_norm, use_instance_norm,
435 | self.noise_inputs)
436 |
437 | # 64 x 64 -> 128 x 128
438 | res = 7
439 | self.GBlock5 = GBlock(res, use_wscale, use_noise, use_pixel_norm, use_instance_norm,
440 | self.noise_inputs)
441 |
442 | # 128 x 128 -> 256 x 256
443 | res = 8
444 | self.GBlock6 = GBlock(res, use_wscale, use_noise, use_pixel_norm, use_instance_norm,
445 | self.noise_inputs)
446 |
447 | # 256 x 256 -> 512 x 512
448 | res = 9
449 | self.GBlock7 = GBlock(res, use_wscale, use_noise, use_pixel_norm, use_instance_norm,
450 | self.noise_inputs)
451 |
452 | # 512 x 512 -> 1024 x 1024
453 | res = 10
454 | self.GBlock8 = GBlock(res, use_wscale, use_noise, use_pixel_norm, use_instance_norm,
455 | self.noise_inputs)
456 |
457 | def forward(self, dlatent):
458 | """
459 | dlatent: Disentangled latents (W), shape为[minibatch, num_layers, dlatent_size].
460 | :param dlatent:
461 | :return:
462 | """
463 | images_out = None
464 | # Fixed structure: simple and efficient, but does not support progressive growing.
465 | if self.structure == 'fixed':
466 | # initial block 0:
467 | x = self.const_input.expand(dlatent.size(0), -1, -1, -1)
468 | x = x + self.bias.view(1, -1, 1, 1)
469 | x = self.adaIn1(x, self.noise_inputs[0], dlatent[:, 0])
470 | x = self.conv1(x)
471 | x = self.adaIn2(x, self.noise_inputs[1], dlatent[:, 1])
472 |
473 | # block 1:
474 | # 4 x 4 -> 8 x 8
475 | x = self.GBlock1(x, dlatent)
476 |
477 | # block 2:
478 | # 8 x 8 -> 16 x 16
479 | x = self.GBlock2(x, dlatent)
480 |
481 | # block 3:
482 | # 16 x 16 -> 32 x 32
483 | x = self.GBlock3(x, dlatent)
484 |
485 | # block 4:
486 | # 32 x 32 -> 64 x 64
487 | x = self.GBlock4(x, dlatent)
488 |
489 | # block 5:
490 | # 64 x 64 -> 128 x 128
491 | x = self.GBlock5(x, dlatent)
492 |
493 | # block 6:
494 | # 128 x 128 -> 256 x 256
495 | x = self.GBlock6(x, dlatent)
496 |
497 | # block 7:
498 | # 256 x 256 -> 512 x 512
499 | x = self.GBlock7(x, dlatent)
500 |
501 | # block 8:
502 | # 512 x 512 -> 1024 x 1024
503 | x = self.GBlock8(x, dlatent)
504 |
505 | x = self.channel_shrinkage(x)
506 | images_out = self.torgb(x)
507 | return images_out
508 |
509 |
510 | class StyleGenerator(nn.Module):
511 | def __init__(self,
512 | mapping_fmaps=512,
513 | style_mixing_prob=0.9, # Probability of mixing styles during training. None = disable.
514 | truncation_psi=0.7, # Style strength multiplier for the truncation trick. None = disable.
515 | truncation_cutoff=8, # Number of layers for which to apply the truncation trick. None = disable.
516 | **kwargs
517 | ):
518 | super(StyleGenerator, self).__init__()
519 | self.mapping_fmaps = mapping_fmaps
520 | self.style_mixing_prob = style_mixing_prob
521 | self.truncation_psi = truncation_psi
522 | self.truncation_cutoff = truncation_cutoff
523 |
524 | self.mapping = G_mapping(self.mapping_fmaps, **kwargs)
525 | self.synthesis = G_synthesis(self.mapping_fmaps, **kwargs)
526 |
527 | def forward(self, latents1):
528 | dlatents1, num_layers = self.mapping(latents1)
529 | # let [N, O] -> [N, num_layers, O]
530 | # 这里的unsqueeze不能使用inplace操作, 如果这样的话, 反向传播的链条会断掉的.
531 | dlatents1 = dlatents1.unsqueeze(1)
532 | dlatents1 = dlatents1.expand(-1, int(num_layers), -1)
533 |
534 | # Add mixing style mechanism.
535 | # with torch.no_grad():
536 | # latents2 = torch.randn(latents1.shape).to(latents1.device)
537 | # dlatents2, num_layers = self.mapping(latents2)
538 | # dlatents2 = dlatents2.unsqueeze(1)
539 | # dlatents2 = dlatents2.expand(-1, int(num_layers), -1)
540 | #
541 | # # TODO: original NvLABs produce a placeholder "lod", this mechanism was not added here.
542 | # cur_layers = num_layers
543 | # mix_layers = num_layers
544 | # if np.random.random() < self.style_mixing_prob:
545 | # mix_layers = np.random.randint(1, cur_layers)
546 | #
547 | # # NvLABs: dlatents = tf.where(tf.broadcast_to(layer_idx < mixing_cutoff, tf.shape(dlatents)), dlatents, dlatents2)
548 | # for i in range(num_layers):
549 | # if i >= mix_layers:
550 | # dlatents1[:, i, :] = dlatents2[:, i, :]
551 |
552 | # Apply truncation trick.
553 | if self.truncation_psi and self.truncation_cutoff:
554 | coefs = np.ones([1, num_layers, 1], dtype=np.float32)
555 | for i in range(num_layers):
556 | if i < self.truncation_cutoff:
557 | coefs[:, i, :] *= self.truncation_psi
558 | """Linear interpolation.
559 | a + (b - a) * t (a = 0)
560 | reduce to
561 | b * t
562 | """
563 |
564 | dlatents1 = dlatents1 * torch.Tensor(coefs).to(dlatents1.device)
565 |
566 | img = self.synthesis(dlatents1)
567 | return img
568 |
569 |
570 | class StyleDiscriminator(nn.Module):
571 | def __init__(self,
572 | resolution=1024,
573 | fmap_base=8192,
574 | num_channels=3,
575 | structure='fixed', # 'fixed' = no progressive growing, 'linear' = human-readable, 'recursive' = efficient, only support 'fixed' mode now.
576 | fmap_max=512,
577 | fmap_decay=1.0,
578 | # f=[1, 2, 1] # (Huge overload, if you dont have enough resouces, please pass it as `f = None`)Low-pass filter to apply when resampling activations. None = no filtering.
579 | f=None # (Huge overload, if you dont have enough resouces, please pass it as `f = None`)Low-pass filter to apply when resampling activations. None = no filtering.
580 | ):
581 | """
582 | Noitce: we only support input pic with height == width.
583 |
584 | if H or W >= 128, we use avgpooling2d to do feature map shrinkage.
585 | else: we use ordinary conv2d.
586 | """
587 | super().__init__()
588 | self.resolution_log2 = int(np.log2(resolution))
589 | assert resolution == 2 ** self.resolution_log2 and resolution >= 4
590 | self.nf = lambda stage: min(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_max)
591 | # fromrgb: fixed mode
592 | self.fromrgb = nn.Conv2d(num_channels, self.nf(self.resolution_log2-1), kernel_size=1)
593 | self.structure = structure
594 |
595 | # blur2d
596 | self.blur2d = Blur2d(f)
597 |
598 | # down_sample
599 | self.down1 = nn.AvgPool2d(2)
600 | self.down21 = nn.Conv2d(self.nf(self.resolution_log2-5), self.nf(self.resolution_log2-5), kernel_size=2, stride=2)
601 | self.down22 = nn.Conv2d(self.nf(self.resolution_log2-6), self.nf(self.resolution_log2-6), kernel_size=2, stride=2)
602 | self.down23 = nn.Conv2d(self.nf(self.resolution_log2-7), self.nf(self.resolution_log2-7), kernel_size=2, stride=2)
603 | self.down24 = nn.Conv2d(self.nf(self.resolution_log2-8), self.nf(self.resolution_log2-8), kernel_size=2, stride=2)
604 |
605 | # conv1: padding=same
606 | self.conv1 = nn.Conv2d(self.nf(self.resolution_log2-1), self.nf(self.resolution_log2-1), kernel_size=3, padding=(1, 1))
607 | self.conv2 = nn.Conv2d(self.nf(self.resolution_log2-1), self.nf(self.resolution_log2-2), kernel_size=3, padding=(1, 1))
608 | self.conv3 = nn.Conv2d(self.nf(self.resolution_log2-2), self.nf(self.resolution_log2-3), kernel_size=3, padding=(1, 1))
609 | self.conv4 = nn.Conv2d(self.nf(self.resolution_log2-3), self.nf(self.resolution_log2-4), kernel_size=3, padding=(1, 1))
610 | self.conv5 = nn.Conv2d(self.nf(self.resolution_log2-4), self.nf(self.resolution_log2-5), kernel_size=3, padding=(1, 1))
611 | self.conv6 = nn.Conv2d(self.nf(self.resolution_log2-5), self.nf(self.resolution_log2-6), kernel_size=3, padding=(1, 1))
612 | self.conv7 = nn.Conv2d(self.nf(self.resolution_log2-6), self.nf(self.resolution_log2-7), kernel_size=3, padding=(1, 1))
613 | self.conv8 = nn.Conv2d(self.nf(self.resolution_log2-7), self.nf(self.resolution_log2-8), kernel_size=3, padding=(1, 1))
614 |
615 | # calculate point:
616 | self.conv_last = nn.Conv2d(self.nf(self.resolution_log2-8), self.nf(1), kernel_size=3, padding=(1, 1))
617 | self.dense0 = nn.Linear(fmap_base, self.nf(0))
618 | self.dense1 = nn.Linear(self.nf(0), 1)
619 | self.sigmoid = nn.Sigmoid()
620 |
621 | def forward(self, input):
622 | if self.structure == 'fixed':
623 | x = F.leaky_relu(self.fromrgb(input), 0.2, inplace=True)
624 | # 1. 1024 x 1024 x nf(9)(16) -> 512 x 512
625 | res = self.resolution_log2
626 | x = F.leaky_relu(self.conv1(x), 0.2, inplace=True)
627 | x = F.leaky_relu(self.down1(self.blur2d(x)), 0.2, inplace=True)
628 |
629 | # 2. 512 x 512 -> 256 x 256
630 | res -= 1
631 | x = F.leaky_relu(self.conv2(x), 0.2, inplace=True)
632 | x = F.leaky_relu(self.down1(self.blur2d(x)), 0.2, inplace=True)
633 |
634 | # 3. 256 x 256 -> 128 x 128
635 | res -= 1
636 | x = F.leaky_relu(self.conv3(x), 0.2, inplace=True)
637 | x = F.leaky_relu(self.down1(self.blur2d(x)), 0.2, inplace=True)
638 |
639 | # 4. 128 x 128 -> 64 x 64
640 | res -= 1
641 | x = F.leaky_relu(self.conv4(x), 0.2, inplace=True)
642 | x = F.leaky_relu(self.down1(self.blur2d(x)), 0.2, inplace=True)
643 |
644 | # 5. 64 x 64 -> 32 x 32
645 | res -= 1
646 | x = F.leaky_relu(self.conv5(x), 0.2, inplace=True)
647 | x = F.leaky_relu(self.down21(self.blur2d(x)), 0.2, inplace=True)
648 |
649 | # 6. 32 x 32 -> 16 x 16
650 | res -= 1
651 | x = F.leaky_relu(self.conv6(x), 0.2, inplace=True)
652 | x = F.leaky_relu(self.down22(self.blur2d(x)), 0.2, inplace=True)
653 |
654 | # 7. 16 x 16 -> 8 x 8
655 | res -= 1
656 | x = F.leaky_relu(self.conv7(x), 0.2, inplace=True)
657 | x = F.leaky_relu(self.down23(self.blur2d(x)), 0.2, inplace=True)
658 |
659 | # 8. 8 x 8 -> 4 x 4
660 | res -= 1
661 | x = F.leaky_relu(self.conv8(x), 0.2, inplace=True)
662 | x = F.leaky_relu(self.down24(self.blur2d(x)), 0.2, inplace=True)
663 |
664 | # 9. 4 x 4 -> point
665 | x = F.leaky_relu(self.conv_last(x), 0.2, inplace=True)
666 | # N x 8192(4 x 4 x nf(1)).
667 | x = x.view(x.size(0), -1)
668 | x = F.leaky_relu(self.dense0(x), 0.2, inplace=True)
669 | # N x 1
670 | x = F.leaky_relu(self.dense1(x), 0.2, inplace=True)
671 | return x
672 |
673 |
--------------------------------------------------------------------------------
/opts/opts.py:
--------------------------------------------------------------------------------
1 | # coding: UTF-8
2 | """
3 | @author: samuel ko
4 | """
5 | import argparse
6 | import torch
7 | import os
8 |
9 |
10 | def INFO(inputs):
11 | print("[ Style GAN ] %s" % (inputs))
12 |
13 |
14 | def presentParameters(args_dict):
15 | """
16 | Print the parameters setting line by line
17 |
18 | Arg: args_dict - The dict object which is transferred from argparse Namespace object
19 | """
20 | INFO("========== Parameters ==========")
21 | for key in sorted(args_dict.keys()):
22 | INFO("{:>15} : {}".format(key, args_dict[key]))
23 | INFO("===============================")
24 |
25 |
26 | class TrainOptions():
27 | def __init__(self):
28 | parser = argparse.ArgumentParser()
29 | parser.add_argument('--path', type=str, default='./star/')
30 | parser.add_argument('--epoch', type=int, default=500)
31 | parser.add_argument('--batch_size', type=int, default=2)
32 | parser.add_argument('--type', type=str, default='style')
33 | parser.add_argument('--resume', type=str, default='train_result/models/latest.pth')
34 | parser.add_argument('--det', type=str, default='train_result')
35 | parser.add_argument('--r1_gamma', type=float, default=10.0)
36 | parser.add_argument('--r2_gamma', type=float, default=0.0)
37 | self.opts = parser.parse_args()
38 |
39 | def parse(self):
40 | self.opts.device = 'cuda' if torch.cuda.is_available() else 'cpu'
41 |
42 | # Check if the parameter is valid
43 | if self.opts.type not in ['style', 'origin']:
44 | raise Exception(
45 | "Unknown type: {} You should assign one of them ['style', 'origin']...".format(self.opts.type))
46 |
47 | # Create the destination folder
48 | if not os.path.exists(self.opts.det):
49 | os.mkdir(self.opts.det)
50 | if not os.path.exists(os.path.join(self.opts.det, 'images')):
51 | os.mkdir(os.path.join(self.opts.det, 'images'))
52 | if not os.path.exists(os.path.join(self.opts.det, 'models')):
53 | os.mkdir(os.path.join(self.opts.det, 'models'))
54 |
55 | # Print the options
56 | presentParameters(vars(self.opts))
57 | return self.opts
58 |
59 |
60 | class InferenceOptions():
61 | def __init__(self):
62 | parser = argparse.ArgumentParser()
63 | parser.add_argument('--resume', type=str, default='train_result/model/latest.pth')
64 | parser.add_argument('--type', type=str, default='style')
65 | parser.add_argument('--num_face', type=int, default=32)
66 | parser.add_argument('--det', type=str, default='result.png')
67 | self.opts = parser.parse_args()
68 |
69 | def parse(self):
70 | self.opts.device = 'cuda' if torch.cuda.is_available() else 'cpu'
71 |
72 | # Print the options
73 | presentParameters(vars(self.opts))
74 | return self.opts
75 |
--------------------------------------------------------------------------------
/star/00793.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00793.png
--------------------------------------------------------------------------------
/star/00794.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00794.png
--------------------------------------------------------------------------------
/star/00795.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00795.png
--------------------------------------------------------------------------------
/star/00796.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00796.png
--------------------------------------------------------------------------------
/star/00797.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00797.png
--------------------------------------------------------------------------------
/star/00798.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00798.png
--------------------------------------------------------------------------------
/star/00799.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00799.png
--------------------------------------------------------------------------------
/star/00800.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00800.png
--------------------------------------------------------------------------------
/star/00801.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00801.png
--------------------------------------------------------------------------------
/star/00802.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00802.png
--------------------------------------------------------------------------------
/star/00803.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00803.png
--------------------------------------------------------------------------------
/star/00804.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00804.png
--------------------------------------------------------------------------------
/star/00805.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00805.png
--------------------------------------------------------------------------------
/star/00806.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00806.png
--------------------------------------------------------------------------------
/star/00807.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00807.png
--------------------------------------------------------------------------------
/star/00808.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00808.png
--------------------------------------------------------------------------------
/star/00809.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00809.png
--------------------------------------------------------------------------------
/star/00810.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00810.png
--------------------------------------------------------------------------------
/star/00811.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00811.png
--------------------------------------------------------------------------------
/star/00812.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00812.png
--------------------------------------------------------------------------------
/star/00813.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00813.png
--------------------------------------------------------------------------------
/star/00814.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00814.png
--------------------------------------------------------------------------------
/star/00815.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00815.png
--------------------------------------------------------------------------------
/star/00816.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00816.png
--------------------------------------------------------------------------------
/star/00817.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00817.png
--------------------------------------------------------------------------------
/star/00818.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00818.png
--------------------------------------------------------------------------------
/star/00819.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00819.png
--------------------------------------------------------------------------------
/star/00820.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00820.png
--------------------------------------------------------------------------------
/star/00821.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00821.png
--------------------------------------------------------------------------------
/star/00822.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00822.png
--------------------------------------------------------------------------------
/star/00823.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00823.png
--------------------------------------------------------------------------------
/star/00824.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00824.png
--------------------------------------------------------------------------------
/star/00825.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00825.png
--------------------------------------------------------------------------------
/star/00826.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00826.png
--------------------------------------------------------------------------------
/star/00827.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00827.png
--------------------------------------------------------------------------------
/star/00828.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00828.png
--------------------------------------------------------------------------------
/star/00829.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00829.png
--------------------------------------------------------------------------------
/star/00830.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00830.png
--------------------------------------------------------------------------------
/star/00831.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00831.png
--------------------------------------------------------------------------------
/star/00832.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/star/00832.png
--------------------------------------------------------------------------------
/torchvision_sunner/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/torchvision_sunner/__init__.py
--------------------------------------------------------------------------------
/torchvision_sunner/constant.py:
--------------------------------------------------------------------------------
1 | """
2 | This script defines the constant which will be used in Torchvision_sunner package
3 |
4 | Author: SunnerLi
5 | """
6 |
7 | # Constant
8 | UNDER_SAMPLING = 0
9 | OVER_SAMPLING = 1
10 | BCHW2BHWC = 0
11 | BHWC2BCHW = 1
12 |
13 | # Categorical constant
14 | ONEHOT2INDEX = 'one_hot_to_index'
15 | INDEX2ONEHOT = 'index_to_one_hot'
16 | ONEHOT2COLOR = 'one_hot_to_color'
17 | COLOR2ONEHOT = 'color_to_one_hot'
18 | INDEX2COLOR = 'index_to_color'
19 | COLOR2INDEX = 'color_to_index'
20 |
21 | # Verbose flag
22 | verbose = True
--------------------------------------------------------------------------------
/torchvision_sunner/data/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | This script define the wrapper of the Torchvision_sunner.data
3 |
4 | Author: SunnerLi
5 | """
6 | from torch.utils.data import DataLoader
7 |
8 | from torchvision_sunner.data.image_dataset import ImageDataset
9 | from torchvision_sunner.data.video_dataset import VideoDataset
10 | from torchvision_sunner.data.loader import *
11 | from torchvision_sunner.constant import *
12 | from torchvision_sunner.utils import *
--------------------------------------------------------------------------------
/torchvision_sunner/data/base_dataset.py:
--------------------------------------------------------------------------------
1 | from torchvision_sunner.constant import *
2 | from torchvision_sunner.utils import INFO
3 | import torch.utils.data as Data
4 |
5 | import pickle
6 | import random
7 | import os
8 |
9 | """
10 | This script define the parent class to deal with some common function for Dataset
11 |
12 | Author: SunnerLi
13 | """
14 |
15 | class BaseDataset(Data.Dataset):
16 | def __init__(self):
17 | self.save_file = False
18 | self.files = None
19 | self.split_files = None
20 |
21 | def generateIndexList(self, a, size):
22 | """
23 | Generate the list of index which will be picked
24 | This function will be used as train-test-split
25 |
26 | Arg: a - The list of images
27 | size - Int, the length of list you want to create
28 | Ret: The index list
29 | """
30 | result = set()
31 | while len(result) != size:
32 | result.add(random.randint(0, len(a) - 1))
33 | return list(result)
34 |
35 | def loadFromFile(self, file_name, check_type = 'image'):
36 | """
37 | Load the root and files information from .pkl record file
38 | This function will return False if the record file format is invalid
39 |
40 | Arg: file_name - The name of record file
41 | check_type - Str. The type of the record file you want to check
42 | Ret: If the loading procedure are successful or not
43 | """
44 | with open(file_name, 'rb') as f:
45 | obj = pickle.load(f)
46 | self.type = obj['type']
47 | if self.type == check_type:
48 | INFO("Load from file: {}".format(file_name))
49 | self.root = obj['root']
50 | self.files = obj['files']
51 | return True
52 | else:
53 | INFO("Record file type: {}\tFail to load...".format(self.type))
54 | INFO("Form the contain from scratch...")
55 | return False
56 |
57 | def save(self, remain_file_name, split_ratio, split_file_name = ".split.pkl", save_type = 'image'):
58 | """
59 | Save the information into record file
60 |
61 | Arg: remain_file_name - The path of record file which store the information of remain data
62 | split_ratio - Float. The proportion to split the data. Usually used to split the testing data
63 | split_file_name - The path of record file which store the information of split data
64 | save_type - Str. The type of the record file you want to save
65 | """
66 | if self.save_file:
67 | if not os.path.exists(remain_file_name):
68 | with open(remain_file_name, 'wb') as f:
69 | pickle.dump({
70 | 'type': save_type,
71 | 'root': self.root,
72 | 'files': self.files
73 | }, f)
74 | if split_ratio:
75 | INFO("Split the dataset, and save as {}".format(split_file_name))
76 | with open(split_file_name, 'wb') as f:
77 | pickle.dump({
78 | 'type': save_type,
79 | 'root': self.root,
80 | 'files': self.split_files
81 | }, f)
--------------------------------------------------------------------------------
/torchvision_sunner/data/image_dataset.py:
--------------------------------------------------------------------------------
1 | from torchvision_sunner.data.base_dataset import BaseDataset
2 | from torchvision_sunner.read import readContain, readItem
3 | from torchvision_sunner.constant import *
4 | from torchvision_sunner.utils import INFO
5 |
6 | from skimage import io as io
7 | # from PIL import Image
8 | from glob import glob
9 |
10 | import torch.utils.data as Data
11 |
12 | import pickle
13 | import math
14 | import os
15 |
16 | """
17 | This script define the structure of image dataset
18 |
19 | =======================================================================================
20 | In the new version, we accept the form that the combination of image and folder:
21 | e.g. [[image1.jpg, image_folder]]
22 | On the other hand, the root can only be 'the list of list'
23 | You should use double list to represent different image domain.
24 | For example:
25 | [[image1.jpg], [image2.jpg]] => valid
26 | [[image1.jpg], [image_folder]] => valid
27 | [[image1.jpg, image2.jpg], [image_folder1, image_folder2]] => valid
28 | [image1.jpg, image2.jpg] => invalid!
29 | Also, the triple of nested list is not allow
30 | =======================================================================================
31 |
32 | Author: SunnerLi
33 | """
34 |
35 | class ImageDataset(BaseDataset):
36 | def __init__(self, root = None, file_name = '.remain.pkl', sample_method = UNDER_SAMPLING, transform = None,
37 | split_ratio = 0.0, save_file = False):
38 | """
39 | The constructor of ImageDataset
40 |
41 | Arg: root - The list object. The image set
42 | file_name - The str. The name of record file.
43 | sample_method - sunnerData.UNDER_SAMPLING or sunnerData.OVER_SAMPLING. Use down sampling or over sampling to deal with data unbalance problem.
44 | (default is sunnerData.OVER_SAMPLING)
45 | transform - transform.Compose object. You can declare some pre-process toward the image
46 | split_ratio - Float. The proportion to split the data. Usually used to split the testing data
47 | save_file - Bool. If storing the record file or not. Default is False
48 | """
49 | super().__init__()
50 | # Record the parameter
51 | self.root = root
52 | self.file_name = file_name
53 | self.sample_method = sample_method
54 | self.transform = transform
55 | self.split_ratio = split_ratio
56 | self.save_file = save_file
57 | self.img_num = -1
58 | INFO()
59 |
60 | # Substitude the contain of record file if the record file is exist
61 | if os.path.exists(file_name) and self.loadFromFile(file_name):
62 | self.getImgNum()
63 | elif not os.path.exists(file_name) and root is None:
64 | raise Exception("Record file {} not found. You should assign 'root' parameter!".format(file_name))
65 | else:
66 | # Extend the images of folder into domain list
67 | self.getFiles()
68 |
69 | # Change root obj as the index format
70 | self.root = range(len(self.root))
71 |
72 | # Adjust the image number
73 | self.getImgNum()
74 |
75 | # Split the files if split_ratio is more than 0.0
76 | self.split()
77 |
78 | # Save the split information
79 | self.save()
80 |
81 | # Print the domain information
82 | self.print()
83 |
84 | # ===========================================================================================
85 | # Define IO function
86 | # ===========================================================================================
87 | def loadFromFile(self, file_name):
88 | """
89 | Load the root and files information from .pkl record file
90 | This function will return False if the record file format is invalid
91 |
92 | Arg: file_name - The name of record file
93 | Ret: If the loading procedure are successful or not
94 | """
95 | return super().loadFromFile(file_name, 'image')
96 |
97 | def save(self, split_file_name = ".split.pkl"):
98 | """
99 | Save the information into record file
100 |
101 | Arg: split_file_name - The path of record file which store the information of split data
102 | """
103 | super().save(self.file_name, self.split_ratio, split_file_name, 'image')
104 |
105 | # ===========================================================================================
106 | # Define main function
107 | # ===========================================================================================
108 | def getFiles(self):
109 | """
110 | Construct the files object for the assigned root
111 | We accept the user to mix folder with image
112 | This function can extract whole image in the folder
113 | The element in the files will all become image
114 |
115 | *******************************************************
116 | * This function only work if the files object is None *
117 | *******************************************************
118 | """
119 | if not self.files:
120 | self.files = {}
121 | for domain_idx, domain in enumerate(self.root):
122 | images = []
123 | for img in domain:
124 | if os.path.exists(img):
125 | if os.path.isdir(img):
126 | images += readContain(img)
127 | else:
128 | images.append(img)
129 | else:
130 | raise Exception("The path {} is not exist".format(img))
131 | self.files[domain_idx] = sorted(images)
132 |
133 | def getImgNum(self):
134 | """
135 | Obtain the image number in the loader for the specific sample method
136 | The function will check if the folder has been extracted
137 | """
138 | if self.img_num == -1:
139 | # Check if the folder has been extracted
140 | for domain in self.root:
141 | for img in self.files[domain]:
142 | if os.path.isdir(img):
143 | raise Exception("You should extend the image in the folder {} first!" % img)
144 |
145 | # Statistic the image number
146 | for domain in self.root:
147 | if domain == 0:
148 | self.img_num = len(self.files[domain])
149 | else:
150 | if self.sample_method == OVER_SAMPLING:
151 | self.img_num = max(self.img_num, len(self.files[domain]))
152 | elif self.sample_method == UNDER_SAMPLING:
153 | self.img_num = min(self.img_num, len(self.files[domain]))
154 | return self.img_num
155 |
156 | def split(self):
157 | """
158 | Split the files object into split_files object
159 | The original files object will shrink
160 |
161 | We also consider the case of pair image
162 | Thus we will check if the number of image in each domain is the same
163 | If it does, then we only generate the list once
164 | """
165 | # Check if the number of image in different domain is the same
166 | if not self.files:
167 | self.getFiles()
168 | pairImage = True
169 | for domain in range(len(self.root) - 1):
170 | if len(self.files[domain]) != len(self.files[domain + 1]):
171 | pairImage = False
172 |
173 | # Split the files
174 | self.split_files = {}
175 | if pairImage:
176 | split_img_num = math.floor(len(self.files[0]) * self.split_ratio)
177 | choice_index_list = self.generateIndexList(range(len(self.files[0])), size = split_img_num)
178 | for domain in range(len(self.root)):
179 | # determine the index list
180 | if not pairImage:
181 | split_img_num = math.floor(len(self.files[domain]) * self.split_ratio)
182 | choice_index_list = self.generateIndexList(range(len(self.files[domain])), size = split_img_num)
183 | # remove the corresponding term and add into new list
184 | split_img_list = []
185 | remain_img_list = self.files[domain].copy()
186 | for j in choice_index_list:
187 | split_img_list.append(self.files[domain][j])
188 | for j in choice_index_list:
189 | self.files[domain].remove(remain_img_list[j])
190 | self.split_files[domain] = sorted(split_img_list)
191 |
192 | def print(self):
193 | """
194 | Print the information for each image domain
195 | """
196 | INFO()
197 | for domain in range(len(self.root)):
198 | INFO("domain index: %d \timage number: %d" % (domain, len(self.files[domain])))
199 | INFO()
200 |
201 | def __len__(self):
202 | return self.img_num
203 |
204 | def __getitem__(self, index):
205 | return_list = []
206 | for domain in self.root:
207 | img_path = self.files[domain][index]
208 | img = readItem(img_path)
209 | if self.transform:
210 | img = self.transform(img)
211 | return_list.append(img)
212 | return return_list
--------------------------------------------------------------------------------
/torchvision_sunner/data/loader.py:
--------------------------------------------------------------------------------
1 | from torchvision_sunner.constant import *
2 | from collections import Iterator
3 | import torch.utils.data as data
4 |
5 | """
6 | This script define the extra loader, and it can be used in flexibility. The loaders include:
7 | 1. ImageLoader (The old version exist)
8 | 2. MultiLoader
9 | 3. IterationLoader
10 |
11 | Author: SunnerLi
12 | """
13 |
14 | class ImageLoader(data.DataLoader):
15 | def __init__(self, dataset, batch_size=1, shuffle=False, num_workers = 1):
16 | """
17 | The DataLoader object which can deal with ImageDataset object.
18 |
19 | Arg: dataset - ImageDataset. You should use sunnerData.ImageDataset to generate the instance first
20 | batch_size - Int.
21 | shuffle - Bool. Shuffle the data or not
22 | num_workers - Int. How many thread you want to use to read the batch data
23 | """
24 | super(ImageLoader, self).__init__(dataset, batch_size=batch_size, shuffle=shuffle, num_workers = num_workers)
25 | self.dataset = dataset
26 | self.iter_num = self.__len__()
27 |
28 | def __len__(self):
29 | return round(self.dataset.img_num / self.batch_size)
30 |
31 | def getImageNumber(self):
32 | return self.dataset.img_num
33 |
34 | class MultiLoader(Iterator):
35 | def __init__(self, datasets, batch_size=1, shuffle=False, num_workers = 1):
36 | """
37 | This class can deal with multiple dataset object
38 |
39 | Arg: datasets - The list of ImageDataset.
40 | batch_size - Int.
41 | shuffle - Bool. Shuffle the data or not
42 | num_workers - Int. How many thread you want to use to read the batch data
43 | """
44 | # Create loaders
45 | self.loaders = []
46 | for dataset in datasets:
47 | self.loaders.append(
48 | data.DataLoader(dataset, batch_size = batch_size, shuffle = shuffle, num_workers = num_workers)
49 | )
50 |
51 | # Check the sample method
52 | self.sample_method = None
53 | for dataset in datasets:
54 | if self.sample_method is None:
55 | self.sample_method = dataset.sample_method
56 | else:
57 | if self.sample_method != dataset.sample_method:
58 | raise Exception("Sample methods are not consistant, {} <=> {}".format(
59 | self.sample_method, dataset.sample_method
60 | ))
61 |
62 | # Check the iteration number
63 | self.iter_num = 0
64 | for i, dataset in enumerate(datasets):
65 | if i == 0:
66 | self.iter_num = len(dataset)
67 | else:
68 | if self.sample_method == UNDER_SAMPLING:
69 | self.iter_num = min(self.iter_num, len(dataset))
70 | else:
71 | self.iter_num = max(self.iter_num, len(dataset))
72 | self.iter_num = round(self.iter_num / batch_size)
73 |
74 | def __len__(self):
75 | return self.iter_num
76 |
77 | def __iter__(self):
78 | self.iter_loaders = []
79 | for loader in self.loaders:
80 | self.iter_loaders.append(iter(loader))
81 | return self
82 |
83 | def __next__(self):
84 | result = []
85 | for loader in self.iter_loaders:
86 | for _ in loader.__next__():
87 | result.append(_)
88 | return tuple(result)
89 |
90 | class IterationLoader(Iterator):
91 | def __init__(self, loader, max_iter = 1):
92 | """
93 | Constructor of the loader with specific iteration (not epoch)
94 | The iteration object will create again while getting end
95 |
96 | Arg: loader - The torch.data.DataLoader object
97 | max_iter - The maximun iteration
98 | """
99 | super().__init__()
100 | self.loader = loader
101 | self.loader_iter = iter(self.loader)
102 | self.iter = 0
103 | self.max_iter = max_iter
104 |
105 | def __next__(self):
106 | try:
107 | result_tuple = next(self.loader_iter)
108 | except:
109 | self.loader_iter = iter(self.loader)
110 | result_tuple = next(self.loader_iter)
111 | self.iter += 1
112 | if self.iter <= self.max_iter:
113 | return result_tuple
114 | else:
115 | print("", end='')
116 | raise StopIteration()
117 |
118 | def __len__(self):
119 | return self.max_iter
--------------------------------------------------------------------------------
/torchvision_sunner/data/video_dataset.py:
--------------------------------------------------------------------------------
1 | from torchvision_sunner.data.base_dataset import BaseDataset
2 | from torchvision_sunner.read import readContain, readItem
3 | from torchvision_sunner.constant import *
4 | from torchvision_sunner.utils import INFO
5 |
6 | import torch.utils.data as Data
7 |
8 | from PIL import Image
9 | from glob import glob
10 | import numpy as np
11 | import subprocess
12 | import random
13 | import pickle
14 | import torch
15 | import math
16 | import os
17 |
18 | """
19 | This script define the structure of video dataset
20 |
21 | =======================================================================================
22 | In the new version, we accept the form that the combination of video and folder:
23 | e.g. [[video1.mp4, image_folder]]
24 | On the other hand, the root can only be 'the list of list'
25 | You should use double list to represent different image domain.
26 | For example:
27 | [[video1.mp4], [video2.mp4]] => valid
28 | [[video1.mp4], [video_folder]] => valid
29 | [[video1.mp4, video2.mp4], [video_folder1, video_folder2]] => valid
30 | [video1.mp4, video2.mp4] => invalid!
31 | Also, the triple of nested list is not allow
32 | =======================================================================================
33 |
34 | Author: SunnerLi
35 | """
36 |
37 | class VideoDataset(BaseDataset):
38 | def __init__(self, root = None, file_name = '.remain.pkl', T = 10, sample_method = UNDER_SAMPLING, transform = None,
39 | split_ratio = 0.0, decode_root = './.decode', save_file = False):
40 | """
41 | The constructor of VideoDataset
42 |
43 | Arg: root - The list object. The image set
44 | file_name - Str. The name of record file.
45 | T - Int. The maximun length of small video sequence
46 | sample_method - sunnerData.UNDER_SAMPLING or sunnerData.OVER_SAMPLING. Use down sampling or over sampling to deal with data unbalance problem.
47 | (default is sunnerData.OVER_SAMPLING)
48 | transform - transform.Compose object. You can declare some pre-process toward the image
49 | split_ratio - Float. The proportion to split the data. Usually used to split the testing data
50 | decode_root - Str. The path to store the ffmpeg decode result.
51 | save_file - Bool. If storing the record file or not. Default is False
52 | """
53 | super().__init__()
54 |
55 | # Record the parameter
56 | self.root = root
57 | self.file_name = file_name
58 | self.T = T
59 | self.sample_method = sample_method
60 | self.transform = transform
61 | self.split_ratio = split_ratio
62 | self.decode_root = decode_root
63 | self.video_num = -1
64 | self.split_root = None
65 | INFO()
66 |
67 | # Substitude the contain of record file if the record file is exist
68 | if not os.path.exists(file_name) and root is None:
69 | raise Exception("Record file {} not found. You should assign 'root' parameter!".format(file_name))
70 | elif os.path.exists(file_name):
71 | INFO("Load from file: {}".format(file_name))
72 | self.loadFromFile(file_name)
73 |
74 | # Extend the images of folder into domain list
75 | self.extendFolder()
76 |
77 | # Split the image
78 | self.split()
79 |
80 | # Form the files obj
81 | self.getFiles()
82 |
83 | # Adjust the image number
84 | self.getVideoNum()
85 |
86 | # Save the split information
87 | self.save()
88 |
89 | # Print the domain information
90 | self.print()
91 |
92 | # ===========================================================================================
93 | # Define IO function
94 | # ===========================================================================================
95 | def loadFromFile(self, file_name):
96 | """
97 | Load the root and files information from .pkl record file
98 | This function will return False if the record file format is invalid
99 |
100 | Arg: file_name - The name of record file
101 | Ret: If the loading procedure are successful or not
102 | """
103 | return super().loadFromFile(file_name, 'video')
104 |
105 | def save(self, split_file_name = ".split.pkl"):
106 | """
107 | Save the information into record file
108 |
109 | Arg: split_file_name - The path of record file which store the information of split data
110 | """
111 | super().save(self.file_name, self.split_ratio, split_file_name, 'video')
112 |
113 | # ===========================================================================================
114 | # Define main function
115 | # ===========================================================================================
116 | def to_folder(self, name):
117 | """
118 | Transfer the name into the folder format
119 | e.g.
120 | '/home/Dataset/video1_folder' => 'home_Dataset_video1_folder'
121 | '/home/Dataset/video1.mp4' => 'home_Dataset_video1'
122 |
123 | Arg: name - Str. The path of file or original folder
124 | Ret: The new (encoded) folder name
125 | """
126 | if not os.path.isdir(name):
127 | name = '_'.join(name.split('.')[:-1])
128 | domain_list = name.split('/')
129 | while True:
130 | if '.' in domain_list:
131 | domain_list.remove('.')
132 | elif '..' in domain_list:
133 | domain_list.remove('..')
134 | else:
135 | break
136 | return '_'.join(domain_list)
137 |
138 | def extendFolder(self):
139 | """
140 | Extend the video folder in root obj
141 | """
142 | if not self.files:
143 | # Extend the folder of video and replace as new root obj
144 | extend_root = []
145 | for domain in self.root:
146 | videos = []
147 | for video in domain:
148 | if os.path.exists(video):
149 | if os.path.isdir(video):
150 | videos += readContain(video)
151 | else:
152 | videos.append(video)
153 | else:
154 | raise Exception("The path {} is not exist".format(videos))
155 | extend_root.append(videos)
156 | self.root = extend_root
157 |
158 | def split(self):
159 | """
160 | Split the root object into split_root object
161 | The original root object will shrink
162 |
163 | We also consider the case of pair image
164 | Thus we will check if the number of image in each domain is the same
165 | If it does, then we only generate the list once
166 | """
167 | # Check if the number of video in different domain is the same
168 | pairImage = True
169 | for domain_idx in range(len(self.root) - 1):
170 | if len(self.root[domain_idx]) != len(self.root[domain_idx + 1]):
171 | pairImage = False
172 |
173 | # Split the files
174 | self.split_root = []
175 | if pairImage:
176 | split_img_num = math.floor(len(self.root[0]) * self.split_ratio)
177 | choice_index_list = self.generateIndexList(range(len(self.root[0])), size = split_img_num)
178 | for domain_idx in range(len(self.root)):
179 | # determine the index list
180 | if not pairImage:
181 | split_img_num = math.floor(len(self.root[domain_idx]) * self.split_ratio)
182 | choice_index_list = self.generateIndexList(range(len(self.root[domain_idx])), size = split_img_num)
183 | # remove the corresponding term and add into new list
184 | split_img_list = []
185 | remain_img_list = self.root[domain_idx].copy()
186 | for j in choice_index_list:
187 | split_img_list.append(self.root[domain_idx][j])
188 | for j in choice_index_list:
189 | self.root[domain_idx].remove(remain_img_list[j])
190 | self.split_root.append(sorted(split_img_list))
191 |
192 | def getFiles(self):
193 | """
194 | Construct the files object for the assigned root
195 | We accept the user to mix folder with image
196 | This function can extract whole image in the folder
197 |
198 | However, unlike the setting in ImageDataset, we store the video result in root obj.
199 | Also, the 'images' name will be store in files obj
200 |
201 | The following list the progress of this function:
202 | 1. check if we need to decode again
203 | 2. decode if needed
204 | 3. form the files obj
205 | """
206 | if not self.files:
207 | # Check if the decode process should be conducted again
208 | should_decode = not os.path.exists(self.decode_root)
209 | if not should_decode:
210 | for domain_idx, domain in enumerate(self.root):
211 | for video in domain:
212 | if not os.path.exists(os.path.join(self.decode_root, str(domain_idx), self.to_folder(video))):
213 | should_decode = True
214 | break
215 |
216 | # Decode the video if needed
217 | if should_decode:
218 | INFO("Decode from scratch...")
219 | if os.path.exists(self.decode_root):
220 | subprocess.call(['rm', '-rf', self.decode_root])
221 | os.mkdir(self.decode_root)
222 | self.decodeVideo()
223 | else:
224 | INFO("Skip the decode process!")
225 |
226 | # Form the files object
227 | self.files = {}
228 | for domain_idx, domain in enumerate(os.listdir(self.decode_root)):
229 | self.files[domain_idx] = []
230 | for video in os.listdir(os.path.join(self.decode_root, domain)):
231 | self.files[domain_idx] += [
232 | sorted(glob(os.path.join(self.decode_root, domain, video, "*")))
233 | ]
234 |
235 | def decodeVideo(self):
236 | """
237 | Decode the single video into a series of images, and store into particular folder
238 | """
239 | for domain_idx, domain in enumerate(self.root):
240 | decode_domain_folder = os.path.join(self.decode_root, str(domain_idx))
241 | os.mkdir(decode_domain_folder)
242 | for video in domain:
243 | os.mkdir(os.path.join(self.decode_root, str(domain_idx), self.to_folder(video)))
244 | source = os.path.join(domain, video)
245 | target = os.path.join(decode_domain_folder, self.to_folder(video), "%5d.png")
246 | subprocess.call(['ffmpeg', '-i', source, target])
247 |
248 | def getVideoNum(self):
249 | """
250 | Obtain the video number in the loader for the specific sample method
251 | The function will check if the folder has been extracted
252 | """
253 | if self.video_num == -1:
254 | # Check if the folder has been extracted
255 | for domain in self.root:
256 | for video in domain:
257 | if os.path.isdir(video):
258 | raise Exception("You should extend the image in the folder {} first!" % video)
259 |
260 | # Statistic the image number
261 | for i, domain in enumerate(self.root):
262 | if i == 0:
263 | self.video_num = len(domain)
264 | else:
265 | if self.sample_method == OVER_SAMPLING:
266 | self.video_num = max(self.video_num, len(domain))
267 | elif self.sample_method == UNDER_SAMPLING:
268 | self.video_num = min(self.video_num, len(domain))
269 | return self.video_num
270 |
271 | def print(self):
272 | """
273 | Print the information for each image domain
274 | """
275 | INFO()
276 | for domain in range(len(self.root)):
277 | total_frame = 0
278 | for video in self.files[domain]:
279 | total_frame += len(video)
280 | INFO("domain index: %d \tvideo number: %d\tframe total: %d" % (domain, len(self.root[domain]), total_frame))
281 | INFO()
282 |
283 | def __len__(self):
284 | return self.video_num
285 |
286 | def __getitem__(self, index):
287 | """
288 | Return single batch of data, and the rank is BTCHW
289 | """
290 | result = []
291 | for domain_idx in range(len(self.root)):
292 |
293 | # Form the sequence in single domain
294 | film_sequence = []
295 | max_init_frame_idx = len(self.files[domain_idx][index]) - self.T
296 | start_pos = random.randint(0, max_init_frame_idx)
297 | for i in range(self.T):
298 | img_path = self.files[domain_idx][index][start_pos + i]
299 | img = readItem(img_path)
300 | film_sequence.append(img)
301 |
302 | # Transform the film sequence
303 | film_sequence = np.asarray(film_sequence)
304 | if self.transform:
305 | film_sequence = self.transform(film_sequence)
306 | result.append(film_sequence)
307 | return result
--------------------------------------------------------------------------------
/torchvision_sunner/read.py:
--------------------------------------------------------------------------------
1 | from collections import Counter
2 | from PIL import Image
3 | from glob import glob
4 | import numpy as np
5 | import os
6 |
7 | """
8 | This script defines the function to read the containing of folder and read the file
9 | You should customize if your data is not considered in the torchvision_sunner previously
10 |
11 | Author: SunnerLi
12 | """
13 |
14 | def readContain(folder_name):
15 | """
16 | Read the containing in the particular folder
17 |
18 | ==================================================================
19 | You should customize this function if your data is not considered
20 | ==================================================================
21 |
22 | Arg: folder_name - The path of folder
23 | Ret: The list of containing
24 | """
25 | # Check the common type in the folder
26 | common_type = Counter()
27 | for name in os.listdir(folder_name):
28 | common_type[name.split('.')[-1]] += 1
29 | common_type = common_type.most_common()[0][0]
30 |
31 | # Deal with the type
32 | if common_type == 'jpg':
33 | name_list = glob(os.path.join(folder_name, '*.jpg'))
34 | elif common_type == 'png':
35 | name_list = glob(os.path.join(folder_name, '*.png'))
36 | elif common_type == 'mp4':
37 | name_list = glob(os.path.join(folder_name, '*.mp4'))
38 | else:
39 | raise Exception("Unknown type {}, You should customize in read.py".format(common_type))
40 | return name_list
41 |
42 | def readItem(item_name):
43 | """
44 | Read the file for the given item name
45 |
46 | ==================================================================
47 | You should customize this function if your data is not considered
48 | ==================================================================
49 |
50 | Arg: item_name - The path of the file
51 | Ret: The item you read
52 | """
53 | file_type = item_name.split('.')[-1]
54 | if file_type == "png" or file_type == 'jpg':
55 | file_obj = np.asarray(Image.open(item_name))
56 |
57 | if len(file_obj.shape) == 3:
58 | # Ignore the 4th dim (RGB only)
59 | file_obj = file_obj[:, :, :3]
60 | elif len(file_obj.shape) == 2:
61 | # Make the rank of gray-scale image as 3
62 | file_obj = np.expand_dims(file_obj, axis = -1)
63 | return file_obj
--------------------------------------------------------------------------------
/torchvision_sunner/transforms/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | This script define the wrapper of the Torchvision_sunner.transforms
3 |
4 | Author: SunnerLi
5 | """
6 | from torchvision_sunner.constant import *
7 | from torchvision_sunner.utils import *
8 |
9 | from torchvision_sunner.transforms.base import *
10 | from torchvision_sunner.transforms.simple import *
11 | from torchvision_sunner.transforms.complex import *
12 | from torchvision_sunner.transforms.categorical import *
13 | from torchvision_sunner.transforms.function import *
--------------------------------------------------------------------------------
/torchvision_sunner/transforms/base.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | """
5 | This class define the parent class of operation
6 |
7 | Author: SunnerLi
8 | """
9 |
10 | class OP():
11 | """
12 | The parent class of each operation
13 | The goal of this class is to adapting with different input format
14 | """
15 | def work(self, tensor):
16 | """
17 | The virtual function to define the process in child class
18 |
19 | Arg: tensor - The np.ndarray object. The tensor you want to deal with
20 | """
21 | raise NotImplementedError("You should define your own function in the class!")
22 |
23 | def __call__(self, tensor):
24 | """
25 | This function define the proceeding of the operation
26 | There are different choice toward the tensor parameter
27 | 1. torch.Tensor and rank is CHW
28 | 2. np.ndarray and rank is CHW
29 | 3. torch.Tensor and rank is TCHW
30 | 4. np.ndarray and rank is TCHW
31 |
32 | Arg: tensor - The tensor you want to operate
33 | Ret: The operated tensor
34 | """
35 | isTensor = type(tensor) == torch.Tensor
36 | if isTensor:
37 | tensor_type = tensor.type()
38 | tensor = tensor.cpu().data.numpy()
39 | if len(tensor.shape) == 3:
40 | tensor = self.work(tensor)
41 | elif len(tensor.shape) == 4:
42 | tensor = np.asarray([self.work(_) for _ in tensor])
43 | else:
44 | raise Exception("We dont support the rank format {}".format(tensor.shape),
45 | "If the rank of the tensor shape is only 2, you can call 'GrayStack()'")
46 | if isTensor:
47 | tensor = torch.from_numpy(tensor)
48 | tensor = tensor.type(tensor_type)
49 | return tensor
--------------------------------------------------------------------------------
/torchvision_sunner/transforms/categorical.py:
--------------------------------------------------------------------------------
1 | from torchvision_sunner.utils import INFO
2 | from torchvision_sunner.constant import *
3 | from torchvision_sunner.transforms.simple import Transpose
4 |
5 | from collections import OrderedDict
6 | from tqdm import tqdm
7 | import numpy as np
8 | import pickle
9 | import torch
10 | import json
11 | import os
12 |
13 | """
14 | This script define the categorical-related operations, including:
15 | 1. getCategoricalMapping
16 | 2. CategoricalTranspose
17 |
18 | Author: SunnerLi
19 | """
20 |
21 | # ----------------------------------------------------------------------------------------
22 | # Define the IO function toward pallete file
23 | # ----------------------------------------------------------------------------------------
24 |
25 | def load_pallete(file_name):
26 | """
27 | Load the pallete object from file
28 |
29 | Arg: file_name - The name of pallete .json file
30 | Ret: The list of pallete object
31 | """
32 | # Load the list of dict from files (key is str)
33 | palletes_str_key = None
34 | with open(file_name, 'r') as f:
35 | palletes_str_key = json.load(f)
36 |
37 | # Change the key into color tuple
38 | palletes = [OrderedDict()] * len(palletes_str_key)
39 | for folder in range(len(palletes_str_key)):
40 | for key in palletes_str_key[folder].keys():
41 | tuple_key = list()
42 | for v in key.split('_'):
43 | tuple_key.append(int(v))
44 | palletes[folder][tuple(tuple_key)] = palletes_str_key[folder][key]
45 | return palletes
46 |
47 | def save_pallete(pallete, file_name):
48 | """
49 | Load the pallete object from file
50 |
51 | Arg: pallete - The list of OrderDict objects
52 | file_name - The name of pallete .json file
53 | """
54 | # Change the key into str
55 | pallete_str_key = [dict()] * len(pallete)
56 | for folder in range(len(pallete)):
57 | for key in pallete[folder].keys():
58 |
59 | str_key = '_'.join([str(_) for _ in key])
60 | pallete_str_key[folder][str_key] = pallete[folder][key]
61 |
62 | # Save into file
63 | with open(file_name, 'w') as f:
64 | json.dump(pallete_str_key, f)
65 |
66 | # ----------------------------------------------------------------------------------------
67 | # Define the categorical-related operations
68 | # ----------------------------------------------------------------------------------------
69 |
70 | def getCategoricalMapping(loader = None, path = 'torchvision_sunner_categories_pallete.json'):
71 | """
72 | This function can statistic the different category with color
73 | And return the list of the mapping OrderedDict object
74 |
75 | Arg: loader - The ImageLoader object
76 | path - The path of pallete file
77 | Ret: The list of OrderDict object (palletes object)
78 | """
79 | INFO("Applied << %15s >>" % getCategoricalMapping.__name__)
80 | INFO("* Notice: the rank format of input tensor should be 'BHWC'")
81 | INFO("* Notice: The range of tensor should be in [0, 255]")
82 | if os.path.exists(path):
83 | palletes = load_pallete(path)
84 | else:
85 | INFO(">> Load from scratch, please wait...")
86 |
87 | # Get the number of folder
88 | folder_num = 0
89 | for img_list in loader:
90 | folder_num = len(img_list)
91 | break
92 |
93 | # Initialize the pallete list
94 | palletes = [OrderedDict()] * folder_num
95 | color_sets = [set()] * folder_num
96 |
97 | # Work
98 | for img_list in tqdm(loader):
99 | for folder_idx in range(folder_num):
100 | img = img_list[folder_idx]
101 | if torch.max(img) > 255 or torch.min(img) < 0:
102 | raise Exception('tensor value out of range...\t range is [' + str(torch.min(img)) + ' ~ ' + str(torch.max(img)))
103 | img = img.cpu().data.numpy().astype(np.uint8)
104 | img = np.reshape(img, [-1, 3])
105 | color_sets[folder_idx] |= set([tuple(_) for _ in img])
106 |
107 | # Merge the color
108 | for i in range(folder_num):
109 | for color in color_sets[i]:
110 | if color not in palletes[i].keys():
111 | palletes[i][color] = len(palletes[i])
112 | save_pallete(palletes, path)
113 |
114 | return palletes
115 |
116 | class CategoricalTranspose():
117 | def __init__(self, pallete = None, direction = COLOR2INDEX, index_default = 0):
118 | """
119 | Transform the tensor into the particular format
120 | We support for 3 different kinds of format:
121 | 1. one hot image
122 | 2. index image
123 | 3. color
124 |
125 | Arg: pallete - The pallete object (default is None)
126 | direction - The direction you want to change
127 | index_default - The default index if the color cannot be found in the pallete
128 | """
129 | self.pallete = pallete
130 | self.direction = direction
131 | self.index_default = index_default
132 | INFO("Applied << %15s >> , direction: %s" % (self.__class__.__name__, self.direction))
133 | INFO("* Notice: The range of tensor should be in [-1, 1]")
134 | INFO("* Notice: the rank format of input tensor should be 'BCHW'")
135 |
136 | def fn_color_to_index(self, tensor):
137 | """
138 | Transfer the tensor from the RGB colorful format into the index format
139 |
140 | Arg: tensor - The tensor obj. The tensor you want to deal with
141 | Ret: The tensor with index format
142 | """
143 | if self.pallete is None:
144 | raise Exception("The direction << %s >> need the pallete object" % self.direction)
145 | tensor = tensor.transpose(-3, -2).transpose(-2, -1).cpu().data.numpy()
146 | size_tuple = list(np.shape(tensor))
147 | tensor = (tensor * 127.5 + 127.5).astype(np.uint8)
148 | tensor = np.reshape(tensor, [-1, 3])
149 | tensor = [tuple(_) for _ in tensor]
150 | tensor = [self.pallete.get(_, self.index_default) for _ in tensor]
151 | tensor = np.asarray(tensor)
152 | size_tuple[-1] = 1
153 | tensor = np.reshape(tensor, size_tuple)
154 | tensor = torch.from_numpy(tensor).transpose(-1, -2).transpose(-2, -3)
155 | return tensor
156 |
157 | def fn_index_to_one_hot(self, tensor):
158 | """
159 | Transfer the tensor from the index format into the one-hot format
160 |
161 | Arg: tensor - The tensor obj. The tensor you want to deal with
162 | Ret: The tensor with one-hot format
163 | """
164 | # Get the number of classes
165 | tensor = tensor.transpose(-3, -2).transpose(-2, -1)
166 | size_tuple = list(np.shape(tensor))
167 | tensor = tensor.view(-1).cpu().data.numpy()
168 | channel = np.amax(tensor) + 1
169 |
170 | # Get the total number of pixel
171 | num_of_pixel = 1
172 | for i in range(len(size_tuple) - 1):
173 | num_of_pixel *= size_tuple[i]
174 |
175 | # Transfer as ont-hot format
176 | one_hot_tensor = np.zeros([num_of_pixel, channel])
177 | for i in range(channel):
178 | one_hot_tensor[tensor == i, i] = 1
179 |
180 | # Recover to origin rank format and shape
181 | size_tuple[-1] = channel
182 | tensor = np.reshape(one_hot_tensor, size_tuple)
183 | tensor = torch.from_numpy(tensor).transpose(-1, -2).transpose(-2, -3)
184 | return tensor
185 |
186 | def fn_one_hot_to_index(self, tensor):
187 | """
188 | Transfer the tensor from the one-hot format into the index format
189 |
190 | Arg: tensor - The tensor obj. The tensor you want to deal with
191 | Ret: The tensor with index format
192 | """
193 | _, tensor = torch.max(tensor, dim = 1)
194 | tensor = tensor.unsqueeze(1)
195 | return tensor
196 |
197 | def fn_index_to_color(self, tensor):
198 | """
199 | Transfer the tensor from the index format into the RGB colorful format
200 |
201 | Arg: tensor - The tensor obj. The tensor you want to deal with
202 | Ret: The tensor with RGB colorful format
203 | """
204 | if self.pallete is None:
205 | raise Exception("The direction << %s >> need the pallete object" % self.direction)
206 | tensor = tensor.transpose(-3, -2).transpose(-2, -1).cpu().data.numpy()
207 | reverse_pallete = {self.pallete[x]: x for x in self.pallete}
208 | batch, height, width, channel = np.shape(tensor)
209 | tensor = np.reshape(tensor, [-1])
210 | tensor = np.round(tensor, decimals=0)
211 | tensor = np.vectorize(reverse_pallete.get)(tensor)
212 | tensor = np.reshape(np.asarray(tensor).T, [batch, height, width, len(reverse_pallete[0])])
213 | tensor = torch.from_numpy((tensor - 127.5) / 127.5).transpose(-1, -2).transpose(-2, -3)
214 | return tensor
215 |
216 | def __call__(self, tensor):
217 | if self.direction == COLOR2INDEX:
218 | return self.fn_color_to_index(tensor)
219 | elif self.direction == INDEX2COLOR:
220 | return self.fn_index_to_color(tensor)
221 | elif self.direction == ONEHOT2INDEX:
222 | return self.fn_one_hot_to_index(tensor)
223 | elif self.direction == INDEX2ONEHOT:
224 | return self.fn_index_to_one_hot(tensor)
225 | elif self.direction == ONEHOT2COLOR:
226 | return self.fn_index_to_color(self.fn_one_hot_to_index(tensor))
227 | elif self.direction == COLOR2ONEHOT:
228 | return self.fn_index_to_one_hot(self.fn_color_to_index(tensor))
229 | else:
230 | raise Exception("Unknown direction: {}".format(self.direction))
--------------------------------------------------------------------------------
/torchvision_sunner/transforms/complex.py:
--------------------------------------------------------------------------------
1 | from torchvision_sunner.transforms.base import OP
2 | from torchvision_sunner.utils import INFO
3 | from skimage import transform
4 | import numpy as np
5 | import torch
6 |
7 | """
8 | This script define some complex operations
9 | These kind of operations should conduct work iteratively (with inherit OP class)
10 |
11 | Author: SunnerLi
12 | """
13 |
14 | class Resize(OP):
15 | def __init__(self, output_size):
16 | """
17 | Resize the tensor to the desired size
18 | This function only support for nearest-neighbor interpolation
19 | Since this mechanism can also deal with categorical data
20 |
21 | Arg: output_size - The tuple (H, W)
22 | """
23 | self.output_size = output_size
24 | INFO("Applied << %15s >>" % self.__class__.__name__)
25 | INFO("* Notice: the rank format of input tensor should be 'BHWC'")
26 |
27 | def work(self, tensor):
28 | """
29 | Resize the tensor
30 | If the tensor is not in the range of [-1, 1], we will do the normalization automatically
31 |
32 | Arg: tensor - The np.ndarray object. The tensor you want to deal with
33 | Ret: The resized tensor
34 | """
35 | # Normalize the tensor if needed
36 | mean, std = -1, -1
37 | min_v = np.min(tensor)
38 | max_v = np.max(tensor)
39 | if not (max_v <= 1 and min_v >= -1):
40 | mean = 0.5 * max_v + 0.5 * min_v
41 | std = 0.5 * max_v - 0.5 * min_v
42 | # print(max_v, min_v, mean, std)
43 | tensor = (tensor - mean) / std
44 |
45 | # Work
46 | tensor = transform.resize(tensor, self.output_size, mode = 'constant', order = 0)
47 |
48 | # De-normalize the tensor
49 | if mean != -1 and std != -1:
50 | tensor = tensor * std + mean
51 | return tensor
52 |
53 | class Normalize(OP):
54 | def __init__(self, mean = [127.5, 127.5, 127.5], std = [127.5, 127.5, 127.5]):
55 | """
56 | Normalize the tensor with given mean and standard deviation
57 | * Notice: If you didn't give mean and std, the result will locate in [-1, 1]
58 |
59 | Args:
60 | mean - The mean of the result tensor
61 | std - The standard deviation
62 | """
63 | self.mean = mean
64 | self.std = std
65 | INFO("Applied << %15s >>" % self.__class__.__name__)
66 | INFO("* Notice: the rank format of input tensor should be 'BCHW'")
67 | INFO("*****************************************************************")
68 | INFO("* Notice: You should must call 'ToFloat' before normalization")
69 | INFO("*****************************************************************")
70 | if self.mean == [127.5, 127.5, 127.5] and self.std == [127.5, 127.5, 127.5]:
71 | INFO("* Notice: The result will locate in [-1, 1]")
72 |
73 | def work(self, tensor):
74 | """
75 | Normalize the tensor
76 |
77 | Arg: tensor - The np.ndarray object. The tensor you want to deal with
78 | Ret: The normalized tensor
79 | """
80 | if tensor.shape[0] != len(self.mean):
81 | raise Exception("The rank format should be BCHW, but the shape is {}".format(tensor.shape))
82 | result = []
83 | for t, m, s in zip(tensor, self.mean, self.std):
84 | result.append((t - m) / s)
85 | tensor = np.asarray(result)
86 |
87 | # Check if the normalization can really work
88 | if np.min(tensor) < -1 or np.max(tensor) > 1:
89 | raise Exception("Normalize can only work with float tensor",
90 | "Try to call 'ToFloat()' before normalization")
91 | return tensor
92 |
93 | class UnNormalize(OP):
94 | def __init__(self, mean = [127.5, 127.5, 127.5], std = [127.5, 127.5, 127.5]):
95 | """
96 | Unnormalize the tensor with given mean and standard deviation
97 | * Notice: If you didn't give mean and std, the function will assume that the original distribution locates in [-1, 1]
98 |
99 | Args:
100 | mean - The mean of the result tensor
101 | std - The standard deviation
102 | """
103 | self.mean = mean
104 | self.std = std
105 | INFO("Applied << %15s >>" % self.__class__.__name__)
106 | INFO("* Notice: the rank format of input tensor should be 'BCHW'")
107 | if self.mean == [127.5, 127.5, 127.5] and self.std == [127.5, 127.5, 127.5]:
108 | INFO("* Notice: The function assume that the input range is [-1, 1]")
109 |
110 | def work(self, tensor):
111 | """
112 | Un-normalize the tensor
113 |
114 | Arg: tensor - The np.ndarray object. The tensor you want to deal with
115 | Ret: The un-normalized tensor
116 | """
117 | if tensor.shape[0] != len(self.mean):
118 | raise Exception("The rank format should be BCHW, but the shape is {}".format(tensor.shape))
119 | result = []
120 | for t, m, s in zip(tensor, self.mean, self.std):
121 | result.append(t * s + m)
122 | tensor = np.asarray(result)
123 | return tensor
124 |
125 | class ToGray(OP):
126 | def __init__(self):
127 | """
128 | Change the tensor as the gray scale
129 | The function will turn the BCHW tensor into B1HW gray-scaled tensor
130 | """
131 | INFO("Applied << %15s >>" % self.__class__.__name__)
132 | INFO("* Notice: the rank format of input tensor should be 'BCHW'")
133 |
134 | def work(self, tensor):
135 | """
136 | Make the tensor into gray-scale
137 |
138 | Arg: tensor - The np.ndarray object. The tensor you want to deal with
139 | Ret: The gray-scale tensor, and the rank of the tensor is B1HW
140 | """
141 | if tensor.shape[0] == 3:
142 | result = 0.299 * tensor[0] + 0.587 * tensor[1] + 0.114 * tensor[2]
143 | result = np.expand_dims(result, axis = 0)
144 | elif tensor.shape[0] != 4:
145 | result = 0.299 * tensor[:, 0] + 0.587 * tensor[:, 1] + 0.114 * tensor[:, 2]
146 | result = np.expand_dims(result, axis = 1)
147 | else:
148 | raise Exception("The rank format should be BCHW, but the shape is {}".format(tensor.shape))
149 | return result
--------------------------------------------------------------------------------
/torchvision_sunner/transforms/function.py:
--------------------------------------------------------------------------------
1 | from torchvision_sunner.transforms.simple import Transpose
2 | from torchvision_sunner.constant import BCHW2BHWC
3 |
4 | from skimage import transform
5 | from skimage import io
6 | import numpy as np
7 | import torch
8 |
9 | """
10 | This script define the transform function which can be called directly
11 |
12 | Author: SunnerLi
13 | """
14 |
15 | channel_op = None # Define the channel op which will be used in 'asImg' function
16 |
17 | def asImg(tensor, size = None):
18 | """
19 | This function provides fast approach to transfer the image into numpy.ndarray
20 | This function only accept the output from sigmoid layer or hyperbolic tangent output
21 |
22 | Arg: tensor - The torch.Variable object, the rank format is BCHW or BHW
23 | size - The tuple object, and the format is (height, width)
24 | Ret: The numpy image, the rank format is BHWC
25 | """
26 | global channel_op
27 | result = tensor.detach()
28 |
29 | # 1. Judge the rank first
30 | if len(tensor.size()) == 3:
31 | result = torch.stack([result, result, result], 1)
32 |
33 | # 2. Judge the range of tensor (sigmoid output or hyperbolic tangent output)
34 | min_v = torch.min(result).cpu().data.numpy()
35 | max_v = torch.max(result).cpu().data.numpy()
36 | if max_v > 1.0 or min_v < -1.0:
37 | raise Exception('tensor value out of range...\t range is [' + str(min_v) + ' ~ ' + str(max_v))
38 | if min_v < 0:
39 | result = (result + 1) / 2
40 |
41 | # 3. Define the BCHW -> BHWC operation
42 | if channel_op is None:
43 | channel_op = Transpose(BCHW2BHWC)
44 |
45 | # 3. Rest
46 | result = channel_op(result)
47 | result = result.cpu().data.numpy()
48 | if size is not None:
49 | result_list = []
50 | for img in result:
51 | result_list.append(transform.resize(img, (size[0], size[1]), mode = 'constant', order = 0) * 255)
52 | result = np.stack(result_list, axis = 0)
53 | else:
54 | result *= 255.
55 | result = result.astype(np.uint8)
56 | return result
--------------------------------------------------------------------------------
/torchvision_sunner/transforms/simple.py:
--------------------------------------------------------------------------------
1 | from torchvision_sunner.utils import INFO
2 | from torchvision_sunner.constant import *
3 | import numpy as np
4 | import torch
5 |
6 | """
7 | This script define some operation which are rather simple
8 | The operation only need to call function once (without inherit OP class)
9 |
10 | Author: SunnerLi
11 | """
12 |
13 | class ToTensor():
14 | def __init__(self):
15 | """
16 | Change the tensor into torch.Tensor type
17 | """
18 | INFO("Applied << %15s >>" % self.__class__.__name__)
19 |
20 | def __call__(self, tensor):
21 | """
22 | Arg: tensor - The torch.Tensor or other type. The tensor you want to deal with
23 | """
24 | if type(tensor) == np.ndarray:
25 | tensor = torch.from_numpy(tensor)
26 | return tensor
27 |
28 | class ToFloat():
29 | def __init__(self):
30 | """
31 | Change the tensor into torch.FloatTensor
32 | """
33 | INFO("Applied << %15s >>" % self.__class__.__name__)
34 |
35 | def __call__(self, tensor):
36 | """
37 | Arg: tensor - The torch.Tensor object. The tensor you want to deal with
38 | """
39 | return tensor.float()
40 |
41 | class Transpose():
42 | def __init__(self, direction = BHWC2BCHW):
43 | """
44 | Transfer the rank of tensor into target one
45 |
46 | Arg: direction - The direction you want to do the transpose
47 | """
48 | self.direction = direction
49 | if self.direction == BHWC2BCHW:
50 | INFO("Applied << %15s >>, The rank format is BCHW" % self.__class__.__name__)
51 | elif self.direction == BCHW2BHWC:
52 | INFO("Applied << %15s >>, The rank format is BHWC" % self.__class__.__name__)
53 | else:
54 | raise Exception("Unknown direction symbol: {}".format(self.direction))
55 |
56 | def __call__(self, tensor):
57 | """
58 | Arg: tensor - The torch.Tensor object. The tensor you want to deal with
59 | """
60 | if self.direction == BHWC2BCHW:
61 | tensor = tensor.transpose(-1, -2).transpose(-2, -3)
62 | else:
63 | tensor = tensor.transpose(-3, -2).transpose(-2, -1)
64 | return tensor
--------------------------------------------------------------------------------
/torchvision_sunner/utils.py:
--------------------------------------------------------------------------------
1 | from torchvision_sunner.constant import *
2 |
3 | """
4 | This script defines the function which are widely used in the whole package
5 |
6 | Author: SunnerLi
7 | """
8 |
9 | def quiet():
10 | """
11 | Mute the information toward the whole log in the toolkit
12 | """
13 | global verbose
14 | verbose = False
15 |
16 | def INFO(string = None):
17 | """
18 | Print the information with prefix
19 |
20 | Arg: string - The string you want to print
21 | """
22 | if verbose:
23 | if string:
24 | print("[ Torchvision_sunner ] %s" % (string))
25 | else:
26 | print("[ Torchvision_sunner ] " + '=' * 50)
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | # coding: UTF-8
2 | """
3 | @author: samuel ko
4 | """
5 | import torchvision_sunner.transforms as sunnertransforms
6 | import torchvision_sunner.data as sunnerData
7 | import torchvision.transforms as transforms
8 |
9 |
10 | from networks_stylegan import StyleGenerator, StyleDiscriminator
11 | from networks_gan import Generator, Discriminator
12 | from utils import plotLossCurve
13 | from loss import gradient_penalty
14 | from opts import TrainOptions
15 |
16 | from torchvision.utils import save_image
17 | from tqdm import tqdm
18 | from matplotlib import pyplot as plt
19 | import torch.optim as optim
20 | import numpy as np
21 | import torch
22 | import os
23 |
24 | # Hyper-parameters
25 | CRITIC_ITER = 5
26 |
27 |
28 | def main(opts):
29 | # Create the data loader
30 | loader = sunnerData.DataLoader(sunnerData.ImageDataset(
31 | root=[[opts.path]],
32 | transform=transforms.Compose([
33 | sunnertransforms.Resize((1024, 1024)),
34 | sunnertransforms.ToTensor(),
35 | sunnertransforms.ToFloat(),
36 | sunnertransforms.Transpose(sunnertransforms.BHWC2BCHW),
37 | sunnertransforms.Normalize(),
38 | ])),
39 | batch_size=opts.batch_size,
40 | shuffle=True,
41 | )
42 |
43 | # Create the model
44 | G = StyleGenerator(bs=opts.batch_size).to(opts.device)
45 | D = StyleDiscriminator().to(opts.device)
46 |
47 | # G = Generator().to(opts.device)
48 | # D = Discriminator().to(opts.device)
49 |
50 | # Create the criterion, optimizer and scheduler
51 | optim_D = optim.Adam(D.parameters(), lr=0.0001, betas=(0.5, 0.999))
52 | optim_G = optim.Adam(G.parameters(), lr=0.0001, betas=(0.5, 0.999))
53 | scheduler_D = optim.lr_scheduler.ExponentialLR(optim_D, gamma=0.99)
54 | scheduler_G = optim.lr_scheduler.ExponentialLR(optim_G, gamma=0.99)
55 |
56 | # Train
57 | fix_z = torch.randn([opts.batch_size, 512]).to(opts.device)
58 | Loss_D_list = [0.0]
59 | Loss_G_list = [0.0]
60 | for ep in range(opts.epoch):
61 | bar = tqdm(loader)
62 | loss_D_list = []
63 | loss_G_list = []
64 | for i, (real_img,) in enumerate(bar):
65 | # =======================================================================================================
66 | # Update discriminator
67 | # =======================================================================================================
68 | # Compute adversarial loss toward discriminator
69 | real_img = real_img.to(opts.device)
70 | real_logit = D(real_img)
71 | fake_img = G(torch.randn([real_img.size(0), 512]).to(opts.device))
72 | fake_logit = D(fake_img.detach())
73 | d_loss = -(real_logit.mean() - fake_logit.mean()) + gradient_penalty(real_img.data, fake_img.data, D) * 10.0
74 | loss_D_list.append(d_loss.item())
75 |
76 | # Update discriminator
77 | optim_D.zero_grad()
78 | d_loss.backward()
79 | optim_D.step()
80 |
81 | # =======================================================================================================
82 | # Update generator
83 | # =======================================================================================================
84 | if i % CRITIC_ITER == 0:
85 | # Compute adversarial loss toward generator
86 | fake_img = G(torch.randn([opts.batch_size, 512]).to(opts.device))
87 | fake_logit = D(fake_img)
88 | g_loss = -fake_logit.mean()
89 | loss_G_list.append(g_loss.item())
90 |
91 | # Update generator
92 | D.zero_grad()
93 | optim_G.zero_grad()
94 | g_loss.backward()
95 | optim_G.step()
96 | bar.set_description(" {} [G]: {} [D]: {}".format(ep, loss_G_list[-1], loss_D_list[-1]))
97 |
98 | # Save the result
99 | Loss_G_list.append(np.mean(loss_G_list))
100 | Loss_D_list.append(np.mean(loss_D_list))
101 | fake_img = G(fix_z)
102 | save_image(fake_img, os.path.join(opts.det, 'images', str(ep) + '.png'), nrow=4, normalize=True)
103 | state = {
104 | 'G': G.state_dict(),
105 | 'D': D.state_dict(),
106 | 'Loss_G': Loss_G_list,
107 | 'Loss_D': Loss_D_list,
108 | }
109 | torch.save(state, os.path.join(opts.det, 'models', 'latest.pth'))
110 |
111 | scheduler_D.step()
112 | scheduler_G.step()
113 |
114 | # Plot the total loss curve
115 | Loss_D_list = Loss_D_list[1:]
116 | Loss_G_list = Loss_G_list[1:]
117 | plotLossCurve(opts, Loss_D_list, Loss_G_list)
118 |
119 |
120 | if __name__ == '__main__':
121 | opts = TrainOptions().parse()
122 | main(opts)
--------------------------------------------------------------------------------
/train_stylegan.py:
--------------------------------------------------------------------------------
1 | # coding: UTF-8
2 | """
3 | @author: samuel ko
4 | """
5 | import torchvision_sunner.transforms as sunnertransforms
6 | import torchvision_sunner.data as sunnerData
7 | import torchvision.transforms as transforms
8 |
9 |
10 | from networks_stylegan import StyleGenerator, StyleDiscriminator
11 | from networks_gan import Generator, Discriminator
12 | from utils.utils import plotLossCurve
13 | from loss.loss import gradient_penalty, R1Penalty, R2Penalty
14 | from opts.opts import TrainOptions, INFO
15 |
16 | from torchvision.utils import save_image
17 | from tqdm import tqdm
18 | from matplotlib import pyplot as plt
19 | from torch import nn
20 | import torch.optim as optim
21 | from torch.autograd import Variable
22 | import numpy as np
23 | import random
24 | import torch
25 | import os
26 |
27 | # Set random seem for reproducibility
28 | manualSeed = 999
29 | #manualSeed = random.randint(1, 10000) # use if you want new results
30 | print("Random Seed: ", manualSeed)
31 | random.seed(manualSeed)
32 | torch.manual_seed(manualSeed)
33 |
34 | # Hyper-parameters
35 | CRITIC_ITER = 5
36 |
37 |
38 | def main(opts):
39 | # Create the data loader
40 | loader = sunnerData.DataLoader(sunnerData.ImageDataset(
41 | root=[[opts.path]],
42 | transform=transforms.Compose([
43 | sunnertransforms.Resize((1024, 1024)),
44 | sunnertransforms.ToTensor(),
45 | sunnertransforms.ToFloat(),
46 | sunnertransforms.Transpose(sunnertransforms.BHWC2BCHW),
47 | sunnertransforms.Normalize(),
48 | ])),
49 | batch_size=opts.batch_size,
50 | shuffle=True,
51 | )
52 |
53 | # Create the model
54 | start_epoch = 0
55 | G = StyleGenerator()
56 | D = StyleDiscriminator()
57 |
58 | # Load the pre-trained weight
59 | if os.path.exists(opts.resume):
60 | INFO("Load the pre-trained weight!")
61 | state = torch.load(opts.resume)
62 | G.load_state_dict(state['G'])
63 | D.load_state_dict(state['D'])
64 | start_epoch = state['start_epoch']
65 | else:
66 | INFO("Pre-trained weight cannot load successfully, train from scratch!")
67 |
68 | # Multi-GPU support
69 | if torch.cuda.device_count() > 1:
70 | INFO("Multiple GPU:" + str(torch.cuda.device_count()) + "\t GPUs")
71 | G = nn.DataParallel(G)
72 | D = nn.DataParallel(D)
73 | G.to(opts.device)
74 | D.to(opts.device)
75 |
76 | # Create the criterion, optimizer and scheduler
77 | optim_D = optim.Adam(D.parameters(), lr=0.00001, betas=(0.5, 0.999))
78 | optim_G = optim.Adam(G.parameters(), lr=0.00001, betas=(0.5, 0.999))
79 | scheduler_D = optim.lr_scheduler.ExponentialLR(optim_D, gamma=0.99)
80 | scheduler_G = optim.lr_scheduler.ExponentialLR(optim_G, gamma=0.99)
81 |
82 | # Train
83 | fix_z = torch.randn([opts.batch_size, 512]).to(opts.device)
84 | softplus = nn.Softplus()
85 | Loss_D_list = [0.0]
86 | Loss_G_list = [0.0]
87 | for ep in range(start_epoch, opts.epoch):
88 | bar = tqdm(loader)
89 | loss_D_list = []
90 | loss_G_list = []
91 | for i, (real_img,) in enumerate(bar):
92 | # =======================================================================================================
93 | # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
94 | # =======================================================================================================
95 | # Compute adversarial loss toward discriminator
96 | D.zero_grad()
97 | real_img = real_img.to(opts.device)
98 | real_logit = D(real_img)
99 | fake_img = G(torch.randn([real_img.size(0), 512]).to(opts.device))
100 | fake_logit = D(fake_img.detach())
101 | d_loss = softplus(fake_logit).mean()
102 | d_loss = d_loss + softplus(-real_logit).mean()
103 |
104 | if opts.r1_gamma != 0.0:
105 | r1_penalty = R1Penalty(real_img.detach(), D)
106 | d_loss = d_loss + r1_penalty * (opts.r1_gamma * 0.5)
107 |
108 | if opts.r2_gamma != 0.0:
109 | r2_penalty = R2Penalty(fake_img.detach(), D)
110 | d_loss = d_loss + r2_penalty * (opts.r2_gamma * 0.5)
111 |
112 | loss_D_list.append(d_loss.item())
113 |
114 | # Update discriminator
115 | d_loss.backward()
116 | optim_D.step()
117 |
118 | # =======================================================================================================
119 | # (2) Update G network: maximize log(D(G(z)))
120 | # =======================================================================================================
121 | if i % CRITIC_ITER == 0:
122 | G.zero_grad()
123 | fake_logit = D(fake_img)
124 | g_loss = softplus(-fake_logit).mean()
125 | loss_G_list.append(g_loss.item())
126 |
127 | # Update generator
128 | g_loss.backward()
129 | optim_G.step()
130 |
131 | # Output training stats
132 | bar.set_description("Epoch {} [{}, {}] [G]: {} [D]: {}".format(ep, i+1, len(loader), loss_G_list[-1], loss_D_list[-1]))
133 |
134 | # Save the result
135 | Loss_G_list.append(np.mean(loss_G_list))
136 | Loss_D_list.append(np.mean(loss_D_list))
137 |
138 | # Check how the generator is doing by saving G's output on fixed_noise
139 | with torch.no_grad():
140 | fake_img = G(fix_z).detach().cpu()
141 | save_image(fake_img, os.path.join(opts.det, 'images', str(ep) + '.png'), nrow=4, normalize=True)
142 |
143 | # Save model
144 | state = {
145 | 'G': G.state_dict(),
146 | 'D': D.state_dict(),
147 | 'Loss_G': Loss_G_list,
148 | 'Loss_D': Loss_D_list,
149 | 'start_epoch': ep,
150 | }
151 | torch.save(state, os.path.join(opts.det, 'models', 'latest.pth'))
152 |
153 | scheduler_D.step()
154 | scheduler_G.step()
155 |
156 | # Plot the total loss curve
157 | Loss_D_list = Loss_D_list[1:]
158 | Loss_G_list = Loss_G_list[1:]
159 | plotLossCurve(opts, Loss_D_list, Loss_G_list)
160 |
161 |
162 | if __name__ == '__main__':
163 | opts = TrainOptions().parse()
164 | main(opts)
165 |
--------------------------------------------------------------------------------
/utils/stylegan-teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tomguluson92/StyleGAN_PyTorch/4fd0711f560b9b080fca3df2448822835226ba02/utils/stylegan-teaser.png
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | # coding: UTF-8
2 | """
3 | @author: samuel ko
4 | """
5 |
6 | from matplotlib import pyplot as plt
7 | import os
8 |
9 | def plotLossCurve(opts, Loss_D_list, Loss_G_list):
10 | plt.figure()
11 | plt.plot(Loss_D_list, '-')
12 | plt.title("Loss curve (Discriminator)")
13 | plt.savefig(os.path.join(opts.det, 'images', 'loss_curve_discriminator.png'))
14 |
15 | plt.figure()
16 | plt.plot(Loss_G_list, '-o')
17 | plt.title("Loss curve (Generator)")
18 | plt.savefig(os.path.join(opts.det, 'images', 'loss_curve_generator.png'))
--------------------------------------------------------------------------------