├── .gitignore
├── LICENSE.txt
├── README.md
├── _config.yml
├── cp_data.py
├── data
├── __init__.py
├── aligned_dataset.py
├── aligned_dataset2.py
├── base_data_loader.py
├── base_dataset.py
├── con_dataset_data_loader.py
├── custom_dataset_data_loader.py
├── data_loader.py
├── face_con_dataset_data_loader.py
├── face_dataset.py
├── image_folder.py
├── keypoint2img.py
├── pose_con_dataset_data_loader.py
└── pose_dataset.py
├── encode_features.py
├── expand_val_test.py
├── generate_data_face_forensics.py
├── img
├── dance.png
├── face.png
├── fig.png
├── scene.png
├── small_dance.png
├── small_face.png
└── small_scene.png
├── infer_face.py
├── inference
├── data
│ ├── img
│ │ ├── __Mr_0__0__a.0.jpg
│ │ ├── __Mr_0__0__a.1.jpg
│ │ ├── __Mr_1__0__a.0.jpg
│ │ └── __Mr_1__0__a.1.jpg
│ ├── infer_list.txt
│ └── keypoints
│ │ ├── __Mr_0__0__a.0.txt
│ │ ├── __Mr_0__0__a.1.txt
│ │ ├── __Mr_1__0__a.0.txt
│ │ └── __Mr_1__0__a.1.txt
├── infer_list.txt
└── test_imgs
│ ├── 1.jpg
│ └── 2.jpg
├── judge_face.py
├── models
├── __init__.py
├── base_model.py
├── c_pix2pixHD_model.py
├── cm_pix2pixHD_model.py
├── models.py
├── networks.py
├── pix2pixHD_model.py
└── ui_model.py
├── new_scripts
├── infer_face.sh
├── table_face.sh
├── table_pose.sh
├── table_scene.sh
├── test_face.sh
├── test_pose.sh
├── test_scene.sh
├── train_face.sh
├── train_pose.sh
└── train_scene.sh
├── options
├── __init__.py
├── base_options.py
├── generate_data_options.py
├── test_options.py
└── train_options.py
├── precompute_feature_maps.py
├── process.py
├── run_engine.py
├── temp_test.py
├── test.py
├── test_all.py
├── test_con.py
├── test_con_bak.py
├── test_delta.txt
├── test_face.py
├── test_mface.py
├── test_pose.py
├── test_seg.py
├── train.py
├── train_con.py
├── train_face.py
├── train_mface.py
├── train_pose.py
├── train_seg.py
├── util
├── __init__.py
├── html.py
├── image_pool.py
├── label.py
├── util.py
└── visualizer.py
├── val_gen.py
├── vis_bdd.py
├── vis_delta.txt
├── vis_face.py
└── vis_pose.py
/.gitignore:
--------------------------------------------------------------------------------
1 | debug*
2 | datasets/
3 | checkpoints/
4 | results/
5 | build/
6 | dist/
7 | logs/
8 | old_logs/
9 | torch.egg-info/
10 | */**/__pycache__
11 | torch/version.py
12 | torch/csrc/generic/TensorMethods.cpp
13 | torch/lib/*.so*
14 | torch/lib/*.dylib*
15 | torch/lib/*.h
16 | torch/lib/build
17 | torch/lib/tmp_install
18 | torch/lib/include
19 | torch/lib/torch_shm_manager
20 | torch/csrc/cudnn/cuDNN.cpp
21 | torch/csrc/nn/THNN.cwrap
22 | torch/csrc/nn/THNN.cpp
23 | torch/csrc/nn/THCUNN.cwrap
24 | torch/csrc/nn/THCUNN.cpp
25 | torch/csrc/nn/THNN_generic.cwrap
26 | torch/csrc/nn/THNN_generic.cpp
27 | torch/csrc/nn/THNN_generic.h
28 | docs/src/**/*
29 | test/data/legacy_modules.t7
30 | test/data/gpu_tensors.pt
31 | test/htmlcov
32 | test/.coverage
33 | */*.pyc
34 | */**/*.pyc
35 | */**/**/*.pyc
36 | */**/**/**/*.pyc
37 | */**/**/**/**/*.pyc
38 | */*.so*
39 | */**/*.so*
40 | */**/*.dylib*
41 | test/data/legacy_serialized.pt
42 | *.DS_Store
43 | *~
44 |
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | Copyright (C) 2017 NVIDIA Corporation. Ting-Chun Wang, Ming-Yu Liu, Jun-Yan Zhu.
2 | All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 |
5 | Permission to use, copy, modify, and distribute this software and its documentation
6 | for any non-commercial purpose is hereby granted without fee, provided that the above
7 | copyright notice appear in all copies and that both that copyright notice and this
8 | permission notice appear in supporting documentation, and that the name of the author
9 | not be used in advertising or publicity pertaining to distribution of the software
10 | without specific, written prior permission.
11 |
12 | THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING ALL
13 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR ANY PARTICULAR PURPOSE.
14 | IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL
15 | DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
16 | WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING
17 | OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
18 |
19 |
20 | --------------------------- LICENSE FOR pytorch-CycleGAN-and-pix2pix ----------------
21 | Copyright (c) 2017, Jun-Yan Zhu and Taesung Park
22 | All rights reserved.
23 |
24 | Redistribution and use in source and binary forms, with or without
25 | modification, are permitted provided that the following conditions are met:
26 |
27 | * Redistributions of source code must retain the above copyright notice, this
28 | list of conditions and the following disclaimer.
29 |
30 | * Redistributions in binary form must reproduce the above copyright notice,
31 | this list of conditions and the following disclaimer in the documentation
32 | and/or other materials provided with the distribution.
33 |
34 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
35 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
36 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
37 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
38 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
39 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
40 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
41 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
42 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
43 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
44 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Example-Guided Style-Consistent Image Synthesis from Semantic Labeling
2 |

