├── README.md
├── checkpoint
└── README.md
├── common
├── README.MD
├── arguments.py
├── camera.py
├── custom_dataset.py
├── generators.py
├── h36m_dataset.py
├── humaneva_dataset.py
├── loss.py
├── mocap_dataset.py
├── model_poseformer.py
├── quaternion.py
├── skeleton.py
├── utils.py
└── visualization.py
├── data
└── README.MD
├── figure
├── H3.6.gif
├── PoseFormer.gif
└── wild.gif
├── poseformer.yml
└── run_poseformer.py
/README.md:
--------------------------------------------------------------------------------
1 | # 3D Human Pose Estimation with Spatial and Temporal Transformers
2 | This repo is the official implementation for [3D Human Pose Estimation with Spatial and Temporal Transformers](https://arxiv.org/pdf/2103.10455.pdf). The paper is accepted to [ICCV 2021](http://iccv2021.thecvf.com/home).
3 |
4 | - **Welcome to check our Neurips 2023 work:** [Context-Aware PoseFormer](https://arxiv.org/pdf/2311.03312.pdf)
5 | - **Welcome to check our CVPR 2023 work:** [PoseFormerV2](https://github.com/QitaoZhao/PoseFormerV2)
6 | - Visualization code for in-the-wild videos can be found here [PoseFormer_demo](https://github.com/zczcwh/poseformer_demo)
7 |
8 | [Video Demonstration](https://youtu.be/z8HWOdXjGR8)
9 |
10 | ## PoseFormer Architecture
11 |
12 |
13 |
14 | ## Video Demo
15 |
16 |
17 | |
|
18 | |:--:|
19 | | 3D HPE on Human3.6M |
20 |
21 | |
|
22 | |:--:|
23 | | 3D HPE on videos in-the-wild using PoseFormer |
24 |
25 |
26 |
27 | Our code is built on top of [VideoPose3D](https://github.com/facebookresearch/VideoPose3D).
28 |
29 | ### Environment
30 |
31 | The code is developed and tested under the following environment
32 |
33 | * Python 3.8.2
34 | * PyTorch 1.7.1
35 | * CUDA 11.0
36 |
37 | You can create the environment:
38 | ```bash
39 | conda env create -f poseformer.yml
40 | ```
41 |
42 | ### Dataset
43 |
44 | Our code is compatible with the dataset setup introduced by [Martinez et al.](https://github.com/una-dinosauria/3d-pose-baseline) and [Pavllo et al.](https://github.com/facebookresearch/VideoPose3D). Please refer to [VideoPose3D](https://github.com/facebookresearch/VideoPose3D) to set up the Human3.6M dataset (./data directory).
45 |
46 | ### Evaluating pre-trained models
47 |
48 | We provide the pre-trained 81-frame model (CPN detected 2D pose as input) [here](https://drive.google.com/file/d/1oX5H5QpVoFzyD-Qz9aaP3RDWDb1v1sIy/view?usp=sharing). To evaluate it, put it into the `./checkpoint` directory and run:
49 |
50 | ```bash
51 | python run_poseformer.py -k cpn_ft_h36m_dbb -f 81 -c checkpoint --evaluate detected81f.bin
52 | ```
53 |
54 | We also provide pre-trained 81-frame model (Ground truth 2D pose as input) [here](https://drive.google.com/file/d/18wW4TdNYxF-zdt9oInmwQK9hEdRJnXzu/view?usp=sharing). To evaluate it, put it into the `./checkpoint` directory and run:
55 |
56 | ```bash
57 | python run_poseformer.py -k gt -f 81 -c checkpoint --evaluate gt81f.bin
58 | ```
59 |
60 |
61 | ### Training new models
62 |
63 | * To train a model from scratch (CPN detected 2D pose as input), run:
64 |
65 | ```bash
66 | python run_poseformer.py -k cpn_ft_h36m_dbb -f 27 -lr 0.00004 -lrd 0.99
67 | ```
68 |
69 | `-f` controls how many frames are used as input. 27 frames achieves 47.0 mm, 81 frames achieves achieves 44.3 mm.
70 |
71 | * To train a model from scratch (Ground truth 2D pose as input), run:
72 |
73 | ```bash
74 | python run_poseformer.py -k gt -f 81 -lr 0.0004 -lrd 0.99
75 | ```
76 |
77 | 81 frames achieves 31.3 mm (MPJPE).
78 |
79 | ### Visualization and other functions
80 |
81 | We keep our code consistent with [VideoPose3D](https://github.com/facebookresearch/VideoPose3D). Please refer to their project page for further information.
82 |
83 | ### Bibtex
84 | If you find our work useful in your research, please consider citing:
85 |
86 | @article{zheng20213d,
87 | title={3D Human Pose Estimation with Spatial and Temporal Transformers},
88 | author={Zheng, Ce and Zhu, Sijie and Mendieta, Matias and Yang, Taojiannan and Chen, Chen and Ding, Zhengming},
89 | journal={Proceedings of the IEEE International Conference on Computer Vision (ICCV)},
90 | year={2021}
91 | }
92 |
93 | ## Acknowledgement
94 |
95 | Part of our code is borrowed from [VideoPose3D](https://github.com/facebookresearch/VideoPose3D). We thank the authors for releasing the codes.
96 |
--------------------------------------------------------------------------------
/checkpoint/README.md:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/common/README.MD:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/common/arguments.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2018-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | import argparse
9 |
10 | def parse_args():
11 | parser = argparse.ArgumentParser(description='Training script')
12 |
13 | # General arguments
14 | parser.add_argument('-d', '--dataset', default='h36m', type=str, metavar='NAME', help='target dataset') # h36m or humaneva
15 | parser.add_argument('-k', '--keypoints', default='cpn_ft_h36m_dbb', type=str, metavar='NAME', help='2D detections to use')
16 | parser.add_argument('-str', '--subjects-train', default='S1,S5,S6,S7,S8', type=str, metavar='LIST',
17 | help='training subjects separated by comma')
18 | parser.add_argument('-ste', '--subjects-test', default='S9,S11', type=str, metavar='LIST', help='test subjects separated by comma')
19 | parser.add_argument('-sun', '--subjects-unlabeled', default='', type=str, metavar='LIST',
20 | help='unlabeled subjects separated by comma for self-supervision')
21 | parser.add_argument('-a', '--actions', default='*', type=str, metavar='LIST',
22 | help='actions to train/test on, separated by comma, or * for all')
23 | parser.add_argument('-c', '--checkpoint', default='checkpoint', type=str, metavar='PATH',
24 | help='checkpoint directory')
25 | parser.add_argument('--checkpoint-frequency', default=40, type=int, metavar='N',
26 | help='create a checkpoint every N epochs')
27 | parser.add_argument('-r', '--resume', default='', type=str, metavar='FILENAME',
28 | help='checkpoint to resume (file name)')
29 | parser.add_argument('--evaluate', default='', type=str, metavar='FILENAME', help='checkpoint to evaluate (file name)')
30 | parser.add_argument('--render', action='store_true', help='visualize a particular video')
31 | parser.add_argument('--by-subject', action='store_true', help='break down error by subject (on evaluation)')
32 | parser.add_argument('--export-training-curves', action='store_true', help='save training curves as .png images')
33 |
34 |
35 | # Model arguments
36 | parser.add_argument('-s', '--stride', default=1, type=int, metavar='N', help='chunk size to use during training')
37 | parser.add_argument('-e', '--epochs', default=200, type=int, metavar='N', help='number of training epochs')
38 | parser.add_argument('-b', '--batch-size', default=512, type=int, metavar='N', help='batch size in terms of predicted frames')
39 | parser.add_argument('-drop', '--dropout', default=0., type=float, metavar='P', help='dropout probability')
40 | parser.add_argument('-lr', '--learning-rate', default=0.0001, type=float, metavar='LR', help='initial learning rate')
41 | parser.add_argument('-lrd', '--lr-decay', default=0.99, type=float, metavar='LR', help='learning rate decay per epoch')
42 | parser.add_argument('-no-da', '--no-data-augmentation', dest='data_augmentation', action='store_false',
43 | help='disable train-time flipping')
44 | # parser.add_argument('-no-tta', '--no-test-time-augmentation', dest='test_time_augmentation', action='store_false',
45 | # help='disable test-time flipping')
46 | # parser.add_argument('-arc', '--architecture', default='3,3,3', type=str, metavar='LAYERS', help='filter widths separated by comma')
47 | parser.add_argument('-frame', '--number-of-frames', default='81', type=int, metavar='N',
48 | help='how many frames used as input')
49 | # parser.add_argument('--causal', action='store_true', help='use causal convolutions for real-time processing')
50 | # parser.add_argument('-ch', '--channels', default=1024, type=int, metavar='N', help='number of channels in convolution layers')
51 |
52 | # Experimental
53 | parser.add_argument('--subset', default=1, type=float, metavar='FRACTION', help='reduce dataset size by fraction')
54 | parser.add_argument('--downsample', default=1, type=int, metavar='FACTOR', help='downsample frame rate by factor (semi-supervised)')
55 | parser.add_argument('--warmup', default=1, type=int, metavar='N', help='warm-up epochs for semi-supervision')
56 | parser.add_argument('--no-eval', action='store_true', help='disable epoch evaluation while training (small speed-up)')
57 | parser.add_argument('--dense', action='store_true', help='use dense convolutions instead of dilated convolutions')
58 | parser.add_argument('--disable-optimizations', action='store_true', help='disable optimized model for single-frame predictions')
59 | parser.add_argument('--linear-projection', action='store_true', help='use only linear coefficients for semi-supervised projection')
60 | parser.add_argument('--no-bone-length', action='store_false', dest='bone_length_term',
61 | help='disable bone length term in semi-supervised settings')
62 | parser.add_argument('--no-proj', action='store_true', help='disable projection for semi-supervised setting')
63 |
64 | # Visualization
65 | parser.add_argument('--viz-subject', type=str, metavar='STR', help='subject to render')
66 | parser.add_argument('--viz-action', type=str, metavar='STR', help='action to render')
67 | parser.add_argument('--viz-camera', type=int, default=0, metavar='N', help='camera to render')
68 | parser.add_argument('--viz-video', type=str, metavar='PATH', help='path to input video')
69 | parser.add_argument('--viz-skip', type=int, default=0, metavar='N', help='skip first N frames of input video')
70 | parser.add_argument('--viz-output', type=str, metavar='PATH', help='output file name (.gif or .mp4)')
71 | parser.add_argument('--viz-export', type=str, metavar='PATH', help='output file name for coordinates')
72 | parser.add_argument('--viz-bitrate', type=int, default=3000, metavar='N', help='bitrate for mp4 videos')
73 | parser.add_argument('--viz-no-ground-truth', action='store_true', help='do not show ground-truth poses')
74 | parser.add_argument('--viz-limit', type=int, default=-1, metavar='N', help='only render first N frames')
75 | parser.add_argument('--viz-downsample', type=int, default=1, metavar='N', help='downsample FPS by a factor N')
76 | parser.add_argument('--viz-size', type=int, default=5, metavar='N', help='image size')
77 |
78 | parser.set_defaults(bone_length_term=True)
79 | parser.set_defaults(data_augmentation=True)
80 | parser.set_defaults(test_time_augmentation=True)
81 | # parser.set_defaults(test_time_augmentation=False)
82 |
83 | args = parser.parse_args()
84 | # Check invalid configuration
85 | if args.resume and args.evaluate:
86 | print('Invalid flags: --resume and --evaluate cannot be set at the same time')
87 | exit()
88 |
89 | if args.export_training_curves and args.no_eval:
90 | print('Invalid flags: --export-training-curves and --no-eval cannot be set at the same time')
91 | exit()
92 |
93 | return args
--------------------------------------------------------------------------------
/common/camera.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2018-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | import numpy as np
9 | import torch
10 |
11 | from common.utils import wrap
12 | from common.quaternion import qrot, qinverse
13 |
14 | def normalize_screen_coordinates(X, w, h):
15 | assert X.shape[-1] == 2
16 |
17 | # Normalize so that [0, w] is mapped to [-1, 1], while preserving the aspect ratio
18 | return X/w*2 - [1, h/w]
19 |
20 |
21 | def image_coordinates(X, w, h):
22 | assert X.shape[-1] == 2
23 |
24 | # Reverse camera frame normalization
25 | return (X + [1, h/w])*w/2
26 |
27 |
28 | def world_to_camera(X, R, t):
29 | Rt = wrap(qinverse, R) # Invert rotation
30 | return wrap(qrot, np.tile(Rt, (*X.shape[:-1], 1)), X - t) # Rotate and translate
31 |
32 |
33 | def camera_to_world(X, R, t):
34 | return wrap(qrot, np.tile(R, (*X.shape[:-1], 1)), X) + t
35 |
36 |
37 | def project_to_2d(X, camera_params):
38 | """
39 | Project 3D points to 2D using the Human3.6M camera projection function.
40 | This is a differentiable and batched reimplementation of the original MATLAB script.
41 |
42 | Arguments:
43 | X -- 3D points in *camera space* to transform (N, *, 3)
44 | camera_params -- intrinsic parameteres (N, 2+2+3+2=9)
45 | """
46 | assert X.shape[-1] == 3
47 | assert len(camera_params.shape) == 2
48 | assert camera_params.shape[-1] == 9
49 | assert X.shape[0] == camera_params.shape[0]
50 |
51 | while len(camera_params.shape) < len(X.shape):
52 | camera_params = camera_params.unsqueeze(1)
53 |
54 | f = camera_params[..., :2]
55 | c = camera_params[..., 2:4]
56 | k = camera_params[..., 4:7]
57 | p = camera_params[..., 7:]
58 |
59 | XX = torch.clamp(X[..., :2] / X[..., 2:], min=-1, max=1)
60 | r2 = torch.sum(XX[..., :2]**2, dim=len(XX.shape)-1, keepdim=True)
61 |
62 | radial = 1 + torch.sum(k * torch.cat((r2, r2**2, r2**3), dim=len(r2.shape)-1), dim=len(r2.shape)-1, keepdim=True)
63 | tan = torch.sum(p*XX, dim=len(XX.shape)-1, keepdim=True)
64 |
65 | XXX = XX*(radial + tan) + p*r2
66 |
67 | return f*XXX + c
68 |
69 | def project_to_2d_linear(X, camera_params):
70 | """
71 | Project 3D points to 2D using only linear parameters (focal length and principal point).
72 |
73 | Arguments:
74 | X -- 3D points in *camera space* to transform (N, *, 3)
75 | camera_params -- intrinsic parameteres (N, 2+2+3+2=9)
76 | """
77 | assert X.shape[-1] == 3
78 | assert len(camera_params.shape) == 2
79 | assert camera_params.shape[-1] == 9
80 | assert X.shape[0] == camera_params.shape[0]
81 |
82 | while len(camera_params.shape) < len(X.shape):
83 | camera_params = camera_params.unsqueeze(1)
84 |
85 | f = camera_params[..., :2]
86 | c = camera_params[..., 2:4]
87 |
88 | XX = torch.clamp(X[..., :2] / X[..., 2:], min=-1, max=1)
89 |
90 | return f*XX + c
--------------------------------------------------------------------------------
/common/custom_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2018-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | import numpy as np
9 | import copy
10 | from common.skeleton import Skeleton
11 | from common.mocap_dataset import MocapDataset
12 | from common.camera import normalize_screen_coordinates, image_coordinates
13 | from common.h36m_dataset import h36m_skeleton
14 |
15 |
16 | custom_camera_params = {
17 | 'id': None,
18 | 'res_w': None, # Pulled from metadata
19 | 'res_h': None, # Pulled from metadata
20 |
21 | # Dummy camera parameters (taken from Human3.6M), only for visualization purposes
22 | 'azimuth': 70, # Only used for visualization
23 | 'orientation': [0.1407056450843811, -0.1500701755285263, -0.755240797996521, 0.6223280429840088],
24 | 'translation': [1841.1070556640625, 4955.28466796875, 1563.4454345703125],
25 | }
26 |
27 | class CustomDataset(MocapDataset):
28 | def __init__(self, detections_path, remove_static_joints=True):
29 | super().__init__(fps=None, skeleton=h36m_skeleton)
30 |
31 | # Load serialized dataset
32 | data = np.load(detections_path, allow_pickle=True)
33 | resolutions = data['metadata'].item()['video_metadata']
34 |
35 | self._cameras = {}
36 | self._data = {}
37 | for video_name, res in resolutions.items():
38 | cam = {}
39 | cam.update(custom_camera_params)
40 | cam['orientation'] = np.array(cam['orientation'], dtype='float32')
41 | cam['translation'] = np.array(cam['translation'], dtype='float32')
42 | cam['translation'] = cam['translation']/1000 # mm to meters
43 |
44 | cam['id'] = video_name
45 | cam['res_w'] = res['w']
46 | cam['res_h'] = res['h']
47 |
48 | self._cameras[video_name] = [cam]
49 |
50 | self._data[video_name] = {
51 | 'custom': {
52 | 'cameras': cam
53 | }
54 | }
55 |
56 | if remove_static_joints:
57 | # Bring the skeleton to 17 joints instead of the original 32
58 | self.remove_joints([4, 5, 9, 10, 11, 16, 20, 21, 22, 23, 24, 28, 29, 30, 31])
59 |
60 | # Rewire shoulders to the correct parents
61 | self._skeleton._parents[11] = 8
62 | self._skeleton._parents[14] = 8
63 |
64 | def supports_semi_supervised(self):
65 | return False
66 |
--------------------------------------------------------------------------------
/common/generators.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2018-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | from itertools import zip_longest
9 | import numpy as np
10 |
11 |
12 | # def getbone(seq, boneindex):
13 | # bs = np.shape(seq)[0]
14 | # ss = np.shape(seq)[1]
15 | # seq = np.reshape(seq,(bs*ss,-1,3))
16 | # bone = []
17 | # for index in boneindex:
18 | # bone.append(seq[:,index[0]] - seq[:,index[1]])
19 | # bone = np.stack(bone,1)
20 | # bone = np.power(np.power(bone,2).sum(2),0.5)
21 | # bone = np.reshape(bone, (bs,ss,np.shape(bone)[1]))
22 | # return bone
23 |
24 | class ChunkedGenerator:
25 | """
26 | Batched data generator, used for training.
27 | The sequences are split into equal-length chunks and padded as necessary.
28 |
29 | Arguments:
30 | batch_size -- the batch size to use for training
31 | cameras -- list of cameras, one element for each video (optional, used for semi-supervised training)
32 | poses_3d -- list of ground-truth 3D poses, one element for each video (optional, used for supervised training)
33 | poses_2d -- list of input 2D keypoints, one element for each video
34 | chunk_length -- number of output frames to predict for each training example (usually 1)
35 | pad -- 2D input padding to compensate for valid convolutions, per side (depends on the receptive field)
36 | causal_shift -- asymmetric padding offset when causal convolutions are used (usually 0 or "pad")
37 | shuffle -- randomly shuffle the dataset before each epoch
38 | random_seed -- initial seed to use for the random generator
39 | augment -- augment the dataset by flipping poses horizontally
40 | kps_left and kps_right -- list of left/right 2D keypoints if flipping is enabled
41 | joints_left and joints_right -- list of left/right 3D joints if flipping is enabled
42 | """
43 | def __init__(self, batch_size, cameras, poses_3d, poses_2d,
44 | chunk_length, pad=0, causal_shift=0,
45 | shuffle=True, random_seed=1234,
46 | augment=False, kps_left=None, kps_right=None, joints_left=None, joints_right=None,
47 | endless=False):
48 | assert poses_3d is None or len(poses_3d) == len(poses_2d), (len(poses_3d), len(poses_2d))
49 | assert cameras is None or len(cameras) == len(poses_2d)
50 |
51 | # Build lineage info
52 | pairs = [] # (seq_idx, start_frame, end_frame, flip) tuples
53 | for i in range(len(poses_2d)):
54 | assert poses_3d is None or poses_3d[i].shape[0] == poses_3d[i].shape[0]
55 | n_chunks = (poses_2d[i].shape[0] + chunk_length - 1) // chunk_length
56 | offset = (n_chunks * chunk_length - poses_2d[i].shape[0]) // 2
57 | bounds = np.arange(n_chunks+1)*chunk_length - offset
58 | augment_vector = np.full(len(bounds - 1), False, dtype=bool)
59 | pairs += zip(np.repeat(i, len(bounds - 1)), bounds[:-1], bounds[1:], augment_vector)
60 | if augment:
61 | pairs += zip(np.repeat(i, len(bounds - 1)), bounds[:-1], bounds[1:], ~augment_vector)
62 |
63 | # Initialize buffers
64 | if cameras is not None:
65 | self.batch_cam = np.empty((batch_size, cameras[0].shape[-1]))
66 | if poses_3d is not None:
67 | self.batch_3d = np.empty((batch_size, chunk_length, poses_3d[0].shape[-2], poses_3d[0].shape[-1]))
68 | self.batch_2d = np.empty((batch_size, chunk_length + 2*pad, poses_2d[0].shape[-2], poses_2d[0].shape[-1]))
69 |
70 | self.num_batches = (len(pairs) + batch_size - 1) // batch_size
71 | self.batch_size = batch_size
72 | self.random = np.random.RandomState(random_seed)
73 | self.pairs = pairs
74 | self.shuffle = shuffle
75 | self.pad = pad
76 | self.causal_shift = causal_shift
77 | self.endless = endless
78 | self.state = None
79 |
80 | self.cameras = cameras
81 | self.poses_3d = poses_3d
82 | self.poses_2d = poses_2d
83 |
84 | self.augment = augment
85 | self.kps_left = kps_left
86 | self.kps_right = kps_right
87 | self.joints_left = joints_left
88 | self.joints_right = joints_right
89 |
90 | def num_frames(self):
91 | return self.num_batches * self.batch_size
92 |
93 | def random_state(self):
94 | return self.random
95 |
96 | def set_random_state(self, random):
97 | self.random = random
98 |
99 | def augment_enabled(self):
100 | return self.augment
101 |
102 | def next_pairs(self):
103 | if self.state is None:
104 | if self.shuffle:
105 | pairs = self.random.permutation(self.pairs)
106 | else:
107 | pairs = self.pairs
108 | return 0, pairs
109 | else:
110 | return self.state
111 |
112 | def next_epoch(self):
113 | enabled = True
114 | while enabled:
115 | start_idx, pairs = self.next_pairs()
116 | for b_i in range(start_idx, self.num_batches):
117 | chunks = pairs[b_i*self.batch_size : (b_i+1)*self.batch_size]
118 | for i, (seq_i, start_3d, end_3d, flip) in enumerate(chunks):
119 | start_2d = start_3d - self.pad - self.causal_shift
120 | end_2d = end_3d + self.pad - self.causal_shift
121 |
122 | # 2D poses
123 | seq_2d = self.poses_2d[seq_i]
124 | low_2d = max(start_2d, 0)
125 | high_2d = min(end_2d, seq_2d.shape[0])
126 | pad_left_2d = low_2d - start_2d
127 | pad_right_2d = end_2d - high_2d
128 | if pad_left_2d != 0 or pad_right_2d != 0:
129 | self.batch_2d[i] = np.pad(seq_2d[low_2d:high_2d], ((pad_left_2d, pad_right_2d), (0, 0), (0, 0)), 'edge')
130 | else:
131 | self.batch_2d[i] = seq_2d[low_2d:high_2d]
132 |
133 | if flip:
134 | # Flip 2D keypoints
135 | self.batch_2d[i, :, :, 0] *= -1
136 | self.batch_2d[i, :, self.kps_left + self.kps_right] = self.batch_2d[i, :, self.kps_right + self.kps_left]
137 |
138 | # 3D poses
139 | if self.poses_3d is not None:
140 | seq_3d = self.poses_3d[seq_i]
141 | low_3d = max(start_3d, 0)
142 | high_3d = min(end_3d, seq_3d.shape[0])
143 | pad_left_3d = low_3d - start_3d
144 | pad_right_3d = end_3d - high_3d
145 | if pad_left_3d != 0 or pad_right_3d != 0:
146 | self.batch_3d[i] = np.pad(seq_3d[low_3d:high_3d], ((pad_left_3d, pad_right_3d), (0, 0), (0, 0)), 'edge')
147 | else:
148 | self.batch_3d[i] = seq_3d[low_3d:high_3d]
149 |
150 | if flip:
151 | # Flip 3D joints
152 | self.batch_3d[i, :, :, 0] *= -1
153 | self.batch_3d[i, :, self.joints_left + self.joints_right] = \
154 | self.batch_3d[i, :, self.joints_right + self.joints_left]
155 |
156 | # Cameras
157 | if self.cameras is not None:
158 | self.batch_cam[i] = self.cameras[seq_i]
159 | if flip:
160 | # Flip horizontal distortion coefficients
161 | self.batch_cam[i, 2] *= -1
162 | self.batch_cam[i, 7] *= -1
163 |
164 | if self.endless:
165 | self.state = (b_i + 1, pairs)
166 | if self.poses_3d is None and self.cameras is None:
167 | yield None, None, self.batch_2d[:len(chunks)]
168 | elif self.poses_3d is not None and self.cameras is None:
169 | yield None, self.batch_3d[:len(chunks)], self.batch_2d[:len(chunks)]
170 | elif self.poses_3d is None:
171 | yield self.batch_cam[:len(chunks)], None, self.batch_2d[:len(chunks)]
172 | else:
173 | yield self.batch_cam[:len(chunks)], self.batch_3d[:len(chunks)], self.batch_2d[:len(chunks)]
174 |
175 | if self.endless:
176 | self.state = None
177 | else:
178 | enabled = False
179 |
180 |
181 | class UnchunkedGenerator:
182 | """
183 | Non-batched data generator, used for testing.
184 | Sequences are returned one at a time (i.e. batch size = 1), without chunking.
185 |
186 | If data augmentation is enabled, the batches contain two sequences (i.e. batch size = 2),
187 | the second of which is a mirrored version of the first.
188 |
189 | Arguments:
190 | cameras -- list of cameras, one element for each video (optional, used for semi-supervised training)
191 | poses_3d -- list of ground-truth 3D poses, one element for each video (optional, used for supervised training)
192 | poses_2d -- list of input 2D keypoints, one element for each video
193 | pad -- 2D input padding to compensate for valid convolutions, per side (depends on the receptive field)
194 | causal_shift -- asymmetric padding offset when causal convolutions are used (usually 0 or "pad")
195 | augment -- augment the dataset by flipping poses horizontally
196 | kps_left and kps_right -- list of left/right 2D keypoints if flipping is enabled
197 | joints_left and joints_right -- list of left/right 3D joints if flipping is enabled
198 | """
199 |
200 | def __init__(self, cameras, poses_3d, poses_2d, pad=0, causal_shift=0,
201 | augment=False, kps_left=None, kps_right=None, joints_left=None, joints_right=None):
202 | assert poses_3d is None or len(poses_3d) == len(poses_2d)
203 | assert cameras is None or len(cameras) == len(poses_2d)
204 |
205 | self.augment = False
206 | self.kps_left = kps_left
207 | self.kps_right = kps_right
208 | self.joints_left = joints_left
209 | self.joints_right = joints_right
210 |
211 | self.pad = pad
212 | self.causal_shift = causal_shift
213 | self.cameras = [] if cameras is None else cameras
214 | self.poses_3d = [] if poses_3d is None else poses_3d
215 | self.poses_2d = poses_2d
216 |
217 | def num_frames(self):
218 | count = 0
219 | for p in self.poses_2d:
220 | count += p.shape[0]
221 | return count
222 |
223 | def augment_enabled(self):
224 | return self.augment
225 |
226 | def set_augment(self, augment):
227 | self.augment = augment
228 |
229 | def next_epoch(self):
230 | for seq_cam, seq_3d, seq_2d in zip_longest(self.cameras, self.poses_3d, self.poses_2d):
231 | batch_cam = None if seq_cam is None else np.expand_dims(seq_cam, axis=0)
232 | batch_3d = None if seq_3d is None else np.expand_dims(seq_3d, axis=0)
233 | batch_2d = np.expand_dims(np.pad(seq_2d,
234 | ((self.pad + self.causal_shift, self.pad - self.causal_shift), (0, 0), (0, 0)),
235 | 'edge'), axis=0)
236 | if self.augment:
237 | # Append flipped version
238 | if batch_cam is not None:
239 | batch_cam = np.concatenate((batch_cam, batch_cam), axis=0)
240 | batch_cam[1, 2] *= -1
241 | batch_cam[1, 7] *= -1
242 |
243 | if batch_3d is not None:
244 | batch_3d = np.concatenate((batch_3d, batch_3d), axis=0)
245 | batch_3d[1, :, :, 0] *= -1
246 | batch_3d[1, :, self.joints_left + self.joints_right] = batch_3d[1, :, self.joints_right + self.joints_left]
247 |
248 | batch_2d = np.concatenate((batch_2d, batch_2d), axis=0)
249 | batch_2d[1, :, :, 0] *= -1
250 | batch_2d[1, :, self.kps_left + self.kps_right] = batch_2d[1, :, self.kps_right + self.kps_left]
251 |
252 | yield batch_cam, batch_3d, batch_2d
--------------------------------------------------------------------------------
/common/h36m_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2018-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | import numpy as np
9 | import copy
10 | from common.skeleton import Skeleton
11 | from common.mocap_dataset import MocapDataset
12 | from common.camera import normalize_screen_coordinates, image_coordinates
13 |
14 | h36m_skeleton = Skeleton(parents=[-1, 0, 1, 2, 3, 4, 0, 6, 7, 8, 9, 0, 11, 12, 13, 14, 12,
15 | 16, 17, 18, 19, 20, 19, 22, 12, 24, 25, 26, 27, 28, 27, 30],
16 | joints_left=[6, 7, 8, 9, 10, 16, 17, 18, 19, 20, 21, 22, 23],
17 | joints_right=[1, 2, 3, 4, 5, 24, 25, 26, 27, 28, 29, 30, 31])
18 |
19 | h36m_cameras_intrinsic_params = [
20 | {
21 | 'id': '54138969',
22 | 'center': [512.54150390625, 515.4514770507812],
23 | 'focal_length': [1145.0494384765625, 1143.7811279296875],
24 | 'radial_distortion': [-0.20709891617298126, 0.24777518212795258, -0.0030751503072679043],
25 | 'tangential_distortion': [-0.0009756988729350269, -0.00142447161488235],
26 | 'res_w': 1000,
27 | 'res_h': 1002,
28 | 'azimuth': 70, # Only used for visualization
29 | },
30 | {
31 | 'id': '55011271',
32 | 'center': [508.8486328125, 508.0649108886719],
33 | 'focal_length': [1149.6756591796875, 1147.5916748046875],
34 | 'radial_distortion': [-0.1942136287689209, 0.2404085397720337, 0.006819975562393665],
35 | 'tangential_distortion': [-0.0016190266469493508, -0.0027408944442868233],
36 | 'res_w': 1000,
37 | 'res_h': 1000,
38 | 'azimuth': -70, # Only used for visualization
39 | },
40 | {
41 | 'id': '58860488',
42 | 'center': [519.8158569335938, 501.40264892578125],
43 | 'focal_length': [1149.1407470703125, 1148.7989501953125],
44 | 'radial_distortion': [-0.2083381861448288, 0.25548800826072693, -0.0024604974314570427],
45 | 'tangential_distortion': [0.0014843869721516967, -0.0007599993259645998],
46 | 'res_w': 1000,
47 | 'res_h': 1000,
48 | 'azimuth': 110, # Only used for visualization
49 | },
50 | {
51 | 'id': '60457274',
52 | 'center': [514.9682006835938, 501.88201904296875],
53 | 'focal_length': [1145.5113525390625, 1144.77392578125],
54 | 'radial_distortion': [-0.198384091258049, 0.21832367777824402, -0.008947807364165783],
55 | 'tangential_distortion': [-0.0005872055771760643, -0.0018133620033040643],
56 | 'res_w': 1000,
57 | 'res_h': 1002,
58 | 'azimuth': -110, # Only used for visualization
59 | },
60 | ]
61 |
62 | h36m_cameras_extrinsic_params = {
63 | 'S1': [
64 | {
65 | 'orientation': [0.1407056450843811, -0.1500701755285263, -0.755240797996521, 0.6223280429840088],
66 | 'translation': [1841.1070556640625, 4955.28466796875, 1563.4454345703125],
67 | },
68 | {
69 | 'orientation': [0.6157187819480896, -0.764836311340332, -0.14833825826644897, 0.11794740706682205],
70 | 'translation': [1761.278564453125, -5078.0068359375, 1606.2650146484375],
71 | },
72 | {
73 | 'orientation': [0.14651472866535187, -0.14647851884365082, 0.7653023600578308, -0.6094175577163696],
74 | 'translation': [-1846.7777099609375, 5215.04638671875, 1491.972412109375],
75 | },
76 | {
77 | 'orientation': [0.5834008455276489, -0.7853162288665771, 0.14548823237419128, -0.14749594032764435],
78 | 'translation': [-1794.7896728515625, -3722.698974609375, 1574.8927001953125],
79 | },
80 | ],
81 | 'S2': [
82 | {},
83 | {},
84 | {},
85 | {},
86 | ],
87 | 'S3': [
88 | {},
89 | {},
90 | {},
91 | {},
92 | ],
93 | 'S4': [
94 | {},
95 | {},
96 | {},
97 | {},
98 | ],
99 | 'S5': [
100 | {
101 | 'orientation': [0.1467377245426178, -0.162370964884758, -0.7551892995834351, 0.6178938746452332],
102 | 'translation': [2097.3916015625, 4880.94482421875, 1605.732421875],
103 | },
104 | {
105 | 'orientation': [0.6159758567810059, -0.7626792192459106, -0.15728192031383514, 0.1189815029501915],
106 | 'translation': [2031.7008056640625, -5167.93310546875, 1612.923095703125],
107 | },
108 | {
109 | 'orientation': [0.14291371405124664, -0.12907841801643372, 0.7678384780883789, -0.6110143065452576],
110 | 'translation': [-1620.5948486328125, 5171.65869140625, 1496.43701171875],
111 | },
112 | {
113 | 'orientation': [0.5920479893684387, -0.7814217805862427, 0.1274748593568802, -0.15036417543888092],
114 | 'translation': [-1637.1737060546875, -3867.3173828125, 1547.033203125],
115 | },
116 | ],
117 | 'S6': [
118 | {
119 | 'orientation': [0.1337897777557373, -0.15692396461963654, -0.7571090459823608, 0.6198879480361938],
120 | 'translation': [1935.4517822265625, 4950.24560546875, 1618.0838623046875],
121 | },
122 | {
123 | 'orientation': [0.6147197484970093, -0.7628812789916992, -0.16174767911434174, 0.11819244921207428],
124 | 'translation': [1969.803955078125, -5128.73876953125, 1632.77880859375],
125 | },
126 | {
127 | 'orientation': [0.1529948115348816, -0.13529130816459656, 0.7646096348762512, -0.6112781167030334],
128 | 'translation': [-1769.596435546875, 5185.361328125, 1476.993408203125],
129 | },
130 | {
131 | 'orientation': [0.5916101336479187, -0.7804774045944214, 0.12832270562648773, -0.1561593860387802],
132 | 'translation': [-1721.668701171875, -3884.13134765625, 1540.4879150390625],
133 | },
134 | ],
135 | 'S7': [
136 | {
137 | 'orientation': [0.1435241848230362, -0.1631336808204651, -0.7548328638076782, 0.6188824772834778],
138 | 'translation': [1974.512939453125, 4926.3544921875, 1597.8326416015625],
139 | },
140 | {
141 | 'orientation': [0.6141672730445862, -0.7638262510299683, -0.1596645563840866, 0.1177929937839508],
142 | 'translation': [1937.0584716796875, -5119.7900390625, 1631.5665283203125],
143 | },
144 | {
145 | 'orientation': [0.14550060033798218, -0.12874816358089447, 0.7660516500473022, -0.6127139329910278],
146 | 'translation': [-1741.8111572265625, 5208.24951171875, 1464.8245849609375],
147 | },
148 | {
149 | 'orientation': [0.5912848114967346, -0.7821764349937439, 0.12445473670959473, -0.15196487307548523],
150 | 'translation': [-1734.7105712890625, -3832.42138671875, 1548.5830078125],
151 | },
152 | ],
153 | 'S8': [
154 | {
155 | 'orientation': [0.14110587537288666, -0.15589867532253265, -0.7561917304992676, 0.619644045829773],
156 | 'translation': [2150.65185546875, 4896.1611328125, 1611.9046630859375],
157 | },
158 | {
159 | 'orientation': [0.6169601678848267, -0.7647668123245239, -0.14846350252628326, 0.11158157885074615],
160 | 'translation': [2219.965576171875, -5148.453125, 1613.0440673828125],
161 | },
162 | {
163 | 'orientation': [0.1471444070339203, -0.13377119600772858, 0.7670128345489502, -0.6100369691848755],
164 | 'translation': [-1571.2215576171875, 5137.0185546875, 1498.1761474609375],
165 | },
166 | {
167 | 'orientation': [0.5927824378013611, -0.7825870513916016, 0.12147816270589828, -0.14631995558738708],
168 | 'translation': [-1476.913330078125, -3896.7412109375, 1547.97216796875],
169 | },
170 | ],
171 | 'S9': [
172 | {
173 | 'orientation': [0.15540587902069092, -0.15548215806484222, -0.7532095313072205, 0.6199594736099243],
174 | 'translation': [2044.45849609375, 4935.1171875, 1481.2275390625],
175 | },
176 | {
177 | 'orientation': [0.618784487247467, -0.7634735107421875, -0.14132238924503326, 0.11933968216180801],
178 | 'translation': [1990.959716796875, -5123.810546875, 1568.8048095703125],
179 | },
180 | {
181 | 'orientation': [0.13357827067375183, -0.1367100477218628, 0.7689454555511475, -0.6100738644599915],
182 | 'translation': [-1670.9921875, 5211.98583984375, 1528.387939453125],
183 | },
184 | {
185 | 'orientation': [0.5879399180412292, -0.7823407053947449, 0.1427614390850067, -0.14794869720935822],
186 | 'translation': [-1696.04345703125, -3827.099853515625, 1591.4127197265625],
187 | },
188 | ],
189 | 'S11': [
190 | {
191 | 'orientation': [0.15232472121715546, -0.15442320704460144, -0.7547563314437866, 0.6191070079803467],
192 | 'translation': [2098.440185546875, 4926.5546875, 1500.278564453125],
193 | },
194 | {
195 | 'orientation': [0.6189449429512024, -0.7600917220115662, -0.15300633013248444, 0.1255258321762085],
196 | 'translation': [2083.182373046875, -4912.1728515625, 1561.07861328125],
197 | },
198 | {
199 | 'orientation': [0.14943228662014008, -0.15650227665901184, 0.7681233882904053, -0.6026304364204407],
200 | 'translation': [-1609.8153076171875, 5177.3359375, 1537.896728515625],
201 | },
202 | {
203 | 'orientation': [0.5894251465797424, -0.7818877100944519, 0.13991211354732513, -0.14715361595153809],
204 | 'translation': [-1590.738037109375, -3854.1689453125, 1578.017578125],
205 | },
206 | ],
207 | }
208 |
209 | class Human36mDataset(MocapDataset):
210 | def __init__(self, path, remove_static_joints=True):
211 | super().__init__(fps=50, skeleton=h36m_skeleton)
212 |
213 | self._cameras = copy.deepcopy(h36m_cameras_extrinsic_params)
214 | for cameras in self._cameras.values():
215 | for i, cam in enumerate(cameras):
216 | cam.update(h36m_cameras_intrinsic_params[i])
217 | for k, v in cam.items():
218 | if k not in ['id', 'res_w', 'res_h']:
219 | cam[k] = np.array(v, dtype='float32')
220 |
221 | # Normalize camera frame
222 | cam['center'] = normalize_screen_coordinates(cam['center'], w=cam['res_w'], h=cam['res_h']).astype('float32')
223 | cam['focal_length'] = cam['focal_length']/cam['res_w']*2
224 | if 'translation' in cam:
225 | cam['translation'] = cam['translation']/1000 # mm to meters
226 |
227 | # Add intrinsic parameters vector
228 | cam['intrinsic'] = np.concatenate((cam['focal_length'],
229 | cam['center'],
230 | cam['radial_distortion'],
231 | cam['tangential_distortion']))
232 |
233 | # Load serialized dataset
234 | data = np.load(path, allow_pickle=True)['positions_3d'].item()
235 |
236 | self._data = {}
237 | for subject, actions in data.items():
238 | self._data[subject] = {}
239 | for action_name, positions in actions.items():
240 | self._data[subject][action_name] = {
241 | 'positions': positions,
242 | 'cameras': self._cameras[subject],
243 | }
244 |
245 | if remove_static_joints:
246 | # Bring the skeleton to 17 joints instead of the original 32
247 | self.remove_joints([4, 5, 9, 10, 11, 16, 20, 21, 22, 23, 24, 28, 29, 30, 31])
248 |
249 | # Rewire shoulders to the correct parents
250 | self._skeleton._parents[11] = 8
251 | self._skeleton._parents[14] = 8
252 |
253 | def supports_semi_supervised(self):
254 | return True
255 |
--------------------------------------------------------------------------------
/common/humaneva_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2018-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | import numpy as np
9 | import copy
10 | from common.skeleton import Skeleton
11 | from common.mocap_dataset import MocapDataset
12 | from common.camera import normalize_screen_coordinates, image_coordinates
13 |
14 | humaneva_skeleton = Skeleton(parents=[-1, 0, 1, 2, 3, 1, 5, 6, 0, 8, 9, 0, 11, 12, 1],
15 | joints_left=[2, 3, 4, 8, 9, 10],
16 | joints_right=[5, 6, 7, 11, 12, 13])
17 |
18 | humaneva_cameras_intrinsic_params = [
19 | {
20 | 'id': 'C1',
21 | 'res_w': 640,
22 | 'res_h': 480,
23 | 'azimuth': 0, # Only used for visualization
24 | },
25 | {
26 | 'id': 'C2',
27 | 'res_w': 640,
28 | 'res_h': 480,
29 | 'azimuth': -90, # Only used for visualization
30 | },
31 | {
32 | 'id': 'C3',
33 | 'res_w': 640,
34 | 'res_h': 480,
35 | 'azimuth': 90, # Only used for visualization
36 | },
37 | ]
38 |
39 | humaneva_cameras_extrinsic_params = {
40 | 'S1': [
41 | {
42 | 'orientation': [0.424207, -0.4983646, -0.5802981, 0.4847012],
43 | 'translation': [4062.227, 663.2477, 1528.397],
44 | },
45 | {
46 | 'orientation': [0.6503354, -0.7481602, -0.0919284, 0.0941766],
47 | 'translation': [844.8131, -3805.2092, 1504.9929],
48 | },
49 | {
50 | 'orientation': [0.0664734, -0.0690535, 0.7416416, -0.6639132],
51 | 'translation': [-797.67377, 3916.3174, 1433.6602],
52 | },
53 | ],
54 | 'S2': [
55 | {
56 | 'orientation': [ 0.4214752, -0.4961493, -0.5838273, 0.4851187 ],
57 | 'translation': [ 4112.9121, 626.4929, 1545.2988],
58 | },
59 | {
60 | 'orientation': [ 0.6501393, -0.7476588, -0.0954617, 0.0959808 ],
61 | 'translation': [ 923.5740, -3877.9243, 1504.5518],
62 | },
63 | {
64 | 'orientation': [ 0.0699353, -0.0712403, 0.7421637, -0.662742 ],
65 | 'translation': [ -781.4915, 3838.8853, 1444.9929],
66 | },
67 | ],
68 | 'S3': [
69 | {
70 | 'orientation': [ 0.424207, -0.4983646, -0.5802981, 0.4847012 ],
71 | 'translation': [ 4062.2271, 663.2477, 1528.3970],
72 | },
73 | {
74 | 'orientation': [ 0.6503354, -0.7481602, -0.0919284, 0.0941766 ],
75 | 'translation': [ 844.8131, -3805.2092, 1504.9929],
76 | },
77 | {
78 | 'orientation': [ 0.0664734, -0.0690535, 0.7416416, -0.6639132 ],
79 | 'translation': [ -797.6738, 3916.3174, 1433.6602],
80 | },
81 | ],
82 | 'S4': [
83 | {},
84 | {},
85 | {},
86 | ],
87 |
88 | }
89 |
90 | class HumanEvaDataset(MocapDataset):
91 | def __init__(self, path):
92 | super().__init__(fps=60, skeleton=humaneva_skeleton)
93 |
94 | self._cameras = copy.deepcopy(humaneva_cameras_extrinsic_params)
95 | for cameras in self._cameras.values():
96 | for i, cam in enumerate(cameras):
97 | cam.update(humaneva_cameras_intrinsic_params[i])
98 | for k, v in cam.items():
99 | if k not in ['id', 'res_w', 'res_h']:
100 | cam[k] = np.array(v, dtype='float32')
101 | if 'translation' in cam:
102 | cam['translation'] = cam['translation']/1000 # mm to meters
103 |
104 | for subject in list(self._cameras.keys()):
105 | data = self._cameras[subject]
106 | del self._cameras[subject]
107 | for prefix in ['Train/', 'Validate/', 'Unlabeled/Train/', 'Unlabeled/Validate/', 'Unlabeled/']:
108 | self._cameras[prefix + subject] = data
109 |
110 | # Load serialized dataset
111 | data = np.load(path, allow_pickle=True)['positions_3d'].item()
112 |
113 | self._data = {}
114 | for subject, actions in data.items():
115 | self._data[subject] = {}
116 | for action_name, positions in actions.items():
117 | self._data[subject][action_name] = {
118 | 'positions': positions,
119 | 'cameras': self._cameras[subject],
120 | }
121 |
--------------------------------------------------------------------------------
/common/loss.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2018-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | import torch
9 | import numpy as np
10 |
11 | def mpjpe(predicted, target):
12 | """
13 | Mean per-joint position error (i.e. mean Euclidean distance),
14 | often referred to as "Protocol #1" in many papers.
15 | """
16 | assert predicted.shape == target.shape
17 | return torch.mean(torch.norm(predicted - target, dim=len(target.shape)-1))
18 |
19 | def weighted_mpjpe(predicted, target, w):
20 | """
21 | Weighted mean per-joint position error (i.e. mean Euclidean distance)
22 | """
23 | assert predicted.shape == target.shape
24 | assert w.shape[0] == predicted.shape[0]
25 | return torch.mean(w * torch.norm(predicted - target, dim=len(target.shape)-1))
26 |
27 | def p_mpjpe(predicted, target):
28 | """
29 | Pose error: MPJPE after rigid alignment (scale, rotation, and translation),
30 | often referred to as "Protocol #2" in many papers.
31 | """
32 | assert predicted.shape == target.shape
33 |
34 | muX = np.mean(target, axis=1, keepdims=True)
35 | muY = np.mean(predicted, axis=1, keepdims=True)
36 |
37 | X0 = target - muX
38 | Y0 = predicted - muY
39 |
40 | normX = np.sqrt(np.sum(X0**2, axis=(1, 2), keepdims=True))
41 | normY = np.sqrt(np.sum(Y0**2, axis=(1, 2), keepdims=True))
42 |
43 | X0 /= normX
44 | Y0 /= normY
45 |
46 | H = np.matmul(X0.transpose(0, 2, 1), Y0)
47 | U, s, Vt = np.linalg.svd(H)
48 | V = Vt.transpose(0, 2, 1)
49 | R = np.matmul(V, U.transpose(0, 2, 1))
50 |
51 | # Avoid improper rotations (reflections), i.e. rotations with det(R) = -1
52 | sign_detR = np.sign(np.expand_dims(np.linalg.det(R), axis=1))
53 | V[:, :, -1] *= sign_detR
54 | s[:, -1] *= sign_detR.flatten()
55 | R = np.matmul(V, U.transpose(0, 2, 1)) # Rotation
56 |
57 | tr = np.expand_dims(np.sum(s, axis=1, keepdims=True), axis=2)
58 |
59 | a = tr * normX / normY # Scale
60 | t = muX - a*np.matmul(muY, R) # Translation
61 |
62 | # Perform rigid transformation on the input
63 | predicted_aligned = a*np.matmul(predicted, R) + t
64 |
65 | # Return MPJPE
66 | return np.mean(np.linalg.norm(predicted_aligned - target, axis=len(target.shape)-1))
67 |
68 | def n_mpjpe(predicted, target):
69 | """
70 | Normalized MPJPE (scale only), adapted from:
71 | https://github.com/hrhodin/UnsupervisedGeometryAwareRepresentationLearning/blob/master/losses/poses.py
72 | """
73 | assert predicted.shape == target.shape
74 |
75 | norm_predicted = torch.mean(torch.sum(predicted**2, dim=3, keepdim=True), dim=2, keepdim=True)
76 | norm_target = torch.mean(torch.sum(target*predicted, dim=3, keepdim=True), dim=2, keepdim=True)
77 | scale = norm_target / norm_predicted
78 | return mpjpe(scale * predicted, target)
79 |
80 | def weighted_bonelen_loss(predict_3d_length, gt_3d_length):
81 | loss_length = 0.001 * torch.pow(predict_3d_length - gt_3d_length, 2).mean()
82 | return loss_length
83 |
84 | def weighted_boneratio_loss(predict_3d_length, gt_3d_length):
85 | loss_length = 0.1 * torch.pow((predict_3d_length - gt_3d_length)/gt_3d_length, 2).mean()
86 | return loss_length
87 |
88 | def mean_velocity_error(predicted, target):
89 | """
90 | Mean per-joint velocity error (i.e. mean Euclidean distance of the 1st derivative)
91 | """
92 | assert predicted.shape == target.shape
93 |
94 | velocity_predicted = np.diff(predicted, axis=0)
95 | velocity_target = np.diff(target, axis=0)
96 |
97 | return np.mean(np.linalg.norm(velocity_predicted - velocity_target, axis=len(target.shape)-1))
--------------------------------------------------------------------------------
/common/mocap_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2018-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | import numpy as np
9 | from common.skeleton import Skeleton
10 |
11 | class MocapDataset:
12 | def __init__(self, fps, skeleton):
13 | self._skeleton = skeleton
14 | self._fps = fps
15 | self._data = None # Must be filled by subclass
16 | self._cameras = None # Must be filled by subclass
17 |
18 | def remove_joints(self, joints_to_remove):
19 | kept_joints = self._skeleton.remove_joints(joints_to_remove)
20 | for subject in self._data.keys():
21 | for action in self._data[subject].keys():
22 | s = self._data[subject][action]
23 | if 'positions' in s:
24 | s['positions'] = s['positions'][:, kept_joints]
25 |
26 |
27 | def __getitem__(self, key):
28 | return self._data[key]
29 |
30 | def subjects(self):
31 | return self._data.keys()
32 |
33 | def fps(self):
34 | return self._fps
35 |
36 | def skeleton(self):
37 | return self._skeleton
38 |
39 | def cameras(self):
40 | return self._cameras
41 |
42 | def supports_semi_supervised(self):
43 | # This method can be overridden
44 | return False
--------------------------------------------------------------------------------
/common/model_poseformer.py:
--------------------------------------------------------------------------------
1 | ## Our PoseFormer model was revised from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
2 |
3 | import math
4 | import logging
5 | from functools import partial
6 | from collections import OrderedDict
7 | from einops import rearrange, repeat
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 |
13 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
14 | from timm.models.helpers import load_pretrained
15 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_
16 | from timm.models.registry import register_model
17 |
18 |
19 | class Mlp(nn.Module):
20 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
21 | super().__init__()
22 | out_features = out_features or in_features
23 | hidden_features = hidden_features or in_features
24 | self.fc1 = nn.Linear(in_features, hidden_features)
25 | self.act = act_layer()
26 | self.fc2 = nn.Linear(hidden_features, out_features)
27 | self.drop = nn.Dropout(drop)
28 |
29 | def forward(self, x):
30 | x = self.fc1(x)
31 | x = self.act(x)
32 | x = self.drop(x)
33 | x = self.fc2(x)
34 | x = self.drop(x)
35 | return x
36 |
37 |
38 | class Attention(nn.Module):
39 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
40 | super().__init__()
41 | self.num_heads = num_heads
42 | head_dim = dim // num_heads
43 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
44 | self.scale = qk_scale or head_dim ** -0.5
45 |
46 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
47 | self.attn_drop = nn.Dropout(attn_drop)
48 | self.proj = nn.Linear(dim, dim)
49 | self.proj_drop = nn.Dropout(proj_drop)
50 |
51 | def forward(self, x):
52 | B, N, C = x.shape
53 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
54 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
55 |
56 | attn = (q @ k.transpose(-2, -1)) * self.scale
57 | attn = attn.softmax(dim=-1)
58 | attn = self.attn_drop(attn)
59 |
60 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
61 | x = self.proj(x)
62 | x = self.proj_drop(x)
63 | return x
64 |
65 |
66 | class Block(nn.Module):
67 |
68 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
69 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
70 | super().__init__()
71 | self.norm1 = norm_layer(dim)
72 | self.attn = Attention(
73 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
74 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
75 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
76 | self.norm2 = norm_layer(dim)
77 | mlp_hidden_dim = int(dim * mlp_ratio)
78 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
79 |
80 | def forward(self, x):
81 | x = x + self.drop_path(self.attn(self.norm1(x)))
82 | x = x + self.drop_path(self.mlp(self.norm2(x)))
83 | return x
84 |
85 | class PoseTransformer(nn.Module):
86 | def __init__(self, num_frame=9, num_joints=17, in_chans=2, embed_dim_ratio=32, depth=4,
87 | num_heads=8, mlp_ratio=2., qkv_bias=True, qk_scale=None,
88 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.2, norm_layer=None):
89 | """ ##########hybrid_backbone=None, representation_size=None,
90 | Args:
91 | num_frame (int, tuple): input frame number
92 | num_joints (int, tuple): joints number
93 | in_chans (int): number of input channels, 2D joints have 2 channels: (x,y)
94 | embed_dim_ratio (int): embedding dimension ratio
95 | depth (int): depth of transformer
96 | num_heads (int): number of attention heads
97 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim
98 | qkv_bias (bool): enable bias for qkv if True
99 | qk_scale (float): override default qk scale of head_dim ** -0.5 if set
100 | drop_rate (float): dropout rate
101 | attn_drop_rate (float): attention dropout rate
102 | drop_path_rate (float): stochastic depth rate
103 | norm_layer: (nn.Module): normalization layer
104 | """
105 | super().__init__()
106 |
107 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
108 | embed_dim = embed_dim_ratio * num_joints #### temporal embed_dim is num_joints * spatial embedding dim ratio
109 | out_dim = num_joints * 3 #### output dimension is num_joints * 3
110 |
111 | ### spatial patch embedding
112 | self.Spatial_patch_to_embedding = nn.Linear(in_chans, embed_dim_ratio)
113 | self.Spatial_pos_embed = nn.Parameter(torch.zeros(1, num_joints, embed_dim_ratio))
114 |
115 | self.Temporal_pos_embed = nn.Parameter(torch.zeros(1, num_frame, embed_dim))
116 | self.pos_drop = nn.Dropout(p=drop_rate)
117 |
118 |
119 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
120 |
121 | self.Spatial_blocks = nn.ModuleList([
122 | Block(
123 | dim=embed_dim_ratio, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
124 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
125 | for i in range(depth)])
126 |
127 | self.blocks = nn.ModuleList([
128 | Block(
129 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
130 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
131 | for i in range(depth)])
132 |
133 | self.Spatial_norm = norm_layer(embed_dim_ratio)
134 | self.Temporal_norm = norm_layer(embed_dim)
135 |
136 | ####### A easy way to implement weighted mean
137 | self.weighted_mean = torch.nn.Conv1d(in_channels=num_frame, out_channels=1, kernel_size=1)
138 |
139 | self.head = nn.Sequential(
140 | nn.LayerNorm(embed_dim),
141 | nn.Linear(embed_dim , out_dim),
142 | )
143 |
144 |
145 | def Spatial_forward_features(self, x):
146 | b, _, f, p = x.shape ##### b is batch size, f is number of frames, p is number of joints
147 | x = rearrange(x, 'b c f p -> (b f) p c', )
148 |
149 | x = self.Spatial_patch_to_embedding(x)
150 | x += self.Spatial_pos_embed
151 | x = self.pos_drop(x)
152 |
153 | for blk in self.Spatial_blocks:
154 | x = blk(x)
155 |
156 | x = self.Spatial_norm(x)
157 | x = rearrange(x, '(b f) w c -> b f (w c)', f=f)
158 | return x
159 |
160 | def forward_features(self, x):
161 | b = x.shape[0]
162 | x += self.Temporal_pos_embed
163 | x = self.pos_drop(x)
164 | for blk in self.blocks:
165 | x = blk(x)
166 |
167 | x = self.Temporal_norm(x)
168 | ##### x size [b, f, emb_dim], then take weighted mean on frame dimension, we only predict 3D pose of the center frame
169 | x = self.weighted_mean(x)
170 | x = x.view(b, 1, -1)
171 | return x
172 |
173 |
174 | def forward(self, x):
175 | x = x.permute(0, 3, 1, 2)
176 | b, _, _, p = x.shape
177 | ### now x is [batch_size, 2 channels, receptive frames, joint_num], following image data
178 | x = self.Spatial_forward_features(x)
179 | x = self.forward_features(x)
180 | x = self.head(x)
181 |
182 | x = x.view(b, 1, p, -1)
183 |
184 | return x
185 |
186 |
--------------------------------------------------------------------------------
/common/quaternion.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2018-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | import torch
9 |
10 | def qrot(q, v):
11 | """
12 | Rotate vector(s) v about the rotation described by quaternion(s) q.
13 | Expects a tensor of shape (*, 4) for q and a tensor of shape (*, 3) for v,
14 | where * denotes any number of dimensions.
15 | Returns a tensor of shape (*, 3).
16 | """
17 | assert q.shape[-1] == 4
18 | assert v.shape[-1] == 3
19 | assert q.shape[:-1] == v.shape[:-1]
20 |
21 | qvec = q[..., 1:]
22 | uv = torch.cross(qvec, v, dim=len(q.shape)-1)
23 | uuv = torch.cross(qvec, uv, dim=len(q.shape)-1)
24 | return (v + 2 * (q[..., :1] * uv + uuv))
25 |
26 |
27 | def qinverse(q, inplace=False):
28 | # We assume the quaternion to be normalized
29 | if inplace:
30 | q[..., 1:] *= -1
31 | return q
32 | else:
33 | w = q[..., :1]
34 | xyz = q[..., 1:]
35 | return torch.cat((w, -xyz), dim=len(q.shape)-1)
--------------------------------------------------------------------------------
/common/skeleton.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2018-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | import numpy as np
9 |
10 | class Skeleton:
11 | def __init__(self, parents, joints_left, joints_right):
12 | assert len(joints_left) == len(joints_right)
13 |
14 | self._parents = np.array(parents)
15 | self._joints_left = joints_left
16 | self._joints_right = joints_right
17 | self._compute_metadata()
18 |
19 | def num_joints(self):
20 | return len(self._parents)
21 |
22 | def parents(self):
23 | return self._parents
24 |
25 | def has_children(self):
26 | return self._has_children
27 |
28 | def children(self):
29 | return self._children
30 |
31 | def remove_joints(self, joints_to_remove):
32 | """
33 | Remove the joints specified in 'joints_to_remove'.
34 | """
35 | valid_joints = []
36 | for joint in range(len(self._parents)):
37 | if joint not in joints_to_remove:
38 | valid_joints.append(joint)
39 |
40 | for i in range(len(self._parents)):
41 | while self._parents[i] in joints_to_remove:
42 | self._parents[i] = self._parents[self._parents[i]]
43 |
44 | index_offsets = np.zeros(len(self._parents), dtype=int)
45 | new_parents = []
46 | for i, parent in enumerate(self._parents):
47 | if i not in joints_to_remove:
48 | new_parents.append(parent - index_offsets[parent])
49 | else:
50 | index_offsets[i:] += 1
51 | self._parents = np.array(new_parents)
52 |
53 |
54 | if self._joints_left is not None:
55 | new_joints_left = []
56 | for joint in self._joints_left:
57 | if joint in valid_joints:
58 | new_joints_left.append(joint - index_offsets[joint])
59 | self._joints_left = new_joints_left
60 | if self._joints_right is not None:
61 | new_joints_right = []
62 | for joint in self._joints_right:
63 | if joint in valid_joints:
64 | new_joints_right.append(joint - index_offsets[joint])
65 | self._joints_right = new_joints_right
66 |
67 | self._compute_metadata()
68 |
69 | return valid_joints
70 |
71 | def joints_left(self):
72 | return self._joints_left
73 |
74 | def joints_right(self):
75 | return self._joints_right
76 |
77 | def _compute_metadata(self):
78 | self._has_children = np.zeros(len(self._parents)).astype(bool)
79 | for i, parent in enumerate(self._parents):
80 | if parent != -1:
81 | self._has_children[parent] = True
82 |
83 | self._children = []
84 | for i, parent in enumerate(self._parents):
85 | self._children.append([])
86 | for i, parent in enumerate(self._parents):
87 | if parent != -1:
88 | self._children[parent].append(i)
--------------------------------------------------------------------------------
/common/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2018-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | import torch
9 | import numpy as np
10 | import hashlib
11 |
12 | def wrap(func, *args, unsqueeze=False):
13 | """
14 | Wrap a torch function so it can be called with NumPy arrays.
15 | Input and return types are seamlessly converted.
16 | """
17 |
18 | # Convert input types where applicable
19 | args = list(args)
20 | for i, arg in enumerate(args):
21 | if type(arg) == np.ndarray:
22 | args[i] = torch.from_numpy(arg)
23 | if unsqueeze:
24 | args[i] = args[i].unsqueeze(0)
25 |
26 | result = func(*args)
27 |
28 | # Convert output types where applicable
29 | if isinstance(result, tuple):
30 | result = list(result)
31 | for i, res in enumerate(result):
32 | if type(res) == torch.Tensor:
33 | if unsqueeze:
34 | res = res.squeeze(0)
35 | result[i] = res.numpy()
36 | return tuple(result)
37 | elif type(result) == torch.Tensor:
38 | if unsqueeze:
39 | result = result.squeeze(0)
40 | return result.numpy()
41 | else:
42 | return result
43 |
44 | def deterministic_random(min_value, max_value, data):
45 | digest = hashlib.sha256(data.encode()).digest()
46 | raw_value = int.from_bytes(digest[:4], byteorder='little', signed=False)
47 | return int(raw_value / (2**32 - 1) * (max_value - min_value)) + min_value
48 |
49 | def load_pretrained_weights(model, checkpoint):
50 | """Load pretrianed weights to model
51 | Incompatible layers (unmatched in name or size) will be ignored
52 | Args:
53 | - model (nn.Module): network model, which must not be nn.DataParallel
54 | - weight_path (str): path to pretrained weights
55 | """
56 | import collections
57 | if 'state_dict' in checkpoint:
58 | state_dict = checkpoint['state_dict']
59 | else:
60 | state_dict = checkpoint
61 | model_dict = model.state_dict()
62 | new_state_dict = collections.OrderedDict()
63 | matched_layers, discarded_layers = [], []
64 | for k, v in state_dict.items():
65 | # If the pretrained state_dict was saved as nn.DataParallel,
66 | # keys would contain "module.", which should be ignored.
67 | if k.startswith('module.'):
68 | k = k[7:]
69 | if k in model_dict and model_dict[k].size() == v.size():
70 | new_state_dict[k] = v
71 | matched_layers.append(k)
72 | else:
73 | discarded_layers.append(k)
74 | # new_state_dict.requires_grad = False
75 | model_dict.update(new_state_dict)
76 |
77 | model.load_state_dict(model_dict)
78 | print('load_weight', len(matched_layers))
79 | # model.state_dict(model_dict).requires_grad = False
80 | return model
81 |
--------------------------------------------------------------------------------
/common/visualization.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2018-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | import matplotlib
9 |
10 | matplotlib.use('Agg')
11 |
12 |
13 | import matplotlib.pyplot as plt
14 | from matplotlib.animation import FuncAnimation, writers
15 | from mpl_toolkits.mplot3d import Axes3D
16 | import numpy as np
17 | import subprocess as sp
18 | import cv2
19 |
20 |
21 | def get_resolution(filename):
22 | command = ['ffprobe', '-v', 'error', '-select_streams', 'v:0',
23 | '-show_entries', 'stream=width,height', '-of', 'csv=p=0', filename]
24 | with sp.Popen(command, stdout=sp.PIPE, bufsize=-1) as pipe:
25 | for line in pipe.stdout:
26 | w, h = line.decode().strip().split(',')
27 | return int(w), int(h)
28 |
29 |
30 | def get_fps(filename):
31 | command = ['ffprobe', '-v', 'error', '-select_streams', 'v:0',
32 | '-show_entries', 'stream=r_frame_rate', '-of', 'csv=p=0', filename]
33 | with sp.Popen(command, stdout=sp.PIPE, bufsize=-1) as pipe:
34 | for line in pipe.stdout:
35 | a, b = line.decode().strip().split('/')
36 | return int(a) / int(b)
37 |
38 |
39 | def read_video(filename, skip=0, limit=-1):
40 | # w, h = get_resolution(filename)
41 | w = 1000
42 | h = 1002
43 |
44 | command = ['ffmpeg',
45 | '-i', filename,
46 | '-f', 'image2pipe',
47 | '-pix_fmt', 'rgb24',
48 | '-vsync', '0',
49 | '-vcodec', 'rawvideo', '-']
50 |
51 | i = 0
52 | with sp.Popen(command, stdout=sp.PIPE, bufsize=-1) as pipe:
53 | while True:
54 | data = pipe.stdout.read(w * h * 3)
55 | if not data:
56 | break
57 | i += 1
58 | if i > limit and limit != -1:
59 | continue
60 | if i > skip:
61 | yield np.frombuffer(data, dtype='uint8').reshape((h, w, 3))
62 |
63 |
64 | def downsample_tensor(X, factor):
65 | length = X.shape[0] // factor * factor
66 | return np.mean(X[:length].reshape(-1, factor, *X.shape[1:]), axis=1)
67 |
68 |
69 | def render_animation(keypoints, keypoints_metadata, poses, skeleton, fps, bitrate, azim, output, viewport,
70 | limit=-1, downsample=1, size=6, input_video_path=None, input_video_skip=0):
71 | """
72 | TODO
73 | Render an animation. The supported output modes are:
74 | -- 'interactive': display an interactive figure
75 | (also works on notebooks if associated with %matplotlib inline)
76 | -- 'html': render the animation as HTML5 video. Can be displayed in a notebook using HTML(...).
77 | -- 'filename.mp4': render and export the animation as an h264 video (requires ffmpeg).
78 | -- 'filename.gif': render and export the animation a gif file (requires imagemagick).
79 | """
80 | plt.ioff()
81 | fig = plt.figure(figsize=(size * (1 + len(poses)), size))
82 | ax_in = fig.add_subplot(1, 1 + len(poses), 1)
83 | ax_in.get_xaxis().set_visible(False)
84 | ax_in.get_yaxis().set_visible(False)
85 | ax_in.set_axis_off()
86 | ax_in.set_title('Input')
87 |
88 | ax_3d = []
89 | lines_3d = []
90 | trajectories = []
91 | radius = 1.7
92 | for index, (title, data) in enumerate(poses.items()):
93 | ax = fig.add_subplot(1, 1 + len(poses), index + 2, projection='3d')
94 | ax.view_init(elev=15., azim=azim)
95 | ax.set_xlim3d([-radius / 2, radius / 2])
96 | ax.set_zlim3d([0, radius])
97 | ax.set_ylim3d([-radius / 2, radius / 2])
98 | try:
99 | ax.set_aspect('equal')
100 | except NotImplementedError:
101 | ax.set_aspect('auto')
102 | ax.set_xticklabels([])
103 | ax.set_yticklabels([])
104 | ax.set_zticklabels([])
105 | ax.dist = 7.5
106 | ax.set_title(title) # , pad=35
107 | ax_3d.append(ax)
108 | lines_3d.append([])
109 | trajectories.append(data[:, 0, [0, 1]])
110 | poses = list(poses.values())
111 |
112 | # Decode video
113 | if input_video_path is None:
114 | # Black background
115 | all_frames = np.zeros((keypoints.shape[0], viewport[1], viewport[0]), dtype='uint8')
116 | else:
117 | # Load video using ffmpeg
118 | all_frames = []
119 | for f in read_video(input_video_path, skip=input_video_skip, limit=limit):
120 | all_frames.append(f)
121 | effective_length = min(keypoints.shape[0], len(all_frames))
122 | all_frames = all_frames[:effective_length]
123 |
124 | keypoints = keypoints[input_video_skip:] # todo remove
125 | for idx in range(len(poses)):
126 | poses[idx] = poses[idx][input_video_skip:]
127 |
128 | if fps is None:
129 | fps = get_fps(input_video_path)
130 |
131 | if downsample > 1:
132 | keypoints = downsample_tensor(keypoints, downsample)
133 | all_frames = downsample_tensor(np.array(all_frames), downsample).astype('uint8')
134 | for idx in range(len(poses)):
135 | poses[idx] = downsample_tensor(poses[idx], downsample)
136 | trajectories[idx] = downsample_tensor(trajectories[idx], downsample)
137 | fps /= downsample
138 |
139 | initialized = False
140 | image = None
141 | lines = []
142 | points = None
143 |
144 | if limit < 1:
145 | limit = len(all_frames)
146 | else:
147 | limit = min(limit, len(all_frames))
148 |
149 | parents = skeleton.parents()
150 |
151 | def update_video(i):
152 | nonlocal initialized, image, lines, points
153 |
154 | for n, ax in enumerate(ax_3d):
155 | ax.set_xlim3d([-radius / 2 + trajectories[n][i, 0], radius / 2 + trajectories[n][i, 0]])
156 | ax.set_ylim3d([-radius / 2 + trajectories[n][i, 1], radius / 2 + trajectories[n][i, 1]])
157 |
158 | # Update 2D poses
159 | joints_right_2d = keypoints_metadata['keypoints_symmetry'][1]
160 | colors_2d = np.full(keypoints.shape[1], 'black')
161 | colors_2d[joints_right_2d] = 'red'
162 | if not initialized:
163 | image = ax_in.imshow(all_frames[i], aspect='equal')
164 |
165 | for j, j_parent in enumerate(parents):
166 | if j_parent == -1:
167 | continue
168 |
169 | if len(parents) == keypoints.shape[1] and keypoints_metadata['layout_name'] != 'coco':
170 | # Draw skeleton only if keypoints match (otherwise we don't have the parents definition)
171 | lines.append(ax_in.plot([keypoints[i, j, 0], keypoints[i, j_parent, 0]],
172 | [keypoints[i, j, 1], keypoints[i, j_parent, 1]], color='pink'))
173 |
174 | col = 'red' if j in skeleton.joints_right() else 'black'
175 | for n, ax in enumerate(ax_3d):
176 | pos = poses[n][i]
177 | lines_3d[n].append(ax.plot([pos[j, 0], pos[j_parent, 0]],
178 | [pos[j, 1], pos[j_parent, 1]],
179 | [pos[j, 2], pos[j_parent, 2]], zdir='z', c=col))
180 |
181 | points = ax_in.scatter(*keypoints[i].T, 10, color=colors_2d, edgecolors='white', zorder=10)
182 |
183 | initialized = True
184 | else:
185 | image.set_data(all_frames[i])
186 |
187 | for j, j_parent in enumerate(parents):
188 | if j_parent == -1:
189 | continue
190 |
191 | if len(parents) == keypoints.shape[1] and keypoints_metadata['layout_name'] != 'coco':
192 | lines[j - 1][0].set_data([keypoints[i, j, 0], keypoints[i, j_parent, 0]],
193 | [keypoints[i, j, 1], keypoints[i, j_parent, 1]])
194 |
195 | for n, ax in enumerate(ax_3d):
196 | pos = poses[n][i]
197 | lines_3d[n][j - 1][0].set_xdata(np.array([pos[j, 0], pos[j_parent, 0]]))
198 | lines_3d[n][j - 1][0].set_ydata(np.array([pos[j, 1], pos[j_parent, 1]]))
199 | lines_3d[n][j - 1][0].set_3d_properties(np.array([pos[j, 2], pos[j_parent, 2]]), zdir='z')
200 |
201 | points.set_offsets(keypoints[i])
202 |
203 | print('{}/{} '.format(i, limit), end='\r')
204 |
205 | fig.tight_layout()
206 |
207 | anim = FuncAnimation(fig, update_video, frames=np.arange(0, limit), interval=1000 / fps, repeat=False)
208 | if output.endswith('.mp4'):
209 | Writer = writers['ffmpeg']
210 | writer = Writer(fps=fps, metadata={}, bitrate=bitrate)
211 | anim.save(output, writer=writer)
212 | elif output.endswith('.gif'):
213 | anim.save(output, dpi=80, writer='imagemagick')
214 | else:
215 | raise ValueError('Unsupported output format (only .mp4 and .gif are supported)')
216 | plt.close()
--------------------------------------------------------------------------------
/data/README.MD:
--------------------------------------------------------------------------------
1 | The dataset put at here.
2 |
3 | Our code is compatible with the dataset setup introduced by [Martinez et al.](https://github.com/una-dinosauria/3d-pose-baseline) and [Pavllo et al.](https://github.com/facebookresearch/VideoPose3D). Please refer to [VideoPose3D](https://github.com/facebookresearch/VideoPose3D) to set up the Human3.6M dataset (./data directory).
4 |
5 |
--------------------------------------------------------------------------------
/figure/H3.6.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zczcwh/PoseFormer/a908b29adca354b2e0acd36606229ca19b6a5e3c/figure/H3.6.gif
--------------------------------------------------------------------------------
/figure/PoseFormer.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zczcwh/PoseFormer/a908b29adca354b2e0acd36606229ca19b6a5e3c/figure/PoseFormer.gif
--------------------------------------------------------------------------------
/figure/wild.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zczcwh/PoseFormer/a908b29adca354b2e0acd36606229ca19b6a5e3c/figure/wild.gif
--------------------------------------------------------------------------------
/poseformer.yml:
--------------------------------------------------------------------------------
1 | name: pose2
2 | channels:
3 | - pytorch
4 | - anaconda
5 | - conda-forge
6 | - defaults
7 | dependencies:
8 | - _libgcc_mutex=0.1=main
9 | - _tflow_select=2.3.0=mkl
10 | - absl-py=0.10.0=py38h32f6830_1
11 | - aiohttp=3.7.2=py38h1e0a361_0
12 | - astunparse=1.6.3=py_0
13 | - async-timeout=3.0.1=py_1000
14 | - attrs=20.2.0=pyh9f0ad1d_0
15 | - blas=1.0=mkl
16 | - blinker=1.4=py_1
17 | - brotlipy=0.7.0=py38h7b6447c_1000
18 | - bzip2=1.0.8=h7b6447c_0
19 | - c-ares=1.16.1=h516909a_3
20 | - ca-certificates=2021.1.19=h06a4308_0
21 | - cachetools=4.1.1=py_0
22 | - cairo=1.16.0=h3fc0475_1005
23 | - cdflib=0.3.18=py_0
24 | - certifi=2020.12.5=py38h06a4308_0
25 | - cffi=1.14.0=py38h2e261b9_0
26 | - chardet=3.0.4=py38_1003
27 | - click=7.1.2=pyh9f0ad1d_0
28 | - cryptography=3.1.1=py38h1ba5d50_0
29 | - cudatoolkit=11.0.221=h6bb024c_0
30 | - cycler=0.10.0=py_2
31 | - cython=0.29.21=py38he6710b0_0
32 | - dbus=1.13.16=hb2f20db_0
33 | - einops=0.3.0=py_0
34 | - expat=2.2.9=he6710b0_2
35 | - fontconfig=2.13.1=h1056068_1002
36 | - freetype=2.10.2=h5ab3b9f_0
37 | - gast=0.3.3=py_0
38 | - glib=2.63.1=h5a9c865_0
39 | - gmp=6.2.0=he1b5a44_2
40 | - gnutls=3.6.13=h79a8f9a_0
41 | - google-auth=1.23.0=pyhd8ed1ab_0
42 | - google-auth-oauthlib=0.4.1=py_2
43 | - google-pasta=0.2.0=py_0
44 | - graphite2=1.3.14=h23475e2_0
45 | - grpcio=1.33.2=py38heead2fc_0
46 | - gst-plugins-base=1.14.5=h0935bb2_2
47 | - gstreamer=1.14.5=h36ae1b5_2
48 | - h5py=2.10.0=py38hd6299e0_1
49 | - harfbuzz=2.7.2=hee91db6_0
50 | - hdf5=1.10.6=hb1b8bf9_0
51 | - icu=67.1=he1b5a44_0
52 | - idna=2.10=py_0
53 | - importlib-metadata=2.0.0=py_1
54 | - intel-openmp=2020.2=254
55 | - jasper=1.900.1=hd497a04_4
56 | - jpeg=9d=h516909a_0
57 | - json_tricks=3.15.3=pyh9f0ad1d_0
58 | - keras-preprocessing=1.1.0=py_1
59 | - kiwisolver=1.2.0=py38hbf85e49_0
60 | - krb5=1.17.1=h173b8e3_0
61 | - lame=3.100=h7b6447c_0
62 | - lcms2=2.11=h396b838_0
63 | - ld_impl_linux-64=2.33.1=h53a641e_7
64 | - libblas=3.8.0=16_mkl
65 | - libcblas=3.8.0=16_mkl
66 | - libclang=10.0.1=default_hb85057a_2
67 | - libedit=3.1.20191231=h14c3975_1
68 | - libevent=2.1.10=hcdb4288_2
69 | - libffi=3.2.1=hf484d3e_1007
70 | - libgcc-ng=9.1.0=hdf63c60_0
71 | - libgfortran-ng=7.3.0=hdf63c60_0
72 | - libiconv=1.16=h516909a_0
73 | - liblapack=3.8.0=16_mkl
74 | - liblapacke=3.8.0=16_mkl
75 | - libllvm10=10.0.1=hbcb73fb_5
76 | - libpng=1.6.37=hbc83047_0
77 | - libpq=12.3=h5513abc_0
78 | - libprotobuf=3.13.0=hd408876_0
79 | - libsodium=1.0.18=h7b6447c_0
80 | - libstdcxx-ng=9.1.0=hdf63c60_0
81 | - libtiff=4.1.0=h2733197_1
82 | - libuuid=2.32.1=h14c3975_1000
83 | - libuv=1.40.0=h7b6447c_0
84 | - libwebp-base=1.1.0=h7b6447c_3
85 | - libxcb=1.14=h7b6447c_0
86 | - libxkbcommon=0.10.0=he1b5a44_0
87 | - libxml2=2.9.10=h68273f3_2
88 | - lz4-c=1.9.2=he6710b0_1
89 | - markdown=3.3.3=pyh9f0ad1d_0
90 | - matplotlib=3.3.2=0
91 | - matplotlib-base=3.3.2=py38h91b0d89_0
92 | - mkl=2020.2=256
93 | - mkl-service=2.3.0=py38he904b0f_0
94 | - mkl_fft=1.2.0=py38h23d657b_0
95 | - mkl_random=1.1.1=py38h0573a6f_0
96 | - multidict=4.7.5=py38h1e0a361_2
97 | - mysql-common=8.0.21=2
98 | - mysql-libs=8.0.21=hf3661c5_2
99 | - ncurses=6.2=he6710b0_1
100 | - nettle=3.4.1=hbb512f6_0
101 | - nibabel=3.1.1=py_0
102 | - ninja=1.10.1=py38hfd86e86_0
103 | - nspr=4.29=he1b5a44_0
104 | - nss=3.57=he751ad9_0
105 | - numpy-base=1.19.2=py38hfa32c7d_0
106 | - oauthlib=3.0.1=py_0
107 | - olefile=0.46=py_0
108 | - openh264=2.1.1=h8b12597_0
109 | - openssl=1.1.1i=h27cfd23_0
110 | - opt_einsum=3.1.0=py_0
111 | - packaging=20.4=py_0
112 | - pandas=1.1.1=py38he6710b0_0
113 | - pcre=8.44=he6710b0_0
114 | - pillow=7.2.0=py38hb39fc2d_0
115 | - pip=20.2.2=py38_0
116 | - pixman=0.38.0=h7b6447c_0
117 | - protobuf=3.13.0=py38he6710b0_1
118 | - psutil=5.7.2=py38h7b6447c_0
119 | - pyasn1=0.4.8=py_0
120 | - pyasn1-modules=0.2.7=py_0
121 | - pycparser=2.20=py_2
122 | - pydicom=2.0.0=pyh9f0ad1d_0
123 | - pyjwt=1.7.1=py_0
124 | - pyopenssl=19.1.0=py_1
125 | - pyparsing=2.4.7=py_0
126 | - pysocks=1.7.1=py38_0
127 | - python=3.8.2=hcf32534_0
128 | - python-dateutil=2.8.1=py_0
129 | - python_abi=3.8=1_cp38
130 | - pytorch=1.7.1=py3.8_cuda11.0.221_cudnn8.0.5_0
131 | - pytz=2020.1=py_0
132 | - pyyaml=5.3.1=py38h7b6447c_1
133 | - pyzmq=19.0.2=py38he6710b0_1
134 | - qt=5.12.9=h1f2b2cb_0
135 | - readline=8.0=h7b6447c_0
136 | - requests=2.24.0=py_0
137 | - requests-oauthlib=1.3.0=pyh9f0ad1d_0
138 | - rsa=4.6=pyh9f0ad1d_0
139 | - scipy=1.5.2=py38h0b6359f_0
140 | - setuptools=49.6.0=py38_1
141 | - six=1.15.0=py_0
142 | - sqlite=3.33.0=h62c20be_0
143 | - tensorboard=2.3.0=py_0
144 | - tensorboard-plugin-wit=1.7.0=pyh9f0ad1d_0
145 | - tensorboardx=2.1=py_0
146 | - tensorflow=2.2.0=mkl_py38h6d3daf0_0
147 | - tensorflow-base=2.2.0=mkl_py38h5059a2d_0
148 | - tensorflow-estimator=2.2.0=pyh208ff02_0
149 | - termcolor=1.1.0=py38_1
150 | - tk=8.6.10=hbc83047_0
151 | - torchfile=0.1.0=py_0
152 | - tornado=6.0.4=py38h7b6447c_1
153 | - tqdm=4.50.0=pyh9f0ad1d_0
154 | - typing-extensions=3.7.4.3=0
155 | - typing_extensions=3.7.4.3=py_0
156 | - urllib3=1.25.10=py_0
157 | - visdom=0.1.8.9=0
158 | - websocket-client=0.57.0=py38_1
159 | - werkzeug=1.0.1=pyh9f0ad1d_0
160 | - wheel=0.35.1=py_0
161 | - wrapt=1.12.1=py38h7b6447c_1
162 | - x264=1!152.20180806=h7b6447c_0
163 | - xorg-kbproto=1.0.7=h14c3975_1002
164 | - xorg-libice=1.0.10=h516909a_0
165 | - xorg-libsm=1.2.3=h84519dc_1000
166 | - xorg-libx11=1.6.12=h516909a_0
167 | - xorg-libxext=1.3.4=h516909a_0
168 | - xorg-libxrender=0.9.10=h516909a_1002
169 | - xorg-renderproto=0.11.1=h14c3975_1002
170 | - xorg-xextproto=7.3.0=h14c3975_1002
171 | - xorg-xproto=7.0.31=h14c3975_1007
172 | - xz=5.2.5=h7b6447c_0
173 | - yacs=0.1.6=py_0
174 | - yaml=0.2.5=h7b6447c_0
175 | - yarl=1.6.2=py38h1e0a361_0
176 | - zeromq=4.3.2=he6710b0_3
177 | - zipp=3.4.0=py_0
178 | - zlib=1.2.11=h7b6447c_3
179 | - zstd=1.4.5=h9ceee32_0
180 | - pip:
181 | - decorator==4.4.2
182 | - easydict==1.7
183 | - fvcore==0.1.3.post20210311
184 | - imageio==2.9.0
185 | - iopath==0.1.4
186 | - munkres==1.1.4
187 | - networkx==2.5
188 | - numpy==1.16.2
189 | - opencv-python==4.4.0.44
190 | - portalocker==2.2.1
191 | - progress==1.5
192 | - pthflops==0.4.0
193 | - pywavelets==1.1.1
194 | - scikit-image==0.17.2
195 | - tabulate==0.8.9
196 | - thop==0.0.31-2005241907
197 | - tifffile==2020.10.1
198 | - timm==0.3.4
199 | - torchscan==0.1.1
200 | - torchstat==0.0.7
201 | - torchvision==0.8.2
202 | prefix: /home/cezheng/anaconda3/envs/pose2
203 |
--------------------------------------------------------------------------------
/run_poseformer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2018-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | import numpy as np
9 |
10 | from common.arguments import parse_args
11 | import torch
12 |
13 | import torch.nn as nn
14 | import torch.nn.functional as F
15 | import torch.optim as optim
16 | import os
17 | import sys
18 | import errno
19 | import math
20 |
21 | from einops import rearrange, repeat
22 | from copy import deepcopy
23 |
24 | from common.camera import *
25 | import collections
26 |
27 | from common.model_poseformer import *
28 |
29 | from common.loss import *
30 | from common.generators import ChunkedGenerator, UnchunkedGenerator
31 | from time import time
32 | from common.utils import *
33 |
34 |
35 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
36 | # os.environ["CUDA_VISIBLE_DEVICES"] = "0"
37 | os.environ["CUDA_VISIBLE_DEVICES"] = "0, 1"
38 | # print(torch.cuda.device_count())
39 |
40 |
41 | ###################
42 | args = parse_args()
43 | # print(args)
44 |
45 | try:
46 | # Create checkpoint directory if it does not exist
47 | os.makedirs(args.checkpoint)
48 | except OSError as e:
49 | if e.errno != errno.EEXIST:
50 | raise RuntimeError('Unable to create checkpoint directory:', args.checkpoint)
51 |
52 | print('Loading dataset...')
53 | dataset_path = 'data/data_3d_' + args.dataset + '.npz'
54 | if args.dataset == 'h36m':
55 | from common.h36m_dataset import Human36mDataset
56 | dataset = Human36mDataset(dataset_path)
57 | elif args.dataset.startswith('humaneva'):
58 | from common.humaneva_dataset import HumanEvaDataset
59 | dataset = HumanEvaDataset(dataset_path)
60 | elif args.dataset.startswith('custom'):
61 | from common.custom_dataset import CustomDataset
62 | dataset = CustomDataset('data/data_2d_' + args.dataset + '_' + args.keypoints + '.npz')
63 | else:
64 | raise KeyError('Invalid dataset')
65 |
66 | print('Preparing data...')
67 | for subject in dataset.subjects():
68 | for action in dataset[subject].keys():
69 | anim = dataset[subject][action]
70 |
71 | if 'positions' in anim:
72 | positions_3d = []
73 | for cam in anim['cameras']:
74 | pos_3d = world_to_camera(anim['positions'], R=cam['orientation'], t=cam['translation'])
75 | pos_3d[:, 1:] -= pos_3d[:, :1] # Remove global offset, but keep trajectory in first position
76 | positions_3d.append(pos_3d)
77 | anim['positions_3d'] = positions_3d
78 |
79 | print('Loading 2D detections...')
80 | keypoints = np.load('data/data_2d_' + args.dataset + '_' + args.keypoints + '.npz', allow_pickle=True)
81 | keypoints_metadata = keypoints['metadata'].item()
82 | keypoints_symmetry = keypoints_metadata['keypoints_symmetry']
83 | kps_left, kps_right = list(keypoints_symmetry[0]), list(keypoints_symmetry[1])
84 | joints_left, joints_right = list(dataset.skeleton().joints_left()), list(dataset.skeleton().joints_right())
85 | keypoints = keypoints['positions_2d'].item()
86 |
87 | ###################
88 | for subject in dataset.subjects():
89 | assert subject in keypoints, 'Subject {} is missing from the 2D detections dataset'.format(subject)
90 | for action in dataset[subject].keys():
91 | assert action in keypoints[subject], 'Action {} of subject {} is missing from the 2D detections dataset'.format(action, subject)
92 | if 'positions_3d' not in dataset[subject][action]:
93 | continue
94 |
95 | for cam_idx in range(len(keypoints[subject][action])):
96 |
97 | # We check for >= instead of == because some videos in H3.6M contain extra frames
98 | mocap_length = dataset[subject][action]['positions_3d'][cam_idx].shape[0]
99 | assert keypoints[subject][action][cam_idx].shape[0] >= mocap_length
100 |
101 | if keypoints[subject][action][cam_idx].shape[0] > mocap_length:
102 | # Shorten sequence
103 | keypoints[subject][action][cam_idx] = keypoints[subject][action][cam_idx][:mocap_length]
104 |
105 | assert len(keypoints[subject][action]) == len(dataset[subject][action]['positions_3d'])
106 |
107 | for subject in keypoints.keys():
108 | for action in keypoints[subject]:
109 | for cam_idx, kps in enumerate(keypoints[subject][action]):
110 | # Normalize camera frame
111 | cam = dataset.cameras()[subject][cam_idx]
112 | kps[..., :2] = normalize_screen_coordinates(kps[..., :2], w=cam['res_w'], h=cam['res_h'])
113 | keypoints[subject][action][cam_idx] = kps
114 |
115 | subjects_train = args.subjects_train.split(',')
116 | subjects_semi = [] if not args.subjects_unlabeled else args.subjects_unlabeled.split(',')
117 | if not args.render:
118 | subjects_test = args.subjects_test.split(',')
119 | else:
120 | subjects_test = [args.viz_subject]
121 |
122 |
123 | def fetch(subjects, action_filter=None, subset=1, parse_3d_poses=True):
124 | out_poses_3d = []
125 | out_poses_2d = []
126 | out_camera_params = []
127 | for subject in subjects:
128 | for action in keypoints[subject].keys():
129 | if action_filter is not None:
130 | found = False
131 | for a in action_filter:
132 | if action.startswith(a):
133 | found = True
134 | break
135 | if not found:
136 | continue
137 |
138 | poses_2d = keypoints[subject][action]
139 | for i in range(len(poses_2d)): # Iterate across cameras
140 | out_poses_2d.append(poses_2d[i])
141 |
142 | if subject in dataset.cameras():
143 | cams = dataset.cameras()[subject]
144 | assert len(cams) == len(poses_2d), 'Camera count mismatch'
145 | for cam in cams:
146 | if 'intrinsic' in cam:
147 | out_camera_params.append(cam['intrinsic'])
148 |
149 | if parse_3d_poses and 'positions_3d' in dataset[subject][action]:
150 | poses_3d = dataset[subject][action]['positions_3d']
151 | assert len(poses_3d) == len(poses_2d), 'Camera count mismatch'
152 | for i in range(len(poses_3d)): # Iterate across cameras
153 | out_poses_3d.append(poses_3d[i])
154 |
155 | if len(out_camera_params) == 0:
156 | out_camera_params = None
157 | if len(out_poses_3d) == 0:
158 | out_poses_3d = None
159 |
160 | stride = args.downsample
161 | if subset < 1:
162 | for i in range(len(out_poses_2d)):
163 | n_frames = int(round(len(out_poses_2d[i])//stride * subset)*stride)
164 | start = deterministic_random(0, len(out_poses_2d[i]) - n_frames + 1, str(len(out_poses_2d[i])))
165 | out_poses_2d[i] = out_poses_2d[i][start:start+n_frames:stride]
166 | if out_poses_3d is not None:
167 | out_poses_3d[i] = out_poses_3d[i][start:start+n_frames:stride]
168 | elif stride > 1:
169 | # Downsample as requested
170 | for i in range(len(out_poses_2d)):
171 | out_poses_2d[i] = out_poses_2d[i][::stride]
172 | if out_poses_3d is not None:
173 | out_poses_3d[i] = out_poses_3d[i][::stride]
174 |
175 |
176 | return out_camera_params, out_poses_3d, out_poses_2d
177 |
178 | action_filter = None if args.actions == '*' else args.actions.split(',')
179 | if action_filter is not None:
180 | print('Selected actions:', action_filter)
181 |
182 | cameras_valid, poses_valid, poses_valid_2d = fetch(subjects_test, action_filter)
183 |
184 |
185 | receptive_field = args.number_of_frames
186 | print('INFO: Receptive field: {} frames'.format(receptive_field))
187 | pad = (receptive_field -1) // 2 # Padding on each side
188 | min_loss = 100000
189 | width = cam['res_w']
190 | height = cam['res_h']
191 | num_joints = keypoints_metadata['num_joints']
192 |
193 | #########################################PoseTransformer
194 |
195 | model_pos_train = PoseTransformer(num_frame=receptive_field, num_joints=num_joints, in_chans=2, embed_dim_ratio=32, depth=4,
196 | num_heads=8, mlp_ratio=2., qkv_bias=True, qk_scale=None,drop_path_rate=0.1)
197 |
198 | model_pos = PoseTransformer(num_frame=receptive_field, num_joints=num_joints, in_chans=2, embed_dim_ratio=32, depth=4,
199 | num_heads=8, mlp_ratio=2., qkv_bias=True, qk_scale=None,drop_path_rate=0)
200 |
201 | ################ load weight ########################
202 | # posetrans_checkpoint = torch.load('./checkpoint/pretrained_posetrans.bin', map_location=lambda storage, loc: storage)
203 | # posetrans_checkpoint = posetrans_checkpoint["model_pos"]
204 | # model_pos_train = load_pretrained_weights(model_pos_train, posetrans_checkpoint)
205 |
206 | #################
207 | causal_shift = 0
208 | model_params = 0
209 | for parameter in model_pos.parameters():
210 | model_params += parameter.numel()
211 | print('INFO: Trainable parameter count:', model_params)
212 |
213 | if torch.cuda.is_available():
214 | model_pos = nn.DataParallel(model_pos)
215 | model_pos = model_pos.cuda()
216 | model_pos_train = nn.DataParallel(model_pos_train)
217 | model_pos_train = model_pos_train.cuda()
218 |
219 |
220 | if args.resume or args.evaluate:
221 | chk_filename = os.path.join(args.checkpoint, args.resume if args.resume else args.evaluate)
222 | print('Loading checkpoint', chk_filename)
223 | checkpoint = torch.load(chk_filename, map_location=lambda storage, loc: storage)
224 | model_pos_train.load_state_dict(checkpoint['model_pos'], strict=False)
225 | model_pos.load_state_dict(checkpoint['model_pos'], strict=False)
226 |
227 |
228 | test_generator = UnchunkedGenerator(cameras_valid, poses_valid, poses_valid_2d,
229 | pad=pad, causal_shift=causal_shift, augment=False,
230 | kps_left=kps_left, kps_right=kps_right, joints_left=joints_left, joints_right=joints_right)
231 | print('INFO: Testing on {} frames'.format(test_generator.num_frames()))
232 |
233 | def eval_data_prepare(receptive_field, inputs_2d, inputs_3d):
234 | inputs_2d_p = torch.squeeze(inputs_2d)
235 | inputs_3d_p = inputs_3d.permute(1,0,2,3)
236 | out_num = inputs_2d_p.shape[0] - receptive_field + 1
237 | eval_input_2d = torch.empty(out_num, receptive_field, inputs_2d_p.shape[1], inputs_2d_p.shape[2])
238 | for i in range(out_num):
239 | eval_input_2d[i,:,:,:] = inputs_2d_p[i:i+receptive_field, :, :]
240 | return eval_input_2d, inputs_3d_p
241 |
242 |
243 | ###################
244 |
245 | if not args.evaluate:
246 | cameras_train, poses_train, poses_train_2d = fetch(subjects_train, action_filter, subset=args.subset)
247 |
248 | lr = args.learning_rate
249 | optimizer = optim.AdamW(model_pos_train.parameters(), lr=lr, weight_decay=0.1)
250 |
251 | lr_decay = args.lr_decay
252 | losses_3d_train = []
253 | losses_3d_train_eval = []
254 | losses_3d_valid = []
255 |
256 | epoch = 0
257 | initial_momentum = 0.1
258 | final_momentum = 0.001
259 |
260 | train_generator = ChunkedGenerator(args.batch_size//args.stride, cameras_train, poses_train, poses_train_2d, args.stride,
261 | pad=pad, causal_shift=causal_shift, shuffle=True, augment=args.data_augmentation,
262 | kps_left=kps_left, kps_right=kps_right, joints_left=joints_left, joints_right=joints_right)
263 | train_generator_eval = UnchunkedGenerator(cameras_train, poses_train, poses_train_2d,
264 | pad=pad, causal_shift=causal_shift, augment=False)
265 | print('INFO: Training on {} frames'.format(train_generator_eval.num_frames()))
266 |
267 |
268 | if args.resume:
269 | epoch = checkpoint['epoch']
270 | if 'optimizer' in checkpoint and checkpoint['optimizer'] is not None:
271 | optimizer.load_state_dict(checkpoint['optimizer'])
272 | train_generator.set_random_state(checkpoint['random_state'])
273 | else:
274 | print('WARNING: this checkpoint does not contain an optimizer state. The optimizer will be reinitialized.')
275 |
276 | lr = checkpoint['lr']
277 |
278 |
279 | print('** Note: reported losses are averaged over all frames.')
280 | print('** The final evaluation will be carried out after the last training epoch.')
281 |
282 | # Pos model only
283 | while epoch < args.epochs:
284 | start_time = time()
285 | epoch_loss_3d_train = 0
286 | epoch_loss_traj_train = 0
287 | epoch_loss_2d_train_unlabeled = 0
288 | N = 0
289 | N_semi = 0
290 | model_pos_train.train()
291 |
292 | for cameras_train, batch_3d, batch_2d in train_generator.next_epoch():
293 | cameras_train = torch.from_numpy(cameras_train.astype('float32'))
294 | inputs_3d = torch.from_numpy(batch_3d.astype('float32'))
295 | inputs_2d = torch.from_numpy(batch_2d.astype('float32'))
296 |
297 | if torch.cuda.is_available():
298 | inputs_3d = inputs_3d.cuda()
299 | inputs_2d = inputs_2d.cuda()
300 | cameras_train = cameras_train.cuda()
301 | inputs_traj = inputs_3d[:, :, :1].clone()
302 | inputs_3d[:, :, 0] = 0
303 |
304 | optimizer.zero_grad()
305 |
306 | # Predict 3D poses
307 | predicted_3d_pos = model_pos_train(inputs_2d)
308 |
309 | del inputs_2d
310 | torch.cuda.empty_cache()
311 |
312 | loss_3d_pos = mpjpe(predicted_3d_pos, inputs_3d)
313 | epoch_loss_3d_train += inputs_3d.shape[0] * inputs_3d.shape[1] * loss_3d_pos.item()
314 | N += inputs_3d.shape[0] * inputs_3d.shape[1]
315 |
316 | loss_total = loss_3d_pos
317 |
318 | loss_total.backward()
319 |
320 | optimizer.step()
321 | del inputs_3d, loss_3d_pos, predicted_3d_pos
322 | torch.cuda.empty_cache()
323 |
324 | losses_3d_train.append(epoch_loss_3d_train / N)
325 | torch.cuda.empty_cache()
326 |
327 | # End-of-epoch evaluation
328 | with torch.no_grad():
329 | model_pos.load_state_dict(model_pos_train.state_dict(), strict=False)
330 | model_pos.eval()
331 |
332 | epoch_loss_3d_valid = 0
333 | epoch_loss_traj_valid = 0
334 | epoch_loss_2d_valid = 0
335 | N = 0
336 | if not args.no_eval:
337 | # Evaluate on test set
338 | for cam, batch, batch_2d in test_generator.next_epoch():
339 | inputs_3d = torch.from_numpy(batch.astype('float32'))
340 | inputs_2d = torch.from_numpy(batch_2d.astype('float32'))
341 |
342 | ##### apply test-time-augmentation (following Videopose3d)
343 | inputs_2d_flip = inputs_2d.clone()
344 | inputs_2d_flip[:, :, :, 0] *= -1
345 | inputs_2d_flip[:, :, kps_left + kps_right, :] = inputs_2d_flip[:, :, kps_right + kps_left, :]
346 |
347 | ##### convert size
348 | inputs_2d, inputs_3d = eval_data_prepare(receptive_field, inputs_2d, inputs_3d)
349 | inputs_2d_flip, _ = eval_data_prepare(receptive_field, inputs_2d_flip, inputs_3d)
350 |
351 | if torch.cuda.is_available():
352 | inputs_2d = inputs_2d.cuda()
353 | inputs_2d_flip = inputs_2d_flip.cuda()
354 | inputs_3d = inputs_3d.cuda()
355 | inputs_3d[:, :, 0] = 0
356 |
357 | predicted_3d_pos = model_pos(inputs_2d)
358 | predicted_3d_pos_flip = model_pos(inputs_2d_flip)
359 | predicted_3d_pos_flip[:, :, :, 0] *= -1
360 | predicted_3d_pos_flip[:, :, joints_left + joints_right] = predicted_3d_pos_flip[:, :,
361 | joints_right + joints_left]
362 |
363 | predicted_3d_pos = torch.mean(torch.cat((predicted_3d_pos, predicted_3d_pos_flip), dim=1), dim=1,
364 | keepdim=True)
365 |
366 | del inputs_2d, inputs_2d_flip
367 | torch.cuda.empty_cache()
368 |
369 | loss_3d_pos = mpjpe(predicted_3d_pos, inputs_3d)
370 | epoch_loss_3d_valid += inputs_3d.shape[0] * inputs_3d.shape[1] * loss_3d_pos.item()
371 | N += inputs_3d.shape[0] * inputs_3d.shape[1]
372 |
373 | del inputs_3d, loss_3d_pos, predicted_3d_pos
374 | torch.cuda.empty_cache()
375 |
376 | losses_3d_valid.append(epoch_loss_3d_valid / N)
377 |
378 | # Evaluate on training set, this time in evaluation mode
379 | epoch_loss_3d_train_eval = 0
380 | epoch_loss_traj_train_eval = 0
381 | epoch_loss_2d_train_labeled_eval = 0
382 | N = 0
383 | for cam, batch, batch_2d in train_generator_eval.next_epoch():
384 | if batch_2d.shape[1] == 0:
385 | # This can only happen when downsampling the dataset
386 | continue
387 |
388 | inputs_3d = torch.from_numpy(batch.astype('float32'))
389 | inputs_2d = torch.from_numpy(batch_2d.astype('float32'))
390 | inputs_2d, inputs_3d = eval_data_prepare(receptive_field, inputs_2d, inputs_3d)
391 |
392 | if torch.cuda.is_available():
393 | inputs_3d = inputs_3d.cuda()
394 | inputs_2d = inputs_2d.cuda()
395 |
396 | inputs_3d[:, :, 0] = 0
397 |
398 | # Compute 3D poses
399 | predicted_3d_pos = model_pos(inputs_2d)
400 |
401 | del inputs_2d
402 | torch.cuda.empty_cache()
403 |
404 | loss_3d_pos = mpjpe(predicted_3d_pos, inputs_3d)
405 | epoch_loss_3d_train_eval += inputs_3d.shape[0] * inputs_3d.shape[1] * loss_3d_pos.item()
406 | N += inputs_3d.shape[0] * inputs_3d.shape[1]
407 |
408 | del inputs_3d, loss_3d_pos, predicted_3d_pos
409 | torch.cuda.empty_cache()
410 |
411 | losses_3d_train_eval.append(epoch_loss_3d_train_eval / N)
412 |
413 | # Evaluate 2D loss on unlabeled training set (in evaluation mode)
414 | epoch_loss_2d_train_unlabeled_eval = 0
415 | N_semi = 0
416 |
417 | elapsed = (time() - start_time) / 60
418 |
419 | if args.no_eval:
420 | print('[%d] time %.2f lr %f 3d_train %f' % (
421 | epoch + 1,
422 | elapsed,
423 | lr,
424 | losses_3d_train[-1] * 1000))
425 | else:
426 |
427 | print('[%d] time %.2f lr %f 3d_train %f 3d_eval %f 3d_valid %f' % (
428 | epoch + 1,
429 | elapsed,
430 | lr,
431 | losses_3d_train[-1] * 1000,
432 | losses_3d_train_eval[-1] * 1000,
433 | losses_3d_valid[-1] * 1000))
434 |
435 | # Decay learning rate exponentially
436 | lr *= lr_decay
437 | for param_group in optimizer.param_groups:
438 | param_group['lr'] *= lr_decay
439 | epoch += 1
440 |
441 | # Decay BatchNorm momentum
442 | # momentum = initial_momentum * np.exp(-epoch/args.epochs * np.log(initial_momentum/final_momentum))
443 | # model_pos_train.set_bn_momentum(momentum)
444 |
445 | # Save checkpoint if necessary
446 | if epoch % args.checkpoint_frequency == 0:
447 | chk_path = os.path.join(args.checkpoint, 'epoch_{}.bin'.format(epoch))
448 | print('Saving checkpoint to', chk_path)
449 |
450 | torch.save({
451 | 'epoch': epoch,
452 | 'lr': lr,
453 | 'random_state': train_generator.random_state(),
454 | 'optimizer': optimizer.state_dict(),
455 | 'model_pos': model_pos_train.state_dict(),
456 | # 'model_traj': model_traj_train.state_dict() if semi_supervised else None,
457 | # 'random_state_semi': semi_generator.random_state() if semi_supervised else None,
458 | }, chk_path)
459 |
460 | #### save best checkpoint
461 | best_chk_path = os.path.join(args.checkpoint, 'best_epoch.bin'.format(epoch))
462 | if losses_3d_valid[-1] * 1000 < min_loss:
463 | min_loss = losses_3d_valid[-1] * 1000
464 | print("save best checkpoint")
465 | torch.save({
466 | 'epoch': epoch,
467 | 'lr': lr,
468 | 'random_state': train_generator.random_state(),
469 | 'optimizer': optimizer.state_dict(),
470 | 'model_pos': model_pos_train.state_dict(),
471 | # 'model_traj': model_traj_train.state_dict() if semi_supervised else None,
472 | # 'random_state_semi': semi_generator.random_state() if semi_supervised else None,
473 | }, best_chk_path)
474 |
475 | # Save training curves after every epoch, as .png images (if requested)
476 | if args.export_training_curves and epoch > 3:
477 | if 'matplotlib' not in sys.modules:
478 | import matplotlib
479 |
480 | matplotlib.use('Agg')
481 | import matplotlib.pyplot as plt
482 |
483 | plt.figure()
484 | epoch_x = np.arange(3, len(losses_3d_train)) + 1
485 | plt.plot(epoch_x, losses_3d_train[3:], '--', color='C0')
486 | plt.plot(epoch_x, losses_3d_train_eval[3:], color='C0')
487 | plt.plot(epoch_x, losses_3d_valid[3:], color='C1')
488 | plt.legend(['3d train', '3d train (eval)', '3d valid (eval)'])
489 | plt.ylabel('MPJPE (m)')
490 | plt.xlabel('Epoch')
491 | plt.xlim((3, epoch))
492 | plt.savefig(os.path.join(args.checkpoint, 'loss_3d.png'))
493 |
494 | plt.close('all')
495 |
496 |
497 | # Evaluate
498 | def evaluate(test_generator, action=None, return_predictions=False, use_trajectory_model=False):
499 | epoch_loss_3d_pos = 0
500 | epoch_loss_3d_pos_procrustes = 0
501 | epoch_loss_3d_pos_scale = 0
502 | epoch_loss_3d_vel = 0
503 | with torch.no_grad():
504 | if not use_trajectory_model:
505 | model_pos.eval()
506 | # else:
507 | # model_traj.eval()
508 | N = 0
509 | for _, batch, batch_2d in test_generator.next_epoch():
510 | inputs_2d = torch.from_numpy(batch_2d.astype('float32'))
511 | inputs_3d = torch.from_numpy(batch.astype('float32'))
512 |
513 |
514 | ##### apply test-time-augmentation (following Videopose3d)
515 | inputs_2d_flip = inputs_2d.clone()
516 | inputs_2d_flip [:, :, :, 0] *= -1
517 | inputs_2d_flip[:, :, kps_left + kps_right,:] = inputs_2d_flip[:, :, kps_right + kps_left,:]
518 |
519 | ##### convert size
520 | inputs_2d, inputs_3d = eval_data_prepare(receptive_field, inputs_2d, inputs_3d)
521 | inputs_2d_flip, _ = eval_data_prepare(receptive_field, inputs_2d_flip, inputs_3d)
522 |
523 | if torch.cuda.is_available():
524 | inputs_2d = inputs_2d.cuda()
525 | inputs_2d_flip = inputs_2d_flip.cuda()
526 | inputs_3d = inputs_3d.cuda()
527 | inputs_3d[:, :, 0] = 0
528 |
529 | predicted_3d_pos = model_pos(inputs_2d)
530 | predicted_3d_pos_flip = model_pos(inputs_2d_flip)
531 | predicted_3d_pos_flip[:, :, :, 0] *= -1
532 | predicted_3d_pos_flip[:, :, joints_left + joints_right] = predicted_3d_pos_flip[:, :,
533 | joints_right + joints_left]
534 |
535 | predicted_3d_pos = torch.mean(torch.cat((predicted_3d_pos, predicted_3d_pos_flip), dim=1), dim=1,
536 | keepdim=True)
537 |
538 | del inputs_2d, inputs_2d_flip
539 | torch.cuda.empty_cache()
540 |
541 | if return_predictions:
542 | return predicted_3d_pos.squeeze(0).cpu().numpy()
543 |
544 |
545 | error = mpjpe(predicted_3d_pos, inputs_3d)
546 | epoch_loss_3d_pos_scale += inputs_3d.shape[0]*inputs_3d.shape[1] * n_mpjpe(predicted_3d_pos, inputs_3d).item()
547 |
548 | epoch_loss_3d_pos += inputs_3d.shape[0]*inputs_3d.shape[1] * error.item()
549 | N += inputs_3d.shape[0] * inputs_3d.shape[1]
550 |
551 | inputs = inputs_3d.cpu().numpy().reshape(-1, inputs_3d.shape[-2], inputs_3d.shape[-1])
552 | predicted_3d_pos = predicted_3d_pos.cpu().numpy().reshape(-1, inputs_3d.shape[-2], inputs_3d.shape[-1])
553 |
554 | epoch_loss_3d_pos_procrustes += inputs_3d.shape[0]*inputs_3d.shape[1] * p_mpjpe(predicted_3d_pos, inputs)
555 |
556 | # Compute velocity error
557 | epoch_loss_3d_vel += inputs_3d.shape[0]*inputs_3d.shape[1] * mean_velocity_error(predicted_3d_pos, inputs)
558 |
559 | if action is None:
560 | print('----------')
561 | else:
562 | print('----'+action+'----')
563 | e1 = (epoch_loss_3d_pos / N)*1000
564 | e2 = (epoch_loss_3d_pos_procrustes / N)*1000
565 | e3 = (epoch_loss_3d_pos_scale / N)*1000
566 | ev = (epoch_loss_3d_vel / N)*1000
567 | print('Protocol #1 Error (MPJPE):', e1, 'mm')
568 | print('Protocol #2 Error (P-MPJPE):', e2, 'mm')
569 | print('Protocol #3 Error (N-MPJPE):', e3, 'mm')
570 | print('Velocity Error (MPJVE):', ev, 'mm')
571 | print('----------')
572 |
573 | return e1, e2, e3, ev
574 |
575 | if args.render:
576 | print('Rendering...')
577 |
578 | input_keypoints = keypoints[args.viz_subject][args.viz_action][args.viz_camera].copy()
579 | ground_truth = None
580 | if args.viz_subject in dataset.subjects() and args.viz_action in dataset[args.viz_subject]:
581 | if 'positions_3d' in dataset[args.viz_subject][args.viz_action]:
582 | ground_truth = dataset[args.viz_subject][args.viz_action]['positions_3d'][args.viz_camera].copy()
583 | if ground_truth is None:
584 | print('INFO: this action is unlabeled. Ground truth will not be rendered.')
585 |
586 | gen = UnchunkedGenerator(None, [ground_truth], [input_keypoints],
587 | pad=pad, causal_shift=causal_shift, augment=args.test_time_augmentation,
588 | kps_left=kps_left, kps_right=kps_right, joints_left=joints_left, joints_right=joints_right)
589 | prediction = evaluate(gen, return_predictions=True)
590 | # if model_traj is not None and ground_truth is None:
591 | # prediction_traj = evaluate(gen, return_predictions=True, use_trajectory_model=True)
592 | # prediction += prediction_traj
593 |
594 | if args.viz_export is not None:
595 | print('Exporting joint positions to', args.viz_export)
596 | # Predictions are in camera space
597 | np.save(args.viz_export, prediction)
598 |
599 | if args.viz_output is not None:
600 | if ground_truth is not None:
601 | # Reapply trajectory
602 | trajectory = ground_truth[:, :1]
603 | ground_truth[:, 1:] += trajectory
604 | prediction += trajectory
605 |
606 | # Invert camera transformation
607 | cam = dataset.cameras()[args.viz_subject][args.viz_camera]
608 | if ground_truth is not None:
609 | prediction = camera_to_world(prediction, R=cam['orientation'], t=cam['translation'])
610 | ground_truth = camera_to_world(ground_truth, R=cam['orientation'], t=cam['translation'])
611 | else:
612 | # If the ground truth is not available, take the camera extrinsic params from a random subject.
613 | # They are almost the same, and anyway, we only need this for visualization purposes.
614 | for subject in dataset.cameras():
615 | if 'orientation' in dataset.cameras()[subject][args.viz_camera]:
616 | rot = dataset.cameras()[subject][args.viz_camera]['orientation']
617 | break
618 | prediction = camera_to_world(prediction, R=rot, t=0)
619 | # We don't have the trajectory, but at least we can rebase the height
620 | prediction[:, :, 2] -= np.min(prediction[:, :, 2])
621 |
622 | anim_output = {'Reconstruction': prediction}
623 | if ground_truth is not None and not args.viz_no_ground_truth:
624 | anim_output['Ground truth'] = ground_truth
625 |
626 | input_keypoints = image_coordinates(input_keypoints[..., :2], w=cam['res_w'], h=cam['res_h'])
627 |
628 | from common.visualization import render_animation
629 |
630 | render_animation(input_keypoints, keypoints_metadata, anim_output,
631 | dataset.skeleton(), dataset.fps(), args.viz_bitrate, cam['azimuth'], args.viz_output,
632 | limit=args.viz_limit, downsample=args.viz_downsample, size=args.viz_size,
633 | input_video_path=args.viz_video, viewport=(cam['res_w'], cam['res_h']),
634 | input_video_skip=args.viz_skip)
635 |
636 | else:
637 | print('Evaluating...')
638 | all_actions = {}
639 | all_actions_by_subject = {}
640 | for subject in subjects_test:
641 | if subject not in all_actions_by_subject:
642 | all_actions_by_subject[subject] = {}
643 |
644 | for action in dataset[subject].keys():
645 | action_name = action.split(' ')[0]
646 | if action_name not in all_actions:
647 | all_actions[action_name] = []
648 | if action_name not in all_actions_by_subject[subject]:
649 | all_actions_by_subject[subject][action_name] = []
650 | all_actions[action_name].append((subject, action))
651 | all_actions_by_subject[subject][action_name].append((subject, action))
652 |
653 |
654 | def fetch_actions(actions):
655 | out_poses_3d = []
656 | out_poses_2d = []
657 |
658 | for subject, action in actions:
659 | poses_2d = keypoints[subject][action]
660 | for i in range(len(poses_2d)): # Iterate across cameras
661 | out_poses_2d.append(poses_2d[i])
662 |
663 | poses_3d = dataset[subject][action]['positions_3d']
664 | assert len(poses_3d) == len(poses_2d), 'Camera count mismatch'
665 | for i in range(len(poses_3d)): # Iterate across cameras
666 | out_poses_3d.append(poses_3d[i])
667 |
668 | stride = args.downsample
669 | if stride > 1:
670 | # Downsample as requested
671 | for i in range(len(out_poses_2d)):
672 | out_poses_2d[i] = out_poses_2d[i][::stride]
673 | if out_poses_3d is not None:
674 | out_poses_3d[i] = out_poses_3d[i][::stride]
675 |
676 | return out_poses_3d, out_poses_2d
677 |
678 |
679 | def run_evaluation(actions, action_filter=None):
680 | errors_p1 = []
681 | errors_p2 = []
682 | errors_p3 = []
683 | errors_vel = []
684 |
685 | for action_key in actions.keys():
686 | if action_filter is not None:
687 | found = False
688 | for a in action_filter:
689 | if action_key.startswith(a):
690 | found = True
691 | break
692 | if not found:
693 | continue
694 |
695 | poses_act, poses_2d_act = fetch_actions(actions[action_key])
696 | gen = UnchunkedGenerator(None, poses_act, poses_2d_act,
697 | pad=pad, causal_shift=causal_shift, augment=args.test_time_augmentation,
698 | kps_left=kps_left, kps_right=kps_right, joints_left=joints_left,
699 | joints_right=joints_right)
700 | e1, e2, e3, ev = evaluate(gen, action_key)
701 | errors_p1.append(e1)
702 | errors_p2.append(e2)
703 | errors_p3.append(e3)
704 | errors_vel.append(ev)
705 |
706 | print('Protocol #1 (MPJPE) action-wise average:', round(np.mean(errors_p1), 1), 'mm')
707 | print('Protocol #2 (P-MPJPE) action-wise average:', round(np.mean(errors_p2), 1), 'mm')
708 | print('Protocol #3 (N-MPJPE) action-wise average:', round(np.mean(errors_p3), 1), 'mm')
709 | print('Velocity (MPJVE) action-wise average:', round(np.mean(errors_vel), 2), 'mm')
710 |
711 |
712 | if not args.by_subject:
713 | run_evaluation(all_actions, action_filter)
714 | else:
715 | for subject in all_actions_by_subject.keys():
716 | print('Evaluating on subject', subject)
717 | run_evaluation(all_actions_by_subject[subject], action_filter)
718 | print('')
--------------------------------------------------------------------------------