├── coma ├── __init__.py └── mesh_sampling.py ├── data ├── readme.txt └── mesh_down_sampling_4.npz ├── model ├── __init__.py ├── networks │ ├── __init__.py │ ├── linear_model.py │ └── graph_layers.py ├── model_v1.py ├── graph_hg.py ├── mesh_graph_hg.py └── smal_mesh_net_img.py ├── smal ├── __init__.py ├── smal_basics.py ├── mesh.py ├── batch_lbs.py └── smal_torch.py ├── util ├── __init__.py ├── utils.py ├── misc.py ├── meter.py ├── pose_prior.py ├── geom_utils.py ├── metrics.py ├── helpers │ ├── draw_smal_joints.py │ ├── conversions.py │ └── visualize.py ├── config.py ├── loss_utils.py ├── loss_sdf.py ├── logger.py ├── joint_limits_prior.py ├── nmr.py └── net_blocks.py ├── datasets ├── __init__.py ├── imutils.py └── stanford.py ├── network.png ├── result_examples.png ├── LICENSE ├── requirements.txt ├── README.md ├── eval.py ├── main.py └── main_meshgraph.py /coma/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /data/readme.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /smal/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /model/networks/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /network.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaneyddtt/Coarse-to-fine-3D-Animal/HEAD/network.png -------------------------------------------------------------------------------- /result_examples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaneyddtt/Coarse-to-fine-3D-Animal/HEAD/result_examples.png -------------------------------------------------------------------------------- /data/mesh_down_sampling_4.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaneyddtt/Coarse-to-fine-3D-Animal/HEAD/data/mesh_down_sampling_4.npz -------------------------------------------------------------------------------- /util/utils.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | 4 | def print_options(args): 5 | message = '' 6 | message += '----------------- Options ---------------\n' 7 | for k, v in sorted(vars(args).items()): 8 | comment = '' 9 | message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) 10 | message += '----------------- End -------------------' 11 | print(message) 12 | 13 | # save to the disk 14 | file_name = osp.join(args.output_dir, 'opt.txt') 15 | with open(file_name, 'wt') as args_file: 16 | args_file.write(message) 17 | args_file.write('\n') -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Li Chen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /util/misc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import shutil 3 | import os 4 | 5 | 6 | def save_checkpoint(state, is_best, checkpoint='checkpoint', filename='checkpoint.pth.tar'): 7 | 8 | filepath = os.path.join(checkpoint, filename) 9 | torch.save(state, filepath) 10 | if is_best: 11 | shutil.copyfile(filepath, os.path.join(checkpoint, 'model_best.pth.tar')) 12 | 13 | def adjust_learning_rate(optimizer, epoch, lr, schedule, gamma): 14 | """Sets the learning rate to the initial LR decayed by schedule""" 15 | if epoch in schedule: 16 | lr *= gamma 17 | for param_group in optimizer.param_groups: 18 | param_group['lr'] = lr 19 | return lr 20 | 21 | 22 | def lr_poly(base_lr, epoch, max_epoch, power): 23 | """ Poly_LR scheduler 24 | """ 25 | return base_lr * ((1 - float(epoch) / max_epoch) ** power) 26 | 27 | 28 | def adjust_learning_rate_main(optimizer, epoch, args): 29 | lr = lr_poly(args.lr, epoch, args.max_epoch, args.power) 30 | for param_group in optimizer.param_groups: 31 | param_group['lr'] = lr 32 | return lr 33 | 34 | 35 | def adjust_learning_rate_exponential(optimizer, epoch, epoch_decay, learning_rate, decay_rate): 36 | lr = learning_rate * (decay_rate ** (epoch / epoch_decay)) 37 | for param_group in optimizer.param_groups: 38 | param_group['lr'] = lr 39 | return lr -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.14.1 2 | cached-property==1.5.2 3 | cachetools==4.2.4 4 | certifi==2020.12.5 5 | chamferdist==0.3.0 6 | chardet==4.0.0 7 | chumpy==0.70 8 | cycler==0.10.0 9 | Cython==0.29.23 10 | decorator==4.4.2 11 | freetype-py==2.2.0 12 | future==0.18.2 13 | google-auth==1.35.0 14 | google-auth-oauthlib==0.4.6 15 | grpcio==1.41.0 16 | h5py==3.2.1 17 | idna==2.10 18 | imageio==2.9.0 19 | importlib-metadata==4.8.1 20 | jsonpatch==1.32 21 | jsonpointer==2.1 22 | kiwisolver==1.3.1 23 | Markdown==3.3.4 24 | matplotlib==3.3.4 25 | networkx==2.5.1 26 | neural-renderer-pytorch==1.1.3 27 | nibabel==3.2.1 28 | numpy==1.20.1 29 | oauthlib==3.1.1 30 | opencv-python==4.5.1.48 31 | packaging==21.0 32 | Pillow==8.1.2 33 | pointnet2-ops==3.0.0 34 | protobuf==3.18.1 35 | pyasn1==0.4.8 36 | pyasn1-modules==0.2.8 37 | pycocotools==2.0.2 38 | pyglet==1.5.18 39 | PyOpenGL==3.1.0 40 | pyparsing==2.4.7 41 | PyQt5==5.9 42 | PyQt5-Qt5==5.15.2 43 | PyQt5-sip==12.9.0 44 | pyrender==0.1.45 45 | python-dateutil==2.8.1 46 | PyWavelets==1.1.1 47 | pyzmq==22.0.3 48 | requests==2.25.1 49 | requests-oauthlib==1.3.0 50 | rsa==4.7.2 51 | scikit-image==0.18.1 52 | scipy==1.2.0 53 | sip==4.19.8 54 | six==1.15.0 55 | tensorboard==2.6.0 56 | tensorboard-data-server==0.6.1 57 | tensorboard-plugin-wit==1.8.0 58 | tifffile==2021.4.8 59 | torch==1.5.0+cu101 60 | torchfile==0.1.0 61 | torchvision==0.6.0+cu101 62 | tornado==6.1 63 | tqdm==4.60.0 64 | transforms3d==0.3.1 65 | trimesh==3.9.26 66 | typing-extensions==3.10.0.2 67 | urllib3==1.26.3 68 | visdom==0.1.8.9 69 | websocket-client==0.58.0 70 | Werkzeug==2.0.2 71 | zipp==3.6.0 72 | -------------------------------------------------------------------------------- /util/meter.py: -------------------------------------------------------------------------------- 1 | class AverageMeter: 2 | """Computes and stores the average and current value""" 3 | 4 | def __init__(self): 5 | self.reset() 6 | 7 | def reset(self): 8 | self.val = 0 9 | self.avg = 0 10 | self.sum = 0 11 | self.count = 0 12 | 13 | def update(self, val, n=1): 14 | self.val = val 15 | self.sum += val * n 16 | self.count += n 17 | self.avg = self.sum / self.count 18 | 19 | def __format__(self, format): 20 | return "{self.val:{format}} ({self.avg:{format}})".format(self=self, format=format) 21 | 22 | 23 | class AverageMeterSet: 24 | def __init__(self): 25 | self.meters = {} 26 | 27 | def __getitem__(self, key): 28 | return self.meters[key] 29 | 30 | def update(self, name, value, n=1): 31 | if not name in self.meters: 32 | self.meters[name] = AverageMeter() 33 | self.meters[name].update(value, n) 34 | 35 | def reset(self): 36 | for meter in self.meters.values(): 37 | meter.reset() 38 | 39 | def values(self, postfix=''): 40 | return {name + postfix: meter.val for name, meter in self.meters.items()} 41 | 42 | def averages(self, postfix='/avg'): 43 | return {name + postfix: meter.avg for name, meter in self.meters.items()} 44 | 45 | def sums(self, postfix='/sum'): 46 | return {name + postfix: meter.sum for name, meter in self.meters.items()} 47 | 48 | def counts(self, postfix='/count'): 49 | return {name + postfix: meter.count for name, meter in self.meters.items()} -------------------------------------------------------------------------------- /util/pose_prior.py: -------------------------------------------------------------------------------- 1 | import pickle as pkl 2 | import numpy as np 3 | from chumpy import Ch 4 | import cv2 5 | import torch 6 | 7 | name2id35 = {'RFoot': 14, 'RFootBack': 24, 'spine1': 4, 'Head': 16, 'LLegBack3': 19, 'RLegBack1': 21, 'pelvis0': 1, 'RLegBack3': 23, 'LLegBack2': 18, 'spine0': 3, 'spine3': 6, 'spine2': 5, 'Mouth': 32, 'Neck': 15, 'LFootBack': 20, 'LLegBack1': 17, 'RLeg3': 13, 'RLeg2': 12, 'LLeg1': 7, 'LLeg3': 9, 'RLeg1': 11, 'LLeg2': 8, 'spine': 2, 'LFoot': 10, 'Tail7': 31, 'Tail6': 30, 'Tail5': 29, 'Tail4': 28, 'Tail3': 27, 'Tail2': 26, 'Tail1': 25, 'RLegBack2': 22, 'root': 0, 'LEar':33, 'REar':34} 8 | id2name35 = {v: k for k, v in name2id35.items()} 9 | 10 | 11 | class Prior(object): 12 | def __init__(self, prior_path, device): 13 | with open(prior_path, 'rb') as f: 14 | res = pkl.load(f, encoding='latin1') 15 | 16 | self.mean_ch = res['mean_pose'] 17 | self.precs_ch = res['pic'] 18 | 19 | self.precs = torch.from_numpy(res['pic'].r.copy()).float().to(device) 20 | self.mean = torch.from_numpy(res['mean_pose']).float().to(device) 21 | 22 | prefix = 3 23 | pose_len = 105 24 | id2name = id2name35 25 | name2id = name2id35 26 | 27 | self.use_ind = np.ones(pose_len, dtype=bool) 28 | self.use_ind[:prefix] = False 29 | self.use_ind_tch = torch.from_numpy(self.use_ind).float().to(device) 30 | 31 | def __call__(self, x): 32 | mean_sub = x.reshape(-1, 35*3) - self.mean.unsqueeze(0) 33 | res = torch.tensordot(mean_sub, self.precs, dims=([1], [0])) * self.use_ind_tch 34 | return (res**2).mean() -------------------------------------------------------------------------------- /smal/smal_basics.py: -------------------------------------------------------------------------------- 1 | import pickle as pkl 2 | import numpy as np 3 | 4 | 5 | def align_smal_template_to_symmetry_axis(v, sym_file): 6 | # These are the indexes of the points that are on the symmetry axis 7 | I = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 37, 55, 119, 120, 163, 209, 210, 211, 213, 216, 227, 326, 395, 452, 578, 910, 959, 964, 975, 976, 977, 1172, 1175, 1176, 1178, 1194, 1243, 1739, 1796, 1797, 1798, 1799, 1800, 1801, 1802, 1803, 1804, 1805, 1806, 1807, 1808, 1809, 1810, 1811, 1812, 1813, 1814, 1815, 1816, 1817, 1818, 1819, 1820, 1821, 1822, 1823, 1824, 1825, 1826, 1827, 1828, 1829, 1830, 1831, 1832, 1833, 1834, 1835, 1836, 1837, 1838, 1839, 1840, 1842, 1843, 1844, 1845, 1846, 1847, 1848, 1849, 1850, 1851, 1852, 1853, 1854, 1855, 1856, 1857, 1858, 1859, 1860, 1861, 1862, 1863, 1870, 1919, 1960, 1961, 1965, 1967, 2003] 8 | 9 | v = v - np.mean(v) 10 | y = np.mean(v[I,1]) 11 | v[:,1] = v[:,1] - y 12 | v[I,1] = 0 13 | 14 | # symIdx = pkl.load(open(sym_path)) 15 | with open(sym_file, 'rb') as f: 16 | u = pkl._Unpickler(f) 17 | u.encoding = 'latin1' 18 | symIdx = u.load() 19 | 20 | 21 | left = v[:, 1] < 0 22 | right = v[:, 1] > 0 23 | center = v[:, 1] == 0 24 | v[left[symIdx]] = np.array([1,-1,1])*v[left] 25 | 26 | left_inds = np.where(left)[0] 27 | right_inds = np.where(right)[0] 28 | center_inds = np.where(center)[0] 29 | 30 | try: 31 | assert(len(left_inds) == len(right_inds)) 32 | except: 33 | import pdb; pdb.set_trace() 34 | 35 | return v, left_inds, right_inds, center_inds -------------------------------------------------------------------------------- /util/geom_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utils related to geometry like projection,, 3 | """ 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import torch 9 | 10 | def sample_textures(texture_flow, images): 11 | """ 12 | texture_flow: B x F x T x T x 2 13 | (In normalized coordinate [-1, 1]) 14 | images: B x 3 x N x N 15 | 16 | output: B x F x T x T x 3 17 | """ 18 | # Reshape into B x F x T*T x 2 19 | T = texture_flow.size(-2) 20 | F = texture_flow.size(1) 21 | flow_grid = texture_flow.view(-1, F, T * T, 2) 22 | # B x 3 x F x T*T 23 | samples = torch.nn.functional.grid_sample(images, flow_grid) 24 | # B x 3 x F x T x T 25 | samples = samples.view(-1, 3, F, T, T) 26 | # B x F x T x T x 3 27 | return samples.permute(0, 2, 3, 4, 1) 28 | 29 | 30 | def perspective_proj_withz(X, cam, offset_z=0, cuda_device=0,norm_f=1., norm_z=0.,norm_f0=0.): 31 | """ 32 | X: B x N x 3 33 | cam: B x 3: [f, cx, cy] 34 | offset_z is for being compatible with previous code and is not used and should be removed 35 | """ 36 | 37 | # B x 1 x 1 38 | #f = norm_f * cam[:, 0].contiguous().view(-1, 1, 1) 39 | f = norm_f0+norm_f * cam[:, 0].contiguous().view(-1, 1, 1) 40 | # f = norm_f0 41 | # B x N x 1 42 | z = norm_z + X[:, :, 2, None] 43 | 44 | # Will z ever be 0? We probably should max it.. 45 | eps = 1e-6 * torch.ones(1).cuda(device=cuda_device) 46 | z = torch.max(z, eps) 47 | image_size_half = cam[0,1] 48 | scale = f / (z*image_size_half) 49 | 50 | # Offset is because cam is at -1 51 | return torch.cat((scale * X[:, :, :2], z+offset_z),2) 52 | 53 | -------------------------------------------------------------------------------- /model/networks/linear_model.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | 4 | import torch.nn as nn 5 | 6 | 7 | def weight_init(m): 8 | if isinstance(m, nn.Linear): 9 | nn.init.kaiming_normal(m.weight) 10 | 11 | 12 | class Linear(nn.Module): 13 | def __init__(self, linear_size, p_dropout=0.5): 14 | super(Linear, self).__init__() 15 | self.l_size = linear_size 16 | 17 | self.relu = nn.ReLU(inplace=True) 18 | self.dropout = nn.Dropout(p_dropout) 19 | 20 | self.w1 = nn.Linear(self.l_size, self.l_size) 21 | self.batch_norm1 = nn.BatchNorm1d(self.l_size) 22 | 23 | self.w2 = nn.Linear(self.l_size, self.l_size) 24 | self.batch_norm2 = nn.BatchNorm1d(self.l_size) 25 | 26 | def forward(self, x): 27 | y = self.w1(x) 28 | y = self.batch_norm1(y) 29 | y = self.relu(y) 30 | y = self.dropout(y) 31 | 32 | y = self.w2(y) 33 | y = self.batch_norm2(y) 34 | y = self.relu(y) 35 | y = self.dropout(y) 36 | 37 | out = x + y 38 | 39 | return out 40 | 41 | 42 | class LinearModel(nn.Module): 43 | def __init__(self, input_size, output_size, linear_size=1024, num_stage=2, p_dropout=0.5): 44 | super(LinearModel, self).__init__() 45 | self.linear_size = linear_size 46 | self.p_dropout = p_dropout 47 | self.num_stage = num_stage 48 | 49 | self.input_size = input_size 50 | self.output_size = output_size 51 | 52 | self.w1 = nn.Linear(self.input_size, self.linear_size) 53 | self.batch_norm1 = nn.BatchNorm1d(self.linear_size) 54 | 55 | self.linear_stages = [] 56 | for _ in range(num_stage): 57 | self.linear_stages.append(Linear(self.linear_size, self.p_dropout)) 58 | self.linear_stages = nn.ModuleList(self.linear_stages) 59 | 60 | self.w2 = nn.Linear(self.linear_size, self.output_size) 61 | 62 | self.relu = nn.ReLU(inplace=True) 63 | self.dropout = nn.Dropout(self.p_dropout) 64 | 65 | def forward(self, x): 66 | y = self.w1(x) 67 | y = self.batch_norm1(y) 68 | y = self.relu(y) 69 | y = self.dropout(y) 70 | 71 | for i in range(self.num_stage): 72 | y = self.linear_stages[i](y) 73 | 74 | y = self.w2(y) 75 | 76 | return y 77 | 78 | -------------------------------------------------------------------------------- /model/model_v1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from util.nmr import NeuralRenderer 5 | from smal.smal_torch import SMAL 6 | from model.smal_mesh_net_img import MeshNet_img 7 | from util import config 8 | 9 | unity_shape_prior = np.load('data/priors/unity_betas.npz') 10 | 11 | 12 | class MeshModel(nn.Module): 13 | def __init__(self, device, shape_family_id, betas_scale=False, shape_init=None, render_rgb=False): 14 | ''' 15 | Args: 16 | device: specify device for training 17 | shape_family_id: specify animal category id 18 | betas_scale: whether predict the additional shape parameter proposed by WLDO 19 | shape_init: whether intialize the bias with a mean shape, choose from smal or unity 20 | render_rgb: whether render 3D mesh into 2D to get rgb image. Only set to true when generating 21 | visualization to save inference time. 22 | ''' 23 | 24 | super(MeshModel, self).__init__() 25 | 26 | self.model_renderer = NeuralRenderer(config.IMG_RES, proj_type=config.PROJECTION, 27 | norm_f0=config.NORM_F0, 28 | norm_f=config.NORM_F, 29 | norm_z=config.NORM_Z, 30 | render_rgb=render_rgb, 31 | device=device) 32 | self.model_renderer.directional_light_only() 33 | self.smal = SMAL(device, shape_family_id=shape_family_id) 34 | if shape_init == 'smal': 35 | print("Initiate shape with smal prior") 36 | shape_init = self.smal.shape_cluster_means 37 | elif shape_init == 'unity': 38 | print("Initiate shape with unity prior ") 39 | shape_init = unity_shape_prior['mean'][:-1] 40 | shape_init = torch.from_numpy(shape_init).float().to(device) 41 | else: 42 | print("No initialization for shape") 43 | shape_init = None 44 | input_size = [config.IMG_RES, config.IMG_RES] 45 | self.meshnet = MeshNet_img(input_size, betas_scale=betas_scale, norm_f0=config.NORM_F0, nz_feat=config.NZ_FEAT, 46 | shape_init=shape_init) 47 | print('INITIALIZED') 48 | 49 | def forward(self, inp): 50 | 51 | pred_codes = self.meshnet(inp) 52 | 53 | return pred_codes -------------------------------------------------------------------------------- /util/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from util import config 3 | import numpy as np 4 | 5 | 6 | class Metrics(): 7 | 8 | @staticmethod 9 | def PCK_thresh( 10 | pred_keypoints, gt_keypoints, 11 | gtseg, has_seg, 12 | thresh, idxs): 13 | 14 | pred_keypoints, gt_keypoints, gtseg = pred_keypoints[has_seg], gt_keypoints[has_seg], gtseg[has_seg] 15 | 16 | if idxs is None: 17 | idxs = list(range(pred_keypoints.shape[1])) 18 | 19 | idxs = np.array(idxs).astype(int) 20 | 21 | pred_keypoints = pred_keypoints[:, idxs] 22 | gt_keypoints = gt_keypoints[:, idxs] 23 | 24 | keypoints_gt = ((gt_keypoints + 1.0) * 0.5) * config.IMG_RES 25 | dist = torch.norm(pred_keypoints - keypoints_gt[:, :, [1, 0]], dim=-1) 26 | seg_area = torch.sum(gtseg.reshape(gtseg.shape[0], -1), dim=-1).unsqueeze(-1) 27 | 28 | hits = (dist / torch.sqrt(seg_area)) < thresh 29 | total_visible = torch.sum(gt_keypoints[:, :, -1], dim=-1) 30 | pck = torch.sum(hits.float() * gt_keypoints[:, :, -1], dim=-1) / total_visible 31 | 32 | return pck 33 | 34 | @staticmethod 35 | def PCK( 36 | pred_keypoints, keypoints, 37 | gtseg, has_seg, 38 | thresh_range=[0.15], 39 | idxs: list = None): 40 | 41 | """Calc PCK with same method as in eval. 42 | idxs = optional list of subset of keypoints to index from 43 | """ 44 | 45 | cumulative_pck = [] 46 | for thresh in thresh_range: 47 | pck = Metrics.PCK_thresh( 48 | pred_keypoints, keypoints, 49 | gtseg, has_seg, thresh, idxs) 50 | cumulative_pck.append(pck) 51 | 52 | pck_mean = torch.stack(cumulative_pck, dim=0).mean(dim=0) 53 | return pck_mean 54 | 55 | @staticmethod 56 | def IOU(synth_silhouettes, gt_seg, img_border_mask, mask): 57 | for i in range(mask.shape[0]): 58 | synth_silhouettes[i] *= mask[i] 59 | 60 | # Do not penalize parts of the segmentation outside the img range 61 | gt_seg = (gt_seg * img_border_mask) + synth_silhouettes * (1.0 - img_border_mask) 62 | 63 | intersection = torch.sum((synth_silhouettes * gt_seg).reshape(synth_silhouettes.shape[0], -1), dim=-1) 64 | union = torch.sum(((synth_silhouettes + gt_seg).reshape(synth_silhouettes.shape[0], -1) > 0.0).float(), dim=-1) 65 | acc_IOU_SCORE = intersection / union 66 | 67 | return acc_IOU_SCORE -------------------------------------------------------------------------------- /util/helpers/draw_smal_joints.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | 4 | from torchvision.utils import make_grid 5 | from matplotlib import cm 6 | import matplotlib.pyplot as plt 7 | from matplotlib.colors import Normalize 8 | import torch.nn.functional as F 9 | import torch 10 | 11 | 12 | # draw keypoints in the input image, code adopted from WLDO 13 | class SMALJointDrawer(): 14 | def __init__(self): 15 | self.jet_colormap = cm.ScalarMappable(norm = Normalize(0, 1), cmap = 'jet') 16 | 17 | def draw_joints(self, image, landmarks, visible = None, normalized = True, marker_size = 8, thickness = 3): 18 | image_np = np.transpose(image.cpu().data.numpy(), (0, 2, 3, 1)) 19 | landmarks_np = landmarks.cpu().data.numpy() 20 | if visible is not None: 21 | visible_np = visible.cpu().data.numpy() 22 | else: 23 | visible_np = visible 24 | 25 | return_stack = self.draw_joints_np(image_np, landmarks_np, visible_np, normalized, marker_size=marker_size, thickness=thickness) 26 | return torch.FloatTensor(np.transpose(return_stack, (0, 3, 1, 2))) 27 | 28 | def draw_joints_np(self, image_np, landmarks_np, visible_np = None, normalized = False, marker_size = 8, thickness = 3): 29 | if normalized: 30 | image_np = (image_np * 0.5) + 0.5 31 | 32 | image_np = (image_np * 255.0).astype(np.uint8) 33 | 34 | bs, nj, _ = landmarks_np.shape 35 | if visible_np is None: 36 | visible_np = np.ones((bs, nj), dtype=bool) 37 | 38 | return_images = [] 39 | for image_sgl, landmarks_sgl, visible_sgl in zip(image_np, landmarks_np, visible_np): 40 | image_sgl = image_sgl.copy() 41 | for joint_id, ((y_co, x_co), vis) in enumerate(zip(landmarks_sgl, visible_sgl)): 42 | if x_co<0 or y_co<0 or x_co>223 and y_co>223: 43 | continue 44 | else: 45 | color = np.array([255, 0, 0]) 46 | marker_type = 0 47 | if not vis: 48 | continue 49 | cv2.drawMarker(image_sgl, (int(x_co), int(y_co)), (int(color[0]), int(color[1]), int(color[2])), marker_type, marker_size, thickness = thickness) 50 | 51 | return_images.append(image_sgl) 52 | 53 | return_stack = np.stack(return_images, 0) 54 | return_stack = return_stack / 255.0 55 | if normalized: 56 | return_stack = (return_stack - 0.5) / 0.5 # Scale and re-normalize for pytorch 57 | 58 | return return_stack 59 | 60 | def draw_heatmap_grids(self, heatmaps, silhouettes, upsample_val = 1, alpha_blend = 0.9): 61 | bs, jts, h, w = heatmaps.shape 62 | heatmap_grids = [] 63 | for heatmap, silhouette in zip(heatmaps, silhouettes): 64 | heatmap_rgb = heatmap[:, None, :, :].expand(jts, 3, h, w) * alpha_blend + silhouette * (1 - alpha_blend) 65 | grid = make_grid(heatmap_rgb, nrow = 5) 66 | heatmap_jet = self.jet_colormap.to_rgba(grid[0].cpu().numpy())[:, :, :3] 67 | heatmap_jet = np.transpose(heatmap_jet, (2, 0, 1)) 68 | heatmap_grids.append(torch.FloatTensor(heatmap_jet)) 69 | 70 | heatmap_stack = torch.stack(heatmap_grids, dim = 0) 71 | heatmap_stack = (heatmap_stack - 0.5) / 0.5 72 | 73 | return F.interpolate(heatmap_stack, [h * upsample_val, w * upsample_val]) 74 | 75 | -------------------------------------------------------------------------------- /util/config.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains definitions of useful data stuctures and the paths 3 | for the datasets and data files necessary to run the code. 4 | Things you need to change: *_ROOT that indicate the path to each dataset 5 | """ 6 | from os.path import join 7 | import os 8 | 9 | # Define paths to each dataset 10 | 11 | CODE_DIR = os.getcwd() 12 | BASE_FOLDER = join(CODE_DIR, 'data') 13 | # Output folder to save test/train npz files 14 | DATASET_NPZ_PATH = join(BASE_FOLDER, 'splits') 15 | 16 | # Path to test/train npy files 17 | DATASET_FILES = [ 18 | { 19 | 'stanford': join(BASE_FOLDER, 'StanfordExtra_v12', 'test_stanford_StanfordExtra_v12.npy'), 20 | 'animal_pose' : join(BASE_FOLDER, 'animal_pose', 'test_animal_pose.npy') 21 | }, 22 | { 23 | 'stanford': join(BASE_FOLDER, 'StanfordExtra_v12', 'train_stanford_StanfordExtra_v12.npy'), 24 | 'animal_pose': join(BASE_FOLDER, 'animal_pose', 'train_animal_pose.npy') 25 | } 26 | ] 27 | 28 | 29 | DATASET_FOLDERS = { 30 | 'stanford' : join(BASE_FOLDER, 'StanfordExtra_v12'), 31 | 'animal_pose' : join(BASE_FOLDER, 'animal_pose') 32 | } 33 | 34 | JSON_NAME = { 35 | 'stanford': 'StanfordExtra_v12.json', # the latest version of the StanfordExtra dataset 36 | 'animal_pose': 'animal_pose_data.json' 37 | } 38 | 39 | BREEDS_CSV = join(BASE_FOLDER, 'breeds.csv') 40 | 41 | EM_DATASET_NAME = "stanford" # the dataset to learn the EM prior on 42 | 43 | data_path = join(CODE_DIR, 'data') 44 | # SMAL 45 | SMAL_FILE = join(data_path, 'smal', 'my_smpl_00781_4_all.pkl') 46 | SMAL_DATA_FILE = join(data_path, 'smal', 'my_smpl_data_00781_4_all.pkl') 47 | SMAL_UV_FILE = join(data_path, 'smal', 'my_smpl_00781_4_all_template_w_tex_uv_001.pkl') 48 | SMAL_SYM_FILE = join(data_path, 'smal', 'symIdx.pkl') 49 | SHAPE_FAMILY_ID = 1 # the dog shape family 50 | 51 | # PRIORS 52 | WALKING_PRIOR_FILE = join(data_path, 'priors', 'walking_toy_symmetric_pose_prior_with_cov_35parts.pkl') 53 | UNITY_POSE_PRIOR = join(data_path, 'priors', 'unity_pose_prior_with_cov_35parts.pkl') 54 | UNITY_SHAPE_PRIOR = join(data_path, 'priors', 'unity_betas.npz') 55 | SMAL_DOG_TOY_IDS = [0, 1, 2] 56 | 57 | # DATALOADER 58 | IMG_RES = 224 59 | NUM_JOINTS = 20 60 | # Mean and standard deviation for normalizing input image 61 | IMG_NORM_MEAN = [0.485, 0.456, 0.406] 62 | IMG_NORM_STD = [0.229, 0.224, 0.225] 63 | 64 | # RENDERER 65 | PROJECTION = 'perspective' 66 | NORM_F0 = 2700.0 67 | NORM_F = 2700.0 68 | NORM_Z = 20.0 69 | 70 | MESH_COLOR = [0, 172, 223] 71 | # MESH_COLOR=[234, 156, 199.] 72 | 73 | # MESH_NET 74 | NZ_FEAT = 100 75 | 76 | # ASSOCIATING SMAL TO ANNOTATED JOINTS 77 | MODEL_JOINTS = [ 78 | 14, 13, 12, # left front (0, 1, 2) 79 | 24, 23, 22, # left rear (3, 4, 5) 80 | 10, 9, 8, # right front (6, 7, 8) 81 | 20, 19, 18, # right rear (9, 10, 11) 82 | 25, 31, # tail start -> end (12, 13) 83 | 34, 33, # right ear, left ear (14, 15) 84 | 35, 36, # nose, chin (16, 17) 85 | 37, 38] # right tip, left tip (18, 19) 86 | 87 | CANONICAL_MODEL_JOINTS = [ 88 | 10, 9, 8, # upper_left [paw, middle, top] 89 | 20, 19, 18, # lower_left [paw, middle, top] 90 | 14, 13, 12, # upper_right [paw, middle, top] 91 | 24, 23, 22, # lower_right [paw, middle, top] 92 | 25, 31, # tail [start, end] 93 | 33, 34, # ear base [left, right] 94 | 35, 36, # nose, chin 95 | 38, 37, # ear tip [left, right] 96 | 39, 40, # eyes [left, right] 97 | 15, 15, # withers, throat (TODO: Labelled same as throat for now), throat 98 | 28] # tail middle 99 | 100 | 101 | EVAL_KEYPOINTS = [ 102 | 0, 1, 2, # left front 103 | 3, 4, 5, # left rear 104 | 6, 7, 8, # right front 105 | 9, 10, 11, # right rear 106 | 12, 13, # tail start -> end 107 | 14, 15, # left ear, right ear 108 | 16, 17, # nose, chin 109 | 18, 19] # left tip, right tip 110 | 111 | KEYPOINT_GROUPS = { 112 | 'legs': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], # legs 113 | 'tail': [12, 13], # tail 114 | 'ears': [14, 15, 18, 19], # ears 115 | 'face': [16, 17] # face 116 | } 117 | -------------------------------------------------------------------------------- /util/loss_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import torch 6 | import numpy as np 7 | from util import config 8 | import pickle as pkl 9 | criterionL2 = torch.nn.MSELoss() 10 | criterionL1 = torch.nn.L1Loss() 11 | 12 | 13 | # L1 or L2 based sihouette loss 14 | def mask_loss(mask_pred, mask_gt): 15 | # return torch.nn.L1Loss()(mask_pred, mask_gt) 16 | return criterionL2(mask_pred, mask_gt) 17 | 18 | 19 | # 2D keypoint loss 20 | def kp_l2_loss(kp_pred, kp_gt, num_joints): 21 | vis = (kp_gt[:, :, 2, None] > 0).float() 22 | if not kp_pred.ndim == 3: 23 | kp_pred.reshape((kp_pred.shape[0], num_joints, -1)) 24 | return criterionL2(vis * kp_pred, vis * kp_gt[:, :, :2]) 25 | 26 | 27 | # compute shape prior based on Mahalanobis distance, formulation taken from 28 | # https://github.com/benjiebob/SMALify/blob/master/smal_fitter/smal_fitter.py 29 | class Shape_prior(torch.nn.Module): 30 | def __init__(self, prior, shape_family_id, device, data_path=None): 31 | ''' 32 | Args: 33 | prior: specify the prior to use, smal or unity or self-defined data 34 | shape_family_id: specify animal category id 35 | device: specify device 36 | data_path: specify self-defined data path if do not use smal or unity 37 | ''' 38 | super(Shape_prior, self).__init__() 39 | if prior == 'smal': 40 | nbetas=20 41 | with open(config.SMAL_DATA_FILE, 'rb') as f: 42 | u = pkl._Unpickler(f) 43 | u.encoding = 'latin1' 44 | data = u.load() 45 | shape_cluster_means = data['cluster_means'][shape_family_id] 46 | betas_cov = data['cluster_cov'][shape_family_id] 47 | betas_mean = torch.from_numpy(shape_cluster_means).float().to(device) 48 | elif prior == 'unity': 49 | nbetas=26 50 | unity_data = np.load(config.UNITY_SHAPE_PRIOR) 51 | betas_cov = unity_data['cov'][:-1, :-1] 52 | betas_mean = torch.from_numpy(unity_data['mean'][:-1]).float().to(device) 53 | else: 54 | assert data_path is not None 55 | nbetas=26 56 | prior_data = np.load(data_path, allow_pickle=True) 57 | betas_mean = torch.from_numpy(prior_data.item()['mean']).float().to(device) 58 | betas_cov = prior_data.item()['cov'] 59 | 60 | invcov = np.linalg.inv(betas_cov + 1e-5 * np.eye(betas_cov.shape[0])) 61 | prec = np.linalg.cholesky(invcov) 62 | self.betas_prec = torch.Tensor(prec)[:nbetas, :nbetas].to(device) 63 | self.betas_mean = betas_mean[:nbetas] 64 | 65 | def __call__(self, betas_pred): 66 | diff = betas_pred - self.betas_mean.unsqueeze(0) 67 | res = torch.tensordot(diff, self.betas_prec, dims=([1], [0])) 68 | return (res**2).mean() 69 | 70 | 71 | # Laplacian loss, calculate the Laplacian coordiante of both coarse and refined vertices and then compare the difference 72 | class Laplacian(torch.nn.Module): 73 | def __init__(self, adjmat, device): 74 | ''' 75 | Args: 76 | adjmat: adjacency matrix of the input graph data 77 | device: specify device for training 78 | ''' 79 | super(Laplacian, self).__init__() 80 | adjmat.data = np.ones_like(adjmat.data) 81 | adjmat = torch.from_numpy(adjmat.todense()).float() 82 | dg = torch.sum(adjmat, dim=-1) 83 | dg_m = torch.diag(dg) 84 | ls = dg_m - adjmat 85 | self.ls = ls.unsqueeze(0).to(device) # Should be normalized by the diagonal elements according to 86 | # the origial definition, this one also works fine. 87 | 88 | def forward(self, verts_pred, verts_gt, smooth=False): 89 | verts_pred = torch.matmul(self.ls, verts_pred) 90 | verts_gt = torch.matmul(self.ls, verts_gt) 91 | loss = torch.norm(verts_pred - verts_gt, dim=-1).mean() 92 | if smooth: 93 | loss_smooth = torch.norm(torch.matmul(self.ls, verts_pred), dim=-1).mean() 94 | return loss, loss_smooth 95 | return loss, None 96 | -------------------------------------------------------------------------------- /model/graph_hg.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains the Definition of GraphCNN 3 | GraphCNN includes ResNet50 as a submodule 4 | """ 5 | from __future__ import division 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from model.networks.graph_layers import GraphResBlock, GraphLinear 11 | from smal.mesh import Mesh 12 | from smal.smal_torch import SMAL 13 | 14 | # encoder-decoder structured GCN with skip connections 15 | class GraphCNN_hg(nn.Module): 16 | 17 | def __init__(self, mesh, num_channels=256, local_feat=False, num_downsample=0): 18 | ''' 19 | Args: 20 | mesh: mesh data that store the adjacency matrix 21 | num_channels: number of channels of GCN 22 | local_feat: whether use local feature for refinement 23 | num_downsample: number of downsampling of the input mesh 24 | ''' 25 | super(GraphCNN_hg, self).__init__() 26 | self.A = mesh._A[num_downsample:] # get the correct adjacency matrix because the input might be downsampled 27 | self.num_layers = len(self.A) - 1 28 | print("Number of downsampling layer: {}".format(self.num_layers)) 29 | self.num_downsample = num_downsample 30 | if local_feat: 31 | self.lin1 = GraphLinear(3 + 2048 + 3840, 2 * num_channels) 32 | else: 33 | self.lin1 = GraphLinear(3 + 2048, 2 * num_channels) 34 | self.res1 = GraphResBlock(2 * num_channels, num_channels, self.A[0]) 35 | encode_layers = [] 36 | decode_layers = [] 37 | 38 | for i in range(len(self.A)): 39 | encode_layers.append(GraphResBlock(num_channels, num_channels, self.A[i])) 40 | 41 | decode_layers.append(GraphResBlock((i+1)*num_channels, (i+1)*num_channels, 42 | self.A[len(self.A) - i - 1])) 43 | current_channels = (i+1)*num_channels 44 | # number of channels for the input is different because of the concatenation operation 45 | self.shape = nn.Sequential(GraphResBlock(current_channels, 64, self.A[0]), 46 | GraphResBlock(64, 32, self.A[0]), 47 | nn.GroupNorm(32 // 8, 32), 48 | nn.ReLU(inplace=True), 49 | GraphLinear(32, 3)) 50 | 51 | self.encoder = nn.Sequential(*encode_layers) 52 | self.decoder = nn.Sequential(*decode_layers) 53 | self.mesh = mesh 54 | 55 | def forward(self, verts_c, img_fea_global, img_fea_multiscale=None, points_local=None): 56 | ''' 57 | Args: 58 | verts_c: vertices from the coarse estimation 59 | img_fea_global: global feature for mesh refinement 60 | img_fea_multiscale: multi-scale feature from the encoder, used for local feature extraction 61 | points_local: 2D keypoint for local feature extraction 62 | Returns: refined mesh 63 | ''' 64 | batch_size = img_fea_global.shape[0] 65 | ref_vertices = verts_c.transpose(1, 2) 66 | image_enc = img_fea_global.view(batch_size, 2048, 1).expand(-1, -1, ref_vertices.shape[-1]) 67 | if points_local is not None: 68 | feat_local = torch.nn.functional.grid_sample(img_fea_multiscale, points_local) 69 | x = torch.cat([ref_vertices, image_enc, feat_local.squeeze(2)], dim=1) 70 | else: 71 | x = torch.cat([ref_vertices, image_enc], dim=1) 72 | x = self.lin1(x) 73 | x = self.res1(x) 74 | x_ = [x] 75 | for i in range(self.num_layers + 1): 76 | if i == self.num_layers: 77 | x = self.encoder[i](x) 78 | else: 79 | x = self.encoder[i](x) 80 | x = self.mesh.downsample(x.transpose(1, 2), n1=self.num_downsample+i, n2=self.num_downsample+i+1) 81 | x = x.transpose(1, 2) 82 | if i < self.num_layers-1: 83 | x_.append(x) 84 | for i in range(self.num_layers + 1): 85 | if i == self.num_layers: 86 | x = self.decoder[i](x) 87 | else: 88 | x = self.decoder[i](x) 89 | x = self.mesh.upsample(x.transpose(1, 2), n1=self.num_layers-i+self.num_downsample, 90 | n2=self.num_layers-i-1+self.num_downsample) 91 | x = x.transpose(1, 2) 92 | x = torch.cat([x, x_[self.num_layers-i-1]], dim=1) # skip connection between encoder and decoder 93 | 94 | shape = self.shape(x) 95 | return shape 96 | -------------------------------------------------------------------------------- /util/loss_sdf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from scipy.ndimage import distance_transform_edt as distance 4 | from skimage import segmentation as skimage_seg 5 | import matplotlib.pyplot as plt 6 | 7 | 8 | def dice_loss(score, target): 9 | # implemented from paper https://arxiv.org/pdf/1606.04797.pdf 10 | target = target.float() 11 | smooth = 1e-5 12 | intersect = torch.sum(score * target) 13 | y_sum = torch.sum(target * target) 14 | z_sum = torch.sum(score * score) 15 | loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth) 16 | loss = 1 - loss 17 | return loss 18 | 19 | 20 | class tversky_loss(torch.nn.Module): 21 | # implemented from https://arxiv.org/pdf/1706.05721.pdf 22 | def __init__(self, alpha, beta): 23 | ''' 24 | Args: 25 | alpha: coefficient for false positive prediction 26 | beta: coefficient for false negtive prediction 27 | ''' 28 | super(tversky_loss, self).__init__() 29 | self.alpha = alpha 30 | self.beta = beta 31 | 32 | def __call__(self, score, target): 33 | target = target.float() 34 | smooth = 1e-5 35 | tp = torch.sum(score * target) 36 | fn = torch.sum(target * (1 - score)) 37 | fp = torch.sum((1-target) * score) 38 | loss = (tp + smooth) / (tp + self.alpha * fp + self.beta * fn + smooth) 39 | loss = 1 - loss 40 | return loss 41 | 42 | 43 | def compute_sdf1_1(img_gt, out_shape): 44 | """ 45 | compute the normalized signed distance map of binary mask 46 | input: segmentation, shape = (batch_size, x, y, z) 47 | output: the Signed Distance Map (SDM) 48 | sdf(x) = 0; x in segmentation boundary 49 | -inf|x-y|; x in segmentation 50 | +inf|x-y|; x out of segmentation 51 | normalize sdf to [-1, 1] 52 | """ 53 | 54 | img_gt = img_gt.astype(np.uint8) 55 | 56 | normalized_sdf = np.zeros(out_shape) 57 | 58 | for b in range(out_shape[0]): # batch size 59 | # ignore background 60 | for c in range(1, out_shape[1]): 61 | posmask = img_gt[b] 62 | negmask = 1-posmask 63 | posdis = distance(posmask) 64 | negdis = distance(negmask) 65 | boundary = skimage_seg.find_boundaries(posmask, mode='inner').astype(np.uint8) 66 | sdf = (negdis-np.min(negdis))/(np.max(negdis)-np.min(negdis)) - (posdis-np.min(posdis))/(np.max(posdis)-np.min(posdis)) 67 | sdf[boundary==1] = 0 68 | normalized_sdf[b][c] = sdf 69 | assert np.min(sdf) == -1.0, print(np.min(posdis), np.min(negdis), np.max(posdis), np.max(negdis)) 70 | assert np.max(sdf) == 1.0, print(np.min(posdis), np.min(negdis), np.max(posdis), np.max(negdis)) 71 | 72 | return normalized_sdf 73 | 74 | 75 | def compute_sdf(img_gt, out_shape): 76 | """ 77 | compute the signed distance map of binary mask 78 | input: segmentation, shape = (batch_size, x, y, z) 79 | output: the Signed Distance Map (SDM) 80 | sdf(x) = 0; x in segmentation boundary 81 | -inf|x-y|; x in segmentation 82 | +inf|x-y|; x out of segmentation 83 | """ 84 | 85 | img_gt = img_gt.astype(np.uint8) 86 | 87 | gt_sdf = np.zeros(out_shape) 88 | debug = False 89 | for b in range(out_shape[0]): # batch size 90 | for c in range(0, out_shape[1]): 91 | posmask = img_gt[b] 92 | negmask = 1-posmask 93 | posdis = distance(posmask) 94 | negdis = distance(negmask) 95 | boundary = skimage_seg.find_boundaries(posmask, mode='inner').astype(np.uint8) 96 | sdf = negdis - posdis 97 | sdf[boundary==1] = 0 98 | gt_sdf[b][c] = sdf 99 | if debug: 100 | plt.figure() 101 | plt.subplot(1, 2, 1), plt.imshow(img_gt[b, 0, :, :]), plt.colorbar() 102 | plt.subplot(1, 2, 2), plt.imshow(gt_sdf[b, 0, :, :]), plt.colorbar() 103 | plt.show() 104 | 105 | return gt_sdf 106 | 107 | 108 | def boundary_loss(output, gt): 109 | """ 110 | compute boundary loss for binary segmentation 111 | input: outputs_soft: softmax results, shape=(b,2,x,y,z) 112 | gt_sdf: sdf of ground truth (can be original or normalized sdf); shape=(b,2,x,y,z) 113 | output: boundary_loss; sclar 114 | adopted from http://proceedings.mlr.press/v102/kervadec19a/kervadec19a.pdf 115 | """ 116 | multipled = torch.einsum('bcxy, bcxy->bcxy', output, gt) 117 | bd_loss = multipled.mean() 118 | 119 | return bd_loss -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Coarse-to-fine-3D-Animal 2 | 3 | **About** 4 | 5 | This is the source code for our paper 6 | 7 | Chen Li, Gim Hee Lee. Coarse-to-fine Animal Pose and Shape Estimation. In Neurips 2021. 8 | 9 | The shape space of the SMAL model is learned from 41 scans of toy animals and thus lacks pose ans shape variations. This may limit the representation capacity of the SMAL model and result in poor fittings of the estimated shapes to the 2D observations, as shown in the second to fourth columns of the Figure below. To mitigate this problem, we propose a coarse-to-fine approach, which combines model-based and model-free representations. Some refined results are shown in the fifth to seventh columns. 10 | 11 |

