├── README.md ├── fig ├── fid_lpips.png ├── performance.png ├── sketch.png └── styleme.png ├── sketch_generation ├── README.md ├── benchmark.py ├── evaluate.py ├── metrics.py ├── models.py ├── operation.py ├── requirement.txt ├── train.py ├── utils.py ├── vgg-feature-weights.z01 ├── vgg-feature-weights.z02 └── vgg-feature-weights.zip └── styleme ├── benchmark.py ├── calculate.py ├── config.py ├── datasets.py ├── framework.png ├── generate_matrix.py ├── lpips ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ ├── base_model.cpython-37.pyc │ ├── base_model.cpython-38.pyc │ ├── dist_model.cpython-37.pyc │ ├── dist_model.cpython-38.pyc │ ├── networks_basic.cpython-37.pyc │ ├── networks_basic.cpython-38.pyc │ ├── pretrained_networks.cpython-37.pyc │ └── pretrained_networks.cpython-38.pyc ├── base_model.py ├── dist_model.py ├── networks_basic.py ├── pretrained_networks.py └── weights │ ├── v0.0 │ ├── alex.pth │ ├── squeeze.pth │ └── vgg.pth │ └── v0.1 │ ├── alex.pth │ ├── squeeze.pth │ └── vgg.pth ├── models.py ├── readme.md ├── style_transform.py ├── train.py ├── train_step_1.py ├── train_step_2.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # StyleMe: Towards Intelligent Fashion Generation with Designer Style 2 | 3 | Proceedings of the 2023 CHI Conference on Human Factors in Computing Systems (**CHI 2023**) | [**Paper**](https://dl.acm.org/doi/fullHtml/10.1145/3544548.3581377) 4 | 5 | 6 | Our model contains the following two parts and datasets is available: 7 | - **image to sketch module** : [ [**sketch_generation**](https://github.com/ExponentiAI/StyleMe/tree/main/sketch_generation) ] 8 | - **sketch to image module** : [ [**style_transform**](https://github.com/ExponentiAI/StyleMe/tree/main/styleme) ] 9 | - **available dataset** : [ [**clothdataset**](https://drive.google.com/drive/folders/1tAHeblEon0Awb3QchTlLq9Knyc443i3x) ] 10 | 11 | 12 | ## 1. Video 13 | 14 |

15 | 16 |

