├── LICENSE ├── README.md ├── configs ├── simple_human36m.yaml └── test_simple_human.yaml ├── data ├── __init__.py ├── base_data_loader.py ├── base_dataset.py ├── human36m_skeleton.py ├── image_folder.py ├── simplehuman36m_dataset.py └── utils.py ├── models ├── __init__.py ├── base_model.py ├── keypoint_gan_model.py ├── networks.py ├── perceptual_loss.py └── utils.py ├── options ├── __init__.py ├── base_options.py ├── test_options.py └── train_options.py ├── requirements.txt ├── test_pose.py ├── train.py └── util ├── __init__.py ├── html.py ├── image_pool.py ├── plotting.py ├── skeleton.py ├── tps_sampler.py ├── util.py └── visualizer.py /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2020, Tomas Jakab 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 15 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 16 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 17 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 18 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 19 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 20 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 21 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 22 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 23 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 24 | 25 | --------------------------- LICENSE FOR CycleGAN -------------------------------- 26 | For CycleGAN software 27 | 28 | Copyright (c) 2017, Jun-Yan Zhu and Taesung Park 29 | All rights reserved. 30 | 31 | Redistribution and use in source and binary forms, with or without 32 | modification, are permitted provided that the following conditions are met: 33 | 34 | * Redistributions of source code must retain the above copyright notice, this 35 | list of conditions and the following disclaimer. 36 | 37 | * Redistributions in binary form must reproduce the above copyright notice, 38 | this list of conditions and the following disclaimer in the documentation 39 | and/or other materials provided with the distribution. 40 | 41 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 42 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 43 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 44 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 45 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 46 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 47 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 48 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 49 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 50 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 51 | 52 | 53 | --------------------------- LICENSE FOR pix2pix -------------------------------- 54 | BSD License 55 | 56 | For pix2pix software 57 | Copyright (c) 2016, Phillip Isola and Jun-Yan Zhu 58 | All rights reserved. 59 | 60 | Redistribution and use in source and binary forms, with or without 61 | modification, are permitted provided that the following conditions are met: 62 | 63 | * Redistributions of source code must retain the above copyright notice, this 64 | list of conditions and the following disclaimer. 65 | 66 | * Redistributions in binary form must reproduce the above copyright notice, 67 | this list of conditions and the following disclaimer in the documentation 68 | and/or other materials provided with the distribution. 69 | 70 | ----------------------------- LICENSE FOR DCGAN -------------------------------- 71 | BSD License 72 | 73 | For dcgan.torch software 74 | 75 | Copyright (c) 2015, Facebook, Inc. All rights reserved. 76 | 77 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 78 | 79 | Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 80 | 81 | Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 82 | 83 | Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 84 | 85 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 86 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [KeypointGAN: Self-supervised Learning of Interpretable Keypoints from Unlabelled Videos](https://www.robots.ox.ac.uk/~vgg/research/unsupervised_pose/) 2 | 3 | [Tomas Jakab](http://www.robots.ox.ac.uk/~tomj), Ankush Gupta, Hakan Bilen, Andrea Vedaldi. 4 | CVPR, 2020 (Oral presentation). 5 | 6 | ## Quick start 7 | Download Simplified Human3.6M dataset from `http://fy.z-yt.net/files.ytzhang.net/lmdis-rep/release-v1/human3.6m/human_images.tar.gz` into `./datasets/simple_human36m/human_images`. 8 | 9 | Download a network for perceptual loss from `http://www.robots.ox.ac.uk/~vgg/research/unsupervised_pose/resources/imagenet-vgg-verydeep-19.mat` into `./networks/imagenet-vgg-verydeep-19.mat`. 10 | 11 | Paths to datasets and checkpoints can be also customized in `configs/simple_human36m.yaml` and `configs/test_simple_human.yaml` 12 | 13 | ### Training 14 | Training requires pre-trained keypoint regressor. See bellow for instructions on how to do the pre-training. 15 | 16 | A pre-trained regressor can be also downloaded from `http://www.robots.ox.ac.uk/~vgg/research/unsupervised_pose/resources/simple_human36m_regressor/580000_net_regressor.pth`. Save the regressor into `./checkpoints/simple_human36m_regressor` directory unless you specified a different path in the config above. 17 | 18 | Train a model on Simplified Human3.6M dataset 19 | ``` 20 | python2.7 train.py -c configs/simple_human36m.yaml 21 | ``` 22 | 23 | ### Testing 24 | Test a model on Simplified Human3.6M dataset 25 | ``` 26 | python2.7 test_pose.py --test_config configs/test_simple_human.yaml -c configs/simple_human36m.yaml --iteration 27 | ``` 28 | 29 | ## Pre-training regressor 30 | Coming soon. 31 | 32 | ## Acknowledgments 33 | Parts of the code are based on [CycleGAN](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix). 34 | -------------------------------------------------------------------------------- /configs/simple_human36m.yaml: -------------------------------------------------------------------------------- 1 | # paths 2 | dataroot: ./datasets/simple_human36m/human_images 3 | perceptual_net: ./networks/imagenet-vgg-verydeep-19.mat 4 | checkpoints_dir: ./checkpoints 5 | nets_paths: [regressor, ./checkpoints/simple_human36m_regressor/580000_net_regressor.pth] 6 | 7 | # for testing 8 | nets_paths: [offline_regressor, ./checkpoints/simple_human36m_regressor/580000_net_regressor.pth] 9 | 10 | 11 | # experiment name 12 | name: simple_human36m 13 | 14 | 15 | model: keypoint_gan 16 | display_id: -1 17 | dataset_mode: simplehuman36m 18 | resize_or_crop: scale_width 19 | no_flip: True 20 | display_freq: 10 21 | multi_ganA: True 22 | print_freq: 10 23 | loadSize: 128 24 | fineSize: 128 25 | output_nc: 1 26 | cycle_loss: perceptual 27 | netG_A: skip_nips 28 | netG_B: nips 29 | netDA: basic 30 | batch_size: 16 31 | num_threads: 16 32 | save_latest_freq: 5000 33 | save_iters_freq: 5000 34 | clip_grad: 1.0 35 | lambda_gan_A: 10.0 36 | skeleton_type: human36m_simple2 37 | paired_skeleton_type: human36m_simple2 38 | prior_skeleton_type: human36m_simple2 39 | 40 | sample_window: [0, 1000] 41 | 42 | finetune_regressor: True 43 | regressor_real_loss: 0.9 44 | regressor_fake_loss: 0.1 45 | 46 | sigma: 0.2 47 | -------------------------------------------------------------------------------- /configs/test_simple_human.yaml: -------------------------------------------------------------------------------- 1 | # paths 2 | results_dir: ./results 3 | 4 | phase: test 5 | 6 | shuffle: false 7 | no_flip: True 8 | display_id: -1 9 | print_freq: 5 10 | 11 | eval: True 12 | 13 | num_threads: 16 14 | batch_size: 64 15 | sample_window: [0, 1] # the second parameter is sampling frequency 16 | 17 | used_points: all 18 | error_form: image_size 19 | 20 | num_test_save: 60 21 | skeleton_subset_size: 0 22 | 23 | offline_regressor: True 24 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | import numpy as np 4 | import torch.utils.data 5 | 6 | from data.base_data_loader import BaseDataLoader 7 | from data.base_dataset import BaseDataset 8 | 9 | 10 | def find_dataset_using_name(dataset_name): 11 | # Given the option --dataset_mode [datasetname], 12 | # the file "data/datasetname_dataset.py" 13 | # will be imported. 14 | dataset_filename = "data." + dataset_name + "_dataset" 15 | datasetlib = importlib.import_module(dataset_filename) 16 | 17 | # In the file, the class called DatasetNameDataset() will 18 | # be instantiated. It has to be a subclass of BaseDataset, 19 | # and it is case-insensitive. 20 | dataset = None 21 | target_dataset_name = dataset_name.replace('_', '') + 'dataset' 22 | for name, cls in datasetlib.__dict__.items(): 23 | if name.lower() == target_dataset_name.lower() \ 24 | and issubclass(cls, BaseDataset): 25 | dataset = cls 26 | 27 | if dataset is None: 28 | print("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name)) 29 | exit(0) 30 | 31 | return dataset 32 | 33 | 34 | def get_option_setter(dataset_name): 35 | dataset_class = find_dataset_using_name(dataset_name) 36 | return dataset_class.modify_commandline_options 37 | 38 | 39 | def create_dataset(opt): 40 | dataset = find_dataset_using_name(opt.dataset_mode) 41 | instance = dataset() 42 | instance.initialize(opt) 43 | print("dataset [%s] was created" % (instance.name())) 44 | return instance 45 | 46 | 47 | def CreateDataLoader(opt): 48 | data_loader = CustomDatasetDataLoader() 49 | data_loader.initialize(opt) 50 | return data_loader 51 | 52 | 53 | # Wrapper class of Dataset class that performs 54 | # multi-threaded data loading 55 | class CustomDatasetDataLoader(BaseDataLoader): 56 | def name(self): 57 | return 'CustomDatasetDataLoader' 58 | 59 | def initialize(self, opt): 60 | BaseDataLoader.initialize(self, opt) 61 | self.dataset = create_dataset(opt) 62 | self.dataloader = torch.utils.data.DataLoader( 63 | self.dataset, 64 | batch_size=opt.batch_size, 65 | shuffle=opt.shuffle, 66 | num_workers=int(opt.num_threads), 67 | pin_memory=True, 68 | worker_init_fn=worker_init_fn) 69 | 70 | def load_data(self): 71 | return self 72 | 73 | def __len__(self): 74 | return min(len(self.dataset), self.opt.max_dataset_size) 75 | 76 | def __iter__(self): 77 | for i, data in enumerate(self.dataloader): 78 | if i * self.opt.batch_size >= self.opt.max_dataset_size: 79 | break 80 | yield data 81 | 82 | 83 | def worker_init_fn(worker_id): 84 | np.random.seed(np.random.get_state()[1][0] + worker_id) 85 | -------------------------------------------------------------------------------- /data/base_data_loader.py: -------------------------------------------------------------------------------- 1 | class BaseDataLoader(): 2 | def __init__(self): 3 | pass 4 | 5 | def initialize(self, opt): 6 | self.opt = opt 7 | pass 8 | 9 | def load_data(): 10 | return None 11 | -------------------------------------------------------------------------------- /data/base_dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | from PIL import Image 3 | import torchvision.transforms as transforms 4 | 5 | 6 | class BaseDataset(data.Dataset): 7 | def __init__(self): 8 | super(BaseDataset, self).__init__() 9 | 10 | def name(self): 11 | return 'BaseDataset' 12 | 13 | @staticmethod 14 | def modify_commandline_options(parser, is_train): 15 | return parser 16 | 17 | def initialize(self, opt): 18 | pass 19 | 20 | def __len__(self): 21 | return 0 22 | 23 | 24 | def get_transform(opt): 25 | transform_list = [] 26 | if opt.resize_or_crop == 'resize_and_crop': 27 | osize = [opt.loadSize, opt.loadSize] 28 | transform_list.append(transforms.Resize(osize, Image.BICUBIC)) 29 | transform_list.append(transforms.RandomCrop(opt.fineSize)) 30 | elif opt.resize_or_crop == 'crop': 31 | transform_list.append(transforms.RandomCrop(opt.fineSize)) 32 | elif opt.resize_or_crop == 'scale_width': 33 | transform_list.append(transforms.Lambda( 34 | lambda img: __scale_width(img, opt.fineSize))) 35 | elif opt.resize_or_crop == 'scale_width_and_crop': 36 | transform_list.append(transforms.Lambda( 37 | lambda img: __scale_width(img, opt.loadSize))) 38 | transform_list.append(transforms.RandomCrop(opt.fineSize)) 39 | elif opt.resize_or_crop == 'none': 40 | transform_list.append(transforms.Lambda( 41 | lambda img: __adjust(img))) 42 | else: 43 | raise ValueError('--resize_or_crop %s is not a valid option.' % opt.resize_or_crop) 44 | 45 | if opt.isTrain and not opt.no_flip: 46 | transform_list.append(transforms.RandomHorizontalFlip()) 47 | 48 | transform_list += [transforms.ToTensor(), 49 | transforms.Normalize((0.5, 0.5, 0.5), 50 | (0.5, 0.5, 0.5))] 51 | return transforms.Compose(transform_list) 52 | 53 | 54 | # just modify the width and height to be multiple of 4 55 | def __adjust(img): 56 | ow, oh = img.size 57 | 58 | # the size needs to be a multiple of this number, 59 | # because going through generator network may change img size 60 | # and eventually cause size mismatch error 61 | mult = 4 62 | if ow % mult == 0 and oh % mult == 0: 63 | return img 64 | w = (ow - 1) // mult 65 | w = (w + 1) * mult 66 | h = (oh - 1) // mult 67 | h = (h + 1) * mult 68 | 69 | if ow != w or oh != h: 70 | __print_size_warning(ow, oh, w, h) 71 | 72 | return img.resize((w, h), Image.BICUBIC) 73 | 74 | 75 | def __scale_width(img, target_width): 76 | ow, oh = img.size 77 | 78 | # the size needs to be a multiple of this number, 79 | # because going through generator network may change img size 80 | # and eventually cause size mismatch error 81 | mult = 4 82 | assert target_width % mult == 0, "the target width needs to be multiple of %d." % mult 83 | if (ow == target_width and oh % mult == 0): 84 | return img 85 | w = target_width 86 | target_height = int(target_width * oh / ow) 87 | m = (target_height - 1) // mult 88 | h = (m + 1) * mult 89 | 90 | if target_height != h: 91 | __print_size_warning(target_width, target_height, w, h) 92 | 93 | return img.resize((w, h), Image.BICUBIC) 94 | 95 | 96 | def __print_size_warning(ow, oh, w, h): 97 | if not hasattr(__print_size_warning, 'has_printed'): 98 | print("The image size needs to be a multiple of 4. " 99 | "The loaded image size was (%d, %d), so it was adjusted to " 100 | "(%d, %d). This adjustment will be done to all images " 101 | "whose sizes are not multiples of 4" % (ow, oh, w, h)) 102 | __print_size_warning.has_printed = True 103 | -------------------------------------------------------------------------------- /data/human36m_skeleton.py: -------------------------------------------------------------------------------- 1 | hip = 'hip' 2 | thorax = 'thorax' 3 | r_hip = 'r_hip' 4 | r_knee = 'r_knee' 5 | r_ankle = 'r_ankle' 6 | r_ball = 'r_ball' 7 | r_toes = 'r_toes' 8 | l_hip = 'l_hip' 9 | l_knee = 'l_knee' 10 | l_ankle = 'l_ankle' 11 | l_ball = 'l_ball' 12 | l_toes = 'l_toes' 13 | neck_base = 'neck' 14 | head_center = 'head-center' 15 | head_back = 'head-back' 16 | l_uknown = 'l_uknown' 17 | l_shoulder = 'l_shoulder' 18 | l_elbow = 'l_elbow' 19 | l_wrist = 'l_wrist' 20 | l_wrist_2 = 'l_wrist_2' 21 | l_thumb = 'l_thumb' 22 | l_little = 'l_little' 23 | l_little_2 = 'l_little_2' 24 | r_uknown = 'r_uknown' 25 | r_shoulder = 'r_shoulder' 26 | r_elbow = 'r_elbow' 27 | r_wrist = 'r_wrist' 28 | r_wrist_2 = 'r_wrist_2' 29 | r_thumb = 'r_thumb' 30 | r_little = 'r_little' 31 | r_little_2 = 'r_little_2' 32 | pelvis = 'pelvis' 33 | 34 | links = ( 35 | (r_hip, thorax), 36 | # (r_hip, pelvis), 37 | (r_knee, r_hip), 38 | (r_ankle, r_knee), 39 | (r_ball, r_ankle), 40 | (r_toes, r_ball), 41 | (l_hip, thorax), 42 | # (l_hip, pelvis), 43 | (l_knee, l_hip), 44 | (l_ankle, l_knee), 45 | (l_ball, l_ankle), 46 | (l_toes, l_ball), 47 | (neck_base, thorax), 48 | # (head_center, head_back), 49 | # (head_back, neck_base), 50 | # (head_back, head_center), 51 | # (head_center, neck_base), 52 | (head_back, neck_base), 53 | (head_center, head_back), 54 | 55 | (l_shoulder, neck_base), 56 | (l_elbow, l_shoulder), 57 | (l_wrist, l_elbow), 58 | (l_thumb, l_wrist), 59 | (l_little, l_wrist), 60 | (r_shoulder, neck_base), 61 | (r_elbow, r_shoulder), 62 | (r_wrist, r_elbow), 63 | (r_thumb, r_wrist), 64 | (r_little, r_wrist), 65 | # (pelvis, thorax), 66 | ) 67 | 68 | links_simple = ( 69 | (r_hip, thorax), 70 | # (r_hip, pelvis), 71 | (r_knee, r_hip), 72 | (r_ankle, r_knee), 73 | (r_ball, r_ankle), 74 | (r_toes, r_ball), 75 | (l_hip, thorax), 76 | # (l_hip, pelvis), 77 | (l_knee, l_hip), 78 | (l_ankle, l_knee), 79 | (l_ball, l_ankle), 80 | (l_toes, l_ball), 81 | (neck_base, thorax), 82 | # (head_center, head_back), 83 | # (head_back, neck_base), 84 | # (head_back, head_center), 85 | # (head_center, neck_base), 86 | (head_back, neck_base), 87 | (head_center, head_back), 88 | 89 | (l_shoulder, neck_base), 90 | (l_elbow, l_shoulder), 91 | (l_wrist, l_elbow), 92 | (r_shoulder, neck_base), 93 | (r_elbow, r_shoulder), 94 | (r_wrist, r_elbow), 95 | # (pelvis, thorax), 96 | ) 97 | 98 | links_simple2 = ( 99 | (r_hip, pelvis), 100 | (r_knee, r_hip), 101 | (r_ankle, r_knee), 102 | (r_toes, r_ankle), 103 | 104 | (l_hip, pelvis), 105 | (l_knee, l_hip), 106 | (l_ankle, l_knee), 107 | (l_toes, l_ankle), 108 | 109 | (neck_base, pelvis), 110 | (head_back, neck_base), 111 | 112 | (l_shoulder, neck_base), 113 | (l_elbow, l_shoulder), 114 | (l_wrist, l_elbow), 115 | 116 | (r_shoulder, neck_base), 117 | (r_elbow, r_shoulder), 118 | (r_wrist, r_elbow), 119 | ) 120 | 121 | joint_indices = { 122 | hip: 0, 123 | thorax: 12, 124 | r_hip: 1, 125 | r_knee: 2, 126 | r_ankle: 3, 127 | r_ball: 4, 128 | r_toes: 5, 129 | 130 | l_hip: 6, 131 | l_knee: 7, 132 | l_ankle: 8, 133 | l_ball: 9, 134 | l_toes: 10, 135 | 136 | neck_base: 13, 137 | head_center: 14, 138 | head_back: 15, 139 | 140 | l_uknown: 16, 141 | l_shoulder: 17, 142 | l_elbow: 18, 143 | l_wrist: 19, 144 | l_wrist_2: 20, 145 | l_thumb: 21, 146 | l_little: 22, 147 | l_little_2: 23, 148 | 149 | r_uknown: 24, 150 | r_shoulder: 25, 151 | r_elbow: 26, 152 | r_wrist: 27, 153 | r_wrist_2: 28, 154 | r_thumb: 29, 155 | r_little: 30, 156 | r_little_2: 31, 157 | pelvis: 11 158 | } 159 | 160 | joints_eval_martinez = { 161 | 'Hip': 0, 162 | 'RHip': 1, 163 | 'RKnee': 2, 164 | 'RFoot': 3, 165 | 'LHip': 6, 166 | 'LKnee': 7, 167 | 'LFoot': 8, 168 | 'Spine': 12, 169 | 'Thorax': 13, 170 | 'Neck/Nose': 14, 171 | 'Head': 15, 172 | 'LShoulder': 17, 173 | 'LElbow': 18, 174 | 'LWrist': 19, 175 | 'RShoulder': 25, 176 | 'RElbow': 26, 177 | 'RWrist': 27 178 | } 179 | 180 | 181 | official_eval = { 182 | 'Pelvis': (pelvis), 183 | 'RHip': (r_hip), 184 | 'RKnee': (r_knee), 185 | 'RAnkle': (r_ankle), 186 | 'LHip': (l_hip), 187 | 'LKnee': (l_knee), 188 | 'LAnkle': (l_ankle), 189 | 'Spine1': (thorax), 190 | 'Neck': (head_center), 191 | 'Head': (head_back), 192 | 'Site': (neck_base), 193 | 'LShoulder': (l_shoulder), 194 | 'LElbow': (l_elbow), 195 | 'LWrist': (l_wrist), 196 | 'RShoulder': (r_shoulder), 197 | 'RElbow': (r_elbow), 198 | 'RWrist': (r_wrist)} 199 | 200 | 201 | official_eval_indices = {k: joint_indices[v] for k, v in official_eval.items()} 202 | 203 | 204 | 205 | def get_link_indices(links): 206 | return [(joint_indices[x], joint_indices[y]) for x, y in links] 207 | 208 | simple_link_indices = get_link_indices(links_simple) 209 | simple2_link_indices = get_link_indices(links_simple2) 210 | link_indices = get_link_indices(links) 211 | 212 | 213 | def get_lr_correspondences(): 214 | paired = [] 215 | for limb in joint_indices.keys(): 216 | if limb[:2] == 'l_': 217 | paired.append(limb[2:]) 218 | correspond = [] 219 | for limb in paired: 220 | correspond.append((joint_indices['l_' + limb], joint_indices['r_' + limb])) 221 | return correspond 222 | -------------------------------------------------------------------------------- /data/image_folder.py: -------------------------------------------------------------------------------- 1 | ############################################################################### 2 | # Code from 3 | # https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py 4 | # Modified the original code so that it also loads images from the current 5 | # directory as well as the subdirectories 6 | ############################################################################### 7 | 8 | import torch.utils.data as data 9 | 10 | from PIL import Image 11 | import os 12 | import os.path 13 | 14 | IMG_EXTENSIONS = [ 15 | '.jpg', '.JPG', '.jpeg', '.JPEG', 16 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 17 | ] 18 | 19 | 20 | def is_image_file(filename): 21 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 22 | 23 | 24 | def make_dataset(dir): 25 | images = [] 26 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 27 | 28 | for root, _, fnames in sorted(os.walk(dir)): 29 | for fname in fnames: 30 | if is_image_file(fname): 31 | path = os.path.join(root, fname) 32 | images.append(path) 33 | 34 | return images 35 | 36 | 37 | def default_loader(path): 38 | return Image.open(path).convert('RGB') 39 | 40 | 41 | class ImageFolder(data.Dataset): 42 | 43 | def __init__(self, root, transform=None, return_paths=False, 44 | loader=default_loader): 45 | imgs = make_dataset(root) 46 | if len(imgs) == 0: 47 | raise(RuntimeError("Found 0 images in: " + root + "\n" 48 | "Supported image extensions are: " + 49 | ",".join(IMG_EXTENSIONS))) 50 | 51 | self.root = root 52 | self.imgs = imgs 53 | self.transform = transform 54 | self.return_paths = return_paths 55 | self.loader = loader 56 | 57 | def __getitem__(self, index): 58 | path = self.imgs[index] 59 | img = self.loader(path) 60 | if self.transform is not None: 61 | img = self.transform(img) 62 | if self.return_paths: 63 | return img, path 64 | else: 65 | return img 66 | 67 | def __len__(self): 68 | return len(self.imgs) 69 | -------------------------------------------------------------------------------- /data/simplehuman36m_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | from data.base_dataset import BaseDataset 3 | from data.image_folder import make_dataset 4 | from data.unaligned_dataset import UnalignedDataset 5 | import human36m_skeleton 6 | from PIL import Image 7 | import torchvision.transforms as transforms 8 | import matplotlib.pyplot as plt 9 | import skimage.io 10 | import skimage.transform 11 | import random 12 | import numpy as np 13 | from data import utils 14 | import scipy.io 15 | 16 | 17 | 18 | def proc_im(image, mask, apply_mask=True): 19 | # read image 20 | image = skimage.io.imread(image) 21 | image = skimage.img_as_float(image).astype(np.float32) 22 | if not apply_mask: 23 | return image 24 | 25 | mask = skimage.io.imread(mask) 26 | mask = skimage.img_as_float(mask).astype(np.float32) 27 | 28 | return image * mask[..., None] 29 | 30 | 31 | def get_transform(opt, channels=3): 32 | mean = 0.5 33 | std = 0.5 34 | transform_list = [transforms.ToTensor(), 35 | transforms.Normalize([mean] * channels, 36 | [std] * channels)] 37 | return transforms.Compose(transform_list) 38 | 39 | 40 | 41 | class SimpleHuman36mDatasetSingle(object): 42 | def __init__(self, root, sample_window=[5, 30], activities=None, 43 | actors=None, split_sequence='full', subsampled_size=None, subsample_seed=None): 44 | 45 | self.root = root 46 | self.sample_window = sample_window 47 | 48 | self.ordered_stream = None 49 | 50 | # load dataset 51 | self.sequences = [] 52 | for actor in actors: 53 | sequences = os.listdir(os.path.join(root, actor, 'BackgroudMask')) 54 | sequences = sorted(sequences) 55 | for activity in activities: 56 | activity_sequences = [s for s in sequences if s.lower().startswith(activity.lower())] 57 | for seq in activity_sequences: 58 | frames = os.listdir(os.path.join(root, actor, 'BackgroudMask', seq)) 59 | frames = [int(os.path.splitext(x)[0]) for x in frames] 60 | frames = sorted(frames) 61 | if split_sequence == 'full': 62 | pass 63 | elif split_sequence == 'first_half': 64 | frames = frames[:len(frames) // 2] 65 | elif split_sequence == 'second_half': 66 | frames = frames[len(frames) // 2:] 67 | else: 68 | raise ValueError() 69 | self.sequences.append({'frames': frames, 'actor': actor, 'activity_sequence': seq}) 70 | if subsampled_size: 71 | sequences_ = [] 72 | rnd = random.Random(subsample_seed) 73 | for _ in range(subsampled_size): 74 | seq = rnd.choice(self.sequences).copy() 75 | seq['frames'] = [rnd.choice(seq['frames'])] 76 | sequences_.append(seq) 77 | self.sequences = sequences_ 78 | 79 | 80 | def get_pair(self, sequence, frame1, frame2): 81 | def get_single(sequence, frame): 82 | mat_file = os.path.join(self.root, sequence['actor'], 'Landmarks', sequence['activity_sequence'], str(frame) + '.mat') 83 | mat = scipy.io.loadmat(mat_file) 84 | landmarks = mat['keypoints_2d'] * 128.0 85 | return { 86 | 'image': os.path.join(self.root, sequence['actor'], 'WithBackground', sequence['activity_sequence'], str(frame) + '.jpg'), 87 | 'mask': os.path.join(self.root, sequence['actor'], 'BackgroudMask', sequence['activity_sequence'], str(frame) + '.png'), 88 | 'landmarks': landmarks 89 | } 90 | return get_single(sequence, frame1), get_single(sequence, frame2) 91 | 92 | 93 | def get_ordered_stream(self): 94 | if self.ordered_stream is None: 95 | self.ordered_stream = [] 96 | for sequence in self.sequences: 97 | step = self.sample_window[1] 98 | for i in range(0, len(sequence['frames']), step): 99 | frame = sequence['frames'][i] 100 | self.ordered_stream.append((sequence, frame)) 101 | return self.ordered_stream 102 | 103 | 104 | def get_item(self, index): 105 | ordered_stream = self.get_ordered_stream() 106 | sequence, frame = ordered_stream[index] 107 | return self.get_pair(sequence, frame, frame) 108 | 109 | 110 | def sample_item(self): 111 | sequence = random.choice(self.sequences) 112 | length = len(sequence['frames']) 113 | start = random.randint(0, length - self.sample_window[0] - 1) 114 | end = random.randint( 115 | start + self.sample_window[0], 116 | min(start + self.sample_window[1], length - 1)) 117 | return self.get_pair( 118 | sequence, sequence['frames'][start], 119 | sequence['frames'][end]) 120 | 121 | 122 | def num_samples(self): 123 | return len(self.get_ordered_stream()) 124 | 125 | 126 | 127 | class SimpleHuman36mDataset(BaseDataset): 128 | 129 | @staticmethod 130 | def modify_commandline_options(parser, is_train): 131 | parser.add_argument('--sample_window', type=int, default=[5, 30], nargs=2, help='') 132 | parser.add_argument('--no_mask', action='store_true', help='') 133 | parser.add_argument('--skeleton_subset_size', type=int, default=0, help='') 134 | parser.add_argument('--skeleton_subset_seed', type=int, default=None, help='') 135 | return parser 136 | 137 | def initialize(self, opt): 138 | self.opt = opt 139 | self.root = opt.dataroot 140 | self.load_images = True 141 | if hasattr(opt, 'load_images'): 142 | self.load_images = opt.load_images 143 | 144 | self.use_mask = not self.opt.no_mask 145 | 146 | activities = ['directions', 'discussion', 'greeting', 'posing', 147 | 'waiting', 'walking'] 148 | train_actors = ['S%d' % i for i in [1, 5, 6, 7, 8, 9]] 149 | val_actors = ['S%d' % i for i in [11]] 150 | test_actors = val_actors 151 | 152 | if opt.subset == 'train': 153 | actors = train_actors 154 | elif opt.subset == 'val': 155 | actors = val_actors 156 | elif opt.subset == 'test': 157 | actors = test_actors 158 | else: 159 | raise ValueError() 160 | 161 | if 'train' in opt.phase: 162 | order_stream = False 163 | split_sequence = 'first_half' 164 | sample_window = opt.sample_window 165 | elif opt.phase == 'val': 166 | order_stream = True 167 | split_sequence = 'full' 168 | order_stream = True 169 | sample_window = opt.sample_window 170 | elif opt.phase == 'test': 171 | order_stream = True 172 | split_sequence = 'full' 173 | sample_window = opt.sample_window 174 | else: 175 | ValueError() 176 | 177 | self.dataset = SimpleHuman36mDatasetSingle( 178 | self.root, sample_window=sample_window, 179 | activities=activities, actors=actors, 180 | split_sequence=split_sequence) 181 | 182 | if 'train' in opt.phase: 183 | self.skeleton_dataset = SimpleHuman36mDatasetSingle( 184 | self.root, sample_window=[0, 0], 185 | activities=activities, actors=actors, 186 | split_sequence='second_half', 187 | subsampled_size=opt.skeleton_subset_size, 188 | subsample_seed=opt.skeleton_subset_seed) 189 | else: 190 | self.skeleton_dataset = self.dataset 191 | 192 | if opt.phase == 'train': 193 | self.len = int(10e7) 194 | else: 195 | self.len = self.dataset.num_samples() 196 | self.ordered_stream = order_stream 197 | 198 | self.A_transform = get_transform(opt) 199 | self.B_transform = get_transform(opt, channels=opt.output_nc) 200 | 201 | def _get_sample(self, dataset, index, load_image=True): 202 | if self.ordered_stream: 203 | source, target = dataset.get_item(index) 204 | else: 205 | source, target = dataset.sample_item() 206 | landmarks = utils.swap_xy_points(source['landmarks']) 207 | future_landmarks = utils.swap_xy_points(target['landmarks']) 208 | 209 | landmarks = landmarks.astype('float32') 210 | future_landmarks = future_landmarks.astype('float32') 211 | 212 | if load_image: 213 | future_image = proc_im(source['image'], source['mask'], apply_mask=self.use_mask) 214 | source_image = proc_im(target['image'], target['mask'], apply_mask=self.use_mask) 215 | else: 216 | future_image = None 217 | source_image = None 218 | 219 | 220 | return source_image, future_image, source['image'], target['image'], landmarks, future_landmarks, 221 | 222 | 223 | def __getitem__(self, index): 224 | # sample 225 | cond_A_img, A_img, cond_A_path, A_paths, paired_cond_B, paired_B = self._get_sample(self.dataset, index) 226 | 227 | # sample B 228 | _, _, _, _, _, B = self._get_sample(self.skeleton_dataset, index, load_image=False) 229 | 230 | # normalize keypoints 231 | paired_cond_B = utils.normalize_points( 232 | paired_cond_B, self.opt.fineSize, self.opt.fineSize) 233 | paired_B = utils.normalize_points( 234 | paired_B, self.opt.fineSize, self.opt.fineSize) 235 | B = utils.normalize_points( 236 | B, self.opt.fineSize, self.opt.fineSize) 237 | 238 | if self.load_images: 239 | A = self.A_transform(A_img) 240 | cond_A = self.A_transform(cond_A_img) 241 | 242 | data = {'B': B, 'paired_cond_B': paired_cond_B, 'paired_B': paired_B, 243 | 'A_paths': A_paths, 'cond_A_path': cond_A_path} 244 | if self.load_images: 245 | data.update({'A': A, 'cond_A': cond_A}) 246 | return data 247 | 248 | 249 | def __len__(self): 250 | return self.len 251 | 252 | def name(self): 253 | return 'UnalignedDataset' 254 | -------------------------------------------------------------------------------- /data/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import skimage 3 | import torchvision.transforms as transforms 4 | 5 | 6 | 7 | def find_common_box(boxes): 8 | """ 9 | Finds the union of boxes, represented as [xmin, ymin, xmax, ymax]. 10 | """ 11 | boxes = np.stack(boxes, axis=0) 12 | box = np.concatenate([np.min(boxes[:, :2], axis=0), 13 | np.max(boxes[:, 2:], axis=0)], axis=0) 14 | return box 15 | 16 | 17 | def fit_box(box, width, height): 18 | """ 19 | Ajusts box size to have the same aspect ratio as the target image 20 | while preserving the centre. 21 | """ 22 | box = box.astype('float32') 23 | im_w, im_h = float(width), float(height) 24 | w, h = box[2] - box[0], box[3] - box[1] 25 | 26 | # r_im - image aspect ratio, r - box aspect ratio 27 | r_im = im_w / im_h 28 | r = w / h 29 | 30 | centre = [box[0] + w / 2, box[1] + h / 2] 31 | 32 | if r < r_im: 33 | h, w = h, r_im * h 34 | else: 35 | h, w = (1 / r_im) * w, w 36 | 37 | box = [centre[0] - w / 2, centre[1] - h / 2, 38 | centre[0] + w / 2, centre[1] + h / 2] 39 | 40 | box = np.array(box, dtype='int32') 41 | return box 42 | 43 | 44 | def crop_to_box(image, bbox, pad=True): 45 | bbox = bbox.astype('int32') 46 | if pad: 47 | sz = image.shape[:2] 48 | pad_top = -min(0, bbox[1]) 49 | pad_left = -min(0, bbox[0]) 50 | pad_bottom = -min(0, sz[0] - bbox[3]) 51 | pad_right = -min(0, sz[1] - bbox[2]) 52 | image = np.pad( 53 | image, [[pad_top, pad_bottom], [pad_left, pad_right], [0, 0]], 54 | 'constant') 55 | bbox[1], bbox[3] = bbox[1] + pad_top, bbox[3] + pad_top 56 | bbox[0], bbox[2] = bbox[0] + pad_left, bbox[2] + pad_left 57 | image = image[bbox[1]:bbox[3], bbox[0]:bbox[2]] 58 | return image 59 | 60 | 61 | def get_crop_size(box, pad=True): 62 | box = box.copy() 63 | if pad: 64 | pad_top = -min(0, box[1]) 65 | pad_left = -min(0, box[0]) 66 | box[1], box[3] = box[1] + pad_top, box[3] + pad_top 67 | box[0], box[2] = box[0] + pad_left, box[2] + pad_left 68 | return box[2] - box[0], box[3] - box[1] 69 | 70 | 71 | def resize_points(points, width, height, target_width, target_height): 72 | dtype = points.dtype 73 | ratio = np.array([target_width, target_height], dtype='float32') / \ 74 | np.array([width, height], dtype='float32') 75 | points = (points.astype('float32') * ratio[None]).astype(dtype) 76 | return points 77 | 78 | 79 | def box_from_points(points): 80 | min_xy = np.min(points, axis=0) 81 | max_xy = np.max(points, axis=0) 82 | return np.concatenate([min_xy, max_xy], axis=0) 83 | 84 | 85 | def swap_xy_box(box): 86 | box[:] = box[[1, 0, 3, 2]] 87 | return box 88 | 89 | 90 | def swap_xy_points(points): 91 | points[:, :] = points[:, [1, 0]] 92 | return points 93 | 94 | 95 | def normalize_points(points, width, height): 96 | return 2.0 * points / np.array([width, height], dtype='float32') - 1.0 97 | 98 | 99 | def render_gaussian_maps(mu, shape_hw, inv_std, mode='rot'): 100 | """ 101 | Generates [B,SHAPE_H,SHAPE_W,NMAPS] tensor of 2D gaussians, 102 | given the gaussian centers: MU [B, NMAPS, 2] tensor. 103 | 104 | STD: is the fixed standard dev. 105 | """ 106 | mu_y, mu_x = mu[:, :, 0:1], mu[:, :, 1:2] 107 | 108 | y = np.linspace(-1.0, 1.0, shape_hw[0]).astype('float32') 109 | 110 | x = np.linspace(-1.0, 1.0, shape_hw[1]).astype('float32') 111 | 112 | mu_y, mu_x = mu_y[..., None], mu_x[..., None] 113 | 114 | y = np.reshape(y, [1, 1, shape_hw[0], 1]) 115 | x = np.reshape(x, [1, 1, 1, shape_hw[1]]) 116 | 117 | g_y = np.square(y - mu_y) 118 | g_x = np.square(x - mu_x) 119 | dist = (g_y + g_x) * inv_std**2 120 | 121 | if mode == 'rot': 122 | g_yx = np.exp(-dist) 123 | else: 124 | g_yx = np.exp(-np.power(dist + 1e-5, 0.25)) 125 | 126 | g_yx = np.transpose(g_yx, axes=[0, 2, 3, 1]) 127 | return g_yx 128 | 129 | 130 | def render_points(points, width, height): 131 | points = normalize_points(points, width, height) 132 | maps = render_gaussian_maps( 133 | swap_xy_points(points)[None], [height, width], 50) 134 | maps = maps[0] 135 | maps *= np.max(maps) 136 | return maps 137 | 138 | 139 | def render_line_segment(s1, s2, size, distance='gauss', discrete=False): 140 | def sumprod(x, y): 141 | return np.sum(x * y, axis=-1, keepdims=True) 142 | 143 | x = np.linspace(-1.0, 1.0, size).astype('float32') 144 | y = np.linspace(-1.0, 1.0, size).astype('float32') 145 | 146 | xv, yv = np.meshgrid(x, y) 147 | m = np.concatenate([xv[..., None], yv[..., None]], axis=-1) 148 | 149 | s1, s2 = s1[None, None], s2[None, None] 150 | t_min = sumprod(m - s1, s2 - s1) / \ 151 | np.maximum(sumprod(s2 - s1, s2 - s1), 1e-6) 152 | t_line = np.minimum(np.maximum(t_min, 0.0), 1.0) 153 | 154 | s = s1 + t_line * (s2 - s1) 155 | d = np.sqrt(sumprod(s - m, s - m)) 156 | 157 | if discrete: 158 | distance = 'norm' 159 | 160 | # normalize distance 161 | if distance == 'gauss': 162 | d_norm = np.exp(-d / (0.2 ** 2)) 163 | elif distance == 'norm': 164 | d_max = np.sqrt(8) 165 | d_norm = (d_max - d) / d_max 166 | else: 167 | raise ValueError() 168 | 169 | thick = 0.9925 170 | if discrete: 171 | d_norm[d_norm >= thick] = 1.0 172 | d_norm[d_norm < thick] = 0.0 173 | 174 | return d_norm 175 | 176 | 177 | def render_skeleton(points, connections, width, height, colored=False): 178 | assert width == height 179 | maps = [] 180 | numbers = np.linspace(0.2, 1.0, len(connections)) 181 | discrete = False 182 | if colored: 183 | discrete = True 184 | for (a, b), number in zip(connections, numbers): 185 | render = render_line_segment( 186 | points[a], points[b], width, discrete=discrete) 187 | if colored: 188 | render *= number 189 | maps.append(render) 190 | maps = np.concatenate(maps, axis=-1) 191 | return maps 192 | 193 | 194 | def proc_im(image, box, landmarks, target_width, target_height, keep_aspect=True, load_image=True): 195 | # read image 196 | if load_image: 197 | image = skimage.io.imread(image) 198 | if len(image.shape) == 2: 199 | image = np.tile(image[..., None], (1, 1, 3)) 200 | 201 | # crop to bounding box 202 | if keep_aspect: 203 | box = fit_box(box, target_width, target_height) 204 | if load_image: 205 | image = crop_to_box(image, box) 206 | else: 207 | width, height = get_crop_size(box) 208 | if landmarks is not None: 209 | landmarks = landmarks - box[:2][None].astype(landmarks.dtype) 210 | 211 | # resize 212 | if landmarks is not None: 213 | if load_image: 214 | height, width = image.shape[:2] 215 | landmarks = resize_points( 216 | landmarks, height, width, target_width, target_height) 217 | 218 | if load_image: 219 | image = skimage.transform.resize(image, [target_height, target_width]) 220 | image = skimage.img_as_float(image).astype(np.float32) 221 | 222 | height_ratio, width_ratio = float(target_height) / height, float(target_width) / width 223 | 224 | return image, landmarks, height_ratio, width_ratio 225 | 226 | 227 | def get_transform(opt, channels=3, normalize=True): 228 | mean = 0.5 229 | std = 0.5 230 | transform_list = [transforms.ToTensor()] 231 | if normalize: 232 | transform_list += [transforms.Normalize([mean] * channels, 233 | [std] * channels)] 234 | return transforms.Compose(transform_list) 235 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from models.base_model import BaseModel 3 | 4 | 5 | def find_model_using_name(model_name): 6 | # Given the option --model [modelname], 7 | # the file "models/modelname_model.py" 8 | # will be imported. 9 | model_filename = "models." + model_name + "_model" 10 | modellib = importlib.import_module(model_filename) 11 | 12 | # In the file, the class called ModelNameModel() will 13 | # be instantiated. It has to be a subclass of BaseModel, 14 | # and it is case-insensitive. 15 | model = None 16 | target_model_name = model_name.replace('_', '') + 'model' 17 | for name, cls in modellib.__dict__.items(): 18 | if name.lower() == target_model_name.lower() \ 19 | and issubclass(cls, BaseModel): 20 | model = cls 21 | 22 | if model is None: 23 | print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name)) 24 | exit(0) 25 | 26 | return model 27 | 28 | 29 | def get_option_setter(model_name): 30 | model_class = find_model_using_name(model_name) 31 | return model_class.modify_commandline_options 32 | 33 | 34 | def create_model(opt): 35 | model = find_model_using_name(opt.model) 36 | instance = model() 37 | instance.initialize(opt) 38 | print("model [%s] was created" % (instance.name())) 39 | return instance 40 | -------------------------------------------------------------------------------- /models/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from collections import OrderedDict 4 | from . import networks 5 | 6 | 7 | class BaseModel(): 8 | 9 | # modify parser to add command line options, 10 | # and also change the default values if needed 11 | @staticmethod 12 | def modify_commandline_options(parser, is_train): 13 | return parser 14 | 15 | def name(self): 16 | return 'BaseModel' 17 | 18 | def initialize(self, opt): 19 | self.opt = opt 20 | self.gpu_ids = opt.gpu_ids 21 | self.isTrain = opt.isTrain 22 | self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') 23 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) 24 | self.nets_paths = {} 25 | if opt.nets_paths: 26 | self.nets_paths = dict( 27 | zip(opt.nets_paths[::2], opt.nets_paths[1::2])) 28 | if opt.resize_or_crop != 'scale_width': 29 | torch.backends.cudnn.benchmark = True 30 | self.loss_names = [] 31 | self.load_model_names = [] 32 | self.save_model_names = [] 33 | self.visual_names = [] 34 | self.image_paths = [] 35 | 36 | def set_input(self, input): 37 | self.input = input 38 | 39 | def forward(self): 40 | pass 41 | 42 | # load and print networks; create schedulers 43 | def setup(self, opt, parser=None): 44 | if self.isTrain: 45 | self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers] 46 | 47 | if not self.isTrain or opt.continue_train or opt.resume_from_name is not None: 48 | resume_from_dir = None 49 | if opt.resume_from_name is not None: 50 | resume_from_dir = os.path.join(opt.checkpoints_dir, opt.resume_from_name) 51 | self.load_networks(opt.iteration, resume_from_dir=resume_from_dir) 52 | for name, path in self.nets_paths.items(): 53 | self.load_network(name, path) 54 | self.print_networks(opt.verbose) 55 | 56 | # make models eval mode during test time 57 | def eval(self): 58 | for name in self.load_model_names: 59 | if isinstance(name, str): 60 | net = getattr(self, 'net' + name) 61 | net.eval() 62 | 63 | # used in test time, wrapping `forward` in no_grad() so we don't save 64 | # intermediate steps for backprop 65 | def test(self): 66 | with torch.no_grad(): 67 | self.forward() 68 | 69 | # get image paths 70 | def get_image_paths(self): 71 | return self.image_paths 72 | 73 | def optimize_parameters(self): 74 | pass 75 | 76 | # FIXME: not called anywhere because we got rid of epochs 77 | # update learning rate (called once every epoch) 78 | def update_learning_rate(self): 79 | for scheduler in self.schedulers: 80 | scheduler.step() 81 | lr = self.optimizers[0].param_groups[0]['lr'] 82 | print('learning rate = %.7f' % lr) 83 | 84 | # return visualization images. train.py will display these images, and save the images to a html 85 | def get_current_visuals(self): 86 | self.compute_visuals() 87 | visual_ret = OrderedDict() 88 | for name in self.visual_names: 89 | if isinstance(name, str): 90 | visual_ret[name] = getattr(self, name) 91 | return visual_ret 92 | 93 | def compute_visuals(self): 94 | pass 95 | 96 | # return traning losses/errors. train.py will print out these errors as debugging information 97 | def get_current_losses(self): 98 | errors_ret = OrderedDict() 99 | for name in self.loss_names: 100 | if isinstance(name, str): 101 | # float(...) works for both scalar tensor and float number 102 | errors_ret[name] = float(getattr(self, 'loss_' + name)) 103 | return errors_ret 104 | 105 | # save models to the disk 106 | def save_networks(self, iteration): 107 | for name in self.save_model_names: 108 | if isinstance(name, str): 109 | save_filename = '%s_net_%s.pth' % (iteration, name) 110 | save_path = os.path.join(self.save_dir, save_filename) 111 | net = getattr(self, 'net' + name) 112 | 113 | if len(self.gpu_ids) > 0 and torch.cuda.is_available(): 114 | torch.save(net.module.cpu().state_dict(), save_path) 115 | net.cuda(self.gpu_ids[0]) 116 | else: 117 | torch.save(net.cpu().state_dict(), save_path) 118 | 119 | def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0): 120 | key = keys[i] 121 | if i + 1 == len(keys): # at the end, pointing to a parameter/buffer 122 | if module.__class__.__name__.startswith('InstanceNorm') and \ 123 | (key == 'running_mean' or key == 'running_var'): 124 | if getattr(module, key) is None: 125 | state_dict.pop('.'.join(keys)) 126 | if module.__class__.__name__.startswith('InstanceNorm') and \ 127 | (key == 'num_batches_tracked'): 128 | state_dict.pop('.'.join(keys)) 129 | else: 130 | self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1) 131 | 132 | # load models from the disk 133 | def load_networks(self, iteration, resume_from_dir=None): 134 | for name in self.load_model_names: 135 | if isinstance(name, str): 136 | if name in self.nets_paths: 137 | load_path = self.nets_paths[name] 138 | else: 139 | if resume_from_dir is None: 140 | save_dir = self.save_dir 141 | else: 142 | save_dir = resume_from_dir 143 | load_filename = '%s_net_%s.pth' % (iteration, name) 144 | load_path = os.path.join(save_dir, load_filename) 145 | self.load_network(name, load_path) 146 | 147 | 148 | def load_network(self, name, load_path): 149 | if isinstance(name, str): 150 | net = getattr(self, 'net' + name) 151 | if isinstance(net, torch.nn.DataParallel): 152 | net = net.module 153 | print('loading the model from %s' % load_path) 154 | # if you are using PyTorch newer than 0.4 (e.g., built from 155 | # GitHub source), you can remove str() on self.device 156 | state_dict = torch.load(load_path, map_location=str(self.device)) 157 | if hasattr(state_dict, '_metadata'): 158 | del state_dict._metadata 159 | 160 | # patch InstanceNorm checkpoints prior to 0.4 161 | for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop 162 | self.__patch_instance_norm_state_dict(state_dict, net, key.split('.')) 163 | net.load_state_dict(state_dict) 164 | 165 | 166 | # print network information 167 | def print_networks(self, verbose): 168 | print('---------- Networks initialized -------------') 169 | for name in list(set(self.load_model_names + self.save_model_names)): 170 | if isinstance(name, str): 171 | net = getattr(self, 'net' + name) 172 | num_params = 0 173 | for param in net.parameters(): 174 | num_params += param.numel() 175 | if verbose: 176 | print(net) 177 | print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6)) 178 | print('-----------------------------------------------') 179 | 180 | # set requies_grad=Fasle to avoid computation 181 | def set_requires_grad(self, nets, requires_grad=False): 182 | if not isinstance(nets, list): 183 | nets = [nets] 184 | for net in nets: 185 | if net is not None: 186 | for param in net.parameters(): 187 | param.requires_grad = requires_grad 188 | -------------------------------------------------------------------------------- /models/keypoint_gan_model.py: -------------------------------------------------------------------------------- 1 | import cPickle as cp 2 | import itertools 3 | import os 4 | import pickle 5 | from argparse import Namespace 6 | 7 | import torch 8 | from torch.autograd import Variable 9 | from torch.nn import functional as F 10 | 11 | import data.human36m_skeleton 12 | from data import CreateDataLoader 13 | from data.human36m_skeleton import simple_link_indices as human36m_link_indices 14 | from util import plotting, util 15 | from util.image_pool import ImagePool 16 | from util.tps_sampler import TPSRandomSampler 17 | 18 | from . import networks, utils 19 | from .base_model import BaseModel 20 | from .perceptual_loss import PerceptualLoss 21 | 22 | 23 | class KeypointGANModel(BaseModel): 24 | def name(self): 25 | return 'KeypointGANModel' 26 | 27 | @staticmethod 28 | def modify_commandline_options(parser, is_train=True): 29 | # default CycleGAN did not use dropout 30 | parser.set_defaults(no_dropout=True) 31 | if is_train: 32 | parser.add_argument('--lambda_A', type=float, default=10.0, help='weight for cycle loss (A -> B -> A)') 33 | parser.add_argument('--lambda_gan_A', type=float, default=1.0, help='weight for gan loss') 34 | return parser 35 | 36 | def initialize(self, opt): 37 | BaseModel.initialize(self, opt) 38 | 39 | # opt.phase in ['train', 'train_regressor'] 40 | self.mode = opt.phase 41 | self.no_grad = opt.phase == 'test' 42 | 43 | # ----------------------- Losses to print------------------------------- 44 | # specify the training losses you want to print out. The program will call base_model.get_current_losses 45 | if opt.phase == 'train': 46 | self.loss_names = [] 47 | 48 | if not opt.not_optimize_G: 49 | self.loss_names += ['cycle_A'] 50 | self.loss_names += ['G_A'] 51 | if opt.lambda_render_consistency > 0: 52 | self.loss_names += ['render_consistency'] 53 | 54 | if not opt.not_optimize_D: 55 | self.loss_names += ['D_A'] 56 | 57 | if opt.finetune_regressor: 58 | self.loss_names += ['regressor'] 59 | 60 | elif opt.phase == 'train_regressor': 61 | self.loss_names = ['regressor'] 62 | 63 | # ----------------------- Visualizations ------------------------------- 64 | # specify the images you want to save/display. The program will call base_model.get_current_visuals 65 | if opt.phase in ['train', 'test']: 66 | if opt.eval_pose_prediction_only: 67 | visual_names_A = ['real_A', 68 | 'fake_B', 'fake_B_regress'] 69 | else: 70 | visual_names_A = ['real_A', 71 | 'fake_B', 'fake_B_regress', 'rec_A'] 72 | if opt.phase != 'test': 73 | visual_names_A.append('real_B') 74 | 75 | visual_names_A.insert(1, 'real_cond_A') 76 | 77 | if self.opt.offline_regressor: 78 | visual_names_A += ['offline_regress'] 79 | 80 | visual_names_B = [] 81 | 82 | self.visual_names = visual_names_A + visual_names_B 83 | if opt.phase in ['train_regressor']: 84 | self.visual_names = ['real_B', 'fake_B_regress'] 85 | 86 | # ----------------------- Networks save/load---------------------------- 87 | # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks 88 | model_names = ['G_B'] 89 | if self.isTrain: 90 | model_names += ['G_A'] 91 | model_names += ['D_A'] 92 | else: # during test time, only load Gs 93 | model_names += ['G_A'] 94 | 95 | if opt.phase == 'train_regressor': 96 | self.load_model_names = ['regressor'] 97 | self.save_model_names = ['regressor'] 98 | else: 99 | self.load_model_names = model_names 100 | self.load_model_names += ['regressor'] 101 | if self.opt.offline_regressor: 102 | self.load_model_names += ['offline_regressor'] 103 | self.save_model_names = model_names 104 | 105 | if opt.finetune_regressor: 106 | self.save_model_names += ['regressor'] 107 | 108 | # ----------------------- Define networks ------------------------------ 109 | # load/define networks 110 | self.netregressor = networks.define_regressor( 111 | 1, self.opt.n_points, norm=opt.regressor_norm, init_type=opt.init_type, 112 | init_gain=opt.init_gain, gpu_ids=self.gpu_ids, 113 | net_type=opt.net_regressor, n_channels=opt.net_regressor_channels) 114 | 115 | if self.opt.offline_regressor: 116 | self.netoffline_regressor = networks.define_regressor( 117 | 1, self.opt.n_points, norm=opt.regressor_norm, init_type=opt.init_type, 118 | init_gain=opt.init_gain, gpu_ids=self.gpu_ids, 119 | net_type=opt.net_regressor, n_channels=opt.net_regressor_channels) 120 | 121 | self.netG_A = networks.define_G( 122 | opt.input_nc, opt.output_nc, opt.netG_A, norm=opt.generators_norm, 123 | init_type=opt.init_type, init_gain=opt.init_gain, gpu_ids=self.gpu_ids, 124 | n_blocks=opt.netG_A_blocks) 125 | 126 | netG_input_nc = opt.output_nc 127 | self.netG_B = networks.define_G_cond( 128 | netG_input_nc, opt.input_nc, opt.input_nc, opt.netG_B, norm=opt.generators_norm, 129 | init_type=opt.init_type, init_gain=opt.init_gain, gpu_ids=self.gpu_ids, 130 | avg_pool_cond=opt.avg_pool_style) 131 | 132 | if self.isTrain: 133 | use_sigmoid = opt.no_lsgan 134 | self.netD_A = networks.define_D( 135 | opt.output_nc, opt.ndf, opt.netDA, multi_gan=opt.multi_ganA, 136 | n_layers_D=opt.n_layers_D, norm=opt.discriminators_norm, 137 | use_sigmoid=use_sigmoid, init_type=opt.init_type, init_gain=opt.init_gain, 138 | gpu_ids=self.gpu_ids) 139 | 140 | if self.isTrain: 141 | # ------------------------- Criterions ----------------------------- 142 | self.fake_A_pool = ImagePool(opt.pool_size) 143 | self.fake_B_pool = ImagePool(opt.pool_size) 144 | # define loss functions 145 | gan_loss = networks.MultiGANLoss if opt.multi_ganA else networks.GANLoss 146 | self.criterionGAN = gan_loss(use_lsgan=not opt.no_lsgan).to(self.device) 147 | self.single_scale_criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan).to(self.device) 148 | if self.opt.cycle_loss == 'l1': 149 | self.criterionCycle = torch.nn.L1Loss() 150 | elif self.opt.cycle_loss == 'perceptual': 151 | self.criterionCycle = PerceptualLoss(self.opt.perceptual_net) 152 | # '/scratch/local/hdd/ankush/minmaxinfo/data/models/imagenet-vgg-verydeep-19.mat' 153 | else: 154 | raise ValueError('Unknown cycle loss: %s' % self.opt.cycle_loss) 155 | self.criterionIdt = torch.nn.L1Loss() 156 | self.criterion_regressor = torch.nn.MSELoss() 157 | 158 | # -------------------------- Optimizers ---------------------------- 159 | # initialize optimizers 160 | G_params = [self.netG_B.parameters()] 161 | G_params += [self.netG_A.parameters()] 162 | self.optimizer_G = torch.optim.Adam(itertools.chain(*G_params), 163 | lr=opt.lr, betas=(opt.beta1, 0.999)) 164 | D_params = [] 165 | D_params += [self.netD_A.parameters()] 166 | self.optimizer_D = torch.optim.Adam(itertools.chain(*D_params), lr=opt.lr, betas=(opt.beta1, 0.999)) 167 | 168 | self.optimizer_regressor = torch.optim.Adam( 169 | self.netregressor.parameters(), lr=opt.lr, 170 | betas=(opt.beta1, 0.999)) 171 | 172 | self.optimizers = [] 173 | self.optimizers.append(self.optimizer_G) 174 | self.optimizers.append(self.optimizer_D) 175 | self.optimizers.append(self.optimizer_regressor) 176 | 177 | # ------------------------------ TPS ----------------------------------- 178 | if self.opt.tps: 179 | self.tps_sampler = TPSRandomSampler( 180 | opt.fineSize, opt.fineSize, rotsd=5.0, scalesd=0.05, transsd=0.05, 181 | warpsd=(0.0005, 0.005)) 182 | self.tps_sampler_target = TPSRandomSampler( 183 | opt.fineSize, opt.fineSize, rotsd=5.0, scalesd=0.05, transsd=0.05, 184 | warpsd=(0.0, 0.0)) 185 | 186 | 187 | def set_input(self, input): 188 | self.input = input 189 | 190 | if 'A' in input: 191 | self.real_A = input['A'].to(self.device) 192 | if 'cond_A' in input: 193 | self.real_cond_A = input['cond_A'].to(self.device) 194 | self.real_B_points = input['B'].to(self.device) 195 | 196 | if 'paired_cond_B' in input: 197 | self.paired_cond_B_points = input['paired_cond_B'].to(self.device) 198 | self.paired_cond_B = self.render_skeleton( 199 | self.paired_cond_B_points, skeleton_type=self.opt.paired_skeleton_type) 200 | 201 | if 'paired_B' in input: 202 | self.paired_B = input['paired_B'].to(self.device) 203 | else: 204 | self.paired_B = self.real_B_points 205 | 206 | if 'B_visible' in input: 207 | self.B_visible = input['B_visible'].to(self.device) 208 | if 'paired_B_visible' in input: 209 | self.paired_B_visible = input['paired_B_visible'].to(self.device) 210 | 211 | if 'A_paths' in input: 212 | self.image_paths = input['A_paths'] 213 | # warp input 214 | if self.opt.tps: 215 | self.real_cond_A = self.tps_sampler(self.real_cond_A) 216 | if self.opt.tps_target: 217 | self.real_A = self.tps_sampler_target(self.real_A) 218 | 219 | # shuffle real_cond_A 220 | if self.opt.shuffle_identities: 221 | torch.manual_seed(0) 222 | self.real_cond_A = self.real_cond_A[torch.randperm(self.real_cond_A.shape[0])] 223 | 224 | self.paired_B_points = self.paired_B 225 | 226 | skeleton_type = self.opt.skeleton_type 227 | if self.opt.prior_skeleton_type is not None: 228 | skeleton_type = self.opt.prior_skeleton_type 229 | self.real_B = self.render_skeleton( 230 | self.real_B_points, skeleton_type=skeleton_type, reduce=self.opt.reduce_rendering_mode) 231 | 232 | def forward(self): 233 | if self.mode == 'train_regressor': 234 | maps = self.netregressor(self.real_B) 235 | self.regressed_points = utils.extract_points(maps) 236 | skeleton_type = self.opt.skeleton_type 237 | if self.opt.prior_skeleton_type is not None: 238 | skeleton_type = self.opt.prior_skeleton_type 239 | self.fake_B_regress = self.render_skeleton( 240 | self.regressed_points, skeleton_type=skeleton_type, 241 | reduce='max') 242 | 243 | elif not self.opt.eval_pose_prediction_only: 244 | if self.opt.finetune_regressor: 245 | maps = self.netregressor(self.real_B) 246 | self.real_B_regressed_points = utils.extract_points(maps) 247 | self.real_B_regress = self.render_skeleton( 248 | self.real_B_regressed_points, 249 | skeleton_type=self.opt.skeleton_type) 250 | 251 | self.fake_B = self.netG_A(self.real_A) 252 | maps = self.netregressor(self.fake_B) 253 | 254 | self.regressed_points = utils.extract_points(maps) 255 | 256 | skeleton_type = self.opt.skeleton_type 257 | if self.opt.prior_skeleton_type is not None: 258 | skeleton_type = self.opt.prior_skeleton_type 259 | fake_B_regress_multi_ch = self.render_skeleton( 260 | self.regressed_points, reduce=None, 261 | skeleton_type=skeleton_type) 262 | self.fake_B_regress = self.reduce_renderings( 263 | fake_B_regress_multi_ch, reduce=self.opt.reduce_rendering_mode, keepdim=True) 264 | 265 | if self.opt.offline_regressor: 266 | # offline_regressor_input 267 | maps = self.netoffline_regressor(self.fake_B_regress) 268 | self.offline_regressed_points = utils.extract_points(maps) 269 | self.offline_regress = self.render_skeleton( 270 | self.offline_regressed_points, colored=True, 271 | skeleton_type=self.opt.skeleton_type) 272 | 273 | netG_B_input = self.fake_B_regress 274 | 275 | self.rec_A = self.netG_B(netG_B_input, self.real_cond_A) 276 | 277 | 278 | def backward_D_basic(self, netD, real, fake): 279 | # Real 280 | pred_real = netD(real) 281 | loss_D_real = self.criterionGAN(pred_real, True) 282 | # Fake 283 | pred_fake = netD(fake.detach()) 284 | loss_D_fake = self.criterionGAN(pred_fake, False) 285 | # Combined loss 286 | loss_D = (loss_D_real + loss_D_fake) * 0.5 287 | # backward 288 | loss_D.backward() 289 | return loss_D 290 | 291 | def backward_D_single_scale(self, netD, real, fake): 292 | # Real 293 | pred_real = netD(real) 294 | loss_D_real = self.single_scale_criterionGAN(pred_real, True) 295 | # Fake 296 | pred_fake = netD(fake.detach()) 297 | loss_D_fake = self.single_scale_criterionGAN(pred_fake, False) 298 | # Combined loss 299 | loss_D = (loss_D_real + loss_D_fake) * 0.5 300 | # backward 301 | loss_D.backward() 302 | return loss_D 303 | 304 | 305 | def backward_D_A(self): 306 | fake_B = self.fake_B 307 | real_B = self.real_B 308 | fake_B = self.fake_B_pool.query(fake_B) 309 | self.loss_D_A = self.backward_D_basic(self.netD_A, real_B, fake_B) 310 | 311 | 312 | def backward_G(self, retain_graph=False): 313 | lambda_A = self.opt.lambda_A 314 | lambda_gan_A = self.opt.lambda_gan_A 315 | 316 | # GAN loss D_A(G_A(A)) 317 | fake_B = self.fake_B 318 | self.loss_G_A = self.criterionGAN(self.netD_A(fake_B), True) * lambda_gan_A 319 | 320 | # Forward cycle loss 321 | self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A 322 | 323 | self.loss_render_consistency = 0 324 | if self.opt.lambda_render_consistency > 0: 325 | self.loss_render_consistency = self.criterion_regressor( 326 | self.fake_B, self.fake_B_regress.detach()) * self.opt.lambda_render_consistency 327 | 328 | # combined loss 329 | self.loss_G = self.loss_G_A + self.loss_cycle_A + self.loss_render_consistency 330 | self.loss_G.backward(retain_graph=retain_graph) 331 | 332 | 333 | def backward_regressor(self): 334 | regressed_points = self.regressed_points 335 | real_B_points = self.real_B_points 336 | if self.opt.only_visible_points_loss: 337 | regressed_points = regressed_points * self.B_visible[:, :, None].type(regressed_points.dtype) 338 | real_B_points = real_B_points * self.B_visible[:, :, None].type(real_B_points.dtype) 339 | self.loss_regressor = self.criterion_regressor(regressed_points, real_B_points) 340 | if self.opt.regressor_im_loss > 0: 341 | self.loss_regressor += self.opt.regressor_im_loss * self.criterion_regressor( 342 | self.fake_B_regress, self.real_B) 343 | self.loss_regressor.backward() 344 | 345 | 346 | def backward_regressor_finetune(self): 347 | self.loss_regressor = self.opt.regressor_fake_loss * self.criterion_regressor( 348 | self.fake_B_regress, self.fake_B.detach()) 349 | real_B_regressed_points = self.real_B_regressed_points 350 | real_B_points = self.real_B_points 351 | if self.opt.only_visible_points_loss: 352 | real_B_regressed_points = real_B_regressed_points * self.B_visible[:, :, None].type(real_B_regressed_points.dtype) 353 | real_B_points = real_B_points * self.B_visible[:, :, None].type(real_B_points.dtype) 354 | self.loss_regressor += self.opt.regressor_real_loss * self.criterion_regressor( 355 | real_B_regressed_points, real_B_points) 356 | self.loss_regressor.backward() 357 | 358 | 359 | def optimize_parameters(self): 360 | # forward 361 | all_nets_but_regressor = [self.netG_B] 362 | all_nets_but_regressor += [self.netG_A] 363 | all_nets_but_regressor += [self.netD_A] 364 | 365 | self.forward() 366 | 367 | if self.mode == 'train_regressor': 368 | self.set_requires_grad(all_nets_but_regressor, False) 369 | self.optimizer_regressor.zero_grad() 370 | self.backward_regressor() 371 | if self.opt.clip_grad < float('inf'): 372 | self.clip_gradient(self.optimizer_regressor, self.opt.clip_grad) 373 | self.optimizer_regressor.step() 374 | 375 | else: 376 | D = [] 377 | D += [self.netD_A] 378 | 379 | self.set_requires_grad(self.netregressor, False) 380 | retain_graph = self.opt.finetune_regressor 381 | 382 | # G_A and G_B 383 | if not self.opt.not_optimize_G: 384 | self.set_requires_grad(D, False) 385 | self.optimizer_G.zero_grad() 386 | self.backward_G(retain_graph=retain_graph) 387 | if self.opt.clip_grad < float('inf'): 388 | self.clip_gradient(self.optimizer_G, self.opt.clip_grad) 389 | self.optimizer_G.step() 390 | 391 | # D_A 392 | if not self.opt.not_optimize_D and len(D) > 0: 393 | self.set_requires_grad(D, True) 394 | self.optimizer_D.zero_grad() 395 | self.backward_D_A() 396 | if self.opt.clip_grad < float('inf'): 397 | self.clip_gradient(self.optimizer_D, self.opt.clip_grad) 398 | self.optimizer_D.step() 399 | 400 | # finetune regressor 401 | if self.opt.finetune_regressor: 402 | self.set_requires_grad(all_nets_but_regressor, False) 403 | self.set_requires_grad(self.netregressor, True) 404 | self.optimizer_regressor.zero_grad() 405 | self.backward_regressor_finetune() 406 | if self.opt.clip_grad < float('inf'): 407 | self.clip_gradient(self.optimizer_regressor, self.opt.clip_grad) 408 | self.optimizer_regressor.step() 409 | self.set_requires_grad(all_nets_but_regressor, True) 410 | 411 | 412 | 413 | def reduce_renderings(self, render, reduce='max', keepdim=True): 414 | if reduce == 'softmax': 415 | weights = F.softmax(render, dim=1) 416 | render = torch.sum(render * weights, dim=1, keepdim=keepdim) 417 | elif reduce == 'mean': 418 | render = torch.mean(render, dim=1, keepdim=keepdim) 419 | elif reduce == 'sum': 420 | render = torch.sum(render, dim=1, keepdim=keepdim) 421 | elif reduce == 'max': 422 | render, _ = torch.max(render, dim=1, keepdim=keepdim) 423 | elif reduce is None: 424 | pass 425 | else: 426 | ValueError() 427 | return render 428 | 429 | 430 | def get_link_indices(self, skeleton_type): 431 | if skeleton_type == 'human36m': 432 | link_indices = human36m_link_indices 433 | elif skeleton_type == 'human36m_simple2': 434 | link_indices = data.human36m_skeleton.simple2_link_indices 435 | elif skeleton_type == 'disconnected': 436 | link_indices = self.get_disconnected_links(self.opt.n_points) 437 | else: 438 | raise ValueError() 439 | return link_indices 440 | 441 | 442 | def render_skeleton(self, points, reduce='max', colored=None, 443 | skeleton_type='human36m', colors=None, centre=True, 444 | normalize=False, widths=None, size=None): 445 | if size is None: 446 | size = (self.opt.fineSize, self.opt.fineSize) 447 | 448 | if skeleton_type != 'points': 449 | link_indices = self.get_link_indices(skeleton_type) 450 | render_fn = utils.render_skeleton 451 | render = render_fn( 452 | points, link_indices, 453 | size[0], size[1], colored=colored, 454 | colors=colors, 455 | normalize=normalize, widths=widths, 456 | sigma=self.opt.sigma) 457 | elif skeleton_type == 'points': 458 | render = utils.render_points( 459 | points, size[0], size[1]) 460 | render = render[:, :, None] 461 | else: 462 | raise ValueError() 463 | 464 | render = self.reduce_renderings(render, reduce=reduce) 465 | 466 | if colors is None: 467 | render = torch.mean(render, dim=2) 468 | 469 | if centre: 470 | render = utils.normalize_im(render) 471 | 472 | return render 473 | 474 | 475 | def get_limb_points(self, points, skeleton_type): 476 | connections = self.get_link_indices(skeleton_type) 477 | return utils.get_line_points(points, connections) 478 | 479 | 480 | def clip_gradient(self, optimizer, clip): 481 | for param_group in optimizer.param_groups: 482 | torch.nn.utils.clip_grad_norm_(param_group['params'], clip) 483 | 484 | 485 | def compute_visuals(self): 486 | if 'train' in self.opt.phase: 487 | return 488 | 489 | self.visual_names.append('real_A_paired_B') 490 | self.visual_names.append('real_A_fake_B') 491 | 492 | if self.opt.plot_skeleton_type is not None: 493 | skeleton_type = self.opt.plot_skeleton_type 494 | else: 495 | skeleton_type = self.opt.skeleton_type 496 | 497 | if skeleton_type != 'points': 498 | link_indices = self.get_link_indices(skeleton_type) 499 | 500 | points = self.regressed_points 501 | real_points = self.paired_B_points 502 | 503 | if self.opt.offline_regressor: 504 | corrected_points = self.correct_flips(points, self.offline_regressed_points) 505 | 506 | if skeleton_type != 'points': 507 | points, new_link_indices = utils.parse_auxiliary_links(points, link_indices) 508 | real_points, _ = utils.parse_auxiliary_links(real_points, link_indices) 509 | 510 | if self.opt.offline_regressor: 511 | self.visual_names.append('real_A_fake_B_offline') 512 | offline_points = self.offline_regressed_points 513 | offline_points, _ = utils.parse_auxiliary_links(offline_points, link_indices) 514 | 515 | self.visual_names.append('real_A_fake_B_offline_adj') 516 | corrected_points, _ = utils.parse_auxiliary_links(corrected_points, link_indices) 517 | 518 | if skeleton_type != 'points': 519 | link_indices = new_link_indices 520 | 521 | if skeleton_type != 'points': 522 | self.real_A_paired_B = plotting.plot_in_image( 523 | util.tensor2im(self.real_A), real_points[0].cpu().numpy(), 524 | color='navy', style='skeleton', connections=link_indices) 525 | self.real_A_fake_B = plotting.plot_in_image( 526 | util.tensor2im( 527 | self.real_A), points[0].cpu().numpy(), 528 | color='limegreen', style='skeleton', connections=link_indices) 529 | if self.opt.offline_regressor: 530 | self.real_A_fake_B_offline = plotting.plot_in_image( 531 | util.tensor2im(self.real_A), offline_points[0].cpu().numpy(), 532 | color='navy', style='skeleton', connections=link_indices) 533 | self.real_A_fake_B_offline_adj = plotting.plot_in_image( 534 | util.tensor2im(self.real_A), corrected_points[0].cpu().numpy(), 535 | color='navy', style='skeleton', connections=link_indices) 536 | elif skeleton_type == 'points': 537 | self.real_A_paired_B = plotting.plot_in_image( 538 | util.tensor2im(self.real_A), real_points[0].cpu().numpy(), 539 | color='cyan', landmark_size=self.opt.plot_landmark_size) 540 | self.real_A_fake_B = plotting.plot_in_image( 541 | util.tensor2im( 542 | self.real_A), points[0].cpu().numpy(), 543 | color='limegreen', landmark_size=self.opt.plot_landmark_size) 544 | else: 545 | raise ValueError() 546 | 547 | 548 | def get_disconnected_links(self, n_points): 549 | return [(i, i + 1) for i in range(0, n_points, 2)] 550 | 551 | 552 | def normalize_points(self, landmarks): 553 | # put in the corner 554 | minv, _ = torch.min(landmarks, dim=1, keepdim=True) 555 | landmarks = landmarks - minv 556 | # normalize between -1, 1 557 | height_width, _ = torch.max(landmarks, dim=1, keepdim=True) 558 | size, _ = torch.max(height_width, dim=2, keepdim=True) 559 | landmarks = 2.0 * (landmarks / size) - 1.0 560 | # centre 561 | maxv, _ = torch.max(landmarks, dim=1, keepdim=True) 562 | landmarks = landmarks + (1.0 - maxv) / 2.0 563 | return landmarks 564 | 565 | 566 | def correct_flips(self, input, offline_prediction): 567 | if self.opt.skeleton_type in ['human36m', 'human36m_simple2']: 568 | correspondences = data.human36m_skeleton.get_lr_correspondences() 569 | else: 570 | raise ValueError() 571 | 572 | input_swapped = utils.swap_points(input, correspondences) 573 | 574 | distance = utils.mean_l2_distance(offline_prediction, input) 575 | swapped_distance = utils.mean_l2_distance(offline_prediction, input_swapped) 576 | min_idx = distance > swapped_distance 577 | corrected_input = torch.zeros_like(input) 578 | for i in range(len(min_idx)): 579 | if min_idx[i]: 580 | corrected_input[i] = input_swapped[i] 581 | else: 582 | corrected_input[i] = input[i] 583 | return corrected_input 584 | -------------------------------------------------------------------------------- /models/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | import functools 5 | from torch.optim import lr_scheduler 6 | import torch.nn.functional as F 7 | 8 | 9 | 10 | ############################################################################### 11 | # Helper Functions 12 | ############################################################################### 13 | class Interpolate(nn.Module): 14 | def __init__(self, size=None, scale=None, mode='nearest'): 15 | super(Interpolate, self).__init__() 16 | self.interp = nn.functional.interpolate 17 | self.size = size 18 | self.scale = scale 19 | self.mode = mode 20 | self.align_corners = None if mode == 'nearest' else False 21 | 22 | def forward(self, x): 23 | x = self.interp(x, size=self.size, scale_factor=self.scale, 24 | mode=self.mode, align_corners=self.align_corners) 25 | return x 26 | 27 | 28 | def get_norm_layer(norm_type='instance', dims=2): 29 | if norm_type == 'batch': 30 | if dims == 1: 31 | layer = nn.BatchNorm1d 32 | elif dims == 2: 33 | layer = nn.BatchNorm2d 34 | else: 35 | raise NotImplementedError('unsupported dim: %d' % dims) 36 | norm_layer = functools.partial(layer, affine=True) 37 | elif norm_type == 'instance': 38 | if dims == 1: 39 | layer = nn.InstanceNorm1d 40 | elif dims == 2: 41 | layer = nn.InstanceNorm2d 42 | else: 43 | raise NotImplementedError('unsupported dim: %d' % dims) 44 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False) 45 | elif norm_type == 'none': 46 | norm_layer = None 47 | else: 48 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type) 49 | return norm_layer 50 | 51 | # TODO: needs to be adapted to iterations 52 | def get_scheduler(optimizer, opt): 53 | assert opt.lr_policy == 'none' 54 | if opt.lr_policy == 'lambda': 55 | def lambda_rule(epoch): 56 | lr_l = 1.0 - max(0, epoch + 1 + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1) 57 | return lr_l 58 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 59 | elif opt.lr_policy == 'step': 60 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1) 61 | elif opt.lr_policy == 'plateau': 62 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) 63 | elif opt.lr_policy == 'cosine': 64 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.niter, eta_min=0) 65 | elif opt.lr_policy == 'none': 66 | scheduler = lr_scheduler.StepLR( 67 | optimizer, step_size=1000000000000, gamma=1) 68 | else: 69 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) 70 | return scheduler 71 | 72 | 73 | def init_weights(net, init_type='normal', gain=0.02): 74 | def init_func(m): 75 | classname = m.__class__.__name__ 76 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 77 | if init_type == 'normal': 78 | init.normal_(m.weight.data, 0.0, gain) 79 | elif init_type == 'xavier': 80 | init.xavier_normal_(m.weight.data, gain=gain) 81 | elif init_type == 'kaiming': 82 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 83 | elif init_type == 'orthogonal': 84 | init.orthogonal_(m.weight.data, gain=gain) 85 | else: 86 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 87 | if hasattr(m, 'bias') and m.bias is not None: 88 | init.constant_(m.bias.data, 0.0) 89 | elif classname.find('BatchNorm2d') != -1: 90 | init.normal_(m.weight.data, 1.0, gain) 91 | init.constant_(m.bias.data, 0.0) 92 | 93 | print('initialize network with %s' % init_type) 94 | net.apply(init_func) 95 | 96 | 97 | def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]): 98 | if len(gpu_ids) > 0: 99 | assert(torch.cuda.is_available()) 100 | net.to(gpu_ids[0]) 101 | net = torch.nn.DataParallel(net, gpu_ids) 102 | init_weights(net, init_type, gain=init_gain) 103 | return net 104 | 105 | 106 | def define_G(input_nc, output_nc, netG, norm='batch', 107 | init_type='normal', init_gain=0.02, gpu_ids=[], n_blocks=4): 108 | net = None 109 | norm_layer = get_norm_layer(norm_type=norm) 110 | 111 | if netG == 'skip_nips': 112 | net = SkipNipsGenerator(input_nc, output_nc, n_blocks=n_blocks) 113 | else: 114 | raise NotImplementedError('Generator model name [%s] is not recognized' % netG) 115 | return init_net(net, init_type, init_gain, gpu_ids) 116 | 117 | 118 | def define_G_cond(input_nc, cond_nc, output_nc, netG, norm='batch', 119 | init_type='normal', init_gain=0.02, gpu_ids=[], avg_pool_cond=False): 120 | net = None 121 | norm_layer = get_norm_layer(norm_type=norm) 122 | 123 | if netG == 'nips': 124 | net = CondNipsGenerator(input_nc, cond_nc, output_nc, avg_pool_cond=avg_pool_cond) 125 | else: 126 | raise NotImplementedError('Generator model name [%s] is not recognized' % netG) 127 | return init_net(net, init_type, init_gain, gpu_ids) 128 | 129 | 130 | def define_D(input_nc, ndf, netD, multi_gan=False, 131 | n_layers_D=3, norm='batch', use_sigmoid=False, init_type='normal', init_gain=0.02, gpu_ids=[]): 132 | net = None 133 | norm_layer = get_norm_layer(norm_type=norm) 134 | 135 | def net_factory(net_class, *args, **kwargs): 136 | return lambda: net_class(*args, **kwargs) 137 | 138 | if netD == 'basic': 139 | net = net_factory( 140 | NLayerDiscriminator, input_nc, ndf, n_layers=3, 141 | norm_layer=norm_layer, use_sigmoid=use_sigmoid) 142 | else: 143 | raise NotImplementedError('Discriminator model name [%s] is not recognized' % net) 144 | if multi_gan: 145 | net = MultiDiscriminator(net) 146 | else: 147 | net = net() 148 | return init_net(net, init_type, init_gain, gpu_ids) 149 | 150 | 151 | def define_regressor(input_nc, output_nc, norm='batch', init_type='normal', 152 | init_gain=0.02, gpu_ids=[], net_type='nips_encoder', 153 | n_channels=32): 154 | norm_layer = get_norm_layer(norm_type=norm) 155 | if net_type == 'nips_encoder': 156 | net = NipsEncoder(input_nc, output_nc, norm_layer=norm_layer, 157 | channels=n_channels) 158 | else: 159 | raise NotImplementedError('Regressor model name [%s] is not recognized' % net) 160 | return init_net(net, init_type, init_gain, gpu_ids) 161 | 162 | 163 | class GANLoss(nn.Module): 164 | def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0): 165 | super(GANLoss, self).__init__() 166 | self.register_buffer('real_label', torch.tensor(target_real_label)) 167 | self.register_buffer('fake_label', torch.tensor(target_fake_label)) 168 | if use_lsgan: 169 | self.loss = nn.MSELoss() 170 | else: 171 | self.loss = nn.BCELoss() 172 | 173 | def get_target_tensor(self, input, target_is_real): 174 | if target_is_real: 175 | target_tensor = self.real_label 176 | else: 177 | target_tensor = self.fake_label 178 | return target_tensor.expand_as(input) 179 | 180 | def __call__(self, input, target_is_real): 181 | target_tensor = self.get_target_tensor(input, target_is_real) 182 | return self.loss(input, target_tensor) 183 | 184 | 185 | class MultiGANLoss(GANLoss): 186 | def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0): 187 | super(MultiGANLoss, self).__init__( 188 | use_lsgan=use_lsgan, target_real_label=target_real_label, 189 | target_fake_label=target_fake_label) 190 | 191 | def __call__(self, inputs, target_is_real): 192 | losses = [] 193 | for input in inputs: 194 | loss = (super(MultiGANLoss, self).__call__(input, target_is_real)) 195 | losses.append(loss) 196 | return torch.mean(torch.stack(losses)) 197 | 198 | 199 | class NormalizedLoss(nn.Module): 200 | def __init__(self, base_loss, mu=0.999, init_val=None): 201 | super(NormalizedLoss, self).__init__() 202 | self.add_module('base_loss', base_loss) 203 | # self.base_loss = base_loss 204 | self.mu = mu 205 | self.init_val = init_val 206 | self.register_buffer('running_mean', torch.tensor(init_val or -1.0)) 207 | 208 | def __call__(self, *args, **kw_args): 209 | curr_loss = self.base_loss(*args, **kw_args) 210 | # update the moving average: 211 | if self.running_mean.device != curr_loss.device: 212 | self.running_mean = self.running_mean.to(curr_loss.device) 213 | if self.running_mean == -1.0: 214 | self.running_mean = curr_loss.detach() 215 | else: 216 | self.running_mean = self.mu * self.running_mean + (1.0 - self.mu) * curr_loss.detach() 217 | loss_val = curr_loss / (self.running_mean + 1e-8) 218 | return loss_val 219 | 220 | 221 | 222 | class SkipNipsGenerator(nn.Module): 223 | def __init__(self, input_nc, output_nc, n_blocks=4, norm_layer=nn.BatchNorm2d): 224 | super(SkipNipsGenerator, self).__init__() 225 | self.input_nc = input_nc 226 | self.output_nc = output_nc 227 | 228 | decoder_inputs = [256, 128, 64, 32][4-n_blocks:] 229 | bottleneck_nc = decoder_inputs[0] 230 | 231 | self.encoder = SkipNipsEncoder(input_nc, bottleneck_nc, n_blocks=n_blocks, norm_layer=norm_layer) 232 | self.decoder = SkipNipsDecoder(decoder_inputs, output_nc, n_blocks=n_blocks, norm_layer=norm_layer) 233 | 234 | def forward(self, input): 235 | outputs = self.encoder(input) 236 | return self.decoder(outputs[::-1]) 237 | 238 | 239 | 240 | class CondNipsGenerator(nn.Module): 241 | def __init__(self, input_nc, cond_nc, output_nc, norm_layer=nn.BatchNorm2d, 242 | avg_pool_cond=False): 243 | super(CondNipsGenerator, self).__init__() 244 | self.input_nc = input_nc 245 | self.output_nc = output_nc 246 | self.avg_pool_cond = avg_pool_cond 247 | 248 | bottleneck_nc = 256 249 | 250 | self.encoder = NipsEncoder(input_nc, bottleneck_nc, norm_layer=norm_layer) 251 | self.encoder_cond = NipsEncoder(cond_nc, bottleneck_nc, norm_layer=norm_layer) 252 | self.decoder = NipsDecoder(2 * bottleneck_nc, output_nc, norm_layer=norm_layer) 253 | 254 | def forward(self, input, cond): 255 | input = self.encoder(input) 256 | cond = self.encoder_cond(cond) 257 | if self.avg_pool_cond: 258 | spatial_size = list(cond.shape[2:]) 259 | cond = torch.mean(cond, dim=(2, 3), keepdim=True) 260 | cond = cond.repeat([1, 1] + spatial_size) 261 | decoder_input = torch.cat([input, cond], dim=1) 262 | output = self.decoder(decoder_input) 263 | return output 264 | 265 | 266 | 267 | class NipsEncoder(nn.Module): 268 | def __init__(self, input_nc, output_nc, channels=32, 269 | norm_layer=nn.BatchNorm2d): 270 | super(NipsEncoder, self).__init__() 271 | 272 | self.input_nc = input_nc 273 | self.output_nc = output_nc 274 | 275 | if type(norm_layer) == functools.partial: 276 | use_bias = norm_layer.func == nn.InstanceNorm2d 277 | else: 278 | use_bias = norm_layer == nn.InstanceNorm2d 279 | 280 | model = [] 281 | 282 | def conv(channels_in, channels_out, kernel_size=3, stride=1, padding=1): 283 | return [nn.Conv2d(channels_in, channels_out, kernel_size=kernel_size, 284 | stride=stride, padding=padding, bias=use_bias), 285 | norm_layer(channels_out), 286 | nn.ReLU(True)] 287 | 288 | channels = 32 289 | model += conv(input_nc, channels, kernel_size=7, stride=1, padding=3) 290 | model += conv(channels, channels, kernel_size=3, stride=1, padding=1) 291 | 292 | channels_in = channels 293 | channels *= 2 294 | model += conv(channels_in, channels, kernel_size=3, stride=2, padding=1) 295 | model += conv(channels, channels, kernel_size=3, stride=1, padding=1) 296 | 297 | channels_in = channels 298 | channels *= 2 299 | model += conv(channels_in, channels, kernel_size=3, stride=2, padding=1) 300 | model += conv(channels, channels, kernel_size=3, stride=1, padding=1) 301 | 302 | channels_in = channels 303 | channels *= 2 304 | model += conv(channels_in, channels, kernel_size=3, stride=2, padding=1) 305 | model += conv(channels, channels, kernel_size=3, stride=1, padding=1) 306 | 307 | model += [nn.Conv2d(channels, output_nc, kernel_size=1, stride=1, padding=0, bias=use_bias)] 308 | 309 | self.model = nn.Sequential(*model) 310 | 311 | def forward(self, input): 312 | return self.model(input) 313 | 314 | 315 | class SkipNipsEncoder(nn.Module): 316 | def __init__(self, input_nc, output_nc, channels=32, n_blocks=4, 317 | norm_layer=nn.BatchNorm2d): 318 | super(SkipNipsEncoder, self).__init__() 319 | 320 | self.input_nc = input_nc 321 | self.output_nc = output_nc 322 | 323 | if type(norm_layer) == functools.partial: 324 | use_bias = norm_layer.func == nn.InstanceNorm2d 325 | else: 326 | use_bias = norm_layer == nn.InstanceNorm2d 327 | 328 | 329 | def conv(channels_in, channels_out, kernel_size=3, stride=1, padding=1): 330 | return [nn.Conv2d(channels_in, channels_out, kernel_size=kernel_size, 331 | stride=stride, padding=padding, bias=use_bias), 332 | norm_layer(channels_out), 333 | nn.ReLU(True)] 334 | 335 | self.n_blocks = n_blocks 336 | 337 | model = [] 338 | channels = 32 339 | model += conv(input_nc, channels, kernel_size=7, stride=1, padding=3) 340 | model += conv(channels, channels, kernel_size=3, stride=1, padding=1) 341 | self.blocks = nn.ModuleList([nn.Sequential(*model)]) 342 | 343 | for _ in range(n_blocks - 2): 344 | model = [] 345 | channels_in = channels 346 | channels *= 2 347 | model += conv(channels_in, channels, kernel_size=3, stride=2, padding=1) 348 | model += conv(channels, channels, kernel_size=3, stride=1, padding=1) 349 | self.blocks.append(nn.Sequential(*model)) 350 | 351 | model = [] 352 | channels_in = channels 353 | channels *= 2 354 | model += conv(channels_in, channels, kernel_size=3, stride=2, padding=1) 355 | model += conv(channels, channels, kernel_size=3, stride=1, padding=1) 356 | 357 | model += [nn.Conv2d(channels, output_nc, kernel_size=1, stride=1, padding=0, bias=use_bias)] 358 | self.blocks.append(nn.Sequential(*model)) 359 | 360 | 361 | def forward(self, input): 362 | outputs = [] 363 | for module in self.blocks: 364 | input = module(input) 365 | outputs.append(input) 366 | return outputs 367 | 368 | 369 | 370 | class SkipNipsDecoder(nn.Module): 371 | def __init__(self, input_nc, output_nc, n_blocks=4, norm_layer=nn.BatchNorm2d): 372 | super(SkipNipsDecoder, self).__init__() 373 | 374 | self.input_nc = input_nc 375 | self.output_nc = output_nc 376 | 377 | if type(norm_layer) == functools.partial: 378 | use_bias = norm_layer.func == nn.InstanceNorm2d 379 | else: 380 | use_bias = norm_layer == nn.InstanceNorm2d 381 | 382 | 383 | def conv(channels_in, channels_out, kernel_size=3, stride=1, padding=1): 384 | return [nn.Conv2d(channels_in, channels_out, kernel_size=kernel_size, 385 | stride=stride, padding=padding, bias=use_bias), 386 | norm_layer(channels_out), 387 | nn.ReLU(True)] 388 | 389 | upsampling = 'bilinear' 390 | 391 | self.n_blocks = n_blocks 392 | 393 | model = [] 394 | channels = 32 * (2**(self.n_blocks-1)) 395 | model += conv(input_nc[0], channels, kernel_size=3, stride=1, padding=1) 396 | model += conv(channels, channels, kernel_size=3, stride=1, padding=1) 397 | model += [Interpolate(scale=2, mode=upsampling)] 398 | self.blocks = nn.ModuleList([nn.Sequential(*model)]) 399 | 400 | for i in range(self.n_blocks - 2): 401 | model = [] 402 | channels_in = channels 403 | channels /= 2 404 | model += [nn.Conv2d(channels_in + input_nc[i + 1], channels_in, kernel_size=1, stride=1, padding=0, bias=False)] 405 | model += conv(channels_in, channels, kernel_size=3, stride=1, padding=1) 406 | model += conv(channels, channels, kernel_size=3, stride=1, padding=1) 407 | model += [Interpolate(scale=2, mode=upsampling)] 408 | self.blocks.append(nn.Sequential(*model)) 409 | 410 | model = [] 411 | channels_in = channels 412 | channels /= 2 413 | model += [nn.Conv2d(channels_in + input_nc[-1], channels_in, kernel_size=1, stride=1, padding=0, bias=False)] 414 | model += conv(channels_in, channels, kernel_size=3, stride=1, padding=1) 415 | model += [nn.Conv2d(channels, output_nc, kernel_size=3, stride=1, padding=1, bias=True)] 416 | self.blocks.append(nn.Sequential(*model)) 417 | 418 | 419 | def forward(self, inputs): 420 | output = self.blocks[0](inputs[0]) 421 | for i in range(1, self.n_blocks): 422 | output = self.blocks[i]( 423 | torch.cat([output, inputs[i]], dim=1)) 424 | return output 425 | 426 | 427 | class NipsDecoder(nn.Module): 428 | def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d): 429 | super(NipsDecoder, self).__init__() 430 | 431 | self.input_nc = input_nc 432 | self.output_nc = output_nc 433 | 434 | if type(norm_layer) == functools.partial: 435 | use_bias = norm_layer.func == nn.InstanceNorm2d 436 | else: 437 | use_bias = norm_layer == nn.InstanceNorm2d 438 | 439 | model = [] 440 | 441 | def conv(channels_in, channels_out, kernel_size=3, stride=1, padding=1): 442 | return [nn.Conv2d(channels_in, channels_out, kernel_size=kernel_size, 443 | stride=stride, padding=padding, bias=use_bias), 444 | norm_layer(channels_out), 445 | nn.ReLU(True)] 446 | 447 | upsampling = 'bilinear' 448 | 449 | channels = 256 450 | model += conv(input_nc, channels, kernel_size=3, stride=1, padding=1) 451 | model += conv(channels, channels, kernel_size=3, stride=1, padding=1) 452 | 453 | model += [Interpolate(scale=2, mode=upsampling)] 454 | channels_in = channels 455 | channels /= 2 456 | model += conv(channels_in, channels, kernel_size=3, stride=1, padding=1) 457 | model += conv(channels, channels, kernel_size=3, stride=1, padding=1) 458 | 459 | model += [Interpolate(scale=2, mode=upsampling)] 460 | channels_in = channels 461 | channels /= 2 462 | model += conv(channels_in, channels, kernel_size=3, stride=1, padding=1) 463 | model += conv(channels, channels, kernel_size=3, stride=1, padding=1) 464 | 465 | model += [Interpolate(scale=2, mode=upsampling)] 466 | channels_in = channels 467 | channels /= 2 468 | model += conv(channels_in, channels, kernel_size=3, stride=1, padding=1) 469 | model += [nn.Conv2d(channels, output_nc, kernel_size=3, stride=1, padding=1, bias=True)] 470 | 471 | self.model = nn.Sequential(*model) 472 | 473 | def forward(self, input): 474 | return self.model(input) 475 | 476 | 477 | 478 | 479 | # Defines the PatchGAN discriminator with the specified arguments. 480 | class NLayerDiscriminator(nn.Module): 481 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False): 482 | super(NLayerDiscriminator, self).__init__() 483 | if type(norm_layer) == functools.partial: 484 | use_bias = norm_layer.func == nn.InstanceNorm2d 485 | else: 486 | use_bias = norm_layer == nn.InstanceNorm2d 487 | 488 | kw = 4 489 | padw = 1 490 | sequence = [ 491 | nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), 492 | nn.LeakyReLU(0.2, True) 493 | ] 494 | 495 | nf_mult = 1 496 | nf_mult_prev = 1 497 | for n in range(1, n_layers): 498 | nf_mult_prev = nf_mult 499 | nf_mult = min(2**n, 8) 500 | sequence += [ 501 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, 502 | kernel_size=kw, stride=2, padding=padw, bias=use_bias), 503 | norm_layer(ndf * nf_mult), 504 | nn.LeakyReLU(0.2, True) 505 | ] 506 | 507 | nf_mult_prev = nf_mult 508 | nf_mult = min(2**n_layers, 8) 509 | sequence += [ 510 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, 511 | kernel_size=kw, stride=1, padding=padw, bias=use_bias), 512 | norm_layer(ndf * nf_mult), 513 | nn.LeakyReLU(0.2, True) 514 | ] 515 | 516 | sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] 517 | 518 | if use_sigmoid: 519 | sequence += [nn.Sigmoid()] 520 | 521 | self.model = nn.Sequential(*sequence) 522 | 523 | def forward(self, input): 524 | return self.model(input) 525 | 526 | 527 | 528 | class MultiDiscriminator(nn.Module): 529 | def __init__(self, net): 530 | super(MultiDiscriminator, self).__init__() 531 | self.scales = [1, 1./2, 1./4] 532 | for i in range(len(self.scales)): 533 | self.add_module('net_%d' % i, net()) 534 | 535 | def forward(self, input): 536 | outputs = [] 537 | for i, scale in enumerate(self.scales): 538 | input_ = torch.nn.functional.interpolate( 539 | input, scale_factor=scale, mode='bilinear') 540 | output = getattr(self, 'net_%d' % i)(input_) 541 | outputs.append(output) 542 | return outputs 543 | -------------------------------------------------------------------------------- /models/perceptual_loss.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | import scipy.misc 6 | import scipy.io 7 | from networks import NormalizedLoss 8 | 9 | 10 | 11 | def conv(inputs, filters): 12 | return nn.Conv2d( 13 | inputs, filters, kernel_size=3, stride=1, padding=1, bias=True) 14 | 15 | def pool(): 16 | return nn.MaxPool2d(kernel_size=2, stride=2) 17 | 18 | 19 | class VGG19(nn.Module): 20 | 21 | def __init__(self): 22 | super(VGG19, self).__init__() 23 | self.conv1_1 = conv(3, 64) 24 | self.conv1_2 = conv(64, 128) 25 | self.pool1 = pool() 26 | self.conv2_1 = conv(128, 128) 27 | self.conv2_2 = conv(128, 256) 28 | self.pool2 = pool() 29 | self.conv3_1 = conv(256, 256) 30 | self.conv3_2 = conv(256, 256) 31 | self.conv3_3 = conv(256, 256) 32 | self.conv3_4 = conv(256, 512) 33 | self.pool3 = pool() 34 | self.conv4_1 = conv(512, 512) 35 | self.conv4_2 = conv(512, 512) 36 | self.conv4_3 = conv(512, 512) 37 | self.conv4_4 = conv(512, 512) 38 | self.pool4 = pool() 39 | self.conv5_1 = conv(512, 512) 40 | self.conv5_2 = conv(512, 512) 41 | self.conv5_3 = conv(512, 512) 42 | self.conv5_4 = conv(512, 512) 43 | self.pool5 = pool() 44 | 45 | 46 | def forward(self, x): 47 | conv1_1 = self.conv1_1(x) 48 | conv1_2 = self.conv1_2(conv1_1) 49 | pool1 = self.pool1(conv1_2) 50 | conv2_1 = self.conv2_1(pool1) 51 | conv2_2 = self.conv2_2(conv2_1) 52 | pool2 = self.pool2(conv2_2) 53 | conv3_1 = self.conv3_1(pool2) 54 | conv3_2 = self.conv3_2(conv3_1) 55 | conv3_3 = self.conv3_3(conv3_2) 56 | conv3_4 = self.conv3_4(conv3_3) 57 | pool3 = self.pool3(conv3_4) 58 | conv4_1 = self.conv4_1(pool3) 59 | conv4_2 = self.conv4_2(conv4_1) 60 | conv4_3 = self.conv4_3(conv4_2) 61 | conv4_4 = self.conv4_4(conv4_3) 62 | pool4 = self.pool4(conv4_4) 63 | conv5_1 = self.conv5_1(pool4) 64 | conv5_2 = self.conv5_2(conv5_1) 65 | conv5_3 = self.conv5_3(conv5_2) 66 | conv5_4 = self.conv5_4(conv5_3) 67 | pool5 = self.pool5(conv5_4) 68 | 69 | return x, conv1_2, conv2_2, conv3_2, conv4_2, conv5_2 70 | 71 | 72 | class PerceptualLoss(nn.Module): 73 | def __init__(self, vgg19_path): 74 | super(PerceptualLoss, self).__init__() 75 | 76 | net = VGG19() 77 | net = net.cuda() 78 | 79 | vgg_rawnet = scipy.io.loadmat(vgg19_path) 80 | vgg_layers = vgg_rawnet['layers'][0] 81 | 82 | #Weight initialization according to the pretrained VGG Very deep 19 network Network weights 83 | layers = [0, 2, 5, 7, 10, 12, 14, 16, 19, 21, 23, 25, 28, 30, 32, 34] 84 | att = ['conv1_1', 'conv1_2', 'conv2_1', 'conv2_2', 'conv3_1', 'conv3_2', 85 | 'conv3_3', 'conv3_4', 'conv4_1', 'conv4_2', 'conv4_3', 'conv4_4', 86 | 'conv5_1', 'conv5_2', 'conv5_3', 'conv5_4'] 87 | filt = [64, 64, 128, 128, 256, 256, 256, 256, 512, 512, 512, 512, 512, 512, 512, 512] 88 | for l in range(len(layers)): 89 | getattr(net, att[l]).weight = nn.Parameter(torch.from_numpy( 90 | vgg_layers[layers[l]][0][0][2][0][0]).permute(3, 2, 0, 1).cuda(), requires_grad=False) 91 | getattr(net, att[l]).bias = nn.Parameter(torch.from_numpy( 92 | vgg_layers[layers[l]][0][0][2][0][1]).view(filt[l]).cuda(), requires_grad=False) 93 | 94 | self.net = net 95 | 96 | self.n_layers = 6 97 | self.losses = [NormalizedLoss(nn.MSELoss(), mu=0.99) for _ in range(self.n_layers)] 98 | 99 | 100 | def forward(self, input, target): 101 | # FIXME: how to handle normalized inputs 102 | input = ((input + 1.0) / 2.0) * 255.0 103 | target = ((target + 1.0) / 2.0) * 255.0 104 | 105 | mean = np.array([123.6800, 116.7790, 103.9390]).reshape((1,1,1,3)) 106 | mean = torch.from_numpy(mean).float().permute(0,3,1,2).cuda() 107 | 108 | input_f = self.net(input - mean) 109 | target_f = self.net(target - mean) 110 | 111 | # normalize 112 | # layer_w = [1.0, 1.6, 2.3, 1.8, 2.8, 0.008] 113 | # input_f = [f / torch.norm(f.view(f.shape[0], -1, 1, 1), p=2, dim=1, keepdim=True) for f in input_f] 114 | # target_f = [f / torch.norm(f.view(f.shape[0], -1, 1, 1), p=2, dim=1, keepdim=True) for f in target_f] 115 | 116 | losses = [] 117 | for x, y, loss_fn in zip(input_f, target_f, self.losses): 118 | losses.append(loss_fn(x, y)) 119 | loss = torch.mean(torch.stack(losses)) 120 | 121 | return loss 122 | -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | from collections import Iterable 3 | 4 | import matplotlib.cm as mpl_color_map 5 | import numpy as np 6 | import PIL.Image 7 | import PIL.ImageDraw 8 | import torch 9 | import torch.nn.functional as F 10 | from torch.autograd import Function 11 | 12 | import pygame 13 | 14 | 15 | def extract_points(input): 16 | """ 17 | input is N x K x H x W "heat map" tensor, N is batch size 18 | """ 19 | def get_coord(input, other_axis, axis_size): 20 | # get "x-y" coordinates: 21 | marg = torch.mean(input, dim=other_axis) # B,W,NMAP 22 | prob = F.softmax(marg, dim=2) # B,W,NMAP 23 | grid = torch.linspace(-1.0, 1.0, axis_size, dtype=torch.float32, device=input.device) # W 24 | grid = grid[None, None] 25 | point = torch.sum(prob * grid, dim=2) 26 | return point, prob 27 | 28 | x, _ = get_coord(input, 2, input.shape[3]) # B,NMAP 29 | y, _ = get_coord(input, 3, input.shape[2]) # B,NMAP 30 | points = torch.stack([x, y], dim=2) 31 | 32 | return points 33 | 34 | 35 | def render_points(points, width, height, inv_std=50): 36 | device = points.device 37 | mu_x, mu_y = points[:, :, 0:1], points[:, :, 1:2] 38 | 39 | y = torch.linspace(-1.0, 1.0, height, dtype=torch.float32, device=device) 40 | x = torch.linspace(-1.0, 1.0, width, dtype=torch.float32, device=device) 41 | 42 | mu_y, mu_x = mu_y[..., None], mu_x[..., None] 43 | 44 | y = torch.reshape(y, [1, 1, height, 1]) 45 | x = torch.reshape(x, [1, 1, 1, width]) 46 | 47 | g_y = (y - mu_y) ** 2 48 | g_x = (x - mu_x) ** 2 49 | dist = (g_y + g_x) * inv_std**2 50 | 51 | g_yx = torch.exp(-dist) 52 | return g_yx 53 | 54 | 55 | def get_perpendicular_unit_vector(u): 56 | negative = torch.Tensor([1, -1]).to(u.device) 57 | # add dimensions to match u 58 | negative = torch.reshape(negative, [1] * (u.dim() - 1) + [2]) 59 | u = u[..., [1, 0]] * negative 60 | u = u / torch.norm(u, dim=-1, keepdim=True) 61 | return u 62 | 63 | 64 | def render_line_segment(a, b, size, distance='gauss', sigma=0.2, 65 | normalize=False, widths=None): 66 | """ 67 | a, b points defining the line segment 68 | widths: B x N 69 | outputs B x N x H x W 70 | """ 71 | def sumprod(x, y, keepdim=True): 72 | return torch.sum(x * y, dim=-1, keepdim=keepdim) 73 | 74 | grid = torch.linspace(-1.0, 1.0, size, dtype=torch.float32, device=a.device) 75 | 76 | # FIXME: api different from numpy 77 | yv, xv = torch.meshgrid([grid, grid]) 78 | # 1 x H x W x 2 79 | m = torch.cat([xv[..., None], yv[..., None]], dim=-1)[None, None] 80 | 81 | # B x N x 1 x 1 x 2 82 | a, b = a[:, :, None, None, :], b[:, :, None, None, :] 83 | t_min = sumprod(m - a, b - a) / \ 84 | torch.max(sumprod(b - a, b - a), torch.tensor(1e-6, device=a.device)) 85 | t_line = torch.clamp(t_min, 0.0, 1.0) 86 | 87 | # closest points on the line to every image pixel 88 | s = a + t_line * (b - a) 89 | 90 | # for rectangle 91 | if widths is not None: 92 | # get perpendicular unit vector 93 | u = b - a 94 | u = u[..., [1, 0]] * torch.Tensor([[[[[1, -1]]]]]).to(u.device) 95 | u = u / torch.norm(u, dim=-1, keepdim=True) 96 | 97 | t_min = sumprod(m - s, u) / \ 98 | torch.max(sumprod(u, u), torch.tensor(1e-6, device=a.device)) 99 | 100 | w = widths[..., None, None, None] / 2.0 101 | t_line = clamp(t_min, -w, w) 102 | 103 | s = s + t_line * u 104 | 105 | d = sumprod(s - m, s - m, keepdim=False) 106 | 107 | # normalize distancin 108 | if distance == 'gauss': 109 | d = torch.sqrt(d + 1e-6) 110 | d_norm = torch.exp(-d / (sigma ** 2)) 111 | elif distance == 'norm': 112 | d = torch.sqrt(d + 1e-6) 113 | d_max = torch.sqrt(8) 114 | d_norm = (d_max - d) / d_max 115 | else: 116 | raise ValueError() 117 | 118 | if normalize: 119 | d_norm = d_norm / torch.sum(d_norm, (2, 3), keepdim=True) 120 | 121 | return d_norm 122 | 123 | 124 | def get_line_points(points, connections): 125 | # gather points for lines 126 | a_points = torch.zeros( 127 | [points.shape[0], len(connections), points.shape[2]], 128 | dtype=points.dtype, device=points.device) 129 | b_points = torch.zeros_like(a_points) 130 | for i, (a, b) in enumerate(connections): 131 | a_points[:, i] = _mean_point(points, a) 132 | b_points[:, i] = _mean_point(points, b) 133 | return a_points, b_points 134 | 135 | 136 | def get_polygons_points(points, polygons): 137 | polygon_points = [] 138 | for polygon in polygons: 139 | polygon_points += [get_polygon_points(points, polygon)] 140 | return polygon_points 141 | 142 | 143 | def get_polygon_points(points, polygon): 144 | point_ids = [x for x, _ in polygon] 145 | polygon_points = points[:, point_ids] 146 | return polygon_points 147 | 148 | 149 | def render_skeleton(points, connections, width, height, colored=False, 150 | auxilary_links=None, colors=None, reduce=None, 151 | sigma=0.2, normalize=False, widths=None): 152 | """ 153 | colors: B x N x C 154 | returns: B x N x C x H x W or B x C x H x W if using reduce 155 | """ 156 | assert width == height 157 | 158 | batch_size = points.shape[0] 159 | if auxilary_links is not None: 160 | points, connections = add_auxiliary_links( 161 | points, connections, auxilary_links) 162 | 163 | # create colors if required 164 | if colors is None: 165 | if colored: 166 | colors = torch.linspace(0.2, 1.0, len(connections), dtype=torch.float32, device=points.device) 167 | colors = colors[None, :, None].repeat([batch_size, 1, 1]) 168 | else: 169 | colors = torch.ones([batch_size, len(connections), 1], dtype=torch.float32, device=points.device) 170 | 171 | # parse auxiliary links contained in connectios 172 | points, connections = parse_auxiliary_links(points, connections) 173 | 174 | # gather points for lines 175 | a, b = zip(*connections) 176 | a, b = list(a), list(b) 177 | a_points = points[:, a] 178 | b_points = points[:, b] 179 | 180 | renderings = render_line_segment(a_points, b_points, width, sigma=sigma, 181 | normalize=normalize, widths=widths) 182 | # add an axis for colors, renderings has B x N x 1 x H x W 183 | renderings = renderings[:, :, None] 184 | renderings = renderings * colors[..., None, None] 185 | 186 | # renderings has B x N x C x H x W 187 | return renderings 188 | 189 | 190 | def add_auxiliary_links(points, connections, auxilary_links): 191 | def mean_point(points, indices): 192 | if isinstance(indices, Iterable): 193 | return torch.mean(points[:, indices], dim=1) 194 | else: 195 | return points[:, indices] 196 | 197 | # add auxilary points and links 198 | connections = connections[:] 199 | n_aux_points = 2 * len(auxilary_links) 200 | index_offset = points.shape[1] 201 | zeros = torch.zeros([points.shape[0], n_aux_points, 2], 202 | dtype=points.dtype, device=points.device) 203 | points = torch.cat([points, zeros], dim=1) 204 | for (a, b), i in zip(auxilary_links, range(index_offset, index_offset + n_aux_points, 2)): 205 | points[:, i] = mean_point(points, a) 206 | points[:, i + 1] = mean_point(points, b) 207 | connections.append([i, i + 1]) 208 | 209 | return points, connections 210 | 211 | 212 | def parse_auxiliary_links(points, connections): 213 | # add auxilary points and links 214 | new_connections = [] 215 | n_points = 2 * len(connections) 216 | new_points = torch.zeros([points.shape[0], n_points, 2], 217 | dtype=points.dtype, device=points.device) 218 | for (a, b), i in zip(connections, range(0, n_points, 2)): 219 | new_points[:, i] = _mean_point(points, a) 220 | new_points[:, i + 1] = _mean_point(points, b) 221 | new_connections.append([i, i + 1]) 222 | 223 | return new_points, new_connections 224 | 225 | 226 | def _mean_point(points, indices): 227 | if isinstance(indices, Iterable): 228 | return torch.mean(points[:, indices], dim=1) 229 | else: 230 | return points[:, indices] 231 | 232 | 233 | def normalize_im(im): 234 | return 2.0 * im - 1.0 235 | 236 | 237 | def l2_distance(x, y): 238 | """ 239 | x: B x N x D 240 | y: B x N x D 241 | """ 242 | return torch.sqrt(torch.sum((x - y) ** 2, dim=2)) 243 | 244 | 245 | def mean_l2_distance(x, y): 246 | """ 247 | x: B x N x D 248 | y: B x N x D 249 | """ 250 | return torch.mean(l2_distance(x, y), dim=1) 251 | 252 | 253 | def mean_l2_distance_norm(predict, target, norm_points): 254 | """ 255 | x: B x N x D 256 | y: B x N x D 257 | """ 258 | dists = l2_distance(predict, target) 259 | dists_points = l2_distance( 260 | target[:, norm_points[0]][:, None], 261 | target[:, norm_points[1]][:, None]) 262 | norm_dists = dists / dists_points 263 | return torch.mean(norm_dists, dim = 1) 264 | 265 | 266 | def swap_points(points, correspondences): 267 | """ 268 | points: B x N x D 269 | """ 270 | permutation = list(range((points.shape[1]))) 271 | for a, b in correspondences: 272 | permutation[a] = b 273 | permutation[b] = a 274 | new_points = points[:, permutation, :] 275 | return new_points 276 | 277 | 278 | def normalize_image_tensor(tensor): 279 | minimum, _ = torch.min(tensor, dim=2, keepdim=True) 280 | minimum, _ = torch.min(minimum, dim=3, keepdim=True) 281 | tensor -= minimum 282 | maximum, _ = torch.max(tensor, dim=2, keepdim=True) 283 | maximum, _ = torch.max(maximum, dim=3, keepdim=True) 284 | return tensor / maximum 285 | 286 | 287 | def clamp(tensor, minimum, maximum): 288 | tensor = torch.max(tensor, minimum) 289 | tensor = torch.min(tensor, maximum) 290 | return tensor 291 | 292 | 293 | def lstsq(input, A): 294 | solution = [] 295 | QR = [] 296 | for i, a in zip(input, A): 297 | s, q = torch.lstsq(i, a) 298 | solution += [s] 299 | QR += [q] 300 | solution = torch.stack(solution) 301 | QR = torch.stack(QR) 302 | return solution, QR 303 | 304 | 305 | def rollout(tensor): 306 | """ 307 | tensor: B x C x .... -> B * C x ... 308 | """ 309 | shape = tensor.shape 310 | new_shape = [shape[0] * shape[1]] 311 | if len(shape) > 2: 312 | new_shape += shape[2:] 313 | return torch.reshape(tensor, new_shape) 314 | 315 | 316 | def unrollout(tensor, n_channels): 317 | """ 318 | tensor: B * C x ... -> B x C x ... 319 | """ 320 | shape = tensor.shape 321 | new_shape = [shape[0] // n_channels, n_channels] 322 | if len(shape) > 1: 323 | new_shape += shape[1:] 324 | return torch.reshape(tensor, new_shape) 325 | 326 | 327 | def apply_colormap_on_tensor(tensor, colormap_name='jet'): 328 | """ 329 | """ 330 | # Get colormap 331 | assert tensor.shape[1] == 1 332 | 333 | color_map = mpl_color_map.get_cmap(colormap_name) 334 | 335 | tensor = normalize_tensor_image(tensor, dim=(2, 3)) 336 | tensor_np = tensor.detach().cpu().numpy() 337 | heatmap = color_map(tensor_np[:, 0]) 338 | # remove alpha 339 | heatmap = heatmap[:, :, :, :3] 340 | heatmap = torch.from_numpy(heatmap).to(tensor.device).type(tensor.dtype) 341 | heatmap = heatmap.permute(0, 3, 1, 2) 342 | 343 | return heatmap 344 | 345 | 346 | def normalize_tensor_image(tensor, dim): 347 | minv = multi_min(tensor, dim, keepdim=True) 348 | maxv = multi_max(tensor, dim, keepdim=True) 349 | return (tensor - minv) / (maxv - minv) 350 | 351 | 352 | 353 | def multi_min(input, dim, keepdim=False): 354 | return _multi_minmax(input, dim, torch.min, keepdim=keepdim) 355 | 356 | 357 | def multi_max(input, dim, keepdim=False): 358 | return _multi_minmax(input, dim, torch.max, keepdim=keepdim) 359 | 360 | 361 | def _multi_minmax(input, dim, operator, keepdim=False): 362 | dim = sorted(dim) 363 | reduced = input 364 | for axis in reversed(dim): 365 | reduced, _ = operator(reduced, axis, keepdim=keepdim) 366 | return reduced 367 | 368 | 369 | # https://gist.github.com/bobchennan/a865b153c6835a3a6a5c628213766150 370 | class gels(Function): 371 | """ Efficient implementation of gels from 372 | Nanxin Chen 373 | bobchennan@gmail.com 374 | """ 375 | @staticmethod 376 | def forward(ctx, A, b): 377 | # A: (..., M, N) 378 | # b: (..., M, K) 379 | # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/linalg_ops.py#L267 380 | u = torch.cholesky(torch.matmul(A.transpose(-1, -2), A), upper=True) 381 | ret = torch.cholesky_solve(torch.matmul(A.transpose(-1, -2), b), u, upper=True) 382 | ctx.save_for_backward(u, ret, A, b) 383 | return ret 384 | 385 | @staticmethod 386 | def backward(ctx, grad_output): 387 | # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/linalg_grad.py#L223 388 | chol, x, a, b = ctx.saved_tensors 389 | z = torch.cholesky_solve(grad_output, chol, upper=True) 390 | xzt = torch.matmul(x, z.transpose(-1,-2)) 391 | zx_sym = xzt + xzt.transpose(-1, -2) 392 | grad_A = - torch.matmul(a, zx_sym) + torch.matmul(b, z.transpose(-1, -2)) 393 | grad_b = torch.matmul(a, z) 394 | return grad_A, grad_b 395 | -------------------------------------------------------------------------------- /options/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomasjakab/keypointgan/541b769d536dc113fcf6da271ed72ae9d963dbb4/options/__init__.py -------------------------------------------------------------------------------- /options/base_options.py: -------------------------------------------------------------------------------- 1 | import configargparse 2 | import os 3 | from util import util 4 | import torch 5 | import models 6 | import data 7 | 8 | 9 | class BaseOptions(): 10 | def __init__(self): 11 | self.initialized = False 12 | 13 | def initialize(self, parser): 14 | parser.add_argument('-c', '--config', required=False, is_config_file=True, help='config file path') 15 | parser.add_argument('--dataroot', required=False, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)') 16 | parser.add_argument('--batch_size', type=int, default=1, help='input batch size') 17 | parser.add_argument('--loadSize', type=int, default=286, help='scale images to this size') 18 | parser.add_argument('--fineSize', type=int, default=256, help='then crop to this size') 19 | parser.add_argument('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML') 20 | parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels') 21 | parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels') 22 | parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer') 23 | parser.add_argument('--netDA', type=str, default='basic', help='discriminator for fwd arm') 24 | parser.add_argument('--netG_A', type=str, default='resnet_9blocks', help='selects model to use for netG_A') 25 | parser.add_argument('--netG_B', type=str, default='nips', help='selects model to use for netG_B') 26 | parser.add_argument('--n_layers_D', type=int, default=3, help='only used if netD==n_layers') 27 | parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') 28 | parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models') 29 | parser.add_argument('--dataset_mode', type=str, default='unaligned', help='chooses dataset') 30 | parser.add_argument('--model', type=str, default='keypoint_gan', help='chooses which model to use') 31 | parser.add_argument('--iteration', type=str, default='latest', help='which iter to load? set to latest to use latest cached model') 32 | parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data') 33 | parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') 34 | parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator') 35 | parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') 36 | parser.add_argument('--resize_or_crop', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop|none]') 37 | parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation') 38 | parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal|xavier|kaiming|orthogonal]') 39 | parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.') 40 | parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information') 41 | parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{loadSize}') 42 | parser.add_argument('--tps', action='store_true', help='random TPS on input') 43 | parser.add_argument('--tps_target', action='store_true', help='random TPS on input') 44 | parser.add_argument('--multi_ganA', action='store_true', help='mulit GAN') 45 | parser.add_argument('--perceptual_net', type=str, default='', help='path to perceptual net') 46 | parser.add_argument('--nets_paths', type=str, nargs='+', help='list of pairs inits nets from the specified checkpoints') 47 | parser.add_argument('--upsampling_G_A', type=str, default='transpose', help='') 48 | parser.add_argument('--skeleton_type', type=str, default='human36m', help='') 49 | parser.add_argument('--paired_skeleton_type', type=str, default='human36m', help='') 50 | parser.add_argument('--n_points', type=int, default='32', help='') 51 | parser.add_argument('--subset', type=str, default='train', help='') 52 | parser.add_argument('--allow_unknown_options', action='store_true', help='') 53 | parser.add_argument('--shuffle', type=str, default='false', help='') 54 | parser.add_argument('--shuffle_identities', action='store_true', help='') 55 | parser.add_argument('--regressor_norm', type=str, default='instance', help='') 56 | parser.add_argument('--discriminators_norm', type=str, default='instance', help='') 57 | parser.add_argument('--generators_norm', type=str, default='batch', help='') 58 | parser.add_argument('--regressor_im_loss', type=float, default=0, help='') 59 | parser.add_argument('--finetune_regressor', action='store_true', help='') 60 | parser.add_argument('--reduce_rendering_mode', type=str, default='max', help='') 61 | parser.add_argument('--net_regressor', type=str, default='nips_encoder', help='') 62 | parser.add_argument('--net_regressor_channels', type=int, default=32, help='') 63 | parser.add_argument('--offline_regressor', action='store_true', help='') 64 | parser.add_argument('--eval_pose_prediction_only', action='store_true', help='') 65 | parser.add_argument('--prior_skeleton_type', type=str, default=None, help='') 66 | parser.add_argument('--plot_skeleton_type', type=str, default=None, help='') 67 | parser.add_argument('--sigma', type=float, default=0.4, help='') 68 | parser.add_argument('--avg_pool_style', action='store_true', help='') 69 | parser.add_argument('--netG_A_blocks', type=int, default=4, help='') 70 | parser.add_argument('--source_tps_params', type=float, default=[5.0, 0.05, 0.05, 0.0005, 0.005], nargs=5, help='') 71 | parser.add_argument('--target_tps_params', type=float, default=[5.0, 0.05, 0.05, 0.0, 0.0], nargs=5, help='') 72 | parser.add_argument('--plot_landmark_size', type=float, default=1.3, help='') 73 | parser.add_argument('--resume_from_name', type=str, default=None, help='') 74 | 75 | self.initialized = True 76 | return parser 77 | 78 | def gather_options(self): 79 | # initialize parser with basic options 80 | if not self.initialized: 81 | parser = configargparse.ArgumentParser( 82 | formatter_class=configargparse.ArgumentDefaultsHelpFormatter) 83 | parser = self.initialize(parser) 84 | 85 | # get the basic options 86 | opt, _ = parser.parse_known_args() 87 | 88 | # modify model-related parser options 89 | model_name = opt.model 90 | model_option_setter = models.get_option_setter(model_name) 91 | parser = model_option_setter(parser, self.isTrain) 92 | opt, _ = parser.parse_known_args() # parse again with the new defaults 93 | 94 | # modify dataset-related parser options 95 | dataset_name = opt.dataset_mode 96 | dataset_option_setter = data.get_option_setter(dataset_name) 97 | parser = dataset_option_setter(parser, self.isTrain) 98 | 99 | self.parser = parser 100 | 101 | if hasattr(opt, 'allow_unknown_options') and opt.allow_unknown_options: 102 | opt, unknown = parser.parse_known_args() 103 | else: 104 | opt = parser.parse_args() 105 | unknown = [] 106 | 107 | return opt, unknown 108 | 109 | def print_options(self, opt): 110 | message = '' 111 | message += '----------------- Options ---------------\n' 112 | for k, v in sorted(vars(opt).items()): 113 | comment = '' 114 | default = self.parser.get_default(k) 115 | if v != default: 116 | comment = '\t[default: %s]' % str(default) 117 | message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) 118 | message += '----------------- End -------------------' 119 | print(message) 120 | 121 | # save to the disk 122 | expr_dir = os.path.join(opt.checkpoints_dir, opt.name) 123 | util.mkdirs(expr_dir) 124 | file_name = os.path.join(expr_dir, 'opt.txt') 125 | with open(file_name, 'wt') as opt_file: 126 | opt_file.write(message) 127 | opt_file.write('\n') 128 | 129 | def print_unknown(self, unknown): 130 | message = '' 131 | message += '----------------- Unknown options ---------------\n' 132 | for item in unknown: 133 | if item.startswith('-'): 134 | message += '%s, ' % item 135 | message += '\n' 136 | message += '----------------- End -------------------' 137 | print(message) 138 | 139 | def parse(self): 140 | 141 | opt, unknown = self.gather_options() 142 | opt.isTrain = self.isTrain # train or test 143 | 144 | # process opt.suffix 145 | if opt.suffix: 146 | suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else '' 147 | opt.name = opt.name + suffix 148 | 149 | self.print_options(opt) 150 | if unknown: 151 | self.print_unknown(unknown) 152 | 153 | # set gpu ids 154 | str_ids = opt.gpu_ids.split(',') 155 | opt.gpu_ids = [] 156 | for str_id in str_ids: 157 | id = int(str_id) 158 | if id >= 0: 159 | opt.gpu_ids.append(id) 160 | if len(opt.gpu_ids) > 0: 161 | torch.cuda.set_device(opt.gpu_ids[0]) 162 | 163 | self.opt = opt 164 | return self.opt 165 | -------------------------------------------------------------------------------- /options/test_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TestOptions(BaseOptions): 5 | def initialize(self, parser): 6 | parser = BaseOptions.initialize(self, parser) 7 | parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.') 8 | parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.') 9 | parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images') 10 | parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') 11 | # Dropout and Batchnorm has different behavioir during training and test. 12 | parser.add_argument('--eval', action='store_true', help='use eval mode during test time.') 13 | parser.add_argument('--num_test', type=int, default=50, help='how many test images to run') 14 | parser.add_argument('--tune_steps', type=int, default=10, help='tune_steps') 15 | parser.add_argument('--tune_lr', type=float, default=0.01, help='tune_lr') 16 | parser.add_argument('--test_config', required=False, is_config_file=True, help='test config file path') 17 | 18 | parser.add_argument('--used_points', type=str, required=True, help='all|original') 19 | parser.add_argument('--error_form', type=str, required=True, help='all|image_size') 20 | parser.add_argument('--num_test_save', type=int, default=30, help='') 21 | parser.add_argument('--print_freq', type=int, default=5, help='') 22 | 23 | parser.set_defaults(subset='test') 24 | parser.set_defaults(model='test') 25 | parser.set_defaults(allow_unknown_options=True) 26 | # To avoid cropping, the loadSize should be the same as fineSize 27 | parser.set_defaults(loadSize=parser.get_default('fineSize')) 28 | self.isTrain = False 29 | return parser 30 | -------------------------------------------------------------------------------- /options/train_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TrainOptions(BaseOptions): 5 | def initialize(self, parser): 6 | parser = BaseOptions.initialize(self, parser) 7 | parser.add_argument('--display_freq', type=int, default=400, help='frequency of showing training results on screen') 8 | parser.add_argument('--display_ncols', type=int, default=4, help='if positive, display all images in a single visdom web panel with certain number of images per row.') 9 | parser.add_argument('--display_id', type=int, default=1, help='window id of the web display') 10 | parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display') 11 | parser.add_argument('--display_env', type=str, default='main', help='visdom display environment name (default is "main")') 12 | parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display') 13 | parser.add_argument('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html') 14 | parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console') 15 | parser.add_argument('--save_latest_freq', type=int, default=20000, help='frequency of saving the latest results') 16 | parser.add_argument('--save_iters_freq', type=int, default=20000, help='frequency of saving checkpoints') 17 | parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') 18 | parser.add_argument('--iters_count', type=int, default=1, help='the starting iters count, we save the model by , +, ...') 19 | parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') 20 | parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam') 21 | parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam') 22 | parser.add_argument('--no_lsgan', action='store_true', help='do *not* use least square GAN, if false, use vanilla GAN') 23 | parser.add_argument('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images') 24 | parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') 25 | # TODO: lr_policy fixed to none, needs to be adapted for iterations (was using epochs) 26 | parser.add_argument('--lr_policy', type=str, default='none', help='learning rate policy: lambda|step|plateau|cosine') 27 | parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations') 28 | parser.add_argument('--cycle_loss', type=str, default='l1', help='cycle loss: l1|perceptual') 29 | parser.add_argument('--clip_grad', type=float, default=float('inf'), help='') 30 | parser.add_argument('--not_optimize_G', action='store_true', help='') 31 | parser.add_argument('--not_optimize_D', action='store_true', help='') 32 | parser.add_argument('--regressor_fake_loss', type=float, default=0.0, help='') 33 | parser.add_argument('--regressor_real_loss', type=float, default=0.0, help='') 34 | parser.add_argument('--lambda_render_consistency', type=float, default=0.0, help='') 35 | parser.add_argument('--only_visible_points_loss', action='store_true', help='') 36 | 37 | self.isTrain = True 38 | return parser 39 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=0.4.0 2 | torchvision>=0.2.1 3 | dominate>=2.3.1 4 | visdom>=0.1.8.3 5 | h5py 6 | matplotlib==2.1.0 7 | scikit-image 8 | configargparse 9 | -------------------------------------------------------------------------------- /test_pose.py: -------------------------------------------------------------------------------- 1 | import os 2 | from options.test_options import TestOptions 3 | from data import CreateDataLoader 4 | from models import create_model 5 | from util.visualizer import save_images 6 | from util import html 7 | from util import util 8 | import torch 9 | from models import utils as mutils 10 | from data import human36m_skeleton 11 | import math 12 | import itertools 13 | import torch 14 | from collections import defaultdict 15 | import numpy as np 16 | import time 17 | import re 18 | 19 | 20 | if __name__ == '__main__': 21 | opt = TestOptions().parse() 22 | opt.num_test = 1000000 23 | num_save = opt.num_test_save 24 | 25 | # load data 26 | data_loader = CreateDataLoader(opt) 27 | dataset = data_loader.load_data() 28 | print('Created dataset with %d samples' % len(dataset)) 29 | 30 | # setup model 31 | model = create_model(opt) 32 | model.setup(opt) 33 | 34 | # create a website 35 | web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.iteration)) 36 | webpage = html.HTML(web_dir, '%s, Phase = %s, Iteration = %s' % ( 37 | opt.name, opt.phase, opt.iteration)) 38 | 39 | if opt.eval: 40 | model.eval() 41 | 42 | ############################################################################ 43 | def format_results_per_activity(results): 44 | """ 45 | results dict {activity: result} 46 | """ 47 | s = '' 48 | for activity, result in sorted(results.items()): 49 | s += '%s: %.4f ' % (activity, result) 50 | return s 51 | 52 | 53 | def format_results_per_activity2(results): 54 | """ 55 | results dict {activity: result} 56 | """ 57 | order_full = ['waiting', 'posing', 'greeting', 'directions', 'discussion', 'walking', 58 | 'eating', 'phone_call', 'purchases', 'sitting', 'sitting_down', 'smoking', 59 | 'taking_photo', 'walking_dog', 'walking_together'] 60 | order_yutig = ['Waiting', 'Posing', 'Greeting', 'Directions', 'Discussion', 'Walking'] 61 | 62 | if set(order_full) == set(results.keys()): 63 | order = order_full 64 | elif set(order_yutig) == set(results.keys()): 65 | order = order_yutig 66 | else: 67 | raise ValueError() 68 | 69 | numbers = ['%.4f' % results[k] for k in order] 70 | return '\t'.join(order), '\t'.join(numbers) 71 | 72 | 73 | def mean_per_activity(distances, paths, path_fn): 74 | activities = [path_fn(p) for p in paths] 75 | return mean_distance_per_activity(distances, activities) 76 | 77 | 78 | def human36m_path_to_activity(path): 79 | return path.split(os.path.sep)[-5] 80 | 81 | 82 | def y_human36m_path_to_activity(path): 83 | return re.split('\s|\.', path.split(os.path.sep)[-2])[0] 84 | 85 | 86 | def mean_distance_per_activity(distances, activities): 87 | d = defaultdict(list) 88 | for distance, activity in zip(distances, activities): 89 | d[activity].append(distance) 90 | means = {} 91 | for activity, values in d.items(): 92 | means[activity] = np.mean(values) 93 | return means 94 | 95 | 96 | def compute_mean_distance(input, target, correspondeces=None, 97 | target_correspondeces=None, used_points=None, 98 | offline_prediction=None): 99 | if target_correspondeces is not None: 100 | target_swapped = mutils.swap_points(target, target_correspondeces) 101 | else: 102 | target_swapped = target.clone() 103 | if correspondeces is not None: 104 | input_swapped = mutils.swap_points(input, correspondeces) 105 | if offline_prediction is not None: 106 | offline_prediction_swapped = mutils.swap_points(offline_prediction, correspondeces) 107 | else: 108 | input_swapped = input.clone() 109 | if offline_prediction is not None: 110 | offline_prediction_swapped = offline_prediction.clone() 111 | 112 | if used_points is not None: 113 | input = input[:, used_points] 114 | input_swapped = input_swapped[:, used_points] 115 | if offline_prediction is not None: 116 | offline_prediction = offline_prediction[:, used_points] 117 | target = target[:, used_points] 118 | target_swapped = target_swapped[:, used_points] 119 | 120 | # offline 121 | if offline_prediction is not None: 122 | distance = mutils.mean_l2_distance(offline_prediction, input) 123 | swapped_distance = mutils.mean_l2_distance(offline_prediction, input_swapped) 124 | min_idx = distance > swapped_distance 125 | for i in range(len(min_idx)): 126 | if min_idx[i]: 127 | input[i] = input_swapped[i] 128 | 129 | distance = mutils.mean_l2_distance(target, input) 130 | swapped_distance = mutils.mean_l2_distance(target_swapped, input) 131 | correct_flip = distance < swapped_distance 132 | min_distance = torch.min(distance, swapped_distance) 133 | 134 | return distance, min_distance, correct_flip 135 | 136 | def normalize_points(points): 137 | return (points + 1) / 2.0 138 | 139 | def points_to_original(points, height, width, height_ratio, width_ratio): 140 | """ 141 | points: B x N x 2 142 | """ 143 | points *= torch.tensor([[[height, width]]], dtype=torch.float32, device=points.device) 144 | points /= torch.stack([height_ratio, width_ratio], dim=-1, )[:, None].to(points.device) 145 | return points 146 | 147 | if opt.used_points == 'simple_links': 148 | used_points = set() 149 | for a, b in human36m_skeleton.simple_link_indices: 150 | used_points.add(a) 151 | used_points.add(b) 152 | used_points = list(used_points) 153 | used_points = sorted(used_points) 154 | elif opt.used_points == 'original': 155 | used_points = sorted(list(human36m_skeleton.official_eval_indices.values())) 156 | elif opt.used_points in ['all']: 157 | used_points = None 158 | else: 159 | raise ValueError() 160 | 161 | ############################################################################ 162 | 163 | distances = [] 164 | distances_min = [] 165 | correct_flips = [] 166 | paths = [] 167 | 168 | if opt.paired_skeleton_type in ['human36m', 'human36m_simple2']: 169 | target_correspondeces = human36m_skeleton.get_lr_correspondences() 170 | else: 171 | target_correspondeces = None 172 | 173 | if opt.skeleton_type in ['human36m', 'human36m_simple2']: 174 | correspondeces = human36m_skeleton.get_lr_correspondences() 175 | else: 176 | correspondeces = None 177 | 178 | n_batches = int(math.ceil(float(len(dataset)) / opt.batch_size)) 179 | 180 | save_frq = int(math.ceil(float(min(n_batches, opt.num_test)) / num_save)) 181 | 182 | avg_time = [] 183 | 184 | path_fn = human36m_path_to_activity 185 | if opt.dataset_mode == 'simplehuman36m': 186 | path_fn = y_human36m_path_to_activity 187 | 188 | iter_start_time = time.time() 189 | for i, data in enumerate(dataset): 190 | if i >= opt.num_test: 191 | break 192 | 193 | model.set_input(data) 194 | model.test() 195 | 196 | img_path = model.get_image_paths() 197 | 198 | prediction = model.regressed_points 199 | if hasattr(model, 'offline_regressed_points'): 200 | offline_prediction = model.offline_regressed_points 201 | else: 202 | offline_prediction = None 203 | target = model.paired_B_points 204 | 205 | prediction = normalize_points(prediction) 206 | if offline_prediction is not None: 207 | offline_prediction = normalize_points(offline_prediction) 208 | target = normalize_points(target) 209 | 210 | if opt.error_form == 'original': 211 | height_ratio = model.input['height_ratio'] 212 | width_ratio = model.input['width_ratio'] 213 | original_landmarks = model.input['landmarks'] 214 | prediction = points_to_original( 215 | prediction, opt.fineSize, opt.fineSize, height_ratio, width_ratio) 216 | if offline_prediction is not None: 217 | offline_prediction = points_to_original( 218 | offline_prediction, opt.fineSize, opt.fineSize, height_ratio, width_ratio) 219 | target = points_to_original( 220 | target, opt.fineSize, opt.fineSize, height_ratio, width_ratio) 221 | elif opt.error_form == 'image_size': 222 | pass 223 | else: 224 | raise ValueError() 225 | 226 | # compute distance error 227 | dist, dist_min, correct_flip = compute_mean_distance( 228 | prediction, target, correspondeces=correspondeces, 229 | target_correspondeces=target_correspondeces, 230 | used_points=used_points, offline_prediction=offline_prediction) 231 | 232 | # log results 233 | distances.extend(dist.cpu().numpy()) 234 | distances_min.extend(dist_min.cpu().numpy()) 235 | correct_flips.extend(correct_flip.cpu().numpy()) 236 | paths.extend(data['A_paths']) 237 | 238 | t = (time.time() - iter_start_time) 239 | iter_start_time = time.time() 240 | avg_time.append(t) 241 | if i % opt.print_freq == 0: 242 | samples_frq = float(opt.batch_size) / t 243 | samples_frq_avg = float(opt.batch_size) / np.mean(avg_time) 244 | time_str = '%.1f samples/sec %.1f samples/sec (avg)' % ( 245 | samples_frq, samples_frq_avg) 246 | 247 | print('processing (%d/%d)-th batch %s' % (i, n_batches, time_str)) 248 | print(np.random.choice(img_path, 1)) 249 | 250 | mean_distances = mean_per_activity(distances, paths, path_fn) 251 | mean_distance = np.mean(mean_distances.values()) 252 | mean_min_distances = mean_per_activity(distances_min, paths, path_fn) 253 | mean_min_distance = np.mean(mean_min_distances.values()) 254 | mean_correct_flips = mean_per_activity(correct_flips, paths, path_fn) 255 | mean_correct_flip = np.mean(mean_correct_flips.values()) 256 | 257 | results_str = 'mean distance %.4f\n' % mean_distance 258 | results_str += 'mean min distance %.4f\n' % mean_min_distance 259 | results_str += 'mean correct flips %.4f\n' % mean_correct_flip 260 | results_str += '%s\n' % format_results_per_activity(mean_distances) 261 | results_str += '%s\n' % format_results_per_activity(mean_min_distances) 262 | results_str += '%s\n' % format_results_per_activity(mean_correct_flips) 263 | print(results_str) 264 | 265 | if i % save_frq == 0: 266 | visuals = model.get_current_visuals() 267 | webpage.add_text(data['A_paths'][0]) 268 | if opt.dataset_mode == 'simplehuman36m': 269 | data['image_name'] = ['-'.join(x.split(os.path.sep)[-2:]) for x in img_path] 270 | if 'image_name' in data: 271 | img_names = [x.replace(os.path.sep, '-') for x in data['image_name']] 272 | take_basename = False 273 | else: 274 | img_names = img_path 275 | take_basename = True 276 | save_images( 277 | webpage, visuals, img_names, aspect_ratio=opt.aspect_ratio, 278 | width=opt.display_winsize, basename=take_basename) 279 | webpage.save() 280 | 281 | mean_distances = mean_per_activity(distances, paths, path_fn) 282 | mean_distance = np.mean(mean_distances.values()) 283 | mean_min_distances = mean_per_activity(distances_min, paths, path_fn) 284 | mean_min_distance = np.mean(mean_min_distances.values()) 285 | mean_correct_flips = mean_per_activity(correct_flips, paths, path_fn) 286 | mean_correct_flip = np.mean(mean_correct_flips.values()) 287 | 288 | results_str = 'mean distance %.4f\n' % mean_distance 289 | results_str += 'mean min distance %.4f\n' % mean_min_distance 290 | results_str += 'mean correct flips %.4f\n' % mean_correct_flip 291 | results_str += '%s\n' % format_results_per_activity(mean_distances) 292 | results_str += '%s\n' % format_results_per_activity(mean_min_distances) 293 | results_str += '%s\n' % format_results_per_activity(mean_min_distances) 294 | results_str += '%s\n' % format_results_per_activity(mean_correct_flips) 295 | 296 | print(results_str) 297 | webpage.add_text(results_str) 298 | 299 | results_str = format_results_per_activity2(mean_distances) 300 | print(results_str[0]) 301 | webpage.add_text(results_str[0]) 302 | print(results_str[1]) 303 | webpage.add_text(results_str[1]) 304 | 305 | results_str = format_results_per_activity2(mean_min_distances) 306 | print(results_str[0]) 307 | webpage.add_text(results_str[0]) 308 | print(results_str[1]) 309 | webpage.add_text(results_str[1]) 310 | 311 | # save the website 312 | webpage.save() 313 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import time 2 | import collections 3 | import numpy as np 4 | from options.train_options import TrainOptions 5 | from data import CreateDataLoader 6 | from models import create_model 7 | from util.visualizer import Visualizer 8 | from util.util import Timer 9 | 10 | 11 | if __name__ == '__main__': 12 | opt = TrainOptions().parse() 13 | data_loader = CreateDataLoader(opt) 14 | dataset = data_loader.load_data() 15 | dataset_size = len(data_loader) 16 | print('#training images = %d' % dataset_size) 17 | 18 | model = create_model(opt) 19 | model.setup(opt) 20 | visualizer = Visualizer(opt) 21 | iterations = opt.iters_count 22 | 23 | avg_time = collections.deque(maxlen=100) 24 | 25 | while True: 26 | iter_start_time = time.time() 27 | iter_data_time = time.time() 28 | for i, data in enumerate(dataset): 29 | iterations += 1 30 | visualizer.reset() 31 | 32 | model.set_input(data) 33 | t_data = time.time() - iter_data_time 34 | 35 | optim_time = time.time() 36 | model.optimize_parameters() 37 | t_optim = time.time() - optim_time 38 | 39 | if iterations % opt.display_freq == 0: 40 | save_result = iterations % opt.update_html_freq == 0 41 | visualizer.display_current_results( 42 | model.get_current_visuals(), iterations, save_result) 43 | 44 | t = (time.time() - iter_start_time) 45 | iter_start_time = time.time() 46 | avg_time.append(t) 47 | if iterations % opt.print_freq == 0: 48 | losses = model.get_current_losses() 49 | samples_frq = float(opt.batch_size) / t 50 | samples_frq_avg = float(opt.batch_size) / np.mean(avg_time) 51 | prefix = '%.1f samples/sec %.1f samples/sec (avg) %.2f optim ' % ( 52 | samples_frq, samples_frq_avg, t_optim) 53 | visualizer.print_current_losses( 54 | iterations, losses, t, t_data, prefix=prefix) 55 | if opt.display_id > 0: 56 | visualizer.plot_current_losses(iterations, opt, losses) 57 | 58 | if iterations % opt.save_latest_freq == 0: 59 | print('saving the latest model (iterations %d)' % (iterations)) 60 | model.save_networks('latest') 61 | 62 | if iterations % opt.save_iters_freq == 0: 63 | print('saving the model at iters %d' % (iterations)) 64 | model.save_networks('latest') 65 | model.save_networks(iterations) 66 | 67 | iter_data_time = time.time() 68 | 69 | # FIXME: should be called at the end of an epoch 70 | # model.update_learning_rate() 71 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomasjakab/keypointgan/541b769d536dc113fcf6da271ed72ae9d963dbb4/util/__init__.py -------------------------------------------------------------------------------- /util/html.py: -------------------------------------------------------------------------------- 1 | import dominate 2 | from dominate.tags import * 3 | from dominate.util import raw 4 | import os 5 | 6 | 7 | 8 | refresh_js = ''' 9 | function refresh_single(refresh_id) { 10 | var source = document.getElementById(refresh_id).src; 11 | var timestamp = (new Date()).getTime(); 12 | var newUrl = source + '?_=' + timestamp; 13 | document.getElementById(refresh_id).src = newUrl; 14 | } 15 | 16 | function refresh(refresh_ids) { 17 | var arrayLength = refresh_ids.length; 18 | for (var i = 0; i < arrayLength; i++) { 19 | console.log(refresh_ids[i]); 20 | refresh_single(refresh_ids[i]); 21 | } 22 | } 23 | ''' 24 | 25 | 26 | class HTML: 27 | def __init__(self, web_dir, title, reflesh=0): 28 | self.title = title 29 | self.web_dir = web_dir 30 | self.img_dir = os.path.join(self.web_dir, 'images') 31 | if not os.path.exists(self.web_dir): 32 | os.makedirs(self.web_dir) 33 | if not os.path.exists(self.img_dir): 34 | os.makedirs(self.img_dir) 35 | # print(self.img_dir) 36 | 37 | self.doc = dominate.document(title=title) 38 | 39 | with self.doc.head: 40 | script(raw(refresh_js)) 41 | 42 | if reflesh > 0: 43 | with self.doc.head: 44 | meta(http_equiv="refresh", content=str(reflesh)) 45 | 46 | with self.doc: 47 | h1(title, style="font-size: x-large;") 48 | 49 | def get_image_dir(self): 50 | return self.img_dir 51 | 52 | def add_header(self, str): 53 | with self.doc: 54 | h3(str, style="margin: 0; font-size: medium;") 55 | 56 | def add_text(self, str): 57 | with self.doc: 58 | p(str) 59 | 60 | def add_table(self, border=1): 61 | self.t = table(border=border, style="table-layout: fixed; border: 0px solid black;") 62 | self.doc.add(self.t) 63 | 64 | 65 | def add_media(self, ims, txts, links, type, width=400, title=None): 66 | # add refresh button 67 | with self.doc: 68 | with p(style="margin: 2px 0 0 0;"): 69 | if title is not None: 70 | span(title, style="margin: 0; font-size: medium;") 71 | button("refresh", onclick="refresh(" + str(ims) + ")") 72 | self.add_table() 73 | with self.t: 74 | with tr(): 75 | for im, txt, link in zip(ims, txts, links): 76 | with td(style="word-wrap: break-word; border: 0px solid black;", halign="center", valign="top"): 77 | with a(href=os.path.join('images', link)): 78 | if type == 'image': 79 | self.add_image(im, width) 80 | elif type == 'video': 81 | self.add_video(im, width) 82 | else: 83 | raise ValueError() 84 | p(txt, style="margin: 0;") 85 | 86 | 87 | def add_image(self, im, width): 88 | img(style="width:%dpx" % width, 89 | src=os.path.join('images', im), 90 | id=im) 91 | 92 | def add_video(self, im, width): 93 | with video(style="width:%dpx" % width, id=im, autoplay="", loop="", muted="", inline="", playsinline=""): 94 | source(src=os.path.join('images', im), type="video/mp4") 95 | 96 | def save(self): 97 | html_file = '%s/index.html' % self.web_dir 98 | f = open(html_file, 'wt') 99 | f.write(self.doc.render()) 100 | f.close() 101 | 102 | 103 | if __name__ == '__main__': 104 | html = HTML('web/', 'test_html') 105 | html.add_header('hello world') 106 | 107 | ims = [] 108 | txts = [] 109 | links = [] 110 | for n in range(4): 111 | ims.append('image_%d.png' % n) 112 | txts.append('text_%d' % n) 113 | links.append('image_%d.png' % n) 114 | html.add_images(ims, txts, links) 115 | html.save() 116 | -------------------------------------------------------------------------------- /util/image_pool.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | 4 | 5 | class ImagePool(): 6 | def __init__(self, pool_size): 7 | self.pool_size = pool_size 8 | if self.pool_size > 0: 9 | self.num_imgs = 0 10 | self.images = [] 11 | 12 | def query(self, images): 13 | if self.pool_size == 0: 14 | return images 15 | return_images = [] 16 | for image in images: 17 | image = torch.unsqueeze(image.data, 0) 18 | if self.num_imgs < self.pool_size: 19 | self.num_imgs = self.num_imgs + 1 20 | self.images.append(image) 21 | return_images.append(image) 22 | else: 23 | p = random.uniform(0, 1) 24 | if p > 0.5: 25 | random_id = random.randint(0, self.pool_size - 1) # randint is inclusive 26 | tmp = self.images[random_id].clone() 27 | self.images[random_id] = image 28 | return_images.append(tmp) 29 | else: 30 | return_images.append(image) 31 | return_images = torch.cat(return_images, 0) 32 | return return_images 33 | -------------------------------------------------------------------------------- /util/plotting.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import matplotlib as mpl 4 | mpl.use('Agg') 5 | 6 | import matplotlib.transforms 7 | import matplotlib.pyplot as plt 8 | import numpy as np 9 | import tempfile 10 | import os 11 | import shutil 12 | 13 | from PIL import Image 14 | 15 | 16 | def get_marker_style(i, cmap='Dark2'): 17 | cmap = plt.get_cmap(cmap) 18 | colors = [cmap(c) for c in np.linspace(0., 1., 8)] 19 | markers = ['v', 'o', 's', 'd', '^', 'x', '+'] 20 | max_i = len(colors) * len(markers) - 1 21 | if i > max_i: 22 | raise ValueError('Exceeded maximum (' + str(max_i) + ') index for styles.') 23 | c = i % len(colors) 24 | m = int(i / len(colors)) 25 | return colors[c], markers[m] 26 | 27 | 28 | def single_marker_style(color, marker): 29 | return lambda _: (color, marker) 30 | 31 | 32 | def plot_line(ax, a, b, k, size=1.5, zorder=2, cmap='Dark2', 33 | style_fn=None): 34 | if style_fn is None: 35 | c, _ = get_marker_style(k, cmap=cmap) 36 | else: 37 | c, _ = style_fn(k) 38 | line = ax.plot([a[0], b[0]], [a[1], b[1]], c=c, zorder=zorder, linewidth=10) 39 | plt.setp(line, linewidth=5) 40 | 41 | 42 | def plot_lines(ax, lines, size=1.5, zorder=2, cmap='Dark2', style_fn=None): 43 | # TODO: avoid for loop if possible 44 | for k, (a, b) in enumerate(lines): 45 | plot_line(ax, a, b, k, size=size, zorder=zorder, 46 | cmap=cmap, style_fn=style_fn) 47 | 48 | def plot_landmark(ax, landmark, k, size=1.5, zorder=2, cmap='Dark2', 49 | style_fn=None): 50 | if style_fn is None: 51 | c, m = get_marker_style(k, cmap=cmap) 52 | else: 53 | c, m = style_fn(k) 54 | ax.scatter(landmark[1], landmark[0], c=c, marker=m, 55 | s=(size * mpl.rcParams['lines.markersize']) ** 2, 56 | zorder=zorder) 57 | 58 | 59 | def plot_landmarks(ax, landmarks, size=1.5, zorder=2, cmap='Dark2', style_fn=None): 60 | # TODO: avoid for loop if possible 61 | for k, landmark in enumerate(landmarks): 62 | plot_landmark(ax, landmark, k, size=size, zorder=zorder, 63 | cmap=cmap, style_fn=style_fn) 64 | 65 | 66 | def show_landmarks(image, landmarks, save_path, landmark_size=1.5, 67 | style='uniform', color='limegreen', connections=None): 68 | def plt_start(): 69 | fig = plt.figure(figsize=(4, 4), dpi=80) 70 | ax = plt.gca() 71 | return fig, ax 72 | 73 | def plt_finish(ax, fig, path): 74 | ax.set_ylim([sz[0], 0]) 75 | ax.set_xlim([0, sz[1]]) 76 | plt.tight_layout() 77 | plt.autoscale(tight=True) 78 | ax.set_xticklabels([]) 79 | ax.set_yticklabels([]) 80 | ax.set_xticklabels([]) 81 | ax.set_axis_off() 82 | ax.set_aspect('equal') 83 | fig.subplots_adjust(bottom=0, top=1, left=0, right=1) 84 | plt.savefig(path, bbox_inches=matplotlib.transforms.Bbox.from_extents( 85 | [0, 0, 4, 4]), pad_inches=0) 86 | plt.close(fig.number) 87 | 88 | sz = image.shape[:2] 89 | landmarks_scaled = ((landmarks + 1) / 2.0) * sz 90 | landmarks_scaled = np.clip(landmarks_scaled, 6, sz[0] - 6) 91 | 92 | save_dir = [] 93 | fig, ax = plt_start() 94 | ax.imshow(image) 95 | if style == 'uniform': 96 | ax.scatter(landmarks_scaled[:, 0].T, landmarks_scaled[:, 1].T, c=color, 97 | s=(landmark_size * mpl.rcParams['lines.markersize']) ** 2) 98 | elif style == 'skeleton': 99 | lines = [] 100 | for a, b in connections: 101 | lines.append((landmarks_scaled[a], landmarks_scaled[b])) 102 | plot_lines(ax, lines, size=landmark_size) 103 | else: 104 | plot_landmarks(ax, landmarks_scaled, size=landmark_size) 105 | plt_finish(ax, fig, save_path) 106 | 107 | 108 | def plot_in_image(image, landmarks, landmark_size=1.3, color='limegreen', 109 | style='uniform', connections=None): 110 | tempdir_path = tempfile.mkdtemp() 111 | tempfile_path = os.path.join(tempdir_path, 'im.png') 112 | try: 113 | show_landmarks(image, landmarks, tempfile_path, 114 | landmark_size=landmark_size, color=color, 115 | connections=connections, style=style) 116 | plot = np.array(Image.open(tempfile_path)) 117 | finally: 118 | shutil.rmtree(tempdir_path) 119 | return plot 120 | 121 | 122 | -------------------------------------------------------------------------------- /util/skeleton.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.stats import truncnorm 3 | 4 | 5 | def render_line_segment(s1, s2, size, distance='gauss', discrete=False): 6 | def sumprod(x, y): 7 | return np.sum(x * y, axis=-1, keepdims=True) 8 | 9 | x = np.linspace(-1.0, 1.0, size).astype('float32') 10 | y = np.linspace(-1.0, 1.0, size).astype('float32') 11 | 12 | xv, yv = np.meshgrid(x, y) 13 | m = np.concatenate([xv[..., None], yv[..., None]], axis=-1) 14 | 15 | s1, s2 = s1[None, None], s2[None, None] 16 | t_min = sumprod(m - s1, s2 - s1) / np.maximum(sumprod(s2 - s1, s2 - s1), 1e-6) 17 | t_line = np.minimum(np.maximum(t_min, 0.0), 1.0) 18 | 19 | s = s1 + t_line * (s2 - s1) 20 | d = np.sqrt(sumprod(s - m, s - m)) 21 | 22 | if discrete: 23 | distance = 'norm' 24 | 25 | # normalize distance 26 | if distance == 'gauss': 27 | d_norm = np.exp(-d / (0.2 ** 2)) 28 | elif distance == 'norm': 29 | d_max = np.sqrt(8) 30 | d_norm = (d_max - d) / d_max 31 | else: 32 | raise ValueError() 33 | 34 | thick = 0.9925 35 | if discrete: 36 | d_norm[d_norm >= thick] = 1.0 37 | d_norm[d_norm < thick] = 0.0 38 | 39 | return d_norm 40 | 41 | 42 | def render_skeleton(points, connections, width, height, colored=False): 43 | assert width == height 44 | maps = [] 45 | numbers = np.linspace(0.2, 1.0, len(connections)) 46 | discrete = False 47 | if colored: 48 | discrete = True 49 | for (a, b), number in zip(connections, numbers): 50 | render = render_line_segment( 51 | points[a], points[b], width, discrete=discrete) 52 | if colored: 53 | render *= number 54 | maps.append(render) 55 | maps = np.concatenate(maps, axis=-1) 56 | return maps 57 | 58 | 59 | def normalize_landmarks(landmarks): 60 | """ 61 | Centre and stretch landmarks, preserve aspect ratio 62 | Landmarks are [[y_0, x_0], [y_1, x_1], ...] 63 | """ 64 | ymin, ymax = min(landmarks[:, 0]), max(landmarks[:, 0]) 65 | xmin, xmax = min(landmarks[:, 1]), max(landmarks[:, 1]) 66 | 67 | # put in the corner 68 | landmarks -= np.min(landmarks, axis=0, keepdims=True) 69 | # normalize between -1, 1 70 | height, width = np.max(landmarks, axis=0) 71 | landmarks = 2.0 * (landmarks / max(height, width)) - 1.0 72 | # centre 73 | landmarks += (1.0 - np.max(landmarks, axis=0, keepdims=True)) / 2.0 74 | return landmarks 75 | 76 | 77 | def rotate_points(points, angle): 78 | rot = np.deg2rad(angle) 79 | af = [[ np.cos(rot), np.sin(rot), 0], 80 | [-np.sin(rot), np.cos(rot), 0]] 81 | af = np.array(af, dtype=np.float32) 82 | ones = np.ones((points.shape[0], 1), dtype=np.float32) 83 | points = np.concatenate([points, ones], axis=1) 84 | points = np.matmul(points, af.T) 85 | return points 86 | 87 | 88 | def jitter_landmarks(landmarks, zoom=[0.5, 1.0], aspect_ratio=[1.0, 1.0], 89 | shift=True, rotate=[0.0, 0.0]): 90 | """ 91 | expects points normalized in [-1, 1] 92 | """ 93 | # rotate and refit to the canvas 94 | if rotate != [0.0, 0.0]: 95 | angle = np.random.uniform(rotate[0], rotate[1]) 96 | landmarks = rotate_points(landmarks, angle) 97 | landmarks = normalize_landmarks(landmarks) 98 | 99 | # zoom 100 | if zoom != [1.0, 1.0]: 101 | # generate random number between 0.0 and 1.0 (1.0 has higher probability) 102 | rand = 1 + truncnorm.rvs(-1.0, 0.0) 103 | zoom_ratio = zoom[0] + (zoom[1] - zoom[0]) * rand 104 | landmarks *= zoom_ratio 105 | 106 | # aspect ratio 107 | if aspect_ratio != [1.0, 1.0]: 108 | landmarks[:, 0] *= np.random.uniform(aspect_ratio[0], aspect_ratio[1]) 109 | 110 | # shift but keep all in the canvas 111 | if shift: 112 | shift_y = [-1 - min(landmarks[:, 0]), 1 - max(landmarks[:, 0])] 113 | shift_x = [-1 - min(landmarks[:, 1]), 1 - max(landmarks[:, 1])] 114 | landmarks[:, 0] += np.random.uniform(shift_y[0], shift_y[1]) 115 | landmarks[:, 1] += np.random.uniform(shift_x[0], shift_x[1]) 116 | 117 | return landmarks 118 | 119 | 120 | def pad_landmarks(landmarks, ratio): 121 | """ 122 | landmarks normalize in [-1, 1] 123 | ratio in [0, 1] 124 | """ 125 | return landmarks * (1.0 / (1 + 2 * ratio)) 126 | 127 | 128 | def landmarks_to_image_space(landmarks, height, width): 129 | landmarks = ((landmarks + 1.0) / 2) * min(height, width) 130 | return landmarks 131 | -------------------------------------------------------------------------------- /util/tps_sampler.py: -------------------------------------------------------------------------------- 1 | # ========================================================== 2 | # Author: Ankush Gupta, Tomas Jakab 3 | # ========================================================== 4 | import scipy.spatial.distance as ssd 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import random 10 | 11 | 12 | class TPSRandomSampler(nn.Module): 13 | 14 | def __init__(self, height, width, vertical_points=10, horizontal_points=10, 15 | rotsd=0.0, scalesd=0.0, transsd=0.1, warpsd=(0.001, 0.005), 16 | cache_size=1000, cache_evict_prob=0.01, pad=True, device=None): 17 | super(TPSRandomSampler, self).__init__() 18 | 19 | self.input_height = height 20 | self.input_width = width 21 | 22 | self.h_pad = 0 23 | self.w_pad = 0 24 | if pad: 25 | self.h_pad = self.input_height // 2 26 | self.w_pad = self.input_width // 2 27 | 28 | self.height = self.input_height + 2 * self.h_pad 29 | self.width = self.input_width + 2 * self.w_pad 30 | 31 | self.vertical_points = vertical_points 32 | self.horizontal_points = horizontal_points 33 | 34 | self.rotsd = rotsd 35 | self.scalesd = scalesd 36 | self.transsd = transsd 37 | self.warpsd = warpsd 38 | self.cache_size = cache_size 39 | self.cache_evict_prob = cache_evict_prob 40 | 41 | self.tps = TPSGridGen( 42 | self.height, self.width, vertical_points, horizontal_points) 43 | 44 | self.cache = [None] * self.cache_size 45 | 46 | self.pad = pad 47 | 48 | self.device = device 49 | 50 | 51 | def _sample_grid(self): 52 | W = sample_tps_w( 53 | self.vertical_points, self.horizontal_points, self.warpsd, 54 | self.rotsd, self.scalesd, self.transsd) 55 | W = torch.from_numpy(W.astype(np.float32)) 56 | # generate grid 57 | grid = self.tps(W[None]) 58 | return grid 59 | 60 | 61 | def _get_grids(self, batch_size): 62 | grids = [] 63 | for i in range(batch_size): 64 | entry = random.randint(0, self.cache_size - 1) 65 | if self.cache[entry] is None or random.random() < self.cache_evict_prob: 66 | grid = self._sample_grid() 67 | if self.device is not None: 68 | grid = grid.to(self.device) 69 | self.cache[entry] = grid 70 | else: 71 | grid = self.cache[entry] 72 | grids.append(grid) 73 | grids = torch.cat(grids) 74 | return grids 75 | 76 | 77 | def forward(self, input): 78 | if self.device is not None: 79 | input_device = input.device 80 | input = input.to(self.device) 81 | 82 | # get TPS grids 83 | batch_size = input.size(0) 84 | grids = self._get_grids(batch_size) 85 | 86 | if self.device is None: 87 | grids = grids.to(input.device) 88 | 89 | input = F.pad(input, (self.h_pad, self.h_pad, self.w_pad, 90 | self.w_pad), mode='replicate') 91 | input = F.grid_sample(input, grids) 92 | input = F.pad(input, (-self.h_pad, -self.h_pad, -self.w_pad, -self.w_pad)) 93 | 94 | if self.device is not None: 95 | input = input.to(input_device) 96 | 97 | return input 98 | 99 | 100 | def forward_py(self, input): 101 | with torch.no_grad(): 102 | input = torch.from_numpy(input) 103 | input = input.permute([0, 3, 1, 2]) 104 | input = self.forward(input) 105 | input = input.permute([0, 2, 3, 1]) 106 | input = input.numpy() 107 | return input 108 | 109 | 110 | 111 | class TPSGridGen(nn.Module): 112 | 113 | def __init__(self, Ho, Wo, Hc, Wc): 114 | """ 115 | Ho,Wo: height/width of the output tensor (grid dimensions). 116 | Hc,Wc: height/width of the control-point grid. 117 | 118 | Assumes for simplicity that the control points lie on a regular grid. 119 | Can be made more general. 120 | """ 121 | super(TPSGridGen, self).__init__() 122 | 123 | self._grid_hw = (Ho, Wo) 124 | self._cp_hw = (Hc, Wc) 125 | 126 | # initialize the grid: 127 | xx, yy = np.meshgrid(np.linspace(-1, 1, Wo), np.linspace(-1, 1, Ho)) 128 | self._grid = np.c_[xx.flatten(), yy.flatten()].astype(np.float32) # Nx2 129 | self._n_grid = self._grid.shape[0] 130 | 131 | # initialize the control points: 132 | xx, yy = np.meshgrid(np.linspace(-1, 1, Wc), np.linspace(-1, 1, Hc)) 133 | self._control_pts = np.c_[ 134 | xx.flatten(), yy.flatten()].astype(np.float32) # Mx2 135 | self._n_cp = self._control_pts.shape[0] 136 | 137 | # compute the pair-wise distances b/w control-points and grid-points: 138 | Dx = ssd.cdist(self._grid, self._control_pts, metric='sqeuclidean') # NxM 139 | 140 | # create the tps kernel: 141 | # real_min = 100 * np.finfo(np.float32).min 142 | real_min = 1e-8 143 | Dx = np.clip(Dx, real_min, None) # avoid log(0) 144 | Kp = np.log(Dx) * Dx 145 | Os = np.ones((self._grid.shape[0])) 146 | L = np.c_[Kp, np.ones((self._n_grid, 1), dtype=np.float32), 147 | self._grid] # Nx(M+3) 148 | self._L = torch.from_numpy(L.astype(np.float32)) # Nx(M+3) 149 | 150 | 151 | def forward(self, w_tps): 152 | """ 153 | W_TPS: Bx(M+3)x2 sized tensor of tps-transformation params. 154 | here `M` is the number of control-points. 155 | `B` is the batch-size. 156 | 157 | Returns an BxHoxWox2 tensor of grid coordinates. 158 | """ 159 | assert w_tps.shape[1] - 3 == self._n_cp 160 | batch_size = w_tps.shape[0] 161 | tfm_grid = torch.matmul(self._L, w_tps) 162 | tfm_grid = tfm_grid.reshape( 163 | (batch_size, self._grid_hw[0], self._grid_hw[1], 2)) 164 | return tfm_grid 165 | 166 | 167 | 168 | def sample_tps_w(Hc, Wc, warpsd, rotsd, scalesd, transsd): 169 | """ 170 | Returns randomly sampled TPS-grid params of size (Hc*Wc+3)x2. 171 | 172 | Params: 173 | WARPSD: 2-tuple 174 | {ROT/SCALE/TRANS}-SD: 1-tuple of standard devs. 175 | """ 176 | Nc = Hc * Wc # no of control-pots 177 | # non-linear component: 178 | mask = (np.random.rand(Nc, 2) > 0.5).astype(np.float32) 179 | W = warpsd[0] * np.random.randn(Nc, 2) + \ 180 | warpsd[1] * (mask * np.random.randn(Nc, 2)) 181 | # affine component: 182 | rnd = np.random.randn 183 | rot = np.deg2rad(rnd() * rotsd) 184 | sc = 1.0 + rnd() * scalesd 185 | aff = [[transsd*rnd(), transsd*rnd()], 186 | [sc * np.cos(rot), sc * -np.sin(rot)], 187 | [sc * np.sin(rot), sc * np.cos(rot)]] 188 | W = np.r_[W, aff] 189 | return W 190 | -------------------------------------------------------------------------------- /util/util.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | import numpy as np 4 | from PIL import Image 5 | import os 6 | import time 7 | 8 | # Converts a Tensor into an image array (numpy) 9 | # |imtype|: the desired type of the converted numpy array 10 | def tensor2im(input_image, imtype=np.uint8): 11 | if isinstance(input_image, torch.Tensor): 12 | image_tensor = input_image.data 13 | else: 14 | return input_image 15 | image_numpy = image_tensor[0].cpu().float().numpy() 16 | if image_numpy.shape[0] == 1: 17 | image_numpy = (image_numpy - np.min(image_numpy)) / (np.max(image_numpy) - np.min(image_numpy)) 18 | image_numpy = image_numpy * 2 - 1 19 | image_numpy = np.tile(image_numpy, (3, 1, 1)) 20 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 21 | image_numpy = np.clip(image_numpy, 0.0, 255.0) 22 | return image_numpy.astype(imtype) 23 | 24 | 25 | def diagnose_network(net, name='network'): 26 | mean = 0.0 27 | count = 0 28 | for param in net.parameters(): 29 | if param.grad is not None: 30 | mean += torch.mean(torch.abs(param.grad.data)) 31 | count += 1 32 | if count > 0: 33 | mean = mean / count 34 | print(name) 35 | print(mean) 36 | 37 | 38 | def save_image(image_numpy, image_path): 39 | image_pil = Image.fromarray(image_numpy) 40 | image_pil.save(image_path) 41 | 42 | 43 | def print_numpy(x, val=True, shp=False): 44 | x = x.astype(np.float64) 45 | if shp: 46 | print('shape,', x.shape) 47 | if val: 48 | x = x.flatten() 49 | print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( 50 | np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) 51 | 52 | 53 | def mkdirs(paths): 54 | if isinstance(paths, list) and not isinstance(paths, str): 55 | for path in paths: 56 | mkdir(path) 57 | else: 58 | mkdir(paths) 59 | 60 | 61 | def mkdir(path): 62 | if not os.path.exists(path): 63 | os.makedirs(path) 64 | 65 | 66 | def isclose(a, b, rel_tol=1e-09, abs_tol=0.0): 67 | return abs(a-b) <= max(rel_tol * max(abs(a), abs(b)), abs_tol) 68 | 69 | class Timer(object): 70 | def __init__(self, name=None, acc=False, avg=False): 71 | self.name = name 72 | self.acc = acc 73 | self.avg = avg 74 | self.total = 0.0 75 | self.iters = 0 76 | 77 | def __enter__(self): 78 | self.start() 79 | 80 | def __exit__(self, type, value, traceback): 81 | self.stop() 82 | 83 | def start(self): 84 | self.tstart = time.time() 85 | 86 | def stop(self): 87 | self.iters += 1 88 | self.total += time.time() - self.tstart 89 | if not self.acc: 90 | self.reset() 91 | 92 | def reset(self): 93 | name_string = '' 94 | if self.name: 95 | name_string = '[' + self.name + '] ' 96 | value = self.total 97 | msg = 'Elapsed' 98 | if self.avg: 99 | value /= self.iters 100 | msg = 'Avg Elapsed' 101 | print('%s%s: %.4f' % (name_string, msg, value)) 102 | self.total = 0.0 103 | -------------------------------------------------------------------------------- /util/visualizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import sys 4 | import ntpath 5 | import time 6 | import torch 7 | import re 8 | from . import util 9 | from . import html 10 | from skimage.transform import resize 11 | from PIL import Image 12 | from models.utils import normalize_image_tensor 13 | from collections import OrderedDict 14 | import skvideo.io 15 | 16 | 17 | if sys.version_info[0] == 2: 18 | VisdomExceptionBase = Exception 19 | else: 20 | VisdomExceptionBase = ConnectionError 21 | 22 | 23 | # save image to the disk 24 | def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256, 25 | basename=True): 26 | # preprocess multichannel maps 27 | for label, image in visuals.items(): 28 | if image.shape[1] not in [1, 3] and isinstance(image, torch.Tensor): 29 | max_frames = 20 30 | for i in range(min(max_frames, image.shape[1])): 31 | visuals[label + str(i)] = image[:, i][:, None] 32 | visuals[label], _ = torch.max(image, dim=1, keepdim=True) 33 | 34 | image_dir = webpage.get_image_dir() 35 | if basename: 36 | short_path = ntpath.basename(image_path[0]) 37 | name = os.path.splitext(short_path)[0] 38 | else: 39 | name = image_path[0] 40 | 41 | ims, txts, links = [], [], [] 42 | 43 | for label, im_data in visuals.items(): 44 | if label == 'fake_B': 45 | tensor = np.transpose(im_data[0].cpu().numpy(), (1, 2, 0)) 46 | tensor_name = '%s-%s.npy' % (name, label) 47 | np.save(os.path.join(image_dir, tensor_name), tensor) 48 | 49 | im = util.tensor2im(im_data) 50 | image_name = '%s-%s.png' % (name, label) 51 | save_path = os.path.join(image_dir, image_name) 52 | h, w, _ = im.shape 53 | if aspect_ratio > 1.0: 54 | im = resize(im, (h, int(w * aspect_ratio)), interp='bicubic') 55 | if aspect_ratio < 1.0: 56 | im = resize(im, (int(h / aspect_ratio), w), interp='bicubic') 57 | util.save_image(im, save_path) 58 | 59 | ims.append(image_name) 60 | txts.append(label) 61 | links.append(image_name) 62 | webpage.add_media(ims, txts, links, 'image', width=width, title=name) 63 | 64 | 65 | def save_videos(webpage, visuals_log, image_path, width=256, basename=True): 66 | image_dir = webpage.get_image_dir() 67 | if basename: 68 | short_path = ntpath.basename(image_path[0]) 69 | name = os.path.splitext(short_path)[0] 70 | else: 71 | name = image_path[0] 72 | 73 | ims, txts, links = [], [], [] 74 | 75 | for label in visuals_log.keys(): 76 | video_name = '%s-%s.mp4' % (name, label) 77 | save_path = os.path.join(image_dir, video_name) 78 | visuals_log.save_as_video(label, save_path) 79 | ims.append(video_name) 80 | txts.append(label) 81 | links.append(video_name) 82 | 83 | webpage.add_media(ims, txts, links, 'video', width=width, title=name) 84 | 85 | 86 | class Visualizer(): 87 | def __init__(self, opt): 88 | self.display_id = opt.display_id 89 | self.use_html = opt.isTrain and not opt.no_html 90 | self.win_size = opt.display_winsize 91 | self.name = opt.name 92 | self.opt = opt 93 | self.saved = False 94 | if self.display_id > 0: 95 | import visdom 96 | self.ncols = opt.display_ncols 97 | self.vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, env=opt.display_env, raise_exceptions=True) 98 | 99 | if self.use_html: 100 | self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') 101 | self.img_dir = os.path.join(self.web_dir, 'images') 102 | print('create web directory %s...' % self.web_dir) 103 | util.mkdirs([self.web_dir, self.img_dir]) 104 | self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') 105 | with open(self.log_name, "a") as log_file: 106 | now = time.strftime("%c") 107 | log_file.write('================ Training Loss (%s) ================\n' % now) 108 | 109 | def reset(self): 110 | self.saved = False 111 | 112 | def throw_visdom_connection_error(self): 113 | print('\n\nCould not connect to Visdom server (https://github.com/facebookresearch/visdom) for displaying training progress.\nYou can suppress connection to Visdom using the option --display_id -1. To install visdom, run \n$ pip install visdom\n, and start the server by \n$ python -m visdom.server.\n\n') 114 | exit(1) 115 | 116 | # |visuals|: dictionary of images to display or save 117 | def display_current_results(self, visuals, iteration, save_result): 118 | visuals = preprocess_multi_channel(visuals) 119 | 120 | if self.display_id > 0: # show images in the browser 121 | ncols = self.ncols 122 | if ncols > 0: 123 | ncols = min(ncols, len(visuals)) 124 | h, w = next(iter(visuals.values())).shape[:2] 125 | table_css = """""" % (w, h) 129 | title = self.name 130 | label_html = '' 131 | label_html_row = '' 132 | images = [] 133 | idx = 0 134 | for label, image in visuals.items(): 135 | image_numpy = util.tensor2im(image) 136 | label_html_row += '%s' % label 137 | images.append(image_numpy.transpose([2, 0, 1])) 138 | idx += 1 139 | if idx % ncols == 0: 140 | label_html += '%s' % label_html_row 141 | label_html_row = '' 142 | white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255 143 | while idx % ncols != 0: 144 | images.append(white_image) 145 | label_html_row += '' 146 | idx += 1 147 | if label_html_row != '': 148 | label_html += '%s' % label_html_row 149 | # pane col = image row 150 | try: 151 | self.vis.images(images, nrow=ncols, win=self.display_id + 1, 152 | padding=2, opts=dict(title=title + ' images')) 153 | label_html = '%s
' % label_html 154 | self.vis.text(table_css + label_html, win=self.display_id + 2, 155 | opts=dict(title=title + ' labels')) 156 | except VisdomExceptionBase: 157 | self.throw_visdom_connection_error() 158 | 159 | else: 160 | idx = 1 161 | for label, image in visuals.items(): 162 | image_numpy = util.tensor2im(image) 163 | self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label), 164 | win=self.display_id + idx) 165 | idx += 1 166 | 167 | if self.use_html and (save_result or not self.saved): # save images to a html file 168 | self.saved = True 169 | for label, image in visuals.items(): 170 | image_numpy = util.tensor2im(image) 171 | img_path = os.path.join(self.img_dir, 'latest_%s.png' % (label)) 172 | util.save_image(image_numpy, img_path) 173 | if iteration % self.opt.save_iters_freq == 0: 174 | img_path = os.path.join( 175 | self.img_dir, 'iteration%.7d_%s.png' % (iteration, label)) 176 | util.save_image(image_numpy, img_path) 177 | 178 | # find saved images 179 | img = os.listdir(self.img_dir) 180 | reg = re.compile('(iteration0*[0-9]+)*') 181 | saved_iterations = set() 182 | for img in os.listdir(self.img_dir): 183 | match = reg.match(img).group(1) 184 | if match: 185 | saved_iterations.add(match) 186 | saved_iterations = list(saved_iterations) 187 | saved_iterations = sorted(saved_iterations, reverse=True) 188 | prefixes = ['latest'] + saved_iterations 189 | # update website 190 | webpage = html.HTML(self.web_dir, '%s' % self.name, reflesh=0) 191 | for prefix in prefixes: 192 | ims, txts, links = [], [], [] 193 | for label, _ in visuals.items(): 194 | # convert images to png 195 | full_img_path_png = os.path.join( 196 | webpage.get_image_dir(), '%s_%s.png' % (prefix, label)) 197 | full_img_path_jpg = os.path.join( 198 | webpage.get_image_dir(), '%s_%s.jpg' % (prefix, label)) 199 | if os.path.isfile(full_img_path_jpg) and not os.path.isfile(full_img_path_png): 200 | image = Image.open(full_img_path_jpg) 201 | image.save(full_img_path_png) 202 | img_path = '%s_%s.png' % (prefix, label) 203 | ims.append(img_path) 204 | txts.append(label) 205 | links.append(img_path) 206 | webpage.add_media(ims, txts, links, 'image', width=self.win_size, title='%s' % prefix) 207 | webpage.save() 208 | 209 | # losses: dictionary of error labels and values 210 | def plot_current_losses(self, iteration, opt, losses): 211 | if not hasattr(self, 'plot_data'): 212 | self.plot_data = {'X': [], 'Y': [], 'legend': list(losses.keys())} 213 | self.plot_data['X'].append(iteration) 214 | self.plot_data['Y'].append([losses[k] for k in self.plot_data['legend']]) 215 | try: 216 | self.vis.line( 217 | X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1), 218 | Y=np.array(self.plot_data['Y']), 219 | opts={ 220 | 'title': self.name + ' loss over time', 221 | 'legend': self.plot_data['legend'], 222 | 'xlabel': 'epoch', 223 | 'ylabel': 'loss'}, 224 | win=self.display_id) 225 | except VisdomExceptionBase: 226 | self.throw_visdom_connection_error() 227 | 228 | # losses: same format as |losses| of plot_current_losses 229 | def print_current_losses(self, iteration, losses, t, t_data, prefix='', txt=None): 230 | message = prefix + '(iters: %d, time: %.3f, data: %.3f) ' % ( 231 | iteration, t, t_data) 232 | for k, v in losses.items(): 233 | message += '%s: %.3f ' % (k, v) 234 | if txt is not None: 235 | message += ' ' + txt 236 | print(message) 237 | with open(self.log_name, "a") as log_file: 238 | log_file.write('%s\n' % message) 239 | 240 | 241 | def preprocess_multi_channel(visuals): 242 | # preprocess multichannel maps 243 | new_visuals = {} 244 | for label, image in visuals.items(): 245 | if image.shape[1] not in [1, 3] or len(image.shape) == 5: 246 | max_frames = 20 247 | for i in range(min(max_frames, image.shape[1])): 248 | if len(image.shape) == 5: 249 | frame = image[:, i] 250 | else: 251 | frame = normalize_image_tensor(image[:, i][:, None]) 252 | new_visuals[label + '_' + str(i)] = frame 253 | # visuals[label], _ = torch.max(image, dim=1, keepdim=True) 254 | else: 255 | new_visuals[label] = image 256 | return new_visuals 257 | 258 | 259 | class VisualsLog(object): 260 | 261 | def __init__(self): 262 | self.log = OrderedDict() 263 | 264 | 265 | def append(self, visuals): 266 | visuals = preprocess_multi_channel(visuals) 267 | for name, visual in visuals.items(): 268 | visual = visual.to('cpu') 269 | if name not in self.log: 270 | self.log[name] = [] 271 | self.log[name] += [visual] 272 | 273 | 274 | def save_as_video(self, visual_name, path): 275 | images = self.log[visual_name] 276 | images = [util.tensor2im(x) for x in images] 277 | skvideo.io.vwrite(path, images, outputdict={'-crf': '1', '-pix_fmt': 'yuv420p'}) 278 | 279 | def keys(self): 280 | return self.log.keys() 281 | --------------------------------------------------------------------------------