├── .gitignore ├── LICENSE.txt ├── README.md ├── _config.yml ├── cp_data.py ├── data ├── __init__.py ├── aligned_dataset.py ├── aligned_dataset2.py ├── base_data_loader.py ├── base_dataset.py ├── con_dataset_data_loader.py ├── custom_dataset_data_loader.py ├── data_loader.py ├── face_con_dataset_data_loader.py ├── face_dataset.py ├── image_folder.py ├── keypoint2img.py ├── pose_con_dataset_data_loader.py └── pose_dataset.py ├── encode_features.py ├── expand_val_test.py ├── generate_data_face_forensics.py ├── img ├── dance.png ├── face.png ├── fig.png ├── scene.png ├── small_dance.png ├── small_face.png └── small_scene.png ├── infer_face.py ├── inference ├── data │ ├── img │ │ ├── __Mr_0__0__a.0.jpg │ │ ├── __Mr_0__0__a.1.jpg │ │ ├── __Mr_1__0__a.0.jpg │ │ └── __Mr_1__0__a.1.jpg │ ├── infer_list.txt │ └── keypoints │ │ ├── __Mr_0__0__a.0.txt │ │ ├── __Mr_0__0__a.1.txt │ │ ├── __Mr_1__0__a.0.txt │ │ └── __Mr_1__0__a.1.txt ├── infer_list.txt └── test_imgs │ ├── 1.jpg │ └── 2.jpg ├── judge_face.py ├── models ├── __init__.py ├── base_model.py ├── c_pix2pixHD_model.py ├── cm_pix2pixHD_model.py ├── models.py ├── networks.py ├── pix2pixHD_model.py └── ui_model.py ├── new_scripts ├── infer_face.sh ├── table_face.sh ├── table_pose.sh ├── table_scene.sh ├── test_face.sh ├── test_pose.sh ├── test_scene.sh ├── train_face.sh ├── train_pose.sh └── train_scene.sh ├── options ├── __init__.py ├── base_options.py ├── generate_data_options.py ├── test_options.py └── train_options.py ├── precompute_feature_maps.py ├── process.py ├── run_engine.py ├── temp_test.py ├── test.py ├── test_all.py ├── test_con.py ├── test_con_bak.py ├── test_delta.txt ├── test_face.py ├── test_mface.py ├── test_pose.py ├── test_seg.py ├── train.py ├── train_con.py ├── train_face.py ├── train_mface.py ├── train_pose.py ├── train_seg.py ├── util ├── __init__.py ├── html.py ├── image_pool.py ├── label.py ├── util.py └── visualizer.py ├── val_gen.py ├── vis_bdd.py ├── vis_delta.txt ├── vis_face.py └── vis_pose.py /.gitignore: -------------------------------------------------------------------------------- 1 | debug* 2 | datasets/ 3 | checkpoints/ 4 | results/ 5 | build/ 6 | dist/ 7 | logs/ 8 | old_logs/ 9 | torch.egg-info/ 10 | */**/__pycache__ 11 | torch/version.py 12 | torch/csrc/generic/TensorMethods.cpp 13 | torch/lib/*.so* 14 | torch/lib/*.dylib* 15 | torch/lib/*.h 16 | torch/lib/build 17 | torch/lib/tmp_install 18 | torch/lib/include 19 | torch/lib/torch_shm_manager 20 | torch/csrc/cudnn/cuDNN.cpp 21 | torch/csrc/nn/THNN.cwrap 22 | torch/csrc/nn/THNN.cpp 23 | torch/csrc/nn/THCUNN.cwrap 24 | torch/csrc/nn/THCUNN.cpp 25 | torch/csrc/nn/THNN_generic.cwrap 26 | torch/csrc/nn/THNN_generic.cpp 27 | torch/csrc/nn/THNN_generic.h 28 | docs/src/**/* 29 | test/data/legacy_modules.t7 30 | test/data/gpu_tensors.pt 31 | test/htmlcov 32 | test/.coverage 33 | */*.pyc 34 | */**/*.pyc 35 | */**/**/*.pyc 36 | */**/**/**/*.pyc 37 | */**/**/**/**/*.pyc 38 | */*.so* 39 | */**/*.so* 40 | */**/*.dylib* 41 | test/data/legacy_serialized.pt 42 | *.DS_Store 43 | *~ 44 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (C) 2017 NVIDIA Corporation. Ting-Chun Wang, Ming-Yu Liu, Jun-Yan Zhu. 2 | All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | 5 | Permission to use, copy, modify, and distribute this software and its documentation 6 | for any non-commercial purpose is hereby granted without fee, provided that the above 7 | copyright notice appear in all copies and that both that copyright notice and this 8 | permission notice appear in supporting documentation, and that the name of the author 9 | not be used in advertising or publicity pertaining to distribution of the software 10 | without specific, written prior permission. 11 | 12 | THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING ALL 13 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR ANY PARTICULAR PURPOSE. 14 | IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL 15 | DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, 16 | WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING 17 | OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 18 | 19 | 20 | --------------------------- LICENSE FOR pytorch-CycleGAN-and-pix2pix ---------------- 21 | Copyright (c) 2017, Jun-Yan Zhu and Taesung Park 22 | All rights reserved. 23 | 24 | Redistribution and use in source and binary forms, with or without 25 | modification, are permitted provided that the following conditions are met: 26 | 27 | * Redistributions of source code must retain the above copyright notice, this 28 | list of conditions and the following disclaimer. 29 | 30 | * Redistributions in binary form must reproduce the above copyright notice, 31 | this list of conditions and the following disclaimer in the documentation 32 | and/or other materials provided with the distribution. 33 | 34 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 35 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 36 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 37 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 38 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 39 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 40 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 41 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 42 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 43 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 44 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Example-Guided Style-Consistent Image Synthesis from Semantic Labeling 2 |

3 | 4 | ## Paper 5 | Example-Guided Style-Consistent Image Synthesis from Semantic Labeling
6 | Miao Wang1, Guo-Ye Yang2, Ruilong Li2, Run-Ze Liang2, Song-Hai Zhang2, Peter M. Hall3 and Shi-Min Hu2,1
7 | 1State Key Laboratory of Virtual Reality Technology and Systems, Beihang University
8 | 2Department of Computer Science and Technology, Tsinghua University, Beijing
9 | 3University of Bath
10 | *IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2019* 11 | 12 | 13 | ## Prerequisites 14 | - Linux 15 | - Python 3 16 | - NVIDIA GPU (12G or 24G memory) + CUDA cuDNN 17 | - pytorch==0.4.1 18 | - numpy 19 | - ... 20 | 21 | ## Tasks 22 | ### Sketch2Face 23 | Task name: face 24 | 25 | We use the real videos in the FaceForensics dataset, which contains 854 videos of reporters broadcasting news. We localize facial landmarks, crop facial regions and resize them to size 256×256. The detected facial landmarks are connected to create face sketches. 26 |

27 | 28 | ### Pose2Dance 29 | Task name: pose 30 | 31 | We download 150 solo dance videos from YouTube, crop out the central body regions and resize them to 256×256. We evenly split each video into the first part and the second part along the time-line, then sample training data only from the first parts and sample testing data only from the second parts of all the videos. The the labels are created using concatenated pre-trained DensePose and OpenPose pose detection results. 32 |

33 | 34 | ### SceneParsing2StreetView 35 | Task name: scene 36 | 37 | We use the BDD100k dataset to synthesize street view images from pixelwise semantic labels (i.e. scene parsing maps). We use the state-of-the-art scene parsing network DANet to create labels. 38 |

39 | 40 | ## Getting Started 41 | ### Installation 42 | ```bash 43 | git clone [this project] 44 | cd pix2pixSC 45 | # download datas.zip at https://drive.google.com/drive/folders/1O94UcCXONq7p2ZiPcfi-dldjREQ-GsJK or https://share.weiyun.com/5lHBkE0 46 | unzip datas.zip 47 | mv datas/checkpoints ./ 48 | mv datas/datasets ./ 49 | 50 | # scripts below is optional 51 | mkdir ../FaceForensics 52 | download FaceForensics dataset to ../FaceForensics/datas 53 | python process.py 54 | python generate_data_face_forensics.py --source_path '../FaceForensics/out_data' --target_path './datasets/FaceForensics3/' --same_style_rate 0.3 --neighbor_size 10 --A_repeat_num 50 --copy_data 55 | ``` 56 | 57 | ### Training 58 | ```bash 59 | new_scripts/train_[Task name].sh 60 | ``` 61 | 62 | ### Testing 63 | ```bash 64 | new_scripts/test_[Task name].sh 65 | ``` 66 | 67 | ### Inference Face 68 | Edit inference/infer_list.txt, one test case each line, outputs of each test case will be in ./results. 69 | ```bash 70 | new_scripts/infer_face.sh 71 | ``` 72 | Inference code of other tasks will come later. 73 | 74 | ## Results 75 | ### Face 76 |

77 | 78 | ### Dance 79 |

80 | 81 | ### Scene 82 |

