├── README.md ├── data ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── base_dataset.cpython-37.pyc │ ├── data_loader.cpython-37.pyc │ ├── pose_transfer_parsing_dataset.cpython-37.pyc │ └── pose_transfer_parsing_market_dataset.cpython-37.pyc ├── base_dataset.py ├── data_loader.py ├── pose_transfer_parsing_dataset.py └── pose_transfer_parsing_market_dataset.py ├── eval_deepfashion.sh ├── eval_market.sh ├── imgs └── pipeline_all-1.png ├── models ├── SPG_net_deepfashion.py ├── SPG_net_market.py ├── __init__.py ├── __pycache__ │ ├── SPG_net_deepfashion.cpython-37.pyc │ ├── SPG_net_market.cpython-37.pyc │ ├── __init__.cpython-37.pyc │ ├── base_model.cpython-37.pyc │ ├── flow_regression_model.cpython-37.pyc │ ├── modules.cpython-37.pyc │ ├── networks.cpython-37.pyc │ ├── normalization.cpython-37.pyc │ └── pose_transfer_model.cpython-37.pyc ├── base_model.py ├── flow_regression_model.py ├── modules.py ├── networks.py ├── normalization.py └── pose_transfer_model.py ├── options ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── base_options.cpython-37.pyc │ └── pose_transfer_options.cpython-37.pyc ├── base_options.py └── pose_transfer_options.py ├── requirements.txt ├── scripts ├── test_pose_transfer_model.py └── train_pose_transfer_model.py ├── test_deepfashion.sh ├── test_market.sh ├── tools ├── __init__.py ├── __pycache__ │ └── cmd.cpython-37.pyc ├── calPCKH_fashion.py ├── calPCKH_market.py ├── cmd.py ├── compute_coordinates.py ├── metrics_deepfashion.py ├── metrics_market.py └── pose_utils.py ├── train_deepfashion.sh ├── train_market.sh └── util ├── __init__.py ├── __pycache__ ├── __init__.cpython-37.pyc ├── flow_util.cpython-37.pyc ├── io.cpython-37.pyc ├── loss_buffer.cpython-37.pyc ├── pose_util.cpython-37.pyc └── visualizer.cpython-37.pyc ├── flow_util.py ├── image_pool.py ├── io.py ├── loss_buffer.py ├── pose_util.py └── visualizer.py /README.md: -------------------------------------------------------------------------------- 1 | # [Learning Semantic Person Image Generation by Region-Adaptive Normalization](https://arxiv.org/pdf/2104.06650.pdf) 2 | 3 | The source code for our paper "Learning Semantic Person Image Generation by Region-Adaptive Normalization" (CVPR 2021) 4 | 5 | ![network](./imgs/pipeline_all-1.png) 6 | 7 | ## Quick Start 8 | 9 | ### Installation 10 | 11 | **Clone this repo** 12 | 13 | ``` 14 | git clone https://github.com/cszy98/SPGNet.git. 15 | cd SPGNet 16 | ``` 17 | 18 | **Prerequisites** 19 | 20 | - python3.7 21 | - pytorch1.2.0 + torchvision0.4.0 22 | - numpy 23 | - opencv 24 | - tqdm 25 | 26 | Create environment and install dependencies: 27 | 28 | ``` 29 | # 1. Create a conda virtual environment. 30 | conda create -n spgnet python=3.7 anaconda 31 | source activate spgnet 32 | 33 | # 2. Install dependency 34 | pip install -r requirements.txt 35 | ``` 36 | 37 | ### Data Preparation 38 | 39 | The DeepFashion and Market-1501 datasets can be downloaded from [GoogleDrive](https://drive.google.com/drive/folders/1TR9hcabKA94PZA7cj5g7nodKEbzpyg1e?usp=sharing). 40 | 41 | ### Testing and Evaluate 42 | 43 | The pretrained models can be downloaded from [GoogleDrive](https://drive.google.com/drive/folders/1TR9hcabKA94PZA7cj5g7nodKEbzpyg1e?usp=sharing). 44 | 45 | **Test on DeepFashion** 46 | 47 | ``` 48 | python scripts/test_pose_transfer_model.py --id deepfashion --gpu_ids 0 --dataset_name deepfashion --which_model_G dual_unet --G_feat_warp 1 --G_vis_mode residual --pretrained_flow_id FlowReg_deepfashion --pretrained_flow_epoch best --dataset_type pose_transfer_parsing --which_epoch latest --batch_size 4 --save_output --output_dir output 49 | ``` 50 | 51 | **Test on Market-1501** 52 | 53 | ``` 54 | python scripts/test_pose_transfer_model.py --id market --gpu_ids 0 --dataset_name market --which_model_G dual_unet --G_feat_warp 1 --G_vis_mode residual --pretrained_flow_id FlowReg_market --pretrained_flow_epoch best --dataset_type pose_transfer_parsing_market --which_epoch latest --batch_size 1 --save_output --output_dir output 55 | ``` 56 | 57 | **Evaluate** 58 | 59 | Run ``eval_deepfashion.sh`` and ``eval_market.sh`` to evaluate LPIPS and FID on DeepFashion and Market-1501, respectively. 60 | 61 | To evaluate the PCKh, download pose estimator from [GoogleDrive](https://drive.google.com/file/d/1Y1WWYKUhCnei2dFxf8gj9lKJswR73arh/view?usp=sharing) and put it under the root folder. Then change the path in `tool/compute_coordinates.py` and launch ``python2 compute_coordinates.py``. After that, launch `python tool/calPCKH_market.py` or `python tool/calPCKH_fashion.py` to get PCKh. Please refer to [Pose-Transfer](https://github.com/tengteng95/Pose-Transfer#evaluation) for more details. 62 | 63 | ### Training 64 | 65 | **1. Train on DeepFashion** 66 | 67 | ``` 68 | python scripts/train_pose_transfer_model.py --id deepfashion --gpu_ids 0,1 --dataset_name deepfashion --which_model_G dual_unet --G_feat_warp 1 --G_vis_mode residual --pretrained_flow_id FlowReg_deepfashion --pretrained_flow_epoch best --dataset_type pose_transfer_parsing --check_grad_freq 3000 --batch_size 4 --n_epoch 45 69 | ``` 70 | **2. Train on Market-1501** 71 | 72 | ``` 73 | python scripts/train_pose_transfer_model.py --id marketest --gpu_ids 0,1,2,3 --dataset_name market --which_model_G dual_unet --G_feat_warp 1 --G_vis_mode residual --pretrained_flow_id FlowReg_market --pretrained_flow_epoch best --dataset_type pose_transfer_parsing_market --check_grad_freq 3000 --batch_size 32 --n_epoch 10 74 | ``` 75 | ## Citation 76 | 77 | If you find our work useful in your research or publication, please cite: 78 | 79 | @article{lv2021learning, 80 | title={Learning Semantic Person Image Generation by Region-Adaptive Normalization}, 81 | author={Lv, Zhengyao and Li, Xiaoming and Li, Xin and Li, Fu and Lin, Tianwei and He, Dongliang and Zuo, Wangmeng}, 82 | booktitle = {IEEE Conference on Computer Vision and Pattern Recognition}, 83 | year = {2021} 84 | } 85 | ## Acknowledgments 86 | 87 | This code borrows heavily from [intrinsic_flow](https://github.com/ly015/intrinsic_flow) and [Pose-Transfer](https://github.com/tengteng95/Pose-Transfer). 88 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszy98/SPGNet/3c1f6b5e290e7339ec01181403f23edf1e87eb15/data/__init__.py -------------------------------------------------------------------------------- /data/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszy98/SPGNet/3c1f6b5e290e7339ec01181403f23edf1e87eb15/data/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /data/__pycache__/base_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszy98/SPGNet/3c1f6b5e290e7339ec01181403f23edf1e87eb15/data/__pycache__/base_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /data/__pycache__/data_loader.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszy98/SPGNet/3c1f6b5e290e7339ec01181403f23edf1e87eb15/data/__pycache__/data_loader.cpython-37.pyc -------------------------------------------------------------------------------- /data/__pycache__/pose_transfer_parsing_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszy98/SPGNet/3c1f6b5e290e7339ec01181403f23edf1e87eb15/data/__pycache__/pose_transfer_parsing_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /data/__pycache__/pose_transfer_parsing_market_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszy98/SPGNet/3c1f6b5e290e7339ec01181403f23edf1e87eb15/data/__pycache__/pose_transfer_parsing_market_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /data/base_dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import division, print_function 2 | import torch.utils.data as data 3 | import numpy as np 4 | from PIL import Image 5 | import cv2 6 | 7 | ##################################### 8 | # BaseDataset Class 9 | ##################################### 10 | 11 | class BaseDataset(data.Dataset): 12 | def __init__(self): 13 | super(BaseDataset, self).__init__() 14 | 15 | def name(self): 16 | return 'BaseDataset' 17 | 18 | def initialize(self, opt): 19 | pass 20 | 21 | ##################################### 22 | # Image Transform Modules 23 | ##################################### 24 | def kp_to_map(img_sz, kps, mode='gaussian', radius=5): 25 | ''' 26 | Keypoint cordinates to heatmap map. 27 | Input: 28 | img_size (w,h): size of heatmap 29 | kps (N,2): (x,y) cordinates of N keypoints 30 | mode: 'gaussian' or 'binary' 31 | radius: radius of each keypoints in heatmap 32 | Output: 33 | m (h,w,N): encoded heatmap 34 | ''' 35 | w, h = img_sz 36 | x_grid, y_grid = np.meshgrid(range(w), range(h), indexing = 'xy') 37 | m = [] 38 | for x, y in kps: 39 | if x == -1 or y == -1: 40 | m.append(np.zeros((h, w)).astype(np.float32)) 41 | else: 42 | if mode == 'gaussian': 43 | m.append(np.exp(-((x_grid - x)**2 + (y_grid - y)**2)/(radius**2)).astype(np.float32)) 44 | elif mode == 'binary': 45 | m.append(((x_grid-x)**2 + (y_grid-y)**2 <= radius**2).astype(np.float32)) 46 | else: 47 | raise NotImplementedError() 48 | m = np.stack(m, axis=2) 49 | return m 50 | 51 | 52 | def seg_label_to_map(seg_label, nc = 7, bin_size=1): 53 | ''' 54 | Input: 55 | seg_label: (H,W), 2D segmentation class label 56 | nc: number of classes 57 | bin_size: filter isolate pixels which is likely to be noise 58 | Output: 59 | seg_map: (H,W,nc) 60 | ''' 61 | seg_map = [(seg_label == i) for i in range(nc)] 62 | seg_map = np.concatenate(seg_map, axis=2).astype(np.float32) 63 | if bin_size > 1: 64 | h, w = seg_map.shape[0:2] 65 | dh, dw = h//bin_size, w//bin_size 66 | seg_map = cv2.resize(seg_map, dsize=(dw,dh), interpolation=cv2.INTER_LINEAR) 67 | seg_map = cv2.resize(seg_map, dsize=(w, h), interpolation=cv2.INTER_NEAREST) 68 | return seg_map 69 | ############################################################################### 70 | -------------------------------------------------------------------------------- /data/data_loader.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data 2 | 3 | # Todo: disentangle data-related parameters from model options 4 | 5 | def CreateDataLoader(opt, split='test'): 6 | # loader = CustomDataLoader() 7 | # loader.initialize(opt) 8 | # return loader 9 | 10 | dataset = CreateDataset(opt, split) 11 | shuffle = (split == 'train' and opt.is_train) 12 | drop_last = opt.is_train 13 | dataloader = torch.utils.data.DataLoader( 14 | dataset = dataset, 15 | batch_size = opt.batch_size, 16 | shuffle = shuffle, 17 | num_workers = 8, 18 | drop_last = drop_last, 19 | pin_memory = False) 20 | return dataloader 21 | 22 | def CreateDataset(opt, split): 23 | dataset = None 24 | if opt.dataset_type == 'pose_transfer_parsing': 25 | from data.pose_transfer_parsing_dataset import PoseTransferParsingDataset as DatasetClass 26 | elif opt.dataset_type == 'pose_transfer_parsing_market': 27 | from data.pose_transfer_parsing_market_dataset import PoseTransferParsingPredDataset as DatasetClass 28 | else: 29 | raise ValueError('Dataset mode [%s] not recognized.' % opt.dataset_type) 30 | 31 | dataset = DatasetClass() 32 | dataset.initialize(opt, split) 33 | print('Dataset [%s] was created (size: %d).' % (dataset.name(), len(dataset))) 34 | 35 | return dataset 36 | 37 | -------------------------------------------------------------------------------- /data/pose_transfer_parsing_dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | import torchvision.transforms as transforms 4 | from .base_dataset import * 5 | import cv2 6 | import numpy as np 7 | import os 8 | import util.io as io 9 | 10 | class PoseTransferParsingDataset(BaseDataset): 11 | def name(self): 12 | return 'PoseTransferParsingDataset' 13 | 14 | def initialize(self, opt, split): 15 | self.opt = opt 16 | self.data_root = opt.data_root 17 | self.split = split 18 | ############################# 19 | # set path / load label 20 | ############################# 21 | data_split = io.load_json(os.path.join(opt.data_root, opt.fn_split)) 22 | self.img_dir = os.path.join(opt.data_root, opt.img_dir) 23 | self.seg_dir = os.path.join(opt.data_root, opt.seg_dir) 24 | self.pose_label = io.load_data(os.path.join(opt.data_root, opt.fn_pose)) 25 | 26 | self.seg_cihp_dir = os.path.join(opt.data_root, opt.seg_dir) 27 | self.seg_cihp_pred_dir = os.path.join(opt.data_root, opt.seg_pred_dir) 28 | 29 | ############################# 30 | # create index list 31 | ############################# 32 | self.id_list = data_split[split] if split in data_split.keys() else data_split['test'][:2000] 33 | self._len = len(self.id_list) 34 | ############################# 35 | # other 36 | ############################# 37 | # here set debug option 38 | if opt.debug: 39 | self.id_list = self.id_list[0:32] 40 | self.tensor_normalize_std = transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 41 | self.to_pil_image = transforms.ToPILImage() 42 | self.pil_to_tensor = transforms.ToTensor() 43 | self.color_jitter = transforms.ColorJitter(brightness=0.0, contrast=0.0, saturation=0.0, hue=0.2) 44 | 45 | def set_len(self, n): 46 | self._len = n 47 | 48 | def __len__(self): 49 | if hasattr(self, '_len') and self._len > 0: 50 | return self._len 51 | else: 52 | return len(self.id_list) 53 | 54 | def to_tensor(self, np_data): 55 | return torch.Tensor(np_data.transpose((2, 0, 1))) 56 | 57 | def read_image(self, sid): 58 | fn = os.path.join(self.img_dir, sid + '.jpg') 59 | # print(fn) 60 | # print(os.path.exists(fn)) 61 | img = cv2.imread(fn).astype(np.float32) / 255. 62 | img = img[..., [2, 1, 0]] 63 | return img 64 | 65 | def read_seg_pred_cihp(self, sid1, sid2): 66 | fn = os.path.join(self.seg_cihp_pred_dir, sid1 + '___' + sid2 + '.png') 67 | seg = cv2.imread(fn, cv2.IMREAD_GRAYSCALE).astype(np.float32)[...,np.newaxis] 68 | return seg 69 | 70 | def read_seg_cihp(self, sid): 71 | fn = os.path.join(self.seg_cihp_dir, sid + '.png') 72 | seg = cv2.imread(fn, cv2.IMREAD_GRAYSCALE).astype(np.float32)[..., np.newaxis] 73 | return seg 74 | 75 | def __getitem__(self, index): 76 | sid1, sid2 = self.id_list[index] 77 | ###################### 78 | # load data 79 | ###################### 80 | img_1 = self.read_image(sid1) 81 | img_2 = self.read_image(sid2) 82 | 83 | seg_cihp_label_1 = self.read_seg_cihp(sid1) 84 | seg_cihp_label_2 = self.read_seg_cihp(sid2) if self.split=='train' else self.read_seg_pred_cihp(sid1, sid2) 85 | joint_c_1 = np.array(self.pose_label[sid1]) 86 | joint_c_2 = np.array(self.pose_label[sid2]) 87 | h, w = self.opt.image_size 88 | ###################### 89 | # pack output data 90 | ###################### 91 | joint_1 = kp_to_map(img_sz=(w, h), kps=joint_c_1, mode=self.opt.joint_mode, radius=self.opt.joint_radius) 92 | joint_2 = kp_to_map(img_sz=(w, h), kps=joint_c_2, mode=self.opt.joint_mode, radius=self.opt.joint_radius) 93 | seg_cihp_1 = seg_label_to_map(seg_cihp_label_1, nc=20) 94 | seg_cihp_2 = seg_label_to_map(seg_cihp_label_2, nc=20) 95 | 96 | data = { 97 | 'img_1': self.tensor_normalize_std(self.to_tensor(img_1)), 98 | 'img_2': self.tensor_normalize_std(self.to_tensor(img_2)), 99 | 'joint_1': self.to_tensor(joint_1), 100 | 'joint_2': self.to_tensor(joint_2), 101 | 'seg_cihp_1': self.to_tensor(seg_cihp_1), 102 | 'seg_cihp_2': self.to_tensor(seg_cihp_2), 103 | 'id_1': sid1, 104 | 'id_2': sid2 105 | } 106 | return data -------------------------------------------------------------------------------- /data/pose_transfer_parsing_market_dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | import torchvision.transforms as transforms 4 | from .base_dataset import * 5 | import cv2 6 | import numpy as np 7 | import os 8 | import util.io as io 9 | from util import flow_util 10 | 11 | class PoseTransferParsingPredDataset(BaseDataset): 12 | def name(self): 13 | return 'PoseTransferParsingPredDataset' 14 | 15 | def initialize(self, opt, split): 16 | self.opt = opt 17 | self.data_root = opt.data_root 18 | self.split = split 19 | ############################# 20 | # set path / load label 21 | ############################# 22 | data_split = io.load_json(os.path.join(opt.data_root, opt.fn_split)) 23 | self.img_dir = os.path.join(opt.data_root, opt.img_dir) 24 | self.seg_dir = os.path.join(opt.data_root, opt.seg_dir) 25 | print(opt.fn_pose) 26 | self.pose_label = io.load_data(os.path.join(opt.data_root, opt.fn_pose)) 27 | 28 | self.seg_cihp_dir = os.path.join(opt.data_root, opt.seg_dir) 29 | self.seg_cihp_pred_dir = os.path.join(opt.data_root, opt.seg_pred_dir) 30 | ############################# 31 | # create index list 32 | ############################# 33 | #self.id_list = data_split[split] 34 | if split=='test_small': 35 | self.id_list=data_split['test'][:2000] 36 | else: 37 | self.id_list = data_split[split] 38 | self._len = len(self.id_list) 39 | ############################# 40 | # other 41 | ############################# 42 | # here set debug option 43 | if opt.debug: 44 | self.id_list = self.id_list[0:32] 45 | self.tensor_normalize_std = transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 46 | self.to_pil_image = transforms.ToPILImage() 47 | self.pil_to_tensor = transforms.ToTensor() 48 | self.color_jitter = transforms.ColorJitter(brightness=0.0, contrast=0.0, saturation=0.0, hue=0.2) 49 | 50 | def set_len(self, n): 51 | self._len = n 52 | 53 | def __len__(self): 54 | if hasattr(self, '_len') and self._len > 0: 55 | return self._len 56 | else: 57 | return len(self.id_list) 58 | 59 | def to_tensor(self, np_data): 60 | return torch.Tensor(np_data.transpose((2, 0, 1))) 61 | 62 | def read_image(self, sid): 63 | fn = os.path.join(self.img_dir, sid + '.jpg') 64 | img = cv2.imread(fn).astype(np.float32) / 255. 65 | img = img[..., [2, 1, 0]] 66 | return img 67 | 68 | def read_seg(self, sid): 69 | fn = os.path.join(self.seg_dir, sid + '.bmp') 70 | seg = cv2.imread(fn, cv2.IMREAD_GRAYSCALE).astype(np.float32)[..., np.newaxis] 71 | return seg 72 | 73 | def read_seg_pred_cihp(self,sid1, sid2): 74 | 75 | fn = os.path.join(self.seg_cihp_pred_dir, sid1+'___'+sid2+'.png') 76 | #print(fn) 77 | # print(fn) 78 | seg = cv2.imread(fn, cv2.IMREAD_GRAYSCALE).astype(np.float32)[...,np.newaxis] 79 | return seg 80 | 81 | def read_seg_cihp(self, sid1, sid2): 82 | fn = os.path.join(self.seg_cihp_dir, sid1+'___'+sid2+'.png') 83 | 84 | seg = cv2.imread(fn, cv2.IMREAD_GRAYSCALE)#.astype(np.float32)[..., np.newaxis] 85 | seg = seg[:,64:192] 86 | seg = cv2.resize(seg,(64,128),interpolation=cv2.INTER_NEAREST) 87 | seg = seg.astype(np.float32)[..., np.newaxis] 88 | return seg 89 | 90 | def read_flow(self, sid1, sid2): 91 | ''' 92 | Output: 93 | flow_2to1: (h,w,2) correspondence from image 2 to image 1. corr_2to1[y,x] = [u,v], means image2[y,x] -> image1[v,u] 94 | vis_2: (h,w) visibility mask of image 2. 95 | 0: human pixel with correspondence 96 | 1: human pixel without correspondece 97 | 2: background pixel 98 | ''' 99 | fn = os.path.join(self.corr_dir, '%s_%s.corr' % (sid2, sid1)) 100 | corr_2to1, vis_2 = flow_util.read_corr(fn) 101 | vis_2 = vis_2[..., np.newaxis] 102 | flow_2to1 = flow_util.corr_to_flow(corr_2to1, vis_2, order='HWC') 103 | if self.opt.vis_smooth_rate > 0: 104 | vis_2b = cv2.medianBlur(vis_2, self.opt.vis_smooth_rate)[..., np.newaxis] 105 | m = (vis_2 < 2).astype(np.uint8) 106 | vis_2 = vis_2b * m + vis_2 * (1 - m) 107 | return flow_2to1, vis_2 108 | 109 | def read_corr(self, sid1, sid2): 110 | ''' 111 | Output: 112 | corr_2to1: (h, w, 2) 113 | vis_2: (h, w) 114 | ''' 115 | try: 116 | fn = os.path.join(self.corr_dir, '%s_%s.corr' % (sid2, sid1)) 117 | corr_2to1, vis_2 = flow_util.read_corr(fn) 118 | vis_2 = vis_2[..., np.newaxis] 119 | if self.opt.vis_smooth_rate > 0: 120 | vis_2b = cv2.medianBlur(vis_2, self.opt.vis_smooth_rate)[..., np.newaxis] 121 | m = (vis_2 < 2).astype(np.uint8) 122 | vis_2 = vis_2b * m + vis_2 * (1 - m) 123 | return corr_2to1, vis_2 124 | except: 125 | h, w = self.opt.image_size 126 | return np.zeros((h, w, 2), dtype=np.float32), np.ones((h, w, 1), dtype=np.float32) * 2 127 | 128 | def color_jit(self, img_1, img_2): 129 | ''' 130 | Input: 131 | img_1, img_2: Tensor CHW 132 | Output: 133 | img_1, img_2: Tensor CHW 134 | ''' 135 | w1 = img_1.shape[2] 136 | img = torch.cat((img_1, img_2), dim=2) 137 | img = self.to_pil_image(img.add_(1).div_(2)) 138 | img = self.color_jitter(img) 139 | img = self.pil_to_tensor(img).mul_(2).sub_(1) 140 | return img[:, :, :w1], img[:, :, w1:] 141 | 142 | def __getitem__(self, index): 143 | sid1, sid2 = self.id_list[index] 144 | ###################### 145 | # load data 146 | ###################### 147 | img_1 = self.read_image(sid1) 148 | img_2 = self.read_image(sid2) 149 | 150 | 151 | seg_cihp_label_1 = self.read_seg_cihp(sid1, sid1) 152 | seg_cihp_label_2 = self.read_seg_cihp(sid1, sid2) 153 | 154 | joint_c_1 = np.array(self.pose_label[sid1]) 155 | joint_c_2 = np.array(self.pose_label[sid2]) 156 | corr_2to1, vis_2 = self.read_corr(sid1, sid2) 157 | h, w = self.opt.image_size 158 | ###################### 159 | # augmentation 160 | ###################### 161 | use_augmentation = self.opt.use_augmentation and self.opt.is_train and self.split == 'train' 162 | #print(vis_2.shape) 163 | if use_augmentation: 164 | # apply random shift and scale on img_2 165 | h, w = self.opt.image_size 166 | dx = np.random.randint(-self.opt.aug_shiftx_range, 167 | self.opt.aug_shiftx_range) if self.opt.aug_shiftx_range > 0 else 0 168 | dy = np.random.randint(-self.opt.aug_shifty_range, 169 | self.opt.aug_shifty_range) if self.opt.aug_shifty_range > 0 else 0 170 | sc = self.opt.aug_scale_range ** (np.random.rand() * 2 - 1) 171 | M = np.array([[sc, 0, 0.5 * h * (1 - sc) + dx], [0, sc, 0.5 * w * (1 - sc) + dy]]) 172 | 173 | img_2 = cv2.warpAffine(img_2, M, dsize=(w, h), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REPLICATE) 174 | seg_cihp_label_2 = \ 175 | cv2.warpAffine(seg_cihp_label_2, M, dsize=(w, h), flags=cv2.INTER_NEAREST, borderMode=cv2.BORDER_REPLICATE)[ 176 | ..., np.newaxis] 177 | corr_2to1 = cv2.warpAffine(corr_2to1, M, dsize=(w, h), flags=cv2.INTER_LINEAR, 178 | borderMode=cv2.BORDER_REPLICATE) 179 | vis_2 = cv2.warpAffine(vis_2, M, dsize=(w, h), flags=cv2.INTER_NEAREST, borderMode=cv2.BORDER_REPLICATE)[...,np.newaxis] 180 | 181 | # v = (d[:,0]>=0) & (d[:,1]>=0) & (d[:,0]= 0) & (joint_c_2[:, 1] >= 0) & (joint_c_2[:, 0] < w) & (joint_c_2[:, 1] < h) 183 | jc = joint_c_2.dot(M[:, 0:2].T) + M[:, 2:].T 184 | v_t = (jc[:, 0] >= 0) & (jc[:, 1] >= 0) & (jc[:, 0] < w) & (jc[:, 1] < h) 185 | v_t = v_t & v 186 | jc[~v_t, :] = -1 187 | joint_c_2 = jc 188 | ###################### 189 | # pack output data 190 | ###################### 191 | #print(vis_2.shape) 192 | joint_1 = kp_to_map(img_sz=(w, h), kps=joint_c_1, mode=self.opt.joint_mode, radius=self.opt.joint_radius) 193 | joint_2 = kp_to_map(img_sz=(w, h), kps=joint_c_2, mode=self.opt.joint_mode, radius=self.opt.joint_radius) 194 | seg_cihp_1 = seg_label_to_map(seg_cihp_label_1, nc=20) 195 | seg_cihp_2 = seg_label_to_map(seg_cihp_label_2, nc=20) 196 | flow_2to1 = flow_util.corr_to_flow(corr_2to1, vis_2, order='HWC') 197 | flow_2to1[..., 0] = flow_2to1[..., 0].clip(-w, w) 198 | flow_2to1[..., 1] = flow_2to1[..., 1].clip(-h, h) 199 | 200 | seg_cihp_1 = self.to_tensor(seg_cihp_1) 201 | 202 | seg_cihp_2 = self.to_tensor(seg_cihp_2) 203 | 204 | data = { 205 | 'img_1': self.tensor_normalize_std(self.to_tensor(img_1)), 206 | 'img_2': self.tensor_normalize_std(self.to_tensor(img_2)), 207 | 'joint_1': self.to_tensor(joint_1), 208 | 'joint_2': self.to_tensor(joint_2), 209 | 'seg_cihp_1': seg_cihp_1, 210 | 'seg_cihp_2': seg_cihp_2, 211 | 'flow_2to1': self.to_tensor(flow_2to1), 212 | 'vis_2': self.to_tensor(vis_2), 213 | 'id_1': sid1, 214 | 'id_2': sid2 215 | } 216 | 217 | ###################### 218 | # color jit 219 | ###################### 220 | if use_augmentation and self.opt.aug_color_jit: 221 | data['img_1'], data['img_2'] = self.color_jit(data['img_1'], data['img_2']) 222 | 223 | return data 224 | -------------------------------------------------------------------------------- /eval_deepfashion.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python tools/metrics_deepfashion.py \ 4 | --gt_path datasets/deepfashion/img/ground_truth \ 5 | --distorated_path checkpoints/PoseTransfer_deepfashion/output \ 6 | --fid_real_path datasets/deepfashion/img/train -------------------------------------------------------------------------------- /eval_market.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python tools/metrics_market.py \ 4 | --gt_path datasets/market1501/img/ground_truth/ \ 5 | --distorated_path checkpoints/PoseTransfer_market/output/ \ 6 | --fid_real_path datasets/market1501/img/train/ -------------------------------------------------------------------------------- /imgs/pipeline_all-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszy98/SPGNet/3c1f6b5e290e7339ec01181403f23edf1e87eb15/imgs/pipeline_all-1.png -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszy98/SPGNet/3c1f6b5e290e7339ec01181403f23edf1e87eb15/models/__init__.py -------------------------------------------------------------------------------- /models/__pycache__/SPG_net_deepfashion.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszy98/SPGNet/3c1f6b5e290e7339ec01181403f23edf1e87eb15/models/__pycache__/SPG_net_deepfashion.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/SPG_net_market.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszy98/SPGNet/3c1f6b5e290e7339ec01181403f23edf1e87eb15/models/__pycache__/SPG_net_market.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszy98/SPGNet/3c1f6b5e290e7339ec01181403f23edf1e87eb15/models/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/base_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszy98/SPGNet/3c1f6b5e290e7339ec01181403f23edf1e87eb15/models/__pycache__/base_model.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/flow_regression_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszy98/SPGNet/3c1f6b5e290e7339ec01181403f23edf1e87eb15/models/__pycache__/flow_regression_model.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/modules.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszy98/SPGNet/3c1f6b5e290e7339ec01181403f23edf1e87eb15/models/__pycache__/modules.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/networks.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszy98/SPGNet/3c1f6b5e290e7339ec01181403f23edf1e87eb15/models/__pycache__/networks.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/normalization.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszy98/SPGNet/3c1f6b5e290e7339ec01181403f23edf1e87eb15/models/__pycache__/normalization.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/pose_transfer_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszy98/SPGNet/3c1f6b5e290e7339ec01181403f23edf1e87eb15/models/__pycache__/pose_transfer_model.cpython-37.pyc -------------------------------------------------------------------------------- /models/base_model.py: -------------------------------------------------------------------------------- 1 | from __future__ import division, print_function 2 | 3 | import torch 4 | import torch.nn 5 | import os 6 | 7 | class BaseModel(object): 8 | def name(self): 9 | return 'BaseModel' 10 | 11 | def initialize(self, opt): 12 | self.opt = opt 13 | self.gpu_ids = opt.gpu_ids 14 | self.is_train = opt.is_train 15 | self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor 16 | self.save_dir = os.path.join('checkpoints', opt.id) 17 | 18 | self.input = {} 19 | self.output = {} 20 | 21 | def set_input(self, data): 22 | self.input = data 23 | 24 | def forward(self): 25 | pass 26 | 27 | # used in test time, no backprob 28 | def test(self): 29 | pass 30 | 31 | def optimize_parameters(self): 32 | pass 33 | 34 | 35 | def get_current_visuals(self): 36 | return self.input 37 | 38 | 39 | def get_current_errors(self): 40 | return {} 41 | 42 | def train(self): 43 | pass 44 | 45 | def eval(self): 46 | pass 47 | 48 | def save(self, label): 49 | pass 50 | 51 | # helper loading function that can be used by subclasses 52 | def save_network(self, network, network_label, epoch_label, gpu_ids): 53 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 54 | save_path = os.path.join(self.save_dir, save_filename) 55 | torch.save(network.cpu().state_dict(), save_path) 56 | 57 | if len(gpu_ids) and torch.cuda.is_available(): 58 | network.cuda() 59 | 60 | def load_network(self, network, network_label, epoch_label, model_id = None, forced = True): 61 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 62 | if model_id is None: 63 | # for continue training 64 | save_dir = self.save_dir 65 | else: 66 | # for initialize weight 67 | save_dir = os.path.join('checkpoints', model_id) 68 | save_path = os.path.join(save_dir, save_filename) 69 | if (not forced) and (not os.path.isfile(save_path)): 70 | print('[%s] FAIL to load [%s] parameters from %s' % (self.name(), network_label, save_path)) 71 | else: 72 | network.load_state_dict(torch.load(save_path)) 73 | print('[%s] load [%s] parameters from %s' % (self.name(), network_label, save_path)) 74 | 75 | def save_optim(self, optim, optim_label, epoch_label): 76 | save_filename = '%s_optim_%s.pth'%(epoch_label, optim_label) 77 | save_path = os.path.join(self.save_dir, save_filename) 78 | torch.save(optim.state_dict(), save_path) 79 | 80 | def load_optim(self, optim, optim_label, epoch_label): 81 | save_filename = '%s_optim_%s.pth'%(epoch_label, optim_label) 82 | save_path = os.path.join(self.save_dir, save_filename) 83 | optim.load_state_dict(torch.load(save_path)) 84 | 85 | # update learning rate (called once every epoch) 86 | def update_learning_rate(self): 87 | for scheduler in self.schedulers: 88 | scheduler.step() 89 | -------------------------------------------------------------------------------- /models/flow_regression_model.py: -------------------------------------------------------------------------------- 1 | from __future__ import division, print_function 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from . import networks 7 | from .base_model import BaseModel 8 | from collections import OrderedDict 9 | 10 | class FlowRegressionModel(BaseModel): 11 | def name(self): 12 | return 'FlowRegressionModel' 13 | 14 | def initialize(self, opt): 15 | super(FlowRegressionModel, self).initialize(opt) 16 | ################################### 17 | # define flow networks 18 | ################################### 19 | if opt.which_model == 'unet': 20 | self.netF = networks.FlowUnet( 21 | input_nc = self.get_input_dim(opt.input_type1) + self.get_input_dim(opt.input_type2), 22 | nf = opt.nf, 23 | start_scale = opt.start_scale, 24 | num_scale = opt.num_scale, 25 | norm = opt.norm, 26 | gpu_ids = opt.gpu_ids, 27 | ) 28 | elif opt.which_model == 'unet_v2': 29 | self.netF = networks.FlowUnet_v2( 30 | input_nc = self.get_input_dim(opt.input_type1) + self.get_input_dim(opt.input_type2), 31 | nf = opt.nf, 32 | max_nf = opt.max_nf, 33 | start_scale = opt.start_scale, 34 | num_scales = opt.num_scale, 35 | norm = opt.norm, 36 | gpu_ids = opt.gpu_ids, 37 | ) 38 | if opt.gpu_ids: 39 | self.netF.cuda() 40 | networks.init_weights(self.netF, init_type=opt.init_type) 41 | ################################### 42 | # loss and optimizers 43 | ################################### 44 | self.crit_flow = networks.MultiScaleFlowLoss(start_scale=opt.start_scale, num_scale=opt.num_scale, loss_type=opt.flow_loss_type) 45 | self.crit_vis = nn.CrossEntropyLoss() #(0-visible, 1-invisible, 2-background) 46 | if opt.use_ss_flow_loss: 47 | self.crit_flow_ss = networks.SS_FlowLoss(loss_type='l1') 48 | if self.is_train: 49 | self.optimizers = [] 50 | self.optim = torch.optim.Adam(self.netF.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2), weight_decay=opt.weight_decay) 51 | self.optimizers.append(self.optim) 52 | 53 | ################################### 54 | # load trained model 55 | ################################### 56 | if not self.is_train: 57 | # load trained model for test 58 | print('load pretrained model') 59 | self.load_network(self.netF, 'netF', opt.which_epoch) 60 | elif opt.resume_train: 61 | # resume training 62 | print('resume training') 63 | self.load_network(self.netF, 'netF', opt.last_epoch) 64 | self.load_optim(self.optim, 'optim', opt.last_epoch) 65 | ################################### 66 | # schedulers 67 | ################################### 68 | if self.is_train: 69 | self.schedulers = [] 70 | for optim in self.optimizers: 71 | self.schedulers.append(networks.get_scheduler(optim, opt)) 72 | 73 | def set_input(self, data): 74 | input_list = [ 75 | 'img_1', 76 | 'img_2', 77 | 'joint_c_1', 78 | 'joint_c_2', 79 | 'joint_1', 80 | 'joint_2', 81 | 'seg_1', 82 | 'seg_2', 83 | 'seg_label_1', 84 | 'seg_label_2', 85 | 'flow_2to1', 86 | 'vis_2', 87 | 'dissrc', 88 | 'disdst' 89 | ] 90 | for name in input_list: 91 | if name in data: 92 | self.input[name] = self.Tensor(data[name].size()).copy_(data[name]) 93 | self.input['id'] = zip(data['id_1'], data['id_2']) 94 | 95 | def forward(self): 96 | # if self.opt.usedismap != '': 97 | # input = [self.get_input_tensor(self.opt.input_type1, '1'),self.get_input_tensor('dissrc'),self.get_input_tensor(self.opt.input_type2, '2'),self.get_input_tensor('disdst')] 98 | # else: 99 | input = [self.get_input_tensor(self.opt.input_type1, '1'), self.get_input_tensor(self.opt.input_type2, '2')] 100 | 101 | input = torch.cat(input, dim=1) 102 | flow_out, vis_out, flow_pyramid_out, flow_feat = self.netF(input) 103 | flow_scale = 20. 104 | flow_out = flow_out * flow_scale 105 | 106 | self.output['flow_pyramid_out'] = flow_pyramid_out 107 | self.output['flow_out'] = flow_out 108 | self.output['vis_out'] = vis_out 109 | self.output['flow_tar'] = self.input['flow_2to1'] 110 | self.output['vis_tar'] = self.input['vis_2'] 111 | self.output['flow_feat'] = flow_feat 112 | self.output['mask_out'] = (self.output['vis_out'].argmax(dim=1, keepdim=True) < 2).float() 113 | self.output['mask_tar'] = (self.output['vis_tar']<2).float() 114 | self.output['flow_final'] = self.output['flow_out'] * self.output['mask_out'] 115 | 116 | 117 | def test(self, compute_loss=False): 118 | with torch.no_grad(): 119 | self.forward() 120 | if compute_loss: 121 | self.compute_loss() 122 | 123 | def compute_loss(self): 124 | # flow loss 125 | self.output['loss_flow'], _ = self.crit_flow(self.output['flow_pyramid_out'], self.output['flow_tar'], self.output['mask_tar']) 126 | # flow_ss loss 127 | if self.opt.use_ss_flow_loss: 128 | self.output['loss_flow_ss'] = self.crit_flow_ss(self.output['flow_out'], self.output['flow_tar'], self.input['seg_1'], self.input['seg_2'], self.output['vis_tar']) 129 | 130 | # visibility loss 131 | self.output['loss_vis'] = self.crit_vis(self.output['vis_out'], self.output['vis_tar'].long().squeeze(dim=1)) 132 | # EPE 133 | self.output['EPE'] = networks.EPE(self.output['flow_out'], self.output['flow_tar'], self.output['mask_tar']) 134 | 135 | def backward(self, check_grad=False): 136 | 137 | if not check_grad: 138 | loss = 0 139 | loss += self.output['loss_flow'] * self.opt.loss_weight_flow 140 | loss += self.output['loss_vis'] * self.opt.loss_weight_vis 141 | if self.opt.use_ss_flow_loss: 142 | loss += self.output['loss_flow_ss'] * self.opt.loss_weight_flow_ss 143 | loss.backward() 144 | else: 145 | with networks.CalcGradNorm(self.netF) as cgn: 146 | (self.output['loss_flow']*self.opt.loss_weight_flow).backward(retain_graph=True) 147 | self.output['grad_flow'] = cgn.get_grad_norm() 148 | (self.output['loss_vis'] * self.opt.loss_weight_vis).backward(retain_graph=True) 149 | self.output['grad_vis'] = cgn.get_grad_norm() 150 | if self.opt.use_ss_flow_loss: 151 | (self.output['loss_flow_ss'] * self.opt.loss_weight_flow_ss).backward(retain_graph=True) 152 | self.output['grad_flow_ss'] = cgn.get_grad_norm() 153 | 154 | def optimize_parameters(self, check_grad=False): 155 | self.output = {} 156 | self.train() 157 | self.forward() 158 | self.optim.zero_grad() 159 | self.compute_loss() 160 | self.backward(check_grad) 161 | self.optim.step() 162 | 163 | def get_input_dim(self, input_type): 164 | dim = 0 165 | input_items = input_type.split('+') 166 | input_items.sort() 167 | for item in input_items: 168 | if item == 'img': 169 | dim += 3 170 | elif item == 'seg': 171 | dim += self.opt.seg_nc 172 | elif item == 'joint': 173 | dim += self.opt.joint_nc 174 | elif item == 'flow' or item == 'flow_gt': 175 | dim += 2 176 | elif item == 'flow_feat': 177 | dim += self.netF.nf_out 178 | elif item == 'vis': 179 | dim += 3 180 | elif item == 'dissrc' or item=='disdst': 181 | dim += 12 182 | else: 183 | raise Exception('invalid input type %s'%item) 184 | return dim 185 | 186 | def get_input_tensor(self, input_type, index='1'): 187 | assert index in {'1', '2'} 188 | tensor = [] 189 | input_items = input_type.split('+') 190 | input_items.sort() 191 | for item in input_items: 192 | if item == 'img': 193 | tensor.append(self.input['img_%s'%index]) 194 | elif item == 'seg': 195 | tensor.append(self.input['seg_%s'%index]) 196 | elif item == 'joint': 197 | tensor.append(self.input['joint_%s'%index]) 198 | elif item == 'dissrc': 199 | tensor.append(self.input['dissrc']) 200 | elif item == 'disdst': 201 | tensor.append(self.input['disdst']) 202 | else: 203 | raise Exception('invalid input type %s'%item) 204 | tensor = torch.cat(tensor, dim=1) 205 | return tensor 206 | 207 | def get_current_errors(self): 208 | error_list = [ 209 | 'EPE', 210 | 'loss_flow', 211 | 'loss_vis', 212 | 'loss_flow_ss', 213 | 'grad_flow', 214 | 'grad_vis', 215 | 'grad_flow_ss', 216 | ] 217 | errors = OrderedDict() 218 | for item in error_list: 219 | if item in self.output: 220 | errors[item] = self.output[item].item() 221 | 222 | return errors 223 | 224 | def get_current_visuals(self): 225 | visuals = OrderedDict([ 226 | ('img_1', [self.input['img_1'].data.cpu(), 'rgb']), 227 | ('img_2', [self.input['img_2'].data.cpu(), 'rgb']), 228 | ('joint_1', [self.input['joint_1'].data.cpu(), 'pose']), 229 | ('joint_2', [self.input['joint_2'].data.cpu(), 'pose']), 230 | ('seg_1', [self.input['seg_1'].data.cpu(), 'seg']), 231 | ('seg_2', [self.input['seg_2'].data.cpu(), 'seg']), 232 | ('flow_tar', [self.output['flow_tar'].data.cpu(), 'flow']), 233 | ('flow_out', [self.output['flow_out'].data.cpu(), 'flow']), 234 | ('vis_tar', [self.output['vis_tar'].data.cpu(), 'vis']), 235 | ('vis_out', [self.output['vis_out'].data.cpu(), 'vis']), 236 | ('flow_final', [self.output['flow_final'].data.cpu(), 'flow']), 237 | ]) 238 | return visuals 239 | 240 | def save(self, label): 241 | # save networks weights 242 | self.save_network(self.netF, 'netF', label, self.gpu_ids) 243 | # save optimizer status 244 | if self.is_train: 245 | self.save_optim(self.optim, 'optim', label) 246 | 247 | 248 | def train(self): 249 | self.netF.train() 250 | 251 | 252 | def eval(self): 253 | self.netF.eval() 254 | -------------------------------------------------------------------------------- /models/modules.py: -------------------------------------------------------------------------------- 1 | from __future__ import division, print_function 2 | import torch 3 | import torchvision 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.nn import init 7 | from torch.autograd import Variable 8 | from torch.optim import lr_scheduler 9 | import functools 10 | import numpy as np 11 | from skimage.measure import compare_ssim, compare_psnr 12 | 13 | ############################################################################### 14 | # model helper functions 15 | ############################################################################### 16 | def print_network(net): 17 | num_params = 0 18 | for param in net.parameters(): 19 | num_params += param.numel() 20 | print(net) 21 | print('Total number of parameters: %d' % num_params) 22 | 23 | def get_norm_layer(norm_type = 'instance'): 24 | if norm_type == 'batch': 25 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True) 26 | elif norm_type == 'instance': 27 | norm_layer = functools.partial(nn.InstanceNorm2d, affine =False) 28 | elif norm_type == 'none': 29 | norm_layer = Identity 30 | else: 31 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type) 32 | return norm_layer 33 | 34 | ############################################################################### 35 | # parameter initialize 36 | ############################################################################### 37 | def weights_init_normal(m): 38 | classname = m.__class__.__name__ 39 | # print(classname) 40 | if classname.startswith('Conv'): 41 | init.normal_(m.weight, 0.0, 0.02) 42 | elif classname.startswith('Linear'): 43 | init.normal_(m.weight, 0.0, 0.02) 44 | elif classname.startswith('BatchNorm2d'): 45 | init.normal_(m.weight, 1.0, 0.02) 46 | 47 | if 'bias' in m._parameters and m.bias is not None: 48 | init.constant_(m.bias, 0.0) 49 | 50 | def weights_init_normal2(m): 51 | classname = m.__class__.__name__ 52 | # print(classname) 53 | if classname.startswith('Conv'): 54 | init.normal_(m.weight, 0.0, 0.001) 55 | elif classname.startswith('Linear'): 56 | init.normal_(m.weight, 0.0, 0.001) 57 | elif classname.startswith('BatchNorm2d'): 58 | init.normal_(m.weight, 1.0, 0.001) 59 | 60 | if 'bias' in m._parameters and m.bias is not None: 61 | init.constant_(m.bias, 0.0) 62 | 63 | def weights_init_xavier(m): 64 | classname = m.__class__.__name__ 65 | # print(classname) 66 | if classname.startswith('Conv'): 67 | init.xavier_normal_(m.weight, gain=0.02) 68 | elif classname.startswith('Linear'): 69 | init.xavier_normal_(m.weight, gain=0.02) 70 | elif classname.startswith('BatchNorm2d'): 71 | init.normal_(m.weight, 1.0, 0.02) 72 | 73 | if 'bias' in m._parameters and m.bias is not None: 74 | init.constant_(m.bias, 0.0) 75 | 76 | def weights_init_kaiming(m): 77 | classname = m.__class__.__name__ 78 | # print(classname) 79 | if classname.startswith('Conv'): 80 | init.kaiming_normal_(m.weight, a=0, mode='fan_in') 81 | elif classname.startswith('Linear'): 82 | init.kaiming_normal_(m.weight, a=0, mode='fan_in') 83 | elif classname.startswith('BatchNorm2d'): 84 | if m.affine == True: 85 | init.normal_(m.weight, 1.0, 0.02) 86 | 87 | if 'bias' in m._parameters and m.bias is not None: 88 | init.constant_(m.bias, 0.0) 89 | 90 | def weights_init_orthogonal(m): 91 | classname = m.__class__.__name__ 92 | # print(classname) 93 | if classname.startswith('Conv'): 94 | init.orthogonal_(m.weight, gain=1) 95 | elif classname.startswith('Linear'): 96 | init.orthogonal_(m.weight, gain=1) 97 | elif classname.startswith('BatchNorm2d'): 98 | init.normal_(m.weight, 1.0, 0.02) 99 | 100 | if 'bias' in m._parameters and m.bias is not None: 101 | init.constant_(m.bias, 0.0) 102 | 103 | def init_weights(net, init_type='normal'): 104 | # print('initialization method [%s]' % init_type) 105 | if init_type == 'normal': 106 | net.apply(weights_init_normal) 107 | elif init_type == 'normal2': 108 | net.apply(weights_init_normal2) 109 | elif init_type == 'xavier': 110 | net.apply(weights_init_xavier) 111 | elif init_type == 'kaiming': 112 | net.apply(weights_init_kaiming) 113 | elif init_type == 'orthogonal': 114 | net.apply(weights_init_orthogonal) 115 | else: 116 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 117 | 118 | ############################################################################### 119 | # Optimizer and Scheduler 120 | ############################################################################### 121 | def get_scheduler(optimizer, opt): 122 | if opt.lr_policy == 'lambda': 123 | def lambda_rule(epoch): 124 | lr_l = 1.0 - max(0, epoch + 1 - opt.niter) / float(opt.niter_decay + 1) 125 | return lr_l 126 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 127 | elif opt.lr_policy == 'step': 128 | if opt.resume_train: 129 | last_epoch = int(opt.last_epoch) -1 130 | else: 131 | last_epoch = -1 132 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay, gamma=opt.lr_gamma, last_epoch=last_epoch) 133 | elif opt.lr_policy == 'plateau': 134 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) 135 | else: 136 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) 137 | return scheduler 138 | 139 | ############################################################################### 140 | # Loss Helper 141 | ############################################################################### 142 | class SmoothLoss(): 143 | ''' 144 | wrapper of pytorch loss layer. 145 | ''' 146 | def __init__(self, crit): 147 | self.crit = crit 148 | self.max_size = 100000 149 | self.clear() 150 | 151 | def __call__(self, input_1, input_2, *extra_input): 152 | loss = self.crit(input_1, input_2, *extra_input) 153 | self.weight_buffer.append(input_1.size(0)) 154 | 155 | if isinstance(loss, Variable): 156 | self.buffer.append(loss.data.item()) 157 | elif isinstance(loss, torch.Tensor): 158 | self.buffer.append(loss.data.item()) 159 | else: 160 | self.buffer.append(loss) 161 | 162 | if len(self.buffer) > self.max_size: 163 | self.buffer = self.buffer[-self.max_size::] 164 | self.weight_buffer = self.weight_buffer[-self.max_size::] 165 | 166 | return loss 167 | 168 | def clear(self): 169 | self.buffer = [] 170 | self.weight_buffer = [] 171 | 172 | def smooth_loss(self, clear = False): 173 | if len(self.weight_buffer) == 0: 174 | loss = 0 175 | else: 176 | loss = sum([l * w for l, w in zip(self.buffer, self.weight_buffer)]) / sum(self.weight_buffer) 177 | if clear: 178 | self.clear() 179 | return loss 180 | 181 | class CalcGradNorm(object): 182 | ''' 183 | example: 184 | y = model(x) 185 | with CalcGradNorm(model) as cgn: 186 | y.backward() 187 | grad_norm = cgn.get_grad_norm() 188 | ''' 189 | def __init__(self, module): 190 | super(CalcGradNorm, self).__init__() 191 | self.module = module 192 | 193 | def __enter__(self): 194 | self.grad_list = [p.grad.clone() if p.grad is not None else None for p in self.module.parameters()] 195 | return self 196 | 197 | def __exit__(self, type, value, traceback): 198 | pass 199 | 200 | def get_grad_norm(self): 201 | grad_norm = self.module.parameters().__next__().new_zeros([]) 202 | new_grad_list = [] 203 | for i, p in enumerate(self.module.parameters()): 204 | if p.grad is None: 205 | assert self.grad_list[i] is None, 'gradient information is missing. maybe caused by calling "zero_grad()"' 206 | new_grad_list.append(None) 207 | else: 208 | g = p.grad.clone() 209 | new_grad_list.append(g) 210 | if self.grad_list[i] is None: 211 | grad_norm += g.norm() 212 | else: 213 | grad_norm += (g-self.grad_list[i]).norm() 214 | 215 | self.grad_list = new_grad_list 216 | return grad_norm.detach() 217 | 218 | 219 | 220 | ############################################################################### 221 | # Losses and metrics 222 | ############################################################################### 223 | class GANLoss(nn.Module): 224 | def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0): 225 | super(GANLoss, self).__init__() 226 | self.register_buffer('real_label', torch.tensor(target_real_label)) 227 | self.register_buffer('fake_label', torch.tensor(target_fake_label)) 228 | if use_lsgan: 229 | self.loss = F.mse_loss 230 | else: 231 | self.loss = F.binary_cross_entropy 232 | 233 | def get_target_tensor(self, input, target_is_real): 234 | if target_is_real: 235 | target_tensor = self.real_label 236 | else: 237 | target_tensor = self.fake_label 238 | return target_tensor.expand_as(input) 239 | 240 | def forward(self, input, target_is_real): 241 | target_tensor = self.get_target_tensor(input, target_is_real) 242 | return self.loss(input, target_tensor) 243 | 244 | 245 | class VGGLoss(nn.Module): 246 | def __init__(self, gpu_ids, content_weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0], style_weights=[1.,1.,1.,1.,1.],shifted_style=False): 247 | super(VGGLoss, self).__init__() 248 | self.gpu_ids = gpu_ids 249 | self.shifted_style = shifted_style 250 | self.content_weights = content_weights 251 | self.style_weights = style_weights 252 | self.shift_delta = [[0,2,4,8,16], [0,2,4,8], [0,2,4], [0,2], [0]] 253 | # self.style_weights = [0,0,1,0,0] # use relu-3 layer feature to compure style loss 254 | # define vgg 255 | vgg_pretrained_features = torchvision.models.vgg19(pretrained=True).features 256 | self.slice1 = torch.nn.Sequential() 257 | self.slice2 = torch.nn.Sequential() 258 | self.slice3 = torch.nn.Sequential() 259 | self.slice4 = torch.nn.Sequential() 260 | self.slice5 = torch.nn.Sequential() 261 | for x in range(2): 262 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) # relu1_1 263 | for x in range(2, 7): 264 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) # relu2_1 265 | for x in range(7, 12): 266 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) # relu3_1 267 | for x in range(12, 21): 268 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) # relu4_1 269 | for x in range(21, 30): 270 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) # relu5_1 271 | for param in self.parameters(): 272 | param.requires_grad = False 273 | 274 | if len(gpu_ids) > 0: 275 | self.cuda() 276 | 277 | def compute_feature(self, X): 278 | h_relu1 = self.slice1(X) 279 | h_relu2 = self.slice2(h_relu1) 280 | h_relu3 = self.slice3(h_relu2) 281 | h_relu4 = self.slice4(h_relu3) 282 | h_relu5 = self.slice5(h_relu4) 283 | out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] 284 | return out 285 | 286 | def forward(self, X, Y, mask=None, loss_type='content', device_mode=None): 287 | ''' 288 | loss_type: 'content', 'style' 289 | device_mode: multi, single, sub 290 | ''' 291 | bsz = X.size(0) 292 | if device_mode is None: 293 | device_mode = 'multi' if len(self.gpu_ids) > 1 else 'single' 294 | 295 | if device_mode == 'multi': 296 | if mask is None: 297 | return nn.parallel.data_parallel(self, (X, Y), module_kwargs={'loss_type': loss_type, 'device_mode': 'sub', 'mask': None}).mean(dim=0) 298 | else: 299 | return nn.parallel.data_parallel(self, (X, Y, mask), module_kwargs={'loss_type': loss_type, 'device_mode': 'sub'}).mean(dim=0) 300 | else: 301 | features_x = self.compute_feature(self.normalize(X)) 302 | features_y = self.compute_feature(self.normalize(Y)) 303 | if mask is not None: 304 | features_x = [feat * F.adaptive_max_pool2d(mask, (feat.size(2), feat.size(3))) for feat in features_x] 305 | features_y = [feat * F.adaptive_max_pool2d(mask, (feat.size(2), feat.size(3))) for feat in features_y] 306 | 307 | # compute content loss 308 | if loss_type == 'content': 309 | loss = 0 310 | for i, (feat_x, feat_y) in enumerate(zip(features_x, features_y)): 311 | loss += self.content_weights[i] * F.l1_loss(feat_x, feat_y, reduce=False).view(bsz, -1).mean(dim=1) 312 | # compute style loss 313 | if loss_type == 'style': 314 | loss = 0 315 | if self.shifted_style: 316 | # with cross_correlation 317 | for i, (feat_x, feat_y) in enumerate(zip(features_x, features_y)): 318 | if self.style_weights[i] > 0: 319 | for delta in self.shift_delta[i]: 320 | if delta == 0: 321 | loss += self.style_weights[i] * F.mse_loss(self.gram_matrix(feat_x), self.gram_matrix(feat_y), reduce=False).view(bsz, -1).sum(dim=1) 322 | else: 323 | loss += 0.5*self.style_weights[i] * \ 324 | (F.mse_loss(self.shifted_gram_matrix(feat_x, delta, 0), self.shifted_gram_matrix(feat_y, delta, 0), reduce=False) \ 325 | +F.mse_loss(self.shifted_gram_matrix(feat_x, 0, delta), self.shifted_gram_matrix(feat_y, 0, delta), reduce=False)).view(bsz, -1).sum(dim=1) 326 | else: 327 | # without cross_correlation 328 | for i, (feat_x, feat_y) in enumerate(zip(features_x, features_y)): 329 | if self.style_weights[i] > 0: 330 | loss += self.style_weights[i] * F.mse_loss(self.gram_matrix(feat_x), self.gram_matrix(feat_y), reduce=False).view(bsz, -1).sum(dim=1) 331 | # loss += self.style_weights[i] * ((self.gram_matrix(feat_x) - self.gram_matrix(feat_y))**2).view(bsz, -1).mean(dim=1) 332 | 333 | if device_mode == 'single': 334 | loss = loss.mean(dim=0) 335 | return loss 336 | 337 | def normalize(self, x): 338 | # normalization parameters of input 339 | mean_1 = x.new([0.5, 0.5, 0.5]).view(1,3,1,1) 340 | std_1 = x.new([0.5, 0.5, 0.5]).view(1,3,1,1) 341 | # normalization parameters of output 342 | mean_2 = x.new([0.485, 0.456, 0.406]).view(1,3,1,1) 343 | std_2 = x.new([0.229, 0.224, 0.225]).view(1,3,1,1) 344 | 345 | return (x*std_1 + mean_1 - mean_2)/std_2 346 | 347 | def gram_matrix(self, feat): 348 | bsz, c, h, w = feat.size() 349 | feat = feat.view(bsz, c, h*w) 350 | feat_T = feat.transpose(1,2) 351 | g = torch.matmul(feat, feat_T) / (c*h*w) 352 | return g 353 | 354 | def shifted_gram_matrix(self, feat, shift_x, shift_y): 355 | bsz, c, h, w = feat.size() 356 | assert shift_x0).float() 454 | mask = (seg_2*(1-seg_1w)).sum(dim=1, keepdim=True) 455 | mask = mask * (vis_2==0).float() 456 | err = (input_flow - target_flow).mul(self.div_flow) * mask 457 | if self.loss_type == 'l1': 458 | loss = err.abs().mean() 459 | elif self.loss_type == 'l2': 460 | loss = err.norm(p=2,dim=1).mean() 461 | return loss 462 | 463 | 464 | class MeanAP(): 465 | ''' 466 | compute meanAP 467 | ''' 468 | 469 | def __init__(self): 470 | self.clear() 471 | 472 | def clear(self): 473 | self.score = None 474 | self.label = None 475 | 476 | def add(self, new_score, new_label): 477 | 478 | inputs = [new_score, new_label] 479 | 480 | for i in range(len(inputs)): 481 | 482 | if isinstance(inputs[i], list): 483 | inputs[i] = np.array(inputs[i], dtype = np.float32) 484 | 485 | elif isinstance(inputs[i], np.ndarray): 486 | inputs[i] = inputs[i].astype(np.float32) 487 | 488 | elif isinstance(inputs[i], torch.Tensor): 489 | inputs[i] = inputs[i].cpu().numpy().astype(np.float32) 490 | 491 | elif isinstance(inputs[i], Variable): 492 | inputs[i] = inputs[i].data.cpu().numpy().astype(np.float32) 493 | 494 | new_score, new_label = inputs 495 | assert new_score.shape == new_label.shape, 'shape mismatch: %s vs. %s' % (new_score.shape, new_label.shape) 496 | 497 | self.score = np.concatenate((self.score, new_score), axis = 0) if self.score is not None else new_score 498 | self.label = np.concatenate((self.label, new_label), axis = 0) if self.label is not None else new_label 499 | 500 | def compute_mean_ap(self): 501 | 502 | score, label = self.score, self.label 503 | 504 | assert score is not None and label is not None 505 | assert score.shape == label.shape, 'shape mismatch: %s vs. %s' % (score.shape, label.shape) 506 | assert(score.ndim == 2) 507 | M, N = score.shape[0], score.shape[1] 508 | 509 | # compute tp: column n in tp is the n-th class label in descending order of the sample score. 510 | index = np.argsort(score, axis = 0)[::-1, :] 511 | tp = label.copy().astype(np.float) 512 | for i in xrange(N): 513 | tp[:, i] = tp[index[:,i], i] 514 | tp = tp.cumsum(axis = 0) 515 | 516 | m_grid, n_grid = np.meshgrid(range(M), range(N), indexing = 'ij') 517 | tp_add_fp = m_grid + 1 518 | num_truths = np.sum(label, axis = 0) 519 | # compute recall and precise 520 | rec = tp / (num_truths+1e-8) 521 | prec = tp / (tp_add_fp+1e-8) 522 | 523 | prec = np.append(np.zeros((1,N), dtype = np.float), prec, axis = 0) 524 | for i in xrange(M-1, -1, -1): 525 | prec[i, :] = np.max(prec[i:i+2, :], axis = 0) 526 | rec_1 = np.append(np.zeros((1,N), dtype = np.float), rec, axis = 0) 527 | rec_2 = np.append(rec, np.ones((1,N), dtype = np.float), axis = 0) 528 | AP = np.sum(prec * (rec_2 - rec_1), axis = 0) 529 | AP[np.isnan(AP)] = -1 # avoid error caused by classes that have no positive sample 530 | 531 | assert((AP <= 1).all()) 532 | 533 | AP = AP * 100. 534 | meanAP = AP[AP >= 0].mean() 535 | 536 | return meanAP, AP 537 | 538 | def compute_recall(self, k = 3): 539 | ''' 540 | compute recall using method in DeepFashion Paper 541 | ''' 542 | score, label = self.score, self.label 543 | tag = np.where((-score).argsort().argsort() < k, 1, 0) 544 | tag_rec = tag * label 545 | 546 | count_rec = tag_rec.sum(axis = 1) 547 | count_gt = label.sum(axis = 1) 548 | 549 | # set recall=1 for sample with no positive attribute label 550 | no_pos_attr = (count_gt == 0).astype(count_gt.dtype) 551 | count_rec += no_pos_attr 552 | count_gt += no_pos_attr 553 | 554 | rec = (count_rec / count_gt).mean() * 100. 555 | 556 | return rec 557 | 558 | ############################################################################### 559 | # image similarity metrics 560 | ############################################################################### 561 | class PSNR(nn.Module): 562 | def forward(self, images_1, images_2): 563 | numpy_imgs_1 = images_1.cpu().detach().numpy().transpose(0,2,3,1) 564 | numpy_imgs_1 = ((numpy_imgs_1 + 1.0) * 127.5).clip(0,255).astype(np.uint8) 565 | numpy_imgs_2 = images_2.cpu().detach().numpy().transpose(0,2,3,1) 566 | numpy_imgs_2 = ((numpy_imgs_2 + 1.0) * 127.5).clip(0,255).astype(np.uint8) 567 | 568 | psnr_score = [] 569 | for img_1, img_2 in zip(numpy_imgs_1, numpy_imgs_2): 570 | psnr_score.append(compare_psnr(img_2, img_1)) 571 | 572 | return Variable(images_1.data.new(1).fill_(np.mean(psnr_score))) 573 | 574 | 575 | class SSIM(nn.Module): 576 | def forward(self, images_1, images_2, mask=None): 577 | numpy_imgs_1 = images_1.cpu().detach().numpy().transpose(0,2,3,1) 578 | numpy_imgs_1 = ((numpy_imgs_1 + 1.0) * 127.5).clip(0,255).astype(np.uint8) 579 | numpy_imgs_2 = images_2.cpu().detach().numpy().transpose(0,2,3,1) 580 | numpy_imgs_2 = ((numpy_imgs_2 + 1.0) * 127.5).clip(0,255).astype(np.uint8) 581 | if mask is not None: 582 | mask = mask.cpu().detach().numpy().transpose(0,2,3,1).astype(np.uint8) 583 | numpy_imgs_1 = numpy_imgs_1 * mask 584 | numpy_imgs_2 = numpy_imgs_2 * mask 585 | 586 | ssim_score = [] 587 | for img_1, img_2 in zip(numpy_imgs_1, numpy_imgs_2): 588 | ssim_score.append(compare_ssim(img_1, img_2, multichannel=True)) 589 | return Variable(images_1.data.new(1).fill_(np.mean(ssim_score))) 590 | 591 | 592 | 593 | ############################################################################### 594 | # flow-based warping 595 | ############################################################################### 596 | def warp_acc_flow(x, flow, mode='bilinear', mask=None, mask_value=-1): 597 | ''' 598 | warp an image/tensor according to given flow. 599 | Input: 600 | x: (bsz, c, h, w) 601 | flow: (bsz, c, h, w) 602 | mask: (bsz, 1, h, w). 1 for valid region and 0 for invalid region. invalid region will be fill with "mask_value" in the output images. 603 | Output: 604 | y: (bsz, c, h, w) 605 | ''' 606 | bsz, c, h, w = x.size() 607 | # mesh grid 608 | xx = x.new_tensor(range(w)).view(1,-1).repeat(h,1) 609 | yy = x.new_tensor(range(h)).view(-1,1).repeat(1,w) 610 | xx = xx.view(1,1,h,w).repeat(bsz,1,1,1) 611 | yy = yy.view(1,1,h,w).repeat(bsz,1,1,1) 612 | grid = torch.cat((xx,yy), dim=1).float() 613 | grid = grid + flow 614 | # scale to [-1, 1] 615 | grid[:,0,:,:] = 2.0*grid[:,0,:,:]/max(w-1,1) - 1.0 616 | grid[:,1,:,:] = 2.0*grid[:,1,:,:]/max(h-1,1) - 1.0 617 | 618 | grid = grid.permute(0,2,3,1) 619 | output = F.grid_sample(x, grid, mode=mode, padding_mode='zeros') 620 | # mask = F.grid_sample(x.new_ones(x.size()), grid) 621 | # mask = torch.where(mask<0.9999, mask.new_zeros(1), mask.new_ones(1)) 622 | # return output * mask 623 | if mask is not None: 624 | output = torch.where(mask>0.5, output, output.new_ones(1).mul_(mask_value)) 625 | return output 626 | 627 | ############################################################################### 628 | # layers 629 | ############################################################################### 630 | class Identity(nn.Module): 631 | def __init__(self, dim=None): 632 | super(Identity, self).__init__() 633 | def forward(self, x): 634 | return x 635 | -------------------------------------------------------------------------------- /models/normalization.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import re 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.nn.utils.spectral_norm as spectral_norm 10 | 11 | 12 | # Returns a function that creates a normalization function 13 | # that does not condition on semantic map 14 | def get_nonspade_norm_layer(opt, norm_type='instance'): 15 | # helper function to get # output channels of the previous layer 16 | def get_out_channel(layer): 17 | if hasattr(layer, 'out_channels'): 18 | return getattr(layer, 'out_channels') 19 | return layer.weight.size(0) 20 | 21 | # this function will be returned 22 | def add_norm_layer(layer): 23 | nonlocal norm_type 24 | if norm_type.startswith('spectral'): 25 | layer = spectral_norm(layer) 26 | subnorm_type = norm_type[len('spectral'):] 27 | 28 | if subnorm_type == 'none' or len(subnorm_type) == 0: 29 | return layer 30 | 31 | # remove bias in the previous layer, which is meaningless 32 | # since it has no effect after normalization 33 | if getattr(layer, 'bias', None) is not None: 34 | delattr(layer, 'bias') 35 | layer.register_parameter('bias', None) 36 | 37 | if subnorm_type == 'batch': 38 | norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True) 39 | # elif subnorm_type == 'sync_batch': 40 | # norm_layer = SynchronizedBatchNorm2d(get_out_channel(layer), affine=True) 41 | elif subnorm_type == 'instance': 42 | norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False) 43 | else: 44 | raise ValueError('normalization layer %s is not recognized' % subnorm_type) 45 | 46 | return nn.Sequential(layer, norm_layer) 47 | 48 | return add_norm_layer 49 | 50 | 51 | class SPADE(nn.Module): 52 | def __init__(self, config_text, norm_nc, label_nc): 53 | super().__init__() 54 | 55 | assert config_text.startswith('spade') 56 | # eg:spadeinstance5x5 57 | parsed = re.search('spade(\D+)(\d)x\d', config_text) 58 | param_free_norm_type = str(parsed.group(1)) 59 | ks = int(parsed.group(2)) 60 | 61 | if param_free_norm_type == 'instance': 62 | self. param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False) 63 | # elif param_free_norm_type == 'syncbatch': 64 | # self.param_free_norm = SynchronizedBatchNorm2d(norm_nc, affine=False) 65 | elif param_free_norm_type == 'batch': 66 | self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False) 67 | else: 68 | raise ValueError('%s is not a recognized param-free norm type in SPADE' 69 | % param_free_norm_type) 70 | 71 | # The dimension of the intermediate embedding space. Yes, hardcoded. 72 | nhidden = 64 73 | 74 | #ks is odd number, 3//2=1, so output feature map has same shape with input 75 | pw = ks // 2 76 | self.mlp_shared = nn.Sequential( 77 | nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw), 78 | nn.ReLU() 79 | ) 80 | self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw) 81 | self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw) 82 | 83 | def forward(self, x, segmap): 84 | 85 | # Part 1. generate parameter-free normalized activations 86 | normalized = self.param_free_norm(x) 87 | 88 | # Part 2. produce scaling and bias conditioned on semantic map 89 | segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest') 90 | actv = self.mlp_shared(segmap) 91 | gamma = self.mlp_gamma(actv) 92 | beta = self.mlp_beta(actv) 93 | 94 | # apply scale and bias 95 | out = normalized * (1 + gamma) + beta 96 | 97 | return out -------------------------------------------------------------------------------- /models/pose_transfer_model.py: -------------------------------------------------------------------------------- 1 | from __future__ import division, print_function 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import os 7 | from collections import OrderedDict 8 | import argparse 9 | 10 | from . import SPG_net_market 11 | from . import SPG_net_deepfashion 12 | from . import networks 13 | from .base_model import BaseModel 14 | from util import io, pose_util 15 | 16 | class PoseTransferModel(BaseModel): 17 | ''' 18 | Pose transfer framework that cascade a 3d-flow module and a generation module. 19 | ''' 20 | 21 | def name(self): 22 | return 'PoseTransferModel' 23 | 24 | def initialize(self, opt): 25 | super(PoseTransferModel, self).initialize(opt) 26 | ################################### 27 | # define generator 28 | ################################### 29 | self.seg_cihp_nc = 20 30 | self.use_parsing = True 31 | self.opt = opt 32 | if opt.dataset_name == 'market': 33 | SPGNet = SPG_net_market 34 | else: 35 | SPGNet = SPG_net_deepfashion 36 | self.netG = SPGNet.DualUnetGenerator_SEAN( 37 | pose_nc=self.get_tensor_dim(opt.G_pose_type), 38 | appearance_nc=self.get_tensor_dim(opt.G_appearance_type), 39 | output_nc=3, 40 | aux_output_nc=[], 41 | nf=opt.G_nf, 42 | max_nf=opt.G_max_nf, 43 | num_scales=opt.G_n_scale, 44 | num_warp_scales=opt.G_n_warp_scale, 45 | n_residual_blocks=2, 46 | norm=opt.G_norm, 47 | vis_mode=opt.G_vis_mode, 48 | activation=nn.LeakyReLU(0.1) if opt.G_activation == 'leaky_relu' else nn.ReLU(), 49 | use_dropout=opt.use_dropout, 50 | no_end_norm=opt.G_no_end_norm, 51 | gpu_ids=opt.gpu_ids, 52 | isTrain = self.is_train 53 | ) 54 | # print(self.netG) 55 | if opt.gpu_ids: 56 | self.netG.cuda() 57 | networks.init_weights(self.netG, init_type=opt.init_type) 58 | ################################### 59 | # define external pixel warper 60 | ################################### 61 | if opt.G_pix_warp: 62 | pix_warp_n_scale = opt.G_n_scale 63 | self.netPW = networks.UnetGenerator_MultiOutput( 64 | input_nc=self.get_tensor_dim(opt.G_pix_warp_input_type), 65 | output_nc=[1], # only use one output branch (weight mask) 66 | nf=32, 67 | max_nf=128, 68 | num_scales=pix_warp_n_scale, 69 | n_residual_blocks=2, 70 | norm=opt.G_norm, 71 | activation=nn.ReLU(False), 72 | use_dropout=False, 73 | gpu_ids=opt.gpu_ids 74 | ) 75 | if opt.gpu_ids: 76 | self.netPW.cuda() 77 | networks.init_weights(self.netPW, init_type=opt.init_type) 78 | ################################### 79 | # define discriminator 80 | ################################### 81 | self.use_gan = self.is_train and self.opt.loss_weight_gan > 0 82 | if self.use_gan: 83 | self.netD = networks.NLayerDiscriminator( 84 | input_nc=self.get_tensor_dim(opt.D_input_type_real), 85 | ndf=opt.D_nf, 86 | n_layers=opt.D_n_layers, 87 | use_sigmoid=(opt.gan_type == 'dcgan'), 88 | output_bias=True, 89 | gpu_ids=opt.gpu_ids, 90 | ) 91 | if opt.gpu_ids: 92 | self.netD.cuda() 93 | networks.init_weights(self.netD, init_type=opt.init_type) 94 | ################################### 95 | # load optical flow model 96 | ################################### 97 | if opt.flow_on_the_fly: 98 | self.netF = load_flow_network(opt.pretrained_flow_id, opt.pretrained_flow_epoch, opt.gpu_ids) 99 | self.netF.eval() 100 | if opt.gpu_ids: 101 | self.netF.cuda() 102 | 103 | ################################### 104 | # loss and optimizers 105 | ################################### 106 | # self.crit_psnr = networks.PSNR().cuda() 107 | self.crit_ssim = networks.SSIM().cuda() 108 | 109 | if self.is_train: 110 | self.crit_vgg = networks.VGGLoss(opt.gpu_ids, shifted_style=opt.shifted_style_loss, 111 | content_weights=opt.vgg_content_weights) 112 | if opt.G_pix_warp: 113 | # only optimze netPW 114 | self.optim = torch.optim.Adam(self.netPW.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2), 115 | weight_decay=opt.weight_decay) 116 | else: 117 | self.optim = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2), 118 | weight_decay=opt.weight_decay) 119 | self.optimizers = [self.optim] 120 | 121 | if self.use_gan: 122 | self.crit_gan = networks.GANLoss(use_lsgan=(opt.gan_type == 'lsgan')) 123 | if self.gpu_ids: 124 | self.crit_gan.cuda() 125 | self.optim_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr_D, betas=(opt.beta1, opt.beta2), 126 | weight_decay=opt.weight_decay_D) 127 | self.optimizers += [self.optim_D] 128 | 129 | ################################### 130 | # load trained model 131 | ################################### 132 | if not self.is_train: 133 | # load trained model for testing 134 | self.load_network(self.netG, 'netG', opt.which_epoch) 135 | if opt.G_pix_warp: 136 | self.load_network(self.netPW, 'netPW', opt.which_epoch) 137 | elif opt.pretrained_G_id is not None: 138 | # load pretrained network 139 | self.load_network(self.netG, 'netG', opt.pretrained_G_epoch, opt.pretrained_G_id) 140 | elif opt.resume_train: 141 | # resume training 142 | self.load_network(self.netG, 'netG', opt.last_epoch) 143 | self.load_optim(self.optim, 'optim', opt.last_epoch) 144 | # note 145 | if self.use_gan: 146 | self.load_network(self.netD, 'netD', opt.last_epoch) 147 | self.load_optim(self.optim_D, 'optim_D', opt.last_epoch) 148 | if opt.G_pix_warp: 149 | self.load_network(self.netPW, 'netPW', opt.last_epoch) 150 | ################################### 151 | # schedulers 152 | ################################### 153 | if self.is_train: 154 | self.schedulers = [] 155 | for optim in self.optimizers: 156 | self.schedulers.append(networks.get_scheduler(optim, opt)) 157 | 158 | def set_input(self, data): 159 | self.input_list = [ 160 | 'img_1', 161 | 'img_2', 162 | 'joint_1', 163 | 'joint_2', 164 | ] 165 | if self.use_parsing: 166 | self.input_list += ['seg_cihp_1', 167 | 'seg_cihp_2' 168 | ] 169 | 170 | for item in self.input_list: 171 | self.input[item] = self.Tensor(data[item].size()).copy_(data[item]) 172 | 173 | self.input['id'] = zip(data['id_1'], data['id_2']) 174 | 175 | def forward(self, test=False): 176 | # generate flow 177 | flow_scale = 20. 178 | if self.opt.flow_on_the_fly: 179 | with torch.no_grad(): 180 | input_F = self.get_tensor(self.opt.F_input_type) 181 | flow_out, vis_out, _, _ = self.netF(input_F) 182 | self.output['vis_out'] = vis_out.argmax(dim=1, keepdim=True).float() 183 | self.output['mask_out'] = (self.output['vis_out'] < 2).float() 184 | self.output['flow_out'] = flow_out * flow_scale * self.output['mask_out'] 185 | else: 186 | self.output['flow_out'] = self.input['flow_2to1'] 187 | self.output['vis_out'] = self.input['vis_2'] 188 | self.output['mask_out'] = (self.output['vis_out'] < 2).float() 189 | self.output['flow_tar'] = self.output['flow_out'] 190 | self.output['vis_tar'] = self.output['vis_out'] 191 | self.output['maks_tar'] = self.output['mask_out'] 192 | bsz, _, h, w = self.output['vis_out'].size() 193 | self.output['vismap_out'] = self.output['vis_out'].new(bsz, 3, h, w).scatter_(dim=1, index=self.output[ 194 | 'vis_out'].long(), value=1) 195 | 196 | # warp image 197 | self.output['img_warp'] = networks.warp_acc_flow(self.input['img_1'], self.output['flow_out'], 198 | mask=self.output['mask_out']) 199 | 200 | # generate image 201 | if self.opt.which_model_G == 'unet': 202 | input_G = self.get_tensor('+'.join([self.opt.G_appearance_type, self.opt.G_pose_type])) 203 | out = self.netG(input_G) 204 | self.output['img_out'] = F.tanh(out) 205 | elif self.opt.which_model_G == 'dual_unet': 206 | input_G_pose = self.get_tensor(self.opt.G_pose_type) 207 | input_G_appearance = self.get_tensor(self.opt.G_appearance_type) 208 | input_G_s_seg = self.get_tensor('seg_cihp_1') 209 | 210 | input_G_d_seg = self.get_tensor('seg_cihp_2') 211 | flow_in, vis_in = (self.output['flow_out'], self.output['vis_out']) if self.opt.G_feat_warp else ( 212 | None, None) 213 | 214 | dismap = None 215 | if not self.opt.G_pix_warp: 216 | out = self.netG(input_G_pose, input_G_appearance, input_G_s_seg, input_G_d_seg, flow_in, vis_in, dismap) 217 | self.output['img_out'] = F.tanh(out) 218 | else: 219 | with torch.no_grad(): 220 | out = self.netG(input_G_pose, input_G_appearance, input_G_s_seg, input_G_d_seg, flow_in, vis_in) 221 | self.output['img_out_G'] = F.tanh(out) 222 | pw_out = self.netPW(self.get_tensor(self.opt.G_pix_warp_input_type)) 223 | self.output['pix_mask'] = F.sigmoid(pw_out[0]) 224 | if self.opt.G_pix_warp_detach: 225 | self.output['img_out'] = self.output['img_warp'] * self.output['pix_mask'] + self.output[ 226 | 'img_out_G'].detach() * (1 - self.output['pix_mask']) 227 | else: 228 | self.output['img_out'] = self.output['img_warp'] * self.output['pix_mask'] + self.output[ 229 | 'img_out_G'] * (1 - self.output['pix_mask']) 230 | self.output['img_tar'] = self.input['img_2'] 231 | 232 | def test(self, compute_loss=True, meas_only=True): 233 | ''' meas_only: only compute measurements (psrn, ssim) when computing loss''' 234 | with torch.no_grad(): 235 | self.forward(test=True) 236 | if compute_loss: 237 | assert self.is_train or meas_only, 'when is_train is False, meas_only must be True' 238 | self.compute_loss(meas_only=meas_only, compute_ssim=True) 239 | 240 | def compute_loss(self, meas_only=False, compute_ssim=False): 241 | '''compute_ssim: set True to compute ssim (time consuming)''' 242 | ############################## 243 | # measurements 244 | ############################## 245 | if compute_ssim: 246 | self.output['SSIM'] = self.crit_ssim(self.output['img_out'], self.output['img_tar']) 247 | if meas_only: 248 | return 249 | ############################## 250 | # losses 251 | ############################## 252 | self.output['loss_l1'] = F.l1_loss(self.output['img_out'], self.output['img_tar']) 253 | # Content (Perceptual) 254 | self.output['loss_content'] = self.crit_vgg(self.output['img_out'], self.output['img_tar'], loss_type='content') 255 | # Style 256 | if self.opt.loss_weight_style > 0: 257 | self.output['loss_style'] = self.crit_vgg(self.output['img_out'], self.output['img_tar'], loss_type='style') 258 | # GAN 259 | if self.use_gan: 260 | input_D = self.get_tensor(self.opt.D_input_type_fake) 261 | self.output['loss_G'] = self.crit_gan(self.netD(input_D), True) 262 | 263 | def backward(self, check_grad=False): 264 | loss_ce = 0.5 265 | # if not check_grad: 266 | loss = 0 267 | loss += self.output['loss_l1'] * self.opt.loss_weight_l1 268 | loss += self.output['loss_content'] * self.opt.loss_weight_content 269 | if self.opt.loss_weight_style > 0: 270 | loss += self.output['loss_style'] * self.opt.loss_weight_style 271 | if self.use_gan: 272 | loss += self.output['loss_G'] * self.opt.loss_weight_gan 273 | 274 | self.output['total_G_loss'] = loss 275 | loss.backward() 276 | 277 | def backward_D(self): 278 | input_D_real = self.get_tensor(self.opt.D_input_type_real).detach() 279 | input_D_fake = self.get_tensor(self.opt.D_input_type_fake).detach() 280 | self.output['loss_D'] = 0.5 * (self.crit_gan(self.netD(input_D_real), True) + \ 281 | self.crit_gan(self.netD(input_D_fake), False)) 282 | (self.output['loss_D'] * self.opt.loss_weight_gan).backward() 283 | 284 | def optimize_parameters(self, check_grad=False): 285 | self.output = {} 286 | # forward 287 | self.forward() 288 | # optim netD 289 | if self.use_gan: 290 | self.optim_D.zero_grad() 291 | self.backward_D() 292 | self.optim_D.step() 293 | # optim netG 294 | self.optim.zero_grad() 295 | self.compute_loss() 296 | self.backward(check_grad) 297 | self.optim.step() 298 | 299 | def get_tensor_dim(self, tensor_type): 300 | dim = 0 301 | tensor_items = tensor_type.split('+') 302 | for item in tensor_items: 303 | if item in {'img_1', 'img_2', 'img_out', 'img_warp', 'img_out_G'}: 304 | dim += 3 305 | elif item in {'seg_1', 'seg_2'}: 306 | dim += self.opt.seg_nc 307 | elif item in {'seg_cihp_1', 'seg_cihp_2', 'pre_seg_cihp_2'}: 308 | dim += self.seg_cihp_nc 309 | elif item in {'joint_1', 'joint_2'}: 310 | dim += self.opt.joint_nc 311 | elif item in {'flow_out', 'flow_tar'}: 312 | dim += 2 313 | elif item in {'vis_out', 'vis_tar'}: 314 | dim += 1 315 | elif item in {'vismap_out', 'vismap_tar'}: 316 | dim += 3 317 | else: 318 | raise Exception('invalid tensor_type: %s' % item) 319 | return dim 320 | 321 | def get_tensor(self, tensor_type): 322 | tensor = [] 323 | tensor_items = tensor_type.split('+') 324 | for item in tensor_items: 325 | if item == 'img_1': 326 | tensor.append(self.input['img_1']) 327 | elif item == 'img_2': 328 | tensor.append(self.input['img_2']) 329 | elif item == 'img_out': 330 | tensor.append(self.output['img_out']) 331 | elif item == 'img_out_G': 332 | tensor.append(self.output['img_out_G']) 333 | elif item == 'img_warp': 334 | tensor.append(self.output['img_warp']) 335 | elif item == 'seg_1': 336 | tensor.append(self.input['seg_1']) 337 | elif item == 'seg_2': 338 | tensor.append(self.input['seg_2']) 339 | elif item == 'seg_cihp_1': 340 | tensor.append(self.input['seg_cihp_1']) 341 | elif item == 'seg_cihp_2': 342 | tensor.append(self.input['seg_cihp_2']) 343 | elif item == 'pre_seg_cihp_2': 344 | tensor.append(self.output['pre_seg_cihp_2']) 345 | elif item == 'joint_1': 346 | tensor.append(self.input['joint_1']) 347 | elif item == 'joint_2': 348 | tensor.append(self.input['joint_2']) 349 | elif item == 'flow_out': 350 | tensor.append(self.output['flow_out']) 351 | elif item == 'flow_tar': 352 | tensor.append(self.output['flow_tar']) 353 | elif item == 'vis_out': 354 | tensor.append(self.output['vis_out']) 355 | elif item == 'vis_tar': 356 | tensor.append(self.output['vis_tar']) 357 | elif item == 'vismap_out': 358 | tensor.append(self.output['vismap_out']) 359 | elif item == 'vismap_tar': 360 | tensor.append(self.output['vismap_tar']) 361 | elif item == 'dis_map': 362 | if self.opt.joint_PATN: 363 | self.input['dis_map'] = torch.exp((-0.1) * self.input['dis_map']) 364 | tensor.append(self.input['dis_map']) 365 | else: 366 | raise Exception('invalid tensor_type: %s' % item) 367 | tensor = torch.cat(tensor, dim=1) 368 | return tensor 369 | 370 | def get_current_errors(self): 371 | error_list = [ 372 | 'PSNR', 373 | 'SSIM', 374 | 'mask_SSIM', 375 | 'loss_l1', 376 | 'loss_content', 377 | 'loss_style', 378 | 'loss_G', 379 | 'loss_D', 380 | 'total_G_loss', 381 | 'grad_l1', 382 | 'grad_content', 383 | 'grad_style', 384 | 'grad_G', 385 | 'loss_ce' 386 | ] 387 | errors = OrderedDict() 388 | for item in error_list: 389 | if item in self.output: 390 | errors[item] = self.output[item].item() 391 | 392 | return errors 393 | 394 | def delvar(self): 395 | for k in self.output.keys(): 396 | del self.output[k] 397 | torch.cuda.empty_cache() 398 | 399 | def get_current_visuals(self): 400 | visual_items = [ 401 | ('img_1', [self.input['img_1'].data.cpu(), 'rgb']), 402 | ('joint_1', [self.input['joint_1'].data.cpu(), 'pose']), 403 | ('joint_2', [self.input['joint_2'].data.cpu(), 'pose']), 404 | ('flow_out', [self.output['flow_out'].data.cpu(), 'flow']), 405 | ('vis_out', [self.output['vis_out'].data.cpu(), 'vis']), 406 | ] 407 | 408 | if self.use_parsing: 409 | visual_items += [ 410 | ('seg_cihp_1', [self.input['seg_cihp_1'].data.cpu(), 'seg']), 411 | ('seg_cihp_2', [self.input['seg_cihp_2'].data.cpu(), 'seg']) 412 | ] 413 | 414 | if self.opt.G_pix_warp: 415 | visual_items += [ 416 | ('img_warp', [self.output['img_warp'].data.cpu(), 'rgb']), 417 | ('img_out_G', [self.output['img_out_G'].data.cpu(), 'rgb']), 418 | ('pix_mask', [self.output['pix_mask'].data.cpu(), 'softmask']), 419 | ('img_out', [self.output['img_out'].data.cpu(), 'rgb']), 420 | ('img_tar', [self.output['img_tar'].data.cpu(), 'rgb']) 421 | ] 422 | else: 423 | visual_items += [ 424 | ('img_warp', [self.output['img_warp'].data.cpu(), 'rgb']), 425 | ('img_out', [self.output['img_out'].data.cpu(), 'rgb']), 426 | ('img_tar', [self.output['img_tar'].data.cpu(), 'rgb']) 427 | ] 428 | 429 | visuals = OrderedDict(visual_items) 430 | return visuals 431 | 432 | def save(self, label): 433 | # save network weights 434 | self.save_network(self.netG, 'netG', label, self.gpu_ids) 435 | if self.opt.G_pix_warp: 436 | self.save_network(self.netPW, 'netPW', label, self.gpu_ids) 437 | # save optimizer status 438 | self.save_optim(self.optim, 'optim', label) 439 | # note 440 | self.save_network(self.netD, 'netD', label, self.gpu_ids) 441 | self.save_optim(self.optim_D, 'optim_D', label) 442 | 443 | def train(self): 444 | # netG and netD will always be in 'train' status 445 | pass 446 | 447 | def eval(self): 448 | # netG and netD will always be in 'train' status 449 | pass 450 | 451 | 452 | ################################################## 453 | # helper functions 454 | ################################################## 455 | def load_flow_network(model_id, epoch='best', gpu_ids=[]): 456 | from .flow_regression_model import FlowRegressionModel 457 | opt_dict = io.load_json(os.path.join('checkpoints', model_id, 'train_opt.json')) 458 | opt = argparse.Namespace(**opt_dict) 459 | opt.gpu_ids = gpu_ids 460 | opt.is_train = False # prevent loading discriminator, optimizer... 461 | opt.which_epoch = epoch 462 | # create network 463 | model = FlowRegressionModel() 464 | model.initialize(opt) 465 | return model.netF -------------------------------------------------------------------------------- /options/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszy98/SPGNet/3c1f6b5e290e7339ec01181403f23edf1e87eb15/options/__init__.py -------------------------------------------------------------------------------- /options/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszy98/SPGNet/3c1f6b5e290e7339ec01181403f23edf1e87eb15/options/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /options/__pycache__/base_options.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszy98/SPGNet/3c1f6b5e290e7339ec01181403f23edf1e87eb15/options/__pycache__/base_options.cpython-37.pyc -------------------------------------------------------------------------------- /options/__pycache__/pose_transfer_options.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszy98/SPGNet/3c1f6b5e290e7339ec01181403f23edf1e87eb15/options/__pycache__/pose_transfer_options.cpython-37.pyc -------------------------------------------------------------------------------- /options/base_options.py: -------------------------------------------------------------------------------- 1 | from __future__ import division, print_function 2 | 3 | import torch 4 | import argparse 5 | import os 6 | import util.io as io 7 | 8 | 9 | def opt_to_str(opt): 10 | return '\n'.join(['%s: %s' % (str(k), str(v)) for k, v in sorted(vars(opt).items())]) 11 | 12 | 13 | 14 | class BaseOptions(object): 15 | 16 | def __init__(self): 17 | self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 18 | self.initialized = False 19 | self.opt = None 20 | 21 | def initialize(self): 22 | parser = self.parser 23 | # basic experiment options 24 | parser.add_argument('--id', type = str, default = 'default', help = 'experiment ID. the experiment dir will be set as "./checkpoint/id/"') 25 | parser.add_argument('--gpu_ids', type = str, default = '0', help = 'gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') 26 | 27 | self.initialized = True 28 | 29 | 30 | def auto_set(self): 31 | ''' 32 | options that will be automatically set 33 | ''' 34 | # set training status 35 | self.opt.is_train = self.is_train 36 | 37 | # set gpu_ids 38 | str_ids = self.opt.gpu_ids.split(',') 39 | self.opt.gpu_ids = [] 40 | for str_id in str_ids: 41 | g_id = int(str_id) 42 | if g_id >= 0: 43 | self.opt.gpu_ids.append(g_id) 44 | # set gpu devices 45 | if len(self.opt.gpu_ids) > 0: 46 | os.environ['CUDA_VISIBLE_DEVICES'] = ','.join([str(i) for i in self.opt.gpu_ids]) 47 | torch.cuda.set_device(0) 48 | 49 | 50 | 51 | def parse(self, ord_str = None, display = True): 52 | ''' 53 | Parse option from terminal command string. If ord_str is given, parse option from it instead. 54 | ''' 55 | 56 | if not self.initialized: 57 | self.initialize() 58 | 59 | if ord_str is None: 60 | self.opt = self.parser.parse_args() 61 | else: 62 | ord_list = ord_str.split() 63 | self.opt = self.parser.parse_args(ord_list) 64 | 65 | self.auto_set() 66 | # display options 67 | if display: 68 | print('------------ Options -------------') 69 | for k, v in sorted(vars(self.opt).items()): 70 | print('%s: %s' % (str(k), str(v))) 71 | print('-------------- End ----------------') 72 | return self.opt 73 | 74 | def save(self, fn=None): 75 | if self.opt is None: 76 | raise Exception("parse options before saving!") 77 | if fn is None: 78 | expr_dir = os.path.join('checkpoints', self.opt.id) 79 | io.mkdir_if_missing(expr_dir) 80 | if self.opt.is_train: 81 | fn = os.path.join(expr_dir, 'train_opt.json') 82 | else: 83 | fn = os.path.join(expr_dir, 'test_opt.json') 84 | io.save_json(vars(self.opt), fn) 85 | 86 | def load(self, fn): 87 | args = io.load_json(fn) 88 | return argparse.Namespace(**args) 89 | -------------------------------------------------------------------------------- /options/pose_transfer_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | class BasePoseTransferOptions(BaseOptions): 4 | def initialize(self): 5 | super(BasePoseTransferOptions, self).initialize() 6 | parser = self.parser 7 | ############################## 8 | # Model Setting 9 | ############################## 10 | parser.add_argument('--init_type', type = str, default = 'kaiming', help = 'network initialization method [normal|xavier|kaiming|orthogonal]') 11 | parser.add_argument('--use_dropout', type=int, default=0, choices=[0,1], help='use dropout in generator') 12 | parser.add_argument('--which_model_G', type=str, default='dual_unet', choices=['unet', 'dual_unet'], help='generator network architecture') 13 | parser.add_argument('--pretrained_G_id', type=str, default=None) 14 | parser.add_argument('--pretrained_G_epoch', type=str, default='best') 15 | parser.add_argument('--G_nf', type=int, default=32, help='feature dimension at the bottom layer') 16 | parser.add_argument('--G_max_nf', type=int, default=128, help='max feature dimension') 17 | parser.add_argument('--G_n_scale', type=int, default=7, help='scale level number') 18 | parser.add_argument('--G_norm', type=str, default='batch', choices=['none', 'batch', 'instance'], help='type of normalization layer') 19 | parser.add_argument('--G_activation', type=str, default='relu', choices=['relu', 'leaky_relu'], help='type of activation function') 20 | parser.add_argument('--G_pose_type', type=str, default='joint_2') 21 | parser.add_argument('--G_appearance_type', type=str, default='img_1') 22 | # netG (only dual unet) 23 | parser.add_argument('--G_feat_warp', type=int, default=1, choices=[0,1], help='set 1 to use feature warping; otherwise the model is a simple unet with 2 encoders for pose and appearance respectively') 24 | parser.add_argument('--G_n_warp_scale', type=int, default=5, help='at scales higher than this, feature warping will not be performed (because the resolution of feature map is too small)') 25 | parser.add_argument('--G_vis_mode', type=str, default='residual', choices=['none', 'hard_gate', 'soft_gate', 'residual', 'res_no_vis'], help='different approaches to integrate visibility map in feature warping module') 26 | parser.add_argument('--G_no_end_norm', type=int, default=0, choices=[0,1], help='if set as 1, convolution at the start and the end of netG will not followed by norm_layer like BN.') 27 | # netG (pixel warping module) 28 | parser.add_argument('--G_pix_warp', type=int, default=0, choices=[0,1], help='use pixel warping module') 29 | # parser.add_argument('--G_pix_warp', type=str, default='none', choices=['none', 'mask', 'mask+flow', 'ext_mask', 'ext_mask+flow', 'exth_mask', 'exth_mask+flow'], help='combine generated image_2 and warped image_1 to synthesize final output. "mask": netG output a soft-mask to combine img_gen and img_warp; "mask+flow": netG output a soft-mask and a flow residual') 30 | parser.add_argument('--G_pix_warp_input_type', type=str, default='img_out_G+img_warp+vis_out+flow_out') 31 | parser.add_argument('--G_pix_warp_detach', type=int, default=1, choices=[0,1], help='generated image will be detached when it is used to combine with warped image. Thus the gradient from combined image will only propagate backward to soft-mask') 32 | # netD 33 | parser.add_argument('--D_nf', type=int, default=64, help='feature number of first conv layer in netD') 34 | parser.add_argument('--D_n_layers', type=int, default=3, help='number of conv layers in netD (patch gan)') 35 | parser.add_argument('--gan_type', type=str, default='dcgan', choices=['lsgan', 'dcgan'], help='gan loss type') 36 | parser.add_argument('--D_input_type_real', type=str, default='img_1+img_2+joint_2', help='input data items to netD') 37 | parser.add_argument('--D_input_type_fake', type=str, default='img_1+img_out+joint_2', help='input data items to netD') 38 | # netF 39 | parser.add_argument('--flow_on_the_fly', type=int, default=1, choices=[0,1], help='use a flow3d model to generate flow on-the-fly') 40 | parser.add_argument('--F_input_type', type=str, default='joint_1+joint_2', help='input data items for netF(flow) which flow is generated on-the-fly') 41 | parser.add_argument('--pretrained_flow_id', type=str, default='FlowReg_0.1', help='model id of flow regression model') 42 | parser.add_argument('--pretrained_flow_epoch', type=str, default='best', help='which epoch to load pretrained flow regression module') 43 | ############################## 44 | # Pose Setting 45 | ############################## 46 | parser.add_argument('--joint_nc', type=int, default=18, help='2d joint number. 18 for openpose joint') 47 | parser.add_argument('--joint_mode', type=str, default='binary', choices=['binary', 'gaussian']) 48 | parser.add_argument('--joint_radius', type=int, default=8, help='radius of joint map') 49 | parser.add_argument('--seg_nc', type=int, default=7, help='number of segmentation classes') 50 | ############################## 51 | # data setting (dataset_mode == general_pair) 52 | ############################## 53 | parser.add_argument('--dataset_type', type=str, default='pose_transfer', help='type of dataset. see data/data_loader.py') 54 | parser.add_argument('--dataset_name', type=str, default='deepfashion') 55 | parser.add_argument('--image_size', type=int, nargs='+', default=[256,256]) 56 | parser.add_argument('--batch_size', type = int, default = 8, help = 'batch size') 57 | parser.add_argument('--data_root', type=str, default=None, help='Set in Options.auto_set()') 58 | parser.add_argument('--fn_split', type=str, default=None, help='Set in Options.auto_set()') 59 | parser.add_argument('--img_dir', type=str, default=None, help='Set in Options.auto_set()') 60 | parser.add_argument('--seg_dir', type=str, default=None, help='Set in Options.auto_set()') 61 | parser.add_argument('--seg_pred_dir', type=str, default=None, help='dest parsing label preded by our model') 62 | parser.add_argument('--fn_pose', type=str, default=None, help='Set in Options.auto_set()') 63 | parser.add_argument('--debug', action='store_true', help='debug') 64 | 65 | parser.add_argument('--use_augmentation', type=int, default=0, choices=[0,1]) 66 | parser.add_argument('--aug_scale_range', type=float, default=1.2) 67 | parser.add_argument('--aug_shiftx_range', type=int, default=10) 68 | parser.add_argument('--aug_shifty_range', type=int, default=10) 69 | parser.add_argument('--aug_color_jit', type=int, default=0, choices=[0,1]) 70 | parser.add_argument('--vis_smooth_rate', type=int, default=5, help='use a median filter of size # to smooth the visiblity map') 71 | 72 | parser.add_argument('--spade_layers', type=int, default=3) 73 | 74 | def auto_set(self): 75 | super(BasePoseTransferOptions, self).auto_set() 76 | opt = self.opt 77 | ########################################### 78 | # Add id profix 79 | ########################################### 80 | if not opt.id.startswith('PoseTransfer_'): 81 | opt.id = 'PoseTransfer_' + opt.id 82 | ########################################### 83 | # Set dataset path 84 | ########################################### 85 | if opt.dataset_name == 'market': 86 | opt.image_size=[128,64] 87 | opt.joint_radius = 4 88 | opt.G_n_scales = 5 89 | opt.G_n_scale = 5 90 | opt.G_n_warp_scale = 4 91 | 92 | opt.use_augmentation = 1 93 | opt.data_root = 'datasets/market1501/' 94 | opt.fn_split = 'label/split.json' 95 | opt.img_dir = 'img/train' if opt.is_train else 'img/test' 96 | opt.fn_pose = 'label/pose_label.pkl' 97 | opt.pretrained_flow_id = 'FlowReg_market' 98 | opt.seg_dir = 'seg/' 99 | opt.seg_pred_dir = 'seg/' 100 | 101 | else: 102 | opt.data_root = 'datasets/deepfashion/' 103 | opt.fn_split = 'label/split.json' 104 | opt.img_dir = 'img/img' 105 | opt.fn_pose = 'label/pose_label.pkl' 106 | opt.pretrained_flow_id = 'FlowReg_deepfashion' 107 | opt.seg_dir = 'seg/' 108 | opt.seg_pred_dir = 'seg/' if opt.is_train else 'pred_seg/' 109 | 110 | 111 | 112 | class TrainPoseTransferOptions(BasePoseTransferOptions): 113 | def initialize(self): 114 | super(TrainPoseTransferOptions, self).initialize() 115 | self.is_train = True 116 | parser = self.parser 117 | # basic 118 | parser.add_argument('--resume_train', action = 'store_true', default = False, help = 'resume training from saved checkpoint') 119 | parser.add_argument('--last_epoch', type=int, default=1) 120 | parser.add_argument('--small_val_set', type=int, default=1, choices=[0,1], help='use 1/5 test samples as validation set') 121 | # optimizer 122 | parser.add_argument('--lr', type = float, default = 2e-4, help = 'initial learning rate') 123 | parser.add_argument('--beta1', type = float, default = 0.5, help = 'momentum1 term for Adam') 124 | parser.add_argument('--beta2', type = float, default = 0.999, help = 'momentum2 term for Adam') 125 | parser.add_argument('--weight_decay', type=float, default=0, help='weight decay') 126 | parser.add_argument('--lr_D', type=float, default=2e-5) 127 | parser.add_argument('--weight_decay_D', type=float, default=4e-4) 128 | # scheduler 129 | parser.add_argument('--lr_policy', type=str, default='step', choices = ['step', 'plateau', 'lambda'], help='learning rate policy: lambda|step|plateau') 130 | parser.add_argument('--n_epoch', type = int, default=30, help = '# of epoch at starting learning rate') 131 | parser.add_argument('--n_epoch_decay', type=int, default=0, help='# of epoch to linearly decay learning rate to zero') 132 | parser.add_argument('--lr_decay', type=int, default=100, help='multiply by a gamma every lr_decay_interval epochs') 133 | parser.add_argument('--lr_gamma', type = float, default = 0.1, help='lr decay rate') 134 | parser.add_argument('--display_freq', type = int, default = 100, help='frequency of showing training results on screen') 135 | parser.add_argument('--test_epoch_freq', type = int, default = 1, help='frequency of testing model') 136 | parser.add_argument('--save_epoch_freq', type = int, default = 1, help='frequency of saving model to disk' ) 137 | parser.add_argument('--vis_epoch_freq', type = int, default = 1, help='frequency of visualizing generated images') 138 | parser.add_argument('--check_grad_freq', type = int, default = 100, help = 'frequency of checking gradient of each loss') 139 | parser.add_argument('--n_vis', type = int, default = 64, help='number of visualized images') 140 | # loss setting 141 | parser.add_argument('--epoch_add_gan', type=int, default=6, help='add gan loss after # epochs of training') 142 | parser.add_argument('--loss_weight_l1', type=float, default=1.) 143 | parser.add_argument('--loss_weight_content', type=float, default=1.) 144 | parser.add_argument('--loss_weight_style', type=float, default=0) 145 | parser.add_argument('--loss_weight_gan', type=float, default=0.01) 146 | parser.add_argument('--shifted_style_loss', type=int, default=1, choices=[0,1]) 147 | #parser.add_argument('--vgg_content_weights', type=float, nargs='+', default=[1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0]) 148 | parser.add_argument('--vgg_content_weights', type=float, nargs='+', default=[0.125, 0.125, 0.125, 0.125, 0.125]) 149 | parser.add_argument('--vgg_content_mode', type=str, default='balance', choices=['balance', 'imbalance', 'special']) 150 | 151 | def auto_set(self): 152 | super(TrainPoseTransferOptions, self).auto_set() 153 | opt = self.opt 154 | if opt.vgg_content_mode == 'balance': 155 | opt.vgg_content_weights = [0.125, 0.125, 0.125, 0.125, 0.125] 156 | elif opt.vgg_content_mode == 'imbalance': 157 | opt.vgg_content_weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0] 158 | 159 | 160 | class TestPoseTransferOptions(BasePoseTransferOptions): 161 | def initialize(self): 162 | super(TestPoseTransferOptions, self).initialize() 163 | self.is_train = False 164 | parser = self.parser 165 | parser.add_argument('--which_epoch', type=str, default='best') 166 | parser.add_argument('--data_split', type=str, default='test') 167 | parser.add_argument('--n_test_batch', type=int, default=-1, help='set number of minibatch used for test') 168 | # visualize samples 169 | parser.add_argument('--n_vis', type = int, default = 64, help='number of visualized images') 170 | # save generated images 171 | parser.add_argument('--save_output', action='store_true', help='save output images in the folder exp_dir/test/') 172 | parser.add_argument('--output_dir', type=str, default='output', help='path to save generated images') 173 | parser.add_argument('--masked', action='store_true', help='also test masked-ssim (for market-1501)') 174 | 175 | 176 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.19.2 2 | torch==1.2.0 3 | torchvision==0.4.0 4 | scikit-image==0.14.2 5 | tqdm==4.25.0 6 | opencv-python==4.4.0.46 7 | imageio==2.5.0 8 | lpips==0.1.3 9 | dominate==2.3.5 10 | h5py==2.9.0 11 | Pillow==5.4.1 12 | pandas==0.25.3 13 | tensorboardX==1.8 14 | 15 | -------------------------------------------------------------------------------- /scripts/test_pose_transfer_model.py: -------------------------------------------------------------------------------- 1 | from __future__ import division, print_function 2 | import sys 3 | sys.path.append('.') 4 | import torch 5 | from data.data_loader import CreateDataLoader 6 | from options.pose_transfer_options import TestPoseTransferOptions 7 | from models.pose_transfer_model import PoseTransferModel 8 | from util.visualizer import Visualizer 9 | from util.loss_buffer import LossBuffer 10 | import util.io as io 11 | import os 12 | import numpy as np 13 | import tqdm 14 | import cv2 15 | import time 16 | from collections import OrderedDict 17 | 18 | parser = TestPoseTransferOptions() 19 | opt = parser.parse(display=False) 20 | parser.save() 21 | 22 | model = PoseTransferModel() 23 | model.initialize(opt) 24 | val_loader = CreateDataLoader(opt, split='test') 25 | # create visualizer 26 | visualizer = Visualizer(opt) 27 | 28 | # visualize 29 | if opt.n_vis > 0: 30 | print('visualizing first %d samples' % opt.n_vis) 31 | num_vis_batch = min(int(np.ceil(1.0 * opt.n_vis / opt.batch_size)), len(val_loader)) 32 | val_loader.dataset.set_len(num_vis_batch * opt.batch_size) 33 | val_visuals = None 34 | for i, data in enumerate(tqdm.tqdm(val_loader, desc='Visualize')): 35 | model.eval() 36 | model.netG.eval() 37 | model.netF.eval() 38 | model.set_input(data) 39 | model.test(compute_loss=False) 40 | visuals = model.get_current_visuals() 41 | if val_visuals is None: 42 | val_visuals = visuals 43 | else: 44 | for name, v in visuals.items(): 45 | val_visuals[name][0] = torch.cat((val_visuals[name][0], v[0]), dim=0) 46 | 47 | fn_vis = os.path.join('checkpoints', opt.id, 'vis', 'test_epoch%s.jpg'%opt.which_epoch) 48 | 49 | visualizer.visualize_results(val_visuals, fn_vis) 50 | 51 | # test 52 | if opt.n_test_batch != 0: 53 | val_loader.dataset.set_len(opt.n_test_batch * opt.batch_size) 54 | # print(opt.n_test_batch * opt.batch_size) 55 | loss_buffer = LossBuffer(size=len(val_loader)) 56 | model.output = {} 57 | if opt.save_output: 58 | output_dir = os.path.join(model.save_dir, opt.output_dir) 59 | io.mkdir_if_missing(output_dir) 60 | 61 | total_time = 0 62 | for i, data in enumerate(tqdm.tqdm(val_loader, desc='Test')): 63 | tic = time.time() 64 | model.eval() 65 | model.netG.eval() 66 | model.netF.eval() 67 | model.set_input(data) 68 | model.test() 69 | toc = time.time() 70 | total_time += (toc - tic) 71 | loss_buffer.add(model.get_current_errors()) 72 | # save output 73 | if opt.save_output: 74 | id_list = model.input['id'] 75 | images = model.output['img_out'].cpu().numpy().transpose(0, 2, 3, 1) 76 | images = ((images + 1.0) * 127.5).clip(0, 255).astype(np.uint8) 77 | for (sid1, sid2), img in zip(id_list, images): 78 | img = img[..., [2, 1, 0]] # convert to cv2 format 79 | cv2.imwrite(os.path.join(output_dir, '%s_%s.jpg' % (sid1, sid2)), img) 80 | 81 | test_error = loss_buffer.get_errors() 82 | test_error['sec_per_image'] = total_time / (opt.batch_size * len(val_loader)) 83 | info = OrderedDict([('model_id', opt.id), ('epoch', opt.which_epoch)]) 84 | log_str = visualizer.log(info, test_error, log_in_file=False) 85 | print(log_str) 86 | 87 | -------------------------------------------------------------------------------- /scripts/train_pose_transfer_model.py: -------------------------------------------------------------------------------- 1 | from __future__ import division, print_function 2 | import os 3 | import sys 4 | sys.path.append('.') 5 | 6 | import torch 7 | import tensorboardX 8 | from data.data_loader import CreateDataLoader 9 | from models.pose_transfer_model import PoseTransferModel 10 | from options.pose_transfer_options import TrainPoseTransferOptions 11 | from util.visualizer import Visualizer 12 | from util.loss_buffer import LossBuffer 13 | 14 | import util.io as io 15 | import tqdm 16 | import time 17 | from collections import OrderedDict 18 | 19 | # parse and save options 20 | parser = TrainPoseTransferOptions() 21 | opt = parser.parse() 22 | parser.save() 23 | # create model 24 | model = PoseTransferModel() 25 | model.initialize(opt) 26 | # save terminal line 27 | io.save_str_list([' '.join(sys.argv)], os.path.join(model.save_dir, 'order_line.txt')) 28 | # create data loader 29 | train_loader = CreateDataLoader(opt, split='train') 30 | val_loader = CreateDataLoader(opt, split='test' if not opt.small_val_set else 'test_small') 31 | # create visualizer 32 | visualizer = Visualizer(opt) 33 | 34 | logdir = os.path.join('logs', opt.id) 35 | if not os.path.exists(logdir): 36 | os.makedirs(logdir) 37 | writer = tensorboardX.SummaryWriter(logdir) 38 | 39 | # set "saving best" 40 | best_info = { 41 | 'meas': 'SSIM', 42 | 'type': 'max', 43 | 'best_value': 0, 44 | 'best_epoch': -1 45 | } 46 | 47 | # set continue training 48 | if not opt.resume_train: 49 | total_steps = 0 50 | epoch_count = 1 51 | else: 52 | last_epoch = int(opt.last_epoch) 53 | total_steps = len(train_loader)*last_epoch 54 | epoch_count = 1 + last_epoch 55 | 56 | if opt.debug: 57 | opt.display_freq = 2 58 | 59 | for epoch in tqdm.trange(epoch_count, opt.n_epoch+opt.n_epoch_decay+1, desc='Epoch'): 60 | #train model 61 | model.train() 62 | model.netG.train() 63 | model.netD.train() 64 | if model.opt.G_pix_warp: 65 | model.netPW.train() 66 | 67 | model.use_gan = (opt.loss_weight_gan > 0) and (epoch >= opt.epoch_add_gan) 68 | for i,data in enumerate(tqdm.tqdm(train_loader, desc='Train')): 69 | total_steps += 1 70 | model.set_input(data) 71 | model.optimize_parameters(check_grad=(opt.check_grad_freq>0 and total_steps%opt.check_grad_freq==0)) 72 | 73 | if total_steps % opt.display_freq == 0: 74 | train_error = model.get_current_errors() 75 | info = OrderedDict([ 76 | ('id', opt.id), 77 | ('iter', total_steps), 78 | ('epoch', epoch), 79 | ('lr', model.optimizers[0].param_groups[0]['lr']), 80 | ]) 81 | tqdm.tqdm.write(visualizer.log(info, train_error)) 82 | for k, v in train_error.items(): 83 | writer.add_scalar(k, v, total_steps) 84 | writer.add_scalar('lr', model.optimizers[0].param_groups[0]['lr'], total_steps) 85 | writer.flush() 86 | 87 | #update learning rate(lr_scheduler.step()) after optim.step(), otherwise lost first lr 88 | model.update_learning_rate() 89 | 90 | if epoch % opt.test_epoch_freq == 0: 91 | # model.get_current_errors() #erase training error information 92 | model.output = {} 93 | loss_buffer = LossBuffer(size=len(val_loader)) 94 | #eval model 95 | model.netG.eval() 96 | model.eval() 97 | if model.opt.G_pix_warp: 98 | model.netPW.eval() 99 | 100 | for i, data in enumerate(tqdm.tqdm(val_loader, desc='Test')): 101 | model.set_input(data) 102 | model.test(compute_loss=True) 103 | loss_buffer.add(model.get_current_errors()) 104 | test_error = loss_buffer.get_errors() 105 | info = OrderedDict([ 106 | ('time', time.ctime()), 107 | ('id', opt.id), 108 | ('epoch', epoch), 109 | ]) 110 | tqdm.tqdm.write(visualizer.log(info, test_error)) 111 | # save best 112 | if best_info['best_epoch']==-1 or (test_error[best_info['meas']].item()best_info['best_value'] and best_info['type']=='max'): 113 | tqdm.tqdm.write('save as best epoch!') 114 | best_info['best_epoch'] = epoch 115 | best_info['best_value'] = test_error[best_info['meas']].item() 116 | model.save('best') 117 | tqdm.tqdm.write(visualizer.log(best_info)) 118 | 119 | if epoch % opt.vis_epoch_freq == 0: 120 | #eval model 121 | model.eval() 122 | model.netG.eval() 123 | if model.opt.G_pix_warp: 124 | model.netPW.eval() 125 | 126 | num_vis_batch = int(1.*opt.n_vis/opt.batch_size) 127 | visuals = None 128 | for i, data in enumerate(train_loader): 129 | if i == num_vis_batch: 130 | break 131 | model.set_input(data) 132 | model.test(compute_loss=True) 133 | v = model.get_current_visuals() 134 | if visuals is None: 135 | visuals = v 136 | else: 137 | for name, item in v.items(): 138 | visuals[name][0] = torch.cat((visuals[name][0], item[0]), dim=0) 139 | tqdm.tqdm.write('visualizing training sample') 140 | fn_vis = os.path.join('checkpoints', opt.id, 'vis', 'train_epoch%d.jpg'%epoch) 141 | visualizer.visualize_results(visuals, fn_vis) 142 | 143 | visuals = None 144 | for i, data in enumerate(val_loader): 145 | if i == num_vis_batch: 146 | break 147 | model.set_input(data) 148 | model.test(compute_loss=True) 149 | v = model.get_current_visuals() 150 | if visuals is None: 151 | visuals = v 152 | else: 153 | for name, item in v.items(): 154 | visuals[name][0] = torch.cat((visuals[name][0], item[0]), dim=0) 155 | tqdm.tqdm.write('visualizing test sample') 156 | fn_vis = os.path.join('checkpoints', opt.id, 'vis', 'test_epoch%d.jpg'%epoch) 157 | visualizer.visualize_results(visuals, fn_vis) 158 | 159 | if epoch % opt.save_epoch_freq == 0: 160 | model.save(epoch) 161 | model.save('latest') 162 | print(best_info) 163 | -------------------------------------------------------------------------------- /test_deepfashion.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python scripts/test_pose_transfer_model.py \ 4 | --id deepfashion \ 5 | --gpu_ids 0 \ 6 | --dataset_name deepfashion \ 7 | --which_model_G dual_unet \ 8 | --G_feat_warp 1 \ 9 | --G_vis_mode residual \ 10 | --pretrained_flow_id FlowReg_deepfashion \ 11 | --pretrained_flow_epoch best \ 12 | --dataset_type pose_transfer_parsing \ 13 | --which_epoch latest \ 14 | --batch_size 4 \ 15 | --save_output \ 16 | --output_dir output 17 | -------------------------------------------------------------------------------- /test_market.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python scripts/test_pose_transfer_model.py \ 4 | --id market \ 5 | --gpu_ids 0 \ 6 | --dataset_name market \ 7 | --which_model_G dual_unet \ 8 | --G_feat_warp 1 \ 9 | --G_vis_mode residual \ 10 | --pretrained_flow_id FlowReg_market \ 11 | --pretrained_flow_epoch best \ 12 | --dataset_type pose_transfer_parsing_market \ 13 | --which_epoch latest \ 14 | --batch_size 1 \ 15 | --save_output \ 16 | --output_dir output -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszy98/SPGNet/3c1f6b5e290e7339ec01181403f23edf1e87eb15/tools/__init__.py -------------------------------------------------------------------------------- /tools/__pycache__/cmd.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszy98/SPGNet/3c1f6b5e290e7339ec01181403f23edf1e87eb15/tools/__pycache__/cmd.cpython-37.pyc -------------------------------------------------------------------------------- /tools/calPCKH_fashion.py: -------------------------------------------------------------------------------- 1 | import tqdm 2 | import pandas as pd 3 | import json 4 | 5 | MISSING_VALUE = -1 6 | 7 | PARTS_SEL = [0, 1, 14, 15, 16, 17] 8 | 9 | target_annotation = '../datasets/deepfashion/label/fasion-annotation-test.csv' 10 | pred_annotation = '../checkpoints/PoseTransfer_deepfashion/output_pckh.csv' 11 | 12 | 13 | ''' 14 | hz: head size 15 | alpha: norm factor 16 | px, py: predict coords 17 | tx, ty: target coords 18 | ''' 19 | def isRight(px, py, tx, ty, hz, alpha): 20 | if px == -1 or py == -1 or tx == -1 or ty == -1: 21 | return 0 22 | 23 | if abs(px - tx) < hz[0] * alpha and abs(py - ty) < hz[1] * alpha: 24 | return 1 25 | else: 26 | return 0 27 | 28 | 29 | def how_many_right_seq(px, py, tx, ty, hz, alpha): 30 | nRight = 0 31 | for i in range(len(px)): 32 | nRight = nRight + isRight(px[i], py[i], tx[i], ty[i], hz, alpha) 33 | 34 | return nRight 35 | 36 | 37 | def ValidPoints(tx): 38 | nValid = 0 39 | for item in tx: 40 | if item != -1: 41 | nValid = nValid + 1 42 | return nValid 43 | 44 | 45 | def get_head_wh(x_coords, y_coords): 46 | final_w, final_h = -1, -1 47 | component_count = 0 48 | save_componets = [] 49 | for component in PARTS_SEL: 50 | if x_coords[component] == MISSING_VALUE or y_coords[component] == MISSING_VALUE: 51 | continue 52 | else: 53 | component_count += 1 54 | save_componets.append([x_coords[component], y_coords[component]]) 55 | if component_count >= 2: 56 | x_cords = [] 57 | y_cords = [] 58 | for component in save_componets: 59 | x_cords.append(component[0]) 60 | y_cords.append(component[1]) 61 | xmin = min(x_cords) 62 | xmax = max(x_cords) 63 | ymin = min(y_cords) 64 | ymax = max(y_cords) 65 | final_w = xmax - xmin 66 | final_h = ymax - ymin 67 | return final_w, final_h 68 | 69 | 70 | tAnno = pd.read_csv(target_annotation, sep=':') 71 | pAnno = pd.read_csv(pred_annotation, sep=':') 72 | 73 | pRows = pAnno.shape[0] 74 | 75 | nAll = 0 76 | nCorrect = 0 77 | alpha = 0.5 78 | for i in tqdm.tqdm(range(pRows)): 79 | pValues = pAnno.iloc[i].values 80 | pname = pValues[0] 81 | pycords = json.loads(pValues[1]) # list of numbers 82 | pxcords = json.loads(pValues[2]) 83 | 84 | tname = pname 85 | 86 | #### 87 | tname = tname.replace('.jpg_vis.jpg','.jpg') 88 | tname = tname.replace('.jpg___', '_') 89 | if tname.count('_')==5: 90 | ns = tname.split('_') 91 | tname = ns[0]+ns[1]+'_'+ns[2]+'_'+ns[3]+ns[4]+'_'+ns[5] 92 | tname = tname.replace('fashion', 'fasion') 93 | tValues = tAnno.query('name == "%s"' % (tname)).values[0] 94 | tycords = json.loads(tValues[1]) # list of numbers 95 | txcords = json.loads(tValues[2]) 96 | 97 | 98 | xBox, yBox = get_head_wh(txcords, tycords) 99 | if xBox == -1 or yBox == -1: 100 | continue 101 | 102 | head_size = (xBox, yBox) 103 | nAll = nAll + ValidPoints(tycords) 104 | nCorrect = nCorrect + how_many_right_seq(pxcords, pycords, txcords, tycords, head_size, alpha) 105 | 106 | print('%d/%d %f' % (nCorrect, nAll, nCorrect * 1.0 / nAll)) 107 | -------------------------------------------------------------------------------- /tools/calPCKH_market.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import json 3 | import tqdm 4 | MISSING_VALUE = -1 5 | 6 | PARTS_SEL = [0, 1, 14, 15, 16, 17] 7 | 8 | # fix the PATH 9 | target_annotation = '../datasets/market1501/label/market-annotation-test.csv' 10 | pred_annotation = '../checkpoints/PoseTransfer_market/output_pckh.csv' 11 | 12 | 13 | ''' 14 | hz: head size 15 | alpha: norm factor 16 | px, py: predict coords 17 | tx, ty: target coords 18 | ''' 19 | def isRight(px, py, tx, ty, hz, alpha): 20 | if px == -1 or py == -1 or tx == -1 or ty == -1: 21 | return 0 22 | 23 | if abs(px-tx) < hz[0]*alpha and abs(py-ty) < hz[1]*alpha: 24 | return 1 25 | else: 26 | return 0 27 | 28 | def how_many_right_seq(px, py, tx, ty, hz, alpha): 29 | nRight = 0 30 | for i in range(len(px)): 31 | nRight = nRight + isRight(px[i], py[i], tx[i], ty[i], hz, alpha) 32 | 33 | return nRight 34 | 35 | def ValidPoints(tx): 36 | nValid = 0 37 | for item in tx: 38 | if item != -1: 39 | nValid = nValid + 1 40 | return nValid 41 | 42 | def get_head_wh(x_coords, y_coords): 43 | final_w, final_h = -1, -1 44 | component_count = 0 45 | save_componets = [] 46 | for component in PARTS_SEL: 47 | if x_coords[component] == MISSING_VALUE or y_coords[component] == MISSING_VALUE: 48 | continue 49 | else: 50 | component_count += 1 51 | save_componets.append([x_coords[component], y_coords[component]]) 52 | if component_count >= 2: 53 | x_cords = [] 54 | y_cords = [] 55 | for component in save_componets: 56 | x_cords.append(component[0]) 57 | y_cords.append(component[1]) 58 | xmin = min(x_cords) 59 | xmax = max(x_cords) 60 | ymin = min(y_cords) 61 | ymax = max(y_cords) 62 | final_w = xmax - xmin 63 | final_h = ymax - ymin 64 | return final_w, final_h 65 | 66 | 67 | 68 | 69 | 70 | tAnno = pd.read_csv(target_annotation, sep=':') 71 | pAnno = pd.read_csv(pred_annotation, sep=':') 72 | 73 | pRows = pAnno.shape[0] 74 | 75 | nAll = 0 76 | nCorrect = 0 77 | alpha = 0.5 78 | for i in tqdm.tqdm(range(pRows)): 79 | pValues = pAnno.iloc[i].values 80 | pname = pValues[0] 81 | pycords = json.loads(pValues[1]) #list of numbers 82 | pxcords = json.loads(pValues[2]) 83 | tnames = pname.split('_') 84 | tname = tnames[4]+'_'+tnames[5]+'_'+tnames[6]+'_'+tnames[7] 85 | tValues = tAnno.query('name == "%s"' %(tname)).values[0] 86 | tycords = json.loads(tValues[1]) #list of numbers 87 | txcords = json.loads(tValues[2]) 88 | 89 | xBox, yBox = get_head_wh(txcords, tycords) 90 | if xBox == -1 or yBox == -1: 91 | continue 92 | 93 | head_size = (xBox, yBox) 94 | nAll = nAll + ValidPoints(tycords) 95 | nCorrect = nCorrect + how_many_right_seq(pxcords, pycords, txcords, tycords, head_size, alpha) 96 | 97 | 98 | print('%d/%d %f' %(nCorrect, nAll, nCorrect*1.0/nAll)) 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | -------------------------------------------------------------------------------- /tools/cmd.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def args(): 4 | """ 5 | Define args that is used in project 6 | """ 7 | parser = argparse.ArgumentParser(description="Pose guided image generation usign deformable skip layers") 8 | parser.add_argument("--output_dir", default='output/displayed_samples', help="Directory with generated sample images") 9 | parser.add_argument("--batch_size", default=4, type=int, help='Size of the batch') 10 | parser.add_argument("--training_ratio", default=1, type=int, 11 | help="The training ratio is the number of discriminator updates per generator update.") 12 | 13 | parser.add_argument("--l1_penalty_weight", default=100, type=float, help='Weight of l1 loss') 14 | parser.add_argument('--gan_penalty_weight', default=1, type=float, help='Weight of GAN loss') 15 | parser.add_argument('--tv_penalty_weight', default=0, type=float, help='Weight of total variation loss') 16 | parser.add_argument('--lstruct_penalty_weight', default=0, type=float, help="Weight of lstruct") 17 | 18 | parser.add_argument("--number_of_epochs", default=500, type=int, help="Number of training epochs") 19 | 20 | parser.add_argument("--content_loss_layer", default='none', help='Name of content layer (vgg19)' 21 | ' e.g. block4_conv1 or none') 22 | 23 | parser.add_argument("--checkpoints_dir", default="output/checkpoints", help="Folder with checkpoints") 24 | parser.add_argument("--checkpoint_ratio", default=30, type=int, help="Number of epochs between consecutive checkpoints") 25 | parser.add_argument("--generator_checkpoint", default=None, help="Previosly saved model of generator") 26 | parser.add_argument("--discriminator_checkpoint", default=None, help="Previosly saved model of discriminator") 27 | parser.add_argument("--nn_loss_area_size", default=1, type=int, help="Use nearest neighbour loss") 28 | parser.add_argument("--use_validation", default=1, type=int, help="Use validation") 29 | 30 | parser.add_argument('--dataset', default='market', choices=['market', 'fasion', 'fasion128', 'fasion128128'], 31 | help='Market or fasion') 32 | 33 | 34 | parser.add_argument("--display_ratio", default=1, type=int, help='Number of epochs between ploting') 35 | parser.add_argument("--start_epoch", default=0, type=int, help='Start epoch for starting from checkpoint') 36 | parser.add_argument("--pose_estimator", default='pose_estimator.h5', 37 | help='Pretrained model for cao pose estimator') 38 | 39 | parser.add_argument("--images_for_test", default=12000, type=int, help="Number of images for testing") 40 | 41 | parser.add_argument("--use_input_pose", default=True, type=int, help='Feed to generator input pose') 42 | parser.add_argument("--warp_skip", default='stn', choices=['none', 'full', 'mask', 'stn'], 43 | help="Type of warping skip layers to use.") 44 | parser.add_argument("--warp_agg", default='max', choices=['max', 'avg'], 45 | help="Type of aggregation.") 46 | 47 | parser.add_argument("--disc_type", default='call', choices=['call', 'sim', 'warp'], 48 | help="Type of discriminator call - concat all, sim - siamease, sharewarp - warp.") 49 | 50 | 51 | parser.add_argument("--generated_images_dir", default='output/generated_images', 52 | help='Folder with generated images from training dataset') 53 | 54 | parser.add_argument('--load_generated_images', default=0, type=int, 55 | help='Load images from generated_images_dir or generate') 56 | 57 | parser.add_argument('--use_dropout_test', default=0, type=int, 58 | help='To use dropout when generate images') 59 | 60 | args = parser.parse_args() 61 | 62 | args.images_dir_train = 'data/' + args.dataset + '-dataset/train' 63 | args.images_dir_test = 'data/' + args.dataset + '-dataset/test' 64 | 65 | args.annotations_file_train = 'data/' + args.dataset + '-annotation-train.csv' 66 | args.annotations_file_test = 'data/' + args.dataset + '-annotation-test.csv' 67 | 68 | args.pairs_file_train = 'data/' + args.dataset + '-pairs-train.csv' 69 | args.pairs_file_test = 'data/' + args.dataset + '-pairs-test.csv' 70 | 71 | if args.dataset == 'fasion': 72 | args.image_size = (256, 256) 73 | elif args.dataset == 'fasion128128': 74 | args.image_size = (128, 128) 75 | else: 76 | args.image_size = (128, 64) 77 | 78 | args.tmp_pose_dir = 'tmp/' + args.dataset + '/' 79 | 80 | del args.dataset 81 | 82 | return args 83 | -------------------------------------------------------------------------------- /tools/compute_coordinates.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pose_utils 4 | 5 | from keras.models import load_model 6 | import skimage.transform as st 7 | import pandas as pd 8 | from tqdm import tqdm 9 | from skimage.io import imread 10 | from skimage.transform import resize 11 | from scipy.ndimage import gaussian_filter 12 | 13 | from cmd import args 14 | 15 | args = args() 16 | 17 | model = load_model(args.pose_estimator) 18 | 19 | 20 | mapIdx = [[31,32], [39,40], [33,34], [35,36], [41,42], [43,44], [19,20], [21,22], 21 | [23,24], [25,26], [27,28], [29,30], [47,48], [49,50], [53,54], [51,52], 22 | [55,56], [37,38], [45,46]] 23 | 24 | limbSeq = [[2,3], [2,6], [3,4], [4,5], [6,7], [7,8], [2,9], [9,10], 25 | [10,11], [2,12], [12,13], [13,14], [2,1], [1,15], [15,17], 26 | [1,16], [16,18], [3,17], [6,18]] 27 | 28 | 29 | def compute_cordinates(heatmap_avg, paf_avg, th1=0.1, th2=0.05): 30 | all_peaks = [] 31 | peak_counter = 0 32 | 33 | for part in range(18): 34 | map_ori = heatmap_avg[:,:,part] 35 | map = gaussian_filter(map_ori, sigma=3) 36 | 37 | map_left = np.zeros(map.shape) 38 | map_left[1:,:] = map[:-1,:] 39 | map_right = np.zeros(map.shape) 40 | map_right[:-1,:] = map[1:,:] 41 | map_up = np.zeros(map.shape) 42 | map_up[:,1:] = map[:,:-1] 43 | map_down = np.zeros(map.shape) 44 | map_down[:,:-1] = map[:,1:] 45 | 46 | peaks_binary = np.logical_and.reduce((map>=map_left, map>=map_right, map>=map_up, map>=map_down, map > th1)) 47 | peaks = zip(np.nonzero(peaks_binary)[1], np.nonzero(peaks_binary)[0]) # note reverse 48 | 49 | peaks_with_score = [x + (map_ori[x[1],x[0]],) for x in peaks] 50 | id = range(peak_counter, peak_counter + len(peaks)) 51 | peaks_with_score_and_id = [peaks_with_score[i] + (id[i],) for i in range(len(id))] 52 | 53 | all_peaks.append(peaks_with_score_and_id) 54 | peak_counter += len(peaks) 55 | 56 | connection_all = [] 57 | special_k = [] 58 | mid_num = 10 59 | 60 | for k in range(len(mapIdx)): 61 | score_mid = paf_avg[:,:,[x-19 for x in mapIdx[k]]] 62 | candA = all_peaks[limbSeq[k][0]-1] 63 | candB = all_peaks[limbSeq[k][1]-1] 64 | nA = len(candA) 65 | nB = len(candB) 66 | indexA, indexB = limbSeq[k] 67 | if(nA != 0 and nB != 0): 68 | connection_candidate = [] 69 | for i in range(nA): 70 | for j in range(nB): 71 | vec = np.subtract(candB[j][:2], candA[i][:2]) 72 | norm = np.sqrt(vec[0]*vec[0] + vec[1]*vec[1]) 73 | vec = np.divide(vec, norm) 74 | 75 | startend = zip(np.linspace(candA[i][0], candB[j][0], num=mid_num), 76 | np.linspace(candA[i][1], candB[j][1], num=mid_num)) 77 | 78 | vec_x = np.array([score_mid[int(round(startend[I][1])), int(round(startend[I][0])), 0] 79 | for I in range(len(startend))]) 80 | vec_y = np.array([score_mid[int(round(startend[I][1])), int(round(startend[I][0])), 1] 81 | for I in range(len(startend))]) 82 | 83 | score_midpts = np.multiply(vec_x, vec[0]) + np.multiply(vec_y, vec[1]) 84 | score_with_dist_prior = sum(score_midpts)/len(score_midpts) + min(0.5*oriImg.shape[0]/norm-1, 0) 85 | criterion1 = len(np.nonzero(score_midpts > th2)[0]) > 0.8 * len(score_midpts) 86 | criterion2 = score_with_dist_prior > 0 87 | if criterion1 and criterion2: 88 | connection_candidate.append([i, j, score_with_dist_prior, score_with_dist_prior+candA[i][2]+candB[j][2]]) 89 | 90 | connection_candidate = sorted(connection_candidate, key=lambda x: x[2], reverse=True) 91 | connection = np.zeros((0,5)) 92 | for c in range(len(connection_candidate)): 93 | i,j,s = connection_candidate[c][0:3] 94 | if(i not in connection[:,3] and j not in connection[:,4]): 95 | connection = np.vstack([connection, [candA[i][3], candB[j][3], s, i, j]]) 96 | if(len(connection) >= min(nA, nB)): 97 | break 98 | 99 | connection_all.append(connection) 100 | else: 101 | special_k.append(k) 102 | connection_all.append([]) 103 | 104 | # last number in each row is the total parts number of that person 105 | # the second last number in each row is the score of the overall configuration 106 | subset = -1 * np.ones((0, 20)) 107 | candidate = np.array([item for sublist in all_peaks for item in sublist]) 108 | 109 | for k in range(len(mapIdx)): 110 | if k not in special_k: 111 | partAs = connection_all[k][:,0] 112 | partBs = connection_all[k][:,1] 113 | indexA, indexB = np.array(limbSeq[k]) - 1 114 | 115 | for i in range(len(connection_all[k])): #= 1:size(temp,1) 116 | found = 0 117 | subset_idx = [-1, -1] 118 | for j in range(len(subset)): #1:size(subset,1): 119 | if subset[j][indexA] == partAs[i] or subset[j][indexB] == partBs[i]: 120 | subset_idx[found] = j 121 | found += 1 122 | 123 | if found == 1: 124 | j = subset_idx[0] 125 | if(subset[j][indexB] != partBs[i]): 126 | subset[j][indexB] = partBs[i] 127 | subset[j][-1] += 1 128 | subset[j][-2] += candidate[partBs[i].astype(int), 2] + connection_all[k][i][2] 129 | elif found == 2: # if found 2 and disjoint, merge them 130 | j1, j2 = subset_idx 131 | print "found = 2" 132 | # print("found = 2") 133 | membership = ((subset[j1]>=0).astype(int) + (subset[j2]>=0).astype(int))[:-2] 134 | if len(np.nonzero(membership == 2)[0]) == 0: #merge 135 | subset[j1][:-2] += (subset[j2][:-2] + 1) 136 | subset[j1][-2:] += subset[j2][-2:] 137 | subset[j1][-2] += connection_all[k][i][2] 138 | subset = np.delete(subset, j2, 0) 139 | else: # as like found == 1 140 | subset[j1][indexB] = partBs[i] 141 | subset[j1][-1] += 1 142 | subset[j1][-2] += candidate[partBs[i].astype(int), 2] + connection_all[k][i][2] 143 | 144 | # if find no partA in the subset, create a new subset 145 | elif not found and k < 17: 146 | row = -1 * np.ones(20) 147 | row[indexA] = partAs[i] 148 | row[indexB] = partBs[i] 149 | row[-1] = 2 150 | row[-2] = sum(candidate[connection_all[k][i,:2].astype(int), 2]) + connection_all[k][i][2] 151 | subset = np.vstack([subset, row]) 152 | 153 | # delete some rows of subset which has few parts occur 154 | deleteIdx = []; 155 | for i in range(len(subset)): 156 | if subset[i][-1] < 4 or subset[i][-2]/subset[i][-1] < 0.4: 157 | deleteIdx.append(i) 158 | subset = np.delete(subset, deleteIdx, axis=0) 159 | 160 | if len(subset) == 0: 161 | return np.array([[-1, -1]] * 18).astype(int) 162 | 163 | cordinates = [] 164 | result_image_index = np.argmax(subset[:, -2]) 165 | 166 | for part in subset[result_image_index, :18]: 167 | if part == -1: 168 | cordinates.append([-1, -1]) 169 | else: 170 | Y = candidate[part.astype(int), 0] 171 | X = candidate[part.astype(int), 1] 172 | cordinates.append([X, Y]) 173 | return np.array(cordinates).astype(int) 174 | 175 | input_folder = './checkpoints/PoseTransfer_market/output' 176 | output_path = './checkpoints/PoseTransfer_market/output_pckh.csv' 177 | 178 | # input_folder = './checkpoints/PoseTransfer_deepfashion/output' 179 | # output_path = './checkpoints/PoseTransfer_deepfashion/output_pckh.csv' 180 | 181 | 182 | img_list = os.listdir(input_folder) 183 | 184 | threshold = 0.1 185 | boxsize = 368 186 | scale_search = [0.5, 1, 1.5, 2] 187 | 188 | if os.path.exists(output_path): 189 | processed_names = set(pd.read_csv(output_path, sep=':')['name']) 190 | result_file = open(output_path, 'a') 191 | else: 192 | result_file = open(output_path, 'w') 193 | processed_names = set() 194 | print >> result_file, 'name:keypoints_y:keypoints_x' 195 | 196 | # for image_name in tqdm(os.listdir(input_folder)): 197 | for image_name in tqdm(img_list): 198 | if image_name in processed_names: 199 | continue 200 | 201 | oriImg = imread(os.path.join(input_folder, image_name))[:, :, ::-1] # B,G,R order 202 | 203 | multiplier = [x * boxsize / oriImg.shape[0] for x in scale_search] 204 | 205 | heatmap_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 19)) 206 | paf_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 38)) 207 | 208 | for m in range(len(multiplier)): 209 | scale = multiplier[m] 210 | 211 | new_size = (np.array(oriImg.shape[:2]) * scale).astype(np.int32) 212 | imageToTest = resize(oriImg, new_size, order=3, preserve_range=True) 213 | imageToTest_padded = imageToTest[np.newaxis, :, :, :]/255 - 0.5 214 | 215 | output1, output2 = model.predict(imageToTest_padded) 216 | 217 | heatmap = st.resize(output2[0], oriImg.shape[:2], preserve_range=True, order=1) 218 | paf = st.resize(output1[0], oriImg.shape[:2], preserve_range=True, order=1) 219 | heatmap_avg += heatmap 220 | paf_avg += paf 221 | 222 | heatmap_avg /= len(multiplier) 223 | 224 | pose_cords = compute_cordinates(heatmap_avg, paf_avg) 225 | 226 | print >> result_file, "%s: %s: %s" % (image_name, str(list(pose_cords[:, 0])), str(list(pose_cords[:, 1]))) 227 | result_file.flush() 228 | -------------------------------------------------------------------------------- /tools/metrics_deepfashion.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | import torch 4 | import numpy as np 5 | from imageio import imread 6 | from scipy import linalg 7 | from torch.nn.functional import adaptive_avg_pool2d 8 | import glob 9 | import argparse 10 | from PIL import Image 11 | import tqdm 12 | import lpips 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | from torchvision import models 16 | 17 | class InceptionV3(nn.Module): 18 | """Pretrained InceptionV3 network returning feature maps""" 19 | 20 | # Index of default block of inception to return, 21 | # corresponds to output of final average pooling 22 | DEFAULT_BLOCK_INDEX = 3 23 | 24 | # Maps feature dimensionality to their output blocks indices 25 | BLOCK_INDEX_BY_DIM = { 26 | 64: 0, # First max pooling features 27 | 192: 1, # Second max pooling featurs 28 | 768: 2, # Pre-aux classifier features 29 | 2048: 3 # Final average pooling features 30 | } 31 | 32 | def __init__(self, 33 | output_blocks=[DEFAULT_BLOCK_INDEX], 34 | resize_input=True, 35 | normalize_input=True, 36 | requires_grad=False): 37 | """Build pretrained InceptionV3 38 | Parameters 39 | ---------- 40 | output_blocks : list of int 41 | Indices of blocks to return features of. Possible values are: 42 | - 0: corresponds to output of first max pooling 43 | - 1: corresponds to output of second max pooling 44 | - 2: corresponds to output which is fed to aux classifier 45 | - 3: corresponds to output of final average pooling 46 | resize_input : bool 47 | If true, bilinearly resizes input to width and height 299 before 48 | feeding input to model. As the network without fully connected 49 | layers is fully convolutional, it should be able to handle inputs 50 | of arbitrary size, so resizing might not be strictly needed 51 | normalize_input : bool 52 | If true, normalizes the input to the statistics the pretrained 53 | Inception network expects 54 | requires_grad : bool 55 | If true, parameters of the model require gradient. Possibly useful 56 | for finetuning the network 57 | """ 58 | super(InceptionV3, self).__init__() 59 | 60 | self.resize_input = resize_input 61 | self.normalize_input = normalize_input 62 | self.output_blocks = sorted(output_blocks) 63 | self.last_needed_block = max(output_blocks) 64 | 65 | assert self.last_needed_block <= 3, \ 66 | 'Last possible output block index is 3' 67 | 68 | self.blocks = nn.ModuleList() 69 | 70 | inception = models.inception_v3(pretrained=True) 71 | 72 | # Block 0: input to maxpool1 73 | block0 = [ 74 | inception.Conv2d_1a_3x3, 75 | inception.Conv2d_2a_3x3, 76 | inception.Conv2d_2b_3x3, 77 | nn.MaxPool2d(kernel_size=3, stride=2) 78 | ] 79 | self.blocks.append(nn.Sequential(*block0)) 80 | 81 | # Block 1: maxpool1 to maxpool2 82 | if self.last_needed_block >= 1: 83 | block1 = [ 84 | inception.Conv2d_3b_1x1, 85 | inception.Conv2d_4a_3x3, 86 | nn.MaxPool2d(kernel_size=3, stride=2) 87 | ] 88 | self.blocks.append(nn.Sequential(*block1)) 89 | 90 | # Block 2: maxpool2 to aux classifier 91 | if self.last_needed_block >= 2: 92 | block2 = [ 93 | inception.Mixed_5b, 94 | inception.Mixed_5c, 95 | inception.Mixed_5d, 96 | inception.Mixed_6a, 97 | inception.Mixed_6b, 98 | inception.Mixed_6c, 99 | inception.Mixed_6d, 100 | inception.Mixed_6e, 101 | ] 102 | self.blocks.append(nn.Sequential(*block2)) 103 | 104 | # Block 3: aux classifier to final avgpool 105 | if self.last_needed_block >= 3: 106 | block3 = [ 107 | inception.Mixed_7a, 108 | inception.Mixed_7b, 109 | inception.Mixed_7c, 110 | nn.AdaptiveAvgPool2d(output_size=(1, 1)) 111 | ] 112 | self.blocks.append(nn.Sequential(*block3)) 113 | 114 | for param in self.parameters(): 115 | param.requires_grad = requires_grad 116 | 117 | def forward(self, inp): 118 | """Get Inception feature maps 119 | Parameters 120 | ---------- 121 | inp : torch.autograd.Variable 122 | Input tensor of shape Bx3xHxW. Values are expected to be in 123 | range (0, 1) 124 | Returns 125 | ------- 126 | List of torch.autograd.Variable, corresponding to the selected output 127 | block, sorted ascending by index 128 | """ 129 | outp = [] 130 | x = inp 131 | 132 | if self.resize_input: 133 | x = F.upsample(x, size=(299, 299), mode='bilinear') 134 | 135 | if self.normalize_input: 136 | x = x.clone() 137 | x[:, 0] = x[:, 0] * (0.229 / 0.5) + (0.485 - 0.5) / 0.5 138 | x[:, 1] = x[:, 1] * (0.224 / 0.5) + (0.456 - 0.5) / 0.5 139 | x[:, 2] = x[:, 2] * (0.225 / 0.5) + (0.406 - 0.5) / 0.5 140 | 141 | for idx, block in enumerate(self.blocks): 142 | x = block(x) 143 | if idx in self.output_blocks: 144 | outp.append(x) 145 | 146 | if idx == self.last_needed_block: 147 | break 148 | 149 | return outp 150 | 151 | class FID(): 152 | """docstring for FID 153 | Calculates the Frechet Inception Distance (FID) to evalulate GANs 154 | The FID metric calculates the distance between two distributions of images. 155 | Typically, we have summary statistics (mean & covariance matrix) of one 156 | of these distributions, while the 2nd distribution is given by a GAN. 157 | When run as a stand-alone program, it compares the distribution of 158 | images that are stored as PNG/JPEG at a specified location with a 159 | distribution given by summary statistics (in pickle format). 160 | The FID is calculated by assuming that X_1 and X_2 are the activations of 161 | the pool_3 layer of the inception net for generated samples and real world 162 | samples respectivly. 163 | See --help to see further details. 164 | Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead 165 | of Tensorflow 166 | Copyright 2018 Institute of Bioinformatics, JKU Linz 167 | Licensed under the Apache License, Version 2.0 (the "License"); 168 | you may not use this file except in compliance with the License. 169 | You may obtain a copy of the License at 170 | http://www.apache.org/licenses/LICENSE-2.0 171 | Unless required by applicable law or agreed to in writing, software 172 | distributed under the License is distributed on an "AS IS" BASIS, 173 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 174 | See the License for the specific language governing permissions and 175 | limitations under the License. 176 | """ 177 | def __init__(self): 178 | self.dims = 2048 179 | self.batch_size = 64 180 | self.cuda = True 181 | self.verbose=False 182 | 183 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[self.dims] 184 | self.model = InceptionV3([block_idx]) 185 | if self.cuda: 186 | # TODO: put model into specific GPU 187 | self.model.cuda() 188 | 189 | def __call__(self, images, gt_path): 190 | """ images: list of the generated image. The values must lie between 0 and 1. 191 | gt_path: the path of the ground truth images. The values must lie between 0 and 1. 192 | """ 193 | if not os.path.exists(gt_path): 194 | raise RuntimeError('Invalid path: %s' % gt_path) 195 | 196 | 197 | print('calculate gt_path statistics...') 198 | m1, s1 = self.compute_statistics_of_path(gt_path, self.verbose) 199 | print('calculate generated_images statistics...') 200 | m2, s2 = self.calculate_activation_statistics(images, self.verbose) 201 | fid_value = self.calculate_frechet_distance(m1, s1, m2, s2) 202 | return fid_value 203 | 204 | 205 | def calculate_from_disk(self, generated_path, gt_path): 206 | """ 207 | """ 208 | if not os.path.exists(gt_path): 209 | raise RuntimeError('Invalid path: %s' % gt_path) 210 | if not os.path.exists(generated_path): 211 | raise RuntimeError('Invalid path: %s' % generated_path) 212 | 213 | print('calculate gt_path statistics...') 214 | m1, s1 = self.compute_statistics_of_path(gt_path, self.verbose) 215 | print('calculate generated_path statistics...') 216 | m2, s2 = self.compute_statistics_of_path(generated_path, self.verbose) 217 | print('calculate frechet distance...') 218 | fid_value = self.calculate_frechet_distance(m1, s1, m2, s2) 219 | print('fid_distance %f' % (fid_value)) 220 | return fid_value 221 | 222 | 223 | def compute_statistics_of_path(self, path, verbose): 224 | # npz_file = os.path.join(path, 'statistics.npz') 225 | # if os.path.exists(npz_file): 226 | # f = np.load(npz_file) 227 | # m, s = f['mu'][:], f['sigma'][:] 228 | # f.close() 229 | # else: 230 | m, s = self.calculate_activation_statistics(path, verbose) 231 | # np.savez(npz_file, mu=m, sigma=s) 232 | 233 | return m, s 234 | 235 | def calculate_activation_statistics(self, path, verbose): 236 | """Calculation of the statistics used by the FID. 237 | Params: 238 | -- images : Numpy array of dimension (n_images, 3, hi, wi). The values 239 | must lie between 0 and 1. 240 | -- model : Instance of inception model 241 | -- batch_size : The images numpy array is split into batches with 242 | batch size batch_size. A reasonable batch size 243 | depends on the hardware. 244 | -- dims : Dimensionality of features returned by Inception 245 | -- cuda : If set to True, use GPU 246 | -- verbose : If set to True and parameter out_step is given, the 247 | number of calculated batches is reported. 248 | Returns: 249 | -- mu : The mean over samples of the activations of the pool_3 layer of 250 | the inception model. 251 | -- sigma : The covariance matrix of the activations of the pool_3 layer of 252 | the inception model. 253 | """ 254 | act = self.get_activations(path, verbose) 255 | mu = np.mean(act, axis=0) 256 | sigma = np.cov(act, rowvar=False) 257 | return mu, sigma 258 | 259 | 260 | 261 | def get_activations(self, path, verbose=False): 262 | """Calculates the activations of the pool_3 layer for all images. 263 | Params: 264 | -- images : Numpy array of dimension (n_images, 3, hi, wi). The values 265 | must lie between 0 and 1. 266 | -- model : Instance of inception model 267 | -- batch_size : the images numpy array is split into batches with 268 | batch size batch_size. A reasonable batch size depends 269 | on the hardware. 270 | -- dims : Dimensionality of features returned by Inception 271 | -- cuda : If set to True, use GPU 272 | -- verbose : If set to True and parameter out_step is given, the number 273 | of calculated batches is reported. 274 | Returns: 275 | -- A numpy array of dimension (num images, dims) that contains the 276 | activations of the given tensor when feeding inception with the 277 | query tensor. 278 | """ 279 | self.model.eval() 280 | 281 | path = pathlib.Path(path) 282 | filenames = list(path.glob('*.jpg')) + list(path.glob('*.png')) 283 | # filenames = os.listdir(path) 284 | d0 = len(filenames) 285 | 286 | n_batches = d0 // self.batch_size 287 | n_used_imgs = n_batches * self.batch_size 288 | import tqdm 289 | pred_arr = np.empty((n_used_imgs, self.dims)) 290 | for i in tqdm.tqdm(range(n_batches)): 291 | 292 | start = i * self.batch_size 293 | end = start + self.batch_size 294 | 295 | imgs = np.array([imread(str(fn)).astype(np.float32) for fn in filenames[start:end]]) 296 | 297 | # Bring images to shape (B, 3, H, W) 298 | imgs = imgs.transpose((0, 3, 1, 2)) 299 | 300 | # Rescale images to be between 0 and 1 301 | imgs /= 255 302 | 303 | batch = torch.from_numpy(imgs).type(torch.FloatTensor) 304 | # batch = Variable(batch, volatile=True) 305 | if self.cuda: 306 | batch = batch.cuda() 307 | 308 | pred = self.model(batch)[0] 309 | 310 | # If model output is not scalar, apply global spatial average pooling. 311 | # This happens if you choose a dimensionality not equal 2048. 312 | if pred.shape[2] != 1 or pred.shape[3] != 1: 313 | pred = adaptive_avg_pool2d(pred, output_size=(1, 1)) 314 | 315 | pred_arr[start:end] = pred.cpu().data.numpy().reshape(self.batch_size, -1) 316 | 317 | if verbose: 318 | print(' done') 319 | 320 | return pred_arr 321 | 322 | 323 | def calculate_frechet_distance(self, mu1, sigma1, mu2, sigma2, eps=1e-6): 324 | """Numpy implementation of the Frechet Distance. 325 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 326 | and X_2 ~ N(mu_2, C_2) is 327 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 328 | Stable version by Dougal J. Sutherland. 329 | Params: 330 | -- mu1 : Numpy array containing the activations of a layer of the 331 | inception net (like returned by the function 'get_predictions') 332 | for generated samples. 333 | -- mu2 : The sample mean over activations, precalculated on an 334 | representive data set. 335 | -- sigma1: The covariance matrix over activations for generated samples. 336 | -- sigma2: The covariance matrix over activations, precalculated on an 337 | representive data set. 338 | Returns: 339 | -- : The Frechet Distance. 340 | """ 341 | 342 | mu1 = np.atleast_1d(mu1) 343 | mu2 = np.atleast_1d(mu2) 344 | 345 | sigma1 = np.atleast_2d(sigma1) 346 | sigma2 = np.atleast_2d(sigma2) 347 | 348 | assert mu1.shape == mu2.shape, \ 349 | 'Training and test mean vectors have different lengths' 350 | assert sigma1.shape == sigma2.shape, \ 351 | 'Training and test covariances have different dimensions' 352 | 353 | diff = mu1 - mu2 354 | 355 | # Product might be almost singular 356 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 357 | if not np.isfinite(covmean).all(): 358 | msg = ('fid calculation produces singular product; ' 359 | 'adding %s to diagonal of cov estimates') % eps 360 | print(msg) 361 | offset = np.eye(sigma1.shape[0]) * eps 362 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 363 | 364 | # Numerical error might give slight imaginary component 365 | if np.iscomplexobj(covmean): 366 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 367 | m = np.max(np.abs(covmean.imag)) 368 | raise ValueError('Imaginary component {}'.format(m)) 369 | covmean = covmean.real 370 | 371 | tr_covmean = np.trace(covmean) 372 | 373 | return (diff.dot(diff) + np.trace(sigma1) + 374 | np.trace(sigma2) - 2 * tr_covmean) 375 | 376 | def get_image_list(flist): 377 | if isinstance(flist, list): 378 | return flist 379 | 380 | # flist: image file path, image directory path, text file flist path 381 | if isinstance(flist, str): 382 | if os.path.isdir(flist): 383 | flist = list(glob.glob(flist + '/*.jpg')) + list(glob.glob(flist + '/*.png')) 384 | flist.sort() 385 | return flist 386 | 387 | if os.path.isfile(flist): 388 | try: 389 | return np.genfromtxt(flist, dtype=np.str) 390 | except: 391 | return [flist] 392 | print('can not read files from %s return empty list'%flist) 393 | return [] 394 | 395 | def crop_img(path): 396 | bname = os.path.basename(path) 397 | save_dir = path.split(bname)[0]+bname+'_crop' 398 | if not os.path.exists(save_dir): 399 | os.mkdir(save_dir) 400 | for item in os.listdir(path): 401 | if not item.endswith('.jpg') and not item.endswith('.png'): 402 | continue 403 | img = Image.open(os.path.join(path, item)) 404 | imgcrop = img.crop((40, 0, 216, 256)) 405 | imgcrop.save(os.path.join(save_dir, item)) 406 | return save_dir 407 | 408 | class LPIPS(): 409 | def __init__(self, use_gpu=True): 410 | self.model =lpips.LPIPS(net='alex').cuda() 411 | self.use_gpu=use_gpu 412 | 413 | def __call__(self, image_1, image_2): 414 | """ 415 | image_1: images with size (n, 3, w, h) with value [-1, 1] 416 | image_2: images with size (n, 3, w, h) with value [-1, 1] 417 | """ 418 | result = self.model.forward(image_1, image_2) 419 | return result 420 | 421 | def calculate_from_disk(self, path_1, path_2, batch_size=64, verbose=False, sort=True): 422 | if sort: 423 | files_1 = sorted(get_image_list(path_1)) 424 | files_2 = sorted(get_image_list(path_2)) 425 | else: 426 | files_1 = get_image_list(path_1) 427 | files_2 = get_image_list(path_2) 428 | 429 | result=[] 430 | 431 | d0 = len(files_1) 432 | assert len(files_1) == len(files_2) 433 | if batch_size > d0: 434 | print(('Warning: batch size is bigger than the data size. ' 435 | 'Setting batch size to data size')) 436 | batch_size = d0 437 | 438 | n_batches = d0 // batch_size 439 | n_used_imgs = n_batches * batch_size 440 | 441 | for i in tqdm.tqdm(range(n_batches)): 442 | if verbose: 443 | print('\rPropagating batch %d/%d' % (i + 1, n_batches)) 444 | # end='', flush=True) 445 | start = i * batch_size 446 | end = start + batch_size 447 | imgs_1 = np.array([imread(str(fn)).astype(np.float32) / 127.5 - 1 for fn in files_1[start:end]]) 448 | imgs_2 = np.array([imread(str(fn)).astype(np.float32) / 127.5 - 1 for fn in files_2[start:end]]) 449 | 450 | # Bring images to shape (B, 3, H, W) 451 | imgs_1 = imgs_1.transpose((0, 3, 1, 2)) 452 | imgs_2 = imgs_2.transpose((0, 3, 1, 2)) 453 | img_1_batch = torch.from_numpy(imgs_1).type(torch.FloatTensor) 454 | img_2_batch = torch.from_numpy(imgs_2).type(torch.FloatTensor) 455 | 456 | if self.use_gpu: 457 | img_1_batch = img_1_batch.cuda() 458 | img_2_batch = img_2_batch.cuda() 459 | result.append(self.model.forward(img_1_batch, img_2_batch).detach().cpu().numpy()) 460 | distance = np.average(result) 461 | print('lpips: %.4f'%distance) 462 | return distance 463 | 464 | if __name__ == "__main__": 465 | print('load start') 466 | 467 | lpips = LPIPS() 468 | print('load LPIPS') 469 | 470 | parser = argparse.ArgumentParser(description='script to compute all statistics') 471 | parser.add_argument('--gt_path', help='Path to ground truth data', type=str) 472 | parser.add_argument('--distorated_path', help='Path to output data', type=str) 473 | parser.add_argument('--fid_real_path', help='Path to real images when calculate FID', type=str) 474 | args = parser.parse_args() 475 | 476 | for arg in vars(args): 477 | print('[%s] =' % arg, getattr(args, arg)) 478 | args.distorated_path = crop_img(args.distorated_path) 479 | 480 | fid = FID() 481 | print('load FID') 482 | 483 | print('calculate fid metric...') 484 | fid_score = fid.calculate_from_disk(args.distorated_path, args.fid_real_path) 485 | 486 | print('calculate lpips metric...') 487 | lpips_score = lpips.calculate_from_disk(args.distorated_path, args.gt_path, sort=False) 488 | 489 | -------------------------------------------------------------------------------- /tools/metrics_market.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | import torch 4 | import numpy as np 5 | from imageio import imread 6 | from scipy import linalg 7 | from torch.nn.functional import adaptive_avg_pool2d 8 | from skimage.morphology import dilation, erosion, square 9 | 10 | import glob 11 | import argparse 12 | import pandas as pd 13 | import json 14 | from skimage.draw import circle, polygon 15 | import tqdm 16 | 17 | import torch.nn as nn 18 | import torch.nn.functional as F 19 | from torchvision import models 20 | import lpips 21 | from skimage.measure import compare_ssim 22 | 23 | 24 | class InceptionV3(nn.Module): 25 | """Pretrained InceptionV3 network returning feature maps""" 26 | 27 | # Index of default block of inception to return, 28 | # corresponds to output of final average pooling 29 | DEFAULT_BLOCK_INDEX = 3 30 | 31 | # Maps feature dimensionality to their output blocks indices 32 | BLOCK_INDEX_BY_DIM = { 33 | 64: 0, # First max pooling features 34 | 192: 1, # Second max pooling featurs 35 | 768: 2, # Pre-aux classifier features 36 | 2048: 3 # Final average pooling features 37 | } 38 | 39 | def __init__(self, 40 | output_blocks=[DEFAULT_BLOCK_INDEX], 41 | resize_input=True, 42 | normalize_input=True, 43 | requires_grad=False): 44 | """Build pretrained InceptionV3 45 | Parameters 46 | ---------- 47 | output_blocks : list of int 48 | Indices of blocks to return features of. Possible values are: 49 | - 0: corresponds to output of first max pooling 50 | - 1: corresponds to output of second max pooling 51 | - 2: corresponds to output which is fed to aux classifier 52 | - 3: corresponds to output of final average pooling 53 | resize_input : bool 54 | If true, bilinearly resizes input to width and height 299 before 55 | feeding input to model. As the network without fully connected 56 | layers is fully convolutional, it should be able to handle inputs 57 | of arbitrary size, so resizing might not be strictly needed 58 | normalize_input : bool 59 | If true, normalizes the input to the statistics the pretrained 60 | Inception network expects 61 | requires_grad : bool 62 | If true, parameters of the model require gradient. Possibly useful 63 | for finetuning the network 64 | """ 65 | super(InceptionV3, self).__init__() 66 | 67 | self.resize_input = resize_input 68 | self.normalize_input = normalize_input 69 | self.output_blocks = sorted(output_blocks) 70 | self.last_needed_block = max(output_blocks) 71 | 72 | assert self.last_needed_block <= 3, \ 73 | 'Last possible output block index is 3' 74 | 75 | self.blocks = nn.ModuleList() 76 | 77 | inception = models.inception_v3(pretrained=True) 78 | 79 | # Block 0: input to maxpool1 80 | block0 = [ 81 | inception.Conv2d_1a_3x3, 82 | inception.Conv2d_2a_3x3, 83 | inception.Conv2d_2b_3x3, 84 | nn.MaxPool2d(kernel_size=3, stride=2) 85 | ] 86 | self.blocks.append(nn.Sequential(*block0)) 87 | 88 | # Block 1: maxpool1 to maxpool2 89 | if self.last_needed_block >= 1: 90 | block1 = [ 91 | inception.Conv2d_3b_1x1, 92 | inception.Conv2d_4a_3x3, 93 | nn.MaxPool2d(kernel_size=3, stride=2) 94 | ] 95 | self.blocks.append(nn.Sequential(*block1)) 96 | 97 | # Block 2: maxpool2 to aux classifier 98 | if self.last_needed_block >= 2: 99 | block2 = [ 100 | inception.Mixed_5b, 101 | inception.Mixed_5c, 102 | inception.Mixed_5d, 103 | inception.Mixed_6a, 104 | inception.Mixed_6b, 105 | inception.Mixed_6c, 106 | inception.Mixed_6d, 107 | inception.Mixed_6e, 108 | ] 109 | self.blocks.append(nn.Sequential(*block2)) 110 | 111 | # Block 3: aux classifier to final avgpool 112 | if self.last_needed_block >= 3: 113 | block3 = [ 114 | inception.Mixed_7a, 115 | inception.Mixed_7b, 116 | inception.Mixed_7c, 117 | nn.AdaptiveAvgPool2d(output_size=(1, 1)) 118 | ] 119 | self.blocks.append(nn.Sequential(*block3)) 120 | 121 | for param in self.parameters(): 122 | param.requires_grad = requires_grad 123 | 124 | def forward(self, inp): 125 | """Get Inception feature maps 126 | Parameters 127 | ---------- 128 | inp : torch.autograd.Variable 129 | Input tensor of shape Bx3xHxW. Values are expected to be in 130 | range (0, 1) 131 | Returns 132 | ------- 133 | List of torch.autograd.Variable, corresponding to the selected output 134 | block, sorted ascending by index 135 | """ 136 | outp = [] 137 | x = inp 138 | 139 | if self.resize_input: 140 | x = F.upsample(x, size=(299, 299), mode='bilinear') 141 | 142 | if self.normalize_input: 143 | x = x.clone() 144 | x[:, 0] = x[:, 0] * (0.229 / 0.5) + (0.485 - 0.5) / 0.5 145 | x[:, 1] = x[:, 1] * (0.224 / 0.5) + (0.456 - 0.5) / 0.5 146 | x[:, 2] = x[:, 2] * (0.225 / 0.5) + (0.406 - 0.5) / 0.5 147 | 148 | for idx, block in enumerate(self.blocks): 149 | x = block(x) 150 | if idx in self.output_blocks: 151 | outp.append(x) 152 | 153 | if idx == self.last_needed_block: 154 | break 155 | 156 | return outp 157 | 158 | class FID(): 159 | """docstring for FID 160 | Calculates the Frechet Inception Distance (FID) to evalulate GANs 161 | The FID metric calculates the distance between two distributions of images. 162 | Typically, we have summary statistics (mean & covariance matrix) of one 163 | of these distributions, while the 2nd distribution is given by a GAN. 164 | When run as a stand-alone program, it compares the distribution of 165 | images that are stored as PNG/JPEG at a specified location with a 166 | distribution given by summary statistics (in pickle format). 167 | The FID is calculated by assuming that X_1 and X_2 are the activations of 168 | the pool_3 layer of the inception net for generated samples and real world 169 | samples respectivly. 170 | See --help to see further details. 171 | Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead 172 | of Tensorflow 173 | Copyright 2018 Institute of Bioinformatics, JKU Linz 174 | Licensed under the Apache License, Version 2.0 (the "License"); 175 | you may not use this file except in compliance with the License. 176 | You may obtain a copy of the License at 177 | http://www.apache.org/licenses/LICENSE-2.0 178 | Unless required by applicable law or agreed to in writing, software 179 | distributed under the License is distributed on an "AS IS" BASIS, 180 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 181 | See the License for the specific language governing permissions and 182 | limitations under the License. 183 | """ 184 | def __init__(self): 185 | self.dims = 2048 186 | self.batch_size = 64 187 | self.cuda = True 188 | self.verbose=False 189 | 190 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[self.dims] 191 | self.model = InceptionV3([block_idx]) 192 | if self.cuda: 193 | # TODO: put model into specific GPU 194 | self.model.cuda() 195 | 196 | def __call__(self, images, gt_path): 197 | """ images: list of the generated image. The values must lie between 0 and 1. 198 | gt_path: the path of the ground truth images. The values must lie between 0 and 1. 199 | """ 200 | if not os.path.exists(gt_path): 201 | raise RuntimeError('Invalid path: %s' % gt_path) 202 | 203 | 204 | print('calculate gt_path statistics...') 205 | m1, s1 = self.compute_statistics_of_path(gt_path, self.verbose) 206 | print('calculate generated_images statistics...') 207 | m2, s2 = self.calculate_activation_statistics(images, self.verbose) 208 | fid_value = self.calculate_frechet_distance(m1, s1, m2, s2) 209 | return fid_value 210 | 211 | 212 | def calculate_from_disk(self, generated_path, gt_path): 213 | """ 214 | """ 215 | if not os.path.exists(gt_path): 216 | raise RuntimeError('Invalid path: %s' % gt_path) 217 | if not os.path.exists(generated_path): 218 | raise RuntimeError('Invalid path: %s' % generated_path) 219 | 220 | print('calculate gt_path statistics...') 221 | m1, s1 = self.compute_statistics_of_path(gt_path, self.verbose) 222 | print('calculate generated_path statistics...') 223 | m2, s2 = self.compute_statistics_of_path(generated_path, self.verbose) 224 | print('calculate frechet distance...') 225 | fid_value = self.calculate_frechet_distance(m1, s1, m2, s2) 226 | print('fid_distance %f' % (fid_value)) 227 | return fid_value 228 | 229 | 230 | def compute_statistics_of_path(self, path, verbose): 231 | npz_file = os.path.join(path, 'statistics.npz') 232 | if os.path.exists(npz_file): 233 | f = np.load(npz_file) 234 | m, s = f['mu'][:], f['sigma'][:] 235 | f.close() 236 | else: 237 | path = pathlib.Path(path) 238 | files = list(path.glob('*.jpg')) + list(path.glob('*.png')) 239 | imgs = np.array([imread(str(fn)).astype(np.float32) for fn in files]) 240 | 241 | # Bring images to shape (B, 3, H, W) 242 | imgs = imgs.transpose((0, 3, 1, 2)) 243 | 244 | # Rescale images to be between 0 and 1 245 | imgs /= 255 246 | 247 | m, s = self.calculate_activation_statistics(path, verbose) 248 | np.savez(npz_file, mu=m, sigma=s) 249 | 250 | return m, s 251 | 252 | def calculate_activation_statistics(self, path, verbose): 253 | """Calculation of the statistics used by the FID. 254 | Params: 255 | -- images : Numpy array of dimension (n_images, 3, hi, wi). The values 256 | must lie between 0 and 1. 257 | -- model : Instance of inception model 258 | -- batch_size : The images numpy array is split into batches with 259 | batch size batch_size. A reasonable batch size 260 | depends on the hardware. 261 | -- dims : Dimensionality of features returned by Inception 262 | -- cuda : If set to True, use GPU 263 | -- verbose : If set to True and parameter out_step is given, the 264 | number of calculated batches is reported. 265 | Returns: 266 | -- mu : The mean over samples of the activations of the pool_3 layer of 267 | the inception model. 268 | -- sigma : The covariance matrix of the activations of the pool_3 layer of 269 | the inception model. 270 | """ 271 | act = self.get_activations(path, verbose) 272 | mu = np.mean(act, axis=0) 273 | sigma = np.cov(act, rowvar=False) 274 | return mu, sigma 275 | 276 | 277 | 278 | def get_activations(self, path, verbose=False): 279 | """Calculates the activations of the pool_3 layer for all images. 280 | Params: 281 | -- images : Numpy array of dimension (n_images, 3, hi, wi). The values 282 | must lie between 0 and 1. 283 | -- model : Instance of inception model 284 | -- batch_size : the images numpy array is split into batches with 285 | batch size batch_size. A reasonable batch size depends 286 | on the hardware. 287 | -- dims : Dimensionality of features returned by Inception 288 | -- cuda : If set to True, use GPU 289 | -- verbose : If set to True and parameter out_step is given, the number 290 | of calculated batches is reported. 291 | Returns: 292 | -- A numpy array of dimension (num images, dims) that contains the 293 | activations of the given tensor when feeding inception with the 294 | query tensor. 295 | """ 296 | self.model.eval() 297 | 298 | path = pathlib.Path(path) 299 | filenames = list(path.glob('*.jpg')) + list(path.glob('*.png')) 300 | # filenames = os.listdir(path) 301 | d0 = len(filenames) 302 | 303 | n_batches = d0 // self.batch_size 304 | n_used_imgs = n_batches * self.batch_size 305 | import tqdm 306 | pred_arr = np.empty((n_used_imgs, self.dims)) 307 | for i in tqdm.tqdm(range(n_batches)): 308 | # if verbose: 309 | # print('\rPropagating batch %d/%d' % (i + 1, n_batches)) 310 | # end='', flush=True) 311 | start = i * self.batch_size 312 | end = start + self.batch_size 313 | 314 | imgs = np.array([imread(str(fn)).astype(np.float32) for fn in filenames[start:end]]) 315 | 316 | # Bring images to shape (B, 3, H, W) 317 | imgs = imgs.transpose((0, 3, 1, 2)) 318 | 319 | # Rescale images to be between 0 and 1 320 | imgs /= 255 321 | 322 | batch = torch.from_numpy(imgs).type(torch.FloatTensor) 323 | # batch = Variable(batch, volatile=True) 324 | if self.cuda: 325 | batch = batch.cuda() 326 | 327 | pred = self.model(batch)[0] 328 | 329 | # If model output is not scalar, apply global spatial average pooling. 330 | # This happens if you choose a dimensionality not equal 2048. 331 | if pred.shape[2] != 1 or pred.shape[3] != 1: 332 | pred = adaptive_avg_pool2d(pred, output_size=(1, 1)) 333 | 334 | pred_arr[start:end] = pred.cpu().data.numpy().reshape(self.batch_size, -1) 335 | 336 | if verbose: 337 | print(' done') 338 | 339 | return pred_arr 340 | 341 | 342 | def calculate_frechet_distance(self, mu1, sigma1, mu2, sigma2, eps=1e-6): 343 | """Numpy implementation of the Frechet Distance. 344 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 345 | and X_2 ~ N(mu_2, C_2) is 346 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 347 | Stable version by Dougal J. Sutherland. 348 | Params: 349 | -- mu1 : Numpy array containing the activations of a layer of the 350 | inception net (like returned by the function 'get_predictions') 351 | for generated samples. 352 | -- mu2 : The sample mean over activations, precalculated on an 353 | representive data set. 354 | -- sigma1: The covariance matrix over activations for generated samples. 355 | -- sigma2: The covariance matrix over activations, precalculated on an 356 | representive data set. 357 | Returns: 358 | -- : The Frechet Distance. 359 | """ 360 | 361 | mu1 = np.atleast_1d(mu1) 362 | mu2 = np.atleast_1d(mu2) 363 | 364 | sigma1 = np.atleast_2d(sigma1) 365 | sigma2 = np.atleast_2d(sigma2) 366 | 367 | assert mu1.shape == mu2.shape, \ 368 | 'Training and test mean vectors have different lengths' 369 | assert sigma1.shape == sigma2.shape, \ 370 | 'Training and test covariances have different dimensions' 371 | 372 | diff = mu1 - mu2 373 | 374 | # Product might be almost singular 375 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 376 | if not np.isfinite(covmean).all(): 377 | msg = ('fid calculation produces singular product; ' 378 | 'adding %s to diagonal of cov estimates') % eps 379 | print(msg) 380 | offset = np.eye(sigma1.shape[0]) * eps 381 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 382 | 383 | # Numerical error might give slight imaginary component 384 | if np.iscomplexobj(covmean): 385 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 386 | m = np.max(np.abs(covmean.imag)) 387 | raise ValueError('Imaginary component {}'.format(m)) 388 | covmean = covmean.real 389 | 390 | tr_covmean = np.trace(covmean) 391 | 392 | return (diff.dot(diff) + np.trace(sigma1) + 393 | np.trace(sigma2) - 2 * tr_covmean) 394 | 395 | def get_image_list(flist): 396 | if isinstance(flist, list): 397 | return flist 398 | 399 | # flist: image file path, image directory path, text file flist path 400 | if isinstance(flist, str): 401 | if os.path.isdir(flist): 402 | flist = list(glob.glob(flist + '/*.jpg')) + list(glob.glob(flist + '/*.png')) 403 | flist.sort() 404 | return flist 405 | 406 | if os.path.isfile(flist): 407 | try: 408 | return np.genfromtxt(flist, dtype=np.str) 409 | except: 410 | return [flist] 411 | print('can not read files from %s return empty list'%flist) 412 | return [] 413 | 414 | class LPIPS(): 415 | def __init__(self, use_gpu=True): 416 | self.model = lpips.LPIPS(net='alex').cuda() 417 | self.use_gpu=use_gpu 418 | 419 | def __call__(self, image_1, image_2): 420 | """ 421 | image_1: images with size (n, 3, w, h) with value [-1, 1] 422 | image_2: images with size (n, 3, w, h) with value [-1, 1] 423 | """ 424 | result = self.model.forward(image_1, image_2) 425 | return result 426 | 427 | def calculate_from_disk(self, files_1, files_2, batch_size=64, verbose=False): 428 | 429 | result=[] 430 | 431 | d0 = len(files_1) 432 | assert len(files_1) == len(files_2) 433 | if batch_size > d0: 434 | print(('Warning: batch size is bigger than the data size. ' 435 | 'Setting batch size to data size')) 436 | batch_size = d0 437 | 438 | n_batches = d0 // batch_size 439 | 440 | for i in tqdm.tqdm(range(n_batches)): 441 | if verbose: 442 | print('\rPropagating batch %d/%d' % (i + 1, n_batches)) 443 | start = i * batch_size 444 | end = start + batch_size 445 | imgs_1 = np.array(files_1[start:end]) 446 | imgs_2 = np.array(files_2[start:end]) 447 | # Bring images to shape (B, 3, H, W) 448 | imgs_1 = imgs_1.transpose((0, 3, 1, 2)) 449 | imgs_2 = imgs_2.transpose((0, 3, 1, 2)) 450 | img_1_batch = torch.from_numpy(imgs_1).type(torch.FloatTensor) 451 | img_2_batch = torch.from_numpy(imgs_2).type(torch.FloatTensor) 452 | 453 | if self.use_gpu: 454 | img_1_batch = img_1_batch.cuda() 455 | img_2_batch = img_2_batch.cuda() 456 | 457 | result.append(self.model.forward(img_1_batch, img_2_batch).detach().cpu().numpy()) 458 | 459 | distance = np.average(result) 460 | print('lpips: %.4f'%distance) 461 | return distance 462 | 463 | def calculate_mask_lpips(self, files_1, files_2, batch_size=64, verbose=False): 464 | result=[] 465 | d0 = len(files_1) 466 | if batch_size > d0: 467 | print(('Warning: batch size is bigger than the data size. ' 468 | 'Setting batch size to data size')) 469 | batch_size = d0 470 | 471 | n_batches = d0 // batch_size 472 | for i in tqdm.tqdm(range(n_batches)): 473 | if verbose: 474 | print('\rPropagating batch %d/%d' % (i + 1, n_batches)) 475 | start = i * batch_size 476 | end = start + batch_size 477 | imgs_1 = np.array(files_1[start:end]) 478 | imgs_2 = np.array(files_2[start:end]) 479 | # Bring images to shape (B, 3, H, W) 480 | imgs_1 = imgs_1.transpose((0, 3, 1, 2)) 481 | imgs_2 = imgs_2.transpose((0, 3, 1, 2)) 482 | 483 | img_1_batch = torch.from_numpy(imgs_1).type(torch.FloatTensor) 484 | img_2_batch = torch.from_numpy(imgs_2).type(torch.FloatTensor) 485 | 486 | if self.use_gpu: 487 | img_1_batch = img_1_batch.cuda() 488 | img_2_batch = img_2_batch.cuda() 489 | 490 | result.append(self.model.forward(img_1_batch, img_2_batch).detach().cpu().numpy()) 491 | 492 | distance = np.average(result) 493 | print('masked lpips: %.4f'%distance) 494 | return distance 495 | 496 | def produce_ma_mask(kp_array, img_size=(128, 64), point_radius=4): 497 | MISSING_VALUE = -1 498 | mask = np.zeros(shape=img_size, dtype=bool) 499 | limbs = [[2,3], [2,6], [3,4], [4,5], [6,7], [7,8], [2,9], [9,10], 500 | [10,11], [2,12], [12,13], [13,14], [2,1], [1,15], [15,17], 501 | [1,16], [16,18], [2,17], [2,18], [9,12], [12,6], [9,3], [17,18]] 502 | limbs = np.array(limbs) - 1 503 | for f, t in limbs: 504 | from_missing = kp_array[f][0] == MISSING_VALUE or kp_array[f][1] == MISSING_VALUE 505 | to_missing = kp_array[t][0] == MISSING_VALUE or kp_array[t][1] == MISSING_VALUE 506 | if from_missing or to_missing: 507 | continue 508 | 509 | norm_vec = kp_array[f] - kp_array[t] 510 | norm_vec = np.array([-norm_vec[1], norm_vec[0]]) 511 | norm_vec = point_radius * norm_vec / np.linalg.norm(norm_vec) 512 | vetexes = np.array([ 513 | kp_array[f] + norm_vec, 514 | kp_array[f] - norm_vec, 515 | kp_array[t] - norm_vec, 516 | kp_array[t] + norm_vec 517 | ]) 518 | yy, xx = polygon(vetexes[:, 0], vetexes[:, 1], shape=img_size) 519 | mask[yy, xx] = True 520 | 521 | for i, joint in enumerate(kp_array): 522 | if kp_array[i][0] == MISSING_VALUE or kp_array[i][1] == MISSING_VALUE: 523 | continue 524 | yy, xx = circle(joint[0], joint[1], radius=point_radius, shape=img_size) 525 | mask[yy, xx] = True 526 | 527 | mask = dilation(mask, square(5)) 528 | mask = erosion(mask, square(5)) 529 | return mask 530 | 531 | def load_pose_cords_from_strings(y_str, x_str): 532 | y_cords = json.loads(y_str) 533 | x_cords = json.loads(x_str) 534 | return np.concatenate([np.expand_dims(y_cords, -1), np.expand_dims(x_cords, -1)], axis=1) 535 | 536 | def masked_ssim_score(generated_images, reference_images): 537 | ssim_score_list = [] 538 | for reference_image, generated_image in tqdm.tqdm(zip(reference_images, generated_images)): 539 | ssim = compare_ssim(reference_image, generated_image,gaussian_weights=True, sigma=1.5, 540 | use_sample_covariance=False, multichannel=True, 541 | data_range=generated_image.max() - generated_image.min()) 542 | ssim_score_list.append(ssim) 543 | print ("masked SSIM %.3f" % np.mean(ssim_score_list)) 544 | return np.mean(ssim_score_list) 545 | 546 | def load_generated_images(generated, gt, annotation_file): 547 | target_images,generated_images,target_images_m,generated_images_m, stm, sgm = [],[],[],[],[],[] 548 | df = pd.read_csv(annotation_file, sep=':') 549 | for file in tqdm.tqdm(os.listdir(generated)): 550 | if not file.endswith('.jpg'): 551 | continue 552 | gntimg = imread(os.path.join(generated,file)) 553 | gtimg= imread(os.path.join(gt,file)) 554 | fs = file.split('_') 555 | name=fs[4]+'_'+fs[5]+'_'+fs[6]+'_'+fs[7] 556 | ano_to = df[df['name'] == name].iloc[0] 557 | kp_to = load_pose_cords_from_strings(ano_to['keypoints_y'], ano_to['keypoints_x']) 558 | 559 | mask = produce_ma_mask(kp_to, (128,64)) 560 | 561 | stm.append(gtimg * mask[..., np.newaxis]) 562 | sgm.append(gntimg * mask[..., np.newaxis]) 563 | generated_images.append(gntimg.astype(np.float32) / 127.5 - 1) 564 | target_images.append(gtimg.astype(np.float32) / 127.5 - 1) 565 | generated_images_m.append((gntimg.astype(np.float32) / 127.5 - 1) * mask[..., np.newaxis]) 566 | target_images_m.append((gtimg.astype(np.float32) / 127.5 - 1) * mask[..., np.newaxis]) 567 | 568 | return generated_images,generated_images_m,target_images,target_images_m, stm, sgm 569 | 570 | 571 | if __name__ == "__main__": 572 | print('load start') 573 | 574 | parser = argparse.ArgumentParser(description='script to compute all statistics') 575 | parser.add_argument('--gt_path', help='Path to ground truth data', type=str) 576 | parser.add_argument('--distorated_path', help='Path to output data', type=str) 577 | parser.add_argument('--fid_real_path', help='Path to real images when calculate FID', type=str) 578 | parser.add_argument('--bonesLst',default='datasets/market1501/label/market-annotation-test.csv',help='Path to annotation',type=str) 579 | args = parser.parse_args() 580 | 581 | for arg in vars(args): 582 | print('[%s] =' % arg, getattr(args, arg)) 583 | 584 | lpips = LPIPS() 585 | print('load LPIPS') 586 | 587 | fid = FID() 588 | print('load FID') 589 | 590 | print('calculate fid metric...') 591 | fid_score = fid.calculate_from_disk(args.distorated_path, args.fid_real_path) 592 | 593 | print('load imgs...') 594 | generated_images,generated_images_m,target_images,target_images_m, stm, sgm = load_generated_images(args.distorated_path,args.gt_path, args.bonesLst) 595 | 596 | print('calculate lpips metric...') 597 | lpips_score = lpips.calculate_from_disk(generated_images, target_images) 598 | 599 | print('calculate masked lpips and SSIM metric...') 600 | lpips_masked = lpips.calculate_mask_lpips(generated_images_m,target_images_m) 601 | structured_masked = masked_ssim_score(sgm, stm) 602 | -------------------------------------------------------------------------------- /tools/pose_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from skimage.draw import circle, line_aa, polygon 3 | import json 4 | 5 | import matplotlib 6 | matplotlib.use('Agg') 7 | import matplotlib.pyplot as plt 8 | import matplotlib.patches as mpatches 9 | 10 | LIMB_SEQ = [[1,2], [1,5], [2,3], [3,4], [5,6], [6,7], [1,8], [8,9], 11 | [9,10], [1,11], [11,12], [12,13], [1,0], [0,14], [14,16], 12 | [0,15], [15,17], [2,16], [5,17]] 13 | 14 | COLORS = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], 15 | [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], 16 | [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]] 17 | 18 | 19 | LABELS = ['nose', 'neck', 'Rsho', 'Relb', 'Rwri', 'Lsho', 'Lelb', 'Lwri', 20 | 'Rhip', 'Rkne', 'Rank', 'Lhip', 'Lkne', 'Lank', 'Leye', 'Reye', 'Lear', 'Rear'] 21 | 22 | MISSING_VALUE = -1 23 | 24 | 25 | def map_to_cord(pose_map, threshold=0.1): 26 | all_peaks = [[] for i in range(18)] 27 | pose_map = pose_map[..., :18] 28 | 29 | y, x, z = np.where(np.logical_and(pose_map == pose_map.max(axis = (0, 1)), 30 | pose_map > threshold)) 31 | for x_i, y_i, z_i in zip(x, y, z): 32 | all_peaks[z_i].append([x_i, y_i]) 33 | 34 | x_values = [] 35 | y_values = [] 36 | 37 | for i in range(18): 38 | if len(all_peaks[i]) != 0: 39 | x_values.append(all_peaks[i][0][0]) 40 | y_values.append(all_peaks[i][0][1]) 41 | else: 42 | x_values.append(MISSING_VALUE) 43 | y_values.append(MISSING_VALUE) 44 | 45 | return np.concatenate([np.expand_dims(y_values, -1), np.expand_dims(x_values, -1)], axis=1) 46 | 47 | 48 | def cords_to_map(cords, img_size, sigma=6): 49 | result = np.zeros(img_size + cords.shape[0:1], dtype='float32') 50 | for i, point in enumerate(cords): 51 | if point[0] == MISSING_VALUE or point[1] == MISSING_VALUE: 52 | continue 53 | xx, yy = np.meshgrid(np.arange(img_size[1]), np.arange(img_size[0])) 54 | result[..., i] = np.exp(-((yy - point[0]) ** 2 + (xx - point[1]) ** 2) / (2 * sigma ** 2)) 55 | return result 56 | 57 | 58 | def draw_pose_from_cords(pose_joints, img_size, radius=2, draw_joints=True): 59 | colors = np.zeros(shape=img_size + (3, ), dtype=np.uint8) 60 | mask = np.zeros(shape=img_size, dtype=bool) 61 | 62 | if draw_joints: 63 | for f, t in LIMB_SEQ: 64 | from_missing = pose_joints[f][0] == MISSING_VALUE or pose_joints[f][1] == MISSING_VALUE 65 | to_missing = pose_joints[t][0] == MISSING_VALUE or pose_joints[t][1] == MISSING_VALUE 66 | if from_missing or to_missing: 67 | continue 68 | yy, xx, val = line_aa(pose_joints[f][0], pose_joints[f][1], pose_joints[t][0], pose_joints[t][1]) 69 | colors[yy, xx] = np.expand_dims(val, 1) * 255 70 | mask[yy, xx] = True 71 | 72 | for i, joint in enumerate(pose_joints): 73 | if pose_joints[i][0] == MISSING_VALUE or pose_joints[i][1] == MISSING_VALUE: 74 | continue 75 | yy, xx = circle(joint[0], joint[1], radius=radius, shape=img_size) 76 | colors[yy, xx] = COLORS[i] 77 | mask[yy, xx] = True 78 | 79 | return colors, mask 80 | 81 | 82 | def draw_pose_from_map(pose_map, threshold=0.1, **kwargs): 83 | cords = map_to_cord(pose_map, threshold=threshold) 84 | return draw_pose_from_cords(cords, pose_map.shape[:2], **kwargs) 85 | 86 | 87 | def load_pose_cords_from_strings(y_str, x_str): 88 | y_cords = json.loads(y_str) 89 | x_cords = json.loads(x_str) 90 | return np.concatenate([np.expand_dims(y_cords, -1), np.expand_dims(x_cords, -1)], axis=1) 91 | 92 | def mean_inputation(X): 93 | X = X.copy() 94 | for i in range(X.shape[1]): 95 | for j in range(X.shape[2]): 96 | val = np.mean(X[:, i, j][X[:, i, j] != -1]) 97 | X[:, i, j][X[:, i, j] == -1] = val 98 | return X 99 | 100 | def draw_legend(): 101 | handles = [mpatches.Patch(color=np.array(color) / 255.0, label=name) for color, name in zip(COLORS, LABELS)] 102 | plt.legend(handles=handles, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) 103 | 104 | def produce_ma_mask(kp_array, img_size, point_radius=4): 105 | from skimage.morphology import dilation, erosion, square 106 | mask = np.zeros(shape=img_size, dtype=bool) 107 | limbs = [[2,3], [2,6], [3,4], [4,5], [6,7], [7,8], [2,9], [9,10], 108 | [10,11], [2,12], [12,13], [13,14], [2,1], [1,15], [15,17], 109 | [1,16], [16,18], [2,17], [2,18], [9,12], [12,6], [9,3], [17,18]] 110 | limbs = np.array(limbs) - 1 111 | for f, t in limbs: 112 | from_missing = kp_array[f][0] == MISSING_VALUE or kp_array[f][1] == MISSING_VALUE 113 | to_missing = kp_array[t][0] == MISSING_VALUE or kp_array[t][1] == MISSING_VALUE 114 | if from_missing or to_missing: 115 | continue 116 | 117 | norm_vec = kp_array[f] - kp_array[t] 118 | norm_vec = np.array([-norm_vec[1], norm_vec[0]]) 119 | norm_vec = point_radius * norm_vec / np.linalg.norm(norm_vec) 120 | 121 | 122 | vetexes = np.array([ 123 | kp_array[f] + norm_vec, 124 | kp_array[f] - norm_vec, 125 | kp_array[t] - norm_vec, 126 | kp_array[t] + norm_vec 127 | ]) 128 | yy, xx = polygon(vetexes[:, 0], vetexes[:, 1], shape=img_size) 129 | mask[yy, xx] = True 130 | 131 | for i, joint in enumerate(kp_array): 132 | if kp_array[i][0] == MISSING_VALUE or kp_array[i][1] == MISSING_VALUE: 133 | continue 134 | yy, xx = circle(joint[0], joint[1], radius=point_radius, shape=img_size) 135 | mask[yy, xx] = True 136 | 137 | mask = dilation(mask, square(5)) 138 | mask = erosion(mask, square(5)) 139 | return mask 140 | 141 | if __name__ == "__main__": 142 | import pandas as pd 143 | from skimage.io import imread 144 | import pylab as plt 145 | import os 146 | i = 5 147 | df = pd.read_csv('data/market-annotation-train.csv', sep=':') 148 | 149 | for index, row in df.iterrows(): 150 | pose_cords = load_pose_cords_from_strings(row['keypoints_y'], row['keypoints_x']) 151 | 152 | colors, mask = draw_pose_from_cords(pose_cords, (128, 64)) 153 | 154 | mmm = produce_ma_mask(pose_cords, (128, 64)).astype(float)[..., np.newaxis].repeat(3, axis=-1) 155 | # print mmm.shape 156 | print(mmm.shape) 157 | img = imread('data/market-dataset/train/' + row['name']) 158 | 159 | mmm[mask] = colors[mask] 160 | 161 | # print (mmm) 162 | print(mmm) 163 | plt.subplot(1, 1, 1) 164 | plt.imshow(mmm) 165 | plt.show() 166 | -------------------------------------------------------------------------------- /train_deepfashion.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python scripts/train_pose_transfer_model.py \ 4 | --id deepfashion \ 5 | --gpu_ids 0,1 \ 6 | --dataset_name deepfashion \ 7 | --which_model_G dual_unet \ 8 | --G_feat_warp 1 \ 9 | --G_vis_mode residual \ 10 | --pretrained_flow_id FlowReg_deepfashion \ 11 | --pretrained_flow_epoch best \ 12 | --dataset_type pose_transfer_parsing \ 13 | --check_grad_freq 3000 \ 14 | --batch_size 4 \ 15 | --n_epoch 45 -------------------------------------------------------------------------------- /train_market.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python scripts/train_pose_transfer_model.py \ 4 | --id market_test \ 5 | --gpu_ids 0,1,2,3 \ 6 | --dataset_name market \ 7 | --which_model_G dual_unet \ 8 | --G_feat_warp 1 \ 9 | --G_vis_mode residual \ 10 | --pretrained_flow_id FlowReg_market \ 11 | --pretrained_flow_epoch best \ 12 | --dataset_type pose_transfer_parsing_market \ 13 | --check_grad_freq 3000 \ 14 | --batch_size 32 \ 15 | --n_epoch 10 -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszy98/SPGNet/3c1f6b5e290e7339ec01181403f23edf1e87eb15/util/__init__.py -------------------------------------------------------------------------------- /util/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszy98/SPGNet/3c1f6b5e290e7339ec01181403f23edf1e87eb15/util/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /util/__pycache__/flow_util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszy98/SPGNet/3c1f6b5e290e7339ec01181403f23edf1e87eb15/util/__pycache__/flow_util.cpython-37.pyc -------------------------------------------------------------------------------- /util/__pycache__/io.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszy98/SPGNet/3c1f6b5e290e7339ec01181403f23edf1e87eb15/util/__pycache__/io.cpython-37.pyc -------------------------------------------------------------------------------- /util/__pycache__/loss_buffer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszy98/SPGNet/3c1f6b5e290e7339ec01181403f23edf1e87eb15/util/__pycache__/loss_buffer.cpython-37.pyc -------------------------------------------------------------------------------- /util/__pycache__/pose_util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszy98/SPGNet/3c1f6b5e290e7339ec01181403f23edf1e87eb15/util/__pycache__/pose_util.cpython-37.pyc -------------------------------------------------------------------------------- /util/__pycache__/visualizer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszy98/SPGNet/3c1f6b5e290e7339ec01181403f23edf1e87eb15/util/__pycache__/visualizer.cpython-37.pyc -------------------------------------------------------------------------------- /util/flow_util.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Derived from flownet2.0 3 | ''' 4 | import torch 5 | import numpy as np 6 | import cv2 7 | 8 | 9 | def readFlow(fn): 10 | """ 11 | Derived from flownet2.0 12 | """ 13 | f = open(fn, 'rb') 14 | 15 | header = f.read(4) 16 | if header.decode("utf-8") != 'PIEH': 17 | raise Exception('Flow file header does not contain PIEH') 18 | 19 | width = np.fromfile(f, np.int32, 1).squeeze() 20 | height = np.fromfile(f, np.int32, 1).squeeze() 21 | 22 | flow = np.fromfile(f, np.float32, width * height * 2).reshape((height, width, 2)) 23 | 24 | return flow.astype(np.float32) 25 | 26 | def writeFlow(fn, flow): 27 | """ 28 | Derived from flownet2.0 29 | """ 30 | f = open(fn, 'wb') 31 | f.write('PIEH'.encode('utf-8')) 32 | np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f) 33 | flow = flow.astype(np.float32) 34 | flow.tofile(f) 35 | f.flush() 36 | f.close() 37 | 38 | 39 | def write_corr(fn, corr, mask): 40 | ''' 41 | Save correspondence map (float data) and mask (uint data) to one file. 42 | Input: 43 | fn: fine name 44 | corr: (H, W, 2), float32 45 | mask: (H, W), uint8 46 | ''' 47 | assert corr.shape[:2]==mask.shape 48 | with open(fn, 'wb') as f: 49 | np.array([corr.shape[1], corr.shape[0]], dtype=np.int32).tofile(f) 50 | corr.astype(np.float32).tofile(f) 51 | mask.astype(np.uint8).tofile(f) 52 | 53 | def read_corr(fn): 54 | ''' 55 | Recover correspondence map and mask saved by "write_corr" 56 | ''' 57 | with open(fn, 'rb') as f: 58 | width = np.fromfile(f, np.int32, 1).squeeze() 59 | height = np.fromfile(f, np.int32, 1).squeeze() 60 | corr = np.fromfile(f, np.float32, width*height*2).reshape((height, width, 2)) 61 | mask = np.fromfile(f, np.uint8, width*height).reshape((height, width)) 62 | return corr, mask 63 | 64 | def visualize_corr(img_1, img_2, corr_1to2, mask_1=None, grid_step=5): 65 | ''' 66 | Input: 67 | img_1: (h1, w1, 3) 68 | img_2: (h2, w2, 3) 69 | corr_1to2: (h1, w1, 2) 70 | grid_step: scalar 71 | Output: 72 | img_out: (max(h1, w1), w1+w2, 3) 73 | ''' 74 | 75 | h1, w1 = img_1.shape[0:2] 76 | h2, w2 = img_2.shape[0:2] 77 | img_out = np.zeros((max(h1, h2), w1+w2, 3), dtype=img_1.dtype) 78 | img_out[:h1,:w1,:] = img_1 79 | img_out[:h2,w1:(w1+w2),:] = img_2 80 | 81 | mask = ((corr_1to2[...,0]>1) & (corr_1to2[...,0]1) & (corr_1to2[...,1] 1: 87 | pt_v = (pt_x1%grid_step==0) & (pt_y1%grid_step==0) 88 | pt_x1 = pt_x1[pt_v] 89 | pt_y1 = pt_y1[pt_v] 90 | 91 | pt_x2 = corr_1to2[pt_y1,pt_x1,0] + w1 92 | pt_y2 = corr_1to2[pt_y1,pt_x1,1] 93 | pt_color = points2color(np.stack([pt_x1, pt_y1], axis=1)) 94 | 95 | for x1, y1, x2, y2, c in zip(pt_x1, pt_y1, pt_x2, pt_y2, pt_color): 96 | c = c.tolist() 97 | cv2.arrowedLine(img_out, (x1, y1), (x2, y2), c, line_type=cv2.LINE_AA, tipLength=0.02) 98 | 99 | return img_out 100 | 101 | 102 | def points2color(points, method='Lab'): 103 | ''' 104 | points: (N, 2) point coordinates 105 | method: {'Lab'} 106 | ''' 107 | if method == 'Lab': 108 | range_x = points[:,0].max() - points[:,0].min() 109 | range_y = points[:,1].max() - points[:,1].min() 110 | L = np.ones(points.shape[0]) * 255 111 | A = points[:,0]*255.0/(range_x+0.1) 112 | B = points[:,1]*255.0/(range_y+0.1) 113 | C = np.stack([L,A,B], axis=1).astype(np.uint8) 114 | C = cv2.cvtColor(C.reshape(1,-1,3), cv2.COLOR_LAB2BGR).reshape(-1,3) 115 | return C 116 | else: 117 | raise NotImplementedError() 118 | 119 | def warp_image(img, flow): 120 | h, w = flow.shape[:2] 121 | m = flow.astype(np.float32) 122 | m[:,:,0] += np.arange(w) 123 | m[:,:,1] += np.arange(h)[:,np.newaxis] 124 | res = cv2.remap(img, m, None, cv2.INTER_LINEAR, cv2.BORDER_REPLICATE) 125 | return res 126 | 127 | 128 | def flow_to_rgb(flow): 129 | h, w = flow.shape[:2] 130 | hsv = np.zeros((h, w, 3)) 131 | hsv[..., 1] = 255 132 | 133 | mag, ang = cv2.cartToPolar(flow[..., 0], flow[..., 1]) 134 | hsv[..., 0] = ang*180/np.pi/2 135 | hsv[..., 2] = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX) 136 | hsv = hsv.astype(np.uint8) 137 | rgb = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB) 138 | 139 | return rgb 140 | 141 | 142 | def corr_to_flow(corr, vis=None, order='NCHW'): 143 | ''' 144 | order should be one of {'NCHW', 'HWC'} 145 | ''' 146 | if order == 'NCHW': 147 | if isinstance(corr, torch.Tensor): 148 | flow = corr.clone() 149 | flow[:,0,:,:] -= torch.arange(flow.shape[3], dtype=flow.dtype, device=flow.device) # x-axis 150 | flow[:,1,:,:] -= torch.arange(flow.shape[2], dtype=flow.dtype, device=flow.device).view(-1,1) # y-axis 151 | elif isinstance(corr, np.ndarray): 152 | flow = corr.copy() 153 | flow[:,0,:,:] -= np.arange(flow.shape[3]) 154 | flow[:,1,:,:] -= np.arange(flow.shape[2]).reshape(-1,1) 155 | elif order == 'HWC': 156 | if isinstance(corr, torch.Tensor): 157 | flow = corr.clone() 158 | flow[:,:,0] -= torch.arange(flow.shape[1], dtype=flow.dtype, device=flow.device) 159 | flow[:,:,1] -= torch.arange(flow.shape[0], dtype=flow.dtype, device=flow.device).view(-1,1) 160 | elif isinstance(corr, np.ndarray): 161 | flow = corr.copy() 162 | flow[:,:,0] -= np.arange(flow.shape[1]).reshape(-1,) 163 | flow[:,:,1] -= np.arange(flow.shape[0]).reshape(-1,1) 164 | if vis is not None: 165 | if isinstance(vis, torch.Tensor): 166 | vis = (vis<2).float() 167 | elif isinstance(vis, np.ndarray): 168 | vis = (vis<2).astype(np.float32) 169 | flow *= vis 170 | return flow 171 | 172 | -------------------------------------------------------------------------------- /util/image_pool.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | from torch.autograd import Variable 4 | 5 | 6 | class ImagePool(): 7 | def __init__(self, pool_size): 8 | self.pool_size = pool_size 9 | if self.pool_size > 0: 10 | self.num_imgs = 0 11 | self.images = [] 12 | 13 | def query(self, images): 14 | if self.pool_size == 0: 15 | return Variable(images) 16 | return_images = [] 17 | for image in images: 18 | image = torch.unsqueeze(image, 0) 19 | if self.num_imgs < self.pool_size: 20 | self.num_imgs = self.num_imgs + 1 21 | self.images.append(image) 22 | return_images.append(image) 23 | else: 24 | p = random.uniform(0, 1) 25 | if p > 0.5: 26 | random_id = random.randint(0, self.pool_size-1) 27 | tmp = self.images[random_id].clone() 28 | self.images[random_id] = image 29 | return_images.append(tmp) 30 | else: 31 | return_images.append(image) 32 | return_images = Variable(torch.cat(return_images, 0)) 33 | return return_images -------------------------------------------------------------------------------- /util/io.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pickle 3 | import os 4 | import shutil 5 | 6 | #io functions of SCRC 7 | def load_str_list(filename, end = '\n'): 8 | with open(filename, 'r') as f: 9 | str_list = f.readlines() 10 | str_list = [s[:-len(end)] for s in str_list] 11 | return str_list 12 | 13 | def save_str_list(str_list, filename, end = '\n'): 14 | str_list = [s+end for s in str_list] 15 | with open(filename, 'w') as f: 16 | f.writelines(str_list) 17 | 18 | def load_json(filename): 19 | with open(filename, 'r') as f: 20 | return json.load(f) 21 | 22 | def save_json(json_obj, filename): 23 | with open(filename, 'w') as f: 24 | # json.dump(json_obj, f, separators=(',\n', ':\n')) 25 | json.dump(json_obj, f, indent = 0, separators = (',', ': ')) 26 | 27 | def mkdir_if_missing(output_dir): 28 | """ 29 | def mkdir_if_missing(output_dir) 30 | """ 31 | if not os.path.exists(output_dir): 32 | os.makedirs(output_dir) 33 | 34 | def save_data(data, filename): 35 | with open(filename, 'wb') as f: 36 | pickle.dump(data, f, pickle.HIGHEST_PROTOCOL) 37 | 38 | def load_data(filename): 39 | with open(filename, 'rb') as f: 40 | data = pickle.load(f,encoding='iso-8859-1') 41 | return data 42 | 43 | def load_data_json(filename): 44 | with open(filename, 'rb') as f: 45 | data = json.load(f) 46 | return data 47 | 48 | def copy(fn_src, fn_tar): 49 | shutil.copyfile(fn_src, fn_tar) 50 | 51 | -------------------------------------------------------------------------------- /util/loss_buffer.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import numpy as np 4 | from collections import OrderedDict 5 | 6 | class LossBuffer(): 7 | 8 | def __init__(self, size=1000): 9 | self.size = size 10 | self.buffer = OrderedDict() 11 | 12 | def clear(self): 13 | for k in self.buffer: 14 | self.buffer[k] = [] 15 | 16 | def add(self, errors): 17 | if not self.buffer: 18 | for k in errors: 19 | self.buffer[k] = [] 20 | 21 | for k, v in errors.items(): 22 | self.buffer[k].append(v) 23 | if len(self.buffer[k]) > self.size * 2: 24 | self.buffer[k] = self.buffer[k][-self.size::] 25 | 26 | def get_errors(self, clear=True): 27 | errors = OrderedDict() 28 | for k, buff in self.buffer.items(): 29 | errors[k] = np.mean(buff[-self.size::]) 30 | # print('[loss buffer] length: %d'%(len(buff[-self.size::]))) 31 | return errors 32 | -------------------------------------------------------------------------------- /util/pose_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from skimage.draw import circle, line_aa, polygon 3 | from skimage.morphology import dilation, erosion, square 4 | from skimage.measure import compare_ssim 5 | import pandas as pd 6 | import json 7 | import os 8 | import cv2 9 | ############################################################################################## 10 | # Derived from Deformable GAN (https://github.com/AliaksandrSiarohin/pose-gan) 11 | ############################################################################################## 12 | LIMB_SEQ = [[1,2], [1,5], [2,3], [3,4], [5,6], [6,7], [1,8], [8,9], 13 | [9,10], [1,11], [11,12], [12,13], [1,0], [0,14], [14,16], 14 | [0,15], [15,17], [2,1], [5,1]] 15 | COLORS = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], 16 | [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], 17 | [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]] 18 | LABELS = ['nose', 'neck', 'Rsho', 'Relb', 'Rwri', 'Lsho', 'Lelb', 'Lwri', 19 | 'Rhip', 'Rkne', 'Rank', 'Lhip', 'Lkne', 'Lank', 'Leye', 'Reye', 'Lear', 'Rear'] 20 | MISSING_VALUE = -1 21 | 22 | def map_to_coords(pose_map, threshold=0.1): 23 | ''' 24 | Input: 25 | pose_map: (h, w, channel) 26 | Output: 27 | coord: 28 | ''' 29 | all_peaks = [[] for i in range(18)] 30 | pose_map = pose_map[...,:18] 31 | 32 | y,x,z = np.where(np.logical_and(pose_map==pose_map.max(axis=(0,1)), pose_map>threshold)) 33 | for x_i, y_i, z_i in zip(x, y, z): 34 | all_peaks[z_i].append([x_i, y_i]) 35 | 36 | x_values = [] 37 | y_values = [] 38 | 39 | for i in range(18): 40 | if len(all_peaks[i]) != 0: 41 | x_values.append(all_peaks[i][0][0]) 42 | y_values.append(all_peaks[i][0][1]) 43 | else: 44 | x_values.append(MISSING_VALUE) 45 | y_values.append(MISSING_VALUE) 46 | 47 | return np.concatenate([np.expand_dims(y_values, -1), np.expand_dims(x_values, -1)], axis=1) 48 | 49 | def draw_pose_from_coords(pose_joints, img_size, radius=2, draw_joints=True): 50 | colors = np.zeros(shape=img_size + (3, ), dtype=np.uint8) 51 | mask = np.zeros(shape=img_size, dtype=bool) 52 | 53 | if draw_joints: 54 | for f, t in LIMB_SEQ: 55 | from_missing = pose_joints[f][0] == MISSING_VALUE or pose_joints[f][1] == MISSING_VALUE 56 | to_missing = pose_joints[t][0] == MISSING_VALUE or pose_joints[t][1] == MISSING_VALUE 57 | if from_missing or to_missing: 58 | continue 59 | yy, xx, val = line_aa(pose_joints[f][0], pose_joints[f][1], pose_joints[t][0], pose_joints[t][1]) 60 | colors[yy, xx] = np.expand_dims(val, 1) * 255 61 | mask[yy, xx] = True 62 | for i, joint in enumerate(pose_joints): 63 | if pose_joints[i][0] == MISSING_VALUE or pose_joints[i][1] == MISSING_VALUE: 64 | continue 65 | yy, xx = circle(joint[0], joint[1], radius=radius, shape=img_size) 66 | colors[yy, xx] = COLORS[i] 67 | mask[yy, xx] = True 68 | 69 | return colors, mask 70 | 71 | def draw_pose_from_map(pose_map, threshold=0.1, radius=2, draw_joints=True): 72 | img_size = pose_map.shape[0:2] 73 | coords = map_to_coords(pose_map, threshold) 74 | return draw_pose_from_coords(coords, img_size, radius, draw_joints) 75 | 76 | def get_pose_mask(pose, img_size, point_radius=4): 77 | mask = np.zeros(shape=img_size, dtype=bool) 78 | limbs = [[2,3], [2,6], [3,4], [4,5], [6,7], [7,8], [2,9], [9,10],\ 79 | [10,11], [2,12], [12,13], [13,14], [2,1], [1,15], [15,17],\ 80 | [1,16], [16,18], [2,17], [2,18], [9,12], [12,6], [9,3], [17,18]] 81 | limbs = np.array(limbs) - 1 82 | for f,t in limbs: 83 | from_missing = pose[f][0] < 0 or pose[f][1] < 0 84 | to_missing = pose[t][0] < 0 or pose[t][1] < 0 85 | if from_missing or to_missing: 86 | continue 87 | norm_vec = pose[f] - pose[t] 88 | norm_vec = np.array([-norm_vec[1], norm_vec[0]]) 89 | norm_vec = point_radius * norm_vec / np.linalg.norm(norm_vec+1e-8) 90 | 91 | vetexes = np.array([ 92 | pose[f] + norm_vec, 93 | pose[f] - norm_vec, 94 | pose[t] - norm_vec, 95 | pose[t] + norm_vec]) 96 | 97 | yy, xx = polygon(vetexes[:,0], vetexes[:,1], shape=img_size) 98 | mask[yy, xx] = True 99 | 100 | for i, joint in enumerate(pose): 101 | if pose[i][0] < 0 or pose[i][1] < 0: 102 | continue 103 | yy, xx = circle(joint[0], joint[1], radius=point_radius, shape=img_size) 104 | mask[yy, xx] = True 105 | 106 | mask = dilation(mask, square(5)) 107 | mask = erosion(mask, square(5)) 108 | 109 | return mask 110 | 111 | def get_pose_mask_batch(pose, img_size, point_radius=4): 112 | ''' 113 | Input: 114 | pose (tensor): (N,18,2) key points 115 | img_size (tuple): (h, w) 116 | point_radius (int): width of skeleton mask 117 | Output: 118 | mask (tensor): (N, 1, h, w) 119 | ''' 120 | mask = [] 121 | pose_np = pose.cpu().numpy() 122 | for p in pose_np: 123 | m = get_pose_mask(p, img_size, point_radius) 124 | mask.append(m) 125 | mask = np.expand_dims(np.stack(mask), axis=1) 126 | mask = mask.astype(np.float32) 127 | return pose.new(mask) 128 | 129 | def load_pose_cords_from_strings(y_str, x_str): 130 | y_cords = json.loads(y_str) 131 | x_cords = json.loads(x_str) 132 | return np.concatenate([np.expand_dims(y_cords, -1), np.expand_dims(x_cords, -1)], axis=1) -------------------------------------------------------------------------------- /util/visualizer.py: -------------------------------------------------------------------------------- 1 | from __future__ import division, print_function 2 | 3 | import torch 4 | import torchvision 5 | 6 | import os 7 | import time 8 | import util.io as io 9 | import numpy as np 10 | from util import pose_util, flow_util 11 | 12 | 13 | def seg_to_rgb(seg_map, with_face=False): 14 | if isinstance(seg_map, np.ndarray): 15 | if seg_map.ndim == 3: 16 | seg_map = seg_map[np.newaxis,:] 17 | seg_map = torch.from_numpy(seg_map.transpose([0,3,1,2])) 18 | elif isinstance(seg_map, torch.Tensor): 19 | seg_map = seg_map.cpu() 20 | if seg_map.dim() == 3: 21 | seg_map = seg_map.unsqueeze(0) 22 | 23 | if with_face: 24 | face = seg_map[:,-3::] 25 | seg_map = seg_map[:,0:-3] 26 | 27 | if seg_map.size(1) > 1: 28 | seg_map = seg_map.max(dim=1, keepdim=True)[1] 29 | else: 30 | seg_map = seg_map.long() 31 | 32 | b,c,h,w = seg_map.size() 33 | assert c == 1 34 | 35 | cmap = [[73,0,255], [255,0,0], [255,0,219], [255, 219,0], [0,255,146], [0,146,255], [146,0,255], [255,127,80], [0,255,0], [0,0,255], 36 | [37, 0, 127], [127,0,0], [127,0,109], [127,109,0], [0,127,73], [0,73,127], [73,0, 127], [127, 63, 40], [0,127,0], [0,0,127]] 37 | cmap = torch.Tensor(cmap)/255. 38 | cmap = cmap[0:(seg_map.max()+1)] 39 | 40 | rgb_map = cmap[seg_map.view(-1)] 41 | rgb_map = rgb_map.view(b, h, w, 3) 42 | rgb_map = rgb_map.transpose(1,3).transpose(2,3) 43 | rgb_map.sub_(0.5).div_(0.5) 44 | 45 | if with_face: 46 | face_mask = ((seg_map == 1) | (seg_map == 2)).float() 47 | rgb_map = rgb_map * (1 - face_mask) + face * face_mask 48 | 49 | return rgb_map 50 | 51 | def merge_visual(visuals): 52 | imgs = [] 53 | vis_list = [] 54 | for name, (vis, vis_type) in visuals.items(): 55 | vis = vis.cpu() 56 | if vis_type == 'rgb': 57 | vis_ = vis 58 | elif vis_type == 'seg': 59 | vis_ = seg_to_rgb(vis) 60 | elif vis_type == 'pose': 61 | pose_maps = vis.numpy().transpose(0,2,3,1) 62 | vis_ = np.stack([pose_util.draw_pose_from_map(m)[0] for m in pose_maps]) 63 | vis_ = vis.new(vis_.transpose(0,3,1,2)).float()/127.5 - 1. 64 | elif vis_type == 'flow': 65 | flows = vis.numpy().transpose(0,2,3,1) 66 | vis_ = np.stack([flow_util.flow_to_rgb(f) for f in flows]) 67 | vis_ = vis.new(vis_.transpose(0,3,1,2)).float()/127.5 - 1. 68 | elif vis_type == 'vis': 69 | if vis.size(1) == 3: 70 | vis = vis.argmax(dim=1, keepdim=True) 71 | vis_ = vis.new(vis.size(0), 3, vis.size(2), vis.size(3)).float() 72 | vis_[:,0,:,:] = (vis==1).float().squeeze(dim=1)*2-1 # red: not visible 73 | vis_[:,1,:,:] = (vis==0).float().squeeze(dim=1)*2-1 # green: visible 74 | vis_[:,2,:,:] = (vis==2).float().squeeze(dim=1)*2-1 # blue: background 75 | elif vis_type == 'softmask': 76 | vis_ = (vis*2-1).repeat(1,3,1,1) 77 | imgs.append(vis_) 78 | vis_list.append(name) 79 | imgs = torch.stack(imgs, dim=1) 80 | imgs = imgs.view(imgs.size(0)*imgs.size(1), imgs.size(2), imgs.size(3), imgs.size(4)) 81 | imgs.clamp_(-1., 1.) 82 | return imgs, vis_list 83 | 84 | class Visualizer(object): 85 | def __init__(self, opt): 86 | self.opt = opt 87 | self.exp_dir = os.path.join('./checkpoints', opt.id) 88 | self.log_file = None 89 | 90 | def __del__(self): 91 | if self.log_file: 92 | self.log_file.close() 93 | 94 | def _open_log_file(self): 95 | fn = 'train_log.txt' if self.opt.is_train else 'test_log.txt' 96 | self.log_file = open(os.path.join(self.exp_dir, fn), 'a') 97 | print(time.ctime(), file=self.log_file) 98 | print('pytorch version: %s' % torch.__version__, file=self.log_file) 99 | 100 | 101 | def log(self, info='', errors={}, log_in_file=True): 102 | ''' 103 | Save log information into log file 104 | Input: 105 | info (dict or str): model id, iteration number, learning rate, etc. 106 | error (dict): output of loss functions or metrics. 107 | Output: 108 | log_str (str) 109 | ''' 110 | if isinstance(info, str): 111 | info_str = info 112 | elif isinstance(info, dict): 113 | info_str = ' '.join(['{}: {}'.format(k,v) for k, v in info.items()]) 114 | 115 | error_str = ' '.join(['%s: %.4f'%(k,v) for k, v in errors.items()]) 116 | log_str = '[%s] %s' %(info_str, error_str) 117 | 118 | if log_in_file: 119 | if self.log_file is None: 120 | self._open_log_file() 121 | print(log_str, file=self.log_file) 122 | return log_str 123 | 124 | def visualize_results(self, visuals, filename): 125 | io.mkdir_if_missing(os.path.dirname(filename)) 126 | imgs, vis_item_list = merge_visual(visuals) 127 | torchvision.utils.save_image(imgs, filename, nrow=len(visuals), normalize=True) 128 | fn_list = os.path.join(os.path.dirname(filename), 'vis_item_list.txt') 129 | io.save_str_list(vis_item_list, fn_list) 130 | 131 | --------------------------------------------------------------------------------