├── README.md
├── fig
├── fid_lpips.png
├── performance.png
├── sketch.png
└── styleme.png
├── sketch_generation
├── README.md
├── benchmark.py
├── evaluate.py
├── metrics.py
├── models.py
├── operation.py
├── requirement.txt
├── train.py
├── utils.py
├── vgg-feature-weights.z01
├── vgg-feature-weights.z02
└── vgg-feature-weights.zip
└── styleme
├── benchmark.py
├── calculate.py
├── config.py
├── datasets.py
├── framework.png
├── generate_matrix.py
├── lpips
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-37.pyc
│ ├── __init__.cpython-38.pyc
│ ├── base_model.cpython-37.pyc
│ ├── base_model.cpython-38.pyc
│ ├── dist_model.cpython-37.pyc
│ ├── dist_model.cpython-38.pyc
│ ├── networks_basic.cpython-37.pyc
│ ├── networks_basic.cpython-38.pyc
│ ├── pretrained_networks.cpython-37.pyc
│ └── pretrained_networks.cpython-38.pyc
├── base_model.py
├── dist_model.py
├── networks_basic.py
├── pretrained_networks.py
└── weights
│ ├── v0.0
│ ├── alex.pth
│ ├── squeeze.pth
│ └── vgg.pth
│ └── v0.1
│ ├── alex.pth
│ ├── squeeze.pth
│ └── vgg.pth
├── models.py
├── readme.md
├── style_transform.py
├── train.py
├── train_step_1.py
├── train_step_2.py
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 | # StyleMe: Towards Intelligent Fashion Generation with Designer Style
2 |
3 | Proceedings of the 2023 CHI Conference on Human Factors in Computing Systems (**CHI 2023**) | [**Paper**](https://dl.acm.org/doi/fullHtml/10.1145/3544548.3581377)
4 |
5 |
6 | Our model contains the following two parts and datasets is available:
7 | - **image to sketch module** : [ [**sketch_generation**](https://github.com/ExponentiAI/StyleMe/tree/main/sketch_generation) ]
8 | - **sketch to image module** : [ [**style_transform**](https://github.com/ExponentiAI/StyleMe/tree/main/styleme) ]
9 | - **available dataset** : [ [**clothdataset**](https://drive.google.com/drive/folders/1tAHeblEon0Awb3QchTlLq9Knyc443i3x) ]
10 |
11 |
12 | ## 1. Video
13 |
14 |
15 |
16 |
17 |
18 | - The video link:**[StyleMe Demonstration](https://user-images.githubusercontent.com/43172916/218964923-1f99907c-4841-4cca-a961-fc771f22834f.mp4)**
19 |
20 |
21 | ## 2. Performance
22 | - Here is our model's performance:
23 |
24 | - Sketch Generation
25 |
26 |
27 |
28 |
29 | - Style Transfer
30 |
31 |
32 |
33 |
34 | - and the FID and LPIPS during training:
35 |
36 |
37 |
38 |
39 |
40 | ## 3. Reference
41 |
42 | If you find our code or dataset is useful for your research, please cite our paper.
43 |
44 | BibTex :
45 | ```
46 | @inproceedings{wu2023styleme,
47 | title={StyleMe: Towards Intelligent Fashion Generation with Designer Style},
48 | author={Wu, Di and Yu, Zhiwang and Ma, Nan and Jiang, Jianan and Wang, Yuetian and Zhou, Guixiang and Deng, Hanhui and Li, Yi},
49 | booktitle={Proceedings of the 2023 CHI Conference on Human Factors in Computing Systems},
50 | pages={1--16},
51 | year={2023}
52 | }
53 | ```
54 |
55 | Or :
56 | ```
57 | Di Wu, Zhiwang Yu, Nan Ma, Jianan Jiang, Yuetian Wang, Guixiang Zhou, Hanhui Deng, Yi Li: StyleMe: Towards Intelligent Fashion Generation with Designer Style. CHI 2023: 23:1-23:16
58 | ```
59 |
60 |
61 |
--------------------------------------------------------------------------------
/fig/fid_lpips.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ExponentiAI/StyleMe/64bf50eba799feba7c077cbf4ef4507ddf5c81d9/fig/fid_lpips.png
--------------------------------------------------------------------------------
/fig/performance.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ExponentiAI/StyleMe/64bf50eba799feba7c077cbf4ef4507ddf5c81d9/fig/performance.png
--------------------------------------------------------------------------------
/fig/sketch.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ExponentiAI/StyleMe/64bf50eba799feba7c077cbf4ef4507ddf5c81d9/fig/sketch.png
--------------------------------------------------------------------------------
/fig/styleme.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ExponentiAI/StyleMe/64bf50eba799feba7c077cbf4ef4507ddf5c81d9/fig/styleme.png
--------------------------------------------------------------------------------
/sketch_generation/README.md:
--------------------------------------------------------------------------------
1 | # StyelMe - pytorch
2 | A pytorch implementation of image-to-sketch model.
3 | running environment: python 3.7.0 pytorch 1.12.1
4 | ## Data
5 | Include RGB image and sketch image of clothes in various styles.
6 |
7 | ## Description
8 | Related code comments:
9 | * models.py all the related models' structure definition, including generator and discriminator
10 | * train.py training the whole model,
11 | * evaluate.py test the model
12 | * vgg-feature-weights.pth pretrained model feature-weights
13 |
--------------------------------------------------------------------------------
/sketch_generation/benchmark.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torchvision import models
5 | from torchvision.models import inception_v3, Inception3
6 | from torchvision.utils import save_image
7 | from torchvision import utils as vutils
8 | from torch.utils.data import DataLoader
9 |
10 | try:
11 | from torchvision.models.utils import load_state_dict_from_url
12 | except ImportError:
13 | from torch.utils.model_zoo import load_url as load_state_dict_from_url
14 |
15 | import numpy as np
16 | from scipy import linalg
17 | from tqdm import tqdm
18 | import pickle
19 | import os
20 | from utils import true_randperm
21 |
22 | # Inception weights ported to Pytorch from
23 | # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
24 | FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth'
25 |
26 |
27 | class InceptionV3(nn.Module):
28 | """Pretrained InceptionV3 network returning feature maps"""
29 |
30 | # Index of default block of inception to return,
31 | # corresponds to output of final average pooling
32 | DEFAULT_BLOCK_INDEX = 3
33 |
34 | # Maps feature dimensionality to their output blocks indices
35 | BLOCK_INDEX_BY_DIM = {
36 | 64: 0, # First max pooling features
37 | 192: 1, # Second max pooling featurs
38 | 768: 2, # Pre-aux classifier features
39 | 2048: 3 # Final average pooling features
40 | }
41 |
42 | def __init__(self,
43 | output_blocks=[DEFAULT_BLOCK_INDEX],
44 | resize_input=True,
45 | normalize_input=True,
46 | requires_grad=False,
47 | use_fid_inception=True):
48 | """Build pretrained InceptionV3
49 | Parameters
50 | ----------
51 | output_blocks : list of int
52 | Indices of blocks to return features of. Possible values are:
53 | - 0: corresponds to output of first max pooling
54 | - 1: corresponds to output of second max pooling
55 | - 2: corresponds to output which is fed to aux classifier
56 | - 3: corresponds to output of final average pooling
57 | resize_input : bool
58 | If true, bilinearly resizes input to width and height 299 before
59 | feeding input to model. As the network without fully connected
60 | layers is fully convolutional, it should be able to handle inputs
61 | of arbitrary size, so resizing might not be strictly needed
62 | normalize_input : bool
63 | If true, scales the input from range (0, 1) to the range the
64 | pretrained Inception network expects, namely (-1, 1)
65 | requires_grad : bool
66 | If true, parameters of the model require gradients. Possibly useful
67 | for finetuning the network
68 | use_fid_inception : bool
69 | If true, uses the pretrained Inception model used in Tensorflow's
70 | FID implementation. If false, uses the pretrained Inception model
71 | available in torchvision. The FID Inception model has different
72 | weights and a slightly different structure from torchvision's
73 | Inception model. If you want to compute FID scores, you are
74 | strongly advised to set this parameter to true to get comparable
75 | results.
76 | """
77 | super(InceptionV3, self).__init__()
78 |
79 | self.resize_input = resize_input
80 | self.normalize_input = normalize_input
81 | self.output_blocks = sorted(output_blocks)
82 | self.last_needed_block = max(output_blocks)
83 |
84 | assert self.last_needed_block <= 3, \
85 | 'Last possible output block index is 3'
86 |
87 | self.blocks = nn.ModuleList()
88 |
89 | if use_fid_inception:
90 | inception = fid_inception_v3()
91 | else:
92 | inception = models.inception_v3(pretrained=True)
93 |
94 | # Block 0: input to maxpool1
95 | block0 = [
96 | inception.Conv2d_1a_3x3,
97 | inception.Conv2d_2a_3x3,
98 | inception.Conv2d_2b_3x3,
99 | nn.MaxPool2d(kernel_size=3, stride=2)
100 | ]
101 | self.blocks.append(nn.Sequential(*block0))
102 |
103 | # Block 1: maxpool1 to maxpool2
104 | if self.last_needed_block >= 1:
105 | block1 = [
106 | inception.Conv2d_3b_1x1,
107 | inception.Conv2d_4a_3x3,
108 | nn.MaxPool2d(kernel_size=3, stride=2)
109 | ]
110 | self.blocks.append(nn.Sequential(*block1))
111 |
112 | # Block 2: maxpool2 to aux classifier
113 | if self.last_needed_block >= 2:
114 | block2 = [
115 | inception.Mixed_5b,
116 | inception.Mixed_5c,
117 | inception.Mixed_5d,
118 | inception.Mixed_6a,
119 | inception.Mixed_6b,
120 | inception.Mixed_6c,
121 | inception.Mixed_6d,
122 | inception.Mixed_6e,
123 | ]
124 | self.blocks.append(nn.Sequential(*block2))
125 |
126 | # Block 3: aux classifier to final avgpool
127 | if self.last_needed_block >= 3:
128 | block3 = [
129 | inception.Mixed_7a,
130 | inception.Mixed_7b,
131 | inception.Mixed_7c,
132 | nn.AdaptiveAvgPool2d(output_size=(1, 1))
133 | ]
134 | self.blocks.append(nn.Sequential(*block3))
135 |
136 | for param in self.parameters():
137 | param.requires_grad = requires_grad
138 |
139 | def forward(self, inp):
140 | """Get Inception feature maps
141 | Parameters
142 | ----------
143 | inp : torch.autograd.Variable
144 | Input tensor of shape Bx3xHxW. Values are expected to be in
145 | range (0, 1)
146 | Returns
147 | -------
148 | List of torch.autograd.Variable, corresponding to the selected output
149 | block, sorted ascending by index
150 | """
151 | outp = []
152 | x = inp
153 |
154 | if self.resize_input:
155 | x = F.interpolate(x,
156 | size=(299, 299),
157 | mode='bilinear',
158 | align_corners=False)
159 |
160 | if self.normalize_input:
161 | x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
162 |
163 | for idx, block in enumerate(self.blocks):
164 | x = block(x)
165 | if idx in self.output_blocks:
166 | outp.append(x)
167 |
168 | if idx == self.last_needed_block:
169 | break
170 |
171 | return outp
172 |
173 |
174 | def fid_inception_v3():
175 | """Build pretrained Inception model for FID computation
176 | The Inception model for FID computation uses a different set of weights
177 | and has a slightly different structure than torchvision's Inception.
178 | This method first constructs torchvision's Inception and then patches the
179 | necessary parts that are different in the FID Inception model.
180 | """
181 | inception = models.inception_v3(num_classes=1008,
182 | aux_logits=False,
183 | pretrained=False)
184 | inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
185 | inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
186 | inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
187 | inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
188 | inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
189 | inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
190 | inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
191 | inception.Mixed_7b = FIDInceptionE_1(1280)
192 | inception.Mixed_7c = FIDInceptionE_2(2048)
193 |
194 | state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
195 | inception.load_state_dict(state_dict)
196 | return inception
197 |
198 |
199 | class FIDInceptionA(models.inception.InceptionA):
200 | """InceptionA block patched for FID computation"""
201 | def __init__(self, in_channels, pool_features):
202 | super(FIDInceptionA, self).__init__(in_channels, pool_features)
203 |
204 | def forward(self, x):
205 | branch1x1 = self.branch1x1(x)
206 |
207 | branch5x5 = self.branch5x5_1(x)
208 | branch5x5 = self.branch5x5_2(branch5x5)
209 |
210 | branch3x3dbl = self.branch3x3dbl_1(x)
211 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
212 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
213 |
214 | # Patch: Tensorflow's average pool does not use the padded zero's in
215 | # its average calculation
216 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
217 | count_include_pad=False)
218 | branch_pool = self.branch_pool(branch_pool)
219 |
220 | outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
221 | return torch.cat(outputs, 1)
222 |
223 |
224 | class FIDInceptionC(models.inception.InceptionC):
225 | """InceptionC block patched for FID computation"""
226 | def __init__(self, in_channels, channels_7x7):
227 | super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
228 |
229 | def forward(self, x):
230 | branch1x1 = self.branch1x1(x)
231 |
232 | branch7x7 = self.branch7x7_1(x)
233 | branch7x7 = self.branch7x7_2(branch7x7)
234 | branch7x7 = self.branch7x7_3(branch7x7)
235 |
236 | branch7x7dbl = self.branch7x7dbl_1(x)
237 | branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
238 | branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
239 | branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
240 | branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
241 |
242 | # Patch: Tensorflow's average pool does not use the padded zero's in
243 | # its average calculation
244 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
245 | count_include_pad=False)
246 | branch_pool = self.branch_pool(branch_pool)
247 |
248 | outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
249 | return torch.cat(outputs, 1)
250 |
251 |
252 | class FIDInceptionE_1(models.inception.InceptionE):
253 | """First InceptionE block patched for FID computation"""
254 | def __init__(self, in_channels):
255 | super(FIDInceptionE_1, self).__init__(in_channels)
256 |
257 | def forward(self, x):
258 | branch1x1 = self.branch1x1(x)
259 |
260 | branch3x3 = self.branch3x3_1(x)
261 | branch3x3 = [
262 | self.branch3x3_2a(branch3x3),
263 | self.branch3x3_2b(branch3x3),
264 | ]
265 | branch3x3 = torch.cat(branch3x3, 1)
266 |
267 | branch3x3dbl = self.branch3x3dbl_1(x)
268 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
269 | branch3x3dbl = [
270 | self.branch3x3dbl_3a(branch3x3dbl),
271 | self.branch3x3dbl_3b(branch3x3dbl),
272 | ]
273 | branch3x3dbl = torch.cat(branch3x3dbl, 1)
274 |
275 | # Patch: Tensorflow's average pool does not use the padded zero's in
276 | # its average calculation
277 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
278 | count_include_pad=False)
279 | branch_pool = self.branch_pool(branch_pool)
280 |
281 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
282 | return torch.cat(outputs, 1)
283 |
284 |
285 | class FIDInceptionE_2(models.inception.InceptionE):
286 | """Second InceptionE block patched for FID computation"""
287 | def __init__(self, in_channels):
288 | super(FIDInceptionE_2, self).__init__(in_channels)
289 |
290 | def forward(self, x):
291 | branch1x1 = self.branch1x1(x)
292 |
293 | branch3x3 = self.branch3x3_1(x)
294 | branch3x3 = [
295 | self.branch3x3_2a(branch3x3),
296 | self.branch3x3_2b(branch3x3),
297 | ]
298 | branch3x3 = torch.cat(branch3x3, 1)
299 |
300 | branch3x3dbl = self.branch3x3dbl_1(x)
301 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
302 | branch3x3dbl = [
303 | self.branch3x3dbl_3a(branch3x3dbl),
304 | self.branch3x3dbl_3b(branch3x3dbl),
305 | ]
306 | branch3x3dbl = torch.cat(branch3x3dbl, 1)
307 |
308 | # Patch: The FID Inception model uses max pooling instead of average
309 | # pooling. This is likely an error in this specific Inception
310 | # implementation, as other Inception models use average pooling here
311 | # (which matches the description in the paper).
312 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
313 | branch_pool = self.branch_pool(branch_pool)
314 |
315 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
316 | return torch.cat(outputs, 1)
317 |
318 |
319 | class Inception3Feature(Inception3):
320 | def forward(self, x):
321 | if x.shape[2] != 299 or x.shape[3] != 299:
322 | x = F.interpolate(x, size=(299, 299), mode='bilinear', align_corners=True)
323 |
324 | x = self.Conv2d_1a_3x3(x) # 299 x 299 x 3
325 | x = self.Conv2d_2a_3x3(x) # 149 x 149 x 32
326 | x = self.Conv2d_2b_3x3(x) # 147 x 147 x 32
327 | x = F.max_pool2d(x, kernel_size=3, stride=2) # 147 x 147 x 64
328 |
329 | x = self.Conv2d_3b_1x1(x) # 73 x 73 x 64
330 | x = self.Conv2d_4a_3x3(x) # 73 x 73 x 80
331 | x = F.max_pool2d(x, kernel_size=3, stride=2) # 71 x 71 x 192
332 |
333 | x = self.Mixed_5b(x) # 35 x 35 x 192
334 | x = self.Mixed_5c(x) # 35 x 35 x 256
335 | x = self.Mixed_5d(x) # 35 x 35 x 288
336 |
337 | x = self.Mixed_6a(x) # 35 x 35 x 288
338 | x = self.Mixed_6b(x) # 17 x 17 x 768
339 | x = self.Mixed_6c(x) # 17 x 17 x 768
340 | x = self.Mixed_6d(x) # 17 x 17 x 768
341 | x = self.Mixed_6e(x) # 17 x 17 x 768
342 |
343 | x = self.Mixed_7a(x) # 17 x 17 x 768
344 | x = self.Mixed_7b(x) # 8 x 8 x 1280
345 | x = self.Mixed_7c(x) # 8 x 8 x 2048
346 |
347 | x = F.avg_pool2d(x, kernel_size=8) # 8 x 8 x 2048
348 |
349 | return x.view(x.shape[0], x.shape[1]) # 1 x 1 x 2048
350 |
351 |
352 | def load_patched_inception_v3():
353 | # inception = inception_v3(pretrained=True)
354 | # inception_feat = Inception3Feature()
355 | # inception_feat.load_state_dict(inception.state_dict())
356 | inception_feat = InceptionV3([3], normalize_input=False)
357 |
358 | return inception_feat
359 |
360 |
361 | @torch.no_grad()
362 | def extract_features(loader, inception, device):
363 | pbar = tqdm(loader)
364 |
365 | feature_list = []
366 |
367 | for img in pbar:
368 | img = img.to(device)
369 | feature = inception(img)[0].view(img.shape[0], -1)
370 | feature_list.append(feature.to('cpu'))
371 |
372 | features = torch.cat(feature_list, 0)
373 |
374 | return features
375 |
376 |
377 |
378 |
379 |
380 |
381 | @torch.no_grad()
382 | def extract_feature_from_generator_fn(generator_fn, inception, device='cuda', total=1000):
383 | features = []
384 |
385 | for batch in tqdm(generator_fn, total=total):
386 | try:
387 | feat = inception(batch)[0].view(batch.shape[0], -1)
388 | features.append(feat.to('cpu'))
389 | except:
390 | break
391 | features = torch.cat(features, 0).detach()
392 | return features.numpy()
393 |
394 |
395 | def calc_fid(sample_features, real_features=None, real_mean=None, real_cov=None, eps=1e-6):
396 | sample_mean = np.mean(sample_features, 0)
397 | sample_cov = np.cov(sample_features, rowvar=False)
398 |
399 | if real_features is not None:
400 | real_mean = np.mean(real_features, 0)
401 | real_cov = np.cov(real_features, rowvar=False)
402 |
403 | cov_sqrt, _ = linalg.sqrtm(sample_cov @ real_cov, disp=False)
404 |
405 | if not np.isfinite(cov_sqrt).all():
406 | print('product of cov matrices is singular')
407 | offset = np.eye(sample_cov.shape[0]) * eps
408 | cov_sqrt = linalg.sqrtm((sample_cov + offset) @ (real_cov + offset))
409 |
410 | if np.iscomplexobj(cov_sqrt):
411 | if not np.allclose(np.diagonal(cov_sqrt).imag, 0, atol=1e-3):
412 | m = np.max(np.abs(cov_sqrt.imag))
413 |
414 | raise ValueError(f'Imaginary component {m}')
415 |
416 | cov_sqrt = cov_sqrt.real
417 |
418 | mean_diff = sample_mean - real_mean
419 | mean_norm = mean_diff @ mean_diff
420 |
421 | trace = np.trace(sample_cov) + np.trace(real_cov) - 2 * np.trace(cov_sqrt)
422 |
423 | fid = mean_norm + trace
424 |
425 | return fid
426 |
427 | def real_image_loader(dataloader, n_batches=10):
428 | counter = 0
429 | while counter < n_batches:
430 | counter += 1
431 | rgb_img = next(dataloader)[0]
432 | if counter == 1:
433 | vutils.save_image(0.5*(rgb_img+1), 'tmp_real.jpg')
434 | yield rgb_img.cuda()
435 |
436 |
437 |
438 |
439 | @torch.no_grad()
440 | def image_generator(dataset, net_ae, net_ig, BATCH_SIZE=8, n_batches=500):
441 | counter = 0
442 | dataloader = iter(DataLoader(dataset, BATCH_SIZE, sampler=InfiniteSamplerWrapper(dataset), num_workers=4, pin_memory=False))
443 | n_batches = min( n_batches, len(dataset)//BATCH_SIZE-1 )
444 | while counter < n_batches:
445 | counter += 1
446 | rgb_img, _, _, skt_img = next(dataloader)
447 | rgb_img = F.interpolate( rgb_img, size=512 ).cuda()
448 | skt_img = F.interpolate( skt_img, size=512 ).cuda()
449 |
450 | gimg_ae, style_feat = net_ae(skt_img, rgb_img)
451 | g_image = net_ig(gimg_ae, style_feat)
452 | if counter == 1:
453 | vutils.save_image(0.5*(g_image+1), 'tmp.jpg')
454 | yield g_image
455 |
456 |
457 | @torch.no_grad()
458 | def image_generator_perm(dataset, net_ae, net_ig, BATCH_SIZE=8, n_batches=500):
459 | counter = 0
460 | dataloader = iter(DataLoader(dataset, BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=False))
461 | n_batches = min( n_batches, len(dataset)//BATCH_SIZE-1 )
462 | while counter < n_batches:
463 | counter += 1
464 | rgb_img, _, _, skt_img = next(dataloader)
465 | rgb_img = F.interpolate( rgb_img, size=512 ).cuda()
466 | skt_img = F.interpolate( skt_img, size=512 ).cuda()
467 |
468 | perm = true_randperm(rgb_img.shape[0], device=rgb_img.device)
469 |
470 | gimg_ae, style_feat = net_ae(skt_img, rgb_img[perm])
471 | g_image = net_ig(gimg_ae, style_feat)
472 | if counter == 1:
473 | vutils.save_image(0.5*(g_image+1), 'tmp.jpg')
474 | yield g_image
475 |
476 |
477 |
478 | if __name__ == "__main__":
479 | from utils import PairedMultiDataset, InfiniteSamplerWrapper, make_folders, AverageMeter
480 | from torch.utils.data import DataLoader
481 | from torchvision import utils as vutils
482 | IM_SIZE = 512
483 | BATCH_SIZE = 8
484 | DATALOADER_WORKERS = 8
485 | NBR_CLS = 2000
486 | TRIAL_NAME = 'trial_vae_512_1'
487 | SAVE_FOLDER = './'
488 |
489 | data_root_colorful = '../images/celebA/CelebA_512_test/img'
490 | data_root_sketch_1 = './sketch_simplification/vggadin_iter_700_test'
491 | data_root_sketch_2 = './sketch_simplification/vggadin_iter_1900_test'
492 | data_root_sketch_3 = './sketch_simplification/vggadin_iter_2300_test'
493 |
494 | dataset = PairedMultiDataset(data_root_colorful, data_root_sketch_1, data_root_sketch_2, data_root_sketch_3, im_size=IM_SIZE, rand_crop=False)
495 | dataloader = iter(DataLoader(dataset, BATCH_SIZE, shuffle=False, num_workers=DATALOADER_WORKERS, pin_memory=True))
496 |
497 |
498 | from models import StyleEncoder, ContentEncoder, Decoder
499 | import pickle
500 | from models import AE, RefineGenerator
501 | from utils import load_params
502 |
503 | net_ig = RefineGenerator().cuda()
504 | net_ig = nn.DataParallel(net_ig)
505 |
506 | ckpt = './train_results/trial_refine_ae_as_gan_1024_2/models/4.pth'
507 | if ckpt is not None:
508 | ckpt = torch.load(ckpt)
509 | #net_ig.load_state_dict(ckpt['ig'])
510 | #net_id.load_state_dict(ckpt['id'])
511 | net_ig_ema = ckpt['ig_ema']
512 | load_params(net_ig, net_ig_ema)
513 | net_ig = net_ig.module
514 | #net_ig.eval()
515 |
516 | net_ae = AE()
517 | net_ae.load_state_dicts('./train_results/trial_vae_512_1/models/176000.pth')
518 | net_ae.cuda()
519 | net_ae.eval()
520 |
521 | inception = load_patched_inception_v3().cuda()
522 | inception.eval()
523 |
524 | '''
525 | real_features = extract_feature_from_generator_fn(
526 | real_image_loader(dataloader, n_batches=1000), inception )
527 | real_mean = np.mean(real_features, 0)
528 | real_cov = np.cov(real_features, rowvar=False)
529 | '''
530 | #pickle.dump({'feats': real_features, 'mean': real_mean, 'cov': real_cov}, open('celeba_fid_feats.npy','wb') )
531 |
532 | real_features = pickle.load( open('celeba_fid_feats.npy', 'rb') )
533 | real_mean = real_features['mean']
534 | real_cov = real_features['cov']
535 | #sample_features = extract_feature_from_generator_fn( real_image_loader(dataloader, n_batches=100), inception )
536 | for it in range(1):
537 | itx = it * 8000
538 | '''
539 | ckpt = torch.load('./train_results/%s/models/%d.pth'%(TRIAL_NAME, itx))
540 |
541 | style_encoder.load_state_dict(ckpt['e'])
542 | content_encoder.load_state_dict(ckpt['c'])
543 | decoder.load_state_dict(ckpt['d'])
544 |
545 | dataloader = iter(DataLoader(dataset, BATCH_SIZE, sampler=InfiniteSamplerWrapper(dataset), num_workers=DATALOADER_WORKERS, pin_memory=True))
546 | '''
547 |
548 | sample_features = extract_feature_from_generator_fn(
549 | image_generator(dataset, net_ae, net_ig, n_batches=1800), inception,
550 | total=1800 )
551 |
552 | #fid = calc_fid(sample_features, real_mean=real_features['mean'], real_cov=real_features['cov'])
553 | fid = calc_fid(sample_features, real_mean=real_mean, real_cov=real_cov)
554 | print(it, fid)
555 |
556 | real_features = extract_feature_from_generator_fn(
557 | real_image_loader(dataloader, n_batches=fid_batch_images), inception)
558 | real_mean = np.mean(real_features, 0)
559 | real_cov = np.cov(real_features, rowvar=False)
560 | pickle.dump({'feats': real_features, 'mean': real_mean, 'cov': real_cov},
561 | open('%s_fid_feats.npy' % (DATA_NAME), 'wb'))
562 | real_features = pickle.load(open('%s_fid_feats.npy' % (DATA_NAME), 'rb'))
--------------------------------------------------------------------------------
/sketch_generation/evaluate.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | import torchvision.datasets as Dataset
5 | import torchvision.utils as vutils
6 | from torch import nn
7 |
8 |
9 | from models import Generator, VGGSimple
10 | from operation import trans_maker_testing
11 |
12 | import argparse
13 |
14 |
15 | if __name__ == '__main__':
16 |
17 | parser = argparse.ArgumentParser(description='Style transfer GAN, during training, the model will learn to take a image from one specific catagory and transform it into another style domain')
18 |
19 | parser.add_argument('--path_content', type=str, help='path of resource dataset, should be a folder that has one or many sub image folders inside')
20 | parser.add_argument('--path_result', type=str, help='path to save the result images')
21 | parser.add_argument('--im_size', type=int, default=256, help='resolution of the generated images')
22 |
23 | parser.add_argument('--gpu_id', type=int, default=0, help='0 is the first gpu, 1 is the second gpu, etc.')
24 | parser.add_argument('--norm_layer', type=str, default="instance", help='can choose between [batch, instance]')
25 | parser.add_argument('--checkpoint', type=str, help='specify the path of the pre-trained model')
26 |
27 | args = parser.parse_args()
28 |
29 | print(str(args))
30 |
31 | device = torch.device("cuda:%d"%(args.gpu_id))
32 |
33 | im_size = args.im_size
34 | if im_size == 128:
35 | base = 4
36 | elif im_size == 256:
37 | base = 8
38 | elif im_size == 512:
39 | base = 16
40 | elif im_size == 1024:
41 | base = 32
42 | if im_size not in [128, 256, 512, 1024]:
43 | print("the size must be in [128, 256, 512, 1024]")
44 |
45 | vgg = VGGSimple()
46 | vgg.load_state_dict(torch.load('./vgg-feature-weights.pth', map_location=lambda a,b:a))
47 | vgg.to(device)
48 | vgg.eval()
49 | for p in vgg.parameters():
50 | p.requires_grad = False
51 |
52 | dataset = Dataset.ImageFolder(root=args.path_content, transform=trans_maker_testing(size=args.im_size))
53 |
54 | net_g = Generator(infc=256, nfc=128)
55 |
56 | if args.checkpoint is not 'None':
57 | checkpoint = torch.load(args.checkpoint, map_location=lambda storage, loc: storage)
58 | net_g.load_state_dict(checkpoint['g'])
59 | print("saved model loaded")
60 |
61 | net_g.to(device)
62 | net_g.eval()
63 |
64 | dist_path = args.path_result
65 | if not os.path.exists(dist_path):
66 | os.mkdir(dist_path)
67 |
68 |
69 | print("begin generating images ...")
70 | with torch.no_grad():
71 | for i in range(len(dataset)):
72 | print("generating the %dth image"%(i))
73 | img = dataset[i][0].to(device)
74 | feat = vgg(img, base=base)[2]
75 | g_img = net_g(feat)
76 |
77 | g_img = g_img.mean(1).unsqueeze(1).detach().add(1).mul(0.5)
78 | g_img = (g_img > 0.7).float()
79 | vutils.save_image(g_img, os.path.join(dist_path, '%d.jpg'%(i)))
--------------------------------------------------------------------------------
/sketch_generation/metrics.py:
--------------------------------------------------------------------------------
1 | # example of calculating the frechet inception distance in Keras
2 | import numpy
3 | import os
4 | import cv2
5 | import argparse
6 | import torch
7 | import numpy as np
8 | from scipy.linalg import sqrtm
9 | from keras.applications.inception_v3 import InceptionV3
10 | from keras.applications.inception_v3 import preprocess_input
11 |
12 |
13 | # os.environ["TF_CPP_MIN_LOG_LEVEL"]='2' # 只显示 warning 和 Error
14 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = '3' # 只显示 Error
15 |
16 | # calculate frechet inception distance
17 | def calculate_fid(model, images1, images2):
18 | # calculate activations
19 | act1 = model.predict(images1)
20 | act2 = model.predict(images2)
21 | # calculate mean and covariance statistics
22 | mu1, sigma1 = act1.mean(axis=0), np.cov(act1, rowvar=False)
23 | mu2, sigma2 = act2.mean(axis=0), np.cov(act2, rowvar=False)
24 | # calculate sum squared difference between means
25 | ssdiff = numpy.sum((mu1 - mu2)**2.0)
26 | # calculate sqrt of product between cov
27 | covmean = sqrtm(np.dot(sigma1, sigma2))
28 | # check and correct imaginary numbers from sqrt
29 | if np.iscomplexobj(covmean):
30 | covmean = covmean.real
31 | # calculate score
32 | fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
33 | return fid
34 |
35 | #act1 =generatedImg ,act2 = realImg
36 | def calculate_fid_modify(act1,act2):
37 | # calculate activations
38 | # act1 = model.predict(images1)
39 | # act2 = model.predict(images2)
40 | # calculate mean and covariance statistics
41 | mu1, sigma1 = act1.mean(axis=0), np.cov(act1, rowvar=False)
42 | mu2, sigma2 = act2.mean(axis=0), np.cov(act2, rowvar=False)
43 | # calculate sum squared difference between means
44 | ssdiff = numpy.sum((mu1 - mu2)**2.0)
45 | # calculate sqrt of product between cov
46 | covmean = sqrtm(np.dot(sigma1, sigma2))
47 | # check and correct imaginary numbers from sqrt
48 | if np.iscomplexobj(covmean):
49 | covmean = covmean.real
50 | # calculate score
51 | fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
52 | return fid
53 |
54 | def data_list(dirPath):
55 | generated_Dataset = []
56 | real_Dataset = []
57 | for root, dirs, files in os.walk(dirPath):
58 | for filename in sorted(files): # sorted已排序的列表副本
59 | # 判断该文件是否是目标文件
60 | if "generated" in filename:
61 | generatedPath = root + '/' + filename
62 | generatedImg = cv2.imread(generatedPath).astype('float32')
63 | generated_Dataset.append(generatedImg)
64 | # 对比图片路径
65 | realPath = root + '/' + filename.replace('generated', 'real')
66 | realImg = cv2.imread(realPath).astype('float32')
67 | real_Dataset.append(realImg)
68 | return generated_Dataset, real_Dataset
69 |
70 | if __name__ == '__main__':
71 | ### 参数设定
72 | parser = argparse.ArgumentParser()
73 | # parser.add_argument('--dataset_dir', type=str, default='./results/hrnet/', help='results')
74 | parser.add_argument('--dataset_dir', type=str, default='./results/ssngan/', help='results')
75 | parser.add_argument('--name', type=str, default='sketch', help='name of dataset')
76 | opt = parser.parse_args()
77 |
78 | # 数据集
79 | dirPath = os.path.join(opt.dataset_dir, opt.name)
80 | generatedImg, realImg = data_list(dirPath)
81 | dataset_size = len(generatedImg)
82 | print("数据集:", dataset_size)
83 |
84 | images1 = torch.Tensor(generatedImg)
85 | images2 = torch.Tensor(realImg)
86 | print('shape: ', images1.shape, images2.shape)
87 |
88 | # 将全部数据集导入
89 | # prepare the inception v3 model
90 | model = InceptionV3(include_top=False, pooling='avg')
91 |
92 | # pre-process images(归一化)
93 | images1 = preprocess_input(images1)
94 | images2 = preprocess_input(images2)
95 |
96 | # fid between images1 and images2
97 | fid = calculate_fid(model, images1, images2)
98 | print('FID : %.3f' % fid)
99 | print('FID_average : %.3f' % (fid / dataset_size))
100 |
101 |
102 |
103 |
--------------------------------------------------------------------------------
/sketch_generation/models.py:
--------------------------------------------------------------------------------
1 | import functools
2 | from math import sqrt
3 | import torch
4 | import torch.nn as nn
5 | import torch.optim as optim
6 | from torch import cat, sigmoid
7 | from torch.autograd import Variable
8 | from torch.nn import Parameter, init
9 | from torch.nn.utils import spectral_norm
10 | import torch.nn.functional as F
11 |
12 | from torch.jit import ScriptModule, script_method, trace
13 |
14 | #####################################################################
15 | ##### functions
16 | #####################################################################
17 |
18 | def calc_mean_std(feat, eps=1e-5):
19 | # eps is a small value added to the variance to avoid divide-by-zero.
20 | size = feat.size()
21 | assert (len(size) == 4)
22 | N, C = size[:2]
23 | feat_var = feat.view(N, C, -1).var(dim=2) + eps
24 | feat_std = feat_var.sqrt().view(N, C, 1, 1)
25 | feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
26 | return feat_mean, feat_std
27 |
28 |
29 | # def adain(content_feat, style_feat):
30 | # assert (content_feat.size()[:2] == style_feat.size()[:2])
31 | # size = content_feat.size()
32 | # style_mean, style_std = calc_mean_std(style_feat)
33 | # content_mean, content_std = calc_mean_std(content_feat)
34 | #
35 | # normalized_feat = (content_feat - content_mean.expand(
36 | # size)) / content_std.expand(size)
37 | # return normalized_feat * style_std.expand(size) + style_mean.expand(size)
38 | def AdaLIN(content_feat,style_feat):
39 |
40 | assert (content_feat.size()[:2]==style_feat.size()[:2])
41 |
42 | rho=Parameter(torch.Tensor(4,256,32,32,)) #维度修改了,原来是 rho=Parameter(torch.Tensor(1,512,1,1,))
43 | rho=rho.data.fill_(0.9)
44 |
45 | size=content_feat.size()
46 | style_mean,style_std=calc_mean_std(style_feat)
47 | content_mean,content_std=calc_mean_std(content_feat)
48 | out_style=(style_feat-style_mean.expand(size))/style_std.expand(size)
49 | out_content=(content_feat-content_mean.expand(size))/content_std.expand(size)
50 | out=rho.expand(size)*out_style+(1-rho.expand(size))*out_content
51 | return out
52 |
53 | def adain(content_feat, style_feat):
54 | assert (content_feat.size()[:2] == style_feat.size()[:2])
55 | size = content_feat.size()
56 | style_mean, style_std = calc_mean_std(style_feat)
57 | content_mean, content_std = calc_mean_std(content_feat)
58 |
59 | normalized_feat = (content_feat - content_mean.expand(
60 | size)) / content_std.expand(size)
61 | normalized_features = normalized_feat * style_std.expand(size) + style_mean.expand(size)
62 | return normalized_features #torch.Size([4, 256, 32, 32])
63 |
64 | def get_batched_gram_matrix(input):
65 | # take a batch of features: B X C X H X W
66 | # return gram of each image: B x C x C
67 | a, b, c, d = input.size()
68 | features = input.view(a, b, c * d)
69 | G = torch.bmm(features, features.transpose(2,1))
70 | return G.div(b * c * d)
71 |
72 | class Adaptive_pool(nn.Module):
73 | '''
74 | take a input tensor of size: B x C' X C'
75 | output a maxpooled tensor of size: B x C x H x W
76 | '''
77 | def __init__(self, channel_out, hw_out):
78 | super().__init__()
79 | self.channel_out = channel_out
80 | self.hw_out = hw_out
81 | self.pool = nn.AdaptiveAvgPool2d((channel_out, hw_out**2))
82 | def forward(self, input):
83 | if len(input.shape) == 3:
84 | input.unsqueeze_(1)
85 | return self.pool(input).view(-1, self.channel_out, self.hw_out, self.hw_out)
86 | ### new function
87 |
88 | #####################################################################
89 | ##### models
90 | #####################################################################
91 | class VGGSimple(nn.Module):
92 | def __init__(self):
93 | super(VGGSimple, self).__init__()
94 |
95 | self.features = self.make_layers()
96 |
97 | self.norm_mean = torch.Tensor([0.485, 0.456, 0.406]).view(1,3,1,1)
98 | self.norm_std = torch.Tensor([0.229, 0.224, 0.225]).view(1,3,1,1)
99 |
100 | def forward(self, img, after_relu=True, base=4):
101 | # re-normalize from [-1, 1] to [0, 1] then to the range used for vgg
102 | feat = (((img+1)*0.5) - self.norm_mean.to(img.device)) / self.norm_std.to(img.device)
103 | # the layer numbers used to extract features
104 | cut_points = [2, 7, 14, 21, 28]
105 | if after_relu:
106 | cut_points = [c+2 for c in cut_points]
107 | for i in range(31):
108 | feat = self.features[i](feat)
109 | if i == cut_points[0]:
110 | feat_64 = F.adaptive_avg_pool2d(feat, base*16)
111 | if i == cut_points[1]:
112 | feat_32 = F.adaptive_avg_pool2d(feat, base*8)
113 | if i == cut_points[2]:
114 | feat_16 = F.adaptive_avg_pool2d(feat, base*4)
115 | if i == cut_points[3]:
116 | feat_8 = F.adaptive_avg_pool2d(feat, base*2)
117 | if i == cut_points[4]:
118 | feat_4 = F.adaptive_avg_pool2d(feat, base)
119 |
120 | return feat_64, feat_32, feat_16, feat_8, feat_4
121 |
122 | def make_layers(self, cfg="D", batch_norm=False):
123 | cfg_dic = {
124 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
125 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
126 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
127 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
128 | }
129 | cfg = cfg_dic[cfg]
130 | layers = []
131 | in_channels = 3
132 | for v in cfg:
133 | if v == 'M':
134 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
135 | else:
136 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
137 | if batch_norm:
138 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=False)]
139 | else:
140 | layers += [conv2d, nn.ReLU(inplace=False)]
141 | in_channels = v
142 | return nn.Sequential(*layers)
143 |
144 |
145 | # this model is used for pre-training
146 | class VGG_3label(nn.Module):
147 | def __init__(self, nclass_artist=1117, nclass_style=55, nclass_genre=26):
148 | super(VGG_3label, self).__init__()
149 | self.features = self.make_layers()
150 | self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
151 |
152 | self.classifier_feat = self.classifier = nn.Sequential(
153 | nn.Linear(512 * 7 * 7, 4096),
154 | nn.ReLU(),
155 | nn.Dropout(),
156 | nn.Linear(4096, 4096),
157 | nn.ReLU(),
158 | nn.Dropout(),
159 | nn.Linear(4096, 512))
160 |
161 | self.classifier_style = nn.Sequential(nn.ReLU(), nn.Dropout(), nn.Linear(512, nclass_style))
162 | self.classifier_genre = nn.Sequential(nn.ReLU(), nn.Dropout(), nn.Linear(512, nclass_genre))
163 | self.classifier_artist = nn.Sequential(nn.ReLU(), nn.Dropout(), nn.Linear(512, nclass_artist))
164 |
165 | self.norm_mean = torch.Tensor([0.485, 0.456, 0.406]).view(1,3,1,1)
166 | self.norm_std = torch.Tensor([0.229, 0.224, 0.225]).view(1,3,1,1)
167 |
168 | self.avgpool_4 = nn.AdaptiveAvgPool2d((4, 4))
169 | self.avgpool_8 = nn.AdaptiveAvgPool2d((8, 8))
170 | self.avgpool_16 = nn.AdaptiveAvgPool2d((16, 16))
171 |
172 | def get_features(self, img, after_relu=True, base=4):
173 | feat = (((img+1)*0.5) - self.norm_mean.to(img.device)) / self.norm_std.to(img.device)
174 | cut_points = [2, 7, 14, 21, 28]
175 | if after_relu:
176 | cut_points = [4, 9, 16, 23, 30]
177 | for i in range(31):
178 | feat = self.features[i](feat)
179 | if i == cut_points[0]:
180 | feat_64 = F.adaptive_avg_pool2d(feat, base*16)
181 | if i == cut_points[1]:
182 | feat_32 = F.adaptive_avg_pool2d(feat, base*8)
183 | if i == cut_points[2]:
184 | feat_16 = F.adaptive_avg_pool2d(feat, base*4)
185 | if i == cut_points[3]:
186 | feat_8 = F.adaptive_avg_pool2d(feat, base*2)
187 | if i == cut_points[4]:
188 | feat_4 = F.adaptive_avg_pool2d(feat, base)
189 | #feat_code = self.classifier_feat(self.avgpool(feat).view(img.size(0), -1))
190 | return feat_64, feat_32, feat_16, feat_8, feat_4#, feat_code
191 |
192 |
193 | def load_pretrain_weights(self):
194 | pretrained_vgg16 = vgg.vgg16(pretrained=True)
195 | self.features.load_state_dict(pretrained_vgg16.features.state_dict())
196 | self.classifier_feat[0] = pretrained_vgg16.classifier[0]
197 | self.classifier_feat[3] = pretrained_vgg16.classifier[3]
198 | for m in self.modules():
199 | if isinstance(m, nn.Linear):
200 | nn.init.normal_(m.weight, 0, 0.01)
201 | nn.init.constant_(m.bias, 0)
202 |
203 | def make_layers(self, cfg="D", batch_norm=False):
204 | cfg_dic = {
205 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
206 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
207 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
208 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
209 | }
210 | cfg = cfg_dic[cfg]
211 | layers = []
212 | in_channels = 3
213 | for v in cfg:
214 | if v == 'M':
215 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
216 | else:
217 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
218 | if batch_norm:
219 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=False)]
220 | else:
221 | layers += [conv2d, nn.ReLU(inplace=False)]
222 | in_channels = v
223 | return nn.Sequential(*layers)
224 |
225 | def forward(self, img):
226 | feature = self.classifier_feat( self.avgpool(self.features(img)).view(img.size(0), -1) )
227 | pred_style = self.classifier_style(feature)
228 | pred_genre = self.classifier_genre(feature)
229 | pred_artist = self.classifier_artist(feature)
230 | return pred_style, pred_genre, pred_artist
231 |
232 |
233 | class UnFlatten(nn.Module):
234 | def __init__(self, block_size):
235 | super(UnFlatten, self).__init__()
236 | self.block_size = block_size
237 |
238 | def forward(self, x):
239 | return x.view(x.size(0), -1, self.block_size, self.block_size)
240 |
241 |
242 | class Flatten(nn.Module):
243 | def __init__(self):
244 | super(Flatten, self).__init__()
245 |
246 | def forward(self, x):
247 | return x.view(x.size(0), -1)
248 |
249 | #batchNorm2d-->InstanceNorm2d
250 | class UpConvBlock(nn.Module):
251 | def __init__(self, in_channel, out_channel, norm_layer=nn.InstanceNorm2d):
252 | super().__init__()
253 |
254 | self.main = nn.Sequential(
255 | nn.ReflectionPad2d(1),
256 | spectral_norm(nn.Conv2d(in_channel, out_channel, 3, 1, 0, bias=True)),
257 | norm_layer(out_channel),
258 | nn.LeakyReLU(0.01),
259 | )
260 |
261 | def forward(self, x):
262 | y = F.interpolate(x, scale_factor=2)
263 | return self.main(y)
264 |
265 | #batchNorm2d-->InstanceNorm2d
266 | class DownConvBlock(nn.Module):
267 | def __init__(self, in_channel, out_channel, norm_layer=nn.InstanceNorm2d, down=True):
268 | super().__init__()
269 |
270 | m = [ spectral_norm(nn.Conv2d(in_channel, out_channel, 3, 1, 1, bias=True)),
271 | norm_layer(out_channel),
272 | nn.LeakyReLU(0.1) ]
273 | if down:
274 | m.append(nn.AvgPool2d(2, 2))
275 | self.main = nn.Sequential(*m)
276 |
277 | def forward(self, x):
278 | return self.main(x)
279 |
280 |
281 | class ResNetBlock(nn.Module):
282 | def __init__(self, dim):
283 | super(ResNetBlock, self).__init__()
284 | conv_block = []
285 | conv_block += [nn.ReflectionPad2d(1),
286 | nn.Conv2d(dim, dim, 3, 1, 0, bias=False),
287 | nn.InstanceNorm2d(dim),
288 | nn.ReLU(True)]
289 |
290 | conv_block += [nn.ReflectionPad2d(1),
291 | nn.Conv2d(dim, dim, 3, 1, 0, bias=False),
292 | nn.InstanceNorm2d(dim)]
293 |
294 | self.conv_block = nn.Sequential(*conv_block)
295 |
296 | def forward(self, x):
297 | out = x + self.conv_block(x)
298 | return out
299 |
300 | class Generator(nn.Module):
301 | def __init__(self, infc=512, nfc=64, nc_out=3):
302 | super(Generator, self).__init__()
303 |
304 | self.decode_32 = UpConvBlock(infc, nfc*4) #32
305 | self.decode_64 = UpConvBlock(nfc*4, nfc*4) #64
306 | self.decode_128 = UpConvBlock(nfc*4, nfc*2) #128
307 | self.gap_fc=nn.Linear(512,1,bias=False)
308 | self.gmp_fc=nn.Linear(512,1,bias=False)
309 | self.gamma = nn.Linear(512, 256, bias=False) #(256,256)
310 | self.beta = nn.Linear(512, 256, bias=False) #
311 | self.conv1x1 = nn.Conv2d(512, 256, 1, 1, bias=True)
312 | self.relu = nn.ReLU(inplace=True)
313 | self.final = nn.Sequential(
314 | spectral_norm( nn.Conv2d(nfc*2, nc_out, 3, 1, 1, bias=True) ),
315 | nn.Tanh())
316 | self.netG_A2B = Generator_UGATIT(image_size=256)
317 | def forward(self, input):
318 |
319 | decode_32 = self.decode_32(input) # input torch.Size([8, 256, 32, 32])
320 | decode_64 = self.decode_64(decode_32)
321 | decode_128 = self.decode_128(decode_64)
322 |
323 | output = self.final(decode_128) #output torch.Size([8, 3, 256, 256])
324 | output=self.netG_A2B(output)[0] #此处解码后,再经过Generator_UGATIT 的处理后再输出
325 | return output
326 |
327 | class Generator_UGATIT(nn.Module):
328 | def __init__(self, image_size=256):
329 | super(Generator_UGATIT, self).__init__()
330 | down_layer = [
331 | nn.ReflectionPad2d(3),
332 | nn.Conv2d(3, 64, 7, 1, 0, bias=False),
333 | nn.InstanceNorm2d(64),
334 | nn.ReLU(inplace=True),
335 |
336 | # Down-Sampling
337 | nn.ReflectionPad2d(1),
338 | nn.Conv2d(64, 128, 3, 2, 0, bias=False),
339 | nn.InstanceNorm2d(256),
340 | nn.ReLU(inplace=True),
341 | nn.ReflectionPad2d(1),
342 | nn.Conv2d(128, 256, 3, 2, 0, bias=False),
343 | nn.InstanceNorm2d(256),
344 | nn.ReLU(inplace=True),
345 |
346 | # Down-Sampling Bottleneck
347 | ResNetBlock(256),
348 | ResNetBlock(256),
349 | ResNetBlock(256),
350 | ResNetBlock(256),
351 | ]
352 |
353 | # Class Activation Map
354 | self.gap_fc = nn.Linear(256, 1, bias=False)
355 | self.gmp_fc = nn.Linear(256, 1, bias=False)
356 | self.conv1x1 = nn.Conv2d(512, 256, 1, 1, bias=True)
357 | self.relu = nn.ReLU(inplace=True)
358 |
359 | # Gamma, Beta block
360 | fc = [
361 | nn.Linear(image_size * image_size * 16, 256, bias=False),
362 | nn.ReLU(inplace=True),
363 | nn.Linear(256, 256, bias=False),
364 | nn.ReLU(inplace=True)
365 | ]
366 |
367 | self.gamma = nn.Linear(256, 256, bias=False)
368 | self.beta = nn.Linear(256, 256, bias=False)
369 |
370 | # Up-Sampling Bottleneck
371 | for i in range(4):
372 | setattr(self, "ResNetAdaILNBlock_" + str(i + 1), ResNetAdaILNBlock(256))
373 |
374 | up_layer = [
375 | nn.Upsample(scale_factor=2, mode="nearest"),
376 | nn.ReflectionPad2d(1),
377 | nn.Conv2d(256, 128, 3, 1, 0, bias=False),
378 | ILN(128),
379 | nn.ReLU(inplace=True),
380 |
381 | nn.Upsample(scale_factor=2, mode="nearest"),
382 | nn.ReflectionPad2d(1),
383 | nn.Conv2d(128, 64, 3, 1, 0, bias=False),
384 | ILN(64),
385 | nn.ReLU(inplace=True),
386 |
387 | nn.ReflectionPad2d(3),
388 | nn.Conv2d(64, 3, 7, 1, 0, bias=False),
389 | nn.Tanh()
390 | ]
391 |
392 | self.down_layer = nn.Sequential(*down_layer)
393 | self.fc = nn.Sequential(*fc)
394 | self.up_layer = nn.Sequential(*up_layer)
395 |
396 | def forward(self, inputs):
397 | x = self.down_layer(inputs)
398 | gap = torch.nn.functional.adaptive_avg_pool2d(x, 1)
399 | gap_logit = self.gap_fc(gap.view(x.shape[0], -1))
400 | gap_weight = list(self.gap_fc.parameters())[0]
401 | gap = x * gap_weight.unsqueeze(2).unsqueeze(3)
402 |
403 | gmp = torch.nn.functional.adaptive_max_pool2d(x, 1)
404 | gmp_logit = self.gmp_fc(gmp.view(x.shape[0], -1))
405 | gmp_weight = list(self.gmp_fc.parameters())[0]
406 | gmp = x * gmp_weight.unsqueeze(2).unsqueeze(3)
407 |
408 | cam_logit = torch.cat([gap_logit, gmp_logit], 1)
409 | x = torch.cat([gap, gmp], 1)
410 | x = self.relu(self.conv1x1(x))
411 |
412 | x_ = self.fc(x.view(x.shape[0], -1))
413 | gamma, beta = self.gamma(x_), self.beta(x_)
414 |
415 | for i in range(4):
416 | x = getattr(self, "ResNetAdaILNBlock_" + str(i + 1))(x, gamma, beta)
417 | out = self.up_layer(x)
418 |
419 | return out, cam_logit
420 |
421 |
422 |
423 |
424 | class ResNetAdaILNBlock(nn.Module):
425 | def __init__(self, dim):
426 | super(ResNetAdaILNBlock, self).__init__()
427 | self.pad1 = nn.ReflectionPad2d(1)
428 | self.conv1 = nn.Conv2d(dim, dim, 3, 1, 0, bias=False)
429 | self.norm1 = AdaILN(dim)
430 | self.relu1 = nn.ReLU(True)
431 |
432 | self.pad2 = nn.ReflectionPad2d(1)
433 | self.conv2 = nn.Conv2d(dim, dim, 3, 1, 0, bias=False)
434 | self.norm2 = AdaILN(dim)
435 |
436 | def forward(self, x, gamma, beta):
437 | out = self.pad1(x)
438 | out = self.conv1(out)
439 | out = self.norm1(out, gamma, beta)
440 | out = self.relu1(out)
441 | out = self.pad2(out)
442 | out = self.conv2(out)
443 | out = self.norm2(out, gamma, beta)
444 |
445 | return out + x
446 |
447 | class ILN(nn.Module):
448 | def __init__(self, num_features, eps=1e-5):
449 | super(ILN, self).__init__()
450 | self.eps = eps
451 | self.rho = Parameter(torch.Tensor(1, num_features, 1, 1))
452 | self.gamma = Parameter(torch.Tensor(1, num_features, 1, 1))
453 | self.beta = Parameter(torch.Tensor(1, num_features, 1, 1))
454 | self.rho.data.fill_(0.0)
455 | self.gamma.data.fill_(1.0)
456 | self.beta.data.fill_(0.0)
457 |
458 | def forward(self, x):
459 | in_mean, in_var = torch.mean(x, dim=[2, 3], keepdim=True), torch.var(x, dim=[2, 3], keepdim=True)
460 | out_in = (x - in_mean) / torch.sqrt(in_var + self.eps)
461 | ln_mean, ln_var = torch.mean(x, dim=[1, 2, 3], keepdim=True), torch.var(x, dim=[1, 2, 3], keepdim=True)
462 | out_ln = (x - ln_mean) / torch.sqrt(ln_var + self.eps)
463 | out = self.rho.expand(x.shape[0], -1, -1, -1) * out_in + (1 - self.rho.expand(x.shape[0], -1, -1, -1)) * out_ln
464 | out = out * self.gamma.expand(x.shape[0], -1, -1, -1) + self.beta.expand(x.shape[0], -1, -1, -1)
465 |
466 | return out
467 |
468 | class AdaILN(nn.Module):
469 | def __init__(self, num_features, eps=1e-5):
470 | super(AdaILN, self).__init__()
471 | self.eps = eps
472 | self.rho = Parameter(torch.Tensor(1, num_features, 1, 1))
473 | self.rho.data.fill_(0.9)
474 |
475 | def forward(self, x, gamma, beta):
476 | in_mean, in_var = torch.mean(x, dim=[2, 3], keepdim=True), torch.var(x, dim=[2, 3], keepdim=True)
477 | out_in = (x - in_mean) / torch.sqrt(in_var + self.eps)
478 | ln_mean, ln_var = torch.mean(x, dim=[1, 2, 3], keepdim=True), torch.var(x, dim=[1, 2, 3], keepdim=True)
479 | out_ln = (x - ln_mean) / torch.sqrt(ln_var + self.eps)
480 | out = self.rho.expand(x.shape[0], -1, -1, -1) * out_in + (1 - self.rho.expand(x.shape[0], -1, -1, -1)) * out_ln
481 | out = out * gamma.unsqueeze(2).unsqueeze(3) + beta.unsqueeze(2).unsqueeze(3)
482 |
483 | return out
484 |
485 | class ResNetBlock(nn.Module):
486 | def __init__(self, dim):
487 | super(ResNetBlock, self).__init__()
488 | conv_block = []
489 | conv_block += [nn.ReflectionPad2d(1),
490 | nn.Conv2d(dim, dim, 3, 1, 0, bias=False),
491 | nn.InstanceNorm2d(dim),
492 | nn.ReLU(True)]
493 |
494 | conv_block += [nn.ReflectionPad2d(1),
495 | nn.Conv2d(dim, dim, 3, 1, 0, bias=False),
496 | nn.InstanceNorm2d(dim)]
497 |
498 | self.conv_block = nn.Sequential(*conv_block)
499 |
500 | def forward(self, x):
501 | out = x + self.conv_block(x)
502 | return out
503 |
504 |
505 |
506 |
507 |
508 |
509 |
510 | class Discriminator(nn.Module):
511 | def __init__(self, nfc=512, norm_layer=nn.InstanceNorm2d):
512 | super(Discriminator, self).__init__()
513 | self.gap_fc = nn.utils.spectral_norm(nn.Linear(1, 4, bias=False)) #这里维度修改了原来是64 * 8, 1
514 | self.gmp_fc = nn.utils.spectral_norm(nn.Linear(1, 4, bias=False))
515 | self.conv1x1 = nn.Conv2d(2, 4, 3, 3, bias=True)
516 | self.leaky_relu = nn.LeakyReLU(0.2, True)
517 |
518 | self.pad = nn.ReflectionPad2d(1)
519 | self.conv = nn.utils.spectral_norm(nn.Conv2d(4, 4, 1, 1, 0, bias=False))
520 |
521 | self.main = nn.Sequential(
522 | DownConvBlock(nfc, nfc // 2, norm_layer=norm_layer, down=False),
523 | DownConvBlock(nfc // 2, nfc // 4, norm_layer=norm_layer), # 4x4
524 | spectral_norm(nn.Conv2d(nfc // 4, 1, 4, 2, 0))
525 | )
526 |
527 | def forward(self, input):
528 | x = self.main(input)
529 | gap = torch.nn.functional.adaptive_avg_pool2d(x, 1) #x torch.Size([4, 1, 3, 3])
530 | gap_logit = self.gap_fc(gap.view(x.shape[0], -1))
531 | gap_weight = list(self.gap_fc.parameters())[0]
532 | gap = x * gap_weight.unsqueeze(2).unsqueeze(3)
533 |
534 | gmp = torch.nn.functional.adaptive_max_pool2d(x, 1)
535 | gmp_logit = self.gmp_fc(gmp.view(x.shape[0], -1))
536 | gmp_weight = list(self.gmp_fc.parameters())[0]
537 | gmp = x * gmp_weight.unsqueeze(2).unsqueeze(3)
538 |
539 | cam_logit = torch.cat([gap_logit, gmp_logit], 1)
540 | x = torch.cat([gap, gmp], 1)
541 | x = self.leaky_relu(self.conv1x1(x))
542 | # x = self.pad(x)
543 | out = self.conv(x)
544 |
545 | return out.view(-1)
546 |
547 | class Discriminator_UGATIT(nn.Module):
548 | def __init__(self, input_nc, ndf=64, n_layers=5):
549 | super(Discriminator_UGATIT, self).__init__()
550 | model = [nn.ReflectionPad2d(1),
551 | nn.utils.spectral_norm(
552 | nn.Conv2d(input_nc, ndf, kernel_size=4, stride=2, padding=0, bias=True)),
553 | nn.LeakyReLU(0.2, True)]
554 |
555 | for i in range(1, n_layers - 2):
556 | mult = 2 ** (i - 1)
557 | model += [nn.ReflectionPad2d(1),
558 | nn.utils.spectral_norm(
559 | nn.Conv2d(ndf * mult, ndf * mult * 2, kernel_size=4, stride=2, padding=0, bias=True)),
560 | nn.LeakyReLU(0.2, True)]
561 |
562 | mult = 2 ** (n_layers - 2 - 1)
563 | model += [nn.ReflectionPad2d(1),
564 | nn.utils.spectral_norm(
565 | nn.Conv2d(ndf * mult, ndf * mult * 2, kernel_size=4, stride=1, padding=0, bias=True)),
566 | nn.LeakyReLU(0.2, True)]
567 |
568 | # Class Activation Map
569 | mult = 2 ** (n_layers - 2)
570 | self.gap_fc = nn.utils.spectral_norm(nn.Linear(ndf * mult, 1, bias=False))
571 | self.gmp_fc = nn.utils.spectral_norm(nn.Linear(ndf * mult, 1, bias=False))
572 | self.conv1x1 = nn.Conv2d(ndf * mult * 2, ndf * mult, kernel_size=1, stride=1, bias=True)
573 | self.leaky_relu = nn.LeakyReLU(0.2, True)
574 |
575 | self.pad = nn.ReflectionPad2d(1)
576 | self.conv = nn.utils.spectral_norm(
577 | nn.Conv2d(ndf * mult, 1, kernel_size=4, stride=1, padding=0, bias=False))
578 |
579 | self.model = nn.Sequential(*model)
580 |
581 | def forward(self, input):
582 | x = self.model(input) #input torch.Size([1, 3, 256, 256])
583 |
584 | gap = torch.nn.functional.adaptive_avg_pool2d(x, 1) #x torch.Size([1, 2048, 7, 7])
585 | gap_logit = self.gap_fc(gap.view(x.shape[0], -1))
586 | gap_weight = list(self.gap_fc.parameters())[0]
587 | gap = x * gap_weight.unsqueeze(2).unsqueeze(3)
588 |
589 | gmp = torch.nn.functional.adaptive_max_pool2d(x, 1)
590 | gmp_logit = self.gmp_fc(gmp.view(x.shape[0], -1))
591 | gmp_weight = list(self.gmp_fc.parameters())[0]
592 | gmp = x * gmp_weight.unsqueeze(2).unsqueeze(3)
593 |
594 | cam_logit = torch.cat([gap_logit, gmp_logit], 1)
595 | x = torch.cat([gap, gmp], 1)
596 | x = self.leaky_relu(self.conv1x1(x))
597 |
598 | heatmap = torch.sum(x, dim=1, keepdim=True)
599 |
600 | x = self.pad(x)
601 | out = self.conv(x) #out.shape torch.Size([1, 1, 6, 6])
602 |
603 | return out, cam_logit, heatmap
604 |
605 |
--------------------------------------------------------------------------------
/sketch_generation/operation.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import pandas as pd
4 | # from skimage import io, transform
5 | import numpy as np
6 | import matplotlib.pyplot as plt
7 | #from torch.utils.data import Dataset, DataLoader
8 | from torchvision import transforms, utils
9 | import subprocess as sp
10 | from PIL import Image
11 | import time
12 | import torch.utils.data as data
13 |
14 |
15 |
16 | ### model math functions
17 |
18 | # from skimage.color import hsv2rgb
19 | import torch.nn.functional as F
20 | import torch.nn as nn
21 |
22 | eps = 1e-7
23 | class HSV_Loss(nn.Module):
24 | def __init__(self, h=0, s=1, v=0.7):
25 | super(HSV_Loss, self).__init__()
26 | self.hsv = [h, s, v]
27 | self.l1 = nn.L1Loss()
28 | self.mse = nn.MSELoss()
29 |
30 | @staticmethod
31 | def get_h(im):
32 | img = im * 0.5 + 0.5
33 | b, c, h, w = img.shape
34 | hue = torch.Tensor(im.shape[0], im.shape[2], im.shape[3]).to(im.device)
35 | hue[img[:,2]==img.max(1)[0]] = 4.0+((img[:,0]-img[:,1])/(img.max(1)[0] - img.min(1)[0]))[img[:,2]==img.max(1)[0]]
36 | hue[img[:,1]==img.max(1)[0]] = 2.0+((img[:,2]-img[:,0])/(img.max(1)[0] - img.min(1)[0]))[img[:,1]==img.max(1)[0]]
37 | hue[img[:,0]==img.max(1)[0]] = ((img[:,1]-img[:,2])/(img.max(1)[0] - img.min(1)[0]))[img[:,0]==img.max(1)[0]]
38 | hue = (hue/6.0) % 1.0
39 | hue[img.min(1)[0]==img.max(1)[0]] = 0.0
40 | return hue
41 |
42 | @staticmethod
43 | def get_v(im):
44 | img = im * 0.5 + 0.5
45 | b, c, h, w = img.shape
46 | it = img.transpose(1,2).transpose(2,3).contiguous().view(b, -1, c)
47 | value = F.max_pool1d(it, c).view(b, h, w)
48 | return value
49 |
50 | @staticmethod
51 | def get_s(im):
52 | img = im * 0.5 + 0.5
53 | b, c, h, w = img.shape
54 | it = img.transpose(1,2).transpose(2,3).contiguous().view(b, -1, c)
55 | max_v = F.max_pool1d(it, c).view(b, h, w)
56 | min_v = F.max_pool1d(it*-1, c).view(b, h, w)
57 | satur = (max_v + min_v) / (max_v+eps)
58 | return satur
59 |
60 | def forward(self, input):
61 | h = self.get_h(input)
62 | s = self.get_s(input)
63 | v = self.get_v(input)
64 | target_h = torch.Tensor(h.shape).fill_(self.hsv[0]).to(input.device).type_as(h)
65 | target_s = torch.Tensor(s.shape).fill_(self.hsv[1]).to(input.device)
66 | target_v = torch.Tensor(v.shape).fill_(self.hsv[2]).to(input.device)
67 | return self.mse(h, target_h) #+ 0.4*self.mse(v, target_v)
68 |
69 |
70 |
71 | ### data loading functions
72 | def InfiniteSampler(n):
73 | # i = 0
74 | i = n - 1
75 | order = np.random.permutation(n)
76 | while True:
77 | yield order[i]
78 | i += 1
79 | if i >= n:
80 | np.random.seed()
81 | order = np.random.permutation(n)
82 | i = 0
83 |
84 | class InfiniteSamplerWrapper(data.sampler.Sampler):
85 | def __init__(self, data_source):
86 | self.num_samples = len(data_source)
87 |
88 | def __iter__(self):
89 | return iter(InfiniteSampler(self.num_samples))
90 |
91 | def __len__(self):
92 | return 2 ** 31
93 |
94 |
95 | def _rescale(img):
96 | return img * 2.0 - 1.0
97 |
98 | def trans_maker(size=256):
99 | trans = transforms.Compose([
100 | transforms.Resize((size+36, size+36)),
101 | transforms.RandomHorizontalFlip(),
102 | transforms.RandomCrop((size, size)),
103 | transforms.ToTensor(),
104 | _rescale
105 | ])
106 | return trans
107 |
108 | def trans_maker_testing(size=256):
109 | trans = transforms.Compose([
110 | transforms.Resize((size, size)),
111 | transforms.ToTensor(),
112 | _rescale
113 | ])
114 | return trans
115 | transform_gan = trans_maker(size=128)
116 |
117 | import torchvision.utils as vutils
118 | import logging
119 | logger = logging.getLogger(__name__)
120 |
121 |
122 |
123 | ### during training util functions
124 | def save_image(net, dataloader_A, device, cur_iter, trial, save_path):
125 | """Save imag output from net"""
126 | logger.info('Saving gan epoch {} images: {}'.format(cur_iter, save_path))
127 |
128 | # Set net to evaluation mode
129 | net.eval()
130 | for p in net.parameters():
131 | data_type = p.type()
132 | break
133 | with torch.no_grad():
134 | for itx, data in enumerate(dataloader_A):
135 | g_img = net.gen_a2b(data[0].to(device).type(data_type))
136 | for i in range(g_img.size(0)):
137 | vutils.save_image(
138 | g_img.cpu().float().add_(1).mul_(0.5),
139 | os.path.join(save_path, "{}_gan_epoch_{}_iter_{}_{}.jpg".format(trial, cur_iter, itx, i)),)
140 | # Set net to train mode
141 | net.train()
142 | return save_path
143 |
144 | def save_model(net, save_folder, cuda_device, if_multi_gpu, trial, cur_iter):
145 | """ Save current model and delete previous model, keep the saved model!"""
146 | save_name = "{}_gan_epoch_{}.pth".format(trial, cur_iter)
147 | save_path = os.path.join(save_folder, save_name)
148 | logger.info('Saving gan model: {}'.format(save_path))
149 |
150 | net.save(save_path)
151 |
152 | for fname in os.listdir(save_folder):
153 | if fname.endswith('.pth') and fname != save_name:
154 | delete_path = os.path.join(save_folder, fname)
155 | os.remove(delete_path)
156 | logger.info('Deleted previous gan model: {}'.format(delete_path))
157 |
158 | return save_path
--------------------------------------------------------------------------------
/sketch_generation/requirement.txt:
--------------------------------------------------------------------------------
1 | python 3.7.10
2 | torch 1.12.1
3 |
--------------------------------------------------------------------------------
/sketch_generation/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 |
4 | import torch,gc
5 | import torch.nn.functional as F
6 | import torch.optim as optim
7 | import torchvision.datasets as Dataset
8 | import torchvision.utils as vutils
9 | from torch import nn
10 | from torch.utils.data import DataLoader
11 | import matplotlib.pyplot as plt
12 | from models import Generator, Discriminator, VGGSimple, Adaptive_pool,AdaLIN, get_batched_gram_matrix,Generator_UGATIT,adain
13 | from operation import InfiniteSamplerWrapper, trans_maker
14 | import numpy
15 | import argparse
16 | import tqdm
17 | from metrics import calculate_fid_modify
18 |
19 | torch.backends.cudnn.benchmark = True
20 |
21 |
22 |
23 | def creat_folder(save_folder, trial_name):
24 | saved_model_folder = os.path.join(save_folder, 'train_results/%s/models'%trial_name)
25 | saved_image_folder = os.path.join(save_folder, 'train_results/%s/images'%trial_name)
26 | folders = [os.path.join(save_folder, 'train_results'), os.path.join(save_folder, 'train_results/%s'%trial_name),
27 | os.path.join(save_folder, 'train_results/%s/images'%trial_name), os.path.join(save_folder, 'train_results/%s/models'%trial_name)]
28 |
29 | for folder in folders:
30 | if not os.path.exists(folder):
31 | os.mkdir(folder)
32 | return saved_model_folder, saved_image_folder
33 |
34 | def train_d(net, data, label="real"):
35 | pred = net(data)
36 | if label=="real":
37 | err = F.relu(1-pred).mean()
38 | else:
39 | err = F.relu(1+pred).mean()
40 |
41 | err.backward()
42 | return torch.sigmoid(pred).mean().item()
43 |
44 | def gram_matrix(input):
45 | a, b, c, d = input.size() # a=batch size(=1)
46 | # b=number of feature maps
47 | # (c,d)=dimensions of a f. map (N=c*d)
48 | features = input.view(a * b, c * d) # resise F_XL into \hat F_XL
49 | G = torch.mm(features, features.t()) # compute the gram product
50 | # we 'normalize' the values of the gram matrix
51 | # by dividing by the number of element in each feature maps.
52 | return G.div(a * b * c * d)
53 |
54 | def gram_loss(input, target):
55 | in_gram = gram_matrix(input)
56 | tar_gram = gram_matrix(target.detach())
57 | return F.mse_loss(in_gram, tar_gram)
58 |
59 | def save_image(net_g, dataloader, saved_image_folder, n_iter):
60 | net_g.eval()
61 | with torch.no_grad():
62 | imgs = []
63 | real = []
64 | for i, d in enumerate(dataloader):
65 | if i < 2:
66 | # net_f=netG_A2B(d[0].to(device))[0]
67 | # f_3 = vgg(d[0].to(device), base=base)[2]
68 | f_3 = vgg(d[0].to(device), base=base)[2]
69 | imgs.append(net_g(f_3).cpu())
70 | real.append(d[0])
71 | gc.collect()
72 | torch.cuda.empty_cache()
73 | else:
74 | break
75 | imgs = torch.cat(imgs, dim=0)
76 | real = torch.cat(real, dim=0)
77 | sss = torch.cat([imgs, real], dim=0)
78 | # 计算fid指标
79 | fid = calculate_fid_modify(imgs[0,0,:,:].detach().numpy(), real[0,0,:,:].detach().numpy())
80 | print('fid-------------', fid)
81 | vutils.save_image( sss, "%s/iter_%d.jpg"%(saved_image_folder, n_iter), range=(-1,1), normalize=True)
82 | del imgs
83 | net_g.train()
84 |
85 | def train(net_g, net_d_style, max_iteration):
86 | print('training begin ... ')
87 | titles = ['D_r', 'D_f', 'G', 'G_rec']
88 | losses = {title: 0.0 for title in titles}
89 |
90 | saved_model_folder, saved_image_folder = creat_folder(save_folder, trial_name)
91 |
92 | for n_iter in tqdm.tqdm(range(max_iteration+1)):
93 | if (n_iter+1)%(100)==0:
94 | try:
95 | model_dict = {'g': net_g.state_dict(), 'ds':net_d_style.state_dict()}
96 | torch.save(model_dict, os.path.join(saved_model_folder, '%d_model.pth'%(n_iter)))
97 | opt_dict = {'g': optG.state_dict(), 'ds':optDS.state_dict()}
98 | torch.save(opt_dict, os.path.join(saved_model_folder, '%d_opt.pth'%(n_iter)))
99 | except:
100 | print("models not properly saved")
101 | if n_iter%100==0:
102 | save_image(net_g, dataloader_A_fixed, saved_image_folder, n_iter)
103 |
104 | ## 1. prepare data
105 | real_style = next(dataloader_B)[0].to(device)
106 | real_content = next(dataloader_A)[0].to(device)
107 |
108 | cf_1, cf_2, cf_3, cf_4, cf_5 = vgg(real_content, base=base)
109 | sf_1, sf_2, sf_3, sf_4, sf_5 = vgg(real_style, base=base)
110 |
111 | fake_img = net_g(cf_3)
112 | tf_1, tf_2, tf_3, tf_4, tf_5 = vgg(fake_img, base=base)
113 | target_3 = adain(cf_3, sf_3) #torch.Size([4, 256, 32, 32])
114 | # target_3 = AdaLIN(cf_3, sf_3) #更换为AdaLIN
115 | gram_sf_4 = gram_reshape(get_batched_gram_matrix(sf_4))
116 | gram_sf_3 = gram_reshape(get_batched_gram_matrix(sf_3))
117 | gram_sf_2 = gram_reshape(get_batched_gram_matrix(sf_2))
118 | real_style_sample = torch.cat([gram_sf_2, gram_sf_3, gram_sf_4], dim=1)
119 |
120 | gram_tf_4 = gram_reshape(get_batched_gram_matrix(tf_4))
121 | gram_tf_3 = gram_reshape(get_batched_gram_matrix(tf_3))
122 | gram_tf_2 = gram_reshape(get_batched_gram_matrix(tf_2))
123 | fake_style_sample = torch.cat([gram_tf_2, gram_tf_3, gram_tf_4], dim=1)
124 |
125 | ## 3. train Discriminator
126 | net_d_style.zero_grad()
127 |
128 | ### 3.1. train D_style on real data
129 | D_R = train_d(net_d_style, real_style_sample, label="real")
130 | ### 3.2. train D_style on fake data
131 | D_F = train_d(net_d_style, fake_style_sample.detach(), label="fake")
132 |
133 | optDS.step()
134 |
135 | ## 2. train Generator
136 | net_g.zero_grad()
137 | ### 2.1. train G as real image
138 | pred_gs = net_d_style(fake_style_sample)
139 | err_gs = -pred_gs.mean()
140 | G_B = torch.sigmoid(pred_gs).mean().item() #+ torch.sigmoid(pred_gc).mean().item()
141 |
142 | err_rec = F.mse_loss(tf_3, target_3)
143 | err_gram = 2000*(
144 | gram_loss(tf_4, sf_4) + \
145 | gram_loss(tf_3, sf_3) + \
146 | gram_loss(tf_2, sf_2))
147 |
148 | G_rec = err_gram.item()
149 |
150 |
151 |
152 | err = err_gs + mse_weight*err_rec + gram_weight*err_gram
153 | err.backward()
154 |
155 | optG.step()
156 |
157 | ## logging ~
158 | loss_values = [D_R, D_F, G_B, G_rec]
159 | for i, term in enumerate(titles):
160 | losses[term] += loss_values[i]
161 |
162 | if n_iter > 0 and n_iter % log_interval == 0:
163 | log_line = ""
164 | for key, value in losses.items():
165 | log_line += "%s: %.5f "%(key, value/log_interval)
166 | losses[key] = 0
167 | print(log_line)
168 |
169 |
170 |
171 | if __name__ == '__main__':
172 |
173 | parser = argparse.ArgumentParser(description='Style transfer GAN, during training, the model will learn to take a image from one specific catagory and transform it into another style domain')
174 | print(os.path.join(os.getcwd(),"art-landscape-rgb-512"))
175 | patha="RGB/"
176 | pathb="Sketch/"
177 | parser.add_argument('--path_a', type=str, default=patha, help='path of resource dataset, should be a folder that has one or many sub image folders inside')
178 | parser.add_argument('--path_b', type=str, default=pathb, help='path of target dataset, should be a folder that has one or many sub image folders inside')
179 | parser.add_argument('--im_size', type=int, default=256, help='resolution of the generated images')
180 | parser.add_argument('--trial_name', type=str, default="test2", help='a brief description of the training trial')
181 | parser.add_argument('--gpu_id', type=int, default=0, help='0 is the first gpu, 1 is the second gpu, etc.')
182 | parser.add_argument('--lr', type=float, default=2e-4, help='learning rate, default is 2e-4, usually dont need to change it, you can try make it smaller, such as 1e-4')
183 | parser.add_argument('--batch_size', type=int, default=4, help='how many images to train together at one iteration')
184 | parser.add_argument('--total_iter', type=int, default=7000, help='how many iterations to train in total, the value is in assumption that init step is 1')
185 | parser.add_argument('--mse_weight', default=0.2, type=float, help='let G generate images with content more like in set A')
186 | parser.add_argument('--gram_weight', default=1, type=float, help='let G generate images with style more like in set B')
187 | parser.add_argument('--checkpoint', default='None', type=str, help='specify the path of the pre-trained model')
188 |
189 | args = parser.parse_args()
190 |
191 | print(str(args))
192 |
193 | trial_name = args.trial_name
194 | data_root_A = args.path_a
195 | data_root_B = args.path_b
196 | mse_weight = args.mse_weight
197 | gram_weight = args.gram_weight
198 | max_iteration = args.total_iter
199 | # device = torch.device("cuda:%d"%(args.gpu_id))
200 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
201 | print(torch.cuda.is_available())
202 | # print(torch.cuda.current_device())
203 | im_size = args.im_size
204 | if im_size == 128:
205 | base = 4
206 | elif im_size == 256:
207 | base = 8
208 | elif im_size == 512:
209 | base = 16
210 | if im_size not in [128, 256, 512]:
211 | print("the size must be in [128, 256, 512]")
212 |
213 |
214 | log_interval = 100
215 | save_folder = './model'
216 | number_model_to_save = 30
217 |
218 | vgg = VGGSimple()
219 | root_path=os.getcwd()
220 | vgg.load_state_dict(torch.load(os.path.join(root_path,'vgg-feature-weights.pth'), map_location=lambda a,b:a))
221 | vgg.to(device)
222 | vgg.eval()
223 |
224 |
225 |
226 | for p in vgg.parameters():
227 | p.requires_grad = False
228 |
229 | dataset_A = Dataset.ImageFolder(root=data_root_A, transform=trans_maker(args.im_size))
230 | dataloader_A_fixed = DataLoader(dataset_A, 8, shuffle=False, num_workers=0)
231 | dataloader_A = iter(DataLoader(dataset_A, args.batch_size, shuffle=False,\
232 | sampler=InfiniteSamplerWrapper(dataset_A), num_workers=4, pin_memory=False))
233 |
234 | dataset_B = Dataset.ImageFolder(root=data_root_B, transform=trans_maker(args.im_size))
235 | dataloader_B = iter(DataLoader(dataset_B, args.batch_size, shuffle=False,\
236 | sampler=InfiniteSamplerWrapper(dataset_B), num_workers=0, pin_memory=False))
237 |
238 | net_g = Generator(infc=256, nfc=128)
239 | netG_A2B = Generator_UGATIT(image_size=256).to(device)
240 | net_d_style = Discriminator(nfc=128*3, norm_layer=nn.BatchNorm2d)
241 | gram_reshape = Adaptive_pool(128, 16)
242 | # this style discriminator take input: 512x512 gram matrix from 512x8x8 vgg feature,
243 | # the reshaped pooled input size is: 256x16x16
244 |
245 | if args.checkpoint != 'None':
246 | checkpoint = torch.load(args.checkpoint, map_location=lambda storage, loc: storage)
247 | net_g.load_state_dict(checkpoint['g'])
248 | net_d_style.load_state_dict(checkpoint['ds'])
249 | print("saved model loaded")
250 |
251 | net_d_style.to(device)
252 | net_g.to(device)
253 |
254 | optG = optim.Adam(net_g.parameters(), lr=args.lr, betas=(0.5, 0.99))
255 | optDS = optim.Adam(net_d_style.parameters(), lr=args.lr, betas=(0.5, 0.99))
256 |
257 | if args.checkpoint != 'None':
258 | opt_path = args.checkpoint.replace("_model.pth", "_opt.pth")
259 | try:
260 | opt_weights = torch.load(opt_path, map_location=lambda a, b: a)
261 | optG.load_state_dict(opt_weights['g'])
262 | optDS.load_state_dict(opt_weights['ds'])
263 | print("saved optimizer loaded")
264 | except:
265 | print("no optimizer weights detected, resuming a training without optimizer weights may not let the model converge as desired")
266 | pass
267 |
268 |
269 | train(net_g, net_d_style, max_iteration)
270 |
271 |
--------------------------------------------------------------------------------
/sketch_generation/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from copy import deepcopy
4 | from random import shuffle
5 | import torch.nn.functional as F
6 |
7 | def d_hinge_loss(real_pred, fake_pred):
8 | real_loss = F.relu(1-real_pred)
9 | fake_loss = F.relu(1+fake_pred)
10 |
11 | return real_loss.mean() + fake_loss.mean()
12 |
13 |
14 | def g_hinge_loss(pred):
15 | return -pred.mean()
16 |
17 |
18 | class AverageMeter(object):
19 |
20 | def __init__(self):
21 | self.reset()
22 |
23 | def reset(self):
24 | self.val = 0
25 | self.avg = 0
26 | self.sum = 0
27 | self.count = 0
28 |
29 | def update(self, val, n=1):
30 | self.val = val
31 | self.sum += val * n
32 | self.count += n
33 | self.avg = self.sum / self.count
34 |
35 |
36 | def true_randperm(size, device='cuda'):
37 | def unmatched_randperm(size):
38 | l1 = [i for i in range(size)]
39 | l2 = []
40 | for j in range(size):
41 | deleted = False
42 | if j in l1:
43 | deleted = True
44 | del l1[l1.index(j)]
45 | shuffle(l1)
46 | if len(l1) == 0:
47 | return 0, False
48 | l2.append(l1[0])
49 | del l1[0]
50 | if deleted:
51 | l1.append(j)
52 | return l2, True
53 | flag = False
54 | l = torch.zeros(size).long()
55 | while not flag:
56 | l, flag = unmatched_randperm(size)
57 | return torch.LongTensor(l).to(device)
58 |
59 |
60 | def copy_G_params(model):
61 | flatten = deepcopy(list(p.data for p in model.parameters()))
62 | return flatten
63 |
64 |
65 | def load_params(model, new_param):
66 | for p, new_p in zip(model.parameters(), new_param):
67 | p.data.copy_(new_p)
68 |
69 |
70 | def make_folders(save_folder, trial_name):
71 | saved_model_folder = os.path.join(save_folder, 'train_results/%s/models'%trial_name)
72 | saved_image_folder = os.path.join(save_folder, 'train_results/%s/images'%trial_name)
73 | folders = [os.path.join(save_folder, 'train_results'),
74 | os.path.join(save_folder, 'train_results/%s'%trial_name),
75 | os.path.join(save_folder, 'train_results/%s/images'%trial_name),
76 | os.path.join(save_folder, 'train_results/%s/models'%trial_name)]
77 | for folder in folders:
78 | if not os.path.exists(folder):
79 | os.mkdir(folder)
80 |
81 | from shutil import copy
82 | try:
83 | for f in os.listdir('.'):
84 | if '.py' in f:
85 | copy(f, os.path.join(save_folder, 'train_results/%s'%trial_name)+'/'+f)
86 | except:
87 | pass
88 | return saved_image_folder, saved_model_folder
89 |
90 |
91 |
92 | import cv2
93 | import numpy as np
94 | import math
95 |
96 | #####################
97 | # Both horizontal and vertical
98 | def warp(img, mag=10, freq=100):
99 | rows, cols = img.shape
100 |
101 | img_output = np.zeros(img.shape, dtype=img.dtype)
102 |
103 | for i in range(rows):
104 | for j in range(cols):
105 | offset_x = int(mag * math.sin(2 * 3.14 * i / freq))
106 | offset_y = int(mag * math.cos(2 * 3.14 * j / freq))
107 | if i+offset_y < rows and j+offset_x < cols:
108 | img_output[i,j] = img[(i+offset_y)%rows,(j+offset_x)%cols]
109 | else:
110 | img_output[i,j] = 0
111 |
112 | return img_output
113 |
114 | #img = cv2.imread('1.png', cv2.IMREAD_GRAYSCALE)
115 | #img_output = warp(img, mag=10, freq=200)
116 | #cv2.imwrite('Multidirectional_wave.jpg', img_output)
117 |
--------------------------------------------------------------------------------
/sketch_generation/vgg-feature-weights.z01:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ExponentiAI/StyleMe/64bf50eba799feba7c077cbf4ef4507ddf5c81d9/sketch_generation/vgg-feature-weights.z01
--------------------------------------------------------------------------------
/sketch_generation/vgg-feature-weights.z02:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ExponentiAI/StyleMe/64bf50eba799feba7c077cbf4ef4507ddf5c81d9/sketch_generation/vgg-feature-weights.z02
--------------------------------------------------------------------------------
/sketch_generation/vgg-feature-weights.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ExponentiAI/StyleMe/64bf50eba799feba7c077cbf4ef4507ddf5c81d9/sketch_generation/vgg-feature-weights.zip
--------------------------------------------------------------------------------
/styleme/benchmark.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torchvision import models
5 | from torchvision.models import inception_v3, Inception3
6 | from torchvision.utils import save_image
7 | from torchvision import utils as vutils
8 | from torch.utils.data import DataLoader
9 |
10 | try:
11 | from torchvision.models.utils import load_state_dict_from_url
12 | except ImportError:
13 | from torch.utils.model_zoo import load_url as load_state_dict_from_url
14 |
15 | import numpy as np
16 | from scipy import linalg
17 | from tqdm import tqdm
18 | import pickle
19 | import os
20 | from utils import true_randperm
21 | from datasets import InfiniteSamplerWrapper
22 |
23 | # Inception weights ported to Pytorch from
24 | # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
25 | FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth'
26 |
27 |
28 | class InceptionV3(nn.Module):
29 | """Pretrained InceptionV3 network returning feature maps"""
30 |
31 | # Index of default block of inception to return,
32 | # corresponds to output of final average pooling
33 | DEFAULT_BLOCK_INDEX = 3
34 |
35 | # Maps feature dimensionality to their output blocks indices
36 | BLOCK_INDEX_BY_DIM = {
37 | 64: 0, # First max pooling features
38 | 192: 1, # Second max pooling featurs
39 | 768: 2, # Pre-aux classifier features
40 | 2048: 3 # Final average pooling features
41 | }
42 |
43 | def __init__(self,
44 | output_blocks=[DEFAULT_BLOCK_INDEX],
45 | resize_input=True,
46 | normalize_input=True,
47 | requires_grad=False,
48 | use_fid_inception=True):
49 | """Build pretrained InceptionV3
50 | Parameters
51 | ----------
52 | output_blocks : list of int
53 | Indices of blocks to return features of. Possible values are:
54 | - 0: corresponds to output of first max pooling
55 | - 1: corresponds to output of second max pooling
56 | - 2: corresponds to output which is fed to aux classifier
57 | - 3: corresponds to output of final average pooling
58 | resize_input : bool
59 | If true, bilinearly resizes input to width and height 299 before
60 | feeding input to model. As the network without fully connected
61 | layers is fully convolutional, it should be able to handle inputs
62 | of arbitrary size, so resizing might not be strictly needed
63 | normalize_input : bool
64 | If true, scales the input from range (0, 1) to the range the
65 | pretrained Inception network expects, namely (-1, 1)
66 | requires_grad : bool
67 | If true, parameters of the model require gradients. Possibly useful
68 | for finetuning the network
69 | use_fid_inception : bool
70 | If true, uses the pretrained Inception model used in Tensorflow's
71 | FID implementation. If false, uses the pretrained Inception model
72 | available in torchvision. The FID Inception model has different
73 | weights and a slightly different structure from torchvision's
74 | Inception model. If you want to compute FID scores, you are
75 | strongly advised to set this parameter to true to get comparable
76 | results.
77 | """
78 | super(InceptionV3, self).__init__()
79 |
80 | self.resize_input = resize_input
81 | self.normalize_input = normalize_input
82 | self.output_blocks = sorted(output_blocks)
83 | self.last_needed_block = max(output_blocks)
84 |
85 | assert self.last_needed_block <= 3, \
86 | 'Last possible output block index is 3'
87 |
88 | self.blocks = nn.ModuleList()
89 |
90 | if use_fid_inception:
91 | inception = fid_inception_v3()
92 | else:
93 | inception = models.inception_v3(pretrained=True)
94 |
95 | # Block 0: input to maxpool1
96 | block0 = [
97 | inception.Conv2d_1a_3x3,
98 | inception.Conv2d_2a_3x3,
99 | inception.Conv2d_2b_3x3,
100 | nn.MaxPool2d(kernel_size=3, stride=2)
101 | ]
102 | self.blocks.append(nn.Sequential(*block0))
103 |
104 | # Block 1: maxpool1 to maxpool2
105 | if self.last_needed_block >= 1:
106 | block1 = [
107 | inception.Conv2d_3b_1x1,
108 | inception.Conv2d_4a_3x3,
109 | nn.MaxPool2d(kernel_size=3, stride=2)
110 | ]
111 | self.blocks.append(nn.Sequential(*block1))
112 |
113 | # Block 2: maxpool2 to aux classifier
114 | if self.last_needed_block >= 2:
115 | block2 = [
116 | inception.Mixed_5b,
117 | inception.Mixed_5c,
118 | inception.Mixed_5d,
119 | inception.Mixed_6a,
120 | inception.Mixed_6b,
121 | inception.Mixed_6c,
122 | inception.Mixed_6d,
123 | inception.Mixed_6e,
124 | ]
125 | self.blocks.append(nn.Sequential(*block2))
126 |
127 | # Block 3: aux classifier to final avgpool
128 | if self.last_needed_block >= 3:
129 | block3 = [
130 | inception.Mixed_7a,
131 | inception.Mixed_7b,
132 | inception.Mixed_7c,
133 | nn.AdaptiveAvgPool2d(output_size=(1, 1))
134 | ]
135 | self.blocks.append(nn.Sequential(*block3))
136 |
137 | for param in self.parameters():
138 | param.requires_grad = requires_grad
139 |
140 | def forward(self, inp):
141 | """Get Inception feature maps
142 | Parameters
143 | ----------
144 | inp : torch.autograd.Variable
145 | Input tensor of shape Bx3xHxW. Values are expected to be in
146 | range (0, 1)
147 | Returns
148 | -------
149 | List of torch.autograd.Variable, corresponding to the selected output
150 | block, sorted ascending by index
151 | """
152 | outp = []
153 | x = inp
154 |
155 | if self.resize_input:
156 | x = F.interpolate(x,
157 | size=(299, 299),
158 | mode='bilinear',
159 | align_corners=False)
160 |
161 | if self.normalize_input:
162 | x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
163 |
164 | for idx, block in enumerate(self.blocks):
165 | x = block(x)
166 | if idx in self.output_blocks:
167 | outp.append(x)
168 |
169 | if idx == self.last_needed_block:
170 | break
171 |
172 | return outp
173 |
174 |
175 | def fid_inception_v3():
176 | """Build pretrained Inception model for FID computation
177 | The Inception model for FID computation uses a different set of weights
178 | and has a slightly different structure than torchvision's Inception.
179 | This method first constructs torchvision's Inception and then patches the
180 | necessary parts that are different in the FID Inception model.
181 | """
182 | inception = models.inception_v3(num_classes=1008,
183 | aux_logits=False,
184 | pretrained=False)
185 | inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
186 | inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
187 | inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
188 | inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
189 | inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
190 | inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
191 | inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
192 | inception.Mixed_7b = FIDInceptionE_1(1280)
193 | inception.Mixed_7c = FIDInceptionE_2(2048)
194 |
195 | state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
196 | inception.load_state_dict(state_dict)
197 | return inception
198 |
199 |
200 | class FIDInceptionA(models.inception.InceptionA):
201 | """InceptionA block patched for FID computation"""
202 |
203 | def __init__(self, in_channels, pool_features):
204 | super(FIDInceptionA, self).__init__(in_channels, pool_features)
205 |
206 | def forward(self, x):
207 | branch1x1 = self.branch1x1(x)
208 |
209 | branch5x5 = self.branch5x5_1(x)
210 | branch5x5 = self.branch5x5_2(branch5x5)
211 |
212 | branch3x3dbl = self.branch3x3dbl_1(x)
213 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
214 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
215 |
216 | # Patch: Tensorflow's average pool does not use the padded zero's in
217 | # its average calculation
218 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
219 | count_include_pad=False)
220 | branch_pool = self.branch_pool(branch_pool)
221 |
222 | outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
223 | return torch.cat(outputs, 1)
224 |
225 |
226 | class FIDInceptionC(models.inception.InceptionC):
227 | """InceptionC block patched for FID computation"""
228 |
229 | def __init__(self, in_channels, channels_7x7):
230 | super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
231 |
232 | def forward(self, x):
233 | branch1x1 = self.branch1x1(x)
234 |
235 | branch7x7 = self.branch7x7_1(x)
236 | branch7x7 = self.branch7x7_2(branch7x7)
237 | branch7x7 = self.branch7x7_3(branch7x7)
238 |
239 | branch7x7dbl = self.branch7x7dbl_1(x)
240 | branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
241 | branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
242 | branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
243 | branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
244 |
245 | # Patch: Tensorflow's average pool does not use the padded zero's in
246 | # its average calculation
247 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
248 | count_include_pad=False)
249 | branch_pool = self.branch_pool(branch_pool)
250 |
251 | outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
252 | return torch.cat(outputs, 1)
253 |
254 |
255 | class FIDInceptionE_1(models.inception.InceptionE):
256 | """First InceptionE block patched for FID computation"""
257 |
258 | def __init__(self, in_channels):
259 | super(FIDInceptionE_1, self).__init__(in_channels)
260 |
261 | def forward(self, x):
262 | branch1x1 = self.branch1x1(x)
263 |
264 | branch3x3 = self.branch3x3_1(x)
265 | branch3x3 = [
266 | self.branch3x3_2a(branch3x3),
267 | self.branch3x3_2b(branch3x3),
268 | ]
269 | branch3x3 = torch.cat(branch3x3, 1)
270 |
271 | branch3x3dbl = self.branch3x3dbl_1(x)
272 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
273 | branch3x3dbl = [
274 | self.branch3x3dbl_3a(branch3x3dbl),
275 | self.branch3x3dbl_3b(branch3x3dbl),
276 | ]
277 | branch3x3dbl = torch.cat(branch3x3dbl, 1)
278 |
279 | # Patch: Tensorflow's average pool does not use the padded zero's in
280 | # its average calculation
281 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
282 | count_include_pad=False)
283 | branch_pool = self.branch_pool(branch_pool)
284 |
285 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
286 | return torch.cat(outputs, 1)
287 |
288 |
289 | class FIDInceptionE_2(models.inception.InceptionE):
290 | """Second InceptionE block patched for FID computation"""
291 |
292 | def __init__(self, in_channels):
293 | super(FIDInceptionE_2, self).__init__(in_channels)
294 |
295 | def forward(self, x):
296 | branch1x1 = self.branch1x1(x)
297 |
298 | branch3x3 = self.branch3x3_1(x)
299 | branch3x3 = [
300 | self.branch3x3_2a(branch3x3),
301 | self.branch3x3_2b(branch3x3),
302 | ]
303 | branch3x3 = torch.cat(branch3x3, 1)
304 |
305 | branch3x3dbl = self.branch3x3dbl_1(x)
306 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
307 | branch3x3dbl = [
308 | self.branch3x3dbl_3a(branch3x3dbl),
309 | self.branch3x3dbl_3b(branch3x3dbl),
310 | ]
311 | branch3x3dbl = torch.cat(branch3x3dbl, 1)
312 |
313 | # Patch: The FID Inception model uses max pooling instead of average
314 | # pooling. This is likely an error in this specific Inception
315 | # implementation, as other Inception models use average pooling here
316 | # (which matches the description in the paper).
317 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
318 | branch_pool = self.branch_pool(branch_pool)
319 |
320 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
321 | return torch.cat(outputs, 1)
322 |
323 |
324 | class Inception3Feature(Inception3):
325 | def forward(self, x):
326 | if x.shape[2] != 299 or x.shape[3] != 299:
327 | x = F.interpolate(x, size=(299, 299), mode='bilinear', align_corners=True)
328 |
329 | x = self.Conv2d_1a_3x3(x) # 299 x 299 x 3
330 | x = self.Conv2d_2a_3x3(x) # 149 x 149 x 32
331 | x = self.Conv2d_2b_3x3(x) # 147 x 147 x 32
332 | x = F.max_pool2d(x, kernel_size=3, stride=2) # 147 x 147 x 64
333 |
334 | x = self.Conv2d_3b_1x1(x) # 73 x 73 x 64
335 | x = self.Conv2d_4a_3x3(x) # 73 x 73 x 80
336 | x = F.max_pool2d(x, kernel_size=3, stride=2) # 71 x 71 x 192
337 |
338 | x = self.Mixed_5b(x) # 35 x 35 x 192
339 | x = self.Mixed_5c(x) # 35 x 35 x 256
340 | x = self.Mixed_5d(x) # 35 x 35 x 288
341 |
342 | x = self.Mixed_6a(x) # 35 x 35 x 288
343 | x = self.Mixed_6b(x) # 17 x 17 x 768
344 | x = self.Mixed_6c(x) # 17 x 17 x 768
345 | x = self.Mixed_6d(x) # 17 x 17 x 768
346 | x = self.Mixed_6e(x) # 17 x 17 x 768
347 |
348 | x = self.Mixed_7a(x) # 17 x 17 x 768
349 | x = self.Mixed_7b(x) # 8 x 8 x 1280
350 | x = self.Mixed_7c(x) # 8 x 8 x 2048
351 |
352 | x = F.avg_pool2d(x, kernel_size=8) # 8 x 8 x 2048
353 |
354 | return x.view(x.shape[0], x.shape[1]) # 1 x 1 x 2048
355 |
356 |
357 | def load_patched_inception_v3():
358 | # inception = inception_v3(pretrained=True)
359 | # inception_feat = Inception3Feature()
360 | # inception_feat.load_state_dict(inception.state_dict())
361 | inception_feat = InceptionV3([3], normalize_input=False)
362 |
363 | return inception_feat
364 |
365 |
366 | @torch.no_grad()
367 | def extract_features(loader, inception, device):
368 | pbar = tqdm(loader)
369 |
370 | feature_list = []
371 |
372 | for img in pbar:
373 | img = img.to(device)
374 | feature = inception(img)[0].view(img.shape[0], -1)
375 | feature_list.append(feature.to('cpu'))
376 |
377 | features = torch.cat(feature_list, 0)
378 |
379 | return features
380 |
381 |
382 | @torch.no_grad()
383 | def extract_feature_from_generator_fn(generator_fn, inception, device='cuda', total=1000):
384 | features = []
385 |
386 | for batch in tqdm(generator_fn, total=total):
387 | try:
388 | feat = inception(batch)[0].view(batch.shape[0], -1)
389 | features.append(feat.to('cpu'))
390 | except:
391 | break
392 | features = torch.cat(features, 0).detach()
393 | return features.numpy()
394 |
395 |
396 | def calc_fid(sample_features, real_features=None, real_mean=None, real_cov=None, eps=1e-6):
397 | sample_mean = np.mean(sample_features, 0)
398 | sample_cov = np.cov(sample_features, rowvar=False)
399 |
400 | if real_features is not None:
401 | real_mean = np.mean(real_features, 0)
402 | real_cov = np.cov(real_features, rowvar=False)
403 |
404 | cov_sqrt, _ = linalg.sqrtm(sample_cov @ real_cov, disp=False)
405 |
406 | if not np.isfinite(cov_sqrt).all():
407 | print('product of cov matrices is singular')
408 | offset = np.eye(sample_cov.shape[0]) * eps
409 | cov_sqrt = linalg.sqrtm((sample_cov + offset) @ (real_cov + offset))
410 |
411 | if np.iscomplexobj(cov_sqrt):
412 | if not np.allclose(np.diagonal(cov_sqrt).imag, 0, atol=1e-3):
413 | m = np.max(np.abs(cov_sqrt.imag))
414 |
415 | raise ValueError(f'Imaginary component {m}')
416 |
417 | cov_sqrt = cov_sqrt.real
418 |
419 | mean_diff = sample_mean - real_mean
420 | mean_norm = mean_diff @ mean_diff
421 |
422 | trace = np.trace(sample_cov) + np.trace(real_cov) - 2 * np.trace(cov_sqrt)
423 |
424 | fid = mean_norm + trace
425 |
426 | return fid
427 |
428 |
429 | def real_image_loader(dataloader, n_batches=10):
430 | counter = 0
431 | while counter < n_batches:
432 | counter += 1
433 | rgb_img = next(dataloader)[0]
434 | if counter == 1:
435 | vutils.save_image(0.5 * (rgb_img + 1), './checkpoint/tmp_real.jpg')
436 | yield rgb_img.cuda()
437 |
438 |
439 | @torch.no_grad()
440 | def image_generator(dataset, net_ae, net_ig, BATCH_SIZE=8, n_batches=500):
441 | counter = 0
442 | dataloader = iter(
443 | DataLoader(dataset, BATCH_SIZE, sampler=InfiniteSamplerWrapper(dataset), num_workers=4, pin_memory=False))
444 | n_batches = min(n_batches, len(dataset) // BATCH_SIZE - 1)
445 | while counter < n_batches:
446 | counter += 1
447 | rgb_img, skt_img = next(dataloader)
448 | rgb_img = F.interpolate(rgb_img, size=256).cuda()
449 | skt_img = F.interpolate(skt_img, size=256).cuda()
450 |
451 | gimg_ae, style_feat = net_ae(skt_img, rgb_img)
452 | # g_image = gimg_ae
453 | g_image = net_ig(gimg_ae, style_feat)
454 | if counter == 1:
455 | vutils.save_image(0.5 * (g_image + 1), './checkpoint/tmp.jpg')
456 | yield g_image
457 |
458 |
459 | @torch.no_grad()
460 | def image_generator_perm(dataset, net_ae, net_ig, BATCH_SIZE=8, n_batches=500):
461 | counter = 0
462 | dataloader = iter(DataLoader(dataset, BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=False))
463 | n_batches = min(n_batches, len(dataset) // BATCH_SIZE - 1)
464 | while counter < n_batches:
465 | counter += 1
466 | rgb_img, skt_img = next(dataloader)
467 | rgb_img = F.interpolate(rgb_img, size=256).cuda()
468 | skt_img = F.interpolate(skt_img, size=256).cuda()
469 |
470 | perm = true_randperm(rgb_img.shape[0], device=rgb_img.device)
471 |
472 | gimg_ae, style_feat = net_ae(skt_img, rgb_img[perm])
473 | # g_image = gimg_ae
474 | g_image = net_ig(gimg_ae, style_feat)
475 | if counter == 1:
476 | vutils.save_image(0.5 * (g_image + 1), './checkpoint/tmp.jpg')
477 | yield g_image
478 |
479 |
480 | if __name__ == "__main__":
481 | from utils import PairedMultiDataset, InfiniteSamplerWrapper, make_folders, AverageMeter
482 | from torch.utils.data import DataLoader
483 | from torchvision import utils as vutils
484 |
485 | IM_SIZE = 1024
486 | BATCH_SIZE = 8
487 | DATALOADER_WORKERS = 8
488 | NBR_CLS = 2000
489 | TRIAL_NAME = 'trial_vae_512_1'
490 | SAVE_FOLDER = './'
491 |
492 | data_root_colorful = '../images/celebA/CelebA_512_test/img'
493 | data_root_sketch_1 = './sketch_simplification/vggadin_iter_700_test'
494 | data_root_sketch_2 = './sketch_simplification/vggadin_iter_1900_test'
495 | data_root_sketch_3 = './sketch_simplification/vggadin_iter_2300_test'
496 |
497 | dataset = PairedMultiDataset(data_root_colorful, data_root_sketch_1, data_root_sketch_2, data_root_sketch_3,
498 | im_size=IM_SIZE, rand_crop=False)
499 | dataloader = iter(DataLoader(dataset, BATCH_SIZE, shuffle=False, num_workers=DATALOADER_WORKERS, pin_memory=True))
500 |
501 | from models import StyleEncoder, ContentEncoder, Decoder
502 | import pickle
503 | from models import AE, RefineGenerator
504 | from utils import load_params
505 |
506 | net_ig = RefineGenerator().cuda()
507 | net_ig = nn.DataParallel(net_ig)
508 |
509 | ckpt = './train_results/trial_refine_ae_as_gan_1024_2/models/4.pth'
510 | if ckpt is not None:
511 | ckpt = torch.load(ckpt)
512 | # net_ig.load_state_dict(ckpt['ig'])
513 | # net_id.load_state_dict(ckpt['id'])
514 | net_ig_ema = ckpt['ig_ema']
515 | load_params(net_ig, net_ig_ema)
516 | net_ig = net_ig.module
517 | # net_ig.eval()
518 |
519 | net_ae = AE()
520 | net_ae.load_state_dicts('./train_results/trial_vae_512_1/models/176000.pth')
521 | net_ae.cuda()
522 | net_ae.eval()
523 |
524 | inception = load_patched_inception_v3().cuda()
525 | inception.eval()
526 |
527 | '''
528 | real_features = extract_feature_from_generator_fn(
529 | real_image_loader(dataloader, n_batches=1000), inception )
530 | real_mean = np.mean(real_features, 0)
531 | real_cov = np.cov(real_features, rowvar=False)
532 | '''
533 | # pickle.dump({'feats': real_features, 'mean': real_mean, 'cov': real_cov}, open('celeba_fid_feats.npy','wb') )
534 |
535 | real_features = pickle.load(open('celeba_fid_feats.npy', 'rb'))
536 | real_mean = real_features['mean']
537 | real_cov = real_features['cov']
538 | # sample_features = extract_feature_from_generator_fn( real_image_loader(dataloader, n_batches=100), inception )
539 | for it in range(1):
540 | itx = it * 8000
541 | '''
542 | ckpt = torch.load('./train_results/%s/models/%d.pth'%(TRIAL_NAME, itx))
543 |
544 | style_encoder.load_state_dict(ckpt['e'])
545 | content_encoder.load_state_dict(ckpt['c'])
546 | decoder.load_state_dict(ckpt['d'])
547 |
548 | dataloader = iter(DataLoader(dataset, BATCH_SIZE, sampler=InfiniteSamplerWrapper(dataset), num_workers=DATALOADER_WORKERS, pin_memory=True))
549 | '''
550 |
551 | sample_features = extract_feature_from_generator_fn(
552 | image_generator(dataset, net_ae, net_ig, n_batches=1800), inception,
553 | total=1800)
554 |
555 | # fid = calc_fid(sample_features, real_mean=real_features['mean'], real_cov=real_features['cov'])
556 | fid = calc_fid(sample_features, real_mean=real_mean, real_cov=real_cov)
557 |
558 | print(it, fid)
559 |
--------------------------------------------------------------------------------
/styleme/calculate.py:
--------------------------------------------------------------------------------
1 | ######################################
2 | # calculate FID and LPIPS #
3 | ######################################
4 |
5 | import torch
6 | import torch.nn.functional as F
7 | from torch.utils.data import DataLoader
8 |
9 | from tqdm import tqdm
10 | from datasets import PairedDataset, InfiniteSamplerWrapper
11 | from utils import AverageMeter
12 |
13 |
14 | def calculate_Lpips(data_root_colorful, data_root_sketch, model):
15 | import lpips
16 | from models import AE
17 | from models import RefineGenerator as Generator
18 |
19 | CHANNEL = 32
20 | NBR_CLS = 50
21 | IM_SIZE = 256
22 | BATCH_SIZE = 6
23 | DATALOADER_WORKERS = 2
24 |
25 | percept = lpips.PerceptualLoss(model='net-lin', net='vgg', use_gpu=True)
26 |
27 | # load dataset
28 | dataset = PairedDataset(data_root_colorful, data_root_sketch, im_size=IM_SIZE)
29 | print('the dataset contains %d images.' % len(dataset))
30 |
31 | dataloader = iter(DataLoader(dataset, BATCH_SIZE, sampler=InfiniteSamplerWrapper(dataset),
32 | num_workers=DATALOADER_WORKERS, pin_memory=True))
33 |
34 | # load ae model
35 | net_ae = AE(ch=CHANNEL, nbr_cls=NBR_CLS)
36 | net_ae.style_encoder.reset_cls()
37 | net_ig = Generator(ch=CHANNEL, im_size=IM_SIZE)
38 |
39 | PRETRAINED_PATH = './checkpoint/GAN.pth'.format(str(model))
40 | print('Pre-trained path : ', PRETRAINED_PATH)
41 | ckpt = torch.load(PRETRAINED_PATH)
42 |
43 | net_ae.load_state_dict(ckpt['ae'])
44 | net_ig.load_state_dict(ckpt['ig'])
45 |
46 | net_ae.cuda()
47 | net_ig.cuda()
48 | net_ae.eval()
49 | net_ig.eval()
50 |
51 | # lpips
52 | get_lpips = AverageMeter()
53 | lpips_list = []
54 |
55 | # Network
56 | for iter_data in tqdm(range(1000)):
57 | rgb_img, skt_img = next(dataloader)
58 |
59 | rgb_img = rgb_img.cuda()
60 | skt_img = skt_img.cuda()
61 |
62 | gimg_ae, style_feats = net_ae(skt_img, rgb_img)
63 | g_image = net_ig(gimg_ae, style_feats)
64 |
65 | loss_mse = 10 * percept(F.adaptive_avg_pool2d(g_image, output_size=256),
66 | F.adaptive_avg_pool2d(rgb_img, output_size=256)).sum()
67 | get_lpips.update(loss_mse.item() / BATCH_SIZE, BATCH_SIZE)
68 |
69 | lpips_list.append(get_lpips.avg)
70 |
71 | if (iter_data + 1) % 100 == 0:
72 | # print('avg : ', get_lpips.avg)
73 | print('LPIPS : ', sum(lpips_list) / len(lpips_list))
74 |
75 | print('LPIPS : ', sum(lpips_list) / len(lpips_list))
76 |
77 |
78 | def calculate_fid(data_root_colorful, data_root_sketch, model):
79 | from benchmark import calc_fid, extract_feature_from_generator_fn, load_patched_inception_v3, real_image_loader, \
80 | image_generator, image_generator_perm
81 | from models import AE
82 | from models import RefineGenerator as Generator
83 | import numpy as np
84 |
85 | CHANNEL = 32
86 | NBR_CLS = 50
87 | IM_SIZE = 256
88 | BATCH_SIZE = 8
89 | DATALOADER_WORKERS = 2
90 | fid_batch_images = 119
91 | fid_iters = 100
92 | inception = load_patched_inception_v3().cuda()
93 | inception.eval()
94 |
95 | fid = []
96 | fid_perm = []
97 |
98 | # load dataset
99 | dataset = PairedDataset(data_root_colorful, data_root_sketch, im_size=IM_SIZE)
100 | print('the dataset contains %d images.' % len(dataset))
101 |
102 | dataloader = iter(DataLoader(dataset, BATCH_SIZE, sampler=InfiniteSamplerWrapper(dataset),
103 | num_workers=DATALOADER_WORKERS, pin_memory=True))
104 |
105 | # load ae model
106 | net_ae = AE(ch=CHANNEL, nbr_cls=NBR_CLS)
107 | net_ae.style_encoder.reset_cls()
108 | net_ig = Generator(ch=CHANNEL, im_size=IM_SIZE)
109 |
110 | PRETRAINED_PATH = './checkpoint/GAN.pth'.format(str(model))
111 | print('Pre-trained path : ', PRETRAINED_PATH)
112 | ckpt = torch.load(PRETRAINED_PATH)
113 |
114 | net_ae.load_state_dict(ckpt['ae'])
115 | net_ig.load_state_dict(ckpt['ig'])
116 |
117 | net_ae.cuda()
118 | net_ig.cuda()
119 | net_ae.eval()
120 | net_ig.eval()
121 |
122 | print("calculating FID ...")
123 |
124 | real_features = extract_feature_from_generator_fn(
125 | real_image_loader(dataloader, n_batches=fid_batch_images), inception)
126 | real_mean = np.mean(real_features, 0)
127 | real_cov = np.cov(real_features, rowvar=False)
128 | real_features = {'feats': real_features, 'mean': real_mean, 'cov': real_cov}
129 |
130 | for iter_fid in range(fid_iters):
131 | sample_features = extract_feature_from_generator_fn(
132 | image_generator(dataset, net_ae, net_ig, n_batches=fid_batch_images),
133 | inception, total=fid_batch_images // BATCH_SIZE - 1)
134 | cur_fid = calc_fid(sample_features, real_mean=real_features['mean'], real_cov=real_features['cov'])
135 |
136 | sample_features_perm = extract_feature_from_generator_fn(
137 | image_generator_perm(dataset, net_ae, net_ig, n_batches=fid_batch_images),
138 | inception, total=fid_batch_images // BATCH_SIZE - 1)
139 | cur_fid_perm = calc_fid(sample_features_perm, real_mean=real_features['mean'],
140 | real_cov=real_features['cov'])
141 |
142 | print('FID[{}]: '.format(iter_fid), [cur_fid, cur_fid_perm])
143 | fid.append(cur_fid)
144 | fid_perm.append(cur_fid_perm)
145 |
146 | print('FID: ', sum(fid) / len(fid))
147 | print('FID perm: ', sum(fid_perm) / len(fid_perm))
148 |
149 |
150 | if __name__ == "__main__":
151 | model = 'styleme'
152 | data_root_colorful = './train_data/rgb/'
153 | data_root_sketch = './train_data/sketch/'
154 | # data_root_colorful = './train_data/comparison/rgb/'
155 | # data_root_sketch = './train_data/comparison/sketch_styleme/'
156 | # data_root_sketch = './train_data/comparison/sketch_cam/'
157 | # data_root_sketch = './train_data/comparison/sketch_adalin/'
158 | # data_root_sketch = './train_data/comparison/sketch_wo_camada/'
159 |
160 | calculate_Lpips(data_root_colorful, data_root_sketch, model)
161 | # calculate_fid(data_root_colorful, data_root_sketch, model)
162 |
163 | # styleme | 0.13515148047968645 | 16.034930465842525
164 | # styleme_wo | 0.4334833870760152 | 32.5567679015783
165 | # cam | 0.1373054370310368 | 17.165196809300138
166 | # adalin | 0.31896749291615123 | 28.387120218137913
167 | # camada | 0.36015568705948886 | 29.75984833745646
168 |
--------------------------------------------------------------------------------
/styleme/config.py:
--------------------------------------------------------------------------------
1 | #################################
2 | # training parameter #
3 | #################################
4 |
5 | DATALOADER_WORKERS = 2
6 | NBR_CLS = 50
7 |
8 | EPOCH_GAN = 100
9 | ITERATION_GAN = 2000
10 |
11 | SAVE_IMAGE_INTERVAL = 100
12 | SAVE_MODEL_INTERVAL = 200
13 | LOG_INTERVAL = 200
14 | FID_INTERVAL = 100
15 | FID_BATCH_NBR = 100
16 |
17 | ITERATION_AE = 20000
18 |
19 | CHANNEL = 32
20 | MULTI_GPU = False
21 |
22 | IM_SIZE_GAN = 256
23 | BATCH_SIZE_GAN = 8
24 |
25 | IM_SIZE_AE = 256
26 | BATCH_SIZE_AE = 8
27 |
28 | SAVE_FOLDER = './checkpoint/'
29 |
30 | # PRETRAINED_AE_PATH = './checkpoint/models/AE_20000.pth'
31 | PRETRAINED_AE_PATH = None
32 |
33 | # GAN_CKECKPOINT = './checkpoint/models/9.pth'
34 | GAN_CKECKPOINT = None
35 |
36 | TRAIN_AE_ONLY = False
37 | TRAIN_GAN_ONLY = False
38 |
39 | data_root_colorful = './train_data/rgb/'
40 | data_root_sketch = './train_data/sketch_styleme/'
41 | # data_root_sketch = './train_data/sketchgen_wo_cam/'
42 | # data_root_sketch = './train_data/sketchgen_wo_adalin/'
43 | # data_root_sketch = './train_data/sketchgen_wo_camada/'
44 |
--------------------------------------------------------------------------------
/styleme/datasets.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import numpy as np
4 | from PIL import Image, ImageFilter
5 | from PIL import ImageFile
6 |
7 | from torch.utils.data import Dataset
8 | from torchvision import transforms
9 | import torch.utils.data as data
10 |
11 | ImageFile.LOAD_TRUNCATED_IMAGES = True
12 |
13 |
14 | def _rescale(img):
15 | return img * 2.0 - 1.0
16 |
17 |
18 | def transform_data(im_size=256):
19 | trans = transforms.Compose([
20 | transforms.Resize((im_size, im_size)),
21 | transforms.ToTensor(),
22 | _rescale
23 | ])
24 | return trans
25 |
26 |
27 | class TransformData(Dataset):
28 | def __init__(self, data_rgb, data_sketch, im_size=256, nbr_cls=100):
29 | super(TransformData, self).__init__()
30 | self.rgb_root = data_rgb
31 | self.skt_root = data_sketch
32 |
33 | self.frame = self._parse_frame()
34 | random.shuffle(self.frame)
35 |
36 | self.nbr_cls = nbr_cls
37 | self.set_offset = 0
38 | self.im_size = im_size
39 |
40 | self.transform = transforms.Compose([
41 | transforms.Resize((im_size, im_size)),
42 | transforms.ToTensor(),
43 | _rescale
44 | ])
45 |
46 | self.transform_rd = transforms.Compose([
47 | transforms.Resize((int(im_size * 1.3), int(im_size * 1.3))),
48 | transforms.RandomCrop((int(im_size), int(im_size))),
49 | transforms.RandomRotation(30),
50 | transforms.RandomHorizontalFlip(p=1),
51 | transforms.Resize((im_size, im_size)),
52 | transforms.ToTensor(),
53 | _rescale
54 | ])
55 |
56 | self.transform_flip = transforms.Compose([
57 | transforms.RandomHorizontalFlip(p=0.8),
58 | transforms.RandomVerticalFlip(p=0.8),
59 | transforms.Resize((im_size, im_size)),
60 | transforms.ToTensor(),
61 | _rescale
62 | ])
63 |
64 | self.transform_erase = transforms.Compose([
65 | transforms.Resize((im_size, im_size)),
66 | transforms.ToTensor(),
67 | _rescale,
68 | transforms.RandomErasing(p=0.8, scale=(0.02, 0.1), value=1),
69 | transforms.RandomErasing(p=0.8, scale=(0.02, 0.1), value=1),
70 | transforms.RandomErasing(p=0.8, scale=(0.02, 0.1), value=1)])
71 |
72 | self.transform_bold = transforms.Compose([
73 | transforms.Resize((int(im_size * 1.1), int(im_size * 1.1))),
74 | transforms.Resize((im_size, im_size)),
75 | transforms.ToTensor(),
76 | _rescale
77 | ])
78 |
79 | def _parse_frame(self):
80 | frame = []
81 | img_names = os.listdir(self.rgb_root)
82 | img_names.sort()
83 | for i in range(len(img_names)):
84 | img_name = img_names[i].zfill(len(str(len(img_names))))
85 | rgb_path = os.path.join(self.rgb_root, img_name)
86 | skt_path = os.path.join(self.skt_root, img_name)
87 | if os.path.exists(rgb_path) and os.path.exists(skt_path):
88 | frame.append((rgb_path, skt_path))
89 |
90 | return frame
91 |
92 | def __len__(self):
93 | return self.nbr_cls
94 |
95 | def _next_set(self):
96 | self.set_offset += self.nbr_cls
97 | if self.set_offset > (len(self.frame) - self.nbr_cls):
98 | random.shuffle(self.frame)
99 | self.set_offset = 0
100 |
101 | def __getitem__(self, idx):
102 | file, skt_path = self.frame[idx + self.set_offset]
103 | rgb = Image.open(file).convert('RGB')
104 | skt = Image.open(skt_path).convert('L')
105 |
106 | img_normal = self.transform(rgb)
107 | img_rd = self.transform_rd(rgb)
108 | img_flip = self.transform_flip(rgb)
109 |
110 | skt_normal = self.transform(skt)
111 | skt_erase = self.transform_erase(skt)
112 | bold_factor = 3
113 | skt_bold = skt.filter(ImageFilter.MinFilter(size=bold_factor))
114 | skt_bold = self.transform_bold(skt_bold)
115 |
116 | return img_normal, img_rd, img_flip, skt_normal, skt_erase, skt_bold, idx
117 |
118 |
119 | def InfiniteSampler(n):
120 | i = n - 1
121 | order = np.random.permutation(n)
122 | while True:
123 | yield order[i]
124 | i += 1
125 | if i >= n:
126 | np.random.seed()
127 | order = np.random.permutation(n)
128 | i = 0
129 |
130 |
131 | class InfiniteSamplerWrapper(data.sampler.Sampler):
132 | def __init__(self, data_source):
133 | self.num_samples = len(data_source)
134 |
135 | def __iter__(self):
136 | return iter(InfiniteSampler(self.num_samples))
137 |
138 | def __len__(self):
139 | return 2 ** 31
140 |
141 |
142 | class PairedDataset(Dataset):
143 | def __init__(self, data_root_1, data_root_2, im_size=256):
144 | super(PairedDataset, self).__init__()
145 | self.root_a = data_root_1
146 | self.root_b = data_root_2
147 |
148 | self.frame = self._parse_frame()
149 | self.transform = transform_data(im_size)
150 |
151 | def _parse_frame(self):
152 | frame = []
153 | img_names = os.listdir(self.root_a)
154 | img_names.sort()
155 | for i in range(len(img_names)):
156 | img_name = '%s.jpg' % str(i).zfill(len(str(len(img_names))))
157 | image_a_path = os.path.join(self.root_a, img_names[i])
158 | if ('.jpg' in image_a_path) or ('.png' in image_a_path):
159 | image_b_path = os.path.join(self.root_b, img_name)
160 | if os.path.exists(image_b_path):
161 | frame.append((image_a_path, image_b_path))
162 |
163 | return frame
164 |
165 | def __len__(self):
166 | return len(self.frame)
167 |
168 | def __getitem__(self, idx):
169 | file_a, file_b = self.frame[idx]
170 | img_a = Image.open(file_a).convert('RGB')
171 | img_b = Image.open(file_b).convert('L')
172 |
173 | if self.transform:
174 | img_a = self.transform(img_a)
175 | img_b = self.transform(img_b)
176 |
177 | return (img_a, img_b)
178 |
179 |
180 | class ImageFolder(Dataset):
181 | def __init__(self, data_root, transform=transform_data(256)):
182 | super(ImageFolder, self).__init__()
183 | self.root = data_root
184 |
185 | self.frame = self._parse_frame()
186 | self.transform = transform
187 |
188 | def _parse_frame(self):
189 | frame = []
190 | img_names = os.listdir(self.root)
191 | img_names.sort()
192 | for i in range(len(img_names)):
193 | image_path = os.path.join(self.root, img_names[i])
194 | if ('.jpg' in image_path) or ('.png' in image_path):
195 | frame.append(image_path)
196 |
197 | return frame
198 |
199 | def __len__(self):
200 | return len(self.frame)
201 |
202 | def __getitem__(self, idx):
203 | file = self.frame[idx]
204 | img = Image.open(file).convert('RGB')
205 |
206 | if self.transform:
207 | img = self.transform(img)
208 | return img
209 |
--------------------------------------------------------------------------------
/styleme/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ExponentiAI/StyleMe/64bf50eba799feba7c077cbf4ef4507ddf5c81d9/styleme/framework.png
--------------------------------------------------------------------------------
/styleme/generate_matrix.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import DataLoader
3 | from torchvision import utils as vutils
4 |
5 |
6 | def make_matrix(dataset_rgb, dataset_skt, net_ae, net_ig, BATCH_SIZE, IM_SIZE, im_name):
7 | dataloader_rgb = iter(DataLoader(dataset_rgb, BATCH_SIZE, shuffle=True))
8 | dataloader_skt = iter(DataLoader(dataset_skt, BATCH_SIZE, shuffle=True))
9 |
10 | rgb_img = next(dataloader_rgb)
11 | skt_img = next(dataloader_skt)
12 |
13 | skt_img = skt_img.mean(dim=1, keepdim=True)
14 |
15 | image_matrix = [torch.ones(1, 3, IM_SIZE, IM_SIZE)]
16 | image_matrix.append(rgb_img.clone())
17 | with torch.no_grad():
18 | rgb_img = rgb_img.cuda()
19 | for skt in skt_img:
20 | input_skts = skt.unsqueeze(0).repeat(BATCH_SIZE, 1, 1, 1).cuda()
21 |
22 | gimg_ae, style_feats = net_ae(input_skts, rgb_img)
23 | image_matrix.append(skt.unsqueeze(0).repeat(1, 3, 1, 1).clone())
24 | image_matrix.append(gimg_ae.cpu())
25 |
26 | g_images = net_ig(gimg_ae, style_feats).cpu()
27 | image_matrix.append(skt.unsqueeze(0).repeat(1, 3, 1, 1).clone().fill_(1))
28 | image_matrix.append(torch.nn.functional.interpolate(g_images, IM_SIZE))
29 |
30 | image_matrix = torch.cat(image_matrix)
31 | vutils.save_image(0.5 * (image_matrix + 1), im_name, nrow=BATCH_SIZE + 1)
32 |
--------------------------------------------------------------------------------
/styleme/lpips/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import numpy as np
6 | from skimage.metrics import structural_similarity as compare_ssim
7 | import torch
8 | from torch.autograd import Variable
9 |
10 | from lpips import dist_model
11 |
12 |
13 | class PerceptualLoss(torch.nn.Module):
14 | def __init__(self, model='net-lin', net='alex', colorspace='rgb', spatial=False, use_gpu=True,
15 | gpu_ids=[0]): # VGG using our perceptually-learned weights (LPIPS metric)
16 | # def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss
17 | super(PerceptualLoss, self).__init__()
18 | print('Setting up Perceptual loss...')
19 | self.use_gpu = use_gpu
20 | self.spatial = spatial
21 | self.gpu_ids = gpu_ids
22 | self.model = dist_model.DistModel()
23 | self.model.initialize(model=model, net=net, use_gpu=use_gpu, colorspace=colorspace, spatial=self.spatial,
24 | gpu_ids=gpu_ids)
25 | print('...[%s] initialized' % self.model.name())
26 | print('...Done')
27 |
28 | def forward(self, pred, target, normalize=False):
29 | """
30 | Pred and target are Variables.
31 | If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1]
32 | If normalize is False, assumes the images are already between [-1,+1]
33 |
34 | Inputs pred and target are Nx3xHxW
35 | Output pytorch Variable N long
36 | """
37 |
38 | if normalize:
39 | target = 2 * target - 1
40 | pred = 2 * pred - 1
41 |
42 | return self.model.forward(target, pred)
43 |
44 |
45 | def normalize_tensor(in_feat, eps=1e-10):
46 | norm_factor = torch.sqrt(torch.sum(in_feat ** 2, dim=1, keepdim=True))
47 | return in_feat / (norm_factor + eps)
48 |
49 |
50 | def l2(p0, p1, range=255.):
51 | return .5 * np.mean((p0 / range - p1 / range) ** 2)
52 |
53 |
54 | def psnr(p0, p1, peak=255.):
55 | return 10 * np.log10(peak ** 2 / np.mean((1. * p0 - 1. * p1) ** 2))
56 |
57 |
58 | def dssim(p0, p1, range=255.):
59 | return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2.
60 |
61 |
62 | def rgb2lab(in_img, mean_cent=False):
63 | from skimage import color
64 | img_lab = color.rgb2lab(in_img)
65 | if (mean_cent):
66 | img_lab[:, :, 0] = img_lab[:, :, 0] - 50
67 | return img_lab
68 |
69 |
70 | def tensor2np(tensor_obj):
71 | # change dimension of a tensor object into a numpy array
72 | return tensor_obj[0].cpu().float().numpy().transpose((1, 2, 0))
73 |
74 |
75 | def np2tensor(np_obj):
76 | # change dimenion of np array into tensor array
77 | return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
78 |
79 |
80 | def tensor2tensorlab(image_tensor, to_norm=True, mc_only=False):
81 | # image tensor to lab tensor
82 | from skimage import color
83 |
84 | img = tensor2im(image_tensor)
85 | img_lab = color.rgb2lab(img)
86 | if (mc_only):
87 | img_lab[:, :, 0] = img_lab[:, :, 0] - 50
88 | if (to_norm and not mc_only):
89 | img_lab[:, :, 0] = img_lab[:, :, 0] - 50
90 | img_lab = img_lab / 100.
91 |
92 | return np2tensor(img_lab)
93 |
94 |
95 | def tensorlab2tensor(lab_tensor, return_inbnd=False):
96 | from skimage import color
97 | import warnings
98 | warnings.filterwarnings("ignore")
99 |
100 | lab = tensor2np(lab_tensor) * 100.
101 | lab[:, :, 0] = lab[:, :, 0] + 50
102 |
103 | rgb_back = 255. * np.clip(color.lab2rgb(lab.astype('float')), 0, 1)
104 | if (return_inbnd):
105 | # convert back to lab, see if we match
106 | lab_back = color.rgb2lab(rgb_back.astype('uint8'))
107 | mask = 1. * np.isclose(lab_back, lab, atol=2.)
108 | mask = np2tensor(np.prod(mask, axis=2)[:, :, np.newaxis])
109 | return (im2tensor(rgb_back), mask)
110 | else:
111 | return im2tensor(rgb_back)
112 |
113 |
114 | def rgb2lab(input):
115 | from skimage import color
116 | return color.rgb2lab(input / 255.)
117 |
118 |
119 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255. / 2.):
120 | image_numpy = image_tensor[0].cpu().float().numpy()
121 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
122 | return image_numpy.astype(imtype)
123 |
124 |
125 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255. / 2.):
126 | return torch.Tensor((image / factor - cent)
127 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
128 |
129 |
130 | def tensor2vec(vector_tensor):
131 | return vector_tensor.data.cpu().numpy()[:, :, 0, 0]
132 |
133 |
134 | def voc_ap(rec, prec, use_07_metric=False):
135 | """ ap = voc_ap(rec, prec, [use_07_metric])
136 | Compute VOC AP given precision and recall.
137 | If use_07_metric is true, uses the
138 | VOC 07 11 point method (default:False).
139 | """
140 | if use_07_metric:
141 | # 11 point metric
142 | ap = 0.
143 | for t in np.arange(0., 1.1, 0.1):
144 | if np.sum(rec >= t) == 0:
145 | p = 0
146 | else:
147 | p = np.max(prec[rec >= t])
148 | ap = ap + p / 11.
149 | else:
150 | # correct AP calculation
151 | # first append sentinel values at the end
152 | mrec = np.concatenate(([0.], rec, [1.]))
153 | mpre = np.concatenate(([0.], prec, [0.]))
154 |
155 | # compute the precision envelope
156 | for i in range(mpre.size - 1, 0, -1):
157 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
158 |
159 | # to calculate area under PR curve, look for points
160 | # where X axis (recall) changes value
161 | i = np.where(mrec[1:] != mrec[:-1])[0]
162 |
163 | # and sum (\Delta recall) * prec
164 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
165 | return ap
166 |
167 |
168 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255. / 2.):
169 | # def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.):
170 | image_numpy = image_tensor[0].cpu().float().numpy()
171 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
172 | return image_numpy.astype(imtype)
173 |
174 |
175 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255. / 2.):
176 | # def im2tensor(image, imtype=np.uint8, cent=1., factor=1.):
177 | return torch.Tensor((image / factor - cent)
178 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
179 |
--------------------------------------------------------------------------------
/styleme/lpips/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ExponentiAI/StyleMe/64bf50eba799feba7c077cbf4ef4507ddf5c81d9/styleme/lpips/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/styleme/lpips/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ExponentiAI/StyleMe/64bf50eba799feba7c077cbf4ef4507ddf5c81d9/styleme/lpips/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/styleme/lpips/__pycache__/base_model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ExponentiAI/StyleMe/64bf50eba799feba7c077cbf4ef4507ddf5c81d9/styleme/lpips/__pycache__/base_model.cpython-37.pyc
--------------------------------------------------------------------------------
/styleme/lpips/__pycache__/base_model.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ExponentiAI/StyleMe/64bf50eba799feba7c077cbf4ef4507ddf5c81d9/styleme/lpips/__pycache__/base_model.cpython-38.pyc
--------------------------------------------------------------------------------
/styleme/lpips/__pycache__/dist_model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ExponentiAI/StyleMe/64bf50eba799feba7c077cbf4ef4507ddf5c81d9/styleme/lpips/__pycache__/dist_model.cpython-37.pyc
--------------------------------------------------------------------------------
/styleme/lpips/__pycache__/dist_model.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ExponentiAI/StyleMe/64bf50eba799feba7c077cbf4ef4507ddf5c81d9/styleme/lpips/__pycache__/dist_model.cpython-38.pyc
--------------------------------------------------------------------------------
/styleme/lpips/__pycache__/networks_basic.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ExponentiAI/StyleMe/64bf50eba799feba7c077cbf4ef4507ddf5c81d9/styleme/lpips/__pycache__/networks_basic.cpython-37.pyc
--------------------------------------------------------------------------------
/styleme/lpips/__pycache__/networks_basic.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ExponentiAI/StyleMe/64bf50eba799feba7c077cbf4ef4507ddf5c81d9/styleme/lpips/__pycache__/networks_basic.cpython-38.pyc
--------------------------------------------------------------------------------
/styleme/lpips/__pycache__/pretrained_networks.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ExponentiAI/StyleMe/64bf50eba799feba7c077cbf4ef4507ddf5c81d9/styleme/lpips/__pycache__/pretrained_networks.cpython-37.pyc
--------------------------------------------------------------------------------
/styleme/lpips/__pycache__/pretrained_networks.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ExponentiAI/StyleMe/64bf50eba799feba7c077cbf4ef4507ddf5c81d9/styleme/lpips/__pycache__/pretrained_networks.cpython-38.pyc
--------------------------------------------------------------------------------
/styleme/lpips/base_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from torch.autograd import Variable
4 | from pdb import set_trace as st
5 | from IPython import embed
6 |
7 | class BaseModel():
8 | def __init__(self):
9 | pass;
10 |
11 | def name(self):
12 | return 'BaseModel'
13 |
14 | def initialize(self, use_gpu=True, gpu_ids=[0]):
15 | self.use_gpu = use_gpu
16 | self.gpu_ids = gpu_ids
17 |
18 | def forward(self):
19 | pass
20 |
21 | def get_image_paths(self):
22 | pass
23 |
24 | def optimize_parameters(self):
25 | pass
26 |
27 | def get_current_visuals(self):
28 | return self.input
29 |
30 | def get_current_errors(self):
31 | return {}
32 |
33 | def save(self, label):
34 | pass
35 |
36 | # helper saving function that can be used by subclasses
37 | def save_network(self, network, path, network_label, epoch_label):
38 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
39 | save_path = os.path.join(path, save_filename)
40 | torch.save(network.state_dict(), save_path)
41 |
42 | # helper loading function that can be used by subclasses
43 | def load_network(self, network, network_label, epoch_label):
44 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
45 | save_path = os.path.join(self.save_dir, save_filename)
46 | print('Loading network from %s'%save_path)
47 | network.load_state_dict(torch.load(save_path))
48 |
49 | def update_learning_rate():
50 | pass
51 |
52 | def get_image_paths(self):
53 | return self.image_paths
54 |
55 | def save_done(self, flag=False):
56 | np.save(os.path.join(self.save_dir, 'done_flag'),flag)
57 | np.savetxt(os.path.join(self.save_dir, 'done_flag'),[flag,],fmt='%i')
58 |
59 |
--------------------------------------------------------------------------------
/styleme/lpips/dist_model.py:
--------------------------------------------------------------------------------
1 |
2 | from __future__ import absolute_import
3 |
4 | import sys
5 | import numpy as np
6 | import torch
7 | from torch import nn
8 | import os
9 | from collections import OrderedDict
10 | from torch.autograd import Variable
11 | import itertools
12 | from .base_model import BaseModel
13 | from scipy.ndimage import zoom
14 | import fractions
15 | import functools
16 | import skimage.transform
17 | from tqdm import tqdm
18 |
19 | from IPython import embed
20 |
21 | from . import networks_basic as networks
22 | import lpips as util
23 |
24 | class DistModel(BaseModel):
25 | def name(self):
26 | return self.model_name
27 |
28 | def initialize(self, model='net-lin', net='alex', colorspace='Lab', pnet_rand=False, pnet_tune=False, model_path=None,
29 | use_gpu=True, printNet=False, spatial=False,
30 | is_train=False, lr=.0001, beta1=0.5, version='0.1', gpu_ids=[0]):
31 | '''
32 | INPUTS
33 | model - ['net-lin'] for linearly calibrated network
34 | ['net'] for off-the-shelf network
35 | ['L2'] for L2 distance in Lab colorspace
36 | ['SSIM'] for ssim in RGB colorspace
37 | net - ['squeeze','alex','vgg']
38 | model_path - if None, will look in weights/[NET_NAME].pth
39 | colorspace - ['Lab','RGB'] colorspace to use for L2 and SSIM
40 | use_gpu - bool - whether or not to use a GPU
41 | printNet - bool - whether or not to print network architecture out
42 | spatial - bool - whether to output an array containing varying distances across spatial dimensions
43 | spatial_shape - if given, output spatial shape. if None then spatial shape is determined automatically via spatial_factor (see below).
44 | spatial_factor - if given, specifies upsampling factor relative to the largest spatial extent of a convolutional layer. if None then resized to size of input images.
45 | spatial_order - spline order of filter for upsampling in spatial mode, by default 1 (bilinear).
46 | is_train - bool - [True] for training mode
47 | lr - float - initial learning rate
48 | beta1 - float - initial momentum term for adam
49 | version - 0.1 for latest, 0.0 was original (with a bug)
50 | gpu_ids - int array - [0] by default, gpus to use
51 | '''
52 | BaseModel.initialize(self, use_gpu=use_gpu, gpu_ids=gpu_ids)
53 |
54 | self.model = model
55 | self.net = net
56 | self.is_train = is_train
57 | self.spatial = spatial
58 | self.gpu_ids = gpu_ids
59 | self.model_name = '%s [%s]'%(model,net)
60 |
61 | if(self.model == 'net-lin'): # pretrained net + linear layer
62 | self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_tune=pnet_tune, pnet_type=net,
63 | use_dropout=True, spatial=spatial, version=version, lpips=True)
64 | kw = {}
65 | if not use_gpu:
66 | kw['map_location'] = 'cpu'
67 | if(model_path is None):
68 | import inspect
69 | model_path = os.path.abspath(os.path.join(inspect.getfile(self.initialize), '..', 'weights/v%s/%s.pth'%(version,net)))
70 |
71 | if(not is_train):
72 | print('Loading model from: %s'%model_path)
73 | self.net.load_state_dict(torch.load(model_path, **kw), strict=False)
74 |
75 | elif(self.model=='net'): # pretrained network
76 | self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_type=net, lpips=False)
77 | elif(self.model in ['L2','l2']):
78 | self.net = networks.L2(use_gpu=use_gpu,colorspace=colorspace) # not really a network, only for testing
79 | self.model_name = 'L2'
80 | elif(self.model in ['DSSIM','dssim','SSIM','ssim']):
81 | self.net = networks.DSSIM(use_gpu=use_gpu,colorspace=colorspace)
82 | self.model_name = 'SSIM'
83 | else:
84 | raise ValueError("Model [%s] not recognized." % self.model)
85 |
86 | self.parameters = list(self.net.parameters())
87 |
88 | if self.is_train: # training mode
89 | # extra network on top to go from distances (d0,d1) => predicted human judgment (h*)
90 | self.rankLoss = networks.BCERankingLoss()
91 | self.parameters += list(self.rankLoss.net.parameters())
92 | self.lr = lr
93 | self.old_lr = lr
94 | self.optimizer_net = torch.optim.Adam(self.parameters, lr=lr, betas=(beta1, 0.999))
95 | else: # test mode
96 | self.net.eval()
97 |
98 | if(use_gpu):
99 | self.net.to(gpu_ids[0])
100 | self.net = torch.nn.DataParallel(self.net, device_ids=gpu_ids)
101 | if(self.is_train):
102 | self.rankLoss = self.rankLoss.to(device=gpu_ids[0]) # just put this on GPU0
103 |
104 | if(printNet):
105 | print('---------- Networks initialized -------------')
106 | networks.print_network(self.net)
107 | print('-----------------------------------------------')
108 |
109 | def forward(self, in0, in1, retPerLayer=False):
110 | ''' Function computes the distance between image patches in0 and in1
111 | INPUTS
112 | in0, in1 - torch.Tensor object of shape Nx3xXxY - image patch scaled to [-1,1]
113 | OUTPUT
114 | computed distances between in0 and in1
115 | '''
116 |
117 | return self.net.forward(in0, in1, retPerLayer=retPerLayer)
118 |
119 | # ***** TRAINING FUNCTIONS *****
120 | def optimize_parameters(self):
121 | self.forward_train()
122 | self.optimizer_net.zero_grad()
123 | self.backward_train()
124 | self.optimizer_net.step()
125 | self.clamp_weights()
126 |
127 | def clamp_weights(self):
128 | for module in self.net.modules():
129 | if(hasattr(module, 'weight') and module.kernel_size==(1,1)):
130 | module.weight.data = torch.clamp(module.weight.data,min=0)
131 |
132 | def set_input(self, data):
133 | self.input_ref = data['ref']
134 | self.input_p0 = data['p0']
135 | self.input_p1 = data['p1']
136 | self.input_judge = data['judge']
137 |
138 | if(self.use_gpu):
139 | self.input_ref = self.input_ref.to(device=self.gpu_ids[0])
140 | self.input_p0 = self.input_p0.to(device=self.gpu_ids[0])
141 | self.input_p1 = self.input_p1.to(device=self.gpu_ids[0])
142 | self.input_judge = self.input_judge.to(device=self.gpu_ids[0])
143 |
144 | self.var_ref = Variable(self.input_ref,requires_grad=True)
145 | self.var_p0 = Variable(self.input_p0,requires_grad=True)
146 | self.var_p1 = Variable(self.input_p1,requires_grad=True)
147 |
148 | def forward_train(self): # run forward pass
149 | # print(self.net.module.scaling_layer.shift)
150 | # print(torch.norm(self.net.module.net.slice1[0].weight).item(), torch.norm(self.net.module.lin0.model[1].weight).item())
151 |
152 | self.d0 = self.forward(self.var_ref, self.var_p0)
153 | self.d1 = self.forward(self.var_ref, self.var_p1)
154 | self.acc_r = self.compute_accuracy(self.d0,self.d1,self.input_judge)
155 |
156 | self.var_judge = Variable(1.*self.input_judge).view(self.d0.size())
157 |
158 | self.loss_total = self.rankLoss.forward(self.d0, self.d1, self.var_judge*2.-1.)
159 |
160 | return self.loss_total
161 |
162 | def backward_train(self):
163 | torch.mean(self.loss_total).backward()
164 |
165 | def compute_accuracy(self,d0,d1,judge):
166 | ''' d0, d1 are Variables, judge is a Tensor '''
167 | d1_lt_d0 = (d1 %f' % (type,self.old_lr, lr))
210 | self.old_lr = lr
211 |
212 | def score_2afc_dataset(data_loader, func, name=''):
213 | ''' Function computes Two Alternative Forced Choice (2AFC) score using
214 | distance function 'func' in dataset 'data_loader'
215 | INPUTS
216 | data_loader - CustomDatasetDataLoader object - contains a TwoAFCDataset inside
217 | func - callable distance function - calling d=func(in0,in1) should take 2
218 | pytorch tensors with shape Nx3xXxY, and return numpy array of length N
219 | OUTPUTS
220 | [0] - 2AFC score in [0,1], fraction of time func agrees with human evaluators
221 | [1] - dictionary with following elements
222 | d0s,d1s - N arrays containing distances between reference patch to perturbed patches
223 | gts - N array in [0,1], preferred patch selected by human evaluators
224 | (closer to "0" for left patch p0, "1" for right patch p1,
225 | "0.6" means 60pct people preferred right patch, 40pct preferred left)
226 | scores - N array in [0,1], corresponding to what percentage function agreed with humans
227 | CONSTS
228 | N - number of test triplets in data_loader
229 | '''
230 |
231 | d0s = []
232 | d1s = []
233 | gts = []
234 |
235 | for data in tqdm(data_loader.load_data(), desc=name):
236 | d0s+=func(data['ref'],data['p0']).data.cpu().numpy().flatten().tolist()
237 | d1s+=func(data['ref'],data['p1']).data.cpu().numpy().flatten().tolist()
238 | gts+=data['judge'].cpu().numpy().flatten().tolist()
239 |
240 | d0s = np.array(d0s)
241 | d1s = np.array(d1s)
242 | gts = np.array(gts)
243 | scores = (d0s mask.shape[1]:
46 | channel_scale = feat.shape[1] // mask.shape[1]
47 | mask = mask.repeat(1, channel_scale, 1, 1)
48 |
49 | mask = F.interpolate(mask, size=feat.shape[2])
50 | feat_a = self.weight_a * feat * mask + self.bias_a
51 | feat_b = self.weight_b * feat * (1 - mask) + self.bias_b
52 | return feat_a + feat_b
53 |
54 |
55 | class Swish(nn.Module):
56 | def forward(self, feat):
57 | return feat * torch.sigmoid(feat)
58 |
59 |
60 | class Squeeze(nn.Module):
61 | def forward(self, feat):
62 | return feat.squeeze(-1).squeeze(-1)
63 |
64 |
65 | class UnSqueeze(nn.Module):
66 | def forward(self, feat):
67 | return feat.unsqueeze(-1).unsqueeze(-1)
68 |
69 |
70 | class ECAModule(nn.Module):
71 | def __init__(self, c, b=1, gamma=2):
72 | super(ECAModule, self).__init__()
73 | t = int(abs((math.log(c, 2) + b) / gamma))
74 | k = t if t % 2 else t + 1
75 |
76 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
77 | self.conv1 = spectral_norm(nn.Conv1d(1, 1, k, 1, int(k / 2), bias=False))
78 | self.sigmoid = nn.Sigmoid()
79 |
80 | def forward(self, x):
81 | x = self.avg_pool(x)
82 | x = self.conv1(x.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
83 | out = self.sigmoid(x)
84 | return x * out
85 |
86 |
87 | class ResBlock(nn.Module):
88 | def __init__(self, ch, expansion=2):
89 | super(ResBlock, self).__init__()
90 | self.main = nn.Sequential(spectral_norm(nn.Conv2d(ch, ch * expansion, 1, 1, 0, bias=False)),
91 | spectral_norm(nn.BatchNorm2d(ch * expansion)), Swish(),
92 | spectral_norm(DepthwiseConv2d(ch * expansion, ch * expansion, 3, 1, 1)),
93 | spectral_norm(nn.BatchNorm2d(ch * expansion)), Swish(),
94 | spectral_norm(nn.Conv2d(ch * expansion, ch, 1, 1, 0, bias=False)),
95 | spectral_norm(nn.BatchNorm2d(ch)), Swish(),
96 | ECAModule(ch))
97 |
98 | def forward(self, x):
99 | return x + self.main(x)
100 |
101 |
102 | def base_block(ch_in, ch_out):
103 | return nn.Sequential(nn.Conv2d(ch_in, ch_out, 3, 1, 1, bias=False),
104 | nn.BatchNorm2d(ch_out),
105 | nn.LeakyReLU(0.2, inplace=True))
106 |
107 |
108 | def down_block(ch_in, ch_out):
109 | return nn.Sequential(nn.Conv2d(ch_in, ch_out, 4, 2, 1, bias=False),
110 | nn.BatchNorm2d(ch_out),
111 | nn.LeakyReLU(0.1, inplace=True))
112 |
113 |
114 | ################################
115 | # style encode #
116 | ################################
117 |
118 | class StyleEncoder(nn.Module):
119 | def __init__(self, ch=32, nbr_cls=100):
120 | super().__init__()
121 |
122 | self.sf_256 = base_block(3, ch // 2)
123 | self.sf_128 = down_block(ch // 2, ch)
124 | self.sf_64 = down_block(ch, ch * 2)
125 |
126 | self.sf_32 = nn.Sequential(down_block(ch * 2, ch * 4),
127 | ResBlock(ch * 4))
128 | self.sf_16 = nn.Sequential(down_block(ch * 4, ch * 8),
129 | ResBlock(ch * 8))
130 | self.sf_8 = nn.Sequential(down_block(ch * 8, ch * 16),
131 | ResBlock(ch * 16))
132 |
133 | self.sfv_32 = nn.Sequential(nn.AdaptiveAvgPool2d(output_size=4),
134 | nn.Conv2d(ch * 4, ch * 2, 4, 1, 0, bias=False),
135 | Squeeze())
136 | self.sfv_16 = nn.Sequential(nn.AdaptiveAvgPool2d(output_size=4),
137 | nn.Conv2d(ch * 8, ch * 4, 4, 1, 0, bias=False),
138 | Squeeze())
139 | self.sfv_8 = nn.Sequential(nn.AdaptiveAvgPool2d(output_size=4),
140 | nn.Conv2d(ch * 16, ch * 8, 4, 1, 0, bias=False),
141 | Squeeze())
142 |
143 | self.ch = ch
144 | self.nbr_cls = nbr_cls
145 | self.final_cls = None
146 |
147 | def reset_cls(self):
148 | if self.final_cls is None:
149 | self.final_cls = nn.Sequential(nn.LeakyReLU(0.1), nn.Linear(self.ch * 8, self.nbr_cls))
150 | stdv = 1. / math.sqrt(self.final_cls[1].weight.size(1))
151 | self.final_cls[1].weight.data.uniform_(-stdv, stdv)
152 | if self.final_cls[1].bias is not None:
153 | self.final_cls[1].bias.data.uniform_(-0.1 * stdv, 0.1 * stdv)
154 |
155 | def get_feats(self, image):
156 | feat = self.sf_256(image)
157 | feat = self.sf_128(feat)
158 | feat = self.sf_64(feat)
159 | feat_32 = self.sf_32(feat)
160 | feat_16 = self.sf_16(feat_32)
161 | feat_8 = self.sf_8(feat_16)
162 |
163 | feat_32 = self.sfv_32(feat_32)
164 | feat_16 = self.sfv_16(feat_16)
165 | feat_8 = self.sfv_8(feat_8)
166 |
167 | return feat_32, feat_16, feat_8
168 |
169 | def forward(self, image):
170 | feat_32, feat_16, feat_8 = self.get_feats(image)
171 | pred_cls = self.final_cls(feat_8)
172 |
173 | return [feat_32, feat_16, feat_8], pred_cls
174 | # [1, 64] [1, 128] [1, 256]
175 |
176 |
177 | ################################
178 | # content encode #
179 | ################################
180 |
181 | class ContentEncoder(nn.Module):
182 | def __init__(self, ch=32):
183 | super().__init__()
184 |
185 | self.feat_256 = base_block(1, ch // 4)
186 | self.feat_128 = down_block(ch // 4, ch // 2)
187 | self.feat_64 = down_block(ch // 2, ch)
188 |
189 | self.feat_32 = nn.Sequential(down_block(ch, ch * 2),
190 | ResBlock(ch * 2))
191 | self.feat_16 = nn.Sequential(down_block(ch * 2, ch * 4),
192 | ResBlock(ch * 4))
193 | self.feat_8 = nn.Sequential(down_block(ch * 4, ch * 8),
194 | ResBlock(ch * 8))
195 |
196 | def forward(self, image):
197 | feat = self.feat_256(image)
198 | feat = self.feat_128(feat)
199 | feat = self.feat_64(feat)
200 |
201 | feat_32 = self.feat_32(feat)
202 | feat_16 = self.feat_16(feat_32)
203 | feat_8 = self.feat_8(feat_16)
204 |
205 | return [feat_32, feat_16, feat_8]
206 | # [1, 64, 32, 32]
207 | # [1, 128, 16, 16]
208 | # [1, 256, 8, 8]
209 |
210 |
211 | def for_decoder(ch_in, ch_out):
212 | return nn.Sequential(
213 | nn.UpsamplingNearest2d(scale_factor=2),
214 | nn.Conv2d(ch_in, ch_out * 2, 3, 1, 1, bias=False),
215 | nn.InstanceNorm2d(ch_out * 2),
216 | GLU())
217 |
218 |
219 | def style_decode(ch_in, ch_out):
220 | return nn.Sequential(nn.Linear(ch_in, ch_out), nn.ReLU(),
221 | nn.Linear(ch_out, ch_out), nn.Sigmoid(),
222 | UnSqueeze())
223 |
224 |
225 | ################################
226 | # decode #
227 | ################################
228 |
229 | class Decoder(nn.Module):
230 | def __init__(self, ch=32):
231 | super().__init__()
232 |
233 | self.base_feat = nn.Parameter(torch.randn(1, ch * 8, 8, 8).normal_(0, 1), requires_grad=True)
234 |
235 | self.dmi_8 = DMI(ch * 8)
236 | self.dmi_16 = DMI(ch * 4)
237 |
238 | self.feat_8_1 = nn.Sequential(ResBlock(ch * 16), nn.LeakyReLU(0.1, inplace=True),
239 | nn.Conv2d(ch * 16, ch * 8, 3, 1, 1, bias=False),
240 | nn.InstanceNorm2d(ch * 8))
241 | self.feat_8_2 = nn.Sequential(nn.LeakyReLU(0.1, inplace=True), ResBlock(ch * 8))
242 |
243 | self.feat_16 = nn.Sequential(nn.LeakyReLU(0.1, inplace=True),
244 | for_decoder(ch * 8, ch * 4), ResBlock(ch * 4))
245 | self.feat_32 = nn.Sequential(nn.LeakyReLU(0.1, inplace=True),
246 | for_decoder(ch * 8, ch * 2), ResBlock(ch * 2))
247 |
248 | self.feat_64 = for_decoder(ch * 4, ch)
249 | self.feat_128 = for_decoder(ch, ch // 2)
250 | self.feat_256 = for_decoder(ch // 2, ch // 4)
251 |
252 | self.to_rgb = nn.Sequential(nn.Conv2d(ch // 4, 3, 3, 1, 1, bias=False),
253 | nn.Tanh())
254 |
255 | self.style_8 = style_decode(ch * 8, ch * 8)
256 | self.style_64 = style_decode(ch * 8, ch)
257 | self.style_128 = style_decode(ch * 4, ch // 2)
258 | self.style_256 = style_decode(ch * 2, ch // 4)
259 |
260 | def forward(self, content_feats, style_vectors):
261 | feat_8 = self.feat_8_1(torch.cat([content_feats[2],
262 | self.base_feat.repeat(style_vectors[0].shape[0], 1, 1, 1)], dim=1))
263 | feat_8 = self.dmi_8(feat_8, content_feats[2])
264 |
265 | feat_8 = feat_8 * self.style_8(style_vectors[2])
266 | feat_8 = self.feat_8_2(feat_8)
267 |
268 | feat_16 = self.feat_16(feat_8)
269 | feat_16 = self.dmi_16(feat_16, content_feats[1])
270 | feat_16 = torch.cat([feat_16, content_feats[1]], dim=1)
271 |
272 | feat_32 = self.feat_32(feat_16)
273 | feat_32 = torch.cat([feat_32, content_feats[0]], dim=1)
274 |
275 | feat_64 = self.feat_64(feat_32) * self.style_64(style_vectors[2])
276 | feat_128 = self.feat_128(feat_64) * self.style_128(style_vectors[1])
277 | feat_256 = self.feat_256(feat_128) * self.style_256(style_vectors[0])
278 |
279 | return self.to_rgb(feat_256)
280 |
281 |
282 | ################################
283 | # AE Module #
284 | ################################
285 |
286 | class AE(nn.Module):
287 | def __init__(self, ch, nbr_cls=100):
288 | super().__init__()
289 |
290 | self.style_encoder = StyleEncoder(ch, nbr_cls=nbr_cls)
291 | self.content_encoder = ContentEncoder(ch)
292 | self.decoder = Decoder(ch)
293 |
294 | @torch.no_grad()
295 | def forward(self, skt_img, style_img):
296 | style_feats = self.style_encoder.get_feats(F.interpolate(style_img, size=256))
297 | content_feats = self.content_encoder(F.interpolate(skt_img, size=256))
298 | gimg = self.decoder(content_feats, style_feats)
299 | return gimg, style_feats
300 |
301 | def load_state_dicts(self, path):
302 | ckpt = torch.load(path)
303 | self.style_encoder.reset_cls()
304 | self.style_encoder.load_state_dict(ckpt['s'])
305 | self.content_encoder.load_state_dict(ckpt['c'])
306 | self.decoder.load_state_dict(ckpt['d'])
307 | print('AE model load success')
308 |
309 |
310 | def down_gan(ch_in, ch_out):
311 | return nn.Sequential(
312 | spectral_norm(nn.Conv2d(ch_in, ch_out, 4, 2, 1, bias=False)),
313 | nn.BatchNorm2d(ch_out),
314 | nn.LeakyReLU(0.1, inplace=True))
315 |
316 |
317 | def up_gan(ch_in, ch_out):
318 | return nn.Sequential(
319 | nn.UpsamplingNearest2d(scale_factor=2),
320 | spectral_norm(nn.Conv2d(ch_in, ch_out, 3, 1, 1, bias=False)),
321 | nn.BatchNorm2d(ch_out),
322 | nn.LeakyReLU(0.1, inplace=True))
323 |
324 |
325 | def style_gan(ch_in, ch_out):
326 | return nn.Sequential(
327 | spectral_norm(nn.Linear(ch_in, ch_out)), nn.ReLU(),
328 | nn.Linear(ch_out, ch_out),
329 | nn.Sigmoid(), UnSqueeze())
330 |
331 |
332 | ################################
333 | # GAN #
334 | ################################
335 |
336 | class RefineGenerator(nn.Module):
337 | def __init__(self, ch=32, im_size=256):
338 | super().__init__()
339 |
340 | self.im_size = im_size
341 |
342 | self.from_noise_32 = nn.Sequential(UnSqueeze(),
343 | spectral_norm(nn.ConvTranspose2d(ch * 8, ch * 8, 4, 1, 0, bias=False)),
344 | nn.BatchNorm2d(ch * 8),
345 | nn.Sigmoid(),
346 | up_gan(ch * 8, ch * 4),
347 | up_gan(ch * 4, ch * 2),
348 | up_gan(ch * 2, ch * 1))
349 |
350 | self.from_style = nn.Sequential(UnSqueeze(),
351 | spectral_norm(
352 | nn.ConvTranspose2d(ch * (8 + 4 + 2), ch * 16, 4, 1, 0, bias=False)),
353 | nn.BatchNorm2d(ch * 16),
354 | GLU(),
355 | up_gan(ch * 8, ch * 4))
356 |
357 | self.encode_256 = nn.Sequential(spectral_norm(nn.Conv2d(3, ch, 3, 1, 1, bias=False)),
358 | nn.LeakyReLU(0.2, inplace=True))
359 | self.encode_128 = nn.Sequential(ResBlock(ch),
360 | down_gan(ch, ch * 2))
361 | self.encode_64 = nn.Sequential(ResBlock(ch * 2),
362 | down_gan(ch * 2, ch * 4))
363 | self.encode_32 = nn.Sequential(ResBlock(ch * 4),
364 | down_gan(ch * 4, ch * 8))
365 |
366 | self.encode_16 = nn.Sequential(ResBlock(ch * 8),
367 | down_gan(ch * 8, ch * 16))
368 |
369 | self.decode_32 = nn.Sequential(ResBlock(ch * 16),
370 | up_gan(ch * 16, ch * 8))
371 | self.decode_64 = nn.Sequential(ResBlock(ch * 8 + ch),
372 | up_gan(ch * 8 + ch, ch * 4))
373 | self.decode_128 = nn.Sequential(ResBlock(ch * 4),
374 | up_gan(ch * 4, ch * 2))
375 | self.decode_256 = nn.Sequential(ResBlock(ch * 2),
376 | up_gan(ch * 2, ch))
377 |
378 | self.style_64 = style_gan(ch * 8, ch * 4)
379 | self.style_128 = style_gan(ch * 4, ch * 2)
380 | self.style_256 = style_gan(ch * 2, ch)
381 |
382 | self.to_rgb = nn.Sequential(nn.Conv2d(ch, 3, 3, 1, 1, bias=False), nn.Tanh())
383 |
384 | def forward(self, image, style_vectors):
385 | n_32 = self.from_noise_32(torch.randn_like(style_vectors[2])) # [8, 32, 32, 32]
386 |
387 | e_256 = self.encode_256(image) # [8, 3, 256, 256] [8, 32, 256, 256]
388 | e_128 = self.encode_128(e_256) # [8, 64, 128, 128]
389 | e_64 = self.encode_64(e_128) # [8, 128, 64, 64]
390 | e_32 = self.encode_32(e_64) # [8, 256, 32, 32]
391 |
392 | e_16 = self.encode_16(e_32) # [8, 256, 16, 16]
393 |
394 | d_32 = self.decode_32(e_16) # [8, 256, 32, 32]
395 | d_64 = self.decode_64(torch.cat([d_32, n_32], dim=1)) # [8, 128, 64, 64]
396 | d_64 = self.style_64(style_vectors[2]) * d_64 # [8, 128, 64, 64]
397 |
398 | d_128 = self.decode_128(d_64 + e_64) # [8, 64, 128, 128]
399 | d_128 = self.style_128(style_vectors[1]) * d_128 # [8, 64, 128, 128]
400 |
401 | d_256 = self.decode_256(d_128 + e_128) # [8, 32, 256, 256]
402 | d_256 = self.style_256(style_vectors[0]) * d_256 # [8, 32, 256, 256]
403 |
404 | d_final = self.to_rgb(d_256)
405 |
406 | return d_final
407 |
408 |
409 | class DownBlock(nn.Module):
410 | def __init__(self, ch_in, ch_out):
411 | super().__init__()
412 |
413 | self.ch_out = ch_out
414 | self.down_main = nn.Sequential(
415 | spectral_norm(nn.Conv2d(ch_in, ch_out, 3, 2, 1, bias=False)),
416 | nn.BatchNorm2d(ch_out),
417 | nn.LeakyReLU(0.1, inplace=True),
418 | spectral_norm(nn.Conv2d(ch_out, ch_out, 3, 1, 1, bias=False)),
419 | nn.BatchNorm2d(ch_out),
420 | nn.LeakyReLU(0.1, inplace=True)
421 | )
422 |
423 | def forward(self, feat):
424 | feat_out = self.down_main(feat)
425 |
426 | return feat_out
427 |
428 |
429 | class Discriminator(nn.Module):
430 | def __init__(self, ch=64, nc=3, im_size=256):
431 | super(Discriminator, self).__init__()
432 | self.ch = ch
433 | self.im_size = im_size
434 |
435 | self.f_256 = nn.Sequential(spectral_norm(nn.Conv2d(nc, ch // 8, 3, 1, 1, bias=False)),
436 | nn.LeakyReLU(0.2, inplace=True))
437 |
438 | self.f_128 = DownBlock(ch // 8, ch // 4)
439 | self.f_64 = DownBlock(ch // 4, ch // 2)
440 | self.f_32 = DownBlock(ch // 2, ch)
441 | self.f_16 = DownBlock(ch, ch * 2)
442 | self.f_8 = DownBlock(ch * 2, ch * 4)
443 | self.f = nn.Sequential(spectral_norm(nn.Conv2d(ch * 4, ch * 8, 1, 1, 0, bias=False)),
444 | nn.BatchNorm2d(ch * 8),
445 | nn.LeakyReLU(0.1, inplace=True))
446 |
447 | self.flatten = spectral_norm(nn.Conv2d(ch * 8, 1, 3, 1, 1, bias=False))
448 |
449 | self.apply(weights_init)
450 |
451 | def forward(self, x):
452 | feat_256 = self.f_256(x)
453 | feat_128 = self.f_128(feat_256)
454 | feat_64 = self.f_64(feat_128)
455 | feat_32 = self.f_32(feat_64)
456 | feat_16 = self.f_16(feat_32)
457 | feat_8 = self.f_8(feat_16)
458 | feat_f = self.f(feat_8)
459 | feat_out = self.flatten(feat_f)
460 |
461 | return feat_out
462 |
--------------------------------------------------------------------------------
/styleme/readme.md:
--------------------------------------------------------------------------------
1 | # Environment :
2 |
3 | - python 3.8.0
4 | - pytorch 1.12.1
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 | Fig.1 An overview of our style transform network of StyleMe
13 |
14 |
15 |
16 | - you can download our datasets which includs 119 RGB images and 119 sketches here: [**styleme datasets**](https://drive.google.com/drive/folders/1UycahUifPoc0n6pyP92bWC07BlJETwRR)
17 |
18 |
19 |
20 | - We provided a pretrained model that was trained 30,000 times here: [**styleme model**](https://drive.google.com/drive/folders/1JHmDdsV6OS0sf6v-OhwkpbkDPn7Co2HW)
21 |
22 |
23 |
24 |
25 | ## 1. Description
26 | Related code comments:
27 |
28 | * train.py: training the hole model, and you can also choose train AE module only or train GAN module only.
29 | * models.py: all the related models' structure definition, including encoder(style and content), decoder(decode random style features and content features), generator, and discriminator.
30 | * datasets.py: data pre-processing and loading methods.
31 | * train_step_1.py: AE module training.
32 | * train_step_2.py: GAN module training.
33 | * config.py: all the hyper-parameters settings.
34 | * calcualte.py: calculate the FID and LPIPS of the model.
35 | * benchmark.py: the FID functions, including inception model and it will automatically download.
36 | * lpips: the LPIPS functions, also including inception model and automatically download.
37 | * style_transform.py: put your sketch and RGB images to tansform the style.
38 |
39 |
40 | ## 2. Training
41 |
42 | - first prepare your datasets as follows:
43 |
44 | ```
45 | train_data/
46 | -./rgb/
47 | -000.png
48 | -001.png
49 | -...
50 | -./sketch/
51 | -000.png
52 | -001.png
53 | -...
54 | ```
55 |
56 |
57 |
58 | - and then training your models:
59 |
60 | ```
61 | python train.py
62 | ```
63 |
64 |
65 |
66 | ## 3. Evaluate
67 |
68 | - You can run the following program to see the performance of our model:
69 |
70 | ```
71 | python style_transform.py
72 | ```
73 |
74 | - or you can also get the FID and LPIPS:
75 |
76 | ```
77 | python calculate.py
78 | ```
79 |
--------------------------------------------------------------------------------
/styleme/style_transform.py:
--------------------------------------------------------------------------------
1 | ##############################
2 | # style transform #
3 | ##############################
4 |
5 | import torch
6 | from torch.utils.data import DataLoader
7 | from torchvision import utils as vutils
8 | from datasets import ImageFolder, transform_data
9 | from models import AE, RefineGenerator
10 |
11 |
12 | def make_matrix(dataloader_rgb, dataloader_skt, net_ae, net_ig, BATCH_SIZE, IM_SIZE, im_name):
13 | rgb_img = next(dataloader_rgb)
14 | skt_img = next(dataloader_skt)
15 |
16 | skt_img = skt_img.mean(dim=1, keepdim=True)
17 |
18 | image_matrix = [torch.ones(1, 3, IM_SIZE, IM_SIZE)]
19 | image_matrix.append(rgb_img.clone())
20 | with torch.no_grad():
21 | rgb_img = rgb_img.cuda()
22 | for skt in skt_img:
23 | input_skts = skt.unsqueeze(0).repeat(BATCH_SIZE, 1, 1, 1).cuda()
24 |
25 | gimg_ae, style_feats = net_ae(input_skts, rgb_img)
26 | image_matrix.append(skt.unsqueeze(0).repeat(1, 3, 1, 1).clone())
27 | image_matrix.append(gimg_ae.cpu())
28 |
29 | g_images = net_ig(gimg_ae, style_feats).cpu()
30 | image_matrix.append(skt.unsqueeze(0).repeat(1, 3, 1, 1).clone().fill_(1))
31 | image_matrix.append(torch.nn.functional.interpolate(g_images, IM_SIZE))
32 |
33 | image_matrix = torch.cat(image_matrix)
34 | vutils.save_image(0.5 * (image_matrix + 1), im_name, nrow=BATCH_SIZE + 1)
35 |
36 |
37 | if __name__ == "__main__":
38 | device = 'cuda'
39 | batch_size = 5
40 | img_size = 256
41 | num_workers = 2
42 | trans_iter = 20
43 | data_root_colorful = './train_data/rgb/'
44 | data_root_sketch = './train_data/sketch/'
45 |
46 | net_ae = AE(ch=32, nbr_cls=50)
47 | net_ae.style_encoder.reset_cls()
48 | net_ig = RefineGenerator()
49 |
50 | ckpt = torch.load('./checkpoint/GAN.pth')
51 |
52 | net_ae.load_state_dict(ckpt['ae'])
53 | net_ae.style_encoder.reset_cls()
54 | net_ig.load_state_dict(ckpt['ig'])
55 |
56 | net_ae.to(device)
57 | net_ig.to(device)
58 | net_ae.eval()
59 | net_ig.eval()
60 |
61 | dataset_rgb = ImageFolder(data_root_colorful, transform_data(img_size))
62 | dataloader_rgb = iter(DataLoader(dataset_rgb, batch_size, shuffle=False, num_workers=num_workers))
63 |
64 | dataset_skt = ImageFolder(data_root_sketch, transform_data(img_size))
65 | dataloader_skt = iter(DataLoader(dataset_skt, batch_size, shuffle=False, num_workers=num_workers))
66 |
67 | for idx in range(trans_iter):
68 | print(idx)
69 | make_matrix(dataloader_rgb, dataloader_skt, net_ae, net_ig, batch_size, img_size,
70 | './trans_data/transform/%d.jpg' % idx)
71 |
--------------------------------------------------------------------------------
/styleme/train.py:
--------------------------------------------------------------------------------
1 | ############################
2 | # main training #
3 | ############################
4 |
5 | import train_step_1
6 | import train_step_2
7 | from config import TRAIN_AE_ONLY, TRAIN_GAN_ONLY
8 |
9 |
10 | if __name__ == "__main__":
11 | if TRAIN_GAN_ONLY:
12 | print('train gan only !')
13 | train_step_1.train()
14 | else:
15 | print('train ae first !')
16 | train_step_1.train()
17 | if not TRAIN_AE_ONLY:
18 | train_step_2.train()
19 |
--------------------------------------------------------------------------------
/styleme/train_step_1.py:
--------------------------------------------------------------------------------
1 | #############################
2 | # train_step_1 #
3 | # #
4 | # transform images #
5 | #############################
6 |
7 | import torch
8 | import torch.nn.functional as F
9 | from torch import optim
10 | from torch.utils.data import DataLoader
11 | from torchvision import utils as vutils
12 |
13 | import random
14 | from tqdm import tqdm
15 |
16 | from datasets import TransformData, InfiniteSamplerWrapper
17 | from utils import make_folders, AverageMeter
18 | from models import StyleEncoder, ContentEncoder, Decoder
19 |
20 |
21 | def loss_for_style(style, style_org, batch_size):
22 | loss_result = 0
23 | for loss_idx in range(len(style)):
24 | loss_result += - F.cosine_similarity(style[loss_idx],
25 | style_org[loss_idx].detach()).mean() + \
26 | F.cosine_similarity(style[loss_idx],
27 | style_org[loss_idx][torch.randperm(batch_size)]
28 | .detach()).mean()
29 | return loss_result / len(style)
30 |
31 |
32 | def loss_for_content(loss, fl1, fl2):
33 | loss_result = 0
34 | for f_idx in range(len(fl1)):
35 | loss_result += loss(fl1[f_idx], fl2[f_idx].detach())
36 | return loss_result * 2
37 |
38 |
39 | def train():
40 | from config import IM_SIZE_AE, BATCH_SIZE_AE, CHANNEL, NBR_CLS, DATALOADER_WORKERS, ITERATION_AE
41 | from config import SAVE_IMAGE_INTERVAL, SAVE_MODEL_INTERVAL, SAVE_FOLDER, LOG_INTERVAL
42 | from config import data_root_colorful, data_root_sketch
43 |
44 | dataset_trans = TransformData(data_root_colorful, data_root_sketch, im_size=IM_SIZE_AE, nbr_cls=NBR_CLS)
45 | print('Num classes:', len(dataset_trans), ' Data nums:', len(dataset_trans.frame))
46 | dataloader_trans = iter(DataLoader(dataset_trans, BATCH_SIZE_AE,
47 | sampler=InfiniteSamplerWrapper(dataset_trans),
48 | num_workers=DATALOADER_WORKERS, pin_memory=True))
49 |
50 | style_encoder = StyleEncoder(ch=CHANNEL, nbr_cls=NBR_CLS).cuda()
51 | content_encoder = ContentEncoder(ch=CHANNEL).cuda()
52 | decoder = Decoder(ch=CHANNEL).cuda()
53 |
54 | opt_content = optim.Adam(content_encoder.parameters(), lr=1e-4, betas=(0.9, 0.999))
55 | opt_style = optim.Adam(style_encoder.parameters(), lr=1e-4, betas=(0.9, 0.999))
56 | opt_decode = optim.Adam(decoder.parameters(), lr=1e-4, betas=(0.9, 0.999))
57 |
58 | style_encoder.reset_cls()
59 | style_encoder.final_cls.cuda()
60 |
61 | # load model
62 | from config import PRETRAINED_AE_PATH
63 | if PRETRAINED_AE_PATH is not None:
64 | ckpt = torch.load(PRETRAINED_AE_PATH)
65 |
66 | print('Pre-trained AE path : ', PRETRAINED_AE_PATH)
67 |
68 | style_encoder.load_state_dict(ckpt['s'])
69 | content_encoder.load_state_dict(ckpt['c'])
70 | decoder.load_state_dict(ckpt['d'])
71 |
72 | opt_style.load_state_dict(ckpt['opt_s'])
73 | opt_content.load_state_dict(ckpt['opt_c'])
74 | opt_decode.load_state_dict(ckpt['opt_d'])
75 | print('loaded pre-trained AE')
76 |
77 | style_encoder.reset_cls()
78 | style_encoder.final_cls.cuda()
79 | opt_s_cls = optim.Adam(style_encoder.final_cls.parameters(), lr=1e-4, betas=(0.9, 0.999))
80 |
81 | # save path
82 | saved_image_folder, saved_model_folder = make_folders(SAVE_FOLDER, 'Train_step_1')
83 |
84 | # loss log
85 | losses_style_feat = AverageMeter()
86 | losses_content_feat = AverageMeter()
87 | losses_cls = AverageMeter()
88 | losses_org = AverageMeter()
89 | losses_rd = AverageMeter()
90 | losses_flip = AverageMeter()
91 |
92 | import lpips
93 | percept = lpips.PerceptualLoss(model='net-lin', net='vgg', use_gpu=True)
94 |
95 | for iteration in tqdm(range(ITERATION_AE)):
96 | if iteration % ((NBR_CLS * 100) // BATCH_SIZE_AE) == 0 and iteration > 1:
97 | dataset_trans._next_set()
98 | dataloader_trans = iter(DataLoader(dataset_trans, BATCH_SIZE_AE,
99 | sampler=InfiniteSamplerWrapper(dataset_trans),
100 | num_workers=DATALOADER_WORKERS, pin_memory=True))
101 | style_encoder.reset_cls()
102 | opt_s_cls = optim.Adam(style_encoder.final_cls.parameters(), lr=1e-4, betas=(0.9, 0.999))
103 |
104 | opt_style.param_groups[0]['lr'] = 1e-4
105 | opt_decode.param_groups[0]['lr'] = 1e-4
106 |
107 | # 1. training for encode & decode
108 | # 1.1 prepare data
109 | rgb_img_org, rgb_img_rd, rgb_img_flip, skt_org, skt_erased, skt_bold, img_idx = next(dataloader_trans)
110 | rgb_img_org = rgb_img_org.cuda()
111 | rgb_img_rd = rgb_img_rd.cuda()
112 | rgb_img_flip = rgb_img_flip.cuda()
113 |
114 | skt_org = F.interpolate(skt_org, size=IM_SIZE_AE).cuda()
115 | skt_erased = F.interpolate(skt_erased, size=IM_SIZE_AE).cuda()
116 | skt_bold = F.interpolate(skt_bold, size=IM_SIZE_AE).cuda()
117 |
118 | img_idx = img_idx.long().cuda()
119 |
120 | # 1.2 model grad zero
121 | style_encoder.zero_grad()
122 | content_encoder.zero_grad()
123 | decoder.zero_grad()
124 |
125 | ################
126 | # encode #
127 | ################
128 | # 1.3 for style
129 | style_vector_org, pred_cls_org = style_encoder(rgb_img_org)
130 | style_vector_rd, pred_cls_rd = style_encoder(rgb_img_rd)
131 | style_vector_flip, pred_cls_flip = style_encoder(rgb_img_flip)
132 |
133 | # 1.4 for content
134 | content_feats_org = content_encoder(skt_org)
135 | content_feats_erased = content_encoder(skt_erased)
136 | content_feats_bold = content_encoder(skt_bold)
137 |
138 | # 1.5 encode loss
139 | loss_style_feat = loss_for_style(style_vector_rd, style_vector_org, BATCH_SIZE_AE) + \
140 | loss_for_style(style_vector_flip, style_vector_org, BATCH_SIZE_AE)
141 |
142 | loss_content_feat = loss_for_content(F.mse_loss, content_feats_bold, content_feats_org) + \
143 | loss_for_content(F.mse_loss, content_feats_erased, content_feats_org)
144 |
145 | loss_cls = F.cross_entropy(pred_cls_org, img_idx) + \
146 | F.cross_entropy(pred_cls_rd, img_idx) + \
147 | F.cross_entropy(pred_cls_flip, img_idx)
148 |
149 | ################
150 | # decode #
151 | ################
152 | org = random.randint(0, 2)
153 | gimg_org = None
154 | if org == 0:
155 | gimg_org = decoder(content_feats_org, style_vector_org)
156 | elif org == 1:
157 | gimg_org = decoder(content_feats_erased, style_vector_org)
158 | elif org == 2:
159 | gimg_org = decoder(content_feats_bold, style_vector_org)
160 |
161 | rd = random.randint(0, 2)
162 | gimg_rd = None
163 | if rd == 0:
164 | gimg_rd = decoder(content_feats_org, style_vector_rd)
165 | elif rd == 1:
166 | gimg_rd = decoder(content_feats_erased, style_vector_rd)
167 | elif rd == 2:
168 | gimg_rd = decoder(content_feats_bold, style_vector_rd)
169 |
170 | flip = random.randint(0, 2)
171 | gimg_flip = None
172 | if flip == 0:
173 | gimg_flip = decoder(content_feats_org, style_vector_flip)
174 | elif flip == 1:
175 | gimg_flip = decoder(content_feats_erased, style_vector_flip)
176 | elif flip == 2:
177 | gimg_flip = decoder(content_feats_bold, style_vector_flip)
178 |
179 | # 1.6 decode loss
180 | loss_org = F.mse_loss(gimg_org, rgb_img_org) + \
181 | percept(F.adaptive_avg_pool2d(gimg_org, output_size=256),
182 | F.adaptive_avg_pool2d(rgb_img_org, output_size=256)).sum()
183 |
184 | loss_rd = F.mse_loss(gimg_rd, rgb_img_org) + \
185 | percept(F.adaptive_avg_pool2d(gimg_rd, output_size=256),
186 | F.adaptive_avg_pool2d(rgb_img_org, output_size=256)).sum()
187 |
188 | loss_flip = F.mse_loss(gimg_flip, rgb_img_org) + \
189 | percept(F.adaptive_avg_pool2d(gimg_flip, output_size=256),
190 | F.adaptive_avg_pool2d(rgb_img_org, output_size=256)).sum()
191 |
192 | loss_total = loss_style_feat + loss_content_feat + loss_cls + loss_org + loss_rd + loss_flip
193 | loss_total.backward()
194 |
195 | opt_style.step()
196 | opt_content.step()
197 | opt_s_cls.step()
198 | opt_decode.step()
199 |
200 | # 1.7 update log
201 | losses_style_feat.update(loss_style_feat.mean().item(), BATCH_SIZE_AE)
202 | losses_content_feat.update(loss_content_feat.mean().item(), BATCH_SIZE_AE)
203 | losses_cls.update(loss_cls.mean().item(), BATCH_SIZE_AE)
204 | losses_org.update(loss_org.item(), BATCH_SIZE_AE)
205 | losses_rd.update(loss_rd.item(), BATCH_SIZE_AE)
206 | losses_flip.update(loss_flip.item(), BATCH_SIZE_AE)
207 |
208 | # 1.8 print log
209 | if iteration % LOG_INTERVAL == 0:
210 | log_msg = '\nTrain Stage 1 (encode and decode): \n' \
211 | 'loss_encode_style: %.4f loss_encode_content: %.4f loss_encode_class: %.4f \n' \
212 | 'loss_decode_org: %.4f loss_decode_rd: %.4f loss_decode_flip: %.4f' % (
213 | losses_style_feat.avg, losses_content_feat.avg, losses_cls.avg,
214 | losses_org.avg, losses_rd.avg, losses_flip.avg)
215 | print(log_msg)
216 |
217 | losses_style_feat.reset()
218 | losses_content_feat.reset()
219 | losses_cls.reset()
220 | losses_org.reset()
221 | losses_rd.reset()
222 | losses_flip.reset()
223 |
224 | if iteration % SAVE_IMAGE_INTERVAL == 0:
225 | vutils.save_image(torch.cat([rgb_img_org,
226 | F.interpolate(skt_org.repeat(1, 3, 1, 1), size=IM_SIZE_AE),
227 | gimg_org]),
228 | '%s/%d_org.jpg' % (saved_image_folder, iteration), normalize=True, range=(-1, 1))
229 | vutils.save_image(torch.cat([rgb_img_rd,
230 | F.interpolate(skt_org.repeat(1, 3, 1, 1), size=IM_SIZE_AE),
231 | gimg_rd]),
232 | '%s/%d_rd.jpg' % (saved_image_folder, iteration), normalize=True, range=(-1, 1))
233 | vutils.save_image(torch.cat([rgb_img_flip,
234 | F.interpolate(skt_org.repeat(1, 3, 1, 1), size=IM_SIZE_AE),
235 | gimg_flip]),
236 | '%s/%d_flip.jpg' % (saved_image_folder, iteration), normalize=True, range=(-1, 1))
237 |
238 | if iteration % SAVE_MODEL_INTERVAL == 0:
239 | print('Saving history model')
240 | torch.save({'s': style_encoder.state_dict(),
241 | 'c': content_encoder.state_dict(),
242 | 'd': decoder.state_dict(),
243 | 'opt_s': opt_style.state_dict(),
244 | 'opt_c': opt_content.state_dict(),
245 | 'opt_s_cls': opt_s_cls.state_dict(),
246 | 'opt_d': opt_decode.state_dict(),
247 | }, '%s/%d.pth' % (saved_model_folder, iteration))
248 |
249 | torch.save({'s': style_encoder.state_dict(),
250 | 'c': content_encoder.state_dict(),
251 | 'd': decoder.state_dict(),
252 | 'opt_s': opt_style.state_dict(),
253 | 'opt_c': opt_content.state_dict(),
254 | 'opt_s_cls': opt_s_cls.state_dict(),
255 | 'opt_d': opt_decode.state_dict(),
256 | }, '%s/%d.pth' % (saved_model_folder, ITERATION_AE))
257 |
258 |
259 | if __name__ == "__main__":
260 | train()
261 |
--------------------------------------------------------------------------------
/styleme/train_step_2.py:
--------------------------------------------------------------------------------
1 | #################################
2 | # train_step_2 #
3 | # #
4 | # optimize transform images #
5 | #################################
6 |
7 | import torch
8 | from torch import nn
9 | import torch.nn.functional as F
10 | from torch import optim
11 |
12 | from torch.utils.data import DataLoader
13 | from torchvision import utils as vutils
14 |
15 | import os
16 | from tqdm import tqdm
17 | from datetime import datetime
18 | import pandas as pd
19 |
20 | from datasets import PairedDataset, InfiniteSamplerWrapper
21 | from utils import copy_G_params, make_folders, AverageMeter, d_hinge_loss, g_hinge_loss
22 | from models import AE, Discriminator
23 |
24 |
25 | def make_matrix(dataset_rgb, dataset_skt, net_ae, net_ig, BATCH_SIZE, IM_SIZE, im_name):
26 | dataloader_rgb = iter(DataLoader(dataset_rgb, BATCH_SIZE, shuffle=True))
27 | dataloader_skt = iter(DataLoader(dataset_skt, BATCH_SIZE, shuffle=True))
28 |
29 | rgb_img = next(dataloader_rgb)
30 | skt_img = next(dataloader_skt)
31 |
32 | skt_img = skt_img.mean(dim=1, keepdim=True)
33 |
34 | image_matrix = [torch.ones(1, 3, IM_SIZE, IM_SIZE)]
35 | image_matrix.append(rgb_img.clone())
36 | with torch.no_grad():
37 | rgb_img = rgb_img.cuda()
38 | for skt in skt_img:
39 | input_skts = skt.unsqueeze(0).repeat(BATCH_SIZE, 1, 1, 1).cuda()
40 |
41 | gimg_ae, style_feats = net_ae(input_skts, rgb_img)
42 | image_matrix.append(skt.unsqueeze(0).repeat(1, 3, 1, 1).clone())
43 | image_matrix.append(gimg_ae.cpu())
44 |
45 | g_images = net_ig(gimg_ae, style_feats).cpu()
46 | image_matrix.append(skt.unsqueeze(0).repeat(1, 3, 1, 1).clone().fill_(1))
47 | image_matrix.append(torch.nn.functional.interpolate(g_images, IM_SIZE))
48 |
49 | image_matrix = torch.cat(image_matrix)
50 | vutils.save_image(0.5 * (image_matrix + 1), im_name, nrow=BATCH_SIZE + 1)
51 |
52 |
53 | def save_csv(save_csv_path, iters, MSE, FID):
54 | time = '{}'.format(datetime.now())
55 | iters = '{}'.format(iters)
56 | MSE = '{:.5f}'.format(MSE)
57 | FID = '{:.5f}'.format(FID)
58 |
59 | print('------ Saving csv ------')
60 | list = [time, iters, MSE, FID]
61 |
62 | data = pd.DataFrame([list])
63 | data.to_csv(save_csv_path, mode='a', header=False, index=False)
64 |
65 |
66 | def train():
67 | from benchmark import load_patched_inception_v3
68 | import lpips
69 |
70 | from config import IM_SIZE_GAN, BATCH_SIZE_GAN, CHANNEL, NBR_CLS, DATALOADER_WORKERS, EPOCH_GAN, ITERATION_GAN, \
71 | ITERATION_AE, GAN_CKECKPOINT
72 | from config import SAVE_MODEL_INTERVAL, LOG_INTERVAL, SAVE_FOLDER, MULTI_GPU
73 | from config import PRETRAINED_AE_PATH
74 | from config import data_root_colorful, data_root_sketch
75 |
76 | inception = load_patched_inception_v3().cuda()
77 | inception.eval()
78 |
79 | percept = lpips.PerceptualLoss(model='net-lin', net='vgg', use_gpu=True)
80 |
81 | # save path
82 | save_csv_path = './checkpoint/train_results.csv'
83 |
84 | saved_image_folder, saved_model_folder = make_folders(SAVE_FOLDER, 'Train_step_2')
85 |
86 | if not os.path.exists(save_csv_path):
87 | df = pd.DataFrame(columns=['time', 'iters', 'Lpips', 'FID'])
88 | df.to_csv(save_csv_path, index=False)
89 | print('make csv successful !')
90 | else:
91 | print('csv is exist !')
92 |
93 | # load dataset
94 | dataset = PairedDataset(data_root_colorful, data_root_sketch, im_size=IM_SIZE_GAN)
95 | print('the dataset contains %d images.' % len(dataset))
96 | dataloader = iter(DataLoader(dataset, BATCH_SIZE_GAN, sampler=InfiniteSamplerWrapper(dataset),
97 | num_workers=DATALOADER_WORKERS, pin_memory=True))
98 |
99 | # load ae model
100 | net_ae = AE(ch=CHANNEL, nbr_cls=NBR_CLS)
101 |
102 | if PRETRAINED_AE_PATH is None:
103 | PRETRAINED_AE_PATH = SAVE_FOLDER + 'train_results/Train_step_1/' + 'models/%d.pth' % ITERATION_AE
104 | else:
105 | PRETRAINED_AE_PATH = PRETRAINED_AE_PATH
106 |
107 | print('Pre-trained AE path : ', PRETRAINED_AE_PATH)
108 |
109 | net_ae.load_state_dicts(PRETRAINED_AE_PATH)
110 | net_ae.cuda()
111 | net_ae.eval()
112 |
113 | from models import RefineGenerator as Generator
114 |
115 | # load generator & discriminator
116 | net_ig = Generator(ch=CHANNEL, im_size=IM_SIZE_GAN).cuda()
117 | net_id = Discriminator(nc=3).cuda()
118 |
119 | if MULTI_GPU:
120 | net_ae = nn.DataParallel(net_ae)
121 | net_ig = nn.DataParallel(net_ig)
122 | net_id = nn.DataParallel(net_id)
123 |
124 | net_ig_ema = copy_G_params(net_ig)
125 |
126 | opt_ig = optim.Adam(net_ig.parameters(), lr=2e-4, betas=(0.8, 0.999))
127 | opt_id = optim.Adam(net_id.parameters(), lr=2e-4, betas=(0.8, 0.999))
128 |
129 | if GAN_CKECKPOINT is not None:
130 | ckpt = torch.load(GAN_CKECKPOINT)
131 | net_ig.load_state_dict(ckpt['ig'])
132 | net_id.load_state_dict(ckpt['id'])
133 | net_ig_ema = ckpt['ig_ema']
134 | opt_ig.load_state_dict(ckpt['opt_ig'])
135 | opt_id.load_state_dict(ckpt['opt_id'])
136 | print('Pre-trained GAN path : ', GAN_CKECKPOINT)
137 |
138 | # loss log
139 | losses_g_img = AverageMeter()
140 | losses_d_img = AverageMeter()
141 | losses_mse = AverageMeter()
142 | losses_style = AverageMeter()
143 | losses_content = AverageMeter()
144 | losses_rec_ae = AverageMeter()
145 |
146 | fid_init = 1000.0
147 |
148 | ###################
149 | # train gan #
150 | ###################
151 | for epoch in range(EPOCH_GAN):
152 | for iteration in tqdm(range(ITERATION_GAN)):
153 | rgb_img, skt_img = next(dataloader)
154 |
155 | rgb_img = rgb_img.cuda()
156 | skt_img = skt_img.cuda()
157 |
158 | # 1. train Discriminator
159 | gimg_ae, style_feats = net_ae(skt_img, rgb_img)
160 | g_image = net_ig(gimg_ae, style_feats)
161 |
162 | real = net_id(rgb_img)
163 | fake = net_id(g_image.detach())
164 |
165 | loss_d = d_hinge_loss(real, fake)
166 |
167 | net_id.zero_grad()
168 | loss_d.backward()
169 | opt_id.step()
170 |
171 | # log ae loss
172 | loss_rec_ae = F.mse_loss(gimg_ae, rgb_img) + F.l1_loss(gimg_ae, rgb_img)
173 | losses_rec_ae.update(loss_rec_ae.item(), BATCH_SIZE_GAN)
174 |
175 | # 2. train Generator
176 | pred_g = net_id(g_image)
177 | loss_g = g_hinge_loss(pred_g)
178 |
179 | loss_mse = 10 * percept(F.adaptive_avg_pool2d(g_image, output_size=256),
180 | F.adaptive_avg_pool2d(rgb_img, output_size=256)).sum()
181 | losses_mse.update(loss_mse.item() / BATCH_SIZE_GAN, BATCH_SIZE_GAN)
182 |
183 | _, g_style_feats = net_ae(skt_img, g_image)
184 |
185 | loss_style = 0
186 | for loss_idx in range(3):
187 | loss_style += - F.cosine_similarity(g_style_feats[loss_idx],
188 | style_feats[loss_idx].detach()).mean() + \
189 | F.cosine_similarity(g_style_feats[loss_idx],
190 | style_feats[loss_idx][torch.randperm(BATCH_SIZE_GAN)]
191 | .detach()).mean()
192 | losses_style.update(loss_style.item() / BATCH_SIZE_GAN, BATCH_SIZE_GAN)
193 |
194 | loss_all = loss_g + loss_mse + loss_style
195 |
196 | net_ig.zero_grad()
197 | loss_all.backward()
198 | opt_ig.step()
199 |
200 | for p, avg_p in zip(net_ig.parameters(), net_ig_ema):
201 | avg_p.mul_(0.999).add_(p.data, alpha=0.001)
202 |
203 | # 3. logging
204 | losses_g_img.update(pred_g.mean().item(), BATCH_SIZE_GAN)
205 | losses_d_img.update(real.mean().item(), BATCH_SIZE_GAN)
206 |
207 | # 4. save model
208 | if iteration % SAVE_MODEL_INTERVAL == 0 or iteration + 1 == 10000:
209 | print('Saving history model')
210 | torch.save({'ig': net_ig.state_dict(),
211 | 'id': net_id.state_dict(),
212 | 'ae': net_ae.state_dict(),
213 | 'ig_ema': net_ig_ema,
214 | 'opt_ig': opt_ig.state_dict(),
215 | 'opt_id': opt_id.state_dict(),
216 | }, '%s/%d.pth' % (saved_model_folder, epoch))
217 |
218 | # 5. print log
219 | if iteration % LOG_INTERVAL == 0:
220 | # calcuate lpips and fid
221 | cal_lpips = calculate_Lpips(data_root_colorful, data_root_sketch, net_ae, net_ig)
222 | cal_fid = calculate_fid(data_root_colorful, data_root_sketch, net_ae, net_ig)
223 |
224 | log_msg = ' \nGAN_Iter: [{0}/{1}] AE_loss: {ae_loss: .5f} \n' \
225 | 'Generator: {losses_g_img.avg:.4f} Discriminator: {losses_d_img.avg:.4f} \n' \
226 | 'Style: {losses_style.avg:.5f} Content: {losses_content.avg:.5f} \n' \
227 | 'Lpips: {lpips:.4f} FID: {fid:.4f}\n'.format(
228 | epoch, iteration, ae_loss=losses_rec_ae.avg, losses_g_img=losses_g_img,
229 | losses_d_img=losses_d_img, losses_style=losses_style, losses_content=losses_content,
230 | lpips=cal_lpips, fid=cal_fid)
231 |
232 | print(log_msg)
233 |
234 | save_csv(save_csv_path, epoch * ITERATION_GAN + iteration, cal_lpips, cal_fid)
235 |
236 | # save model
237 | if cal_fid < fid_init:
238 | fid_init = cal_fid
239 | print('Saving history model')
240 | torch.save({'ig': net_ig.state_dict(),
241 | 'id': net_id.state_dict(),
242 | 'ae': net_ae.state_dict(),
243 | 'ig_ema': net_ig_ema,
244 | 'opt_ig': opt_ig.state_dict(),
245 | 'opt_id': opt_id.state_dict(),
246 | }, '%s/%d_%d.pth' % (saved_model_folder, epoch, iteration))
247 |
248 | losses_g_img.reset()
249 | losses_d_img.reset()
250 | losses_mse.reset()
251 | losses_style.reset()
252 | losses_content.reset()
253 | losses_rec_ae.reset()
254 |
255 |
256 | def calculate_Lpips(data_root_colorful, data_root_sketch, net_ae, net_ig):
257 | import lpips
258 |
259 | IM_SIZE = 256
260 | BATCH_SIZE = 6
261 | DATALOADER_WORKERS = 0
262 |
263 | percept = lpips.PerceptualLoss(model='net-lin', net='vgg', use_gpu=True)
264 |
265 | # load dataset
266 | dataset = PairedDataset(data_root_colorful, data_root_sketch, im_size=IM_SIZE)
267 | print('the dataset contains %d images.' % len(dataset))
268 |
269 | dataloader = iter(DataLoader(dataset, BATCH_SIZE, sampler=InfiniteSamplerWrapper(dataset),
270 | num_workers=DATALOADER_WORKERS, pin_memory=True))
271 |
272 | net_ae.eval()
273 | net_ig.eval()
274 |
275 | # lpips
276 | get_lpips = AverageMeter()
277 | lpips_list = []
278 |
279 | # Network
280 | for iter_data in tqdm(range(100)):
281 | rgb_img, skt_img = next(dataloader)
282 |
283 | rgb_img = rgb_img.cuda()
284 | skt_img = skt_img.cuda()
285 |
286 | gimg_ae, style_feats = net_ae(skt_img, rgb_img)
287 | g_image = net_ig(gimg_ae, style_feats)
288 |
289 | loss_mse = 10 * percept(F.adaptive_avg_pool2d(g_image, output_size=256),
290 | F.adaptive_avg_pool2d(rgb_img, output_size=256)).sum()
291 | get_lpips.update(loss_mse.item() / BATCH_SIZE, BATCH_SIZE)
292 |
293 | lpips_list.append(get_lpips.avg)
294 |
295 | print('LPIPS : ', sum(lpips_list) / len(lpips_list))
296 |
297 | return sum(lpips_list) / len(lpips_list)
298 |
299 |
300 | def calculate_fid(data_root_colorful, data_root_sketch, net_ae, net_ig):
301 | from benchmark import calc_fid, extract_feature_from_generator_fn, load_patched_inception_v3, \
302 | real_image_loader, image_generator
303 | import numpy as np
304 |
305 | IM_SIZE = 256
306 | BATCH_SIZE = 8
307 | DATALOADER_WORKERS = 0
308 | fid_batch_images = 119
309 | fid_iters = 10
310 | inception = load_patched_inception_v3().cuda()
311 | inception.eval()
312 |
313 | fid = []
314 |
315 | # load dataset
316 | dataset = PairedDataset(data_root_colorful, data_root_sketch, im_size=IM_SIZE)
317 | print('the dataset contains %d images.' % len(dataset))
318 |
319 | dataloader = iter(DataLoader(dataset, BATCH_SIZE, sampler=InfiniteSamplerWrapper(dataset),
320 | num_workers=DATALOADER_WORKERS, pin_memory=True))
321 |
322 | net_ae.eval()
323 | net_ig.eval()
324 |
325 | print("calculating FID ...")
326 |
327 | real_features = extract_feature_from_generator_fn(
328 | real_image_loader(dataloader, n_batches=fid_batch_images), inception)
329 | real_mean = np.mean(real_features, 0)
330 | real_cov = np.cov(real_features, rowvar=False)
331 | real_features = {'feats': real_features, 'mean': real_mean, 'cov': real_cov}
332 |
333 | for iter_fid in range(fid_iters):
334 | sample_features = extract_feature_from_generator_fn(
335 | image_generator(dataset, net_ae, net_ig, n_batches=fid_batch_images),
336 | inception, total=fid_batch_images // BATCH_SIZE - 1)
337 | cur_fid = calc_fid(sample_features, real_mean=real_features['mean'], real_cov=real_features['cov'])
338 |
339 | print('FID[{}]: '.format(iter_fid), cur_fid)
340 | fid.append(cur_fid)
341 |
342 | print('FID: ', sum(fid) / len(fid))
343 |
344 | return sum(fid) / len(fid)
345 |
346 |
347 | if __name__ == "__main__":
348 | train()
349 |
--------------------------------------------------------------------------------
/styleme/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from copy import deepcopy
4 | from random import shuffle
5 | import torch.nn.functional as F
6 |
7 |
8 | def d_hinge_loss(real_pred, fake_pred):
9 | real_loss = F.relu(1 - real_pred)
10 | fake_loss = F.relu(1 + fake_pred)
11 |
12 | return real_loss.mean() + fake_loss.mean()
13 |
14 |
15 | def g_hinge_loss(pred):
16 | return -pred.mean()
17 |
18 |
19 | class AverageMeter(object):
20 |
21 | def __init__(self):
22 | self.reset()
23 |
24 | def reset(self):
25 | self.val = 0
26 | self.avg = 0
27 | self.sum = 0
28 | self.count = 0
29 |
30 | def update(self, val, n=1):
31 | self.val = val
32 | self.sum += val * n
33 | self.count += n
34 | self.avg = self.sum / self.count
35 |
36 |
37 | def true_randperm(size, device='cuda'):
38 | def unmatched_randperm(size):
39 | l1 = [i for i in range(size)]
40 | l2 = []
41 | for j in range(size):
42 | deleted = False
43 | if j in l1:
44 | deleted = True
45 | del l1[l1.index(j)]
46 | shuffle(l1)
47 | if len(l1) == 0:
48 | return 0, False
49 | l2.append(l1[0])
50 | del l1[0]
51 | if deleted:
52 | l1.append(j)
53 | return l2, True
54 |
55 | flag = False
56 | l = torch.zeros(size).long()
57 | while not flag:
58 | l, flag = unmatched_randperm(size)
59 | return torch.LongTensor(l).to(device)
60 |
61 |
62 | def copy_G_params(model):
63 | flatten = deepcopy(list(p.data for p in model.parameters()))
64 | return flatten
65 |
66 |
67 | def load_params(model, new_param):
68 | for p, new_p in zip(model.parameters(), new_param):
69 | p.data.copy_(new_p)
70 |
71 |
72 | def make_folders(save_folder, trial_name):
73 | saved_model_folder = os.path.join(save_folder, 'train_results/%s/models' % trial_name)
74 | saved_image_folder = os.path.join(save_folder, 'train_results/%s/images' % trial_name)
75 | folders = [os.path.join(save_folder, 'train_results'),
76 | os.path.join(save_folder, 'train_results/%s' % trial_name),
77 | os.path.join(save_folder, 'train_results/%s/images' % trial_name),
78 | os.path.join(save_folder, 'train_results/%s/models' % trial_name)]
79 | for folder in folders:
80 | if not os.path.exists(folder):
81 | os.mkdir(folder)
82 |
83 | return saved_image_folder, saved_model_folder
84 |
--------------------------------------------------------------------------------