├── README.md
├── images
├── main.py
├── mesh1.png
├── mesh2.png
├── mesh3.png
├── paper_picked.png
└── r.png
└── src
├── Discriminator.py
├── HourGlass.py
├── LinearModel.py
├── PRNetEncoder.py
├── Resnet.py
├── SMPL.py
├── config.py
├── dataloader
├── AICH_dataloader.py
├── COCO2017_dataloader.py
├── eval_dataloader.py
├── hum36m_dataloader.py
├── lsp_dataloader.py
├── lsp_ext_dataloader.py
├── mosh_dataloader.py
└── mpi_inf_3dhp_dataloader.py
├── densenet.py
├── do_train.sh
├── model.py
├── timer.py
├── trainer.py
└── util.py
/README.md:
--------------------------------------------------------------------------------
1 | ## HMR
2 |
3 |
4 |
5 |
6 | This is a **pytorch** implementation of [End-to-end Recovery of Human Shape and Pose](https://arxiv.org/abs/1712.06584) by *Angjoo Kanazawa, Michael J. Black, David W. Jacobs*, and *Jitendra Malik*, accompanying by some famous human pose estimation networks and datasets.
7 | HMR is an end-to end framework for reconstructing a full 3D mesh of a human body from a single RGB image. In contrast to most current methods that compute 2D or 3D joint locations, HMR produce a richer and more useful mesh representation that is parameterized by shape and 3D joint angles. The main objective is to minimize the reprojection loss of keypoints, which allow model to be trained using in-the-wild images that only have ground truth 2D annotations. For visual impact, please visit the [author's original video.](https://www.youtube.com/watch?v=bmMV9aJKa-c)
8 |
9 | ## training step (the following links are not available now due to license limitation)
10 | #### 1. download the following datasets.
11 | - [AI challenger keypoint dataset](https://challenger.ai/datasets/keypoint)
12 | - [lsp 14-keypoint dataset](https://pan.baidu.com/s/1BgKRJfggJcObHXkzHH5I5A)
13 | - [lsp 14-keypoint extension dataset](https://pan.baidu.com/s/1uUcsdCKbzIwKCc9SzVFXAA)
14 | - [COCO-2017-keypoint dataset](http://cocodataset.org/)
15 | - [mpi_inf_3dhp 3d keypoint dataset](https://pan.baidu.com/s/1XQZNV3KPtiBi5ODnr7RB9A)
16 | - [mosh dataset, which used for adv training](https://pan.baidu.com/s/1OWzeMeLS5tKx1XGAiyZ0XA)
17 | #### 2. download human3.6 datasets.
18 | - [hum3.6m_part_1.zip](https://pan.baidu.com/s/1oeO213vrKyYEr46P1OBEgw)
19 | - [hum3.6m_part_2.zip](https://pan.baidu.com/s/1XRnNn0qJeo5TECacjiJv4g)
20 | - [hum3.6m_part_3.zip](https://pan.baidu.com/s/15AOngXr3zya2XsK7Sry97g)
21 | - [hum3.6m_part_4.zip](https://pan.baidu.com/s/1RNqWSP1KREBhvPHn6-pCbA)
22 | - [hum3.6m_part_5.zip](https://pan.baidu.com/s/109RwxgpWxEraXzIXf7iYkg)
23 | - [hum3.6m_anno.zip](https://pan.baidu.com/s/1kCOQ2qzf69RLX3VN4cw5Mw)
24 | #### 3. unzip the downloaded datasets.
25 | #### 4. unzip the [model.zip](https://pan.baidu.com/s/1PUv5kUydmx5RG1E0KsQBkw)
26 | #### 5. config the environment by modify the src/config.py and do_train.sh
27 | #### 6. run ./do_train.sh directly
28 |
29 | ## environment configurations.
30 | - install **pytorch0.4**
31 | - install torchvision
32 | - install numpy
33 | - install scipy
34 | - install h5py
35 | - install opencv-python
36 |
37 | ## result
38 |
39 |
40 |
41 |
42 |
43 | ## reference papers
44 | - [Stacked Hourglass Networks for Human Pose Estimation](https://arxiv.org/abs/1603.06937)
45 | - [SMPL: A Skinned Multi-Person Linear Model](http://files.is.tue.mpg.de/black/papers/SMPL2015.pdf)
46 | - [Keep it SMPL: Automatic Estimation of 3D Human Pose and Shape from a Single Image](https://pdfs.semanticscholar.org/4cea/52b44fc5cb1803a07fa466b6870c25535313.pdf)
47 | - [motion and shape capture from sparse markers](http://files.is.tue.mpg.de/black/papers/MoSh.pdf)
48 | - [Unite the People: Closing the Loop Between 3D and 2D Human Representations](https://arxiv.org/abs/1701.02468)
49 | - [End-to-end Recovery of Human Shape and Pose](https://arxiv.org/abs/1712.06584)
50 |
51 | ## reference resources
52 | - [up-3d dataset](http://files.is.tuebingen.mpg.de/classner/up/)
53 | - [coco-2017 dataset](http://cocodataset.org/)
54 | - [human 3.6m datas](http://vision.imar.ro/human3.6m/description.php)
55 | - [ai challenger dataset](https://challenger.ai/)
56 |
57 |
--------------------------------------------------------------------------------
/images/main.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 |
4 |
5 | sa = cv2.imread('./mesh1.png')
6 | sb = cv2.imread('./mesh2.png')
7 | sc = cv2.imread('./mesh3.png')
8 |
9 | da = dst_image = cv2.resize(sa, (300, 400), interpolation = cv2.INTER_CUBIC)
10 | db = dst_image = cv2.resize(sb, (300, 400), interpolation = cv2.INTER_CUBIC)
11 | dc = dst_image = cv2.resize(sc, (300, 400), interpolation = cv2.INTER_CUBIC)
12 |
13 | d = np.zeros((400, 900, 3))
14 | d[:, :300, :] = da[:, :, :]
15 | d[:, 300:600,:] = db[:, :, :]
16 | d[:, 600:, :] = dc[:, :, :]
17 |
18 | cv2.imwrite('r.png', d)
--------------------------------------------------------------------------------
/images/mesh1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MandyMo/pytorch_HMR/7bf18d619aeafd97e9df7364e354cd7e9480966f/images/mesh1.png
--------------------------------------------------------------------------------
/images/mesh2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MandyMo/pytorch_HMR/7bf18d619aeafd97e9df7364e354cd7e9480966f/images/mesh2.png
--------------------------------------------------------------------------------
/images/mesh3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MandyMo/pytorch_HMR/7bf18d619aeafd97e9df7364e354cd7e9480966f/images/mesh3.png
--------------------------------------------------------------------------------
/images/paper_picked.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MandyMo/pytorch_HMR/7bf18d619aeafd97e9df7364e354cd7e9480966f/images/paper_picked.png
--------------------------------------------------------------------------------
/images/r.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MandyMo/pytorch_HMR/7bf18d619aeafd97e9df7364e354cd7e9480966f/images/r.png
--------------------------------------------------------------------------------
/src/Discriminator.py:
--------------------------------------------------------------------------------
1 |
2 | '''
3 | file: Discriminator.py
4 |
5 | date: 2017_04_29
6 | author: zhangxiong(1025679612@qq.com)
7 | '''
8 |
9 | from LinearModel import LinearModel
10 | import config
11 | import util
12 | import torch
13 | import numpy as np
14 | import torch.nn as nn
15 | from config import args
16 |
17 | '''
18 | shape discriminator is used for shape discriminator
19 | the inputs if N x 10
20 | '''
21 | class ShapeDiscriminator(LinearModel):
22 | def __init__(self, fc_layers, use_dropout, drop_prob, use_ac_func):
23 | if fc_layers[-1] != 1:
24 | msg = 'the neuron count of the last layer must be 1, but got {}'.format(fc_layers[-1])
25 | sys.exit(msg)
26 |
27 | super(ShapeDiscriminator, self).__init__(fc_layers, use_dropout, drop_prob, use_ac_func)
28 |
29 | def forward(self, inputs):
30 | return self.fc_blocks(inputs)
31 |
32 | class PoseDiscriminator(nn.Module):
33 | def __init__(self, channels):
34 | super(PoseDiscriminator, self).__init__()
35 |
36 | if channels[-1] != 1:
37 | msg = 'the neuron count of the last layer must be 1, but got {}'.format(channels[-1])
38 | sys.exit(msg)
39 |
40 | self.conv_blocks = nn.Sequential()
41 | l = len(channels)
42 | for idx in range(l - 2):
43 | self.conv_blocks.add_module(
44 | name = 'conv_{}'.format(idx),
45 | module = nn.Conv2d(in_channels = channels[idx], out_channels = channels[idx + 1], kernel_size = 1, stride = 1)
46 | )
47 |
48 | self.fc_layer = nn.ModuleList()
49 | for idx in range(23):
50 | self.fc_layer.append(nn.Linear(in_features = channels[l - 2], out_features = 1))
51 |
52 | # N x 23 x 9
53 | def forward(self, inputs):
54 | batch_size = inputs.shape[0]
55 | inputs = inputs.transpose(1, 2).unsqueeze(2) # to N x 9 x 1 x 23
56 | internal_outputs = self.conv_blocks(inputs) # to N x c x 1 x 23
57 | o = []
58 | for idx in range(23):
59 | o.append(self.fc_layer[idx](internal_outputs[:,:,0,idx]))
60 |
61 | return torch.cat(o, 1), internal_outputs
62 |
63 | class FullPoseDiscriminator(LinearModel):
64 | def __init__(self, fc_layers, use_dropout, drop_prob, use_ac_func):
65 | if fc_layers[-1] != 1:
66 | msg = 'the neuron count of the last layer must be 1, but got {}'.format(fc_layers[-1])
67 | sys.exit(msg)
68 |
69 | super(FullPoseDiscriminator, self).__init__(fc_layers, use_dropout, drop_prob, use_ac_func)
70 |
71 | def forward(self, inputs):
72 | return self.fc_blocks(inputs)
73 |
74 | class Discriminator(nn.Module):
75 | def __init__(self):
76 | super(Discriminator, self).__init__()
77 | self._read_configs()
78 |
79 | self._create_sub_modules()
80 |
81 | def _read_configs(self):
82 | self.beta_count = args.beta_count
83 | self.smpl_model = args.smpl_model
84 | self.smpl_mean_theta_path = args.smpl_mean_theta_path
85 | self.total_theta_count = args.total_theta_count
86 | self.joint_count = args.joint_count
87 | self.feature_count = args.feature_count
88 |
89 | def _create_sub_modules(self):
90 | '''
91 | create theta discriminator for 23 joint
92 | '''
93 |
94 | self.pose_discriminator = PoseDiscriminator([9, 32, 32, 1])
95 |
96 | '''
97 | create full pose discriminator for total 23 joints
98 | '''
99 | fc_layers = [23 * 32, 1024, 1024, 1]
100 | use_dropout = [False, False, False]
101 | drop_prob = [0.5, 0.5, 0.5]
102 | use_ac_func = [True, True, False]
103 | self.full_pose_discriminator = FullPoseDiscriminator(fc_layers, use_dropout, drop_prob, use_ac_func)
104 |
105 | '''
106 | shape discriminator for betas
107 | '''
108 | fc_layers = [self.beta_count, 5, 1]
109 | use_dropout = [False, False]
110 | drop_prob = [0.5, 0.5]
111 | use_ac_func = [True, False]
112 | self.shape_discriminator = ShapeDiscriminator(fc_layers, use_dropout, drop_prob, use_ac_func)
113 |
114 | print('finished create the discriminator modules...')
115 |
116 |
117 | '''
118 | inputs is N x 85(3 + 72 + 10)
119 | '''
120 | def forward(self, thetas):
121 | batch_size = thetas.shape[0]
122 | cams, poses, shapes = thetas[:, :3], thetas[:, 3:75], thetas[:, 75:]
123 | shape_disc_value = self.shape_discriminator(shapes)
124 | rotate_matrixs = util.batch_rodrigues(poses.contiguous().view(-1, 3)).view(-1, 24, 9)[:, 1:, :]
125 | pose_disc_value, pose_inter_disc_value = self.pose_discriminator(rotate_matrixs)
126 | full_pose_disc_value = self.full_pose_discriminator(pose_inter_disc_value.contiguous().view(batch_size, -1))
127 | return torch.cat((pose_disc_value, full_pose_disc_value, shape_disc_value), 1)
128 |
129 | if __name__ == '__main__':
130 | device = torch.device('cuda')
131 | net = Discriminator()
132 | inputs = torch.ones((100, 85))
133 | disc_value = net(inputs)
134 | print(net)
--------------------------------------------------------------------------------
/src/HourGlass.py:
--------------------------------------------------------------------------------
1 |
2 | '''
3 | file: hourglass.py
4 |
5 | date: 2018_05_12
6 | author: zhangxiong(1025679612@qq.com)
7 | '''
8 |
9 | from __future__ import print_function
10 | import numpy as np
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.functional as F
14 |
15 | class Residual(nn.Module):
16 | def __init__(self, use_bn, input_channels, out_channels, mid_channels):
17 | super(Residual, self).__init__()
18 | self.use_bn = use_bn
19 | self.out_channels = out_channels
20 | self.input_channels = input_channels
21 | self.mid_channels = mid_channels
22 |
23 | self.down_channel = nn.Conv2d(input_channels, self.mid_channels, kernel_size = 1)
24 | self.AcFunc = nn.ReLU()
25 | if use_bn:
26 | self.bn_0 = nn.BatchNorm2d(num_features = self.mid_channels)
27 | self.bn_1 = nn.BatchNorm2d(num_features = self.mid_channels)
28 | self.bn_2 = nn.BatchNorm2d(num_features = self.out_channels)
29 |
30 | self.conv = nn.Conv2d(self.mid_channels, self.mid_channels, kernel_size = 3, padding = 1)
31 |
32 | self.up_channel = nn.Conv2d(self.mid_channels, out_channels, kernel_size= 1)
33 |
34 | if input_channels != out_channels:
35 | self.trans = nn.Conv2d(input_channels, out_channels, kernel_size = 1)
36 |
37 | def forward(self, inputs):
38 | x = self.down_channel(inputs)
39 | if self.use_bn:
40 | x = self.bn_0(x)
41 | x = self.AcFunc(x)
42 |
43 | x = self.conv(x)
44 | if self.use_bn:
45 | x = self.bn_1(x)
46 | x = self.AcFunc(x)
47 |
48 | x = self.up_channel(x)
49 |
50 | if self.input_channels != self.out_channels:
51 | x += self.trans(inputs)
52 | else:
53 | x += inputs
54 |
55 | if self.use_bn:
56 | x = self.bn_2(x)
57 |
58 | return self.AcFunc(x)
59 |
60 | class HourGlassBlock(nn.Module):
61 | def __init__(self, block_count, residual_each_block, input_channels, mid_channels, use_bn, stack_index):
62 | super(HourGlassBlock, self).__init__()
63 |
64 | self.block_count = block_count
65 | self.residual_each_block = residual_each_block
66 | self.use_bn = use_bn
67 | self.stack_index = stack_index
68 | self.input_channels = input_channels
69 | self.mid_channels = mid_channels
70 |
71 | if self.block_count == 0: #inner block
72 | self.process = nn.Sequential()
73 | for _ in range(residual_each_block * 3):
74 | self.process.add_module(
75 | name = 'inner_{}_{}'.format(self.stack_index, _),
76 | module = Residual(input_channels = input_channels, out_channels = input_channels, mid_channels = mid_channels, use_bn = use_bn)
77 | )
78 | else:
79 | #down sampling
80 | self.down_sampling = nn.Sequential()
81 | self.down_sampling.add_module(
82 | name = 'down_sample_{}_{}'.format(self.stack_index, self.block_count),
83 | module = nn.MaxPool2d(kernel_size = 2, stride = 2)
84 | )
85 | for _ in range(residual_each_block):
86 | self.down_sampling.add_module(
87 | name = 'residual_{}_{}_{}'.format(self.stack_index, self.block_count, _),
88 | module = Residual(input_channels = input_channels, out_channels = input_channels, mid_channels = mid_channels, use_bn = use_bn)
89 | )
90 |
91 | #up sampling
92 | self.up_sampling = nn.Sequential()
93 | self.up_sampling.add_module(
94 | name = 'up_sample_{}_{}'.format(self.stack_index, self.block_count),
95 | module = nn.Upsample(scale_factor=2, mode='bilinear')
96 | )
97 | for _ in range(residual_each_block):
98 | self.up_sampling.add_module(
99 | name = 'residual_{}_{}_{}'.format(self.stack_index, self.block_count, _),
100 | module = Residual(input_channels = input_channels, out_channels = input_channels, mid_channels = mid_channels, use_bn = use_bn)
101 | )
102 |
103 | #sub hour glass
104 | self.sub_hg = HourGlassBlock(
105 | block_count = self.block_count - 1,
106 | residual_each_block = self.residual_each_block,
107 | input_channels = self.input_channels,
108 | mid_channels = self.mid_channels,
109 | use_bn = self.use_bn,
110 | stack_index = self.stack_index
111 | )
112 |
113 | # trans
114 | self.trans = nn.Sequential()
115 | for _ in range(residual_each_block):
116 | self.trans.add_module(
117 | name = 'trans_{}_{}_{}'.format(self.stack_index, self.block_count, _),
118 | module = Residual(input_channels = input_channels, out_channels = input_channels, mid_channels = mid_channels, use_bn = use_bn)
119 | )
120 |
121 |
122 | def forward(self, inputs):
123 | if self.block_count == 0:
124 | return self.process(inputs)
125 | else:
126 | down_sampled = self.down_sampling(inputs)
127 | transed = self.trans(down_sampled)
128 | sub_net_output = self.sub_hg(down_sampled)
129 | return self.up_sampling(transed + sub_net_output)
130 |
131 | '''
132 | the input is a 256 x 256 x 3 image
133 | '''
134 | class HourGlass(nn.Module):
135 | def __init__(self, nStack, nBlockCount, nResidualEachBlock, nMidChannels, nChannels, nJointCount, bUseBn):
136 | super(HourGlass, self).__init__()
137 |
138 | self.nStack = nStack
139 | self.nBlockCount = nBlockCount
140 | self.nResidualEachBlock = nResidualEachBlock
141 | self.nChannels = nChannels
142 | self.nMidChannels = nMidChannels
143 | self.nJointCount = nJointCount
144 | self.bUseBn = bUseBn
145 |
146 | self.pre_process = nn.Sequential(
147 | nn.Conv2d(3, nChannels, kernel_size = 3, padding = 1),
148 | Residual(use_bn = bUseBn, input_channels = nChannels, out_channels = nChannels, mid_channels = nMidChannels),
149 | nn.MaxPool2d(kernel_size = 2, stride = 2), # to 128 x 128 x c
150 | Residual(use_bn = bUseBn, input_channels = nChannels, out_channels = nChannels, mid_channels = nMidChannels),
151 | nn.MaxPool2d(kernel_size = 2, stride = 2), # to 64 x 64 x c
152 | Residual(use_bn = bUseBn, input_channels = nChannels, out_channels = nChannels, mid_channels = nMidChannels)
153 | )
154 |
155 | self.hg = nn.ModuleList()
156 | for _ in range(nStack):
157 | self.hg.append(
158 | HourGlassBlock(
159 | block_count = nBlockCount,
160 | residual_each_block = nResidualEachBlock,
161 | input_channels = nChannels,
162 | mid_channels = nMidChannels,
163 | use_bn = bUseBn,
164 | stack_index = _
165 | )
166 | )
167 |
168 | self.blocks = nn.ModuleList()
169 | for _ in range(nStack - 1):
170 | self.blocks.append(
171 | nn.Sequential(
172 | Residual(
173 | use_bn = bUseBn, input_channels = nChannels, out_channels = nChannels, mid_channels = nMidChannels
174 | ),
175 | Residual(
176 | use_bn = bUseBn, input_channels = nChannels, out_channels = nChannels, mid_channels = nMidChannels
177 | )
178 | )
179 | )
180 |
181 | self.intermediate_supervision = nn.ModuleList()
182 | for _ in range(nStack): # to 64 x 64 x joint_count
183 | self.intermediate_supervision.append(
184 | nn.Conv2d(nChannels, nJointCount, kernel_size = 1, stride = 1)
185 | )
186 |
187 | self.normal_feature_channel = nn.ModuleList()
188 | for _ in range(nStack - 1):
189 | self.normal_feature_channel.append(
190 | Residual(
191 | use_bn = bUseBn, input_channels = nJointCount, out_channels = nChannels, mid_channels = nMidChannels
192 | )
193 | )
194 |
195 | def forward(self, inputs):
196 | o = [] #outputs include intermediate supervision result
197 | x = self.pre_process(inputs)
198 | for _ in range(self.nStack):
199 | o1 = self.hg[_](x)
200 | o2 = self.intermediate_supervision[_](o1)
201 | o.append(o2.view(-1, 4096))
202 | if _ == self.nStack - 1:
203 | break
204 | o2 = self.normal_feature_channel[_](o2)
205 | o1 = self.blocks[_](o1)
206 | x = o1 + o2 + x
207 | return o
208 |
209 | def _create_hourglass_net():
210 | return HourGlass(
211 | nStack = 2,
212 | nBlockCount = 4,
213 | nResidualEachBlock = 1,
214 | nMidChannels = 128,
215 | nChannels = 256,
216 | nJointCount = 1,
217 | bUseBn = True,
218 | )
--------------------------------------------------------------------------------
/src/LinearModel.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | '''
4 | file: LinearModel.py
5 |
6 | date: 2018_04_29
7 | author: zhangxiong(1025679612@qq.com)
8 | '''
9 |
10 | import torch.nn as nn
11 | import numpy as np
12 | import sys
13 | import torch
14 |
15 | class LinearModel(nn.Module):
16 | '''
17 | input param:
18 | fc_layers: a list of neuron count, such as [2133, 1024, 1024, 85]
19 | use_dropout: a list of bool define use dropout or not for each layer, such as [True, True, False]
20 | drop_prob: a list of float defined the drop prob, such as [0.5, 0.5, 0]
21 | use_ac_func: a list of bool define use active function or not, such as [True, True, False]
22 | '''
23 | def __init__(self, fc_layers, use_dropout, drop_prob, use_ac_func):
24 | super(LinearModel, self).__init__()
25 | self.fc_layers = fc_layers
26 | self.use_dropout = use_dropout
27 | self.drop_prob = drop_prob
28 | self.use_ac_func = use_ac_func
29 |
30 | if not self._check():
31 | msg = 'wrong LinearModel parameters!'
32 | print(msg)
33 | sys.exit(msg)
34 |
35 | self.create_layers()
36 |
37 | def _check(self):
38 | while True:
39 | if not isinstance(self.fc_layers, list):
40 | print('fc_layers require list, get {}'.format(type(self.fc_layers)))
41 | break
42 |
43 | if not isinstance(self.use_dropout, list):
44 | print('use_dropout require list, get {}'.format(type(self.use_dropout)))
45 | break
46 |
47 | if not isinstance(self.drop_prob, list):
48 | print('drop_prob require list, get {}'.format(type(self.drop_prob)))
49 | break
50 |
51 | if not isinstance(self.use_ac_func, list):
52 | print('use_ac_func require list, get {}'.format(type(self.use_ac_func)))
53 | break
54 |
55 | l_fc_layer = len(self.fc_layers)
56 | l_use_drop = len(self.use_dropout)
57 | l_drop_porb = len(self.drop_prob)
58 | l_use_ac_func = len(self.use_ac_func)
59 |
60 | return l_fc_layer >= 2 and l_use_drop < l_fc_layer and l_drop_porb < l_fc_layer and l_use_ac_func < l_fc_layer and l_drop_porb == l_use_drop
61 |
62 | return False
63 |
64 | def create_layers(self):
65 | l_fc_layer = len(self.fc_layers)
66 | l_use_drop = len(self.use_dropout)
67 | l_drop_porb = len(self.drop_prob)
68 | l_use_ac_func = len(self.use_ac_func)
69 |
70 | self.fc_blocks = nn.Sequential()
71 |
72 | for _ in range(l_fc_layer - 1):
73 | self.fc_blocks.add_module(
74 | name = 'regressor_fc_{}'.format(_),
75 | module = nn.Linear(in_features = self.fc_layers[_], out_features = self.fc_layers[_ + 1])
76 | )
77 |
78 | if _ < l_use_ac_func and self.use_ac_func[_]:
79 | self.fc_blocks.add_module(
80 | name = 'regressor_af_{}'.format(_),
81 | module = nn.ReLU()
82 | )
83 |
84 | if _ < l_use_drop and self.use_dropout[_]:
85 | self.fc_blocks.add_module(
86 | name = 'regressor_fc_dropout_{}'.format(_),
87 | module = nn.Dropout(p = self.drop_prob[_])
88 | )
89 |
90 | def forward(self, inputs):
91 | msg = 'the base class [LinearModel] is not callable!'
92 | sys.exit(msg)
93 |
94 | if __name__ == '__main__':
95 | fc_layers = [2133, 1024, 1024, 85]
96 | iterations = 3
97 | use_dropout = [True, True, False]
98 | drop_prob = [0.5, 0.5, 0]
99 | use_ac_func = [True, True, False]
100 | device = torch.device('cuda')
101 | net = LinearModel(fc_layers, use_dropout, drop_prob, use_ac_func).to(device)
102 | print(net)
103 | nx = np.zeros([2, 2048])
104 | vx = torch.from_numpy(nx).to(device)
105 |
--------------------------------------------------------------------------------
/src/PRNetEncoder.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | '''
4 | file: PRnetEncoder.py
5 |
6 | date: 2018_05_22
7 | author: zhangxiong(1025679612@qq.com)
8 | mark: the algorithm is cited from PRNet code
9 | '''
10 |
11 | from __future__ import print_function
12 | import numpy as np
13 | import torch
14 | import torch.nn as nn
15 | import torch.nn.functional as F
16 |
17 | class Residual(nn.Module):
18 | def __init__(self, use_bn, input_channels, out_channels, mid_channels, kernel_size = 3, padding = 1, stride = 1):
19 | super(Residual, self).__init__()
20 | self.use_bn = use_bn
21 | self.out_channels = out_channels
22 | self.input_channels = input_channels
23 | self.mid_channels = mid_channels
24 |
25 | self.down_channel = nn.Conv2d(input_channels, self.mid_channels, kernel_size = 1)
26 | self.AcFunc = nn.ReLU()
27 | if use_bn:
28 | self.bn_0 = nn.BatchNorm2d(num_features = self.mid_channels)
29 | self.bn_1 = nn.BatchNorm2d(num_features = self.mid_channels)
30 | self.bn_2 = nn.BatchNorm2d(num_features = self.out_channels)
31 |
32 | self.conv = nn.Conv2d(self.mid_channels, self.mid_channels, kernel_size = kernel_size, padding = padding, stride = stride)
33 |
34 | self.up_channel = nn.Conv2d(self.mid_channels, out_channels, kernel_size= 1)
35 |
36 | if input_channels != out_channels:
37 | self.trans = nn.Conv2d(input_channels, out_channels, kernel_size = 1)
38 |
39 | def forward(self, inputs):
40 | x = self.down_channel(inputs)
41 | if self.use_bn:
42 | x = self.bn_0(x)
43 | x = self.AcFunc(x)
44 |
45 | x = self.conv(x)
46 | if self.use_bn:
47 | x = self.bn_1(x)
48 | x = self.AcFunc(x)
49 |
50 | x = self.up_channel(x)
51 |
52 | if self.input_channels != self.out_channels:
53 | x += self.trans(inputs)
54 | else:
55 | x += inputs
56 |
57 | if self.use_bn:
58 | x = self.bn_2(x)
59 |
60 | return self.AcFunc(x)
61 |
62 | class PRNetEncoder(nn.Module):
63 | def __init__(self):
64 | super(PRNetEncoder, self).__init__()
65 | self.conv_blocks = nn.Sequential(
66 | nn.Conv2d(in_channels = 3, out_channels = 8, kernel_size = 3, stride = 1, padding = 1), # to 256 x 256 x 8
67 | nn.Conv2d(in_channels = 8, out_channels = 16, kernel_size = 3, stride = 1, padding = 1), # to 256 x 256 x 16
68 | Residual(use_bn = True, input_channels = 16, out_channels = 32, mid_channels = 16, stride = 1, padding = 1), # to 256 x 256 x 32
69 | nn.MaxPool2d(kernel_size = 2, stride = 2), # to 128 x 128 x 32
70 | Residual(use_bn = True, input_channels = 32, out_channels = 32, mid_channels = 16, stride = 1, padding = 1), # to 128 x 128 x 32
71 | Residual(use_bn = True, input_channels = 32, out_channels = 32, mid_channels = 16, stride = 1, padding = 1), # to 128 x 128 x 32
72 | Residual(use_bn = True, input_channels = 32, out_channels = 64, mid_channels = 32, stride = 1, padding = 1), # to 128 x 128 x 64
73 | nn.MaxPool2d(kernel_size = 2, stride = 2), # to 64 x 64 x 64
74 | Residual(use_bn = True, input_channels = 64, out_channels = 64, mid_channels = 32, stride = 1, padding = 1), # to 64 x 64 x 64
75 | Residual(use_bn = True, input_channels = 64, out_channels = 64, mid_channels = 32, stride = 1, padding = 1), # to 64 x 64 x 64
76 | Residual(use_bn = True, input_channels = 64, out_channels = 128, mid_channels = 64, stride = 1, padding = 1), # to 64 x 64 x 128
77 | nn.MaxPool2d(kernel_size = 2, stride = 2), # to 32 x 32 x 128
78 | Residual(use_bn = True, input_channels = 128, out_channels = 128, mid_channels = 64, stride = 1, padding = 1), # to 32 x 32 x 128
79 | Residual(use_bn = True, input_channels = 128, out_channels = 128, mid_channels = 64, stride = 1, padding = 1), # to 32 x 32 x 128
80 | Residual(use_bn = True, input_channels = 128, out_channels = 256, mid_channels = 128, stride = 1, padding = 1), # to 32 x 32 x 256
81 | nn.MaxPool2d(kernel_size = 2, stride = 2), # to 16 x 16 x 256
82 | Residual(use_bn = True, input_channels = 256, out_channels = 256, mid_channels = 128, stride = 1, padding = 1), # to 16 x 16 x 256
83 | Residual(use_bn = True, input_channels = 256, out_channels = 256, mid_channels = 128, stride = 1, padding = 1), # to 16 x 16 x 256
84 | Residual(use_bn = True, input_channels = 256, out_channels = 512, mid_channels = 256, stride = 1, padding = 1), # to 16 x 16 x 512
85 | nn.MaxPool2d(kernel_size = 2, stride = 2), # to 8 x 8 x 512
86 | Residual(use_bn = True, input_channels = 512, out_channels = 512, mid_channels = 256, stride = 1, padding = 1), # to 8 x 8 x 512
87 | nn.MaxPool2d(kernel_size = 2, stride = 2) , # to 4 x 4 x 512
88 | Residual(use_bn = True, input_channels = 512, out_channels = 512, mid_channels = 256, stride = 1, padding = 1), # to 4 x 4 x 512
89 | nn.MaxPool2d(kernel_size = 2, stride = 2), # to 2 x 2 x 512
90 | Residual(use_bn = True, input_channels = 512, out_channels = 512, mid_channels = 256, stride = 1, padding = 1) # to 2 x 2 x 512
91 | )
92 |
93 | def forward(self, inputs):
94 | return self.conv_blocks(inputs).view(-1, 2048)
95 |
96 |
97 | if __name__ == '__main__':
98 | net = PRNetEncoder()
99 | inputs = torch.ones(size = (10, 3, 256, 256)).float()
100 | r = net(inputs)
101 | print(r.shape)
--------------------------------------------------------------------------------
/src/Resnet.py:
--------------------------------------------------------------------------------
1 |
2 | '''
3 | file: Resnet.py
4 |
5 | date: 2018_05_02
6 | author: zhangxiong(1025679612@qq.com)
7 | mark: copied from pytorch sourc code
8 | '''
9 |
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | import torch
13 | from torch.nn.parameter import Parameter
14 | import torch.optim as optim
15 | import numpy as np
16 | import math
17 | import torchvision
18 |
19 | class ResNet(nn.Module):
20 | def __init__(self, block, layers, num_classes=1000):
21 | self.inplanes = 64
22 | super(ResNet, self).__init__()
23 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
24 | bias=False)
25 | self.bn1 = nn.BatchNorm2d(64)
26 | self.relu = nn.ReLU(inplace=True)
27 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
28 | self.layer1 = self._make_layer(block, 64, layers[0])
29 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
30 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
31 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
32 | self.avgpool = nn.AvgPool2d(7, stride=1)
33 | # self.fc = nn.Linear(512 * block.expansion, num_classes)
34 |
35 | for m in self.modules():
36 | if isinstance(m, nn.Conv2d):
37 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
38 | m.weight.data.normal_(0, math.sqrt(2. / n))
39 | elif isinstance(m, nn.BatchNorm2d):
40 | m.weight.data.fill_(1)
41 | m.bias.data.zero_()
42 |
43 | def _make_layer(self, block, planes, blocks, stride=1):
44 | downsample = None
45 | if stride != 1 or self.inplanes != planes * block.expansion:
46 | downsample = nn.Sequential(
47 | nn.Conv2d(self.inplanes, planes * block.expansion,
48 | kernel_size=1, stride=stride, bias=False),
49 | nn.BatchNorm2d(planes * block.expansion),
50 | )
51 |
52 | layers = []
53 | layers.append(block(self.inplanes, planes, stride, downsample))
54 | self.inplanes = planes * block.expansion
55 | for i in range(1, blocks):
56 | layers.append(block(self.inplanes, planes))
57 |
58 | return nn.Sequential(*layers)
59 |
60 | def forward(self, x):
61 | x = self.conv1(x)
62 | x = self.bn1(x)
63 | x = self.relu(x)
64 | x = self.maxpool(x)
65 |
66 | x = self.layer1(x)
67 | x = self.layer2(x)
68 | x = self.layer3(x)
69 | x = self.layer4(x)
70 |
71 | x = self.avgpool(x)
72 | x = x.view(x.size(0), -1)
73 | # x = self.fc(x)
74 |
75 | return x
76 |
77 | class Bottleneck(nn.Module):
78 | expansion = 4
79 |
80 | def __init__(self, inplanes, planes, stride=1, downsample=None):
81 | super(Bottleneck, self).__init__()
82 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
83 | self.bn1 = nn.BatchNorm2d(planes)
84 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
85 | padding=1, bias=False)
86 | self.bn2 = nn.BatchNorm2d(planes)
87 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
88 | self.bn3 = nn.BatchNorm2d(planes * 4)
89 | self.relu = nn.ReLU(inplace=True)
90 | self.downsample = downsample
91 | self.stride = stride
92 |
93 | def forward(self, x):
94 | residual = x
95 |
96 | out = self.conv1(x)
97 | out = self.bn1(out)
98 | out = self.relu(out)
99 |
100 | out = self.conv2(out)
101 | out = self.bn2(out)
102 | out = self.relu(out)
103 |
104 | out = self.conv3(out)
105 | out = self.bn3(out)
106 |
107 | if self.downsample is not None:
108 | residual = self.downsample(x)
109 |
110 | out += residual
111 | out = self.relu(out)
112 |
113 | return out
114 |
115 | def conv3x3(in_planes, out_planes, stride=1):
116 | """3x3 convolution with padding"""
117 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
118 | padding=1, bias=False)
119 |
120 | class BasicBlock(nn.Module):
121 | expansion = 1
122 |
123 | def __init__(self, inplanes, planes, stride=1, downsample=None):
124 | super(BasicBlock, self).__init__()
125 | self.conv1 = conv3x3(inplanes, planes, stride)
126 | self.bn1 = nn.BatchNorm2d(planes)
127 | self.relu = nn.ReLU(inplace=True)
128 | self.conv2 = conv3x3(planes, planes)
129 | self.bn2 = nn.BatchNorm2d(planes)
130 | self.downsample = downsample
131 | self.stride = stride
132 |
133 | def forward(self, x):
134 | residual = x
135 |
136 | out = self.conv1(x)
137 | out = self.bn1(out)
138 | out = self.relu(out)
139 |
140 | out = self.conv2(out)
141 | out = self.bn2(out)
142 |
143 | if self.downsample is not None:
144 | residual = self.downsample(x)
145 |
146 | out += residual
147 | out = self.relu(out)
148 |
149 | return out
150 |
151 | def copy_parameter_from_resnet50(model, res50_dict):
152 | cur_state_dict = model.state_dict()
153 | for name, param in list(res50_dict.items())[0:None]:
154 | if name not in cur_state_dict:
155 | print('unexpected ', name, ' !')
156 | continue
157 | if isinstance(param, Parameter):
158 | param = param.data
159 | try:
160 | cur_state_dict[name].copy_(param)
161 | except:
162 | print(name, ' is inconsistent!')
163 | continue
164 | print('copy state dict finished!')
165 |
166 | def load_Res50Model():
167 | model = ResNet(Bottleneck, [3, 4, 6, 3])
168 | copy_parameter_from_resnet50(model, torchvision.models.resnet50(pretrained = True).state_dict())
169 | return model
170 |
171 | if __name__ == '__main__':
172 | vx = torch.autograd.Variable(torch.from_numpy(np.array([1, 1, 1])))
173 | vy = torch.autograd.Variable(torch.from_numpy(np.array([2, 2, 2])))
174 | vz = torch.cat([vx, vy], 0)
175 | vz[0] = 100
176 | print(vz)
177 | print(vx)
--------------------------------------------------------------------------------
/src/SMPL.py:
--------------------------------------------------------------------------------
1 |
2 | '''
3 | file: SMPL.py
4 |
5 | date: 2018_05_03
6 | author: zhangxiong(1025679612@qq.com)
7 | mark: the algorithm is cited from original SMPL
8 | '''
9 | import torch
10 | from config import args
11 | import json
12 | import sys
13 | import numpy as np
14 | from util import batch_global_rigid_transformation, batch_rodrigues, batch_lrotmin, reflect_pose
15 | import torch.nn as nn
16 |
17 | class SMPL(nn.Module):
18 | def __init__(self, model_path, joint_type = 'cocoplus', obj_saveable = False):
19 | super(SMPL, self).__init__()
20 |
21 | if joint_type not in ['cocoplus', 'lsp']:
22 | msg = 'unknow joint type: {}, it must be either "cocoplus" or "lsp"'.format(joint_type)
23 | sys.exit(msg)
24 |
25 | self.model_path = model_path
26 | self.joint_type = joint_type
27 | with open(model_path, 'r') as reader:
28 | model = json.load(reader)
29 |
30 | if obj_saveable:
31 | self.faces = model['f']
32 | else:
33 | self.faces = None
34 |
35 | np_v_template = np.array(model['v_template'], dtype = np.float)
36 | self.register_buffer('v_template', torch.from_numpy(np_v_template).float())
37 | self.size = [np_v_template.shape[0], 3]
38 |
39 | np_shapedirs = np.array(model['shapedirs'], dtype = np.float)
40 | self.num_betas = np_shapedirs.shape[-1]
41 | np_shapedirs = np.reshape(np_shapedirs, [-1, self.num_betas]).T
42 | self.register_buffer('shapedirs', torch.from_numpy(np_shapedirs).float())
43 |
44 | np_J_regressor = np.array(model['J_regressor'], dtype = np.float)
45 | self.register_buffer('J_regressor', torch.from_numpy(np_J_regressor).float())
46 |
47 | np_posedirs = np.array(model['posedirs'], dtype = np.float)
48 | num_pose_basis = np_posedirs.shape[-1]
49 | np_posedirs = np.reshape(np_posedirs, [-1, num_pose_basis]).T
50 | self.register_buffer('posedirs', torch.from_numpy(np_posedirs).float())
51 |
52 | self.parents = np.array(model['kintree_table'])[0].astype(np.int32)
53 |
54 | np_joint_regressor = np.array(model['cocoplus_regressor'], dtype = np.float)
55 | if joint_type == 'lsp':
56 | self.register_buffer('joint_regressor', torch.from_numpy(np_joint_regressor[:, :14]).float())
57 | else:
58 | self.register_buffer('joint_regressor', torch.from_numpy(np_joint_regressor).float())
59 |
60 | np_weights = np.array(model['weights'], dtype = np.float)
61 |
62 | vertex_count = np_weights.shape[0]
63 | vertex_component = np_weights.shape[1]
64 |
65 | batch_size = max(args.batch_size + args.batch_3d_size, args.eval_batch_size)
66 | np_weights = np.tile(np_weights, (batch_size, 1))
67 | self.register_buffer('weight', torch.from_numpy(np_weights).float().reshape(-1, vertex_count, vertex_component))
68 |
69 | self.register_buffer('e3', torch.eye(3).float())
70 |
71 | self.cur_device = None
72 |
73 | def save_obj(self, verts, obj_mesh_name):
74 | if not self.faces:
75 | msg = 'obj not saveable!'
76 | sys.exit(msg)
77 |
78 | with open(obj_mesh_name, 'w') as fp:
79 | for v in verts:
80 | fp.write( 'v %f %f %f\n' % ( v[0], v[1], v[2]) )
81 |
82 | for f in self.faces: # Faces are 1-based, not 0-based in obj files
83 | fp.write( 'f %d %d %d\n' % (f[0] + 1, f[1] + 1, f[2] + 1) )
84 |
85 | def forward(self, beta, theta, get_skin = False):
86 | if not self.cur_device:
87 | device = beta.device
88 | self.cur_device = torch.device(device.type, device.index)
89 |
90 | num_batch = beta.shape[0]
91 |
92 | v_shaped = torch.matmul(beta, self.shapedirs).view(-1, self.size[0], self.size[1]) + self.v_template
93 | Jx = torch.matmul(v_shaped[:, :, 0], self.J_regressor)
94 | Jy = torch.matmul(v_shaped[:, :, 1], self.J_regressor)
95 | Jz = torch.matmul(v_shaped[:, :, 2], self.J_regressor)
96 | J = torch.stack([Jx, Jy, Jz], dim = 2)
97 |
98 | Rs = batch_rodrigues(theta.view(-1, 3)).view(-1, 24, 3, 3)
99 | pose_feature = (Rs[:, 1:, :, :]).sub(1.0, self.e3).view(-1, 207)
100 | v_posed = torch.matmul(pose_feature, self.posedirs).view(-1, self.size[0], self.size[1]) + v_shaped
101 | self.J_transformed, A = batch_global_rigid_transformation(Rs, J, self.parents, rotate_base = True)
102 |
103 | weight = self.weight[:num_batch]
104 | W = weight.view(num_batch, -1, 24)
105 | T = torch.matmul(W, A.view(num_batch, 24, 16)).view(num_batch, -1, 4, 4)
106 |
107 | v_posed_homo = torch.cat([v_posed, torch.ones(num_batch, v_posed.shape[1], 1, device = self.cur_device)], dim = 2)
108 | v_homo = torch.matmul(T, torch.unsqueeze(v_posed_homo, -1))
109 |
110 | verts = v_homo[:, :, :3, 0]
111 |
112 | joint_x = torch.matmul(verts[:, :, 0], self.joint_regressor)
113 | joint_y = torch.matmul(verts[:, :, 1], self.joint_regressor)
114 | joint_z = torch.matmul(verts[:, :, 2], self.joint_regressor)
115 |
116 | joints = torch.stack([joint_x, joint_y, joint_z], dim = 2)
117 |
118 | if get_skin:
119 | return verts, joints, Rs
120 | else:
121 | return joints
122 |
123 | if __name__ == '__main__':
124 | device = torch.device('cuda', 0)
125 | smpl = SMPL(args.smpl_model, obj_saveable = True).to(device)
126 | pose= np.array([
127 | 1.22162998e+00, 1.17162502e+00, 1.16706634e+00,
128 | -1.20581151e-03, 8.60930011e-02, 4.45963144e-02,
129 | -1.52801601e-02, -1.16911056e-02, -6.02894090e-03,
130 | 1.62427306e-01, 4.26302850e-02, -1.55304456e-02,
131 | 2.58729942e-02, -2.15941742e-01, -6.59851432e-02,
132 | 7.79098943e-02, 1.96353287e-01, 6.44420758e-02,
133 | -5.43042570e-02, -3.45508829e-02, 1.13200583e-02,
134 | -5.60734887e-04, 3.21716577e-01, -2.18840033e-01,
135 | -7.61821344e-02, -3.64610642e-01, 2.97633410e-01,
136 | 9.65453908e-02, -5.54007106e-03, 2.83410680e-02,
137 | -9.57194716e-02, 9.02515948e-02, 3.31488043e-01,
138 | -1.18847653e-01, 2.96623230e-01, -4.76809204e-01,
139 | -1.53382001e-02, 1.72342166e-01, -1.44332021e-01,
140 | -8.10869411e-02, 4.68325168e-02, 1.42248288e-01,
141 | -4.60898802e-02, -4.05981280e-02, 5.28727695e-02,
142 | 3.20133418e-02, -5.23784310e-02, 2.41559884e-03,
143 | -3.08033824e-01, 2.31431410e-01, 1.62540793e-01,
144 | 6.28208935e-01, -1.94355965e-01, 7.23800480e-01,
145 | -6.49612308e-01, -4.07179184e-02, -1.46422181e-02,
146 | 4.51475441e-01, 1.59122205e+00, 2.70355493e-01,
147 | 2.04248756e-01, -6.33800551e-02, -5.50178960e-02,
148 | -1.00920045e+00, 2.39532292e-01, 3.62904727e-01,
149 | -3.38783532e-01, 9.40650925e-02, -8.44506770e-02,
150 | 3.55101633e-03, -2.68924050e-02, 4.93676625e-02],dtype = np.float)
151 |
152 | beta = np.array([-0.25349993, 0.25009069, 0.21440795, 0.78280628, 0.08625954,
153 | 0.28128183, 0.06626327, -0.26495767, 0.09009246, 0.06537955 ])
154 |
155 | vbeta = torch.tensor(np.array([beta])).float().to(device)
156 | vpose = torch.tensor(np.array([pose])).float().to(device)
157 |
158 | verts, j, r = smpl(vbeta, vpose, get_skin = True)
159 |
160 | smpl.save_obj(verts[0].cpu().numpy(), './mesh.obj')
161 |
162 | rpose = reflect_pose(pose)
163 | vpose = torch.tensor(np.array([rpose])).float().to(device)
164 |
165 | verts, j, r = smpl(vbeta, vpose, get_skin = True)
166 | smpl.save_obj(verts[0].cpu().numpy(), './rmesh.obj')
167 |
168 |
--------------------------------------------------------------------------------
/src/config.py:
--------------------------------------------------------------------------------
1 |
2 | '''
3 | file: config.py
4 |
5 | date: 2018_04_29
6 | author: zhangxiong(1025679612@qq.com)
7 | '''
8 |
9 | import argparse
10 |
11 | parser = argparse.ArgumentParser(description = 'hmr model')
12 |
13 | parser.add_argument(
14 | '--fine-tune',
15 | default = True,
16 | type = bool,
17 | help = 'fine tune or not.'
18 | )
19 |
20 | parser.add_argument(
21 | '--encoder-network',
22 | type = str,
23 | default = 'resnet50',
24 | help = 'the encoder network name'
25 | )
26 |
27 | parser.add_argument(
28 | '--smpl-mean-theta-path',
29 | type = str,
30 | default = 'E:/HMR/model/neutral_smpl_mean_params.h5',
31 | help = 'the path for mean smpl theta value'
32 | )
33 |
34 | parser.add_argument(
35 | '--smpl-model',
36 | type = str,
37 | default = 'E:/HMR/model/neutral_smpl_with_cocoplus_reg.txt',
38 | help = 'smpl model path'
39 | )
40 |
41 | parser.add_argument(
42 | '--total-theta-count',
43 | type = int,
44 | default = 85,
45 | help = 'the count of theta param'
46 | )
47 |
48 | parser.add_argument(
49 | '--batch-size',
50 | type = int,
51 | default = 8,
52 | help = 'batch size'
53 | )
54 |
55 | parser.add_argument(
56 | '--batch-3d-size',
57 | type = int,
58 | default = 8,
59 | help = '3d data batch size'
60 | )
61 |
62 | parser.add_argument(
63 | '--adv-batch-size',
64 | type = int,
65 | default = 24,
66 | help = 'default adv batch size'
67 | )
68 |
69 | parser.add_argument(
70 | '--eval-batch-size',
71 | type = int,
72 | default = 400,
73 | help = 'default eval batch size'
74 | )
75 |
76 | parser.add_argument(
77 | '--joint-count',
78 | type = int,
79 | default = 24,
80 | help = 'the count of joints'
81 | )
82 |
83 | parser.add_argument(
84 | '--beta-count',
85 | type = int,
86 | default = 10,
87 | help = 'the count of beta'
88 | )
89 |
90 | parser.add_argument(
91 | '--use-adv-train',
92 | type = bool,
93 | default = True,
94 | help = 'use adv traing or not'
95 | )
96 |
97 | parser.add_argument(
98 | '--scale-min',
99 | type = float,
100 | default = 1.1,
101 | help = 'min scale'
102 | )
103 |
104 | parser.add_argument(
105 | '--scale-max',
106 | type = float,
107 | default = 1.5,
108 | help = 'max scale'
109 | )
110 |
111 | parser.add_argument(
112 | '--num-worker',
113 | type = int,
114 | default = 1,
115 | help = 'pytorch number worker.'
116 | )
117 |
118 | parser.add_argument(
119 | '--iter-count',
120 | type = int,
121 | default = 500001,
122 | help = 'iter count, eatch contains batch-size samples'
123 | )
124 |
125 | parser.add_argument(
126 | '--e-lr',
127 | type = float,
128 | default = 0.00001,
129 | help = 'encoder learning rate.'
130 | )
131 |
132 | parser.add_argument(
133 | '--d-lr',
134 | type = float,
135 | default = 0.0001,
136 | help = 'Adversarial prior learning rate.'
137 | )
138 |
139 | parser.add_argument(
140 | '--e-wd',
141 | type = float,
142 | default = 0.0001,
143 | help = 'encoder weight decay rate.'
144 | )
145 |
146 | parser.add_argument(
147 | '--d-wd',
148 | type = float,
149 | default = 0.0001,
150 | help = 'Adversarial prior weight decay'
151 | )
152 |
153 | parser.add_argument(
154 | '--e-loss-weight',
155 | type = float,
156 | default = 60,
157 | help = 'weight on encoder 2d kp losses.'
158 | )
159 |
160 | parser.add_argument(
161 | '--d-loss-weight',
162 | type = float,
163 | default = 1,
164 | help = 'weight on discriminator losses'
165 | )
166 |
167 |
168 | parser.add_argument(
169 | '--d-disc-ratio',
170 | type = float,
171 | default = 1.0,
172 | help = 'multiple weight of discriminator loss'
173 | )
174 |
175 | parser.add_argument(
176 | '--e-3d-loss-weight',
177 | type = float,
178 | default = 60,
179 | help = 'weight on encoder thetas losses.'
180 | )
181 |
182 | parser.add_argument(
183 | '--e-shape-ratio',
184 | type = float,
185 | default = 5,
186 | help = 'multiple weight of shape loss'
187 | )
188 |
189 | parser.add_argument(
190 | '--e-3d-kp-ratio',
191 | type = float,
192 | default = 10.0,
193 | help = 'multiple weight of 3d key point.'
194 | )
195 |
196 | parser.add_argument(
197 | '--e-pose-ratio',
198 | type = float,
199 | default = 20,
200 | help = 'multiple weight of pose'
201 | )
202 |
203 | parser.add_argument(
204 | '--save-folder',
205 | type = str,
206 | default = 'E:/HMR/data_advanced/trained_model',
207 | help = 'save model path'
208 | )
209 |
210 | parser.add_argument(
211 | '--enable-inter-supervision',
212 | type = bool,
213 | default = False,
214 | help = 'enable inter supervision or not.'
215 | )
216 |
217 | train_2d_set = ['coco', 'lsp', 'lsp_ext', 'ai-ch']
218 | train_3d_set = ['mpi-inf-3dhp', 'hum3.6m']
219 | train_adv_set = ['mosh']
220 | eval_set = ['up3d']
221 |
222 | allowed_encoder_net = ['hourglass', 'resnet50', 'densenet169']
223 |
224 | encoder_feature_count = {
225 | 'hourglass' : 4096,
226 | 'resnet50' : 2048,
227 | 'densenet169' : 1664
228 | }
229 |
230 | crop_size = {
231 | 'hourglass':256,
232 | 'resnet50':224,
233 | 'densenet169':224
234 | }
235 |
236 | data_set_path = {
237 | 'coco':'E:/HMR/data/COCO/',
238 | 'lsp':'E:/HMR/data/lsp',
239 | 'lsp_ext':'E:/HMR/data/lsp_ext',
240 | 'ai-ch':'E:/HMR/data/ai_challenger_keypoint_train_20170902',
241 | 'mpi-inf-3dhp':'E:/HMR/data/mpi_inf_3dhp',
242 | 'hum3.6m':'E:/HMR/data/human3.6m',
243 | 'mosh':'E:/HMR/data/mosh_gen',
244 | 'up3d':'E:/HMR/data/up3d_mpii'
245 | }
246 |
247 | pre_trained_model = {
248 | 'generator' : '/media/disk1/zhangxiong/HMR/hmr_resnet50/fine_tuned/3500_generator.pkl',
249 | 'discriminator' : '/media/disk1/zhangxiong/HMR/hmr_resnet50/fine_tuned/3500_discriminator.pkl'
250 | }
251 |
252 | args = parser.parse_args()
253 | encoder_network = args.encoder_network
254 | args.feature_count = encoder_feature_count[encoder_network]
255 | args.crop_size = crop_size[encoder_network]
256 |
--------------------------------------------------------------------------------
/src/dataloader/AICH_dataloader.py:
--------------------------------------------------------------------------------
1 |
2 | '''
3 | file: AICH_dataloader.py
4 |
5 | author: zhangxiong(1025679612@qq.com)
6 | date: 2018_05_08
7 | purpose: load ai challenge dataset
8 | '''
9 |
10 | import sys
11 | from torch.utils.data import Dataset, DataLoader
12 | import scipy.io as scio
13 | import os
14 | import glob
15 | import numpy as np
16 | import random
17 | import cv2
18 | import json
19 | import torch
20 |
21 | sys.path.append('./src')
22 | from util import calc_aabb, cut_image, flip_image, draw_lsp_14kp__bone, rectangle_intersect, get_rectangle_intersect_ratio, convert_image_by_pixformat_normalize
23 | from config import args
24 | from timer import Clock
25 |
26 | class AICH_dataloader(Dataset):
27 | def __init__(self, data_set_path, use_crop, scale_range, use_flip, only_single_person, min_pts_required, max_intersec_ratio = 0.1, pix_format = 'NHWC', normalize = False, flip_prob = 0.3):
28 | self.data_folder = data_set_path
29 | self.use_crop = use_crop
30 | self.scale_range = scale_range
31 | self.use_flip = use_flip
32 | self.flip_prob = flip_prob
33 | self.only_single_person = only_single_person
34 | self.min_pts_required = min_pts_required
35 | self.max_intersec_ratio = max_intersec_ratio
36 | self.img_ext = '.jpg'
37 | self.pix_format = pix_format
38 | self.normalize = normalize
39 | self._load_data_set()
40 |
41 | def _load_data_set(self):
42 | clk = Clock()
43 |
44 | self.images = []
45 | self.kp2ds = []
46 | self.boxs = []
47 | print('start loading AI CH keypoint data.')
48 | anno_file_path = os.path.join(self.data_folder, 'keypoint_train_annotations_20170902.json')
49 | with open(anno_file_path, 'r') as reader:
50 | anno = json.load(reader)
51 | for record in anno:
52 | image_name = record['image_id'] + self.img_ext
53 | image_path = os.path.join(self.data_folder, 'keypoint_train_images_20170902', image_name)
54 | kp_set = record['keypoint_annotations']
55 | box_set = record['human_annotations']
56 | self._handle_image(image_path, kp_set, box_set)
57 |
58 | print('finished load Ai CH keypoint data, total {} samples'.format(len(self)))
59 |
60 | clk.stop()
61 |
62 | def _ai_ch_to_lsp(self, pts):
63 | kp_map = [8, 7, 6, 9, 10, 11, 2, 1, 0, 3, 4, 5, 13, 12]
64 | pts = np.array(pts, dtype = np.float).reshape(14, 3).copy()
65 | pts[:, 2] = (3.0 - pts[:, 2]) / 2.0
66 | return pts[kp_map].copy()
67 |
68 | def _handle_image(self, image_path, kp_set, box_set):
69 | assert len(kp_set) == len(box_set)
70 |
71 | if len(kp_set) > 1:
72 | if self.only_single_person:
73 | print('only single person supported now!')
74 | return
75 | for key in kp_set.keys():
76 | kps = kp_set[key]
77 | box = box_set[key]
78 | self._handle_sample(key, image_path, kps, [ [box[0], box[1]], [box[2], box[3]] ], box_set)
79 |
80 | def _handle_sample(self, key, image_path, pts, box, boxs):
81 | def _collect_box(key, boxs):
82 | r = []
83 | for k, v in boxs.items():
84 | if k == key:
85 | continue
86 | r.append([[v[0],v[1]], [v[2],v[3]]])
87 | return r
88 |
89 | def _collide_heavily(box, boxs):
90 | for it in boxs:
91 | if get_rectangle_intersect_ratio(box[0], box[1], it[0], it[1]) > self.max_intersec_ratio:
92 | return True
93 | return False
94 | pts = self._ai_ch_to_lsp(pts)
95 | valid_pt_cound = np.sum(pts[:, 2])
96 | if valid_pt_cound < self.min_pts_required:
97 | return
98 |
99 | boxs = _collect_box(key, boxs)
100 | if _collide_heavily(box, boxs):
101 | return
102 |
103 | self.images.append(image_path)
104 | self.kp2ds.append(pts)
105 | lt, rb = box[0], box[1]
106 | self.boxs.append((np.array(lt), np.array(rb)))
107 |
108 | def __len__(self):
109 | return len(self.images)
110 |
111 | def __getitem__(self, index):
112 | image_path = self.images[index]
113 | kps = self.kp2ds[index].copy()
114 | box = self.boxs[index]
115 |
116 | scale = np.random.rand(4) * (self.scale_range[1] - self.scale_range[0]) + self.scale_range[0]
117 | image, kps = cut_image(image_path, kps, scale, box[0], box[1])
118 | ratio = 1.0 * args.crop_size / image.shape[0]
119 | kps[:, :2] *= ratio
120 | dst_image = cv2.resize(image, (args.crop_size, args.crop_size), interpolation = cv2.INTER_CUBIC)
121 |
122 | if self.use_flip and random.random() <= self.flip_prob:
123 | dst_image, kps = flip_image(dst_image, kps)
124 |
125 | #normalize kp to [-1, 1]
126 | ratio = 1.0 / args.crop_size
127 | kps[:, :2] = 2.0 * kps[:, :2] * ratio - 1.0
128 |
129 | return {
130 | 'image': torch.tensor(convert_image_by_pixformat_normalize(dst_image, self.pix_format, self.normalize)).float(),
131 | 'kp_2d': torch.tensor(kps).float(),
132 | 'image_name': self.images[index],
133 | 'data_set':'AI Ch'
134 | }
135 |
136 | if __name__ == '__main__':
137 | aic = AICH_dataloader(
138 | data_set_path = 'E:/HMR/data/ai_challenger_keypoint_train_20170902',
139 | use_crop = True,
140 | scale_range = [1.1, 1.5],
141 | use_flip = True,
142 | only_single_person = False,
143 | min_pts_required = 5,
144 | flip_prob = 1.0
145 | )
146 | l = len(aic)
147 | for _ in range(l):
148 | r = aic.__getitem__(_)
149 | image = r['image'].cpu().numpy().astype(np.uint8)
150 | kps = r['kp_2d'].cpu().numpy()
151 | kps[:, :2] = (kps[:, :2] + 1) * args.crop_size / 2.0
152 | base_name = os.path.basename(r['image_name'])
153 | draw_lsp_14kp__bone(image, kps)
154 | cv2.imshow(base_name, cv2.resize(image, (512, 512), interpolation = cv2.INTER_CUBIC))
155 | cv2.waitKey(0)
156 |
--------------------------------------------------------------------------------
/src/dataloader/COCO2017_dataloader.py:
--------------------------------------------------------------------------------
1 |
2 | '''
3 | file: COCO2017_dataloader.py
4 |
5 | author: zhangxiong(1025679612@qq.com)
6 | date: 2018_05_09
7 | purpose: load COCO 2017 keypoint dataset
8 | '''
9 |
10 | import sys
11 | from torch.utils.data import Dataset, DataLoader
12 | import scipy.io as scio
13 | import os
14 | import glob
15 | import numpy as np
16 | import random
17 | import cv2
18 | import json
19 | import torch
20 |
21 | sys.path.append('./src')
22 | from util import calc_aabb, cut_image, flip_image, draw_lsp_14kp__bone, rectangle_intersect, get_rectangle_intersect_ratio, convert_image_by_pixformat_normalize
23 | from config import args
24 | from timer import Clock
25 |
26 | class COCO2017_dataloader(Dataset):
27 | def __init__(self, data_set_path, use_crop, scale_range, use_flip, only_single_person, min_pts_required, max_intersec_ratio = 0.1, pix_format = 'NHWC', normalize = False, flip_prob = 0.3):
28 | self.data_folder = data_set_path
29 | self.use_crop = use_crop
30 | self.scale_range = scale_range
31 | self.use_flip = use_flip
32 | self.flip_prob = flip_prob
33 | self.only_single_person = only_single_person
34 | self.min_pts_required = min_pts_required
35 | self.max_intersec_ratio = max_intersec_ratio
36 | self.pix_format = pix_format
37 | self.normalize = normalize
38 | self._load_data_set()
39 |
40 | def _load_data_set(self):
41 | self.images = []
42 | self.kp2ds = []
43 | self.boxs = []
44 | clk = Clock()
45 | print('start loading coco 2017 dataset.')
46 | anno_file_path = os.path.join(self.data_folder, 'annotations', 'person_keypoints_train2017.json')
47 | with open(anno_file_path, 'r') as reader:
48 | anno = json.load(reader)
49 |
50 | def _hash_image_id_(image_id_to_info, coco_images_info):
51 | for image_info in coco_images_info:
52 | image_id = image_info['id']
53 | image_name = image_info['file_name']
54 | _anno = {}
55 | _anno['image_path'] = os.path.join(self.data_folder, 'images', 'train-valid2017', image_name)
56 | _anno['kps'] = []
57 | _anno['box'] = []
58 | assert not (image_id in image_id_to_info)
59 | image_id_to_info[image_id] = _anno
60 |
61 | images = anno['images']
62 |
63 | image_id_to_info = {}
64 | _hash_image_id_(image_id_to_info, images)
65 |
66 |
67 | annos = anno['annotations']
68 | for anno_info in annos:
69 | self._handle_anno_info(anno_info, image_id_to_info)
70 |
71 | for k, v in image_id_to_info.items():
72 | self._handle_image_info_(v)
73 |
74 | print('finished load coco 2017 dataset, total {} samples.'.format(len(self.images)))
75 |
76 | clk.stop()
77 |
78 | def _handle_image_info_(self, image_info):
79 | image_path = image_info['image_path']
80 | kp_set = image_info['kps']
81 | box_set = image_info['box']
82 | if len(box_set) > 1:
83 | if self.only_single_person:
84 | return
85 |
86 | for _ in range(len(box_set)):
87 | self._handle_sample(_, kp_set, box_set, image_path)
88 |
89 | def _handle_sample(self, key, kps, boxs, image_path):
90 | def _collect_box(l, boxs):
91 | r = []
92 | for _ in range(len(boxs)):
93 | if _ == l:
94 | continue
95 | r.append(boxs[_])
96 | return r
97 |
98 | def _collide_heavily(box, boxs):
99 | for it in boxs:
100 | if get_rectangle_intersect_ratio(box[0], box[1], it[0], it[1]) > self.max_intersec_ratio:
101 | return True
102 | return False
103 |
104 | kp = kps[key]
105 | box = boxs[key]
106 |
107 | valid_pt_cound = np.sum(kp[:, 2])
108 | if valid_pt_cound < self.min_pts_required:
109 | return
110 |
111 | r = _collect_box(key, boxs)
112 | if _collide_heavily(box, r):
113 | return
114 |
115 | self.images.append(image_path)
116 | self.kp2ds.append(kp.copy())
117 | self.boxs.append(box.copy())
118 |
119 | def _handle_anno_info(self, anno_info, image_id_to_info):
120 | image_id = anno_info['image_id']
121 | kps = anno_info['keypoints']
122 | box_info = anno_info['bbox']
123 | box = [np.array([int(box_info[0]), int(box_info[1])]), np.array([int(box_info[0] + box_info[2]), int(box_info[1] + box_info[3])])]
124 | assert image_id in image_id_to_info
125 | _anno = image_id_to_info[image_id]
126 | _anno['box'].append(box)
127 | _anno['kps'].append(self._convert_to_lsp14_pts(kps))
128 |
129 | def _convert_to_lsp14_pts(self, coco_pts):
130 | kp_map = [15, 13, 11, 10, 12, 14, 9, 7, 5, 4, 6, 8, 0, 0]
131 | kp_map = [16, 14, 12, 11, 13, 15, 10, 8, 6, 5, 7, 9, 0, 0]
132 | kps = np.array(coco_pts, dtype = np.float).reshape(-1, 3)[kp_map].copy()
133 | kps[12: ,2] = 0.0 #no neck, top head
134 | kps[:, 2] /= 2.0
135 | return kps
136 |
137 | def __len__(self):
138 | return len(self.images)
139 |
140 | def __getitem__(self, index):
141 | image_path = self.images[index]
142 | kps = self.kp2ds[index].copy()
143 | box = self.boxs[index]
144 |
145 | scale = np.random.rand(4) * (self.scale_range[1] - self.scale_range[0]) + self.scale_range[0]
146 | image, kps = cut_image(image_path, kps, scale, box[0], box[1])
147 | ratio = 1.0 * args.crop_size / image.shape[0]
148 | kps[:, :2] *= ratio
149 | dst_image = cv2.resize(image, (args.crop_size, args.crop_size), interpolation = cv2.INTER_CUBIC)
150 |
151 | if self.use_flip and random.random() <= self.flip_prob:
152 | dst_image, kps = flip_image(dst_image, kps)
153 |
154 | #normalize kp to [-1, 1]
155 | ratio = 1.0 / args.crop_size
156 | kps[:, :2] = 2.0 * kps[:, :2] * ratio - 1.0
157 |
158 | return {
159 | 'image': torch.tensor(convert_image_by_pixformat_normalize(dst_image, self.pix_format, self.normalize)).float(),
160 | 'kp_2d': torch.tensor(kps).float(),
161 | 'image_name': self.images[index],
162 | 'data_set':'COCO 2017'
163 | }
164 |
165 | if __name__ == '__main__':
166 | coco = COCO2017_dataloader('E:/HMR/data/COCO/', True, [1.1, 1.5], False, False, 10, 0.1)
167 | l = len(coco)
168 | for _ in range(l):
169 | r = lsp.__getitem__(_)
170 | image = r['image'].cpu().numpy().astype(np.uint8)
171 | kps = r['kp_2d'].cpu().numpy()
172 | base_name = os.path.basename(r['image_name'])
173 | draw_lsp_14kp__bone(image, kps)
174 | cv2.imshow(base_name, cv2.resize(image, (512, 512), interpolation = cv2.INTER_CUBIC))
175 | cv2.waitKey(0)
176 |
--------------------------------------------------------------------------------
/src/dataloader/eval_dataloader.py:
--------------------------------------------------------------------------------
1 |
2 | '''
3 | file: eval_dataloader.py
4 |
5 | author: zhangxiong(1025679612@qq.com)
6 | date: 2018_05_20
7 | purpose: load evaluation data
8 | '''
9 | import sys
10 | from torch.utils.data import Dataset, DataLoader
11 | import scipy.io as scio
12 | import os
13 | import glob
14 | import numpy as np
15 | import random
16 | import cv2
17 | import json
18 | import torch
19 |
20 | sys.path.append('./src')
21 | from util import calc_aabb, cut_image, flip_image, draw_lsp_14kp__bone, rectangle_intersect, get_rectangle_intersect_ratio, convert_image_by_pixformat_normalize, reflect_pose
22 | from config import args
23 | # from timer import Clock
24 |
25 | class eval_dataloader(Dataset):
26 | def __init__(self, data_set_path, use_flip, pix_format = 'NHWC', normalize = False, flip_prob = 0.3):
27 | self.use_flip = use_flip
28 | self.flip_prob = flip_prob
29 | self.data_folder = data_set_path
30 | self.pix_format = pix_format
31 | self.normalize = normalize
32 |
33 | self._load_data_set()
34 |
35 | def _load_data_set(self):
36 | # clk = Clock()
37 |
38 | self.images = sorted(glob.glob(os.path.join(self.data_folder, 'image/*.png')))
39 | self.kp2ds = []
40 | self.poses = []
41 | self.betas = []
42 |
43 | for idx in range(len(self.images)):
44 | image_name = os.path.basename(self.images[idx])[:5]
45 | anno_path = os.path.join(self.data_folder, 'annos', image_name + '_joints.npy')
46 | self.kp2ds.append(np.load(anno_path).T)
47 | anno_path = os.path.join(self.data_folder, 'annos', image_name +'.json')
48 | with open(anno_path, 'r') as fp:
49 | annos = json.load(fp)
50 | self.poses.append(np.array(annos['pose']))
51 | self.betas.append(np.array(annos['betas']))
52 |
53 | # clk.stop()
54 |
55 | def __len__(self):
56 | return len(self.images)
57 |
58 | def __getitem__(self, index):
59 | image_path = self.images[index]
60 | kps = self.kp2ds[index].copy()
61 | pose = self.poses[index].copy()
62 | shape = self.betas[index].copy()
63 |
64 | dst_image = cv2.imread(image_path)
65 |
66 |
67 | if self.use_flip and random.random() <= self.flip_prob:
68 | dst_image, kps = flip_image(dst_image, kps)
69 | pose = reflect_poses(pose)
70 |
71 | #normalize kp to [-1, 1]
72 | ratio = 1.0 / args.crop_size
73 | kps[:, :2] = 2.0 * kps[:, :2] * ratio - 1.0
74 |
75 | return {
76 | 'image': torch.tensor(convert_image_by_pixformat_normalize(dst_image, self.pix_format, self.normalize)).float(),
77 | 'kp_2d': torch.tensor(kps).float(),
78 | 'pose': torch.tensor(pose).float(),
79 | 'shape': torch.tensor(shape).float(),
80 | 'image_name': self.images[index],
81 | 'data_set':'up_3d_evaluation'
82 | }
83 |
84 | if __name__ == '__main__':
85 | evl = eval_dataloader('E:/HMR/data/up3d_mpii', True)
86 | l = evl.__len__()
87 | data_loader = DataLoader(evl, batch_size=10,shuffle=True)
88 | for _ in range(l):
89 | r = evl.__getitem__(_)
90 | pass
91 |
92 |
93 |
--------------------------------------------------------------------------------
/src/dataloader/hum36m_dataloader.py:
--------------------------------------------------------------------------------
1 |
2 | '''
3 | file: hum36m_dataloader.py
4 |
5 | author: zhangxiong(1025679612@qq.com)
6 | date: 2018_05_09
7 | purpose: load hum3.6m data
8 | '''
9 |
10 | import sys
11 | from torch.utils.data import Dataset, DataLoader
12 | import os
13 | import glob
14 | import numpy as np
15 | import random
16 | import cv2
17 | import json
18 | import h5py
19 | import torch
20 |
21 | sys.path.append('./src')
22 | from util import calc_aabb, cut_image, flip_image, draw_lsp_14kp__bone, rectangle_intersect, get_rectangle_intersect_ratio, convert_image_by_pixformat_normalize, reflect_pose, reflect_lsp_kp
23 | from config import args
24 | from timer import Clock
25 |
26 | class hum36m_dataloader(Dataset):
27 | def __init__(self, data_set_path, use_crop, scale_range, use_flip, min_pts_required, pix_format = 'NHWC', normalize = False, flip_prob = 0.3):
28 | self.data_folder = data_set_path
29 | self.use_crop = use_crop
30 | self.scale_range = scale_range
31 | self.use_flip = use_flip
32 | self.flip_prob = flip_prob
33 | self.min_pts_required = min_pts_required
34 | self.pix_format = pix_format
35 | self.normalize = normalize
36 | self._load_data_set()
37 |
38 | def _load_data_set(self):
39 |
40 | clk = Clock()
41 |
42 | self.images = []
43 | self.kp2ds = []
44 | self.boxs = []
45 | self.kp3ds = []
46 | self.shapes = []
47 | self.poses = []
48 |
49 | print('start loading hum3.6m data.')
50 |
51 | anno_file_path = os.path.join(self.data_folder, 'annot.h5')
52 | with h5py.File(anno_file_path) as fp:
53 | total_kp2d = np.array(fp['gt2d'])
54 | total_kp3d = np.array(fp['gt3d'])
55 | total_shap = np.array(fp['shape'])
56 | total_pose = np.array(fp['pose'])
57 | total_image_names = np.array(fp['imagename'])
58 |
59 | assert len(total_kp2d) == len(total_kp3d) and len(total_kp2d) == len(total_image_names) and \
60 | len(total_kp2d) == len(total_shap) and len(total_kp2d) == len(total_pose)
61 |
62 | l = len(total_kp2d)
63 | def _collect_valid_pts(pts):
64 | r = []
65 | for pt in pts:
66 | if pt[2] != 0:
67 | r.append(pt)
68 | return r
69 |
70 | for index in range(l):
71 | kp2d = total_kp2d[index].reshape((-1, 3))
72 | if np.sum(kp2d[:, 2]) < self.min_pts_required:
73 | continue
74 |
75 | lt, rb, v = calc_aabb(_collect_valid_pts(kp2d))
76 | self.kp2ds.append(np.array(kp2d.copy(), dtype = np.float))
77 | self.boxs.append((lt, rb))
78 | self.kp3ds.append(total_kp3d[index].copy().reshape(-1, 3))
79 | self.shapes.append(total_shap[index].copy())
80 | self.poses.append(total_pose[index].copy())
81 | self.images.append(os.path.join(self.data_folder, 'image') + total_image_names[index].decode())
82 |
83 | print('finished load hum3.6m data, total {} samples'.format(len(self.kp3ds)))
84 |
85 | clk.stop()
86 |
87 | def __len__(self):
88 | return len(self.images)
89 |
90 | def __getitem__(self, index):
91 | image_path = self.images[index]
92 | kps = self.kp2ds[index].copy()
93 | box = self.boxs[index]
94 | kp_3d = self.kp3ds[index].copy()
95 |
96 | scale = np.random.rand(4) * (self.scale_range[1] - self.scale_range[0]) + self.scale_range[0]
97 | image, kps = cut_image(image_path, kps, scale, box[0], box[1])
98 |
99 | ratio = 1.0 * args.crop_size / image.shape[0]
100 | kps[:, :2] *= ratio
101 | dst_image = cv2.resize(image, (args.crop_size, args.crop_size), interpolation = cv2.INTER_CUBIC)
102 |
103 | trival, shape, pose = np.zeros(3), self.shapes[index], self.poses[index]
104 |
105 | if self.use_flip and random.random() <= self.flip_prob:
106 | dst_image, kps = flip_image(dst_image, kps)
107 | pose = reflect_pose(pose)
108 | kp_3d = reflect_lsp_kp(kp_3d)
109 |
110 | #normalize kp to [-1, 1]
111 | ratio = 1.0 / args.crop_size
112 | kps[:, :2] = 2.0 * kps[:, :2] * ratio - 1.0
113 |
114 | theta = np.concatenate((trival, pose, shape), axis = 0)
115 |
116 | return {
117 | 'image': torch.from_numpy(convert_image_by_pixformat_normalize(dst_image, self.pix_format, self.normalize)).float(),
118 | 'kp_2d': torch.from_numpy(kps).float(),
119 | 'kp_3d': torch.from_numpy(kp_3d).float(),
120 | 'theta': torch.from_numpy(theta).float(),
121 | 'image_name': self.images[index],
122 | 'w_smpl':1.0,
123 | 'w_3d':1.0,
124 | 'data_set':'hum3.6m'
125 | }
126 |
127 | if __name__ == '__main__':
128 | h36m = hum36m_dataloader('E:/HMR/data/human3.6m', True, [1.1, 2.0], True, 5, flip_prob = 1)
129 | l = len(h36m)
130 | for _ in range(l):
131 | r = h36m.__getitem__(_)
132 | pass
--------------------------------------------------------------------------------
/src/dataloader/lsp_dataloader.py:
--------------------------------------------------------------------------------
1 |
2 | '''
3 | file: lsp_dataloader.py
4 |
5 | author: zhangxiong(1025679612@qq.com)
6 | date: 2018_05_07
7 | '''
8 | import sys
9 | from torch.utils.data import Dataset, DataLoader
10 | import scipy.io as scio
11 | import os
12 | import glob
13 | import numpy as np
14 | import random
15 | import cv2
16 | import torch
17 |
18 | sys.path.append('./src')
19 |
20 | from util import calc_aabb, cut_image, flip_image, draw_lsp_14kp__bone, convert_image_by_pixformat_normalize, reflect_lsp_kp
21 | from config import args
22 | from timer import Clock
23 |
24 | class LspLoader(Dataset):
25 | def __init__(self, data_set_path, use_crop, scale_range, use_flip, pix_format = 'NHWC', normalize = False, flip_prob = 0.3):
26 | '''
27 | marks:
28 | data_set path links to the parent folder to lsp, which contains images, joints.mat, README.txt
29 |
30 | inputs:
31 | use_crop crop the image or not, it should be True by default
32 | scale_range, contain the scale range
33 | use_flip, left right flip is allowed
34 | '''
35 | self.use_crop = use_crop
36 | self.scale_range = scale_range
37 | self.use_flip = use_flip
38 | self.flip_prob = flip_prob
39 | self.data_folder = data_set_path
40 | self.pix_format = pix_format
41 | self.normalize = normalize
42 |
43 | self._load_data_set()
44 |
45 | def _load_data_set(self):
46 | clk = Clock()
47 | print('loading LSP data.')
48 | self.images = []
49 | self.kp2ds = []
50 | self.boxs = []
51 |
52 | anno_file_path = os.path.join(self.data_folder, 'joints.mat')
53 | anno = scio.loadmat(anno_file_path)
54 | kp2d = anno['joints'].transpose(2, 1, 0) # N x k x 3
55 | visible = np.logical_not(kp2d[:, :, 2])
56 | kp2d[:, :, 2] = visible.astype(kp2d.dtype)
57 | image_folder = os.path.join(self.data_folder, 'images')
58 | images = sorted(glob.glob(image_folder + '/im*.jpg'))
59 | for _ in range(len(images)):
60 | self._handle_image(images[_], kp2d[_])
61 |
62 | print('finished load LSP data.')
63 | clk.stop()
64 |
65 | def _handle_image(self, image_path, kps):
66 | pt_valid = []
67 | for pt in kps:
68 | if pt[2] == 1:
69 | pt_valid.append(pt)
70 | lt, rb, valid = calc_aabb(pt_valid)
71 |
72 | if not valid:
73 | return
74 |
75 | self.kp2ds.append(kps.copy().astype(np.float))
76 | self.images.append(image_path)
77 | self.boxs.append((lt, rb))
78 |
79 | def __len__(self):
80 | return len(self.images)
81 |
82 | def __getitem__(self, index):
83 | image_path = self.images[index]
84 | kps = self.kp2ds[index].copy()
85 | box = self.boxs[index]
86 |
87 | scale = np.random.rand(4) * (self.scale_range[1] - self.scale_range[0]) + self.scale_range[0]
88 | image, kps = cut_image(image_path, kps, scale, box[0], box[1])
89 | ratio = 1.0 * args.crop_size / image.shape[0]
90 | kps[:, :2] *= ratio
91 | dst_image = cv2.resize(image, (args.crop_size, args.crop_size), interpolation = cv2.INTER_CUBIC)
92 |
93 | if self.use_flip and random.random() <= self.flip_prob:
94 | dst_image, kps = flip_image(dst_image, kps)
95 |
96 | #normalize kp to [-1, 1]
97 | ratio = 1.0 / args.crop_size
98 | kps[:, :2] = 2.0 * kps[:, :2] * ratio - 1.0
99 | return {
100 | 'image': torch.tensor(convert_image_by_pixformat_normalize(dst_image, self.pix_format, self.normalize)).float(),
101 | 'kp_2d': torch.tensor(kps).float(),
102 | 'image_name': self.images[index],
103 | 'data_set':'lsp'
104 | }
105 |
106 | if __name__ == '__main__':
107 | lsp = LspLoader(
108 | data_set_path = 'E:/HMR/data/lsp',
109 | use_crop = True,
110 | scale_range = [1.05, 1.2],
111 | use_flip = True,
112 | flip_prob = 1.0
113 | )
114 | l = lsp.__len__()
115 | data_loader = DataLoader(lsp, batch_size=10,shuffle=True)
116 | for _ in range(l):
117 | r = lsp.__getitem__(_)
118 | image = r['image'].cpu().numpy().astype(np.uint8)
119 | kps = r['kp_2d'].cpu().numpy()
120 | kps[:, :2] = (kps[:, :2] + 1) * args.crop_size / 2.0
121 | base_name = os.path.basename(r['image_name'])
122 | draw_lsp_14kp__bone(image, kps)
123 | cv2.imshow(base_name, cv2.resize(image, (512, 512), interpolation = cv2.INTER_CUBIC))
124 | cv2.waitKey(0)
125 |
126 |
--------------------------------------------------------------------------------
/src/dataloader/lsp_ext_dataloader.py:
--------------------------------------------------------------------------------
1 |
2 | '''
3 | file: lsp_ext_dataloader.py
4 |
5 | author: zhangxiong(1025679612@qq.com)
6 | date: 2018_05_07
7 | '''
8 | import sys
9 | from torch.utils.data import Dataset, DataLoader
10 | import scipy.io as scio
11 | import os
12 | import glob
13 | import numpy as np
14 | import random
15 | import cv2
16 | import torch
17 |
18 | sys.path.append('./src')
19 |
20 | from util import calc_aabb, cut_image, flip_image, draw_lsp_14kp__bone, convert_image_by_pixformat_normalize
21 | from config import args
22 | from timer import Clock
23 |
24 |
25 | class LspExtLoader(Dataset):
26 | def __init__(self, data_set_path, use_crop, scale_range, use_flip, pix_format = 'NHWC', normalize = False, flip_prob = 0.3):
27 | '''
28 | marks:
29 | data_set path links to the parent folder to lsp, which contains images, joints.mat, README.txt
30 |
31 | inputs:
32 | use_crop crop the image or not, it should be True by default
33 | scale_range, contain the scale range
34 | use_flip, left right flip is allowed
35 | '''
36 | self.use_crop = use_crop
37 | self.scale_range = scale_range
38 | self.use_flip = use_flip
39 | self.flip_prob = flip_prob
40 | self.data_folder = data_set_path
41 | self.pix_format = pix_format
42 | self.normalize = normalize
43 |
44 | self._load_data_set()
45 |
46 | def _load_data_set(self):
47 | clk = Clock()
48 |
49 | print('loading LSP ext data.')
50 | self.images = []
51 | self.kp2ds = []
52 | self.boxs = []
53 |
54 | anno_file_path = os.path.join(self.data_folder, 'joints.mat')
55 | anno = scio.loadmat(anno_file_path)
56 | kp2d = anno['joints'].transpose(2, 0, 1) # N x k x 3
57 | image_folder = os.path.join(self.data_folder, 'images')
58 | images = sorted(glob.glob(image_folder + '/im*.jpg'))
59 | for _ in range(len(images)):
60 | self._handle_image(images[_], kp2d[_])
61 |
62 | print('finished load LSP ext data.')
63 | clk.stop()
64 |
65 | def _handle_image(self, image_path, kps):
66 | pt_valid = []
67 | for pt in kps:
68 | if pt[2] == 1:
69 | pt_valid.append(pt)
70 | lt, rb, valid = calc_aabb(pt_valid)
71 |
72 | if not valid:
73 | return
74 |
75 | self.kp2ds.append(kps.copy().astype(np.float))
76 | self.images.append(image_path)
77 | self.boxs.append((lt, rb))
78 |
79 | def __len__(self):
80 | return len(self.images)
81 |
82 | def __getitem__(self, index):
83 | image_path = self.images[index]
84 | kps = self.kp2ds[index].copy()
85 | box = self.boxs[index]
86 |
87 | scale = np.random.rand(4) * (self.scale_range[1] - self.scale_range[0]) + self.scale_range[0]
88 | image, kps = cut_image(image_path, kps, scale, box[0], box[1])
89 | ratio = 1.0 * args.crop_size / image.shape[0]
90 | kps[:, :2] *= ratio
91 | dst_image = cv2.resize(image, (args.crop_size, args.crop_size), interpolation = cv2.INTER_CUBIC)
92 |
93 | if self.use_flip and random.random() <= self.flip_prob:
94 | dst_image, kps = flip_image(dst_image, kps)
95 |
96 | #normalize kp to [-1, 1]
97 | ratio = 1.0 / args.crop_size
98 | kps[:, :2] = 2.0 * kps[:, :2] * ratio - 1.0
99 |
100 | return {
101 | 'image': torch.tensor(convert_image_by_pixformat_normalize(dst_image, self.pix_format, self.normalize)).float(),
102 | 'kp_2d': torch.tensor(kps).float(),
103 | 'image_name': self.images[index],
104 | 'data_set':'lsp_ext'
105 | }
106 |
107 | if __name__ == '__main__':
108 | lsp = LspExtLoader('E:/HMR/data/lsp_ext', True, [1.05, 1.5], False, flip_prob = 1.0)
109 | l = lsp.__len__()
110 |
111 | data_loader = DataLoader(lsp, batch_size=10,shuffle=True)
112 |
113 | for _ in range(l):
114 | r = lsp.__getitem__(_)
115 | image = r['image'].cpu().numpy().astype(np.uint8)
116 | kps = r['kp_2d'].cpu().numpy()
117 | base_name = os.path.basename(r['image_name'])
118 | draw_lsp_14kp__bone(image, kps)
119 | cv2.imshow(base_name, cv2.resize(image, (512, 512), interpolation = cv2.INTER_CUBIC))
120 | cv2.waitKey(0)
121 |
122 |
--------------------------------------------------------------------------------
/src/dataloader/mosh_dataloader.py:
--------------------------------------------------------------------------------
1 |
2 | '''
3 | file: mosh_dataloader.py
4 |
5 | author: zhangxiong(1025679612@qq.com)
6 | date: 2018_05_09
7 | purpose: load COCO 2017 keypoint dataset
8 | '''
9 |
10 | import sys
11 | from torch.utils.data import Dataset, DataLoader
12 | import scipy.io as scio
13 | import os
14 | import glob
15 | import numpy as np
16 | import random
17 | import cv2
18 | import json
19 | import h5py
20 | import torch
21 |
22 | sys.path.append('./src')
23 | from util import calc_aabb, cut_image, flip_image, draw_lsp_14kp__bone, rectangle_intersect, get_rectangle_intersect_ratio, reflect_pose
24 | from config import args
25 | from timer import Clock
26 |
27 |
28 | class mosh_dataloader(Dataset):
29 | def __init__(self, data_set_path, use_flip = True, flip_prob = 0.3):
30 | self.data_folder = data_set_path
31 | self.use_flip = use_flip
32 | self.flip_prob = flip_prob
33 |
34 | self._load_data_set()
35 |
36 | def _load_data_set(self):
37 | clk = Clock()
38 | print('start loading mosh data.')
39 | anno_file_path = os.path.join(self.data_folder, 'mosh_annot.h5')
40 | with h5py.File(anno_file_path) as fp:
41 | self.shapes = np.array(fp['shape'])
42 | self.poses = np.array(fp['pose'])
43 | print('finished load mosh data, total {} samples'.format(len(self.poses)))
44 | clk.stop()
45 |
46 | def __len__(self):
47 | return len(self.poses)
48 |
49 | def __getitem__(self, index):
50 | trival, pose, shape = np.zeros(3), self.poses[index], self.shapes[index]
51 |
52 | if self.use_flip and random.uniform(0, 1) <= self.flip_prob:#left-right reflect the pose
53 | pose = reflect_pose(pose)
54 |
55 | return {
56 | 'theta': torch.tensor(np.concatenate((trival, pose, shape), axis = 0)).float()
57 | }
58 |
59 | if __name__ == '__main__':
60 | print(random.rand(1))
61 | mosh = mosh_dataloader('E:/HMR/data/mosh_gen')
62 | l = len(mosh)
63 | import time
64 | for _ in range(l):
65 | r = mosh.__getitem__(_)
66 | print(r)
--------------------------------------------------------------------------------
/src/dataloader/mpi_inf_3dhp_dataloader.py:
--------------------------------------------------------------------------------
1 |
2 | '''
3 | file: mpi_inf_3dhp_dataloader.py
4 |
5 | author: zhangxiong(1025679612@qq.com)
6 | date: 2018_05_09
7 | purpose: load mpi inf 3dhp data
8 | '''
9 |
10 | import sys
11 | from torch.utils.data import Dataset, DataLoader
12 | import os
13 | import glob
14 | import numpy as np
15 | import random
16 | import cv2
17 | import json
18 | import h5py
19 | import torch
20 |
21 | sys.path.append('./src')
22 | from util import calc_aabb, cut_image, flip_image, draw_lsp_14kp__bone, rectangle_intersect, get_rectangle_intersect_ratio, convert_image_by_pixformat_normalize, reflect_pose, reflect_lsp_kp
23 | from config import args
24 | from timer import Clock
25 |
26 | class mpi_inf_3dhp_dataloader(Dataset):
27 | def __init__(self, data_set_path, use_crop, scale_range, use_flip, min_pts_required, pix_format = 'NHWC', normalize = False, flip_prob = 0.3):
28 | self.data_folder = data_set_path
29 | self.use_crop = use_crop
30 | self.scale_range = scale_range
31 | self.use_flip = use_flip
32 | self.flip_prob = 0.3
33 | self.min_pts_required = min_pts_required
34 | self.pix_format = pix_format
35 | self.normalize = normalize
36 | self._load_data_set()
37 |
38 | def _load_data_set(self):
39 | clk = Clock()
40 |
41 | self.images = []
42 | self.kp2ds = []
43 | self.boxs = []
44 | self.kp3ds = []
45 |
46 | print('start loading mpii-inf-3dhp data.')
47 | anno_file_path = os.path.join(self.data_folder, 'annot.h5')
48 | with h5py.File(anno_file_path) as fp:
49 | total_kp2d = np.array(fp['gt2d'])
50 | total_kp3d = np.array(fp['gt3d'])
51 | total_image_names = np.array(fp['imagename'])
52 |
53 | assert len(total_kp2d) == len(total_kp3d) and len(total_kp2d) == len(total_image_names)
54 |
55 | l = len(total_kp2d)
56 | def _collect_valid_pts(pts):
57 | r = []
58 | for pt in pts:
59 | if pt[2] != 0:
60 | r.append(pt)
61 | return r
62 |
63 | for index in range(l):
64 | kp2d = total_kp2d[index].reshape((-1, 3))
65 | if np.sum(kp2d[:, 2]) < self.min_pts_required:
66 | continue
67 |
68 | lt, rb, v = calc_aabb(_collect_valid_pts(kp2d))
69 | self.kp2ds.append(np.array(kp2d.copy(), dtype = np.float))
70 | self.boxs.append((lt, rb))
71 | self.kp3ds.append(total_kp3d[index].copy().reshape(-1, 3))
72 | self.images.append(os.path.join(self.data_folder, 'image') + total_image_names[index].decode())
73 |
74 | print('finished load mpii-inf-3dhp data, total {} samples'.format(len(self.images)))
75 | clk.stop()
76 |
77 | def __len__(self):
78 | return len(self.images)
79 |
80 | def __getitem__(self, index):
81 | image_path = self.images[index]
82 | kps = self.kp2ds[index].copy()
83 | box = self.boxs[index]
84 | kp_3d = self.kp3ds[index].copy()
85 |
86 | scale = np.random.rand(4) * (self.scale_range[1] - self.scale_range[0]) + self.scale_range[0]
87 | image, kps = cut_image(image_path, kps, scale, box[0], box[1])
88 |
89 | ratio = 1.0 * args.crop_size / image.shape[0]
90 | kps[:, :2] *= ratio
91 | dst_image = cv2.resize(image, (args.crop_size, args.crop_size), interpolation = cv2.INTER_CUBIC)
92 |
93 | if self.use_flip and random.random() <= self.flip_prob:
94 | dst_image, kps = flip_image(dst_image, kps)
95 | kp_3d = reflect_lsp_kp(kp_3d)
96 |
97 | #normalize kp to [-1, 1]
98 | ratio = 1.0 / args.crop_size
99 | kps[:, :2] = 2.0 * kps[:, :2] * ratio - 1.0
100 |
101 | return {
102 | 'image': torch.from_numpy(convert_image_by_pixformat_normalize(dst_image, self.pix_format, self.normalize)).float(),
103 | 'kp_2d': torch.from_numpy(kps).float(),
104 | 'kp_3d': torch.from_numpy(kp_3d).float(),
105 | 'theta': torch.zeros(85).float(),
106 | 'image_name': self.images[index],
107 | 'w_smpl':0.0,
108 | 'w_3d':1.0,
109 | 'data_set':'mpi inf 3dhp'
110 | }
111 |
112 | if __name__ == '__main__':
113 | mpi = mpi_inf_3dhp_dataloader('E:/HMR/data/mpii_inf_3dhp', True, [1.1, 2.0], False, 5)
114 | l = len(mpi)
115 | for _ in range(l):
116 | r = mpi.__getitem__(_)
117 | base_name = os.path.basename(r['image_name'])
118 | draw_lsp_14kp__bone(r['image'], r['kp_2d'])
119 | cv2.imshow(base_name, cv2.resize(r['image'], (512, 512), interpolation = cv2.INTER_CUBIC))
120 | cv2.waitKey(0)
121 |
--------------------------------------------------------------------------------
/src/densenet.py:
--------------------------------------------------------------------------------
1 |
2 | import re
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import torch.utils.model_zoo as model_zoo
7 | from collections import OrderedDict
8 | import sys
9 |
10 | __all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161']
11 |
12 |
13 | model_urls = {
14 | 'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth',
15 | 'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth',
16 | 'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth',
17 | 'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth',
18 | }
19 |
20 |
21 | def densenet121(pretrained=False, **kwargs):
22 | r"""Densenet-121 model from
23 | `"Densely Connected Convolutional Networks" `_
24 |
25 | Args:
26 | pretrained (bool): If True, returns a model pre-trained on ImageNet
27 | """
28 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16),
29 | **kwargs)
30 | if pretrained:
31 | # '.'s are no longer allowed in module names, but pervious _DenseLayer
32 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
33 | # They are also in the checkpoints in model_urls. This pattern is used
34 | # to find such keys.
35 | pattern = re.compile(
36 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
37 | state_dict = model_zoo.load_url(model_urls['densenet121'])
38 | for key in list(state_dict.keys()):
39 | res = pattern.match(key)
40 | if res:
41 | new_key = res.group(1) + res.group(2)
42 | state_dict[new_key] = state_dict[key]
43 | del state_dict[key]
44 | model.load_state_dict(state_dict)
45 | return model
46 |
47 |
48 | def densenet169(pretrained=False, **kwargs):
49 | r"""Densenet-169 model from
50 | `"Densely Connected Convolutional Networks" `_
51 |
52 | Args:
53 | pretrained (bool): If True, returns a model pre-trained on ImageNet
54 | """
55 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32),
56 | **kwargs)
57 | if pretrained:
58 | # '.'s are no longer allowed in module names, but pervious _DenseLayer
59 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
60 | # They are also in the checkpoints in model_urls. This pattern is used
61 | # to find such keys.
62 | pattern = re.compile(
63 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
64 | state_dict = model_zoo.load_url(model_urls['densenet169'])
65 | for key in list(state_dict.keys()):
66 | res = pattern.match(key)
67 | if res:
68 | new_key = res.group(1) + res.group(2)
69 | state_dict[new_key] = state_dict[key]
70 | del state_dict[key]
71 | model.load_state_dict(state_dict)
72 | return model
73 |
74 |
75 | def densenet201(pretrained=False, **kwargs):
76 | r"""Densenet-201 model from
77 | `"Densely Connected Convolutional Networks" `_
78 |
79 | Args:
80 | pretrained (bool): If True, returns a model pre-trained on ImageNet
81 | """
82 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32),
83 | **kwargs)
84 | if pretrained:
85 | # '.'s are no longer allowed in module names, but pervious _DenseLayer
86 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
87 | # They are also in the checkpoints in model_urls. This pattern is used
88 | # to find such keys.
89 | pattern = re.compile(
90 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
91 | state_dict = model_zoo.load_url(model_urls['densenet201'])
92 | for key in list(state_dict.keys()):
93 | res = pattern.match(key)
94 | if res:
95 | new_key = res.group(1) + res.group(2)
96 | state_dict[new_key] = state_dict[key]
97 | del state_dict[key]
98 | model.load_state_dict(state_dict)
99 | return model
100 |
101 |
102 | def densenet161(pretrained=False, **kwargs):
103 | r"""Densenet-161 model from
104 | `"Densely Connected Convolutional Networks" `_
105 |
106 | Args:
107 | pretrained (bool): If True, returns a model pre-trained on ImageNet
108 | """
109 | model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24),
110 | **kwargs)
111 | if pretrained:
112 | # '.'s are no longer allowed in module names, but pervious _DenseLayer
113 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
114 | # They are also in the checkpoints in model_urls. This pattern is used
115 | # to find such keys.
116 | pattern = re.compile(
117 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
118 | state_dict = model_zoo.load_url(model_urls['densenet161'])
119 | for key in list(state_dict.keys()):
120 | res = pattern.match(key)
121 | if res:
122 | new_key = res.group(1) + res.group(2)
123 | state_dict[new_key] = state_dict[key]
124 | del state_dict[key]
125 | model.load_state_dict(state_dict)
126 | return model
127 |
128 |
129 | class _DenseLayer(nn.Sequential):
130 | def __init__(self, num_input_features, growth_rate, bn_size, drop_rate):
131 | super(_DenseLayer, self).__init__()
132 | self.add_module('norm1', nn.BatchNorm2d(num_input_features)),
133 | self.add_module('relu1', nn.ReLU(inplace=True)),
134 | self.add_module('conv1', nn.Conv2d(num_input_features, bn_size *
135 | growth_rate, kernel_size=1, stride=1, bias=False)),
136 | self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)),
137 | self.add_module('relu2', nn.ReLU(inplace=True)),
138 | self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate,
139 | kernel_size=3, stride=1, padding=1, bias=False)),
140 | self.drop_rate = drop_rate
141 |
142 | def forward(self, x):
143 | new_features = super(_DenseLayer, self).forward(x)
144 | if self.drop_rate > 0:
145 | new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)
146 | return torch.cat([x, new_features], 1)
147 |
148 |
149 | class _DenseBlock(nn.Sequential):
150 | def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate):
151 | super(_DenseBlock, self).__init__()
152 | for i in range(num_layers):
153 | layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate)
154 | self.add_module('denselayer%d' % (i + 1), layer)
155 |
156 |
157 | class _Transition(nn.Sequential):
158 | def __init__(self, num_input_features, num_output_features):
159 | super(_Transition, self).__init__()
160 | self.add_module('norm', nn.BatchNorm2d(num_input_features))
161 | self.add_module('relu', nn.ReLU(inplace=True))
162 | self.add_module('conv', nn.Conv2d(num_input_features, num_output_features,
163 | kernel_size=1, stride=1, bias=False))
164 | self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2))
165 |
166 |
167 | class DenseNet(nn.Module):
168 | r"""Densenet-BC model class, based on
169 | `"Densely Connected Convolutional Networks" `_
170 |
171 | Args:
172 | growth_rate (int) - how many filters to add each layer (`k` in paper)
173 | block_config (list of 4 ints) - how many layers in each pooling block
174 | num_init_features (int) - the number of filters to learn in the first convolution layer
175 | bn_size (int) - multiplicative factor for number of bottle neck layers
176 | (i.e. bn_size * k features in the bottleneck layer)
177 | drop_rate (float) - dropout rate after each dense layer
178 | num_classes (int) - number of classification classes
179 | """
180 | def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16),
181 | num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000):
182 |
183 | super(DenseNet, self).__init__()
184 |
185 | # First convolution
186 | self.features = nn.Sequential(OrderedDict([
187 | ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),
188 | ('norm0', nn.BatchNorm2d(num_init_features)),
189 | ('relu0', nn.ReLU(inplace=True)),
190 | ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
191 | ]))
192 |
193 | # Each denseblock
194 | num_features = num_init_features
195 | for i, num_layers in enumerate(block_config):
196 | block = _DenseBlock(num_layers=num_layers, num_input_features=num_features,
197 | bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate)
198 | self.features.add_module('denseblock%d' % (i + 1), block)
199 | num_features = num_features + num_layers * growth_rate
200 | if i != len(block_config) - 1:
201 | trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2)
202 | self.features.add_module('transition%d' % (i + 1), trans)
203 | num_features = num_features // 2
204 |
205 | # Final batch norm
206 | self.features.add_module('norm5', nn.BatchNorm2d(num_features))
207 |
208 | # Linear layer
209 | self.classifier = nn.Linear(num_features, num_classes)
210 |
211 | # Official init from torch repo.
212 | for m in self.modules():
213 | if isinstance(m, nn.Conv2d):
214 | nn.init.kaiming_normal(m.weight.data)
215 | elif isinstance(m, nn.BatchNorm2d):
216 | m.weight.data.fill_(1)
217 | m.bias.data.zero_()
218 | elif isinstance(m, nn.Linear):
219 | m.bias.data.zero_()
220 |
221 | def forward(self, x):
222 | features = self.features(x)
223 | out = F.relu(features, inplace=True)
224 | out = F.avg_pool2d(out, kernel_size=7, stride=1).view(features.size(0), -1)
225 | #out = self.classifier(out)
226 | return out
227 |
228 | def load_denseNet(net_type):
229 | if net_type == 'densenet121':
230 | return densenet121(pretrained=True)
231 | elif net_type == 'densenet169':
232 | return densenet169(pretrained=True)
233 | elif net_type == 'densenet201':
234 | return densenet201(pretrained=True)
235 | elif net_type == 'densenet161':
236 | return densenet161(pretrained=True)
237 | else:
238 | msg = 'invalid denset net type'
239 | sys.exit(msg)
240 |
241 | if __name__ == '__main__':
242 | net = load_denseNet('densenet169')
243 | print(net)
--------------------------------------------------------------------------------
/src/do_train.sh:
--------------------------------------------------------------------------------
1 |
2 |
3 | #CUDA_VISIBLE_DEVICES=4,5,6,7 python3 trainer.py
4 |
5 | CUDA_VISIBLE_DEVICES=0,1,2,3 nohup python3 trainer.py > train.log &
6 |
--------------------------------------------------------------------------------
/src/model.py:
--------------------------------------------------------------------------------
1 |
2 | '''
3 | file: model.py
4 |
5 | date: 2018_05_03
6 | author: zhangxiong(1025679612@qq.com)
7 | '''
8 |
9 | from LinearModel import LinearModel
10 | import torch.nn as nn
11 | import numpy as np
12 | import torch
13 | import util
14 | from Discriminator import ShapeDiscriminator, PoseDiscriminator, FullPoseDiscriminator
15 | from SMPL import SMPL
16 | from config import args
17 | import config
18 | import Resnet
19 | from HourGlass import _create_hourglass_net
20 | from densenet import load_denseNet
21 | import sys
22 |
23 | class ThetaRegressor(LinearModel):
24 | def __init__(self, fc_layers, use_dropout, drop_prob, use_ac_func, iterations):
25 | super(ThetaRegressor, self).__init__(fc_layers, use_dropout, drop_prob, use_ac_func)
26 | self.iterations = iterations
27 | batch_size = max(args.batch_size + args.batch_3d_size, args.eval_batch_size)
28 | mean_theta = np.tile(util.load_mean_theta(), batch_size).reshape((batch_size, -1))
29 | self.register_buffer('mean_theta', torch.from_numpy(mean_theta).float())
30 | '''
31 | param:
32 | inputs: is the output of encoder, which has 2048 features
33 |
34 | return:
35 | a list contains [ [theta1, theta1, ..., theta1], [theta2, theta2, ..., theta2], ... , ], shape is iterations X N X 85(or other theta count)
36 | '''
37 | def forward(self, inputs):
38 | thetas = []
39 | shape = inputs.shape
40 | theta = self.mean_theta[:shape[0], :]
41 | for _ in range(self.iterations):
42 | total_inputs = torch.cat([inputs, theta], 1)
43 | theta = theta + self.fc_blocks(total_inputs)
44 | thetas.append(theta)
45 | return thetas
46 |
47 | class HMRNetBase(nn.Module):
48 | def __init__(self):
49 | super(HMRNetBase, self).__init__()
50 | self._read_configs()
51 |
52 | print('start creating sub modules...')
53 | self._create_sub_modules()
54 |
55 | def _read_configs(self):
56 | def _check_config():
57 | encoder_name = args.encoder_network
58 | enable_inter_supervions = args.enable_inter_supervision
59 | feature_count = args.feature_count
60 | if encoder_name == 'hourglass':
61 | assert args.crop_size == 256
62 | elif encoder_name == 'resnet50':
63 | assert args.crop_size == 224
64 | assert not enable_inter_supervions
65 | elif encoder_name.startswith('densenet'):
66 | assert args.crop_size == 224
67 | assert not enable_inter_supervions
68 | else:
69 | msg = 'invalid encoder network, only {} is allowd, got {}'.format(args.allowed_encoder_net, encoder_name)
70 | sys.exit(msg)
71 | assert config.encoder_feature_count[encoder_name] == feature_count
72 |
73 | _check_config()
74 |
75 | self.encoder_name = args.encoder_network
76 | self.beta_count = args.beta_count
77 | self.smpl_model = args.smpl_model
78 | self.smpl_mean_theta_path = args.smpl_mean_theta_path
79 | self.total_theta_count = args.total_theta_count
80 | self.joint_count = args.joint_count
81 | self.feature_count = args.feature_count
82 |
83 | def _create_sub_modules(self):
84 | '''
85 | ddd smpl model, SMPL can create a mesh from beta & theta
86 | '''
87 | self.smpl = SMPL(self.smpl_model, obj_saveable = True)
88 |
89 | '''
90 | only resnet50 and hourglass is allowd currently, maybe other encoder will be allowd later.
91 | '''
92 | if self.encoder_name == 'resnet50':
93 | print('creating resnet50')
94 | self.encoder = Resnet.load_Res50Model()
95 | elif self.encoder_name == 'hourglass':
96 | print('creating hourglass')
97 | self.encoder = _create_hourglass_net()
98 | elif self.encoder_name.startswith('densenet'):
99 | print('creating densenet')
100 | self.encoder = load_denseNet(self.encoder_name)
101 | else:
102 | assert 0
103 | '''
104 | regressor can predict betas(include beta and theta which needed by SMPL) from coder extracted from encoder in a iteratirve way
105 | '''
106 | fc_layers = [self.feature_count + self.total_theta_count, 1024, 1024, 85]
107 | use_dropout = [True, True, False]
108 | drop_prob = [0.5, 0.5, 0.5]
109 | use_ac_func = [True, True, False] #unactive the last layer
110 | iterations = 3
111 | self.regressor = ThetaRegressor(fc_layers, use_dropout, drop_prob, use_ac_func, iterations)
112 | self.iterations = iterations
113 |
114 | print('finished create the encoder modules...')
115 |
116 | def forward(self, inputs):
117 | if self.encoder_name == 'resnet50':
118 | feature = self.encoder(inputs)
119 | thetas = self.regressor(feature)
120 | detail_info = []
121 | for theta in thetas:
122 | detail_info.append(self._calc_detail_info(theta))
123 | return detail_info
124 | elif self.encoder_name.startswith('densenet'):
125 | feature = self.encoder(inputs)
126 | thetas = self.regressor(feature)
127 | detail_info = []
128 | for theta in thetas:
129 | detail_info.append(self._calc_detail_info(theta))
130 | return detail_info
131 | elif self.encoder_name == 'hourglass':
132 | if args.enable_inter_supervision:
133 | features = self.encoder(inputs)
134 | detail_info = []
135 | for feature in features:
136 | thetas = self.regressor(feature)
137 | detail_info.append(self._calc_detail_info(thetas[-1]))
138 | return detail_info
139 | else:
140 | features = self.encoder(inputs)
141 | thetas = self.regressor(features[-1]) #only the last block
142 | detail_info = []
143 | for theta in thetas:
144 | detail_info.append(self._calc_detail_info(theta))
145 | return detail_info
146 | else:
147 | assert 0
148 |
149 | '''
150 | purpose:
151 | calc verts, joint2d, joint3d, Rotation matrix
152 |
153 | inputs:
154 | theta: N X (3 + 72 + 10)
155 |
156 | return:
157 | thetas, verts, j2d, j3d, Rs
158 | '''
159 |
160 | def _calc_detail_info(self, theta):
161 | cam = theta[:, 0:3].contiguous()
162 | pose = theta[:, 3:75].contiguous()
163 | shape = theta[:, 75:].contiguous()
164 | verts, j3d, Rs = self.smpl(beta = shape, theta = pose, get_skin = True)
165 | j2d = util.batch_orth_proj(j3d, cam)
166 |
167 | return (theta, verts, j2d, j3d, Rs)
168 |
169 | if __name__ == '__main__':
170 | cam = np.array([[0.9, 0, 0]], dtype = np.float)
171 | pose= np.array([[-9.44920200e+01, -4.25263865e+01, -1.30050643e+01, -2.79970490e-01,
172 | 3.24995661e-01, 5.03083125e-01, -6.90573755e-01, -4.12994214e-01,
173 | -4.21870093e-01, 5.98717416e-01, -1.48420885e-02, -3.85911139e-02,
174 | 1.13642605e-01, 2.30647176e-01, -2.11843286e-01, 1.31767149e+00,
175 | -6.61596447e-01, 4.02174644e-01, 3.03129424e-02, 5.91100770e-02,
176 | -8.04416564e-02, -1.12944653e-01, 3.15045050e-01, -1.32838375e-01,
177 | -1.33748209e-01, -4.99408923e-01, 1.40508643e-01, 6.10867911e-02,
178 | -2.22951915e-02, -4.73448564e-02, -1.48489055e-01, 1.47620442e-01,
179 | 3.24157346e-01, 7.78414851e-04, 1.70687935e-01, -1.54716815e-01,
180 | 2.95053507e-01, -2.91967776e-01, 1.26000780e-01, 8.09572677e-02,
181 | 1.54710846e-02, -4.21941758e-01, 7.44124075e-02, 1.17146423e-01,
182 | 3.16305389e-01, 5.04810448e-01, -3.65526364e-01, 1.31366428e-01,
183 | -2.76658949e-02, -9.17315987e-03, -1.88285742e-01, 7.86409877e-03,
184 | -9.41106758e-02, 2.08424367e-01, 1.62278709e-01, -7.98170265e-01,
185 | -3.97403587e-03, 1.11321421e-01, 6.07793270e-01, 1.42215980e-01,
186 | 4.48185010e-01, -1.38429048e-01, 3.77056061e-02, 4.48877661e-01,
187 | 1.31445158e-01, 5.07427503e-02, -3.80920772e-01, -2.52292254e-02,
188 | -5.27745375e-02, -7.43903887e-02, 7.22498075e-02, -6.35824487e-03]])
189 |
190 | beta = np.array([[-3.54196257, 0.90870435, -1.0978663 , -0.20436199, 0.18589762, 0.55789026, -0.18163599, 0.12002746, -0.09172286, 0.4430783 ]])
191 | real_shapes = torch.from_numpy(beta).float().cuda()
192 | real_poses = torch.from_numpy(pose).float().cuda()
193 |
194 | net = HMRNetBase().cuda()
195 | nx = torch.rand(2, 3, 224, 224).float().cuda()
196 |
--------------------------------------------------------------------------------
/src/timer.py:
--------------------------------------------------------------------------------
1 |
2 | '''
3 | file: timer.py
4 |
5 | date: 2018_05_09
6 | author: zhangxiong(1025679612@qq.com)
7 | '''
8 |
9 | import time
10 |
11 | class Clock:
12 | def __init__(self, start_tick = True):
13 | self.pre_time = 0
14 | if start_tick:
15 | self.start()
16 |
17 | def start(self):
18 | self.pre_time = time.time()
19 |
20 | def stop(self):
21 | self.cur_time = time.time()
22 | print('time {} elapsed!'.format(self.cur_time - self.pre_time))
23 | self.pre_time = self.cur_time
--------------------------------------------------------------------------------
/src/trainer.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | '''
4 | file: trainer.py
5 |
6 | date: 2018_05_07
7 | author: zhangxiong(1025679612@qq.com)
8 | '''
9 |
10 | import sys
11 | from model import HMRNetBase
12 | from Discriminator import Discriminator
13 | from config import args
14 | import config
15 | import torch
16 | import torch.nn as nn
17 | from torch.utils.data import Dataset, DataLoader, ConcatDataset
18 |
19 | from dataloader.AICH_dataloader import AICH_dataloader
20 | from dataloader.COCO2017_dataloader import COCO2017_dataloader
21 | from dataloader.hum36m_dataloader import hum36m_dataloader
22 | from dataloader.lsp_dataloader import LspLoader
23 | from dataloader.lsp_ext_dataloader import LspExtLoader
24 | from dataloader.mosh_dataloader import mosh_dataloader
25 | from dataloader.mpi_inf_3dhp_dataloader import mpi_inf_3dhp_dataloader
26 | from dataloader.eval_dataloader import eval_dataloader
27 |
28 | from util import align_by_pelvis, batch_rodrigues, copy_state_dict
29 | from timer import Clock
30 | import time
31 | import datetime
32 | from collections import OrderedDict
33 | import os
34 |
35 | class HMRTrainer(object):
36 | def __init__(self):
37 | self.pix_format = 'NCHW'
38 | self.normalize = True
39 | self.flip_prob = 0.5
40 | self.use_flip = False
41 | self.w_smpl = torch.ones((config.args.eval_batch_size)).float().cuda()
42 |
43 | self._build_model()
44 | self._create_data_loader()
45 |
46 | def _create_data_loader(self):
47 | self.loader_2d = self._create_2d_data_loader(config.train_2d_set)
48 | self.loader_mosh = self._create_adv_data_loader(config.train_adv_set)
49 | self.loader_3d = self._create_3d_data_loader(config.train_3d_set)
50 |
51 | def _build_model(self):
52 | print('start building modle.')
53 |
54 | '''
55 | load pretrain model
56 | '''
57 | generator = HMRNetBase()
58 | model_path = config.pre_trained_model['generator']
59 | if os.path.exists(model_path):
60 | copy_state_dict(
61 | generator.state_dict(),
62 | torch.load(model_path),
63 | prefix = 'module.'
64 | )
65 | else:
66 | print('model {} not exist!'.format(model_path))
67 |
68 | discriminator = Discriminator()
69 | model_path = config.pre_trained_model['discriminator']
70 | if os.path.exists(model_path):
71 | copy_state_dict(
72 | discriminator.state_dict(),
73 | torch.load(model_path),
74 | prefix = 'module.'
75 | )
76 | else:
77 | print('model {} not exist!'.format(model_path))
78 |
79 | self.generator = nn.DataParallel(generator).cuda()
80 | self.discriminator = nn.DataParallel(discriminator).cuda()
81 |
82 | self.e_opt = torch.optim.Adam(
83 | self.generator.parameters(),
84 | lr = args.e_lr,
85 | weight_decay = args.e_wd
86 | )
87 |
88 | self.d_opt = torch.optim.Adam(
89 | self.discriminator.parameters(),
90 | lr = args.d_lr,
91 | weight_decay = args.d_wd
92 | )
93 |
94 | self.e_sche = torch.optim.lr_scheduler.StepLR(
95 | self.e_opt,
96 | step_size = 500,
97 | gamma = 0.9
98 | )
99 |
100 | self.d_sche = torch.optim.lr_scheduler.StepLR(
101 | self.d_opt,
102 | step_size = 500,
103 | gamma = 0.9
104 | )
105 |
106 | print('finished build model.')
107 |
108 | def _create_2d_data_loader(self, data_2d_set):
109 | data_set = []
110 | for data_set_name in data_2d_set:
111 | data_set_path = config.data_set_path[data_set_name]
112 | if data_set_name == 'coco':
113 | coco = COCO2017_dataloader(
114 | data_set_path = data_set_path,
115 | use_crop = True,
116 | scale_range = [1.05, 1.3],
117 | use_flip = self.use_flip,
118 | only_single_person = False,
119 | min_pts_required = 7,
120 | max_intersec_ratio = 0.5,
121 | pix_format = self.pix_format,
122 | normalize = self.normalize,
123 | flip_prob = self.flip_prob
124 | )
125 | data_set.append(coco)
126 | elif data_set_name == 'lsp':
127 | lsp = LspLoader(
128 | data_set_path = data_set_path,
129 | use_crop = True,
130 | scale_range = [1.05, 1.3],
131 | use_flip = self.use_flip,
132 | pix_format = self.pix_format,
133 | normalize = self.normalize,
134 | flip_prob = self.flip_prob
135 | )
136 | data_set.append(lsp)
137 | elif data_set_name == 'lsp_ext':
138 | lsp_ext = LspExtLoader(
139 | data_set_path = data_set_path,
140 | use_crop = True,
141 | scale_range = [1.1, 1.2],
142 | use_flip = self.use_flip,
143 | pix_format = self.pix_format,
144 | normalize = self.normalize,
145 | flip_prob = self.flip_prob
146 | )
147 | data_set.append(lsp_ext)
148 | elif data_set_name == 'ai-ch':
149 | ai_ch = AICH_dataloader(
150 | data_set_path = data_set_path,
151 | use_crop = True,
152 | scale_range = [1.1, 1.2],
153 | use_flip = self.use_flip,
154 | only_single_person = False,
155 | min_pts_required = 5,
156 | max_intersec_ratio = 0.1,
157 | pix_format = self.pix_format,
158 | normalize = self.normalize,
159 | flip_prob = self.flip_prob
160 | )
161 | data_set.append(ai_ch)
162 | else:
163 | msg = 'invalid 2d dataset'
164 | sys.exit(msg)
165 |
166 | con_2d_dataset = ConcatDataset(data_set)
167 |
168 | return DataLoader(
169 | dataset = con_2d_dataset,
170 | batch_size = config.args.batch_size,
171 | shuffle = True,
172 | drop_last = True,
173 | pin_memory = True,
174 | num_workers = config.args.num_worker
175 | )
176 |
177 | def _create_3d_data_loader(self, data_3d_set):
178 | data_set = []
179 | for data_set_name in data_3d_set:
180 | data_set_path = config.data_set_path[data_set_name]
181 | if data_set_name == 'mpi-inf-3dhp':
182 | mpi_inf_3dhp = mpi_inf_3dhp_dataloader(
183 | data_set_path = data_set_path,
184 | use_crop = True,
185 | scale_range = [1.1, 1.2],
186 | use_flip = self.use_flip,
187 | min_pts_required = 5,
188 | pix_format = self.pix_format,
189 | normalize = self.normalize,
190 | flip_prob = self.flip_prob
191 | )
192 | data_set.append(mpi_inf_3dhp)
193 | elif data_set_name == 'hum3.6m':
194 | hum36m = hum36m_dataloader(
195 | data_set_path = data_set_path,
196 | use_crop = True,
197 | scale_range = [1.1, 1.2],
198 | use_flip = self.use_flip,
199 | min_pts_required = 5,
200 | pix_format = self.pix_format,
201 | normalize = self.normalize,
202 | flip_prob = self.flip_prob
203 | )
204 | data_set.append(hum36m)
205 | else:
206 | msg = 'invalid 3d dataset'
207 | sys.exit(msg)
208 |
209 | con_3d_dataset = ConcatDataset(data_set)
210 |
211 | return DataLoader(
212 | dataset = con_3d_dataset,
213 | batch_size = config.args.batch_3d_size,
214 | shuffle = True,
215 | drop_last = True,
216 | pin_memory = True,
217 | num_workers = config.args.num_worker
218 | )
219 |
220 | def _create_adv_data_loader(self, data_adv_set):
221 | data_set = []
222 | for data_set_name in data_adv_set:
223 | data_set_path = config.data_set_path[data_set_name]
224 | if data_set_name == 'mosh':
225 | mosh = mosh_dataloader(
226 | data_set_path = data_set_path,
227 | use_flip = self.use_flip,
228 | flip_prob = self.flip_prob
229 | )
230 | data_set.append(mosh)
231 | else:
232 | msg = 'invalid adv dataset'
233 | sys.exit(msg)
234 |
235 | con_adv_dataset = ConcatDataset(data_set)
236 | return DataLoader(
237 | dataset = con_adv_dataset,
238 | batch_size = config.args.adv_batch_size,
239 | shuffle = True,
240 | drop_last = True,
241 | pin_memory = True,
242 | )
243 |
244 | def _create_eval_data_loader(self, data_eval_set):
245 | data_set = []
246 | for data_set_name in data_eval_set:
247 | data_set_path = config.data_set_path[data_set_name]
248 | if data_set_name == 'up3d':
249 | up3d = eval_dataloader(
250 | data_set_path = data_set_path,
251 | use_flip = False,
252 | flip_prob = self.flip_prob,
253 | pix_format = self.pix_format,
254 | normalize = self.normalize
255 | )
256 | data_set.append(up3d)
257 | else:
258 | msg = 'invalid eval dataset'
259 | sys.exit(msg)
260 | con_eval_dataset = ConcatDataset(data_set)
261 | return DataLoader(
262 | dataset = con_eval_dataset,
263 | batch_size = config.args.eval_batch_size,
264 | shuffle = False,
265 | drop_last = False,
266 | pin_memory = True,
267 | num_workers = config.args.num_worker
268 | )
269 |
270 | def train(self):
271 | def save_model(result):
272 | exclude_key = 'module.smpl'
273 | def exclude_smpl(model_dict):
274 | result = OrderedDict()
275 | for (k, v) in model_dict.items():
276 | if exclude_key in k:
277 | continue
278 | result[k] = v
279 | return result
280 |
281 | parent_folder = args.save_folder
282 | if not os.path.exists(parent_folder):
283 | os.makedirs(parent_folder)
284 |
285 | title = result['title']
286 | generator_save_path = os.path.join(parent_folder, title + 'generator.pkl')
287 | torch.save(exclude_smpl(self.generator.state_dict()), generator_save_path)
288 | disc_save_path = os.path.join(parent_folder, title + 'discriminator.pkl')
289 | torch.save(exclude_smpl(self.discriminator.state_dict()), disc_save_path)
290 | with open(os.path.join(parent_folder, title + '.txt'), 'w') as fp:
291 | fp.write(str(result))
292 |
293 | #pre_best_loss = None
294 |
295 | torch.backends.cudnn.benchmark = True
296 | loader_2d, loader_3d, loader_mosh = iter(self.loader_2d), iter(self.loader_3d), iter(self.loader_mosh)
297 | e_opt, d_opt = self.e_opt, self.d_opt
298 |
299 | self.generator.train()
300 | self.discriminator.train()
301 |
302 | for iter_index in range(config.args.iter_count):
303 | try:
304 | data_2d = next(loader_2d)
305 | except StopIteration:
306 | loader_2d = iter(self.loader_2d)
307 | data_2d = next(loader_2d)
308 |
309 | try:
310 | data_3d = next(loader_3d)
311 | except StopIteration:
312 | loader_3d = iter(self.loader_3d)
313 | data_3d = next(loader_3d)
314 |
315 | try:
316 | data_mosh = next(loader_mosh)
317 | except StopIteration:
318 | loader_mosh = iter(self.loader_mosh)
319 | data_mosh = next(loader_mosh)
320 |
321 | image_from_2d, image_from_3d = data_2d['image'], data_3d['image']
322 | sample_2d_count, sample_3d_count, sample_mosh_count = image_from_2d.shape[0], image_from_3d.shape[0], data_mosh['theta'].shape[0]
323 | images = torch.cat((image_from_2d, image_from_3d), dim = 0).cuda()
324 |
325 | generator_outputs = self.generator(images)
326 |
327 | loss_kp_2d, loss_kp_3d, loss_shape, loss_pose, e_disc_loss, d_disc_loss, d_disc_real, d_disc_predict = self._calc_loss(generator_outputs, data_2d, data_3d, data_mosh)
328 |
329 | e_loss = loss_kp_2d + loss_kp_3d + loss_shape + loss_pose + e_disc_loss
330 | d_loss = d_disc_loss
331 |
332 | e_opt.zero_grad()
333 | e_loss.backward()
334 | e_opt.step()
335 |
336 | d_opt.zero_grad()
337 | d_loss.backward()
338 | d_opt.step()
339 |
340 | loss_kp_2d = float(loss_kp_2d)
341 | loss_shape = float(loss_shape / args.e_shape_ratio)
342 | loss_kp_3d = float(loss_kp_3d / args.e_3d_kp_ratio)
343 | loss_pose = float(loss_pose / args.e_pose_ratio)
344 | e_disc_loss = float(e_disc_loss / args.d_disc_ratio)
345 | d_disc_loss = float(d_disc_loss / args.d_disc_ratio)
346 |
347 | d_disc_real = float(d_disc_real / args.d_disc_ratio)
348 | d_disc_predict = float(d_disc_predict / args.d_disc_ratio)
349 |
350 | e_loss = loss_kp_2d + loss_kp_3d + loss_shape + loss_pose + e_disc_loss
351 | d_loss = d_disc_loss
352 |
353 | iter_msg = OrderedDict(
354 | [
355 | ('time',datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S')),
356 | ('iter',iter_index),
357 | ('e_loss', e_loss),
358 | ('2d_loss',loss_kp_2d),
359 | ('3d_loss',loss_kp_3d),
360 | ('shape_loss',loss_shape),
361 | ('pose_loss', loss_pose),
362 | ('e_disc_loss',float(e_disc_loss)),
363 | ('d_disc_loss',float(d_disc_loss)),
364 | ('d_disc_real', float(d_disc_real)),
365 | ('d_disc_predict', float(d_disc_predict))
366 | ]
367 | )
368 |
369 | print(iter_msg)
370 |
371 | if iter_index % 500 == 0:
372 | iter_msg['title'] = '{}_{}_'.format(iter_msg['iter'], iter_msg['e_loss'])
373 | save_model(iter_msg)
374 |
375 | def _calc_loss(self, generator_outputs, data_2d, data_3d, data_mosh):
376 | def _accumulate_thetas(generator_outputs):
377 | thetas = []
378 | for (theta, verts, j2d, j3d, Rs) in generator_outputs:
379 | thetas.append(theta)
380 | return torch.cat(thetas, 0)
381 |
382 | sample_2d_count, sample_3d_count, sample_mosh_count = data_2d['kp_2d'].shape[0], data_3d['kp_2d'].shape[0], data_mosh['theta'].shape
383 | data_3d_theta, w_3d, w_smpl = data_3d['theta'].cuda(), data_3d['w_3d'].float().cuda(), data_3d['w_smpl'].float().cuda()
384 |
385 | total_predict_thetas = _accumulate_thetas(generator_outputs)
386 | (predict_theta, predict_verts, predict_j2d, predict_j3d, predict_Rs) = generator_outputs[-1]
387 |
388 | real_2d, real_3d = torch.cat((data_2d['kp_2d'], data_3d['kp_2d']), 0).cuda(), data_3d['kp_3d'].float().cuda()
389 | predict_j2d, predict_j3d, predict_theta = predict_j2d, predict_j3d[sample_2d_count:, :], predict_theta[sample_2d_count:, :]
390 |
391 | loss_kp_2d = self.batch_kp_2d_l1_loss(real_2d, predict_j2d[:,:14,:]) * args.e_loss_weight
392 | loss_kp_3d = self.batch_kp_3d_l2_loss(real_3d, predict_j3d[:,:14,:], w_3d) * args.e_3d_loss_weight * args.e_3d_kp_ratio
393 |
394 | real_shape, predict_shape = data_3d_theta[:, 75:], predict_theta[:, 75:]
395 | loss_shape = self.batch_shape_l2_loss(real_shape, predict_shape, w_smpl) * args.e_3d_loss_weight * args.e_shape_ratio
396 |
397 | real_pose, predict_pose = data_3d_theta[:, 3:75], predict_theta[:, 3:75]
398 | loss_pose = self.batch_pose_l2_loss(real_pose.contiguous(), predict_pose.contiguous(), w_smpl) * args.e_3d_loss_weight * args.e_pose_ratio
399 |
400 | e_disc_loss = self.batch_encoder_disc_l2_loss(self.discriminator(total_predict_thetas)) * args.d_loss_weight * args.d_disc_ratio
401 |
402 | mosh_real_thetas = data_mosh['theta'].cuda()
403 | fake_thetas = total_predict_thetas.detach()
404 | fake_disc_value, real_disc_value = self.discriminator(fake_thetas), self.discriminator(mosh_real_thetas)
405 | d_disc_real, d_disc_fake, d_disc_loss = self.batch_adv_disc_l2_loss(real_disc_value, fake_disc_value)
406 | d_disc_real, d_disc_fake, d_disc_loss = d_disc_real * args.d_loss_weight * args.d_disc_ratio, d_disc_fake * args.d_loss_weight * args.d_disc_ratio, d_disc_loss * args.d_loss_weight * args.d_disc_ratio
407 |
408 | return loss_kp_2d, loss_kp_3d, loss_shape, loss_pose, e_disc_loss, d_disc_loss, d_disc_real, d_disc_fake
409 |
410 | """
411 | purpose:
412 | calc L1 error
413 | Inputs:
414 | kp_gt : N x K x 3
415 | kp_pred: N x K x 2
416 | """
417 | def batch_kp_2d_l1_loss(self, real_2d_kp, predict_2d_kp):
418 | kp_gt = real_2d_kp.view(-1, 3)
419 | kp_pred = predict_2d_kp.contiguous().view(-1, 2)
420 | vis = kp_gt[:, 2]
421 | k = torch.sum(vis) * 2.0 + 1e-8
422 | dif_abs = torch.abs(kp_gt[:, :2] - kp_pred).sum(1)
423 | return torch.matmul(dif_abs, vis) * 1.0 / k
424 |
425 | '''
426 | purpose:
427 | calc mse * 0.5
428 |
429 | Inputs:
430 | real_3d_kp : N x k x 3
431 | fake_3d_kp : N x k x 3
432 | w_3d : N x 1
433 | '''
434 | def batch_kp_3d_l2_loss(self, real_3d_kp, fake_3d_kp, w_3d):
435 | shape = real_3d_kp.shape
436 | k = torch.sum(w_3d) * shape[1] * 3.0 * 2.0 + 1e-8
437 |
438 | #first align it
439 | real_3d_kp, fake_3d_kp = align_by_pelvis(real_3d_kp), align_by_pelvis(fake_3d_kp)
440 | kp_gt = real_3d_kp
441 | kp_pred = fake_3d_kp
442 | kp_dif = (kp_gt - kp_pred) ** 2
443 | return torch.matmul(kp_dif.sum(1).sum(1), w_3d) * 1.0 / k
444 |
445 | '''
446 | purpose:
447 | calc mse * 0.5
448 |
449 | Inputs:
450 | real_shape : N x 10
451 | fake_shape : N x 10
452 | w_shape : N x 1
453 | '''
454 | def batch_shape_l2_loss(self, real_shape, fake_shape, w_shape):
455 | k = torch.sum(w_shape) * 10.0 * 2.0 + 1e-8
456 | shape_dif = (real_shape - fake_shape) ** 2
457 | return torch.matmul(shape_dif.sum(1), w_shape) * 1.0 / k
458 |
459 | '''
460 | Input:
461 | real_pose : N x 72
462 | fake_pose : N x 72
463 | '''
464 | def batch_pose_l2_loss(self, real_pose, fake_pose, w_pose):
465 | k = torch.sum(w_pose) * 207.0 * 2.0 + 1e-8
466 | real_rs, fake_rs = batch_rodrigues(real_pose.view(-1, 3)).view(-1, 24, 9)[:,1:,:], batch_rodrigues(fake_pose.view(-1, 3)).view(-1, 24, 9)[:,1:,:]
467 | dif_rs = ((real_rs - fake_rs) ** 2).view(-1, 207)
468 | return torch.matmul(dif_rs.sum(1), w_pose) * 1.0 / k
469 | '''
470 | Inputs:
471 | disc_value: N x 25
472 | '''
473 | def batch_encoder_disc_l2_loss(self, disc_value):
474 | k = disc_value.shape[0]
475 | return torch.sum((disc_value - 1.0) ** 2) * 1.0 / k
476 | '''
477 | Inputs:
478 | disc_value: N x 25
479 | '''
480 | def batch_adv_disc_l2_loss(self, real_disc_value, fake_disc_value):
481 | ka = real_disc_value.shape[0]
482 | kb = fake_disc_value.shape[0]
483 | lb, la = torch.sum(fake_disc_value ** 2) / kb, torch.sum((real_disc_value - 1) ** 2) / ka
484 | return la, lb, la + lb
485 |
486 | def main():
487 | trainer = HMRTrainer()
488 | trainer.train()
489 |
490 | if __name__ == '__main__':
491 | main()
492 |
--------------------------------------------------------------------------------
/src/util.py:
--------------------------------------------------------------------------------
1 |
2 | '''
3 | file: util.py
4 |
5 | date: 2018_04_29
6 | author: zhangxiong(1025679612@qq.com)
7 | '''
8 |
9 | import h5py
10 | import torch
11 | import numpy as np
12 | from config import args
13 | import json
14 | from torch.autograd import Variable
15 | import torch.nn.functional as F
16 | import cv2
17 | import math
18 | from scipy import interpolate
19 |
20 | def load_mean_theta():
21 | mean = np.zeros(args.total_theta_count, dtype = np.float)
22 |
23 | mean_values = h5py.File(args.smpl_mean_theta_path)
24 | mean_pose = mean_values['pose']
25 | mean_pose[:3] = 0
26 | mean_shape = mean_values['shape']
27 | mean_pose[0]=np.pi
28 |
29 | #init sacle is 0.9
30 | mean[0] = 0.9
31 |
32 | mean[3:75] = mean_pose[:]
33 | mean[75:] = mean_shape[:]
34 |
35 | return mean
36 |
37 | def batch_rodrigues(theta):
38 | #theta N x 3
39 | batch_size = theta.shape[0]
40 | l1norm = torch.norm(theta + 1e-8, p = 2, dim = 1)
41 | angle = torch.unsqueeze(l1norm, -1)
42 | normalized = torch.div(theta, angle)
43 | angle = angle * 0.5
44 | v_cos = torch.cos(angle)
45 | v_sin = torch.sin(angle)
46 | quat = torch.cat([v_cos, v_sin * normalized], dim = 1)
47 |
48 | return quat2mat(quat)
49 |
50 | def quat2mat(quat):
51 | """Convert quaternion coefficients to rotation matrix.
52 | Args:
53 | quat: size = [B, 4] 4 <===>(w, x, y, z)
54 | Returns:
55 | Rotation matrix corresponding to the quaternion -- size = [B, 3, 3]
56 | """
57 | norm_quat = quat
58 | norm_quat = norm_quat/norm_quat.norm(p=2, dim=1, keepdim=True)
59 | w, x, y, z = norm_quat[:,0], norm_quat[:,1], norm_quat[:,2], norm_quat[:,3]
60 |
61 | B = quat.size(0)
62 |
63 | w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
64 | wx, wy, wz = w*x, w*y, w*z
65 | xy, xz, yz = x*y, x*z, y*z
66 |
67 | rotMat = torch.stack([w2 + x2 - y2 - z2, 2*xy - 2*wz, 2*wy + 2*xz,
68 | 2*wz + 2*xy, w2 - x2 + y2 - z2, 2*yz - 2*wx,
69 | 2*xz - 2*wy, 2*wx + 2*yz, w2 - x2 - y2 + z2], dim=1).view(B, 3, 3)
70 | return rotMat
71 |
72 | def batch_global_rigid_transformation(Rs, Js, parent, rotate_base = False):
73 | N = Rs.shape[0]
74 | if rotate_base:
75 | np_rot_x = np.array([[1, 0, 0], [0, -1, 0], [0, 0, -1]], dtype = np.float)
76 | np_rot_x = np.reshape(np.tile(np_rot_x, [N, 1]), [N, 3, 3])
77 | rot_x = Variable(torch.from_numpy(np_rot_x).float()).cuda()
78 | root_rotation = torch.matmul(Rs[:, 0, :, :], rot_x)
79 | else:
80 | root_rotation = Rs[:, 0, :, :]
81 | Js = torch.unsqueeze(Js, -1)
82 |
83 | def make_A(R, t):
84 | R_homo = F.pad(R, [0, 0, 0, 1, 0, 0])
85 | t_homo = torch.cat([t, Variable(torch.ones(N, 1, 1)).cuda()], dim = 1)
86 | return torch.cat([R_homo, t_homo], 2)
87 |
88 | A0 = make_A(root_rotation, Js[:, 0])
89 | results = [A0]
90 |
91 | for i in range(1, parent.shape[0]):
92 | j_here = Js[:, i] - Js[:, parent[i]]
93 | A_here = make_A(Rs[:, i], j_here)
94 | res_here = torch.matmul(results[parent[i]], A_here)
95 | results.append(res_here)
96 |
97 | results = torch.stack(results, dim = 1)
98 |
99 | new_J = results[:, :, :3, 3]
100 | Js_w0 = torch.cat([Js, Variable(torch.zeros(N, 24, 1, 1)).cuda()], dim = 2)
101 | init_bone = torch.matmul(results, Js_w0)
102 | init_bone = F.pad(init_bone, [3, 0, 0, 0, 0, 0, 0, 0])
103 | A = results - init_bone
104 |
105 | return new_J, A
106 |
107 |
108 | def batch_lrotmin(theta):
109 | theta = theta[:,3:].contiguous()
110 | Rs = batch_rodrigues(theta.view(-1, 3))
111 | print(Rs.shape)
112 | e = Variable(torch.eye(3).float())
113 | Rs = Rs.sub(1.0, e)
114 |
115 | return Rs.view(-1, 23 * 9)
116 |
117 | def batch_orth_proj(X, camera):
118 | '''
119 | X is N x num_points x 3
120 | '''
121 | camera = camera.view(-1, 1, 3)
122 | X_trans = X[:, :, :2] + camera[:, :, 1:]
123 | shape = X_trans.shape
124 | return (camera[:, :, 0] * X_trans.view(shape[0], -1)).view(shape)
125 |
126 | def calc_aabb(ptSets):
127 | if not ptSets or len(ptSets) == 0:
128 | return False, False, False
129 |
130 | ptLeftTop = np.array([ptSets[0][0], ptSets[0][1]])
131 | ptRightBottom = ptLeftTop.copy()
132 | for pt in ptSets:
133 | ptLeftTop[0] = min(ptLeftTop[0], pt[0])
134 | ptLeftTop[1] = min(ptLeftTop[1], pt[1])
135 | ptRightBottom[0] = max(ptRightBottom[0], pt[0])
136 | ptRightBottom[1] = max(ptRightBottom[1], pt[1])
137 |
138 | return ptLeftTop, ptRightBottom, len(ptSets) >= 5
139 |
140 | '''
141 | calculate a obb for a set of points
142 |
143 | inputs:
144 | ptSets: a set of points
145 |
146 | return the center and 4 corners of a obb
147 | '''
148 | def calc_obb(ptSets):
149 | ca = np.cov(ptSets,y = None,rowvar = 0,bias = 1)
150 | v, vect = np.linalg.eig(ca)
151 | tvect = np.transpose(vect)
152 | ar = np.dot(ptSets,np.linalg.inv(tvect))
153 | mina = np.min(ar,axis=0)
154 | maxa = np.max(ar,axis=0)
155 | diff = (maxa - mina)*0.5
156 | center = mina + diff
157 | corners = np.array([center+[-diff[0],-diff[1]],center+[diff[0],-diff[1]],center+[diff[0],diff[1]],center+[-diff[0],diff[1]]])
158 | corners = np.dot(corners, tvect)
159 | return corners[0], corners[1], corners[2], corners[3]
160 |
161 | def get_image_cut_box(leftTop, rightBottom, ExpandsRatio, Center = None):
162 | try:
163 | l = len(ExpandsRatio)
164 | except:
165 | ExpandsRatio = [ExpandsRatio, ExpandsRatio, ExpandsRatio, ExpandsRatio]
166 |
167 | def _expand_crop_box(lt, rb, scale):
168 | center = (lt + rb) / 2.0
169 | xl, xr, yt, yb = lt[0] - center[0], rb[0] - center[0], lt[1] - center[1], rb[1] - center[1]
170 | xl, xr, yt, yb = xl * scale[0], xr * scale[1], yt * scale[2], yb * scale[3]
171 | #expand it
172 | lt, rb = np.array([center[0] + xl, center[1] + yt]), np.array([center[0] + xr, center[1] + yb])
173 | lb, rt = np.array([center[0] + xl, center[1] + yb]), np.array([center[0] + xr, center[1] + yt])
174 | center = (lt + rb) / 2
175 | return center, lt, rt, rb, lb
176 |
177 | if Center == None:
178 | Center = (leftTop + rightBottom) // 2
179 |
180 | Center, leftTop, rightTop, rightBottom, leftBottom = _expand_crop_box(leftTop, rightBottom, ExpandsRatio)
181 |
182 | offset = (rightBottom - leftTop) // 2
183 |
184 | cx = offset[0]
185 | cy = offset[1]
186 |
187 | r = max(cx, cy)
188 |
189 | cx = r
190 | cy = r
191 |
192 | x = int(Center[0])
193 | y = int(Center[1])
194 |
195 | return [x - cx, y - cy], [x + cx, y + cy]
196 |
197 | def shrink(leftTop, rightBottom, width, height):
198 | xl = -leftTop[0]
199 | xr = rightBottom[0] - width
200 |
201 | yt = -leftTop[1]
202 | yb = rightBottom[1] - height
203 |
204 | cx = (leftTop[0] + rightBottom[0]) / 2
205 | cy = (leftTop[1] + rightBottom[1]) / 2
206 |
207 | r = (rightBottom[0] - leftTop[0]) / 2
208 |
209 | sx = max(xl, 0) + max(xr, 0)
210 | sy = max(yt, 0) + max(yb, 0)
211 |
212 | if (xl <= 0 and xr <= 0) or (yt <= 0 and yb <=0):
213 | return leftTop, rightBottom
214 | elif leftTop[0] >= 0 and leftTop[1] >= 0 : # left top corner is in box
215 | l = min(yb, xr)
216 | r = r - l / 2
217 | cx = cx - l / 2
218 | cy = cy - l / 2
219 | elif rightBottom[0] <= width and rightBottom[1] <= height : # right bottom corner is in box
220 | l = min(yt, xl)
221 | r = r - l / 2
222 | cx = cx + l / 2
223 | cy = cy + l / 2
224 | elif leftTop[0] >= 0 and rightBottom[1] <= height : #left bottom corner is in box
225 | l = min(xr, yt)
226 | r = r - l / 2
227 | cx = cx - l / 2
228 | cy = cy + l / 2
229 | elif rightBottom[0] <= width and leftTop[1] >= 0 : #right top corner is in box
230 | l = min(xl, yb)
231 | r = r - l / 2
232 | cx = cx + l / 2
233 | cy = cy - l / 2
234 | elif xl < 0 or xr < 0 or yb < 0 or yt < 0:
235 | return leftTop, rightBottom
236 | elif sx >= sy:
237 | sx = max(xl, 0) + max(0, xr)
238 | sy = max(yt, 0) + max(0, yb)
239 | # cy = height / 2
240 | if yt >= 0 and yb >= 0:
241 | cy = height / 2
242 | elif yt >= 0:
243 | cy = cy + sy / 2
244 | else:
245 | cy = cy - sy / 2
246 | r = r - sy / 2
247 |
248 | if xl >= sy / 2 and xr >= sy / 2:
249 | pass
250 | elif xl < sy / 2:
251 | cx = cx - (sy / 2 - xl)
252 | else:
253 | cx = cx + (sy / 2 - xr)
254 | elif sx < sy:
255 | cx = width / 2
256 | r = r - sx / 2
257 | if yt >= sx / 2 and yb >= sx / 2:
258 | pass
259 | elif yt < sx / 2:
260 | cy = cy - (sx / 2 - yt)
261 | else:
262 | cy = cy + (sx / 2 - yb)
263 |
264 |
265 | return [cx - r, cy - r], [cx + r, cy + r]
266 |
267 | '''
268 | offset the keypoint by leftTop
269 | '''
270 | def off_set_pts(keyPoints, leftTop):
271 | result = keyPoints.copy()
272 | result[:, 0] -= leftTop[0]
273 | result[:, 1] -= leftTop[1]
274 | return result
275 |
276 | '''
277 | cut the image, by expanding a bounding box
278 | '''
279 | def cut_image(filePath, kps, expand_ratio, leftTop, rightBottom):
280 | originImage = cv2.imread(filePath)
281 | height = originImage.shape[0]
282 | width = originImage.shape[1]
283 | channels = originImage.shape[2] if len(originImage.shape) >= 3 else 1
284 |
285 | leftTop, rightBottom = get_image_cut_box(leftTop, rightBottom, expand_ratio)
286 |
287 | #remove extra space.
288 | #leftTop, rightBottom = shrink(leftTop, rightBottom, width, height)
289 |
290 | lt = [int(leftTop[0]), int(leftTop[1])]
291 | rb = [int(rightBottom[0]), int(rightBottom[1])]
292 |
293 | lt[0] = max(0, lt[0])
294 | lt[1] = max(0, lt[1])
295 | rb[0] = min(rb[0], width)
296 | rb[1] = min(rb[1], height)
297 |
298 | leftTop = [int(leftTop[0]), int(leftTop[1])]
299 | rightBottom = [int(rightBottom[0] + 0.5), int(rightBottom[1] + 0.5)]
300 |
301 | dstImage = np.zeros(shape = [rightBottom[1] - leftTop[1], rightBottom[0] - leftTop[0], channels], dtype = np.uint8)
302 | dstImage[:,:,:] = 0
303 |
304 | offset = [lt[0] - leftTop[0], lt[1] - leftTop[1]]
305 | size = [rb[0] - lt[0], rb[1] - lt[1]]
306 |
307 | dstImage[offset[1]:size[1] + offset[1], offset[0]:size[0] + offset[0], :] = originImage[lt[1]:rb[1], lt[0]:rb[0],:]
308 | return dstImage, off_set_pts(kps, leftTop)
309 |
310 | '''
311 | purpose:
312 | reflect key point, when the image is reflect by left-right
313 |
314 | inputs:
315 | kps:3d key point(14 x 3)
316 |
317 | marks:
318 | the key point is given by lsp order.
319 | '''
320 | def reflect_lsp_kp(kps):
321 | kp_map = [5, 4, 3, 2, 1, 0, 11, 10, 9, 8, 7, 6, 12, 13]
322 | joint_ref = kps[kp_map]
323 | joint_ref[:,0] = -joint_ref[:,0]
324 |
325 | return joint_ref - np.mean(joint_ref, axis = 0)
326 |
327 | '''
328 | purpose:
329 | reflect poses, when the image is reflect by left-right
330 |
331 | inputs:
332 | poses: 72 real number
333 | '''
334 | def reflect_pose(poses):
335 | swap_inds = np.array([
336 | 0, 1, 2, 6, 7, 8, 3, 4, 5, 9, 10, 11, 15, 16, 17, 12, 13, 14, 18,
337 | 19, 20, 24, 25, 26, 21, 22, 23, 27, 28, 29, 33, 34, 35, 30, 31, 32,
338 | 36, 37, 38, 42, 43, 44, 39, 40, 41, 45, 46, 47, 51, 52, 53, 48, 49,
339 | 50, 57, 58, 59, 54, 55, 56, 63, 64, 65, 60, 61, 62, 69, 70, 71, 66,
340 | 67, 68
341 | ])
342 |
343 | sign_flip = np.array([
344 | 1, -1, -1, 1, -1, -1, 1, -1, -1, 1, -1, -1, 1, -1, -1, 1, -1,
345 | -1, 1, -1, -1, 1, -1, -1, 1, -1, -1, 1, -1, -1, 1, -1, -1, 1,
346 | -1, -1, 1, -1, -1, 1, -1, -1, 1, -1, -1, 1, -1, -1, 1, -1, -1,
347 | 1, -1, -1, 1, -1, -1, 1, -1, -1, 1, -1, -1, 1, -1, -1, 1, -1,
348 | -1, 1, -1, -1
349 | ])
350 |
351 | return poses[swap_inds] * sign_flip
352 |
353 | '''
354 | purpose:
355 | crop the image
356 | inputs:
357 | image_path: the
358 | '''
359 | def crop_image(image_path, angle, lt, rb, scale, kp_2d, crop_size):
360 | '''
361 | given a crop box, expand it at 4 directions.(left, right, top, bottom)
362 | '''
363 | assert 'error algorithm exist.' and 0
364 |
365 | def _expand_crop_box(lt, rb, scale):
366 | center = (lt + rb) / 2.0
367 | xl, xr, yt, yb = lt[0] - center[0], rb[0] - center[0], lt[1] - center[1], rb[1] - center[1]
368 | xl, xr, yt, yb = xl * scale[0], xr * scale[1], yt * scale[2], yb * scale[3]
369 | #expand it
370 | lt, rb = np.array([center[0] + xl, center[1] + yt]), np.array([center[0] + xr, center[1] + yb])
371 | lb, rt = np.array([center[0] + xl, center[1] + yb]), np.array([center[0] + xr, center[1] + yt])
372 | center = (lt + rb) / 2
373 | return center, lt, rt, rb, lb
374 |
375 | '''
376 | extend the box to square
377 | '''
378 | def _extend_box(center, lt, rt, rb, lb, crop_size):
379 | lx, ly = np.linalg.norm(rt - lt), np.linalg.norm(lb - lt)
380 | dx, dy = (rt - lt) / lx, (lb - lt) / ly
381 | l = max(lx, ly) / 2.0
382 | return center - l * dx - l * dy, center + l * dx - l *dy, center + l * dx + l * dy, center - l * dx + l * dy, dx, dy, crop_size * 1.0 / l
383 |
384 | def _get_sample_points(lt, rt, rb, lb, crop_size):
385 | vec_x = rt - lt
386 | vec_y = lb - lt
387 | i_x, i_y = np.meshgrid(range(crop_size), range(crop_size))
388 | i_x = i_x.astype(np.float)
389 | i_y = i_y.astype(np.float)
390 | i_x /= float(crop_size)
391 | i_y /= float(crop_size)
392 | interp_points = i_x[..., np.newaxis].repeat(2, axis=2) * vec_x + i_y[..., np.newaxis].repeat(2, axis=2) * vec_y
393 | interp_points += lt
394 | return interp_points
395 |
396 | def _sample_image(src_image, interp_points):
397 | sample_method = 'nearest'
398 | interp_image = np.zeros((interp_points.shape[0] * interp_points.shape[1], src_image.shape[2]))
399 | i_x = range(src_image.shape[1])
400 | i_y = range(src_image.shape[0])
401 | flatten_interp_points = interp_points.reshape([interp_points.shape[0]*interp_points.shape[1], 2])
402 | for i_channel in range(src_image.shape[2]):
403 | interp_image[:, i_channel] = interpolate.interpn((i_y, i_x), src_image[:, :, i_channel],
404 | flatten_interp_points[:, [1, 0]], method = sample_method,
405 | bounds_error=False, fill_value=0)
406 | interp_image = interp_image.reshape((interp_points.shape[0], interp_points.shape[1], src_image.shape[2]))
407 |
408 | return interp_image
409 |
410 | def _trans_kp_2d(kps, center, dx, dy, lt, ratio):
411 | kp2d_offset = kps[:, :2] - center
412 | proj_x, proj_y = np.dot(kp2d_offset, dx), np.dot(kp2d_offset, dy)
413 | #kp2d = (dx * proj_x + dy * proj_y + lt) * ratio
414 | for idx in range(len(kps)):
415 | kps[idx, :2] = (dx * proj_x[idx] + dy * proj_y[idx] + lt) * ratio
416 | return kps
417 |
418 |
419 | src_image = cv2.imread(image_path)
420 |
421 | center, lt, rt, rb, lb = _expand_crop_box(lt, rb, scale)
422 |
423 | #calc rotated box
424 | radian = angle * np.pi / 180.0
425 | v_sin, v_cos = math.sin(radian), math.cos(radian)
426 |
427 | rot_matrix = np.array(
428 | [
429 | [v_cos, v_sin],
430 | [-v_sin, v_cos]
431 | ]
432 | )
433 |
434 | n_corner = (np.dot(rot_matrix, np.array([lt - center, rt - center, rb - center, lb - center]).T).T) + center
435 | n_lt, n_rt, n_rb, n_lb = n_corner[0], n_corner[1], n_corner[2], n_corner[3]
436 |
437 | lt, rt, rb, lb = calc_obb(np.array([lt, rt, rb, lb, n_lt, n_rt, n_rb, n_lb]))
438 | lt, rt, rb, lb, dx, dy, ratio = _extend_box(center, lt, rt, rb, lb, crop_size = crop_size)
439 | s_pts = _get_sample_points(lt, rt, rb, lb, crop_size)
440 | dst_image = _sample_image(src_image, s_pts)
441 | kp_2d = _trans_kp_2d(kp_2d, center, dx, dy, lt, ratio)
442 |
443 | return dst_image, kp_2d
444 |
445 |
446 | '''
447 | purpose:
448 | flip a image given by src_image and the 2d keypoints
449 | flip_mode:
450 | 0: horizontal flip
451 | >0: vertical flip
452 | <0: horizontal & vertical flip
453 | '''
454 | def flip_image(src_image, kps):
455 | h, w = src_image.shape[0], src_image.shape[1]
456 | src_image = cv2.flip(src_image, 1)
457 | if kps is not None:
458 | kps[:, 0] = w - 1 - kps[:, 0]
459 | kp_map = [5, 4, 3, 2, 1, 0, 11, 10, 9, 8, 7, 6, 12, 13]
460 | kps[:, :] = kps[kp_map]
461 |
462 | return src_image, kps
463 |
464 | '''
465 | src_image: h x w x c
466 | pts: 14 x 3
467 | '''
468 | def draw_lsp_14kp__bone(src_image, pts):
469 | bones = [
470 | [0, 1, 255, 0, 0],
471 | [1, 2, 255, 0, 0],
472 | [2, 12, 255, 0, 0],
473 | [3, 12, 0, 0, 255],
474 | [3, 4, 0, 0, 255],
475 | [4, 5, 0, 0, 255],
476 | [12, 9, 0, 0, 255],
477 | [9,10, 0, 0, 255],
478 | [10,11, 0, 0, 255],
479 | [12, 8, 255, 0, 0],
480 | [8,7, 255, 0, 0],
481 | [7,6, 255, 0, 0],
482 | [12, 13, 0, 255, 0]
483 | ]
484 |
485 | for pt in pts:
486 | if pt[2] > 0.2:
487 | cv2.circle(src_image,(int(pt[0]), int(pt[1])),2,(0,255,255),-1)
488 |
489 | for line in bones:
490 | pa = pts[line[0]]
491 | pb = pts[line[1]]
492 | xa,ya,xb,yb = int(pa[0]),int(pa[1]),int(pb[0]),int(pb[1])
493 | if pa[2] > 0.2 and pb[2] > 0.2:
494 | cv2.line(src_image,(xa,ya),(xb,yb),(line[2], line[3], line[4]),2)
495 |
496 | '''
497 | return whether two segment intersect
498 | '''
499 |
500 | def line_intersect(sa, sb):
501 | al, ar, bl, br = sa[0], sa[1], sb[0], sb[1]
502 | assert al <= ar and bl <= br
503 | if al >= br or bl >= ar:
504 | return False
505 | return True
506 |
507 | '''
508 | return whether two rectangle intersect
509 | ra, rb left_top point, right_bottom point
510 | '''
511 | def rectangle_intersect(ra, rb):
512 | ax = [ra[0][0], ra[1][0]]
513 | ay = [ra[0][1], ra[1][1]]
514 |
515 | bx = [rb[0][0], rb[1][0]]
516 | by = [rb[0][1], rb[1][1]]
517 |
518 | return line_intersect(ax, bx) and line_intersect(ay, by)
519 |
520 | def get_intersected_rectangle(lt0, rb0, lt1, rb1):
521 | if not rectangle_intersect([lt0, rb0], [lt1, rb1]):
522 | return None, None
523 |
524 | lt = lt0.copy()
525 | rb = rb0.copy()
526 |
527 | lt[0] = max(lt[0], lt1[0])
528 | lt[1] = max(lt[1], lt1[1])
529 |
530 | rb[0] = min(rb[0], rb1[0])
531 | rb[1] = min(rb[1], rb1[1])
532 | return lt, rb
533 |
534 | def get_union_rectangle(lt0, rb0, lt1, rb1):
535 | lt = lt0.copy()
536 | rb = rb0.copy()
537 |
538 | lt[0] = min(lt[0], lt1[0])
539 | lt[1] = min(lt[1], lt1[1])
540 |
541 | rb[0] = max(rb[0], rb1[0])
542 | rb[1] = max(rb[1], rb1[1])
543 | return lt, rb
544 |
545 | def get_rectangle_area(lt, rb):
546 | return (rb[0] - lt[0]) * (rb[1] - lt[1])
547 |
548 | def get_rectangle_intersect_ratio(lt0, rb0, lt1, rb1):
549 | (lt0, rb0), (lt1, rb1) = get_intersected_rectangle(lt0, rb0, lt1, rb1), get_union_rectangle(lt0, rb0, lt1, rb1)
550 |
551 | if lt0 is None:
552 | return 0.0
553 | else:
554 | return 1.0 * get_rectangle_area(lt0, rb0) / get_rectangle_area(lt1, rb1)
555 |
556 | def convert_image_by_pixformat_normalize(src_image, pix_format, normalize):
557 | if pix_format == 'NCHW':
558 | src_image = src_image.transpose((2, 0, 1))
559 |
560 | if normalize:
561 | src_image = (src_image.astype(np.float) / 255) * 2.0 - 1.0
562 |
563 | return src_image
564 |
565 | '''
566 | align ty pelvis
567 | joints: n x 14 x 3, by lsp order
568 | '''
569 | def align_by_pelvis(joints):
570 | left_id = 3
571 | right_id = 2
572 | pelvis = (joints[:, left_id, :] + joints[:, right_id, :]) / 2.0
573 | return joints - torch.unsqueeze(pelvis, dim=1)
574 |
575 | def copy_state_dict(cur_state_dict, pre_state_dict, prefix = ''):
576 | def _get_params(key):
577 | key = prefix + key
578 | if key in pre_state_dict:
579 | return pre_state_dict[key]
580 | return None
581 |
582 | for k in cur_state_dict.keys():
583 | v = _get_params(k)
584 | try:
585 | if v is None:
586 | print('parameter {} not found'.format(k))
587 | continue
588 | cur_state_dict[k].copy_(v)
589 | except:
590 | print('copy param {} failed'.format(k))
591 | continue
592 |
593 | if __name__ == '__main__':
594 | image_path = 'E:/HMR/data/COCO/images/train-valid2017/000000000009.jpg'
595 | lt = np.array([-10, -10], dtype = np.float)
596 | rb = np.array([10,10], dtype = np.float)
597 | print(crop_image(image_path, 45, lt, rb, [1, 1, 1, 1], None))
598 |
--------------------------------------------------------------------------------