├── README.md ├── demo.py ├── images └── timg.jpeg └── src ├── __init__.py ├── breakout25.py ├── model.py ├── torch_openpose.py └── util.py /README.md: -------------------------------------------------------------------------------- 1 | # pytorch_openpose_body_25 2 | pytorch implementation of openpose including Body coco and body_25 Estimation, and the pytorch model is directly converted from openpose caffemodel by caffemodel2pytorch.I did some work to implement the body_25 net model and to figure out the correspond of Part Confidence Map and Part Affinity Field outputs. Some code came from PyTorch OpenPose, and I debug some problem. 3 | 4 | Download the torch model ,and put them in the model folder(mkdir by yourself) 5 | 6 | # Demo: 7 | 8 | python3 demo.py images/timg.jpeg 9 | 10 | # Downloads: 11 | * [body_25](https://pan.baidu.com/s/1CopeW-Em4Tm9H-Wl_hzVfg) download code : 9g4p ; [google cloud](https://drive.google.com/file/d/1ghXakEXhBMCdV78K6tCFTPp_vjJDWmcE/view?usp=sharing) 12 | * [body_coco](https://pan.baidu.com/s/19Hjo5qEsNPoRt6zY6Ly4Lw) download code : kav3 ; [google cloud](https://drive.google.com/file/d/1VPiIxXk5KWEwdJlVVe5PDQ1QufMS1Zpk/view?usp=sharing) 13 | 14 | 15 | ## References 16 | * [OpenPose doc](https://arxiv.org/abs/1812.08008) 17 | * [OpenPose](https://github.com/CMU-Perceptual-Computing-Lab/openpose) 18 | * [PyTorch OpenPose](https://github.com/Hzzone/pytorch-openpose) 19 | 20 | ## License 21 | * [OpenPose License](https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/LICENSE) 22 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Tue Apr 28 19:17:47 2020 5 | 6 | @author: joe 7 | """ 8 | 9 | 10 | from src import torch_openpose,util 11 | 12 | import cv2 13 | 14 | if __name__ == "__main__": 15 | tp = torch_openpose.torch_openpose('body_25') 16 | img = cv2.imread("images/timg.jpeg") 17 | poses = tp(img) 18 | img = util.draw_bodypose(img, poses,'body_25') 19 | cv2.imshow('v',img) 20 | cv2.waitKey(0) 21 | cv2.destroyAllWindows() -------------------------------------------------------------------------------- /images/timg.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/beingjoey/pytorch_openpose_body_25/066c67dfe954230da1e5673ec05a27192229b012/images/timg.jpeg -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Tue Apr 28 18:47:14 2020 5 | 6 | @author: joe 7 | """ 8 | 9 | 10 | 11 | #__all__ = ["torch_openpose", "util", "model"] -------------------------------------------------------------------------------- /src/breakout25.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Tue Apr 28 15:13:51 2020 5 | 6 | @author: joe 7 | """ 8 | import numpy as np 9 | import math 10 | 11 | jointpairs = [[1,0],[1,2],[2,3],[3,4],[1,5],[5,6],[6,7],[1,8],[8,9],[9,10],[10,11],[8,12],[12,13],[13,14],[0,15],[0,16]\ 12 | ,[15,17],[16,18],[11,24],[11,22],[14,21],[14,19],[22,23],[19,20]] 13 | #[[1,0], [1,2], [2,3], [3,4], [1,5], [5,6], [6,7], [1,8], [8,9], [9,10],[10,11], [8,12],[12,13], [13,14], [0,15], [0,16]] 14 | 15 | #[[30, 31],[14, 15],[16, 17],[18, 19],[22, 23],[24, 25],[26, 27],[0, 1],[6, 7],[2, 3],[4, 5], [8, 9],[10, 11],[12, 13],[32, 33],[34, 35]] 16 | 17 | #[[15,17],[16,18],[11,24],[11,22],[14,21],[14,19],[22,23],[19,20]] 18 | #[[36,37],[38,39],[50,51],[46,47],[44,45],[40,41],[48,49],[42,43]] 19 | map25 = [[i,i+1] for i in range(0,52,2)] 20 | 21 | def findoutmappair(all_peaks,paf): 22 | mid_num = 10 23 | pairmap = [] 24 | for pair in jointpairs: 25 | candA = all_peaks[pair[0]] 26 | candB = all_peaks[pair[1]] 27 | if len(candA) == 0 or len(candB) == 0: 28 | pairmap.append([]) 29 | continue 30 | candA = candA[0] 31 | candB = candB[0] 32 | startend = list(zip(np.linspace(candA[0], candB[0], num=mid_num), \ 33 | np.linspace(candA[1], candB[1], num=mid_num))) 34 | 35 | vec = np.subtract(candB[:2], candA[:2]) 36 | norm = math.sqrt(vec[0] * vec[0] + vec[1] * vec[1]) 37 | vec = np.divide(vec, norm) 38 | score = 0. 39 | tmp = [] 40 | for mp in map25: 41 | score_mid = paf[:,:,[mp[0],mp[1]]] 42 | vec_x = np.array([score_mid[int(round(startend[I][1])), int(round(startend[I][0])), 0] \ 43 | for I in range(len(startend))]) 44 | vec_y = np.array([score_mid[int(round(startend[I][1])), int(round(startend[I][0])), 1] \ 45 | for I in range(len(startend))]) 46 | score_midpts = np.multiply(vec_x, vec[0]) + np.multiply(vec_y, vec[1]) 47 | score_midpts = score_midpts.sum() 48 | if score < score_midpts: 49 | score = score_midpts 50 | tmp = mp 51 | if score > 0.5: 52 | pairmap.append(tmp+[score,]) 53 | else: 54 | pairmap.append([]) 55 | return pairmap 56 | 57 | -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from collections import OrderedDict 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | def make_layers(block, no_relu_layers,prelu_layers = []): 8 | layers = [] 9 | for layer_name, v in block.items(): 10 | if 'pool' in layer_name: 11 | layer = nn.MaxPool2d(kernel_size=v[0], stride=v[1], 12 | padding=v[2]) 13 | layers.append((layer_name, layer)) 14 | else: 15 | conv2d = nn.Conv2d(in_channels=v[0], out_channels=v[1], 16 | kernel_size=v[2], stride=v[3], 17 | padding=v[4]) 18 | layers.append((layer_name, conv2d)) 19 | if layer_name not in no_relu_layers: 20 | if layer_name not in prelu_layers: 21 | layers.append(('relu_'+layer_name, nn.ReLU(inplace=True))) 22 | else: 23 | layers.append(('prelu'+layer_name[4:],nn.PReLU(v[1]))) 24 | 25 | return nn.Sequential(OrderedDict(layers)) 26 | 27 | def make_layers_Mconv(block,no_relu_layers): 28 | modules = [] 29 | for layer_name, v in block.items(): 30 | layers = [] 31 | if 'pool' in layer_name: 32 | layer = nn.MaxPool2d(kernel_size=v[0], stride=v[1], 33 | padding=v[2]) 34 | layers.append((layer_name, layer)) 35 | else: 36 | conv2d = nn.Conv2d(in_channels=v[0], out_channels=v[1], 37 | kernel_size=v[2], stride=v[3], 38 | padding=v[4]) 39 | layers.append((layer_name, conv2d)) 40 | if layer_name not in no_relu_layers: 41 | layers.append(('Mprelu'+layer_name[5:], nn.PReLU(v[1]))) 42 | modules.append(nn.Sequential(OrderedDict(layers))) 43 | return nn.ModuleList(modules) 44 | 45 | class bodypose_25_model(nn.Module): 46 | def __init__(self): 47 | super(bodypose_25_model,self).__init__() 48 | # these layers have no relu layer 49 | no_relu_layers = ['Mconv7_stage0_L1','Mconv7_stage0_L2',\ 50 | 'Mconv7_stage1_L1', 'Mconv7_stage1_L2',\ 51 | 'Mconv7_stage2_L2', 'Mconv7_stage3_L2'] 52 | prelu_layers = ['conv4_2','conv4_3_CPM','conv4_4_CPM'] 53 | blocks = {} 54 | block0 = OrderedDict([ 55 | ('conv1_1', [3, 64, 3, 1, 1]), 56 | ('conv1_2', [64, 64, 3, 1, 1]), 57 | ('pool1_stage1', [2, 2, 0]), 58 | ('conv2_1', [64, 128, 3, 1, 1]), 59 | ('conv2_2', [128, 128, 3, 1, 1]), 60 | ('pool2_stage1', [2, 2, 0]), 61 | ('conv3_1', [128, 256, 3, 1, 1]), 62 | ('conv3_2', [256, 256, 3, 1, 1]), 63 | ('conv3_3', [256, 256, 3, 1, 1]), 64 | ('conv3_4', [256, 256, 3, 1, 1]), 65 | ('pool3_stage1', [2, 2, 0]), 66 | ('conv4_1', [256, 512, 3, 1, 1]), 67 | ('conv4_2', [512, 512, 3, 1, 1]), 68 | ('conv4_3_CPM', [512, 256, 3, 1, 1]), 69 | ('conv4_4_CPM', [256, 128, 3, 1, 1]) 70 | ]) 71 | self.model0 = make_layers(block0, no_relu_layers,prelu_layers) 72 | 73 | #L2 74 | #stage0 75 | blocks['Mconv1_stage0_L2'] = OrderedDict([ 76 | ('Mconv1_stage0_L2_0',[128,96,3,1,1]), 77 | ('Mconv1_stage0_L2_1',[96,96,3,1,1]), 78 | ('Mconv1_stage0_L2_2',[96,96,3,1,1]) 79 | ]) 80 | for i in range(2,6): 81 | blocks['Mconv%d_stage0_L2' % i] = OrderedDict([ 82 | ('Mconv%d_stage0_L2_0' % i,[288,96,3,1,1]), 83 | ('Mconv%d_stage0_L2_1' % i,[96,96,3,1,1]), 84 | ('Mconv%d_stage0_L2_2' % i,[96,96,3,1,1]) 85 | ]) 86 | blocks['Mconv6_7_stage0_L2'] = OrderedDict([ 87 | ('Mconv6_stage0_L2',[288, 256, 1,1,0]), 88 | ('Mconv7_stage0_L2',[256,52,1,1,0]) 89 | ]) 90 | #stage1~3 91 | for s in range(1,4): 92 | blocks['Mconv1_stage%d_L2' % s] = OrderedDict([ 93 | ('Mconv1_stage%d_L2_0' % s,[180,128,3,1,1]), 94 | ('Mconv1_stage%d_L2_1' % s,[128,128,3,1,1]), 95 | ('Mconv1_stage%d_L2_2' % s,[128,128,3,1,1]) 96 | ]) 97 | for i in range(2,6): 98 | blocks['Mconv%d_stage%d_L2' % (i,s)] = OrderedDict([ 99 | ('Mconv%d_stage%d_L2_0' % (i,s) ,[384,128,3,1,1]), 100 | ('Mconv%d_stage%d_L2_1' % (i,s) ,[128,128,3,1,1]), 101 | ('Mconv%d_stage%d_L2_2' % (i,s) ,[128,128,3,1,1]) 102 | ]) 103 | blocks['Mconv6_7_stage%d_L2' % s] = OrderedDict([ 104 | ('Mconv6_stage%d_L2' % s,[384,512,1,1,0]), 105 | ('Mconv7_stage%d_L2' % s,[512,52,1,1,0]) 106 | ]) 107 | 108 | #L1 109 | #stage0 110 | blocks['Mconv1_stage0_L1'] = OrderedDict([ 111 | ('Mconv1_stage0_L1_0',[180,96,3,1,1]), 112 | ('Mconv1_stage0_L1_1',[96,96,3,1,1]), 113 | ('Mconv1_stage0_L1_2',[96,96,3,1,1]) 114 | ]) 115 | for i in range(2,6): 116 | blocks['Mconv%d_stage0_L1' % i] = OrderedDict([ 117 | ('Mconv%d_stage0_L1_0' % i,[288,96,3,1,1]), 118 | ('Mconv%d_stage0_L1_1' % i,[96,96,3,1,1]), 119 | ('Mconv%d_stage0_L1_2' % i,[96,96,3,1,1]) 120 | ]) 121 | blocks['Mconv6_7_stage0_L1'] = OrderedDict([ 122 | ('Mconv6_stage0_L1',[288, 256, 1,1,0]), 123 | ('Mconv7_stage0_L1',[256,26,1,1,0]) 124 | ]) 125 | #stage1 126 | blocks['Mconv1_stage1_L1'] = OrderedDict([ 127 | ('Mconv1_stage1_L1_0',[206,128,3,1,1]), 128 | ('Mconv1_stage1_L1_1',[128,128,3,1,1]), 129 | ('Mconv1_stage1_L1_2',[128,128,3,1,1]) 130 | ]) 131 | for i in range(2,6): 132 | blocks['Mconv%d_stage1_L1' % i] = OrderedDict([ 133 | ('Mconv%d_stage1_L1_0' % i,[384,128,3,1,1]), 134 | ('Mconv%d_stage1_L1_1' % i,[128,128,3,1,1]), 135 | ('Mconv%d_stage1_L1_2' % i,[128,128,3,1,1]) 136 | ]) 137 | blocks['Mconv6_7_stage1_L1'] = OrderedDict([ 138 | ('Mconv6_stage1_L1',[384,512,1,1,0]), 139 | ('Mconv7_stage1_L1',[512,26,1,1,0]) 140 | ]) 141 | 142 | for k in blocks.keys(): 143 | blocks[k] = make_layers_Mconv(blocks[k], no_relu_layers) 144 | self.models = nn.ModuleDict(blocks) 145 | #self.model_L2_S0_mconv1 = blocks['Mconv1_stage0_L2'] 146 | 147 | 148 | def _Mconv_forward(self,x,models): 149 | outs = [] 150 | out = x 151 | for m in models: 152 | out = m(out) 153 | outs.append(out) 154 | return torch.cat(outs,1) 155 | 156 | def forward(self,x): 157 | out0 = self.model0(x) 158 | #L2 159 | tout = out0 160 | for s in range(4): 161 | tout = self._Mconv_forward(tout,self.models['Mconv1_stage%d_L2' % s]) 162 | for v in range(2,6): 163 | tout = self._Mconv_forward(tout,self.models['Mconv%d_stage%d_L2' % (v,s)]) 164 | tout = self.models['Mconv6_7_stage%d_L2' % s][0](tout) 165 | tout = self.models['Mconv6_7_stage%d_L2' % s][1](tout) 166 | outL2 = tout 167 | tout = torch.cat([out0,tout],1) 168 | #L1 stage0 169 | #tout = torch.cat([out0,outL2],1) 170 | tout = self._Mconv_forward(tout, self.models['Mconv1_stage0_L1']) 171 | for v in range(2,6): 172 | tout = self._Mconv_forward(tout, self.models['Mconv%d_stage0_L1' % v]) 173 | tout = self.models['Mconv6_7_stage0_L1'][0](tout) 174 | tout = self.models['Mconv6_7_stage0_L1'][1](tout) 175 | outS0L1 = tout 176 | tout = torch.cat([out0,outS0L1,outL2],1) 177 | #L1 stage1 178 | tout = self._Mconv_forward(tout, self.models['Mconv1_stage1_L1']) 179 | for v in range(2,6): 180 | tout = self._Mconv_forward(tout, self.models['Mconv%d_stage1_L1' % v]) 181 | tout = self.models['Mconv6_7_stage1_L1'][0](tout) 182 | outS1L1 = self.models['Mconv6_7_stage1_L1'][1](tout) 183 | 184 | return outS1L1,outL2 185 | 186 | 187 | class bodypose_model(nn.Module): 188 | def __init__(self): 189 | super(bodypose_model, self).__init__() 190 | 191 | # these layers have no relu layer 192 | no_relu_layers = ['conv5_5_CPM_L1', 'conv5_5_CPM_L2', 'Mconv7_stage2_L1',\ 193 | 'Mconv7_stage2_L2', 'Mconv7_stage3_L1', 'Mconv7_stage3_L2',\ 194 | 'Mconv7_stage4_L1', 'Mconv7_stage4_L2', 'Mconv7_stage5_L1',\ 195 | 'Mconv7_stage5_L2', 'Mconv7_stage6_L1', 'Mconv7_stage6_L1'] 196 | blocks = {} 197 | block0 = OrderedDict([ 198 | ('conv1_1', [3, 64, 3, 1, 1]), 199 | ('conv1_2', [64, 64, 3, 1, 1]), 200 | ('pool1_stage1', [2, 2, 0]), 201 | ('conv2_1', [64, 128, 3, 1, 1]), 202 | ('conv2_2', [128, 128, 3, 1, 1]), 203 | ('pool2_stage1', [2, 2, 0]), 204 | ('conv3_1', [128, 256, 3, 1, 1]), 205 | ('conv3_2', [256, 256, 3, 1, 1]), 206 | ('conv3_3', [256, 256, 3, 1, 1]), 207 | ('conv3_4', [256, 256, 3, 1, 1]), 208 | ('pool3_stage1', [2, 2, 0]), 209 | ('conv4_1', [256, 512, 3, 1, 1]), 210 | ('conv4_2', [512, 512, 3, 1, 1]), 211 | ('conv4_3_CPM', [512, 256, 3, 1, 1]), 212 | ('conv4_4_CPM', [256, 128, 3, 1, 1]) 213 | ]) 214 | 215 | 216 | # Stage 1 217 | block1_1 = OrderedDict([ 218 | ('conv5_1_CPM_L1', [128, 128, 3, 1, 1]), 219 | ('conv5_2_CPM_L1', [128, 128, 3, 1, 1]), 220 | ('conv5_3_CPM_L1', [128, 128, 3, 1, 1]), 221 | ('conv5_4_CPM_L1', [128, 512, 1, 1, 0]), 222 | ('conv5_5_CPM_L1', [512, 38, 1, 1, 0]) 223 | ]) 224 | 225 | block1_2 = OrderedDict([ 226 | ('conv5_1_CPM_L2', [128, 128, 3, 1, 1]), 227 | ('conv5_2_CPM_L2', [128, 128, 3, 1, 1]), 228 | ('conv5_3_CPM_L2', [128, 128, 3, 1, 1]), 229 | ('conv5_4_CPM_L2', [128, 512, 1, 1, 0]), 230 | ('conv5_5_CPM_L2', [512, 19, 1, 1, 0]) 231 | ]) 232 | blocks['block1_1'] = block1_1 233 | blocks['block1_2'] = block1_2 234 | 235 | self.model0 = make_layers(block0, no_relu_layers) 236 | 237 | # Stages 2 - 6 238 | for i in range(2, 7): 239 | blocks['block%d_1' % i] = OrderedDict([ 240 | ('Mconv1_stage%d_L1' % i, [185, 128, 7, 1, 3]), 241 | ('Mconv2_stage%d_L1' % i, [128, 128, 7, 1, 3]), 242 | ('Mconv3_stage%d_L1' % i, [128, 128, 7, 1, 3]), 243 | ('Mconv4_stage%d_L1' % i, [128, 128, 7, 1, 3]), 244 | ('Mconv5_stage%d_L1' % i, [128, 128, 7, 1, 3]), 245 | ('Mconv6_stage%d_L1' % i, [128, 128, 1, 1, 0]), 246 | ('Mconv7_stage%d_L1' % i, [128, 38, 1, 1, 0]) 247 | ]) 248 | 249 | blocks['block%d_2' % i] = OrderedDict([ 250 | ('Mconv1_stage%d_L2' % i, [185, 128, 7, 1, 3]), 251 | ('Mconv2_stage%d_L2' % i, [128, 128, 7, 1, 3]), 252 | ('Mconv3_stage%d_L2' % i, [128, 128, 7, 1, 3]), 253 | ('Mconv4_stage%d_L2' % i, [128, 128, 7, 1, 3]), 254 | ('Mconv5_stage%d_L2' % i, [128, 128, 7, 1, 3]), 255 | ('Mconv6_stage%d_L2' % i, [128, 128, 1, 1, 0]), 256 | ('Mconv7_stage%d_L2' % i, [128, 19, 1, 1, 0]) 257 | ]) 258 | 259 | for k in blocks.keys(): 260 | blocks[k] = make_layers(blocks[k], no_relu_layers) 261 | 262 | self.model1_1 = blocks['block1_1'] 263 | self.model2_1 = blocks['block2_1'] 264 | self.model3_1 = blocks['block3_1'] 265 | self.model4_1 = blocks['block4_1'] 266 | self.model5_1 = blocks['block5_1'] 267 | self.model6_1 = blocks['block6_1'] 268 | 269 | self.model1_2 = blocks['block1_2'] 270 | self.model2_2 = blocks['block2_2'] 271 | self.model3_2 = blocks['block3_2'] 272 | self.model4_2 = blocks['block4_2'] 273 | self.model5_2 = blocks['block5_2'] 274 | self.model6_2 = blocks['block6_2'] 275 | 276 | 277 | def forward(self, x): 278 | 279 | out1 = self.model0(x) 280 | 281 | out1_1 = self.model1_1(out1) 282 | out1_2 = self.model1_2(out1) 283 | out2 = torch.cat([out1_1, out1_2, out1], 1) 284 | 285 | out2_1 = self.model2_1(out2) 286 | out2_2 = self.model2_2(out2) 287 | out3 = torch.cat([out2_1, out2_2, out1], 1) 288 | 289 | out3_1 = self.model3_1(out3) 290 | out3_2 = self.model3_2(out3) 291 | out4 = torch.cat([out3_1, out3_2, out1], 1) 292 | 293 | out4_1 = self.model4_1(out4) 294 | out4_2 = self.model4_2(out4) 295 | out5 = torch.cat([out4_1, out4_2, out1], 1) 296 | 297 | out5_1 = self.model5_1(out5) 298 | out5_2 = self.model5_2(out5) 299 | out6 = torch.cat([out5_1, out5_2, out1], 1) 300 | 301 | out6_1 = self.model6_1(out6) 302 | out6_2 = self.model6_2(out6) 303 | 304 | return out6_2,out6_1 305 | 306 | class handpose_model(nn.Module): 307 | def __init__(self): 308 | super(handpose_model, self).__init__() 309 | 310 | # these layers have no relu layer 311 | no_relu_layers = ['conv6_2_CPM', 'Mconv7_stage2', 'Mconv7_stage3',\ 312 | 'Mconv7_stage4', 'Mconv7_stage5', 'Mconv7_stage6'] 313 | # stage 1 314 | block1_0 = OrderedDict([ 315 | ('conv1_1', [3, 64, 3, 1, 1]), 316 | ('conv1_2', [64, 64, 3, 1, 1]), 317 | ('pool1_stage1', [2, 2, 0]), 318 | ('conv2_1', [64, 128, 3, 1, 1]), 319 | ('conv2_2', [128, 128, 3, 1, 1]), 320 | ('pool2_stage1', [2, 2, 0]), 321 | ('conv3_1', [128, 256, 3, 1, 1]), 322 | ('conv3_2', [256, 256, 3, 1, 1]), 323 | ('conv3_3', [256, 256, 3, 1, 1]), 324 | ('conv3_4', [256, 256, 3, 1, 1]), 325 | ('pool3_stage1', [2, 2, 0]), 326 | ('conv4_1', [256, 512, 3, 1, 1]), 327 | ('conv4_2', [512, 512, 3, 1, 1]), 328 | ('conv4_3', [512, 512, 3, 1, 1]), 329 | ('conv4_4', [512, 512, 3, 1, 1]), 330 | ('conv5_1', [512, 512, 3, 1, 1]), 331 | ('conv5_2', [512, 512, 3, 1, 1]), 332 | ('conv5_3_CPM', [512, 128, 3, 1, 1]) 333 | ]) 334 | 335 | block1_1 = OrderedDict([ 336 | ('conv6_1_CPM', [128, 512, 1, 1, 0]), 337 | ('conv6_2_CPM', [512, 22, 1, 1, 0]) 338 | ]) 339 | 340 | blocks = {} 341 | blocks['block1_0'] = block1_0 342 | blocks['block1_1'] = block1_1 343 | 344 | # stage 2-6 345 | for i in range(2, 7): 346 | blocks['block%d' % i] = OrderedDict([ 347 | ('Mconv1_stage%d' % i, [150, 128, 7, 1, 3]), 348 | ('Mconv2_stage%d' % i, [128, 128, 7, 1, 3]), 349 | ('Mconv3_stage%d' % i, [128, 128, 7, 1, 3]), 350 | ('Mconv4_stage%d' % i, [128, 128, 7, 1, 3]), 351 | ('Mconv5_stage%d' % i, [128, 128, 7, 1, 3]), 352 | ('Mconv6_stage%d' % i, [128, 128, 1, 1, 0]), 353 | ('Mconv7_stage%d' % i, [128, 22, 1, 1, 0]) 354 | ]) 355 | 356 | for k in blocks.keys(): 357 | blocks[k] = make_layers(blocks[k], no_relu_layers) 358 | 359 | self.model1_0 = blocks['block1_0'] 360 | self.model1_1 = blocks['block1_1'] 361 | self.model2 = blocks['block2'] 362 | self.model3 = blocks['block3'] 363 | self.model4 = blocks['block4'] 364 | self.model5 = blocks['block5'] 365 | self.model6 = blocks['block6'] 366 | 367 | def forward(self, x): 368 | out1_0 = self.model1_0(x) 369 | out1_1 = self.model1_1(out1_0) 370 | concat_stage2 = torch.cat([out1_1, out1_0], 1) 371 | out_stage2 = self.model2(concat_stage2) 372 | concat_stage3 = torch.cat([out_stage2, out1_0], 1) 373 | out_stage3 = self.model3(concat_stage3) 374 | concat_stage4 = torch.cat([out_stage3, out1_0], 1) 375 | out_stage4 = self.model4(concat_stage4) 376 | concat_stage5 = torch.cat([out_stage4, out1_0], 1) 377 | out_stage5 = self.model5(concat_stage5) 378 | concat_stage6 = torch.cat([out_stage5, out1_0], 1) 379 | out_stage6 = self.model6(concat_stage6) 380 | return out_stage6 381 | 382 | 383 | -------------------------------------------------------------------------------- /src/torch_openpose.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import math 4 | from scipy.ndimage.filters import gaussian_filter 5 | import torch 6 | 7 | from src import util 8 | from src.model import bodypose_model,bodypose_25_model 9 | 10 | model_coco = 'model/body_coco.pth' 11 | model_body25 = 'model/body_25.pth' 12 | 13 | class torch_openpose(object): 14 | def __init__(self, model_type): 15 | if model_type == 'body_25': 16 | self.model = bodypose_25_model() 17 | self.njoint = 26 18 | self.npaf = 52 19 | self.model.load_state_dict(torch.load(model_body25)) 20 | else: 21 | self.model = bodypose_model() 22 | self.njoint = 19 23 | self.npaf = 38 24 | self.model.load_state_dict(torch.load(model_coco)) 25 | if torch.cuda.is_available(): 26 | self.model = self.model.cuda() 27 | self.model.eval() 28 | 29 | if self.njoint == 19: #coco 30 | self.limbSeq = [[1, 2], [1, 5], [2, 3], [3, 4], [5, 6], [6, 7], [1, 8], [8, 9], \ 31 | [9, 10], [1, 11], [11, 12], [12, 13], [1, 0], [0, 14], [14, 16], \ 32 | [0, 15], [15, 17]] 33 | self.mapIdx = [[12, 13],[20, 21],[14, 15],[16, 17],[22, 23],[24, 25],[0, 1],[2, 3],\ 34 | [4, 5],[6, 7],[8, 9],[10, 11],[28, 29],[30, 31],[34, 35],[32, 33],\ 35 | [36, 37]] 36 | elif self.njoint == 26: #body_25 37 | self.limbSeq = [[1,0],[1,2],[2,3],[3,4],[1,5],[5,6],[6,7],[1,8],[8,9],[9,10],\ 38 | [10,11],[8,12],[12,13],[13,14],[0,15],[0,16],[15,17],[16,18],\ 39 | [11,24],[11,22],[14,21],[14,19],[22,23],[19,20]] 40 | self.mapIdx = [[30, 31],[14, 15],[16, 17],[18, 19],[22, 23],[24, 25],[26, 27],[0, 1],[6, 7],\ 41 | [2, 3],[4, 5], [8, 9],[10, 11],[12, 13],[32, 33],[34, 35],[36,37],[38,39],\ 42 | [50,51],[46,47],[44,45],[40,41],[48,49],[42,43]] 43 | 44 | 45 | def __call__(self, oriImg): 46 | # scale_search = [0.5, 1.0, 1.5, 2.0] 47 | scale_search = [0.5] 48 | boxsize = 368 49 | stride = 8 50 | padValue = 128 51 | thre1 = 0.1 52 | thre2 = 0.05 53 | multiplier = [x * boxsize / oriImg.shape[0] for x in scale_search] 54 | heatmap_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], self.njoint)) 55 | paf_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], self.npaf)) 56 | 57 | for m in range(len(multiplier)): 58 | scale = multiplier[m] 59 | imageToTest = cv2.resize(oriImg, (0, 0), fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC) 60 | imageToTest_padded, pad = util.padRightDownCorner(imageToTest, stride, padValue) 61 | im = np.transpose(np.float32(imageToTest_padded[:, :, :, np.newaxis]), (3, 2, 0, 1)) / 256 - 0.5 62 | im = np.ascontiguousarray(im) 63 | 64 | data = torch.from_numpy(im).float() 65 | if torch.cuda.is_available(): 66 | data = data.cuda() 67 | # data = data.permute([2, 0, 1]).unsqueeze(0).float() 68 | with torch.no_grad(): 69 | heatmap, paf = self.model(data) 70 | 71 | heatmap = heatmap.detach().cpu().numpy() 72 | paf = paf.detach().cpu().numpy() 73 | 74 | # extract outputs, resize, and remove padding 75 | # heatmap = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[1]].data), (1, 2, 0)) # output 1 is heatmaps 76 | heatmap = np.transpose(np.squeeze(heatmap), (1, 2, 0)) # output 1 is heatmaps 77 | heatmap = cv2.resize(heatmap, (0, 0), fx=stride, fy=stride, interpolation=cv2.INTER_CUBIC) 78 | heatmap = heatmap[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :] 79 | heatmap = cv2.resize(heatmap, (oriImg.shape[1], oriImg.shape[0]), interpolation=cv2.INTER_CUBIC) 80 | 81 | # paf = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[0]].data), (1, 2, 0)) # output 0 is PAFs 82 | paf = np.transpose(np.squeeze(paf), (1, 2, 0)) # output 0 is PAFs 83 | paf = cv2.resize(paf, (0, 0), fx=stride, fy=stride, interpolation=cv2.INTER_CUBIC) 84 | paf = paf[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :] 85 | paf = cv2.resize(paf, (oriImg.shape[1], oriImg.shape[0]), interpolation=cv2.INTER_CUBIC) 86 | 87 | heatmap_avg += heatmap_avg + heatmap / len(multiplier) 88 | paf_avg += + paf / len(multiplier) 89 | 90 | all_peaks = [] 91 | peak_counter = 0 92 | 93 | for part in range(self.njoint - 1): 94 | map_ori = heatmap_avg[:, :, part] 95 | one_heatmap = gaussian_filter(map_ori, sigma=3) 96 | 97 | map_left = np.zeros(one_heatmap.shape) 98 | map_left[1:, :] = one_heatmap[:-1, :] 99 | map_right = np.zeros(one_heatmap.shape) 100 | map_right[:-1, :] = one_heatmap[1:, :] 101 | map_up = np.zeros(one_heatmap.shape) 102 | map_up[:, 1:] = one_heatmap[:, :-1] 103 | map_down = np.zeros(one_heatmap.shape) 104 | map_down[:, :-1] = one_heatmap[:, 1:] 105 | 106 | peaks_binary = np.logical_and.reduce( 107 | (one_heatmap >= map_left, one_heatmap >= map_right, one_heatmap >= map_up, one_heatmap >= map_down, one_heatmap > thre1)) 108 | peaks = list(zip(np.nonzero(peaks_binary)[1], np.nonzero(peaks_binary)[0])) # note reverse 109 | peaks_with_score = [x + (map_ori[x[1], x[0]],) for x in peaks] 110 | peak_id = range(peak_counter, peak_counter + len(peaks)) 111 | peaks_with_score_and_id = [peaks_with_score[i] + (peak_id[i],) for i in range(len(peak_id))] 112 | 113 | all_peaks.append(peaks_with_score_and_id) 114 | peak_counter += len(peaks) 115 | 116 | # find connection in the specified sequence, center 29 is in the position 15 117 | limbSeq = self.limbSeq 118 | # the middle joints heatmap correpondence 119 | mapIdx = self.mapIdx 120 | 121 | connection_all = [] 122 | special_k = [] 123 | mid_num = 10 124 | 125 | for k in range(len(mapIdx)): 126 | score_mid = paf_avg[:, :, mapIdx[k]] 127 | candA = all_peaks[limbSeq[k][0]] 128 | candB = all_peaks[limbSeq[k][1]] 129 | nA = len(candA) 130 | nB = len(candB) 131 | indexA, indexB = limbSeq[k] 132 | if (nA != 0 and nB != 0): 133 | connection_candidate = [] 134 | for i in range(nA): 135 | for j in range(nB): 136 | vec = np.subtract(candB[j][:2], candA[i][:2]) 137 | norm = math.sqrt(vec[0] * vec[0] + vec[1] * vec[1]) 138 | vec = np.divide(vec, norm) 139 | 140 | startend = list(zip(np.linspace(candA[i][0], candB[j][0], num=mid_num), \ 141 | np.linspace(candA[i][1], candB[j][1], num=mid_num))) 142 | 143 | vec_x = np.array([score_mid[int(round(startend[I][1])), int(round(startend[I][0])), 0] \ 144 | for I in range(len(startend))]) 145 | vec_y = np.array([score_mid[int(round(startend[I][1])), int(round(startend[I][0])), 1] \ 146 | for I in range(len(startend))]) 147 | 148 | score_midpts = np.multiply(vec_x, vec[0]) + np.multiply(vec_y, vec[1]) 149 | score_with_dist_prior = sum(score_midpts) / len(score_midpts) + min( 150 | 0.5 * oriImg.shape[0] / norm - 1, 0) 151 | criterion1 = len(np.nonzero(score_midpts > thre2)[0]) > 0.8 * len(score_midpts) 152 | criterion2 = score_with_dist_prior > 0 153 | if criterion1 and criterion2: 154 | connection_candidate.append( 155 | [i, j, score_with_dist_prior, score_with_dist_prior + candA[i][2] + candB[j][2]]) 156 | 157 | connection_candidate = sorted(connection_candidate, key=lambda x: x[2], reverse=True) 158 | connection = np.zeros((0, 5)) 159 | for c in range(len(connection_candidate)): 160 | i, j, s = connection_candidate[c][0:3] 161 | if (i not in connection[:, 3] and j not in connection[:, 4]): 162 | connection = np.vstack([connection, [candA[i][3], candB[j][3], s, i, j]]) 163 | if (len(connection) >= min(nA, nB)): 164 | break 165 | 166 | connection_all.append(connection) 167 | else: 168 | special_k.append(k) 169 | connection_all.append([]) 170 | 171 | # last number in each row is the total parts number of that person 172 | # the second last number in each row is the score of the overall configuration 173 | subset = -1 * np.ones((0, self.njoint + 1)) 174 | candidate = np.array([item for sublist in all_peaks for item in sublist]) 175 | 176 | for k in range(len(mapIdx)): 177 | if k not in special_k: 178 | partAs = connection_all[k][:, 0] 179 | partBs = connection_all[k][:, 1] 180 | indexA, indexB = np.array(limbSeq[k]) 181 | 182 | for i in range(len(connection_all[k])): # = 1:size(temp,1) 183 | found = 0 184 | subset_idx = [-1, -1] 185 | for j in range(len(subset)): # 1:size(subset,1): 186 | if subset[j][indexA] == partAs[i] or subset[j][indexB] == partBs[i]: 187 | subset_idx[found] = j 188 | found += 1 189 | 190 | if found == 1: 191 | j = subset_idx[0] 192 | if subset[j][indexB] != partBs[i]: 193 | subset[j][indexB] = partBs[i] 194 | subset[j][-1] += 1 195 | subset[j][-2] += candidate[partBs[i].astype(int), 2] + connection_all[k][i][2] 196 | elif found == 2: # if found 2 and disjoint, merge them 197 | j1, j2 = subset_idx 198 | membership = ((subset[j1] >= 0).astype(int) + (subset[j2] >= 0).astype(int))[:-2] 199 | if len(np.nonzero(membership == 2)[0]) == 0: # merge 200 | subset[j1][:-2] += (subset[j2][:-2] + 1) 201 | subset[j1][-2:] += subset[j2][-2:] 202 | subset[j1][-2] += connection_all[k][i][2] 203 | subset = np.delete(subset, j2, 0) 204 | else: # as like found == 1 205 | subset[j1][indexB] = partBs[i] 206 | subset[j1][-1] += 1 207 | subset[j1][-2] += candidate[partBs[i].astype(int), 2] + connection_all[k][i][2] 208 | 209 | # if find no partA in the subset, create a new subset 210 | elif not found: 211 | row = -1 * np.ones(self.njoint + 1) 212 | row[indexA] = partAs[i] 213 | row[indexB] = partBs[i] 214 | row[-1] = 2 215 | row[-2] = sum(candidate[connection_all[k][i, :2].astype(int), 2]) + connection_all[k][i][2] 216 | subset = np.vstack([subset, row]) 217 | # delete some rows of subset which has few parts occur 218 | deleteIdx = [] 219 | for i in range(len(subset)): 220 | if subset[i][-1] < 4 or subset[i][-2] / subset[i][-1] < 0.4: 221 | deleteIdx.append(i) 222 | subset = np.delete(subset, deleteIdx, axis=0) 223 | 224 | poses = [] 225 | for per in subset: 226 | pose = [] 227 | for po in per[:-2]: 228 | if po >= 0: 229 | joint = list(candidate[int(po)][:3]) 230 | else: 231 | joint = [0.,0.,0.] 232 | pose.append(joint) 233 | poses.append(pose) 234 | 235 | return poses 236 | 237 | 238 | -------------------------------------------------------------------------------- /src/util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | import cv2 4 | 5 | 6 | # draw the body keypoint and lims 7 | def draw_bodypose(img, poses,model_type = 'coco'): 8 | stickwidth = 4 9 | 10 | limbSeq = [[1, 2], [1, 5], [2, 3], [3, 4], [5, 6], [6, 7], [1, 8], [8, 9], \ 11 | [9, 10], [1, 11], [11, 12], [12, 13], [1, 0], [0, 14], [14, 16], \ 12 | [0, 15], [15, 17]] 13 | njoint = 18 14 | if model_type == 'body_25': 15 | limbSeq = [[1,0],[1,2],[2,3],[3,4],[1,5],[5,6],[6,7],[1,8],[8,9],[9,10],\ 16 | [10,11],[8,12],[12,13],[13,14],[0,15],[0,16],[15,17],[16,18],\ 17 | [11,24],[11,22],[14,21],[14,19],[22,23],[19,20]] 18 | njoint = 25 19 | 20 | colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \ 21 | [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \ 22 | [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85], [255,255,0], [255,255,85], [255,255,170],\ 23 | [255,255,255],[170,255,255],[85,255,255],[0,255,255]] 24 | for i in range(njoint): 25 | for n in range(len(poses)): 26 | pose = poses[n][i] 27 | if pose[2] <= 0: 28 | continue 29 | x, y = pose[:2] 30 | cv2.circle(img, (int(x), int(y)), 4, colors[i], thickness=-1) 31 | 32 | for pose in poses: 33 | for limb,color in zip(limbSeq,colors): 34 | p1 = pose[limb[0]] 35 | p2 = pose[limb[1]] 36 | if p1[2] <=0 or p2[2] <= 0: 37 | continue 38 | cur_canvas = img.copy() 39 | X = [p1[1],p2[1]] 40 | Y = [p1[0],p2[0]] 41 | mX = np.mean(X) 42 | mY = np.mean(Y) 43 | length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 44 | angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) 45 | polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1) 46 | cv2.fillConvexPoly(cur_canvas, polygon, color) 47 | img = cv2.addWeighted(img, 0.4, cur_canvas, 0.6, 0) 48 | 49 | return img 50 | 51 | def padRightDownCorner(img, stride, padValue): 52 | h = img.shape[0] 53 | w = img.shape[1] 54 | 55 | pad = 4 * [None] 56 | pad[0] = 0 # up 57 | pad[1] = 0 # left 58 | pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down 59 | pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right 60 | 61 | img_padded = img 62 | pad_up = np.tile(img_padded[0:1, :, :]*0 + padValue, (pad[0], 1, 1)) 63 | img_padded = np.concatenate((pad_up, img_padded), axis=0) 64 | pad_left = np.tile(img_padded[:, 0:1, :]*0 + padValue, (1, pad[1], 1)) 65 | img_padded = np.concatenate((pad_left, img_padded), axis=1) 66 | pad_down = np.tile(img_padded[-2:-1, :, :]*0 + padValue, (pad[2], 1, 1)) 67 | img_padded = np.concatenate((img_padded, pad_down), axis=0) 68 | pad_right = np.tile(img_padded[:, -2:-1, :]*0 + padValue, (1, pad[3], 1)) 69 | img_padded = np.concatenate((img_padded, pad_right), axis=1) 70 | 71 | return img_padded, pad 72 | --------------------------------------------------------------------------------