├── commona ├── common ├── README.MD ├── __pycache__ │ ├── loss.cpython-38.pyc │ ├── utils.cpython-38.pyc │ ├── camera.cpython-38.pyc │ ├── skeleton.cpython-38.pyc │ ├── arguments.cpython-38.pyc │ ├── generators.cpython-38.pyc │ ├── quaternion.cpython-38.pyc │ ├── reversible.cpython-38.pyc │ ├── h36m_dataset.cpython-38.pyc │ ├── mocap_dataset.cpython-38.pyc │ ├── visualization.cpython-38.pyc │ └── model_crossformer.cpython-38.pyc ├── quaternion.py ├── mocap_dataset.py ├── custom_dataset.py ├── utils.py ├── camera.py ├── skeleton.py ├── loss.py ├── humaneva_dataset.py ├── reversible.py ├── arguments.py ├── visualization.py ├── h36m_dataset.py ├── generators.py └── model_crossformer.py ├── README.md ├── crossformer.yml ├── LICENSE └── run_crossformer.py /commona: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /common/README.MD: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /common/__pycache__/loss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mfawzy/CrossFormer/HEAD/common/__pycache__/loss.cpython-38.pyc -------------------------------------------------------------------------------- /common/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mfawzy/CrossFormer/HEAD/common/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /common/__pycache__/camera.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mfawzy/CrossFormer/HEAD/common/__pycache__/camera.cpython-38.pyc -------------------------------------------------------------------------------- /common/__pycache__/skeleton.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mfawzy/CrossFormer/HEAD/common/__pycache__/skeleton.cpython-38.pyc -------------------------------------------------------------------------------- /common/__pycache__/arguments.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mfawzy/CrossFormer/HEAD/common/__pycache__/arguments.cpython-38.pyc -------------------------------------------------------------------------------- /common/__pycache__/generators.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mfawzy/CrossFormer/HEAD/common/__pycache__/generators.cpython-38.pyc -------------------------------------------------------------------------------- /common/__pycache__/quaternion.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mfawzy/CrossFormer/HEAD/common/__pycache__/quaternion.cpython-38.pyc -------------------------------------------------------------------------------- /common/__pycache__/reversible.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mfawzy/CrossFormer/HEAD/common/__pycache__/reversible.cpython-38.pyc -------------------------------------------------------------------------------- /common/__pycache__/h36m_dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mfawzy/CrossFormer/HEAD/common/__pycache__/h36m_dataset.cpython-38.pyc -------------------------------------------------------------------------------- /common/__pycache__/mocap_dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mfawzy/CrossFormer/HEAD/common/__pycache__/mocap_dataset.cpython-38.pyc -------------------------------------------------------------------------------- /common/__pycache__/visualization.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mfawzy/CrossFormer/HEAD/common/__pycache__/visualization.cpython-38.pyc -------------------------------------------------------------------------------- /common/__pycache__/model_crossformer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mfawzy/CrossFormer/HEAD/common/__pycache__/model_crossformer.cpython-38.pyc -------------------------------------------------------------------------------- /common/quaternion.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import torch 9 | 10 | def qrot(q, v): 11 | """ 12 | Rotate vector(s) v about the rotation described by quaternion(s) q. 13 | Expects a tensor of shape (*, 4) for q and a tensor of shape (*, 3) for v, 14 | where * denotes any number of dimensions. 15 | Returns a tensor of shape (*, 3). 16 | """ 17 | assert q.shape[-1] == 4 18 | assert v.shape[-1] == 3 19 | assert q.shape[:-1] == v.shape[:-1] 20 | 21 | qvec = q[..., 1:] 22 | uv = torch.cross(qvec, v, dim=len(q.shape)-1) 23 | uuv = torch.cross(qvec, uv, dim=len(q.shape)-1) 24 | return (v + 2 * (q[..., :1] * uv + uuv)) 25 | 26 | 27 | def qinverse(q, inplace=False): 28 | # We assume the quaternion to be normalized 29 | if inplace: 30 | q[..., 1:] *= -1 31 | return q 32 | else: 33 | w = q[..., :1] 34 | xyz = q[..., 1:] 35 | return torch.cat((w, -xyz), dim=len(q.shape)-1) -------------------------------------------------------------------------------- /common/mocap_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import numpy as np 9 | from common.skeleton import Skeleton 10 | 11 | class MocapDataset: 12 | def __init__(self, fps, skeleton): 13 | self._skeleton = skeleton 14 | self._fps = fps 15 | self._data = None # Must be filled by subclass 16 | self._cameras = None # Must be filled by subclass 17 | 18 | def remove_joints(self, joints_to_remove): 19 | kept_joints = self._skeleton.remove_joints(joints_to_remove) 20 | for subject in self._data.keys(): 21 | for action in self._data[subject].keys(): 22 | s = self._data[subject][action] 23 | if 'positions' in s: 24 | s['positions'] = s['positions'][:, kept_joints] 25 | 26 | 27 | def __getitem__(self, key): 28 | return self._data[key] 29 | 30 | def subjects(self): 31 | return self._data.keys() 32 | 33 | def fps(self): 34 | return self._fps 35 | 36 | def skeleton(self): 37 | return self._skeleton 38 | 39 | def cameras(self): 40 | return self._cameras 41 | 42 | def supports_semi_supervised(self): 43 | # This method can be overridden 44 | return False -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 3D Human Pose Estimation with Spatial and Temporal Transformers 2 | This repo is the official implementation for [CrossFormer: Cross Spatio-Temporal Transformer for 3D Human Pose Estimation](https://arxiv.org/abs/2203.13387) 3 | 4 | 5 | 6 | 7 | Our code is built on top of [VideoPose3D](https://github.com/facebookresearch/VideoPose3D). 8 | 9 | ### Environment 10 | 11 | The code is developed and tested under the following environment 12 | 13 | * Python 3.8.2 14 | * PyTorch 1.7.1 15 | * CUDA 11.0 16 | 17 | You can create the environment: 18 | ```bash 19 | conda env create -f crossformer.yml 20 | ``` 21 | 22 | ### Dataset 23 | 24 | Our code is compatible with the dataset setup introduced by [Martinez et al.](https://github.com/una-dinosauria/3d-pose-baseline) and [Pavllo et al.](https://github.com/facebookresearch/VideoPose3D). Please refer to [VideoPose3D](https://github.com/facebookresearch/VideoPose3D) to set up the Human3.6M dataset (./data directory). 25 | 26 | ### Evaluating pre-trained models 27 | 28 | We provide the pre-trained 81-frame model (CPN detected 2D pose as input) [here](https://drive.google.com/file/d/1eNmpfTAhc-6hKLQjXv7qeCdpVvZlI8D1/view?usp=sharing). To evaluate it, put it into the `./checkpoint` directory and run: 29 | 30 | ```bash 31 | python run_crossformer.py -k cpn_ft_h36m_dbb -f 81 -c checkpoint --evaluate best_epoch44.4.bin 32 | ``` 33 | 34 | We also provide pre-trained 81-frame model (Ground truth 2D pose as input) [here](https://drive.google.com/file/d/1LF-HVcyqMWC8VBWDXrL4oVfGzGnpT8aN/view?usp=sharing). To evaluate it, put it into the `./checkpoint` directory and run: 35 | 36 | ```bash 37 | python run_crossformer.py -k gt -f 81 -c checkpoint --evaluate best_epoch_gt_28.5.bin 38 | ``` 39 | 40 | 41 | ### Training new models 42 | 43 | * To train a model from scratch (CPN detected 2D pose as input), run: 44 | 45 | ```bash 46 | python run_crossformer.py -k cpn_ft_h36m_dbb -f 27 -lr 0.00004 -lrd 0.99 47 | ``` 48 | 49 | 50 | 51 | * To train a model from scratch (Ground truth 2D pose as input), run: 52 | 53 | ```bash 54 | python run_crossformer.py -k gt -f 81 -lr 0.0004 -lrd 0.99 55 | ``` 56 | 57 | 81 frames achieves 28.5 mm (MPJPE). 58 | 59 | ### Visualization and other functions 60 | 61 | We keep our code consistent with [VideoPose3D](https://github.com/facebookresearch/VideoPose3D). Please refer to their project page for further information. 62 | 63 | ## Acknowledgement 64 | 65 | Part of our code is borrowed from [VideoPose3D](https://github.com/facebookresearch/VideoPose3D). We thank the authors for releasing the codes. 66 | -------------------------------------------------------------------------------- /common/custom_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import numpy as np 9 | import copy 10 | from common.skeleton import Skeleton 11 | from common.mocap_dataset import MocapDataset 12 | from common.camera import normalize_screen_coordinates, image_coordinates 13 | from common.h36m_dataset import h36m_skeleton 14 | 15 | 16 | custom_camera_params = { 17 | 'id': None, 18 | 'res_w': None, # Pulled from metadata 19 | 'res_h': None, # Pulled from metadata 20 | 21 | # Dummy camera parameters (taken from Human3.6M), only for visualization purposes 22 | 'azimuth': 70, # Only used for visualization 23 | 'orientation': [0.1407056450843811, -0.1500701755285263, -0.755240797996521, 0.6223280429840088], 24 | 'translation': [1841.1070556640625, 4955.28466796875, 1563.4454345703125], 25 | } 26 | 27 | class CustomDataset(MocapDataset): 28 | def __init__(self, detections_path, remove_static_joints=True): 29 | super().__init__(fps=None, skeleton=h36m_skeleton) 30 | 31 | # Load serialized dataset 32 | data = np.load(detections_path, allow_pickle=True) 33 | resolutions = data['metadata'].item()['video_metadata'] 34 | 35 | self._cameras = {} 36 | self._data = {} 37 | for video_name, res in resolutions.items(): 38 | cam = {} 39 | cam.update(custom_camera_params) 40 | cam['orientation'] = np.array(cam['orientation'], dtype='float32') 41 | cam['translation'] = np.array(cam['translation'], dtype='float32') 42 | cam['translation'] = cam['translation']/1000 # mm to meters 43 | 44 | cam['id'] = video_name 45 | cam['res_w'] = res['w'] 46 | cam['res_h'] = res['h'] 47 | 48 | self._cameras[video_name] = [cam] 49 | 50 | self._data[video_name] = { 51 | 'custom': { 52 | 'cameras': cam 53 | } 54 | } 55 | 56 | if remove_static_joints: 57 | # Bring the skeleton to 17 joints instead of the original 32 58 | self.remove_joints([4, 5, 9, 10, 11, 16, 20, 21, 22, 23, 24, 28, 29, 30, 31]) 59 | 60 | # Rewire shoulders to the correct parents 61 | self._skeleton._parents[11] = 8 62 | self._skeleton._parents[14] = 8 63 | 64 | def supports_semi_supervised(self): 65 | return False 66 | -------------------------------------------------------------------------------- /common/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import torch 9 | import numpy as np 10 | import hashlib 11 | 12 | def wrap(func, *args, unsqueeze=False): 13 | """ 14 | Wrap a torch function so it can be called with NumPy arrays. 15 | Input and return types are seamlessly converted. 16 | """ 17 | 18 | # Convert input types where applicable 19 | args = list(args) 20 | for i, arg in enumerate(args): 21 | if type(arg) == np.ndarray: 22 | args[i] = torch.from_numpy(arg) 23 | if unsqueeze: 24 | args[i] = args[i].unsqueeze(0) 25 | 26 | result = func(*args) 27 | 28 | # Convert output types where applicable 29 | if isinstance(result, tuple): 30 | result = list(result) 31 | for i, res in enumerate(result): 32 | if type(res) == torch.Tensor: 33 | if unsqueeze: 34 | res = res.squeeze(0) 35 | result[i] = res.numpy() 36 | return tuple(result) 37 | elif type(result) == torch.Tensor: 38 | if unsqueeze: 39 | result = result.squeeze(0) 40 | return result.numpy() 41 | else: 42 | return result 43 | 44 | def deterministic_random(min_value, max_value, data): 45 | digest = hashlib.sha256(data.encode()).digest() 46 | raw_value = int.from_bytes(digest[:4], byteorder='little', signed=False) 47 | return int(raw_value / (2**32 - 1) * (max_value - min_value)) + min_value 48 | 49 | def load_pretrained_weights(model, checkpoint): 50 | """Load pretrianed weights to model 51 | Incompatible layers (unmatched in name or size) will be ignored 52 | Args: 53 | - model (nn.Module): network model, which must not be nn.DataParallel 54 | - weight_path (str): path to pretrained weights 55 | """ 56 | import collections 57 | if 'state_dict' in checkpoint: 58 | state_dict = checkpoint['state_dict'] 59 | else: 60 | state_dict = checkpoint 61 | model_dict = model.state_dict() 62 | new_state_dict = collections.OrderedDict() 63 | matched_layers, discarded_layers = [], [] 64 | for k, v in state_dict.items(): 65 | # If the pretrained state_dict was saved as nn.DataParallel, 66 | # keys would contain "module.", which should be ignored. 67 | if k.startswith('module.'): 68 | k = k[7:] 69 | if k in model_dict and model_dict[k].size() == v.size(): 70 | new_state_dict[k] = v 71 | matched_layers.append(k) 72 | else: 73 | discarded_layers.append(k) 74 | # new_state_dict.requires_grad = False 75 | model_dict.update(new_state_dict) 76 | 77 | model.load_state_dict(model_dict) 78 | print('load_weight', len(matched_layers)) 79 | # model.state_dict(model_dict).requires_grad = False 80 | return model 81 | -------------------------------------------------------------------------------- /common/camera.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import numpy as np 9 | import torch 10 | 11 | from common.utils import wrap 12 | from common.quaternion import qrot, qinverse 13 | 14 | def normalize_screen_coordinates(X, w, h): 15 | assert X.shape[-1] == 2 16 | 17 | # Normalize so that [0, w] is mapped to [-1, 1], while preserving the aspect ratio 18 | return X/w*2 - [1, h/w] 19 | 20 | 21 | def image_coordinates(X, w, h): 22 | assert X.shape[-1] == 2 23 | 24 | # Reverse camera frame normalization 25 | return (X + [1, h/w])*w/2 26 | 27 | 28 | def world_to_camera(X, R, t): 29 | Rt = wrap(qinverse, R) # Invert rotation 30 | return wrap(qrot, np.tile(Rt, (*X.shape[:-1], 1)), X - t) # Rotate and translate 31 | 32 | 33 | def camera_to_world(X, R, t): 34 | return wrap(qrot, np.tile(R, (*X.shape[:-1], 1)), X) + t 35 | 36 | 37 | def project_to_2d(X, camera_params): 38 | """ 39 | Project 3D points to 2D using the Human3.6M camera projection function. 40 | This is a differentiable and batched reimplementation of the original MATLAB script. 41 | 42 | Arguments: 43 | X -- 3D points in *camera space* to transform (N, *, 3) 44 | camera_params -- intrinsic parameteres (N, 2+2+3+2=9) 45 | """ 46 | assert X.shape[-1] == 3 47 | assert len(camera_params.shape) == 2 48 | assert camera_params.shape[-1] == 9 49 | assert X.shape[0] == camera_params.shape[0] 50 | 51 | while len(camera_params.shape) < len(X.shape): 52 | camera_params = camera_params.unsqueeze(1) 53 | 54 | f = camera_params[..., :2] 55 | c = camera_params[..., 2:4] 56 | k = camera_params[..., 4:7] 57 | p = camera_params[..., 7:] 58 | 59 | XX = torch.clamp(X[..., :2] / X[..., 2:], min=-1, max=1) 60 | r2 = torch.sum(XX[..., :2]**2, dim=len(XX.shape)-1, keepdim=True) 61 | 62 | radial = 1 + torch.sum(k * torch.cat((r2, r2**2, r2**3), dim=len(r2.shape)-1), dim=len(r2.shape)-1, keepdim=True) 63 | tan = torch.sum(p*XX, dim=len(XX.shape)-1, keepdim=True) 64 | 65 | XXX = XX*(radial + tan) + p*r2 66 | 67 | return f*XXX + c 68 | 69 | def project_to_2d_linear(X, camera_params): 70 | """ 71 | Project 3D points to 2D using only linear parameters (focal length and principal point). 72 | 73 | Arguments: 74 | X -- 3D points in *camera space* to transform (N, *, 3) 75 | camera_params -- intrinsic parameteres (N, 2+2+3+2=9) 76 | """ 77 | assert X.shape[-1] == 3 78 | assert len(camera_params.shape) == 2 79 | assert camera_params.shape[-1] == 9 80 | assert X.shape[0] == camera_params.shape[0] 81 | 82 | while len(camera_params.shape) < len(X.shape): 83 | camera_params = camera_params.unsqueeze(1) 84 | 85 | f = camera_params[..., :2] 86 | c = camera_params[..., 2:4] 87 | 88 | XX = torch.clamp(X[..., :2] / X[..., 2:], min=-1, max=1) 89 | 90 | return f*XX + c -------------------------------------------------------------------------------- /common/skeleton.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import numpy as np 9 | 10 | class Skeleton: 11 | def __init__(self, parents, joints_left, joints_right): 12 | assert len(joints_left) == len(joints_right) 13 | 14 | self._parents = np.array(parents) 15 | self._joints_left = joints_left 16 | self._joints_right = joints_right 17 | self._compute_metadata() 18 | 19 | def num_joints(self): 20 | return len(self._parents) 21 | 22 | def parents(self): 23 | return self._parents 24 | 25 | def has_children(self): 26 | return self._has_children 27 | 28 | def children(self): 29 | return self._children 30 | 31 | def remove_joints(self, joints_to_remove): 32 | """ 33 | Remove the joints specified in 'joints_to_remove'. 34 | """ 35 | valid_joints = [] 36 | for joint in range(len(self._parents)): 37 | if joint not in joints_to_remove: 38 | valid_joints.append(joint) 39 | 40 | for i in range(len(self._parents)): 41 | while self._parents[i] in joints_to_remove: 42 | self._parents[i] = self._parents[self._parents[i]] 43 | 44 | index_offsets = np.zeros(len(self._parents), dtype=int) 45 | new_parents = [] 46 | for i, parent in enumerate(self._parents): 47 | if i not in joints_to_remove: 48 | new_parents.append(parent - index_offsets[parent]) 49 | else: 50 | index_offsets[i:] += 1 51 | self._parents = np.array(new_parents) 52 | 53 | 54 | if self._joints_left is not None: 55 | new_joints_left = [] 56 | for joint in self._joints_left: 57 | if joint in valid_joints: 58 | new_joints_left.append(joint - index_offsets[joint]) 59 | self._joints_left = new_joints_left 60 | if self._joints_right is not None: 61 | new_joints_right = [] 62 | for joint in self._joints_right: 63 | if joint in valid_joints: 64 | new_joints_right.append(joint - index_offsets[joint]) 65 | self._joints_right = new_joints_right 66 | 67 | self._compute_metadata() 68 | 69 | return valid_joints 70 | 71 | def joints_left(self): 72 | return self._joints_left 73 | 74 | def joints_right(self): 75 | return self._joints_right 76 | 77 | def _compute_metadata(self): 78 | self._has_children = np.zeros(len(self._parents)).astype(bool) 79 | for i, parent in enumerate(self._parents): 80 | if parent != -1: 81 | self._has_children[parent] = True 82 | 83 | self._children = [] 84 | for i, parent in enumerate(self._parents): 85 | self._children.append([]) 86 | for i, parent in enumerate(self._parents): 87 | if parent != -1: 88 | self._children[parent].append(i) -------------------------------------------------------------------------------- /common/loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import torch 9 | import numpy as np 10 | 11 | def mpjpe(predicted, target): 12 | """ 13 | Mean per-joint position error (i.e. mean Euclidean distance), 14 | often referred to as "Protocol #1" in many papers. 15 | """ 16 | assert predicted.shape == target.shape 17 | return torch.mean(torch.norm(predicted - target, dim=len(target.shape)-1)) 18 | 19 | def weighted_mpjpe(predicted, target, w): 20 | """ 21 | Weighted mean per-joint position error (i.e. mean Euclidean distance) 22 | """ 23 | assert predicted.shape == target.shape 24 | assert w.shape[0] == predicted.shape[0] 25 | return torch.mean(w * torch.norm(predicted - target, dim=len(target.shape)-1)) 26 | 27 | def p_mpjpe(predicted, target): 28 | """ 29 | Pose error: MPJPE after rigid alignment (scale, rotation, and translation), 30 | often referred to as "Protocol #2" in many papers. 31 | """ 32 | assert predicted.shape == target.shape 33 | 34 | muX = np.mean(target, axis=1, keepdims=True) 35 | muY = np.mean(predicted, axis=1, keepdims=True) 36 | 37 | X0 = target - muX 38 | Y0 = predicted - muY 39 | 40 | normX = np.sqrt(np.sum(X0**2, axis=(1, 2), keepdims=True)) 41 | normY = np.sqrt(np.sum(Y0**2, axis=(1, 2), keepdims=True)) 42 | 43 | X0 /= normX 44 | Y0 /= normY 45 | 46 | H = np.matmul(X0.transpose(0, 2, 1), Y0) 47 | U, s, Vt = np.linalg.svd(H) 48 | V = Vt.transpose(0, 2, 1) 49 | R = np.matmul(V, U.transpose(0, 2, 1)) 50 | 51 | # Avoid improper rotations (reflections), i.e. rotations with det(R) = -1 52 | sign_detR = np.sign(np.expand_dims(np.linalg.det(R), axis=1)) 53 | V[:, :, -1] *= sign_detR 54 | s[:, -1] *= sign_detR.flatten() 55 | R = np.matmul(V, U.transpose(0, 2, 1)) # Rotation 56 | 57 | tr = np.expand_dims(np.sum(s, axis=1, keepdims=True), axis=2) 58 | 59 | a = tr * normX / normY # Scale 60 | t = muX - a*np.matmul(muY, R) # Translation 61 | 62 | # Perform rigid transformation on the input 63 | predicted_aligned = a*np.matmul(predicted, R) + t 64 | 65 | # Return MPJPE 66 | return np.mean(np.linalg.norm(predicted_aligned - target, axis=len(target.shape)-1)) 67 | 68 | def n_mpjpe(predicted, target): 69 | """ 70 | Normalized MPJPE (scale only), adapted from: 71 | https://github.com/hrhodin/UnsupervisedGeometryAwareRepresentationLearning/blob/master/losses/poses.py 72 | """ 73 | assert predicted.shape == target.shape 74 | 75 | norm_predicted = torch.mean(torch.sum(predicted**2, dim=3, keepdim=True), dim=2, keepdim=True) 76 | norm_target = torch.mean(torch.sum(target*predicted, dim=3, keepdim=True), dim=2, keepdim=True) 77 | scale = norm_target / norm_predicted 78 | return mpjpe(scale * predicted, target) 79 | 80 | def weighted_bonelen_loss(predict_3d_length, gt_3d_length): 81 | loss_length = 0.001 * torch.pow(predict_3d_length - gt_3d_length, 2).mean() 82 | return loss_length 83 | 84 | def weighted_boneratio_loss(predict_3d_length, gt_3d_length): 85 | loss_length = 0.1 * torch.pow((predict_3d_length - gt_3d_length)/gt_3d_length, 2).mean() 86 | return loss_length 87 | 88 | def mean_velocity_error(predicted, target): 89 | """ 90 | Mean per-joint velocity error (i.e. mean Euclidean distance of the 1st derivative) 91 | """ 92 | assert predicted.shape == target.shape 93 | 94 | velocity_predicted = np.diff(predicted, axis=0) 95 | velocity_target = np.diff(target, axis=0) 96 | 97 | return np.mean(np.linalg.norm(velocity_predicted - velocity_target, axis=len(target.shape)-1)) -------------------------------------------------------------------------------- /common/humaneva_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import numpy as np 9 | import copy 10 | from common.skeleton import Skeleton 11 | from common.mocap_dataset import MocapDataset 12 | from common.camera import normalize_screen_coordinates, image_coordinates 13 | 14 | humaneva_skeleton = Skeleton(parents=[-1, 0, 1, 2, 3, 1, 5, 6, 0, 8, 9, 0, 11, 12, 1], 15 | joints_left=[2, 3, 4, 8, 9, 10], 16 | joints_right=[5, 6, 7, 11, 12, 13]) 17 | 18 | humaneva_cameras_intrinsic_params = [ 19 | { 20 | 'id': 'C1', 21 | 'res_w': 640, 22 | 'res_h': 480, 23 | 'azimuth': 0, # Only used for visualization 24 | }, 25 | { 26 | 'id': 'C2', 27 | 'res_w': 640, 28 | 'res_h': 480, 29 | 'azimuth': -90, # Only used for visualization 30 | }, 31 | { 32 | 'id': 'C3', 33 | 'res_w': 640, 34 | 'res_h': 480, 35 | 'azimuth': 90, # Only used for visualization 36 | }, 37 | ] 38 | 39 | humaneva_cameras_extrinsic_params = { 40 | 'S1': [ 41 | { 42 | 'orientation': [0.424207, -0.4983646, -0.5802981, 0.4847012], 43 | 'translation': [4062.227, 663.2477, 1528.397], 44 | }, 45 | { 46 | 'orientation': [0.6503354, -0.7481602, -0.0919284, 0.0941766], 47 | 'translation': [844.8131, -3805.2092, 1504.9929], 48 | }, 49 | { 50 | 'orientation': [0.0664734, -0.0690535, 0.7416416, -0.6639132], 51 | 'translation': [-797.67377, 3916.3174, 1433.6602], 52 | }, 53 | ], 54 | 'S2': [ 55 | { 56 | 'orientation': [ 0.4214752, -0.4961493, -0.5838273, 0.4851187 ], 57 | 'translation': [ 4112.9121, 626.4929, 1545.2988], 58 | }, 59 | { 60 | 'orientation': [ 0.6501393, -0.7476588, -0.0954617, 0.0959808 ], 61 | 'translation': [ 923.5740, -3877.9243, 1504.5518], 62 | }, 63 | { 64 | 'orientation': [ 0.0699353, -0.0712403, 0.7421637, -0.662742 ], 65 | 'translation': [ -781.4915, 3838.8853, 1444.9929], 66 | }, 67 | ], 68 | 'S3': [ 69 | { 70 | 'orientation': [ 0.424207, -0.4983646, -0.5802981, 0.4847012 ], 71 | 'translation': [ 4062.2271, 663.2477, 1528.3970], 72 | }, 73 | { 74 | 'orientation': [ 0.6503354, -0.7481602, -0.0919284, 0.0941766 ], 75 | 'translation': [ 844.8131, -3805.2092, 1504.9929], 76 | }, 77 | { 78 | 'orientation': [ 0.0664734, -0.0690535, 0.7416416, -0.6639132 ], 79 | 'translation': [ -797.6738, 3916.3174, 1433.6602], 80 | }, 81 | ], 82 | 'S4': [ 83 | {}, 84 | {}, 85 | {}, 86 | ], 87 | 88 | } 89 | 90 | class HumanEvaDataset(MocapDataset): 91 | def __init__(self, path): 92 | super().__init__(fps=60, skeleton=humaneva_skeleton) 93 | 94 | self._cameras = copy.deepcopy(humaneva_cameras_extrinsic_params) 95 | for cameras in self._cameras.values(): 96 | for i, cam in enumerate(cameras): 97 | cam.update(humaneva_cameras_intrinsic_params[i]) 98 | for k, v in cam.items(): 99 | if k not in ['id', 'res_w', 'res_h']: 100 | cam[k] = np.array(v, dtype='float32') 101 | if 'translation' in cam: 102 | cam['translation'] = cam['translation']/1000 # mm to meters 103 | 104 | for subject in list(self._cameras.keys()): 105 | data = self._cameras[subject] 106 | del self._cameras[subject] 107 | for prefix in ['Train/', 'Validate/', 'Unlabeled/Train/', 'Unlabeled/Validate/', 'Unlabeled/']: 108 | self._cameras[prefix + subject] = data 109 | 110 | # Load serialized dataset 111 | data = np.load(path, allow_pickle=True)['positions_3d'].item() 112 | 113 | self._data = {} 114 | for subject, actions in data.items(): 115 | self._data[subject] = {} 116 | for action_name, positions in actions.items(): 117 | self._data[subject][action_name] = { 118 | 'positions': positions, 119 | 'cameras': self._cameras[subject], 120 | } 121 | -------------------------------------------------------------------------------- /crossformer.yml: -------------------------------------------------------------------------------- 1 | name: pose2 2 | channels: 3 | - pytorch 4 | - anaconda 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1=main 9 | - _tflow_select=2.3.0=mkl 10 | - absl-py=0.10.0=py38h32f6830_1 11 | - aiohttp=3.7.2=py38h1e0a361_0 12 | - astunparse=1.6.3=py_0 13 | - async-timeout=3.0.1=py_1000 14 | - attrs=20.2.0=pyh9f0ad1d_0 15 | - blas=1.0=mkl 16 | - blinker=1.4=py_1 17 | - brotlipy=0.7.0=py38h7b6447c_1000 18 | - bzip2=1.0.8=h7b6447c_0 19 | - c-ares=1.16.1=h516909a_3 20 | - ca-certificates=2021.1.19=h06a4308_0 21 | - cachetools=4.1.1=py_0 22 | - cairo=1.16.0=h3fc0475_1005 23 | - cdflib=0.3.18=py_0 24 | - certifi=2020.12.5=py38h06a4308_0 25 | - cffi=1.14.0=py38h2e261b9_0 26 | - chardet=3.0.4=py38_1003 27 | - click=7.1.2=pyh9f0ad1d_0 28 | - cryptography=3.1.1=py38h1ba5d50_0 29 | - cudatoolkit=11.0.221=h6bb024c_0 30 | - cycler=0.10.0=py_2 31 | - cython=0.29.21=py38he6710b0_0 32 | - dbus=1.13.16=hb2f20db_0 33 | - einops=0.3.0=py_0 34 | - expat=2.2.9=he6710b0_2 35 | - fontconfig=2.13.1=h1056068_1002 36 | - freetype=2.10.2=h5ab3b9f_0 37 | - gast=0.3.3=py_0 38 | - glib=2.63.1=h5a9c865_0 39 | - gmp=6.2.0=he1b5a44_2 40 | - gnutls=3.6.13=h79a8f9a_0 41 | - google-auth=1.23.0=pyhd8ed1ab_0 42 | - google-auth-oauthlib=0.4.1=py_2 43 | - google-pasta=0.2.0=py_0 44 | - graphite2=1.3.14=h23475e2_0 45 | - grpcio=1.33.2=py38heead2fc_0 46 | - gst-plugins-base=1.14.5=h0935bb2_2 47 | - gstreamer=1.14.5=h36ae1b5_2 48 | - h5py=2.10.0=py38hd6299e0_1 49 | - harfbuzz=2.7.2=hee91db6_0 50 | - hdf5=1.10.6=hb1b8bf9_0 51 | - icu=67.1=he1b5a44_0 52 | - idna=2.10=py_0 53 | - importlib-metadata=2.0.0=py_1 54 | - intel-openmp=2020.2=254 55 | - jasper=1.900.1=hd497a04_4 56 | - jpeg=9d=h516909a_0 57 | - json_tricks=3.15.3=pyh9f0ad1d_0 58 | - keras-preprocessing=1.1.0=py_1 59 | - kiwisolver=1.2.0=py38hbf85e49_0 60 | - krb5=1.17.1=h173b8e3_0 61 | - lame=3.100=h7b6447c_0 62 | - lcms2=2.11=h396b838_0 63 | - ld_impl_linux-64=2.33.1=h53a641e_7 64 | - libblas=3.8.0=16_mkl 65 | - libcblas=3.8.0=16_mkl 66 | - libclang=10.0.1=default_hb85057a_2 67 | - libedit=3.1.20191231=h14c3975_1 68 | - libevent=2.1.10=hcdb4288_2 69 | - libffi=3.2.1=hf484d3e_1007 70 | - libgcc-ng=9.1.0=hdf63c60_0 71 | - libgfortran-ng=7.3.0=hdf63c60_0 72 | - libiconv=1.16=h516909a_0 73 | - liblapack=3.8.0=16_mkl 74 | - liblapacke=3.8.0=16_mkl 75 | - libllvm10=10.0.1=hbcb73fb_5 76 | - libpng=1.6.37=hbc83047_0 77 | - libpq=12.3=h5513abc_0 78 | - libprotobuf=3.13.0=hd408876_0 79 | - libsodium=1.0.18=h7b6447c_0 80 | - libstdcxx-ng=9.1.0=hdf63c60_0 81 | - libtiff=4.1.0=h2733197_1 82 | - libuuid=2.32.1=h14c3975_1000 83 | - libuv=1.40.0=h7b6447c_0 84 | - libwebp-base=1.1.0=h7b6447c_3 85 | - libxcb=1.14=h7b6447c_0 86 | - libxkbcommon=0.10.0=he1b5a44_0 87 | - libxml2=2.9.10=h68273f3_2 88 | - lz4-c=1.9.2=he6710b0_1 89 | - markdown=3.3.3=pyh9f0ad1d_0 90 | - matplotlib=3.3.2=0 91 | - matplotlib-base=3.3.2=py38h91b0d89_0 92 | - mkl=2020.2=256 93 | - mkl-service=2.3.0=py38he904b0f_0 94 | - mkl_fft=1.2.0=py38h23d657b_0 95 | - mkl_random=1.1.1=py38h0573a6f_0 96 | - multidict=4.7.5=py38h1e0a361_2 97 | - mysql-common=8.0.21=2 98 | - mysql-libs=8.0.21=hf3661c5_2 99 | - ncurses=6.2=he6710b0_1 100 | - nettle=3.4.1=hbb512f6_0 101 | - nibabel=3.1.1=py_0 102 | - ninja=1.10.1=py38hfd86e86_0 103 | - nspr=4.29=he1b5a44_0 104 | - nss=3.57=he751ad9_0 105 | - numpy-base=1.19.2=py38hfa32c7d_0 106 | - oauthlib=3.0.1=py_0 107 | - olefile=0.46=py_0 108 | - openh264=2.1.1=h8b12597_0 109 | - openssl=1.1.1i=h27cfd23_0 110 | - opt_einsum=3.1.0=py_0 111 | - packaging=20.4=py_0 112 | - pandas=1.1.1=py38he6710b0_0 113 | - pcre=8.44=he6710b0_0 114 | - pillow=7.2.0=py38hb39fc2d_0 115 | - pip=20.2.2=py38_0 116 | - pixman=0.38.0=h7b6447c_0 117 | - protobuf=3.13.0=py38he6710b0_1 118 | - psutil=5.7.2=py38h7b6447c_0 119 | - pyasn1=0.4.8=py_0 120 | - pyasn1-modules=0.2.7=py_0 121 | - pycparser=2.20=py_2 122 | - pydicom=2.0.0=pyh9f0ad1d_0 123 | - pyjwt=1.7.1=py_0 124 | - pyopenssl=19.1.0=py_1 125 | - pyparsing=2.4.7=py_0 126 | - pysocks=1.7.1=py38_0 127 | - python=3.8.2=hcf32534_0 128 | - python-dateutil=2.8.1=py_0 129 | - python_abi=3.8=1_cp38 130 | - pytorch=1.7.1=py3.8_cuda11.0.221_cudnn8.0.5_0 131 | - pytz=2020.1=py_0 132 | - pyyaml=5.3.1=py38h7b6447c_1 133 | - pyzmq=19.0.2=py38he6710b0_1 134 | - qt=5.12.9=h1f2b2cb_0 135 | - readline=8.0=h7b6447c_0 136 | - requests=2.24.0=py_0 137 | - requests-oauthlib=1.3.0=pyh9f0ad1d_0 138 | - rsa=4.6=pyh9f0ad1d_0 139 | - scipy=1.5.2=py38h0b6359f_0 140 | - setuptools=49.6.0=py38_1 141 | - six=1.15.0=py_0 142 | - sqlite=3.33.0=h62c20be_0 143 | - tensorboard=2.3.0=py_0 144 | - tensorboard-plugin-wit=1.7.0=pyh9f0ad1d_0 145 | - tensorboardx=2.1=py_0 146 | - tensorflow=2.2.0=mkl_py38h6d3daf0_0 147 | - tensorflow-base=2.2.0=mkl_py38h5059a2d_0 148 | - tensorflow-estimator=2.2.0=pyh208ff02_0 149 | - termcolor=1.1.0=py38_1 150 | - tk=8.6.10=hbc83047_0 151 | - torchfile=0.1.0=py_0 152 | - tornado=6.0.4=py38h7b6447c_1 153 | - tqdm=4.50.0=pyh9f0ad1d_0 154 | - typing-extensions=3.7.4.3=0 155 | - typing_extensions=3.7.4.3=py_0 156 | - urllib3=1.25.10=py_0 157 | - visdom=0.1.8.9=0 158 | - websocket-client=0.57.0=py38_1 159 | - werkzeug=1.0.1=pyh9f0ad1d_0 160 | - wheel=0.35.1=py_0 161 | - wrapt=1.12.1=py38h7b6447c_1 162 | - x264=1!152.20180806=h7b6447c_0 163 | - xorg-kbproto=1.0.7=h14c3975_1002 164 | - xorg-libice=1.0.10=h516909a_0 165 | - xorg-libsm=1.2.3=h84519dc_1000 166 | - xorg-libx11=1.6.12=h516909a_0 167 | - xorg-libxext=1.3.4=h516909a_0 168 | - xorg-libxrender=0.9.10=h516909a_1002 169 | - xorg-renderproto=0.11.1=h14c3975_1002 170 | - xorg-xextproto=7.3.0=h14c3975_1002 171 | - xorg-xproto=7.0.31=h14c3975_1007 172 | - xz=5.2.5=h7b6447c_0 173 | - yacs=0.1.6=py_0 174 | - yaml=0.2.5=h7b6447c_0 175 | - yarl=1.6.2=py38h1e0a361_0 176 | - zeromq=4.3.2=he6710b0_3 177 | - zipp=3.4.0=py_0 178 | - zlib=1.2.11=h7b6447c_3 179 | - zstd=1.4.5=h9ceee32_0 180 | - pip: 181 | - decorator==4.4.2 182 | - easydict==1.7 183 | - fvcore==0.1.3.post20210311 184 | - imageio==2.9.0 185 | - iopath==0.1.4 186 | - munkres==1.1.4 187 | - networkx==2.5 188 | - numpy==1.16.2 189 | - opencv-python==4.4.0.44 190 | - portalocker==2.2.1 191 | - progress==1.5 192 | - pthflops==0.4.0 193 | - pywavelets==1.1.1 194 | - scikit-image==0.17.2 195 | - tabulate==0.8.9 196 | - thop==0.0.31-2005241907 197 | - tifffile==2020.10.1 198 | - timm==0.3.4 199 | - torchscan==0.1.1 200 | - torchstat==0.0.7 201 | - torchvision==0.8.2 202 | prefix: /home/cezheng/anaconda3/envs/pose2 203 | -------------------------------------------------------------------------------- /common/reversible.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from operator import itemgetter 4 | from torch.autograd.function import Function 5 | from torch.utils.checkpoint import get_device_states, set_device_states 6 | 7 | # for routing arguments into the functions of the reversible layer 8 | def route_args(router, args, depth): 9 | routed_args = [(dict(), dict()) for _ in range(depth)] 10 | matched_keys = [key for key in args.keys() if key in router] 11 | 12 | for key in matched_keys: 13 | val = args[key] 14 | for depth, ((f_args, g_args), routes) in enumerate(zip(routed_args, router[key])): 15 | new_f_args, new_g_args = map(lambda route: ({key: val} if route else {}), routes) 16 | routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args}) 17 | return routed_args 18 | 19 | def layer_drop(layers, prob): 20 | to_drop = torch.empty(len(layers)).uniform_(0, 1) < prob 21 | blocks = [block for block, drop in zip(layers, to_drop) if not drop] 22 | blocks = layers[:1] if len(blocks) == 0 else blocks 23 | return blocks 24 | 25 | # following example for saving and setting rng here https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html 26 | class Deterministic(nn.Module): 27 | def __init__(self, net): 28 | super().__init__() 29 | self.net = net 30 | self.cpu_state = None 31 | self.cuda_in_fwd = None 32 | self.gpu_devices = None 33 | self.gpu_states = None 34 | 35 | def record_rng(self, *args): 36 | self.cpu_state = torch.get_rng_state() 37 | if torch.cuda._initialized: 38 | self.cuda_in_fwd = True 39 | self.gpu_devices, self.gpu_states = get_device_states(*args) 40 | 41 | def forward(self, *args, record_rng = False, set_rng = False, **kwargs): 42 | if record_rng: 43 | self.record_rng(*args) 44 | 45 | if not set_rng: 46 | return self.net(*args, **kwargs) 47 | 48 | rng_devices = [] 49 | if self.cuda_in_fwd: 50 | rng_devices = self.gpu_devices 51 | 52 | with torch.random.fork_rng(devices=rng_devices, enabled=True): 53 | torch.set_rng_state(self.cpu_state) 54 | if self.cuda_in_fwd: 55 | set_device_states(self.gpu_devices, self.gpu_states) 56 | return self.net(*args, **kwargs) 57 | 58 | # heavily inspired by https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py 59 | # once multi-GPU is confirmed working, refactor and send PR back to source 60 | class ReversibleBlock(nn.Module): 61 | def __init__(self, f, g): 62 | super().__init__() 63 | self.f = Deterministic(f) 64 | self.g = Deterministic(g) 65 | 66 | def forward(self, x, f_args = {}, g_args = {}): 67 | x1, x2 = torch.chunk(x, 2, dim=2) 68 | y1, y2 = None, None 69 | 70 | with torch.no_grad(): 71 | y1 = x1 + self.f(x2, record_rng=self.training, **f_args) 72 | y2 = x2 + self.g(y1, record_rng=self.training, **g_args) 73 | 74 | return torch.cat([y1, y2], dim=2) 75 | 76 | def backward_pass(self, y, dy, f_args = {}, g_args = {}): 77 | y1, y2 = torch.chunk(y, 2, dim=2) 78 | del y 79 | 80 | dy1, dy2 = torch.chunk(dy, 2, dim=2) 81 | del dy 82 | 83 | with torch.enable_grad(): 84 | y1.requires_grad = True 85 | gy1 = self.g(y1, set_rng=True, **g_args) 86 | torch.autograd.backward(gy1, dy2) 87 | 88 | with torch.no_grad(): 89 | x2 = y2 - gy1 90 | del y2, gy1 91 | 92 | dx1 = dy1 + y1.grad 93 | del dy1 94 | y1.grad = None 95 | 96 | with torch.enable_grad(): 97 | x2.requires_grad = True 98 | fx2 = self.f(x2, set_rng=True, **f_args) 99 | torch.autograd.backward(fx2, dx1, retain_graph=True) 100 | 101 | with torch.no_grad(): 102 | x1 = y1 - fx2 103 | del y1, fx2 104 | 105 | dx2 = dy2 + x2.grad 106 | del dy2 107 | x2.grad = None 108 | 109 | x = torch.cat([x1, x2.detach()], dim=2) 110 | dx = torch.cat([dx1, dx2], dim=2) 111 | 112 | return x, dx 113 | 114 | class _ReversibleFunction(Function): 115 | @staticmethod 116 | def forward(ctx, x, blocks, args): 117 | ctx.args = args 118 | for block, kwarg in zip(blocks, args): 119 | x = block(x, **kwarg) 120 | ctx.y = x.detach() 121 | ctx.blocks = blocks 122 | return x 123 | 124 | @staticmethod 125 | def backward(ctx, dy): 126 | y = ctx.y 127 | args = ctx.args 128 | for block, kwargs in zip(ctx.blocks[::-1], args[::-1]): 129 | y, dy = block.backward_pass(y, dy, **kwargs) 130 | return dy, None, None 131 | 132 | 133 | class SequentialSequence(nn.Module): 134 | def __init__(self, layers, args_route = {}, layer_dropout = 0.): 135 | super().__init__() 136 | assert all(len(route) == len(layers) for route in args_route.values()), 'each argument route map must have the same depth as the number of sequential layers' 137 | self.layers = layers 138 | self.args_route = args_route 139 | self.layer_dropout = layer_dropout 140 | 141 | def forward(self, x, **kwargs): 142 | args = route_args(self.args_route, kwargs, len(self.layers)) 143 | layers_and_args = list(zip(self.layers, args)) 144 | 145 | if self.training and self.layer_dropout > 0: 146 | layers_and_args = layer_drop(layers_and_args, self.layer_dropout) 147 | 148 | for (f, g), (f_args, g_args) in layers_and_args: 149 | x = x + f(x, **f_args) 150 | x = x + g(x, **g_args) 151 | return x 152 | 153 | class ReversibleSequence(nn.Module): 154 | def __init__(self, blocks, args_route = {}, layer_dropout = 0.): 155 | super().__init__() 156 | self.args_route = args_route 157 | self.layer_dropout = layer_dropout 158 | self.blocks = nn.ModuleList([ReversibleBlock(f=f, g=g) for f, g in blocks]) 159 | 160 | def forward(self, x, **kwargs): 161 | x = torch.cat([x, x], dim=-1) 162 | 163 | blocks = self.blocks 164 | args = route_args(self.args_route, kwargs, len(blocks)) 165 | args = list(map(lambda x: {'f_args': x[0], 'g_args': x[1]}, args)) 166 | 167 | layers_and_args = list(zip(blocks, args)) 168 | 169 | if self.training and self.layer_dropout > 0: 170 | layers_and_args = layer_drop(layers_and_args, self.layer_dropout) 171 | blocks, args = map(lambda ind: list(map(itemgetter(ind), layers_and_args)), (0, 1)) 172 | 173 | out = _ReversibleFunction.apply(x, blocks, args) 174 | return torch.stack(out.chunk(2, dim=-1)).sum(dim=0) 175 | -------------------------------------------------------------------------------- /common/arguments.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import argparse 9 | 10 | def parse_args(): 11 | parser = argparse.ArgumentParser(description='Training script') 12 | 13 | # General arguments 14 | parser.add_argument('-d', '--dataset', default='h36m', type=str, metavar='NAME', help='target dataset') # h36m or humaneva 15 | parser.add_argument('-k', '--keypoints', default='gt', type=str, metavar='NAME', help='2D detections to use') 16 | parser.add_argument('-str', '--subjects-train', default='S1,S5,S6,S7,S8', type=str, metavar='LIST', 17 | help='training subjects separated by comma') 18 | parser.add_argument('-ste', '--subjects-test', default='S9,S11', type=str, metavar='LIST', help='test subjects separated by comma') 19 | parser.add_argument('-sun', '--subjects-unlabeled', default='', type=str, metavar='LIST', 20 | help='unlabeled subjects separated by comma for self-supervision') 21 | parser.add_argument('-a', '--actions', default='*', type=str, metavar='LIST', 22 | help='actions to train/test on, separated by comma, or * for all') 23 | parser.add_argument('-c', '--checkpoint', default='checkpoint', type=str, metavar='PATH', 24 | help='checkpoint directory') 25 | parser.add_argument('--checkpoint-frequency', default=40, type=int, metavar='N', 26 | help='create a checkpoint every N epochs') 27 | parser.add_argument('-r', '--resume', default='', type=str, metavar='FILENAME', 28 | help='checkpoint to resume (file name)') 29 | parser.add_argument('--evaluate', default='', type=str, metavar='FILENAME', help='checkpoint to evaluate (file name)') 30 | parser.add_argument('--render', action='store_true', help='visualize a particular video') 31 | parser.add_argument('--by-subject', action='store_true', help='break down error by subject (on evaluation)') 32 | parser.add_argument('--export-training-curves', action='store_true', help='save training curves as .png images') 33 | 34 | 35 | # Model arguments 36 | parser.add_argument('-s', '--stride', default=1, type=int, metavar='N', help='chunk size to use during training') 37 | parser.add_argument('-e', '--epochs', default=200, type=int, metavar='N', help='number of training epochs') 38 | parser.add_argument('-b', '--batch-size', default=512, type=int, metavar='N', help='batch size in terms of predicted frames') 39 | parser.add_argument('-drop', '--dropout', default=0.1, type=float, metavar='P', help='dropout probability') 40 | parser.add_argument('-lr', '--learning-rate', default=0.0001, type=float, metavar='LR', help='initial learning rate') 41 | parser.add_argument('-lrd', '--lr-decay', default=0.99, type=float, metavar='LR', help='learning rate decay per epoch') 42 | parser.add_argument('-no-da', '--no-data-augmentation', dest='data_augmentation', action='store_false', 43 | help='disable train-time flipping') 44 | # parser.add_argument('-no-tta', '--no-test-time-augmentation', dest='test_time_augmentation', action='store_false', 45 | # help='disable test-time flipping') 46 | # parser.add_argument('-arc', '--architecture', default='3,3,3', type=str, metavar='LAYERS', help='filter widths separated by comma') 47 | parser.add_argument('-frame', '--number-of-frames', default='81', type=int, metavar='N', 48 | help='how many frames used as input') 49 | # parser.add_argument('--causal', action='store_true', help='use causal convolutions for real-time processing') 50 | # parser.add_argument('-ch', '--channels', default=1024, type=int, metavar='N', help='number of channels in convolution layers') 51 | 52 | # Experimental 53 | parser.add_argument('--subset', default=1, type=float, metavar='FRACTION', help='reduce dataset size by fraction') 54 | parser.add_argument('--downsample', default=1, type=int, metavar='FACTOR', help='downsample frame rate by factor (semi-supervised)') 55 | parser.add_argument('--warmup', default=1, type=int, metavar='N', help='warm-up epochs for semi-supervision') 56 | parser.add_argument('--no-eval', action='store_true', help='disable epoch evaluation while training (small speed-up)') 57 | parser.add_argument('--dense', action='store_true', help='use dense convolutions instead of dilated convolutions') 58 | parser.add_argument('--disable-optimizations', action='store_true', help='disable optimized model for single-frame predictions') 59 | parser.add_argument('--linear-projection', action='store_true', help='use only linear coefficients for semi-supervised projection') 60 | parser.add_argument('--no-bone-length', action='store_false', dest='bone_length_term', 61 | help='disable bone length term in semi-supervised settings') 62 | parser.add_argument('--no-proj', action='store_true', help='disable projection for semi-supervised setting') 63 | 64 | # Visualization 65 | parser.add_argument('--viz-subject', type=str, metavar='STR', help='subject to render') 66 | parser.add_argument('--viz-action', type=str, metavar='STR', help='action to render') 67 | parser.add_argument('--viz-camera', type=int, default=0, metavar='N', help='camera to render') 68 | parser.add_argument('--viz-video', type=str, metavar='PATH', help='path to input video') 69 | parser.add_argument('--viz-skip', type=int, default=0, metavar='N', help='skip first N frames of input video') 70 | parser.add_argument('--viz-output', type=str, metavar='PATH', help='output file name (.gif or .mp4)') 71 | parser.add_argument('--viz-export', type=str, metavar='PATH', help='output file name for coordinates') 72 | parser.add_argument('--viz-bitrate', type=int, default=3000, metavar='N', help='bitrate for mp4 videos') 73 | parser.add_argument('--viz-no-ground-truth', action='store_true', help='do not show ground-truth poses') 74 | parser.add_argument('--viz-limit', type=int, default=-1, metavar='N', help='only render first N frames') 75 | parser.add_argument('--viz-downsample', type=int, default=1, metavar='N', help='downsample FPS by a factor N') 76 | parser.add_argument('--viz-size', type=int, default=5, metavar='N', help='image size') 77 | 78 | parser.set_defaults(bone_length_term=True) 79 | parser.set_defaults(data_augmentation=True) 80 | parser.set_defaults(test_time_augmentation=True) 81 | # parser.set_defaults(test_time_augmentation=False) 82 | 83 | args = parser.parse_args() 84 | # Check invalid configuration 85 | if args.resume and args.evaluate: 86 | print('Invalid flags: --resume and --evaluate cannot be set at the same time') 87 | exit() 88 | 89 | if args.export_training_curves and args.no_eval: 90 | print('Invalid flags: --export-training-curves and --no-eval cannot be set at the same time') 91 | exit() 92 | 93 | return args 94 | -------------------------------------------------------------------------------- /common/visualization.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import matplotlib 9 | 10 | matplotlib.use('Agg') 11 | 12 | 13 | import matplotlib.pyplot as plt 14 | from matplotlib.animation import FuncAnimation, writers 15 | from mpl_toolkits.mplot3d import Axes3D 16 | import numpy as np 17 | import subprocess as sp 18 | import cv2 19 | 20 | 21 | def get_resolution(filename): 22 | command = ['ffprobe', '-v', 'error', '-select_streams', 'v:0', 23 | '-show_entries', 'stream=width,height', '-of', 'csv=p=0', filename] 24 | with sp.Popen(command, stdout=sp.PIPE, bufsize=-1) as pipe: 25 | for line in pipe.stdout: 26 | w, h = line.decode().strip().split(',') 27 | return int(w), int(h) 28 | 29 | 30 | def get_fps(filename): 31 | command = ['ffprobe', '-v', 'error', '-select_streams', 'v:0', 32 | '-show_entries', 'stream=r_frame_rate', '-of', 'csv=p=0', filename] 33 | with sp.Popen(command, stdout=sp.PIPE, bufsize=-1) as pipe: 34 | for line in pipe.stdout: 35 | a, b = line.decode().strip().split('/') 36 | return int(a) / int(b) 37 | 38 | 39 | def read_video(filename, skip=0, limit=-1): 40 | # w, h = get_resolution(filename) 41 | w = 1000 42 | h = 1002 43 | 44 | command = ['ffmpeg', 45 | '-i', filename, 46 | '-f', 'image2pipe', 47 | '-pix_fmt', 'rgb24', 48 | '-vsync', '0', 49 | '-vcodec', 'rawvideo', '-'] 50 | 51 | i = 0 52 | with sp.Popen(command, stdout=sp.PIPE, bufsize=-1) as pipe: 53 | while True: 54 | data = pipe.stdout.read(w * h * 3) 55 | if not data: 56 | break 57 | i += 1 58 | if i > limit and limit != -1: 59 | continue 60 | if i > skip: 61 | yield np.frombuffer(data, dtype='uint8').reshape((h, w, 3)) 62 | 63 | 64 | def downsample_tensor(X, factor): 65 | length = X.shape[0] // factor * factor 66 | return np.mean(X[:length].reshape(-1, factor, *X.shape[1:]), axis=1) 67 | 68 | 69 | def render_animation(keypoints, keypoints_metadata, poses, skeleton, fps, bitrate, azim, output, viewport, 70 | limit=-1, downsample=1, size=6, input_video_path=None, input_video_skip=0): 71 | """ 72 | TODO 73 | Render an animation. The supported output modes are: 74 | -- 'interactive': display an interactive figure 75 | (also works on notebooks if associated with %matplotlib inline) 76 | -- 'html': render the animation as HTML5 video. Can be displayed in a notebook using HTML(...). 77 | -- 'filename.mp4': render and export the animation as an h264 video (requires ffmpeg). 78 | -- 'filename.gif': render and export the animation a gif file (requires imagemagick). 79 | """ 80 | plt.ioff() 81 | fig = plt.figure(figsize=(size * (1 + len(poses)), size)) 82 | ax_in = fig.add_subplot(1, 1 + len(poses), 1) 83 | ax_in.get_xaxis().set_visible(False) 84 | ax_in.get_yaxis().set_visible(False) 85 | ax_in.set_axis_off() 86 | ax_in.set_title('Input') 87 | 88 | ax_3d = [] 89 | lines_3d = [] 90 | trajectories = [] 91 | radius = 1.7 92 | for index, (title, data) in enumerate(poses.items()): 93 | ax = fig.add_subplot(1, 1 + len(poses), index + 2, projection='3d') 94 | ax.view_init(elev=15., azim=azim) 95 | ax.set_xlim3d([-radius / 2, radius / 2]) 96 | ax.set_zlim3d([0, radius]) 97 | ax.set_ylim3d([-radius / 2, radius / 2]) 98 | try: 99 | ax.set_aspect('equal') 100 | except NotImplementedError: 101 | ax.set_aspect('auto') 102 | ax.set_xticklabels([]) 103 | ax.set_yticklabels([]) 104 | ax.set_zticklabels([]) 105 | ax.dist = 7.5 106 | ax.set_title(title) # , pad=35 107 | ax_3d.append(ax) 108 | lines_3d.append([]) 109 | trajectories.append(data[:, 0, [0, 1]]) 110 | poses = list(poses.values()) 111 | 112 | # Decode video 113 | if input_video_path is None: 114 | # Black background 115 | all_frames = np.zeros((keypoints.shape[0], viewport[1], viewport[0]), dtype='uint8') 116 | else: 117 | # Load video using ffmpeg 118 | all_frames = [] 119 | for f in read_video(input_video_path, skip=input_video_skip, limit=limit): 120 | all_frames.append(f) 121 | effective_length = min(keypoints.shape[0], len(all_frames)) 122 | all_frames = all_frames[:effective_length] 123 | 124 | keypoints = keypoints[input_video_skip:] # todo remove 125 | for idx in range(len(poses)): 126 | poses[idx] = poses[idx][input_video_skip:] 127 | 128 | if fps is None: 129 | fps = get_fps(input_video_path) 130 | 131 | if downsample > 1: 132 | keypoints = downsample_tensor(keypoints, downsample) 133 | all_frames = downsample_tensor(np.array(all_frames), downsample).astype('uint8') 134 | for idx in range(len(poses)): 135 | poses[idx] = downsample_tensor(poses[idx], downsample) 136 | trajectories[idx] = downsample_tensor(trajectories[idx], downsample) 137 | fps /= downsample 138 | 139 | initialized = False 140 | image = None 141 | lines = [] 142 | points = None 143 | 144 | if limit < 1: 145 | limit = len(all_frames) 146 | else: 147 | limit = min(limit, len(all_frames)) 148 | 149 | parents = skeleton.parents() 150 | 151 | def update_video(i): 152 | nonlocal initialized, image, lines, points 153 | 154 | for n, ax in enumerate(ax_3d): 155 | ax.set_xlim3d([-radius / 2 + trajectories[n][i, 0], radius / 2 + trajectories[n][i, 0]]) 156 | ax.set_ylim3d([-radius / 2 + trajectories[n][i, 1], radius / 2 + trajectories[n][i, 1]]) 157 | 158 | # Update 2D poses 159 | joints_right_2d = keypoints_metadata['keypoints_symmetry'][1] 160 | colors_2d = np.full(keypoints.shape[1], 'black') 161 | colors_2d[joints_right_2d] = 'red' 162 | if not initialized: 163 | image = ax_in.imshow(all_frames[i], aspect='equal') 164 | 165 | for j, j_parent in enumerate(parents): 166 | if j_parent == -1: 167 | continue 168 | 169 | if len(parents) == keypoints.shape[1] and keypoints_metadata['layout_name'] != 'coco': 170 | # Draw skeleton only if keypoints match (otherwise we don't have the parents definition) 171 | lines.append(ax_in.plot([keypoints[i, j, 0], keypoints[i, j_parent, 0]], 172 | [keypoints[i, j, 1], keypoints[i, j_parent, 1]], color='pink')) 173 | 174 | col = 'red' if j in skeleton.joints_right() else 'black' 175 | for n, ax in enumerate(ax_3d): 176 | pos = poses[n][i] 177 | lines_3d[n].append(ax.plot([pos[j, 0], pos[j_parent, 0]], 178 | [pos[j, 1], pos[j_parent, 1]], 179 | [pos[j, 2], pos[j_parent, 2]], zdir='z', c=col)) 180 | 181 | points = ax_in.scatter(*keypoints[i].T, 10, color=colors_2d, edgecolors='white', zorder=10) 182 | 183 | initialized = True 184 | else: 185 | image.set_data(all_frames[i]) 186 | 187 | for j, j_parent in enumerate(parents): 188 | if j_parent == -1: 189 | continue 190 | 191 | if len(parents) == keypoints.shape[1] and keypoints_metadata['layout_name'] != 'coco': 192 | lines[j - 1][0].set_data([keypoints[i, j, 0], keypoints[i, j_parent, 0]], 193 | [keypoints[i, j, 1], keypoints[i, j_parent, 1]]) 194 | 195 | for n, ax in enumerate(ax_3d): 196 | pos = poses[n][i] 197 | lines_3d[n][j - 1][0].set_xdata(np.array([pos[j, 0], pos[j_parent, 0]])) 198 | lines_3d[n][j - 1][0].set_ydata(np.array([pos[j, 1], pos[j_parent, 1]])) 199 | lines_3d[n][j - 1][0].set_3d_properties(np.array([pos[j, 2], pos[j_parent, 2]]), zdir='z') 200 | 201 | points.set_offsets(keypoints[i]) 202 | 203 | print('{}/{} '.format(i, limit), end='\r') 204 | 205 | fig.tight_layout() 206 | 207 | anim = FuncAnimation(fig, update_video, frames=np.arange(0, limit), interval=1000 / fps, repeat=False) 208 | if output.endswith('.mp4'): 209 | Writer = writers['ffmpeg'] 210 | writer = Writer(fps=fps, metadata={}, bitrate=bitrate) 211 | anim.save(output, writer=writer) 212 | elif output.endswith('.gif'): 213 | anim.save(output, dpi=80, writer='imagemagick') 214 | else: 215 | raise ValueError('Unsupported output format (only .mp4 and .gif are supported)') 216 | plt.close() -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /common/h36m_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import numpy as np 9 | import copy 10 | from common.skeleton import Skeleton 11 | from common.mocap_dataset import MocapDataset 12 | from common.camera import normalize_screen_coordinates, image_coordinates 13 | 14 | h36m_skeleton = Skeleton(parents=[-1, 0, 1, 2, 3, 4, 0, 6, 7, 8, 9, 0, 11, 12, 13, 14, 12, 15 | 16, 17, 18, 19, 20, 19, 22, 12, 24, 25, 26, 27, 28, 27, 30], 16 | joints_left=[6, 7, 8, 9, 10, 16, 17, 18, 19, 20, 21, 22, 23], 17 | joints_right=[1, 2, 3, 4, 5, 24, 25, 26, 27, 28, 29, 30, 31]) 18 | 19 | h36m_cameras_intrinsic_params = [ 20 | { 21 | 'id': '54138969', 22 | 'center': [512.54150390625, 515.4514770507812], 23 | 'focal_length': [1145.0494384765625, 1143.7811279296875], 24 | 'radial_distortion': [-0.20709891617298126, 0.24777518212795258, -0.0030751503072679043], 25 | 'tangential_distortion': [-0.0009756988729350269, -0.00142447161488235], 26 | 'res_w': 1000, 27 | 'res_h': 1002, 28 | 'azimuth': 70, # Only used for visualization 29 | }, 30 | { 31 | 'id': '55011271', 32 | 'center': [508.8486328125, 508.0649108886719], 33 | 'focal_length': [1149.6756591796875, 1147.5916748046875], 34 | 'radial_distortion': [-0.1942136287689209, 0.2404085397720337, 0.006819975562393665], 35 | 'tangential_distortion': [-0.0016190266469493508, -0.0027408944442868233], 36 | 'res_w': 1000, 37 | 'res_h': 1000, 38 | 'azimuth': -70, # Only used for visualization 39 | }, 40 | { 41 | 'id': '58860488', 42 | 'center': [519.8158569335938, 501.40264892578125], 43 | 'focal_length': [1149.1407470703125, 1148.7989501953125], 44 | 'radial_distortion': [-0.2083381861448288, 0.25548800826072693, -0.0024604974314570427], 45 | 'tangential_distortion': [0.0014843869721516967, -0.0007599993259645998], 46 | 'res_w': 1000, 47 | 'res_h': 1000, 48 | 'azimuth': 110, # Only used for visualization 49 | }, 50 | { 51 | 'id': '60457274', 52 | 'center': [514.9682006835938, 501.88201904296875], 53 | 'focal_length': [1145.5113525390625, 1144.77392578125], 54 | 'radial_distortion': [-0.198384091258049, 0.21832367777824402, -0.008947807364165783], 55 | 'tangential_distortion': [-0.0005872055771760643, -0.0018133620033040643], 56 | 'res_w': 1000, 57 | 'res_h': 1002, 58 | 'azimuth': -110, # Only used for visualization 59 | }, 60 | ] 61 | 62 | h36m_cameras_extrinsic_params = { 63 | 'S1': [ 64 | { 65 | 'orientation': [0.1407056450843811, -0.1500701755285263, -0.755240797996521, 0.6223280429840088], 66 | 'translation': [1841.1070556640625, 4955.28466796875, 1563.4454345703125], 67 | }, 68 | { 69 | 'orientation': [0.6157187819480896, -0.764836311340332, -0.14833825826644897, 0.11794740706682205], 70 | 'translation': [1761.278564453125, -5078.0068359375, 1606.2650146484375], 71 | }, 72 | { 73 | 'orientation': [0.14651472866535187, -0.14647851884365082, 0.7653023600578308, -0.6094175577163696], 74 | 'translation': [-1846.7777099609375, 5215.04638671875, 1491.972412109375], 75 | }, 76 | { 77 | 'orientation': [0.5834008455276489, -0.7853162288665771, 0.14548823237419128, -0.14749594032764435], 78 | 'translation': [-1794.7896728515625, -3722.698974609375, 1574.8927001953125], 79 | }, 80 | ], 81 | 'S2': [ 82 | {}, 83 | {}, 84 | {}, 85 | {}, 86 | ], 87 | 'S3': [ 88 | {}, 89 | {}, 90 | {}, 91 | {}, 92 | ], 93 | 'S4': [ 94 | {}, 95 | {}, 96 | {}, 97 | {}, 98 | ], 99 | 'S5': [ 100 | { 101 | 'orientation': [0.1467377245426178, -0.162370964884758, -0.7551892995834351, 0.6178938746452332], 102 | 'translation': [2097.3916015625, 4880.94482421875, 1605.732421875], 103 | }, 104 | { 105 | 'orientation': [0.6159758567810059, -0.7626792192459106, -0.15728192031383514, 0.1189815029501915], 106 | 'translation': [2031.7008056640625, -5167.93310546875, 1612.923095703125], 107 | }, 108 | { 109 | 'orientation': [0.14291371405124664, -0.12907841801643372, 0.7678384780883789, -0.6110143065452576], 110 | 'translation': [-1620.5948486328125, 5171.65869140625, 1496.43701171875], 111 | }, 112 | { 113 | 'orientation': [0.5920479893684387, -0.7814217805862427, 0.1274748593568802, -0.15036417543888092], 114 | 'translation': [-1637.1737060546875, -3867.3173828125, 1547.033203125], 115 | }, 116 | ], 117 | 'S6': [ 118 | { 119 | 'orientation': [0.1337897777557373, -0.15692396461963654, -0.7571090459823608, 0.6198879480361938], 120 | 'translation': [1935.4517822265625, 4950.24560546875, 1618.0838623046875], 121 | }, 122 | { 123 | 'orientation': [0.6147197484970093, -0.7628812789916992, -0.16174767911434174, 0.11819244921207428], 124 | 'translation': [1969.803955078125, -5128.73876953125, 1632.77880859375], 125 | }, 126 | { 127 | 'orientation': [0.1529948115348816, -0.13529130816459656, 0.7646096348762512, -0.6112781167030334], 128 | 'translation': [-1769.596435546875, 5185.361328125, 1476.993408203125], 129 | }, 130 | { 131 | 'orientation': [0.5916101336479187, -0.7804774045944214, 0.12832270562648773, -0.1561593860387802], 132 | 'translation': [-1721.668701171875, -3884.13134765625, 1540.4879150390625], 133 | }, 134 | ], 135 | 'S7': [ 136 | { 137 | 'orientation': [0.1435241848230362, -0.1631336808204651, -0.7548328638076782, 0.6188824772834778], 138 | 'translation': [1974.512939453125, 4926.3544921875, 1597.8326416015625], 139 | }, 140 | { 141 | 'orientation': [0.6141672730445862, -0.7638262510299683, -0.1596645563840866, 0.1177929937839508], 142 | 'translation': [1937.0584716796875, -5119.7900390625, 1631.5665283203125], 143 | }, 144 | { 145 | 'orientation': [0.14550060033798218, -0.12874816358089447, 0.7660516500473022, -0.6127139329910278], 146 | 'translation': [-1741.8111572265625, 5208.24951171875, 1464.8245849609375], 147 | }, 148 | { 149 | 'orientation': [0.5912848114967346, -0.7821764349937439, 0.12445473670959473, -0.15196487307548523], 150 | 'translation': [-1734.7105712890625, -3832.42138671875, 1548.5830078125], 151 | }, 152 | ], 153 | 'S8': [ 154 | { 155 | 'orientation': [0.14110587537288666, -0.15589867532253265, -0.7561917304992676, 0.619644045829773], 156 | 'translation': [2150.65185546875, 4896.1611328125, 1611.9046630859375], 157 | }, 158 | { 159 | 'orientation': [0.6169601678848267, -0.7647668123245239, -0.14846350252628326, 0.11158157885074615], 160 | 'translation': [2219.965576171875, -5148.453125, 1613.0440673828125], 161 | }, 162 | { 163 | 'orientation': [0.1471444070339203, -0.13377119600772858, 0.7670128345489502, -0.6100369691848755], 164 | 'translation': [-1571.2215576171875, 5137.0185546875, 1498.1761474609375], 165 | }, 166 | { 167 | 'orientation': [0.5927824378013611, -0.7825870513916016, 0.12147816270589828, -0.14631995558738708], 168 | 'translation': [-1476.913330078125, -3896.7412109375, 1547.97216796875], 169 | }, 170 | ], 171 | 'S9': [ 172 | { 173 | 'orientation': [0.15540587902069092, -0.15548215806484222, -0.7532095313072205, 0.6199594736099243], 174 | 'translation': [2044.45849609375, 4935.1171875, 1481.2275390625], 175 | }, 176 | { 177 | 'orientation': [0.618784487247467, -0.7634735107421875, -0.14132238924503326, 0.11933968216180801], 178 | 'translation': [1990.959716796875, -5123.810546875, 1568.8048095703125], 179 | }, 180 | { 181 | 'orientation': [0.13357827067375183, -0.1367100477218628, 0.7689454555511475, -0.6100738644599915], 182 | 'translation': [-1670.9921875, 5211.98583984375, 1528.387939453125], 183 | }, 184 | { 185 | 'orientation': [0.5879399180412292, -0.7823407053947449, 0.1427614390850067, -0.14794869720935822], 186 | 'translation': [-1696.04345703125, -3827.099853515625, 1591.4127197265625], 187 | }, 188 | ], 189 | 'S11': [ 190 | { 191 | 'orientation': [0.15232472121715546, -0.15442320704460144, -0.7547563314437866, 0.6191070079803467], 192 | 'translation': [2098.440185546875, 4926.5546875, 1500.278564453125], 193 | }, 194 | { 195 | 'orientation': [0.6189449429512024, -0.7600917220115662, -0.15300633013248444, 0.1255258321762085], 196 | 'translation': [2083.182373046875, -4912.1728515625, 1561.07861328125], 197 | }, 198 | { 199 | 'orientation': [0.14943228662014008, -0.15650227665901184, 0.7681233882904053, -0.6026304364204407], 200 | 'translation': [-1609.8153076171875, 5177.3359375, 1537.896728515625], 201 | }, 202 | { 203 | 'orientation': [0.5894251465797424, -0.7818877100944519, 0.13991211354732513, -0.14715361595153809], 204 | 'translation': [-1590.738037109375, -3854.1689453125, 1578.017578125], 205 | }, 206 | ], 207 | } 208 | 209 | class Human36mDataset(MocapDataset): 210 | def __init__(self, path, remove_static_joints=True): 211 | super().__init__(fps=50, skeleton=h36m_skeleton) 212 | 213 | self._cameras = copy.deepcopy(h36m_cameras_extrinsic_params) 214 | for cameras in self._cameras.values(): 215 | for i, cam in enumerate(cameras): 216 | cam.update(h36m_cameras_intrinsic_params[i]) 217 | for k, v in cam.items(): 218 | if k not in ['id', 'res_w', 'res_h']: 219 | cam[k] = np.array(v, dtype='float32') 220 | 221 | # Normalize camera frame 222 | cam['center'] = normalize_screen_coordinates(cam['center'], w=cam['res_w'], h=cam['res_h']).astype('float32') 223 | cam['focal_length'] = cam['focal_length']/cam['res_w']*2 224 | if 'translation' in cam: 225 | cam['translation'] = cam['translation']/1000 # mm to meters 226 | 227 | # Add intrinsic parameters vector 228 | cam['intrinsic'] = np.concatenate((cam['focal_length'], 229 | cam['center'], 230 | cam['radial_distortion'], 231 | cam['tangential_distortion'])) 232 | 233 | # Load serialized dataset 234 | data = np.load(path, allow_pickle=True)['positions_3d'].item() 235 | 236 | self._data = {} 237 | for subject, actions in data.items(): 238 | self._data[subject] = {} 239 | for action_name, positions in actions.items(): 240 | self._data[subject][action_name] = { 241 | 'positions': positions, 242 | 'cameras': self._cameras[subject], 243 | } 244 | 245 | if remove_static_joints: 246 | # Bring the skeleton to 17 joints instead of the original 32 247 | self.remove_joints([4, 5, 9, 10, 11, 16, 20, 21, 22, 23, 24, 28, 29, 30, 31]) 248 | # for gt 249 | # self.remove_joints([4, 5, 9, 10, 11, 14, 16, 20, 21, 22, 23, 24, 28, 29, 30, 31]) 250 | # Rewire shoulders to the correct parents 251 | self._skeleton._parents[11] = 8 252 | self._skeleton._parents[14] = 8 253 | 254 | def supports_semi_supervised(self): 255 | return True 256 | -------------------------------------------------------------------------------- /common/generators.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from itertools import zip_longest 9 | import numpy as np 10 | 11 | 12 | # def getbone(seq, boneindex): 13 | # bs = np.shape(seq)[0] 14 | # ss = np.shape(seq)[1] 15 | # seq = np.reshape(seq,(bs*ss,-1,3)) 16 | # bone = [] 17 | # for index in boneindex: 18 | # bone.append(seq[:,index[0]] - seq[:,index[1]]) 19 | # bone = np.stack(bone,1) 20 | # bone = np.power(np.power(bone,2).sum(2),0.5) 21 | # bone = np.reshape(bone, (bs,ss,np.shape(bone)[1])) 22 | # return bone 23 | 24 | class ChunkedGenerator: 25 | """ 26 | Batched data generator, used for training. 27 | The sequences are split into equal-length chunks and padded as necessary. 28 | 29 | Arguments: 30 | batch_size -- the batch size to use for training 31 | cameras -- list of cameras, one element for each video (optional, used for semi-supervised training) 32 | poses_3d -- list of ground-truth 3D poses, one element for each video (optional, used for supervised training) 33 | poses_2d -- list of input 2D keypoints, one element for each video 34 | chunk_length -- number of output frames to predict for each training example (usually 1) 35 | pad -- 2D input padding to compensate for valid convolutions, per side (depends on the receptive field) 36 | causal_shift -- asymmetric padding offset when causal convolutions are used (usually 0 or "pad") 37 | shuffle -- randomly shuffle the dataset before each epoch 38 | random_seed -- initial seed to use for the random generator 39 | augment -- augment the dataset by flipping poses horizontally 40 | kps_left and kps_right -- list of left/right 2D keypoints if flipping is enabled 41 | joints_left and joints_right -- list of left/right 3D joints if flipping is enabled 42 | """ 43 | def __init__(self, batch_size, cameras, poses_3d, poses_2d, 44 | chunk_length, pad=0, causal_shift=0, 45 | shuffle=True, random_seed=1234, 46 | augment=False, kps_left=None, kps_right=None, joints_left=None, joints_right=None, 47 | endless=False): 48 | assert poses_3d is None or len(poses_3d) == len(poses_2d), (len(poses_3d), len(poses_2d)) 49 | assert cameras is None or len(cameras) == len(poses_2d) 50 | 51 | # Build lineage info 52 | pairs = [] # (seq_idx, start_frame, end_frame, flip) tuples 53 | for i in range(len(poses_2d)): 54 | assert poses_3d is None or poses_3d[i].shape[0] == poses_3d[i].shape[0] 55 | n_chunks = (poses_2d[i].shape[0] + chunk_length - 1) // chunk_length 56 | offset = (n_chunks * chunk_length - poses_2d[i].shape[0]) // 2 57 | bounds = np.arange(n_chunks+1)*chunk_length - offset 58 | augment_vector = np.full(len(bounds - 1), False, dtype=bool) 59 | pairs += zip(np.repeat(i, len(bounds - 1)), bounds[:-1], bounds[1:], augment_vector) 60 | if augment: 61 | pairs += zip(np.repeat(i, len(bounds - 1)), bounds[:-1], bounds[1:], ~augment_vector) 62 | 63 | # Initialize buffers 64 | if cameras is not None: 65 | self.batch_cam = np.empty((batch_size, cameras[0].shape[-1])) 66 | if poses_3d is not None: 67 | self.batch_3d = np.empty((batch_size, chunk_length, poses_3d[0].shape[-2], poses_3d[0].shape[-1])) 68 | self.batch_2d = np.empty((batch_size, chunk_length + 2*pad, poses_2d[0].shape[-2], poses_2d[0].shape[-1])) 69 | 70 | self.num_batches = (len(pairs) + batch_size - 1) // batch_size 71 | self.batch_size = batch_size 72 | self.random = np.random.RandomState(random_seed) 73 | self.pairs = pairs 74 | self.shuffle = shuffle 75 | self.pad = pad 76 | self.causal_shift = causal_shift 77 | self.endless = endless 78 | self.state = None 79 | 80 | self.cameras = cameras 81 | self.poses_3d = poses_3d 82 | self.poses_2d = poses_2d 83 | 84 | self.augment = augment 85 | self.kps_left = kps_left 86 | self.kps_right = kps_right 87 | self.joints_left = joints_left 88 | self.joints_right = joints_right 89 | 90 | def num_frames(self): 91 | return self.num_batches * self.batch_size 92 | 93 | def random_state(self): 94 | return self.random 95 | 96 | def set_random_state(self, random): 97 | self.random = random 98 | 99 | def augment_enabled(self): 100 | return self.augment 101 | 102 | def next_pairs(self): 103 | if self.state is None: 104 | if self.shuffle: 105 | pairs = self.random.permutation(self.pairs) 106 | else: 107 | pairs = self.pairs 108 | return 0, pairs 109 | else: 110 | return self.state 111 | 112 | def next_epoch(self): 113 | enabled = True 114 | while enabled: 115 | start_idx, pairs = self.next_pairs() 116 | for b_i in range(start_idx, self.num_batches): 117 | chunks = pairs[b_i*self.batch_size : (b_i+1)*self.batch_size] 118 | for i, (seq_i, start_3d, end_3d, flip) in enumerate(chunks): 119 | start_2d = start_3d - self.pad - self.causal_shift 120 | end_2d = end_3d + self.pad - self.causal_shift 121 | 122 | # 2D poses 123 | seq_2d = self.poses_2d[seq_i] 124 | low_2d = max(start_2d, 0) 125 | high_2d = min(end_2d, seq_2d.shape[0]) 126 | pad_left_2d = low_2d - start_2d 127 | pad_right_2d = end_2d - high_2d 128 | if pad_left_2d != 0 or pad_right_2d != 0: 129 | self.batch_2d[i] = np.pad(seq_2d[low_2d:high_2d], ((pad_left_2d, pad_right_2d), (0, 0), (0, 0)), 'edge') 130 | else: 131 | self.batch_2d[i] = seq_2d[low_2d:high_2d] 132 | 133 | if flip: 134 | # Flip 2D keypoints 135 | self.batch_2d[i, :, :, 0] *= -1 136 | self.batch_2d[i, :, self.kps_left + self.kps_right] = self.batch_2d[i, :, self.kps_right + self.kps_left] 137 | 138 | # 3D poses 139 | if self.poses_3d is not None: 140 | seq_3d = self.poses_3d[seq_i] 141 | low_3d = max(start_3d, 0) 142 | high_3d = min(end_3d, seq_3d.shape[0]) 143 | pad_left_3d = low_3d - start_3d 144 | pad_right_3d = end_3d - high_3d 145 | if pad_left_3d != 0 or pad_right_3d != 0: 146 | self.batch_3d[i] = np.pad(seq_3d[low_3d:high_3d], ((pad_left_3d, pad_right_3d), (0, 0), (0, 0)), 'edge') 147 | else: 148 | self.batch_3d[i] = seq_3d[low_3d:high_3d] 149 | 150 | if flip: 151 | # Flip 3D joints 152 | self.batch_3d[i, :, :, 0] *= -1 153 | self.batch_3d[i, :, self.joints_left + self.joints_right] = \ 154 | self.batch_3d[i, :, self.joints_right + self.joints_left] 155 | 156 | # Cameras 157 | if self.cameras is not None: 158 | self.batch_cam[i] = self.cameras[seq_i] 159 | if flip: 160 | # Flip horizontal distortion coefficients 161 | self.batch_cam[i, 2] *= -1 162 | self.batch_cam[i, 7] *= -1 163 | 164 | if self.endless: 165 | self.state = (b_i + 1, pairs) 166 | if self.poses_3d is None and self.cameras is None: 167 | yield None, None, self.batch_2d[:len(chunks)] 168 | elif self.poses_3d is not None and self.cameras is None: 169 | yield None, self.batch_3d[:len(chunks)], self.batch_2d[:len(chunks)] 170 | elif self.poses_3d is None: 171 | yield self.batch_cam[:len(chunks)], None, self.batch_2d[:len(chunks)] 172 | else: 173 | yield self.batch_cam[:len(chunks)], self.batch_3d[:len(chunks)], self.batch_2d[:len(chunks)] 174 | 175 | if self.endless: 176 | self.state = None 177 | else: 178 | enabled = False 179 | 180 | 181 | class UnchunkedGenerator: 182 | """ 183 | Non-batched data generator, used for testing. 184 | Sequences are returned one at a time (i.e. batch size = 1), without chunking. 185 | 186 | If data augmentation is enabled, the batches contain two sequences (i.e. batch size = 2), 187 | the second of which is a mirrored version of the first. 188 | 189 | Arguments: 190 | cameras -- list of cameras, one element for each video (optional, used for semi-supervised training) 191 | poses_3d -- list of ground-truth 3D poses, one element for each video (optional, used for supervised training) 192 | poses_2d -- list of input 2D keypoints, one element for each video 193 | pad -- 2D input padding to compensate for valid convolutions, per side (depends on the receptive field) 194 | causal_shift -- asymmetric padding offset when causal convolutions are used (usually 0 or "pad") 195 | augment -- augment the dataset by flipping poses horizontally 196 | kps_left and kps_right -- list of left/right 2D keypoints if flipping is enabled 197 | joints_left and joints_right -- list of left/right 3D joints if flipping is enabled 198 | """ 199 | 200 | def __init__(self, cameras, poses_3d, poses_2d, pad=0, causal_shift=0, 201 | augment=False, kps_left=None, kps_right=None, joints_left=None, joints_right=None): 202 | assert poses_3d is None or len(poses_3d) == len(poses_2d) 203 | assert cameras is None or len(cameras) == len(poses_2d) 204 | 205 | self.augment = False 206 | self.kps_left = kps_left 207 | self.kps_right = kps_right 208 | self.joints_left = joints_left 209 | self.joints_right = joints_right 210 | 211 | self.pad = pad 212 | self.causal_shift = causal_shift 213 | self.cameras = [] if cameras is None else cameras 214 | self.poses_3d = [] if poses_3d is None else poses_3d 215 | self.poses_2d = poses_2d 216 | 217 | def num_frames(self): 218 | count = 0 219 | for p in self.poses_2d: 220 | count += p.shape[0] 221 | return count 222 | 223 | def augment_enabled(self): 224 | return self.augment 225 | 226 | def set_augment(self, augment): 227 | self.augment = augment 228 | 229 | def next_epoch(self): 230 | for seq_cam, seq_3d, seq_2d in zip_longest(self.cameras, self.poses_3d, self.poses_2d): 231 | batch_cam = None if seq_cam is None else np.expand_dims(seq_cam, axis=0) 232 | batch_3d = None if seq_3d is None else np.expand_dims(seq_3d, axis=0) 233 | batch_2d = np.expand_dims(np.pad(seq_2d, 234 | ((self.pad + self.causal_shift, self.pad - self.causal_shift), (0, 0), (0, 0)), 235 | 'edge'), axis=0) 236 | if self.augment: 237 | # Append flipped version 238 | if batch_cam is not None: 239 | batch_cam = np.concatenate((batch_cam, batch_cam), axis=0) 240 | batch_cam[1, 2] *= -1 241 | batch_cam[1, 7] *= -1 242 | 243 | if batch_3d is not None: 244 | batch_3d = np.concatenate((batch_3d, batch_3d), axis=0) 245 | batch_3d[1, :, :, 0] *= -1 246 | batch_3d[1, :, self.joints_left + self.joints_right] = batch_3d[1, :, self.joints_right + self.joints_left] 247 | 248 | batch_2d = np.concatenate((batch_2d, batch_2d), axis=0) 249 | batch_2d[1, :, :, 0] *= -1 250 | batch_2d[1, :, self.kps_left + self.kps_right] = batch_2d[1, :, self.kps_right + self.kps_left] 251 | 252 | yield batch_cam, batch_3d, batch_2d -------------------------------------------------------------------------------- /common/model_crossformer.py: -------------------------------------------------------------------------------- 1 | import math 2 | import logging 3 | from functools import partial 4 | from collections import OrderedDict 5 | from einops import rearrange, repeat 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 12 | from timm.models.helpers import load_pretrained 13 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 14 | from timm.models.registry import register_model 15 | 16 | 17 | 18 | class Mlp(nn.Module): 19 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 20 | super().__init__() 21 | out_features = out_features or in_features 22 | hidden_features = hidden_features or in_features 23 | self.fc1 = nn.Linear(in_features, hidden_features) 24 | self.act = act_layer() 25 | self.fc2 = nn.Linear(hidden_features, out_features) 26 | self.drop = nn.Dropout(drop) 27 | 28 | def forward(self, x): 29 | x = self.fc1(x) 30 | x = self.act(x) 31 | x = self.drop(x) 32 | x = self.fc2(x) 33 | x = self.drop(x) 34 | return x 35 | 36 | 37 | 38 | class SpatialCGNL(nn.Module): 39 | """Spatial CGNL block with dot production kernel for image classfication. 40 | """ 41 | def __init__(self, inplanes, planes, use_scale=False, groups=1): 42 | self.use_scale = use_scale 43 | self.groups = groups 44 | 45 | super(SpatialCGNL, self).__init__() 46 | # conv theta 47 | self.t = nn.Conv1d(inplanes, planes, kernel_size=3, padding=1, stride=1, bias=False) 48 | # conv phi 49 | self.p = nn.Conv1d(inplanes, planes, kernel_size=3, padding=1, stride=1, bias=False) 50 | # conv g 51 | self.g = nn.Conv1d(inplanes, planes, kernel_size=3, padding=1, stride=1, bias=False) 52 | # conv z 53 | self.z = nn.Conv1d(planes, inplanes, kernel_size=3, padding=1, stride=1, 54 | groups=self.groups, bias=False) 55 | self.gn = nn.GroupNorm(num_groups=self.groups, num_channels=inplanes) 56 | 57 | def kernel(self, t, p, g, b, c, w): 58 | """The linear kernel (dot production). 59 | Args: 60 | t: output of conv theata 61 | p: output of conv phi 62 | g: output of conv g 63 | b: batch size 64 | c: channels number 65 | h: height of featuremaps 66 | w: width of featuremaps 67 | """ 68 | # t = t.view(b, 1,c//4, 4 * w) 69 | # p = p.view(b, 1, c//4, 4 * w) 70 | # g = g.view(b, c//4, 4 * w, 1) 71 | 72 | # t = t.view(b, c//4, 4 * w) 73 | # p = p.view(b, c//4, 4 * w) 74 | # g = g.view(b, 4 * w, c//4) 75 | 76 | t = t.view(b, 1,c * w) 77 | p = p.view(b, 1, c * w) 78 | g = g.view(b, c * w, 1) 79 | 80 | att = torch.bmm(p, g) 81 | 82 | if self.use_scale: 83 | att = att.div((c*w)**0.5) 84 | 85 | x = torch.bmm(att, t) 86 | x = x.view(b, c, w) 87 | return x 88 | 89 | def forward(self, x): 90 | residual = x 91 | 92 | t = self.t(x) 93 | p = self.p(x) 94 | g = self.g(x) 95 | 96 | b, c, w = t.size() 97 | 98 | if self.groups and self.groups > 1: 99 | _c = int(c / self.groups) 100 | 101 | ts = torch.split(t, split_size_or_sections=_c, dim=1) 102 | ps = torch.split(p, split_size_or_sections=_c, dim=1) 103 | gs = torch.split(g, split_size_or_sections=_c, dim=1) 104 | 105 | _t_sequences = [] 106 | for i in range(self.groups): 107 | _x = self.kernel(ts[i], ps[i], gs[i], 108 | b, _c, w) 109 | _t_sequences.append(_x) 110 | 111 | x = torch.cat(_t_sequences, dim=1) 112 | else: 113 | x = self.kernel(t, p, g, 114 | b, c, w) 115 | 116 | x = self.z(x) 117 | x = self.gn(x) + residual 118 | 119 | return x 120 | 121 | 122 | class LPI(nn.Module): 123 | """ 124 | Local Patch Interaction module that allows explicit communication between tokens in 3x3 windows 125 | to augment the implicit communcation performed by the block diagonal scatter attention. 126 | Implemented using 2 layers of separable 3x3 convolutions with GeLU and BatchNorm2d 127 | """ 128 | 129 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, 130 | drop=0., kernel_size=5, dim=17): 131 | super().__init__() 132 | out_features = out_features or in_features 133 | 134 | padding = kernel_size // 2 135 | 136 | self.conv1 = torch.nn.Conv1d(in_features, out_features, kernel_size=kernel_size, 137 | padding=padding, groups=1) 138 | self.act = act_layer() 139 | # self.bn = nn.SyncBatchNorm(in_features) 140 | self.gn = nn.GroupNorm(num_groups=out_features, num_channels=in_features) 141 | self.bn = nn.BatchNorm1d(in_features) 142 | self.conv2 = torch.nn.Conv1d(in_features, out_features, kernel_size=kernel_size, 143 | padding=padding, groups=1) 144 | 145 | self.conv3 = torch.nn.Conv1d(dim, dim, kernel_size=kernel_size, 146 | padding=padding, groups=1) 147 | self.gn1 = nn.GroupNorm(num_groups=out_features, num_channels=out_features) 148 | # self.gn2 = nn.GroupNorm(num_groups=out_features, num_channels=out_features) 149 | # self.conv3 = torch.nn.Conv1d(2*in_features, out_features, kernel_size=kernel_size, 150 | # padding=padding, groups=1) 151 | 152 | def forward(self, x): 153 | res = x 154 | B, N, C = x.shape 155 | x = x.permute(0, 2, 1) 156 | x = self.conv1(x) 157 | x = self.act(x) 158 | x = self.gn(x) 159 | x = self.conv2(x) 160 | x = x.reshape(B, C, N).permute(0, 2, 1) 161 | x += res 162 | return x 163 | 164 | 165 | 166 | class Block(nn.Module): 167 | 168 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 169 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, dim_conv=1): 170 | super().__init__() 171 | self.norm1 = norm_layer(dim) 172 | self.attn = Attention( 173 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 174 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 175 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 176 | self.norm2 = norm_layer(dim) 177 | mlp_hidden_dim = int(dim * mlp_ratio) 178 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 179 | # self.causal = TemporalModelOptimized1f(17, dim, 17, 1) 180 | if dim_conv == 81: 181 | self.local = SpatialCGNL(dim_conv, int(dim_conv), use_scale=False, groups=3) 182 | else: 183 | self.local_mp = LPI(in_features=dim, act_layer=act_layer, dim=dim_conv) 184 | 185 | 186 | self.norm3 = norm_layer(dim) 187 | self.norm4 = norm_layer(81) 188 | 189 | eta=1e-5 190 | self.gamma1 = nn.Parameter(eta * torch.ones(dim), requires_grad=True) 191 | self.gamma2 = nn.Parameter(eta * torch.ones(dim), requires_grad=True) 192 | self.gamma3 = nn.Parameter(eta * torch.ones(dim), requires_grad=True) 193 | self.gamma4 = nn.Parameter(eta * torch.ones(81), requires_grad=True) 194 | 195 | def forward(self, x): 196 | x = x + self.drop_path(self.gamma1 * self.attn(self.norm1(x))) 197 | 198 | if x.shape[2] == 544: 199 | # x = x.transpose(-2,-1) 200 | x = x + self.drop_path(self.gamma3 * self.local(self.norm3(x))) 201 | # x = x.transpose(-2,-1) 202 | else: 203 | x = x + self.drop_path(self.gamma3 * self.local_mp(self.norm3(x))) 204 | 205 | x = x + self.drop_path(self.gamma2 * self.mlp(self.norm2(x))) 206 | return x 207 | 208 | 209 | 210 | # def forward(self, x): 211 | # x = x + self.drop_path(self.attn(self.norm1(x))) 212 | # x = x + self.drop_path(self.mlp(self.norm2(x))) 213 | # return x 214 | 215 | 216 | class SE(nn.Module): 217 | def __init__(self, dim, hidden_ratio=None): 218 | super().__init__() 219 | hidden_ratio = hidden_ratio or 1 220 | self.dim = dim 221 | hidden_dim = int(dim * hidden_ratio) 222 | self.fc = nn.Sequential( 223 | nn.LayerNorm(dim), 224 | nn.Linear(dim, hidden_dim), 225 | nn.ReLU(inplace=True), 226 | nn.Linear(hidden_dim, dim), 227 | nn.Tanh() 228 | ) 229 | 230 | def forward(self, x): 231 | a = x.mean(dim=1, keepdim=True) # B, 1, C 232 | a = self.fc(a) 233 | x = a * x 234 | return x 235 | 236 | 237 | class Attention(nn.Module): 238 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 239 | super().__init__() 240 | self.num_heads = num_heads 241 | head_dim = dim // num_heads 242 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 243 | self.scale = qk_scale or head_dim ** -0.5 244 | 245 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 246 | self.attn_drop = nn.Dropout(attn_drop) 247 | self.proj = nn.Linear(dim, dim) 248 | self.proj_drop = nn.Dropout(proj_drop) 249 | 250 | def forward(self, x): 251 | B, N, C = x.shape 252 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 253 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 254 | 255 | attn = (q @ k.transpose(-2, -1)) * self.scale 256 | attn = attn.softmax(dim=-1) 257 | attn = self.attn_drop(attn) 258 | 259 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 260 | x = self.proj(x) 261 | x = self.proj_drop(x) 262 | return x 263 | 264 | attns = [] 265 | class PoseTransformer(nn.Module): 266 | def __init__(self, num_frame=9, num_joints=17, in_chans=2, embed_dim_ratio=32, depth=4, 267 | num_heads=8, mlp_ratio=2., qkv_bias=True, qk_scale=None, 268 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.2, norm_layer=None): 269 | """ ##########hybrid_backbone=None, representation_size=None, 270 | Args: 271 | num_frame (int, tuple): input frame number 272 | num_joints (int, tuple): joints number 273 | in_chans (int): number of input channels, 2D joints have 2 channels: (x,y) 274 | embed_dim_ratio (int): embedding dimension ratio 275 | depth (int): depth of transformer 276 | num_heads (int): number of attention heads 277 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim 278 | qkv_bias (bool): enable bias for qkv if True 279 | qk_scale (float): override default qk scale of head_dim ** -0.5 if set 280 | drop_rate (float): dropout rate 281 | attn_drop_rate (float): attention dropout rate 282 | drop_path_rate (float): stochastic depth rate 283 | norm_layer: (nn.Module): normalization layer 284 | """ 285 | super().__init__() 286 | 287 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 288 | embed_dim = embed_dim_ratio * num_joints #### temporal embed_dim is num_joints * spatial embedding dim ratio 289 | out_dim = num_joints * 3 #### output dimension is num_joints * 3 290 | 291 | ### spatial patch embedding 292 | self.Spatial_patch_to_embedding = nn.Linear(in_chans, embed_dim_ratio) 293 | self.Spatial_pos_embed = nn.Parameter(torch.zeros(1, num_joints, embed_dim_ratio)) 294 | 295 | self.Spatial_patch_to_embedding = nn.Linear(in_chans, embed_dim_ratio) 296 | self.Spatial_patch_to_embedding1 = nn.Linear(in_chans, embed_dim_ratio//2) 297 | 298 | self.temporal_patch_to_embedding1 = nn.Linear(544, 544//4) 299 | self.Spatial_pos_embed = nn.Parameter(torch.zeros(1, num_joints, embed_dim_ratio)) 300 | 301 | self.Temporal_pos_embed1 = nn.Parameter(torch.zeros(1, 544)) 302 | self.Temporal_pos_embed = nn.Parameter(torch.zeros(1, 544)) 303 | 304 | self.chennels_pos_embed = nn.Parameter(torch.zeros(32, 17)) 305 | 306 | 307 | self.top_pos_embed = nn.Parameter(torch.zeros(1, 544)) 308 | self.pos_drop = nn.Dropout(p=drop_rate) 309 | 310 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 311 | 312 | 313 | outer_dim=32 314 | inner_dim=16 315 | # depth=12 316 | outer_num_heads=8 317 | inner_num_heads=4 318 | mlp_ratio=4. 319 | qkv_bias=False 320 | qk_scale=None 321 | drop_rate=0. 322 | attn_drop_rate=0. 323 | drop_path_rate=0. 324 | norm_layer=nn.LayerNorm 325 | inner_stride=4 326 | se=1 327 | 328 | self.proj_norm1 = norm_layer(inner_dim) 329 | self.proj = nn.Linear(inner_dim, outer_dim) 330 | self.proj_norm2 = norm_layer(outer_dim) 331 | 332 | self.proj_norm12 = norm_layer(136) 333 | self.proj12 = nn.Linear(136, 544) 334 | self.proj_norm13 = norm_layer(544) 335 | 336 | self.outer_tokens = nn.Parameter(torch.zeros(1, 17, outer_dim), requires_grad=False) 337 | self.outer_pos = nn.Parameter(torch.zeros(1, 17, outer_dim)) 338 | 339 | self.outer_temp = nn.Parameter(torch.zeros(1, 1, 544)) 340 | self.inner_pos = nn.Parameter(torch.zeros(1, 17, inner_dim)) 341 | 342 | self.inner_temp = nn.Parameter(torch.zeros(1, 1, 136)) 343 | self.pos_drop = nn.Dropout(p=drop_rate) 344 | 345 | self.Spatial_blocks = nn.ModuleList([ 346 | Block( 347 | dim=embed_dim_ratio, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 348 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, dim_conv=17) 349 | for i in range(depth)]) 350 | 351 | 352 | self.blocks1 = nn.ModuleList([ 353 | Block( 354 | dim=544, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 355 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, dim_conv=81) 356 | for i in range(depth)]) 357 | 358 | self.lins = nn.Linear(2,32) 359 | self.Spatial_norm = norm_layer(embed_dim_ratio) 360 | self.Temporal_norm = norm_layer(544) 361 | self.channels_norm = norm_layer(17) 362 | 363 | self.Temporal = norm_layer(352) 364 | self.channels = norm_layer(11) 365 | 366 | self.Temporal_norm1 = norm_layer(544) 367 | ####### A easy way to implement weighted mean 368 | self.weighted_mean = torch.nn.Conv1d(in_channels=num_frame, out_channels=1, kernel_size=1) 369 | self.weighted_mean1 = torch.nn.Conv1d(in_channels=17, out_channels=17, kernel_size=1) 370 | self.weighted_mean2 = torch.nn.Conv1d(in_channels=11, out_channels=11, kernel_size=1) 371 | 372 | 373 | 374 | 375 | self.spatial_embed = nn.Parameter(torch.zeros(1, 11, embed_dim_ratio)) 376 | 377 | self.Spatial_patch_to_embedding = nn.Linear(in_chans, embed_dim_ratio) 378 | self.Spatial_patch_to_embedding1 = nn.Linear(in_chans, embed_dim_ratio//2) 379 | 380 | self.temporal_patch_to_embedding1 = nn.Linear(544, 544//4) 381 | self.Spatial_pos_embed = nn.Parameter(torch.zeros(1, num_joints, embed_dim_ratio)) 382 | 383 | self.Temporal_pos_embed1 = nn.Parameter(torch.zeros(1, 544)) 384 | self.Temporal_pos_embed = nn.Parameter(torch.zeros(1, 544)) 385 | self.Temporal_embed = nn.Parameter(torch.zeros(1, 352)) 386 | 387 | 388 | self.chennels_pos_embed = nn.Parameter(torch.zeros(32, 17)) 389 | 390 | self.head = nn.Sequential( 391 | nn.LayerNorm(embed_dim), 392 | nn.Linear(embed_dim , out_dim), 393 | ) 394 | 395 | 396 | 397 | def forward_features1(self, x): 398 | b = x.shape[0] 399 | x += self.Temporal_pos_embed1 400 | x = self.pos_drop(x) 401 | attn = None 402 | for blk in self.blocks1: 403 | x = blk(x) 404 | attns.append(attn) 405 | 406 | x = self.Temporal_norm1(x) 407 | ##### x size [b, f, emb_dim], then take weighted mean on frame dimension, we only predict 3D pose of the center frame 408 | x = self.weighted_mean(x) 409 | x = x.view(b, 1, -1) 410 | return x 411 | 412 | 413 | def Spatial_forward_features(self, x): 414 | b, _, f, p = x.shape ##### b is batch size, f is number of frames, p is number of joints 415 | x = rearrange(x, 'b c f p -> (b f) p c', ) 416 | 417 | x = self.Spatial_patch_to_embedding(x) 418 | 419 | x += self.Spatial_pos_embed 420 | x = self.pos_drop(x) 421 | for blk in self.Spatial_blocks: 422 | x = blk(x) 423 | 424 | x = self.Spatial_norm(x) 425 | x = rearrange(x, '(b f) w c -> b f (w c)', f=f) 426 | return x 427 | 428 | def forward_features(self, x): 429 | b = x.shape[0] 430 | x += self.Spatial_pos_embed 431 | x = self.pos_drop(x) 432 | for blk in self.Spatial_blocks: 433 | x = blk(x) 434 | 435 | x = self.Spatial_norm(x) 436 | ##### x size [b, f, emb_dim], then take weighted mean on frame dimension, we only predict 3D pose of the center frame 437 | x = self.weighted_mean1(x) 438 | x = x.view(b, 1, -1) 439 | return x 440 | 441 | 442 | def forward(self, x): 443 | attns.clear() 444 | x1 = x 445 | x = x.permute(0, 3, 1, 2) 446 | b, _, _, p = x.shape 447 | 448 | 449 | x[:,:,10:16]=0 450 | 451 | #b,17, 32 452 | x = self.Spatial_forward_features(x) 453 | 454 | #b, 1, 544 455 | x2=x 456 | x = self.forward_features1(x) 457 | 458 | 459 | x = self.head(x) 460 | 461 | x = x.view(b, 1, p, -1) 462 | return x 463 | 464 | 465 | 466 | -------------------------------------------------------------------------------- /run_crossformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import numpy as np 9 | 10 | from common.arguments import parse_args 11 | import torch 12 | 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | import torch.optim as optim 16 | import os 17 | import sys 18 | import errno 19 | import math 20 | 21 | from einops import rearrange, repeat 22 | from copy import deepcopy 23 | 24 | from common.camera import * 25 | import collections 26 | 27 | from common.model_crossformer import * 28 | 29 | from common.loss import * 30 | from common.generators import ChunkedGenerator, UnchunkedGenerator 31 | from time import time 32 | from common.utils import * 33 | 34 | 35 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 36 | # os.environ["CUDA_VISIBLE_DEVICES"] = "0" 37 | os.environ["CUDA_VISIBLE_DEVICES"] = "0,1" 38 | # print(torch.cuda.device_count()) 39 | 40 | 41 | ################### 42 | args = parse_args() 43 | # print(args) 44 | 45 | try: 46 | # Create checkpoint directory if it does not exist 47 | os.makedirs(args.checkpoint) 48 | except OSError as e: 49 | if e.errno != errno.EEXIST: 50 | raise RuntimeError('Unable to create checkpoint directory:', args.checkpoint) 51 | 52 | print('Loading dataset...') 53 | dataset_path = 'data/data_3d_' + args.dataset + '.npz' 54 | if args.dataset == 'h36m': 55 | from common.h36m_dataset import Human36mDataset 56 | dataset = Human36mDataset(dataset_path) 57 | elif args.dataset.startswith('humaneva'): 58 | from common.humaneva_dataset import HumanEvaDataset 59 | dataset = HumanEvaDataset(dataset_path) 60 | elif args.dataset.startswith('custom'): 61 | from common.custom_dataset import CustomDataset 62 | dataset = CustomDataset('data/data_2d_' + args.dataset + '_' + args.keypoints + '.npz') 63 | else: 64 | raise KeyError('Invalid dataset') 65 | 66 | print('Preparing data...') 67 | for subject in dataset.subjects(): 68 | for action in dataset[subject].keys(): 69 | anim = dataset[subject][action] 70 | if 'positions' in anim: 71 | positions_3d = [] 72 | for cam in anim['cameras']: 73 | pos_3d = world_to_camera(anim['positions'], R=cam['orientation'], t=cam['translation']) 74 | pos_3d[:, 1:] -= pos_3d[:, :1] # Remove global offset, but keep trajectory in first position 75 | positions_3d.append(pos_3d) 76 | anim['positions_3d'] = positions_3d 77 | 78 | print('Loading 2D detections...') 79 | keypoints = np.load('data/data_2d_' + args.dataset + '_' + args.keypoints + '.npz', allow_pickle=True) 80 | keypoints_metadata = keypoints['metadata'].item() 81 | keypoints_symmetry = keypoints_metadata['keypoints_symmetry'] 82 | kps_left, kps_right = list(keypoints_symmetry[0]), list(keypoints_symmetry[1]) 83 | 84 | joints_left, joints_right = list(dataset.skeleton().joints_left()), list(dataset.skeleton().joints_right()) 85 | # print(kps_left, joints_right) 86 | keypoints = keypoints['positions_2d'].item() 87 | 88 | ################### 89 | for subject in dataset.subjects(): 90 | assert subject in keypoints, 'Subject {} is missing from the 2D detections dataset'.format(subject) 91 | for action in dataset[subject].keys(): 92 | assert action in keypoints[subject], 'Action {} of subject {} is missing from the 2D detections dataset'.format(action, subject) 93 | if 'positions_3d' not in dataset[subject][action]: 94 | continue 95 | 96 | for cam_idx in range(len(keypoints[subject][action])): 97 | 98 | # We check for >= instead of == because some videos in H3.6M contain extra frames 99 | mocap_length = dataset[subject][action]['positions_3d'][cam_idx].shape[0] 100 | assert keypoints[subject][action][cam_idx].shape[0] >= mocap_length 101 | 102 | if keypoints[subject][action][cam_idx].shape[0] > mocap_length: 103 | # Shorten sequence 104 | keypoints[subject][action][cam_idx] = keypoints[subject][action][cam_idx][:mocap_length] 105 | 106 | assert len(keypoints[subject][action]) == len(dataset[subject][action]['positions_3d']) 107 | 108 | for subject in keypoints.keys(): 109 | for action in keypoints[subject]: 110 | for cam_idx, kps in enumerate(keypoints[subject][action]): 111 | # Normalize camera frame 112 | cam = dataset.cameras()[subject][cam_idx] 113 | kps[..., :2] = normalize_screen_coordinates(kps[..., :2], w=cam['res_w'], h=cam['res_h']) 114 | keypoints[subject][action][cam_idx] = kps 115 | 116 | subjects_train = args.subjects_train.split(',') 117 | subjects_semi = [] if not args.subjects_unlabeled else args.subjects_unlabeled.split(',') 118 | if not args.render: 119 | subjects_test = args.subjects_test.split(',') 120 | else: 121 | subjects_test = [args.viz_subject] 122 | 123 | 124 | def fetch(subjects, action_filter=None, subset=1, parse_3d_poses=True): 125 | out_poses_3d = [] 126 | out_poses_2d = [] 127 | out_camera_params = [] 128 | for subject in subjects: 129 | for action in keypoints[subject].keys(): 130 | if action_filter is not None: 131 | found = False 132 | for a in action_filter: 133 | if action.startswith(a): 134 | found = True 135 | break 136 | if not found: 137 | continue 138 | 139 | poses_2d = keypoints[subject][action] 140 | for i in range(len(poses_2d)): # Iterate across cameras 141 | out_poses_2d.append(poses_2d[i]) 142 | 143 | if subject in dataset.cameras(): 144 | cams = dataset.cameras()[subject] 145 | assert len(cams) == len(poses_2d), 'Camera count mismatch' 146 | for cam in cams: 147 | if 'intrinsic' in cam: 148 | out_camera_params.append(cam['intrinsic']) 149 | 150 | if parse_3d_poses and 'positions_3d' in dataset[subject][action]: 151 | poses_3d = dataset[subject][action]['positions_3d'] 152 | assert len(poses_3d) == len(poses_2d), 'Camera count mismatch' 153 | for i in range(len(poses_3d)): # Iterate across cameras 154 | out_poses_3d.append(poses_3d[i]) 155 | 156 | if len(out_camera_params) == 0: 157 | out_camera_params = None 158 | if len(out_poses_3d) == 0: 159 | out_poses_3d = None 160 | 161 | stride = args.downsample 162 | if subset < 1: 163 | for i in range(len(out_poses_2d)): 164 | n_frames = int(round(len(out_poses_2d[i])//stride * subset)*stride) 165 | start = deterministic_random(0, len(out_poses_2d[i]) - n_frames + 1, str(len(out_poses_2d[i]))) 166 | out_poses_2d[i] = out_poses_2d[i][start:start+n_frames:stride] 167 | if out_poses_3d is not None: 168 | out_poses_3d[i] = out_poses_3d[i][start:start+n_frames:stride] 169 | elif stride > 1: 170 | # Downsample as requested 171 | for i in range(len(out_poses_2d)): 172 | out_poses_2d[i] = out_poses_2d[i][::stride] 173 | if out_poses_3d is not None: 174 | out_poses_3d[i] = out_poses_3d[i][::stride] 175 | 176 | 177 | return out_camera_params, out_poses_3d, out_poses_2d 178 | 179 | action_filter = None if args.actions == '*' else args.actions.split(',') 180 | if action_filter is not None: 181 | print('Selected actions:', action_filter) 182 | 183 | cameras_valid, poses_valid, poses_valid_2d = fetch(subjects_test, action_filter) 184 | 185 | 186 | receptive_field = args.number_of_frames 187 | print('INFO: Receptive field: {} frames'.format(receptive_field)) 188 | pad = (receptive_field -1) // 2 # Padding on each side 189 | min_loss = 100000 190 | width = cam['res_w'] 191 | height = cam['res_h'] 192 | num_joints = keypoints_metadata['num_joints'] 193 | #########################################PoseTransformer 194 | 195 | model_pos_train = PoseTransformer(num_frame=receptive_field, num_joints=num_joints, in_chans=2, embed_dim_ratio=32, depth=4, 196 | num_heads=8, mlp_ratio=2., qkv_bias=True, qk_scale=None,drop_path_rate=0.1) 197 | 198 | model_pos = PoseTransformer(num_frame=receptive_field, num_joints=num_joints, in_chans=2, embed_dim_ratio=32, depth=4, 199 | num_heads=8, mlp_ratio=2., qkv_bias=True, qk_scale=None,drop_path_rate=0) 200 | 201 | ################ load weight ######################## 202 | # posetrans_checkpoint = torch.load('./checkpoint/best_epoch_gt_28.5.bin', map_location=lambda storage, loc: storage) 203 | # posetrans_checkpoint = posetrans_checkpoint["model_pos"] 204 | # model_pos_train = load_pretrained_weights(model_pos_train, posetrans_checkpoint) 205 | 206 | ################# 207 | causal_shift = 0 208 | model_params = 0 209 | for parameter in model_pos.parameters(): 210 | model_params += parameter.numel() 211 | print('INFO: Trainable parameter count:', model_params) 212 | 213 | if torch.cuda.is_available(): 214 | model_pos = nn.DataParallel(model_pos) 215 | model_pos = model_pos.cuda() 216 | model_pos_train = nn.DataParallel(model_pos_train) 217 | model_pos_train = model_pos_train.cuda() 218 | 219 | 220 | if args.resume or args.evaluate: 221 | chk_filename = os.path.join(args.checkpoint, args.resume if args.resume else args.evaluate) 222 | print('Loading checkpoint', chk_filename) 223 | checkpoint = torch.load(chk_filename, map_location=lambda storage, loc: storage) 224 | model_pos_train.load_state_dict(checkpoint['model_pos'], strict=False) 225 | model_pos.load_state_dict(checkpoint['model_pos'], strict=False) 226 | 227 | 228 | test_generator = UnchunkedGenerator(cameras_valid, poses_valid, poses_valid_2d, 229 | pad=pad, causal_shift=causal_shift, augment=False, 230 | kps_left=kps_left, kps_right=kps_right, joints_left=joints_left, joints_right=joints_right) 231 | print('INFO: Testing on {} frames'.format(test_generator.num_frames())) 232 | 233 | def eval_data_prepare(receptive_field, inputs_2d, inputs_3d): 234 | inputs_2d_p = torch.squeeze(inputs_2d) 235 | inputs_3d_p = inputs_3d.permute(1,0,2,3) 236 | out_num = inputs_2d_p.shape[0] - receptive_field + 1 237 | eval_input_2d = torch.empty(out_num, receptive_field, inputs_2d_p.shape[1], inputs_2d_p.shape[2]) 238 | for i in range(out_num): 239 | eval_input_2d[i,:,:,:] = inputs_2d_p[i:i+receptive_field, :, :] 240 | return eval_input_2d, inputs_3d_p 241 | 242 | 243 | ################### 244 | 245 | if not args.evaluate: 246 | cameras_train, poses_train, poses_train_2d = fetch(subjects_train, action_filter, subset=args.subset) 247 | 248 | lr = args.learning_rate 249 | optimizer = optim.AdamW(model_pos_train.parameters(), lr=lr, weight_decay=0.1) 250 | # optimizer = optim.SGD(model_pos_train.parameters(), lr = lr, weight_decay=0.01) 251 | 252 | lr_decay = args.lr_decay 253 | losses_3d_train = [] 254 | losses_3d_train_eval = [] 255 | losses_3d_valid = [] 256 | 257 | epoch = 0 258 | initial_momentum = 0.1 259 | final_momentum = 0.001 260 | 261 | train_generator = ChunkedGenerator(args.batch_size//args.stride, cameras_train, poses_train, poses_train_2d, args.stride, 262 | pad=pad, causal_shift=causal_shift, shuffle=True, augment=args.data_augmentation, 263 | kps_left=kps_left, kps_right=kps_right, joints_left=joints_left, joints_right=joints_right) 264 | train_generator_eval = UnchunkedGenerator(cameras_train, poses_train, poses_train_2d, 265 | pad=pad, causal_shift=causal_shift, augment=False) 266 | print('INFO: Training on {} frames'.format(train_generator_eval.num_frames())) 267 | 268 | 269 | if args.resume: 270 | epoch = checkpoint['epoch'] 271 | if 'optimizer' in checkpoint and checkpoint['optimizer'] is not None: 272 | optimizer.load_state_dict(checkpoint['optimizer']) 273 | train_generator.set_random_state(checkpoint['random_state']) 274 | else: 275 | print('WARNING: this checkpoint does not contain an optimizer state. The optimizer will be reinitialized.') 276 | 277 | lr = checkpoint['lr'] 278 | 279 | 280 | print('** Note: reported losses are averaged over all frames.') 281 | print('** The final evaluation will be carried out after the last training epoch.') 282 | 283 | # Pos model only 284 | while epoch < args.epochs: 285 | start_time = time() 286 | epoch_loss_3d_train = 0 287 | epoch_loss_traj_train = 0 288 | epoch_loss_2d_train_unlabeled = 0 289 | N = 0 290 | N_semi = 0 291 | model_pos_train.train() 292 | 293 | for cameras_train, batch_3d, batch_2d in train_generator.next_epoch(): 294 | cameras_train = torch.from_numpy(cameras_train.astype('float32')) 295 | inputs_3d = torch.from_numpy(batch_3d.astype('float32')) 296 | inputs_2d = torch.from_numpy(batch_2d.astype('float32')) 297 | 298 | if torch.cuda.is_available(): 299 | inputs_3d = inputs_3d.cuda() 300 | inputs_2d = inputs_2d.cuda() 301 | cameras_train = cameras_train.cuda() 302 | inputs_traj = inputs_3d[:, :, :1].clone() 303 | inputs_3d[:, :, 0] = 0 304 | 305 | optimizer.zero_grad() 306 | 307 | # Predict 3D poses 308 | predicted_3d_pos = model_pos_train(inputs_2d) 309 | 310 | del inputs_2d 311 | torch.cuda.empty_cache() 312 | 313 | loss_3d_pos = mpjpe(predicted_3d_pos, inputs_3d) 314 | epoch_loss_3d_train += inputs_3d.shape[0] * inputs_3d.shape[1] * loss_3d_pos.item() 315 | N += inputs_3d.shape[0] * inputs_3d.shape[1] 316 | 317 | loss_total = loss_3d_pos 318 | 319 | loss_total.backward() 320 | 321 | optimizer.step() 322 | del inputs_3d, loss_3d_pos, predicted_3d_pos 323 | torch.cuda.empty_cache() 324 | 325 | losses_3d_train.append(epoch_loss_3d_train / N) 326 | torch.cuda.empty_cache() 327 | 328 | # End-of-epoch evaluation 329 | with torch.no_grad(): 330 | model_pos.load_state_dict(model_pos_train.state_dict(), strict=False) 331 | model_pos.eval() 332 | 333 | epoch_loss_3d_valid = 0 334 | epoch_loss_traj_valid = 0 335 | epoch_loss_2d_valid = 0 336 | N = 0 337 | if not args.no_eval: 338 | # Evaluate on test set 339 | for cam, batch, batch_2d in test_generator.next_epoch(): 340 | inputs_3d = torch.from_numpy(batch.astype('float32')) 341 | inputs_2d = torch.from_numpy(batch_2d.astype('float32')) 342 | 343 | ##### apply test-time-augmentation (following Videopose3d) 344 | inputs_2d_flip = inputs_2d.clone() 345 | inputs_2d_flip[:, :, :, 0] *= -1 346 | inputs_2d_flip[:, :, kps_left + kps_right, :] = inputs_2d_flip[:, :, kps_right + kps_left, :] 347 | 348 | ##### convert size 349 | inputs_2d, inputs_3d = eval_data_prepare(receptive_field, inputs_2d, inputs_3d) 350 | inputs_2d_flip, _ = eval_data_prepare(receptive_field, inputs_2d_flip, inputs_3d) 351 | 352 | if torch.cuda.is_available(): 353 | inputs_2d = inputs_2d.cuda() 354 | inputs_2d_flip = inputs_2d_flip.cuda() 355 | inputs_3d = inputs_3d.cuda() 356 | inputs_3d[:, :, 0] = 0 357 | 358 | predicted_3d_pos = model_pos(inputs_2d) 359 | predicted_3d_pos_flip = model_pos(inputs_2d_flip) 360 | predicted_3d_pos_flip[:, :, :, 0] *= -1 361 | predicted_3d_pos_flip[:, :, joints_left + joints_right] = predicted_3d_pos_flip[:, :, 362 | joints_right + joints_left] 363 | 364 | predicted_3d_pos = torch.mean(torch.cat((predicted_3d_pos, predicted_3d_pos_flip), dim=1), dim=1, 365 | keepdim=True) 366 | 367 | del inputs_2d, inputs_2d_flip 368 | torch.cuda.empty_cache() 369 | 370 | loss_3d_pos = mpjpe(predicted_3d_pos, inputs_3d) 371 | epoch_loss_3d_valid += inputs_3d.shape[0] * inputs_3d.shape[1] * loss_3d_pos.item() 372 | N += inputs_3d.shape[0] * inputs_3d.shape[1] 373 | 374 | del inputs_3d, loss_3d_pos, predicted_3d_pos 375 | torch.cuda.empty_cache() 376 | 377 | losses_3d_valid.append(epoch_loss_3d_valid / N) 378 | 379 | # Evaluate on training set, this time in evaluation mode 380 | epoch_loss_3d_train_eval = 0 381 | epoch_loss_traj_train_eval = 0 382 | epoch_loss_2d_train_labeled_eval = 0 383 | N = 0 384 | for cam, batch, batch_2d in train_generator_eval.next_epoch(): 385 | if batch_2d.shape[1] == 0: 386 | # This can only happen when downsampling the dataset 387 | continue 388 | 389 | inputs_3d = torch.from_numpy(batch.astype('float32')) 390 | inputs_2d = torch.from_numpy(batch_2d.astype('float32')) 391 | inputs_2d, inputs_3d = eval_data_prepare(receptive_field, inputs_2d, inputs_3d) 392 | 393 | if torch.cuda.is_available(): 394 | inputs_3d = inputs_3d.cuda() 395 | inputs_2d = inputs_2d.cuda() 396 | 397 | inputs_3d[:, :, 0] = 0 398 | 399 | # Compute 3D poses 400 | predicted_3d_pos = model_pos(inputs_2d) 401 | 402 | del inputs_2d 403 | torch.cuda.empty_cache() 404 | 405 | loss_3d_pos = mpjpe(predicted_3d_pos, inputs_3d) 406 | epoch_loss_3d_train_eval += inputs_3d.shape[0] * inputs_3d.shape[1] * loss_3d_pos.item() 407 | N += inputs_3d.shape[0] * inputs_3d.shape[1] 408 | 409 | del inputs_3d, loss_3d_pos, predicted_3d_pos 410 | torch.cuda.empty_cache() 411 | 412 | losses_3d_train_eval.append(epoch_loss_3d_train_eval / N) 413 | 414 | # Evaluate 2D loss on unlabeled training set (in evaluation mode) 415 | epoch_loss_2d_train_unlabeled_eval = 0 416 | N_semi = 0 417 | 418 | elapsed = (time() - start_time) / 60 419 | 420 | if args.no_eval: 421 | print('[%d] time %.2f lr %f 3d_train %f' % ( 422 | epoch + 1, 423 | elapsed, 424 | lr, 425 | losses_3d_train[-1] * 1000)) 426 | else: 427 | 428 | print('[%d] time %.2f lr %f 3d_train %f 3d_eval %f 3d_valid %f' % ( 429 | epoch + 1, 430 | elapsed, 431 | lr, 432 | losses_3d_train[-1] * 1000, 433 | losses_3d_train_eval[-1] * 1000, 434 | losses_3d_valid[-1] * 1000)) 435 | 436 | # Decay learning rate exponentially 437 | lr *= lr_decay 438 | for param_group in optimizer.param_groups: 439 | param_group['lr'] *= lr_decay 440 | epoch += 1 441 | 442 | # Decay BatchNorm momentum 443 | # momentum = initial_momentum * np.exp(-epoch/args.epochs * np.log(initial_momentum/final_momentum)) 444 | # model_pos_train.set_bn_momentum(momentum) 445 | 446 | # Save checkpoint if necessary 447 | if epoch % args.checkpoint_frequency == 0: 448 | chk_path = os.path.join(args.checkpoint, 'epoch_{}.bin'.format(epoch)) 449 | print('Saving checkpoint to', chk_path) 450 | 451 | torch.save({ 452 | 'epoch': epoch, 453 | 'lr': lr, 454 | 'random_state': train_generator.random_state(), 455 | 'optimizer': optimizer.state_dict(), 456 | 'model_pos': model_pos_train.state_dict(), 457 | # 'model_traj': model_traj_train.state_dict() if semi_supervised else None, 458 | # 'random_state_semi': semi_generator.random_state() if semi_supervised else None, 459 | }, chk_path) 460 | 461 | #### save best checkpoint 462 | best_chk_path = os.path.join(args.checkpoint, 'best_epoch.bin'.format(epoch)) 463 | if losses_3d_valid[-1] * 1000 < min_loss: 464 | min_loss = losses_3d_valid[-1] * 1000 465 | print("save best checkpoint") 466 | torch.save({ 467 | 'epoch': epoch, 468 | 'lr': lr, 469 | 'random_state': train_generator.random_state(), 470 | 'optimizer': optimizer.state_dict(), 471 | 'model_pos': model_pos_train.state_dict(), 472 | # 'model_traj': model_traj_train.state_dict() if semi_supervised else None, 473 | # 'random_state_semi': semi_generator.random_state() if semi_supervised else None, 474 | }, best_chk_path) 475 | 476 | # Save training curves after every epoch, as .png images (if requested) 477 | if args.export_training_curves and epoch > 3: 478 | if 'matplotlib' not in sys.modules: 479 | import matplotlib 480 | 481 | matplotlib.use('Agg') 482 | import matplotlib.pyplot as plt 483 | 484 | plt.figure() 485 | epoch_x = np.arange(3, len(losses_3d_train)) + 1 486 | plt.plot(epoch_x, losses_3d_train[3:], '--', color='C0') 487 | plt.plot(epoch_x, losses_3d_train_eval[3:], color='C0') 488 | plt.plot(epoch_x, losses_3d_valid[3:], color='C1') 489 | plt.legend(['3d train', '3d train (eval)', '3d valid (eval)']) 490 | plt.ylabel('MPJPE (m)') 491 | plt.xlabel('Epoch') 492 | plt.xlim((3, epoch)) 493 | plt.savefig(os.path.join(args.checkpoint, 'loss_3d.png')) 494 | 495 | plt.close('all') 496 | 497 | 498 | # Evaluate 499 | def evaluate(test_generator, action=None, return_predictions=False, use_trajectory_model=False): 500 | epoch_loss_3d_pos = 0 501 | epoch_loss_3d_pos_procrustes = 0 502 | epoch_loss_3d_pos_scale = 0 503 | epoch_loss_3d_vel = 0 504 | with torch.no_grad(): 505 | if not use_trajectory_model: 506 | model_pos.eval() 507 | # else: 508 | # model_traj.eval() 509 | N = 0 510 | for _, batch, batch_2d in test_generator.next_epoch(): 511 | inputs_2d = torch.from_numpy(batch_2d.astype('float32')) 512 | inputs_3d = torch.from_numpy(batch.astype('float32')) 513 | 514 | 515 | ##### apply test-time-augmentation (following Videopose3d) 516 | inputs_2d_flip = inputs_2d.clone() 517 | inputs_2d_flip [:, :, :, 0] *= -1 518 | inputs_2d_flip[:, :, kps_left + kps_right,:] = inputs_2d_flip[:, :, kps_right + kps_left,:] 519 | 520 | ##### convert size 521 | inputs_2d, inputs_3d = eval_data_prepare(receptive_field, inputs_2d, inputs_3d) 522 | inputs_2d_flip, _ = eval_data_prepare(receptive_field, inputs_2d_flip, inputs_3d) 523 | 524 | if torch.cuda.is_available(): 525 | inputs_2d = inputs_2d.cuda() 526 | inputs_2d_flip = inputs_2d_flip.cuda() 527 | inputs_3d = inputs_3d.cuda() 528 | inputs_3d[:, :, 0] = 0 529 | 530 | predicted_3d_pos = model_pos(inputs_2d) 531 | predicted_3d_pos_flip = model_pos(inputs_2d_flip) 532 | predicted_3d_pos_flip[:, :, :, 0] *= -1 533 | predicted_3d_pos_flip[:, :, joints_left + joints_right] = predicted_3d_pos_flip[:, :, 534 | joints_right + joints_left] 535 | 536 | predicted_3d_pos = torch.mean(torch.cat((predicted_3d_pos, predicted_3d_pos_flip), dim=1), dim=1, 537 | keepdim=True) 538 | 539 | del inputs_2d, inputs_2d_flip 540 | torch.cuda.empty_cache() 541 | 542 | if return_predictions: 543 | return predicted_3d_pos.squeeze(0).cpu().numpy() 544 | 545 | 546 | error = mpjpe(predicted_3d_pos, inputs_3d) 547 | epoch_loss_3d_pos_scale += inputs_3d.shape[0]*inputs_3d.shape[1] * n_mpjpe(predicted_3d_pos, inputs_3d).item() 548 | 549 | epoch_loss_3d_pos += inputs_3d.shape[0]*inputs_3d.shape[1] * error.item() 550 | N += inputs_3d.shape[0] * inputs_3d.shape[1] 551 | 552 | inputs = inputs_3d.cpu().numpy().reshape(-1, inputs_3d.shape[-2], inputs_3d.shape[-1]) 553 | predicted_3d_pos = predicted_3d_pos.cpu().numpy().reshape(-1, inputs_3d.shape[-2], inputs_3d.shape[-1]) 554 | 555 | epoch_loss_3d_pos_procrustes += inputs_3d.shape[0]*inputs_3d.shape[1] * p_mpjpe(predicted_3d_pos, inputs) 556 | 557 | # Compute velocity error 558 | epoch_loss_3d_vel += inputs_3d.shape[0]*inputs_3d.shape[1] * mean_velocity_error(predicted_3d_pos, inputs) 559 | 560 | if action is None: 561 | print('----------') 562 | else: 563 | print('----'+action+'----') 564 | e1 = (epoch_loss_3d_pos / N)*1000 565 | e2 = (epoch_loss_3d_pos_procrustes / N)*1000 566 | e3 = (epoch_loss_3d_pos_scale / N)*1000 567 | ev = (epoch_loss_3d_vel / N)*1000 568 | print('Protocol #1 Error (MPJPE):', e1, 'mm') 569 | print('Protocol #2 Error (P-MPJPE):', e2, 'mm') 570 | print('Protocol #3 Error (N-MPJPE):', e3, 'mm') 571 | print('Velocity Error (MPJVE):', ev, 'mm') 572 | print('----------') 573 | 574 | return e1, e2, e3, ev 575 | 576 | if args.render: 577 | print('Rendering...') 578 | 579 | input_keypoints = keypoints[args.viz_subject][args.viz_action][args.viz_camera].copy() 580 | ground_truth = None 581 | if args.viz_subject in dataset.subjects() and args.viz_action in dataset[args.viz_subject]: 582 | if 'positions_3d' in dataset[args.viz_subject][args.viz_action]: 583 | ground_truth = dataset[args.viz_subject][args.viz_action]['positions_3d'][args.viz_camera].copy() 584 | if ground_truth is None: 585 | print('INFO: this action is unlabeled. Ground truth will not be rendered.') 586 | 587 | gen = UnchunkedGenerator(None, [ground_truth], [input_keypoints], 588 | pad=pad, causal_shift=causal_shift, augment=args.test_time_augmentation, 589 | kps_left=kps_left, kps_right=kps_right, joints_left=joints_left, joints_right=joints_right) 590 | prediction = evaluate(gen, return_predictions=True) 591 | # if model_traj is not None and ground_truth is None: 592 | # prediction_traj = evaluate(gen, return_predictions=True, use_trajectory_model=True) 593 | # prediction += prediction_traj 594 | 595 | if args.viz_export is not None: 596 | print('Exporting joint positions to', args.viz_export) 597 | # Predictions are in camera space 598 | np.save(args.viz_export, prediction) 599 | 600 | gen = UnchunkedGenerator(None, [ground_truth], [input_keypoints], 601 | pad=pad, causal_shift=causal_shift, augment=args.test_time_augmentation, 602 | kps_left=kps_left, kps_right=kps_right, joints_left=joints_left, joints_right=joints_right) 603 | prediction = evaluate(gen, return_predictions=True) 604 | prediction = prediction[:,0,:,:] 605 | 606 | if args.viz_output is not None: 607 | if ground_truth is not None: 608 | # Reapply trajectory 609 | trajectory = ground_truth[:, :1] 610 | ground_truth[:, 1:] += trajectory 611 | prediction += trajectory 612 | 613 | # Invert camera transformation 614 | cam = dataset.cameras()[args.viz_subject][args.viz_camera] 615 | if ground_truth is not None: 616 | prediction = camera_to_world(prediction, R=cam['orientation'], t=cam['translation']) 617 | ground_truth = camera_to_world(ground_truth, R=cam['orientation'], t=cam['translation']) 618 | else: 619 | # If the ground truth is not available, take the camera extrinsic params from a random subject. 620 | # They are almost the same, and anyway, we only need this for visualization purposes. 621 | for subject in dataset.cameras(): 622 | if 'orientation' in dataset.cameras()[subject][args.viz_camera]: 623 | rot = dataset.cameras()[subject][args.viz_camera]['orientation'] 624 | break 625 | prediction = camera_to_world(prediction, R=rot, t=0) 626 | # We don't have the trajectory, but at least we can rebase the height 627 | prediction[:, :, 2] -= np.min(prediction[:, :, 2]) 628 | 629 | anim_output = {'Reconstruction': prediction} 630 | if ground_truth is not None and not args.viz_no_ground_truth: 631 | anim_output['Ground truth'] = ground_truth 632 | 633 | input_keypoints = image_coordinates(input_keypoints[..., :2], w=cam['res_w'], h=cam['res_h']) 634 | 635 | from common.visualization import render_animation 636 | 637 | render_animation(input_keypoints, keypoints_metadata, anim_output, 638 | dataset.skeleton(), dataset.fps(), args.viz_bitrate, cam['azimuth'], args.viz_output, 639 | limit=args.viz_limit, downsample=args.viz_downsample, size=args.viz_size, 640 | input_video_path=args.viz_video, viewport=(cam['res_w'], cam['res_h']), 641 | input_video_skip=args.viz_skip) 642 | 643 | else: 644 | print('Evaluating...') 645 | all_actions = {} 646 | all_actions_by_subject = {} 647 | for subject in subjects_test: 648 | if subject not in all_actions_by_subject: 649 | all_actions_by_subject[subject] = {} 650 | 651 | for action in dataset[subject].keys(): 652 | action_name = action.split(' ')[0] 653 | if action_name not in all_actions: 654 | all_actions[action_name] = [] 655 | if action_name not in all_actions_by_subject[subject]: 656 | all_actions_by_subject[subject][action_name] = [] 657 | all_actions[action_name].append((subject, action)) 658 | all_actions_by_subject[subject][action_name].append((subject, action)) 659 | 660 | 661 | def fetch_actions(actions): 662 | out_poses_3d = [] 663 | out_poses_2d = [] 664 | 665 | for subject, action in actions: 666 | poses_2d = keypoints[subject][action] 667 | for i in range(len(poses_2d)): # Iterate across cameras 668 | out_poses_2d.append(poses_2d[i]) 669 | 670 | poses_3d = dataset[subject][action]['positions_3d'] 671 | assert len(poses_3d) == len(poses_2d), 'Camera count mismatch' 672 | for i in range(len(poses_3d)): # Iterate across cameras 673 | out_poses_3d.append(poses_3d[i]) 674 | 675 | stride = args.downsample 676 | if stride > 1: 677 | # Downsample as requested 678 | for i in range(len(out_poses_2d)): 679 | out_poses_2d[i] = out_poses_2d[i][::stride] 680 | if out_poses_3d is not None: 681 | out_poses_3d[i] = out_poses_3d[i][::stride] 682 | 683 | return out_poses_3d, out_poses_2d 684 | 685 | 686 | def run_evaluation(actions, action_filter=None): 687 | errors_p1 = [] 688 | errors_p2 = [] 689 | errors_p3 = [] 690 | errors_vel = [] 691 | 692 | for action_key in actions.keys(): 693 | if action_filter is not None: 694 | found = False 695 | for a in action_filter: 696 | if action_key.startswith(a): 697 | found = True 698 | break 699 | if not found: 700 | continue 701 | 702 | poses_act, poses_2d_act = fetch_actions(actions[action_key]) 703 | gen = UnchunkedGenerator(None, poses_act, poses_2d_act, 704 | pad=pad, causal_shift=causal_shift, augment=args.test_time_augmentation, 705 | kps_left=kps_left, kps_right=kps_right, joints_left=joints_left, 706 | joints_right=joints_right) 707 | e1, e2, e3, ev = evaluate(gen, action_key) 708 | errors_p1.append(e1) 709 | errors_p2.append(e2) 710 | errors_p3.append(e3) 711 | errors_vel.append(ev) 712 | 713 | print('Protocol #1 (MPJPE) action-wise average:', round(np.mean(errors_p1), 1), 'mm') 714 | print('Protocol #2 (P-MPJPE) action-wise average:', round(np.mean(errors_p2), 1), 'mm') 715 | print('Protocol #3 (N-MPJPE) action-wise average:', round(np.mean(errors_p3), 1), 'mm') 716 | print('Velocity (MPJVE) action-wise average:', round(np.mean(errors_vel), 2), 'mm') 717 | 718 | 719 | if not args.by_subject: 720 | run_evaluation(all_actions, action_filter) 721 | else: 722 | for subject in all_actions_by_subject.keys(): 723 | print('Evaluating on subject', subject) 724 | run_evaluation(all_actions_by_subject[subject], action_filter) 725 | print('') 726 | --------------------------------------------------------------------------------