├── README.md ├── checkpoint └── README.md ├── common ├── README.MD ├── arguments.py ├── camera.py ├── custom_dataset.py ├── generators.py ├── h36m_dataset.py ├── humaneva_dataset.py ├── loss.py ├── mocap_dataset.py ├── model_poseformer.py ├── quaternion.py ├── skeleton.py ├── utils.py └── visualization.py ├── data └── README.MD ├── figure ├── H3.6.gif ├── PoseFormer.gif └── wild.gif ├── poseformer.yml └── run_poseformer.py /README.md: -------------------------------------------------------------------------------- 1 | # 3D Human Pose Estimation with Spatial and Temporal Transformers 2 | This repo is the official implementation for [3D Human Pose Estimation with Spatial and Temporal Transformers](https://arxiv.org/pdf/2103.10455.pdf). The paper is accepted to [ICCV 2021](http://iccv2021.thecvf.com/home). 3 | 4 | - **Welcome to check our Neurips 2023 work:** [Context-Aware PoseFormer](https://arxiv.org/pdf/2311.03312.pdf) 5 | - **Welcome to check our CVPR 2023 work:** [PoseFormerV2](https://github.com/QitaoZhao/PoseFormerV2) 6 | - Visualization code for in-the-wild videos can be found here [PoseFormer_demo](https://github.com/zczcwh/poseformer_demo) 7 | 8 | [Video Demonstration](https://youtu.be/z8HWOdXjGR8) 9 | 10 | ## PoseFormer Architecture 11 |

12 | 13 | 14 | ## Video Demo 15 | 16 | 17 | |

| 18 | |:--:| 19 | | 3D HPE on Human3.6M | 20 | 21 | |

