├── README.md ├── aic_test.py ├── config.py ├── data ├── __init__.py ├── augmentation.py ├── augmentation2.py ├── det_data_noexpand.py ├── detection_data.py ├── generate_256img.py ├── generate_data.py ├── hg_data.py ├── image_handle.py ├── regression_data.py └── test_data_noexpand.py ├── eval_train_val.py ├── jupyter_file ├── .ipynb_checkpoints │ ├── annotations鏁版嵁澶勭悊-checkpoint.ipynb │ ├── 鏁版嵁澧炲己娴嬭瘯-checkpoint.ipynb │ └── 鏁版嵁澧炲己娴嬭瘯2-checkpoint.ipynb ├── annotations鏁版嵁澶勭悊.ipynb ├── 鏁版嵁澧炲己娴嬭瘯.ipynb └── 鏁版嵁澧炲己娴嬭瘯2.ipynb ├── logs ├── AIC-HGNet_log.txt └── AIC-HGNet_log.txt.old ├── main_hg.py ├── models ├── BasicModule.py ├── Conv_part_hm_reg_model.py ├── HPE_det_reg_model.py ├── __init__.py ├── hourglass.py └── new_ResNet.py └── utils ├── __init__.py ├── eval_score.py ├── helper.py ├── ian_eval.py ├── ian_eval_tensor.py ├── logger.py ├── net_validation.py ├── prediction_handle.py └── visualize.py /README.md: -------------------------------------------------------------------------------- 1 | # Hourglass implementation 2 | - Paper: [Stacked Hourglass Networks for Human Pose Estimation.](https://arxiv.org/pdf/1603.06937.pdf) Alejandro Newell等 3 | - pytorch 实现 4 | - Datasets: AI challenger 人体骨骼关键点检测 -------------------------------------------------------------------------------- /aic_test.py: -------------------------------------------------------------------------------- 1 | # coding:utf8 2 | from config import opt 3 | import models 4 | from data import HPEPoseTestDataset 5 | import torch 6 | from torch.utils import data 7 | from torch.autograd import Variable 8 | from utils.prediction_handle import get_pred_kps 9 | from tqdm import tqdm 10 | from torch import nn 11 | import json 12 | 13 | 14 | def aic_test(): 15 | opt.model_id = 4 16 | model = getattr(models, opt.model[opt.model_id])(num_stacks=6) 17 | model = model.cuda() 18 | # with open('checkpoints/AIC-HGNet_progress.json', 'r') as f: 19 | # progress = json.load(f) 20 | # model.load_state_dict(torch.load(progress['best_path'])) 21 | best_path = 'checkpoints/AIC-HGNet_0.567476429117.model' 22 | model.load_state_dict(torch.load(best_path)) 23 | # model = nn.DataParallel(model, device_ids=[0, 1]) 24 | 25 | model.eval() 26 | 27 | opt.test_anno_file = '/media/bnrc2/_backup/ai/ai_challenger_keypoint_test_b_20171120/test_b_0.4.pkl' 28 | opt.test_img_dir = '/media/bnrc2/_backup/ai/ai_challenger_keypoint_test_b_20171120/keypoint_test_b_images_20171120/' 29 | dataset = HPEPoseTestDataset(opt.test_anno_file, opt.test_img_dir) 30 | 31 | print(len(dataset)) 32 | dataloader = data.DataLoader(dataset, batch_size=opt.val_bs, num_workers=opt.num_workers) 33 | print("proposessing data begin...") 34 | pred_list = [] 35 | for processed_img, processed_info in tqdm(dataloader, ncols=50): 36 | processed_img = processed_img.float() 37 | processed_img = Variable(processed_img.cuda()) 38 | pred_list += get_pred_kps(processed_info, model(processed_img)[-1].cpu()) 39 | print("proposessing data end...") 40 | pred_list_file = opt.interim_data_path + 'pred_test_list.pkl' 41 | with open(pred_list_file, 'wb') as f: 42 | import pickle 43 | pickle.dump(pred_list, f) 44 | submit = get_keypoints(pred_list) 45 | with open('12_03_submit.json', 'w') as f: 46 | json.dump(submit, f) 47 | 48 | 49 | def get_keypoints(pred_list): 50 | predictions = dict() 51 | for pred in pred_list: 52 | if pred[0] in list(predictions.keys()): 53 | predictions[pred[0]]['keypoint_annotations'].update(pred[1]) 54 | else: 55 | predictions[pred[0]] = { 56 | 'image_id': pred[0], 57 | 'keypoint_annotations': pred[1] 58 | } 59 | return list(predictions.values()) 60 | 61 | 62 | if __name__ == '__main__': 63 | aic_test() 64 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # coding:utf8 2 | import time 3 | import numpy as np 4 | 5 | 6 | class Config: 7 | def __init__(self): 8 | self.demo = False # 是否是在少量数据集上运行 9 | self.reuse = False 10 | self.resume = False 11 | self.device = 0 12 | self.progress_file = 'train_progress.json' 13 | self.title = "AIC" 14 | self.progress = { 15 | 'best_path': '', 16 | 'best_mAP': None, 17 | 'count': 0, 18 | 'lr': 2.5e-4, 19 | 'epoch': 0, 20 | } 21 | 22 | self.train_mean = (108, 115, 122) 23 | self.val_mean = (103, 111, 119) # [ 102.72338076, 111.23301759, 119.31230327] 24 | self.test_mean = (114, 122, 129) 25 | 26 | self.batch_size = 8 # batch 27 | self.val_bs = 8 # 验证集batch size 28 | self.epoch = 20 # 轮数 29 | self.start_epoch = 0 30 | self.start_count = 0 31 | self.lr = 2.5e-4 # 学习率 32 | self.min_lr = 1e-6 33 | self.lr_decay = 0.99 # 34 | self.shuffle = True 35 | self.augument = False # 是否使用数据增强 36 | self.check_every = 5000 # 每多少个batch查看一下mAP,并修改学习率 37 | self.is_train = True # 是否是训练阶段 38 | self.plot_every = 100 # 每10个batch, 更新visdom 39 | 40 | self.root_dir = './' 41 | # self.root_dir = '/home/hadoop/deeplearning/pytorch/pytorch_ai_challenger_HPE/' 42 | self.dataset_root_dir = '/media/bnrc2/_backup/ai/' 43 | 44 | self.img_dir = '/media/bnrc2/_backup/ai/ai_challenger_keypoint_train_20170902/keypoint_train_images_20170902/' 45 | self.annotations_file = '/home/bnrc2/ai_challenge/ian/hg.aic.pytorch/official/' \ 46 | 'keypoint_train_annotations_newclear.json' 47 | 48 | self.val_anno_file = '/home/bnrc2/ai_challenge/ian/hg.aic.pytorch/official/' \ 49 | 'keypoint_validation_annotations_newclear.json' 50 | self.val_img_dir = '/media/bnrc2/_backup/ai/ai_challenger_keypoint_validation_20170911' \ 51 | '/keypoint_validation_images_20170911/' 52 | self.test_anno_file = '/media/bnrc2/_backup/ai/ai_challenger_keypoint_test_a_20170923/test_anno.pkl' 53 | self.test_img_dir = '/media/bnrc2/_backup/ai/ai_challenger_keypoint_test_a_20170923/' \ 54 | 'keypoint_test_a_images_20170923/' 55 | # self.val_anno_file = '/home/bnrc2/ai_challenge/ian/Pytorch_Human_Pose_Estimation/interim_data/val_dataset/' 56 | # self.val_img_dir = '/home/bnrc2/ai_challenge/ian/Pytorch_Human_Pose_Estimation/interim_data/val_dataset' \ 57 | # '/val_imgs/' 58 | 59 | self.interim_data_path = self.root_dir + 'interim_data/' # 训练过程中产生的中间数据(暂存) 60 | self.model_path = self.root_dir + 'models/' 61 | self.checkpoints = self.root_dir + 'checkpoints/' 62 | self.logs_path = self.root_dir + 'logs/' 63 | self.model_id = 0 # 选择的模型编号, 64 | self.model = ['Part_detection_subnet_model', 'Regression_subnet', 'Part_detection_subnet101', 65 | 'Regression_subnet101', 'HourglassNet', 'AIC-HGNet'] # 模型 66 | 67 | self.threshold = 0.0 68 | 69 | self.num_workers = 4 # 多线程加载所需要的线程数目 70 | self.pin_memory = False # 数据从CPU->pin_memory—>GPU加速 71 | 72 | self.env = time.strftime('%m%d_%H%M%S') # Visdom env 73 | # 1/右肩,2/右肘,3/右腕,4/左肩,5/左肘,6/左腕,7/右髋,8/右膝,9/右踝,10/左髋,11/左膝,12/左踝,13/头顶,14/脖子 74 | self.part_id = {1: 'r_shoulder', 2: 'r_elbow', 3: 'r_wrist', 4: 'l_shoulder', 5: 'l_elbow', 6: 'l_wrist', 75 | 7: 'r_hip', 8: 'r_knee', 9: 'r_ankle', 10: 'l_hip', 11: 'l_knee', 12: 'l_ankle', 13: 'head', 76 | 14: 'neck'} 77 | self.delta = 2 * np.array([0.01388152, 0.01515228, 0.01057665, 0.01417709, \ 78 | 0.01497891, 0.01402144, 0.03909642, 0.03686941, 0.01981803, \ 79 | 0.03843971, 0.03412318, 0.02415081, 0.01291456, 0.01236173]) 80 | 81 | def parse(self, kwargs): 82 | """ 83 | 根据字典kwargs 更新 config参数 84 | """ 85 | for k, v in kwargs.items(): 86 | if not hasattr(self, k): 87 | raise Exception("opt has not attribute <%s>" % k) 88 | setattr(self, k, v) 89 | 90 | def config_info_print(self): 91 | print("Train&Val Batch size: ".rjust(30, ' '), self.batch_size, self.val_bs) 92 | print("Epochs: ".rjust(30, ' '), self.epoch) 93 | print("GPU device: ".rjust(30, ' '), self.device) 94 | print("Beginning Learning Rate: ".rjust(30, ' '), self.lr) 95 | print("Check&Plot every: ".rjust(30, ' '), self.check_every, self.plot_every) 96 | print("Train Data Dir: ".rjust(30, ' '), self.annotations_file) 97 | print("Val Data Dir: ".rjust(30, ' '), self.val_anno_file) 98 | if self.demo: 99 | print(" ".rjust(30, ' '), "NOTE: Using demo data!") 100 | else: 101 | print(" ".rjust(30, ' '), "NOTE: Not using demo data!") 102 | 103 | 104 | opt = Config() 105 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from .generate_data import HumanPoseDetectionDataset, HumanPoseValDataset, HumanPoseRegressionDataset, \ 2 | RegressionValDataset 3 | from .image_handle import * 4 | from .augmentation import * 5 | from .detection_data import HPEDetDataset 6 | from .regression_data import HPEPoseDataset, HPEPoseValDataset, HPEPoseTestDataset 7 | from .augmentation2 import HPEAugmentation2 8 | from .det_data_noexpand import HPEDetDataset_NE 9 | from .hg_data import hgDataset, hgValDataset, EvalDataset 10 | -------------------------------------------------------------------------------- /data/augmentation.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import cv2 3 | import numpy as np 4 | from numpy import random 5 | import math 6 | 7 | 8 | class RandomRotation(object): 9 | def __call__(self, image, keypoints): 10 | h, w, c = image.shape 11 | center = (w / 2, h / 2) 12 | if random.randint(2): 13 | degree = random.randint(-30, 30) 14 | w_r = int(w * math.fabs(math.cos(math.radians(degree))) + 15 | h * math.fabs(math.sin(math.radians(degree)))) 16 | h_r = int(w * math.fabs(math.sin(math.radians(degree))) + 17 | h * math.fabs(math.cos(math.radians(degree)))) 18 | m = cv2.getRotationMatrix2D(center, degree, 1) # 旋转矩阵 19 | m[0][2] += (w_r - w) / 2 20 | m[1][2] += (h_r - h) / 2 21 | image = cv2.warpAffine(image, m, (w_r, h_r)) 22 | # 计算旋转后的关节点的位置 23 | kps = np.reshape(keypoints, (-1, 3))[:, 0:2] 24 | v = np.reshape(keypoints, (-1, 3))[:, 2] 25 | kps = np.hstack((kps, np.ones(14, dtype=int)[:, np.newaxis])) 26 | kps = np.dot(kps, m.T).astype('int') 27 | keypoints = np.hstack((kps, v[:, np.newaxis])) 28 | return image, keypoints 29 | 30 | 31 | class RandomMirror(object): # 实现图片的水平反转flip 32 | def __call__(self, image, keypoints): 33 | h, w, c = image.shape 34 | 35 | if random.randint(2): 36 | image = image[:, ::-1].copy() 37 | keypoints = keypoints.copy() 38 | keypoints[:, 0] = w - keypoints[:, 0] 39 | return image, keypoints 40 | 41 | 42 | # ---------------------------------------------------------------------- 43 | class RandomContrast(object): 44 | def __init__(self, lower=0.5, upper=1.5): 45 | self.lower = lower 46 | self.upper = upper 47 | assert self.upper >= self.lower, "contrast upper must be >= lower." 48 | assert self.lower >= 0, "contrast lower must be non-negative." 49 | 50 | # expects float image 51 | def __call__(self, image, keypoints=None): 52 | if random.randint(2): 53 | alpha = random.uniform(self.lower, self.upper) 54 | # image *= alpha 55 | np.multiply(image, alpha, out=image, casting='unsafe') 56 | return image, keypoints 57 | 58 | 59 | class ConvertColor(object): 60 | """颜色空间转换""" 61 | 62 | def __init__(self, current='BGR', transform='HSV'): 63 | self.transform = transform 64 | self.current = current 65 | 66 | def __call__(self, image, keypoints=None): 67 | if self.current == 'BGR' and self.transform == 'HSV': 68 | image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) 69 | elif self.current == 'HSV' and self.transform == 'BGR': 70 | image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR) 71 | else: 72 | raise NotImplementedError 73 | return image, keypoints 74 | 75 | 76 | class RandomBrightness(object): 77 | """随机增加图片的亮度""" 78 | 79 | def __init__(self, delta=32): 80 | assert delta >= 0.0 81 | assert delta <= 255.0 82 | self.delta = delta 83 | 84 | def __call__(self, image, keypoints=None): 85 | if random.randint(2): 86 | delta = random.uniform(-self.delta, self.delta) 87 | # image += delta 88 | np.add(image, delta, out=image, casting="unsafe") 89 | return image, keypoints 90 | 91 | 92 | class RandomSaturation(object): 93 | """随机饱和度""" 94 | 95 | def __init__(self, lower=0.5, upper=1.5): 96 | self.lower = lower 97 | self.upper = upper 98 | assert self.upper >= self.lower, "contrast upper must be >= lower." 99 | assert self.lower >= 0, "contrast lower must be non-negative." 100 | 101 | def __call__(self, image, keypoints=None): 102 | if random.randint(2): 103 | # image[:, :, 1] *= random.uniform(self.lower, self.upper) 104 | np.multiply(image[:, :, 1], random.uniform(self.lower, self.upper), 105 | out=image[:, :, 1], casting="unsafe") 106 | 107 | return image, keypoints 108 | 109 | 110 | class RandomHue(object): 111 | """随机色调(色相)""" 112 | 113 | def __init__(self, delta=18.0): 114 | assert 0.0 <= delta <= 360.0 115 | self.delta = delta 116 | 117 | def __call__(self, image, keypoints=None): 118 | if random.randint(2): 119 | # image[:, :, 0] += random.uniform(-self.delta, self.delta) 120 | np.add(image[:, :, 0], 121 | random.uniform(-self.delta, self.delta), 122 | out=image[:, :, 0], casting="unsafe") 123 | # image[:, :, 0][image[:, :, 0] > 360.0] -= 360.0 124 | np.subtract(image[:, :, 0][image[:, :, 0] > 360.0], 360.0, 125 | out=image[:, :, 0][image[:, :, 0] > 360.0], casting='unsafe') 126 | # image[:, :, 0][image[:, :, 0] < 0.0] += 360.0 127 | np.subtract(image[:, :, 0][image[:, :, 0] < 0.0], 360.0, 128 | out=image[:, :, 0][image[:, :, 0] < 0.0], casting='unsafe') 129 | return image, keypoints 130 | 131 | 132 | class RandomLightingNoise(object): 133 | """随机交换图片的通道,也就是随机shuffle channels的顺序""" 134 | 135 | def __init__(self): 136 | self.perms = ((0, 1, 2), (0, 2, 1), 137 | (1, 0, 2), (1, 2, 0), 138 | (2, 0, 1), (2, 1, 0)) 139 | 140 | def __call__(self, image, keypoints=None): 141 | if random.randint(2): 142 | swap = self.perms[random.randint(len(self.perms))] 143 | shuffle = SwapChannels(swap) # shuffle channels 144 | image = shuffle(image) 145 | return image, keypoints 146 | 147 | 148 | class SwapChannels(object): 149 | def __init__(self, swaps): 150 | self.swaps = swaps 151 | 152 | def __call__(self, image): 153 | image = image[:, :, self.swaps] 154 | return image 155 | 156 | 157 | class PhotometricDistort(object): 158 | def __init__(self): 159 | self.pd = [ 160 | RandomContrast(), 161 | ConvertColor(transform='HSV'), 162 | RandomSaturation(), 163 | RandomHue(), 164 | ConvertColor(current='HSV', transform='BGR'), 165 | RandomContrast() 166 | ] 167 | self.rand_brightness = RandomBrightness() 168 | self.rand_light_noise = RandomLightingNoise() 169 | 170 | def __call__(self, image, keypoints): 171 | im = image.copy() 172 | im, keypoints = self.rand_brightness(im, keypoints) 173 | if random.randint(2): 174 | distort = Compose(self.pd[:-1]) 175 | else: 176 | distort = Compose(self.pd[1:]) 177 | im, keypoints = distort(im, keypoints) 178 | return self.rand_light_noise(im, keypoints) 179 | 180 | 181 | # ---------------------------------------------------------------------- 182 | 183 | 184 | class Make_padding(object): 185 | def __call__(self, image, mean): 186 | h, w, c = image.shape 187 | if h > w: 188 | pad = (np.zeros((h, h - w, 3)) + mean).astype(image.dtype) 189 | image = np.concatenate((image, pad), axis=1) 190 | elif w > h: 191 | pad = (np.zeros((w - h, w, 3)) + mean).astype(image.dtype) 192 | image = np.concatenate((image, pad), axis=0) 193 | return image 194 | 195 | 196 | class Resize(object): 197 | def __init__(self, size=256, mean=(0, 0, 0)): 198 | self.size = size 199 | self.mean = mean 200 | self.mp = Make_padding() 201 | 202 | def __call__(self, image, keypoints): 203 | image = self.mp(image, self.mean) 204 | image = cv2.resize(image, (self.size, self.size)) 205 | return image, keypoints 206 | 207 | 208 | class KeypointTransform(object): 209 | def __call__(self, image, keypoints, side=256.0): 210 | h, w, _ = image.shape 211 | if h > w: 212 | scale = side / h 213 | m = np.array([[scale, 0., 0.], [0., scale, 0.]]) # 缩放变换矩阵 214 | keypoints = np.reshape(keypoints.copy(), (-1, 3)) 215 | v = keypoints[:, 2] 216 | # keypoints = np.dot(keypoints, m.T) 217 | keypoints = np.hstack((np.dot(keypoints, m.T), v[:, np.newaxis])) 218 | else: 219 | scale = side / w 220 | m = np.array([[scale, 0., 0.], [0., scale, 0.]]) # 缩放变换矩阵 221 | keypoints = np.reshape(keypoints.copy(), (-1, 3)) 222 | v = keypoints[:, 2] 223 | keypoints = np.hstack((np.dot(keypoints, m.T), v[:, np.newaxis])) 224 | return image, keypoints.astype('int') 225 | 226 | 227 | class GenerateLabel(object): 228 | def __call__(self, image, keypoints): 229 | radius = 10 230 | heatmap_res = None 231 | x = np.arange(0, 256, dtype=np.uint32) 232 | y = np.arange(0, 256, dtype=np.uint32)[:, np.newaxis] 233 | for kp in keypoints: 234 | if kp[2] == 1: 235 | if heatmap_res is None: 236 | heatmap_res = ((x - kp[0]) ** 2 + (y - kp[1]) ** 2) <= radius ** 2 237 | else: 238 | heatmap_res = np.vstack((heatmap_res, (((x - kp[0]) ** 2 + (y - kp[1]) ** 2) <= radius ** 2))) 239 | else: 240 | if heatmap_res is None: 241 | heatmap_res = np.zeros((256, 256), dtype=np.uint8) 242 | else: 243 | heatmap_res = np.vstack((heatmap_res, np.zeros((256, 256), dtype=np.uint8))) 244 | heatmap_res = heatmap_res.astype(np.uint8) 245 | 246 | return image, np.reshape(heatmap_res, (-1, 256, 256)) 247 | 248 | 249 | def makeGaussian(height, width, sigma=5, center=None): 250 | """ 251 | Make a square gaussian kernel. 252 | :param height: 边长 253 | :param width: 边长 254 | :param sigma: 分布的幅度,标准差 255 | :param center: 高斯核中心 256 | :return: heatmap 带有高斯核 257 | """ 258 | x = np.arange(0, width, 1, float) 259 | y = np.arange(0, height, 1, float)[:, np.newaxis] 260 | if center is None: 261 | x0 = width // 2 262 | y0 = height // 2 263 | else: 264 | x0 = center[0] 265 | y0 = center[1] 266 | return np.exp(-4 * np.log(2) * ((x - x0) ** 2 + (y - y0) ** 2) / sigma ** 2) 267 | 268 | 269 | class GenerateHeatMap(object): 270 | def __call__(self, image, keypoints, H=256, W=256, sigma=3): 271 | hm = np.zeros((14, H, W), dtype=np.float32) 272 | for i, kp in enumerate(keypoints): 273 | if kp[2] == 1: 274 | hm[i] = makeGaussian(H, W, sigma=sigma, center=(kp[0], kp[1])) 275 | return image, hm 276 | 277 | 278 | class Compose(object): 279 | def __init__(self, transforms): 280 | self.transforms = transforms 281 | 282 | def __call__(self, iamge, keypoints): 283 | for t in self.transforms: 284 | image, keypoints = t(iamge, keypoints) 285 | return image, keypoints 286 | 287 | 288 | class SubtractMeans(object): 289 | def __init__(self, mean): 290 | self.mean = np.array(mean, dtype=np.float32) 291 | 292 | def __call__(self, image, keypoints=None): 293 | image = image.astype(np.float32) 294 | image -= self.mean 295 | return image.astype(np.float32), keypoints 296 | 297 | 298 | class HPEAugmentation(object): 299 | def __init__(self, mean=(0, 0, 0)): 300 | self.rm = RandomMirror() 301 | self.resize = Resize(mean=mean) 302 | self.submean = SubtractMeans(mean=mean) 303 | self.kptrans = KeypointTransform() 304 | 305 | def __call__(self, image, keypoints): 306 | image, keypoints = self.rm(image, keypoints) 307 | image, keypoints = self.kptrans(image, keypoints) 308 | image, keypoints = self.resize(image, keypoints) 309 | return self.submean(image, keypoints) 310 | 311 | 312 | class HPEBaseTransform(object): 313 | def __init__(self, mean=(0, 0, 0), hm_side=256.0): 314 | self.kptrans = KeypointTransform() 315 | self.resize = Resize(mean=mean) 316 | self.submean = SubtractMeans(mean=mean) 317 | self.hm_size = hm_side 318 | 319 | def __call__(self, image, keypoints): 320 | image, keypoints = self.kptrans(image, keypoints, self.hm_size) 321 | image, keypoints = self.resize(image, keypoints) 322 | return self.submean(image, keypoints) 323 | 324 | 325 | class HPETest(object): 326 | def __init__(self): 327 | self.rm = RandomMirror() 328 | self.pd = PhotometricDistort() 329 | self.rr = RandomRotation() 330 | self.kt = KeypointTransform() 331 | 332 | def __call__(self, image, keypoints): 333 | image, keypoints = self.rm(image, keypoints) 334 | image, keypoints = self.pd(image, keypoints) 335 | image, keypoints = self.rr(image, keypoints) 336 | image, keypoints = self.kt(image, keypoints) 337 | return image, keypoints 338 | -------------------------------------------------------------------------------- /data/augmentation2.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import cv2 3 | import numpy as np 4 | from numpy import random 5 | import math 6 | 7 | 8 | class RandomRotation(object): 9 | def __call__(self, image, keypoints): 10 | h, w, c = image.shape 11 | center = (w / 2, h / 2) 12 | if random.randint(2): 13 | degree = random.randint(-30, 30) 14 | w_r = int(w * math.fabs(math.cos(math.radians(degree))) + 15 | h * math.fabs(math.sin(math.radians(degree)))) 16 | h_r = int(w * math.fabs(math.sin(math.radians(degree))) + 17 | h * math.fabs(math.cos(math.radians(degree)))) 18 | m = cv2.getRotationMatrix2D(center, degree, 1) # 旋转矩阵 19 | m[0][2] += (w_r - w) / 2 20 | m[1][2] += (h_r - h) / 2 21 | image = cv2.warpAffine(image, m, (w_r, h_r)) 22 | # 计算旋转后的关节点的位置 23 | kps = np.reshape(keypoints, (-1, 3))[:, 0:2] 24 | v = np.reshape(keypoints, (-1, 3))[:, 2] 25 | kps = np.hstack((kps, np.ones(14, dtype=int)[:, np.newaxis])) 26 | kps = np.dot(kps, m.T).astype('int') 27 | keypoints = np.hstack((kps, v[:, np.newaxis])) 28 | print(keypoints) 29 | print('--------RR---------') 30 | return image, keypoints 31 | 32 | 33 | class RandomMirror(object): # 实现图片的水平反转flip 34 | def __call__(self, image, keypoints): 35 | h, w, c = image.shape 36 | 37 | if random.randint(2): 38 | image = image[:, ::-1].copy() 39 | keypoints = keypoints.copy() 40 | keypoints[:, 0] = w - keypoints[:, 0] 41 | print(keypoints) 42 | print('--------RM---------') 43 | return image, keypoints 44 | 45 | 46 | class Make_padding(object): 47 | def __call__(self, image): 48 | h, w, c = image.shape 49 | if h > w: 50 | image = np.concatenate((image, np.zeros((h, h - w, 3), dtype=image.dtype)), axis=1) 51 | elif w > h: 52 | image = np.concatenate((image, np.zeros((w - h, w, 3), dtype=image.dtype)), axis=0) 53 | return image 54 | 55 | 56 | class Resize(object): 57 | def __init__(self, size=256): 58 | self.size = size 59 | self.mp = Make_padding() 60 | 61 | def __call__(self, image, keypoints): 62 | image = self.mp(image) 63 | image = cv2.resize(image, (self.size, self.size)) 64 | return image, keypoints 65 | 66 | 67 | class KeypointTransform(object): 68 | def __call__(self, image, keypoints): 69 | h, w, c = image.shape 70 | if h > w: 71 | scale = 256.0 / h 72 | m = np.array([[scale, 0, 0], [0, scale, 0]]) # 缩放变换矩阵 73 | keypoints = np.reshape(keypoints.copy(), (-1, 3)) 74 | v = keypoints[:, 2] 75 | # keypoints = np.dot(keypoints, m.T) 76 | keypoints = np.hstack((np.dot(keypoints, m.T), v[:, np.newaxis])) 77 | else: 78 | scale = 256.0 / w 79 | m = np.array([[scale, 0, 0], [0, scale, 0]]) # 缩放变换矩阵 80 | keypoints = np.reshape(keypoints.copy(), (-1, 3)) 81 | v = keypoints[:, 2] 82 | keypoints = np.hstack((np.dot(keypoints, m.T), v[:, np.newaxis])) 83 | return image, keypoints.astype('int') 84 | 85 | 86 | def makeGaussian(height, width, sigma=5, center=None): 87 | """ 88 | Make a square gaussian kernel. 89 | :param height: 边长 90 | :param width: 边长 91 | :param sigma: 分布的幅度,标准差 92 | :param center: 高斯核中心 93 | :return: heatmap 带有高斯核 94 | """ 95 | x = np.arange(0, width, 1, float) 96 | y = np.arange(0, height, 1, float)[:, np.newaxis] 97 | if center is None: 98 | x0 = width // 2 99 | y0 = height // 2 100 | else: 101 | x0 = center[0] 102 | y0 = center[1] 103 | return np.exp(-4 * np.log(2) * ((x - x0) ** 2 + (y - y0) ** 2) / sigma ** 2) 104 | 105 | 106 | class GenerateHeatMap(object): 107 | def __call__(self, image, keypoints): 108 | hm = np.zeros((14, 256, 256), dtype=np.float32) 109 | for i, kp in enumerate(keypoints): 110 | if kp[2] == 1: 111 | hm[i] = makeGaussian(256, 256, sigma=5, center=(kp[0], kp[1])) 112 | return image, hm 113 | 114 | 115 | class Compose(object): 116 | def __init__(self, transforms): 117 | self.transforms = transforms 118 | 119 | def __call__(self, iamge, keypoints): 120 | for t in self.transforms: 121 | image, keypoints = t(iamge, keypoints) 122 | return image, keypoints 123 | 124 | 125 | class HPEAugmentation2(object): 126 | def __init__(self): 127 | self.augment = Compose([ 128 | RandomMirror(), 129 | RandomRotation() 130 | ]) 131 | 132 | def __call__(self, image, keypoints): 133 | return self.augment(image, keypoints) 134 | -------------------------------------------------------------------------------- /data/det_data_noexpand.py: -------------------------------------------------------------------------------- 1 | # coding:utf8 2 | from torch.utils.data import Dataset 3 | import pandas as pd 4 | import os 5 | import pickle 6 | from tqdm import tqdm 7 | import numpy as np 8 | import cv2 9 | from config import opt 10 | from .augmentation import GenerateLabel 11 | from utils.helper import Helper 12 | 13 | 14 | class HPEDetDataset_NE(Dataset): 15 | def __init__(self, annotations_file, img_dir, transform=None, phase='train'): 16 | self.anno_file = annotations_file 17 | self.img_dir = img_dir 18 | self.transform = transform 19 | self.phase = phase 20 | self.helper = Helper() 21 | 22 | self.pro_annos = [] 23 | self.gen_intermediate_file() 24 | self.generatelabel = GenerateLabel() 25 | 26 | def __getitem__(self, idx): 27 | anno = self.pro_annos[idx] 28 | image = cv2.imread(self.img_dir + anno['image_id'] + '.jpg') 29 | lx, ly = anno['coords'][:2] 30 | rx, ry = anno['coords'][2:] 31 | img = image[ly:ry, lx:rx, :] 32 | kps = np.array(anno['keypoints']).reshape(-1, 3) - [lx, ly, 0] 33 | if self.transform is not None: 34 | img, kps = self.transform(img, kps) 35 | img, label = self.generatelabel(img, kps) 36 | img = img[:, :, (2, 1, 0)] 37 | img = np.transpose(img, (2, 0, 1)) 38 | 39 | return img, label 40 | 41 | def __len__(self): 42 | return len(self.pro_annos) 43 | 44 | def gen_intermediate_file(self): 45 | _pkl_file = opt.interim_data_path + '{}_preprocessed.pkl'.format(self.phase) 46 | if os.path.exists(_pkl_file): 47 | self.pro_annos = pickle.load(open(_pkl_file, 'rb')) 48 | # if self.phase is 'val': 49 | # self.pro_annos = self.pro_annos[:5000] 50 | else: 51 | anno = pd.read_json(self.anno_file) 52 | for i in tqdm(range(anno.shape[0])): 53 | img_np = cv2.imread(self.img_dir + anno.image_id[i] + '.jpg') 54 | h, w = np.shape(img_np)[:2] 55 | for k, v in anno.human_annotations[i].items(): 56 | self.pro_annos.append({'image_id': anno.image_id[i], 57 | 'human': k, 58 | 'coords': v, 59 | 'height_width': (h, w), 60 | 'keypoints': anno.keypoint_annotations[i][k]}) 61 | self.pro_annos = list(filter(lambda x: x['image_id'] not in self.helper.img_list, self.pro_annos)) 62 | del anno 63 | with open(_pkl_file, 'wb') as f: 64 | pickle.dump(self.pro_annos, f) 65 | -------------------------------------------------------------------------------- /data/detection_data.py: -------------------------------------------------------------------------------- 1 | # coding:utf8 2 | from torch.utils.data import Dataset 3 | import pandas as pd 4 | import os 5 | import pickle 6 | from tqdm import tqdm 7 | import numpy as np 8 | import cv2 9 | from config import opt 10 | from .augmentation import GenerateLabel, Resize, KeypointTransform, SubtractMeans 11 | from utils.helper import Helper 12 | 13 | 14 | class HPEDetDataset(Dataset): 15 | def __init__(self, annotations_file, img_dir, transform=None, phase='train'): 16 | self.anno_file = annotations_file 17 | self.img_dir = img_dir 18 | self.transform = transform 19 | self.phase = phase 20 | self.helper = Helper() 21 | 22 | self.pro_annos = [] 23 | self.gen_intermediate_file() 24 | self.generatelabel = GenerateLabel() 25 | self.kptransform = KeypointTransform() 26 | # if phase is 'train': 27 | # self.resize = Resize(mean=opt.train_mean) 28 | # self.submean = SubtractMeans(mean=opt.train_mean) 29 | # else: 30 | # self.resize = Resize(mean=opt.val_mean) 31 | # self.submean = SubtractMeans(mean=opt.val_mean) 32 | 33 | def __getitem__(self, idx): 34 | anno = self.pro_annos[idx] 35 | image = cv2.imread(self.img_dir + anno['image_id'] + '.jpg') 36 | lx, ly = anno['coords'][:2] 37 | rx, ry = anno['coords'][2:] 38 | img = image[ly:ry, lx:rx, :] 39 | kps = np.array(anno['keypoints']).reshape(-1, 3) - [lx, ly, 0] 40 | if self.transform is not None: 41 | img, kps = self.transform(img, kps) 42 | # img, kps = self.kptransform(img, kps) 43 | # img, kps = self.resize(img, kps) 44 | # img, kps = self.submean(img, kps) 45 | img, label = self.generatelabel(img, kps) 46 | img = img[:, :, (2, 1, 0)] 47 | img = np.transpose(img, (2, 0, 1)) 48 | # return img, kps, anno['image_id'], anno['keypoints'] 49 | 50 | return img, label 51 | 52 | def __len__(self): 53 | return len(self.pro_annos) 54 | 55 | def gen_intermediate_file(self): 56 | _pkl_file = opt.interim_data_path + '{}_preprocessed.pkl'.format(self.phase) 57 | if os.path.exists(_pkl_file): 58 | self.pro_annos = pickle.load(open(_pkl_file, 'rb')) 59 | # if self.phase is 'val': 60 | # self.pro_annos = self.pro_annos[:5000] 61 | else: 62 | anno = pd.read_json(self.anno_file) 63 | for i in tqdm(range(anno.shape[0])): 64 | img_np = cv2.imread(self.img_dir + anno.image_id[i] + '.jpg') 65 | h, w = np.shape(img_np)[:2] 66 | for k, v in anno.human_annotations[i].items(): 67 | coords = np.array(v).reshape(-1, 2) 68 | offset = (coords[1] - coords[0]) * 0.15 # 沿着长和宽扩大30% 69 | coords = v + np.concatenate((-offset, offset)) 70 | coords = coords.astype("int") 71 | coords[np.where(coords < 0)] = 0 72 | if coords[2] > w: 73 | coords[2] = w 74 | if coords[3] > h: 75 | coords[3] = h 76 | self.pro_annos.append({'image_id': anno.image_id[i], 77 | 'human': k, 78 | 'coords': coords, 79 | 'height_width': (h, w), 80 | 'keypoints': anno.keypoint_annotations[i][k]}) 81 | self.pro_annos = list(filter(lambda x: x['image_id'] not in self.helper.img_list, self.pro_annos)) 82 | del anno 83 | # print(self.pro_annos[:3]) 84 | with open(_pkl_file, 'wb') as f: 85 | pickle.dump(self.pro_annos, f) 86 | -------------------------------------------------------------------------------- /data/generate_256img.py: -------------------------------------------------------------------------------- 1 | # coding:utf8 2 | from torch.utils.data import Dataset 3 | import os 4 | import pickle 5 | import numpy as np 6 | import cv2 7 | from config import opt 8 | from .augmentation import GenerateHeatMap, Resize, SubtractMeans 9 | from tqdm import tqdm 10 | import pandas as pd 11 | import random 12 | 13 | 14 | def gen_intermediate_file(img_dir, phase='train', transform=None): 15 | _pkl_file = opt.interim_data_path + '{}_preprocessed.pkl'.format(phase) 16 | pro_annos = pickle.load(open(_pkl_file, 'rb')) 17 | for i in range(2): 18 | if i == 1: 19 | for anno in tqdm(pro_annos): 20 | # print(self.img_dir + anno['image_id'] + '.jpg') 21 | image = cv2.imread(img_dir + anno['image_id'] + '.jpg') 22 | lx, ly = anno['coords'][:2] 23 | rx, ry = anno['coords'][2:] 24 | img = image[ly:ry, lx:rx, :] 25 | kps = np.array(anno['keypoints']).reshape(-1, 3) - [lx, ly, 0] 26 | if transform is not None: 27 | img, kps = transform(img, kps) 28 | img, label = generatehm(img, kps, H=64, W=64, sigma=3) 29 | img = img[:, :, (2, 1, 0)] 30 | img = np.transpose(img, (2, 0, 1)) 31 | return img, label, anno 32 | 33 | 34 | class hgDataset(Dataset): 35 | def __init__(self, annotations_file, img_dir, transform=None, phase='train'): 36 | self.anno_file = annotations_file 37 | self.img_dir = img_dir 38 | self.transform = transform 39 | self.phase = phase 40 | self.pro_annos = [] 41 | self.gen_intermediate_file() 42 | self.generatehm = GenerateHeatMap() 43 | 44 | def gen_intermediate_file(self): 45 | _pkl_file = opt.interim_data_path + '{}_preprocessed.pkl'.format(self.phase) 46 | if os.path.exists(_pkl_file): 47 | self.pro_annos = pickle.load(open(_pkl_file, 'rb')) 48 | else: 49 | anno = pd.read_json(self.anno_file) 50 | for i in tqdm(range(anno.shape[0]), ncols=20): 51 | img_np = cv2.imread(self.img_dir + anno.image_id[i] + '.jpg') 52 | h, w = np.shape(img_np)[:2] 53 | for k, v in anno.human_annotations[i].items(): 54 | self.pro_annos.append({'image_id': anno.image_id[i], 55 | 'human': k, 56 | 'coords': np.array(v), 57 | 'height_width': (h, w), 58 | 'keypoints': np.array(anno.keypoint_annotations[i][k])}) 59 | del anno 60 | with open(_pkl_file, 'wb') as f: 61 | pickle.dump(self.pro_annos, f) 62 | 63 | def __getitem__(self, idx): 64 | anno = self.pro_annos[idx] 65 | # print(self.img_dir + anno['image_id'] + '.jpg') 66 | image = cv2.imread(self.img_dir + anno['image_id'] + '.jpg') 67 | lx, ly = anno['coords'][:2] 68 | rx, ry = anno['coords'][2:] 69 | img = image[ly:ry, lx:rx, :] 70 | kps = np.array(anno['keypoints']).reshape(-1, 3) - [lx, ly, 0] 71 | if self.transform is not None: 72 | img, kps = self.transform(img, kps) 73 | img, label = self.generatehm(img, kps, H=64, W=64, sigma=3) 74 | img = img[:, :, (2, 1, 0)] 75 | img = np.transpose(img, (2, 0, 1)) 76 | return img, label, anno 77 | 78 | def __len__(self): 79 | return len(self.pro_annos) 80 | 81 | 82 | class hgValDataset(Dataset): 83 | def __init__(self, annotations_file, img_dir, num=None, phase='val'): 84 | self.anno_file = annotations_file 85 | self.img_dir = img_dir 86 | assert isinstance(num, int) or num is None # 抽样数据量,num应该是整数 87 | self.num = num 88 | 89 | self.mean = opt.val_mean if phase == 'val' else opt.train_mean 90 | self.phase = phase 91 | 92 | self.pro_annos = [] 93 | self.gen_intermediate_file() 94 | 95 | self.resize = Resize(mean=self.mean) 96 | self.submean = SubtractMeans(mean=self.mean) 97 | 98 | def gen_intermediate_file(self): 99 | _pkl_file = opt.interim_data_path + '{}_preprocessed.pkl'.format(self.phase) 100 | if os.path.exists(_pkl_file): 101 | self.pro_annos = pickle.load(open(_pkl_file, 'rb')) 102 | if self.num is not None: 103 | self.pro_annos = random.sample(self.pro_annos, self.num) 104 | else: 105 | anno = pd.read_json(self.anno_file) 106 | for i in tqdm(range(anno.shape[0]), ncols=20): 107 | img_np = cv2.imread(self.img_dir + anno.image_id[i] + '.jpg') 108 | h, w = np.shape(img_np)[:2] 109 | for k, v in anno.human_annotations[i].items(): 110 | self.pro_annos.append({'image_id': anno.image_id[i], 111 | 'human': k, 112 | 'coords': np.array(v), 113 | 'height_width': (h, w), 114 | 'keypoints': np.array(anno.keypoint_annotations[i][k])}) 115 | del anno 116 | with open(_pkl_file, 'wb') as f: 117 | pickle.dump(self.pro_annos, f) 118 | 119 | def __getitem__(self, idx): 120 | anno = self.pro_annos[idx] 121 | image = cv2.imread(self.img_dir + anno['image_id'] + '.jpg') 122 | lx, ly = anno['coords'][:2] 123 | rx, ry = anno['coords'][2:] 124 | img = image[ly:ry, lx:rx, :] 125 | img, _ = self.resize(img, None) 126 | img, _ = self.submean(img, None) 127 | img = img[:, :, (2, 1, 0)] 128 | img = np.transpose(img, (2, 0, 1)) 129 | return img, anno 130 | 131 | def __len__(self): 132 | return len(self.pro_annos) 133 | 134 | 135 | class EvalDataset(Dataset): 136 | def __init__(self, annotations_file, test_img_dir, phase='val'): 137 | self.anno_file = annotations_file 138 | self.img_dir = test_img_dir 139 | self.phase = phase 140 | 141 | self.pro_annos = [] 142 | self.gen_intermediate_file() 143 | if phase is 'val': 144 | self.mean = opt.val_mean 145 | elif phase is 'train': 146 | self.mean = opt.train_mean 147 | else: 148 | self.mean = opt.test_mean 149 | self.resize = Resize(mean=self.mean) 150 | self.submean = SubtractMeans(mean=self.mean) 151 | 152 | def gen_intermediate_file(self): 153 | _pkl_file = opt.interim_data_path + 'eval_{}30000_preprocessed.pkl'.format(self.phase) 154 | if os.path.exists(_pkl_file): 155 | self.pro_annos = pickle.load(open(_pkl_file, 'rb')) 156 | else: 157 | with open(self.anno_file, 'rb') as f: 158 | anno = pickle.load(f) 159 | for i in tqdm(range(len(anno)), ncols=50): 160 | img_np = cv2.imread(self.img_dir + anno[i]['image_id'] + '.jpg') 161 | h, w = np.shape(img_np)[:2] 162 | for k, v in anno[i]['human_annotations'].items(): 163 | coords = np.array(v).reshape(-1, 2) 164 | offset = (coords[1] - coords[0]) * 0.15 # 沿着长和宽扩大30% 165 | coords = v + np.concatenate((-offset, offset)) 166 | coords = coords.astype("int") 167 | coords[np.where(coords < 0)] = 0 168 | if coords[2] > w: 169 | coords[2] = w 170 | if coords[3] > h: 171 | coords[3] = h 172 | self.pro_annos.append({'image_id': anno[i]['image_id'], 173 | 'human': k, 174 | 'coords': coords, 175 | 'height_width': (h, w)}) 176 | with open(_pkl_file, 'wb') as f: 177 | pickle.dump(self.pro_annos, f) 178 | 179 | def __getitem__(self, idx): 180 | anno = self.pro_annos[idx] 181 | image = cv2.imread(self.img_dir + anno['image_id'] + '.jpg') 182 | lx, ly = anno['coords'][:2] 183 | rx, ry = anno['coords'][2:] 184 | img = image[ly:ry, lx:rx, :] 185 | kps = None 186 | img, kps = self.resize(img, kps) 187 | img, kps = self.submean(img, kps) 188 | img = img[:, :, (2, 1, 0)] 189 | img = np.transpose(img, (2, 0, 1)) 190 | return img, anno 191 | 192 | def __len__(self): 193 | return len(self.pro_annos) 194 | -------------------------------------------------------------------------------- /data/generate_data.py: -------------------------------------------------------------------------------- 1 | # coding:utf8 2 | from torch.utils.data import Dataset 3 | import pandas as pd 4 | import os 5 | import pickle 6 | from skimage import io 7 | from tqdm import tqdm 8 | from . import image_handle 9 | from config import opt 10 | from utils import Helper 11 | import numpy as np 12 | 13 | 14 | class HumanPoseDetectionDataset(Dataset): 15 | """Human Pose Detection Dataset""" 16 | 17 | def __init__(self, annotations_file, img_dir, transform=None): 18 | """ 19 | :param annotations_file: 标注信息的路径 20 | :param img_dir: 图片保存路径 21 | :param transform(callable, optional): 数据增强 22 | """ 23 | self.processed_annotations = [] 24 | self.annotations_file = annotations_file 25 | self.img_dir = img_dir 26 | self.transform = transform 27 | self.img_list = Helper().img_list # 有问题的图片list 28 | 29 | self.gen_intermediate_file() 30 | 31 | def gen_intermediate_file(self): 32 | if os.path.exists(opt.interim_data_path + 'train_processed_dataset.pkl'): 33 | self.processed_annotations = pickle.load(open(opt.interim_data_path + 'train_processed_dataset.pkl', 'r')) 34 | # print(self.processed_annotations[:2]) 35 | else: 36 | anno = pd.read_json(self.annotations_file + 'keypoint_train_annotations_20170909.json') 37 | for i in tqdm(xrange(anno.shape[0])): 38 | self.processed_annotations.extend(image_handle.annotations_handle(anno.image_id[i], 39 | anno.human_annotations[i], 40 | anno.keypoint_annotations[i])) 41 | # 过滤掉有问题的图片,后续可能会手工对有问题的图片处理 42 | # print(self.processed_annotations[:2]) 43 | self.processed_annotations = filter(lambda x: x['img_id'] not in self.img_list, self.processed_annotations) 44 | pickle.dump(self.processed_annotations, 45 | open(opt.interim_data_path + 'train_processed_dataset.pkl', 'w')) 46 | del anno 47 | 48 | def __getitem__(self, idx): 49 | img_name = self.processed_annotations[idx]['img_id'] 50 | isupright = self.processed_annotations[idx]['info'][0][2] 51 | offset = self.processed_annotations[idx]['info'][2] 52 | lx, ly = self.processed_annotations[idx]['info'][0][:2] 53 | rx, ry = self.processed_annotations[idx]['info'][1][:2] 54 | img = io.imread(self.img_dir + img_name + '.jpg') 55 | img = img[ly:ry, lx: rx, :] 56 | processed_img = image_handle.crop_and_scale(img, offset, isupright) 57 | processed_img = np.transpose(processed_img, (2, 0, 1)) 58 | label = image_handle.generate_part_label(self.processed_annotations[idx]['info'][3:]) 59 | 60 | if self.transform is not None: 61 | processed_img, label = self.transform(processed_img, label) 62 | return processed_img.astype(np.float), label.astype(np.float) 63 | 64 | def __len__(self): 65 | return len(self.processed_annotations) 66 | 67 | 68 | class HumanPoseValDataset(Dataset): 69 | """探测网络验证集""" 70 | 71 | def __init__(self, annotations_file, img_dir): 72 | self.annotations_file = annotations_file 73 | self.img_dir = img_dir 74 | self.problem_imgs = [] # 有问题的图片 75 | self.processed_annotations = [] 76 | 77 | self.official_data = pd.read_json(self.annotations_file + 'keypoint_validation_annotations_20170911.json') 78 | self.process_problem_data() 79 | self.gen_intermediate_file() 80 | 81 | def __len__(self): 82 | return len(self.processed_annotations) 83 | 84 | def __getitem__(self, idx): 85 | img_name = self.processed_annotations[idx]['img_id'] 86 | isupright = self.processed_annotations[idx]['info'][0][2] 87 | offset = self.processed_annotations[idx]['info'][2] 88 | lx, ly = self.processed_annotations[idx]['info'][0][:2] 89 | rx, ry = self.processed_annotations[idx]['info'][1][:2] 90 | img = io.imread(self.img_dir + img_name + '.jpg') 91 | img = img[ly:ry, lx: rx, :] 92 | processed_img = image_handle.crop_and_scale(img, offset, isupright) 93 | processed_img = np.transpose(processed_img, (2, 0, 1)) 94 | label = image_handle.generate_part_label(self.processed_annotations[idx]['info'][3:]) 95 | return processed_img.astype(np.float), label.astype(np.float) 96 | 97 | def process_problem_data(self): 98 | for i in xrange(self.official_data.shape[0]): 99 | for k, v in self.official_data.human_annotations[i].items(): 100 | if v[0] >= v[2] or v[1] >= v[3]: 101 | self.problem_imgs.append(self.official_data.image_id[i]) 102 | 103 | def gen_intermediate_file(self): 104 | if os.path.exists(opt.interim_data_path + 'val_processed_dataset.pkl'): 105 | self.processed_annotations = pickle.load(open(opt.interim_data_path + 'val_processed_dataset.pkl', 'r')) 106 | else: 107 | for i in tqdm(xrange(self.official_data.shape[0])): 108 | if self.official_data.image_id[i] in self.problem_imgs: 109 | continue 110 | else: 111 | self.processed_annotations.extend( 112 | image_handle.val_annotations_handle( 113 | self.official_data.image_id[i], 114 | self.official_data.human_annotations[i], 115 | self.official_data.keypoint_annotations[i] 116 | ) 117 | ) 118 | pickle.dump(self.processed_annotations, 119 | open(opt.interim_data_path + 'val_processed_dataset.pkl', 'w')) 120 | del self.official_data 121 | 122 | 123 | # 以下是回归网络的数据集------------------------------------ 124 | class HumanPoseRegressionDataset(Dataset): 125 | """回归子网络数据集""" 126 | 127 | def __len__(self): 128 | return len(self.processed_annotations) 129 | 130 | def __init__(self, annotations_file, img_dir, transform=None): 131 | """ 132 | :param annotations_file: 标注信息的路径 133 | :param img_dir: 图片保存路径 134 | :param transform(callable, optional): Optional transform to be applied on a sample. 135 | """ 136 | self.annotations_file = annotations_file 137 | self.img_dir = img_dir 138 | self.transform = transform 139 | self.img_list = Helper().img_list # 有问题的图片list 140 | 141 | # self.processed_annotations = pickle.load(open(self.annotations_file + 'train_processed_dataset.pkl', 'r')) 142 | self.processed_annotations = [] 143 | self.gen_intermediate_file() 144 | 145 | def __getitem__(self, idx): 146 | img_name = self.processed_annotations[idx]['img_id'] 147 | isupright = self.processed_annotations[idx]['info'][0][2] 148 | offset = self.processed_annotations[idx]['info'][2] 149 | lx, ly = self.processed_annotations[idx]['info'][0][:2] 150 | rx, ry = self.processed_annotations[idx]['info'][1][:2] 151 | img = io.imread(self.img_dir + img_name + '.jpg') 152 | img = img[ly:ry, lx: rx, :] 153 | processed_img = image_handle.crop_and_scale(img, offset, isupright) 154 | processed_img = np.transpose(processed_img, (2, 0, 1)) 155 | 156 | label = image_handle.generate_regression_hm(self.processed_annotations[idx]['info'][3:]) 157 | 158 | if self.transform is not None: 159 | processed_img, label = self.transform(processed_img, label) 160 | 161 | return processed_img.astype(np.float), label.astype(np.float) 162 | 163 | def gen_intermediate_file(self): 164 | if os.path.exists(opt.interim_data_path + 'train_processed_dataset.pkl'): 165 | self.processed_annotations = pickle.load(open(opt.interim_data_path + 'train_processed_dataset.pkl', 'r')) 166 | # print(self.processed_annotations[:2]) 167 | else: 168 | anno = pd.read_json(self.annotations_file + 'keypoint_train_annotations_20170909.json') 169 | for i in tqdm(xrange(anno.shape[0])): 170 | self.processed_annotations.extend(image_handle.annotations_handle(anno.image_id[i], 171 | anno.human_annotations[i], 172 | anno.keypoint_annotations[i])) 173 | # 过滤掉有问题的图片,后续可能会手工对有问题的图片处理 174 | # print(self.processed_annotations[:2]) 175 | self.processed_annotations = filter(lambda x: x['img_id'] not in self.img_list, self.processed_annotations) 176 | pickle.dump(self.processed_annotations, 177 | open(opt.interim_data_path + 'train_processed_dataset.pkl', 'w')) 178 | del anno 179 | 180 | 181 | class RegressionValDataset(Dataset): 182 | """回归网络验证集""" 183 | 184 | def __init__(self, annotations_file, img_dir): 185 | self.annotations_file = annotations_file 186 | self.img_dir = img_dir 187 | self.problem_imgs = [] # 有问题的图片 188 | 189 | # self.processed_annotations = pickle.load(open(opt.interim_data_path + 'val_processed_dataset.pkl', 'r')) 190 | self.processed_annotations = [] 191 | 192 | self.official_data = pd.read_json(self.annotations_file + 'keypoint_validation_annotations_20170911.json') 193 | self.process_problem_data() 194 | self.gen_intermediate_file() 195 | 196 | def __len__(self): 197 | return len(self.processed_annotations) 198 | 199 | def __getitem__(self, idx): 200 | img_name = self.processed_annotations[idx]['img_id'] 201 | isupright = self.processed_annotations[idx]['info'][0][2] 202 | offset = self.processed_annotations[idx]['info'][2] 203 | lx, ly = self.processed_annotations[idx]['info'][0][:2] 204 | rx, ry = self.processed_annotations[idx]['info'][1][:2] 205 | img = io.imread(self.img_dir + img_name + '.jpg') 206 | img = img[ly:ry, lx: rx, :] 207 | processed_img = image_handle.crop_and_scale(img, offset, isupright) 208 | processed_img = np.transpose(processed_img, (2, 0, 1)) 209 | return processed_img.astype(np.float), self.processed_annotations[idx] 210 | 211 | def process_problem_data(self): 212 | for i in xrange(self.official_data.shape[0]): 213 | for k, v in self.official_data.human_annotations[i].items(): 214 | if v[0] >= v[2] or v[1] >= v[3]: 215 | self.problem_imgs.append(self.official_data.image_id[i]) 216 | 217 | def gen_intermediate_file(self): 218 | if os.path.exists(opt.interim_data_path + 'val_processed_dataset.pkl'): 219 | self.processed_annotations = pickle.load(open(opt.interim_data_path + 'val_processed_dataset.pkl', 'r')) 220 | else: 221 | for i in tqdm(xrange(self.official_data.shape[0])): 222 | if self.official_data.image_id[i] in self.problem_imgs: 223 | continue 224 | else: 225 | self.processed_annotations.extend( 226 | image_handle.val_annotations_handle( 227 | self.official_data.image_id[i], 228 | self.official_data.human_annotations[i], 229 | self.official_data.keypoint_annotations[i] 230 | ) 231 | ) 232 | pickle.dump(self.processed_annotations, 233 | open(opt.interim_data_path + 'val_processed_dataset.pkl', 'w')) 234 | del self.official_data 235 | -------------------------------------------------------------------------------- /data/hg_data.py: -------------------------------------------------------------------------------- 1 | # coding:utf8 2 | from torch.utils.data import Dataset 3 | import os 4 | import pickle 5 | import numpy as np 6 | import cv2 7 | from config import opt 8 | from .augmentation import GenerateHeatMap, Resize, SubtractMeans 9 | from tqdm import tqdm 10 | import pandas as pd 11 | import random 12 | 13 | 14 | class hgDataset(Dataset): 15 | def __init__(self, annotations_file, img_dir, transform=None, phase='train'): 16 | self.anno_file = annotations_file 17 | self.img_dir = img_dir 18 | self.transform = transform 19 | self.phase = phase 20 | self.pro_annos = [] 21 | self.gen_intermediate_file() 22 | self.generatehm = GenerateHeatMap() 23 | 24 | def gen_intermediate_file(self): 25 | _pkl_file = opt.interim_data_path + '{}_preprocessed.pkl'.format(self.phase) 26 | if os.path.exists(_pkl_file): 27 | self.pro_annos = pickle.load(open(_pkl_file, 'rb')) 28 | else: 29 | anno = pd.read_json(self.anno_file) 30 | for i in tqdm(range(anno.shape[0]), ncols=20): 31 | img_np = cv2.imread(self.img_dir + anno.image_id[i] + '.jpg') 32 | h, w = np.shape(img_np)[:2] 33 | for k, v in anno.human_annotations[i].items(): 34 | self.pro_annos.append({'image_id': anno.image_id[i], 35 | 'human': k, 36 | 'coords': np.array(v), 37 | 'height_width': (h, w), 38 | 'keypoints': np.array(anno.keypoint_annotations[i][k])}) 39 | del anno 40 | with open(_pkl_file, 'wb') as f: 41 | pickle.dump(self.pro_annos, f) 42 | 43 | def __getitem__(self, idx): 44 | anno = self.pro_annos[idx] 45 | # print(self.img_dir + anno['image_id'] + '.jpg') 46 | image = cv2.imread(self.img_dir + anno['image_id'] + '.jpg') 47 | lx, ly = anno['coords'][:2] 48 | rx, ry = anno['coords'][2:] 49 | img = image[ly:ry, lx:rx, :] 50 | kps = np.array(anno['keypoints']).reshape(-1, 3) - [lx, ly, 0] 51 | if self.transform is not None: 52 | img, kps = self.transform(img, kps) 53 | img, label = self.generatehm(img, kps, H=64, W=64, sigma=3) 54 | img = img[:, :, (2, 1, 0)] 55 | img = np.transpose(img, (2, 0, 1)) 56 | return img, label, anno 57 | 58 | def __len__(self): 59 | return len(self.pro_annos) 60 | 61 | 62 | class hgValDataset(Dataset): 63 | def __init__(self, annotations_file, img_dir, num=None, phase='val'): 64 | self.anno_file = annotations_file 65 | self.img_dir = img_dir 66 | assert isinstance(num, int) or num is None # 抽样数据量,num应该是整数 67 | self.num = num 68 | 69 | self.mean = opt.val_mean if phase == 'val' else opt.train_mean 70 | self.phase = phase 71 | 72 | self.pro_annos = [] 73 | self.gen_intermediate_file() 74 | 75 | self.resize = Resize(mean=self.mean) 76 | self.submean = SubtractMeans(mean=self.mean) 77 | 78 | def gen_intermediate_file(self): 79 | _pkl_file = opt.interim_data_path + '{}_preprocessed.pkl'.format(self.phase) 80 | if os.path.exists(_pkl_file): 81 | self.pro_annos = pickle.load(open(_pkl_file, 'rb')) 82 | if self.num is not None: 83 | self.pro_annos = random.sample(self.pro_annos, self.num) 84 | else: 85 | anno = pd.read_json(self.anno_file) 86 | for i in tqdm(range(anno.shape[0]), ncols=20): 87 | img_np = cv2.imread(self.img_dir + anno.image_id[i] + '.jpg') 88 | h, w = np.shape(img_np)[:2] 89 | for k, v in anno.human_annotations[i].items(): 90 | self.pro_annos.append({'image_id': anno.image_id[i], 91 | 'human': k, 92 | 'coords': np.array(v), 93 | 'height_width': (h, w), 94 | 'keypoints': np.array(anno.keypoint_annotations[i][k])}) 95 | del anno 96 | with open(_pkl_file, 'wb') as f: 97 | pickle.dump(self.pro_annos, f) 98 | 99 | def __getitem__(self, idx): 100 | anno = self.pro_annos[idx] 101 | image = cv2.imread(self.img_dir + anno['image_id'] + '.jpg') 102 | lx, ly = anno['coords'][:2] 103 | rx, ry = anno['coords'][2:] 104 | img = image[ly:ry, lx:rx, :] 105 | img, _ = self.resize(img, None) 106 | img, _ = self.submean(img, None) 107 | img = img[:, :, (2, 1, 0)] 108 | img = np.transpose(img, (2, 0, 1)) 109 | return img, anno 110 | 111 | def __len__(self): 112 | return len(self.pro_annos) 113 | 114 | 115 | class EvalDataset(Dataset): 116 | def __init__(self, annotations_file, test_img_dir, phase='val'): 117 | self.anno_file = annotations_file 118 | self.img_dir = test_img_dir 119 | self.phase = phase 120 | 121 | self.pro_annos = [] 122 | self.gen_intermediate_file() 123 | if phase is 'val': 124 | self.mean = opt.val_mean 125 | elif phase is 'train': 126 | self.mean = opt.train_mean 127 | else: 128 | self.mean = opt.test_mean 129 | self.resize = Resize(mean=self.mean) 130 | self.submean = SubtractMeans(mean=self.mean) 131 | 132 | def gen_intermediate_file(self): 133 | _pkl_file = opt.interim_data_path + 'eval_{}30000_preprocessed.pkl'.format(self.phase) 134 | if os.path.exists(_pkl_file): 135 | self.pro_annos = pickle.load(open(_pkl_file, 'rb')) 136 | else: 137 | with open(self.anno_file, 'rb') as f: 138 | anno = pickle.load(f) 139 | for i in tqdm(range(len(anno)), ncols=50): 140 | img_np = cv2.imread(self.img_dir + anno[i]['image_id'] + '.jpg') 141 | h, w = np.shape(img_np)[:2] 142 | for k, v in anno[i]['human_annotations'].items(): 143 | coords = np.array(v).reshape(-1, 2) 144 | offset = (coords[1] - coords[0]) * 0.15 # 沿着长和宽扩大30% 145 | coords = v + np.concatenate((-offset, offset)) 146 | coords = coords.astype("int") 147 | coords[np.where(coords < 0)] = 0 148 | if coords[2] > w: 149 | coords[2] = w 150 | if coords[3] > h: 151 | coords[3] = h 152 | self.pro_annos.append({'image_id': anno[i]['image_id'], 153 | 'human': k, 154 | 'coords': coords, 155 | 'height_width': (h, w)}) 156 | with open(_pkl_file, 'wb') as f: 157 | pickle.dump(self.pro_annos, f) 158 | 159 | def __getitem__(self, idx): 160 | anno = self.pro_annos[idx] 161 | image = cv2.imread(self.img_dir + anno['image_id'] + '.jpg') 162 | lx, ly = anno['coords'][:2] 163 | rx, ry = anno['coords'][2:] 164 | img = image[ly:ry, lx:rx, :] 165 | kps = None 166 | img, kps = self.resize(img, kps) 167 | img, kps = self.submean(img, kps) 168 | img = img[:, :, (2, 1, 0)] 169 | img = np.transpose(img, (2, 0, 1)) 170 | return img, anno 171 | 172 | def __len__(self): 173 | return len(self.pro_annos) 174 | -------------------------------------------------------------------------------- /data/image_handle.py: -------------------------------------------------------------------------------- 1 | # coding:utf8 2 | from skimage import transform 3 | import numpy as np 4 | 5 | 6 | def trans_coordinate(hm_position, keypoints): 7 | """ 8 | 实现从原图到256*256分辨率的坐标转换,包括标注点(关节点) 9 | :param hm_position: 人框的位置,两个对角坐标点 10 | :param keypoints: 原关节标注点 11 | :return: 字典: 缩放率 scale 12 | info 人体位置信息和关节点信息组成的numpy数组维度(16,3) 13 | ((原图人体左上角坐标, 是否站立), (原图人体右下角坐标, 0), (人体右下角位置), 其他14个关节点) 14 | """ 15 | span_x = hm_position[2] - hm_position[0] 16 | span_y = hm_position[3] - hm_position[1] 17 | isupright = True if span_y - span_x >= 0 else False 18 | keypoints_res = [hm_position[2], hm_position[3], 0] + keypoints 19 | keypoints_res = np.reshape(keypoints_res, (15, 3)) 20 | 21 | if isupright: 22 | scale_ratio = 256.0 / span_y 23 | keypoints_res = keypoints_res - [hm_position[0], hm_position[1], 0] 24 | keypoints_res = keypoints_res * [scale_ratio, scale_ratio, 1] 25 | keypoints_res = np.vstack( 26 | ([[hm_position[0], hm_position[1], 1], [hm_position[2], hm_position[3], 0]], keypoints_res)) 27 | keypoints_res = keypoints_res.astype(np.int32) 28 | return {'scale': span_y, 'info': keypoints_res} 29 | else: 30 | scale_ratio = 256.0 / span_x 31 | keypoints_res = keypoints_res - [hm_position[0], hm_position[1], 0] 32 | keypoints_res = keypoints_res * [scale_ratio, scale_ratio, 1] 33 | keypoints_res = np.vstack( 34 | ([[hm_position[0], hm_position[1], 0], [hm_position[2], hm_position[3], 0]], keypoints_res)) 35 | keypoints_res = keypoints_res.astype(np.int32) 36 | return {'scale': span_x, 'info': keypoints_res} 37 | 38 | 39 | def annotations_handle(img_name, human_positions, keypoints): 40 | """ 41 | 生成剪切以及按比例缩放的关键点 42 | :param img_name: 图片名字 43 | :param human_positions: 人框位置 44 | :param keypoints: 原关节标注点 45 | :return: List [{'scale': scale_ratio, 'info': keypoints_res, 'img_id': img_name}] 46 | """ 47 | processed_list = [] 48 | for k, v in human_positions.items(): 49 | # 例(k, v) (u'human1', [185, 161, 418, 936]) 50 | dict_temp = trans_coordinate(v, keypoints[k]) 51 | dict_temp.update({'img_id': img_name}) 52 | processed_list.append(dict_temp) 53 | # print(processed_list) 54 | return processed_list 55 | 56 | 57 | def val_annotations_handle(img_name, human_positions, keypoints): 58 | """ 59 | 验证集,生成剪切以及按比例缩放的关键点 60 | :param img_name: 图片名字 61 | :param human_positions: 人框位置 62 | :param keypoints: 原关节标注点 63 | :return: List [{'scale': scale_ratio, 'info': keypoints_res, 'img_id': img_name, 'human': human}] 64 | """ 65 | processed_list = [] 66 | for k, v in human_positions.items(): 67 | # 例(k, v) (u'human1', [185, 161, 418, 936]) 68 | dict_temp = trans_coordinate(v, keypoints[k]) 69 | dict_temp.update({'img_id': img_name, 'human': k}) 70 | processed_list.append(dict_temp) 71 | return processed_list 72 | 73 | 74 | def crop_and_scale(img, offset, isupright): 75 | """ 76 | 剪切以及按比例缩放,生成输入数据 77 | :param img: 原图,三维矩阵 78 | :param offset: 距离右上角的偏差(偏右或偏下) 79 | :param isupright: 是否是直立 80 | :return: 裁剪好的图像256*256,有可能是多张,与原图中人个数有关,维度:(256,256,3) 81 | """ 82 | if isupright: 83 | func_off = offset[0] 84 | # print (func_off) 85 | return np.concatenate( 86 | (transform.resize(img, (256, func_off), mode='reflect'), np.zeros((256, 256 - func_off, 3))), axis=1) 87 | else: 88 | func_off = offset[1] 89 | # print (func_off) 90 | return np.concatenate( 91 | (transform.resize(img, (func_off, 256), mode='reflect'), np.zeros((256 - func_off, 256, 3))), axis=0) 92 | 93 | 94 | def generate_part_label(keypoints, height=256, width=256, radius=10): 95 | """ 96 | 生成part detection subnet 标签数据 97 | :param keypoints: 新的标注点 98 | :param height: 生成的高度 99 | :param width: 生成的宽度 100 | :param radius: binary区域半径 101 | :return: labels 102 | """ 103 | heatmap_res = None 104 | x = np.arange(0, width, dtype=np.uint32) 105 | y = np.arange(0, height, dtype=np.uint32)[:, np.newaxis] 106 | for kp in keypoints: 107 | if kp[2] == 1: 108 | if heatmap_res is None: 109 | heatmap_res = ((x - kp[0]) ** 2 + (y - kp[1]) ** 2) <= radius ** 2 110 | else: 111 | heatmap_res = np.vstack((heatmap_res, (((x - kp[0]) ** 2 + (y - kp[1]) ** 2) <= radius ** 2))) 112 | else: 113 | if heatmap_res is None: 114 | heatmap_res = np.zeros((height, width), dtype=np.uint8) 115 | else: 116 | heatmap_res = np.vstack((heatmap_res, np.zeros((height, width), dtype=np.uint8))) 117 | heatmap_res = heatmap_res.astype(np.uint8) 118 | 119 | return np.reshape(heatmap_res, (-1, height, width)) 120 | 121 | 122 | def makeGaussian(height, width, sigma=5, center=None): 123 | """ 124 | Make a square gaussian kernel. 125 | :param height: 边长 126 | :param width: 边长 127 | :param sigma: 分布的幅度,标准差 128 | :param center: 高斯核中心 129 | :return: heatmap 带有高斯核 130 | """ 131 | x = np.arange(0, width, 1, float) 132 | y = np.arange(0, height, 1, float)[:, np.newaxis] 133 | if center is None: 134 | x0 = width // 2 135 | y0 = height // 2 136 | else: 137 | x0 = center[0] 138 | y0 = center[1] 139 | return np.exp(-4 * np.log(2) * ((x - x0) ** 2 + (y - y0) ** 2) / sigma ** 2) 140 | 141 | 142 | def generate_regression_hm(keypoint, height=256, width=256, sigma=5): 143 | """ 144 | 生成回归子网络的label 145 | :param keypoint: 关节点 146 | :param height: 边长 147 | :param width: 边长 148 | :param sigma: 分布的幅度,标准差 149 | :return: label 150 | """ 151 | hm = np.zeros((14, height, width), dtype=np.float32) 152 | for i, kp in enumerate(keypoint): 153 | if kp[2] == 1: 154 | hm[i] = makeGaussian(height, width, sigma=sigma, center=(kp[0], kp[1])) 155 | return hm 156 | -------------------------------------------------------------------------------- /data/regression_data.py: -------------------------------------------------------------------------------- 1 | # coding:utf8 2 | from torch.utils.data import Dataset 3 | import os 4 | import pickle 5 | import numpy as np 6 | import cv2 7 | from config import opt 8 | from .augmentation import GenerateHeatMap, Resize, SubtractMeans 9 | from tqdm import tqdm 10 | import pandas as pd 11 | from utils.helper import Helper 12 | 13 | 14 | class HPEPoseDataset(Dataset): 15 | def __init__(self, annotations_file, img_dir, transform=None, phase='train'): 16 | self.anno_file = annotations_file 17 | self.img_dir = img_dir 18 | self.transform = transform 19 | self.phase = phase 20 | 21 | self.pro_annos = [] 22 | self.gen_intermediate_file() 23 | 24 | self.resize = Resize() 25 | self.generatehm = GenerateHeatMap() 26 | 27 | def gen_intermediate_file(self): 28 | _pkl_file = opt.interim_data_path + '{}_preprocessed.pkl'.format(self.phase) 29 | if os.path.exists(_pkl_file): 30 | self.pro_annos = pickle.load(open(_pkl_file, 'rb')) 31 | # if self.phase is 'val': 32 | # self.pro_annos = self.pro_annos[:50] 33 | else: 34 | pass 35 | 36 | def __getitem__(self, idx): 37 | anno = self.pro_annos[idx] 38 | image = cv2.imread(self.img_dir + anno['image_id'] + '.jpg') 39 | lx, ly = anno['coords'][:2] 40 | rx, ry = anno['coords'][2:] 41 | img = image[ly:ry, lx:rx, :] 42 | kps = np.array(anno['keypoints']).reshape(-1, 3) - [lx, ly, 0] 43 | if self.transform is not None: 44 | img, kps = self.transform(img, kps) 45 | # img, kps = self.kptransform(img, kps) 46 | # img, kps = self.resize(img, kps) 47 | img, label = self.generatehm(img, kps) 48 | img = img[:, :, (2, 1, 0)] 49 | img = np.transpose(img, (2, 0, 1)) 50 | return img, label 51 | 52 | def __len__(self): 53 | return len(self.pro_annos) 54 | 55 | def data_sample(self, n): 56 | pass 57 | 58 | 59 | class HPEPoseValDataset(Dataset): 60 | def __init__(self, annotations_file, img_dir): 61 | self.anno_file = annotations_file 62 | self.img_dir = img_dir 63 | self.helper = Helper() 64 | 65 | self.pro_annos = [] 66 | self.gen_intermediate_file() 67 | 68 | self.resize = Resize(mean=opt.val_mean) 69 | # self.resize = Resize() 70 | self.submean = SubtractMeans(mean=opt.val_mean) 71 | 72 | def gen_intermediate_file(self): 73 | _pkl_file = opt.interim_data_path + 'val_preprocessed.pkl' 74 | if os.path.exists(_pkl_file): 75 | self.pro_annos = pickle.load(open(_pkl_file, 'rb')) 76 | else: 77 | anno = pd.read_json(self.anno_file) 78 | for i in tqdm(range(anno.shape[0])): 79 | img_np = cv2.imread(self.img_dir + anno.image_id[i] + '.jpg') 80 | h, w = np.shape(img_np)[:2] 81 | for k, v in anno.human_annotations[i].items(): 82 | self.pro_annos.append({'image_id': anno.image_id[i], 83 | 'human': k, 84 | 'coords': np.array(v), 85 | 'height_width': (h, w), 86 | 'keypoints': anno.keypoint_annotations[i][k]}) 87 | self.pro_annos = list(filter(lambda x: x['image_id'] not in self.helper.img_list, self.pro_annos)) 88 | del anno 89 | with open(_pkl_file, 'wb') as f: 90 | pickle.dump(self.pro_annos, f) 91 | 92 | def __getitem__(self, idx): 93 | anno = self.pro_annos[idx] 94 | image = cv2.imread(self.img_dir + anno['image_id'] + '.jpg') 95 | lx, ly = anno['coords'][:2] 96 | rx, ry = anno['coords'][2:] 97 | img = image[ly:ry, lx:rx, :] 98 | img, _ = self.resize(img, None) 99 | img, _ = self.submean(img, None) 100 | img = img[:, :, (2, 1, 0)] 101 | img = np.transpose(img, (2, 0, 1)) 102 | return img, anno 103 | 104 | def __len__(self): 105 | return len(self.pro_annos) 106 | 107 | 108 | class HPEPoseTestDataset(Dataset): 109 | def __init__(self, annotations_file, test_img_dir): 110 | self.anno_file = annotations_file 111 | self.img_dir = test_img_dir 112 | 113 | self.pro_annos = [] 114 | self.gen_intermediate_file() 115 | 116 | self.resize = Resize(mean=opt.test_mean) 117 | self.submean = SubtractMeans(mean=opt.test_mean) 118 | 119 | def gen_intermediate_file(self): 120 | _pkl_file = opt.interim_data_path + 'aic_testset_preprocessed.pkl' 121 | if os.path.exists(_pkl_file): 122 | self.pro_annos = pickle.load(open(_pkl_file, 'rb')) 123 | else: 124 | with open(self.anno_file, 'rb') as f: 125 | anno = pickle.load(f) 126 | for i in tqdm(range(len(anno)), ncols=50): 127 | img_np = cv2.imread(self.img_dir + anno[i]['image_id'] + '.jpg') 128 | h, w = np.shape(img_np)[:2] 129 | for k, v in anno[i]['human_annotations'].items(): 130 | coords = np.array(v).reshape(-1, 2) 131 | offset = (coords[1] - coords[0]) * 0.15 # 沿着长和宽扩大30% 132 | coords = v + np.concatenate((-offset, offset)) 133 | coords = coords.astype("int") 134 | coords[np.where(coords < 0)] = 0 135 | if coords[2] > w: 136 | coords[2] = w 137 | if coords[3] > h: 138 | coords[3] = h 139 | self.pro_annos.append({'image_id': anno[i]['image_id'], 140 | 'human': k, 141 | 'coords': coords, 142 | 'height_width': (h, w)}) 143 | with open(_pkl_file, 'wb') as f: 144 | pickle.dump(self.pro_annos, f) 145 | 146 | def __getitem__(self, idx): 147 | anno = self.pro_annos[idx] 148 | image = cv2.imread(self.img_dir + anno['image_id'] + '.jpg') 149 | lx, ly = anno['coords'][:2] 150 | rx, ry = anno['coords'][2:] 151 | img = image[ly:ry, lx:rx, :] 152 | kps = None 153 | img, kps = self.resize(img, kps) 154 | img, kps = self.submean(img, kps) 155 | img = img[:, :, (2, 1, 0)] 156 | img = np.transpose(img, (2, 0, 1)) 157 | return img, anno 158 | 159 | def __len__(self): 160 | return len(self.pro_annos) 161 | -------------------------------------------------------------------------------- /data/test_data_noexpand.py: -------------------------------------------------------------------------------- 1 | # coding:utf8 2 | from torch.utils.data import Dataset 3 | import os 4 | import pickle 5 | import numpy as np 6 | import cv2 7 | from config import opt 8 | from .augmentation import Resize, SubtractMeans 9 | from tqdm import tqdm 10 | 11 | 12 | class HPEPoseTestDataset_NE(Dataset): 13 | def __init__(self, annotations_file, test_img_dir): 14 | self.anno_file = annotations_file 15 | self.img_dir = test_img_dir 16 | 17 | self.pro_annos = [] 18 | self.gen_intermediate_file() 19 | 20 | self.resize = Resize(mean=opt.test_mean) 21 | self.submean = SubtractMeans(mean=opt.test_mean) 22 | 23 | def gen_intermediate_file(self): 24 | _pkl_file = opt.interim_data_path + 'test_preprocessed.pkl' 25 | if os.path.exists(_pkl_file): 26 | self.pro_annos = pickle.load(open(_pkl_file, 'rb')) 27 | else: 28 | with open(self.anno_file, 'rb') as f: 29 | anno = pickle.load(f) 30 | for i in tqdm(range(len(anno))): 31 | img_np = cv2.imread(self.img_dir + anno[i]['image_id'] + '.jpg') 32 | h, w = np.shape(img_np)[:2] 33 | for k, v in anno[i]['human_annotations'].items(): 34 | self.pro_annos.append({'image_id': anno[i]['image_id'], 35 | 'human': k, 36 | 'coords': v, 37 | 'height_width': (h, w)}) 38 | with open(_pkl_file, 'wb') as f: 39 | pickle.dump(self.pro_annos, f) 40 | 41 | def __getitem__(self, idx): 42 | anno = self.pro_annos[idx] 43 | image = cv2.imread(self.img_dir + anno['image_id'] + '.jpg') 44 | lx, ly = anno['coords'][:2] 45 | rx, ry = anno['coords'][2:] 46 | img = image[ly:ry, lx:rx, :] 47 | kps = None 48 | img, kps = self.resize(img, kps) 49 | img, kps = self.submean(img, kps) 50 | img = img[:, :, (2, 1, 0)] 51 | img = np.transpose(img, (2, 0, 1)) 52 | return img, anno 53 | 54 | def __len__(self): 55 | return len(self.pro_annos) 56 | -------------------------------------------------------------------------------- /eval_train_val.py: -------------------------------------------------------------------------------- 1 | # coding:utf8 2 | from config import opt 3 | import models 4 | from data import EvalDataset 5 | import torch 6 | from torch.utils import data 7 | from torch.autograd import Variable 8 | from utils.prediction_handle import get_pred_kps, val_input_convert 9 | from utils import eval_score 10 | from tqdm import tqdm 11 | import json 12 | 13 | 14 | def eval_train_val(): 15 | opt.model_id = 4 16 | opt.val_bs = 8 17 | model = getattr(models, opt.model[opt.model_id])(num_stacks=4) 18 | torch.cuda.set_device(1) 19 | model = model.cuda(1) 20 | with open('checkpoints/AIC-HGNet_progress.json', 'r') as f: 21 | progress = json.load(f) 22 | model.load_state_dict(torch.load(progress['best_path'])) 23 | # best_path = 'checkpoints/AIC-HGNet_0.567476429117.model' 24 | # model.load_state_dict(torch.load(best_path)) 25 | 26 | val_anno_path = 'official/keypoint_validation_annotations_newclear.json' 27 | annotations = eval_score.load_annotations(val_anno_path) 28 | val_anno_file = '/media/bnrc2/_backup/ai/ai_challenger_keypoint_test_a_20170923/val10000_anno-newclear_thr4.5.pkl' 29 | # val_anno_file = '/home/bnrc2/mu/mxnet-ssd/22338-test-b-data.json' 30 | # val_anno_file = '/media/bnrc2/_backup/ai/mu/abiao_liang/res24_anno.pkl' 31 | print(val_anno_file) 32 | dataset = EvalDataset(val_anno_file, opt.val_img_dir) 33 | # dataset = EvalDataset(val_anno_file, '/media/bnrc2/_backup/ai/mu/abiao_liang/', 'test') 34 | 35 | dataloader = data.DataLoader(dataset, batch_size=opt.val_bs, num_workers=opt.num_workers) 36 | # TODO: model.eval() 37 | model.eval() 38 | print("proposessing data begin...") 39 | pred_list = [] 40 | for processed_img, processed_info in tqdm(dataloader, ncols=50): 41 | processed_img = processed_img.float() 42 | processed_img = Variable(processed_img.cuda()) 43 | pred_list += get_pred_kps(processed_info, model(processed_img)[-1].cpu()) 44 | print("proposessing data end...") 45 | 46 | predictions = val_input_convert(pred_list) 47 | with open('./res12-03_keypoints.json', 'w') as f: 48 | json.dump(predictions, f) 49 | model.train() 50 | mAP_value = eval_score.keypoint_eval(predictions, annotations) 51 | print('mAP_value:', mAP_value) 52 | 53 | 54 | if __name__ == '__main__': 55 | eval_train_val() 56 | -------------------------------------------------------------------------------- /jupyter_file/.ipynb_checkpoints/鏁版嵁澧炲己娴嬭瘯-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [], 3 | "metadata": {}, 4 | "nbformat": 4, 5 | "nbformat_minor": 2 6 | } 7 | -------------------------------------------------------------------------------- /jupyter_file/.ipynb_checkpoints/鏁版嵁澧炲己娴嬭瘯2-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [], 3 | "metadata": {}, 4 | "nbformat": 4, 5 | "nbformat_minor": 2 6 | } 7 | -------------------------------------------------------------------------------- /main_hg.py: -------------------------------------------------------------------------------- 1 | # coding:utf8 2 | from os.path import join 3 | from config import opt 4 | import models 5 | from data import HPEBaseTransform 6 | from data import hgDataset, hgValDataset 7 | import torch 8 | import fire 9 | from torch import nn 10 | from torch.utils import data 11 | from torch.autograd import Variable 12 | from tensorboardX import SummaryWriter 13 | import os 14 | import numpy as np 15 | from tqdm import tqdm 16 | from utils.ian_eval import compute_batch_oks, compute_batch_oks_gpu 17 | from utils.logger import Logger 18 | import json 19 | import time 20 | 21 | 22 | def main(**kwargs): 23 | writer = SummaryWriter(opt.logs_path, comment=opt.env) # tensorboard 24 | opt.title = 'AIC-HGNet' 25 | 26 | use_gpu = torch.cuda.is_available() 27 | opt.model_id = 4 28 | opt.parse(kwargs) 29 | model = getattr(models, opt.model[opt.model_id])(num_stacks=4) 30 | if use_gpu: 31 | torch.cuda.set_device(opt.device) 32 | model = model.cuda() 33 | if opt.resume: 34 | logger = Logger(join(opt.logs_path, '{}_log.txt'.format(opt.title)), title=opt.title, resume=True) # 生成日志文件 35 | with open(opt.checkpoints + '{}_progress.json'.format(opt.title), 'r') as f: 36 | opt.progress = json.load(f) 37 | model.load(opt.progress['best_path'], opt.device) 38 | opt.lr = opt.progress['lr'] 39 | else: 40 | with open(opt.checkpoints + '{}_progress.json'.format(opt.title), 'w') as f: 41 | json.dump(opt.progress, f) 42 | opt.progress['lr'] = opt.lr 43 | logger = Logger(join(opt.logs_path, '{}_log.txt'.format(opt.title)), title=opt.title) 44 | logger.set_names(['Epoch', '--Time', '--TrainLoss', '--TrainmAP', '--ValmAP']) 45 | 46 | # 是否使用少量数据跑 Demo 47 | if opt.demo: 48 | opt.img_dir = '/home/bnrc2/ai_challenge/ian/hg.aic.pytorch/demo_data/train_images/' 49 | opt.annotations_file = '/home/bnrc2/ai_challenge/ian/hg.aic.pytorch/demo_data/' \ 50 | 'keypoint_train_annotations_20170909.json' 51 | opt.val_img_dir = '/home/bnrc2/ai_challenge/ian/hg.aic.pytorch/demo_data/validation_images/' 52 | opt.val_anno_file = '/home/bnrc2/ai_challenge/ian/hg.aic.pytorch/demo_data/' \ 53 | 'keypoint_validation_annotations_20170911.json' 54 | opt.config_info_print() 55 | # 数据集 56 | trainset = hgDataset(opt.annotations_file, opt.img_dir, HPEBaseTransform(opt.train_mean, hm_side=64.0)) 57 | valset = hgValDataset(opt.val_anno_file, opt.val_img_dir) 58 | valloader = data.DataLoader(valset, batch_size=opt.val_bs, num_workers=opt.num_workers) 59 | 60 | optimizer = model.get_optimizer() 61 | criterion = nn.MSELoss() 62 | 63 | for epoch in range(opt.progress['epoch'], opt.epoch): 64 | if epoch in [5, 10, 15, 20]: 65 | opt.lr *= 0.5 66 | print(opt.lr) 67 | model.get_optimizer(opt.lr) 68 | if epoch in range(5, 200, 5): 69 | model_path = 'checkpoints/{}_epoch_{}'.format(opt.title, epoch) + '.model' 70 | torch.save(model.state_dict(), model_path) 71 | opt.progress['epoch'] = epoch 72 | opt.progress['count'] = train(model, optimizer, criterion, trainset, valloader, logger, writer) 73 | 74 | 75 | def train(model, optimizer, criterion, trainset, valloader, logger, writer): 76 | model.train() 77 | trainloader = data.DataLoader(trainset, batch_size=opt.batch_size, shuffle=True, num_workers=4) 78 | stride = int(len(trainset) / opt.batch_size) 79 | epoch = opt.progress['epoch'] 80 | count = opt.progress['count'] + 1 81 | start_count = count % stride 82 | print('bs:', opt.batch_size) 83 | for i, (img, label, anno) in tqdm(enumerate(trainloader, start_count), ncols=30): 84 | img, label = Variable(img.float().cuda()), Variable(label.float().cuda()) 85 | optimizer.zero_grad() 86 | output = model(img) 87 | # score_map = output[-1].cpu() 88 | score_map = output[-1] 89 | loss = criterion(output[0], label) 90 | for j in range(1, len(output)): 91 | loss += criterion(output[j], label) 92 | loss.backward() # 反向传播 93 | optimizer.step() # 参数更新 94 | if count % opt.plot_every == 0: 95 | oks_all = compute_batch_oks_gpu(score_map, anno) 96 | average_precision = [] 97 | for threshold in np.linspace(0.5, 0.95, 10): 98 | average_precision.append(np.sum(oks_all > threshold) / np.float32(opt.batch_size)) 99 | trainmAP = np.mean(average_precision) 100 | logger.append([epoch, time.strftime('%m/%d-%H:%M:%S'), (loss.data[0] * 10000), trainmAP, '-']) 101 | writer.add_scalar('hg_loss', loss.data[0], count) 102 | writer.add_scalar('train_mAP', trainmAP, count) 103 | if count % opt.check_every == 0 and count != start_count: 104 | valmAP = val(model, valloader) 105 | logger.append([epoch, time.strftime('%m/%d-%H:%M:%S'), (loss.data[0] * 10000), '-', valmAP]) 106 | writer.add_scalar('mAP_value', valmAP, count) 107 | if opt.progress['best_mAP'] is None or valmAP > opt.progress['best_mAP']: 108 | opt.progress['best_mAP'] = valmAP 109 | opt.progress['count'] = count 110 | opt.progress['lr'] = opt.lr 111 | if opt.progress['best_path'] is not '': 112 | os.system('rm {}'.format(opt.progress['best_path'])) # 删除mAP值低的模型 113 | best_path = model.save(opt.title + '_' + str(valmAP)) 114 | opt.progress['best_path'] = best_path 115 | with open(opt.checkpoints + '{}_progress.json'.format(opt.title), 'w') as f: 116 | json.dump(opt.progress, f) 117 | # tra_loader = data.DataLoader(trainset, batch_size=opt.batch_size, shuffle=True, num_workers=4) 118 | count += 1 119 | if i >= stride: 120 | break 121 | return count 122 | 123 | 124 | def val(model, loader): 125 | """回归子网络验证""" 126 | model.eval() 127 | set_len = len(loader.dataset) 128 | oks_all = np.zeros(0) 129 | for img, anno in loader: 130 | img = img.float().cuda() 131 | img = Variable(img) 132 | oks_all = np.concatenate((oks_all, compute_batch_oks_gpu(model(img)[-1], anno)), axis=0) 133 | average_precision = [] 134 | for threshold in np.linspace(0.5, 0.95, 10): 135 | average_precision.append(np.sum(oks_all > threshold) / np.float32(set_len)) 136 | model.train() 137 | return np.mean(average_precision) 138 | 139 | 140 | def exp_lr_scheduler(optimizer, epoch, init_lr=0.001, decay_rate=0.95, lr_decay_epoch=2): 141 | """Decay learning rate by a factor of 0.1 every lr_decay_epoch epochs.""" 142 | 143 | lr = init_lr * (0.95 ** (epoch // lr_decay_epoch)) 144 | 145 | if epoch % lr_decay_epoch == 0: 146 | print('LR is set to {}'.format(lr)) 147 | 148 | for param_group in optimizer.param_groups: 149 | param_group['lr'] = lr 150 | 151 | return optimizer 152 | 153 | 154 | def detection_loss_func(detection_result, label): 155 | s = torch.sum( 156 | torch.mul(label, torch.log(detection_result)) + torch.mul((1 - label), 157 | torch.log((1 - detection_result)))) 158 | return torch.div(-s, 14.0 * opt.batch_size) 159 | 160 | 161 | def l2_loss_func(regressiong_result, label): 162 | s = torch.sum(torch.pow((regressiong_result - label), 2)) 163 | return torch.div(s, 14.0 * opt.batch_size) 164 | 165 | 166 | if __name__ == "__main__": 167 | fire.Fire() 168 | -------------------------------------------------------------------------------- /models/BasicModule.py: -------------------------------------------------------------------------------- 1 | # coding:utf8 2 | import torch 3 | import time 4 | 5 | 6 | class BasicModule(torch.nn.Module): 7 | def __init__(self): 8 | super(BasicModule, self).__init__() 9 | 10 | def load(self, path, device=0): 11 | data = torch.load(path, map_location=lambda storage, loc: storage.cuda(device)) 12 | return self.load_state_dict(data) 13 | 14 | def save(self, name=None): 15 | prefix = 'checkpoints/' 16 | if name is None: 17 | name = time.strftime('%m%d_%H:%M:%S') 18 | path = prefix + name + '.model' 19 | data = self.state_dict() 20 | torch.save(data, path) 21 | return path 22 | 23 | def get_optimizer(self, lr=2.5e-4, weight_decay=0, momentum=0): 24 | optimizer = torch.optim.RMSprop(self.parameters(), 25 | lr=lr, 26 | momentum=momentum, 27 | weight_decay=weight_decay) 28 | return optimizer 29 | -------------------------------------------------------------------------------- /models/Conv_part_hm_reg_model.py: -------------------------------------------------------------------------------- 1 | # coding:utf-8 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.nn import init 6 | from config import opt 7 | from .new_ResNet import resnet152 8 | from .hourglass import Hourglass, Bottleneck 9 | from collections import OrderedDict 10 | 11 | __all__ = ['Part_detection_subnet_model', 'Regression_subnet'] 12 | 13 | 14 | class Part_detection_subnet_model(nn.Module): 15 | def __init__(self, opt=None): 16 | super(Part_detection_subnet_model, self).__init__() 17 | self.part_resnet = resnet152(pretrained=True) 18 | 19 | for param in self.part_resnet.parameters(): 20 | param.requires_grad = False 21 | self.conv_b6 = nn.Conv2d(2048, 16, kernel_size=1, stride=1) 22 | self.bn_b6 = nn.BatchNorm2d(16) 23 | self.relu_b6 = nn.ReLU(inplace=True) 24 | self.deconv = nn.ConvTranspose2d(16, 14, kernel_size=4, stride=4) 25 | self.upsample = nn.Upsample(size=(256, 256)) 26 | 27 | def forward(self, x): 28 | x = self.part_resnet(x) 29 | x = self.conv_b6(x) 30 | x = self.deconv(x) 31 | x = self.upsample(x) 32 | x = F.sigmoid(x) 33 | 34 | return x 35 | 36 | 37 | class Regression_subnet(nn.Module): 38 | """Regression subnet""" 39 | 40 | def __init__(self, opt=None, block=Bottleneck, num_blocks=3, depth=4, bias=True): 41 | super(Regression_subnet, self).__init__() 42 | self.inplanes = 64 43 | self.num_feats = 128 44 | self.detection_subnet = load_detection_subnet() 45 | 46 | self.conv1 = nn.Conv2d(17, 64, kernel_size=7, stride=2, padding=3, bias=bias) # out:[batch_size,64,128,128] 47 | self.bn1 = nn.BatchNorm2d(self.inplanes) 48 | self.relu1 = nn.ReLU(inplace=True) 49 | # 使用一个简单的网络结构,降采样->上采样 50 | # self.down1 = nn.Sequential( 51 | # nn.Conv2d(64, 128, kernel_size=7, stride=2, padding=3, bias=bias), # out:[bs, 128, 64, 64] 52 | # nn.BatchNorm2d(128), 53 | # nn.ReLU(inplace=True) 54 | # ) 55 | # self.down2 = nn.Sequential( 56 | # nn.Conv2d(128, 256, kernel_size=1, stride=1, bias=bias), # out:[bs, 256, 64, 64] 57 | # nn.BatchNorm2d(256), 58 | # nn.ReLU(inplace=True) 59 | # ) 60 | # self.up1 = nn.Sequential( 61 | # nn.Upsample(scale_factor=2), 62 | # nn.BatchNorm2d(256), 63 | # nn.ReLU(inplace=True) 64 | # ) 65 | # self.up2 = nn.Sequential( 66 | # nn.Upsample(scale_factor=2), 67 | # nn.BatchNorm2d(256), 68 | # nn.ReLU(inplace=True) 69 | # ) 70 | # self.conv2 = nn.Sequential( 71 | # nn.Conv2d(256, 128, kernel_size=1, stride=1, bias=bias), 72 | # nn.BatchNorm2d(128), 73 | # nn.ReLU(inplace=True) 74 | # ) 75 | # self.conv3 = nn.Conv2d(128, 14, kernel_size=1, stride=1, bias=bias) 76 | 77 | ### ----------------------------------- 78 | self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2) # out:[bs,64,64,64] 79 | downsample = nn.Sequential( 80 | nn.Conv2d(self.inplanes, self.num_feats * 2, 81 | kernel_size=1, bias=False), 82 | nn.BatchNorm2d(self.num_feats * 2), 83 | ) 84 | self.d2 = nn.Sequential( 85 | block(self.inplanes, self.num_feats, downsample=downsample), 86 | block(self.num_feats * 2, self.num_feats), 87 | block(self.num_feats * 2, self.num_feats) 88 | ) # out:[bs,256,64,64] 89 | 90 | self.hg = Hourglass(block, num_blocks, self.num_feats, depth) 91 | 92 | self.d61 = nn.Conv2d(256, 512, kernel_size=1, stride=1) 93 | self.bn_d61 = nn.BatchNorm2d(512) 94 | self.relu_d61 = nn.ReLU(inplace=True) 95 | 96 | # self.d62 = nn.Conv2d(512, 512, kernel_size=1, stride=1) 97 | # self.bn_d62 = nn.BatchNorm2d(512) 98 | # self.relu_d62 = nn.ReLU(inplace=True) 99 | 100 | self.up1 = nn.Sequential( 101 | nn.Upsample(scale_factor=2), 102 | nn.BatchNorm2d(512), 103 | nn.ReLU(inplace=True) 104 | ) 105 | self.up2 = nn.Sequential( 106 | nn.Upsample(scale_factor=2), 107 | nn.BatchNorm2d(512), 108 | nn.ReLU(inplace=True) 109 | ) 110 | self.d7 = nn.Conv2d(512, 14, kernel_size=1, stride=1) # out:[bs,14,256,256] 111 | 112 | # self.d5 = nn.ConvTranspose2d(14, 14, kernel_size=4, stride=4) 113 | 114 | ### ----------------------------------- 115 | for m in self.modules(): 116 | if isinstance(m, nn.Conv2d) and m.weight.requires_grad: 117 | init.xavier_uniform(m.weight.data) 118 | elif isinstance(m, nn.BatchNorm2d) and m.weight.requires_grad: 119 | m.weight.data.fill_(1) 120 | m.bias.data.zero_() 121 | 122 | def forward(self, x): 123 | detec = self.detection_subnet(x) 124 | # print("Detection size:", detec.size()) 125 | # a = detec.cpu().data.numpy()[0][0] 126 | # import numpy as np 127 | # print("Detec Data:", np.max(a)) 128 | # print("Input size:", x.size()) 129 | # print("Input Data:", x.cpu().data.numpy()[0][0][0][:10]) 130 | x = torch.cat((detec, x), 1) 131 | # print("Cat size:", x.size()) 132 | x = self.conv1(x) 133 | x = self.bn1(x) 134 | x = self.relu1(x) 135 | x = self.maxpool1(x) 136 | 137 | x = self.d2(x) 138 | x = self.hg(x) 139 | 140 | x = self.d61(x) 141 | x = self.bn_d61(x) 142 | x = self.relu_d61(x) 143 | x = self.up1(x) 144 | x = self.up2(x) 145 | 146 | # x = self.d62(x) 147 | # x = self.bn_d62(x) 148 | # x = self.relu_d62(x) 149 | 150 | x = self.d7(x) 151 | # -------- 152 | # x = self.down1(x) 153 | # x = self.down2(x) 154 | # x = self.up1(x) 155 | # x = self.up2(x) 156 | # x = self.conv2(x) 157 | # x = self.conv3(x) 158 | 159 | return F.sigmoid(x) 160 | 161 | 162 | def load_detection_subnet(): 163 | model = Part_detection_subnet_model() 164 | state_dict = torch.load(opt.checkpoints + opt.model[0] + '.pkl', 165 | map_location=lambda storage, loc: storage.cuda(0)) 166 | # new_state_dict = OrderedDict() 167 | # for k, v in state_dict.items(): 168 | # name = k[7:] # remove `module.` 169 | # new_state_dict[name] = v 170 | model.load_state_dict(state_dict) 171 | for param in model.parameters(): 172 | param.requires_grad = False 173 | 174 | return model 175 | 176 | 177 | if __name__ == '__main__': 178 | # reg_model = Regression_subnet() 179 | # x = torch.autograd.Variable(torch.arange(0, 17 * 256 * 256).view(1, 17, 256, 256)).float() 180 | # o = reg_model(x) 181 | # print (o.size()) 182 | model = load_detection_subnet() 183 | print(model) 184 | -------------------------------------------------------------------------------- /models/HPE_det_reg_model.py: -------------------------------------------------------------------------------- 1 | # coding:utf-8 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.nn import init 6 | from config import opt 7 | from .new_ResNet import resnet101 8 | from .hourglass import Hourglass, Bottleneck 9 | 10 | __all__ = ['Part_detection_subnet101', 'Regression_subnet101'] 11 | 12 | 13 | class Part_detection_subnet101(nn.Module): 14 | def __init__(self): 15 | super(Part_detection_subnet101, self).__init__() 16 | self.part_resnet = resnet101(pretrained=True) 17 | self.add_extras = Add_extras() 18 | # for param in self.part_resnet.parameters(): 19 | # param.requires_grad = False 20 | 21 | def forward(self, x): 22 | x = self.part_resnet(x) # [16, 2048, 16, 16] 23 | return self.add_extras(x) 24 | 25 | 26 | class Add_extras(nn.Module): 27 | def __init__(self): 28 | super(Add_extras, self).__init__() 29 | self.b6 = nn.Sequential( 30 | nn.Conv2d(2048, 256, kernel_size=1, stride=1), 31 | nn.BatchNorm2d(256), 32 | nn.ReLU(inplace=True) 33 | ) 34 | self.b7 = nn.Sequential( 35 | # nn.Upsample(size=(128, 128), mode='bilinear'), 36 | nn.ConvTranspose2d(256, 14, kernel_size=4, stride=4), 37 | nn.Upsample(size=(256, 256)) 38 | ) 39 | 40 | def forward(self, x): 41 | x = self.b6(x) 42 | x = self.b7(x) 43 | return F.sigmoid(x) 44 | 45 | 46 | class Regression_subnet101(nn.Module): 47 | """Regression subnet""" 48 | 49 | def __init__(self, block=Bottleneck, num_blocks=3, depth=4, bias=True): 50 | super(Regression_subnet101, self).__init__() 51 | self.inplanes = 64 52 | self.num_feats = 128 53 | self.detection_subnet = load_detection_subnet() 54 | self.d1 = nn.Sequential( 55 | nn.Conv2d(17, 64, kernel_size=7, stride=2, padding=3, bias=bias), # out:[batch_size,64,128,128] 56 | nn.BatchNorm2d(self.inplanes), 57 | nn.ReLU(inplace=True) 58 | ) 59 | self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2) # out:[bs,64,64,64] 60 | downsample = nn.Sequential( 61 | nn.Conv2d(self.inplanes, self.num_feats * 2, kernel_size=1, bias=False), 62 | nn.BatchNorm2d(self.num_feats * 2), 63 | ) 64 | self.d2 = nn.Sequential( 65 | block(self.inplanes, self.num_feats, downsample=downsample), 66 | block(self.num_feats * 2, self.num_feats), 67 | block(self.num_feats * 2, self.num_feats) 68 | ) # out:[bs,256,64,64] 69 | self.hg = Hourglass(block, num_blocks, self.num_feats, depth) 70 | 71 | self.d6 = nn.Sequential( 72 | nn.Conv2d(256, 512, kernel_size=1, stride=1), 73 | nn.BatchNorm2d(512), 74 | nn.ReLU(inplace=True) 75 | ) 76 | self.up1 = nn.Sequential( 77 | nn.Upsample(scale_factor=2, mode='bilinear'), 78 | nn.BatchNorm2d(512), 79 | nn.ReLU(inplace=True) 80 | ) 81 | self.up2 = nn.Sequential( 82 | nn.Upsample(scale_factor=2, mode='bilinear'), 83 | nn.BatchNorm2d(512), 84 | nn.ReLU(inplace=True) 85 | ) 86 | self.d7 = nn.Conv2d(512, 14, kernel_size=1, stride=1) # out:[bs,14,256,256] 87 | 88 | for m in self.modules(): 89 | if isinstance(m, nn.Conv2d) and m.weight.requires_grad: 90 | init.xavier_uniform(m.weight.data) 91 | elif isinstance(m, nn.BatchNorm2d) and m.weight.requires_grad: 92 | m.weight.data.fill_(1) 93 | m.bias.data.zero_() 94 | 95 | def forward(self, x): 96 | detec = self.detection_subnet(x) 97 | x = torch.cat((detec, x), 1) 98 | x = self.d1(x) 99 | x = self.maxpool1(x) 100 | x = self.d2(x) 101 | x = self.hg(x) 102 | x = self.d6(x) 103 | x = self.up1(x) 104 | x = self.up2(x) 105 | x = self.d7(x) 106 | return F.sigmoid(x) 107 | 108 | 109 | def load_detection_subnet(): 110 | model = Part_detection_subnet101() 111 | state_dict = torch.load(opt.checkpoints + opt.model[2] + '.model', 112 | map_location=lambda storage, loc: storage.cuda(1)) 113 | model.load_state_dict(state_dict) 114 | for param in model.parameters(): 115 | param.requires_grad = False 116 | return model 117 | 118 | 119 | if __name__ == '__main__': 120 | # reg_model = Regression_subnet() 121 | # x = torch.autograd.Variable(torch.arange(0, 17 * 256 * 256).view(1, 17, 256, 256)).float() 122 | # o = reg_model(x) 123 | # print (o.size()) 124 | model = load_detection_subnet() 125 | print(model) 126 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .Conv_part_hm_reg_model import Part_detection_subnet_model, Regression_subnet 2 | from .new_ResNet import resnet152 3 | from .HPE_det_reg_model import Part_detection_subnet101, Regression_subnet101 4 | from .hourglass import HourglassNet 5 | -------------------------------------------------------------------------------- /models/hourglass.py: -------------------------------------------------------------------------------- 1 | # coding:utf8 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .BasicModule import BasicModule 5 | 6 | __all__ = ['Hourglass', 'Bottleneck', 'HourglassNet'] 7 | 8 | 9 | class Bottleneck(nn.Module): 10 | expansion = 2 11 | 12 | def __init__(self, inplanes, planes, stride=1, downsample=None): # (256, 128) 13 | super(Bottleneck, self).__init__() 14 | 15 | self.bn1 = nn.BatchNorm2d(inplanes) 16 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=True) 17 | self.bn2 = nn.BatchNorm2d(planes) 18 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 19 | padding=1, bias=True) 20 | self.bn3 = nn.BatchNorm2d(planes) 21 | self.conv3 = nn.Conv2d(planes, planes * 2, kernel_size=1, bias=True) 22 | self.relu = nn.ReLU(inplace=True) 23 | self.downsample = downsample 24 | self.stride = stride 25 | 26 | def forward(self, x): 27 | residual = x 28 | 29 | out = self.bn1(x) 30 | out = self.relu(out) 31 | out = self.conv1(out) 32 | 33 | out = self.bn2(out) 34 | out = self.relu(out) 35 | out = self.conv2(out) 36 | 37 | out = self.bn3(out) 38 | out = self.relu(out) 39 | out = self.conv3(out) 40 | 41 | if self.downsample is not None: 42 | residual = self.downsample(x) 43 | 44 | out += residual 45 | 46 | return out 47 | 48 | 49 | class Hourglass(nn.Module): 50 | def __init__(self, block, num_blocks, planes, depth): # (_, 4, 128, 4) 51 | super(Hourglass, self).__init__() 52 | self.depth = depth # 4 53 | self.block = block 54 | self.upsample = nn.Upsample(scale_factor=2) 55 | self.hg = self._make_hour_glass(block, num_blocks, planes, depth) 56 | 57 | def _make_residual(self, block, num_blocks, planes): 58 | layers = [] 59 | for i in range(0, num_blocks): 60 | layers.append(block(planes * block.expansion, planes)) 61 | return nn.Sequential(*layers) 62 | 63 | def _make_hour_glass(self, block, num_blocks, planes, depth): 64 | hg = [] 65 | for i in range(depth): 66 | res = [] 67 | for j in range(3): 68 | res.append(self._make_residual(block, num_blocks, planes)) 69 | if i == 0: 70 | res.append(self._make_residual(block, num_blocks, planes)) 71 | hg.append(nn.ModuleList(res)) 72 | return nn.ModuleList(hg) 73 | 74 | def _hour_glass_forward(self, n, x): 75 | up1 = self.hg[n - 1][0](x) 76 | low1 = F.max_pool2d(x, 2, stride=2) 77 | low1 = self.hg[n - 1][1](low1) 78 | 79 | if n > 1: 80 | low2 = self._hour_glass_forward(n - 1, low1) 81 | else: 82 | low2 = self.hg[n - 1][3](low1) 83 | low3 = self.hg[n - 1][2](low2) 84 | up2 = self.upsample(low3) 85 | out = up1 + up2 86 | return out 87 | 88 | def forward(self, x): 89 | return self._hour_glass_forward(self.depth, x) 90 | 91 | 92 | class HourglassNet(BasicModule): 93 | '''Hourglass model from Newell et al ECCV 2016''' 94 | 95 | def __init__(self, block=Bottleneck, num_stacks=2, num_blocks=4, num_classes=14): 96 | """ 97 | 参数解释 98 | :param block: hg块元素 99 | :param num_stacks: 有几个hg 100 | :param num_blocks: 在两个hg之间有几个block块 101 | :param num_classes: keypoint个数,也就是最后的heatmap个数 102 | """ 103 | super(HourglassNet, self).__init__() 104 | 105 | self.inplanes = 64 106 | self.num_feats = 128 107 | self.num_stacks = num_stacks 108 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 109 | bias=True) 110 | self.bn1 = nn.BatchNorm2d(self.inplanes) 111 | self.relu = nn.ReLU(inplace=True) 112 | self.layer1 = self._make_residual(block, self.inplanes, 1) # inplanes = 128 downsample 113 | self.layer2 = self._make_residual(block, self.inplanes, 1) # inplanes = 256 downsample 114 | self.layer3 = self._make_residual(block, self.num_feats, 1) # inplanes = 256 115 | self.maxpool = nn.MaxPool2d(2, stride=2) 116 | 117 | # build hourglass modules 118 | ch = self.num_feats * block.expansion # 256 119 | hg, res, fc, score, fc_, score_ = [], [], [], [], [], [] 120 | for i in range(num_stacks): 121 | hg.append(Hourglass(block, num_blocks, self.num_feats, 4)) 122 | res.append(self._make_residual(block, self.num_feats, num_blocks)) # inplanes = 256, 123 | fc.append(self._make_fc(ch, ch)) 124 | score.append(nn.Conv2d(ch, num_classes, kernel_size=1, bias=True)) 125 | if i < num_stacks - 1: 126 | fc_.append(nn.Conv2d(ch, ch, kernel_size=1, bias=True)) 127 | score_.append(nn.Conv2d(num_classes, ch, kernel_size=1, bias=True)) 128 | self.hg = nn.ModuleList(hg) 129 | self.res = nn.ModuleList(res) 130 | self.fc = nn.ModuleList(fc) 131 | self.score = nn.ModuleList(score) 132 | self.fc_ = nn.ModuleList(fc_) 133 | self.score_ = nn.ModuleList(score_) 134 | 135 | def _make_residual(self, block, planes, blocks, stride=1): # [64,1] 136 | downsample = None 137 | if stride != 1 or self.inplanes != planes * block.expansion: 138 | downsample = nn.Sequential( 139 | nn.Conv2d(self.inplanes, planes * block.expansion, 140 | kernel_size=1, stride=stride, bias=True), 141 | ) 142 | 143 | layers = [] 144 | layers.append(block(self.inplanes, planes, stride, downsample)) 145 | self.inplanes = planes * block.expansion 146 | for i in range(1, blocks): 147 | layers.append(block(self.inplanes, planes)) 148 | 149 | return nn.Sequential(*layers) 150 | 151 | def _make_fc(self, inplanes, outplanes): 152 | bn = nn.BatchNorm2d(inplanes) 153 | conv = nn.Conv2d(inplanes, outplanes, kernel_size=1, bias=True) 154 | return nn.Sequential( 155 | conv, 156 | bn, 157 | self.relu, 158 | ) 159 | 160 | def forward(self, x): # torch.Size([8, 3, 256, 256]) 161 | out = [] 162 | x = self.conv1(x) 163 | x = self.bn1(x) 164 | x = self.relu(x) 165 | 166 | x = self.layer1(x) 167 | x = self.maxpool(x) 168 | x = self.layer2(x) 169 | x = self.layer3(x) # torch.Size([8, 256, 64, 64]) 170 | for i in range(self.num_stacks): 171 | y = self.hg[i](x) 172 | y = self.res[i](y) 173 | y = self.fc[i](y) 174 | score = self.score[i](y) # [8, 14, 64, 64] 175 | # print("score.size:", score.size(), i) 176 | out.append(score) 177 | if i < self.num_stacks - 1: 178 | fc_ = self.fc_[i](y) 179 | score_ = self.score_[i](score) 180 | x = x + fc_ + score_ 181 | # else: 182 | # print(score.size()) 183 | 184 | return out # (bs, 256, 64, 64) 185 | 186 | 187 | def hg(**kwargs): 188 | model = HourglassNet(Bottleneck, num_stacks=kwargs['num_stacks'], num_blocks=kwargs['num_blocks'], 189 | num_classes=kwargs['num_classes']) 190 | return model 191 | -------------------------------------------------------------------------------- /models/new_ResNet.py: -------------------------------------------------------------------------------- 1 | # coding:utf8 2 | import torch.nn as nn 3 | import math 4 | import torch.utils.model_zoo as model_zoo 5 | 6 | __all__ = ['ResNet', 'resnet152', 'resnet18', 'resnet101'] 7 | 8 | model_urls = { 9 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 10 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 11 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 12 | } 13 | 14 | 15 | def conv3x3(in_planes, out_planes, stride=1): 16 | "3x3 convolution with padding" 17 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 18 | padding=1, bias=False) 19 | 20 | 21 | class BasicBlock(nn.Module): 22 | expansion = 1 23 | 24 | def __init__(self, inplanes, planes, stride=1, downsample=None): 25 | super(BasicBlock, self).__init__() 26 | self.conv1 = conv3x3(inplanes, planes, stride) 27 | self.bn1 = nn.BatchNorm2d(planes) 28 | self.relu = nn.ReLU(inplace=True) 29 | self.conv2 = conv3x3(planes, planes) 30 | self.bn2 = nn.BatchNorm2d(planes) 31 | self.downsample = downsample 32 | self.stride = stride 33 | 34 | def forward(self, x): 35 | residual = x 36 | 37 | out = self.conv1(x) 38 | out = self.bn1(out) 39 | out = self.relu(out) 40 | 41 | out = self.conv2(out) 42 | out = self.bn2(out) 43 | 44 | if self.downsample is not None: 45 | residual = self.downsample(x) 46 | 47 | out += residual 48 | out = self.relu(out) 49 | 50 | return out 51 | 52 | 53 | class Bottleneck(nn.Module): 54 | expansion = 4 55 | 56 | def __init__(self, inplanes, planes, stride=1, downsample=None): 57 | super(Bottleneck, self).__init__() 58 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 59 | self.bn1 = nn.BatchNorm2d(planes) 60 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 61 | padding=1, bias=False) 62 | self.bn2 = nn.BatchNorm2d(planes) 63 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 64 | self.bn3 = nn.BatchNorm2d(planes * 4) 65 | self.relu = nn.ReLU(inplace=True) 66 | self.downsample = downsample 67 | self.stride = stride 68 | 69 | def forward(self, x): 70 | residual = x 71 | 72 | out = self.conv1(x) 73 | out = self.bn1(out) 74 | out = self.relu(out) 75 | 76 | out = self.conv2(out) 77 | out = self.bn2(out) 78 | out = self.relu(out) 79 | 80 | out = self.conv3(out) 81 | out = self.bn3(out) 82 | 83 | if self.downsample is not None: 84 | residual = self.downsample(x) 85 | 86 | out += residual 87 | out = self.relu(out) 88 | 89 | return out 90 | 91 | 92 | class ResNet(nn.Module): 93 | def __init__(self, block, layers, num_classes=1000): 94 | self.inplanes = 64 95 | super(ResNet, self).__init__() 96 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 97 | bias=False) # 128,128,64 98 | self.bn1 = nn.BatchNorm2d(64) 99 | self.relu = nn.ReLU(inplace=True) 100 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # 64,64,64 101 | self.layer1 = self._make_layer(block, 64, layers[0]) 102 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 103 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 104 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1) # 原resnet中stride是2 105 | self.avgpool = nn.AvgPool2d(7) 106 | self.fc = nn.Linear(512 * block.expansion, num_classes) 107 | 108 | for m in self.modules(): 109 | if isinstance(m, nn.Conv2d): 110 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 111 | m.weight.data.normal_(0, math.sqrt(2. / n)) 112 | elif isinstance(m, nn.BatchNorm2d): 113 | m.weight.data.fill_(1) 114 | m.bias.data.zero_() 115 | 116 | def _make_layer(self, block, planes, blocks, stride=1): 117 | downsample = None 118 | if stride != 1 or self.inplanes != planes * block.expansion: 119 | downsample = nn.Sequential( 120 | nn.Conv2d(self.inplanes, planes * block.expansion, 121 | kernel_size=1, stride=stride, bias=False), 122 | nn.BatchNorm2d(planes * block.expansion), 123 | ) 124 | 125 | layers = [] 126 | layers.append(block(self.inplanes, planes, stride, downsample)) 127 | self.inplanes = planes * block.expansion 128 | for i in range(1, blocks): 129 | layers.append(block(self.inplanes, planes)) 130 | 131 | return nn.Sequential(*layers) 132 | 133 | def forward(self, x): 134 | x = self.conv1(x) 135 | x = self.bn1(x) 136 | x = self.relu(x) 137 | x = self.maxpool(x) 138 | 139 | x = self.layer1(x) 140 | x = self.layer2(x) 141 | x = self.layer3(x) 142 | x = self.layer4(x) 143 | # x = self.avgpool(x) 144 | # x = x.view(x.size(0), -1) 145 | # x = self.fc(x) 146 | 147 | return x 148 | 149 | 150 | def resnet152(pretrained=False, **kwargs): 151 | """Constructs a ResNet-152 model. 152 | Args: 153 | pretrained (bool): If True, returns a model pre-trained on ImageNet 154 | """ 155 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 156 | if pretrained: 157 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 158 | 159 | return model 160 | 161 | 162 | def resnet101(pretrained=False, **kwargs): 163 | """Constructs a ResNet-101 model. 164 | Args: 165 | pretrained (bool): If True, returns a model pre-trained on ImageNet 166 | """ 167 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 168 | print("model loading...") 169 | if pretrained: 170 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 171 | print("model load end...") 172 | return model 173 | 174 | 175 | def resnet18(pretrained=False, **kwargs): 176 | """Constructs a ResNet-18 model. 177 | Args: 178 | pretrained (bool): If True, returns a model pre-trained on ImageNet 179 | """ 180 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 181 | if pretrained: 182 | # old_resnet = models.resnet18(pretrained=True) 183 | # old_resnet = nn.Sequential(*list(old_resnet.children())[:-2]) 184 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 185 | 186 | return model 187 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .helper import Helper 2 | from .logger import * 3 | from .prediction_handle import * 4 | from .visualize import Visualizer 5 | from .ian_eval import compute_batch_oks, compute_batch_oks_gpu 6 | # from .net_validation import part1_val, part2_val 7 | -------------------------------------------------------------------------------- /utils/eval_score.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import json 3 | import numpy as np 4 | 5 | 6 | def load_annotations(anno_file): 7 | """Convert annotation JSON file.""" 8 | 9 | annotations = dict() 10 | annotations['image_ids'] = set([]) 11 | annotations['annos'] = dict() 12 | annotations['delta'] = 2 * np.array([0.01388152, 0.01515228, 0.01057665, 0.01417709, \ 13 | 0.01497891, 0.01402144, 0.03909642, 0.03686941, 0.01981803, \ 14 | 0.03843971, 0.03412318, 0.02415081, 0.01291456, 0.01236173]) 15 | try: 16 | # print (anno_file) 17 | annos = json.load(open(anno_file, 'r')) 18 | except Exception: 19 | print('Annotation file does not exist or is an invalid JSON file.') 20 | 21 | for anno in annos: 22 | annotations['image_ids'].add(anno['image_id']) 23 | annotations['annos'][anno['image_id']] = dict() 24 | annotations['annos'][anno['image_id']]['human_annos'] = anno['human_annotations'] 25 | annotations['annos'][anno['image_id']]['keypoint_annos'] = anno['keypoint_annotations'] 26 | 27 | return annotations 28 | 29 | 30 | def old_compute_oks(anno, predict, delta): 31 | """Compute oks matrix (size gtN*pN).""" 32 | 33 | anno_count = len(anno['keypoint_annos'].keys()) 34 | predict_count = len(predict.keys()) 35 | oks = np.zeros((anno_count, predict_count)) 36 | # for every human keypoint annotation 37 | for i in range(anno_count): 38 | anno_key = list(anno['keypoint_annos'].keys())[i] 39 | anno_keypoints = np.reshape(anno['keypoint_annos'][anno_key], (14, 3)) 40 | visible = anno_keypoints[:, 2] == 1 41 | bbox = anno['human_annos'][anno_key] 42 | scale = np.float32((bbox[3] - bbox[1]) * (bbox[2] - bbox[0])) 43 | if np.sum(visible) == 0: 44 | for j in range(predict_count): 45 | oks[i, j] = 0 46 | else: 47 | # for every predicted human 48 | for j in range(predict_count): 49 | predict_key = list(predict.keys())[j] 50 | predict_keypoints = np.reshape(predict[predict_key], (14, 3)) 51 | dis = np.sum((anno_keypoints[visible, :2] \ 52 | - predict_keypoints[visible, :2]) ** 2, axis=1) 53 | oks[i, j] = np.mean(np.exp(-dis / 2 / delta[visible] ** 2 / scale)) 54 | return oks 55 | 56 | 57 | def compute_oks(anno, predict, delta): 58 | """Compute oks matrix (size gtN*pN).""" 59 | 60 | anno_count = len(anno['keypoint_annos'].keys()) 61 | predict_count = len(predict.keys()) 62 | oks = np.zeros((anno_count, predict_count)) 63 | if predict_count == 0: 64 | return oks.T 65 | 66 | # for every human keypoint annotation 67 | for i in range(anno_count): 68 | anno_key = list(anno['keypoint_annos'].keys())[i] 69 | anno_keypoints = np.reshape(anno['keypoint_annos'][anno_key], (14, 3)) 70 | visible = anno_keypoints[:, 2] == 1 71 | bbox = anno['human_annos'][anno_key] 72 | scale = np.float32((bbox[3] - bbox[1]) * (bbox[2] - bbox[0])) # 框的面积 73 | if np.sum(visible) == 0: 74 | for j in range(predict_count): 75 | oks[i, j] = 0 76 | else: 77 | # for every predicted human 78 | for j in range(predict_count): 79 | predict_key = list(predict.keys())[j] 80 | predict_keypoints = np.reshape(predict[predict_key], (14, 3)) 81 | dis = np.sum((anno_keypoints[visible, :2] - predict_keypoints[visible, :2]) ** 2, axis=1) 82 | oks[i, j] = np.mean(np.exp(-dis / 2 / delta[visible] ** 2 / (scale + 1))) 83 | return oks 84 | 85 | 86 | def keypoint_eval(predictions, annotations): 87 | """Evaluate predicted_file and return mAP.""" 88 | 89 | oks_all = np.zeros((0)) 90 | oks_num = 0 91 | 92 | prediction_id_set = set(predictions['image_ids']) 93 | # for every annotation in our test/validation set 94 | for image_id in annotations['image_ids']: 95 | # if the image in the predictions, then compute oks 96 | # print(image_id) 97 | if image_id in prediction_id_set: 98 | oks = compute_oks(anno=annotations['annos'][image_id], \ 99 | predict=predictions['annos'][image_id]['keypoint_annos'], \ 100 | delta=annotations['delta']) 101 | # view pairs with max OKSs as match ones, add to oks_all 102 | oks_all = np.concatenate((oks_all, np.max(oks, axis=1)), axis=0) 103 | # accumulate total num by max(gtN,pN) 104 | oks_num += np.max(oks.shape) 105 | else: 106 | # otherwise report warning 107 | # number of humen in ground truth annotations 108 | gt_n = len(annotations['annos'][image_id]['human_annos'].keys()) 109 | # fill 0 in oks scores 110 | oks_all = np.concatenate((oks_all, np.zeros((gt_n))), axis=0) 111 | # accumulate total num by ground truth number 112 | oks_num += gt_n 113 | 114 | # compute mAP by APs under different oks thresholds 115 | average_precision = [] 116 | for threshold in np.linspace(0.5, 0.95, 10): 117 | average_precision.append(np.sum(oks_all > threshold) / np.float32(oks_num)) 118 | 119 | return np.mean(average_precision) 120 | 121 | 122 | def val_input_convert(anno): 123 | """ 格式转换,转换成eval脚本能使用的prediction 124 | :return: 例如:{"image_ids": ['img1','img2',...], "annos": {'img1':{"human3": [254, 203, 1, ...],"huma2": ...}}} 125 | """ 126 | predictions = dict() 127 | predictions['image_ids'] = [] 128 | predictions['annos'] = dict() 129 | for pred in anno: 130 | predictions['image_ids'].append(pred[0]) 131 | predictions['annos'][pred[0]] = dict() 132 | predictions['annos'][pred[0]]['keypoint_annos'] = pred[1] 133 | return predictions 134 | 135 | 136 | if __name__ == '__main__': 137 | path = '/home/bnrc2/ai_challenge/ian/Pytorch_Human_Pose_Estimation/interim_data/val_dataset' \ 138 | '/keypoint_validation_annotations_20170911.json' 139 | anno = load_annotations(path) 140 | # print keypoint_eval(anno, anno) 141 | print(anno) 142 | -------------------------------------------------------------------------------- /utils/helper.py: -------------------------------------------------------------------------------- 1 | # coding:utf8 2 | import pickle 3 | from config import opt 4 | 5 | 6 | class Helper(object): 7 | """ 8 | 实现一个helper类,用于对数据的预处理和清洗 9 | """ 10 | 11 | def __init__(self): 12 | self.img_list = [u'97fb2cb75320cb0094681715dbb5aa7a13b27fb7', 13 | u'bc4f5b8d1daef85d542a4dee4843c045d9511e9f', 14 | u'9b4694e434c41e7aab3b921b7397293d68311a61', 15 | u'ff585c6fbdb0672e0173afa08e689a98af22aa82', 16 | u'c298cd4a867943e79e03ba92b52204e19c0cbe16', 17 | u'46d05fc2173ae77970ba1eb33107092021753654', 18 | u'9b1d18cae697da8160beed40834ee3b1bfb7386e', 19 | u'd39b35193ef4d7f587ffe0cb4de9d2203b59fa63', 20 | u'aa8e4ce1f69018eaeebac4fa8714a3c00ee85cba', 21 | u'fd0a492dd5aa8d9bd73758f323bbbd1d88e2b03b', 22 | u'b125b9cda788d1a02a11131f7aa1b0f835e13cbf', 23 | u'b125b9cda788d1a02a11131f7aa1b0f835e13cbf', 24 | u'e622c27c0760ed757e7f60b0fac37595ec538506', 25 | u'f94c9f8d14f1c432e38b282012a570e9504c1239', 26 | u'ba8ca016b0000b99f44176cb5c2636a951796621', 27 | u'3f8957d7948790c29886f29e27e3809d8acd3ccc', 28 | u'6805fee7416a6bc5003d291dc06956d5a5b06dc3', 29 | u'daae63a17f06df617d7681d78968dc856685c1d4', 30 | u'aee27632b9990f08cc39da6c8ce595544de96d16', 31 | u'f3f1402e49251ddbc079ae208fa80ae6036eda94', 32 | u'3aef0e2d3a64f2d45b2f0b2b4d38d202aff098cf', 33 | u'89269460a718cd1902c525084d0ba2424ad1a348'] 34 | 35 | def del_problem_imgs(self): 36 | new_anno = pickle.load(open(opt.annotations_file + 'processed_dataset.pkl', 'r')) 37 | for i, na in enumerate(new_anno): 38 | if na[0] in self.img_list: 39 | new_anno.pop(i) 40 | pickle.dump(new_anno, open(opt.annotations_file + 'processed_dataset.pkl', 'w')) 41 | -------------------------------------------------------------------------------- /utils/ian_eval.py: -------------------------------------------------------------------------------- 1 | # coding:utf8 2 | 3 | import numpy as np 4 | from config import opt 5 | import torch.nn.functional as F 6 | import torch 7 | 8 | 9 | def compute_batch_oks(preds, annos): 10 | oks_all = np.zeros(0) 11 | # 8, 14, 64, 64] 12 | preds = F.upsample(preds, scale_factor=4, mode='bilinear').data.numpy() 13 | for i, pred in enumerate(preds): 14 | coord = annos['coords'][i].numpy() 15 | kps = annos['keypoints'][i].numpy() 16 | span_x = coord[2] - coord[0] 17 | span_y = coord[3] - coord[1] 18 | isupright, scale = (True, span_y) if span_y >= span_x else (False, span_x) 19 | _kps = get_keypoint_coordinate(isupright, pred, opt.threshold) 20 | pred_kps = convert_coordinate(_kps, coord, scale) 21 | oks_all = np.append(oks_all, compute_oks(pred_kps, kps, opt.delta, span_x * span_y)) 22 | return oks_all 23 | 24 | 25 | def compute_oks(pred_kps, kps, delta, scale): 26 | """Compute oks matrix (size gtN*pN).""" 27 | anno_keypoints = np.reshape(kps, (14, 3)) 28 | visible = anno_keypoints[:, 2] == 1 29 | predict_keypoints = np.reshape(pred_kps, (14, -1)) 30 | dis = np.sum((anno_keypoints[visible, :2] - predict_keypoints[visible, :2]) ** 2, axis=1) 31 | oks = np.mean(np.exp(-dis / 2 / delta[visible] ** 2 / (scale + 1))) 32 | return oks 33 | 34 | 35 | def get_keypoint_coordinate(isupright, pred, threshold=0.0): 36 | kps = [] 37 | if isupright: 38 | for p in pred: 39 | if np.max(p) > threshold: 40 | x = np.argmax(p) % 256 41 | y = np.argmax(p) / 256 42 | kps.append([x, y, 1]) 43 | else: 44 | kps.append([0, 0, 0]) 45 | else: 46 | for p in pred: 47 | if np.max(p) > threshold: 48 | x = np.argmax(p) % 256 49 | y = np.argmax(p) / 256 50 | kps.append([x, y, 1]) 51 | else: 52 | kps.append([0, 0, 0]) 53 | return kps 54 | 55 | 56 | def convert_coordinate(keypoints, human_position, scale): 57 | kps = np.reshape(keypoints, (-1, 3)) 58 | kps = (kps * [scale / 256.0, scale / 256.0, 1]).astype(np.int16) 59 | kps = kps + [human_position[0], human_position[1], 0] 60 | return kps.reshape(-1).tolist() 61 | 62 | 63 | def compute_batch_oks_gpu(preds, annos): 64 | preds = F.upsample(preds, scale_factor=4, mode='bilinear') 65 | maxval, idx = torch.max(preds.view(preds.size(0), preds.size(1), -1), 2) 66 | y = idx / 256 67 | x = idx % 256 68 | new_kps = torch.stack([x, y], 2) # [bs, 14, 2] 69 | coord = annos['coords'].cuda() # bsx4 70 | kps = annos['keypoints'].view(-1, 14, 3) # 4x14x3 71 | span_x = coord[:, 2] - coord[:, 0] 72 | span_y = coord[:, 3] - coord[:, 1] 73 | s = span_x * span_y 74 | scale = torch.max(span_y, span_x).float() / 256.0 # bs*4 75 | scale = scale.view(-1, 1).expand(preds.size(0), 28).contiguous().view(-1, 14, 2) # 4x14x2 (GPU 1) 76 | new_kps = new_kps.data.float() * scale.cuda() # bsx14x2 (GPU 1) 77 | oks_all = np.zeros(0) 78 | for i in range(preds.size(0)): 79 | new_kps[i] = new_kps[i] + coord[i][:2].repeat(14, 1).float() 80 | oks_all = np.append(oks_all, compute_oks(new_kps[i].cpu().numpy(), kps[i].numpy(), opt.delta, s[i])) 81 | return oks_all 82 | 83 | -------------------------------------------------------------------------------- /utils/ian_eval_tensor.py: -------------------------------------------------------------------------------- 1 | # coding:utf8 2 | import numpy as np 3 | from config import opt 4 | import torch.nn.functional as F 5 | 6 | 7 | def compute_batch_oks(preds, annos): 8 | oks_all = np.zeros(0) 9 | # 8, 14, 64, 64] 10 | preds = F.upsample(preds, scale_factor=4, mode='bilinear').data.numpy() 11 | for i, pred in enumerate(preds): 12 | coord = annos['coords'][i].numpy() 13 | kps = annos['keypoints'][i].numpy() 14 | span_x = coord[2] - coord[0] 15 | span_y = coord[3] - coord[1] 16 | isupright, scale = (True, span_y) if span_y >= span_x else (False, span_x) 17 | _kps = get_keypoint_coordinate(isupright, pred, opt.threshold) 18 | pred_kps = convert_coordinate(_kps, coord, scale) 19 | oks_all = np.append(oks_all, compute_oks(pred_kps, kps, opt.delta, span_x * span_y)) 20 | return oks_all 21 | 22 | 23 | def compute_oks(pred_kps, kps, delta, scale): 24 | """Compute oks matrix (size gtN*pN).""" 25 | anno_keypoints = np.reshape(kps, (14, 3)) 26 | visible = anno_keypoints[:, 2] == 1 27 | predict_keypoints = np.reshape(pred_kps, (14, 3)) 28 | dis = np.sum((anno_keypoints[visible, :2] - predict_keypoints[visible, :2]) ** 2, axis=1) 29 | oks = np.mean(np.exp(-dis / 2 / delta[visible] ** 2 / (scale + 1))) 30 | return oks 31 | 32 | 33 | def get_keypoint_coordinate(isupright, pred, threshold=0.0): 34 | kps = [] 35 | if isupright: 36 | for p in pred: 37 | if np.max(p) > threshold: 38 | x = np.argmax(p) % 256 39 | y = np.argmax(p) / 256 40 | kps.append([x, y, 1]) 41 | else: 42 | kps.append([0, 0, 0]) 43 | else: 44 | for p in pred: 45 | if np.max(p) > threshold: 46 | x = np.argmax(p) % 256 47 | y = np.argmax(p) / 256 48 | kps.append([x, y, 1]) 49 | else: 50 | kps.append([0, 0, 0]) 51 | return kps 52 | 53 | 54 | def convert_coordinate(keypoints, human_position, scale): 55 | kps = np.reshape(keypoints, (-1, 3)) 56 | kps = (kps * [scale / 256.0, scale / 256.0, 1]).astype(np.int16) 57 | kps = kps + [human_position[0], human_position[1], 0] 58 | return kps.reshape(-1).tolist() 59 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | # A simple torch style logger 2 | # (C) Wei YANG 2017 3 | from __future__ import absolute_import 4 | 5 | import os 6 | import sys 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | 10 | __all__ = ['Logger', 'LoggerMonitor', 'savefig'] 11 | 12 | 13 | def savefig(fname, dpi=None): 14 | dpi = 150 if dpi is None else dpi 15 | plt.savefig(fname, dpi=dpi) 16 | 17 | 18 | def plot_overlap(logger, names=None): 19 | names = logger.names if names is None else names 20 | numbers = logger.numbers 21 | for _, name in enumerate(names): 22 | x = np.arange(len(numbers[name])) 23 | plt.plot(x, np.asarray(numbers[name])) 24 | return [logger.title + '(' + name + ')' for name in names] 25 | 26 | 27 | class Logger(object): 28 | '''Save training process to log file with simple plot function.''' 29 | 30 | def __init__(self, fpath, title=None, resume=False): 31 | self.file = None 32 | self.resume = resume 33 | self.title = '' if title is None else title 34 | if fpath is not None: 35 | if resume: 36 | with open(fpath, 'r') as f: 37 | log_info = f.readlines() 38 | self.names = log_info[0].rstrip().split('\t') 39 | self.numbers = {} 40 | for _, name in enumerate(self.names): 41 | self.numbers[name] = [] 42 | for numbers in log_info[1:]: 43 | numbers = numbers.rstrip().split('\t') 44 | for i in range(0, len(numbers)): 45 | self.numbers[self.names[i]].append(numbers[i]) 46 | self.file = open(fpath, 'a') 47 | else: 48 | self.file = open(fpath, 'w') 49 | 50 | def set_names(self, names): 51 | if self.resume: 52 | pass 53 | # initialize numbers as empty list 54 | self.numbers = {} 55 | self.names = names 56 | for _, name in enumerate(self.names): 57 | self.file.write(name) 58 | self.file.write('\t') 59 | self.numbers[name] = [] 60 | self.file.write('\n') 61 | self.file.flush() 62 | 63 | def append(self, numbers): 64 | assert len(self.names) == len(numbers), 'Numbers do not match names' 65 | for index, num in enumerate(numbers): 66 | self.file.write("{}".format(num)) 67 | self.file.write('\t') 68 | self.numbers[self.names[index]].append(num) 69 | self.file.write('\n') 70 | self.file.flush() 71 | 72 | def plot(self, names=None): 73 | names = self.names if names is None else names 74 | numbers = self.numbers 75 | for _, name in enumerate(names): 76 | x = np.arange(len(numbers[name])) 77 | plt.plot(x, np.asarray(numbers[name])) 78 | plt.legend([self.title + '(' + name + ')' for name in names]) 79 | plt.grid(True) 80 | 81 | def close(self): 82 | if self.file is not None: 83 | self.file.close() 84 | 85 | 86 | class LoggerMonitor(object): 87 | '''Load and visualize multiple logs.''' 88 | 89 | def __init__(self, paths): 90 | '''paths is a distionary with {name:filepath} pair''' 91 | self.loggers = [] 92 | for title, path in paths.items(): 93 | logger = Logger(path, title=title, resume=True) 94 | self.loggers.append(logger) 95 | 96 | def plot(self, names=None): 97 | plt.figure() 98 | plt.subplot(121) 99 | legend_text = [] 100 | for logger in self.loggers: 101 | legend_text += plot_overlap(logger, names) 102 | plt.legend(legend_text, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) 103 | plt.grid(True) 104 | 105 | 106 | if __name__ == '__main__': 107 | # Example 108 | logger = Logger('test.txt') 109 | logger.set_names(['Train loss', 'Valid loss', 'Test loss']) 110 | 111 | length = 100 112 | t = np.arange(length) 113 | train_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 114 | valid_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 115 | test_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 116 | 117 | for i in range(0, length): 118 | logger.append([train_loss[i], valid_loss[i], test_loss[i]]) 119 | logger.plot() 120 | 121 | # Example: logger monitor 122 | # paths = { 123 | # 'resadvnet20': '/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt', 124 | # 'resadvnet32': '/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt', 125 | # 'resadvnet44': '/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt', 126 | # } 127 | # 128 | # field = ['Valid Acc.'] 129 | # 130 | # monitor = LoggerMonitor(paths) 131 | # monitor.plot(names=field) 132 | # savefig('test.eps') 133 | -------------------------------------------------------------------------------- /utils/net_validation.py: -------------------------------------------------------------------------------- 1 | # coding:utf8 2 | from data import HPEDetDataset_NE, HPEAugmentation, HPEPoseDataset, HPEPoseValDataset, HPEBaseTransform 3 | import torch 4 | from torch.utils import data 5 | from torch.autograd import Variable 6 | import numpy as np 7 | from utils.prediction_handle import get_pred_kps, val_input_convert 8 | 9 | 10 | def part1_val(model, writer, count, opt): 11 | """局部探测网络验证""" 12 | model.eval() 13 | if opt.demo: 14 | opt.val_anno_file = opt.root_dir + 'demo_data/keypoint_validation_annotations_20170911.json' 15 | dataset = HPEDetDataset_NE(opt.val_anno_file, opt.val_img_dir, HPEBaseTransform(opt.val_mean), phase='val') 16 | dataloader = data.DataLoader(dataset, 17 | batch_size=opt.batch_size, 18 | num_workers=opt.num_workers, 19 | pin_memory=opt.pin_memory, 20 | ) 21 | loss_sum = 0 22 | for processed_img, label in dataloader: 23 | processed_img, label = Variable(processed_img.float().cuda()), Variable(label.float().cuda()) 24 | detection_result = model(processed_img) 25 | val_loss = detection_loss_func(detection_result, label, opt) 26 | loss_sum += val_loss.data[0] 27 | model.train() 28 | return loss_sum / np.ceil(len(dataset) / opt.batch_size) 29 | 30 | 31 | def part2_val(model, writer, count, opt): 32 | """回归子网络验证""" 33 | model.eval() 34 | if opt.demo: 35 | pass 36 | dataset = HPEPoseValDataset(opt.val_anno_file, opt.val_img_dir) 37 | dataloader = data.DataLoader(dataset, batch_size=opt.val_bs, num_workers=opt.num_workers) 38 | 39 | pred_list = [] 40 | for processed_img, processed_info in dataloader: 41 | processed_img = processed_img.float() 42 | processed_img = Variable(processed_img.cuda()) 43 | pred_list += get_pred_kps(processed_info, model(processed_img).cpu().data.numpy()) 44 | 45 | predictions = val_input_convert(pred_list) 46 | 47 | model.train() 48 | return predictions 49 | 50 | 51 | def detection_loss_func(detection_result, label, opt): 52 | s = torch.sum( 53 | torch.mul(label, torch.log(detection_result)) + torch.mul((1 - label), 54 | torch.log((1 - detection_result)))) 55 | return torch.div(-s, 14.0 * opt.batch_size) 56 | -------------------------------------------------------------------------------- /utils/prediction_handle.py: -------------------------------------------------------------------------------- 1 | # coding:utf8 2 | import numpy as np 3 | from config import opt 4 | import torch.nn.functional as F 5 | 6 | 7 | def val_input_convert(predictions_in): 8 | """ 格式转换,转换成eval脚本能使用的prediction,处理成eval适配的数据格式 9 | :param predictions_in: 10 | :return: 例如:{"image_ids": ['img1','img2',...], "annos": {'img1':{"human3": [254, 203, 1, ...],"human2": ...}}} 11 | """ 12 | predictions = dict() 13 | predictions['image_ids'] = [] 14 | predictions['annos'] = dict() 15 | for pred in predictions_in: 16 | if pred[0] in predictions['image_ids']: 17 | predictions['annos'][pred[0]]['keypoint_annos'].update(pred[1]) 18 | else: 19 | predictions['image_ids'].append(pred[0]) 20 | predictions['annos'][pred[0]] = dict() 21 | predictions['annos'][pred[0]]['keypoint_annos'] = pred[1] 22 | return predictions 23 | 24 | 25 | def get_prediction_keypoints(processed_info, preds): 26 | """ 计算一个batch 预测出来的骨骼点坐标 27 | :param processed_info: List [{'scale': scale_ratio, 'info': keypoints_res, 'img_id': img_name, 'human': human}] 28 | :param preds: 预测出来的热图,14张叠加 29 | :return: dict ["image name", {"human3": [254, 203, 1, ...],"huma2": ...}] 30 | """ 31 | predictions = dict() 32 | # print(processed_info['img_id'][:2]) 33 | for i, pred in enumerate(preds): 34 | if processed_info['img_id'][i] not in predictions.keys(): 35 | predictions[processed_info['img_id'][i]] = {} 36 | kp_annos = get_keypoint_coordinate(processed_info['info'][i][0][2], pred, opt.threshold) 37 | new_kp_annos = convert_coordinate(kp_annos, processed_info['info'][i][0], processed_info['scale'][i]) 38 | predictions[processed_info['img_id'][i]][processed_info['human'][i]] = new_kp_annos 39 | else: 40 | kp_annos = get_keypoint_coordinate(processed_info['info'][i][0][2], pred, opt.threshold) 41 | new_kp_annos = convert_coordinate(kp_annos, processed_info['info'][i][0], processed_info['scale'][i]) 42 | predictions[processed_info['img_id'][i]][processed_info['human'][i]] = new_kp_annos 43 | return predictions.items() 44 | 45 | 46 | def get_pred_kps(processed_info, preds): 47 | """ 计算一个batch 预测出来的骨骼点坐标 48 | :param processed_info: List [{'scale': scale_ratio, 'info': keypoints_res, 'img_id': img_name, 'human': human}] 49 | :param preds: 预测出来的热图,14张叠加 50 | :return: dict ["image name", {"human3": [254, 203, 1, ...],"huma2": ...}] 51 | """ 52 | # print(len(processed_info))5 53 | # print(processed_info) 54 | # print(len(preds))8 55 | predictions = dict() 56 | # print(processed_info['img_id'][:2]) 57 | # print(processed_info) 58 | preds = F.upsample(preds, scale_factor=4, mode='bilinear').data.numpy() 59 | for i, pred in enumerate(preds): 60 | coord = processed_info['coords'][i].numpy() 61 | span_x = coord[2] - coord[0] 62 | span_y = coord[3] - coord[1] 63 | isupright, scale = (True, span_y) if span_y >= span_x else (False, span_x) 64 | 65 | if processed_info['image_id'][i] not in predictions.keys(): 66 | predictions[processed_info['image_id'][i]] = {} 67 | kp_annos = get_keypoint_coordinate(isupright, pred, opt.threshold) 68 | else: 69 | kp_annos = get_keypoint_coordinate(isupright, pred, opt.threshold) 70 | 71 | new_kp_annos = convert_coordinate(kp_annos, coord, scale) 72 | predictions[processed_info['image_id'][i]][processed_info['human'][i]] = new_kp_annos 73 | # print("--", list(predictions.items())) 74 | return list(predictions.items()) 75 | 76 | 77 | def get_keypoint_coordinate(isupright, pred, threshold=0.0): 78 | """ 79 | 实现函数trans_coordinate()的逆过程,坐标转换 80 | :param threshold: 81 | :param isupright: 是否站立 82 | :param pred: 预测的热图 83 | :return: 返回预测的缩放后的骨骼点 84 | """ 85 | kps = [] 86 | if isupright: 87 | for p in pred: 88 | if np.max(p) > threshold: 89 | x = np.argmax(p) % 256 90 | y = np.argmax(p) / 256 91 | kps.append([x, y, 1]) 92 | else: 93 | kps.append([0, 0, 0]) 94 | else: 95 | for p in pred: 96 | if np.max(p) > threshold: 97 | x = np.argmax(p) % 256 98 | y = np.argmax(p) / 256 99 | kps.append([x, y, 1]) 100 | else: 101 | kps.append([0, 0, 0]) 102 | return kps 103 | 104 | 105 | def convert_coordinate(keypoints, human_position, scale): 106 | """predicted keypoints 2 original keypoints 107 | :param scale: 缩放比 108 | :param keypoints: 关节点 109 | :param human_position: 原图中人框坐标 110 | :return: 返回新的keypoints 111 | """ 112 | kps = np.reshape(keypoints, (-1, 3)) 113 | kps = (kps * [scale / 256.0, scale / 256.0, 1]).astype(np.int16) 114 | kps = kps + [human_position[0], human_position[1], 0] 115 | return kps.reshape(-1).tolist() 116 | 117 | 118 | if __name__ == '__main__': 119 | a = [('sdfg', {'h1': [1, 4, 5], 'h2': [23, 7]}), ('wer', {'h1': [2, 43, 5]}), ('sdfg', {'h3': [23, 435]})] 120 | print(val_input_convert(a)) 121 | -------------------------------------------------------------------------------- /utils/visualize.py: -------------------------------------------------------------------------------- 1 | # coding:utf8 2 | import visdom 3 | import time 4 | import numpy as np 5 | 6 | 7 | class Visualizer(): 8 | def __init__(self, part_id, env='default', **kwargs): 9 | self.viz = visdom.Visdom(env=env, **kwargs) 10 | self.index = {} 11 | self.log_text = "" 12 | self.part_id = part_id 13 | 14 | def plot(self, name, y): 15 | ''' 16 | self.plot('loss',1.00) 17 | ''' 18 | x = self.index.get(name, 0) 19 | self.viz.line(Y=np.array([y]), 20 | X=np.array([x]), 21 | win=unicode(name), 22 | opts=dict(title=name), 23 | update=None if x == 0 else 'append' 24 | ) 25 | self.index[name] = x + 1 26 | 27 | def log(self, info, win='log_text'): 28 | ''' 29 | self.log({'loss':1,'lr':0.0001}) 30 | ''' 31 | self.log_text += ('[{time}] {info}
'.format( 32 | time=time.strftime('%m%d_%H:%M:%S'), \ 33 | info=info)) 34 | self.viz.text(self.log_text, win=win) 35 | 36 | def heatmap_many(self, heatmap_out): 37 | for i, hm in enumerate(heatmap_out, 1): 38 | self.viz.heatmap( 39 | X=np.array(hm), 40 | win=unicode(self.part_id[i]), 41 | opts=dict( 42 | columnnames=range(256), 43 | rownames=range(256), 44 | colormap='Viridis', 45 | title=self.part_id[i], 46 | ) 47 | ) 48 | 49 | def img_vis(self, img): 50 | self.viz.image( 51 | img, 52 | opts=dict(title='Image', caption='Img'), 53 | ) 54 | --------------------------------------------------------------------------------