17 | 18 | - The video link:**[StyleMe Demonstration](https://user-images.githubusercontent.com/43172916/218964923-1f99907c-4841-4cca-a961-fc771f22834f.mp4)** 19 | 20 | 21 | ## 2. Performance 22 | - Here is our model's performance: 23 | 24 | - Sketch Generation 25 |

26 | 27 |

28 | 29 | - Style Transfer 30 |

31 | 32 |

33 | 34 | - and the FID and LPIPS during training: 35 | 36 |

37 | 38 |

39 | 40 | ## 3. Reference 41 | 42 | If you find our code or dataset is useful for your research, please cite our paper. 43 | 44 | BibTex : 45 | ``` 46 | @inproceedings{wu2023styleme, 47 | title={StyleMe: Towards Intelligent Fashion Generation with Designer Style}, 48 | author={Wu, Di and Yu, Zhiwang and Ma, Nan and Jiang, Jianan and Wang, Yuetian and Zhou, Guixiang and Deng, Hanhui and Li, Yi}, 49 | booktitle={Proceedings of the 2023 CHI Conference on Human Factors in Computing Systems}, 50 | pages={1--16}, 51 | year={2023} 52 | } 53 | ``` 54 | 55 | Or : 56 | ``` 57 | Di Wu, Zhiwang Yu, Nan Ma, Jianan Jiang, Yuetian Wang, Guixiang Zhou, Hanhui Deng, Yi Li: StyleMe: Towards Intelligent Fashion Generation with Designer Style. CHI 2023: 23:1-23:16 58 | ``` 59 | 60 | 61 | -------------------------------------------------------------------------------- /fig/fid_lpips.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExponentiAI/StyleMe/64bf50eba799feba7c077cbf4ef4507ddf5c81d9/fig/fid_lpips.png -------------------------------------------------------------------------------- /fig/performance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExponentiAI/StyleMe/64bf50eba799feba7c077cbf4ef4507ddf5c81d9/fig/performance.png -------------------------------------------------------------------------------- /fig/sketch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExponentiAI/StyleMe/64bf50eba799feba7c077cbf4ef4507ddf5c81d9/fig/sketch.png -------------------------------------------------------------------------------- /fig/styleme.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExponentiAI/StyleMe/64bf50eba799feba7c077cbf4ef4507ddf5c81d9/fig/styleme.png -------------------------------------------------------------------------------- /sketch_generation/README.md: -------------------------------------------------------------------------------- 1 | # StyelMe - pytorch 2 | A pytorch implementation of image-to-sketch model. 3 | running environment: python 3.7.0 pytorch 1.12.1 4 | ## Data 5 | Include RGB image and sketch image of clothes in various styles. 6 | 7 | ## Description 8 | Related code comments: 9 | * models.py all the related models' structure definition, including generator and discriminator 10 | * train.py training the whole model, 11 | * evaluate.py test the model 12 | * vgg-feature-weights.pth pretrained model feature-weights 13 | -------------------------------------------------------------------------------- /sketch_generation/benchmark.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchvision import models 5 | from torchvision.models import inception_v3, Inception3 6 | from torchvision.utils import save_image 7 | from torchvision import utils as vutils 8 | from torch.utils.data import DataLoader 9 | 10 | try: 11 | from torchvision.models.utils import load_state_dict_from_url 12 | except ImportError: 13 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 14 | 15 | import numpy as np 16 | from scipy import linalg 17 | from tqdm import tqdm 18 | import pickle 19 | import os 20 | from utils import true_randperm 21 | 22 | # Inception weights ported to Pytorch from 23 | # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 24 | FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' 25 | 26 | 27 | class InceptionV3(nn.Module): 28 | """Pretrained InceptionV3 network returning feature maps""" 29 | 30 | # Index of default block of inception to return, 31 | # corresponds to output of final average pooling 32 | DEFAULT_BLOCK_INDEX = 3 33 | 34 | # Maps feature dimensionality to their output blocks indices 35 | BLOCK_INDEX_BY_DIM = { 36 | 64: 0, # First max pooling features 37 | 192: 1, # Second max pooling featurs 38 | 768: 2, # Pre-aux classifier features 39 | 2048: 3 # Final average pooling features 40 | } 41 | 42 | def __init__(self, 43 | output_blocks=[DEFAULT_BLOCK_INDEX], 44 | resize_input=True, 45 | normalize_input=True, 46 | requires_grad=False, 47 | use_fid_inception=True): 48 | """Build pretrained InceptionV3 49 | Parameters 50 | ---------- 51 | output_blocks : list of int 52 | Indices of blocks to return features of. Possible values are: 53 | - 0: corresponds to output of first max pooling 54 | - 1: corresponds to output of second max pooling 55 | - 2: corresponds to output which is fed to aux classifier 56 | - 3: corresponds to output of final average pooling 57 | resize_input : bool 58 | If true, bilinearly resizes input to width and height 299 before 59 | feeding input to model. As the network without fully connected 60 | layers is fully convolutional, it should be able to handle inputs 61 | of arbitrary size, so resizing might not be strictly needed 62 | normalize_input : bool 63 | If true, scales the input from range (0, 1) to the range the 64 | pretrained Inception network expects, namely (-1, 1) 65 | requires_grad : bool 66 | If true, parameters of the model require gradients. Possibly useful 67 | for finetuning the network 68 | use_fid_inception : bool 69 | If true, uses the pretrained Inception model used in Tensorflow's 70 | FID implementation. If false, uses the pretrained Inception model 71 | available in torchvision. The FID Inception model has different 72 | weights and a slightly different structure from torchvision's 73 | Inception model. If you want to compute FID scores, you are 74 | strongly advised to set this parameter to true to get comparable 75 | results. 76 | """ 77 | super(InceptionV3, self).__init__() 78 | 79 | self.resize_input = resize_input 80 | self.normalize_input = normalize_input 81 | self.output_blocks = sorted(output_blocks) 82 | self.last_needed_block = max(output_blocks) 83 | 84 | assert self.last_needed_block <= 3, \ 85 | 'Last possible output block index is 3' 86 | 87 | self.blocks = nn.ModuleList() 88 | 89 | if use_fid_inception: 90 | inception = fid_inception_v3() 91 | else: 92 | inception = models.inception_v3(pretrained=True) 93 | 94 | # Block 0: input to maxpool1 95 | block0 = [ 96 | inception.Conv2d_1a_3x3, 97 | inception.Conv2d_2a_3x3, 98 | inception.Conv2d_2b_3x3, 99 | nn.MaxPool2d(kernel_size=3, stride=2) 100 | ] 101 | self.blocks.append(nn.Sequential(*block0)) 102 | 103 | # Block 1: maxpool1 to maxpool2 104 | if self.last_needed_block >= 1: 105 | block1 = [ 106 | inception.Conv2d_3b_1x1, 107 | inception.Conv2d_4a_3x3, 108 | nn.MaxPool2d(kernel_size=3, stride=2) 109 | ] 110 | self.blocks.append(nn.Sequential(*block1)) 111 | 112 | # Block 2: maxpool2 to aux classifier 113 | if self.last_needed_block >= 2: 114 | block2 = [ 115 | inception.Mixed_5b, 116 | inception.Mixed_5c, 117 | inception.Mixed_5d, 118 | inception.Mixed_6a, 119 | inception.Mixed_6b, 120 | inception.Mixed_6c, 121 | inception.Mixed_6d, 122 | inception.Mixed_6e, 123 | ] 124 | self.blocks.append(nn.Sequential(*block2)) 125 | 126 | # Block 3: aux classifier to final avgpool 127 | if self.last_needed_block >= 3: 128 | block3 = [ 129 | inception.Mixed_7a, 130 | inception.Mixed_7b, 131 | inception.Mixed_7c, 132 | nn.AdaptiveAvgPool2d(output_size=(1, 1)) 133 | ] 134 | self.blocks.append(nn.Sequential(*block3)) 135 | 136 | for param in self.parameters(): 137 | param.requires_grad = requires_grad 138 | 139 | def forward(self, inp): 140 | """Get Inception feature maps 141 | Parameters 142 | ---------- 143 | inp : torch.autograd.Variable 144 | Input tensor of shape Bx3xHxW. Values are expected to be in 145 | range (0, 1) 146 | Returns 147 | ------- 148 | List of torch.autograd.Variable, corresponding to the selected output 149 | block, sorted ascending by index 150 | """ 151 | outp = [] 152 | x = inp 153 | 154 | if self.resize_input: 155 | x = F.interpolate(x, 156 | size=(299, 299), 157 | mode='bilinear', 158 | align_corners=False) 159 | 160 | if self.normalize_input: 161 | x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1) 162 | 163 | for idx, block in enumerate(self.blocks): 164 | x = block(x) 165 | if idx in self.output_blocks: 166 | outp.append(x) 167 | 168 | if idx == self.last_needed_block: 169 | break 170 | 171 | return outp 172 | 173 | 174 | def fid_inception_v3(): 175 | """Build pretrained Inception model for FID computation 176 | The Inception model for FID computation uses a different set of weights 177 | and has a slightly different structure than torchvision's Inception. 178 | This method first constructs torchvision's Inception and then patches the 179 | necessary parts that are different in the FID Inception model. 180 | """ 181 | inception = models.inception_v3(num_classes=1008, 182 | aux_logits=False, 183 | pretrained=False) 184 | inception.Mixed_5b = FIDInceptionA(192, pool_features=32) 185 | inception.Mixed_5c = FIDInceptionA(256, pool_features=64) 186 | inception.Mixed_5d = FIDInceptionA(288, pool_features=64) 187 | inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128) 188 | inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160) 189 | inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160) 190 | inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192) 191 | inception.Mixed_7b = FIDInceptionE_1(1280) 192 | inception.Mixed_7c = FIDInceptionE_2(2048) 193 | 194 | state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True) 195 | inception.load_state_dict(state_dict) 196 | return inception 197 | 198 | 199 | class FIDInceptionA(models.inception.InceptionA): 200 | """InceptionA block patched for FID computation""" 201 | def __init__(self, in_channels, pool_features): 202 | super(FIDInceptionA, self).__init__(in_channels, pool_features) 203 | 204 | def forward(self, x): 205 | branch1x1 = self.branch1x1(x) 206 | 207 | branch5x5 = self.branch5x5_1(x) 208 | branch5x5 = self.branch5x5_2(branch5x5) 209 | 210 | branch3x3dbl = self.branch3x3dbl_1(x) 211 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 212 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) 213 | 214 | # Patch: Tensorflow's average pool does not use the padded zero's in 215 | # its average calculation 216 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 217 | count_include_pad=False) 218 | branch_pool = self.branch_pool(branch_pool) 219 | 220 | outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] 221 | return torch.cat(outputs, 1) 222 | 223 | 224 | class FIDInceptionC(models.inception.InceptionC): 225 | """InceptionC block patched for FID computation""" 226 | def __init__(self, in_channels, channels_7x7): 227 | super(FIDInceptionC, self).__init__(in_channels, channels_7x7) 228 | 229 | def forward(self, x): 230 | branch1x1 = self.branch1x1(x) 231 | 232 | branch7x7 = self.branch7x7_1(x) 233 | branch7x7 = self.branch7x7_2(branch7x7) 234 | branch7x7 = self.branch7x7_3(branch7x7) 235 | 236 | branch7x7dbl = self.branch7x7dbl_1(x) 237 | branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) 238 | branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) 239 | branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) 240 | branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) 241 | 242 | # Patch: Tensorflow's average pool does not use the padded zero's in 243 | # its average calculation 244 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 245 | count_include_pad=False) 246 | branch_pool = self.branch_pool(branch_pool) 247 | 248 | outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] 249 | return torch.cat(outputs, 1) 250 | 251 | 252 | class FIDInceptionE_1(models.inception.InceptionE): 253 | """First InceptionE block patched for FID computation""" 254 | def __init__(self, in_channels): 255 | super(FIDInceptionE_1, self).__init__(in_channels) 256 | 257 | def forward(self, x): 258 | branch1x1 = self.branch1x1(x) 259 | 260 | branch3x3 = self.branch3x3_1(x) 261 | branch3x3 = [ 262 | self.branch3x3_2a(branch3x3), 263 | self.branch3x3_2b(branch3x3), 264 | ] 265 | branch3x3 = torch.cat(branch3x3, 1) 266 | 267 | branch3x3dbl = self.branch3x3dbl_1(x) 268 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 269 | branch3x3dbl = [ 270 | self.branch3x3dbl_3a(branch3x3dbl), 271 | self.branch3x3dbl_3b(branch3x3dbl), 272 | ] 273 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 274 | 275 | # Patch: Tensorflow's average pool does not use the padded zero's in 276 | # its average calculation 277 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 278 | count_include_pad=False) 279 | branch_pool = self.branch_pool(branch_pool) 280 | 281 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 282 | return torch.cat(outputs, 1) 283 | 284 | 285 | class FIDInceptionE_2(models.inception.InceptionE): 286 | """Second InceptionE block patched for FID computation""" 287 | def __init__(self, in_channels): 288 | super(FIDInceptionE_2, self).__init__(in_channels) 289 | 290 | def forward(self, x): 291 | branch1x1 = self.branch1x1(x) 292 | 293 | branch3x3 = self.branch3x3_1(x) 294 | branch3x3 = [ 295 | self.branch3x3_2a(branch3x3), 296 | self.branch3x3_2b(branch3x3), 297 | ] 298 | branch3x3 = torch.cat(branch3x3, 1) 299 | 300 | branch3x3dbl = self.branch3x3dbl_1(x) 301 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 302 | branch3x3dbl = [ 303 | self.branch3x3dbl_3a(branch3x3dbl), 304 | self.branch3x3dbl_3b(branch3x3dbl), 305 | ] 306 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 307 | 308 | # Patch: The FID Inception model uses max pooling instead of average 309 | # pooling. This is likely an error in this specific Inception 310 | # implementation, as other Inception models use average pooling here 311 | # (which matches the description in the paper). 312 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1) 313 | branch_pool = self.branch_pool(branch_pool) 314 | 315 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 316 | return torch.cat(outputs, 1) 317 | 318 | 319 | class Inception3Feature(Inception3): 320 | def forward(self, x): 321 | if x.shape[2] != 299 or x.shape[3] != 299: 322 | x = F.interpolate(x, size=(299, 299), mode='bilinear', align_corners=True) 323 | 324 | x = self.Conv2d_1a_3x3(x) # 299 x 299 x 3 325 | x = self.Conv2d_2a_3x3(x) # 149 x 149 x 32 326 | x = self.Conv2d_2b_3x3(x) # 147 x 147 x 32 327 | x = F.max_pool2d(x, kernel_size=3, stride=2) # 147 x 147 x 64 328 | 329 | x = self.Conv2d_3b_1x1(x) # 73 x 73 x 64 330 | x = self.Conv2d_4a_3x3(x) # 73 x 73 x 80 331 | x = F.max_pool2d(x, kernel_size=3, stride=2) # 71 x 71 x 192 332 | 333 | x = self.Mixed_5b(x) # 35 x 35 x 192 334 | x = self.Mixed_5c(x) # 35 x 35 x 256 335 | x = self.Mixed_5d(x) # 35 x 35 x 288 336 | 337 | x = self.Mixed_6a(x) # 35 x 35 x 288 338 | x = self.Mixed_6b(x) # 17 x 17 x 768 339 | x = self.Mixed_6c(x) # 17 x 17 x 768 340 | x = self.Mixed_6d(x) # 17 x 17 x 768 341 | x = self.Mixed_6e(x) # 17 x 17 x 768 342 | 343 | x = self.Mixed_7a(x) # 17 x 17 x 768 344 | x = self.Mixed_7b(x) # 8 x 8 x 1280 345 | x = self.Mixed_7c(x) # 8 x 8 x 2048 346 | 347 | x = F.avg_pool2d(x, kernel_size=8) # 8 x 8 x 2048 348 | 349 | return x.view(x.shape[0], x.shape[1]) # 1 x 1 x 2048 350 | 351 | 352 | def load_patched_inception_v3(): 353 | # inception = inception_v3(pretrained=True) 354 | # inception_feat = Inception3Feature() 355 | # inception_feat.load_state_dict(inception.state_dict()) 356 | inception_feat = InceptionV3([3], normalize_input=False) 357 | 358 | return inception_feat 359 | 360 | 361 | @torch.no_grad() 362 | def extract_features(loader, inception, device): 363 | pbar = tqdm(loader) 364 | 365 | feature_list = [] 366 | 367 | for img in pbar: 368 | img = img.to(device) 369 | feature = inception(img)[0].view(img.shape[0], -1) 370 | feature_list.append(feature.to('cpu')) 371 | 372 | features = torch.cat(feature_list, 0) 373 | 374 | return features 375 | 376 | 377 | 378 | 379 | 380 | 381 | @torch.no_grad() 382 | def extract_feature_from_generator_fn(generator_fn, inception, device='cuda', total=1000): 383 | features = [] 384 | 385 | for batch in tqdm(generator_fn, total=total): 386 | try: 387 | feat = inception(batch)[0].view(batch.shape[0], -1) 388 | features.append(feat.to('cpu')) 389 | except: 390 | break 391 | features = torch.cat(features, 0).detach() 392 | return features.numpy() 393 | 394 | 395 | def calc_fid(sample_features, real_features=None, real_mean=None, real_cov=None, eps=1e-6): 396 | sample_mean = np.mean(sample_features, 0) 397 | sample_cov = np.cov(sample_features, rowvar=False) 398 | 399 | if real_features is not None: 400 | real_mean = np.mean(real_features, 0) 401 | real_cov = np.cov(real_features, rowvar=False) 402 | 403 | cov_sqrt, _ = linalg.sqrtm(sample_cov @ real_cov, disp=False) 404 | 405 | if not np.isfinite(cov_sqrt).all(): 406 | print('product of cov matrices is singular') 407 | offset = np.eye(sample_cov.shape[0]) * eps 408 | cov_sqrt = linalg.sqrtm((sample_cov + offset) @ (real_cov + offset)) 409 | 410 | if np.iscomplexobj(cov_sqrt): 411 | if not np.allclose(np.diagonal(cov_sqrt).imag, 0, atol=1e-3): 412 | m = np.max(np.abs(cov_sqrt.imag)) 413 | 414 | raise ValueError(f'Imaginary component {m}') 415 | 416 | cov_sqrt = cov_sqrt.real 417 | 418 | mean_diff = sample_mean - real_mean 419 | mean_norm = mean_diff @ mean_diff 420 | 421 | trace = np.trace(sample_cov) + np.trace(real_cov) - 2 * np.trace(cov_sqrt) 422 | 423 | fid = mean_norm + trace 424 | 425 | return fid 426 | 427 | def real_image_loader(dataloader, n_batches=10): 428 | counter = 0 429 | while counter < n_batches: 430 | counter += 1 431 | rgb_img = next(dataloader)[0] 432 | if counter == 1: 433 | vutils.save_image(0.5*(rgb_img+1), 'tmp_real.jpg') 434 | yield rgb_img.cuda() 435 | 436 | 437 | 438 | 439 | @torch.no_grad() 440 | def image_generator(dataset, net_ae, net_ig, BATCH_SIZE=8, n_batches=500): 441 | counter = 0 442 | dataloader = iter(DataLoader(dataset, BATCH_SIZE, sampler=InfiniteSamplerWrapper(dataset), num_workers=4, pin_memory=False)) 443 | n_batches = min( n_batches, len(dataset)//BATCH_SIZE-1 ) 444 | while counter < n_batches: 445 | counter += 1 446 | rgb_img, _, _, skt_img = next(dataloader) 447 | rgb_img = F.interpolate( rgb_img, size=512 ).cuda() 448 | skt_img = F.interpolate( skt_img, size=512 ).cuda() 449 | 450 | gimg_ae, style_feat = net_ae(skt_img, rgb_img) 451 | g_image = net_ig(gimg_ae, style_feat) 452 | if counter == 1: 453 | vutils.save_image(0.5*(g_image+1), 'tmp.jpg') 454 | yield g_image 455 | 456 | 457 | @torch.no_grad() 458 | def image_generator_perm(dataset, net_ae, net_ig, BATCH_SIZE=8, n_batches=500): 459 | counter = 0 460 | dataloader = iter(DataLoader(dataset, BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=False)) 461 | n_batches = min( n_batches, len(dataset)//BATCH_SIZE-1 ) 462 | while counter < n_batches: 463 | counter += 1 464 | rgb_img, _, _, skt_img = next(dataloader) 465 | rgb_img = F.interpolate( rgb_img, size=512 ).cuda() 466 | skt_img = F.interpolate( skt_img, size=512 ).cuda() 467 | 468 | perm = true_randperm(rgb_img.shape[0], device=rgb_img.device) 469 | 470 | gimg_ae, style_feat = net_ae(skt_img, rgb_img[perm]) 471 | g_image = net_ig(gimg_ae, style_feat) 472 | if counter == 1: 473 | vutils.save_image(0.5*(g_image+1), 'tmp.jpg') 474 | yield g_image 475 | 476 | 477 | 478 | if __name__ == "__main__": 479 | from utils import PairedMultiDataset, InfiniteSamplerWrapper, make_folders, AverageMeter 480 | from torch.utils.data import DataLoader 481 | from torchvision import utils as vutils 482 | IM_SIZE = 512 483 | BATCH_SIZE = 8 484 | DATALOADER_WORKERS = 8 485 | NBR_CLS = 2000 486 | TRIAL_NAME = 'trial_vae_512_1' 487 | SAVE_FOLDER = './' 488 | 489 | data_root_colorful = '../images/celebA/CelebA_512_test/img' 490 | data_root_sketch_1 = './sketch_simplification/vggadin_iter_700_test' 491 | data_root_sketch_2 = './sketch_simplification/vggadin_iter_1900_test' 492 | data_root_sketch_3 = './sketch_simplification/vggadin_iter_2300_test' 493 | 494 | dataset = PairedMultiDataset(data_root_colorful, data_root_sketch_1, data_root_sketch_2, data_root_sketch_3, im_size=IM_SIZE, rand_crop=False) 495 | dataloader = iter(DataLoader(dataset, BATCH_SIZE, shuffle=False, num_workers=DATALOADER_WORKERS, pin_memory=True)) 496 | 497 | 498 | from models import StyleEncoder, ContentEncoder, Decoder 499 | import pickle 500 | from models import AE, RefineGenerator 501 | from utils import load_params 502 | 503 | net_ig = RefineGenerator().cuda() 504 | net_ig = nn.DataParallel(net_ig) 505 | 506 | ckpt = './train_results/trial_refine_ae_as_gan_1024_2/models/4.pth' 507 | if ckpt is not None: 508 | ckpt = torch.load(ckpt) 509 | #net_ig.load_state_dict(ckpt['ig']) 510 | #net_id.load_state_dict(ckpt['id']) 511 | net_ig_ema = ckpt['ig_ema'] 512 | load_params(net_ig, net_ig_ema) 513 | net_ig = net_ig.module 514 | #net_ig.eval() 515 | 516 | net_ae = AE() 517 | net_ae.load_state_dicts('./train_results/trial_vae_512_1/models/176000.pth') 518 | net_ae.cuda() 519 | net_ae.eval() 520 | 521 | inception = load_patched_inception_v3().cuda() 522 | inception.eval() 523 | 524 | ''' 525 | real_features = extract_feature_from_generator_fn( 526 | real_image_loader(dataloader, n_batches=1000), inception ) 527 | real_mean = np.mean(real_features, 0) 528 | real_cov = np.cov(real_features, rowvar=False) 529 | ''' 530 | #pickle.dump({'feats': real_features, 'mean': real_mean, 'cov': real_cov}, open('celeba_fid_feats.npy','wb') ) 531 | 532 | real_features = pickle.load( open('celeba_fid_feats.npy', 'rb') ) 533 | real_mean = real_features['mean'] 534 | real_cov = real_features['cov'] 535 | #sample_features = extract_feature_from_generator_fn( real_image_loader(dataloader, n_batches=100), inception ) 536 | for it in range(1): 537 | itx = it * 8000 538 | ''' 539 | ckpt = torch.load('./train_results/%s/models/%d.pth'%(TRIAL_NAME, itx)) 540 | 541 | style_encoder.load_state_dict(ckpt['e']) 542 | content_encoder.load_state_dict(ckpt['c']) 543 | decoder.load_state_dict(ckpt['d']) 544 | 545 | dataloader = iter(DataLoader(dataset, BATCH_SIZE, sampler=InfiniteSamplerWrapper(dataset), num_workers=DATALOADER_WORKERS, pin_memory=True)) 546 | ''' 547 | 548 | sample_features = extract_feature_from_generator_fn( 549 | image_generator(dataset, net_ae, net_ig, n_batches=1800), inception, 550 | total=1800 ) 551 | 552 | #fid = calc_fid(sample_features, real_mean=real_features['mean'], real_cov=real_features['cov']) 553 | fid = calc_fid(sample_features, real_mean=real_mean, real_cov=real_cov) 554 | print(it, fid) 555 | 556 | real_features = extract_feature_from_generator_fn( 557 | real_image_loader(dataloader, n_batches=fid_batch_images), inception) 558 | real_mean = np.mean(real_features, 0) 559 | real_cov = np.cov(real_features, rowvar=False) 560 | pickle.dump({'feats': real_features, 'mean': real_mean, 'cov': real_cov}, 561 | open('%s_fid_feats.npy' % (DATA_NAME), 'wb')) 562 | real_features = pickle.load(open('%s_fid_feats.npy' % (DATA_NAME), 'rb')) -------------------------------------------------------------------------------- /sketch_generation/evaluate.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torchvision.datasets as Dataset 5 | import torchvision.utils as vutils 6 | from torch import nn 7 | 8 | 9 | from models import Generator, VGGSimple 10 | from operation import trans_maker_testing 11 | 12 | import argparse 13 | 14 | 15 | if __name__ == '__main__': 16 | 17 | parser = argparse.ArgumentParser(description='Style transfer GAN, during training, the model will learn to take a image from one specific catagory and transform it into another style domain') 18 | 19 | parser.add_argument('--path_content', type=str, help='path of resource dataset, should be a folder that has one or many sub image folders inside') 20 | parser.add_argument('--path_result', type=str, help='path to save the result images') 21 | parser.add_argument('--im_size', type=int, default=256, help='resolution of the generated images') 22 | 23 | parser.add_argument('--gpu_id', type=int, default=0, help='0 is the first gpu, 1 is the second gpu, etc.') 24 | parser.add_argument('--norm_layer', type=str, default="instance", help='can choose between [batch, instance]') 25 | parser.add_argument('--checkpoint', type=str, help='specify the path of the pre-trained model') 26 | 27 | args = parser.parse_args() 28 | 29 | print(str(args)) 30 | 31 | device = torch.device("cuda:%d"%(args.gpu_id)) 32 | 33 | im_size = args.im_size 34 | if im_size == 128: 35 | base = 4 36 | elif im_size == 256: 37 | base = 8 38 | elif im_size == 512: 39 | base = 16 40 | elif im_size == 1024: 41 | base = 32 42 | if im_size not in [128, 256, 512, 1024]: 43 | print("the size must be in [128, 256, 512, 1024]") 44 | 45 | vgg = VGGSimple() 46 | vgg.load_state_dict(torch.load('./vgg-feature-weights.pth', map_location=lambda a,b:a)) 47 | vgg.to(device) 48 | vgg.eval() 49 | for p in vgg.parameters(): 50 | p.requires_grad = False 51 | 52 | dataset = Dataset.ImageFolder(root=args.path_content, transform=trans_maker_testing(size=args.im_size)) 53 | 54 | net_g = Generator(infc=256, nfc=128) 55 | 56 | if args.checkpoint is not 'None': 57 | checkpoint = torch.load(args.checkpoint, map_location=lambda storage, loc: storage) 58 | net_g.load_state_dict(checkpoint['g']) 59 | print("saved model loaded") 60 | 61 | net_g.to(device) 62 | net_g.eval() 63 | 64 | dist_path = args.path_result 65 | if not os.path.exists(dist_path): 66 | os.mkdir(dist_path) 67 | 68 | 69 | print("begin generating images ...") 70 | with torch.no_grad(): 71 | for i in range(len(dataset)): 72 | print("generating the %dth image"%(i)) 73 | img = dataset[i][0].to(device) 74 | feat = vgg(img, base=base)[2] 75 | g_img = net_g(feat) 76 | 77 | g_img = g_img.mean(1).unsqueeze(1).detach().add(1).mul(0.5) 78 | g_img = (g_img > 0.7).float() 79 | vutils.save_image(g_img, os.path.join(dist_path, '%d.jpg'%(i))) -------------------------------------------------------------------------------- /sketch_generation/metrics.py: -------------------------------------------------------------------------------- 1 | # example of calculating the frechet inception distance in Keras 2 | import numpy 3 | import os 4 | import cv2 5 | import argparse 6 | import torch 7 | import numpy as np 8 | from scipy.linalg import sqrtm 9 | from keras.applications.inception_v3 import InceptionV3 10 | from keras.applications.inception_v3 import preprocess_input 11 | 12 | 13 | # os.environ["TF_CPP_MIN_LOG_LEVEL"]='2' # 只显示 warning 和 Error 14 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = '3' # 只显示 Error 15 | 16 | # calculate frechet inception distance 17 | def calculate_fid(model, images1, images2): 18 | # calculate activations 19 | act1 = model.predict(images1) 20 | act2 = model.predict(images2) 21 | # calculate mean and covariance statistics 22 | mu1, sigma1 = act1.mean(axis=0), np.cov(act1, rowvar=False) 23 | mu2, sigma2 = act2.mean(axis=0), np.cov(act2, rowvar=False) 24 | # calculate sum squared difference between means 25 | ssdiff = numpy.sum((mu1 - mu2)**2.0) 26 | # calculate sqrt of product between cov 27 | covmean = sqrtm(np.dot(sigma1, sigma2)) 28 | # check and correct imaginary numbers from sqrt 29 | if np.iscomplexobj(covmean): 30 | covmean = covmean.real 31 | # calculate score 32 | fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean) 33 | return fid 34 | 35 | #act1 =generatedImg ,act2 = realImg 36 | def calculate_fid_modify(act1,act2): 37 | # calculate activations 38 | # act1 = model.predict(images1) 39 | # act2 = model.predict(images2) 40 | # calculate mean and covariance statistics 41 | mu1, sigma1 = act1.mean(axis=0), np.cov(act1, rowvar=False) 42 | mu2, sigma2 = act2.mean(axis=0), np.cov(act2, rowvar=False) 43 | # calculate sum squared difference between means 44 | ssdiff = numpy.sum((mu1 - mu2)**2.0) 45 | # calculate sqrt of product between cov 46 | covmean = sqrtm(np.dot(sigma1, sigma2)) 47 | # check and correct imaginary numbers from sqrt 48 | if np.iscomplexobj(covmean): 49 | covmean = covmean.real 50 | # calculate score 51 | fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean) 52 | return fid 53 | 54 | def data_list(dirPath): 55 | generated_Dataset = [] 56 | real_Dataset = [] 57 | for root, dirs, files in os.walk(dirPath): 58 | for filename in sorted(files): # sorted已排序的列表副本 59 | # 判断该文件是否是目标文件 60 | if "generated" in filename: 61 | generatedPath = root + '/' + filename 62 | generatedImg = cv2.imread(generatedPath).astype('float32') 63 | generated_Dataset.append(generatedImg) 64 | # 对比图片路径 65 | realPath = root + '/' + filename.replace('generated', 'real') 66 | realImg = cv2.imread(realPath).astype('float32') 67 | real_Dataset.append(realImg) 68 | return generated_Dataset, real_Dataset 69 | 70 | if __name__ == '__main__': 71 | ### 参数设定 72 | parser = argparse.ArgumentParser() 73 | # parser.add_argument('--dataset_dir', type=str, default='./results/hrnet/', help='results') 74 | parser.add_argument('--dataset_dir', type=str, default='./results/ssngan/', help='results') 75 | parser.add_argument('--name', type=str, default='sketch', help='name of dataset') 76 | opt = parser.parse_args() 77 | 78 | # 数据集 79 | dirPath = os.path.join(opt.dataset_dir, opt.name) 80 | generatedImg, realImg = data_list(dirPath) 81 | dataset_size = len(generatedImg) 82 | print("数据集:", dataset_size) 83 | 84 | images1 = torch.Tensor(generatedImg) 85 | images2 = torch.Tensor(realImg) 86 | print('shape: ', images1.shape, images2.shape) 87 | 88 | # 将全部数据集导入 89 | # prepare the inception v3 model 90 | model = InceptionV3(include_top=False, pooling='avg') 91 | 92 | # pre-process images(归一化) 93 | images1 = preprocess_input(images1) 94 | images2 = preprocess_input(images2) 95 | 96 | # fid between images1 and images2 97 | fid = calculate_fid(model, images1, images2) 98 | print('FID : %.3f' % fid) 99 | print('FID_average : %.3f' % (fid / dataset_size)) 100 | 101 | 102 | 103 | -------------------------------------------------------------------------------- /sketch_generation/models.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from math import sqrt 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | from torch import cat, sigmoid 7 | from torch.autograd import Variable 8 | from torch.nn import Parameter, init 9 | from torch.nn.utils import spectral_norm 10 | import torch.nn.functional as F 11 | 12 | from torch.jit import ScriptModule, script_method, trace 13 | 14 | ##################################################################### 15 | ##### functions 16 | ##################################################################### 17 | 18 | def calc_mean_std(feat, eps=1e-5): 19 | # eps is a small value added to the variance to avoid divide-by-zero. 20 | size = feat.size() 21 | assert (len(size) == 4) 22 | N, C = size[:2] 23 | feat_var = feat.view(N, C, -1).var(dim=2) + eps 24 | feat_std = feat_var.sqrt().view(N, C, 1, 1) 25 | feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1) 26 | return feat_mean, feat_std 27 | 28 | 29 | # def adain(content_feat, style_feat): 30 | # assert (content_feat.size()[:2] == style_feat.size()[:2]) 31 | # size = content_feat.size() 32 | # style_mean, style_std = calc_mean_std(style_feat) 33 | # content_mean, content_std = calc_mean_std(content_feat) 34 | # 35 | # normalized_feat = (content_feat - content_mean.expand( 36 | # size)) / content_std.expand(size) 37 | # return normalized_feat * style_std.expand(size) + style_mean.expand(size) 38 | def AdaLIN(content_feat,style_feat): 39 | 40 | assert (content_feat.size()[:2]==style_feat.size()[:2]) 41 | 42 | rho=Parameter(torch.Tensor(4,256,32,32,)) #维度修改了,原来是 rho=Parameter(torch.Tensor(1,512,1,1,)) 43 | rho=rho.data.fill_(0.9) 44 | 45 | size=content_feat.size() 46 | style_mean,style_std=calc_mean_std(style_feat) 47 | content_mean,content_std=calc_mean_std(content_feat) 48 | out_style=(style_feat-style_mean.expand(size))/style_std.expand(size) 49 | out_content=(content_feat-content_mean.expand(size))/content_std.expand(size) 50 | out=rho.expand(size)*out_style+(1-rho.expand(size))*out_content 51 | return out 52 | 53 | def adain(content_feat, style_feat): 54 | assert (content_feat.size()[:2] == style_feat.size()[:2]) 55 | size = content_feat.size() 56 | style_mean, style_std = calc_mean_std(style_feat) 57 | content_mean, content_std = calc_mean_std(content_feat) 58 | 59 | normalized_feat = (content_feat - content_mean.expand( 60 | size)) / content_std.expand(size) 61 | normalized_features = normalized_feat * style_std.expand(size) + style_mean.expand(size) 62 | return normalized_features #torch.Size([4, 256, 32, 32]) 63 | 64 | def get_batched_gram_matrix(input): 65 | # take a batch of features: B X C X H X W 66 | # return gram of each image: B x C x C 67 | a, b, c, d = input.size() 68 | features = input.view(a, b, c * d) 69 | G = torch.bmm(features, features.transpose(2,1)) 70 | return G.div(b * c * d) 71 | 72 | class Adaptive_pool(nn.Module): 73 | ''' 74 | take a input tensor of size: B x C' X C' 75 | output a maxpooled tensor of size: B x C x H x W 76 | ''' 77 | def __init__(self, channel_out, hw_out): 78 | super().__init__() 79 | self.channel_out = channel_out 80 | self.hw_out = hw_out 81 | self.pool = nn.AdaptiveAvgPool2d((channel_out, hw_out**2)) 82 | def forward(self, input): 83 | if len(input.shape) == 3: 84 | input.unsqueeze_(1) 85 | return self.pool(input).view(-1, self.channel_out, self.hw_out, self.hw_out) 86 | ### new function 87 | 88 | ##################################################################### 89 | ##### models 90 | ##################################################################### 91 | class VGGSimple(nn.Module): 92 | def __init__(self): 93 | super(VGGSimple, self).__init__() 94 | 95 | self.features = self.make_layers() 96 | 97 | self.norm_mean = torch.Tensor([0.485, 0.456, 0.406]).view(1,3,1,1) 98 | self.norm_std = torch.Tensor([0.229, 0.224, 0.225]).view(1,3,1,1) 99 | 100 | def forward(self, img, after_relu=True, base=4): 101 | # re-normalize from [-1, 1] to [0, 1] then to the range used for vgg 102 | feat = (((img+1)*0.5) - self.norm_mean.to(img.device)) / self.norm_std.to(img.device) 103 | # the layer numbers used to extract features 104 | cut_points = [2, 7, 14, 21, 28] 105 | if after_relu: 106 | cut_points = [c+2 for c in cut_points] 107 | for i in range(31): 108 | feat = self.features[i](feat) 109 | if i == cut_points[0]: 110 | feat_64 = F.adaptive_avg_pool2d(feat, base*16) 111 | if i == cut_points[1]: 112 | feat_32 = F.adaptive_avg_pool2d(feat, base*8) 113 | if i == cut_points[2]: 114 | feat_16 = F.adaptive_avg_pool2d(feat, base*4) 115 | if i == cut_points[3]: 116 | feat_8 = F.adaptive_avg_pool2d(feat, base*2) 117 | if i == cut_points[4]: 118 | feat_4 = F.adaptive_avg_pool2d(feat, base) 119 | 120 | return feat_64, feat_32, feat_16, feat_8, feat_4 121 | 122 | def make_layers(self, cfg="D", batch_norm=False): 123 | cfg_dic = { 124 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 125 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 126 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 127 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 128 | } 129 | cfg = cfg_dic[cfg] 130 | layers = [] 131 | in_channels = 3 132 | for v in cfg: 133 | if v == 'M': 134 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 135 | else: 136 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 137 | if batch_norm: 138 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=False)] 139 | else: 140 | layers += [conv2d, nn.ReLU(inplace=False)] 141 | in_channels = v 142 | return nn.Sequential(*layers) 143 | 144 | 145 | # this model is used for pre-training 146 | class VGG_3label(nn.Module): 147 | def __init__(self, nclass_artist=1117, nclass_style=55, nclass_genre=26): 148 | super(VGG_3label, self).__init__() 149 | self.features = self.make_layers() 150 | self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) 151 | 152 | self.classifier_feat = self.classifier = nn.Sequential( 153 | nn.Linear(512 * 7 * 7, 4096), 154 | nn.ReLU(), 155 | nn.Dropout(), 156 | nn.Linear(4096, 4096), 157 | nn.ReLU(), 158 | nn.Dropout(), 159 | nn.Linear(4096, 512)) 160 | 161 | self.classifier_style = nn.Sequential(nn.ReLU(), nn.Dropout(), nn.Linear(512, nclass_style)) 162 | self.classifier_genre = nn.Sequential(nn.ReLU(), nn.Dropout(), nn.Linear(512, nclass_genre)) 163 | self.classifier_artist = nn.Sequential(nn.ReLU(), nn.Dropout(), nn.Linear(512, nclass_artist)) 164 | 165 | self.norm_mean = torch.Tensor([0.485, 0.456, 0.406]).view(1,3,1,1) 166 | self.norm_std = torch.Tensor([0.229, 0.224, 0.225]).view(1,3,1,1) 167 | 168 | self.avgpool_4 = nn.AdaptiveAvgPool2d((4, 4)) 169 | self.avgpool_8 = nn.AdaptiveAvgPool2d((8, 8)) 170 | self.avgpool_16 = nn.AdaptiveAvgPool2d((16, 16)) 171 | 172 | def get_features(self, img, after_relu=True, base=4): 173 | feat = (((img+1)*0.5) - self.norm_mean.to(img.device)) / self.norm_std.to(img.device) 174 | cut_points = [2, 7, 14, 21, 28] 175 | if after_relu: 176 | cut_points = [4, 9, 16, 23, 30] 177 | for i in range(31): 178 | feat = self.features[i](feat) 179 | if i == cut_points[0]: 180 | feat_64 = F.adaptive_avg_pool2d(feat, base*16) 181 | if i == cut_points[1]: 182 | feat_32 = F.adaptive_avg_pool2d(feat, base*8) 183 | if i == cut_points[2]: 184 | feat_16 = F.adaptive_avg_pool2d(feat, base*4) 185 | if i == cut_points[3]: 186 | feat_8 = F.adaptive_avg_pool2d(feat, base*2) 187 | if i == cut_points[4]: 188 | feat_4 = F.adaptive_avg_pool2d(feat, base) 189 | #feat_code = self.classifier_feat(self.avgpool(feat).view(img.size(0), -1)) 190 | return feat_64, feat_32, feat_16, feat_8, feat_4#, feat_code 191 | 192 | 193 | def load_pretrain_weights(self): 194 | pretrained_vgg16 = vgg.vgg16(pretrained=True) 195 | self.features.load_state_dict(pretrained_vgg16.features.state_dict()) 196 | self.classifier_feat[0] = pretrained_vgg16.classifier[0] 197 | self.classifier_feat[3] = pretrained_vgg16.classifier[3] 198 | for m in self.modules(): 199 | if isinstance(m, nn.Linear): 200 | nn.init.normal_(m.weight, 0, 0.01) 201 | nn.init.constant_(m.bias, 0) 202 | 203 | def make_layers(self, cfg="D", batch_norm=False): 204 | cfg_dic = { 205 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 206 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 207 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 208 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 209 | } 210 | cfg = cfg_dic[cfg] 211 | layers = [] 212 | in_channels = 3 213 | for v in cfg: 214 | if v == 'M': 215 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 216 | else: 217 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 218 | if batch_norm: 219 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=False)] 220 | else: 221 | layers += [conv2d, nn.ReLU(inplace=False)] 222 | in_channels = v 223 | return nn.Sequential(*layers) 224 | 225 | def forward(self, img): 226 | feature = self.classifier_feat( self.avgpool(self.features(img)).view(img.size(0), -1) ) 227 | pred_style = self.classifier_style(feature) 228 | pred_genre = self.classifier_genre(feature) 229 | pred_artist = self.classifier_artist(feature) 230 | return pred_style, pred_genre, pred_artist 231 | 232 | 233 | class UnFlatten(nn.Module): 234 | def __init__(self, block_size): 235 | super(UnFlatten, self).__init__() 236 | self.block_size = block_size 237 | 238 | def forward(self, x): 239 | return x.view(x.size(0), -1, self.block_size, self.block_size) 240 | 241 | 242 | class Flatten(nn.Module): 243 | def __init__(self): 244 | super(Flatten, self).__init__() 245 | 246 | def forward(self, x): 247 | return x.view(x.size(0), -1) 248 | 249 | #batchNorm2d-->InstanceNorm2d 250 | class UpConvBlock(nn.Module): 251 | def __init__(self, in_channel, out_channel, norm_layer=nn.InstanceNorm2d): 252 | super().__init__() 253 | 254 | self.main = nn.Sequential( 255 | nn.ReflectionPad2d(1), 256 | spectral_norm(nn.Conv2d(in_channel, out_channel, 3, 1, 0, bias=True)), 257 | norm_layer(out_channel), 258 | nn.LeakyReLU(0.01), 259 | ) 260 | 261 | def forward(self, x): 262 | y = F.interpolate(x, scale_factor=2) 263 | return self.main(y) 264 | 265 | #batchNorm2d-->InstanceNorm2d 266 | class DownConvBlock(nn.Module): 267 | def __init__(self, in_channel, out_channel, norm_layer=nn.InstanceNorm2d, down=True): 268 | super().__init__() 269 | 270 | m = [ spectral_norm(nn.Conv2d(in_channel, out_channel, 3, 1, 1, bias=True)), 271 | norm_layer(out_channel), 272 | nn.LeakyReLU(0.1) ] 273 | if down: 274 | m.append(nn.AvgPool2d(2, 2)) 275 | self.main = nn.Sequential(*m) 276 | 277 | def forward(self, x): 278 | return self.main(x) 279 | 280 | 281 | class ResNetBlock(nn.Module): 282 | def __init__(self, dim): 283 | super(ResNetBlock, self).__init__() 284 | conv_block = [] 285 | conv_block += [nn.ReflectionPad2d(1), 286 | nn.Conv2d(dim, dim, 3, 1, 0, bias=False), 287 | nn.InstanceNorm2d(dim), 288 | nn.ReLU(True)] 289 | 290 | conv_block += [nn.ReflectionPad2d(1), 291 | nn.Conv2d(dim, dim, 3, 1, 0, bias=False), 292 | nn.InstanceNorm2d(dim)] 293 | 294 | self.conv_block = nn.Sequential(*conv_block) 295 | 296 | def forward(self, x): 297 | out = x + self.conv_block(x) 298 | return out 299 | 300 | class Generator(nn.Module): 301 | def __init__(self, infc=512, nfc=64, nc_out=3): 302 | super(Generator, self).__init__() 303 | 304 | self.decode_32 = UpConvBlock(infc, nfc*4) #32 305 | self.decode_64 = UpConvBlock(nfc*4, nfc*4) #64 306 | self.decode_128 = UpConvBlock(nfc*4, nfc*2) #128 307 | self.gap_fc=nn.Linear(512,1,bias=False) 308 | self.gmp_fc=nn.Linear(512,1,bias=False) 309 | self.gamma = nn.Linear(512, 256, bias=False) #(256,256) 310 | self.beta = nn.Linear(512, 256, bias=False) # 311 | self.conv1x1 = nn.Conv2d(512, 256, 1, 1, bias=True) 312 | self.relu = nn.ReLU(inplace=True) 313 | self.final = nn.Sequential( 314 | spectral_norm( nn.Conv2d(nfc*2, nc_out, 3, 1, 1, bias=True) ), 315 | nn.Tanh()) 316 | self.netG_A2B = Generator_UGATIT(image_size=256) 317 | def forward(self, input): 318 | 319 | decode_32 = self.decode_32(input) # input torch.Size([8, 256, 32, 32]) 320 | decode_64 = self.decode_64(decode_32) 321 | decode_128 = self.decode_128(decode_64) 322 | 323 | output = self.final(decode_128) #output torch.Size([8, 3, 256, 256]) 324 | output=self.netG_A2B(output)[0] #此处解码后,再经过Generator_UGATIT 的处理后再输出 325 | return output 326 | 327 | class Generator_UGATIT(nn.Module): 328 | def __init__(self, image_size=256): 329 | super(Generator_UGATIT, self).__init__() 330 | down_layer = [ 331 | nn.ReflectionPad2d(3), 332 | nn.Conv2d(3, 64, 7, 1, 0, bias=False), 333 | nn.InstanceNorm2d(64), 334 | nn.ReLU(inplace=True), 335 | 336 | # Down-Sampling 337 | nn.ReflectionPad2d(1), 338 | nn.Conv2d(64, 128, 3, 2, 0, bias=False), 339 | nn.InstanceNorm2d(256), 340 | nn.ReLU(inplace=True), 341 | nn.ReflectionPad2d(1), 342 | nn.Conv2d(128, 256, 3, 2, 0, bias=False), 343 | nn.InstanceNorm2d(256), 344 | nn.ReLU(inplace=True), 345 | 346 | # Down-Sampling Bottleneck 347 | ResNetBlock(256), 348 | ResNetBlock(256), 349 | ResNetBlock(256), 350 | ResNetBlock(256), 351 | ] 352 | 353 | # Class Activation Map 354 | self.gap_fc = nn.Linear(256, 1, bias=False) 355 | self.gmp_fc = nn.Linear(256, 1, bias=False) 356 | self.conv1x1 = nn.Conv2d(512, 256, 1, 1, bias=True) 357 | self.relu = nn.ReLU(inplace=True) 358 | 359 | # Gamma, Beta block 360 | fc = [ 361 | nn.Linear(image_size * image_size * 16, 256, bias=False), 362 | nn.ReLU(inplace=True), 363 | nn.Linear(256, 256, bias=False), 364 | nn.ReLU(inplace=True) 365 | ] 366 | 367 | self.gamma = nn.Linear(256, 256, bias=False) 368 | self.beta = nn.Linear(256, 256, bias=False) 369 | 370 | # Up-Sampling Bottleneck 371 | for i in range(4): 372 | setattr(self, "ResNetAdaILNBlock_" + str(i + 1), ResNetAdaILNBlock(256)) 373 | 374 | up_layer = [ 375 | nn.Upsample(scale_factor=2, mode="nearest"), 376 | nn.ReflectionPad2d(1), 377 | nn.Conv2d(256, 128, 3, 1, 0, bias=False), 378 | ILN(128), 379 | nn.ReLU(inplace=True), 380 | 381 | nn.Upsample(scale_factor=2, mode="nearest"), 382 | nn.ReflectionPad2d(1), 383 | nn.Conv2d(128, 64, 3, 1, 0, bias=False), 384 | ILN(64), 385 | nn.ReLU(inplace=True), 386 | 387 | nn.ReflectionPad2d(3), 388 | nn.Conv2d(64, 3, 7, 1, 0, bias=False), 389 | nn.Tanh() 390 | ] 391 | 392 | self.down_layer = nn.Sequential(*down_layer) 393 | self.fc = nn.Sequential(*fc) 394 | self.up_layer = nn.Sequential(*up_layer) 395 | 396 | def forward(self, inputs): 397 | x = self.down_layer(inputs) 398 | gap = torch.nn.functional.adaptive_avg_pool2d(x, 1) 399 | gap_logit = self.gap_fc(gap.view(x.shape[0], -1)) 400 | gap_weight = list(self.gap_fc.parameters())[0] 401 | gap = x * gap_weight.unsqueeze(2).unsqueeze(3) 402 | 403 | gmp = torch.nn.functional.adaptive_max_pool2d(x, 1) 404 | gmp_logit = self.gmp_fc(gmp.view(x.shape[0], -1)) 405 | gmp_weight = list(self.gmp_fc.parameters())[0] 406 | gmp = x * gmp_weight.unsqueeze(2).unsqueeze(3) 407 | 408 | cam_logit = torch.cat([gap_logit, gmp_logit], 1) 409 | x = torch.cat([gap, gmp], 1) 410 | x = self.relu(self.conv1x1(x)) 411 | 412 | x_ = self.fc(x.view(x.shape[0], -1)) 413 | gamma, beta = self.gamma(x_), self.beta(x_) 414 | 415 | for i in range(4): 416 | x = getattr(self, "ResNetAdaILNBlock_" + str(i + 1))(x, gamma, beta) 417 | out = self.up_layer(x) 418 | 419 | return out, cam_logit 420 | 421 | 422 | 423 | 424 | class ResNetAdaILNBlock(nn.Module): 425 | def __init__(self, dim): 426 | super(ResNetAdaILNBlock, self).__init__() 427 | self.pad1 = nn.ReflectionPad2d(1) 428 | self.conv1 = nn.Conv2d(dim, dim, 3, 1, 0, bias=False) 429 | self.norm1 = AdaILN(dim) 430 | self.relu1 = nn.ReLU(True) 431 | 432 | self.pad2 = nn.ReflectionPad2d(1) 433 | self.conv2 = nn.Conv2d(dim, dim, 3, 1, 0, bias=False) 434 | self.norm2 = AdaILN(dim) 435 | 436 | def forward(self, x, gamma, beta): 437 | out = self.pad1(x) 438 | out = self.conv1(out) 439 | out = self.norm1(out, gamma, beta) 440 | out = self.relu1(out) 441 | out = self.pad2(out) 442 | out = self.conv2(out) 443 | out = self.norm2(out, gamma, beta) 444 | 445 | return out + x 446 | 447 | class ILN(nn.Module): 448 | def __init__(self, num_features, eps=1e-5): 449 | super(ILN, self).__init__() 450 | self.eps = eps 451 | self.rho = Parameter(torch.Tensor(1, num_features, 1, 1)) 452 | self.gamma = Parameter(torch.Tensor(1, num_features, 1, 1)) 453 | self.beta = Parameter(torch.Tensor(1, num_features, 1, 1)) 454 | self.rho.data.fill_(0.0) 455 | self.gamma.data.fill_(1.0) 456 | self.beta.data.fill_(0.0) 457 | 458 | def forward(self, x): 459 | in_mean, in_var = torch.mean(x, dim=[2, 3], keepdim=True), torch.var(x, dim=[2, 3], keepdim=True) 460 | out_in = (x - in_mean) / torch.sqrt(in_var + self.eps) 461 | ln_mean, ln_var = torch.mean(x, dim=[1, 2, 3], keepdim=True), torch.var(x, dim=[1, 2, 3], keepdim=True) 462 | out_ln = (x - ln_mean) / torch.sqrt(ln_var + self.eps) 463 | out = self.rho.expand(x.shape[0], -1, -1, -1) * out_in + (1 - self.rho.expand(x.shape[0], -1, -1, -1)) * out_ln 464 | out = out * self.gamma.expand(x.shape[0], -1, -1, -1) + self.beta.expand(x.shape[0], -1, -1, -1) 465 | 466 | return out 467 | 468 | class AdaILN(nn.Module): 469 | def __init__(self, num_features, eps=1e-5): 470 | super(AdaILN, self).__init__() 471 | self.eps = eps 472 | self.rho = Parameter(torch.Tensor(1, num_features, 1, 1)) 473 | self.rho.data.fill_(0.9) 474 | 475 | def forward(self, x, gamma, beta): 476 | in_mean, in_var = torch.mean(x, dim=[2, 3], keepdim=True), torch.var(x, dim=[2, 3], keepdim=True) 477 | out_in = (x - in_mean) / torch.sqrt(in_var + self.eps) 478 | ln_mean, ln_var = torch.mean(x, dim=[1, 2, 3], keepdim=True), torch.var(x, dim=[1, 2, 3], keepdim=True) 479 | out_ln = (x - ln_mean) / torch.sqrt(ln_var + self.eps) 480 | out = self.rho.expand(x.shape[0], -1, -1, -1) * out_in + (1 - self.rho.expand(x.shape[0], -1, -1, -1)) * out_ln 481 | out = out * gamma.unsqueeze(2).unsqueeze(3) + beta.unsqueeze(2).unsqueeze(3) 482 | 483 | return out 484 | 485 | class ResNetBlock(nn.Module): 486 | def __init__(self, dim): 487 | super(ResNetBlock, self).__init__() 488 | conv_block = [] 489 | conv_block += [nn.ReflectionPad2d(1), 490 | nn.Conv2d(dim, dim, 3, 1, 0, bias=False), 491 | nn.InstanceNorm2d(dim), 492 | nn.ReLU(True)] 493 | 494 | conv_block += [nn.ReflectionPad2d(1), 495 | nn.Conv2d(dim, dim, 3, 1, 0, bias=False), 496 | nn.InstanceNorm2d(dim)] 497 | 498 | self.conv_block = nn.Sequential(*conv_block) 499 | 500 | def forward(self, x): 501 | out = x + self.conv_block(x) 502 | return out 503 | 504 | 505 | 506 | 507 | 508 | 509 | 510 | class Discriminator(nn.Module): 511 | def __init__(self, nfc=512, norm_layer=nn.InstanceNorm2d): 512 | super(Discriminator, self).__init__() 513 | self.gap_fc = nn.utils.spectral_norm(nn.Linear(1, 4, bias=False)) #这里维度修改了原来是64 * 8, 1 514 | self.gmp_fc = nn.utils.spectral_norm(nn.Linear(1, 4, bias=False)) 515 | self.conv1x1 = nn.Conv2d(2, 4, 3, 3, bias=True) 516 | self.leaky_relu = nn.LeakyReLU(0.2, True) 517 | 518 | self.pad = nn.ReflectionPad2d(1) 519 | self.conv = nn.utils.spectral_norm(nn.Conv2d(4, 4, 1, 1, 0, bias=False)) 520 | 521 | self.main = nn.Sequential( 522 | DownConvBlock(nfc, nfc // 2, norm_layer=norm_layer, down=False), 523 | DownConvBlock(nfc // 2, nfc // 4, norm_layer=norm_layer), # 4x4 524 | spectral_norm(nn.Conv2d(nfc // 4, 1, 4, 2, 0)) 525 | ) 526 | 527 | def forward(self, input): 528 | x = self.main(input) 529 | gap = torch.nn.functional.adaptive_avg_pool2d(x, 1) #x torch.Size([4, 1, 3, 3]) 530 | gap_logit = self.gap_fc(gap.view(x.shape[0], -1)) 531 | gap_weight = list(self.gap_fc.parameters())[0] 532 | gap = x * gap_weight.unsqueeze(2).unsqueeze(3) 533 | 534 | gmp = torch.nn.functional.adaptive_max_pool2d(x, 1) 535 | gmp_logit = self.gmp_fc(gmp.view(x.shape[0], -1)) 536 | gmp_weight = list(self.gmp_fc.parameters())[0] 537 | gmp = x * gmp_weight.unsqueeze(2).unsqueeze(3) 538 | 539 | cam_logit = torch.cat([gap_logit, gmp_logit], 1) 540 | x = torch.cat([gap, gmp], 1) 541 | x = self.leaky_relu(self.conv1x1(x)) 542 | # x = self.pad(x) 543 | out = self.conv(x) 544 | 545 | return out.view(-1) 546 | 547 | class Discriminator_UGATIT(nn.Module): 548 | def __init__(self, input_nc, ndf=64, n_layers=5): 549 | super(Discriminator_UGATIT, self).__init__() 550 | model = [nn.ReflectionPad2d(1), 551 | nn.utils.spectral_norm( 552 | nn.Conv2d(input_nc, ndf, kernel_size=4, stride=2, padding=0, bias=True)), 553 | nn.LeakyReLU(0.2, True)] 554 | 555 | for i in range(1, n_layers - 2): 556 | mult = 2 ** (i - 1) 557 | model += [nn.ReflectionPad2d(1), 558 | nn.utils.spectral_norm( 559 | nn.Conv2d(ndf * mult, ndf * mult * 2, kernel_size=4, stride=2, padding=0, bias=True)), 560 | nn.LeakyReLU(0.2, True)] 561 | 562 | mult = 2 ** (n_layers - 2 - 1) 563 | model += [nn.ReflectionPad2d(1), 564 | nn.utils.spectral_norm( 565 | nn.Conv2d(ndf * mult, ndf * mult * 2, kernel_size=4, stride=1, padding=0, bias=True)), 566 | nn.LeakyReLU(0.2, True)] 567 | 568 | # Class Activation Map 569 | mult = 2 ** (n_layers - 2) 570 | self.gap_fc = nn.utils.spectral_norm(nn.Linear(ndf * mult, 1, bias=False)) 571 | self.gmp_fc = nn.utils.spectral_norm(nn.Linear(ndf * mult, 1, bias=False)) 572 | self.conv1x1 = nn.Conv2d(ndf * mult * 2, ndf * mult, kernel_size=1, stride=1, bias=True) 573 | self.leaky_relu = nn.LeakyReLU(0.2, True) 574 | 575 | self.pad = nn.ReflectionPad2d(1) 576 | self.conv = nn.utils.spectral_norm( 577 | nn.Conv2d(ndf * mult, 1, kernel_size=4, stride=1, padding=0, bias=False)) 578 | 579 | self.model = nn.Sequential(*model) 580 | 581 | def forward(self, input): 582 | x = self.model(input) #input torch.Size([1, 3, 256, 256]) 583 | 584 | gap = torch.nn.functional.adaptive_avg_pool2d(x, 1) #x torch.Size([1, 2048, 7, 7]) 585 | gap_logit = self.gap_fc(gap.view(x.shape[0], -1)) 586 | gap_weight = list(self.gap_fc.parameters())[0] 587 | gap = x * gap_weight.unsqueeze(2).unsqueeze(3) 588 | 589 | gmp = torch.nn.functional.adaptive_max_pool2d(x, 1) 590 | gmp_logit = self.gmp_fc(gmp.view(x.shape[0], -1)) 591 | gmp_weight = list(self.gmp_fc.parameters())[0] 592 | gmp = x * gmp_weight.unsqueeze(2).unsqueeze(3) 593 | 594 | cam_logit = torch.cat([gap_logit, gmp_logit], 1) 595 | x = torch.cat([gap, gmp], 1) 596 | x = self.leaky_relu(self.conv1x1(x)) 597 | 598 | heatmap = torch.sum(x, dim=1, keepdim=True) 599 | 600 | x = self.pad(x) 601 | out = self.conv(x) #out.shape torch.Size([1, 1, 6, 6]) 602 | 603 | return out, cam_logit, heatmap 604 | 605 | -------------------------------------------------------------------------------- /sketch_generation/operation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import pandas as pd 4 | # from skimage import io, transform 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | #from torch.utils.data import Dataset, DataLoader 8 | from torchvision import transforms, utils 9 | import subprocess as sp 10 | from PIL import Image 11 | import time 12 | import torch.utils.data as data 13 | 14 | 15 | 16 | ### model math functions 17 | 18 | # from skimage.color import hsv2rgb 19 | import torch.nn.functional as F 20 | import torch.nn as nn 21 | 22 | eps = 1e-7 23 | class HSV_Loss(nn.Module): 24 | def __init__(self, h=0, s=1, v=0.7): 25 | super(HSV_Loss, self).__init__() 26 | self.hsv = [h, s, v] 27 | self.l1 = nn.L1Loss() 28 | self.mse = nn.MSELoss() 29 | 30 | @staticmethod 31 | def get_h(im): 32 | img = im * 0.5 + 0.5 33 | b, c, h, w = img.shape 34 | hue = torch.Tensor(im.shape[0], im.shape[2], im.shape[3]).to(im.device) 35 | hue[img[:,2]==img.max(1)[0]] = 4.0+((img[:,0]-img[:,1])/(img.max(1)[0] - img.min(1)[0]))[img[:,2]==img.max(1)[0]] 36 | hue[img[:,1]==img.max(1)[0]] = 2.0+((img[:,2]-img[:,0])/(img.max(1)[0] - img.min(1)[0]))[img[:,1]==img.max(1)[0]] 37 | hue[img[:,0]==img.max(1)[0]] = ((img[:,1]-img[:,2])/(img.max(1)[0] - img.min(1)[0]))[img[:,0]==img.max(1)[0]] 38 | hue = (hue/6.0) % 1.0 39 | hue[img.min(1)[0]==img.max(1)[0]] = 0.0 40 | return hue 41 | 42 | @staticmethod 43 | def get_v(im): 44 | img = im * 0.5 + 0.5 45 | b, c, h, w = img.shape 46 | it = img.transpose(1,2).transpose(2,3).contiguous().view(b, -1, c) 47 | value = F.max_pool1d(it, c).view(b, h, w) 48 | return value 49 | 50 | @staticmethod 51 | def get_s(im): 52 | img = im * 0.5 + 0.5 53 | b, c, h, w = img.shape 54 | it = img.transpose(1,2).transpose(2,3).contiguous().view(b, -1, c) 55 | max_v = F.max_pool1d(it, c).view(b, h, w) 56 | min_v = F.max_pool1d(it*-1, c).view(b, h, w) 57 | satur = (max_v + min_v) / (max_v+eps) 58 | return satur 59 | 60 | def forward(self, input): 61 | h = self.get_h(input) 62 | s = self.get_s(input) 63 | v = self.get_v(input) 64 | target_h = torch.Tensor(h.shape).fill_(self.hsv[0]).to(input.device).type_as(h) 65 | target_s = torch.Tensor(s.shape).fill_(self.hsv[1]).to(input.device) 66 | target_v = torch.Tensor(v.shape).fill_(self.hsv[2]).to(input.device) 67 | return self.mse(h, target_h) #+ 0.4*self.mse(v, target_v) 68 | 69 | 70 | 71 | ### data loading functions 72 | def InfiniteSampler(n): 73 | # i = 0 74 | i = n - 1 75 | order = np.random.permutation(n) 76 | while True: 77 | yield order[i] 78 | i += 1 79 | if i >= n: 80 | np.random.seed() 81 | order = np.random.permutation(n) 82 | i = 0 83 | 84 | class InfiniteSamplerWrapper(data.sampler.Sampler): 85 | def __init__(self, data_source): 86 | self.num_samples = len(data_source) 87 | 88 | def __iter__(self): 89 | return iter(InfiniteSampler(self.num_samples)) 90 | 91 | def __len__(self): 92 | return 2 ** 31 93 | 94 | 95 | def _rescale(img): 96 | return img * 2.0 - 1.0 97 | 98 | def trans_maker(size=256): 99 | trans = transforms.Compose([ 100 | transforms.Resize((size+36, size+36)), 101 | transforms.RandomHorizontalFlip(), 102 | transforms.RandomCrop((size, size)), 103 | transforms.ToTensor(), 104 | _rescale 105 | ]) 106 | return trans 107 | 108 | def trans_maker_testing(size=256): 109 | trans = transforms.Compose([ 110 | transforms.Resize((size, size)), 111 | transforms.ToTensor(), 112 | _rescale 113 | ]) 114 | return trans 115 | transform_gan = trans_maker(size=128) 116 | 117 | import torchvision.utils as vutils 118 | import logging 119 | logger = logging.getLogger(__name__) 120 | 121 | 122 | 123 | ### during training util functions 124 | def save_image(net, dataloader_A, device, cur_iter, trial, save_path): 125 | """Save imag output from net""" 126 | logger.info('Saving gan epoch {} images: {}'.format(cur_iter, save_path)) 127 | 128 | # Set net to evaluation mode 129 | net.eval() 130 | for p in net.parameters(): 131 | data_type = p.type() 132 | break 133 | with torch.no_grad(): 134 | for itx, data in enumerate(dataloader_A): 135 | g_img = net.gen_a2b(data[0].to(device).type(data_type)) 136 | for i in range(g_img.size(0)): 137 | vutils.save_image( 138 | g_img.cpu().float().add_(1).mul_(0.5), 139 | os.path.join(save_path, "{}_gan_epoch_{}_iter_{}_{}.jpg".format(trial, cur_iter, itx, i)),) 140 | # Set net to train mode 141 | net.train() 142 | return save_path 143 | 144 | def save_model(net, save_folder, cuda_device, if_multi_gpu, trial, cur_iter): 145 | """ Save current model and delete previous model, keep the saved model!""" 146 | save_name = "{}_gan_epoch_{}.pth".format(trial, cur_iter) 147 | save_path = os.path.join(save_folder, save_name) 148 | logger.info('Saving gan model: {}'.format(save_path)) 149 | 150 | net.save(save_path) 151 | 152 | for fname in os.listdir(save_folder): 153 | if fname.endswith('.pth') and fname != save_name: 154 | delete_path = os.path.join(save_folder, fname) 155 | os.remove(delete_path) 156 | logger.info('Deleted previous gan model: {}'.format(delete_path)) 157 | 158 | return save_path -------------------------------------------------------------------------------- /sketch_generation/requirement.txt: -------------------------------------------------------------------------------- 1 | python 3.7.10 2 | torch 1.12.1 3 | -------------------------------------------------------------------------------- /sketch_generation/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | import torch,gc 5 | import torch.nn.functional as F 6 | import torch.optim as optim 7 | import torchvision.datasets as Dataset 8 | import torchvision.utils as vutils 9 | from torch import nn 10 | from torch.utils.data import DataLoader 11 | import matplotlib.pyplot as plt 12 | from models import Generator, Discriminator, VGGSimple, Adaptive_pool,AdaLIN, get_batched_gram_matrix,Generator_UGATIT,adain 13 | from operation import InfiniteSamplerWrapper, trans_maker 14 | import numpy 15 | import argparse 16 | import tqdm 17 | from metrics import calculate_fid_modify 18 | 19 | torch.backends.cudnn.benchmark = True 20 | 21 | 22 | 23 | def creat_folder(save_folder, trial_name): 24 | saved_model_folder = os.path.join(save_folder, 'train_results/%s/models'%trial_name) 25 | saved_image_folder = os.path.join(save_folder, 'train_results/%s/images'%trial_name) 26 | folders = [os.path.join(save_folder, 'train_results'), os.path.join(save_folder, 'train_results/%s'%trial_name), 27 | os.path.join(save_folder, 'train_results/%s/images'%trial_name), os.path.join(save_folder, 'train_results/%s/models'%trial_name)] 28 | 29 | for folder in folders: 30 | if not os.path.exists(folder): 31 | os.mkdir(folder) 32 | return saved_model_folder, saved_image_folder 33 | 34 | def train_d(net, data, label="real"): 35 | pred = net(data) 36 | if label=="real": 37 | err = F.relu(1-pred).mean() 38 | else: 39 | err = F.relu(1+pred).mean() 40 | 41 | err.backward() 42 | return torch.sigmoid(pred).mean().item() 43 | 44 | def gram_matrix(input): 45 | a, b, c, d = input.size() # a=batch size(=1) 46 | # b=number of feature maps 47 | # (c,d)=dimensions of a f. map (N=c*d) 48 | features = input.view(a * b, c * d) # resise F_XL into \hat F_XL 49 | G = torch.mm(features, features.t()) # compute the gram product 50 | # we 'normalize' the values of the gram matrix 51 | # by dividing by the number of element in each feature maps. 52 | return G.div(a * b * c * d) 53 | 54 | def gram_loss(input, target): 55 | in_gram = gram_matrix(input) 56 | tar_gram = gram_matrix(target.detach()) 57 | return F.mse_loss(in_gram, tar_gram) 58 | 59 | def save_image(net_g, dataloader, saved_image_folder, n_iter): 60 | net_g.eval() 61 | with torch.no_grad(): 62 | imgs = [] 63 | real = [] 64 | for i, d in enumerate(dataloader): 65 | if i < 2: 66 | # net_f=netG_A2B(d[0].to(device))[0] 67 | # f_3 = vgg(d[0].to(device), base=base)[2] 68 | f_3 = vgg(d[0].to(device), base=base)[2] 69 | imgs.append(net_g(f_3).cpu()) 70 | real.append(d[0]) 71 | gc.collect() 72 | torch.cuda.empty_cache() 73 | else: 74 | break 75 | imgs = torch.cat(imgs, dim=0) 76 | real = torch.cat(real, dim=0) 77 | sss = torch.cat([imgs, real], dim=0) 78 | # 计算fid指标 79 | fid = calculate_fid_modify(imgs[0,0,:,:].detach().numpy(), real[0,0,:,:].detach().numpy()) 80 | print('fid-------------', fid) 81 | vutils.save_image( sss, "%s/iter_%d.jpg"%(saved_image_folder, n_iter), range=(-1,1), normalize=True) 82 | del imgs 83 | net_g.train() 84 | 85 | def train(net_g, net_d_style, max_iteration): 86 | print('training begin ... ') 87 | titles = ['D_r', 'D_f', 'G', 'G_rec'] 88 | losses = {title: 0.0 for title in titles} 89 | 90 | saved_model_folder, saved_image_folder = creat_folder(save_folder, trial_name) 91 | 92 | for n_iter in tqdm.tqdm(range(max_iteration+1)): 93 | if (n_iter+1)%(100)==0: 94 | try: 95 | model_dict = {'g': net_g.state_dict(), 'ds':net_d_style.state_dict()} 96 | torch.save(model_dict, os.path.join(saved_model_folder, '%d_model.pth'%(n_iter))) 97 | opt_dict = {'g': optG.state_dict(), 'ds':optDS.state_dict()} 98 | torch.save(opt_dict, os.path.join(saved_model_folder, '%d_opt.pth'%(n_iter))) 99 | except: 100 | print("models not properly saved") 101 | if n_iter%100==0: 102 | save_image(net_g, dataloader_A_fixed, saved_image_folder, n_iter) 103 | 104 | ## 1. prepare data 105 | real_style = next(dataloader_B)[0].to(device) 106 | real_content = next(dataloader_A)[0].to(device) 107 | 108 | cf_1, cf_2, cf_3, cf_4, cf_5 = vgg(real_content, base=base) 109 | sf_1, sf_2, sf_3, sf_4, sf_5 = vgg(real_style, base=base) 110 | 111 | fake_img = net_g(cf_3) 112 | tf_1, tf_2, tf_3, tf_4, tf_5 = vgg(fake_img, base=base) 113 | target_3 = adain(cf_3, sf_3) #torch.Size([4, 256, 32, 32]) 114 | # target_3 = AdaLIN(cf_3, sf_3) #更换为AdaLIN 115 | gram_sf_4 = gram_reshape(get_batched_gram_matrix(sf_4)) 116 | gram_sf_3 = gram_reshape(get_batched_gram_matrix(sf_3)) 117 | gram_sf_2 = gram_reshape(get_batched_gram_matrix(sf_2)) 118 | real_style_sample = torch.cat([gram_sf_2, gram_sf_3, gram_sf_4], dim=1) 119 | 120 | gram_tf_4 = gram_reshape(get_batched_gram_matrix(tf_4)) 121 | gram_tf_3 = gram_reshape(get_batched_gram_matrix(tf_3)) 122 | gram_tf_2 = gram_reshape(get_batched_gram_matrix(tf_2)) 123 | fake_style_sample = torch.cat([gram_tf_2, gram_tf_3, gram_tf_4], dim=1) 124 | 125 | ## 3. train Discriminator 126 | net_d_style.zero_grad() 127 | 128 | ### 3.1. train D_style on real data 129 | D_R = train_d(net_d_style, real_style_sample, label="real") 130 | ### 3.2. train D_style on fake data 131 | D_F = train_d(net_d_style, fake_style_sample.detach(), label="fake") 132 | 133 | optDS.step() 134 | 135 | ## 2. train Generator 136 | net_g.zero_grad() 137 | ### 2.1. train G as real image 138 | pred_gs = net_d_style(fake_style_sample) 139 | err_gs = -pred_gs.mean() 140 | G_B = torch.sigmoid(pred_gs).mean().item() #+ torch.sigmoid(pred_gc).mean().item() 141 | 142 | err_rec = F.mse_loss(tf_3, target_3) 143 | err_gram = 2000*( 144 | gram_loss(tf_4, sf_4) + \ 145 | gram_loss(tf_3, sf_3) + \ 146 | gram_loss(tf_2, sf_2)) 147 | 148 | G_rec = err_gram.item() 149 | 150 | 151 | 152 | err = err_gs + mse_weight*err_rec + gram_weight*err_gram 153 | err.backward() 154 | 155 | optG.step() 156 | 157 | ## logging ~ 158 | loss_values = [D_R, D_F, G_B, G_rec] 159 | for i, term in enumerate(titles): 160 | losses[term] += loss_values[i] 161 | 162 | if n_iter > 0 and n_iter % log_interval == 0: 163 | log_line = "" 164 | for key, value in losses.items(): 165 | log_line += "%s: %.5f "%(key, value/log_interval) 166 | losses[key] = 0 167 | print(log_line) 168 | 169 | 170 | 171 | if __name__ == '__main__': 172 | 173 | parser = argparse.ArgumentParser(description='Style transfer GAN, during training, the model will learn to take a image from one specific catagory and transform it into another style domain') 174 | print(os.path.join(os.getcwd(),"art-landscape-rgb-512")) 175 | patha="RGB/" 176 | pathb="Sketch/" 177 | parser.add_argument('--path_a', type=str, default=patha, help='path of resource dataset, should be a folder that has one or many sub image folders inside') 178 | parser.add_argument('--path_b', type=str, default=pathb, help='path of target dataset, should be a folder that has one or many sub image folders inside') 179 | parser.add_argument('--im_size', type=int, default=256, help='resolution of the generated images') 180 | parser.add_argument('--trial_name', type=str, default="test2", help='a brief description of the training trial') 181 | parser.add_argument('--gpu_id', type=int, default=0, help='0 is the first gpu, 1 is the second gpu, etc.') 182 | parser.add_argument('--lr', type=float, default=2e-4, help='learning rate, default is 2e-4, usually dont need to change it, you can try make it smaller, such as 1e-4') 183 | parser.add_argument('--batch_size', type=int, default=4, help='how many images to train together at one iteration') 184 | parser.add_argument('--total_iter', type=int, default=7000, help='how many iterations to train in total, the value is in assumption that init step is 1') 185 | parser.add_argument('--mse_weight', default=0.2, type=float, help='let G generate images with content more like in set A') 186 | parser.add_argument('--gram_weight', default=1, type=float, help='let G generate images with style more like in set B') 187 | parser.add_argument('--checkpoint', default='None', type=str, help='specify the path of the pre-trained model') 188 | 189 | args = parser.parse_args() 190 | 191 | print(str(args)) 192 | 193 | trial_name = args.trial_name 194 | data_root_A = args.path_a 195 | data_root_B = args.path_b 196 | mse_weight = args.mse_weight 197 | gram_weight = args.gram_weight 198 | max_iteration = args.total_iter 199 | # device = torch.device("cuda:%d"%(args.gpu_id)) 200 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 201 | print(torch.cuda.is_available()) 202 | # print(torch.cuda.current_device()) 203 | im_size = args.im_size 204 | if im_size == 128: 205 | base = 4 206 | elif im_size == 256: 207 | base = 8 208 | elif im_size == 512: 209 | base = 16 210 | if im_size not in [128, 256, 512]: 211 | print("the size must be in [128, 256, 512]") 212 | 213 | 214 | log_interval = 100 215 | save_folder = './model' 216 | number_model_to_save = 30 217 | 218 | vgg = VGGSimple() 219 | root_path=os.getcwd() 220 | vgg.load_state_dict(torch.load(os.path.join(root_path,'vgg-feature-weights.pth'), map_location=lambda a,b:a)) 221 | vgg.to(device) 222 | vgg.eval() 223 | 224 | 225 | 226 | for p in vgg.parameters(): 227 | p.requires_grad = False 228 | 229 | dataset_A = Dataset.ImageFolder(root=data_root_A, transform=trans_maker(args.im_size)) 230 | dataloader_A_fixed = DataLoader(dataset_A, 8, shuffle=False, num_workers=0) 231 | dataloader_A = iter(DataLoader(dataset_A, args.batch_size, shuffle=False,\ 232 | sampler=InfiniteSamplerWrapper(dataset_A), num_workers=4, pin_memory=False)) 233 | 234 | dataset_B = Dataset.ImageFolder(root=data_root_B, transform=trans_maker(args.im_size)) 235 | dataloader_B = iter(DataLoader(dataset_B, args.batch_size, shuffle=False,\ 236 | sampler=InfiniteSamplerWrapper(dataset_B), num_workers=0, pin_memory=False)) 237 | 238 | net_g = Generator(infc=256, nfc=128) 239 | netG_A2B = Generator_UGATIT(image_size=256).to(device) 240 | net_d_style = Discriminator(nfc=128*3, norm_layer=nn.BatchNorm2d) 241 | gram_reshape = Adaptive_pool(128, 16) 242 | # this style discriminator take input: 512x512 gram matrix from 512x8x8 vgg feature, 243 | # the reshaped pooled input size is: 256x16x16 244 | 245 | if args.checkpoint != 'None': 246 | checkpoint = torch.load(args.checkpoint, map_location=lambda storage, loc: storage) 247 | net_g.load_state_dict(checkpoint['g']) 248 | net_d_style.load_state_dict(checkpoint['ds']) 249 | print("saved model loaded") 250 | 251 | net_d_style.to(device) 252 | net_g.to(device) 253 | 254 | optG = optim.Adam(net_g.parameters(), lr=args.lr, betas=(0.5, 0.99)) 255 | optDS = optim.Adam(net_d_style.parameters(), lr=args.lr, betas=(0.5, 0.99)) 256 | 257 | if args.checkpoint != 'None': 258 | opt_path = args.checkpoint.replace("_model.pth", "_opt.pth") 259 | try: 260 | opt_weights = torch.load(opt_path, map_location=lambda a, b: a) 261 | optG.load_state_dict(opt_weights['g']) 262 | optDS.load_state_dict(opt_weights['ds']) 263 | print("saved optimizer loaded") 264 | except: 265 | print("no optimizer weights detected, resuming a training without optimizer weights may not let the model converge as desired") 266 | pass 267 | 268 | 269 | train(net_g, net_d_style, max_iteration) 270 | 271 | -------------------------------------------------------------------------------- /sketch_generation/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from copy import deepcopy 4 | from random import shuffle 5 | import torch.nn.functional as F 6 | 7 | def d_hinge_loss(real_pred, fake_pred): 8 | real_loss = F.relu(1-real_pred) 9 | fake_loss = F.relu(1+fake_pred) 10 | 11 | return real_loss.mean() + fake_loss.mean() 12 | 13 | 14 | def g_hinge_loss(pred): 15 | return -pred.mean() 16 | 17 | 18 | class AverageMeter(object): 19 | 20 | def __init__(self): 21 | self.reset() 22 | 23 | def reset(self): 24 | self.val = 0 25 | self.avg = 0 26 | self.sum = 0 27 | self.count = 0 28 | 29 | def update(self, val, n=1): 30 | self.val = val 31 | self.sum += val * n 32 | self.count += n 33 | self.avg = self.sum / self.count 34 | 35 | 36 | def true_randperm(size, device='cuda'): 37 | def unmatched_randperm(size): 38 | l1 = [i for i in range(size)] 39 | l2 = [] 40 | for j in range(size): 41 | deleted = False 42 | if j in l1: 43 | deleted = True 44 | del l1[l1.index(j)] 45 | shuffle(l1) 46 | if len(l1) == 0: 47 | return 0, False 48 | l2.append(l1[0]) 49 | del l1[0] 50 | if deleted: 51 | l1.append(j) 52 | return l2, True 53 | flag = False 54 | l = torch.zeros(size).long() 55 | while not flag: 56 | l, flag = unmatched_randperm(size) 57 | return torch.LongTensor(l).to(device) 58 | 59 | 60 | def copy_G_params(model): 61 | flatten = deepcopy(list(p.data for p in model.parameters())) 62 | return flatten 63 | 64 | 65 | def load_params(model, new_param): 66 | for p, new_p in zip(model.parameters(), new_param): 67 | p.data.copy_(new_p) 68 | 69 | 70 | def make_folders(save_folder, trial_name): 71 | saved_model_folder = os.path.join(save_folder, 'train_results/%s/models'%trial_name) 72 | saved_image_folder = os.path.join(save_folder, 'train_results/%s/images'%trial_name) 73 | folders = [os.path.join(save_folder, 'train_results'), 74 | os.path.join(save_folder, 'train_results/%s'%trial_name), 75 | os.path.join(save_folder, 'train_results/%s/images'%trial_name), 76 | os.path.join(save_folder, 'train_results/%s/models'%trial_name)] 77 | for folder in folders: 78 | if not os.path.exists(folder): 79 | os.mkdir(folder) 80 | 81 | from shutil import copy 82 | try: 83 | for f in os.listdir('.'): 84 | if '.py' in f: 85 | copy(f, os.path.join(save_folder, 'train_results/%s'%trial_name)+'/'+f) 86 | except: 87 | pass 88 | return saved_image_folder, saved_model_folder 89 | 90 | 91 | 92 | import cv2 93 | import numpy as np 94 | import math 95 | 96 | ##################### 97 | # Both horizontal and vertical 98 | def warp(img, mag=10, freq=100): 99 | rows, cols = img.shape 100 | 101 | img_output = np.zeros(img.shape, dtype=img.dtype) 102 | 103 | for i in range(rows): 104 | for j in range(cols): 105 | offset_x = int(mag * math.sin(2 * 3.14 * i / freq)) 106 | offset_y = int(mag * math.cos(2 * 3.14 * j / freq)) 107 | if i+offset_y < rows and j+offset_x < cols: 108 | img_output[i,j] = img[(i+offset_y)%rows,(j+offset_x)%cols] 109 | else: 110 | img_output[i,j] = 0 111 | 112 | return img_output 113 | 114 | #img = cv2.imread('1.png', cv2.IMREAD_GRAYSCALE) 115 | #img_output = warp(img, mag=10, freq=200) 116 | #cv2.imwrite('Multidirectional_wave.jpg', img_output) 117 | -------------------------------------------------------------------------------- /sketch_generation/vgg-feature-weights.z01: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExponentiAI/StyleMe/64bf50eba799feba7c077cbf4ef4507ddf5c81d9/sketch_generation/vgg-feature-weights.z01 -------------------------------------------------------------------------------- /sketch_generation/vgg-feature-weights.z02: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExponentiAI/StyleMe/64bf50eba799feba7c077cbf4ef4507ddf5c81d9/sketch_generation/vgg-feature-weights.z02 -------------------------------------------------------------------------------- /sketch_generation/vgg-feature-weights.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExponentiAI/StyleMe/64bf50eba799feba7c077cbf4ef4507ddf5c81d9/sketch_generation/vgg-feature-weights.zip -------------------------------------------------------------------------------- /styleme/benchmark.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchvision import models 5 | from torchvision.models import inception_v3, Inception3 6 | from torchvision.utils import save_image 7 | from torchvision import utils as vutils 8 | from torch.utils.data import DataLoader 9 | 10 | try: 11 | from torchvision.models.utils import load_state_dict_from_url 12 | except ImportError: 13 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 14 | 15 | import numpy as np 16 | from scipy import linalg 17 | from tqdm import tqdm 18 | import pickle 19 | import os 20 | from utils import true_randperm 21 | from datasets import InfiniteSamplerWrapper 22 | 23 | # Inception weights ported to Pytorch from 24 | # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 25 | FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' 26 | 27 | 28 | class InceptionV3(nn.Module): 29 | """Pretrained InceptionV3 network returning feature maps""" 30 | 31 | # Index of default block of inception to return, 32 | # corresponds to output of final average pooling 33 | DEFAULT_BLOCK_INDEX = 3 34 | 35 | # Maps feature dimensionality to their output blocks indices 36 | BLOCK_INDEX_BY_DIM = { 37 | 64: 0, # First max pooling features 38 | 192: 1, # Second max pooling featurs 39 | 768: 2, # Pre-aux classifier features 40 | 2048: 3 # Final average pooling features 41 | } 42 | 43 | def __init__(self, 44 | output_blocks=[DEFAULT_BLOCK_INDEX], 45 | resize_input=True, 46 | normalize_input=True, 47 | requires_grad=False, 48 | use_fid_inception=True): 49 | """Build pretrained InceptionV3 50 | Parameters 51 | ---------- 52 | output_blocks : list of int 53 | Indices of blocks to return features of. Possible values are: 54 | - 0: corresponds to output of first max pooling 55 | - 1: corresponds to output of second max pooling 56 | - 2: corresponds to output which is fed to aux classifier 57 | - 3: corresponds to output of final average pooling 58 | resize_input : bool 59 | If true, bilinearly resizes input to width and height 299 before 60 | feeding input to model. As the network without fully connected 61 | layers is fully convolutional, it should be able to handle inputs 62 | of arbitrary size, so resizing might not be strictly needed 63 | normalize_input : bool 64 | If true, scales the input from range (0, 1) to the range the 65 | pretrained Inception network expects, namely (-1, 1) 66 | requires_grad : bool 67 | If true, parameters of the model require gradients. Possibly useful 68 | for finetuning the network 69 | use_fid_inception : bool 70 | If true, uses the pretrained Inception model used in Tensorflow's 71 | FID implementation. If false, uses the pretrained Inception model 72 | available in torchvision. The FID Inception model has different 73 | weights and a slightly different structure from torchvision's 74 | Inception model. If you want to compute FID scores, you are 75 | strongly advised to set this parameter to true to get comparable 76 | results. 77 | """ 78 | super(InceptionV3, self).__init__() 79 | 80 | self.resize_input = resize_input 81 | self.normalize_input = normalize_input 82 | self.output_blocks = sorted(output_blocks) 83 | self.last_needed_block = max(output_blocks) 84 | 85 | assert self.last_needed_block <= 3, \ 86 | 'Last possible output block index is 3' 87 | 88 | self.blocks = nn.ModuleList() 89 | 90 | if use_fid_inception: 91 | inception = fid_inception_v3() 92 | else: 93 | inception = models.inception_v3(pretrained=True) 94 | 95 | # Block 0: input to maxpool1 96 | block0 = [ 97 | inception.Conv2d_1a_3x3, 98 | inception.Conv2d_2a_3x3, 99 | inception.Conv2d_2b_3x3, 100 | nn.MaxPool2d(kernel_size=3, stride=2) 101 | ] 102 | self.blocks.append(nn.Sequential(*block0)) 103 | 104 | # Block 1: maxpool1 to maxpool2 105 | if self.last_needed_block >= 1: 106 | block1 = [ 107 | inception.Conv2d_3b_1x1, 108 | inception.Conv2d_4a_3x3, 109 | nn.MaxPool2d(kernel_size=3, stride=2) 110 | ] 111 | self.blocks.append(nn.Sequential(*block1)) 112 | 113 | # Block 2: maxpool2 to aux classifier 114 | if self.last_needed_block >= 2: 115 | block2 = [ 116 | inception.Mixed_5b, 117 | inception.Mixed_5c, 118 | inception.Mixed_5d, 119 | inception.Mixed_6a, 120 | inception.Mixed_6b, 121 | inception.Mixed_6c, 122 | inception.Mixed_6d, 123 | inception.Mixed_6e, 124 | ] 125 | self.blocks.append(nn.Sequential(*block2)) 126 | 127 | # Block 3: aux classifier to final avgpool 128 | if self.last_needed_block >= 3: 129 | block3 = [ 130 | inception.Mixed_7a, 131 | inception.Mixed_7b, 132 | inception.Mixed_7c, 133 | nn.AdaptiveAvgPool2d(output_size=(1, 1)) 134 | ] 135 | self.blocks.append(nn.Sequential(*block3)) 136 | 137 | for param in self.parameters(): 138 | param.requires_grad = requires_grad 139 | 140 | def forward(self, inp): 141 | """Get Inception feature maps 142 | Parameters 143 | ---------- 144 | inp : torch.autograd.Variable 145 | Input tensor of shape Bx3xHxW. Values are expected to be in 146 | range (0, 1) 147 | Returns 148 | ------- 149 | List of torch.autograd.Variable, corresponding to the selected output 150 | block, sorted ascending by index 151 | """ 152 | outp = [] 153 | x = inp 154 | 155 | if self.resize_input: 156 | x = F.interpolate(x, 157 | size=(299, 299), 158 | mode='bilinear', 159 | align_corners=False) 160 | 161 | if self.normalize_input: 162 | x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1) 163 | 164 | for idx, block in enumerate(self.blocks): 165 | x = block(x) 166 | if idx in self.output_blocks: 167 | outp.append(x) 168 | 169 | if idx == self.last_needed_block: 170 | break 171 | 172 | return outp 173 | 174 | 175 | def fid_inception_v3(): 176 | """Build pretrained Inception model for FID computation 177 | The Inception model for FID computation uses a different set of weights 178 | and has a slightly different structure than torchvision's Inception. 179 | This method first constructs torchvision's Inception and then patches the 180 | necessary parts that are different in the FID Inception model. 181 | """ 182 | inception = models.inception_v3(num_classes=1008, 183 | aux_logits=False, 184 | pretrained=False) 185 | inception.Mixed_5b = FIDInceptionA(192, pool_features=32) 186 | inception.Mixed_5c = FIDInceptionA(256, pool_features=64) 187 | inception.Mixed_5d = FIDInceptionA(288, pool_features=64) 188 | inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128) 189 | inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160) 190 | inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160) 191 | inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192) 192 | inception.Mixed_7b = FIDInceptionE_1(1280) 193 | inception.Mixed_7c = FIDInceptionE_2(2048) 194 | 195 | state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True) 196 | inception.load_state_dict(state_dict) 197 | return inception 198 | 199 | 200 | class FIDInceptionA(models.inception.InceptionA): 201 | """InceptionA block patched for FID computation""" 202 | 203 | def __init__(self, in_channels, pool_features): 204 | super(FIDInceptionA, self).__init__(in_channels, pool_features) 205 | 206 | def forward(self, x): 207 | branch1x1 = self.branch1x1(x) 208 | 209 | branch5x5 = self.branch5x5_1(x) 210 | branch5x5 = self.branch5x5_2(branch5x5) 211 | 212 | branch3x3dbl = self.branch3x3dbl_1(x) 213 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 214 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) 215 | 216 | # Patch: Tensorflow's average pool does not use the padded zero's in 217 | # its average calculation 218 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 219 | count_include_pad=False) 220 | branch_pool = self.branch_pool(branch_pool) 221 | 222 | outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] 223 | return torch.cat(outputs, 1) 224 | 225 | 226 | class FIDInceptionC(models.inception.InceptionC): 227 | """InceptionC block patched for FID computation""" 228 | 229 | def __init__(self, in_channels, channels_7x7): 230 | super(FIDInceptionC, self).__init__(in_channels, channels_7x7) 231 | 232 | def forward(self, x): 233 | branch1x1 = self.branch1x1(x) 234 | 235 | branch7x7 = self.branch7x7_1(x) 236 | branch7x7 = self.branch7x7_2(branch7x7) 237 | branch7x7 = self.branch7x7_3(branch7x7) 238 | 239 | branch7x7dbl = self.branch7x7dbl_1(x) 240 | branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) 241 | branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) 242 | branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) 243 | branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) 244 | 245 | # Patch: Tensorflow's average pool does not use the padded zero's in 246 | # its average calculation 247 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 248 | count_include_pad=False) 249 | branch_pool = self.branch_pool(branch_pool) 250 | 251 | outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] 252 | return torch.cat(outputs, 1) 253 | 254 | 255 | class FIDInceptionE_1(models.inception.InceptionE): 256 | """First InceptionE block patched for FID computation""" 257 | 258 | def __init__(self, in_channels): 259 | super(FIDInceptionE_1, self).__init__(in_channels) 260 | 261 | def forward(self, x): 262 | branch1x1 = self.branch1x1(x) 263 | 264 | branch3x3 = self.branch3x3_1(x) 265 | branch3x3 = [ 266 | self.branch3x3_2a(branch3x3), 267 | self.branch3x3_2b(branch3x3), 268 | ] 269 | branch3x3 = torch.cat(branch3x3, 1) 270 | 271 | branch3x3dbl = self.branch3x3dbl_1(x) 272 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 273 | branch3x3dbl = [ 274 | self.branch3x3dbl_3a(branch3x3dbl), 275 | self.branch3x3dbl_3b(branch3x3dbl), 276 | ] 277 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 278 | 279 | # Patch: Tensorflow's average pool does not use the padded zero's in 280 | # its average calculation 281 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 282 | count_include_pad=False) 283 | branch_pool = self.branch_pool(branch_pool) 284 | 285 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 286 | return torch.cat(outputs, 1) 287 | 288 | 289 | class FIDInceptionE_2(models.inception.InceptionE): 290 | """Second InceptionE block patched for FID computation""" 291 | 292 | def __init__(self, in_channels): 293 | super(FIDInceptionE_2, self).__init__(in_channels) 294 | 295 | def forward(self, x): 296 | branch1x1 = self.branch1x1(x) 297 | 298 | branch3x3 = self.branch3x3_1(x) 299 | branch3x3 = [ 300 | self.branch3x3_2a(branch3x3), 301 | self.branch3x3_2b(branch3x3), 302 | ] 303 | branch3x3 = torch.cat(branch3x3, 1) 304 | 305 | branch3x3dbl = self.branch3x3dbl_1(x) 306 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 307 | branch3x3dbl = [ 308 | self.branch3x3dbl_3a(branch3x3dbl), 309 | self.branch3x3dbl_3b(branch3x3dbl), 310 | ] 311 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 312 | 313 | # Patch: The FID Inception model uses max pooling instead of average 314 | # pooling. This is likely an error in this specific Inception 315 | # implementation, as other Inception models use average pooling here 316 | # (which matches the description in the paper). 317 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1) 318 | branch_pool = self.branch_pool(branch_pool) 319 | 320 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 321 | return torch.cat(outputs, 1) 322 | 323 | 324 | class Inception3Feature(Inception3): 325 | def forward(self, x): 326 | if x.shape[2] != 299 or x.shape[3] != 299: 327 | x = F.interpolate(x, size=(299, 299), mode='bilinear', align_corners=True) 328 | 329 | x = self.Conv2d_1a_3x3(x) # 299 x 299 x 3 330 | x = self.Conv2d_2a_3x3(x) # 149 x 149 x 32 331 | x = self.Conv2d_2b_3x3(x) # 147 x 147 x 32 332 | x = F.max_pool2d(x, kernel_size=3, stride=2) # 147 x 147 x 64 333 | 334 | x = self.Conv2d_3b_1x1(x) # 73 x 73 x 64 335 | x = self.Conv2d_4a_3x3(x) # 73 x 73 x 80 336 | x = F.max_pool2d(x, kernel_size=3, stride=2) # 71 x 71 x 192 337 | 338 | x = self.Mixed_5b(x) # 35 x 35 x 192 339 | x = self.Mixed_5c(x) # 35 x 35 x 256 340 | x = self.Mixed_5d(x) # 35 x 35 x 288 341 | 342 | x = self.Mixed_6a(x) # 35 x 35 x 288 343 | x = self.Mixed_6b(x) # 17 x 17 x 768 344 | x = self.Mixed_6c(x) # 17 x 17 x 768 345 | x = self.Mixed_6d(x) # 17 x 17 x 768 346 | x = self.Mixed_6e(x) # 17 x 17 x 768 347 | 348 | x = self.Mixed_7a(x) # 17 x 17 x 768 349 | x = self.Mixed_7b(x) # 8 x 8 x 1280 350 | x = self.Mixed_7c(x) # 8 x 8 x 2048 351 | 352 | x = F.avg_pool2d(x, kernel_size=8) # 8 x 8 x 2048 353 | 354 | return x.view(x.shape[0], x.shape[1]) # 1 x 1 x 2048 355 | 356 | 357 | def load_patched_inception_v3(): 358 | # inception = inception_v3(pretrained=True) 359 | # inception_feat = Inception3Feature() 360 | # inception_feat.load_state_dict(inception.state_dict()) 361 | inception_feat = InceptionV3([3], normalize_input=False) 362 | 363 | return inception_feat 364 | 365 | 366 | @torch.no_grad() 367 | def extract_features(loader, inception, device): 368 | pbar = tqdm(loader) 369 | 370 | feature_list = [] 371 | 372 | for img in pbar: 373 | img = img.to(device) 374 | feature = inception(img)[0].view(img.shape[0], -1) 375 | feature_list.append(feature.to('cpu')) 376 | 377 | features = torch.cat(feature_list, 0) 378 | 379 | return features 380 | 381 | 382 | @torch.no_grad() 383 | def extract_feature_from_generator_fn(generator_fn, inception, device='cuda', total=1000): 384 | features = [] 385 | 386 | for batch in tqdm(generator_fn, total=total): 387 | try: 388 | feat = inception(batch)[0].view(batch.shape[0], -1) 389 | features.append(feat.to('cpu')) 390 | except: 391 | break 392 | features = torch.cat(features, 0).detach() 393 | return features.numpy() 394 | 395 | 396 | def calc_fid(sample_features, real_features=None, real_mean=None, real_cov=None, eps=1e-6): 397 | sample_mean = np.mean(sample_features, 0) 398 | sample_cov = np.cov(sample_features, rowvar=False) 399 | 400 | if real_features is not None: 401 | real_mean = np.mean(real_features, 0) 402 | real_cov = np.cov(real_features, rowvar=False) 403 | 404 | cov_sqrt, _ = linalg.sqrtm(sample_cov @ real_cov, disp=False) 405 | 406 | if not np.isfinite(cov_sqrt).all(): 407 | print('product of cov matrices is singular') 408 | offset = np.eye(sample_cov.shape[0]) * eps 409 | cov_sqrt = linalg.sqrtm((sample_cov + offset) @ (real_cov + offset)) 410 | 411 | if np.iscomplexobj(cov_sqrt): 412 | if not np.allclose(np.diagonal(cov_sqrt).imag, 0, atol=1e-3): 413 | m = np.max(np.abs(cov_sqrt.imag)) 414 | 415 | raise ValueError(f'Imaginary component {m}') 416 | 417 | cov_sqrt = cov_sqrt.real 418 | 419 | mean_diff = sample_mean - real_mean 420 | mean_norm = mean_diff @ mean_diff 421 | 422 | trace = np.trace(sample_cov) + np.trace(real_cov) - 2 * np.trace(cov_sqrt) 423 | 424 | fid = mean_norm + trace 425 | 426 | return fid 427 | 428 | 429 | def real_image_loader(dataloader, n_batches=10): 430 | counter = 0 431 | while counter < n_batches: 432 | counter += 1 433 | rgb_img = next(dataloader)[0] 434 | if counter == 1: 435 | vutils.save_image(0.5 * (rgb_img + 1), './checkpoint/tmp_real.jpg') 436 | yield rgb_img.cuda() 437 | 438 | 439 | @torch.no_grad() 440 | def image_generator(dataset, net_ae, net_ig, BATCH_SIZE=8, n_batches=500): 441 | counter = 0 442 | dataloader = iter( 443 | DataLoader(dataset, BATCH_SIZE, sampler=InfiniteSamplerWrapper(dataset), num_workers=4, pin_memory=False)) 444 | n_batches = min(n_batches, len(dataset) // BATCH_SIZE - 1) 445 | while counter < n_batches: 446 | counter += 1 447 | rgb_img, skt_img = next(dataloader) 448 | rgb_img = F.interpolate(rgb_img, size=256).cuda() 449 | skt_img = F.interpolate(skt_img, size=256).cuda() 450 | 451 | gimg_ae, style_feat = net_ae(skt_img, rgb_img) 452 | # g_image = gimg_ae 453 | g_image = net_ig(gimg_ae, style_feat) 454 | if counter == 1: 455 | vutils.save_image(0.5 * (g_image + 1), './checkpoint/tmp.jpg') 456 | yield g_image 457 | 458 | 459 | @torch.no_grad() 460 | def image_generator_perm(dataset, net_ae, net_ig, BATCH_SIZE=8, n_batches=500): 461 | counter = 0 462 | dataloader = iter(DataLoader(dataset, BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=False)) 463 | n_batches = min(n_batches, len(dataset) // BATCH_SIZE - 1) 464 | while counter < n_batches: 465 | counter += 1 466 | rgb_img, skt_img = next(dataloader) 467 | rgb_img = F.interpolate(rgb_img, size=256).cuda() 468 | skt_img = F.interpolate(skt_img, size=256).cuda() 469 | 470 | perm = true_randperm(rgb_img.shape[0], device=rgb_img.device) 471 | 472 | gimg_ae, style_feat = net_ae(skt_img, rgb_img[perm]) 473 | # g_image = gimg_ae 474 | g_image = net_ig(gimg_ae, style_feat) 475 | if counter == 1: 476 | vutils.save_image(0.5 * (g_image + 1), './checkpoint/tmp.jpg') 477 | yield g_image 478 | 479 | 480 | if __name__ == "__main__": 481 | from utils import PairedMultiDataset, InfiniteSamplerWrapper, make_folders, AverageMeter 482 | from torch.utils.data import DataLoader 483 | from torchvision import utils as vutils 484 | 485 | IM_SIZE = 1024 486 | BATCH_SIZE = 8 487 | DATALOADER_WORKERS = 8 488 | NBR_CLS = 2000 489 | TRIAL_NAME = 'trial_vae_512_1' 490 | SAVE_FOLDER = './' 491 | 492 | data_root_colorful = '../images/celebA/CelebA_512_test/img' 493 | data_root_sketch_1 = './sketch_simplification/vggadin_iter_700_test' 494 | data_root_sketch_2 = './sketch_simplification/vggadin_iter_1900_test' 495 | data_root_sketch_3 = './sketch_simplification/vggadin_iter_2300_test' 496 | 497 | dataset = PairedMultiDataset(data_root_colorful, data_root_sketch_1, data_root_sketch_2, data_root_sketch_3, 498 | im_size=IM_SIZE, rand_crop=False) 499 | dataloader = iter(DataLoader(dataset, BATCH_SIZE, shuffle=False, num_workers=DATALOADER_WORKERS, pin_memory=True)) 500 | 501 | from models import StyleEncoder, ContentEncoder, Decoder 502 | import pickle 503 | from models import AE, RefineGenerator 504 | from utils import load_params 505 | 506 | net_ig = RefineGenerator().cuda() 507 | net_ig = nn.DataParallel(net_ig) 508 | 509 | ckpt = './train_results/trial_refine_ae_as_gan_1024_2/models/4.pth' 510 | if ckpt is not None: 511 | ckpt = torch.load(ckpt) 512 | # net_ig.load_state_dict(ckpt['ig']) 513 | # net_id.load_state_dict(ckpt['id']) 514 | net_ig_ema = ckpt['ig_ema'] 515 | load_params(net_ig, net_ig_ema) 516 | net_ig = net_ig.module 517 | # net_ig.eval() 518 | 519 | net_ae = AE() 520 | net_ae.load_state_dicts('./train_results/trial_vae_512_1/models/176000.pth') 521 | net_ae.cuda() 522 | net_ae.eval() 523 | 524 | inception = load_patched_inception_v3().cuda() 525 | inception.eval() 526 | 527 | ''' 528 | real_features = extract_feature_from_generator_fn( 529 | real_image_loader(dataloader, n_batches=1000), inception ) 530 | real_mean = np.mean(real_features, 0) 531 | real_cov = np.cov(real_features, rowvar=False) 532 | ''' 533 | # pickle.dump({'feats': real_features, 'mean': real_mean, 'cov': real_cov}, open('celeba_fid_feats.npy','wb') ) 534 | 535 | real_features = pickle.load(open('celeba_fid_feats.npy', 'rb')) 536 | real_mean = real_features['mean'] 537 | real_cov = real_features['cov'] 538 | # sample_features = extract_feature_from_generator_fn( real_image_loader(dataloader, n_batches=100), inception ) 539 | for it in range(1): 540 | itx = it * 8000 541 | ''' 542 | ckpt = torch.load('./train_results/%s/models/%d.pth'%(TRIAL_NAME, itx)) 543 | 544 | style_encoder.load_state_dict(ckpt['e']) 545 | content_encoder.load_state_dict(ckpt['c']) 546 | decoder.load_state_dict(ckpt['d']) 547 | 548 | dataloader = iter(DataLoader(dataset, BATCH_SIZE, sampler=InfiniteSamplerWrapper(dataset), num_workers=DATALOADER_WORKERS, pin_memory=True)) 549 | ''' 550 | 551 | sample_features = extract_feature_from_generator_fn( 552 | image_generator(dataset, net_ae, net_ig, n_batches=1800), inception, 553 | total=1800) 554 | 555 | # fid = calc_fid(sample_features, real_mean=real_features['mean'], real_cov=real_features['cov']) 556 | fid = calc_fid(sample_features, real_mean=real_mean, real_cov=real_cov) 557 | 558 | print(it, fid) 559 | -------------------------------------------------------------------------------- /styleme/calculate.py: -------------------------------------------------------------------------------- 1 | ###################################### 2 | # calculate FID and LPIPS # 3 | ###################################### 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | from torch.utils.data import DataLoader 8 | 9 | from tqdm import tqdm 10 | from datasets import PairedDataset, InfiniteSamplerWrapper 11 | from utils import AverageMeter 12 | 13 | 14 | def calculate_Lpips(data_root_colorful, data_root_sketch, model): 15 | import lpips 16 | from models import AE 17 | from models import RefineGenerator as Generator 18 | 19 | CHANNEL = 32 20 | NBR_CLS = 50 21 | IM_SIZE = 256 22 | BATCH_SIZE = 6 23 | DATALOADER_WORKERS = 2 24 | 25 | percept = lpips.PerceptualLoss(model='net-lin', net='vgg', use_gpu=True) 26 | 27 | # load dataset 28 | dataset = PairedDataset(data_root_colorful, data_root_sketch, im_size=IM_SIZE) 29 | print('the dataset contains %d images.' % len(dataset)) 30 | 31 | dataloader = iter(DataLoader(dataset, BATCH_SIZE, sampler=InfiniteSamplerWrapper(dataset), 32 | num_workers=DATALOADER_WORKERS, pin_memory=True)) 33 | 34 | # load ae model 35 | net_ae = AE(ch=CHANNEL, nbr_cls=NBR_CLS) 36 | net_ae.style_encoder.reset_cls() 37 | net_ig = Generator(ch=CHANNEL, im_size=IM_SIZE) 38 | 39 | PRETRAINED_PATH = './checkpoint/GAN.pth'.format(str(model)) 40 | print('Pre-trained path : ', PRETRAINED_PATH) 41 | ckpt = torch.load(PRETRAINED_PATH) 42 | 43 | net_ae.load_state_dict(ckpt['ae']) 44 | net_ig.load_state_dict(ckpt['ig']) 45 | 46 | net_ae.cuda() 47 | net_ig.cuda() 48 | net_ae.eval() 49 | net_ig.eval() 50 | 51 | # lpips 52 | get_lpips = AverageMeter() 53 | lpips_list = [] 54 | 55 | # Network 56 | for iter_data in tqdm(range(1000)): 57 | rgb_img, skt_img = next(dataloader) 58 | 59 | rgb_img = rgb_img.cuda() 60 | skt_img = skt_img.cuda() 61 | 62 | gimg_ae, style_feats = net_ae(skt_img, rgb_img) 63 | g_image = net_ig(gimg_ae, style_feats) 64 | 65 | loss_mse = 10 * percept(F.adaptive_avg_pool2d(g_image, output_size=256), 66 | F.adaptive_avg_pool2d(rgb_img, output_size=256)).sum() 67 | get_lpips.update(loss_mse.item() / BATCH_SIZE, BATCH_SIZE) 68 | 69 | lpips_list.append(get_lpips.avg) 70 | 71 | if (iter_data + 1) % 100 == 0: 72 | # print('avg : ', get_lpips.avg) 73 | print('LPIPS : ', sum(lpips_list) / len(lpips_list)) 74 | 75 | print('LPIPS : ', sum(lpips_list) / len(lpips_list)) 76 | 77 | 78 | def calculate_fid(data_root_colorful, data_root_sketch, model): 79 | from benchmark import calc_fid, extract_feature_from_generator_fn, load_patched_inception_v3, real_image_loader, \ 80 | image_generator, image_generator_perm 81 | from models import AE 82 | from models import RefineGenerator as Generator 83 | import numpy as np 84 | 85 | CHANNEL = 32 86 | NBR_CLS = 50 87 | IM_SIZE = 256 88 | BATCH_SIZE = 8 89 | DATALOADER_WORKERS = 2 90 | fid_batch_images = 119 91 | fid_iters = 100 92 | inception = load_patched_inception_v3().cuda() 93 | inception.eval() 94 | 95 | fid = [] 96 | fid_perm = [] 97 | 98 | # load dataset 99 | dataset = PairedDataset(data_root_colorful, data_root_sketch, im_size=IM_SIZE) 100 | print('the dataset contains %d images.' % len(dataset)) 101 | 102 | dataloader = iter(DataLoader(dataset, BATCH_SIZE, sampler=InfiniteSamplerWrapper(dataset), 103 | num_workers=DATALOADER_WORKERS, pin_memory=True)) 104 | 105 | # load ae model 106 | net_ae = AE(ch=CHANNEL, nbr_cls=NBR_CLS) 107 | net_ae.style_encoder.reset_cls() 108 | net_ig = Generator(ch=CHANNEL, im_size=IM_SIZE) 109 | 110 | PRETRAINED_PATH = './checkpoint/GAN.pth'.format(str(model)) 111 | print('Pre-trained path : ', PRETRAINED_PATH) 112 | ckpt = torch.load(PRETRAINED_PATH) 113 | 114 | net_ae.load_state_dict(ckpt['ae']) 115 | net_ig.load_state_dict(ckpt['ig']) 116 | 117 | net_ae.cuda() 118 | net_ig.cuda() 119 | net_ae.eval() 120 | net_ig.eval() 121 | 122 | print("calculating FID ...") 123 | 124 | real_features = extract_feature_from_generator_fn( 125 | real_image_loader(dataloader, n_batches=fid_batch_images), inception) 126 | real_mean = np.mean(real_features, 0) 127 | real_cov = np.cov(real_features, rowvar=False) 128 | real_features = {'feats': real_features, 'mean': real_mean, 'cov': real_cov} 129 | 130 | for iter_fid in range(fid_iters): 131 | sample_features = extract_feature_from_generator_fn( 132 | image_generator(dataset, net_ae, net_ig, n_batches=fid_batch_images), 133 | inception, total=fid_batch_images // BATCH_SIZE - 1) 134 | cur_fid = calc_fid(sample_features, real_mean=real_features['mean'], real_cov=real_features['cov']) 135 | 136 | sample_features_perm = extract_feature_from_generator_fn( 137 | image_generator_perm(dataset, net_ae, net_ig, n_batches=fid_batch_images), 138 | inception, total=fid_batch_images // BATCH_SIZE - 1) 139 | cur_fid_perm = calc_fid(sample_features_perm, real_mean=real_features['mean'], 140 | real_cov=real_features['cov']) 141 | 142 | print('FID[{}]: '.format(iter_fid), [cur_fid, cur_fid_perm]) 143 | fid.append(cur_fid) 144 | fid_perm.append(cur_fid_perm) 145 | 146 | print('FID: ', sum(fid) / len(fid)) 147 | print('FID perm: ', sum(fid_perm) / len(fid_perm)) 148 | 149 | 150 | if __name__ == "__main__": 151 | model = 'styleme' 152 | data_root_colorful = './train_data/rgb/' 153 | data_root_sketch = './train_data/sketch/' 154 | # data_root_colorful = './train_data/comparison/rgb/' 155 | # data_root_sketch = './train_data/comparison/sketch_styleme/' 156 | # data_root_sketch = './train_data/comparison/sketch_cam/' 157 | # data_root_sketch = './train_data/comparison/sketch_adalin/' 158 | # data_root_sketch = './train_data/comparison/sketch_wo_camada/' 159 | 160 | calculate_Lpips(data_root_colorful, data_root_sketch, model) 161 | # calculate_fid(data_root_colorful, data_root_sketch, model) 162 | 163 | # styleme | 0.13515148047968645 | 16.034930465842525 164 | # styleme_wo | 0.4334833870760152 | 32.5567679015783 165 | # cam | 0.1373054370310368 | 17.165196809300138 166 | # adalin | 0.31896749291615123 | 28.387120218137913 167 | # camada | 0.36015568705948886 | 29.75984833745646 168 | -------------------------------------------------------------------------------- /styleme/config.py: -------------------------------------------------------------------------------- 1 | ################################# 2 | # training parameter # 3 | ################################# 4 | 5 | DATALOADER_WORKERS = 2 6 | NBR_CLS = 50 7 | 8 | EPOCH_GAN = 100 9 | ITERATION_GAN = 2000 10 | 11 | SAVE_IMAGE_INTERVAL = 100 12 | SAVE_MODEL_INTERVAL = 200 13 | LOG_INTERVAL = 200 14 | FID_INTERVAL = 100 15 | FID_BATCH_NBR = 100 16 | 17 | ITERATION_AE = 20000 18 | 19 | CHANNEL = 32 20 | MULTI_GPU = False 21 | 22 | IM_SIZE_GAN = 256 23 | BATCH_SIZE_GAN = 8 24 | 25 | IM_SIZE_AE = 256 26 | BATCH_SIZE_AE = 8 27 | 28 | SAVE_FOLDER = './checkpoint/' 29 | 30 | # PRETRAINED_AE_PATH = './checkpoint/models/AE_20000.pth' 31 | PRETRAINED_AE_PATH = None 32 | 33 | # GAN_CKECKPOINT = './checkpoint/models/9.pth' 34 | GAN_CKECKPOINT = None 35 | 36 | TRAIN_AE_ONLY = False 37 | TRAIN_GAN_ONLY = False 38 | 39 | data_root_colorful = './train_data/rgb/' 40 | data_root_sketch = './train_data/sketch_styleme/' 41 | # data_root_sketch = './train_data/sketchgen_wo_cam/' 42 | # data_root_sketch = './train_data/sketchgen_wo_adalin/' 43 | # data_root_sketch = './train_data/sketchgen_wo_camada/' 44 | -------------------------------------------------------------------------------- /styleme/datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | from PIL import Image, ImageFilter 5 | from PIL import ImageFile 6 | 7 | from torch.utils.data import Dataset 8 | from torchvision import transforms 9 | import torch.utils.data as data 10 | 11 | ImageFile.LOAD_TRUNCATED_IMAGES = True 12 | 13 | 14 | def _rescale(img): 15 | return img * 2.0 - 1.0 16 | 17 | 18 | def transform_data(im_size=256): 19 | trans = transforms.Compose([ 20 | transforms.Resize((im_size, im_size)), 21 | transforms.ToTensor(), 22 | _rescale 23 | ]) 24 | return trans 25 | 26 | 27 | class TransformData(Dataset): 28 | def __init__(self, data_rgb, data_sketch, im_size=256, nbr_cls=100): 29 | super(TransformData, self).__init__() 30 | self.rgb_root = data_rgb 31 | self.skt_root = data_sketch 32 | 33 | self.frame = self._parse_frame() 34 | random.shuffle(self.frame) 35 | 36 | self.nbr_cls = nbr_cls 37 | self.set_offset = 0 38 | self.im_size = im_size 39 | 40 | self.transform = transforms.Compose([ 41 | transforms.Resize((im_size, im_size)), 42 | transforms.ToTensor(), 43 | _rescale 44 | ]) 45 | 46 | self.transform_rd = transforms.Compose([ 47 | transforms.Resize((int(im_size * 1.3), int(im_size * 1.3))), 48 | transforms.RandomCrop((int(im_size), int(im_size))), 49 | transforms.RandomRotation(30), 50 | transforms.RandomHorizontalFlip(p=1), 51 | transforms.Resize((im_size, im_size)), 52 | transforms.ToTensor(), 53 | _rescale 54 | ]) 55 | 56 | self.transform_flip = transforms.Compose([ 57 | transforms.RandomHorizontalFlip(p=0.8), 58 | transforms.RandomVerticalFlip(p=0.8), 59 | transforms.Resize((im_size, im_size)), 60 | transforms.ToTensor(), 61 | _rescale 62 | ]) 63 | 64 | self.transform_erase = transforms.Compose([ 65 | transforms.Resize((im_size, im_size)), 66 | transforms.ToTensor(), 67 | _rescale, 68 | transforms.RandomErasing(p=0.8, scale=(0.02, 0.1), value=1), 69 | transforms.RandomErasing(p=0.8, scale=(0.02, 0.1), value=1), 70 | transforms.RandomErasing(p=0.8, scale=(0.02, 0.1), value=1)]) 71 | 72 | self.transform_bold = transforms.Compose([ 73 | transforms.Resize((int(im_size * 1.1), int(im_size * 1.1))), 74 | transforms.Resize((im_size, im_size)), 75 | transforms.ToTensor(), 76 | _rescale 77 | ]) 78 | 79 | def _parse_frame(self): 80 | frame = [] 81 | img_names = os.listdir(self.rgb_root) 82 | img_names.sort() 83 | for i in range(len(img_names)): 84 | img_name = img_names[i].zfill(len(str(len(img_names)))) 85 | rgb_path = os.path.join(self.rgb_root, img_name) 86 | skt_path = os.path.join(self.skt_root, img_name) 87 | if os.path.exists(rgb_path) and os.path.exists(skt_path): 88 | frame.append((rgb_path, skt_path)) 89 | 90 | return frame 91 | 92 | def __len__(self): 93 | return self.nbr_cls 94 | 95 | def _next_set(self): 96 | self.set_offset += self.nbr_cls 97 | if self.set_offset > (len(self.frame) - self.nbr_cls): 98 | random.shuffle(self.frame) 99 | self.set_offset = 0 100 | 101 | def __getitem__(self, idx): 102 | file, skt_path = self.frame[idx + self.set_offset] 103 | rgb = Image.open(file).convert('RGB') 104 | skt = Image.open(skt_path).convert('L') 105 | 106 | img_normal = self.transform(rgb) 107 | img_rd = self.transform_rd(rgb) 108 | img_flip = self.transform_flip(rgb) 109 | 110 | skt_normal = self.transform(skt) 111 | skt_erase = self.transform_erase(skt) 112 | bold_factor = 3 113 | skt_bold = skt.filter(ImageFilter.MinFilter(size=bold_factor)) 114 | skt_bold = self.transform_bold(skt_bold) 115 | 116 | return img_normal, img_rd, img_flip, skt_normal, skt_erase, skt_bold, idx 117 | 118 | 119 | def InfiniteSampler(n): 120 | i = n - 1 121 | order = np.random.permutation(n) 122 | while True: 123 | yield order[i] 124 | i += 1 125 | if i >= n: 126 | np.random.seed() 127 | order = np.random.permutation(n) 128 | i = 0 129 | 130 | 131 | class InfiniteSamplerWrapper(data.sampler.Sampler): 132 | def __init__(self, data_source): 133 | self.num_samples = len(data_source) 134 | 135 | def __iter__(self): 136 | return iter(InfiniteSampler(self.num_samples)) 137 | 138 | def __len__(self): 139 | return 2 ** 31 140 | 141 | 142 | class PairedDataset(Dataset): 143 | def __init__(self, data_root_1, data_root_2, im_size=256): 144 | super(PairedDataset, self).__init__() 145 | self.root_a = data_root_1 146 | self.root_b = data_root_2 147 | 148 | self.frame = self._parse_frame() 149 | self.transform = transform_data(im_size) 150 | 151 | def _parse_frame(self): 152 | frame = [] 153 | img_names = os.listdir(self.root_a) 154 | img_names.sort() 155 | for i in range(len(img_names)): 156 | img_name = '%s.jpg' % str(i).zfill(len(str(len(img_names)))) 157 | image_a_path = os.path.join(self.root_a, img_names[i]) 158 | if ('.jpg' in image_a_path) or ('.png' in image_a_path): 159 | image_b_path = os.path.join(self.root_b, img_name) 160 | if os.path.exists(image_b_path): 161 | frame.append((image_a_path, image_b_path)) 162 | 163 | return frame 164 | 165 | def __len__(self): 166 | return len(self.frame) 167 | 168 | def __getitem__(self, idx): 169 | file_a, file_b = self.frame[idx] 170 | img_a = Image.open(file_a).convert('RGB') 171 | img_b = Image.open(file_b).convert('L') 172 | 173 | if self.transform: 174 | img_a = self.transform(img_a) 175 | img_b = self.transform(img_b) 176 | 177 | return (img_a, img_b) 178 | 179 | 180 | class ImageFolder(Dataset): 181 | def __init__(self, data_root, transform=transform_data(256)): 182 | super(ImageFolder, self).__init__() 183 | self.root = data_root 184 | 185 | self.frame = self._parse_frame() 186 | self.transform = transform 187 | 188 | def _parse_frame(self): 189 | frame = [] 190 | img_names = os.listdir(self.root) 191 | img_names.sort() 192 | for i in range(len(img_names)): 193 | image_path = os.path.join(self.root, img_names[i]) 194 | if ('.jpg' in image_path) or ('.png' in image_path): 195 | frame.append(image_path) 196 | 197 | return frame 198 | 199 | def __len__(self): 200 | return len(self.frame) 201 | 202 | def __getitem__(self, idx): 203 | file = self.frame[idx] 204 | img = Image.open(file).convert('RGB') 205 | 206 | if self.transform: 207 | img = self.transform(img) 208 | return img 209 | -------------------------------------------------------------------------------- /styleme/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExponentiAI/StyleMe/64bf50eba799feba7c077cbf4ef4507ddf5c81d9/styleme/framework.png -------------------------------------------------------------------------------- /styleme/generate_matrix.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from torchvision import utils as vutils 4 | 5 | 6 | def make_matrix(dataset_rgb, dataset_skt, net_ae, net_ig, BATCH_SIZE, IM_SIZE, im_name): 7 | dataloader_rgb = iter(DataLoader(dataset_rgb, BATCH_SIZE, shuffle=True)) 8 | dataloader_skt = iter(DataLoader(dataset_skt, BATCH_SIZE, shuffle=True)) 9 | 10 | rgb_img = next(dataloader_rgb) 11 | skt_img = next(dataloader_skt) 12 | 13 | skt_img = skt_img.mean(dim=1, keepdim=True) 14 | 15 | image_matrix = [torch.ones(1, 3, IM_SIZE, IM_SIZE)] 16 | image_matrix.append(rgb_img.clone()) 17 | with torch.no_grad(): 18 | rgb_img = rgb_img.cuda() 19 | for skt in skt_img: 20 | input_skts = skt.unsqueeze(0).repeat(BATCH_SIZE, 1, 1, 1).cuda() 21 | 22 | gimg_ae, style_feats = net_ae(input_skts, rgb_img) 23 | image_matrix.append(skt.unsqueeze(0).repeat(1, 3, 1, 1).clone()) 24 | image_matrix.append(gimg_ae.cpu()) 25 | 26 | g_images = net_ig(gimg_ae, style_feats).cpu() 27 | image_matrix.append(skt.unsqueeze(0).repeat(1, 3, 1, 1).clone().fill_(1)) 28 | image_matrix.append(torch.nn.functional.interpolate(g_images, IM_SIZE)) 29 | 30 | image_matrix = torch.cat(image_matrix) 31 | vutils.save_image(0.5 * (image_matrix + 1), im_name, nrow=BATCH_SIZE + 1) 32 | -------------------------------------------------------------------------------- /styleme/lpips/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import numpy as np 6 | from skimage.metrics import structural_similarity as compare_ssim 7 | import torch 8 | from torch.autograd import Variable 9 | 10 | from lpips import dist_model 11 | 12 | 13 | class PerceptualLoss(torch.nn.Module): 14 | def __init__(self, model='net-lin', net='alex', colorspace='rgb', spatial=False, use_gpu=True, 15 | gpu_ids=[0]): # VGG using our perceptually-learned weights (LPIPS metric) 16 | # def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss 17 | super(PerceptualLoss, self).__init__() 18 | print('Setting up Perceptual loss...') 19 | self.use_gpu = use_gpu 20 | self.spatial = spatial 21 | self.gpu_ids = gpu_ids 22 | self.model = dist_model.DistModel() 23 | self.model.initialize(model=model, net=net, use_gpu=use_gpu, colorspace=colorspace, spatial=self.spatial, 24 | gpu_ids=gpu_ids) 25 | print('...[%s] initialized' % self.model.name()) 26 | print('...Done') 27 | 28 | def forward(self, pred, target, normalize=False): 29 | """ 30 | Pred and target are Variables. 31 | If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1] 32 | If normalize is False, assumes the images are already between [-1,+1] 33 | 34 | Inputs pred and target are Nx3xHxW 35 | Output pytorch Variable N long 36 | """ 37 | 38 | if normalize: 39 | target = 2 * target - 1 40 | pred = 2 * pred - 1 41 | 42 | return self.model.forward(target, pred) 43 | 44 | 45 | def normalize_tensor(in_feat, eps=1e-10): 46 | norm_factor = torch.sqrt(torch.sum(in_feat ** 2, dim=1, keepdim=True)) 47 | return in_feat / (norm_factor + eps) 48 | 49 | 50 | def l2(p0, p1, range=255.): 51 | return .5 * np.mean((p0 / range - p1 / range) ** 2) 52 | 53 | 54 | def psnr(p0, p1, peak=255.): 55 | return 10 * np.log10(peak ** 2 / np.mean((1. * p0 - 1. * p1) ** 2)) 56 | 57 | 58 | def dssim(p0, p1, range=255.): 59 | return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2. 60 | 61 | 62 | def rgb2lab(in_img, mean_cent=False): 63 | from skimage import color 64 | img_lab = color.rgb2lab(in_img) 65 | if (mean_cent): 66 | img_lab[:, :, 0] = img_lab[:, :, 0] - 50 67 | return img_lab 68 | 69 | 70 | def tensor2np(tensor_obj): 71 | # change dimension of a tensor object into a numpy array 72 | return tensor_obj[0].cpu().float().numpy().transpose((1, 2, 0)) 73 | 74 | 75 | def np2tensor(np_obj): 76 | # change dimenion of np array into tensor array 77 | return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 78 | 79 | 80 | def tensor2tensorlab(image_tensor, to_norm=True, mc_only=False): 81 | # image tensor to lab tensor 82 | from skimage import color 83 | 84 | img = tensor2im(image_tensor) 85 | img_lab = color.rgb2lab(img) 86 | if (mc_only): 87 | img_lab[:, :, 0] = img_lab[:, :, 0] - 50 88 | if (to_norm and not mc_only): 89 | img_lab[:, :, 0] = img_lab[:, :, 0] - 50 90 | img_lab = img_lab / 100. 91 | 92 | return np2tensor(img_lab) 93 | 94 | 95 | def tensorlab2tensor(lab_tensor, return_inbnd=False): 96 | from skimage import color 97 | import warnings 98 | warnings.filterwarnings("ignore") 99 | 100 | lab = tensor2np(lab_tensor) * 100. 101 | lab[:, :, 0] = lab[:, :, 0] + 50 102 | 103 | rgb_back = 255. * np.clip(color.lab2rgb(lab.astype('float')), 0, 1) 104 | if (return_inbnd): 105 | # convert back to lab, see if we match 106 | lab_back = color.rgb2lab(rgb_back.astype('uint8')) 107 | mask = 1. * np.isclose(lab_back, lab, atol=2.) 108 | mask = np2tensor(np.prod(mask, axis=2)[:, :, np.newaxis]) 109 | return (im2tensor(rgb_back), mask) 110 | else: 111 | return im2tensor(rgb_back) 112 | 113 | 114 | def rgb2lab(input): 115 | from skimage import color 116 | return color.rgb2lab(input / 255.) 117 | 118 | 119 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255. / 2.): 120 | image_numpy = image_tensor[0].cpu().float().numpy() 121 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor 122 | return image_numpy.astype(imtype) 123 | 124 | 125 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255. / 2.): 126 | return torch.Tensor((image / factor - cent) 127 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 128 | 129 | 130 | def tensor2vec(vector_tensor): 131 | return vector_tensor.data.cpu().numpy()[:, :, 0, 0] 132 | 133 | 134 | def voc_ap(rec, prec, use_07_metric=False): 135 | """ ap = voc_ap(rec, prec, [use_07_metric]) 136 | Compute VOC AP given precision and recall. 137 | If use_07_metric is true, uses the 138 | VOC 07 11 point method (default:False). 139 | """ 140 | if use_07_metric: 141 | # 11 point metric 142 | ap = 0. 143 | for t in np.arange(0., 1.1, 0.1): 144 | if np.sum(rec >= t) == 0: 145 | p = 0 146 | else: 147 | p = np.max(prec[rec >= t]) 148 | ap = ap + p / 11. 149 | else: 150 | # correct AP calculation 151 | # first append sentinel values at the end 152 | mrec = np.concatenate(([0.], rec, [1.])) 153 | mpre = np.concatenate(([0.], prec, [0.])) 154 | 155 | # compute the precision envelope 156 | for i in range(mpre.size - 1, 0, -1): 157 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) 158 | 159 | # to calculate area under PR curve, look for points 160 | # where X axis (recall) changes value 161 | i = np.where(mrec[1:] != mrec[:-1])[0] 162 | 163 | # and sum (\Delta recall) * prec 164 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) 165 | return ap 166 | 167 | 168 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255. / 2.): 169 | # def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.): 170 | image_numpy = image_tensor[0].cpu().float().numpy() 171 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor 172 | return image_numpy.astype(imtype) 173 | 174 | 175 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255. / 2.): 176 | # def im2tensor(image, imtype=np.uint8, cent=1., factor=1.): 177 | return torch.Tensor((image / factor - cent) 178 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 179 | -------------------------------------------------------------------------------- /styleme/lpips/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExponentiAI/StyleMe/64bf50eba799feba7c077cbf4ef4507ddf5c81d9/styleme/lpips/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /styleme/lpips/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExponentiAI/StyleMe/64bf50eba799feba7c077cbf4ef4507ddf5c81d9/styleme/lpips/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /styleme/lpips/__pycache__/base_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExponentiAI/StyleMe/64bf50eba799feba7c077cbf4ef4507ddf5c81d9/styleme/lpips/__pycache__/base_model.cpython-37.pyc -------------------------------------------------------------------------------- /styleme/lpips/__pycache__/base_model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExponentiAI/StyleMe/64bf50eba799feba7c077cbf4ef4507ddf5c81d9/styleme/lpips/__pycache__/base_model.cpython-38.pyc -------------------------------------------------------------------------------- /styleme/lpips/__pycache__/dist_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExponentiAI/StyleMe/64bf50eba799feba7c077cbf4ef4507ddf5c81d9/styleme/lpips/__pycache__/dist_model.cpython-37.pyc -------------------------------------------------------------------------------- /styleme/lpips/__pycache__/dist_model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExponentiAI/StyleMe/64bf50eba799feba7c077cbf4ef4507ddf5c81d9/styleme/lpips/__pycache__/dist_model.cpython-38.pyc -------------------------------------------------------------------------------- /styleme/lpips/__pycache__/networks_basic.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExponentiAI/StyleMe/64bf50eba799feba7c077cbf4ef4507ddf5c81d9/styleme/lpips/__pycache__/networks_basic.cpython-37.pyc -------------------------------------------------------------------------------- /styleme/lpips/__pycache__/networks_basic.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExponentiAI/StyleMe/64bf50eba799feba7c077cbf4ef4507ddf5c81d9/styleme/lpips/__pycache__/networks_basic.cpython-38.pyc -------------------------------------------------------------------------------- /styleme/lpips/__pycache__/pretrained_networks.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExponentiAI/StyleMe/64bf50eba799feba7c077cbf4ef4507ddf5c81d9/styleme/lpips/__pycache__/pretrained_networks.cpython-37.pyc -------------------------------------------------------------------------------- /styleme/lpips/__pycache__/pretrained_networks.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExponentiAI/StyleMe/64bf50eba799feba7c077cbf4ef4507ddf5c81d9/styleme/lpips/__pycache__/pretrained_networks.cpython-38.pyc -------------------------------------------------------------------------------- /styleme/lpips/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.autograd import Variable 4 | from pdb import set_trace as st 5 | from IPython import embed 6 | 7 | class BaseModel(): 8 | def __init__(self): 9 | pass; 10 | 11 | def name(self): 12 | return 'BaseModel' 13 | 14 | def initialize(self, use_gpu=True, gpu_ids=[0]): 15 | self.use_gpu = use_gpu 16 | self.gpu_ids = gpu_ids 17 | 18 | def forward(self): 19 | pass 20 | 21 | def get_image_paths(self): 22 | pass 23 | 24 | def optimize_parameters(self): 25 | pass 26 | 27 | def get_current_visuals(self): 28 | return self.input 29 | 30 | def get_current_errors(self): 31 | return {} 32 | 33 | def save(self, label): 34 | pass 35 | 36 | # helper saving function that can be used by subclasses 37 | def save_network(self, network, path, network_label, epoch_label): 38 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 39 | save_path = os.path.join(path, save_filename) 40 | torch.save(network.state_dict(), save_path) 41 | 42 | # helper loading function that can be used by subclasses 43 | def load_network(self, network, network_label, epoch_label): 44 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 45 | save_path = os.path.join(self.save_dir, save_filename) 46 | print('Loading network from %s'%save_path) 47 | network.load_state_dict(torch.load(save_path)) 48 | 49 | def update_learning_rate(): 50 | pass 51 | 52 | def get_image_paths(self): 53 | return self.image_paths 54 | 55 | def save_done(self, flag=False): 56 | np.save(os.path.join(self.save_dir, 'done_flag'),flag) 57 | np.savetxt(os.path.join(self.save_dir, 'done_flag'),[flag,],fmt='%i') 58 | 59 | -------------------------------------------------------------------------------- /styleme/lpips/dist_model.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | 4 | import sys 5 | import numpy as np 6 | import torch 7 | from torch import nn 8 | import os 9 | from collections import OrderedDict 10 | from torch.autograd import Variable 11 | import itertools 12 | from .base_model import BaseModel 13 | from scipy.ndimage import zoom 14 | import fractions 15 | import functools 16 | import skimage.transform 17 | from tqdm import tqdm 18 | 19 | from IPython import embed 20 | 21 | from . import networks_basic as networks 22 | import lpips as util 23 | 24 | class DistModel(BaseModel): 25 | def name(self): 26 | return self.model_name 27 | 28 | def initialize(self, model='net-lin', net='alex', colorspace='Lab', pnet_rand=False, pnet_tune=False, model_path=None, 29 | use_gpu=True, printNet=False, spatial=False, 30 | is_train=False, lr=.0001, beta1=0.5, version='0.1', gpu_ids=[0]): 31 | ''' 32 | INPUTS 33 | model - ['net-lin'] for linearly calibrated network 34 | ['net'] for off-the-shelf network 35 | ['L2'] for L2 distance in Lab colorspace 36 | ['SSIM'] for ssim in RGB colorspace 37 | net - ['squeeze','alex','vgg'] 38 | model_path - if None, will look in weights/[NET_NAME].pth 39 | colorspace - ['Lab','RGB'] colorspace to use for L2 and SSIM 40 | use_gpu - bool - whether or not to use a GPU 41 | printNet - bool - whether or not to print network architecture out 42 | spatial - bool - whether to output an array containing varying distances across spatial dimensions 43 | spatial_shape - if given, output spatial shape. if None then spatial shape is determined automatically via spatial_factor (see below). 44 | spatial_factor - if given, specifies upsampling factor relative to the largest spatial extent of a convolutional layer. if None then resized to size of input images. 45 | spatial_order - spline order of filter for upsampling in spatial mode, by default 1 (bilinear). 46 | is_train - bool - [True] for training mode 47 | lr - float - initial learning rate 48 | beta1 - float - initial momentum term for adam 49 | version - 0.1 for latest, 0.0 was original (with a bug) 50 | gpu_ids - int array - [0] by default, gpus to use 51 | ''' 52 | BaseModel.initialize(self, use_gpu=use_gpu, gpu_ids=gpu_ids) 53 | 54 | self.model = model 55 | self.net = net 56 | self.is_train = is_train 57 | self.spatial = spatial 58 | self.gpu_ids = gpu_ids 59 | self.model_name = '%s [%s]'%(model,net) 60 | 61 | if(self.model == 'net-lin'): # pretrained net + linear layer 62 | self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_tune=pnet_tune, pnet_type=net, 63 | use_dropout=True, spatial=spatial, version=version, lpips=True) 64 | kw = {} 65 | if not use_gpu: 66 | kw['map_location'] = 'cpu' 67 | if(model_path is None): 68 | import inspect 69 | model_path = os.path.abspath(os.path.join(inspect.getfile(self.initialize), '..', 'weights/v%s/%s.pth'%(version,net))) 70 | 71 | if(not is_train): 72 | print('Loading model from: %s'%model_path) 73 | self.net.load_state_dict(torch.load(model_path, **kw), strict=False) 74 | 75 | elif(self.model=='net'): # pretrained network 76 | self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_type=net, lpips=False) 77 | elif(self.model in ['L2','l2']): 78 | self.net = networks.L2(use_gpu=use_gpu,colorspace=colorspace) # not really a network, only for testing 79 | self.model_name = 'L2' 80 | elif(self.model in ['DSSIM','dssim','SSIM','ssim']): 81 | self.net = networks.DSSIM(use_gpu=use_gpu,colorspace=colorspace) 82 | self.model_name = 'SSIM' 83 | else: 84 | raise ValueError("Model [%s] not recognized." % self.model) 85 | 86 | self.parameters = list(self.net.parameters()) 87 | 88 | if self.is_train: # training mode 89 | # extra network on top to go from distances (d0,d1) => predicted human judgment (h*) 90 | self.rankLoss = networks.BCERankingLoss() 91 | self.parameters += list(self.rankLoss.net.parameters()) 92 | self.lr = lr 93 | self.old_lr = lr 94 | self.optimizer_net = torch.optim.Adam(self.parameters, lr=lr, betas=(beta1, 0.999)) 95 | else: # test mode 96 | self.net.eval() 97 | 98 | if(use_gpu): 99 | self.net.to(gpu_ids[0]) 100 | self.net = torch.nn.DataParallel(self.net, device_ids=gpu_ids) 101 | if(self.is_train): 102 | self.rankLoss = self.rankLoss.to(device=gpu_ids[0]) # just put this on GPU0 103 | 104 | if(printNet): 105 | print('---------- Networks initialized -------------') 106 | networks.print_network(self.net) 107 | print('-----------------------------------------------') 108 | 109 | def forward(self, in0, in1, retPerLayer=False): 110 | ''' Function computes the distance between image patches in0 and in1 111 | INPUTS 112 | in0, in1 - torch.Tensor object of shape Nx3xXxY - image patch scaled to [-1,1] 113 | OUTPUT 114 | computed distances between in0 and in1 115 | ''' 116 | 117 | return self.net.forward(in0, in1, retPerLayer=retPerLayer) 118 | 119 | # ***** TRAINING FUNCTIONS ***** 120 | def optimize_parameters(self): 121 | self.forward_train() 122 | self.optimizer_net.zero_grad() 123 | self.backward_train() 124 | self.optimizer_net.step() 125 | self.clamp_weights() 126 | 127 | def clamp_weights(self): 128 | for module in self.net.modules(): 129 | if(hasattr(module, 'weight') and module.kernel_size==(1,1)): 130 | module.weight.data = torch.clamp(module.weight.data,min=0) 131 | 132 | def set_input(self, data): 133 | self.input_ref = data['ref'] 134 | self.input_p0 = data['p0'] 135 | self.input_p1 = data['p1'] 136 | self.input_judge = data['judge'] 137 | 138 | if(self.use_gpu): 139 | self.input_ref = self.input_ref.to(device=self.gpu_ids[0]) 140 | self.input_p0 = self.input_p0.to(device=self.gpu_ids[0]) 141 | self.input_p1 = self.input_p1.to(device=self.gpu_ids[0]) 142 | self.input_judge = self.input_judge.to(device=self.gpu_ids[0]) 143 | 144 | self.var_ref = Variable(self.input_ref,requires_grad=True) 145 | self.var_p0 = Variable(self.input_p0,requires_grad=True) 146 | self.var_p1 = Variable(self.input_p1,requires_grad=True) 147 | 148 | def forward_train(self): # run forward pass 149 | # print(self.net.module.scaling_layer.shift) 150 | # print(torch.norm(self.net.module.net.slice1[0].weight).item(), torch.norm(self.net.module.lin0.model[1].weight).item()) 151 | 152 | self.d0 = self.forward(self.var_ref, self.var_p0) 153 | self.d1 = self.forward(self.var_ref, self.var_p1) 154 | self.acc_r = self.compute_accuracy(self.d0,self.d1,self.input_judge) 155 | 156 | self.var_judge = Variable(1.*self.input_judge).view(self.d0.size()) 157 | 158 | self.loss_total = self.rankLoss.forward(self.d0, self.d1, self.var_judge*2.-1.) 159 | 160 | return self.loss_total 161 | 162 | def backward_train(self): 163 | torch.mean(self.loss_total).backward() 164 | 165 | def compute_accuracy(self,d0,d1,judge): 166 | ''' d0, d1 are Variables, judge is a Tensor ''' 167 | d1_lt_d0 = (d1 %f' % (type,self.old_lr, lr)) 210 | self.old_lr = lr 211 | 212 | def score_2afc_dataset(data_loader, func, name=''): 213 | ''' Function computes Two Alternative Forced Choice (2AFC) score using 214 | distance function 'func' in dataset 'data_loader' 215 | INPUTS 216 | data_loader - CustomDatasetDataLoader object - contains a TwoAFCDataset inside 217 | func - callable distance function - calling d=func(in0,in1) should take 2 218 | pytorch tensors with shape Nx3xXxY, and return numpy array of length N 219 | OUTPUTS 220 | [0] - 2AFC score in [0,1], fraction of time func agrees with human evaluators 221 | [1] - dictionary with following elements 222 | d0s,d1s - N arrays containing distances between reference patch to perturbed patches 223 | gts - N array in [0,1], preferred patch selected by human evaluators 224 | (closer to "0" for left patch p0, "1" for right patch p1, 225 | "0.6" means 60pct people preferred right patch, 40pct preferred left) 226 | scores - N array in [0,1], corresponding to what percentage function agreed with humans 227 | CONSTS 228 | N - number of test triplets in data_loader 229 | ''' 230 | 231 | d0s = [] 232 | d1s = [] 233 | gts = [] 234 | 235 | for data in tqdm(data_loader.load_data(), desc=name): 236 | d0s+=func(data['ref'],data['p0']).data.cpu().numpy().flatten().tolist() 237 | d1s+=func(data['ref'],data['p1']).data.cpu().numpy().flatten().tolist() 238 | gts+=data['judge'].cpu().numpy().flatten().tolist() 239 | 240 | d0s = np.array(d0s) 241 | d1s = np.array(d1s) 242 | gts = np.array(gts) 243 | scores = (d0s mask.shape[1]: 46 | channel_scale = feat.shape[1] // mask.shape[1] 47 | mask = mask.repeat(1, channel_scale, 1, 1) 48 | 49 | mask = F.interpolate(mask, size=feat.shape[2]) 50 | feat_a = self.weight_a * feat * mask + self.bias_a 51 | feat_b = self.weight_b * feat * (1 - mask) + self.bias_b 52 | return feat_a + feat_b 53 | 54 | 55 | class Swish(nn.Module): 56 | def forward(self, feat): 57 | return feat * torch.sigmoid(feat) 58 | 59 | 60 | class Squeeze(nn.Module): 61 | def forward(self, feat): 62 | return feat.squeeze(-1).squeeze(-1) 63 | 64 | 65 | class UnSqueeze(nn.Module): 66 | def forward(self, feat): 67 | return feat.unsqueeze(-1).unsqueeze(-1) 68 | 69 | 70 | class ECAModule(nn.Module): 71 | def __init__(self, c, b=1, gamma=2): 72 | super(ECAModule, self).__init__() 73 | t = int(abs((math.log(c, 2) + b) / gamma)) 74 | k = t if t % 2 else t + 1 75 | 76 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 77 | self.conv1 = spectral_norm(nn.Conv1d(1, 1, k, 1, int(k / 2), bias=False)) 78 | self.sigmoid = nn.Sigmoid() 79 | 80 | def forward(self, x): 81 | x = self.avg_pool(x) 82 | x = self.conv1(x.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1) 83 | out = self.sigmoid(x) 84 | return x * out 85 | 86 | 87 | class ResBlock(nn.Module): 88 | def __init__(self, ch, expansion=2): 89 | super(ResBlock, self).__init__() 90 | self.main = nn.Sequential(spectral_norm(nn.Conv2d(ch, ch * expansion, 1, 1, 0, bias=False)), 91 | spectral_norm(nn.BatchNorm2d(ch * expansion)), Swish(), 92 | spectral_norm(DepthwiseConv2d(ch * expansion, ch * expansion, 3, 1, 1)), 93 | spectral_norm(nn.BatchNorm2d(ch * expansion)), Swish(), 94 | spectral_norm(nn.Conv2d(ch * expansion, ch, 1, 1, 0, bias=False)), 95 | spectral_norm(nn.BatchNorm2d(ch)), Swish(), 96 | ECAModule(ch)) 97 | 98 | def forward(self, x): 99 | return x + self.main(x) 100 | 101 | 102 | def base_block(ch_in, ch_out): 103 | return nn.Sequential(nn.Conv2d(ch_in, ch_out, 3, 1, 1, bias=False), 104 | nn.BatchNorm2d(ch_out), 105 | nn.LeakyReLU(0.2, inplace=True)) 106 | 107 | 108 | def down_block(ch_in, ch_out): 109 | return nn.Sequential(nn.Conv2d(ch_in, ch_out, 4, 2, 1, bias=False), 110 | nn.BatchNorm2d(ch_out), 111 | nn.LeakyReLU(0.1, inplace=True)) 112 | 113 | 114 | ################################ 115 | # style encode # 116 | ################################ 117 | 118 | class StyleEncoder(nn.Module): 119 | def __init__(self, ch=32, nbr_cls=100): 120 | super().__init__() 121 | 122 | self.sf_256 = base_block(3, ch // 2) 123 | self.sf_128 = down_block(ch // 2, ch) 124 | self.sf_64 = down_block(ch, ch * 2) 125 | 126 | self.sf_32 = nn.Sequential(down_block(ch * 2, ch * 4), 127 | ResBlock(ch * 4)) 128 | self.sf_16 = nn.Sequential(down_block(ch * 4, ch * 8), 129 | ResBlock(ch * 8)) 130 | self.sf_8 = nn.Sequential(down_block(ch * 8, ch * 16), 131 | ResBlock(ch * 16)) 132 | 133 | self.sfv_32 = nn.Sequential(nn.AdaptiveAvgPool2d(output_size=4), 134 | nn.Conv2d(ch * 4, ch * 2, 4, 1, 0, bias=False), 135 | Squeeze()) 136 | self.sfv_16 = nn.Sequential(nn.AdaptiveAvgPool2d(output_size=4), 137 | nn.Conv2d(ch * 8, ch * 4, 4, 1, 0, bias=False), 138 | Squeeze()) 139 | self.sfv_8 = nn.Sequential(nn.AdaptiveAvgPool2d(output_size=4), 140 | nn.Conv2d(ch * 16, ch * 8, 4, 1, 0, bias=False), 141 | Squeeze()) 142 | 143 | self.ch = ch 144 | self.nbr_cls = nbr_cls 145 | self.final_cls = None 146 | 147 | def reset_cls(self): 148 | if self.final_cls is None: 149 | self.final_cls = nn.Sequential(nn.LeakyReLU(0.1), nn.Linear(self.ch * 8, self.nbr_cls)) 150 | stdv = 1. / math.sqrt(self.final_cls[1].weight.size(1)) 151 | self.final_cls[1].weight.data.uniform_(-stdv, stdv) 152 | if self.final_cls[1].bias is not None: 153 | self.final_cls[1].bias.data.uniform_(-0.1 * stdv, 0.1 * stdv) 154 | 155 | def get_feats(self, image): 156 | feat = self.sf_256(image) 157 | feat = self.sf_128(feat) 158 | feat = self.sf_64(feat) 159 | feat_32 = self.sf_32(feat) 160 | feat_16 = self.sf_16(feat_32) 161 | feat_8 = self.sf_8(feat_16) 162 | 163 | feat_32 = self.sfv_32(feat_32) 164 | feat_16 = self.sfv_16(feat_16) 165 | feat_8 = self.sfv_8(feat_8) 166 | 167 | return feat_32, feat_16, feat_8 168 | 169 | def forward(self, image): 170 | feat_32, feat_16, feat_8 = self.get_feats(image) 171 | pred_cls = self.final_cls(feat_8) 172 | 173 | return [feat_32, feat_16, feat_8], pred_cls 174 | # [1, 64] [1, 128] [1, 256] 175 | 176 | 177 | ################################ 178 | # content encode # 179 | ################################ 180 | 181 | class ContentEncoder(nn.Module): 182 | def __init__(self, ch=32): 183 | super().__init__() 184 | 185 | self.feat_256 = base_block(1, ch // 4) 186 | self.feat_128 = down_block(ch // 4, ch // 2) 187 | self.feat_64 = down_block(ch // 2, ch) 188 | 189 | self.feat_32 = nn.Sequential(down_block(ch, ch * 2), 190 | ResBlock(ch * 2)) 191 | self.feat_16 = nn.Sequential(down_block(ch * 2, ch * 4), 192 | ResBlock(ch * 4)) 193 | self.feat_8 = nn.Sequential(down_block(ch * 4, ch * 8), 194 | ResBlock(ch * 8)) 195 | 196 | def forward(self, image): 197 | feat = self.feat_256(image) 198 | feat = self.feat_128(feat) 199 | feat = self.feat_64(feat) 200 | 201 | feat_32 = self.feat_32(feat) 202 | feat_16 = self.feat_16(feat_32) 203 | feat_8 = self.feat_8(feat_16) 204 | 205 | return [feat_32, feat_16, feat_8] 206 | # [1, 64, 32, 32] 207 | # [1, 128, 16, 16] 208 | # [1, 256, 8, 8] 209 | 210 | 211 | def for_decoder(ch_in, ch_out): 212 | return nn.Sequential( 213 | nn.UpsamplingNearest2d(scale_factor=2), 214 | nn.Conv2d(ch_in, ch_out * 2, 3, 1, 1, bias=False), 215 | nn.InstanceNorm2d(ch_out * 2), 216 | GLU()) 217 | 218 | 219 | def style_decode(ch_in, ch_out): 220 | return nn.Sequential(nn.Linear(ch_in, ch_out), nn.ReLU(), 221 | nn.Linear(ch_out, ch_out), nn.Sigmoid(), 222 | UnSqueeze()) 223 | 224 | 225 | ################################ 226 | # decode # 227 | ################################ 228 | 229 | class Decoder(nn.Module): 230 | def __init__(self, ch=32): 231 | super().__init__() 232 | 233 | self.base_feat = nn.Parameter(torch.randn(1, ch * 8, 8, 8).normal_(0, 1), requires_grad=True) 234 | 235 | self.dmi_8 = DMI(ch * 8) 236 | self.dmi_16 = DMI(ch * 4) 237 | 238 | self.feat_8_1 = nn.Sequential(ResBlock(ch * 16), nn.LeakyReLU(0.1, inplace=True), 239 | nn.Conv2d(ch * 16, ch * 8, 3, 1, 1, bias=False), 240 | nn.InstanceNorm2d(ch * 8)) 241 | self.feat_8_2 = nn.Sequential(nn.LeakyReLU(0.1, inplace=True), ResBlock(ch * 8)) 242 | 243 | self.feat_16 = nn.Sequential(nn.LeakyReLU(0.1, inplace=True), 244 | for_decoder(ch * 8, ch * 4), ResBlock(ch * 4)) 245 | self.feat_32 = nn.Sequential(nn.LeakyReLU(0.1, inplace=True), 246 | for_decoder(ch * 8, ch * 2), ResBlock(ch * 2)) 247 | 248 | self.feat_64 = for_decoder(ch * 4, ch) 249 | self.feat_128 = for_decoder(ch, ch // 2) 250 | self.feat_256 = for_decoder(ch // 2, ch // 4) 251 | 252 | self.to_rgb = nn.Sequential(nn.Conv2d(ch // 4, 3, 3, 1, 1, bias=False), 253 | nn.Tanh()) 254 | 255 | self.style_8 = style_decode(ch * 8, ch * 8) 256 | self.style_64 = style_decode(ch * 8, ch) 257 | self.style_128 = style_decode(ch * 4, ch // 2) 258 | self.style_256 = style_decode(ch * 2, ch // 4) 259 | 260 | def forward(self, content_feats, style_vectors): 261 | feat_8 = self.feat_8_1(torch.cat([content_feats[2], 262 | self.base_feat.repeat(style_vectors[0].shape[0], 1, 1, 1)], dim=1)) 263 | feat_8 = self.dmi_8(feat_8, content_feats[2]) 264 | 265 | feat_8 = feat_8 * self.style_8(style_vectors[2]) 266 | feat_8 = self.feat_8_2(feat_8) 267 | 268 | feat_16 = self.feat_16(feat_8) 269 | feat_16 = self.dmi_16(feat_16, content_feats[1]) 270 | feat_16 = torch.cat([feat_16, content_feats[1]], dim=1) 271 | 272 | feat_32 = self.feat_32(feat_16) 273 | feat_32 = torch.cat([feat_32, content_feats[0]], dim=1) 274 | 275 | feat_64 = self.feat_64(feat_32) * self.style_64(style_vectors[2]) 276 | feat_128 = self.feat_128(feat_64) * self.style_128(style_vectors[1]) 277 | feat_256 = self.feat_256(feat_128) * self.style_256(style_vectors[0]) 278 | 279 | return self.to_rgb(feat_256) 280 | 281 | 282 | ################################ 283 | # AE Module # 284 | ################################ 285 | 286 | class AE(nn.Module): 287 | def __init__(self, ch, nbr_cls=100): 288 | super().__init__() 289 | 290 | self.style_encoder = StyleEncoder(ch, nbr_cls=nbr_cls) 291 | self.content_encoder = ContentEncoder(ch) 292 | self.decoder = Decoder(ch) 293 | 294 | @torch.no_grad() 295 | def forward(self, skt_img, style_img): 296 | style_feats = self.style_encoder.get_feats(F.interpolate(style_img, size=256)) 297 | content_feats = self.content_encoder(F.interpolate(skt_img, size=256)) 298 | gimg = self.decoder(content_feats, style_feats) 299 | return gimg, style_feats 300 | 301 | def load_state_dicts(self, path): 302 | ckpt = torch.load(path) 303 | self.style_encoder.reset_cls() 304 | self.style_encoder.load_state_dict(ckpt['s']) 305 | self.content_encoder.load_state_dict(ckpt['c']) 306 | self.decoder.load_state_dict(ckpt['d']) 307 | print('AE model load success') 308 | 309 | 310 | def down_gan(ch_in, ch_out): 311 | return nn.Sequential( 312 | spectral_norm(nn.Conv2d(ch_in, ch_out, 4, 2, 1, bias=False)), 313 | nn.BatchNorm2d(ch_out), 314 | nn.LeakyReLU(0.1, inplace=True)) 315 | 316 | 317 | def up_gan(ch_in, ch_out): 318 | return nn.Sequential( 319 | nn.UpsamplingNearest2d(scale_factor=2), 320 | spectral_norm(nn.Conv2d(ch_in, ch_out, 3, 1, 1, bias=False)), 321 | nn.BatchNorm2d(ch_out), 322 | nn.LeakyReLU(0.1, inplace=True)) 323 | 324 | 325 | def style_gan(ch_in, ch_out): 326 | return nn.Sequential( 327 | spectral_norm(nn.Linear(ch_in, ch_out)), nn.ReLU(), 328 | nn.Linear(ch_out, ch_out), 329 | nn.Sigmoid(), UnSqueeze()) 330 | 331 | 332 | ################################ 333 | # GAN # 334 | ################################ 335 | 336 | class RefineGenerator(nn.Module): 337 | def __init__(self, ch=32, im_size=256): 338 | super().__init__() 339 | 340 | self.im_size = im_size 341 | 342 | self.from_noise_32 = nn.Sequential(UnSqueeze(), 343 | spectral_norm(nn.ConvTranspose2d(ch * 8, ch * 8, 4, 1, 0, bias=False)), 344 | nn.BatchNorm2d(ch * 8), 345 | nn.Sigmoid(), 346 | up_gan(ch * 8, ch * 4), 347 | up_gan(ch * 4, ch * 2), 348 | up_gan(ch * 2, ch * 1)) 349 | 350 | self.from_style = nn.Sequential(UnSqueeze(), 351 | spectral_norm( 352 | nn.ConvTranspose2d(ch * (8 + 4 + 2), ch * 16, 4, 1, 0, bias=False)), 353 | nn.BatchNorm2d(ch * 16), 354 | GLU(), 355 | up_gan(ch * 8, ch * 4)) 356 | 357 | self.encode_256 = nn.Sequential(spectral_norm(nn.Conv2d(3, ch, 3, 1, 1, bias=False)), 358 | nn.LeakyReLU(0.2, inplace=True)) 359 | self.encode_128 = nn.Sequential(ResBlock(ch), 360 | down_gan(ch, ch * 2)) 361 | self.encode_64 = nn.Sequential(ResBlock(ch * 2), 362 | down_gan(ch * 2, ch * 4)) 363 | self.encode_32 = nn.Sequential(ResBlock(ch * 4), 364 | down_gan(ch * 4, ch * 8)) 365 | 366 | self.encode_16 = nn.Sequential(ResBlock(ch * 8), 367 | down_gan(ch * 8, ch * 16)) 368 | 369 | self.decode_32 = nn.Sequential(ResBlock(ch * 16), 370 | up_gan(ch * 16, ch * 8)) 371 | self.decode_64 = nn.Sequential(ResBlock(ch * 8 + ch), 372 | up_gan(ch * 8 + ch, ch * 4)) 373 | self.decode_128 = nn.Sequential(ResBlock(ch * 4), 374 | up_gan(ch * 4, ch * 2)) 375 | self.decode_256 = nn.Sequential(ResBlock(ch * 2), 376 | up_gan(ch * 2, ch)) 377 | 378 | self.style_64 = style_gan(ch * 8, ch * 4) 379 | self.style_128 = style_gan(ch * 4, ch * 2) 380 | self.style_256 = style_gan(ch * 2, ch) 381 | 382 | self.to_rgb = nn.Sequential(nn.Conv2d(ch, 3, 3, 1, 1, bias=False), nn.Tanh()) 383 | 384 | def forward(self, image, style_vectors): 385 | n_32 = self.from_noise_32(torch.randn_like(style_vectors[2])) # [8, 32, 32, 32] 386 | 387 | e_256 = self.encode_256(image) # [8, 3, 256, 256] [8, 32, 256, 256] 388 | e_128 = self.encode_128(e_256) # [8, 64, 128, 128] 389 | e_64 = self.encode_64(e_128) # [8, 128, 64, 64] 390 | e_32 = self.encode_32(e_64) # [8, 256, 32, 32] 391 | 392 | e_16 = self.encode_16(e_32) # [8, 256, 16, 16] 393 | 394 | d_32 = self.decode_32(e_16) # [8, 256, 32, 32] 395 | d_64 = self.decode_64(torch.cat([d_32, n_32], dim=1)) # [8, 128, 64, 64] 396 | d_64 = self.style_64(style_vectors[2]) * d_64 # [8, 128, 64, 64] 397 | 398 | d_128 = self.decode_128(d_64 + e_64) # [8, 64, 128, 128] 399 | d_128 = self.style_128(style_vectors[1]) * d_128 # [8, 64, 128, 128] 400 | 401 | d_256 = self.decode_256(d_128 + e_128) # [8, 32, 256, 256] 402 | d_256 = self.style_256(style_vectors[0]) * d_256 # [8, 32, 256, 256] 403 | 404 | d_final = self.to_rgb(d_256) 405 | 406 | return d_final 407 | 408 | 409 | class DownBlock(nn.Module): 410 | def __init__(self, ch_in, ch_out): 411 | super().__init__() 412 | 413 | self.ch_out = ch_out 414 | self.down_main = nn.Sequential( 415 | spectral_norm(nn.Conv2d(ch_in, ch_out, 3, 2, 1, bias=False)), 416 | nn.BatchNorm2d(ch_out), 417 | nn.LeakyReLU(0.1, inplace=True), 418 | spectral_norm(nn.Conv2d(ch_out, ch_out, 3, 1, 1, bias=False)), 419 | nn.BatchNorm2d(ch_out), 420 | nn.LeakyReLU(0.1, inplace=True) 421 | ) 422 | 423 | def forward(self, feat): 424 | feat_out = self.down_main(feat) 425 | 426 | return feat_out 427 | 428 | 429 | class Discriminator(nn.Module): 430 | def __init__(self, ch=64, nc=3, im_size=256): 431 | super(Discriminator, self).__init__() 432 | self.ch = ch 433 | self.im_size = im_size 434 | 435 | self.f_256 = nn.Sequential(spectral_norm(nn.Conv2d(nc, ch // 8, 3, 1, 1, bias=False)), 436 | nn.LeakyReLU(0.2, inplace=True)) 437 | 438 | self.f_128 = DownBlock(ch // 8, ch // 4) 439 | self.f_64 = DownBlock(ch // 4, ch // 2) 440 | self.f_32 = DownBlock(ch // 2, ch) 441 | self.f_16 = DownBlock(ch, ch * 2) 442 | self.f_8 = DownBlock(ch * 2, ch * 4) 443 | self.f = nn.Sequential(spectral_norm(nn.Conv2d(ch * 4, ch * 8, 1, 1, 0, bias=False)), 444 | nn.BatchNorm2d(ch * 8), 445 | nn.LeakyReLU(0.1, inplace=True)) 446 | 447 | self.flatten = spectral_norm(nn.Conv2d(ch * 8, 1, 3, 1, 1, bias=False)) 448 | 449 | self.apply(weights_init) 450 | 451 | def forward(self, x): 452 | feat_256 = self.f_256(x) 453 | feat_128 = self.f_128(feat_256) 454 | feat_64 = self.f_64(feat_128) 455 | feat_32 = self.f_32(feat_64) 456 | feat_16 = self.f_16(feat_32) 457 | feat_8 = self.f_8(feat_16) 458 | feat_f = self.f(feat_8) 459 | feat_out = self.flatten(feat_f) 460 | 461 | return feat_out 462 | -------------------------------------------------------------------------------- /styleme/readme.md: -------------------------------------------------------------------------------- 1 | # Environment : 2 | 3 | - python 3.8.0 4 | - pytorch 1.12.1 5 | 6 |
7 | 8 |

9 | 10 |

11 | 12 |

Fig.1 An overview of our style transform network of StyleMe

13 | 14 |
15 | 16 | - you can download our datasets which includs 119 RGB images and 119 sketches here: [**styleme datasets**](https://drive.google.com/drive/folders/1UycahUifPoc0n6pyP92bWC07BlJETwRR) 17 | 18 |
19 | 20 | - We provided a pretrained model that was trained 30,000 times here: [**styleme model**](https://drive.google.com/drive/folders/1JHmDdsV6OS0sf6v-OhwkpbkDPn7Co2HW) 21 | 22 |
23 | 24 | 25 | ## 1. Description 26 | Related code comments: 27 | 28 | * train.py: training the hole model, and you can also choose train AE module only or train GAN module only. 29 | * models.py: all the related models' structure definition, including encoder(style and content), decoder(decode random style features and content features), generator, and discriminator. 30 | * datasets.py: data pre-processing and loading methods. 31 | * train_step_1.py: AE module training. 32 | * train_step_2.py: GAN module training. 33 | * config.py: all the hyper-parameters settings. 34 | * calcualte.py: calculate the FID and LPIPS of the model. 35 | * benchmark.py: the FID functions, including inception model and it will automatically download. 36 | * lpips: the LPIPS functions, also including inception model and automatically download. 37 | * style_transform.py: put your sketch and RGB images to tansform the style. 38 | 39 | 40 | ## 2. Training 41 | 42 | - first prepare your datasets as follows: 43 | 44 | ``` 45 | train_data/ 46 | -./rgb/ 47 | -000.png 48 | -001.png 49 | -... 50 | -./sketch/ 51 | -000.png 52 | -001.png 53 | -... 54 | ``` 55 | 56 |
57 | 58 | - and then training your models: 59 | 60 | ``` 61 | python train.py 62 | ``` 63 | 64 |
65 | 66 | ## 3. Evaluate 67 | 68 | - You can run the following program to see the performance of our model: 69 | 70 | ``` 71 | python style_transform.py 72 | ``` 73 | 74 | - or you can also get the FID and LPIPS: 75 | 76 | ``` 77 | python calculate.py 78 | ``` 79 | -------------------------------------------------------------------------------- /styleme/style_transform.py: -------------------------------------------------------------------------------- 1 | ############################## 2 | # style transform # 3 | ############################## 4 | 5 | import torch 6 | from torch.utils.data import DataLoader 7 | from torchvision import utils as vutils 8 | from datasets import ImageFolder, transform_data 9 | from models import AE, RefineGenerator 10 | 11 | 12 | def make_matrix(dataloader_rgb, dataloader_skt, net_ae, net_ig, BATCH_SIZE, IM_SIZE, im_name): 13 | rgb_img = next(dataloader_rgb) 14 | skt_img = next(dataloader_skt) 15 | 16 | skt_img = skt_img.mean(dim=1, keepdim=True) 17 | 18 | image_matrix = [torch.ones(1, 3, IM_SIZE, IM_SIZE)] 19 | image_matrix.append(rgb_img.clone()) 20 | with torch.no_grad(): 21 | rgb_img = rgb_img.cuda() 22 | for skt in skt_img: 23 | input_skts = skt.unsqueeze(0).repeat(BATCH_SIZE, 1, 1, 1).cuda() 24 | 25 | gimg_ae, style_feats = net_ae(input_skts, rgb_img) 26 | image_matrix.append(skt.unsqueeze(0).repeat(1, 3, 1, 1).clone()) 27 | image_matrix.append(gimg_ae.cpu()) 28 | 29 | g_images = net_ig(gimg_ae, style_feats).cpu() 30 | image_matrix.append(skt.unsqueeze(0).repeat(1, 3, 1, 1).clone().fill_(1)) 31 | image_matrix.append(torch.nn.functional.interpolate(g_images, IM_SIZE)) 32 | 33 | image_matrix = torch.cat(image_matrix) 34 | vutils.save_image(0.5 * (image_matrix + 1), im_name, nrow=BATCH_SIZE + 1) 35 | 36 | 37 | if __name__ == "__main__": 38 | device = 'cuda' 39 | batch_size = 5 40 | img_size = 256 41 | num_workers = 2 42 | trans_iter = 20 43 | data_root_colorful = './train_data/rgb/' 44 | data_root_sketch = './train_data/sketch/' 45 | 46 | net_ae = AE(ch=32, nbr_cls=50) 47 | net_ae.style_encoder.reset_cls() 48 | net_ig = RefineGenerator() 49 | 50 | ckpt = torch.load('./checkpoint/GAN.pth') 51 | 52 | net_ae.load_state_dict(ckpt['ae']) 53 | net_ae.style_encoder.reset_cls() 54 | net_ig.load_state_dict(ckpt['ig']) 55 | 56 | net_ae.to(device) 57 | net_ig.to(device) 58 | net_ae.eval() 59 | net_ig.eval() 60 | 61 | dataset_rgb = ImageFolder(data_root_colorful, transform_data(img_size)) 62 | dataloader_rgb = iter(DataLoader(dataset_rgb, batch_size, shuffle=False, num_workers=num_workers)) 63 | 64 | dataset_skt = ImageFolder(data_root_sketch, transform_data(img_size)) 65 | dataloader_skt = iter(DataLoader(dataset_skt, batch_size, shuffle=False, num_workers=num_workers)) 66 | 67 | for idx in range(trans_iter): 68 | print(idx) 69 | make_matrix(dataloader_rgb, dataloader_skt, net_ae, net_ig, batch_size, img_size, 70 | './trans_data/transform/%d.jpg' % idx) 71 | -------------------------------------------------------------------------------- /styleme/train.py: -------------------------------------------------------------------------------- 1 | ############################ 2 | # main training # 3 | ############################ 4 | 5 | import train_step_1 6 | import train_step_2 7 | from config import TRAIN_AE_ONLY, TRAIN_GAN_ONLY 8 | 9 | 10 | if __name__ == "__main__": 11 | if TRAIN_GAN_ONLY: 12 | print('train gan only !') 13 | train_step_1.train() 14 | else: 15 | print('train ae first !') 16 | train_step_1.train() 17 | if not TRAIN_AE_ONLY: 18 | train_step_2.train() 19 | -------------------------------------------------------------------------------- /styleme/train_step_1.py: -------------------------------------------------------------------------------- 1 | ############################# 2 | # train_step_1 # 3 | # # 4 | # transform images # 5 | ############################# 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | from torch import optim 10 | from torch.utils.data import DataLoader 11 | from torchvision import utils as vutils 12 | 13 | import random 14 | from tqdm import tqdm 15 | 16 | from datasets import TransformData, InfiniteSamplerWrapper 17 | from utils import make_folders, AverageMeter 18 | from models import StyleEncoder, ContentEncoder, Decoder 19 | 20 | 21 | def loss_for_style(style, style_org, batch_size): 22 | loss_result = 0 23 | for loss_idx in range(len(style)): 24 | loss_result += - F.cosine_similarity(style[loss_idx], 25 | style_org[loss_idx].detach()).mean() + \ 26 | F.cosine_similarity(style[loss_idx], 27 | style_org[loss_idx][torch.randperm(batch_size)] 28 | .detach()).mean() 29 | return loss_result / len(style) 30 | 31 | 32 | def loss_for_content(loss, fl1, fl2): 33 | loss_result = 0 34 | for f_idx in range(len(fl1)): 35 | loss_result += loss(fl1[f_idx], fl2[f_idx].detach()) 36 | return loss_result * 2 37 | 38 | 39 | def train(): 40 | from config import IM_SIZE_AE, BATCH_SIZE_AE, CHANNEL, NBR_CLS, DATALOADER_WORKERS, ITERATION_AE 41 | from config import SAVE_IMAGE_INTERVAL, SAVE_MODEL_INTERVAL, SAVE_FOLDER, LOG_INTERVAL 42 | from config import data_root_colorful, data_root_sketch 43 | 44 | dataset_trans = TransformData(data_root_colorful, data_root_sketch, im_size=IM_SIZE_AE, nbr_cls=NBR_CLS) 45 | print('Num classes:', len(dataset_trans), ' Data nums:', len(dataset_trans.frame)) 46 | dataloader_trans = iter(DataLoader(dataset_trans, BATCH_SIZE_AE, 47 | sampler=InfiniteSamplerWrapper(dataset_trans), 48 | num_workers=DATALOADER_WORKERS, pin_memory=True)) 49 | 50 | style_encoder = StyleEncoder(ch=CHANNEL, nbr_cls=NBR_CLS).cuda() 51 | content_encoder = ContentEncoder(ch=CHANNEL).cuda() 52 | decoder = Decoder(ch=CHANNEL).cuda() 53 | 54 | opt_content = optim.Adam(content_encoder.parameters(), lr=1e-4, betas=(0.9, 0.999)) 55 | opt_style = optim.Adam(style_encoder.parameters(), lr=1e-4, betas=(0.9, 0.999)) 56 | opt_decode = optim.Adam(decoder.parameters(), lr=1e-4, betas=(0.9, 0.999)) 57 | 58 | style_encoder.reset_cls() 59 | style_encoder.final_cls.cuda() 60 | 61 | # load model 62 | from config import PRETRAINED_AE_PATH 63 | if PRETRAINED_AE_PATH is not None: 64 | ckpt = torch.load(PRETRAINED_AE_PATH) 65 | 66 | print('Pre-trained AE path : ', PRETRAINED_AE_PATH) 67 | 68 | style_encoder.load_state_dict(ckpt['s']) 69 | content_encoder.load_state_dict(ckpt['c']) 70 | decoder.load_state_dict(ckpt['d']) 71 | 72 | opt_style.load_state_dict(ckpt['opt_s']) 73 | opt_content.load_state_dict(ckpt['opt_c']) 74 | opt_decode.load_state_dict(ckpt['opt_d']) 75 | print('loaded pre-trained AE') 76 | 77 | style_encoder.reset_cls() 78 | style_encoder.final_cls.cuda() 79 | opt_s_cls = optim.Adam(style_encoder.final_cls.parameters(), lr=1e-4, betas=(0.9, 0.999)) 80 | 81 | # save path 82 | saved_image_folder, saved_model_folder = make_folders(SAVE_FOLDER, 'Train_step_1') 83 | 84 | # loss log 85 | losses_style_feat = AverageMeter() 86 | losses_content_feat = AverageMeter() 87 | losses_cls = AverageMeter() 88 | losses_org = AverageMeter() 89 | losses_rd = AverageMeter() 90 | losses_flip = AverageMeter() 91 | 92 | import lpips 93 | percept = lpips.PerceptualLoss(model='net-lin', net='vgg', use_gpu=True) 94 | 95 | for iteration in tqdm(range(ITERATION_AE)): 96 | if iteration % ((NBR_CLS * 100) // BATCH_SIZE_AE) == 0 and iteration > 1: 97 | dataset_trans._next_set() 98 | dataloader_trans = iter(DataLoader(dataset_trans, BATCH_SIZE_AE, 99 | sampler=InfiniteSamplerWrapper(dataset_trans), 100 | num_workers=DATALOADER_WORKERS, pin_memory=True)) 101 | style_encoder.reset_cls() 102 | opt_s_cls = optim.Adam(style_encoder.final_cls.parameters(), lr=1e-4, betas=(0.9, 0.999)) 103 | 104 | opt_style.param_groups[0]['lr'] = 1e-4 105 | opt_decode.param_groups[0]['lr'] = 1e-4 106 | 107 | # 1. training for encode & decode 108 | # 1.1 prepare data 109 | rgb_img_org, rgb_img_rd, rgb_img_flip, skt_org, skt_erased, skt_bold, img_idx = next(dataloader_trans) 110 | rgb_img_org = rgb_img_org.cuda() 111 | rgb_img_rd = rgb_img_rd.cuda() 112 | rgb_img_flip = rgb_img_flip.cuda() 113 | 114 | skt_org = F.interpolate(skt_org, size=IM_SIZE_AE).cuda() 115 | skt_erased = F.interpolate(skt_erased, size=IM_SIZE_AE).cuda() 116 | skt_bold = F.interpolate(skt_bold, size=IM_SIZE_AE).cuda() 117 | 118 | img_idx = img_idx.long().cuda() 119 | 120 | # 1.2 model grad zero 121 | style_encoder.zero_grad() 122 | content_encoder.zero_grad() 123 | decoder.zero_grad() 124 | 125 | ################ 126 | # encode # 127 | ################ 128 | # 1.3 for style 129 | style_vector_org, pred_cls_org = style_encoder(rgb_img_org) 130 | style_vector_rd, pred_cls_rd = style_encoder(rgb_img_rd) 131 | style_vector_flip, pred_cls_flip = style_encoder(rgb_img_flip) 132 | 133 | # 1.4 for content 134 | content_feats_org = content_encoder(skt_org) 135 | content_feats_erased = content_encoder(skt_erased) 136 | content_feats_bold = content_encoder(skt_bold) 137 | 138 | # 1.5 encode loss 139 | loss_style_feat = loss_for_style(style_vector_rd, style_vector_org, BATCH_SIZE_AE) + \ 140 | loss_for_style(style_vector_flip, style_vector_org, BATCH_SIZE_AE) 141 | 142 | loss_content_feat = loss_for_content(F.mse_loss, content_feats_bold, content_feats_org) + \ 143 | loss_for_content(F.mse_loss, content_feats_erased, content_feats_org) 144 | 145 | loss_cls = F.cross_entropy(pred_cls_org, img_idx) + \ 146 | F.cross_entropy(pred_cls_rd, img_idx) + \ 147 | F.cross_entropy(pred_cls_flip, img_idx) 148 | 149 | ################ 150 | # decode # 151 | ################ 152 | org = random.randint(0, 2) 153 | gimg_org = None 154 | if org == 0: 155 | gimg_org = decoder(content_feats_org, style_vector_org) 156 | elif org == 1: 157 | gimg_org = decoder(content_feats_erased, style_vector_org) 158 | elif org == 2: 159 | gimg_org = decoder(content_feats_bold, style_vector_org) 160 | 161 | rd = random.randint(0, 2) 162 | gimg_rd = None 163 | if rd == 0: 164 | gimg_rd = decoder(content_feats_org, style_vector_rd) 165 | elif rd == 1: 166 | gimg_rd = decoder(content_feats_erased, style_vector_rd) 167 | elif rd == 2: 168 | gimg_rd = decoder(content_feats_bold, style_vector_rd) 169 | 170 | flip = random.randint(0, 2) 171 | gimg_flip = None 172 | if flip == 0: 173 | gimg_flip = decoder(content_feats_org, style_vector_flip) 174 | elif flip == 1: 175 | gimg_flip = decoder(content_feats_erased, style_vector_flip) 176 | elif flip == 2: 177 | gimg_flip = decoder(content_feats_bold, style_vector_flip) 178 | 179 | # 1.6 decode loss 180 | loss_org = F.mse_loss(gimg_org, rgb_img_org) + \ 181 | percept(F.adaptive_avg_pool2d(gimg_org, output_size=256), 182 | F.adaptive_avg_pool2d(rgb_img_org, output_size=256)).sum() 183 | 184 | loss_rd = F.mse_loss(gimg_rd, rgb_img_org) + \ 185 | percept(F.adaptive_avg_pool2d(gimg_rd, output_size=256), 186 | F.adaptive_avg_pool2d(rgb_img_org, output_size=256)).sum() 187 | 188 | loss_flip = F.mse_loss(gimg_flip, rgb_img_org) + \ 189 | percept(F.adaptive_avg_pool2d(gimg_flip, output_size=256), 190 | F.adaptive_avg_pool2d(rgb_img_org, output_size=256)).sum() 191 | 192 | loss_total = loss_style_feat + loss_content_feat + loss_cls + loss_org + loss_rd + loss_flip 193 | loss_total.backward() 194 | 195 | opt_style.step() 196 | opt_content.step() 197 | opt_s_cls.step() 198 | opt_decode.step() 199 | 200 | # 1.7 update log 201 | losses_style_feat.update(loss_style_feat.mean().item(), BATCH_SIZE_AE) 202 | losses_content_feat.update(loss_content_feat.mean().item(), BATCH_SIZE_AE) 203 | losses_cls.update(loss_cls.mean().item(), BATCH_SIZE_AE) 204 | losses_org.update(loss_org.item(), BATCH_SIZE_AE) 205 | losses_rd.update(loss_rd.item(), BATCH_SIZE_AE) 206 | losses_flip.update(loss_flip.item(), BATCH_SIZE_AE) 207 | 208 | # 1.8 print log 209 | if iteration % LOG_INTERVAL == 0: 210 | log_msg = '\nTrain Stage 1 (encode and decode): \n' \ 211 | 'loss_encode_style: %.4f loss_encode_content: %.4f loss_encode_class: %.4f \n' \ 212 | 'loss_decode_org: %.4f loss_decode_rd: %.4f loss_decode_flip: %.4f' % ( 213 | losses_style_feat.avg, losses_content_feat.avg, losses_cls.avg, 214 | losses_org.avg, losses_rd.avg, losses_flip.avg) 215 | print(log_msg) 216 | 217 | losses_style_feat.reset() 218 | losses_content_feat.reset() 219 | losses_cls.reset() 220 | losses_org.reset() 221 | losses_rd.reset() 222 | losses_flip.reset() 223 | 224 | if iteration % SAVE_IMAGE_INTERVAL == 0: 225 | vutils.save_image(torch.cat([rgb_img_org, 226 | F.interpolate(skt_org.repeat(1, 3, 1, 1), size=IM_SIZE_AE), 227 | gimg_org]), 228 | '%s/%d_org.jpg' % (saved_image_folder, iteration), normalize=True, range=(-1, 1)) 229 | vutils.save_image(torch.cat([rgb_img_rd, 230 | F.interpolate(skt_org.repeat(1, 3, 1, 1), size=IM_SIZE_AE), 231 | gimg_rd]), 232 | '%s/%d_rd.jpg' % (saved_image_folder, iteration), normalize=True, range=(-1, 1)) 233 | vutils.save_image(torch.cat([rgb_img_flip, 234 | F.interpolate(skt_org.repeat(1, 3, 1, 1), size=IM_SIZE_AE), 235 | gimg_flip]), 236 | '%s/%d_flip.jpg' % (saved_image_folder, iteration), normalize=True, range=(-1, 1)) 237 | 238 | if iteration % SAVE_MODEL_INTERVAL == 0: 239 | print('Saving history model') 240 | torch.save({'s': style_encoder.state_dict(), 241 | 'c': content_encoder.state_dict(), 242 | 'd': decoder.state_dict(), 243 | 'opt_s': opt_style.state_dict(), 244 | 'opt_c': opt_content.state_dict(), 245 | 'opt_s_cls': opt_s_cls.state_dict(), 246 | 'opt_d': opt_decode.state_dict(), 247 | }, '%s/%d.pth' % (saved_model_folder, iteration)) 248 | 249 | torch.save({'s': style_encoder.state_dict(), 250 | 'c': content_encoder.state_dict(), 251 | 'd': decoder.state_dict(), 252 | 'opt_s': opt_style.state_dict(), 253 | 'opt_c': opt_content.state_dict(), 254 | 'opt_s_cls': opt_s_cls.state_dict(), 255 | 'opt_d': opt_decode.state_dict(), 256 | }, '%s/%d.pth' % (saved_model_folder, ITERATION_AE)) 257 | 258 | 259 | if __name__ == "__main__": 260 | train() 261 | -------------------------------------------------------------------------------- /styleme/train_step_2.py: -------------------------------------------------------------------------------- 1 | ################################# 2 | # train_step_2 # 3 | # # 4 | # optimize transform images # 5 | ################################# 6 | 7 | import torch 8 | from torch import nn 9 | import torch.nn.functional as F 10 | from torch import optim 11 | 12 | from torch.utils.data import DataLoader 13 | from torchvision import utils as vutils 14 | 15 | import os 16 | from tqdm import tqdm 17 | from datetime import datetime 18 | import pandas as pd 19 | 20 | from datasets import PairedDataset, InfiniteSamplerWrapper 21 | from utils import copy_G_params, make_folders, AverageMeter, d_hinge_loss, g_hinge_loss 22 | from models import AE, Discriminator 23 | 24 | 25 | def make_matrix(dataset_rgb, dataset_skt, net_ae, net_ig, BATCH_SIZE, IM_SIZE, im_name): 26 | dataloader_rgb = iter(DataLoader(dataset_rgb, BATCH_SIZE, shuffle=True)) 27 | dataloader_skt = iter(DataLoader(dataset_skt, BATCH_SIZE, shuffle=True)) 28 | 29 | rgb_img = next(dataloader_rgb) 30 | skt_img = next(dataloader_skt) 31 | 32 | skt_img = skt_img.mean(dim=1, keepdim=True) 33 | 34 | image_matrix = [torch.ones(1, 3, IM_SIZE, IM_SIZE)] 35 | image_matrix.append(rgb_img.clone()) 36 | with torch.no_grad(): 37 | rgb_img = rgb_img.cuda() 38 | for skt in skt_img: 39 | input_skts = skt.unsqueeze(0).repeat(BATCH_SIZE, 1, 1, 1).cuda() 40 | 41 | gimg_ae, style_feats = net_ae(input_skts, rgb_img) 42 | image_matrix.append(skt.unsqueeze(0).repeat(1, 3, 1, 1).clone()) 43 | image_matrix.append(gimg_ae.cpu()) 44 | 45 | g_images = net_ig(gimg_ae, style_feats).cpu() 46 | image_matrix.append(skt.unsqueeze(0).repeat(1, 3, 1, 1).clone().fill_(1)) 47 | image_matrix.append(torch.nn.functional.interpolate(g_images, IM_SIZE)) 48 | 49 | image_matrix = torch.cat(image_matrix) 50 | vutils.save_image(0.5 * (image_matrix + 1), im_name, nrow=BATCH_SIZE + 1) 51 | 52 | 53 | def save_csv(save_csv_path, iters, MSE, FID): 54 | time = '{}'.format(datetime.now()) 55 | iters = '{}'.format(iters) 56 | MSE = '{:.5f}'.format(MSE) 57 | FID = '{:.5f}'.format(FID) 58 | 59 | print('------ Saving csv ------') 60 | list = [time, iters, MSE, FID] 61 | 62 | data = pd.DataFrame([list]) 63 | data.to_csv(save_csv_path, mode='a', header=False, index=False) 64 | 65 | 66 | def train(): 67 | from benchmark import load_patched_inception_v3 68 | import lpips 69 | 70 | from config import IM_SIZE_GAN, BATCH_SIZE_GAN, CHANNEL, NBR_CLS, DATALOADER_WORKERS, EPOCH_GAN, ITERATION_GAN, \ 71 | ITERATION_AE, GAN_CKECKPOINT 72 | from config import SAVE_MODEL_INTERVAL, LOG_INTERVAL, SAVE_FOLDER, MULTI_GPU 73 | from config import PRETRAINED_AE_PATH 74 | from config import data_root_colorful, data_root_sketch 75 | 76 | inception = load_patched_inception_v3().cuda() 77 | inception.eval() 78 | 79 | percept = lpips.PerceptualLoss(model='net-lin', net='vgg', use_gpu=True) 80 | 81 | # save path 82 | save_csv_path = './checkpoint/train_results.csv' 83 | 84 | saved_image_folder, saved_model_folder = make_folders(SAVE_FOLDER, 'Train_step_2') 85 | 86 | if not os.path.exists(save_csv_path): 87 | df = pd.DataFrame(columns=['time', 'iters', 'Lpips', 'FID']) 88 | df.to_csv(save_csv_path, index=False) 89 | print('make csv successful !') 90 | else: 91 | print('csv is exist !') 92 | 93 | # load dataset 94 | dataset = PairedDataset(data_root_colorful, data_root_sketch, im_size=IM_SIZE_GAN) 95 | print('the dataset contains %d images.' % len(dataset)) 96 | dataloader = iter(DataLoader(dataset, BATCH_SIZE_GAN, sampler=InfiniteSamplerWrapper(dataset), 97 | num_workers=DATALOADER_WORKERS, pin_memory=True)) 98 | 99 | # load ae model 100 | net_ae = AE(ch=CHANNEL, nbr_cls=NBR_CLS) 101 | 102 | if PRETRAINED_AE_PATH is None: 103 | PRETRAINED_AE_PATH = SAVE_FOLDER + 'train_results/Train_step_1/' + 'models/%d.pth' % ITERATION_AE 104 | else: 105 | PRETRAINED_AE_PATH = PRETRAINED_AE_PATH 106 | 107 | print('Pre-trained AE path : ', PRETRAINED_AE_PATH) 108 | 109 | net_ae.load_state_dicts(PRETRAINED_AE_PATH) 110 | net_ae.cuda() 111 | net_ae.eval() 112 | 113 | from models import RefineGenerator as Generator 114 | 115 | # load generator & discriminator 116 | net_ig = Generator(ch=CHANNEL, im_size=IM_SIZE_GAN).cuda() 117 | net_id = Discriminator(nc=3).cuda() 118 | 119 | if MULTI_GPU: 120 | net_ae = nn.DataParallel(net_ae) 121 | net_ig = nn.DataParallel(net_ig) 122 | net_id = nn.DataParallel(net_id) 123 | 124 | net_ig_ema = copy_G_params(net_ig) 125 | 126 | opt_ig = optim.Adam(net_ig.parameters(), lr=2e-4, betas=(0.8, 0.999)) 127 | opt_id = optim.Adam(net_id.parameters(), lr=2e-4, betas=(0.8, 0.999)) 128 | 129 | if GAN_CKECKPOINT is not None: 130 | ckpt = torch.load(GAN_CKECKPOINT) 131 | net_ig.load_state_dict(ckpt['ig']) 132 | net_id.load_state_dict(ckpt['id']) 133 | net_ig_ema = ckpt['ig_ema'] 134 | opt_ig.load_state_dict(ckpt['opt_ig']) 135 | opt_id.load_state_dict(ckpt['opt_id']) 136 | print('Pre-trained GAN path : ', GAN_CKECKPOINT) 137 | 138 | # loss log 139 | losses_g_img = AverageMeter() 140 | losses_d_img = AverageMeter() 141 | losses_mse = AverageMeter() 142 | losses_style = AverageMeter() 143 | losses_content = AverageMeter() 144 | losses_rec_ae = AverageMeter() 145 | 146 | fid_init = 1000.0 147 | 148 | ################### 149 | # train gan # 150 | ################### 151 | for epoch in range(EPOCH_GAN): 152 | for iteration in tqdm(range(ITERATION_GAN)): 153 | rgb_img, skt_img = next(dataloader) 154 | 155 | rgb_img = rgb_img.cuda() 156 | skt_img = skt_img.cuda() 157 | 158 | # 1. train Discriminator 159 | gimg_ae, style_feats = net_ae(skt_img, rgb_img) 160 | g_image = net_ig(gimg_ae, style_feats) 161 | 162 | real = net_id(rgb_img) 163 | fake = net_id(g_image.detach()) 164 | 165 | loss_d = d_hinge_loss(real, fake) 166 | 167 | net_id.zero_grad() 168 | loss_d.backward() 169 | opt_id.step() 170 | 171 | # log ae loss 172 | loss_rec_ae = F.mse_loss(gimg_ae, rgb_img) + F.l1_loss(gimg_ae, rgb_img) 173 | losses_rec_ae.update(loss_rec_ae.item(), BATCH_SIZE_GAN) 174 | 175 | # 2. train Generator 176 | pred_g = net_id(g_image) 177 | loss_g = g_hinge_loss(pred_g) 178 | 179 | loss_mse = 10 * percept(F.adaptive_avg_pool2d(g_image, output_size=256), 180 | F.adaptive_avg_pool2d(rgb_img, output_size=256)).sum() 181 | losses_mse.update(loss_mse.item() / BATCH_SIZE_GAN, BATCH_SIZE_GAN) 182 | 183 | _, g_style_feats = net_ae(skt_img, g_image) 184 | 185 | loss_style = 0 186 | for loss_idx in range(3): 187 | loss_style += - F.cosine_similarity(g_style_feats[loss_idx], 188 | style_feats[loss_idx].detach()).mean() + \ 189 | F.cosine_similarity(g_style_feats[loss_idx], 190 | style_feats[loss_idx][torch.randperm(BATCH_SIZE_GAN)] 191 | .detach()).mean() 192 | losses_style.update(loss_style.item() / BATCH_SIZE_GAN, BATCH_SIZE_GAN) 193 | 194 | loss_all = loss_g + loss_mse + loss_style 195 | 196 | net_ig.zero_grad() 197 | loss_all.backward() 198 | opt_ig.step() 199 | 200 | for p, avg_p in zip(net_ig.parameters(), net_ig_ema): 201 | avg_p.mul_(0.999).add_(p.data, alpha=0.001) 202 | 203 | # 3. logging 204 | losses_g_img.update(pred_g.mean().item(), BATCH_SIZE_GAN) 205 | losses_d_img.update(real.mean().item(), BATCH_SIZE_GAN) 206 | 207 | # 4. save model 208 | if iteration % SAVE_MODEL_INTERVAL == 0 or iteration + 1 == 10000: 209 | print('Saving history model') 210 | torch.save({'ig': net_ig.state_dict(), 211 | 'id': net_id.state_dict(), 212 | 'ae': net_ae.state_dict(), 213 | 'ig_ema': net_ig_ema, 214 | 'opt_ig': opt_ig.state_dict(), 215 | 'opt_id': opt_id.state_dict(), 216 | }, '%s/%d.pth' % (saved_model_folder, epoch)) 217 | 218 | # 5. print log 219 | if iteration % LOG_INTERVAL == 0: 220 | # calcuate lpips and fid 221 | cal_lpips = calculate_Lpips(data_root_colorful, data_root_sketch, net_ae, net_ig) 222 | cal_fid = calculate_fid(data_root_colorful, data_root_sketch, net_ae, net_ig) 223 | 224 | log_msg = ' \nGAN_Iter: [{0}/{1}] AE_loss: {ae_loss: .5f} \n' \ 225 | 'Generator: {losses_g_img.avg:.4f} Discriminator: {losses_d_img.avg:.4f} \n' \ 226 | 'Style: {losses_style.avg:.5f} Content: {losses_content.avg:.5f} \n' \ 227 | 'Lpips: {lpips:.4f} FID: {fid:.4f}\n'.format( 228 | epoch, iteration, ae_loss=losses_rec_ae.avg, losses_g_img=losses_g_img, 229 | losses_d_img=losses_d_img, losses_style=losses_style, losses_content=losses_content, 230 | lpips=cal_lpips, fid=cal_fid) 231 | 232 | print(log_msg) 233 | 234 | save_csv(save_csv_path, epoch * ITERATION_GAN + iteration, cal_lpips, cal_fid) 235 | 236 | # save model 237 | if cal_fid < fid_init: 238 | fid_init = cal_fid 239 | print('Saving history model') 240 | torch.save({'ig': net_ig.state_dict(), 241 | 'id': net_id.state_dict(), 242 | 'ae': net_ae.state_dict(), 243 | 'ig_ema': net_ig_ema, 244 | 'opt_ig': opt_ig.state_dict(), 245 | 'opt_id': opt_id.state_dict(), 246 | }, '%s/%d_%d.pth' % (saved_model_folder, epoch, iteration)) 247 | 248 | losses_g_img.reset() 249 | losses_d_img.reset() 250 | losses_mse.reset() 251 | losses_style.reset() 252 | losses_content.reset() 253 | losses_rec_ae.reset() 254 | 255 | 256 | def calculate_Lpips(data_root_colorful, data_root_sketch, net_ae, net_ig): 257 | import lpips 258 | 259 | IM_SIZE = 256 260 | BATCH_SIZE = 6 261 | DATALOADER_WORKERS = 0 262 | 263 | percept = lpips.PerceptualLoss(model='net-lin', net='vgg', use_gpu=True) 264 | 265 | # load dataset 266 | dataset = PairedDataset(data_root_colorful, data_root_sketch, im_size=IM_SIZE) 267 | print('the dataset contains %d images.' % len(dataset)) 268 | 269 | dataloader = iter(DataLoader(dataset, BATCH_SIZE, sampler=InfiniteSamplerWrapper(dataset), 270 | num_workers=DATALOADER_WORKERS, pin_memory=True)) 271 | 272 | net_ae.eval() 273 | net_ig.eval() 274 | 275 | # lpips 276 | get_lpips = AverageMeter() 277 | lpips_list = [] 278 | 279 | # Network 280 | for iter_data in tqdm(range(100)): 281 | rgb_img, skt_img = next(dataloader) 282 | 283 | rgb_img = rgb_img.cuda() 284 | skt_img = skt_img.cuda() 285 | 286 | gimg_ae, style_feats = net_ae(skt_img, rgb_img) 287 | g_image = net_ig(gimg_ae, style_feats) 288 | 289 | loss_mse = 10 * percept(F.adaptive_avg_pool2d(g_image, output_size=256), 290 | F.adaptive_avg_pool2d(rgb_img, output_size=256)).sum() 291 | get_lpips.update(loss_mse.item() / BATCH_SIZE, BATCH_SIZE) 292 | 293 | lpips_list.append(get_lpips.avg) 294 | 295 | print('LPIPS : ', sum(lpips_list) / len(lpips_list)) 296 | 297 | return sum(lpips_list) / len(lpips_list) 298 | 299 | 300 | def calculate_fid(data_root_colorful, data_root_sketch, net_ae, net_ig): 301 | from benchmark import calc_fid, extract_feature_from_generator_fn, load_patched_inception_v3, \ 302 | real_image_loader, image_generator 303 | import numpy as np 304 | 305 | IM_SIZE = 256 306 | BATCH_SIZE = 8 307 | DATALOADER_WORKERS = 0 308 | fid_batch_images = 119 309 | fid_iters = 10 310 | inception = load_patched_inception_v3().cuda() 311 | inception.eval() 312 | 313 | fid = [] 314 | 315 | # load dataset 316 | dataset = PairedDataset(data_root_colorful, data_root_sketch, im_size=IM_SIZE) 317 | print('the dataset contains %d images.' % len(dataset)) 318 | 319 | dataloader = iter(DataLoader(dataset, BATCH_SIZE, sampler=InfiniteSamplerWrapper(dataset), 320 | num_workers=DATALOADER_WORKERS, pin_memory=True)) 321 | 322 | net_ae.eval() 323 | net_ig.eval() 324 | 325 | print("calculating FID ...") 326 | 327 | real_features = extract_feature_from_generator_fn( 328 | real_image_loader(dataloader, n_batches=fid_batch_images), inception) 329 | real_mean = np.mean(real_features, 0) 330 | real_cov = np.cov(real_features, rowvar=False) 331 | real_features = {'feats': real_features, 'mean': real_mean, 'cov': real_cov} 332 | 333 | for iter_fid in range(fid_iters): 334 | sample_features = extract_feature_from_generator_fn( 335 | image_generator(dataset, net_ae, net_ig, n_batches=fid_batch_images), 336 | inception, total=fid_batch_images // BATCH_SIZE - 1) 337 | cur_fid = calc_fid(sample_features, real_mean=real_features['mean'], real_cov=real_features['cov']) 338 | 339 | print('FID[{}]: '.format(iter_fid), cur_fid) 340 | fid.append(cur_fid) 341 | 342 | print('FID: ', sum(fid) / len(fid)) 343 | 344 | return sum(fid) / len(fid) 345 | 346 | 347 | if __name__ == "__main__": 348 | train() 349 | -------------------------------------------------------------------------------- /styleme/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from copy import deepcopy 4 | from random import shuffle 5 | import torch.nn.functional as F 6 | 7 | 8 | def d_hinge_loss(real_pred, fake_pred): 9 | real_loss = F.relu(1 - real_pred) 10 | fake_loss = F.relu(1 + fake_pred) 11 | 12 | return real_loss.mean() + fake_loss.mean() 13 | 14 | 15 | def g_hinge_loss(pred): 16 | return -pred.mean() 17 | 18 | 19 | class AverageMeter(object): 20 | 21 | def __init__(self): 22 | self.reset() 23 | 24 | def reset(self): 25 | self.val = 0 26 | self.avg = 0 27 | self.sum = 0 28 | self.count = 0 29 | 30 | def update(self, val, n=1): 31 | self.val = val 32 | self.sum += val * n 33 | self.count += n 34 | self.avg = self.sum / self.count 35 | 36 | 37 | def true_randperm(size, device='cuda'): 38 | def unmatched_randperm(size): 39 | l1 = [i for i in range(size)] 40 | l2 = [] 41 | for j in range(size): 42 | deleted = False 43 | if j in l1: 44 | deleted = True 45 | del l1[l1.index(j)] 46 | shuffle(l1) 47 | if len(l1) == 0: 48 | return 0, False 49 | l2.append(l1[0]) 50 | del l1[0] 51 | if deleted: 52 | l1.append(j) 53 | return l2, True 54 | 55 | flag = False 56 | l = torch.zeros(size).long() 57 | while not flag: 58 | l, flag = unmatched_randperm(size) 59 | return torch.LongTensor(l).to(device) 60 | 61 | 62 | def copy_G_params(model): 63 | flatten = deepcopy(list(p.data for p in model.parameters())) 64 | return flatten 65 | 66 | 67 | def load_params(model, new_param): 68 | for p, new_p in zip(model.parameters(), new_param): 69 | p.data.copy_(new_p) 70 | 71 | 72 | def make_folders(save_folder, trial_name): 73 | saved_model_folder = os.path.join(save_folder, 'train_results/%s/models' % trial_name) 74 | saved_image_folder = os.path.join(save_folder, 'train_results/%s/images' % trial_name) 75 | folders = [os.path.join(save_folder, 'train_results'), 76 | os.path.join(save_folder, 'train_results/%s' % trial_name), 77 | os.path.join(save_folder, 'train_results/%s/images' % trial_name), 78 | os.path.join(save_folder, 'train_results/%s/models' % trial_name)] 79 | for folder in folders: 80 | if not os.path.exists(folder): 81 | os.mkdir(folder) 82 | 83 | return saved_image_folder, saved_model_folder 84 | --------------------------------------------------------------------------------