| 22 | |:--:| 23 | | 3D HPE on videos in-the-wild using PoseFormer | 24 | 25 | 26 | 27 | Our code is built on top of [VideoPose3D](https://github.com/facebookresearch/VideoPose3D). 28 | 29 | ### Environment 30 | 31 | The code is developed and tested under the following environment 32 | 33 | * Python 3.8.2 34 | * PyTorch 1.7.1 35 | * CUDA 11.0 36 | 37 | You can create the environment: 38 | ```bash 39 | conda env create -f poseformer.yml 40 | ``` 41 | 42 | ### Dataset 43 | 44 | Our code is compatible with the dataset setup introduced by [Martinez et al.](https://github.com/una-dinosauria/3d-pose-baseline) and [Pavllo et al.](https://github.com/facebookresearch/VideoPose3D). Please refer to [VideoPose3D](https://github.com/facebookresearch/VideoPose3D) to set up the Human3.6M dataset (./data directory). 45 | 46 | ### Evaluating pre-trained models 47 | 48 | We provide the pre-trained 81-frame model (CPN detected 2D pose as input) [here](https://drive.google.com/file/d/1oX5H5QpVoFzyD-Qz9aaP3RDWDb1v1sIy/view?usp=sharing). To evaluate it, put it into the `./checkpoint` directory and run: 49 | 50 | ```bash 51 | python run_poseformer.py -k cpn_ft_h36m_dbb -f 81 -c checkpoint --evaluate detected81f.bin 52 | ``` 53 | 54 | We also provide pre-trained 81-frame model (Ground truth 2D pose as input) [here](https://drive.google.com/file/d/18wW4TdNYxF-zdt9oInmwQK9hEdRJnXzu/view?usp=sharing). To evaluate it, put it into the `./checkpoint` directory and run: 55 | 56 | ```bash 57 | python run_poseformer.py -k gt -f 81 -c checkpoint --evaluate gt81f.bin 58 | ``` 59 | 60 | 61 | ### Training new models 62 | 63 | * To train a model from scratch (CPN detected 2D pose as input), run: 64 | 65 | ```bash 66 | python run_poseformer.py -k cpn_ft_h36m_dbb -f 27 -lr 0.00004 -lrd 0.99 67 | ``` 68 | 69 | `-f` controls how many frames are used as input. 27 frames achieves 47.0 mm, 81 frames achieves achieves 44.3 mm. 70 | 71 | * To train a model from scratch (Ground truth 2D pose as input), run: 72 | 73 | ```bash 74 | python run_poseformer.py -k gt -f 81 -lr 0.0004 -lrd 0.99 75 | ``` 76 | 77 | 81 frames achieves 31.3 mm (MPJPE). 78 | 79 | ### Visualization and other functions 80 | 81 | We keep our code consistent with [VideoPose3D](https://github.com/facebookresearch/VideoPose3D). Please refer to their project page for further information. 82 | 83 | ### Bibtex 84 | If you find our work useful in your research, please consider citing: 85 | 86 | @article{zheng20213d, 87 | title={3D Human Pose Estimation with Spatial and Temporal Transformers}, 88 | author={Zheng, Ce and Zhu, Sijie and Mendieta, Matias and Yang, Taojiannan and Chen, Chen and Ding, Zhengming}, 89 | journal={Proceedings of the IEEE International Conference on Computer Vision (ICCV)}, 90 | year={2021} 91 | } 92 | 93 | ## Acknowledgement 94 | 95 | Part of our code is borrowed from [VideoPose3D](https://github.com/facebookresearch/VideoPose3D). We thank the authors for releasing the codes. 96 | -------------------------------------------------------------------------------- /checkpoint/README.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /common/README.MD: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /common/arguments.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import argparse 9 | 10 | def parse_args(): 11 | parser = argparse.ArgumentParser(description='Training script') 12 | 13 | # General arguments 14 | parser.add_argument('-d', '--dataset', default='h36m', type=str, metavar='NAME', help='target dataset') # h36m or humaneva 15 | parser.add_argument('-k', '--keypoints', default='cpn_ft_h36m_dbb', type=str, metavar='NAME', help='2D detections to use') 16 | parser.add_argument('-str', '--subjects-train', default='S1,S5,S6,S7,S8', type=str, metavar='LIST', 17 | help='training subjects separated by comma') 18 | parser.add_argument('-ste', '--subjects-test', default='S9,S11', type=str, metavar='LIST', help='test subjects separated by comma') 19 | parser.add_argument('-sun', '--subjects-unlabeled', default='', type=str, metavar='LIST', 20 | help='unlabeled subjects separated by comma for self-supervision') 21 | parser.add_argument('-a', '--actions', default='*', type=str, metavar='LIST', 22 | help='actions to train/test on, separated by comma, or * for all') 23 | parser.add_argument('-c', '--checkpoint', default='checkpoint', type=str, metavar='PATH', 24 | help='checkpoint directory') 25 | parser.add_argument('--checkpoint-frequency', default=40, type=int, metavar='N', 26 | help='create a checkpoint every N epochs') 27 | parser.add_argument('-r', '--resume', default='', type=str, metavar='FILENAME', 28 | help='checkpoint to resume (file name)') 29 | parser.add_argument('--evaluate', default='', type=str, metavar='FILENAME', help='checkpoint to evaluate (file name)') 30 | parser.add_argument('--render', action='store_true', help='visualize a particular video') 31 | parser.add_argument('--by-subject', action='store_true', help='break down error by subject (on evaluation)') 32 | parser.add_argument('--export-training-curves', action='store_true', help='save training curves as .png images') 33 | 34 | 35 | # Model arguments 36 | parser.add_argument('-s', '--stride', default=1, type=int, metavar='N', help='chunk size to use during training') 37 | parser.add_argument('-e', '--epochs', default=200, type=int, metavar='N', help='number of training epochs') 38 | parser.add_argument('-b', '--batch-size', default=512, type=int, metavar='N', help='batch size in terms of predicted frames') 39 | parser.add_argument('-drop', '--dropout', default=0., type=float, metavar='P', help='dropout probability') 40 | parser.add_argument('-lr', '--learning-rate', default=0.0001, type=float, metavar='LR', help='initial learning rate') 41 | parser.add_argument('-lrd', '--lr-decay', default=0.99, type=float, metavar='LR', help='learning rate decay per epoch') 42 | parser.add_argument('-no-da', '--no-data-augmentation', dest='data_augmentation', action='store_false', 43 | help='disable train-time flipping') 44 | # parser.add_argument('-no-tta', '--no-test-time-augmentation', dest='test_time_augmentation', action='store_false', 45 | # help='disable test-time flipping') 46 | # parser.add_argument('-arc', '--architecture', default='3,3,3', type=str, metavar='LAYERS', help='filter widths separated by comma') 47 | parser.add_argument('-frame', '--number-of-frames', default='81', type=int, metavar='N', 48 | help='how many frames used as input') 49 | # parser.add_argument('--causal', action='store_true', help='use causal convolutions for real-time processing') 50 | # parser.add_argument('-ch', '--channels', default=1024, type=int, metavar='N', help='number of channels in convolution layers') 51 | 52 | # Experimental 53 | parser.add_argument('--subset', default=1, type=float, metavar='FRACTION', help='reduce dataset size by fraction') 54 | parser.add_argument('--downsample', default=1, type=int, metavar='FACTOR', help='downsample frame rate by factor (semi-supervised)') 55 | parser.add_argument('--warmup', default=1, type=int, metavar='N', help='warm-up epochs for semi-supervision') 56 | parser.add_argument('--no-eval', action='store_true', help='disable epoch evaluation while training (small speed-up)') 57 | parser.add_argument('--dense', action='store_true', help='use dense convolutions instead of dilated convolutions') 58 | parser.add_argument('--disable-optimizations', action='store_true', help='disable optimized model for single-frame predictions') 59 | parser.add_argument('--linear-projection', action='store_true', help='use only linear coefficients for semi-supervised projection') 60 | parser.add_argument('--no-bone-length', action='store_false', dest='bone_length_term', 61 | help='disable bone length term in semi-supervised settings') 62 | parser.add_argument('--no-proj', action='store_true', help='disable projection for semi-supervised setting') 63 | 64 | # Visualization 65 | parser.add_argument('--viz-subject', type=str, metavar='STR', help='subject to render') 66 | parser.add_argument('--viz-action', type=str, metavar='STR', help='action to render') 67 | parser.add_argument('--viz-camera', type=int, default=0, metavar='N', help='camera to render') 68 | parser.add_argument('--viz-video', type=str, metavar='PATH', help='path to input video') 69 | parser.add_argument('--viz-skip', type=int, default=0, metavar='N', help='skip first N frames of input video') 70 | parser.add_argument('--viz-output', type=str, metavar='PATH', help='output file name (.gif or .mp4)') 71 | parser.add_argument('--viz-export', type=str, metavar='PATH', help='output file name for coordinates') 72 | parser.add_argument('--viz-bitrate', type=int, default=3000, metavar='N', help='bitrate for mp4 videos') 73 | parser.add_argument('--viz-no-ground-truth', action='store_true', help='do not show ground-truth poses') 74 | parser.add_argument('--viz-limit', type=int, default=-1, metavar='N', help='only render first N frames') 75 | parser.add_argument('--viz-downsample', type=int, default=1, metavar='N', help='downsample FPS by a factor N') 76 | parser.add_argument('--viz-size', type=int, default=5, metavar='N', help='image size') 77 | 78 | parser.set_defaults(bone_length_term=True) 79 | parser.set_defaults(data_augmentation=True) 80 | parser.set_defaults(test_time_augmentation=True) 81 | # parser.set_defaults(test_time_augmentation=False) 82 | 83 | args = parser.parse_args() 84 | # Check invalid configuration 85 | if args.resume and args.evaluate: 86 | print('Invalid flags: --resume and --evaluate cannot be set at the same time') 87 | exit() 88 | 89 | if args.export_training_curves and args.no_eval: 90 | print('Invalid flags: --export-training-curves and --no-eval cannot be set at the same time') 91 | exit() 92 | 93 | return args -------------------------------------------------------------------------------- /common/camera.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import numpy as np 9 | import torch 10 | 11 | from common.utils import wrap 12 | from common.quaternion import qrot, qinverse 13 | 14 | def normalize_screen_coordinates(X, w, h): 15 | assert X.shape[-1] == 2 16 | 17 | # Normalize so that [0, w] is mapped to [-1, 1], while preserving the aspect ratio 18 | return X/w*2 - [1, h/w] 19 | 20 | 21 | def image_coordinates(X, w, h): 22 | assert X.shape[-1] == 2 23 | 24 | # Reverse camera frame normalization 25 | return (X + [1, h/w])*w/2 26 | 27 | 28 | def world_to_camera(X, R, t): 29 | Rt = wrap(qinverse, R) # Invert rotation 30 | return wrap(qrot, np.tile(Rt, (*X.shape[:-1], 1)), X - t) # Rotate and translate 31 | 32 | 33 | def camera_to_world(X, R, t): 34 | return wrap(qrot, np.tile(R, (*X.shape[:-1], 1)), X) + t 35 | 36 | 37 | def project_to_2d(X, camera_params): 38 | """ 39 | Project 3D points to 2D using the Human3.6M camera projection function. 40 | This is a differentiable and batched reimplementation of the original MATLAB script. 41 | 42 | Arguments: 43 | X -- 3D points in *camera space* to transform (N, *, 3) 44 | camera_params -- intrinsic parameteres (N, 2+2+3+2=9) 45 | """ 46 | assert X.shape[-1] == 3 47 | assert len(camera_params.shape) == 2 48 | assert camera_params.shape[-1] == 9 49 | assert X.shape[0] == camera_params.shape[0] 50 | 51 | while len(camera_params.shape) < len(X.shape): 52 | camera_params = camera_params.unsqueeze(1) 53 | 54 | f = camera_params[..., :2] 55 | c = camera_params[..., 2:4] 56 | k = camera_params[..., 4:7] 57 | p = camera_params[..., 7:] 58 | 59 | XX = torch.clamp(X[..., :2] / X[..., 2:], min=-1, max=1) 60 | r2 = torch.sum(XX[..., :2]**2, dim=len(XX.shape)-1, keepdim=True) 61 | 62 | radial = 1 + torch.sum(k * torch.cat((r2, r2**2, r2**3), dim=len(r2.shape)-1), dim=len(r2.shape)-1, keepdim=True) 63 | tan = torch.sum(p*XX, dim=len(XX.shape)-1, keepdim=True) 64 | 65 | XXX = XX*(radial + tan) + p*r2 66 | 67 | return f*XXX + c 68 | 69 | def project_to_2d_linear(X, camera_params): 70 | """ 71 | Project 3D points to 2D using only linear parameters (focal length and principal point). 72 | 73 | Arguments: 74 | X -- 3D points in *camera space* to transform (N, *, 3) 75 | camera_params -- intrinsic parameteres (N, 2+2+3+2=9) 76 | """ 77 | assert X.shape[-1] == 3 78 | assert len(camera_params.shape) == 2 79 | assert camera_params.shape[-1] == 9 80 | assert X.shape[0] == camera_params.shape[0] 81 | 82 | while len(camera_params.shape) < len(X.shape): 83 | camera_params = camera_params.unsqueeze(1) 84 | 85 | f = camera_params[..., :2] 86 | c = camera_params[..., 2:4] 87 | 88 | XX = torch.clamp(X[..., :2] / X[..., 2:], min=-1, max=1) 89 | 90 | return f*XX + c -------------------------------------------------------------------------------- /common/custom_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import numpy as np 9 | import copy 10 | from common.skeleton import Skeleton 11 | from common.mocap_dataset import MocapDataset 12 | from common.camera import normalize_screen_coordinates, image_coordinates 13 | from common.h36m_dataset import h36m_skeleton 14 | 15 | 16 | custom_camera_params = { 17 | 'id': None, 18 | 'res_w': None, # Pulled from metadata 19 | 'res_h': None, # Pulled from metadata 20 | 21 | # Dummy camera parameters (taken from Human3.6M), only for visualization purposes 22 | 'azimuth': 70, # Only used for visualization 23 | 'orientation': [0.1407056450843811, -0.1500701755285263, -0.755240797996521, 0.6223280429840088], 24 | 'translation': [1841.1070556640625, 4955.28466796875, 1563.4454345703125], 25 | } 26 | 27 | class CustomDataset(MocapDataset): 28 | def __init__(self, detections_path, remove_static_joints=True): 29 | super().__init__(fps=None, skeleton=h36m_skeleton) 30 | 31 | # Load serialized dataset 32 | data = np.load(detections_path, allow_pickle=True) 33 | resolutions = data['metadata'].item()['video_metadata'] 34 | 35 | self._cameras = {} 36 | self._data = {} 37 | for video_name, res in resolutions.items(): 38 | cam = {} 39 | cam.update(custom_camera_params) 40 | cam['orientation'] = np.array(cam['orientation'], dtype='float32') 41 | cam['translation'] = np.array(cam['translation'], dtype='float32') 42 | cam['translation'] = cam['translation']/1000 # mm to meters 43 | 44 | cam['id'] = video_name 45 | cam['res_w'] = res['w'] 46 | cam['res_h'] = res['h'] 47 | 48 | self._cameras[video_name] = [cam] 49 | 50 | self._data[video_name] = { 51 | 'custom': { 52 | 'cameras': cam 53 | } 54 | } 55 | 56 | if remove_static_joints: 57 | # Bring the skeleton to 17 joints instead of the original 32 58 | self.remove_joints([4, 5, 9, 10, 11, 16, 20, 21, 22, 23, 24, 28, 29, 30, 31]) 59 | 60 | # Rewire shoulders to the correct parents 61 | self._skeleton._parents[11] = 8 62 | self._skeleton._parents[14] = 8 63 | 64 | def supports_semi_supervised(self): 65 | return False 66 | -------------------------------------------------------------------------------- /common/generators.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from itertools import zip_longest 9 | import numpy as np 10 | 11 | 12 | # def getbone(seq, boneindex): 13 | # bs = np.shape(seq)[0] 14 | # ss = np.shape(seq)[1] 15 | # seq = np.reshape(seq,(bs*ss,-1,3)) 16 | # bone = [] 17 | # for index in boneindex: 18 | # bone.append(seq[:,index[0]] - seq[:,index[1]]) 19 | # bone = np.stack(bone,1) 20 | # bone = np.power(np.power(bone,2).sum(2),0.5) 21 | # bone = np.reshape(bone, (bs,ss,np.shape(bone)[1])) 22 | # return bone 23 | 24 | class ChunkedGenerator: 25 | """ 26 | Batched data generator, used for training. 27 | The sequences are split into equal-length chunks and padded as necessary. 28 | 29 | Arguments: 30 | batch_size -- the batch size to use for training 31 | cameras -- list of cameras, one element for each video (optional, used for semi-supervised training) 32 | poses_3d -- list of ground-truth 3D poses, one element for each video (optional, used for supervised training) 33 | poses_2d -- list of input 2D keypoints, one element for each video 34 | chunk_length -- number of output frames to predict for each training example (usually 1) 35 | pad -- 2D input padding to compensate for valid convolutions, per side (depends on the receptive field) 36 | causal_shift -- asymmetric padding offset when causal convolutions are used (usually 0 or "pad") 37 | shuffle -- randomly shuffle the dataset before each epoch 38 | random_seed -- initial seed to use for the random generator 39 | augment -- augment the dataset by flipping poses horizontally 40 | kps_left and kps_right -- list of left/right 2D keypoints if flipping is enabled 41 | joints_left and joints_right -- list of left/right 3D joints if flipping is enabled 42 | """ 43 | def __init__(self, batch_size, cameras, poses_3d, poses_2d, 44 | chunk_length, pad=0, causal_shift=0, 45 | shuffle=True, random_seed=1234, 46 | augment=False, kps_left=None, kps_right=None, joints_left=None, joints_right=None, 47 | endless=False): 48 | assert poses_3d is None or len(poses_3d) == len(poses_2d), (len(poses_3d), len(poses_2d)) 49 | assert cameras is None or len(cameras) == len(poses_2d) 50 | 51 | # Build lineage info 52 | pairs = [] # (seq_idx, start_frame, end_frame, flip) tuples 53 | for i in range(len(poses_2d)): 54 | assert poses_3d is None or poses_3d[i].shape[0] == poses_3d[i].shape[0] 55 | n_chunks = (poses_2d[i].shape[0] + chunk_length - 1) // chunk_length 56 | offset = (n_chunks * chunk_length - poses_2d[i].shape[0]) // 2 57 | bounds = np.arange(n_chunks+1)*chunk_length - offset 58 | augment_vector = np.full(len(bounds - 1), False, dtype=bool) 59 | pairs += zip(np.repeat(i, len(bounds - 1)), bounds[:-1], bounds[1:], augment_vector) 60 | if augment: 61 | pairs += zip(np.repeat(i, len(bounds - 1)), bounds[:-1], bounds[1:], ~augment_vector) 62 | 63 | # Initialize buffers 64 | if cameras is not None: 65 | self.batch_cam = np.empty((batch_size, cameras[0].shape[-1])) 66 | if poses_3d is not None: 67 | self.batch_3d = np.empty((batch_size, chunk_length, poses_3d[0].shape[-2], poses_3d[0].shape[-1])) 68 | self.batch_2d = np.empty((batch_size, chunk_length + 2*pad, poses_2d[0].shape[-2], poses_2d[0].shape[-1])) 69 | 70 | self.num_batches = (len(pairs) + batch_size - 1) // batch_size 71 | self.batch_size = batch_size 72 | self.random = np.random.RandomState(random_seed) 73 | self.pairs = pairs 74 | self.shuffle = shuffle 75 | self.pad = pad 76 | self.causal_shift = causal_shift 77 | self.endless = endless 78 | self.state = None 79 | 80 | self.cameras = cameras 81 | self.poses_3d = poses_3d 82 | self.poses_2d = poses_2d 83 | 84 | self.augment = augment 85 | self.kps_left = kps_left 86 | self.kps_right = kps_right 87 | self.joints_left = joints_left 88 | self.joints_right = joints_right 89 | 90 | def num_frames(self): 91 | return self.num_batches * self.batch_size 92 | 93 | def random_state(self): 94 | return self.random 95 | 96 | def set_random_state(self, random): 97 | self.random = random 98 | 99 | def augment_enabled(self): 100 | return self.augment 101 | 102 | def next_pairs(self): 103 | if self.state is None: 104 | if self.shuffle: 105 | pairs = self.random.permutation(self.pairs) 106 | else: 107 | pairs = self.pairs 108 | return 0, pairs 109 | else: 110 | return self.state 111 | 112 | def next_epoch(self): 113 | enabled = True 114 | while enabled: 115 | start_idx, pairs = self.next_pairs() 116 | for b_i in range(start_idx, self.num_batches): 117 | chunks = pairs[b_i*self.batch_size : (b_i+1)*self.batch_size] 118 | for i, (seq_i, start_3d, end_3d, flip) in enumerate(chunks): 119 | start_2d = start_3d - self.pad - self.causal_shift 120 | end_2d = end_3d + self.pad - self.causal_shift 121 | 122 | # 2D poses 123 | seq_2d = self.poses_2d[seq_i] 124 | low_2d = max(start_2d, 0) 125 | high_2d = min(end_2d, seq_2d.shape[0]) 126 | pad_left_2d = low_2d - start_2d 127 | pad_right_2d = end_2d - high_2d 128 | if pad_left_2d != 0 or pad_right_2d != 0: 129 | self.batch_2d[i] = np.pad(seq_2d[low_2d:high_2d], ((pad_left_2d, pad_right_2d), (0, 0), (0, 0)), 'edge') 130 | else: 131 | self.batch_2d[i] = seq_2d[low_2d:high_2d] 132 | 133 | if flip: 134 | # Flip 2D keypoints 135 | self.batch_2d[i, :, :, 0] *= -1 136 | self.batch_2d[i, :, self.kps_left + self.kps_right] = self.batch_2d[i, :, self.kps_right + self.kps_left] 137 | 138 | # 3D poses 139 | if self.poses_3d is not None: 140 | seq_3d = self.poses_3d[seq_i] 141 | low_3d = max(start_3d, 0) 142 | high_3d = min(end_3d, seq_3d.shape[0]) 143 | pad_left_3d = low_3d - start_3d 144 | pad_right_3d = end_3d - high_3d 145 | if pad_left_3d != 0 or pad_right_3d != 0: 146 | self.batch_3d[i] = np.pad(seq_3d[low_3d:high_3d], ((pad_left_3d, pad_right_3d), (0, 0), (0, 0)), 'edge') 147 | else: 148 | self.batch_3d[i] = seq_3d[low_3d:high_3d] 149 | 150 | if flip: 151 | # Flip 3D joints 152 | self.batch_3d[i, :, :, 0] *= -1 153 | self.batch_3d[i, :, self.joints_left + self.joints_right] = \ 154 | self.batch_3d[i, :, self.joints_right + self.joints_left] 155 | 156 | # Cameras 157 | if self.cameras is not None: 158 | self.batch_cam[i] = self.cameras[seq_i] 159 | if flip: 160 | # Flip horizontal distortion coefficients 161 | self.batch_cam[i, 2] *= -1 162 | self.batch_cam[i, 7] *= -1 163 | 164 | if self.endless: 165 | self.state = (b_i + 1, pairs) 166 | if self.poses_3d is None and self.cameras is None: 167 | yield None, None, self.batch_2d[:len(chunks)] 168 | elif self.poses_3d is not None and self.cameras is None: 169 | yield None, self.batch_3d[:len(chunks)], self.batch_2d[:len(chunks)] 170 | elif self.poses_3d is None: 171 | yield self.batch_cam[:len(chunks)], None, self.batch_2d[:len(chunks)] 172 | else: 173 | yield self.batch_cam[:len(chunks)], self.batch_3d[:len(chunks)], self.batch_2d[:len(chunks)] 174 | 175 | if self.endless: 176 | self.state = None 177 | else: 178 | enabled = False 179 | 180 | 181 | class UnchunkedGenerator: 182 | """ 183 | Non-batched data generator, used for testing. 184 | Sequences are returned one at a time (i.e. batch size = 1), without chunking. 185 | 186 | If data augmentation is enabled, the batches contain two sequences (i.e. batch size = 2), 187 | the second of which is a mirrored version of the first. 188 | 189 | Arguments: 190 | cameras -- list of cameras, one element for each video (optional, used for semi-supervised training) 191 | poses_3d -- list of ground-truth 3D poses, one element for each video (optional, used for supervised training) 192 | poses_2d -- list of input 2D keypoints, one element for each video 193 | pad -- 2D input padding to compensate for valid convolutions, per side (depends on the receptive field) 194 | causal_shift -- asymmetric padding offset when causal convolutions are used (usually 0 or "pad") 195 | augment -- augment the dataset by flipping poses horizontally 196 | kps_left and kps_right -- list of left/right 2D keypoints if flipping is enabled 197 | joints_left and joints_right -- list of left/right 3D joints if flipping is enabled 198 | """ 199 | 200 | def __init__(self, cameras, poses_3d, poses_2d, pad=0, causal_shift=0, 201 | augment=False, kps_left=None, kps_right=None, joints_left=None, joints_right=None): 202 | assert poses_3d is None or len(poses_3d) == len(poses_2d) 203 | assert cameras is None or len(cameras) == len(poses_2d) 204 | 205 | self.augment = False 206 | self.kps_left = kps_left 207 | self.kps_right = kps_right 208 | self.joints_left = joints_left 209 | self.joints_right = joints_right 210 | 211 | self.pad = pad 212 | self.causal_shift = causal_shift 213 | self.cameras = [] if cameras is None else cameras 214 | self.poses_3d = [] if poses_3d is None else poses_3d 215 | self.poses_2d = poses_2d 216 | 217 | def num_frames(self): 218 | count = 0 219 | for p in self.poses_2d: 220 | count += p.shape[0] 221 | return count 222 | 223 | def augment_enabled(self): 224 | return self.augment 225 | 226 | def set_augment(self, augment): 227 | self.augment = augment 228 | 229 | def next_epoch(self): 230 | for seq_cam, seq_3d, seq_2d in zip_longest(self.cameras, self.poses_3d, self.poses_2d): 231 | batch_cam = None if seq_cam is None else np.expand_dims(seq_cam, axis=0) 232 | batch_3d = None if seq_3d is None else np.expand_dims(seq_3d, axis=0) 233 | batch_2d = np.expand_dims(np.pad(seq_2d, 234 | ((self.pad + self.causal_shift, self.pad - self.causal_shift), (0, 0), (0, 0)), 235 | 'edge'), axis=0) 236 | if self.augment: 237 | # Append flipped version 238 | if batch_cam is not None: 239 | batch_cam = np.concatenate((batch_cam, batch_cam), axis=0) 240 | batch_cam[1, 2] *= -1 241 | batch_cam[1, 7] *= -1 242 | 243 | if batch_3d is not None: 244 | batch_3d = np.concatenate((batch_3d, batch_3d), axis=0) 245 | batch_3d[1, :, :, 0] *= -1 246 | batch_3d[1, :, self.joints_left + self.joints_right] = batch_3d[1, :, self.joints_right + self.joints_left] 247 | 248 | batch_2d = np.concatenate((batch_2d, batch_2d), axis=0) 249 | batch_2d[1, :, :, 0] *= -1 250 | batch_2d[1, :, self.kps_left + self.kps_right] = batch_2d[1, :, self.kps_right + self.kps_left] 251 | 252 | yield batch_cam, batch_3d, batch_2d -------------------------------------------------------------------------------- /common/h36m_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import numpy as np 9 | import copy 10 | from common.skeleton import Skeleton 11 | from common.mocap_dataset import MocapDataset 12 | from common.camera import normalize_screen_coordinates, image_coordinates 13 | 14 | h36m_skeleton = Skeleton(parents=[-1, 0, 1, 2, 3, 4, 0, 6, 7, 8, 9, 0, 11, 12, 13, 14, 12, 15 | 16, 17, 18, 19, 20, 19, 22, 12, 24, 25, 26, 27, 28, 27, 30], 16 | joints_left=[6, 7, 8, 9, 10, 16, 17, 18, 19, 20, 21, 22, 23], 17 | joints_right=[1, 2, 3, 4, 5, 24, 25, 26, 27, 28, 29, 30, 31]) 18 | 19 | h36m_cameras_intrinsic_params = [ 20 | { 21 | 'id': '54138969', 22 | 'center': [512.54150390625, 515.4514770507812], 23 | 'focal_length': [1145.0494384765625, 1143.7811279296875], 24 | 'radial_distortion': [-0.20709891617298126, 0.24777518212795258, -0.0030751503072679043], 25 | 'tangential_distortion': [-0.0009756988729350269, -0.00142447161488235], 26 | 'res_w': 1000, 27 | 'res_h': 1002, 28 | 'azimuth': 70, # Only used for visualization 29 | }, 30 | { 31 | 'id': '55011271', 32 | 'center': [508.8486328125, 508.0649108886719], 33 | 'focal_length': [1149.6756591796875, 1147.5916748046875], 34 | 'radial_distortion': [-0.1942136287689209, 0.2404085397720337, 0.006819975562393665], 35 | 'tangential_distortion': [-0.0016190266469493508, -0.0027408944442868233], 36 | 'res_w': 1000, 37 | 'res_h': 1000, 38 | 'azimuth': -70, # Only used for visualization 39 | }, 40 | { 41 | 'id': '58860488', 42 | 'center': [519.8158569335938, 501.40264892578125], 43 | 'focal_length': [1149.1407470703125, 1148.7989501953125], 44 | 'radial_distortion': [-0.2083381861448288, 0.25548800826072693, -0.0024604974314570427], 45 | 'tangential_distortion': [0.0014843869721516967, -0.0007599993259645998], 46 | 'res_w': 1000, 47 | 'res_h': 1000, 48 | 'azimuth': 110, # Only used for visualization 49 | }, 50 | { 51 | 'id': '60457274', 52 | 'center': [514.9682006835938, 501.88201904296875], 53 | 'focal_length': [1145.5113525390625, 1144.77392578125], 54 | 'radial_distortion': [-0.198384091258049, 0.21832367777824402, -0.008947807364165783], 55 | 'tangential_distortion': [-0.0005872055771760643, -0.0018133620033040643], 56 | 'res_w': 1000, 57 | 'res_h': 1002, 58 | 'azimuth': -110, # Only used for visualization 59 | }, 60 | ] 61 | 62 | h36m_cameras_extrinsic_params = { 63 | 'S1': [ 64 | { 65 | 'orientation': [0.1407056450843811, -0.1500701755285263, -0.755240797996521, 0.6223280429840088], 66 | 'translation': [1841.1070556640625, 4955.28466796875, 1563.4454345703125], 67 | }, 68 | { 69 | 'orientation': [0.6157187819480896, -0.764836311340332, -0.14833825826644897, 0.11794740706682205], 70 | 'translation': [1761.278564453125, -5078.0068359375, 1606.2650146484375], 71 | }, 72 | { 73 | 'orientation': [0.14651472866535187, -0.14647851884365082, 0.7653023600578308, -0.6094175577163696], 74 | 'translation': [-1846.7777099609375, 5215.04638671875, 1491.972412109375], 75 | }, 76 | { 77 | 'orientation': [0.5834008455276489, -0.7853162288665771, 0.14548823237419128, -0.14749594032764435], 78 | 'translation': [-1794.7896728515625, -3722.698974609375, 1574.8927001953125], 79 | }, 80 | ], 81 | 'S2': [ 82 | {}, 83 | {}, 84 | {}, 85 | {}, 86 | ], 87 | 'S3': [ 88 | {}, 89 | {}, 90 | {}, 91 | {}, 92 | ], 93 | 'S4': [ 94 | {}, 95 | {}, 96 | {}, 97 | {}, 98 | ], 99 | 'S5': [ 100 | { 101 | 'orientation': [0.1467377245426178, -0.162370964884758, -0.7551892995834351, 0.6178938746452332], 102 | 'translation': [2097.3916015625, 4880.94482421875, 1605.732421875], 103 | }, 104 | { 105 | 'orientation': [0.6159758567810059, -0.7626792192459106, -0.15728192031383514, 0.1189815029501915], 106 | 'translation': [2031.7008056640625, -5167.93310546875, 1612.923095703125], 107 | }, 108 | { 109 | 'orientation': [0.14291371405124664, -0.12907841801643372, 0.7678384780883789, -0.6110143065452576], 110 | 'translation': [-1620.5948486328125, 5171.65869140625, 1496.43701171875], 111 | }, 112 | { 113 | 'orientation': [0.5920479893684387, -0.7814217805862427, 0.1274748593568802, -0.15036417543888092], 114 | 'translation': [-1637.1737060546875, -3867.3173828125, 1547.033203125], 115 | }, 116 | ], 117 | 'S6': [ 118 | { 119 | 'orientation': [0.1337897777557373, -0.15692396461963654, -0.7571090459823608, 0.6198879480361938], 120 | 'translation': [1935.4517822265625, 4950.24560546875, 1618.0838623046875], 121 | }, 122 | { 123 | 'orientation': [0.6147197484970093, -0.7628812789916992, -0.16174767911434174, 0.11819244921207428], 124 | 'translation': [1969.803955078125, -5128.73876953125, 1632.77880859375], 125 | }, 126 | { 127 | 'orientation': [0.1529948115348816, -0.13529130816459656, 0.7646096348762512, -0.6112781167030334], 128 | 'translation': [-1769.596435546875, 5185.361328125, 1476.993408203125], 129 | }, 130 | { 131 | 'orientation': [0.5916101336479187, -0.7804774045944214, 0.12832270562648773, -0.1561593860387802], 132 | 'translation': [-1721.668701171875, -3884.13134765625, 1540.4879150390625], 133 | }, 134 | ], 135 | 'S7': [ 136 | { 137 | 'orientation': [0.1435241848230362, -0.1631336808204651, -0.7548328638076782, 0.6188824772834778], 138 | 'translation': [1974.512939453125, 4926.3544921875, 1597.8326416015625], 139 | }, 140 | { 141 | 'orientation': [0.6141672730445862, -0.7638262510299683, -0.1596645563840866, 0.1177929937839508], 142 | 'translation': [1937.0584716796875, -5119.7900390625, 1631.5665283203125], 143 | }, 144 | { 145 | 'orientation': [0.14550060033798218, -0.12874816358089447, 0.7660516500473022, -0.6127139329910278], 146 | 'translation': [-1741.8111572265625, 5208.24951171875, 1464.8245849609375], 147 | }, 148 | { 149 | 'orientation': [0.5912848114967346, -0.7821764349937439, 0.12445473670959473, -0.15196487307548523], 150 | 'translation': [-1734.7105712890625, -3832.42138671875, 1548.5830078125], 151 | }, 152 | ], 153 | 'S8': [ 154 | { 155 | 'orientation': [0.14110587537288666, -0.15589867532253265, -0.7561917304992676, 0.619644045829773], 156 | 'translation': [2150.65185546875, 4896.1611328125, 1611.9046630859375], 157 | }, 158 | { 159 | 'orientation': [0.6169601678848267, -0.7647668123245239, -0.14846350252628326, 0.11158157885074615], 160 | 'translation': [2219.965576171875, -5148.453125, 1613.0440673828125], 161 | }, 162 | { 163 | 'orientation': [0.1471444070339203, -0.13377119600772858, 0.7670128345489502, -0.6100369691848755], 164 | 'translation': [-1571.2215576171875, 5137.0185546875, 1498.1761474609375], 165 | }, 166 | { 167 | 'orientation': [0.5927824378013611, -0.7825870513916016, 0.12147816270589828, -0.14631995558738708], 168 | 'translation': [-1476.913330078125, -3896.7412109375, 1547.97216796875], 169 | }, 170 | ], 171 | 'S9': [ 172 | { 173 | 'orientation': [0.15540587902069092, -0.15548215806484222, -0.7532095313072205, 0.6199594736099243], 174 | 'translation': [2044.45849609375, 4935.1171875, 1481.2275390625], 175 | }, 176 | { 177 | 'orientation': [0.618784487247467, -0.7634735107421875, -0.14132238924503326, 0.11933968216180801], 178 | 'translation': [1990.959716796875, -5123.810546875, 1568.8048095703125], 179 | }, 180 | { 181 | 'orientation': [0.13357827067375183, -0.1367100477218628, 0.7689454555511475, -0.6100738644599915], 182 | 'translation': [-1670.9921875, 5211.98583984375, 1528.387939453125], 183 | }, 184 | { 185 | 'orientation': [0.5879399180412292, -0.7823407053947449, 0.1427614390850067, -0.14794869720935822], 186 | 'translation': [-1696.04345703125, -3827.099853515625, 1591.4127197265625], 187 | }, 188 | ], 189 | 'S11': [ 190 | { 191 | 'orientation': [0.15232472121715546, -0.15442320704460144, -0.7547563314437866, 0.6191070079803467], 192 | 'translation': [2098.440185546875, 4926.5546875, 1500.278564453125], 193 | }, 194 | { 195 | 'orientation': [0.6189449429512024, -0.7600917220115662, -0.15300633013248444, 0.1255258321762085], 196 | 'translation': [2083.182373046875, -4912.1728515625, 1561.07861328125], 197 | }, 198 | { 199 | 'orientation': [0.14943228662014008, -0.15650227665901184, 0.7681233882904053, -0.6026304364204407], 200 | 'translation': [-1609.8153076171875, 5177.3359375, 1537.896728515625], 201 | }, 202 | { 203 | 'orientation': [0.5894251465797424, -0.7818877100944519, 0.13991211354732513, -0.14715361595153809], 204 | 'translation': [-1590.738037109375, -3854.1689453125, 1578.017578125], 205 | }, 206 | ], 207 | } 208 | 209 | class Human36mDataset(MocapDataset): 210 | def __init__(self, path, remove_static_joints=True): 211 | super().__init__(fps=50, skeleton=h36m_skeleton) 212 | 213 | self._cameras = copy.deepcopy(h36m_cameras_extrinsic_params) 214 | for cameras in self._cameras.values(): 215 | for i, cam in enumerate(cameras): 216 | cam.update(h36m_cameras_intrinsic_params[i]) 217 | for k, v in cam.items(): 218 | if k not in ['id', 'res_w', 'res_h']: 219 | cam[k] = np.array(v, dtype='float32') 220 | 221 | # Normalize camera frame 222 | cam['center'] = normalize_screen_coordinates(cam['center'], w=cam['res_w'], h=cam['res_h']).astype('float32') 223 | cam['focal_length'] = cam['focal_length']/cam['res_w']*2 224 | if 'translation' in cam: 225 | cam['translation'] = cam['translation']/1000 # mm to meters 226 | 227 | # Add intrinsic parameters vector 228 | cam['intrinsic'] = np.concatenate((cam['focal_length'], 229 | cam['center'], 230 | cam['radial_distortion'], 231 | cam['tangential_distortion'])) 232 | 233 | # Load serialized dataset 234 | data = np.load(path, allow_pickle=True)['positions_3d'].item() 235 | 236 | self._data = {} 237 | for subject, actions in data.items(): 238 | self._data[subject] = {} 239 | for action_name, positions in actions.items(): 240 | self._data[subject][action_name] = { 241 | 'positions': positions, 242 | 'cameras': self._cameras[subject], 243 | } 244 | 245 | if remove_static_joints: 246 | # Bring the skeleton to 17 joints instead of the original 32 247 | self.remove_joints([4, 5, 9, 10, 11, 16, 20, 21, 22, 23, 24, 28, 29, 30, 31]) 248 | 249 | # Rewire shoulders to the correct parents 250 | self._skeleton._parents[11] = 8 251 | self._skeleton._parents[14] = 8 252 | 253 | def supports_semi_supervised(self): 254 | return True 255 | -------------------------------------------------------------------------------- /common/humaneva_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import numpy as np 9 | import copy 10 | from common.skeleton import Skeleton 11 | from common.mocap_dataset import MocapDataset 12 | from common.camera import normalize_screen_coordinates, image_coordinates 13 | 14 | humaneva_skeleton = Skeleton(parents=[-1, 0, 1, 2, 3, 1, 5, 6, 0, 8, 9, 0, 11, 12, 1], 15 | joints_left=[2, 3, 4, 8, 9, 10], 16 | joints_right=[5, 6, 7, 11, 12, 13]) 17 | 18 | humaneva_cameras_intrinsic_params = [ 19 | { 20 | 'id': 'C1', 21 | 'res_w': 640, 22 | 'res_h': 480, 23 | 'azimuth': 0, # Only used for visualization 24 | }, 25 | { 26 | 'id': 'C2', 27 | 'res_w': 640, 28 | 'res_h': 480, 29 | 'azimuth': -90, # Only used for visualization 30 | }, 31 | { 32 | 'id': 'C3', 33 | 'res_w': 640, 34 | 'res_h': 480, 35 | 'azimuth': 90, # Only used for visualization 36 | }, 37 | ] 38 | 39 | humaneva_cameras_extrinsic_params = { 40 | 'S1': [ 41 | { 42 | 'orientation': [0.424207, -0.4983646, -0.5802981, 0.4847012], 43 | 'translation': [4062.227, 663.2477, 1528.397], 44 | }, 45 | { 46 | 'orientation': [0.6503354, -0.7481602, -0.0919284, 0.0941766], 47 | 'translation': [844.8131, -3805.2092, 1504.9929], 48 | }, 49 | { 50 | 'orientation': [0.0664734, -0.0690535, 0.7416416, -0.6639132], 51 | 'translation': [-797.67377, 3916.3174, 1433.6602], 52 | }, 53 | ], 54 | 'S2': [ 55 | { 56 | 'orientation': [ 0.4214752, -0.4961493, -0.5838273, 0.4851187 ], 57 | 'translation': [ 4112.9121, 626.4929, 1545.2988], 58 | }, 59 | { 60 | 'orientation': [ 0.6501393, -0.7476588, -0.0954617, 0.0959808 ], 61 | 'translation': [ 923.5740, -3877.9243, 1504.5518], 62 | }, 63 | { 64 | 'orientation': [ 0.0699353, -0.0712403, 0.7421637, -0.662742 ], 65 | 'translation': [ -781.4915, 3838.8853, 1444.9929], 66 | }, 67 | ], 68 | 'S3': [ 69 | { 70 | 'orientation': [ 0.424207, -0.4983646, -0.5802981, 0.4847012 ], 71 | 'translation': [ 4062.2271, 663.2477, 1528.3970], 72 | }, 73 | { 74 | 'orientation': [ 0.6503354, -0.7481602, -0.0919284, 0.0941766 ], 75 | 'translation': [ 844.8131, -3805.2092, 1504.9929], 76 | }, 77 | { 78 | 'orientation': [ 0.0664734, -0.0690535, 0.7416416, -0.6639132 ], 79 | 'translation': [ -797.6738, 3916.3174, 1433.6602], 80 | }, 81 | ], 82 | 'S4': [ 83 | {}, 84 | {}, 85 | {}, 86 | ], 87 | 88 | } 89 | 90 | class HumanEvaDataset(MocapDataset): 91 | def __init__(self, path): 92 | super().__init__(fps=60, skeleton=humaneva_skeleton) 93 | 94 | self._cameras = copy.deepcopy(humaneva_cameras_extrinsic_params) 95 | for cameras in self._cameras.values(): 96 | for i, cam in enumerate(cameras): 97 | cam.update(humaneva_cameras_intrinsic_params[i]) 98 | for k, v in cam.items(): 99 | if k not in ['id', 'res_w', 'res_h']: 100 | cam[k] = np.array(v, dtype='float32') 101 | if 'translation' in cam: 102 | cam['translation'] = cam['translation']/1000 # mm to meters 103 | 104 | for subject in list(self._cameras.keys()): 105 | data = self._cameras[subject] 106 | del self._cameras[subject] 107 | for prefix in ['Train/', 'Validate/', 'Unlabeled/Train/', 'Unlabeled/Validate/', 'Unlabeled/']: 108 | self._cameras[prefix + subject] = data 109 | 110 | # Load serialized dataset 111 | data = np.load(path, allow_pickle=True)['positions_3d'].item() 112 | 113 | self._data = {} 114 | for subject, actions in data.items(): 115 | self._data[subject] = {} 116 | for action_name, positions in actions.items(): 117 | self._data[subject][action_name] = { 118 | 'positions': positions, 119 | 'cameras': self._cameras[subject], 120 | } 121 | -------------------------------------------------------------------------------- /common/loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import torch 9 | import numpy as np 10 | 11 | def mpjpe(predicted, target): 12 | """ 13 | Mean per-joint position error (i.e. mean Euclidean distance), 14 | often referred to as "Protocol #1" in many papers. 15 | """ 16 | assert predicted.shape == target.shape 17 | return torch.mean(torch.norm(predicted - target, dim=len(target.shape)-1)) 18 | 19 | def weighted_mpjpe(predicted, target, w): 20 | """ 21 | Weighted mean per-joint position error (i.e. mean Euclidean distance) 22 | """ 23 | assert predicted.shape == target.shape 24 | assert w.shape[0] == predicted.shape[0] 25 | return torch.mean(w * torch.norm(predicted - target, dim=len(target.shape)-1)) 26 | 27 | def p_mpjpe(predicted, target): 28 | """ 29 | Pose error: MPJPE after rigid alignment (scale, rotation, and translation), 30 | often referred to as "Protocol #2" in many papers. 31 | """ 32 | assert predicted.shape == target.shape 33 | 34 | muX = np.mean(target, axis=1, keepdims=True) 35 | muY = np.mean(predicted, axis=1, keepdims=True) 36 | 37 | X0 = target - muX 38 | Y0 = predicted - muY 39 | 40 | normX = np.sqrt(np.sum(X0**2, axis=(1, 2), keepdims=True)) 41 | normY = np.sqrt(np.sum(Y0**2, axis=(1, 2), keepdims=True)) 42 | 43 | X0 /= normX 44 | Y0 /= normY 45 | 46 | H = np.matmul(X0.transpose(0, 2, 1), Y0) 47 | U, s, Vt = np.linalg.svd(H) 48 | V = Vt.transpose(0, 2, 1) 49 | R = np.matmul(V, U.transpose(0, 2, 1)) 50 | 51 | # Avoid improper rotations (reflections), i.e. rotations with det(R) = -1 52 | sign_detR = np.sign(np.expand_dims(np.linalg.det(R), axis=1)) 53 | V[:, :, -1] *= sign_detR 54 | s[:, -1] *= sign_detR.flatten() 55 | R = np.matmul(V, U.transpose(0, 2, 1)) # Rotation 56 | 57 | tr = np.expand_dims(np.sum(s, axis=1, keepdims=True), axis=2) 58 | 59 | a = tr * normX / normY # Scale 60 | t = muX - a*np.matmul(muY, R) # Translation 61 | 62 | # Perform rigid transformation on the input 63 | predicted_aligned = a*np.matmul(predicted, R) + t 64 | 65 | # Return MPJPE 66 | return np.mean(np.linalg.norm(predicted_aligned - target, axis=len(target.shape)-1)) 67 | 68 | def n_mpjpe(predicted, target): 69 | """ 70 | Normalized MPJPE (scale only), adapted from: 71 | https://github.com/hrhodin/UnsupervisedGeometryAwareRepresentationLearning/blob/master/losses/poses.py 72 | """ 73 | assert predicted.shape == target.shape 74 | 75 | norm_predicted = torch.mean(torch.sum(predicted**2, dim=3, keepdim=True), dim=2, keepdim=True) 76 | norm_target = torch.mean(torch.sum(target*predicted, dim=3, keepdim=True), dim=2, keepdim=True) 77 | scale = norm_target / norm_predicted 78 | return mpjpe(scale * predicted, target) 79 | 80 | def weighted_bonelen_loss(predict_3d_length, gt_3d_length): 81 | loss_length = 0.001 * torch.pow(predict_3d_length - gt_3d_length, 2).mean() 82 | return loss_length 83 | 84 | def weighted_boneratio_loss(predict_3d_length, gt_3d_length): 85 | loss_length = 0.1 * torch.pow((predict_3d_length - gt_3d_length)/gt_3d_length, 2).mean() 86 | return loss_length 87 | 88 | def mean_velocity_error(predicted, target): 89 | """ 90 | Mean per-joint velocity error (i.e. mean Euclidean distance of the 1st derivative) 91 | """ 92 | assert predicted.shape == target.shape 93 | 94 | velocity_predicted = np.diff(predicted, axis=0) 95 | velocity_target = np.diff(target, axis=0) 96 | 97 | return np.mean(np.linalg.norm(velocity_predicted - velocity_target, axis=len(target.shape)-1)) -------------------------------------------------------------------------------- /common/mocap_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import numpy as np 9 | from common.skeleton import Skeleton 10 | 11 | class MocapDataset: 12 | def __init__(self, fps, skeleton): 13 | self._skeleton = skeleton 14 | self._fps = fps 15 | self._data = None # Must be filled by subclass 16 | self._cameras = None # Must be filled by subclass 17 | 18 | def remove_joints(self, joints_to_remove): 19 | kept_joints = self._skeleton.remove_joints(joints_to_remove) 20 | for subject in self._data.keys(): 21 | for action in self._data[subject].keys(): 22 | s = self._data[subject][action] 23 | if 'positions' in s: 24 | s['positions'] = s['positions'][:, kept_joints] 25 | 26 | 27 | def __getitem__(self, key): 28 | return self._data[key] 29 | 30 | def subjects(self): 31 | return self._data.keys() 32 | 33 | def fps(self): 34 | return self._fps 35 | 36 | def skeleton(self): 37 | return self._skeleton 38 | 39 | def cameras(self): 40 | return self._cameras 41 | 42 | def supports_semi_supervised(self): 43 | # This method can be overridden 44 | return False -------------------------------------------------------------------------------- /common/model_poseformer.py: -------------------------------------------------------------------------------- 1 | ## Our PoseFormer model was revised from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py 2 | 3 | import math 4 | import logging 5 | from functools import partial 6 | from collections import OrderedDict 7 | from einops import rearrange, repeat 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 14 | from timm.models.helpers import load_pretrained 15 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 16 | from timm.models.registry import register_model 17 | 18 | 19 | class Mlp(nn.Module): 20 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 21 | super().__init__() 22 | out_features = out_features or in_features 23 | hidden_features = hidden_features or in_features 24 | self.fc1 = nn.Linear(in_features, hidden_features) 25 | self.act = act_layer() 26 | self.fc2 = nn.Linear(hidden_features, out_features) 27 | self.drop = nn.Dropout(drop) 28 | 29 | def forward(self, x): 30 | x = self.fc1(x) 31 | x = self.act(x) 32 | x = self.drop(x) 33 | x = self.fc2(x) 34 | x = self.drop(x) 35 | return x 36 | 37 | 38 | class Attention(nn.Module): 39 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 40 | super().__init__() 41 | self.num_heads = num_heads 42 | head_dim = dim // num_heads 43 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 44 | self.scale = qk_scale or head_dim ** -0.5 45 | 46 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 47 | self.attn_drop = nn.Dropout(attn_drop) 48 | self.proj = nn.Linear(dim, dim) 49 | self.proj_drop = nn.Dropout(proj_drop) 50 | 51 | def forward(self, x): 52 | B, N, C = x.shape 53 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 54 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 55 | 56 | attn = (q @ k.transpose(-2, -1)) * self.scale 57 | attn = attn.softmax(dim=-1) 58 | attn = self.attn_drop(attn) 59 | 60 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 61 | x = self.proj(x) 62 | x = self.proj_drop(x) 63 | return x 64 | 65 | 66 | class Block(nn.Module): 67 | 68 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 69 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 70 | super().__init__() 71 | self.norm1 = norm_layer(dim) 72 | self.attn = Attention( 73 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 74 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 75 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 76 | self.norm2 = norm_layer(dim) 77 | mlp_hidden_dim = int(dim * mlp_ratio) 78 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 79 | 80 | def forward(self, x): 81 | x = x + self.drop_path(self.attn(self.norm1(x))) 82 | x = x + self.drop_path(self.mlp(self.norm2(x))) 83 | return x 84 | 85 | class PoseTransformer(nn.Module): 86 | def __init__(self, num_frame=9, num_joints=17, in_chans=2, embed_dim_ratio=32, depth=4, 87 | num_heads=8, mlp_ratio=2., qkv_bias=True, qk_scale=None, 88 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.2, norm_layer=None): 89 | """ ##########hybrid_backbone=None, representation_size=None, 90 | Args: 91 | num_frame (int, tuple): input frame number 92 | num_joints (int, tuple): joints number 93 | in_chans (int): number of input channels, 2D joints have 2 channels: (x,y) 94 | embed_dim_ratio (int): embedding dimension ratio 95 | depth (int): depth of transformer 96 | num_heads (int): number of attention heads 97 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim 98 | qkv_bias (bool): enable bias for qkv if True 99 | qk_scale (float): override default qk scale of head_dim ** -0.5 if set 100 | drop_rate (float): dropout rate 101 | attn_drop_rate (float): attention dropout rate 102 | drop_path_rate (float): stochastic depth rate 103 | norm_layer: (nn.Module): normalization layer 104 | """ 105 | super().__init__() 106 | 107 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 108 | embed_dim = embed_dim_ratio * num_joints #### temporal embed_dim is num_joints * spatial embedding dim ratio 109 | out_dim = num_joints * 3 #### output dimension is num_joints * 3 110 | 111 | ### spatial patch embedding 112 | self.Spatial_patch_to_embedding = nn.Linear(in_chans, embed_dim_ratio) 113 | self.Spatial_pos_embed = nn.Parameter(torch.zeros(1, num_joints, embed_dim_ratio)) 114 | 115 | self.Temporal_pos_embed = nn.Parameter(torch.zeros(1, num_frame, embed_dim)) 116 | self.pos_drop = nn.Dropout(p=drop_rate) 117 | 118 | 119 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 120 | 121 | self.Spatial_blocks = nn.ModuleList([ 122 | Block( 123 | dim=embed_dim_ratio, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 124 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) 125 | for i in range(depth)]) 126 | 127 | self.blocks = nn.ModuleList([ 128 | Block( 129 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 130 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) 131 | for i in range(depth)]) 132 | 133 | self.Spatial_norm = norm_layer(embed_dim_ratio) 134 | self.Temporal_norm = norm_layer(embed_dim) 135 | 136 | ####### A easy way to implement weighted mean 137 | self.weighted_mean = torch.nn.Conv1d(in_channels=num_frame, out_channels=1, kernel_size=1) 138 | 139 | self.head = nn.Sequential( 140 | nn.LayerNorm(embed_dim), 141 | nn.Linear(embed_dim , out_dim), 142 | ) 143 | 144 | 145 | def Spatial_forward_features(self, x): 146 | b, _, f, p = x.shape ##### b is batch size, f is number of frames, p is number of joints 147 | x = rearrange(x, 'b c f p -> (b f) p c', ) 148 | 149 | x = self.Spatial_patch_to_embedding(x) 150 | x += self.Spatial_pos_embed 151 | x = self.pos_drop(x) 152 | 153 | for blk in self.Spatial_blocks: 154 | x = blk(x) 155 | 156 | x = self.Spatial_norm(x) 157 | x = rearrange(x, '(b f) w c -> b f (w c)', f=f) 158 | return x 159 | 160 | def forward_features(self, x): 161 | b = x.shape[0] 162 | x += self.Temporal_pos_embed 163 | x = self.pos_drop(x) 164 | for blk in self.blocks: 165 | x = blk(x) 166 | 167 | x = self.Temporal_norm(x) 168 | ##### x size [b, f, emb_dim], then take weighted mean on frame dimension, we only predict 3D pose of the center frame 169 | x = self.weighted_mean(x) 170 | x = x.view(b, 1, -1) 171 | return x 172 | 173 | 174 | def forward(self, x): 175 | x = x.permute(0, 3, 1, 2) 176 | b, _, _, p = x.shape 177 | ### now x is [batch_size, 2 channels, receptive frames, joint_num], following image data 178 | x = self.Spatial_forward_features(x) 179 | x = self.forward_features(x) 180 | x = self.head(x) 181 | 182 | x = x.view(b, 1, p, -1) 183 | 184 | return x 185 | 186 | -------------------------------------------------------------------------------- /common/quaternion.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import torch 9 | 10 | def qrot(q, v): 11 | """ 12 | Rotate vector(s) v about the rotation described by quaternion(s) q. 13 | Expects a tensor of shape (*, 4) for q and a tensor of shape (*, 3) for v, 14 | where * denotes any number of dimensions. 15 | Returns a tensor of shape (*, 3). 16 | """ 17 | assert q.shape[-1] == 4 18 | assert v.shape[-1] == 3 19 | assert q.shape[:-1] == v.shape[:-1] 20 | 21 | qvec = q[..., 1:] 22 | uv = torch.cross(qvec, v, dim=len(q.shape)-1) 23 | uuv = torch.cross(qvec, uv, dim=len(q.shape)-1) 24 | return (v + 2 * (q[..., :1] * uv + uuv)) 25 | 26 | 27 | def qinverse(q, inplace=False): 28 | # We assume the quaternion to be normalized 29 | if inplace: 30 | q[..., 1:] *= -1 31 | return q 32 | else: 33 | w = q[..., :1] 34 | xyz = q[..., 1:] 35 | return torch.cat((w, -xyz), dim=len(q.shape)-1) -------------------------------------------------------------------------------- /common/skeleton.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import numpy as np 9 | 10 | class Skeleton: 11 | def __init__(self, parents, joints_left, joints_right): 12 | assert len(joints_left) == len(joints_right) 13 | 14 | self._parents = np.array(parents) 15 | self._joints_left = joints_left 16 | self._joints_right = joints_right 17 | self._compute_metadata() 18 | 19 | def num_joints(self): 20 | return len(self._parents) 21 | 22 | def parents(self): 23 | return self._parents 24 | 25 | def has_children(self): 26 | return self._has_children 27 | 28 | def children(self): 29 | return self._children 30 | 31 | def remove_joints(self, joints_to_remove): 32 | """ 33 | Remove the joints specified in 'joints_to_remove'. 34 | """ 35 | valid_joints = [] 36 | for joint in range(len(self._parents)): 37 | if joint not in joints_to_remove: 38 | valid_joints.append(joint) 39 | 40 | for i in range(len(self._parents)): 41 | while self._parents[i] in joints_to_remove: 42 | self._parents[i] = self._parents[self._parents[i]] 43 | 44 | index_offsets = np.zeros(len(self._parents), dtype=int) 45 | new_parents = [] 46 | for i, parent in enumerate(self._parents): 47 | if i not in joints_to_remove: 48 | new_parents.append(parent - index_offsets[parent]) 49 | else: 50 | index_offsets[i:] += 1 51 | self._parents = np.array(new_parents) 52 | 53 | 54 | if self._joints_left is not None: 55 | new_joints_left = [] 56 | for joint in self._joints_left: 57 | if joint in valid_joints: 58 | new_joints_left.append(joint - index_offsets[joint]) 59 | self._joints_left = new_joints_left 60 | if self._joints_right is not None: 61 | new_joints_right = [] 62 | for joint in self._joints_right: 63 | if joint in valid_joints: 64 | new_joints_right.append(joint - index_offsets[joint]) 65 | self._joints_right = new_joints_right 66 | 67 | self._compute_metadata() 68 | 69 | return valid_joints 70 | 71 | def joints_left(self): 72 | return self._joints_left 73 | 74 | def joints_right(self): 75 | return self._joints_right 76 | 77 | def _compute_metadata(self): 78 | self._has_children = np.zeros(len(self._parents)).astype(bool) 79 | for i, parent in enumerate(self._parents): 80 | if parent != -1: 81 | self._has_children[parent] = True 82 | 83 | self._children = [] 84 | for i, parent in enumerate(self._parents): 85 | self._children.append([]) 86 | for i, parent in enumerate(self._parents): 87 | if parent != -1: 88 | self._children[parent].append(i) -------------------------------------------------------------------------------- /common/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import torch 9 | import numpy as np 10 | import hashlib 11 | 12 | def wrap(func, *args, unsqueeze=False): 13 | """ 14 | Wrap a torch function so it can be called with NumPy arrays. 15 | Input and return types are seamlessly converted. 16 | """ 17 | 18 | # Convert input types where applicable 19 | args = list(args) 20 | for i, arg in enumerate(args): 21 | if type(arg) == np.ndarray: 22 | args[i] = torch.from_numpy(arg) 23 | if unsqueeze: 24 | args[i] = args[i].unsqueeze(0) 25 | 26 | result = func(*args) 27 | 28 | # Convert output types where applicable 29 | if isinstance(result, tuple): 30 | result = list(result) 31 | for i, res in enumerate(result): 32 | if type(res) == torch.Tensor: 33 | if unsqueeze: 34 | res = res.squeeze(0) 35 | result[i] = res.numpy() 36 | return tuple(result) 37 | elif type(result) == torch.Tensor: 38 | if unsqueeze: 39 | result = result.squeeze(0) 40 | return result.numpy() 41 | else: 42 | return result 43 | 44 | def deterministic_random(min_value, max_value, data): 45 | digest = hashlib.sha256(data.encode()).digest() 46 | raw_value = int.from_bytes(digest[:4], byteorder='little', signed=False) 47 | return int(raw_value / (2**32 - 1) * (max_value - min_value)) + min_value 48 | 49 | def load_pretrained_weights(model, checkpoint): 50 | """Load pretrianed weights to model 51 | Incompatible layers (unmatched in name or size) will be ignored 52 | Args: 53 | - model (nn.Module): network model, which must not be nn.DataParallel 54 | - weight_path (str): path to pretrained weights 55 | """ 56 | import collections 57 | if 'state_dict' in checkpoint: 58 | state_dict = checkpoint['state_dict'] 59 | else: 60 | state_dict = checkpoint 61 | model_dict = model.state_dict() 62 | new_state_dict = collections.OrderedDict() 63 | matched_layers, discarded_layers = [], [] 64 | for k, v in state_dict.items(): 65 | # If the pretrained state_dict was saved as nn.DataParallel, 66 | # keys would contain "module.", which should be ignored. 67 | if k.startswith('module.'): 68 | k = k[7:] 69 | if k in model_dict and model_dict[k].size() == v.size(): 70 | new_state_dict[k] = v 71 | matched_layers.append(k) 72 | else: 73 | discarded_layers.append(k) 74 | # new_state_dict.requires_grad = False 75 | model_dict.update(new_state_dict) 76 | 77 | model.load_state_dict(model_dict) 78 | print('load_weight', len(matched_layers)) 79 | # model.state_dict(model_dict).requires_grad = False 80 | return model 81 | -------------------------------------------------------------------------------- /common/visualization.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import matplotlib 9 | 10 | matplotlib.use('Agg') 11 | 12 | 13 | import matplotlib.pyplot as plt 14 | from matplotlib.animation import FuncAnimation, writers 15 | from mpl_toolkits.mplot3d import Axes3D 16 | import numpy as np 17 | import subprocess as sp 18 | import cv2 19 | 20 | 21 | def get_resolution(filename): 22 | command = ['ffprobe', '-v', 'error', '-select_streams', 'v:0', 23 | '-show_entries', 'stream=width,height', '-of', 'csv=p=0', filename] 24 | with sp.Popen(command, stdout=sp.PIPE, bufsize=-1) as pipe: 25 | for line in pipe.stdout: 26 | w, h = line.decode().strip().split(',') 27 | return int(w), int(h) 28 | 29 | 30 | def get_fps(filename): 31 | command = ['ffprobe', '-v', 'error', '-select_streams', 'v:0', 32 | '-show_entries', 'stream=r_frame_rate', '-of', 'csv=p=0', filename] 33 | with sp.Popen(command, stdout=sp.PIPE, bufsize=-1) as pipe: 34 | for line in pipe.stdout: 35 | a, b = line.decode().strip().split('/') 36 | return int(a) / int(b) 37 | 38 | 39 | def read_video(filename, skip=0, limit=-1): 40 | # w, h = get_resolution(filename) 41 | w = 1000 42 | h = 1002 43 | 44 | command = ['ffmpeg', 45 | '-i', filename, 46 | '-f', 'image2pipe', 47 | '-pix_fmt', 'rgb24', 48 | '-vsync', '0', 49 | '-vcodec', 'rawvideo', '-'] 50 | 51 | i = 0 52 | with sp.Popen(command, stdout=sp.PIPE, bufsize=-1) as pipe: 53 | while True: 54 | data = pipe.stdout.read(w * h * 3) 55 | if not data: 56 | break 57 | i += 1 58 | if i > limit and limit != -1: 59 | continue 60 | if i > skip: 61 | yield np.frombuffer(data, dtype='uint8').reshape((h, w, 3)) 62 | 63 | 64 | def downsample_tensor(X, factor): 65 | length = X.shape[0] // factor * factor 66 | return np.mean(X[:length].reshape(-1, factor, *X.shape[1:]), axis=1) 67 | 68 | 69 | def render_animation(keypoints, keypoints_metadata, poses, skeleton, fps, bitrate, azim, output, viewport, 70 | limit=-1, downsample=1, size=6, input_video_path=None, input_video_skip=0): 71 | """ 72 | TODO 73 | Render an animation. The supported output modes are: 74 | -- 'interactive': display an interactive figure 75 | (also works on notebooks if associated with %matplotlib inline) 76 | -- 'html': render the animation as HTML5 video. Can be displayed in a notebook using HTML(...). 77 | -- 'filename.mp4': render and export the animation as an h264 video (requires ffmpeg). 78 | -- 'filename.gif': render and export the animation a gif file (requires imagemagick). 79 | """ 80 | plt.ioff() 81 | fig = plt.figure(figsize=(size * (1 + len(poses)), size)) 82 | ax_in = fig.add_subplot(1, 1 + len(poses), 1) 83 | ax_in.get_xaxis().set_visible(False) 84 | ax_in.get_yaxis().set_visible(False) 85 | ax_in.set_axis_off() 86 | ax_in.set_title('Input') 87 | 88 | ax_3d = [] 89 | lines_3d = [] 90 | trajectories = [] 91 | radius = 1.7 92 | for index, (title, data) in enumerate(poses.items()): 93 | ax = fig.add_subplot(1, 1 + len(poses), index + 2, projection='3d') 94 | ax.view_init(elev=15., azim=azim) 95 | ax.set_xlim3d([-radius / 2, radius / 2]) 96 | ax.set_zlim3d([0, radius]) 97 | ax.set_ylim3d([-radius / 2, radius / 2]) 98 | try: 99 | ax.set_aspect('equal') 100 | except NotImplementedError: 101 | ax.set_aspect('auto') 102 | ax.set_xticklabels([]) 103 | ax.set_yticklabels([]) 104 | ax.set_zticklabels([]) 105 | ax.dist = 7.5 106 | ax.set_title(title) # , pad=35 107 | ax_3d.append(ax) 108 | lines_3d.append([]) 109 | trajectories.append(data[:, 0, [0, 1]]) 110 | poses = list(poses.values()) 111 | 112 | # Decode video 113 | if input_video_path is None: 114 | # Black background 115 | all_frames = np.zeros((keypoints.shape[0], viewport[1], viewport[0]), dtype='uint8') 116 | else: 117 | # Load video using ffmpeg 118 | all_frames = [] 119 | for f in read_video(input_video_path, skip=input_video_skip, limit=limit): 120 | all_frames.append(f) 121 | effective_length = min(keypoints.shape[0], len(all_frames)) 122 | all_frames = all_frames[:effective_length] 123 | 124 | keypoints = keypoints[input_video_skip:] # todo remove 125 | for idx in range(len(poses)): 126 | poses[idx] = poses[idx][input_video_skip:] 127 | 128 | if fps is None: 129 | fps = get_fps(input_video_path) 130 | 131 | if downsample > 1: 132 | keypoints = downsample_tensor(keypoints, downsample) 133 | all_frames = downsample_tensor(np.array(all_frames), downsample).astype('uint8') 134 | for idx in range(len(poses)): 135 | poses[idx] = downsample_tensor(poses[idx], downsample) 136 | trajectories[idx] = downsample_tensor(trajectories[idx], downsample) 137 | fps /= downsample 138 | 139 | initialized = False 140 | image = None 141 | lines = [] 142 | points = None 143 | 144 | if limit < 1: 145 | limit = len(all_frames) 146 | else: 147 | limit = min(limit, len(all_frames)) 148 | 149 | parents = skeleton.parents() 150 | 151 | def update_video(i): 152 | nonlocal initialized, image, lines, points 153 | 154 | for n, ax in enumerate(ax_3d): 155 | ax.set_xlim3d([-radius / 2 + trajectories[n][i, 0], radius / 2 + trajectories[n][i, 0]]) 156 | ax.set_ylim3d([-radius / 2 + trajectories[n][i, 1], radius / 2 + trajectories[n][i, 1]]) 157 | 158 | # Update 2D poses 159 | joints_right_2d = keypoints_metadata['keypoints_symmetry'][1] 160 | colors_2d = np.full(keypoints.shape[1], 'black') 161 | colors_2d[joints_right_2d] = 'red' 162 | if not initialized: 163 | image = ax_in.imshow(all_frames[i], aspect='equal') 164 | 165 | for j, j_parent in enumerate(parents): 166 | if j_parent == -1: 167 | continue 168 | 169 | if len(parents) == keypoints.shape[1] and keypoints_metadata['layout_name'] != 'coco': 170 | # Draw skeleton only if keypoints match (otherwise we don't have the parents definition) 171 | lines.append(ax_in.plot([keypoints[i, j, 0], keypoints[i, j_parent, 0]], 172 | [keypoints[i, j, 1], keypoints[i, j_parent, 1]], color='pink')) 173 | 174 | col = 'red' if j in skeleton.joints_right() else 'black' 175 | for n, ax in enumerate(ax_3d): 176 | pos = poses[n][i] 177 | lines_3d[n].append(ax.plot([pos[j, 0], pos[j_parent, 0]], 178 | [pos[j, 1], pos[j_parent, 1]], 179 | [pos[j, 2], pos[j_parent, 2]], zdir='z', c=col)) 180 | 181 | points = ax_in.scatter(*keypoints[i].T, 10, color=colors_2d, edgecolors='white', zorder=10) 182 | 183 | initialized = True 184 | else: 185 | image.set_data(all_frames[i]) 186 | 187 | for j, j_parent in enumerate(parents): 188 | if j_parent == -1: 189 | continue 190 | 191 | if len(parents) == keypoints.shape[1] and keypoints_metadata['layout_name'] != 'coco': 192 | lines[j - 1][0].set_data([keypoints[i, j, 0], keypoints[i, j_parent, 0]], 193 | [keypoints[i, j, 1], keypoints[i, j_parent, 1]]) 194 | 195 | for n, ax in enumerate(ax_3d): 196 | pos = poses[n][i] 197 | lines_3d[n][j - 1][0].set_xdata(np.array([pos[j, 0], pos[j_parent, 0]])) 198 | lines_3d[n][j - 1][0].set_ydata(np.array([pos[j, 1], pos[j_parent, 1]])) 199 | lines_3d[n][j - 1][0].set_3d_properties(np.array([pos[j, 2], pos[j_parent, 2]]), zdir='z') 200 | 201 | points.set_offsets(keypoints[i]) 202 | 203 | print('{}/{} '.format(i, limit), end='\r') 204 | 205 | fig.tight_layout() 206 | 207 | anim = FuncAnimation(fig, update_video, frames=np.arange(0, limit), interval=1000 / fps, repeat=False) 208 | if output.endswith('.mp4'): 209 | Writer = writers['ffmpeg'] 210 | writer = Writer(fps=fps, metadata={}, bitrate=bitrate) 211 | anim.save(output, writer=writer) 212 | elif output.endswith('.gif'): 213 | anim.save(output, dpi=80, writer='imagemagick') 214 | else: 215 | raise ValueError('Unsupported output format (only .mp4 and .gif are supported)') 216 | plt.close() -------------------------------------------------------------------------------- /data/README.MD: -------------------------------------------------------------------------------- 1 | The dataset put at here. 2 | 3 | Our code is compatible with the dataset setup introduced by [Martinez et al.](https://github.com/una-dinosauria/3d-pose-baseline) and [Pavllo et al.](https://github.com/facebookresearch/VideoPose3D). Please refer to [VideoPose3D](https://github.com/facebookresearch/VideoPose3D) to set up the Human3.6M dataset (./data directory). 4 | 5 | -------------------------------------------------------------------------------- /figure/H3.6.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zczcwh/PoseFormer/a908b29adca354b2e0acd36606229ca19b6a5e3c/figure/H3.6.gif -------------------------------------------------------------------------------- /figure/PoseFormer.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zczcwh/PoseFormer/a908b29adca354b2e0acd36606229ca19b6a5e3c/figure/PoseFormer.gif -------------------------------------------------------------------------------- /figure/wild.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zczcwh/PoseFormer/a908b29adca354b2e0acd36606229ca19b6a5e3c/figure/wild.gif -------------------------------------------------------------------------------- /poseformer.yml: -------------------------------------------------------------------------------- 1 | name: pose2 2 | channels: 3 | - pytorch 4 | - anaconda 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1=main 9 | - _tflow_select=2.3.0=mkl 10 | - absl-py=0.10.0=py38h32f6830_1 11 | - aiohttp=3.7.2=py38h1e0a361_0 12 | - astunparse=1.6.3=py_0 13 | - async-timeout=3.0.1=py_1000 14 | - attrs=20.2.0=pyh9f0ad1d_0 15 | - blas=1.0=mkl 16 | - blinker=1.4=py_1 17 | - brotlipy=0.7.0=py38h7b6447c_1000 18 | - bzip2=1.0.8=h7b6447c_0 19 | - c-ares=1.16.1=h516909a_3 20 | - ca-certificates=2021.1.19=h06a4308_0 21 | - cachetools=4.1.1=py_0 22 | - cairo=1.16.0=h3fc0475_1005 23 | - cdflib=0.3.18=py_0 24 | - certifi=2020.12.5=py38h06a4308_0 25 | - cffi=1.14.0=py38h2e261b9_0 26 | - chardet=3.0.4=py38_1003 27 | - click=7.1.2=pyh9f0ad1d_0 28 | - cryptography=3.1.1=py38h1ba5d50_0 29 | - cudatoolkit=11.0.221=h6bb024c_0 30 | - cycler=0.10.0=py_2 31 | - cython=0.29.21=py38he6710b0_0 32 | - dbus=1.13.16=hb2f20db_0 33 | - einops=0.3.0=py_0 34 | - expat=2.2.9=he6710b0_2 35 | - fontconfig=2.13.1=h1056068_1002 36 | - freetype=2.10.2=h5ab3b9f_0 37 | - gast=0.3.3=py_0 38 | - glib=2.63.1=h5a9c865_0 39 | - gmp=6.2.0=he1b5a44_2 40 | - gnutls=3.6.13=h79a8f9a_0 41 | - google-auth=1.23.0=pyhd8ed1ab_0 42 | - google-auth-oauthlib=0.4.1=py_2 43 | - google-pasta=0.2.0=py_0 44 | - graphite2=1.3.14=h23475e2_0 45 | - grpcio=1.33.2=py38heead2fc_0 46 | - gst-plugins-base=1.14.5=h0935bb2_2 47 | - gstreamer=1.14.5=h36ae1b5_2 48 | - h5py=2.10.0=py38hd6299e0_1 49 | - harfbuzz=2.7.2=hee91db6_0 50 | - hdf5=1.10.6=hb1b8bf9_0 51 | - icu=67.1=he1b5a44_0 52 | - idna=2.10=py_0 53 | - importlib-metadata=2.0.0=py_1 54 | - intel-openmp=2020.2=254 55 | - jasper=1.900.1=hd497a04_4 56 | - jpeg=9d=h516909a_0 57 | - json_tricks=3.15.3=pyh9f0ad1d_0 58 | - keras-preprocessing=1.1.0=py_1 59 | - kiwisolver=1.2.0=py38hbf85e49_0 60 | - krb5=1.17.1=h173b8e3_0 61 | - lame=3.100=h7b6447c_0 62 | - lcms2=2.11=h396b838_0 63 | - ld_impl_linux-64=2.33.1=h53a641e_7 64 | - libblas=3.8.0=16_mkl 65 | - libcblas=3.8.0=16_mkl 66 | - libclang=10.0.1=default_hb85057a_2 67 | - libedit=3.1.20191231=h14c3975_1 68 | - libevent=2.1.10=hcdb4288_2 69 | - libffi=3.2.1=hf484d3e_1007 70 | - libgcc-ng=9.1.0=hdf63c60_0 71 | - libgfortran-ng=7.3.0=hdf63c60_0 72 | - libiconv=1.16=h516909a_0 73 | - liblapack=3.8.0=16_mkl 74 | - liblapacke=3.8.0=16_mkl 75 | - libllvm10=10.0.1=hbcb73fb_5 76 | - libpng=1.6.37=hbc83047_0 77 | - libpq=12.3=h5513abc_0 78 | - libprotobuf=3.13.0=hd408876_0 79 | - libsodium=1.0.18=h7b6447c_0 80 | - libstdcxx-ng=9.1.0=hdf63c60_0 81 | - libtiff=4.1.0=h2733197_1 82 | - libuuid=2.32.1=h14c3975_1000 83 | - libuv=1.40.0=h7b6447c_0 84 | - libwebp-base=1.1.0=h7b6447c_3 85 | - libxcb=1.14=h7b6447c_0 86 | - libxkbcommon=0.10.0=he1b5a44_0 87 | - libxml2=2.9.10=h68273f3_2 88 | - lz4-c=1.9.2=he6710b0_1 89 | - markdown=3.3.3=pyh9f0ad1d_0 90 | - matplotlib=3.3.2=0 91 | - matplotlib-base=3.3.2=py38h91b0d89_0 92 | - mkl=2020.2=256 93 | - mkl-service=2.3.0=py38he904b0f_0 94 | - mkl_fft=1.2.0=py38h23d657b_0 95 | - mkl_random=1.1.1=py38h0573a6f_0 96 | - multidict=4.7.5=py38h1e0a361_2 97 | - mysql-common=8.0.21=2 98 | - mysql-libs=8.0.21=hf3661c5_2 99 | - ncurses=6.2=he6710b0_1 100 | - nettle=3.4.1=hbb512f6_0 101 | - nibabel=3.1.1=py_0 102 | - ninja=1.10.1=py38hfd86e86_0 103 | - nspr=4.29=he1b5a44_0 104 | - nss=3.57=he751ad9_0 105 | - numpy-base=1.19.2=py38hfa32c7d_0 106 | - oauthlib=3.0.1=py_0 107 | - olefile=0.46=py_0 108 | - openh264=2.1.1=h8b12597_0 109 | - openssl=1.1.1i=h27cfd23_0 110 | - opt_einsum=3.1.0=py_0 111 | - packaging=20.4=py_0 112 | - pandas=1.1.1=py38he6710b0_0 113 | - pcre=8.44=he6710b0_0 114 | - pillow=7.2.0=py38hb39fc2d_0 115 | - pip=20.2.2=py38_0 116 | - pixman=0.38.0=h7b6447c_0 117 | - protobuf=3.13.0=py38he6710b0_1 118 | - psutil=5.7.2=py38h7b6447c_0 119 | - pyasn1=0.4.8=py_0 120 | - pyasn1-modules=0.2.7=py_0 121 | - pycparser=2.20=py_2 122 | - pydicom=2.0.0=pyh9f0ad1d_0 123 | - pyjwt=1.7.1=py_0 124 | - pyopenssl=19.1.0=py_1 125 | - pyparsing=2.4.7=py_0 126 | - pysocks=1.7.1=py38_0 127 | - python=3.8.2=hcf32534_0 128 | - python-dateutil=2.8.1=py_0 129 | - python_abi=3.8=1_cp38 130 | - pytorch=1.7.1=py3.8_cuda11.0.221_cudnn8.0.5_0 131 | - pytz=2020.1=py_0 132 | - pyyaml=5.3.1=py38h7b6447c_1 133 | - pyzmq=19.0.2=py38he6710b0_1 134 | - qt=5.12.9=h1f2b2cb_0 135 | - readline=8.0=h7b6447c_0 136 | - requests=2.24.0=py_0 137 | - requests-oauthlib=1.3.0=pyh9f0ad1d_0 138 | - rsa=4.6=pyh9f0ad1d_0 139 | - scipy=1.5.2=py38h0b6359f_0 140 | - setuptools=49.6.0=py38_1 141 | - six=1.15.0=py_0 142 | - sqlite=3.33.0=h62c20be_0 143 | - tensorboard=2.3.0=py_0 144 | - tensorboard-plugin-wit=1.7.0=pyh9f0ad1d_0 145 | - tensorboardx=2.1=py_0 146 | - tensorflow=2.2.0=mkl_py38h6d3daf0_0 147 | - tensorflow-base=2.2.0=mkl_py38h5059a2d_0 148 | - tensorflow-estimator=2.2.0=pyh208ff02_0 149 | - termcolor=1.1.0=py38_1 150 | - tk=8.6.10=hbc83047_0 151 | - torchfile=0.1.0=py_0 152 | - tornado=6.0.4=py38h7b6447c_1 153 | - tqdm=4.50.0=pyh9f0ad1d_0 154 | - typing-extensions=3.7.4.3=0 155 | - typing_extensions=3.7.4.3=py_0 156 | - urllib3=1.25.10=py_0 157 | - visdom=0.1.8.9=0 158 | - websocket-client=0.57.0=py38_1 159 | - werkzeug=1.0.1=pyh9f0ad1d_0 160 | - wheel=0.35.1=py_0 161 | - wrapt=1.12.1=py38h7b6447c_1 162 | - x264=1!152.20180806=h7b6447c_0 163 | - xorg-kbproto=1.0.7=h14c3975_1002 164 | - xorg-libice=1.0.10=h516909a_0 165 | - xorg-libsm=1.2.3=h84519dc_1000 166 | - xorg-libx11=1.6.12=h516909a_0 167 | - xorg-libxext=1.3.4=h516909a_0 168 | - xorg-libxrender=0.9.10=h516909a_1002 169 | - xorg-renderproto=0.11.1=h14c3975_1002 170 | - xorg-xextproto=7.3.0=h14c3975_1002 171 | - xorg-xproto=7.0.31=h14c3975_1007 172 | - xz=5.2.5=h7b6447c_0 173 | - yacs=0.1.6=py_0 174 | - yaml=0.2.5=h7b6447c_0 175 | - yarl=1.6.2=py38h1e0a361_0 176 | - zeromq=4.3.2=he6710b0_3 177 | - zipp=3.4.0=py_0 178 | - zlib=1.2.11=h7b6447c_3 179 | - zstd=1.4.5=h9ceee32_0 180 | - pip: 181 | - decorator==4.4.2 182 | - easydict==1.7 183 | - fvcore==0.1.3.post20210311 184 | - imageio==2.9.0 185 | - iopath==0.1.4 186 | - munkres==1.1.4 187 | - networkx==2.5 188 | - numpy==1.16.2 189 | - opencv-python==4.4.0.44 190 | - portalocker==2.2.1 191 | - progress==1.5 192 | - pthflops==0.4.0 193 | - pywavelets==1.1.1 194 | - scikit-image==0.17.2 195 | - tabulate==0.8.9 196 | - thop==0.0.31-2005241907 197 | - tifffile==2020.10.1 198 | - timm==0.3.4 199 | - torchscan==0.1.1 200 | - torchstat==0.0.7 201 | - torchvision==0.8.2 202 | prefix: /home/cezheng/anaconda3/envs/pose2 203 | -------------------------------------------------------------------------------- /run_poseformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import numpy as np 9 | 10 | from common.arguments import parse_args 11 | import torch 12 | 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | import torch.optim as optim 16 | import os 17 | import sys 18 | import errno 19 | import math 20 | 21 | from einops import rearrange, repeat 22 | from copy import deepcopy 23 | 24 | from common.camera import * 25 | import collections 26 | 27 | from common.model_poseformer import * 28 | 29 | from common.loss import * 30 | from common.generators import ChunkedGenerator, UnchunkedGenerator 31 | from time import time 32 | from common.utils import * 33 | 34 | 35 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 36 | # os.environ["CUDA_VISIBLE_DEVICES"] = "0" 37 | os.environ["CUDA_VISIBLE_DEVICES"] = "0, 1" 38 | # print(torch.cuda.device_count()) 39 | 40 | 41 | ################### 42 | args = parse_args() 43 | # print(args) 44 | 45 | try: 46 | # Create checkpoint directory if it does not exist 47 | os.makedirs(args.checkpoint) 48 | except OSError as e: 49 | if e.errno != errno.EEXIST: 50 | raise RuntimeError('Unable to create checkpoint directory:', args.checkpoint) 51 | 52 | print('Loading dataset...') 53 | dataset_path = 'data/data_3d_' + args.dataset + '.npz' 54 | if args.dataset == 'h36m': 55 | from common.h36m_dataset import Human36mDataset 56 | dataset = Human36mDataset(dataset_path) 57 | elif args.dataset.startswith('humaneva'): 58 | from common.humaneva_dataset import HumanEvaDataset 59 | dataset = HumanEvaDataset(dataset_path) 60 | elif args.dataset.startswith('custom'): 61 | from common.custom_dataset import CustomDataset 62 | dataset = CustomDataset('data/data_2d_' + args.dataset + '_' + args.keypoints + '.npz') 63 | else: 64 | raise KeyError('Invalid dataset') 65 | 66 | print('Preparing data...') 67 | for subject in dataset.subjects(): 68 | for action in dataset[subject].keys(): 69 | anim = dataset[subject][action] 70 | 71 | if 'positions' in anim: 72 | positions_3d = [] 73 | for cam in anim['cameras']: 74 | pos_3d = world_to_camera(anim['positions'], R=cam['orientation'], t=cam['translation']) 75 | pos_3d[:, 1:] -= pos_3d[:, :1] # Remove global offset, but keep trajectory in first position 76 | positions_3d.append(pos_3d) 77 | anim['positions_3d'] = positions_3d 78 | 79 | print('Loading 2D detections...') 80 | keypoints = np.load('data/data_2d_' + args.dataset + '_' + args.keypoints + '.npz', allow_pickle=True) 81 | keypoints_metadata = keypoints['metadata'].item() 82 | keypoints_symmetry = keypoints_metadata['keypoints_symmetry'] 83 | kps_left, kps_right = list(keypoints_symmetry[0]), list(keypoints_symmetry[1]) 84 | joints_left, joints_right = list(dataset.skeleton().joints_left()), list(dataset.skeleton().joints_right()) 85 | keypoints = keypoints['positions_2d'].item() 86 | 87 | ################### 88 | for subject in dataset.subjects(): 89 | assert subject in keypoints, 'Subject {} is missing from the 2D detections dataset'.format(subject) 90 | for action in dataset[subject].keys(): 91 | assert action in keypoints[subject], 'Action {} of subject {} is missing from the 2D detections dataset'.format(action, subject) 92 | if 'positions_3d' not in dataset[subject][action]: 93 | continue 94 | 95 | for cam_idx in range(len(keypoints[subject][action])): 96 | 97 | # We check for >= instead of == because some videos in H3.6M contain extra frames 98 | mocap_length = dataset[subject][action]['positions_3d'][cam_idx].shape[0] 99 | assert keypoints[subject][action][cam_idx].shape[0] >= mocap_length 100 | 101 | if keypoints[subject][action][cam_idx].shape[0] > mocap_length: 102 | # Shorten sequence 103 | keypoints[subject][action][cam_idx] = keypoints[subject][action][cam_idx][:mocap_length] 104 | 105 | assert len(keypoints[subject][action]) == len(dataset[subject][action]['positions_3d']) 106 | 107 | for subject in keypoints.keys(): 108 | for action in keypoints[subject]: 109 | for cam_idx, kps in enumerate(keypoints[subject][action]): 110 | # Normalize camera frame 111 | cam = dataset.cameras()[subject][cam_idx] 112 | kps[..., :2] = normalize_screen_coordinates(kps[..., :2], w=cam['res_w'], h=cam['res_h']) 113 | keypoints[subject][action][cam_idx] = kps 114 | 115 | subjects_train = args.subjects_train.split(',') 116 | subjects_semi = [] if not args.subjects_unlabeled else args.subjects_unlabeled.split(',') 117 | if not args.render: 118 | subjects_test = args.subjects_test.split(',') 119 | else: 120 | subjects_test = [args.viz_subject] 121 | 122 | 123 | def fetch(subjects, action_filter=None, subset=1, parse_3d_poses=True): 124 | out_poses_3d = [] 125 | out_poses_2d = [] 126 | out_camera_params = [] 127 | for subject in subjects: 128 | for action in keypoints[subject].keys(): 129 | if action_filter is not None: 130 | found = False 131 | for a in action_filter: 132 | if action.startswith(a): 133 | found = True 134 | break 135 | if not found: 136 | continue 137 | 138 | poses_2d = keypoints[subject][action] 139 | for i in range(len(poses_2d)): # Iterate across cameras 140 | out_poses_2d.append(poses_2d[i]) 141 | 142 | if subject in dataset.cameras(): 143 | cams = dataset.cameras()[subject] 144 | assert len(cams) == len(poses_2d), 'Camera count mismatch' 145 | for cam in cams: 146 | if 'intrinsic' in cam: 147 | out_camera_params.append(cam['intrinsic']) 148 | 149 | if parse_3d_poses and 'positions_3d' in dataset[subject][action]: 150 | poses_3d = dataset[subject][action]['positions_3d'] 151 | assert len(poses_3d) == len(poses_2d), 'Camera count mismatch' 152 | for i in range(len(poses_3d)): # Iterate across cameras 153 | out_poses_3d.append(poses_3d[i]) 154 | 155 | if len(out_camera_params) == 0: 156 | out_camera_params = None 157 | if len(out_poses_3d) == 0: 158 | out_poses_3d = None 159 | 160 | stride = args.downsample 161 | if subset < 1: 162 | for i in range(len(out_poses_2d)): 163 | n_frames = int(round(len(out_poses_2d[i])//stride * subset)*stride) 164 | start = deterministic_random(0, len(out_poses_2d[i]) - n_frames + 1, str(len(out_poses_2d[i]))) 165 | out_poses_2d[i] = out_poses_2d[i][start:start+n_frames:stride] 166 | if out_poses_3d is not None: 167 | out_poses_3d[i] = out_poses_3d[i][start:start+n_frames:stride] 168 | elif stride > 1: 169 | # Downsample as requested 170 | for i in range(len(out_poses_2d)): 171 | out_poses_2d[i] = out_poses_2d[i][::stride] 172 | if out_poses_3d is not None: 173 | out_poses_3d[i] = out_poses_3d[i][::stride] 174 | 175 | 176 | return out_camera_params, out_poses_3d, out_poses_2d 177 | 178 | action_filter = None if args.actions == '*' else args.actions.split(',') 179 | if action_filter is not None: 180 | print('Selected actions:', action_filter) 181 | 182 | cameras_valid, poses_valid, poses_valid_2d = fetch(subjects_test, action_filter) 183 | 184 | 185 | receptive_field = args.number_of_frames 186 | print('INFO: Receptive field: {} frames'.format(receptive_field)) 187 | pad = (receptive_field -1) // 2 # Padding on each side 188 | min_loss = 100000 189 | width = cam['res_w'] 190 | height = cam['res_h'] 191 | num_joints = keypoints_metadata['num_joints'] 192 | 193 | #########################################PoseTransformer 194 | 195 | model_pos_train = PoseTransformer(num_frame=receptive_field, num_joints=num_joints, in_chans=2, embed_dim_ratio=32, depth=4, 196 | num_heads=8, mlp_ratio=2., qkv_bias=True, qk_scale=None,drop_path_rate=0.1) 197 | 198 | model_pos = PoseTransformer(num_frame=receptive_field, num_joints=num_joints, in_chans=2, embed_dim_ratio=32, depth=4, 199 | num_heads=8, mlp_ratio=2., qkv_bias=True, qk_scale=None,drop_path_rate=0) 200 | 201 | ################ load weight ######################## 202 | # posetrans_checkpoint = torch.load('./checkpoint/pretrained_posetrans.bin', map_location=lambda storage, loc: storage) 203 | # posetrans_checkpoint = posetrans_checkpoint["model_pos"] 204 | # model_pos_train = load_pretrained_weights(model_pos_train, posetrans_checkpoint) 205 | 206 | ################# 207 | causal_shift = 0 208 | model_params = 0 209 | for parameter in model_pos.parameters(): 210 | model_params += parameter.numel() 211 | print('INFO: Trainable parameter count:', model_params) 212 | 213 | if torch.cuda.is_available(): 214 | model_pos = nn.DataParallel(model_pos) 215 | model_pos = model_pos.cuda() 216 | model_pos_train = nn.DataParallel(model_pos_train) 217 | model_pos_train = model_pos_train.cuda() 218 | 219 | 220 | if args.resume or args.evaluate: 221 | chk_filename = os.path.join(args.checkpoint, args.resume if args.resume else args.evaluate) 222 | print('Loading checkpoint', chk_filename) 223 | checkpoint = torch.load(chk_filename, map_location=lambda storage, loc: storage) 224 | model_pos_train.load_state_dict(checkpoint['model_pos'], strict=False) 225 | model_pos.load_state_dict(checkpoint['model_pos'], strict=False) 226 | 227 | 228 | test_generator = UnchunkedGenerator(cameras_valid, poses_valid, poses_valid_2d, 229 | pad=pad, causal_shift=causal_shift, augment=False, 230 | kps_left=kps_left, kps_right=kps_right, joints_left=joints_left, joints_right=joints_right) 231 | print('INFO: Testing on {} frames'.format(test_generator.num_frames())) 232 | 233 | def eval_data_prepare(receptive_field, inputs_2d, inputs_3d): 234 | inputs_2d_p = torch.squeeze(inputs_2d) 235 | inputs_3d_p = inputs_3d.permute(1,0,2,3) 236 | out_num = inputs_2d_p.shape[0] - receptive_field + 1 237 | eval_input_2d = torch.empty(out_num, receptive_field, inputs_2d_p.shape[1], inputs_2d_p.shape[2]) 238 | for i in range(out_num): 239 | eval_input_2d[i,:,:,:] = inputs_2d_p[i:i+receptive_field, :, :] 240 | return eval_input_2d, inputs_3d_p 241 | 242 | 243 | ################### 244 | 245 | if not args.evaluate: 246 | cameras_train, poses_train, poses_train_2d = fetch(subjects_train, action_filter, subset=args.subset) 247 | 248 | lr = args.learning_rate 249 | optimizer = optim.AdamW(model_pos_train.parameters(), lr=lr, weight_decay=0.1) 250 | 251 | lr_decay = args.lr_decay 252 | losses_3d_train = [] 253 | losses_3d_train_eval = [] 254 | losses_3d_valid = [] 255 | 256 | epoch = 0 257 | initial_momentum = 0.1 258 | final_momentum = 0.001 259 | 260 | train_generator = ChunkedGenerator(args.batch_size//args.stride, cameras_train, poses_train, poses_train_2d, args.stride, 261 | pad=pad, causal_shift=causal_shift, shuffle=True, augment=args.data_augmentation, 262 | kps_left=kps_left, kps_right=kps_right, joints_left=joints_left, joints_right=joints_right) 263 | train_generator_eval = UnchunkedGenerator(cameras_train, poses_train, poses_train_2d, 264 | pad=pad, causal_shift=causal_shift, augment=False) 265 | print('INFO: Training on {} frames'.format(train_generator_eval.num_frames())) 266 | 267 | 268 | if args.resume: 269 | epoch = checkpoint['epoch'] 270 | if 'optimizer' in checkpoint and checkpoint['optimizer'] is not None: 271 | optimizer.load_state_dict(checkpoint['optimizer']) 272 | train_generator.set_random_state(checkpoint['random_state']) 273 | else: 274 | print('WARNING: this checkpoint does not contain an optimizer state. The optimizer will be reinitialized.') 275 | 276 | lr = checkpoint['lr'] 277 | 278 | 279 | print('** Note: reported losses are averaged over all frames.') 280 | print('** The final evaluation will be carried out after the last training epoch.') 281 | 282 | # Pos model only 283 | while epoch < args.epochs: 284 | start_time = time() 285 | epoch_loss_3d_train = 0 286 | epoch_loss_traj_train = 0 287 | epoch_loss_2d_train_unlabeled = 0 288 | N = 0 289 | N_semi = 0 290 | model_pos_train.train() 291 | 292 | for cameras_train, batch_3d, batch_2d in train_generator.next_epoch(): 293 | cameras_train = torch.from_numpy(cameras_train.astype('float32')) 294 | inputs_3d = torch.from_numpy(batch_3d.astype('float32')) 295 | inputs_2d = torch.from_numpy(batch_2d.astype('float32')) 296 | 297 | if torch.cuda.is_available(): 298 | inputs_3d = inputs_3d.cuda() 299 | inputs_2d = inputs_2d.cuda() 300 | cameras_train = cameras_train.cuda() 301 | inputs_traj = inputs_3d[:, :, :1].clone() 302 | inputs_3d[:, :, 0] = 0 303 | 304 | optimizer.zero_grad() 305 | 306 | # Predict 3D poses 307 | predicted_3d_pos = model_pos_train(inputs_2d) 308 | 309 | del inputs_2d 310 | torch.cuda.empty_cache() 311 | 312 | loss_3d_pos = mpjpe(predicted_3d_pos, inputs_3d) 313 | epoch_loss_3d_train += inputs_3d.shape[0] * inputs_3d.shape[1] * loss_3d_pos.item() 314 | N += inputs_3d.shape[0] * inputs_3d.shape[1] 315 | 316 | loss_total = loss_3d_pos 317 | 318 | loss_total.backward() 319 | 320 | optimizer.step() 321 | del inputs_3d, loss_3d_pos, predicted_3d_pos 322 | torch.cuda.empty_cache() 323 | 324 | losses_3d_train.append(epoch_loss_3d_train / N) 325 | torch.cuda.empty_cache() 326 | 327 | # End-of-epoch evaluation 328 | with torch.no_grad(): 329 | model_pos.load_state_dict(model_pos_train.state_dict(), strict=False) 330 | model_pos.eval() 331 | 332 | epoch_loss_3d_valid = 0 333 | epoch_loss_traj_valid = 0 334 | epoch_loss_2d_valid = 0 335 | N = 0 336 | if not args.no_eval: 337 | # Evaluate on test set 338 | for cam, batch, batch_2d in test_generator.next_epoch(): 339 | inputs_3d = torch.from_numpy(batch.astype('float32')) 340 | inputs_2d = torch.from_numpy(batch_2d.astype('float32')) 341 | 342 | ##### apply test-time-augmentation (following Videopose3d) 343 | inputs_2d_flip = inputs_2d.clone() 344 | inputs_2d_flip[:, :, :, 0] *= -1 345 | inputs_2d_flip[:, :, kps_left + kps_right, :] = inputs_2d_flip[:, :, kps_right + kps_left, :] 346 | 347 | ##### convert size 348 | inputs_2d, inputs_3d = eval_data_prepare(receptive_field, inputs_2d, inputs_3d) 349 | inputs_2d_flip, _ = eval_data_prepare(receptive_field, inputs_2d_flip, inputs_3d) 350 | 351 | if torch.cuda.is_available(): 352 | inputs_2d = inputs_2d.cuda() 353 | inputs_2d_flip = inputs_2d_flip.cuda() 354 | inputs_3d = inputs_3d.cuda() 355 | inputs_3d[:, :, 0] = 0 356 | 357 | predicted_3d_pos = model_pos(inputs_2d) 358 | predicted_3d_pos_flip = model_pos(inputs_2d_flip) 359 | predicted_3d_pos_flip[:, :, :, 0] *= -1 360 | predicted_3d_pos_flip[:, :, joints_left + joints_right] = predicted_3d_pos_flip[:, :, 361 | joints_right + joints_left] 362 | 363 | predicted_3d_pos = torch.mean(torch.cat((predicted_3d_pos, predicted_3d_pos_flip), dim=1), dim=1, 364 | keepdim=True) 365 | 366 | del inputs_2d, inputs_2d_flip 367 | torch.cuda.empty_cache() 368 | 369 | loss_3d_pos = mpjpe(predicted_3d_pos, inputs_3d) 370 | epoch_loss_3d_valid += inputs_3d.shape[0] * inputs_3d.shape[1] * loss_3d_pos.item() 371 | N += inputs_3d.shape[0] * inputs_3d.shape[1] 372 | 373 | del inputs_3d, loss_3d_pos, predicted_3d_pos 374 | torch.cuda.empty_cache() 375 | 376 | losses_3d_valid.append(epoch_loss_3d_valid / N) 377 | 378 | # Evaluate on training set, this time in evaluation mode 379 | epoch_loss_3d_train_eval = 0 380 | epoch_loss_traj_train_eval = 0 381 | epoch_loss_2d_train_labeled_eval = 0 382 | N = 0 383 | for cam, batch, batch_2d in train_generator_eval.next_epoch(): 384 | if batch_2d.shape[1] == 0: 385 | # This can only happen when downsampling the dataset 386 | continue 387 | 388 | inputs_3d = torch.from_numpy(batch.astype('float32')) 389 | inputs_2d = torch.from_numpy(batch_2d.astype('float32')) 390 | inputs_2d, inputs_3d = eval_data_prepare(receptive_field, inputs_2d, inputs_3d) 391 | 392 | if torch.cuda.is_available(): 393 | inputs_3d = inputs_3d.cuda() 394 | inputs_2d = inputs_2d.cuda() 395 | 396 | inputs_3d[:, :, 0] = 0 397 | 398 | # Compute 3D poses 399 | predicted_3d_pos = model_pos(inputs_2d) 400 | 401 | del inputs_2d 402 | torch.cuda.empty_cache() 403 | 404 | loss_3d_pos = mpjpe(predicted_3d_pos, inputs_3d) 405 | epoch_loss_3d_train_eval += inputs_3d.shape[0] * inputs_3d.shape[1] * loss_3d_pos.item() 406 | N += inputs_3d.shape[0] * inputs_3d.shape[1] 407 | 408 | del inputs_3d, loss_3d_pos, predicted_3d_pos 409 | torch.cuda.empty_cache() 410 | 411 | losses_3d_train_eval.append(epoch_loss_3d_train_eval / N) 412 | 413 | # Evaluate 2D loss on unlabeled training set (in evaluation mode) 414 | epoch_loss_2d_train_unlabeled_eval = 0 415 | N_semi = 0 416 | 417 | elapsed = (time() - start_time) / 60 418 | 419 | if args.no_eval: 420 | print('[%d] time %.2f lr %f 3d_train %f' % ( 421 | epoch + 1, 422 | elapsed, 423 | lr, 424 | losses_3d_train[-1] * 1000)) 425 | else: 426 | 427 | print('[%d] time %.2f lr %f 3d_train %f 3d_eval %f 3d_valid %f' % ( 428 | epoch + 1, 429 | elapsed, 430 | lr, 431 | losses_3d_train[-1] * 1000, 432 | losses_3d_train_eval[-1] * 1000, 433 | losses_3d_valid[-1] * 1000)) 434 | 435 | # Decay learning rate exponentially 436 | lr *= lr_decay 437 | for param_group in optimizer.param_groups: 438 | param_group['lr'] *= lr_decay 439 | epoch += 1 440 | 441 | # Decay BatchNorm momentum 442 | # momentum = initial_momentum * np.exp(-epoch/args.epochs * np.log(initial_momentum/final_momentum)) 443 | # model_pos_train.set_bn_momentum(momentum) 444 | 445 | # Save checkpoint if necessary 446 | if epoch % args.checkpoint_frequency == 0: 447 | chk_path = os.path.join(args.checkpoint, 'epoch_{}.bin'.format(epoch)) 448 | print('Saving checkpoint to', chk_path) 449 | 450 | torch.save({ 451 | 'epoch': epoch, 452 | 'lr': lr, 453 | 'random_state': train_generator.random_state(), 454 | 'optimizer': optimizer.state_dict(), 455 | 'model_pos': model_pos_train.state_dict(), 456 | # 'model_traj': model_traj_train.state_dict() if semi_supervised else None, 457 | # 'random_state_semi': semi_generator.random_state() if semi_supervised else None, 458 | }, chk_path) 459 | 460 | #### save best checkpoint 461 | best_chk_path = os.path.join(args.checkpoint, 'best_epoch.bin'.format(epoch)) 462 | if losses_3d_valid[-1] * 1000 < min_loss: 463 | min_loss = losses_3d_valid[-1] * 1000 464 | print("save best checkpoint") 465 | torch.save({ 466 | 'epoch': epoch, 467 | 'lr': lr, 468 | 'random_state': train_generator.random_state(), 469 | 'optimizer': optimizer.state_dict(), 470 | 'model_pos': model_pos_train.state_dict(), 471 | # 'model_traj': model_traj_train.state_dict() if semi_supervised else None, 472 | # 'random_state_semi': semi_generator.random_state() if semi_supervised else None, 473 | }, best_chk_path) 474 | 475 | # Save training curves after every epoch, as .png images (if requested) 476 | if args.export_training_curves and epoch > 3: 477 | if 'matplotlib' not in sys.modules: 478 | import matplotlib 479 | 480 | matplotlib.use('Agg') 481 | import matplotlib.pyplot as plt 482 | 483 | plt.figure() 484 | epoch_x = np.arange(3, len(losses_3d_train)) + 1 485 | plt.plot(epoch_x, losses_3d_train[3:], '--', color='C0') 486 | plt.plot(epoch_x, losses_3d_train_eval[3:], color='C0') 487 | plt.plot(epoch_x, losses_3d_valid[3:], color='C1') 488 | plt.legend(['3d train', '3d train (eval)', '3d valid (eval)']) 489 | plt.ylabel('MPJPE (m)') 490 | plt.xlabel('Epoch') 491 | plt.xlim((3, epoch)) 492 | plt.savefig(os.path.join(args.checkpoint, 'loss_3d.png')) 493 | 494 | plt.close('all') 495 | 496 | 497 | # Evaluate 498 | def evaluate(test_generator, action=None, return_predictions=False, use_trajectory_model=False): 499 | epoch_loss_3d_pos = 0 500 | epoch_loss_3d_pos_procrustes = 0 501 | epoch_loss_3d_pos_scale = 0 502 | epoch_loss_3d_vel = 0 503 | with torch.no_grad(): 504 | if not use_trajectory_model: 505 | model_pos.eval() 506 | # else: 507 | # model_traj.eval() 508 | N = 0 509 | for _, batch, batch_2d in test_generator.next_epoch(): 510 | inputs_2d = torch.from_numpy(batch_2d.astype('float32')) 511 | inputs_3d = torch.from_numpy(batch.astype('float32')) 512 | 513 | 514 | ##### apply test-time-augmentation (following Videopose3d) 515 | inputs_2d_flip = inputs_2d.clone() 516 | inputs_2d_flip [:, :, :, 0] *= -1 517 | inputs_2d_flip[:, :, kps_left + kps_right,:] = inputs_2d_flip[:, :, kps_right + kps_left,:] 518 | 519 | ##### convert size 520 | inputs_2d, inputs_3d = eval_data_prepare(receptive_field, inputs_2d, inputs_3d) 521 | inputs_2d_flip, _ = eval_data_prepare(receptive_field, inputs_2d_flip, inputs_3d) 522 | 523 | if torch.cuda.is_available(): 524 | inputs_2d = inputs_2d.cuda() 525 | inputs_2d_flip = inputs_2d_flip.cuda() 526 | inputs_3d = inputs_3d.cuda() 527 | inputs_3d[:, :, 0] = 0 528 | 529 | predicted_3d_pos = model_pos(inputs_2d) 530 | predicted_3d_pos_flip = model_pos(inputs_2d_flip) 531 | predicted_3d_pos_flip[:, :, :, 0] *= -1 532 | predicted_3d_pos_flip[:, :, joints_left + joints_right] = predicted_3d_pos_flip[:, :, 533 | joints_right + joints_left] 534 | 535 | predicted_3d_pos = torch.mean(torch.cat((predicted_3d_pos, predicted_3d_pos_flip), dim=1), dim=1, 536 | keepdim=True) 537 | 538 | del inputs_2d, inputs_2d_flip 539 | torch.cuda.empty_cache() 540 | 541 | if return_predictions: 542 | return predicted_3d_pos.squeeze(0).cpu().numpy() 543 | 544 | 545 | error = mpjpe(predicted_3d_pos, inputs_3d) 546 | epoch_loss_3d_pos_scale += inputs_3d.shape[0]*inputs_3d.shape[1] * n_mpjpe(predicted_3d_pos, inputs_3d).item() 547 | 548 | epoch_loss_3d_pos += inputs_3d.shape[0]*inputs_3d.shape[1] * error.item() 549 | N += inputs_3d.shape[0] * inputs_3d.shape[1] 550 | 551 | inputs = inputs_3d.cpu().numpy().reshape(-1, inputs_3d.shape[-2], inputs_3d.shape[-1]) 552 | predicted_3d_pos = predicted_3d_pos.cpu().numpy().reshape(-1, inputs_3d.shape[-2], inputs_3d.shape[-1]) 553 | 554 | epoch_loss_3d_pos_procrustes += inputs_3d.shape[0]*inputs_3d.shape[1] * p_mpjpe(predicted_3d_pos, inputs) 555 | 556 | # Compute velocity error 557 | epoch_loss_3d_vel += inputs_3d.shape[0]*inputs_3d.shape[1] * mean_velocity_error(predicted_3d_pos, inputs) 558 | 559 | if action is None: 560 | print('----------') 561 | else: 562 | print('----'+action+'----') 563 | e1 = (epoch_loss_3d_pos / N)*1000 564 | e2 = (epoch_loss_3d_pos_procrustes / N)*1000 565 | e3 = (epoch_loss_3d_pos_scale / N)*1000 566 | ev = (epoch_loss_3d_vel / N)*1000 567 | print('Protocol #1 Error (MPJPE):', e1, 'mm') 568 | print('Protocol #2 Error (P-MPJPE):', e2, 'mm') 569 | print('Protocol #3 Error (N-MPJPE):', e3, 'mm') 570 | print('Velocity Error (MPJVE):', ev, 'mm') 571 | print('----------') 572 | 573 | return e1, e2, e3, ev 574 | 575 | if args.render: 576 | print('Rendering...') 577 | 578 | input_keypoints = keypoints[args.viz_subject][args.viz_action][args.viz_camera].copy() 579 | ground_truth = None 580 | if args.viz_subject in dataset.subjects() and args.viz_action in dataset[args.viz_subject]: 581 | if 'positions_3d' in dataset[args.viz_subject][args.viz_action]: 582 | ground_truth = dataset[args.viz_subject][args.viz_action]['positions_3d'][args.viz_camera].copy() 583 | if ground_truth is None: 584 | print('INFO: this action is unlabeled. Ground truth will not be rendered.') 585 | 586 | gen = UnchunkedGenerator(None, [ground_truth], [input_keypoints], 587 | pad=pad, causal_shift=causal_shift, augment=args.test_time_augmentation, 588 | kps_left=kps_left, kps_right=kps_right, joints_left=joints_left, joints_right=joints_right) 589 | prediction = evaluate(gen, return_predictions=True) 590 | # if model_traj is not None and ground_truth is None: 591 | # prediction_traj = evaluate(gen, return_predictions=True, use_trajectory_model=True) 592 | # prediction += prediction_traj 593 | 594 | if args.viz_export is not None: 595 | print('Exporting joint positions to', args.viz_export) 596 | # Predictions are in camera space 597 | np.save(args.viz_export, prediction) 598 | 599 | if args.viz_output is not None: 600 | if ground_truth is not None: 601 | # Reapply trajectory 602 | trajectory = ground_truth[:, :1] 603 | ground_truth[:, 1:] += trajectory 604 | prediction += trajectory 605 | 606 | # Invert camera transformation 607 | cam = dataset.cameras()[args.viz_subject][args.viz_camera] 608 | if ground_truth is not None: 609 | prediction = camera_to_world(prediction, R=cam['orientation'], t=cam['translation']) 610 | ground_truth = camera_to_world(ground_truth, R=cam['orientation'], t=cam['translation']) 611 | else: 612 | # If the ground truth is not available, take the camera extrinsic params from a random subject. 613 | # They are almost the same, and anyway, we only need this for visualization purposes. 614 | for subject in dataset.cameras(): 615 | if 'orientation' in dataset.cameras()[subject][args.viz_camera]: 616 | rot = dataset.cameras()[subject][args.viz_camera]['orientation'] 617 | break 618 | prediction = camera_to_world(prediction, R=rot, t=0) 619 | # We don't have the trajectory, but at least we can rebase the height 620 | prediction[:, :, 2] -= np.min(prediction[:, :, 2]) 621 | 622 | anim_output = {'Reconstruction': prediction} 623 | if ground_truth is not None and not args.viz_no_ground_truth: 624 | anim_output['Ground truth'] = ground_truth 625 | 626 | input_keypoints = image_coordinates(input_keypoints[..., :2], w=cam['res_w'], h=cam['res_h']) 627 | 628 | from common.visualization import render_animation 629 | 630 | render_animation(input_keypoints, keypoints_metadata, anim_output, 631 | dataset.skeleton(), dataset.fps(), args.viz_bitrate, cam['azimuth'], args.viz_output, 632 | limit=args.viz_limit, downsample=args.viz_downsample, size=args.viz_size, 633 | input_video_path=args.viz_video, viewport=(cam['res_w'], cam['res_h']), 634 | input_video_skip=args.viz_skip) 635 | 636 | else: 637 | print('Evaluating...') 638 | all_actions = {} 639 | all_actions_by_subject = {} 640 | for subject in subjects_test: 641 | if subject not in all_actions_by_subject: 642 | all_actions_by_subject[subject] = {} 643 | 644 | for action in dataset[subject].keys(): 645 | action_name = action.split(' ')[0] 646 | if action_name not in all_actions: 647 | all_actions[action_name] = [] 648 | if action_name not in all_actions_by_subject[subject]: 649 | all_actions_by_subject[subject][action_name] = [] 650 | all_actions[action_name].append((subject, action)) 651 | all_actions_by_subject[subject][action_name].append((subject, action)) 652 | 653 | 654 | def fetch_actions(actions): 655 | out_poses_3d = [] 656 | out_poses_2d = [] 657 | 658 | for subject, action in actions: 659 | poses_2d = keypoints[subject][action] 660 | for i in range(len(poses_2d)): # Iterate across cameras 661 | out_poses_2d.append(poses_2d[i]) 662 | 663 | poses_3d = dataset[subject][action]['positions_3d'] 664 | assert len(poses_3d) == len(poses_2d), 'Camera count mismatch' 665 | for i in range(len(poses_3d)): # Iterate across cameras 666 | out_poses_3d.append(poses_3d[i]) 667 | 668 | stride = args.downsample 669 | if stride > 1: 670 | # Downsample as requested 671 | for i in range(len(out_poses_2d)): 672 | out_poses_2d[i] = out_poses_2d[i][::stride] 673 | if out_poses_3d is not None: 674 | out_poses_3d[i] = out_poses_3d[i][::stride] 675 | 676 | return out_poses_3d, out_poses_2d 677 | 678 | 679 | def run_evaluation(actions, action_filter=None): 680 | errors_p1 = [] 681 | errors_p2 = [] 682 | errors_p3 = [] 683 | errors_vel = [] 684 | 685 | for action_key in actions.keys(): 686 | if action_filter is not None: 687 | found = False 688 | for a in action_filter: 689 | if action_key.startswith(a): 690 | found = True 691 | break 692 | if not found: 693 | continue 694 | 695 | poses_act, poses_2d_act = fetch_actions(actions[action_key]) 696 | gen = UnchunkedGenerator(None, poses_act, poses_2d_act, 697 | pad=pad, causal_shift=causal_shift, augment=args.test_time_augmentation, 698 | kps_left=kps_left, kps_right=kps_right, joints_left=joints_left, 699 | joints_right=joints_right) 700 | e1, e2, e3, ev = evaluate(gen, action_key) 701 | errors_p1.append(e1) 702 | errors_p2.append(e2) 703 | errors_p3.append(e3) 704 | errors_vel.append(ev) 705 | 706 | print('Protocol #1 (MPJPE) action-wise average:', round(np.mean(errors_p1), 1), 'mm') 707 | print('Protocol #2 (P-MPJPE) action-wise average:', round(np.mean(errors_p2), 1), 'mm') 708 | print('Protocol #3 (N-MPJPE) action-wise average:', round(np.mean(errors_p3), 1), 'mm') 709 | print('Velocity (MPJVE) action-wise average:', round(np.mean(errors_vel), 2), 'mm') 710 | 711 | 712 | if not args.by_subject: 713 | run_evaluation(all_actions, action_filter) 714 | else: 715 | for subject in all_actions_by_subject.keys(): 716 | print('Evaluating on subject', subject) 717 | run_evaluation(all_actions_by_subject[subject], action_filter) 718 | print('') --------------------------------------------------------------------------------