3 |
4 | ## Paper
5 | Example-Guided Style-Consistent Image Synthesis from Semantic Labeling
6 | Miao Wang1, Guo-Ye Yang2, Ruilong Li2, Run-Ze Liang2, Song-Hai Zhang2, Peter M. Hall3 and Shi-Min Hu2,1
7 | 1State Key Laboratory of Virtual Reality Technology and Systems, Beihang University
8 | 2Department of Computer Science and Technology, Tsinghua University, Beijing
9 | 3University of Bath
10 | *IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2019*
11 |
12 |
13 | ## Prerequisites
14 | - Linux
15 | - Python 3
16 | - NVIDIA GPU (12G or 24G memory) + CUDA cuDNN
17 | - pytorch==0.4.1
18 | - numpy
19 | - ...
20 |
21 | ## Tasks
22 | ### Sketch2Face
23 | Task name: face
24 |
25 | We use the real videos in the FaceForensics dataset, which contains 854 videos of reporters broadcasting news. We localize facial landmarks, crop facial regions and resize them to size 256×256. The detected facial landmarks are connected to create face sketches.
26 | 
27 |
28 | ### Pose2Dance
29 | Task name: pose
30 |
31 | We download 150 solo dance videos from YouTube, crop out the central body regions and resize them to 256×256. We evenly split each video into the first part and the second part along the time-line, then sample training data only from the first parts and sample testing data only from the second parts of all the videos. The the labels are created using concatenated pre-trained DensePose and OpenPose pose detection results.
32 | 
33 |
34 | ### SceneParsing2StreetView
35 | Task name: scene
36 |
37 | We use the BDD100k dataset to synthesize street view images from pixelwise semantic labels (i.e. scene parsing maps). We use the state-of-the-art scene parsing network DANet to create labels.
38 | 
39 |
40 | ## Getting Started
41 | ### Installation
42 | ```bash
43 | git clone [this project]
44 | cd pix2pixSC
45 | # download datas.zip at https://drive.google.com/drive/folders/1O94UcCXONq7p2ZiPcfi-dldjREQ-GsJK or https://share.weiyun.com/5lHBkE0
46 | unzip datas.zip
47 | mv datas/checkpoints ./
48 | mv datas/datasets ./
49 |
50 | # scripts below is optional
51 | mkdir ../FaceForensics
52 | download FaceForensics dataset to ../FaceForensics/datas
53 | python process.py
54 | python generate_data_face_forensics.py --source_path '../FaceForensics/out_data' --target_path './datasets/FaceForensics3/' --same_style_rate 0.3 --neighbor_size 10 --A_repeat_num 50 --copy_data
55 | ```
56 |
57 | ### Training
58 | ```bash
59 | new_scripts/train_[Task name].sh
60 | ```
61 |
62 | ### Testing
63 | ```bash
64 | new_scripts/test_[Task name].sh
65 | ```
66 |
67 | ### Inference Face
68 | Edit inference/infer_list.txt, one test case each line, outputs of each test case will be in ./results.
69 | ```bash
70 | new_scripts/infer_face.sh
71 | ```
72 | Inference code of other tasks will come later.
73 |
74 | ## Results
75 | ### Face
76 | 
77 |
78 | ### Dance
79 | 
80 |
81 | ### Scene
82 | 
83 |
84 |
85 | ## Citation
86 |
87 | If you find this useful for your research, please cite the following paper.
88 |
89 | ```
90 | @InProceedings{pix2pixSC2019,
91 | author = {Wang, Miao and Yang, Guo-Ye and Li, Ruilong and Liang, Run-Ze and Zhang, Song-Hai and Hall, Peter. M and Hu, Shi-Min},
92 | title = {Example-Guided Style-Consistent Image Synthesis from Semantic Labeling},
93 | booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
94 | month = {June},
95 | year = {2019}
96 | }
97 | ```
98 |
--------------------------------------------------------------------------------
/_config.yml:
--------------------------------------------------------------------------------
1 | theme: jekyll-theme-minimal
--------------------------------------------------------------------------------
/cp_data.py:
--------------------------------------------------------------------------------
1 | from shutil import copyfile
2 | tp = 'datasets/cityscapes/val_'
3 | pre = 'aachen_000018_000019'
4 | pre = 'bremen_000108_000019'
5 | pre = 'frankfurt_000001_008688'
6 | pre = 'frankfurt_000001_060906'
7 | s_img_path = tp + 'img/' + pre + '_leftImg8bit.png'
8 | t_img_path = 'datasets/test/img.png'
9 | s_inst_path = tp + 'inst/' + pre + '_gtFine_instanceIds.png'
10 | t_inst_path = 'datasets/test/inst.png'
11 | s_label_path = tp + 'label/' + pre + '_gtFine_labelIds.png'
12 | t_label_path = 'datasets/test/label.png'
13 | copyfile(s_img_path, t_img_path)
14 | copyfile(s_inst_path, t_inst_path)
15 | copyfile(s_label_path, t_label_path)
16 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cxjyxxme/pix2pixSC/1851e934b333e89b98b3c0df9962912201fcb376/data/__init__.py
--------------------------------------------------------------------------------
/data/aligned_dataset.py:
--------------------------------------------------------------------------------
1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
3 | import os.path
4 | from data.base_dataset import BaseDataset, get_params, get_transform, normalize
5 | from data.image_folder import make_dataset
6 | from PIL import Image
7 |
8 | class AlignedDataset(BaseDataset):
9 | def initialize(self, opt):
10 | self.opt = opt
11 | self.root = opt.dataroot
12 |
13 | ### input A (label maps)
14 | dir_A = '_A' if self.opt.label_nc == 0 else '_label'
15 | self.dir_A = os.path.join(opt.dataroot, opt.phase + dir_A)
16 | self.A_paths = sorted(make_dataset(self.dir_A))
17 |
18 | ### input B (real images)
19 | if opt.isTrain:
20 | dir_B = '_B' if self.opt.label_nc == 0 else '_img'
21 | self.dir_B = os.path.join(opt.dataroot, opt.phase + dir_B)
22 | self.B_paths = sorted(make_dataset(self.dir_B))
23 |
24 | ### instance maps
25 | if not opt.no_instance:
26 | self.dir_inst = os.path.join(opt.dataroot, opt.phase + '_inst')
27 | self.inst_paths = sorted(make_dataset(self.dir_inst))
28 |
29 | ### load precomputed instance-wise encoded features
30 | if opt.load_features:
31 | self.dir_feat = os.path.join(opt.dataroot, opt.phase + '_feat')
32 | print('----------- loading features from %s ----------' % self.dir_feat)
33 | self.feat_paths = sorted(make_dataset(self.dir_feat))
34 |
35 | self.dataset_size = len(self.A_paths)
36 |
37 | def __getitem__(self, index):
38 | ### input A (label maps)
39 | A_path = self.A_paths[index]
40 | A = Image.open(A_path)
41 | params = get_params(self.opt, A.size)
42 | if self.opt.label_nc == 0:
43 | transform_A = get_transform(self.opt, params)
44 | A_tensor = transform_A(A.convert('RGB'))
45 | else:
46 | transform_A = get_transform(self.opt, params, method=Image.NEAREST, normalize=False)
47 | A_tensor = transform_A(A) * 255.0
48 |
49 | B_tensor = inst_tensor = feat_tensor = 0
50 | ### input B (real images)
51 | if self.opt.isTrain:
52 | B_path = self.B_paths[index]
53 | B = Image.open(B_path).convert('RGB')
54 | transform_B = get_transform(self.opt, params)
55 | B_tensor = transform_B(B)
56 |
57 | ### if using instance maps
58 | if not self.opt.no_instance:
59 | inst_path = self.inst_paths[index]
60 | inst = Image.open(inst_path)
61 | inst_tensor = transform_A(inst)
62 |
63 | if self.opt.load_features:
64 | feat_path = self.feat_paths[index]
65 | feat = Image.open(feat_path).convert('RGB')
66 | norm = normalize()
67 | feat_tensor = norm(transform_A(feat))
68 |
69 | input_dict = {'label': A_tensor, 'inst': inst_tensor, 'image': B_tensor,
70 | 'feat': feat_tensor, 'path': A_path}
71 |
72 | return input_dict
73 |
74 | def __len__(self):
75 | return len(self.A_paths) // self.opt.batchSize * self.opt.batchSize
76 |
77 | def name(self):
78 | return 'AlignedDataset'
--------------------------------------------------------------------------------
/data/aligned_dataset2.py:
--------------------------------------------------------------------------------
1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
3 | import os.path
4 | from data.base_dataset import BaseDataset, get_params, get_transform, normalize
5 | from data.image_folder import make_dataset
6 | from PIL import Image
7 |
8 | class AlignedDataset2(BaseDataset):
9 | def initialize(self, opt):
10 | self.opt = opt
11 | self.root = opt.dataroot
12 | data_list_path = os.path.join(opt.dataroot, opt.phase + '_list.txt')
13 | f = open(data_list_path, 'r')
14 | self.all_paths = f.readlines()
15 | self.dataset_size = len(self.all_paths)
16 |
17 | def transfer_path(self, path):
18 | temp = path.split('__')[-1]
19 | temp = './datasets/bdd2/danet_vis/' + temp + '.label.png'
20 | return temp
21 |
22 | def same_style(self, path1, path2):
23 | t1 = path1.split('__')
24 | t2 = path2.split('__')
25 | p1 = t1[-3] + '_' + t1[-2]
26 | p2 = t2[-3] + '_' + t2[-2]
27 | if (p1 == p2):
28 | return 1
29 | else:
30 | return 0
31 |
32 | def get_X(self, path, params, do_transfer = False):
33 | ### input A (label maps)
34 | if self.opt.use_new_label and do_transfer:
35 | path = self.transfer_path(path)
36 | A = Image.open(path)
37 | if self.opt.label_nc == 0:
38 | transform_A = get_transform(self.opt, params)
39 | A_tensor = transform_A(A.convert('RGB'))
40 | else:
41 | transform_A = get_transform(self.opt, params, method=Image.NEAREST, normalize=False)
42 | A_tensor = transform_A(A) * 255.0
43 | return A_tensor
44 |
45 | def get_X2(self, path, params):
46 | B = Image.open(path).convert('RGB')
47 | transform_B = get_transform(self.opt, params)
48 | B_tensor = transform_B(B)
49 | return B_tensor
50 |
51 | def __getitem__(self, index):
52 | paths = self.all_paths[index].rstrip('\n').split('&')
53 | A = Image.open(paths[0])
54 | params = get_params(self.opt, A.size)
55 |
56 | A_tensor = self.get_X(paths[0], params)
57 | B_tensor = self.get_X(paths[1], params, True)
58 | B2_tensor = self.get_X2(paths[2], params)
59 | A2_tensor = self.get_X2(paths[3], params)
60 | C_tensor = C2_tensor = D_tensor = D2_tensor = 0
61 | if (self.opt.isTrain):
62 | C_tensor = self.get_X(paths[4], params)
63 | C2_tensor = self.get_X2(paths[5], params)
64 | D_tensor = self.get_X(paths[6], params)
65 | D2_tensor = self.get_X2(paths[7], params)
66 | input_dict = {'A': A_tensor, 'A2': A2_tensor, 'B': B_tensor, 'B2': B2_tensor, 'C': C_tensor, 'C2': C2_tensor,
67 | 'D': D_tensor, 'D2': D2_tensor, 'path': paths[0] + '_' + paths[1], 'same_style': self.same_style(paths[2], paths[3])}
68 | return input_dict
69 |
70 | def __len__(self):
71 | return len(self.all_paths) // self.opt.batchSize * self.opt.batchSize
72 |
73 | def name(self):
74 | return 'AlignedDataset2'
75 |
--------------------------------------------------------------------------------
/data/base_data_loader.py:
--------------------------------------------------------------------------------
1 |
2 | class BaseDataLoader():
3 | def __init__(self):
4 | pass
5 |
6 | def initialize(self, opt):
7 | self.opt = opt
8 | pass
9 |
10 | def load_data():
11 | return None
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/data/base_dataset.py:
--------------------------------------------------------------------------------
1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
3 | import torch.utils.data as data
4 | from PIL import Image
5 | import torchvision.transforms as transforms
6 | import numpy as np
7 | import random
8 |
9 | class BaseDataset(data.Dataset):
10 | def __init__(self):
11 | super(BaseDataset, self).__init__()
12 |
13 | def name(self):
14 | return 'BaseDataset'
15 |
16 | def initialize(self, opt):
17 | pass
18 |
19 | def get_params(opt, size):
20 | w, h = size
21 | new_h = h
22 | new_w = w
23 | if opt.resize_or_crop == 'resize_and_crop':
24 | new_h = new_w = opt.loadSize
25 | elif opt.resize_or_crop == 'scale_width_and_crop':
26 | new_w = opt.loadSize
27 | new_h = opt.loadSize * h // w
28 |
29 | x = random.randint(0, np.maximum(0, new_w - opt.fineSize))
30 | y = random.randint(0, np.maximum(0, new_h - opt.fineSize))
31 | if (not opt.isTrain) and opt.resize_or_crop == 'scale_width_and_crop':
32 | x = np.maximum(0, new_w - opt.fineSize) / 2
33 | y = np.maximum(0, new_h - opt.fineSize) / 2
34 |
35 | flip = random.random() > 0.5
36 | return {'crop_pos': (x, y), 'flip': flip}
37 |
38 | def get_transform(opt, params, method=Image.BICUBIC, normalize=True):
39 | transform_list = []
40 | if 'resize' in opt.resize_or_crop:
41 | osize = [opt.loadSize, opt.loadSize]
42 | transform_list.append(transforms.Scale(osize, method))
43 | elif 'scale_width' in opt.resize_or_crop:
44 | transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.loadSize, method)))
45 |
46 | if 'crop' in opt.resize_or_crop:
47 | transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.fineSize)))
48 |
49 | if opt.resize_or_crop == 'none':
50 | base = float(2 ** opt.n_downsample_global)
51 | if opt.netG == 'local':
52 | base *= (2 ** opt.n_local_enhancers)
53 | transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method)))
54 |
55 | if opt.isTrain and not opt.no_flip:
56 | transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
57 |
58 | transform_list += [transforms.ToTensor()]
59 |
60 | if normalize:
61 | transform_list += [transforms.Normalize((0.5, 0.5, 0.5),
62 | (0.5, 0.5, 0.5))]
63 | return transforms.Compose(transform_list)
64 |
65 | def normalize():
66 | return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
67 |
68 | def __make_power_2(img, base, method=Image.BICUBIC):
69 | ow, oh = img.size
70 | h = int(round(oh / base) * base)
71 | w = int(round(ow / base) * base)
72 | if (h == oh) and (w == ow):
73 | return img
74 | return img.resize((w, h), method)
75 |
76 | def __scale_width(img, target_width, method=Image.BICUBIC):
77 | ow, oh = img.size
78 | if (ow == target_width):
79 | return img
80 | w = target_width
81 | h = int(target_width * oh / ow)
82 | return img.resize((w, h), method)
83 |
84 | def __crop(img, pos, size):
85 | ow, oh = img.size
86 | x1, y1 = pos
87 | tw = th = size
88 | if (ow > tw or oh > th):
89 | return img.crop((x1, y1, x1 + tw, y1 + th))
90 | return img
91 |
92 | def __flip(img, flip):
93 | if flip:
94 | return img.transpose(Image.FLIP_LEFT_RIGHT)
95 | return img
96 |
97 |
--------------------------------------------------------------------------------
/data/con_dataset_data_loader.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data
2 | from data.base_data_loader import BaseDataLoader
3 |
4 |
5 | def CreateDataset(opt):
6 | dataset = None
7 | from data.aligned_dataset2 import AlignedDataset2
8 | dataset = AlignedDataset2()
9 |
10 | print("dataset [%s] was created" % (dataset.name()))
11 | dataset.initialize(opt)
12 | return dataset
13 |
14 | class ConDatasetDataLoader(BaseDataLoader):
15 | def name(self):
16 | return 'ConDatasetDataLoader'
17 |
18 | def initialize(self, opt):
19 | BaseDataLoader.initialize(self, opt)
20 | self.dataset = CreateDataset(opt)
21 | self.dataloader = torch.utils.data.DataLoader(
22 | self.dataset,
23 | batch_size=opt.batchSize,
24 | shuffle=not opt.serial_batches,
25 | num_workers=int(opt.nThreads))
26 |
27 | def load_data(self):
28 | return self.dataloader
29 |
30 | def __len__(self):
31 | return min(len(self.dataset), self.opt.max_dataset_size)
32 |
--------------------------------------------------------------------------------
/data/custom_dataset_data_loader.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data
2 | from data.base_data_loader import BaseDataLoader
3 |
4 |
5 | def CreateDataset(opt):
6 | dataset = None
7 | from data.aligned_dataset import AlignedDataset
8 | dataset = AlignedDataset()
9 |
10 | print("dataset [%s] was created" % (dataset.name()))
11 | dataset.initialize(opt)
12 | return dataset
13 |
14 | class CustomDatasetDataLoader(BaseDataLoader):
15 | def name(self):
16 | return 'CustomDatasetDataLoader'
17 |
18 | def initialize(self, opt):
19 | BaseDataLoader.initialize(self, opt)
20 | self.dataset = CreateDataset(opt)
21 | self.dataloader = torch.utils.data.DataLoader(
22 | self.dataset,
23 | batch_size=opt.batchSize,
24 | shuffle=not opt.serial_batches,
25 | num_workers=int(opt.nThreads))
26 |
27 | def load_data(self):
28 | return self.dataloader
29 |
30 | def __len__(self):
31 | return min(len(self.dataset), self.opt.max_dataset_size)
32 |
--------------------------------------------------------------------------------
/data/data_loader.py:
--------------------------------------------------------------------------------
1 |
2 | def CreateDataLoader(opt):
3 | from data.custom_dataset_data_loader import CustomDatasetDataLoader
4 | data_loader = CustomDatasetDataLoader()
5 | print(data_loader.name())
6 | data_loader.initialize(opt)
7 | return data_loader
8 | def CreateConDataLoader(opt):
9 | from data.con_dataset_data_loader import ConDatasetDataLoader
10 | data_loader = ConDatasetDataLoader()
11 | print(data_loader.name())
12 | data_loader.initialize(opt)
13 | return data_loader
14 | def CreateFaceConDataLoader(opt):
15 | from data.face_con_dataset_data_loader import FaceConDatasetDataLoader
16 | data_loader = FaceConDatasetDataLoader()
17 | print(data_loader.name())
18 | data_loader.initialize(opt)
19 | return data_loader
20 | def CreatePoseConDataLoader(opt):
21 | from data.pose_con_dataset_data_loader import PoseConDatasetDataLoader
22 | data_loader = PoseConDatasetDataLoader()
23 | print(data_loader.name())
24 | data_loader.initialize(opt)
25 | return data_loader
26 |
--------------------------------------------------------------------------------
/data/face_con_dataset_data_loader.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data
2 | from data.base_data_loader import BaseDataLoader
3 |
4 |
5 | def CreateDataset(opt):
6 | dataset = None
7 | from data.face_dataset import FaceDataset
8 | dataset = FaceDataset()
9 |
10 | print("dataset [%s] was created" % (dataset.name()))
11 | dataset.initialize(opt)
12 | return dataset
13 |
14 | class FaceConDatasetDataLoader(BaseDataLoader):
15 | def name(self):
16 | return 'FaceConDatasetDataLoader'
17 |
18 | def initialize(self, opt):
19 | BaseDataLoader.initialize(self, opt)
20 | self.dataset = CreateDataset(opt)
21 | self.dataloader = torch.utils.data.DataLoader(
22 | self.dataset,
23 | batch_size=opt.batchSize,
24 | shuffle=not opt.serial_batches,
25 | num_workers=int(opt.nThreads))
26 |
27 | def load_data(self):
28 | return self.dataloader
29 |
30 | def __len__(self):
31 | return min(len(self.dataset), self.opt.max_dataset_size)
32 |
--------------------------------------------------------------------------------
/data/face_dataset.py:
--------------------------------------------------------------------------------
1 | import os.path
2 | import torchvision.transforms as transforms
3 | import torch
4 | from PIL import Image
5 | import numpy as np
6 | import cv2
7 | from skimage import feature
8 |
9 | from data.base_dataset import BaseDataset, get_transform, get_params
10 | from data.keypoint2img import interpPoints, drawEdge
11 | import random
12 |
13 | class FaceDataset(BaseDataset):
14 | def initialize(self, opt):
15 | self.opt = opt
16 | self.root = opt.dataroot
17 | data_list_path = os.path.join(opt.dataroot, opt.phase + '_list.txt')
18 | f = open(data_list_path, 'r')
19 | self.all_paths = f.readlines()
20 | self.dataset_size = len(self.all_paths)
21 | if (not opt.isTrain and opt.serial_batches):
22 | ff = open(opt.test_delta_path, "r")
23 | t = int(ff.readlines()[0])
24 | self.delta = t
25 | else:
26 | self.delta = 0
27 |
28 | def get_X(self, path, params, B_size, B_img):
29 | transform_scaleA = get_transform(self.opt, params, method=Image.BILINEAR, normalize=False)
30 | transform_label = get_transform(self.opt, params, method=Image.NEAREST, normalize=False)
31 | Ai, Li = self.get_face_image(path, transform_scaleA, transform_label, B_size, B_img)
32 | return Ai
33 |
34 | def get_X2(self, path, params):
35 | B = Image.open(path).convert('RGB')
36 | transform_B = get_transform(self.opt, params)
37 | B_tensor = transform_B(B)
38 | return B_tensor, B
39 |
40 | def same_style(self, path1, path2):
41 | t1 = path1.split('__')
42 | t2 = path2.split('__')
43 | p1 = t1[-3] + '_' + t1[-2]
44 | p2 = t2[-3] + '_' + t2[-2]
45 | if (p1 == p2):
46 | return 1
47 | else:
48 | return 0
49 |
50 | def __getitem__(self, index):
51 | index = (index + self.delta) % len(self.all_paths)
52 | paths = self.all_paths[index].rstrip('\n').split('&')
53 | A = Image.open(paths[2])
54 | params = get_params(self.opt, A.size)
55 |
56 | B2_tensor, B = self.get_X2(paths[2], params)
57 | A2_tensor, A = self.get_X2(paths[3], params)
58 | A_tensor = self.get_X(paths[0], params, A.size, A)
59 | B_tensor = self.get_X(paths[1], params, A.size, B)
60 | C_tensor = C2_tensor = D_tensor = D2_tensor = 0
61 | if (self.opt.isTrain):
62 | C2_tensor, C = self.get_X2(paths[5], params)
63 | C_tensor = self.get_X(paths[4], params, A.size, C)
64 | D2_tensor, D = self.get_X2(paths[7], params)
65 | D_tensor = self.get_X(paths[6], params, A.size, D)
66 | input_dict = {'A': A_tensor, 'A2': A2_tensor, 'B': B_tensor, 'B2': B2_tensor, 'C': C_tensor, 'C2': C2_tensor,
67 | 'D': D_tensor, 'D2': D2_tensor, 'path': paths[0] + '_' + paths[1], 'same_style': self.same_style(paths[2], paths[3])}
68 | return input_dict
69 |
70 | '''
71 | def __getitem__(self, index):
72 | A, B, I, seq_idx = self.update_frame_idx(self.A_paths, index)
73 | A_paths = self.A_paths[seq_idx]
74 | B_paths = self.B_paths[seq_idx]
75 | n_frames_total, start_idx, t_step = get_video_params(self.opt, self.n_frames_total, len(A_paths), self.frame_idx)
76 |
77 | B_img = Image.open(B_paths[0]).convert('RGB')
78 | B_size = B_img.size
79 | points = np.loadtxt(A_paths[0], delimiter=',')
80 | is_first_frame = self.opt.isTrain or not hasattr(self, 'min_x')
81 | if is_first_frame: # crop only the face region
82 | self.get_crop_coords(points, B_size)
83 | params = get_img_params(self.opt, self.crop(B_img).size)
84 | transform_scaleA = get_transform(self.opt, params, method=Image.BILINEAR, normalize=False)
85 | transform_label = get_transform(self.opt, params, method=Image.NEAREST, normalize=False)
86 | transform_scaleB = get_transform(self.opt, params)
87 |
88 | # read in images
89 | frame_range = list(range(n_frames_total)) if self.A is None else [self.opt.n_frames_G-1]
90 | for i in frame_range:
91 | A_path = A_paths[start_idx + i * t_step]
92 | B_path = B_paths[start_idx + i * t_step]
93 | B_img = Image.open(B_path)
94 | Ai, Li = self.get_face_image(A_path, transform_scaleA, transform_label, B_size, B_img)
95 | Bi = transform_scaleB(self.crop(B_img))
96 | A = concat_frame(A, Ai, n_frames_total)
97 | B = concat_frame(B, Bi, n_frames_total)
98 | I = concat_frame(I, Li, n_frames_total)
99 |
100 | if not self.opt.isTrain:
101 | self.A, self.B, self.I = A, B, I
102 | self.frame_idx += 1
103 | change_seq = False if self.opt.isTrain else self.change_seq
104 | return_list = {'A': A, 'B': B, 'inst': I, 'A_path': A_path, 'change_seq': change_seq}
105 |
106 | return return_list
107 | '''
108 | def get_image(self, A_path, transform_scaleA):
109 | A_img = Image.open(A_path)
110 | A_scaled = transform_scaleA(self.crop(A_img))
111 | return A_scaled
112 |
113 | def get_face_image(self, A_path, transform_A, transform_L, size, img):
114 | # read face keypoints from path and crop face region
115 | keypoints, part_list, part_labels = self.read_keypoints(A_path, size)
116 |
117 | # draw edges and possibly add distance transform maps
118 | add_dist_map = not self.opt.no_dist_map
119 | im_edges, dist_tensor = self.draw_face_edges(keypoints, part_list, transform_A, size, add_dist_map)
120 |
121 | # canny edge for background
122 | if not self.opt.no_canny_edge:
123 | edges = feature.canny(np.array(img.convert('L')))
124 | edges = edges * (part_labels == 0) # remove edges within face
125 | im_edges += (edges * 255).astype(np.uint8)
126 | edge_tensor = transform_A(Image.fromarray(self.crop(im_edges)))
127 |
128 | # final input tensor
129 | input_tensor = torch.cat([edge_tensor, dist_tensor]) if add_dist_map else edge_tensor
130 | label_tensor = transform_L(Image.fromarray(self.crop(part_labels.astype(np.uint8)))) * 255.0
131 | return input_tensor, label_tensor
132 |
133 | def read_keypoints(self, A_path, size):
134 | # mapping from keypoints to face part
135 | part_list = [[list(range(0, 17)) + list(range(68, 83)) + [0]], # face
136 | [range(17, 22)], # right eyebrow
137 | [range(22, 27)], # left eyebrow
138 | [[28, 31], range(31, 36), [35, 28]], # nose
139 | [[36,37,38,39], [39,40,41,36]], # right eye
140 | [[42,43,44,45], [45,46,47,42]], # left eye
141 | [range(48, 55), [54,55,56,57,58,59,48]], # mouth
142 | [range(60, 65), [64,65,66,67,60]] # tongue
143 | ]
144 | label_list = [1, 2, 2, 3, 4, 4, 5, 6] # labeling for different facial parts
145 | keypoints = np.loadtxt(A_path, delimiter=',')
146 |
147 | # add upper half face by symmetry
148 | pts = keypoints[:17, :].astype(np.int32)
149 | baseline_y = (pts[0,1] + pts[-1,1]) / 2
150 | upper_pts = pts[1:-1,:].copy()
151 | upper_pts[:,1] = baseline_y + (baseline_y-upper_pts[:,1]) * 2 // 3
152 | keypoints = np.vstack((keypoints, upper_pts[::-1,:]))
153 |
154 | # label map for facial part
155 | w, h = size
156 | part_labels = np.zeros((h, w), np.uint8)
157 | for p, edge_list in enumerate(part_list):
158 | indices = [item for sublist in edge_list for item in sublist]
159 | pts = keypoints[indices, :].astype(np.int32)
160 | cv2.fillPoly(part_labels, pts=[pts], color=label_list[p])
161 |
162 | return keypoints, part_list, part_labels
163 |
164 | def draw_face_edges(self, keypoints, part_list, transform_A, size, add_dist_map):
165 | w, h = size
166 | edge_len = 3 # interpolate 3 keypoints to form a curve when drawing edges
167 | # edge map for face region from keypoints
168 | im_edges = np.zeros((h, w), np.uint8) # edge map for all edges
169 | dist_tensor = 0
170 | e = 1
171 | for edge_list in part_list:
172 | for edge in edge_list:
173 | im_edge = np.zeros((h, w), np.uint8) # edge map for the current edge
174 | for i in range(0, max(1, len(edge)-1), edge_len-1): # divide a long edge into multiple small edges when drawing
175 | sub_edge = edge[i:i+edge_len]
176 | x = keypoints[sub_edge, 0]
177 | y = keypoints[sub_edge, 1]
178 |
179 | curve_x, curve_y = interpPoints(x, y) # interp keypoints to get the curve shape
180 | drawEdge(im_edges, curve_x, curve_y)
181 | if add_dist_map:
182 | drawEdge(im_edge, curve_x, curve_y)
183 |
184 | if add_dist_map: # add distance transform map on each facial part
185 | im_dist = cv2.distanceTransform(255-im_edge, cv2.DIST_L1, 3)
186 | im_dist = np.clip((im_dist / 3), 0, 255).astype(np.uint8)
187 | im_dist = Image.fromarray(im_dist)
188 | tensor_cropped = transform_A(self.crop(im_dist))
189 | dist_tensor = tensor_cropped if e == 1 else torch.cat([dist_tensor, tensor_cropped])
190 | e += 1
191 |
192 | return im_edges, dist_tensor
193 |
194 | def get_crop_coords(self, keypoints, size):
195 | min_y, max_y = keypoints[:,1].min(), keypoints[:,1].max()
196 | min_x, max_x = keypoints[:,0].min(), keypoints[:,0].max()
197 | offset = (max_x - min_x) // 2
198 | min_y = max(0, min_y - offset*2)
199 | min_x = max(0, min_x - offset)
200 | max_x = min(size[0], max_x + offset)
201 | max_y = min(size[1], max_y + offset)
202 | self.min_y, self.max_y, self.min_x, self.max_x = int(min_y), int(max_y), int(min_x), int(max_x)
203 |
204 | def crop(self, img):
205 | return img
206 | #???
207 | if isinstance(img, np.ndarray):
208 | return img[self.min_y:self.max_y, self.min_x:self.max_x]
209 | else:
210 | return img.crop((self.min_x, self.min_y, self.max_x, self.max_y))
211 |
212 | def __len__(self):
213 | return len(self.all_paths) // self.opt.batchSize * self.opt.batchSize
214 |
215 | def name(self):
216 | return 'FaceDataset'
217 |
--------------------------------------------------------------------------------
/data/image_folder.py:
--------------------------------------------------------------------------------
1 | ###############################################################################
2 | # Code from
3 | # https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py
4 | # Modified the original code so that it also loads images from the current
5 | # directory as well as the subdirectories
6 | ###############################################################################
7 | import torch.utils.data as data
8 | from PIL import Image
9 | import os
10 |
11 | IMG_EXTENSIONS = [
12 | '.jpg', '.JPG', '.jpeg', '.JPEG',
13 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff'
14 | ]
15 |
16 |
17 | def is_image_file(filename):
18 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
19 |
20 |
21 | def make_dataset(dir):
22 | images = []
23 | assert os.path.isdir(dir), '%s is not a valid directory' % dir
24 |
25 | for root, _, fnames in sorted(os.walk(dir)):
26 | for fname in fnames:
27 | if is_image_file(fname):
28 | path = os.path.join(root, fname)
29 | images.append(path)
30 |
31 | return images
32 |
33 |
34 | def default_loader(path):
35 | return Image.open(path).convert('RGB')
36 |
37 |
38 | class ImageFolder(data.Dataset):
39 |
40 | def __init__(self, root, transform=None, return_paths=False,
41 | loader=default_loader):
42 | imgs = make_dataset(root)
43 | if len(imgs) == 0:
44 | raise(RuntimeError("Found 0 images in: " + root + "\n"
45 | "Supported image extensions are: " +
46 | ",".join(IMG_EXTENSIONS)))
47 |
48 | self.root = root
49 | self.imgs = imgs
50 | self.transform = transform
51 | self.return_paths = return_paths
52 | self.loader = loader
53 |
54 | def __getitem__(self, index):
55 | path = self.imgs[index]
56 | img = self.loader(path)
57 | if self.transform is not None:
58 | img = self.transform(img)
59 | if self.return_paths:
60 | return img, path
61 | else:
62 | return img
63 |
64 | def __len__(self):
65 | return len(self.imgs)
66 |
--------------------------------------------------------------------------------
/data/keypoint2img.py:
--------------------------------------------------------------------------------
1 | import os.path
2 | from PIL import Image
3 | import numpy as np
4 | import json
5 | import glob
6 | from scipy.optimize import curve_fit
7 | import cv2
8 | import torch
9 | import warnings
10 |
11 | def func(x, a, b, c):
12 | return a * x**2 + b * x + c
13 |
14 | def linear(x, a, b):
15 | return a * x + b
16 |
17 | def setColor(im, yy, xx, color):
18 | if len(im.shape) == 3:
19 | if (im[yy, xx] == 0).all():
20 | im[yy, xx, 0], im[yy, xx, 1], im[yy, xx, 2] = color[0], color[1], color[2]
21 | else:
22 | im[yy, xx, 0] = ((im[yy, xx, 0].astype(float) + color[0]) / 2).astype(np.uint8)
23 | im[yy, xx, 1] = ((im[yy, xx, 1].astype(float) + color[1]) / 2).astype(np.uint8)
24 | im[yy, xx, 2] = ((im[yy, xx, 2].astype(float) + color[2]) / 2).astype(np.uint8)
25 | else:
26 | im[yy, xx] = color[0]
27 |
28 | def drawEdge(im, x, y, bw=1, color=(255,255,255), draw_end_points=False):
29 | if x is not None and x.size:
30 | h, w = im.shape[0], im.shape[1]
31 | # edge
32 | for i in range(-bw, bw):
33 | for j in range(-bw, bw):
34 | yy = np.maximum(0, np.minimum(h-1, y+i))
35 | xx = np.maximum(0, np.minimum(w-1, x+j))
36 | setColor(im, yy, xx, color)
37 |
38 | # edge endpoints
39 | if draw_end_points:
40 | for i in range(-bw*2, bw*2):
41 | for j in range(-bw*2, bw*2):
42 | if (i**2) + (j**2) < (4 * bw**2):
43 | yy = np.maximum(0, np.minimum(h-1, np.array([y[0], y[-1]])+i))
44 | xx = np.maximum(0, np.minimum(w-1, np.array([x[0], x[-1]])+j))
45 | setColor(im, yy, xx, color)
46 |
47 | def interpPoints(x, y):
48 | if abs(x[:-1] - x[1:]).max() < abs(y[:-1] - y[1:]).max():
49 | curve_y, curve_x = interpPoints(y, x)
50 | if curve_y is None:
51 | return None, None
52 | else:
53 | with warnings.catch_warnings():
54 | warnings.simplefilter("ignore")
55 | if len(x) < 3:
56 | popt, _ = curve_fit(linear, x, y)
57 | else:
58 | popt, _ = curve_fit(func, x, y)
59 | if abs(popt[0]) > 1:
60 | return None, None
61 | if x[0] > x[-1]:
62 | x = list(reversed(x))
63 | y = list(reversed(y))
64 | curve_x = np.linspace(x[0], x[-1], (x[-1]-x[0]))
65 | if len(x) < 3:
66 | curve_y = linear(curve_x, *popt)
67 | else:
68 | curve_y = func(curve_x, *popt)
69 | return curve_x.astype(int), curve_y.astype(int)
70 |
71 | def read_keypoints(json_input, size, random_drop_prob=0, remove_face_labels=False):
72 | with open(json_input, encoding='utf-8') as f:
73 | keypoint_dicts = json.loads(f.read())["people"]
74 |
75 | edge_lists = define_edge_lists()
76 | w, h = size
77 | pose_img = np.zeros((h, w, 3), np.uint8)
78 | for keypoint_dict in keypoint_dicts:
79 | pose_pts = np.array(keypoint_dict["pose_keypoints_2d"]).reshape(25, 3)
80 | face_pts = np.array(keypoint_dict["face_keypoints_2d"]).reshape(70, 3)
81 | hand_pts_l = np.array(keypoint_dict["hand_left_keypoints_2d"]).reshape(21, 3)
82 | hand_pts_r = np.array(keypoint_dict["hand_right_keypoints_2d"]).reshape(21, 3)
83 | pts = [extract_valid_keypoints(pts, edge_lists) for pts in [pose_pts, face_pts, hand_pts_l, hand_pts_r]]
84 | pose_img += connect_keypoints(pts, edge_lists, size, random_drop_prob, remove_face_labels)
85 | return pose_img
86 |
87 | def read_keypoints2(json_input, size, transform_A, random_drop_prob=0, remove_face_labels=False):
88 | with open(json_input) as f:
89 | keypoint_dicts = json.loads(f.read())
90 |
91 | edge_lists = define_edge_lists2()
92 | w, h = size
93 | pose_pts = np.array(keypoint_dicts).reshape(18, 3)
94 | pts = extract_valid_keypoints2(pose_pts, edge_lists)
95 | pose_img, dist_tensor = connect_keypoints2(pts, edge_lists, size, random_drop_prob, remove_face_labels, transform_A)
96 | return pose_img, dist_tensor
97 |
98 | def extract_valid_keypoints2(pts, edge_lists):
99 | pose_edge_list, _ = edge_lists
100 | p = pts.shape[0]
101 | thre = 0.1 if p == 70 else 0.01
102 | output = np.zeros((p, 2))
103 |
104 | valid = (pts[:, 2] > thre)
105 | output[valid, :] = pts[valid, :2]
106 |
107 | return output
108 |
109 | def extract_valid_keypoints(pts, edge_lists):
110 | pose_edge_list, _, hand_edge_list, _, face_list = edge_lists
111 | p = pts.shape[0]
112 | thre = 0.1 if p == 70 else 0.01
113 | output = np.zeros((p, 2))
114 |
115 | if p == 70: # face
116 | for edge_list in face_list:
117 | for edge in edge_list:
118 | if (pts[edge, 2] > thre).all():
119 | output[edge, :] = pts[edge, :2]
120 | elif p == 21: # hand
121 | for edge in hand_edge_list:
122 | if (pts[edge, 2] > thre).all():
123 | output[edge, :] = pts[edge, :2]
124 | else: # pose
125 | valid = (pts[:, 2] > thre)
126 | output[valid, :] = pts[valid, :2]
127 |
128 | return output
129 |
130 | def connect_keypoints2(pts, edge_lists, size, random_drop_prob, remove_face_labels, transform_A):
131 | pose_pts = pts
132 | w, h = size
133 | output_edges = np.zeros((h, w, 3), np.uint8)
134 | pose_edge_list, pose_color_list = edge_lists
135 | dist_tensor = 0
136 | e = 1
137 |
138 | ### pose
139 | for i, edge in enumerate(pose_edge_list):
140 | im_edge = np.zeros((h, w), np.uint8)
141 | x, y = pose_pts[edge, 0], pose_pts[edge, 1]
142 | if (np.random.rand() > random_drop_prob) and (0 not in x):
143 | curve_x, curve_y = interpPoints(x, y)
144 | drawEdge(output_edges, curve_x, curve_y, bw=2, color=pose_color_list[i], draw_end_points=True)
145 | drawEdge(im_edge, curve_x, curve_y)
146 | im_dist = cv2.distanceTransform(255-im_edge, cv2.DIST_L1, 3)
147 | im_dist = np.clip((im_dist / 2), 0, 255).astype(np.uint8)
148 | im_dist = Image.fromarray(im_dist)
149 | tensor_cropped = transform_A(im_dist)
150 | dist_tensor = tensor_cropped if e == 1 else torch.cat((dist_tensor, tensor_cropped), 0)
151 | e += 1
152 | return Image.fromarray(output_edges), dist_tensor
153 |
154 | def connect_keypoints(pts, edge_lists, size, random_drop_prob, remove_face_labels):
155 | pose_pts, face_pts, hand_pts_l, hand_pts_r = pts
156 | w, h = size
157 | output_edges = np.zeros((h, w, 3), np.uint8)
158 | pose_edge_list, pose_color_list, hand_edge_list, hand_color_list, face_list = edge_lists
159 |
160 | if random_drop_prob > 0 and remove_face_labels:
161 | # add random noise to keypoints
162 | pose_pts[[0,15,16,17,18], :] += 5 * np.random.randn(5,2)
163 | face_pts[:,0] += 2 * np.random.randn()
164 | face_pts[:,1] += 2 * np.random.randn()
165 |
166 | ### pose
167 | for i, edge in enumerate(pose_edge_list):
168 | x, y = pose_pts[edge, 0], pose_pts[edge, 1]
169 | if (np.random.rand() > random_drop_prob) and (0 not in x):
170 | curve_x, curve_y = interpPoints(x, y)
171 | drawEdge(output_edges, curve_x, curve_y, bw=3, color=pose_color_list[i], draw_end_points=True)
172 |
173 | ### hand
174 | for hand_pts in [hand_pts_l, hand_pts_r]: # for left and right hand
175 | if np.random.rand() > random_drop_prob:
176 | for i, edge in enumerate(hand_edge_list): # for each finger
177 | for j in range(0, len(edge)-1): # for each part of the finger
178 | sub_edge = edge[j:j+2]
179 | x, y = hand_pts[sub_edge, 0], hand_pts[sub_edge, 1]
180 | if 0 not in x:
181 | line_x, line_y = interpPoints(x, y)
182 | drawEdge(output_edges, line_x, line_y, bw=1, color=hand_color_list[i], draw_end_points=True)
183 |
184 | ### face
185 | edge_len = 2
186 | if (np.random.rand() > random_drop_prob):
187 | for edge_list in face_list:
188 | for edge in edge_list:
189 | for i in range(0, max(1, len(edge)-1), edge_len-1):
190 | sub_edge = edge[i:i+edge_len]
191 | x, y = face_pts[sub_edge, 0], face_pts[sub_edge, 1]
192 | if 0 not in x:
193 | curve_x, curve_y = interpPoints(x, y)
194 | drawEdge(output_edges, curve_x, curve_y, draw_end_points=True)
195 |
196 | return output_edges
197 |
198 | def define_edge_lists2():
199 | ### pose
200 | pose_edge_list = [
201 | [0, 1],
202 | [14, 16], [15, 17],
203 | [14, 0], [15, 0],
204 | [1, 2], [1, 5],
205 | [1, 8], [1, 11],
206 | [8, 9], [9, 10],
207 | [11, 12], [12, 13],
208 | [2, 3], [3, 4],
209 | [5, 6], [6, 7]
210 | ]
211 | pose_edge_list = [
212 | [0, 1],
213 | [2, 3], [3, 4],
214 | [5, 6], [6, 7],
215 | [8, 9], [9, 10],
216 | [11, 12], [12, 13],
217 | [14, 16], [15, 17],
218 | [14, 0], [15, 0],
219 | [1, 2], [1, 5],
220 | [1, 8], [1, 11]
221 | ]
222 | pose_color_list = [
223 | [153, 0,153], [153, 0,102], [102, 0,153], [ 51, 0,153], [153, 0, 51],
224 | [153, 0, 0],
225 | [153, 51, 0], [153,102, 0], [153,153, 0],
226 | [102,153, 0], [ 51,153, 0], [ 0,153, 0],
227 | [ 0,153, 51], [ 0,153,102], [ 0,153,153], [ 0,102,153], [ 0,51,153], [ 0,0,153],
228 | [ 0,102,153], [ 0, 51,153], [ 0, 0,153], [ 0, 0,153], [ 0, 0,153], [ 0, 0,153]
229 | ]
230 |
231 | return pose_edge_list, pose_color_list
232 |
233 | def define_edge_lists():
234 | ### pose
235 | pose_edge_list = [
236 | [17, 15], [15, 0], [ 0, 16], [16, 18], [ 0, 1], # head
237 | [ 1, 8], # body
238 | [ 1, 2], [ 2, 3], [ 3, 4], # right arm
239 | [ 1, 5], [ 5, 6], [ 6, 7], # left arm
240 | [ 8, 9], [ 9, 10], [10, 11], [11, 24], [11, 22], [22, 23], # right leg
241 | [ 8, 12], [12, 13], [13, 14], [14, 21], [14, 19], [19, 20] # left leg
242 | ]
243 | pose_color_list = [
244 | [153, 0,153], [153, 0,102], [102, 0,153], [ 51, 0,153], [153, 0, 51],
245 | [153, 0, 0],
246 | [153, 51, 0], [153,102, 0], [153,153, 0],
247 | [102,153, 0], [ 51,153, 0], [ 0,153, 0],
248 | [ 0,153, 51], [ 0,153,102], [ 0,153,153], [ 0,153,153], [ 0,153,153], [ 0,153,153],
249 | [ 0,102,153], [ 0, 51,153], [ 0, 0,153], [ 0, 0,153], [ 0, 0,153], [ 0, 0,153]
250 | ]
251 |
252 | ### hand
253 | hand_edge_list = [
254 | [0, 1, 2, 3, 4],
255 | [0, 5, 6, 7, 8],
256 | [0, 9, 10, 11, 12],
257 | [0, 13, 14, 15, 16],
258 | [0, 17, 18, 19, 20]
259 | ]
260 | hand_color_list = [
261 | [204,0,0], [163,204,0], [0,204,82], [0,82,204], [163,0,204]
262 | ]
263 |
264 | ### face
265 | face_list = [
266 | #[range(0, 17)], # face
267 | [range(17, 22)], # left eyebrow
268 | [range(22, 27)], # right eyebrow
269 | [range(27, 31), range(31, 36)], # nose
270 | [[36,37,38,39], [39,40,41,36]], # left eye
271 | [[42,43,44,45], [45,46,47,42]], # right eye
272 | [range(48, 55), [54,55,56,57,58,59,48]], # mouth
273 | ]
274 | return pose_edge_list, pose_color_list, hand_edge_list, hand_color_list, face_list
275 |
--------------------------------------------------------------------------------
/data/pose_con_dataset_data_loader.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data
2 | from data.base_data_loader import BaseDataLoader
3 |
4 |
5 | def CreateDataset(opt):
6 | dataset = None
7 | from data.pose_dataset import PoseDataset
8 | dataset = PoseDataset()
9 |
10 | print("dataset [%s] was created" % (dataset.name()))
11 | dataset.initialize(opt)
12 | return dataset
13 |
14 | class PoseConDatasetDataLoader(BaseDataLoader):
15 | def name(self):
16 | return 'PoseConDatasetDataLoader'
17 |
18 | def initialize(self, opt):
19 | BaseDataLoader.initialize(self, opt)
20 | self.dataset = CreateDataset(opt)
21 | self.dataloader = torch.utils.data.DataLoader(
22 | self.dataset,
23 | batch_size=opt.batchSize,
24 | shuffle=not opt.serial_batches,
25 | num_workers=int(opt.nThreads))
26 |
27 | def load_data(self):
28 | return self.dataloader
29 |
30 | def __len__(self):
31 | return min(len(self.dataset), self.opt.max_dataset_size)
32 |
--------------------------------------------------------------------------------
/data/pose_dataset.py:
--------------------------------------------------------------------------------
1 | import os.path
2 | import torchvision.transforms as transforms
3 | import torch
4 | from PIL import Image
5 | import numpy as np
6 | import cv2
7 | from skimage import feature
8 |
9 | from data.base_dataset import BaseDataset, get_transform, get_params
10 | from data.keypoint2img import interpPoints, drawEdge, read_keypoints2
11 | import random
12 |
13 | class PoseDataset(BaseDataset):
14 | def initialize(self, opt):
15 | self.opt = opt
16 | self.root = opt.dataroot
17 | data_list_path = os.path.join(opt.dataroot, opt.phase + '_list.txt')
18 | f = open(data_list_path, 'r')
19 | self.all_paths = f.readlines()
20 | self.dataset_size = len(self.all_paths)
21 | if (not opt.isTrain and opt.serial_batches):
22 | ff = open(opt.test_delta_path, "r")
23 | t = int(ff.readlines()[0])
24 | self.delta = t
25 | else:
26 | self.delta = 0
27 |
28 |
29 | def get_image2(self, A_path, size, params, input_type):
30 | is_img = input_type == 'img'
31 | method = Image.BICUBIC if is_img else Image.NEAREST
32 | transform_scaleA = get_transform(self.opt, params, normalize=is_img, method=method)
33 | if input_type != 'openpose':
34 | A_img = Image.open(A_path).convert('RGB')
35 | else:
36 | random_drop_prob = self.opt.random_drop_prob if self.opt.isTrain else 0
37 | A_img, dist_tensor = read_keypoints2(A_path, size, transform_scaleA, random_drop_prob, False)
38 |
39 | if input_type == 'densepose' and self.opt.isTrain:
40 | # randomly remove labels
41 | A_np = np.array(A_img)
42 | part_labels = A_np[:,:,2]
43 | for part_id in range(1, 25):
44 | if (np.random.rand() < self.opt.random_drop_prob):
45 | A_np[(part_labels == part_id), :] = 0
46 | if self.opt.remove_face_labels:
47 | A_np[(part_labels == 23) | (part_labels == 24), :] = 0
48 | A_img = Image.fromarray(A_np)
49 |
50 | A_scaled = transform_scaleA(A_img)
51 | if (input_type == 'openpose' and self.opt.do_pose_dist_map):
52 | A_scaled = torch.cat([A_scaled, dist_tensor])
53 | return A_scaled
54 |
55 | def get_X(self, path, params, B_size, B_tensor):
56 | paths = path.split('|')
57 | if not self.opt.openpose_only:
58 | Di = self.get_image2(paths[0], B_size, params, input_type='densepose')
59 | Di[2,:,:] = Di[2,:,:] * 255 / 24
60 | if not self.opt.densepose_only:
61 | Oi = self.get_image2(paths[1], B_size, params, input_type='openpose')
62 |
63 | if self.opt.openpose_only:
64 | Ai = Oi
65 | elif self.opt.densepose_only:
66 | Ai = Di
67 | else:
68 | Ai = torch.cat([Di, Oi])
69 | valid = (Di[0, :, :] > 0) + (Di[1, :, :] > 0) + (Di[2, :, :] > 0)
70 | valid_ = valid == 0
71 | if (self.opt.add_mask):
72 | B_tensor[:, valid_] = 0
73 | '''
74 | valid_ = (valid > 0)[np.newaxis, :, :]
75 | valid__ = torch.cat([valid_, valid_, valid_])
76 | print(valid__[:, 128, 128])
77 | B_out = B_tensor * valid__
78 | '''
79 | #Ai, Bi = self.crop(Ai), self.crop(Bi) # only crop the central half region to save time
80 | return Ai, B_tensor
81 |
82 | def same_style(self, path1, path2):
83 | t1 = path1.split('__')
84 | t2 = path2.split('__')
85 | p1 = t1[-3] + '_' + t1[-2]
86 | p2 = t2[-3] + '_' + t2[-2]
87 | if (p1 == p2):
88 | return 1
89 | else:
90 | return 0
91 |
92 | def get_X2(self, path, params):
93 | B = Image.open(path).convert('RGB')
94 | transform_B = get_transform(self.opt, params)
95 | B_tensor = transform_B(B)
96 | return B_tensor, B
97 |
98 | def __getitem__(self, index):
99 | index = (index + self.delta) % len(self.all_paths)
100 | paths = self.all_paths[index].rstrip('\n').split('&')
101 | A = Image.open(paths[2])
102 | params = get_params(self.opt, A.size)
103 |
104 | B2_tensor_, B = self.get_X2(paths[2], params)
105 | A2_tensor_, A = self.get_X2(paths[3], params)
106 | A_tensor, A2_tensor = self.get_X(paths[0], params, A.size, A2_tensor_)
107 | B_tensor, B2_tensor = self.get_X(paths[1], params, A.size, B2_tensor_)
108 | C_tensor = C2_tensor = D_tensor = D2_tensor = 0
109 | if (self.opt.isTrain):
110 | C2_tensor_, C = self.get_X2(paths[5], params)
111 | C_tensor, C2_tensor = self.get_X(paths[4], params, A.size, C2_tensor_)
112 | D2_tensor_, D = self.get_X2(paths[7], params)
113 | D_tensor, D2_tensor = self.get_X(paths[6], params, A.size, D2_tensor_)
114 | input_dict = {'A': A_tensor, 'A2': A2_tensor, 'B': B_tensor, 'B2': B2_tensor, 'C': C_tensor, 'C2': C2_tensor,
115 | 'D': D_tensor, 'D2': D2_tensor, 'path': paths[0] + '_' + paths[1], 'same_style': self.same_style(paths[2], paths[3])}
116 | return input_dict
117 |
118 | def get_image(self, A_path, transform_scaleA):
119 | A_img = Image.open(A_path)
120 | A_scaled = transform_scaleA(self.crop(A_img))
121 | return A_scaled
122 |
123 | def get_face_image(self, A_path, transform_A, transform_L, size, img):
124 | # read face keypoints from path and crop face region
125 | keypoints, part_list, part_labels = self.read_keypoints(A_path, size)
126 |
127 | # draw edges and possibly add distance transform maps
128 | add_dist_map = not self.opt.no_dist_map
129 | im_edges, dist_tensor = self.draw_face_edges(keypoints, part_list, transform_A, size, add_dist_map)
130 |
131 | # canny edge for background
132 | if not self.opt.no_canny_edge:
133 | edges = feature.canny(np.array(img.convert('L')))
134 | edges = edges * (part_labels == 0) # remove edges within face
135 | im_edges += (edges * 255).astype(np.uint8)
136 | edge_tensor = transform_A(Image.fromarray(self.crop(im_edges)))
137 |
138 | # final input tensor
139 | input_tensor = torch.cat([edge_tensor, dist_tensor]) if add_dist_map else edge_tensor
140 | label_tensor = transform_L(Image.fromarray(self.crop(part_labels.astype(np.uint8)))) * 255.0
141 | return input_tensor, label_tensor
142 |
143 | def read_keypoints(self, A_path, size):
144 | # mapping from keypoints to face part
145 | part_list = [[list(range(0, 17)) + list(range(68, 83)) + [0]], # face
146 | [range(17, 22)], # right eyebrow
147 | [range(22, 27)], # left eyebrow
148 | [[28, 31], range(31, 36), [35, 28]], # nose
149 | [[36,37,38,39], [39,40,41,36]], # right eye
150 | [[42,43,44,45], [45,46,47,42]], # left eye
151 | [range(48, 55), [54,55,56,57,58,59,48]], # mouth
152 | [range(60, 65), [64,65,66,67,60]] # tongue
153 | ]
154 | label_list = [1, 2, 2, 3, 4, 4, 5, 6] # labeling for different facial parts
155 | keypoints = np.loadtxt(A_path, delimiter=',')
156 |
157 | # add upper half face by symmetry
158 | pts = keypoints[:17, :].astype(np.int32)
159 | baseline_y = (pts[0,1] + pts[-1,1]) / 2
160 | upper_pts = pts[1:-1,:].copy()
161 | upper_pts[:,1] = baseline_y + (baseline_y-upper_pts[:,1]) * 2 // 3
162 | keypoints = np.vstack((keypoints, upper_pts[::-1,:]))
163 |
164 | # label map for facial part
165 | w, h = size
166 | part_labels = np.zeros((h, w), np.uint8)
167 | for p, edge_list in enumerate(part_list):
168 | indices = [item for sublist in edge_list for item in sublist]
169 | pts = keypoints[indices, :].astype(np.int32)
170 | cv2.fillPoly(part_labels, pts=[pts], color=label_list[p])
171 |
172 | return keypoints, part_list, part_labels
173 |
174 | def draw_face_edges(self, keypoints, part_list, transform_A, size, add_dist_map):
175 | w, h = size
176 | edge_len = 3 # interpolate 3 keypoints to form a curve when drawing edges
177 | # edge map for face region from keypoints
178 | im_edges = np.zeros((h, w), np.uint8) # edge map for all edges
179 | dist_tensor = 0
180 | e = 1
181 | for edge_list in part_list:
182 | for edge in edge_list:
183 | im_edge = np.zeros((h, w), np.uint8) # edge map for the current edge
184 | for i in range(0, max(1, len(edge)-1), edge_len-1): # divide a long edge into multiple small edges when drawing
185 | sub_edge = edge[i:i+edge_len]
186 | x = keypoints[sub_edge, 0]
187 | y = keypoints[sub_edge, 1]
188 |
189 | curve_x, curve_y = interpPoints(x, y) # interp keypoints to get the curve shape
190 | drawEdge(im_edges, curve_x, curve_y)
191 | if add_dist_map:
192 | drawEdge(im_edge, curve_x, curve_y)
193 |
194 | if add_dist_map: # add distance transform map on each facial part
195 | im_dist = cv2.distanceTransform(255-im_edge, cv2.DIST_L1, 3)
196 | im_dist = np.clip((im_dist / 3), 0, 255).astype(np.uint8)
197 | im_dist = Image.fromarray(im_dist)
198 | tensor_cropped = transform_A(self.crop(im_dist))
199 | dist_tensor = tensor_cropped if e == 1 else torch.cat([dist_tensor, tensor_cropped])
200 | e += 1
201 |
202 | return im_edges, dist_tensor
203 |
204 | def get_crop_coords(self, keypoints, size):
205 | min_y, max_y = keypoints[:,1].min(), keypoints[:,1].max()
206 | min_x, max_x = keypoints[:,0].min(), keypoints[:,0].max()
207 | offset = (max_x - min_x) // 2
208 | min_y = max(0, min_y - offset*2)
209 | min_x = max(0, min_x - offset)
210 | max_x = min(size[0], max_x + offset)
211 | max_y = min(size[1], max_y + offset)
212 | self.min_y, self.max_y, self.min_x, self.max_x = int(min_y), int(max_y), int(min_x), int(max_x)
213 |
214 | def crop(self, img):
215 | return img
216 | #???
217 | if isinstance(img, np.ndarray):
218 | return img[self.min_y:self.max_y, self.min_x:self.max_x]
219 | else:
220 | return img.crop((self.min_x, self.min_y, self.max_x, self.max_y))
221 |
222 | def __len__(self):
223 | return len(self.all_paths) // self.opt.batchSize * self.opt.batchSize
224 |
225 | def name(self):
226 | return 'PoseDataset'
227 |
--------------------------------------------------------------------------------
/encode_features.py:
--------------------------------------------------------------------------------
1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
3 | from options.train_options import TrainOptions
4 | from data.data_loader import CreateDataLoader
5 | from models.models import create_model
6 | import numpy as np
7 | import os
8 |
9 | opt = TrainOptions().parse()
10 | opt.nThreads = 1
11 | opt.batchSize = 1
12 | opt.serial_batches = True
13 | opt.no_flip = True
14 | opt.instance_feat = True
15 |
16 | name = 'features'
17 | save_path = os.path.join(opt.checkpoints_dir, opt.name)
18 |
19 | ############ Initialize #########
20 | data_loader = CreateDataLoader(opt)
21 | dataset = data_loader.load_data()
22 | dataset_size = len(data_loader)
23 | model = create_model(opt)
24 |
25 | ########### Encode features ###########
26 | reencode = True
27 | if reencode:
28 | features = {}
29 | for label in range(opt.label_nc):
30 | features[label] = np.zeros((0, opt.feat_num+1))
31 | for i, data in enumerate(dataset):
32 | feat = model.module.encode_features(data['image'], data['inst'])
33 | for label in range(opt.label_nc):
34 | features[label] = np.append(features[label], feat[label], axis=0)
35 |
36 | print('%d / %d images' % (i+1, dataset_size))
37 | save_name = os.path.join(save_path, name + '.npy')
38 | np.save(save_name, features)
39 |
40 | ############## Clustering ###########
41 | n_clusters = opt.n_clusters
42 | load_name = os.path.join(save_path, name + '.npy')
43 | features = np.load(load_name).item()
44 | from sklearn.cluster import KMeans
45 | centers = {}
46 | for label in range(opt.label_nc):
47 | feat = features[label]
48 | feat = feat[feat[:,-1] > 0.5, :-1]
49 | if feat.shape[0]:
50 | n_clusters = min(feat.shape[0], opt.n_clusters)
51 | kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit(feat)
52 | centers[label] = kmeans.cluster_centers_
53 | save_name = os.path.join(save_path, name + '_clustered_%03d.npy' % opt.n_clusters)
54 | np.save(save_name, centers)
55 | print('saving to %s' % save_name)
--------------------------------------------------------------------------------
/expand_val_test.py:
--------------------------------------------------------------------------------
1 | import os
2 | def expand(path):
3 | f = open(path, "r")
4 | fs = f.readlines()
5 | ans = []
6 | for line in fs:
7 | paths = line.rstrip('\n').split('&')
8 | new = [paths[0], paths[1], paths[2], paths[3], paths[1], paths[2], paths[1], paths[2]]
9 | ans.append('&'.join(new) + '\n')
10 | fw = open(path + '_', "w")
11 | fw.writelines(ans)
12 | path = 'datasets/YouTubeFaces_/'
13 | expand(os.path.join(path, 'val_list.txt'))
14 | expand(os.path.join(path, 'test_list.txt'))
15 |
--------------------------------------------------------------------------------
/generate_data_face_forensics.py:
--------------------------------------------------------------------------------
1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
3 | import os
4 | from collections import OrderedDict
5 | from options.generate_data_options import GenerateDataOptions
6 | from data.data_loader import CreateDataLoader
7 | from PIL import Image
8 | import copy
9 | import numpy as np
10 | from shutil import copyfile
11 | import random
12 | import glob
13 | from skimage import transform,io
14 | import dlib
15 | import sys
16 | import time
17 |
18 | threshold = 130
19 | crop_rate = 0.6
20 | face_rate = 0.5
21 | up_center_rate = 0.2
22 | out_size = 512
23 | side_face_threshold = 0.7
24 | predictor_path = os.path.join('./datasets/', 'shape_predictor_68_face_landmarks.dat')
25 | detector = dlib.get_frontal_face_detector()
26 | predictor = dlib.shape_predictor(predictor_path)
27 | time_last = time.time()
28 | FPS = 30
29 |
30 | def get_keys(img, get_point = True):
31 | global time_last
32 | time_start = time.time()
33 | #print("other:" + str(time_start - time_last))
34 | dets = detector(img, 1)
35 | #print(time.time() - time_start)
36 | time_last = time.time()
37 | detected = False
38 | points = []
39 | left_up_x = 10000000
40 | left_up_y = 10000000
41 | right_down_x = -1
42 | right_down_y = -1
43 |
44 | if len(dets) > 0:
45 | detected = True
46 | if (get_point == False):
47 | return True, [], [dets[0].left(), dets[0].top(), dets[0].right(), dets[0].bottom()]
48 | else:
49 | shape = predictor(img, dets[0])
50 | points = np.empty([68, 2], dtype=int)
51 | for b in range(68):
52 | points[b,0] = shape.part(b).x
53 | points[b,1] = shape.part(b).y
54 | left_up_x = min(left_up_x, points[b, 0])
55 | left_up_y = min(left_up_y, points[b, 1])
56 | right_down_x = max(right_down_x, points[b, 0])
57 | right_down_y = max(right_down_y, points[b, 1])
58 | if (abs(points[8, 0] - points[1, 0]) == 0 or abs(points[8, 0] - points[15, 0]) == 0):
59 | detected = False
60 | else:
61 | r = float(abs(points[8, 0] - points[1, 0])) / abs(points[8, 0] - points[15, 0])
62 | r = min(r, 1 / r)
63 | if (r < side_face_threshold):
64 | detected = False
65 |
66 | return detected, points, [left_up_x, left_up_y, right_down_x, right_down_y]
67 |
68 | def get_img(path, target_path):
69 | #print("233")
70 | img = io.imread(path)
71 | detected, _, box = get_keys(img)
72 | if (detected):
73 | h, w, c = img.shape
74 | b_h = box[3] - box[1]
75 | b_w = box[2] - box[0]
76 | dh = int((b_h / face_rate - b_h) / 2)
77 | dw = int((b_w / face_rate - b_w) / 2)
78 | ddh = int(b_h * up_center_rate)
79 | if (box[1] - dh - ddh < 0 or box[3] + dh - ddh >= h or box[0] - dw < 0 or box[2] + dw >= w):
80 | return False, None
81 | else:
82 | img_ = img[box[1] - dh - ddh : box[3] + dh - ddh, box[0] - dw : box[2] + dw, :].copy()
83 | if (img_.shape[0] < out_size / 4 or img_.shape[1] < out_size / 4):
84 | return False, None
85 | #dh = int(h * (1 - crop_rate) / 2)
86 | #dw = int(w * (1 - crop_rate) / 2)
87 | #img_ = img[dh:h-dh, dw:w-dw, :].copy()
88 | img_ = transform.resize(img_, (out_size, out_size))
89 | #time_s = time.time()
90 | #io.imsave(target_path, img_)
91 | #print("writetime:" + str(time.time() - time_s))
92 | img_ = (np.maximum(np.minimum(255, img_ * 256), 0)).astype(np.uint8)
93 |
94 | return True, img_
95 | else:
96 | return False, None
97 |
98 | def deal(datas, rootpath):
99 | paths = []
100 | path = []
101 | last = ''
102 | min_hw = [1000, 1000]
103 | for data in datas:
104 | paras = data.split(',')
105 | names = paras[0].split('\\')
106 | if (last != names[0] + '\\' + names[1]):
107 | if min_hw[0] > threshold and min_hw[1] > threshold:
108 | paths.append(path)
109 | path = []
110 | min_hw = [1000, 1000]
111 | last = names[0] + '\\' + names[1]
112 | min_hw[0] = min(min_hw[0], int(paras[4]))
113 | min_hw[1] = min(min_hw[1], int(paras[5]))
114 | path.append(os.path.join(rootpath, 'aligned_images_DB', names[0], names[1], 'aligned_detect_' + names[2]))
115 |
116 | if min_hw[0] > threshold and min_hw[1] > threshold:
117 | paths.append(path)
118 | return paths
119 |
120 | def get_num(path):
121 | t = path.find("a.")
122 | return int(path[t + 2 : -4])
123 |
124 | def deal_vid(paths):
125 | paths = sorted(paths, key=lambda path: get_num(path))
126 | ans = []
127 | for i in range(len(paths)):
128 | if (i % 10 == 0):
129 | ans.append(paths[i])
130 | return ans
131 |
132 | def get_paths(rootpath):
133 | data_path = os.path.join(rootpath, 'frame_images_DB')
134 | data_path = rootpath
135 | paths = []
136 | for root, _, fnames in os.walk(data_path):
137 | temp = []
138 | for fname in fnames:
139 | if not fname.endswith('.jpg'):
140 | continue
141 | temp.append(os.path.join(root, fname))
142 | temp = deal_vid(temp)
143 | print(len(temp))
144 | paths.append([temp])
145 | random.shuffle(paths)
146 | paths_ = []
147 | for human in paths:
148 | for vid in human:
149 | paths_.append(vid)
150 | return paths_
151 |
152 | def transform_name(name, root):
153 | name = name[len(root):]
154 | name = name.replace('/', '__').replace(' ', '_')
155 | return name
156 |
157 | def neighbor_index(index, length, n_size):
158 | left = max(0, index[1] - n_size)
159 | right = min(length - 1, index[1] + n_size)
160 | return [index[0], np.random.randint(left, right + 1)]
161 |
162 | def make_list(paths, phase, opt):
163 | indexs = []
164 | for i in range(len(paths)):
165 | for j in range(paths[i]['len']):
166 | indexs.append([i, j])
167 |
168 | ans_list = []
169 | for i in range(len(paths)):
170 | for j in range(paths[i]['len']):
171 | for k in range(opt.A_repeat_num):
172 | '''
173 | if (phase == 'val' or phase == 'test'):
174 | index = indexs[np.random.randint(len(indexs))]
175 | context = paths[i]['label_names'][j] \
176 | + '&' + paths[index[0]]['label_names'][index[1]] + '&' + paths[index[0]]['img_names'][index[1]] \
177 | + '&' + paths[i]['img_names'][j] \
178 | + '\n'
179 | ans_list.append(context)
180 | else:
181 | '''
182 | if (random.random() < opt.same_style_rate):
183 | index_B = neighbor_index([i, j], paths[i]['len'], opt.neighbor_size)
184 | else:
185 | index_B = indexs[np.random.randint(len(indexs))]
186 | index_C = neighbor_index(index_B, paths[index_B[0]]['len'], opt.neighbor_size)
187 | index_D = index_C
188 | while (index_D[0] == index_C[0]):
189 | index_D = indexs[np.random.randint(len(indexs))]
190 | context = paths[i]['label_names'][j] \
191 | + '&' + paths[index_B[0]]['label_names'][index_B[1]] + '&' + paths[index_B[0]]['img_names'][index_B[1]] \
192 | + '&' + paths[i]['img_names'][j] \
193 | + '&' + paths[index_C[0]]['label_names'][index_C[1]] + '&' + paths[index_C[0]]['img_names'][index_C[1]] \
194 | + '&' + paths[index_D[0]]['label_names'][index_D[1]] + '&' + paths[index_D[0]]['img_names'][index_D[1]] \
195 | + '\n'
196 | ans_list.append(context)
197 |
198 | f = open(os.path.join(opt.target_path, phase + '_list.txt'), 'w')
199 | f.writelines(ans_list)
200 |
201 | opt = GenerateDataOptions().parse(save=False)
202 | opt.neighbor_size *= FPS
203 | paths = get_paths(opt.source_path)
204 | label_path = os.path.join(opt.target_path, 'keypoints/')
205 | img_path = os.path.join(opt.target_path, 'img/')
206 | if (not os.path.exists(label_path)):
207 | os.makedirs(label_path)
208 | if (not os.path.exists(img_path)):
209 | os.makedirs(img_path)
210 | paths_ = []
211 | ans = 0
212 | for i in range(len(paths)):
213 | ans += len(paths[i])
214 | print(ans)
215 | tot = 0
216 | for i in range(len(paths)):
217 | if (opt.copy_data):
218 | print(str(tot) + ' ' + str(len(paths[i])))
219 | img_names = []
220 | label_names = []
221 | for j in range(len(paths[i])):
222 | img_name = transform_name(paths[i][j], opt.source_path)
223 | label_name = img_name[: -3] + 'txt'
224 | if (opt.copy_data):
225 | useable, img = get_img(paths[i][j], 'datasets/YouTubeFaces/temp.jpg')
226 | if not useable:
227 | continue
228 | #img = io.imread('datasets/YouTubeFaces/temp.jpg')
229 | detected, keys, box = get_keys(img)
230 | if detected:
231 | tot += 1
232 | io.imsave(os.path.join(img_path, img_name), img)
233 | np.savetxt(os.path.join(label_path, label_name), keys, fmt='%d', delimiter=',')
234 | img_names.append(img_path + img_name)
235 | label_names.append(label_path + label_name)
236 | else:
237 | if os.path.exists(label_path + label_name):
238 | img_names.append(img_path + img_name)
239 | label_names.append(label_path + label_name)
240 | paths_.append({'label_names': label_names, 'img_names':img_names, 'len': len(img_names)})
241 | paths = paths_
242 | img_num = 0
243 | for i in range(len(paths)):
244 | img_num += len(paths[i]['img_names'])
245 | print(img_num)
246 | val_size = int(img_num * opt.val_ratio)
247 | test_size = int(img_num * opt.test_ratio)
248 | did_val = False
249 | tot = 0
250 | last = 0
251 | last_name = ''
252 | for i in range(len(paths)):
253 | tot += len(paths[i]['img_names'])
254 | if (len(paths[i]['img_names']) > 0):
255 | name = paths[i]['img_names'][0].split('__')[1]
256 | if ((not did_val) and tot >= val_size and last_name != name):
257 | print("=============")
258 | did_val = True
259 | tot = 0
260 | last = i + 1
261 | make_list(paths[0:i], 'val', opt)
262 | elif (did_val and tot >= test_size and last_name != name):
263 | print("=============")
264 | make_list(paths[last:i], 'test', opt)
265 | last = i + 1
266 | break
267 | print(name)
268 | last_name = name
269 | make_list(paths[last:], 'train', opt)
270 |
271 |
--------------------------------------------------------------------------------
/img/dance.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cxjyxxme/pix2pixSC/1851e934b333e89b98b3c0df9962912201fcb376/img/dance.png
--------------------------------------------------------------------------------
/img/face.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cxjyxxme/pix2pixSC/1851e934b333e89b98b3c0df9962912201fcb376/img/face.png
--------------------------------------------------------------------------------
/img/fig.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cxjyxxme/pix2pixSC/1851e934b333e89b98b3c0df9962912201fcb376/img/fig.png
--------------------------------------------------------------------------------
/img/scene.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cxjyxxme/pix2pixSC/1851e934b333e89b98b3c0df9962912201fcb376/img/scene.png
--------------------------------------------------------------------------------
/img/small_dance.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cxjyxxme/pix2pixSC/1851e934b333e89b98b3c0df9962912201fcb376/img/small_dance.png
--------------------------------------------------------------------------------
/img/small_face.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cxjyxxme/pix2pixSC/1851e934b333e89b98b3c0df9962912201fcb376/img/small_face.png
--------------------------------------------------------------------------------
/img/small_scene.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cxjyxxme/pix2pixSC/1851e934b333e89b98b3c0df9962912201fcb376/img/small_scene.png
--------------------------------------------------------------------------------
/infer_face.py:
--------------------------------------------------------------------------------
1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
3 | import os
4 | from collections import OrderedDict
5 | from torch.autograd import Variable
6 | from options.test_options import TestOptions
7 | from data.data_loader import CreateFaceConDataLoader
8 | from models.models import create_model
9 | import util.util as util
10 | from util.visualizer import Visualizer
11 | from util import html
12 | import random
13 | import torch
14 | from skimage import transform,io
15 | import dlib
16 | import shutil
17 | import numpy as np
18 |
19 | predictor_path = os.path.join('./datasets/', 'shape_predictor_68_face_landmarks.dat')
20 | detector = dlib.get_frontal_face_detector()
21 | predictor = dlib.shape_predictor(predictor_path)
22 |
23 | def get_keys(img, get_point = True):
24 | dets = detector(img, 1)
25 | detected = False
26 | points = []
27 | left_up_x = 10000000
28 | left_up_y = 10000000
29 | right_down_x = -1
30 | right_down_y = -1
31 |
32 | if len(dets) > 0:
33 | detected = True
34 | if (get_point == False):
35 | return True, [], [dets[0].left(), dets[0].top(), dets[0].right(), dets[0].bottom()]
36 | else:
37 | shape = predictor(img, dets[0])
38 | points = np.empty([68, 2], dtype=int)
39 | for b in range(68):
40 | points[b,0] = shape.part(b).x
41 | points[b,1] = shape.part(b).y
42 | left_up_x = min(left_up_x, points[b, 0])
43 | left_up_y = min(left_up_y, points[b, 1])
44 | right_down_x = max(right_down_x, points[b, 0])
45 | right_down_y = max(right_down_y, points[b, 1])
46 | if (abs(points[8, 0] - points[1, 0]) == 0 or abs(points[8, 0] - points[15, 0]) == 0):
47 | detected = False
48 | else:
49 | r = float(abs(points[8, 0] - points[1, 0])) / abs(points[8, 0] - points[15, 0])
50 | r = min(r, 1 / r)
51 |
52 | return detected, points, [left_up_x, left_up_y, right_down_x, right_down_y]
53 |
54 | def get_img(path):
55 | face_rate = 0.5
56 | up_center_rate = 0.2
57 | out_size = 512
58 | img = io.imread(path)
59 | detected, _, box = get_keys(img)
60 | if (detected):
61 | h, w, c = img.shape
62 | b_h = box[3] - box[1]
63 | b_w = box[2] - box[0]
64 | dh = int((b_h / face_rate - b_h) / 2)
65 | dw = int((b_w / face_rate - b_w) / 2)
66 | ddh = int(b_h * up_center_rate)
67 | if (box[1] - dh - ddh < 0 or box[3] + dh - ddh >= h or box[0] - dw < 0 or box[2] + dw >= w):
68 | return False, None
69 | else:
70 | img_ = img[box[1] - dh - ddh : box[3] + dh - ddh, box[0] - dw : box[2] + dw, :].copy()
71 | if (img_.shape[0] < out_size / 4 or img_.shape[1] < out_size / 4):
72 | return False, None
73 | img_ = transform.resize(img_, (out_size, out_size))
74 | img_ = (np.maximum(np.minimum(255, img_ * 256), 0)).astype(np.uint8)
75 | return True, img_
76 | else:
77 | return False, None
78 |
79 | def get_info(path, img_path, key_path):
80 | useable, img = get_img(path)
81 | assert useable
82 | detected, keys, box = get_keys(img)
83 | assert detected
84 | io.imsave(img_path, img)
85 | np.savetxt(key_path, keys, fmt='%d', delimiter=',')
86 | return img_path, key_path
87 |
88 |
89 | def make_data():
90 | assert os.path.exists('inference/infer_list.txt')
91 | with open('inference/infer_list.txt', 'r') as f:
92 | tasks = f.readlines()
93 | if os.path.exists('inference/data'):
94 | shutil.rmtree('inference/data')
95 | if os.path.exists('inference/output'):
96 | shutil.rmtree('inference/output')
97 | os.makedirs('inference/data')
98 | os.makedirs('inference/output')
99 | img_path = 'inference/data/img'
100 | key_path = 'inference/data/keypoints'
101 | os.makedirs(img_path)
102 | os.makedirs(key_path)
103 |
104 | ans_list = []
105 | for i in range(len(tasks)):
106 | name = '__Mr_'+str(i)+'__0__a.'
107 | img1, key1 = get_info(tasks[i].split(' ')[0], os.path.join(img_path, name + '0.jpg'), os.path.join(key_path, name + '0.txt'))
108 | img2, key2 = get_info(tasks[i].split(' ')[1].rstrip('\n'), os.path.join(img_path, name + '1.jpg'), os.path.join(key_path, name + '1.txt'))
109 |
110 | context = key1 \
111 | + '&' + key2 + '&' + img2 \
112 | + '&' + img1 \
113 | + '&' + key1 + '&' + img1 \
114 | + '&' + key1 + '&' + img1 \
115 | + '\n'
116 | ans_list.append(context)
117 | f = open('inference/data/infer_list.txt', 'w')
118 | f.writelines(ans_list)
119 |
120 | make_data()
121 | opt = TestOptions().parse(save=False)
122 | opt.nThreads = 1 # test code only supports nThreads = 1
123 | opt.batchSize = 1 # test code only supports batchSize = 1
124 | opt.serial_batches = True # no shuffle
125 | opt.no_flip = True # no flip
126 |
127 | data_loader = CreateFaceConDataLoader(opt)
128 | dataset = data_loader.load_data()
129 | visualizer = Visualizer(opt)
130 | # create website
131 | web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s_%s' % (opt.phase, opt.which_epoch, str(opt.serial_batches)))
132 | webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch))
133 |
134 | # test
135 | if not opt.engine and not opt.onnx:
136 | model = create_model(opt)
137 | if opt.data_type == 16:
138 | model.half()
139 | elif opt.data_type == 8:
140 | model.type(torch.uint8)
141 |
142 | if opt.verbose:
143 | print(model)
144 | else:
145 | from run_engine import run_trt_engine, run_onnx
146 |
147 | for i, data in enumerate(dataset, start=0):
148 | if opt.data_type == 16:
149 | data['A'] = data['A'].half()
150 | data['A2'] = data['A2'].half()
151 | data['B'] = data['B'].half()
152 | data['B2'] = data['B2'].half()
153 | elif opt.data_type == 8:
154 | data['A'] = data['A'].uint8()
155 | data['A2'] = data['A2'].uint8()
156 | data['B'] = data['B'].uint8()
157 | data['B2'] = data['B2'].uint8()
158 | if opt.export_onnx:
159 | print ("Exporting to ONNX: ", opt.export_onnx)
160 | assert opt.export_onnx.endswith("onnx"), "Export model file should end with .onnx"
161 | torch.onnx.export(model, [data['label'], data['inst']],
162 | opt.export_onnx, verbose=True)
163 | exit(0)
164 | minibatch = 1
165 | if opt.engine:
166 | generated = run_trt_engine(opt.engine, minibatch, [data['label'], data['inst']])
167 | elif opt.onnx:
168 | generated = run_onnx(opt.onnx, opt.data_type, minibatch, [data['label'], data['inst']])
169 | else:
170 | generated = model.inference(data['A'], data['B'], data['B2'])
171 |
172 | visuals = OrderedDict([('input_label', util.tensor2label(data['A'][0], 0)),
173 | ('real_image', util.tensor2im(data['A2'][0])),
174 | ('synthesized_image', util.tensor2im(generated.data[0])),
175 | ('B', util.tensor2label(data['B'][0], 0)),
176 | ('B2', util.tensor2im(data['B2'][0]))])
177 | img_path = data['path']
178 | img_path[0] = str(i)
179 | print('process image... %s' % img_path)
180 | visualizer.save_images(webpage, visuals, img_path)
181 |
182 | webpage.save()
183 |
--------------------------------------------------------------------------------
/inference/data/img/__Mr_0__0__a.0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cxjyxxme/pix2pixSC/1851e934b333e89b98b3c0df9962912201fcb376/inference/data/img/__Mr_0__0__a.0.jpg
--------------------------------------------------------------------------------
/inference/data/img/__Mr_0__0__a.1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cxjyxxme/pix2pixSC/1851e934b333e89b98b3c0df9962912201fcb376/inference/data/img/__Mr_0__0__a.1.jpg
--------------------------------------------------------------------------------
/inference/data/img/__Mr_1__0__a.0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cxjyxxme/pix2pixSC/1851e934b333e89b98b3c0df9962912201fcb376/inference/data/img/__Mr_1__0__a.0.jpg
--------------------------------------------------------------------------------
/inference/data/img/__Mr_1__0__a.1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cxjyxxme/pix2pixSC/1851e934b333e89b98b3c0df9962912201fcb376/inference/data/img/__Mr_1__0__a.1.jpg
--------------------------------------------------------------------------------
/inference/data/infer_list.txt:
--------------------------------------------------------------------------------
1 | inference/data/keypoints/__Mr_0__0__a.0.txt&inference/data/keypoints/__Mr_0__0__a.1.txt&inference/data/img/__Mr_0__0__a.1.jpg&inference/data/img/__Mr_0__0__a.0.jpg&inference/data/keypoints/__Mr_0__0__a.0.txt&inference/data/img/__Mr_0__0__a.0.jpg&inference/data/keypoints/__Mr_0__0__a.0.txt&inference/data/img/__Mr_0__0__a.0.jpg
2 | inference/data/keypoints/__Mr_1__0__a.0.txt&inference/data/keypoints/__Mr_1__0__a.1.txt&inference/data/img/__Mr_1__0__a.1.jpg&inference/data/img/__Mr_1__0__a.0.jpg&inference/data/keypoints/__Mr_1__0__a.0.txt&inference/data/img/__Mr_1__0__a.0.jpg&inference/data/keypoints/__Mr_1__0__a.0.txt&inference/data/img/__Mr_1__0__a.0.jpg
3 |
--------------------------------------------------------------------------------
/inference/data/keypoints/__Mr_0__0__a.0.txt:
--------------------------------------------------------------------------------
1 | 126,246
2 | 130,279
3 | 134,312
4 | 140,345
5 | 153,374
6 | 175,397
7 | 202,414
8 | 231,427
9 | 264,430
10 | 295,426
11 | 322,412
12 | 345,394
13 | 362,370
14 | 371,340
15 | 376,310
16 | 381,279
17 | 384,249
18 | 162,209
19 | 176,190
20 | 198,180
21 | 222,181
22 | 244,190
23 | 284,192
24 | 307,185
25 | 330,185
26 | 350,196
27 | 362,215
28 | 265,228
29 | 266,247
30 | 267,266
31 | 268,285
32 | 240,308
33 | 253,310
34 | 265,313
35 | 277,311
36 | 288,309
37 | 187,233
38 | 201,222
39 | 218,223
40 | 231,237
41 | 217,239
42 | 199,238
43 | 295,239
44 | 309,226
45 | 325,226
46 | 338,238
47 | 326,242
48 | 309,242
49 | 216,348
50 | 237,337
51 | 254,331
52 | 264,333
53 | 274,331
54 | 289,339
55 | 306,350
56 | 289,356
57 | 275,358
58 | 263,358
59 | 252,357
60 | 237,355
61 | 224,347
62 | 253,342
63 | 264,343
64 | 274,343
65 | 299,349
66 | 274,344
67 | 264,344
68 | 253,343
69 |
--------------------------------------------------------------------------------
/inference/data/keypoints/__Mr_0__0__a.1.txt:
--------------------------------------------------------------------------------
1 | 129,217
2 | 127,253
3 | 128,289
4 | 132,326
5 | 147,359
6 | 172,387
7 | 199,410
8 | 226,429
9 | 255,435
10 | 283,429
11 | 305,405
12 | 328,383
13 | 348,359
14 | 365,332
15 | 378,304
16 | 383,274
17 | 385,244
18 | 161,188
19 | 185,179
20 | 211,181
21 | 236,188
22 | 260,199
23 | 304,201
24 | 326,194
25 | 349,191
26 | 370,194
27 | 383,208
28 | 281,231
29 | 279,256
30 | 278,281
31 | 277,306
32 | 241,313
33 | 256,319
34 | 271,325
35 | 285,323
36 | 298,320
37 | 188,217
38 | 205,211
39 | 223,214
40 | 239,230
41 | 220,228
42 | 202,226
43 | 309,237
44 | 326,224
45 | 344,225
46 | 357,235
47 | 345,240
48 | 327,239
49 | 193,332
50 | 223,335
51 | 251,338
52 | 266,343
53 | 282,342
54 | 302,345
55 | 321,346
56 | 298,373
57 | 277,385
58 | 260,385
59 | 243,381
60 | 217,364
61 | 200,336
62 | 249,346
63 | 265,350
64 | 281,349
65 | 312,349
66 | 279,369
67 | 262,370
68 | 246,365
69 |
--------------------------------------------------------------------------------
/inference/data/keypoints/__Mr_1__0__a.0.txt:
--------------------------------------------------------------------------------
1 | 129,217
2 | 127,253
3 | 128,289
4 | 132,326
5 | 147,359
6 | 172,387
7 | 199,410
8 | 226,429
9 | 255,435
10 | 283,429
11 | 305,405
12 | 328,383
13 | 348,359
14 | 365,332
15 | 378,304
16 | 383,274
17 | 385,244
18 | 161,188
19 | 185,179
20 | 211,181
21 | 236,188
22 | 260,199
23 | 304,201
24 | 326,194
25 | 349,191
26 | 370,194
27 | 383,208
28 | 281,231
29 | 279,256
30 | 278,281
31 | 277,306
32 | 241,313
33 | 256,319
34 | 271,325
35 | 285,323
36 | 298,320
37 | 188,217
38 | 205,211
39 | 223,214
40 | 239,230
41 | 220,228
42 | 202,226
43 | 309,237
44 | 326,224
45 | 344,225
46 | 357,235
47 | 345,240
48 | 327,239
49 | 193,332
50 | 223,335
51 | 251,338
52 | 266,343
53 | 282,342
54 | 302,345
55 | 321,346
56 | 298,373
57 | 277,385
58 | 260,385
59 | 243,381
60 | 217,364
61 | 200,336
62 | 249,346
63 | 265,350
64 | 281,349
65 | 312,349
66 | 279,369
67 | 262,370
68 | 246,365
69 |
--------------------------------------------------------------------------------
/inference/data/keypoints/__Mr_1__0__a.1.txt:
--------------------------------------------------------------------------------
1 | 126,246
2 | 130,279
3 | 134,312
4 | 140,345
5 | 153,374
6 | 175,397
7 | 202,414
8 | 231,427
9 | 264,430
10 | 295,426
11 | 322,412
12 | 345,394
13 | 362,370
14 | 371,340
15 | 376,310
16 | 381,279
17 | 384,249
18 | 162,209
19 | 176,190
20 | 198,180
21 | 222,181
22 | 244,190
23 | 284,192
24 | 307,185
25 | 330,185
26 | 350,196
27 | 362,215
28 | 265,228
29 | 266,247
30 | 267,266
31 | 268,285
32 | 240,308
33 | 253,310
34 | 265,313
35 | 277,311
36 | 288,309
37 | 187,233
38 | 201,222
39 | 218,223
40 | 231,237
41 | 217,239
42 | 199,238
43 | 295,239
44 | 309,226
45 | 325,226
46 | 338,238
47 | 326,242
48 | 309,242
49 | 216,348
50 | 237,337
51 | 254,331
52 | 264,333
53 | 274,331
54 | 289,339
55 | 306,350
56 | 289,356
57 | 275,358
58 | 263,358
59 | 252,357
60 | 237,355
61 | 224,347
62 | 253,342
63 | 264,343
64 | 274,343
65 | 299,349
66 | 274,344
67 | 264,344
68 | 253,343
69 |
--------------------------------------------------------------------------------
/inference/infer_list.txt:
--------------------------------------------------------------------------------
1 | inference/test_imgs/1.jpg inference/test_imgs/2.jpg
2 | inference/test_imgs/2.jpg inference/test_imgs/1.jpg
3 |
--------------------------------------------------------------------------------
/inference/test_imgs/1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cxjyxxme/pix2pixSC/1851e934b333e89b98b3c0df9962912201fcb376/inference/test_imgs/1.jpg
--------------------------------------------------------------------------------
/inference/test_imgs/2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cxjyxxme/pix2pixSC/1851e934b333e89b98b3c0df9962912201fcb376/inference/test_imgs/2.jpg
--------------------------------------------------------------------------------
/judge_face.py:
--------------------------------------------------------------------------------
1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
3 | import os
4 | from collections import OrderedDict
5 | from options.generate_data_options import GenerateDataOptions
6 | from data.data_loader import CreateDataLoader
7 | from PIL import Image
8 | import copy
9 | import numpy as np
10 | from shutil import copyfile
11 | import random
12 | import glob
13 | from skimage import transform,io
14 | import dlib
15 | import sys
16 | import time
17 | import math
18 |
19 | threshold = 130
20 | crop_rate = 0.6
21 | face_rate = 0.5
22 | up_center_rate = 0.2
23 | out_size = 512
24 | side_face_threshold = 0.7
25 | predictor_path = os.path.join('./datasets/', 'shape_predictor_68_face_landmarks.dat')
26 | detector = dlib.get_frontal_face_detector()
27 | predictor = dlib.shape_predictor(predictor_path)
28 | time_last = time.time()
29 |
30 | pts_max_score = 1000#.1 * 256
31 | pic_max_score = 10#0.007 * 256
32 | #pic_size = 256.0
33 | pic_size = 1.0
34 |
35 | def get_keys(img, get_point = True):
36 | global time_last
37 | time_start = time.time()
38 | #print("other:" + str(time_start - time_last))
39 | dets = detector(img, 1)
40 | #print(time.time() - time_start)
41 | time_last = time.time()
42 | detected = False
43 | points = []
44 | left_up_x = 10000000
45 | left_up_y = 10000000
46 | right_down_x = -1
47 | right_down_y = -1
48 |
49 | if len(dets) > 0:
50 | detected = True
51 | if (get_point == False):
52 | return True, [], [dets[0].left(), dets[0].top(), dets[0].right(), dets[0].bottom()]
53 | else:
54 | shape = predictor(img, dets[0])
55 | points = np.empty([68, 2], dtype=int)
56 | for b in range(68):
57 | points[b,0] = shape.part(b).x
58 | points[b,1] = shape.part(b).y
59 | left_up_x = min(left_up_x, points[b, 0])
60 | left_up_y = min(left_up_y, points[b, 1])
61 | right_down_x = max(right_down_x, points[b, 0])
62 | right_down_y = max(right_down_y, points[b, 1])
63 | '''
64 | if (abs(points[8, 0] - points[1, 0]) == 0 or abs(points[8, 0] - points[15, 0]) == 0):
65 | detected = False
66 | else:
67 | r = float(abs(points[8, 0] - points[1, 0])) / abs(points[8, 0] - points[15, 0])
68 | r = min(r, 1 / r)
69 | if (r < side_face_threshold):
70 | detected = False
71 | '''
72 |
73 | return detected, points, [left_up_x, left_up_y, right_down_x, right_down_y]
74 | def transfer_list(list1):
75 | list2 = []
76 | for v in list1:
77 | list2.append(v[:-len('synthesized_image.jpg')] + 'real_image.jpg')
78 | #list2.append(v[:-len('synthesized_image.jpg')] + 'B2.jpg')
79 | return list2
80 | '''
81 |
82 | def transfer_list(list1):
83 | list2 = []
84 | for v in list1:
85 | id=v.split('_')[-1][:-4]
86 | t = 'results/label2city_256p_face_102/test_final_latest_True/images/'+id+'_real_image.jpg'
87 | list2.append(t)
88 | return list2
89 | '''
90 |
91 | def dist(p1, p2):
92 | return math.sqrt((p1[0] - p2[0]) * (p1[0] - p2[0]) + (p1[1] - p2[1]) * (p1[1] - p2[1]))
93 |
94 | path = 'results/label2city_256p_face_104/test_final_latest_True/images'
95 |
96 | list1 = []
97 | for root, _, fnames in os.walk(path):
98 | for fname in fnames:
99 | if (fname.endswith('_synthesized_image.jpg')):
100 | list1.append(os.path.join(root, fname))
101 | list2 = transfer_list(list1)
102 | score = 0
103 | print(len(list1))
104 | for i in range(len(list1)):
105 | img1 = io.imread(list1[i])
106 | img2 = io.imread(list2[i])
107 | detected1, points1, _ = get_keys(img1)
108 | detected2, points2, _ = get_keys(img2)
109 | if (not detected1) or (not detected2):
110 | temp_score = pic_max_score
111 | else:
112 | points1 = points1 / pic_size
113 | points2 = points2 / pic_size
114 | temp_score = 0
115 | for j in range(68):
116 | d = dist(points1[j, :], points2[j, :])
117 | d = min(pts_max_score, d)
118 | temp_score += d
119 | temp_score = min(pic_max_score, temp_score / 68)
120 | score += temp_score
121 | if (i % 10 == 0):
122 | print(i)
123 | print(score / (i + 1)/256)
124 | print(temp_score/256)
125 | score /= len(list1)
126 | print(score/256)
127 |
128 |
129 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cxjyxxme/pix2pixSC/1851e934b333e89b98b3c0df9962912201fcb376/models/__init__.py
--------------------------------------------------------------------------------
/models/base_model.py:
--------------------------------------------------------------------------------
1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
3 | import os
4 | import torch
5 | import sys
6 |
7 | class BaseModel(torch.nn.Module):
8 | def name(self):
9 | return 'BaseModel'
10 |
11 | def initialize(self, opt):
12 | self.opt = opt
13 | self.gpu_ids = opt.gpu_ids
14 | self.isTrain = opt.isTrain
15 | self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor
16 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
17 |
18 | def set_input(self, input):
19 | self.input = input
20 |
21 | def forward(self):
22 | pass
23 |
24 | # used in test time, no backprop
25 | def test(self):
26 | pass
27 |
28 | def get_image_paths(self):
29 | pass
30 |
31 | def optimize_parameters(self):
32 | pass
33 |
34 | def get_current_visuals(self):
35 | return self.input
36 |
37 | def get_current_errors(self):
38 | return {}
39 |
40 | def save(self, label):
41 | pass
42 |
43 | # helper saving function that can be used by subclasses
44 | def save_network(self, network, network_label, epoch_label, gpu_ids):
45 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
46 | save_path = os.path.join(self.save_dir, save_filename)
47 | torch.save(network.cpu().state_dict(), save_path)
48 | if len(gpu_ids) and torch.cuda.is_available():
49 | network.cuda()
50 |
51 | # helper loading function that can be used by subclasses
52 | def load_network(self, network, network_label, epoch_label, save_dir=''):
53 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
54 | if not save_dir:
55 | save_dir = self.save_dir
56 | save_path = os.path.join(save_dir, save_filename)
57 | if not os.path.isfile(save_path):
58 | print('%s not exists yet!' % save_path)
59 | if network_label == 'G':
60 | raise('Generator must exist!')
61 | else:
62 | #network.load_state_dict(torch.load(save_path))
63 | try:
64 | network.load_state_dict(torch.load(save_path))
65 | except:
66 | pretrained_dict = torch.load(save_path)
67 | model_dict = network.state_dict()
68 | try:
69 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
70 | network.load_state_dict(pretrained_dict)
71 | if self.opt.verbose:
72 | print('Pretrained network %s has excessive layers; Only loading layers that are used' % network_label)
73 | except:
74 | print('Pretrained network %s has fewer layers; The following are not initialized:' % network_label)
75 | for k, v in pretrained_dict.items():
76 | if v.size() == model_dict[k].size():
77 | model_dict[k] = v
78 |
79 | if sys.version_info >= (3,0):
80 | not_initialized = set()
81 | else:
82 | from sets import Set
83 | not_initialized = Set()
84 |
85 | for k, v in model_dict.items():
86 | if k not in pretrained_dict or v.size() != pretrained_dict[k].size():
87 | not_initialized.add(k.split('.')[0])
88 |
89 | print(sorted(not_initialized))
90 | network.load_state_dict(model_dict)
91 |
92 | def update_learning_rate():
93 | pass
94 |
--------------------------------------------------------------------------------
/models/models.py:
--------------------------------------------------------------------------------
1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
3 | import torch
4 |
5 | def create_model(opt):
6 | if opt.model == 'pix2pixHD':
7 | from .pix2pixHD_model import Pix2PixHDModel, InferenceModel
8 | if opt.isTrain:
9 | model = Pix2PixHDModel()
10 | else:
11 | model = InferenceModel()
12 | elif opt.model == 'c_pix2pixHD':
13 | from .c_pix2pixHD_model import CPix2PixHDModel, InferenceModel
14 | if opt.isTrain:
15 | model = CPix2PixHDModel()
16 | else:
17 | model = InferenceModel()
18 | elif opt.model == 'cm_pix2pixHD':
19 | from .cm_pix2pixHD_model import CMPix2PixHDModel, InferenceModel
20 | if opt.isTrain:
21 | model = CMPix2PixHDModel()
22 | else:
23 | model = InferenceModel()
24 | else:
25 | from .ui_model import UIModel
26 | model = UIModel()
27 | model.initialize(opt)
28 | if opt.verbose:
29 | print("model [%s] was created" % (model.name()))
30 |
31 | if opt.isTrain and len(opt.gpu_ids):
32 | model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids)
33 |
34 | return model
35 |
--------------------------------------------------------------------------------
/new_scripts/infer_face.sh:
--------------------------------------------------------------------------------
1 | ### Adding instances and encoded features
2 | python infer_face.py --name label2city_256p_face_102 --label_nc 0 --no_canny_edge --input_nc 15 --fineSize 256 --dataroot './inference/data/' --model 'c_pix2pixHD' --no_instance --loadSize 256 --phase 'infer' --gpu_ids 0;
3 |
--------------------------------------------------------------------------------
/new_scripts/table_face.sh:
--------------------------------------------------------------------------------
1 | ### Adding instances and encoded features
2 | python gen_vis_face.py --name label2city_256p_face_102 --label_nc 0 --test_delta_path 'vis_delta.txt' --no_canny_edge --input_nc 15 --fineSize 256 --dataroot './datasets/FaceForensics3/' --model 'c_pix2pixHD' --no_instance --loadSize 256 --phase 'val' --gpu_ids 0;
3 |
--------------------------------------------------------------------------------
/new_scripts/table_pose.sh:
--------------------------------------------------------------------------------
1 | ### Adding instances and encoded features
2 | python gen_vis_pose.py --name label2city_256p_pose_106 --label_nc 0 --test_delta_path "vis_delta.txt" --no_canny_edge --input_nc 23 --do_pose_dist_map --fineSize 256 --dataroot './datasets/YouTubePose2/' --model 'c_pix2pixHD' --no_instance --loadSize 256 --phase 'val' --gpu_ids 1;
3 |
--------------------------------------------------------------------------------
/new_scripts/table_scene.sh:
--------------------------------------------------------------------------------
1 | ################################ Testing ################################
2 | # first precompute and cluster all features
3 | #python encode_features.py --name label2city_256p_feat;
4 | # use instance-wise features
5 |
6 | python gen_vis_bdd.py --name label2city_256p_face_108 --label_nc 42 --use_new_label --resize_or_crop 'scale_width_and_crop' --fineSize 256 --label_indexs '0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,255' --dataroot './datasets/bdd2/' --model 'c_pix2pixHD' --gpu_ids 2 --no_instance --loadSize 256 --phase 'val';
7 |
--------------------------------------------------------------------------------
/new_scripts/test_face.sh:
--------------------------------------------------------------------------------
1 | ### Adding instances and encoded features
2 | python test_face.py --name label2city_256p_face_102 --how_many 2500 --label_nc 0 --no_canny_edge --input_nc 15 --fineSize 256 --dataroot './datasets/FaceForensics3/' --model 'c_pix2pixHD' --no_instance --loadSize 256 --phase 'test_final' --gpu_ids 0;
3 |
--------------------------------------------------------------------------------
/new_scripts/test_pose.sh:
--------------------------------------------------------------------------------
1 | ### Adding instances and encoded features
2 | python test_pose.py --name label2city_256p_pose_106 --how_many 2500 --label_nc 0 --test_delta_path "vis_delta.txt" --no_canny_edge --input_nc 23 --do_pose_dist_map --fineSize 256 --dataroot './datasets/YouTubePose2/' --model 'c_pix2pixHD' --no_instance --loadSize 256 --phase 'test_final' --gpu_ids 3;
3 |
--------------------------------------------------------------------------------
/new_scripts/test_scene.sh:
--------------------------------------------------------------------------------
1 | ################################ Testing ################################
2 | # first precompute and cluster all features
3 | #python encode_features.py --name label2city_256p_feat;
4 | # use instance-wise features
5 |
6 | python test_seg.py --name label2city_256p_face_108 --use_new_label --how_many 2500 --label_nc 42 --resize_or_crop 'scale_width_and_crop' --fineSize 256 --label_indexs '0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,255' --dataroot './datasets/bdd2/' --model 'c_pix2pixHD' --gpu_ids 3 --no_instance --loadSize 256 --phase 'test_final';
7 |
--------------------------------------------------------------------------------
/new_scripts/train_face.sh:
--------------------------------------------------------------------------------
1 | ### Adding instances and encoded features
2 | python train_face.py --name label2city_256p_face_1 --use_self_loss --self_vgg_mul 20 --no_D_label --no_ganFeat_loss --label_nc 0 --no_canny_edge --style_stage_mul "0:0,500000:0.1,550000:0.3,600000:0.5,700000:1" --niter_iter 1500000 --use_style_iter 0 --niter_decay_iter 200000 --input_nc 15 --display_freq 300 --fineSize 256 --dataroot './datasets/FaceForensics3/' --model 'c_pix2pixHD' --no_instance --gpu_ids=1 --batchSize=1 --use_iter_decay --loadSize 256 --num_D 1;
3 |
--------------------------------------------------------------------------------
/new_scripts/train_pose.sh:
--------------------------------------------------------------------------------
1 | ### Adding instances and encoded features
2 | python train_pose.py --name label2city_256p_pose_1 --use_self_loss --self_vgg_mul 20 --no_D_label --do_pose_dist_map --label_nc 0 --style_stage_mul "0:0,600000:0.1,650000:0.5,700000:1" --random_drop_prob 0 --no_ganFeat_loss --no_canny_edge --niter_iter 1100000 --use_style_iter -1 --niter_decay_iter 200000 --input_nc 23 --display_freq 300 --fineSize 256 --dataroot './datasets/YouTubePose2/' --model 'c_pix2pixHD' --no_instance --gpu_ids=2 --batchSize=1 --use_iter_decay --loadSize 256 --num_D 1;
3 |
--------------------------------------------------------------------------------
/new_scripts/train_scene.sh:
--------------------------------------------------------------------------------
1 | ### Adding instances and encoded features
2 | python train_seg.py --name label2city_256p_scene_1 --label_nc 42 --no_ganFeat_loss --use_new_label --no_D_label --no_canny_edge --use_self_loss --self_vgg_mul 20 --niter_iter 200000 --use_style_iter -1 --niter_decay_iter 200000 --display_freq 300 --fineSize 256 --dataroot './datasets/bdd2/' --model 'c_pix2pixHD' --no_instance --gpu_ids=3 --batchSize=1 --use_iter_decay --loadSize 256 --num_D 1 --label_indexs '0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,255';
3 |
--------------------------------------------------------------------------------
/options/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cxjyxxme/pix2pixSC/1851e934b333e89b98b3c0df9962912201fcb376/options/__init__.py
--------------------------------------------------------------------------------
/options/base_options.py:
--------------------------------------------------------------------------------
1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
3 | import argparse
4 | import os
5 | from util import util
6 | import torch
7 |
8 | class BaseOptions():
9 | def __init__(self):
10 | self.parser = argparse.ArgumentParser()
11 | self.initialized = False
12 |
13 | def initialize(self):
14 | # experiment specifics
15 | self.parser.add_argument('--name', type=str, default='label2city', help='name of the experiment. It decides where to store samples and models')
16 | self.parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
17 | self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
18 | self.parser.add_argument('--model', type=str, default='pix2pixHD', help='which model to use')
19 | self.parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization')
20 | self.parser.add_argument('--use_dropout', action='store_true', help='use dropout for the generator')
21 | self.parser.add_argument('--data_type', default=32, type=int, choices=[8, 16, 32], help="Supported data type i.e. 8, 16, 32 bit")
22 | self.parser.add_argument('--verbose', action='store_true', default=False, help='toggles verbose')
23 |
24 | # input/output sizes
25 | self.parser.add_argument('--batchSize', type=int, default=1, help='input batch size')
26 | self.parser.add_argument('--loadSize', type=int, default=1024, help='scale images to this size')
27 | self.parser.add_argument('--fineSize', type=int, default=512, help='then crop to this size')
28 | self.parser.add_argument('--label_nc', type=int, default=35, help='# of input label channels')
29 | self.parser.add_argument('--label_indexs', type=str, default='', help='label index ids, empty means 0..label_nc-1')
30 | self.parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels')
31 | self.parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels')
32 |
33 | # for setting inputs
34 | self.parser.add_argument('--dataroot', type=str, default='./datasets/cityscapes/')
35 | self.parser.add_argument('--resize_or_crop', type=str, default='scale_width', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]')
36 | self.parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
37 | self.parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data argumentation')
38 | self.parser.add_argument('--nThreads', default=2, type=int, help='# threads for loading data')
39 | self.parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
40 |
41 | # for displays
42 | self.parser.add_argument('--display_winsize', type=int, default=512, help='display window size')
43 | self.parser.add_argument('--tf_log', action='store_true', help='if specified, use tensorboard logging. Requires tensorflow installed')
44 |
45 | # for generator
46 | self.parser.add_argument('--netG', type=str, default='global', help='selects model to use for netG')
47 | self.parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer')
48 | self.parser.add_argument('--n_downsample_global', type=int, default=4, help='number of downsampling layers in netG')
49 | self.parser.add_argument('--n_blocks_global', type=int, default=9, help='number of residual blocks in the global generator network')
50 | self.parser.add_argument('--n_blocks_local', type=int, default=3, help='number of residual blocks in the local enhancer network')
51 | self.parser.add_argument('--n_local_enhancers', type=int, default=1, help='number of local enhancers to use')
52 | self.parser.add_argument('--niter_fix_global', type=int, default=0, help='number of epochs that we only train the outmost local enhancer')
53 |
54 | # for instance-wise features
55 | self.parser.add_argument('--no_instance', action='store_true', help='if specified, do *not* add instance map as input')
56 | self.parser.add_argument('--instance_feat', action='store_true', help='if specified, add encoded instance features as input')
57 | self.parser.add_argument('--label_feat', action='store_true', help='if specified, add encoded label features as input')
58 | self.parser.add_argument('--feat_num', type=int, default=3, help='vector length for encoded features')
59 | self.parser.add_argument('--load_features', action='store_true', help='if specified, load precomputed feature maps')
60 | self.parser.add_argument('--n_downsample_E', type=int, default=4, help='# of downsampling layers in encoder')
61 | self.parser.add_argument('--nef', type=int, default=16, help='# of encoder filters in the first conv layer')
62 | self.parser.add_argument('--n_clusters', type=int, default=10, help='number of clusters for features')
63 |
64 | self.parser.add_argument('--no_canny_edge', action='store_true', help='no canny edge for background')
65 | self.parser.add_argument('--no_dist_map', action='store_true', help='no dist map for background')
66 | self.parser.add_argument('--do_pose_dist_map', action='store_true', help='do pose dist map for background')
67 |
68 | self.parser.add_argument('--remove_face_labels', action='store_true', help='remove face labels to better adapt to different face shapes')
69 | self.parser.add_argument('--random_drop_prob', type=float, default=0.2, help='the probability to randomly drop each pose segment during training')
70 |
71 | self.parser.add_argument('--densepose_only', action='store_true', help='use only densepose as input')
72 | self.parser.add_argument('--openpose_only', action='store_true', help='use only openpose as input')
73 | self.parser.add_argument('--add_mask', action='store_true', help='mask of background')
74 |
75 | self.parser.add_argument('--vgg_weights', type=str, default='1,1,1,1,1', help='vgg weights of ans&guidence loss')
76 | self.parser.add_argument('--gram_weights', type=str, default='1,1,1,1,1', help='gram weights of ans&guidence loss')
77 | self.parser.add_argument('--guide_vgg_mul', type=float, default=0.0, help='')
78 | self.parser.add_argument('--guide_gram_mul', type=float, default=0.0, help='')
79 | self.parser.add_argument('--use_self_loss', action='store_true', help='mask of background')
80 | self.parser.add_argument('--self_vgg_weights', type=str, default='1,1,1,1,1', help='vgg weights of ans&guidence loss')
81 | self.parser.add_argument('--self_gram_weights', type=str, default='1,1,1,1,1', help='gram weights of ans&guidence loss')
82 | self.parser.add_argument('--self_vgg_mul', type=float, default=0.0, help='')
83 | self.parser.add_argument('--self_gram_mul', type=float, default=0.0, help='')
84 |
85 | self.parser.add_argument('--style_stage_mul', type=str, default='0:1', help='')
86 | self.parser.add_argument('--real_stage_mul', type=str, default='0:1', help='')
87 | self.parser.add_argument('--train_val_list', type=str, default='0,1,2,3,4,100000,200000,300000,400000,500000', help='')
88 |
89 | self.parser.add_argument('--no_D_label', action='store_true', help='remove label channel of input of D')
90 | self.parser.add_argument('--no_G_label', action='store_true', help='remove label channel of input of G')
91 | self.parser.add_argument('--use_new_label', action='store_true', help='use generated seg map')
92 | self.initialized = True
93 |
94 | def get_list_int(self, s):
95 | str_indexs = s.split(',')
96 | ans = []
97 | if (str_indexs[0] != ''):
98 | for str_index in str_indexs:
99 | ans.append(int(str_index))
100 | return ans
101 | def get_list_float(self, s):
102 | str_indexs = s.split(',')
103 | ans = []
104 | if (str_indexs[0] != ''):
105 | for str_index in str_indexs:
106 | ans.append(float(str_index))
107 | return ans
108 | def get_list_dict(self, s):
109 | str_indexs = s.split(',')
110 | ans = []
111 | if (str_indexs[0] != ''):
112 | for str_index in str_indexs:
113 | temp = str_index.split(':')
114 | ans.append([int(temp[0]), float(temp[1])])
115 | return ans
116 |
117 | def parse(self, save=True):
118 | if not self.initialized:
119 | self.initialize()
120 | self.opt = self.parser.parse_args()
121 | self.opt.isTrain = self.isTrain # train or test
122 |
123 | self.opt.label_indexs = self.get_list_int(self.opt.label_indexs)
124 | self.opt.gram_weights = self.get_list_float(self.opt.gram_weights)
125 | self.opt.vgg_weights = self.get_list_float(self.opt.vgg_weights)
126 | self.opt.self_gram_weights = self.get_list_float(self.opt.self_gram_weights)
127 | self.opt.self_vgg_weights = self.get_list_float(self.opt.self_vgg_weights)
128 | self.opt.style_stage_mul = self.get_list_dict(self.opt.style_stage_mul)
129 | self.opt.real_stage_mul = self.get_list_dict(self.opt.real_stage_mul)
130 | self.opt.train_val_list = self.get_list_int(self.opt.train_val_list)
131 |
132 | str_ids = self.opt.gpu_ids.split(',')
133 | self.opt.gpu_ids = []
134 | for str_id in str_ids:
135 | id = int(str_id)
136 | if id >= 0:
137 | self.opt.gpu_ids.append(id)
138 |
139 | # set gpu ids
140 | if len(self.opt.gpu_ids) > 0:
141 | torch.cuda.set_device(self.opt.gpu_ids[0])
142 |
143 | args = vars(self.opt)
144 |
145 | print('------------ Options -------------')
146 | for k, v in sorted(args.items()):
147 | print('%s: %s' % (str(k), str(v)))
148 | print('-------------- End ----------------')
149 |
150 | # save to the disk
151 | expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.name)
152 | util.mkdirs(expr_dir)
153 | if save and not self.opt.continue_train:
154 | file_name = os.path.join(expr_dir, 'opt.txt')
155 | with open(file_name, 'wt') as opt_file:
156 | opt_file.write('------------ Options -------------\n')
157 | for k, v in sorted(args.items()):
158 | opt_file.write('%s: %s\n' % (str(k), str(v)))
159 | opt_file.write('-------------- End ----------------\n')
160 | return self.opt
161 |
--------------------------------------------------------------------------------
/options/generate_data_options.py:
--------------------------------------------------------------------------------
1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
3 | from .base_options import BaseOptions
4 |
5 | class GenerateDataOptions(BaseOptions):
6 | def initialize(self):
7 | BaseOptions.initialize(self)
8 | self.parser.add_argument('--neighbor_size', type=int, default=int(30), help='max distance between two same style frame.')
9 | self.parser.add_argument('--same_style_rate', type=float, default=float(0.5), help='rate of A&B are neighbors')
10 | self.parser.add_argument('--copy_data', action='store_true', default=False, help='copy img datas')
11 | self.parser.add_argument('--target_path', type=str, default='./datasets/apollo/', help='target path')
12 | self.parser.add_argument('--source_path', type=str, default='../datas/road02/', help='target path')
13 | self.parser.add_argument('--val_ratio', type=float, default=float(0.2), help='val data ratio')
14 | self.parser.add_argument('--test_ratio', type=float, default=float(0.2), help='test data ratio')
15 | self.parser.add_argument('--A_repeat_num', type=int, default=int(10), help='# of same A repeats in dataset.')
16 | self.isTrain = False
17 |
--------------------------------------------------------------------------------
/options/test_options.py:
--------------------------------------------------------------------------------
1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
3 | from .base_options import BaseOptions
4 |
5 | class TestOptions(BaseOptions):
6 | def initialize(self):
7 | BaseOptions.initialize(self)
8 | self.parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.')
9 | self.parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.')
10 | self.parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')
11 | self.parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
12 | self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
13 | self.parser.add_argument('--how_many', type=int, default=100, help='how many test images to run')
14 | self.parser.add_argument('--cluster_path', type=str, default='features_clustered_010.npy', help='the path for clustered results of encoded features')
15 | self.parser.add_argument("--export_onnx", type=str, help="export ONNX model to a given file")
16 | self.parser.add_argument("--engine", type=str, help="run serialized TRT engine")
17 | self.parser.add_argument("--onnx", type=str, help="run ONNX model via TRT")
18 | self.parser.add_argument("--conditioned", action='store_true', help="use conditioned")
19 | self.parser.add_argument("--test_delta_path", type=str, default="test_delta.txt", help="test_delta_path")
20 | self.isTrain = False
21 |
--------------------------------------------------------------------------------
/options/train_options.py:
--------------------------------------------------------------------------------
1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
3 | from .base_options import BaseOptions
4 |
5 | class TrainOptions(BaseOptions):
6 | def initialize(self):
7 | BaseOptions.initialize(self)
8 | # for displays
9 | self.parser.add_argument('--display_freq', type=int, default=100, help='frequency of showing training results on screen')
10 | self.parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')
11 | self.parser.add_argument('--save_latest_freq', type=int, default=1000, help='frequency of saving the latest results')
12 | self.parser.add_argument('--save_iter_freq', type=int, default=150000, help='frequency of saving the latest results')
13 | self.parser.add_argument('--save_epoch_freq', type=int, default=10, help='frequency of saving checkpoints at the end of epochs')
14 | self.parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')
15 | self.parser.add_argument('--debug', action='store_true', help='only do one epoch and displays at each iteration')
16 | self.parser.add_argument('--val_freq', type=int, default=500, help='frequency of showing training results on console')
17 |
18 | # for training
19 | self.parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
20 | self.parser.add_argument('--load_pretrain', type=str, default='', help='load the pretrained model from the specified location')
21 | self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
22 | self.parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
23 | self.parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate')
24 | self.parser.add_argument('--niter_decay', type=int, default=100, help='# of iter to linearly decay learning rate to zero')
25 | self.parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
26 | self.parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')
27 |
28 | self.parser.add_argument('--use_iter_decay', action='store_true', help='user iter decay')
29 | self.parser.add_argument('--niter_iter', type=int, default=200000, help='# of iter at starting learning rate')
30 | self.parser.add_argument('--niter_decay_iter', type=int, default=200000, help='# of iter to linearly decay learning rate to zero')
31 | self.parser.add_argument('--use_style_iter', type=int, default=-1, help='# of iter to use style discriminator')
32 |
33 |
34 | # for discriminators
35 | self.parser.add_argument('--num_D', type=int, default=2, help='number of discriminators to use')
36 | self.parser.add_argument('--n_layers_D', type=int, default=3, help='only used if which_model_netD==n_layers')
37 | self.parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer')
38 | self.parser.add_argument('--nsdf', type=int, default=128, help='# of discrim filters in first conv layer')
39 | self.parser.add_argument('--lambda_feat', type=float, default=10.0, help='weight for feature matching loss')
40 | self.parser.add_argument('--no_ganFeat_loss', action='store_true', help='if specified, do *not* use discriminator feature matching loss')
41 | self.parser.add_argument('--no_vgg_loss', action='store_true', help='if specified, do *not* use VGG feature matching loss')
42 | self.parser.add_argument('--no_lsgan', action='store_true', help='do *not* use least square GAN, if false, use vanilla GAN')
43 | self.parser.add_argument('--pool_size', type=int, default=0, help='the size of image buffer that stores previously generated images')
44 | self.parser.add_argument('--SD_mul', type=float, default=10, help='style discriminator mul')
45 | self.parser.add_argument('--GAN_Feat_mul', type=float, default=1, help='GAN Feat mul')
46 | self.parser.add_argument('--G_confidence_mul', type=float, default=1, help='G confidence mul')
47 | self.parser.add_argument('--FG_GAN_mul', type=float, default=10, help='G confidence mul')
48 | self.parser.add_argument('--val_n_everytime', type=int, default=10, help='val num everytime')
49 |
50 | self.parser.add_argument('--no_SD_false_pair', action='store_true', help='')
51 |
52 | self.parser.add_argument('--use_stage_lr', action='store_true', help='')
53 | self.parser.add_argument('--stage_lr_decay_iter', type=int, default=50000, help='')
54 | self.parser.add_argument('--stage_lr_decay_rate', type=float, default=0.3, help='')
55 | self.isTrain = True
56 |
--------------------------------------------------------------------------------
/precompute_feature_maps.py:
--------------------------------------------------------------------------------
1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
3 | from options.train_options import TrainOptions
4 | from data.data_loader import CreateDataLoader
5 | from models.models import create_model
6 | import os
7 | import util.util as util
8 | from torch.autograd import Variable
9 | import torch.nn as nn
10 |
11 | opt = TrainOptions().parse()
12 | opt.nThreads = 1
13 | opt.batchSize = 1
14 | opt.serial_batches = True
15 | opt.no_flip = True
16 | opt.instance_feat = True
17 |
18 | name = 'features'
19 | save_path = os.path.join(opt.checkpoints_dir, opt.name)
20 |
21 | ############ Initialize #########
22 | data_loader = CreateDataLoader(opt)
23 | dataset = data_loader.load_data()
24 | dataset_size = len(data_loader)
25 | model = create_model(opt)
26 | util.mkdirs(os.path.join(opt.dataroot, opt.phase + '_feat'))
27 |
28 | ######## Save precomputed feature maps for 1024p training #######
29 | for i, data in enumerate(dataset):
30 | print('%d / %d images' % (i+1, dataset_size))
31 | feat_map = model.module.netE.forward(Variable(data['image'].cuda(), volatile=True), data['inst'].cuda())
32 | feat_map = nn.Upsample(scale_factor=2, mode='nearest')(feat_map)
33 | image_numpy = util.tensor2im(feat_map.data[0])
34 | save_path = data['path'][0].replace('/train_label/', '/train_feat/')
35 | util.save_image(image_numpy, save_path)
--------------------------------------------------------------------------------
/process.py:
--------------------------------------------------------------------------------
1 | import cv2 # pip install opencv-python
2 | import os
3 | from os.path import join
4 | import argparse
5 | import random
6 | import numpy as np
7 |
8 | in_path = "../FaceForensics/datas/"
9 | out_path = "../FaceForensics/out_data/"
10 |
11 | def process(vid_path, out_path):
12 | print(vid_path)
13 | print(out_path)
14 | os.makedirs(out_path)
15 | video = cv2.VideoCapture(vid_path)
16 | n = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
17 | for i in range(n):
18 | _, image = video.read()
19 | path = os.path.join(out_path, 'a.' + str(i) + '.jpg')
20 | cv2.imwrite(path, image)
21 |
22 | tot = 0
23 | for root, _, fnames in os.walk(in_path):
24 | for fname in fnames:
25 | if not fname.endswith('.avi'):
26 | continue
27 | process(os.path.join(root, fname), os.path.join(out_path, "Mr_"+str(tot), "1"))
28 | tot += 1
29 |
--------------------------------------------------------------------------------
/run_engine.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | from random import randint
4 | import numpy as np
5 | import tensorrt
6 |
7 | try:
8 | from PIL import Image
9 | import pycuda.driver as cuda
10 | import pycuda.gpuarray as gpuarray
11 | import pycuda.autoinit
12 | import argparse
13 | except ImportError as err:
14 | sys.stderr.write("""ERROR: failed to import module ({})
15 | Please make sure you have pycuda and the example dependencies installed.
16 | https://wiki.tiker.net/PyCuda/Installation/Linux
17 | pip(3) install tensorrt[examples]
18 | """.format(err))
19 | exit(1)
20 |
21 | try:
22 | import tensorrt as trt
23 | from tensorrt.parsers import caffeparser
24 | from tensorrt.parsers import onnxparser
25 | except ImportError as err:
26 | sys.stderr.write("""ERROR: failed to import module ({})
27 | Please make sure you have the TensorRT Library installed
28 | and accessible in your LD_LIBRARY_PATH
29 | """.format(err))
30 | exit(1)
31 |
32 |
33 | G_LOGGER = trt.infer.ConsoleLogger(trt.infer.LogSeverity.INFO)
34 |
35 | class Profiler(trt.infer.Profiler):
36 | """
37 | Example Implimentation of a Profiler
38 | Is identical to the Profiler class in trt.infer so it is possible
39 | to just use that instead of implementing this if further
40 | functionality is not needed
41 | """
42 | def __init__(self, timing_iter):
43 | trt.infer.Profiler.__init__(self)
44 | self.timing_iterations = timing_iter
45 | self.profile = []
46 |
47 | def report_layer_time(self, layerName, ms):
48 | record = next((r for r in self.profile if r[0] == layerName), (None, None))
49 | if record == (None, None):
50 | self.profile.append((layerName, ms))
51 | else:
52 | self.profile[self.profile.index(record)] = (record[0], record[1] + ms)
53 |
54 | def print_layer_times(self):
55 | totalTime = 0
56 | for i in range(len(self.profile)):
57 | print("{:40.40} {:4.3f}ms".format(self.profile[i][0], self.profile[i][1] / self.timing_iterations))
58 | totalTime += self.profile[i][1]
59 | print("Time over all layers: {:4.2f} ms per iteration".format(totalTime / self.timing_iterations))
60 |
61 |
62 | def get_input_output_names(trt_engine):
63 | nbindings = trt_engine.get_nb_bindings();
64 | maps = []
65 |
66 | for b in range(0, nbindings):
67 | dims = trt_engine.get_binding_dimensions(b).to_DimsCHW()
68 | name = trt_engine.get_binding_name(b)
69 | type = trt_engine.get_binding_data_type(b)
70 |
71 | if (trt_engine.binding_is_input(b)):
72 | maps.append(name)
73 | print("Found input: ", name)
74 | else:
75 | maps.append(name)
76 | print("Found output: ", name)
77 |
78 | print("shape=" + str(dims.C()) + " , " + str(dims.H()) + " , " + str(dims.W()))
79 | print("dtype=" + str(type))
80 | return maps
81 |
82 | def create_memory(engine, name, buf, mem, batchsize, inp, inp_idx):
83 | binding_idx = engine.get_binding_index(name)
84 | if binding_idx == -1:
85 | raise AttributeError("Not a valid binding")
86 | print("Binding: name={}, bindingIndex={}".format(name, str(binding_idx)))
87 | dims = engine.get_binding_dimensions(binding_idx).to_DimsCHW()
88 | eltCount = dims.C() * dims.H() * dims.W() * batchsize
89 |
90 | if engine.binding_is_input(binding_idx):
91 | h_mem = inp[inp_idx]
92 | inp_idx = inp_idx + 1
93 | else:
94 | h_mem = np.random.uniform(0.0, 255.0, eltCount).astype(np.dtype('f4'))
95 |
96 | d_mem = cuda.mem_alloc(eltCount * 4)
97 | cuda.memcpy_htod(d_mem, h_mem)
98 | buf.insert(binding_idx, int(d_mem))
99 | mem.append(d_mem)
100 | return inp_idx
101 |
102 |
103 | #Run inference on device
104 | def time_inference(engine, batch_size, inp):
105 | bindings = []
106 | mem = []
107 | inp_idx = 0
108 | for io in get_input_output_names(engine):
109 | inp_idx = create_memory(engine, io, bindings, mem,
110 | batch_size, inp, inp_idx)
111 |
112 | context = engine.create_execution_context()
113 | g_prof = Profiler(500)
114 | context.set_profiler(g_prof)
115 | for i in range(iter):
116 | context.execute(batch_size, bindings)
117 | g_prof.print_layer_times()
118 |
119 | context.destroy()
120 | return
121 |
122 |
123 | def convert_to_datatype(v):
124 | if v==8:
125 | return trt.infer.DataType.INT8
126 | elif v==16:
127 | return trt.infer.DataType.HALF
128 | elif v==32:
129 | return trt.infer.DataType.FLOAT
130 | else:
131 | print("ERROR: Invalid model data type bit depth: " + str(v))
132 | return trt.infer.DataType.INT8
133 |
134 | def run_trt_engine(engine_file, bs, it):
135 | engine = trt.utils.load_engine(G_LOGGER, engine_file)
136 | time_inference(engine, bs, it)
137 |
138 | def run_onnx(onnx_file, data_type, bs, inp):
139 | # Create onnx_config
140 | apex = onnxparser.create_onnxconfig()
141 | apex.set_model_file_name(onnx_file)
142 | apex.set_model_dtype(convert_to_datatype(data_type))
143 |
144 | # create parser
145 | trt_parser = onnxparser.create_onnxparser(apex)
146 | assert(trt_parser)
147 | data_type = apex.get_model_dtype()
148 | onnx_filename = apex.get_model_file_name()
149 | trt_parser.parse(onnx_filename, data_type)
150 | trt_parser.report_parsing_info()
151 | trt_parser.convert_to_trtnetwork()
152 | trt_network = trt_parser.get_trtnetwork()
153 | assert(trt_network)
154 |
155 | # create infer builder
156 | trt_builder = trt.infer.create_infer_builder(G_LOGGER)
157 | trt_builder.set_max_batch_size(max_batch_size)
158 | trt_builder.set_max_workspace_size(max_workspace_size)
159 |
160 | if (apex.get_model_dtype() == trt.infer.DataType_kHALF):
161 | print("------------------- Running FP16 -----------------------------")
162 | trt_builder.set_half2_mode(True)
163 | elif (apex.get_model_dtype() == trt.infer.DataType_kINT8):
164 | print("------------------- Running INT8 -----------------------------")
165 | trt_builder.set_int8_mode(True)
166 | else:
167 | print("------------------- Running FP32 -----------------------------")
168 |
169 | print("----- Builder is Done -----")
170 | print("----- Creating Engine -----")
171 | trt_engine = trt_builder.build_cuda_engine(trt_network)
172 | print("----- Engine is built -----")
173 | time_inference(engine, bs, inp)
174 |
--------------------------------------------------------------------------------
/temp_test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.optim as optim
3 | from torch.autograd import Variable
4 | w1 = Variable(torch.Tensor([1.0,2.0,3.0]),requires_grad=True)
5 |
6 |
7 | optimizer = optim.SGD(w1.parameters(), lr = 0.01)
8 | d = torch.mean(w1)
9 | d.backward()
10 | print(w1.grad)
11 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
3 | import os
4 | from collections import OrderedDict
5 | from torch.autograd import Variable
6 | from options.test_options import TestOptions
7 | from data.data_loader import CreateDataLoader
8 | from models.models import create_model
9 | import util.util as util
10 | from util.visualizer import Visualizer
11 | from util import html
12 | import torch
13 |
14 | opt = TestOptions().parse(save=False)
15 | opt.nThreads = 1 # test code only supports nThreads = 1
16 | opt.batchSize = 1 # test code only supports batchSize = 1
17 | opt.serial_batches = True # no shuffle
18 | opt.no_flip = True # no flip
19 |
20 | data_loader = CreateDataLoader(opt)
21 | dataset = data_loader.load_data()
22 | visualizer = Visualizer(opt)
23 | # create website
24 | web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch))
25 | webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch))
26 |
27 | # test
28 | if not opt.engine and not opt.onnx:
29 | model = create_model(opt)
30 | if opt.data_type == 16:
31 | model.half()
32 | elif opt.data_type == 8:
33 | model.type(torch.uint8)
34 |
35 | if opt.verbose:
36 | print(model)
37 | else:
38 | from run_engine import run_trt_engine, run_onnx
39 |
40 | for i, data in enumerate(dataset):
41 | if i >= opt.how_many:
42 | break
43 | if opt.data_type == 16:
44 | data['label'] = data['label'].half()
45 | data['inst'] = data['inst'].half()
46 | elif opt.data_type == 8:
47 | data['label'] = data['label'].uint8()
48 | data['inst'] = data['inst'].uint8()
49 | if opt.export_onnx:
50 | print ("Exporting to ONNX: ", opt.export_onnx)
51 | assert opt.export_onnx.endswith("onnx"), "Export model file should end with .onnx"
52 | torch.onnx.export(model, [data['label'], data['inst']],
53 | opt.export_onnx, verbose=True)
54 | exit(0)
55 | minibatch = 1
56 | if opt.engine:
57 | generated = run_trt_engine(opt.engine, minibatch, [data['label'], data['inst']])
58 | elif opt.onnx:
59 | generated = run_onnx(opt.onnx, opt.data_type, minibatch, [data['label'], data['inst']])
60 | else:
61 | generated = model.inference(data['label'], data['inst'])
62 |
63 | visuals = OrderedDict([('input_label', util.tensor2label(data['label'][0], opt.label_nc)),
64 | ('synthesized_image', util.tensor2im(generated.data[0]))])
65 | img_path = data['path']
66 | print('process image... %s' % img_path)
67 | visualizer.save_images(webpage, visuals, img_path)
68 |
69 | webpage.save()
70 |
--------------------------------------------------------------------------------
/test_all.py:
--------------------------------------------------------------------------------
1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
3 | import os
4 | from collections import OrderedDict
5 | from torch.autograd import Variable
6 | from options.test_options import TestOptions
7 | from data.data_loader import CreateDataLoader
8 | from models.models import create_model
9 | import util.util as util
10 | from util.visualizer import Visualizer
11 | from util import html
12 | from PIL import Image
13 | import torch
14 | from data.base_dataset import get_params, get_transform, normalize
15 | import copy
16 | from models import networks
17 | import numpy as np
18 | import torch.nn as nn
19 |
20 | opt = TestOptions().parse(save=False)
21 | opt.nThreads = 1 # test code only supports nThreads = 1
22 | opt.batchSize = 1 # test code only supports batchSize = 1
23 | opt.serial_batches = True # no shuffle
24 | opt.no_flip = True # no flip
25 |
26 | data_loader = CreateDataLoader(opt)
27 | dataset = data_loader.load_data()
28 | visualizer = Visualizer(opt)
29 | def get_features(inst, feat):
30 | feat_num = opt.feat_num
31 | h, w = inst.size()[1], inst.size()[2]
32 | block_num = 32
33 | feature = {}
34 | max_v = {}
35 | for i in range(opt.label_nc):
36 | feature[i] = np.zeros((0, feat_num+1))
37 | max_v[i] = 0
38 | for i in np.unique(inst):
39 | label = i if i < 1000 else i//1000
40 | idx = (inst == int(i)).nonzero()
41 | num = idx.size()[0]
42 | idx = idx[num//2,:]
43 | val = np.zeros((feat_num))
44 | for k in range(feat_num):
45 | val[k] = feat[idx[0] + k, idx[1], idx[2]].data[0]
46 | temp = float(num) / (h * w // block_num)
47 | if (temp > max_v[label]):
48 | max_v[label] = temp
49 | feature[label] = val
50 | return feature
51 |
52 | def getitem(A_path, B_path, inst_path, feat_path):
53 | ### input A (label maps)
54 | A = Image.open(A_path)
55 | params = get_params(opt, A.size)
56 | if opt.label_nc == 0:
57 | transform_A = get_transform(opt, params)
58 | A_tensor = transform_A(A.convert('RGB'))
59 | else:
60 | transform_A = get_transform(opt, params, method=Image.NEAREST, normalize=False)
61 | A_tensor = transform_A(A) * 255.0
62 |
63 | B_tensor = inst_tensor = feat_tensor = 0
64 | ### input B (real images)
65 | B = Image.open(B_path).convert('RGB')
66 | transform_B = get_transform(opt, params)
67 | B_tensor = transform_B(B)
68 |
69 | ### if using instance maps
70 | inst = Image.open(inst_path)
71 | inst_tensor = transform_A(inst)
72 |
73 | #get feat
74 | netE = networks.define_G(opt.output_nc, opt.feat_num, opt.nef, 'encoder',
75 | opt.n_downsample_E, norm=opt.norm, gpu_ids=opt.gpu_ids)
76 | feat_map = netE.forward(Variable(B_tensor[np.newaxis, :].cuda(), volatile=True), inst_tensor[np.newaxis, :].cuda())
77 | feat_map = nn.Upsample(scale_factor=2, mode='nearest')(feat_map)
78 | image_numpy = util.tensor2im(feat_map.data[0])
79 | util.save_image(image_numpy, feat_path)
80 |
81 | feat = Image.open(feat_path).convert('RGB')
82 | norm = normalize()
83 | feat_tensor = norm(transform_A(feat))
84 | input_dict = {'label': A_tensor, 'inst': inst_tensor, 'image': B_tensor,
85 | 'feat': feat_tensor, 'path': A_path}
86 |
87 | return get_features(input_dict['inst'], input_dict['feat'])
88 |
89 | # test
90 | if not opt.engine and not opt.onnx:
91 | model = create_model(opt)
92 | if opt.data_type == 16:
93 | model.half()
94 | elif opt.data_type == 8:
95 | model.type(torch.uint8)
96 |
97 | if opt.verbose:
98 | print(model)
99 | else:
100 | from run_engine import run_trt_engine, run_onnx
101 |
102 | label_path = 'datasets/test/label.png'
103 | img_path = 'datasets/test/img.png'
104 | inst_path = 'datasets/test/inst.png'
105 | feat_path = 'datasets/test/feat.png'
106 | con = getitem(label_path, img_path, inst_path, feat_path)
107 | for k in range(opt.n_clusters):
108 | # create website
109 | web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s_%d' % (opt.phase, opt.which_epoch, k))
110 | webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch))
111 |
112 |
113 | for i, data in enumerate(dataset):
114 | if i >= opt.how_many:
115 | break
116 | if opt.data_type == 16:
117 | data['label'] = data['label'].half()
118 | data['inst'] = data['inst'].half()
119 | elif opt.data_type == 8:
120 | data['label'] = data['label'].uint8()
121 | data['inst'] = data['inst'].uint8()
122 | if opt.export_onnx:
123 | print ("Exporting to ONNX: ", opt.export_onnx)
124 | assert opt.export_onnx.endswith("onnx"), "Export model file should end with .onnx"
125 | torch.onnx.export(model, [data['label'], data['inst']],
126 | opt.export_onnx, verbose=True)
127 | exit(0)
128 | minibatch = 1
129 | if opt.engine:
130 | generated = run_trt_engine(opt.engine, minibatch, [data['label'], data['inst']])
131 | elif opt.onnx:
132 | generated = run_onnx(opt.onnx, opt.data_type, minibatch, [data['label'], data['inst']])
133 | elif opt.conditioned:
134 | generated = model.inference_conditioned(data['label'], data['inst'], con, k)
135 | else:
136 | generated = model.inference(data['label'], data['inst'])
137 |
138 | visuals = OrderedDict([('input_label', util.tensor2label(data['label'][0], opt.label_nc)),
139 | ('synthesized_image', util.tensor2im(generated.data[0]))])
140 | img_path = data['path']
141 | print('process image... %s' % img_path)
142 | visualizer.save_images(webpage, visuals, img_path)
143 |
144 | webpage.save()
145 |
--------------------------------------------------------------------------------
/test_con.py:
--------------------------------------------------------------------------------
1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
3 | import os
4 | from collections import OrderedDict
5 | from torch.autograd import Variable
6 | from options.test_options import TestOptions
7 | from data.data_loader import CreateConDataLoader
8 | from models.models import create_model
9 | import util.util as util
10 | from util.visualizer import Visualizer
11 | from util import html
12 | import torch
13 |
14 | opt = TestOptions().parse(save=False)
15 | opt.nThreads = 1 # test code only supports nThreads = 1
16 | opt.batchSize = 1 # test code only supports batchSize = 1
17 | opt.serial_batches = True # no shuffle
18 | opt.no_flip = True # no flip
19 |
20 | data_loader = CreateConDataLoader(opt)
21 | dataset = data_loader.load_data()
22 | visualizer = Visualizer(opt)
23 | # create website
24 | web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch))
25 | webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch))
26 |
27 | # test
28 | if not opt.engine and not opt.onnx:
29 | model = create_model(opt)
30 | if opt.data_type == 16:
31 | model.half()
32 | elif opt.data_type == 8:
33 | model.type(torch.uint8)
34 |
35 | if opt.verbose:
36 | print(model)
37 | else:
38 | from run_engine import run_trt_engine, run_onnx
39 |
40 | for i, data in enumerate(dataset):
41 | if i >= opt.how_many:
42 | break
43 | if opt.data_type == 16:
44 | data['A'] = data['A'].half()
45 | data['A2'] = data['A2'].half()
46 | data['B'] = data['B'].half()
47 | data['B2'] = data['B2'].half()
48 | elif opt.data_type == 8:
49 | data['A'] = data['A'].uint8()
50 | data['A2'] = data['A2'].uint8()
51 | data['B'] = data['B'].uint8()
52 | data['B2'] = data['B2'].uint8()
53 | if opt.export_onnx:
54 | print ("Exporting to ONNX: ", opt.export_onnx)
55 | assert opt.export_onnx.endswith("onnx"), "Export model file should end with .onnx"
56 | torch.onnx.export(model, [data['label'], data['inst']],
57 | opt.export_onnx, verbose=True)
58 | exit(0)
59 | minibatch = 1
60 | if opt.engine:
61 | generated = run_trt_engine(opt.engine, minibatch, [data['label'], data['inst']])
62 | elif opt.onnx:
63 | generated = run_onnx(opt.onnx, opt.data_type, minibatch, [data['label'], data['inst']])
64 | else:
65 | generated = model.inference(data['A'], data['B'], data['B2'])
66 |
67 | visuals = OrderedDict([('input_label', util.tensor2label(data['A'][0], 256)),
68 | ('real_image', util.tensor2im(data['A2'][0])),
69 | ('synthesized_image', util.tensor2im(generated.data[0])),
70 | ('B', util.tensor2label(data['B'][0], 256)),
71 | ('B2', util.tensor2im(data['B2'][0]))])
72 | img_path = data['path']
73 | print('process image... %s' % img_path)
74 | visualizer.save_images(webpage, visuals, img_path)
75 |
76 | webpage.save()
77 |
--------------------------------------------------------------------------------
/test_con_bak.py:
--------------------------------------------------------------------------------
1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
3 | import os
4 | from collections import OrderedDict
5 | from torch.autograd import Variable
6 | from options.test_options import TestOptions
7 | from data.data_loader import CreateDataLoader
8 | from models.models import create_model
9 | import util.util as util
10 | from util.visualizer import Visualizer
11 | from util import html
12 | from PIL import Image
13 | import torch
14 | from data.base_dataset import get_params, get_transform, normalize
15 | import copy
16 | from models import networks
17 | import numpy as np
18 | import torch.nn as nn
19 |
20 | opt = TestOptions().parse(save=False)
21 | opt.nThreads = 1 # test code only supports nThreads = 1
22 | opt.batchSize = 1 # test code only supports batchSize = 1
23 | opt.serial_batches = True # no shuffle
24 | opt.no_flip = True # no flip
25 |
26 | data_loader = CreateDataLoader(opt)
27 | dataset = data_loader.load_data()
28 | visualizer = Visualizer(opt)
29 | # create website
30 | web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch))
31 | webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch))
32 |
33 | def get_features(inst, feat):
34 | feat_num = opt.feat_num
35 | h, w = inst.size()[1], inst.size()[2]
36 | block_num = 32
37 | feature = {}
38 | max_v = {}
39 | for i in range(opt.label_nc):
40 | feature[i] = np.zeros((0, feat_num+1))
41 | max_v[i] = 0
42 | for i in np.unique(inst):
43 | label = i if i < 1000 else i//1000
44 | idx = (inst == int(i)).nonzero()
45 | num = idx.size()[0]
46 | idx = idx[num//2,:]
47 | val = np.zeros((feat_num))
48 | for k in range(feat_num):
49 | val[k] = feat[0, idx[0] + k, idx[1], idx[2]].data[0]
50 | temp = float(num) / (h * w // block_num)
51 | if (temp > max_v[label]):
52 | max_v[label] = temp
53 | feature[label] = val
54 | return feature
55 |
56 | def getitem(A_path, B_path, inst_path, feat_path):
57 | ### input A (label maps)
58 | A = Image.open(A_path)
59 | params = get_params(opt, A.size)
60 | if opt.label_nc == 0:
61 | transform_A = get_transform(opt, params)
62 | A_tensor = transform_A(A.convert('RGB'))
63 | else:
64 | transform_A = get_transform(opt, params, method=Image.NEAREST, normalize=False)
65 | A_tensor = transform_A(A) * 255.0
66 |
67 | B_tensor = inst_tensor = feat_tensor = 0
68 | ### input B (real images)
69 | B = Image.open(B_path).convert('RGB')
70 | transform_B = get_transform(opt, params)
71 | B_tensor = transform_B(B)
72 |
73 | ### if using instance maps
74 | inst = Image.open(inst_path)
75 | inst_tensor = transform_A(inst)
76 |
77 | #get feat
78 | netE = networks.define_G(opt.output_nc, opt.feat_num, opt.nef, 'encoder',
79 | opt.n_downsample_E, norm=opt.norm, gpu_ids=opt.gpu_ids)
80 | feat_map = netE.forward(Variable(B_tensor[np.newaxis, :].cuda(), volatile=True), inst_tensor[np.newaxis, :].cuda())
81 | '''
82 | feat_map = nn.Upsample(scale_factor=2, mode='nearest')(feat_map)
83 | image_numpy = util.tensor2im(feat_map.data[0])
84 | util.save_image(image_numpy, feat_path)
85 |
86 | feat = Image.open(feat_path).convert('RGB')
87 | norm = normalize()
88 | feat_tensor = norm(transform_A(feat))
89 | input_dict = {'label': A_tensor, 'inst': inst_tensor, 'image': B_tensor,
90 | 'feat': feat_tensor, 'path': A_path}
91 | '''
92 |
93 | return get_features(inst_tensor, feat_map)
94 |
95 | # test
96 | if not opt.engine and not opt.onnx:
97 | model = create_model(opt)
98 | if opt.data_type == 16:
99 | model.half()
100 | elif opt.data_type == 8:
101 | model.type(torch.uint8)
102 |
103 | if opt.verbose:
104 | print(model)
105 | else:
106 | from run_engine import run_trt_engine, run_onnx
107 |
108 | label_path = 'datasets/test/label.png'
109 | img_path = 'datasets/test/img.png'
110 | inst_path = 'datasets/test/inst.png'
111 | feat_path = 'datasets/test/feat.png'
112 | con = getitem(label_path, img_path, inst_path, feat_path)
113 |
114 | for i, data in enumerate(dataset):
115 | if i >= opt.how_many:
116 | break
117 | if opt.data_type == 16:
118 | data['label'] = data['label'].half()
119 | data['inst'] = data['inst'].half()
120 | elif opt.data_type == 8:
121 | data['label'] = data['label'].uint8()
122 | data['inst'] = data['inst'].uint8()
123 | if opt.export_onnx:
124 | print ("Exporting to ONNX: ", opt.export_onnx)
125 | assert opt.export_onnx.endswith("onnx"), "Export model file should end with .onnx"
126 | torch.onnx.export(model, [data['label'], data['inst']],
127 | opt.export_onnx, verbose=True)
128 | exit(0)
129 | minibatch = 1
130 | if opt.engine:
131 | generated = run_trt_engine(opt.engine, minibatch, [data['label'], data['inst']])
132 | elif opt.onnx:
133 | generated = run_onnx(opt.onnx, opt.data_type, minibatch, [data['label'], data['inst']])
134 | elif opt.conditioned:
135 | generated = model.inference_conditioned(data['label'], data['inst'], con)
136 | else:
137 | generated = model.inference(data['label'], data['inst'])
138 |
139 | visuals = OrderedDict([('input_label', util.tensor2label(data['label'][0], opt.label_nc)),
140 | ('synthesized_image', util.tensor2im(generated.data[0]))])
141 | img_path = data['path']
142 | print('process image... %s' % img_path)
143 | visualizer.save_images(webpage, visuals, img_path)
144 |
145 | webpage.save()
146 |
--------------------------------------------------------------------------------
/test_delta.txt:
--------------------------------------------------------------------------------
1 | 0
2 |
--------------------------------------------------------------------------------
/test_face.py:
--------------------------------------------------------------------------------
1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
3 | import os
4 | from collections import OrderedDict
5 | from torch.autograd import Variable
6 | from options.test_options import TestOptions
7 | from data.data_loader import CreateFaceConDataLoader
8 | from models.models import create_model
9 | import util.util as util
10 | from util.visualizer import Visualizer
11 | from util import html
12 | import random
13 | import torch
14 |
15 | opt = TestOptions().parse(save=False)
16 | opt.nThreads = 1 # test code only supports nThreads = 1
17 | opt.batchSize = 1 # test code only supports batchSize = 1
18 | opt.serial_batches = True # no shuffle
19 | opt.no_flip = True # no flip
20 |
21 | data_loader = CreateFaceConDataLoader(opt)
22 | dataset = data_loader.load_data()
23 | visualizer = Visualizer(opt)
24 | # create website
25 | web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s_%s' % (opt.phase, opt.which_epoch, str(opt.serial_batches)))
26 | webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch))
27 |
28 | # test
29 | if not opt.engine and not opt.onnx:
30 | model = create_model(opt)
31 | if opt.data_type == 16:
32 | model.half()
33 | elif opt.data_type == 8:
34 | model.type(torch.uint8)
35 |
36 | if opt.verbose:
37 | print(model)
38 | else:
39 | from run_engine import run_trt_engine, run_onnx
40 |
41 | tot = 0
42 | for i, data in enumerate(dataset, start=random.randint(0, len(dataset) - opt.how_many)):
43 | tot += 1
44 | if tot > opt.how_many:
45 | break
46 | if opt.data_type == 16:
47 | data['A'] = data['A'].half()
48 | data['A2'] = data['A2'].half()
49 | data['B'] = data['B'].half()
50 | data['B2'] = data['B2'].half()
51 | elif opt.data_type == 8:
52 | data['A'] = data['A'].uint8()
53 | data['A2'] = data['A2'].uint8()
54 | data['B'] = data['B'].uint8()
55 | data['B2'] = data['B2'].uint8()
56 | if opt.export_onnx:
57 | print ("Exporting to ONNX: ", opt.export_onnx)
58 | assert opt.export_onnx.endswith("onnx"), "Export model file should end with .onnx"
59 | torch.onnx.export(model, [data['label'], data['inst']],
60 | opt.export_onnx, verbose=True)
61 | exit(0)
62 | minibatch = 1
63 | if opt.engine:
64 | generated = run_trt_engine(opt.engine, minibatch, [data['label'], data['inst']])
65 | elif opt.onnx:
66 | generated = run_onnx(opt.onnx, opt.data_type, minibatch, [data['label'], data['inst']])
67 | else:
68 | generated = model.inference(data['A'], data['B'], data['B2'])
69 |
70 | visuals = OrderedDict([('input_label', util.tensor2label(data['A'][0], 0)),
71 | ('real_image', util.tensor2im(data['A2'][0])),
72 | ('synthesized_image', util.tensor2im(generated.data[0])),
73 | ('B', util.tensor2label(data['B'][0], 0)),
74 | ('B2', util.tensor2im(data['B2'][0]))])
75 | img_path = data['path']
76 | img_path[0] = str(i)
77 | print('process image... %s' % img_path)
78 | visualizer.save_images(webpage, visuals, img_path)
79 |
80 | webpage.save()
81 |
--------------------------------------------------------------------------------
/test_mface.py:
--------------------------------------------------------------------------------
1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
3 | import os
4 | from collections import OrderedDict
5 | from torch.autograd import Variable
6 | from options.test_options import TestOptions
7 | from data.data_loader import CreateFaceConDataLoader
8 | from models.models import create_model
9 | import util.util as util
10 | from util.visualizer import Visualizer
11 | from util import html
12 | import torch
13 |
14 | opt = TestOptions().parse(save=False)
15 | opt.nThreads = 1 # test code only supports nThreads = 1
16 | opt.batchSize = 1 # test code only supports batchSize = 1
17 | opt.serial_batches = False #True # no shuffle
18 | opt.no_flip = True # no flip
19 |
20 | data_loader = CreateFaceConDataLoader(opt)
21 | dataset = data_loader.load_data()
22 | visualizer = Visualizer(opt)
23 | # create website
24 | web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch))
25 | webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch))
26 |
27 | # test
28 | model = create_model(opt)
29 | '''
30 | if not opt.engine and not opt.onnx:
31 | model = create_model(opt)
32 | if opt.data_type == 16:
33 | model.half()
34 | elif opt.data_type == 8:
35 | model.type(torch.uint8)
36 |
37 | if opt.verbose:
38 | print(model)
39 | else:
40 | from run_engine import run_trt_engine, run_onnx
41 | '''
42 | for i, data in enumerate(dataset):
43 | if i >= opt.how_many:
44 | break
45 | '''
46 | if opt.data_type == 16:
47 | data['A'] = data['A'].half()
48 | data['A2'] = data['A2'].half()
49 | data['B'] = data['B'].half()
50 | data['B2'] = data['B2'].half()
51 | elif opt.data_type == 8:
52 | data['A'] = data['A'].uint8()
53 | data['A2'] = data['A2'].uint8()
54 | data['B'] = data['B'].uint8()
55 | data['B2'] = data['B2'].uint8()
56 | '''
57 | if opt.export_onnx:
58 | print ("Exporting to ONNX: ", opt.export_onnx)
59 | assert opt.export_onnx.endswith("onnx"), "Export model file should end with .onnx"
60 | torch.onnx.export(model, [data['label'], data['inst']],
61 | opt.export_onnx, verbose=True)
62 | exit(0)
63 | minibatch = 1
64 | if opt.engine:
65 | generated = run_trt_engine(opt.engine, minibatch, [data['label'], data['inst']])
66 | elif opt.onnx:
67 | generated = run_onnx(opt.onnx, opt.data_type, minibatch, [data['label'], data['inst']])
68 | else:
69 | generated1, mask, generated = model.inference(Variable(data['A']), Variable(data['B']), Variable(data['B2']))
70 |
71 | visuals = OrderedDict([('input_label', util.tensor2label(data['A'][0], 0)),
72 | ('real_image', util.tensor2im(data['A2'][0])),
73 | ('synthesized_image_1', util.tensor2im(generated1.data[0])),
74 | ('mask', util.tensor2im(mask.data[0])),
75 | ('synthesized_image', util.tensor2im(generated.data[0])),
76 | ('B', util.tensor2label(data['B'][0], 0)),
77 | ('B2', util.tensor2im(data['B2'][0]))])
78 | img_path = data['path']
79 | print('process image... %s' % img_path)
80 | visualizer.save_images(webpage, visuals, img_path)
81 |
82 | webpage.save()
83 |
--------------------------------------------------------------------------------
/test_pose.py:
--------------------------------------------------------------------------------
1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
3 | import os
4 | from collections import OrderedDict
5 | from torch.autograd import Variable
6 | from options.test_options import TestOptions
7 | from data.data_loader import CreatePoseConDataLoader
8 | from models.models import create_model
9 | import util.util as util
10 | from util.visualizer import Visualizer
11 | from util import html
12 | import random
13 | import torch
14 |
15 | opt = TestOptions().parse(save=False)
16 | opt.nThreads = 1 # test code only supports nThreads = 1
17 | opt.batchSize = 1 # test code only supports batchSize = 1
18 | opt.serial_batches = True # no shuffle
19 | opt.no_flip = True # no flip
20 |
21 | data_loader = CreatePoseConDataLoader(opt)
22 | dataset = data_loader.load_data()
23 | visualizer = Visualizer(opt)
24 | # create website
25 | web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s_%s' % (opt.phase, opt.which_epoch, str(opt.serial_batches)))
26 | webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch))
27 |
28 | # test
29 | if not opt.engine and not opt.onnx:
30 | model = create_model(opt)
31 | if opt.data_type == 16:
32 | model.half()
33 | elif opt.data_type == 8:
34 | model.type(torch.uint8)
35 |
36 | if opt.verbose:
37 | print(model)
38 | else:
39 | from run_engine import run_trt_engine, run_onnx
40 |
41 | tot = 0
42 | for i, data in enumerate(dataset, start=0):
43 | tot += 1
44 | if tot > opt.how_many:
45 | break
46 | if opt.data_type == 16:
47 | data['A'] = data['A'].half()
48 | data['A2'] = data['A2'].half()
49 | data['B'] = data['B'].half()
50 | data['B2'] = data['B2'].half()
51 | elif opt.data_type == 8:
52 | data['A'] = data['A'].uint8()
53 | data['A2'] = data['A2'].uint8()
54 | data['B'] = data['B'].uint8()
55 | data['B2'] = data['B2'].uint8()
56 | if opt.export_onnx:
57 | print ("Exporting to ONNX: ", opt.export_onnx)
58 | assert opt.export_onnx.endswith("onnx"), "Export model file should end with .onnx"
59 | torch.onnx.export(model, [data['label'], data['inst']],
60 | opt.export_onnx, verbose=True)
61 | exit(0)
62 | minibatch = 1
63 | if opt.engine:
64 | generated = run_trt_engine(opt.engine, minibatch, [data['label'], data['inst']])
65 | elif opt.onnx:
66 | generated = run_onnx(opt.onnx, opt.data_type, minibatch, [data['label'], data['inst']])
67 | else:
68 | generated = model.inference(data['A'], data['B'], data['B2'])
69 |
70 | visuals = OrderedDict([('input_label', util.tensor2im(data['A'][0])),
71 | ('real_image', util.tensor2im(data['A2'][0])),
72 | ('synthesized_image', util.tensor2im(generated.data[0])),
73 | ('B', util.tensor2im(data['B'][0])),
74 | ('B2', util.tensor2im(data['B2'][0]))])
75 | img_path = data['path']
76 | img_path = [str(tot-1)]
77 | print('process image... %s' % img_path)
78 | visualizer.save_images(webpage, visuals, img_path)
79 |
80 | webpage.save()
81 |
--------------------------------------------------------------------------------
/test_seg.py:
--------------------------------------------------------------------------------
1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
3 | import os
4 | from collections import OrderedDict
5 | from torch.autograd import Variable
6 | from options.test_options import TestOptions
7 | from data.data_loader import CreateConDataLoader
8 | from models.models import create_model
9 | import util.util as util
10 | from util.visualizer import Visualizer
11 | from util import html
12 | import random
13 | import torch
14 |
15 | opt = TestOptions().parse(save=False)
16 | opt.nThreads = 1 # test code only supports nThreads = 1
17 | opt.batchSize = 1 # test code only supports batchSize = 1
18 | opt.serial_batches = True # no shuffle
19 | opt.no_flip = True # no flip
20 |
21 | data_loader = CreateConDataLoader(opt)
22 | dataset = data_loader.load_data()
23 | visualizer = Visualizer(opt)
24 | # create website
25 | web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s_%s' % (opt.phase, opt.which_epoch, str(opt.serial_batches)))
26 | webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch))
27 |
28 | # test
29 | if not opt.engine and not opt.onnx:
30 | model = create_model(opt)
31 | if opt.data_type == 16:
32 | model.half()
33 | elif opt.data_type == 8:
34 | model.type(torch.uint8)
35 |
36 | if opt.verbose:
37 | print(model)
38 | else:
39 | from run_engine import run_trt_engine, run_onnx
40 |
41 | tot = 0
42 | for i, data in enumerate(dataset, start=random.randint(0, len(dataset) - opt.how_many)):
43 | tot += 1
44 | if tot > opt.how_many:
45 | break
46 | if opt.data_type == 16:
47 | data['A'] = data['A'].half()
48 | data['A2'] = data['A2'].half()
49 | data['B'] = data['B'].half()
50 | data['B2'] = data['B2'].half()
51 | elif opt.data_type == 8:
52 | data['A'] = data['A'].uint8()
53 | data['A2'] = data['A2'].uint8()
54 | data['B'] = data['B'].uint8()
55 | data['B2'] = data['B2'].uint8()
56 | if opt.export_onnx:
57 | print ("Exporting to ONNX: ", opt.export_onnx)
58 | assert opt.export_onnx.endswith("onnx"), "Export model file should end with .onnx"
59 | torch.onnx.export(model, [data['label'], data['inst']],
60 | opt.export_onnx, verbose=True)
61 | exit(0)
62 | minibatch = 1
63 | if opt.engine:
64 | generated = run_trt_engine(opt.engine, minibatch, [data['label'], data['inst']])
65 | elif opt.onnx:
66 | generated = run_onnx(opt.onnx, opt.data_type, minibatch, [data['label'], data['inst']])
67 | else:
68 | generated = model.inference(data['A'], data['B'], data['B2'])
69 |
70 | visuals = OrderedDict([('input_label', util.tensor2label(data['A'][0], 0)),
71 | ('real_image', util.tensor2im(data['A2'][0])),
72 | ('synthesized_image', util.tensor2im(generated.data[0])),
73 | ('B', util.tensor2label(data['B'][0], 0)),
74 | ('B2', util.tensor2im(data['B2'][0]))])
75 | img_path = data['path']
76 | img_path[0] = str(i)
77 | print('process image... %s' % img_path)
78 | visualizer.save_images(webpage, visuals, img_path)
79 |
80 | webpage.save()
81 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
3 | import time
4 | from collections import OrderedDict
5 | from options.train_options import TrainOptions
6 | from data.data_loader import CreateDataLoader
7 | from models.models import create_model
8 | import util.util as util
9 | from util.visualizer import Visualizer
10 | import os
11 | import numpy as np
12 | import torch
13 | from torch.autograd import Variable
14 |
15 | opt = TrainOptions().parse()
16 | print(opt.gpu_ids)
17 | iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt')
18 | if opt.continue_train:
19 | try:
20 | start_epoch, epoch_iter = np.loadtxt(iter_path , delimiter=',', dtype=int)
21 | except:
22 | start_epoch, epoch_iter = 1, 0
23 | print('Resuming from epoch %d at iteration %d' % (start_epoch, epoch_iter))
24 | else:
25 | start_epoch, epoch_iter = 1, 0
26 |
27 | if opt.debug:
28 | opt.display_freq = 1
29 | opt.print_freq = 1
30 | opt.niter = 1
31 | opt.niter_decay = 0
32 | opt.max_dataset_size = 10
33 |
34 | data_loader = CreateDataLoader(opt)
35 | dataset = data_loader.load_data()
36 | dataset_size = len(data_loader)
37 | print('#training images = %d' % dataset_size)
38 |
39 | model = create_model(opt)
40 | visualizer = Visualizer(opt)
41 |
42 | total_steps = (start_epoch-1) * dataset_size + epoch_iter
43 |
44 | display_delta = total_steps % opt.display_freq
45 | print_delta = total_steps % opt.print_freq
46 | save_delta = total_steps % opt.save_latest_freq
47 |
48 | for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1):
49 | epoch_start_time = time.time()
50 | if epoch != start_epoch:
51 | epoch_iter = epoch_iter % dataset_size
52 | for i, data in enumerate(dataset, start=epoch_iter):
53 | iter_start_time = time.time()
54 | total_steps += opt.batchSize
55 | epoch_iter += opt.batchSize
56 |
57 | # whether to collect output images
58 | save_fake = total_steps % opt.display_freq == display_delta
59 |
60 | ############## Forward Pass ######################
61 | losses, generated = model(Variable(data['label']), Variable(data['inst']),
62 | Variable(data['image']), Variable(data['feat']), infer=save_fake)
63 |
64 | # sum per device losses
65 | losses = [ torch.mean(x) if not isinstance(x, int) else x for x in losses ]
66 | loss_dict = dict(zip(model.module.loss_names, losses))
67 |
68 | # calculate final loss scalar
69 | loss_D = (loss_dict['D_fake'] + loss_dict['D_real']) * 0.5
70 | loss_G = loss_dict['G_GAN'] + loss_dict.get('G_GAN_Feat',0) + loss_dict.get('G_VGG',0)
71 |
72 | ############### Backward Pass ####################
73 | # update generator weights
74 | model.module.optimizer_G.zero_grad()
75 | loss_G.backward()
76 | model.module.optimizer_G.step()
77 |
78 | # update discriminator weights
79 | model.module.optimizer_D.zero_grad()
80 | loss_D.backward()
81 | model.module.optimizer_D.step()
82 |
83 | #call(["nvidia-smi", "--format=csv", "--query-gpu=memory.used,memory.free"])
84 |
85 | ############## Display results and errors ##########
86 | ### print out errors
87 | if total_steps % opt.print_freq == print_delta:
88 | errors = {k: v.data[0] if not isinstance(v, int) else v for k, v in loss_dict.items()}
89 | t = (time.time() - iter_start_time) / opt.batchSize
90 | visualizer.print_current_errors(epoch, epoch_iter, errors, t)
91 | visualizer.plot_current_errors(errors, total_steps)
92 |
93 | ### display output images
94 | if save_fake:
95 | visuals = OrderedDict([('input_label', util.tensor2label(data['label'][0], opt.label_nc)),
96 | ('synthesized_image', util.tensor2im(generated.data[0])),
97 | ('real_image', util.tensor2im(data['image'][0]))])
98 | visualizer.display_current_results(visuals, epoch, total_steps)
99 |
100 | ### save latest model
101 | if total_steps % opt.save_latest_freq == save_delta:
102 | print('saving the latest model (epoch %d, total_steps %d)' % (epoch, total_steps))
103 | model.module.save('latest')
104 | np.savetxt(iter_path, (epoch, epoch_iter), delimiter=',', fmt='%d')
105 |
106 | if epoch_iter >= dataset_size:
107 | break
108 |
109 | # end of epoch
110 | iter_end_time = time.time()
111 | print('End of epoch %d / %d \t Time Taken: %d sec' %
112 | (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))
113 |
114 | ### save model for this epoch
115 | if epoch % opt.save_epoch_freq == 0:
116 | print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps))
117 | model.module.save('latest')
118 | model.module.save(epoch)
119 | np.savetxt(iter_path, (epoch+1, 0), delimiter=',', fmt='%d')
120 |
121 | ### instead of only training the local enhancer, train the entire network after certain iterations
122 | if (opt.niter_fix_global != 0) and (epoch == opt.niter_fix_global):
123 | model.module.update_fixed_params()
124 |
125 | ### linearly decay learning rate after certain iterations
126 | if epoch > opt.niter:
127 | model.module.update_learning_rate()
128 |
--------------------------------------------------------------------------------
/train_con.py:
--------------------------------------------------------------------------------
1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
3 | import time
4 | from collections import OrderedDict
5 | from options.train_options import TrainOptions
6 | from data.data_loader import CreateConDataLoader
7 | from models.models import create_model
8 | import util.util as util
9 | from util.visualizer import Visualizer
10 | import os
11 | import numpy as np
12 | import torch
13 | from torch.autograd import Variable
14 |
15 | opt = TrainOptions().parse()
16 | iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt')
17 | if opt.continue_train:
18 | try:
19 | start_epoch, epoch_iter = np.loadtxt(iter_path , delimiter=',', dtype=int)
20 | except:
21 | start_epoch, epoch_iter = 1, 0
22 | print('Resuming from epoch %d at iteration %d' % (start_epoch, epoch_iter))
23 | else:
24 | start_epoch, epoch_iter = 1, 0
25 |
26 | if opt.debug:
27 | opt.display_freq = 1
28 | opt.print_freq = 1
29 | opt.niter = 1
30 | opt.niter_decay = 0
31 | opt.max_dataset_size = 10
32 |
33 | data_loader = CreateConDataLoader(opt)
34 | dataset = data_loader.load_data()
35 | dataset_size = len(data_loader)
36 | print('#training images = %d' % dataset_size)
37 |
38 | model = create_model(opt)
39 | visualizer = Visualizer(opt)
40 |
41 | total_steps = (start_epoch-1) * dataset_size + epoch_iter
42 |
43 | display_delta = total_steps % opt.display_freq
44 | print_delta = total_steps % opt.print_freq
45 | save_delta = total_steps % opt.save_latest_freq
46 |
47 | for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1):
48 | epoch_start_time = time.time()
49 | if epoch != start_epoch:
50 | epoch_iter = epoch_iter % dataset_size
51 | for i, data in enumerate(dataset, start=epoch_iter):
52 | iter_start_time = time.time()
53 | total_steps += opt.batchSize
54 | epoch_iter += opt.batchSize
55 |
56 | # whether to collect output images
57 | save_fake = total_steps % opt.display_freq == display_delta
58 |
59 | ############## Forward Pass ######################
60 | losses, generated = model(
61 | Variable(data['A']), Variable(data['A2']),
62 | Variable(data['B']), Variable(data['B2']),
63 | Variable(data['C']), Variable(data['C2']),
64 | Variable(data['D']), Variable(data['D2']), infer=save_fake)
65 |
66 | # sum per device losses
67 | losses = [ torch.mean(x) if not isinstance(x, int) else x for x in losses ]
68 | loss_dict = dict(zip(model.module.loss_names, losses))
69 |
70 | if (total_steps <= opt.use_style_iter):
71 | use_style = 0
72 | else:
73 | use_style = 1
74 | loss_dict['G_GAN_style'] *= 10
75 | # calculate final loss scalar
76 | loss_D = (loss_dict['D_fake'] + loss_dict['D_real']) * 0.5
77 | loss_SD = (loss_dict['SD_fake1'] + loss_dict['SD_fake2'] + loss_dict['SD_real']) * 0.5
78 | loss_G = loss_dict['G_GAN'] + loss_dict['G_GAN_style'] * use_style + (loss_dict.get('G_GAN_Feat',0) + loss_dict.get('G_VGG',0))
79 |
80 | ############### Backward Pass ####################
81 | # update generator weights
82 | model.module.optimizer_G.zero_grad()
83 | loss_G.backward()
84 | model.module.optimizer_G.step()
85 |
86 | # update discriminator weights
87 | model.module.optimizer_D.zero_grad()
88 | loss_D.backward()
89 | model.module.optimizer_D.step()
90 |
91 | # update discriminator weights
92 | if (total_steps > opt.use_style_iter):
93 | model.module.optimizer_SD.zero_grad()
94 | loss_SD.backward()
95 | model.module.optimizer_SD.step()
96 |
97 | #call(["nvidia-smi", "--format=csv", "--query-gpu=memory.used,memory.free"])
98 |
99 | ############## Display results and errors ##########
100 | ### print out errors
101 | if total_steps % opt.print_freq == print_delta:
102 | errors = {k: v.data[0] if not isinstance(v, int) else v for k, v in loss_dict.items()}
103 | errors['loss_G'] = loss_G
104 | errors['loss_D'] = loss_D
105 | errors['loss_SD'] = loss_SD
106 | t = (time.time() - iter_start_time) / opt.batchSize
107 | visualizer.print_current_errors(epoch, epoch_iter, errors, t)
108 | visualizer.plot_current_errors(errors, total_steps)
109 |
110 |
111 | ### display output images
112 | if save_fake:
113 | visuals = OrderedDict([('input_label', util.tensor2label(data['A'][0], 256)),
114 | ('real_image', util.tensor2im(data['A2'][0])),
115 | ('synthesized_image', util.tensor2im(generated.data[0])),
116 | ('B', util.tensor2label(data['B'][0], 256)),
117 | ('B2', util.tensor2im(data['B2'][0]))])
118 | visualizer.display_current_results2(visuals, epoch, total_steps)
119 |
120 | ### save latest model
121 | if total_steps % opt.save_latest_freq == save_delta:
122 | print('saving the latest model (epoch %d, total_steps %d)' % (epoch, total_steps))
123 | model.module.save('latest')
124 | np.savetxt(iter_path, (epoch, epoch_iter), delimiter=',', fmt='%d')
125 |
126 | if opt.use_iter_decay and total_steps > opt.niter_iter:
127 | for temp in range(opt.batchSize):
128 | model.module.update_learning_rate()
129 |
130 | if opt.use_iter_decay and total_steps >= opt.niter_iter + opt.niter_decay_iter:
131 | break
132 | if epoch_iter >= dataset_size:
133 | break
134 |
135 | # end of epoch
136 | iter_end_time = time.time()
137 | print('End of epoch %d / %d \t Time Taken: %d sec' %
138 | (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))
139 |
140 | ### save model for this epoch
141 | if epoch % opt.save_epoch_freq == 0:
142 | print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps))
143 | model.module.save('latest')
144 | model.module.save(epoch)
145 | np.savetxt(iter_path, (epoch+1, 0), delimiter=',', fmt='%d')
146 |
147 | ### instead of only training the local enhancer, train the entire network after certain iterations
148 | if (opt.niter_fix_global != 0) and (epoch == opt.niter_fix_global):
149 | model.module.update_fixed_params()
150 |
151 | ### linearly decay learning rate after certain iterations
152 | if epoch > opt.niter and not opt.use_iter_decay:
153 | model.module.update_learning_rate()
154 |
155 | if opt.use_iter_decay and total_steps > opt.niter_iter + opt.niter_decay_iter:
156 | break
157 |
158 |
--------------------------------------------------------------------------------
/util/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cxjyxxme/pix2pixSC/1851e934b333e89b98b3c0df9962912201fcb376/util/__init__.py
--------------------------------------------------------------------------------
/util/html.py:
--------------------------------------------------------------------------------
1 | import dominate
2 | from dominate.tags import *
3 | import os
4 |
5 |
6 | class HTML:
7 | def __init__(self, web_dir, title, refresh=0):
8 | self.title = title
9 | self.web_dir = web_dir
10 | self.img_dir = os.path.join(self.web_dir, 'images')
11 | if not os.path.exists(self.web_dir):
12 | os.makedirs(self.web_dir)
13 | if not os.path.exists(self.img_dir):
14 | os.makedirs(self.img_dir)
15 |
16 | self.doc = dominate.document(title=title)
17 | if refresh > 0:
18 | with self.doc.head:
19 | meta(http_equiv="refresh", content=str(refresh))
20 |
21 | def get_image_dir(self):
22 | return self.img_dir
23 |
24 | def add_header(self, str):
25 | with self.doc:
26 | h3(str)
27 |
28 | def add_table(self, border=1):
29 | self.t = table(border=border, style="table-layout: fixed;")
30 | self.doc.add(self.t)
31 |
32 | def add_images(self, ims, txts, links, width=512):
33 | self.add_table()
34 | with self.t:
35 | with tr():
36 | for im, txt, link in zip(ims, txts, links):
37 | with td(style="word-wrap: break-word;", halign="center", valign="top"):
38 | with p():
39 | with a(href=os.path.join('images', link)):
40 | img(style="width:%dpx" % (width), src=os.path.join('images', im))
41 | br()
42 | p(txt)
43 |
44 |
45 | def add_images2(self, ims, links, txts, disp_title, width=512):
46 | self.add_table()
47 | with self.t:
48 | with tr():
49 | with td(style="word-wrap: break-word;", halign="center", valign="top"):
50 | with p():
51 | t = txts[0]
52 | while (len(t) < 10):
53 | t = '_' + t
54 | p(t)
55 | for im, txt, link in zip(ims, txts, links):
56 | with td(style="word-wrap: break-word;", halign="center", valign="top"):
57 | with p():
58 | if (disp_title):
59 | p(txt)
60 | with a(href=os.path.join('images', link)):
61 | img(style="width:%dpx" % (width), src=os.path.join('images', im))
62 |
63 | def save(self):
64 | html_file = '%s/index.html' % self.web_dir
65 | f = open(html_file, 'wt')
66 | f.write(self.doc.render())
67 | f.close()
68 |
69 |
70 | if __name__ == '__main__':
71 | html = HTML('web/', 'test_html')
72 | html.add_header('hello world')
73 |
74 | ims = []
75 | txts = []
76 | links = []
77 | for n in range(4):
78 | ims.append('image_%d.jpg' % n)
79 | txts.append('text_%d' % n)
80 | links.append('image_%d.jpg' % n)
81 | html.add_images(ims, txts, links)
82 | html.save()
83 |
--------------------------------------------------------------------------------
/util/image_pool.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torch
3 | from torch.autograd import Variable
4 | class ImagePool():
5 | def __init__(self, pool_size):
6 | self.pool_size = pool_size
7 | if self.pool_size > 0:
8 | self.num_imgs = 0
9 | self.images = []
10 |
11 | def query(self, images):
12 | if self.pool_size == 0:
13 | return images
14 | return_images = []
15 | for image in images.data:
16 | image = torch.unsqueeze(image, 0)
17 | if self.num_imgs < self.pool_size:
18 | self.num_imgs = self.num_imgs + 1
19 | self.images.append(image)
20 | return_images.append(image)
21 | else:
22 | p = random.uniform(0, 1)
23 | if p > 0.5:
24 | random_id = random.randint(0, self.pool_size-1)
25 | tmp = self.images[random_id].clone()
26 | self.images[random_id] = image
27 | return_images.append(tmp)
28 | else:
29 | return_images.append(image)
30 | return_images = Variable(torch.cat(return_images, 0))
31 | return return_images
32 |
--------------------------------------------------------------------------------
/util/label.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 |
3 | # a label and all meta information
4 | # Code inspired by Cityscapes https://github.com/mcordts/cityscapesScripts
5 | Label = namedtuple('Label', [
6 |
7 | 'name', # The identifier of this label, e.g. 'car', 'person', ... .
8 | # We use them to uniquely name a class
9 |
10 | 'id', # An integer ID that is associated with this label.
11 | # The IDs are used to represent the label in ground truth images
12 | # An ID of -1 means that this label does not have an ID and thus
13 | # is ignored when creating ground truth images (e.g. license plate).
14 | # Do not modify these IDs, since exactly these IDs are expected by the
15 | # evaluation server.
16 |
17 | 'trainId',
18 | # Feel free to modify these IDs as suitable for your method. Then create
19 | # ground truth images with train IDs, using the tools provided in the
20 | # 'preparation' folder. However, make sure to validate or submit results
21 | # to our evaluation server using the regular IDs above!
22 | # For trainIds, multiple labels might have the same ID. Then, these labels
23 | # are mapped to the same class in the ground truth images. For the inverse
24 | # mapping, we use the label that is defined first in the list below.
25 | # For example, mapping all void-type classes to the same ID in training,
26 | # might make sense for some approaches.
27 | # Max value is 255!
28 |
29 | 'category', # The name of the category that this label belongs to
30 |
31 | 'categoryId',
32 | # The ID of this category. Used to create ground truth images
33 | # on category level.
34 |
35 | 'hasInstances',
36 | # Whether this label distinguishes between single instances or not
37 |
38 | 'ignoreInEval',
39 | # Whether pixels having this class as ground truth label are ignored
40 | # during evaluations or not
41 |
42 | 'color', # The color of this label
43 | ])
44 |
45 |
46 | # Our extended list of label types. Our train id is compatible with Cityscapes
47 | labels = [
48 | # name id trainId category catId hasInstances ignoreInEval color
49 | Label( 'unlabeled' , 0 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
50 | Label( 'dynamic' , 1 , 255 , 'void' , 0 , False , True , (111, 74, 0) ),
51 | Label( 'ego vehicle' , 2 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
52 | Label( 'ground' , 3 , 255 , 'void' , 0 , False , True , ( 81, 0, 81) ),
53 | Label( 'static' , 4 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
54 | Label( 'parking' , 5 , 255 , 'flat' , 1 , False , True , (250,170,160) ),
55 | Label( 'rail track' , 6 , 255 , 'flat' , 1 , False , True , (230,150,140) ),
56 | Label( 'road' , 7 , 0 , 'flat' , 1 , False , False , (128, 64,128) ),
57 | Label( 'sidewalk' , 8 , 1 , 'flat' , 1 , False , False , (244, 35,232) ),
58 | Label( 'bridge' , 9 , 255 , 'construction' , 2 , False , True , (150,100,100) ),
59 | Label( 'building' , 10 , 2 , 'construction' , 2 , False , False , ( 70, 70, 70) ),
60 | Label( 'fence' , 11 , 4 , 'construction' , 2 , False , False , (190,153,153) ),
61 | Label( 'garage' , 12 , 255 , 'construction' , 2 , False , True , (180,100,180) ),
62 | Label( 'guard rail' , 13 , 255 , 'construction' , 2 , False , True , (180,165,180) ),
63 | Label( 'tunnel' , 14 , 255 , 'construction' , 2 , False , True , (150,120, 90) ),
64 | Label( 'wall' , 15 , 3 , 'construction' , 2 , False , False , (102,102,156) ),
65 | Label( 'banner' , 16 , 255 , 'object' , 3 , False , True , (250,170,100) ),
66 | Label( 'billboard' , 17 , 255 , 'object' , 3 , False , True , (220,220,250) ),
67 | Label( 'lane divider' , 18 , 255 , 'object' , 3 , False , True , (255, 165, 0) ),
68 | Label( 'parking sign' , 19 , 255 , 'object' , 3 , False , False , (220, 20, 60) ),
69 | Label( 'pole' , 20 , 5 , 'object' , 3 , False , False , (153,153,153) ),
70 | Label( 'polegroup' , 21 , 255 , 'object' , 3 , False , True , (153,153,153) ),
71 | Label( 'street light' , 22 , 255 , 'object' , 3 , False , True , (220,220,100) ),
72 | Label( 'traffic cone' , 23 , 255 , 'object' , 3 , False , True , (255, 70, 0) ),
73 | Label( 'traffic device' , 24 , 255 , 'object' , 3 , False , True , (220,220,220) ),
74 | Label( 'traffic light' , 25 , 6 , 'object' , 3 , False , False , (250,170, 30) ),
75 | Label( 'traffic sign' , 26 , 7 , 'object' , 3 , False , False , (220,220, 0) ),
76 | Label( 'traffic sign frame' , 27 , 255 , 'object' , 3 , False , True , (250,170,250) ),
77 | Label( 'terrain' , 28 , 9 , 'nature' , 4 , False , False , (152,251,152) ),
78 | Label( 'vegetation' , 29 , 8 , 'nature' , 4 , False , False , (107,142, 35) ),
79 | Label( 'sky' , 30 , 10 , 'sky' , 5 , False , False , ( 70,130,180) ),
80 | Label( 'person' , 31 , 11 , 'human' , 6 , True , False , (220, 20, 60) ),
81 | Label( 'rider' , 32 , 12 , 'human' , 6 , True , False , (255, 0, 0) ),
82 | Label( 'bicycle' , 33 , 18 , 'vehicle' , 7 , True , False , (119, 11, 32) ),
83 | Label( 'bus' , 34 , 15 , 'vehicle' , 7 , True , False , ( 0, 60,100) ),
84 | Label( 'car' , 35 , 13 , 'vehicle' , 7 , True , False , ( 0, 0,142) ),
85 | Label( 'caravan' , 36 , 255 , 'vehicle' , 7 , True , True , ( 0, 0, 90) ),
86 | Label( 'motorcycle' , 37 , 17 , 'vehicle' , 7 , True , False , ( 0, 0,230) ),
87 | Label( 'trailer' , 38 , 255 , 'vehicle' , 7 , True , True , ( 0, 0,110) ),
88 | Label( 'train' , 39 , 16 , 'vehicle' , 7 , True , False , ( 0, 80,100) ),
89 | Label( 'truck' , 40 , 14 , 'vehicle' , 7 , True , False , ( 0, 0, 70) ),
90 | ]
91 |
--------------------------------------------------------------------------------
/util/util.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import torch
3 | import numpy as np
4 | from PIL import Image
5 | import numpy as np
6 | import os
7 | from util.label import labels
8 |
9 | def get_big_img(ans):
10 | ret = []
11 | for a in ans:
12 | temp = []
13 | for b in a:
14 | if (len(b.shape) == 2):
15 | t = b[np.newaxis, :, :]
16 | c = np.concatenate((t,t,t), axis=0)
17 | else:
18 | c = b
19 | temp.append(c)
20 | t = np.concatenate(temp, axis=1)
21 | ret.append(t)
22 | final = np.concatenate(ret, axis=2)
23 | print(final.shape)
24 | return final
25 |
26 | # Converts a Tensor into a Numpy array
27 | # |imtype|: the desired type of the converted numpy array
28 | def tensor2im(image_tensor, imtype=np.uint8, normalize=True):
29 | if isinstance(image_tensor, list):
30 | image_numpy = []
31 | for i in range(len(image_tensor)):
32 | image_numpy.append(tensor2im(image_tensor[i], imtype, normalize))
33 | return image_numpy
34 | image_numpy = image_tensor.cpu().float().numpy()
35 | if normalize:
36 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
37 | else:
38 | image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0
39 | image_numpy = np.clip(image_numpy, 0, 255)
40 | if image_numpy.shape[2] == 1 or image_numpy.shape[2] > 3:
41 | image_numpy = image_numpy[:,:,0]
42 | return image_numpy.astype(imtype)
43 |
44 | # Converts a Tensor into a Numpy array
45 | # |imtype|: the desired type of the converted numpy array
46 | def tensor2im2(image_tensor, imtype=np.uint8, normalize=True):
47 | if isinstance(image_tensor, list):
48 | image_numpy = []
49 | for i in range(len(image_tensor)):
50 | image_numpy.append(tensor2im(image_tensor[i], imtype, normalize))
51 | return image_numpy
52 | image_numpy = image_tensor.cpu().float().numpy()
53 | if normalize:
54 | image_numpy = (image_numpy + 1) / 2.0 * 255.0
55 | else:
56 | image_numpy = image_numpy * 255.0
57 | image_numpy = np.clip(image_numpy, 0, 255)
58 | if image_numpy.shape[0] == 23:
59 | image_numpy = np.concatenate([image_numpy[3:6, :, :], image_numpy[0:3, :, :]], axis=1)
60 | if image_numpy.shape[0] == 1 or image_numpy.shape[0] > 3:
61 | image_numpy = image_numpy[0, :,:]
62 | return image_numpy.astype(imtype)
63 |
64 | # Converts a one-hot tensor into a colorful label map
65 | def tensor2label2(label_tensor, n_label, imtype=np.uint8):
66 | if n_label == 0:
67 | return label_tensor.cpu().float().numpy().astype(imtype)
68 | return tensor2im2(label_tensor, imtype)
69 | label_tensor = label_tensor.cpu().float()
70 | if label_tensor.size()[0] > 1:
71 | label_tensor = label_tensor.max(0, keepdim=True)[1]
72 | label_tensor = Colorize(n_label)(label_tensor)
73 | #label_numpy = np.transpose(label_tensor.numpy(), (1, 2, 0))
74 | #print(label_numpy.shape)
75 | return label_tensor.numpy().astype(imtype)
76 |
77 | # Converts a one-hot tensor into a colorful label map
78 | def tensor2label(label_tensor, n_label, imtype=np.uint8):
79 | if n_label == 0:
80 | return tensor2im(label_tensor, imtype)
81 | label_tensor = label_tensor.cpu().float()
82 | if label_tensor.size()[0] > 1:
83 | label_tensor = label_tensor.max(0, keepdim=True)[1]
84 | label_tensor = Colorize(n_label)(label_tensor)
85 | label_numpy = np.transpose(label_tensor.numpy(), (1, 2, 0))
86 | #print(label_numpy.shape)
87 | return label_numpy.astype(imtype)
88 |
89 | def save_image(image_numpy, image_path):
90 | image_pil = Image.fromarray(image_numpy)
91 | image_pil.save(image_path)
92 |
93 | def mkdirs(paths):
94 | if isinstance(paths, list) and not isinstance(paths, str):
95 | for path in paths:
96 | mkdir(path)
97 | else:
98 | mkdir(paths)
99 |
100 | def mkdir(path):
101 | if not os.path.exists(path):
102 | os.makedirs(path)
103 |
104 | ###############################################################################
105 | # Code from
106 | # https://github.com/ycszen/pytorch-seg/blob/master/transform.py
107 | # Modified so it complies with the Citscape label map colors
108 | ###############################################################################
109 | def uint82bin(n, count=8):
110 | """returns the binary of integer n, count refers to amount of bits"""
111 | return ''.join([str((n >> y) & 1) for y in range(count-1, -1, -1)])
112 |
113 | def labelcolormap(N):
114 | if N == 35: # cityscape
115 | cmap = np.array([( 0, 0, 0), ( 0, 0, 0), ( 0, 0, 0), ( 0, 0, 0), ( 0, 0, 0), (111, 74, 0), ( 81, 0, 81),
116 | (128, 64,128), (244, 35,232), (250,170,160), (230,150,140), ( 70, 70, 70), (102,102,156), (190,153,153),
117 | (180,165,180), (150,100,100), (150,120, 90), (153,153,153), (153,153,153), (250,170, 30), (220,220, 0),
118 | (107,142, 35), (152,251,152), ( 70,130,180), (220, 20, 60), (255, 0, 0), ( 0, 0,142), ( 0, 0, 70),
119 | ( 0, 60,100), ( 0, 0, 90), ( 0, 0,110), ( 0, 80,100), ( 0, 0,230), (119, 11, 32), ( 0, 0,142)],
120 | dtype=np.uint8)
121 | elif N == 256:
122 | t = []
123 | for i in range(256):
124 | t.append((0, 0, 0))
125 | for i in range(len(labels)):
126 | t[labels[i].trainId] = labels[i].color
127 | cmap = np.array(t, dtype=np.uint8)
128 | else:
129 | cmap = np.zeros((N, 3), dtype=np.uint8)
130 | for i in range(N):
131 | r, g, b = 0, 0, 0
132 | id = i
133 | for j in range(7):
134 | str_id = uint82bin(id)
135 | r = r ^ (np.uint8(str_id[-1]) << (7-j))
136 | g = g ^ (np.uint8(str_id[-2]) << (7-j))
137 | b = b ^ (np.uint8(str_id[-3]) << (7-j))
138 | id = id >> 3
139 | cmap[i, 0] = r
140 | cmap[i, 1] = g
141 | cmap[i, 2] = b
142 | return cmap
143 |
144 | class Colorize(object):
145 | def __init__(self, n=35):
146 | self.cmap = labelcolormap(n)
147 | self.cmap = torch.from_numpy(self.cmap[:n])
148 |
149 | def __call__(self, gray_image):
150 | size = gray_image.size()
151 | color_image = torch.ByteTensor(3, size[1], size[2]).fill_(0)
152 |
153 | for label in range(0, len(self.cmap)):
154 | mask = (label == gray_image[0]).cpu()
155 | color_image[0][mask] = self.cmap[label][0]
156 | color_image[1][mask] = self.cmap[label][1]
157 | color_image[2][mask] = self.cmap[label][2]
158 |
159 | return color_image
160 |
--------------------------------------------------------------------------------
/util/visualizer.py:
--------------------------------------------------------------------------------
1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
3 | import numpy as np
4 | import os
5 | import ntpath
6 | import time
7 | from . import util
8 | from . import html
9 | import scipy.misc
10 | try:
11 | from StringIO import StringIO # Python 2.7
12 | except ImportError:
13 | from io import BytesIO # Python 3.x
14 |
15 | class Visualizer():
16 | def __init__(self, opt):
17 | # self.opt = opt
18 | self.tf_log = opt.tf_log
19 | self.use_html = opt.isTrain and not opt.no_html
20 | self.win_size = opt.display_winsize
21 | self.name = opt.name
22 | if self.tf_log:
23 | import tensorflow as tf
24 | self.tf = tf
25 | self.log_dir = os.path.join(opt.checkpoints_dir, opt.name, 'logs')
26 | self.writer = tf.summary.FileWriter(self.log_dir)
27 |
28 | if self.use_html:
29 | self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
30 | self.img_dir = os.path.join(self.web_dir, 'images')
31 | print('create web directory %s...' % self.web_dir)
32 | util.mkdirs([self.web_dir, self.img_dir])
33 | self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
34 | with open(self.log_name, "a") as log_file:
35 | now = time.strftime("%c")
36 | log_file.write('================ Training Loss (%s) ================\n' % now)
37 |
38 | def get_names(self, path):
39 | names = []
40 | for root, _, fnames in os.walk(path):
41 | for fname in fnames:
42 | t = fname.find('input_label')
43 | if (t != -1):
44 | names.append([fname[:t], fname[t + 12:-4]])
45 | names = sorted(names, key=lambda x:-int(x[1]))
46 | return names
47 |
48 | # |visuals|: dictionary of images to display or save
49 | def display_current_results2(self, visuals, epoch, step):
50 | if self.tf_log: # show images in tensorboard output
51 | img_summaries = []
52 | for label, image_numpy in visuals.items():
53 | # Write the image to a string
54 | try:
55 | s = StringIO()
56 | except:
57 | s = BytesIO()
58 | scipy.misc.toimage(image_numpy).save(s, format="jpeg")
59 | # Create an Image object
60 | img_sum = self.tf.Summary.Image(encoded_image_string=s.getvalue(), height=image_numpy.shape[0], width=image_numpy.shape[1])
61 | # Create a Summary value
62 | img_summaries.append(self.tf.Summary.Value(tag=label, image=img_sum))
63 |
64 | # Create and write Summary
65 | summary = self.tf.Summary(value=img_summaries)
66 | self.writer.add_summary(summary, step)
67 |
68 | if self.use_html: # save images to a html file
69 | for label, image_numpy in visuals.items():
70 | if isinstance(image_numpy, list):
71 | for i in range(len(image_numpy)):
72 | img_path = os.path.join(self.img_dir, 'epoch%.3d_%s_%d.jpg' % (epoch, label, i))
73 | util.save_image(image_numpy[i], img_path)
74 | else:
75 | img_path = os.path.join(self.img_dir, 'epoch%.3d_%s_%.3d.jpg' % (epoch, label, step))
76 | util.save_image(image_numpy, img_path)
77 |
78 | names = self.get_names(self.web_dir)
79 | # update website
80 | webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=30)
81 | for name in names:
82 | webpage.add_header('%s %s' % (name[0], name[1]))
83 | ims = []
84 | txts = []
85 | links = []
86 |
87 | for label, image_numpy in visuals.items():
88 | img_path = '%s%s_%s.jpg' % (name[0], label, name[1])
89 | ims.append(img_path)
90 | txts.append(label)
91 | links.append(img_path)
92 | if len(ims) < 10:
93 | webpage.add_images(ims, txts, links, width=self.win_size)
94 | else:
95 | num = int(round(len(ims)/2.0))
96 | webpage.add_images(ims[:num], txts[:num], links[:num], width=self.win_size)
97 | webpage.add_images(ims[num:], txts[num:], links[num:], width=self.win_size)
98 | webpage.save()
99 |
100 |
101 | # |visuals|: dictionary of images to display or save
102 | def display_current_results(self, visuals, epoch, step):
103 | if self.tf_log: # show images in tensorboard output
104 | img_summaries = []
105 | for label, image_numpy in visuals.items():
106 | # Write the image to a string
107 | try:
108 | s = StringIO()
109 | except:
110 | s = BytesIO()
111 | scipy.misc.toimage(image_numpy).save(s, format="jpeg")
112 | # Create an Image object
113 | img_sum = self.tf.Summary.Image(encoded_image_string=s.getvalue(), height=image_numpy.shape[0], width=image_numpy.shape[1])
114 | # Create a Summary value
115 | img_summaries.append(self.tf.Summary.Value(tag=label, image=img_sum))
116 |
117 | # Create and write Summary
118 | summary = self.tf.Summary(value=img_summaries)
119 | self.writer.add_summary(summary, step)
120 |
121 | if self.use_html: # save images to a html file
122 | for label, image_numpy in visuals.items():
123 | if isinstance(image_numpy, list):
124 | for i in range(len(image_numpy)):
125 | img_path = os.path.join(self.img_dir, 'epoch%.3d_%s_%d.jpg' % (epoch, label, i))
126 | util.save_image(image_numpy[i], img_path)
127 | else:
128 | img_path = os.path.join(self.img_dir, 'epoch%.3d_%s_%.3d.jpg' % (epoch, label, step))
129 | util.save_image(image_numpy, img_path)
130 |
131 | # update website
132 | webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=30)
133 | for n in range(epoch, 0, -1):
134 | webpage.add_header('epoch [%d]' % n)
135 | ims = []
136 | txts = []
137 | links = []
138 |
139 | for label, image_numpy in visuals.items():
140 | if isinstance(image_numpy, list):
141 | for i in range(len(image_numpy)):
142 | img_path = 'epoch%.3d_%s_%d.jpg' % (n, label, i)
143 | ims.append(img_path)
144 | txts.append(label+str(i))
145 | links.append(img_path)
146 | else:
147 | img_path = 'epoch%.3d_%s_%.3d.jpg' % (n, label, step)
148 | ims.append(img_path)
149 | txts.append(label)
150 | links.append(img_path)
151 | if len(ims) < 10:
152 | webpage.add_images(ims, txts, links, width=self.win_size)
153 | else:
154 | num = int(round(len(ims)/2.0))
155 | webpage.add_images(ims[:num], txts[:num], links[:num], width=self.win_size)
156 | webpage.add_images(ims[num:], txts[num:], links[num:], width=self.win_size)
157 | webpage.save()
158 |
159 | # errors: dictionary of error labels and values
160 | def plot_current_errors(self, errors, step):
161 | if self.tf_log:
162 | for tag, value in errors.items():
163 | summary = self.tf.Summary(value=[self.tf.Summary.Value(tag=tag, simple_value=value)])
164 | self.writer.add_summary(summary, step)
165 |
166 | # errors: same format as |errors| of plotCurrentErrors
167 | def print_current_errors(self, epoch, i, errors, t):
168 | message = '(epoch: %d, iters: %d, time: %.3f) ' % (epoch, i, t)
169 | temp = errors.items()
170 | temp = sorted(temp, key=lambda x:x[0])
171 | for k, v in temp:
172 | if v != 0:
173 | message += '%s: %.3f ' % (k, v)
174 |
175 | print(message)
176 | with open(self.log_name, "a") as log_file:
177 | log_file.write('%s\n' % message)
178 |
179 | # save image to the disk
180 | def save_images(self, webpage, visuals, image_path):
181 | image_dir = webpage.get_image_dir()
182 | short_path = ntpath.basename(image_path[0])
183 | name = os.path.splitext(short_path)[0]
184 |
185 | webpage.add_header(name)
186 | ims = []
187 | txts = []
188 | links = []
189 |
190 | for label, image_numpy in visuals.items():
191 | image_name = '%s_%s.jpg' % (name, label)
192 | save_path = os.path.join(image_dir, image_name)
193 | util.save_image(image_numpy, save_path)
194 |
195 | ims.append(image_name)
196 | txts.append(label)
197 | links.append(image_name)
198 | webpage.add_images(ims, txts, links, width=self.win_size)
199 | # save image to the disk
200 | def save_images2(self, webpage, visuals, image_path, disp_title=False):
201 | image_dir = webpage.get_image_dir()
202 | short_path = ntpath.basename(image_path[0])
203 | name = os.path.splitext(short_path)[0]
204 |
205 | ims = []
206 | links = []
207 | txts = []
208 |
209 | for label, image_numpy in visuals.items():
210 | image_name = '%s_%s.jpg' % (name, label)
211 | save_path = os.path.join(image_dir, image_name)
212 | util.save_image(image_numpy, save_path)
213 |
214 | ims.append(image_name)
215 | links.append(image_name)
216 | txts.append(label)
217 | webpage.add_images2(ims, links, txts, disp_title, width=self.win_size)
218 |
--------------------------------------------------------------------------------
/val_gen.py:
--------------------------------------------------------------------------------
1 | import os
2 | A_list = ['__Record024__Camera_5__170927_071032417_Camera_5']
3 | A_list = ['_a']
4 | B_list = ['__Record024__Camera_6__170927_070953498_Camera_6', '__Record031__Camera_5__170927_071846100_Camera_5', '__Record035__Camera_5__170927_072446711_Camera_5']
5 | data_path = './datasets/apollo/'
6 |
7 | def label(A):
8 | return os.path.join(data_path, 'label', 'Label' + A + '_bin.png')
9 | def img(A):
10 | return os.path.join(data_path, 'img', 'ColorImage' + A + '.jpg')
11 |
12 | lines = []
13 | for A in A_list:
14 | for B in B_list:
15 | lines.append(label(A) + '&' + label(B) + '&' + img(B) + '&' + img(A) + '\n')
16 | f = open(os.path.join(data_path, 'debug_list.txt'), 'w')
17 | f.writelines(lines)
18 |
--------------------------------------------------------------------------------
/vis_bdd.py:
--------------------------------------------------------------------------------
1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
3 | import time
4 | from collections import OrderedDict
5 | from options.train_options import TrainOptions
6 | from data.data_loader import CreateConDataLoader
7 | from models.models import create_model
8 | import util.util as util
9 | from util.visualizer import Visualizer
10 | import os
11 | import numpy as np
12 | import torch
13 | from torch.autograd import Variable
14 | import tensorboardX
15 | import random
16 |
17 | GAP = 1
18 |
19 | def write_temp(opt, target_phase):
20 | source_path = os.path.join(opt.dataroot, opt.phase + '_list.txt')
21 | target_path = os.path.join(opt.dataroot, target_phase + '_list.txt')
22 | f = open(source_path, 'r')
23 | all_paths = f.readlines()
24 | ans = []
25 | last = ''
26 | tot = 0
27 | print(len(all_paths))
28 | for path in all_paths:
29 | path_ = path.split('&')[0]
30 | if (path_ != last):
31 | last = path_
32 | tot = tot + 1
33 | if tot % GAP == 0:
34 | ans.append(path)
35 | f_w = open(target_path, 'w')
36 | f_w.writelines(ans)
37 |
38 | opt = TrainOptions().parse()
39 | opt.phase = 'val'
40 | write_temp(opt, "temp")
41 | opt.phase = "temp"
42 | opt.serial_batches = True
43 |
44 | data_loader = CreateConDataLoader(opt)
45 | dataset = data_loader.load_data()
46 | dataset_size = len(data_loader)
47 |
48 | visualizer = Visualizer(opt)
49 |
50 | total_steps = 0# (start_epoch-1) * dataset_size + epoch_iter
51 |
52 | display_delta = total_steps % opt.display_freq
53 | print_delta = total_steps % opt.print_freq
54 | save_delta = total_steps % opt.save_latest_freq
55 |
56 | for i, data in enumerate(dataset):
57 | if (i % 100 == 0):
58 | print((i, dataset_size))
59 | visuals = OrderedDict([('input_label', util.tensor2label(data['A'][0], 256)),
60 | ('real_image', util.tensor2im(data['A2'][0]))])
61 | visualizer.display_current_results2(visuals, 0, i)
62 |
63 |
--------------------------------------------------------------------------------
/vis_delta.txt:
--------------------------------------------------------------------------------
1 | 0
2 |
--------------------------------------------------------------------------------
/vis_face.py:
--------------------------------------------------------------------------------
1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
3 | import time
4 | from collections import OrderedDict
5 | from options.train_options import TrainOptions
6 | from data.data_loader import CreateFaceConDataLoader
7 | from models.models import create_model
8 | import util.util as util
9 | from util.visualizer import Visualizer
10 | import os
11 | import numpy as np
12 | import torch
13 | from torch.autograd import Variable
14 | import tensorboardX
15 | import random
16 |
17 | GAP = 10
18 |
19 | def write_temp(opt, target_phase):
20 | source_path = os.path.join(opt.dataroot, opt.phase + '_list.txt')
21 | target_path = os.path.join(opt.dataroot, target_phase + '_list.txt')
22 | f = open(source_path, 'r')
23 | all_paths = f.readlines()
24 | ans = []
25 | last = ''
26 | tot = 0
27 | print(len(all_paths))
28 | for path in all_paths:
29 | path_ = path.split('&')[0]
30 | if (path_ != last):
31 | last = path_
32 | tot = tot + 1
33 | if tot % GAP == 0:
34 | ans.append(path)
35 | f_w = open(target_path, 'w')
36 | f_w.writelines(ans)
37 |
38 | opt = TrainOptions().parse()
39 | opt.phase = 'val'
40 | write_temp(opt, "temp")
41 | opt.phase = "temp"
42 | opt.serial_batches = True
43 |
44 | data_loader = CreateFaceConDataLoader(opt)
45 | dataset = data_loader.load_data()
46 | dataset_size = len(data_loader)
47 |
48 | visualizer = Visualizer(opt)
49 |
50 | total_steps = 0# (start_epoch-1) * dataset_size + epoch_iter
51 |
52 | display_delta = total_steps % opt.display_freq
53 | print_delta = total_steps % opt.print_freq
54 | save_delta = total_steps % opt.save_latest_freq
55 |
56 | for i, data in enumerate(dataset):
57 | if (i % 100 == 0):
58 | print((i, dataset_size))
59 | visuals = OrderedDict([('input_label', util.tensor2label(data['A'][0], 0)),
60 | ('real_image', util.tensor2im(data['A2'][0]))])
61 | visualizer.display_current_results2(visuals, 0, i)
62 |
63 |
--------------------------------------------------------------------------------
/vis_pose.py:
--------------------------------------------------------------------------------
1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
3 | import time
4 | from collections import OrderedDict
5 | from options.train_options import TrainOptions
6 | from data.data_loader import CreatePoseConDataLoader
7 | from models.models import create_model
8 | import util.util as util
9 | from util.visualizer import Visualizer
10 | import os
11 | import numpy as np
12 | import torch
13 | from torch.autograd import Variable
14 | import tensorboardX
15 | import random
16 |
17 | GAP = 10
18 |
19 | def write_temp(opt, target_phase):
20 | source_path = os.path.join(opt.dataroot, opt.phase + '_list.txt')
21 | target_path = os.path.join(opt.dataroot, target_phase + '_list.txt')
22 | f = open(source_path, 'r')
23 | all_paths = f.readlines()
24 | ans = []
25 | last = ''
26 | tot = 0
27 | print(len(all_paths))
28 | for path in all_paths:
29 | path_ = path.split('&')[0]
30 | if (path_ != last):
31 | last = path_
32 | tot = tot + 1
33 | if tot % GAP == 0:
34 | ans.append(path)
35 | f_w = open(target_path, 'w')
36 | f_w.writelines(ans)
37 |
38 | opt = TrainOptions().parse()
39 | opt.phase = 'val'
40 | write_temp(opt, "temp")
41 | opt.phase = "temp"
42 | opt.serial_batches = True
43 |
44 | data_loader = CreatePoseConDataLoader(opt)
45 | dataset = data_loader.load_data()
46 | dataset_size = len(data_loader)
47 |
48 | visualizer = Visualizer(opt)
49 |
50 | total_steps = 0# (start_epoch-1) * dataset_size + epoch_iter
51 |
52 | display_delta = total_steps % opt.display_freq
53 | print_delta = total_steps % opt.print_freq
54 | save_delta = total_steps % opt.save_latest_freq
55 |
56 | for i, data in enumerate(dataset):
57 | if (i % 100 == 0):
58 | print((i, dataset_size))
59 | visuals = OrderedDict([('input_label', util.tensor2label(data['A'][0][3:6, :, :], 0)),
60 | ('real_image', util.tensor2im(data['A2'][0]))])
61 | visualizer.display_current_results2(visuals, 0, i)
62 |
63 |
--------------------------------------------------------------------------------