├── README.md ├── images ├── main.py ├── mesh1.png ├── mesh2.png ├── mesh3.png ├── paper_picked.png └── r.png └── src ├── Discriminator.py ├── HourGlass.py ├── LinearModel.py ├── PRNetEncoder.py ├── Resnet.py ├── SMPL.py ├── config.py ├── dataloader ├── AICH_dataloader.py ├── COCO2017_dataloader.py ├── eval_dataloader.py ├── hum36m_dataloader.py ├── lsp_dataloader.py ├── lsp_ext_dataloader.py ├── mosh_dataloader.py └── mpi_inf_3dhp_dataloader.py ├── densenet.py ├── do_train.sh ├── model.py ├── timer.py ├── trainer.py └── util.py /README.md: -------------------------------------------------------------------------------- 1 | ## HMR 2 |

3 | 4 |

5 | 6 | This is a **pytorch** implementation of [End-to-end Recovery of Human Shape and Pose](https://arxiv.org/abs/1712.06584) by *Angjoo Kanazawa, Michael J. Black, David W. Jacobs*, and *Jitendra Malik*, accompanying by some famous human pose estimation networks and datasets. 7 | HMR is an end-to end framework for reconstructing a full 3D mesh of a human body from a single RGB image. In contrast to most current methods that compute 2D or 3D joint locations, HMR produce a richer and more useful mesh representation that is parameterized by shape and 3D joint angles. The main objective is to minimize the reprojection loss of keypoints, which allow model to be trained using in-the-wild images that only have ground truth 2D annotations. For visual impact, please visit the [author's original video.](https://www.youtube.com/watch?v=bmMV9aJKa-c) 8 | 9 | ## training step (the following links are not available now due to license limitation) 10 | #### 1. download the following datasets. 11 | - [AI challenger keypoint dataset](https://challenger.ai/datasets/keypoint) 12 | - [lsp 14-keypoint dataset](https://pan.baidu.com/s/1BgKRJfggJcObHXkzHH5I5A) 13 | - [lsp 14-keypoint extension dataset](https://pan.baidu.com/s/1uUcsdCKbzIwKCc9SzVFXAA) 14 | - [COCO-2017-keypoint dataset](http://cocodataset.org/) 15 | - [mpi_inf_3dhp 3d keypoint dataset](https://pan.baidu.com/s/1XQZNV3KPtiBi5ODnr7RB9A) 16 | - [mosh dataset, which used for adv training](https://pan.baidu.com/s/1OWzeMeLS5tKx1XGAiyZ0XA) 17 | #### 2. download human3.6 datasets. 18 | - [hum3.6m_part_1.zip](https://pan.baidu.com/s/1oeO213vrKyYEr46P1OBEgw) 19 | - [hum3.6m_part_2.zip](https://pan.baidu.com/s/1XRnNn0qJeo5TECacjiJv4g) 20 | - [hum3.6m_part_3.zip](https://pan.baidu.com/s/15AOngXr3zya2XsK7Sry97g) 21 | - [hum3.6m_part_4.zip](https://pan.baidu.com/s/1RNqWSP1KREBhvPHn6-pCbA) 22 | - [hum3.6m_part_5.zip](https://pan.baidu.com/s/109RwxgpWxEraXzIXf7iYkg) 23 | - [hum3.6m_anno.zip](https://pan.baidu.com/s/1kCOQ2qzf69RLX3VN4cw5Mw) 24 | #### 3. unzip the downloaded datasets. 25 | #### 4. unzip the [model.zip](https://pan.baidu.com/s/1PUv5kUydmx5RG1E0KsQBkw) 26 | #### 5. config the environment by modify the src/config.py and do_train.sh 27 | #### 6. run ./do_train.sh directly 28 | 29 | ## environment configurations. 30 | - install **pytorch0.4** 31 | - install torchvision 32 | - install numpy 33 | - install scipy 34 | - install h5py 35 | - install opencv-python 36 | 37 | ## result 38 |

39 | 40 |

41 | 42 | 43 | ## reference papers 44 | - [Stacked Hourglass Networks for Human Pose Estimation](https://arxiv.org/abs/1603.06937) 45 | - [SMPL: A Skinned Multi-Person Linear Model](http://files.is.tue.mpg.de/black/papers/SMPL2015.pdf) 46 | - [Keep it SMPL: Automatic Estimation of 3D Human Pose and Shape from a Single Image](https://pdfs.semanticscholar.org/4cea/52b44fc5cb1803a07fa466b6870c25535313.pdf) 47 | - [motion and shape capture from sparse markers](http://files.is.tue.mpg.de/black/papers/MoSh.pdf) 48 | - [Unite the People: Closing the Loop Between 3D and 2D Human Representations](https://arxiv.org/abs/1701.02468) 49 | - [End-to-end Recovery of Human Shape and Pose](https://arxiv.org/abs/1712.06584) 50 | 51 | ## reference resources 52 | - [up-3d dataset](http://files.is.tuebingen.mpg.de/classner/up/) 53 | - [coco-2017 dataset](http://cocodataset.org/) 54 | - [human 3.6m datas](http://vision.imar.ro/human3.6m/description.php) 55 | - [ai challenger dataset](https://challenger.ai/) 56 | 57 | -------------------------------------------------------------------------------- /images/main.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | 5 | sa = cv2.imread('./mesh1.png') 6 | sb = cv2.imread('./mesh2.png') 7 | sc = cv2.imread('./mesh3.png') 8 | 9 | da = dst_image = cv2.resize(sa, (300, 400), interpolation = cv2.INTER_CUBIC) 10 | db = dst_image = cv2.resize(sb, (300, 400), interpolation = cv2.INTER_CUBIC) 11 | dc = dst_image = cv2.resize(sc, (300, 400), interpolation = cv2.INTER_CUBIC) 12 | 13 | d = np.zeros((400, 900, 3)) 14 | d[:, :300, :] = da[:, :, :] 15 | d[:, 300:600,:] = db[:, :, :] 16 | d[:, 600:, :] = dc[:, :, :] 17 | 18 | cv2.imwrite('r.png', d) -------------------------------------------------------------------------------- /images/mesh1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MandyMo/pytorch_HMR/7bf18d619aeafd97e9df7364e354cd7e9480966f/images/mesh1.png -------------------------------------------------------------------------------- /images/mesh2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MandyMo/pytorch_HMR/7bf18d619aeafd97e9df7364e354cd7e9480966f/images/mesh2.png -------------------------------------------------------------------------------- /images/mesh3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MandyMo/pytorch_HMR/7bf18d619aeafd97e9df7364e354cd7e9480966f/images/mesh3.png -------------------------------------------------------------------------------- /images/paper_picked.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MandyMo/pytorch_HMR/7bf18d619aeafd97e9df7364e354cd7e9480966f/images/paper_picked.png -------------------------------------------------------------------------------- /images/r.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MandyMo/pytorch_HMR/7bf18d619aeafd97e9df7364e354cd7e9480966f/images/r.png -------------------------------------------------------------------------------- /src/Discriminator.py: -------------------------------------------------------------------------------- 1 | 2 | ''' 3 | file: Discriminator.py 4 | 5 | date: 2017_04_29 6 | author: zhangxiong(1025679612@qq.com) 7 | ''' 8 | 9 | from LinearModel import LinearModel 10 | import config 11 | import util 12 | import torch 13 | import numpy as np 14 | import torch.nn as nn 15 | from config import args 16 | 17 | ''' 18 | shape discriminator is used for shape discriminator 19 | the inputs if N x 10 20 | ''' 21 | class ShapeDiscriminator(LinearModel): 22 | def __init__(self, fc_layers, use_dropout, drop_prob, use_ac_func): 23 | if fc_layers[-1] != 1: 24 | msg = 'the neuron count of the last layer must be 1, but got {}'.format(fc_layers[-1]) 25 | sys.exit(msg) 26 | 27 | super(ShapeDiscriminator, self).__init__(fc_layers, use_dropout, drop_prob, use_ac_func) 28 | 29 | def forward(self, inputs): 30 | return self.fc_blocks(inputs) 31 | 32 | class PoseDiscriminator(nn.Module): 33 | def __init__(self, channels): 34 | super(PoseDiscriminator, self).__init__() 35 | 36 | if channels[-1] != 1: 37 | msg = 'the neuron count of the last layer must be 1, but got {}'.format(channels[-1]) 38 | sys.exit(msg) 39 | 40 | self.conv_blocks = nn.Sequential() 41 | l = len(channels) 42 | for idx in range(l - 2): 43 | self.conv_blocks.add_module( 44 | name = 'conv_{}'.format(idx), 45 | module = nn.Conv2d(in_channels = channels[idx], out_channels = channels[idx + 1], kernel_size = 1, stride = 1) 46 | ) 47 | 48 | self.fc_layer = nn.ModuleList() 49 | for idx in range(23): 50 | self.fc_layer.append(nn.Linear(in_features = channels[l - 2], out_features = 1)) 51 | 52 | # N x 23 x 9 53 | def forward(self, inputs): 54 | batch_size = inputs.shape[0] 55 | inputs = inputs.transpose(1, 2).unsqueeze(2) # to N x 9 x 1 x 23 56 | internal_outputs = self.conv_blocks(inputs) # to N x c x 1 x 23 57 | o = [] 58 | for idx in range(23): 59 | o.append(self.fc_layer[idx](internal_outputs[:,:,0,idx])) 60 | 61 | return torch.cat(o, 1), internal_outputs 62 | 63 | class FullPoseDiscriminator(LinearModel): 64 | def __init__(self, fc_layers, use_dropout, drop_prob, use_ac_func): 65 | if fc_layers[-1] != 1: 66 | msg = 'the neuron count of the last layer must be 1, but got {}'.format(fc_layers[-1]) 67 | sys.exit(msg) 68 | 69 | super(FullPoseDiscriminator, self).__init__(fc_layers, use_dropout, drop_prob, use_ac_func) 70 | 71 | def forward(self, inputs): 72 | return self.fc_blocks(inputs) 73 | 74 | class Discriminator(nn.Module): 75 | def __init__(self): 76 | super(Discriminator, self).__init__() 77 | self._read_configs() 78 | 79 | self._create_sub_modules() 80 | 81 | def _read_configs(self): 82 | self.beta_count = args.beta_count 83 | self.smpl_model = args.smpl_model 84 | self.smpl_mean_theta_path = args.smpl_mean_theta_path 85 | self.total_theta_count = args.total_theta_count 86 | self.joint_count = args.joint_count 87 | self.feature_count = args.feature_count 88 | 89 | def _create_sub_modules(self): 90 | ''' 91 | create theta discriminator for 23 joint 92 | ''' 93 | 94 | self.pose_discriminator = PoseDiscriminator([9, 32, 32, 1]) 95 | 96 | ''' 97 | create full pose discriminator for total 23 joints 98 | ''' 99 | fc_layers = [23 * 32, 1024, 1024, 1] 100 | use_dropout = [False, False, False] 101 | drop_prob = [0.5, 0.5, 0.5] 102 | use_ac_func = [True, True, False] 103 | self.full_pose_discriminator = FullPoseDiscriminator(fc_layers, use_dropout, drop_prob, use_ac_func) 104 | 105 | ''' 106 | shape discriminator for betas 107 | ''' 108 | fc_layers = [self.beta_count, 5, 1] 109 | use_dropout = [False, False] 110 | drop_prob = [0.5, 0.5] 111 | use_ac_func = [True, False] 112 | self.shape_discriminator = ShapeDiscriminator(fc_layers, use_dropout, drop_prob, use_ac_func) 113 | 114 | print('finished create the discriminator modules...') 115 | 116 | 117 | ''' 118 | inputs is N x 85(3 + 72 + 10) 119 | ''' 120 | def forward(self, thetas): 121 | batch_size = thetas.shape[0] 122 | cams, poses, shapes = thetas[:, :3], thetas[:, 3:75], thetas[:, 75:] 123 | shape_disc_value = self.shape_discriminator(shapes) 124 | rotate_matrixs = util.batch_rodrigues(poses.contiguous().view(-1, 3)).view(-1, 24, 9)[:, 1:, :] 125 | pose_disc_value, pose_inter_disc_value = self.pose_discriminator(rotate_matrixs) 126 | full_pose_disc_value = self.full_pose_discriminator(pose_inter_disc_value.contiguous().view(batch_size, -1)) 127 | return torch.cat((pose_disc_value, full_pose_disc_value, shape_disc_value), 1) 128 | 129 | if __name__ == '__main__': 130 | device = torch.device('cuda') 131 | net = Discriminator() 132 | inputs = torch.ones((100, 85)) 133 | disc_value = net(inputs) 134 | print(net) -------------------------------------------------------------------------------- /src/HourGlass.py: -------------------------------------------------------------------------------- 1 | 2 | ''' 3 | file: hourglass.py 4 | 5 | date: 2018_05_12 6 | author: zhangxiong(1025679612@qq.com) 7 | ''' 8 | 9 | from __future__ import print_function 10 | import numpy as np 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | class Residual(nn.Module): 16 | def __init__(self, use_bn, input_channels, out_channels, mid_channels): 17 | super(Residual, self).__init__() 18 | self.use_bn = use_bn 19 | self.out_channels = out_channels 20 | self.input_channels = input_channels 21 | self.mid_channels = mid_channels 22 | 23 | self.down_channel = nn.Conv2d(input_channels, self.mid_channels, kernel_size = 1) 24 | self.AcFunc = nn.ReLU() 25 | if use_bn: 26 | self.bn_0 = nn.BatchNorm2d(num_features = self.mid_channels) 27 | self.bn_1 = nn.BatchNorm2d(num_features = self.mid_channels) 28 | self.bn_2 = nn.BatchNorm2d(num_features = self.out_channels) 29 | 30 | self.conv = nn.Conv2d(self.mid_channels, self.mid_channels, kernel_size = 3, padding = 1) 31 | 32 | self.up_channel = nn.Conv2d(self.mid_channels, out_channels, kernel_size= 1) 33 | 34 | if input_channels != out_channels: 35 | self.trans = nn.Conv2d(input_channels, out_channels, kernel_size = 1) 36 | 37 | def forward(self, inputs): 38 | x = self.down_channel(inputs) 39 | if self.use_bn: 40 | x = self.bn_0(x) 41 | x = self.AcFunc(x) 42 | 43 | x = self.conv(x) 44 | if self.use_bn: 45 | x = self.bn_1(x) 46 | x = self.AcFunc(x) 47 | 48 | x = self.up_channel(x) 49 | 50 | if self.input_channels != self.out_channels: 51 | x += self.trans(inputs) 52 | else: 53 | x += inputs 54 | 55 | if self.use_bn: 56 | x = self.bn_2(x) 57 | 58 | return self.AcFunc(x) 59 | 60 | class HourGlassBlock(nn.Module): 61 | def __init__(self, block_count, residual_each_block, input_channels, mid_channels, use_bn, stack_index): 62 | super(HourGlassBlock, self).__init__() 63 | 64 | self.block_count = block_count 65 | self.residual_each_block = residual_each_block 66 | self.use_bn = use_bn 67 | self.stack_index = stack_index 68 | self.input_channels = input_channels 69 | self.mid_channels = mid_channels 70 | 71 | if self.block_count == 0: #inner block 72 | self.process = nn.Sequential() 73 | for _ in range(residual_each_block * 3): 74 | self.process.add_module( 75 | name = 'inner_{}_{}'.format(self.stack_index, _), 76 | module = Residual(input_channels = input_channels, out_channels = input_channels, mid_channels = mid_channels, use_bn = use_bn) 77 | ) 78 | else: 79 | #down sampling 80 | self.down_sampling = nn.Sequential() 81 | self.down_sampling.add_module( 82 | name = 'down_sample_{}_{}'.format(self.stack_index, self.block_count), 83 | module = nn.MaxPool2d(kernel_size = 2, stride = 2) 84 | ) 85 | for _ in range(residual_each_block): 86 | self.down_sampling.add_module( 87 | name = 'residual_{}_{}_{}'.format(self.stack_index, self.block_count, _), 88 | module = Residual(input_channels = input_channels, out_channels = input_channels, mid_channels = mid_channels, use_bn = use_bn) 89 | ) 90 | 91 | #up sampling 92 | self.up_sampling = nn.Sequential() 93 | self.up_sampling.add_module( 94 | name = 'up_sample_{}_{}'.format(self.stack_index, self.block_count), 95 | module = nn.Upsample(scale_factor=2, mode='bilinear') 96 | ) 97 | for _ in range(residual_each_block): 98 | self.up_sampling.add_module( 99 | name = 'residual_{}_{}_{}'.format(self.stack_index, self.block_count, _), 100 | module = Residual(input_channels = input_channels, out_channels = input_channels, mid_channels = mid_channels, use_bn = use_bn) 101 | ) 102 | 103 | #sub hour glass 104 | self.sub_hg = HourGlassBlock( 105 | block_count = self.block_count - 1, 106 | residual_each_block = self.residual_each_block, 107 | input_channels = self.input_channels, 108 | mid_channels = self.mid_channels, 109 | use_bn = self.use_bn, 110 | stack_index = self.stack_index 111 | ) 112 | 113 | # trans 114 | self.trans = nn.Sequential() 115 | for _ in range(residual_each_block): 116 | self.trans.add_module( 117 | name = 'trans_{}_{}_{}'.format(self.stack_index, self.block_count, _), 118 | module = Residual(input_channels = input_channels, out_channels = input_channels, mid_channels = mid_channels, use_bn = use_bn) 119 | ) 120 | 121 | 122 | def forward(self, inputs): 123 | if self.block_count == 0: 124 | return self.process(inputs) 125 | else: 126 | down_sampled = self.down_sampling(inputs) 127 | transed = self.trans(down_sampled) 128 | sub_net_output = self.sub_hg(down_sampled) 129 | return self.up_sampling(transed + sub_net_output) 130 | 131 | ''' 132 | the input is a 256 x 256 x 3 image 133 | ''' 134 | class HourGlass(nn.Module): 135 | def __init__(self, nStack, nBlockCount, nResidualEachBlock, nMidChannels, nChannels, nJointCount, bUseBn): 136 | super(HourGlass, self).__init__() 137 | 138 | self.nStack = nStack 139 | self.nBlockCount = nBlockCount 140 | self.nResidualEachBlock = nResidualEachBlock 141 | self.nChannels = nChannels 142 | self.nMidChannels = nMidChannels 143 | self.nJointCount = nJointCount 144 | self.bUseBn = bUseBn 145 | 146 | self.pre_process = nn.Sequential( 147 | nn.Conv2d(3, nChannels, kernel_size = 3, padding = 1), 148 | Residual(use_bn = bUseBn, input_channels = nChannels, out_channels = nChannels, mid_channels = nMidChannels), 149 | nn.MaxPool2d(kernel_size = 2, stride = 2), # to 128 x 128 x c 150 | Residual(use_bn = bUseBn, input_channels = nChannels, out_channels = nChannels, mid_channels = nMidChannels), 151 | nn.MaxPool2d(kernel_size = 2, stride = 2), # to 64 x 64 x c 152 | Residual(use_bn = bUseBn, input_channels = nChannels, out_channels = nChannels, mid_channels = nMidChannels) 153 | ) 154 | 155 | self.hg = nn.ModuleList() 156 | for _ in range(nStack): 157 | self.hg.append( 158 | HourGlassBlock( 159 | block_count = nBlockCount, 160 | residual_each_block = nResidualEachBlock, 161 | input_channels = nChannels, 162 | mid_channels = nMidChannels, 163 | use_bn = bUseBn, 164 | stack_index = _ 165 | ) 166 | ) 167 | 168 | self.blocks = nn.ModuleList() 169 | for _ in range(nStack - 1): 170 | self.blocks.append( 171 | nn.Sequential( 172 | Residual( 173 | use_bn = bUseBn, input_channels = nChannels, out_channels = nChannels, mid_channels = nMidChannels 174 | ), 175 | Residual( 176 | use_bn = bUseBn, input_channels = nChannels, out_channels = nChannels, mid_channels = nMidChannels 177 | ) 178 | ) 179 | ) 180 | 181 | self.intermediate_supervision = nn.ModuleList() 182 | for _ in range(nStack): # to 64 x 64 x joint_count 183 | self.intermediate_supervision.append( 184 | nn.Conv2d(nChannels, nJointCount, kernel_size = 1, stride = 1) 185 | ) 186 | 187 | self.normal_feature_channel = nn.ModuleList() 188 | for _ in range(nStack - 1): 189 | self.normal_feature_channel.append( 190 | Residual( 191 | use_bn = bUseBn, input_channels = nJointCount, out_channels = nChannels, mid_channels = nMidChannels 192 | ) 193 | ) 194 | 195 | def forward(self, inputs): 196 | o = [] #outputs include intermediate supervision result 197 | x = self.pre_process(inputs) 198 | for _ in range(self.nStack): 199 | o1 = self.hg[_](x) 200 | o2 = self.intermediate_supervision[_](o1) 201 | o.append(o2.view(-1, 4096)) 202 | if _ == self.nStack - 1: 203 | break 204 | o2 = self.normal_feature_channel[_](o2) 205 | o1 = self.blocks[_](o1) 206 | x = o1 + o2 + x 207 | return o 208 | 209 | def _create_hourglass_net(): 210 | return HourGlass( 211 | nStack = 2, 212 | nBlockCount = 4, 213 | nResidualEachBlock = 1, 214 | nMidChannels = 128, 215 | nChannels = 256, 216 | nJointCount = 1, 217 | bUseBn = True, 218 | ) -------------------------------------------------------------------------------- /src/LinearModel.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | ''' 4 | file: LinearModel.py 5 | 6 | date: 2018_04_29 7 | author: zhangxiong(1025679612@qq.com) 8 | ''' 9 | 10 | import torch.nn as nn 11 | import numpy as np 12 | import sys 13 | import torch 14 | 15 | class LinearModel(nn.Module): 16 | ''' 17 | input param: 18 | fc_layers: a list of neuron count, such as [2133, 1024, 1024, 85] 19 | use_dropout: a list of bool define use dropout or not for each layer, such as [True, True, False] 20 | drop_prob: a list of float defined the drop prob, such as [0.5, 0.5, 0] 21 | use_ac_func: a list of bool define use active function or not, such as [True, True, False] 22 | ''' 23 | def __init__(self, fc_layers, use_dropout, drop_prob, use_ac_func): 24 | super(LinearModel, self).__init__() 25 | self.fc_layers = fc_layers 26 | self.use_dropout = use_dropout 27 | self.drop_prob = drop_prob 28 | self.use_ac_func = use_ac_func 29 | 30 | if not self._check(): 31 | msg = 'wrong LinearModel parameters!' 32 | print(msg) 33 | sys.exit(msg) 34 | 35 | self.create_layers() 36 | 37 | def _check(self): 38 | while True: 39 | if not isinstance(self.fc_layers, list): 40 | print('fc_layers require list, get {}'.format(type(self.fc_layers))) 41 | break 42 | 43 | if not isinstance(self.use_dropout, list): 44 | print('use_dropout require list, get {}'.format(type(self.use_dropout))) 45 | break 46 | 47 | if not isinstance(self.drop_prob, list): 48 | print('drop_prob require list, get {}'.format(type(self.drop_prob))) 49 | break 50 | 51 | if not isinstance(self.use_ac_func, list): 52 | print('use_ac_func require list, get {}'.format(type(self.use_ac_func))) 53 | break 54 | 55 | l_fc_layer = len(self.fc_layers) 56 | l_use_drop = len(self.use_dropout) 57 | l_drop_porb = len(self.drop_prob) 58 | l_use_ac_func = len(self.use_ac_func) 59 | 60 | return l_fc_layer >= 2 and l_use_drop < l_fc_layer and l_drop_porb < l_fc_layer and l_use_ac_func < l_fc_layer and l_drop_porb == l_use_drop 61 | 62 | return False 63 | 64 | def create_layers(self): 65 | l_fc_layer = len(self.fc_layers) 66 | l_use_drop = len(self.use_dropout) 67 | l_drop_porb = len(self.drop_prob) 68 | l_use_ac_func = len(self.use_ac_func) 69 | 70 | self.fc_blocks = nn.Sequential() 71 | 72 | for _ in range(l_fc_layer - 1): 73 | self.fc_blocks.add_module( 74 | name = 'regressor_fc_{}'.format(_), 75 | module = nn.Linear(in_features = self.fc_layers[_], out_features = self.fc_layers[_ + 1]) 76 | ) 77 | 78 | if _ < l_use_ac_func and self.use_ac_func[_]: 79 | self.fc_blocks.add_module( 80 | name = 'regressor_af_{}'.format(_), 81 | module = nn.ReLU() 82 | ) 83 | 84 | if _ < l_use_drop and self.use_dropout[_]: 85 | self.fc_blocks.add_module( 86 | name = 'regressor_fc_dropout_{}'.format(_), 87 | module = nn.Dropout(p = self.drop_prob[_]) 88 | ) 89 | 90 | def forward(self, inputs): 91 | msg = 'the base class [LinearModel] is not callable!' 92 | sys.exit(msg) 93 | 94 | if __name__ == '__main__': 95 | fc_layers = [2133, 1024, 1024, 85] 96 | iterations = 3 97 | use_dropout = [True, True, False] 98 | drop_prob = [0.5, 0.5, 0] 99 | use_ac_func = [True, True, False] 100 | device = torch.device('cuda') 101 | net = LinearModel(fc_layers, use_dropout, drop_prob, use_ac_func).to(device) 102 | print(net) 103 | nx = np.zeros([2, 2048]) 104 | vx = torch.from_numpy(nx).to(device) 105 | -------------------------------------------------------------------------------- /src/PRNetEncoder.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | ''' 4 | file: PRnetEncoder.py 5 | 6 | date: 2018_05_22 7 | author: zhangxiong(1025679612@qq.com) 8 | mark: the algorithm is cited from PRNet code 9 | ''' 10 | 11 | from __future__ import print_function 12 | import numpy as np 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | 17 | class Residual(nn.Module): 18 | def __init__(self, use_bn, input_channels, out_channels, mid_channels, kernel_size = 3, padding = 1, stride = 1): 19 | super(Residual, self).__init__() 20 | self.use_bn = use_bn 21 | self.out_channels = out_channels 22 | self.input_channels = input_channels 23 | self.mid_channels = mid_channels 24 | 25 | self.down_channel = nn.Conv2d(input_channels, self.mid_channels, kernel_size = 1) 26 | self.AcFunc = nn.ReLU() 27 | if use_bn: 28 | self.bn_0 = nn.BatchNorm2d(num_features = self.mid_channels) 29 | self.bn_1 = nn.BatchNorm2d(num_features = self.mid_channels) 30 | self.bn_2 = nn.BatchNorm2d(num_features = self.out_channels) 31 | 32 | self.conv = nn.Conv2d(self.mid_channels, self.mid_channels, kernel_size = kernel_size, padding = padding, stride = stride) 33 | 34 | self.up_channel = nn.Conv2d(self.mid_channels, out_channels, kernel_size= 1) 35 | 36 | if input_channels != out_channels: 37 | self.trans = nn.Conv2d(input_channels, out_channels, kernel_size = 1) 38 | 39 | def forward(self, inputs): 40 | x = self.down_channel(inputs) 41 | if self.use_bn: 42 | x = self.bn_0(x) 43 | x = self.AcFunc(x) 44 | 45 | x = self.conv(x) 46 | if self.use_bn: 47 | x = self.bn_1(x) 48 | x = self.AcFunc(x) 49 | 50 | x = self.up_channel(x) 51 | 52 | if self.input_channels != self.out_channels: 53 | x += self.trans(inputs) 54 | else: 55 | x += inputs 56 | 57 | if self.use_bn: 58 | x = self.bn_2(x) 59 | 60 | return self.AcFunc(x) 61 | 62 | class PRNetEncoder(nn.Module): 63 | def __init__(self): 64 | super(PRNetEncoder, self).__init__() 65 | self.conv_blocks = nn.Sequential( 66 | nn.Conv2d(in_channels = 3, out_channels = 8, kernel_size = 3, stride = 1, padding = 1), # to 256 x 256 x 8 67 | nn.Conv2d(in_channels = 8, out_channels = 16, kernel_size = 3, stride = 1, padding = 1), # to 256 x 256 x 16 68 | Residual(use_bn = True, input_channels = 16, out_channels = 32, mid_channels = 16, stride = 1, padding = 1), # to 256 x 256 x 32 69 | nn.MaxPool2d(kernel_size = 2, stride = 2), # to 128 x 128 x 32 70 | Residual(use_bn = True, input_channels = 32, out_channels = 32, mid_channels = 16, stride = 1, padding = 1), # to 128 x 128 x 32 71 | Residual(use_bn = True, input_channels = 32, out_channels = 32, mid_channels = 16, stride = 1, padding = 1), # to 128 x 128 x 32 72 | Residual(use_bn = True, input_channels = 32, out_channels = 64, mid_channels = 32, stride = 1, padding = 1), # to 128 x 128 x 64 73 | nn.MaxPool2d(kernel_size = 2, stride = 2), # to 64 x 64 x 64 74 | Residual(use_bn = True, input_channels = 64, out_channels = 64, mid_channels = 32, stride = 1, padding = 1), # to 64 x 64 x 64 75 | Residual(use_bn = True, input_channels = 64, out_channels = 64, mid_channels = 32, stride = 1, padding = 1), # to 64 x 64 x 64 76 | Residual(use_bn = True, input_channels = 64, out_channels = 128, mid_channels = 64, stride = 1, padding = 1), # to 64 x 64 x 128 77 | nn.MaxPool2d(kernel_size = 2, stride = 2), # to 32 x 32 x 128 78 | Residual(use_bn = True, input_channels = 128, out_channels = 128, mid_channels = 64, stride = 1, padding = 1), # to 32 x 32 x 128 79 | Residual(use_bn = True, input_channels = 128, out_channels = 128, mid_channels = 64, stride = 1, padding = 1), # to 32 x 32 x 128 80 | Residual(use_bn = True, input_channels = 128, out_channels = 256, mid_channels = 128, stride = 1, padding = 1), # to 32 x 32 x 256 81 | nn.MaxPool2d(kernel_size = 2, stride = 2), # to 16 x 16 x 256 82 | Residual(use_bn = True, input_channels = 256, out_channels = 256, mid_channels = 128, stride = 1, padding = 1), # to 16 x 16 x 256 83 | Residual(use_bn = True, input_channels = 256, out_channels = 256, mid_channels = 128, stride = 1, padding = 1), # to 16 x 16 x 256 84 | Residual(use_bn = True, input_channels = 256, out_channels = 512, mid_channels = 256, stride = 1, padding = 1), # to 16 x 16 x 512 85 | nn.MaxPool2d(kernel_size = 2, stride = 2), # to 8 x 8 x 512 86 | Residual(use_bn = True, input_channels = 512, out_channels = 512, mid_channels = 256, stride = 1, padding = 1), # to 8 x 8 x 512 87 | nn.MaxPool2d(kernel_size = 2, stride = 2) , # to 4 x 4 x 512 88 | Residual(use_bn = True, input_channels = 512, out_channels = 512, mid_channels = 256, stride = 1, padding = 1), # to 4 x 4 x 512 89 | nn.MaxPool2d(kernel_size = 2, stride = 2), # to 2 x 2 x 512 90 | Residual(use_bn = True, input_channels = 512, out_channels = 512, mid_channels = 256, stride = 1, padding = 1) # to 2 x 2 x 512 91 | ) 92 | 93 | def forward(self, inputs): 94 | return self.conv_blocks(inputs).view(-1, 2048) 95 | 96 | 97 | if __name__ == '__main__': 98 | net = PRNetEncoder() 99 | inputs = torch.ones(size = (10, 3, 256, 256)).float() 100 | r = net(inputs) 101 | print(r.shape) -------------------------------------------------------------------------------- /src/Resnet.py: -------------------------------------------------------------------------------- 1 | 2 | ''' 3 | file: Resnet.py 4 | 5 | date: 2018_05_02 6 | author: zhangxiong(1025679612@qq.com) 7 | mark: copied from pytorch sourc code 8 | ''' 9 | 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import torch 13 | from torch.nn.parameter import Parameter 14 | import torch.optim as optim 15 | import numpy as np 16 | import math 17 | import torchvision 18 | 19 | class ResNet(nn.Module): 20 | def __init__(self, block, layers, num_classes=1000): 21 | self.inplanes = 64 22 | super(ResNet, self).__init__() 23 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 24 | bias=False) 25 | self.bn1 = nn.BatchNorm2d(64) 26 | self.relu = nn.ReLU(inplace=True) 27 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 28 | self.layer1 = self._make_layer(block, 64, layers[0]) 29 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 30 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 31 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 32 | self.avgpool = nn.AvgPool2d(7, stride=1) 33 | # self.fc = nn.Linear(512 * block.expansion, num_classes) 34 | 35 | for m in self.modules(): 36 | if isinstance(m, nn.Conv2d): 37 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 38 | m.weight.data.normal_(0, math.sqrt(2. / n)) 39 | elif isinstance(m, nn.BatchNorm2d): 40 | m.weight.data.fill_(1) 41 | m.bias.data.zero_() 42 | 43 | def _make_layer(self, block, planes, blocks, stride=1): 44 | downsample = None 45 | if stride != 1 or self.inplanes != planes * block.expansion: 46 | downsample = nn.Sequential( 47 | nn.Conv2d(self.inplanes, planes * block.expansion, 48 | kernel_size=1, stride=stride, bias=False), 49 | nn.BatchNorm2d(planes * block.expansion), 50 | ) 51 | 52 | layers = [] 53 | layers.append(block(self.inplanes, planes, stride, downsample)) 54 | self.inplanes = planes * block.expansion 55 | for i in range(1, blocks): 56 | layers.append(block(self.inplanes, planes)) 57 | 58 | return nn.Sequential(*layers) 59 | 60 | def forward(self, x): 61 | x = self.conv1(x) 62 | x = self.bn1(x) 63 | x = self.relu(x) 64 | x = self.maxpool(x) 65 | 66 | x = self.layer1(x) 67 | x = self.layer2(x) 68 | x = self.layer3(x) 69 | x = self.layer4(x) 70 | 71 | x = self.avgpool(x) 72 | x = x.view(x.size(0), -1) 73 | # x = self.fc(x) 74 | 75 | return x 76 | 77 | class Bottleneck(nn.Module): 78 | expansion = 4 79 | 80 | def __init__(self, inplanes, planes, stride=1, downsample=None): 81 | super(Bottleneck, self).__init__() 82 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 83 | self.bn1 = nn.BatchNorm2d(planes) 84 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 85 | padding=1, bias=False) 86 | self.bn2 = nn.BatchNorm2d(planes) 87 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 88 | self.bn3 = nn.BatchNorm2d(planes * 4) 89 | self.relu = nn.ReLU(inplace=True) 90 | self.downsample = downsample 91 | self.stride = stride 92 | 93 | def forward(self, x): 94 | residual = x 95 | 96 | out = self.conv1(x) 97 | out = self.bn1(out) 98 | out = self.relu(out) 99 | 100 | out = self.conv2(out) 101 | out = self.bn2(out) 102 | out = self.relu(out) 103 | 104 | out = self.conv3(out) 105 | out = self.bn3(out) 106 | 107 | if self.downsample is not None: 108 | residual = self.downsample(x) 109 | 110 | out += residual 111 | out = self.relu(out) 112 | 113 | return out 114 | 115 | def conv3x3(in_planes, out_planes, stride=1): 116 | """3x3 convolution with padding""" 117 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 118 | padding=1, bias=False) 119 | 120 | class BasicBlock(nn.Module): 121 | expansion = 1 122 | 123 | def __init__(self, inplanes, planes, stride=1, downsample=None): 124 | super(BasicBlock, self).__init__() 125 | self.conv1 = conv3x3(inplanes, planes, stride) 126 | self.bn1 = nn.BatchNorm2d(planes) 127 | self.relu = nn.ReLU(inplace=True) 128 | self.conv2 = conv3x3(planes, planes) 129 | self.bn2 = nn.BatchNorm2d(planes) 130 | self.downsample = downsample 131 | self.stride = stride 132 | 133 | def forward(self, x): 134 | residual = x 135 | 136 | out = self.conv1(x) 137 | out = self.bn1(out) 138 | out = self.relu(out) 139 | 140 | out = self.conv2(out) 141 | out = self.bn2(out) 142 | 143 | if self.downsample is not None: 144 | residual = self.downsample(x) 145 | 146 | out += residual 147 | out = self.relu(out) 148 | 149 | return out 150 | 151 | def copy_parameter_from_resnet50(model, res50_dict): 152 | cur_state_dict = model.state_dict() 153 | for name, param in list(res50_dict.items())[0:None]: 154 | if name not in cur_state_dict: 155 | print('unexpected ', name, ' !') 156 | continue 157 | if isinstance(param, Parameter): 158 | param = param.data 159 | try: 160 | cur_state_dict[name].copy_(param) 161 | except: 162 | print(name, ' is inconsistent!') 163 | continue 164 | print('copy state dict finished!') 165 | 166 | def load_Res50Model(): 167 | model = ResNet(Bottleneck, [3, 4, 6, 3]) 168 | copy_parameter_from_resnet50(model, torchvision.models.resnet50(pretrained = True).state_dict()) 169 | return model 170 | 171 | if __name__ == '__main__': 172 | vx = torch.autograd.Variable(torch.from_numpy(np.array([1, 1, 1]))) 173 | vy = torch.autograd.Variable(torch.from_numpy(np.array([2, 2, 2]))) 174 | vz = torch.cat([vx, vy], 0) 175 | vz[0] = 100 176 | print(vz) 177 | print(vx) -------------------------------------------------------------------------------- /src/SMPL.py: -------------------------------------------------------------------------------- 1 | 2 | ''' 3 | file: SMPL.py 4 | 5 | date: 2018_05_03 6 | author: zhangxiong(1025679612@qq.com) 7 | mark: the algorithm is cited from original SMPL 8 | ''' 9 | import torch 10 | from config import args 11 | import json 12 | import sys 13 | import numpy as np 14 | from util import batch_global_rigid_transformation, batch_rodrigues, batch_lrotmin, reflect_pose 15 | import torch.nn as nn 16 | 17 | class SMPL(nn.Module): 18 | def __init__(self, model_path, joint_type = 'cocoplus', obj_saveable = False): 19 | super(SMPL, self).__init__() 20 | 21 | if joint_type not in ['cocoplus', 'lsp']: 22 | msg = 'unknow joint type: {}, it must be either "cocoplus" or "lsp"'.format(joint_type) 23 | sys.exit(msg) 24 | 25 | self.model_path = model_path 26 | self.joint_type = joint_type 27 | with open(model_path, 'r') as reader: 28 | model = json.load(reader) 29 | 30 | if obj_saveable: 31 | self.faces = model['f'] 32 | else: 33 | self.faces = None 34 | 35 | np_v_template = np.array(model['v_template'], dtype = np.float) 36 | self.register_buffer('v_template', torch.from_numpy(np_v_template).float()) 37 | self.size = [np_v_template.shape[0], 3] 38 | 39 | np_shapedirs = np.array(model['shapedirs'], dtype = np.float) 40 | self.num_betas = np_shapedirs.shape[-1] 41 | np_shapedirs = np.reshape(np_shapedirs, [-1, self.num_betas]).T 42 | self.register_buffer('shapedirs', torch.from_numpy(np_shapedirs).float()) 43 | 44 | np_J_regressor = np.array(model['J_regressor'], dtype = np.float) 45 | self.register_buffer('J_regressor', torch.from_numpy(np_J_regressor).float()) 46 | 47 | np_posedirs = np.array(model['posedirs'], dtype = np.float) 48 | num_pose_basis = np_posedirs.shape[-1] 49 | np_posedirs = np.reshape(np_posedirs, [-1, num_pose_basis]).T 50 | self.register_buffer('posedirs', torch.from_numpy(np_posedirs).float()) 51 | 52 | self.parents = np.array(model['kintree_table'])[0].astype(np.int32) 53 | 54 | np_joint_regressor = np.array(model['cocoplus_regressor'], dtype = np.float) 55 | if joint_type == 'lsp': 56 | self.register_buffer('joint_regressor', torch.from_numpy(np_joint_regressor[:, :14]).float()) 57 | else: 58 | self.register_buffer('joint_regressor', torch.from_numpy(np_joint_regressor).float()) 59 | 60 | np_weights = np.array(model['weights'], dtype = np.float) 61 | 62 | vertex_count = np_weights.shape[0] 63 | vertex_component = np_weights.shape[1] 64 | 65 | batch_size = max(args.batch_size + args.batch_3d_size, args.eval_batch_size) 66 | np_weights = np.tile(np_weights, (batch_size, 1)) 67 | self.register_buffer('weight', torch.from_numpy(np_weights).float().reshape(-1, vertex_count, vertex_component)) 68 | 69 | self.register_buffer('e3', torch.eye(3).float()) 70 | 71 | self.cur_device = None 72 | 73 | def save_obj(self, verts, obj_mesh_name): 74 | if not self.faces: 75 | msg = 'obj not saveable!' 76 | sys.exit(msg) 77 | 78 | with open(obj_mesh_name, 'w') as fp: 79 | for v in verts: 80 | fp.write( 'v %f %f %f\n' % ( v[0], v[1], v[2]) ) 81 | 82 | for f in self.faces: # Faces are 1-based, not 0-based in obj files 83 | fp.write( 'f %d %d %d\n' % (f[0] + 1, f[1] + 1, f[2] + 1) ) 84 | 85 | def forward(self, beta, theta, get_skin = False): 86 | if not self.cur_device: 87 | device = beta.device 88 | self.cur_device = torch.device(device.type, device.index) 89 | 90 | num_batch = beta.shape[0] 91 | 92 | v_shaped = torch.matmul(beta, self.shapedirs).view(-1, self.size[0], self.size[1]) + self.v_template 93 | Jx = torch.matmul(v_shaped[:, :, 0], self.J_regressor) 94 | Jy = torch.matmul(v_shaped[:, :, 1], self.J_regressor) 95 | Jz = torch.matmul(v_shaped[:, :, 2], self.J_regressor) 96 | J = torch.stack([Jx, Jy, Jz], dim = 2) 97 | 98 | Rs = batch_rodrigues(theta.view(-1, 3)).view(-1, 24, 3, 3) 99 | pose_feature = (Rs[:, 1:, :, :]).sub(1.0, self.e3).view(-1, 207) 100 | v_posed = torch.matmul(pose_feature, self.posedirs).view(-1, self.size[0], self.size[1]) + v_shaped 101 | self.J_transformed, A = batch_global_rigid_transformation(Rs, J, self.parents, rotate_base = True) 102 | 103 | weight = self.weight[:num_batch] 104 | W = weight.view(num_batch, -1, 24) 105 | T = torch.matmul(W, A.view(num_batch, 24, 16)).view(num_batch, -1, 4, 4) 106 | 107 | v_posed_homo = torch.cat([v_posed, torch.ones(num_batch, v_posed.shape[1], 1, device = self.cur_device)], dim = 2) 108 | v_homo = torch.matmul(T, torch.unsqueeze(v_posed_homo, -1)) 109 | 110 | verts = v_homo[:, :, :3, 0] 111 | 112 | joint_x = torch.matmul(verts[:, :, 0], self.joint_regressor) 113 | joint_y = torch.matmul(verts[:, :, 1], self.joint_regressor) 114 | joint_z = torch.matmul(verts[:, :, 2], self.joint_regressor) 115 | 116 | joints = torch.stack([joint_x, joint_y, joint_z], dim = 2) 117 | 118 | if get_skin: 119 | return verts, joints, Rs 120 | else: 121 | return joints 122 | 123 | if __name__ == '__main__': 124 | device = torch.device('cuda', 0) 125 | smpl = SMPL(args.smpl_model, obj_saveable = True).to(device) 126 | pose= np.array([ 127 | 1.22162998e+00, 1.17162502e+00, 1.16706634e+00, 128 | -1.20581151e-03, 8.60930011e-02, 4.45963144e-02, 129 | -1.52801601e-02, -1.16911056e-02, -6.02894090e-03, 130 | 1.62427306e-01, 4.26302850e-02, -1.55304456e-02, 131 | 2.58729942e-02, -2.15941742e-01, -6.59851432e-02, 132 | 7.79098943e-02, 1.96353287e-01, 6.44420758e-02, 133 | -5.43042570e-02, -3.45508829e-02, 1.13200583e-02, 134 | -5.60734887e-04, 3.21716577e-01, -2.18840033e-01, 135 | -7.61821344e-02, -3.64610642e-01, 2.97633410e-01, 136 | 9.65453908e-02, -5.54007106e-03, 2.83410680e-02, 137 | -9.57194716e-02, 9.02515948e-02, 3.31488043e-01, 138 | -1.18847653e-01, 2.96623230e-01, -4.76809204e-01, 139 | -1.53382001e-02, 1.72342166e-01, -1.44332021e-01, 140 | -8.10869411e-02, 4.68325168e-02, 1.42248288e-01, 141 | -4.60898802e-02, -4.05981280e-02, 5.28727695e-02, 142 | 3.20133418e-02, -5.23784310e-02, 2.41559884e-03, 143 | -3.08033824e-01, 2.31431410e-01, 1.62540793e-01, 144 | 6.28208935e-01, -1.94355965e-01, 7.23800480e-01, 145 | -6.49612308e-01, -4.07179184e-02, -1.46422181e-02, 146 | 4.51475441e-01, 1.59122205e+00, 2.70355493e-01, 147 | 2.04248756e-01, -6.33800551e-02, -5.50178960e-02, 148 | -1.00920045e+00, 2.39532292e-01, 3.62904727e-01, 149 | -3.38783532e-01, 9.40650925e-02, -8.44506770e-02, 150 | 3.55101633e-03, -2.68924050e-02, 4.93676625e-02],dtype = np.float) 151 | 152 | beta = np.array([-0.25349993, 0.25009069, 0.21440795, 0.78280628, 0.08625954, 153 | 0.28128183, 0.06626327, -0.26495767, 0.09009246, 0.06537955 ]) 154 | 155 | vbeta = torch.tensor(np.array([beta])).float().to(device) 156 | vpose = torch.tensor(np.array([pose])).float().to(device) 157 | 158 | verts, j, r = smpl(vbeta, vpose, get_skin = True) 159 | 160 | smpl.save_obj(verts[0].cpu().numpy(), './mesh.obj') 161 | 162 | rpose = reflect_pose(pose) 163 | vpose = torch.tensor(np.array([rpose])).float().to(device) 164 | 165 | verts, j, r = smpl(vbeta, vpose, get_skin = True) 166 | smpl.save_obj(verts[0].cpu().numpy(), './rmesh.obj') 167 | 168 | -------------------------------------------------------------------------------- /src/config.py: -------------------------------------------------------------------------------- 1 | 2 | ''' 3 | file: config.py 4 | 5 | date: 2018_04_29 6 | author: zhangxiong(1025679612@qq.com) 7 | ''' 8 | 9 | import argparse 10 | 11 | parser = argparse.ArgumentParser(description = 'hmr model') 12 | 13 | parser.add_argument( 14 | '--fine-tune', 15 | default = True, 16 | type = bool, 17 | help = 'fine tune or not.' 18 | ) 19 | 20 | parser.add_argument( 21 | '--encoder-network', 22 | type = str, 23 | default = 'resnet50', 24 | help = 'the encoder network name' 25 | ) 26 | 27 | parser.add_argument( 28 | '--smpl-mean-theta-path', 29 | type = str, 30 | default = 'E:/HMR/model/neutral_smpl_mean_params.h5', 31 | help = 'the path for mean smpl theta value' 32 | ) 33 | 34 | parser.add_argument( 35 | '--smpl-model', 36 | type = str, 37 | default = 'E:/HMR/model/neutral_smpl_with_cocoplus_reg.txt', 38 | help = 'smpl model path' 39 | ) 40 | 41 | parser.add_argument( 42 | '--total-theta-count', 43 | type = int, 44 | default = 85, 45 | help = 'the count of theta param' 46 | ) 47 | 48 | parser.add_argument( 49 | '--batch-size', 50 | type = int, 51 | default = 8, 52 | help = 'batch size' 53 | ) 54 | 55 | parser.add_argument( 56 | '--batch-3d-size', 57 | type = int, 58 | default = 8, 59 | help = '3d data batch size' 60 | ) 61 | 62 | parser.add_argument( 63 | '--adv-batch-size', 64 | type = int, 65 | default = 24, 66 | help = 'default adv batch size' 67 | ) 68 | 69 | parser.add_argument( 70 | '--eval-batch-size', 71 | type = int, 72 | default = 400, 73 | help = 'default eval batch size' 74 | ) 75 | 76 | parser.add_argument( 77 | '--joint-count', 78 | type = int, 79 | default = 24, 80 | help = 'the count of joints' 81 | ) 82 | 83 | parser.add_argument( 84 | '--beta-count', 85 | type = int, 86 | default = 10, 87 | help = 'the count of beta' 88 | ) 89 | 90 | parser.add_argument( 91 | '--use-adv-train', 92 | type = bool, 93 | default = True, 94 | help = 'use adv traing or not' 95 | ) 96 | 97 | parser.add_argument( 98 | '--scale-min', 99 | type = float, 100 | default = 1.1, 101 | help = 'min scale' 102 | ) 103 | 104 | parser.add_argument( 105 | '--scale-max', 106 | type = float, 107 | default = 1.5, 108 | help = 'max scale' 109 | ) 110 | 111 | parser.add_argument( 112 | '--num-worker', 113 | type = int, 114 | default = 1, 115 | help = 'pytorch number worker.' 116 | ) 117 | 118 | parser.add_argument( 119 | '--iter-count', 120 | type = int, 121 | default = 500001, 122 | help = 'iter count, eatch contains batch-size samples' 123 | ) 124 | 125 | parser.add_argument( 126 | '--e-lr', 127 | type = float, 128 | default = 0.00001, 129 | help = 'encoder learning rate.' 130 | ) 131 | 132 | parser.add_argument( 133 | '--d-lr', 134 | type = float, 135 | default = 0.0001, 136 | help = 'Adversarial prior learning rate.' 137 | ) 138 | 139 | parser.add_argument( 140 | '--e-wd', 141 | type = float, 142 | default = 0.0001, 143 | help = 'encoder weight decay rate.' 144 | ) 145 | 146 | parser.add_argument( 147 | '--d-wd', 148 | type = float, 149 | default = 0.0001, 150 | help = 'Adversarial prior weight decay' 151 | ) 152 | 153 | parser.add_argument( 154 | '--e-loss-weight', 155 | type = float, 156 | default = 60, 157 | help = 'weight on encoder 2d kp losses.' 158 | ) 159 | 160 | parser.add_argument( 161 | '--d-loss-weight', 162 | type = float, 163 | default = 1, 164 | help = 'weight on discriminator losses' 165 | ) 166 | 167 | 168 | parser.add_argument( 169 | '--d-disc-ratio', 170 | type = float, 171 | default = 1.0, 172 | help = 'multiple weight of discriminator loss' 173 | ) 174 | 175 | parser.add_argument( 176 | '--e-3d-loss-weight', 177 | type = float, 178 | default = 60, 179 | help = 'weight on encoder thetas losses.' 180 | ) 181 | 182 | parser.add_argument( 183 | '--e-shape-ratio', 184 | type = float, 185 | default = 5, 186 | help = 'multiple weight of shape loss' 187 | ) 188 | 189 | parser.add_argument( 190 | '--e-3d-kp-ratio', 191 | type = float, 192 | default = 10.0, 193 | help = 'multiple weight of 3d key point.' 194 | ) 195 | 196 | parser.add_argument( 197 | '--e-pose-ratio', 198 | type = float, 199 | default = 20, 200 | help = 'multiple weight of pose' 201 | ) 202 | 203 | parser.add_argument( 204 | '--save-folder', 205 | type = str, 206 | default = 'E:/HMR/data_advanced/trained_model', 207 | help = 'save model path' 208 | ) 209 | 210 | parser.add_argument( 211 | '--enable-inter-supervision', 212 | type = bool, 213 | default = False, 214 | help = 'enable inter supervision or not.' 215 | ) 216 | 217 | train_2d_set = ['coco', 'lsp', 'lsp_ext', 'ai-ch'] 218 | train_3d_set = ['mpi-inf-3dhp', 'hum3.6m'] 219 | train_adv_set = ['mosh'] 220 | eval_set = ['up3d'] 221 | 222 | allowed_encoder_net = ['hourglass', 'resnet50', 'densenet169'] 223 | 224 | encoder_feature_count = { 225 | 'hourglass' : 4096, 226 | 'resnet50' : 2048, 227 | 'densenet169' : 1664 228 | } 229 | 230 | crop_size = { 231 | 'hourglass':256, 232 | 'resnet50':224, 233 | 'densenet169':224 234 | } 235 | 236 | data_set_path = { 237 | 'coco':'E:/HMR/data/COCO/', 238 | 'lsp':'E:/HMR/data/lsp', 239 | 'lsp_ext':'E:/HMR/data/lsp_ext', 240 | 'ai-ch':'E:/HMR/data/ai_challenger_keypoint_train_20170902', 241 | 'mpi-inf-3dhp':'E:/HMR/data/mpi_inf_3dhp', 242 | 'hum3.6m':'E:/HMR/data/human3.6m', 243 | 'mosh':'E:/HMR/data/mosh_gen', 244 | 'up3d':'E:/HMR/data/up3d_mpii' 245 | } 246 | 247 | pre_trained_model = { 248 | 'generator' : '/media/disk1/zhangxiong/HMR/hmr_resnet50/fine_tuned/3500_generator.pkl', 249 | 'discriminator' : '/media/disk1/zhangxiong/HMR/hmr_resnet50/fine_tuned/3500_discriminator.pkl' 250 | } 251 | 252 | args = parser.parse_args() 253 | encoder_network = args.encoder_network 254 | args.feature_count = encoder_feature_count[encoder_network] 255 | args.crop_size = crop_size[encoder_network] 256 | -------------------------------------------------------------------------------- /src/dataloader/AICH_dataloader.py: -------------------------------------------------------------------------------- 1 | 2 | ''' 3 | file: AICH_dataloader.py 4 | 5 | author: zhangxiong(1025679612@qq.com) 6 | date: 2018_05_08 7 | purpose: load ai challenge dataset 8 | ''' 9 | 10 | import sys 11 | from torch.utils.data import Dataset, DataLoader 12 | import scipy.io as scio 13 | import os 14 | import glob 15 | import numpy as np 16 | import random 17 | import cv2 18 | import json 19 | import torch 20 | 21 | sys.path.append('./src') 22 | from util import calc_aabb, cut_image, flip_image, draw_lsp_14kp__bone, rectangle_intersect, get_rectangle_intersect_ratio, convert_image_by_pixformat_normalize 23 | from config import args 24 | from timer import Clock 25 | 26 | class AICH_dataloader(Dataset): 27 | def __init__(self, data_set_path, use_crop, scale_range, use_flip, only_single_person, min_pts_required, max_intersec_ratio = 0.1, pix_format = 'NHWC', normalize = False, flip_prob = 0.3): 28 | self.data_folder = data_set_path 29 | self.use_crop = use_crop 30 | self.scale_range = scale_range 31 | self.use_flip = use_flip 32 | self.flip_prob = flip_prob 33 | self.only_single_person = only_single_person 34 | self.min_pts_required = min_pts_required 35 | self.max_intersec_ratio = max_intersec_ratio 36 | self.img_ext = '.jpg' 37 | self.pix_format = pix_format 38 | self.normalize = normalize 39 | self._load_data_set() 40 | 41 | def _load_data_set(self): 42 | clk = Clock() 43 | 44 | self.images = [] 45 | self.kp2ds = [] 46 | self.boxs = [] 47 | print('start loading AI CH keypoint data.') 48 | anno_file_path = os.path.join(self.data_folder, 'keypoint_train_annotations_20170902.json') 49 | with open(anno_file_path, 'r') as reader: 50 | anno = json.load(reader) 51 | for record in anno: 52 | image_name = record['image_id'] + self.img_ext 53 | image_path = os.path.join(self.data_folder, 'keypoint_train_images_20170902', image_name) 54 | kp_set = record['keypoint_annotations'] 55 | box_set = record['human_annotations'] 56 | self._handle_image(image_path, kp_set, box_set) 57 | 58 | print('finished load Ai CH keypoint data, total {} samples'.format(len(self))) 59 | 60 | clk.stop() 61 | 62 | def _ai_ch_to_lsp(self, pts): 63 | kp_map = [8, 7, 6, 9, 10, 11, 2, 1, 0, 3, 4, 5, 13, 12] 64 | pts = np.array(pts, dtype = np.float).reshape(14, 3).copy() 65 | pts[:, 2] = (3.0 - pts[:, 2]) / 2.0 66 | return pts[kp_map].copy() 67 | 68 | def _handle_image(self, image_path, kp_set, box_set): 69 | assert len(kp_set) == len(box_set) 70 | 71 | if len(kp_set) > 1: 72 | if self.only_single_person: 73 | print('only single person supported now!') 74 | return 75 | for key in kp_set.keys(): 76 | kps = kp_set[key] 77 | box = box_set[key] 78 | self._handle_sample(key, image_path, kps, [ [box[0], box[1]], [box[2], box[3]] ], box_set) 79 | 80 | def _handle_sample(self, key, image_path, pts, box, boxs): 81 | def _collect_box(key, boxs): 82 | r = [] 83 | for k, v in boxs.items(): 84 | if k == key: 85 | continue 86 | r.append([[v[0],v[1]], [v[2],v[3]]]) 87 | return r 88 | 89 | def _collide_heavily(box, boxs): 90 | for it in boxs: 91 | if get_rectangle_intersect_ratio(box[0], box[1], it[0], it[1]) > self.max_intersec_ratio: 92 | return True 93 | return False 94 | pts = self._ai_ch_to_lsp(pts) 95 | valid_pt_cound = np.sum(pts[:, 2]) 96 | if valid_pt_cound < self.min_pts_required: 97 | return 98 | 99 | boxs = _collect_box(key, boxs) 100 | if _collide_heavily(box, boxs): 101 | return 102 | 103 | self.images.append(image_path) 104 | self.kp2ds.append(pts) 105 | lt, rb = box[0], box[1] 106 | self.boxs.append((np.array(lt), np.array(rb))) 107 | 108 | def __len__(self): 109 | return len(self.images) 110 | 111 | def __getitem__(self, index): 112 | image_path = self.images[index] 113 | kps = self.kp2ds[index].copy() 114 | box = self.boxs[index] 115 | 116 | scale = np.random.rand(4) * (self.scale_range[1] - self.scale_range[0]) + self.scale_range[0] 117 | image, kps = cut_image(image_path, kps, scale, box[0], box[1]) 118 | ratio = 1.0 * args.crop_size / image.shape[0] 119 | kps[:, :2] *= ratio 120 | dst_image = cv2.resize(image, (args.crop_size, args.crop_size), interpolation = cv2.INTER_CUBIC) 121 | 122 | if self.use_flip and random.random() <= self.flip_prob: 123 | dst_image, kps = flip_image(dst_image, kps) 124 | 125 | #normalize kp to [-1, 1] 126 | ratio = 1.0 / args.crop_size 127 | kps[:, :2] = 2.0 * kps[:, :2] * ratio - 1.0 128 | 129 | return { 130 | 'image': torch.tensor(convert_image_by_pixformat_normalize(dst_image, self.pix_format, self.normalize)).float(), 131 | 'kp_2d': torch.tensor(kps).float(), 132 | 'image_name': self.images[index], 133 | 'data_set':'AI Ch' 134 | } 135 | 136 | if __name__ == '__main__': 137 | aic = AICH_dataloader( 138 | data_set_path = 'E:/HMR/data/ai_challenger_keypoint_train_20170902', 139 | use_crop = True, 140 | scale_range = [1.1, 1.5], 141 | use_flip = True, 142 | only_single_person = False, 143 | min_pts_required = 5, 144 | flip_prob = 1.0 145 | ) 146 | l = len(aic) 147 | for _ in range(l): 148 | r = aic.__getitem__(_) 149 | image = r['image'].cpu().numpy().astype(np.uint8) 150 | kps = r['kp_2d'].cpu().numpy() 151 | kps[:, :2] = (kps[:, :2] + 1) * args.crop_size / 2.0 152 | base_name = os.path.basename(r['image_name']) 153 | draw_lsp_14kp__bone(image, kps) 154 | cv2.imshow(base_name, cv2.resize(image, (512, 512), interpolation = cv2.INTER_CUBIC)) 155 | cv2.waitKey(0) 156 | -------------------------------------------------------------------------------- /src/dataloader/COCO2017_dataloader.py: -------------------------------------------------------------------------------- 1 | 2 | ''' 3 | file: COCO2017_dataloader.py 4 | 5 | author: zhangxiong(1025679612@qq.com) 6 | date: 2018_05_09 7 | purpose: load COCO 2017 keypoint dataset 8 | ''' 9 | 10 | import sys 11 | from torch.utils.data import Dataset, DataLoader 12 | import scipy.io as scio 13 | import os 14 | import glob 15 | import numpy as np 16 | import random 17 | import cv2 18 | import json 19 | import torch 20 | 21 | sys.path.append('./src') 22 | from util import calc_aabb, cut_image, flip_image, draw_lsp_14kp__bone, rectangle_intersect, get_rectangle_intersect_ratio, convert_image_by_pixformat_normalize 23 | from config import args 24 | from timer import Clock 25 | 26 | class COCO2017_dataloader(Dataset): 27 | def __init__(self, data_set_path, use_crop, scale_range, use_flip, only_single_person, min_pts_required, max_intersec_ratio = 0.1, pix_format = 'NHWC', normalize = False, flip_prob = 0.3): 28 | self.data_folder = data_set_path 29 | self.use_crop = use_crop 30 | self.scale_range = scale_range 31 | self.use_flip = use_flip 32 | self.flip_prob = flip_prob 33 | self.only_single_person = only_single_person 34 | self.min_pts_required = min_pts_required 35 | self.max_intersec_ratio = max_intersec_ratio 36 | self.pix_format = pix_format 37 | self.normalize = normalize 38 | self._load_data_set() 39 | 40 | def _load_data_set(self): 41 | self.images = [] 42 | self.kp2ds = [] 43 | self.boxs = [] 44 | clk = Clock() 45 | print('start loading coco 2017 dataset.') 46 | anno_file_path = os.path.join(self.data_folder, 'annotations', 'person_keypoints_train2017.json') 47 | with open(anno_file_path, 'r') as reader: 48 | anno = json.load(reader) 49 | 50 | def _hash_image_id_(image_id_to_info, coco_images_info): 51 | for image_info in coco_images_info: 52 | image_id = image_info['id'] 53 | image_name = image_info['file_name'] 54 | _anno = {} 55 | _anno['image_path'] = os.path.join(self.data_folder, 'images', 'train-valid2017', image_name) 56 | _anno['kps'] = [] 57 | _anno['box'] = [] 58 | assert not (image_id in image_id_to_info) 59 | image_id_to_info[image_id] = _anno 60 | 61 | images = anno['images'] 62 | 63 | image_id_to_info = {} 64 | _hash_image_id_(image_id_to_info, images) 65 | 66 | 67 | annos = anno['annotations'] 68 | for anno_info in annos: 69 | self._handle_anno_info(anno_info, image_id_to_info) 70 | 71 | for k, v in image_id_to_info.items(): 72 | self._handle_image_info_(v) 73 | 74 | print('finished load coco 2017 dataset, total {} samples.'.format(len(self.images))) 75 | 76 | clk.stop() 77 | 78 | def _handle_image_info_(self, image_info): 79 | image_path = image_info['image_path'] 80 | kp_set = image_info['kps'] 81 | box_set = image_info['box'] 82 | if len(box_set) > 1: 83 | if self.only_single_person: 84 | return 85 | 86 | for _ in range(len(box_set)): 87 | self._handle_sample(_, kp_set, box_set, image_path) 88 | 89 | def _handle_sample(self, key, kps, boxs, image_path): 90 | def _collect_box(l, boxs): 91 | r = [] 92 | for _ in range(len(boxs)): 93 | if _ == l: 94 | continue 95 | r.append(boxs[_]) 96 | return r 97 | 98 | def _collide_heavily(box, boxs): 99 | for it in boxs: 100 | if get_rectangle_intersect_ratio(box[0], box[1], it[0], it[1]) > self.max_intersec_ratio: 101 | return True 102 | return False 103 | 104 | kp = kps[key] 105 | box = boxs[key] 106 | 107 | valid_pt_cound = np.sum(kp[:, 2]) 108 | if valid_pt_cound < self.min_pts_required: 109 | return 110 | 111 | r = _collect_box(key, boxs) 112 | if _collide_heavily(box, r): 113 | return 114 | 115 | self.images.append(image_path) 116 | self.kp2ds.append(kp.copy()) 117 | self.boxs.append(box.copy()) 118 | 119 | def _handle_anno_info(self, anno_info, image_id_to_info): 120 | image_id = anno_info['image_id'] 121 | kps = anno_info['keypoints'] 122 | box_info = anno_info['bbox'] 123 | box = [np.array([int(box_info[0]), int(box_info[1])]), np.array([int(box_info[0] + box_info[2]), int(box_info[1] + box_info[3])])] 124 | assert image_id in image_id_to_info 125 | _anno = image_id_to_info[image_id] 126 | _anno['box'].append(box) 127 | _anno['kps'].append(self._convert_to_lsp14_pts(kps)) 128 | 129 | def _convert_to_lsp14_pts(self, coco_pts): 130 | kp_map = [15, 13, 11, 10, 12, 14, 9, 7, 5, 4, 6, 8, 0, 0] 131 | kp_map = [16, 14, 12, 11, 13, 15, 10, 8, 6, 5, 7, 9, 0, 0] 132 | kps = np.array(coco_pts, dtype = np.float).reshape(-1, 3)[kp_map].copy() 133 | kps[12: ,2] = 0.0 #no neck, top head 134 | kps[:, 2] /= 2.0 135 | return kps 136 | 137 | def __len__(self): 138 | return len(self.images) 139 | 140 | def __getitem__(self, index): 141 | image_path = self.images[index] 142 | kps = self.kp2ds[index].copy() 143 | box = self.boxs[index] 144 | 145 | scale = np.random.rand(4) * (self.scale_range[1] - self.scale_range[0]) + self.scale_range[0] 146 | image, kps = cut_image(image_path, kps, scale, box[0], box[1]) 147 | ratio = 1.0 * args.crop_size / image.shape[0] 148 | kps[:, :2] *= ratio 149 | dst_image = cv2.resize(image, (args.crop_size, args.crop_size), interpolation = cv2.INTER_CUBIC) 150 | 151 | if self.use_flip and random.random() <= self.flip_prob: 152 | dst_image, kps = flip_image(dst_image, kps) 153 | 154 | #normalize kp to [-1, 1] 155 | ratio = 1.0 / args.crop_size 156 | kps[:, :2] = 2.0 * kps[:, :2] * ratio - 1.0 157 | 158 | return { 159 | 'image': torch.tensor(convert_image_by_pixformat_normalize(dst_image, self.pix_format, self.normalize)).float(), 160 | 'kp_2d': torch.tensor(kps).float(), 161 | 'image_name': self.images[index], 162 | 'data_set':'COCO 2017' 163 | } 164 | 165 | if __name__ == '__main__': 166 | coco = COCO2017_dataloader('E:/HMR/data/COCO/', True, [1.1, 1.5], False, False, 10, 0.1) 167 | l = len(coco) 168 | for _ in range(l): 169 | r = lsp.__getitem__(_) 170 | image = r['image'].cpu().numpy().astype(np.uint8) 171 | kps = r['kp_2d'].cpu().numpy() 172 | base_name = os.path.basename(r['image_name']) 173 | draw_lsp_14kp__bone(image, kps) 174 | cv2.imshow(base_name, cv2.resize(image, (512, 512), interpolation = cv2.INTER_CUBIC)) 175 | cv2.waitKey(0) 176 | -------------------------------------------------------------------------------- /src/dataloader/eval_dataloader.py: -------------------------------------------------------------------------------- 1 | 2 | ''' 3 | file: eval_dataloader.py 4 | 5 | author: zhangxiong(1025679612@qq.com) 6 | date: 2018_05_20 7 | purpose: load evaluation data 8 | ''' 9 | import sys 10 | from torch.utils.data import Dataset, DataLoader 11 | import scipy.io as scio 12 | import os 13 | import glob 14 | import numpy as np 15 | import random 16 | import cv2 17 | import json 18 | import torch 19 | 20 | sys.path.append('./src') 21 | from util import calc_aabb, cut_image, flip_image, draw_lsp_14kp__bone, rectangle_intersect, get_rectangle_intersect_ratio, convert_image_by_pixformat_normalize, reflect_pose 22 | from config import args 23 | # from timer import Clock 24 | 25 | class eval_dataloader(Dataset): 26 | def __init__(self, data_set_path, use_flip, pix_format = 'NHWC', normalize = False, flip_prob = 0.3): 27 | self.use_flip = use_flip 28 | self.flip_prob = flip_prob 29 | self.data_folder = data_set_path 30 | self.pix_format = pix_format 31 | self.normalize = normalize 32 | 33 | self._load_data_set() 34 | 35 | def _load_data_set(self): 36 | # clk = Clock() 37 | 38 | self.images = sorted(glob.glob(os.path.join(self.data_folder, 'image/*.png'))) 39 | self.kp2ds = [] 40 | self.poses = [] 41 | self.betas = [] 42 | 43 | for idx in range(len(self.images)): 44 | image_name = os.path.basename(self.images[idx])[:5] 45 | anno_path = os.path.join(self.data_folder, 'annos', image_name + '_joints.npy') 46 | self.kp2ds.append(np.load(anno_path).T) 47 | anno_path = os.path.join(self.data_folder, 'annos', image_name +'.json') 48 | with open(anno_path, 'r') as fp: 49 | annos = json.load(fp) 50 | self.poses.append(np.array(annos['pose'])) 51 | self.betas.append(np.array(annos['betas'])) 52 | 53 | # clk.stop() 54 | 55 | def __len__(self): 56 | return len(self.images) 57 | 58 | def __getitem__(self, index): 59 | image_path = self.images[index] 60 | kps = self.kp2ds[index].copy() 61 | pose = self.poses[index].copy() 62 | shape = self.betas[index].copy() 63 | 64 | dst_image = cv2.imread(image_path) 65 | 66 | 67 | if self.use_flip and random.random() <= self.flip_prob: 68 | dst_image, kps = flip_image(dst_image, kps) 69 | pose = reflect_poses(pose) 70 | 71 | #normalize kp to [-1, 1] 72 | ratio = 1.0 / args.crop_size 73 | kps[:, :2] = 2.0 * kps[:, :2] * ratio - 1.0 74 | 75 | return { 76 | 'image': torch.tensor(convert_image_by_pixformat_normalize(dst_image, self.pix_format, self.normalize)).float(), 77 | 'kp_2d': torch.tensor(kps).float(), 78 | 'pose': torch.tensor(pose).float(), 79 | 'shape': torch.tensor(shape).float(), 80 | 'image_name': self.images[index], 81 | 'data_set':'up_3d_evaluation' 82 | } 83 | 84 | if __name__ == '__main__': 85 | evl = eval_dataloader('E:/HMR/data/up3d_mpii', True) 86 | l = evl.__len__() 87 | data_loader = DataLoader(evl, batch_size=10,shuffle=True) 88 | for _ in range(l): 89 | r = evl.__getitem__(_) 90 | pass 91 | 92 | 93 | -------------------------------------------------------------------------------- /src/dataloader/hum36m_dataloader.py: -------------------------------------------------------------------------------- 1 | 2 | ''' 3 | file: hum36m_dataloader.py 4 | 5 | author: zhangxiong(1025679612@qq.com) 6 | date: 2018_05_09 7 | purpose: load hum3.6m data 8 | ''' 9 | 10 | import sys 11 | from torch.utils.data import Dataset, DataLoader 12 | import os 13 | import glob 14 | import numpy as np 15 | import random 16 | import cv2 17 | import json 18 | import h5py 19 | import torch 20 | 21 | sys.path.append('./src') 22 | from util import calc_aabb, cut_image, flip_image, draw_lsp_14kp__bone, rectangle_intersect, get_rectangle_intersect_ratio, convert_image_by_pixformat_normalize, reflect_pose, reflect_lsp_kp 23 | from config import args 24 | from timer import Clock 25 | 26 | class hum36m_dataloader(Dataset): 27 | def __init__(self, data_set_path, use_crop, scale_range, use_flip, min_pts_required, pix_format = 'NHWC', normalize = False, flip_prob = 0.3): 28 | self.data_folder = data_set_path 29 | self.use_crop = use_crop 30 | self.scale_range = scale_range 31 | self.use_flip = use_flip 32 | self.flip_prob = flip_prob 33 | self.min_pts_required = min_pts_required 34 | self.pix_format = pix_format 35 | self.normalize = normalize 36 | self._load_data_set() 37 | 38 | def _load_data_set(self): 39 | 40 | clk = Clock() 41 | 42 | self.images = [] 43 | self.kp2ds = [] 44 | self.boxs = [] 45 | self.kp3ds = [] 46 | self.shapes = [] 47 | self.poses = [] 48 | 49 | print('start loading hum3.6m data.') 50 | 51 | anno_file_path = os.path.join(self.data_folder, 'annot.h5') 52 | with h5py.File(anno_file_path) as fp: 53 | total_kp2d = np.array(fp['gt2d']) 54 | total_kp3d = np.array(fp['gt3d']) 55 | total_shap = np.array(fp['shape']) 56 | total_pose = np.array(fp['pose']) 57 | total_image_names = np.array(fp['imagename']) 58 | 59 | assert len(total_kp2d) == len(total_kp3d) and len(total_kp2d) == len(total_image_names) and \ 60 | len(total_kp2d) == len(total_shap) and len(total_kp2d) == len(total_pose) 61 | 62 | l = len(total_kp2d) 63 | def _collect_valid_pts(pts): 64 | r = [] 65 | for pt in pts: 66 | if pt[2] != 0: 67 | r.append(pt) 68 | return r 69 | 70 | for index in range(l): 71 | kp2d = total_kp2d[index].reshape((-1, 3)) 72 | if np.sum(kp2d[:, 2]) < self.min_pts_required: 73 | continue 74 | 75 | lt, rb, v = calc_aabb(_collect_valid_pts(kp2d)) 76 | self.kp2ds.append(np.array(kp2d.copy(), dtype = np.float)) 77 | self.boxs.append((lt, rb)) 78 | self.kp3ds.append(total_kp3d[index].copy().reshape(-1, 3)) 79 | self.shapes.append(total_shap[index].copy()) 80 | self.poses.append(total_pose[index].copy()) 81 | self.images.append(os.path.join(self.data_folder, 'image') + total_image_names[index].decode()) 82 | 83 | print('finished load hum3.6m data, total {} samples'.format(len(self.kp3ds))) 84 | 85 | clk.stop() 86 | 87 | def __len__(self): 88 | return len(self.images) 89 | 90 | def __getitem__(self, index): 91 | image_path = self.images[index] 92 | kps = self.kp2ds[index].copy() 93 | box = self.boxs[index] 94 | kp_3d = self.kp3ds[index].copy() 95 | 96 | scale = np.random.rand(4) * (self.scale_range[1] - self.scale_range[0]) + self.scale_range[0] 97 | image, kps = cut_image(image_path, kps, scale, box[0], box[1]) 98 | 99 | ratio = 1.0 * args.crop_size / image.shape[0] 100 | kps[:, :2] *= ratio 101 | dst_image = cv2.resize(image, (args.crop_size, args.crop_size), interpolation = cv2.INTER_CUBIC) 102 | 103 | trival, shape, pose = np.zeros(3), self.shapes[index], self.poses[index] 104 | 105 | if self.use_flip and random.random() <= self.flip_prob: 106 | dst_image, kps = flip_image(dst_image, kps) 107 | pose = reflect_pose(pose) 108 | kp_3d = reflect_lsp_kp(kp_3d) 109 | 110 | #normalize kp to [-1, 1] 111 | ratio = 1.0 / args.crop_size 112 | kps[:, :2] = 2.0 * kps[:, :2] * ratio - 1.0 113 | 114 | theta = np.concatenate((trival, pose, shape), axis = 0) 115 | 116 | return { 117 | 'image': torch.from_numpy(convert_image_by_pixformat_normalize(dst_image, self.pix_format, self.normalize)).float(), 118 | 'kp_2d': torch.from_numpy(kps).float(), 119 | 'kp_3d': torch.from_numpy(kp_3d).float(), 120 | 'theta': torch.from_numpy(theta).float(), 121 | 'image_name': self.images[index], 122 | 'w_smpl':1.0, 123 | 'w_3d':1.0, 124 | 'data_set':'hum3.6m' 125 | } 126 | 127 | if __name__ == '__main__': 128 | h36m = hum36m_dataloader('E:/HMR/data/human3.6m', True, [1.1, 2.0], True, 5, flip_prob = 1) 129 | l = len(h36m) 130 | for _ in range(l): 131 | r = h36m.__getitem__(_) 132 | pass -------------------------------------------------------------------------------- /src/dataloader/lsp_dataloader.py: -------------------------------------------------------------------------------- 1 | 2 | ''' 3 | file: lsp_dataloader.py 4 | 5 | author: zhangxiong(1025679612@qq.com) 6 | date: 2018_05_07 7 | ''' 8 | import sys 9 | from torch.utils.data import Dataset, DataLoader 10 | import scipy.io as scio 11 | import os 12 | import glob 13 | import numpy as np 14 | import random 15 | import cv2 16 | import torch 17 | 18 | sys.path.append('./src') 19 | 20 | from util import calc_aabb, cut_image, flip_image, draw_lsp_14kp__bone, convert_image_by_pixformat_normalize, reflect_lsp_kp 21 | from config import args 22 | from timer import Clock 23 | 24 | class LspLoader(Dataset): 25 | def __init__(self, data_set_path, use_crop, scale_range, use_flip, pix_format = 'NHWC', normalize = False, flip_prob = 0.3): 26 | ''' 27 | marks: 28 | data_set path links to the parent folder to lsp, which contains images, joints.mat, README.txt 29 | 30 | inputs: 31 | use_crop crop the image or not, it should be True by default 32 | scale_range, contain the scale range 33 | use_flip, left right flip is allowed 34 | ''' 35 | self.use_crop = use_crop 36 | self.scale_range = scale_range 37 | self.use_flip = use_flip 38 | self.flip_prob = flip_prob 39 | self.data_folder = data_set_path 40 | self.pix_format = pix_format 41 | self.normalize = normalize 42 | 43 | self._load_data_set() 44 | 45 | def _load_data_set(self): 46 | clk = Clock() 47 | print('loading LSP data.') 48 | self.images = [] 49 | self.kp2ds = [] 50 | self.boxs = [] 51 | 52 | anno_file_path = os.path.join(self.data_folder, 'joints.mat') 53 | anno = scio.loadmat(anno_file_path) 54 | kp2d = anno['joints'].transpose(2, 1, 0) # N x k x 3 55 | visible = np.logical_not(kp2d[:, :, 2]) 56 | kp2d[:, :, 2] = visible.astype(kp2d.dtype) 57 | image_folder = os.path.join(self.data_folder, 'images') 58 | images = sorted(glob.glob(image_folder + '/im*.jpg')) 59 | for _ in range(len(images)): 60 | self._handle_image(images[_], kp2d[_]) 61 | 62 | print('finished load LSP data.') 63 | clk.stop() 64 | 65 | def _handle_image(self, image_path, kps): 66 | pt_valid = [] 67 | for pt in kps: 68 | if pt[2] == 1: 69 | pt_valid.append(pt) 70 | lt, rb, valid = calc_aabb(pt_valid) 71 | 72 | if not valid: 73 | return 74 | 75 | self.kp2ds.append(kps.copy().astype(np.float)) 76 | self.images.append(image_path) 77 | self.boxs.append((lt, rb)) 78 | 79 | def __len__(self): 80 | return len(self.images) 81 | 82 | def __getitem__(self, index): 83 | image_path = self.images[index] 84 | kps = self.kp2ds[index].copy() 85 | box = self.boxs[index] 86 | 87 | scale = np.random.rand(4) * (self.scale_range[1] - self.scale_range[0]) + self.scale_range[0] 88 | image, kps = cut_image(image_path, kps, scale, box[0], box[1]) 89 | ratio = 1.0 * args.crop_size / image.shape[0] 90 | kps[:, :2] *= ratio 91 | dst_image = cv2.resize(image, (args.crop_size, args.crop_size), interpolation = cv2.INTER_CUBIC) 92 | 93 | if self.use_flip and random.random() <= self.flip_prob: 94 | dst_image, kps = flip_image(dst_image, kps) 95 | 96 | #normalize kp to [-1, 1] 97 | ratio = 1.0 / args.crop_size 98 | kps[:, :2] = 2.0 * kps[:, :2] * ratio - 1.0 99 | return { 100 | 'image': torch.tensor(convert_image_by_pixformat_normalize(dst_image, self.pix_format, self.normalize)).float(), 101 | 'kp_2d': torch.tensor(kps).float(), 102 | 'image_name': self.images[index], 103 | 'data_set':'lsp' 104 | } 105 | 106 | if __name__ == '__main__': 107 | lsp = LspLoader( 108 | data_set_path = 'E:/HMR/data/lsp', 109 | use_crop = True, 110 | scale_range = [1.05, 1.2], 111 | use_flip = True, 112 | flip_prob = 1.0 113 | ) 114 | l = lsp.__len__() 115 | data_loader = DataLoader(lsp, batch_size=10,shuffle=True) 116 | for _ in range(l): 117 | r = lsp.__getitem__(_) 118 | image = r['image'].cpu().numpy().astype(np.uint8) 119 | kps = r['kp_2d'].cpu().numpy() 120 | kps[:, :2] = (kps[:, :2] + 1) * args.crop_size / 2.0 121 | base_name = os.path.basename(r['image_name']) 122 | draw_lsp_14kp__bone(image, kps) 123 | cv2.imshow(base_name, cv2.resize(image, (512, 512), interpolation = cv2.INTER_CUBIC)) 124 | cv2.waitKey(0) 125 | 126 | -------------------------------------------------------------------------------- /src/dataloader/lsp_ext_dataloader.py: -------------------------------------------------------------------------------- 1 | 2 | ''' 3 | file: lsp_ext_dataloader.py 4 | 5 | author: zhangxiong(1025679612@qq.com) 6 | date: 2018_05_07 7 | ''' 8 | import sys 9 | from torch.utils.data import Dataset, DataLoader 10 | import scipy.io as scio 11 | import os 12 | import glob 13 | import numpy as np 14 | import random 15 | import cv2 16 | import torch 17 | 18 | sys.path.append('./src') 19 | 20 | from util import calc_aabb, cut_image, flip_image, draw_lsp_14kp__bone, convert_image_by_pixformat_normalize 21 | from config import args 22 | from timer import Clock 23 | 24 | 25 | class LspExtLoader(Dataset): 26 | def __init__(self, data_set_path, use_crop, scale_range, use_flip, pix_format = 'NHWC', normalize = False, flip_prob = 0.3): 27 | ''' 28 | marks: 29 | data_set path links to the parent folder to lsp, which contains images, joints.mat, README.txt 30 | 31 | inputs: 32 | use_crop crop the image or not, it should be True by default 33 | scale_range, contain the scale range 34 | use_flip, left right flip is allowed 35 | ''' 36 | self.use_crop = use_crop 37 | self.scale_range = scale_range 38 | self.use_flip = use_flip 39 | self.flip_prob = flip_prob 40 | self.data_folder = data_set_path 41 | self.pix_format = pix_format 42 | self.normalize = normalize 43 | 44 | self._load_data_set() 45 | 46 | def _load_data_set(self): 47 | clk = Clock() 48 | 49 | print('loading LSP ext data.') 50 | self.images = [] 51 | self.kp2ds = [] 52 | self.boxs = [] 53 | 54 | anno_file_path = os.path.join(self.data_folder, 'joints.mat') 55 | anno = scio.loadmat(anno_file_path) 56 | kp2d = anno['joints'].transpose(2, 0, 1) # N x k x 3 57 | image_folder = os.path.join(self.data_folder, 'images') 58 | images = sorted(glob.glob(image_folder + '/im*.jpg')) 59 | for _ in range(len(images)): 60 | self._handle_image(images[_], kp2d[_]) 61 | 62 | print('finished load LSP ext data.') 63 | clk.stop() 64 | 65 | def _handle_image(self, image_path, kps): 66 | pt_valid = [] 67 | for pt in kps: 68 | if pt[2] == 1: 69 | pt_valid.append(pt) 70 | lt, rb, valid = calc_aabb(pt_valid) 71 | 72 | if not valid: 73 | return 74 | 75 | self.kp2ds.append(kps.copy().astype(np.float)) 76 | self.images.append(image_path) 77 | self.boxs.append((lt, rb)) 78 | 79 | def __len__(self): 80 | return len(self.images) 81 | 82 | def __getitem__(self, index): 83 | image_path = self.images[index] 84 | kps = self.kp2ds[index].copy() 85 | box = self.boxs[index] 86 | 87 | scale = np.random.rand(4) * (self.scale_range[1] - self.scale_range[0]) + self.scale_range[0] 88 | image, kps = cut_image(image_path, kps, scale, box[0], box[1]) 89 | ratio = 1.0 * args.crop_size / image.shape[0] 90 | kps[:, :2] *= ratio 91 | dst_image = cv2.resize(image, (args.crop_size, args.crop_size), interpolation = cv2.INTER_CUBIC) 92 | 93 | if self.use_flip and random.random() <= self.flip_prob: 94 | dst_image, kps = flip_image(dst_image, kps) 95 | 96 | #normalize kp to [-1, 1] 97 | ratio = 1.0 / args.crop_size 98 | kps[:, :2] = 2.0 * kps[:, :2] * ratio - 1.0 99 | 100 | return { 101 | 'image': torch.tensor(convert_image_by_pixformat_normalize(dst_image, self.pix_format, self.normalize)).float(), 102 | 'kp_2d': torch.tensor(kps).float(), 103 | 'image_name': self.images[index], 104 | 'data_set':'lsp_ext' 105 | } 106 | 107 | if __name__ == '__main__': 108 | lsp = LspExtLoader('E:/HMR/data/lsp_ext', True, [1.05, 1.5], False, flip_prob = 1.0) 109 | l = lsp.__len__() 110 | 111 | data_loader = DataLoader(lsp, batch_size=10,shuffle=True) 112 | 113 | for _ in range(l): 114 | r = lsp.__getitem__(_) 115 | image = r['image'].cpu().numpy().astype(np.uint8) 116 | kps = r['kp_2d'].cpu().numpy() 117 | base_name = os.path.basename(r['image_name']) 118 | draw_lsp_14kp__bone(image, kps) 119 | cv2.imshow(base_name, cv2.resize(image, (512, 512), interpolation = cv2.INTER_CUBIC)) 120 | cv2.waitKey(0) 121 | 122 | -------------------------------------------------------------------------------- /src/dataloader/mosh_dataloader.py: -------------------------------------------------------------------------------- 1 | 2 | ''' 3 | file: mosh_dataloader.py 4 | 5 | author: zhangxiong(1025679612@qq.com) 6 | date: 2018_05_09 7 | purpose: load COCO 2017 keypoint dataset 8 | ''' 9 | 10 | import sys 11 | from torch.utils.data import Dataset, DataLoader 12 | import scipy.io as scio 13 | import os 14 | import glob 15 | import numpy as np 16 | import random 17 | import cv2 18 | import json 19 | import h5py 20 | import torch 21 | 22 | sys.path.append('./src') 23 | from util import calc_aabb, cut_image, flip_image, draw_lsp_14kp__bone, rectangle_intersect, get_rectangle_intersect_ratio, reflect_pose 24 | from config import args 25 | from timer import Clock 26 | 27 | 28 | class mosh_dataloader(Dataset): 29 | def __init__(self, data_set_path, use_flip = True, flip_prob = 0.3): 30 | self.data_folder = data_set_path 31 | self.use_flip = use_flip 32 | self.flip_prob = flip_prob 33 | 34 | self._load_data_set() 35 | 36 | def _load_data_set(self): 37 | clk = Clock() 38 | print('start loading mosh data.') 39 | anno_file_path = os.path.join(self.data_folder, 'mosh_annot.h5') 40 | with h5py.File(anno_file_path) as fp: 41 | self.shapes = np.array(fp['shape']) 42 | self.poses = np.array(fp['pose']) 43 | print('finished load mosh data, total {} samples'.format(len(self.poses))) 44 | clk.stop() 45 | 46 | def __len__(self): 47 | return len(self.poses) 48 | 49 | def __getitem__(self, index): 50 | trival, pose, shape = np.zeros(3), self.poses[index], self.shapes[index] 51 | 52 | if self.use_flip and random.uniform(0, 1) <= self.flip_prob:#left-right reflect the pose 53 | pose = reflect_pose(pose) 54 | 55 | return { 56 | 'theta': torch.tensor(np.concatenate((trival, pose, shape), axis = 0)).float() 57 | } 58 | 59 | if __name__ == '__main__': 60 | print(random.rand(1)) 61 | mosh = mosh_dataloader('E:/HMR/data/mosh_gen') 62 | l = len(mosh) 63 | import time 64 | for _ in range(l): 65 | r = mosh.__getitem__(_) 66 | print(r) -------------------------------------------------------------------------------- /src/dataloader/mpi_inf_3dhp_dataloader.py: -------------------------------------------------------------------------------- 1 | 2 | ''' 3 | file: mpi_inf_3dhp_dataloader.py 4 | 5 | author: zhangxiong(1025679612@qq.com) 6 | date: 2018_05_09 7 | purpose: load mpi inf 3dhp data 8 | ''' 9 | 10 | import sys 11 | from torch.utils.data import Dataset, DataLoader 12 | import os 13 | import glob 14 | import numpy as np 15 | import random 16 | import cv2 17 | import json 18 | import h5py 19 | import torch 20 | 21 | sys.path.append('./src') 22 | from util import calc_aabb, cut_image, flip_image, draw_lsp_14kp__bone, rectangle_intersect, get_rectangle_intersect_ratio, convert_image_by_pixformat_normalize, reflect_pose, reflect_lsp_kp 23 | from config import args 24 | from timer import Clock 25 | 26 | class mpi_inf_3dhp_dataloader(Dataset): 27 | def __init__(self, data_set_path, use_crop, scale_range, use_flip, min_pts_required, pix_format = 'NHWC', normalize = False, flip_prob = 0.3): 28 | self.data_folder = data_set_path 29 | self.use_crop = use_crop 30 | self.scale_range = scale_range 31 | self.use_flip = use_flip 32 | self.flip_prob = 0.3 33 | self.min_pts_required = min_pts_required 34 | self.pix_format = pix_format 35 | self.normalize = normalize 36 | self._load_data_set() 37 | 38 | def _load_data_set(self): 39 | clk = Clock() 40 | 41 | self.images = [] 42 | self.kp2ds = [] 43 | self.boxs = [] 44 | self.kp3ds = [] 45 | 46 | print('start loading mpii-inf-3dhp data.') 47 | anno_file_path = os.path.join(self.data_folder, 'annot.h5') 48 | with h5py.File(anno_file_path) as fp: 49 | total_kp2d = np.array(fp['gt2d']) 50 | total_kp3d = np.array(fp['gt3d']) 51 | total_image_names = np.array(fp['imagename']) 52 | 53 | assert len(total_kp2d) == len(total_kp3d) and len(total_kp2d) == len(total_image_names) 54 | 55 | l = len(total_kp2d) 56 | def _collect_valid_pts(pts): 57 | r = [] 58 | for pt in pts: 59 | if pt[2] != 0: 60 | r.append(pt) 61 | return r 62 | 63 | for index in range(l): 64 | kp2d = total_kp2d[index].reshape((-1, 3)) 65 | if np.sum(kp2d[:, 2]) < self.min_pts_required: 66 | continue 67 | 68 | lt, rb, v = calc_aabb(_collect_valid_pts(kp2d)) 69 | self.kp2ds.append(np.array(kp2d.copy(), dtype = np.float)) 70 | self.boxs.append((lt, rb)) 71 | self.kp3ds.append(total_kp3d[index].copy().reshape(-1, 3)) 72 | self.images.append(os.path.join(self.data_folder, 'image') + total_image_names[index].decode()) 73 | 74 | print('finished load mpii-inf-3dhp data, total {} samples'.format(len(self.images))) 75 | clk.stop() 76 | 77 | def __len__(self): 78 | return len(self.images) 79 | 80 | def __getitem__(self, index): 81 | image_path = self.images[index] 82 | kps = self.kp2ds[index].copy() 83 | box = self.boxs[index] 84 | kp_3d = self.kp3ds[index].copy() 85 | 86 | scale = np.random.rand(4) * (self.scale_range[1] - self.scale_range[0]) + self.scale_range[0] 87 | image, kps = cut_image(image_path, kps, scale, box[0], box[1]) 88 | 89 | ratio = 1.0 * args.crop_size / image.shape[0] 90 | kps[:, :2] *= ratio 91 | dst_image = cv2.resize(image, (args.crop_size, args.crop_size), interpolation = cv2.INTER_CUBIC) 92 | 93 | if self.use_flip and random.random() <= self.flip_prob: 94 | dst_image, kps = flip_image(dst_image, kps) 95 | kp_3d = reflect_lsp_kp(kp_3d) 96 | 97 | #normalize kp to [-1, 1] 98 | ratio = 1.0 / args.crop_size 99 | kps[:, :2] = 2.0 * kps[:, :2] * ratio - 1.0 100 | 101 | return { 102 | 'image': torch.from_numpy(convert_image_by_pixformat_normalize(dst_image, self.pix_format, self.normalize)).float(), 103 | 'kp_2d': torch.from_numpy(kps).float(), 104 | 'kp_3d': torch.from_numpy(kp_3d).float(), 105 | 'theta': torch.zeros(85).float(), 106 | 'image_name': self.images[index], 107 | 'w_smpl':0.0, 108 | 'w_3d':1.0, 109 | 'data_set':'mpi inf 3dhp' 110 | } 111 | 112 | if __name__ == '__main__': 113 | mpi = mpi_inf_3dhp_dataloader('E:/HMR/data/mpii_inf_3dhp', True, [1.1, 2.0], False, 5) 114 | l = len(mpi) 115 | for _ in range(l): 116 | r = mpi.__getitem__(_) 117 | base_name = os.path.basename(r['image_name']) 118 | draw_lsp_14kp__bone(r['image'], r['kp_2d']) 119 | cv2.imshow(base_name, cv2.resize(r['image'], (512, 512), interpolation = cv2.INTER_CUBIC)) 120 | cv2.waitKey(0) 121 | -------------------------------------------------------------------------------- /src/densenet.py: -------------------------------------------------------------------------------- 1 | 2 | import re 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.utils.model_zoo as model_zoo 7 | from collections import OrderedDict 8 | import sys 9 | 10 | __all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161'] 11 | 12 | 13 | model_urls = { 14 | 'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth', 15 | 'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth', 16 | 'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth', 17 | 'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth', 18 | } 19 | 20 | 21 | def densenet121(pretrained=False, **kwargs): 22 | r"""Densenet-121 model from 23 | `"Densely Connected Convolutional Networks" `_ 24 | 25 | Args: 26 | pretrained (bool): If True, returns a model pre-trained on ImageNet 27 | """ 28 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), 29 | **kwargs) 30 | if pretrained: 31 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 32 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 33 | # They are also in the checkpoints in model_urls. This pattern is used 34 | # to find such keys. 35 | pattern = re.compile( 36 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 37 | state_dict = model_zoo.load_url(model_urls['densenet121']) 38 | for key in list(state_dict.keys()): 39 | res = pattern.match(key) 40 | if res: 41 | new_key = res.group(1) + res.group(2) 42 | state_dict[new_key] = state_dict[key] 43 | del state_dict[key] 44 | model.load_state_dict(state_dict) 45 | return model 46 | 47 | 48 | def densenet169(pretrained=False, **kwargs): 49 | r"""Densenet-169 model from 50 | `"Densely Connected Convolutional Networks" `_ 51 | 52 | Args: 53 | pretrained (bool): If True, returns a model pre-trained on ImageNet 54 | """ 55 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32), 56 | **kwargs) 57 | if pretrained: 58 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 59 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 60 | # They are also in the checkpoints in model_urls. This pattern is used 61 | # to find such keys. 62 | pattern = re.compile( 63 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 64 | state_dict = model_zoo.load_url(model_urls['densenet169']) 65 | for key in list(state_dict.keys()): 66 | res = pattern.match(key) 67 | if res: 68 | new_key = res.group(1) + res.group(2) 69 | state_dict[new_key] = state_dict[key] 70 | del state_dict[key] 71 | model.load_state_dict(state_dict) 72 | return model 73 | 74 | 75 | def densenet201(pretrained=False, **kwargs): 76 | r"""Densenet-201 model from 77 | `"Densely Connected Convolutional Networks" `_ 78 | 79 | Args: 80 | pretrained (bool): If True, returns a model pre-trained on ImageNet 81 | """ 82 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32), 83 | **kwargs) 84 | if pretrained: 85 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 86 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 87 | # They are also in the checkpoints in model_urls. This pattern is used 88 | # to find such keys. 89 | pattern = re.compile( 90 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 91 | state_dict = model_zoo.load_url(model_urls['densenet201']) 92 | for key in list(state_dict.keys()): 93 | res = pattern.match(key) 94 | if res: 95 | new_key = res.group(1) + res.group(2) 96 | state_dict[new_key] = state_dict[key] 97 | del state_dict[key] 98 | model.load_state_dict(state_dict) 99 | return model 100 | 101 | 102 | def densenet161(pretrained=False, **kwargs): 103 | r"""Densenet-161 model from 104 | `"Densely Connected Convolutional Networks" `_ 105 | 106 | Args: 107 | pretrained (bool): If True, returns a model pre-trained on ImageNet 108 | """ 109 | model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24), 110 | **kwargs) 111 | if pretrained: 112 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 113 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 114 | # They are also in the checkpoints in model_urls. This pattern is used 115 | # to find such keys. 116 | pattern = re.compile( 117 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 118 | state_dict = model_zoo.load_url(model_urls['densenet161']) 119 | for key in list(state_dict.keys()): 120 | res = pattern.match(key) 121 | if res: 122 | new_key = res.group(1) + res.group(2) 123 | state_dict[new_key] = state_dict[key] 124 | del state_dict[key] 125 | model.load_state_dict(state_dict) 126 | return model 127 | 128 | 129 | class _DenseLayer(nn.Sequential): 130 | def __init__(self, num_input_features, growth_rate, bn_size, drop_rate): 131 | super(_DenseLayer, self).__init__() 132 | self.add_module('norm1', nn.BatchNorm2d(num_input_features)), 133 | self.add_module('relu1', nn.ReLU(inplace=True)), 134 | self.add_module('conv1', nn.Conv2d(num_input_features, bn_size * 135 | growth_rate, kernel_size=1, stride=1, bias=False)), 136 | self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)), 137 | self.add_module('relu2', nn.ReLU(inplace=True)), 138 | self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate, 139 | kernel_size=3, stride=1, padding=1, bias=False)), 140 | self.drop_rate = drop_rate 141 | 142 | def forward(self, x): 143 | new_features = super(_DenseLayer, self).forward(x) 144 | if self.drop_rate > 0: 145 | new_features = F.dropout(new_features, p=self.drop_rate, training=self.training) 146 | return torch.cat([x, new_features], 1) 147 | 148 | 149 | class _DenseBlock(nn.Sequential): 150 | def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate): 151 | super(_DenseBlock, self).__init__() 152 | for i in range(num_layers): 153 | layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate) 154 | self.add_module('denselayer%d' % (i + 1), layer) 155 | 156 | 157 | class _Transition(nn.Sequential): 158 | def __init__(self, num_input_features, num_output_features): 159 | super(_Transition, self).__init__() 160 | self.add_module('norm', nn.BatchNorm2d(num_input_features)) 161 | self.add_module('relu', nn.ReLU(inplace=True)) 162 | self.add_module('conv', nn.Conv2d(num_input_features, num_output_features, 163 | kernel_size=1, stride=1, bias=False)) 164 | self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) 165 | 166 | 167 | class DenseNet(nn.Module): 168 | r"""Densenet-BC model class, based on 169 | `"Densely Connected Convolutional Networks" `_ 170 | 171 | Args: 172 | growth_rate (int) - how many filters to add each layer (`k` in paper) 173 | block_config (list of 4 ints) - how many layers in each pooling block 174 | num_init_features (int) - the number of filters to learn in the first convolution layer 175 | bn_size (int) - multiplicative factor for number of bottle neck layers 176 | (i.e. bn_size * k features in the bottleneck layer) 177 | drop_rate (float) - dropout rate after each dense layer 178 | num_classes (int) - number of classification classes 179 | """ 180 | def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), 181 | num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000): 182 | 183 | super(DenseNet, self).__init__() 184 | 185 | # First convolution 186 | self.features = nn.Sequential(OrderedDict([ 187 | ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)), 188 | ('norm0', nn.BatchNorm2d(num_init_features)), 189 | ('relu0', nn.ReLU(inplace=True)), 190 | ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), 191 | ])) 192 | 193 | # Each denseblock 194 | num_features = num_init_features 195 | for i, num_layers in enumerate(block_config): 196 | block = _DenseBlock(num_layers=num_layers, num_input_features=num_features, 197 | bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate) 198 | self.features.add_module('denseblock%d' % (i + 1), block) 199 | num_features = num_features + num_layers * growth_rate 200 | if i != len(block_config) - 1: 201 | trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2) 202 | self.features.add_module('transition%d' % (i + 1), trans) 203 | num_features = num_features // 2 204 | 205 | # Final batch norm 206 | self.features.add_module('norm5', nn.BatchNorm2d(num_features)) 207 | 208 | # Linear layer 209 | self.classifier = nn.Linear(num_features, num_classes) 210 | 211 | # Official init from torch repo. 212 | for m in self.modules(): 213 | if isinstance(m, nn.Conv2d): 214 | nn.init.kaiming_normal(m.weight.data) 215 | elif isinstance(m, nn.BatchNorm2d): 216 | m.weight.data.fill_(1) 217 | m.bias.data.zero_() 218 | elif isinstance(m, nn.Linear): 219 | m.bias.data.zero_() 220 | 221 | def forward(self, x): 222 | features = self.features(x) 223 | out = F.relu(features, inplace=True) 224 | out = F.avg_pool2d(out, kernel_size=7, stride=1).view(features.size(0), -1) 225 | #out = self.classifier(out) 226 | return out 227 | 228 | def load_denseNet(net_type): 229 | if net_type == 'densenet121': 230 | return densenet121(pretrained=True) 231 | elif net_type == 'densenet169': 232 | return densenet169(pretrained=True) 233 | elif net_type == 'densenet201': 234 | return densenet201(pretrained=True) 235 | elif net_type == 'densenet161': 236 | return densenet161(pretrained=True) 237 | else: 238 | msg = 'invalid denset net type' 239 | sys.exit(msg) 240 | 241 | if __name__ == '__main__': 242 | net = load_denseNet('densenet169') 243 | print(net) -------------------------------------------------------------------------------- /src/do_train.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | #CUDA_VISIBLE_DEVICES=4,5,6,7 python3 trainer.py 4 | 5 | CUDA_VISIBLE_DEVICES=0,1,2,3 nohup python3 trainer.py > train.log & 6 | -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | 2 | ''' 3 | file: model.py 4 | 5 | date: 2018_05_03 6 | author: zhangxiong(1025679612@qq.com) 7 | ''' 8 | 9 | from LinearModel import LinearModel 10 | import torch.nn as nn 11 | import numpy as np 12 | import torch 13 | import util 14 | from Discriminator import ShapeDiscriminator, PoseDiscriminator, FullPoseDiscriminator 15 | from SMPL import SMPL 16 | from config import args 17 | import config 18 | import Resnet 19 | from HourGlass import _create_hourglass_net 20 | from densenet import load_denseNet 21 | import sys 22 | 23 | class ThetaRegressor(LinearModel): 24 | def __init__(self, fc_layers, use_dropout, drop_prob, use_ac_func, iterations): 25 | super(ThetaRegressor, self).__init__(fc_layers, use_dropout, drop_prob, use_ac_func) 26 | self.iterations = iterations 27 | batch_size = max(args.batch_size + args.batch_3d_size, args.eval_batch_size) 28 | mean_theta = np.tile(util.load_mean_theta(), batch_size).reshape((batch_size, -1)) 29 | self.register_buffer('mean_theta', torch.from_numpy(mean_theta).float()) 30 | ''' 31 | param: 32 | inputs: is the output of encoder, which has 2048 features 33 | 34 | return: 35 | a list contains [ [theta1, theta1, ..., theta1], [theta2, theta2, ..., theta2], ... , ], shape is iterations X N X 85(or other theta count) 36 | ''' 37 | def forward(self, inputs): 38 | thetas = [] 39 | shape = inputs.shape 40 | theta = self.mean_theta[:shape[0], :] 41 | for _ in range(self.iterations): 42 | total_inputs = torch.cat([inputs, theta], 1) 43 | theta = theta + self.fc_blocks(total_inputs) 44 | thetas.append(theta) 45 | return thetas 46 | 47 | class HMRNetBase(nn.Module): 48 | def __init__(self): 49 | super(HMRNetBase, self).__init__() 50 | self._read_configs() 51 | 52 | print('start creating sub modules...') 53 | self._create_sub_modules() 54 | 55 | def _read_configs(self): 56 | def _check_config(): 57 | encoder_name = args.encoder_network 58 | enable_inter_supervions = args.enable_inter_supervision 59 | feature_count = args.feature_count 60 | if encoder_name == 'hourglass': 61 | assert args.crop_size == 256 62 | elif encoder_name == 'resnet50': 63 | assert args.crop_size == 224 64 | assert not enable_inter_supervions 65 | elif encoder_name.startswith('densenet'): 66 | assert args.crop_size == 224 67 | assert not enable_inter_supervions 68 | else: 69 | msg = 'invalid encoder network, only {} is allowd, got {}'.format(args.allowed_encoder_net, encoder_name) 70 | sys.exit(msg) 71 | assert config.encoder_feature_count[encoder_name] == feature_count 72 | 73 | _check_config() 74 | 75 | self.encoder_name = args.encoder_network 76 | self.beta_count = args.beta_count 77 | self.smpl_model = args.smpl_model 78 | self.smpl_mean_theta_path = args.smpl_mean_theta_path 79 | self.total_theta_count = args.total_theta_count 80 | self.joint_count = args.joint_count 81 | self.feature_count = args.feature_count 82 | 83 | def _create_sub_modules(self): 84 | ''' 85 | ddd smpl model, SMPL can create a mesh from beta & theta 86 | ''' 87 | self.smpl = SMPL(self.smpl_model, obj_saveable = True) 88 | 89 | ''' 90 | only resnet50 and hourglass is allowd currently, maybe other encoder will be allowd later. 91 | ''' 92 | if self.encoder_name == 'resnet50': 93 | print('creating resnet50') 94 | self.encoder = Resnet.load_Res50Model() 95 | elif self.encoder_name == 'hourglass': 96 | print('creating hourglass') 97 | self.encoder = _create_hourglass_net() 98 | elif self.encoder_name.startswith('densenet'): 99 | print('creating densenet') 100 | self.encoder = load_denseNet(self.encoder_name) 101 | else: 102 | assert 0 103 | ''' 104 | regressor can predict betas(include beta and theta which needed by SMPL) from coder extracted from encoder in a iteratirve way 105 | ''' 106 | fc_layers = [self.feature_count + self.total_theta_count, 1024, 1024, 85] 107 | use_dropout = [True, True, False] 108 | drop_prob = [0.5, 0.5, 0.5] 109 | use_ac_func = [True, True, False] #unactive the last layer 110 | iterations = 3 111 | self.regressor = ThetaRegressor(fc_layers, use_dropout, drop_prob, use_ac_func, iterations) 112 | self.iterations = iterations 113 | 114 | print('finished create the encoder modules...') 115 | 116 | def forward(self, inputs): 117 | if self.encoder_name == 'resnet50': 118 | feature = self.encoder(inputs) 119 | thetas = self.regressor(feature) 120 | detail_info = [] 121 | for theta in thetas: 122 | detail_info.append(self._calc_detail_info(theta)) 123 | return detail_info 124 | elif self.encoder_name.startswith('densenet'): 125 | feature = self.encoder(inputs) 126 | thetas = self.regressor(feature) 127 | detail_info = [] 128 | for theta in thetas: 129 | detail_info.append(self._calc_detail_info(theta)) 130 | return detail_info 131 | elif self.encoder_name == 'hourglass': 132 | if args.enable_inter_supervision: 133 | features = self.encoder(inputs) 134 | detail_info = [] 135 | for feature in features: 136 | thetas = self.regressor(feature) 137 | detail_info.append(self._calc_detail_info(thetas[-1])) 138 | return detail_info 139 | else: 140 | features = self.encoder(inputs) 141 | thetas = self.regressor(features[-1]) #only the last block 142 | detail_info = [] 143 | for theta in thetas: 144 | detail_info.append(self._calc_detail_info(theta)) 145 | return detail_info 146 | else: 147 | assert 0 148 | 149 | ''' 150 | purpose: 151 | calc verts, joint2d, joint3d, Rotation matrix 152 | 153 | inputs: 154 | theta: N X (3 + 72 + 10) 155 | 156 | return: 157 | thetas, verts, j2d, j3d, Rs 158 | ''' 159 | 160 | def _calc_detail_info(self, theta): 161 | cam = theta[:, 0:3].contiguous() 162 | pose = theta[:, 3:75].contiguous() 163 | shape = theta[:, 75:].contiguous() 164 | verts, j3d, Rs = self.smpl(beta = shape, theta = pose, get_skin = True) 165 | j2d = util.batch_orth_proj(j3d, cam) 166 | 167 | return (theta, verts, j2d, j3d, Rs) 168 | 169 | if __name__ == '__main__': 170 | cam = np.array([[0.9, 0, 0]], dtype = np.float) 171 | pose= np.array([[-9.44920200e+01, -4.25263865e+01, -1.30050643e+01, -2.79970490e-01, 172 | 3.24995661e-01, 5.03083125e-01, -6.90573755e-01, -4.12994214e-01, 173 | -4.21870093e-01, 5.98717416e-01, -1.48420885e-02, -3.85911139e-02, 174 | 1.13642605e-01, 2.30647176e-01, -2.11843286e-01, 1.31767149e+00, 175 | -6.61596447e-01, 4.02174644e-01, 3.03129424e-02, 5.91100770e-02, 176 | -8.04416564e-02, -1.12944653e-01, 3.15045050e-01, -1.32838375e-01, 177 | -1.33748209e-01, -4.99408923e-01, 1.40508643e-01, 6.10867911e-02, 178 | -2.22951915e-02, -4.73448564e-02, -1.48489055e-01, 1.47620442e-01, 179 | 3.24157346e-01, 7.78414851e-04, 1.70687935e-01, -1.54716815e-01, 180 | 2.95053507e-01, -2.91967776e-01, 1.26000780e-01, 8.09572677e-02, 181 | 1.54710846e-02, -4.21941758e-01, 7.44124075e-02, 1.17146423e-01, 182 | 3.16305389e-01, 5.04810448e-01, -3.65526364e-01, 1.31366428e-01, 183 | -2.76658949e-02, -9.17315987e-03, -1.88285742e-01, 7.86409877e-03, 184 | -9.41106758e-02, 2.08424367e-01, 1.62278709e-01, -7.98170265e-01, 185 | -3.97403587e-03, 1.11321421e-01, 6.07793270e-01, 1.42215980e-01, 186 | 4.48185010e-01, -1.38429048e-01, 3.77056061e-02, 4.48877661e-01, 187 | 1.31445158e-01, 5.07427503e-02, -3.80920772e-01, -2.52292254e-02, 188 | -5.27745375e-02, -7.43903887e-02, 7.22498075e-02, -6.35824487e-03]]) 189 | 190 | beta = np.array([[-3.54196257, 0.90870435, -1.0978663 , -0.20436199, 0.18589762, 0.55789026, -0.18163599, 0.12002746, -0.09172286, 0.4430783 ]]) 191 | real_shapes = torch.from_numpy(beta).float().cuda() 192 | real_poses = torch.from_numpy(pose).float().cuda() 193 | 194 | net = HMRNetBase().cuda() 195 | nx = torch.rand(2, 3, 224, 224).float().cuda() 196 | -------------------------------------------------------------------------------- /src/timer.py: -------------------------------------------------------------------------------- 1 | 2 | ''' 3 | file: timer.py 4 | 5 | date: 2018_05_09 6 | author: zhangxiong(1025679612@qq.com) 7 | ''' 8 | 9 | import time 10 | 11 | class Clock: 12 | def __init__(self, start_tick = True): 13 | self.pre_time = 0 14 | if start_tick: 15 | self.start() 16 | 17 | def start(self): 18 | self.pre_time = time.time() 19 | 20 | def stop(self): 21 | self.cur_time = time.time() 22 | print('time {} elapsed!'.format(self.cur_time - self.pre_time)) 23 | self.pre_time = self.cur_time -------------------------------------------------------------------------------- /src/trainer.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | ''' 4 | file: trainer.py 5 | 6 | date: 2018_05_07 7 | author: zhangxiong(1025679612@qq.com) 8 | ''' 9 | 10 | import sys 11 | from model import HMRNetBase 12 | from Discriminator import Discriminator 13 | from config import args 14 | import config 15 | import torch 16 | import torch.nn as nn 17 | from torch.utils.data import Dataset, DataLoader, ConcatDataset 18 | 19 | from dataloader.AICH_dataloader import AICH_dataloader 20 | from dataloader.COCO2017_dataloader import COCO2017_dataloader 21 | from dataloader.hum36m_dataloader import hum36m_dataloader 22 | from dataloader.lsp_dataloader import LspLoader 23 | from dataloader.lsp_ext_dataloader import LspExtLoader 24 | from dataloader.mosh_dataloader import mosh_dataloader 25 | from dataloader.mpi_inf_3dhp_dataloader import mpi_inf_3dhp_dataloader 26 | from dataloader.eval_dataloader import eval_dataloader 27 | 28 | from util import align_by_pelvis, batch_rodrigues, copy_state_dict 29 | from timer import Clock 30 | import time 31 | import datetime 32 | from collections import OrderedDict 33 | import os 34 | 35 | class HMRTrainer(object): 36 | def __init__(self): 37 | self.pix_format = 'NCHW' 38 | self.normalize = True 39 | self.flip_prob = 0.5 40 | self.use_flip = False 41 | self.w_smpl = torch.ones((config.args.eval_batch_size)).float().cuda() 42 | 43 | self._build_model() 44 | self._create_data_loader() 45 | 46 | def _create_data_loader(self): 47 | self.loader_2d = self._create_2d_data_loader(config.train_2d_set) 48 | self.loader_mosh = self._create_adv_data_loader(config.train_adv_set) 49 | self.loader_3d = self._create_3d_data_loader(config.train_3d_set) 50 | 51 | def _build_model(self): 52 | print('start building modle.') 53 | 54 | ''' 55 | load pretrain model 56 | ''' 57 | generator = HMRNetBase() 58 | model_path = config.pre_trained_model['generator'] 59 | if os.path.exists(model_path): 60 | copy_state_dict( 61 | generator.state_dict(), 62 | torch.load(model_path), 63 | prefix = 'module.' 64 | ) 65 | else: 66 | print('model {} not exist!'.format(model_path)) 67 | 68 | discriminator = Discriminator() 69 | model_path = config.pre_trained_model['discriminator'] 70 | if os.path.exists(model_path): 71 | copy_state_dict( 72 | discriminator.state_dict(), 73 | torch.load(model_path), 74 | prefix = 'module.' 75 | ) 76 | else: 77 | print('model {} not exist!'.format(model_path)) 78 | 79 | self.generator = nn.DataParallel(generator).cuda() 80 | self.discriminator = nn.DataParallel(discriminator).cuda() 81 | 82 | self.e_opt = torch.optim.Adam( 83 | self.generator.parameters(), 84 | lr = args.e_lr, 85 | weight_decay = args.e_wd 86 | ) 87 | 88 | self.d_opt = torch.optim.Adam( 89 | self.discriminator.parameters(), 90 | lr = args.d_lr, 91 | weight_decay = args.d_wd 92 | ) 93 | 94 | self.e_sche = torch.optim.lr_scheduler.StepLR( 95 | self.e_opt, 96 | step_size = 500, 97 | gamma = 0.9 98 | ) 99 | 100 | self.d_sche = torch.optim.lr_scheduler.StepLR( 101 | self.d_opt, 102 | step_size = 500, 103 | gamma = 0.9 104 | ) 105 | 106 | print('finished build model.') 107 | 108 | def _create_2d_data_loader(self, data_2d_set): 109 | data_set = [] 110 | for data_set_name in data_2d_set: 111 | data_set_path = config.data_set_path[data_set_name] 112 | if data_set_name == 'coco': 113 | coco = COCO2017_dataloader( 114 | data_set_path = data_set_path, 115 | use_crop = True, 116 | scale_range = [1.05, 1.3], 117 | use_flip = self.use_flip, 118 | only_single_person = False, 119 | min_pts_required = 7, 120 | max_intersec_ratio = 0.5, 121 | pix_format = self.pix_format, 122 | normalize = self.normalize, 123 | flip_prob = self.flip_prob 124 | ) 125 | data_set.append(coco) 126 | elif data_set_name == 'lsp': 127 | lsp = LspLoader( 128 | data_set_path = data_set_path, 129 | use_crop = True, 130 | scale_range = [1.05, 1.3], 131 | use_flip = self.use_flip, 132 | pix_format = self.pix_format, 133 | normalize = self.normalize, 134 | flip_prob = self.flip_prob 135 | ) 136 | data_set.append(lsp) 137 | elif data_set_name == 'lsp_ext': 138 | lsp_ext = LspExtLoader( 139 | data_set_path = data_set_path, 140 | use_crop = True, 141 | scale_range = [1.1, 1.2], 142 | use_flip = self.use_flip, 143 | pix_format = self.pix_format, 144 | normalize = self.normalize, 145 | flip_prob = self.flip_prob 146 | ) 147 | data_set.append(lsp_ext) 148 | elif data_set_name == 'ai-ch': 149 | ai_ch = AICH_dataloader( 150 | data_set_path = data_set_path, 151 | use_crop = True, 152 | scale_range = [1.1, 1.2], 153 | use_flip = self.use_flip, 154 | only_single_person = False, 155 | min_pts_required = 5, 156 | max_intersec_ratio = 0.1, 157 | pix_format = self.pix_format, 158 | normalize = self.normalize, 159 | flip_prob = self.flip_prob 160 | ) 161 | data_set.append(ai_ch) 162 | else: 163 | msg = 'invalid 2d dataset' 164 | sys.exit(msg) 165 | 166 | con_2d_dataset = ConcatDataset(data_set) 167 | 168 | return DataLoader( 169 | dataset = con_2d_dataset, 170 | batch_size = config.args.batch_size, 171 | shuffle = True, 172 | drop_last = True, 173 | pin_memory = True, 174 | num_workers = config.args.num_worker 175 | ) 176 | 177 | def _create_3d_data_loader(self, data_3d_set): 178 | data_set = [] 179 | for data_set_name in data_3d_set: 180 | data_set_path = config.data_set_path[data_set_name] 181 | if data_set_name == 'mpi-inf-3dhp': 182 | mpi_inf_3dhp = mpi_inf_3dhp_dataloader( 183 | data_set_path = data_set_path, 184 | use_crop = True, 185 | scale_range = [1.1, 1.2], 186 | use_flip = self.use_flip, 187 | min_pts_required = 5, 188 | pix_format = self.pix_format, 189 | normalize = self.normalize, 190 | flip_prob = self.flip_prob 191 | ) 192 | data_set.append(mpi_inf_3dhp) 193 | elif data_set_name == 'hum3.6m': 194 | hum36m = hum36m_dataloader( 195 | data_set_path = data_set_path, 196 | use_crop = True, 197 | scale_range = [1.1, 1.2], 198 | use_flip = self.use_flip, 199 | min_pts_required = 5, 200 | pix_format = self.pix_format, 201 | normalize = self.normalize, 202 | flip_prob = self.flip_prob 203 | ) 204 | data_set.append(hum36m) 205 | else: 206 | msg = 'invalid 3d dataset' 207 | sys.exit(msg) 208 | 209 | con_3d_dataset = ConcatDataset(data_set) 210 | 211 | return DataLoader( 212 | dataset = con_3d_dataset, 213 | batch_size = config.args.batch_3d_size, 214 | shuffle = True, 215 | drop_last = True, 216 | pin_memory = True, 217 | num_workers = config.args.num_worker 218 | ) 219 | 220 | def _create_adv_data_loader(self, data_adv_set): 221 | data_set = [] 222 | for data_set_name in data_adv_set: 223 | data_set_path = config.data_set_path[data_set_name] 224 | if data_set_name == 'mosh': 225 | mosh = mosh_dataloader( 226 | data_set_path = data_set_path, 227 | use_flip = self.use_flip, 228 | flip_prob = self.flip_prob 229 | ) 230 | data_set.append(mosh) 231 | else: 232 | msg = 'invalid adv dataset' 233 | sys.exit(msg) 234 | 235 | con_adv_dataset = ConcatDataset(data_set) 236 | return DataLoader( 237 | dataset = con_adv_dataset, 238 | batch_size = config.args.adv_batch_size, 239 | shuffle = True, 240 | drop_last = True, 241 | pin_memory = True, 242 | ) 243 | 244 | def _create_eval_data_loader(self, data_eval_set): 245 | data_set = [] 246 | for data_set_name in data_eval_set: 247 | data_set_path = config.data_set_path[data_set_name] 248 | if data_set_name == 'up3d': 249 | up3d = eval_dataloader( 250 | data_set_path = data_set_path, 251 | use_flip = False, 252 | flip_prob = self.flip_prob, 253 | pix_format = self.pix_format, 254 | normalize = self.normalize 255 | ) 256 | data_set.append(up3d) 257 | else: 258 | msg = 'invalid eval dataset' 259 | sys.exit(msg) 260 | con_eval_dataset = ConcatDataset(data_set) 261 | return DataLoader( 262 | dataset = con_eval_dataset, 263 | batch_size = config.args.eval_batch_size, 264 | shuffle = False, 265 | drop_last = False, 266 | pin_memory = True, 267 | num_workers = config.args.num_worker 268 | ) 269 | 270 | def train(self): 271 | def save_model(result): 272 | exclude_key = 'module.smpl' 273 | def exclude_smpl(model_dict): 274 | result = OrderedDict() 275 | for (k, v) in model_dict.items(): 276 | if exclude_key in k: 277 | continue 278 | result[k] = v 279 | return result 280 | 281 | parent_folder = args.save_folder 282 | if not os.path.exists(parent_folder): 283 | os.makedirs(parent_folder) 284 | 285 | title = result['title'] 286 | generator_save_path = os.path.join(parent_folder, title + 'generator.pkl') 287 | torch.save(exclude_smpl(self.generator.state_dict()), generator_save_path) 288 | disc_save_path = os.path.join(parent_folder, title + 'discriminator.pkl') 289 | torch.save(exclude_smpl(self.discriminator.state_dict()), disc_save_path) 290 | with open(os.path.join(parent_folder, title + '.txt'), 'w') as fp: 291 | fp.write(str(result)) 292 | 293 | #pre_best_loss = None 294 | 295 | torch.backends.cudnn.benchmark = True 296 | loader_2d, loader_3d, loader_mosh = iter(self.loader_2d), iter(self.loader_3d), iter(self.loader_mosh) 297 | e_opt, d_opt = self.e_opt, self.d_opt 298 | 299 | self.generator.train() 300 | self.discriminator.train() 301 | 302 | for iter_index in range(config.args.iter_count): 303 | try: 304 | data_2d = next(loader_2d) 305 | except StopIteration: 306 | loader_2d = iter(self.loader_2d) 307 | data_2d = next(loader_2d) 308 | 309 | try: 310 | data_3d = next(loader_3d) 311 | except StopIteration: 312 | loader_3d = iter(self.loader_3d) 313 | data_3d = next(loader_3d) 314 | 315 | try: 316 | data_mosh = next(loader_mosh) 317 | except StopIteration: 318 | loader_mosh = iter(self.loader_mosh) 319 | data_mosh = next(loader_mosh) 320 | 321 | image_from_2d, image_from_3d = data_2d['image'], data_3d['image'] 322 | sample_2d_count, sample_3d_count, sample_mosh_count = image_from_2d.shape[0], image_from_3d.shape[0], data_mosh['theta'].shape[0] 323 | images = torch.cat((image_from_2d, image_from_3d), dim = 0).cuda() 324 | 325 | generator_outputs = self.generator(images) 326 | 327 | loss_kp_2d, loss_kp_3d, loss_shape, loss_pose, e_disc_loss, d_disc_loss, d_disc_real, d_disc_predict = self._calc_loss(generator_outputs, data_2d, data_3d, data_mosh) 328 | 329 | e_loss = loss_kp_2d + loss_kp_3d + loss_shape + loss_pose + e_disc_loss 330 | d_loss = d_disc_loss 331 | 332 | e_opt.zero_grad() 333 | e_loss.backward() 334 | e_opt.step() 335 | 336 | d_opt.zero_grad() 337 | d_loss.backward() 338 | d_opt.step() 339 | 340 | loss_kp_2d = float(loss_kp_2d) 341 | loss_shape = float(loss_shape / args.e_shape_ratio) 342 | loss_kp_3d = float(loss_kp_3d / args.e_3d_kp_ratio) 343 | loss_pose = float(loss_pose / args.e_pose_ratio) 344 | e_disc_loss = float(e_disc_loss / args.d_disc_ratio) 345 | d_disc_loss = float(d_disc_loss / args.d_disc_ratio) 346 | 347 | d_disc_real = float(d_disc_real / args.d_disc_ratio) 348 | d_disc_predict = float(d_disc_predict / args.d_disc_ratio) 349 | 350 | e_loss = loss_kp_2d + loss_kp_3d + loss_shape + loss_pose + e_disc_loss 351 | d_loss = d_disc_loss 352 | 353 | iter_msg = OrderedDict( 354 | [ 355 | ('time',datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S')), 356 | ('iter',iter_index), 357 | ('e_loss', e_loss), 358 | ('2d_loss',loss_kp_2d), 359 | ('3d_loss',loss_kp_3d), 360 | ('shape_loss',loss_shape), 361 | ('pose_loss', loss_pose), 362 | ('e_disc_loss',float(e_disc_loss)), 363 | ('d_disc_loss',float(d_disc_loss)), 364 | ('d_disc_real', float(d_disc_real)), 365 | ('d_disc_predict', float(d_disc_predict)) 366 | ] 367 | ) 368 | 369 | print(iter_msg) 370 | 371 | if iter_index % 500 == 0: 372 | iter_msg['title'] = '{}_{}_'.format(iter_msg['iter'], iter_msg['e_loss']) 373 | save_model(iter_msg) 374 | 375 | def _calc_loss(self, generator_outputs, data_2d, data_3d, data_mosh): 376 | def _accumulate_thetas(generator_outputs): 377 | thetas = [] 378 | for (theta, verts, j2d, j3d, Rs) in generator_outputs: 379 | thetas.append(theta) 380 | return torch.cat(thetas, 0) 381 | 382 | sample_2d_count, sample_3d_count, sample_mosh_count = data_2d['kp_2d'].shape[0], data_3d['kp_2d'].shape[0], data_mosh['theta'].shape 383 | data_3d_theta, w_3d, w_smpl = data_3d['theta'].cuda(), data_3d['w_3d'].float().cuda(), data_3d['w_smpl'].float().cuda() 384 | 385 | total_predict_thetas = _accumulate_thetas(generator_outputs) 386 | (predict_theta, predict_verts, predict_j2d, predict_j3d, predict_Rs) = generator_outputs[-1] 387 | 388 | real_2d, real_3d = torch.cat((data_2d['kp_2d'], data_3d['kp_2d']), 0).cuda(), data_3d['kp_3d'].float().cuda() 389 | predict_j2d, predict_j3d, predict_theta = predict_j2d, predict_j3d[sample_2d_count:, :], predict_theta[sample_2d_count:, :] 390 | 391 | loss_kp_2d = self.batch_kp_2d_l1_loss(real_2d, predict_j2d[:,:14,:]) * args.e_loss_weight 392 | loss_kp_3d = self.batch_kp_3d_l2_loss(real_3d, predict_j3d[:,:14,:], w_3d) * args.e_3d_loss_weight * args.e_3d_kp_ratio 393 | 394 | real_shape, predict_shape = data_3d_theta[:, 75:], predict_theta[:, 75:] 395 | loss_shape = self.batch_shape_l2_loss(real_shape, predict_shape, w_smpl) * args.e_3d_loss_weight * args.e_shape_ratio 396 | 397 | real_pose, predict_pose = data_3d_theta[:, 3:75], predict_theta[:, 3:75] 398 | loss_pose = self.batch_pose_l2_loss(real_pose.contiguous(), predict_pose.contiguous(), w_smpl) * args.e_3d_loss_weight * args.e_pose_ratio 399 | 400 | e_disc_loss = self.batch_encoder_disc_l2_loss(self.discriminator(total_predict_thetas)) * args.d_loss_weight * args.d_disc_ratio 401 | 402 | mosh_real_thetas = data_mosh['theta'].cuda() 403 | fake_thetas = total_predict_thetas.detach() 404 | fake_disc_value, real_disc_value = self.discriminator(fake_thetas), self.discriminator(mosh_real_thetas) 405 | d_disc_real, d_disc_fake, d_disc_loss = self.batch_adv_disc_l2_loss(real_disc_value, fake_disc_value) 406 | d_disc_real, d_disc_fake, d_disc_loss = d_disc_real * args.d_loss_weight * args.d_disc_ratio, d_disc_fake * args.d_loss_weight * args.d_disc_ratio, d_disc_loss * args.d_loss_weight * args.d_disc_ratio 407 | 408 | return loss_kp_2d, loss_kp_3d, loss_shape, loss_pose, e_disc_loss, d_disc_loss, d_disc_real, d_disc_fake 409 | 410 | """ 411 | purpose: 412 | calc L1 error 413 | Inputs: 414 | kp_gt : N x K x 3 415 | kp_pred: N x K x 2 416 | """ 417 | def batch_kp_2d_l1_loss(self, real_2d_kp, predict_2d_kp): 418 | kp_gt = real_2d_kp.view(-1, 3) 419 | kp_pred = predict_2d_kp.contiguous().view(-1, 2) 420 | vis = kp_gt[:, 2] 421 | k = torch.sum(vis) * 2.0 + 1e-8 422 | dif_abs = torch.abs(kp_gt[:, :2] - kp_pred).sum(1) 423 | return torch.matmul(dif_abs, vis) * 1.0 / k 424 | 425 | ''' 426 | purpose: 427 | calc mse * 0.5 428 | 429 | Inputs: 430 | real_3d_kp : N x k x 3 431 | fake_3d_kp : N x k x 3 432 | w_3d : N x 1 433 | ''' 434 | def batch_kp_3d_l2_loss(self, real_3d_kp, fake_3d_kp, w_3d): 435 | shape = real_3d_kp.shape 436 | k = torch.sum(w_3d) * shape[1] * 3.0 * 2.0 + 1e-8 437 | 438 | #first align it 439 | real_3d_kp, fake_3d_kp = align_by_pelvis(real_3d_kp), align_by_pelvis(fake_3d_kp) 440 | kp_gt = real_3d_kp 441 | kp_pred = fake_3d_kp 442 | kp_dif = (kp_gt - kp_pred) ** 2 443 | return torch.matmul(kp_dif.sum(1).sum(1), w_3d) * 1.0 / k 444 | 445 | ''' 446 | purpose: 447 | calc mse * 0.5 448 | 449 | Inputs: 450 | real_shape : N x 10 451 | fake_shape : N x 10 452 | w_shape : N x 1 453 | ''' 454 | def batch_shape_l2_loss(self, real_shape, fake_shape, w_shape): 455 | k = torch.sum(w_shape) * 10.0 * 2.0 + 1e-8 456 | shape_dif = (real_shape - fake_shape) ** 2 457 | return torch.matmul(shape_dif.sum(1), w_shape) * 1.0 / k 458 | 459 | ''' 460 | Input: 461 | real_pose : N x 72 462 | fake_pose : N x 72 463 | ''' 464 | def batch_pose_l2_loss(self, real_pose, fake_pose, w_pose): 465 | k = torch.sum(w_pose) * 207.0 * 2.0 + 1e-8 466 | real_rs, fake_rs = batch_rodrigues(real_pose.view(-1, 3)).view(-1, 24, 9)[:,1:,:], batch_rodrigues(fake_pose.view(-1, 3)).view(-1, 24, 9)[:,1:,:] 467 | dif_rs = ((real_rs - fake_rs) ** 2).view(-1, 207) 468 | return torch.matmul(dif_rs.sum(1), w_pose) * 1.0 / k 469 | ''' 470 | Inputs: 471 | disc_value: N x 25 472 | ''' 473 | def batch_encoder_disc_l2_loss(self, disc_value): 474 | k = disc_value.shape[0] 475 | return torch.sum((disc_value - 1.0) ** 2) * 1.0 / k 476 | ''' 477 | Inputs: 478 | disc_value: N x 25 479 | ''' 480 | def batch_adv_disc_l2_loss(self, real_disc_value, fake_disc_value): 481 | ka = real_disc_value.shape[0] 482 | kb = fake_disc_value.shape[0] 483 | lb, la = torch.sum(fake_disc_value ** 2) / kb, torch.sum((real_disc_value - 1) ** 2) / ka 484 | return la, lb, la + lb 485 | 486 | def main(): 487 | trainer = HMRTrainer() 488 | trainer.train() 489 | 490 | if __name__ == '__main__': 491 | main() 492 | -------------------------------------------------------------------------------- /src/util.py: -------------------------------------------------------------------------------- 1 | 2 | ''' 3 | file: util.py 4 | 5 | date: 2018_04_29 6 | author: zhangxiong(1025679612@qq.com) 7 | ''' 8 | 9 | import h5py 10 | import torch 11 | import numpy as np 12 | from config import args 13 | import json 14 | from torch.autograd import Variable 15 | import torch.nn.functional as F 16 | import cv2 17 | import math 18 | from scipy import interpolate 19 | 20 | def load_mean_theta(): 21 | mean = np.zeros(args.total_theta_count, dtype = np.float) 22 | 23 | mean_values = h5py.File(args.smpl_mean_theta_path) 24 | mean_pose = mean_values['pose'] 25 | mean_pose[:3] = 0 26 | mean_shape = mean_values['shape'] 27 | mean_pose[0]=np.pi 28 | 29 | #init sacle is 0.9 30 | mean[0] = 0.9 31 | 32 | mean[3:75] = mean_pose[:] 33 | mean[75:] = mean_shape[:] 34 | 35 | return mean 36 | 37 | def batch_rodrigues(theta): 38 | #theta N x 3 39 | batch_size = theta.shape[0] 40 | l1norm = torch.norm(theta + 1e-8, p = 2, dim = 1) 41 | angle = torch.unsqueeze(l1norm, -1) 42 | normalized = torch.div(theta, angle) 43 | angle = angle * 0.5 44 | v_cos = torch.cos(angle) 45 | v_sin = torch.sin(angle) 46 | quat = torch.cat([v_cos, v_sin * normalized], dim = 1) 47 | 48 | return quat2mat(quat) 49 | 50 | def quat2mat(quat): 51 | """Convert quaternion coefficients to rotation matrix. 52 | Args: 53 | quat: size = [B, 4] 4 <===>(w, x, y, z) 54 | Returns: 55 | Rotation matrix corresponding to the quaternion -- size = [B, 3, 3] 56 | """ 57 | norm_quat = quat 58 | norm_quat = norm_quat/norm_quat.norm(p=2, dim=1, keepdim=True) 59 | w, x, y, z = norm_quat[:,0], norm_quat[:,1], norm_quat[:,2], norm_quat[:,3] 60 | 61 | B = quat.size(0) 62 | 63 | w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2) 64 | wx, wy, wz = w*x, w*y, w*z 65 | xy, xz, yz = x*y, x*z, y*z 66 | 67 | rotMat = torch.stack([w2 + x2 - y2 - z2, 2*xy - 2*wz, 2*wy + 2*xz, 68 | 2*wz + 2*xy, w2 - x2 + y2 - z2, 2*yz - 2*wx, 69 | 2*xz - 2*wy, 2*wx + 2*yz, w2 - x2 - y2 + z2], dim=1).view(B, 3, 3) 70 | return rotMat 71 | 72 | def batch_global_rigid_transformation(Rs, Js, parent, rotate_base = False): 73 | N = Rs.shape[0] 74 | if rotate_base: 75 | np_rot_x = np.array([[1, 0, 0], [0, -1, 0], [0, 0, -1]], dtype = np.float) 76 | np_rot_x = np.reshape(np.tile(np_rot_x, [N, 1]), [N, 3, 3]) 77 | rot_x = Variable(torch.from_numpy(np_rot_x).float()).cuda() 78 | root_rotation = torch.matmul(Rs[:, 0, :, :], rot_x) 79 | else: 80 | root_rotation = Rs[:, 0, :, :] 81 | Js = torch.unsqueeze(Js, -1) 82 | 83 | def make_A(R, t): 84 | R_homo = F.pad(R, [0, 0, 0, 1, 0, 0]) 85 | t_homo = torch.cat([t, Variable(torch.ones(N, 1, 1)).cuda()], dim = 1) 86 | return torch.cat([R_homo, t_homo], 2) 87 | 88 | A0 = make_A(root_rotation, Js[:, 0]) 89 | results = [A0] 90 | 91 | for i in range(1, parent.shape[0]): 92 | j_here = Js[:, i] - Js[:, parent[i]] 93 | A_here = make_A(Rs[:, i], j_here) 94 | res_here = torch.matmul(results[parent[i]], A_here) 95 | results.append(res_here) 96 | 97 | results = torch.stack(results, dim = 1) 98 | 99 | new_J = results[:, :, :3, 3] 100 | Js_w0 = torch.cat([Js, Variable(torch.zeros(N, 24, 1, 1)).cuda()], dim = 2) 101 | init_bone = torch.matmul(results, Js_w0) 102 | init_bone = F.pad(init_bone, [3, 0, 0, 0, 0, 0, 0, 0]) 103 | A = results - init_bone 104 | 105 | return new_J, A 106 | 107 | 108 | def batch_lrotmin(theta): 109 | theta = theta[:,3:].contiguous() 110 | Rs = batch_rodrigues(theta.view(-1, 3)) 111 | print(Rs.shape) 112 | e = Variable(torch.eye(3).float()) 113 | Rs = Rs.sub(1.0, e) 114 | 115 | return Rs.view(-1, 23 * 9) 116 | 117 | def batch_orth_proj(X, camera): 118 | ''' 119 | X is N x num_points x 3 120 | ''' 121 | camera = camera.view(-1, 1, 3) 122 | X_trans = X[:, :, :2] + camera[:, :, 1:] 123 | shape = X_trans.shape 124 | return (camera[:, :, 0] * X_trans.view(shape[0], -1)).view(shape) 125 | 126 | def calc_aabb(ptSets): 127 | if not ptSets or len(ptSets) == 0: 128 | return False, False, False 129 | 130 | ptLeftTop = np.array([ptSets[0][0], ptSets[0][1]]) 131 | ptRightBottom = ptLeftTop.copy() 132 | for pt in ptSets: 133 | ptLeftTop[0] = min(ptLeftTop[0], pt[0]) 134 | ptLeftTop[1] = min(ptLeftTop[1], pt[1]) 135 | ptRightBottom[0] = max(ptRightBottom[0], pt[0]) 136 | ptRightBottom[1] = max(ptRightBottom[1], pt[1]) 137 | 138 | return ptLeftTop, ptRightBottom, len(ptSets) >= 5 139 | 140 | ''' 141 | calculate a obb for a set of points 142 | 143 | inputs: 144 | ptSets: a set of points 145 | 146 | return the center and 4 corners of a obb 147 | ''' 148 | def calc_obb(ptSets): 149 | ca = np.cov(ptSets,y = None,rowvar = 0,bias = 1) 150 | v, vect = np.linalg.eig(ca) 151 | tvect = np.transpose(vect) 152 | ar = np.dot(ptSets,np.linalg.inv(tvect)) 153 | mina = np.min(ar,axis=0) 154 | maxa = np.max(ar,axis=0) 155 | diff = (maxa - mina)*0.5 156 | center = mina + diff 157 | corners = np.array([center+[-diff[0],-diff[1]],center+[diff[0],-diff[1]],center+[diff[0],diff[1]],center+[-diff[0],diff[1]]]) 158 | corners = np.dot(corners, tvect) 159 | return corners[0], corners[1], corners[2], corners[3] 160 | 161 | def get_image_cut_box(leftTop, rightBottom, ExpandsRatio, Center = None): 162 | try: 163 | l = len(ExpandsRatio) 164 | except: 165 | ExpandsRatio = [ExpandsRatio, ExpandsRatio, ExpandsRatio, ExpandsRatio] 166 | 167 | def _expand_crop_box(lt, rb, scale): 168 | center = (lt + rb) / 2.0 169 | xl, xr, yt, yb = lt[0] - center[0], rb[0] - center[0], lt[1] - center[1], rb[1] - center[1] 170 | xl, xr, yt, yb = xl * scale[0], xr * scale[1], yt * scale[2], yb * scale[3] 171 | #expand it 172 | lt, rb = np.array([center[0] + xl, center[1] + yt]), np.array([center[0] + xr, center[1] + yb]) 173 | lb, rt = np.array([center[0] + xl, center[1] + yb]), np.array([center[0] + xr, center[1] + yt]) 174 | center = (lt + rb) / 2 175 | return center, lt, rt, rb, lb 176 | 177 | if Center == None: 178 | Center = (leftTop + rightBottom) // 2 179 | 180 | Center, leftTop, rightTop, rightBottom, leftBottom = _expand_crop_box(leftTop, rightBottom, ExpandsRatio) 181 | 182 | offset = (rightBottom - leftTop) // 2 183 | 184 | cx = offset[0] 185 | cy = offset[1] 186 | 187 | r = max(cx, cy) 188 | 189 | cx = r 190 | cy = r 191 | 192 | x = int(Center[0]) 193 | y = int(Center[1]) 194 | 195 | return [x - cx, y - cy], [x + cx, y + cy] 196 | 197 | def shrink(leftTop, rightBottom, width, height): 198 | xl = -leftTop[0] 199 | xr = rightBottom[0] - width 200 | 201 | yt = -leftTop[1] 202 | yb = rightBottom[1] - height 203 | 204 | cx = (leftTop[0] + rightBottom[0]) / 2 205 | cy = (leftTop[1] + rightBottom[1]) / 2 206 | 207 | r = (rightBottom[0] - leftTop[0]) / 2 208 | 209 | sx = max(xl, 0) + max(xr, 0) 210 | sy = max(yt, 0) + max(yb, 0) 211 | 212 | if (xl <= 0 and xr <= 0) or (yt <= 0 and yb <=0): 213 | return leftTop, rightBottom 214 | elif leftTop[0] >= 0 and leftTop[1] >= 0 : # left top corner is in box 215 | l = min(yb, xr) 216 | r = r - l / 2 217 | cx = cx - l / 2 218 | cy = cy - l / 2 219 | elif rightBottom[0] <= width and rightBottom[1] <= height : # right bottom corner is in box 220 | l = min(yt, xl) 221 | r = r - l / 2 222 | cx = cx + l / 2 223 | cy = cy + l / 2 224 | elif leftTop[0] >= 0 and rightBottom[1] <= height : #left bottom corner is in box 225 | l = min(xr, yt) 226 | r = r - l / 2 227 | cx = cx - l / 2 228 | cy = cy + l / 2 229 | elif rightBottom[0] <= width and leftTop[1] >= 0 : #right top corner is in box 230 | l = min(xl, yb) 231 | r = r - l / 2 232 | cx = cx + l / 2 233 | cy = cy - l / 2 234 | elif xl < 0 or xr < 0 or yb < 0 or yt < 0: 235 | return leftTop, rightBottom 236 | elif sx >= sy: 237 | sx = max(xl, 0) + max(0, xr) 238 | sy = max(yt, 0) + max(0, yb) 239 | # cy = height / 2 240 | if yt >= 0 and yb >= 0: 241 | cy = height / 2 242 | elif yt >= 0: 243 | cy = cy + sy / 2 244 | else: 245 | cy = cy - sy / 2 246 | r = r - sy / 2 247 | 248 | if xl >= sy / 2 and xr >= sy / 2: 249 | pass 250 | elif xl < sy / 2: 251 | cx = cx - (sy / 2 - xl) 252 | else: 253 | cx = cx + (sy / 2 - xr) 254 | elif sx < sy: 255 | cx = width / 2 256 | r = r - sx / 2 257 | if yt >= sx / 2 and yb >= sx / 2: 258 | pass 259 | elif yt < sx / 2: 260 | cy = cy - (sx / 2 - yt) 261 | else: 262 | cy = cy + (sx / 2 - yb) 263 | 264 | 265 | return [cx - r, cy - r], [cx + r, cy + r] 266 | 267 | ''' 268 | offset the keypoint by leftTop 269 | ''' 270 | def off_set_pts(keyPoints, leftTop): 271 | result = keyPoints.copy() 272 | result[:, 0] -= leftTop[0] 273 | result[:, 1] -= leftTop[1] 274 | return result 275 | 276 | ''' 277 | cut the image, by expanding a bounding box 278 | ''' 279 | def cut_image(filePath, kps, expand_ratio, leftTop, rightBottom): 280 | originImage = cv2.imread(filePath) 281 | height = originImage.shape[0] 282 | width = originImage.shape[1] 283 | channels = originImage.shape[2] if len(originImage.shape) >= 3 else 1 284 | 285 | leftTop, rightBottom = get_image_cut_box(leftTop, rightBottom, expand_ratio) 286 | 287 | #remove extra space. 288 | #leftTop, rightBottom = shrink(leftTop, rightBottom, width, height) 289 | 290 | lt = [int(leftTop[0]), int(leftTop[1])] 291 | rb = [int(rightBottom[0]), int(rightBottom[1])] 292 | 293 | lt[0] = max(0, lt[0]) 294 | lt[1] = max(0, lt[1]) 295 | rb[0] = min(rb[0], width) 296 | rb[1] = min(rb[1], height) 297 | 298 | leftTop = [int(leftTop[0]), int(leftTop[1])] 299 | rightBottom = [int(rightBottom[0] + 0.5), int(rightBottom[1] + 0.5)] 300 | 301 | dstImage = np.zeros(shape = [rightBottom[1] - leftTop[1], rightBottom[0] - leftTop[0], channels], dtype = np.uint8) 302 | dstImage[:,:,:] = 0 303 | 304 | offset = [lt[0] - leftTop[0], lt[1] - leftTop[1]] 305 | size = [rb[0] - lt[0], rb[1] - lt[1]] 306 | 307 | dstImage[offset[1]:size[1] + offset[1], offset[0]:size[0] + offset[0], :] = originImage[lt[1]:rb[1], lt[0]:rb[0],:] 308 | return dstImage, off_set_pts(kps, leftTop) 309 | 310 | ''' 311 | purpose: 312 | reflect key point, when the image is reflect by left-right 313 | 314 | inputs: 315 | kps:3d key point(14 x 3) 316 | 317 | marks: 318 | the key point is given by lsp order. 319 | ''' 320 | def reflect_lsp_kp(kps): 321 | kp_map = [5, 4, 3, 2, 1, 0, 11, 10, 9, 8, 7, 6, 12, 13] 322 | joint_ref = kps[kp_map] 323 | joint_ref[:,0] = -joint_ref[:,0] 324 | 325 | return joint_ref - np.mean(joint_ref, axis = 0) 326 | 327 | ''' 328 | purpose: 329 | reflect poses, when the image is reflect by left-right 330 | 331 | inputs: 332 | poses: 72 real number 333 | ''' 334 | def reflect_pose(poses): 335 | swap_inds = np.array([ 336 | 0, 1, 2, 6, 7, 8, 3, 4, 5, 9, 10, 11, 15, 16, 17, 12, 13, 14, 18, 337 | 19, 20, 24, 25, 26, 21, 22, 23, 27, 28, 29, 33, 34, 35, 30, 31, 32, 338 | 36, 37, 38, 42, 43, 44, 39, 40, 41, 45, 46, 47, 51, 52, 53, 48, 49, 339 | 50, 57, 58, 59, 54, 55, 56, 63, 64, 65, 60, 61, 62, 69, 70, 71, 66, 340 | 67, 68 341 | ]) 342 | 343 | sign_flip = np.array([ 344 | 1, -1, -1, 1, -1, -1, 1, -1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 345 | -1, 1, -1, -1, 1, -1, -1, 1, -1, -1, 1, -1, -1, 1, -1, -1, 1, 346 | -1, -1, 1, -1, -1, 1, -1, -1, 1, -1, -1, 1, -1, -1, 1, -1, -1, 347 | 1, -1, -1, 1, -1, -1, 1, -1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 348 | -1, 1, -1, -1 349 | ]) 350 | 351 | return poses[swap_inds] * sign_flip 352 | 353 | ''' 354 | purpose: 355 | crop the image 356 | inputs: 357 | image_path: the 358 | ''' 359 | def crop_image(image_path, angle, lt, rb, scale, kp_2d, crop_size): 360 | ''' 361 | given a crop box, expand it at 4 directions.(left, right, top, bottom) 362 | ''' 363 | assert 'error algorithm exist.' and 0 364 | 365 | def _expand_crop_box(lt, rb, scale): 366 | center = (lt + rb) / 2.0 367 | xl, xr, yt, yb = lt[0] - center[0], rb[0] - center[0], lt[1] - center[1], rb[1] - center[1] 368 | xl, xr, yt, yb = xl * scale[0], xr * scale[1], yt * scale[2], yb * scale[3] 369 | #expand it 370 | lt, rb = np.array([center[0] + xl, center[1] + yt]), np.array([center[0] + xr, center[1] + yb]) 371 | lb, rt = np.array([center[0] + xl, center[1] + yb]), np.array([center[0] + xr, center[1] + yt]) 372 | center = (lt + rb) / 2 373 | return center, lt, rt, rb, lb 374 | 375 | ''' 376 | extend the box to square 377 | ''' 378 | def _extend_box(center, lt, rt, rb, lb, crop_size): 379 | lx, ly = np.linalg.norm(rt - lt), np.linalg.norm(lb - lt) 380 | dx, dy = (rt - lt) / lx, (lb - lt) / ly 381 | l = max(lx, ly) / 2.0 382 | return center - l * dx - l * dy, center + l * dx - l *dy, center + l * dx + l * dy, center - l * dx + l * dy, dx, dy, crop_size * 1.0 / l 383 | 384 | def _get_sample_points(lt, rt, rb, lb, crop_size): 385 | vec_x = rt - lt 386 | vec_y = lb - lt 387 | i_x, i_y = np.meshgrid(range(crop_size), range(crop_size)) 388 | i_x = i_x.astype(np.float) 389 | i_y = i_y.astype(np.float) 390 | i_x /= float(crop_size) 391 | i_y /= float(crop_size) 392 | interp_points = i_x[..., np.newaxis].repeat(2, axis=2) * vec_x + i_y[..., np.newaxis].repeat(2, axis=2) * vec_y 393 | interp_points += lt 394 | return interp_points 395 | 396 | def _sample_image(src_image, interp_points): 397 | sample_method = 'nearest' 398 | interp_image = np.zeros((interp_points.shape[0] * interp_points.shape[1], src_image.shape[2])) 399 | i_x = range(src_image.shape[1]) 400 | i_y = range(src_image.shape[0]) 401 | flatten_interp_points = interp_points.reshape([interp_points.shape[0]*interp_points.shape[1], 2]) 402 | for i_channel in range(src_image.shape[2]): 403 | interp_image[:, i_channel] = interpolate.interpn((i_y, i_x), src_image[:, :, i_channel], 404 | flatten_interp_points[:, [1, 0]], method = sample_method, 405 | bounds_error=False, fill_value=0) 406 | interp_image = interp_image.reshape((interp_points.shape[0], interp_points.shape[1], src_image.shape[2])) 407 | 408 | return interp_image 409 | 410 | def _trans_kp_2d(kps, center, dx, dy, lt, ratio): 411 | kp2d_offset = kps[:, :2] - center 412 | proj_x, proj_y = np.dot(kp2d_offset, dx), np.dot(kp2d_offset, dy) 413 | #kp2d = (dx * proj_x + dy * proj_y + lt) * ratio 414 | for idx in range(len(kps)): 415 | kps[idx, :2] = (dx * proj_x[idx] + dy * proj_y[idx] + lt) * ratio 416 | return kps 417 | 418 | 419 | src_image = cv2.imread(image_path) 420 | 421 | center, lt, rt, rb, lb = _expand_crop_box(lt, rb, scale) 422 | 423 | #calc rotated box 424 | radian = angle * np.pi / 180.0 425 | v_sin, v_cos = math.sin(radian), math.cos(radian) 426 | 427 | rot_matrix = np.array( 428 | [ 429 | [v_cos, v_sin], 430 | [-v_sin, v_cos] 431 | ] 432 | ) 433 | 434 | n_corner = (np.dot(rot_matrix, np.array([lt - center, rt - center, rb - center, lb - center]).T).T) + center 435 | n_lt, n_rt, n_rb, n_lb = n_corner[0], n_corner[1], n_corner[2], n_corner[3] 436 | 437 | lt, rt, rb, lb = calc_obb(np.array([lt, rt, rb, lb, n_lt, n_rt, n_rb, n_lb])) 438 | lt, rt, rb, lb, dx, dy, ratio = _extend_box(center, lt, rt, rb, lb, crop_size = crop_size) 439 | s_pts = _get_sample_points(lt, rt, rb, lb, crop_size) 440 | dst_image = _sample_image(src_image, s_pts) 441 | kp_2d = _trans_kp_2d(kp_2d, center, dx, dy, lt, ratio) 442 | 443 | return dst_image, kp_2d 444 | 445 | 446 | ''' 447 | purpose: 448 | flip a image given by src_image and the 2d keypoints 449 | flip_mode: 450 | 0: horizontal flip 451 | >0: vertical flip 452 | <0: horizontal & vertical flip 453 | ''' 454 | def flip_image(src_image, kps): 455 | h, w = src_image.shape[0], src_image.shape[1] 456 | src_image = cv2.flip(src_image, 1) 457 | if kps is not None: 458 | kps[:, 0] = w - 1 - kps[:, 0] 459 | kp_map = [5, 4, 3, 2, 1, 0, 11, 10, 9, 8, 7, 6, 12, 13] 460 | kps[:, :] = kps[kp_map] 461 | 462 | return src_image, kps 463 | 464 | ''' 465 | src_image: h x w x c 466 | pts: 14 x 3 467 | ''' 468 | def draw_lsp_14kp__bone(src_image, pts): 469 | bones = [ 470 | [0, 1, 255, 0, 0], 471 | [1, 2, 255, 0, 0], 472 | [2, 12, 255, 0, 0], 473 | [3, 12, 0, 0, 255], 474 | [3, 4, 0, 0, 255], 475 | [4, 5, 0, 0, 255], 476 | [12, 9, 0, 0, 255], 477 | [9,10, 0, 0, 255], 478 | [10,11, 0, 0, 255], 479 | [12, 8, 255, 0, 0], 480 | [8,7, 255, 0, 0], 481 | [7,6, 255, 0, 0], 482 | [12, 13, 0, 255, 0] 483 | ] 484 | 485 | for pt in pts: 486 | if pt[2] > 0.2: 487 | cv2.circle(src_image,(int(pt[0]), int(pt[1])),2,(0,255,255),-1) 488 | 489 | for line in bones: 490 | pa = pts[line[0]] 491 | pb = pts[line[1]] 492 | xa,ya,xb,yb = int(pa[0]),int(pa[1]),int(pb[0]),int(pb[1]) 493 | if pa[2] > 0.2 and pb[2] > 0.2: 494 | cv2.line(src_image,(xa,ya),(xb,yb),(line[2], line[3], line[4]),2) 495 | 496 | ''' 497 | return whether two segment intersect 498 | ''' 499 | 500 | def line_intersect(sa, sb): 501 | al, ar, bl, br = sa[0], sa[1], sb[0], sb[1] 502 | assert al <= ar and bl <= br 503 | if al >= br or bl >= ar: 504 | return False 505 | return True 506 | 507 | ''' 508 | return whether two rectangle intersect 509 | ra, rb left_top point, right_bottom point 510 | ''' 511 | def rectangle_intersect(ra, rb): 512 | ax = [ra[0][0], ra[1][0]] 513 | ay = [ra[0][1], ra[1][1]] 514 | 515 | bx = [rb[0][0], rb[1][0]] 516 | by = [rb[0][1], rb[1][1]] 517 | 518 | return line_intersect(ax, bx) and line_intersect(ay, by) 519 | 520 | def get_intersected_rectangle(lt0, rb0, lt1, rb1): 521 | if not rectangle_intersect([lt0, rb0], [lt1, rb1]): 522 | return None, None 523 | 524 | lt = lt0.copy() 525 | rb = rb0.copy() 526 | 527 | lt[0] = max(lt[0], lt1[0]) 528 | lt[1] = max(lt[1], lt1[1]) 529 | 530 | rb[0] = min(rb[0], rb1[0]) 531 | rb[1] = min(rb[1], rb1[1]) 532 | return lt, rb 533 | 534 | def get_union_rectangle(lt0, rb0, lt1, rb1): 535 | lt = lt0.copy() 536 | rb = rb0.copy() 537 | 538 | lt[0] = min(lt[0], lt1[0]) 539 | lt[1] = min(lt[1], lt1[1]) 540 | 541 | rb[0] = max(rb[0], rb1[0]) 542 | rb[1] = max(rb[1], rb1[1]) 543 | return lt, rb 544 | 545 | def get_rectangle_area(lt, rb): 546 | return (rb[0] - lt[0]) * (rb[1] - lt[1]) 547 | 548 | def get_rectangle_intersect_ratio(lt0, rb0, lt1, rb1): 549 | (lt0, rb0), (lt1, rb1) = get_intersected_rectangle(lt0, rb0, lt1, rb1), get_union_rectangle(lt0, rb0, lt1, rb1) 550 | 551 | if lt0 is None: 552 | return 0.0 553 | else: 554 | return 1.0 * get_rectangle_area(lt0, rb0) / get_rectangle_area(lt1, rb1) 555 | 556 | def convert_image_by_pixformat_normalize(src_image, pix_format, normalize): 557 | if pix_format == 'NCHW': 558 | src_image = src_image.transpose((2, 0, 1)) 559 | 560 | if normalize: 561 | src_image = (src_image.astype(np.float) / 255) * 2.0 - 1.0 562 | 563 | return src_image 564 | 565 | ''' 566 | align ty pelvis 567 | joints: n x 14 x 3, by lsp order 568 | ''' 569 | def align_by_pelvis(joints): 570 | left_id = 3 571 | right_id = 2 572 | pelvis = (joints[:, left_id, :] + joints[:, right_id, :]) / 2.0 573 | return joints - torch.unsqueeze(pelvis, dim=1) 574 | 575 | def copy_state_dict(cur_state_dict, pre_state_dict, prefix = ''): 576 | def _get_params(key): 577 | key = prefix + key 578 | if key in pre_state_dict: 579 | return pre_state_dict[key] 580 | return None 581 | 582 | for k in cur_state_dict.keys(): 583 | v = _get_params(k) 584 | try: 585 | if v is None: 586 | print('parameter {} not found'.format(k)) 587 | continue 588 | cur_state_dict[k].copy_(v) 589 | except: 590 | print('copy param {} failed'.format(k)) 591 | continue 592 | 593 | if __name__ == '__main__': 594 | image_path = 'E:/HMR/data/COCO/images/train-valid2017/000000000009.jpg' 595 | lt = np.array([-10, -10], dtype = np.float) 596 | rb = np.array([10,10], dtype = np.float) 597 | print(crop_image(image_path, 45, lt, rb, [1, 1, 1, 1], None)) 598 | --------------------------------------------------------------------------------