├── common ├── __init__.py ├── quaternion.py ├── mocap_dataset.py ├── generators.py ├── graph_utils.py ├── utils.py ├── skeleton.py ├── data_utils.py ├── log.py ├── loss.py ├── camera.py ├── visualization.py └── h36m_dataset.py ├── models ├── __init__.py ├── gconv │ ├── __init__.py │ ├── no_sharing_graph_conv.py │ ├── conv_style_graph_conv.py │ ├── sem_graph_conv.py │ ├── vanilla_graph_conv.py │ ├── modulated_gcn_conv.py │ ├── pre_agg_graph_conv.py │ ├── post_agg_graph_conv.py │ └── sem_ch_graph_conv.py ├── graph_non_local.py └── graph_sh.py ├── arch.png ├── requirements.txt ├── data ├── convert_cdf_to_mat.m ├── README.md ├── prepare_data_2d_h36m_sh.py └── prepare_data_h36m.py ├── progress ├── spinner.py ├── counter.py ├── bar.py ├── helpers.py └── __init__.py ├── README.md └── main_gcn.py /common/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/gconv/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tamasino52/GraphSH/HEAD/arch.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | numpy 4 | scipy 5 | h5py 6 | matplotlib 7 | -------------------------------------------------------------------------------- /data/convert_cdf_to_mat.m: -------------------------------------------------------------------------------- 1 | % Extract "Poses_D3_Positions_S*.tgz" to the "pose" directory 2 | % and run this script to convert all .cdf files to .mat 3 | 4 | pose_directory = 'pose'; 5 | dirs = dir(strcat(pose_directory, '/*/MyPoseFeatures/D3_Positions/*.cdf')); 6 | 7 | paths = {dirs.folder}; 8 | names = {dirs.name}; 9 | 10 | for i = 1:numel(names) 11 | data = cdfread(strcat(paths{i}, '/', names{i})); 12 | save(strcat(paths{i}, '/', names{i}, '.mat'), 'data'); 13 | end -------------------------------------------------------------------------------- /common/quaternion.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division 2 | 3 | import torch 4 | 5 | 6 | def qrot(q, v): 7 | """ 8 | Rotate vector(s) v about the rotation described by quaternion(s) q. 9 | Expects a tensor of shape (*, 4) for q and a tensor of shape (*, 3) for v, 10 | where * denotes any number of dimensions. 11 | Returns a tensor of shape (*, 3). 12 | """ 13 | assert q.shape[-1] == 4 14 | assert v.shape[-1] == 3 15 | assert q.shape[:-1] == v.shape[:-1] 16 | 17 | qvec = q[..., 1:] 18 | uv = torch.cross(qvec, v, dim=len(q.shape) - 1) 19 | uuv = torch.cross(qvec, uv, dim=len(q.shape) - 1) 20 | return v + 2 * (q[..., :1] * uv + uuv) 21 | 22 | 23 | def qinverse(q, inplace=False): 24 | # We assume the quaternion to be normalized 25 | if inplace: 26 | q[..., 1:] *= -1 27 | return q 28 | else: 29 | w = q[..., :1] 30 | xyz = q[..., 1:] 31 | return torch.cat((w, -xyz), dim=len(q.shape) - 1) 32 | -------------------------------------------------------------------------------- /common/mocap_dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | 4 | class MocapDataset(object): 5 | def __init__(self, skeleton, fps=None): 6 | self._skeleton = skeleton 7 | self._fps = fps 8 | self._data = None # Must be filled by subclass 9 | self._cameras = None # Must be filled by subclass 10 | 11 | def remove_joints(self, joints_to_remove): 12 | kept_joints = self._skeleton.remove_joints(joints_to_remove) 13 | for subject in self._data.keys(): 14 | for action in self._data[subject].keys(): 15 | s = self._data[subject][action] 16 | s['positions'] = s['positions'][:, kept_joints] 17 | 18 | def __getitem__(self, key): 19 | return self._data[key] 20 | 21 | def subjects(self): 22 | return self._data.keys() 23 | 24 | def fps(self): 25 | return self._fps 26 | 27 | def skeleton(self): 28 | return self._skeleton 29 | 30 | def cameras(self): 31 | return self._cameras 32 | 33 | def define_actions(self, action): 34 | # This method can be overridden 35 | return False 36 | -------------------------------------------------------------------------------- /common/generators.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | 3 | import numpy as np 4 | import torch 5 | from torch.utils.data import Dataset 6 | from functools import reduce 7 | 8 | 9 | class PoseGenerator(Dataset): 10 | def __init__(self, poses_3d, poses_2d, actions, cam): 11 | assert poses_3d is not None 12 | 13 | self._poses_3d = np.concatenate(poses_3d) 14 | self._poses_2d = np.concatenate(poses_2d) 15 | self._cam = np.concatenate(cam) 16 | 17 | self._actions = reduce(lambda x, y: x + y, actions) 18 | 19 | assert self._poses_3d.shape[0] == self._poses_2d.shape[0] and self._poses_3d.shape[0] == len(self._actions) and self._poses_3d.shape[0] == self._cam.shape[0] 20 | print('Generating {} poses...'.format(len(self._actions))) 21 | 22 | def __getitem__(self, index): 23 | out_pose_3d = self._poses_3d[index] 24 | out_pose_2d = self._poses_2d[index] 25 | out_action = self._actions[index] 26 | out_cam = self._cam[index] 27 | 28 | out_pose_3d = torch.from_numpy(out_pose_3d).float() 29 | out_pose_2d = torch.from_numpy(out_pose_2d).float() 30 | out_cam = torch.from_numpy(out_cam).float() 31 | 32 | return out_pose_3d, out_pose_2d, out_action, out_cam 33 | 34 | def __len__(self): 35 | return len(self._actions) 36 | -------------------------------------------------------------------------------- /models/gconv/no_sharing_graph_conv.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division 2 | 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class NoSharingGraphConv(nn.Module): 9 | """ 10 | No-sharing graph convolution layer 11 | """ 12 | 13 | def __init__(self, in_features, out_features, adj, bias=True): 14 | super(NoSharingGraphConv, self).__init__() 15 | self.in_features = in_features 16 | self.out_features = out_features 17 | 18 | self.n_pts = adj.size(1) 19 | self.W = nn.Parameter(torch.zeros(size=(self.n_pts, self.n_pts, in_features, out_features), dtype=torch.float)) 20 | nn.init.xavier_uniform_(self.W.data, gain=1.414) 21 | 22 | self.adj = adj 23 | 24 | if bias: 25 | self.bias = nn.Parameter(torch.zeros(out_features, dtype=torch.float)) 26 | stdv = 1. / math.sqrt(self.W.size(2)) 27 | self.bias.data.uniform_(-stdv, stdv) 28 | else: 29 | self.register_parameter('bias', None) 30 | 31 | def forward(self, input): 32 | adj = self.adj[None, :].to(input.device) 33 | 34 | h0 = torch.einsum('bhn,hwnm->bhwm', input, self.W) 35 | output = torch.einsum('bhw, bhwm->bwm', adj, h0) 36 | 37 | if self.bias is not None: 38 | return output + self.bias.view(1, 1, -1) 39 | else: 40 | return output 41 | 42 | def __repr__(self): 43 | return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')' -------------------------------------------------------------------------------- /progress/spinner.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright (c) 2012 Giorgos Verigakis 4 | # 5 | # Permission to use, copy, modify, and distribute this software for any 6 | # purpose with or without fee is hereby granted, provided that the above 7 | # copyright notice and this permission notice appear in all copies. 8 | # 9 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 10 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 11 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 12 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 13 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 14 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 15 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 16 | 17 | from __future__ import unicode_literals 18 | from . import Infinite 19 | from .helpers import WriteMixin 20 | 21 | 22 | class Spinner(WriteMixin, Infinite): 23 | message = '' 24 | phases = ('-', '\\', '|', '/') 25 | hide_cursor = True 26 | 27 | def update(self): 28 | i = self.index % len(self.phases) 29 | self.write(self.phases[i]) 30 | 31 | 32 | class PieSpinner(Spinner): 33 | phases = ['◷', '◶', '◵', '◴'] 34 | 35 | 36 | class MoonSpinner(Spinner): 37 | phases = ['◑', '◒', '◐', '◓'] 38 | 39 | 40 | class LineSpinner(Spinner): 41 | phases = ['⎺', '⎻', '⎼', '⎽', '⎼', '⎻'] 42 | 43 | class PixelSpinner(Spinner): 44 | phases = ['⣾','⣷', '⣯', '⣟', '⡿', '⢿', '⣻', '⣽'] 45 | -------------------------------------------------------------------------------- /progress/counter.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright (c) 2012 Giorgos Verigakis 4 | # 5 | # Permission to use, copy, modify, and distribute this software for any 6 | # purpose with or without fee is hereby granted, provided that the above 7 | # copyright notice and this permission notice appear in all copies. 8 | # 9 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 10 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 11 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 12 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 13 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 14 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 15 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 16 | 17 | from __future__ import unicode_literals 18 | from . import Infinite, Progress 19 | from .helpers import WriteMixin 20 | 21 | 22 | class Counter(WriteMixin, Infinite): 23 | message = '' 24 | hide_cursor = True 25 | 26 | def update(self): 27 | self.write(str(self.index)) 28 | 29 | 30 | class Countdown(WriteMixin, Progress): 31 | hide_cursor = True 32 | 33 | def update(self): 34 | self.write(str(self.remaining)) 35 | 36 | 37 | class Stack(WriteMixin, Progress): 38 | phases = (' ', '▁', '▂', '▃', '▄', '▅', '▆', '▇', '█') 39 | hide_cursor = True 40 | 41 | def update(self): 42 | nphases = len(self.phases) 43 | i = min(nphases - 1, int(self.progress * nphases)) 44 | self.write(self.phases[i]) 45 | 46 | 47 | class Pie(Stack): 48 | phases = ('○', '◔', '◑', '◕', '●') 49 | -------------------------------------------------------------------------------- /models/gconv/conv_style_graph_conv.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division 2 | 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class ConvStyleGraphConv(nn.Module): 9 | """ 10 | Convolution-style graph convolution layer 11 | """ 12 | 13 | def __init__(self, in_features, out_features, adj, bias=True): 14 | super(ConvStyleGraphConv, self).__init__() 15 | self.in_features = in_features 16 | self.out_features = out_features 17 | 18 | self.W = nn.Parameter(torch.zeros(size=(3, in_features, out_features), dtype=torch.float)) 19 | nn.init.xavier_uniform_(self.W.data, gain=1.414) 20 | 21 | self.adj = adj 22 | 23 | if bias: 24 | self.bias = nn.Parameter(torch.zeros(out_features, dtype=torch.float)) 25 | stdv = 1. / math.sqrt(self.W.size(2)) 26 | self.bias.data.uniform_(-stdv, stdv) 27 | else: 28 | self.register_parameter('bias', None) 29 | 30 | def forward(self, input): 31 | adj = self.adj[None, :].to(input.device) 32 | 33 | h0 = torch.matmul(input, self.W[0]) 34 | h1 = torch.matmul(input, self.W[1]) 35 | h2 = torch.matmul(input, self.W[2]) 36 | 37 | E0 = torch.eye(adj.size(1), dtype=torch.float).to(input.device) 38 | E1 = torch.triu(torch.ones_like(adj), diagonal=1) 39 | E2 = 1 - E1 - E0 40 | 41 | output = torch.matmul(adj * E0, h0) + torch.matmul(adj * E1, h1) + torch.matmul(adj * E2, h2) 42 | 43 | if self.bias is not None: 44 | return output + self.bias.view(1, 1, -1) 45 | else: 46 | return output 47 | 48 | def __repr__(self): 49 | return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')' -------------------------------------------------------------------------------- /models/gconv/sem_graph_conv.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division 2 | 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class SemGraphConv(nn.Module): 10 | """ 11 | Semantic graph convolution layer 12 | """ 13 | 14 | def __init__(self, in_features, out_features, adj, bias=True): 15 | super(SemGraphConv, self).__init__() 16 | self.in_features = in_features 17 | self.out_features = out_features 18 | 19 | self.W = nn.Parameter(torch.zeros(size=(2, in_features, out_features), dtype=torch.float)) 20 | nn.init.xavier_uniform_(self.W.data, gain=1.414) 21 | 22 | self.adj = adj 23 | self.m = (self.adj > 0) 24 | self.e = nn.Parameter(torch.zeros(1, len(self.m.nonzero()), dtype=torch.float)) 25 | nn.init.constant_(self.e.data, 1) 26 | 27 | if bias: 28 | self.bias = nn.Parameter(torch.zeros(out_features, dtype=torch.float)) 29 | stdv = 1. / math.sqrt(self.W.size(2)) 30 | self.bias.data.uniform_(-stdv, stdv) 31 | else: 32 | self.register_parameter('bias', None) 33 | 34 | def forward(self, input): 35 | h0 = torch.matmul(input, self.W[0]) 36 | h1 = torch.matmul(input, self.W[1]) 37 | 38 | adj = -9e15 * torch.ones_like(self.adj).to(input.device) 39 | adj[self.m] = self.e 40 | adj = F.softmax(adj, dim=1) 41 | 42 | M = torch.eye(adj.size(0), dtype=torch.float).to(input.device) 43 | output = torch.matmul(adj * M, h0) + torch.matmul(adj * (1 - M), h1) 44 | 45 | if self.bias is not None: 46 | return output + self.bias.view(1, 1, -1) 47 | else: 48 | return output 49 | 50 | def __repr__(self): 51 | return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')' -------------------------------------------------------------------------------- /common/graph_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | import numpy as np 5 | import scipy.sparse as sp 6 | 7 | 8 | def normalize(mx): 9 | """Row-normalize sparse matrix""" 10 | rowsum = np.array(mx.sum(1)) 11 | r_inv = np.power(rowsum, -1).flatten() 12 | r_inv[np.isinf(r_inv)] = 0. 13 | r_mat_inv = sp.diags(r_inv) 14 | mx = r_mat_inv.dot(mx) 15 | return mx 16 | 17 | 18 | def sparse_mx_to_torch_sparse_tensor(sparse_mx): 19 | """Convert a scipy sparse matrix to a torch sparse tensor.""" 20 | sparse_mx = sparse_mx.tocoo().astype(np.float32) 21 | indices = torch.from_numpy(np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64)) 22 | values = torch.from_numpy(sparse_mx.data) 23 | shape = torch.Size(sparse_mx.shape) 24 | return torch.sparse.FloatTensor(indices, values, shape) 25 | 26 | 27 | def adj_mx_from_edges(num_pts, edges, sparse=True): 28 | edges = np.array(edges, dtype=np.int32) 29 | data, i, j = np.ones(edges.shape[0]), edges[:, 0], edges[:, 1] 30 | adj_mx = sp.coo_matrix((data, (i, j)), shape=(num_pts, num_pts), dtype=np.float32) 31 | 32 | # build symmetric adjacency matrix 33 | adj_mx = adj_mx + adj_mx.T.multiply(adj_mx.T > adj_mx) - adj_mx.multiply(adj_mx.T > adj_mx) 34 | adj_mx = normalize(adj_mx + sp.eye(adj_mx.shape[0])) 35 | if sparse: 36 | adj_mx = sparse_mx_to_torch_sparse_tensor(adj_mx) 37 | else: 38 | adj_mx = torch.tensor(adj_mx.todense(), dtype=torch.float) 39 | 40 | #adj_mx = adj_mx * (1-torch.eye(adj_mx.shape[0])) + torch.eye(adj_mx.shape[0]) 41 | 42 | return adj_mx 43 | 44 | 45 | def adj_mx_from_skeleton(skeleton): 46 | num_joints = skeleton.num_joints() 47 | edges = list(filter(lambda x: x[1] >= 0, zip(list(range(0, num_joints)), skeleton.parents()))) 48 | 49 | adj = adj_mx_from_edges(num_joints, edges, sparse=False) 50 | 51 | return adj 52 | -------------------------------------------------------------------------------- /models/gconv/vanilla_graph_conv.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division 2 | 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class DecoupleVanillaGraphConv(nn.Module): 9 | """ 10 | Vanilla graph convolution layer 11 | """ 12 | 13 | def __init__(self, in_features, out_features, adj, decouple=True, bias=True): 14 | super(DecoupleVanillaGraphConv, self).__init__() 15 | self.decouple = decouple 16 | self.in_features = in_features 17 | self.out_features = out_features 18 | 19 | if decouple: 20 | self.W = nn.Parameter(torch.zeros(size=(2, in_features, out_features), dtype=torch.float)) 21 | else: 22 | self.W = nn.Parameter(torch.zeros(size=(1, in_features, out_features), dtype=torch.float)) 23 | 24 | nn.init.xavier_uniform_(self.W.data, gain=1.414) 25 | 26 | self.adj = adj 27 | 28 | if bias: 29 | self.bias = nn.Parameter(torch.zeros(out_features, dtype=torch.float)) 30 | stdv = 1. / math.sqrt(self.W.size(2)) 31 | self.bias.data.uniform_(-stdv, stdv) 32 | else: 33 | self.register_parameter('bias', None) 34 | 35 | def forward(self, input): 36 | adj = self.adj[None, :].to(input.device) 37 | 38 | if self.decouple: 39 | h0 = torch.matmul(input, self.W[0]) 40 | h1 = torch.matmul(input, self.W[1]) 41 | 42 | E = torch.eye(adj.size(1), dtype=torch.float).to(input.device) 43 | output = torch.matmul(adj * E, h0) + torch.matmul(adj * (1 - E), h1) 44 | else: 45 | h0 = torch.matmul(input, self.W[0]) 46 | output = torch.matmul(adj, h0) 47 | 48 | if self.bias is not None: 49 | return output + self.bias.view(1, 1, -1) 50 | else: 51 | return output 52 | 53 | def __repr__(self): 54 | return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')' -------------------------------------------------------------------------------- /models/gconv/modulated_gcn_conv.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division 2 | 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class ModulatedGraphConv(nn.Module): 9 | """ 10 | Modulated graph convolution layer 11 | """ 12 | 13 | def __init__(self, in_features, out_features, adj, bias=True): 14 | super(ModulatedGraphConv, self).__init__() 15 | self.in_features = in_features 16 | self.out_features = out_features 17 | 18 | self.W = nn.Parameter(torch.zeros(size=(2, in_features, out_features), dtype=torch.float)) 19 | nn.init.xavier_uniform_(self.W.data, gain=1.414) 20 | 21 | self.M = nn.Parameter(torch.ones(size=(adj.size(0), out_features), dtype=torch.float)) 22 | 23 | self.adj = adj 24 | self.adj2 = nn.Parameter(torch.ones_like(adj)) 25 | nn.init.constant_(self.adj2, 1e-6) 26 | 27 | if bias: 28 | self.bias = nn.Parameter(torch.zeros(out_features, dtype=torch.float)) 29 | stdv = 1. / math.sqrt(self.W.size(2)) 30 | self.bias.data.uniform_(-stdv, stdv) 31 | else: 32 | self.register_parameter('bias', None) 33 | 34 | def forward(self, input): 35 | h0 = torch.matmul(input, self.W[0]) 36 | h1 = torch.matmul(input, self.W[1]) 37 | 38 | # add modulation 39 | adj = self.adj.to(input.device) + self.adj2.to(input.device) 40 | 41 | # symmetry modulation 42 | adj = (adj.T + adj)/2 43 | 44 | # mix modulation 45 | E = torch.eye(adj.size(1), dtype=torch.float).to(input.device) 46 | output = torch.matmul(adj * E, self.M * h0) + torch.matmul(adj * (1 - E), self.M * h1) 47 | if self.bias is not None: 48 | return output + self.bias.view(1, 1, -1) 49 | else: 50 | return output 51 | 52 | def __repr__(self): 53 | return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')' 54 | -------------------------------------------------------------------------------- /models/gconv/pre_agg_graph_conv.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division 2 | 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class DecouplePreAggGraphConv(nn.Module): 9 | """ 10 | Pre-aggregation graph convolution layer 11 | """ 12 | 13 | def __init__(self, in_features, out_features, adj, decouple=True, bias=True): 14 | super(DecouplePreAggGraphConv, self).__init__() 15 | self.decouple = decouple 16 | self.in_features = in_features 17 | self.out_features = out_features 18 | self.n_pts = adj.size(1) 19 | 20 | if decouple: 21 | self.W = nn.Parameter(torch.zeros(size=(2, self.n_pts, in_features, out_features), dtype=torch.float)) 22 | else: 23 | self.W = nn.Parameter(torch.zeros(size=(1, self.n_pts, in_features, out_features), dtype=torch.float)) 24 | 25 | nn.init.xavier_uniform_(self.W.data, gain=1.414) 26 | 27 | self.adj = adj 28 | 29 | if bias: 30 | self.bias = nn.Parameter(torch.zeros(out_features, dtype=torch.float)) 31 | stdv = 1. / math.sqrt(self.W.size(2)) 32 | self.bias.data.uniform_(-stdv, stdv) 33 | else: 34 | self.register_parameter('bias', None) 35 | 36 | def forward(self, input): 37 | adj = self.adj[None, :].to(input.device) 38 | 39 | if self.decouple: 40 | h0 = torch.einsum('bjn,jnm->bjm', input, self.W[0]) 41 | h1 = torch.einsum('bjn,jnm->bjm', input, self.W[1]) 42 | E = torch.eye(adj.size(1), dtype=torch.float).to(input.device) 43 | output = torch.matmul(adj * E, h0) + torch.matmul(adj * (1 - E), h1) 44 | else: 45 | h0 = torch.einsum('bjn,jnm->bjm', input, self.W[0]) 46 | output = torch.matmul(adj, h0) 47 | if self.bias is not None: 48 | return output + self.bias.view(1, 1, -1) 49 | else: 50 | return output 51 | 52 | def __repr__(self): 53 | return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')' -------------------------------------------------------------------------------- /models/gconv/post_agg_graph_conv.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division 2 | 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class DecouplePostAggGraphConv(nn.Module): 9 | """ 10 | Post-aggregation graph convolution layer 11 | """ 12 | 13 | def __init__(self, in_features, out_features, adj, decouple=True, bias=True): 14 | super(DecouplePostAggGraphConv, self).__init__() 15 | self.decouple = decouple 16 | self.in_features = in_features 17 | self.out_features = out_features 18 | self.n_pts = adj.size(1) 19 | 20 | if decouple: 21 | self.W = nn.Parameter(torch.zeros(size=(2, self.n_pts, in_features, out_features), dtype=torch.float)) 22 | else: 23 | self.W = nn.Parameter(torch.zeros(size=(1, self.n_pts, in_features, out_features), dtype=torch.float)) 24 | nn.init.xavier_uniform_(self.W.data, gain=1.414) 25 | 26 | self.adj = adj 27 | 28 | if bias: 29 | self.bias = nn.Parameter(torch.zeros(out_features, dtype=torch.float)) 30 | stdv = 1. / math.sqrt(self.W.size(2)) 31 | self.bias.data.uniform_(-stdv, stdv) 32 | else: 33 | self.register_parameter('bias', None) 34 | 35 | def forward(self, input): 36 | adj = self.adj[None, :].to(input.device) 37 | 38 | if self.decouple: 39 | E = torch.eye(adj.size(1), dtype=torch.float).to(input.device) 40 | 41 | h0 = torch.matmul(adj * E, input) 42 | h1 = torch.matmul(adj * (1 - E), input) 43 | 44 | output = torch.einsum('bjn,jnm->bjm', h0, self.W[0]) + torch.einsum('bjn,jnm->bjm', h1, self.W[1]) 45 | else: 46 | h0 = torch.matmul(self.adj, input) 47 | output = torch.einsum('bjn,jnm->bjm', h0, self.W) 48 | 49 | if self.bias is not None: 50 | return output + self.bias.view(1, 1, -1) 51 | else: 52 | return output 53 | 54 | def __repr__(self): 55 | return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')' -------------------------------------------------------------------------------- /common/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division 2 | 3 | import os 4 | import torch 5 | import numpy as np 6 | 7 | 8 | class AverageMeter(object): 9 | """Computes and stores the average and current value""" 10 | def __init__(self): 11 | self.reset() 12 | 13 | def reset(self): 14 | self.val = 0 15 | self.avg = 0 16 | self.sum = 0 17 | self.count = 0 18 | 19 | def update(self, val, n=1): 20 | self.val = val 21 | self.sum += val * n 22 | self.count += n 23 | self.avg = self.sum / self.count 24 | 25 | 26 | def lr_decay(optimizer, step, lr, decay_step, gamma): 27 | lr = lr * gamma ** (step / decay_step) 28 | for param_group in optimizer.param_groups: 29 | param_group['lr'] = lr 30 | return lr 31 | 32 | 33 | def save_ckpt(state, ckpt_path, suffix=None): 34 | if suffix is None: 35 | suffix = 'epoch_{:04d}'.format(state['epoch']) 36 | 37 | file_path = os.path.join(ckpt_path, 'ckpt_{}.pth.tar'.format(suffix)) 38 | torch.save(state, file_path) 39 | 40 | 41 | def wrap(func, unsqueeze, *args): 42 | """ 43 | Wrap a torch function so it can be called with NumPy arrays. 44 | Input and return types are seamlessly converted. 45 | """ 46 | 47 | # Convert input types where applicable 48 | args = list(args) 49 | for i, arg in enumerate(args): 50 | if type(arg) == np.ndarray: 51 | args[i] = torch.from_numpy(arg) 52 | if unsqueeze: 53 | args[i] = args[i].unsqueeze(0) 54 | 55 | result = func(*args) 56 | 57 | # Convert output types where applicable 58 | if isinstance(result, tuple): 59 | result = list(result) 60 | for i, res in enumerate(result): 61 | if type(res) == torch.Tensor: 62 | if unsqueeze: 63 | res = res.squeeze(0) 64 | result[i] = res.numpy() 65 | return tuple(result) 66 | elif type(result) == torch.Tensor: 67 | if unsqueeze: 68 | result = result.squeeze(0) 69 | return result.numpy() 70 | else: 71 | return result 72 | -------------------------------------------------------------------------------- /models/gconv/sem_ch_graph_conv.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division 2 | 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class SemCHGraphConv(nn.Module): 10 | """ 11 | Semantic channel-wise graph convolution layer 12 | """ 13 | 14 | def __init__(self, in_features, out_features, adj, bias=True): 15 | super(SemCHGraphConv, self).__init__() 16 | self.in_features = in_features 17 | self.out_features = out_features 18 | 19 | self.W = nn.Parameter(torch.zeros(size=(2, in_features, out_features), dtype=torch.float)) 20 | nn.init.xavier_uniform_(self.W.data, gain=1.414) 21 | 22 | self.adj = adj.unsqueeze(0).repeat(out_features, 1, 1) 23 | self.m = (self.adj > 0) 24 | self.e = nn.Parameter(torch.zeros(out_features, len(self.m[0].nonzero()), dtype=torch.float)) 25 | nn.init.constant_(self.e.data, 1) 26 | 27 | if bias: 28 | self.bias = nn.Parameter(torch.zeros(out_features, dtype=torch.float)) 29 | stdv = 1. / math.sqrt(self.W.size(1)) 30 | self.bias.data.uniform_(-stdv, stdv) 31 | else: 32 | self.register_parameter('bias', None) 33 | 34 | def forward(self, input): 35 | h0 = torch.matmul(input, self.W[0]).unsqueeze(1).transpose(1, 3) # B * C * J * 1 36 | h1 = torch.matmul(input, self.W[1]).unsqueeze(1).transpose(1, 3) # B * C * J * 1 37 | 38 | adj = -9e15 * torch.ones_like(self.adj).to(input.device) # C * J * J 39 | adj[self.m] = self.e.view(-1) 40 | adj = F.softmax(adj, dim=2) 41 | 42 | E = torch.eye(adj.size(1), dtype=torch.float).to(input.device) 43 | E = E.unsqueeze(0).repeat(self.out_features, 1, 1) # C * J * J 44 | output = torch.matmul(adj * E, h0) + torch.matmul(adj * (1 - E), h1) 45 | output = output.transpose(1, 3).squeeze(1) 46 | 47 | if self.bias is not None: 48 | return output + self.bias.view(1, 1, -1) 49 | else: 50 | return output 51 | 52 | def __repr__(self): 53 | return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')' -------------------------------------------------------------------------------- /common/skeleton.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import numpy as np 4 | 5 | 6 | class Skeleton(object): 7 | def __init__(self, parents, joints_left, joints_right, joints_group=None): 8 | assert len(joints_left) == len(joints_right) 9 | 10 | self._parents = np.array(parents) 11 | self._joints_left = joints_left 12 | self._joints_right = joints_right 13 | self._joints_group = joints_group 14 | self._compute_metadata() 15 | 16 | def num_joints(self): 17 | return len(self._parents) 18 | 19 | def parents(self): 20 | return self._parents 21 | 22 | def has_children(self): 23 | return self._has_children 24 | 25 | def children(self): 26 | return self._children 27 | 28 | def remove_joints(self, joints_to_remove): 29 | """ 30 | Remove the joints specified in 'joints_to_remove'. 31 | """ 32 | valid_joints = [] 33 | for joint in range(len(self._parents)): 34 | if joint not in joints_to_remove: 35 | valid_joints.append(joint) 36 | 37 | for i in range(len(self._parents)): 38 | while self._parents[i] in joints_to_remove: 39 | self._parents[i] = self._parents[self._parents[i]] 40 | 41 | index_offsets = np.zeros(len(self._parents), dtype=int) 42 | new_parents = [] 43 | for i, parent in enumerate(self._parents): 44 | if i not in joints_to_remove: 45 | new_parents.append(parent - index_offsets[parent]) 46 | else: 47 | index_offsets[i:] += 1 48 | self._parents = np.array(new_parents) 49 | 50 | if self._joints_left is not None: 51 | new_joints_left = [] 52 | for joint in self._joints_left: 53 | if joint in valid_joints: 54 | new_joints_left.append(joint - index_offsets[joint]) 55 | self._joints_left = new_joints_left 56 | if self._joints_right is not None: 57 | new_joints_right = [] 58 | for joint in self._joints_right: 59 | if joint in valid_joints: 60 | new_joints_right.append(joint - index_offsets[joint]) 61 | self._joints_right = new_joints_right 62 | 63 | self._compute_metadata() 64 | 65 | return valid_joints 66 | 67 | def joints_left(self): 68 | return self._joints_left 69 | 70 | def joints_right(self): 71 | return self._joints_right 72 | 73 | def joints_group(self): 74 | return self._joints_group 75 | 76 | def _compute_metadata(self): 77 | self._has_children = np.zeros(len(self._parents)).astype(bool) 78 | for i, parent in enumerate(self._parents): 79 | if parent != -1: 80 | self._has_children[parent] = True 81 | 82 | self._children = [] 83 | for i, parent in enumerate(self._parents): 84 | self._children.append([]) 85 | for i, parent in enumerate(self._parents): 86 | if parent != -1: 87 | self._children[parent].append(i) 88 | -------------------------------------------------------------------------------- /progress/bar.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright (c) 2012 Giorgos Verigakis 4 | # 5 | # Permission to use, copy, modify, and distribute this software for any 6 | # purpose with or without fee is hereby granted, provided that the above 7 | # copyright notice and this permission notice appear in all copies. 8 | # 9 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 10 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 11 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 12 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 13 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 14 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 15 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 16 | 17 | from __future__ import unicode_literals 18 | from . import Progress 19 | from .helpers import WritelnMixin 20 | 21 | 22 | class Bar(WritelnMixin, Progress): 23 | width = 32 24 | message = '' 25 | suffix = '%(index)d/%(max)d' 26 | bar_prefix = ' |' 27 | bar_suffix = '| ' 28 | empty_fill = ' ' 29 | fill = '#' 30 | hide_cursor = True 31 | 32 | def update(self): 33 | filled_length = int(self.width * self.progress) 34 | empty_length = self.width - filled_length 35 | 36 | message = self.message % self 37 | bar = self.fill * filled_length 38 | empty = self.empty_fill * empty_length 39 | suffix = self.suffix % self 40 | line = ''.join([message, self.bar_prefix, bar, empty, self.bar_suffix, 41 | suffix]) 42 | self.writeln(line) 43 | 44 | 45 | class ChargingBar(Bar): 46 | suffix = '%(percent)d%%' 47 | bar_prefix = ' ' 48 | bar_suffix = ' ' 49 | empty_fill = '∙' 50 | fill = '█' 51 | 52 | 53 | class FillingSquaresBar(ChargingBar): 54 | empty_fill = '▢' 55 | fill = '▣' 56 | 57 | 58 | class FillingCirclesBar(ChargingBar): 59 | empty_fill = '◯' 60 | fill = '◉' 61 | 62 | 63 | class IncrementalBar(Bar): 64 | phases = (' ', '▏', '▎', '▍', '▌', '▋', '▊', '▉', '█') 65 | 66 | def update(self): 67 | nphases = len(self.phases) 68 | filled_len = self.width * self.progress 69 | nfull = int(filled_len) # Number of full chars 70 | phase = int((filled_len - nfull) * nphases) # Phase of last char 71 | nempty = self.width - nfull # Number of empty chars 72 | 73 | message = self.message % self 74 | bar = self.phases[-1] * nfull 75 | current = self.phases[phase] if phase > 0 else '' 76 | empty = self.empty_fill * max(0, nempty - len(current)) 77 | suffix = self.suffix % self 78 | line = ''.join([message, self.bar_prefix, bar, current, empty, 79 | self.bar_suffix, suffix]) 80 | self.writeln(line) 81 | 82 | 83 | class PixelBar(IncrementalBar): 84 | phases = ('⡀', '⡄', '⡆', '⡇', '⣇', '⣧', '⣷', '⣿') 85 | 86 | 87 | class ShadyBar(IncrementalBar): 88 | phases = (' ', '░', '▒', '▓', '█') 89 | -------------------------------------------------------------------------------- /progress/helpers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2012 Giorgos Verigakis 2 | # 3 | # Permission to use, copy, modify, and distribute this software for any 4 | # purpose with or without fee is hereby granted, provided that the above 5 | # copyright notice and this permission notice appear in all copies. 6 | # 7 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 8 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 9 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 10 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 11 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 12 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 13 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 14 | 15 | from __future__ import print_function 16 | 17 | 18 | HIDE_CURSOR = '\x1b[?25l' 19 | SHOW_CURSOR = '\x1b[?25h' 20 | 21 | 22 | class WriteMixin(object): 23 | hide_cursor = False 24 | 25 | def __init__(self, message=None, **kwargs): 26 | super(WriteMixin, self).__init__(**kwargs) 27 | self._width = 0 28 | if message: 29 | self.message = message 30 | 31 | if self.file.isatty(): 32 | if self.hide_cursor: 33 | print(HIDE_CURSOR, end='', file=self.file) 34 | print(self.message, end='', file=self.file) 35 | self.file.flush() 36 | 37 | def write(self, s): 38 | if self.file.isatty(): 39 | b = '\b' * self._width 40 | c = s.ljust(self._width) 41 | print(b + c, end='', file=self.file) 42 | self._width = max(self._width, len(s)) 43 | self.file.flush() 44 | 45 | def finish(self): 46 | if self.file.isatty() and self.hide_cursor: 47 | print(SHOW_CURSOR, end='', file=self.file) 48 | 49 | 50 | class WritelnMixin(object): 51 | hide_cursor = False 52 | 53 | def __init__(self, message=None, **kwargs): 54 | super(WritelnMixin, self).__init__(**kwargs) 55 | if message: 56 | self.message = message 57 | 58 | if self.file.isatty() and self.hide_cursor: 59 | print(HIDE_CURSOR, end='', file=self.file) 60 | 61 | def clearln(self): 62 | if self.file.isatty(): 63 | print('\r\x1b[K', end='', file=self.file) 64 | 65 | def writeln(self, line): 66 | if self.file.isatty(): 67 | self.clearln() 68 | print(line, end='', file=self.file) 69 | self.file.flush() 70 | 71 | def finish(self): 72 | if self.file.isatty(): 73 | print(file=self.file) 74 | if self.hide_cursor: 75 | print(SHOW_CURSOR, end='', file=self.file) 76 | 77 | 78 | from signal import signal, SIGINT 79 | from sys import exit 80 | 81 | 82 | class SigIntMixin(object): 83 | """Registers a signal handler that calls finish on SIGINT""" 84 | 85 | def __init__(self, *args, **kwargs): 86 | super(SigIntMixin, self).__init__(*args, **kwargs) 87 | signal(SIGINT, self._sigint_handler) 88 | 89 | def _sigint_handler(self, signum, frame): 90 | self.finish() 91 | exit(0) 92 | -------------------------------------------------------------------------------- /common/data_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division 2 | 3 | import numpy as np 4 | 5 | from .camera import world_to_camera, normalize_screen_coordinates 6 | 7 | 8 | def create_2d_data(data_path, dataset): 9 | keypoints = np.load(data_path, allow_pickle=True) 10 | keypoints = keypoints['positions_2d'].item() 11 | 12 | for subject in keypoints.keys(): 13 | for action in keypoints[subject]: 14 | #print(action) 15 | for cam_idx, kps in enumerate(keypoints[subject][action]): 16 | # Normalize camera frame 17 | cam = dataset.cameras()[subject][cam_idx] 18 | kps[..., :2] = normalize_screen_coordinates(kps[..., :2], w=cam['res_w'], h=cam['res_h']) 19 | keypoints[subject][action][cam_idx] = kps 20 | 21 | return keypoints 22 | 23 | 24 | def read_3d_data(dataset): 25 | for subject in dataset.subjects(): 26 | for action in dataset[subject].keys(): 27 | anim = dataset[subject][action] 28 | 29 | positions_3d = [] 30 | for cam in anim['cameras']: 31 | pos_3d = world_to_camera(anim['positions'], R=cam['orientation'], t=cam['translation']) 32 | pos_3d[:, :] -= pos_3d[:, :1] # Remove global offset 33 | positions_3d.append(pos_3d) 34 | anim['positions_3d'] = positions_3d 35 | 36 | return dataset 37 | 38 | 39 | def fetch(subjects, dataset, keypoints, action_filter=None, stride=1, parse_3d_poses=True): 40 | out_poses_3d = [] 41 | out_poses_2d = [] 42 | out_actions = [] 43 | out_camera_params = [] 44 | 45 | for subject in subjects: 46 | for action in keypoints[subject].keys(): 47 | if action_filter is not None: 48 | found = False 49 | for a in action_filter: 50 | # if action.startswith(a): 51 | if action.split(' ')[0] == a: 52 | found = True 53 | break 54 | if not found: 55 | continue 56 | 57 | cams = dataset.cameras()[subject] 58 | poses_2d = keypoints[subject][action] 59 | for i in range(len(poses_2d)): # Iterate across cameras 60 | out_poses_2d.append(poses_2d[i]) 61 | out_actions.append([action.split(' ')[0]] * poses_2d[i].shape[0]) 62 | out_camera_params.append([cams[i]['intrinsic']] * poses_2d[i].shape[0]) 63 | 64 | if parse_3d_poses and 'positions_3d' in dataset[subject][action]: 65 | poses_3d = dataset[subject][action]['positions_3d'] 66 | assert len(poses_3d) == len(poses_2d), 'Camera count mismatch' 67 | for i in range(len(poses_3d)): # Iterate across cameras 68 | out_poses_3d.append(poses_3d[i]) 69 | 70 | 71 | if len(out_poses_3d) == 0: 72 | out_poses_3d = None 73 | 74 | if stride > 1: 75 | # Downsample as requested 76 | for i in range(len(out_poses_2d)): 77 | out_poses_2d[i] = out_poses_2d[i][::stride] 78 | out_actions[i] = out_actions[i][::stride] 79 | if out_poses_3d is not None: 80 | out_poses_3d[i] = out_poses_3d[i][::stride] 81 | 82 | return out_poses_3d, out_poses_2d, out_actions, out_camera_params 83 | -------------------------------------------------------------------------------- /progress/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2012 Giorgos Verigakis 2 | # 3 | # Permission to use, copy, modify, and distribute this software for any 4 | # purpose with or without fee is hereby granted, provided that the above 5 | # copyright notice and this permission notice appear in all copies. 6 | # 7 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 8 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 9 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 10 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 11 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 12 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 13 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 14 | 15 | from __future__ import division 16 | 17 | from collections import deque 18 | from datetime import timedelta 19 | from math import ceil 20 | from sys import stderr 21 | from time import time 22 | 23 | 24 | __version__ = '1.3' 25 | 26 | 27 | class Infinite(object): 28 | file = stderr 29 | sma_window = 10 # Simple Moving Average window 30 | 31 | def __init__(self, *args, **kwargs): 32 | self.index = 0 33 | self.start_ts = time() 34 | self.avg = 0 35 | self._ts = self.start_ts 36 | self._xput = deque(maxlen=self.sma_window) 37 | for key, val in kwargs.items(): 38 | setattr(self, key, val) 39 | 40 | def __getitem__(self, key): 41 | if key.startswith('_'): 42 | return None 43 | return getattr(self, key, None) 44 | 45 | @property 46 | def elapsed(self): 47 | return int(time() - self.start_ts) 48 | 49 | @property 50 | def elapsed_td(self): 51 | return timedelta(seconds=self.elapsed) 52 | 53 | def update_avg(self, n, dt): 54 | if n > 0: 55 | self._xput.append(dt / n) 56 | self.avg = sum(self._xput) / len(self._xput) 57 | 58 | def update(self): 59 | pass 60 | 61 | def start(self): 62 | pass 63 | 64 | def finish(self): 65 | pass 66 | 67 | def next(self, n=1): 68 | now = time() 69 | dt = now - self._ts 70 | self.update_avg(n, dt) 71 | self._ts = now 72 | self.index = self.index + n 73 | self.update() 74 | 75 | def iter(self, it): 76 | try: 77 | for x in it: 78 | yield x 79 | self.next() 80 | finally: 81 | self.finish() 82 | 83 | 84 | class Progress(Infinite): 85 | def __init__(self, *args, **kwargs): 86 | super(Progress, self).__init__(*args, **kwargs) 87 | self.max = kwargs.get('max', 100) 88 | 89 | @property 90 | def eta(self): 91 | return int(ceil(self.avg * self.remaining)) 92 | 93 | @property 94 | def eta_td(self): 95 | return timedelta(seconds=self.eta) 96 | 97 | @property 98 | def percent(self): 99 | return self.progress * 100 100 | 101 | @property 102 | def progress(self): 103 | return min(1, self.index / self.max) 104 | 105 | @property 106 | def remaining(self): 107 | return max(self.max - self.index, 0) 108 | 109 | def start(self): 110 | self.update() 111 | 112 | def goto(self, index): 113 | incr = index - self.index 114 | self.next(incr) 115 | 116 | def iter(self, it): 117 | try: 118 | self.max = len(it) 119 | except TypeError: 120 | pass 121 | 122 | try: 123 | for x in it: 124 | yield x 125 | self.next() 126 | finally: 127 | self.finish() 128 | -------------------------------------------------------------------------------- /common/log.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import matplotlib 4 | matplotlib.use('Agg') 5 | 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | 9 | __all__ = ['Logger', 'LoggerMonitor', 'savefig'] 10 | 11 | 12 | def savefig(fname, dpi=None): 13 | dpi = 150 if dpi == None else dpi 14 | plt.savefig(fname, dpi=dpi) 15 | 16 | 17 | def plot_overlap(logger, names=None): 18 | names = logger.names if names == None else names 19 | numbers = logger.numbers 20 | for _, name in enumerate(names): 21 | x = np.arange(len(numbers[name])) 22 | plt.plot(x, np.asarray(numbers[name])) 23 | return [logger.title + '(' + name + ')' for name in names] 24 | 25 | 26 | class Logger(object): 27 | '''Save training process to log file with simple plot function.''' 28 | 29 | def __init__(self, fpath, title=None, resume=False): 30 | self.file = None 31 | self.resume = resume 32 | self.title = '' if title == None else title 33 | if fpath is not None: 34 | if resume: 35 | self.file = open(fpath, 'r') 36 | name = self.file.readline() 37 | self.names = name.rstrip().split('\t') 38 | self.numbers = {} 39 | for _, name in enumerate(self.names): 40 | self.numbers[name] = [] 41 | 42 | for numbers in self.file: 43 | numbers = numbers.rstrip().split('\t') 44 | for i in range(0, len(numbers)): 45 | self.numbers[self.names[i]].append(numbers[i]) 46 | self.file.close() 47 | self.file = open(fpath, 'a') 48 | else: 49 | self.file = open(fpath, 'w') 50 | 51 | def set_names(self, names): 52 | if self.resume: 53 | pass 54 | # initialize numbers as empty list 55 | self.numbers = {} 56 | self.names = names 57 | for _, name in enumerate(self.names): 58 | self.file.write(name) 59 | self.file.write('\t') 60 | self.numbers[name] = [] 61 | self.file.write('\n') 62 | self.file.flush() 63 | 64 | def append(self, numbers): 65 | assert len(self.names) == len(numbers), 'Numbers do not match names' 66 | for index, num in enumerate(numbers): 67 | self.file.write("{0:.6f}".format(num)) 68 | self.file.write('\t') 69 | self.numbers[self.names[index]].append(num) 70 | self.file.write('\n') 71 | self.file.flush() 72 | 73 | def plot(self, names=None): 74 | names = self.names if names == None else names 75 | numbers = self.numbers 76 | for _, name in enumerate(names): 77 | x = np.arange(len(numbers[name])) 78 | plt.plot(x, np.asarray(numbers[name])) 79 | plt.legend([self.title + '(' + name + ')' for name in names]) 80 | plt.grid(True) 81 | 82 | def close(self): 83 | if self.file is not None: 84 | self.file.close() 85 | 86 | 87 | class LoggerMonitor(object): 88 | '''Load and visualize multiple logs.''' 89 | 90 | def __init__(self, paths): 91 | '''paths is a distionary with {name:filepath} pair''' 92 | self.loggers = [] 93 | for title, path in paths.items(): 94 | logger = Logger(path, title=title, resume=True) 95 | self.loggers.append(logger) 96 | 97 | def plot(self, names=None): 98 | plt.figure() 99 | plt.subplot(121) 100 | legend_text = [] 101 | for logger in self.loggers: 102 | legend_text += plot_overlap(logger, names) 103 | plt.legend(legend_text, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) 104 | plt.grid(True) 105 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | # Dataset setup 2 | 3 | ## Human3.6M 4 | We provide two ways to set up the Human3.6M dataset on our pipeline. You can either use the [dataset preprocessed by Martinez et al.](https://github.com/una-dinosauria/3d-pose-baseline) (fastest way) or convert the original dataset from scratch. The two methods produce the same result. After this step, you should end up with two files in the `data` directory: `data_3d_h36m.npz` for the 3D poses, and `data_2d_h36m_gt.npz` for the ground-truth 2D poses. 5 | 6 | ### Setup from preprocessed dataset 7 | Download the [h36m.zip archive](https://www.dropbox.com/s/e35qv3n6zlkouki/h36m.zip) (source: [3D pose baseline repository](https://github.com/una-dinosauria/3d-pose-baseline)) to the `data` directory, and run the conversion script from the same directory. This step does not require any additional dependency. 8 | 9 | ```sh 10 | cd data 11 | wget https://www.dropbox.com/s/e35qv3n6zlkouki/h36m.zip 12 | python prepare_data_h36m.py --from-archive h36m.zip 13 | cd .. 14 | ``` 15 | 16 | If the Dropbox link does not work, please download the dataset from [Google Drive](https://drive.google.com/drive/folders/1c7Iz6Tt7qbaw0c1snKgcGOD-JGSzuZ4X?usp=sharing). 17 | 18 | ### Setup from original source 19 | Alternatively, you can download the dataset from the [Human3.6m website](http://vision.imar.ro/human3.6m/) and convert it from its original format. This is useful if the other link goes down, or if you want to be sure to use the original source. MATLAB is required for this step. 20 | 21 | First, we need to convert the 3D poses from `.cdf` to `.mat`, so they can be loaded from Python scripts. To this end, we have provided the MATLAB script `convert_cdf_to_mat.m` in the `data` directory. Extract the archives named `Poses_D3_Positions_S*.tgz` (subjects 1, 5, 6, 7, 8, 9, 11) to a directory named `pose`, and set up your directory tree so that it looks like this: 22 | 23 | ``` 24 | /path/to/dataset/convert_cdf_to_mat.m 25 | /path/to/dataset/pose/S1/MyPoseFeatures/D3_Positions/Directions 1.cdf 26 | /path/to/dataset/pose/S1/MyPoseFeatures/D3_Positions/Directions.cdf 27 | ... 28 | ``` 29 | Then run `convert_cdf_to_mat.m` from MATLAB. 30 | 31 | Finally, as before, run the Python conversion script specifying the dataset path: 32 | ```sh 33 | cd data 34 | python prepare_data_h36m.py --from-source /path/to/dataset/pose 35 | cd .. 36 | ``` 37 | 38 | ## 2D detections for Human3.6M 39 | We provide support for the following 2D detections: 40 | 41 | - `gt`: ground-truth 2D poses, extracted through the camera projection parameters. 42 | - `sh_pt_mpii`: Stacked Hourglass detections, pretrained on MPII. 43 | - `sh_ft_h36m`: Stacked Hourglass detections, fine-tuned on Human3.6M. 44 | 45 | The 2D detection source is specified through the `--keypoints` parameter, which loads the file `data_2d_${DATASET}_${DETECTION}.npz` from the `data` directory, where `DATASET` is the dataset name (e.g., `h36m`) and `DETECTION` is the 2D detection source (e.g., `sh_pt_mpii`). Since all the files are encoded according to the same format, it is trivial to create a custom set of 2D detections. 46 | 47 | Ground-truth poses (`gt`) have already been extracted by the previous step. The other detections must be downloaded manually (see instructions below). You only need to download the detections you want to use. 48 | 49 | ### Stacked Hourglass detections 50 | These detections (both pretrained and fine-tuned) are provided by [Martinez et al.](https://github.com/una-dinosauria/3d-pose-baseline) in their repository on 3D human pose estimation. The 2D poses produced by the pretrained model are in the same archive as the dataset ([h36m.zip](https://www.dropbox.com/s/e35qv3n6zlkouki/h36m.zip)). The fine-tuned poses can be downloaded [here](https://drive.google.com/open?id=0BxWzojlLp259S2FuUXJ6aUNxZkE). Put the two archives in the `data` directory and run: 51 | 52 | ```sh 53 | cd data 54 | python prepare_data_2d_h36m_sh.py -pt h36m.zip 55 | python prepare_data_2d_h36m_sh.py -ft stacked_hourglass_fined_tuned_240.tar.gz 56 | cd .. 57 | ``` 58 | -------------------------------------------------------------------------------- /common/loss.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division 2 | 3 | import torch 4 | import numpy as np 5 | 6 | 7 | def mpjpe(predicted, target): 8 | """ 9 | Mean per-joint position error (i.e. mean Euclidean distance), 10 | often referred to as "Protocol #1" in many papers. 11 | """ 12 | assert predicted.shape == target.shape 13 | return torch.mean(torch.norm(predicted - target, dim=len(target.shape) - 1)) 14 | 15 | 16 | def weighted_mpjpe(predicted, target, w): 17 | """ 18 | Weighted mean per-joint position error (i.e. mean Euclidean distance) 19 | """ 20 | assert predicted.shape == target.shape 21 | assert w.shape[0] == predicted.shape[0] 22 | return torch.mean(w * torch.norm(predicted - target, dim=len(target.shape) - 1)) 23 | 24 | 25 | def p_mpjpe(predicted, target): 26 | """ 27 | Pose error: MPJPE after rigid alignment (scale, rotation, and translation), 28 | often referred to as "Protocol #2" in many papers. 29 | """ 30 | assert predicted.shape == target.shape 31 | 32 | muX = np.mean(target, axis=1, keepdims=True) 33 | muY = np.mean(predicted, axis=1, keepdims=True) 34 | 35 | X0 = target - muX 36 | Y0 = predicted - muY 37 | 38 | normX = np.sqrt(np.sum(X0 ** 2, axis=(1, 2), keepdims=True)) 39 | normY = np.sqrt(np.sum(Y0 ** 2, axis=(1, 2), keepdims=True)) 40 | 41 | X0 /= normX 42 | Y0 /= normY 43 | 44 | H = np.matmul(X0.transpose(0, 2, 1), Y0) 45 | U, s, Vt = np.linalg.svd(H) 46 | V = Vt.transpose(0, 2, 1) 47 | R = np.matmul(V, U.transpose(0, 2, 1)) 48 | 49 | # Avoid improper rotations (reflections), i.e. rotations with det(R) = -1 50 | sign_detR = np.sign(np.expand_dims(np.linalg.det(R), axis=1)) 51 | V[:, :, -1] *= sign_detR 52 | s[:, -1] *= sign_detR.flatten() 53 | R = np.matmul(V, U.transpose(0, 2, 1)) # Rotation 54 | 55 | tr = np.expand_dims(np.sum(s, axis=1, keepdims=True), axis=2) 56 | 57 | a = tr * normX / normY # Scale 58 | t = muX - a * np.matmul(muY, R) # Translation 59 | 60 | # Perform rigid transformation on the input 61 | predicted_aligned = a * np.matmul(predicted, R) + t 62 | 63 | # Return MPJPE 64 | return np.mean(np.linalg.norm(predicted_aligned - target, axis=len(target.shape) - 1)) 65 | 66 | 67 | def n_mpjpe(predicted, target): 68 | """ 69 | Normalized MPJPE (scale only), adapted from: 70 | https://github.com/hrhodin/UnsupervisedGeometryAwareRepresentationLearning/blob/master/losses/poses.py 71 | """ 72 | assert predicted.shape == target.shape 73 | 74 | norm_predicted = torch.mean(torch.sum(predicted ** 2, dim=3, keepdim=True), dim=2, keepdim=True) 75 | norm_target = torch.mean(torch.sum(target * predicted, dim=3, keepdim=True), dim=2, keepdim=True) 76 | scale = norm_target / norm_predicted 77 | return mpjpe(scale * predicted, target) 78 | 79 | 80 | def mean_velocity_error(predicted, target): 81 | """ 82 | Mean per-joint velocity error (i.e. mean Euclidean distance of the 1st derivative) 83 | """ 84 | assert predicted.shape == target.shape 85 | 86 | velocity_predicted = np.diff(predicted, axis=0) 87 | velocity_target = np.diff(target, axis=0) 88 | 89 | return np.mean(np.linalg.norm(velocity_predicted - velocity_target, axis=len(target.shape) - 1)) 90 | 91 | 92 | def sym_penalty(pred_out): 93 | """ 94 | get penalty for the symmetry of human body 95 | :return: 96 | """ 97 | loss_sym = 0 98 | left_bone = [(0, 4), (4, 5), (5, 6), (8, 10), (10, 11), (11, 12)] 99 | right_bone = [(0, 1), (1, 2), (2, 3), (8, 13), (13, 14), (14, 15)] 100 | 101 | for (i_left, j_left), (i_right, j_right) in zip(left_bone, right_bone): 102 | left_part = pred_out[:, i_left] - pred_out[:, j_left] 103 | right_part = pred_out[:, i_right] - pred_out[:, j_right] 104 | loss_sym += torch.mean(torch.norm(left_part, dim=- 1) - torch.norm(right_part, dim=- 1)) 105 | 106 | return loss_sym 107 | 108 | -------------------------------------------------------------------------------- /common/camera.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from common.utils import wrap 7 | from common.quaternion import qrot, qinverse 8 | 9 | 10 | def normalize_screen_coordinates(X, w, h): 11 | assert X.shape[-1] == 2 12 | 13 | # Normalize so that [0, w] is mapped to [-1, 1], while preserving the aspect ratio 14 | return X / w * 2 - [1, h / w] 15 | 16 | 17 | def image_coordinates(X, w, h): 18 | assert X.shape[-1] == 2 19 | 20 | # Reverse camera frame normalization 21 | return (X + [1, h / w]) * w / 2 22 | 23 | 24 | def world_to_camera(X, R, t): 25 | Rt = wrap(qinverse, False, R) # Invert rotation 26 | return wrap(qrot, False, np.tile(Rt, X.shape[:-1] + (1,)), X - t) # Rotate and translate 27 | 28 | 29 | def camera_to_world(X, R, t): 30 | return wrap(qrot, False, np.tile(R, X.shape[:-1] + (1,)), X) + t 31 | 32 | 33 | def project_to_2d(X, camera_params): 34 | """ 35 | Project 3D points to 2D using the Human3.6M camera projection function. 36 | This is a differentiable and batched reimplementation of the original MATLAB script. 37 | 38 | Arguments: 39 | X -- 3D points in *camera space* to transform (N, *, 3) 40 | camera_params -- intrinsic parameteres (N, 2+2+3+2=9) 41 | """ 42 | assert X.shape[-1] == 3 43 | assert len(camera_params.shape) == 2 44 | assert camera_params.shape[-1] == 9 45 | assert X.shape[0] == camera_params.shape[0] 46 | 47 | while len(camera_params.shape) < len(X.shape): 48 | camera_params = camera_params.unsqueeze(1) 49 | 50 | f = camera_params[..., :2] 51 | c = camera_params[..., 2:4] 52 | k = camera_params[..., 4:7] 53 | p = camera_params[..., 7:] 54 | 55 | XX = torch.clamp(X[..., :2] / X[..., 2:], min=-1, max=1) 56 | r2 = torch.sum(XX[..., :2] ** 2, dim=len(XX.shape) - 1, keepdim=True) 57 | 58 | radial = 1 + torch.sum(k * torch.cat((r2, r2 ** 2, r2 ** 3), dim=len(r2.shape) - 1), dim=len(r2.shape) - 1, 59 | keepdim=True) 60 | tan = torch.sum(p * XX, dim=len(XX.shape) - 1, keepdim=True) 61 | 62 | XXX = XX * (radial + tan) + p * r2 63 | 64 | return f * XXX + c 65 | 66 | 67 | def project_to_2d_linear(X, camera_params): 68 | """ 69 | Project 3D points to 2D using only linear parameters (focal length and principal point). 70 | 71 | Arguments: 72 | X -- 3D points in *camera space* to transform (N, *, 3) 73 | camera_params -- intrinsic parameteres (N, 2+2+3+2=9) 74 | """ 75 | assert X.shape[-1] == 3 76 | assert len(camera_params.shape) == 2 77 | assert camera_params.shape[-1] == 9 78 | assert X.shape[0] == camera_params.shape[0] 79 | 80 | while len(camera_params.shape) < len(X.shape): 81 | camera_params = camera_params.unsqueeze(1) 82 | 83 | f = camera_params[..., :2] 84 | c = camera_params[..., 2:4] 85 | 86 | XX = torch.clamp(X[..., :2] / X[..., 2:], min=-1, max=1) 87 | 88 | return f * XX + c 89 | 90 | 91 | def get_uvd2xyz(uvd, gt_3D, cam): 92 | """ 93 | transfer uvd to xyz 94 | :param uvd: N*T*V*3 (uv and z channel) 95 | :param gt_3D: N*T*V*3 (NOTE: V=0 is absolute depth value of root joint) 96 | :return: root-relative xyz results 97 | """ 98 | N, T, V, _ = uvd.size() 99 | 100 | dec_out_all = uvd.view(-1, T, V, 3).clone() # N*T*V*3 101 | root = gt_3D[:, :, 0, :].unsqueeze(-2).repeat(1, 1, V, 1).clone() # N*T*V*3 102 | enc_in_all = uvd[:, :, :, :2].view(-1, T, V, 2).clone() # N*T*V*2 103 | 104 | cam_f_all = cam[..., :2].view(-1, 1, 1, 2).repeat(1, T, V, 1) # N*T*V*2 105 | cam_c_all = cam[..., 2:4].view(-1, 1, 1, 2).repeat(1, T, V, 1) # N*T*V*2 106 | 107 | # change to global 108 | z_global = dec_out_all[:, :, :, 2] # N*T*V 109 | z_global[:, :, 0] = root[:, :, 0, 2] 110 | z_global[:, :, 1:] = dec_out_all[:, :, 1:, 2] + root[:, :, 1:, 2] # N*T*V 111 | z_global = z_global.unsqueeze(-1) # N*T*V*1 112 | 113 | uv = enc_in_all - cam_c_all # N*T*V*2 114 | xy = uv * z_global.repeat(1, 1, 1, 2) / cam_f_all # N*T*V*2 115 | xyz_global = torch.cat((xy, z_global), -1) # N*T*V*3 116 | xyz_offset = (xyz_global - xyz_global[:, :, 0, :].unsqueeze(-2).repeat(1, 1, V, 1)) # N*T*V*3 117 | 118 | return xyz_offset -------------------------------------------------------------------------------- /models/graph_non_local.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class _NonLocalBlock(nn.Module): 8 | def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=1, bn_layer=True): 9 | super(_NonLocalBlock, self).__init__() 10 | 11 | assert dimension in [1, 2, 3] 12 | 13 | self.dimension = dimension 14 | self.sub_sample = sub_sample 15 | 16 | self.in_channels = in_channels 17 | self.inter_channels = inter_channels 18 | 19 | if self.inter_channels is None: 20 | self.inter_channels = in_channels // 2 21 | 22 | assert self.inter_channels > 0 23 | 24 | if dimension == 3: 25 | conv_nd = nn.Conv3d 26 | max_pool = nn.MaxPool3d 27 | bn = nn.BatchNorm3d 28 | elif dimension == 2: 29 | conv_nd = nn.Conv2d 30 | max_pool = nn.MaxPool2d 31 | bn = nn.BatchNorm2d 32 | elif dimension == 1: 33 | conv_nd = nn.Conv1d 34 | max_pool = nn.MaxPool1d 35 | bn = nn.BatchNorm1d 36 | else: 37 | raise Exception('Error feature dimension.') 38 | 39 | self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 40 | kernel_size=1, stride=1, padding=0) 41 | self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 42 | kernel_size=1, stride=1, padding=0) 43 | self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 44 | kernel_size=1, stride=1, padding=0) 45 | 46 | self.concat_project = nn.Sequential( 47 | nn.Conv2d(self.inter_channels * 2, 1, 1, 1, 0, bias=False), 48 | nn.ReLU() 49 | ) 50 | 51 | nn.init.kaiming_normal_(self.concat_project[0].weight) 52 | nn.init.kaiming_normal_(self.g.weight) 53 | nn.init.constant_(self.g.bias, 0) 54 | nn.init.kaiming_normal_(self.theta.weight) 55 | nn.init.constant_(self.theta.bias, 0) 56 | nn.init.kaiming_normal_(self.phi.weight) 57 | nn.init.constant_(self.phi.bias, 0) 58 | 59 | if bn_layer: 60 | self.W = nn.Sequential( 61 | conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 62 | kernel_size=1, stride=1, padding=0), 63 | bn(self.in_channels) 64 | ) 65 | nn.init.kaiming_normal_(self.W[0].weight) 66 | nn.init.constant_(self.W[0].bias, 0) 67 | nn.init.constant_(self.W[1].weight, 0) 68 | nn.init.constant_(self.W[1].bias, 0) 69 | else: 70 | self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 71 | kernel_size=1, stride=1, padding=0) 72 | nn.init.constant_(self.W.weight, 0) 73 | nn.init.constant_(self.W.bias, 0) 74 | 75 | if sub_sample > 1: 76 | self.g = nn.Sequential(self.g, max_pool(kernel_size=sub_sample)) 77 | self.phi = nn.Sequential(self.phi, max_pool(kernel_size=sub_sample)) 78 | 79 | def forward(self, x): 80 | batch_size = x.size(0) # x: (b, c, t, h, w) 81 | 82 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 83 | g_x = g_x.permute(0, 2, 1) 84 | 85 | # (b, c, N, 1) 86 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1, 1) 87 | # (b, c, 1, N) 88 | phi_x = self.phi(x).view(batch_size, self.inter_channels, 1, -1) 89 | 90 | h = theta_x.size(2) 91 | w = phi_x.size(3) 92 | theta_x = theta_x.expand(-1, -1, -1, w) 93 | phi_x = phi_x.expand(-1, -1, h, -1) 94 | 95 | concat_feature = torch.cat([theta_x, phi_x], dim=1) 96 | f = self.concat_project(concat_feature) 97 | b, _, h, w = f.size() 98 | f = f.view(b, h, w) 99 | 100 | N = f.size(-1) 101 | f_div_C = f / N 102 | 103 | y = torch.matmul(f_div_C, g_x) 104 | y = y.permute(0, 2, 1).contiguous() 105 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 106 | W_y = self.W(y) 107 | z = W_y + x 108 | 109 | return z 110 | 111 | 112 | class GraphNonLocal(_NonLocalBlock): 113 | def __init__(self, in_channels, inter_channels=None, sub_sample=1, bn_layer=True): 114 | super(GraphNonLocal, self).__init__(in_channels, inter_channels=inter_channels, dimension=1, 115 | sub_sample=sub_sample, bn_layer=bn_layer) 116 | -------------------------------------------------------------------------------- /data/prepare_data_2d_h36m_sh.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import, division 2 | 3 | import argparse 4 | import os 5 | import zipfile 6 | import tarfile 7 | import numpy as np 8 | import h5py 9 | from glob import glob 10 | from shutil import rmtree 11 | 12 | import sys 13 | 14 | sys.path.append('../') 15 | 16 | from common.h36m_dataset import H36M_NAMES 17 | 18 | output_filename_pt = 'data_2d_h36m_sh_pt_mpii' 19 | output_filename_ft = 'data_2d_h36m_sh_ft_h36m' 20 | subjects = ['S1', 'S5', 'S6', 'S7', 'S8', 'S9', 'S11'] 21 | cam_map = { 22 | '54138969': 0, 23 | '55011271': 1, 24 | '58860488': 2, 25 | '60457274': 3, 26 | } 27 | 28 | metadata = { 29 | 'num_joints': 16, 30 | 'keypoints_symmetry': [ 31 | [3, 4, 5, 13, 14, 15], 32 | [2, 1, 0, 12, 11, 10], 33 | ] 34 | } 35 | 36 | # Stacked Hourglass produces 16 joints. These are the names. 37 | SH_NAMES = [''] * 16 38 | SH_NAMES[0] = 'RFoot' 39 | SH_NAMES[1] = 'RKnee' 40 | SH_NAMES[2] = 'RHip' 41 | SH_NAMES[3] = 'LHip' 42 | SH_NAMES[4] = 'LKnee' 43 | SH_NAMES[5] = 'LFoot' 44 | SH_NAMES[6] = 'Hip' 45 | SH_NAMES[7] = 'Spine' 46 | SH_NAMES[8] = 'Thorax' 47 | SH_NAMES[9] = 'Head' 48 | SH_NAMES[10] = 'RWrist' 49 | SH_NAMES[11] = 'RElbow' 50 | SH_NAMES[12] = 'RShoulder' 51 | SH_NAMES[13] = 'LShoulder' 52 | SH_NAMES[14] = 'LElbow' 53 | SH_NAMES[15] = 'LWrist' 54 | 55 | # Permutation that goes from SH detections to H36M ordering. 56 | SH_TO_GT_PERM = np.array([SH_NAMES.index(h) for h in H36M_NAMES if h != '' and h in SH_NAMES]) 57 | assert np.all(SH_TO_GT_PERM == np.array([6, 2, 1, 0, 3, 4, 5, 7, 8, 9, 13, 14, 15, 12, 11, 10])) 58 | 59 | metadata['keypoints_symmetry'][0] = [SH_TO_GT_PERM.tolist().index(h) for h in metadata['keypoints_symmetry'][0]] 60 | metadata['keypoints_symmetry'][1] = [SH_TO_GT_PERM.tolist().index(h) for h in metadata['keypoints_symmetry'][1]] 61 | 62 | 63 | def process_subject(subject, file_list, output): 64 | if subject == 'S11': 65 | assert len(file_list) == 119, "Expected 119 files for subject " + subject + ", got " + str(len(file_list)) 66 | else: 67 | assert len(file_list) == 120, "Expected 120 files for subject " + subject + ", got " + str(len(file_list)) 68 | 69 | for f in file_list: 70 | action, cam = os.path.splitext(os.path.basename(f))[0].replace('_', ' ').split('.') 71 | 72 | if subject == 'S11' and action == 'Directions': 73 | continue # Discard corrupted video 74 | 75 | if action not in output[subject]: 76 | output[subject][action] = [None, None, None, None] 77 | 78 | with h5py.File(f) as hf: 79 | positions = hf['poses'].value 80 | positions = positions[:, SH_TO_GT_PERM, :] 81 | output[subject][action][cam_map[cam]] = positions.astype('float32') 82 | 83 | 84 | if __name__ == '__main__': 85 | if os.path.basename(os.getcwd()) != 'data': 86 | print('This script must be launched from the "data" directory') 87 | exit(0) 88 | 89 | parser = argparse.ArgumentParser(description='Human3.6M dataset downloader/converter') 90 | 91 | parser.add_argument('-pt', '--pretrained', default='', type=str, metavar='PATH', help='convert pretrained dataset') 92 | parser.add_argument('-ft', '--fine-tuned', default='', type=str, metavar='PATH', help='convert fine-tuned dataset') 93 | 94 | args = parser.parse_args() 95 | 96 | if args.pretrained: 97 | print('Converting pretrained dataset from', args.pretrained) 98 | print('Extracting...') 99 | with zipfile.ZipFile(args.pretrained, 'r') as archive: 100 | archive.extractall('sh_pt') 101 | 102 | print('Converting...') 103 | output = {} 104 | for subject in subjects: 105 | output[subject] = {} 106 | file_list = glob('sh_pt/h36m/' + subject + '/StackedHourglass/*.h5') 107 | process_subject(subject, file_list, output) 108 | 109 | print('Saving...') 110 | np.savez_compressed(output_filename_pt, positions_2d=output, metadata=metadata) 111 | 112 | print('Cleaning up...') 113 | rmtree('sh_pt') 114 | 115 | print('Done.') 116 | 117 | if args.fine_tuned: 118 | print('Converting fine-tuned dataset from', args.fine_tuned) 119 | print('Extracting...') 120 | with tarfile.open(args.fine_tuned, 'r:gz') as archive: 121 | archive.extractall('sh_ft') 122 | 123 | print('Converting...') 124 | output = {} 125 | for subject in subjects: 126 | output[subject] = {} 127 | file_list = glob('sh_ft/' + subject + '/StackedHourglassFineTuned240/*.h5') 128 | process_subject(subject, file_list, output) 129 | 130 | print('Saving...') 131 | np.savez_compressed(output_filename_ft, positions_2d=output, metadata=metadata) 132 | 133 | print('Cleaning up...') 134 | rmtree('sh_ft') 135 | 136 | print('Done.') 137 | -------------------------------------------------------------------------------- /data/prepare_data_h36m.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import, division 2 | 3 | import argparse 4 | import os 5 | import zipfile 6 | import numpy as np 7 | from glob import glob 8 | from shutil import rmtree 9 | 10 | import sys 11 | 12 | sys.path.append('../') 13 | 14 | from common.h36m_dataset import Human36mDataset 15 | from common.camera import world_to_camera, project_to_2d, image_coordinates 16 | from common.utils import wrap 17 | 18 | output_filename = 'data_3d_h36m' 19 | output_filename_2d = 'data_2d_h36m_gt' 20 | subjects = ['S1', 'S5', 'S6', 'S7', 'S8', 'S9', 'S11'] 21 | 22 | if __name__ == '__main__': 23 | if os.path.basename(os.getcwd()) != 'data': 24 | print('This script must be launched from the "data" directory') 25 | exit(0) 26 | 27 | parser = argparse.ArgumentParser(description='Human3.6M dataset downloader/converter') 28 | 29 | # Default: convert dataset preprocessed by Martinez et al. in https://github.com/una-dinosauria/3d-pose-baseline 30 | parser.add_argument('--from-archive', default='', type=str, metavar='PATH', help='convert preprocessed dataset') 31 | 32 | # Alternatively, convert dataset from original source (the Human3.6M dataset path must be specified manually) 33 | parser.add_argument('--from-source', default='', type=str, metavar='PATH', help='convert original dataset') 34 | 35 | args = parser.parse_args() 36 | 37 | if args.from_archive and args.from_source: 38 | print('Please specify only one argument') 39 | exit(0) 40 | 41 | if os.path.exists(output_filename + '.npz'): 42 | print('The dataset already exists at', output_filename + '.npz') 43 | exit(0) 44 | 45 | if args.from_archive: 46 | print('Extracting Human3.6M dataset from', args.from_archive) 47 | with zipfile.ZipFile(args.from_archive, 'r') as archive: 48 | archive.extractall() 49 | 50 | print('Converting...') 51 | output = {} 52 | 53 | import h5py 54 | 55 | for subject in subjects: 56 | output[subject] = {} 57 | file_list = glob('h36m/' + subject + '/MyPoses/3D_positions/*.h5') 58 | assert len(file_list) == 30, "Expected 30 files for subject " + subject + ", got " + str(len(file_list)) 59 | for f in file_list: 60 | action = os.path.splitext(os.path.basename(f))[0] 61 | 62 | if subject == 'S11' and action == 'Directions': 63 | continue # Discard corrupted video 64 | 65 | with h5py.File(f) as hf: 66 | positions = hf['3D_positions'].value.reshape(32, 3, -1).transpose(2, 0, 1) 67 | positions /= 1000 # Meters instead of millimeters 68 | output[subject][action] = positions.astype('float32') 69 | 70 | print('Saving...') 71 | np.savez_compressed(output_filename, positions_3d=output) 72 | 73 | print('Cleaning up...') 74 | rmtree('h36m') 75 | 76 | print('Done.') 77 | 78 | elif args.from_source: 79 | print('Converting original Human3.6M dataset from', args.from_source) 80 | output = {} 81 | 82 | from scipy.io import loadmat 83 | 84 | for subject in subjects: 85 | output[subject] = {} 86 | file_list = glob(args.from_source + '/' + subject + '/MyPoseFeatures/D3_Positions/*.cdf.mat') 87 | assert len(file_list) == 30, "Expected 30 files for subject " + subject + ", got " + str(len(file_list)) 88 | for f in file_list: 89 | action = os.path.splitext(os.path.splitext(os.path.basename(f))[0])[0] 90 | 91 | if subject == 'S11' and action == 'Directions': 92 | continue # Discard corrupted video 93 | 94 | # Use consistent naming convention 95 | canonical_name = action.replace('TakingPhoto', 'Photo').replace('WalkingDog', 'WalkDog') 96 | 97 | hf = loadmat(f) 98 | positions = hf['data'][0, 0].reshape(-1, 32, 3) 99 | positions /= 1000 # Meters instead of millimeters 100 | output[subject][canonical_name] = positions.astype('float32') 101 | 102 | print('Saving...') 103 | np.savez_compressed(output_filename, positions_3d=output) 104 | 105 | print('Done.') 106 | 107 | else: 108 | print('Please specify the dataset source') 109 | exit(0) 110 | 111 | # Create 2D pose file 112 | print('') 113 | print('Computing ground-truth 2D poses...') 114 | dataset = Human36mDataset(output_filename + '.npz') 115 | output_2d_poses = {} 116 | for subject in dataset.subjects(): 117 | output_2d_poses[subject] = {} 118 | for action in dataset[subject].keys(): 119 | anim = dataset[subject][action] 120 | 121 | positions_2d = [] 122 | for cam in anim['cameras']: 123 | pos_3d = world_to_camera(anim['positions'], R=cam['orientation'], t=cam['translation']) 124 | pos_2d = wrap(project_to_2d, True, pos_3d, cam['intrinsic']) 125 | pos_2d_pixel_space = image_coordinates(pos_2d, w=cam['res_w'], h=cam['res_h']) 126 | positions_2d.append(pos_2d_pixel_space.astype('float32')) 127 | output_2d_poses[subject][action] = positions_2d 128 | 129 | print('Saving...') 130 | metadata = { 131 | 'num_joints': dataset.skeleton().num_joints(), 132 | 'keypoints_symmetry': [dataset.skeleton().joints_left(), dataset.skeleton().joints_right()] 133 | } 134 | np.savez_compressed(output_filename_2d, positions_2d=output_2d_poses, metadata=metadata) 135 | 136 | print('Done.') 137 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Graph Stacked Hourglass Network (CVPR 2021) 2 | 3 | 4 | 5 |