12 | 13 |

14 | 15 | Our network consists of a coarse estimation stage and a mesh refinement stage. The SMAL model parameters and camera parameters are regressed from the input image in the first stage for coarse estimation. This coarse estimation is further refined by an encoder-decoder structured GCN in the second stage. 16 | 17 |

18 | 19 |

20 | 21 | For more details, please refer to [our paper](https://arxiv.org/pdf/2111.08176.pdf). 22 | 23 | **Bibtex** 24 | ``` 25 | @article{li2021coarse, 26 | title={Coarse-to-fine animal pose and shape estimation}, 27 | author={Li, Chen and Lee, Gim Hee}, 28 | journal={Advances in Neural Information Processing Systems}, 29 | volume={34}, 30 | pages={11757--11768}, 31 | year={2021} 32 | } 33 | ``` 34 | 35 | **Dependencies** 36 | 1. Python 3.7.10 37 | 2. Pytorch 1.5.0 38 | 39 | Please refer to requirements.txt for more details on dependencies. 40 | 41 | **Download datasets** 42 | * Download the [StanfordExtra dataset](https://github.com/benjiebob/StanfordExtra) and put it under the folder ./data/. 43 | 44 | * Download the [Animal Pose dataset](https://sites.google.com/view/animal-pose/) and the test split from [WLDO](https://github.com/benjiebob/WLDO/tree/master/data/animal_pose) and put them under the foder ./data/. 45 | 46 | **Download SMAL and priors** 47 | * Download the [SMAL](https://github.com/benjiebob/WLDO/tree/master/data) template and put the downloaded smal folder under ./data/. 48 | * Download the [pose prior data](https://github.com/benjiebob/SMALify/tree/master/data) and put the downloaded priors folder under ./data/. 49 | 50 | **Train** 51 | 52 | We provide the [pretrained model](https://drive.google.com/file/d/1mvr7iYkyKVUxPdFExE0HOsrVNl_sc1O1/view?usp=sharing) for each stage. You can download the pretrained models and put them under the folder ./logs/pretrained_models/. To save training time, You can directly train from stage 2 using our pretrained model for stage 1 ('stage1.pth.tar') by running: 53 | ``` 54 | python main_meshgraph.py --output_dir logs/stage2 --nEpochs 10 --local_feat --batch_size 32 --freezecoarse --gpu_ids 0 --pretrained logs/pretrained_models/stage1.pth.tar 55 | ``` 56 | Then you can continue to train stage 3 by running: 57 | ``` 58 | python main_meshgraph.py --output_dir logs/stage3 --nEpochs 200 --lr 1e-5 --local_feat --w_arap 10000 --w_dice 1000 --w_dice_refine 100 --w_pose_limit_prior 5000 --resume logs/pretrained_models/stage2.pth.tar --gpu_ids 0 59 | ``` 60 | Note that you will need to change the '--resume' to the path of your own model if you want to use your own model from stage 2. 61 | 62 | Alternatively, you can also train from scratch. In this case, you will need to pretrain the coarse estimation part first by running: 63 | ``` 64 | python main.py --batch_size 32 --output_dir logs/stage1 --gpu_ids 0 65 | ``` 66 | Then you can continue to train stage 2 and stage 3 as we have explained. Note that you will need to change the '--pretrain' in stage 2 and '--resume' in stage 3 to the path of your own model. 67 | 68 | **Test** 69 | 70 | Test our model on the StandfordExtra dataset by running: 71 | ``` 72 | python eval.py --output_dir logs/test --resume logs/pretrained_models/stage3.pth.tar --gpu_ids 0 --local_feat 73 | ``` 74 | or on the Animal Pose dataset by running: 75 | ``` 76 | python eval.py --output_dir logs/test --resume logs/pretrained_models/stage3.pth.tar --gpu_ids 0 --local_feat --dataset animal_pose 77 | ``` 78 | Qualitative results can be genrated and saved by adding '--save_results' to the command. 79 | 80 | **Acknowledgements** 81 | 82 | The code for the coarse estimation stage is adopted from [WLDO](https://github.com/benjiebob/WLDO/tree/master/data/animal_pose). If you use the coarse estimation pipeline, please cite: 83 | ``` 84 | @inproceedings{biggs2020wldo, 85 | title={{W}ho left the dogs out?: {3D} animal reconstruction with expectation maximization in the loop}, 86 | author={Biggs, Benjamin and Boyne, Oliver and Charles, James and Fitzgibbon, Andrew and Cipolla, Roberto}, 87 | booktitle={ECCV}, 88 | year={2020} 89 | } 90 | ``` 91 | -------------------------------------------------------------------------------- /smal/mesh.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | import numpy as np 4 | import scipy.sparse 5 | 6 | from model.networks.graph_layers import spmm 7 | 8 | 9 | def scipy_to_pytorch(A, U, D): 10 | """Convert scipy sparse matrices to pytorch sparse matrix.""" 11 | ptU = [] 12 | ptD = [] 13 | 14 | for i in range(len(U)): 15 | u = scipy.sparse.coo_matrix(U[i]) 16 | i = torch.LongTensor(np.array([u.row, u.col])) 17 | v = torch.FloatTensor(u.data) 18 | ptU.append(torch.sparse.FloatTensor(i, v, u.shape)) 19 | 20 | for i in range(len(D)): 21 | d = scipy.sparse.coo_matrix(D[i]) 22 | i = torch.LongTensor(np.array([d.row, d.col])) 23 | v = torch.FloatTensor(d.data) 24 | ptD.append(torch.sparse.FloatTensor(i, v, d.shape)) 25 | 26 | return ptU, ptD 27 | 28 | 29 | def adjmat_sparse(adjmat, nsize=1): 30 | """Create row-normalized sparse graph adjacency matrix.""" 31 | adjmat = scipy.sparse.csr_matrix(adjmat) 32 | if nsize > 1: 33 | orig_adjmat = adjmat.copy() 34 | for _ in range(1, nsize): 35 | adjmat = adjmat * orig_adjmat 36 | adjmat.data = np.ones_like(adjmat.data) 37 | for i in range(adjmat.shape[0]): 38 | adjmat[i, i] = 1 39 | num_neighbors = np.array(1 / adjmat.sum(axis=-1)) 40 | adjmat = adjmat.multiply(num_neighbors) 41 | adjmat = scipy.sparse.coo_matrix(adjmat) 42 | row = adjmat.row 43 | col = adjmat.col 44 | data = adjmat.data 45 | i = torch.LongTensor(np.array([row, col])) 46 | v = torch.from_numpy(data).float() 47 | adjmat = torch.sparse.FloatTensor(i, v, adjmat.shape) 48 | return adjmat 49 | 50 | 51 | def get_graph_params(filename, nsize=1): 52 | """Load and process graph adjacency matrix and upsampling/downsampling matrices.""" 53 | data = np.load(filename, encoding='latin1', allow_pickle=True) 54 | A = data['A'] 55 | U = data['U'] 56 | D = data['D'] 57 | U, D = scipy_to_pytorch(A, U, D) 58 | A = [adjmat_sparse(a, nsize=nsize) for a in A] 59 | return A, U, D 60 | 61 | 62 | class Mesh(object): 63 | """Mesh object that is used for handling certain graph operations.""" 64 | 65 | def __init__(self, smal, filename='./data/mesh_down_sampling.npz', 66 | num_downsampling=1, nsize=1, device=torch.device('cuda')): 67 | self._A, self._U, self._D = get_graph_params(filename=filename, nsize=nsize) 68 | self._A = [a.to(device) for a in self._A] 69 | self._U = [u.to(device) for u in self._U] 70 | self._D = [d.to(device) for d in self._D] 71 | self.num_downsampling = num_downsampling 72 | 73 | # load template vertices from SMPL and normalize them 74 | ref_vertices = smal.v_template.clone() 75 | center = 0.5 * (ref_vertices.max(dim=0)[0] + ref_vertices.min(dim=0)[0])[None] 76 | ref_vertices -= center 77 | ref_vertices /= ref_vertices.abs().max().item() 78 | 79 | self._ref_vertices = ref_vertices.to(device) 80 | self.faces = smal.faces.int() 81 | 82 | @property 83 | def adjmat(self): 84 | """Return the graph adjacency matrix at the specified subsampling level.""" 85 | return self._A[self.num_downsampling].float() 86 | 87 | @property 88 | def ref_vertices(self): 89 | """Return the template vertices at the specified subsampling level.""" 90 | ref_vertices = self._ref_vertices 91 | for i in range(self.num_downsampling): 92 | ref_vertices = torch.spmm(self._D[i], ref_vertices) 93 | return ref_vertices 94 | 95 | def downsample(self, x, n1=0, n2=None): 96 | """Downsample mesh.""" 97 | if n2 is None: 98 | n2 = self.num_downsampling 99 | if x.ndimension() < 3: 100 | for i in range(n1, n2): 101 | x = spmm(self._D[i], x) 102 | elif x.ndimension() == 3: 103 | out = [] 104 | for i in range(x.shape[0]): 105 | y = x[i] 106 | for j in range(n1, n2): 107 | y = spmm(self._D[j], y) 108 | out.append(y) 109 | x = torch.stack(out, dim=0) 110 | return x 111 | 112 | def upsample(self, x, n1=None, n2=0): 113 | """Upsample mesh.""" 114 | if n1 is None: 115 | n1 = self.num_downsampling 116 | if x.ndimension() < 3: 117 | for i in reversed(range(n2, n1)): 118 | x = spmm(self._U[i], x) 119 | elif x.ndimension() == 3: 120 | out = [] 121 | for i in range(x.shape[0]): 122 | y = x[i] 123 | for j in reversed(range(n2, n1)): 124 | y = spmm(self._U[j], y) 125 | out.append(y) 126 | x = torch.stack(out, dim=0) 127 | return x 128 | -------------------------------------------------------------------------------- /model/mesh_graph_hg.py: -------------------------------------------------------------------------------- 1 | """ 2 | Mesh net model. 3 | """ 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import numpy as np 9 | import torch.nn as nn 10 | import torch 11 | from smal.smal_torch import SMAL 12 | from util import config 13 | from model.graph_hg import GraphCNN_hg 14 | from model.smal_mesh_net_img import MeshNet_img 15 | from util.nmr import NeuralRenderer 16 | from smal.mesh import Mesh 17 | import cv2 18 | # ------------- Modules ------------# 19 | # ----------------------------------# 20 | 21 | unity_shape_prior = np.load('data/priors/unity_betas.npz') 22 | 23 | 24 | class MeshGraph_hg(nn.Module): 25 | def __init__(self, device, shape_family_id, number_channels, num_layers, betas_scale=False, shape_init=None 26 | , local_feat=False, num_downsampling=0, render_rgb=False): 27 | ''' 28 | 29 | Args: 30 | device: specify device for training 31 | shape_family_id: specify animal category id 32 | number_channels: specify number of channels for GCN 33 | betas_scale: whether predict additional shape parameters proposed by WLDO 34 | shape_init: whether initiate the bias weights for the coarse stage as mean shape 35 | local_feat: whether use local feature for refinement step 36 | num_downsampling: number of donwsamplings before input to GCN. 37 | We downsample the original mesh once before going through GCN to save memory 38 | render_rgb: wehther render the 3D mesh onto 2D to get RGB image. Only set to true when generating 39 | visualization to save inference time. 40 | ''' 41 | super(MeshGraph_hg, self).__init__() 42 | 43 | self.model_renderer = NeuralRenderer(config.IMG_RES, proj_type=config.PROJECTION, 44 | norm_f0=config.NORM_F0, 45 | norm_f=config.NORM_F, 46 | norm_z=config.NORM_Z, render_rgb=render_rgb, device=device) 47 | self.model_renderer.directional_light_only() 48 | self.smal = SMAL(device, shape_family_id=shape_family_id) 49 | self.local_feat = local_feat 50 | if shape_init == 'smal': 51 | print("Initiate shape with smal prior") 52 | shape_init = self.smal.shape_cluster_means 53 | elif shape_init == 'unity': 54 | print("Initiate shape with unity prior ") 55 | shape_init = unity_shape_prior['mean'][:-1] 56 | shape_init = torch.from_numpy(shape_init).float().to(device) 57 | else: 58 | print("No initialization for shape") 59 | shape_init = None 60 | 61 | input_size = [config.IMG_RES, config.IMG_RES] 62 | self.meshnet = MeshNet_img(input_size, betas_scale=betas_scale, norm_f0=config.NORM_F0, nz_feat=config.NZ_FEAT, 63 | shape_init=shape_init, return_feat=True) 64 | self.mesh = Mesh(self.smal, num_downsampling=num_downsampling, filename='./data/mesh_down_sampling_4.npz', 65 | device=device) 66 | self.graphnet = GraphCNN_hg(self.mesh, num_channels=number_channels, 67 | local_feat=local_feat, num_downsample=num_downsampling).to(device) 68 | 69 | def forward(self, img): 70 | pred_codes, enc_feat, feat_multiscale = self.meshnet(img) 71 | scale_pred, trans_pred, pose_pred, betas_pred, betas_scale_pred = pred_codes 72 | pred_camera = torch.cat([scale_pred[:, [0]], torch.ones(scale_pred.shape[0], 2).cuda() * config.IMG_RES / 2], 73 | dim=1) 74 | verts, joints, _, _ = self.smal(betas_pred, pose_pred, trans=trans_pred, 75 | betas_logscale=betas_scale_pred) 76 | enc_feat_copy = enc_feat.detach().clone() 77 | verts_d = self.mesh.downsample(verts) 78 | verts_copy = verts_d.detach().clone() 79 | if self.local_feat: 80 | feat_multiscale_copy = feat_multiscale.detach().clone() 81 | points_img = self.model_renderer.project_points(verts_copy, pred_camera, normalize_kpts=True) 82 | # used normalized image coordinate for the requirement of torch.nn.functional.grid_sample 83 | verts_re = self.graphnet(verts_d, enc_feat_copy, feat_multiscale_copy, points_img.unsqueeze(1)) 84 | else: 85 | verts_re = self.graphnet(verts_d, enc_feat_copy) 86 | verts_re = self.mesh.upsample(verts_re.transpose(1, 2)) 87 | return verts, joints, verts_re, pred_codes 88 | 89 | 90 | def init_pretrained(model, checkpoint): 91 | pretrained_dict = checkpoint['state_dict'] 92 | model_dict = model.state_dict() 93 | model_dict.update(pretrained_dict) 94 | model.load_state_dict(model_dict) 95 | print("Init MeshNet") 96 | 97 | -------------------------------------------------------------------------------- /util/logger.py: -------------------------------------------------------------------------------- 1 | # A simple torch style logger 2 | # (C) Wei YANG 2017 3 | from __future__ import absolute_import 4 | 5 | import os 6 | import sys 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | 10 | __all__ = ['Logger', 'LoggerMonitor', 'savefig'] 11 | 12 | def savefig(fname, dpi=None): 13 | dpi = 150 if dpi == None else dpi 14 | plt.savefig(fname, dpi=dpi) 15 | 16 | def plot_overlap(logger, names=None): 17 | names = logger.names if names == None else names 18 | numbers = logger.numbers 19 | for _, name in enumerate(names): 20 | x = np.arange(len(numbers[name])) 21 | plt.plot(x, np.asarray(numbers[name])) 22 | return [logger.title + '(' + name + ')' for name in names] 23 | 24 | class Logger(object): 25 | '''Save training process to log file with simple plot function.''' 26 | def __init__(self, fpath, title=None, resume=False): 27 | self.file = None 28 | self.resume = resume 29 | self.title = '' if title == None else title 30 | if fpath is not None: 31 | if resume: 32 | self.file = open(fpath, 'r') 33 | name = self.file.readline() 34 | self.names = name.rstrip().split('\t') 35 | self.numbers = {} 36 | for _, name in enumerate(self.names): 37 | self.numbers[name] = [] 38 | 39 | for numbers in self.file: 40 | numbers = numbers.rstrip().split('\t') 41 | for i in range(0, len(numbers)): 42 | self.numbers[self.names[i]].append(numbers[i]) 43 | self.file.close() 44 | self.file = open(fpath, 'a') 45 | else: 46 | self.file = open(fpath, 'w') 47 | 48 | def set_names(self, names): 49 | if self.resume: 50 | pass 51 | # initialize numbers as empty list 52 | self.numbers = {} 53 | self.names = names 54 | for _, name in enumerate(self.names): 55 | self.file.write(name) 56 | self.file.write('\t') 57 | self.numbers[name] = [] 58 | self.file.write('\n') 59 | self.file.flush() 60 | 61 | def append(self, numbers): 62 | assert len(self.names) == len(numbers), 'Numbers do not match names' 63 | for index, num in enumerate(numbers): 64 | self.file.write("{0:.6f}".format(num)) 65 | self.file.write('\t') 66 | self.numbers[self.names[index]].append(num) 67 | self.file.write('\n') 68 | self.file.flush() 69 | 70 | def log_arguments(self, args): 71 | 72 | self.file.write('Command:{}'.format(sys.argv)) 73 | self.file.write('\n') 74 | s = '\n'.join(['{}: {}'.format(arg, getattr(args, arg)) for arg in vars(args)]) 75 | s = 'Arguments:\n' + s 76 | self.file.write(s) 77 | self.file.write('\n') 78 | self.file.flush() 79 | 80 | def plot(self, names=None): 81 | names = self.names if names == None else names 82 | numbers = self.numbers 83 | for _, name in enumerate(names): 84 | x = np.arange(len(numbers[name])) 85 | plt.plot(x, np.asarray(numbers[name])) 86 | plt.legend([self.title + '(' + name + ')' for name in names]) 87 | plt.grid(True) 88 | plt.gca().invert_yaxis() 89 | plt.show() 90 | 91 | def close(self): 92 | if self.file is not None: 93 | self.file.close() 94 | 95 | 96 | class LoggerMonitor(object): 97 | '''Load and visualize multiple logs.''' 98 | def __init__ (self, paths): 99 | '''paths is a distionary with {name:filepath} pair''' 100 | self.loggers = [] 101 | for title, path in paths.items(): 102 | logger = Logger(path, title=title, resume=True) 103 | self.loggers.append(logger) 104 | 105 | def plot(self, names=None): 106 | plt.figure() 107 | plt.subplot(121) 108 | legend_text = [] 109 | for logger in self.loggers: 110 | legend_text += plot_overlap(logger, names) 111 | plt.legend(legend_text, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) 112 | plt.grid(True) 113 | 114 | if __name__ == '__main__': 115 | # # Example 116 | # logger = Logger('test.txt') 117 | # logger.set_names(['Train loss', 'Valid loss','Test loss']) 118 | 119 | # length = 100 120 | # t = np.arange(length) 121 | # train_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 122 | # valid_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 123 | # test_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 124 | 125 | # for i in range(0, length): 126 | # logger.append([train_loss[i], valid_loss[i], test_loss[i]]) 127 | # logger.plot() 128 | 129 | # Example: logger monitor 130 | # paths = { 131 | # 'resadvnet20':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt', 132 | # 'resadvnet32':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt', 133 | # 'resadvnet44':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt', 134 | # } 135 | # 136 | # field = ['Valid Acc.'] 137 | # 138 | # monitor = LoggerMonitor(paths) 139 | # monitor.plot(names=field) 140 | # savefig('test.eps') 141 | logger = Logger(fpath='/media/haleh/Harddisk1/checkpoint/horse/syn2real_grl/log.txt', resume=True) 142 | logger.plot(names=['LR']) -------------------------------------------------------------------------------- /model/networks/graph_layers.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains definitions of layers used to build the GraphCNN 3 | """ 4 | from __future__ import division 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import math 10 | # from torch_geometric.nn.conv import MessagePassing 11 | 12 | class GraphConvolution(nn.Module): 13 | """Simple GCN layer, similar to https://arxiv.org/abs/1609.02907.""" 14 | def __init__(self, in_features, out_features, adjmat, bias=True): 15 | super(GraphConvolution, self).__init__() 16 | self.in_features = in_features 17 | self.out_features = out_features 18 | self.adjmat = adjmat 19 | self.weight = nn.Parameter(torch.FloatTensor(in_features, out_features)) 20 | if bias: 21 | self.bias = nn.Parameter(torch.FloatTensor(out_features)) 22 | else: 23 | self.register_parameter('bias', None) 24 | self.reset_parameters() 25 | 26 | def reset_parameters(self): 27 | # stdv = 1. / math.sqrt(self.weight.size(1)) 28 | stdv = 6. / math.sqrt(self.weight.size(0) + self.weight.size(1)) 29 | self.weight.data.uniform_(-stdv, stdv) 30 | if self.bias is not None: 31 | self.bias.data.uniform_(-stdv, stdv) 32 | 33 | def forward(self, x): 34 | if x.ndimension() == 2: 35 | support = torch.matmul(x, self.weight) 36 | output = torch.matmul(self.adjmat, support) 37 | if self.bias is not None: 38 | output = output + self.bias 39 | return output 40 | else: 41 | output = [] 42 | for i in range(x.shape[0]): 43 | support = torch.matmul(x[i], self.weight) 44 | # output.append(torch.matmul(self.adjmat, support)) 45 | output.append(spmm(self.adjmat, support)) 46 | output = torch.stack(output, dim=0) 47 | if self.bias is not None: 48 | output = output + self.bias 49 | return output 50 | 51 | def __repr__(self): 52 | return self.__class__.__name__ + ' (' \ 53 | + str(self.in_features) + ' -> ' \ 54 | + str(self.out_features) + ')' 55 | 56 | 57 | class GraphLinear(nn.Module): 58 | """ 59 | Generalization of 1x1 convolutions on Graphs 60 | """ 61 | def __init__(self, in_channels, out_channels): 62 | super(GraphLinear, self).__init__() 63 | self.in_channels = in_channels 64 | self.out_channels = out_channels 65 | self.W = nn.Parameter(torch.FloatTensor(out_channels, in_channels)) 66 | self.b = nn.Parameter(torch.FloatTensor(out_channels)) 67 | self.reset_parameters() 68 | 69 | def reset_parameters(self): 70 | w_stdv = 1 / (self.in_channels * self.out_channels) 71 | self.W.data.uniform_(-w_stdv, w_stdv) 72 | self.b.data.uniform_(-w_stdv, w_stdv) 73 | 74 | def forward(self, x): 75 | return torch.matmul(self.W[None, :], x) + self.b[None, :, None] 76 | 77 | 78 | class GraphResBlock(nn.Module): 79 | """ 80 | Graph Residual Block similar to the Bottleneck Residual Block in ResNet 81 | """ 82 | 83 | def __init__(self, in_channels, out_channels, A): 84 | super(GraphResBlock, self).__init__() 85 | self.in_channels = in_channels 86 | self.out_channels = out_channels 87 | self.lin1 = GraphLinear(in_channels, out_channels // 2) 88 | self.conv = GraphConvolution(out_channels // 2, out_channels // 2, A) 89 | self.lin2 = GraphLinear(out_channels // 2, out_channels) 90 | self.skip_conv = GraphLinear(in_channels, out_channels) 91 | self.pre_norm = nn.GroupNorm(in_channels // 8, in_channels) 92 | self.norm1 = nn.GroupNorm((out_channels // 2) // 8, (out_channels // 2)) 93 | self.norm2 = nn.GroupNorm((out_channels // 2) // 8, (out_channels // 2)) 94 | 95 | def forward(self, x): 96 | y = F.relu(self.pre_norm(x)) 97 | y = self.lin1(y) 98 | 99 | y = F.relu(self.norm1(y)) 100 | y = self.conv(y.transpose(1,2)).transpose(1,2) 101 | 102 | y = F.relu(self.norm2(y)) 103 | y = self.lin2(y) 104 | if self.in_channels != self.out_channels: 105 | x = self.skip_conv(x) 106 | return x+y 107 | 108 | 109 | class SparseMM(torch.autograd.Function): 110 | """Redefine sparse @ dense matrix multiplication to enable backpropagation. 111 | The builtin matrix multiplication operation does not support backpropagation in some cases. 112 | """ 113 | @staticmethod 114 | def forward(ctx, sparse, dense): 115 | ctx.req_grad = dense.requires_grad 116 | ctx.save_for_backward(sparse) 117 | return torch.matmul(sparse, dense) 118 | 119 | @staticmethod 120 | def backward(ctx, grad_output): 121 | grad_input = None 122 | sparse, = ctx.saved_tensors 123 | if ctx.req_grad: 124 | grad_input = torch.matmul(sparse.t(), grad_output) 125 | return None, grad_input 126 | 127 | 128 | def spmm(sparse, dense): 129 | return SparseMM.apply(sparse, dense) 130 | 131 | 132 | # class Pool(MessagePassing): 133 | # def __init__(self, pool_mat): 134 | # super(Pool, self).__init__(flow='target_to_source') 135 | # self.pool_mat = pool_mat 136 | # 137 | # def forward(self, x): 138 | # x = x.transpose(0, 1) 139 | # out = self.propagate(edge_index=self.pool_mat._indices(), x=x, norm=self.pool_mat._values(), 140 | # size=self.pool_mat.size()) 141 | # return out.transpose(0, 1) 142 | # 143 | # def message(self, x_j, norm): 144 | # return norm.view(-1, 1, 1) * x_j -------------------------------------------------------------------------------- /util/helpers/conversions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import torchgeometry as tgm 5 | 6 | def rotation_matrix_to_angle_axis(rotation_matrix): 7 | """Convert 3x4 rotation matrix to Rodrigues vector 8 | 9 | Args: 10 | rotation_matrix (Tensor): rotation matrix. 11 | 12 | Returns: 13 | Tensor: Rodrigues vector transformation. 14 | 15 | Shape: 16 | - Input: :math:`(N, 3, 4)` 17 | - Output: :math:`(N, 3)` 18 | 19 | Example: 20 | >>> input = torch.rand(2, 3, 4) # Nx4x4 21 | >>> output = tgm.rotation_matrix_to_angle_axis(input) # Nx3 22 | """ 23 | # todo add check that matrix is a valid rotation matrix 24 | quaternion = rotation_matrix_to_quaternion(rotation_matrix) 25 | return quaternion_to_angle_axis(quaternion) 26 | 27 | def rotation_matrix_to_quaternion(rotation_matrix, eps=1e-6): 28 | """Convert 3x4 rotation matrix to 4d quaternion vector 29 | 30 | This algorithm is based on algorithm described in 31 | https://github.com/KieranWynn/pyquaternion/blob/master/pyquaternion/quaternion.py#L201 32 | 33 | Args: 34 | rotation_matrix (Tensor): the rotation matrix to convert. 35 | 36 | Return: 37 | Tensor: the rotation in quaternion 38 | 39 | Shape: 40 | - Input: :math:`(N, 3, 4)` 41 | - Output: :math:`(N, 4)` 42 | 43 | Example: 44 | >>> input = torch.rand(4, 3, 4) # Nx3x4 45 | >>> output = tgm.rotation_matrix_to_quaternion(input) # Nx4 46 | """ 47 | if not torch.is_tensor(rotation_matrix): 48 | raise TypeError("Input type is not a torch.Tensor. Got {}".format( 49 | type(rotation_matrix))) 50 | 51 | if len(rotation_matrix.shape) > 3: 52 | raise ValueError( 53 | "Input size must be a three dimensional tensor. Got {}".format( 54 | rotation_matrix.shape)) 55 | if not rotation_matrix.shape[-2:] == (3, 4): 56 | raise ValueError( 57 | "Input size must be a N x 3 x 4 tensor. Got {}".format( 58 | rotation_matrix.shape)) 59 | 60 | rmat_t = torch.transpose(rotation_matrix, 1, 2) 61 | 62 | mask_d2 = rmat_t[:, 2, 2] < eps 63 | 64 | mask_d0_d1 = rmat_t[:, 0, 0] > rmat_t[:, 1, 1] 65 | mask_d0_nd1 = rmat_t[:, 0, 0] < -rmat_t[:, 1, 1] 66 | 67 | t0 = 1 + rmat_t[:, 0, 0] - rmat_t[:, 1, 1] - rmat_t[:, 2, 2] 68 | q0 = torch.stack([rmat_t[:, 1, 2] - rmat_t[:, 2, 1], 69 | t0, rmat_t[:, 0, 1] + rmat_t[:, 1, 0], 70 | rmat_t[:, 2, 0] + rmat_t[:, 0, 2]], -1) 71 | t0_rep = t0.repeat(4, 1).t() 72 | 73 | t1 = 1 - rmat_t[:, 0, 0] + rmat_t[:, 1, 1] - rmat_t[:, 2, 2] 74 | q1 = torch.stack([rmat_t[:, 2, 0] - rmat_t[:, 0, 2], 75 | rmat_t[:, 0, 1] + rmat_t[:, 1, 0], 76 | t1, rmat_t[:, 1, 2] + rmat_t[:, 2, 1]], -1) 77 | t1_rep = t1.repeat(4, 1).t() 78 | 79 | t2 = 1 - rmat_t[:, 0, 0] - rmat_t[:, 1, 1] + rmat_t[:, 2, 2] 80 | q2 = torch.stack([rmat_t[:, 0, 1] - rmat_t[:, 1, 0], 81 | rmat_t[:, 2, 0] + rmat_t[:, 0, 2], 82 | rmat_t[:, 1, 2] + rmat_t[:, 2, 1], t2], -1) 83 | t2_rep = t2.repeat(4, 1).t() 84 | 85 | t3 = 1 + rmat_t[:, 0, 0] + rmat_t[:, 1, 1] + rmat_t[:, 2, 2] 86 | q3 = torch.stack([t3, rmat_t[:, 1, 2] - rmat_t[:, 2, 1], 87 | rmat_t[:, 2, 0] - rmat_t[:, 0, 2], 88 | rmat_t[:, 0, 1] - rmat_t[:, 1, 0]], -1) 89 | t3_rep = t3.repeat(4, 1).t() 90 | 91 | mask_c0 = mask_d2 * mask_d0_d1 92 | # mask_c1 = mask_d2 * (1 - mask_d0_d1) 93 | mask_c1 = mask_d2 * (~mask_d0_d1) 94 | # mask_c2 = (1 - mask_d2) * mask_d0_nd1 95 | mask_c2 = (~mask_d2) * mask_d0_nd1 96 | # mask_c3 = (1 - mask_d2) * (1 - mask_d0_nd1) 97 | mask_c3 = (~mask_d2) * (~mask_d0_nd1) 98 | mask_c0 = mask_c0.view(-1, 1).type_as(q0) 99 | mask_c1 = mask_c1.view(-1, 1).type_as(q1) 100 | mask_c2 = mask_c2.view(-1, 1).type_as(q2) 101 | mask_c3 = mask_c3.view(-1, 1).type_as(q3) 102 | 103 | q = q0 * mask_c0 + q1 * mask_c1 + q2 * mask_c2 + q3 * mask_c3 104 | q /= torch.sqrt(t0_rep * mask_c0 + t1_rep * mask_c1 + # noqa 105 | t2_rep * mask_c2 + t3_rep * mask_c3) # noqa 106 | q *= 0.5 107 | return q 108 | 109 | def quaternion_to_angle_axis(quaternion: torch.Tensor) -> torch.Tensor: 110 | """Convert quaternion vector to angle axis of rotation. 111 | 112 | Adapted from ceres C++ library: ceres-solver/include/ceres/rotation.h 113 | 114 | Args: 115 | quaternion (torch.Tensor): tensor with quaternions. 116 | 117 | Return: 118 | torch.Tensor: tensor with angle axis of rotation. 119 | 120 | Shape: 121 | - Input: :math:`(*, 4)` where `*` means, any number of dimensions 122 | - Output: :math:`(*, 3)` 123 | 124 | Example: 125 | >>> quaternion = torch.rand(2, 4) # Nx4 126 | >>> angle_axis = tgm.quaternion_to_angle_axis(quaternion) # Nx3 127 | """ 128 | if not torch.is_tensor(quaternion): 129 | raise TypeError("Input type is not a torch.Tensor. Got {}".format( 130 | type(quaternion))) 131 | 132 | if not quaternion.shape[-1] == 4: 133 | raise ValueError("Input must be a tensor of shape Nx4 or 4. Got {}" 134 | .format(quaternion.shape)) 135 | # unpack input and compute conversion 136 | q1: torch.Tensor = quaternion[..., 1] 137 | q2: torch.Tensor = quaternion[..., 2] 138 | q3: torch.Tensor = quaternion[..., 3] 139 | sin_squared_theta: torch.Tensor = q1 * q1 + q2 * q2 + q3 * q3 140 | 141 | sin_theta: torch.Tensor = torch.sqrt(sin_squared_theta) 142 | cos_theta: torch.Tensor = quaternion[..., 0] 143 | two_theta: torch.Tensor = 2.0 * torch.where( 144 | cos_theta < 0.0, 145 | torch.atan2(-sin_theta, -cos_theta), 146 | torch.atan2(sin_theta, cos_theta)) 147 | 148 | k_pos: torch.Tensor = two_theta / sin_theta 149 | k_neg: torch.Tensor = 2.0 * torch.ones_like(sin_theta) 150 | k: torch.Tensor = torch.where(sin_squared_theta > 0.0, k_pos, k_neg) 151 | 152 | angle_axis: torch.Tensor = torch.zeros_like(quaternion)[..., :3] 153 | angle_axis[..., 0] += q1 * k 154 | angle_axis[..., 1] += q2 * k 155 | angle_axis[..., 2] += q3 * k 156 | return angle_axis -------------------------------------------------------------------------------- /smal/batch_lbs.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import torch 6 | import numpy as np 7 | 8 | 9 | def batch_skew(vec, batch_size=None, opts=None): 10 | """ 11 | vec is N x 3, batch_size is int 12 | 13 | returns N x 3 x 3. Skew_sym version of each matrix. 14 | """ 15 | if batch_size is None: 16 | batch_size = vec.shape.as_list()[0] 17 | col_inds = torch.LongTensor([1, 2, 3, 5, 6, 7]) 18 | indices = torch.reshape(torch.reshape(torch.arange(0, batch_size) * 9, [-1, 1]) + col_inds, [-1, 1]) 19 | updates = torch.reshape( 20 | torch.stack( 21 | [ 22 | -vec[:, 2], vec[:, 1], vec[:, 2], -vec[:, 0], -vec[:, 1], 23 | vec[:, 0] 24 | ], 25 | dim=1), [-1]) 26 | out_shape = [batch_size * 9] 27 | res = torch.Tensor(np.zeros(out_shape[0])).cuda(device=vec.device) 28 | res[np.array(indices.flatten())] = updates 29 | res = torch.reshape(res, [batch_size, 3, 3]) 30 | 31 | return res 32 | 33 | def batch_rodrigues(theta, opts=None): 34 | """ 35 | Theta is Nx3 36 | """ 37 | batch_size = theta.shape[0] 38 | 39 | angle = (torch.norm(theta + 1e-8, p=2, dim=1)).unsqueeze(-1) 40 | r = (torch.div(theta, angle)).unsqueeze(-1) 41 | 42 | angle = angle.unsqueeze(-1) 43 | cos = torch.cos(angle) 44 | sin = torch.sin(angle) 45 | 46 | outer = torch.matmul(r, r.transpose(1,2)) 47 | 48 | eyes = torch.eye(3).unsqueeze(0).repeat([batch_size, 1, 1]).cuda(device=theta.device) 49 | H = batch_skew(r, batch_size=batch_size, opts=opts) 50 | R = cos * eyes + (1 - cos) * outer + sin * H 51 | 52 | return R 53 | 54 | def batch_lrotmin(theta): 55 | """ 56 | Output of this is used to compute joint-to-pose blend shape mapping. 57 | Equation 9 in SMPL paper. 58 | 59 | 60 | Args: 61 | pose: `Tensor`, N x 72 vector holding the axis-angle rep of K joints. 62 | This includes the global rotation so K=24 63 | 64 | Returns 65 | diff_vec : `Tensor`: N x 207 rotation matrix of 23=(K-1) joints with identity subtracted., 66 | """ 67 | # Ignore global rotation 68 | theta = theta[:,3:] 69 | 70 | Rs = batch_rodrigues(torch.reshape(theta, [-1,3])) 71 | lrotmin = torch.reshape(Rs - torch.eye(3), [-1, 207]) 72 | 73 | return lrotmin 74 | 75 | def batch_global_rigid_transformation(Rs, Js, parent, rotate_base = False, betas_logscale=None, opts=None): 76 | """ 77 | Computes absolute joint locations given pose. 78 | 79 | rotate_base: if True, rotates the global rotation by 90 deg in x axis. 80 | if False, this is the original SMPL coordinate. 81 | 82 | Args: 83 | Rs: N x 24 x 3 x 3 rotation vector of K joints 84 | Js: N x 24 x 3, joint locations before posing 85 | parent: 24 holding the parent id for each index 86 | 87 | Returns 88 | new_J : `Tensor`: N x 24 x 3 location of absolute joints 89 | A : `Tensor`: N x 24 4 x 4 relative joint transformations for LBS. 90 | """ 91 | if rotate_base: 92 | print('Flipping the SMPL coordinate frame!!!!') 93 | rot_x = torch.Tensor([[1, 0, 0], [0, -1, 0], [0, 0, -1]]) 94 | rot_x = torch.reshape(torch.repeat(rot_x, [N, 1]), [N, 3, 3]) # In tf it was tile 95 | root_rotation = torch.matmul(Rs[:, 0, :, :], rot_x) 96 | else: 97 | root_rotation = Rs[:, 0, :, :] 98 | 99 | # Now Js is N x 24 x 3 x 1 100 | Js = Js.unsqueeze(-1) 101 | N = Rs.shape[0] 102 | 103 | Js_orig = Js.clone() 104 | 105 | scaling_factors = torch.ones(N, parent.shape[0], 3).to(Rs.device) 106 | if betas_logscale is not None: 107 | leg_joints = list(range(7,11)) + list(range(11,15)) + list(range(17,21)) + list(range(21,25)) 108 | tail_joints = list(range(25, 32)) 109 | ear_joints = [33, 34] 110 | 111 | beta_scale_mask = torch.zeros(35, 3, 6).to(betas_logscale.device) 112 | beta_scale_mask[leg_joints, [2], [0]] = 1.0 # Leg lengthening 113 | beta_scale_mask[leg_joints, [0], [1]] = 1.0 # Leg fatness 114 | beta_scale_mask[leg_joints, [1], [1]] = 1.0 # Leg fatness 115 | 116 | beta_scale_mask[tail_joints, [0], [2]] = 1.0 # Tail lengthening 117 | beta_scale_mask[tail_joints, [1], [3]] = 1.0 # Tail fatness 118 | beta_scale_mask[tail_joints, [2], [3]] = 1.0 # Tail fatness 119 | 120 | beta_scale_mask[ear_joints, [1], [4]] = 1.0 # Ear y 121 | beta_scale_mask[ear_joints, [2], [5]] = 1.0 # Ear z 122 | 123 | beta_scale_mask = torch.transpose( 124 | beta_scale_mask.reshape(35*3, 6), 0, 1) 125 | 126 | betas_scale = torch.exp(betas_logscale @ beta_scale_mask) 127 | scaling_factors = betas_scale.reshape(-1, 35, 3) 128 | 129 | scale_factors_3x3 = torch.diag_embed(scaling_factors, dim1=-2, dim2=-1) 130 | 131 | def make_A(R, t): 132 | # Rs is N x 3 x 3, ts is N x 3 x 1 133 | R_homo = torch.nn.functional.pad(R, (0,0,0,1,0,0)) 134 | t_homo = torch.cat([t, torch.ones([N, 1, 1]).to(Rs.device)], 1) 135 | return torch.cat([R_homo, t_homo], 2) 136 | 137 | A0 = make_A(root_rotation, Js[:, 0]) 138 | results = [A0] 139 | for i in range(1, parent.shape[0]): 140 | j_here = Js[:, i] - Js[:, parent[i]] 141 | 142 | s_par_inv = torch.inverse(scale_factors_3x3[:, parent[i]]) 143 | rot = Rs[:, i] 144 | s = scale_factors_3x3[:, i] 145 | 146 | rot_new = s_par_inv @ rot @ s 147 | 148 | A_here = make_A(rot_new, j_here) 149 | res_here = torch.matmul( 150 | results[parent[i]], A_here) 151 | 152 | results.append(res_here) 153 | 154 | # 10 x 24 x 4 x 4 155 | results = torch.stack(results, dim=1) 156 | 157 | # scale updates 158 | new_J = results[:, :, :3, 3] 159 | 160 | # --- Compute relative A: Skinning is based on 161 | # how much the bone moved (not the final location of the bone) 162 | # but (final_bone - init_bone) 163 | # --- 164 | Js_w0 = torch.cat([Js_orig, torch.zeros([N, 35, 1, 1]).to(Rs.device)], 2) 165 | init_bone = torch.matmul(results, Js_w0) 166 | # Append empty 4 x 3: 167 | init_bone = torch.nn.functional.pad(init_bone, (3,0,0,0,0,0,0,0)) 168 | A = results - init_bone 169 | 170 | return new_J, A 171 | -------------------------------------------------------------------------------- /util/joint_limits_prior.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | Ranges = { 4 | 'pelvis': [[0, 0], [0, 0], [0, 0]], 5 | 'pelvis0': [[-0.3, 0.3], [-1.2, 0.5], [-0.1, 0.1]], 6 | 'spine': [[-0.4, 0.4], [-1.0, 0.9], [-0.8, 0.8]], 7 | 'spine0': [[-0.4, 0.4], [-1.0, 0.9], [-0.8, 0.8]], 8 | 'spine1': [[-0.4, 0.4], [-0.5, 1.2], [-0.4, 0.4]], 9 | 'spine3': [[-0.5, 0.5], [-0.6, 1.4], [-0.8, 0.8]], 10 | 'spine2': [[-0.5, 0.5], [-0.4, 1.4], [-0.5, 0.5]], 11 | 'RFootBack': [[-0.2, 0.3], [-0.3, 1.1], [-0.3, 0.5]], 12 | 'LFootBack': [[-0.3, 0.2], [-0.3, 1.1], [-0.5, 0.3]], 13 | 'LLegBack1': [[-0.2, 0.3], [-0.5, 0.8], [-0.5, 0.4]], 14 | 'RLegBack1': [[-0.3, 0.2], [-0.5, 0.8], [-0.4, 0.5]], 15 | 'Head': [[-0.5, 0.5], [-1.0, 0.9], [-0.9, 0.9]], 16 | 'RLegBack2': [[-0.3, 0.2], [-0.6, 0.8], [-0.5, 0.6]], 17 | 'LLegBack2': [[-0.2, 0.3], [-0.6, 0.8], [-0.6, 0.5]], 18 | 'RLegBack3': [[-0.2, 0.3], [-0.8, 0.2], [-0.4, 0.5]], 19 | 'LLegBack3': [[-0.3, 0.2], [-0.8, 0.2], [-0.5, 0.4]], 20 | 'Mouth': [[-0.1, 0.1], [-1.1, 0.5], [-0.1, 0.1]], 21 | 'Neck': [[-0.8, 0.8], [-1.0, 1.0], [-1.1, 1.1]], 22 | 'LLeg1': [[-0.05, 0.05], [-1.3, 0.8], [-0.6, 0.6]], # Extreme 23 | 'RLeg1': [[-0.05, 0.05], [-1.3, 0.8], [-0.6, 0.6]], 24 | 'RLeg2': [[-0.05, 0.05], [-1.0, 0.9], [-0.6, 0.6]], # Extreme 25 | 'LLeg2': [[-0.05, 0.05], [-1.0, 1.1], [-0.6, 0.6]], 26 | 'RLeg3': [[-0.1, 0.4], [-0.3, 1.4], [-0.4, 0.7]], # Extreme 27 | 'LLeg3': [[-0.4, 0.1], [-0.3, 1.4], [-0.7, 0.4]], 28 | 'LFoot': [[-0.3, 0.1], [-0.4, 1.5], [-0.7, 0.3]], # Extreme 29 | 'RFoot': [[-0.1, 0.3], [-0.4, 1.5], [-0.3, 0.7]], 30 | 'Tail7': [[-0.1, 0.1], [-0.7, 1.1], [-0.9, 0.8]], 31 | 'Tail6': [[-0.1, 0.1], [-1.4, 1.4], [-1.0, 1.0]], 32 | 'Tail5': [[-0.1, 0.1], [-1.0, 1.0], [-0.8, 0.8]], 33 | 'Tail4': [[-0.1, 0.1], [-1.0, 1.0], [-0.8, 0.8]], 34 | 'Tail3': [[-0.1, 0.1], [-1.0, 1.0], [-0.8, 0.8]], 35 | 'Tail2': [[-0.1, 0.1], [-1.0, 1.0], [-0.8, 0.8]], 36 | 'Tail1': [[-0.1, 0.1], [-1.5, 1.4], [-1.2, 1.2]], 37 | } 38 | 39 | 40 | class LimitPrior(object): 41 | def __init__(self, device, n_pose=32): 42 | self.parts = { 43 | 'pelvis0': 0, 44 | 'spine': 1, 45 | 'spine0': 2, 46 | 'spine1': 3, 47 | 'spine2': 4, 48 | 'spine3': 5, 49 | 'LLeg1': 6, 50 | 'LLeg2': 7, 51 | 'LLeg3': 8, 52 | 'LFoot': 9, 53 | 'RLeg1': 10, 54 | 'RLeg2': 11, 55 | 'RLeg3': 12, 56 | 'RFoot': 13, 57 | 'Neck': 14, 58 | 'Head': 15, 59 | 'LLegBack1': 16, 60 | 'LLegBack2': 17, 61 | 'LLegBack3': 18, 62 | 'LFootBack': 19, 63 | 'RLegBack1': 20, 64 | 'RLegBack2': 21, 65 | 'RLegBack3': 22, 66 | 'RFootBack': 23, 67 | 'Tail1': 24, 68 | 'Tail2': 25, 69 | 'Tail3': 26, 70 | 'Tail4': 27, 71 | 'Tail5': 28, 72 | 'Tail6': 29, 73 | 'Tail7': 30, 74 | 'Mouth': 31 75 | } 76 | self.id2name = {v: k for k, v in self.parts.items()} 77 | # Ignore the first joint. 78 | self.prefix = 3 79 | self.postfix= 99 80 | self.part_ids = np.array(sorted(self.parts.values())) 81 | min_values = np.hstack([np.array(np.array(Ranges[self.id2name[part_id]])[:, 0]) for part_id in self.part_ids]) 82 | max_values = np.hstack([ 83 | np.array(np.array(Ranges[self.id2name[part_id]])[:, 1]) 84 | for part_id in self.part_ids 85 | ]) 86 | self.ranges = Ranges 87 | self.device = device 88 | self.min_values = torch.from_numpy(min_values).view(n_pose, 3).float().to(device) 89 | self.max_values = torch.from_numpy(max_values).view(n_pose, 3).float().to(device) 90 | 91 | def __call__(self, x): 92 | ''' 93 | Given x, rel rotation of 31 joints, for each parts compute the limit value. 94 | k is steepness of the curve, max_val + margin is the midpoint of the curve (val 0.5) 95 | Using Logistic: 96 | max limit: 1/(1 + exp(k * ((max_val + margin) - x))) 97 | min limit: 1/(1 + exp(k * (x - (min_val - margin)))) 98 | With max/min: 99 | minlimit: max( min_vals - x , 0 ) 100 | maxlimit: max( x - max_vals , 0 ) 101 | With exponential: 102 | min: exp(k * (minval - x) ) 103 | max: exp(k * (x - maxval) ) 104 | ''' 105 | ## Max/min discontinous but fast. (flat + L2 past the limit) 106 | 107 | x = x[:, self.prefix:self.postfix].view(x.shape[0], -1, 3) 108 | zeros = torch.zeros_like(x).to(self.device) 109 | # return np.maximum(x - self.max_values, zeros) + np.maximum(self.min_values - x, zeros) 110 | return torch.mean(torch.max(x - self.max_values.unsqueeze(0), zeros) + torch.max(self.min_values.unsqueeze(0) - x, zeros)) 111 | 112 | def report(self, x): 113 | res = self(x).r.reshape(-1, 3) 114 | values = x[self.prefix:].r.reshape(-1, 3) 115 | bad = np.any(res > 0, axis=1) 116 | bad_ids = np.array(self.part_ids)[bad] 117 | np.set_printoptions(precision=3) 118 | for bad_id in bad_ids: 119 | name = self.id2name[bad_id] 120 | limits = self.ranges[name] 121 | print('%s over! Overby:' % name), 122 | print(res[bad_id - 1, :]), 123 | print(' Limits:'), 124 | print(limits), 125 | print(' Values:'), 126 | print(values[bad_id - 1, :]) 127 | 128 | if __name__ == '__main__': 129 | name2id33 = {'RFoot': 14, 'RFootBack': 24, 'spine1': 4, 'Head': 16, 'LLegBack3': 19, 'RLegBack1': 21, 'pelvis0': 1, 130 | 'RLegBack3': 23, 'LLegBack2': 18, 'spine0': 3, 'spine3': 6, 'spine2': 5, 'Mouth': 32, 'Neck': 15, 131 | 'LFootBack': 20, 'LLegBack1': 17, 'RLeg3': 13, 'RLeg2': 12, 'LLeg1': 7, 'LLeg3': 9, 'RLeg1': 11, 132 | 'LLeg2': 8, 'spine': 2, 'LFoot': 10, 'Tail7': 31, 'Tail6': 30, 'Tail5': 29, 'Tail4': 28, 'Tail3': 27, 133 | 'Tail2': 26, 'Tail1': 25, 'RLegBack2': 22, 'root': 0} 134 | name2id35 = {'RFoot': 14, 'RFootBack': 24, 'spine1': 4, 'Head': 16, 'LLegBack3': 19, 'RLegBack1': 21, 'pelvis0': 1, 135 | 'RLegBack3': 23, 'LLegBack2': 18, 'spine0': 3, 'spine3': 6, 'spine2': 5, 'Mouth': 32, 'Neck': 15, 136 | 'LFootBack': 20, 'LLegBack1': 17, 'RLeg3': 13, 'RLeg2': 12, 'LLeg1': 7, 'LLeg3': 9, 'RLeg1': 11, 137 | 'LLeg2': 8, 'spine': 2, 'LFoot': 10, 'Tail7': 31, 'Tail6': 30, 'Tail5': 29, 'Tail4': 28, 'Tail3': 27, 138 | 'Tail2': 26, 'Tail1': 25, 'RLegBack2': 22, 'root': 0, 'LEar': 33, 'REar': 34} 139 | 140 | import os 141 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 142 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 143 | limit_prior = LimitPrior(device, 32) 144 | for k,v in limit_prior.parts.items(): 145 | id33 = name2id33[k]-1 146 | id35 = name2id35[k]-1 147 | assert id33 == id35 and id33==v 148 | x = torch.zeros((35*3,)).float().to(device) 149 | limit_loss = limit_prior(x) 150 | print('done') -------------------------------------------------------------------------------- /smal/smal_torch.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | PyTorch implementation of the SMAL/SMPL model 4 | 5 | """ 6 | from __future__ import absolute_import 7 | from __future__ import division 8 | from __future__ import print_function 9 | 10 | import numpy as np 11 | import torch 12 | import torch.nn as nn 13 | from torch.autograd import Variable 14 | import pickle as pkl 15 | from smal.batch_lbs import batch_rodrigues, batch_global_rigid_transformation 16 | from smal.smal_basics import align_smal_template_to_symmetry_axis 17 | from util import config 18 | 19 | # There are chumpy variables so convert them to numpy. 20 | def undo_chumpy(x): 21 | return x if isinstance(x, np.ndarray) else x.r 22 | 23 | class SMAL(nn.Module): 24 | def __init__(self, device, shape_family_id=1, dtype=torch.float): 25 | super(SMAL, self).__init__() 26 | 27 | # -- Load SMPL params -- 28 | # with open(pkl_path, 'r') as f: 29 | # dd = pkl.load(f) 30 | 31 | print (f"Loading SMAL with shape family: {shape_family_id}") 32 | 33 | with open(config.SMAL_FILE, 'rb') as f: 34 | u = pkl._Unpickler(f) 35 | u.encoding = 'latin1' 36 | dd = u.load() 37 | 38 | self.f = dd['f'] 39 | 40 | self.faces = torch.from_numpy(self.f.astype(int)).to(device) 41 | 42 | v_template = dd['v_template'] 43 | 44 | # Size of mesh [Number of vertices, 3] 45 | self.size = [v_template.shape[0], 3] 46 | self.num_betas = dd['shapedirs'].shape[-1] 47 | # Shape blend shape basis 48 | 49 | shapedir = np.reshape( 50 | undo_chumpy(dd['shapedirs']), [-1, self.num_betas]).T.copy() 51 | self.shapedirs = Variable( 52 | torch.Tensor(shapedir), requires_grad=False).to(device) 53 | 54 | with open(config.SMAL_DATA_FILE, 'rb') as f: 55 | u = pkl._Unpickler(f) 56 | u.encoding = 'latin1' 57 | data = u.load() 58 | 59 | # Zero_Betas -> V_Template -> Aligned 60 | # Zero_Betas -> V_Template -> V_Template + ShapeCluster * ShapeDirs -> Aligned 61 | 62 | # Aligned(V_T + ShapeCluster * ShapeDirs) - ShapeCluster * ShapeDirs 63 | 64 | # Select mean shape for quadruped type 65 | shape_cluster_means = data['cluster_means'][shape_family_id] 66 | self.shape_cluster_cov = data['cluster_cov'][shape_family_id] 67 | self.shape_cluster_means = torch.from_numpy(shape_cluster_means).float().to(device) 68 | 69 | v_sym, self.left_inds, self.right_inds, self.center_inds = align_smal_template_to_symmetry_axis( 70 | v_template, sym_file=config.SMAL_SYM_FILE) 71 | 72 | # Mean template vertices 73 | self.v_template = Variable( 74 | torch.Tensor(v_sym), 75 | requires_grad=False).to(device) 76 | 77 | # Regressor for joint locations given shape 78 | self.J_regressor = Variable( 79 | torch.Tensor(dd['J_regressor'].T.todense()), 80 | requires_grad=False).to(device) 81 | 82 | # Pose blend shape basis 83 | num_pose_basis = dd['posedirs'].shape[-1] 84 | 85 | posedirs = np.reshape( 86 | undo_chumpy(dd['posedirs']), [-1, num_pose_basis]).T 87 | self.posedirs = Variable( 88 | torch.Tensor(posedirs), requires_grad=False).to(device) 89 | 90 | # indices of parents for each joints 91 | self.parents = dd['kintree_table'][0].astype(np.int32) 92 | 93 | # LBS weights 94 | self.weights = Variable( 95 | torch.Tensor(undo_chumpy(dd['weights'])), 96 | requires_grad=False).to(device) 97 | 98 | 99 | def __call__(self, beta, theta, trans=None, del_v=None, betas_logscale=None, get_skin=True, v_template=None): 100 | 101 | if True: 102 | nBetas = beta.shape[1] 103 | else: 104 | nBetas = 0 105 | 106 | 107 | # v_template = self.v_template.unsqueeze(0).expand(beta.shape[0], 3889, 3) 108 | if v_template is None: 109 | v_template = self.v_template 110 | 111 | # 1. Add shape blend shapes 112 | 113 | if nBetas > 0: 114 | if del_v is None: 115 | v_shaped = v_template + torch.reshape(torch.matmul(beta, self.shapedirs[:nBetas,:]), [-1, self.size[0], self.size[1]]) 116 | else: 117 | v_shaped = v_template + del_v + torch.reshape(torch.matmul(beta, self.shapedirs[:nBetas,:]), [-1, self.size[0], self.size[1]]) 118 | else: 119 | if del_v is None: 120 | v_shaped = v_template.unsqueeze(0) 121 | else: 122 | v_shaped = v_template + del_v 123 | 124 | # 2. Infer shape-dependent joint locations. 125 | Jx = torch.matmul(v_shaped[:, :, 0], self.J_regressor) 126 | Jy = torch.matmul(v_shaped[:, :, 1], self.J_regressor) 127 | Jz = torch.matmul(v_shaped[:, :, 2], self.J_regressor) 128 | J = torch.stack([Jx, Jy, Jz], dim=2) 129 | 130 | # 3. Add pose blend shapes 131 | # N x 24 x 3 x 3 132 | if len(theta.shape) == 4: 133 | Rs = theta 134 | else: 135 | Rs = torch.reshape( batch_rodrigues(torch.reshape(theta, [-1, 3])), [-1, 35, 3, 3]) 136 | 137 | # Ignore global rotation. 138 | pose_feature = torch.reshape(Rs[:, 1:, :, :] - torch.eye(3).to(beta.device), [-1, 306]) 139 | 140 | v_posed = torch.reshape( 141 | torch.matmul(pose_feature, self.posedirs), 142 | [-1, self.size[0], self.size[1]]) + v_shaped 143 | 144 | #4. Get the global joint location 145 | self.J_transformed, A = batch_global_rigid_transformation( 146 | Rs, J, self.parents, betas_logscale=betas_logscale) 147 | 148 | 149 | # 5. Do skinning: 150 | num_batch = theta.shape[0] 151 | 152 | weights_t = self.weights.repeat([num_batch, 1]) 153 | W = torch.reshape(weights_t, [num_batch, -1, 35]) 154 | 155 | 156 | T = torch.reshape( 157 | torch.matmul(W, torch.reshape(A, [num_batch, 35, 16])), 158 | [num_batch, -1, 4, 4]) 159 | v_posed_homo = torch.cat( 160 | [v_posed, torch.ones([num_batch, v_posed.shape[1], 1]).to(device=beta.device)], 2) 161 | v_homo = torch.matmul(T, v_posed_homo.unsqueeze(-1)) 162 | 163 | verts = v_homo[:, :, :3, 0] 164 | 165 | if trans is None: 166 | trans = torch.zeros((num_batch,3)).to(device=beta.device) 167 | 168 | verts = verts + trans[:,None,:] 169 | 170 | # Get joints: 171 | joint_x = torch.matmul(verts[:, :, 0], self.J_regressor) 172 | joint_y = torch.matmul(verts[:, :, 1], self.J_regressor) 173 | joint_z = torch.matmul(verts[:, :, 2], self.J_regressor) 174 | joints = torch.stack([joint_x, joint_y, joint_z], dim=2) 175 | 176 | joints = torch.cat([ 177 | joints, 178 | verts[:, None, 1863], # end_of_nose 179 | verts[:, None, 26], # chin 180 | verts[:, None, 2124], # right ear tip 181 | verts[:, None, 150], # left ear tip 182 | verts[:, None, 3055], # left eye 183 | verts[:, None, 1097], # right eye 184 | ], dim = 1) 185 | 186 | if get_skin: 187 | return verts, joints, Rs, v_shaped 188 | else: 189 | return joints 190 | 191 | def verts2joints(self, verts): 192 | # Get joints: 193 | joint_x = torch.matmul(verts[:, :, 0], self.J_regressor) 194 | joint_y = torch.matmul(verts[:, :, 1], self.J_regressor) 195 | joint_z = torch.matmul(verts[:, :, 2], self.J_regressor) 196 | joints = torch.stack([joint_x, joint_y, joint_z], dim=2) 197 | 198 | joints = torch.cat([ 199 | joints, 200 | verts[:, None, 1863], # end_of_nose 201 | verts[:, None, 26], # chin 202 | verts[:, None, 2124], # right ear tip 203 | verts[:, None, 150], # left ear tip 204 | verts[:, None, 3055], # left eye 205 | verts[:, None, 1097], # right eye 206 | ], dim=1) 207 | 208 | return joints 209 | 210 | 211 | 212 | -------------------------------------------------------------------------------- /datasets/imutils.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains functions that are used to perform data augmentation. 3 | """ 4 | import torch 5 | import numpy as np 6 | import scipy.misc 7 | import cv2 8 | import imageio 9 | from scipy import ndimage 10 | 11 | def get_transform(center, scale, res, rot=0): 12 | """Generate transformation matrix.""" 13 | h = 200 * scale 14 | t = np.zeros((3, 3)) 15 | t[0, 0] = float(res[1]) / h 16 | t[1, 1] = float(res[0]) / h 17 | t[0, 2] = res[1] * (-float(center[0]) / h + .5) 18 | t[1, 2] = res[0] * (-float(center[1]) / h + .5) 19 | t[2, 2] = 1 20 | if not rot == 0: 21 | rot = -rot # To match direction of rotation from cropping 22 | rot_mat = np.zeros((3, 3)) 23 | rot_rad = rot * np.pi / 180 24 | sn, cs = np.sin(rot_rad), np.cos(rot_rad) 25 | rot_mat[0, :2] = [cs, -sn] 26 | rot_mat[1, :2] = [sn, cs] 27 | rot_mat[2, 2] = 1 28 | # Need to rotate around center 29 | t_mat = np.eye(3) 30 | t_mat[0, 2] = -res[1] / 2 31 | t_mat[1, 2] = -res[0] / 2 32 | t_inv = t_mat.copy() 33 | t_inv[:2, 2] *= -1 34 | t = np.dot(t_inv, np.dot(rot_mat, np.dot(t_mat, t))) 35 | return t 36 | 37 | 38 | def transform(pt, center, scale, res, invert=0, rot=0): 39 | """Transform pixel location to different reference.""" 40 | t = get_transform(center, scale, res, rot=rot) 41 | if invert: 42 | t = np.linalg.inv(t) 43 | new_pt = np.array([pt[0] - 1, pt[1] - 1, 1.]).T 44 | new_pt = np.dot(t, new_pt) 45 | return new_pt[:2].astype(int) + 1 46 | 47 | 48 | def crop(img, center, scale, res, rot=0, border_grey_intensity=0.0): 49 | """Crop image according to the supplied bounding box.""" 50 | # Upper left point 51 | ul = np.array(transform([1, 1], center, scale, res, invert=1)) - 1 52 | # Bottom right point 53 | br = np.array(transform([res[0] + 1, 54 | res[1] + 1], center, scale, res, invert=1)) - 1 55 | 56 | # Padding so that when rotated proper amount of context is included 57 | pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2) 58 | if not rot == 0: 59 | ul -= pad 60 | br += pad 61 | 62 | new_shape = [br[1] - ul[1], br[0] - ul[0]] 63 | if len(img.shape) > 2: 64 | new_shape += [img.shape[2]] 65 | new_img = np.ones(new_shape) * border_grey_intensity 66 | 67 | # Range to fill new array 68 | new_x = max(0, -ul[0]), min(br[0], len(img[0])) - ul[0] 69 | new_y = max(0, -ul[1]), min(br[1], len(img)) - ul[1] 70 | # Range to sample from original image 71 | old_x = max(0, ul[0]), min(len(img[0]), br[0]) 72 | old_y = max(0, ul[1]), min(len(img), br[1]) 73 | try: 74 | new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1], 75 | old_x[0]:old_x[1]] 76 | except ValueError: 77 | print('here') 78 | if not rot == 0: 79 | # Remove padding 80 | new_img = scipy.misc.imrotate(new_img, rot) 81 | new_img = new_img[pad:-pad, pad:-pad] 82 | 83 | new_img = cv2.resize(new_img, (*res,)) 84 | return new_img 85 | 86 | 87 | def uncrop(img, center, scale, orig_shape, rot=0, is_rgb=True): 88 | """'Undo' the image cropping/resizing. 89 | This function is used when evaluating mask/part segmentation. 90 | """ 91 | res = img.shape[:2] 92 | # Upper left point 93 | ul = np.array(transform([1, 1], center, scale, res, invert=1)) - 1 94 | # Bottom right point 95 | br = np.array(transform([res[0] + 1, res[1] + 1], center, scale, res, invert=1)) - 1 96 | # size of cropped image 97 | crop_shape = [br[1] - ul[1], br[0] - ul[0]] 98 | 99 | new_shape = [br[1] - ul[1], br[0] - ul[0]] 100 | if len(img.shape) > 2: 101 | new_shape += [img.shape[2]] 102 | new_img = np.zeros(orig_shape, dtype=np.uint8) 103 | # Range to fill new array 104 | new_x = max(0, -ul[0]), min(br[0], orig_shape[1]) - ul[0] 105 | new_y = max(0, -ul[1]), min(br[1], orig_shape[0]) - ul[1] 106 | # Range to sample from original image 107 | old_x = max(0, ul[0]), min(orig_shape[1], br[0]) 108 | old_y = max(0, ul[1]), min(orig_shape[0], br[1]) 109 | img = scipy.misc.imresize(img, crop_shape, interp='nearest') 110 | new_img[old_y[0]:old_y[1], old_x[0]:old_x[1]] = img[new_y[0]:new_y[1], new_x[0]:new_x[1]] 111 | return new_img 112 | 113 | 114 | def rot_aa(aa, rot): 115 | """Rotate axis angle parameters.""" 116 | # pose parameters 117 | R = np.array([[np.cos(np.deg2rad(-rot)), -np.sin(np.deg2rad(-rot)), 0], 118 | [np.sin(np.deg2rad(-rot)), np.cos(np.deg2rad(-rot)), 0], 119 | [0, 0, 1]]) 120 | # find the rotation of the body in camera frame 121 | per_rdg, _ = cv2.Rodrigues(aa) 122 | # apply the global rotation to the global orientation 123 | resrot, _ = cv2.Rodrigues(np.dot(R, per_rdg)) 124 | aa = (resrot.T)[0] 125 | return aa 126 | 127 | 128 | def flip_img(img): 129 | """Flip rgb images or masks. 130 | channels come last, e.g. (256,256,3). 131 | """ 132 | img = np.fliplr(img) 133 | return img 134 | 135 | 136 | def flip_kp(kp, width): 137 | """Flip keypoints.""" 138 | kp[:, 0] = width - kp[:, 0] 139 | flipped_parts = [6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 12, 13, 15, 14, 16, 17, 19, 18, 21, 20, 22, 140 | 23] # order of kp indices for L-R flipping 141 | kp = kp[flipped_parts] 142 | 143 | return kp 144 | 145 | 146 | def flip_pose(pose): 147 | """Flip pose. 148 | The flipping is based on SMPL parameters. 149 | """ 150 | flippedParts = [0, 1, 2, 6, 7, 8, 3, 4, 5, 9, 10, 11, 15, 16, 17, 12, 13, 151 | 14, 18, 19, 20, 24, 25, 26, 21, 22, 23, 27, 28, 29, 33, 152 | 34, 35, 30, 31, 32, 36, 37, 38, 42, 43, 44, 39, 40, 41, 153 | 45, 46, 47, 51, 52, 53, 48, 49, 50, 57, 58, 59, 54, 55, 154 | 56, 63, 64, 65, 60, 61, 62, 69, 70, 71, 66, 67, 68] 155 | pose = pose[flippedParts] 156 | # we also negate the second and the third dimension of the axis-angle 157 | pose[1::3] = -pose[1::3] 158 | pose[2::3] = -pose[2::3] 159 | return pose 160 | 161 | 162 | def flip_aa(aa): 163 | """Flip axis-angle representation. 164 | We negate the second and the third dimension of the axis-angle. 165 | """ 166 | aa[1] = -aa[1] 167 | aa[2] = -aa[2] 168 | return aa 169 | 170 | 171 | def draw_labelmap(img, pt, sigma, type='Gaussian'): 172 | # Draw a 2D gaussian 173 | # Adopted from https://github.com/anewell/pose-hg-train/blob/master/src/pypose/draw.py 174 | # img = to_numpy(img) 175 | img = img.numpy() 176 | # Check that any part of the gaussian is in-bounds 177 | ul = [int(pt[0] - 3 * sigma), int(pt[1] - 3 * sigma)] 178 | br = [int(pt[0] + 3 * sigma + 1), int(pt[1] + 3 * sigma + 1)] 179 | if (ul[0] >= img.shape[1] or ul[1] >= img.shape[0] or 180 | br[0] < 0 or br[1] < 0): 181 | # If not, just return the image as is 182 | return torch.from_numpy(img).float(), 0 183 | 184 | # Generate gaussian 185 | size = 6 * sigma + 1 186 | x = np.arange(0, size, 1, float) 187 | y = x[:, np.newaxis] 188 | x0 = y0 = size // 2 189 | # The gaussian is not normalized, we want the center value to equal 1 190 | if type == 'Gaussian': 191 | g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2)) 192 | elif type == 'Cauchy': 193 | g = sigma / (((x - x0) ** 2 + (y - y0) ** 2 + sigma ** 2) ** 1.5) 194 | 195 | 196 | # Usable gaussian range 197 | g_x = max(0, -ul[0]), min(br[0], img.shape[1]) - ul[0] 198 | g_y = max(0, -ul[1]), min(br[1], img.shape[0]) - ul[1] 199 | # Image range 200 | img_x = max(0, ul[0]), min(br[0], img.shape[1]) 201 | img_y = max(0, ul[1]), min(br[1], img.shape[0]) 202 | 203 | img[img_y[0]:img_y[1], img_x[0]:img_x[1]] = g[g_y[0]:g_y[1], g_x[0]:g_x[1]] 204 | # return to_torch(img), 1 205 | return torch.from_numpy(img).float(), 1 206 | 207 | 208 | def flip_back(flip_output): 209 | """ 210 | flip output map 211 | """ 212 | 213 | flipped_parts = [6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 12, 13, 15, 14, 16, 17, 19, 18, 21, 20, 22, 23] 214 | # flip output horizontally 215 | flip_output = fliplr(flip_output.numpy()) 216 | 217 | # Change left-right parts 218 | flip_output = flip_output[:, flipped_parts, :, :] 219 | 220 | return torch.from_numpy(flip_output).float() 221 | 222 | 223 | def fliplr(x): 224 | if x.ndim == 3: 225 | x = np.transpose(np.fliplr(np.transpose(x, (0, 2, 1))), (0, 2, 1)) 226 | elif x.ndim == 4: 227 | for i in range(x.shape[0]): 228 | x[i] = np.transpose(np.fliplr(np.transpose(x[i], (0, 2, 1))), (0, 2, 1)) 229 | return x.astype(float) 230 | -------------------------------------------------------------------------------- /coma/mesh_sampling.py: -------------------------------------------------------------------------------- 1 | import math 2 | import heapq 3 | import numpy as np 4 | import scipy.sparse as sp 5 | from opendr.topology import get_vert_connectivity, get_vertices_per_edge 6 | from psbody.mesh import Mesh 7 | 8 | def vertex_quadrics(mesh): 9 | """Computes a quadric for each vertex in the Mesh. 10 | Returns: 11 | v_quadrics: an (N x 4 x 4) array, where N is # vertices. 12 | """ 13 | 14 | # Allocate quadrics 15 | v_quadrics = np.zeros((len(mesh.v), 4, 4,)) 16 | 17 | # For each face... 18 | for f_idx in range(len(mesh.f)): 19 | 20 | # Compute normalized plane equation for that face 21 | vert_idxs = mesh.f[f_idx] 22 | verts = np.hstack((mesh.v[vert_idxs], np.array([1, 1, 1]).reshape(-1, 1))) 23 | u, s, v = np.linalg.svd(verts) 24 | eq = v[-1, :].reshape(-1, 1) 25 | eq = eq / (np.linalg.norm(eq[0:3])) 26 | 27 | # Add the outer product of the plane equation to the 28 | # quadrics of the vertices for this face 29 | for k in range(3): 30 | v_quadrics[mesh.f[f_idx, k], :, :] += np.outer(eq, eq) 31 | 32 | return v_quadrics 33 | 34 | 35 | def setup_deformation_transfer(source, target, use_normals=False): 36 | rows = np.zeros(3 * target.v.shape[0]) 37 | cols = np.zeros(3 * target.v.shape[0]) 38 | coeffs_v = np.zeros(3 * target.v.shape[0]) 39 | coeffs_n = np.zeros(3 * target.v.shape[0]) 40 | 41 | nearest_faces, nearest_parts, nearest_vertices = source.compute_aabb_tree().nearest(target.v, True) 42 | nearest_faces = nearest_faces.ravel().astype(np.int64) 43 | nearest_parts = nearest_parts.ravel().astype(np.int64) 44 | nearest_vertices = nearest_vertices.ravel() 45 | 46 | for i in range(target.v.shape[0]): 47 | # Closest triangle index 48 | f_id = nearest_faces[i] 49 | # Closest triangle vertex ids 50 | nearest_f = source.f[f_id] 51 | 52 | # Closest surface point 53 | nearest_v = nearest_vertices[3 * i:3 * i + 3] 54 | # Distance vector to the closest surface point 55 | dist_vec = target.v[i] - nearest_v 56 | 57 | rows[3 * i:3 * i + 3] = i * np.ones(3) 58 | cols[3 * i:3 * i + 3] = nearest_f 59 | 60 | n_id = nearest_parts[i] 61 | if n_id == 0: 62 | # Closest surface point in triangle 63 | A = np.vstack((source.v[nearest_f])).T 64 | coeffs_v[3 * i:3 * i + 3] = np.linalg.lstsq(A, nearest_v)[0] 65 | elif n_id > 0 and n_id <= 3: 66 | # Closest surface point on edge 67 | A = np.vstack((source.v[nearest_f[n_id - 1]], source.v[nearest_f[n_id % 3]])).T 68 | tmp_coeffs = np.linalg.lstsq(A, target.v[i])[0] 69 | coeffs_v[3 * i + n_id - 1] = tmp_coeffs[0] 70 | coeffs_v[3 * i + n_id % 3] = tmp_coeffs[1] 71 | else: 72 | # Closest surface point a vertex 73 | coeffs_v[3 * i + n_id - 4] = 1.0 74 | 75 | # if use_normals: 76 | # A = np.vstack((vn[nearest_f])).T 77 | # coeffs_n[3 * i:3 * i + 3] = np.linalg.lstsq(A, dist_vec)[0] 78 | 79 | # coeffs = np.hstack((coeffs_v, coeffs_n)) 80 | # rows = np.hstack((rows, rows)) 81 | # cols = np.hstack((cols, source.v.shape[0] + cols)) 82 | matrix = sp.csc_matrix((coeffs_v, (rows, cols)), shape=(target.v.shape[0], source.v.shape[0])) 83 | return matrix 84 | 85 | 86 | def qslim_decimator_transformer(mesh, factor=None, n_verts_desired=None): 87 | """Return a simplified version of this mesh. 88 | A Qslim-style approach is used here. 89 | :param factor: fraction of the original vertices to retain 90 | :param n_verts_desired: number of the original vertices to retain 91 | :returns: new_faces: An Fx3 array of faces, mtx: Transformation matrix 92 | """ 93 | 94 | if factor is None and n_verts_desired is None: 95 | raise Exception('Need either factor or n_verts_desired.') 96 | 97 | if n_verts_desired is None: 98 | n_verts_desired = math.ceil(len(mesh.v) * factor) 99 | 100 | Qv = vertex_quadrics(mesh) 101 | 102 | # fill out a sparse matrix indicating vertex-vertex adjacency 103 | # from psbody.mesh.topology.connectivity import get_vertices_per_edge 104 | vert_adj = get_vertices_per_edge(mesh.v, mesh.f) 105 | # vert_adj = sp.lil_matrix((len(mesh.v), len(mesh.v))) 106 | # for f_idx in range(len(mesh.f)): 107 | # vert_adj[mesh.f[f_idx], mesh.f[f_idx]] = 1 108 | 109 | vert_adj = sp.csc_matrix((vert_adj[:, 0] * 0 + 1, (vert_adj[:, 0], vert_adj[:, 1])), 110 | shape=(len(mesh.v), len(mesh.v))) 111 | vert_adj = vert_adj + vert_adj.T 112 | vert_adj = vert_adj.tocoo() 113 | 114 | def collapse_cost(Qv, r, c, v): 115 | Qsum = Qv[r, :, :] + Qv[c, :, :] 116 | p1 = np.vstack((v[r].reshape(-1, 1), np.array([1]).reshape(-1, 1))) 117 | p2 = np.vstack((v[c].reshape(-1, 1), np.array([1]).reshape(-1, 1))) 118 | 119 | destroy_c_cost = p1.T.dot(Qsum).dot(p1) 120 | destroy_r_cost = p2.T.dot(Qsum).dot(p2) 121 | result = { 122 | 'destroy_c_cost': destroy_c_cost, 123 | 'destroy_r_cost': destroy_r_cost, 124 | 'collapse_cost': min([destroy_c_cost, destroy_r_cost]), 125 | 'Qsum': Qsum} 126 | return result 127 | 128 | # construct a queue of edges with costs 129 | queue = [] 130 | for k in range(vert_adj.nnz): 131 | r = vert_adj.row[k] 132 | c = vert_adj.col[k] 133 | 134 | if r > c: 135 | continue 136 | 137 | cost = collapse_cost(Qv, r, c, mesh.v)['collapse_cost'] 138 | heapq.heappush(queue, (cost, (r, c))) 139 | 140 | # decimate 141 | collapse_list = [] 142 | nverts_total = len(mesh.v) 143 | faces = mesh.f.copy() 144 | while nverts_total > n_verts_desired: 145 | e = heapq.heappop(queue) 146 | r = e[1][0] 147 | c = e[1][1] 148 | if r == c: 149 | continue 150 | 151 | cost = collapse_cost(Qv, r, c, mesh.v) 152 | if cost['collapse_cost'] > e[0]: 153 | heapq.heappush(queue, (cost['collapse_cost'], e[1])) 154 | # print 'found outdated cost, %.2f < %.2f' % (e[0], cost['collapse_cost']) 155 | continue 156 | else: 157 | 158 | # update old vert idxs to new one, 159 | # in queue and in face list 160 | if cost['destroy_c_cost'] < cost['destroy_r_cost']: 161 | to_destroy = c 162 | to_keep = r 163 | else: 164 | to_destroy = r 165 | to_keep = c 166 | 167 | collapse_list.append([to_keep, to_destroy]) 168 | 169 | # in our face array, replace "to_destroy" vertidx with "to_keep" vertidx 170 | np.place(faces, faces == to_destroy, to_keep) 171 | 172 | # same for queue 173 | which1 = [idx for idx in range(len(queue)) if queue[idx][1][0] == to_destroy] 174 | which2 = [idx for idx in range(len(queue)) if queue[idx][1][1] == to_destroy] 175 | for k in which1: 176 | queue[k] = (queue[k][0], (to_keep, queue[k][1][1])) 177 | for k in which2: 178 | queue[k] = (queue[k][0], (queue[k][1][0], to_keep)) 179 | 180 | Qv[r, :, :] = cost['Qsum'] 181 | Qv[c, :, :] = cost['Qsum'] 182 | 183 | a = faces[:, 0] == faces[:, 1] 184 | b = faces[:, 1] == faces[:, 2] 185 | c = faces[:, 2] == faces[:, 0] 186 | 187 | # remove degenerate faces 188 | def logical_or3(x, y, z): 189 | return np.logical_or(x, np.logical_or(y, z)) 190 | 191 | faces_to_keep = np.logical_not(logical_or3(a, b, c)) 192 | faces = faces[faces_to_keep, :].copy() 193 | 194 | nverts_total = (len(np.unique(faces.flatten()))) 195 | 196 | new_faces, mtx = _get_sparse_transform(faces, len(mesh.v)) 197 | return new_faces, mtx 198 | 199 | 200 | def _get_sparse_transform(faces, num_original_verts): 201 | verts_left = np.unique(faces.flatten()) 202 | IS = np.arange(len(verts_left)) 203 | JS = verts_left 204 | data = np.ones(len(JS)) 205 | 206 | mp = np.arange(0, np.max(faces.flatten()) + 1) 207 | mp[JS] = IS 208 | new_faces = mp[faces.copy().flatten()].reshape((-1, 3)) 209 | 210 | ij = np.vstack((IS.flatten(), JS.flatten())) 211 | mtx = sp.csc_matrix((data, ij), shape=(len(verts_left), num_original_verts)) 212 | 213 | return (new_faces, mtx) 214 | 215 | 216 | def generate_transform_matrices(mesh, factors): 217 | """Generates len(factors) meshes, each of them is scaled by factors[i] and 218 | computes the transformations between them. 219 | 220 | Returns: 221 | M: a set of meshes downsampled from mesh by a factor specified in factors. 222 | A: Adjacency matrix for each of the meshes 223 | D: Downsampling transforms between each of the meshes 224 | U: Upsampling transforms between each of the meshes 225 | """ 226 | 227 | factors = map(lambda x: 1.0 / x, factors) 228 | M, A, D, U = [], [], [], [] 229 | A.append(get_vert_connectivity(mesh.v, mesh.f)) 230 | M.append(mesh) 231 | 232 | for factor in factors: 233 | ds_f, ds_D = qslim_decimator_transformer(M[-1], factor=factor) 234 | D.append(ds_D) 235 | new_mesh_v = ds_D.dot(M[-1].v) 236 | new_mesh = Mesh(v=new_mesh_v, f=ds_f) 237 | M.append(new_mesh) 238 | A.append(get_vert_connectivity(new_mesh.v, new_mesh.f)) 239 | U.append(setup_deformation_transfer(M[-1], M[-2])) 240 | 241 | return M, A, D, U -------------------------------------------------------------------------------- /util/nmr.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import numpy as np 6 | import scipy.misc 7 | import tqdm 8 | 9 | # import chainer 10 | import torch 11 | 12 | import neural_renderer 13 | 14 | from util import geom_utils 15 | import pickle as pkl 16 | 17 | from util import config 18 | 19 | ############# 20 | ### Utils ### 21 | ############# 22 | def convert_as(src, trg): 23 | src = src.type_as(trg) 24 | if src.is_cuda: 25 | src = src.cuda(device=trg.get_device()) 26 | return src 27 | 28 | ######################################################################## 29 | ############ Wrapper class for the chainer Neural Renderer ############# 30 | ##### All functions must only use numpy arrays as inputs/outputs ####### 31 | ######################################################################## 32 | # class NMR(object): 33 | # def __init__(self): 34 | # # setup renderer 35 | # renderer = neural_renderer.Renderer() 36 | # self.renderer = renderer 37 | 38 | # def to_gpu(self, device=0): 39 | # # self.renderer.to_gpu(device) 40 | # self.cuda_device = device 41 | 42 | # def forward_mask(self, vertices, faces): 43 | # ''' Renders masks. 44 | # Args: 45 | # vertices: B X N X 3 numpy array 46 | # faces: B X F X 3 numpy array 47 | # Returns: 48 | # masks: B X 256 X 256 numpy array 49 | # ''' 50 | # # self.faces = chainer.Variable(chainer.cuda.to_gpu(faces, self.cuda_device)) 51 | # # self.vertices = chainer.Variable(chainer.cuda.to_gpu(vertices, self.cuda_device)) 52 | # # self.masks = self.renderer.render_silhouettes(self.vertices, self.faces) 53 | 54 | # masks = self.renderer.render_silhouettes(vertices, faces) 55 | 56 | # # masks = self.masks.data.get() 57 | # return masks 58 | 59 | # def backward_mask(self, grad_masks): 60 | # ''' Compute gradient of vertices given mask gradients. 61 | # Args: 62 | # grad_masks: B X 256 X 256 numpy array 63 | # Returns: 64 | # grad_vertices: B X N X 3 numpy array 65 | # ''' 66 | # self.masks.grad = chainer.cuda.to_gpu(grad_masks, self.cuda_device) 67 | # self.masks.backward() 68 | # return self.vertices.grad.get() 69 | 70 | # def forward_img(self, vertices, faces, textures): 71 | # ''' Renders masks. 72 | # Args: 73 | # vertices: B X N X 3 numpy array 74 | # faces: B X F X 3 numpy array 75 | # textures: B X F X T X T X T X 3 numpy array 76 | # Returns: 77 | # images: B X 3 x 256 X 256 numpy array 78 | # ''' 79 | # self.faces = chainer.Variable(chainer.cuda.to_gpu(faces, self.cuda_device)) 80 | # self.vertices = chainer.Variable(chainer.cuda.to_gpu(vertices, self.cuda_device)) 81 | # self.textures = chainer.Variable(chainer.cuda.to_gpu(textures, self.cuda_device)) 82 | # self.images = self.renderer.render(self.vertices, self.faces, self.textures) 83 | 84 | # images = self.images.data.get() 85 | # return images 86 | 87 | 88 | # def backward_img(self, grad_images): 89 | # ''' Compute gradient of vertices given image gradients. 90 | # Args: 91 | # grad_images: B X 3? X 256 X 256 numpy array 92 | # Returns: 93 | # grad_vertices: B X N X 3 numpy array 94 | # grad_textures: B X F X T X T X T X 3 numpy array 95 | # ''' 96 | # self.images.grad = chainer.cuda.to_gpu(grad_images, self.cuda_device) 97 | # self.images.backward() 98 | # return self.vertices.grad.get(), self.textures.grad.get() 99 | 100 | ######################################################################## 101 | ################# Wrapper class a rendering PythonOp ################### 102 | ##### All functions must only use torch Tensors as inputs/outputs ###### 103 | ######################################################################## 104 | # class Render(torch.autograd.Function): 105 | # # TODO(Shubham): Make sure the outputs/gradients are on the GPU 106 | # def __init__(self, renderer): 107 | # super(Render, self).__init__() 108 | # self.renderer = renderer 109 | 110 | # def forward(self, vertices, faces, textures=None): 111 | # # B x N x 3 112 | # # Flipping the y-axis here to make it align with the image coordinate system! 113 | # vs = vertices.cpu().numpy() 114 | # vs[:, :, 1] *= -1 115 | # fs = faces.cpu().numpy() 116 | # if textures is None: 117 | # self.mask_only = True 118 | # masks = self.renderer.forward_mask(vs, fs) 119 | # return convert_as(torch.Tensor(masks), vertices) 120 | # else: 121 | # self.mask_only = False 122 | # ts = textures.cpu().numpy() 123 | # imgs = self.renderer.forward_img(vs, fs, ts) 124 | # return convert_as(torch.Tensor(imgs), vertices) 125 | 126 | # def backward(self, grad_out): 127 | # g_o = grad_out.cpu().numpy() 128 | # if self.mask_only: 129 | # grad_verts = self.renderer.backward_mask(g_o) 130 | # grad_verts = convert_as(torch.Tensor(grad_verts), grad_out) 131 | # grad_tex = None 132 | # else: 133 | # grad_verts, grad_tex = self.renderer.backward_img(g_o) 134 | # grad_verts = convert_as(torch.Tensor(grad_verts), grad_out) 135 | # grad_tex = convert_as(torch.Tensor(grad_tex), grad_out) 136 | 137 | # grad_verts[:, :, 1] *= -1 138 | # return grad_verts, None, grad_tex 139 | 140 | 141 | ######################################################################## 142 | ############## Wrapper torch module for Neural Renderer ################ 143 | ######################################################################## 144 | class NeuralRenderer(torch.nn.Module): 145 | """ 146 | This is the core pytorch function to call. 147 | Every torch NMR has a chainer NMR. 148 | Only fwd/bwd once per iteration. 149 | """ 150 | def __init__(self, img_size=256, proj_type='perspective', norm_f=1., norm_z=0.,norm_f0=0., 151 | render_rgb=False, device=None): 152 | super(NeuralRenderer, self).__init__() 153 | # self.renderer = NMR() 154 | 155 | self.renderer = neural_renderer.Renderer(camera_mode='look_at', background_color=[1,1,1]) 156 | 157 | # self.renderer = nr.Renderer(camera_mode='look_at') 158 | 159 | self.norm_f = norm_f 160 | self.norm_f0 = norm_f0 161 | self.norm_z = norm_z 162 | 163 | # Adjust the core renderer 164 | self.renderer.image_size = img_size 165 | self.renderer.perspective = False 166 | 167 | # Set a default camera to be at (0, 0, -2.732) 168 | self.renderer.eye = [0, 0, -1.0] 169 | 170 | # Make it a bit brighter for vis 171 | self.renderer.light_intensity_ambient = 0.8 172 | 173 | # self.renderer.to_gpu() 174 | 175 | # Silvia 176 | if proj_type == 'perspective': 177 | self.proj_fn = geom_utils.perspective_proj_withz 178 | else: 179 | print('unknown projection type') 180 | import pdb; pdb.set_trace() 181 | 182 | self.offset_z = -1.0 183 | self.textures = torch.ones(7774, 4, 4, 4, 3) * torch.FloatTensor(config.MESH_COLOR) / 255.0 # light blue 184 | # self.textures = self.textures.cuda() 185 | self.textures = self.textures.to(device) 186 | self.render_rgb = render_rgb 187 | 188 | def ambient_light_only(self): 189 | # Make light only ambient. 190 | self.renderer.light_intensity_ambient = 1 191 | self.renderer.light_intensity_directional = 0 192 | 193 | def directional_light_only(self): 194 | # Make light only directional. 195 | self.renderer.light_intensity_ambient = 0.8 196 | self.renderer.light_intensity_directional = 0.8 197 | self.renderer.light_direction = [0, 1, 0] # up-to-down, this is the default 198 | 199 | def set_bgcolor(self, color): 200 | self.renderer.background_color = color 201 | 202 | def project_points(self, verts, cams, normalize_kpts=False): 203 | proj = self.proj_fn(verts, cams, offset_z=self.offset_z, norm_f=self.norm_f, norm_z=self.norm_z, 204 | norm_f0=self.norm_f0) 205 | image_size_half = cams[0, 1] 206 | proj_points = proj[:, :, :2] 207 | if not normalize_kpts: # output 2d keypoint in the image coordinate 208 | proj_points = (proj_points[:, :, :2] + 1) * image_size_half 209 | proj_points = torch.stack([ 210 | image_size_half * 2 - proj_points[:, :, 1], 211 | proj_points[:, :, 0]], dim=-1) 212 | else: # output 2d keypoint in the normalized image coordinate 213 | proj_points = torch.stack([ 214 | proj_points[:, :, 0], 215 | -proj_points[:, :, 1]], dim=-1) 216 | 217 | return proj_points 218 | 219 | def forward(self, vertices, faces, cams, textures=None): 220 | verts = self.proj_fn(vertices, cams, offset_z=self.offset_z, norm_f=self.norm_f, norm_z=self.norm_z, norm_f0=self.norm_f0) 221 | if textures is not None: 222 | if self.render_rgb: 223 | img = self.renderer.render(verts, faces, textures) 224 | else: 225 | img = None 226 | sil = self.renderer.render_silhouettes(verts, faces) 227 | return img, sil 228 | 229 | else: 230 | textures = self.textures.unsqueeze(0).expand(verts.shape[0], -1, -1, -1, -1, -1) 231 | if self.render_rgb: 232 | img = self.renderer.render(verts, faces, textures) 233 | else: 234 | img = None 235 | sil = self.renderer.render_silhouettes(verts, faces) 236 | return img, sil 237 | 238 | 239 | 240 | -------------------------------------------------------------------------------- /datasets/stanford.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import division 3 | 4 | import torch 5 | from torch.utils.data import Dataset 6 | from torchvision.transforms import Normalize 7 | import numpy as np 8 | import cv2 9 | from os.path import join 10 | import os 11 | import json 12 | 13 | from util import config 14 | from datasets.imutils import crop, flip_img, flip_pose, flip_kp, transform, rot_aa 15 | from pycocotools.mask import decode as decode_RLE 16 | 17 | 18 | def seg_from_anno(entry): 19 | """Given a .json entry, returns the binary mask as a numpy array""" 20 | 21 | rle = { 22 | "size": [entry['img_height'], entry['img_width']], 23 | "counts": entry['seg'] 24 | } 25 | 26 | decoded = decode_RLE(rle) 27 | return decoded 28 | 29 | 30 | class BaseDataset(Dataset): 31 | """ 32 | Base Dataset Class - Handles data loading and augmentation. 33 | Able to handle heterogeneous datasets (different annotations available for different datasets). 34 | You need to update the path to each dataset in utils/config.py. 35 | """ 36 | 37 | def __init__(self, 38 | dataset, 39 | param_dir=None, 40 | use_augmentation=True, 41 | is_train=True, 42 | img_res=224): 43 | 44 | super(BaseDataset, self).__init__() 45 | self.dataset = dataset 46 | self.param_dir = param_dir 47 | self.is_train = is_train 48 | BASE_FOLDER = config.DATASET_FOLDERS[dataset] 49 | 50 | self.img_dir = os.path.join(BASE_FOLDER, 'Images') 51 | self.jsonfile = os.path.join(BASE_FOLDER, config.JSON_NAME[dataset]) # accessing new version of keypoints.json 52 | # create train/test split 53 | with open(self.jsonfile) as anno_file: 54 | self.anno = json.load(anno_file) 55 | 56 | self.data_idx = np.load(os.path.join(config.DATASET_FILES[is_train][dataset])) 57 | print("Number of images: {}".format(len(self.data_idx))) 58 | # self.options = options 59 | self.normalize_img = Normalize( 60 | mean=config.IMG_NORM_MEAN, std=config.IMG_NORM_STD) 61 | 62 | self.rot_factor = 30 # Random rotation in the range [-rot_factor, rot_factor]' 63 | self.noise_factor = 0.4 # Random rotation in the range [-rot_factor, rot_factor] 64 | self.scale_factor = 0.25 # Rescale bounding boxes by a factor of [1-options.scale_factor,1+options.scale_factor] 65 | self.img_res = img_res 66 | 67 | # If False, do not do augmentation 68 | self.use_augmentation = use_augmentation 69 | 70 | def augm_params(self): 71 | """Get augmentation parameters.""" 72 | flip = 0 # flipping 73 | pn = np.ones(3) # per channel pixel-noise 74 | rot = 0 # rotation 75 | sc = 1 # scaling 76 | if self.is_train and self.use_augmentation: 77 | # We flip with probability 1/2 78 | # if np.random.uniform() <= 0.5: 79 | # flip = 1 80 | 81 | # Each channel is multiplied with a number 82 | # in the area [1-opt.noiseFactor,1+opt.noiseFactor] 83 | # pn = np.random.uniform(1-self.options.noise_factor, 1+self.options.noise_factor, 3) 84 | pn = np.random.uniform(1 - self.noise_factor, 1 + self.noise_factor, 3) 85 | 86 | # The rotation is a number in the area [-2*rotFactor, 2*rotFactor] 87 | # rot = min(2*self.options.rot_factor, 88 | # max(-2*self.options.rot_factor, np.random.randn()*self.options.rot_factor)) 89 | 90 | rot = min(2 * self.rot_factor, 91 | max(-2 * self.rot_factor, np.random.randn() * self.rot_factor)) 92 | 93 | # The scale is multiplied with a number 94 | # in the area [1-scaleFactor,1+scaleFactor] 95 | # sc = min(1+self.options.scale_factor, 96 | # max(1-self.options.scale_factor, np.random.randn()*self.options.scale_factor+1)) 97 | 98 | sc = min(1 + self.scale_factor, 99 | max(1 - self.scale_factor, np.random.randn() * self.scale_factor + 1)) 100 | # but it is zero with probability 3/5 101 | if np.random.uniform() <= 0.6: 102 | rot = 0 103 | 104 | return flip, pn, rot, sc 105 | 106 | def rgb_processing(self, rgb_img, center, scale, rot, flip, pn, border_grey_intensity=0.0): 107 | """Process rgb image and do augmentation.""" 108 | rgb_img = crop(rgb_img, center, scale, 109 | [self.img_res, self.img_res], rot=rot, 110 | border_grey_intensity=border_grey_intensity) 111 | 112 | # flip the image 113 | if flip: 114 | rgb_img = flip_img(rgb_img) 115 | # in the rgb image we add pixel noise in a channel-wise manner 116 | rgb_img[:, :, 0] = np.minimum(255.0, np.maximum(0.0, rgb_img[:, :, 0] * pn[0])) 117 | rgb_img[:, :, 1] = np.minimum(255.0, np.maximum(0.0, rgb_img[:, :, 1] * pn[1])) 118 | rgb_img[:, :, 2] = np.minimum(255.0, np.maximum(0.0, rgb_img[:, :, 2] * pn[2])) 119 | # (3,224,224),float,[0,1] 120 | rgb_img = np.transpose(rgb_img.astype('float32'), (2, 0, 1)) / 255.0 121 | return rgb_img 122 | 123 | def j2d_processing(self, kp, center, scale, r, f): 124 | """Process gt 2D keypoints and apply all augmentation transforms.""" 125 | nparts = kp.shape[0] 126 | for i in range(nparts): 127 | kp[i, 0:2] = transform(kp[i, 0:2] + 1, center, scale, 128 | [self.img_res, self.img_res], rot=r) 129 | # flip the x coordinates 130 | if f: 131 | kp = flip_kp(kp, config.IMG_RES) 132 | kp_norm = kp.copy() 133 | # convert to normalized coordinates 134 | kp_norm[:, :-1] = 2. * kp_norm[:, :-1] / self.img_res - 1. 135 | 136 | kp = kp.astype('float32') 137 | return kp, kp_norm.astype('float32') 138 | 139 | def j3d_processing(self, S, r, f): 140 | """Process gt 3D keypoints and apply all augmentation transforms.""" 141 | # in-plane rotation 142 | rot_mat = np.eye(3) 143 | if not r == 0: 144 | rot_rad = -r * np.pi / 180 145 | sn, cs = np.sin(rot_rad), np.cos(rot_rad) 146 | rot_mat[0, :2] = [cs, -sn] 147 | rot_mat[1, :2] = [sn, cs] 148 | S = np.einsum('ij,kj->ki', rot_mat, S) 149 | # flip the x coordinates 150 | if f: 151 | S = flip_kp(S) 152 | S = S.astype('float32') 153 | return S 154 | 155 | def pose_processing(self, pose, r, f): 156 | """Process SMPL theta parameters and apply all augmentation transforms.""" 157 | # rotation or the pose parameters 158 | pose[:3] = rot_aa(pose[:3], r) 159 | # flip the pose parameters 160 | if f: 161 | pose = flip_pose(pose) 162 | # (72),float 163 | pose = pose.astype('float32') 164 | return pose 165 | 166 | def __getitem__(self, index): 167 | idx = index 168 | img_idx = self.data_idx[idx] 169 | a = self.anno[img_idx] 170 | 171 | # Get augmentation parameters 172 | flip, pn, rot, sc = self.augm_params() 173 | 174 | imgname_raw = a['img_path'] 175 | 176 | # Load image 177 | imgname = join(self.img_dir, imgname_raw) 178 | 179 | # Some datatsets store \ instead of /, so convert 180 | imgname = imgname.replace("\\", "/") 181 | filename = imgname_raw.split('/')[1].split('.')[0] if self.dataset == 'stanford' else imgname_raw.split('.')[0] 182 | assert os.path.exists(imgname), "Cannot find image: {0}".format(imgname) 183 | img = cv2.imread(imgname)[:, :, ::-1].copy().astype('float32') 184 | 185 | seg = seg_from_anno(a) # (H, W) bool 186 | seg = seg.astype('float32') * 255. 187 | assert (np.nonzero(seg)[0].shape[0] > 1) 188 | seg = np.dstack([seg, seg, seg]) # (H, W, 3) as float 189 | x0, y0, width, height = a['img_bbox'] 190 | 191 | scaleFactor = 1.2 192 | scale = scaleFactor * max(width, height) / 200 193 | 194 | kp_S24 = np.array(a['joints']) 195 | center = np.array([x0 + width / 2, y0 + height / 2]) # Center of dog 196 | 197 | kp_S24, kp_S24_norm = self.j2d_processing(kp_S24, center, sc * scale, rot, flip) 198 | 199 | kp_S24 = torch.from_numpy(kp_S24) 200 | kp_S24_norm = torch.from_numpy(kp_S24_norm) 201 | img_crop = self.rgb_processing(img, center, sc * scale, rot, flip, pn, border_grey_intensity=255.0) 202 | seg_crop = self.rgb_processing(seg, center, sc * scale, rot, flip, 203 | np.array([1.0, 1.0, 1.0])) # No pixel noise multiplier 204 | 205 | item = {} 206 | 207 | item['has_pose_3d'] = False 208 | item['has_smpl'] = False 209 | item['keypoints_3d'] = np.zeros((24, 4)) 210 | 211 | item['pred_pose'] = np.zeros((105)) 212 | item['pred_shape'] = np.zeros((26)) 213 | item['pred_camera'] = np.zeros((3)) 214 | item['pred_trans'] = np.zeros((3)) 215 | 216 | if self.param_dir is not None: 217 | inp_path = imgname_raw.replace("/", "_").replace(".jpg", ".npz") 218 | if self.dataset == 'animal_pose': 219 | inp_path = "images_{0}".format(inp_path) 220 | 221 | with np.load(os.path.join(self.param_dir, inp_path)) as f: 222 | item['pred_pose'] = f.f.pose 223 | item['pred_shape'] = f.f.betas 224 | item['pred_camera'] = f.f.camera 225 | item['pred_trans'] = f.f.trans 226 | 227 | item['imgname'] = imgname 228 | item['keypoints_norm'] = kp_S24_norm[config.EVAL_KEYPOINTS] 229 | item['keypoints'] = kp_S24[config.EVAL_KEYPOINTS] 230 | item['scale'] = float(sc * scale) 231 | item['center'] = center.astype('float32') 232 | item['index'] = img_idx 233 | 234 | img_crop = torch.from_numpy(img_crop).float() 235 | seg_crop = torch.from_numpy(seg_crop[[0]]).float() # [3, h, w] -> [1, h, w] 236 | 237 | item['img_orig'] = img_crop.clone() 238 | item['img'] = self.normalize_img(img_crop) 239 | item['img_border_mask'] = torch.all(img_crop < 1.0, dim=0).unsqueeze(0).float() 240 | item['seg'] = seg_crop.clone() 241 | item['has_seg'] = True 242 | item['dataset'] = self.dataset 243 | item['filename'] = filename 244 | 245 | return item 246 | 247 | def __len__(self): 248 | return len(self.data_idx) 249 | -------------------------------------------------------------------------------- /util/net_blocks.py: -------------------------------------------------------------------------------- 1 | ''' 2 | CNN building blocks. 3 | Taken from https://github.com/shubhtuls/factored3d/ 4 | ''' 5 | from __future__ import division 6 | from __future__ import print_function 7 | import torch 8 | import torch.nn as nn 9 | import math 10 | 11 | 12 | class Flatten(nn.Module): 13 | def forward(self, x): 14 | return x.view(x.size()[0], -1) 15 | 16 | 17 | class Unsqueeze(nn.Module): 18 | def __init__(self, dim): 19 | super(Unsqueeze, self).__init__() 20 | self.dim = dim 21 | 22 | def forward(self, x): 23 | return x.unsqueeze(self.dim) 24 | 25 | 26 | ## fc layers 27 | def fc(norm_type, nc_inp, nc_out): 28 | if norm_type == 'batch': 29 | return nn.Sequential( 30 | nn.Linear(nc_inp, nc_out, bias=True), 31 | nn.BatchNorm1d(nc_out), 32 | nn.LeakyReLU(0.2, inplace=True) 33 | ) 34 | else: 35 | return nn.Sequential( 36 | nn.Linear(nc_inp, nc_out), 37 | nn.LeakyReLU(0.1, inplace=True) 38 | ) 39 | 40 | 41 | def fc_stack(nc_inp, nc_out, nlayers, norm_type='batch'): 42 | modules = [] 43 | for l in range(nlayers): 44 | modules.append(fc(norm_type, nc_inp, nc_out)) 45 | nc_inp = nc_out 46 | encoder = nn.Sequential(*modules) 47 | net_init(encoder) 48 | return encoder 49 | 50 | 51 | def fc_stack_dropout(nc_inp, nc_out, nlayers): 52 | modules = [] 53 | modules.append(nn.Linear(nc_inp, 1024, bias=True)) 54 | modules.append(nn.ReLU()) 55 | modules.append(nn.Dropout()) 56 | modules.append(nn.Linear(1024, 1024, bias=True)) 57 | modules.append(nn.ReLU()) 58 | modules.append(nn.Dropout()) 59 | modules.append(nn.Linear(1024, nc_out, bias=True)) 60 | 61 | encoder = nn.Sequential(*modules) 62 | net_init(encoder) 63 | nl = 1 64 | for m in encoder.modules(): 65 | if isinstance(m, nn.Linear): 66 | if nl == nlayers: 67 | torch.nn.init.xavier_normal(m.weight, gain=0.01) 68 | else: 69 | torch.nn.init.xavier_normal(m.weight) 70 | if m.bias is not None: 71 | m.bias.data.zero_() 72 | nl += 1 73 | 74 | return encoder 75 | 76 | 77 | ## 2D convolution layers 78 | def conv2d(norm_type, in_planes, out_planes, kernel_size=3, stride=1, num_groups=2): 79 | if norm_type == 'batch': 80 | return nn.Sequential( 81 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size - 1) // 2, 82 | bias=True), 83 | nn.BatchNorm2d(out_planes), 84 | nn.LeakyReLU(0.2, inplace=True) 85 | ) 86 | elif norm_type == 'group': 87 | return nn.Sequential( 88 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size - 1) // 2, 89 | bias=True), 90 | nn.GroupNorm(num_groups, out_planes), 91 | nn.LeakyReLU(0.2, inplace=True) 92 | ) 93 | else: 94 | return nn.Sequential( 95 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size - 1) // 2, 96 | bias=True), 97 | nn.LeakyReLU(0.2, inplace=True) 98 | ) 99 | 100 | 101 | def deconv2d(in_planes, out_planes): 102 | return nn.Sequential( 103 | nn.ConvTranspose2d(in_planes, out_planes, kernel_size=4, stride=2, padding=1, bias=True), 104 | nn.LeakyReLU(0.2, inplace=True) 105 | ) 106 | 107 | 108 | def upconv2d(in_planes, out_planes, mode='bilinear'): 109 | if mode == 'nearest': 110 | print('Using NN upsample!!') 111 | upconv = nn.Sequential( 112 | nn.Upsample(scale_factor=2, mode=mode), 113 | nn.ReflectionPad2d(1), 114 | nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=1, padding=0), 115 | nn.LeakyReLU(0.2, inplace=True) 116 | ) 117 | return upconv 118 | 119 | 120 | def decoder2d(nlayers, nz_shape, nc_input, norm_type='batch', nc_final=1, nc_min=8, nc_step=1, init_fc=True, 121 | use_deconv=False, upconv_mode='bilinear', num_groups=2): 122 | ''' Simple 3D encoder with nlayers. 123 | 124 | Args: 125 | nlayers: number of decoder layers 126 | nz_shape: number of bottleneck 127 | nc_input: number of channels to start upconvolution from 128 | use_bn: whether to use batch_norm 129 | nc_final: number of output channels 130 | nc_min: number of min channels 131 | nc_step: double number of channels every nc_step layers 132 | init_fc: initial features are not spatial, use an fc & unsqueezing to make them 3D 133 | ''' 134 | modules = [] 135 | if init_fc: 136 | modules.append(fc('batch', nz_shape, nc_input)) 137 | for d in range(3): 138 | modules.append(Unsqueeze(2)) 139 | nc_output = nc_input 140 | for nl in range(nlayers): 141 | if (nl % nc_step == 0) and (nc_output // 2 >= nc_min): 142 | nc_output = nc_output // 2 143 | if use_deconv: 144 | print('Using deconv decoder!') 145 | modules.append(deconv2d(nc_input, nc_output)) 146 | nc_input = nc_output 147 | modules.append(conv2d(norm_type, nc_input, nc_output, num_groups=num_groups // 2)) 148 | else: 149 | modules.append(upconv2d(nc_input, nc_output, mode=upconv_mode)) 150 | nc_input = nc_output 151 | modules.append(conv2d(norm_type, nc_input, nc_output, num_groups=num_groups // 2)) 152 | 153 | modules.append(nn.Conv2d(nc_output, nc_final, kernel_size=3, stride=1, padding=1, bias=True)) 154 | decoder = nn.Sequential(*modules) 155 | net_init(decoder) 156 | return decoder 157 | 158 | 159 | ## 3D convolution layers 160 | def conv3d(norm_type, in_planes, out_planes, kernel_size=3, stride=1, num_groups=2): 161 | if norm_type == 'batch': 162 | return nn.Sequential( 163 | nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size - 1) // 2, 164 | bias=True), 165 | nn.BatchNorm3d(out_planes), 166 | nn.LeakyReLU(0.2, inplace=True) 167 | ) 168 | elif norm_type == 'group': 169 | return nn.Sequential( 170 | nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size - 1) // 2, 171 | bias=True), 172 | nn.GroupNorm(num_groups, out_planes), 173 | nn.LeakyReLU(0.2, inplace=True) 174 | ) 175 | else: 176 | return nn.Sequential( 177 | nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size - 1) // 2, 178 | bias=True), 179 | nn.LeakyReLU(0.2, inplace=True) 180 | ) 181 | 182 | 183 | def deconv3d(norm_type, in_planes, out_planes, num_groups=2): 184 | if norm_type == 'batch': 185 | return nn.Sequential( 186 | nn.ConvTranspose3d(in_planes, out_planes, kernel_size=4, stride=2, padding=1, bias=True), 187 | nn.BatchNorm3d(out_planes), 188 | nn.LeakyReLU(0.2, inplace=True) 189 | ) 190 | elif norm_type == 'group': 191 | return nn.Sequential( 192 | nn.ConvTranspose3d(in_planes, out_planes, kernel_size=4, stride=2, padding=1, bias=True), 193 | nn.GroupNorm(num_groups, out_planes), 194 | nn.LeakyReLU(0.2, inplace=True) 195 | ) 196 | else: 197 | return nn.Sequential( 198 | nn.ConvTranspose3d(in_planes, out_planes, kernel_size=4, stride=2, padding=1, bias=True), 199 | nn.LeakyReLU(0.2, inplace=True) 200 | ) 201 | 202 | 203 | ## 3D Network Modules 204 | def encoder3d(nlayers, norm_type='batch', nc_input=1, nc_max=128, nc_l1=8, nc_step=1, nz_shape=20): 205 | ''' Simple 3D encoder with nlayers. 206 | 207 | Args: 208 | nlayers: number of encoder layers 209 | use_bn: whether to use batch_norm 210 | nc_input: number of input channels 211 | nc_max: number of max channels 212 | nc_l1: number of channels in layer 1 213 | nc_step: double number of channels every nc_step layers 214 | nz_shape: size of bottleneck layer 215 | ''' 216 | modules = [] 217 | nc_output = nc_l1 218 | for nl in range(nlayers): 219 | if (nl >= 1) and (nl % nc_step == 0) and (nc_output <= nc_max * 2): 220 | nc_output *= 2 221 | 222 | modules.append(conv3d(norm_type, nc_input, nc_output, stride=1)) 223 | nc_input = nc_output 224 | modules.append(conv3d(norm_type, nc_input, nc_output, stride=1)) 225 | modules.append(torch.nn.MaxPool3d(kernel_size=2, stride=2)) 226 | 227 | modules.append(Flatten()) 228 | modules.append(fc_stack(nc_output, nz_shape, 2, norm_type)) 229 | encoder = nn.Sequential(*modules) 230 | net_init(encoder) 231 | return encoder, nc_output 232 | 233 | 234 | def decoder3d(nlayers, nz_shape, nc_input, norm_type='batch', nc_final=1, nc_min=8, nc_step=1, init_fc=True): 235 | ''' Simple 3D encoder with nlayers. 236 | 237 | Args: 238 | nlayers: number of decoder layers 239 | nz_shape: number of bottleneck 240 | nc_input: number of channels to start upconvolution from 241 | use_bn: whether to use batch_norm 242 | nc_final: number of output channels 243 | nc_min: number of min channels 244 | nc_step: double number of channels every nc_step layers 245 | init_fc: initial features are not spatial, use an fc & unsqueezing to make them 3D 246 | ''' 247 | modules = [] 248 | if init_fc: 249 | modules.append(fc('batch', nz_shape, nc_input)) 250 | for d in range(3): 251 | modules.append(Unsqueeze(2)) 252 | nc_output = nc_input 253 | for nl in range(nlayers): 254 | if (nl % nc_step == 0) and (nc_output // 2 >= nc_min): 255 | nc_output = nc_output // 2 256 | 257 | modules.append(deconv3d(norm_type, nc_input, nc_output)) 258 | nc_input = nc_output 259 | modules.append(conv3d(norm_type, nc_input, nc_output)) 260 | 261 | modules.append(nn.Conv3d(nc_output, nc_final, kernel_size=3, stride=1, padding=1, bias=True)) 262 | decoder = nn.Sequential(*modules) 263 | net_init(decoder) 264 | return decoder 265 | 266 | 267 | def net_init(net): 268 | for m in net.modules(): 269 | if isinstance(m, nn.Linear): 270 | m.weight.data.normal_(0, 0.02) 271 | if m.bias is not None: 272 | m.bias.data.zero_() 273 | 274 | if isinstance(m, nn.Conv2d): # or isinstance(m, nn.ConvTranspose2d): 275 | m.weight.data.normal_(0, 0.02) 276 | if m.bias is not None: 277 | m.bias.data.zero_() 278 | 279 | if isinstance(m, nn.ConvTranspose2d): 280 | # Initialize Deconv with bilinear weights. 281 | base_weights = bilinear_init(m.weight.data.size(-1)) 282 | base_weights = base_weights.unsqueeze(0).unsqueeze(0) 283 | m.weight.data = base_weights.repeat(m.weight.data.size(0), m.weight.data.size(1), 1, 1) 284 | if m.bias is not None: 285 | m.bias.data.zero_() 286 | 287 | if isinstance(m, nn.Conv3d) or isinstance(m, nn.ConvTranspose3d): 288 | m.weight.data.normal_(0, 0.02) 289 | if m.bias is not None: 290 | m.bias.data.zero_() 291 | 292 | elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm3d): 293 | m.weight.data.fill_(1) 294 | m.bias.data.zero_() 295 | 296 | 297 | def bilinear_init(kernel_size=4): 298 | # Following Caffe's BilinearUpsamplingFiller 299 | # https://github.com/BVLC/caffe/pull/2213/files 300 | import numpy as np 301 | width = kernel_size 302 | height = kernel_size 303 | f = int(np.ceil(width / 2.)) 304 | cc = (2 * f - 1 - f % 2) / (2. * f) 305 | weights = torch.zeros((height, width)) 306 | for y in range(height): 307 | for x in range(width): 308 | weights[y, x] = (1 - np.abs(x / f - cc)) * (1 - np.abs(y / f - cc)) 309 | 310 | return weights 311 | 312 | 313 | if __name__ == '__main__': 314 | decoder2d(5, None, 256, use_deconv=True, init_fc=False) 315 | bilinear_init() 316 | -------------------------------------------------------------------------------- /util/helpers/visualize.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from util.helpers.draw_smal_joints import SMALJointDrawer 4 | import numpy as np 5 | 6 | 7 | # generate visualizations, code adopted from WLDO 8 | class Visualizer(): 9 | 10 | @staticmethod 11 | def generate_output_figures(preds, vis_refine=False): 12 | ''' 13 | Args: 14 | preds: predictions from model 15 | vis_refine: whether return refined visualizations 16 | 17 | Returns: visualization with predicted 2D keypoints and silhouettes 18 | ''' 19 | marker_size = 8 20 | thickness = 4 21 | smal_drawer = SMALJointDrawer() 22 | real_rgb_vis = preds['img_orig'] 23 | keypoints_gt = preds['keypoints'] 24 | real_rgb_vis_kp = smal_drawer.draw_joints(preds['img_orig'], preds['synth_landmarks'], 25 | visible=preds['keypoints'][:, :, [2]], 26 | marker_size=marker_size, 27 | thickness=thickness, normalized=False) 28 | pic_on_pic = preds['img_orig'].cpu() * (1-preds['synth_silhouettes'].cpu()) + \ 29 | preds['synth_xyz'].cpu() * preds['synth_silhouettes'].cpu() 30 | pic_npy = pic_on_pic.data.cpu() 31 | 32 | sil_err = torch.cat([ 33 | preds['seg'].cpu(), preds['synth_silhouettes'].cpu(), torch.zeros_like(preds['seg']).cpu() 34 | ], dim=1) 35 | 36 | sil_err_npy = sil_err.data.cpu() 37 | sil_err_npy = sil_err_npy + (1.0 - preds['img_border_mask'].cpu()) 38 | output_figs = torch.stack( 39 | [real_rgb_vis, pic_npy, real_rgb_vis_kp, sil_err_npy],dim = 1) # Batch Size, 4 (Images), RGB, H, W 40 | 41 | if vis_refine: 42 | real_rgb_vis_kp_re = smal_drawer.draw_joints(preds['img_orig'], preds['synth_landmarks_re'], 43 | visible=preds['keypoints'][:, :, [2]], 44 | marker_size=marker_size, thickness=thickness, normalized=False) 45 | 46 | pic_on_pic = preds['img_orig'].cpu() * (1 - preds['synth_silhouettes_re'].cpu()) + \ 47 | preds['synth_xyz_re'].cpu() * preds['synth_silhouettes_re'].cpu() 48 | pic_npy_re = pic_on_pic.data.cpu() 49 | 50 | sil_err = torch.cat([ 51 | preds['seg'].cpu(), preds['synth_silhouettes_re'].cpu(), torch.zeros_like(preds['seg']).cpu() 52 | ], dim=1) 53 | sil_err_npy_re = sil_err.data.cpu() 54 | sil_err_npy_re = sil_err_npy_re + (1.0 - preds['img_border_mask'].cpu()) 55 | 56 | output_figs = torch.stack( 57 | [real_rgb_vis, pic_npy, real_rgb_vis_kp, sil_err_npy, pic_npy_re, real_rgb_vis_kp_re, sil_err_npy_re 58 | ], dim=1) 59 | return output_figs 60 | 61 | @staticmethod 62 | def generate_output_figures_v2(preds, vis_refine=False): 63 | # add recovered mesh in an alternative view 64 | marker_size = 8 65 | thickness = 4 66 | 67 | smal_drawer = SMALJointDrawer() 68 | real_rgb_vis = preds['img_orig'] 69 | keypoints_gt = preds['keypoints'] 70 | real_rgb_vis_kp = smal_drawer.draw_joints(preds['img_orig'], preds['synth_landmarks'], 71 | visible=preds['keypoints'][:, :, [2]], 72 | marker_size=marker_size, 73 | thickness=thickness, normalized=False) 74 | pic_on_pic = preds['img_orig'].cpu() * (1-preds['synth_silhouettes'].cpu()) + \ 75 | preds['synth_xyz'].cpu() * preds['synth_silhouettes'].cpu() 76 | pic_npy = pic_on_pic.data.cpu() 77 | 78 | sil_err = torch.cat([ 79 | preds['seg'].cpu(), preds['synth_silhouettes'].cpu(), torch.zeros_like(preds['seg']).cpu() 80 | ], dim=1) 81 | 82 | sil_err_npy = sil_err.data.cpu() 83 | sil_err_npy = sil_err_npy + (1.0 - preds['img_border_mask'].cpu()) 84 | output_figs = torch.stack( 85 | [real_rgb_vis, pic_npy, real_rgb_vis_kp, sil_err_npy],dim = 1) # Batch Size, 4 (Images), RGB, H, W 86 | 87 | if vis_refine: 88 | real_rgb_vis_kp_re = smal_drawer.draw_joints(preds['img_orig'], preds['synth_landmarks_re'], 89 | visible=preds['keypoints'][:, :, [2]], 90 | marker_size=marker_size, thickness=thickness, normalized=False) 91 | 92 | pic_on_pic = preds['img_orig'].cpu() * (1 - preds['synth_silhouettes_re'].cpu()) + \ 93 | preds['synth_xyz_re'].cpu() * preds['synth_silhouettes_re'].cpu() 94 | pic_npy_re = pic_on_pic.data.cpu() 95 | 96 | sil_err = torch.cat([ 97 | preds['seg'].cpu(), preds['synth_silhouettes_re'].cpu(), torch.zeros_like(preds['seg']).cpu() 98 | ], dim=1) 99 | sil_err_npy_re = sil_err.data.cpu() 100 | sil_err_npy_re = sil_err_npy_re + (1.0 - preds['img_border_mask'].cpu()) 101 | 102 | synth_xyz_re_cano = preds['synth_xyz_re_cano'].cpu() 103 | 104 | output_figs = torch.stack( 105 | [real_rgb_vis, pic_npy, real_rgb_vis_kp, sil_err_npy, pic_npy_re, real_rgb_vis_kp_re, sil_err_npy_re, 106 | synth_xyz_re_cano], dim=1) 107 | return output_figs 108 | 109 | @staticmethod 110 | def generate_demo_output(preds): 111 | """Figure output of: [raw_img, cropped, mesh_view, mesh & raw, silh_view]""" 112 | marker_size = 8 113 | thickness = 4 114 | 115 | smal_drawer = SMALJointDrawer() 116 | 117 | #synth_render_vis = smal_drawer.draw_joints(preds['synth_xyz'], preds['synth_landmarks'], marker_size=marker_size, thickness=thickness, normalized=False) 118 | synth_render_vis = preds['synth_xyz'].cpu() 119 | 120 | # Real image (with bbox overlayed) 121 | real_rgb_vis = preds['img_orig'].cpu() 122 | 123 | #Overlaid bbox 124 | # bbox_overlay = np.zeros(real_rgb_vis.shape).astype(np.float32) 125 | # for n, bbox in enumerate(preds['bbox']): 126 | # (x0, y0, width, height) = list(map(int, bbox.cpu().numpy())) 127 | # bbox_overlay[n, 0, y0:y0+height, x0:x0+width] = 1. # red channel overlay 128 | # 129 | # real_rgb_vis += .4 * bbox_overlay # overlay bbox 130 | 131 | # real cropped 132 | real_rgb_vis_cropped = preds['img'].cpu() 133 | 134 | pic_on_pic = preds['img'] * 0.4 + preds['synth_xyz'].cpu() * 0.6 135 | pic_npy = pic_on_pic.data.cpu() 136 | 137 | batch, _, H, W = pic_npy.shape 138 | sil_err = torch.cat([ 139 | preds['synth_silhouettes'], preds['synth_silhouettes'], preds['synth_silhouettes']], dim = 1) 140 | 141 | sil_err_npy = sil_err.data.cpu() 142 | output_figs = torch.stack( 143 | [real_rgb_vis, synth_render_vis, pic_npy, sil_err_npy], dim = 1) # Batch Size, 4 (Images), RGB, H, W 144 | 145 | return output_figs 146 | 147 | @staticmethod 148 | def draw_mesh_plotly( 149 | title, 150 | verts, faces, 151 | up=dict(x=0,y=1,z=0), eye=dict(x=0.0, y=0.0, z=1.0), 152 | hack_box_size = 1.0, 153 | center_mesh = True): 154 | 155 | camera = dict(up=up, center=dict(x=0, y=0, z=0), eye=eye) 156 | scene = dict( 157 | xaxis = dict(nticks=10, range=[-1,1],), 158 | yaxis = dict(nticks=10, range=[-1,1],), 159 | zaxis = dict(nticks=10, range=[-1,1],), 160 | camera = camera) 161 | 162 | centre_of_mass = torch.mean(verts, dim = 0, keepdim=True) 163 | if not center_mesh: 164 | centre_of_mass = torch.zeros_like(centre_of_mass) 165 | 166 | output_verts = (verts - centre_of_mass).data.cpu().numpy() 167 | output_faces = faces.data.cpu().numpy() 168 | 169 | hack_points = np.array([ 170 | [-1.0, -1.0, -1.0], 171 | [-1.0, -1.0, 1.0], 172 | [-1.0, 1.0, -1.0], 173 | [-1.0, 1.0, 1.0], 174 | [1.0, -1.0, -1.0], 175 | [1.0, -1.0, 1.0], 176 | [1.0, 1.0, -1.0], 177 | [1.0, 1.0, 1.0]]) * hack_box_size 178 | 179 | vis_fig = go.Figure(data = [ 180 | go.Mesh3d( 181 | x = output_verts[:, 0], 182 | y = output_verts[:, 1], 183 | z = -1 * output_verts[:, 2], 184 | i = output_faces[:, 0], 185 | j = output_faces[:, 1], 186 | k = output_faces[:, 2], 187 | color='cornflowerblue' 188 | ), 189 | go.Scatter3d( 190 | x = hack_points[:, 0], 191 | y = hack_points[:, 1], 192 | z = hack_points[:, 2], 193 | mode='markers', 194 | name='_fake_pts', 195 | visible=True, 196 | marker=dict( 197 | size=1, 198 | opacity = 0, 199 | color=(0.0, 0.0, 0.0), 200 | ) 201 | )]) 202 | 203 | vis_fig.update_scenes(patch = scene) 204 | vis_fig.update_layout(title=title) 205 | return vis_fig 206 | 207 | def draw_double_mesh_plotly( 208 | viz, 209 | title, 210 | verts, faces, 211 | verts2, 212 | joints_3d, 213 | gt_joints_3d, 214 | visdom_env_imgs, 215 | up=dict(x=0,y=1,z=0), eye=dict(x=0.0, y=0.0, z=1.0), 216 | hack_box_size = 1.0, 217 | center_mesh = True): 218 | 219 | camera = dict(up=up, center=dict(x=0, y=0, z=0), eye=eye) 220 | scene = dict( 221 | xaxis = dict(nticks=10, range=[-1,1],), 222 | yaxis = dict(nticks=10, range=[-1,1],), 223 | zaxis = dict(nticks=10, range=[-1,1],), 224 | camera = camera) 225 | 226 | centre_of_mass = torch.mean(verts, dim = 0, keepdim=True) 227 | if not center_mesh: 228 | centre_of_mass = torch.zeros_like(centre_of_mass) 229 | 230 | output_joints_3d = (joints_3d - centre_of_mass).data.cpu().numpy() 231 | output_verts = (verts - centre_of_mass).data.cpu().numpy() 232 | output_verts2 = (verts2 - centre_of_mass).data.cpu().numpy() 233 | output_joints_3d2 = (gt_joints_3d - centre_of_mass).data.cpu().numpy() 234 | output_faces = faces.data.cpu().numpy() 235 | 236 | hack_points = np.array([ 237 | [-1.0, -1.0, -1.0], 238 | [-1.0, -1.0, 1.0], 239 | [-1.0, 1.0, -1.0], 240 | [-1.0, 1.0, 1.0], 241 | [1.0, -1.0, -1.0], 242 | [1.0, -1.0, 1.0], 243 | [1.0, 1.0, -1.0], 244 | [1.0, 1.0, 1.0]]) * hack_box_size 245 | 246 | vis_fig = go.Figure(data = [ 247 | go.Mesh3d( 248 | x = output_verts[:, 0], 249 | y = output_verts[:, 1], 250 | z = -1 * output_verts[:, 2], 251 | i = output_faces[:, 0], 252 | j = output_faces[:, 1], 253 | k = output_faces[:, 2], 254 | color='cornflowerblue', 255 | opacity=0.5, 256 | ), 257 | go.Scatter3d( 258 | x = output_joints_3d[:, 0], 259 | y = output_joints_3d[:, 1], 260 | z = -1 * output_joints_3d[:, 2], 261 | mode='markers', 262 | marker = dict( 263 | size=8, 264 | color='red' 265 | ) 266 | ), 267 | go.Mesh3d( 268 | x = output_verts2[:, 0], 269 | y = output_verts2[:, 1], 270 | z = -1 * output_verts2[:, 2], 271 | i = output_faces[:, 0], 272 | j = output_faces[:, 1], 273 | k = output_faces[:, 2], 274 | color='green', 275 | opacity=0.5, 276 | ), 277 | go.Scatter3d( 278 | x = output_joints_3d2[:, 0], 279 | y = output_joints_3d2[:, 1], 280 | z = -1 * output_joints_3d2[:, 2], 281 | mode='markers', 282 | marker = dict( 283 | size=8, 284 | color='purple' 285 | ) 286 | ), 287 | go.Scatter3d( 288 | x = hack_points[:, 0], 289 | y = hack_points[:, 1], 290 | z = hack_points[:, 2], 291 | mode='markers', 292 | name='_fake_pts', 293 | visible=True, 294 | marker=dict( 295 | size=1, 296 | opacity = 0, 297 | color=(0.0, 0.0, 0.0), 298 | ) 299 | )]) 300 | 301 | vis_fig.update_scenes(patch = scene) 302 | vis_fig.update_layout(title=title) 303 | return vis_fig -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import, division 2 | 3 | import os 4 | import sys 5 | import time 6 | import numpy as np 7 | from tqdm import tqdm 8 | import cv2 9 | import argparse 10 | import torch 11 | import torch.nn as nn 12 | import torch.optim 13 | from torch.utils.data import DataLoader 14 | from model.mesh_graph_hg import MeshGraph_hg 15 | from util import config 16 | from util.helpers.visualize import Visualizer 17 | from util.metrics import Metrics 18 | from datasets.stanford import BaseDataset 19 | from scipy.spatial.transform import Rotation as R 20 | 21 | def main(args): 22 | 23 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_ids 24 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 25 | 26 | if not os.path.exists(args.output_dir): 27 | os.mkdir(args.output_dir) 28 | # set model 29 | model = MeshGraph_hg(device, args.shape_family_id, args.num_channels, args.num_layers, args.betas_scale, 30 | args.shape_init, args.local_feat, num_downsampling=args.num_downsampling, 31 | render_rgb=args.save_results) 32 | model = nn.DataParallel(model).to(device) 33 | # set data 34 | print("Evaluate on {} dataset".format(args.dataset)) 35 | dataset_eval = BaseDataset(args.dataset, param_dir=args.param_dir, is_train=False, use_augmentation=False) 36 | data_loader_eval = DataLoader(dataset_eval, batch_size=args.batch_size, shuffle=False, num_workers=args.num_works) 37 | # set optimizer 38 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 39 | 40 | if os.path.isfile(args.resume): 41 | print("=> loading checkpoint {}".format(args.resume)) 42 | checkpoint = torch.load(args.resume) 43 | model.load_state_dict(checkpoint['state_dict']) 44 | if args.load_optimizer: 45 | optimizer.load_state_dict(checkpoint['optimizer']) 46 | args.start_epoch = checkpoint['epoch'] + 1 47 | print("=> loaded checkpoint {} (epoch {})".format(args.resume, checkpoint['epoch'])) 48 | else: 49 | print("No checkpoint found") 50 | 51 | pck, iou_silh, pck_by_part, pck_re, iou_re = run_evaluation(model, dataset_eval, data_loader_eval, device, args) 52 | print("Evaluate only, PCK: {:6.4f}, IOU: {:6.4f}, PCK_re: {:6.4f}, IOU_re: {:6.4f}" 53 | .format(pck, iou_silh, pck_re, iou_re)) 54 | return 55 | 56 | 57 | def run_evaluation(model, dataset, data_loader, device, args): 58 | 59 | model.eval() 60 | result_dir = args.output_dir 61 | batch_size = args.batch_size 62 | 63 | pck = np.zeros((len(dataset))) 64 | pck_by_part = {group: np.zeros((len(dataset))) for group in config.KEYPOINT_GROUPS} 65 | pck_by_part_re = {group: np.zeros((len(dataset))) for group in config.KEYPOINT_GROUPS} 66 | acc_sil_2d = np.zeros(len(dataset)) 67 | 68 | pck_re = np.zeros((len(dataset))) 69 | acc_sil_2d_re = np.zeros(len(dataset)) 70 | 71 | smal_pose = np.zeros((len(dataset), 105)) 72 | smal_betas = np.zeros((len(dataset), 20)) 73 | smal_camera = np.zeros((len(dataset), 3)) 74 | smal_imgname = [] 75 | # rotate estimated mesh to visualize in an alternative view 76 | rot_matrix = torch.from_numpy(R.from_euler('y', -90, degrees=True).as_dcm()).float().to(device) 77 | tqdm_iterator = tqdm(data_loader, desc='Eval', total=len(data_loader)) 78 | 79 | for step, batch in enumerate(tqdm_iterator): 80 | with torch.no_grad(): 81 | preds = {} 82 | 83 | keypoints = batch['keypoints'].to(device) 84 | keypoints_norm = batch['keypoints_norm'].to(device) 85 | seg = batch['seg'].to(device) 86 | has_seg = batch['has_seg'] 87 | img = batch['img'].to(device) 88 | img_border_mask = batch['img_border_mask'].to(device) 89 | verts, joints, shape, pred_codes = model(img) 90 | scale_pred, trans_pred, pose_pred, betas_pred, betas_scale_pred = pred_codes 91 | pred_camera = torch.cat([scale_pred[:, [0]], torch.ones(keypoints.shape[0], 2).cuda() * config.IMG_RES / 2], 92 | dim=1) 93 | faces = model.module.smal.faces.unsqueeze(0).expand(verts.shape[0], 7774, 3) 94 | labelled_joints_3d = joints[:, config.MODEL_JOINTS] 95 | synth_rgb, synth_silhouettes = model.module.model_renderer(verts, faces, pred_camera) 96 | synth_silhouettes = synth_silhouettes.unsqueeze(1) 97 | synth_landmarks = model.module.model_renderer.project_points(labelled_joints_3d, pred_camera) 98 | 99 | verts_refine, joints_refine, _, _ = model.module.smal(betas_pred, pose_pred, trans=trans_pred, 100 | del_v=shape, 101 | betas_logscale=betas_scale_pred) 102 | labelled_joints_3d_refine = joints_refine[:, config.MODEL_JOINTS] 103 | synth_rgb_refine, synth_silhouettes_refine = model.module.model_renderer(verts_refine, faces, pred_camera) 104 | synth_silhouettes_refine = synth_silhouettes_refine.unsqueeze(1) 105 | synth_landmarks_refine = model.module.model_renderer.project_points(labelled_joints_3d_refine, 106 | pred_camera) 107 | 108 | if args.save_results: 109 | synth_rgb = torch.clamp(synth_rgb[0], 0.0, 1.0) 110 | synth_rgb_refine = torch.clamp(synth_rgb_refine[0], 0.0, 1.0) 111 | # visualize in another view 112 | verts_refine_cano = verts_refine - torch.mean(verts_refine, dim=1, keepdim=True) 113 | verts_refine_cano = (rot_matrix @ verts_refine_cano.unsqueeze(-1)).squeeze(-1) 114 | # increase the depth such that the rendered the shapes are in within the image 115 | verts_refine_cano[:, :, 2] = verts_refine_cano[:, :, 2] + 15 116 | synth_rgb_refine_cano, _ = model.module.model_renderer(verts_refine_cano, faces, 117 | pred_camera) 118 | synth_rgb_refine_cano = torch.clamp(synth_rgb_refine_cano[0], 0.0, 1.0) 119 | preds['synth_xyz_re_cano'] = synth_rgb_refine_cano 120 | 121 | preds['pose'] = pose_pred 122 | preds['betas'] = betas_pred 123 | preds['camera'] = pred_camera 124 | preds['trans'] = trans_pred 125 | 126 | preds['verts'] = verts 127 | preds['joints_3d'] = labelled_joints_3d 128 | preds['faces'] = faces 129 | 130 | preds['acc_PCK'] = Metrics.PCK(synth_landmarks, keypoints_norm, seg, has_seg) 131 | preds['acc_IOU'] = Metrics.IOU(synth_silhouettes, seg, img_border_mask, mask=has_seg) 132 | 133 | preds['acc_PCK_re'] = Metrics.PCK(synth_landmarks_refine, keypoints_norm, seg, has_seg) 134 | preds['acc_IOU_re'] = Metrics.IOU(synth_silhouettes_refine, seg, img_border_mask, mask=has_seg) 135 | 136 | for group, group_kps in config.KEYPOINT_GROUPS.items(): 137 | preds[f'{group}_PCK'] = Metrics.PCK(synth_landmarks, keypoints_norm, seg, has_seg, 138 | thresh_range=[0.15], 139 | idxs=group_kps) 140 | preds[f'{group}_PCK_RE'] = Metrics.PCK(synth_landmarks_refine, keypoints_norm, seg, has_seg, 141 | thresh_range=[0.15], 142 | idxs=group_kps) 143 | preds['synth_xyz'] = synth_rgb 144 | preds['synth_silhouettes'] = synth_silhouettes 145 | preds['synth_landmarks'] = synth_landmarks 146 | preds['synth_xyz_re'] = synth_rgb_refine 147 | preds['synth_landmarks_re'] = synth_landmarks_refine 148 | preds['synth_silhouettes_re'] = synth_silhouettes_refine 149 | 150 | assert not any(k in preds for k in batch.keys()) 151 | preds.update(batch) 152 | 153 | curr_batch_size = preds['synth_landmarks'].shape[0] 154 | # compute accuracy for coarse stage 155 | pck[step * batch_size:step * batch_size + curr_batch_size] = preds['acc_PCK'].data.cpu().numpy() 156 | acc_sil_2d[step * batch_size:step * batch_size + curr_batch_size] = preds['acc_IOU'].data.cpu().numpy() 157 | smal_pose[step * batch_size:step * batch_size + curr_batch_size] = preds['pose'].data.cpu().numpy() 158 | smal_betas[step * batch_size:step * batch_size + curr_batch_size, :preds['betas'].shape[1]] = preds['betas'].data.cpu().numpy() 159 | smal_camera[step * batch_size:step * batch_size + curr_batch_size] = preds['camera'].data.cpu().numpy() 160 | # compute accuracy for refinement stage 161 | pck_re[step * batch_size:step * batch_size + curr_batch_size] = preds['acc_PCK_re'].data.cpu().numpy() 162 | acc_sil_2d_re[step * batch_size:step * batch_size + curr_batch_size] = preds['acc_IOU_re'].data.cpu().numpy() 163 | for part in pck_by_part: 164 | pck_by_part[part][step * batch_size:step * batch_size + curr_batch_size] = preds[f'{part}_PCK'].data.cpu().numpy() 165 | pck_by_part_re[part][step * batch_size:step * batch_size + curr_batch_size] = preds[ 166 | f'{part}_PCK_RE'].data.cpu().numpy() 167 | 168 | if args.save_results: 169 | output_figs = np.transpose( 170 | Visualizer.generate_output_figures_v2(preds, vis_refine=True).data.cpu().numpy(), 171 | (0, 1, 3, 4, 2)) 172 | for img_id in range(len(preds['imgname'])): 173 | imgname = preds['imgname'][img_id] 174 | output_fig_list = output_figs[img_id] 175 | 176 | path_parts = imgname.split('/') 177 | path_suffix = "{0}_{1}".format(path_parts[-2], path_parts[-1]) 178 | img_file = os.path.join(result_dir, path_suffix) 179 | output_fig = np.hstack(output_fig_list) 180 | smal_imgname.append(path_suffix) 181 | npz_file = "{0}.npz".format(os.path.splitext(img_file)[0]) 182 | 183 | cv2.imwrite(img_file, output_fig[:, :, ::-1] * 255.0) 184 | # np.savez_compressed(npz_file, 185 | # imgname=preds['imgname'][img_id], 186 | # pose=preds['pose'][img_id].data.cpu().numpy(), 187 | # betas=preds['betas'][img_id].data.cpu().numpy(), 188 | # camera=preds['camera'][img_id].data.cpu().numpy(), 189 | # trans=preds['trans'][img_id].data.cpu().numpy(), 190 | # acc_PCK=preds['acc_PCK'][img_id].data.cpu().numpy(), 191 | # # acc_SIL_2D=preds['acc_IOU'][img_id].data.cpu().numpy(), 192 | # **{f'{part}_PCK': preds[f'{part}_PCK'].data.cpu().numpy() for part in pck_by_part} 193 | # ) 194 | report = f"""*** Final Results *** 195 | 196 | SIL IOU 2D: {np.nanmean(acc_sil_2d):.5f} 197 | PCK 2D: {np.nanmean(pck):.5f} 198 | 199 | SIL IOU 2D REFINE: {np.nanmean(acc_sil_2d_re):.5f} 200 | PCK 2D REFINE: {np.nanmean(pck_re):.5f}""" 201 | 202 | for part in pck_by_part: 203 | report += f'\n {part} PCK 2D: {np.nanmean(pck_by_part[part]):.5f}' 204 | 205 | for part in pck_by_part: 206 | report += f'\n {part} PCK 2D RE: {np.nanmean(pck_by_part_re[part]):.5f}' 207 | print(report) 208 | 209 | # save report to file 210 | with open(os.path.join(result_dir, '{}_report.txt'.format(args.dataset)), 'w') as outfile: 211 | print(report, file=outfile) 212 | return np.nanmean(pck), np.nanmean(acc_sil_2d), pck_by_part, np.nanmean(pck_re), np.nanmean(acc_sil_2d_re) 213 | 214 | 215 | if __name__ == '__main__': 216 | parser = argparse.ArgumentParser() 217 | parser.add_argument('--lr', default=0.0001, type=float) 218 | parser.add_argument('--output_dir', default='./logs/', type=str) 219 | parser.add_argument('--batch_size', default=32, type=int) 220 | parser.add_argument('--num_works', default=4, type=int) 221 | parser.add_argument('--gpu_ids', default='0', type=str) 222 | parser.add_argument('--resume', default=None, type=str) 223 | parser.add_argument('--load_optimizer', action='store_true') 224 | parser.add_argument('--shape_family_id', default=1, type=int) 225 | parser.add_argument('--dataset', default='stanford', type=str) 226 | parser.add_argument('--param_dir', default=None, type=str, help='Exported parameter folder to load') 227 | parser.add_argument('--save_results', action='store_true') 228 | parser.add_argument('--prior_betas', default='smal', type=str) 229 | parser.add_argument('--prior_pose', default='smal', type=str) 230 | parser.add_argument('--betas_scale', action='store_true') 231 | parser.add_argument('--num_channels', type=int, default=256, help='Number of channels in Graph Residual layers') 232 | parser.add_argument('--num_layers', type=int, default=5, help='Number of residuals blocks in the Graph CNN') 233 | parser.add_argument('--local_feat', action='store_true') 234 | parser.add_argument('--shape_init', default='smal', help='enable to initiate shape with mean shape') 235 | parser.add_argument('--num_downsampling', default=1, type=int) 236 | 237 | args = parser.parse_args() 238 | main(args) -------------------------------------------------------------------------------- /model/smal_mesh_net_img.py: -------------------------------------------------------------------------------- 1 | """ 2 | Mesh net model. 3 | """ 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import os 9 | import os.path as osp 10 | import numpy as np 11 | import torch 12 | import torchvision 13 | import torch.nn as nn 14 | from util import net_blocks as nb 15 | 16 | 17 | 18 | # ------------- Modules ------------# 19 | # ----------------------------------# 20 | class ResNetConv(nn.Module): 21 | def __init__(self, n_blocks=4): 22 | super(ResNetConv, self).__init__() 23 | # if opts.use_resnet50: 24 | self.resnet = torchvision.models.resnet50(pretrained=True) 25 | # else: 26 | # self.resnet = torchvision.models.resnet18(pretrained=True) 27 | self.n_blocks = n_blocks 28 | # if self.opts.use_double_input: 29 | # self.fc = nb.fc_stack(512*16*8, 512*8*8, 2) 30 | 31 | def forward(self, x, y=None): 32 | # if self.opts.use_double_input and y is not None: 33 | # x = torch.cat([x, y], 2) 34 | img_feat_multiscale = [] # collect multi-scale features for local feature retrieve 35 | n_blocks = self.n_blocks 36 | x = self.resnet.conv1(x) 37 | x = self.resnet.bn1(x) 38 | x = self.resnet.relu(x) 39 | x = self.resnet.maxpool(x) 40 | 41 | if n_blocks >= 1: 42 | x = self.resnet.layer1(x) 43 | img_feat_multiscale.append(x) 44 | if n_blocks >= 2: 45 | x = self.resnet.layer2(x) 46 | img_feat_multiscale.append(torch.nn.Upsample(size=(56, 56), mode='bilinear')(x)) 47 | if n_blocks >= 3: 48 | x = self.resnet.layer3(x) 49 | img_feat_multiscale.append(torch.nn.Upsample(size=(56, 56), mode='bilinear')(x)) 50 | if n_blocks >= 4: 51 | x = self.resnet.layer4(x) 52 | img_feat_multiscale.append(torch.nn.Upsample(size=(56, 56), mode='bilinear')(x)) 53 | # if self.opts.use_double_input and y is not None: 54 | # x = x.view(x.size(0), -1) 55 | # x = self.fc.forward(x) 56 | # x = x.view(x.size(0), 512, 8, 8) 57 | 58 | return x, torch.cat(img_feat_multiscale, dim=1) 59 | 60 | 61 | class Encoder(nn.Module): 62 | """ 63 | Current: 64 | Resnet with 4 blocks (x32 spatial dim reduction) 65 | Another conv with stride 2 (x64) 66 | This is sent to 2 fc layers with final output nz_feat. 67 | """ 68 | 69 | def __init__(self, input_shape, channels_per_group=16, n_blocks=4, nz_feat=100, bott_size=256): 70 | super(Encoder, self).__init__() 71 | self.resnet_conv = ResNetConv(n_blocks=4) 72 | num_norm_groups = bott_size // channels_per_group 73 | # if opts.use_resnet50: 74 | self.enc_conv1 = nb.conv2d('group', 2048, bott_size, stride=2, kernel_size=4, num_groups=num_norm_groups) 75 | # else: 76 | # self.enc_conv1 = nb.conv2d('group', 512, bott_size, stride=2, kernel_size=4, num_groups=num_norm_groups) 77 | 78 | nc_input = bott_size * (input_shape[0] // 64) * (input_shape[1] // 64) 79 | self.enc_fc = nb.fc_stack(nc_input, nz_feat, 2, 'batch') 80 | self.nenc_feat = nc_input 81 | 82 | nb.net_init(self.enc_conv1) 83 | self.avgpool = nn.AvgPool2d(7, stride=1) 84 | 85 | def forward(self, img, fg_img): 86 | resnet_feat, feat_multiscale = self.resnet_conv.forward(img, fg_img) # multi-scale feature is used to extract local feature for refinement 87 | out_enc_conv1 = self.enc_conv1(resnet_feat) # feature for predicting SMAL parameters 88 | out_resnet = self.avgpool(resnet_feat) # add an pooling layer to get global feature for mesh refinement 89 | out_enc_conv1 = out_enc_conv1.view(img.size(0), -1) 90 | feat = self.enc_fc.forward(out_enc_conv1) 91 | return feat, out_enc_conv1, out_resnet, feat_multiscale 92 | 93 | 94 | class ShapePredictor(nn.Module): 95 | """ 96 | Outputs mesh deformations 97 | """ 98 | 99 | def __init__(self, nz_feat, num_verts, left_idx, right_idx, shapedirs, use_delta_v=False, use_sym_idx=False, 100 | use_smal_betas=False, n_shape_feat=40): 101 | super(ShapePredictor, self).__init__() 102 | self.use_delta_v = use_delta_v 103 | self.use_sym_idx = use_sym_idx 104 | self.use_smal_betas = use_smal_betas 105 | self.ref_delta_v = torch.Tensor(np.zeros((num_verts, 3))).cuda() 106 | 107 | def forward(self, feat): 108 | if self.use_sym_idx: 109 | batch_size = feat.shape[0] 110 | delta_v = torch.Tensor(np.zeros((batch_size, self.num_verts, 3))).cuda() 111 | feat = self.fc(feat) 112 | self.shape_f = feat 113 | 114 | half_delta_v = self.pred_layer.forward(feat) 115 | half_delta_v = half_delta_v.view(half_delta_v.size(0), -1, 3) 116 | delta_v[:, self.left_idx, :] = half_delta_v 117 | half_delta_v[:, :, 1] = -1. * half_delta_v[:, :, 1] 118 | delta_v[:, self.right_idx, :] = half_delta_v 119 | else: 120 | delta_v = self.pred_layer.forward(feat) 121 | # Make it B x num_verts x 3 122 | delta_v = delta_v.view(delta_v.size(0), -1, 3) 123 | # print('shape: ( Mean = {}, Var = {} )'.format(delta_v.mean().data[0], delta_v.var().data[0])) 124 | return delta_v 125 | 126 | 127 | class PosePredictor(nn.Module): 128 | """ 129 | """ 130 | 131 | def __init__(self, nz_feat, num_joints=35): 132 | super(PosePredictor, self).__init__() 133 | self.pose_var = 1.0 134 | self.num_joints = num_joints 135 | self.pred_layer = nn.Linear(nz_feat, num_joints * 3) 136 | # bjb_edit 137 | self.pred_layer.weight.data.normal_(0, 1e-4) 138 | self.pred_layer.bias.data.normal_(0, 1e-4) 139 | 140 | def forward(self, feat): 141 | pose = self.pose_var * self.pred_layer.forward(feat) 142 | 143 | # Add this to have zero to correspond to frontal facing 144 | # edit by lic, frontal facing and upright 145 | pose[:, 0] += -1.20919958 146 | pose[:, 1] += 1.20919958 147 | pose[:, 2] += 1.20919958 148 | return pose 149 | 150 | 151 | class BetaScalePredictor(nn.Module): 152 | def __init__(self, nz_feat, nenc_feat, num_beta_scale=6, model_mean=None): 153 | super(BetaScalePredictor, self).__init__() 154 | self.model_mean = model_mean 155 | self.pred_layer = nn.Linear(nenc_feat, num_beta_scale) 156 | # bjb_edit 157 | self.pred_layer.weight.data.normal_(0, 1e-4) 158 | if model_mean is not None: 159 | self.pred_layer.bias.data = model_mean + torch.randn_like(model_mean) * 1e-4 160 | else: 161 | self.pred_layer.bias.data.normal_(0, 1e-4) 162 | 163 | def forward(self, feat, enc_feat): 164 | betas = self.pred_layer.forward(enc_feat) 165 | 166 | return betas 167 | 168 | 169 | class BetasPredictor(nn.Module): 170 | def __init__(self, nz_feat, nenc_feat, num_betas=20, model_mean=None): 171 | super(BetasPredictor, self).__init__() 172 | self.model_mean = model_mean 173 | self.pred_layer = nn.Linear(nenc_feat, num_betas) 174 | # bjb_edit 175 | self.pred_layer.weight.data.normal_(0, 1e-4) 176 | if model_mean is not None: 177 | self.pred_layer.bias.data = model_mean + torch.randn_like(model_mean) * 1e-4 178 | else: 179 | self.pred_layer.bias.data.normal_(0, 1e-4) 180 | 181 | def forward(self, feat, enc_feat): 182 | betas = self.pred_layer.forward(enc_feat) 183 | return betas 184 | 185 | 186 | class ScalePredictor(nn.Module): 187 | ''' 188 | In case of perspective projection scale is focal length 189 | ''' 190 | 191 | def __init__(self, nz, norm_f0, use_camera=True, scale_bias=1): 192 | super(ScalePredictor, self).__init__() 193 | self.use_camera = use_camera 194 | self.norm_f0 = norm_f0 195 | if self.use_camera: 196 | self.pred_layer = nn.Linear(nz, scale_bias) 197 | # else: 198 | # scale = np.zeros((opts.batch_size,1)) 199 | # scale[:,0] = 0. 200 | # self.ref_camera = torch.Tensor(scale).cuda() 201 | 202 | def forward(self, feat): 203 | if not self.use_camera: 204 | scale = np.zeros((feat.shape[0], 1)) 205 | scale[:, 0] = 0. 206 | return torch.Tensor(scale).cuda() 207 | if self.norm_f0 != 0: 208 | off = 0. 209 | else: 210 | off = 1. 211 | scale = self.pred_layer.forward(feat) + off 212 | return scale 213 | 214 | 215 | class TransPredictor(nn.Module): 216 | """ 217 | Outputs [tx, ty] or [tx, ty, tz] 218 | """ 219 | 220 | def __init__(self, nz, projection_type, fix_trans=False): 221 | super(TransPredictor, self).__init__() 222 | self.fix_trans = fix_trans 223 | if projection_type == 'orth': 224 | self.pred_layer = nn.Linear(nz, 2) 225 | elif projection_type == 'perspective': 226 | self.pred_layer_xy = nn.Linear(nz, 2) 227 | self.pred_layer_z = nn.Linear(nz, 1) 228 | self.pred_layer_xy.weight.data.normal_(0, 0.0001) 229 | self.pred_layer_xy.bias.data.normal_(0, 0.0001) 230 | self.pred_layer_z.weight.data.normal_(0, 0.0001) 231 | self.pred_layer_z.bias.data.normal_(0, 0.0001) 232 | else: 233 | print('Unknown projection type') 234 | 235 | def forward(self, feat): 236 | trans = torch.Tensor(np.zeros((feat.shape[0], 3))).cuda() 237 | f = torch.Tensor(np.zeros((feat.shape[0], 1))).cuda() 238 | feat_xy = feat 239 | feat_z = feat 240 | 241 | trans[:, :2] = self.pred_layer_xy(feat_xy) 242 | trans[:, 2] = 1 + self.pred_layer_z(feat_z)[:, 0] 243 | 244 | if self.fix_trans: 245 | trans[:, 2] = 1. 246 | 247 | # print('trans: ( Mean = {}, Var = {} )'.format(trans.mean().data[0], trans.var().data[0])) 248 | return trans 249 | 250 | 251 | class CodePredictor(nn.Module): 252 | def __init__( 253 | self, norm_f0, nz_feat=100, nenc_feat=2048, 254 | use_smal_betas=True, 255 | num_betas=27, # bjb_edit 256 | use_camera=True, scale_bias=1, 257 | fix_trans=False, betas_scale=False, 258 | use_smal_pose=True, 259 | shape_init=None): 260 | 261 | super(CodePredictor, self).__init__() 262 | self.use_smal_pose = use_smal_pose 263 | self.use_smal_betas = use_smal_betas 264 | self.use_camera = use_camera 265 | self.betas_scale = betas_scale 266 | self.scale_predictor = ScalePredictor( 267 | nz_feat, norm_f0, use_camera=use_camera, scale_bias=scale_bias) 268 | self.trans_predictor = TransPredictor( 269 | nz_feat, 'perspective', fix_trans=fix_trans) 270 | 271 | if self.use_smal_pose: 272 | self.pose_predictor = PosePredictor(nz_feat) 273 | 274 | if self.use_smal_betas: 275 | scale_init = None 276 | shape_betas_init = None 277 | if shape_init is not None: 278 | shape_betas_init = shape_init[:20] 279 | if shape_init.shape[0] == 26: 280 | scale_init = shape_init[20:] 281 | 282 | self.betas_predictor = BetasPredictor( 283 | nz_feat, nenc_feat, num_betas=20, model_mean=shape_betas_init) 284 | if self.betas_scale: 285 | self.betas_scale_predictor = BetaScalePredictor( 286 | nz_feat, nenc_feat, num_beta_scale=6, model_mean=scale_init) 287 | 288 | def forward(self, feat, enc_feat): 289 | if self.use_camera: 290 | scale_pred = self.scale_predictor.forward(feat) 291 | else: 292 | scale_pred = self.scale_predictor.ref_camera 293 | 294 | trans_pred = self.trans_predictor.forward(feat) 295 | 296 | if self.use_smal_pose: 297 | pose_pred = self.pose_predictor.forward(feat) 298 | else: 299 | pose_pred = None 300 | 301 | if self.use_smal_betas: 302 | betas_pred = self.betas_predictor.forward(feat, enc_feat) 303 | if self.betas_scale: 304 | betas_scale_pred = self.betas_scale_predictor.forward(feat, enc_feat)[:, 305 | :6] # choose first 6 for backward compat 306 | else: 307 | betas_scale_pred = None 308 | else: 309 | betas_pred = None 310 | betas_scale_pred = None 311 | 312 | return scale_pred, trans_pred, pose_pred, betas_pred, betas_scale_pred 313 | 314 | 315 | # ------------ Mesh Net ------------# 316 | # ----------------------------------# 317 | class MeshNet_img(nn.Module): 318 | def __init__(self, 319 | input_shape, betas_scale=False, 320 | norm_f0=2700., nz_feat=100, 321 | shape_init=None, return_feat=False): 322 | # Input shape is H x W of the image. 323 | super(MeshNet_img, self).__init__() 324 | 325 | self.bottleneck_size = 2048 326 | self.channels_per_group = 16 327 | self.shape_init = shape_init 328 | 329 | self.encoder = Encoder( 330 | input_shape, 331 | channels_per_group=self.channels_per_group, 332 | n_blocks=4, nz_feat=nz_feat, bott_size=self.bottleneck_size) 333 | 334 | self.code_predictor = CodePredictor( 335 | norm_f0, nz_feat=nz_feat, nenc_feat=self.encoder.nenc_feat, betas_scale=betas_scale, 336 | use_smal_betas=True, shape_init=self.shape_init) 337 | self.return_feat = return_feat 338 | 339 | def forward(self, img, masks=None, is_optimization=False, is_var_opt=False): 340 | img_feat, enc_feat, feat_resnet, feat_multiscale = self.encoder.forward(img, masks) 341 | codes_pred = self.code_predictor.forward(img_feat, enc_feat) 342 | if self.return_feat: 343 | return codes_pred, feat_resnet, feat_multiscale 344 | else: 345 | return codes_pred 346 | 347 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import, division 2 | 3 | import os 4 | import numpy as np 5 | from tqdm import tqdm 6 | import cv2 7 | import argparse 8 | import torch 9 | import torch.nn as nn 10 | import torch.optim 11 | from torch.utils.data import DataLoader 12 | from torch.utils.tensorboard import SummaryWriter 13 | from model.model_v1 import MeshModel 14 | from util import config 15 | from util.helpers.visualize import Visualizer 16 | from util.loss_utils import kp_l2_loss, Shape_prior, mask_loss 17 | from util.metrics import Metrics 18 | from datasets.stanford import BaseDataset 19 | from util.logger import Logger 20 | from util.meter import AverageMeterSet 21 | from util.misc import save_checkpoint, adjust_learning_rate_exponential 22 | from util.pose_prior import Prior 23 | from util.joint_limits_prior import LimitPrior 24 | import pickle 25 | 26 | # Set some global varibles 27 | global_step = 0 28 | best_pck = 0 29 | best_pck_epoch = 0 30 | 31 | 32 | def main(args): 33 | global best_pck 34 | global best_pck_epoch 35 | global global_step 36 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_ids 37 | print("RESULTS: {0}".format(args.output_dir)) 38 | if not os.path.exists(args.output_dir): 39 | os.mkdir(args.output_dir) 40 | # set up device 41 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 42 | # set up model 43 | model = MeshModel(device, args.shape_family_id, args.betas_scale, args.shape_init, render_rgb=args.save_results) 44 | model = nn.DataParallel(model).to(device) 45 | # set up datasets 46 | dataset_train = BaseDataset(args.dataset, param_dir=args.param_dir, is_train=True, use_augmentation=True) 47 | dataset_eval = BaseDataset(args.dataset, param_dir=args.param_dir, is_train=False, use_augmentation=False) 48 | data_loader_train = DataLoader(dataset_train, batch_size=args.batch_size, shuffle=True, num_workers=args.num_works) 49 | data_loader_eval = DataLoader(dataset_eval, batch_size=args.batch_size, shuffle=False, num_workers=args.num_works) 50 | 51 | # set up optimizer 52 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 53 | 54 | writer = SummaryWriter(os.path.join(args.output_dir, 'train')) 55 | 56 | # set up criterion 57 | joint_limit_prior = LimitPrior(device) 58 | shape_prior = Shape_prior(args.prior_betas, args.shape_family_id, device) 59 | if args.resume: 60 | if os.path.isfile(args.resume): 61 | print("=> loading checkpoint {}".format(args.resume)) 62 | checkpoint = torch.load(args.resume) 63 | model.load_state_dict(checkpoint['state_dict']) 64 | if args.load_optimizer: 65 | optimizer.load_state_dict(checkpoint['optimizer']) 66 | args.start_epoch = checkpoint['epoch'] + 1 67 | print("=> loaded checkpoint {} (epoch {})".format(args.resume, checkpoint['epoch'])) 68 | logger = Logger(os.path.join(args.output_dir, 'log.txt')) 69 | logger.log_arguments(args) 70 | logger.set_names(['Epoch', 'LR', 'PCK', 'IOU']) 71 | else: 72 | print("=> no checkpoint found at {}".format(args.resume)) 73 | else: 74 | logger = Logger(os.path.join(args.output_dir, 'log.txt')) 75 | logger.log_arguments(args) 76 | logger.set_names(['Epoch', 'LR', 'PCK', 'IOU']) 77 | 78 | if args.evaluate: 79 | pck, iou_silh, pck_by_part = run_evaluation(model, dataset_eval, data_loader_eval, device, args) 80 | print("Evaluate only, PCK: {}, IOU: {}".format(pck, iou_silh)) 81 | return 82 | 83 | lr = args.lr 84 | for epoch in range(args.start_epoch, args.nEpochs): 85 | 86 | model.train() 87 | tqdm_iterator = tqdm(data_loader_train, desc='Train', total=len(data_loader_train)) 88 | meters = AverageMeterSet() 89 | 90 | for step, batch in enumerate(tqdm_iterator): 91 | keypoints = batch['keypoints'].to(device) 92 | seg = batch['seg'].to(device) 93 | img = batch['img'].to(device) 94 | 95 | pred_codes = model(img) 96 | scale_pred, trans_pred, pose_pred, betas_pred, betas_scale_pred = pred_codes 97 | pred_camera = torch.cat([scale_pred[:, [0]], torch.ones(keypoints.shape[0], 2).cuda() * config.IMG_RES / 2], 98 | dim=1) 99 | # recover 3D mesh from SMAL parameters 100 | verts, joints, _, _ = model.module.smal(betas_pred, pose_pred, trans=trans_pred, 101 | betas_logscale=betas_scale_pred) 102 | faces = model.module.smal.faces.unsqueeze(0).expand(verts.shape[0], 7774, 3) 103 | labelled_joints_3d = joints[:, config.MODEL_JOINTS] 104 | # project 3D joints onto 2D space and apply 2D keypoints supervision 105 | synth_landmarks = model.module.model_renderer.project_points(labelled_joints_3d, pred_camera) 106 | loss_kpts = args.w_kpts * kp_l2_loss(synth_landmarks, keypoints[:, :, [1, 0, 2]], config.NUM_JOINTS) 107 | meters.update('loss_kpt', loss_kpts.item()) 108 | loss = loss_kpts 109 | 110 | # apply shape prior constraint, either come from SMAL or unity from WLDO 111 | if args.w_betas_prior > 0: 112 | if args.prior_betas == 'smal': 113 | s_prior = args.w_betas_prior * shape_prior(betas_pred) 114 | elif args.prior_betas == 'unity': 115 | betas_pred = torch.cat([betas_pred, betas_scale_pred], dim=1) 116 | s_prior = args.w_betas_prior * shape_prior(betas_pred) 117 | else: 118 | Exception("Shape prior should come from either smal or unity") 119 | s_prior = 0 120 | meters.update('loss_prior', s_prior.item()) 121 | loss += s_prior 122 | 123 | # apply pose prior constraint, either come from SMAL or unity from WLDO 124 | if args.w_pose_prior > 0: 125 | if args.prior_pose == 'smal': 126 | pose_prior_path = config.WALKING_PRIOR_FILE 127 | elif args.prior_pose == 'unity': 128 | pose_prior_path = config.UNITY_POSE_PRIOR 129 | else: 130 | Exception('The prior should come from either smal or unity') 131 | pose_prior_path = None 132 | pose_prior = Prior(pose_prior_path, device) 133 | p_prior = args.w_pose_prior * pose_prior(pose_pred) 134 | meters.update('pose_prior', p_prior.item()) 135 | loss += p_prior 136 | 137 | meters.update('loss_all', loss.item()) 138 | optimizer.zero_grad() 139 | loss.backward() 140 | optimizer.step() 141 | global_step += 1 142 | if step % 20 == 0: 143 | loss_values = meters.averages() 144 | for name, meter in loss_values.items(): 145 | writer.add_scalar(name, meter, global_step) 146 | writer.flush() 147 | 148 | pck, iou_silh, pck_by_part = run_evaluation(model, dataset_eval, data_loader_eval, device, args) 149 | 150 | print("Epoch: {}, LR: {}, PCK: {}, IOU: {}".format(epoch, lr, pck, iou_silh)) 151 | logger.append([epoch, lr, pck, iou_silh]) 152 | 153 | is_best = pck > best_pck 154 | if pck > best_pck: 155 | best_pck_epoch = epoch 156 | best_pck = max(pck, best_pck) 157 | save_checkpoint({'epoch': epoch, 158 | 'state_dict': model.state_dict(), 159 | 'best_pck': best_pck, 160 | 'optimizer': optimizer.state_dict()}, 161 | is_best, checkpoint=args.output_dir, filename='checkpoint.pth.tar') 162 | writer.close() 163 | logger.close() 164 | 165 | 166 | def run_evaluation(model, dataset, data_loader, device, args): 167 | model.eval() 168 | result_dir = args.output_dir 169 | batch_size = args.batch_size 170 | 171 | pck = np.zeros((len(dataset))) 172 | pck_by_part = {group: np.zeros((len(dataset))) for group in config.KEYPOINT_GROUPS} 173 | acc_sil_2d = np.zeros(len(dataset)) 174 | 175 | smal_pose = np.zeros((len(dataset), 105)) 176 | smal_betas = np.zeros((len(dataset), 20)) 177 | smal_camera = np.zeros((len(dataset), 3)) 178 | smal_imgname = [] 179 | 180 | tqdm_iterator = tqdm(data_loader, desc='Eval', total=len(data_loader)) 181 | 182 | for step, batch in enumerate(tqdm_iterator): 183 | with torch.no_grad(): 184 | keypoints = batch['keypoints'].to(device) 185 | keypoints_norm = batch['keypoints_norm'].to(device) 186 | seg = batch['seg'].to(device) 187 | has_seg = batch['has_seg'] 188 | img = batch['img'].to(device) 189 | img_border_mask = batch['img_border_mask'].to(device) 190 | pred_codes = model(img) 191 | 192 | scale_pred, trans_pred, pose_pred, betas_pred, betas_scale_pred = pred_codes 193 | pred_camera = torch.cat([scale_pred[:, [0]], torch.ones(keypoints.shape[0], 2).cuda() * config.IMG_RES / 2], 194 | dim=1) 195 | verts, joints, _, _ = model.module.smal(betas_pred, pose_pred, trans=trans_pred, 196 | betas_logscale=betas_scale_pred) 197 | faces = model.module.smal.faces.unsqueeze(0).expand(verts.shape[0], 7774, 3) 198 | labelled_joints_3d = joints[:, config.MODEL_JOINTS] 199 | synth_rgb, synth_silhouettes = model.module.model_renderer(verts, faces, pred_camera) 200 | if args.save_results: 201 | synth_rgb = torch.clamp(synth_rgb[0], 0.0, 1.0) 202 | synth_silhouettes = synth_silhouettes.unsqueeze(1) 203 | synth_landmarks = model.module.model_renderer.project_points(labelled_joints_3d, pred_camera) 204 | 205 | preds = {} 206 | preds['pose'] = pose_pred 207 | preds['betas'] = betas_pred 208 | preds['camera'] = pred_camera 209 | preds['trans'] = trans_pred 210 | 211 | preds['verts'] = verts 212 | preds['joints_3d'] = labelled_joints_3d 213 | preds['faces'] = faces 214 | 215 | preds['acc_PCK'] = Metrics.PCK(synth_landmarks, keypoints_norm, seg, has_seg) 216 | preds['acc_IOU'] = Metrics.IOU(synth_silhouettes, seg, img_border_mask, mask=has_seg) 217 | 218 | for group, group_kps in config.KEYPOINT_GROUPS.items(): 219 | preds[f'{group}_PCK'] = Metrics.PCK(synth_landmarks, keypoints_norm, seg, has_seg, 220 | thresh_range=[0.15], 221 | idxs=group_kps) 222 | 223 | preds['synth_xyz'] = synth_rgb 224 | preds['synth_silhouettes'] = synth_silhouettes 225 | preds['synth_landmarks'] = synth_landmarks 226 | 227 | assert not any(k in preds for k in batch.keys()) 228 | preds.update(batch) 229 | 230 | curr_batch_size = preds['synth_landmarks'].shape[0] 231 | 232 | pck[step * batch_size:step * batch_size + curr_batch_size] = preds['acc_PCK'].data.cpu().numpy() 233 | acc_sil_2d[step * batch_size:step * batch_size + curr_batch_size] = preds['acc_IOU'].data.cpu().numpy() 234 | smal_pose[step * batch_size:step * batch_size + curr_batch_size] = preds['pose'].data.cpu().numpy() 235 | smal_betas[step * batch_size:step * batch_size + curr_batch_size, :preds['betas'].shape[1]] = preds['betas'].data.cpu().numpy() 236 | smal_camera[step * batch_size:step * batch_size + curr_batch_size] = preds['camera'].data.cpu().numpy() 237 | 238 | for part in pck_by_part: 239 | pck_by_part[part][step * batch_size:step * batch_size + curr_batch_size] = preds[f'{part}_PCK'].data.cpu().numpy() 240 | 241 | # save results as well as visualization 242 | if args.save_results: 243 | output_figs = np.transpose( 244 | Visualizer.generate_output_figures(preds).data.cpu().numpy(), 245 | (0, 1, 3, 4, 2)) 246 | 247 | for img_id in range(len(preds['imgname'])): 248 | imgname = preds['imgname'][img_id] 249 | output_fig_list = output_figs[img_id] 250 | 251 | path_parts = imgname.split('/') 252 | path_suffix = "{0}_{1}".format(path_parts[-2], path_parts[-1]) 253 | img_file = os.path.join(result_dir, path_suffix) 254 | output_fig = np.hstack(output_fig_list) 255 | smal_imgname.append(path_suffix) 256 | # npz_file = "{0}.npz".format(os.path.splitext(img_file)[0]) 257 | 258 | cv2.imwrite(img_file, output_fig[:, :, ::-1] * 255.0) 259 | # np.savez_compressed(npz_file, 260 | # imgname=preds['imgname'][img_id], 261 | # pose=preds['pose'][img_id].data.cpu().numpy(), 262 | # betas=preds['betas'][img_id].data.cpu().numpy(), 263 | # camera=preds['camera'][img_id].data.cpu().numpy(), 264 | # trans=preds['trans'][img_id].data.cpu().numpy(), 265 | # acc_PCK=preds['acc_PCK'][img_id].data.cpu().numpy(), 266 | # # acc_SIL_2D=preds['acc_IOU'][img_id].data.cpu().numpy(), 267 | # **{f'{part}_PCK': preds[f'{part}_PCK'].data.cpu().numpy() for part in pck_by_part} 268 | # ) 269 | 270 | return np.nanmean(pck), np.nanmean(acc_sil_2d), pck_by_part 271 | 272 | 273 | if __name__ == '__main__': 274 | parser = argparse.ArgumentParser() 275 | parser.add_argument('--lr', default=0.0001, type=float) 276 | parser.add_argument('--output_dir', default='./logs/', type=str) 277 | parser.add_argument('--nEpochs', default=200, type=int) 278 | parser.add_argument('--w_kpts', default=10, type=float) 279 | parser.add_argument('--w_betas_prior', default=1, type=float) 280 | parser.add_argument('--w_pose_prior', default=1, type=float) 281 | parser.add_argument('--batch_size', default=16, type=int) 282 | parser.add_argument('--num_works', default=4, type=int) 283 | parser.add_argument('--start_epoch', default=0, type=int) 284 | parser.add_argument('--gpu_ids', default='0', type=str) 285 | parser.add_argument('--evaluate', action='store_true') 286 | parser.add_argument('--resume', default=None, type=str) 287 | parser.add_argument('--load_optimizer', action='store_true') 288 | parser.add_argument('--shape_family_id', default=1, type=int) 289 | parser.add_argument('--dataset', default='stanford', type=str) 290 | parser.add_argument('--param_dir', default=None, type=str, help='Exported parameter folder to load') 291 | parser.add_argument('--shape_init', default='smal', help='enable to initiate shape with mean shape') 292 | parser.add_argument('--save_results', action='store_true') 293 | parser.add_argument('--prior_betas', default='smal', type=str) 294 | parser.add_argument('--prior_pose', default='smal', type=str) 295 | parser.add_argument('--betas_scale', action='store_true') 296 | 297 | args = parser.parse_args() 298 | main(args) -------------------------------------------------------------------------------- /main_meshgraph.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import, division 2 | 3 | import os 4 | import numpy as np 5 | from tqdm import tqdm 6 | import cv2 7 | import argparse 8 | import torch 9 | import torch.nn as nn 10 | import torch.optim 11 | from torch.utils.data import DataLoader 12 | from torch.utils.tensorboard import SummaryWriter 13 | from model.mesh_graph_hg import MeshGraph_hg, init_pretrained 14 | 15 | from util import config 16 | from util.helpers.visualize import Visualizer 17 | from util.loss_utils import kp_l2_loss, Shape_prior, Laplacian 18 | from util.loss_sdf import tversky_loss 19 | from util.metrics import Metrics 20 | 21 | from datasets.stanford import BaseDataset 22 | 23 | from util.logger import Logger 24 | from util.meter import AverageMeterSet 25 | from util.misc import save_checkpoint 26 | from util.pose_prior import Prior 27 | from util.joint_limits_prior import LimitPrior 28 | from util.utils import print_options 29 | 30 | # Set some global varibles 31 | global_step = 0 32 | best_pck = 0 33 | best_pck_epoch = 0 34 | 35 | 36 | def main(args): 37 | global best_pck 38 | global best_pck_epoch 39 | global global_step 40 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_ids 41 | 42 | print("RESULTS: {0}".format(args.output_dir)) 43 | if not os.path.exists(args.output_dir): 44 | os.mkdir(args.output_dir) 45 | 46 | # set up device 47 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 48 | 49 | # set up model 50 | model = MeshGraph_hg(device, args.shape_family_id, args.num_channels, args.num_layers, args.betas_scale, 51 | args.shape_init, args.local_feat, num_downsampling=args.num_downsampling, 52 | render_rgb=args.save_results) 53 | 54 | model = nn.DataParallel(model).to(device) 55 | 56 | # set up dataset 57 | dataset_train = BaseDataset(args.dataset, param_dir=args.param_dir, is_train=True, use_augmentation=True) 58 | data_loader_train = DataLoader(dataset_train, batch_size=args.batch_size, shuffle=True, num_workers=args.num_works) 59 | dataset_eval = BaseDataset(args.dataset, param_dir=args.param_dir, is_train=False, use_augmentation=False) 60 | data_loader_eval = DataLoader(dataset_eval, batch_size=args.batch_size, shuffle=False, num_workers=args.num_works) 61 | 62 | # set up optimizer 63 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 64 | 65 | writer = SummaryWriter(os.path.join(args.output_dir, 'train')) 66 | 67 | # set up priors 68 | joint_limit_prior = LimitPrior(device) 69 | shape_prior = Shape_prior(args.prior_betas, args.shape_family_id, device) 70 | tversky = tversky_loss(args.alpha, args.beta) 71 | 72 | # read the adjacency matrix, which will used in the Laplacian regularizer 73 | data = np.load('./data/mesh_down_sampling_4.npz', encoding='latin1', allow_pickle=True) 74 | adjmat = data['A'][0] 75 | laplacianloss = Laplacian(adjmat, device) 76 | 77 | if args.resume: 78 | if os.path.isfile(args.resume): 79 | print("=> loading checkpoint {}".format(args.resume)) 80 | checkpoint = torch.load(args.resume) 81 | model.load_state_dict(checkpoint['state_dict']) 82 | if args.load_optimizer: 83 | optimizer.load_state_dict(checkpoint['optimizer']) 84 | args.start_epoch = checkpoint['epoch'] + 1 85 | print("=> loaded checkpoint {} (epoch {})".format(args.resume, checkpoint['epoch'])) 86 | # logger = Logger(os.path.join(args.output_dir, 'log.txt'), resume=True) 87 | logger = Logger(os.path.join(args.output_dir, 'log.txt')) 88 | logger.set_names(['Epoch', 'LR', 'PCK', 'IOU', 'PCK_re', 'IOU_re']) 89 | else: 90 | print("=> no checkpoint found at {}".format(args.resume)) 91 | else: 92 | logger = Logger(os.path.join(args.output_dir, 'log.txt')) 93 | logger.set_names(['Epoch', 'LR', 'PCK', 'IOU','PCK_re', 'IOU_re']) 94 | 95 | if args.freezecoarse: 96 | for p in model.module.meshnet.parameters(): 97 | p.requires_grad = False 98 | if args.pretrained: 99 | if os.path.isfile(args.pretrained): 100 | print("=> loading checkpoint {}".format(args.pretrained)) 101 | checkpoint_pre = torch.load(args.pretrained) 102 | init_pretrained(model, checkpoint_pre) 103 | print("=> loaded checkpoint {} (epoch {})".format(args.pretrained, checkpoint_pre['epoch'])) 104 | # logger = Logger(os.path.join(args.output_dir, 'log.txt'), resume=True) 105 | logger = Logger(os.path.join(args.output_dir, 'log.txt')) 106 | logger.set_names(['Epoch', 'LR', 'PCK', 'IOU', 'PCK_re', 'IOU_re']) 107 | 108 | print_options(args) 109 | if args.evaluate: 110 | pck, iou_silh, pck_by_part, pck_re, iou_re = run_evaluation(model, dataset_eval, data_loader_eval, device, args) 111 | print("Evaluate only, PCK: {:6.4f}, IOU: {:6.4f}, PCK_re: {:6.4f}, IOU_re: {:6.4f}" 112 | .format(pck, iou_silh, pck_re, iou_re)) 113 | return 114 | 115 | lr = args.lr 116 | # lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( 117 | # optimizer, [160, 190], 0.5, 118 | # ) 119 | for epoch in range(args.start_epoch, args.nEpochs): 120 | # lr_scheduler.step() 121 | # lr = lr_scheduler.get_last_lr()[0] 122 | model.train() 123 | tqdm_iterator = tqdm(data_loader_train, desc='Train', total=len(data_loader_train)) 124 | meters = AverageMeterSet() 125 | for step, batch in enumerate(tqdm_iterator): 126 | keypoints = batch['keypoints'].to(device) 127 | keypoints_norm = batch['keypoints_norm'].to(device) 128 | seg = batch['seg'].to(device) 129 | img = batch['img'].to(device) 130 | 131 | verts, joints, shape, pred_codes = model(img) 132 | scale_pred, trans_pred, pose_pred, betas_pred, betas_scale_pred = pred_codes 133 | pred_camera = torch.cat([scale_pred[:, [0]], torch.ones(keypoints.shape[0], 2).cuda() * config.IMG_RES / 2], 134 | dim=1) 135 | faces = model.module.smal.faces.unsqueeze(0).expand(verts.shape[0], 7774, 3) 136 | labelled_joints_3d = joints[:, config.MODEL_JOINTS] 137 | 138 | # project 3D joints onto 2D space and apply 2D keypoints supervision 139 | synth_landmarks = model.module.model_renderer.project_points(labelled_joints_3d, pred_camera) 140 | loss_kpts = args.w_kpts * kp_l2_loss(synth_landmarks, keypoints[:, :, [1, 0, 2]], config.NUM_JOINTS) 141 | meters.update('loss_kpt', loss_kpts.item()) 142 | loss = loss_kpts 143 | 144 | # use tversky for silhouette loss 145 | if args.w_dice>0: 146 | synth_rgb, synth_silhouettes = model.module.model_renderer(verts, faces, pred_camera) 147 | synth_silhouettes = synth_silhouettes.unsqueeze(1) 148 | loss_dice = args.w_dice * tversky(synth_silhouettes, seg) 149 | meters.update('loss_dice', loss_dice.item()) 150 | loss += loss_dice 151 | 152 | # apply shape prior constraint, either come from SMAL or unity from WLDO 153 | if args.w_betas_prior > 0: 154 | if args.prior_betas == 'smal': 155 | s_prior = args.w_betas_prior * shape_prior(betas_pred) 156 | elif args.prior_betas == 'unity': 157 | betas_pred = torch.cat([betas_pred, betas_scale_pred], dim=1) 158 | s_prior = args.w_betas_prior * shape_prior(betas_pred) 159 | else: 160 | Exception("Shape prior should come from either smal or unity") 161 | s_prior = 0 162 | meters.update('loss_prior', s_prior.item()) 163 | loss += s_prior 164 | 165 | # apply pose prior constraint, either come from SMAL or unity from WLDO 166 | if args.w_pose_prior > 0: 167 | if args.prior_pose == 'smal': 168 | pose_prior_path = config.WALKING_PRIOR_FILE 169 | elif args.prior_pose == 'unity': 170 | pose_prior_path = config.UNITY_POSE_PRIOR 171 | else: 172 | Exception('The prior should come from either smal or unity') 173 | pose_prior_path = None 174 | pose_prior = Prior(pose_prior_path, device) 175 | p_prior = args.w_pose_prior * pose_prior(pose_pred) 176 | meters.update('pose_prior', p_prior.item()) 177 | loss += p_prior 178 | 179 | # apply pose limit constraint 180 | if args.w_pose_limit_prior > 0: 181 | pose_limit_loss = args.w_pose_limit_prior * joint_limit_prior(pose_pred) 182 | meters.update('pose_limit', pose_limit_loss.item()) 183 | loss += pose_limit_loss 184 | 185 | # get refined meshes by adding del_v to the coarse mesh from SMAL 186 | verts_refine, joints_refine, _, _ = model.module.smal(betas_pred, pose_pred, trans=trans_pred, 187 | del_v=shape, 188 | betas_logscale=betas_scale_pred) 189 | # apply 2D keypoint and silhouette supervision 190 | labelled_joints_3d_refine = joints_refine[:, config.MODEL_JOINTS] 191 | synth_landmarks_refine = model.module.model_renderer.project_points(labelled_joints_3d_refine, 192 | pred_camera) 193 | loss_kpts_refine = args.w_kpts_refine * kp_l2_loss(synth_landmarks_refine, keypoints[:, :, [1, 0, 2]], 194 | config.NUM_JOINTS) 195 | 196 | meters.update('loss_kpt_refine', loss_kpts_refine.item()) 197 | loss += loss_kpts_refine 198 | if args.w_dice_refine> 0: 199 | _, synth_silhouettes_refine = model.module.model_renderer(verts_refine, faces, pred_camera) 200 | synth_silhouettes_refine = synth_silhouettes_refine.unsqueeze(1) 201 | loss_dice_refine = args.w_dice_refine * tversky(synth_silhouettes_refine, seg) 202 | meters.update('loss_dice_refine', loss_dice_refine.item()) 203 | loss += loss_dice_refine 204 | 205 | # apply Laplacian constraint to prevent large deformation predictions 206 | if args.w_arap > 0: 207 | verts_clone = verts.detach().clone() 208 | loss_arap, loss_smooth = laplacianloss(verts_refine, verts_clone) 209 | loss_arap = args.w_arap * loss_arap 210 | meters.update('loss_arap', loss_arap.item()) 211 | loss += loss_arap 212 | 213 | meters.update('loss_all', loss.item()) 214 | optimizer.zero_grad() 215 | loss.backward() 216 | optimizer.step() 217 | global_step += 1 218 | if step % 20 == 0: 219 | loss_values = meters.averages() 220 | for name, meter in loss_values.items(): 221 | writer.add_scalar(name, meter, global_step) 222 | writer.flush() 223 | 224 | pck, iou_silh, pck_by_part, pck_re, iou_re = run_evaluation(model, dataset_eval, data_loader_eval, device, args) 225 | 226 | print("Epoch: {:3d}, LR: {:6.5f}, PCK: {:6.4f}, IOU: {:6.4f}, PCK_re: {:6.4f}, IOU_re: {:6.4f}" 227 | .format(epoch, lr, pck, iou_silh, pck_re, iou_re)) 228 | logger.append([epoch, lr, pck, iou_silh, pck_re, iou_re]) 229 | 230 | is_best = pck_re > best_pck 231 | if pck_re > best_pck: 232 | best_pck_epoch = epoch 233 | best_pck = max(pck_re, best_pck) 234 | save_checkpoint({'epoch': epoch, 235 | 'state_dict': model.state_dict(), 236 | 'best_pck': best_pck, 237 | 'optimizer': optimizer.state_dict()}, 238 | is_best, checkpoint=args.output_dir, filename='checkpoint.pth.tar') 239 | writer.close() 240 | logger.close() 241 | 242 | 243 | def run_evaluation(model, dataset, data_loader, device, args): 244 | 245 | model.eval() 246 | result_dir = args.output_dir 247 | batch_size = args.batch_size 248 | 249 | pck = np.zeros((len(dataset))) 250 | pck_by_part = {group: np.zeros((len(dataset))) for group in config.KEYPOINT_GROUPS} 251 | acc_sil_2d = np.zeros(len(dataset)) 252 | 253 | pck_re = np.zeros((len(dataset))) 254 | acc_sil_2d_re = np.zeros(len(dataset)) 255 | 256 | smal_pose = np.zeros((len(dataset), 105)) 257 | smal_betas = np.zeros((len(dataset), 20)) 258 | smal_camera = np.zeros((len(dataset), 3)) 259 | smal_imgname = [] 260 | 261 | tqdm_iterator = tqdm(data_loader, desc='Eval', total=len(data_loader)) 262 | 263 | for step, batch in enumerate(tqdm_iterator): 264 | with torch.no_grad(): 265 | preds = {} 266 | keypoints = batch['keypoints'].to(device) 267 | keypoints_norm = batch['keypoints_norm'].to(device) 268 | seg = batch['seg'].to(device) 269 | has_seg = batch['has_seg'] 270 | img = batch['img'].to(device) 271 | img_border_mask = batch['img_border_mask'].to(device) 272 | # get coarse meshes and project onto 2D 273 | verts, joints, shape, pred_codes = model(img) 274 | scale_pred, trans_pred, pose_pred, betas_pred, betas_scale_pred = pred_codes 275 | pred_camera = torch.cat([scale_pred[:, [0]], torch.ones(keypoints.shape[0], 2).cuda() * config.IMG_RES / 2], 276 | dim=1) 277 | faces = model.module.smal.faces.unsqueeze(0).expand(verts.shape[0], 7774, 3) 278 | labelled_joints_3d = joints[:, config.MODEL_JOINTS] 279 | 280 | synth_rgb, synth_silhouettes = model.module.model_renderer(verts, faces, pred_camera) 281 | synth_silhouettes = synth_silhouettes.unsqueeze(1) 282 | synth_landmarks = model.module.model_renderer.project_points(labelled_joints_3d, pred_camera) 283 | 284 | # get refined meshes by adding del_v to coarse estimations 285 | verts_refine, joints_refine, _, _ = model.module.smal(betas_pred, pose_pred, trans=trans_pred, 286 | del_v=shape, 287 | betas_logscale=betas_scale_pred) 288 | labelled_joints_3d_refine = joints_refine[:, config.MODEL_JOINTS] 289 | # project refined 3D meshes onto 2D 290 | synth_rgb_refine, synth_silhouettes_refine = model.module.model_renderer(verts_refine, faces, pred_camera) 291 | synth_silhouettes_refine = synth_silhouettes_refine.unsqueeze(1) 292 | synth_landmarks_refine = model.module.model_renderer.project_points(labelled_joints_3d_refine, 293 | pred_camera) 294 | 295 | if args.save_results: 296 | synth_rgb = torch.clamp(synth_rgb[0], 0.0, 1.0) 297 | synth_rgb_refine = torch.clamp(synth_rgb_refine[0], 0.0, 1.0) 298 | 299 | preds['pose'] = pose_pred 300 | preds['betas'] = betas_pred 301 | preds['camera'] = pred_camera 302 | preds['trans'] = trans_pred 303 | 304 | preds['verts'] = verts 305 | preds['joints_3d'] = labelled_joints_3d 306 | preds['faces'] = faces 307 | 308 | preds['acc_PCK'] = Metrics.PCK(synth_landmarks, keypoints_norm, seg, has_seg) 309 | preds['acc_IOU'] = Metrics.IOU(synth_silhouettes, seg, img_border_mask, mask=has_seg) 310 | 311 | preds['acc_PCK_re'] = Metrics.PCK(synth_landmarks_refine, keypoints_norm, seg, has_seg) 312 | preds['acc_IOU_re'] = Metrics.IOU(synth_silhouettes_refine, seg, img_border_mask, mask=has_seg) 313 | 314 | for group, group_kps in config.KEYPOINT_GROUPS.items(): 315 | preds[f'{group}_PCK'] = Metrics.PCK(synth_landmarks, keypoints_norm, seg, has_seg, 316 | thresh_range=[0.15], 317 | idxs=group_kps) 318 | 319 | preds['synth_xyz'] = synth_rgb 320 | preds['synth_silhouettes'] = synth_silhouettes 321 | preds['synth_landmarks'] = synth_landmarks 322 | preds['synth_xyz_re'] = synth_rgb_refine 323 | preds['synth_landmarks_re'] = synth_landmarks_refine 324 | preds['synth_silhouettes_re'] = synth_silhouettes_refine 325 | 326 | assert not any(k in preds for k in batch.keys()) 327 | preds.update(batch) 328 | 329 | curr_batch_size = preds['synth_landmarks'].shape[0] 330 | 331 | pck[step * batch_size:step * batch_size + curr_batch_size] = preds['acc_PCK'].data.cpu().numpy() 332 | acc_sil_2d[step * batch_size:step * batch_size + curr_batch_size] = preds['acc_IOU'].data.cpu().numpy() 333 | smal_pose[step * batch_size:step * batch_size + curr_batch_size] = preds['pose'].data.cpu().numpy() 334 | smal_betas[step * batch_size:step * batch_size + curr_batch_size, :preds['betas'].shape[1]] = preds['betas'].data.cpu().numpy() 335 | smal_camera[step * batch_size:step * batch_size + curr_batch_size] = preds['camera'].data.cpu().numpy() 336 | 337 | pck_re[step * batch_size:step * batch_size + curr_batch_size] = preds['acc_PCK_re'].data.cpu().numpy() 338 | acc_sil_2d_re[step * batch_size:step * batch_size + curr_batch_size] = preds['acc_IOU_re'].data.cpu().numpy() 339 | for part in pck_by_part: 340 | pck_by_part[part][step * batch_size:step * batch_size + curr_batch_size] = preds[f'{part}_PCK'].data.cpu().numpy() 341 | 342 | if args.save_results: 343 | output_figs = np.transpose( 344 | Visualizer.generate_output_figures(preds, vis_refine=True).data.cpu().numpy(), 345 | (0, 1, 3, 4, 2)) 346 | 347 | for img_id in range(len(preds['imgname'])): 348 | imgname = preds['imgname'][img_id] 349 | output_fig_list = output_figs[img_id] 350 | 351 | path_parts = imgname.split('/') 352 | path_suffix = "{0}_{1}".format(path_parts[-2], path_parts[-1]) 353 | img_file = os.path.join(result_dir, path_suffix) 354 | output_fig = np.hstack(output_fig_list) 355 | smal_imgname.append(path_suffix) 356 | npz_file = "{0}.npz".format(os.path.splitext(img_file)[0]) 357 | 358 | cv2.imwrite(img_file, output_fig[:, :, ::-1] * 255.0) 359 | # np.savez_compressed(npz_file, 360 | # imgname=preds['imgname'][img_id], 361 | # pose=preds['pose'][img_id].data.cpu().numpy(), 362 | # betas=preds['betas'][img_id].data.cpu().numpy(), 363 | # camera=preds['camera'][img_id].data.cpu().numpy(), 364 | # trans=preds['trans'][img_id].data.cpu().numpy(), 365 | # acc_PCK=preds['acc_PCK'][img_id].data.cpu().numpy(), 366 | # # acc_SIL_2D=preds['acc_IOU'][img_id].data.cpu().numpy(), 367 | # **{f'{part}_PCK': preds[f'{part}_PCK'].data.cpu().numpy() for part in pck_by_part} 368 | # ) 369 | 370 | return np.nanmean(pck), np.nanmean(acc_sil_2d), pck_by_part, np.nanmean(pck_re), np.nanmean(acc_sil_2d_re) 371 | 372 | 373 | if __name__ == '__main__': 374 | parser = argparse.ArgumentParser() 375 | parser.add_argument('--lr', default=0.0001, type=float) 376 | parser.add_argument('--output_dir', default='./logs/', type=str) 377 | parser.add_argument('--nEpochs', default=250, type=int) 378 | 379 | parser.add_argument('--w_kpts', default=10, type=float) 380 | parser.add_argument('--w_betas_prior', default=1, type=float) 381 | parser.add_argument('--w_pose_prior', default=1, type=float) 382 | parser.add_argument('--w_pose_limit_prior', default=0, type=float) 383 | parser.add_argument('--w_kpts_refine', default=1, type=float) 384 | 385 | parser.add_argument('--batch_size', default=16, type=int) 386 | parser.add_argument('--num_works', default=4, type=int) 387 | parser.add_argument('--start_epoch', default=0, type=int) 388 | parser.add_argument('--gpu_ids', default='0', type=str) 389 | parser.add_argument('--evaluate', action='store_true') 390 | parser.add_argument('--resume', default=None, type=str) 391 | parser.add_argument('--load_optimizer', action='store_true') 392 | parser.add_argument('--dataset', default='stanford', type=str) 393 | parser.add_argument('--shape_family_id', default=1, type=int) 394 | parser.add_argument('--param_dir', default=None, type=str, help='Exported parameter folder to load') 395 | 396 | parser.add_argument('--shape_init', default='smal', help='enable to initiate shape with mean shape') 397 | parser.add_argument('--save_results', action='store_true') 398 | parser.add_argument('--prior_betas', default='smal', type=str) 399 | parser.add_argument('--prior_pose', default='smal', type=str) 400 | parser.add_argument('--betas_scale', action='store_true') 401 | 402 | parser.add_argument('--num_channels', type=int, default=256, help='Number of channels in Graph Residual layers') 403 | parser.add_argument('--num_layers', type=int, default=5, help='Number of residuals blocks in the Graph CNN') 404 | parser.add_argument('--pretrained', default=None, type=str) 405 | parser.add_argument('--local_feat', action='store_true') 406 | 407 | parser.add_argument('--num_downsampling', default=1, type=int) 408 | parser.add_argument('--freezecoarse', action='store_true') 409 | 410 | parser.add_argument('--w_arap', default=0, type=float) 411 | parser.add_argument('--w_dice', default=0, type=float) 412 | parser.add_argument('--w_dice_refine', default=0, type=float) 413 | parser.add_argument('--alpha', default=0.6, type=float) 414 | parser.add_argument('--beta', default=0.4, type=float) 415 | 416 | args = parser.parse_args() 417 | main(args) --------------------------------------------------------------------------------