├── .gitignore ├── cal_score.py ├── checkpoints ├── config.py ├── data ├── __init__.py ├── dataset.py ├── fastitem.py ├── item.py ├── tool.py └── visual.py ├── data_exp.ipynb ├── eval.py ├── eval_tool.py ├── models ├── __init__.py ├── basic_module.py ├── loss.py ├── mobilenet.py └── model.py ├── test.ipynb ├── test.py ├── test_save.py ├── train.py └── utils ├── __init__.py ├── draw_limb.py ├── test_tool.py ├── util.py └── visualize.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.zip 3 | *.png 4 | *.jpg 5 | .* 6 | _del/ 7 | *.json 8 | -------------------------------------------------------------------------------- /cal_score.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2017 challenger.ai 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Evaluation utility for human skeleton system keypoint task. 16 | This python script is used for calculating the final score (mAP) of the test result, 17 | based on your submited file and the reference file containing ground truth. 18 | usage 19 | python keypoint_eval.py --submit SUBMIT_FILEPATH --ref REF_FILEPATH 20 | A test case is provided, submited file is submit.json, reference file is ref.json, test it by: 21 | python keypoint_eval.py --submit ./keypoint_sample_predictions.json \ 22 | --ref ./keypoint_sample_annotations.json 23 | The final score of the submited result, error message and warning message will be printed. 24 | """ 25 | 26 | import json 27 | import time 28 | import argparse 29 | import pprint 30 | 31 | import numpy as np 32 | from config import opt 33 | 34 | 35 | def parse(kwargs): 36 | ## 处理配置和参数 37 | for k, v in kwargs.iteritems(): 38 | if not hasattr(opt, k): 39 | print("Warning: opt has not attribut %s" % k) 40 | setattr(opt, k, v) 41 | for k, v in opt.__class__.__dict__.iteritems(): 42 | if not k.startswith('__'): print(k, getattr(opt, k)) 43 | 44 | 45 | def load_annotations(anno_file, return_dict): 46 | """Convert annotation JSON file.""" 47 | 48 | annotations = dict() 49 | annotations['image_ids'] = set([]) 50 | annotations['annos'] = dict() 51 | annotations['delta'] = 2*np.array([0.01388152, 0.01515228, 0.01057665, 0.01417709, \ 52 | 0.01497891, 0.01402144, 0.03909642, 0.03686941, 0.01981803, \ 53 | 0.03843971, 0.03412318, 0.02415081, 0.01291456, 0.01236173]) 54 | try: 55 | annos = json.load(open(anno_file, 'r')) 56 | except Exception: 57 | return_dict[ 58 | 'error'] = 'Annotation file does not exist or is an invalid JSON file.' 59 | exit(return_dict['error']) 60 | 61 | for anno in annos: 62 | annotations['image_ids'].add(anno['image_id']) 63 | annotations['annos'][anno['image_id']] = dict() 64 | annotations['annos'][anno['image_id']]['human_annos'] = anno[ 65 | 'human_annotations'] 66 | annotations['annos'][anno['image_id']]['keypoint_annos'] = anno[ 67 | 'keypoint_annotations'] 68 | 69 | return annotations 70 | 71 | 72 | def load_predictions(prediction_file, return_dict): 73 | """Convert prediction JSON file.""" 74 | 75 | predictions = dict() 76 | predictions['image_ids'] = [] 77 | predictions['annos'] = dict() 78 | id_set = set([]) 79 | 80 | try: 81 | preds = json.load(open(prediction_file, 'r')) 82 | except Exception: 83 | return_dict[ 84 | 'error'] = 'Prediction file does not exist or is an invalid JSON file.' 85 | exit(return_dict['error']) 86 | 87 | for pred in preds: 88 | if 'image_id' not in pred.keys(): 89 | return_dict['warning'].append( 90 | 'There is an invalid annotation info, \ 91 | likely missing key \'image_id\'.') 92 | continue 93 | if 'keypoint_annotations' not in pred.keys(): 94 | return_dict['warning'].append(pred['image_id']+\ 95 | ' does not have key \'keypoint_annotations\'.') 96 | continue 97 | image_id = pred['image_id'].split('.')[0] 98 | if image_id in id_set: 99 | return_dict['warning'].append(pred['image_id']+\ 100 | ' is duplicated in prediction JSON file.') 101 | else: 102 | id_set.add(image_id) 103 | predictions['image_ids'].append(image_id) 104 | predictions['annos'][pred['image_id']] = dict() 105 | predictions['annos'][pred['image_id']]['keypoint_annos'] = pred[ 106 | 'keypoint_annotations'] 107 | 108 | return predictions 109 | 110 | 111 | def compute_oks(anno, predict, delta): 112 | """Compute oks matrix (size gtN*pN).""" 113 | 114 | anno_count = len(anno['keypoint_annos'].keys()) 115 | predict_count = len(predict.keys()) 116 | oks = np.zeros((anno_count, predict_count)) 117 | 118 | # for every human keypoint annotation 119 | for i in range(anno_count): 120 | anno_key = anno['keypoint_annos'].keys()[i] 121 | anno_keypoints = np.reshape(anno['keypoint_annos'][anno_key], (14, 3)) 122 | visible = anno_keypoints[:, 2] == 1 123 | # print ("visible: ",visible) 124 | bbox = anno['human_annos'][anno_key] 125 | scale = np.float32((bbox[3] - bbox[1]) * (bbox[2] - bbox[0])) 126 | if np.sum(visible) == 0: 127 | for j in range(predict_count): 128 | oks[i, j] = 0 129 | else: 130 | # for every predicted human 131 | for j in range(predict_count): 132 | predict_key = predict.keys()[j] 133 | predict_keypoints = np.reshape(predict[predict_key], (14, 3)) 134 | dis = np.sum((anno_keypoints[visible, :2] \ 135 | - predict_keypoints[visible, :2])**2, axis=1) 136 | # print ("dis: ", dis) 137 | # print ("in exp:", np.exp(-dis / 2 / delta[visible]**2 / (scale + 1))) 138 | oks[i, j] = np.mean( 139 | np.exp(-dis / 2 / delta[visible]**2 / (scale + 1))) 140 | if anno_count<50: 141 | print ("oks: ",oks) 142 | return oks 143 | 144 | 145 | def keypoint_eval(predictions, annotations, return_dict): 146 | """Evaluate predicted_file and return mAP.""" 147 | 148 | oks_all = np.zeros((0)) 149 | oks_num = 0 150 | # Construct set to speed up id searching. 151 | prediction_id_set = set(predictions['image_ids']) 152 | 153 | # for every annotation in our test/validation set 154 | for image_id in annotations['image_ids']: 155 | # if the image in the predictions, then compute oks 156 | if image_id in prediction_id_set: 157 | oks = compute_oks(anno=annotations['annos'][image_id], \ 158 | predict=predictions['annos'][image_id]['keypoint_annos'], \ 159 | delta=annotations['delta']) 160 | # view pairs with max OKSs as match ones, add to oks_all 161 | oks_all = np.concatenate((oks_all, np.max(oks, axis=0)), axis=0) 162 | # accumulate total num by max(gtN,pN) 163 | oks_num += np.max(oks.shape) 164 | else: 165 | # otherwise report warning 166 | return_dict['warning'].append( 167 | image_id + ' is not in the prediction JSON file.') 168 | # number of humen in ground truth annotations 169 | gt_n = len(annotations['annos'][image_id]['human_annos'].keys()) 170 | # fill 0 in oks scores 171 | oks_all = np.concatenate((oks_all, np.zeros((gt_n))), axis=0) 172 | # accumulate total num by ground truth number 173 | #oks_num += gt_n 174 | 175 | # compute mAP by APs under different oks thresholds 176 | print (oks_num) 177 | average_precision = [] 178 | for threshold in np.linspace(0.5, 0.95, 10): 179 | num_thre = np.sum(oks_all > threshold) 180 | average_precision.append( 181 | num_thre / np.float32(oks_num)) 182 | print (num_thre, num_thre / np.float32(oks_num)) 183 | #print (average_precision) 184 | return_dict['score'] = np.mean(average_precision) 185 | 186 | return return_dict 187 | 188 | 189 | def main(**kwargs): 190 | """The evaluator.""" 191 | 192 | # Arguments parser 193 | parse(kwargs) 194 | # Initialize return_dict 195 | return_dict = dict() 196 | return_dict['error'] = None 197 | return_dict['warning'] = [] 198 | return_dict['score'] = None 199 | 200 | # Load annotation JSON file 201 | start_time = time.time() 202 | annotations = load_annotations(anno_file=opt.ref, return_dict=return_dict) 203 | print ('Complete reading annotation JSON file in %.2f seconds.' % ( 204 | time.time() - start_time)) 205 | 206 | # Load prediction JSON file 207 | start_time = time.time() 208 | predictions = load_predictions( 209 | prediction_file=opt.submit, return_dict=return_dict) 210 | print ('Complete reading prediction JSON file in %.2f seconds.' % ( 211 | time.time() - start_time)) 212 | 213 | # Keypoint evaluation 214 | start_time = time.time() 215 | return_dict = keypoint_eval( 216 | predictions=predictions, 217 | annotations=annotations, 218 | return_dict=return_dict) 219 | print ('Complete evaluation in %.2f seconds.' % (time.time() - start_time)) 220 | 221 | # Print return_dict and final score 222 | # pprint.pprint(return_dict) 223 | print ('Score: ', '%.8f' % return_dict['score']) 224 | 225 | 226 | if __name__ == "__main__": 227 | import fire 228 | fire.Fire() -------------------------------------------------------------------------------- /checkpoints: -------------------------------------------------------------------------------- 1 | /data/checkpoint/aicha/keypoint -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # encoding:utf-8 2 | import time 3 | tfmt = '%m%d_%H%M%D' 4 | 5 | 6 | class Config: 7 | crop_size_x = 384 8 | crop_size_y = 384 9 | 10 | sigma = 1.0 11 | boxsize = 368 12 | env = 'keypoint' # Visdom env 13 | model = 'KeypointModel' 14 | model_path = None 15 | batch_size = 32 16 | num_workers = 10 17 | shuffle = True 18 | debug_file = '/tmp/debugk' 19 | lr1 = 0 20 | lr2 = 1e-3 21 | max_epoch = 100 22 | model_path = None 23 | ref = '/mnt/6/ai_challenger_keypoint_validation_20170911/keypoint_validation_annotations_20170911.json' 24 | submit = 'result/val_result.json' 25 | plot_every = 10 26 | save_every = 3000 27 | decay_every = 1500 28 | stride = 8 29 | boxsize = 384 30 | thre2 = 0.05 31 | thre1 = 0.1 32 | scale_search=[1.0] 33 | val_result_dir = '/home/x/dcsb/KeyPoint/result/' 34 | val_dir = '/mnt/6/ai_challenger_keypoint_validation_20170911/keypoint_validation_images_20170911/' 35 | img_root = '/data_ssd/ai_challenger/ai_challenger_keypoint_train_20170909/keypoint_train_images_20170902/' 36 | # img_root = '/mnt/7/ai_challenger_keypoint_train_20170909/keypoint_train_images_20170902/' 37 | anno_path = '/data_ssd/ai_challenger/ai_challenger_keypoint_train_20170909/train.pth' 38 | 39 | downsample_rate = 8 40 | 41 | def parse(self,kwargs): 42 | # 处理配置和参数 43 | for k, v in kwargs.iteritems(): 44 | if not hasattr(self, k): 45 | print("Warning: opt has not attribut %s" % k) 46 | setattr(self, k, v) 47 | 48 | keys = sorted(self.state_dict().keys()) 49 | for key in keys: 50 | if not key.startswith('__'): 51 | print(key, getattr(self, key)) 52 | 53 | def state_dict(self): 54 | return {k:getattr(self,k) for k in dir(self) if not k.startswith('_') and k!='parse' and k!='state_dict' } 55 | 56 | Config.parse = parse 57 | Config.state_dict = state_dict 58 | opt = Config() 59 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from dataset import Dataset -------------------------------------------------------------------------------- /data/dataset.py: -------------------------------------------------------------------------------- 1 | # encoding:utf-8 2 | import torch as t 3 | import torch 4 | from torch.utils import data 5 | import numpy as np 6 | from skimage import exposure 7 | from PIL import Image, ImageEnhance 8 | import random 9 | from config import opt 10 | from fastitem import FastItem 11 | from torchvision import transforms 12 | from skimage import transform as sktrf 13 | 14 | class Dataset(data.Dataset): 15 | def __init__(self, opt): 16 | self.annotations = t.load(opt.anno_path) 17 | self.opt = opt 18 | self.transforms = transforms.Compose([ 19 | transforms.Normalize( 20 | mean=[0.485, 0.456, 0.406], 21 | std=[0.229, 0.224, 0.225]), 22 | ]) 23 | 24 | def __getitem__(self, index): 25 | """ 26 | :return img: 3xHxW 27 | :return gt: 41xHxW 28 | :return weight: HxW 29 | """ 30 | anno = self.annotations[index] 31 | item = FastItem(self.opt, anno) 32 | img, cf, paf, weight = item.get() 33 | 34 | # scale 35 | # paf = item.rescale(paf,0.125) 36 | # cf = item.rescale(cf,0.125) 37 | # weight = item.rescale(weight,0.125) 38 | 39 | # to tensor 40 | img = t.from_numpy(img.transpose((2, 0, 1))) 41 | cf = t.from_numpy(cf.transpose((2, 0, 1))) 42 | paf = t.from_numpy(paf.transpose((2, 0, 1))) 43 | weight = t.from_numpy(weight).squeeze() 44 | 45 | # other process 46 | img = self.transforms(img) 47 | gt = t.cat([cf, paf], dim=0) 48 | 49 | return img.float(), gt.float(), weight.unsqueeze(0).float().expand_as(gt) 50 | 51 | def __len__(self): 52 | return len(self.annotations) 53 | -------------------------------------------------------------------------------- /data/fastitem.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | import numpy as np 3 | from PIL import Image 4 | import copy 5 | from scipy import stats 6 | from skimage import draw,transform,util 7 | 8 | 9 | class FastItem(object): 10 | ''' 11 | 每一张图片对应一个Item对象 12 | ''' 13 | def __init__(self,opt,anno): 14 | ''' 15 | @param: opt 16 | @param: anno 标记 dict对象 17 | ''' 18 | self.anno = copy.deepcopy(anno)# 不要修改之前的dict对象 19 | self._img = Image.open(opt.img_root+anno['image_id']+'.jpg') 20 | 21 | 22 | self.size = list(self._img.size)[::-1] # 图片形状 height*weight 23 | self.output_size = opt.crop_size_x,opt.crop_size_y # 图片输出形状 24 | self.downsample_rate = opt.downsample_rate 25 | self.opt = opt 26 | 27 | 28 | 29 | size,output_size = np.array(self.size),np.array(self.output_size) 30 | ratio = (size+0.0)/output_size 31 | ratio = min(ratio) 32 | re_size = (size/ratio) 33 | self.re_size =(re_size/opt.downsample_rate).astype(np.int32) 34 | self.update_anno(opt.downsample_rate*ratio) 35 | 36 | # !TODO 针对不同的躯干设计不同的比例 37 | # 手臂的可以比例高点,肩膀到左髋的距离比较大,粗细可以小一点 38 | self.ratio = 0.1 # 长/宽为3 即手臂长/手臂粗 39 | 40 | def update_anno(self,ratio): 41 | ''' 42 | 调整anno信息: 43 | 之前的信息座标都是先列,再行,现在统一改成先行再列 44 | 形如: 45 | {u'image_id': u'043758c591b58f39a01648c49b5154ad1e01d400', u'keypoint_annotations': 46 | {u'human1': [(x1, y1, type1), ....(x14, y14, type14)]} 47 | , u'human_annotations': {u'human1': [x, y, x, y]}} 48 | ''' 49 | 50 | for human,human_anno in self.anno['human_annotations'].items(): 51 | human_anno[1],human_anno[0], human_anno[3],human_anno[2] = \ 52 | [int(_/ratio) for _ in human_anno] 53 | 54 | key_annos = self.anno['keypoint_annotations'] 55 | for human,key_anno in key_annos.items(): 56 | # 修改行和列的顺序 57 | coord2 = [_/(0.+ratio) for _ in key_anno[::3]] 58 | coord1 = [_/(0.+ratio) for _ in key_anno[1::3]] 59 | stat_ = key_anno[2::3] 60 | key_annos[human] = zip(coord1,coord2,stat_) 61 | 62 | def confidence_map(self): 63 | ''' 64 | 针对每个关节,生成类似高斯的激活 65 | !TODO: 输出的尺度要求? 现在太小,是否要增大 66 | ''' 67 | _h,_w = self.re_size 68 | _h,_w = int(_h),int(_w) 69 | # _h,_w = _h/self.opt.downsample_rate,_w/self.opt.downsample_rate 70 | part_maps = np.zeros([_h,_w,15]) 71 | for part in range(14): 72 | keypoint_anno = self.anno['keypoint_annotations'] 73 | person_num = len(keypoint_anno) 74 | maps = np.zeros([_h,_w,person_num]) 75 | for ii,(human,middle) in enumerate(keypoint_anno.items()): 76 | if middle[part][2]==3: 77 | # 不在图中的初始化为零 78 | map = np.zeros([_h,_w]) 79 | else: 80 | middle_ = middle[part][:2] 81 | map = self._gaussion_map((_h,_w),middle_,self.opt.sigma) 82 | maps[:,:,ii] = map 83 | part_map = maps.max(axis=2) 84 | part_maps[:,:,part]= part_map 85 | 86 | part_maps[:,:,-1] = part_maps[:,:,:14].max(axis=2) 87 | self.cf = part_maps 88 | return self.cf 89 | 90 | def _gaussion_map(self,shape,middle,sigma): 91 | ''' 92 | 根据(u,sigma)和形状生成高斯热度图 93 | 94 | @sigma: 标准差! 不是方差 95 | @shape: 生成map的形状 96 | @middle:中心位置(均值) 97 | ''' 98 | middle = np.array(middle).reshape([1,1,-1]) 99 | def gaussion(x,y,middle=middle): 100 | x = x[...,np.newaxis] 101 | y = y[...,np.newaxis] 102 | coord = np.concatenate([x,y],axis=2) 103 | distance = np.linalg.norm(coord - middle, axis=2) 104 | #return stats.norm.pdf(distance, 0, sigma)*(2*np.pi*(sigma**0.5)) 105 | return stats.norm.pdf(distance, 0, sigma)*((2*np.pi)**0.5)*sigma 106 | return np.fromfunction(gaussion,(shape)) 107 | 108 | def PAF_map(self): 109 | ''' 110 | 生成paf, 111 | ''' 112 | mid_1 = [12, 0, 3, 0, 1, 3, 4, 0, 6, 7, 3, 9, 10] 113 | mid_2 = [13, 13, 13, 1, 2, 4, 5, 6, 7, 8, 9, 10, 11] 114 | h_,w_ = self.re_size 115 | # h_,w_ = h_/self.opt.downsample_rate/self.ratio,w_/self.opt.downsample_rate/self.ratio 116 | 117 | keypoint_anno = self.anno['keypoint_annotations'] 118 | person_num = len(keypoint_anno) 119 | masks = np.zeros([person_num,h_,w_,13,2]) 120 | person_num_mask = np.zeros([h_,w_]) # 有多少个人的关节在这个区域重叠 121 | for ii_,(person,middle) in enumerate(keypoint_anno.items()): 122 | for jj_,(m1,m2) in enumerate(zip(mid_1,mid_2)): 123 | center1,center2 = middle[m1][:2],middle[m2][:2] 124 | if middle[m1][2] == 3 or middle[m2][2]==3: 125 | # 不再图中的就跳过(初始化为0) 126 | continue 127 | mask = self._PAF(center1,center2,self.ratio) 128 | #!TODO: 只有一个人,但是某一个位置还是可能超过多个值是否有问题(3) 129 | person_num_mask = person_num_mask+ ((mask!=0).sum(axis = 2)>0) 130 | masks[ii_,:,:,jj_,:] = mask 131 | masks = masks.sum(axis=0)/(person_num_mask[:,:,np.newaxis,np.newaxis]+1e-100) 132 | self.paf = masks 133 | self.person_num_mask = person_num_mask 134 | return self.paf 135 | 136 | def _PAF(self,center1,center2,ratio): 137 | 138 | ''' 139 | @param center1: 起点 140 | @param center2: 终点 141 | @param ratio: 长宽(粗)比 142 | 143 | @return mask: paf, shape(h,w,2) 144 | 145 | #!NOTE: center1(x,y) 是根据行和列设计的座标,也就是根据数组[x][y]对应图中像素 146 | 然而计算向量的时候,却需要 先取第y列,然后第x行,即和数组相反 147 | ''' 148 | center1,center2 = np.array(center1),np.array(center2) 149 | h_,w_ = self.re_size 150 | mask = np.zeros([h_,w_,2]) 151 | 152 | length = np.linalg.norm(center1-center2) 153 | limb_vec_unit = (center2-center1)/(length+1e-100) 154 | # limb_perp_unit = np.array([-limb_vec_unit[1],limb_vec_unit[0]]) 155 | # width = int(length*ratio/2)+1 156 | # if width>1:print width 157 | 158 | # coords = np.zeros([4,2]) 159 | # coords[0] = center1 - limb_perp_unit*width 160 | # coords[1] = center1 + limb_perp_unit*width 161 | # coords[2] = center2 + limb_perp_unit*width 162 | # coords[3] = center2 - limb_perp_unit*width 163 | 164 | # 不要让他超出边界 165 | # for _c in coords: 166 | # if _c[0]<0:_c[0]=0 167 | # if _c[0]>h_-1:_c[0]=h_-1 168 | # if _c[1]<0:_c[1]=0 169 | # if _c[1]>w_-1:_c[1]=w_-1 170 | # x,y = coords[:,0],coords[:,1] 171 | #x,y = np.array(x).round().astype(np.int32),np.array(y).round().astype(np.int32) 172 | 173 | center1,center2 = np.round(center1).astype(np.int32),np.round(center2).astype(np.int32) 174 | 175 | x = np.array([center1[0],center2[0]]) 176 | y = np.array([center1[1],center2[1]]) 177 | x = x.clip(min=0,max=h_-1) 178 | y = y.clip(min=0,max=w_-1) 179 | 180 | rr,cc,val = draw.line_aa(x[0],y[0],x[1],y[1]) 181 | # 图的座标和数组的座标不一致 182 | mask[rr,cc] = limb_vec_unit[::-1][np.newaxis, :]#*val[:,np.newaxis]#/(np.abs(val).min()+0.0000001) 183 | 184 | # rr, cc = draw.polygon(x,y) 185 | # mask[rr,cc] = limb_vec_unit[::-1] # x和y是相反的 186 | return mask 187 | 188 | def resize(self,img,size,mode='reflect'): 189 | return transform.resize(img,size,mode=mode) 190 | 191 | def random_crop(self,imgs): 192 | results = [] 193 | random_size 194 | for img in imgs: 195 | result = transform.crop() 196 | 197 | def weight_mask(self): 198 | mask = np.zeros(self.re_size) 199 | human_annotations = self.anno['human_annotations'] 200 | for human,anno in human_annotations.items(): 201 | x1,y1,x2,y2 = anno 202 | mask[x1:x2,y1:y2]=1 203 | return mask 204 | 205 | def get(self): 206 | ''' 207 | @return img:图片(opt.output_size,opt.output_size,3) 208 | @return cf: 209 | @return paf: 210 | @return mask: 211 | ''' 212 | img = self.resize(np.asarray(self._img),self.re_size*self.opt.downsample_rate) 213 | 214 | 215 | paf = self.PAF_map() 216 | _1,_2,_3,_4 = paf.shape 217 | paf=paf.reshape(_1,_2,-1) 218 | cf = self.confidence_map() 219 | mask = self.weight_mask()[...,np.newaxis] 220 | size,output_size = np.array(self.size),np.array(self.output_size) 221 | 222 | crop_width = ((self.re_size*self.opt.downsample_rate - output_size)/self.opt.downsample_rate).astype(np.int32) 223 | crop_before = np.floor(crop_width/2).astype(np.int32) 224 | crop_after = (crop_width.astype(np.int32)-crop_before) 225 | crop_before = crop_before.tolist()+[0] 226 | crop_after = crop_after.tolist()+[0] 227 | 228 | paf = self.crop(paf,crop_before,crop_after) 229 | cf = self.crop(cf,crop_before,crop_after) 230 | mask = self.crop(mask,crop_before,crop_after) 231 | 232 | crop_after = [_*self.opt.downsample_rate for _ in crop_after] 233 | crop_before = [_*self.opt.downsample_rate for _ in crop_before] 234 | 235 | img = self.crop(img,crop_before,crop_after) 236 | 237 | 238 | 239 | return (img,cf,paf,mask) 240 | 241 | def rescale(self,img,scale_,mode = 'reflect'): 242 | return transform.rescale(img,scale_,mode=mode) 243 | 244 | def crop(self,img,before,after): 245 | return util.crop(img,zip(before,after)) 246 | 247 | def flip(self,img,vec=False): 248 | """ 249 | docstring here 250 | :param self: 251 | :param img: 252 | :param vec=False: 253 | """ 254 | if vec: 255 | img[:,:,::2]*=-1 256 | else: 257 | pass 258 | return img 259 | def rotate(self,ratio): 260 | pass 261 | 262 | 263 | if __name__ == '__main__': 264 | db = AIChallengerDB('/mnt/7/ai_challenger_keypoint_train_20170909/') 265 | img, gt = db.get(999) 266 | 267 | -------------------------------------------------------------------------------- /data/item.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | import numpy as np 3 | from PIL import Image 4 | import copy 5 | from scipy import stats 6 | from skimage import draw,transform,util 7 | 8 | class Item(object): 9 | ''' 10 | 每一张图片对应一个Item对象 11 | ''' 12 | def __init__(self,opt,anno): 13 | ''' 14 | @param: opt 15 | @param: anno 标记 dict对象 16 | ''' 17 | self.anno = copy.deepcopy(anno)# 不要修改之前的dict对象 18 | self._img = Image.open(opt.img_root+anno['image_id']+'.jpg') 19 | self.size = list(self._img.size)[::-1] # 图片形状 height*weight 20 | self.output_size = opt.crop_size_x,opt.crop_size_y # 图片输出形状 21 | self.opt = opt 22 | self.update_anno() 23 | # !TODO 针对不同的躯干设计不同的比例 24 | # 手臂的可以比例高点,肩膀到左髋的距离比较大,粗细可以小一点 25 | self.ratio = 0.1 # 长/宽为3 即手臂长/手臂粗 26 | 27 | def update_anno(self): 28 | ''' 29 | 调整anno信息: 30 | 之前的信息座标都是先列,再行,现在统一改成先行再列 31 | 形如: 32 | {u'image_id': u'043758c591b58f39a01648c49b5154ad1e01d400', u'keypoint_annotations': 33 | {u'human1': [(x1, y1, type1), ....(x14, y14, type14)]} 34 | , u'human_annotations': {u'human1': [x, y, x, y]}} 35 | ''' 36 | 37 | for human,human_anno in self.anno['human_annotations'].items(): 38 | human_anno[1],human_anno[0], human_anno[3],human_anno[2] = human_anno 39 | 40 | key_annos = self.anno['keypoint_annotations'] 41 | for human,key_anno in key_annos.items(): 42 | coord2 = key_anno[::3] 43 | coord1 = key_anno[1::3] 44 | stat_ = key_anno[2::3] 45 | key_annos[human] = zip(coord1,coord2,stat_) 46 | 47 | def confidence_map(self): 48 | ''' 49 | 针对每个关节,生成类似高斯的激活 50 | !TODO: 输出的尺度要求? 现在太小,是否要增大 51 | ''' 52 | _h,_w = self.size 53 | part_maps = np.zeros([_h,_w,15]) 54 | for part in range(14): 55 | maps = [] 56 | for human,middle in self.anno['keypoint_annotations'].items(): 57 | if middle[part][2]==3: 58 | # 不在图中的初始化为零 59 | map = np.zeros([_h,_w]) 60 | else: 61 | middle_ = middle[part][:2] 62 | map = self._gaussion_map((_h,_w),middle_,self.opt.sigma) 63 | maps.append(map) 64 | part_map = np.stack(maps, -1).max(axis=2) 65 | part_maps[:,:,part]= part_map 66 | 67 | part_maps[:,:,-1] = part_maps[:,:,:14].max(axis=2) 68 | self.cf = part_maps 69 | return self.cf 70 | 71 | def _gaussion_map(self,shape,middle,sigma): 72 | ''' 73 | 根据(u,sigma)和形状生成高斯热度图 74 | 75 | @sigma: 标准差! 不是方差 76 | @shape: 生成map的形状 77 | @middle:中心位置(均值) 78 | ''' 79 | middle = np.array(middle).reshape([1,1,-1]) 80 | def gaussion(x,y,middle=middle): 81 | x = x[...,np.newaxis] 82 | y = y[...,np.newaxis] 83 | coord = np.concatenate([x,y],axis=2) 84 | distance = np.linalg.norm(coord - middle, axis=2) 85 | return stats.norm.pdf(distance, 0, sigma)*(2*np.pi*(sigma**0.5)) 86 | return np.fromfunction(gaussion,(shape)) 87 | 88 | 89 | 90 | def confidence_map2(self): 91 | ''' 92 | 针对每个关节,生成类似高斯的激活 93 | !TODO: 输出的尺度要求? 现在太小,是否要增大 94 | ''' 95 | _h,_w = self.size 96 | part_maps = np.zeros([_h,_w,15]) 97 | for part in range(14): 98 | keypoint_anno = self.anno['keypoint_annotations'] 99 | person_num = len(keypoint_anno) 100 | maps = np.zeros([_h,_w,person_num]) 101 | for ii,(human,middle) in enumerate(keypoint_anno.items()): 102 | if middle[part][2]==3: 103 | # 不在图中的初始化为零 104 | map = np.zeros([_h,_w]) 105 | else: 106 | middle_ = middle[part][:2] 107 | map = self._gaussion_map((_h,_w),middle_,self.opt.sigma) 108 | maps[:,:,ii] = map 109 | part_map = maps.max(axis=2) 110 | part_maps[:,:,part]= part_map 111 | 112 | part_maps[:,:,-1] = part_maps[:,:,:14].max(axis=2) 113 | self.cf = part_maps 114 | return self.cf 115 | def PAF_map(self): 116 | ''' 117 | 生成paf, 118 | ''' 119 | mid_1 = [12, 0, 3, 0, 1, 3, 4, 0, 6, 7, 3, 9, 10] 120 | mid_2 = [13, 13, 13, 1, 2, 4, 5, 6, 7, 8, 9, 10, 11] 121 | h_,w_ = self.size 122 | keypoint_anno = self.anno['keypoint_annotations'] 123 | person_num = len(keypoint_anno) 124 | masks = np.zeros([person_num,h_,w_,13,2]) 125 | person_num_mask = np.zeros([h_,w_]) # 有多少个人的关节在这个区域重叠 126 | for ii_,(person,middle) in enumerate(keypoint_anno.items()): 127 | for jj_,(m1,m2) in enumerate(zip(mid_1,mid_2)): 128 | center1,center2 = middle[m1][:2],middle[m2][:2] 129 | if middle[m1][2] == 3 or middle[m2][2]==3: 130 | # 不再图中的就跳过(初始化为0) 131 | continue 132 | mask = self._PAF(center1,center2,self.ratio) 133 | person_num_mask = person_num_mask+ ((mask!=0).sum(axis = 2)>0) 134 | masks[ii_,:,:,jj_,:] = mask 135 | masks = masks.sum(axis=0)/(person_num_mask[:,:,np.newaxis,np.newaxis]+0.000001) 136 | 137 | self.paf = masks 138 | return self.paf 139 | 140 | def _PAF(self,center1,center2,ratio): 141 | 142 | ''' 143 | @param center1: 起点 144 | @param center2: 终点 145 | @param ratio: 长宽(粗)比 146 | 147 | @return mask: paf, shape(h,w,2) 148 | ''' 149 | center1,center2 = np.array(center1),np.array(center2) 150 | h_,w_ = self.size 151 | mask = np.zeros([h_,w_,2]) 152 | 153 | length = np.linalg.norm(center1-center2) 154 | limb_vec_unit = (center2-center1)/(length+0.00001) 155 | limb_perp_unit = np.array([-limb_vec_unit[1],limb_vec_unit[0]]) 156 | width = int(length*ratio/2)+2 157 | 158 | coords = np.zeros([4,2]) 159 | coords[0] = center1 - limb_perp_unit*width 160 | coords[1] = center1 + limb_perp_unit*width 161 | coords[2] = center2 + limb_perp_unit*width 162 | coords[3] = center2 - limb_perp_unit*width 163 | 164 | # 不要让他超出边界 165 | for _c in coords: 166 | if _c[0]<0:_c[0]=0 167 | if _c[0]>h_-2:_c[0]=h_-2 168 | if _c[1]<0:_c[1]=0 169 | if _c[1]>w_-2:_c[1]=w_-2 170 | x,y = coords[:,0],coords[:,1] 171 | rr, cc = draw.polygon(x,y) 172 | mask[rr,cc] = limb_vec_unit 173 | return mask 174 | 175 | def resize(self,img,size): 176 | return transform.resize(img,size) 177 | 178 | def random_crop(self,imgs): 179 | results = [] 180 | random_size 181 | for img in imgs: 182 | result = transform.crop() 183 | 184 | def weight_mask(self): 185 | mask = np.zeros(self.size) 186 | human_annotations = self.anno['human_annotations'] 187 | for human,anno in human_annotations.items(): 188 | x1,y1,x2,y2 = anno 189 | mask[x1:x2,y1:y2]=1 190 | return mask 191 | 192 | def get(self): 193 | ''' 194 | @return img:图片(opt.output_size,opt.output_size,3) 195 | @return cf: 196 | @return paf: 197 | @return maask: 198 | ''' 199 | img = np.asarray(self._img) 200 | paf = self.PAF_map() 201 | cf = self.confidence_map() 202 | mask = self.weight_mask()[...,np.newaxis] 203 | 204 | size,output_size = np.array(self.size),np.array(self.output_size) 205 | 206 | ratio = (size+0.0)/output_size 207 | ratio = min(ratio) 208 | re_size = (size/ratio).astype(np.int32) 209 | 210 | img = self.resize(img,re_size.tolist()) 211 | paf = self.resize(paf.reshape(paf.shape[0],-1,13*2),re_size.tolist()) 212 | cf = self.resize(cf,re_size.tolist()) 213 | mask = self.resize(mask,re_size.tolist()) 214 | 215 | crop_width = re_size - output_size 216 | crop_before = np.floor(crop_width/2).astype(np.int32) 217 | crop_after = crop_width.astype(np.int32)-crop_before 218 | 219 | crop_after = crop_after.tolist()+[0] 220 | crop_before = crop_before.tolist()+[0] 221 | 222 | img = self.crop(img,crop_before,crop_after) 223 | paf = self.crop(paf,crop_before,crop_after) 224 | cf = self.crop(cf,crop_before,crop_after) 225 | mask = self.crop(mask,crop_before,crop_after) 226 | 227 | 228 | return (img,cf,paf,mask) 229 | 230 | def rescale(self,img,scale_,mode = 'reflect'): 231 | return transform.rescale(img,scale_,mode=mode) 232 | 233 | def crop(self,img,before,after): 234 | return util.crop(img,zip(before,after)) 235 | 236 | def flip(self,img,vec=False): 237 | """ 238 | docstring here 239 | :param self: 240 | :param img: 241 | :param vec=False: 242 | """ 243 | if vec: 244 | img[:,:,::2]*=-1 245 | else: 246 | pass 247 | return img 248 | def rotate(self,ratio): 249 | pass 250 | 251 | def test_cf(self): 252 | '''测试''' 253 | from config import opt 254 | opt.img_root = '/mnt/7/ai_challenger_keypoint_train_20170909/keypoint_train_images_20170902/' 255 | import torch as t 256 | annos=t.load('/mnt/6/train.pth') 257 | anno=annos[58] 258 | a=Item(opt,anno) 259 | cf = a.confidence_map() 260 | plt.imshow(cf[:,:,-1]) 261 | a._img 262 | 263 | -------------------------------------------------------------------------------- /data/tool.py: -------------------------------------------------------------------------------- 1 | #coding:utf8 2 | import torch as t 3 | import numpy as np 4 | from skimage import transform 5 | from pylab import plt 6 | from PIL import Image 7 | from skimage import transform 8 | 9 | def show_paf(img,paf,stride = 5,thres=0.1): 10 | """ 11 | @param img: ndarry, HxWx3 12 | @param paf: ndarry, HxWxN 13 | """ 14 | paf = transform.rescale(paf,img.shape[0]/paf.shape[0]) 15 | h,w,n = paf.shape 16 | mask = (paf**2).reshape(h,w,n/2,2).sum(axis=3).sum(axis=2) threshold) 180 | average_precision.append( 181 | num_thre / np.float32(oks_num)) 182 | print (num_thre, num_thre / np.float32(oks_num)) 183 | #print (average_precision) 184 | return_dict['score'] = np.mean(average_precision) 185 | 186 | return return_dict 187 | 188 | 189 | def main(**kwargs): 190 | """The evaluator.""" 191 | 192 | # Arguments parser 193 | parse(kwargs) 194 | # Initialize return_dict 195 | return_dict = dict() 196 | return_dict['error'] = None 197 | return_dict['warning'] = [] 198 | return_dict['score'] = None 199 | 200 | # Load annotation JSON file 201 | start_time = time.time() 202 | annotations = load_annotations(anno_file=opt.ref, return_dict=return_dict) 203 | print ('Complete reading annotation JSON file in %.2f seconds.' % ( 204 | time.time() - start_time)) 205 | 206 | # Load prediction JSON file 207 | start_time = time.time() 208 | predictions = load_predictions( 209 | prediction_file=opt.submit, return_dict=return_dict) 210 | print ('Complete reading prediction JSON file in %.2f seconds.' % ( 211 | time.time() - start_time)) 212 | 213 | # Keypoint evaluation 214 | start_time = time.time() 215 | return_dict = keypoint_eval( 216 | predictions=predictions, 217 | annotations=annotations, 218 | return_dict=return_dict) 219 | print ('Complete evaluation in %.2f seconds.' % (time.time() - start_time)) 220 | 221 | # Print return_dict and final score 222 | # pprint.pprint(return_dict) 223 | print ('Score: ', '%.8f' % return_dict['score']) 224 | 225 | 226 | if __name__ == "__main__": 227 | import fire 228 | fire.Fire() 229 | -------------------------------------------------------------------------------- /eval_tool.py: -------------------------------------------------------------------------------- 1 | #coding:utf8 2 | import json 3 | import time 4 | import argparse 5 | import pprint 6 | 7 | import numpy as np 8 | from config import opt 9 | from warnings import warn 10 | 11 | def parse(kwargs): 12 | ## 处理配置和参数 13 | for k, v in kwargs.items(): 14 | if not hasattr(opt, k): 15 | print("Warning: opt has not attribut %s" % k) 16 | setattr(opt, k, v) 17 | for k, v in opt.__class__.__dict__.items(): 18 | if not k.startswith('__'): print(k, getattr(opt, k)) 19 | 20 | 21 | def load_annotations(annos): 22 | """Convert annotation JSON file.""" 23 | 24 | annotations = dict() 25 | annotations['image_ids'] = set([]) 26 | annotations['annos'] = dict() 27 | annotations['delta'] = 2*np.array([0.01388152, 0.01515228, 0.01057665, 0.01417709, \ 28 | 0.01497891, 0.01402144, 0.03909642, 0.03686941, 0.01981803, \ 29 | 0.03843971, 0.03412318, 0.02415081, 0.01291456, 0.01236173]) 30 | 31 | 32 | for anno in annos: 33 | annotations['image_ids'].add(anno['image_id']) 34 | annotations['annos'][anno['image_id']] = dict() 35 | annotations['annos'][anno['image_id']]['human_annos'] = anno[ 36 | 'human_annotations'] 37 | annotations['annos'][anno['image_id']]['keypoint_annos'] = anno[ 38 | 'keypoint_annotations'] 39 | 40 | return annotations 41 | 42 | 43 | def load_predictions(preds): 44 | """Convert prediction JSON file.""" 45 | 46 | predictions = dict() 47 | predictions['image_ids'] = [] 48 | predictions['annos'] = dict() 49 | id_set = set([]) 50 | 51 | 52 | for pred in preds: 53 | if 'image_id' not in pred.keys(): 54 | warn('There is an invalid annotation info, \ 55 | likely missing key \'image_id\'.') 56 | continue 57 | if 'keypoint_annotations' not in pred.keys(): 58 | warn(pred['image_id']+\ 59 | ' does not have key \'keypoint_annotations\'.') 60 | continue 61 | image_id = pred['image_id'].split('.')[0] 62 | if image_id in id_set: 63 | warn(pred['image_id']+\ 64 | ' is duplicated in prediction JSON file.') 65 | else: 66 | id_set.add(image_id) 67 | predictions['image_ids'].append(image_id) 68 | predictions['annos'][pred['image_id']] = dict() 69 | predictions['annos'][pred['image_id']]['keypoint_annos'] = pred[ 70 | 'keypoint_annotations'] 71 | 72 | return predictions 73 | 74 | 75 | def compute_oks(anno, predict, delta): 76 | """Compute oks matrix (size gtN*pN).""" 77 | 78 | anno_count = len(anno['keypoint_annos'].keys()) 79 | predict_count = len(predict.keys()) 80 | oks = np.zeros((anno_count, predict_count)) 81 | if predict_count == 0:return oks.T 82 | 83 | # for every human keypoint annotation 84 | for i in range(anno_count): 85 | anno_key = anno['keypoint_annos'].keys()[i] 86 | anno_keypoints = np.reshape(anno['keypoint_annos'][anno_key], (14, 3)) 87 | visible = anno_keypoints[:, 2] == 1 88 | bbox = anno['human_annos'][anno_key] 89 | scale = np.float32((bbox[3] - bbox[1]) * (bbox[2] - bbox[0])) 90 | if np.sum(visible) == 0: 91 | for j in range(predict_count): 92 | oks[i, j] = 0 93 | else: 94 | # for every predicted human 95 | for j in range(predict_count): 96 | predict_key = predict.keys()[j] 97 | predict_keypoints = np.reshape(predict[predict_key], (14, 3)) 98 | dis = np.sum((anno_keypoints[visible, :2] \ 99 | - predict_keypoints[visible, :2])**2, axis=1) 100 | oks[i, j] = np.mean( 101 | np.exp(-dis / 2 / delta[visible]**2 / (scale + 1))) 102 | # if anno_count<50: 103 | # print ("oks: ",oks) 104 | return oks 105 | 106 | 107 | def keypoint_eval(predictions, annotations): 108 | """Evaluate predicted_file and return mAP.""" 109 | 110 | oks_all = np.zeros((0)) 111 | oks_num = 0 112 | prediction_id_set = set(predictions['image_ids']) 113 | _oks_ = [] 114 | 115 | # for every annotation in our test/validation set 116 | for image_id in annotations['image_ids']: 117 | # if the image in the predictions, then compute oks 118 | if image_id in prediction_id_set: 119 | oks = compute_oks(anno=annotations['annos'][image_id], \ 120 | predict=predictions['annos'][image_id]['keypoint_annos'], \ 121 | delta=annotations['delta']) 122 | _oks_.append(oks) 123 | # view pairs with max OKSs as match ones, add to oks_all 124 | oks_all = np.concatenate((oks_all, np.max(oks, axis=1)), axis=0) 125 | # accumulate total num by max(gtN,pN) 126 | oks_num += np.max(oks.shape) 127 | else: 128 | continue 129 | # otherwise report warning 130 | warn(image_id + ' is not in the prediction JSON file.') 131 | # number of humen in ground truth annotations 132 | gt_n = len(annotations['annos'][image_id]['human_annos'].keys()) 133 | # fill 0 in oks scores 134 | oks_all = np.concatenate((oks_all, np.zeros((gt_n))), axis=0) 135 | # accumulate total num by ground truth number 136 | #oks_num += gt_n 137 | 138 | # compute mAP by APs under different oks thresholds 139 | #print (oks_num) 140 | average_precision = [] 141 | for threshold in np.linspace(0.5, 0.95, 10): 142 | num_thre = np.sum(oks_all > threshold) 143 | average_precision.append( 144 | num_thre / np.float32(oks_num)) 145 | #print (num_thre, num_thre / np.float32(oks_num)) 146 | #print (average_precision) 147 | return np.mean(average_precision),(oks_all,oks_num,_oks_) 148 | 149 | 150 | 151 | def test(preds,annos): 152 | 153 | 154 | #annos = json.load(open(opt.ref))) 155 | #annotations = load_annotations(annos) 156 | 157 | 158 | 159 | score = keypoint_eval( 160 | predictions=preds, 161 | annotations=annos, 162 | ) 163 | print ('Score: ', '%.8f' % score) 164 | 165 | 166 | 167 | if __name__ == "__main__": 168 | import fire 169 | fire.Fire() 170 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from loss import l2_loss 2 | from model import KeypointModel -------------------------------------------------------------------------------- /models/basic_module.py: -------------------------------------------------------------------------------- 1 | #coding:utf8 2 | import torch as t 3 | import time 4 | class BasicModule(t.nn.Module): 5 | ''' 6 | 封装了nn.Module,主要是提供了save和load两个方法 7 | ''' 8 | 9 | def __init__(self,opt=None): 10 | super(BasicModule,self).__init__() 11 | self.model_name=str(type(self).__name__)# 默认名字 12 | self.opt = opt 13 | 14 | def load(self, path,map_location=lambda storage, loc: storage): 15 | checkpoint = t.load(path,map_location=map_location) 16 | if 'opt' in checkpoint: 17 | self.load_state_dict(checkpoint['d']) 18 | print('old config:') 19 | print(checkpoint['opt']) 20 | else: 21 | self.load_state_dict(checkpoint) 22 | # for k,v in checkpoint['opt'].items(): 23 | # setattr(self.opt,k,v) 24 | return self 25 | 26 | def save(self, name=''): 27 | # prefix = 'checkpoint/'+name 28 | file_name = 'checkpoints/{model_name}_{time_}_{name}.pth'.format( 29 | time_=time.strftime('%m%d_%H%M'), 30 | model_name = self.model_name, 31 | name= str(name)) 32 | state_dict = self.state_dict() 33 | opt_state_dict = self.opt.state_dict() 34 | t.save({'d':state_dict,'opt':opt_state_dict}, file_name) 35 | return file_name 36 | -------------------------------------------------------------------------------- /models/loss.py: -------------------------------------------------------------------------------- 1 | # coding:utf8 2 | import torch 3 | from torch.nn import init 4 | import torch as t 5 | from torch import nn 6 | import time 7 | import torch.nn as nn 8 | from torch.nn.functional import mse_loss 9 | 10 | def l2_loss(outputs, target, mask): 11 | ''' 12 | @param outputs: tuple of variabel size(6),outputs of 6 stages 13 | @param target: tensor size(h,w,) 14 | @param mask: Variable 15 | @return loss 16 | ''' 17 | 18 | losses = [] 19 | batch_size,c_,h_,w_ = outputs[0].size() 20 | 21 | for output in outputs: 22 | dist = (output - target) * mask#.view(-1,1,h_,w_).expand_as(output) 23 | losses.append((dist**2).sum()/dist.numel()) 24 | loss = sum(losses) 25 | return loss,losses 26 | -------------------------------------------------------------------------------- /models/mobilenet.py: -------------------------------------------------------------------------------- 1 | #coding:utf8 2 | 3 | import torchvision as tv 4 | import torch as t 5 | from torch import nn 6 | 7 | 8 | class MobileNet(nn.Module): 9 | 10 | def __init__(self): 11 | super(MobileNet, self).__init__() 12 | self.model_name='mobilenet' 13 | 14 | def conv_bn(inp, oup, stride): 15 | return nn.Sequential( 16 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 17 | nn.BatchNorm2d(oup), 18 | nn.ReLU(inplace=True) 19 | ) 20 | 21 | def conv_dw(inp, oup, stride): 22 | return nn.Sequential( 23 | nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), 24 | nn.BatchNorm2d(inp), 25 | nn.ReLU(inplace=True), 26 | 27 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 28 | nn.BatchNorm2d(oup), 29 | nn.ReLU(inplace=True), 30 | ) 31 | 32 | self.features = nn.Sequential( 33 | conv_bn( 3, 32, 2), 34 | conv_dw( 32, 64, 1), 35 | conv_dw( 64, 128, 2), 36 | conv_dw(128, 128, 1), 37 | conv_dw(128, 256, 2), 38 | conv_dw(256, 256, 1), 39 | conv_dw(256, 512, 2), 40 | conv_dw(512, 512, 1), 41 | conv_dw(512, 512, 1), 42 | conv_dw(512, 512, 1), 43 | conv_dw(512, 512, 1), 44 | conv_dw(512, 512, 1), 45 | conv_dw(512, 1024, 2), 46 | conv_dw(1024, 1024, 1), 47 | ) 48 | self.classifier = nn.Linear(1024, 2) 49 | self.load_state_dict(t.load('/data_ssd/ai_challenger/ai_challenger_keypoint_train_20170909/mobilenet.pth')) 50 | self.pack() 51 | 52 | def pack(self): 53 | del self.classifier 54 | for ii in range(13,5,-1): 55 | del self.features._modules[str(ii)] 56 | #del self.features._modules[14] 57 | 58 | 59 | def forward(self, x): 60 | x = self.features(x) 61 | # x = self.classifier(x) 62 | return x 63 | -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | #coding:utf8 2 | 3 | import torchvision as tv 4 | import torch as t 5 | from torch import nn 6 | from mobilenet import MobileNet 7 | from collections import namedtuple 8 | from basic_module import BasicModule 9 | 10 | class Stage(nn.Module): 11 | 12 | def __init__(self, inplanes,outplanes=41): 13 | super(Stage, self).__init__() 14 | self.cf = nn.Sequential( 15 | nn.Conv2d(inplanes,128,7,1,3), 16 | nn.BatchNorm2d(128), 17 | nn.ReLU(True), 18 | 19 | nn.Conv2d(128,128,7,1,3), 20 | nn.BatchNorm2d(128), 21 | nn.ReLU(True), 22 | 23 | nn.Conv2d(128,128,7,1,3), 24 | nn.BatchNorm2d(128), 25 | nn.ReLU(True), 26 | 27 | nn.Conv2d(128,128,7,1,3), 28 | nn.BatchNorm2d(128), 29 | nn.ReLU(True), 30 | 31 | nn.Conv2d(128,128,7,1,3), 32 | nn.BatchNorm2d(128), 33 | nn.ReLU(True), 34 | 35 | nn.Conv2d(128,128,3,1,1), 36 | nn.BatchNorm2d(128), 37 | nn.ReLU(True), 38 | 39 | nn.Conv2d(128,outplanes,1) 40 | ) 41 | 42 | 43 | def forward(self, x): 44 | return self.cf(x) 45 | 46 | 47 | class KeypointModel(BasicModule): 48 | def __init__(self,opt): 49 | super(KeypointModel, self).__init__(opt) 50 | self.pretrained= MobileNet() 51 | self.trf = nn.Sequential( 52 | nn.Conv2d(256,256,3,1,1), 53 | nn.BatchNorm2d(256), 54 | nn.ReLU(True), 55 | nn.Conv2d(256,128,3,1,1), 56 | nn.BatchNorm2d(128), 57 | nn.ReLU(True) 58 | ) 59 | # self.ReturnType = namedtuple('ReturnType',['out1','out2','out3','out4','out5','out6']) 60 | stages = [Stage(128)] + [Stage(169) for _ in range(2,7)] 61 | self.stages = nn.ModuleList(stages) 62 | 63 | def forward(self,img): 64 | img = self.pretrained(img) 65 | #if self.optimizer.param_groups[0]['lr'] == 0: 66 | # img = img.detach() 67 | features = self.trf(img) 68 | 69 | output = self.stages[0](features) 70 | outputs= [output] 71 | for ii in range(1,6): 72 | stage = self.stages[ii] 73 | input = t.cat([features,output],dim=1) 74 | output = stage(input) 75 | outputs.append(output) 76 | 77 | return outputs 78 | 79 | def get_optimizer(self,lr1,lr2): 80 | param_groups = [ 81 | {'params':self.pretrained.parameters(),'lr':lr1}, 82 | {'params':self.stages.parameters(),'lr':lr2}, 83 | {'params':self.trf.parameters(),'lr':lr2} 84 | ] 85 | 86 | self.optimizer = t.optim.Adam(param_groups) 87 | return self.optimizer 88 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | #coding:utf8 2 | import math 3 | import sys 4 | import time 5 | from config import opt as opt_ 6 | import glob 7 | import matplotlib 8 | matplotlib.use('agg') 9 | import matplotlib.pyplot as plt 10 | import numpy as np 11 | import scipy 12 | import torch as t 13 | 14 | import torchvision 15 | from data.tool import show_paf 16 | from models import KeypointModel 17 | from utils.test_tool import * 18 | import tqdm 19 | import json 20 | 21 | class TestConfig: 22 | scale_search = [0.8,1,1.5,2.0] 23 | boxsize=384 24 | stride=8 25 | val = True 26 | val_dir = '/data/image/ai_cha/old_del/ai_challenger_keypoint_validation_20170911/keypoint_validation_images_20170911/' 27 | ckpt_path = 'checkpoints/KeypointModel_1030_0332_0.0100910376112.pth' 28 | 29 | 30 | opt = TestConfig() 31 | 32 | 33 | def test_file(pose_model,file_name='206998b0cca06ec6f3951e4acf7e178e114bdba7.jpg'): 34 | """ 35 | @param file_name: 206998b0cca06ec6f3951e4acf7e178e114bdba7.jpg 36 | @param test_image: 37 | """ 38 | # 图片 39 | 40 | test_image = file_name 41 | oriImg = cv.imread(test_image) # B,G,R order 42 | multiplier = [x * opt.boxsize*1.0 / oriImg.shape[0] for x in opt.scale_search] 43 | 44 | # 得到heatmap+paf 45 | heatmap_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 15)) 46 | paf_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 26)) 47 | for scale in multiplier: 48 | (heatmap1,heatmap2),(paf1,paf2) = get_output(pose_model,scale,oriImg,opt.stride) 49 | heatmap = flip_heatmap(heatmap1,heatmap2) 50 | paf = flip_paf(paf1,paf2) 51 | heatmap_avg += heatmap 52 | paf_avg += paf 53 | heatmap_avg /= len(multiplier) 54 | paf_avg /= len(multiplier) 55 | 56 | # 找到关键点 57 | all_peaks,peak_counter = find_peaks(heatmap_avg,0.1) 58 | # 找到连接 59 | special_k,connection_all = find_connection(all_peaks,paf_avg) 60 | # 找到人 61 | subsets,candidate = find_person(all_peaks,special_k,connection_all) 62 | return get_result(subsets,candidate,file_name.split('/')[-1][:-4]) 63 | 64 | def main(**kwargs): 65 | pose_model=KeypointModel(opt_) 66 | pose_model.load(opt.ckpt_path).eval().cuda() 67 | 68 | img_list = glob.glob(opt.val_dir + '*.jpg') 69 | results = [] 70 | for ii,img in tqdm.tqdm(enumerate(img_list)): 71 | img_result = test_file(pose_model, img) 72 | results.append(img_result) 73 | if (ii+1)%100==0: 74 | save_json(results,'tmp%s.json'%ii) 75 | 76 | if __name__=='__main__': 77 | import fire 78 | fire.Fire() 79 | -------------------------------------------------------------------------------- /test_save.py: -------------------------------------------------------------------------------- 1 | #coding:utf8 2 | import math 3 | import sys 4 | import time 5 | from config import opt as opt_ 6 | import glob 7 | import matplotlib 8 | matplotlib.use('agg') 9 | import matplotlib.pyplot as plt 10 | import numpy as np 11 | import scipy 12 | import torch as t 13 | 14 | import torchvision 15 | from data.tool import show_paf 16 | from models import KeypointModel 17 | from utils.test_tool import * 18 | import tqdm 19 | import json 20 | 21 | class TestConfig: 22 | scale_search = [0.8,1,1.5,2.0] 23 | boxsize=384 24 | stride=8 25 | val = True 26 | val_dir = '/data/image/ai_cha/old_del/ai_challenger_keypoint_validation_20170911/keypoint_validation_images_20170911/' 27 | ckpt_path = 'checkpoints/KeypointModel_1030_0332_0.0100910376112.pth' 28 | output_path = 'result/' 29 | test_num = 1000 30 | 31 | 32 | opt = TestConfig() 33 | 34 | 35 | def test_file(pose_model,file_name='206998b0cca06ec6f3951e4acf7e178e114bdba7.jpg'): 36 | """ 37 | @param file_name: 206998b0cca06ec6f3951e4acf7e178e114bdba7.jpg 38 | @param file_name: 39 | """ 40 | 41 | # 图片 42 | oriImg = cv.imread(file_name) # B,G,R order 43 | multiplier = [x * opt.boxsize*1.0 / oriImg.shape[0] for x in opt.scale_search] 44 | 45 | # 得到heatmap+paf 46 | heatmap_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 15)) 47 | paf_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 26)) 48 | for scale in multiplier: 49 | (heatmap1,heatmap2),(paf1,paf2) = get_output(pose_model,scale,oriImg,opt.stride) 50 | heatmap = flip_heatmap(heatmap1,heatmap2) 51 | paf = flip_paf(paf1,paf2) 52 | heatmap_avg += heatmap 53 | paf_avg += paf 54 | heatmap_avg /= len(multiplier) 55 | paf_avg /= len(multiplier) 56 | result = dict(heatmap=heatmap_avg,paf=paf_avg) 57 | save_path = '%s/%s.npz' %(opt.output_path,file_name.split('/')[-1][:-4]) 58 | np.savez_compressed(save_path,paf=paf_avg,heatmap=heatmap_avg) 59 | # # 找到关键点 60 | # all_peaks,peak_counter = find_peaks(heatmap_avg,0.1) 61 | # # 找到连接 62 | # special_k,connection_all = find_connection(all_peaks,paf_avg) 63 | # # 找到人 64 | # subsets,candidate = find_person(all_peaks,special_k,connection_all) 65 | # return get_result(subsets,candidate,file_name.split('/')[-1][:-4]) 66 | 67 | def main(**kwargs): 68 | pose_model=KeypointModel(opt_) 69 | pose_model.load(opt.ckpt_path).eval().cuda() 70 | 71 | img_list = glob.glob(opt.val_dir + '*.jpg')[:opt.test_num] 72 | 73 | for ii,img in tqdm.tqdm(enumerate(img_list)): 74 | test_file(pose_model, img) 75 | 76 | if __name__=='__main__': 77 | import fire 78 | fire.Fire() 79 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # coding:utf8 2 | import os 3 | import time 4 | import fire 5 | import ipdb 6 | 7 | import numpy as np 8 | import torch as t 9 | import torchnet as tnt 10 | from tqdm import tqdm 11 | from skimage import transform 12 | import matplotlib 13 | matplotlib.use('agg') 14 | 15 | from config import opt 16 | from data import Dataset 17 | from models import l2_loss 18 | from utils import Visualizer,mask_img 19 | import models 20 | from data import tool 21 | 22 | 23 | 24 | 25 | def train(**kwargs): 26 | opt.parse(kwargs) 27 | vis = Visualizer(opt.env) 28 | 29 | model = models.KeypointModel(opt) 30 | if opt.model_path is not None: 31 | model.load(opt.model_path) 32 | 33 | model.cuda() 34 | dataset = Dataset(opt) 35 | dataloader = t.utils.data.DataLoader( 36 | dataset, 37 | opt.batch_size, 38 | num_workers=opt.num_workers, 39 | shuffle=True, 40 | drop_last=True) 41 | 42 | lr1, lr2 = opt.lr1, opt.lr2 43 | optimizer = model.get_optimizer(lr1, lr2) 44 | loss_meter = tnt.meter.AverageValueMeter() 45 | pre_loss = 1e100 46 | model.save() 47 | for epoch in range(opt.max_epoch): 48 | 49 | loss_meter.reset() 50 | start = time.time() 51 | 52 | for ii, (img, gt, weight) in tqdm(enumerate(dataloader)): 53 | optimizer.zero_grad() 54 | img = t.autograd.Variable(img).cuda() 55 | target = t.autograd.Variable(gt).cuda() 56 | weight = t.autograd.Variable(weight).cuda() 57 | outputs = model(img) 58 | loss, loss_list = l2_loss(outputs, target, weight) 59 | (loss).backward() 60 | loss_meter.add(loss.data[0]) 61 | optimizer.step() 62 | 63 | # 可视化, 记录, log,print 64 | if ii % opt.plot_every == 0 and ii > 0: 65 | if os.path.exists(opt.debug_file): 66 | ipdb.set_trace() 67 | vis_plots = {'loss': loss_meter.value()[0], 'ii': ii} 68 | vis.plot_many(vis_plots) 69 | 70 | # 随机展示一张图片 71 | k = t.randperm(img.size(0))[0] 72 | show = img.data[k].cpu() 73 | raw = (show*0.225+0.45).clamp(min=0,max=1) 74 | 75 | train_masked_img = mask_img( raw,outputs[-1].data[k][14]) 76 | origin_masked_img = mask_img( raw,gt[k][14]) 77 | 78 | vis.img('target',origin_masked_img) 79 | vis.img('train',train_masked_img) 80 | vis.img('label',gt[k][14]) 81 | vis.img('predict',outputs[-1].data[k][14].clamp(max=1,min=0)) 82 | paf_img = tool.vis_paf(raw,gt[k][15:]) 83 | train_paf_img = tool.vis_paf(raw,outputs[-1][k].data[15:].clamp(min=-1,max=1)) 84 | vis.img('paf_train',train_paf_img) 85 | #fig = tool.show_paf(np.transpose(raw.cpu().numpy(),(1,2,0)),gt[k][15:].cpu().numpy().transpose((1,2,0))).get_figure() 86 | #paf_img = tool.fig2data(fig).astype(np.int32) 87 | #vis.img('paf',t.from_numpy(paf_img/255).float()) 88 | vis.img('paf',paf_img) 89 | model.save(loss_meter.value()[0]) 90 | vis.save([opt.env]) 91 | 92 | if __name__ == "__main__": 93 | fire.Fire() 94 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .visualize import Visualizer,mask_img -------------------------------------------------------------------------------- /utils/draw_limb.py: -------------------------------------------------------------------------------- 1 | #coding:utf8 2 | import numpy as np 3 | import json 4 | from scipy.misc import imread 5 | import cv2 6 | from tqdm import tqdm 7 | from pylab import plt 8 | ''' 9 | 1/右肩,2/右肘,3/右腕,4/左肩,5/左肘,6/左腕, 10 | 7/右髋,8/右膝,9/右踝,10/左髋,11/左膝,12/左踝,13/头顶,14/脖子。 11 | ''' 12 | part_labels = ['sho_r','elb_r','wri_r','sho_l','elb_l','wri_l','hip_r','kne_r','ank_r','hip_l','kne_l','ank_l','head','neck'] 13 | # part_labels = ['nose','eye_l','eye_r','ear_l','ear_r', 14 | # 'sho_l','sho_r','elb_l','elb_r','wri_l','wri_r', 15 | # 'hip_l','hip_r','kne_l','kne_r','ank_l','ank_r'] 16 | part_idx = {b:a for a, b in enumerate(part_labels)} 17 | 18 | 19 | def show_limb(pred,img_path): 20 | img = imread(img_path) 21 | annos = pred.get('keypoint_annotations', pred.get('keypoint_annos')) 22 | for human in annos.values(): 23 | draw_limbs(img,human) 24 | return plt.imshow(img) 25 | #def show_limb2(pred,root): 26 | # img = root+pred['image_id']+'.jpg' 27 | # img = imread(img) 28 | # annos = pred.get('keypoint_annotations', pred.get('keypoint_annos')) 29 | # for human in annos.values(): 30 | # draw_limbs(img,human) 31 | # break 32 | # return plt.imshow(img) 33 | def draw_limbs(inp, pred): 34 | def link(a, b, color): 35 | if part_idx[a] < pred.shape[0] and part_idx[b] < pred.shape[0]: 36 | a = pred[part_idx[a]] 37 | b = pred[part_idx[b]] 38 | # if a[2]>0.07 and b[2]>0.07: 39 | 40 | if a[2]<3 and b[2]<3 and a[0]*b[0]: 41 | cv2.line(inp, (int(a[0]), int(a[1])), (int(b[0]), int(b[1])), color, 6) 42 | 43 | pred = np.array(pred).reshape(-1, 3) 44 | bbox = pred[pred[:,2]>0] 45 | 46 | a, b, c, d = bbox[:,0][bbox[:,0]>0].min(), bbox[:,1][bbox[:,1]>0].min(),\ 47 | bbox[:,0][bbox[:,0]>0].max(), bbox[:,1][bbox[:,1]>0].max() 48 | cv2.rectangle(inp, (int(a), int(b)), (int(c), int(d)), (255, 255, 255), 5) 49 | 50 | # link('nose', 'eye_l', (255, 0, 0)) 51 | # link('eye_l', 'eye_r', (255, 0, 0)) 52 | # link('eye_r', 'nose', (255, 0, 0)) 53 | 54 | # link('eye_l', 'ear_l', (255, 0, 0)) 55 | # link('eye_r', 'ear_r', (255, 0, 0)) 56 | 57 | # link('ear_l', 'sho_l', (255, 0, 0)) 58 | # link('ear_r', 'sho_r', (255, 0, 0)) 59 | link('head','neck',(255,0,0)) 60 | link('neck','sho_l',(255,0,0)) 61 | link('neck','sho_r',(255,0,0)) 62 | link('sho_l', 'sho_r', (255, 0, 0)) 63 | link('sho_l', 'hip_l', (0, 255, 0)) 64 | link('sho_r', 'hip_r',(0, 255, 0)) 65 | link('hip_l', 'hip_r', (0, 255, 0)) 66 | 67 | link('sho_l', 'elb_l', (0, 0, 255)) 68 | link('elb_l', 'wri_l', (0, 0, 255)) 69 | 70 | link('sho_r', 'elb_r', (0, 0, 255)) 71 | link('elb_r', 'wri_r', (0, 0, 255)) 72 | 73 | link('hip_l', 'kne_l', (255, 255, 0)) 74 | link('kne_l', 'ank_l', (255, 255, 0)) 75 | 76 | link('hip_r', 'kne_r', (255, 255, 0)) 77 | link('kne_r', 'ank_r', (255, 255, 0)) 78 | 79 | class Config: 80 | pred = '/data/image/ai_cha/old_del/ai_challenger_keypoint_validation_20170911/keypoint_validation_annotations_20170911.json' 81 | #pred = '/home/a/code/pytorch/ai_challenge/dc_keypoint/KeyPoint/result/val_result.json' 82 | file_dir ='/data/image/ai_cha/old_del/ai_challenger_keypoint_validation_20170911/keypoint_validation_images_20170911' 83 | result_img = 'out.jpg' 84 | limit = 100 85 | opt = Config() 86 | 87 | if __name__=='__main__': 88 | 89 | with open(opt.pred) as f: 90 | data = json.load(f) 91 | 92 | for ii,d in tqdm(enumerate(data[:opt.limit]),total= opt.limit): 93 | img_path = '%s/%s.jpg' %(opt.file_dir, d['image_id']) 94 | img = imread(img_path, mode='RGB') 95 | for human in d['keypoint_annotations'].values(): 96 | draw_limbs(img, human) 97 | cv2.imwrite('%s.jpg' %ii, img[:,:,::-1]) 98 | 99 | 100 | 101 | -------------------------------------------------------------------------------- /utils/test_tool.py: -------------------------------------------------------------------------------- 1 | #coding:utf8 2 | import torch as t 3 | import numpy as np 4 | import cv2 as cv 5 | from pylab import plt 6 | import torchvision as tv 7 | import copy 8 | import scipy 9 | from collections import namedtuple 10 | from scipy.ndimage.filters import gaussian_filter 11 | import math 12 | import json 13 | 14 | Result = namedtuple('Result', ['all_peaks','peak_counter',\ 15 | 'special_k','connection_all','subsets','candidate','pred','rrs','file_name']) 16 | class HyperParameter: 17 | h_thre=0.1 # heatmap 关键点阈值 18 | p_thre = 0.05 # paf阈值 19 | sigma= 5 # 高斯滤波器的sigma 20 | min_limb_num = 5 # 一个人至少找到5个关节才行 21 | min_avg_score = 0.5 # 一个人的关节之间的置信度阈值 22 | criterion1 = None 23 | criterion2 = None 24 | search_order = None 25 | 26 | #hp = HyperParameter() 27 | 28 | def padRightDownCorner(img, stride, padValue): 29 | """ 30 | 在图片的右下角补边 31 | @param img: 图片 32 | @param stride: 图片长宽都要是这个的整数倍 33 | @param padValue: 补的值,0代表黑,128代表回,255代表白 34 | @return im_padded: ndarray 35 | @return pad: list of pad lengths 36 | """ 37 | h = img.shape[0] 38 | w = img.shape[1] 39 | 40 | pad = 4 * [None] 41 | pad[0] = 0 # up 42 | pad[1] = 0 # left 43 | pad[2] = 0 if (h%stride==0) else stride - (h % stride) # down 44 | pad[3] = 0 if (w%stride==0) else stride - (w % stride) # right 45 | 46 | img_padded = img 47 | pad_up = np.tile(img_padded[0:1,:,:]*0 + padValue, (pad[0], 1, 1)) 48 | img_padded = np.concatenate((pad_up, img_padded), axis=0) 49 | pad_left = np.tile(img_padded[:,0:1,:]*0 + padValue, (1, pad[1], 1)) 50 | img_padded = np.concatenate((pad_left, img_padded), axis=1) 51 | pad_down = np.tile(img_padded[-2:-1,:,:]*0 + padValue, (pad[2], 1, 1)) 52 | img_padded = np.concatenate((img_padded, pad_down), axis=0) 53 | pad_right = np.tile(img_padded[:,-2:-1,:]*0 + padValue, (1, pad[3], 1)) 54 | img_padded = np.concatenate((img_padded, pad_right), axis=1) 55 | 56 | return img_padded, pad 57 | 58 | 59 | def find(heatmap,paf,file_name,hp): 60 | all_peaks,peak_counter = find_peaks(heatmap,hp ) 61 | # 找到连接 62 | special_k,connection_all = find_connection(all_peaks,paf,hp ) 63 | # 找到人 64 | subsets,candidate,rrs = find_person(all_peaks,special_k,connection_all,hp) 65 | pred = get_result(subsets,candidate,file_name.split('/')[-1][:-4]) 66 | _file = file_name.split('/')[-1][:-4] 67 | return Result(all_peaks,peak_counter,special_k,connection_all,subsets,candidate,pred,rrs,_file) 68 | 69 | def flip_heatmap(heatmap1,heatmap2): 70 | ''' 71 | 翻转之后,左右也要都应翻转(左手腕的channel对应右手腕的channel),不是简单的::-1 72 | @param: heatmap1: 原图headmap 73 | @param: heatmap2:flipped heatmap 74 | @return output: 原图和flip后平均后的heatmap 75 | ''' 76 | output=np.zeros([heatmap1.shape[0],heatmap1.shape[1],15]) 77 | left = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,13, 14] 78 | right =[4, 5, 6, 1, 2, 3,10, 11, 12, 7, 8, 9, 13, 14] 79 | for ii in range(14): 80 | output[:,:,ii]=(heatmap1[:,:,left[ii]-1]+heatmap2[:,::-1,right[ii]-1])/2.0 81 | output[:,:,14]=(heatmap1[:,:,14]+heatmap2[:,::-1,14])/2.0 82 | return output 83 | def flip_paf(paf1,paf2): 84 | ''' 85 | @param paf1: 原图paf 86 | @param paf2:flipped paf 87 | @return output: 平均paf 88 | ''' 89 | map_index = [0,1,4,6,2,3,10,11,12,13,6,7,8,9,20,21,22,23,24,25,14,15,16,17,18,19] 90 | map_ = np.array([-1 if ii%2==0 else 1 for ii in range(26)]) 91 | output2 = paf1+paf2[:,::-1,map_index]*map_ 92 | return output2/2.0 93 | 94 | 95 | 96 | normalize = tv.transforms.Normalize(mean = [0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]) 97 | 98 | def get_output(pose_model,scale,oriImg,stride): 99 | """ 100 | @param pose_model: model 101 | @param scale: 图片缩放尺度 102 | @param oriImg: 原始图片 103 | """ 104 | image_hwc = cv.resize(oriImg, (0,0), fx=scale, fy=scale, interpolation=cv.INTER_CUBIC) 105 | image_hwc_flip=image_hwc[:,::-1,:] 106 | image_hwc_padded, pad1 = padRightDownCorner(image_hwc, 64, 128) 107 | image_hwc_pad_flip,pad2=padRightDownCorner(image_hwc_flip, 64, 128) 108 | 109 | image_chw = np.transpose(np.float32(image_hwc_padded[:,:,:]), (2,0,1))/255.0 110 | image_tensor = t.from_numpy(image_chw) 111 | image_tensor = normalize(image_tensor).unsqueeze(0) 112 | output1=pose_model(t.autograd.Variable(image_tensor.cuda(),volatile=True)) 113 | 114 | image_chw_pad_flip = np.transpose(np.float32(image_hwc_pad_flip[:,:,:]), (2,0,1))/255.0 115 | image_tensor = t.from_numpy(image_chw_pad_flip) 116 | image_tensor = normalize(image_tensor).unsqueeze(0) 117 | output2=pose_model(t.autograd.Variable(image_tensor,volatile=True).cuda()) 118 | 119 | # resize to origin shape and substract the padded heatmap 120 | heatmap1 = np.transpose(output1[-1][0][:15].data.cpu().numpy(),[1,2,0]) 121 | heatmap2 = np.transpose(output2[-1][0][:15].data.cpu().numpy(),[1,2,0]) 122 | heatmap1 = cv.resize(heatmap1, (0,0), fx=stride, fy=stride, interpolation=cv.INTER_CUBIC) 123 | heatmap1 = heatmap1[:image_hwc_padded.shape[0]-pad1[2], :image_hwc_padded.shape[1]-pad1[3], :] 124 | heatmap1 = cv.resize(heatmap1, (oriImg.shape[1], oriImg.shape[0]), interpolation=cv.INTER_CUBIC) 125 | heatmap2 = cv.resize(heatmap2, (0,0), fx=stride, fy=stride, interpolation=cv.INTER_CUBIC) 126 | heatmap2 = heatmap2[:image_hwc_pad_flip.shape[0]-pad2[2], :image_hwc_pad_flip.shape[1]-pad2[3], :] 127 | heatmap2 = cv.resize(heatmap2, (oriImg.shape[1], oriImg.shape[0]), interpolation=cv.INTER_CUBIC) 128 | 129 | # resize to origin shape and substract the padded paf 130 | paf1 = np.transpose(output1[-1][0][15:].data.cpu().numpy(),[1,2,0]) 131 | paf2 = np.transpose(output2[-1][0][15:].data.cpu().numpy(),[1,2,0]) 132 | paf1 = cv.resize(paf1, (0,0), fx=stride, fy=stride, interpolation=cv.INTER_CUBIC) 133 | paf1 = paf1[:image_hwc_padded.shape[0]-pad1[2], :image_hwc_padded.shape[1]-pad1[3], :] 134 | paf1 = cv.resize(paf1, (oriImg.shape[1], oriImg.shape[0]), interpolation=cv.INTER_CUBIC) 135 | paf2 = cv.resize(paf2, (0,0), fx=stride, fy=stride, interpolation=cv.INTER_CUBIC) 136 | paf2 = paf2[:image_hwc_padded.shape[0]-pad2[2], :image_hwc_padded.shape[1]-pad2[3], :] 137 | paf2 = cv.resize(paf2, (oriImg.shape[1], oriImg.shape[0]), interpolation=cv.INTER_CUBIC) 138 | 139 | return (heatmap1,heatmap2),(paf1,paf2) 140 | 141 | def find_peaks(heatmap_avg,hp): 142 | """ 143 | 从heatmap中找出真正的关键点: 144 | - 用高斯滤波器滤波 145 | - 关键点必须同时大于与之相邻的四个点 146 | - 关键点必须大于阈值hp.thre1 147 | @param heatmap_avg: 融合之后的热力图 148 | @param thre1: 热力图阈值 149 | @return all_peaks: list,,len=14(个关键点),每个元素形如: [ 150 | cooradx,coordy,score,idx(第几个峰值) 151 | (119, 355, 0.71655809879302979, 0), (330, 393, 0.71109861135482788, 1)] 152 | @return peak_counter: 多少个关键点 某一类关键点可能有多个,因为图中有多个人 153 | """ 154 | 155 | all_peaks = [] 156 | peak_counter = 0 157 | 158 | for part in range(14): 159 | x_list = [] 160 | y_list = [] 161 | map_ori = heatmap_avg[:,:,part] 162 | map = gaussian_filter(map_ori, hp.sigma) 163 | 164 | map_left = np.zeros(map.shape) 165 | map_left[1:,:] = map[:-1,:] 166 | map_right = np.zeros(map.shape) 167 | map_right[:-1,:] = map[1:,:] 168 | map_up = np.zeros(map.shape) 169 | map_up[:,1:] = map[:,:-1] 170 | map_down = np.zeros(map.shape) 171 | map_down[:,:-1] = map[:,1:] 172 | #mask 的点分别与它左边 右边 上边 下边的一个点比较 并且〉thre1 173 | peaks_binary = np.logical_and.reduce((map>=map_left, map>=map_right, map>=map_up, map>=map_down, map > hp.h_thre)) 174 | peaks = zip(np.nonzero(peaks_binary)[1], np.nonzero(peaks_binary)[0]) # 注意顺序是反着的 175 | # rett=copy.deepcopy(peaks) # deepcopy-python3 176 | # ll=len(tuple(rett)) 177 | ll= (len(peaks)) 178 | 179 | 180 | peaks_with_score = [x + (map_ori[x[1],x[0]],) for x in peaks] 181 | ids = range(peak_counter, peak_counter + ll) 182 | # peaks_with_score_and_id = [peaks_with_score[i] + (ids[i],) for i in range(len(ids))] 183 | peaks_with_score_and_id = [peaks_with_score[i] + (id_,) for i,id_ in enumerate(ids)] 184 | 185 | all_peaks.append(peaks_with_score_and_id) 186 | peak_counter += ll 187 | return all_peaks,peak_counter 188 | 189 | 190 | def find_connection(all_peaks,paf_avg,hp): 191 | """ 192 | @param all_peaks: list,每个元素形如: (coord_x, coord_y, score, id) 193 | @param paf_avg: ndarray, shape hXwX26 194 | @return connection_candidate: list, 每个元素形如: 195 | ([i, j,score_with_dist_prior, score_with_dist_prior+candA[i][2]+candB[j][2]]) 196 | @return special_l: list of k,表示第k个躯干没找到 197 | @return connection_all: list len=13,每一项元素connection表示某种躯干,一个躯干有多个可能(多人) 198 | connection 也是一个list,每个元素 [第几个峰值(allpeaks), 第几个峰值, score, 是第几个头(peak), 是第几个脖子]] 199 | """ 200 | # 13/头顶 14/脖子 头顶指向脖子属于躯干的一个limb 201 | EPS = 1e-30 202 | mid_1=[13,1,4,1,2,4,5,1,7, 8,4,10,11] 203 | mid_2=[14,14,14,2,3,5,6,7,8,9,10,11,12] 204 | # find connection in the specified sequence, center 29 is in the position 15 205 | limbSeq = [[13,14], [1,14], [4,14], [1,2], [2,3], [4,5], [5,6], [1,7], \ 206 | [7,8], [8,9], [4,10], [10,11], [11,12]] 207 | # the middle joints heatmap correpondence 208 | mapIdx = [[i,i+1] for i in range(0,26,2)] 209 | 210 | connection_all = [] 211 | special_k = [] 212 | special_non_zero_index = [] 213 | mid_num = 11 214 | 215 | for k in range(len(mapIdx)): 216 | score_mid = paf_avg[:,:,[x for x in mapIdx[k]]] 217 | candA = all_peaks[limbSeq[k][0]-1] # 此类关键点有几个 218 | candB = all_peaks[limbSeq[k][1]-1] 219 | 220 | nA = len(candA) 221 | nB = len(candB) 222 | indexA, indexB = limbSeq[k] 223 | # 计算两两相似度 224 | if(nA != 0 and nB != 0): 225 | connection_candidate = [] 226 | for i in range(nA): 227 | for j in range(nB): 228 | vec = np.subtract(candB[j][:2], candA[i][:2]) 229 | norm = math.sqrt(vec[0]*vec[0] + vec[1]*vec[1])+1e-100 230 | vec = np.divide(vec, norm) # 单位向量 231 | startend = tuple(zip(np.linspace(candA[i][0], candB[j][0], num=mid_num), \ 232 | np.linspace(candA[i][1], candB[j][1], num=mid_num))) # 两个关键点之间的连线经过的点,每一项是(x,y) 233 | # lll=len(tuple(copy.deepcopy(startend))) ####deepcopy个毛 234 | lll=len(tuple((startend))) 235 | vec_x = np.array([score_mid[int(round(startend[I][1])), int(round(startend[I][0])), 0] \ 236 | for I in range(lll)]) 237 | vec_y = np.array([score_mid[int(round(startend[I][1])), int(round(startend[I][0])), 1] \ 238 | for I in range(lll)]) 239 | score_midpts = np.multiply(vec_x, vec[0]) + np.multiply(vec_y, vec[1]) #向量余弦距离 240 | #论文公式10 求E 241 | if len(score_midpts)==0:print 'len 0' 242 | score_with_dist_prior = sum(score_midpts)/(len(score_midpts)+EPS) + min(0.5*paf_avg.shape[0]/(norm+EPS)-1, 0) 243 | criterion1 = len(np.nonzero(score_midpts > hp.p_thre)[0]) > hp.conn_thre * len(score_midpts) # 两个关键点连线上有80%的点向量相似度大于阈值 244 | criterion2 = score_with_dist_prior > 0 # 两个关键点连线的paf与向量方向相同 245 | 246 | if criterion1 and criterion2: 247 | connection_candidate.append([i, j, score_with_dist_prior, score_with_dist_prior+candA[i][2]+candB[j][2]]) 248 | # print('--------end-----------') 249 | connection_candidate = sorted(connection_candidate, key=lambda x: x[2], reverse=True) 250 | connection = np.zeros((0,5)) 251 | for c in range(len(connection_candidate)): 252 | ## 找最可能的连接点,并且满足没有重复连接的(类似于nms) 253 | i,j,s = connection_candidate[c][0:3] 254 | # 每个点-枝干对, 每个点都只分配给所有可能的枝干中最大的一个 255 | if(i not in connection[:,3] and j not in connection[:,4]): 256 | connection = np.vstack([connection, [candA[i][3], candB[j][3], s, i, j]]) 257 | if(len(connection) >= min(nA, nB)): 258 | break 259 | connection_all.append(connection) 260 | else: 261 | special_k.append(k) 262 | connection_all.append([]) 263 | 264 | return special_k,connection_all 265 | 266 | def save_json(d,fname): 267 | with open(fname,'w') as f: 268 | json.dump(d,f) 269 | 270 | def get_result(subsets,candidate,image_id): 271 | ''' 272 | 生成提交的字典格式数据 273 | @param subset: 274 | @param candidates: 275 | @param image_id: 276 | @return img_result: 字典数据 277 | ''' 278 | keypoint = {} 279 | img_result = {} 280 | img_result['image_id'] = image_id 281 | 282 | for ii, sub in enumerate(subsets): 283 | parts = sub[:-2] 284 | key_ii = [] 285 | for jj, part in enumerate(parts): 286 | if part > -1: 287 | pp = candidate[int(part)] 288 | key_ii += [int(pp[0]), int(pp[1]), 1] 289 | else: 290 | key_ii += [0, 0, 0] 291 | keypoint["human" + str(ii + 1)] = key_ii 292 | img_result['keypoint_annotations'] = keypoint 293 | return img_result 294 | 295 | 296 | def find_person(all_peaks,special_k,connection_all,hp): 297 | """ 298 | @param all_peaks: 299 | @param special_k: 未找到的关节 300 | @return connection_all: pass 301 | @retuen subset: ndarray,human_numx16,形如[头对应all_peaks中哪一个peak,脖子对应all_peak中哪一个peak],-1代表没找到 302 | @return candidate: 所有的peak(某类关键点可能有多个) 303 | """ 304 | subset = -1 * np.ones((0, 16)) #14个关键点+这个人的分数,和这个人找到了多少个关键点 305 | limbSeq = [[13,14], [1,14], [4,14], [1,2], [2,3], [4,5], [5,6], [1,7], \ 306 | [7,8], [8,9], [4,10], [10,11], [11,12]] #有规律:每个枝干至少和之前的枝干有一个共同点 307 | # the middle joints heatmap correpondence 308 | mapIdx = [[i,i+1] for i in range(0,26,2)] 309 | rrs= [] 310 | candidate = np.array([item for sublist in all_peaks for item in sublist]) 311 | for k in range(len(mapIdx)): 312 | rrs.append(subset) 313 | if k not in special_k: 314 | ## all_peaks id 315 | partAs = connection_all[k][:,0] 316 | ## all_peaks id 317 | partBs = connection_all[k][:,1] 318 | indexA, indexB = np.array(limbSeq[k]) - 1 319 | 320 | for i in range(len(connection_all[k])): #= 1:size(temp,1) 321 | found = 0 322 | subset_idx = [-1, -1] #哪两个人应该合并 323 | for j in range(len(subset)): #1:size(subset,1): 324 | # subset[j][indexA] == partAs[i] 325 | #NOTE found不可能大于2 为什么?(每一次都会合并,所以永远没有共享的状态) 326 | if subset[j][indexA] == partAs[i] or subset[j][indexB] == partBs[i]: 327 | subset_idx[found] = j 328 | found += 1 329 | 330 | if found == 1: 331 | j0 = subset_idx[0] 332 | if(subset[j0][indexB] != partBs[i]): 333 | # 重复的那个点是A,就把新的B加入到这个人, 334 | subset[j0][indexB] = partBs[i] 335 | subset[j0][-1] += 1 336 | subset[j0][-2] += candidate[partBs[i].astype(int), 2] + connection_all[k][i][2] 337 | else: 338 | #重复的是B,就把新的A加入 339 | subset[j0][indexA] = partAs[i] 340 | subset[j0][-1] += 1 341 | subset[j0][-2] += candidate[partAs[i].astype(int), 2] + connection_all[k][i][2] 342 | elif found == 2: 343 | # found==2意味着:这个躯干同时在之前两个人之间连接,这时候要把他们给合并成一个人的 344 | # if found 2 and disjoint, merge them 345 | j1, j2 = subset_idx 346 | membership = ((subset[j1]>=0).astype(int) + (subset[j2]>=0).astype(int))[:-2] 347 | if len(np.nonzero(membership == 2)[0]) == 0: #这两个人没有一个共同的peak(任何一个关键点,都只有其中一个人找到了) 348 | ## 一个人,如果某一个关键点被遮挡住了,但是后来通过这种方式还是可以连到一起 349 | #NOTE: 这个在只有13条连接线的情况下是做不到的! 因为13条连接线是最小的,不可能重新连到一起!!!! 350 | subset[j1][:-2] += (subset[j2][:-2] + 1) 351 | subset[j1][-2:] += subset[j2][-2:] 352 | subset[j1][-2] += connection_all[k][i][2] 353 | subset = np.delete(subset, j2, 0) 354 | else: # as like found == 1 355 | # 有两个头,,肯定无法合并成一个人,这里的做法是直接分配给第一个人 356 | # TODO 这里可以优化 357 | # NOTE 不科学 什么鬼 358 | subset[j1][indexB] = partBs[i] 359 | subset[j1][-1] += 1 360 | subset[j1][-2] += candidate[partBs[i].astype(int), 2] + connection_all[k][i][2] 361 | 362 | # if find no partA in the subset, create a new subset 363 | elif not found and k < 13: 364 | row = -1 * np.ones(16) 365 | row[indexA] = partAs[i] 366 | row[indexB] = partBs[i] 367 | row[-1] = 2 # 找了多少个关键点 368 | row[-2] = sum(candidate[connection_all[k][i,:2].astype(int), 2]) + connection_all[k][i][2] # 人的分数(peak+limb的分数) 369 | subset = np.vstack([subset, row]) 370 | 371 | # delete some rows of subset which has few parts occur 372 | # deleteIdx = [] 373 | # keepIdx=[] 374 | # for i in range(len(subset)): 375 | # if subset[i][-1] < hp.min_limb_num or subset[i][-2]/subset[i][-1] < hp.min_avg_score: 376 | # if 377 | # deleteIdx.append(i) 378 | # else:keepIdx.append(i) 379 | # subset = np.delete(subset, deleteIdx, axis=0) 380 | return subset,candidate,rrs 381 | 382 | 383 | 384 | 385 | import torch as t 386 | def get_dataloader(test_files,num_workers=2): 387 | dataset = Dataset(test_files) 388 | dataloader = t.utils.data.DataLoader(dataset,num_workers=2) 389 | return dataloader 390 | 391 | class Dataset: 392 | 393 | def __init__(self,files): 394 | self.files = files 395 | 396 | def __getitem__(self,index): 397 | _d = np.load(self.files[index]) 398 | heatmap,paf = _d['heatmap'],_d['paf'] 399 | return (heatmap,paf,self.files[index]) 400 | 401 | def __len__(self): 402 | return len(self.files) 403 | 404 | 405 | -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | from itertools import chain 2 | import torch 3 | import torchvision as tv 4 | import numpy as np 5 | from .yellowfin import YFOptimizer 6 | 7 | 8 | def get_optimizer(model, lr, lr1=5e-5, weight_decay=1e-5): 9 | parameters = model.parameters() 10 | optimizer = torch.optim.Adam([ 11 | {'params': model.model0.model.parameters(), 'lr' : lr1}, 12 | {'params': model.model1.pre_model.parameters(), 'lr': lr1}, 13 | {'params': model.model1.cmp_model.parameters()}, 14 | {'params': model.model1.paf_model.parameters()}, 15 | 16 | {'params': model.model2.pre_model.parameters(), 'lr': lr1}, 17 | {'params': model.model2.cmp_model.parameters()}, 18 | {'params': model.model2.paf_model.parameters()}, 19 | 20 | {'params': model.model3.pre_model.parameters(), 'lr': lr1}, 21 | {'params': model.model3.cmp_model.parameters()}, 22 | {'params': model.model3.paf_model.parameters()}, 23 | 24 | {'params': model.model4.pre_model.parameters(), 'lr': lr1}, 25 | {'params': model.model4.cmp_model.parameters()}, 26 | {'params': model.model4.paf_model.parameters()}, 27 | 28 | {'params': model.model5.pre_model.parameters(), 'lr': lr1}, 29 | {'params': model.model5.cmp_model.parameters()}, 30 | {'params': model.model5.paf_model.parameters()}, 31 | 32 | {'params': model.model6.pre_model.parameters(), 'lr': lr1}, 33 | {'params': model.model6.cmp_model.parameters()}, 34 | {'params': model.model6.paf_model.parameters()}, 35 | ], lr = lr, weight_decay = weight_decay, betas=(0.9, 0.99)) 36 | # optimizer = torch.optim.Adam( 37 | # parameters, lr=lr, weight_decay=weight_decay, betas=(0.9, 0.999)) 38 | return optimizer 39 | 40 | 41 | def get_yellow(model, lr): 42 | return YFOptimizer(model.parameters(), weight_decay=1e-4) 43 | -------------------------------------------------------------------------------- /utils/visualize.py: -------------------------------------------------------------------------------- 1 | #coding:utf8 2 | from itertools import chain 3 | import visdom 4 | import torch 5 | import time 6 | import torchvision as tv 7 | import numpy as np 8 | from PIL import Image 9 | from skimage import transform 10 | to_pil = tv.transforms.ToPILImage() 11 | to_ten = tv.transforms.ToTensor() 12 | 13 | def mask_img(img,mask,scale=8): 14 | img = to_pil(img) 15 | mask = np.abs(mask.cpu().numpy()) 16 | 17 | mask = mask/mask.max() 18 | alpha = transform.rescale((mask>0.5)*1.,8,mode='reflect')*192 19 | alpha = np.uint8(alpha) 20 | alpha = Image.fromarray(alpha) 21 | 22 | mask_ = np.uint8(transform.rescale(mask,8,mode='reflect')*255) 23 | mask_ = Image.fromarray(mask_) 24 | masked_img = Image.composite(mask_,img,alpha) 25 | 26 | return to_ten(masked_img) 27 | 28 | class Visualizer(): 29 | ''' 30 | 封装了visdom的基本操作,但是你仍然可以通过`self.vis.function` 31 | 调用原生的visdom接口 32 | ''' 33 | 34 | def __init__(self, env='default', **kwargs): 35 | import visdom 36 | self.vis = visdom.Visdom(env=env, **kwargs) 37 | 38 | # 画的第几个数,相当于横座标 39 | # 保存(’loss',23) 即loss的第23个点 40 | self.index = {} 41 | self.log_text = '' 42 | def reinit(self,env='default',**kwargs): 43 | ''' 44 | 修改visdom的配置 45 | ''' 46 | self.vis = visdom.Visdom(env=env,**kwargs) 47 | return self 48 | 49 | def plot_many(self, d): 50 | ''' 51 | 一次plot多个 52 | @params d: dict (name,value) i.e. ('loss',0.11) 53 | ''' 54 | for k, v in d.iteritems(): 55 | self.plot(k, v) 56 | 57 | def img_many(self, d): 58 | for k, v in d.iteritems(): 59 | self.img(k, v) 60 | 61 | def plot(self, name, y): 62 | ''' 63 | self.plot('loss',1.00) 64 | ''' 65 | x = self.index.get(name, 0) 66 | self.vis.line(Y=np.array([y]), X=np.array([x]), 67 | win=unicode(name), 68 | opts=dict(title=name), 69 | update=None if x == 0 else 'append' 70 | ) 71 | self.index[name] = x + 1 72 | 73 | def img(self, name, img_): 74 | ''' 75 | self.img('input_img',t.Tensor(64,64)) 76 | ''' 77 | 78 | if len(img_.size())<3: 79 | img_ = img_.cpu().unsqueeze(0) 80 | self.vis.image(img_.cpu(), 81 | win=unicode(name), 82 | opts=dict(title=name) 83 | ) 84 | def img_grid_many(self,d): 85 | for k, v in d.iteritems(): 86 | self.img_grid(k, v) 87 | 88 | def img_grid(self, name, input_3d): 89 | ''' 90 | 一个batch的图片转成一个网格图,i.e. input(36,64,64) 91 | 会变成 6*6 的网格图,每个格子大小64*64 92 | ''' 93 | self.img(name, tv.utils.make_grid( 94 | input_3d.cpu()[0].unsqueeze(1).clamp(max=1,min=0))) 95 | 96 | def log(self,info,win='log_text'): 97 | ''' 98 | self.log({'loss':1,'lr':0.0001}) 99 | ''' 100 | 101 | self.log_text += ('[{time}] {info}
'.format( 102 | time=time.strftime('%m%d_%H%M%S'),\ 103 | info=info)) 104 | self.vis.text(self.log_text,win='log_text') 105 | 106 | def __getattr__(self, name): 107 | return getattr(self.vis, name) 108 | 109 | --------------------------------------------------------------------------------