├── coma
├── __init__.py
└── mesh_sampling.py
├── data
├── readme.txt
└── mesh_down_sampling_4.npz
├── model
├── __init__.py
├── networks
│ ├── __init__.py
│ ├── linear_model.py
│ └── graph_layers.py
├── model_v1.py
├── graph_hg.py
├── mesh_graph_hg.py
└── smal_mesh_net_img.py
├── smal
├── __init__.py
├── smal_basics.py
├── mesh.py
├── batch_lbs.py
└── smal_torch.py
├── util
├── __init__.py
├── utils.py
├── misc.py
├── meter.py
├── pose_prior.py
├── geom_utils.py
├── metrics.py
├── helpers
│ ├── draw_smal_joints.py
│ ├── conversions.py
│ └── visualize.py
├── config.py
├── loss_utils.py
├── loss_sdf.py
├── logger.py
├── joint_limits_prior.py
├── nmr.py
└── net_blocks.py
├── datasets
├── __init__.py
├── imutils.py
└── stanford.py
├── network.png
├── result_examples.png
├── LICENSE
├── requirements.txt
├── README.md
├── eval.py
├── main.py
└── main_meshgraph.py
/coma/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/data/readme.txt:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/smal/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/util/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/model/networks/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/network.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaneyddtt/Coarse-to-fine-3D-Animal/HEAD/network.png
--------------------------------------------------------------------------------
/result_examples.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaneyddtt/Coarse-to-fine-3D-Animal/HEAD/result_examples.png
--------------------------------------------------------------------------------
/data/mesh_down_sampling_4.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaneyddtt/Coarse-to-fine-3D-Animal/HEAD/data/mesh_down_sampling_4.npz
--------------------------------------------------------------------------------
/util/utils.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 |
3 |
4 | def print_options(args):
5 | message = ''
6 | message += '----------------- Options ---------------\n'
7 | for k, v in sorted(vars(args).items()):
8 | comment = ''
9 | message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
10 | message += '----------------- End -------------------'
11 | print(message)
12 |
13 | # save to the disk
14 | file_name = osp.join(args.output_dir, 'opt.txt')
15 | with open(file_name, 'wt') as args_file:
16 | args_file.write(message)
17 | args_file.write('\n')
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Li Chen
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/util/misc.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import shutil
3 | import os
4 |
5 |
6 | def save_checkpoint(state, is_best, checkpoint='checkpoint', filename='checkpoint.pth.tar'):
7 |
8 | filepath = os.path.join(checkpoint, filename)
9 | torch.save(state, filepath)
10 | if is_best:
11 | shutil.copyfile(filepath, os.path.join(checkpoint, 'model_best.pth.tar'))
12 |
13 | def adjust_learning_rate(optimizer, epoch, lr, schedule, gamma):
14 | """Sets the learning rate to the initial LR decayed by schedule"""
15 | if epoch in schedule:
16 | lr *= gamma
17 | for param_group in optimizer.param_groups:
18 | param_group['lr'] = lr
19 | return lr
20 |
21 |
22 | def lr_poly(base_lr, epoch, max_epoch, power):
23 | """ Poly_LR scheduler
24 | """
25 | return base_lr * ((1 - float(epoch) / max_epoch) ** power)
26 |
27 |
28 | def adjust_learning_rate_main(optimizer, epoch, args):
29 | lr = lr_poly(args.lr, epoch, args.max_epoch, args.power)
30 | for param_group in optimizer.param_groups:
31 | param_group['lr'] = lr
32 | return lr
33 |
34 |
35 | def adjust_learning_rate_exponential(optimizer, epoch, epoch_decay, learning_rate, decay_rate):
36 | lr = learning_rate * (decay_rate ** (epoch / epoch_decay))
37 | for param_group in optimizer.param_groups:
38 | param_group['lr'] = lr
39 | return lr
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==0.14.1
2 | cached-property==1.5.2
3 | cachetools==4.2.4
4 | certifi==2020.12.5
5 | chamferdist==0.3.0
6 | chardet==4.0.0
7 | chumpy==0.70
8 | cycler==0.10.0
9 | Cython==0.29.23
10 | decorator==4.4.2
11 | freetype-py==2.2.0
12 | future==0.18.2
13 | google-auth==1.35.0
14 | google-auth-oauthlib==0.4.6
15 | grpcio==1.41.0
16 | h5py==3.2.1
17 | idna==2.10
18 | imageio==2.9.0
19 | importlib-metadata==4.8.1
20 | jsonpatch==1.32
21 | jsonpointer==2.1
22 | kiwisolver==1.3.1
23 | Markdown==3.3.4
24 | matplotlib==3.3.4
25 | networkx==2.5.1
26 | neural-renderer-pytorch==1.1.3
27 | nibabel==3.2.1
28 | numpy==1.20.1
29 | oauthlib==3.1.1
30 | opencv-python==4.5.1.48
31 | packaging==21.0
32 | Pillow==8.1.2
33 | pointnet2-ops==3.0.0
34 | protobuf==3.18.1
35 | pyasn1==0.4.8
36 | pyasn1-modules==0.2.8
37 | pycocotools==2.0.2
38 | pyglet==1.5.18
39 | PyOpenGL==3.1.0
40 | pyparsing==2.4.7
41 | PyQt5==5.9
42 | PyQt5-Qt5==5.15.2
43 | PyQt5-sip==12.9.0
44 | pyrender==0.1.45
45 | python-dateutil==2.8.1
46 | PyWavelets==1.1.1
47 | pyzmq==22.0.3
48 | requests==2.25.1
49 | requests-oauthlib==1.3.0
50 | rsa==4.7.2
51 | scikit-image==0.18.1
52 | scipy==1.2.0
53 | sip==4.19.8
54 | six==1.15.0
55 | tensorboard==2.6.0
56 | tensorboard-data-server==0.6.1
57 | tensorboard-plugin-wit==1.8.0
58 | tifffile==2021.4.8
59 | torch==1.5.0+cu101
60 | torchfile==0.1.0
61 | torchvision==0.6.0+cu101
62 | tornado==6.1
63 | tqdm==4.60.0
64 | transforms3d==0.3.1
65 | trimesh==3.9.26
66 | typing-extensions==3.10.0.2
67 | urllib3==1.26.3
68 | visdom==0.1.8.9
69 | websocket-client==0.58.0
70 | Werkzeug==2.0.2
71 | zipp==3.6.0
72 |
--------------------------------------------------------------------------------
/util/meter.py:
--------------------------------------------------------------------------------
1 | class AverageMeter:
2 | """Computes and stores the average and current value"""
3 |
4 | def __init__(self):
5 | self.reset()
6 |
7 | def reset(self):
8 | self.val = 0
9 | self.avg = 0
10 | self.sum = 0
11 | self.count = 0
12 |
13 | def update(self, val, n=1):
14 | self.val = val
15 | self.sum += val * n
16 | self.count += n
17 | self.avg = self.sum / self.count
18 |
19 | def __format__(self, format):
20 | return "{self.val:{format}} ({self.avg:{format}})".format(self=self, format=format)
21 |
22 |
23 | class AverageMeterSet:
24 | def __init__(self):
25 | self.meters = {}
26 |
27 | def __getitem__(self, key):
28 | return self.meters[key]
29 |
30 | def update(self, name, value, n=1):
31 | if not name in self.meters:
32 | self.meters[name] = AverageMeter()
33 | self.meters[name].update(value, n)
34 |
35 | def reset(self):
36 | for meter in self.meters.values():
37 | meter.reset()
38 |
39 | def values(self, postfix=''):
40 | return {name + postfix: meter.val for name, meter in self.meters.items()}
41 |
42 | def averages(self, postfix='/avg'):
43 | return {name + postfix: meter.avg for name, meter in self.meters.items()}
44 |
45 | def sums(self, postfix='/sum'):
46 | return {name + postfix: meter.sum for name, meter in self.meters.items()}
47 |
48 | def counts(self, postfix='/count'):
49 | return {name + postfix: meter.count for name, meter in self.meters.items()}
--------------------------------------------------------------------------------
/util/pose_prior.py:
--------------------------------------------------------------------------------
1 | import pickle as pkl
2 | import numpy as np
3 | from chumpy import Ch
4 | import cv2
5 | import torch
6 |
7 | name2id35 = {'RFoot': 14, 'RFootBack': 24, 'spine1': 4, 'Head': 16, 'LLegBack3': 19, 'RLegBack1': 21, 'pelvis0': 1, 'RLegBack3': 23, 'LLegBack2': 18, 'spine0': 3, 'spine3': 6, 'spine2': 5, 'Mouth': 32, 'Neck': 15, 'LFootBack': 20, 'LLegBack1': 17, 'RLeg3': 13, 'RLeg2': 12, 'LLeg1': 7, 'LLeg3': 9, 'RLeg1': 11, 'LLeg2': 8, 'spine': 2, 'LFoot': 10, 'Tail7': 31, 'Tail6': 30, 'Tail5': 29, 'Tail4': 28, 'Tail3': 27, 'Tail2': 26, 'Tail1': 25, 'RLegBack2': 22, 'root': 0, 'LEar':33, 'REar':34}
8 | id2name35 = {v: k for k, v in name2id35.items()}
9 |
10 |
11 | class Prior(object):
12 | def __init__(self, prior_path, device):
13 | with open(prior_path, 'rb') as f:
14 | res = pkl.load(f, encoding='latin1')
15 |
16 | self.mean_ch = res['mean_pose']
17 | self.precs_ch = res['pic']
18 |
19 | self.precs = torch.from_numpy(res['pic'].r.copy()).float().to(device)
20 | self.mean = torch.from_numpy(res['mean_pose']).float().to(device)
21 |
22 | prefix = 3
23 | pose_len = 105
24 | id2name = id2name35
25 | name2id = name2id35
26 |
27 | self.use_ind = np.ones(pose_len, dtype=bool)
28 | self.use_ind[:prefix] = False
29 | self.use_ind_tch = torch.from_numpy(self.use_ind).float().to(device)
30 |
31 | def __call__(self, x):
32 | mean_sub = x.reshape(-1, 35*3) - self.mean.unsqueeze(0)
33 | res = torch.tensordot(mean_sub, self.precs, dims=([1], [0])) * self.use_ind_tch
34 | return (res**2).mean()
--------------------------------------------------------------------------------
/smal/smal_basics.py:
--------------------------------------------------------------------------------
1 | import pickle as pkl
2 | import numpy as np
3 |
4 |
5 | def align_smal_template_to_symmetry_axis(v, sym_file):
6 | # These are the indexes of the points that are on the symmetry axis
7 | I = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 37, 55, 119, 120, 163, 209, 210, 211, 213, 216, 227, 326, 395, 452, 578, 910, 959, 964, 975, 976, 977, 1172, 1175, 1176, 1178, 1194, 1243, 1739, 1796, 1797, 1798, 1799, 1800, 1801, 1802, 1803, 1804, 1805, 1806, 1807, 1808, 1809, 1810, 1811, 1812, 1813, 1814, 1815, 1816, 1817, 1818, 1819, 1820, 1821, 1822, 1823, 1824, 1825, 1826, 1827, 1828, 1829, 1830, 1831, 1832, 1833, 1834, 1835, 1836, 1837, 1838, 1839, 1840, 1842, 1843, 1844, 1845, 1846, 1847, 1848, 1849, 1850, 1851, 1852, 1853, 1854, 1855, 1856, 1857, 1858, 1859, 1860, 1861, 1862, 1863, 1870, 1919, 1960, 1961, 1965, 1967, 2003]
8 |
9 | v = v - np.mean(v)
10 | y = np.mean(v[I,1])
11 | v[:,1] = v[:,1] - y
12 | v[I,1] = 0
13 |
14 | # symIdx = pkl.load(open(sym_path))
15 | with open(sym_file, 'rb') as f:
16 | u = pkl._Unpickler(f)
17 | u.encoding = 'latin1'
18 | symIdx = u.load()
19 |
20 |
21 | left = v[:, 1] < 0
22 | right = v[:, 1] > 0
23 | center = v[:, 1] == 0
24 | v[left[symIdx]] = np.array([1,-1,1])*v[left]
25 |
26 | left_inds = np.where(left)[0]
27 | right_inds = np.where(right)[0]
28 | center_inds = np.where(center)[0]
29 |
30 | try:
31 | assert(len(left_inds) == len(right_inds))
32 | except:
33 | import pdb; pdb.set_trace()
34 |
35 | return v, left_inds, right_inds, center_inds
--------------------------------------------------------------------------------
/util/geom_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Utils related to geometry like projection,,
3 | """
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | from __future__ import print_function
7 |
8 | import torch
9 |
10 | def sample_textures(texture_flow, images):
11 | """
12 | texture_flow: B x F x T x T x 2
13 | (In normalized coordinate [-1, 1])
14 | images: B x 3 x N x N
15 |
16 | output: B x F x T x T x 3
17 | """
18 | # Reshape into B x F x T*T x 2
19 | T = texture_flow.size(-2)
20 | F = texture_flow.size(1)
21 | flow_grid = texture_flow.view(-1, F, T * T, 2)
22 | # B x 3 x F x T*T
23 | samples = torch.nn.functional.grid_sample(images, flow_grid)
24 | # B x 3 x F x T x T
25 | samples = samples.view(-1, 3, F, T, T)
26 | # B x F x T x T x 3
27 | return samples.permute(0, 2, 3, 4, 1)
28 |
29 |
30 | def perspective_proj_withz(X, cam, offset_z=0, cuda_device=0,norm_f=1., norm_z=0.,norm_f0=0.):
31 | """
32 | X: B x N x 3
33 | cam: B x 3: [f, cx, cy]
34 | offset_z is for being compatible with previous code and is not used and should be removed
35 | """
36 |
37 | # B x 1 x 1
38 | #f = norm_f * cam[:, 0].contiguous().view(-1, 1, 1)
39 | f = norm_f0+norm_f * cam[:, 0].contiguous().view(-1, 1, 1)
40 | # f = norm_f0
41 | # B x N x 1
42 | z = norm_z + X[:, :, 2, None]
43 |
44 | # Will z ever be 0? We probably should max it..
45 | eps = 1e-6 * torch.ones(1).cuda(device=cuda_device)
46 | z = torch.max(z, eps)
47 | image_size_half = cam[0,1]
48 | scale = f / (z*image_size_half)
49 |
50 | # Offset is because cam is at -1
51 | return torch.cat((scale * X[:, :, :2], z+offset_z),2)
52 |
53 |
--------------------------------------------------------------------------------
/model/networks/linear_model.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import print_function
3 |
4 | import torch.nn as nn
5 |
6 |
7 | def weight_init(m):
8 | if isinstance(m, nn.Linear):
9 | nn.init.kaiming_normal(m.weight)
10 |
11 |
12 | class Linear(nn.Module):
13 | def __init__(self, linear_size, p_dropout=0.5):
14 | super(Linear, self).__init__()
15 | self.l_size = linear_size
16 |
17 | self.relu = nn.ReLU(inplace=True)
18 | self.dropout = nn.Dropout(p_dropout)
19 |
20 | self.w1 = nn.Linear(self.l_size, self.l_size)
21 | self.batch_norm1 = nn.BatchNorm1d(self.l_size)
22 |
23 | self.w2 = nn.Linear(self.l_size, self.l_size)
24 | self.batch_norm2 = nn.BatchNorm1d(self.l_size)
25 |
26 | def forward(self, x):
27 | y = self.w1(x)
28 | y = self.batch_norm1(y)
29 | y = self.relu(y)
30 | y = self.dropout(y)
31 |
32 | y = self.w2(y)
33 | y = self.batch_norm2(y)
34 | y = self.relu(y)
35 | y = self.dropout(y)
36 |
37 | out = x + y
38 |
39 | return out
40 |
41 |
42 | class LinearModel(nn.Module):
43 | def __init__(self, input_size, output_size, linear_size=1024, num_stage=2, p_dropout=0.5):
44 | super(LinearModel, self).__init__()
45 | self.linear_size = linear_size
46 | self.p_dropout = p_dropout
47 | self.num_stage = num_stage
48 |
49 | self.input_size = input_size
50 | self.output_size = output_size
51 |
52 | self.w1 = nn.Linear(self.input_size, self.linear_size)
53 | self.batch_norm1 = nn.BatchNorm1d(self.linear_size)
54 |
55 | self.linear_stages = []
56 | for _ in range(num_stage):
57 | self.linear_stages.append(Linear(self.linear_size, self.p_dropout))
58 | self.linear_stages = nn.ModuleList(self.linear_stages)
59 |
60 | self.w2 = nn.Linear(self.linear_size, self.output_size)
61 |
62 | self.relu = nn.ReLU(inplace=True)
63 | self.dropout = nn.Dropout(self.p_dropout)
64 |
65 | def forward(self, x):
66 | y = self.w1(x)
67 | y = self.batch_norm1(y)
68 | y = self.relu(y)
69 | y = self.dropout(y)
70 |
71 | for i in range(self.num_stage):
72 | y = self.linear_stages[i](y)
73 |
74 | y = self.w2(y)
75 |
76 | return y
77 |
78 |
--------------------------------------------------------------------------------
/model/model_v1.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | from util.nmr import NeuralRenderer
5 | from smal.smal_torch import SMAL
6 | from model.smal_mesh_net_img import MeshNet_img
7 | from util import config
8 |
9 | unity_shape_prior = np.load('data/priors/unity_betas.npz')
10 |
11 |
12 | class MeshModel(nn.Module):
13 | def __init__(self, device, shape_family_id, betas_scale=False, shape_init=None, render_rgb=False):
14 | '''
15 | Args:
16 | device: specify device for training
17 | shape_family_id: specify animal category id
18 | betas_scale: whether predict the additional shape parameter proposed by WLDO
19 | shape_init: whether intialize the bias with a mean shape, choose from smal or unity
20 | render_rgb: whether render 3D mesh into 2D to get rgb image. Only set to true when generating
21 | visualization to save inference time.
22 | '''
23 |
24 | super(MeshModel, self).__init__()
25 |
26 | self.model_renderer = NeuralRenderer(config.IMG_RES, proj_type=config.PROJECTION,
27 | norm_f0=config.NORM_F0,
28 | norm_f=config.NORM_F,
29 | norm_z=config.NORM_Z,
30 | render_rgb=render_rgb,
31 | device=device)
32 | self.model_renderer.directional_light_only()
33 | self.smal = SMAL(device, shape_family_id=shape_family_id)
34 | if shape_init == 'smal':
35 | print("Initiate shape with smal prior")
36 | shape_init = self.smal.shape_cluster_means
37 | elif shape_init == 'unity':
38 | print("Initiate shape with unity prior ")
39 | shape_init = unity_shape_prior['mean'][:-1]
40 | shape_init = torch.from_numpy(shape_init).float().to(device)
41 | else:
42 | print("No initialization for shape")
43 | shape_init = None
44 | input_size = [config.IMG_RES, config.IMG_RES]
45 | self.meshnet = MeshNet_img(input_size, betas_scale=betas_scale, norm_f0=config.NORM_F0, nz_feat=config.NZ_FEAT,
46 | shape_init=shape_init)
47 | print('INITIALIZED')
48 |
49 | def forward(self, inp):
50 |
51 | pred_codes = self.meshnet(inp)
52 |
53 | return pred_codes
--------------------------------------------------------------------------------
/util/metrics.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from util import config
3 | import numpy as np
4 |
5 |
6 | class Metrics():
7 |
8 | @staticmethod
9 | def PCK_thresh(
10 | pred_keypoints, gt_keypoints,
11 | gtseg, has_seg,
12 | thresh, idxs):
13 |
14 | pred_keypoints, gt_keypoints, gtseg = pred_keypoints[has_seg], gt_keypoints[has_seg], gtseg[has_seg]
15 |
16 | if idxs is None:
17 | idxs = list(range(pred_keypoints.shape[1]))
18 |
19 | idxs = np.array(idxs).astype(int)
20 |
21 | pred_keypoints = pred_keypoints[:, idxs]
22 | gt_keypoints = gt_keypoints[:, idxs]
23 |
24 | keypoints_gt = ((gt_keypoints + 1.0) * 0.5) * config.IMG_RES
25 | dist = torch.norm(pred_keypoints - keypoints_gt[:, :, [1, 0]], dim=-1)
26 | seg_area = torch.sum(gtseg.reshape(gtseg.shape[0], -1), dim=-1).unsqueeze(-1)
27 |
28 | hits = (dist / torch.sqrt(seg_area)) < thresh
29 | total_visible = torch.sum(gt_keypoints[:, :, -1], dim=-1)
30 | pck = torch.sum(hits.float() * gt_keypoints[:, :, -1], dim=-1) / total_visible
31 |
32 | return pck
33 |
34 | @staticmethod
35 | def PCK(
36 | pred_keypoints, keypoints,
37 | gtseg, has_seg,
38 | thresh_range=[0.15],
39 | idxs: list = None):
40 |
41 | """Calc PCK with same method as in eval.
42 | idxs = optional list of subset of keypoints to index from
43 | """
44 |
45 | cumulative_pck = []
46 | for thresh in thresh_range:
47 | pck = Metrics.PCK_thresh(
48 | pred_keypoints, keypoints,
49 | gtseg, has_seg, thresh, idxs)
50 | cumulative_pck.append(pck)
51 |
52 | pck_mean = torch.stack(cumulative_pck, dim=0).mean(dim=0)
53 | return pck_mean
54 |
55 | @staticmethod
56 | def IOU(synth_silhouettes, gt_seg, img_border_mask, mask):
57 | for i in range(mask.shape[0]):
58 | synth_silhouettes[i] *= mask[i]
59 |
60 | # Do not penalize parts of the segmentation outside the img range
61 | gt_seg = (gt_seg * img_border_mask) + synth_silhouettes * (1.0 - img_border_mask)
62 |
63 | intersection = torch.sum((synth_silhouettes * gt_seg).reshape(synth_silhouettes.shape[0], -1), dim=-1)
64 | union = torch.sum(((synth_silhouettes + gt_seg).reshape(synth_silhouettes.shape[0], -1) > 0.0).float(), dim=-1)
65 | acc_IOU_SCORE = intersection / union
66 |
67 | return acc_IOU_SCORE
--------------------------------------------------------------------------------
/util/helpers/draw_smal_joints.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 |
4 | from torchvision.utils import make_grid
5 | from matplotlib import cm
6 | import matplotlib.pyplot as plt
7 | from matplotlib.colors import Normalize
8 | import torch.nn.functional as F
9 | import torch
10 |
11 |
12 | # draw keypoints in the input image, code adopted from WLDO
13 | class SMALJointDrawer():
14 | def __init__(self):
15 | self.jet_colormap = cm.ScalarMappable(norm = Normalize(0, 1), cmap = 'jet')
16 |
17 | def draw_joints(self, image, landmarks, visible = None, normalized = True, marker_size = 8, thickness = 3):
18 | image_np = np.transpose(image.cpu().data.numpy(), (0, 2, 3, 1))
19 | landmarks_np = landmarks.cpu().data.numpy()
20 | if visible is not None:
21 | visible_np = visible.cpu().data.numpy()
22 | else:
23 | visible_np = visible
24 |
25 | return_stack = self.draw_joints_np(image_np, landmarks_np, visible_np, normalized, marker_size=marker_size, thickness=thickness)
26 | return torch.FloatTensor(np.transpose(return_stack, (0, 3, 1, 2)))
27 |
28 | def draw_joints_np(self, image_np, landmarks_np, visible_np = None, normalized = False, marker_size = 8, thickness = 3):
29 | if normalized:
30 | image_np = (image_np * 0.5) + 0.5
31 |
32 | image_np = (image_np * 255.0).astype(np.uint8)
33 |
34 | bs, nj, _ = landmarks_np.shape
35 | if visible_np is None:
36 | visible_np = np.ones((bs, nj), dtype=bool)
37 |
38 | return_images = []
39 | for image_sgl, landmarks_sgl, visible_sgl in zip(image_np, landmarks_np, visible_np):
40 | image_sgl = image_sgl.copy()
41 | for joint_id, ((y_co, x_co), vis) in enumerate(zip(landmarks_sgl, visible_sgl)):
42 | if x_co<0 or y_co<0 or x_co>223 and y_co>223:
43 | continue
44 | else:
45 | color = np.array([255, 0, 0])
46 | marker_type = 0
47 | if not vis:
48 | continue
49 | cv2.drawMarker(image_sgl, (int(x_co), int(y_co)), (int(color[0]), int(color[1]), int(color[2])), marker_type, marker_size, thickness = thickness)
50 |
51 | return_images.append(image_sgl)
52 |
53 | return_stack = np.stack(return_images, 0)
54 | return_stack = return_stack / 255.0
55 | if normalized:
56 | return_stack = (return_stack - 0.5) / 0.5 # Scale and re-normalize for pytorch
57 |
58 | return return_stack
59 |
60 | def draw_heatmap_grids(self, heatmaps, silhouettes, upsample_val = 1, alpha_blend = 0.9):
61 | bs, jts, h, w = heatmaps.shape
62 | heatmap_grids = []
63 | for heatmap, silhouette in zip(heatmaps, silhouettes):
64 | heatmap_rgb = heatmap[:, None, :, :].expand(jts, 3, h, w) * alpha_blend + silhouette * (1 - alpha_blend)
65 | grid = make_grid(heatmap_rgb, nrow = 5)
66 | heatmap_jet = self.jet_colormap.to_rgba(grid[0].cpu().numpy())[:, :, :3]
67 | heatmap_jet = np.transpose(heatmap_jet, (2, 0, 1))
68 | heatmap_grids.append(torch.FloatTensor(heatmap_jet))
69 |
70 | heatmap_stack = torch.stack(heatmap_grids, dim = 0)
71 | heatmap_stack = (heatmap_stack - 0.5) / 0.5
72 |
73 | return F.interpolate(heatmap_stack, [h * upsample_val, w * upsample_val])
74 |
75 |
--------------------------------------------------------------------------------
/util/config.py:
--------------------------------------------------------------------------------
1 | """
2 | This file contains definitions of useful data stuctures and the paths
3 | for the datasets and data files necessary to run the code.
4 | Things you need to change: *_ROOT that indicate the path to each dataset
5 | """
6 | from os.path import join
7 | import os
8 |
9 | # Define paths to each dataset
10 |
11 | CODE_DIR = os.getcwd()
12 | BASE_FOLDER = join(CODE_DIR, 'data')
13 | # Output folder to save test/train npz files
14 | DATASET_NPZ_PATH = join(BASE_FOLDER, 'splits')
15 |
16 | # Path to test/train npy files
17 | DATASET_FILES = [
18 | {
19 | 'stanford': join(BASE_FOLDER, 'StanfordExtra_v12', 'test_stanford_StanfordExtra_v12.npy'),
20 | 'animal_pose' : join(BASE_FOLDER, 'animal_pose', 'test_animal_pose.npy')
21 | },
22 | {
23 | 'stanford': join(BASE_FOLDER, 'StanfordExtra_v12', 'train_stanford_StanfordExtra_v12.npy'),
24 | 'animal_pose': join(BASE_FOLDER, 'animal_pose', 'train_animal_pose.npy')
25 | }
26 | ]
27 |
28 |
29 | DATASET_FOLDERS = {
30 | 'stanford' : join(BASE_FOLDER, 'StanfordExtra_v12'),
31 | 'animal_pose' : join(BASE_FOLDER, 'animal_pose')
32 | }
33 |
34 | JSON_NAME = {
35 | 'stanford': 'StanfordExtra_v12.json', # the latest version of the StanfordExtra dataset
36 | 'animal_pose': 'animal_pose_data.json'
37 | }
38 |
39 | BREEDS_CSV = join(BASE_FOLDER, 'breeds.csv')
40 |
41 | EM_DATASET_NAME = "stanford" # the dataset to learn the EM prior on
42 |
43 | data_path = join(CODE_DIR, 'data')
44 | # SMAL
45 | SMAL_FILE = join(data_path, 'smal', 'my_smpl_00781_4_all.pkl')
46 | SMAL_DATA_FILE = join(data_path, 'smal', 'my_smpl_data_00781_4_all.pkl')
47 | SMAL_UV_FILE = join(data_path, 'smal', 'my_smpl_00781_4_all_template_w_tex_uv_001.pkl')
48 | SMAL_SYM_FILE = join(data_path, 'smal', 'symIdx.pkl')
49 | SHAPE_FAMILY_ID = 1 # the dog shape family
50 |
51 | # PRIORS
52 | WALKING_PRIOR_FILE = join(data_path, 'priors', 'walking_toy_symmetric_pose_prior_with_cov_35parts.pkl')
53 | UNITY_POSE_PRIOR = join(data_path, 'priors', 'unity_pose_prior_with_cov_35parts.pkl')
54 | UNITY_SHAPE_PRIOR = join(data_path, 'priors', 'unity_betas.npz')
55 | SMAL_DOG_TOY_IDS = [0, 1, 2]
56 |
57 | # DATALOADER
58 | IMG_RES = 224
59 | NUM_JOINTS = 20
60 | # Mean and standard deviation for normalizing input image
61 | IMG_NORM_MEAN = [0.485, 0.456, 0.406]
62 | IMG_NORM_STD = [0.229, 0.224, 0.225]
63 |
64 | # RENDERER
65 | PROJECTION = 'perspective'
66 | NORM_F0 = 2700.0
67 | NORM_F = 2700.0
68 | NORM_Z = 20.0
69 |
70 | MESH_COLOR = [0, 172, 223]
71 | # MESH_COLOR=[234, 156, 199.]
72 |
73 | # MESH_NET
74 | NZ_FEAT = 100
75 |
76 | # ASSOCIATING SMAL TO ANNOTATED JOINTS
77 | MODEL_JOINTS = [
78 | 14, 13, 12, # left front (0, 1, 2)
79 | 24, 23, 22, # left rear (3, 4, 5)
80 | 10, 9, 8, # right front (6, 7, 8)
81 | 20, 19, 18, # right rear (9, 10, 11)
82 | 25, 31, # tail start -> end (12, 13)
83 | 34, 33, # right ear, left ear (14, 15)
84 | 35, 36, # nose, chin (16, 17)
85 | 37, 38] # right tip, left tip (18, 19)
86 |
87 | CANONICAL_MODEL_JOINTS = [
88 | 10, 9, 8, # upper_left [paw, middle, top]
89 | 20, 19, 18, # lower_left [paw, middle, top]
90 | 14, 13, 12, # upper_right [paw, middle, top]
91 | 24, 23, 22, # lower_right [paw, middle, top]
92 | 25, 31, # tail [start, end]
93 | 33, 34, # ear base [left, right]
94 | 35, 36, # nose, chin
95 | 38, 37, # ear tip [left, right]
96 | 39, 40, # eyes [left, right]
97 | 15, 15, # withers, throat (TODO: Labelled same as throat for now), throat
98 | 28] # tail middle
99 |
100 |
101 | EVAL_KEYPOINTS = [
102 | 0, 1, 2, # left front
103 | 3, 4, 5, # left rear
104 | 6, 7, 8, # right front
105 | 9, 10, 11, # right rear
106 | 12, 13, # tail start -> end
107 | 14, 15, # left ear, right ear
108 | 16, 17, # nose, chin
109 | 18, 19] # left tip, right tip
110 |
111 | KEYPOINT_GROUPS = {
112 | 'legs': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], # legs
113 | 'tail': [12, 13], # tail
114 | 'ears': [14, 15, 18, 19], # ears
115 | 'face': [16, 17] # face
116 | }
117 |
--------------------------------------------------------------------------------
/util/loss_utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import torch
6 | import numpy as np
7 | from util import config
8 | import pickle as pkl
9 | criterionL2 = torch.nn.MSELoss()
10 | criterionL1 = torch.nn.L1Loss()
11 |
12 |
13 | # L1 or L2 based sihouette loss
14 | def mask_loss(mask_pred, mask_gt):
15 | # return torch.nn.L1Loss()(mask_pred, mask_gt)
16 | return criterionL2(mask_pred, mask_gt)
17 |
18 |
19 | # 2D keypoint loss
20 | def kp_l2_loss(kp_pred, kp_gt, num_joints):
21 | vis = (kp_gt[:, :, 2, None] > 0).float()
22 | if not kp_pred.ndim == 3:
23 | kp_pred.reshape((kp_pred.shape[0], num_joints, -1))
24 | return criterionL2(vis * kp_pred, vis * kp_gt[:, :, :2])
25 |
26 |
27 | # compute shape prior based on Mahalanobis distance, formulation taken from
28 | # https://github.com/benjiebob/SMALify/blob/master/smal_fitter/smal_fitter.py
29 | class Shape_prior(torch.nn.Module):
30 | def __init__(self, prior, shape_family_id, device, data_path=None):
31 | '''
32 | Args:
33 | prior: specify the prior to use, smal or unity or self-defined data
34 | shape_family_id: specify animal category id
35 | device: specify device
36 | data_path: specify self-defined data path if do not use smal or unity
37 | '''
38 | super(Shape_prior, self).__init__()
39 | if prior == 'smal':
40 | nbetas=20
41 | with open(config.SMAL_DATA_FILE, 'rb') as f:
42 | u = pkl._Unpickler(f)
43 | u.encoding = 'latin1'
44 | data = u.load()
45 | shape_cluster_means = data['cluster_means'][shape_family_id]
46 | betas_cov = data['cluster_cov'][shape_family_id]
47 | betas_mean = torch.from_numpy(shape_cluster_means).float().to(device)
48 | elif prior == 'unity':
49 | nbetas=26
50 | unity_data = np.load(config.UNITY_SHAPE_PRIOR)
51 | betas_cov = unity_data['cov'][:-1, :-1]
52 | betas_mean = torch.from_numpy(unity_data['mean'][:-1]).float().to(device)
53 | else:
54 | assert data_path is not None
55 | nbetas=26
56 | prior_data = np.load(data_path, allow_pickle=True)
57 | betas_mean = torch.from_numpy(prior_data.item()['mean']).float().to(device)
58 | betas_cov = prior_data.item()['cov']
59 |
60 | invcov = np.linalg.inv(betas_cov + 1e-5 * np.eye(betas_cov.shape[0]))
61 | prec = np.linalg.cholesky(invcov)
62 | self.betas_prec = torch.Tensor(prec)[:nbetas, :nbetas].to(device)
63 | self.betas_mean = betas_mean[:nbetas]
64 |
65 | def __call__(self, betas_pred):
66 | diff = betas_pred - self.betas_mean.unsqueeze(0)
67 | res = torch.tensordot(diff, self.betas_prec, dims=([1], [0]))
68 | return (res**2).mean()
69 |
70 |
71 | # Laplacian loss, calculate the Laplacian coordiante of both coarse and refined vertices and then compare the difference
72 | class Laplacian(torch.nn.Module):
73 | def __init__(self, adjmat, device):
74 | '''
75 | Args:
76 | adjmat: adjacency matrix of the input graph data
77 | device: specify device for training
78 | '''
79 | super(Laplacian, self).__init__()
80 | adjmat.data = np.ones_like(adjmat.data)
81 | adjmat = torch.from_numpy(adjmat.todense()).float()
82 | dg = torch.sum(adjmat, dim=-1)
83 | dg_m = torch.diag(dg)
84 | ls = dg_m - adjmat
85 | self.ls = ls.unsqueeze(0).to(device) # Should be normalized by the diagonal elements according to
86 | # the origial definition, this one also works fine.
87 |
88 | def forward(self, verts_pred, verts_gt, smooth=False):
89 | verts_pred = torch.matmul(self.ls, verts_pred)
90 | verts_gt = torch.matmul(self.ls, verts_gt)
91 | loss = torch.norm(verts_pred - verts_gt, dim=-1).mean()
92 | if smooth:
93 | loss_smooth = torch.norm(torch.matmul(self.ls, verts_pred), dim=-1).mean()
94 | return loss, loss_smooth
95 | return loss, None
96 |
--------------------------------------------------------------------------------
/model/graph_hg.py:
--------------------------------------------------------------------------------
1 | """
2 | This file contains the Definition of GraphCNN
3 | GraphCNN includes ResNet50 as a submodule
4 | """
5 | from __future__ import division
6 |
7 | import torch
8 | import torch.nn as nn
9 |
10 | from model.networks.graph_layers import GraphResBlock, GraphLinear
11 | from smal.mesh import Mesh
12 | from smal.smal_torch import SMAL
13 |
14 | # encoder-decoder structured GCN with skip connections
15 | class GraphCNN_hg(nn.Module):
16 |
17 | def __init__(self, mesh, num_channels=256, local_feat=False, num_downsample=0):
18 | '''
19 | Args:
20 | mesh: mesh data that store the adjacency matrix
21 | num_channels: number of channels of GCN
22 | local_feat: whether use local feature for refinement
23 | num_downsample: number of downsampling of the input mesh
24 | '''
25 | super(GraphCNN_hg, self).__init__()
26 | self.A = mesh._A[num_downsample:] # get the correct adjacency matrix because the input might be downsampled
27 | self.num_layers = len(self.A) - 1
28 | print("Number of downsampling layer: {}".format(self.num_layers))
29 | self.num_downsample = num_downsample
30 | if local_feat:
31 | self.lin1 = GraphLinear(3 + 2048 + 3840, 2 * num_channels)
32 | else:
33 | self.lin1 = GraphLinear(3 + 2048, 2 * num_channels)
34 | self.res1 = GraphResBlock(2 * num_channels, num_channels, self.A[0])
35 | encode_layers = []
36 | decode_layers = []
37 |
38 | for i in range(len(self.A)):
39 | encode_layers.append(GraphResBlock(num_channels, num_channels, self.A[i]))
40 |
41 | decode_layers.append(GraphResBlock((i+1)*num_channels, (i+1)*num_channels,
42 | self.A[len(self.A) - i - 1]))
43 | current_channels = (i+1)*num_channels
44 | # number of channels for the input is different because of the concatenation operation
45 | self.shape = nn.Sequential(GraphResBlock(current_channels, 64, self.A[0]),
46 | GraphResBlock(64, 32, self.A[0]),
47 | nn.GroupNorm(32 // 8, 32),
48 | nn.ReLU(inplace=True),
49 | GraphLinear(32, 3))
50 |
51 | self.encoder = nn.Sequential(*encode_layers)
52 | self.decoder = nn.Sequential(*decode_layers)
53 | self.mesh = mesh
54 |
55 | def forward(self, verts_c, img_fea_global, img_fea_multiscale=None, points_local=None):
56 | '''
57 | Args:
58 | verts_c: vertices from the coarse estimation
59 | img_fea_global: global feature for mesh refinement
60 | img_fea_multiscale: multi-scale feature from the encoder, used for local feature extraction
61 | points_local: 2D keypoint for local feature extraction
62 | Returns: refined mesh
63 | '''
64 | batch_size = img_fea_global.shape[0]
65 | ref_vertices = verts_c.transpose(1, 2)
66 | image_enc = img_fea_global.view(batch_size, 2048, 1).expand(-1, -1, ref_vertices.shape[-1])
67 | if points_local is not None:
68 | feat_local = torch.nn.functional.grid_sample(img_fea_multiscale, points_local)
69 | x = torch.cat([ref_vertices, image_enc, feat_local.squeeze(2)], dim=1)
70 | else:
71 | x = torch.cat([ref_vertices, image_enc], dim=1)
72 | x = self.lin1(x)
73 | x = self.res1(x)
74 | x_ = [x]
75 | for i in range(self.num_layers + 1):
76 | if i == self.num_layers:
77 | x = self.encoder[i](x)
78 | else:
79 | x = self.encoder[i](x)
80 | x = self.mesh.downsample(x.transpose(1, 2), n1=self.num_downsample+i, n2=self.num_downsample+i+1)
81 | x = x.transpose(1, 2)
82 | if i < self.num_layers-1:
83 | x_.append(x)
84 | for i in range(self.num_layers + 1):
85 | if i == self.num_layers:
86 | x = self.decoder[i](x)
87 | else:
88 | x = self.decoder[i](x)
89 | x = self.mesh.upsample(x.transpose(1, 2), n1=self.num_layers-i+self.num_downsample,
90 | n2=self.num_layers-i-1+self.num_downsample)
91 | x = x.transpose(1, 2)
92 | x = torch.cat([x, x_[self.num_layers-i-1]], dim=1) # skip connection between encoder and decoder
93 |
94 | shape = self.shape(x)
95 | return shape
96 |
--------------------------------------------------------------------------------
/util/loss_sdf.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from scipy.ndimage import distance_transform_edt as distance
4 | from skimage import segmentation as skimage_seg
5 | import matplotlib.pyplot as plt
6 |
7 |
8 | def dice_loss(score, target):
9 | # implemented from paper https://arxiv.org/pdf/1606.04797.pdf
10 | target = target.float()
11 | smooth = 1e-5
12 | intersect = torch.sum(score * target)
13 | y_sum = torch.sum(target * target)
14 | z_sum = torch.sum(score * score)
15 | loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth)
16 | loss = 1 - loss
17 | return loss
18 |
19 |
20 | class tversky_loss(torch.nn.Module):
21 | # implemented from https://arxiv.org/pdf/1706.05721.pdf
22 | def __init__(self, alpha, beta):
23 | '''
24 | Args:
25 | alpha: coefficient for false positive prediction
26 | beta: coefficient for false negtive prediction
27 | '''
28 | super(tversky_loss, self).__init__()
29 | self.alpha = alpha
30 | self.beta = beta
31 |
32 | def __call__(self, score, target):
33 | target = target.float()
34 | smooth = 1e-5
35 | tp = torch.sum(score * target)
36 | fn = torch.sum(target * (1 - score))
37 | fp = torch.sum((1-target) * score)
38 | loss = (tp + smooth) / (tp + self.alpha * fp + self.beta * fn + smooth)
39 | loss = 1 - loss
40 | return loss
41 |
42 |
43 | def compute_sdf1_1(img_gt, out_shape):
44 | """
45 | compute the normalized signed distance map of binary mask
46 | input: segmentation, shape = (batch_size, x, y, z)
47 | output: the Signed Distance Map (SDM)
48 | sdf(x) = 0; x in segmentation boundary
49 | -inf|x-y|; x in segmentation
50 | +inf|x-y|; x out of segmentation
51 | normalize sdf to [-1, 1]
52 | """
53 |
54 | img_gt = img_gt.astype(np.uint8)
55 |
56 | normalized_sdf = np.zeros(out_shape)
57 |
58 | for b in range(out_shape[0]): # batch size
59 | # ignore background
60 | for c in range(1, out_shape[1]):
61 | posmask = img_gt[b]
62 | negmask = 1-posmask
63 | posdis = distance(posmask)
64 | negdis = distance(negmask)
65 | boundary = skimage_seg.find_boundaries(posmask, mode='inner').astype(np.uint8)
66 | sdf = (negdis-np.min(negdis))/(np.max(negdis)-np.min(negdis)) - (posdis-np.min(posdis))/(np.max(posdis)-np.min(posdis))
67 | sdf[boundary==1] = 0
68 | normalized_sdf[b][c] = sdf
69 | assert np.min(sdf) == -1.0, print(np.min(posdis), np.min(negdis), np.max(posdis), np.max(negdis))
70 | assert np.max(sdf) == 1.0, print(np.min(posdis), np.min(negdis), np.max(posdis), np.max(negdis))
71 |
72 | return normalized_sdf
73 |
74 |
75 | def compute_sdf(img_gt, out_shape):
76 | """
77 | compute the signed distance map of binary mask
78 | input: segmentation, shape = (batch_size, x, y, z)
79 | output: the Signed Distance Map (SDM)
80 | sdf(x) = 0; x in segmentation boundary
81 | -inf|x-y|; x in segmentation
82 | +inf|x-y|; x out of segmentation
83 | """
84 |
85 | img_gt = img_gt.astype(np.uint8)
86 |
87 | gt_sdf = np.zeros(out_shape)
88 | debug = False
89 | for b in range(out_shape[0]): # batch size
90 | for c in range(0, out_shape[1]):
91 | posmask = img_gt[b]
92 | negmask = 1-posmask
93 | posdis = distance(posmask)
94 | negdis = distance(negmask)
95 | boundary = skimage_seg.find_boundaries(posmask, mode='inner').astype(np.uint8)
96 | sdf = negdis - posdis
97 | sdf[boundary==1] = 0
98 | gt_sdf[b][c] = sdf
99 | if debug:
100 | plt.figure()
101 | plt.subplot(1, 2, 1), plt.imshow(img_gt[b, 0, :, :]), plt.colorbar()
102 | plt.subplot(1, 2, 2), plt.imshow(gt_sdf[b, 0, :, :]), plt.colorbar()
103 | plt.show()
104 |
105 | return gt_sdf
106 |
107 |
108 | def boundary_loss(output, gt):
109 | """
110 | compute boundary loss for binary segmentation
111 | input: outputs_soft: softmax results, shape=(b,2,x,y,z)
112 | gt_sdf: sdf of ground truth (can be original or normalized sdf); shape=(b,2,x,y,z)
113 | output: boundary_loss; sclar
114 | adopted from http://proceedings.mlr.press/v102/kervadec19a/kervadec19a.pdf
115 | """
116 | multipled = torch.einsum('bcxy, bcxy->bcxy', output, gt)
117 | bd_loss = multipled.mean()
118 |
119 | return bd_loss
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Coarse-to-fine-3D-Animal
2 |
3 | **About**
4 |
5 | This is the source code for our paper
6 |
7 | Chen Li, Gim Hee Lee. Coarse-to-fine Animal Pose and Shape Estimation. In Neurips 2021.
8 |
9 | The shape space of the SMAL model is learned from 41 scans of toy animals and thus lacks pose ans shape variations. This may limit the representation capacity of the SMAL model and result in poor fittings of the estimated shapes to the 2D observations, as shown in the second to fourth columns of the Figure below. To mitigate this problem, we propose a coarse-to-fine approach, which combines model-based and model-free representations. Some refined results are shown in the fifth to seventh columns.
10 |
11 |
12 |
13 |
14 |
15 | Our network consists of a coarse estimation stage and a mesh refinement stage. The SMAL model parameters and camera parameters are regressed from the input image in the first stage for coarse estimation. This coarse estimation is further refined by an encoder-decoder structured GCN in the second stage.
16 |
17 |
18 |
19 |
20 |
21 | For more details, please refer to [our paper](https://arxiv.org/pdf/2111.08176.pdf).
22 |
23 | **Bibtex**
24 | ```
25 | @article{li2021coarse,
26 | title={Coarse-to-fine animal pose and shape estimation},
27 | author={Li, Chen and Lee, Gim Hee},
28 | journal={Advances in Neural Information Processing Systems},
29 | volume={34},
30 | pages={11757--11768},
31 | year={2021}
32 | }
33 | ```
34 |
35 | **Dependencies**
36 | 1. Python 3.7.10
37 | 2. Pytorch 1.5.0
38 |
39 | Please refer to requirements.txt for more details on dependencies.
40 |
41 | **Download datasets**
42 | * Download the [StanfordExtra dataset](https://github.com/benjiebob/StanfordExtra) and put it under the folder ./data/.
43 |
44 | * Download the [Animal Pose dataset](https://sites.google.com/view/animal-pose/) and the test split from [WLDO](https://github.com/benjiebob/WLDO/tree/master/data/animal_pose) and put them under the foder ./data/.
45 |
46 | **Download SMAL and priors**
47 | * Download the [SMAL](https://github.com/benjiebob/WLDO/tree/master/data) template and put the downloaded smal folder under ./data/.
48 | * Download the [pose prior data](https://github.com/benjiebob/SMALify/tree/master/data) and put the downloaded priors folder under ./data/.
49 |
50 | **Train**
51 |
52 | We provide the [pretrained model](https://drive.google.com/file/d/1mvr7iYkyKVUxPdFExE0HOsrVNl_sc1O1/view?usp=sharing) for each stage. You can download the pretrained models and put them under the folder ./logs/pretrained_models/. To save training time, You can directly train from stage 2 using our pretrained model for stage 1 ('stage1.pth.tar') by running:
53 | ```
54 | python main_meshgraph.py --output_dir logs/stage2 --nEpochs 10 --local_feat --batch_size 32 --freezecoarse --gpu_ids 0 --pretrained logs/pretrained_models/stage1.pth.tar
55 | ```
56 | Then you can continue to train stage 3 by running:
57 | ```
58 | python main_meshgraph.py --output_dir logs/stage3 --nEpochs 200 --lr 1e-5 --local_feat --w_arap 10000 --w_dice 1000 --w_dice_refine 100 --w_pose_limit_prior 5000 --resume logs/pretrained_models/stage2.pth.tar --gpu_ids 0
59 | ```
60 | Note that you will need to change the '--resume' to the path of your own model if you want to use your own model from stage 2.
61 |
62 | Alternatively, you can also train from scratch. In this case, you will need to pretrain the coarse estimation part first by running:
63 | ```
64 | python main.py --batch_size 32 --output_dir logs/stage1 --gpu_ids 0
65 | ```
66 | Then you can continue to train stage 2 and stage 3 as we have explained. Note that you will need to change the '--pretrain' in stage 2 and '--resume' in stage 3 to the path of your own model.
67 |
68 | **Test**
69 |
70 | Test our model on the StandfordExtra dataset by running:
71 | ```
72 | python eval.py --output_dir logs/test --resume logs/pretrained_models/stage3.pth.tar --gpu_ids 0 --local_feat
73 | ```
74 | or on the Animal Pose dataset by running:
75 | ```
76 | python eval.py --output_dir logs/test --resume logs/pretrained_models/stage3.pth.tar --gpu_ids 0 --local_feat --dataset animal_pose
77 | ```
78 | Qualitative results can be genrated and saved by adding '--save_results' to the command.
79 |
80 | **Acknowledgements**
81 |
82 | The code for the coarse estimation stage is adopted from [WLDO](https://github.com/benjiebob/WLDO/tree/master/data/animal_pose). If you use the coarse estimation pipeline, please cite:
83 | ```
84 | @inproceedings{biggs2020wldo,
85 | title={{W}ho left the dogs out?: {3D} animal reconstruction with expectation maximization in the loop},
86 | author={Biggs, Benjamin and Boyne, Oliver and Charles, James and Fitzgibbon, Andrew and Cipolla, Roberto},
87 | booktitle={ECCV},
88 | year={2020}
89 | }
90 | ```
91 |
--------------------------------------------------------------------------------
/smal/mesh.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | import torch
3 | import numpy as np
4 | import scipy.sparse
5 |
6 | from model.networks.graph_layers import spmm
7 |
8 |
9 | def scipy_to_pytorch(A, U, D):
10 | """Convert scipy sparse matrices to pytorch sparse matrix."""
11 | ptU = []
12 | ptD = []
13 |
14 | for i in range(len(U)):
15 | u = scipy.sparse.coo_matrix(U[i])
16 | i = torch.LongTensor(np.array([u.row, u.col]))
17 | v = torch.FloatTensor(u.data)
18 | ptU.append(torch.sparse.FloatTensor(i, v, u.shape))
19 |
20 | for i in range(len(D)):
21 | d = scipy.sparse.coo_matrix(D[i])
22 | i = torch.LongTensor(np.array([d.row, d.col]))
23 | v = torch.FloatTensor(d.data)
24 | ptD.append(torch.sparse.FloatTensor(i, v, d.shape))
25 |
26 | return ptU, ptD
27 |
28 |
29 | def adjmat_sparse(adjmat, nsize=1):
30 | """Create row-normalized sparse graph adjacency matrix."""
31 | adjmat = scipy.sparse.csr_matrix(adjmat)
32 | if nsize > 1:
33 | orig_adjmat = adjmat.copy()
34 | for _ in range(1, nsize):
35 | adjmat = adjmat * orig_adjmat
36 | adjmat.data = np.ones_like(adjmat.data)
37 | for i in range(adjmat.shape[0]):
38 | adjmat[i, i] = 1
39 | num_neighbors = np.array(1 / adjmat.sum(axis=-1))
40 | adjmat = adjmat.multiply(num_neighbors)
41 | adjmat = scipy.sparse.coo_matrix(adjmat)
42 | row = adjmat.row
43 | col = adjmat.col
44 | data = adjmat.data
45 | i = torch.LongTensor(np.array([row, col]))
46 | v = torch.from_numpy(data).float()
47 | adjmat = torch.sparse.FloatTensor(i, v, adjmat.shape)
48 | return adjmat
49 |
50 |
51 | def get_graph_params(filename, nsize=1):
52 | """Load and process graph adjacency matrix and upsampling/downsampling matrices."""
53 | data = np.load(filename, encoding='latin1', allow_pickle=True)
54 | A = data['A']
55 | U = data['U']
56 | D = data['D']
57 | U, D = scipy_to_pytorch(A, U, D)
58 | A = [adjmat_sparse(a, nsize=nsize) for a in A]
59 | return A, U, D
60 |
61 |
62 | class Mesh(object):
63 | """Mesh object that is used for handling certain graph operations."""
64 |
65 | def __init__(self, smal, filename='./data/mesh_down_sampling.npz',
66 | num_downsampling=1, nsize=1, device=torch.device('cuda')):
67 | self._A, self._U, self._D = get_graph_params(filename=filename, nsize=nsize)
68 | self._A = [a.to(device) for a in self._A]
69 | self._U = [u.to(device) for u in self._U]
70 | self._D = [d.to(device) for d in self._D]
71 | self.num_downsampling = num_downsampling
72 |
73 | # load template vertices from SMPL and normalize them
74 | ref_vertices = smal.v_template.clone()
75 | center = 0.5 * (ref_vertices.max(dim=0)[0] + ref_vertices.min(dim=0)[0])[None]
76 | ref_vertices -= center
77 | ref_vertices /= ref_vertices.abs().max().item()
78 |
79 | self._ref_vertices = ref_vertices.to(device)
80 | self.faces = smal.faces.int()
81 |
82 | @property
83 | def adjmat(self):
84 | """Return the graph adjacency matrix at the specified subsampling level."""
85 | return self._A[self.num_downsampling].float()
86 |
87 | @property
88 | def ref_vertices(self):
89 | """Return the template vertices at the specified subsampling level."""
90 | ref_vertices = self._ref_vertices
91 | for i in range(self.num_downsampling):
92 | ref_vertices = torch.spmm(self._D[i], ref_vertices)
93 | return ref_vertices
94 |
95 | def downsample(self, x, n1=0, n2=None):
96 | """Downsample mesh."""
97 | if n2 is None:
98 | n2 = self.num_downsampling
99 | if x.ndimension() < 3:
100 | for i in range(n1, n2):
101 | x = spmm(self._D[i], x)
102 | elif x.ndimension() == 3:
103 | out = []
104 | for i in range(x.shape[0]):
105 | y = x[i]
106 | for j in range(n1, n2):
107 | y = spmm(self._D[j], y)
108 | out.append(y)
109 | x = torch.stack(out, dim=0)
110 | return x
111 |
112 | def upsample(self, x, n1=None, n2=0):
113 | """Upsample mesh."""
114 | if n1 is None:
115 | n1 = self.num_downsampling
116 | if x.ndimension() < 3:
117 | for i in reversed(range(n2, n1)):
118 | x = spmm(self._U[i], x)
119 | elif x.ndimension() == 3:
120 | out = []
121 | for i in range(x.shape[0]):
122 | y = x[i]
123 | for j in reversed(range(n2, n1)):
124 | y = spmm(self._U[j], y)
125 | out.append(y)
126 | x = torch.stack(out, dim=0)
127 | return x
128 |
--------------------------------------------------------------------------------
/model/mesh_graph_hg.py:
--------------------------------------------------------------------------------
1 | """
2 | Mesh net model.
3 | """
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | from __future__ import print_function
7 |
8 | import numpy as np
9 | import torch.nn as nn
10 | import torch
11 | from smal.smal_torch import SMAL
12 | from util import config
13 | from model.graph_hg import GraphCNN_hg
14 | from model.smal_mesh_net_img import MeshNet_img
15 | from util.nmr import NeuralRenderer
16 | from smal.mesh import Mesh
17 | import cv2
18 | # ------------- Modules ------------#
19 | # ----------------------------------#
20 |
21 | unity_shape_prior = np.load('data/priors/unity_betas.npz')
22 |
23 |
24 | class MeshGraph_hg(nn.Module):
25 | def __init__(self, device, shape_family_id, number_channels, num_layers, betas_scale=False, shape_init=None
26 | , local_feat=False, num_downsampling=0, render_rgb=False):
27 | '''
28 |
29 | Args:
30 | device: specify device for training
31 | shape_family_id: specify animal category id
32 | number_channels: specify number of channels for GCN
33 | betas_scale: whether predict additional shape parameters proposed by WLDO
34 | shape_init: whether initiate the bias weights for the coarse stage as mean shape
35 | local_feat: whether use local feature for refinement step
36 | num_downsampling: number of donwsamplings before input to GCN.
37 | We downsample the original mesh once before going through GCN to save memory
38 | render_rgb: wehther render the 3D mesh onto 2D to get RGB image. Only set to true when generating
39 | visualization to save inference time.
40 | '''
41 | super(MeshGraph_hg, self).__init__()
42 |
43 | self.model_renderer = NeuralRenderer(config.IMG_RES, proj_type=config.PROJECTION,
44 | norm_f0=config.NORM_F0,
45 | norm_f=config.NORM_F,
46 | norm_z=config.NORM_Z, render_rgb=render_rgb, device=device)
47 | self.model_renderer.directional_light_only()
48 | self.smal = SMAL(device, shape_family_id=shape_family_id)
49 | self.local_feat = local_feat
50 | if shape_init == 'smal':
51 | print("Initiate shape with smal prior")
52 | shape_init = self.smal.shape_cluster_means
53 | elif shape_init == 'unity':
54 | print("Initiate shape with unity prior ")
55 | shape_init = unity_shape_prior['mean'][:-1]
56 | shape_init = torch.from_numpy(shape_init).float().to(device)
57 | else:
58 | print("No initialization for shape")
59 | shape_init = None
60 |
61 | input_size = [config.IMG_RES, config.IMG_RES]
62 | self.meshnet = MeshNet_img(input_size, betas_scale=betas_scale, norm_f0=config.NORM_F0, nz_feat=config.NZ_FEAT,
63 | shape_init=shape_init, return_feat=True)
64 | self.mesh = Mesh(self.smal, num_downsampling=num_downsampling, filename='./data/mesh_down_sampling_4.npz',
65 | device=device)
66 | self.graphnet = GraphCNN_hg(self.mesh, num_channels=number_channels,
67 | local_feat=local_feat, num_downsample=num_downsampling).to(device)
68 |
69 | def forward(self, img):
70 | pred_codes, enc_feat, feat_multiscale = self.meshnet(img)
71 | scale_pred, trans_pred, pose_pred, betas_pred, betas_scale_pred = pred_codes
72 | pred_camera = torch.cat([scale_pred[:, [0]], torch.ones(scale_pred.shape[0], 2).cuda() * config.IMG_RES / 2],
73 | dim=1)
74 | verts, joints, _, _ = self.smal(betas_pred, pose_pred, trans=trans_pred,
75 | betas_logscale=betas_scale_pred)
76 | enc_feat_copy = enc_feat.detach().clone()
77 | verts_d = self.mesh.downsample(verts)
78 | verts_copy = verts_d.detach().clone()
79 | if self.local_feat:
80 | feat_multiscale_copy = feat_multiscale.detach().clone()
81 | points_img = self.model_renderer.project_points(verts_copy, pred_camera, normalize_kpts=True)
82 | # used normalized image coordinate for the requirement of torch.nn.functional.grid_sample
83 | verts_re = self.graphnet(verts_d, enc_feat_copy, feat_multiscale_copy, points_img.unsqueeze(1))
84 | else:
85 | verts_re = self.graphnet(verts_d, enc_feat_copy)
86 | verts_re = self.mesh.upsample(verts_re.transpose(1, 2))
87 | return verts, joints, verts_re, pred_codes
88 |
89 |
90 | def init_pretrained(model, checkpoint):
91 | pretrained_dict = checkpoint['state_dict']
92 | model_dict = model.state_dict()
93 | model_dict.update(pretrained_dict)
94 | model.load_state_dict(model_dict)
95 | print("Init MeshNet")
96 |
97 |
--------------------------------------------------------------------------------
/util/logger.py:
--------------------------------------------------------------------------------
1 | # A simple torch style logger
2 | # (C) Wei YANG 2017
3 | from __future__ import absolute_import
4 |
5 | import os
6 | import sys
7 | import numpy as np
8 | import matplotlib.pyplot as plt
9 |
10 | __all__ = ['Logger', 'LoggerMonitor', 'savefig']
11 |
12 | def savefig(fname, dpi=None):
13 | dpi = 150 if dpi == None else dpi
14 | plt.savefig(fname, dpi=dpi)
15 |
16 | def plot_overlap(logger, names=None):
17 | names = logger.names if names == None else names
18 | numbers = logger.numbers
19 | for _, name in enumerate(names):
20 | x = np.arange(len(numbers[name]))
21 | plt.plot(x, np.asarray(numbers[name]))
22 | return [logger.title + '(' + name + ')' for name in names]
23 |
24 | class Logger(object):
25 | '''Save training process to log file with simple plot function.'''
26 | def __init__(self, fpath, title=None, resume=False):
27 | self.file = None
28 | self.resume = resume
29 | self.title = '' if title == None else title
30 | if fpath is not None:
31 | if resume:
32 | self.file = open(fpath, 'r')
33 | name = self.file.readline()
34 | self.names = name.rstrip().split('\t')
35 | self.numbers = {}
36 | for _, name in enumerate(self.names):
37 | self.numbers[name] = []
38 |
39 | for numbers in self.file:
40 | numbers = numbers.rstrip().split('\t')
41 | for i in range(0, len(numbers)):
42 | self.numbers[self.names[i]].append(numbers[i])
43 | self.file.close()
44 | self.file = open(fpath, 'a')
45 | else:
46 | self.file = open(fpath, 'w')
47 |
48 | def set_names(self, names):
49 | if self.resume:
50 | pass
51 | # initialize numbers as empty list
52 | self.numbers = {}
53 | self.names = names
54 | for _, name in enumerate(self.names):
55 | self.file.write(name)
56 | self.file.write('\t')
57 | self.numbers[name] = []
58 | self.file.write('\n')
59 | self.file.flush()
60 |
61 | def append(self, numbers):
62 | assert len(self.names) == len(numbers), 'Numbers do not match names'
63 | for index, num in enumerate(numbers):
64 | self.file.write("{0:.6f}".format(num))
65 | self.file.write('\t')
66 | self.numbers[self.names[index]].append(num)
67 | self.file.write('\n')
68 | self.file.flush()
69 |
70 | def log_arguments(self, args):
71 |
72 | self.file.write('Command:{}'.format(sys.argv))
73 | self.file.write('\n')
74 | s = '\n'.join(['{}: {}'.format(arg, getattr(args, arg)) for arg in vars(args)])
75 | s = 'Arguments:\n' + s
76 | self.file.write(s)
77 | self.file.write('\n')
78 | self.file.flush()
79 |
80 | def plot(self, names=None):
81 | names = self.names if names == None else names
82 | numbers = self.numbers
83 | for _, name in enumerate(names):
84 | x = np.arange(len(numbers[name]))
85 | plt.plot(x, np.asarray(numbers[name]))
86 | plt.legend([self.title + '(' + name + ')' for name in names])
87 | plt.grid(True)
88 | plt.gca().invert_yaxis()
89 | plt.show()
90 |
91 | def close(self):
92 | if self.file is not None:
93 | self.file.close()
94 |
95 |
96 | class LoggerMonitor(object):
97 | '''Load and visualize multiple logs.'''
98 | def __init__ (self, paths):
99 | '''paths is a distionary with {name:filepath} pair'''
100 | self.loggers = []
101 | for title, path in paths.items():
102 | logger = Logger(path, title=title, resume=True)
103 | self.loggers.append(logger)
104 |
105 | def plot(self, names=None):
106 | plt.figure()
107 | plt.subplot(121)
108 | legend_text = []
109 | for logger in self.loggers:
110 | legend_text += plot_overlap(logger, names)
111 | plt.legend(legend_text, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
112 | plt.grid(True)
113 |
114 | if __name__ == '__main__':
115 | # # Example
116 | # logger = Logger('test.txt')
117 | # logger.set_names(['Train loss', 'Valid loss','Test loss'])
118 |
119 | # length = 100
120 | # t = np.arange(length)
121 | # train_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1
122 | # valid_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1
123 | # test_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1
124 |
125 | # for i in range(0, length):
126 | # logger.append([train_loss[i], valid_loss[i], test_loss[i]])
127 | # logger.plot()
128 |
129 | # Example: logger monitor
130 | # paths = {
131 | # 'resadvnet20':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt',
132 | # 'resadvnet32':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt',
133 | # 'resadvnet44':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt',
134 | # }
135 | #
136 | # field = ['Valid Acc.']
137 | #
138 | # monitor = LoggerMonitor(paths)
139 | # monitor.plot(names=field)
140 | # savefig('test.eps')
141 | logger = Logger(fpath='/media/haleh/Harddisk1/checkpoint/horse/syn2real_grl/log.txt', resume=True)
142 | logger.plot(names=['LR'])
--------------------------------------------------------------------------------
/model/networks/graph_layers.py:
--------------------------------------------------------------------------------
1 | """
2 | This file contains definitions of layers used to build the GraphCNN
3 | """
4 | from __future__ import division
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | import math
10 | # from torch_geometric.nn.conv import MessagePassing
11 |
12 | class GraphConvolution(nn.Module):
13 | """Simple GCN layer, similar to https://arxiv.org/abs/1609.02907."""
14 | def __init__(self, in_features, out_features, adjmat, bias=True):
15 | super(GraphConvolution, self).__init__()
16 | self.in_features = in_features
17 | self.out_features = out_features
18 | self.adjmat = adjmat
19 | self.weight = nn.Parameter(torch.FloatTensor(in_features, out_features))
20 | if bias:
21 | self.bias = nn.Parameter(torch.FloatTensor(out_features))
22 | else:
23 | self.register_parameter('bias', None)
24 | self.reset_parameters()
25 |
26 | def reset_parameters(self):
27 | # stdv = 1. / math.sqrt(self.weight.size(1))
28 | stdv = 6. / math.sqrt(self.weight.size(0) + self.weight.size(1))
29 | self.weight.data.uniform_(-stdv, stdv)
30 | if self.bias is not None:
31 | self.bias.data.uniform_(-stdv, stdv)
32 |
33 | def forward(self, x):
34 | if x.ndimension() == 2:
35 | support = torch.matmul(x, self.weight)
36 | output = torch.matmul(self.adjmat, support)
37 | if self.bias is not None:
38 | output = output + self.bias
39 | return output
40 | else:
41 | output = []
42 | for i in range(x.shape[0]):
43 | support = torch.matmul(x[i], self.weight)
44 | # output.append(torch.matmul(self.adjmat, support))
45 | output.append(spmm(self.adjmat, support))
46 | output = torch.stack(output, dim=0)
47 | if self.bias is not None:
48 | output = output + self.bias
49 | return output
50 |
51 | def __repr__(self):
52 | return self.__class__.__name__ + ' (' \
53 | + str(self.in_features) + ' -> ' \
54 | + str(self.out_features) + ')'
55 |
56 |
57 | class GraphLinear(nn.Module):
58 | """
59 | Generalization of 1x1 convolutions on Graphs
60 | """
61 | def __init__(self, in_channels, out_channels):
62 | super(GraphLinear, self).__init__()
63 | self.in_channels = in_channels
64 | self.out_channels = out_channels
65 | self.W = nn.Parameter(torch.FloatTensor(out_channels, in_channels))
66 | self.b = nn.Parameter(torch.FloatTensor(out_channels))
67 | self.reset_parameters()
68 |
69 | def reset_parameters(self):
70 | w_stdv = 1 / (self.in_channels * self.out_channels)
71 | self.W.data.uniform_(-w_stdv, w_stdv)
72 | self.b.data.uniform_(-w_stdv, w_stdv)
73 |
74 | def forward(self, x):
75 | return torch.matmul(self.W[None, :], x) + self.b[None, :, None]
76 |
77 |
78 | class GraphResBlock(nn.Module):
79 | """
80 | Graph Residual Block similar to the Bottleneck Residual Block in ResNet
81 | """
82 |
83 | def __init__(self, in_channels, out_channels, A):
84 | super(GraphResBlock, self).__init__()
85 | self.in_channels = in_channels
86 | self.out_channels = out_channels
87 | self.lin1 = GraphLinear(in_channels, out_channels // 2)
88 | self.conv = GraphConvolution(out_channels // 2, out_channels // 2, A)
89 | self.lin2 = GraphLinear(out_channels // 2, out_channels)
90 | self.skip_conv = GraphLinear(in_channels, out_channels)
91 | self.pre_norm = nn.GroupNorm(in_channels // 8, in_channels)
92 | self.norm1 = nn.GroupNorm((out_channels // 2) // 8, (out_channels // 2))
93 | self.norm2 = nn.GroupNorm((out_channels // 2) // 8, (out_channels // 2))
94 |
95 | def forward(self, x):
96 | y = F.relu(self.pre_norm(x))
97 | y = self.lin1(y)
98 |
99 | y = F.relu(self.norm1(y))
100 | y = self.conv(y.transpose(1,2)).transpose(1,2)
101 |
102 | y = F.relu(self.norm2(y))
103 | y = self.lin2(y)
104 | if self.in_channels != self.out_channels:
105 | x = self.skip_conv(x)
106 | return x+y
107 |
108 |
109 | class SparseMM(torch.autograd.Function):
110 | """Redefine sparse @ dense matrix multiplication to enable backpropagation.
111 | The builtin matrix multiplication operation does not support backpropagation in some cases.
112 | """
113 | @staticmethod
114 | def forward(ctx, sparse, dense):
115 | ctx.req_grad = dense.requires_grad
116 | ctx.save_for_backward(sparse)
117 | return torch.matmul(sparse, dense)
118 |
119 | @staticmethod
120 | def backward(ctx, grad_output):
121 | grad_input = None
122 | sparse, = ctx.saved_tensors
123 | if ctx.req_grad:
124 | grad_input = torch.matmul(sparse.t(), grad_output)
125 | return None, grad_input
126 |
127 |
128 | def spmm(sparse, dense):
129 | return SparseMM.apply(sparse, dense)
130 |
131 |
132 | # class Pool(MessagePassing):
133 | # def __init__(self, pool_mat):
134 | # super(Pool, self).__init__(flow='target_to_source')
135 | # self.pool_mat = pool_mat
136 | #
137 | # def forward(self, x):
138 | # x = x.transpose(0, 1)
139 | # out = self.propagate(edge_index=self.pool_mat._indices(), x=x, norm=self.pool_mat._values(),
140 | # size=self.pool_mat.size())
141 | # return out.transpose(0, 1)
142 | #
143 | # def message(self, x_j, norm):
144 | # return norm.view(-1, 1, 1) * x_j
--------------------------------------------------------------------------------
/util/helpers/conversions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | import torchgeometry as tgm
5 |
6 | def rotation_matrix_to_angle_axis(rotation_matrix):
7 | """Convert 3x4 rotation matrix to Rodrigues vector
8 |
9 | Args:
10 | rotation_matrix (Tensor): rotation matrix.
11 |
12 | Returns:
13 | Tensor: Rodrigues vector transformation.
14 |
15 | Shape:
16 | - Input: :math:`(N, 3, 4)`
17 | - Output: :math:`(N, 3)`
18 |
19 | Example:
20 | >>> input = torch.rand(2, 3, 4) # Nx4x4
21 | >>> output = tgm.rotation_matrix_to_angle_axis(input) # Nx3
22 | """
23 | # todo add check that matrix is a valid rotation matrix
24 | quaternion = rotation_matrix_to_quaternion(rotation_matrix)
25 | return quaternion_to_angle_axis(quaternion)
26 |
27 | def rotation_matrix_to_quaternion(rotation_matrix, eps=1e-6):
28 | """Convert 3x4 rotation matrix to 4d quaternion vector
29 |
30 | This algorithm is based on algorithm described in
31 | https://github.com/KieranWynn/pyquaternion/blob/master/pyquaternion/quaternion.py#L201
32 |
33 | Args:
34 | rotation_matrix (Tensor): the rotation matrix to convert.
35 |
36 | Return:
37 | Tensor: the rotation in quaternion
38 |
39 | Shape:
40 | - Input: :math:`(N, 3, 4)`
41 | - Output: :math:`(N, 4)`
42 |
43 | Example:
44 | >>> input = torch.rand(4, 3, 4) # Nx3x4
45 | >>> output = tgm.rotation_matrix_to_quaternion(input) # Nx4
46 | """
47 | if not torch.is_tensor(rotation_matrix):
48 | raise TypeError("Input type is not a torch.Tensor. Got {}".format(
49 | type(rotation_matrix)))
50 |
51 | if len(rotation_matrix.shape) > 3:
52 | raise ValueError(
53 | "Input size must be a three dimensional tensor. Got {}".format(
54 | rotation_matrix.shape))
55 | if not rotation_matrix.shape[-2:] == (3, 4):
56 | raise ValueError(
57 | "Input size must be a N x 3 x 4 tensor. Got {}".format(
58 | rotation_matrix.shape))
59 |
60 | rmat_t = torch.transpose(rotation_matrix, 1, 2)
61 |
62 | mask_d2 = rmat_t[:, 2, 2] < eps
63 |
64 | mask_d0_d1 = rmat_t[:, 0, 0] > rmat_t[:, 1, 1]
65 | mask_d0_nd1 = rmat_t[:, 0, 0] < -rmat_t[:, 1, 1]
66 |
67 | t0 = 1 + rmat_t[:, 0, 0] - rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
68 | q0 = torch.stack([rmat_t[:, 1, 2] - rmat_t[:, 2, 1],
69 | t0, rmat_t[:, 0, 1] + rmat_t[:, 1, 0],
70 | rmat_t[:, 2, 0] + rmat_t[:, 0, 2]], -1)
71 | t0_rep = t0.repeat(4, 1).t()
72 |
73 | t1 = 1 - rmat_t[:, 0, 0] + rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
74 | q1 = torch.stack([rmat_t[:, 2, 0] - rmat_t[:, 0, 2],
75 | rmat_t[:, 0, 1] + rmat_t[:, 1, 0],
76 | t1, rmat_t[:, 1, 2] + rmat_t[:, 2, 1]], -1)
77 | t1_rep = t1.repeat(4, 1).t()
78 |
79 | t2 = 1 - rmat_t[:, 0, 0] - rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
80 | q2 = torch.stack([rmat_t[:, 0, 1] - rmat_t[:, 1, 0],
81 | rmat_t[:, 2, 0] + rmat_t[:, 0, 2],
82 | rmat_t[:, 1, 2] + rmat_t[:, 2, 1], t2], -1)
83 | t2_rep = t2.repeat(4, 1).t()
84 |
85 | t3 = 1 + rmat_t[:, 0, 0] + rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
86 | q3 = torch.stack([t3, rmat_t[:, 1, 2] - rmat_t[:, 2, 1],
87 | rmat_t[:, 2, 0] - rmat_t[:, 0, 2],
88 | rmat_t[:, 0, 1] - rmat_t[:, 1, 0]], -1)
89 | t3_rep = t3.repeat(4, 1).t()
90 |
91 | mask_c0 = mask_d2 * mask_d0_d1
92 | # mask_c1 = mask_d2 * (1 - mask_d0_d1)
93 | mask_c1 = mask_d2 * (~mask_d0_d1)
94 | # mask_c2 = (1 - mask_d2) * mask_d0_nd1
95 | mask_c2 = (~mask_d2) * mask_d0_nd1
96 | # mask_c3 = (1 - mask_d2) * (1 - mask_d0_nd1)
97 | mask_c3 = (~mask_d2) * (~mask_d0_nd1)
98 | mask_c0 = mask_c0.view(-1, 1).type_as(q0)
99 | mask_c1 = mask_c1.view(-1, 1).type_as(q1)
100 | mask_c2 = mask_c2.view(-1, 1).type_as(q2)
101 | mask_c3 = mask_c3.view(-1, 1).type_as(q3)
102 |
103 | q = q0 * mask_c0 + q1 * mask_c1 + q2 * mask_c2 + q3 * mask_c3
104 | q /= torch.sqrt(t0_rep * mask_c0 + t1_rep * mask_c1 + # noqa
105 | t2_rep * mask_c2 + t3_rep * mask_c3) # noqa
106 | q *= 0.5
107 | return q
108 |
109 | def quaternion_to_angle_axis(quaternion: torch.Tensor) -> torch.Tensor:
110 | """Convert quaternion vector to angle axis of rotation.
111 |
112 | Adapted from ceres C++ library: ceres-solver/include/ceres/rotation.h
113 |
114 | Args:
115 | quaternion (torch.Tensor): tensor with quaternions.
116 |
117 | Return:
118 | torch.Tensor: tensor with angle axis of rotation.
119 |
120 | Shape:
121 | - Input: :math:`(*, 4)` where `*` means, any number of dimensions
122 | - Output: :math:`(*, 3)`
123 |
124 | Example:
125 | >>> quaternion = torch.rand(2, 4) # Nx4
126 | >>> angle_axis = tgm.quaternion_to_angle_axis(quaternion) # Nx3
127 | """
128 | if not torch.is_tensor(quaternion):
129 | raise TypeError("Input type is not a torch.Tensor. Got {}".format(
130 | type(quaternion)))
131 |
132 | if not quaternion.shape[-1] == 4:
133 | raise ValueError("Input must be a tensor of shape Nx4 or 4. Got {}"
134 | .format(quaternion.shape))
135 | # unpack input and compute conversion
136 | q1: torch.Tensor = quaternion[..., 1]
137 | q2: torch.Tensor = quaternion[..., 2]
138 | q3: torch.Tensor = quaternion[..., 3]
139 | sin_squared_theta: torch.Tensor = q1 * q1 + q2 * q2 + q3 * q3
140 |
141 | sin_theta: torch.Tensor = torch.sqrt(sin_squared_theta)
142 | cos_theta: torch.Tensor = quaternion[..., 0]
143 | two_theta: torch.Tensor = 2.0 * torch.where(
144 | cos_theta < 0.0,
145 | torch.atan2(-sin_theta, -cos_theta),
146 | torch.atan2(sin_theta, cos_theta))
147 |
148 | k_pos: torch.Tensor = two_theta / sin_theta
149 | k_neg: torch.Tensor = 2.0 * torch.ones_like(sin_theta)
150 | k: torch.Tensor = torch.where(sin_squared_theta > 0.0, k_pos, k_neg)
151 |
152 | angle_axis: torch.Tensor = torch.zeros_like(quaternion)[..., :3]
153 | angle_axis[..., 0] += q1 * k
154 | angle_axis[..., 1] += q2 * k
155 | angle_axis[..., 2] += q3 * k
156 | return angle_axis
--------------------------------------------------------------------------------
/smal/batch_lbs.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import torch
6 | import numpy as np
7 |
8 |
9 | def batch_skew(vec, batch_size=None, opts=None):
10 | """
11 | vec is N x 3, batch_size is int
12 |
13 | returns N x 3 x 3. Skew_sym version of each matrix.
14 | """
15 | if batch_size is None:
16 | batch_size = vec.shape.as_list()[0]
17 | col_inds = torch.LongTensor([1, 2, 3, 5, 6, 7])
18 | indices = torch.reshape(torch.reshape(torch.arange(0, batch_size) * 9, [-1, 1]) + col_inds, [-1, 1])
19 | updates = torch.reshape(
20 | torch.stack(
21 | [
22 | -vec[:, 2], vec[:, 1], vec[:, 2], -vec[:, 0], -vec[:, 1],
23 | vec[:, 0]
24 | ],
25 | dim=1), [-1])
26 | out_shape = [batch_size * 9]
27 | res = torch.Tensor(np.zeros(out_shape[0])).cuda(device=vec.device)
28 | res[np.array(indices.flatten())] = updates
29 | res = torch.reshape(res, [batch_size, 3, 3])
30 |
31 | return res
32 |
33 | def batch_rodrigues(theta, opts=None):
34 | """
35 | Theta is Nx3
36 | """
37 | batch_size = theta.shape[0]
38 |
39 | angle = (torch.norm(theta + 1e-8, p=2, dim=1)).unsqueeze(-1)
40 | r = (torch.div(theta, angle)).unsqueeze(-1)
41 |
42 | angle = angle.unsqueeze(-1)
43 | cos = torch.cos(angle)
44 | sin = torch.sin(angle)
45 |
46 | outer = torch.matmul(r, r.transpose(1,2))
47 |
48 | eyes = torch.eye(3).unsqueeze(0).repeat([batch_size, 1, 1]).cuda(device=theta.device)
49 | H = batch_skew(r, batch_size=batch_size, opts=opts)
50 | R = cos * eyes + (1 - cos) * outer + sin * H
51 |
52 | return R
53 |
54 | def batch_lrotmin(theta):
55 | """
56 | Output of this is used to compute joint-to-pose blend shape mapping.
57 | Equation 9 in SMPL paper.
58 |
59 |
60 | Args:
61 | pose: `Tensor`, N x 72 vector holding the axis-angle rep of K joints.
62 | This includes the global rotation so K=24
63 |
64 | Returns
65 | diff_vec : `Tensor`: N x 207 rotation matrix of 23=(K-1) joints with identity subtracted.,
66 | """
67 | # Ignore global rotation
68 | theta = theta[:,3:]
69 |
70 | Rs = batch_rodrigues(torch.reshape(theta, [-1,3]))
71 | lrotmin = torch.reshape(Rs - torch.eye(3), [-1, 207])
72 |
73 | return lrotmin
74 |
75 | def batch_global_rigid_transformation(Rs, Js, parent, rotate_base = False, betas_logscale=None, opts=None):
76 | """
77 | Computes absolute joint locations given pose.
78 |
79 | rotate_base: if True, rotates the global rotation by 90 deg in x axis.
80 | if False, this is the original SMPL coordinate.
81 |
82 | Args:
83 | Rs: N x 24 x 3 x 3 rotation vector of K joints
84 | Js: N x 24 x 3, joint locations before posing
85 | parent: 24 holding the parent id for each index
86 |
87 | Returns
88 | new_J : `Tensor`: N x 24 x 3 location of absolute joints
89 | A : `Tensor`: N x 24 4 x 4 relative joint transformations for LBS.
90 | """
91 | if rotate_base:
92 | print('Flipping the SMPL coordinate frame!!!!')
93 | rot_x = torch.Tensor([[1, 0, 0], [0, -1, 0], [0, 0, -1]])
94 | rot_x = torch.reshape(torch.repeat(rot_x, [N, 1]), [N, 3, 3]) # In tf it was tile
95 | root_rotation = torch.matmul(Rs[:, 0, :, :], rot_x)
96 | else:
97 | root_rotation = Rs[:, 0, :, :]
98 |
99 | # Now Js is N x 24 x 3 x 1
100 | Js = Js.unsqueeze(-1)
101 | N = Rs.shape[0]
102 |
103 | Js_orig = Js.clone()
104 |
105 | scaling_factors = torch.ones(N, parent.shape[0], 3).to(Rs.device)
106 | if betas_logscale is not None:
107 | leg_joints = list(range(7,11)) + list(range(11,15)) + list(range(17,21)) + list(range(21,25))
108 | tail_joints = list(range(25, 32))
109 | ear_joints = [33, 34]
110 |
111 | beta_scale_mask = torch.zeros(35, 3, 6).to(betas_logscale.device)
112 | beta_scale_mask[leg_joints, [2], [0]] = 1.0 # Leg lengthening
113 | beta_scale_mask[leg_joints, [0], [1]] = 1.0 # Leg fatness
114 | beta_scale_mask[leg_joints, [1], [1]] = 1.0 # Leg fatness
115 |
116 | beta_scale_mask[tail_joints, [0], [2]] = 1.0 # Tail lengthening
117 | beta_scale_mask[tail_joints, [1], [3]] = 1.0 # Tail fatness
118 | beta_scale_mask[tail_joints, [2], [3]] = 1.0 # Tail fatness
119 |
120 | beta_scale_mask[ear_joints, [1], [4]] = 1.0 # Ear y
121 | beta_scale_mask[ear_joints, [2], [5]] = 1.0 # Ear z
122 |
123 | beta_scale_mask = torch.transpose(
124 | beta_scale_mask.reshape(35*3, 6), 0, 1)
125 |
126 | betas_scale = torch.exp(betas_logscale @ beta_scale_mask)
127 | scaling_factors = betas_scale.reshape(-1, 35, 3)
128 |
129 | scale_factors_3x3 = torch.diag_embed(scaling_factors, dim1=-2, dim2=-1)
130 |
131 | def make_A(R, t):
132 | # Rs is N x 3 x 3, ts is N x 3 x 1
133 | R_homo = torch.nn.functional.pad(R, (0,0,0,1,0,0))
134 | t_homo = torch.cat([t, torch.ones([N, 1, 1]).to(Rs.device)], 1)
135 | return torch.cat([R_homo, t_homo], 2)
136 |
137 | A0 = make_A(root_rotation, Js[:, 0])
138 | results = [A0]
139 | for i in range(1, parent.shape[0]):
140 | j_here = Js[:, i] - Js[:, parent[i]]
141 |
142 | s_par_inv = torch.inverse(scale_factors_3x3[:, parent[i]])
143 | rot = Rs[:, i]
144 | s = scale_factors_3x3[:, i]
145 |
146 | rot_new = s_par_inv @ rot @ s
147 |
148 | A_here = make_A(rot_new, j_here)
149 | res_here = torch.matmul(
150 | results[parent[i]], A_here)
151 |
152 | results.append(res_here)
153 |
154 | # 10 x 24 x 4 x 4
155 | results = torch.stack(results, dim=1)
156 |
157 | # scale updates
158 | new_J = results[:, :, :3, 3]
159 |
160 | # --- Compute relative A: Skinning is based on
161 | # how much the bone moved (not the final location of the bone)
162 | # but (final_bone - init_bone)
163 | # ---
164 | Js_w0 = torch.cat([Js_orig, torch.zeros([N, 35, 1, 1]).to(Rs.device)], 2)
165 | init_bone = torch.matmul(results, Js_w0)
166 | # Append empty 4 x 3:
167 | init_bone = torch.nn.functional.pad(init_bone, (3,0,0,0,0,0,0,0))
168 | A = results - init_bone
169 |
170 | return new_J, A
171 |
--------------------------------------------------------------------------------
/util/joint_limits_prior.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | Ranges = {
4 | 'pelvis': [[0, 0], [0, 0], [0, 0]],
5 | 'pelvis0': [[-0.3, 0.3], [-1.2, 0.5], [-0.1, 0.1]],
6 | 'spine': [[-0.4, 0.4], [-1.0, 0.9], [-0.8, 0.8]],
7 | 'spine0': [[-0.4, 0.4], [-1.0, 0.9], [-0.8, 0.8]],
8 | 'spine1': [[-0.4, 0.4], [-0.5, 1.2], [-0.4, 0.4]],
9 | 'spine3': [[-0.5, 0.5], [-0.6, 1.4], [-0.8, 0.8]],
10 | 'spine2': [[-0.5, 0.5], [-0.4, 1.4], [-0.5, 0.5]],
11 | 'RFootBack': [[-0.2, 0.3], [-0.3, 1.1], [-0.3, 0.5]],
12 | 'LFootBack': [[-0.3, 0.2], [-0.3, 1.1], [-0.5, 0.3]],
13 | 'LLegBack1': [[-0.2, 0.3], [-0.5, 0.8], [-0.5, 0.4]],
14 | 'RLegBack1': [[-0.3, 0.2], [-0.5, 0.8], [-0.4, 0.5]],
15 | 'Head': [[-0.5, 0.5], [-1.0, 0.9], [-0.9, 0.9]],
16 | 'RLegBack2': [[-0.3, 0.2], [-0.6, 0.8], [-0.5, 0.6]],
17 | 'LLegBack2': [[-0.2, 0.3], [-0.6, 0.8], [-0.6, 0.5]],
18 | 'RLegBack3': [[-0.2, 0.3], [-0.8, 0.2], [-0.4, 0.5]],
19 | 'LLegBack3': [[-0.3, 0.2], [-0.8, 0.2], [-0.5, 0.4]],
20 | 'Mouth': [[-0.1, 0.1], [-1.1, 0.5], [-0.1, 0.1]],
21 | 'Neck': [[-0.8, 0.8], [-1.0, 1.0], [-1.1, 1.1]],
22 | 'LLeg1': [[-0.05, 0.05], [-1.3, 0.8], [-0.6, 0.6]], # Extreme
23 | 'RLeg1': [[-0.05, 0.05], [-1.3, 0.8], [-0.6, 0.6]],
24 | 'RLeg2': [[-0.05, 0.05], [-1.0, 0.9], [-0.6, 0.6]], # Extreme
25 | 'LLeg2': [[-0.05, 0.05], [-1.0, 1.1], [-0.6, 0.6]],
26 | 'RLeg3': [[-0.1, 0.4], [-0.3, 1.4], [-0.4, 0.7]], # Extreme
27 | 'LLeg3': [[-0.4, 0.1], [-0.3, 1.4], [-0.7, 0.4]],
28 | 'LFoot': [[-0.3, 0.1], [-0.4, 1.5], [-0.7, 0.3]], # Extreme
29 | 'RFoot': [[-0.1, 0.3], [-0.4, 1.5], [-0.3, 0.7]],
30 | 'Tail7': [[-0.1, 0.1], [-0.7, 1.1], [-0.9, 0.8]],
31 | 'Tail6': [[-0.1, 0.1], [-1.4, 1.4], [-1.0, 1.0]],
32 | 'Tail5': [[-0.1, 0.1], [-1.0, 1.0], [-0.8, 0.8]],
33 | 'Tail4': [[-0.1, 0.1], [-1.0, 1.0], [-0.8, 0.8]],
34 | 'Tail3': [[-0.1, 0.1], [-1.0, 1.0], [-0.8, 0.8]],
35 | 'Tail2': [[-0.1, 0.1], [-1.0, 1.0], [-0.8, 0.8]],
36 | 'Tail1': [[-0.1, 0.1], [-1.5, 1.4], [-1.2, 1.2]],
37 | }
38 |
39 |
40 | class LimitPrior(object):
41 | def __init__(self, device, n_pose=32):
42 | self.parts = {
43 | 'pelvis0': 0,
44 | 'spine': 1,
45 | 'spine0': 2,
46 | 'spine1': 3,
47 | 'spine2': 4,
48 | 'spine3': 5,
49 | 'LLeg1': 6,
50 | 'LLeg2': 7,
51 | 'LLeg3': 8,
52 | 'LFoot': 9,
53 | 'RLeg1': 10,
54 | 'RLeg2': 11,
55 | 'RLeg3': 12,
56 | 'RFoot': 13,
57 | 'Neck': 14,
58 | 'Head': 15,
59 | 'LLegBack1': 16,
60 | 'LLegBack2': 17,
61 | 'LLegBack3': 18,
62 | 'LFootBack': 19,
63 | 'RLegBack1': 20,
64 | 'RLegBack2': 21,
65 | 'RLegBack3': 22,
66 | 'RFootBack': 23,
67 | 'Tail1': 24,
68 | 'Tail2': 25,
69 | 'Tail3': 26,
70 | 'Tail4': 27,
71 | 'Tail5': 28,
72 | 'Tail6': 29,
73 | 'Tail7': 30,
74 | 'Mouth': 31
75 | }
76 | self.id2name = {v: k for k, v in self.parts.items()}
77 | # Ignore the first joint.
78 | self.prefix = 3
79 | self.postfix= 99
80 | self.part_ids = np.array(sorted(self.parts.values()))
81 | min_values = np.hstack([np.array(np.array(Ranges[self.id2name[part_id]])[:, 0]) for part_id in self.part_ids])
82 | max_values = np.hstack([
83 | np.array(np.array(Ranges[self.id2name[part_id]])[:, 1])
84 | for part_id in self.part_ids
85 | ])
86 | self.ranges = Ranges
87 | self.device = device
88 | self.min_values = torch.from_numpy(min_values).view(n_pose, 3).float().to(device)
89 | self.max_values = torch.from_numpy(max_values).view(n_pose, 3).float().to(device)
90 |
91 | def __call__(self, x):
92 | '''
93 | Given x, rel rotation of 31 joints, for each parts compute the limit value.
94 | k is steepness of the curve, max_val + margin is the midpoint of the curve (val 0.5)
95 | Using Logistic:
96 | max limit: 1/(1 + exp(k * ((max_val + margin) - x)))
97 | min limit: 1/(1 + exp(k * (x - (min_val - margin))))
98 | With max/min:
99 | minlimit: max( min_vals - x , 0 )
100 | maxlimit: max( x - max_vals , 0 )
101 | With exponential:
102 | min: exp(k * (minval - x) )
103 | max: exp(k * (x - maxval) )
104 | '''
105 | ## Max/min discontinous but fast. (flat + L2 past the limit)
106 |
107 | x = x[:, self.prefix:self.postfix].view(x.shape[0], -1, 3)
108 | zeros = torch.zeros_like(x).to(self.device)
109 | # return np.maximum(x - self.max_values, zeros) + np.maximum(self.min_values - x, zeros)
110 | return torch.mean(torch.max(x - self.max_values.unsqueeze(0), zeros) + torch.max(self.min_values.unsqueeze(0) - x, zeros))
111 |
112 | def report(self, x):
113 | res = self(x).r.reshape(-1, 3)
114 | values = x[self.prefix:].r.reshape(-1, 3)
115 | bad = np.any(res > 0, axis=1)
116 | bad_ids = np.array(self.part_ids)[bad]
117 | np.set_printoptions(precision=3)
118 | for bad_id in bad_ids:
119 | name = self.id2name[bad_id]
120 | limits = self.ranges[name]
121 | print('%s over! Overby:' % name),
122 | print(res[bad_id - 1, :]),
123 | print(' Limits:'),
124 | print(limits),
125 | print(' Values:'),
126 | print(values[bad_id - 1, :])
127 |
128 | if __name__ == '__main__':
129 | name2id33 = {'RFoot': 14, 'RFootBack': 24, 'spine1': 4, 'Head': 16, 'LLegBack3': 19, 'RLegBack1': 21, 'pelvis0': 1,
130 | 'RLegBack3': 23, 'LLegBack2': 18, 'spine0': 3, 'spine3': 6, 'spine2': 5, 'Mouth': 32, 'Neck': 15,
131 | 'LFootBack': 20, 'LLegBack1': 17, 'RLeg3': 13, 'RLeg2': 12, 'LLeg1': 7, 'LLeg3': 9, 'RLeg1': 11,
132 | 'LLeg2': 8, 'spine': 2, 'LFoot': 10, 'Tail7': 31, 'Tail6': 30, 'Tail5': 29, 'Tail4': 28, 'Tail3': 27,
133 | 'Tail2': 26, 'Tail1': 25, 'RLegBack2': 22, 'root': 0}
134 | name2id35 = {'RFoot': 14, 'RFootBack': 24, 'spine1': 4, 'Head': 16, 'LLegBack3': 19, 'RLegBack1': 21, 'pelvis0': 1,
135 | 'RLegBack3': 23, 'LLegBack2': 18, 'spine0': 3, 'spine3': 6, 'spine2': 5, 'Mouth': 32, 'Neck': 15,
136 | 'LFootBack': 20, 'LLegBack1': 17, 'RLeg3': 13, 'RLeg2': 12, 'LLeg1': 7, 'LLeg3': 9, 'RLeg1': 11,
137 | 'LLeg2': 8, 'spine': 2, 'LFoot': 10, 'Tail7': 31, 'Tail6': 30, 'Tail5': 29, 'Tail4': 28, 'Tail3': 27,
138 | 'Tail2': 26, 'Tail1': 25, 'RLegBack2': 22, 'root': 0, 'LEar': 33, 'REar': 34}
139 |
140 | import os
141 | os.environ['CUDA_VISIBLE_DEVICES'] = '0'
142 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
143 | limit_prior = LimitPrior(device, 32)
144 | for k,v in limit_prior.parts.items():
145 | id33 = name2id33[k]-1
146 | id35 = name2id35[k]-1
147 | assert id33 == id35 and id33==v
148 | x = torch.zeros((35*3,)).float().to(device)
149 | limit_loss = limit_prior(x)
150 | print('done')
--------------------------------------------------------------------------------
/smal/smal_torch.py:
--------------------------------------------------------------------------------
1 | """
2 |
3 | PyTorch implementation of the SMAL/SMPL model
4 |
5 | """
6 | from __future__ import absolute_import
7 | from __future__ import division
8 | from __future__ import print_function
9 |
10 | import numpy as np
11 | import torch
12 | import torch.nn as nn
13 | from torch.autograd import Variable
14 | import pickle as pkl
15 | from smal.batch_lbs import batch_rodrigues, batch_global_rigid_transformation
16 | from smal.smal_basics import align_smal_template_to_symmetry_axis
17 | from util import config
18 |
19 | # There are chumpy variables so convert them to numpy.
20 | def undo_chumpy(x):
21 | return x if isinstance(x, np.ndarray) else x.r
22 |
23 | class SMAL(nn.Module):
24 | def __init__(self, device, shape_family_id=1, dtype=torch.float):
25 | super(SMAL, self).__init__()
26 |
27 | # -- Load SMPL params --
28 | # with open(pkl_path, 'r') as f:
29 | # dd = pkl.load(f)
30 |
31 | print (f"Loading SMAL with shape family: {shape_family_id}")
32 |
33 | with open(config.SMAL_FILE, 'rb') as f:
34 | u = pkl._Unpickler(f)
35 | u.encoding = 'latin1'
36 | dd = u.load()
37 |
38 | self.f = dd['f']
39 |
40 | self.faces = torch.from_numpy(self.f.astype(int)).to(device)
41 |
42 | v_template = dd['v_template']
43 |
44 | # Size of mesh [Number of vertices, 3]
45 | self.size = [v_template.shape[0], 3]
46 | self.num_betas = dd['shapedirs'].shape[-1]
47 | # Shape blend shape basis
48 |
49 | shapedir = np.reshape(
50 | undo_chumpy(dd['shapedirs']), [-1, self.num_betas]).T.copy()
51 | self.shapedirs = Variable(
52 | torch.Tensor(shapedir), requires_grad=False).to(device)
53 |
54 | with open(config.SMAL_DATA_FILE, 'rb') as f:
55 | u = pkl._Unpickler(f)
56 | u.encoding = 'latin1'
57 | data = u.load()
58 |
59 | # Zero_Betas -> V_Template -> Aligned
60 | # Zero_Betas -> V_Template -> V_Template + ShapeCluster * ShapeDirs -> Aligned
61 |
62 | # Aligned(V_T + ShapeCluster * ShapeDirs) - ShapeCluster * ShapeDirs
63 |
64 | # Select mean shape for quadruped type
65 | shape_cluster_means = data['cluster_means'][shape_family_id]
66 | self.shape_cluster_cov = data['cluster_cov'][shape_family_id]
67 | self.shape_cluster_means = torch.from_numpy(shape_cluster_means).float().to(device)
68 |
69 | v_sym, self.left_inds, self.right_inds, self.center_inds = align_smal_template_to_symmetry_axis(
70 | v_template, sym_file=config.SMAL_SYM_FILE)
71 |
72 | # Mean template vertices
73 | self.v_template = Variable(
74 | torch.Tensor(v_sym),
75 | requires_grad=False).to(device)
76 |
77 | # Regressor for joint locations given shape
78 | self.J_regressor = Variable(
79 | torch.Tensor(dd['J_regressor'].T.todense()),
80 | requires_grad=False).to(device)
81 |
82 | # Pose blend shape basis
83 | num_pose_basis = dd['posedirs'].shape[-1]
84 |
85 | posedirs = np.reshape(
86 | undo_chumpy(dd['posedirs']), [-1, num_pose_basis]).T
87 | self.posedirs = Variable(
88 | torch.Tensor(posedirs), requires_grad=False).to(device)
89 |
90 | # indices of parents for each joints
91 | self.parents = dd['kintree_table'][0].astype(np.int32)
92 |
93 | # LBS weights
94 | self.weights = Variable(
95 | torch.Tensor(undo_chumpy(dd['weights'])),
96 | requires_grad=False).to(device)
97 |
98 |
99 | def __call__(self, beta, theta, trans=None, del_v=None, betas_logscale=None, get_skin=True, v_template=None):
100 |
101 | if True:
102 | nBetas = beta.shape[1]
103 | else:
104 | nBetas = 0
105 |
106 |
107 | # v_template = self.v_template.unsqueeze(0).expand(beta.shape[0], 3889, 3)
108 | if v_template is None:
109 | v_template = self.v_template
110 |
111 | # 1. Add shape blend shapes
112 |
113 | if nBetas > 0:
114 | if del_v is None:
115 | v_shaped = v_template + torch.reshape(torch.matmul(beta, self.shapedirs[:nBetas,:]), [-1, self.size[0], self.size[1]])
116 | else:
117 | v_shaped = v_template + del_v + torch.reshape(torch.matmul(beta, self.shapedirs[:nBetas,:]), [-1, self.size[0], self.size[1]])
118 | else:
119 | if del_v is None:
120 | v_shaped = v_template.unsqueeze(0)
121 | else:
122 | v_shaped = v_template + del_v
123 |
124 | # 2. Infer shape-dependent joint locations.
125 | Jx = torch.matmul(v_shaped[:, :, 0], self.J_regressor)
126 | Jy = torch.matmul(v_shaped[:, :, 1], self.J_regressor)
127 | Jz = torch.matmul(v_shaped[:, :, 2], self.J_regressor)
128 | J = torch.stack([Jx, Jy, Jz], dim=2)
129 |
130 | # 3. Add pose blend shapes
131 | # N x 24 x 3 x 3
132 | if len(theta.shape) == 4:
133 | Rs = theta
134 | else:
135 | Rs = torch.reshape( batch_rodrigues(torch.reshape(theta, [-1, 3])), [-1, 35, 3, 3])
136 |
137 | # Ignore global rotation.
138 | pose_feature = torch.reshape(Rs[:, 1:, :, :] - torch.eye(3).to(beta.device), [-1, 306])
139 |
140 | v_posed = torch.reshape(
141 | torch.matmul(pose_feature, self.posedirs),
142 | [-1, self.size[0], self.size[1]]) + v_shaped
143 |
144 | #4. Get the global joint location
145 | self.J_transformed, A = batch_global_rigid_transformation(
146 | Rs, J, self.parents, betas_logscale=betas_logscale)
147 |
148 |
149 | # 5. Do skinning:
150 | num_batch = theta.shape[0]
151 |
152 | weights_t = self.weights.repeat([num_batch, 1])
153 | W = torch.reshape(weights_t, [num_batch, -1, 35])
154 |
155 |
156 | T = torch.reshape(
157 | torch.matmul(W, torch.reshape(A, [num_batch, 35, 16])),
158 | [num_batch, -1, 4, 4])
159 | v_posed_homo = torch.cat(
160 | [v_posed, torch.ones([num_batch, v_posed.shape[1], 1]).to(device=beta.device)], 2)
161 | v_homo = torch.matmul(T, v_posed_homo.unsqueeze(-1))
162 |
163 | verts = v_homo[:, :, :3, 0]
164 |
165 | if trans is None:
166 | trans = torch.zeros((num_batch,3)).to(device=beta.device)
167 |
168 | verts = verts + trans[:,None,:]
169 |
170 | # Get joints:
171 | joint_x = torch.matmul(verts[:, :, 0], self.J_regressor)
172 | joint_y = torch.matmul(verts[:, :, 1], self.J_regressor)
173 | joint_z = torch.matmul(verts[:, :, 2], self.J_regressor)
174 | joints = torch.stack([joint_x, joint_y, joint_z], dim=2)
175 |
176 | joints = torch.cat([
177 | joints,
178 | verts[:, None, 1863], # end_of_nose
179 | verts[:, None, 26], # chin
180 | verts[:, None, 2124], # right ear tip
181 | verts[:, None, 150], # left ear tip
182 | verts[:, None, 3055], # left eye
183 | verts[:, None, 1097], # right eye
184 | ], dim = 1)
185 |
186 | if get_skin:
187 | return verts, joints, Rs, v_shaped
188 | else:
189 | return joints
190 |
191 | def verts2joints(self, verts):
192 | # Get joints:
193 | joint_x = torch.matmul(verts[:, :, 0], self.J_regressor)
194 | joint_y = torch.matmul(verts[:, :, 1], self.J_regressor)
195 | joint_z = torch.matmul(verts[:, :, 2], self.J_regressor)
196 | joints = torch.stack([joint_x, joint_y, joint_z], dim=2)
197 |
198 | joints = torch.cat([
199 | joints,
200 | verts[:, None, 1863], # end_of_nose
201 | verts[:, None, 26], # chin
202 | verts[:, None, 2124], # right ear tip
203 | verts[:, None, 150], # left ear tip
204 | verts[:, None, 3055], # left eye
205 | verts[:, None, 1097], # right eye
206 | ], dim=1)
207 |
208 | return joints
209 |
210 |
211 |
212 |
--------------------------------------------------------------------------------
/datasets/imutils.py:
--------------------------------------------------------------------------------
1 | """
2 | This file contains functions that are used to perform data augmentation.
3 | """
4 | import torch
5 | import numpy as np
6 | import scipy.misc
7 | import cv2
8 | import imageio
9 | from scipy import ndimage
10 |
11 | def get_transform(center, scale, res, rot=0):
12 | """Generate transformation matrix."""
13 | h = 200 * scale
14 | t = np.zeros((3, 3))
15 | t[0, 0] = float(res[1]) / h
16 | t[1, 1] = float(res[0]) / h
17 | t[0, 2] = res[1] * (-float(center[0]) / h + .5)
18 | t[1, 2] = res[0] * (-float(center[1]) / h + .5)
19 | t[2, 2] = 1
20 | if not rot == 0:
21 | rot = -rot # To match direction of rotation from cropping
22 | rot_mat = np.zeros((3, 3))
23 | rot_rad = rot * np.pi / 180
24 | sn, cs = np.sin(rot_rad), np.cos(rot_rad)
25 | rot_mat[0, :2] = [cs, -sn]
26 | rot_mat[1, :2] = [sn, cs]
27 | rot_mat[2, 2] = 1
28 | # Need to rotate around center
29 | t_mat = np.eye(3)
30 | t_mat[0, 2] = -res[1] / 2
31 | t_mat[1, 2] = -res[0] / 2
32 | t_inv = t_mat.copy()
33 | t_inv[:2, 2] *= -1
34 | t = np.dot(t_inv, np.dot(rot_mat, np.dot(t_mat, t)))
35 | return t
36 |
37 |
38 | def transform(pt, center, scale, res, invert=0, rot=0):
39 | """Transform pixel location to different reference."""
40 | t = get_transform(center, scale, res, rot=rot)
41 | if invert:
42 | t = np.linalg.inv(t)
43 | new_pt = np.array([pt[0] - 1, pt[1] - 1, 1.]).T
44 | new_pt = np.dot(t, new_pt)
45 | return new_pt[:2].astype(int) + 1
46 |
47 |
48 | def crop(img, center, scale, res, rot=0, border_grey_intensity=0.0):
49 | """Crop image according to the supplied bounding box."""
50 | # Upper left point
51 | ul = np.array(transform([1, 1], center, scale, res, invert=1)) - 1
52 | # Bottom right point
53 | br = np.array(transform([res[0] + 1,
54 | res[1] + 1], center, scale, res, invert=1)) - 1
55 |
56 | # Padding so that when rotated proper amount of context is included
57 | pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2)
58 | if not rot == 0:
59 | ul -= pad
60 | br += pad
61 |
62 | new_shape = [br[1] - ul[1], br[0] - ul[0]]
63 | if len(img.shape) > 2:
64 | new_shape += [img.shape[2]]
65 | new_img = np.ones(new_shape) * border_grey_intensity
66 |
67 | # Range to fill new array
68 | new_x = max(0, -ul[0]), min(br[0], len(img[0])) - ul[0]
69 | new_y = max(0, -ul[1]), min(br[1], len(img)) - ul[1]
70 | # Range to sample from original image
71 | old_x = max(0, ul[0]), min(len(img[0]), br[0])
72 | old_y = max(0, ul[1]), min(len(img), br[1])
73 | try:
74 | new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1],
75 | old_x[0]:old_x[1]]
76 | except ValueError:
77 | print('here')
78 | if not rot == 0:
79 | # Remove padding
80 | new_img = scipy.misc.imrotate(new_img, rot)
81 | new_img = new_img[pad:-pad, pad:-pad]
82 |
83 | new_img = cv2.resize(new_img, (*res,))
84 | return new_img
85 |
86 |
87 | def uncrop(img, center, scale, orig_shape, rot=0, is_rgb=True):
88 | """'Undo' the image cropping/resizing.
89 | This function is used when evaluating mask/part segmentation.
90 | """
91 | res = img.shape[:2]
92 | # Upper left point
93 | ul = np.array(transform([1, 1], center, scale, res, invert=1)) - 1
94 | # Bottom right point
95 | br = np.array(transform([res[0] + 1, res[1] + 1], center, scale, res, invert=1)) - 1
96 | # size of cropped image
97 | crop_shape = [br[1] - ul[1], br[0] - ul[0]]
98 |
99 | new_shape = [br[1] - ul[1], br[0] - ul[0]]
100 | if len(img.shape) > 2:
101 | new_shape += [img.shape[2]]
102 | new_img = np.zeros(orig_shape, dtype=np.uint8)
103 | # Range to fill new array
104 | new_x = max(0, -ul[0]), min(br[0], orig_shape[1]) - ul[0]
105 | new_y = max(0, -ul[1]), min(br[1], orig_shape[0]) - ul[1]
106 | # Range to sample from original image
107 | old_x = max(0, ul[0]), min(orig_shape[1], br[0])
108 | old_y = max(0, ul[1]), min(orig_shape[0], br[1])
109 | img = scipy.misc.imresize(img, crop_shape, interp='nearest')
110 | new_img[old_y[0]:old_y[1], old_x[0]:old_x[1]] = img[new_y[0]:new_y[1], new_x[0]:new_x[1]]
111 | return new_img
112 |
113 |
114 | def rot_aa(aa, rot):
115 | """Rotate axis angle parameters."""
116 | # pose parameters
117 | R = np.array([[np.cos(np.deg2rad(-rot)), -np.sin(np.deg2rad(-rot)), 0],
118 | [np.sin(np.deg2rad(-rot)), np.cos(np.deg2rad(-rot)), 0],
119 | [0, 0, 1]])
120 | # find the rotation of the body in camera frame
121 | per_rdg, _ = cv2.Rodrigues(aa)
122 | # apply the global rotation to the global orientation
123 | resrot, _ = cv2.Rodrigues(np.dot(R, per_rdg))
124 | aa = (resrot.T)[0]
125 | return aa
126 |
127 |
128 | def flip_img(img):
129 | """Flip rgb images or masks.
130 | channels come last, e.g. (256,256,3).
131 | """
132 | img = np.fliplr(img)
133 | return img
134 |
135 |
136 | def flip_kp(kp, width):
137 | """Flip keypoints."""
138 | kp[:, 0] = width - kp[:, 0]
139 | flipped_parts = [6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 12, 13, 15, 14, 16, 17, 19, 18, 21, 20, 22,
140 | 23] # order of kp indices for L-R flipping
141 | kp = kp[flipped_parts]
142 |
143 | return kp
144 |
145 |
146 | def flip_pose(pose):
147 | """Flip pose.
148 | The flipping is based on SMPL parameters.
149 | """
150 | flippedParts = [0, 1, 2, 6, 7, 8, 3, 4, 5, 9, 10, 11, 15, 16, 17, 12, 13,
151 | 14, 18, 19, 20, 24, 25, 26, 21, 22, 23, 27, 28, 29, 33,
152 | 34, 35, 30, 31, 32, 36, 37, 38, 42, 43, 44, 39, 40, 41,
153 | 45, 46, 47, 51, 52, 53, 48, 49, 50, 57, 58, 59, 54, 55,
154 | 56, 63, 64, 65, 60, 61, 62, 69, 70, 71, 66, 67, 68]
155 | pose = pose[flippedParts]
156 | # we also negate the second and the third dimension of the axis-angle
157 | pose[1::3] = -pose[1::3]
158 | pose[2::3] = -pose[2::3]
159 | return pose
160 |
161 |
162 | def flip_aa(aa):
163 | """Flip axis-angle representation.
164 | We negate the second and the third dimension of the axis-angle.
165 | """
166 | aa[1] = -aa[1]
167 | aa[2] = -aa[2]
168 | return aa
169 |
170 |
171 | def draw_labelmap(img, pt, sigma, type='Gaussian'):
172 | # Draw a 2D gaussian
173 | # Adopted from https://github.com/anewell/pose-hg-train/blob/master/src/pypose/draw.py
174 | # img = to_numpy(img)
175 | img = img.numpy()
176 | # Check that any part of the gaussian is in-bounds
177 | ul = [int(pt[0] - 3 * sigma), int(pt[1] - 3 * sigma)]
178 | br = [int(pt[0] + 3 * sigma + 1), int(pt[1] + 3 * sigma + 1)]
179 | if (ul[0] >= img.shape[1] or ul[1] >= img.shape[0] or
180 | br[0] < 0 or br[1] < 0):
181 | # If not, just return the image as is
182 | return torch.from_numpy(img).float(), 0
183 |
184 | # Generate gaussian
185 | size = 6 * sigma + 1
186 | x = np.arange(0, size, 1, float)
187 | y = x[:, np.newaxis]
188 | x0 = y0 = size // 2
189 | # The gaussian is not normalized, we want the center value to equal 1
190 | if type == 'Gaussian':
191 | g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))
192 | elif type == 'Cauchy':
193 | g = sigma / (((x - x0) ** 2 + (y - y0) ** 2 + sigma ** 2) ** 1.5)
194 |
195 |
196 | # Usable gaussian range
197 | g_x = max(0, -ul[0]), min(br[0], img.shape[1]) - ul[0]
198 | g_y = max(0, -ul[1]), min(br[1], img.shape[0]) - ul[1]
199 | # Image range
200 | img_x = max(0, ul[0]), min(br[0], img.shape[1])
201 | img_y = max(0, ul[1]), min(br[1], img.shape[0])
202 |
203 | img[img_y[0]:img_y[1], img_x[0]:img_x[1]] = g[g_y[0]:g_y[1], g_x[0]:g_x[1]]
204 | # return to_torch(img), 1
205 | return torch.from_numpy(img).float(), 1
206 |
207 |
208 | def flip_back(flip_output):
209 | """
210 | flip output map
211 | """
212 |
213 | flipped_parts = [6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 12, 13, 15, 14, 16, 17, 19, 18, 21, 20, 22, 23]
214 | # flip output horizontally
215 | flip_output = fliplr(flip_output.numpy())
216 |
217 | # Change left-right parts
218 | flip_output = flip_output[:, flipped_parts, :, :]
219 |
220 | return torch.from_numpy(flip_output).float()
221 |
222 |
223 | def fliplr(x):
224 | if x.ndim == 3:
225 | x = np.transpose(np.fliplr(np.transpose(x, (0, 2, 1))), (0, 2, 1))
226 | elif x.ndim == 4:
227 | for i in range(x.shape[0]):
228 | x[i] = np.transpose(np.fliplr(np.transpose(x[i], (0, 2, 1))), (0, 2, 1))
229 | return x.astype(float)
230 |
--------------------------------------------------------------------------------
/coma/mesh_sampling.py:
--------------------------------------------------------------------------------
1 | import math
2 | import heapq
3 | import numpy as np
4 | import scipy.sparse as sp
5 | from opendr.topology import get_vert_connectivity, get_vertices_per_edge
6 | from psbody.mesh import Mesh
7 |
8 | def vertex_quadrics(mesh):
9 | """Computes a quadric for each vertex in the Mesh.
10 | Returns:
11 | v_quadrics: an (N x 4 x 4) array, where N is # vertices.
12 | """
13 |
14 | # Allocate quadrics
15 | v_quadrics = np.zeros((len(mesh.v), 4, 4,))
16 |
17 | # For each face...
18 | for f_idx in range(len(mesh.f)):
19 |
20 | # Compute normalized plane equation for that face
21 | vert_idxs = mesh.f[f_idx]
22 | verts = np.hstack((mesh.v[vert_idxs], np.array([1, 1, 1]).reshape(-1, 1)))
23 | u, s, v = np.linalg.svd(verts)
24 | eq = v[-1, :].reshape(-1, 1)
25 | eq = eq / (np.linalg.norm(eq[0:3]))
26 |
27 | # Add the outer product of the plane equation to the
28 | # quadrics of the vertices for this face
29 | for k in range(3):
30 | v_quadrics[mesh.f[f_idx, k], :, :] += np.outer(eq, eq)
31 |
32 | return v_quadrics
33 |
34 |
35 | def setup_deformation_transfer(source, target, use_normals=False):
36 | rows = np.zeros(3 * target.v.shape[0])
37 | cols = np.zeros(3 * target.v.shape[0])
38 | coeffs_v = np.zeros(3 * target.v.shape[0])
39 | coeffs_n = np.zeros(3 * target.v.shape[0])
40 |
41 | nearest_faces, nearest_parts, nearest_vertices = source.compute_aabb_tree().nearest(target.v, True)
42 | nearest_faces = nearest_faces.ravel().astype(np.int64)
43 | nearest_parts = nearest_parts.ravel().astype(np.int64)
44 | nearest_vertices = nearest_vertices.ravel()
45 |
46 | for i in range(target.v.shape[0]):
47 | # Closest triangle index
48 | f_id = nearest_faces[i]
49 | # Closest triangle vertex ids
50 | nearest_f = source.f[f_id]
51 |
52 | # Closest surface point
53 | nearest_v = nearest_vertices[3 * i:3 * i + 3]
54 | # Distance vector to the closest surface point
55 | dist_vec = target.v[i] - nearest_v
56 |
57 | rows[3 * i:3 * i + 3] = i * np.ones(3)
58 | cols[3 * i:3 * i + 3] = nearest_f
59 |
60 | n_id = nearest_parts[i]
61 | if n_id == 0:
62 | # Closest surface point in triangle
63 | A = np.vstack((source.v[nearest_f])).T
64 | coeffs_v[3 * i:3 * i + 3] = np.linalg.lstsq(A, nearest_v)[0]
65 | elif n_id > 0 and n_id <= 3:
66 | # Closest surface point on edge
67 | A = np.vstack((source.v[nearest_f[n_id - 1]], source.v[nearest_f[n_id % 3]])).T
68 | tmp_coeffs = np.linalg.lstsq(A, target.v[i])[0]
69 | coeffs_v[3 * i + n_id - 1] = tmp_coeffs[0]
70 | coeffs_v[3 * i + n_id % 3] = tmp_coeffs[1]
71 | else:
72 | # Closest surface point a vertex
73 | coeffs_v[3 * i + n_id - 4] = 1.0
74 |
75 | # if use_normals:
76 | # A = np.vstack((vn[nearest_f])).T
77 | # coeffs_n[3 * i:3 * i + 3] = np.linalg.lstsq(A, dist_vec)[0]
78 |
79 | # coeffs = np.hstack((coeffs_v, coeffs_n))
80 | # rows = np.hstack((rows, rows))
81 | # cols = np.hstack((cols, source.v.shape[0] + cols))
82 | matrix = sp.csc_matrix((coeffs_v, (rows, cols)), shape=(target.v.shape[0], source.v.shape[0]))
83 | return matrix
84 |
85 |
86 | def qslim_decimator_transformer(mesh, factor=None, n_verts_desired=None):
87 | """Return a simplified version of this mesh.
88 | A Qslim-style approach is used here.
89 | :param factor: fraction of the original vertices to retain
90 | :param n_verts_desired: number of the original vertices to retain
91 | :returns: new_faces: An Fx3 array of faces, mtx: Transformation matrix
92 | """
93 |
94 | if factor is None and n_verts_desired is None:
95 | raise Exception('Need either factor or n_verts_desired.')
96 |
97 | if n_verts_desired is None:
98 | n_verts_desired = math.ceil(len(mesh.v) * factor)
99 |
100 | Qv = vertex_quadrics(mesh)
101 |
102 | # fill out a sparse matrix indicating vertex-vertex adjacency
103 | # from psbody.mesh.topology.connectivity import get_vertices_per_edge
104 | vert_adj = get_vertices_per_edge(mesh.v, mesh.f)
105 | # vert_adj = sp.lil_matrix((len(mesh.v), len(mesh.v)))
106 | # for f_idx in range(len(mesh.f)):
107 | # vert_adj[mesh.f[f_idx], mesh.f[f_idx]] = 1
108 |
109 | vert_adj = sp.csc_matrix((vert_adj[:, 0] * 0 + 1, (vert_adj[:, 0], vert_adj[:, 1])),
110 | shape=(len(mesh.v), len(mesh.v)))
111 | vert_adj = vert_adj + vert_adj.T
112 | vert_adj = vert_adj.tocoo()
113 |
114 | def collapse_cost(Qv, r, c, v):
115 | Qsum = Qv[r, :, :] + Qv[c, :, :]
116 | p1 = np.vstack((v[r].reshape(-1, 1), np.array([1]).reshape(-1, 1)))
117 | p2 = np.vstack((v[c].reshape(-1, 1), np.array([1]).reshape(-1, 1)))
118 |
119 | destroy_c_cost = p1.T.dot(Qsum).dot(p1)
120 | destroy_r_cost = p2.T.dot(Qsum).dot(p2)
121 | result = {
122 | 'destroy_c_cost': destroy_c_cost,
123 | 'destroy_r_cost': destroy_r_cost,
124 | 'collapse_cost': min([destroy_c_cost, destroy_r_cost]),
125 | 'Qsum': Qsum}
126 | return result
127 |
128 | # construct a queue of edges with costs
129 | queue = []
130 | for k in range(vert_adj.nnz):
131 | r = vert_adj.row[k]
132 | c = vert_adj.col[k]
133 |
134 | if r > c:
135 | continue
136 |
137 | cost = collapse_cost(Qv, r, c, mesh.v)['collapse_cost']
138 | heapq.heappush(queue, (cost, (r, c)))
139 |
140 | # decimate
141 | collapse_list = []
142 | nverts_total = len(mesh.v)
143 | faces = mesh.f.copy()
144 | while nverts_total > n_verts_desired:
145 | e = heapq.heappop(queue)
146 | r = e[1][0]
147 | c = e[1][1]
148 | if r == c:
149 | continue
150 |
151 | cost = collapse_cost(Qv, r, c, mesh.v)
152 | if cost['collapse_cost'] > e[0]:
153 | heapq.heappush(queue, (cost['collapse_cost'], e[1]))
154 | # print 'found outdated cost, %.2f < %.2f' % (e[0], cost['collapse_cost'])
155 | continue
156 | else:
157 |
158 | # update old vert idxs to new one,
159 | # in queue and in face list
160 | if cost['destroy_c_cost'] < cost['destroy_r_cost']:
161 | to_destroy = c
162 | to_keep = r
163 | else:
164 | to_destroy = r
165 | to_keep = c
166 |
167 | collapse_list.append([to_keep, to_destroy])
168 |
169 | # in our face array, replace "to_destroy" vertidx with "to_keep" vertidx
170 | np.place(faces, faces == to_destroy, to_keep)
171 |
172 | # same for queue
173 | which1 = [idx for idx in range(len(queue)) if queue[idx][1][0] == to_destroy]
174 | which2 = [idx for idx in range(len(queue)) if queue[idx][1][1] == to_destroy]
175 | for k in which1:
176 | queue[k] = (queue[k][0], (to_keep, queue[k][1][1]))
177 | for k in which2:
178 | queue[k] = (queue[k][0], (queue[k][1][0], to_keep))
179 |
180 | Qv[r, :, :] = cost['Qsum']
181 | Qv[c, :, :] = cost['Qsum']
182 |
183 | a = faces[:, 0] == faces[:, 1]
184 | b = faces[:, 1] == faces[:, 2]
185 | c = faces[:, 2] == faces[:, 0]
186 |
187 | # remove degenerate faces
188 | def logical_or3(x, y, z):
189 | return np.logical_or(x, np.logical_or(y, z))
190 |
191 | faces_to_keep = np.logical_not(logical_or3(a, b, c))
192 | faces = faces[faces_to_keep, :].copy()
193 |
194 | nverts_total = (len(np.unique(faces.flatten())))
195 |
196 | new_faces, mtx = _get_sparse_transform(faces, len(mesh.v))
197 | return new_faces, mtx
198 |
199 |
200 | def _get_sparse_transform(faces, num_original_verts):
201 | verts_left = np.unique(faces.flatten())
202 | IS = np.arange(len(verts_left))
203 | JS = verts_left
204 | data = np.ones(len(JS))
205 |
206 | mp = np.arange(0, np.max(faces.flatten()) + 1)
207 | mp[JS] = IS
208 | new_faces = mp[faces.copy().flatten()].reshape((-1, 3))
209 |
210 | ij = np.vstack((IS.flatten(), JS.flatten()))
211 | mtx = sp.csc_matrix((data, ij), shape=(len(verts_left), num_original_verts))
212 |
213 | return (new_faces, mtx)
214 |
215 |
216 | def generate_transform_matrices(mesh, factors):
217 | """Generates len(factors) meshes, each of them is scaled by factors[i] and
218 | computes the transformations between them.
219 |
220 | Returns:
221 | M: a set of meshes downsampled from mesh by a factor specified in factors.
222 | A: Adjacency matrix for each of the meshes
223 | D: Downsampling transforms between each of the meshes
224 | U: Upsampling transforms between each of the meshes
225 | """
226 |
227 | factors = map(lambda x: 1.0 / x, factors)
228 | M, A, D, U = [], [], [], []
229 | A.append(get_vert_connectivity(mesh.v, mesh.f))
230 | M.append(mesh)
231 |
232 | for factor in factors:
233 | ds_f, ds_D = qslim_decimator_transformer(M[-1], factor=factor)
234 | D.append(ds_D)
235 | new_mesh_v = ds_D.dot(M[-1].v)
236 | new_mesh = Mesh(v=new_mesh_v, f=ds_f)
237 | M.append(new_mesh)
238 | A.append(get_vert_connectivity(new_mesh.v, new_mesh.f))
239 | U.append(setup_deformation_transfer(M[-1], M[-2]))
240 |
241 | return M, A, D, U
--------------------------------------------------------------------------------
/util/nmr.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import numpy as np
6 | import scipy.misc
7 | import tqdm
8 |
9 | # import chainer
10 | import torch
11 |
12 | import neural_renderer
13 |
14 | from util import geom_utils
15 | import pickle as pkl
16 |
17 | from util import config
18 |
19 | #############
20 | ### Utils ###
21 | #############
22 | def convert_as(src, trg):
23 | src = src.type_as(trg)
24 | if src.is_cuda:
25 | src = src.cuda(device=trg.get_device())
26 | return src
27 |
28 | ########################################################################
29 | ############ Wrapper class for the chainer Neural Renderer #############
30 | ##### All functions must only use numpy arrays as inputs/outputs #######
31 | ########################################################################
32 | # class NMR(object):
33 | # def __init__(self):
34 | # # setup renderer
35 | # renderer = neural_renderer.Renderer()
36 | # self.renderer = renderer
37 |
38 | # def to_gpu(self, device=0):
39 | # # self.renderer.to_gpu(device)
40 | # self.cuda_device = device
41 |
42 | # def forward_mask(self, vertices, faces):
43 | # ''' Renders masks.
44 | # Args:
45 | # vertices: B X N X 3 numpy array
46 | # faces: B X F X 3 numpy array
47 | # Returns:
48 | # masks: B X 256 X 256 numpy array
49 | # '''
50 | # # self.faces = chainer.Variable(chainer.cuda.to_gpu(faces, self.cuda_device))
51 | # # self.vertices = chainer.Variable(chainer.cuda.to_gpu(vertices, self.cuda_device))
52 | # # self.masks = self.renderer.render_silhouettes(self.vertices, self.faces)
53 |
54 | # masks = self.renderer.render_silhouettes(vertices, faces)
55 |
56 | # # masks = self.masks.data.get()
57 | # return masks
58 |
59 | # def backward_mask(self, grad_masks):
60 | # ''' Compute gradient of vertices given mask gradients.
61 | # Args:
62 | # grad_masks: B X 256 X 256 numpy array
63 | # Returns:
64 | # grad_vertices: B X N X 3 numpy array
65 | # '''
66 | # self.masks.grad = chainer.cuda.to_gpu(grad_masks, self.cuda_device)
67 | # self.masks.backward()
68 | # return self.vertices.grad.get()
69 |
70 | # def forward_img(self, vertices, faces, textures):
71 | # ''' Renders masks.
72 | # Args:
73 | # vertices: B X N X 3 numpy array
74 | # faces: B X F X 3 numpy array
75 | # textures: B X F X T X T X T X 3 numpy array
76 | # Returns:
77 | # images: B X 3 x 256 X 256 numpy array
78 | # '''
79 | # self.faces = chainer.Variable(chainer.cuda.to_gpu(faces, self.cuda_device))
80 | # self.vertices = chainer.Variable(chainer.cuda.to_gpu(vertices, self.cuda_device))
81 | # self.textures = chainer.Variable(chainer.cuda.to_gpu(textures, self.cuda_device))
82 | # self.images = self.renderer.render(self.vertices, self.faces, self.textures)
83 |
84 | # images = self.images.data.get()
85 | # return images
86 |
87 |
88 | # def backward_img(self, grad_images):
89 | # ''' Compute gradient of vertices given image gradients.
90 | # Args:
91 | # grad_images: B X 3? X 256 X 256 numpy array
92 | # Returns:
93 | # grad_vertices: B X N X 3 numpy array
94 | # grad_textures: B X F X T X T X T X 3 numpy array
95 | # '''
96 | # self.images.grad = chainer.cuda.to_gpu(grad_images, self.cuda_device)
97 | # self.images.backward()
98 | # return self.vertices.grad.get(), self.textures.grad.get()
99 |
100 | ########################################################################
101 | ################# Wrapper class a rendering PythonOp ###################
102 | ##### All functions must only use torch Tensors as inputs/outputs ######
103 | ########################################################################
104 | # class Render(torch.autograd.Function):
105 | # # TODO(Shubham): Make sure the outputs/gradients are on the GPU
106 | # def __init__(self, renderer):
107 | # super(Render, self).__init__()
108 | # self.renderer = renderer
109 |
110 | # def forward(self, vertices, faces, textures=None):
111 | # # B x N x 3
112 | # # Flipping the y-axis here to make it align with the image coordinate system!
113 | # vs = vertices.cpu().numpy()
114 | # vs[:, :, 1] *= -1
115 | # fs = faces.cpu().numpy()
116 | # if textures is None:
117 | # self.mask_only = True
118 | # masks = self.renderer.forward_mask(vs, fs)
119 | # return convert_as(torch.Tensor(masks), vertices)
120 | # else:
121 | # self.mask_only = False
122 | # ts = textures.cpu().numpy()
123 | # imgs = self.renderer.forward_img(vs, fs, ts)
124 | # return convert_as(torch.Tensor(imgs), vertices)
125 |
126 | # def backward(self, grad_out):
127 | # g_o = grad_out.cpu().numpy()
128 | # if self.mask_only:
129 | # grad_verts = self.renderer.backward_mask(g_o)
130 | # grad_verts = convert_as(torch.Tensor(grad_verts), grad_out)
131 | # grad_tex = None
132 | # else:
133 | # grad_verts, grad_tex = self.renderer.backward_img(g_o)
134 | # grad_verts = convert_as(torch.Tensor(grad_verts), grad_out)
135 | # grad_tex = convert_as(torch.Tensor(grad_tex), grad_out)
136 |
137 | # grad_verts[:, :, 1] *= -1
138 | # return grad_verts, None, grad_tex
139 |
140 |
141 | ########################################################################
142 | ############## Wrapper torch module for Neural Renderer ################
143 | ########################################################################
144 | class NeuralRenderer(torch.nn.Module):
145 | """
146 | This is the core pytorch function to call.
147 | Every torch NMR has a chainer NMR.
148 | Only fwd/bwd once per iteration.
149 | """
150 | def __init__(self, img_size=256, proj_type='perspective', norm_f=1., norm_z=0.,norm_f0=0.,
151 | render_rgb=False, device=None):
152 | super(NeuralRenderer, self).__init__()
153 | # self.renderer = NMR()
154 |
155 | self.renderer = neural_renderer.Renderer(camera_mode='look_at', background_color=[1,1,1])
156 |
157 | # self.renderer = nr.Renderer(camera_mode='look_at')
158 |
159 | self.norm_f = norm_f
160 | self.norm_f0 = norm_f0
161 | self.norm_z = norm_z
162 |
163 | # Adjust the core renderer
164 | self.renderer.image_size = img_size
165 | self.renderer.perspective = False
166 |
167 | # Set a default camera to be at (0, 0, -2.732)
168 | self.renderer.eye = [0, 0, -1.0]
169 |
170 | # Make it a bit brighter for vis
171 | self.renderer.light_intensity_ambient = 0.8
172 |
173 | # self.renderer.to_gpu()
174 |
175 | # Silvia
176 | if proj_type == 'perspective':
177 | self.proj_fn = geom_utils.perspective_proj_withz
178 | else:
179 | print('unknown projection type')
180 | import pdb; pdb.set_trace()
181 |
182 | self.offset_z = -1.0
183 | self.textures = torch.ones(7774, 4, 4, 4, 3) * torch.FloatTensor(config.MESH_COLOR) / 255.0 # light blue
184 | # self.textures = self.textures.cuda()
185 | self.textures = self.textures.to(device)
186 | self.render_rgb = render_rgb
187 |
188 | def ambient_light_only(self):
189 | # Make light only ambient.
190 | self.renderer.light_intensity_ambient = 1
191 | self.renderer.light_intensity_directional = 0
192 |
193 | def directional_light_only(self):
194 | # Make light only directional.
195 | self.renderer.light_intensity_ambient = 0.8
196 | self.renderer.light_intensity_directional = 0.8
197 | self.renderer.light_direction = [0, 1, 0] # up-to-down, this is the default
198 |
199 | def set_bgcolor(self, color):
200 | self.renderer.background_color = color
201 |
202 | def project_points(self, verts, cams, normalize_kpts=False):
203 | proj = self.proj_fn(verts, cams, offset_z=self.offset_z, norm_f=self.norm_f, norm_z=self.norm_z,
204 | norm_f0=self.norm_f0)
205 | image_size_half = cams[0, 1]
206 | proj_points = proj[:, :, :2]
207 | if not normalize_kpts: # output 2d keypoint in the image coordinate
208 | proj_points = (proj_points[:, :, :2] + 1) * image_size_half
209 | proj_points = torch.stack([
210 | image_size_half * 2 - proj_points[:, :, 1],
211 | proj_points[:, :, 0]], dim=-1)
212 | else: # output 2d keypoint in the normalized image coordinate
213 | proj_points = torch.stack([
214 | proj_points[:, :, 0],
215 | -proj_points[:, :, 1]], dim=-1)
216 |
217 | return proj_points
218 |
219 | def forward(self, vertices, faces, cams, textures=None):
220 | verts = self.proj_fn(vertices, cams, offset_z=self.offset_z, norm_f=self.norm_f, norm_z=self.norm_z, norm_f0=self.norm_f0)
221 | if textures is not None:
222 | if self.render_rgb:
223 | img = self.renderer.render(verts, faces, textures)
224 | else:
225 | img = None
226 | sil = self.renderer.render_silhouettes(verts, faces)
227 | return img, sil
228 |
229 | else:
230 | textures = self.textures.unsqueeze(0).expand(verts.shape[0], -1, -1, -1, -1, -1)
231 | if self.render_rgb:
232 | img = self.renderer.render(verts, faces, textures)
233 | else:
234 | img = None
235 | sil = self.renderer.render_silhouettes(verts, faces)
236 | return img, sil
237 |
238 |
239 |
240 |
--------------------------------------------------------------------------------
/datasets/stanford.py:
--------------------------------------------------------------------------------
1 |
2 | from __future__ import division
3 |
4 | import torch
5 | from torch.utils.data import Dataset
6 | from torchvision.transforms import Normalize
7 | import numpy as np
8 | import cv2
9 | from os.path import join
10 | import os
11 | import json
12 |
13 | from util import config
14 | from datasets.imutils import crop, flip_img, flip_pose, flip_kp, transform, rot_aa
15 | from pycocotools.mask import decode as decode_RLE
16 |
17 |
18 | def seg_from_anno(entry):
19 | """Given a .json entry, returns the binary mask as a numpy array"""
20 |
21 | rle = {
22 | "size": [entry['img_height'], entry['img_width']],
23 | "counts": entry['seg']
24 | }
25 |
26 | decoded = decode_RLE(rle)
27 | return decoded
28 |
29 |
30 | class BaseDataset(Dataset):
31 | """
32 | Base Dataset Class - Handles data loading and augmentation.
33 | Able to handle heterogeneous datasets (different annotations available for different datasets).
34 | You need to update the path to each dataset in utils/config.py.
35 | """
36 |
37 | def __init__(self,
38 | dataset,
39 | param_dir=None,
40 | use_augmentation=True,
41 | is_train=True,
42 | img_res=224):
43 |
44 | super(BaseDataset, self).__init__()
45 | self.dataset = dataset
46 | self.param_dir = param_dir
47 | self.is_train = is_train
48 | BASE_FOLDER = config.DATASET_FOLDERS[dataset]
49 |
50 | self.img_dir = os.path.join(BASE_FOLDER, 'Images')
51 | self.jsonfile = os.path.join(BASE_FOLDER, config.JSON_NAME[dataset]) # accessing new version of keypoints.json
52 | # create train/test split
53 | with open(self.jsonfile) as anno_file:
54 | self.anno = json.load(anno_file)
55 |
56 | self.data_idx = np.load(os.path.join(config.DATASET_FILES[is_train][dataset]))
57 | print("Number of images: {}".format(len(self.data_idx)))
58 | # self.options = options
59 | self.normalize_img = Normalize(
60 | mean=config.IMG_NORM_MEAN, std=config.IMG_NORM_STD)
61 |
62 | self.rot_factor = 30 # Random rotation in the range [-rot_factor, rot_factor]'
63 | self.noise_factor = 0.4 # Random rotation in the range [-rot_factor, rot_factor]
64 | self.scale_factor = 0.25 # Rescale bounding boxes by a factor of [1-options.scale_factor,1+options.scale_factor]
65 | self.img_res = img_res
66 |
67 | # If False, do not do augmentation
68 | self.use_augmentation = use_augmentation
69 |
70 | def augm_params(self):
71 | """Get augmentation parameters."""
72 | flip = 0 # flipping
73 | pn = np.ones(3) # per channel pixel-noise
74 | rot = 0 # rotation
75 | sc = 1 # scaling
76 | if self.is_train and self.use_augmentation:
77 | # We flip with probability 1/2
78 | # if np.random.uniform() <= 0.5:
79 | # flip = 1
80 |
81 | # Each channel is multiplied with a number
82 | # in the area [1-opt.noiseFactor,1+opt.noiseFactor]
83 | # pn = np.random.uniform(1-self.options.noise_factor, 1+self.options.noise_factor, 3)
84 | pn = np.random.uniform(1 - self.noise_factor, 1 + self.noise_factor, 3)
85 |
86 | # The rotation is a number in the area [-2*rotFactor, 2*rotFactor]
87 | # rot = min(2*self.options.rot_factor,
88 | # max(-2*self.options.rot_factor, np.random.randn()*self.options.rot_factor))
89 |
90 | rot = min(2 * self.rot_factor,
91 | max(-2 * self.rot_factor, np.random.randn() * self.rot_factor))
92 |
93 | # The scale is multiplied with a number
94 | # in the area [1-scaleFactor,1+scaleFactor]
95 | # sc = min(1+self.options.scale_factor,
96 | # max(1-self.options.scale_factor, np.random.randn()*self.options.scale_factor+1))
97 |
98 | sc = min(1 + self.scale_factor,
99 | max(1 - self.scale_factor, np.random.randn() * self.scale_factor + 1))
100 | # but it is zero with probability 3/5
101 | if np.random.uniform() <= 0.6:
102 | rot = 0
103 |
104 | return flip, pn, rot, sc
105 |
106 | def rgb_processing(self, rgb_img, center, scale, rot, flip, pn, border_grey_intensity=0.0):
107 | """Process rgb image and do augmentation."""
108 | rgb_img = crop(rgb_img, center, scale,
109 | [self.img_res, self.img_res], rot=rot,
110 | border_grey_intensity=border_grey_intensity)
111 |
112 | # flip the image
113 | if flip:
114 | rgb_img = flip_img(rgb_img)
115 | # in the rgb image we add pixel noise in a channel-wise manner
116 | rgb_img[:, :, 0] = np.minimum(255.0, np.maximum(0.0, rgb_img[:, :, 0] * pn[0]))
117 | rgb_img[:, :, 1] = np.minimum(255.0, np.maximum(0.0, rgb_img[:, :, 1] * pn[1]))
118 | rgb_img[:, :, 2] = np.minimum(255.0, np.maximum(0.0, rgb_img[:, :, 2] * pn[2]))
119 | # (3,224,224),float,[0,1]
120 | rgb_img = np.transpose(rgb_img.astype('float32'), (2, 0, 1)) / 255.0
121 | return rgb_img
122 |
123 | def j2d_processing(self, kp, center, scale, r, f):
124 | """Process gt 2D keypoints and apply all augmentation transforms."""
125 | nparts = kp.shape[0]
126 | for i in range(nparts):
127 | kp[i, 0:2] = transform(kp[i, 0:2] + 1, center, scale,
128 | [self.img_res, self.img_res], rot=r)
129 | # flip the x coordinates
130 | if f:
131 | kp = flip_kp(kp, config.IMG_RES)
132 | kp_norm = kp.copy()
133 | # convert to normalized coordinates
134 | kp_norm[:, :-1] = 2. * kp_norm[:, :-1] / self.img_res - 1.
135 |
136 | kp = kp.astype('float32')
137 | return kp, kp_norm.astype('float32')
138 |
139 | def j3d_processing(self, S, r, f):
140 | """Process gt 3D keypoints and apply all augmentation transforms."""
141 | # in-plane rotation
142 | rot_mat = np.eye(3)
143 | if not r == 0:
144 | rot_rad = -r * np.pi / 180
145 | sn, cs = np.sin(rot_rad), np.cos(rot_rad)
146 | rot_mat[0, :2] = [cs, -sn]
147 | rot_mat[1, :2] = [sn, cs]
148 | S = np.einsum('ij,kj->ki', rot_mat, S)
149 | # flip the x coordinates
150 | if f:
151 | S = flip_kp(S)
152 | S = S.astype('float32')
153 | return S
154 |
155 | def pose_processing(self, pose, r, f):
156 | """Process SMPL theta parameters and apply all augmentation transforms."""
157 | # rotation or the pose parameters
158 | pose[:3] = rot_aa(pose[:3], r)
159 | # flip the pose parameters
160 | if f:
161 | pose = flip_pose(pose)
162 | # (72),float
163 | pose = pose.astype('float32')
164 | return pose
165 |
166 | def __getitem__(self, index):
167 | idx = index
168 | img_idx = self.data_idx[idx]
169 | a = self.anno[img_idx]
170 |
171 | # Get augmentation parameters
172 | flip, pn, rot, sc = self.augm_params()
173 |
174 | imgname_raw = a['img_path']
175 |
176 | # Load image
177 | imgname = join(self.img_dir, imgname_raw)
178 |
179 | # Some datatsets store \ instead of /, so convert
180 | imgname = imgname.replace("\\", "/")
181 | filename = imgname_raw.split('/')[1].split('.')[0] if self.dataset == 'stanford' else imgname_raw.split('.')[0]
182 | assert os.path.exists(imgname), "Cannot find image: {0}".format(imgname)
183 | img = cv2.imread(imgname)[:, :, ::-1].copy().astype('float32')
184 |
185 | seg = seg_from_anno(a) # (H, W) bool
186 | seg = seg.astype('float32') * 255.
187 | assert (np.nonzero(seg)[0].shape[0] > 1)
188 | seg = np.dstack([seg, seg, seg]) # (H, W, 3) as float
189 | x0, y0, width, height = a['img_bbox']
190 |
191 | scaleFactor = 1.2
192 | scale = scaleFactor * max(width, height) / 200
193 |
194 | kp_S24 = np.array(a['joints'])
195 | center = np.array([x0 + width / 2, y0 + height / 2]) # Center of dog
196 |
197 | kp_S24, kp_S24_norm = self.j2d_processing(kp_S24, center, sc * scale, rot, flip)
198 |
199 | kp_S24 = torch.from_numpy(kp_S24)
200 | kp_S24_norm = torch.from_numpy(kp_S24_norm)
201 | img_crop = self.rgb_processing(img, center, sc * scale, rot, flip, pn, border_grey_intensity=255.0)
202 | seg_crop = self.rgb_processing(seg, center, sc * scale, rot, flip,
203 | np.array([1.0, 1.0, 1.0])) # No pixel noise multiplier
204 |
205 | item = {}
206 |
207 | item['has_pose_3d'] = False
208 | item['has_smpl'] = False
209 | item['keypoints_3d'] = np.zeros((24, 4))
210 |
211 | item['pred_pose'] = np.zeros((105))
212 | item['pred_shape'] = np.zeros((26))
213 | item['pred_camera'] = np.zeros((3))
214 | item['pred_trans'] = np.zeros((3))
215 |
216 | if self.param_dir is not None:
217 | inp_path = imgname_raw.replace("/", "_").replace(".jpg", ".npz")
218 | if self.dataset == 'animal_pose':
219 | inp_path = "images_{0}".format(inp_path)
220 |
221 | with np.load(os.path.join(self.param_dir, inp_path)) as f:
222 | item['pred_pose'] = f.f.pose
223 | item['pred_shape'] = f.f.betas
224 | item['pred_camera'] = f.f.camera
225 | item['pred_trans'] = f.f.trans
226 |
227 | item['imgname'] = imgname
228 | item['keypoints_norm'] = kp_S24_norm[config.EVAL_KEYPOINTS]
229 | item['keypoints'] = kp_S24[config.EVAL_KEYPOINTS]
230 | item['scale'] = float(sc * scale)
231 | item['center'] = center.astype('float32')
232 | item['index'] = img_idx
233 |
234 | img_crop = torch.from_numpy(img_crop).float()
235 | seg_crop = torch.from_numpy(seg_crop[[0]]).float() # [3, h, w] -> [1, h, w]
236 |
237 | item['img_orig'] = img_crop.clone()
238 | item['img'] = self.normalize_img(img_crop)
239 | item['img_border_mask'] = torch.all(img_crop < 1.0, dim=0).unsqueeze(0).float()
240 | item['seg'] = seg_crop.clone()
241 | item['has_seg'] = True
242 | item['dataset'] = self.dataset
243 | item['filename'] = filename
244 |
245 | return item
246 |
247 | def __len__(self):
248 | return len(self.data_idx)
249 |
--------------------------------------------------------------------------------
/util/net_blocks.py:
--------------------------------------------------------------------------------
1 | '''
2 | CNN building blocks.
3 | Taken from https://github.com/shubhtuls/factored3d/
4 | '''
5 | from __future__ import division
6 | from __future__ import print_function
7 | import torch
8 | import torch.nn as nn
9 | import math
10 |
11 |
12 | class Flatten(nn.Module):
13 | def forward(self, x):
14 | return x.view(x.size()[0], -1)
15 |
16 |
17 | class Unsqueeze(nn.Module):
18 | def __init__(self, dim):
19 | super(Unsqueeze, self).__init__()
20 | self.dim = dim
21 |
22 | def forward(self, x):
23 | return x.unsqueeze(self.dim)
24 |
25 |
26 | ## fc layers
27 | def fc(norm_type, nc_inp, nc_out):
28 | if norm_type == 'batch':
29 | return nn.Sequential(
30 | nn.Linear(nc_inp, nc_out, bias=True),
31 | nn.BatchNorm1d(nc_out),
32 | nn.LeakyReLU(0.2, inplace=True)
33 | )
34 | else:
35 | return nn.Sequential(
36 | nn.Linear(nc_inp, nc_out),
37 | nn.LeakyReLU(0.1, inplace=True)
38 | )
39 |
40 |
41 | def fc_stack(nc_inp, nc_out, nlayers, norm_type='batch'):
42 | modules = []
43 | for l in range(nlayers):
44 | modules.append(fc(norm_type, nc_inp, nc_out))
45 | nc_inp = nc_out
46 | encoder = nn.Sequential(*modules)
47 | net_init(encoder)
48 | return encoder
49 |
50 |
51 | def fc_stack_dropout(nc_inp, nc_out, nlayers):
52 | modules = []
53 | modules.append(nn.Linear(nc_inp, 1024, bias=True))
54 | modules.append(nn.ReLU())
55 | modules.append(nn.Dropout())
56 | modules.append(nn.Linear(1024, 1024, bias=True))
57 | modules.append(nn.ReLU())
58 | modules.append(nn.Dropout())
59 | modules.append(nn.Linear(1024, nc_out, bias=True))
60 |
61 | encoder = nn.Sequential(*modules)
62 | net_init(encoder)
63 | nl = 1
64 | for m in encoder.modules():
65 | if isinstance(m, nn.Linear):
66 | if nl == nlayers:
67 | torch.nn.init.xavier_normal(m.weight, gain=0.01)
68 | else:
69 | torch.nn.init.xavier_normal(m.weight)
70 | if m.bias is not None:
71 | m.bias.data.zero_()
72 | nl += 1
73 |
74 | return encoder
75 |
76 |
77 | ## 2D convolution layers
78 | def conv2d(norm_type, in_planes, out_planes, kernel_size=3, stride=1, num_groups=2):
79 | if norm_type == 'batch':
80 | return nn.Sequential(
81 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size - 1) // 2,
82 | bias=True),
83 | nn.BatchNorm2d(out_planes),
84 | nn.LeakyReLU(0.2, inplace=True)
85 | )
86 | elif norm_type == 'group':
87 | return nn.Sequential(
88 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size - 1) // 2,
89 | bias=True),
90 | nn.GroupNorm(num_groups, out_planes),
91 | nn.LeakyReLU(0.2, inplace=True)
92 | )
93 | else:
94 | return nn.Sequential(
95 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size - 1) // 2,
96 | bias=True),
97 | nn.LeakyReLU(0.2, inplace=True)
98 | )
99 |
100 |
101 | def deconv2d(in_planes, out_planes):
102 | return nn.Sequential(
103 | nn.ConvTranspose2d(in_planes, out_planes, kernel_size=4, stride=2, padding=1, bias=True),
104 | nn.LeakyReLU(0.2, inplace=True)
105 | )
106 |
107 |
108 | def upconv2d(in_planes, out_planes, mode='bilinear'):
109 | if mode == 'nearest':
110 | print('Using NN upsample!!')
111 | upconv = nn.Sequential(
112 | nn.Upsample(scale_factor=2, mode=mode),
113 | nn.ReflectionPad2d(1),
114 | nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=1, padding=0),
115 | nn.LeakyReLU(0.2, inplace=True)
116 | )
117 | return upconv
118 |
119 |
120 | def decoder2d(nlayers, nz_shape, nc_input, norm_type='batch', nc_final=1, nc_min=8, nc_step=1, init_fc=True,
121 | use_deconv=False, upconv_mode='bilinear', num_groups=2):
122 | ''' Simple 3D encoder with nlayers.
123 |
124 | Args:
125 | nlayers: number of decoder layers
126 | nz_shape: number of bottleneck
127 | nc_input: number of channels to start upconvolution from
128 | use_bn: whether to use batch_norm
129 | nc_final: number of output channels
130 | nc_min: number of min channels
131 | nc_step: double number of channels every nc_step layers
132 | init_fc: initial features are not spatial, use an fc & unsqueezing to make them 3D
133 | '''
134 | modules = []
135 | if init_fc:
136 | modules.append(fc('batch', nz_shape, nc_input))
137 | for d in range(3):
138 | modules.append(Unsqueeze(2))
139 | nc_output = nc_input
140 | for nl in range(nlayers):
141 | if (nl % nc_step == 0) and (nc_output // 2 >= nc_min):
142 | nc_output = nc_output // 2
143 | if use_deconv:
144 | print('Using deconv decoder!')
145 | modules.append(deconv2d(nc_input, nc_output))
146 | nc_input = nc_output
147 | modules.append(conv2d(norm_type, nc_input, nc_output, num_groups=num_groups // 2))
148 | else:
149 | modules.append(upconv2d(nc_input, nc_output, mode=upconv_mode))
150 | nc_input = nc_output
151 | modules.append(conv2d(norm_type, nc_input, nc_output, num_groups=num_groups // 2))
152 |
153 | modules.append(nn.Conv2d(nc_output, nc_final, kernel_size=3, stride=1, padding=1, bias=True))
154 | decoder = nn.Sequential(*modules)
155 | net_init(decoder)
156 | return decoder
157 |
158 |
159 | ## 3D convolution layers
160 | def conv3d(norm_type, in_planes, out_planes, kernel_size=3, stride=1, num_groups=2):
161 | if norm_type == 'batch':
162 | return nn.Sequential(
163 | nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size - 1) // 2,
164 | bias=True),
165 | nn.BatchNorm3d(out_planes),
166 | nn.LeakyReLU(0.2, inplace=True)
167 | )
168 | elif norm_type == 'group':
169 | return nn.Sequential(
170 | nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size - 1) // 2,
171 | bias=True),
172 | nn.GroupNorm(num_groups, out_planes),
173 | nn.LeakyReLU(0.2, inplace=True)
174 | )
175 | else:
176 | return nn.Sequential(
177 | nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size - 1) // 2,
178 | bias=True),
179 | nn.LeakyReLU(0.2, inplace=True)
180 | )
181 |
182 |
183 | def deconv3d(norm_type, in_planes, out_planes, num_groups=2):
184 | if norm_type == 'batch':
185 | return nn.Sequential(
186 | nn.ConvTranspose3d(in_planes, out_planes, kernel_size=4, stride=2, padding=1, bias=True),
187 | nn.BatchNorm3d(out_planes),
188 | nn.LeakyReLU(0.2, inplace=True)
189 | )
190 | elif norm_type == 'group':
191 | return nn.Sequential(
192 | nn.ConvTranspose3d(in_planes, out_planes, kernel_size=4, stride=2, padding=1, bias=True),
193 | nn.GroupNorm(num_groups, out_planes),
194 | nn.LeakyReLU(0.2, inplace=True)
195 | )
196 | else:
197 | return nn.Sequential(
198 | nn.ConvTranspose3d(in_planes, out_planes, kernel_size=4, stride=2, padding=1, bias=True),
199 | nn.LeakyReLU(0.2, inplace=True)
200 | )
201 |
202 |
203 | ## 3D Network Modules
204 | def encoder3d(nlayers, norm_type='batch', nc_input=1, nc_max=128, nc_l1=8, nc_step=1, nz_shape=20):
205 | ''' Simple 3D encoder with nlayers.
206 |
207 | Args:
208 | nlayers: number of encoder layers
209 | use_bn: whether to use batch_norm
210 | nc_input: number of input channels
211 | nc_max: number of max channels
212 | nc_l1: number of channels in layer 1
213 | nc_step: double number of channels every nc_step layers
214 | nz_shape: size of bottleneck layer
215 | '''
216 | modules = []
217 | nc_output = nc_l1
218 | for nl in range(nlayers):
219 | if (nl >= 1) and (nl % nc_step == 0) and (nc_output <= nc_max * 2):
220 | nc_output *= 2
221 |
222 | modules.append(conv3d(norm_type, nc_input, nc_output, stride=1))
223 | nc_input = nc_output
224 | modules.append(conv3d(norm_type, nc_input, nc_output, stride=1))
225 | modules.append(torch.nn.MaxPool3d(kernel_size=2, stride=2))
226 |
227 | modules.append(Flatten())
228 | modules.append(fc_stack(nc_output, nz_shape, 2, norm_type))
229 | encoder = nn.Sequential(*modules)
230 | net_init(encoder)
231 | return encoder, nc_output
232 |
233 |
234 | def decoder3d(nlayers, nz_shape, nc_input, norm_type='batch', nc_final=1, nc_min=8, nc_step=1, init_fc=True):
235 | ''' Simple 3D encoder with nlayers.
236 |
237 | Args:
238 | nlayers: number of decoder layers
239 | nz_shape: number of bottleneck
240 | nc_input: number of channels to start upconvolution from
241 | use_bn: whether to use batch_norm
242 | nc_final: number of output channels
243 | nc_min: number of min channels
244 | nc_step: double number of channels every nc_step layers
245 | init_fc: initial features are not spatial, use an fc & unsqueezing to make them 3D
246 | '''
247 | modules = []
248 | if init_fc:
249 | modules.append(fc('batch', nz_shape, nc_input))
250 | for d in range(3):
251 | modules.append(Unsqueeze(2))
252 | nc_output = nc_input
253 | for nl in range(nlayers):
254 | if (nl % nc_step == 0) and (nc_output // 2 >= nc_min):
255 | nc_output = nc_output // 2
256 |
257 | modules.append(deconv3d(norm_type, nc_input, nc_output))
258 | nc_input = nc_output
259 | modules.append(conv3d(norm_type, nc_input, nc_output))
260 |
261 | modules.append(nn.Conv3d(nc_output, nc_final, kernel_size=3, stride=1, padding=1, bias=True))
262 | decoder = nn.Sequential(*modules)
263 | net_init(decoder)
264 | return decoder
265 |
266 |
267 | def net_init(net):
268 | for m in net.modules():
269 | if isinstance(m, nn.Linear):
270 | m.weight.data.normal_(0, 0.02)
271 | if m.bias is not None:
272 | m.bias.data.zero_()
273 |
274 | if isinstance(m, nn.Conv2d): # or isinstance(m, nn.ConvTranspose2d):
275 | m.weight.data.normal_(0, 0.02)
276 | if m.bias is not None:
277 | m.bias.data.zero_()
278 |
279 | if isinstance(m, nn.ConvTranspose2d):
280 | # Initialize Deconv with bilinear weights.
281 | base_weights = bilinear_init(m.weight.data.size(-1))
282 | base_weights = base_weights.unsqueeze(0).unsqueeze(0)
283 | m.weight.data = base_weights.repeat(m.weight.data.size(0), m.weight.data.size(1), 1, 1)
284 | if m.bias is not None:
285 | m.bias.data.zero_()
286 |
287 | if isinstance(m, nn.Conv3d) or isinstance(m, nn.ConvTranspose3d):
288 | m.weight.data.normal_(0, 0.02)
289 | if m.bias is not None:
290 | m.bias.data.zero_()
291 |
292 | elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm3d):
293 | m.weight.data.fill_(1)
294 | m.bias.data.zero_()
295 |
296 |
297 | def bilinear_init(kernel_size=4):
298 | # Following Caffe's BilinearUpsamplingFiller
299 | # https://github.com/BVLC/caffe/pull/2213/files
300 | import numpy as np
301 | width = kernel_size
302 | height = kernel_size
303 | f = int(np.ceil(width / 2.))
304 | cc = (2 * f - 1 - f % 2) / (2. * f)
305 | weights = torch.zeros((height, width))
306 | for y in range(height):
307 | for x in range(width):
308 | weights[y, x] = (1 - np.abs(x / f - cc)) * (1 - np.abs(y / f - cc))
309 |
310 | return weights
311 |
312 |
313 | if __name__ == '__main__':
314 | decoder2d(5, None, 256, use_deconv=True, init_fc=False)
315 | bilinear_init()
316 |
--------------------------------------------------------------------------------
/util/helpers/visualize.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | from util.helpers.draw_smal_joints import SMALJointDrawer
4 | import numpy as np
5 |
6 |
7 | # generate visualizations, code adopted from WLDO
8 | class Visualizer():
9 |
10 | @staticmethod
11 | def generate_output_figures(preds, vis_refine=False):
12 | '''
13 | Args:
14 | preds: predictions from model
15 | vis_refine: whether return refined visualizations
16 |
17 | Returns: visualization with predicted 2D keypoints and silhouettes
18 | '''
19 | marker_size = 8
20 | thickness = 4
21 | smal_drawer = SMALJointDrawer()
22 | real_rgb_vis = preds['img_orig']
23 | keypoints_gt = preds['keypoints']
24 | real_rgb_vis_kp = smal_drawer.draw_joints(preds['img_orig'], preds['synth_landmarks'],
25 | visible=preds['keypoints'][:, :, [2]],
26 | marker_size=marker_size,
27 | thickness=thickness, normalized=False)
28 | pic_on_pic = preds['img_orig'].cpu() * (1-preds['synth_silhouettes'].cpu()) + \
29 | preds['synth_xyz'].cpu() * preds['synth_silhouettes'].cpu()
30 | pic_npy = pic_on_pic.data.cpu()
31 |
32 | sil_err = torch.cat([
33 | preds['seg'].cpu(), preds['synth_silhouettes'].cpu(), torch.zeros_like(preds['seg']).cpu()
34 | ], dim=1)
35 |
36 | sil_err_npy = sil_err.data.cpu()
37 | sil_err_npy = sil_err_npy + (1.0 - preds['img_border_mask'].cpu())
38 | output_figs = torch.stack(
39 | [real_rgb_vis, pic_npy, real_rgb_vis_kp, sil_err_npy],dim = 1) # Batch Size, 4 (Images), RGB, H, W
40 |
41 | if vis_refine:
42 | real_rgb_vis_kp_re = smal_drawer.draw_joints(preds['img_orig'], preds['synth_landmarks_re'],
43 | visible=preds['keypoints'][:, :, [2]],
44 | marker_size=marker_size, thickness=thickness, normalized=False)
45 |
46 | pic_on_pic = preds['img_orig'].cpu() * (1 - preds['synth_silhouettes_re'].cpu()) + \
47 | preds['synth_xyz_re'].cpu() * preds['synth_silhouettes_re'].cpu()
48 | pic_npy_re = pic_on_pic.data.cpu()
49 |
50 | sil_err = torch.cat([
51 | preds['seg'].cpu(), preds['synth_silhouettes_re'].cpu(), torch.zeros_like(preds['seg']).cpu()
52 | ], dim=1)
53 | sil_err_npy_re = sil_err.data.cpu()
54 | sil_err_npy_re = sil_err_npy_re + (1.0 - preds['img_border_mask'].cpu())
55 |
56 | output_figs = torch.stack(
57 | [real_rgb_vis, pic_npy, real_rgb_vis_kp, sil_err_npy, pic_npy_re, real_rgb_vis_kp_re, sil_err_npy_re
58 | ], dim=1)
59 | return output_figs
60 |
61 | @staticmethod
62 | def generate_output_figures_v2(preds, vis_refine=False):
63 | # add recovered mesh in an alternative view
64 | marker_size = 8
65 | thickness = 4
66 |
67 | smal_drawer = SMALJointDrawer()
68 | real_rgb_vis = preds['img_orig']
69 | keypoints_gt = preds['keypoints']
70 | real_rgb_vis_kp = smal_drawer.draw_joints(preds['img_orig'], preds['synth_landmarks'],
71 | visible=preds['keypoints'][:, :, [2]],
72 | marker_size=marker_size,
73 | thickness=thickness, normalized=False)
74 | pic_on_pic = preds['img_orig'].cpu() * (1-preds['synth_silhouettes'].cpu()) + \
75 | preds['synth_xyz'].cpu() * preds['synth_silhouettes'].cpu()
76 | pic_npy = pic_on_pic.data.cpu()
77 |
78 | sil_err = torch.cat([
79 | preds['seg'].cpu(), preds['synth_silhouettes'].cpu(), torch.zeros_like(preds['seg']).cpu()
80 | ], dim=1)
81 |
82 | sil_err_npy = sil_err.data.cpu()
83 | sil_err_npy = sil_err_npy + (1.0 - preds['img_border_mask'].cpu())
84 | output_figs = torch.stack(
85 | [real_rgb_vis, pic_npy, real_rgb_vis_kp, sil_err_npy],dim = 1) # Batch Size, 4 (Images), RGB, H, W
86 |
87 | if vis_refine:
88 | real_rgb_vis_kp_re = smal_drawer.draw_joints(preds['img_orig'], preds['synth_landmarks_re'],
89 | visible=preds['keypoints'][:, :, [2]],
90 | marker_size=marker_size, thickness=thickness, normalized=False)
91 |
92 | pic_on_pic = preds['img_orig'].cpu() * (1 - preds['synth_silhouettes_re'].cpu()) + \
93 | preds['synth_xyz_re'].cpu() * preds['synth_silhouettes_re'].cpu()
94 | pic_npy_re = pic_on_pic.data.cpu()
95 |
96 | sil_err = torch.cat([
97 | preds['seg'].cpu(), preds['synth_silhouettes_re'].cpu(), torch.zeros_like(preds['seg']).cpu()
98 | ], dim=1)
99 | sil_err_npy_re = sil_err.data.cpu()
100 | sil_err_npy_re = sil_err_npy_re + (1.0 - preds['img_border_mask'].cpu())
101 |
102 | synth_xyz_re_cano = preds['synth_xyz_re_cano'].cpu()
103 |
104 | output_figs = torch.stack(
105 | [real_rgb_vis, pic_npy, real_rgb_vis_kp, sil_err_npy, pic_npy_re, real_rgb_vis_kp_re, sil_err_npy_re,
106 | synth_xyz_re_cano], dim=1)
107 | return output_figs
108 |
109 | @staticmethod
110 | def generate_demo_output(preds):
111 | """Figure output of: [raw_img, cropped, mesh_view, mesh & raw, silh_view]"""
112 | marker_size = 8
113 | thickness = 4
114 |
115 | smal_drawer = SMALJointDrawer()
116 |
117 | #synth_render_vis = smal_drawer.draw_joints(preds['synth_xyz'], preds['synth_landmarks'], marker_size=marker_size, thickness=thickness, normalized=False)
118 | synth_render_vis = preds['synth_xyz'].cpu()
119 |
120 | # Real image (with bbox overlayed)
121 | real_rgb_vis = preds['img_orig'].cpu()
122 |
123 | #Overlaid bbox
124 | # bbox_overlay = np.zeros(real_rgb_vis.shape).astype(np.float32)
125 | # for n, bbox in enumerate(preds['bbox']):
126 | # (x0, y0, width, height) = list(map(int, bbox.cpu().numpy()))
127 | # bbox_overlay[n, 0, y0:y0+height, x0:x0+width] = 1. # red channel overlay
128 | #
129 | # real_rgb_vis += .4 * bbox_overlay # overlay bbox
130 |
131 | # real cropped
132 | real_rgb_vis_cropped = preds['img'].cpu()
133 |
134 | pic_on_pic = preds['img'] * 0.4 + preds['synth_xyz'].cpu() * 0.6
135 | pic_npy = pic_on_pic.data.cpu()
136 |
137 | batch, _, H, W = pic_npy.shape
138 | sil_err = torch.cat([
139 | preds['synth_silhouettes'], preds['synth_silhouettes'], preds['synth_silhouettes']], dim = 1)
140 |
141 | sil_err_npy = sil_err.data.cpu()
142 | output_figs = torch.stack(
143 | [real_rgb_vis, synth_render_vis, pic_npy, sil_err_npy], dim = 1) # Batch Size, 4 (Images), RGB, H, W
144 |
145 | return output_figs
146 |
147 | @staticmethod
148 | def draw_mesh_plotly(
149 | title,
150 | verts, faces,
151 | up=dict(x=0,y=1,z=0), eye=dict(x=0.0, y=0.0, z=1.0),
152 | hack_box_size = 1.0,
153 | center_mesh = True):
154 |
155 | camera = dict(up=up, center=dict(x=0, y=0, z=0), eye=eye)
156 | scene = dict(
157 | xaxis = dict(nticks=10, range=[-1,1],),
158 | yaxis = dict(nticks=10, range=[-1,1],),
159 | zaxis = dict(nticks=10, range=[-1,1],),
160 | camera = camera)
161 |
162 | centre_of_mass = torch.mean(verts, dim = 0, keepdim=True)
163 | if not center_mesh:
164 | centre_of_mass = torch.zeros_like(centre_of_mass)
165 |
166 | output_verts = (verts - centre_of_mass).data.cpu().numpy()
167 | output_faces = faces.data.cpu().numpy()
168 |
169 | hack_points = np.array([
170 | [-1.0, -1.0, -1.0],
171 | [-1.0, -1.0, 1.0],
172 | [-1.0, 1.0, -1.0],
173 | [-1.0, 1.0, 1.0],
174 | [1.0, -1.0, -1.0],
175 | [1.0, -1.0, 1.0],
176 | [1.0, 1.0, -1.0],
177 | [1.0, 1.0, 1.0]]) * hack_box_size
178 |
179 | vis_fig = go.Figure(data = [
180 | go.Mesh3d(
181 | x = output_verts[:, 0],
182 | y = output_verts[:, 1],
183 | z = -1 * output_verts[:, 2],
184 | i = output_faces[:, 0],
185 | j = output_faces[:, 1],
186 | k = output_faces[:, 2],
187 | color='cornflowerblue'
188 | ),
189 | go.Scatter3d(
190 | x = hack_points[:, 0],
191 | y = hack_points[:, 1],
192 | z = hack_points[:, 2],
193 | mode='markers',
194 | name='_fake_pts',
195 | visible=True,
196 | marker=dict(
197 | size=1,
198 | opacity = 0,
199 | color=(0.0, 0.0, 0.0),
200 | )
201 | )])
202 |
203 | vis_fig.update_scenes(patch = scene)
204 | vis_fig.update_layout(title=title)
205 | return vis_fig
206 |
207 | def draw_double_mesh_plotly(
208 | viz,
209 | title,
210 | verts, faces,
211 | verts2,
212 | joints_3d,
213 | gt_joints_3d,
214 | visdom_env_imgs,
215 | up=dict(x=0,y=1,z=0), eye=dict(x=0.0, y=0.0, z=1.0),
216 | hack_box_size = 1.0,
217 | center_mesh = True):
218 |
219 | camera = dict(up=up, center=dict(x=0, y=0, z=0), eye=eye)
220 | scene = dict(
221 | xaxis = dict(nticks=10, range=[-1,1],),
222 | yaxis = dict(nticks=10, range=[-1,1],),
223 | zaxis = dict(nticks=10, range=[-1,1],),
224 | camera = camera)
225 |
226 | centre_of_mass = torch.mean(verts, dim = 0, keepdim=True)
227 | if not center_mesh:
228 | centre_of_mass = torch.zeros_like(centre_of_mass)
229 |
230 | output_joints_3d = (joints_3d - centre_of_mass).data.cpu().numpy()
231 | output_verts = (verts - centre_of_mass).data.cpu().numpy()
232 | output_verts2 = (verts2 - centre_of_mass).data.cpu().numpy()
233 | output_joints_3d2 = (gt_joints_3d - centre_of_mass).data.cpu().numpy()
234 | output_faces = faces.data.cpu().numpy()
235 |
236 | hack_points = np.array([
237 | [-1.0, -1.0, -1.0],
238 | [-1.0, -1.0, 1.0],
239 | [-1.0, 1.0, -1.0],
240 | [-1.0, 1.0, 1.0],
241 | [1.0, -1.0, -1.0],
242 | [1.0, -1.0, 1.0],
243 | [1.0, 1.0, -1.0],
244 | [1.0, 1.0, 1.0]]) * hack_box_size
245 |
246 | vis_fig = go.Figure(data = [
247 | go.Mesh3d(
248 | x = output_verts[:, 0],
249 | y = output_verts[:, 1],
250 | z = -1 * output_verts[:, 2],
251 | i = output_faces[:, 0],
252 | j = output_faces[:, 1],
253 | k = output_faces[:, 2],
254 | color='cornflowerblue',
255 | opacity=0.5,
256 | ),
257 | go.Scatter3d(
258 | x = output_joints_3d[:, 0],
259 | y = output_joints_3d[:, 1],
260 | z = -1 * output_joints_3d[:, 2],
261 | mode='markers',
262 | marker = dict(
263 | size=8,
264 | color='red'
265 | )
266 | ),
267 | go.Mesh3d(
268 | x = output_verts2[:, 0],
269 | y = output_verts2[:, 1],
270 | z = -1 * output_verts2[:, 2],
271 | i = output_faces[:, 0],
272 | j = output_faces[:, 1],
273 | k = output_faces[:, 2],
274 | color='green',
275 | opacity=0.5,
276 | ),
277 | go.Scatter3d(
278 | x = output_joints_3d2[:, 0],
279 | y = output_joints_3d2[:, 1],
280 | z = -1 * output_joints_3d2[:, 2],
281 | mode='markers',
282 | marker = dict(
283 | size=8,
284 | color='purple'
285 | )
286 | ),
287 | go.Scatter3d(
288 | x = hack_points[:, 0],
289 | y = hack_points[:, 1],
290 | z = hack_points[:, 2],
291 | mode='markers',
292 | name='_fake_pts',
293 | visible=True,
294 | marker=dict(
295 | size=1,
296 | opacity = 0,
297 | color=(0.0, 0.0, 0.0),
298 | )
299 | )])
300 |
301 | vis_fig.update_scenes(patch = scene)
302 | vis_fig.update_layout(title=title)
303 | return vis_fig
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, absolute_import, division
2 |
3 | import os
4 | import sys
5 | import time
6 | import numpy as np
7 | from tqdm import tqdm
8 | import cv2
9 | import argparse
10 | import torch
11 | import torch.nn as nn
12 | import torch.optim
13 | from torch.utils.data import DataLoader
14 | from model.mesh_graph_hg import MeshGraph_hg
15 | from util import config
16 | from util.helpers.visualize import Visualizer
17 | from util.metrics import Metrics
18 | from datasets.stanford import BaseDataset
19 | from scipy.spatial.transform import Rotation as R
20 |
21 | def main(args):
22 |
23 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_ids
24 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25 |
26 | if not os.path.exists(args.output_dir):
27 | os.mkdir(args.output_dir)
28 | # set model
29 | model = MeshGraph_hg(device, args.shape_family_id, args.num_channels, args.num_layers, args.betas_scale,
30 | args.shape_init, args.local_feat, num_downsampling=args.num_downsampling,
31 | render_rgb=args.save_results)
32 | model = nn.DataParallel(model).to(device)
33 | # set data
34 | print("Evaluate on {} dataset".format(args.dataset))
35 | dataset_eval = BaseDataset(args.dataset, param_dir=args.param_dir, is_train=False, use_augmentation=False)
36 | data_loader_eval = DataLoader(dataset_eval, batch_size=args.batch_size, shuffle=False, num_workers=args.num_works)
37 | # set optimizer
38 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
39 |
40 | if os.path.isfile(args.resume):
41 | print("=> loading checkpoint {}".format(args.resume))
42 | checkpoint = torch.load(args.resume)
43 | model.load_state_dict(checkpoint['state_dict'])
44 | if args.load_optimizer:
45 | optimizer.load_state_dict(checkpoint['optimizer'])
46 | args.start_epoch = checkpoint['epoch'] + 1
47 | print("=> loaded checkpoint {} (epoch {})".format(args.resume, checkpoint['epoch']))
48 | else:
49 | print("No checkpoint found")
50 |
51 | pck, iou_silh, pck_by_part, pck_re, iou_re = run_evaluation(model, dataset_eval, data_loader_eval, device, args)
52 | print("Evaluate only, PCK: {:6.4f}, IOU: {:6.4f}, PCK_re: {:6.4f}, IOU_re: {:6.4f}"
53 | .format(pck, iou_silh, pck_re, iou_re))
54 | return
55 |
56 |
57 | def run_evaluation(model, dataset, data_loader, device, args):
58 |
59 | model.eval()
60 | result_dir = args.output_dir
61 | batch_size = args.batch_size
62 |
63 | pck = np.zeros((len(dataset)))
64 | pck_by_part = {group: np.zeros((len(dataset))) for group in config.KEYPOINT_GROUPS}
65 | pck_by_part_re = {group: np.zeros((len(dataset))) for group in config.KEYPOINT_GROUPS}
66 | acc_sil_2d = np.zeros(len(dataset))
67 |
68 | pck_re = np.zeros((len(dataset)))
69 | acc_sil_2d_re = np.zeros(len(dataset))
70 |
71 | smal_pose = np.zeros((len(dataset), 105))
72 | smal_betas = np.zeros((len(dataset), 20))
73 | smal_camera = np.zeros((len(dataset), 3))
74 | smal_imgname = []
75 | # rotate estimated mesh to visualize in an alternative view
76 | rot_matrix = torch.from_numpy(R.from_euler('y', -90, degrees=True).as_dcm()).float().to(device)
77 | tqdm_iterator = tqdm(data_loader, desc='Eval', total=len(data_loader))
78 |
79 | for step, batch in enumerate(tqdm_iterator):
80 | with torch.no_grad():
81 | preds = {}
82 |
83 | keypoints = batch['keypoints'].to(device)
84 | keypoints_norm = batch['keypoints_norm'].to(device)
85 | seg = batch['seg'].to(device)
86 | has_seg = batch['has_seg']
87 | img = batch['img'].to(device)
88 | img_border_mask = batch['img_border_mask'].to(device)
89 | verts, joints, shape, pred_codes = model(img)
90 | scale_pred, trans_pred, pose_pred, betas_pred, betas_scale_pred = pred_codes
91 | pred_camera = torch.cat([scale_pred[:, [0]], torch.ones(keypoints.shape[0], 2).cuda() * config.IMG_RES / 2],
92 | dim=1)
93 | faces = model.module.smal.faces.unsqueeze(0).expand(verts.shape[0], 7774, 3)
94 | labelled_joints_3d = joints[:, config.MODEL_JOINTS]
95 | synth_rgb, synth_silhouettes = model.module.model_renderer(verts, faces, pred_camera)
96 | synth_silhouettes = synth_silhouettes.unsqueeze(1)
97 | synth_landmarks = model.module.model_renderer.project_points(labelled_joints_3d, pred_camera)
98 |
99 | verts_refine, joints_refine, _, _ = model.module.smal(betas_pred, pose_pred, trans=trans_pred,
100 | del_v=shape,
101 | betas_logscale=betas_scale_pred)
102 | labelled_joints_3d_refine = joints_refine[:, config.MODEL_JOINTS]
103 | synth_rgb_refine, synth_silhouettes_refine = model.module.model_renderer(verts_refine, faces, pred_camera)
104 | synth_silhouettes_refine = synth_silhouettes_refine.unsqueeze(1)
105 | synth_landmarks_refine = model.module.model_renderer.project_points(labelled_joints_3d_refine,
106 | pred_camera)
107 |
108 | if args.save_results:
109 | synth_rgb = torch.clamp(synth_rgb[0], 0.0, 1.0)
110 | synth_rgb_refine = torch.clamp(synth_rgb_refine[0], 0.0, 1.0)
111 | # visualize in another view
112 | verts_refine_cano = verts_refine - torch.mean(verts_refine, dim=1, keepdim=True)
113 | verts_refine_cano = (rot_matrix @ verts_refine_cano.unsqueeze(-1)).squeeze(-1)
114 | # increase the depth such that the rendered the shapes are in within the image
115 | verts_refine_cano[:, :, 2] = verts_refine_cano[:, :, 2] + 15
116 | synth_rgb_refine_cano, _ = model.module.model_renderer(verts_refine_cano, faces,
117 | pred_camera)
118 | synth_rgb_refine_cano = torch.clamp(synth_rgb_refine_cano[0], 0.0, 1.0)
119 | preds['synth_xyz_re_cano'] = synth_rgb_refine_cano
120 |
121 | preds['pose'] = pose_pred
122 | preds['betas'] = betas_pred
123 | preds['camera'] = pred_camera
124 | preds['trans'] = trans_pred
125 |
126 | preds['verts'] = verts
127 | preds['joints_3d'] = labelled_joints_3d
128 | preds['faces'] = faces
129 |
130 | preds['acc_PCK'] = Metrics.PCK(synth_landmarks, keypoints_norm, seg, has_seg)
131 | preds['acc_IOU'] = Metrics.IOU(synth_silhouettes, seg, img_border_mask, mask=has_seg)
132 |
133 | preds['acc_PCK_re'] = Metrics.PCK(synth_landmarks_refine, keypoints_norm, seg, has_seg)
134 | preds['acc_IOU_re'] = Metrics.IOU(synth_silhouettes_refine, seg, img_border_mask, mask=has_seg)
135 |
136 | for group, group_kps in config.KEYPOINT_GROUPS.items():
137 | preds[f'{group}_PCK'] = Metrics.PCK(synth_landmarks, keypoints_norm, seg, has_seg,
138 | thresh_range=[0.15],
139 | idxs=group_kps)
140 | preds[f'{group}_PCK_RE'] = Metrics.PCK(synth_landmarks_refine, keypoints_norm, seg, has_seg,
141 | thresh_range=[0.15],
142 | idxs=group_kps)
143 | preds['synth_xyz'] = synth_rgb
144 | preds['synth_silhouettes'] = synth_silhouettes
145 | preds['synth_landmarks'] = synth_landmarks
146 | preds['synth_xyz_re'] = synth_rgb_refine
147 | preds['synth_landmarks_re'] = synth_landmarks_refine
148 | preds['synth_silhouettes_re'] = synth_silhouettes_refine
149 |
150 | assert not any(k in preds for k in batch.keys())
151 | preds.update(batch)
152 |
153 | curr_batch_size = preds['synth_landmarks'].shape[0]
154 | # compute accuracy for coarse stage
155 | pck[step * batch_size:step * batch_size + curr_batch_size] = preds['acc_PCK'].data.cpu().numpy()
156 | acc_sil_2d[step * batch_size:step * batch_size + curr_batch_size] = preds['acc_IOU'].data.cpu().numpy()
157 | smal_pose[step * batch_size:step * batch_size + curr_batch_size] = preds['pose'].data.cpu().numpy()
158 | smal_betas[step * batch_size:step * batch_size + curr_batch_size, :preds['betas'].shape[1]] = preds['betas'].data.cpu().numpy()
159 | smal_camera[step * batch_size:step * batch_size + curr_batch_size] = preds['camera'].data.cpu().numpy()
160 | # compute accuracy for refinement stage
161 | pck_re[step * batch_size:step * batch_size + curr_batch_size] = preds['acc_PCK_re'].data.cpu().numpy()
162 | acc_sil_2d_re[step * batch_size:step * batch_size + curr_batch_size] = preds['acc_IOU_re'].data.cpu().numpy()
163 | for part in pck_by_part:
164 | pck_by_part[part][step * batch_size:step * batch_size + curr_batch_size] = preds[f'{part}_PCK'].data.cpu().numpy()
165 | pck_by_part_re[part][step * batch_size:step * batch_size + curr_batch_size] = preds[
166 | f'{part}_PCK_RE'].data.cpu().numpy()
167 |
168 | if args.save_results:
169 | output_figs = np.transpose(
170 | Visualizer.generate_output_figures_v2(preds, vis_refine=True).data.cpu().numpy(),
171 | (0, 1, 3, 4, 2))
172 | for img_id in range(len(preds['imgname'])):
173 | imgname = preds['imgname'][img_id]
174 | output_fig_list = output_figs[img_id]
175 |
176 | path_parts = imgname.split('/')
177 | path_suffix = "{0}_{1}".format(path_parts[-2], path_parts[-1])
178 | img_file = os.path.join(result_dir, path_suffix)
179 | output_fig = np.hstack(output_fig_list)
180 | smal_imgname.append(path_suffix)
181 | npz_file = "{0}.npz".format(os.path.splitext(img_file)[0])
182 |
183 | cv2.imwrite(img_file, output_fig[:, :, ::-1] * 255.0)
184 | # np.savez_compressed(npz_file,
185 | # imgname=preds['imgname'][img_id],
186 | # pose=preds['pose'][img_id].data.cpu().numpy(),
187 | # betas=preds['betas'][img_id].data.cpu().numpy(),
188 | # camera=preds['camera'][img_id].data.cpu().numpy(),
189 | # trans=preds['trans'][img_id].data.cpu().numpy(),
190 | # acc_PCK=preds['acc_PCK'][img_id].data.cpu().numpy(),
191 | # # acc_SIL_2D=preds['acc_IOU'][img_id].data.cpu().numpy(),
192 | # **{f'{part}_PCK': preds[f'{part}_PCK'].data.cpu().numpy() for part in pck_by_part}
193 | # )
194 | report = f"""*** Final Results ***
195 |
196 | SIL IOU 2D: {np.nanmean(acc_sil_2d):.5f}
197 | PCK 2D: {np.nanmean(pck):.5f}
198 |
199 | SIL IOU 2D REFINE: {np.nanmean(acc_sil_2d_re):.5f}
200 | PCK 2D REFINE: {np.nanmean(pck_re):.5f}"""
201 |
202 | for part in pck_by_part:
203 | report += f'\n {part} PCK 2D: {np.nanmean(pck_by_part[part]):.5f}'
204 |
205 | for part in pck_by_part:
206 | report += f'\n {part} PCK 2D RE: {np.nanmean(pck_by_part_re[part]):.5f}'
207 | print(report)
208 |
209 | # save report to file
210 | with open(os.path.join(result_dir, '{}_report.txt'.format(args.dataset)), 'w') as outfile:
211 | print(report, file=outfile)
212 | return np.nanmean(pck), np.nanmean(acc_sil_2d), pck_by_part, np.nanmean(pck_re), np.nanmean(acc_sil_2d_re)
213 |
214 |
215 | if __name__ == '__main__':
216 | parser = argparse.ArgumentParser()
217 | parser.add_argument('--lr', default=0.0001, type=float)
218 | parser.add_argument('--output_dir', default='./logs/', type=str)
219 | parser.add_argument('--batch_size', default=32, type=int)
220 | parser.add_argument('--num_works', default=4, type=int)
221 | parser.add_argument('--gpu_ids', default='0', type=str)
222 | parser.add_argument('--resume', default=None, type=str)
223 | parser.add_argument('--load_optimizer', action='store_true')
224 | parser.add_argument('--shape_family_id', default=1, type=int)
225 | parser.add_argument('--dataset', default='stanford', type=str)
226 | parser.add_argument('--param_dir', default=None, type=str, help='Exported parameter folder to load')
227 | parser.add_argument('--save_results', action='store_true')
228 | parser.add_argument('--prior_betas', default='smal', type=str)
229 | parser.add_argument('--prior_pose', default='smal', type=str)
230 | parser.add_argument('--betas_scale', action='store_true')
231 | parser.add_argument('--num_channels', type=int, default=256, help='Number of channels in Graph Residual layers')
232 | parser.add_argument('--num_layers', type=int, default=5, help='Number of residuals blocks in the Graph CNN')
233 | parser.add_argument('--local_feat', action='store_true')
234 | parser.add_argument('--shape_init', default='smal', help='enable to initiate shape with mean shape')
235 | parser.add_argument('--num_downsampling', default=1, type=int)
236 |
237 | args = parser.parse_args()
238 | main(args)
--------------------------------------------------------------------------------
/model/smal_mesh_net_img.py:
--------------------------------------------------------------------------------
1 | """
2 | Mesh net model.
3 | """
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | from __future__ import print_function
7 |
8 | import os
9 | import os.path as osp
10 | import numpy as np
11 | import torch
12 | import torchvision
13 | import torch.nn as nn
14 | from util import net_blocks as nb
15 |
16 |
17 |
18 | # ------------- Modules ------------#
19 | # ----------------------------------#
20 | class ResNetConv(nn.Module):
21 | def __init__(self, n_blocks=4):
22 | super(ResNetConv, self).__init__()
23 | # if opts.use_resnet50:
24 | self.resnet = torchvision.models.resnet50(pretrained=True)
25 | # else:
26 | # self.resnet = torchvision.models.resnet18(pretrained=True)
27 | self.n_blocks = n_blocks
28 | # if self.opts.use_double_input:
29 | # self.fc = nb.fc_stack(512*16*8, 512*8*8, 2)
30 |
31 | def forward(self, x, y=None):
32 | # if self.opts.use_double_input and y is not None:
33 | # x = torch.cat([x, y], 2)
34 | img_feat_multiscale = [] # collect multi-scale features for local feature retrieve
35 | n_blocks = self.n_blocks
36 | x = self.resnet.conv1(x)
37 | x = self.resnet.bn1(x)
38 | x = self.resnet.relu(x)
39 | x = self.resnet.maxpool(x)
40 |
41 | if n_blocks >= 1:
42 | x = self.resnet.layer1(x)
43 | img_feat_multiscale.append(x)
44 | if n_blocks >= 2:
45 | x = self.resnet.layer2(x)
46 | img_feat_multiscale.append(torch.nn.Upsample(size=(56, 56), mode='bilinear')(x))
47 | if n_blocks >= 3:
48 | x = self.resnet.layer3(x)
49 | img_feat_multiscale.append(torch.nn.Upsample(size=(56, 56), mode='bilinear')(x))
50 | if n_blocks >= 4:
51 | x = self.resnet.layer4(x)
52 | img_feat_multiscale.append(torch.nn.Upsample(size=(56, 56), mode='bilinear')(x))
53 | # if self.opts.use_double_input and y is not None:
54 | # x = x.view(x.size(0), -1)
55 | # x = self.fc.forward(x)
56 | # x = x.view(x.size(0), 512, 8, 8)
57 |
58 | return x, torch.cat(img_feat_multiscale, dim=1)
59 |
60 |
61 | class Encoder(nn.Module):
62 | """
63 | Current:
64 | Resnet with 4 blocks (x32 spatial dim reduction)
65 | Another conv with stride 2 (x64)
66 | This is sent to 2 fc layers with final output nz_feat.
67 | """
68 |
69 | def __init__(self, input_shape, channels_per_group=16, n_blocks=4, nz_feat=100, bott_size=256):
70 | super(Encoder, self).__init__()
71 | self.resnet_conv = ResNetConv(n_blocks=4)
72 | num_norm_groups = bott_size // channels_per_group
73 | # if opts.use_resnet50:
74 | self.enc_conv1 = nb.conv2d('group', 2048, bott_size, stride=2, kernel_size=4, num_groups=num_norm_groups)
75 | # else:
76 | # self.enc_conv1 = nb.conv2d('group', 512, bott_size, stride=2, kernel_size=4, num_groups=num_norm_groups)
77 |
78 | nc_input = bott_size * (input_shape[0] // 64) * (input_shape[1] // 64)
79 | self.enc_fc = nb.fc_stack(nc_input, nz_feat, 2, 'batch')
80 | self.nenc_feat = nc_input
81 |
82 | nb.net_init(self.enc_conv1)
83 | self.avgpool = nn.AvgPool2d(7, stride=1)
84 |
85 | def forward(self, img, fg_img):
86 | resnet_feat, feat_multiscale = self.resnet_conv.forward(img, fg_img) # multi-scale feature is used to extract local feature for refinement
87 | out_enc_conv1 = self.enc_conv1(resnet_feat) # feature for predicting SMAL parameters
88 | out_resnet = self.avgpool(resnet_feat) # add an pooling layer to get global feature for mesh refinement
89 | out_enc_conv1 = out_enc_conv1.view(img.size(0), -1)
90 | feat = self.enc_fc.forward(out_enc_conv1)
91 | return feat, out_enc_conv1, out_resnet, feat_multiscale
92 |
93 |
94 | class ShapePredictor(nn.Module):
95 | """
96 | Outputs mesh deformations
97 | """
98 |
99 | def __init__(self, nz_feat, num_verts, left_idx, right_idx, shapedirs, use_delta_v=False, use_sym_idx=False,
100 | use_smal_betas=False, n_shape_feat=40):
101 | super(ShapePredictor, self).__init__()
102 | self.use_delta_v = use_delta_v
103 | self.use_sym_idx = use_sym_idx
104 | self.use_smal_betas = use_smal_betas
105 | self.ref_delta_v = torch.Tensor(np.zeros((num_verts, 3))).cuda()
106 |
107 | def forward(self, feat):
108 | if self.use_sym_idx:
109 | batch_size = feat.shape[0]
110 | delta_v = torch.Tensor(np.zeros((batch_size, self.num_verts, 3))).cuda()
111 | feat = self.fc(feat)
112 | self.shape_f = feat
113 |
114 | half_delta_v = self.pred_layer.forward(feat)
115 | half_delta_v = half_delta_v.view(half_delta_v.size(0), -1, 3)
116 | delta_v[:, self.left_idx, :] = half_delta_v
117 | half_delta_v[:, :, 1] = -1. * half_delta_v[:, :, 1]
118 | delta_v[:, self.right_idx, :] = half_delta_v
119 | else:
120 | delta_v = self.pred_layer.forward(feat)
121 | # Make it B x num_verts x 3
122 | delta_v = delta_v.view(delta_v.size(0), -1, 3)
123 | # print('shape: ( Mean = {}, Var = {} )'.format(delta_v.mean().data[0], delta_v.var().data[0]))
124 | return delta_v
125 |
126 |
127 | class PosePredictor(nn.Module):
128 | """
129 | """
130 |
131 | def __init__(self, nz_feat, num_joints=35):
132 | super(PosePredictor, self).__init__()
133 | self.pose_var = 1.0
134 | self.num_joints = num_joints
135 | self.pred_layer = nn.Linear(nz_feat, num_joints * 3)
136 | # bjb_edit
137 | self.pred_layer.weight.data.normal_(0, 1e-4)
138 | self.pred_layer.bias.data.normal_(0, 1e-4)
139 |
140 | def forward(self, feat):
141 | pose = self.pose_var * self.pred_layer.forward(feat)
142 |
143 | # Add this to have zero to correspond to frontal facing
144 | # edit by lic, frontal facing and upright
145 | pose[:, 0] += -1.20919958
146 | pose[:, 1] += 1.20919958
147 | pose[:, 2] += 1.20919958
148 | return pose
149 |
150 |
151 | class BetaScalePredictor(nn.Module):
152 | def __init__(self, nz_feat, nenc_feat, num_beta_scale=6, model_mean=None):
153 | super(BetaScalePredictor, self).__init__()
154 | self.model_mean = model_mean
155 | self.pred_layer = nn.Linear(nenc_feat, num_beta_scale)
156 | # bjb_edit
157 | self.pred_layer.weight.data.normal_(0, 1e-4)
158 | if model_mean is not None:
159 | self.pred_layer.bias.data = model_mean + torch.randn_like(model_mean) * 1e-4
160 | else:
161 | self.pred_layer.bias.data.normal_(0, 1e-4)
162 |
163 | def forward(self, feat, enc_feat):
164 | betas = self.pred_layer.forward(enc_feat)
165 |
166 | return betas
167 |
168 |
169 | class BetasPredictor(nn.Module):
170 | def __init__(self, nz_feat, nenc_feat, num_betas=20, model_mean=None):
171 | super(BetasPredictor, self).__init__()
172 | self.model_mean = model_mean
173 | self.pred_layer = nn.Linear(nenc_feat, num_betas)
174 | # bjb_edit
175 | self.pred_layer.weight.data.normal_(0, 1e-4)
176 | if model_mean is not None:
177 | self.pred_layer.bias.data = model_mean + torch.randn_like(model_mean) * 1e-4
178 | else:
179 | self.pred_layer.bias.data.normal_(0, 1e-4)
180 |
181 | def forward(self, feat, enc_feat):
182 | betas = self.pred_layer.forward(enc_feat)
183 | return betas
184 |
185 |
186 | class ScalePredictor(nn.Module):
187 | '''
188 | In case of perspective projection scale is focal length
189 | '''
190 |
191 | def __init__(self, nz, norm_f0, use_camera=True, scale_bias=1):
192 | super(ScalePredictor, self).__init__()
193 | self.use_camera = use_camera
194 | self.norm_f0 = norm_f0
195 | if self.use_camera:
196 | self.pred_layer = nn.Linear(nz, scale_bias)
197 | # else:
198 | # scale = np.zeros((opts.batch_size,1))
199 | # scale[:,0] = 0.
200 | # self.ref_camera = torch.Tensor(scale).cuda()
201 |
202 | def forward(self, feat):
203 | if not self.use_camera:
204 | scale = np.zeros((feat.shape[0], 1))
205 | scale[:, 0] = 0.
206 | return torch.Tensor(scale).cuda()
207 | if self.norm_f0 != 0:
208 | off = 0.
209 | else:
210 | off = 1.
211 | scale = self.pred_layer.forward(feat) + off
212 | return scale
213 |
214 |
215 | class TransPredictor(nn.Module):
216 | """
217 | Outputs [tx, ty] or [tx, ty, tz]
218 | """
219 |
220 | def __init__(self, nz, projection_type, fix_trans=False):
221 | super(TransPredictor, self).__init__()
222 | self.fix_trans = fix_trans
223 | if projection_type == 'orth':
224 | self.pred_layer = nn.Linear(nz, 2)
225 | elif projection_type == 'perspective':
226 | self.pred_layer_xy = nn.Linear(nz, 2)
227 | self.pred_layer_z = nn.Linear(nz, 1)
228 | self.pred_layer_xy.weight.data.normal_(0, 0.0001)
229 | self.pred_layer_xy.bias.data.normal_(0, 0.0001)
230 | self.pred_layer_z.weight.data.normal_(0, 0.0001)
231 | self.pred_layer_z.bias.data.normal_(0, 0.0001)
232 | else:
233 | print('Unknown projection type')
234 |
235 | def forward(self, feat):
236 | trans = torch.Tensor(np.zeros((feat.shape[0], 3))).cuda()
237 | f = torch.Tensor(np.zeros((feat.shape[0], 1))).cuda()
238 | feat_xy = feat
239 | feat_z = feat
240 |
241 | trans[:, :2] = self.pred_layer_xy(feat_xy)
242 | trans[:, 2] = 1 + self.pred_layer_z(feat_z)[:, 0]
243 |
244 | if self.fix_trans:
245 | trans[:, 2] = 1.
246 |
247 | # print('trans: ( Mean = {}, Var = {} )'.format(trans.mean().data[0], trans.var().data[0]))
248 | return trans
249 |
250 |
251 | class CodePredictor(nn.Module):
252 | def __init__(
253 | self, norm_f0, nz_feat=100, nenc_feat=2048,
254 | use_smal_betas=True,
255 | num_betas=27, # bjb_edit
256 | use_camera=True, scale_bias=1,
257 | fix_trans=False, betas_scale=False,
258 | use_smal_pose=True,
259 | shape_init=None):
260 |
261 | super(CodePredictor, self).__init__()
262 | self.use_smal_pose = use_smal_pose
263 | self.use_smal_betas = use_smal_betas
264 | self.use_camera = use_camera
265 | self.betas_scale = betas_scale
266 | self.scale_predictor = ScalePredictor(
267 | nz_feat, norm_f0, use_camera=use_camera, scale_bias=scale_bias)
268 | self.trans_predictor = TransPredictor(
269 | nz_feat, 'perspective', fix_trans=fix_trans)
270 |
271 | if self.use_smal_pose:
272 | self.pose_predictor = PosePredictor(nz_feat)
273 |
274 | if self.use_smal_betas:
275 | scale_init = None
276 | shape_betas_init = None
277 | if shape_init is not None:
278 | shape_betas_init = shape_init[:20]
279 | if shape_init.shape[0] == 26:
280 | scale_init = shape_init[20:]
281 |
282 | self.betas_predictor = BetasPredictor(
283 | nz_feat, nenc_feat, num_betas=20, model_mean=shape_betas_init)
284 | if self.betas_scale:
285 | self.betas_scale_predictor = BetaScalePredictor(
286 | nz_feat, nenc_feat, num_beta_scale=6, model_mean=scale_init)
287 |
288 | def forward(self, feat, enc_feat):
289 | if self.use_camera:
290 | scale_pred = self.scale_predictor.forward(feat)
291 | else:
292 | scale_pred = self.scale_predictor.ref_camera
293 |
294 | trans_pred = self.trans_predictor.forward(feat)
295 |
296 | if self.use_smal_pose:
297 | pose_pred = self.pose_predictor.forward(feat)
298 | else:
299 | pose_pred = None
300 |
301 | if self.use_smal_betas:
302 | betas_pred = self.betas_predictor.forward(feat, enc_feat)
303 | if self.betas_scale:
304 | betas_scale_pred = self.betas_scale_predictor.forward(feat, enc_feat)[:,
305 | :6] # choose first 6 for backward compat
306 | else:
307 | betas_scale_pred = None
308 | else:
309 | betas_pred = None
310 | betas_scale_pred = None
311 |
312 | return scale_pred, trans_pred, pose_pred, betas_pred, betas_scale_pred
313 |
314 |
315 | # ------------ Mesh Net ------------#
316 | # ----------------------------------#
317 | class MeshNet_img(nn.Module):
318 | def __init__(self,
319 | input_shape, betas_scale=False,
320 | norm_f0=2700., nz_feat=100,
321 | shape_init=None, return_feat=False):
322 | # Input shape is H x W of the image.
323 | super(MeshNet_img, self).__init__()
324 |
325 | self.bottleneck_size = 2048
326 | self.channels_per_group = 16
327 | self.shape_init = shape_init
328 |
329 | self.encoder = Encoder(
330 | input_shape,
331 | channels_per_group=self.channels_per_group,
332 | n_blocks=4, nz_feat=nz_feat, bott_size=self.bottleneck_size)
333 |
334 | self.code_predictor = CodePredictor(
335 | norm_f0, nz_feat=nz_feat, nenc_feat=self.encoder.nenc_feat, betas_scale=betas_scale,
336 | use_smal_betas=True, shape_init=self.shape_init)
337 | self.return_feat = return_feat
338 |
339 | def forward(self, img, masks=None, is_optimization=False, is_var_opt=False):
340 | img_feat, enc_feat, feat_resnet, feat_multiscale = self.encoder.forward(img, masks)
341 | codes_pred = self.code_predictor.forward(img_feat, enc_feat)
342 | if self.return_feat:
343 | return codes_pred, feat_resnet, feat_multiscale
344 | else:
345 | return codes_pred
346 |
347 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, absolute_import, division
2 |
3 | import os
4 | import numpy as np
5 | from tqdm import tqdm
6 | import cv2
7 | import argparse
8 | import torch
9 | import torch.nn as nn
10 | import torch.optim
11 | from torch.utils.data import DataLoader
12 | from torch.utils.tensorboard import SummaryWriter
13 | from model.model_v1 import MeshModel
14 | from util import config
15 | from util.helpers.visualize import Visualizer
16 | from util.loss_utils import kp_l2_loss, Shape_prior, mask_loss
17 | from util.metrics import Metrics
18 | from datasets.stanford import BaseDataset
19 | from util.logger import Logger
20 | from util.meter import AverageMeterSet
21 | from util.misc import save_checkpoint, adjust_learning_rate_exponential
22 | from util.pose_prior import Prior
23 | from util.joint_limits_prior import LimitPrior
24 | import pickle
25 |
26 | # Set some global varibles
27 | global_step = 0
28 | best_pck = 0
29 | best_pck_epoch = 0
30 |
31 |
32 | def main(args):
33 | global best_pck
34 | global best_pck_epoch
35 | global global_step
36 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_ids
37 | print("RESULTS: {0}".format(args.output_dir))
38 | if not os.path.exists(args.output_dir):
39 | os.mkdir(args.output_dir)
40 | # set up device
41 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42 | # set up model
43 | model = MeshModel(device, args.shape_family_id, args.betas_scale, args.shape_init, render_rgb=args.save_results)
44 | model = nn.DataParallel(model).to(device)
45 | # set up datasets
46 | dataset_train = BaseDataset(args.dataset, param_dir=args.param_dir, is_train=True, use_augmentation=True)
47 | dataset_eval = BaseDataset(args.dataset, param_dir=args.param_dir, is_train=False, use_augmentation=False)
48 | data_loader_train = DataLoader(dataset_train, batch_size=args.batch_size, shuffle=True, num_workers=args.num_works)
49 | data_loader_eval = DataLoader(dataset_eval, batch_size=args.batch_size, shuffle=False, num_workers=args.num_works)
50 |
51 | # set up optimizer
52 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
53 |
54 | writer = SummaryWriter(os.path.join(args.output_dir, 'train'))
55 |
56 | # set up criterion
57 | joint_limit_prior = LimitPrior(device)
58 | shape_prior = Shape_prior(args.prior_betas, args.shape_family_id, device)
59 | if args.resume:
60 | if os.path.isfile(args.resume):
61 | print("=> loading checkpoint {}".format(args.resume))
62 | checkpoint = torch.load(args.resume)
63 | model.load_state_dict(checkpoint['state_dict'])
64 | if args.load_optimizer:
65 | optimizer.load_state_dict(checkpoint['optimizer'])
66 | args.start_epoch = checkpoint['epoch'] + 1
67 | print("=> loaded checkpoint {} (epoch {})".format(args.resume, checkpoint['epoch']))
68 | logger = Logger(os.path.join(args.output_dir, 'log.txt'))
69 | logger.log_arguments(args)
70 | logger.set_names(['Epoch', 'LR', 'PCK', 'IOU'])
71 | else:
72 | print("=> no checkpoint found at {}".format(args.resume))
73 | else:
74 | logger = Logger(os.path.join(args.output_dir, 'log.txt'))
75 | logger.log_arguments(args)
76 | logger.set_names(['Epoch', 'LR', 'PCK', 'IOU'])
77 |
78 | if args.evaluate:
79 | pck, iou_silh, pck_by_part = run_evaluation(model, dataset_eval, data_loader_eval, device, args)
80 | print("Evaluate only, PCK: {}, IOU: {}".format(pck, iou_silh))
81 | return
82 |
83 | lr = args.lr
84 | for epoch in range(args.start_epoch, args.nEpochs):
85 |
86 | model.train()
87 | tqdm_iterator = tqdm(data_loader_train, desc='Train', total=len(data_loader_train))
88 | meters = AverageMeterSet()
89 |
90 | for step, batch in enumerate(tqdm_iterator):
91 | keypoints = batch['keypoints'].to(device)
92 | seg = batch['seg'].to(device)
93 | img = batch['img'].to(device)
94 |
95 | pred_codes = model(img)
96 | scale_pred, trans_pred, pose_pred, betas_pred, betas_scale_pred = pred_codes
97 | pred_camera = torch.cat([scale_pred[:, [0]], torch.ones(keypoints.shape[0], 2).cuda() * config.IMG_RES / 2],
98 | dim=1)
99 | # recover 3D mesh from SMAL parameters
100 | verts, joints, _, _ = model.module.smal(betas_pred, pose_pred, trans=trans_pred,
101 | betas_logscale=betas_scale_pred)
102 | faces = model.module.smal.faces.unsqueeze(0).expand(verts.shape[0], 7774, 3)
103 | labelled_joints_3d = joints[:, config.MODEL_JOINTS]
104 | # project 3D joints onto 2D space and apply 2D keypoints supervision
105 | synth_landmarks = model.module.model_renderer.project_points(labelled_joints_3d, pred_camera)
106 | loss_kpts = args.w_kpts * kp_l2_loss(synth_landmarks, keypoints[:, :, [1, 0, 2]], config.NUM_JOINTS)
107 | meters.update('loss_kpt', loss_kpts.item())
108 | loss = loss_kpts
109 |
110 | # apply shape prior constraint, either come from SMAL or unity from WLDO
111 | if args.w_betas_prior > 0:
112 | if args.prior_betas == 'smal':
113 | s_prior = args.w_betas_prior * shape_prior(betas_pred)
114 | elif args.prior_betas == 'unity':
115 | betas_pred = torch.cat([betas_pred, betas_scale_pred], dim=1)
116 | s_prior = args.w_betas_prior * shape_prior(betas_pred)
117 | else:
118 | Exception("Shape prior should come from either smal or unity")
119 | s_prior = 0
120 | meters.update('loss_prior', s_prior.item())
121 | loss += s_prior
122 |
123 | # apply pose prior constraint, either come from SMAL or unity from WLDO
124 | if args.w_pose_prior > 0:
125 | if args.prior_pose == 'smal':
126 | pose_prior_path = config.WALKING_PRIOR_FILE
127 | elif args.prior_pose == 'unity':
128 | pose_prior_path = config.UNITY_POSE_PRIOR
129 | else:
130 | Exception('The prior should come from either smal or unity')
131 | pose_prior_path = None
132 | pose_prior = Prior(pose_prior_path, device)
133 | p_prior = args.w_pose_prior * pose_prior(pose_pred)
134 | meters.update('pose_prior', p_prior.item())
135 | loss += p_prior
136 |
137 | meters.update('loss_all', loss.item())
138 | optimizer.zero_grad()
139 | loss.backward()
140 | optimizer.step()
141 | global_step += 1
142 | if step % 20 == 0:
143 | loss_values = meters.averages()
144 | for name, meter in loss_values.items():
145 | writer.add_scalar(name, meter, global_step)
146 | writer.flush()
147 |
148 | pck, iou_silh, pck_by_part = run_evaluation(model, dataset_eval, data_loader_eval, device, args)
149 |
150 | print("Epoch: {}, LR: {}, PCK: {}, IOU: {}".format(epoch, lr, pck, iou_silh))
151 | logger.append([epoch, lr, pck, iou_silh])
152 |
153 | is_best = pck > best_pck
154 | if pck > best_pck:
155 | best_pck_epoch = epoch
156 | best_pck = max(pck, best_pck)
157 | save_checkpoint({'epoch': epoch,
158 | 'state_dict': model.state_dict(),
159 | 'best_pck': best_pck,
160 | 'optimizer': optimizer.state_dict()},
161 | is_best, checkpoint=args.output_dir, filename='checkpoint.pth.tar')
162 | writer.close()
163 | logger.close()
164 |
165 |
166 | def run_evaluation(model, dataset, data_loader, device, args):
167 | model.eval()
168 | result_dir = args.output_dir
169 | batch_size = args.batch_size
170 |
171 | pck = np.zeros((len(dataset)))
172 | pck_by_part = {group: np.zeros((len(dataset))) for group in config.KEYPOINT_GROUPS}
173 | acc_sil_2d = np.zeros(len(dataset))
174 |
175 | smal_pose = np.zeros((len(dataset), 105))
176 | smal_betas = np.zeros((len(dataset), 20))
177 | smal_camera = np.zeros((len(dataset), 3))
178 | smal_imgname = []
179 |
180 | tqdm_iterator = tqdm(data_loader, desc='Eval', total=len(data_loader))
181 |
182 | for step, batch in enumerate(tqdm_iterator):
183 | with torch.no_grad():
184 | keypoints = batch['keypoints'].to(device)
185 | keypoints_norm = batch['keypoints_norm'].to(device)
186 | seg = batch['seg'].to(device)
187 | has_seg = batch['has_seg']
188 | img = batch['img'].to(device)
189 | img_border_mask = batch['img_border_mask'].to(device)
190 | pred_codes = model(img)
191 |
192 | scale_pred, trans_pred, pose_pred, betas_pred, betas_scale_pred = pred_codes
193 | pred_camera = torch.cat([scale_pred[:, [0]], torch.ones(keypoints.shape[0], 2).cuda() * config.IMG_RES / 2],
194 | dim=1)
195 | verts, joints, _, _ = model.module.smal(betas_pred, pose_pred, trans=trans_pred,
196 | betas_logscale=betas_scale_pred)
197 | faces = model.module.smal.faces.unsqueeze(0).expand(verts.shape[0], 7774, 3)
198 | labelled_joints_3d = joints[:, config.MODEL_JOINTS]
199 | synth_rgb, synth_silhouettes = model.module.model_renderer(verts, faces, pred_camera)
200 | if args.save_results:
201 | synth_rgb = torch.clamp(synth_rgb[0], 0.0, 1.0)
202 | synth_silhouettes = synth_silhouettes.unsqueeze(1)
203 | synth_landmarks = model.module.model_renderer.project_points(labelled_joints_3d, pred_camera)
204 |
205 | preds = {}
206 | preds['pose'] = pose_pred
207 | preds['betas'] = betas_pred
208 | preds['camera'] = pred_camera
209 | preds['trans'] = trans_pred
210 |
211 | preds['verts'] = verts
212 | preds['joints_3d'] = labelled_joints_3d
213 | preds['faces'] = faces
214 |
215 | preds['acc_PCK'] = Metrics.PCK(synth_landmarks, keypoints_norm, seg, has_seg)
216 | preds['acc_IOU'] = Metrics.IOU(synth_silhouettes, seg, img_border_mask, mask=has_seg)
217 |
218 | for group, group_kps in config.KEYPOINT_GROUPS.items():
219 | preds[f'{group}_PCK'] = Metrics.PCK(synth_landmarks, keypoints_norm, seg, has_seg,
220 | thresh_range=[0.15],
221 | idxs=group_kps)
222 |
223 | preds['synth_xyz'] = synth_rgb
224 | preds['synth_silhouettes'] = synth_silhouettes
225 | preds['synth_landmarks'] = synth_landmarks
226 |
227 | assert not any(k in preds for k in batch.keys())
228 | preds.update(batch)
229 |
230 | curr_batch_size = preds['synth_landmarks'].shape[0]
231 |
232 | pck[step * batch_size:step * batch_size + curr_batch_size] = preds['acc_PCK'].data.cpu().numpy()
233 | acc_sil_2d[step * batch_size:step * batch_size + curr_batch_size] = preds['acc_IOU'].data.cpu().numpy()
234 | smal_pose[step * batch_size:step * batch_size + curr_batch_size] = preds['pose'].data.cpu().numpy()
235 | smal_betas[step * batch_size:step * batch_size + curr_batch_size, :preds['betas'].shape[1]] = preds['betas'].data.cpu().numpy()
236 | smal_camera[step * batch_size:step * batch_size + curr_batch_size] = preds['camera'].data.cpu().numpy()
237 |
238 | for part in pck_by_part:
239 | pck_by_part[part][step * batch_size:step * batch_size + curr_batch_size] = preds[f'{part}_PCK'].data.cpu().numpy()
240 |
241 | # save results as well as visualization
242 | if args.save_results:
243 | output_figs = np.transpose(
244 | Visualizer.generate_output_figures(preds).data.cpu().numpy(),
245 | (0, 1, 3, 4, 2))
246 |
247 | for img_id in range(len(preds['imgname'])):
248 | imgname = preds['imgname'][img_id]
249 | output_fig_list = output_figs[img_id]
250 |
251 | path_parts = imgname.split('/')
252 | path_suffix = "{0}_{1}".format(path_parts[-2], path_parts[-1])
253 | img_file = os.path.join(result_dir, path_suffix)
254 | output_fig = np.hstack(output_fig_list)
255 | smal_imgname.append(path_suffix)
256 | # npz_file = "{0}.npz".format(os.path.splitext(img_file)[0])
257 |
258 | cv2.imwrite(img_file, output_fig[:, :, ::-1] * 255.0)
259 | # np.savez_compressed(npz_file,
260 | # imgname=preds['imgname'][img_id],
261 | # pose=preds['pose'][img_id].data.cpu().numpy(),
262 | # betas=preds['betas'][img_id].data.cpu().numpy(),
263 | # camera=preds['camera'][img_id].data.cpu().numpy(),
264 | # trans=preds['trans'][img_id].data.cpu().numpy(),
265 | # acc_PCK=preds['acc_PCK'][img_id].data.cpu().numpy(),
266 | # # acc_SIL_2D=preds['acc_IOU'][img_id].data.cpu().numpy(),
267 | # **{f'{part}_PCK': preds[f'{part}_PCK'].data.cpu().numpy() for part in pck_by_part}
268 | # )
269 |
270 | return np.nanmean(pck), np.nanmean(acc_sil_2d), pck_by_part
271 |
272 |
273 | if __name__ == '__main__':
274 | parser = argparse.ArgumentParser()
275 | parser.add_argument('--lr', default=0.0001, type=float)
276 | parser.add_argument('--output_dir', default='./logs/', type=str)
277 | parser.add_argument('--nEpochs', default=200, type=int)
278 | parser.add_argument('--w_kpts', default=10, type=float)
279 | parser.add_argument('--w_betas_prior', default=1, type=float)
280 | parser.add_argument('--w_pose_prior', default=1, type=float)
281 | parser.add_argument('--batch_size', default=16, type=int)
282 | parser.add_argument('--num_works', default=4, type=int)
283 | parser.add_argument('--start_epoch', default=0, type=int)
284 | parser.add_argument('--gpu_ids', default='0', type=str)
285 | parser.add_argument('--evaluate', action='store_true')
286 | parser.add_argument('--resume', default=None, type=str)
287 | parser.add_argument('--load_optimizer', action='store_true')
288 | parser.add_argument('--shape_family_id', default=1, type=int)
289 | parser.add_argument('--dataset', default='stanford', type=str)
290 | parser.add_argument('--param_dir', default=None, type=str, help='Exported parameter folder to load')
291 | parser.add_argument('--shape_init', default='smal', help='enable to initiate shape with mean shape')
292 | parser.add_argument('--save_results', action='store_true')
293 | parser.add_argument('--prior_betas', default='smal', type=str)
294 | parser.add_argument('--prior_pose', default='smal', type=str)
295 | parser.add_argument('--betas_scale', action='store_true')
296 |
297 | args = parser.parse_args()
298 | main(args)
--------------------------------------------------------------------------------
/main_meshgraph.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, absolute_import, division
2 |
3 | import os
4 | import numpy as np
5 | from tqdm import tqdm
6 | import cv2
7 | import argparse
8 | import torch
9 | import torch.nn as nn
10 | import torch.optim
11 | from torch.utils.data import DataLoader
12 | from torch.utils.tensorboard import SummaryWriter
13 | from model.mesh_graph_hg import MeshGraph_hg, init_pretrained
14 |
15 | from util import config
16 | from util.helpers.visualize import Visualizer
17 | from util.loss_utils import kp_l2_loss, Shape_prior, Laplacian
18 | from util.loss_sdf import tversky_loss
19 | from util.metrics import Metrics
20 |
21 | from datasets.stanford import BaseDataset
22 |
23 | from util.logger import Logger
24 | from util.meter import AverageMeterSet
25 | from util.misc import save_checkpoint
26 | from util.pose_prior import Prior
27 | from util.joint_limits_prior import LimitPrior
28 | from util.utils import print_options
29 |
30 | # Set some global varibles
31 | global_step = 0
32 | best_pck = 0
33 | best_pck_epoch = 0
34 |
35 |
36 | def main(args):
37 | global best_pck
38 | global best_pck_epoch
39 | global global_step
40 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_ids
41 |
42 | print("RESULTS: {0}".format(args.output_dir))
43 | if not os.path.exists(args.output_dir):
44 | os.mkdir(args.output_dir)
45 |
46 | # set up device
47 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
48 |
49 | # set up model
50 | model = MeshGraph_hg(device, args.shape_family_id, args.num_channels, args.num_layers, args.betas_scale,
51 | args.shape_init, args.local_feat, num_downsampling=args.num_downsampling,
52 | render_rgb=args.save_results)
53 |
54 | model = nn.DataParallel(model).to(device)
55 |
56 | # set up dataset
57 | dataset_train = BaseDataset(args.dataset, param_dir=args.param_dir, is_train=True, use_augmentation=True)
58 | data_loader_train = DataLoader(dataset_train, batch_size=args.batch_size, shuffle=True, num_workers=args.num_works)
59 | dataset_eval = BaseDataset(args.dataset, param_dir=args.param_dir, is_train=False, use_augmentation=False)
60 | data_loader_eval = DataLoader(dataset_eval, batch_size=args.batch_size, shuffle=False, num_workers=args.num_works)
61 |
62 | # set up optimizer
63 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
64 |
65 | writer = SummaryWriter(os.path.join(args.output_dir, 'train'))
66 |
67 | # set up priors
68 | joint_limit_prior = LimitPrior(device)
69 | shape_prior = Shape_prior(args.prior_betas, args.shape_family_id, device)
70 | tversky = tversky_loss(args.alpha, args.beta)
71 |
72 | # read the adjacency matrix, which will used in the Laplacian regularizer
73 | data = np.load('./data/mesh_down_sampling_4.npz', encoding='latin1', allow_pickle=True)
74 | adjmat = data['A'][0]
75 | laplacianloss = Laplacian(adjmat, device)
76 |
77 | if args.resume:
78 | if os.path.isfile(args.resume):
79 | print("=> loading checkpoint {}".format(args.resume))
80 | checkpoint = torch.load(args.resume)
81 | model.load_state_dict(checkpoint['state_dict'])
82 | if args.load_optimizer:
83 | optimizer.load_state_dict(checkpoint['optimizer'])
84 | args.start_epoch = checkpoint['epoch'] + 1
85 | print("=> loaded checkpoint {} (epoch {})".format(args.resume, checkpoint['epoch']))
86 | # logger = Logger(os.path.join(args.output_dir, 'log.txt'), resume=True)
87 | logger = Logger(os.path.join(args.output_dir, 'log.txt'))
88 | logger.set_names(['Epoch', 'LR', 'PCK', 'IOU', 'PCK_re', 'IOU_re'])
89 | else:
90 | print("=> no checkpoint found at {}".format(args.resume))
91 | else:
92 | logger = Logger(os.path.join(args.output_dir, 'log.txt'))
93 | logger.set_names(['Epoch', 'LR', 'PCK', 'IOU','PCK_re', 'IOU_re'])
94 |
95 | if args.freezecoarse:
96 | for p in model.module.meshnet.parameters():
97 | p.requires_grad = False
98 | if args.pretrained:
99 | if os.path.isfile(args.pretrained):
100 | print("=> loading checkpoint {}".format(args.pretrained))
101 | checkpoint_pre = torch.load(args.pretrained)
102 | init_pretrained(model, checkpoint_pre)
103 | print("=> loaded checkpoint {} (epoch {})".format(args.pretrained, checkpoint_pre['epoch']))
104 | # logger = Logger(os.path.join(args.output_dir, 'log.txt'), resume=True)
105 | logger = Logger(os.path.join(args.output_dir, 'log.txt'))
106 | logger.set_names(['Epoch', 'LR', 'PCK', 'IOU', 'PCK_re', 'IOU_re'])
107 |
108 | print_options(args)
109 | if args.evaluate:
110 | pck, iou_silh, pck_by_part, pck_re, iou_re = run_evaluation(model, dataset_eval, data_loader_eval, device, args)
111 | print("Evaluate only, PCK: {:6.4f}, IOU: {:6.4f}, PCK_re: {:6.4f}, IOU_re: {:6.4f}"
112 | .format(pck, iou_silh, pck_re, iou_re))
113 | return
114 |
115 | lr = args.lr
116 | # lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
117 | # optimizer, [160, 190], 0.5,
118 | # )
119 | for epoch in range(args.start_epoch, args.nEpochs):
120 | # lr_scheduler.step()
121 | # lr = lr_scheduler.get_last_lr()[0]
122 | model.train()
123 | tqdm_iterator = tqdm(data_loader_train, desc='Train', total=len(data_loader_train))
124 | meters = AverageMeterSet()
125 | for step, batch in enumerate(tqdm_iterator):
126 | keypoints = batch['keypoints'].to(device)
127 | keypoints_norm = batch['keypoints_norm'].to(device)
128 | seg = batch['seg'].to(device)
129 | img = batch['img'].to(device)
130 |
131 | verts, joints, shape, pred_codes = model(img)
132 | scale_pred, trans_pred, pose_pred, betas_pred, betas_scale_pred = pred_codes
133 | pred_camera = torch.cat([scale_pred[:, [0]], torch.ones(keypoints.shape[0], 2).cuda() * config.IMG_RES / 2],
134 | dim=1)
135 | faces = model.module.smal.faces.unsqueeze(0).expand(verts.shape[0], 7774, 3)
136 | labelled_joints_3d = joints[:, config.MODEL_JOINTS]
137 |
138 | # project 3D joints onto 2D space and apply 2D keypoints supervision
139 | synth_landmarks = model.module.model_renderer.project_points(labelled_joints_3d, pred_camera)
140 | loss_kpts = args.w_kpts * kp_l2_loss(synth_landmarks, keypoints[:, :, [1, 0, 2]], config.NUM_JOINTS)
141 | meters.update('loss_kpt', loss_kpts.item())
142 | loss = loss_kpts
143 |
144 | # use tversky for silhouette loss
145 | if args.w_dice>0:
146 | synth_rgb, synth_silhouettes = model.module.model_renderer(verts, faces, pred_camera)
147 | synth_silhouettes = synth_silhouettes.unsqueeze(1)
148 | loss_dice = args.w_dice * tversky(synth_silhouettes, seg)
149 | meters.update('loss_dice', loss_dice.item())
150 | loss += loss_dice
151 |
152 | # apply shape prior constraint, either come from SMAL or unity from WLDO
153 | if args.w_betas_prior > 0:
154 | if args.prior_betas == 'smal':
155 | s_prior = args.w_betas_prior * shape_prior(betas_pred)
156 | elif args.prior_betas == 'unity':
157 | betas_pred = torch.cat([betas_pred, betas_scale_pred], dim=1)
158 | s_prior = args.w_betas_prior * shape_prior(betas_pred)
159 | else:
160 | Exception("Shape prior should come from either smal or unity")
161 | s_prior = 0
162 | meters.update('loss_prior', s_prior.item())
163 | loss += s_prior
164 |
165 | # apply pose prior constraint, either come from SMAL or unity from WLDO
166 | if args.w_pose_prior > 0:
167 | if args.prior_pose == 'smal':
168 | pose_prior_path = config.WALKING_PRIOR_FILE
169 | elif args.prior_pose == 'unity':
170 | pose_prior_path = config.UNITY_POSE_PRIOR
171 | else:
172 | Exception('The prior should come from either smal or unity')
173 | pose_prior_path = None
174 | pose_prior = Prior(pose_prior_path, device)
175 | p_prior = args.w_pose_prior * pose_prior(pose_pred)
176 | meters.update('pose_prior', p_prior.item())
177 | loss += p_prior
178 |
179 | # apply pose limit constraint
180 | if args.w_pose_limit_prior > 0:
181 | pose_limit_loss = args.w_pose_limit_prior * joint_limit_prior(pose_pred)
182 | meters.update('pose_limit', pose_limit_loss.item())
183 | loss += pose_limit_loss
184 |
185 | # get refined meshes by adding del_v to the coarse mesh from SMAL
186 | verts_refine, joints_refine, _, _ = model.module.smal(betas_pred, pose_pred, trans=trans_pred,
187 | del_v=shape,
188 | betas_logscale=betas_scale_pred)
189 | # apply 2D keypoint and silhouette supervision
190 | labelled_joints_3d_refine = joints_refine[:, config.MODEL_JOINTS]
191 | synth_landmarks_refine = model.module.model_renderer.project_points(labelled_joints_3d_refine,
192 | pred_camera)
193 | loss_kpts_refine = args.w_kpts_refine * kp_l2_loss(synth_landmarks_refine, keypoints[:, :, [1, 0, 2]],
194 | config.NUM_JOINTS)
195 |
196 | meters.update('loss_kpt_refine', loss_kpts_refine.item())
197 | loss += loss_kpts_refine
198 | if args.w_dice_refine> 0:
199 | _, synth_silhouettes_refine = model.module.model_renderer(verts_refine, faces, pred_camera)
200 | synth_silhouettes_refine = synth_silhouettes_refine.unsqueeze(1)
201 | loss_dice_refine = args.w_dice_refine * tversky(synth_silhouettes_refine, seg)
202 | meters.update('loss_dice_refine', loss_dice_refine.item())
203 | loss += loss_dice_refine
204 |
205 | # apply Laplacian constraint to prevent large deformation predictions
206 | if args.w_arap > 0:
207 | verts_clone = verts.detach().clone()
208 | loss_arap, loss_smooth = laplacianloss(verts_refine, verts_clone)
209 | loss_arap = args.w_arap * loss_arap
210 | meters.update('loss_arap', loss_arap.item())
211 | loss += loss_arap
212 |
213 | meters.update('loss_all', loss.item())
214 | optimizer.zero_grad()
215 | loss.backward()
216 | optimizer.step()
217 | global_step += 1
218 | if step % 20 == 0:
219 | loss_values = meters.averages()
220 | for name, meter in loss_values.items():
221 | writer.add_scalar(name, meter, global_step)
222 | writer.flush()
223 |
224 | pck, iou_silh, pck_by_part, pck_re, iou_re = run_evaluation(model, dataset_eval, data_loader_eval, device, args)
225 |
226 | print("Epoch: {:3d}, LR: {:6.5f}, PCK: {:6.4f}, IOU: {:6.4f}, PCK_re: {:6.4f}, IOU_re: {:6.4f}"
227 | .format(epoch, lr, pck, iou_silh, pck_re, iou_re))
228 | logger.append([epoch, lr, pck, iou_silh, pck_re, iou_re])
229 |
230 | is_best = pck_re > best_pck
231 | if pck_re > best_pck:
232 | best_pck_epoch = epoch
233 | best_pck = max(pck_re, best_pck)
234 | save_checkpoint({'epoch': epoch,
235 | 'state_dict': model.state_dict(),
236 | 'best_pck': best_pck,
237 | 'optimizer': optimizer.state_dict()},
238 | is_best, checkpoint=args.output_dir, filename='checkpoint.pth.tar')
239 | writer.close()
240 | logger.close()
241 |
242 |
243 | def run_evaluation(model, dataset, data_loader, device, args):
244 |
245 | model.eval()
246 | result_dir = args.output_dir
247 | batch_size = args.batch_size
248 |
249 | pck = np.zeros((len(dataset)))
250 | pck_by_part = {group: np.zeros((len(dataset))) for group in config.KEYPOINT_GROUPS}
251 | acc_sil_2d = np.zeros(len(dataset))
252 |
253 | pck_re = np.zeros((len(dataset)))
254 | acc_sil_2d_re = np.zeros(len(dataset))
255 |
256 | smal_pose = np.zeros((len(dataset), 105))
257 | smal_betas = np.zeros((len(dataset), 20))
258 | smal_camera = np.zeros((len(dataset), 3))
259 | smal_imgname = []
260 |
261 | tqdm_iterator = tqdm(data_loader, desc='Eval', total=len(data_loader))
262 |
263 | for step, batch in enumerate(tqdm_iterator):
264 | with torch.no_grad():
265 | preds = {}
266 | keypoints = batch['keypoints'].to(device)
267 | keypoints_norm = batch['keypoints_norm'].to(device)
268 | seg = batch['seg'].to(device)
269 | has_seg = batch['has_seg']
270 | img = batch['img'].to(device)
271 | img_border_mask = batch['img_border_mask'].to(device)
272 | # get coarse meshes and project onto 2D
273 | verts, joints, shape, pred_codes = model(img)
274 | scale_pred, trans_pred, pose_pred, betas_pred, betas_scale_pred = pred_codes
275 | pred_camera = torch.cat([scale_pred[:, [0]], torch.ones(keypoints.shape[0], 2).cuda() * config.IMG_RES / 2],
276 | dim=1)
277 | faces = model.module.smal.faces.unsqueeze(0).expand(verts.shape[0], 7774, 3)
278 | labelled_joints_3d = joints[:, config.MODEL_JOINTS]
279 |
280 | synth_rgb, synth_silhouettes = model.module.model_renderer(verts, faces, pred_camera)
281 | synth_silhouettes = synth_silhouettes.unsqueeze(1)
282 | synth_landmarks = model.module.model_renderer.project_points(labelled_joints_3d, pred_camera)
283 |
284 | # get refined meshes by adding del_v to coarse estimations
285 | verts_refine, joints_refine, _, _ = model.module.smal(betas_pred, pose_pred, trans=trans_pred,
286 | del_v=shape,
287 | betas_logscale=betas_scale_pred)
288 | labelled_joints_3d_refine = joints_refine[:, config.MODEL_JOINTS]
289 | # project refined 3D meshes onto 2D
290 | synth_rgb_refine, synth_silhouettes_refine = model.module.model_renderer(verts_refine, faces, pred_camera)
291 | synth_silhouettes_refine = synth_silhouettes_refine.unsqueeze(1)
292 | synth_landmarks_refine = model.module.model_renderer.project_points(labelled_joints_3d_refine,
293 | pred_camera)
294 |
295 | if args.save_results:
296 | synth_rgb = torch.clamp(synth_rgb[0], 0.0, 1.0)
297 | synth_rgb_refine = torch.clamp(synth_rgb_refine[0], 0.0, 1.0)
298 |
299 | preds['pose'] = pose_pred
300 | preds['betas'] = betas_pred
301 | preds['camera'] = pred_camera
302 | preds['trans'] = trans_pred
303 |
304 | preds['verts'] = verts
305 | preds['joints_3d'] = labelled_joints_3d
306 | preds['faces'] = faces
307 |
308 | preds['acc_PCK'] = Metrics.PCK(synth_landmarks, keypoints_norm, seg, has_seg)
309 | preds['acc_IOU'] = Metrics.IOU(synth_silhouettes, seg, img_border_mask, mask=has_seg)
310 |
311 | preds['acc_PCK_re'] = Metrics.PCK(synth_landmarks_refine, keypoints_norm, seg, has_seg)
312 | preds['acc_IOU_re'] = Metrics.IOU(synth_silhouettes_refine, seg, img_border_mask, mask=has_seg)
313 |
314 | for group, group_kps in config.KEYPOINT_GROUPS.items():
315 | preds[f'{group}_PCK'] = Metrics.PCK(synth_landmarks, keypoints_norm, seg, has_seg,
316 | thresh_range=[0.15],
317 | idxs=group_kps)
318 |
319 | preds['synth_xyz'] = synth_rgb
320 | preds['synth_silhouettes'] = synth_silhouettes
321 | preds['synth_landmarks'] = synth_landmarks
322 | preds['synth_xyz_re'] = synth_rgb_refine
323 | preds['synth_landmarks_re'] = synth_landmarks_refine
324 | preds['synth_silhouettes_re'] = synth_silhouettes_refine
325 |
326 | assert not any(k in preds for k in batch.keys())
327 | preds.update(batch)
328 |
329 | curr_batch_size = preds['synth_landmarks'].shape[0]
330 |
331 | pck[step * batch_size:step * batch_size + curr_batch_size] = preds['acc_PCK'].data.cpu().numpy()
332 | acc_sil_2d[step * batch_size:step * batch_size + curr_batch_size] = preds['acc_IOU'].data.cpu().numpy()
333 | smal_pose[step * batch_size:step * batch_size + curr_batch_size] = preds['pose'].data.cpu().numpy()
334 | smal_betas[step * batch_size:step * batch_size + curr_batch_size, :preds['betas'].shape[1]] = preds['betas'].data.cpu().numpy()
335 | smal_camera[step * batch_size:step * batch_size + curr_batch_size] = preds['camera'].data.cpu().numpy()
336 |
337 | pck_re[step * batch_size:step * batch_size + curr_batch_size] = preds['acc_PCK_re'].data.cpu().numpy()
338 | acc_sil_2d_re[step * batch_size:step * batch_size + curr_batch_size] = preds['acc_IOU_re'].data.cpu().numpy()
339 | for part in pck_by_part:
340 | pck_by_part[part][step * batch_size:step * batch_size + curr_batch_size] = preds[f'{part}_PCK'].data.cpu().numpy()
341 |
342 | if args.save_results:
343 | output_figs = np.transpose(
344 | Visualizer.generate_output_figures(preds, vis_refine=True).data.cpu().numpy(),
345 | (0, 1, 3, 4, 2))
346 |
347 | for img_id in range(len(preds['imgname'])):
348 | imgname = preds['imgname'][img_id]
349 | output_fig_list = output_figs[img_id]
350 |
351 | path_parts = imgname.split('/')
352 | path_suffix = "{0}_{1}".format(path_parts[-2], path_parts[-1])
353 | img_file = os.path.join(result_dir, path_suffix)
354 | output_fig = np.hstack(output_fig_list)
355 | smal_imgname.append(path_suffix)
356 | npz_file = "{0}.npz".format(os.path.splitext(img_file)[0])
357 |
358 | cv2.imwrite(img_file, output_fig[:, :, ::-1] * 255.0)
359 | # np.savez_compressed(npz_file,
360 | # imgname=preds['imgname'][img_id],
361 | # pose=preds['pose'][img_id].data.cpu().numpy(),
362 | # betas=preds['betas'][img_id].data.cpu().numpy(),
363 | # camera=preds['camera'][img_id].data.cpu().numpy(),
364 | # trans=preds['trans'][img_id].data.cpu().numpy(),
365 | # acc_PCK=preds['acc_PCK'][img_id].data.cpu().numpy(),
366 | # # acc_SIL_2D=preds['acc_IOU'][img_id].data.cpu().numpy(),
367 | # **{f'{part}_PCK': preds[f'{part}_PCK'].data.cpu().numpy() for part in pck_by_part}
368 | # )
369 |
370 | return np.nanmean(pck), np.nanmean(acc_sil_2d), pck_by_part, np.nanmean(pck_re), np.nanmean(acc_sil_2d_re)
371 |
372 |
373 | if __name__ == '__main__':
374 | parser = argparse.ArgumentParser()
375 | parser.add_argument('--lr', default=0.0001, type=float)
376 | parser.add_argument('--output_dir', default='./logs/', type=str)
377 | parser.add_argument('--nEpochs', default=250, type=int)
378 |
379 | parser.add_argument('--w_kpts', default=10, type=float)
380 | parser.add_argument('--w_betas_prior', default=1, type=float)
381 | parser.add_argument('--w_pose_prior', default=1, type=float)
382 | parser.add_argument('--w_pose_limit_prior', default=0, type=float)
383 | parser.add_argument('--w_kpts_refine', default=1, type=float)
384 |
385 | parser.add_argument('--batch_size', default=16, type=int)
386 | parser.add_argument('--num_works', default=4, type=int)
387 | parser.add_argument('--start_epoch', default=0, type=int)
388 | parser.add_argument('--gpu_ids', default='0', type=str)
389 | parser.add_argument('--evaluate', action='store_true')
390 | parser.add_argument('--resume', default=None, type=str)
391 | parser.add_argument('--load_optimizer', action='store_true')
392 | parser.add_argument('--dataset', default='stanford', type=str)
393 | parser.add_argument('--shape_family_id', default=1, type=int)
394 | parser.add_argument('--param_dir', default=None, type=str, help='Exported parameter folder to load')
395 |
396 | parser.add_argument('--shape_init', default='smal', help='enable to initiate shape with mean shape')
397 | parser.add_argument('--save_results', action='store_true')
398 | parser.add_argument('--prior_betas', default='smal', type=str)
399 | parser.add_argument('--prior_pose', default='smal', type=str)
400 | parser.add_argument('--betas_scale', action='store_true')
401 |
402 | parser.add_argument('--num_channels', type=int, default=256, help='Number of channels in Graph Residual layers')
403 | parser.add_argument('--num_layers', type=int, default=5, help='Number of residuals blocks in the Graph CNN')
404 | parser.add_argument('--pretrained', default=None, type=str)
405 | parser.add_argument('--local_feat', action='store_true')
406 |
407 | parser.add_argument('--num_downsampling', default=1, type=int)
408 | parser.add_argument('--freezecoarse', action='store_true')
409 |
410 | parser.add_argument('--w_arap', default=0, type=float)
411 | parser.add_argument('--w_dice', default=0, type=float)
412 | parser.add_argument('--w_dice_refine', default=0, type=float)
413 | parser.add_argument('--alpha', default=0.6, type=float)
414 | parser.add_argument('--beta', default=0.4, type=float)
415 |
416 | args = parser.parse_args()
417 | main(args)
--------------------------------------------------------------------------------