83 | 84 | 85 | ## Citation 86 | 87 | If you find this useful for your research, please cite the following paper. 88 | 89 | ``` 90 | @InProceedings{pix2pixSC2019, 91 | author = {Wang, Miao and Yang, Guo-Ye and Li, Ruilong and Liang, Run-Ze and Zhang, Song-Hai and Hall, Peter. M and Hu, Shi-Min}, 92 | title = {Example-Guided Style-Consistent Image Synthesis from Semantic Labeling}, 93 | booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 94 | month = {June}, 95 | year = {2019} 96 | } 97 | ``` 98 | -------------------------------------------------------------------------------- /_config.yml: -------------------------------------------------------------------------------- 1 | theme: jekyll-theme-minimal -------------------------------------------------------------------------------- /cp_data.py: -------------------------------------------------------------------------------- 1 | from shutil import copyfile 2 | tp = 'datasets/cityscapes/val_' 3 | pre = 'aachen_000018_000019' 4 | pre = 'bremen_000108_000019' 5 | pre = 'frankfurt_000001_008688' 6 | pre = 'frankfurt_000001_060906' 7 | s_img_path = tp + 'img/' + pre + '_leftImg8bit.png' 8 | t_img_path = 'datasets/test/img.png' 9 | s_inst_path = tp + 'inst/' + pre + '_gtFine_instanceIds.png' 10 | t_inst_path = 'datasets/test/inst.png' 11 | s_label_path = tp + 'label/' + pre + '_gtFine_labelIds.png' 12 | t_label_path = 'datasets/test/label.png' 13 | copyfile(s_img_path, t_img_path) 14 | copyfile(s_inst_path, t_inst_path) 15 | copyfile(s_label_path, t_label_path) 16 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cxjyxxme/pix2pixSC/1851e934b333e89b98b3c0df9962912201fcb376/data/__init__.py -------------------------------------------------------------------------------- /data/aligned_dataset.py: -------------------------------------------------------------------------------- 1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 3 | import os.path 4 | from data.base_dataset import BaseDataset, get_params, get_transform, normalize 5 | from data.image_folder import make_dataset 6 | from PIL import Image 7 | 8 | class AlignedDataset(BaseDataset): 9 | def initialize(self, opt): 10 | self.opt = opt 11 | self.root = opt.dataroot 12 | 13 | ### input A (label maps) 14 | dir_A = '_A' if self.opt.label_nc == 0 else '_label' 15 | self.dir_A = os.path.join(opt.dataroot, opt.phase + dir_A) 16 | self.A_paths = sorted(make_dataset(self.dir_A)) 17 | 18 | ### input B (real images) 19 | if opt.isTrain: 20 | dir_B = '_B' if self.opt.label_nc == 0 else '_img' 21 | self.dir_B = os.path.join(opt.dataroot, opt.phase + dir_B) 22 | self.B_paths = sorted(make_dataset(self.dir_B)) 23 | 24 | ### instance maps 25 | if not opt.no_instance: 26 | self.dir_inst = os.path.join(opt.dataroot, opt.phase + '_inst') 27 | self.inst_paths = sorted(make_dataset(self.dir_inst)) 28 | 29 | ### load precomputed instance-wise encoded features 30 | if opt.load_features: 31 | self.dir_feat = os.path.join(opt.dataroot, opt.phase + '_feat') 32 | print('----------- loading features from %s ----------' % self.dir_feat) 33 | self.feat_paths = sorted(make_dataset(self.dir_feat)) 34 | 35 | self.dataset_size = len(self.A_paths) 36 | 37 | def __getitem__(self, index): 38 | ### input A (label maps) 39 | A_path = self.A_paths[index] 40 | A = Image.open(A_path) 41 | params = get_params(self.opt, A.size) 42 | if self.opt.label_nc == 0: 43 | transform_A = get_transform(self.opt, params) 44 | A_tensor = transform_A(A.convert('RGB')) 45 | else: 46 | transform_A = get_transform(self.opt, params, method=Image.NEAREST, normalize=False) 47 | A_tensor = transform_A(A) * 255.0 48 | 49 | B_tensor = inst_tensor = feat_tensor = 0 50 | ### input B (real images) 51 | if self.opt.isTrain: 52 | B_path = self.B_paths[index] 53 | B = Image.open(B_path).convert('RGB') 54 | transform_B = get_transform(self.opt, params) 55 | B_tensor = transform_B(B) 56 | 57 | ### if using instance maps 58 | if not self.opt.no_instance: 59 | inst_path = self.inst_paths[index] 60 | inst = Image.open(inst_path) 61 | inst_tensor = transform_A(inst) 62 | 63 | if self.opt.load_features: 64 | feat_path = self.feat_paths[index] 65 | feat = Image.open(feat_path).convert('RGB') 66 | norm = normalize() 67 | feat_tensor = norm(transform_A(feat)) 68 | 69 | input_dict = {'label': A_tensor, 'inst': inst_tensor, 'image': B_tensor, 70 | 'feat': feat_tensor, 'path': A_path} 71 | 72 | return input_dict 73 | 74 | def __len__(self): 75 | return len(self.A_paths) // self.opt.batchSize * self.opt.batchSize 76 | 77 | def name(self): 78 | return 'AlignedDataset' -------------------------------------------------------------------------------- /data/aligned_dataset2.py: -------------------------------------------------------------------------------- 1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 3 | import os.path 4 | from data.base_dataset import BaseDataset, get_params, get_transform, normalize 5 | from data.image_folder import make_dataset 6 | from PIL import Image 7 | 8 | class AlignedDataset2(BaseDataset): 9 | def initialize(self, opt): 10 | self.opt = opt 11 | self.root = opt.dataroot 12 | data_list_path = os.path.join(opt.dataroot, opt.phase + '_list.txt') 13 | f = open(data_list_path, 'r') 14 | self.all_paths = f.readlines() 15 | self.dataset_size = len(self.all_paths) 16 | 17 | def transfer_path(self, path): 18 | temp = path.split('__')[-1] 19 | temp = './datasets/bdd2/danet_vis/' + temp + '.label.png' 20 | return temp 21 | 22 | def same_style(self, path1, path2): 23 | t1 = path1.split('__') 24 | t2 = path2.split('__') 25 | p1 = t1[-3] + '_' + t1[-2] 26 | p2 = t2[-3] + '_' + t2[-2] 27 | if (p1 == p2): 28 | return 1 29 | else: 30 | return 0 31 | 32 | def get_X(self, path, params, do_transfer = False): 33 | ### input A (label maps) 34 | if self.opt.use_new_label and do_transfer: 35 | path = self.transfer_path(path) 36 | A = Image.open(path) 37 | if self.opt.label_nc == 0: 38 | transform_A = get_transform(self.opt, params) 39 | A_tensor = transform_A(A.convert('RGB')) 40 | else: 41 | transform_A = get_transform(self.opt, params, method=Image.NEAREST, normalize=False) 42 | A_tensor = transform_A(A) * 255.0 43 | return A_tensor 44 | 45 | def get_X2(self, path, params): 46 | B = Image.open(path).convert('RGB') 47 | transform_B = get_transform(self.opt, params) 48 | B_tensor = transform_B(B) 49 | return B_tensor 50 | 51 | def __getitem__(self, index): 52 | paths = self.all_paths[index].rstrip('\n').split('&') 53 | A = Image.open(paths[0]) 54 | params = get_params(self.opt, A.size) 55 | 56 | A_tensor = self.get_X(paths[0], params) 57 | B_tensor = self.get_X(paths[1], params, True) 58 | B2_tensor = self.get_X2(paths[2], params) 59 | A2_tensor = self.get_X2(paths[3], params) 60 | C_tensor = C2_tensor = D_tensor = D2_tensor = 0 61 | if (self.opt.isTrain): 62 | C_tensor = self.get_X(paths[4], params) 63 | C2_tensor = self.get_X2(paths[5], params) 64 | D_tensor = self.get_X(paths[6], params) 65 | D2_tensor = self.get_X2(paths[7], params) 66 | input_dict = {'A': A_tensor, 'A2': A2_tensor, 'B': B_tensor, 'B2': B2_tensor, 'C': C_tensor, 'C2': C2_tensor, 67 | 'D': D_tensor, 'D2': D2_tensor, 'path': paths[0] + '_' + paths[1], 'same_style': self.same_style(paths[2], paths[3])} 68 | return input_dict 69 | 70 | def __len__(self): 71 | return len(self.all_paths) // self.opt.batchSize * self.opt.batchSize 72 | 73 | def name(self): 74 | return 'AlignedDataset2' 75 | -------------------------------------------------------------------------------- /data/base_data_loader.py: -------------------------------------------------------------------------------- 1 | 2 | class BaseDataLoader(): 3 | def __init__(self): 4 | pass 5 | 6 | def initialize(self, opt): 7 | self.opt = opt 8 | pass 9 | 10 | def load_data(): 11 | return None 12 | 13 | 14 | 15 | -------------------------------------------------------------------------------- /data/base_dataset.py: -------------------------------------------------------------------------------- 1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 3 | import torch.utils.data as data 4 | from PIL import Image 5 | import torchvision.transforms as transforms 6 | import numpy as np 7 | import random 8 | 9 | class BaseDataset(data.Dataset): 10 | def __init__(self): 11 | super(BaseDataset, self).__init__() 12 | 13 | def name(self): 14 | return 'BaseDataset' 15 | 16 | def initialize(self, opt): 17 | pass 18 | 19 | def get_params(opt, size): 20 | w, h = size 21 | new_h = h 22 | new_w = w 23 | if opt.resize_or_crop == 'resize_and_crop': 24 | new_h = new_w = opt.loadSize 25 | elif opt.resize_or_crop == 'scale_width_and_crop': 26 | new_w = opt.loadSize 27 | new_h = opt.loadSize * h // w 28 | 29 | x = random.randint(0, np.maximum(0, new_w - opt.fineSize)) 30 | y = random.randint(0, np.maximum(0, new_h - opt.fineSize)) 31 | if (not opt.isTrain) and opt.resize_or_crop == 'scale_width_and_crop': 32 | x = np.maximum(0, new_w - opt.fineSize) / 2 33 | y = np.maximum(0, new_h - opt.fineSize) / 2 34 | 35 | flip = random.random() > 0.5 36 | return {'crop_pos': (x, y), 'flip': flip} 37 | 38 | def get_transform(opt, params, method=Image.BICUBIC, normalize=True): 39 | transform_list = [] 40 | if 'resize' in opt.resize_or_crop: 41 | osize = [opt.loadSize, opt.loadSize] 42 | transform_list.append(transforms.Scale(osize, method)) 43 | elif 'scale_width' in opt.resize_or_crop: 44 | transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.loadSize, method))) 45 | 46 | if 'crop' in opt.resize_or_crop: 47 | transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.fineSize))) 48 | 49 | if opt.resize_or_crop == 'none': 50 | base = float(2 ** opt.n_downsample_global) 51 | if opt.netG == 'local': 52 | base *= (2 ** opt.n_local_enhancers) 53 | transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method))) 54 | 55 | if opt.isTrain and not opt.no_flip: 56 | transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip']))) 57 | 58 | transform_list += [transforms.ToTensor()] 59 | 60 | if normalize: 61 | transform_list += [transforms.Normalize((0.5, 0.5, 0.5), 62 | (0.5, 0.5, 0.5))] 63 | return transforms.Compose(transform_list) 64 | 65 | def normalize(): 66 | return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 67 | 68 | def __make_power_2(img, base, method=Image.BICUBIC): 69 | ow, oh = img.size 70 | h = int(round(oh / base) * base) 71 | w = int(round(ow / base) * base) 72 | if (h == oh) and (w == ow): 73 | return img 74 | return img.resize((w, h), method) 75 | 76 | def __scale_width(img, target_width, method=Image.BICUBIC): 77 | ow, oh = img.size 78 | if (ow == target_width): 79 | return img 80 | w = target_width 81 | h = int(target_width * oh / ow) 82 | return img.resize((w, h), method) 83 | 84 | def __crop(img, pos, size): 85 | ow, oh = img.size 86 | x1, y1 = pos 87 | tw = th = size 88 | if (ow > tw or oh > th): 89 | return img.crop((x1, y1, x1 + tw, y1 + th)) 90 | return img 91 | 92 | def __flip(img, flip): 93 | if flip: 94 | return img.transpose(Image.FLIP_LEFT_RIGHT) 95 | return img 96 | 97 | -------------------------------------------------------------------------------- /data/con_dataset_data_loader.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data 2 | from data.base_data_loader import BaseDataLoader 3 | 4 | 5 | def CreateDataset(opt): 6 | dataset = None 7 | from data.aligned_dataset2 import AlignedDataset2 8 | dataset = AlignedDataset2() 9 | 10 | print("dataset [%s] was created" % (dataset.name())) 11 | dataset.initialize(opt) 12 | return dataset 13 | 14 | class ConDatasetDataLoader(BaseDataLoader): 15 | def name(self): 16 | return 'ConDatasetDataLoader' 17 | 18 | def initialize(self, opt): 19 | BaseDataLoader.initialize(self, opt) 20 | self.dataset = CreateDataset(opt) 21 | self.dataloader = torch.utils.data.DataLoader( 22 | self.dataset, 23 | batch_size=opt.batchSize, 24 | shuffle=not opt.serial_batches, 25 | num_workers=int(opt.nThreads)) 26 | 27 | def load_data(self): 28 | return self.dataloader 29 | 30 | def __len__(self): 31 | return min(len(self.dataset), self.opt.max_dataset_size) 32 | -------------------------------------------------------------------------------- /data/custom_dataset_data_loader.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data 2 | from data.base_data_loader import BaseDataLoader 3 | 4 | 5 | def CreateDataset(opt): 6 | dataset = None 7 | from data.aligned_dataset import AlignedDataset 8 | dataset = AlignedDataset() 9 | 10 | print("dataset [%s] was created" % (dataset.name())) 11 | dataset.initialize(opt) 12 | return dataset 13 | 14 | class CustomDatasetDataLoader(BaseDataLoader): 15 | def name(self): 16 | return 'CustomDatasetDataLoader' 17 | 18 | def initialize(self, opt): 19 | BaseDataLoader.initialize(self, opt) 20 | self.dataset = CreateDataset(opt) 21 | self.dataloader = torch.utils.data.DataLoader( 22 | self.dataset, 23 | batch_size=opt.batchSize, 24 | shuffle=not opt.serial_batches, 25 | num_workers=int(opt.nThreads)) 26 | 27 | def load_data(self): 28 | return self.dataloader 29 | 30 | def __len__(self): 31 | return min(len(self.dataset), self.opt.max_dataset_size) 32 | -------------------------------------------------------------------------------- /data/data_loader.py: -------------------------------------------------------------------------------- 1 | 2 | def CreateDataLoader(opt): 3 | from data.custom_dataset_data_loader import CustomDatasetDataLoader 4 | data_loader = CustomDatasetDataLoader() 5 | print(data_loader.name()) 6 | data_loader.initialize(opt) 7 | return data_loader 8 | def CreateConDataLoader(opt): 9 | from data.con_dataset_data_loader import ConDatasetDataLoader 10 | data_loader = ConDatasetDataLoader() 11 | print(data_loader.name()) 12 | data_loader.initialize(opt) 13 | return data_loader 14 | def CreateFaceConDataLoader(opt): 15 | from data.face_con_dataset_data_loader import FaceConDatasetDataLoader 16 | data_loader = FaceConDatasetDataLoader() 17 | print(data_loader.name()) 18 | data_loader.initialize(opt) 19 | return data_loader 20 | def CreatePoseConDataLoader(opt): 21 | from data.pose_con_dataset_data_loader import PoseConDatasetDataLoader 22 | data_loader = PoseConDatasetDataLoader() 23 | print(data_loader.name()) 24 | data_loader.initialize(opt) 25 | return data_loader 26 | -------------------------------------------------------------------------------- /data/face_con_dataset_data_loader.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data 2 | from data.base_data_loader import BaseDataLoader 3 | 4 | 5 | def CreateDataset(opt): 6 | dataset = None 7 | from data.face_dataset import FaceDataset 8 | dataset = FaceDataset() 9 | 10 | print("dataset [%s] was created" % (dataset.name())) 11 | dataset.initialize(opt) 12 | return dataset 13 | 14 | class FaceConDatasetDataLoader(BaseDataLoader): 15 | def name(self): 16 | return 'FaceConDatasetDataLoader' 17 | 18 | def initialize(self, opt): 19 | BaseDataLoader.initialize(self, opt) 20 | self.dataset = CreateDataset(opt) 21 | self.dataloader = torch.utils.data.DataLoader( 22 | self.dataset, 23 | batch_size=opt.batchSize, 24 | shuffle=not opt.serial_batches, 25 | num_workers=int(opt.nThreads)) 26 | 27 | def load_data(self): 28 | return self.dataloader 29 | 30 | def __len__(self): 31 | return min(len(self.dataset), self.opt.max_dataset_size) 32 | -------------------------------------------------------------------------------- /data/face_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import torchvision.transforms as transforms 3 | import torch 4 | from PIL import Image 5 | import numpy as np 6 | import cv2 7 | from skimage import feature 8 | 9 | from data.base_dataset import BaseDataset, get_transform, get_params 10 | from data.keypoint2img import interpPoints, drawEdge 11 | import random 12 | 13 | class FaceDataset(BaseDataset): 14 | def initialize(self, opt): 15 | self.opt = opt 16 | self.root = opt.dataroot 17 | data_list_path = os.path.join(opt.dataroot, opt.phase + '_list.txt') 18 | f = open(data_list_path, 'r') 19 | self.all_paths = f.readlines() 20 | self.dataset_size = len(self.all_paths) 21 | if (not opt.isTrain and opt.serial_batches): 22 | ff = open(opt.test_delta_path, "r") 23 | t = int(ff.readlines()[0]) 24 | self.delta = t 25 | else: 26 | self.delta = 0 27 | 28 | def get_X(self, path, params, B_size, B_img): 29 | transform_scaleA = get_transform(self.opt, params, method=Image.BILINEAR, normalize=False) 30 | transform_label = get_transform(self.opt, params, method=Image.NEAREST, normalize=False) 31 | Ai, Li = self.get_face_image(path, transform_scaleA, transform_label, B_size, B_img) 32 | return Ai 33 | 34 | def get_X2(self, path, params): 35 | B = Image.open(path).convert('RGB') 36 | transform_B = get_transform(self.opt, params) 37 | B_tensor = transform_B(B) 38 | return B_tensor, B 39 | 40 | def same_style(self, path1, path2): 41 | t1 = path1.split('__') 42 | t2 = path2.split('__') 43 | p1 = t1[-3] + '_' + t1[-2] 44 | p2 = t2[-3] + '_' + t2[-2] 45 | if (p1 == p2): 46 | return 1 47 | else: 48 | return 0 49 | 50 | def __getitem__(self, index): 51 | index = (index + self.delta) % len(self.all_paths) 52 | paths = self.all_paths[index].rstrip('\n').split('&') 53 | A = Image.open(paths[2]) 54 | params = get_params(self.opt, A.size) 55 | 56 | B2_tensor, B = self.get_X2(paths[2], params) 57 | A2_tensor, A = self.get_X2(paths[3], params) 58 | A_tensor = self.get_X(paths[0], params, A.size, A) 59 | B_tensor = self.get_X(paths[1], params, A.size, B) 60 | C_tensor = C2_tensor = D_tensor = D2_tensor = 0 61 | if (self.opt.isTrain): 62 | C2_tensor, C = self.get_X2(paths[5], params) 63 | C_tensor = self.get_X(paths[4], params, A.size, C) 64 | D2_tensor, D = self.get_X2(paths[7], params) 65 | D_tensor = self.get_X(paths[6], params, A.size, D) 66 | input_dict = {'A': A_tensor, 'A2': A2_tensor, 'B': B_tensor, 'B2': B2_tensor, 'C': C_tensor, 'C2': C2_tensor, 67 | 'D': D_tensor, 'D2': D2_tensor, 'path': paths[0] + '_' + paths[1], 'same_style': self.same_style(paths[2], paths[3])} 68 | return input_dict 69 | 70 | ''' 71 | def __getitem__(self, index): 72 | A, B, I, seq_idx = self.update_frame_idx(self.A_paths, index) 73 | A_paths = self.A_paths[seq_idx] 74 | B_paths = self.B_paths[seq_idx] 75 | n_frames_total, start_idx, t_step = get_video_params(self.opt, self.n_frames_total, len(A_paths), self.frame_idx) 76 | 77 | B_img = Image.open(B_paths[0]).convert('RGB') 78 | B_size = B_img.size 79 | points = np.loadtxt(A_paths[0], delimiter=',') 80 | is_first_frame = self.opt.isTrain or not hasattr(self, 'min_x') 81 | if is_first_frame: # crop only the face region 82 | self.get_crop_coords(points, B_size) 83 | params = get_img_params(self.opt, self.crop(B_img).size) 84 | transform_scaleA = get_transform(self.opt, params, method=Image.BILINEAR, normalize=False) 85 | transform_label = get_transform(self.opt, params, method=Image.NEAREST, normalize=False) 86 | transform_scaleB = get_transform(self.opt, params) 87 | 88 | # read in images 89 | frame_range = list(range(n_frames_total)) if self.A is None else [self.opt.n_frames_G-1] 90 | for i in frame_range: 91 | A_path = A_paths[start_idx + i * t_step] 92 | B_path = B_paths[start_idx + i * t_step] 93 | B_img = Image.open(B_path) 94 | Ai, Li = self.get_face_image(A_path, transform_scaleA, transform_label, B_size, B_img) 95 | Bi = transform_scaleB(self.crop(B_img)) 96 | A = concat_frame(A, Ai, n_frames_total) 97 | B = concat_frame(B, Bi, n_frames_total) 98 | I = concat_frame(I, Li, n_frames_total) 99 | 100 | if not self.opt.isTrain: 101 | self.A, self.B, self.I = A, B, I 102 | self.frame_idx += 1 103 | change_seq = False if self.opt.isTrain else self.change_seq 104 | return_list = {'A': A, 'B': B, 'inst': I, 'A_path': A_path, 'change_seq': change_seq} 105 | 106 | return return_list 107 | ''' 108 | def get_image(self, A_path, transform_scaleA): 109 | A_img = Image.open(A_path) 110 | A_scaled = transform_scaleA(self.crop(A_img)) 111 | return A_scaled 112 | 113 | def get_face_image(self, A_path, transform_A, transform_L, size, img): 114 | # read face keypoints from path and crop face region 115 | keypoints, part_list, part_labels = self.read_keypoints(A_path, size) 116 | 117 | # draw edges and possibly add distance transform maps 118 | add_dist_map = not self.opt.no_dist_map 119 | im_edges, dist_tensor = self.draw_face_edges(keypoints, part_list, transform_A, size, add_dist_map) 120 | 121 | # canny edge for background 122 | if not self.opt.no_canny_edge: 123 | edges = feature.canny(np.array(img.convert('L'))) 124 | edges = edges * (part_labels == 0) # remove edges within face 125 | im_edges += (edges * 255).astype(np.uint8) 126 | edge_tensor = transform_A(Image.fromarray(self.crop(im_edges))) 127 | 128 | # final input tensor 129 | input_tensor = torch.cat([edge_tensor, dist_tensor]) if add_dist_map else edge_tensor 130 | label_tensor = transform_L(Image.fromarray(self.crop(part_labels.astype(np.uint8)))) * 255.0 131 | return input_tensor, label_tensor 132 | 133 | def read_keypoints(self, A_path, size): 134 | # mapping from keypoints to face part 135 | part_list = [[list(range(0, 17)) + list(range(68, 83)) + [0]], # face 136 | [range(17, 22)], # right eyebrow 137 | [range(22, 27)], # left eyebrow 138 | [[28, 31], range(31, 36), [35, 28]], # nose 139 | [[36,37,38,39], [39,40,41,36]], # right eye 140 | [[42,43,44,45], [45,46,47,42]], # left eye 141 | [range(48, 55), [54,55,56,57,58,59,48]], # mouth 142 | [range(60, 65), [64,65,66,67,60]] # tongue 143 | ] 144 | label_list = [1, 2, 2, 3, 4, 4, 5, 6] # labeling for different facial parts 145 | keypoints = np.loadtxt(A_path, delimiter=',') 146 | 147 | # add upper half face by symmetry 148 | pts = keypoints[:17, :].astype(np.int32) 149 | baseline_y = (pts[0,1] + pts[-1,1]) / 2 150 | upper_pts = pts[1:-1,:].copy() 151 | upper_pts[:,1] = baseline_y + (baseline_y-upper_pts[:,1]) * 2 // 3 152 | keypoints = np.vstack((keypoints, upper_pts[::-1,:])) 153 | 154 | # label map for facial part 155 | w, h = size 156 | part_labels = np.zeros((h, w), np.uint8) 157 | for p, edge_list in enumerate(part_list): 158 | indices = [item for sublist in edge_list for item in sublist] 159 | pts = keypoints[indices, :].astype(np.int32) 160 | cv2.fillPoly(part_labels, pts=[pts], color=label_list[p]) 161 | 162 | return keypoints, part_list, part_labels 163 | 164 | def draw_face_edges(self, keypoints, part_list, transform_A, size, add_dist_map): 165 | w, h = size 166 | edge_len = 3 # interpolate 3 keypoints to form a curve when drawing edges 167 | # edge map for face region from keypoints 168 | im_edges = np.zeros((h, w), np.uint8) # edge map for all edges 169 | dist_tensor = 0 170 | e = 1 171 | for edge_list in part_list: 172 | for edge in edge_list: 173 | im_edge = np.zeros((h, w), np.uint8) # edge map for the current edge 174 | for i in range(0, max(1, len(edge)-1), edge_len-1): # divide a long edge into multiple small edges when drawing 175 | sub_edge = edge[i:i+edge_len] 176 | x = keypoints[sub_edge, 0] 177 | y = keypoints[sub_edge, 1] 178 | 179 | curve_x, curve_y = interpPoints(x, y) # interp keypoints to get the curve shape 180 | drawEdge(im_edges, curve_x, curve_y) 181 | if add_dist_map: 182 | drawEdge(im_edge, curve_x, curve_y) 183 | 184 | if add_dist_map: # add distance transform map on each facial part 185 | im_dist = cv2.distanceTransform(255-im_edge, cv2.DIST_L1, 3) 186 | im_dist = np.clip((im_dist / 3), 0, 255).astype(np.uint8) 187 | im_dist = Image.fromarray(im_dist) 188 | tensor_cropped = transform_A(self.crop(im_dist)) 189 | dist_tensor = tensor_cropped if e == 1 else torch.cat([dist_tensor, tensor_cropped]) 190 | e += 1 191 | 192 | return im_edges, dist_tensor 193 | 194 | def get_crop_coords(self, keypoints, size): 195 | min_y, max_y = keypoints[:,1].min(), keypoints[:,1].max() 196 | min_x, max_x = keypoints[:,0].min(), keypoints[:,0].max() 197 | offset = (max_x - min_x) // 2 198 | min_y = max(0, min_y - offset*2) 199 | min_x = max(0, min_x - offset) 200 | max_x = min(size[0], max_x + offset) 201 | max_y = min(size[1], max_y + offset) 202 | self.min_y, self.max_y, self.min_x, self.max_x = int(min_y), int(max_y), int(min_x), int(max_x) 203 | 204 | def crop(self, img): 205 | return img 206 | #??? 207 | if isinstance(img, np.ndarray): 208 | return img[self.min_y:self.max_y, self.min_x:self.max_x] 209 | else: 210 | return img.crop((self.min_x, self.min_y, self.max_x, self.max_y)) 211 | 212 | def __len__(self): 213 | return len(self.all_paths) // self.opt.batchSize * self.opt.batchSize 214 | 215 | def name(self): 216 | return 'FaceDataset' 217 | -------------------------------------------------------------------------------- /data/image_folder.py: -------------------------------------------------------------------------------- 1 | ############################################################################### 2 | # Code from 3 | # https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py 4 | # Modified the original code so that it also loads images from the current 5 | # directory as well as the subdirectories 6 | ############################################################################### 7 | import torch.utils.data as data 8 | from PIL import Image 9 | import os 10 | 11 | IMG_EXTENSIONS = [ 12 | '.jpg', '.JPG', '.jpeg', '.JPEG', 13 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff' 14 | ] 15 | 16 | 17 | def is_image_file(filename): 18 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 19 | 20 | 21 | def make_dataset(dir): 22 | images = [] 23 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 24 | 25 | for root, _, fnames in sorted(os.walk(dir)): 26 | for fname in fnames: 27 | if is_image_file(fname): 28 | path = os.path.join(root, fname) 29 | images.append(path) 30 | 31 | return images 32 | 33 | 34 | def default_loader(path): 35 | return Image.open(path).convert('RGB') 36 | 37 | 38 | class ImageFolder(data.Dataset): 39 | 40 | def __init__(self, root, transform=None, return_paths=False, 41 | loader=default_loader): 42 | imgs = make_dataset(root) 43 | if len(imgs) == 0: 44 | raise(RuntimeError("Found 0 images in: " + root + "\n" 45 | "Supported image extensions are: " + 46 | ",".join(IMG_EXTENSIONS))) 47 | 48 | self.root = root 49 | self.imgs = imgs 50 | self.transform = transform 51 | self.return_paths = return_paths 52 | self.loader = loader 53 | 54 | def __getitem__(self, index): 55 | path = self.imgs[index] 56 | img = self.loader(path) 57 | if self.transform is not None: 58 | img = self.transform(img) 59 | if self.return_paths: 60 | return img, path 61 | else: 62 | return img 63 | 64 | def __len__(self): 65 | return len(self.imgs) 66 | -------------------------------------------------------------------------------- /data/keypoint2img.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | from PIL import Image 3 | import numpy as np 4 | import json 5 | import glob 6 | from scipy.optimize import curve_fit 7 | import cv2 8 | import torch 9 | import warnings 10 | 11 | def func(x, a, b, c): 12 | return a * x**2 + b * x + c 13 | 14 | def linear(x, a, b): 15 | return a * x + b 16 | 17 | def setColor(im, yy, xx, color): 18 | if len(im.shape) == 3: 19 | if (im[yy, xx] == 0).all(): 20 | im[yy, xx, 0], im[yy, xx, 1], im[yy, xx, 2] = color[0], color[1], color[2] 21 | else: 22 | im[yy, xx, 0] = ((im[yy, xx, 0].astype(float) + color[0]) / 2).astype(np.uint8) 23 | im[yy, xx, 1] = ((im[yy, xx, 1].astype(float) + color[1]) / 2).astype(np.uint8) 24 | im[yy, xx, 2] = ((im[yy, xx, 2].astype(float) + color[2]) / 2).astype(np.uint8) 25 | else: 26 | im[yy, xx] = color[0] 27 | 28 | def drawEdge(im, x, y, bw=1, color=(255,255,255), draw_end_points=False): 29 | if x is not None and x.size: 30 | h, w = im.shape[0], im.shape[1] 31 | # edge 32 | for i in range(-bw, bw): 33 | for j in range(-bw, bw): 34 | yy = np.maximum(0, np.minimum(h-1, y+i)) 35 | xx = np.maximum(0, np.minimum(w-1, x+j)) 36 | setColor(im, yy, xx, color) 37 | 38 | # edge endpoints 39 | if draw_end_points: 40 | for i in range(-bw*2, bw*2): 41 | for j in range(-bw*2, bw*2): 42 | if (i**2) + (j**2) < (4 * bw**2): 43 | yy = np.maximum(0, np.minimum(h-1, np.array([y[0], y[-1]])+i)) 44 | xx = np.maximum(0, np.minimum(w-1, np.array([x[0], x[-1]])+j)) 45 | setColor(im, yy, xx, color) 46 | 47 | def interpPoints(x, y): 48 | if abs(x[:-1] - x[1:]).max() < abs(y[:-1] - y[1:]).max(): 49 | curve_y, curve_x = interpPoints(y, x) 50 | if curve_y is None: 51 | return None, None 52 | else: 53 | with warnings.catch_warnings(): 54 | warnings.simplefilter("ignore") 55 | if len(x) < 3: 56 | popt, _ = curve_fit(linear, x, y) 57 | else: 58 | popt, _ = curve_fit(func, x, y) 59 | if abs(popt[0]) > 1: 60 | return None, None 61 | if x[0] > x[-1]: 62 | x = list(reversed(x)) 63 | y = list(reversed(y)) 64 | curve_x = np.linspace(x[0], x[-1], (x[-1]-x[0])) 65 | if len(x) < 3: 66 | curve_y = linear(curve_x, *popt) 67 | else: 68 | curve_y = func(curve_x, *popt) 69 | return curve_x.astype(int), curve_y.astype(int) 70 | 71 | def read_keypoints(json_input, size, random_drop_prob=0, remove_face_labels=False): 72 | with open(json_input, encoding='utf-8') as f: 73 | keypoint_dicts = json.loads(f.read())["people"] 74 | 75 | edge_lists = define_edge_lists() 76 | w, h = size 77 | pose_img = np.zeros((h, w, 3), np.uint8) 78 | for keypoint_dict in keypoint_dicts: 79 | pose_pts = np.array(keypoint_dict["pose_keypoints_2d"]).reshape(25, 3) 80 | face_pts = np.array(keypoint_dict["face_keypoints_2d"]).reshape(70, 3) 81 | hand_pts_l = np.array(keypoint_dict["hand_left_keypoints_2d"]).reshape(21, 3) 82 | hand_pts_r = np.array(keypoint_dict["hand_right_keypoints_2d"]).reshape(21, 3) 83 | pts = [extract_valid_keypoints(pts, edge_lists) for pts in [pose_pts, face_pts, hand_pts_l, hand_pts_r]] 84 | pose_img += connect_keypoints(pts, edge_lists, size, random_drop_prob, remove_face_labels) 85 | return pose_img 86 | 87 | def read_keypoints2(json_input, size, transform_A, random_drop_prob=0, remove_face_labels=False): 88 | with open(json_input) as f: 89 | keypoint_dicts = json.loads(f.read()) 90 | 91 | edge_lists = define_edge_lists2() 92 | w, h = size 93 | pose_pts = np.array(keypoint_dicts).reshape(18, 3) 94 | pts = extract_valid_keypoints2(pose_pts, edge_lists) 95 | pose_img, dist_tensor = connect_keypoints2(pts, edge_lists, size, random_drop_prob, remove_face_labels, transform_A) 96 | return pose_img, dist_tensor 97 | 98 | def extract_valid_keypoints2(pts, edge_lists): 99 | pose_edge_list, _ = edge_lists 100 | p = pts.shape[0] 101 | thre = 0.1 if p == 70 else 0.01 102 | output = np.zeros((p, 2)) 103 | 104 | valid = (pts[:, 2] > thre) 105 | output[valid, :] = pts[valid, :2] 106 | 107 | return output 108 | 109 | def extract_valid_keypoints(pts, edge_lists): 110 | pose_edge_list, _, hand_edge_list, _, face_list = edge_lists 111 | p = pts.shape[0] 112 | thre = 0.1 if p == 70 else 0.01 113 | output = np.zeros((p, 2)) 114 | 115 | if p == 70: # face 116 | for edge_list in face_list: 117 | for edge in edge_list: 118 | if (pts[edge, 2] > thre).all(): 119 | output[edge, :] = pts[edge, :2] 120 | elif p == 21: # hand 121 | for edge in hand_edge_list: 122 | if (pts[edge, 2] > thre).all(): 123 | output[edge, :] = pts[edge, :2] 124 | else: # pose 125 | valid = (pts[:, 2] > thre) 126 | output[valid, :] = pts[valid, :2] 127 | 128 | return output 129 | 130 | def connect_keypoints2(pts, edge_lists, size, random_drop_prob, remove_face_labels, transform_A): 131 | pose_pts = pts 132 | w, h = size 133 | output_edges = np.zeros((h, w, 3), np.uint8) 134 | pose_edge_list, pose_color_list = edge_lists 135 | dist_tensor = 0 136 | e = 1 137 | 138 | ### pose 139 | for i, edge in enumerate(pose_edge_list): 140 | im_edge = np.zeros((h, w), np.uint8) 141 | x, y = pose_pts[edge, 0], pose_pts[edge, 1] 142 | if (np.random.rand() > random_drop_prob) and (0 not in x): 143 | curve_x, curve_y = interpPoints(x, y) 144 | drawEdge(output_edges, curve_x, curve_y, bw=2, color=pose_color_list[i], draw_end_points=True) 145 | drawEdge(im_edge, curve_x, curve_y) 146 | im_dist = cv2.distanceTransform(255-im_edge, cv2.DIST_L1, 3) 147 | im_dist = np.clip((im_dist / 2), 0, 255).astype(np.uint8) 148 | im_dist = Image.fromarray(im_dist) 149 | tensor_cropped = transform_A(im_dist) 150 | dist_tensor = tensor_cropped if e == 1 else torch.cat((dist_tensor, tensor_cropped), 0) 151 | e += 1 152 | return Image.fromarray(output_edges), dist_tensor 153 | 154 | def connect_keypoints(pts, edge_lists, size, random_drop_prob, remove_face_labels): 155 | pose_pts, face_pts, hand_pts_l, hand_pts_r = pts 156 | w, h = size 157 | output_edges = np.zeros((h, w, 3), np.uint8) 158 | pose_edge_list, pose_color_list, hand_edge_list, hand_color_list, face_list = edge_lists 159 | 160 | if random_drop_prob > 0 and remove_face_labels: 161 | # add random noise to keypoints 162 | pose_pts[[0,15,16,17,18], :] += 5 * np.random.randn(5,2) 163 | face_pts[:,0] += 2 * np.random.randn() 164 | face_pts[:,1] += 2 * np.random.randn() 165 | 166 | ### pose 167 | for i, edge in enumerate(pose_edge_list): 168 | x, y = pose_pts[edge, 0], pose_pts[edge, 1] 169 | if (np.random.rand() > random_drop_prob) and (0 not in x): 170 | curve_x, curve_y = interpPoints(x, y) 171 | drawEdge(output_edges, curve_x, curve_y, bw=3, color=pose_color_list[i], draw_end_points=True) 172 | 173 | ### hand 174 | for hand_pts in [hand_pts_l, hand_pts_r]: # for left and right hand 175 | if np.random.rand() > random_drop_prob: 176 | for i, edge in enumerate(hand_edge_list): # for each finger 177 | for j in range(0, len(edge)-1): # for each part of the finger 178 | sub_edge = edge[j:j+2] 179 | x, y = hand_pts[sub_edge, 0], hand_pts[sub_edge, 1] 180 | if 0 not in x: 181 | line_x, line_y = interpPoints(x, y) 182 | drawEdge(output_edges, line_x, line_y, bw=1, color=hand_color_list[i], draw_end_points=True) 183 | 184 | ### face 185 | edge_len = 2 186 | if (np.random.rand() > random_drop_prob): 187 | for edge_list in face_list: 188 | for edge in edge_list: 189 | for i in range(0, max(1, len(edge)-1), edge_len-1): 190 | sub_edge = edge[i:i+edge_len] 191 | x, y = face_pts[sub_edge, 0], face_pts[sub_edge, 1] 192 | if 0 not in x: 193 | curve_x, curve_y = interpPoints(x, y) 194 | drawEdge(output_edges, curve_x, curve_y, draw_end_points=True) 195 | 196 | return output_edges 197 | 198 | def define_edge_lists2(): 199 | ### pose 200 | pose_edge_list = [ 201 | [0, 1], 202 | [14, 16], [15, 17], 203 | [14, 0], [15, 0], 204 | [1, 2], [1, 5], 205 | [1, 8], [1, 11], 206 | [8, 9], [9, 10], 207 | [11, 12], [12, 13], 208 | [2, 3], [3, 4], 209 | [5, 6], [6, 7] 210 | ] 211 | pose_edge_list = [ 212 | [0, 1], 213 | [2, 3], [3, 4], 214 | [5, 6], [6, 7], 215 | [8, 9], [9, 10], 216 | [11, 12], [12, 13], 217 | [14, 16], [15, 17], 218 | [14, 0], [15, 0], 219 | [1, 2], [1, 5], 220 | [1, 8], [1, 11] 221 | ] 222 | pose_color_list = [ 223 | [153, 0,153], [153, 0,102], [102, 0,153], [ 51, 0,153], [153, 0, 51], 224 | [153, 0, 0], 225 | [153, 51, 0], [153,102, 0], [153,153, 0], 226 | [102,153, 0], [ 51,153, 0], [ 0,153, 0], 227 | [ 0,153, 51], [ 0,153,102], [ 0,153,153], [ 0,102,153], [ 0,51,153], [ 0,0,153], 228 | [ 0,102,153], [ 0, 51,153], [ 0, 0,153], [ 0, 0,153], [ 0, 0,153], [ 0, 0,153] 229 | ] 230 | 231 | return pose_edge_list, pose_color_list 232 | 233 | def define_edge_lists(): 234 | ### pose 235 | pose_edge_list = [ 236 | [17, 15], [15, 0], [ 0, 16], [16, 18], [ 0, 1], # head 237 | [ 1, 8], # body 238 | [ 1, 2], [ 2, 3], [ 3, 4], # right arm 239 | [ 1, 5], [ 5, 6], [ 6, 7], # left arm 240 | [ 8, 9], [ 9, 10], [10, 11], [11, 24], [11, 22], [22, 23], # right leg 241 | [ 8, 12], [12, 13], [13, 14], [14, 21], [14, 19], [19, 20] # left leg 242 | ] 243 | pose_color_list = [ 244 | [153, 0,153], [153, 0,102], [102, 0,153], [ 51, 0,153], [153, 0, 51], 245 | [153, 0, 0], 246 | [153, 51, 0], [153,102, 0], [153,153, 0], 247 | [102,153, 0], [ 51,153, 0], [ 0,153, 0], 248 | [ 0,153, 51], [ 0,153,102], [ 0,153,153], [ 0,153,153], [ 0,153,153], [ 0,153,153], 249 | [ 0,102,153], [ 0, 51,153], [ 0, 0,153], [ 0, 0,153], [ 0, 0,153], [ 0, 0,153] 250 | ] 251 | 252 | ### hand 253 | hand_edge_list = [ 254 | [0, 1, 2, 3, 4], 255 | [0, 5, 6, 7, 8], 256 | [0, 9, 10, 11, 12], 257 | [0, 13, 14, 15, 16], 258 | [0, 17, 18, 19, 20] 259 | ] 260 | hand_color_list = [ 261 | [204,0,0], [163,204,0], [0,204,82], [0,82,204], [163,0,204] 262 | ] 263 | 264 | ### face 265 | face_list = [ 266 | #[range(0, 17)], # face 267 | [range(17, 22)], # left eyebrow 268 | [range(22, 27)], # right eyebrow 269 | [range(27, 31), range(31, 36)], # nose 270 | [[36,37,38,39], [39,40,41,36]], # left eye 271 | [[42,43,44,45], [45,46,47,42]], # right eye 272 | [range(48, 55), [54,55,56,57,58,59,48]], # mouth 273 | ] 274 | return pose_edge_list, pose_color_list, hand_edge_list, hand_color_list, face_list 275 | -------------------------------------------------------------------------------- /data/pose_con_dataset_data_loader.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data 2 | from data.base_data_loader import BaseDataLoader 3 | 4 | 5 | def CreateDataset(opt): 6 | dataset = None 7 | from data.pose_dataset import PoseDataset 8 | dataset = PoseDataset() 9 | 10 | print("dataset [%s] was created" % (dataset.name())) 11 | dataset.initialize(opt) 12 | return dataset 13 | 14 | class PoseConDatasetDataLoader(BaseDataLoader): 15 | def name(self): 16 | return 'PoseConDatasetDataLoader' 17 | 18 | def initialize(self, opt): 19 | BaseDataLoader.initialize(self, opt) 20 | self.dataset = CreateDataset(opt) 21 | self.dataloader = torch.utils.data.DataLoader( 22 | self.dataset, 23 | batch_size=opt.batchSize, 24 | shuffle=not opt.serial_batches, 25 | num_workers=int(opt.nThreads)) 26 | 27 | def load_data(self): 28 | return self.dataloader 29 | 30 | def __len__(self): 31 | return min(len(self.dataset), self.opt.max_dataset_size) 32 | -------------------------------------------------------------------------------- /data/pose_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import torchvision.transforms as transforms 3 | import torch 4 | from PIL import Image 5 | import numpy as np 6 | import cv2 7 | from skimage import feature 8 | 9 | from data.base_dataset import BaseDataset, get_transform, get_params 10 | from data.keypoint2img import interpPoints, drawEdge, read_keypoints2 11 | import random 12 | 13 | class PoseDataset(BaseDataset): 14 | def initialize(self, opt): 15 | self.opt = opt 16 | self.root = opt.dataroot 17 | data_list_path = os.path.join(opt.dataroot, opt.phase + '_list.txt') 18 | f = open(data_list_path, 'r') 19 | self.all_paths = f.readlines() 20 | self.dataset_size = len(self.all_paths) 21 | if (not opt.isTrain and opt.serial_batches): 22 | ff = open(opt.test_delta_path, "r") 23 | t = int(ff.readlines()[0]) 24 | self.delta = t 25 | else: 26 | self.delta = 0 27 | 28 | 29 | def get_image2(self, A_path, size, params, input_type): 30 | is_img = input_type == 'img' 31 | method = Image.BICUBIC if is_img else Image.NEAREST 32 | transform_scaleA = get_transform(self.opt, params, normalize=is_img, method=method) 33 | if input_type != 'openpose': 34 | A_img = Image.open(A_path).convert('RGB') 35 | else: 36 | random_drop_prob = self.opt.random_drop_prob if self.opt.isTrain else 0 37 | A_img, dist_tensor = read_keypoints2(A_path, size, transform_scaleA, random_drop_prob, False) 38 | 39 | if input_type == 'densepose' and self.opt.isTrain: 40 | # randomly remove labels 41 | A_np = np.array(A_img) 42 | part_labels = A_np[:,:,2] 43 | for part_id in range(1, 25): 44 | if (np.random.rand() < self.opt.random_drop_prob): 45 | A_np[(part_labels == part_id), :] = 0 46 | if self.opt.remove_face_labels: 47 | A_np[(part_labels == 23) | (part_labels == 24), :] = 0 48 | A_img = Image.fromarray(A_np) 49 | 50 | A_scaled = transform_scaleA(A_img) 51 | if (input_type == 'openpose' and self.opt.do_pose_dist_map): 52 | A_scaled = torch.cat([A_scaled, dist_tensor]) 53 | return A_scaled 54 | 55 | def get_X(self, path, params, B_size, B_tensor): 56 | paths = path.split('|') 57 | if not self.opt.openpose_only: 58 | Di = self.get_image2(paths[0], B_size, params, input_type='densepose') 59 | Di[2,:,:] = Di[2,:,:] * 255 / 24 60 | if not self.opt.densepose_only: 61 | Oi = self.get_image2(paths[1], B_size, params, input_type='openpose') 62 | 63 | if self.opt.openpose_only: 64 | Ai = Oi 65 | elif self.opt.densepose_only: 66 | Ai = Di 67 | else: 68 | Ai = torch.cat([Di, Oi]) 69 | valid = (Di[0, :, :] > 0) + (Di[1, :, :] > 0) + (Di[2, :, :] > 0) 70 | valid_ = valid == 0 71 | if (self.opt.add_mask): 72 | B_tensor[:, valid_] = 0 73 | ''' 74 | valid_ = (valid > 0)[np.newaxis, :, :] 75 | valid__ = torch.cat([valid_, valid_, valid_]) 76 | print(valid__[:, 128, 128]) 77 | B_out = B_tensor * valid__ 78 | ''' 79 | #Ai, Bi = self.crop(Ai), self.crop(Bi) # only crop the central half region to save time 80 | return Ai, B_tensor 81 | 82 | def same_style(self, path1, path2): 83 | t1 = path1.split('__') 84 | t2 = path2.split('__') 85 | p1 = t1[-3] + '_' + t1[-2] 86 | p2 = t2[-3] + '_' + t2[-2] 87 | if (p1 == p2): 88 | return 1 89 | else: 90 | return 0 91 | 92 | def get_X2(self, path, params): 93 | B = Image.open(path).convert('RGB') 94 | transform_B = get_transform(self.opt, params) 95 | B_tensor = transform_B(B) 96 | return B_tensor, B 97 | 98 | def __getitem__(self, index): 99 | index = (index + self.delta) % len(self.all_paths) 100 | paths = self.all_paths[index].rstrip('\n').split('&') 101 | A = Image.open(paths[2]) 102 | params = get_params(self.opt, A.size) 103 | 104 | B2_tensor_, B = self.get_X2(paths[2], params) 105 | A2_tensor_, A = self.get_X2(paths[3], params) 106 | A_tensor, A2_tensor = self.get_X(paths[0], params, A.size, A2_tensor_) 107 | B_tensor, B2_tensor = self.get_X(paths[1], params, A.size, B2_tensor_) 108 | C_tensor = C2_tensor = D_tensor = D2_tensor = 0 109 | if (self.opt.isTrain): 110 | C2_tensor_, C = self.get_X2(paths[5], params) 111 | C_tensor, C2_tensor = self.get_X(paths[4], params, A.size, C2_tensor_) 112 | D2_tensor_, D = self.get_X2(paths[7], params) 113 | D_tensor, D2_tensor = self.get_X(paths[6], params, A.size, D2_tensor_) 114 | input_dict = {'A': A_tensor, 'A2': A2_tensor, 'B': B_tensor, 'B2': B2_tensor, 'C': C_tensor, 'C2': C2_tensor, 115 | 'D': D_tensor, 'D2': D2_tensor, 'path': paths[0] + '_' + paths[1], 'same_style': self.same_style(paths[2], paths[3])} 116 | return input_dict 117 | 118 | def get_image(self, A_path, transform_scaleA): 119 | A_img = Image.open(A_path) 120 | A_scaled = transform_scaleA(self.crop(A_img)) 121 | return A_scaled 122 | 123 | def get_face_image(self, A_path, transform_A, transform_L, size, img): 124 | # read face keypoints from path and crop face region 125 | keypoints, part_list, part_labels = self.read_keypoints(A_path, size) 126 | 127 | # draw edges and possibly add distance transform maps 128 | add_dist_map = not self.opt.no_dist_map 129 | im_edges, dist_tensor = self.draw_face_edges(keypoints, part_list, transform_A, size, add_dist_map) 130 | 131 | # canny edge for background 132 | if not self.opt.no_canny_edge: 133 | edges = feature.canny(np.array(img.convert('L'))) 134 | edges = edges * (part_labels == 0) # remove edges within face 135 | im_edges += (edges * 255).astype(np.uint8) 136 | edge_tensor = transform_A(Image.fromarray(self.crop(im_edges))) 137 | 138 | # final input tensor 139 | input_tensor = torch.cat([edge_tensor, dist_tensor]) if add_dist_map else edge_tensor 140 | label_tensor = transform_L(Image.fromarray(self.crop(part_labels.astype(np.uint8)))) * 255.0 141 | return input_tensor, label_tensor 142 | 143 | def read_keypoints(self, A_path, size): 144 | # mapping from keypoints to face part 145 | part_list = [[list(range(0, 17)) + list(range(68, 83)) + [0]], # face 146 | [range(17, 22)], # right eyebrow 147 | [range(22, 27)], # left eyebrow 148 | [[28, 31], range(31, 36), [35, 28]], # nose 149 | [[36,37,38,39], [39,40,41,36]], # right eye 150 | [[42,43,44,45], [45,46,47,42]], # left eye 151 | [range(48, 55), [54,55,56,57,58,59,48]], # mouth 152 | [range(60, 65), [64,65,66,67,60]] # tongue 153 | ] 154 | label_list = [1, 2, 2, 3, 4, 4, 5, 6] # labeling for different facial parts 155 | keypoints = np.loadtxt(A_path, delimiter=',') 156 | 157 | # add upper half face by symmetry 158 | pts = keypoints[:17, :].astype(np.int32) 159 | baseline_y = (pts[0,1] + pts[-1,1]) / 2 160 | upper_pts = pts[1:-1,:].copy() 161 | upper_pts[:,1] = baseline_y + (baseline_y-upper_pts[:,1]) * 2 // 3 162 | keypoints = np.vstack((keypoints, upper_pts[::-1,:])) 163 | 164 | # label map for facial part 165 | w, h = size 166 | part_labels = np.zeros((h, w), np.uint8) 167 | for p, edge_list in enumerate(part_list): 168 | indices = [item for sublist in edge_list for item in sublist] 169 | pts = keypoints[indices, :].astype(np.int32) 170 | cv2.fillPoly(part_labels, pts=[pts], color=label_list[p]) 171 | 172 | return keypoints, part_list, part_labels 173 | 174 | def draw_face_edges(self, keypoints, part_list, transform_A, size, add_dist_map): 175 | w, h = size 176 | edge_len = 3 # interpolate 3 keypoints to form a curve when drawing edges 177 | # edge map for face region from keypoints 178 | im_edges = np.zeros((h, w), np.uint8) # edge map for all edges 179 | dist_tensor = 0 180 | e = 1 181 | for edge_list in part_list: 182 | for edge in edge_list: 183 | im_edge = np.zeros((h, w), np.uint8) # edge map for the current edge 184 | for i in range(0, max(1, len(edge)-1), edge_len-1): # divide a long edge into multiple small edges when drawing 185 | sub_edge = edge[i:i+edge_len] 186 | x = keypoints[sub_edge, 0] 187 | y = keypoints[sub_edge, 1] 188 | 189 | curve_x, curve_y = interpPoints(x, y) # interp keypoints to get the curve shape 190 | drawEdge(im_edges, curve_x, curve_y) 191 | if add_dist_map: 192 | drawEdge(im_edge, curve_x, curve_y) 193 | 194 | if add_dist_map: # add distance transform map on each facial part 195 | im_dist = cv2.distanceTransform(255-im_edge, cv2.DIST_L1, 3) 196 | im_dist = np.clip((im_dist / 3), 0, 255).astype(np.uint8) 197 | im_dist = Image.fromarray(im_dist) 198 | tensor_cropped = transform_A(self.crop(im_dist)) 199 | dist_tensor = tensor_cropped if e == 1 else torch.cat([dist_tensor, tensor_cropped]) 200 | e += 1 201 | 202 | return im_edges, dist_tensor 203 | 204 | def get_crop_coords(self, keypoints, size): 205 | min_y, max_y = keypoints[:,1].min(), keypoints[:,1].max() 206 | min_x, max_x = keypoints[:,0].min(), keypoints[:,0].max() 207 | offset = (max_x - min_x) // 2 208 | min_y = max(0, min_y - offset*2) 209 | min_x = max(0, min_x - offset) 210 | max_x = min(size[0], max_x + offset) 211 | max_y = min(size[1], max_y + offset) 212 | self.min_y, self.max_y, self.min_x, self.max_x = int(min_y), int(max_y), int(min_x), int(max_x) 213 | 214 | def crop(self, img): 215 | return img 216 | #??? 217 | if isinstance(img, np.ndarray): 218 | return img[self.min_y:self.max_y, self.min_x:self.max_x] 219 | else: 220 | return img.crop((self.min_x, self.min_y, self.max_x, self.max_y)) 221 | 222 | def __len__(self): 223 | return len(self.all_paths) // self.opt.batchSize * self.opt.batchSize 224 | 225 | def name(self): 226 | return 'PoseDataset' 227 | -------------------------------------------------------------------------------- /encode_features.py: -------------------------------------------------------------------------------- 1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 3 | from options.train_options import TrainOptions 4 | from data.data_loader import CreateDataLoader 5 | from models.models import create_model 6 | import numpy as np 7 | import os 8 | 9 | opt = TrainOptions().parse() 10 | opt.nThreads = 1 11 | opt.batchSize = 1 12 | opt.serial_batches = True 13 | opt.no_flip = True 14 | opt.instance_feat = True 15 | 16 | name = 'features' 17 | save_path = os.path.join(opt.checkpoints_dir, opt.name) 18 | 19 | ############ Initialize ######### 20 | data_loader = CreateDataLoader(opt) 21 | dataset = data_loader.load_data() 22 | dataset_size = len(data_loader) 23 | model = create_model(opt) 24 | 25 | ########### Encode features ########### 26 | reencode = True 27 | if reencode: 28 | features = {} 29 | for label in range(opt.label_nc): 30 | features[label] = np.zeros((0, opt.feat_num+1)) 31 | for i, data in enumerate(dataset): 32 | feat = model.module.encode_features(data['image'], data['inst']) 33 | for label in range(opt.label_nc): 34 | features[label] = np.append(features[label], feat[label], axis=0) 35 | 36 | print('%d / %d images' % (i+1, dataset_size)) 37 | save_name = os.path.join(save_path, name + '.npy') 38 | np.save(save_name, features) 39 | 40 | ############## Clustering ########### 41 | n_clusters = opt.n_clusters 42 | load_name = os.path.join(save_path, name + '.npy') 43 | features = np.load(load_name).item() 44 | from sklearn.cluster import KMeans 45 | centers = {} 46 | for label in range(opt.label_nc): 47 | feat = features[label] 48 | feat = feat[feat[:,-1] > 0.5, :-1] 49 | if feat.shape[0]: 50 | n_clusters = min(feat.shape[0], opt.n_clusters) 51 | kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit(feat) 52 | centers[label] = kmeans.cluster_centers_ 53 | save_name = os.path.join(save_path, name + '_clustered_%03d.npy' % opt.n_clusters) 54 | np.save(save_name, centers) 55 | print('saving to %s' % save_name) -------------------------------------------------------------------------------- /expand_val_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | def expand(path): 3 | f = open(path, "r") 4 | fs = f.readlines() 5 | ans = [] 6 | for line in fs: 7 | paths = line.rstrip('\n').split('&') 8 | new = [paths[0], paths[1], paths[2], paths[3], paths[1], paths[2], paths[1], paths[2]] 9 | ans.append('&'.join(new) + '\n') 10 | fw = open(path + '_', "w") 11 | fw.writelines(ans) 12 | path = 'datasets/YouTubeFaces_/' 13 | expand(os.path.join(path, 'val_list.txt')) 14 | expand(os.path.join(path, 'test_list.txt')) 15 | -------------------------------------------------------------------------------- /generate_data_face_forensics.py: -------------------------------------------------------------------------------- 1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 3 | import os 4 | from collections import OrderedDict 5 | from options.generate_data_options import GenerateDataOptions 6 | from data.data_loader import CreateDataLoader 7 | from PIL import Image 8 | import copy 9 | import numpy as np 10 | from shutil import copyfile 11 | import random 12 | import glob 13 | from skimage import transform,io 14 | import dlib 15 | import sys 16 | import time 17 | 18 | threshold = 130 19 | crop_rate = 0.6 20 | face_rate = 0.5 21 | up_center_rate = 0.2 22 | out_size = 512 23 | side_face_threshold = 0.7 24 | predictor_path = os.path.join('./datasets/', 'shape_predictor_68_face_landmarks.dat') 25 | detector = dlib.get_frontal_face_detector() 26 | predictor = dlib.shape_predictor(predictor_path) 27 | time_last = time.time() 28 | FPS = 30 29 | 30 | def get_keys(img, get_point = True): 31 | global time_last 32 | time_start = time.time() 33 | #print("other:" + str(time_start - time_last)) 34 | dets = detector(img, 1) 35 | #print(time.time() - time_start) 36 | time_last = time.time() 37 | detected = False 38 | points = [] 39 | left_up_x = 10000000 40 | left_up_y = 10000000 41 | right_down_x = -1 42 | right_down_y = -1 43 | 44 | if len(dets) > 0: 45 | detected = True 46 | if (get_point == False): 47 | return True, [], [dets[0].left(), dets[0].top(), dets[0].right(), dets[0].bottom()] 48 | else: 49 | shape = predictor(img, dets[0]) 50 | points = np.empty([68, 2], dtype=int) 51 | for b in range(68): 52 | points[b,0] = shape.part(b).x 53 | points[b,1] = shape.part(b).y 54 | left_up_x = min(left_up_x, points[b, 0]) 55 | left_up_y = min(left_up_y, points[b, 1]) 56 | right_down_x = max(right_down_x, points[b, 0]) 57 | right_down_y = max(right_down_y, points[b, 1]) 58 | if (abs(points[8, 0] - points[1, 0]) == 0 or abs(points[8, 0] - points[15, 0]) == 0): 59 | detected = False 60 | else: 61 | r = float(abs(points[8, 0] - points[1, 0])) / abs(points[8, 0] - points[15, 0]) 62 | r = min(r, 1 / r) 63 | if (r < side_face_threshold): 64 | detected = False 65 | 66 | return detected, points, [left_up_x, left_up_y, right_down_x, right_down_y] 67 | 68 | def get_img(path, target_path): 69 | #print("233") 70 | img = io.imread(path) 71 | detected, _, box = get_keys(img) 72 | if (detected): 73 | h, w, c = img.shape 74 | b_h = box[3] - box[1] 75 | b_w = box[2] - box[0] 76 | dh = int((b_h / face_rate - b_h) / 2) 77 | dw = int((b_w / face_rate - b_w) / 2) 78 | ddh = int(b_h * up_center_rate) 79 | if (box[1] - dh - ddh < 0 or box[3] + dh - ddh >= h or box[0] - dw < 0 or box[2] + dw >= w): 80 | return False, None 81 | else: 82 | img_ = img[box[1] - dh - ddh : box[3] + dh - ddh, box[0] - dw : box[2] + dw, :].copy() 83 | if (img_.shape[0] < out_size / 4 or img_.shape[1] < out_size / 4): 84 | return False, None 85 | #dh = int(h * (1 - crop_rate) / 2) 86 | #dw = int(w * (1 - crop_rate) / 2) 87 | #img_ = img[dh:h-dh, dw:w-dw, :].copy() 88 | img_ = transform.resize(img_, (out_size, out_size)) 89 | #time_s = time.time() 90 | #io.imsave(target_path, img_) 91 | #print("writetime:" + str(time.time() - time_s)) 92 | img_ = (np.maximum(np.minimum(255, img_ * 256), 0)).astype(np.uint8) 93 | 94 | return True, img_ 95 | else: 96 | return False, None 97 | 98 | def deal(datas, rootpath): 99 | paths = [] 100 | path = [] 101 | last = '' 102 | min_hw = [1000, 1000] 103 | for data in datas: 104 | paras = data.split(',') 105 | names = paras[0].split('\\') 106 | if (last != names[0] + '\\' + names[1]): 107 | if min_hw[0] > threshold and min_hw[1] > threshold: 108 | paths.append(path) 109 | path = [] 110 | min_hw = [1000, 1000] 111 | last = names[0] + '\\' + names[1] 112 | min_hw[0] = min(min_hw[0], int(paras[4])) 113 | min_hw[1] = min(min_hw[1], int(paras[5])) 114 | path.append(os.path.join(rootpath, 'aligned_images_DB', names[0], names[1], 'aligned_detect_' + names[2])) 115 | 116 | if min_hw[0] > threshold and min_hw[1] > threshold: 117 | paths.append(path) 118 | return paths 119 | 120 | def get_num(path): 121 | t = path.find("a.") 122 | return int(path[t + 2 : -4]) 123 | 124 | def deal_vid(paths): 125 | paths = sorted(paths, key=lambda path: get_num(path)) 126 | ans = [] 127 | for i in range(len(paths)): 128 | if (i % 10 == 0): 129 | ans.append(paths[i]) 130 | return ans 131 | 132 | def get_paths(rootpath): 133 | data_path = os.path.join(rootpath, 'frame_images_DB') 134 | data_path = rootpath 135 | paths = [] 136 | for root, _, fnames in os.walk(data_path): 137 | temp = [] 138 | for fname in fnames: 139 | if not fname.endswith('.jpg'): 140 | continue 141 | temp.append(os.path.join(root, fname)) 142 | temp = deal_vid(temp) 143 | print(len(temp)) 144 | paths.append([temp]) 145 | random.shuffle(paths) 146 | paths_ = [] 147 | for human in paths: 148 | for vid in human: 149 | paths_.append(vid) 150 | return paths_ 151 | 152 | def transform_name(name, root): 153 | name = name[len(root):] 154 | name = name.replace('/', '__').replace(' ', '_') 155 | return name 156 | 157 | def neighbor_index(index, length, n_size): 158 | left = max(0, index[1] - n_size) 159 | right = min(length - 1, index[1] + n_size) 160 | return [index[0], np.random.randint(left, right + 1)] 161 | 162 | def make_list(paths, phase, opt): 163 | indexs = [] 164 | for i in range(len(paths)): 165 | for j in range(paths[i]['len']): 166 | indexs.append([i, j]) 167 | 168 | ans_list = [] 169 | for i in range(len(paths)): 170 | for j in range(paths[i]['len']): 171 | for k in range(opt.A_repeat_num): 172 | ''' 173 | if (phase == 'val' or phase == 'test'): 174 | index = indexs[np.random.randint(len(indexs))] 175 | context = paths[i]['label_names'][j] \ 176 | + '&' + paths[index[0]]['label_names'][index[1]] + '&' + paths[index[0]]['img_names'][index[1]] \ 177 | + '&' + paths[i]['img_names'][j] \ 178 | + '\n' 179 | ans_list.append(context) 180 | else: 181 | ''' 182 | if (random.random() < opt.same_style_rate): 183 | index_B = neighbor_index([i, j], paths[i]['len'], opt.neighbor_size) 184 | else: 185 | index_B = indexs[np.random.randint(len(indexs))] 186 | index_C = neighbor_index(index_B, paths[index_B[0]]['len'], opt.neighbor_size) 187 | index_D = index_C 188 | while (index_D[0] == index_C[0]): 189 | index_D = indexs[np.random.randint(len(indexs))] 190 | context = paths[i]['label_names'][j] \ 191 | + '&' + paths[index_B[0]]['label_names'][index_B[1]] + '&' + paths[index_B[0]]['img_names'][index_B[1]] \ 192 | + '&' + paths[i]['img_names'][j] \ 193 | + '&' + paths[index_C[0]]['label_names'][index_C[1]] + '&' + paths[index_C[0]]['img_names'][index_C[1]] \ 194 | + '&' + paths[index_D[0]]['label_names'][index_D[1]] + '&' + paths[index_D[0]]['img_names'][index_D[1]] \ 195 | + '\n' 196 | ans_list.append(context) 197 | 198 | f = open(os.path.join(opt.target_path, phase + '_list.txt'), 'w') 199 | f.writelines(ans_list) 200 | 201 | opt = GenerateDataOptions().parse(save=False) 202 | opt.neighbor_size *= FPS 203 | paths = get_paths(opt.source_path) 204 | label_path = os.path.join(opt.target_path, 'keypoints/') 205 | img_path = os.path.join(opt.target_path, 'img/') 206 | if (not os.path.exists(label_path)): 207 | os.makedirs(label_path) 208 | if (not os.path.exists(img_path)): 209 | os.makedirs(img_path) 210 | paths_ = [] 211 | ans = 0 212 | for i in range(len(paths)): 213 | ans += len(paths[i]) 214 | print(ans) 215 | tot = 0 216 | for i in range(len(paths)): 217 | if (opt.copy_data): 218 | print(str(tot) + ' ' + str(len(paths[i]))) 219 | img_names = [] 220 | label_names = [] 221 | for j in range(len(paths[i])): 222 | img_name = transform_name(paths[i][j], opt.source_path) 223 | label_name = img_name[: -3] + 'txt' 224 | if (opt.copy_data): 225 | useable, img = get_img(paths[i][j], 'datasets/YouTubeFaces/temp.jpg') 226 | if not useable: 227 | continue 228 | #img = io.imread('datasets/YouTubeFaces/temp.jpg') 229 | detected, keys, box = get_keys(img) 230 | if detected: 231 | tot += 1 232 | io.imsave(os.path.join(img_path, img_name), img) 233 | np.savetxt(os.path.join(label_path, label_name), keys, fmt='%d', delimiter=',') 234 | img_names.append(img_path + img_name) 235 | label_names.append(label_path + label_name) 236 | else: 237 | if os.path.exists(label_path + label_name): 238 | img_names.append(img_path + img_name) 239 | label_names.append(label_path + label_name) 240 | paths_.append({'label_names': label_names, 'img_names':img_names, 'len': len(img_names)}) 241 | paths = paths_ 242 | img_num = 0 243 | for i in range(len(paths)): 244 | img_num += len(paths[i]['img_names']) 245 | print(img_num) 246 | val_size = int(img_num * opt.val_ratio) 247 | test_size = int(img_num * opt.test_ratio) 248 | did_val = False 249 | tot = 0 250 | last = 0 251 | last_name = '' 252 | for i in range(len(paths)): 253 | tot += len(paths[i]['img_names']) 254 | if (len(paths[i]['img_names']) > 0): 255 | name = paths[i]['img_names'][0].split('__')[1] 256 | if ((not did_val) and tot >= val_size and last_name != name): 257 | print("=============") 258 | did_val = True 259 | tot = 0 260 | last = i + 1 261 | make_list(paths[0:i], 'val', opt) 262 | elif (did_val and tot >= test_size and last_name != name): 263 | print("=============") 264 | make_list(paths[last:i], 'test', opt) 265 | last = i + 1 266 | break 267 | print(name) 268 | last_name = name 269 | make_list(paths[last:], 'train', opt) 270 | 271 | -------------------------------------------------------------------------------- /img/dance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cxjyxxme/pix2pixSC/1851e934b333e89b98b3c0df9962912201fcb376/img/dance.png -------------------------------------------------------------------------------- /img/face.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cxjyxxme/pix2pixSC/1851e934b333e89b98b3c0df9962912201fcb376/img/face.png -------------------------------------------------------------------------------- /img/fig.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cxjyxxme/pix2pixSC/1851e934b333e89b98b3c0df9962912201fcb376/img/fig.png -------------------------------------------------------------------------------- /img/scene.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cxjyxxme/pix2pixSC/1851e934b333e89b98b3c0df9962912201fcb376/img/scene.png -------------------------------------------------------------------------------- /img/small_dance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cxjyxxme/pix2pixSC/1851e934b333e89b98b3c0df9962912201fcb376/img/small_dance.png -------------------------------------------------------------------------------- /img/small_face.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cxjyxxme/pix2pixSC/1851e934b333e89b98b3c0df9962912201fcb376/img/small_face.png -------------------------------------------------------------------------------- /img/small_scene.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cxjyxxme/pix2pixSC/1851e934b333e89b98b3c0df9962912201fcb376/img/small_scene.png -------------------------------------------------------------------------------- /infer_face.py: -------------------------------------------------------------------------------- 1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 3 | import os 4 | from collections import OrderedDict 5 | from torch.autograd import Variable 6 | from options.test_options import TestOptions 7 | from data.data_loader import CreateFaceConDataLoader 8 | from models.models import create_model 9 | import util.util as util 10 | from util.visualizer import Visualizer 11 | from util import html 12 | import random 13 | import torch 14 | from skimage import transform,io 15 | import dlib 16 | import shutil 17 | import numpy as np 18 | 19 | predictor_path = os.path.join('./datasets/', 'shape_predictor_68_face_landmarks.dat') 20 | detector = dlib.get_frontal_face_detector() 21 | predictor = dlib.shape_predictor(predictor_path) 22 | 23 | def get_keys(img, get_point = True): 24 | dets = detector(img, 1) 25 | detected = False 26 | points = [] 27 | left_up_x = 10000000 28 | left_up_y = 10000000 29 | right_down_x = -1 30 | right_down_y = -1 31 | 32 | if len(dets) > 0: 33 | detected = True 34 | if (get_point == False): 35 | return True, [], [dets[0].left(), dets[0].top(), dets[0].right(), dets[0].bottom()] 36 | else: 37 | shape = predictor(img, dets[0]) 38 | points = np.empty([68, 2], dtype=int) 39 | for b in range(68): 40 | points[b,0] = shape.part(b).x 41 | points[b,1] = shape.part(b).y 42 | left_up_x = min(left_up_x, points[b, 0]) 43 | left_up_y = min(left_up_y, points[b, 1]) 44 | right_down_x = max(right_down_x, points[b, 0]) 45 | right_down_y = max(right_down_y, points[b, 1]) 46 | if (abs(points[8, 0] - points[1, 0]) == 0 or abs(points[8, 0] - points[15, 0]) == 0): 47 | detected = False 48 | else: 49 | r = float(abs(points[8, 0] - points[1, 0])) / abs(points[8, 0] - points[15, 0]) 50 | r = min(r, 1 / r) 51 | 52 | return detected, points, [left_up_x, left_up_y, right_down_x, right_down_y] 53 | 54 | def get_img(path): 55 | face_rate = 0.5 56 | up_center_rate = 0.2 57 | out_size = 512 58 | img = io.imread(path) 59 | detected, _, box = get_keys(img) 60 | if (detected): 61 | h, w, c = img.shape 62 | b_h = box[3] - box[1] 63 | b_w = box[2] - box[0] 64 | dh = int((b_h / face_rate - b_h) / 2) 65 | dw = int((b_w / face_rate - b_w) / 2) 66 | ddh = int(b_h * up_center_rate) 67 | if (box[1] - dh - ddh < 0 or box[3] + dh - ddh >= h or box[0] - dw < 0 or box[2] + dw >= w): 68 | return False, None 69 | else: 70 | img_ = img[box[1] - dh - ddh : box[3] + dh - ddh, box[0] - dw : box[2] + dw, :].copy() 71 | if (img_.shape[0] < out_size / 4 or img_.shape[1] < out_size / 4): 72 | return False, None 73 | img_ = transform.resize(img_, (out_size, out_size)) 74 | img_ = (np.maximum(np.minimum(255, img_ * 256), 0)).astype(np.uint8) 75 | return True, img_ 76 | else: 77 | return False, None 78 | 79 | def get_info(path, img_path, key_path): 80 | useable, img = get_img(path) 81 | assert useable 82 | detected, keys, box = get_keys(img) 83 | assert detected 84 | io.imsave(img_path, img) 85 | np.savetxt(key_path, keys, fmt='%d', delimiter=',') 86 | return img_path, key_path 87 | 88 | 89 | def make_data(): 90 | assert os.path.exists('inference/infer_list.txt') 91 | with open('inference/infer_list.txt', 'r') as f: 92 | tasks = f.readlines() 93 | if os.path.exists('inference/data'): 94 | shutil.rmtree('inference/data') 95 | if os.path.exists('inference/output'): 96 | shutil.rmtree('inference/output') 97 | os.makedirs('inference/data') 98 | os.makedirs('inference/output') 99 | img_path = 'inference/data/img' 100 | key_path = 'inference/data/keypoints' 101 | os.makedirs(img_path) 102 | os.makedirs(key_path) 103 | 104 | ans_list = [] 105 | for i in range(len(tasks)): 106 | name = '__Mr_'+str(i)+'__0__a.' 107 | img1, key1 = get_info(tasks[i].split(' ')[0], os.path.join(img_path, name + '0.jpg'), os.path.join(key_path, name + '0.txt')) 108 | img2, key2 = get_info(tasks[i].split(' ')[1].rstrip('\n'), os.path.join(img_path, name + '1.jpg'), os.path.join(key_path, name + '1.txt')) 109 | 110 | context = key1 \ 111 | + '&' + key2 + '&' + img2 \ 112 | + '&' + img1 \ 113 | + '&' + key1 + '&' + img1 \ 114 | + '&' + key1 + '&' + img1 \ 115 | + '\n' 116 | ans_list.append(context) 117 | f = open('inference/data/infer_list.txt', 'w') 118 | f.writelines(ans_list) 119 | 120 | make_data() 121 | opt = TestOptions().parse(save=False) 122 | opt.nThreads = 1 # test code only supports nThreads = 1 123 | opt.batchSize = 1 # test code only supports batchSize = 1 124 | opt.serial_batches = True # no shuffle 125 | opt.no_flip = True # no flip 126 | 127 | data_loader = CreateFaceConDataLoader(opt) 128 | dataset = data_loader.load_data() 129 | visualizer = Visualizer(opt) 130 | # create website 131 | web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s_%s' % (opt.phase, opt.which_epoch, str(opt.serial_batches))) 132 | webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch)) 133 | 134 | # test 135 | if not opt.engine and not opt.onnx: 136 | model = create_model(opt) 137 | if opt.data_type == 16: 138 | model.half() 139 | elif opt.data_type == 8: 140 | model.type(torch.uint8) 141 | 142 | if opt.verbose: 143 | print(model) 144 | else: 145 | from run_engine import run_trt_engine, run_onnx 146 | 147 | for i, data in enumerate(dataset, start=0): 148 | if opt.data_type == 16: 149 | data['A'] = data['A'].half() 150 | data['A2'] = data['A2'].half() 151 | data['B'] = data['B'].half() 152 | data['B2'] = data['B2'].half() 153 | elif opt.data_type == 8: 154 | data['A'] = data['A'].uint8() 155 | data['A2'] = data['A2'].uint8() 156 | data['B'] = data['B'].uint8() 157 | data['B2'] = data['B2'].uint8() 158 | if opt.export_onnx: 159 | print ("Exporting to ONNX: ", opt.export_onnx) 160 | assert opt.export_onnx.endswith("onnx"), "Export model file should end with .onnx" 161 | torch.onnx.export(model, [data['label'], data['inst']], 162 | opt.export_onnx, verbose=True) 163 | exit(0) 164 | minibatch = 1 165 | if opt.engine: 166 | generated = run_trt_engine(opt.engine, minibatch, [data['label'], data['inst']]) 167 | elif opt.onnx: 168 | generated = run_onnx(opt.onnx, opt.data_type, minibatch, [data['label'], data['inst']]) 169 | else: 170 | generated = model.inference(data['A'], data['B'], data['B2']) 171 | 172 | visuals = OrderedDict([('input_label', util.tensor2label(data['A'][0], 0)), 173 | ('real_image', util.tensor2im(data['A2'][0])), 174 | ('synthesized_image', util.tensor2im(generated.data[0])), 175 | ('B', util.tensor2label(data['B'][0], 0)), 176 | ('B2', util.tensor2im(data['B2'][0]))]) 177 | img_path = data['path'] 178 | img_path[0] = str(i) 179 | print('process image... %s' % img_path) 180 | visualizer.save_images(webpage, visuals, img_path) 181 | 182 | webpage.save() 183 | -------------------------------------------------------------------------------- /inference/data/img/__Mr_0__0__a.0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cxjyxxme/pix2pixSC/1851e934b333e89b98b3c0df9962912201fcb376/inference/data/img/__Mr_0__0__a.0.jpg -------------------------------------------------------------------------------- /inference/data/img/__Mr_0__0__a.1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cxjyxxme/pix2pixSC/1851e934b333e89b98b3c0df9962912201fcb376/inference/data/img/__Mr_0__0__a.1.jpg -------------------------------------------------------------------------------- /inference/data/img/__Mr_1__0__a.0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cxjyxxme/pix2pixSC/1851e934b333e89b98b3c0df9962912201fcb376/inference/data/img/__Mr_1__0__a.0.jpg -------------------------------------------------------------------------------- /inference/data/img/__Mr_1__0__a.1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cxjyxxme/pix2pixSC/1851e934b333e89b98b3c0df9962912201fcb376/inference/data/img/__Mr_1__0__a.1.jpg -------------------------------------------------------------------------------- /inference/data/infer_list.txt: -------------------------------------------------------------------------------- 1 | inference/data/keypoints/__Mr_0__0__a.0.txt&inference/data/keypoints/__Mr_0__0__a.1.txt&inference/data/img/__Mr_0__0__a.1.jpg&inference/data/img/__Mr_0__0__a.0.jpg&inference/data/keypoints/__Mr_0__0__a.0.txt&inference/data/img/__Mr_0__0__a.0.jpg&inference/data/keypoints/__Mr_0__0__a.0.txt&inference/data/img/__Mr_0__0__a.0.jpg 2 | inference/data/keypoints/__Mr_1__0__a.0.txt&inference/data/keypoints/__Mr_1__0__a.1.txt&inference/data/img/__Mr_1__0__a.1.jpg&inference/data/img/__Mr_1__0__a.0.jpg&inference/data/keypoints/__Mr_1__0__a.0.txt&inference/data/img/__Mr_1__0__a.0.jpg&inference/data/keypoints/__Mr_1__0__a.0.txt&inference/data/img/__Mr_1__0__a.0.jpg 3 | -------------------------------------------------------------------------------- /inference/data/keypoints/__Mr_0__0__a.0.txt: -------------------------------------------------------------------------------- 1 | 126,246 2 | 130,279 3 | 134,312 4 | 140,345 5 | 153,374 6 | 175,397 7 | 202,414 8 | 231,427 9 | 264,430 10 | 295,426 11 | 322,412 12 | 345,394 13 | 362,370 14 | 371,340 15 | 376,310 16 | 381,279 17 | 384,249 18 | 162,209 19 | 176,190 20 | 198,180 21 | 222,181 22 | 244,190 23 | 284,192 24 | 307,185 25 | 330,185 26 | 350,196 27 | 362,215 28 | 265,228 29 | 266,247 30 | 267,266 31 | 268,285 32 | 240,308 33 | 253,310 34 | 265,313 35 | 277,311 36 | 288,309 37 | 187,233 38 | 201,222 39 | 218,223 40 | 231,237 41 | 217,239 42 | 199,238 43 | 295,239 44 | 309,226 45 | 325,226 46 | 338,238 47 | 326,242 48 | 309,242 49 | 216,348 50 | 237,337 51 | 254,331 52 | 264,333 53 | 274,331 54 | 289,339 55 | 306,350 56 | 289,356 57 | 275,358 58 | 263,358 59 | 252,357 60 | 237,355 61 | 224,347 62 | 253,342 63 | 264,343 64 | 274,343 65 | 299,349 66 | 274,344 67 | 264,344 68 | 253,343 69 | -------------------------------------------------------------------------------- /inference/data/keypoints/__Mr_0__0__a.1.txt: -------------------------------------------------------------------------------- 1 | 129,217 2 | 127,253 3 | 128,289 4 | 132,326 5 | 147,359 6 | 172,387 7 | 199,410 8 | 226,429 9 | 255,435 10 | 283,429 11 | 305,405 12 | 328,383 13 | 348,359 14 | 365,332 15 | 378,304 16 | 383,274 17 | 385,244 18 | 161,188 19 | 185,179 20 | 211,181 21 | 236,188 22 | 260,199 23 | 304,201 24 | 326,194 25 | 349,191 26 | 370,194 27 | 383,208 28 | 281,231 29 | 279,256 30 | 278,281 31 | 277,306 32 | 241,313 33 | 256,319 34 | 271,325 35 | 285,323 36 | 298,320 37 | 188,217 38 | 205,211 39 | 223,214 40 | 239,230 41 | 220,228 42 | 202,226 43 | 309,237 44 | 326,224 45 | 344,225 46 | 357,235 47 | 345,240 48 | 327,239 49 | 193,332 50 | 223,335 51 | 251,338 52 | 266,343 53 | 282,342 54 | 302,345 55 | 321,346 56 | 298,373 57 | 277,385 58 | 260,385 59 | 243,381 60 | 217,364 61 | 200,336 62 | 249,346 63 | 265,350 64 | 281,349 65 | 312,349 66 | 279,369 67 | 262,370 68 | 246,365 69 | -------------------------------------------------------------------------------- /inference/data/keypoints/__Mr_1__0__a.0.txt: -------------------------------------------------------------------------------- 1 | 129,217 2 | 127,253 3 | 128,289 4 | 132,326 5 | 147,359 6 | 172,387 7 | 199,410 8 | 226,429 9 | 255,435 10 | 283,429 11 | 305,405 12 | 328,383 13 | 348,359 14 | 365,332 15 | 378,304 16 | 383,274 17 | 385,244 18 | 161,188 19 | 185,179 20 | 211,181 21 | 236,188 22 | 260,199 23 | 304,201 24 | 326,194 25 | 349,191 26 | 370,194 27 | 383,208 28 | 281,231 29 | 279,256 30 | 278,281 31 | 277,306 32 | 241,313 33 | 256,319 34 | 271,325 35 | 285,323 36 | 298,320 37 | 188,217 38 | 205,211 39 | 223,214 40 | 239,230 41 | 220,228 42 | 202,226 43 | 309,237 44 | 326,224 45 | 344,225 46 | 357,235 47 | 345,240 48 | 327,239 49 | 193,332 50 | 223,335 51 | 251,338 52 | 266,343 53 | 282,342 54 | 302,345 55 | 321,346 56 | 298,373 57 | 277,385 58 | 260,385 59 | 243,381 60 | 217,364 61 | 200,336 62 | 249,346 63 | 265,350 64 | 281,349 65 | 312,349 66 | 279,369 67 | 262,370 68 | 246,365 69 | -------------------------------------------------------------------------------- /inference/data/keypoints/__Mr_1__0__a.1.txt: -------------------------------------------------------------------------------- 1 | 126,246 2 | 130,279 3 | 134,312 4 | 140,345 5 | 153,374 6 | 175,397 7 | 202,414 8 | 231,427 9 | 264,430 10 | 295,426 11 | 322,412 12 | 345,394 13 | 362,370 14 | 371,340 15 | 376,310 16 | 381,279 17 | 384,249 18 | 162,209 19 | 176,190 20 | 198,180 21 | 222,181 22 | 244,190 23 | 284,192 24 | 307,185 25 | 330,185 26 | 350,196 27 | 362,215 28 | 265,228 29 | 266,247 30 | 267,266 31 | 268,285 32 | 240,308 33 | 253,310 34 | 265,313 35 | 277,311 36 | 288,309 37 | 187,233 38 | 201,222 39 | 218,223 40 | 231,237 41 | 217,239 42 | 199,238 43 | 295,239 44 | 309,226 45 | 325,226 46 | 338,238 47 | 326,242 48 | 309,242 49 | 216,348 50 | 237,337 51 | 254,331 52 | 264,333 53 | 274,331 54 | 289,339 55 | 306,350 56 | 289,356 57 | 275,358 58 | 263,358 59 | 252,357 60 | 237,355 61 | 224,347 62 | 253,342 63 | 264,343 64 | 274,343 65 | 299,349 66 | 274,344 67 | 264,344 68 | 253,343 69 | -------------------------------------------------------------------------------- /inference/infer_list.txt: -------------------------------------------------------------------------------- 1 | inference/test_imgs/1.jpg inference/test_imgs/2.jpg 2 | inference/test_imgs/2.jpg inference/test_imgs/1.jpg 3 | -------------------------------------------------------------------------------- /inference/test_imgs/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cxjyxxme/pix2pixSC/1851e934b333e89b98b3c0df9962912201fcb376/inference/test_imgs/1.jpg -------------------------------------------------------------------------------- /inference/test_imgs/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cxjyxxme/pix2pixSC/1851e934b333e89b98b3c0df9962912201fcb376/inference/test_imgs/2.jpg -------------------------------------------------------------------------------- /judge_face.py: -------------------------------------------------------------------------------- 1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 3 | import os 4 | from collections import OrderedDict 5 | from options.generate_data_options import GenerateDataOptions 6 | from data.data_loader import CreateDataLoader 7 | from PIL import Image 8 | import copy 9 | import numpy as np 10 | from shutil import copyfile 11 | import random 12 | import glob 13 | from skimage import transform,io 14 | import dlib 15 | import sys 16 | import time 17 | import math 18 | 19 | threshold = 130 20 | crop_rate = 0.6 21 | face_rate = 0.5 22 | up_center_rate = 0.2 23 | out_size = 512 24 | side_face_threshold = 0.7 25 | predictor_path = os.path.join('./datasets/', 'shape_predictor_68_face_landmarks.dat') 26 | detector = dlib.get_frontal_face_detector() 27 | predictor = dlib.shape_predictor(predictor_path) 28 | time_last = time.time() 29 | 30 | pts_max_score = 1000#.1 * 256 31 | pic_max_score = 10#0.007 * 256 32 | #pic_size = 256.0 33 | pic_size = 1.0 34 | 35 | def get_keys(img, get_point = True): 36 | global time_last 37 | time_start = time.time() 38 | #print("other:" + str(time_start - time_last)) 39 | dets = detector(img, 1) 40 | #print(time.time() - time_start) 41 | time_last = time.time() 42 | detected = False 43 | points = [] 44 | left_up_x = 10000000 45 | left_up_y = 10000000 46 | right_down_x = -1 47 | right_down_y = -1 48 | 49 | if len(dets) > 0: 50 | detected = True 51 | if (get_point == False): 52 | return True, [], [dets[0].left(), dets[0].top(), dets[0].right(), dets[0].bottom()] 53 | else: 54 | shape = predictor(img, dets[0]) 55 | points = np.empty([68, 2], dtype=int) 56 | for b in range(68): 57 | points[b,0] = shape.part(b).x 58 | points[b,1] = shape.part(b).y 59 | left_up_x = min(left_up_x, points[b, 0]) 60 | left_up_y = min(left_up_y, points[b, 1]) 61 | right_down_x = max(right_down_x, points[b, 0]) 62 | right_down_y = max(right_down_y, points[b, 1]) 63 | ''' 64 | if (abs(points[8, 0] - points[1, 0]) == 0 or abs(points[8, 0] - points[15, 0]) == 0): 65 | detected = False 66 | else: 67 | r = float(abs(points[8, 0] - points[1, 0])) / abs(points[8, 0] - points[15, 0]) 68 | r = min(r, 1 / r) 69 | if (r < side_face_threshold): 70 | detected = False 71 | ''' 72 | 73 | return detected, points, [left_up_x, left_up_y, right_down_x, right_down_y] 74 | def transfer_list(list1): 75 | list2 = [] 76 | for v in list1: 77 | list2.append(v[:-len('synthesized_image.jpg')] + 'real_image.jpg') 78 | #list2.append(v[:-len('synthesized_image.jpg')] + 'B2.jpg') 79 | return list2 80 | ''' 81 | 82 | def transfer_list(list1): 83 | list2 = [] 84 | for v in list1: 85 | id=v.split('_')[-1][:-4] 86 | t = 'results/label2city_256p_face_102/test_final_latest_True/images/'+id+'_real_image.jpg' 87 | list2.append(t) 88 | return list2 89 | ''' 90 | 91 | def dist(p1, p2): 92 | return math.sqrt((p1[0] - p2[0]) * (p1[0] - p2[0]) + (p1[1] - p2[1]) * (p1[1] - p2[1])) 93 | 94 | path = 'results/label2city_256p_face_104/test_final_latest_True/images' 95 | 96 | list1 = [] 97 | for root, _, fnames in os.walk(path): 98 | for fname in fnames: 99 | if (fname.endswith('_synthesized_image.jpg')): 100 | list1.append(os.path.join(root, fname)) 101 | list2 = transfer_list(list1) 102 | score = 0 103 | print(len(list1)) 104 | for i in range(len(list1)): 105 | img1 = io.imread(list1[i]) 106 | img2 = io.imread(list2[i]) 107 | detected1, points1, _ = get_keys(img1) 108 | detected2, points2, _ = get_keys(img2) 109 | if (not detected1) or (not detected2): 110 | temp_score = pic_max_score 111 | else: 112 | points1 = points1 / pic_size 113 | points2 = points2 / pic_size 114 | temp_score = 0 115 | for j in range(68): 116 | d = dist(points1[j, :], points2[j, :]) 117 | d = min(pts_max_score, d) 118 | temp_score += d 119 | temp_score = min(pic_max_score, temp_score / 68) 120 | score += temp_score 121 | if (i % 10 == 0): 122 | print(i) 123 | print(score / (i + 1)/256) 124 | print(temp_score/256) 125 | score /= len(list1) 126 | print(score/256) 127 | 128 | 129 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cxjyxxme/pix2pixSC/1851e934b333e89b98b3c0df9962912201fcb376/models/__init__.py -------------------------------------------------------------------------------- /models/base_model.py: -------------------------------------------------------------------------------- 1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 3 | import os 4 | import torch 5 | import sys 6 | 7 | class BaseModel(torch.nn.Module): 8 | def name(self): 9 | return 'BaseModel' 10 | 11 | def initialize(self, opt): 12 | self.opt = opt 13 | self.gpu_ids = opt.gpu_ids 14 | self.isTrain = opt.isTrain 15 | self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor 16 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) 17 | 18 | def set_input(self, input): 19 | self.input = input 20 | 21 | def forward(self): 22 | pass 23 | 24 | # used in test time, no backprop 25 | def test(self): 26 | pass 27 | 28 | def get_image_paths(self): 29 | pass 30 | 31 | def optimize_parameters(self): 32 | pass 33 | 34 | def get_current_visuals(self): 35 | return self.input 36 | 37 | def get_current_errors(self): 38 | return {} 39 | 40 | def save(self, label): 41 | pass 42 | 43 | # helper saving function that can be used by subclasses 44 | def save_network(self, network, network_label, epoch_label, gpu_ids): 45 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 46 | save_path = os.path.join(self.save_dir, save_filename) 47 | torch.save(network.cpu().state_dict(), save_path) 48 | if len(gpu_ids) and torch.cuda.is_available(): 49 | network.cuda() 50 | 51 | # helper loading function that can be used by subclasses 52 | def load_network(self, network, network_label, epoch_label, save_dir=''): 53 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 54 | if not save_dir: 55 | save_dir = self.save_dir 56 | save_path = os.path.join(save_dir, save_filename) 57 | if not os.path.isfile(save_path): 58 | print('%s not exists yet!' % save_path) 59 | if network_label == 'G': 60 | raise('Generator must exist!') 61 | else: 62 | #network.load_state_dict(torch.load(save_path)) 63 | try: 64 | network.load_state_dict(torch.load(save_path)) 65 | except: 66 | pretrained_dict = torch.load(save_path) 67 | model_dict = network.state_dict() 68 | try: 69 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 70 | network.load_state_dict(pretrained_dict) 71 | if self.opt.verbose: 72 | print('Pretrained network %s has excessive layers; Only loading layers that are used' % network_label) 73 | except: 74 | print('Pretrained network %s has fewer layers; The following are not initialized:' % network_label) 75 | for k, v in pretrained_dict.items(): 76 | if v.size() == model_dict[k].size(): 77 | model_dict[k] = v 78 | 79 | if sys.version_info >= (3,0): 80 | not_initialized = set() 81 | else: 82 | from sets import Set 83 | not_initialized = Set() 84 | 85 | for k, v in model_dict.items(): 86 | if k not in pretrained_dict or v.size() != pretrained_dict[k].size(): 87 | not_initialized.add(k.split('.')[0]) 88 | 89 | print(sorted(not_initialized)) 90 | network.load_state_dict(model_dict) 91 | 92 | def update_learning_rate(): 93 | pass 94 | -------------------------------------------------------------------------------- /models/models.py: -------------------------------------------------------------------------------- 1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 3 | import torch 4 | 5 | def create_model(opt): 6 | if opt.model == 'pix2pixHD': 7 | from .pix2pixHD_model import Pix2PixHDModel, InferenceModel 8 | if opt.isTrain: 9 | model = Pix2PixHDModel() 10 | else: 11 | model = InferenceModel() 12 | elif opt.model == 'c_pix2pixHD': 13 | from .c_pix2pixHD_model import CPix2PixHDModel, InferenceModel 14 | if opt.isTrain: 15 | model = CPix2PixHDModel() 16 | else: 17 | model = InferenceModel() 18 | elif opt.model == 'cm_pix2pixHD': 19 | from .cm_pix2pixHD_model import CMPix2PixHDModel, InferenceModel 20 | if opt.isTrain: 21 | model = CMPix2PixHDModel() 22 | else: 23 | model = InferenceModel() 24 | else: 25 | from .ui_model import UIModel 26 | model = UIModel() 27 | model.initialize(opt) 28 | if opt.verbose: 29 | print("model [%s] was created" % (model.name())) 30 | 31 | if opt.isTrain and len(opt.gpu_ids): 32 | model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids) 33 | 34 | return model 35 | -------------------------------------------------------------------------------- /new_scripts/infer_face.sh: -------------------------------------------------------------------------------- 1 | ### Adding instances and encoded features 2 | python infer_face.py --name label2city_256p_face_102 --label_nc 0 --no_canny_edge --input_nc 15 --fineSize 256 --dataroot './inference/data/' --model 'c_pix2pixHD' --no_instance --loadSize 256 --phase 'infer' --gpu_ids 0; 3 | -------------------------------------------------------------------------------- /new_scripts/table_face.sh: -------------------------------------------------------------------------------- 1 | ### Adding instances and encoded features 2 | python gen_vis_face.py --name label2city_256p_face_102 --label_nc 0 --test_delta_path 'vis_delta.txt' --no_canny_edge --input_nc 15 --fineSize 256 --dataroot './datasets/FaceForensics3/' --model 'c_pix2pixHD' --no_instance --loadSize 256 --phase 'val' --gpu_ids 0; 3 | -------------------------------------------------------------------------------- /new_scripts/table_pose.sh: -------------------------------------------------------------------------------- 1 | ### Adding instances and encoded features 2 | python gen_vis_pose.py --name label2city_256p_pose_106 --label_nc 0 --test_delta_path "vis_delta.txt" --no_canny_edge --input_nc 23 --do_pose_dist_map --fineSize 256 --dataroot './datasets/YouTubePose2/' --model 'c_pix2pixHD' --no_instance --loadSize 256 --phase 'val' --gpu_ids 1; 3 | -------------------------------------------------------------------------------- /new_scripts/table_scene.sh: -------------------------------------------------------------------------------- 1 | ################################ Testing ################################ 2 | # first precompute and cluster all features 3 | #python encode_features.py --name label2city_256p_feat; 4 | # use instance-wise features 5 | 6 | python gen_vis_bdd.py --name label2city_256p_face_108 --label_nc 42 --use_new_label --resize_or_crop 'scale_width_and_crop' --fineSize 256 --label_indexs '0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,255' --dataroot './datasets/bdd2/' --model 'c_pix2pixHD' --gpu_ids 2 --no_instance --loadSize 256 --phase 'val'; 7 | -------------------------------------------------------------------------------- /new_scripts/test_face.sh: -------------------------------------------------------------------------------- 1 | ### Adding instances and encoded features 2 | python test_face.py --name label2city_256p_face_102 --how_many 2500 --label_nc 0 --no_canny_edge --input_nc 15 --fineSize 256 --dataroot './datasets/FaceForensics3/' --model 'c_pix2pixHD' --no_instance --loadSize 256 --phase 'test_final' --gpu_ids 0; 3 | -------------------------------------------------------------------------------- /new_scripts/test_pose.sh: -------------------------------------------------------------------------------- 1 | ### Adding instances and encoded features 2 | python test_pose.py --name label2city_256p_pose_106 --how_many 2500 --label_nc 0 --test_delta_path "vis_delta.txt" --no_canny_edge --input_nc 23 --do_pose_dist_map --fineSize 256 --dataroot './datasets/YouTubePose2/' --model 'c_pix2pixHD' --no_instance --loadSize 256 --phase 'test_final' --gpu_ids 3; 3 | -------------------------------------------------------------------------------- /new_scripts/test_scene.sh: -------------------------------------------------------------------------------- 1 | ################################ Testing ################################ 2 | # first precompute and cluster all features 3 | #python encode_features.py --name label2city_256p_feat; 4 | # use instance-wise features 5 | 6 | python test_seg.py --name label2city_256p_face_108 --use_new_label --how_many 2500 --label_nc 42 --resize_or_crop 'scale_width_and_crop' --fineSize 256 --label_indexs '0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,255' --dataroot './datasets/bdd2/' --model 'c_pix2pixHD' --gpu_ids 3 --no_instance --loadSize 256 --phase 'test_final'; 7 | -------------------------------------------------------------------------------- /new_scripts/train_face.sh: -------------------------------------------------------------------------------- 1 | ### Adding instances and encoded features 2 | python train_face.py --name label2city_256p_face_1 --use_self_loss --self_vgg_mul 20 --no_D_label --no_ganFeat_loss --label_nc 0 --no_canny_edge --style_stage_mul "0:0,500000:0.1,550000:0.3,600000:0.5,700000:1" --niter_iter 1500000 --use_style_iter 0 --niter_decay_iter 200000 --input_nc 15 --display_freq 300 --fineSize 256 --dataroot './datasets/FaceForensics3/' --model 'c_pix2pixHD' --no_instance --gpu_ids=1 --batchSize=1 --use_iter_decay --loadSize 256 --num_D 1; 3 | -------------------------------------------------------------------------------- /new_scripts/train_pose.sh: -------------------------------------------------------------------------------- 1 | ### Adding instances and encoded features 2 | python train_pose.py --name label2city_256p_pose_1 --use_self_loss --self_vgg_mul 20 --no_D_label --do_pose_dist_map --label_nc 0 --style_stage_mul "0:0,600000:0.1,650000:0.5,700000:1" --random_drop_prob 0 --no_ganFeat_loss --no_canny_edge --niter_iter 1100000 --use_style_iter -1 --niter_decay_iter 200000 --input_nc 23 --display_freq 300 --fineSize 256 --dataroot './datasets/YouTubePose2/' --model 'c_pix2pixHD' --no_instance --gpu_ids=2 --batchSize=1 --use_iter_decay --loadSize 256 --num_D 1; 3 | -------------------------------------------------------------------------------- /new_scripts/train_scene.sh: -------------------------------------------------------------------------------- 1 | ### Adding instances and encoded features 2 | python train_seg.py --name label2city_256p_scene_1 --label_nc 42 --no_ganFeat_loss --use_new_label --no_D_label --no_canny_edge --use_self_loss --self_vgg_mul 20 --niter_iter 200000 --use_style_iter -1 --niter_decay_iter 200000 --display_freq 300 --fineSize 256 --dataroot './datasets/bdd2/' --model 'c_pix2pixHD' --no_instance --gpu_ids=3 --batchSize=1 --use_iter_decay --loadSize 256 --num_D 1 --label_indexs '0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,255'; 3 | -------------------------------------------------------------------------------- /options/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cxjyxxme/pix2pixSC/1851e934b333e89b98b3c0df9962912201fcb376/options/__init__.py -------------------------------------------------------------------------------- /options/base_options.py: -------------------------------------------------------------------------------- 1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 3 | import argparse 4 | import os 5 | from util import util 6 | import torch 7 | 8 | class BaseOptions(): 9 | def __init__(self): 10 | self.parser = argparse.ArgumentParser() 11 | self.initialized = False 12 | 13 | def initialize(self): 14 | # experiment specifics 15 | self.parser.add_argument('--name', type=str, default='label2city', help='name of the experiment. It decides where to store samples and models') 16 | self.parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') 17 | self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') 18 | self.parser.add_argument('--model', type=str, default='pix2pixHD', help='which model to use') 19 | self.parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization') 20 | self.parser.add_argument('--use_dropout', action='store_true', help='use dropout for the generator') 21 | self.parser.add_argument('--data_type', default=32, type=int, choices=[8, 16, 32], help="Supported data type i.e. 8, 16, 32 bit") 22 | self.parser.add_argument('--verbose', action='store_true', default=False, help='toggles verbose') 23 | 24 | # input/output sizes 25 | self.parser.add_argument('--batchSize', type=int, default=1, help='input batch size') 26 | self.parser.add_argument('--loadSize', type=int, default=1024, help='scale images to this size') 27 | self.parser.add_argument('--fineSize', type=int, default=512, help='then crop to this size') 28 | self.parser.add_argument('--label_nc', type=int, default=35, help='# of input label channels') 29 | self.parser.add_argument('--label_indexs', type=str, default='', help='label index ids, empty means 0..label_nc-1') 30 | self.parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels') 31 | self.parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels') 32 | 33 | # for setting inputs 34 | self.parser.add_argument('--dataroot', type=str, default='./datasets/cityscapes/') 35 | self.parser.add_argument('--resize_or_crop', type=str, default='scale_width', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]') 36 | self.parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') 37 | self.parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data argumentation') 38 | self.parser.add_argument('--nThreads', default=2, type=int, help='# threads for loading data') 39 | self.parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') 40 | 41 | # for displays 42 | self.parser.add_argument('--display_winsize', type=int, default=512, help='display window size') 43 | self.parser.add_argument('--tf_log', action='store_true', help='if specified, use tensorboard logging. Requires tensorflow installed') 44 | 45 | # for generator 46 | self.parser.add_argument('--netG', type=str, default='global', help='selects model to use for netG') 47 | self.parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer') 48 | self.parser.add_argument('--n_downsample_global', type=int, default=4, help='number of downsampling layers in netG') 49 | self.parser.add_argument('--n_blocks_global', type=int, default=9, help='number of residual blocks in the global generator network') 50 | self.parser.add_argument('--n_blocks_local', type=int, default=3, help='number of residual blocks in the local enhancer network') 51 | self.parser.add_argument('--n_local_enhancers', type=int, default=1, help='number of local enhancers to use') 52 | self.parser.add_argument('--niter_fix_global', type=int, default=0, help='number of epochs that we only train the outmost local enhancer') 53 | 54 | # for instance-wise features 55 | self.parser.add_argument('--no_instance', action='store_true', help='if specified, do *not* add instance map as input') 56 | self.parser.add_argument('--instance_feat', action='store_true', help='if specified, add encoded instance features as input') 57 | self.parser.add_argument('--label_feat', action='store_true', help='if specified, add encoded label features as input') 58 | self.parser.add_argument('--feat_num', type=int, default=3, help='vector length for encoded features') 59 | self.parser.add_argument('--load_features', action='store_true', help='if specified, load precomputed feature maps') 60 | self.parser.add_argument('--n_downsample_E', type=int, default=4, help='# of downsampling layers in encoder') 61 | self.parser.add_argument('--nef', type=int, default=16, help='# of encoder filters in the first conv layer') 62 | self.parser.add_argument('--n_clusters', type=int, default=10, help='number of clusters for features') 63 | 64 | self.parser.add_argument('--no_canny_edge', action='store_true', help='no canny edge for background') 65 | self.parser.add_argument('--no_dist_map', action='store_true', help='no dist map for background') 66 | self.parser.add_argument('--do_pose_dist_map', action='store_true', help='do pose dist map for background') 67 | 68 | self.parser.add_argument('--remove_face_labels', action='store_true', help='remove face labels to better adapt to different face shapes') 69 | self.parser.add_argument('--random_drop_prob', type=float, default=0.2, help='the probability to randomly drop each pose segment during training') 70 | 71 | self.parser.add_argument('--densepose_only', action='store_true', help='use only densepose as input') 72 | self.parser.add_argument('--openpose_only', action='store_true', help='use only openpose as input') 73 | self.parser.add_argument('--add_mask', action='store_true', help='mask of background') 74 | 75 | self.parser.add_argument('--vgg_weights', type=str, default='1,1,1,1,1', help='vgg weights of ans&guidence loss') 76 | self.parser.add_argument('--gram_weights', type=str, default='1,1,1,1,1', help='gram weights of ans&guidence loss') 77 | self.parser.add_argument('--guide_vgg_mul', type=float, default=0.0, help='') 78 | self.parser.add_argument('--guide_gram_mul', type=float, default=0.0, help='') 79 | self.parser.add_argument('--use_self_loss', action='store_true', help='mask of background') 80 | self.parser.add_argument('--self_vgg_weights', type=str, default='1,1,1,1,1', help='vgg weights of ans&guidence loss') 81 | self.parser.add_argument('--self_gram_weights', type=str, default='1,1,1,1,1', help='gram weights of ans&guidence loss') 82 | self.parser.add_argument('--self_vgg_mul', type=float, default=0.0, help='') 83 | self.parser.add_argument('--self_gram_mul', type=float, default=0.0, help='') 84 | 85 | self.parser.add_argument('--style_stage_mul', type=str, default='0:1', help='') 86 | self.parser.add_argument('--real_stage_mul', type=str, default='0:1', help='') 87 | self.parser.add_argument('--train_val_list', type=str, default='0,1,2,3,4,100000,200000,300000,400000,500000', help='') 88 | 89 | self.parser.add_argument('--no_D_label', action='store_true', help='remove label channel of input of D') 90 | self.parser.add_argument('--no_G_label', action='store_true', help='remove label channel of input of G') 91 | self.parser.add_argument('--use_new_label', action='store_true', help='use generated seg map') 92 | self.initialized = True 93 | 94 | def get_list_int(self, s): 95 | str_indexs = s.split(',') 96 | ans = [] 97 | if (str_indexs[0] != ''): 98 | for str_index in str_indexs: 99 | ans.append(int(str_index)) 100 | return ans 101 | def get_list_float(self, s): 102 | str_indexs = s.split(',') 103 | ans = [] 104 | if (str_indexs[0] != ''): 105 | for str_index in str_indexs: 106 | ans.append(float(str_index)) 107 | return ans 108 | def get_list_dict(self, s): 109 | str_indexs = s.split(',') 110 | ans = [] 111 | if (str_indexs[0] != ''): 112 | for str_index in str_indexs: 113 | temp = str_index.split(':') 114 | ans.append([int(temp[0]), float(temp[1])]) 115 | return ans 116 | 117 | def parse(self, save=True): 118 | if not self.initialized: 119 | self.initialize() 120 | self.opt = self.parser.parse_args() 121 | self.opt.isTrain = self.isTrain # train or test 122 | 123 | self.opt.label_indexs = self.get_list_int(self.opt.label_indexs) 124 | self.opt.gram_weights = self.get_list_float(self.opt.gram_weights) 125 | self.opt.vgg_weights = self.get_list_float(self.opt.vgg_weights) 126 | self.opt.self_gram_weights = self.get_list_float(self.opt.self_gram_weights) 127 | self.opt.self_vgg_weights = self.get_list_float(self.opt.self_vgg_weights) 128 | self.opt.style_stage_mul = self.get_list_dict(self.opt.style_stage_mul) 129 | self.opt.real_stage_mul = self.get_list_dict(self.opt.real_stage_mul) 130 | self.opt.train_val_list = self.get_list_int(self.opt.train_val_list) 131 | 132 | str_ids = self.opt.gpu_ids.split(',') 133 | self.opt.gpu_ids = [] 134 | for str_id in str_ids: 135 | id = int(str_id) 136 | if id >= 0: 137 | self.opt.gpu_ids.append(id) 138 | 139 | # set gpu ids 140 | if len(self.opt.gpu_ids) > 0: 141 | torch.cuda.set_device(self.opt.gpu_ids[0]) 142 | 143 | args = vars(self.opt) 144 | 145 | print('------------ Options -------------') 146 | for k, v in sorted(args.items()): 147 | print('%s: %s' % (str(k), str(v))) 148 | print('-------------- End ----------------') 149 | 150 | # save to the disk 151 | expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.name) 152 | util.mkdirs(expr_dir) 153 | if save and not self.opt.continue_train: 154 | file_name = os.path.join(expr_dir, 'opt.txt') 155 | with open(file_name, 'wt') as opt_file: 156 | opt_file.write('------------ Options -------------\n') 157 | for k, v in sorted(args.items()): 158 | opt_file.write('%s: %s\n' % (str(k), str(v))) 159 | opt_file.write('-------------- End ----------------\n') 160 | return self.opt 161 | -------------------------------------------------------------------------------- /options/generate_data_options.py: -------------------------------------------------------------------------------- 1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 3 | from .base_options import BaseOptions 4 | 5 | class GenerateDataOptions(BaseOptions): 6 | def initialize(self): 7 | BaseOptions.initialize(self) 8 | self.parser.add_argument('--neighbor_size', type=int, default=int(30), help='max distance between two same style frame.') 9 | self.parser.add_argument('--same_style_rate', type=float, default=float(0.5), help='rate of A&B are neighbors') 10 | self.parser.add_argument('--copy_data', action='store_true', default=False, help='copy img datas') 11 | self.parser.add_argument('--target_path', type=str, default='./datasets/apollo/', help='target path') 12 | self.parser.add_argument('--source_path', type=str, default='../datas/road02/', help='target path') 13 | self.parser.add_argument('--val_ratio', type=float, default=float(0.2), help='val data ratio') 14 | self.parser.add_argument('--test_ratio', type=float, default=float(0.2), help='test data ratio') 15 | self.parser.add_argument('--A_repeat_num', type=int, default=int(10), help='# of same A repeats in dataset.') 16 | self.isTrain = False 17 | -------------------------------------------------------------------------------- /options/test_options.py: -------------------------------------------------------------------------------- 1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 3 | from .base_options import BaseOptions 4 | 5 | class TestOptions(BaseOptions): 6 | def initialize(self): 7 | BaseOptions.initialize(self) 8 | self.parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.') 9 | self.parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.') 10 | self.parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images') 11 | self.parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') 12 | self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') 13 | self.parser.add_argument('--how_many', type=int, default=100, help='how many test images to run') 14 | self.parser.add_argument('--cluster_path', type=str, default='features_clustered_010.npy', help='the path for clustered results of encoded features') 15 | self.parser.add_argument("--export_onnx", type=str, help="export ONNX model to a given file") 16 | self.parser.add_argument("--engine", type=str, help="run serialized TRT engine") 17 | self.parser.add_argument("--onnx", type=str, help="run ONNX model via TRT") 18 | self.parser.add_argument("--conditioned", action='store_true', help="use conditioned") 19 | self.parser.add_argument("--test_delta_path", type=str, default="test_delta.txt", help="test_delta_path") 20 | self.isTrain = False 21 | -------------------------------------------------------------------------------- /options/train_options.py: -------------------------------------------------------------------------------- 1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 3 | from .base_options import BaseOptions 4 | 5 | class TrainOptions(BaseOptions): 6 | def initialize(self): 7 | BaseOptions.initialize(self) 8 | # for displays 9 | self.parser.add_argument('--display_freq', type=int, default=100, help='frequency of showing training results on screen') 10 | self.parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console') 11 | self.parser.add_argument('--save_latest_freq', type=int, default=1000, help='frequency of saving the latest results') 12 | self.parser.add_argument('--save_iter_freq', type=int, default=150000, help='frequency of saving the latest results') 13 | self.parser.add_argument('--save_epoch_freq', type=int, default=10, help='frequency of saving checkpoints at the end of epochs') 14 | self.parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') 15 | self.parser.add_argument('--debug', action='store_true', help='only do one epoch and displays at each iteration') 16 | self.parser.add_argument('--val_freq', type=int, default=500, help='frequency of showing training results on console') 17 | 18 | # for training 19 | self.parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') 20 | self.parser.add_argument('--load_pretrain', type=str, default='', help='load the pretrained model from the specified location') 21 | self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') 22 | self.parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') 23 | self.parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate') 24 | self.parser.add_argument('--niter_decay', type=int, default=100, help='# of iter to linearly decay learning rate to zero') 25 | self.parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam') 26 | self.parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam') 27 | 28 | self.parser.add_argument('--use_iter_decay', action='store_true', help='user iter decay') 29 | self.parser.add_argument('--niter_iter', type=int, default=200000, help='# of iter at starting learning rate') 30 | self.parser.add_argument('--niter_decay_iter', type=int, default=200000, help='# of iter to linearly decay learning rate to zero') 31 | self.parser.add_argument('--use_style_iter', type=int, default=-1, help='# of iter to use style discriminator') 32 | 33 | 34 | # for discriminators 35 | self.parser.add_argument('--num_D', type=int, default=2, help='number of discriminators to use') 36 | self.parser.add_argument('--n_layers_D', type=int, default=3, help='only used if which_model_netD==n_layers') 37 | self.parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer') 38 | self.parser.add_argument('--nsdf', type=int, default=128, help='# of discrim filters in first conv layer') 39 | self.parser.add_argument('--lambda_feat', type=float, default=10.0, help='weight for feature matching loss') 40 | self.parser.add_argument('--no_ganFeat_loss', action='store_true', help='if specified, do *not* use discriminator feature matching loss') 41 | self.parser.add_argument('--no_vgg_loss', action='store_true', help='if specified, do *not* use VGG feature matching loss') 42 | self.parser.add_argument('--no_lsgan', action='store_true', help='do *not* use least square GAN, if false, use vanilla GAN') 43 | self.parser.add_argument('--pool_size', type=int, default=0, help='the size of image buffer that stores previously generated images') 44 | self.parser.add_argument('--SD_mul', type=float, default=10, help='style discriminator mul') 45 | self.parser.add_argument('--GAN_Feat_mul', type=float, default=1, help='GAN Feat mul') 46 | self.parser.add_argument('--G_confidence_mul', type=float, default=1, help='G confidence mul') 47 | self.parser.add_argument('--FG_GAN_mul', type=float, default=10, help='G confidence mul') 48 | self.parser.add_argument('--val_n_everytime', type=int, default=10, help='val num everytime') 49 | 50 | self.parser.add_argument('--no_SD_false_pair', action='store_true', help='') 51 | 52 | self.parser.add_argument('--use_stage_lr', action='store_true', help='') 53 | self.parser.add_argument('--stage_lr_decay_iter', type=int, default=50000, help='') 54 | self.parser.add_argument('--stage_lr_decay_rate', type=float, default=0.3, help='') 55 | self.isTrain = True 56 | -------------------------------------------------------------------------------- /precompute_feature_maps.py: -------------------------------------------------------------------------------- 1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 3 | from options.train_options import TrainOptions 4 | from data.data_loader import CreateDataLoader 5 | from models.models import create_model 6 | import os 7 | import util.util as util 8 | from torch.autograd import Variable 9 | import torch.nn as nn 10 | 11 | opt = TrainOptions().parse() 12 | opt.nThreads = 1 13 | opt.batchSize = 1 14 | opt.serial_batches = True 15 | opt.no_flip = True 16 | opt.instance_feat = True 17 | 18 | name = 'features' 19 | save_path = os.path.join(opt.checkpoints_dir, opt.name) 20 | 21 | ############ Initialize ######### 22 | data_loader = CreateDataLoader(opt) 23 | dataset = data_loader.load_data() 24 | dataset_size = len(data_loader) 25 | model = create_model(opt) 26 | util.mkdirs(os.path.join(opt.dataroot, opt.phase + '_feat')) 27 | 28 | ######## Save precomputed feature maps for 1024p training ####### 29 | for i, data in enumerate(dataset): 30 | print('%d / %d images' % (i+1, dataset_size)) 31 | feat_map = model.module.netE.forward(Variable(data['image'].cuda(), volatile=True), data['inst'].cuda()) 32 | feat_map = nn.Upsample(scale_factor=2, mode='nearest')(feat_map) 33 | image_numpy = util.tensor2im(feat_map.data[0]) 34 | save_path = data['path'][0].replace('/train_label/', '/train_feat/') 35 | util.save_image(image_numpy, save_path) -------------------------------------------------------------------------------- /process.py: -------------------------------------------------------------------------------- 1 | import cv2 # pip install opencv-python 2 | import os 3 | from os.path import join 4 | import argparse 5 | import random 6 | import numpy as np 7 | 8 | in_path = "../FaceForensics/datas/" 9 | out_path = "../FaceForensics/out_data/" 10 | 11 | def process(vid_path, out_path): 12 | print(vid_path) 13 | print(out_path) 14 | os.makedirs(out_path) 15 | video = cv2.VideoCapture(vid_path) 16 | n = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) 17 | for i in range(n): 18 | _, image = video.read() 19 | path = os.path.join(out_path, 'a.' + str(i) + '.jpg') 20 | cv2.imwrite(path, image) 21 | 22 | tot = 0 23 | for root, _, fnames in os.walk(in_path): 24 | for fname in fnames: 25 | if not fname.endswith('.avi'): 26 | continue 27 | process(os.path.join(root, fname), os.path.join(out_path, "Mr_"+str(tot), "1")) 28 | tot += 1 29 | -------------------------------------------------------------------------------- /run_engine.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from random import randint 4 | import numpy as np 5 | import tensorrt 6 | 7 | try: 8 | from PIL import Image 9 | import pycuda.driver as cuda 10 | import pycuda.gpuarray as gpuarray 11 | import pycuda.autoinit 12 | import argparse 13 | except ImportError as err: 14 | sys.stderr.write("""ERROR: failed to import module ({}) 15 | Please make sure you have pycuda and the example dependencies installed. 16 | https://wiki.tiker.net/PyCuda/Installation/Linux 17 | pip(3) install tensorrt[examples] 18 | """.format(err)) 19 | exit(1) 20 | 21 | try: 22 | import tensorrt as trt 23 | from tensorrt.parsers import caffeparser 24 | from tensorrt.parsers import onnxparser 25 | except ImportError as err: 26 | sys.stderr.write("""ERROR: failed to import module ({}) 27 | Please make sure you have the TensorRT Library installed 28 | and accessible in your LD_LIBRARY_PATH 29 | """.format(err)) 30 | exit(1) 31 | 32 | 33 | G_LOGGER = trt.infer.ConsoleLogger(trt.infer.LogSeverity.INFO) 34 | 35 | class Profiler(trt.infer.Profiler): 36 | """ 37 | Example Implimentation of a Profiler 38 | Is identical to the Profiler class in trt.infer so it is possible 39 | to just use that instead of implementing this if further 40 | functionality is not needed 41 | """ 42 | def __init__(self, timing_iter): 43 | trt.infer.Profiler.__init__(self) 44 | self.timing_iterations = timing_iter 45 | self.profile = [] 46 | 47 | def report_layer_time(self, layerName, ms): 48 | record = next((r for r in self.profile if r[0] == layerName), (None, None)) 49 | if record == (None, None): 50 | self.profile.append((layerName, ms)) 51 | else: 52 | self.profile[self.profile.index(record)] = (record[0], record[1] + ms) 53 | 54 | def print_layer_times(self): 55 | totalTime = 0 56 | for i in range(len(self.profile)): 57 | print("{:40.40} {:4.3f}ms".format(self.profile[i][0], self.profile[i][1] / self.timing_iterations)) 58 | totalTime += self.profile[i][1] 59 | print("Time over all layers: {:4.2f} ms per iteration".format(totalTime / self.timing_iterations)) 60 | 61 | 62 | def get_input_output_names(trt_engine): 63 | nbindings = trt_engine.get_nb_bindings(); 64 | maps = [] 65 | 66 | for b in range(0, nbindings): 67 | dims = trt_engine.get_binding_dimensions(b).to_DimsCHW() 68 | name = trt_engine.get_binding_name(b) 69 | type = trt_engine.get_binding_data_type(b) 70 | 71 | if (trt_engine.binding_is_input(b)): 72 | maps.append(name) 73 | print("Found input: ", name) 74 | else: 75 | maps.append(name) 76 | print("Found output: ", name) 77 | 78 | print("shape=" + str(dims.C()) + " , " + str(dims.H()) + " , " + str(dims.W())) 79 | print("dtype=" + str(type)) 80 | return maps 81 | 82 | def create_memory(engine, name, buf, mem, batchsize, inp, inp_idx): 83 | binding_idx = engine.get_binding_index(name) 84 | if binding_idx == -1: 85 | raise AttributeError("Not a valid binding") 86 | print("Binding: name={}, bindingIndex={}".format(name, str(binding_idx))) 87 | dims = engine.get_binding_dimensions(binding_idx).to_DimsCHW() 88 | eltCount = dims.C() * dims.H() * dims.W() * batchsize 89 | 90 | if engine.binding_is_input(binding_idx): 91 | h_mem = inp[inp_idx] 92 | inp_idx = inp_idx + 1 93 | else: 94 | h_mem = np.random.uniform(0.0, 255.0, eltCount).astype(np.dtype('f4')) 95 | 96 | d_mem = cuda.mem_alloc(eltCount * 4) 97 | cuda.memcpy_htod(d_mem, h_mem) 98 | buf.insert(binding_idx, int(d_mem)) 99 | mem.append(d_mem) 100 | return inp_idx 101 | 102 | 103 | #Run inference on device 104 | def time_inference(engine, batch_size, inp): 105 | bindings = [] 106 | mem = [] 107 | inp_idx = 0 108 | for io in get_input_output_names(engine): 109 | inp_idx = create_memory(engine, io, bindings, mem, 110 | batch_size, inp, inp_idx) 111 | 112 | context = engine.create_execution_context() 113 | g_prof = Profiler(500) 114 | context.set_profiler(g_prof) 115 | for i in range(iter): 116 | context.execute(batch_size, bindings) 117 | g_prof.print_layer_times() 118 | 119 | context.destroy() 120 | return 121 | 122 | 123 | def convert_to_datatype(v): 124 | if v==8: 125 | return trt.infer.DataType.INT8 126 | elif v==16: 127 | return trt.infer.DataType.HALF 128 | elif v==32: 129 | return trt.infer.DataType.FLOAT 130 | else: 131 | print("ERROR: Invalid model data type bit depth: " + str(v)) 132 | return trt.infer.DataType.INT8 133 | 134 | def run_trt_engine(engine_file, bs, it): 135 | engine = trt.utils.load_engine(G_LOGGER, engine_file) 136 | time_inference(engine, bs, it) 137 | 138 | def run_onnx(onnx_file, data_type, bs, inp): 139 | # Create onnx_config 140 | apex = onnxparser.create_onnxconfig() 141 | apex.set_model_file_name(onnx_file) 142 | apex.set_model_dtype(convert_to_datatype(data_type)) 143 | 144 | # create parser 145 | trt_parser = onnxparser.create_onnxparser(apex) 146 | assert(trt_parser) 147 | data_type = apex.get_model_dtype() 148 | onnx_filename = apex.get_model_file_name() 149 | trt_parser.parse(onnx_filename, data_type) 150 | trt_parser.report_parsing_info() 151 | trt_parser.convert_to_trtnetwork() 152 | trt_network = trt_parser.get_trtnetwork() 153 | assert(trt_network) 154 | 155 | # create infer builder 156 | trt_builder = trt.infer.create_infer_builder(G_LOGGER) 157 | trt_builder.set_max_batch_size(max_batch_size) 158 | trt_builder.set_max_workspace_size(max_workspace_size) 159 | 160 | if (apex.get_model_dtype() == trt.infer.DataType_kHALF): 161 | print("------------------- Running FP16 -----------------------------") 162 | trt_builder.set_half2_mode(True) 163 | elif (apex.get_model_dtype() == trt.infer.DataType_kINT8): 164 | print("------------------- Running INT8 -----------------------------") 165 | trt_builder.set_int8_mode(True) 166 | else: 167 | print("------------------- Running FP32 -----------------------------") 168 | 169 | print("----- Builder is Done -----") 170 | print("----- Creating Engine -----") 171 | trt_engine = trt_builder.build_cuda_engine(trt_network) 172 | print("----- Engine is built -----") 173 | time_inference(engine, bs, inp) 174 | -------------------------------------------------------------------------------- /temp_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | from torch.autograd import Variable 4 | w1 = Variable(torch.Tensor([1.0,2.0,3.0]),requires_grad=True) 5 | 6 | 7 | optimizer = optim.SGD(w1.parameters(), lr = 0.01) 8 | d = torch.mean(w1) 9 | d.backward() 10 | print(w1.grad) 11 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 3 | import os 4 | from collections import OrderedDict 5 | from torch.autograd import Variable 6 | from options.test_options import TestOptions 7 | from data.data_loader import CreateDataLoader 8 | from models.models import create_model 9 | import util.util as util 10 | from util.visualizer import Visualizer 11 | from util import html 12 | import torch 13 | 14 | opt = TestOptions().parse(save=False) 15 | opt.nThreads = 1 # test code only supports nThreads = 1 16 | opt.batchSize = 1 # test code only supports batchSize = 1 17 | opt.serial_batches = True # no shuffle 18 | opt.no_flip = True # no flip 19 | 20 | data_loader = CreateDataLoader(opt) 21 | dataset = data_loader.load_data() 22 | visualizer = Visualizer(opt) 23 | # create website 24 | web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch)) 25 | webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch)) 26 | 27 | # test 28 | if not opt.engine and not opt.onnx: 29 | model = create_model(opt) 30 | if opt.data_type == 16: 31 | model.half() 32 | elif opt.data_type == 8: 33 | model.type(torch.uint8) 34 | 35 | if opt.verbose: 36 | print(model) 37 | else: 38 | from run_engine import run_trt_engine, run_onnx 39 | 40 | for i, data in enumerate(dataset): 41 | if i >= opt.how_many: 42 | break 43 | if opt.data_type == 16: 44 | data['label'] = data['label'].half() 45 | data['inst'] = data['inst'].half() 46 | elif opt.data_type == 8: 47 | data['label'] = data['label'].uint8() 48 | data['inst'] = data['inst'].uint8() 49 | if opt.export_onnx: 50 | print ("Exporting to ONNX: ", opt.export_onnx) 51 | assert opt.export_onnx.endswith("onnx"), "Export model file should end with .onnx" 52 | torch.onnx.export(model, [data['label'], data['inst']], 53 | opt.export_onnx, verbose=True) 54 | exit(0) 55 | minibatch = 1 56 | if opt.engine: 57 | generated = run_trt_engine(opt.engine, minibatch, [data['label'], data['inst']]) 58 | elif opt.onnx: 59 | generated = run_onnx(opt.onnx, opt.data_type, minibatch, [data['label'], data['inst']]) 60 | else: 61 | generated = model.inference(data['label'], data['inst']) 62 | 63 | visuals = OrderedDict([('input_label', util.tensor2label(data['label'][0], opt.label_nc)), 64 | ('synthesized_image', util.tensor2im(generated.data[0]))]) 65 | img_path = data['path'] 66 | print('process image... %s' % img_path) 67 | visualizer.save_images(webpage, visuals, img_path) 68 | 69 | webpage.save() 70 | -------------------------------------------------------------------------------- /test_all.py: -------------------------------------------------------------------------------- 1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 3 | import os 4 | from collections import OrderedDict 5 | from torch.autograd import Variable 6 | from options.test_options import TestOptions 7 | from data.data_loader import CreateDataLoader 8 | from models.models import create_model 9 | import util.util as util 10 | from util.visualizer import Visualizer 11 | from util import html 12 | from PIL import Image 13 | import torch 14 | from data.base_dataset import get_params, get_transform, normalize 15 | import copy 16 | from models import networks 17 | import numpy as np 18 | import torch.nn as nn 19 | 20 | opt = TestOptions().parse(save=False) 21 | opt.nThreads = 1 # test code only supports nThreads = 1 22 | opt.batchSize = 1 # test code only supports batchSize = 1 23 | opt.serial_batches = True # no shuffle 24 | opt.no_flip = True # no flip 25 | 26 | data_loader = CreateDataLoader(opt) 27 | dataset = data_loader.load_data() 28 | visualizer = Visualizer(opt) 29 | def get_features(inst, feat): 30 | feat_num = opt.feat_num 31 | h, w = inst.size()[1], inst.size()[2] 32 | block_num = 32 33 | feature = {} 34 | max_v = {} 35 | for i in range(opt.label_nc): 36 | feature[i] = np.zeros((0, feat_num+1)) 37 | max_v[i] = 0 38 | for i in np.unique(inst): 39 | label = i if i < 1000 else i//1000 40 | idx = (inst == int(i)).nonzero() 41 | num = idx.size()[0] 42 | idx = idx[num//2,:] 43 | val = np.zeros((feat_num)) 44 | for k in range(feat_num): 45 | val[k] = feat[idx[0] + k, idx[1], idx[2]].data[0] 46 | temp = float(num) / (h * w // block_num) 47 | if (temp > max_v[label]): 48 | max_v[label] = temp 49 | feature[label] = val 50 | return feature 51 | 52 | def getitem(A_path, B_path, inst_path, feat_path): 53 | ### input A (label maps) 54 | A = Image.open(A_path) 55 | params = get_params(opt, A.size) 56 | if opt.label_nc == 0: 57 | transform_A = get_transform(opt, params) 58 | A_tensor = transform_A(A.convert('RGB')) 59 | else: 60 | transform_A = get_transform(opt, params, method=Image.NEAREST, normalize=False) 61 | A_tensor = transform_A(A) * 255.0 62 | 63 | B_tensor = inst_tensor = feat_tensor = 0 64 | ### input B (real images) 65 | B = Image.open(B_path).convert('RGB') 66 | transform_B = get_transform(opt, params) 67 | B_tensor = transform_B(B) 68 | 69 | ### if using instance maps 70 | inst = Image.open(inst_path) 71 | inst_tensor = transform_A(inst) 72 | 73 | #get feat 74 | netE = networks.define_G(opt.output_nc, opt.feat_num, opt.nef, 'encoder', 75 | opt.n_downsample_E, norm=opt.norm, gpu_ids=opt.gpu_ids) 76 | feat_map = netE.forward(Variable(B_tensor[np.newaxis, :].cuda(), volatile=True), inst_tensor[np.newaxis, :].cuda()) 77 | feat_map = nn.Upsample(scale_factor=2, mode='nearest')(feat_map) 78 | image_numpy = util.tensor2im(feat_map.data[0]) 79 | util.save_image(image_numpy, feat_path) 80 | 81 | feat = Image.open(feat_path).convert('RGB') 82 | norm = normalize() 83 | feat_tensor = norm(transform_A(feat)) 84 | input_dict = {'label': A_tensor, 'inst': inst_tensor, 'image': B_tensor, 85 | 'feat': feat_tensor, 'path': A_path} 86 | 87 | return get_features(input_dict['inst'], input_dict['feat']) 88 | 89 | # test 90 | if not opt.engine and not opt.onnx: 91 | model = create_model(opt) 92 | if opt.data_type == 16: 93 | model.half() 94 | elif opt.data_type == 8: 95 | model.type(torch.uint8) 96 | 97 | if opt.verbose: 98 | print(model) 99 | else: 100 | from run_engine import run_trt_engine, run_onnx 101 | 102 | label_path = 'datasets/test/label.png' 103 | img_path = 'datasets/test/img.png' 104 | inst_path = 'datasets/test/inst.png' 105 | feat_path = 'datasets/test/feat.png' 106 | con = getitem(label_path, img_path, inst_path, feat_path) 107 | for k in range(opt.n_clusters): 108 | # create website 109 | web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s_%d' % (opt.phase, opt.which_epoch, k)) 110 | webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch)) 111 | 112 | 113 | for i, data in enumerate(dataset): 114 | if i >= opt.how_many: 115 | break 116 | if opt.data_type == 16: 117 | data['label'] = data['label'].half() 118 | data['inst'] = data['inst'].half() 119 | elif opt.data_type == 8: 120 | data['label'] = data['label'].uint8() 121 | data['inst'] = data['inst'].uint8() 122 | if opt.export_onnx: 123 | print ("Exporting to ONNX: ", opt.export_onnx) 124 | assert opt.export_onnx.endswith("onnx"), "Export model file should end with .onnx" 125 | torch.onnx.export(model, [data['label'], data['inst']], 126 | opt.export_onnx, verbose=True) 127 | exit(0) 128 | minibatch = 1 129 | if opt.engine: 130 | generated = run_trt_engine(opt.engine, minibatch, [data['label'], data['inst']]) 131 | elif opt.onnx: 132 | generated = run_onnx(opt.onnx, opt.data_type, minibatch, [data['label'], data['inst']]) 133 | elif opt.conditioned: 134 | generated = model.inference_conditioned(data['label'], data['inst'], con, k) 135 | else: 136 | generated = model.inference(data['label'], data['inst']) 137 | 138 | visuals = OrderedDict([('input_label', util.tensor2label(data['label'][0], opt.label_nc)), 139 | ('synthesized_image', util.tensor2im(generated.data[0]))]) 140 | img_path = data['path'] 141 | print('process image... %s' % img_path) 142 | visualizer.save_images(webpage, visuals, img_path) 143 | 144 | webpage.save() 145 | -------------------------------------------------------------------------------- /test_con.py: -------------------------------------------------------------------------------- 1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 3 | import os 4 | from collections import OrderedDict 5 | from torch.autograd import Variable 6 | from options.test_options import TestOptions 7 | from data.data_loader import CreateConDataLoader 8 | from models.models import create_model 9 | import util.util as util 10 | from util.visualizer import Visualizer 11 | from util import html 12 | import torch 13 | 14 | opt = TestOptions().parse(save=False) 15 | opt.nThreads = 1 # test code only supports nThreads = 1 16 | opt.batchSize = 1 # test code only supports batchSize = 1 17 | opt.serial_batches = True # no shuffle 18 | opt.no_flip = True # no flip 19 | 20 | data_loader = CreateConDataLoader(opt) 21 | dataset = data_loader.load_data() 22 | visualizer = Visualizer(opt) 23 | # create website 24 | web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch)) 25 | webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch)) 26 | 27 | # test 28 | if not opt.engine and not opt.onnx: 29 | model = create_model(opt) 30 | if opt.data_type == 16: 31 | model.half() 32 | elif opt.data_type == 8: 33 | model.type(torch.uint8) 34 | 35 | if opt.verbose: 36 | print(model) 37 | else: 38 | from run_engine import run_trt_engine, run_onnx 39 | 40 | for i, data in enumerate(dataset): 41 | if i >= opt.how_many: 42 | break 43 | if opt.data_type == 16: 44 | data['A'] = data['A'].half() 45 | data['A2'] = data['A2'].half() 46 | data['B'] = data['B'].half() 47 | data['B2'] = data['B2'].half() 48 | elif opt.data_type == 8: 49 | data['A'] = data['A'].uint8() 50 | data['A2'] = data['A2'].uint8() 51 | data['B'] = data['B'].uint8() 52 | data['B2'] = data['B2'].uint8() 53 | if opt.export_onnx: 54 | print ("Exporting to ONNX: ", opt.export_onnx) 55 | assert opt.export_onnx.endswith("onnx"), "Export model file should end with .onnx" 56 | torch.onnx.export(model, [data['label'], data['inst']], 57 | opt.export_onnx, verbose=True) 58 | exit(0) 59 | minibatch = 1 60 | if opt.engine: 61 | generated = run_trt_engine(opt.engine, minibatch, [data['label'], data['inst']]) 62 | elif opt.onnx: 63 | generated = run_onnx(opt.onnx, opt.data_type, minibatch, [data['label'], data['inst']]) 64 | else: 65 | generated = model.inference(data['A'], data['B'], data['B2']) 66 | 67 | visuals = OrderedDict([('input_label', util.tensor2label(data['A'][0], 256)), 68 | ('real_image', util.tensor2im(data['A2'][0])), 69 | ('synthesized_image', util.tensor2im(generated.data[0])), 70 | ('B', util.tensor2label(data['B'][0], 256)), 71 | ('B2', util.tensor2im(data['B2'][0]))]) 72 | img_path = data['path'] 73 | print('process image... %s' % img_path) 74 | visualizer.save_images(webpage, visuals, img_path) 75 | 76 | webpage.save() 77 | -------------------------------------------------------------------------------- /test_con_bak.py: -------------------------------------------------------------------------------- 1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 3 | import os 4 | from collections import OrderedDict 5 | from torch.autograd import Variable 6 | from options.test_options import TestOptions 7 | from data.data_loader import CreateDataLoader 8 | from models.models import create_model 9 | import util.util as util 10 | from util.visualizer import Visualizer 11 | from util import html 12 | from PIL import Image 13 | import torch 14 | from data.base_dataset import get_params, get_transform, normalize 15 | import copy 16 | from models import networks 17 | import numpy as np 18 | import torch.nn as nn 19 | 20 | opt = TestOptions().parse(save=False) 21 | opt.nThreads = 1 # test code only supports nThreads = 1 22 | opt.batchSize = 1 # test code only supports batchSize = 1 23 | opt.serial_batches = True # no shuffle 24 | opt.no_flip = True # no flip 25 | 26 | data_loader = CreateDataLoader(opt) 27 | dataset = data_loader.load_data() 28 | visualizer = Visualizer(opt) 29 | # create website 30 | web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch)) 31 | webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch)) 32 | 33 | def get_features(inst, feat): 34 | feat_num = opt.feat_num 35 | h, w = inst.size()[1], inst.size()[2] 36 | block_num = 32 37 | feature = {} 38 | max_v = {} 39 | for i in range(opt.label_nc): 40 | feature[i] = np.zeros((0, feat_num+1)) 41 | max_v[i] = 0 42 | for i in np.unique(inst): 43 | label = i if i < 1000 else i//1000 44 | idx = (inst == int(i)).nonzero() 45 | num = idx.size()[0] 46 | idx = idx[num//2,:] 47 | val = np.zeros((feat_num)) 48 | for k in range(feat_num): 49 | val[k] = feat[0, idx[0] + k, idx[1], idx[2]].data[0] 50 | temp = float(num) / (h * w // block_num) 51 | if (temp > max_v[label]): 52 | max_v[label] = temp 53 | feature[label] = val 54 | return feature 55 | 56 | def getitem(A_path, B_path, inst_path, feat_path): 57 | ### input A (label maps) 58 | A = Image.open(A_path) 59 | params = get_params(opt, A.size) 60 | if opt.label_nc == 0: 61 | transform_A = get_transform(opt, params) 62 | A_tensor = transform_A(A.convert('RGB')) 63 | else: 64 | transform_A = get_transform(opt, params, method=Image.NEAREST, normalize=False) 65 | A_tensor = transform_A(A) * 255.0 66 | 67 | B_tensor = inst_tensor = feat_tensor = 0 68 | ### input B (real images) 69 | B = Image.open(B_path).convert('RGB') 70 | transform_B = get_transform(opt, params) 71 | B_tensor = transform_B(B) 72 | 73 | ### if using instance maps 74 | inst = Image.open(inst_path) 75 | inst_tensor = transform_A(inst) 76 | 77 | #get feat 78 | netE = networks.define_G(opt.output_nc, opt.feat_num, opt.nef, 'encoder', 79 | opt.n_downsample_E, norm=opt.norm, gpu_ids=opt.gpu_ids) 80 | feat_map = netE.forward(Variable(B_tensor[np.newaxis, :].cuda(), volatile=True), inst_tensor[np.newaxis, :].cuda()) 81 | ''' 82 | feat_map = nn.Upsample(scale_factor=2, mode='nearest')(feat_map) 83 | image_numpy = util.tensor2im(feat_map.data[0]) 84 | util.save_image(image_numpy, feat_path) 85 | 86 | feat = Image.open(feat_path).convert('RGB') 87 | norm = normalize() 88 | feat_tensor = norm(transform_A(feat)) 89 | input_dict = {'label': A_tensor, 'inst': inst_tensor, 'image': B_tensor, 90 | 'feat': feat_tensor, 'path': A_path} 91 | ''' 92 | 93 | return get_features(inst_tensor, feat_map) 94 | 95 | # test 96 | if not opt.engine and not opt.onnx: 97 | model = create_model(opt) 98 | if opt.data_type == 16: 99 | model.half() 100 | elif opt.data_type == 8: 101 | model.type(torch.uint8) 102 | 103 | if opt.verbose: 104 | print(model) 105 | else: 106 | from run_engine import run_trt_engine, run_onnx 107 | 108 | label_path = 'datasets/test/label.png' 109 | img_path = 'datasets/test/img.png' 110 | inst_path = 'datasets/test/inst.png' 111 | feat_path = 'datasets/test/feat.png' 112 | con = getitem(label_path, img_path, inst_path, feat_path) 113 | 114 | for i, data in enumerate(dataset): 115 | if i >= opt.how_many: 116 | break 117 | if opt.data_type == 16: 118 | data['label'] = data['label'].half() 119 | data['inst'] = data['inst'].half() 120 | elif opt.data_type == 8: 121 | data['label'] = data['label'].uint8() 122 | data['inst'] = data['inst'].uint8() 123 | if opt.export_onnx: 124 | print ("Exporting to ONNX: ", opt.export_onnx) 125 | assert opt.export_onnx.endswith("onnx"), "Export model file should end with .onnx" 126 | torch.onnx.export(model, [data['label'], data['inst']], 127 | opt.export_onnx, verbose=True) 128 | exit(0) 129 | minibatch = 1 130 | if opt.engine: 131 | generated = run_trt_engine(opt.engine, minibatch, [data['label'], data['inst']]) 132 | elif opt.onnx: 133 | generated = run_onnx(opt.onnx, opt.data_type, minibatch, [data['label'], data['inst']]) 134 | elif opt.conditioned: 135 | generated = model.inference_conditioned(data['label'], data['inst'], con) 136 | else: 137 | generated = model.inference(data['label'], data['inst']) 138 | 139 | visuals = OrderedDict([('input_label', util.tensor2label(data['label'][0], opt.label_nc)), 140 | ('synthesized_image', util.tensor2im(generated.data[0]))]) 141 | img_path = data['path'] 142 | print('process image... %s' % img_path) 143 | visualizer.save_images(webpage, visuals, img_path) 144 | 145 | webpage.save() 146 | -------------------------------------------------------------------------------- /test_delta.txt: -------------------------------------------------------------------------------- 1 | 0 2 | -------------------------------------------------------------------------------- /test_face.py: -------------------------------------------------------------------------------- 1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 3 | import os 4 | from collections import OrderedDict 5 | from torch.autograd import Variable 6 | from options.test_options import TestOptions 7 | from data.data_loader import CreateFaceConDataLoader 8 | from models.models import create_model 9 | import util.util as util 10 | from util.visualizer import Visualizer 11 | from util import html 12 | import random 13 | import torch 14 | 15 | opt = TestOptions().parse(save=False) 16 | opt.nThreads = 1 # test code only supports nThreads = 1 17 | opt.batchSize = 1 # test code only supports batchSize = 1 18 | opt.serial_batches = True # no shuffle 19 | opt.no_flip = True # no flip 20 | 21 | data_loader = CreateFaceConDataLoader(opt) 22 | dataset = data_loader.load_data() 23 | visualizer = Visualizer(opt) 24 | # create website 25 | web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s_%s' % (opt.phase, opt.which_epoch, str(opt.serial_batches))) 26 | webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch)) 27 | 28 | # test 29 | if not opt.engine and not opt.onnx: 30 | model = create_model(opt) 31 | if opt.data_type == 16: 32 | model.half() 33 | elif opt.data_type == 8: 34 | model.type(torch.uint8) 35 | 36 | if opt.verbose: 37 | print(model) 38 | else: 39 | from run_engine import run_trt_engine, run_onnx 40 | 41 | tot = 0 42 | for i, data in enumerate(dataset, start=random.randint(0, len(dataset) - opt.how_many)): 43 | tot += 1 44 | if tot > opt.how_many: 45 | break 46 | if opt.data_type == 16: 47 | data['A'] = data['A'].half() 48 | data['A2'] = data['A2'].half() 49 | data['B'] = data['B'].half() 50 | data['B2'] = data['B2'].half() 51 | elif opt.data_type == 8: 52 | data['A'] = data['A'].uint8() 53 | data['A2'] = data['A2'].uint8() 54 | data['B'] = data['B'].uint8() 55 | data['B2'] = data['B2'].uint8() 56 | if opt.export_onnx: 57 | print ("Exporting to ONNX: ", opt.export_onnx) 58 | assert opt.export_onnx.endswith("onnx"), "Export model file should end with .onnx" 59 | torch.onnx.export(model, [data['label'], data['inst']], 60 | opt.export_onnx, verbose=True) 61 | exit(0) 62 | minibatch = 1 63 | if opt.engine: 64 | generated = run_trt_engine(opt.engine, minibatch, [data['label'], data['inst']]) 65 | elif opt.onnx: 66 | generated = run_onnx(opt.onnx, opt.data_type, minibatch, [data['label'], data['inst']]) 67 | else: 68 | generated = model.inference(data['A'], data['B'], data['B2']) 69 | 70 | visuals = OrderedDict([('input_label', util.tensor2label(data['A'][0], 0)), 71 | ('real_image', util.tensor2im(data['A2'][0])), 72 | ('synthesized_image', util.tensor2im(generated.data[0])), 73 | ('B', util.tensor2label(data['B'][0], 0)), 74 | ('B2', util.tensor2im(data['B2'][0]))]) 75 | img_path = data['path'] 76 | img_path[0] = str(i) 77 | print('process image... %s' % img_path) 78 | visualizer.save_images(webpage, visuals, img_path) 79 | 80 | webpage.save() 81 | -------------------------------------------------------------------------------- /test_mface.py: -------------------------------------------------------------------------------- 1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 3 | import os 4 | from collections import OrderedDict 5 | from torch.autograd import Variable 6 | from options.test_options import TestOptions 7 | from data.data_loader import CreateFaceConDataLoader 8 | from models.models import create_model 9 | import util.util as util 10 | from util.visualizer import Visualizer 11 | from util import html 12 | import torch 13 | 14 | opt = TestOptions().parse(save=False) 15 | opt.nThreads = 1 # test code only supports nThreads = 1 16 | opt.batchSize = 1 # test code only supports batchSize = 1 17 | opt.serial_batches = False #True # no shuffle 18 | opt.no_flip = True # no flip 19 | 20 | data_loader = CreateFaceConDataLoader(opt) 21 | dataset = data_loader.load_data() 22 | visualizer = Visualizer(opt) 23 | # create website 24 | web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch)) 25 | webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch)) 26 | 27 | # test 28 | model = create_model(opt) 29 | ''' 30 | if not opt.engine and not opt.onnx: 31 | model = create_model(opt) 32 | if opt.data_type == 16: 33 | model.half() 34 | elif opt.data_type == 8: 35 | model.type(torch.uint8) 36 | 37 | if opt.verbose: 38 | print(model) 39 | else: 40 | from run_engine import run_trt_engine, run_onnx 41 | ''' 42 | for i, data in enumerate(dataset): 43 | if i >= opt.how_many: 44 | break 45 | ''' 46 | if opt.data_type == 16: 47 | data['A'] = data['A'].half() 48 | data['A2'] = data['A2'].half() 49 | data['B'] = data['B'].half() 50 | data['B2'] = data['B2'].half() 51 | elif opt.data_type == 8: 52 | data['A'] = data['A'].uint8() 53 | data['A2'] = data['A2'].uint8() 54 | data['B'] = data['B'].uint8() 55 | data['B2'] = data['B2'].uint8() 56 | ''' 57 | if opt.export_onnx: 58 | print ("Exporting to ONNX: ", opt.export_onnx) 59 | assert opt.export_onnx.endswith("onnx"), "Export model file should end with .onnx" 60 | torch.onnx.export(model, [data['label'], data['inst']], 61 | opt.export_onnx, verbose=True) 62 | exit(0) 63 | minibatch = 1 64 | if opt.engine: 65 | generated = run_trt_engine(opt.engine, minibatch, [data['label'], data['inst']]) 66 | elif opt.onnx: 67 | generated = run_onnx(opt.onnx, opt.data_type, minibatch, [data['label'], data['inst']]) 68 | else: 69 | generated1, mask, generated = model.inference(Variable(data['A']), Variable(data['B']), Variable(data['B2'])) 70 | 71 | visuals = OrderedDict([('input_label', util.tensor2label(data['A'][0], 0)), 72 | ('real_image', util.tensor2im(data['A2'][0])), 73 | ('synthesized_image_1', util.tensor2im(generated1.data[0])), 74 | ('mask', util.tensor2im(mask.data[0])), 75 | ('synthesized_image', util.tensor2im(generated.data[0])), 76 | ('B', util.tensor2label(data['B'][0], 0)), 77 | ('B2', util.tensor2im(data['B2'][0]))]) 78 | img_path = data['path'] 79 | print('process image... %s' % img_path) 80 | visualizer.save_images(webpage, visuals, img_path) 81 | 82 | webpage.save() 83 | -------------------------------------------------------------------------------- /test_pose.py: -------------------------------------------------------------------------------- 1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 3 | import os 4 | from collections import OrderedDict 5 | from torch.autograd import Variable 6 | from options.test_options import TestOptions 7 | from data.data_loader import CreatePoseConDataLoader 8 | from models.models import create_model 9 | import util.util as util 10 | from util.visualizer import Visualizer 11 | from util import html 12 | import random 13 | import torch 14 | 15 | opt = TestOptions().parse(save=False) 16 | opt.nThreads = 1 # test code only supports nThreads = 1 17 | opt.batchSize = 1 # test code only supports batchSize = 1 18 | opt.serial_batches = True # no shuffle 19 | opt.no_flip = True # no flip 20 | 21 | data_loader = CreatePoseConDataLoader(opt) 22 | dataset = data_loader.load_data() 23 | visualizer = Visualizer(opt) 24 | # create website 25 | web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s_%s' % (opt.phase, opt.which_epoch, str(opt.serial_batches))) 26 | webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch)) 27 | 28 | # test 29 | if not opt.engine and not opt.onnx: 30 | model = create_model(opt) 31 | if opt.data_type == 16: 32 | model.half() 33 | elif opt.data_type == 8: 34 | model.type(torch.uint8) 35 | 36 | if opt.verbose: 37 | print(model) 38 | else: 39 | from run_engine import run_trt_engine, run_onnx 40 | 41 | tot = 0 42 | for i, data in enumerate(dataset, start=0): 43 | tot += 1 44 | if tot > opt.how_many: 45 | break 46 | if opt.data_type == 16: 47 | data['A'] = data['A'].half() 48 | data['A2'] = data['A2'].half() 49 | data['B'] = data['B'].half() 50 | data['B2'] = data['B2'].half() 51 | elif opt.data_type == 8: 52 | data['A'] = data['A'].uint8() 53 | data['A2'] = data['A2'].uint8() 54 | data['B'] = data['B'].uint8() 55 | data['B2'] = data['B2'].uint8() 56 | if opt.export_onnx: 57 | print ("Exporting to ONNX: ", opt.export_onnx) 58 | assert opt.export_onnx.endswith("onnx"), "Export model file should end with .onnx" 59 | torch.onnx.export(model, [data['label'], data['inst']], 60 | opt.export_onnx, verbose=True) 61 | exit(0) 62 | minibatch = 1 63 | if opt.engine: 64 | generated = run_trt_engine(opt.engine, minibatch, [data['label'], data['inst']]) 65 | elif opt.onnx: 66 | generated = run_onnx(opt.onnx, opt.data_type, minibatch, [data['label'], data['inst']]) 67 | else: 68 | generated = model.inference(data['A'], data['B'], data['B2']) 69 | 70 | visuals = OrderedDict([('input_label', util.tensor2im(data['A'][0])), 71 | ('real_image', util.tensor2im(data['A2'][0])), 72 | ('synthesized_image', util.tensor2im(generated.data[0])), 73 | ('B', util.tensor2im(data['B'][0])), 74 | ('B2', util.tensor2im(data['B2'][0]))]) 75 | img_path = data['path'] 76 | img_path = [str(tot-1)] 77 | print('process image... %s' % img_path) 78 | visualizer.save_images(webpage, visuals, img_path) 79 | 80 | webpage.save() 81 | -------------------------------------------------------------------------------- /test_seg.py: -------------------------------------------------------------------------------- 1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 3 | import os 4 | from collections import OrderedDict 5 | from torch.autograd import Variable 6 | from options.test_options import TestOptions 7 | from data.data_loader import CreateConDataLoader 8 | from models.models import create_model 9 | import util.util as util 10 | from util.visualizer import Visualizer 11 | from util import html 12 | import random 13 | import torch 14 | 15 | opt = TestOptions().parse(save=False) 16 | opt.nThreads = 1 # test code only supports nThreads = 1 17 | opt.batchSize = 1 # test code only supports batchSize = 1 18 | opt.serial_batches = True # no shuffle 19 | opt.no_flip = True # no flip 20 | 21 | data_loader = CreateConDataLoader(opt) 22 | dataset = data_loader.load_data() 23 | visualizer = Visualizer(opt) 24 | # create website 25 | web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s_%s' % (opt.phase, opt.which_epoch, str(opt.serial_batches))) 26 | webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch)) 27 | 28 | # test 29 | if not opt.engine and not opt.onnx: 30 | model = create_model(opt) 31 | if opt.data_type == 16: 32 | model.half() 33 | elif opt.data_type == 8: 34 | model.type(torch.uint8) 35 | 36 | if opt.verbose: 37 | print(model) 38 | else: 39 | from run_engine import run_trt_engine, run_onnx 40 | 41 | tot = 0 42 | for i, data in enumerate(dataset, start=random.randint(0, len(dataset) - opt.how_many)): 43 | tot += 1 44 | if tot > opt.how_many: 45 | break 46 | if opt.data_type == 16: 47 | data['A'] = data['A'].half() 48 | data['A2'] = data['A2'].half() 49 | data['B'] = data['B'].half() 50 | data['B2'] = data['B2'].half() 51 | elif opt.data_type == 8: 52 | data['A'] = data['A'].uint8() 53 | data['A2'] = data['A2'].uint8() 54 | data['B'] = data['B'].uint8() 55 | data['B2'] = data['B2'].uint8() 56 | if opt.export_onnx: 57 | print ("Exporting to ONNX: ", opt.export_onnx) 58 | assert opt.export_onnx.endswith("onnx"), "Export model file should end with .onnx" 59 | torch.onnx.export(model, [data['label'], data['inst']], 60 | opt.export_onnx, verbose=True) 61 | exit(0) 62 | minibatch = 1 63 | if opt.engine: 64 | generated = run_trt_engine(opt.engine, minibatch, [data['label'], data['inst']]) 65 | elif opt.onnx: 66 | generated = run_onnx(opt.onnx, opt.data_type, minibatch, [data['label'], data['inst']]) 67 | else: 68 | generated = model.inference(data['A'], data['B'], data['B2']) 69 | 70 | visuals = OrderedDict([('input_label', util.tensor2label(data['A'][0], 0)), 71 | ('real_image', util.tensor2im(data['A2'][0])), 72 | ('synthesized_image', util.tensor2im(generated.data[0])), 73 | ('B', util.tensor2label(data['B'][0], 0)), 74 | ('B2', util.tensor2im(data['B2'][0]))]) 75 | img_path = data['path'] 76 | img_path[0] = str(i) 77 | print('process image... %s' % img_path) 78 | visualizer.save_images(webpage, visuals, img_path) 79 | 80 | webpage.save() 81 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 3 | import time 4 | from collections import OrderedDict 5 | from options.train_options import TrainOptions 6 | from data.data_loader import CreateDataLoader 7 | from models.models import create_model 8 | import util.util as util 9 | from util.visualizer import Visualizer 10 | import os 11 | import numpy as np 12 | import torch 13 | from torch.autograd import Variable 14 | 15 | opt = TrainOptions().parse() 16 | print(opt.gpu_ids) 17 | iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt') 18 | if opt.continue_train: 19 | try: 20 | start_epoch, epoch_iter = np.loadtxt(iter_path , delimiter=',', dtype=int) 21 | except: 22 | start_epoch, epoch_iter = 1, 0 23 | print('Resuming from epoch %d at iteration %d' % (start_epoch, epoch_iter)) 24 | else: 25 | start_epoch, epoch_iter = 1, 0 26 | 27 | if opt.debug: 28 | opt.display_freq = 1 29 | opt.print_freq = 1 30 | opt.niter = 1 31 | opt.niter_decay = 0 32 | opt.max_dataset_size = 10 33 | 34 | data_loader = CreateDataLoader(opt) 35 | dataset = data_loader.load_data() 36 | dataset_size = len(data_loader) 37 | print('#training images = %d' % dataset_size) 38 | 39 | model = create_model(opt) 40 | visualizer = Visualizer(opt) 41 | 42 | total_steps = (start_epoch-1) * dataset_size + epoch_iter 43 | 44 | display_delta = total_steps % opt.display_freq 45 | print_delta = total_steps % opt.print_freq 46 | save_delta = total_steps % opt.save_latest_freq 47 | 48 | for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1): 49 | epoch_start_time = time.time() 50 | if epoch != start_epoch: 51 | epoch_iter = epoch_iter % dataset_size 52 | for i, data in enumerate(dataset, start=epoch_iter): 53 | iter_start_time = time.time() 54 | total_steps += opt.batchSize 55 | epoch_iter += opt.batchSize 56 | 57 | # whether to collect output images 58 | save_fake = total_steps % opt.display_freq == display_delta 59 | 60 | ############## Forward Pass ###################### 61 | losses, generated = model(Variable(data['label']), Variable(data['inst']), 62 | Variable(data['image']), Variable(data['feat']), infer=save_fake) 63 | 64 | # sum per device losses 65 | losses = [ torch.mean(x) if not isinstance(x, int) else x for x in losses ] 66 | loss_dict = dict(zip(model.module.loss_names, losses)) 67 | 68 | # calculate final loss scalar 69 | loss_D = (loss_dict['D_fake'] + loss_dict['D_real']) * 0.5 70 | loss_G = loss_dict['G_GAN'] + loss_dict.get('G_GAN_Feat',0) + loss_dict.get('G_VGG',0) 71 | 72 | ############### Backward Pass #################### 73 | # update generator weights 74 | model.module.optimizer_G.zero_grad() 75 | loss_G.backward() 76 | model.module.optimizer_G.step() 77 | 78 | # update discriminator weights 79 | model.module.optimizer_D.zero_grad() 80 | loss_D.backward() 81 | model.module.optimizer_D.step() 82 | 83 | #call(["nvidia-smi", "--format=csv", "--query-gpu=memory.used,memory.free"]) 84 | 85 | ############## Display results and errors ########## 86 | ### print out errors 87 | if total_steps % opt.print_freq == print_delta: 88 | errors = {k: v.data[0] if not isinstance(v, int) else v for k, v in loss_dict.items()} 89 | t = (time.time() - iter_start_time) / opt.batchSize 90 | visualizer.print_current_errors(epoch, epoch_iter, errors, t) 91 | visualizer.plot_current_errors(errors, total_steps) 92 | 93 | ### display output images 94 | if save_fake: 95 | visuals = OrderedDict([('input_label', util.tensor2label(data['label'][0], opt.label_nc)), 96 | ('synthesized_image', util.tensor2im(generated.data[0])), 97 | ('real_image', util.tensor2im(data['image'][0]))]) 98 | visualizer.display_current_results(visuals, epoch, total_steps) 99 | 100 | ### save latest model 101 | if total_steps % opt.save_latest_freq == save_delta: 102 | print('saving the latest model (epoch %d, total_steps %d)' % (epoch, total_steps)) 103 | model.module.save('latest') 104 | np.savetxt(iter_path, (epoch, epoch_iter), delimiter=',', fmt='%d') 105 | 106 | if epoch_iter >= dataset_size: 107 | break 108 | 109 | # end of epoch 110 | iter_end_time = time.time() 111 | print('End of epoch %d / %d \t Time Taken: %d sec' % 112 | (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time)) 113 | 114 | ### save model for this epoch 115 | if epoch % opt.save_epoch_freq == 0: 116 | print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps)) 117 | model.module.save('latest') 118 | model.module.save(epoch) 119 | np.savetxt(iter_path, (epoch+1, 0), delimiter=',', fmt='%d') 120 | 121 | ### instead of only training the local enhancer, train the entire network after certain iterations 122 | if (opt.niter_fix_global != 0) and (epoch == opt.niter_fix_global): 123 | model.module.update_fixed_params() 124 | 125 | ### linearly decay learning rate after certain iterations 126 | if epoch > opt.niter: 127 | model.module.update_learning_rate() 128 | -------------------------------------------------------------------------------- /train_con.py: -------------------------------------------------------------------------------- 1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 3 | import time 4 | from collections import OrderedDict 5 | from options.train_options import TrainOptions 6 | from data.data_loader import CreateConDataLoader 7 | from models.models import create_model 8 | import util.util as util 9 | from util.visualizer import Visualizer 10 | import os 11 | import numpy as np 12 | import torch 13 | from torch.autograd import Variable 14 | 15 | opt = TrainOptions().parse() 16 | iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt') 17 | if opt.continue_train: 18 | try: 19 | start_epoch, epoch_iter = np.loadtxt(iter_path , delimiter=',', dtype=int) 20 | except: 21 | start_epoch, epoch_iter = 1, 0 22 | print('Resuming from epoch %d at iteration %d' % (start_epoch, epoch_iter)) 23 | else: 24 | start_epoch, epoch_iter = 1, 0 25 | 26 | if opt.debug: 27 | opt.display_freq = 1 28 | opt.print_freq = 1 29 | opt.niter = 1 30 | opt.niter_decay = 0 31 | opt.max_dataset_size = 10 32 | 33 | data_loader = CreateConDataLoader(opt) 34 | dataset = data_loader.load_data() 35 | dataset_size = len(data_loader) 36 | print('#training images = %d' % dataset_size) 37 | 38 | model = create_model(opt) 39 | visualizer = Visualizer(opt) 40 | 41 | total_steps = (start_epoch-1) * dataset_size + epoch_iter 42 | 43 | display_delta = total_steps % opt.display_freq 44 | print_delta = total_steps % opt.print_freq 45 | save_delta = total_steps % opt.save_latest_freq 46 | 47 | for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1): 48 | epoch_start_time = time.time() 49 | if epoch != start_epoch: 50 | epoch_iter = epoch_iter % dataset_size 51 | for i, data in enumerate(dataset, start=epoch_iter): 52 | iter_start_time = time.time() 53 | total_steps += opt.batchSize 54 | epoch_iter += opt.batchSize 55 | 56 | # whether to collect output images 57 | save_fake = total_steps % opt.display_freq == display_delta 58 | 59 | ############## Forward Pass ###################### 60 | losses, generated = model( 61 | Variable(data['A']), Variable(data['A2']), 62 | Variable(data['B']), Variable(data['B2']), 63 | Variable(data['C']), Variable(data['C2']), 64 | Variable(data['D']), Variable(data['D2']), infer=save_fake) 65 | 66 | # sum per device losses 67 | losses = [ torch.mean(x) if not isinstance(x, int) else x for x in losses ] 68 | loss_dict = dict(zip(model.module.loss_names, losses)) 69 | 70 | if (total_steps <= opt.use_style_iter): 71 | use_style = 0 72 | else: 73 | use_style = 1 74 | loss_dict['G_GAN_style'] *= 10 75 | # calculate final loss scalar 76 | loss_D = (loss_dict['D_fake'] + loss_dict['D_real']) * 0.5 77 | loss_SD = (loss_dict['SD_fake1'] + loss_dict['SD_fake2'] + loss_dict['SD_real']) * 0.5 78 | loss_G = loss_dict['G_GAN'] + loss_dict['G_GAN_style'] * use_style + (loss_dict.get('G_GAN_Feat',0) + loss_dict.get('G_VGG',0)) 79 | 80 | ############### Backward Pass #################### 81 | # update generator weights 82 | model.module.optimizer_G.zero_grad() 83 | loss_G.backward() 84 | model.module.optimizer_G.step() 85 | 86 | # update discriminator weights 87 | model.module.optimizer_D.zero_grad() 88 | loss_D.backward() 89 | model.module.optimizer_D.step() 90 | 91 | # update discriminator weights 92 | if (total_steps > opt.use_style_iter): 93 | model.module.optimizer_SD.zero_grad() 94 | loss_SD.backward() 95 | model.module.optimizer_SD.step() 96 | 97 | #call(["nvidia-smi", "--format=csv", "--query-gpu=memory.used,memory.free"]) 98 | 99 | ############## Display results and errors ########## 100 | ### print out errors 101 | if total_steps % opt.print_freq == print_delta: 102 | errors = {k: v.data[0] if not isinstance(v, int) else v for k, v in loss_dict.items()} 103 | errors['loss_G'] = loss_G 104 | errors['loss_D'] = loss_D 105 | errors['loss_SD'] = loss_SD 106 | t = (time.time() - iter_start_time) / opt.batchSize 107 | visualizer.print_current_errors(epoch, epoch_iter, errors, t) 108 | visualizer.plot_current_errors(errors, total_steps) 109 | 110 | 111 | ### display output images 112 | if save_fake: 113 | visuals = OrderedDict([('input_label', util.tensor2label(data['A'][0], 256)), 114 | ('real_image', util.tensor2im(data['A2'][0])), 115 | ('synthesized_image', util.tensor2im(generated.data[0])), 116 | ('B', util.tensor2label(data['B'][0], 256)), 117 | ('B2', util.tensor2im(data['B2'][0]))]) 118 | visualizer.display_current_results2(visuals, epoch, total_steps) 119 | 120 | ### save latest model 121 | if total_steps % opt.save_latest_freq == save_delta: 122 | print('saving the latest model (epoch %d, total_steps %d)' % (epoch, total_steps)) 123 | model.module.save('latest') 124 | np.savetxt(iter_path, (epoch, epoch_iter), delimiter=',', fmt='%d') 125 | 126 | if opt.use_iter_decay and total_steps > opt.niter_iter: 127 | for temp in range(opt.batchSize): 128 | model.module.update_learning_rate() 129 | 130 | if opt.use_iter_decay and total_steps >= opt.niter_iter + opt.niter_decay_iter: 131 | break 132 | if epoch_iter >= dataset_size: 133 | break 134 | 135 | # end of epoch 136 | iter_end_time = time.time() 137 | print('End of epoch %d / %d \t Time Taken: %d sec' % 138 | (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time)) 139 | 140 | ### save model for this epoch 141 | if epoch % opt.save_epoch_freq == 0: 142 | print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps)) 143 | model.module.save('latest') 144 | model.module.save(epoch) 145 | np.savetxt(iter_path, (epoch+1, 0), delimiter=',', fmt='%d') 146 | 147 | ### instead of only training the local enhancer, train the entire network after certain iterations 148 | if (opt.niter_fix_global != 0) and (epoch == opt.niter_fix_global): 149 | model.module.update_fixed_params() 150 | 151 | ### linearly decay learning rate after certain iterations 152 | if epoch > opt.niter and not opt.use_iter_decay: 153 | model.module.update_learning_rate() 154 | 155 | if opt.use_iter_decay and total_steps > opt.niter_iter + opt.niter_decay_iter: 156 | break 157 | 158 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cxjyxxme/pix2pixSC/1851e934b333e89b98b3c0df9962912201fcb376/util/__init__.py -------------------------------------------------------------------------------- /util/html.py: -------------------------------------------------------------------------------- 1 | import dominate 2 | from dominate.tags import * 3 | import os 4 | 5 | 6 | class HTML: 7 | def __init__(self, web_dir, title, refresh=0): 8 | self.title = title 9 | self.web_dir = web_dir 10 | self.img_dir = os.path.join(self.web_dir, 'images') 11 | if not os.path.exists(self.web_dir): 12 | os.makedirs(self.web_dir) 13 | if not os.path.exists(self.img_dir): 14 | os.makedirs(self.img_dir) 15 | 16 | self.doc = dominate.document(title=title) 17 | if refresh > 0: 18 | with self.doc.head: 19 | meta(http_equiv="refresh", content=str(refresh)) 20 | 21 | def get_image_dir(self): 22 | return self.img_dir 23 | 24 | def add_header(self, str): 25 | with self.doc: 26 | h3(str) 27 | 28 | def add_table(self, border=1): 29 | self.t = table(border=border, style="table-layout: fixed;") 30 | self.doc.add(self.t) 31 | 32 | def add_images(self, ims, txts, links, width=512): 33 | self.add_table() 34 | with self.t: 35 | with tr(): 36 | for im, txt, link in zip(ims, txts, links): 37 | with td(style="word-wrap: break-word;", halign="center", valign="top"): 38 | with p(): 39 | with a(href=os.path.join('images', link)): 40 | img(style="width:%dpx" % (width), src=os.path.join('images', im)) 41 | br() 42 | p(txt) 43 | 44 | 45 | def add_images2(self, ims, links, txts, disp_title, width=512): 46 | self.add_table() 47 | with self.t: 48 | with tr(): 49 | with td(style="word-wrap: break-word;", halign="center", valign="top"): 50 | with p(): 51 | t = txts[0] 52 | while (len(t) < 10): 53 | t = '_' + t 54 | p(t) 55 | for im, txt, link in zip(ims, txts, links): 56 | with td(style="word-wrap: break-word;", halign="center", valign="top"): 57 | with p(): 58 | if (disp_title): 59 | p(txt) 60 | with a(href=os.path.join('images', link)): 61 | img(style="width:%dpx" % (width), src=os.path.join('images', im)) 62 | 63 | def save(self): 64 | html_file = '%s/index.html' % self.web_dir 65 | f = open(html_file, 'wt') 66 | f.write(self.doc.render()) 67 | f.close() 68 | 69 | 70 | if __name__ == '__main__': 71 | html = HTML('web/', 'test_html') 72 | html.add_header('hello world') 73 | 74 | ims = [] 75 | txts = [] 76 | links = [] 77 | for n in range(4): 78 | ims.append('image_%d.jpg' % n) 79 | txts.append('text_%d' % n) 80 | links.append('image_%d.jpg' % n) 81 | html.add_images(ims, txts, links) 82 | html.save() 83 | -------------------------------------------------------------------------------- /util/image_pool.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | from torch.autograd import Variable 4 | class ImagePool(): 5 | def __init__(self, pool_size): 6 | self.pool_size = pool_size 7 | if self.pool_size > 0: 8 | self.num_imgs = 0 9 | self.images = [] 10 | 11 | def query(self, images): 12 | if self.pool_size == 0: 13 | return images 14 | return_images = [] 15 | for image in images.data: 16 | image = torch.unsqueeze(image, 0) 17 | if self.num_imgs < self.pool_size: 18 | self.num_imgs = self.num_imgs + 1 19 | self.images.append(image) 20 | return_images.append(image) 21 | else: 22 | p = random.uniform(0, 1) 23 | if p > 0.5: 24 | random_id = random.randint(0, self.pool_size-1) 25 | tmp = self.images[random_id].clone() 26 | self.images[random_id] = image 27 | return_images.append(tmp) 28 | else: 29 | return_images.append(image) 30 | return_images = Variable(torch.cat(return_images, 0)) 31 | return return_images 32 | -------------------------------------------------------------------------------- /util/label.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | # a label and all meta information 4 | # Code inspired by Cityscapes https://github.com/mcordts/cityscapesScripts 5 | Label = namedtuple('Label', [ 6 | 7 | 'name', # The identifier of this label, e.g. 'car', 'person', ... . 8 | # We use them to uniquely name a class 9 | 10 | 'id', # An integer ID that is associated with this label. 11 | # The IDs are used to represent the label in ground truth images 12 | # An ID of -1 means that this label does not have an ID and thus 13 | # is ignored when creating ground truth images (e.g. license plate). 14 | # Do not modify these IDs, since exactly these IDs are expected by the 15 | # evaluation server. 16 | 17 | 'trainId', 18 | # Feel free to modify these IDs as suitable for your method. Then create 19 | # ground truth images with train IDs, using the tools provided in the 20 | # 'preparation' folder. However, make sure to validate or submit results 21 | # to our evaluation server using the regular IDs above! 22 | # For trainIds, multiple labels might have the same ID. Then, these labels 23 | # are mapped to the same class in the ground truth images. For the inverse 24 | # mapping, we use the label that is defined first in the list below. 25 | # For example, mapping all void-type classes to the same ID in training, 26 | # might make sense for some approaches. 27 | # Max value is 255! 28 | 29 | 'category', # The name of the category that this label belongs to 30 | 31 | 'categoryId', 32 | # The ID of this category. Used to create ground truth images 33 | # on category level. 34 | 35 | 'hasInstances', 36 | # Whether this label distinguishes between single instances or not 37 | 38 | 'ignoreInEval', 39 | # Whether pixels having this class as ground truth label are ignored 40 | # during evaluations or not 41 | 42 | 'color', # The color of this label 43 | ]) 44 | 45 | 46 | # Our extended list of label types. Our train id is compatible with Cityscapes 47 | labels = [ 48 | # name id trainId category catId hasInstances ignoreInEval color 49 | Label( 'unlabeled' , 0 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 50 | Label( 'dynamic' , 1 , 255 , 'void' , 0 , False , True , (111, 74, 0) ), 51 | Label( 'ego vehicle' , 2 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 52 | Label( 'ground' , 3 , 255 , 'void' , 0 , False , True , ( 81, 0, 81) ), 53 | Label( 'static' , 4 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 54 | Label( 'parking' , 5 , 255 , 'flat' , 1 , False , True , (250,170,160) ), 55 | Label( 'rail track' , 6 , 255 , 'flat' , 1 , False , True , (230,150,140) ), 56 | Label( 'road' , 7 , 0 , 'flat' , 1 , False , False , (128, 64,128) ), 57 | Label( 'sidewalk' , 8 , 1 , 'flat' , 1 , False , False , (244, 35,232) ), 58 | Label( 'bridge' , 9 , 255 , 'construction' , 2 , False , True , (150,100,100) ), 59 | Label( 'building' , 10 , 2 , 'construction' , 2 , False , False , ( 70, 70, 70) ), 60 | Label( 'fence' , 11 , 4 , 'construction' , 2 , False , False , (190,153,153) ), 61 | Label( 'garage' , 12 , 255 , 'construction' , 2 , False , True , (180,100,180) ), 62 | Label( 'guard rail' , 13 , 255 , 'construction' , 2 , False , True , (180,165,180) ), 63 | Label( 'tunnel' , 14 , 255 , 'construction' , 2 , False , True , (150,120, 90) ), 64 | Label( 'wall' , 15 , 3 , 'construction' , 2 , False , False , (102,102,156) ), 65 | Label( 'banner' , 16 , 255 , 'object' , 3 , False , True , (250,170,100) ), 66 | Label( 'billboard' , 17 , 255 , 'object' , 3 , False , True , (220,220,250) ), 67 | Label( 'lane divider' , 18 , 255 , 'object' , 3 , False , True , (255, 165, 0) ), 68 | Label( 'parking sign' , 19 , 255 , 'object' , 3 , False , False , (220, 20, 60) ), 69 | Label( 'pole' , 20 , 5 , 'object' , 3 , False , False , (153,153,153) ), 70 | Label( 'polegroup' , 21 , 255 , 'object' , 3 , False , True , (153,153,153) ), 71 | Label( 'street light' , 22 , 255 , 'object' , 3 , False , True , (220,220,100) ), 72 | Label( 'traffic cone' , 23 , 255 , 'object' , 3 , False , True , (255, 70, 0) ), 73 | Label( 'traffic device' , 24 , 255 , 'object' , 3 , False , True , (220,220,220) ), 74 | Label( 'traffic light' , 25 , 6 , 'object' , 3 , False , False , (250,170, 30) ), 75 | Label( 'traffic sign' , 26 , 7 , 'object' , 3 , False , False , (220,220, 0) ), 76 | Label( 'traffic sign frame' , 27 , 255 , 'object' , 3 , False , True , (250,170,250) ), 77 | Label( 'terrain' , 28 , 9 , 'nature' , 4 , False , False , (152,251,152) ), 78 | Label( 'vegetation' , 29 , 8 , 'nature' , 4 , False , False , (107,142, 35) ), 79 | Label( 'sky' , 30 , 10 , 'sky' , 5 , False , False , ( 70,130,180) ), 80 | Label( 'person' , 31 , 11 , 'human' , 6 , True , False , (220, 20, 60) ), 81 | Label( 'rider' , 32 , 12 , 'human' , 6 , True , False , (255, 0, 0) ), 82 | Label( 'bicycle' , 33 , 18 , 'vehicle' , 7 , True , False , (119, 11, 32) ), 83 | Label( 'bus' , 34 , 15 , 'vehicle' , 7 , True , False , ( 0, 60,100) ), 84 | Label( 'car' , 35 , 13 , 'vehicle' , 7 , True , False , ( 0, 0,142) ), 85 | Label( 'caravan' , 36 , 255 , 'vehicle' , 7 , True , True , ( 0, 0, 90) ), 86 | Label( 'motorcycle' , 37 , 17 , 'vehicle' , 7 , True , False , ( 0, 0,230) ), 87 | Label( 'trailer' , 38 , 255 , 'vehicle' , 7 , True , True , ( 0, 0,110) ), 88 | Label( 'train' , 39 , 16 , 'vehicle' , 7 , True , False , ( 0, 80,100) ), 89 | Label( 'truck' , 40 , 14 , 'vehicle' , 7 , True , False , ( 0, 0, 70) ), 90 | ] 91 | -------------------------------------------------------------------------------- /util/util.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | import numpy as np 4 | from PIL import Image 5 | import numpy as np 6 | import os 7 | from util.label import labels 8 | 9 | def get_big_img(ans): 10 | ret = [] 11 | for a in ans: 12 | temp = [] 13 | for b in a: 14 | if (len(b.shape) == 2): 15 | t = b[np.newaxis, :, :] 16 | c = np.concatenate((t,t,t), axis=0) 17 | else: 18 | c = b 19 | temp.append(c) 20 | t = np.concatenate(temp, axis=1) 21 | ret.append(t) 22 | final = np.concatenate(ret, axis=2) 23 | print(final.shape) 24 | return final 25 | 26 | # Converts a Tensor into a Numpy array 27 | # |imtype|: the desired type of the converted numpy array 28 | def tensor2im(image_tensor, imtype=np.uint8, normalize=True): 29 | if isinstance(image_tensor, list): 30 | image_numpy = [] 31 | for i in range(len(image_tensor)): 32 | image_numpy.append(tensor2im(image_tensor[i], imtype, normalize)) 33 | return image_numpy 34 | image_numpy = image_tensor.cpu().float().numpy() 35 | if normalize: 36 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 37 | else: 38 | image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0 39 | image_numpy = np.clip(image_numpy, 0, 255) 40 | if image_numpy.shape[2] == 1 or image_numpy.shape[2] > 3: 41 | image_numpy = image_numpy[:,:,0] 42 | return image_numpy.astype(imtype) 43 | 44 | # Converts a Tensor into a Numpy array 45 | # |imtype|: the desired type of the converted numpy array 46 | def tensor2im2(image_tensor, imtype=np.uint8, normalize=True): 47 | if isinstance(image_tensor, list): 48 | image_numpy = [] 49 | for i in range(len(image_tensor)): 50 | image_numpy.append(tensor2im(image_tensor[i], imtype, normalize)) 51 | return image_numpy 52 | image_numpy = image_tensor.cpu().float().numpy() 53 | if normalize: 54 | image_numpy = (image_numpy + 1) / 2.0 * 255.0 55 | else: 56 | image_numpy = image_numpy * 255.0 57 | image_numpy = np.clip(image_numpy, 0, 255) 58 | if image_numpy.shape[0] == 23: 59 | image_numpy = np.concatenate([image_numpy[3:6, :, :], image_numpy[0:3, :, :]], axis=1) 60 | if image_numpy.shape[0] == 1 or image_numpy.shape[0] > 3: 61 | image_numpy = image_numpy[0, :,:] 62 | return image_numpy.astype(imtype) 63 | 64 | # Converts a one-hot tensor into a colorful label map 65 | def tensor2label2(label_tensor, n_label, imtype=np.uint8): 66 | if n_label == 0: 67 | return label_tensor.cpu().float().numpy().astype(imtype) 68 | return tensor2im2(label_tensor, imtype) 69 | label_tensor = label_tensor.cpu().float() 70 | if label_tensor.size()[0] > 1: 71 | label_tensor = label_tensor.max(0, keepdim=True)[1] 72 | label_tensor = Colorize(n_label)(label_tensor) 73 | #label_numpy = np.transpose(label_tensor.numpy(), (1, 2, 0)) 74 | #print(label_numpy.shape) 75 | return label_tensor.numpy().astype(imtype) 76 | 77 | # Converts a one-hot tensor into a colorful label map 78 | def tensor2label(label_tensor, n_label, imtype=np.uint8): 79 | if n_label == 0: 80 | return tensor2im(label_tensor, imtype) 81 | label_tensor = label_tensor.cpu().float() 82 | if label_tensor.size()[0] > 1: 83 | label_tensor = label_tensor.max(0, keepdim=True)[1] 84 | label_tensor = Colorize(n_label)(label_tensor) 85 | label_numpy = np.transpose(label_tensor.numpy(), (1, 2, 0)) 86 | #print(label_numpy.shape) 87 | return label_numpy.astype(imtype) 88 | 89 | def save_image(image_numpy, image_path): 90 | image_pil = Image.fromarray(image_numpy) 91 | image_pil.save(image_path) 92 | 93 | def mkdirs(paths): 94 | if isinstance(paths, list) and not isinstance(paths, str): 95 | for path in paths: 96 | mkdir(path) 97 | else: 98 | mkdir(paths) 99 | 100 | def mkdir(path): 101 | if not os.path.exists(path): 102 | os.makedirs(path) 103 | 104 | ############################################################################### 105 | # Code from 106 | # https://github.com/ycszen/pytorch-seg/blob/master/transform.py 107 | # Modified so it complies with the Citscape label map colors 108 | ############################################################################### 109 | def uint82bin(n, count=8): 110 | """returns the binary of integer n, count refers to amount of bits""" 111 | return ''.join([str((n >> y) & 1) for y in range(count-1, -1, -1)]) 112 | 113 | def labelcolormap(N): 114 | if N == 35: # cityscape 115 | cmap = np.array([( 0, 0, 0), ( 0, 0, 0), ( 0, 0, 0), ( 0, 0, 0), ( 0, 0, 0), (111, 74, 0), ( 81, 0, 81), 116 | (128, 64,128), (244, 35,232), (250,170,160), (230,150,140), ( 70, 70, 70), (102,102,156), (190,153,153), 117 | (180,165,180), (150,100,100), (150,120, 90), (153,153,153), (153,153,153), (250,170, 30), (220,220, 0), 118 | (107,142, 35), (152,251,152), ( 70,130,180), (220, 20, 60), (255, 0, 0), ( 0, 0,142), ( 0, 0, 70), 119 | ( 0, 60,100), ( 0, 0, 90), ( 0, 0,110), ( 0, 80,100), ( 0, 0,230), (119, 11, 32), ( 0, 0,142)], 120 | dtype=np.uint8) 121 | elif N == 256: 122 | t = [] 123 | for i in range(256): 124 | t.append((0, 0, 0)) 125 | for i in range(len(labels)): 126 | t[labels[i].trainId] = labels[i].color 127 | cmap = np.array(t, dtype=np.uint8) 128 | else: 129 | cmap = np.zeros((N, 3), dtype=np.uint8) 130 | for i in range(N): 131 | r, g, b = 0, 0, 0 132 | id = i 133 | for j in range(7): 134 | str_id = uint82bin(id) 135 | r = r ^ (np.uint8(str_id[-1]) << (7-j)) 136 | g = g ^ (np.uint8(str_id[-2]) << (7-j)) 137 | b = b ^ (np.uint8(str_id[-3]) << (7-j)) 138 | id = id >> 3 139 | cmap[i, 0] = r 140 | cmap[i, 1] = g 141 | cmap[i, 2] = b 142 | return cmap 143 | 144 | class Colorize(object): 145 | def __init__(self, n=35): 146 | self.cmap = labelcolormap(n) 147 | self.cmap = torch.from_numpy(self.cmap[:n]) 148 | 149 | def __call__(self, gray_image): 150 | size = gray_image.size() 151 | color_image = torch.ByteTensor(3, size[1], size[2]).fill_(0) 152 | 153 | for label in range(0, len(self.cmap)): 154 | mask = (label == gray_image[0]).cpu() 155 | color_image[0][mask] = self.cmap[label][0] 156 | color_image[1][mask] = self.cmap[label][1] 157 | color_image[2][mask] = self.cmap[label][2] 158 | 159 | return color_image 160 | -------------------------------------------------------------------------------- /util/visualizer.py: -------------------------------------------------------------------------------- 1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 3 | import numpy as np 4 | import os 5 | import ntpath 6 | import time 7 | from . import util 8 | from . import html 9 | import scipy.misc 10 | try: 11 | from StringIO import StringIO # Python 2.7 12 | except ImportError: 13 | from io import BytesIO # Python 3.x 14 | 15 | class Visualizer(): 16 | def __init__(self, opt): 17 | # self.opt = opt 18 | self.tf_log = opt.tf_log 19 | self.use_html = opt.isTrain and not opt.no_html 20 | self.win_size = opt.display_winsize 21 | self.name = opt.name 22 | if self.tf_log: 23 | import tensorflow as tf 24 | self.tf = tf 25 | self.log_dir = os.path.join(opt.checkpoints_dir, opt.name, 'logs') 26 | self.writer = tf.summary.FileWriter(self.log_dir) 27 | 28 | if self.use_html: 29 | self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') 30 | self.img_dir = os.path.join(self.web_dir, 'images') 31 | print('create web directory %s...' % self.web_dir) 32 | util.mkdirs([self.web_dir, self.img_dir]) 33 | self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') 34 | with open(self.log_name, "a") as log_file: 35 | now = time.strftime("%c") 36 | log_file.write('================ Training Loss (%s) ================\n' % now) 37 | 38 | def get_names(self, path): 39 | names = [] 40 | for root, _, fnames in os.walk(path): 41 | for fname in fnames: 42 | t = fname.find('input_label') 43 | if (t != -1): 44 | names.append([fname[:t], fname[t + 12:-4]]) 45 | names = sorted(names, key=lambda x:-int(x[1])) 46 | return names 47 | 48 | # |visuals|: dictionary of images to display or save 49 | def display_current_results2(self, visuals, epoch, step): 50 | if self.tf_log: # show images in tensorboard output 51 | img_summaries = [] 52 | for label, image_numpy in visuals.items(): 53 | # Write the image to a string 54 | try: 55 | s = StringIO() 56 | except: 57 | s = BytesIO() 58 | scipy.misc.toimage(image_numpy).save(s, format="jpeg") 59 | # Create an Image object 60 | img_sum = self.tf.Summary.Image(encoded_image_string=s.getvalue(), height=image_numpy.shape[0], width=image_numpy.shape[1]) 61 | # Create a Summary value 62 | img_summaries.append(self.tf.Summary.Value(tag=label, image=img_sum)) 63 | 64 | # Create and write Summary 65 | summary = self.tf.Summary(value=img_summaries) 66 | self.writer.add_summary(summary, step) 67 | 68 | if self.use_html: # save images to a html file 69 | for label, image_numpy in visuals.items(): 70 | if isinstance(image_numpy, list): 71 | for i in range(len(image_numpy)): 72 | img_path = os.path.join(self.img_dir, 'epoch%.3d_%s_%d.jpg' % (epoch, label, i)) 73 | util.save_image(image_numpy[i], img_path) 74 | else: 75 | img_path = os.path.join(self.img_dir, 'epoch%.3d_%s_%.3d.jpg' % (epoch, label, step)) 76 | util.save_image(image_numpy, img_path) 77 | 78 | names = self.get_names(self.web_dir) 79 | # update website 80 | webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=30) 81 | for name in names: 82 | webpage.add_header('%s %s' % (name[0], name[1])) 83 | ims = [] 84 | txts = [] 85 | links = [] 86 | 87 | for label, image_numpy in visuals.items(): 88 | img_path = '%s%s_%s.jpg' % (name[0], label, name[1]) 89 | ims.append(img_path) 90 | txts.append(label) 91 | links.append(img_path) 92 | if len(ims) < 10: 93 | webpage.add_images(ims, txts, links, width=self.win_size) 94 | else: 95 | num = int(round(len(ims)/2.0)) 96 | webpage.add_images(ims[:num], txts[:num], links[:num], width=self.win_size) 97 | webpage.add_images(ims[num:], txts[num:], links[num:], width=self.win_size) 98 | webpage.save() 99 | 100 | 101 | # |visuals|: dictionary of images to display or save 102 | def display_current_results(self, visuals, epoch, step): 103 | if self.tf_log: # show images in tensorboard output 104 | img_summaries = [] 105 | for label, image_numpy in visuals.items(): 106 | # Write the image to a string 107 | try: 108 | s = StringIO() 109 | except: 110 | s = BytesIO() 111 | scipy.misc.toimage(image_numpy).save(s, format="jpeg") 112 | # Create an Image object 113 | img_sum = self.tf.Summary.Image(encoded_image_string=s.getvalue(), height=image_numpy.shape[0], width=image_numpy.shape[1]) 114 | # Create a Summary value 115 | img_summaries.append(self.tf.Summary.Value(tag=label, image=img_sum)) 116 | 117 | # Create and write Summary 118 | summary = self.tf.Summary(value=img_summaries) 119 | self.writer.add_summary(summary, step) 120 | 121 | if self.use_html: # save images to a html file 122 | for label, image_numpy in visuals.items(): 123 | if isinstance(image_numpy, list): 124 | for i in range(len(image_numpy)): 125 | img_path = os.path.join(self.img_dir, 'epoch%.3d_%s_%d.jpg' % (epoch, label, i)) 126 | util.save_image(image_numpy[i], img_path) 127 | else: 128 | img_path = os.path.join(self.img_dir, 'epoch%.3d_%s_%.3d.jpg' % (epoch, label, step)) 129 | util.save_image(image_numpy, img_path) 130 | 131 | # update website 132 | webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=30) 133 | for n in range(epoch, 0, -1): 134 | webpage.add_header('epoch [%d]' % n) 135 | ims = [] 136 | txts = [] 137 | links = [] 138 | 139 | for label, image_numpy in visuals.items(): 140 | if isinstance(image_numpy, list): 141 | for i in range(len(image_numpy)): 142 | img_path = 'epoch%.3d_%s_%d.jpg' % (n, label, i) 143 | ims.append(img_path) 144 | txts.append(label+str(i)) 145 | links.append(img_path) 146 | else: 147 | img_path = 'epoch%.3d_%s_%.3d.jpg' % (n, label, step) 148 | ims.append(img_path) 149 | txts.append(label) 150 | links.append(img_path) 151 | if len(ims) < 10: 152 | webpage.add_images(ims, txts, links, width=self.win_size) 153 | else: 154 | num = int(round(len(ims)/2.0)) 155 | webpage.add_images(ims[:num], txts[:num], links[:num], width=self.win_size) 156 | webpage.add_images(ims[num:], txts[num:], links[num:], width=self.win_size) 157 | webpage.save() 158 | 159 | # errors: dictionary of error labels and values 160 | def plot_current_errors(self, errors, step): 161 | if self.tf_log: 162 | for tag, value in errors.items(): 163 | summary = self.tf.Summary(value=[self.tf.Summary.Value(tag=tag, simple_value=value)]) 164 | self.writer.add_summary(summary, step) 165 | 166 | # errors: same format as |errors| of plotCurrentErrors 167 | def print_current_errors(self, epoch, i, errors, t): 168 | message = '(epoch: %d, iters: %d, time: %.3f) ' % (epoch, i, t) 169 | temp = errors.items() 170 | temp = sorted(temp, key=lambda x:x[0]) 171 | for k, v in temp: 172 | if v != 0: 173 | message += '%s: %.3f ' % (k, v) 174 | 175 | print(message) 176 | with open(self.log_name, "a") as log_file: 177 | log_file.write('%s\n' % message) 178 | 179 | # save image to the disk 180 | def save_images(self, webpage, visuals, image_path): 181 | image_dir = webpage.get_image_dir() 182 | short_path = ntpath.basename(image_path[0]) 183 | name = os.path.splitext(short_path)[0] 184 | 185 | webpage.add_header(name) 186 | ims = [] 187 | txts = [] 188 | links = [] 189 | 190 | for label, image_numpy in visuals.items(): 191 | image_name = '%s_%s.jpg' % (name, label) 192 | save_path = os.path.join(image_dir, image_name) 193 | util.save_image(image_numpy, save_path) 194 | 195 | ims.append(image_name) 196 | txts.append(label) 197 | links.append(image_name) 198 | webpage.add_images(ims, txts, links, width=self.win_size) 199 | # save image to the disk 200 | def save_images2(self, webpage, visuals, image_path, disp_title=False): 201 | image_dir = webpage.get_image_dir() 202 | short_path = ntpath.basename(image_path[0]) 203 | name = os.path.splitext(short_path)[0] 204 | 205 | ims = [] 206 | links = [] 207 | txts = [] 208 | 209 | for label, image_numpy in visuals.items(): 210 | image_name = '%s_%s.jpg' % (name, label) 211 | save_path = os.path.join(image_dir, image_name) 212 | util.save_image(image_numpy, save_path) 213 | 214 | ims.append(image_name) 215 | links.append(image_name) 216 | txts.append(label) 217 | webpage.add_images2(ims, links, txts, disp_title, width=self.win_size) 218 | -------------------------------------------------------------------------------- /val_gen.py: -------------------------------------------------------------------------------- 1 | import os 2 | A_list = ['__Record024__Camera_5__170927_071032417_Camera_5'] 3 | A_list = ['_a'] 4 | B_list = ['__Record024__Camera_6__170927_070953498_Camera_6', '__Record031__Camera_5__170927_071846100_Camera_5', '__Record035__Camera_5__170927_072446711_Camera_5'] 5 | data_path = './datasets/apollo/' 6 | 7 | def label(A): 8 | return os.path.join(data_path, 'label', 'Label' + A + '_bin.png') 9 | def img(A): 10 | return os.path.join(data_path, 'img', 'ColorImage' + A + '.jpg') 11 | 12 | lines = [] 13 | for A in A_list: 14 | for B in B_list: 15 | lines.append(label(A) + '&' + label(B) + '&' + img(B) + '&' + img(A) + '\n') 16 | f = open(os.path.join(data_path, 'debug_list.txt'), 'w') 17 | f.writelines(lines) 18 | -------------------------------------------------------------------------------- /vis_bdd.py: -------------------------------------------------------------------------------- 1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 3 | import time 4 | from collections import OrderedDict 5 | from options.train_options import TrainOptions 6 | from data.data_loader import CreateConDataLoader 7 | from models.models import create_model 8 | import util.util as util 9 | from util.visualizer import Visualizer 10 | import os 11 | import numpy as np 12 | import torch 13 | from torch.autograd import Variable 14 | import tensorboardX 15 | import random 16 | 17 | GAP = 1 18 | 19 | def write_temp(opt, target_phase): 20 | source_path = os.path.join(opt.dataroot, opt.phase + '_list.txt') 21 | target_path = os.path.join(opt.dataroot, target_phase + '_list.txt') 22 | f = open(source_path, 'r') 23 | all_paths = f.readlines() 24 | ans = [] 25 | last = '' 26 | tot = 0 27 | print(len(all_paths)) 28 | for path in all_paths: 29 | path_ = path.split('&')[0] 30 | if (path_ != last): 31 | last = path_ 32 | tot = tot + 1 33 | if tot % GAP == 0: 34 | ans.append(path) 35 | f_w = open(target_path, 'w') 36 | f_w.writelines(ans) 37 | 38 | opt = TrainOptions().parse() 39 | opt.phase = 'val' 40 | write_temp(opt, "temp") 41 | opt.phase = "temp" 42 | opt.serial_batches = True 43 | 44 | data_loader = CreateConDataLoader(opt) 45 | dataset = data_loader.load_data() 46 | dataset_size = len(data_loader) 47 | 48 | visualizer = Visualizer(opt) 49 | 50 | total_steps = 0# (start_epoch-1) * dataset_size + epoch_iter 51 | 52 | display_delta = total_steps % opt.display_freq 53 | print_delta = total_steps % opt.print_freq 54 | save_delta = total_steps % opt.save_latest_freq 55 | 56 | for i, data in enumerate(dataset): 57 | if (i % 100 == 0): 58 | print((i, dataset_size)) 59 | visuals = OrderedDict([('input_label', util.tensor2label(data['A'][0], 256)), 60 | ('real_image', util.tensor2im(data['A2'][0]))]) 61 | visualizer.display_current_results2(visuals, 0, i) 62 | 63 | -------------------------------------------------------------------------------- /vis_delta.txt: -------------------------------------------------------------------------------- 1 | 0 2 | -------------------------------------------------------------------------------- /vis_face.py: -------------------------------------------------------------------------------- 1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 3 | import time 4 | from collections import OrderedDict 5 | from options.train_options import TrainOptions 6 | from data.data_loader import CreateFaceConDataLoader 7 | from models.models import create_model 8 | import util.util as util 9 | from util.visualizer import Visualizer 10 | import os 11 | import numpy as np 12 | import torch 13 | from torch.autograd import Variable 14 | import tensorboardX 15 | import random 16 | 17 | GAP = 10 18 | 19 | def write_temp(opt, target_phase): 20 | source_path = os.path.join(opt.dataroot, opt.phase + '_list.txt') 21 | target_path = os.path.join(opt.dataroot, target_phase + '_list.txt') 22 | f = open(source_path, 'r') 23 | all_paths = f.readlines() 24 | ans = [] 25 | last = '' 26 | tot = 0 27 | print(len(all_paths)) 28 | for path in all_paths: 29 | path_ = path.split('&')[0] 30 | if (path_ != last): 31 | last = path_ 32 | tot = tot + 1 33 | if tot % GAP == 0: 34 | ans.append(path) 35 | f_w = open(target_path, 'w') 36 | f_w.writelines(ans) 37 | 38 | opt = TrainOptions().parse() 39 | opt.phase = 'val' 40 | write_temp(opt, "temp") 41 | opt.phase = "temp" 42 | opt.serial_batches = True 43 | 44 | data_loader = CreateFaceConDataLoader(opt) 45 | dataset = data_loader.load_data() 46 | dataset_size = len(data_loader) 47 | 48 | visualizer = Visualizer(opt) 49 | 50 | total_steps = 0# (start_epoch-1) * dataset_size + epoch_iter 51 | 52 | display_delta = total_steps % opt.display_freq 53 | print_delta = total_steps % opt.print_freq 54 | save_delta = total_steps % opt.save_latest_freq 55 | 56 | for i, data in enumerate(dataset): 57 | if (i % 100 == 0): 58 | print((i, dataset_size)) 59 | visuals = OrderedDict([('input_label', util.tensor2label(data['A'][0], 0)), 60 | ('real_image', util.tensor2im(data['A2'][0]))]) 61 | visualizer.display_current_results2(visuals, 0, i) 62 | 63 | -------------------------------------------------------------------------------- /vis_pose.py: -------------------------------------------------------------------------------- 1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 3 | import time 4 | from collections import OrderedDict 5 | from options.train_options import TrainOptions 6 | from data.data_loader import CreatePoseConDataLoader 7 | from models.models import create_model 8 | import util.util as util 9 | from util.visualizer import Visualizer 10 | import os 11 | import numpy as np 12 | import torch 13 | from torch.autograd import Variable 14 | import tensorboardX 15 | import random 16 | 17 | GAP = 10 18 | 19 | def write_temp(opt, target_phase): 20 | source_path = os.path.join(opt.dataroot, opt.phase + '_list.txt') 21 | target_path = os.path.join(opt.dataroot, target_phase + '_list.txt') 22 | f = open(source_path, 'r') 23 | all_paths = f.readlines() 24 | ans = [] 25 | last = '' 26 | tot = 0 27 | print(len(all_paths)) 28 | for path in all_paths: 29 | path_ = path.split('&')[0] 30 | if (path_ != last): 31 | last = path_ 32 | tot = tot + 1 33 | if tot % GAP == 0: 34 | ans.append(path) 35 | f_w = open(target_path, 'w') 36 | f_w.writelines(ans) 37 | 38 | opt = TrainOptions().parse() 39 | opt.phase = 'val' 40 | write_temp(opt, "temp") 41 | opt.phase = "temp" 42 | opt.serial_batches = True 43 | 44 | data_loader = CreatePoseConDataLoader(opt) 45 | dataset = data_loader.load_data() 46 | dataset_size = len(data_loader) 47 | 48 | visualizer = Visualizer(opt) 49 | 50 | total_steps = 0# (start_epoch-1) * dataset_size + epoch_iter 51 | 52 | display_delta = total_steps % opt.display_freq 53 | print_delta = total_steps % opt.print_freq 54 | save_delta = total_steps % opt.save_latest_freq 55 | 56 | for i, data in enumerate(dataset): 57 | if (i % 100 == 0): 58 | print((i, dataset_size)) 59 | visuals = OrderedDict([('input_label', util.tensor2label(data['A'][0][3:6, :, :], 0)), 60 | ('real_image', util.tensor2im(data['A2'][0]))]) 61 | visualizer.display_current_results2(visuals, 0, i) 62 | 63 | --------------------------------------------------------------------------------