6 | 7 | 8 | This repository contains the pytorch implementation of the approach described in the paper: 9 | > Tianhan Xu and Wataru Takano. 10 | [Graph Stacked Hourglass Networks for 3D Human Pose Estimation](https://arxiv.org/pdf/2103.16385.pdf) 11 | Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), 2021, pp. 16105-16114 12 | 13 | 14 | ``` 15 | @InProceedings{Xu_2021_CVPR, 16 | author = {Xu, Tianhan and Takano, Wataru}, 17 | title = {Graph Stacked Hourglass Networks for 3D Human Pose Estimation}, 18 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 19 | month = {June}, 20 | year = {2021}, 21 | pages = {16105-16114} 22 | } 23 | ``` 24 | 25 | ## Introduction 26 | 27 | We evaluate models for 3D human pose estimation on the [Human3.6M Dataset](http://vision.imar.ro/human3.6m/). 28 | 29 | In this repository, only 2D joints of the human pose are exploited as inputs. We utilize the method described in Pavllo et al. [2] to normalize 2D and 3D poses in the dataset. To be specific, 2D poses are scaled according to the image resolution and normalized to [-1, 1]; 3D poses are aligned with respect to the root joint. Please refer to the corresponding part in Pavllo et al. [2] for more details. For the 2D ground truth, we predict 16 joints (as the skeleton in Martinez et al. [1] and Zhao et al. [3] without the 'Neck/Nose' joint). For the 2D pose detections, the 'Neck/Nose' joint is reserved. 30 | 31 | 32 | ## Quickstart 33 | 34 | This repository is build upon Python v3.8 and Pytorch v1.9.0 on Ubuntu 20.04 LTS. All experiments are conducted on a single NVIDIA RTX 3090 GPU. See [`requirements.txt`](requirements.txt) for other dependencies. We recommend installing Python v3.78 from [Anaconda](https://www.anaconda.com/) and installing Pytorch (>= 1.9.0) following guide on the [official instructions](https://pytorch.org/) according to your specific CUDA version. Then you can install dependencies with the following commands. 35 | 36 | ``` 37 | git clone https://github.com/tamasino52/GraphSH.git 38 | cd GraphSH 39 | pip install -r requirements.txt 40 | ``` 41 | 42 | ### Benchmark setup 43 | CPN 2D detections for Human3.6M datasets are provided by [VideoPose3D](https://github.com/facebookresearch/VideoPose3D) Pavllo et al. [2], which can be downloaded by the following steps: 44 | 45 | ``` 46 | cd data 47 | wget https://dl.fbaipublicfiles.com/video-pose-3d/data_2d_h36m_cpn_ft_h36m_dbb.npz 48 | wget https://dl.fbaipublicfiles.com/video-pose-3d/data_2d_h36m_detectron_ft_h36m.npz 49 | cd .. 50 | ``` 51 | 52 | 3D labels and ground truth can be downloaded 53 | ``` 54 | cd data 55 | gdown --id 1P7W3ldx2lxaYJJYcf3RG4Y9PsD4EJ6b0 56 | cd .. 57 | ``` 58 | 59 | ### GT setup 60 | 61 | GT 2D keypoints for Human3.6M datasets are obtained from [SemGCN](https://github.com/garyzhao/SemGCN) Zhao et al. [3], which can be downloaded by the following steps: 62 | ``` 63 | cd data 64 | pip install gdown 65 | gdown https://drive.google.com/uc?id=1Ac-gUXAg-6UiwThJVaw6yw2151Bot3L1 66 | python prepare_data_h36m.py --from-archive h36m.zip 67 | cd .. 68 | ``` 69 | After this step, you should end up with two files in the dataset directory: data_3d_h36m.npz for the 3D poses, and data_2d_h36m_gt.npz for the ground-truth 2D poses. 70 | 71 | ### GT Evaluation 72 | ``` 73 | python main_gcn.py --gcn {gcn_name} --evaluate checkpoint/{weight_name}.pth.tar 74 | ``` 75 | 76 | ### GT Training 77 | ``` 78 | # Decoupled Vanilla GCN (What method used in paper) 79 | python main_gcn.py --gcn dc_vanilla 80 | 81 | # Decoupled Pre-Aggresive GCN (What method used in paper) 82 | python main_gcn.py --gcn dc_preagg 83 | 84 | # Semantic GCN (What method used in paper) 85 | python main_gcn.py --gcn semantic 86 | 87 | # Decoupled Post-Aggresive GCN 88 | python main_gcn.py --gcn dc_postagg 89 | 90 | # Convolution-style GCN 91 | python main_gcn.py --gcn convst 92 | 93 | # No-sharing GCN 94 | python main_gcn.py --gcn nosharing 95 | 96 | # Modulated GCN 97 | python main_gcn.py --gcn modulated 98 | ``` 99 | 100 | ### Training 101 | 102 | ``` 103 | # Decoupled Vanilla GCN 104 | python main_gcn.py --gcn dc_vanilla --keypoints cpn_ft_h36m_dbb 105 | 106 | # Decoupled Pre-Aggresive GCN 107 | python main_gcn.py --gcn dc_preagg --keypoints cpn_ft_h36m_dbb 108 | ``` 109 | 110 | ### Pre-trained weight 111 | I implemented and tested all the elements proposed in this paper, but failed to reach the benchmark score presented in the paper. Instead, I uploaded a good enough one, so use it if you need it. 112 | 113 | [Download Link.](https://drive.google.com/file/d/1FQpAnNyycKXgqlJ7vitFgP7KDwD365sQ/view?usp=sharing) 114 | 115 | ### Acknowledgement 116 | This code is extended from the following repositories. 117 | - [3d-pose-baseline](https://github.com/una-dinosauria/3d-pose-baseline) 118 | - [3d_pose_baseline_pytorch](https://github.com/weigq/3d_pose_baseline_pytorch) 119 | - [VideoPose3D](https://github.com/facebookresearch/VideoPose3D) 120 | - [Semantic GCN](https://github.com/garyzhao/SemGCN) 121 | - [Local-to-Global GCN](https://github.com/vanoracai/Exploiting-Spatial-temporal-Relationships-for-3D-Pose-Estimation-via-Graph-Convolutional-Networks) 122 | - [Modulated-GCN](https://github.com/ZhimingZo/Modulated-GCN) 123 | 124 | Thank you to authors for releasing their code. Please also consider citing their works. 125 | -------------------------------------------------------------------------------- /common/visualization.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import, division 2 | 3 | import matplotlib 4 | matplotlib.use('Agg') 5 | 6 | import matplotlib.pyplot as plt 7 | from matplotlib.animation import FuncAnimation, writers 8 | from mpl_toolkits.mplot3d import Axes3D 9 | import numpy as np 10 | import subprocess as sp 11 | 12 | 13 | def get_resolution(filename): 14 | command = ['ffprobe', '-v', 'error', '-select_streams', 'v:0', 15 | '-show_entries', 'stream=width,height', '-of', 'csv=p=0', filename] 16 | 17 | try: 18 | pipe = sp.Popen(command, stdout=sp.PIPE, bufsize=-1) 19 | for line in pipe.stdout: 20 | w, h = line.decode().strip().split(',') 21 | finally: 22 | pipe.stdout.close() 23 | 24 | return int(w), int(h) 25 | 26 | 27 | def read_video(filename, skip=0, limit=-1): 28 | w, h = get_resolution(filename) 29 | 30 | command = ['ffmpeg', 31 | '-i', filename, 32 | '-f', 'image2pipe', 33 | '-pix_fmt', 'rgb24', 34 | '-vsync', '0', 35 | '-vcodec', 'rawvideo', '-'] 36 | 37 | i = 0 38 | pipe = sp.Popen(command, stdout=sp.PIPE, bufsize=-1) 39 | try: 40 | while True: 41 | data = pipe.stdout.read(w * h * 3) 42 | if not data: 43 | break 44 | i += 1 45 | if i > skip: 46 | yield np.frombuffer(data, dtype='uint8').reshape((h, w, 3)) 47 | if i == limit: 48 | break 49 | finally: 50 | pipe.stdout.close() 51 | 52 | 53 | def downsample_tensor(X, factor): 54 | length = X.shape[0] // factor * factor 55 | return np.mean(X[:length].reshape(-1, factor, *X.shape[1:]), axis=1) 56 | 57 | 58 | def render_animation(keypoints, poses, skeleton, fps, bitrate, azim, output, viewport, 59 | limit=-1, downsample=1, size=6, input_video_path=None, input_video_skip=0): 60 | 61 | plt.ioff() 62 | fig = plt.figure(figsize=(size * (1 + len(poses)), size)) 63 | ax_in = fig.add_subplot(1, 1 + len(poses), 1) 64 | ax_in.get_xaxis().set_visible(False) 65 | ax_in.get_yaxis().set_visible(False) 66 | ax_in.set_axis_off() 67 | ax_in.set_title('Input') 68 | 69 | ax_3d = [] 70 | lines_3d = [] 71 | trajectories = [] 72 | radius = 1.7 73 | for index, (title, data) in enumerate(poses.items()): 74 | ax = fig.add_subplot(1, 1 + len(poses), index + 2, projection='3d') 75 | ax.view_init(elev=15., azim=azim) 76 | ax.set_xlim3d([-radius / 2, radius / 2]) 77 | ax.set_zlim3d([0, radius]) 78 | ax.set_ylim3d([-radius / 2, radius / 2]) 79 | #ax.set_aspect('equal') 80 | ax.set_xticklabels([]) 81 | ax.set_yticklabels([]) 82 | ax.set_zticklabels([]) 83 | ax.dist = 7.5 84 | ax.set_title(title) # , pad=35 85 | ax_3d.append(ax) 86 | lines_3d.append([]) 87 | trajectories.append(data[:, 0, [0, 1]]) 88 | poses = list(poses.values()) 89 | 90 | # Decode video 91 | if input_video_path is None: 92 | # Black background 93 | all_frames = np.zeros((keypoints.shape[0], viewport[1], viewport[0]), dtype='uint8') 94 | else: 95 | # Load video using ffmpeg 96 | all_frames = [] 97 | for f in read_video(input_video_path, skip=input_video_skip): 98 | all_frames.append(f) 99 | effective_length = min(keypoints.shape[0], len(all_frames)) 100 | all_frames = all_frames[:effective_length] 101 | 102 | if downsample > 1: 103 | keypoints = downsample_tensor(keypoints, downsample) 104 | all_frames = downsample_tensor(np.array(all_frames), downsample).astype('uint8') 105 | for idx in range(len(poses)): 106 | poses[idx] = downsample_tensor(poses[idx], downsample) 107 | trajectories[idx] = downsample_tensor(trajectories[idx], downsample) 108 | fps /= downsample 109 | 110 | render_animation.initialized = False 111 | render_animation.image = None 112 | render_animation.lines = [] 113 | render_animation.points = None 114 | 115 | if limit < 1: 116 | limit = len(all_frames) 117 | else: 118 | limit = min(limit, len(all_frames)) 119 | 120 | parents = skeleton.parents() 121 | 122 | def update_video(i): 123 | for n, ax in enumerate(ax_3d): 124 | ax.set_xlim3d([-radius / 2 + trajectories[n][i, 0], radius / 2 + trajectories[n][i, 0]]) 125 | ax.set_ylim3d([-radius / 2 + trajectories[n][i, 1], radius / 2 + trajectories[n][i, 1]]) 126 | 127 | # Update 2D poses 128 | if not render_animation.initialized: 129 | render_animation.image = ax_in.imshow(all_frames[i], aspect='equal') 130 | 131 | for j, j_parent in enumerate(parents): 132 | if j_parent == -1: 133 | continue 134 | 135 | if len(parents) == keypoints.shape[1]: 136 | # Draw skeleton only if keypoints match (otherwise we don't have the parents definition) 137 | render_animation.lines.append(ax_in.plot([keypoints[i, j, 0], keypoints[i, j_parent, 0]], 138 | [keypoints[i, j, 1], keypoints[i, j_parent, 1]], color='b')) 139 | 140 | col = 'red' if j in skeleton.joints_right() else 'black' 141 | for n, ax in enumerate(ax_3d): 142 | pos = poses[n][i] 143 | lines_3d[n].append(ax.plot([pos[j, 0], pos[j_parent, 0]], 144 | [pos[j, 1], pos[j_parent, 1]], 145 | [pos[j, 2], pos[j_parent, 2]], zdir='z', c=col)) 146 | 147 | render_animation.points = ax_in.scatter(keypoints[i].T[0], keypoints[i].T[1], 5, color='red', 148 | edgecolors='white', zorder=10) 149 | 150 | render_animation.initialized = True 151 | else: 152 | render_animation.image.set_data(all_frames[i]) 153 | 154 | for j, j_parent in enumerate(parents): 155 | if j_parent == -1: 156 | continue 157 | 158 | if len(parents) == keypoints.shape[1]: 159 | render_animation.lines[j - 1][0].set_data([keypoints[i, j, 0], keypoints[i, j_parent, 0]], 160 | [keypoints[i, j, 1], keypoints[i, j_parent, 1]]) 161 | 162 | for n, ax in enumerate(ax_3d): 163 | pos = poses[n][i] 164 | lines_3d[n][j - 1][0].set_xdata([pos[j, 0], pos[j_parent, 0]]) 165 | lines_3d[n][j - 1][0].set_ydata([pos[j, 1], pos[j_parent, 1]]) 166 | lines_3d[n][j - 1][0].set_3d_properties([pos[j, 2], pos[j_parent, 2]], zdir='z') 167 | 168 | render_animation.points.set_offsets(keypoints[i]) 169 | 170 | print('{}/{} '.format(i, limit), end='\r') 171 | 172 | fig.tight_layout() 173 | 174 | anim = FuncAnimation(fig, update_video, frames=np.arange(0, limit), interval=1000 / fps, repeat=False) 175 | if output.endswith('.mp4'): 176 | Writer = writers['ffmpeg'] 177 | writer = Writer(fps=fps, metadata={}, bitrate=bitrate) 178 | anim.save(output, writer=writer) 179 | elif output.endswith('.gif'): 180 | anim.save(output, dpi=80, writer='imagemagick') 181 | else: 182 | raise ValueError('Unsupported output format (only .mp4 and .gif are supported)') 183 | plt.close() 184 | -------------------------------------------------------------------------------- /models/graph_sh.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import torch 3 | import torch.nn as nn 4 | from functools import reduce 5 | from common.graph_utils import adj_mx_from_edges 6 | 7 | from models.gconv.vanilla_graph_conv import DecoupleVanillaGraphConv 8 | from models.gconv.pre_agg_graph_conv import DecouplePreAggGraphConv 9 | from models.gconv.post_agg_graph_conv import DecouplePostAggGraphConv 10 | from models.gconv.conv_style_graph_conv import ConvStyleGraphConv 11 | from models.gconv.no_sharing_graph_conv import NoSharingGraphConv 12 | from models.gconv.modulated_gcn_conv import ModulatedGraphConv 13 | from models.gconv.sem_graph_conv import SemGraphConv 14 | 15 | from models.graph_non_local import GraphNonLocal 16 | 17 | 18 | class _GraphConv(nn.Module): 19 | def __init__(self, adj, input_dim, output_dim, p_dropout=None, gcn_type=None): 20 | super(_GraphConv, self).__init__() 21 | 22 | if gcn_type == 'vanilla': 23 | self.gconv = DecoupleVanillaGraphConv(input_dim, output_dim, adj, decouple=False) 24 | elif gcn_type == 'dc_vanilla': 25 | self.gconv = DecoupleVanillaGraphConv(input_dim, output_dim, adj) 26 | elif gcn_type == 'preagg': 27 | self.gconv = DecouplePreAggGraphConv(input_dim, output_dim, adj, decouple=False) 28 | elif gcn_type == 'dc_preagg': 29 | self.gconv = DecouplePreAggGraphConv(input_dim, output_dim, adj) 30 | elif gcn_type == 'postagg': 31 | self.gconv = DecouplePostAggGraphConv(input_dim, output_dim, adj, decouple=False) 32 | elif gcn_type == 'dc_postagg': 33 | self.gconv = DecouplePostAggGraphConv(input_dim, output_dim, adj) 34 | elif gcn_type == 'convst': 35 | self.gconv = ConvStyleGraphConv(input_dim, output_dim, adj) 36 | elif gcn_type == 'nosharing': 37 | self.gconv = NoSharingGraphConv(input_dim, output_dim, adj) 38 | elif gcn_type == 'modulated': 39 | self.gconv = ModulatedGraphConv(input_dim, output_dim, adj) 40 | elif gcn_type == 'semantic': 41 | self.gconv = SemGraphConv(input_dim, output_dim, adj) 42 | 43 | 44 | else: 45 | assert False, 'Invalid graph convolution type' 46 | 47 | self.bn = nn.BatchNorm1d(output_dim) 48 | self.relu = nn.ReLU() 49 | 50 | if p_dropout is not None: 51 | self.dropout = nn.Dropout(p_dropout) 52 | else: 53 | self.dropout = None 54 | 55 | def forward(self, x): 56 | x = self.gconv(x).transpose(1, 2) 57 | x = self.bn(x).transpose(1, 2) 58 | if self.dropout is not None: 59 | x = self.dropout(self.relu(x)) 60 | 61 | x = self.relu(x) 62 | return x 63 | 64 | 65 | class _Hourglass(nn.Module): 66 | def __init__(self, adj, input_dim, output_dim, hid_dim1, hid_dim2, nodes_group, p_dropout, gcn_type): 67 | super(_Hourglass, self).__init__() 68 | 69 | adj_mid = adj_mx_from_edges(8, [[0, 2], [1, 2], [2, 3], [3, 7], [4, 7], [5, 7], [6, 7]], sparse=False) 70 | adj_low = adj_mx_from_edges(4, [[0, 1], [1, 2], [2, 3]], sparse=False) 71 | 72 | self.gconv1 = _GraphConv(adj, input_dim, hid_dim1, p_dropout, gcn_type) 73 | self.gconv2 = _GraphConv(adj_mid, hid_dim1, hid_dim2, p_dropout, gcn_type) 74 | self.gconv3 = _GraphConv(adj_low, hid_dim2, hid_dim2, p_dropout, gcn_type) 75 | self.gconv4 = _GraphConv(adj_mid, hid_dim2, hid_dim1, p_dropout, gcn_type) 76 | self.gconv5 = _GraphConv(adj, hid_dim1, output_dim, p_dropout, gcn_type) 77 | 78 | self.pool = _SkeletalPool(nodes_group) 79 | self.unpool = _SkeletalUnpool(nodes_group) 80 | 81 | def forward(self, x): 82 | skip1 = x 83 | skip2 = self.gconv1(skip1) 84 | skip3 = self.gconv2(self.pool(skip2)) 85 | out = self.gconv3(self.pool(skip3)) 86 | out = self.gconv4(self.unpool(out) + skip3) 87 | out = self.gconv5(self.unpool(out) + skip2) 88 | return out + skip1 89 | 90 | 91 | class _SkeletalPool(nn.Module): 92 | def __init__(self, nodes_group): 93 | super(_SkeletalPool, self).__init__() 94 | self.high_group = sum(nodes_group, []) 95 | self.mid_group = [0, 1, 2, 3, 5, 6, 4, 7] 96 | self.pool = nn.MaxPool1d(kernel_size=2, stride=2) 97 | 98 | def forward(self, x): 99 | if x.shape[1] == 16: 100 | out = self.pool(x[:, self.high_group].transpose(1, 2)) 101 | return out.transpose(1, 2) 102 | elif x.shape[1] == 8: 103 | out = self.pool(x[:, self.mid_group].transpose(1, 2)) 104 | return out.transpose(1, 2) 105 | else: 106 | assert False, 'Invalid Type in Skeletal Pooling : x.shape is {}'.format(x.shape) 107 | 108 | 109 | class _SkeletalUnpool(nn.Module): 110 | def __init__(self, nodes_group): 111 | super(_SkeletalUnpool, self).__init__() 112 | self.nodes_group = sum(nodes_group, []) 113 | self.inv_low = [0, 0, 1, 1, 3, 2, 2, 3] 114 | self.inv_mid = [3, 2, 1, 1, 2, 0, 0, 3, 4, 4, 7, 6, 6, 7, 5, 5] 115 | 116 | def forward(self, x): 117 | if x.shape[1] == 8: 118 | return x[:, self.inv_mid] 119 | elif x.shape[1] == 4: 120 | return x[:, self.inv_low] 121 | else: 122 | assert False, 'Invalid Type in Skeletal Unpooling : x.shape is {}'.format(x.shape) 123 | 124 | 125 | class _GraphNonLocal(nn.Module): 126 | def __init__(self, hid_dim, grouped_order, restored_order, group_size): 127 | super(_GraphNonLocal, self).__init__() 128 | 129 | self.non_local = GraphNonLocal(hid_dim, sub_sample=group_size) 130 | self.grouped_order = grouped_order 131 | self.restored_order = restored_order 132 | 133 | def forward(self, x): 134 | out = x[:, self.grouped_order, :] 135 | out = self.non_local(out.transpose(1, 2)).transpose(1, 2) 136 | out = out[:, self.restored_order, :] 137 | return out 138 | 139 | 140 | class SEBlock(nn.Module): 141 | def __init__(self, adj, input_dim, reduction_ratio=8): 142 | super(SEBlock, self).__init__() 143 | hid_dim = input_dim // reduction_ratio 144 | self.fc1 = nn.Linear(input_dim, hid_dim, bias=True) 145 | self.fc2 = nn.Linear(hid_dim, input_dim, bias=True) 146 | self.gap = nn.AvgPool1d(kernel_size=adj.shape[-1]) 147 | self.relu = nn.ReLU() 148 | self.sigmoid = nn.Sigmoid() 149 | 150 | def forward(self, x): 151 | out = self.gap(x) 152 | out = self.relu(self.fc1(out.squeeze())) 153 | out = self.sigmoid(self.fc2(out)) 154 | 155 | return x * out[:, :, None] 156 | 157 | 158 | class GraphSH(nn.Module): 159 | def __init__(self, adj, hid_dim, nodes_group, coords_dim=(2, 3), num_layers=4, p_dropout=None, gcn_type=None): 160 | super(GraphSH, self).__init__() 161 | 162 | self.gconv_input = _GraphConv(adj, coords_dim[0], hid_dim, p_dropout=p_dropout, gcn_type=gcn_type) 163 | self.num_layers = num_layers 164 | _gconv_layers = [] 165 | _conv_layers = [] 166 | 167 | group_size = len(nodes_group[0]) 168 | assert group_size > 1 169 | 170 | grouped_order = list(reduce(lambda x, y: x + y, nodes_group)) 171 | restored_order = [0] * len(grouped_order) 172 | for i in range(len(restored_order)): 173 | for j in range(len(grouped_order)): 174 | if grouped_order[j] == i: 175 | restored_order[i] = j 176 | break 177 | 178 | for i in range(num_layers): 179 | _gconv_layers.append(_Hourglass(adj, hid_dim, hid_dim, int(hid_dim * 1.5), hid_dim * 2, nodes_group, p_dropout, gcn_type)) 180 | _conv_layers.append(nn.Conv1d(hid_dim, hid_dim // num_layers, 1)) 181 | 182 | self.gconv_layers = nn.ModuleList(_gconv_layers) 183 | self.conv_layers = nn.ModuleList(_conv_layers) 184 | 185 | self.se_blocks = SEBlock(adj, hid_dim) 186 | self.gconv_output = nn.Conv1d(hid_dim, coords_dim[1], 1) 187 | 188 | def forward(self, x): 189 | out = self.gconv_input(x) 190 | inter_fs = [] 191 | for l in range(self.num_layers): 192 | out = self.gconv_layers[l](out) 193 | inter_fs.append(self.conv_layers[l](out.transpose(1, 2)).transpose(1, 2)) 194 | f_out = torch.cat(inter_fs, dim=2) 195 | out = self.se_blocks(f_out.transpose(1, 2)) 196 | out = self.gconv_output(out).transpose(1, 2) 197 | return out 198 | -------------------------------------------------------------------------------- /common/h36m_dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division 2 | 3 | import numpy as np 4 | import copy 5 | from common.skeleton import Skeleton 6 | from common.mocap_dataset import MocapDataset 7 | from common.camera import normalize_screen_coordinates 8 | 9 | h36m_skeleton = Skeleton(parents=[-1, 0, 1, 2, 3, 4, 0, 6, 7, 8, 9, 0, 11, 12, 13, 14, 12, 10 | 16, 17, 18, 19, 20, 19, 22, 12, 24, 25, 26, 27, 28, 27, 30], 11 | joints_left=[6, 7, 8, 9, 10, 16, 17, 18, 19, 20, 21, 22, 23], 12 | joints_right=[1, 2, 3, 4, 5, 24, 25, 26, 27, 28, 29, 30, 31]) 13 | 14 | h36m_skeleton_joints_group = [[2, 3], [5, 6], [1, 4], [0, 7], [8, 9], [14, 15], [11, 12], [10, 13]] 15 | 16 | # Joints in H3.6M -- data has 32 joints, but only 17 that move; these are the indices. 17 | H36M_NAMES = [''] * 32 18 | H36M_NAMES[0] = 'Hip' 19 | H36M_NAMES[1] = 'RHip' 20 | H36M_NAMES[2] = 'RKnee' 21 | H36M_NAMES[3] = 'RFoot' 22 | H36M_NAMES[6] = 'LHip' 23 | H36M_NAMES[7] = 'LKnee' 24 | H36M_NAMES[8] = 'LFoot' 25 | H36M_NAMES[12] = 'Spine' 26 | H36M_NAMES[13] = 'Thorax' 27 | H36M_NAMES[14] = 'Neck/Nose' 28 | H36M_NAMES[15] = 'Head' 29 | H36M_NAMES[17] = 'LShoulder' 30 | H36M_NAMES[18] = 'LElbow' 31 | H36M_NAMES[19] = 'LWrist' 32 | H36M_NAMES[25] = 'RShoulder' 33 | H36M_NAMES[26] = 'RElbow' 34 | H36M_NAMES[27] = 'RWrist' 35 | 36 | # Human3.6m IDs for training and testing 37 | TRAIN_SUBJECTS = ['S1', 'S5', 'S6', 'S7', 'S8'] 38 | TEST_SUBJECTS = ['S9', 'S11'] 39 | 40 | h36m_cameras_intrinsic_params = [ 41 | { 42 | 'id': '54138969', 43 | 'center': [512.54150390625, 515.4514770507812], 44 | 'focal_length': [1145.0494384765625, 1143.7811279296875], 45 | 'radial_distortion': [-0.20709891617298126, 0.24777518212795258, -0.0030751503072679043], 46 | 'tangential_distortion': [-0.0009756988729350269, -0.00142447161488235], 47 | 'res_w': 1000, 48 | 'res_h': 1002, 49 | 'azimuth': 70, # Only used for visualization 50 | }, 51 | { 52 | 'id': '55011271', 53 | 'center': [508.8486328125, 508.0649108886719], 54 | 'focal_length': [1149.6756591796875, 1147.5916748046875], 55 | 'radial_distortion': [-0.1942136287689209, 0.2404085397720337, 0.006819975562393665], 56 | 'tangential_distortion': [-0.0016190266469493508, -0.0027408944442868233], 57 | 'res_w': 1000, 58 | 'res_h': 1000, 59 | 'azimuth': -70, # Only used for visualization 60 | }, 61 | { 62 | 'id': '58860488', 63 | 'center': [519.8158569335938, 501.40264892578125], 64 | 'focal_length': [1149.1407470703125, 1148.7989501953125], 65 | 'radial_distortion': [-0.2083381861448288, 0.25548800826072693, -0.0024604974314570427], 66 | 'tangential_distortion': [0.0014843869721516967, -0.0007599993259645998], 67 | 'res_w': 1000, 68 | 'res_h': 1000, 69 | 'azimuth': 110, # Only used for visualization 70 | }, 71 | { 72 | 'id': '60457274', 73 | 'center': [514.9682006835938, 501.88201904296875], 74 | 'focal_length': [1145.5113525390625, 1144.77392578125], 75 | 'radial_distortion': [-0.198384091258049, 0.21832367777824402, -0.008947807364165783], 76 | 'tangential_distortion': [-0.0005872055771760643, -0.0018133620033040643], 77 | 'res_w': 1000, 78 | 'res_h': 1002, 79 | 'azimuth': -110, # Only used for visualization 80 | }, 81 | ] 82 | 83 | h36m_cameras_extrinsic_params = { 84 | 'S1': [ 85 | { 86 | 'orientation': [0.1407056450843811, -0.1500701755285263, -0.755240797996521, 0.6223280429840088], 87 | 'translation': [1841.1070556640625, 4955.28466796875, 1563.4454345703125], 88 | }, 89 | { 90 | 'orientation': [0.6157187819480896, -0.764836311340332, -0.14833825826644897, 0.11794740706682205], 91 | 'translation': [1761.278564453125, -5078.0068359375, 1606.2650146484375], 92 | }, 93 | { 94 | 'orientation': [0.14651472866535187, -0.14647851884365082, 0.7653023600578308, -0.6094175577163696], 95 | 'translation': [-1846.7777099609375, 5215.04638671875, 1491.972412109375], 96 | }, 97 | { 98 | 'orientation': [0.5834008455276489, -0.7853162288665771, 0.14548823237419128, -0.14749594032764435], 99 | 'translation': [-1794.7896728515625, -3722.698974609375, 1574.8927001953125], 100 | }, 101 | ], 102 | 'S2': [ 103 | {}, 104 | {}, 105 | {}, 106 | {}, 107 | ], 108 | 'S3': [ 109 | {}, 110 | {}, 111 | {}, 112 | {}, 113 | ], 114 | 'S4': [ 115 | {}, 116 | {}, 117 | {}, 118 | {}, 119 | ], 120 | 'S5': [ 121 | { 122 | 'orientation': [0.1467377245426178, -0.162370964884758, -0.7551892995834351, 0.6178938746452332], 123 | 'translation': [2097.3916015625, 4880.94482421875, 1605.732421875], 124 | }, 125 | { 126 | 'orientation': [0.6159758567810059, -0.7626792192459106, -0.15728192031383514, 0.1189815029501915], 127 | 'translation': [2031.7008056640625, -5167.93310546875, 1612.923095703125], 128 | }, 129 | { 130 | 'orientation': [0.14291371405124664, -0.12907841801643372, 0.7678384780883789, -0.6110143065452576], 131 | 'translation': [-1620.5948486328125, 5171.65869140625, 1496.43701171875], 132 | }, 133 | { 134 | 'orientation': [0.5920479893684387, -0.7814217805862427, 0.1274748593568802, -0.15036417543888092], 135 | 'translation': [-1637.1737060546875, -3867.3173828125, 1547.033203125], 136 | }, 137 | ], 138 | 'S6': [ 139 | { 140 | 'orientation': [0.1337897777557373, -0.15692396461963654, -0.7571090459823608, 0.6198879480361938], 141 | 'translation': [1935.4517822265625, 4950.24560546875, 1618.0838623046875], 142 | }, 143 | { 144 | 'orientation': [0.6147197484970093, -0.7628812789916992, -0.16174767911434174, 0.11819244921207428], 145 | 'translation': [1969.803955078125, -5128.73876953125, 1632.77880859375], 146 | }, 147 | { 148 | 'orientation': [0.1529948115348816, -0.13529130816459656, 0.7646096348762512, -0.6112781167030334], 149 | 'translation': [-1769.596435546875, 5185.361328125, 1476.993408203125], 150 | }, 151 | { 152 | 'orientation': [0.5916101336479187, -0.7804774045944214, 0.12832270562648773, -0.1561593860387802], 153 | 'translation': [-1721.668701171875, -3884.13134765625, 1540.4879150390625], 154 | }, 155 | ], 156 | 'S7': [ 157 | { 158 | 'orientation': [0.1435241848230362, -0.1631336808204651, -0.7548328638076782, 0.6188824772834778], 159 | 'translation': [1974.512939453125, 4926.3544921875, 1597.8326416015625], 160 | }, 161 | { 162 | 'orientation': [0.6141672730445862, -0.7638262510299683, -0.1596645563840866, 0.1177929937839508], 163 | 'translation': [1937.0584716796875, -5119.7900390625, 1631.5665283203125], 164 | }, 165 | { 166 | 'orientation': [0.14550060033798218, -0.12874816358089447, 0.7660516500473022, -0.6127139329910278], 167 | 'translation': [-1741.8111572265625, 5208.24951171875, 1464.8245849609375], 168 | }, 169 | { 170 | 'orientation': [0.5912848114967346, -0.7821764349937439, 0.12445473670959473, -0.15196487307548523], 171 | 'translation': [-1734.7105712890625, -3832.42138671875, 1548.5830078125], 172 | }, 173 | ], 174 | 'S8': [ 175 | { 176 | 'orientation': [0.14110587537288666, -0.15589867532253265, -0.7561917304992676, 0.619644045829773], 177 | 'translation': [2150.65185546875, 4896.1611328125, 1611.9046630859375], 178 | }, 179 | { 180 | 'orientation': [0.6169601678848267, -0.7647668123245239, -0.14846350252628326, 0.11158157885074615], 181 | 'translation': [2219.965576171875, -5148.453125, 1613.0440673828125], 182 | }, 183 | { 184 | 'orientation': [0.1471444070339203, -0.13377119600772858, 0.7670128345489502, -0.6100369691848755], 185 | 'translation': [-1571.2215576171875, 5137.0185546875, 1498.1761474609375], 186 | }, 187 | { 188 | 'orientation': [0.5927824378013611, -0.7825870513916016, 0.12147816270589828, -0.14631995558738708], 189 | 'translation': [-1476.913330078125, -3896.7412109375, 1547.97216796875], 190 | }, 191 | ], 192 | 'S9': [ 193 | { 194 | 'orientation': [0.15540587902069092, -0.15548215806484222, -0.7532095313072205, 0.6199594736099243], 195 | 'translation': [2044.45849609375, 4935.1171875, 1481.2275390625], 196 | }, 197 | { 198 | 'orientation': [0.618784487247467, -0.7634735107421875, -0.14132238924503326, 0.11933968216180801], 199 | 'translation': [1990.959716796875, -5123.810546875, 1568.8048095703125], 200 | }, 201 | { 202 | 'orientation': [0.13357827067375183, -0.1367100477218628, 0.7689454555511475, -0.6100738644599915], 203 | 'translation': [-1670.9921875, 5211.98583984375, 1528.387939453125], 204 | }, 205 | { 206 | 'orientation': [0.5879399180412292, -0.7823407053947449, 0.1427614390850067, -0.14794869720935822], 207 | 'translation': [-1696.04345703125, -3827.099853515625, 1591.4127197265625], 208 | }, 209 | ], 210 | 'S11': [ 211 | { 212 | 'orientation': [0.15232472121715546, -0.15442320704460144, -0.7547563314437866, 0.6191070079803467], 213 | 'translation': [2098.440185546875, 4926.5546875, 1500.278564453125], 214 | }, 215 | { 216 | 'orientation': [0.6189449429512024, -0.7600917220115662, -0.15300633013248444, 0.1255258321762085], 217 | 'translation': [2083.182373046875, -4912.1728515625, 1561.07861328125], 218 | }, 219 | { 220 | 'orientation': [0.14943228662014008, -0.15650227665901184, 0.7681233882904053, -0.6026304364204407], 221 | 'translation': [-1609.8153076171875, 5177.3359375, 1537.896728515625], 222 | }, 223 | { 224 | 'orientation': [0.5894251465797424, -0.7818877100944519, 0.13991211354732513, -0.14715361595153809], 225 | 'translation': [-1590.738037109375, -3854.1689453125, 1578.017578125], 226 | }, 227 | ], 228 | } 229 | 230 | 231 | class Human36mDataset(MocapDataset): 232 | def __init__(self, path, remove_static_joints=True): 233 | super(Human36mDataset, self).__init__(skeleton=h36m_skeleton, fps=50) 234 | 235 | self._cameras = copy.deepcopy(h36m_cameras_extrinsic_params) 236 | for cameras in self._cameras.values(): 237 | for i, cam in enumerate(cameras): 238 | cam.update(h36m_cameras_intrinsic_params[i]) 239 | for k, v in cam.items(): 240 | if k not in ['id', 'res_w', 'res_h']: 241 | cam[k] = np.array(v, dtype='float32') 242 | 243 | # Normalize camera frame 244 | cam['center'] = normalize_screen_coordinates(cam['center'], w=cam['res_w'], h=cam['res_h']).astype( 245 | 'float32') 246 | cam['focal_length'] = cam['focal_length'] / cam['res_w'] * 2.0 247 | if 'translation' in cam: 248 | cam['translation'] = cam['translation'] / 1000 # mm to meters 249 | 250 | # Add intrinsic parameters vector 251 | cam['intrinsic'] = np.concatenate((cam['focal_length'], 252 | cam['center'], 253 | cam['radial_distortion'], 254 | cam['tangential_distortion'])) 255 | 256 | # Load serialized dataset 257 | data = np.load(path, allow_pickle=True)['positions_3d'].item() 258 | 259 | self._data = {} 260 | for subject, actions in data.items(): 261 | self._data[subject] = {} 262 | for action_name, positions in actions.items(): 263 | self._data[subject][action_name] = { 264 | 'positions': positions, 265 | 'cameras': self._cameras[subject], 266 | } 267 | 268 | if remove_static_joints: 269 | # Bring the skeleton to 16 joints instead of the original 32 270 | joints = [] 271 | for i, x in enumerate(H36M_NAMES): 272 | if x == '' or x == 'Neck/Nose': # Remove 'Nose' to make SH and H36M 2D poses have the same dimension 273 | joints.append(i) 274 | self.remove_joints(joints) 275 | 276 | # Rewire shoulders to the correct parents 277 | self._skeleton._parents[10] = 8 278 | self._skeleton._parents[13] = 8 279 | 280 | # Set joints group 281 | self._skeleton._joints_group = h36m_skeleton_joints_group 282 | 283 | def define_actions(self, action=None): 284 | all_actions = ["Directions", 285 | "Discussion", 286 | "Eating", 287 | "Greeting", 288 | "Phoning", 289 | "Photo", 290 | "Posing", 291 | "Purchases", 292 | "Sitting", 293 | "SittingDown", 294 | "Smoking", 295 | "Waiting", 296 | "WalkDog", 297 | "Walking", 298 | "WalkTogether"] 299 | 300 | if action is None: 301 | return all_actions 302 | 303 | if action not in all_actions: 304 | raise (ValueError, "Undefined action: {}".format(action)) 305 | 306 | return [action] 307 | -------------------------------------------------------------------------------- /main_gcn.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import, division 2 | 3 | import os 4 | import time 5 | import datetime 6 | import argparse 7 | import numpy as np 8 | import os.path as path 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.backends.cudnn as cudnn 13 | from torch.utils.data import DataLoader 14 | 15 | from progress.bar import Bar 16 | from common.log import Logger, savefig 17 | from common.utils import AverageMeter, lr_decay, save_ckpt 18 | from common.graph_utils import adj_mx_from_skeleton 19 | from common.data_utils import fetch, read_3d_data, create_2d_data 20 | from common.generators import PoseGenerator 21 | from common.loss import mpjpe, p_mpjpe, sym_penalty 22 | from common.camera import get_uvd2xyz 23 | 24 | from models.graph_sh import GraphSH 25 | 26 | 27 | def parse_args(): 28 | parser = argparse.ArgumentParser(description='PyTorch training script') 29 | 30 | # General arguments 31 | parser.add_argument('-d', '--dataset', default='h36m', type=str, metavar='NAME', help='target dataset') 32 | parser.add_argument('-k', '--keypoints', default='gt', type=str, metavar='NAME', help='2D detections to use') 33 | parser.add_argument('-a', '--actions', default='*', type=str, metavar='LIST', 34 | help='actions to train/test on, separated by comma, or * for all') 35 | parser.add_argument('--evaluate', default='', type=str, metavar='FILENAME', 36 | help='checkpoint to evaluate (file name)') 37 | parser.add_argument('-r', '--resume', default='', type=str, metavar='FILENAME', 38 | help='checkpoint to resume (file name)') 39 | parser.add_argument('-c', '--checkpoint', default='checkpoint', type=str, metavar='PATH', 40 | help='checkpoint directory') 41 | parser.add_argument('--snapshot', default=5, type=int, help='save models for every #snapshot epochs (default: 20)') 42 | 43 | # Model arguments 44 | parser.add_argument('-l', '--num_layers', default=4, type=int, metavar='N', help='num of residual layers') 45 | parser.add_argument('-z', '--hid_dim', default=64, type=int, metavar='N', help='num of hidden dimensions') 46 | parser.add_argument('-b', '--batch_size', default=256, type=int, metavar='N', 47 | help='batch size in terms of predicted frames') 48 | parser.add_argument('-e', '--epochs', default=50, type=int, metavar='N', help='number of training epochs') 49 | parser.add_argument('--lamda', '--weight_L1_norm', default=0.0, type=float, metavar='N', help='scale of L1 Norm') 50 | parser.add_argument('--num_workers', default=8, type=int, metavar='N', help='num of workers for data loading') 51 | parser.add_argument('--lr', default=1.0e-4, type=float, metavar='LR', help='initial learning rate') 52 | parser.add_argument('--lr_decay', type=int, default=20000, help='num of steps of learning rate decay') 53 | parser.add_argument('--lr_gamma', type=float, default=0.92, help='gamma of learning rate decay') 54 | parser.add_argument('--no_max', dest='max_norm', action='store_false', help='if use max_norm clip on grad') 55 | parser.set_defaults(max_norm=True) 56 | parser.add_argument('--post_refine', dest='post_refine', action='store_true', help='if use post-refine layers') 57 | parser.set_defaults(post_refine=False) 58 | parser.add_argument('--dropout', default=0.25, type=float, help='dropout rate') 59 | parser.add_argument('--gcn', default='', type=str, metavar='NAME', help='type of gcn') 60 | parser.add_argument('-n', '--name', default='', type=str, metavar='NAME', help='name of model') 61 | 62 | # Experimental 63 | parser.add_argument('--downsample', default=1, type=int, metavar='FACTOR', help='downsample frame rate by factor') 64 | 65 | args = parser.parse_args() 66 | 67 | # Check invalid configuration 68 | if args.resume and args.evaluate: 69 | print('Invalid flags: --resume and --evaluate cannot be set at the same time') 70 | exit() 71 | 72 | return args 73 | 74 | 75 | def main(args): 76 | print('==> Using settings {}'.format(args)) 77 | 78 | print('==> Loading dataset...') 79 | dataset_path = path.join('data', 'data_3d_' + args.dataset + '.npz') 80 | if args.dataset == 'h36m': 81 | from common.h36m_dataset import Human36mDataset, TRAIN_SUBJECTS, TEST_SUBJECTS 82 | dataset = Human36mDataset(dataset_path) 83 | subjects_train = TRAIN_SUBJECTS 84 | subjects_test = TEST_SUBJECTS 85 | else: 86 | raise KeyError('Invalid dataset') 87 | 88 | print('==> Preparing data...') 89 | dataset = read_3d_data(dataset) 90 | 91 | print('==> Loading 2D detections...') 92 | keypoints = create_2d_data(path.join('data', 'data_2d_' + args.dataset + '_' + args.keypoints + '.npz'), dataset) 93 | 94 | action_filter = None if args.actions == '*' else args.actions.split(',') 95 | if action_filter is not None: 96 | action_filter = map(lambda x: dataset.define_actions(x)[0], action_filter) 97 | print('==> Selected actions: {}'.format(action_filter)) 98 | 99 | stride = args.downsample 100 | cudnn.benchmark = True 101 | device = torch.device("cuda:0") 102 | 103 | # Create model 104 | print("==> Creating model...") 105 | 106 | p_dropout = (None if args.dropout == 0.0 else args.dropout) 107 | adj = adj_mx_from_skeleton(dataset.skeleton()).to(device) 108 | 109 | # Post refinement model 110 | if args.post_refine: 111 | model_post_refine = PostRefine(2, 3, 16).to(device) 112 | else: 113 | model_post_refine = None 114 | 115 | model_pos = GraphSH(adj, args.hid_dim, dataset.skeleton().joints_group(), num_layers=args.num_layers, p_dropout=p_dropout, gcn_type=args.gcn).to(device) 116 | 117 | print("==> Total parameters: {:.2f}M".format(sum(p.numel() for p in model_pos.parameters()) / 1000000.0)) 118 | 119 | criterion = nn.MSELoss(reduction='mean').to(device) 120 | criterionL1 = nn.L1Loss(reduction='mean').to(device) 121 | 122 | optimizer = torch.optim.Adam(model_pos.parameters(), lr=args.lr) 123 | 124 | # Optionally resume from a checkpoint 125 | if args.resume or args.evaluate: 126 | ckpt_path = (args.resume if args.resume else args.evaluate) 127 | 128 | if path.isfile(ckpt_path): 129 | print("==> Loading checkpoint '{}'".format(ckpt_path)) 130 | ckpt = torch.load(ckpt_path) 131 | start_epoch = ckpt['epoch'] 132 | error_best = ckpt['error'] 133 | glob_step = ckpt['step'] 134 | lr_now = ckpt['lr'] 135 | model_pos.load_state_dict(ckpt['state_dict']) 136 | optimizer.load_state_dict(ckpt['optimizer']) 137 | print("==> Loaded checkpoint (Epoch: {} | Error: {})".format(start_epoch, error_best)) 138 | # for name, p in model_pos.named_parameters(): 139 | # print(name)#, p.data) 140 | # exit(0) 141 | 142 | if args.resume: 143 | ckpt_dir_path = path.dirname(ckpt_path) 144 | logger = Logger(path.join(ckpt_dir_path, 'log.txt'), resume=True) 145 | else: 146 | raise RuntimeError("==> No checkpoint found at '{}'".format(ckpt_path)) 147 | else: 148 | start_epoch = 0 149 | error_best = None 150 | glob_step = 0 151 | lr_now = args.lr 152 | ckpt_dir_path = path.join(args.checkpoint, args.name + '-' + args.gcn + '-' + datetime.date.today().isoformat()) 153 | 154 | if not path.exists(ckpt_dir_path): 155 | os.makedirs(ckpt_dir_path) 156 | print('==> Making checkpoint dir: {}'.format(ckpt_dir_path)) 157 | 158 | logger = Logger(os.path.join(ckpt_dir_path, 'log.txt')) 159 | logger.set_names(['epoch', 'lr', 'loss_train', 'error_eval_p1', 'error_eval_p2']) 160 | 161 | if args.evaluate: 162 | print('==> Evaluating...') 163 | 164 | if action_filter is None: 165 | action_filter = dataset.define_actions() 166 | 167 | errors_p1 = np.zeros(len(action_filter)) 168 | errors_p2 = np.zeros(len(action_filter)) 169 | 170 | for i, action in enumerate(action_filter): 171 | poses_valid, poses_valid_2d, actions_valid, cam_valid = fetch(subjects_test, dataset, keypoints, [action], stride) 172 | valid_loader = DataLoader(PoseGenerator(poses_valid, poses_valid_2d, actions_valid, cam_valid), 173 | batch_size=args.batch_size, shuffle=False, 174 | num_workers=args.num_workers, pin_memory=True) 175 | errors_p1[i], errors_p2[i] = evaluate(valid_loader, model_pos, device) 176 | 177 | print('Protocol #1 (MPJPE) action-wise average: {:.2f} (mm)'.format(np.mean(errors_p1).item())) 178 | print('Protocol #2 (P-MPJPE) action-wise average: {:.2f} (mm)'.format(np.mean(errors_p2).item())) 179 | exit(0) 180 | 181 | poses_train, poses_train_2d, actions_train, cam_train = fetch(subjects_train, dataset, keypoints, action_filter, stride) 182 | train_loader = DataLoader(PoseGenerator(poses_train, poses_train_2d, actions_train, cam_train), batch_size=args.batch_size, 183 | shuffle=True, num_workers=args.num_workers, pin_memory=True) 184 | 185 | poses_valid, poses_valid_2d, actions_valid, cam_valid = fetch(subjects_test, dataset, keypoints, action_filter, stride) 186 | valid_loader = DataLoader(PoseGenerator(poses_valid, poses_valid_2d, actions_valid, cam_valid), batch_size=args.batch_size, 187 | shuffle=False, num_workers=args.num_workers, pin_memory=True) 188 | 189 | for epoch in range(start_epoch, args.epochs): 190 | print('\nEpoch: %d | LR: %.8f' % (epoch + 1, lr_now)) 191 | 192 | # Train for one epoch 193 | epoch_loss, lr_now, glob_step = train(train_loader, model_pos, model_post_refine, 194 | args.lamda, criterion, criterionL1, optimizer, 195 | device, args.lr, lr_now, 196 | glob_step, args.lr_decay, args.lr_gamma, max_norm=args.max_norm) 197 | 198 | # Evaluate 199 | error_eval_p1, error_eval_p2 = evaluate(valid_loader, model_pos, device) 200 | 201 | # Update log file 202 | logger.append([epoch + 1, lr_now, epoch_loss, error_eval_p1, error_eval_p2]) 203 | 204 | # Save checkpoint 205 | if error_best is None or error_best > error_eval_p1: 206 | error_best = error_eval_p1 207 | save_ckpt({'epoch': epoch + 1, 'lr': lr_now, 'step': glob_step, 'state_dict': model_pos.state_dict(), 208 | 'optimizer': optimizer.state_dict(), 'post_refine': model_post_refine, 'error': error_eval_p1}, 209 | ckpt_dir_path, suffix='best') 210 | 211 | if (epoch + 1) % args.snapshot == 0: 212 | save_ckpt({'epoch': epoch + 1, 'lr': lr_now, 'step': glob_step, 'state_dict': model_pos.state_dict(), 213 | 'optimizer': optimizer.state_dict(), 'post_refine': model_post_refine, 'error': error_eval_p1}, 214 | ckpt_dir_path) 215 | 216 | logger.close() 217 | logger.plot(['loss_train', 'error_eval_p1']) 218 | savefig(path.join(ckpt_dir_path, 'log.eps')) 219 | 220 | return 221 | 222 | 223 | def train(data_loader, model_pos, model_post_refine, lamda, criterion, criterionL1, optimizer, device, lr_init, lr_now, step, decay, gamma, 224 | max_norm=True): 225 | batch_time = AverageMeter() 226 | data_time = AverageMeter() 227 | epoch_loss_3d_pos = AverageMeter() 228 | 229 | # Switch to train mode 230 | torch.set_grad_enabled(True) 231 | model_pos.train() 232 | if model_post_refine is not None: 233 | model_post_refine.train() 234 | end = time.time() 235 | 236 | bar = Bar('Train', max=len(data_loader)) 237 | for i, (targets_3d, inputs_2d, _, batch_cam) in enumerate(data_loader): 238 | # Measure data loading time 239 | data_time.update(time.time() - end) 240 | num_poses = targets_3d.size(0) 241 | 242 | step += 1 243 | if step % decay == 0 or step == 1: 244 | lr_now = lr_decay(optimizer, step, lr_init, decay, gamma) 245 | 246 | targets_3d, inputs_2d, batch_cam = targets_3d.to(device), inputs_2d.to(device), batch_cam.to(device) 247 | outputs_3d = model_pos(inputs_2d) 248 | 249 | if model_post_refine is not None: 250 | uvd = torch.cat((inputs_2d.unsqueeze(1), outputs_3d[:, None, :, 2].unsqueeze(-1)), -1) 251 | xyz = get_uvd2xyz(uvd, targets_3d[:, None], batch_cam).squeeze() 252 | xyz[:, 0, :] = 0 253 | 254 | post_out = model_post_refine(outputs_3d, xyz) 255 | loss_sym = sym_penalty(post_out) 256 | loss_post_refine = mpjpe(post_out, targets_3d) + 0.01*loss_sym 257 | else: 258 | loss_post_refine = 0 259 | 260 | optimizer.zero_grad() 261 | loss_3d_pos = (1 - lamda) * criterion(outputs_3d, targets_3d) + lamda * criterionL1(outputs_3d, targets_3d) + loss_post_refine 262 | loss_3d_pos.backward() 263 | if max_norm: 264 | nn.utils.clip_grad_norm_(model_pos.parameters(), max_norm=1) 265 | optimizer.step() 266 | 267 | epoch_loss_3d_pos.update(loss_3d_pos.item(), num_poses) 268 | 269 | # Measure elapsed time 270 | batch_time.update(time.time() - end) 271 | end = time.time() 272 | 273 | bar.suffix = '({batch}/{size}) Data: {data:.6f}s | Batch: {bt:.3f}s | Total: {ttl:} | ETA: {eta:} ' \ 274 | '| Loss: {loss: .6f}' \ 275 | .format(batch=i + 1, size=len(data_loader), data=data_time.avg, bt=batch_time.avg, 276 | ttl=bar.elapsed_td, eta=bar.eta_td, loss=epoch_loss_3d_pos.avg) 277 | bar.next() 278 | 279 | bar.finish() 280 | return epoch_loss_3d_pos.avg, lr_now, step 281 | 282 | 283 | def evaluate(data_loader, model_pos, device): 284 | batch_time = AverageMeter() 285 | data_time = AverageMeter() 286 | epoch_loss_3d_pos = AverageMeter() 287 | epoch_loss_3d_pos_procrustes = AverageMeter() 288 | 289 | # Switch to evaluate mode 290 | torch.set_grad_enabled(False) 291 | model_pos.eval() 292 | end = time.time() 293 | 294 | bar = Bar('Eval ', max=len(data_loader)) 295 | for i, (targets_3d, inputs_2d, _, _) in enumerate(data_loader): 296 | # Measure data loading time 297 | data_time.update(time.time() - end) 298 | num_poses = targets_3d.size(0) 299 | 300 | inputs_2d = inputs_2d.to(device) 301 | outputs_3d = model_pos(inputs_2d).cpu() 302 | 303 | outputs_3d[:, :, :] = outputs_3d[:, :, :] - outputs_3d[:, :1, :] # Zero-centre the root (hip) 304 | 305 | epoch_loss_3d_pos.update(mpjpe(outputs_3d, targets_3d).item() * 1000.0, num_poses) 306 | epoch_loss_3d_pos_procrustes.update(p_mpjpe(outputs_3d.numpy(), targets_3d.numpy()).item() * 1000.0, num_poses) 307 | 308 | # Measure elapsed time 309 | batch_time.update(time.time() - end) 310 | end = time.time() 311 | 312 | bar.suffix = '({batch}/{size}) Data: {data:.6f}s | Batch: {bt:.3f}s | Total: {ttl:} | ETA: {eta:} ' \ 313 | '| MPJPE: {e1: .4f} | P-MPJPE: {e2: .4f}' \ 314 | .format(batch=i + 1, size=len(data_loader), data=data_time.avg, bt=batch_time.avg, 315 | ttl=bar.elapsed_td, eta=bar.eta_td, e1=epoch_loss_3d_pos.avg, e2=epoch_loss_3d_pos_procrustes.avg) 316 | bar.next() 317 | 318 | bar.finish() 319 | return epoch_loss_3d_pos.avg, epoch_loss_3d_pos_procrustes.avg 320 | 321 | 322 | if __name__ == '__main__': 323 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' 324 | 325 | main(parse_args()) 326 | --------------------------------------------------------------------------------