├── Figures ├── GIF.gif ├── results.png └── structure.jpg ├── README.md ├── checkpoint └── README.md ├── common ├── arguments.py ├── camera.py ├── generators.py ├── h36m_dataset.py ├── loss.py ├── mocap_dataset.py ├── model.py ├── quaternion.py ├── ranger.py ├── skeleton.py ├── utils.py └── visualization.py ├── data └── README.md └── run.py /Figures/GIF.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lrxjason/Attention3DHumanPose/dc921991dc1700597511f9588c09c0aff43f1448/Figures/GIF.gif -------------------------------------------------------------------------------- /Figures/results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lrxjason/Attention3DHumanPose/dc921991dc1700597511f9588c09c0aff43f1448/Figures/results.png -------------------------------------------------------------------------------- /Figures/structure.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lrxjason/Attention3DHumanPose/dc921991dc1700597511f9588c09c0aff43f1448/Figures/structure.jpg -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Attention Mechanism Exploits Temporal Contexts: Real-time 3D Human Pose Reconstruction (CVPR 2020 Oral) 2 | More extensive evaluation andcode can be found at our lab website: (https://sites.google.com/a/udayton.edu/jshen1/cvpr2020) 3 | ![network](Figures/structure.jpg) 4 |

5 | 6 |   7 |   8 |   9 | 10 |

11 | 12 | PyTorch code of the paper "Attention Mechanism Exploits Temporal Contexts: Real-time 3D Human Pose Reconstruction". [pdf](http://openaccess.thecvf.com/content_CVPR_2020/papers/Liu_Attention_Mechanism_Exploits_Temporal_Contexts_Real-Time_3D_Human_Pose_Reconstruction_CVPR_2020_paper.pdf) 13 | 14 | ### [Bibtex](https://scholar.googleusercontent.com/scholar.bib?q=info:sVZlnopW0ZQJ:scholar.google.com/&output=citation&scisdr=CgUvGH_mEIi98y29oOM:AAGBfm0AAAAAXu-4uOOunCSIKKuamAWN5VjFJ_OC0cHs&scisig=AAGBfm0AAAAAXu-4uBa5vr92Yk6AXlKVO0mVXEXZorOx&scisf=4&ct=citation&cd=-1&hl=en) 15 | 16 | If you found this code useful, please cite the following paper: 17 | 18 | @inproceedings{liu2020attention, 19 | title={Attention Mechanism Exploits Temporal Contexts: Real-Time 3D Human Pose Reconstruction}, 20 | author={Liu, Ruixu and Shen, Ju and Wang, He and Chen, Chen and Cheung, Sen-ching and Asari, Vijayan}, 21 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 22 | pages={5064--5073}, 23 | year={2020} 24 | } 25 | 26 | ### Environment 27 | 28 | The code is developed and tested on the following environment 29 | 30 | * Python 3.6 31 | * PyTorch 1.1 or higher 32 | * CUDA 10 33 | 34 | ### Dataset 35 | 36 | The source code is for training/evaluating on the [Human3.6M](http://vision.imar.ro/human3.6m) dataset. Our code is compatible with the dataset setup introduced by [Martinez et al.](https://github.com/una-dinosauria/3d-pose-baseline) and [Pavllo et al.](https://github.com/facebookresearch/VideoPose3D). Please refer to [VideoPose3D](https://github.com/facebookresearch/VideoPose3D) to set up the Human3.6M dataset (`./data` directory). We upload the training 2D cpn data [here](https://drive.google.com/file/d/131EnG8L0-A9DNy9bfsqCSrG1n5GnzwkO/view?usp=sharing) and the 3D gt data [here](https://drive.google.com/file/d/1nbscv_IlJ-sdug6GU2KWN4MYkPtYj4YX/view?usp=sharing). The 3D Avatar model and code are avaliable [here](https://drive.google.com/file/d/1RxhwFHCX4ydf1I1crLnQ_4NEXF84MkMY/view?usp=sharing). 37 | 38 | 39 | ### Training new models 40 | 41 | To train a model from scratch, run: 42 | 43 | ```bash 44 | python run.py -da -tta 45 | ``` 46 | 47 | `-da` controls the data augments during training and `-tta` is the testing data augmentation. 48 | 49 | For example, to train our 243-frame ground truth model or causal model in our paper, please run: 50 | 51 | ```bash 52 | python run.py -k gt 53 | ``` 54 | 55 | or 56 | 57 | ```bash 58 | python run.py -k cpn_ft_h36m_dbb --causal 59 | ``` 60 | 61 | It should require 24 hours to train on one TITAN RTX GPU. 62 | 63 | ### Evaluating pre-trained models 64 | 65 | We provide the pre-trained cpn model [here](https://drive.google.com/file/d/1jiZWqAOJmXoTL8dxhPX8QgK0QeECeoAM/view?usp=sharing) and ground truth model [here](https://drive.google.com/file/d/1EAS9PUddznBPqNaEHV6-tCfqsQOHZ1Of/view?usp=sharing). To evaluate them, put them into the `./checkpoint` directory and run: 66 | 67 | For cpn model: 68 | ```bash 69 | python run.py -tta --evaluate cpn.bin 70 | ``` 71 | 72 | For ground truth model: 73 | ```bash 74 | python run.py -k gt -tta --evaluate gt.bin 75 | ``` 76 | 77 | ### Visualization and other functions 78 | 79 | We keep our code consistent with [VideoPose3D](https://github.com/facebookresearch/VideoPose3D). Please refer to their project page for further information. 80 | 81 | 82 | -------------------------------------------------------------------------------- /checkpoint/README.md: -------------------------------------------------------------------------------- 1 | The pre-trained model put at here. 2 | -------------------------------------------------------------------------------- /common/arguments.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import argparse 9 | 10 | def parse_args(): 11 | parser = argparse.ArgumentParser(description='Training script') 12 | 13 | # General arguments 14 | parser.add_argument('-d', '--dataset', default='h36m', type=str, metavar='NAME', help='target dataset') # h36m or humaneva 15 | parser.add_argument('-k', '--keypoints', default='cpn_ft_h36m_dbb', type=str, metavar='NAME', help='2D detections to use') 16 | parser.add_argument('-str', '--subjects-train', default='S1,S5,S6,S7,S8', type=str, metavar='LIST', help='training subjects separated by comma') 17 | parser.add_argument('-ste', '--subjects-test', default='S9,S11', type=str, metavar='LIST', help='test subjects separated by comma') 18 | 19 | parser.add_argument('-a', '--actions', default='*', type=str, metavar='LIST', 20 | help='actions to train/test on, separated by comma, or * for all') 21 | parser.add_argument('-c', '--checkpoint', default='checkpoint', type=str, metavar='PATH', 22 | help='checkpoint directory') 23 | parser.add_argument('--checkpoint-frequency', default=10, type=int, metavar='N', 24 | help='create a checkpoint every N epochs') 25 | parser.add_argument('-r', '--resume', default='', type=str, metavar='FILENAME', 26 | help='checkpoint to resume (file name)') 27 | parser.add_argument('--evaluate', default='', type=str, metavar='FILENAME', help='checkpoint to evaluate (file name)') 28 | parser.add_argument('--render', action='store_true', help='visualize a particular video') 29 | parser.add_argument('--by-subject', action='store_true', help='break down error by subject (on evaluation)') 30 | parser.add_argument('--export-training-curves', action='store_true', help='save training curves as .png images') 31 | 32 | # Model arguments 33 | parser.add_argument('-s', '--stride', default=1, type=int, metavar='N', help='chunk size to use during training') 34 | parser.add_argument('-e', '--epochs', default=80, type=int, metavar='N', help='number of training epochs') 35 | parser.add_argument('-b', '--batch-size', default=1024, type=int, metavar='N', help='batch size in terms of predicted frames') 36 | parser.add_argument('-drop', '--dropout', default=0.2, type=float, metavar='P', help='dropout probability') 37 | parser.add_argument('-lr', '--learning-rate', default=0.001, type=float, metavar='LR', help='initial learning rate') 38 | parser.add_argument('-lrd', '--lr-decay', default=0.95, type=float, metavar='LR', help='learning rate decay per epoch') 39 | parser.add_argument('-da', '--data-augmentation', dest='data_augmentation', action='store_true', 40 | help='disable train-time flipping') 41 | parser.add_argument('-tta', '--test-time-augmentation', dest='test_time_augmentation', action='store_true', 42 | help='disable test-time flipping') 43 | parser.add_argument('-arc', '--architecture', default='3,3,3,3,3', type=str, metavar='LAYERS', help='filter widths separated by comma') 44 | parser.add_argument('--causal', action='store_true', help='use causal convolutions for real-time processing') 45 | parser.add_argument('-ch', '--channels', default=1024, type=int, metavar='N', help='number of channels in convolution layers') 46 | 47 | # Experimental 48 | parser.add_argument('--subset', default=1, type=float, metavar='FRACTION', help='reduce dataset size by fraction') 49 | parser.add_argument('--downsample', default=1, type=int, metavar='FACTOR', help='downsample frame rate by factor (semi-supervised)') 50 | parser.add_argument('--no-eval', action='store_true', help='disable epoch evaluation while training (small speed-up)') 51 | parser.add_argument('--dense', action='store_true', help='use dense convolutions instead of dilated convolutions') 52 | parser.add_argument('--disable-optimizations', action='store_true', help='disable optimized model for single-frame predictions') 53 | parser.add_argument('--linear-projection', action='store_true', help='use only linear coefficients for semi-supervised projection') 54 | parser.add_argument('--no-bone-length', action='store_false', dest='bone_length_term', 55 | help='disable bone length term in semi-supervised settings') 56 | parser.add_argument('--no-proj', action='store_true', help='disable projection for semi-supervised setting') 57 | 58 | # Visualization 59 | parser.add_argument('--viz-subject', type=str, metavar='STR', help='subject to render') 60 | parser.add_argument('--viz-action', type=str, metavar='STR', help='action to render') 61 | parser.add_argument('--viz-camera', type=int, default=0, metavar='N', help='camera to render') 62 | parser.add_argument('--viz-video', type=str, metavar='PATH', help='path to input video') 63 | parser.add_argument('--viz-skip', type=int, default=0, metavar='N', help='skip first N frames of input video') 64 | parser.add_argument('--viz-output', type=str, metavar='PATH', help='output file name (.gif or .mp4)') 65 | parser.add_argument('--viz-bitrate', type=int, default=3000, metavar='N', help='bitrate for mp4 videos') 66 | parser.add_argument('--viz-no-ground-truth', action='store_true', help='do not show ground-truth poses') 67 | parser.add_argument('--viz-limit', type=int, default=-1, metavar='N', help='only render first N frames') 68 | parser.add_argument('--viz-downsample', type=int, default=1, metavar='N', help='downsample FPS by a factor N') 69 | parser.add_argument('--viz-size', type=int, default=5, metavar='N', help='image size') 70 | 71 | parser.set_defaults(bone_length_term=True) 72 | parser.set_defaults(data_augmentation=False) 73 | parser.set_defaults(test_time_augmentation=False) 74 | 75 | args = parser.parse_args() 76 | # Check invalid configuration 77 | if args.resume and args.evaluate: 78 | print('Invalid flags: --resume and --evaluate cannot be set at the same time') 79 | exit() 80 | 81 | if args.export_training_curves and args.no_eval: 82 | print('Invalid flags: --export-training-curves and --no-eval cannot be set at the same time') 83 | exit() 84 | 85 | return args 86 | -------------------------------------------------------------------------------- /common/camera.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import numpy as np 9 | import torch 10 | 11 | from common.utils import wrap 12 | from common.quaternion import qrot, qinverse 13 | 14 | def normalize_screen_coordinates(X, w, h): 15 | assert X.shape[-1] == 2 16 | 17 | # Normalize so that [0, w] is mapped to [-1, 1], while preserving the aspect ratio 18 | return X/w*2 - [1, h/w] 19 | 20 | 21 | def image_coordinates(X, w, h): 22 | assert X.shape[-1] == 2 23 | 24 | # Reverse camera frame normalization 25 | return (X + [1, h/w])*w/2 26 | 27 | 28 | def world_to_camera(X, R, t): 29 | Rt = wrap(qinverse, R) # Invert rotation 30 | return wrap(qrot, np.tile(Rt, (*X.shape[:-1], 1)), X - t) # Rotate and translate 31 | 32 | 33 | def camera_to_world(X, R, t): 34 | return wrap(qrot, np.tile(R, (*X.shape[:-1], 1)), X) + t 35 | 36 | 37 | def project_to_2d(X, camera_params): 38 | """ 39 | Project 3D points to 2D using the Human3.6M camera projection function. 40 | This is a differentiable and batched reimplementation of the original MATLAB script. 41 | 42 | Arguments: 43 | X -- 3D points in *camera space* to transform (N, *, 3) 44 | camera_params -- intrinsic parameteres (N, 2+2+3+2=9) 45 | """ 46 | assert X.shape[-1] == 3 47 | assert len(camera_params.shape) == 2 48 | assert camera_params.shape[-1] == 9 49 | assert X.shape[0] == camera_params.shape[0] 50 | 51 | while len(camera_params.shape) < len(X.shape): 52 | camera_params = camera_params.unsqueeze(1) 53 | 54 | f = camera_params[..., :2] 55 | c = camera_params[..., 2:4] 56 | k = camera_params[..., 4:7] 57 | p = camera_params[..., 7:] 58 | 59 | XX = torch.clamp(X[..., :2] / X[..., 2:], min=-1, max=1) 60 | r2 = torch.sum(XX[..., :2]**2, dim=len(XX.shape)-1, keepdim=True) 61 | 62 | radial = 1 + torch.sum(k * torch.cat((r2, r2**2, r2**3), dim=len(r2.shape)-1), dim=len(r2.shape)-1, keepdim=True) 63 | tan = torch.sum(p*XX, dim=len(XX.shape)-1, keepdim=True) 64 | 65 | XXX = XX*(radial + tan) + p*r2 66 | 67 | return f*XXX + c 68 | 69 | def project_to_2d_linear(X, camera_params): 70 | """ 71 | Project 3D points to 2D using only linear parameters (focal length and principal point). 72 | 73 | Arguments: 74 | X -- 3D points in *camera space* to transform (N, *, 3) 75 | camera_params -- intrinsic parameteres (N, 2+2+3+2=9) 76 | """ 77 | assert X.shape[-1] == 3 78 | assert len(camera_params.shape) == 2 79 | assert camera_params.shape[-1] == 9 80 | assert X.shape[0] == camera_params.shape[0] 81 | 82 | while len(camera_params.shape) < len(X.shape): 83 | camera_params = camera_params.unsqueeze(1) 84 | 85 | f = camera_params[..., :2] 86 | c = camera_params[..., 2:4] 87 | 88 | XX = torch.clamp(X[..., :2] / X[..., 2:], min=-1, max=1) 89 | 90 | return f*XX + c -------------------------------------------------------------------------------- /common/generators.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from itertools import zip_longest 9 | import numpy as np 10 | 11 | class ChunkedGenerator: 12 | """ 13 | Batched data generator, used for training. 14 | The sequences are split into equal-length chunks and padded as necessary. 15 | 16 | Arguments: 17 | batch_size -- the batch size to use for training 18 | cameras -- list of cameras, one element for each video (optional, used for semi-supervised training) 19 | poses_3d -- list of ground-truth 3D poses, one element for each video (optional, used for supervised training) 20 | poses_2d -- list of input 2D keypoints, one element for each video 21 | chunk_length -- number of output frames to predict for each training example (usually 1) 22 | pad -- 2D input padding to compensate for valid convolutions, per side (depends on the receptive field) 23 | causal_shift -- asymmetric padding offset when causal convolutions are used (usually 0 or "pad") 24 | shuffle -- randomly shuffle the dataset before each epoch 25 | random_seed -- initial seed to use for the random generator 26 | augment -- augment the dataset by flipping poses horizontally 27 | kps_left and kps_right -- list of left/right 2D keypoints if flipping is enabled 28 | joints_left and joints_right -- list of left/right 3D joints if flipping is enabled 29 | """ 30 | def __init__(self, batch_size, cameras, poses_3d, poses_2d, 31 | chunk_length, pad=0, causal_shift=0, 32 | shuffle=True, random_seed=1234, 33 | augment=False, kps_left=None, kps_right=None, joints_left=None, joints_right=None, 34 | endless=False, noisy=False): 35 | assert poses_3d is None or len(poses_3d) == len(poses_2d), (len(poses_3d), len(poses_2d)) 36 | assert cameras is None or len(cameras) == len(poses_2d) 37 | 38 | # Build lineage info 39 | pairs = [] # (seq_idx, start_frame, end_frame, flip) tuples 40 | for i in range(len(poses_2d)): 41 | assert poses_3d is None or poses_3d[i].shape[0] == poses_3d[i].shape[0] 42 | n_chunks = (poses_2d[i].shape[0] + chunk_length - 1) // chunk_length 43 | offset = (n_chunks * chunk_length - poses_2d[i].shape[0]) // 2 44 | bounds = np.arange(n_chunks+1)*chunk_length - offset 45 | augment_vector = np.full(len(bounds - 1), False, dtype=bool) 46 | pairs += zip(np.repeat(i, len(bounds - 1)), bounds[:-1], bounds[1:], augment_vector) 47 | if augment: 48 | pairs += zip(np.repeat(i, len(bounds - 1)), bounds[:-1], bounds[1:], ~augment_vector) 49 | 50 | # Initialize buffers 51 | if cameras is not None: 52 | self.batch_cam = np.empty((batch_size, cameras[0].shape[-1])) 53 | if poses_3d is not None: 54 | self.batch_3d = np.empty((batch_size, chunk_length, poses_3d[0].shape[-2], poses_3d[0].shape[-1])) 55 | self.batch_2d = np.empty((batch_size, chunk_length + 2*pad, poses_2d[0].shape[-2], poses_2d[0].shape[-1])) 56 | 57 | self.num_batches = (len(pairs) + batch_size - 1) // batch_size 58 | self.batch_size = batch_size 59 | self.random = np.random.RandomState(random_seed) 60 | self.pairs = pairs 61 | self.shuffle = shuffle 62 | self.pad = pad 63 | self.causal_shift = causal_shift 64 | self.endless = endless 65 | self.state = None 66 | 67 | self.cameras = cameras 68 | self.poses_3d = poses_3d 69 | self.poses_2d = poses_2d 70 | 71 | self.augment = augment 72 | self.noisy = noisy 73 | self.kps_left = kps_left 74 | self.kps_right = kps_right 75 | self.joints_left = joints_left 76 | self.joints_right = joints_right 77 | 78 | def num_frames(self): 79 | return self.num_batches * self.batch_size 80 | 81 | def random_state(self): 82 | return self.random 83 | 84 | def set_random_state(self, random): 85 | self.random = random 86 | 87 | def augment_enabled(self): 88 | return self.augment 89 | 90 | def next_pairs(self): 91 | if self.state is None: 92 | if self.shuffle: 93 | pairs = self.random.permutation(self.pairs) 94 | else: 95 | pairs = self.pairs 96 | return 0, pairs 97 | else: 98 | return self.state 99 | 100 | def next_epoch(self): 101 | enabled = True 102 | while enabled: 103 | start_idx, pairs = self.next_pairs() 104 | for b_i in range(start_idx, self.num_batches): 105 | chunks = pairs[b_i*self.batch_size : (b_i+1)*self.batch_size] 106 | for i, (seq_i, start_3d, end_3d, flip) in enumerate(chunks): 107 | start_2d = start_3d - self.pad - self.causal_shift 108 | end_2d = end_3d + self.pad - self.causal_shift 109 | 110 | # 2D poses 111 | seq_2d = self.poses_2d[seq_i] 112 | low_2d = max(start_2d, 0) 113 | high_2d = min(end_2d, seq_2d.shape[0]) 114 | pad_left_2d = low_2d - start_2d 115 | pad_right_2d = end_2d - high_2d 116 | if pad_left_2d != 0 or pad_right_2d != 0: 117 | self.batch_2d[i] = np.pad(seq_2d[low_2d:high_2d], ((pad_left_2d, pad_right_2d), (0, 0), (0, 0)), 'edge') 118 | else: 119 | self.batch_2d[i] = seq_2d[low_2d:high_2d] 120 | 121 | if flip: 122 | # Flip 2D keypoints 123 | # self.batch_2d = np.flip(self.batch_2d, 1) 124 | self.batch_2d[i, :, :, 0] *= -1 125 | self.batch_2d[i, :, self.kps_left + self.kps_right] = self.batch_2d[i, :, self.kps_right + self.kps_left] 126 | 127 | # 3D poses 128 | if self.poses_3d is not None: 129 | seq_3d = self.poses_3d[seq_i] 130 | low_3d = max(start_3d, 0) 131 | high_3d = min(end_3d, seq_3d.shape[0]) 132 | pad_left_3d = low_3d - start_3d 133 | pad_right_3d = end_3d - high_3d 134 | if pad_left_3d != 0 or pad_right_3d != 0: 135 | self.batch_3d[i] = np.pad(seq_3d[low_3d:high_3d], ((pad_left_3d, pad_right_3d), (0, 0), (0, 0)), 'edge') 136 | else: 137 | self.batch_3d[i] = seq_3d[low_3d:high_3d] 138 | 139 | if flip: 140 | # Flip 3D joints 141 | self.batch_3d[i, :, :, 0] *= -1 142 | self.batch_3d[i, :, self.joints_left + self.joints_right] = \ 143 | self.batch_3d[i, :, self.joints_right + self.joints_left] 144 | 145 | # Cameras 146 | if self.cameras is not None: 147 | self.batch_cam[i] = self.cameras[seq_i] 148 | if flip: 149 | # Flip horizontal distortion coefficients 150 | self.batch_cam[i, 2] *= -1 151 | self.batch_cam[i, 7] *= -1 152 | 153 | if self.endless: 154 | self.state = (b_i + 1, pairs) 155 | if self.poses_3d is None and self.cameras is None: 156 | yield None, None, self.batch_2d[:len(chunks)] 157 | elif self.poses_3d is not None and self.cameras is None: 158 | yield None, self.batch_3d[:len(chunks)], self.batch_2d[:len(chunks)] 159 | elif self.poses_3d is None: 160 | yield self.batch_cam[:len(chunks)], None, self.batch_2d[:len(chunks)] 161 | else: 162 | yield self.batch_cam[:len(chunks)], self.batch_3d[:len(chunks)], self.batch_2d[:len(chunks)] 163 | 164 | if self.endless: 165 | self.state = None 166 | else: 167 | enabled = False 168 | 169 | 170 | class Evaluate_Generator: 171 | """ 172 | Batched data generator, used for training. 173 | The sequences are split into equal-length chunks and padded as necessary. 174 | 175 | Arguments: 176 | batch_size -- the batch size to use for training 177 | cameras -- list of cameras, one element for each video (optional, used for semi-supervised training) 178 | poses_3d -- list of ground-truth 3D poses, one element for each video (optional, used for supervised training) 179 | poses_2d -- list of input 2D keypoints, one element for each video 180 | chunk_length -- number of output frames to predict for each training example (usually 1) 181 | pad -- 2D input padding to compensate for valid convolutions, per side (depends on the receptive field) 182 | causal_shift -- asymmetric padding offset when causal convolutions are used (usually 0 or "pad") 183 | shuffle -- randomly shuffle the dataset before each epoch 184 | random_seed -- initial seed to use for the random generator 185 | augment -- augment the dataset by flipping poses horizontally 186 | kps_left and kps_right -- list of left/right 2D keypoints if flipping is enabled 187 | joints_left and joints_right -- list of left/right 3D joints if flipping is enabled 188 | """ 189 | 190 | def __init__(self, batch_size, cameras, poses_3d, poses_2d, 191 | chunk_length, pad=0, causal_shift=0, 192 | shuffle=True, random_seed=1234, 193 | augment=False, kps_left=None, kps_right=None, joints_left=None, joints_right=None, 194 | endless=False): 195 | assert poses_3d is None or len(poses_3d) == len(poses_2d), (len(poses_3d), len(poses_2d)) 196 | assert cameras is None or len(cameras) == len(poses_2d) 197 | 198 | # Build lineage info 199 | pairs = [] # (seq_idx, start_frame, end_frame, flip) tuples 200 | for i in range(len(poses_2d)): 201 | assert poses_3d is None or poses_3d[i].shape[0] == poses_3d[i].shape[0] 202 | n_chunks = (poses_2d[i].shape[0] + chunk_length - 1) // chunk_length 203 | offset = (n_chunks * chunk_length - poses_2d[i].shape[0]) // 2 204 | bounds = np.arange(n_chunks + 1) * chunk_length - offset 205 | augment_vector = np.full(len(bounds - 1), False, dtype=bool) 206 | pairs += zip(np.repeat(i, len(bounds - 1)), bounds[:-1], bounds[1:], augment_vector) 207 | 208 | # Initialize buffers 209 | if cameras is not None: 210 | self.batch_cam = np.empty((batch_size, cameras[0].shape[-1])) 211 | if poses_3d is not None: 212 | self.batch_3d = np.empty((batch_size, chunk_length, poses_3d[0].shape[-2], poses_3d[0].shape[-1])) 213 | 214 | if augment: 215 | self.batch_2d_flip = np.empty((batch_size, chunk_length + 2 * pad, poses_2d[0].shape[-2], poses_2d[0].shape[-1])) 216 | self.batch_2d = np.empty((batch_size, chunk_length + 2 * pad, poses_2d[0].shape[-2], poses_2d[0].shape[-1])) 217 | else: 218 | self.batch_2d = np.empty((batch_size, chunk_length + 2 * pad, poses_2d[0].shape[-2], poses_2d[0].shape[-1])) 219 | 220 | self.num_batches = (len(pairs) + batch_size - 1) // batch_size 221 | self.batch_size = batch_size 222 | self.random = np.random.RandomState(random_seed) 223 | self.pairs = pairs 224 | self.shuffle = shuffle 225 | self.pad = pad 226 | self.causal_shift = causal_shift 227 | self.endless = endless 228 | self.state = None 229 | 230 | self.cameras = cameras 231 | self.poses_3d = poses_3d 232 | self.poses_2d = poses_2d 233 | 234 | self.augment = augment 235 | self.kps_left = kps_left 236 | self.kps_right = kps_right 237 | self.joints_left = joints_left 238 | self.joints_right = joints_right 239 | 240 | def num_frames(self): 241 | return self.num_batches * self.batch_size 242 | 243 | def random_state(self): 244 | return self.random 245 | 246 | def set_random_state(self, random): 247 | self.random = random 248 | 249 | def augment_enabled(self): 250 | return self.augment 251 | 252 | def next_pairs(self): 253 | if self.state is None: 254 | if self.shuffle: 255 | pairs = self.random.permutation(self.pairs) 256 | else: 257 | pairs = self.pairs 258 | return 0, pairs 259 | else: 260 | return self.state 261 | 262 | def next_epoch(self): 263 | enabled = True 264 | while enabled: 265 | start_idx, pairs = self.next_pairs() 266 | for b_i in range(start_idx, self.num_batches): 267 | chunks = pairs[b_i * self.batch_size: (b_i + 1) * self.batch_size] 268 | for i, (seq_i, start_3d, end_3d, flip) in enumerate(chunks): 269 | start_2d = start_3d - self.pad - self.causal_shift 270 | end_2d = end_3d + self.pad - self.causal_shift 271 | 272 | # 2D poses 273 | seq_2d = self.poses_2d[seq_i] 274 | low_2d = max(start_2d, 0) 275 | high_2d = min(end_2d, seq_2d.shape[0]) 276 | pad_left_2d = low_2d - start_2d 277 | pad_right_2d = end_2d - high_2d 278 | if pad_left_2d != 0 or pad_right_2d != 0: 279 | self.batch_2d[i] = np.pad(seq_2d[low_2d:high_2d], ((pad_left_2d, pad_right_2d), (0, 0), (0, 0)), 280 | 'edge') 281 | if self.augment: 282 | self.batch_2d_flip[i] = np.pad(seq_2d[low_2d:high_2d], 283 | ((pad_left_2d, pad_right_2d), (0, 0), (0, 0)), 284 | 'edge') 285 | 286 | else: 287 | self.batch_2d[i] = seq_2d[low_2d:high_2d] 288 | if self.augment: 289 | self.batch_2d_flip[i] = seq_2d[low_2d:high_2d] 290 | 291 | if self.augment: 292 | self.batch_2d_flip[i, :, :, 0] *= -1 293 | self.batch_2d_flip[i, :, self.kps_left + self.kps_right] = self.batch_2d_flip[i, :, self.kps_right + self.kps_left] 294 | 295 | # 3D poses 296 | if self.poses_3d is not None: 297 | seq_3d = self.poses_3d[seq_i] 298 | low_3d = max(start_3d, 0) 299 | high_3d = min(end_3d, seq_3d.shape[0]) 300 | pad_left_3d = low_3d - start_3d 301 | pad_right_3d = end_3d - high_3d 302 | if pad_left_3d != 0 or pad_right_3d != 0: 303 | self.batch_3d[i] = np.pad(seq_3d[low_3d:high_3d], 304 | ((pad_left_3d, pad_right_3d), (0, 0), (0, 0)), 'edge') 305 | else: 306 | self.batch_3d[i] = seq_3d[low_3d:high_3d] 307 | 308 | if flip: 309 | self.batch_3d[i, :, :, 0] *= -1 310 | self.batch_3d[i, :, self.joints_left + self.joints_right] = \ 311 | self.batch_3d[i, :, self.joints_right + self.joints_left] 312 | 313 | # Cameras 314 | if self.cameras is not None: 315 | self.batch_cam[i] = self.cameras[seq_i] 316 | if flip: 317 | # Flip horizontal distortion coefficients 318 | self.batch_cam[i, 2] *= -1 319 | self.batch_cam[i, 7] *= -1 320 | 321 | if self.endless: 322 | self.state = (b_i + 1, pairs) 323 | 324 | if self.augment: 325 | if self.poses_3d is None and self.cameras is None: 326 | yield None, None, self.batch_2d[:len(chunks)], self.batch_2d_flip[:len(chunks)] 327 | elif self.poses_3d is not None and self.cameras is None: 328 | yield None, self.batch_3d[:len(chunks)], self.batch_2d[:len(chunks)], self.batch_2d_flip[:len(chunks)] 329 | elif self.poses_3d is None: 330 | yield self.batch_cam[:len(chunks)], None, self.batch_2d[:len(chunks)], self.batch_2d_flip[:len(chunks)] 331 | else: 332 | yield self.batch_cam[:len(chunks)], self.batch_3d[:len(chunks)], self.batch_2d[:len(chunks)], self.batch_2d_flip[:len(chunks)] 333 | else: 334 | if self.poses_3d is None and self.cameras is None: 335 | yield None, None, self.batch_2d[:len(chunks)] 336 | elif self.poses_3d is not None and self.cameras is None: 337 | yield None, self.batch_3d[:len(chunks)], self.batch_2d[:len(chunks)] 338 | elif self.poses_3d is None: 339 | yield self.batch_cam[:len(chunks)], None, self.batch_2d[:len(chunks)] 340 | else: 341 | yield self.batch_cam[:len(chunks)], self.batch_3d[:len(chunks)], self.batch_2d[:len(chunks)] 342 | 343 | if self.endless: 344 | self.state = None 345 | else: 346 | enabled = False 347 | 348 | 349 | -------------------------------------------------------------------------------- /common/h36m_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import numpy as np 9 | import copy 10 | from common.skeleton import Skeleton 11 | from common.mocap_dataset import MocapDataset 12 | from common.camera import normalize_screen_coordinates, image_coordinates 13 | 14 | h36m_skeleton = Skeleton(parents=[-1, 0, 1, 2, 3, 4, 0, 6, 7, 8, 9, 0, 11, 12, 13, 14, 12, 15 | 16, 17, 18, 19, 20, 19, 22, 12, 24, 25, 26, 27, 28, 27, 30], 16 | joints_left=[6, 7, 8, 9, 10, 16, 17, 18, 19, 20, 21, 22, 23], 17 | joints_right=[1, 2, 3, 4, 5, 24, 25, 26, 27, 28, 29, 30, 31]) 18 | 19 | h36m_cameras_intrinsic_params = [ 20 | { 21 | 'id': '54138969', 22 | 'center': [512.54150390625, 515.4514770507812], 23 | 'focal_length': [1145.0494384765625, 1143.7811279296875], 24 | 'radial_distortion': [-0.20709891617298126, 0.24777518212795258, -0.0030751503072679043], 25 | 'tangential_distortion': [-0.0009756988729350269, -0.00142447161488235], 26 | 'res_w': 1000, 27 | 'res_h': 1002, 28 | 'azimuth': 70, # Only used for visualization 29 | }, 30 | { 31 | 'id': '55011271', 32 | 'center': [508.8486328125, 508.0649108886719], 33 | 'focal_length': [1149.6756591796875, 1147.5916748046875], 34 | 'radial_distortion': [-0.1942136287689209, 0.2404085397720337, 0.006819975562393665], 35 | 'tangential_distortion': [-0.0016190266469493508, -0.0027408944442868233], 36 | 'res_w': 1000, 37 | 'res_h': 1000, 38 | 'azimuth': -70, # Only used for visualization 39 | }, 40 | { 41 | 'id': '58860488', 42 | 'center': [519.8158569335938, 501.40264892578125], 43 | 'focal_length': [1149.1407470703125, 1148.7989501953125], 44 | 'radial_distortion': [-0.2083381861448288, 0.25548800826072693, -0.0024604974314570427], 45 | 'tangential_distortion': [0.0014843869721516967, -0.0007599993259645998], 46 | 'res_w': 1000, 47 | 'res_h': 1000, 48 | 'azimuth': 110, # Only used for visualization 49 | }, 50 | { 51 | 'id': '60457274', 52 | 'center': [514.9682006835938, 501.88201904296875], 53 | 'focal_length': [1145.5113525390625, 1144.77392578125], 54 | 'radial_distortion': [-0.198384091258049, 0.21832367777824402, -0.008947807364165783], 55 | 'tangential_distortion': [-0.0005872055771760643, -0.0018133620033040643], 56 | 'res_w': 1000, 57 | 'res_h': 1002, 58 | 'azimuth': -110, # Only used for visualization 59 | }, 60 | ] 61 | 62 | h36m_cameras_extrinsic_params = { 63 | 'S1': [ 64 | { 65 | 'orientation': [0.1407056450843811, -0.1500701755285263, -0.755240797996521, 0.6223280429840088], 66 | 'translation': [1841.1070556640625, 4955.28466796875, 1563.4454345703125], 67 | }, 68 | { 69 | 'orientation': [0.6157187819480896, -0.764836311340332, -0.14833825826644897, 0.11794740706682205], 70 | 'translation': [1761.278564453125, -5078.0068359375, 1606.2650146484375], 71 | }, 72 | { 73 | 'orientation': [0.14651472866535187, -0.14647851884365082, 0.7653023600578308, -0.6094175577163696], 74 | 'translation': [-1846.7777099609375, 5215.04638671875, 1491.972412109375], 75 | }, 76 | { 77 | 'orientation': [0.5834008455276489, -0.7853162288665771, 0.14548823237419128, -0.14749594032764435], 78 | 'translation': [-1794.7896728515625, -3722.698974609375, 1574.8927001953125], 79 | }, 80 | ], 81 | 'S2': [ 82 | {}, 83 | {}, 84 | {}, 85 | {}, 86 | ], 87 | 'S3': [ 88 | {}, 89 | {}, 90 | {}, 91 | {}, 92 | ], 93 | 'S4': [ 94 | {}, 95 | {}, 96 | {}, 97 | {}, 98 | ], 99 | 'S5': [ 100 | { 101 | 'orientation': [0.1467377245426178, -0.162370964884758, -0.7551892995834351, 0.6178938746452332], 102 | 'translation': [2097.3916015625, 4880.94482421875, 1605.732421875], 103 | }, 104 | { 105 | 'orientation': [0.6159758567810059, -0.7626792192459106, -0.15728192031383514, 0.1189815029501915], 106 | 'translation': [2031.7008056640625, -5167.93310546875, 1612.923095703125], 107 | }, 108 | { 109 | 'orientation': [0.14291371405124664, -0.12907841801643372, 0.7678384780883789, -0.6110143065452576], 110 | 'translation': [-1620.5948486328125, 5171.65869140625, 1496.43701171875], 111 | }, 112 | { 113 | 'orientation': [0.5920479893684387, -0.7814217805862427, 0.1274748593568802, -0.15036417543888092], 114 | 'translation': [-1637.1737060546875, -3867.3173828125, 1547.033203125], 115 | }, 116 | ], 117 | 'S6': [ 118 | { 119 | 'orientation': [0.1337897777557373, -0.15692396461963654, -0.7571090459823608, 0.6198879480361938], 120 | 'translation': [1935.4517822265625, 4950.24560546875, 1618.0838623046875], 121 | }, 122 | { 123 | 'orientation': [0.6147197484970093, -0.7628812789916992, -0.16174767911434174, 0.11819244921207428], 124 | 'translation': [1969.803955078125, -5128.73876953125, 1632.77880859375], 125 | }, 126 | { 127 | 'orientation': [0.1529948115348816, -0.13529130816459656, 0.7646096348762512, -0.6112781167030334], 128 | 'translation': [-1769.596435546875, 5185.361328125, 1476.993408203125], 129 | }, 130 | { 131 | 'orientation': [0.5916101336479187, -0.7804774045944214, 0.12832270562648773, -0.1561593860387802], 132 | 'translation': [-1721.668701171875, -3884.13134765625, 1540.4879150390625], 133 | }, 134 | ], 135 | 'S7': [ 136 | { 137 | 'orientation': [0.1435241848230362, -0.1631336808204651, -0.7548328638076782, 0.6188824772834778], 138 | 'translation': [1974.512939453125, 4926.3544921875, 1597.8326416015625], 139 | }, 140 | { 141 | 'orientation': [0.6141672730445862, -0.7638262510299683, -0.1596645563840866, 0.1177929937839508], 142 | 'translation': [1937.0584716796875, -5119.7900390625, 1631.5665283203125], 143 | }, 144 | { 145 | 'orientation': [0.14550060033798218, -0.12874816358089447, 0.7660516500473022, -0.6127139329910278], 146 | 'translation': [-1741.8111572265625, 5208.24951171875, 1464.8245849609375], 147 | }, 148 | { 149 | 'orientation': [0.5912848114967346, -0.7821764349937439, 0.12445473670959473, -0.15196487307548523], 150 | 'translation': [-1734.7105712890625, -3832.42138671875, 1548.5830078125], 151 | }, 152 | ], 153 | 'S8': [ 154 | { 155 | 'orientation': [0.14110587537288666, -0.15589867532253265, -0.7561917304992676, 0.619644045829773], 156 | 'translation': [2150.65185546875, 4896.1611328125, 1611.9046630859375], 157 | }, 158 | { 159 | 'orientation': [0.6169601678848267, -0.7647668123245239, -0.14846350252628326, 0.11158157885074615], 160 | 'translation': [2219.965576171875, -5148.453125, 1613.0440673828125], 161 | }, 162 | { 163 | 'orientation': [0.1471444070339203, -0.13377119600772858, 0.7670128345489502, -0.6100369691848755], 164 | 'translation': [-1571.2215576171875, 5137.0185546875, 1498.1761474609375], 165 | }, 166 | { 167 | 'orientation': [0.5927824378013611, -0.7825870513916016, 0.12147816270589828, -0.14631995558738708], 168 | 'translation': [-1476.913330078125, -3896.7412109375, 1547.97216796875], 169 | }, 170 | ], 171 | 'S9': [ 172 | { 173 | 'orientation': [0.15540587902069092, -0.15548215806484222, -0.7532095313072205, 0.6199594736099243], 174 | 'translation': [2044.45849609375, 4935.1171875, 1481.2275390625], 175 | }, 176 | { 177 | 'orientation': [0.618784487247467, -0.7634735107421875, -0.14132238924503326, 0.11933968216180801], 178 | 'translation': [1990.959716796875, -5123.810546875, 1568.8048095703125], 179 | }, 180 | { 181 | 'orientation': [0.13357827067375183, -0.1367100477218628, 0.7689454555511475, -0.6100738644599915], 182 | 'translation': [-1670.9921875, 5211.98583984375, 1528.387939453125], 183 | }, 184 | { 185 | 'orientation': [0.5879399180412292, -0.7823407053947449, 0.1427614390850067, -0.14794869720935822], 186 | 'translation': [-1696.04345703125, -3827.099853515625, 1591.4127197265625], 187 | }, 188 | ], 189 | 'S11': [ 190 | { 191 | 'orientation': [0.15232472121715546, -0.15442320704460144, -0.7547563314437866, 0.6191070079803467], 192 | 'translation': [2098.440185546875, 4926.5546875, 1500.278564453125], 193 | }, 194 | { 195 | 'orientation': [0.6189449429512024, -0.7600917220115662, -0.15300633013248444, 0.1255258321762085], 196 | 'translation': [2083.182373046875, -4912.1728515625, 1561.07861328125], 197 | }, 198 | { 199 | 'orientation': [0.14943228662014008, -0.15650227665901184, 0.7681233882904053, -0.6026304364204407], 200 | 'translation': [-1609.8153076171875, 5177.3359375, 1537.896728515625], 201 | }, 202 | { 203 | 'orientation': [0.5894251465797424, -0.7818877100944519, 0.13991211354732513, -0.14715361595153809], 204 | 'translation': [-1590.738037109375, -3854.1689453125, 1578.017578125], 205 | }, 206 | ], 207 | } 208 | 209 | class Human36mDataset(MocapDataset): 210 | def __init__(self, path, remove_static_joints=True): 211 | super().__init__(fps=50, skeleton=h36m_skeleton) 212 | 213 | self._cameras = copy.deepcopy(h36m_cameras_extrinsic_params) 214 | for cameras in self._cameras.values(): 215 | for i, cam in enumerate(cameras): 216 | cam.update(h36m_cameras_intrinsic_params[i]) 217 | for k, v in cam.items(): 218 | if k not in ['id', 'res_w', 'res_h']: 219 | cam[k] = np.array(v, dtype='float32') 220 | 221 | # Normalize camera frame 222 | cam['center'] = normalize_screen_coordinates(cam['center'], w=cam['res_w'], h=cam['res_h']).astype('float32') 223 | cam['focal_length'] = cam['focal_length']/cam['res_w']*2 224 | if 'translation' in cam: 225 | cam['translation'] = cam['translation']/1000 # mm to meters 226 | 227 | # Add intrinsic parameters vector 228 | cam['intrinsic'] = np.concatenate((cam['focal_length'], 229 | cam['center'], 230 | cam['radial_distortion'], 231 | cam['tangential_distortion'])) 232 | 233 | # Load serialized dataset 234 | data = np.load(path, allow_pickle=True)['positions_3d'].item() 235 | 236 | self._data = {} 237 | for subject, actions in data.items(): 238 | self._data[subject] = {} 239 | for action_name, positions in actions.items(): 240 | self._data[subject][action_name] = { 241 | 'positions': positions, 242 | 'cameras': self._cameras[subject], 243 | } 244 | 245 | if remove_static_joints: 246 | # Bring the skeleton to 17 joints instead of the original 32 247 | self.remove_joints([4, 5, 9, 10, 11, 16, 20, 21, 22, 23, 24, 28, 29, 30, 31]) 248 | 249 | # Rewire shoulders to the correct parents 250 | self._skeleton._parents[11] = 8 251 | self._skeleton._parents[14] = 8 252 | 253 | def supports_semi_supervised(self): 254 | return True 255 | 256 | def define_actions(self): 257 | all_actions = ["Directions", 258 | "Discussion", 259 | "Eating", 260 | "Greeting", 261 | "Phoning", 262 | "Photo", 263 | "Posing", 264 | "Purchases", 265 | "Sitting", 266 | "SittingDown", 267 | "Smoking", 268 | "Waiting", 269 | "WalkDog", 270 | "Walking", 271 | "WalkTogether"] 272 | 273 | return all_actions 274 | -------------------------------------------------------------------------------- /common/loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import torch 9 | import numpy as np 10 | 11 | def mpjpe(predicted, target): 12 | """ 13 | Mean per-joint position error (i.e. mean Euclidean distance), 14 | often referred to as "Protocol #1" in many papers. 15 | """ 16 | assert predicted.shape == target.shape 17 | return torch.mean(torch.norm(predicted - target, dim=len(target.shape)-1)) 18 | 19 | 20 | def p_mpjpe(predicted, target): 21 | """ 22 | Pose error: MPJPE after rigid alignment (scale, rotation, and translation), 23 | often referred to as "Protocol #2" in many papers. 24 | """ 25 | assert predicted.shape == target.shape 26 | 27 | muX = np.mean(target, axis=1, keepdims=True) 28 | muY = np.mean(predicted, axis=1, keepdims=True) 29 | 30 | X0 = target - muX 31 | Y0 = predicted - muY 32 | 33 | normX = np.sqrt(np.sum(X0**2, axis=(1, 2), keepdims=True)) 34 | normY = np.sqrt(np.sum(Y0**2, axis=(1, 2), keepdims=True)) 35 | 36 | X0 /= normX 37 | Y0 /= normY 38 | 39 | H = np.matmul(X0.transpose(0, 2, 1), Y0) 40 | U, s, Vt = np.linalg.svd(H) 41 | V = Vt.transpose(0, 2, 1) 42 | R = np.matmul(V, U.transpose(0, 2, 1)) 43 | 44 | # Avoid improper rotations (reflections), i.e. rotations with det(R) = -1 45 | sign_detR = np.sign(np.expand_dims(np.linalg.det(R), axis=1)) 46 | V[:, :, -1] *= sign_detR 47 | s[:, -1] *= sign_detR.flatten() 48 | R = np.matmul(V, U.transpose(0, 2, 1)) # Rotation 49 | 50 | tr = np.expand_dims(np.sum(s, axis=1, keepdims=True), axis=2) 51 | 52 | a = tr * normX / normY # Scale 53 | t = muX - a*np.matmul(muY, R) # Translation 54 | 55 | # Perform rigid transformation on the input 56 | predicted_aligned = a*np.matmul(predicted, R) + t 57 | 58 | # Return MPJPE 59 | return np.mean(np.linalg.norm(predicted_aligned - target, axis=len(target.shape)-1)) 60 | 61 | -------------------------------------------------------------------------------- /common/mocap_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import numpy as np 9 | from common.skeleton import Skeleton 10 | 11 | class MocapDataset: 12 | def __init__(self, fps, skeleton): 13 | self._skeleton = skeleton 14 | self._fps = fps 15 | self._data = None # Must be filled by subclass 16 | self._cameras = None # Must be filled by subclass 17 | 18 | def remove_joints(self, joints_to_remove): 19 | kept_joints = self._skeleton.remove_joints(joints_to_remove) 20 | for subject in self._data.keys(): 21 | for action in self._data[subject].keys(): 22 | s = self._data[subject][action] 23 | s['positions'] = s['positions'][:, kept_joints] 24 | 25 | 26 | def __getitem__(self, key): 27 | return self._data[key] 28 | 29 | def subjects(self): 30 | return self._data.keys() 31 | 32 | def fps(self): 33 | return self._fps 34 | 35 | def skeleton(self): 36 | return self._skeleton 37 | 38 | def cameras(self): 39 | return self._cameras 40 | 41 | def supports_semi_supervised(self): 42 | # This method can be overridden 43 | return False -------------------------------------------------------------------------------- /common/model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | 6 | class DWConv(nn.Module): 7 | def __init__(self, in_features, out_features, kernel_size=3, stride=3): 8 | super(DWConv, self).__init__() 9 | self.DW_conv = nn.Conv1d(in_features, in_features, kernel_size=kernel_size, stride=stride, 10 | groups=in_features, bias=False) 11 | self.DW_bn = nn.BatchNorm1d(in_features, momentum=0.1) 12 | self.PW_conv = nn.Conv1d(in_features, out_features, kernel_size=1, stride=1, bias=False) 13 | self.PW_bn = nn.BatchNorm1d(out_features, momentum=0.1) 14 | 15 | def forward(self, x): 16 | x = self.DW_conv(x) 17 | x = self.DW_bn(x) 18 | x = self.PW_conv(x) 19 | x = self.PW_bn(x) 20 | return x 21 | 22 | 23 | class Kernel_Attention(nn.Module): 24 | def __init__(self, in_features, out_features=1024, M=3, G=8, r=128, stride=3): 25 | super(Kernel_Attention, self).__init__() 26 | self.convs = nn.ModuleList([]) 27 | 28 | for i in range(M): 29 | self.convs.append(nn.Sequential( 30 | nn.Conv1d(in_features, in_features, kernel_size=3, dilation=i + 1, stride=stride, padding=0, 31 | groups=in_features, bias=False), 32 | nn.BatchNorm1d(in_features), 33 | nn.Conv1d(in_features, out_features, kernel_size=1, stride=1, padding=0, groups=G, bias=False), 34 | nn.BatchNorm1d(out_features), 35 | Mish() 36 | )) 37 | self.fc = nn.Linear(out_features, r) 38 | 39 | self.fcs = nn.ModuleList([]) 40 | for i in range(M): 41 | self.fcs.append( 42 | nn.Linear(r, out_features) 43 | ) 44 | self.softmax = nn.Softmax(dim=1) 45 | 46 | def forward(self, x): 47 | for i, conv in enumerate(self.convs): 48 | if i == 0: 49 | fea = conv(x).unsqueeze_(dim=1) 50 | feas = fea 51 | else: 52 | fea = F.pad(x, (i, i), 'replicate') 53 | fea = conv(fea).unsqueeze_(dim=1) 54 | feas = torch.cat([feas, fea], dim=1) 55 | fea_U = torch.sum(feas, dim=1) 56 | fea_s = fea_U.mean(-1) 57 | fea_z = self.fc(fea_s) 58 | for i, fc in enumerate(self.fcs): 59 | vector = fc(fea_z).unsqueeze_(dim=1) 60 | if i == 0: 61 | attention_vectors = vector 62 | else: 63 | attention_vectors = torch.cat([attention_vectors, vector], dim=1) 64 | attention_vectors = self.softmax(attention_vectors) 65 | attention_vectors = attention_vectors.unsqueeze(-1) 66 | fea_v = (feas * attention_vectors).sum(dim=1) 67 | return fea_v 68 | 69 | class Mish(nn.Module): 70 | def __init__(self): 71 | super().__init__() 72 | 73 | def forward(self, x): 74 | x = x * (torch.tanh(F.softplus(x))) 75 | return x 76 | 77 | 78 | class TemporalModelBase(nn.Module): 79 | """ 80 | Do not instantiate this class. 81 | """ 82 | 83 | def __init__(self, num_joints_in, in_features, num_joints_out, 84 | filter_widths, causal, dropout, channels): 85 | super().__init__() 86 | 87 | # Validate input 88 | for fw in filter_widths: 89 | assert fw % 2 != 0, 'Only odd filter widths are supported' 90 | 91 | self.num_joints_in = num_joints_in 92 | self.in_features = in_features 93 | self.num_joints_out = num_joints_out 94 | self.filter_widths = filter_widths 95 | 96 | self.drop = nn.Dropout(dropout) 97 | self.relu = Mish() 98 | self.sigmoid = nn.Sigmoid() 99 | 100 | self.pad = [filter_widths[0] // 2] 101 | self.shrink = nn.Conv1d(channels, num_joints_out * 3, 1) 102 | 103 | def set_bn_momentum(self, momentum): 104 | for bn in self.layers_bn: 105 | bn.momentum = momentum 106 | for bn in self.layers_tem_bn: 107 | bn.momentum = momentum 108 | 109 | def receptive_field(self): 110 | """ 111 | Return the total receptive field of this model as # of frames. 112 | """ 113 | frames = 0 114 | for f in self.pad: 115 | frames += f 116 | return 1 + 2 * frames 117 | 118 | def total_causal_shift(self): 119 | """ 120 | Return the asymmetric offset for sequence padding. 121 | The returned value is typically 0 if causal convolutions are disabled, 122 | otherwise it is half the receptive field. 123 | """ 124 | frames = self.causal_shift[0] 125 | next_dilation = self.filter_widths[0] 126 | for i in range(1, len(self.filter_widths)): 127 | frames += self.causal_shift[i] * next_dilation 128 | next_dilation *= self.filter_widths[i] 129 | return frames 130 | 131 | def forward(self, x): 132 | assert len(x.shape) == 4 133 | assert x.shape[-2] == self.num_joints_in 134 | assert x.shape[-1] == self.in_features 135 | 136 | sz = x.shape[:3] 137 | mean = x[:, :, 0:1, :].expand_as(x) 138 | input_pose_centered = x - mean 139 | 140 | x = x.view(x.shape[0], x.shape[1], -1) 141 | x = x.permute(0, 2, 1) 142 | 143 | input_pose_centered = input_pose_centered.view(input_pose_centered.shape[0], input_pose_centered.shape[1], -1) 144 | input_pose_centered = input_pose_centered.permute(0, 2, 1) 145 | 146 | x = self._forward_blocks(x, input_pose_centered) 147 | 148 | x = x.permute(0, 2, 1) 149 | x = x.view(sz[0], -1, self.num_joints_out, 3) 150 | 151 | return x 152 | 153 | 154 | class TemporalModelOptimized1f(TemporalModelBase): 155 | """ 156 | 3D pose estimation model optimized for single-frame batching, i.e. 157 | where batches have input length = receptive field, and output length = 1. 158 | This scenario is only used for training when stride == 1. 159 | 160 | This implementation replaces dilated convolutions with strided convolutions 161 | to avoid generating unused intermediate results. The weights are interchangeable 162 | with the reference implementation. 163 | """ 164 | 165 | def __init__(self, num_joints_in, in_features, num_joints_out, 166 | filter_widths, causal=False, dropout=0.2, channels=1024, dense=False): 167 | """ 168 | Initialize this model. 169 | 170 | Arguments: 171 | num_joints_in -- number of input joints (e.g. 17 for Human3.6M) 172 | in_features -- number of input features for each joint (typically 2 for 2D input) 173 | num_joints_out -- number of output joints (can be different than input) 174 | filter_widths -- list of convolution widths, which also determines the # of blocks and receptive field 175 | causal -- use causal convolutions instead of symmetric convolutions (for real-time applications) 176 | dropout -- dropout probability 177 | channels -- number of convolution channels 178 | """ 179 | super().__init__(num_joints_in, in_features, num_joints_out, filter_widths, causal, dropout, channels) 180 | 181 | expand_conv = [] 182 | for i in range(len(filter_widths) - 1): 183 | expand_conv.append(DWConv(num_joints_in * in_features, channels, 184 | kernel_size=filter_widths[0], stride=filter_widths[0])) 185 | self.expand_conv = nn.ModuleList(expand_conv) 186 | 187 | self.cos_dis = nn.CosineSimilarity(dim=1, eps=1e-6) 188 | layers_tem_att = [] 189 | layers_tem_bn = [] 190 | self.frames = self.total_frame() 191 | 192 | layers_conv = [] 193 | layers_bn = [] 194 | 195 | self.causal_shift = [(filter_widths[0] // 2) if causal else 0] 196 | next_dilation = filter_widths[0] 197 | 198 | dilation_conv = [] 199 | dilation_bn = [] 200 | 201 | for i in range(3): 202 | dilation_conv.append(DWConv(channels, channels, kernel_size=filter_widths[i], stride=filter_widths[i])) 203 | dilation_conv.append(nn.Conv1d(channels, channels, 1, dilation=1, bias=False)) 204 | dilation_bn.append(nn.BatchNorm1d(channels, momentum=0.1)) 205 | 206 | self.dilation_conv = nn.ModuleList(dilation_conv) 207 | self.dilation_bn = nn.ModuleList(dilation_bn) 208 | 209 | for i in range(1, len(filter_widths)): 210 | self.pad.append((filter_widths[i] - 1) * next_dilation // 2) 211 | self.causal_shift.append((filter_widths[i] // 2) if causal else 0) 212 | 213 | layers_tem_att.append(nn.Linear(self.frames, self.frames // next_dilation)) 214 | layers_tem_bn.append(nn.BatchNorm1d(self.frames // next_dilation)) 215 | 216 | layers_conv.append(Kernel_Attention(channels, out_features=channels)) 217 | layers_conv.append(nn.Conv1d(channels, channels, 1, dilation=1, bias=False)) 218 | layers_bn.append(nn.BatchNorm1d(channels, momentum=0.1)) 219 | 220 | next_dilation *= filter_widths[i] 221 | 222 | self.layers_conv = nn.ModuleList(layers_conv) 223 | self.layers_bn = nn.ModuleList(layers_bn) 224 | self.layers_tem_att = nn.ModuleList(layers_tem_att) 225 | self.layers_tem_bn = nn.ModuleList(layers_tem_bn) 226 | 227 | def set_KA_bn(self, momentum): 228 | for i in range(len(self.layers_conv) // 2): 229 | for j in range(3): 230 | self.layers_conv[2 * i].convs[j][1].momentum = momentum 231 | self.layers_conv[2 * i].convs[j][3].momentum = momentum 232 | 233 | def set_expand_bn(self, momentum): 234 | for i in range(len(self.expand_conv)): 235 | self.expand_conv[i].DW_bn.momentum = momentum 236 | self.expand_conv[i].PW_bn.momentum = momentum 237 | 238 | def set_dilation_bn(self, momentum): 239 | for bn in self.dilation_bn: 240 | bn.momentum = momentum 241 | for i in range(len(self.dilation_conv)//2): 242 | self.dilation_conv[2*i].DW_bn.momentum = momentum 243 | self.dilation_conv[2*i].PW_bn.momentum = momentum 244 | 245 | def total_frame(self): 246 | frames = 1 247 | for i in range(len(self.filter_widths)): 248 | frames *= self.filter_widths[i] 249 | return frames 250 | 251 | def _forward_blocks(self, x, input_2D_centered): 252 | b, c, t = input_2D_centered.size() 253 | x_target = input_2D_centered[:, :, input_2D_centered.shape[2] // 2] 254 | x_target_extend = x_target.view(b, c, 1) 255 | x_traget_matrix = x_target_extend.expand_as(input_2D_centered) 256 | cos_score = self.cos_dis(x_traget_matrix, input_2D_centered) 257 | 258 | ''' 259 | Top layers 260 | ''' 261 | x_0_1 = x[:, :, 1::3] 262 | x_0_2 = x[:, :, 4::9] 263 | x_0_3 = x[:, :, 13::27] 264 | 265 | x = self.drop(self.relu(self.expand_conv[0](x))) 266 | x_0_1 = self.drop(self.relu(self.expand_conv[1](x_0_1))) 267 | x_0_2 = self.drop(self.relu(self.expand_conv[2](x_0_2))) 268 | x_0_3 = self.drop(self.relu(self.expand_conv[3](x_0_3))) 269 | 270 | for i in range(len(self.pad) - 1): 271 | res = x[:, :, self.causal_shift[i + 1] + self.filter_widths[i + 1] // 2:: self.filter_widths[i + 1]] 272 | t_attention = self.sigmoid(self.layers_tem_bn[i](self.layers_tem_att[i](cos_score))) # [batches frames] 273 | t_attention_expand = t_attention.unsqueeze(1) # [batches channels frames] 274 | if i == 0: 275 | res_1_1 = res[:, :, 1::3] 276 | res_1_2 = res[:, :, 4::9] 277 | x = x * t_attention_expand # broadcasting dot mul 278 | x_1_1 = x[:, :, 1::3] 279 | x_1_2 = x[:, :, 4::9] 280 | 281 | x = self.drop(self.layers_conv[2 * i](x)) 282 | x = res + self.drop(self.relu(self.layers_bn[i](self.layers_conv[2 * i + 1](x)))) 283 | 284 | x_1_1 = self.drop(self.relu(self.dilation_conv[0](x_1_1))) 285 | x_1_1 = res_1_1 + self.drop(self.relu(self.dilation_bn[0](self.dilation_conv[1](x_1_1)))) 286 | 287 | x_1_2 = self.drop(self.relu(self.dilation_conv[2](x_1_2))) 288 | x_1_2 = res_1_2 + self.drop(self.relu(self.dilation_bn[1](self.dilation_conv[3](x_1_2)))) 289 | 290 | elif i == 1: 291 | res_2_1 = res[:, :, 1::3] 292 | x = x * t_attention_expand # broadcasting dot mul 293 | x_2_1 = x[:, :, 1::3] 294 | x_0_1 = x_0_1 * t_attention_expand # broadcasting dot mul 295 | x = x + x_0_1 296 | 297 | x = self.drop(self.layers_conv[2 * i](x)) 298 | x = res + self.drop(self.relu(self.layers_bn[i](self.layers_conv[2 * i + 1](x)))) 299 | 300 | x_2_1 = self.drop(self.relu(self.dilation_conv[4](x_2_1))) 301 | x_2_1 = res_2_1 + self.drop(self.relu(self.dilation_bn[2](self.dilation_conv[5](x_2_1)))) 302 | 303 | elif i == 2: 304 | x = x + x_0_2 + x_1_1 305 | x = x * t_attention_expand # broadcasting dot mul 306 | x = self.drop(self.layers_conv[2 * i](x)) 307 | x = res + self.drop(self.relu(self.layers_bn[i](self.layers_conv[2 * i + 1](x)))) 308 | elif i == 3: 309 | x = x + x_0_3 + x_1_2 + x_2_1 310 | x = x * t_attention_expand # broadcasting dot mul 311 | x = self.drop(self.layers_conv[2 * i](x)) 312 | x = res + self.drop(self.relu(self.layers_bn[i](self.layers_conv[2 * i + 1](x)))) 313 | 314 | x = self.shrink(x) 315 | return x 316 | 317 | -------------------------------------------------------------------------------- /common/quaternion.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import torch 9 | 10 | def qrot(q, v): 11 | """ 12 | Rotate vector(s) v about the rotation described by quaternion(s) q. 13 | Expects a tensor of shape (*, 4) for q and a tensor of shape (*, 3) for v, 14 | where * denotes any number of dimensions. 15 | Returns a tensor of shape (*, 3). 16 | """ 17 | assert q.shape[-1] == 4 18 | assert v.shape[-1] == 3 19 | assert q.shape[:-1] == v.shape[:-1] 20 | 21 | qvec = q[..., 1:] 22 | uv = torch.cross(qvec, v, dim=len(q.shape)-1) 23 | uuv = torch.cross(qvec, uv, dim=len(q.shape)-1) 24 | return (v + 2 * (q[..., :1] * uv + uuv)) 25 | 26 | 27 | def qinverse(q, inplace=False): 28 | # We assume the quaternion to be normalized 29 | if inplace: 30 | q[..., 1:] *= -1 31 | return q 32 | else: 33 | w = q[..., :1] 34 | xyz = q[..., 1:] 35 | return torch.cat((w, -xyz), dim=len(q.shape)-1) -------------------------------------------------------------------------------- /common/ranger.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.optim.optimizer import Optimizer #, required 4 | import itertools as it 5 | 6 | class Ranger(Optimizer): 7 | 8 | def __init__(self, params, lr=1e-3, alpha=0.5, k=6, N_sma_threshhold=5, betas=(.95,0.999), eps=1e-5, weight_decay=0): 9 | #parameter checks 10 | if not 0.0 <= alpha <= 1.0: 11 | raise ValueError(f'Invalid slow update rate: {alpha}') 12 | if not 1 <= k: 13 | raise ValueError(f'Invalid lookahead steps: {k}') 14 | if not lr > 0: 15 | raise ValueError(f'Invalid Learning Rate: {lr}') 16 | if not eps > 0: 17 | raise ValueError(f'Invalid eps: {eps}') 18 | 19 | #parameter comments: 20 | # beta1 (momentum) of .95 seems to work better than .90... 21 | #N_sma_threshold of 5 seems better in testing than 4. 22 | #In both cases, worth testing on your dataset (.90 vs .95, 4 vs 5) to make sure which works best for you. 23 | 24 | #prep defaults and init torch.optim base 25 | defaults = dict(lr=lr, alpha=alpha, k=k, step_counter=0, betas=betas, N_sma_threshhold=N_sma_threshhold, eps=eps, weight_decay=weight_decay, amsgrad=True) 26 | super().__init__(params,defaults) 27 | 28 | #adjustable threshold 29 | self.N_sma_threshhold = N_sma_threshhold 30 | 31 | #now we can get to work... 32 | #removed as we now use step from RAdam...no need for duplicate step counting 33 | #for group in self.param_groups: 34 | # group["step_counter"] = 0 35 | #print("group step counter init") 36 | 37 | #look ahead params 38 | self.alpha = alpha 39 | self.k = k 40 | 41 | #radam buffer for state 42 | self.radam_buffer = [[None,None,None] for ind in range(10)] 43 | 44 | #self.first_run_check=0 45 | 46 | #lookahead weights 47 | #9/2/19 - lookahead param tensors have been moved to state storage. 48 | #This should resolve issues with load/save where weights were left in GPU memory from first load, slowing down future runs. 49 | 50 | #self.slow_weights = [[p.clone().detach() for p in group['params']] 51 | # for group in self.param_groups] 52 | 53 | #don't use grad for lookahead weights 54 | #for w in it.chain(*self.slow_weights): 55 | # w.requires_grad = False 56 | 57 | def __setstate__(self, state): 58 | print("set state called") 59 | super(Ranger, self).__setstate__(state) 60 | 61 | 62 | def step(self, closure=None): 63 | loss = None 64 | #note - below is commented out b/c I have other work that passes back the loss as a float, and thus not a callable closure. 65 | #Uncomment if you need to use the actual closure... 66 | 67 | #if closure is not None: 68 | #loss = closure() 69 | 70 | #Evaluate averages and grad, update param tensors 71 | for group in self.param_groups: 72 | 73 | for p in group['params']: 74 | if p.grad is None: 75 | continue 76 | grad = p.grad.data.float() 77 | if grad.is_sparse: 78 | raise RuntimeError('Ranger optimizer does not support sparse gradients') 79 | 80 | p_data_fp32 = p.data.float() 81 | 82 | state = self.state[p] #get state dict for this param 83 | 84 | if len(state) == 0: #if first time to run...init dictionary with our desired entries 85 | #if self.first_run_check==0: 86 | #self.first_run_check=1 87 | #print("Initializing slow buffer...should not see this at load from saved model!") 88 | state['step'] = 0 89 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 90 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 91 | 92 | #look ahead weight storage now in state dict 93 | state['slow_buffer'] = torch.empty_like(p.data) 94 | state['slow_buffer'].copy_(p.data) 95 | 96 | else: 97 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 98 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 99 | 100 | #begin computations 101 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 102 | beta1, beta2 = group['betas'] 103 | 104 | #compute variance mov avg 105 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 106 | #compute mean moving avg 107 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 108 | 109 | state['step'] += 1 110 | 111 | 112 | buffered = self.radam_buffer[int(state['step'] % 10)] 113 | if state['step'] == buffered[0]: 114 | N_sma, step_size = buffered[1], buffered[2] 115 | else: 116 | buffered[0] = state['step'] 117 | beta2_t = beta2 ** state['step'] 118 | N_sma_max = 2 / (1 - beta2) - 1 119 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 120 | buffered[1] = N_sma 121 | if N_sma > self.N_sma_threshhold: 122 | step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) 123 | else: 124 | step_size = 1.0 / (1 - beta1 ** state['step']) 125 | buffered[2] = step_size 126 | 127 | if group['weight_decay'] != 0: 128 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 129 | 130 | if N_sma > self.N_sma_threshhold: 131 | denom = exp_avg_sq.sqrt().add_(group['eps']) 132 | p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom) 133 | else: 134 | p_data_fp32.add_(-step_size * group['lr'], exp_avg) 135 | 136 | p.data.copy_(p_data_fp32) 137 | 138 | #integrated look ahead... 139 | #we do it at the param level instead of group level 140 | if state['step'] % group['k'] == 0: 141 | slow_p = state['slow_buffer'] #get access to slow param tensor 142 | slow_p.add_(self.alpha, p.data - slow_p) #(fast weights - slow weights) * alpha 143 | p.data.copy_(slow_p) #copy interpolated weights to RAdam param tensor 144 | 145 | return loss 146 | -------------------------------------------------------------------------------- /common/skeleton.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import numpy as np 9 | 10 | class Skeleton: 11 | def __init__(self, parents, joints_left, joints_right): 12 | assert len(joints_left) == len(joints_right) 13 | 14 | self._parents = np.array(parents) 15 | self._joints_left = joints_left 16 | self._joints_right = joints_right 17 | self._compute_metadata() 18 | 19 | def num_joints(self): 20 | return len(self._parents) 21 | 22 | def parents(self): 23 | return self._parents 24 | 25 | def has_children(self): 26 | return self._has_children 27 | 28 | def children(self): 29 | return self._children 30 | 31 | def remove_joints(self, joints_to_remove): 32 | """ 33 | Remove the joints specified in 'joints_to_remove'. 34 | """ 35 | valid_joints = [] 36 | for joint in range(len(self._parents)): 37 | if joint not in joints_to_remove: 38 | valid_joints.append(joint) 39 | 40 | for i in range(len(self._parents)): 41 | while self._parents[i] in joints_to_remove: 42 | self._parents[i] = self._parents[self._parents[i]] 43 | 44 | index_offsets = np.zeros(len(self._parents), dtype=int) 45 | new_parents = [] 46 | for i, parent in enumerate(self._parents): 47 | if i not in joints_to_remove: 48 | new_parents.append(parent - index_offsets[parent]) 49 | else: 50 | index_offsets[i:] += 1 51 | self._parents = np.array(new_parents) 52 | 53 | 54 | if self._joints_left is not None: 55 | new_joints_left = [] 56 | for joint in self._joints_left: 57 | if joint in valid_joints: 58 | new_joints_left.append(joint - index_offsets[joint]) 59 | self._joints_left = new_joints_left 60 | if self._joints_right is not None: 61 | new_joints_right = [] 62 | for joint in self._joints_right: 63 | if joint in valid_joints: 64 | new_joints_right.append(joint - index_offsets[joint]) 65 | self._joints_right = new_joints_right 66 | 67 | self._compute_metadata() 68 | 69 | return valid_joints 70 | 71 | def joints_left(self): 72 | return self._joints_left 73 | 74 | def joints_right(self): 75 | return self._joints_right 76 | 77 | def _compute_metadata(self): 78 | self._has_children = np.zeros(len(self._parents)).astype(bool) 79 | for i, parent in enumerate(self._parents): 80 | if parent != -1: 81 | self._has_children[parent] = True 82 | 83 | self._children = [] 84 | for i, parent in enumerate(self._parents): 85 | self._children.append([]) 86 | for i, parent in enumerate(self._parents): 87 | if parent != -1: 88 | self._children[parent].append(i) -------------------------------------------------------------------------------- /common/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import torch 9 | import numpy as np 10 | import hashlib 11 | 12 | def wrap(func, *args, unsqueeze=False): 13 | """ 14 | Wrap a torch function so it can be called with NumPy arrays. 15 | Input and return types are seamlessly converted. 16 | """ 17 | 18 | # Convert input types where applicable 19 | args = list(args) 20 | for i, arg in enumerate(args): 21 | if type(arg) == np.ndarray: 22 | args[i] = torch.from_numpy(arg) 23 | if unsqueeze: 24 | args[i] = args[i].unsqueeze(0) 25 | 26 | result = func(*args) 27 | 28 | # Convert output types where applicable 29 | if isinstance(result, tuple): 30 | result = list(result) 31 | for i, res in enumerate(result): 32 | if type(res) == torch.Tensor: 33 | if unsqueeze: 34 | res = res.squeeze(0) 35 | result[i] = res.numpy() 36 | return tuple(result) 37 | elif type(result) == torch.Tensor: 38 | if unsqueeze: 39 | result = result.squeeze(0) 40 | return result.numpy() 41 | else: 42 | return result 43 | 44 | def deterministic_random(min_value, max_value, data): 45 | digest = hashlib.sha256(data.encode()).digest() 46 | raw_value = int.from_bytes(digest[:4], byteorder='little', signed=False) 47 | return int(raw_value / (2**32 - 1) * (max_value - min_value)) + min_value -------------------------------------------------------------------------------- /common/visualization.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import matplotlib 9 | matplotlib.use('Agg') 10 | 11 | import matplotlib.pyplot as plt 12 | from matplotlib.animation import FuncAnimation, writers 13 | from mpl_toolkits.mplot3d import Axes3D 14 | import numpy as np 15 | import subprocess as sp 16 | 17 | def get_resolution(filename): 18 | command = ['ffprobe', '-v', 'error', '-select_streams', 'v:0', 19 | '-show_entries', 'stream=width,height', '-of', 'csv=p=0', filename] 20 | with sp.Popen(command, stdout=sp.PIPE, bufsize=-1) as pipe: 21 | for line in pipe.stdout: 22 | w, h = line.decode().strip().split(',') 23 | return int(w), int(h) 24 | 25 | def read_video(filename, skip=0, limit=-1): 26 | w, h = get_resolution(filename) 27 | 28 | command = ['ffmpeg', 29 | '-i', filename, 30 | '-f', 'image2pipe', 31 | '-pix_fmt', 'rgb24', 32 | '-vsync', '0', 33 | '-vcodec', 'rawvideo', '-'] 34 | 35 | i = 0 36 | with sp.Popen(command, stdout = sp.PIPE, bufsize=-1) as pipe: 37 | while True: 38 | data = pipe.stdout.read(w*h*3) 39 | if not data: 40 | break 41 | i += 1 42 | if i > skip: 43 | yield np.frombuffer(data, dtype='uint8').reshape((h, w, 3)) 44 | if i == limit: 45 | break 46 | 47 | 48 | 49 | def downsample_tensor(X, factor): 50 | length = X.shape[0]//factor * factor 51 | return np.mean(X[:length].reshape(-1, factor, *X.shape[1:]), axis=1) 52 | 53 | def render_animation(keypoints, poses, skeleton, fps, bitrate, azim, output, viewport, 54 | limit=-1, downsample=1, size=6, input_video_path=None, input_video_skip=0): 55 | """ 56 | TODO 57 | Render an animation. The supported output modes are: 58 | -- 'interactive': display an interactive figure 59 | (also works on notebooks if associated with %matplotlib inline) 60 | -- 'html': render the animation as HTML5 video. Can be displayed in a notebook using HTML(...). 61 | -- 'filename.mp4': render and export the animation as an h264 video (requires ffmpeg). 62 | -- 'filename.gif': render and export the animation a gif file (requires imagemagick). 63 | """ 64 | plt.ioff() 65 | fig = plt.figure(figsize=(size*(1 + len(poses)), size)) 66 | ax_in = fig.add_subplot(1, 1 + len(poses), 1) 67 | ax_in.get_xaxis().set_visible(False) 68 | ax_in.get_yaxis().set_visible(False) 69 | ax_in.set_axis_off() 70 | ax_in.set_title('Input') 71 | 72 | ax_3d = [] 73 | lines_3d = [] 74 | trajectories = [] 75 | radius = 1.7 76 | for index, (title, data) in enumerate(poses.items()): 77 | ax = fig.add_subplot(1, 1 + len(poses), index+2, projection='3d') 78 | ax.view_init(elev=15., azim=azim) 79 | ax.set_xlim3d([-radius/2, radius/2]) 80 | ax.set_zlim3d([0, radius]) 81 | ax.set_ylim3d([-radius/2, radius/2]) 82 | ax.set_aspect('equal') 83 | ax.set_xticklabels([]) 84 | ax.set_yticklabels([]) 85 | ax.set_zticklabels([]) 86 | ax.dist = 7.5 87 | ax.set_title(title) #, pad=35 88 | ax_3d.append(ax) 89 | lines_3d.append([]) 90 | trajectories.append(data[:, 0, [0, 1]]) 91 | poses = list(poses.values()) 92 | 93 | # Decode video 94 | if input_video_path is None: 95 | # Black background 96 | all_frames = np.zeros((keypoints.shape[0], viewport[1], viewport[0]), dtype='uint8') 97 | else: 98 | # Load video using ffmpeg 99 | all_frames = [] 100 | for f in read_video(input_video_path, skip=input_video_skip): 101 | all_frames.append(f) 102 | effective_length = min(keypoints.shape[0], len(all_frames)) 103 | all_frames = all_frames[:effective_length] 104 | 105 | if downsample > 1: 106 | keypoints = downsample_tensor(keypoints, downsample) 107 | all_frames = downsample_tensor(np.array(all_frames), downsample).astype('uint8') 108 | for idx in range(len(poses)): 109 | poses[idx] = downsample_tensor(poses[idx], downsample) 110 | trajectories[idx] = downsample_tensor(trajectories[idx], downsample) 111 | fps /= downsample 112 | 113 | initialized = False 114 | image = None 115 | lines = [] 116 | points = None 117 | 118 | if limit < 1: 119 | limit = len(all_frames) 120 | else: 121 | limit = min(limit, len(all_frames)) 122 | 123 | parents = skeleton.parents() 124 | def update_video(i): 125 | nonlocal initialized, image, lines, points 126 | 127 | for n, ax in enumerate(ax_3d): 128 | ax.set_xlim3d([-radius/2 + trajectories[n][i, 0], radius/2 + trajectories[n][i, 0]]) 129 | ax.set_ylim3d([-radius/2 + trajectories[n][i, 1], radius/2 + trajectories[n][i, 1]]) 130 | 131 | # Update 2D poses 132 | if not initialized: 133 | image = ax_in.imshow(all_frames[i], aspect='equal') 134 | 135 | for j, j_parent in enumerate(parents): 136 | if j_parent == -1: 137 | continue 138 | 139 | if len(parents) == keypoints.shape[1]: 140 | # Draw skeleton only if keypoints match (otherwise we don't have the parents definition) 141 | lines.append(ax_in.plot([keypoints[i, j, 0], keypoints[i, j_parent, 0]], 142 | [keypoints[i, j, 1], keypoints[i, j_parent, 1]], color='pink')) 143 | 144 | col = 'red' if j in skeleton.joints_right() else 'black' 145 | for n, ax in enumerate(ax_3d): 146 | pos = poses[n][i] 147 | lines_3d[n].append(ax.plot([pos[j, 0], pos[j_parent, 0]], 148 | [pos[j, 1], pos[j_parent, 1]], 149 | [pos[j, 2], pos[j_parent, 2]], zdir='z', c=col)) 150 | 151 | points = ax_in.scatter(*keypoints[i].T, 5, color='red', edgecolors='white', zorder=10) 152 | 153 | initialized = True 154 | else: 155 | image.set_data(all_frames[i]) 156 | 157 | for j, j_parent in enumerate(parents): 158 | if j_parent == -1: 159 | continue 160 | 161 | if len(parents) == keypoints.shape[1]: 162 | lines[j-1][0].set_data([keypoints[i, j, 0], keypoints[i, j_parent, 0]], 163 | [keypoints[i, j, 1], keypoints[i, j_parent, 1]]) 164 | 165 | for n, ax in enumerate(ax_3d): 166 | pos = poses[n][i] 167 | lines_3d[n][j-1][0].set_xdata([pos[j, 0], pos[j_parent, 0]]) 168 | lines_3d[n][j-1][0].set_ydata([pos[j, 1], pos[j_parent, 1]]) 169 | lines_3d[n][j-1][0].set_3d_properties([pos[j, 2], pos[j_parent, 2]], zdir='z') 170 | 171 | points.set_offsets(keypoints[i]) 172 | 173 | print('{}/{} '.format(i, limit), end='\r') 174 | 175 | 176 | fig.tight_layout() 177 | 178 | anim = FuncAnimation(fig, update_video, frames=np.arange(0, limit), interval=1000/fps, repeat=False) 179 | if output.endswith('.mp4'): 180 | Writer = writers['ffmpeg'] 181 | writer = Writer(fps=fps, metadata={}, bitrate=bitrate) 182 | anim.save(output, writer=writer) 183 | elif output.endswith('.gif'): 184 | anim.save(output, dpi=80, writer='imagemagick') 185 | else: 186 | raise ValueError('Unsupported output format (only .mp4 and .gif are supported)') 187 | plt.close() -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | The dataset put at here. 2 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import numpy as np 9 | 10 | from common.arguments import parse_args 11 | import torch 12 | 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | import torch.optim as optim 16 | import os 17 | import sys 18 | import errno 19 | 20 | from common.camera import * 21 | from common.loss import * 22 | from common.generators import ChunkedGenerator, Evaluate_Generator 23 | from time import time 24 | from common.utils import deterministic_random 25 | from common.ranger import Ranger 26 | from torch.optim import lr_scheduler 27 | 28 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 29 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 30 | 31 | torch.backends.cudnn.benchmark = True 32 | 33 | args = parse_args() 34 | print(args) 35 | 36 | try: 37 | # Create checkpoint directory if it does not exist 38 | os.makedirs(args.checkpoint) 39 | except OSError as e: 40 | if e.errno != errno.EEXIST: 41 | raise RuntimeError('Unable to create checkpoint directory:', args.checkpoint) 42 | 43 | if args.causal: 44 | from common.causal_model import * 45 | else: 46 | from common.model import * 47 | print('Loading dataset...') 48 | dataset_path = 'data/data_3d_' + args.dataset + '.npz' 49 | if args.dataset == 'h36m': 50 | from common.h36m_dataset import Human36mDataset 51 | 52 | dataset = Human36mDataset(dataset_path) 53 | elif args.dataset.startswith('humaneva'): 54 | from common.humaneva_dataset import HumanEvaDataset 55 | 56 | dataset = HumanEvaDataset(dataset_path) 57 | else: 58 | raise KeyError('Invalid dataset') 59 | 60 | print('Preparing data...') 61 | for subject in dataset.subjects(): 62 | for action in dataset[subject].keys(): 63 | anim = dataset[subject][action] 64 | 65 | positions_3d = [] 66 | for cam in anim['cameras']: 67 | pos_3d = world_to_camera(anim['positions'], R=cam['orientation'], t=cam['translation']) 68 | pos_3d[:, 1:] -= pos_3d[:, :1] # Remove global offset, but keep trajectory in first position 69 | positions_3d.append(pos_3d) 70 | anim['positions_3d'] = positions_3d 71 | 72 | print('Loading 2D detections...') 73 | keypoints = np.load('data/data_2d_' + args.dataset + '_' + args.keypoints + '.npz', allow_pickle=True) 74 | keypoints_symmetry = keypoints['metadata'].item()['keypoints_symmetry'] 75 | kps_left, kps_right = list(keypoints_symmetry[0]), list(keypoints_symmetry[1]) 76 | joints_left, joints_right = list(dataset.skeleton().joints_left()), list(dataset.skeleton().joints_right()) 77 | keypoints = keypoints['positions_2d'].item() 78 | 79 | for subject in dataset.subjects(): 80 | assert subject in keypoints, 'Subject {} is missing from the 2D detections dataset'.format(subject) 81 | for action in dataset[subject].keys(): 82 | assert action in keypoints[subject], 'Action {} of subject {} is missing from the 2D detections dataset'.format( 83 | action, subject) 84 | for cam_idx in range(len(keypoints[subject][action])): 85 | 86 | # We check for >= instead of == because some videos in H3.6M contain extra frames 87 | mocap_length = dataset[subject][action]['positions_3d'][cam_idx].shape[0] 88 | assert keypoints[subject][action][cam_idx].shape[0] >= mocap_length 89 | 90 | if keypoints[subject][action][cam_idx].shape[0] > mocap_length: 91 | # Shorten sequence 92 | keypoints[subject][action][cam_idx] = keypoints[subject][action][cam_idx][:mocap_length] 93 | 94 | assert len(keypoints[subject][action]) == len(dataset[subject][action]['positions_3d']) 95 | 96 | for subject in keypoints.keys(): 97 | for action in keypoints[subject]: 98 | for cam_idx, kps in enumerate(keypoints[subject][action]): 99 | # Normalize camera frame 100 | cam = dataset.cameras()[subject][cam_idx] 101 | kps[..., :2] = normalize_screen_coordinates(kps[..., :2], w=cam['res_w'], h=cam['res_h']) 102 | keypoints[subject][action][cam_idx] = kps 103 | 104 | subjects_train = args.subjects_train.split(',') 105 | subjects_test = args.subjects_test.split(',') 106 | 107 | 108 | def fetch(subjects, action_filter=None, subset=1, parse_3d_poses=True): 109 | out_poses_3d = [] 110 | out_poses_2d = [] 111 | out_camera_params = [] 112 | for subject in subjects: 113 | for action in keypoints[subject].keys(): 114 | if action_filter is not None: 115 | found = False 116 | for a in action_filter: 117 | if action.startswith(a): 118 | found = True 119 | break 120 | if not found: 121 | continue 122 | 123 | poses_2d = keypoints[subject][action] 124 | for i in range(len(poses_2d)): # Iterate across cameras 125 | out_poses_2d.append(poses_2d[i]) 126 | 127 | if subject in dataset.cameras(): 128 | cams = dataset.cameras()[subject] 129 | assert len(cams) == len(poses_2d), 'Camera count mismatch' 130 | for cam in cams: 131 | if 'intrinsic' in cam: 132 | out_camera_params.append(cam['intrinsic']) 133 | 134 | if parse_3d_poses and 'positions_3d' in dataset[subject][action]: 135 | poses_3d = dataset[subject][action]['positions_3d'] 136 | assert len(poses_3d) == len(poses_2d), 'Camera count mismatch' 137 | for i in range(len(poses_3d)): # Iterate across cameras 138 | out_poses_3d.append(poses_3d[i]) 139 | 140 | if len(out_camera_params) == 0: 141 | out_camera_params = None 142 | if len(out_poses_3d) == 0: 143 | out_poses_3d = None 144 | 145 | stride = args.downsample 146 | if subset < 1: 147 | for i in range(len(out_poses_2d)): 148 | n_frames = int(round(len(out_poses_2d[i]) // stride * subset) * stride) 149 | start = deterministic_random(0, len(out_poses_2d[i]) - n_frames + 1, str(len(out_poses_2d[i]))) 150 | out_poses_2d[i] = out_poses_2d[i][start:start + n_frames:stride] 151 | if out_poses_3d is not None: 152 | out_poses_3d[i] = out_poses_3d[i][start:start + n_frames:stride] 153 | elif stride > 1: 154 | # Downsample as requested 155 | for i in range(len(out_poses_2d)): 156 | out_poses_2d[i] = out_poses_2d[i][::stride] 157 | if out_poses_3d is not None: 158 | out_poses_3d[i] = out_poses_3d[i][::stride] 159 | 160 | return out_camera_params, out_poses_3d, out_poses_2d 161 | 162 | 163 | action_filter = None if args.actions == '*' else args.actions.split(',') 164 | if action_filter is not None: 165 | print('Selected actions:', action_filter) 166 | 167 | cameras_valid, poses_valid, poses_valid_2d = fetch(subjects_test, action_filter) 168 | 169 | filter_widths = [int(x) for x in args.architecture.split(',')] 170 | if not args.disable_optimizations and not args.dense and args.stride == 1: 171 | # Use optimized model for single-frame predictions 172 | model_pos_train = TemporalModelOptimized1f(poses_valid_2d[0].shape[-2], poses_valid_2d[0].shape[-1], 173 | poses_valid[0].shape[-2], 174 | filter_widths=filter_widths, causal=args.causal, dropout=args.dropout, 175 | channels=args.channels) 176 | else: 177 | # When incompatible settings are detected (stride > 1, dense filters, or disabled optimization) fall back to normal model 178 | model_pos_train = TemporalModel(poses_valid_2d[0].shape[-2], poses_valid_2d[0].shape[-1], poses_valid[0].shape[-2], 179 | filter_widths=filter_widths, causal=args.causal, dropout=args.dropout, 180 | channels=args.channels, 181 | dense=args.dense) 182 | 183 | model_pos = TemporalModelOptimized1f(poses_valid_2d[0].shape[-2], poses_valid_2d[0].shape[-1], poses_valid[0].shape[-2], 184 | filter_widths=filter_widths, causal=args.causal, dropout=args.dropout, 185 | channels=args.channels, dense=args.dense) 186 | 187 | receptive_field = model_pos.receptive_field() 188 | print('INFO: Receptive field: {} frames'.format(receptive_field)) 189 | pad = (receptive_field - 1) // 2 # Padding on each side 190 | if args.causal: 191 | print('INFO: Using causal convolutions') 192 | causal_shift = pad 193 | else: 194 | causal_shift = 0 195 | 196 | model_params = 0 197 | for parameter in model_pos.parameters(): 198 | model_params += parameter.numel() 199 | print('INFO: Trainable parameter count:', model_params) 200 | 201 | if torch.cuda.is_available(): 202 | model_pos = model_pos.cuda() 203 | model_pos_train = model_pos_train.cuda() 204 | 205 | if args.resume or args.evaluate: 206 | chk_filename = os.path.join(args.checkpoint, args.resume if args.resume else args.evaluate) 207 | print('Loading checkpoint', chk_filename) 208 | checkpoint = torch.load(chk_filename, map_location=lambda storage, loc: storage) 209 | model_pos_train.load_state_dict(checkpoint['model_pos']) 210 | model_pos.load_state_dict(checkpoint['model_pos']) 211 | 212 | test_generator = ChunkedGenerator(args.batch_size // args.stride, cameras_valid, poses_valid, poses_valid_2d, 213 | args.stride, 214 | pad=pad, causal_shift=causal_shift, augment=args.test_time_augmentation, 215 | shuffle=False, 216 | kps_left=kps_left, kps_right=kps_right, joints_left=joints_left, 217 | joints_right=joints_right, noisy=False) 218 | print('INFO: Testing on {} sequences'.format(test_generator.num_frames())) 219 | 220 | if not args.evaluate: 221 | cameras_train, poses_train, poses_train_2d = fetch(subjects_train, action_filter, subset=args.subset) 222 | 223 | lr = args.learning_rate 224 | optimizer = Ranger(model_pos_train.parameters(), lr=lr) 225 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, eta_min=1e-5, T_max=args.epochs) 226 | 227 | lr_decay = args.lr_decay 228 | 229 | losses_3d_train = [] 230 | losses_3d_train_eval = [] 231 | losses_3d_valid = [] 232 | 233 | epoch = 0 234 | initial_momentum = 0.1 235 | final_momentum = 0.001 236 | 237 | train_generator = ChunkedGenerator(args.batch_size // args.stride, cameras_train, poses_train, poses_train_2d, 238 | args.stride, 239 | pad=pad, causal_shift=causal_shift, shuffle=True, augment=args.data_augmentation, 240 | kps_left=kps_left, kps_right=kps_right, joints_left=joints_left, 241 | joints_right=joints_right) 242 | train_generator_eval = ChunkedGenerator(args.batch_size // args.stride, cameras_train, poses_train, poses_train_2d, 243 | args.stride, 244 | pad=pad, causal_shift=causal_shift, augment=False, shuffle=True, 245 | kps_left=kps_left, kps_right=kps_right, joints_left=joints_left, 246 | joints_right=joints_right) 247 | print('INFO: Supervision Training on {} frames'.format(train_generator.num_frames())) 248 | 249 | if args.resume: 250 | epoch = checkpoint['epoch'] 251 | if 'optimizer' in checkpoint and checkpoint['optimizer'] is not None: 252 | optimizer.load_state_dict(checkpoint['optimizer']) 253 | train_generator.set_random_state(checkpoint['random_state']) 254 | else: 255 | print('WARNING: this checkpoint does not contain an optimizer state. The optimizer will be reinitialized.') 256 | 257 | lr = checkpoint['lr'] 258 | 259 | print('** Note: reported losses are averaged over all frames and test-time augmentation is not used here.') 260 | print('** The final evaluation will be carried out after the last training epoch.') 261 | 262 | # Pos model only 263 | while epoch < args.epochs: 264 | start_time = time() 265 | epoch_loss_3d_train = 0 266 | epoch_loss_traj_train = 0 267 | epoch_loss_2d_train_unlabeled = 0 268 | N = 0 269 | N_semi = 0 270 | model_pos_train.train() 271 | 272 | for _, batch_3d, batch_2d in train_generator.next_epoch(): 273 | inputs_3d = torch.from_numpy(batch_3d.astype('float32')) 274 | inputs_2d = torch.from_numpy(batch_2d.astype('float32')) 275 | if torch.cuda.is_available(): 276 | inputs_3d = inputs_3d.cuda() 277 | inputs_2d = inputs_2d.cuda() 278 | inputs_3d[:, :, 0] = 0 279 | 280 | optimizer.zero_grad() 281 | 282 | # Predict 3D poses 283 | predicted_3d_pos = model_pos_train(inputs_2d) 284 | loss_3d_pos = mpjpe(predicted_3d_pos, inputs_3d) 285 | epoch_loss_3d_train += inputs_3d.shape[0] * inputs_3d.shape[1] * loss_3d_pos.item() 286 | N += inputs_3d.shape[0] * inputs_3d.shape[1] 287 | 288 | loss_total = loss_3d_pos 289 | loss_total.backward() 290 | 291 | optimizer.step() 292 | 293 | losses_3d_train.append(epoch_loss_3d_train / N) 294 | 295 | # End-of-epoch evaluation 296 | with torch.no_grad(): 297 | model_pos.load_state_dict(model_pos_train.state_dict()) 298 | model_pos.eval() 299 | 300 | epoch_loss_3d_valid = 0 301 | epoch_loss_traj_valid = 0 302 | epoch_loss_2d_valid = 0 303 | N = 0 304 | 305 | if not args.no_eval: 306 | # Evaluate on test set 307 | for cam, batch, batch_2d in test_generator.next_epoch(): 308 | inputs_3d = torch.from_numpy(batch.astype('float32')) 309 | inputs_2d = torch.from_numpy(batch_2d.astype('float32')) 310 | if torch.cuda.is_available(): 311 | inputs_3d = inputs_3d.cuda() 312 | inputs_2d = inputs_2d.cuda() 313 | inputs_traj = inputs_3d[:, :, :1].clone() 314 | inputs_3d[:, :, 0] = 0 315 | 316 | # Predict 3D poses 317 | predicted_3d_pos = model_pos(inputs_2d) 318 | loss_3d_pos = mpjpe(predicted_3d_pos, inputs_3d) 319 | epoch_loss_3d_valid += inputs_3d.shape[0] * inputs_3d.shape[1] * loss_3d_pos.item() 320 | N += inputs_3d.shape[0] * inputs_3d.shape[1] 321 | 322 | losses_3d_valid.append(epoch_loss_3d_valid / N) 323 | 324 | # Evaluate on training set, this time in evaluation mode 325 | epoch_loss_3d_train_eval = 0 326 | epoch_loss_traj_train_eval = 0 327 | epoch_loss_2d_train_labeled_eval = 0 328 | N = 0 329 | for cam, batch, batch_2d in train_generator_eval.next_epoch(): 330 | if batch_2d.shape[1] == 0: 331 | # This can only happen when downsampling the dataset 332 | continue 333 | 334 | inputs_3d = torch.from_numpy(batch.astype('float32')) 335 | inputs_2d = torch.from_numpy(batch_2d.astype('float32')) 336 | if torch.cuda.is_available(): 337 | inputs_3d = inputs_3d.cuda() 338 | inputs_2d = inputs_2d.cuda() 339 | inputs_traj = inputs_3d[:, :, :1].clone() 340 | inputs_3d[:, :, 0] = 0 341 | 342 | # Compute 3D poses 343 | predicted_3d_pos = model_pos(inputs_2d) 344 | loss_3d_pos = mpjpe(predicted_3d_pos, inputs_3d) 345 | epoch_loss_3d_train_eval += inputs_3d.shape[0] * inputs_3d.shape[1] * loss_3d_pos.item() 346 | N += inputs_3d.shape[0] * inputs_3d.shape[1] 347 | 348 | losses_3d_train_eval.append(epoch_loss_3d_train_eval / N) 349 | 350 | # Evaluate 2D loss on unlabeled training set (in evaluation mode) 351 | epoch_loss_2d_train_unlabeled_eval = 0 352 | N_semi = 0 353 | 354 | elapsed = (time() - start_time) / 60 355 | 356 | if args.no_eval: 357 | print('[%d] time %.2f lr %f 3d_train %f' % ( 358 | epoch + 1, 359 | elapsed, 360 | lr, 361 | losses_3d_train[-1] * 1000)) 362 | else: 363 | print('[%d] time %.2f lr %f 3d_train %f 3d_eval %f 3d_valid %f' % ( 364 | epoch + 1, 365 | elapsed, 366 | lr, 367 | losses_3d_train[-1] * 1000, 368 | losses_3d_train_eval[-1] * 1000, 369 | losses_3d_valid[-1] * 1000)) 370 | 371 | # cosin annealing 372 | scheduler.step() 373 | lr = scheduler.get_lr()[0] 374 | for param_group in optimizer.param_groups: 375 | param_group['lr'] = lr 376 | 377 | epoch += 1 378 | momentum = initial_momentum * np.exp(-epoch / args.epochs * np.log(initial_momentum / final_momentum)) 379 | model_pos_train.set_bn_momentum(momentum) 380 | model_pos_train.set_KA_bn(momentum) 381 | model_pos_train.set_expand_bn(momentum) 382 | model_pos_train.set_dilation_bn(momentum) 383 | 384 | # Save checkpoint if necessary 385 | if epoch % args.checkpoint_frequency == 0: 386 | check_point_name = 'supervised' 387 | 388 | chk_path = os.path.join(args.checkpoint, str(args.channels) + '_' + str(args.keypoints) + 389 | '_' + str(receptive_field) + '_' + check_point_name + '_epoch_{}.bin'.format( 390 | epoch)) 391 | print('Saving checkpoint to', chk_path) 392 | 393 | torch.save({ 394 | 'epoch': epoch, 395 | 'lr': lr, 396 | 'random_state': train_generator.random_state(), 397 | 'optimizer': optimizer.state_dict(), 398 | 'model_pos': model_pos_train.state_dict(), 399 | }, chk_path) 400 | 401 | # Save training curves after every epoch, as .png images (if requested) 402 | if args.export_training_curves and epoch > 3: 403 | if 'matplotlib' not in sys.modules: 404 | import matplotlib 405 | 406 | matplotlib.use('Agg') 407 | import matplotlib.pyplot as plt 408 | 409 | plt.figure() 410 | epoch_x = np.arange(3, len(losses_3d_train)) + 1 411 | plt.plot(epoch_x, losses_3d_train[3:], '--', color='C0') 412 | plt.plot(epoch_x, losses_3d_train_eval[3:], color='C0') 413 | plt.plot(epoch_x, losses_3d_valid[3:], color='C1') 414 | plt.legend(['3d train', '3d train (eval)', '3d valid (eval)']) 415 | plt.ylabel('MPJPE (m)') 416 | plt.xlabel('Epoch') 417 | plt.xlim((3, epoch)) 418 | plt.savefig(os.path.join(args.checkpoint, 'loss_3d.png')) 419 | 420 | plt.close('all') 421 | 422 | 423 | # Evaluate 424 | def evaluate(test_generator, action=None, return_predictions=False): 425 | epoch_loss_3d_pos = 0 426 | epoch_loss_3d_pos_procrustes = 0 427 | 428 | with torch.no_grad(): 429 | model_pos.eval() 430 | N = 0 431 | 432 | # Test-time augmentation (if enabled) 433 | if args.test_time_augmentation: 434 | for _, batch, batch_2d, batch_2d_flip in test_generator.next_epoch(): 435 | inputs_2d = torch.from_numpy(batch_2d.astype('float32')) 436 | inputs_2d_flip = torch.from_numpy(batch_2d_flip.astype('float32')) 437 | if torch.cuda.is_available(): 438 | inputs_2d = inputs_2d.cuda() 439 | inputs_2d_flip = inputs_2d_flip.cuda() 440 | 441 | # Positional model 442 | predicted_3d_pos = model_pos(inputs_2d) 443 | predicted_3d_pos_flip = model_pos(inputs_2d_flip) 444 | predicted_3d_pos_flip[:, :, :, 0] *= -1 445 | predicted_3d_pos_flip[:, :, joints_left + joints_right] = predicted_3d_pos_flip[:, :, 446 | joints_right + joints_left] 447 | 448 | predicted_3d_pos = torch.mean(torch.cat((predicted_3d_pos, predicted_3d_pos_flip), dim=1), dim=1, 449 | keepdim=True) 450 | 451 | if return_predictions: 452 | return predicted_3d_pos.squeeze().cpu().numpy() 453 | 454 | inputs_3d = torch.from_numpy(batch.astype('float32')) 455 | if torch.cuda.is_available(): 456 | inputs_3d = inputs_3d.cuda() 457 | inputs_3d[:, :, 0] = 0 458 | 459 | error = mpjpe(predicted_3d_pos, inputs_3d) 460 | 461 | epoch_loss_3d_pos += inputs_3d.shape[0] * inputs_3d.shape[1] * error.item() 462 | N += inputs_3d.shape[0] * inputs_3d.shape[1] 463 | 464 | inputs = inputs_3d.cpu().numpy().reshape(-1, inputs_3d.shape[-2], inputs_3d.shape[-1]) 465 | predicted_3d_pos = predicted_3d_pos.cpu().numpy().reshape(-1, inputs_3d.shape[-2], inputs_3d.shape[-1]) 466 | 467 | epoch_loss_3d_pos_procrustes += inputs_3d.shape[0] * inputs_3d.shape[1] * p_mpjpe(predicted_3d_pos, 468 | inputs) 469 | 470 | else: 471 | for _, batch, batch_2d in test_generator.next_epoch(): 472 | inputs_2d = torch.from_numpy(batch_2d.astype('float32')) 473 | if torch.cuda.is_available(): 474 | inputs_2d = inputs_2d.cuda() 475 | 476 | # Positional model 477 | predicted_3d_pos = model_pos(inputs_2d) 478 | 479 | if return_predictions: 480 | return predicted_3d_pos.squeeze().cpu().numpy() 481 | 482 | inputs_3d = torch.from_numpy(batch.astype('float32')) 483 | if torch.cuda.is_available(): 484 | inputs_3d = inputs_3d.cuda() 485 | inputs_3d[:, :, 0] = 0 486 | 487 | error = mpjpe(predicted_3d_pos, inputs_3d) 488 | 489 | epoch_loss_3d_pos += inputs_3d.shape[0] * inputs_3d.shape[1] * error.item() 490 | N += inputs_3d.shape[0] * inputs_3d.shape[1] 491 | 492 | inputs = inputs_3d.cpu().numpy().reshape(-1, inputs_3d.shape[-2], inputs_3d.shape[-1]) 493 | predicted_3d_pos = predicted_3d_pos.cpu().numpy().reshape(-1, inputs_3d.shape[-2], inputs_3d.shape[-1]) 494 | 495 | epoch_loss_3d_pos_procrustes += inputs_3d.shape[0] * inputs_3d.shape[1] * p_mpjpe(predicted_3d_pos, 496 | inputs) 497 | if action is None: 498 | print('----------') 499 | else: 500 | print('----' + action + '----') 501 | e1 = (epoch_loss_3d_pos / N) * 1000 502 | e2 = (epoch_loss_3d_pos_procrustes / N) * 1000 503 | 504 | print('Test time augmentation:', test_generator.augment_enabled()) 505 | print('Protocol #1 Error (MPJPE):', e1, 'mm') 506 | print('Protocol #2 Error (P-MPJPE):', e2, 'mm') 507 | print('----------') 508 | 509 | return e1, e2 510 | 511 | 512 | if args.render: 513 | print('Rendering...') 514 | 515 | input_keypoints = keypoints[args.viz_subject][args.viz_action][args.viz_camera].copy() 516 | if args.viz_subject in dataset.subjects() and args.viz_action in dataset[args.viz_subject]: 517 | ground_truth = dataset[args.viz_subject][args.viz_action]['positions_3d'][args.viz_camera].copy() 518 | else: 519 | ground_truth = None 520 | print('INFO: this action is unlabeled. Ground truth will not be rendered.') 521 | 522 | gen = Evaluate_Generator(1, None, None, [input_keypoints], args.stride, 523 | pad=pad, causal_shift=causal_shift, augment=args.test_time_augmentation, 524 | shuffle=False, 525 | kps_left=kps_left, kps_right=kps_right, joints_left=joints_left, 526 | joints_right=joints_right) 527 | prediction = evaluate(gen, return_predictions=True) 528 | 529 | if ground_truth is not None: 530 | # Reapply trajectory 531 | trajectory = ground_truth[:, :1] 532 | ground_truth[:, 1:] += trajectory 533 | prediction += trajectory 534 | 535 | # Invert camera transformation 536 | cam = dataset.cameras()[args.viz_subject][args.viz_camera] 537 | if ground_truth is not None: 538 | prediction = camera_to_world(prediction, R=cam['orientation'], t=cam['translation']) 539 | ground_truth = camera_to_world(ground_truth, R=cam['orientation'], t=cam['translation']) 540 | else: 541 | # If the ground truth is not available, take the camera extrinsic params from a random subject. 542 | # They are almost the same, and anyway, we only need this for visualization purposes. 543 | for subject in dataset.cameras(): 544 | if 'orientation' in dataset.cameras()[subject][args.viz_camera]: 545 | rot = dataset.cameras()[subject][args.viz_camera]['orientation'] 546 | break 547 | prediction = camera_to_world(prediction, R=rot, t=0) 548 | # We don't have the trajectory, but at least we can rebase the height 549 | prediction[:, :, 2] -= np.min(prediction[:, :, 2]) 550 | 551 | anim_output = {'Reconstruction': prediction} 552 | if ground_truth is not None and not args.viz_no_ground_truth: 553 | anim_output['Ground truth'] = ground_truth 554 | 555 | input_keypoints = image_coordinates(input_keypoints[..., :2], w=cam['res_w'], h=cam['res_h']) 556 | 557 | from common.visualization import render_animation 558 | 559 | render_animation(input_keypoints, anim_output, 560 | dataset.skeleton(), dataset.fps(), args.viz_bitrate, cam['azimuth'], args.viz_output, 561 | limit=args.viz_limit, downsample=args.viz_downsample, size=args.viz_size, 562 | input_video_path=args.viz_video, viewport=(cam['res_w'], cam['res_h']), 563 | input_video_skip=args.viz_skip) 564 | 565 | else: 566 | print('Evaluating...') 567 | all_actions = {} 568 | all_actions_by_subject = {} 569 | for subject in subjects_test: 570 | if subject not in all_actions_by_subject: 571 | all_actions_by_subject[subject] = {} 572 | 573 | ordered_actions = dataset.define_actions() 574 | for ordered_action in ordered_actions: 575 | for action in dataset[subject].keys(): 576 | action_name = action.split(' ')[0] 577 | if action_name == ordered_action: 578 | if action_name not in all_actions: 579 | all_actions[action_name] = [] 580 | if action_name not in all_actions_by_subject: 581 | all_actions_by_subject[subject][action_name] = [] 582 | all_actions[action_name].append((subject, action)) 583 | all_actions_by_subject[subject][action_name].append((subject, action)) 584 | else: 585 | continue 586 | 587 | 588 | def fetch_actions(actions): 589 | out_poses_3d = [] 590 | out_poses_2d = [] 591 | 592 | for subject, action in actions: 593 | poses_2d = keypoints[subject][action] 594 | for i in range(len(poses_2d)): # Iterate across cameras 595 | out_poses_2d.append(poses_2d[i]) 596 | 597 | poses_3d = dataset[subject][action]['positions_3d'] 598 | assert len(poses_3d) == len(poses_2d), 'Camera count mismatch' 599 | for i in range(len(poses_3d)): # Iterate across cameras 600 | out_poses_3d.append(poses_3d[i]) 601 | 602 | stride = args.downsample 603 | if stride > 1: 604 | # Downsample as requested 605 | for i in range(len(out_poses_2d)): 606 | out_poses_2d[i] = out_poses_2d[i][::stride] 607 | if out_poses_3d is not None: 608 | out_poses_3d[i] = out_poses_3d[i][::stride] 609 | 610 | return out_poses_3d, out_poses_2d 611 | 612 | 613 | def run_evaluation(actions, action_filter=None): 614 | errors_p1 = [] 615 | errors_p2 = [] 616 | 617 | for action_key in actions.keys(): 618 | if action_filter is not None: 619 | found = False 620 | for a in action_filter: 621 | if action_key.startswith(a): 622 | found = True 623 | break 624 | if not found: 625 | continue 626 | 627 | poses_act, poses_2d_act = fetch_actions(actions[action_key]) 628 | gen = Evaluate_Generator(1, None, poses_act, poses_2d_act, args.stride, 629 | pad=pad, causal_shift=causal_shift, augment=args.test_time_augmentation, 630 | shuffle=False, 631 | kps_left=kps_left, kps_right=kps_right, joints_left=joints_left, 632 | joints_right=joints_right) 633 | e1, e2 = evaluate(gen, action_key) 634 | errors_p1.append(e1) 635 | errors_p2.append(e2) 636 | 637 | print('Protocol #1 (MPJPE) action-wise average:', round(np.mean(errors_p1), 1), 'mm') 638 | print('Protocol #2 (P-MPJPE) action-wise average:', round(np.mean(errors_p2), 1), 'mm') 639 | 640 | if not args.by_subject: 641 | run_evaluation(all_actions, action_filter) 642 | else: 643 | for subject in all_actions_by_subject.keys(): 644 | print('Evaluating on subject', subject) 645 | run_evaluation(all_actions_by_subject[subject], action_filter) 646 | print('') 647 | 648 | --------------------------------------------------------------------------------