├── Figures
├── GIF.gif
├── results.png
└── structure.jpg
├── README.md
├── checkpoint
└── README.md
├── common
├── arguments.py
├── camera.py
├── generators.py
├── h36m_dataset.py
├── loss.py
├── mocap_dataset.py
├── model.py
├── quaternion.py
├── ranger.py
├── skeleton.py
├── utils.py
└── visualization.py
├── data
└── README.md
└── run.py
/Figures/GIF.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lrxjason/Attention3DHumanPose/dc921991dc1700597511f9588c09c0aff43f1448/Figures/GIF.gif
--------------------------------------------------------------------------------
/Figures/results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lrxjason/Attention3DHumanPose/dc921991dc1700597511f9588c09c0aff43f1448/Figures/results.png
--------------------------------------------------------------------------------
/Figures/structure.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lrxjason/Attention3DHumanPose/dc921991dc1700597511f9588c09c0aff43f1448/Figures/structure.jpg
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Attention Mechanism Exploits Temporal Contexts: Real-time 3D Human Pose Reconstruction (CVPR 2020 Oral)
2 | More extensive evaluation andcode can be found at our lab website: (https://sites.google.com/a/udayton.edu/jshen1/cvpr2020)
3 | 
4 |
5 |
6 |  
7 |  
8 |  
9 |
10 |
11 |
12 | PyTorch code of the paper "Attention Mechanism Exploits Temporal Contexts: Real-time 3D Human Pose Reconstruction". [pdf](http://openaccess.thecvf.com/content_CVPR_2020/papers/Liu_Attention_Mechanism_Exploits_Temporal_Contexts_Real-Time_3D_Human_Pose_Reconstruction_CVPR_2020_paper.pdf)
13 |
14 | ### [Bibtex](https://scholar.googleusercontent.com/scholar.bib?q=info:sVZlnopW0ZQJ:scholar.google.com/&output=citation&scisdr=CgUvGH_mEIi98y29oOM:AAGBfm0AAAAAXu-4uOOunCSIKKuamAWN5VjFJ_OC0cHs&scisig=AAGBfm0AAAAAXu-4uBa5vr92Yk6AXlKVO0mVXEXZorOx&scisf=4&ct=citation&cd=-1&hl=en)
15 |
16 | If you found this code useful, please cite the following paper:
17 |
18 | @inproceedings{liu2020attention,
19 | title={Attention Mechanism Exploits Temporal Contexts: Real-Time 3D Human Pose Reconstruction},
20 | author={Liu, Ruixu and Shen, Ju and Wang, He and Chen, Chen and Cheung, Sen-ching and Asari, Vijayan},
21 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
22 | pages={5064--5073},
23 | year={2020}
24 | }
25 |
26 | ### Environment
27 |
28 | The code is developed and tested on the following environment
29 |
30 | * Python 3.6
31 | * PyTorch 1.1 or higher
32 | * CUDA 10
33 |
34 | ### Dataset
35 |
36 | The source code is for training/evaluating on the [Human3.6M](http://vision.imar.ro/human3.6m) dataset. Our code is compatible with the dataset setup introduced by [Martinez et al.](https://github.com/una-dinosauria/3d-pose-baseline) and [Pavllo et al.](https://github.com/facebookresearch/VideoPose3D). Please refer to [VideoPose3D](https://github.com/facebookresearch/VideoPose3D) to set up the Human3.6M dataset (`./data` directory). We upload the training 2D cpn data [here](https://drive.google.com/file/d/131EnG8L0-A9DNy9bfsqCSrG1n5GnzwkO/view?usp=sharing) and the 3D gt data [here](https://drive.google.com/file/d/1nbscv_IlJ-sdug6GU2KWN4MYkPtYj4YX/view?usp=sharing). The 3D Avatar model and code are avaliable [here](https://drive.google.com/file/d/1RxhwFHCX4ydf1I1crLnQ_4NEXF84MkMY/view?usp=sharing).
37 |
38 |
39 | ### Training new models
40 |
41 | To train a model from scratch, run:
42 |
43 | ```bash
44 | python run.py -da -tta
45 | ```
46 |
47 | `-da` controls the data augments during training and `-tta` is the testing data augmentation.
48 |
49 | For example, to train our 243-frame ground truth model or causal model in our paper, please run:
50 |
51 | ```bash
52 | python run.py -k gt
53 | ```
54 |
55 | or
56 |
57 | ```bash
58 | python run.py -k cpn_ft_h36m_dbb --causal
59 | ```
60 |
61 | It should require 24 hours to train on one TITAN RTX GPU.
62 |
63 | ### Evaluating pre-trained models
64 |
65 | We provide the pre-trained cpn model [here](https://drive.google.com/file/d/1jiZWqAOJmXoTL8dxhPX8QgK0QeECeoAM/view?usp=sharing) and ground truth model [here](https://drive.google.com/file/d/1EAS9PUddznBPqNaEHV6-tCfqsQOHZ1Of/view?usp=sharing). To evaluate them, put them into the `./checkpoint` directory and run:
66 |
67 | For cpn model:
68 | ```bash
69 | python run.py -tta --evaluate cpn.bin
70 | ```
71 |
72 | For ground truth model:
73 | ```bash
74 | python run.py -k gt -tta --evaluate gt.bin
75 | ```
76 |
77 | ### Visualization and other functions
78 |
79 | We keep our code consistent with [VideoPose3D](https://github.com/facebookresearch/VideoPose3D). Please refer to their project page for further information.
80 |
81 |
82 |
--------------------------------------------------------------------------------
/checkpoint/README.md:
--------------------------------------------------------------------------------
1 | The pre-trained model put at here.
2 |
--------------------------------------------------------------------------------
/common/arguments.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2018-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | import argparse
9 |
10 | def parse_args():
11 | parser = argparse.ArgumentParser(description='Training script')
12 |
13 | # General arguments
14 | parser.add_argument('-d', '--dataset', default='h36m', type=str, metavar='NAME', help='target dataset') # h36m or humaneva
15 | parser.add_argument('-k', '--keypoints', default='cpn_ft_h36m_dbb', type=str, metavar='NAME', help='2D detections to use')
16 | parser.add_argument('-str', '--subjects-train', default='S1,S5,S6,S7,S8', type=str, metavar='LIST', help='training subjects separated by comma')
17 | parser.add_argument('-ste', '--subjects-test', default='S9,S11', type=str, metavar='LIST', help='test subjects separated by comma')
18 |
19 | parser.add_argument('-a', '--actions', default='*', type=str, metavar='LIST',
20 | help='actions to train/test on, separated by comma, or * for all')
21 | parser.add_argument('-c', '--checkpoint', default='checkpoint', type=str, metavar='PATH',
22 | help='checkpoint directory')
23 | parser.add_argument('--checkpoint-frequency', default=10, type=int, metavar='N',
24 | help='create a checkpoint every N epochs')
25 | parser.add_argument('-r', '--resume', default='', type=str, metavar='FILENAME',
26 | help='checkpoint to resume (file name)')
27 | parser.add_argument('--evaluate', default='', type=str, metavar='FILENAME', help='checkpoint to evaluate (file name)')
28 | parser.add_argument('--render', action='store_true', help='visualize a particular video')
29 | parser.add_argument('--by-subject', action='store_true', help='break down error by subject (on evaluation)')
30 | parser.add_argument('--export-training-curves', action='store_true', help='save training curves as .png images')
31 |
32 | # Model arguments
33 | parser.add_argument('-s', '--stride', default=1, type=int, metavar='N', help='chunk size to use during training')
34 | parser.add_argument('-e', '--epochs', default=80, type=int, metavar='N', help='number of training epochs')
35 | parser.add_argument('-b', '--batch-size', default=1024, type=int, metavar='N', help='batch size in terms of predicted frames')
36 | parser.add_argument('-drop', '--dropout', default=0.2, type=float, metavar='P', help='dropout probability')
37 | parser.add_argument('-lr', '--learning-rate', default=0.001, type=float, metavar='LR', help='initial learning rate')
38 | parser.add_argument('-lrd', '--lr-decay', default=0.95, type=float, metavar='LR', help='learning rate decay per epoch')
39 | parser.add_argument('-da', '--data-augmentation', dest='data_augmentation', action='store_true',
40 | help='disable train-time flipping')
41 | parser.add_argument('-tta', '--test-time-augmentation', dest='test_time_augmentation', action='store_true',
42 | help='disable test-time flipping')
43 | parser.add_argument('-arc', '--architecture', default='3,3,3,3,3', type=str, metavar='LAYERS', help='filter widths separated by comma')
44 | parser.add_argument('--causal', action='store_true', help='use causal convolutions for real-time processing')
45 | parser.add_argument('-ch', '--channels', default=1024, type=int, metavar='N', help='number of channels in convolution layers')
46 |
47 | # Experimental
48 | parser.add_argument('--subset', default=1, type=float, metavar='FRACTION', help='reduce dataset size by fraction')
49 | parser.add_argument('--downsample', default=1, type=int, metavar='FACTOR', help='downsample frame rate by factor (semi-supervised)')
50 | parser.add_argument('--no-eval', action='store_true', help='disable epoch evaluation while training (small speed-up)')
51 | parser.add_argument('--dense', action='store_true', help='use dense convolutions instead of dilated convolutions')
52 | parser.add_argument('--disable-optimizations', action='store_true', help='disable optimized model for single-frame predictions')
53 | parser.add_argument('--linear-projection', action='store_true', help='use only linear coefficients for semi-supervised projection')
54 | parser.add_argument('--no-bone-length', action='store_false', dest='bone_length_term',
55 | help='disable bone length term in semi-supervised settings')
56 | parser.add_argument('--no-proj', action='store_true', help='disable projection for semi-supervised setting')
57 |
58 | # Visualization
59 | parser.add_argument('--viz-subject', type=str, metavar='STR', help='subject to render')
60 | parser.add_argument('--viz-action', type=str, metavar='STR', help='action to render')
61 | parser.add_argument('--viz-camera', type=int, default=0, metavar='N', help='camera to render')
62 | parser.add_argument('--viz-video', type=str, metavar='PATH', help='path to input video')
63 | parser.add_argument('--viz-skip', type=int, default=0, metavar='N', help='skip first N frames of input video')
64 | parser.add_argument('--viz-output', type=str, metavar='PATH', help='output file name (.gif or .mp4)')
65 | parser.add_argument('--viz-bitrate', type=int, default=3000, metavar='N', help='bitrate for mp4 videos')
66 | parser.add_argument('--viz-no-ground-truth', action='store_true', help='do not show ground-truth poses')
67 | parser.add_argument('--viz-limit', type=int, default=-1, metavar='N', help='only render first N frames')
68 | parser.add_argument('--viz-downsample', type=int, default=1, metavar='N', help='downsample FPS by a factor N')
69 | parser.add_argument('--viz-size', type=int, default=5, metavar='N', help='image size')
70 |
71 | parser.set_defaults(bone_length_term=True)
72 | parser.set_defaults(data_augmentation=False)
73 | parser.set_defaults(test_time_augmentation=False)
74 |
75 | args = parser.parse_args()
76 | # Check invalid configuration
77 | if args.resume and args.evaluate:
78 | print('Invalid flags: --resume and --evaluate cannot be set at the same time')
79 | exit()
80 |
81 | if args.export_training_curves and args.no_eval:
82 | print('Invalid flags: --export-training-curves and --no-eval cannot be set at the same time')
83 | exit()
84 |
85 | return args
86 |
--------------------------------------------------------------------------------
/common/camera.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2018-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | import numpy as np
9 | import torch
10 |
11 | from common.utils import wrap
12 | from common.quaternion import qrot, qinverse
13 |
14 | def normalize_screen_coordinates(X, w, h):
15 | assert X.shape[-1] == 2
16 |
17 | # Normalize so that [0, w] is mapped to [-1, 1], while preserving the aspect ratio
18 | return X/w*2 - [1, h/w]
19 |
20 |
21 | def image_coordinates(X, w, h):
22 | assert X.shape[-1] == 2
23 |
24 | # Reverse camera frame normalization
25 | return (X + [1, h/w])*w/2
26 |
27 |
28 | def world_to_camera(X, R, t):
29 | Rt = wrap(qinverse, R) # Invert rotation
30 | return wrap(qrot, np.tile(Rt, (*X.shape[:-1], 1)), X - t) # Rotate and translate
31 |
32 |
33 | def camera_to_world(X, R, t):
34 | return wrap(qrot, np.tile(R, (*X.shape[:-1], 1)), X) + t
35 |
36 |
37 | def project_to_2d(X, camera_params):
38 | """
39 | Project 3D points to 2D using the Human3.6M camera projection function.
40 | This is a differentiable and batched reimplementation of the original MATLAB script.
41 |
42 | Arguments:
43 | X -- 3D points in *camera space* to transform (N, *, 3)
44 | camera_params -- intrinsic parameteres (N, 2+2+3+2=9)
45 | """
46 | assert X.shape[-1] == 3
47 | assert len(camera_params.shape) == 2
48 | assert camera_params.shape[-1] == 9
49 | assert X.shape[0] == camera_params.shape[0]
50 |
51 | while len(camera_params.shape) < len(X.shape):
52 | camera_params = camera_params.unsqueeze(1)
53 |
54 | f = camera_params[..., :2]
55 | c = camera_params[..., 2:4]
56 | k = camera_params[..., 4:7]
57 | p = camera_params[..., 7:]
58 |
59 | XX = torch.clamp(X[..., :2] / X[..., 2:], min=-1, max=1)
60 | r2 = torch.sum(XX[..., :2]**2, dim=len(XX.shape)-1, keepdim=True)
61 |
62 | radial = 1 + torch.sum(k * torch.cat((r2, r2**2, r2**3), dim=len(r2.shape)-1), dim=len(r2.shape)-1, keepdim=True)
63 | tan = torch.sum(p*XX, dim=len(XX.shape)-1, keepdim=True)
64 |
65 | XXX = XX*(radial + tan) + p*r2
66 |
67 | return f*XXX + c
68 |
69 | def project_to_2d_linear(X, camera_params):
70 | """
71 | Project 3D points to 2D using only linear parameters (focal length and principal point).
72 |
73 | Arguments:
74 | X -- 3D points in *camera space* to transform (N, *, 3)
75 | camera_params -- intrinsic parameteres (N, 2+2+3+2=9)
76 | """
77 | assert X.shape[-1] == 3
78 | assert len(camera_params.shape) == 2
79 | assert camera_params.shape[-1] == 9
80 | assert X.shape[0] == camera_params.shape[0]
81 |
82 | while len(camera_params.shape) < len(X.shape):
83 | camera_params = camera_params.unsqueeze(1)
84 |
85 | f = camera_params[..., :2]
86 | c = camera_params[..., 2:4]
87 |
88 | XX = torch.clamp(X[..., :2] / X[..., 2:], min=-1, max=1)
89 |
90 | return f*XX + c
--------------------------------------------------------------------------------
/common/generators.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2018-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | from itertools import zip_longest
9 | import numpy as np
10 |
11 | class ChunkedGenerator:
12 | """
13 | Batched data generator, used for training.
14 | The sequences are split into equal-length chunks and padded as necessary.
15 |
16 | Arguments:
17 | batch_size -- the batch size to use for training
18 | cameras -- list of cameras, one element for each video (optional, used for semi-supervised training)
19 | poses_3d -- list of ground-truth 3D poses, one element for each video (optional, used for supervised training)
20 | poses_2d -- list of input 2D keypoints, one element for each video
21 | chunk_length -- number of output frames to predict for each training example (usually 1)
22 | pad -- 2D input padding to compensate for valid convolutions, per side (depends on the receptive field)
23 | causal_shift -- asymmetric padding offset when causal convolutions are used (usually 0 or "pad")
24 | shuffle -- randomly shuffle the dataset before each epoch
25 | random_seed -- initial seed to use for the random generator
26 | augment -- augment the dataset by flipping poses horizontally
27 | kps_left and kps_right -- list of left/right 2D keypoints if flipping is enabled
28 | joints_left and joints_right -- list of left/right 3D joints if flipping is enabled
29 | """
30 | def __init__(self, batch_size, cameras, poses_3d, poses_2d,
31 | chunk_length, pad=0, causal_shift=0,
32 | shuffle=True, random_seed=1234,
33 | augment=False, kps_left=None, kps_right=None, joints_left=None, joints_right=None,
34 | endless=False, noisy=False):
35 | assert poses_3d is None or len(poses_3d) == len(poses_2d), (len(poses_3d), len(poses_2d))
36 | assert cameras is None or len(cameras) == len(poses_2d)
37 |
38 | # Build lineage info
39 | pairs = [] # (seq_idx, start_frame, end_frame, flip) tuples
40 | for i in range(len(poses_2d)):
41 | assert poses_3d is None or poses_3d[i].shape[0] == poses_3d[i].shape[0]
42 | n_chunks = (poses_2d[i].shape[0] + chunk_length - 1) // chunk_length
43 | offset = (n_chunks * chunk_length - poses_2d[i].shape[0]) // 2
44 | bounds = np.arange(n_chunks+1)*chunk_length - offset
45 | augment_vector = np.full(len(bounds - 1), False, dtype=bool)
46 | pairs += zip(np.repeat(i, len(bounds - 1)), bounds[:-1], bounds[1:], augment_vector)
47 | if augment:
48 | pairs += zip(np.repeat(i, len(bounds - 1)), bounds[:-1], bounds[1:], ~augment_vector)
49 |
50 | # Initialize buffers
51 | if cameras is not None:
52 | self.batch_cam = np.empty((batch_size, cameras[0].shape[-1]))
53 | if poses_3d is not None:
54 | self.batch_3d = np.empty((batch_size, chunk_length, poses_3d[0].shape[-2], poses_3d[0].shape[-1]))
55 | self.batch_2d = np.empty((batch_size, chunk_length + 2*pad, poses_2d[0].shape[-2], poses_2d[0].shape[-1]))
56 |
57 | self.num_batches = (len(pairs) + batch_size - 1) // batch_size
58 | self.batch_size = batch_size
59 | self.random = np.random.RandomState(random_seed)
60 | self.pairs = pairs
61 | self.shuffle = shuffle
62 | self.pad = pad
63 | self.causal_shift = causal_shift
64 | self.endless = endless
65 | self.state = None
66 |
67 | self.cameras = cameras
68 | self.poses_3d = poses_3d
69 | self.poses_2d = poses_2d
70 |
71 | self.augment = augment
72 | self.noisy = noisy
73 | self.kps_left = kps_left
74 | self.kps_right = kps_right
75 | self.joints_left = joints_left
76 | self.joints_right = joints_right
77 |
78 | def num_frames(self):
79 | return self.num_batches * self.batch_size
80 |
81 | def random_state(self):
82 | return self.random
83 |
84 | def set_random_state(self, random):
85 | self.random = random
86 |
87 | def augment_enabled(self):
88 | return self.augment
89 |
90 | def next_pairs(self):
91 | if self.state is None:
92 | if self.shuffle:
93 | pairs = self.random.permutation(self.pairs)
94 | else:
95 | pairs = self.pairs
96 | return 0, pairs
97 | else:
98 | return self.state
99 |
100 | def next_epoch(self):
101 | enabled = True
102 | while enabled:
103 | start_idx, pairs = self.next_pairs()
104 | for b_i in range(start_idx, self.num_batches):
105 | chunks = pairs[b_i*self.batch_size : (b_i+1)*self.batch_size]
106 | for i, (seq_i, start_3d, end_3d, flip) in enumerate(chunks):
107 | start_2d = start_3d - self.pad - self.causal_shift
108 | end_2d = end_3d + self.pad - self.causal_shift
109 |
110 | # 2D poses
111 | seq_2d = self.poses_2d[seq_i]
112 | low_2d = max(start_2d, 0)
113 | high_2d = min(end_2d, seq_2d.shape[0])
114 | pad_left_2d = low_2d - start_2d
115 | pad_right_2d = end_2d - high_2d
116 | if pad_left_2d != 0 or pad_right_2d != 0:
117 | self.batch_2d[i] = np.pad(seq_2d[low_2d:high_2d], ((pad_left_2d, pad_right_2d), (0, 0), (0, 0)), 'edge')
118 | else:
119 | self.batch_2d[i] = seq_2d[low_2d:high_2d]
120 |
121 | if flip:
122 | # Flip 2D keypoints
123 | # self.batch_2d = np.flip(self.batch_2d, 1)
124 | self.batch_2d[i, :, :, 0] *= -1
125 | self.batch_2d[i, :, self.kps_left + self.kps_right] = self.batch_2d[i, :, self.kps_right + self.kps_left]
126 |
127 | # 3D poses
128 | if self.poses_3d is not None:
129 | seq_3d = self.poses_3d[seq_i]
130 | low_3d = max(start_3d, 0)
131 | high_3d = min(end_3d, seq_3d.shape[0])
132 | pad_left_3d = low_3d - start_3d
133 | pad_right_3d = end_3d - high_3d
134 | if pad_left_3d != 0 or pad_right_3d != 0:
135 | self.batch_3d[i] = np.pad(seq_3d[low_3d:high_3d], ((pad_left_3d, pad_right_3d), (0, 0), (0, 0)), 'edge')
136 | else:
137 | self.batch_3d[i] = seq_3d[low_3d:high_3d]
138 |
139 | if flip:
140 | # Flip 3D joints
141 | self.batch_3d[i, :, :, 0] *= -1
142 | self.batch_3d[i, :, self.joints_left + self.joints_right] = \
143 | self.batch_3d[i, :, self.joints_right + self.joints_left]
144 |
145 | # Cameras
146 | if self.cameras is not None:
147 | self.batch_cam[i] = self.cameras[seq_i]
148 | if flip:
149 | # Flip horizontal distortion coefficients
150 | self.batch_cam[i, 2] *= -1
151 | self.batch_cam[i, 7] *= -1
152 |
153 | if self.endless:
154 | self.state = (b_i + 1, pairs)
155 | if self.poses_3d is None and self.cameras is None:
156 | yield None, None, self.batch_2d[:len(chunks)]
157 | elif self.poses_3d is not None and self.cameras is None:
158 | yield None, self.batch_3d[:len(chunks)], self.batch_2d[:len(chunks)]
159 | elif self.poses_3d is None:
160 | yield self.batch_cam[:len(chunks)], None, self.batch_2d[:len(chunks)]
161 | else:
162 | yield self.batch_cam[:len(chunks)], self.batch_3d[:len(chunks)], self.batch_2d[:len(chunks)]
163 |
164 | if self.endless:
165 | self.state = None
166 | else:
167 | enabled = False
168 |
169 |
170 | class Evaluate_Generator:
171 | """
172 | Batched data generator, used for training.
173 | The sequences are split into equal-length chunks and padded as necessary.
174 |
175 | Arguments:
176 | batch_size -- the batch size to use for training
177 | cameras -- list of cameras, one element for each video (optional, used for semi-supervised training)
178 | poses_3d -- list of ground-truth 3D poses, one element for each video (optional, used for supervised training)
179 | poses_2d -- list of input 2D keypoints, one element for each video
180 | chunk_length -- number of output frames to predict for each training example (usually 1)
181 | pad -- 2D input padding to compensate for valid convolutions, per side (depends on the receptive field)
182 | causal_shift -- asymmetric padding offset when causal convolutions are used (usually 0 or "pad")
183 | shuffle -- randomly shuffle the dataset before each epoch
184 | random_seed -- initial seed to use for the random generator
185 | augment -- augment the dataset by flipping poses horizontally
186 | kps_left and kps_right -- list of left/right 2D keypoints if flipping is enabled
187 | joints_left and joints_right -- list of left/right 3D joints if flipping is enabled
188 | """
189 |
190 | def __init__(self, batch_size, cameras, poses_3d, poses_2d,
191 | chunk_length, pad=0, causal_shift=0,
192 | shuffle=True, random_seed=1234,
193 | augment=False, kps_left=None, kps_right=None, joints_left=None, joints_right=None,
194 | endless=False):
195 | assert poses_3d is None or len(poses_3d) == len(poses_2d), (len(poses_3d), len(poses_2d))
196 | assert cameras is None or len(cameras) == len(poses_2d)
197 |
198 | # Build lineage info
199 | pairs = [] # (seq_idx, start_frame, end_frame, flip) tuples
200 | for i in range(len(poses_2d)):
201 | assert poses_3d is None or poses_3d[i].shape[0] == poses_3d[i].shape[0]
202 | n_chunks = (poses_2d[i].shape[0] + chunk_length - 1) // chunk_length
203 | offset = (n_chunks * chunk_length - poses_2d[i].shape[0]) // 2
204 | bounds = np.arange(n_chunks + 1) * chunk_length - offset
205 | augment_vector = np.full(len(bounds - 1), False, dtype=bool)
206 | pairs += zip(np.repeat(i, len(bounds - 1)), bounds[:-1], bounds[1:], augment_vector)
207 |
208 | # Initialize buffers
209 | if cameras is not None:
210 | self.batch_cam = np.empty((batch_size, cameras[0].shape[-1]))
211 | if poses_3d is not None:
212 | self.batch_3d = np.empty((batch_size, chunk_length, poses_3d[0].shape[-2], poses_3d[0].shape[-1]))
213 |
214 | if augment:
215 | self.batch_2d_flip = np.empty((batch_size, chunk_length + 2 * pad, poses_2d[0].shape[-2], poses_2d[0].shape[-1]))
216 | self.batch_2d = np.empty((batch_size, chunk_length + 2 * pad, poses_2d[0].shape[-2], poses_2d[0].shape[-1]))
217 | else:
218 | self.batch_2d = np.empty((batch_size, chunk_length + 2 * pad, poses_2d[0].shape[-2], poses_2d[0].shape[-1]))
219 |
220 | self.num_batches = (len(pairs) + batch_size - 1) // batch_size
221 | self.batch_size = batch_size
222 | self.random = np.random.RandomState(random_seed)
223 | self.pairs = pairs
224 | self.shuffle = shuffle
225 | self.pad = pad
226 | self.causal_shift = causal_shift
227 | self.endless = endless
228 | self.state = None
229 |
230 | self.cameras = cameras
231 | self.poses_3d = poses_3d
232 | self.poses_2d = poses_2d
233 |
234 | self.augment = augment
235 | self.kps_left = kps_left
236 | self.kps_right = kps_right
237 | self.joints_left = joints_left
238 | self.joints_right = joints_right
239 |
240 | def num_frames(self):
241 | return self.num_batches * self.batch_size
242 |
243 | def random_state(self):
244 | return self.random
245 |
246 | def set_random_state(self, random):
247 | self.random = random
248 |
249 | def augment_enabled(self):
250 | return self.augment
251 |
252 | def next_pairs(self):
253 | if self.state is None:
254 | if self.shuffle:
255 | pairs = self.random.permutation(self.pairs)
256 | else:
257 | pairs = self.pairs
258 | return 0, pairs
259 | else:
260 | return self.state
261 |
262 | def next_epoch(self):
263 | enabled = True
264 | while enabled:
265 | start_idx, pairs = self.next_pairs()
266 | for b_i in range(start_idx, self.num_batches):
267 | chunks = pairs[b_i * self.batch_size: (b_i + 1) * self.batch_size]
268 | for i, (seq_i, start_3d, end_3d, flip) in enumerate(chunks):
269 | start_2d = start_3d - self.pad - self.causal_shift
270 | end_2d = end_3d + self.pad - self.causal_shift
271 |
272 | # 2D poses
273 | seq_2d = self.poses_2d[seq_i]
274 | low_2d = max(start_2d, 0)
275 | high_2d = min(end_2d, seq_2d.shape[0])
276 | pad_left_2d = low_2d - start_2d
277 | pad_right_2d = end_2d - high_2d
278 | if pad_left_2d != 0 or pad_right_2d != 0:
279 | self.batch_2d[i] = np.pad(seq_2d[low_2d:high_2d], ((pad_left_2d, pad_right_2d), (0, 0), (0, 0)),
280 | 'edge')
281 | if self.augment:
282 | self.batch_2d_flip[i] = np.pad(seq_2d[low_2d:high_2d],
283 | ((pad_left_2d, pad_right_2d), (0, 0), (0, 0)),
284 | 'edge')
285 |
286 | else:
287 | self.batch_2d[i] = seq_2d[low_2d:high_2d]
288 | if self.augment:
289 | self.batch_2d_flip[i] = seq_2d[low_2d:high_2d]
290 |
291 | if self.augment:
292 | self.batch_2d_flip[i, :, :, 0] *= -1
293 | self.batch_2d_flip[i, :, self.kps_left + self.kps_right] = self.batch_2d_flip[i, :, self.kps_right + self.kps_left]
294 |
295 | # 3D poses
296 | if self.poses_3d is not None:
297 | seq_3d = self.poses_3d[seq_i]
298 | low_3d = max(start_3d, 0)
299 | high_3d = min(end_3d, seq_3d.shape[0])
300 | pad_left_3d = low_3d - start_3d
301 | pad_right_3d = end_3d - high_3d
302 | if pad_left_3d != 0 or pad_right_3d != 0:
303 | self.batch_3d[i] = np.pad(seq_3d[low_3d:high_3d],
304 | ((pad_left_3d, pad_right_3d), (0, 0), (0, 0)), 'edge')
305 | else:
306 | self.batch_3d[i] = seq_3d[low_3d:high_3d]
307 |
308 | if flip:
309 | self.batch_3d[i, :, :, 0] *= -1
310 | self.batch_3d[i, :, self.joints_left + self.joints_right] = \
311 | self.batch_3d[i, :, self.joints_right + self.joints_left]
312 |
313 | # Cameras
314 | if self.cameras is not None:
315 | self.batch_cam[i] = self.cameras[seq_i]
316 | if flip:
317 | # Flip horizontal distortion coefficients
318 | self.batch_cam[i, 2] *= -1
319 | self.batch_cam[i, 7] *= -1
320 |
321 | if self.endless:
322 | self.state = (b_i + 1, pairs)
323 |
324 | if self.augment:
325 | if self.poses_3d is None and self.cameras is None:
326 | yield None, None, self.batch_2d[:len(chunks)], self.batch_2d_flip[:len(chunks)]
327 | elif self.poses_3d is not None and self.cameras is None:
328 | yield None, self.batch_3d[:len(chunks)], self.batch_2d[:len(chunks)], self.batch_2d_flip[:len(chunks)]
329 | elif self.poses_3d is None:
330 | yield self.batch_cam[:len(chunks)], None, self.batch_2d[:len(chunks)], self.batch_2d_flip[:len(chunks)]
331 | else:
332 | yield self.batch_cam[:len(chunks)], self.batch_3d[:len(chunks)], self.batch_2d[:len(chunks)], self.batch_2d_flip[:len(chunks)]
333 | else:
334 | if self.poses_3d is None and self.cameras is None:
335 | yield None, None, self.batch_2d[:len(chunks)]
336 | elif self.poses_3d is not None and self.cameras is None:
337 | yield None, self.batch_3d[:len(chunks)], self.batch_2d[:len(chunks)]
338 | elif self.poses_3d is None:
339 | yield self.batch_cam[:len(chunks)], None, self.batch_2d[:len(chunks)]
340 | else:
341 | yield self.batch_cam[:len(chunks)], self.batch_3d[:len(chunks)], self.batch_2d[:len(chunks)]
342 |
343 | if self.endless:
344 | self.state = None
345 | else:
346 | enabled = False
347 |
348 |
349 |
--------------------------------------------------------------------------------
/common/h36m_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2018-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | import numpy as np
9 | import copy
10 | from common.skeleton import Skeleton
11 | from common.mocap_dataset import MocapDataset
12 | from common.camera import normalize_screen_coordinates, image_coordinates
13 |
14 | h36m_skeleton = Skeleton(parents=[-1, 0, 1, 2, 3, 4, 0, 6, 7, 8, 9, 0, 11, 12, 13, 14, 12,
15 | 16, 17, 18, 19, 20, 19, 22, 12, 24, 25, 26, 27, 28, 27, 30],
16 | joints_left=[6, 7, 8, 9, 10, 16, 17, 18, 19, 20, 21, 22, 23],
17 | joints_right=[1, 2, 3, 4, 5, 24, 25, 26, 27, 28, 29, 30, 31])
18 |
19 | h36m_cameras_intrinsic_params = [
20 | {
21 | 'id': '54138969',
22 | 'center': [512.54150390625, 515.4514770507812],
23 | 'focal_length': [1145.0494384765625, 1143.7811279296875],
24 | 'radial_distortion': [-0.20709891617298126, 0.24777518212795258, -0.0030751503072679043],
25 | 'tangential_distortion': [-0.0009756988729350269, -0.00142447161488235],
26 | 'res_w': 1000,
27 | 'res_h': 1002,
28 | 'azimuth': 70, # Only used for visualization
29 | },
30 | {
31 | 'id': '55011271',
32 | 'center': [508.8486328125, 508.0649108886719],
33 | 'focal_length': [1149.6756591796875, 1147.5916748046875],
34 | 'radial_distortion': [-0.1942136287689209, 0.2404085397720337, 0.006819975562393665],
35 | 'tangential_distortion': [-0.0016190266469493508, -0.0027408944442868233],
36 | 'res_w': 1000,
37 | 'res_h': 1000,
38 | 'azimuth': -70, # Only used for visualization
39 | },
40 | {
41 | 'id': '58860488',
42 | 'center': [519.8158569335938, 501.40264892578125],
43 | 'focal_length': [1149.1407470703125, 1148.7989501953125],
44 | 'radial_distortion': [-0.2083381861448288, 0.25548800826072693, -0.0024604974314570427],
45 | 'tangential_distortion': [0.0014843869721516967, -0.0007599993259645998],
46 | 'res_w': 1000,
47 | 'res_h': 1000,
48 | 'azimuth': 110, # Only used for visualization
49 | },
50 | {
51 | 'id': '60457274',
52 | 'center': [514.9682006835938, 501.88201904296875],
53 | 'focal_length': [1145.5113525390625, 1144.77392578125],
54 | 'radial_distortion': [-0.198384091258049, 0.21832367777824402, -0.008947807364165783],
55 | 'tangential_distortion': [-0.0005872055771760643, -0.0018133620033040643],
56 | 'res_w': 1000,
57 | 'res_h': 1002,
58 | 'azimuth': -110, # Only used for visualization
59 | },
60 | ]
61 |
62 | h36m_cameras_extrinsic_params = {
63 | 'S1': [
64 | {
65 | 'orientation': [0.1407056450843811, -0.1500701755285263, -0.755240797996521, 0.6223280429840088],
66 | 'translation': [1841.1070556640625, 4955.28466796875, 1563.4454345703125],
67 | },
68 | {
69 | 'orientation': [0.6157187819480896, -0.764836311340332, -0.14833825826644897, 0.11794740706682205],
70 | 'translation': [1761.278564453125, -5078.0068359375, 1606.2650146484375],
71 | },
72 | {
73 | 'orientation': [0.14651472866535187, -0.14647851884365082, 0.7653023600578308, -0.6094175577163696],
74 | 'translation': [-1846.7777099609375, 5215.04638671875, 1491.972412109375],
75 | },
76 | {
77 | 'orientation': [0.5834008455276489, -0.7853162288665771, 0.14548823237419128, -0.14749594032764435],
78 | 'translation': [-1794.7896728515625, -3722.698974609375, 1574.8927001953125],
79 | },
80 | ],
81 | 'S2': [
82 | {},
83 | {},
84 | {},
85 | {},
86 | ],
87 | 'S3': [
88 | {},
89 | {},
90 | {},
91 | {},
92 | ],
93 | 'S4': [
94 | {},
95 | {},
96 | {},
97 | {},
98 | ],
99 | 'S5': [
100 | {
101 | 'orientation': [0.1467377245426178, -0.162370964884758, -0.7551892995834351, 0.6178938746452332],
102 | 'translation': [2097.3916015625, 4880.94482421875, 1605.732421875],
103 | },
104 | {
105 | 'orientation': [0.6159758567810059, -0.7626792192459106, -0.15728192031383514, 0.1189815029501915],
106 | 'translation': [2031.7008056640625, -5167.93310546875, 1612.923095703125],
107 | },
108 | {
109 | 'orientation': [0.14291371405124664, -0.12907841801643372, 0.7678384780883789, -0.6110143065452576],
110 | 'translation': [-1620.5948486328125, 5171.65869140625, 1496.43701171875],
111 | },
112 | {
113 | 'orientation': [0.5920479893684387, -0.7814217805862427, 0.1274748593568802, -0.15036417543888092],
114 | 'translation': [-1637.1737060546875, -3867.3173828125, 1547.033203125],
115 | },
116 | ],
117 | 'S6': [
118 | {
119 | 'orientation': [0.1337897777557373, -0.15692396461963654, -0.7571090459823608, 0.6198879480361938],
120 | 'translation': [1935.4517822265625, 4950.24560546875, 1618.0838623046875],
121 | },
122 | {
123 | 'orientation': [0.6147197484970093, -0.7628812789916992, -0.16174767911434174, 0.11819244921207428],
124 | 'translation': [1969.803955078125, -5128.73876953125, 1632.77880859375],
125 | },
126 | {
127 | 'orientation': [0.1529948115348816, -0.13529130816459656, 0.7646096348762512, -0.6112781167030334],
128 | 'translation': [-1769.596435546875, 5185.361328125, 1476.993408203125],
129 | },
130 | {
131 | 'orientation': [0.5916101336479187, -0.7804774045944214, 0.12832270562648773, -0.1561593860387802],
132 | 'translation': [-1721.668701171875, -3884.13134765625, 1540.4879150390625],
133 | },
134 | ],
135 | 'S7': [
136 | {
137 | 'orientation': [0.1435241848230362, -0.1631336808204651, -0.7548328638076782, 0.6188824772834778],
138 | 'translation': [1974.512939453125, 4926.3544921875, 1597.8326416015625],
139 | },
140 | {
141 | 'orientation': [0.6141672730445862, -0.7638262510299683, -0.1596645563840866, 0.1177929937839508],
142 | 'translation': [1937.0584716796875, -5119.7900390625, 1631.5665283203125],
143 | },
144 | {
145 | 'orientation': [0.14550060033798218, -0.12874816358089447, 0.7660516500473022, -0.6127139329910278],
146 | 'translation': [-1741.8111572265625, 5208.24951171875, 1464.8245849609375],
147 | },
148 | {
149 | 'orientation': [0.5912848114967346, -0.7821764349937439, 0.12445473670959473, -0.15196487307548523],
150 | 'translation': [-1734.7105712890625, -3832.42138671875, 1548.5830078125],
151 | },
152 | ],
153 | 'S8': [
154 | {
155 | 'orientation': [0.14110587537288666, -0.15589867532253265, -0.7561917304992676, 0.619644045829773],
156 | 'translation': [2150.65185546875, 4896.1611328125, 1611.9046630859375],
157 | },
158 | {
159 | 'orientation': [0.6169601678848267, -0.7647668123245239, -0.14846350252628326, 0.11158157885074615],
160 | 'translation': [2219.965576171875, -5148.453125, 1613.0440673828125],
161 | },
162 | {
163 | 'orientation': [0.1471444070339203, -0.13377119600772858, 0.7670128345489502, -0.6100369691848755],
164 | 'translation': [-1571.2215576171875, 5137.0185546875, 1498.1761474609375],
165 | },
166 | {
167 | 'orientation': [0.5927824378013611, -0.7825870513916016, 0.12147816270589828, -0.14631995558738708],
168 | 'translation': [-1476.913330078125, -3896.7412109375, 1547.97216796875],
169 | },
170 | ],
171 | 'S9': [
172 | {
173 | 'orientation': [0.15540587902069092, -0.15548215806484222, -0.7532095313072205, 0.6199594736099243],
174 | 'translation': [2044.45849609375, 4935.1171875, 1481.2275390625],
175 | },
176 | {
177 | 'orientation': [0.618784487247467, -0.7634735107421875, -0.14132238924503326, 0.11933968216180801],
178 | 'translation': [1990.959716796875, -5123.810546875, 1568.8048095703125],
179 | },
180 | {
181 | 'orientation': [0.13357827067375183, -0.1367100477218628, 0.7689454555511475, -0.6100738644599915],
182 | 'translation': [-1670.9921875, 5211.98583984375, 1528.387939453125],
183 | },
184 | {
185 | 'orientation': [0.5879399180412292, -0.7823407053947449, 0.1427614390850067, -0.14794869720935822],
186 | 'translation': [-1696.04345703125, -3827.099853515625, 1591.4127197265625],
187 | },
188 | ],
189 | 'S11': [
190 | {
191 | 'orientation': [0.15232472121715546, -0.15442320704460144, -0.7547563314437866, 0.6191070079803467],
192 | 'translation': [2098.440185546875, 4926.5546875, 1500.278564453125],
193 | },
194 | {
195 | 'orientation': [0.6189449429512024, -0.7600917220115662, -0.15300633013248444, 0.1255258321762085],
196 | 'translation': [2083.182373046875, -4912.1728515625, 1561.07861328125],
197 | },
198 | {
199 | 'orientation': [0.14943228662014008, -0.15650227665901184, 0.7681233882904053, -0.6026304364204407],
200 | 'translation': [-1609.8153076171875, 5177.3359375, 1537.896728515625],
201 | },
202 | {
203 | 'orientation': [0.5894251465797424, -0.7818877100944519, 0.13991211354732513, -0.14715361595153809],
204 | 'translation': [-1590.738037109375, -3854.1689453125, 1578.017578125],
205 | },
206 | ],
207 | }
208 |
209 | class Human36mDataset(MocapDataset):
210 | def __init__(self, path, remove_static_joints=True):
211 | super().__init__(fps=50, skeleton=h36m_skeleton)
212 |
213 | self._cameras = copy.deepcopy(h36m_cameras_extrinsic_params)
214 | for cameras in self._cameras.values():
215 | for i, cam in enumerate(cameras):
216 | cam.update(h36m_cameras_intrinsic_params[i])
217 | for k, v in cam.items():
218 | if k not in ['id', 'res_w', 'res_h']:
219 | cam[k] = np.array(v, dtype='float32')
220 |
221 | # Normalize camera frame
222 | cam['center'] = normalize_screen_coordinates(cam['center'], w=cam['res_w'], h=cam['res_h']).astype('float32')
223 | cam['focal_length'] = cam['focal_length']/cam['res_w']*2
224 | if 'translation' in cam:
225 | cam['translation'] = cam['translation']/1000 # mm to meters
226 |
227 | # Add intrinsic parameters vector
228 | cam['intrinsic'] = np.concatenate((cam['focal_length'],
229 | cam['center'],
230 | cam['radial_distortion'],
231 | cam['tangential_distortion']))
232 |
233 | # Load serialized dataset
234 | data = np.load(path, allow_pickle=True)['positions_3d'].item()
235 |
236 | self._data = {}
237 | for subject, actions in data.items():
238 | self._data[subject] = {}
239 | for action_name, positions in actions.items():
240 | self._data[subject][action_name] = {
241 | 'positions': positions,
242 | 'cameras': self._cameras[subject],
243 | }
244 |
245 | if remove_static_joints:
246 | # Bring the skeleton to 17 joints instead of the original 32
247 | self.remove_joints([4, 5, 9, 10, 11, 16, 20, 21, 22, 23, 24, 28, 29, 30, 31])
248 |
249 | # Rewire shoulders to the correct parents
250 | self._skeleton._parents[11] = 8
251 | self._skeleton._parents[14] = 8
252 |
253 | def supports_semi_supervised(self):
254 | return True
255 |
256 | def define_actions(self):
257 | all_actions = ["Directions",
258 | "Discussion",
259 | "Eating",
260 | "Greeting",
261 | "Phoning",
262 | "Photo",
263 | "Posing",
264 | "Purchases",
265 | "Sitting",
266 | "SittingDown",
267 | "Smoking",
268 | "Waiting",
269 | "WalkDog",
270 | "Walking",
271 | "WalkTogether"]
272 |
273 | return all_actions
274 |
--------------------------------------------------------------------------------
/common/loss.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2018-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | import torch
9 | import numpy as np
10 |
11 | def mpjpe(predicted, target):
12 | """
13 | Mean per-joint position error (i.e. mean Euclidean distance),
14 | often referred to as "Protocol #1" in many papers.
15 | """
16 | assert predicted.shape == target.shape
17 | return torch.mean(torch.norm(predicted - target, dim=len(target.shape)-1))
18 |
19 |
20 | def p_mpjpe(predicted, target):
21 | """
22 | Pose error: MPJPE after rigid alignment (scale, rotation, and translation),
23 | often referred to as "Protocol #2" in many papers.
24 | """
25 | assert predicted.shape == target.shape
26 |
27 | muX = np.mean(target, axis=1, keepdims=True)
28 | muY = np.mean(predicted, axis=1, keepdims=True)
29 |
30 | X0 = target - muX
31 | Y0 = predicted - muY
32 |
33 | normX = np.sqrt(np.sum(X0**2, axis=(1, 2), keepdims=True))
34 | normY = np.sqrt(np.sum(Y0**2, axis=(1, 2), keepdims=True))
35 |
36 | X0 /= normX
37 | Y0 /= normY
38 |
39 | H = np.matmul(X0.transpose(0, 2, 1), Y0)
40 | U, s, Vt = np.linalg.svd(H)
41 | V = Vt.transpose(0, 2, 1)
42 | R = np.matmul(V, U.transpose(0, 2, 1))
43 |
44 | # Avoid improper rotations (reflections), i.e. rotations with det(R) = -1
45 | sign_detR = np.sign(np.expand_dims(np.linalg.det(R), axis=1))
46 | V[:, :, -1] *= sign_detR
47 | s[:, -1] *= sign_detR.flatten()
48 | R = np.matmul(V, U.transpose(0, 2, 1)) # Rotation
49 |
50 | tr = np.expand_dims(np.sum(s, axis=1, keepdims=True), axis=2)
51 |
52 | a = tr * normX / normY # Scale
53 | t = muX - a*np.matmul(muY, R) # Translation
54 |
55 | # Perform rigid transformation on the input
56 | predicted_aligned = a*np.matmul(predicted, R) + t
57 |
58 | # Return MPJPE
59 | return np.mean(np.linalg.norm(predicted_aligned - target, axis=len(target.shape)-1))
60 |
61 |
--------------------------------------------------------------------------------
/common/mocap_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2018-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | import numpy as np
9 | from common.skeleton import Skeleton
10 |
11 | class MocapDataset:
12 | def __init__(self, fps, skeleton):
13 | self._skeleton = skeleton
14 | self._fps = fps
15 | self._data = None # Must be filled by subclass
16 | self._cameras = None # Must be filled by subclass
17 |
18 | def remove_joints(self, joints_to_remove):
19 | kept_joints = self._skeleton.remove_joints(joints_to_remove)
20 | for subject in self._data.keys():
21 | for action in self._data[subject].keys():
22 | s = self._data[subject][action]
23 | s['positions'] = s['positions'][:, kept_joints]
24 |
25 |
26 | def __getitem__(self, key):
27 | return self._data[key]
28 |
29 | def subjects(self):
30 | return self._data.keys()
31 |
32 | def fps(self):
33 | return self._fps
34 |
35 | def skeleton(self):
36 | return self._skeleton
37 |
38 | def cameras(self):
39 | return self._cameras
40 |
41 | def supports_semi_supervised(self):
42 | # This method can be overridden
43 | return False
--------------------------------------------------------------------------------
/common/model.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | import torch.nn.functional as F
4 |
5 |
6 | class DWConv(nn.Module):
7 | def __init__(self, in_features, out_features, kernel_size=3, stride=3):
8 | super(DWConv, self).__init__()
9 | self.DW_conv = nn.Conv1d(in_features, in_features, kernel_size=kernel_size, stride=stride,
10 | groups=in_features, bias=False)
11 | self.DW_bn = nn.BatchNorm1d(in_features, momentum=0.1)
12 | self.PW_conv = nn.Conv1d(in_features, out_features, kernel_size=1, stride=1, bias=False)
13 | self.PW_bn = nn.BatchNorm1d(out_features, momentum=0.1)
14 |
15 | def forward(self, x):
16 | x = self.DW_conv(x)
17 | x = self.DW_bn(x)
18 | x = self.PW_conv(x)
19 | x = self.PW_bn(x)
20 | return x
21 |
22 |
23 | class Kernel_Attention(nn.Module):
24 | def __init__(self, in_features, out_features=1024, M=3, G=8, r=128, stride=3):
25 | super(Kernel_Attention, self).__init__()
26 | self.convs = nn.ModuleList([])
27 |
28 | for i in range(M):
29 | self.convs.append(nn.Sequential(
30 | nn.Conv1d(in_features, in_features, kernel_size=3, dilation=i + 1, stride=stride, padding=0,
31 | groups=in_features, bias=False),
32 | nn.BatchNorm1d(in_features),
33 | nn.Conv1d(in_features, out_features, kernel_size=1, stride=1, padding=0, groups=G, bias=False),
34 | nn.BatchNorm1d(out_features),
35 | Mish()
36 | ))
37 | self.fc = nn.Linear(out_features, r)
38 |
39 | self.fcs = nn.ModuleList([])
40 | for i in range(M):
41 | self.fcs.append(
42 | nn.Linear(r, out_features)
43 | )
44 | self.softmax = nn.Softmax(dim=1)
45 |
46 | def forward(self, x):
47 | for i, conv in enumerate(self.convs):
48 | if i == 0:
49 | fea = conv(x).unsqueeze_(dim=1)
50 | feas = fea
51 | else:
52 | fea = F.pad(x, (i, i), 'replicate')
53 | fea = conv(fea).unsqueeze_(dim=1)
54 | feas = torch.cat([feas, fea], dim=1)
55 | fea_U = torch.sum(feas, dim=1)
56 | fea_s = fea_U.mean(-1)
57 | fea_z = self.fc(fea_s)
58 | for i, fc in enumerate(self.fcs):
59 | vector = fc(fea_z).unsqueeze_(dim=1)
60 | if i == 0:
61 | attention_vectors = vector
62 | else:
63 | attention_vectors = torch.cat([attention_vectors, vector], dim=1)
64 | attention_vectors = self.softmax(attention_vectors)
65 | attention_vectors = attention_vectors.unsqueeze(-1)
66 | fea_v = (feas * attention_vectors).sum(dim=1)
67 | return fea_v
68 |
69 | class Mish(nn.Module):
70 | def __init__(self):
71 | super().__init__()
72 |
73 | def forward(self, x):
74 | x = x * (torch.tanh(F.softplus(x)))
75 | return x
76 |
77 |
78 | class TemporalModelBase(nn.Module):
79 | """
80 | Do not instantiate this class.
81 | """
82 |
83 | def __init__(self, num_joints_in, in_features, num_joints_out,
84 | filter_widths, causal, dropout, channels):
85 | super().__init__()
86 |
87 | # Validate input
88 | for fw in filter_widths:
89 | assert fw % 2 != 0, 'Only odd filter widths are supported'
90 |
91 | self.num_joints_in = num_joints_in
92 | self.in_features = in_features
93 | self.num_joints_out = num_joints_out
94 | self.filter_widths = filter_widths
95 |
96 | self.drop = nn.Dropout(dropout)
97 | self.relu = Mish()
98 | self.sigmoid = nn.Sigmoid()
99 |
100 | self.pad = [filter_widths[0] // 2]
101 | self.shrink = nn.Conv1d(channels, num_joints_out * 3, 1)
102 |
103 | def set_bn_momentum(self, momentum):
104 | for bn in self.layers_bn:
105 | bn.momentum = momentum
106 | for bn in self.layers_tem_bn:
107 | bn.momentum = momentum
108 |
109 | def receptive_field(self):
110 | """
111 | Return the total receptive field of this model as # of frames.
112 | """
113 | frames = 0
114 | for f in self.pad:
115 | frames += f
116 | return 1 + 2 * frames
117 |
118 | def total_causal_shift(self):
119 | """
120 | Return the asymmetric offset for sequence padding.
121 | The returned value is typically 0 if causal convolutions are disabled,
122 | otherwise it is half the receptive field.
123 | """
124 | frames = self.causal_shift[0]
125 | next_dilation = self.filter_widths[0]
126 | for i in range(1, len(self.filter_widths)):
127 | frames += self.causal_shift[i] * next_dilation
128 | next_dilation *= self.filter_widths[i]
129 | return frames
130 |
131 | def forward(self, x):
132 | assert len(x.shape) == 4
133 | assert x.shape[-2] == self.num_joints_in
134 | assert x.shape[-1] == self.in_features
135 |
136 | sz = x.shape[:3]
137 | mean = x[:, :, 0:1, :].expand_as(x)
138 | input_pose_centered = x - mean
139 |
140 | x = x.view(x.shape[0], x.shape[1], -1)
141 | x = x.permute(0, 2, 1)
142 |
143 | input_pose_centered = input_pose_centered.view(input_pose_centered.shape[0], input_pose_centered.shape[1], -1)
144 | input_pose_centered = input_pose_centered.permute(0, 2, 1)
145 |
146 | x = self._forward_blocks(x, input_pose_centered)
147 |
148 | x = x.permute(0, 2, 1)
149 | x = x.view(sz[0], -1, self.num_joints_out, 3)
150 |
151 | return x
152 |
153 |
154 | class TemporalModelOptimized1f(TemporalModelBase):
155 | """
156 | 3D pose estimation model optimized for single-frame batching, i.e.
157 | where batches have input length = receptive field, and output length = 1.
158 | This scenario is only used for training when stride == 1.
159 |
160 | This implementation replaces dilated convolutions with strided convolutions
161 | to avoid generating unused intermediate results. The weights are interchangeable
162 | with the reference implementation.
163 | """
164 |
165 | def __init__(self, num_joints_in, in_features, num_joints_out,
166 | filter_widths, causal=False, dropout=0.2, channels=1024, dense=False):
167 | """
168 | Initialize this model.
169 |
170 | Arguments:
171 | num_joints_in -- number of input joints (e.g. 17 for Human3.6M)
172 | in_features -- number of input features for each joint (typically 2 for 2D input)
173 | num_joints_out -- number of output joints (can be different than input)
174 | filter_widths -- list of convolution widths, which also determines the # of blocks and receptive field
175 | causal -- use causal convolutions instead of symmetric convolutions (for real-time applications)
176 | dropout -- dropout probability
177 | channels -- number of convolution channels
178 | """
179 | super().__init__(num_joints_in, in_features, num_joints_out, filter_widths, causal, dropout, channels)
180 |
181 | expand_conv = []
182 | for i in range(len(filter_widths) - 1):
183 | expand_conv.append(DWConv(num_joints_in * in_features, channels,
184 | kernel_size=filter_widths[0], stride=filter_widths[0]))
185 | self.expand_conv = nn.ModuleList(expand_conv)
186 |
187 | self.cos_dis = nn.CosineSimilarity(dim=1, eps=1e-6)
188 | layers_tem_att = []
189 | layers_tem_bn = []
190 | self.frames = self.total_frame()
191 |
192 | layers_conv = []
193 | layers_bn = []
194 |
195 | self.causal_shift = [(filter_widths[0] // 2) if causal else 0]
196 | next_dilation = filter_widths[0]
197 |
198 | dilation_conv = []
199 | dilation_bn = []
200 |
201 | for i in range(3):
202 | dilation_conv.append(DWConv(channels, channels, kernel_size=filter_widths[i], stride=filter_widths[i]))
203 | dilation_conv.append(nn.Conv1d(channels, channels, 1, dilation=1, bias=False))
204 | dilation_bn.append(nn.BatchNorm1d(channels, momentum=0.1))
205 |
206 | self.dilation_conv = nn.ModuleList(dilation_conv)
207 | self.dilation_bn = nn.ModuleList(dilation_bn)
208 |
209 | for i in range(1, len(filter_widths)):
210 | self.pad.append((filter_widths[i] - 1) * next_dilation // 2)
211 | self.causal_shift.append((filter_widths[i] // 2) if causal else 0)
212 |
213 | layers_tem_att.append(nn.Linear(self.frames, self.frames // next_dilation))
214 | layers_tem_bn.append(nn.BatchNorm1d(self.frames // next_dilation))
215 |
216 | layers_conv.append(Kernel_Attention(channels, out_features=channels))
217 | layers_conv.append(nn.Conv1d(channels, channels, 1, dilation=1, bias=False))
218 | layers_bn.append(nn.BatchNorm1d(channels, momentum=0.1))
219 |
220 | next_dilation *= filter_widths[i]
221 |
222 | self.layers_conv = nn.ModuleList(layers_conv)
223 | self.layers_bn = nn.ModuleList(layers_bn)
224 | self.layers_tem_att = nn.ModuleList(layers_tem_att)
225 | self.layers_tem_bn = nn.ModuleList(layers_tem_bn)
226 |
227 | def set_KA_bn(self, momentum):
228 | for i in range(len(self.layers_conv) // 2):
229 | for j in range(3):
230 | self.layers_conv[2 * i].convs[j][1].momentum = momentum
231 | self.layers_conv[2 * i].convs[j][3].momentum = momentum
232 |
233 | def set_expand_bn(self, momentum):
234 | for i in range(len(self.expand_conv)):
235 | self.expand_conv[i].DW_bn.momentum = momentum
236 | self.expand_conv[i].PW_bn.momentum = momentum
237 |
238 | def set_dilation_bn(self, momentum):
239 | for bn in self.dilation_bn:
240 | bn.momentum = momentum
241 | for i in range(len(self.dilation_conv)//2):
242 | self.dilation_conv[2*i].DW_bn.momentum = momentum
243 | self.dilation_conv[2*i].PW_bn.momentum = momentum
244 |
245 | def total_frame(self):
246 | frames = 1
247 | for i in range(len(self.filter_widths)):
248 | frames *= self.filter_widths[i]
249 | return frames
250 |
251 | def _forward_blocks(self, x, input_2D_centered):
252 | b, c, t = input_2D_centered.size()
253 | x_target = input_2D_centered[:, :, input_2D_centered.shape[2] // 2]
254 | x_target_extend = x_target.view(b, c, 1)
255 | x_traget_matrix = x_target_extend.expand_as(input_2D_centered)
256 | cos_score = self.cos_dis(x_traget_matrix, input_2D_centered)
257 |
258 | '''
259 | Top layers
260 | '''
261 | x_0_1 = x[:, :, 1::3]
262 | x_0_2 = x[:, :, 4::9]
263 | x_0_3 = x[:, :, 13::27]
264 |
265 | x = self.drop(self.relu(self.expand_conv[0](x)))
266 | x_0_1 = self.drop(self.relu(self.expand_conv[1](x_0_1)))
267 | x_0_2 = self.drop(self.relu(self.expand_conv[2](x_0_2)))
268 | x_0_3 = self.drop(self.relu(self.expand_conv[3](x_0_3)))
269 |
270 | for i in range(len(self.pad) - 1):
271 | res = x[:, :, self.causal_shift[i + 1] + self.filter_widths[i + 1] // 2:: self.filter_widths[i + 1]]
272 | t_attention = self.sigmoid(self.layers_tem_bn[i](self.layers_tem_att[i](cos_score))) # [batches frames]
273 | t_attention_expand = t_attention.unsqueeze(1) # [batches channels frames]
274 | if i == 0:
275 | res_1_1 = res[:, :, 1::3]
276 | res_1_2 = res[:, :, 4::9]
277 | x = x * t_attention_expand # broadcasting dot mul
278 | x_1_1 = x[:, :, 1::3]
279 | x_1_2 = x[:, :, 4::9]
280 |
281 | x = self.drop(self.layers_conv[2 * i](x))
282 | x = res + self.drop(self.relu(self.layers_bn[i](self.layers_conv[2 * i + 1](x))))
283 |
284 | x_1_1 = self.drop(self.relu(self.dilation_conv[0](x_1_1)))
285 | x_1_1 = res_1_1 + self.drop(self.relu(self.dilation_bn[0](self.dilation_conv[1](x_1_1))))
286 |
287 | x_1_2 = self.drop(self.relu(self.dilation_conv[2](x_1_2)))
288 | x_1_2 = res_1_2 + self.drop(self.relu(self.dilation_bn[1](self.dilation_conv[3](x_1_2))))
289 |
290 | elif i == 1:
291 | res_2_1 = res[:, :, 1::3]
292 | x = x * t_attention_expand # broadcasting dot mul
293 | x_2_1 = x[:, :, 1::3]
294 | x_0_1 = x_0_1 * t_attention_expand # broadcasting dot mul
295 | x = x + x_0_1
296 |
297 | x = self.drop(self.layers_conv[2 * i](x))
298 | x = res + self.drop(self.relu(self.layers_bn[i](self.layers_conv[2 * i + 1](x))))
299 |
300 | x_2_1 = self.drop(self.relu(self.dilation_conv[4](x_2_1)))
301 | x_2_1 = res_2_1 + self.drop(self.relu(self.dilation_bn[2](self.dilation_conv[5](x_2_1))))
302 |
303 | elif i == 2:
304 | x = x + x_0_2 + x_1_1
305 | x = x * t_attention_expand # broadcasting dot mul
306 | x = self.drop(self.layers_conv[2 * i](x))
307 | x = res + self.drop(self.relu(self.layers_bn[i](self.layers_conv[2 * i + 1](x))))
308 | elif i == 3:
309 | x = x + x_0_3 + x_1_2 + x_2_1
310 | x = x * t_attention_expand # broadcasting dot mul
311 | x = self.drop(self.layers_conv[2 * i](x))
312 | x = res + self.drop(self.relu(self.layers_bn[i](self.layers_conv[2 * i + 1](x))))
313 |
314 | x = self.shrink(x)
315 | return x
316 |
317 |
--------------------------------------------------------------------------------
/common/quaternion.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2018-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | import torch
9 |
10 | def qrot(q, v):
11 | """
12 | Rotate vector(s) v about the rotation described by quaternion(s) q.
13 | Expects a tensor of shape (*, 4) for q and a tensor of shape (*, 3) for v,
14 | where * denotes any number of dimensions.
15 | Returns a tensor of shape (*, 3).
16 | """
17 | assert q.shape[-1] == 4
18 | assert v.shape[-1] == 3
19 | assert q.shape[:-1] == v.shape[:-1]
20 |
21 | qvec = q[..., 1:]
22 | uv = torch.cross(qvec, v, dim=len(q.shape)-1)
23 | uuv = torch.cross(qvec, uv, dim=len(q.shape)-1)
24 | return (v + 2 * (q[..., :1] * uv + uuv))
25 |
26 |
27 | def qinverse(q, inplace=False):
28 | # We assume the quaternion to be normalized
29 | if inplace:
30 | q[..., 1:] *= -1
31 | return q
32 | else:
33 | w = q[..., :1]
34 | xyz = q[..., 1:]
35 | return torch.cat((w, -xyz), dim=len(q.shape)-1)
--------------------------------------------------------------------------------
/common/ranger.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch.optim.optimizer import Optimizer #, required
4 | import itertools as it
5 |
6 | class Ranger(Optimizer):
7 |
8 | def __init__(self, params, lr=1e-3, alpha=0.5, k=6, N_sma_threshhold=5, betas=(.95,0.999), eps=1e-5, weight_decay=0):
9 | #parameter checks
10 | if not 0.0 <= alpha <= 1.0:
11 | raise ValueError(f'Invalid slow update rate: {alpha}')
12 | if not 1 <= k:
13 | raise ValueError(f'Invalid lookahead steps: {k}')
14 | if not lr > 0:
15 | raise ValueError(f'Invalid Learning Rate: {lr}')
16 | if not eps > 0:
17 | raise ValueError(f'Invalid eps: {eps}')
18 |
19 | #parameter comments:
20 | # beta1 (momentum) of .95 seems to work better than .90...
21 | #N_sma_threshold of 5 seems better in testing than 4.
22 | #In both cases, worth testing on your dataset (.90 vs .95, 4 vs 5) to make sure which works best for you.
23 |
24 | #prep defaults and init torch.optim base
25 | defaults = dict(lr=lr, alpha=alpha, k=k, step_counter=0, betas=betas, N_sma_threshhold=N_sma_threshhold, eps=eps, weight_decay=weight_decay, amsgrad=True)
26 | super().__init__(params,defaults)
27 |
28 | #adjustable threshold
29 | self.N_sma_threshhold = N_sma_threshhold
30 |
31 | #now we can get to work...
32 | #removed as we now use step from RAdam...no need for duplicate step counting
33 | #for group in self.param_groups:
34 | # group["step_counter"] = 0
35 | #print("group step counter init")
36 |
37 | #look ahead params
38 | self.alpha = alpha
39 | self.k = k
40 |
41 | #radam buffer for state
42 | self.radam_buffer = [[None,None,None] for ind in range(10)]
43 |
44 | #self.first_run_check=0
45 |
46 | #lookahead weights
47 | #9/2/19 - lookahead param tensors have been moved to state storage.
48 | #This should resolve issues with load/save where weights were left in GPU memory from first load, slowing down future runs.
49 |
50 | #self.slow_weights = [[p.clone().detach() for p in group['params']]
51 | # for group in self.param_groups]
52 |
53 | #don't use grad for lookahead weights
54 | #for w in it.chain(*self.slow_weights):
55 | # w.requires_grad = False
56 |
57 | def __setstate__(self, state):
58 | print("set state called")
59 | super(Ranger, self).__setstate__(state)
60 |
61 |
62 | def step(self, closure=None):
63 | loss = None
64 | #note - below is commented out b/c I have other work that passes back the loss as a float, and thus not a callable closure.
65 | #Uncomment if you need to use the actual closure...
66 |
67 | #if closure is not None:
68 | #loss = closure()
69 |
70 | #Evaluate averages and grad, update param tensors
71 | for group in self.param_groups:
72 |
73 | for p in group['params']:
74 | if p.grad is None:
75 | continue
76 | grad = p.grad.data.float()
77 | if grad.is_sparse:
78 | raise RuntimeError('Ranger optimizer does not support sparse gradients')
79 |
80 | p_data_fp32 = p.data.float()
81 |
82 | state = self.state[p] #get state dict for this param
83 |
84 | if len(state) == 0: #if first time to run...init dictionary with our desired entries
85 | #if self.first_run_check==0:
86 | #self.first_run_check=1
87 | #print("Initializing slow buffer...should not see this at load from saved model!")
88 | state['step'] = 0
89 | state['exp_avg'] = torch.zeros_like(p_data_fp32)
90 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
91 |
92 | #look ahead weight storage now in state dict
93 | state['slow_buffer'] = torch.empty_like(p.data)
94 | state['slow_buffer'].copy_(p.data)
95 |
96 | else:
97 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
98 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
99 |
100 | #begin computations
101 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
102 | beta1, beta2 = group['betas']
103 |
104 | #compute variance mov avg
105 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
106 | #compute mean moving avg
107 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
108 |
109 | state['step'] += 1
110 |
111 |
112 | buffered = self.radam_buffer[int(state['step'] % 10)]
113 | if state['step'] == buffered[0]:
114 | N_sma, step_size = buffered[1], buffered[2]
115 | else:
116 | buffered[0] = state['step']
117 | beta2_t = beta2 ** state['step']
118 | N_sma_max = 2 / (1 - beta2) - 1
119 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
120 | buffered[1] = N_sma
121 | if N_sma > self.N_sma_threshhold:
122 | step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step'])
123 | else:
124 | step_size = 1.0 / (1 - beta1 ** state['step'])
125 | buffered[2] = step_size
126 |
127 | if group['weight_decay'] != 0:
128 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
129 |
130 | if N_sma > self.N_sma_threshhold:
131 | denom = exp_avg_sq.sqrt().add_(group['eps'])
132 | p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom)
133 | else:
134 | p_data_fp32.add_(-step_size * group['lr'], exp_avg)
135 |
136 | p.data.copy_(p_data_fp32)
137 |
138 | #integrated look ahead...
139 | #we do it at the param level instead of group level
140 | if state['step'] % group['k'] == 0:
141 | slow_p = state['slow_buffer'] #get access to slow param tensor
142 | slow_p.add_(self.alpha, p.data - slow_p) #(fast weights - slow weights) * alpha
143 | p.data.copy_(slow_p) #copy interpolated weights to RAdam param tensor
144 |
145 | return loss
146 |
--------------------------------------------------------------------------------
/common/skeleton.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2018-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | import numpy as np
9 |
10 | class Skeleton:
11 | def __init__(self, parents, joints_left, joints_right):
12 | assert len(joints_left) == len(joints_right)
13 |
14 | self._parents = np.array(parents)
15 | self._joints_left = joints_left
16 | self._joints_right = joints_right
17 | self._compute_metadata()
18 |
19 | def num_joints(self):
20 | return len(self._parents)
21 |
22 | def parents(self):
23 | return self._parents
24 |
25 | def has_children(self):
26 | return self._has_children
27 |
28 | def children(self):
29 | return self._children
30 |
31 | def remove_joints(self, joints_to_remove):
32 | """
33 | Remove the joints specified in 'joints_to_remove'.
34 | """
35 | valid_joints = []
36 | for joint in range(len(self._parents)):
37 | if joint not in joints_to_remove:
38 | valid_joints.append(joint)
39 |
40 | for i in range(len(self._parents)):
41 | while self._parents[i] in joints_to_remove:
42 | self._parents[i] = self._parents[self._parents[i]]
43 |
44 | index_offsets = np.zeros(len(self._parents), dtype=int)
45 | new_parents = []
46 | for i, parent in enumerate(self._parents):
47 | if i not in joints_to_remove:
48 | new_parents.append(parent - index_offsets[parent])
49 | else:
50 | index_offsets[i:] += 1
51 | self._parents = np.array(new_parents)
52 |
53 |
54 | if self._joints_left is not None:
55 | new_joints_left = []
56 | for joint in self._joints_left:
57 | if joint in valid_joints:
58 | new_joints_left.append(joint - index_offsets[joint])
59 | self._joints_left = new_joints_left
60 | if self._joints_right is not None:
61 | new_joints_right = []
62 | for joint in self._joints_right:
63 | if joint in valid_joints:
64 | new_joints_right.append(joint - index_offsets[joint])
65 | self._joints_right = new_joints_right
66 |
67 | self._compute_metadata()
68 |
69 | return valid_joints
70 |
71 | def joints_left(self):
72 | return self._joints_left
73 |
74 | def joints_right(self):
75 | return self._joints_right
76 |
77 | def _compute_metadata(self):
78 | self._has_children = np.zeros(len(self._parents)).astype(bool)
79 | for i, parent in enumerate(self._parents):
80 | if parent != -1:
81 | self._has_children[parent] = True
82 |
83 | self._children = []
84 | for i, parent in enumerate(self._parents):
85 | self._children.append([])
86 | for i, parent in enumerate(self._parents):
87 | if parent != -1:
88 | self._children[parent].append(i)
--------------------------------------------------------------------------------
/common/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2018-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | import torch
9 | import numpy as np
10 | import hashlib
11 |
12 | def wrap(func, *args, unsqueeze=False):
13 | """
14 | Wrap a torch function so it can be called with NumPy arrays.
15 | Input and return types are seamlessly converted.
16 | """
17 |
18 | # Convert input types where applicable
19 | args = list(args)
20 | for i, arg in enumerate(args):
21 | if type(arg) == np.ndarray:
22 | args[i] = torch.from_numpy(arg)
23 | if unsqueeze:
24 | args[i] = args[i].unsqueeze(0)
25 |
26 | result = func(*args)
27 |
28 | # Convert output types where applicable
29 | if isinstance(result, tuple):
30 | result = list(result)
31 | for i, res in enumerate(result):
32 | if type(res) == torch.Tensor:
33 | if unsqueeze:
34 | res = res.squeeze(0)
35 | result[i] = res.numpy()
36 | return tuple(result)
37 | elif type(result) == torch.Tensor:
38 | if unsqueeze:
39 | result = result.squeeze(0)
40 | return result.numpy()
41 | else:
42 | return result
43 |
44 | def deterministic_random(min_value, max_value, data):
45 | digest = hashlib.sha256(data.encode()).digest()
46 | raw_value = int.from_bytes(digest[:4], byteorder='little', signed=False)
47 | return int(raw_value / (2**32 - 1) * (max_value - min_value)) + min_value
--------------------------------------------------------------------------------
/common/visualization.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2018-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | import matplotlib
9 | matplotlib.use('Agg')
10 |
11 | import matplotlib.pyplot as plt
12 | from matplotlib.animation import FuncAnimation, writers
13 | from mpl_toolkits.mplot3d import Axes3D
14 | import numpy as np
15 | import subprocess as sp
16 |
17 | def get_resolution(filename):
18 | command = ['ffprobe', '-v', 'error', '-select_streams', 'v:0',
19 | '-show_entries', 'stream=width,height', '-of', 'csv=p=0', filename]
20 | with sp.Popen(command, stdout=sp.PIPE, bufsize=-1) as pipe:
21 | for line in pipe.stdout:
22 | w, h = line.decode().strip().split(',')
23 | return int(w), int(h)
24 |
25 | def read_video(filename, skip=0, limit=-1):
26 | w, h = get_resolution(filename)
27 |
28 | command = ['ffmpeg',
29 | '-i', filename,
30 | '-f', 'image2pipe',
31 | '-pix_fmt', 'rgb24',
32 | '-vsync', '0',
33 | '-vcodec', 'rawvideo', '-']
34 |
35 | i = 0
36 | with sp.Popen(command, stdout = sp.PIPE, bufsize=-1) as pipe:
37 | while True:
38 | data = pipe.stdout.read(w*h*3)
39 | if not data:
40 | break
41 | i += 1
42 | if i > skip:
43 | yield np.frombuffer(data, dtype='uint8').reshape((h, w, 3))
44 | if i == limit:
45 | break
46 |
47 |
48 |
49 | def downsample_tensor(X, factor):
50 | length = X.shape[0]//factor * factor
51 | return np.mean(X[:length].reshape(-1, factor, *X.shape[1:]), axis=1)
52 |
53 | def render_animation(keypoints, poses, skeleton, fps, bitrate, azim, output, viewport,
54 | limit=-1, downsample=1, size=6, input_video_path=None, input_video_skip=0):
55 | """
56 | TODO
57 | Render an animation. The supported output modes are:
58 | -- 'interactive': display an interactive figure
59 | (also works on notebooks if associated with %matplotlib inline)
60 | -- 'html': render the animation as HTML5 video. Can be displayed in a notebook using HTML(...).
61 | -- 'filename.mp4': render and export the animation as an h264 video (requires ffmpeg).
62 | -- 'filename.gif': render and export the animation a gif file (requires imagemagick).
63 | """
64 | plt.ioff()
65 | fig = plt.figure(figsize=(size*(1 + len(poses)), size))
66 | ax_in = fig.add_subplot(1, 1 + len(poses), 1)
67 | ax_in.get_xaxis().set_visible(False)
68 | ax_in.get_yaxis().set_visible(False)
69 | ax_in.set_axis_off()
70 | ax_in.set_title('Input')
71 |
72 | ax_3d = []
73 | lines_3d = []
74 | trajectories = []
75 | radius = 1.7
76 | for index, (title, data) in enumerate(poses.items()):
77 | ax = fig.add_subplot(1, 1 + len(poses), index+2, projection='3d')
78 | ax.view_init(elev=15., azim=azim)
79 | ax.set_xlim3d([-radius/2, radius/2])
80 | ax.set_zlim3d([0, radius])
81 | ax.set_ylim3d([-radius/2, radius/2])
82 | ax.set_aspect('equal')
83 | ax.set_xticklabels([])
84 | ax.set_yticklabels([])
85 | ax.set_zticklabels([])
86 | ax.dist = 7.5
87 | ax.set_title(title) #, pad=35
88 | ax_3d.append(ax)
89 | lines_3d.append([])
90 | trajectories.append(data[:, 0, [0, 1]])
91 | poses = list(poses.values())
92 |
93 | # Decode video
94 | if input_video_path is None:
95 | # Black background
96 | all_frames = np.zeros((keypoints.shape[0], viewport[1], viewport[0]), dtype='uint8')
97 | else:
98 | # Load video using ffmpeg
99 | all_frames = []
100 | for f in read_video(input_video_path, skip=input_video_skip):
101 | all_frames.append(f)
102 | effective_length = min(keypoints.shape[0], len(all_frames))
103 | all_frames = all_frames[:effective_length]
104 |
105 | if downsample > 1:
106 | keypoints = downsample_tensor(keypoints, downsample)
107 | all_frames = downsample_tensor(np.array(all_frames), downsample).astype('uint8')
108 | for idx in range(len(poses)):
109 | poses[idx] = downsample_tensor(poses[idx], downsample)
110 | trajectories[idx] = downsample_tensor(trajectories[idx], downsample)
111 | fps /= downsample
112 |
113 | initialized = False
114 | image = None
115 | lines = []
116 | points = None
117 |
118 | if limit < 1:
119 | limit = len(all_frames)
120 | else:
121 | limit = min(limit, len(all_frames))
122 |
123 | parents = skeleton.parents()
124 | def update_video(i):
125 | nonlocal initialized, image, lines, points
126 |
127 | for n, ax in enumerate(ax_3d):
128 | ax.set_xlim3d([-radius/2 + trajectories[n][i, 0], radius/2 + trajectories[n][i, 0]])
129 | ax.set_ylim3d([-radius/2 + trajectories[n][i, 1], radius/2 + trajectories[n][i, 1]])
130 |
131 | # Update 2D poses
132 | if not initialized:
133 | image = ax_in.imshow(all_frames[i], aspect='equal')
134 |
135 | for j, j_parent in enumerate(parents):
136 | if j_parent == -1:
137 | continue
138 |
139 | if len(parents) == keypoints.shape[1]:
140 | # Draw skeleton only if keypoints match (otherwise we don't have the parents definition)
141 | lines.append(ax_in.plot([keypoints[i, j, 0], keypoints[i, j_parent, 0]],
142 | [keypoints[i, j, 1], keypoints[i, j_parent, 1]], color='pink'))
143 |
144 | col = 'red' if j in skeleton.joints_right() else 'black'
145 | for n, ax in enumerate(ax_3d):
146 | pos = poses[n][i]
147 | lines_3d[n].append(ax.plot([pos[j, 0], pos[j_parent, 0]],
148 | [pos[j, 1], pos[j_parent, 1]],
149 | [pos[j, 2], pos[j_parent, 2]], zdir='z', c=col))
150 |
151 | points = ax_in.scatter(*keypoints[i].T, 5, color='red', edgecolors='white', zorder=10)
152 |
153 | initialized = True
154 | else:
155 | image.set_data(all_frames[i])
156 |
157 | for j, j_parent in enumerate(parents):
158 | if j_parent == -1:
159 | continue
160 |
161 | if len(parents) == keypoints.shape[1]:
162 | lines[j-1][0].set_data([keypoints[i, j, 0], keypoints[i, j_parent, 0]],
163 | [keypoints[i, j, 1], keypoints[i, j_parent, 1]])
164 |
165 | for n, ax in enumerate(ax_3d):
166 | pos = poses[n][i]
167 | lines_3d[n][j-1][0].set_xdata([pos[j, 0], pos[j_parent, 0]])
168 | lines_3d[n][j-1][0].set_ydata([pos[j, 1], pos[j_parent, 1]])
169 | lines_3d[n][j-1][0].set_3d_properties([pos[j, 2], pos[j_parent, 2]], zdir='z')
170 |
171 | points.set_offsets(keypoints[i])
172 |
173 | print('{}/{} '.format(i, limit), end='\r')
174 |
175 |
176 | fig.tight_layout()
177 |
178 | anim = FuncAnimation(fig, update_video, frames=np.arange(0, limit), interval=1000/fps, repeat=False)
179 | if output.endswith('.mp4'):
180 | Writer = writers['ffmpeg']
181 | writer = Writer(fps=fps, metadata={}, bitrate=bitrate)
182 | anim.save(output, writer=writer)
183 | elif output.endswith('.gif'):
184 | anim.save(output, dpi=80, writer='imagemagick')
185 | else:
186 | raise ValueError('Unsupported output format (only .mp4 and .gif are supported)')
187 | plt.close()
--------------------------------------------------------------------------------
/data/README.md:
--------------------------------------------------------------------------------
1 | The dataset put at here.
2 |
--------------------------------------------------------------------------------
/run.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2018-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | import numpy as np
9 |
10 | from common.arguments import parse_args
11 | import torch
12 |
13 | import torch.nn as nn
14 | import torch.nn.functional as F
15 | import torch.optim as optim
16 | import os
17 | import sys
18 | import errno
19 |
20 | from common.camera import *
21 | from common.loss import *
22 | from common.generators import ChunkedGenerator, Evaluate_Generator
23 | from time import time
24 | from common.utils import deterministic_random
25 | from common.ranger import Ranger
26 | from torch.optim import lr_scheduler
27 |
28 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
29 | os.environ["CUDA_VISIBLE_DEVICES"] = "1"
30 |
31 | torch.backends.cudnn.benchmark = True
32 |
33 | args = parse_args()
34 | print(args)
35 |
36 | try:
37 | # Create checkpoint directory if it does not exist
38 | os.makedirs(args.checkpoint)
39 | except OSError as e:
40 | if e.errno != errno.EEXIST:
41 | raise RuntimeError('Unable to create checkpoint directory:', args.checkpoint)
42 |
43 | if args.causal:
44 | from common.causal_model import *
45 | else:
46 | from common.model import *
47 | print('Loading dataset...')
48 | dataset_path = 'data/data_3d_' + args.dataset + '.npz'
49 | if args.dataset == 'h36m':
50 | from common.h36m_dataset import Human36mDataset
51 |
52 | dataset = Human36mDataset(dataset_path)
53 | elif args.dataset.startswith('humaneva'):
54 | from common.humaneva_dataset import HumanEvaDataset
55 |
56 | dataset = HumanEvaDataset(dataset_path)
57 | else:
58 | raise KeyError('Invalid dataset')
59 |
60 | print('Preparing data...')
61 | for subject in dataset.subjects():
62 | for action in dataset[subject].keys():
63 | anim = dataset[subject][action]
64 |
65 | positions_3d = []
66 | for cam in anim['cameras']:
67 | pos_3d = world_to_camera(anim['positions'], R=cam['orientation'], t=cam['translation'])
68 | pos_3d[:, 1:] -= pos_3d[:, :1] # Remove global offset, but keep trajectory in first position
69 | positions_3d.append(pos_3d)
70 | anim['positions_3d'] = positions_3d
71 |
72 | print('Loading 2D detections...')
73 | keypoints = np.load('data/data_2d_' + args.dataset + '_' + args.keypoints + '.npz', allow_pickle=True)
74 | keypoints_symmetry = keypoints['metadata'].item()['keypoints_symmetry']
75 | kps_left, kps_right = list(keypoints_symmetry[0]), list(keypoints_symmetry[1])
76 | joints_left, joints_right = list(dataset.skeleton().joints_left()), list(dataset.skeleton().joints_right())
77 | keypoints = keypoints['positions_2d'].item()
78 |
79 | for subject in dataset.subjects():
80 | assert subject in keypoints, 'Subject {} is missing from the 2D detections dataset'.format(subject)
81 | for action in dataset[subject].keys():
82 | assert action in keypoints[subject], 'Action {} of subject {} is missing from the 2D detections dataset'.format(
83 | action, subject)
84 | for cam_idx in range(len(keypoints[subject][action])):
85 |
86 | # We check for >= instead of == because some videos in H3.6M contain extra frames
87 | mocap_length = dataset[subject][action]['positions_3d'][cam_idx].shape[0]
88 | assert keypoints[subject][action][cam_idx].shape[0] >= mocap_length
89 |
90 | if keypoints[subject][action][cam_idx].shape[0] > mocap_length:
91 | # Shorten sequence
92 | keypoints[subject][action][cam_idx] = keypoints[subject][action][cam_idx][:mocap_length]
93 |
94 | assert len(keypoints[subject][action]) == len(dataset[subject][action]['positions_3d'])
95 |
96 | for subject in keypoints.keys():
97 | for action in keypoints[subject]:
98 | for cam_idx, kps in enumerate(keypoints[subject][action]):
99 | # Normalize camera frame
100 | cam = dataset.cameras()[subject][cam_idx]
101 | kps[..., :2] = normalize_screen_coordinates(kps[..., :2], w=cam['res_w'], h=cam['res_h'])
102 | keypoints[subject][action][cam_idx] = kps
103 |
104 | subjects_train = args.subjects_train.split(',')
105 | subjects_test = args.subjects_test.split(',')
106 |
107 |
108 | def fetch(subjects, action_filter=None, subset=1, parse_3d_poses=True):
109 | out_poses_3d = []
110 | out_poses_2d = []
111 | out_camera_params = []
112 | for subject in subjects:
113 | for action in keypoints[subject].keys():
114 | if action_filter is not None:
115 | found = False
116 | for a in action_filter:
117 | if action.startswith(a):
118 | found = True
119 | break
120 | if not found:
121 | continue
122 |
123 | poses_2d = keypoints[subject][action]
124 | for i in range(len(poses_2d)): # Iterate across cameras
125 | out_poses_2d.append(poses_2d[i])
126 |
127 | if subject in dataset.cameras():
128 | cams = dataset.cameras()[subject]
129 | assert len(cams) == len(poses_2d), 'Camera count mismatch'
130 | for cam in cams:
131 | if 'intrinsic' in cam:
132 | out_camera_params.append(cam['intrinsic'])
133 |
134 | if parse_3d_poses and 'positions_3d' in dataset[subject][action]:
135 | poses_3d = dataset[subject][action]['positions_3d']
136 | assert len(poses_3d) == len(poses_2d), 'Camera count mismatch'
137 | for i in range(len(poses_3d)): # Iterate across cameras
138 | out_poses_3d.append(poses_3d[i])
139 |
140 | if len(out_camera_params) == 0:
141 | out_camera_params = None
142 | if len(out_poses_3d) == 0:
143 | out_poses_3d = None
144 |
145 | stride = args.downsample
146 | if subset < 1:
147 | for i in range(len(out_poses_2d)):
148 | n_frames = int(round(len(out_poses_2d[i]) // stride * subset) * stride)
149 | start = deterministic_random(0, len(out_poses_2d[i]) - n_frames + 1, str(len(out_poses_2d[i])))
150 | out_poses_2d[i] = out_poses_2d[i][start:start + n_frames:stride]
151 | if out_poses_3d is not None:
152 | out_poses_3d[i] = out_poses_3d[i][start:start + n_frames:stride]
153 | elif stride > 1:
154 | # Downsample as requested
155 | for i in range(len(out_poses_2d)):
156 | out_poses_2d[i] = out_poses_2d[i][::stride]
157 | if out_poses_3d is not None:
158 | out_poses_3d[i] = out_poses_3d[i][::stride]
159 |
160 | return out_camera_params, out_poses_3d, out_poses_2d
161 |
162 |
163 | action_filter = None if args.actions == '*' else args.actions.split(',')
164 | if action_filter is not None:
165 | print('Selected actions:', action_filter)
166 |
167 | cameras_valid, poses_valid, poses_valid_2d = fetch(subjects_test, action_filter)
168 |
169 | filter_widths = [int(x) for x in args.architecture.split(',')]
170 | if not args.disable_optimizations and not args.dense and args.stride == 1:
171 | # Use optimized model for single-frame predictions
172 | model_pos_train = TemporalModelOptimized1f(poses_valid_2d[0].shape[-2], poses_valid_2d[0].shape[-1],
173 | poses_valid[0].shape[-2],
174 | filter_widths=filter_widths, causal=args.causal, dropout=args.dropout,
175 | channels=args.channels)
176 | else:
177 | # When incompatible settings are detected (stride > 1, dense filters, or disabled optimization) fall back to normal model
178 | model_pos_train = TemporalModel(poses_valid_2d[0].shape[-2], poses_valid_2d[0].shape[-1], poses_valid[0].shape[-2],
179 | filter_widths=filter_widths, causal=args.causal, dropout=args.dropout,
180 | channels=args.channels,
181 | dense=args.dense)
182 |
183 | model_pos = TemporalModelOptimized1f(poses_valid_2d[0].shape[-2], poses_valid_2d[0].shape[-1], poses_valid[0].shape[-2],
184 | filter_widths=filter_widths, causal=args.causal, dropout=args.dropout,
185 | channels=args.channels, dense=args.dense)
186 |
187 | receptive_field = model_pos.receptive_field()
188 | print('INFO: Receptive field: {} frames'.format(receptive_field))
189 | pad = (receptive_field - 1) // 2 # Padding on each side
190 | if args.causal:
191 | print('INFO: Using causal convolutions')
192 | causal_shift = pad
193 | else:
194 | causal_shift = 0
195 |
196 | model_params = 0
197 | for parameter in model_pos.parameters():
198 | model_params += parameter.numel()
199 | print('INFO: Trainable parameter count:', model_params)
200 |
201 | if torch.cuda.is_available():
202 | model_pos = model_pos.cuda()
203 | model_pos_train = model_pos_train.cuda()
204 |
205 | if args.resume or args.evaluate:
206 | chk_filename = os.path.join(args.checkpoint, args.resume if args.resume else args.evaluate)
207 | print('Loading checkpoint', chk_filename)
208 | checkpoint = torch.load(chk_filename, map_location=lambda storage, loc: storage)
209 | model_pos_train.load_state_dict(checkpoint['model_pos'])
210 | model_pos.load_state_dict(checkpoint['model_pos'])
211 |
212 | test_generator = ChunkedGenerator(args.batch_size // args.stride, cameras_valid, poses_valid, poses_valid_2d,
213 | args.stride,
214 | pad=pad, causal_shift=causal_shift, augment=args.test_time_augmentation,
215 | shuffle=False,
216 | kps_left=kps_left, kps_right=kps_right, joints_left=joints_left,
217 | joints_right=joints_right, noisy=False)
218 | print('INFO: Testing on {} sequences'.format(test_generator.num_frames()))
219 |
220 | if not args.evaluate:
221 | cameras_train, poses_train, poses_train_2d = fetch(subjects_train, action_filter, subset=args.subset)
222 |
223 | lr = args.learning_rate
224 | optimizer = Ranger(model_pos_train.parameters(), lr=lr)
225 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, eta_min=1e-5, T_max=args.epochs)
226 |
227 | lr_decay = args.lr_decay
228 |
229 | losses_3d_train = []
230 | losses_3d_train_eval = []
231 | losses_3d_valid = []
232 |
233 | epoch = 0
234 | initial_momentum = 0.1
235 | final_momentum = 0.001
236 |
237 | train_generator = ChunkedGenerator(args.batch_size // args.stride, cameras_train, poses_train, poses_train_2d,
238 | args.stride,
239 | pad=pad, causal_shift=causal_shift, shuffle=True, augment=args.data_augmentation,
240 | kps_left=kps_left, kps_right=kps_right, joints_left=joints_left,
241 | joints_right=joints_right)
242 | train_generator_eval = ChunkedGenerator(args.batch_size // args.stride, cameras_train, poses_train, poses_train_2d,
243 | args.stride,
244 | pad=pad, causal_shift=causal_shift, augment=False, shuffle=True,
245 | kps_left=kps_left, kps_right=kps_right, joints_left=joints_left,
246 | joints_right=joints_right)
247 | print('INFO: Supervision Training on {} frames'.format(train_generator.num_frames()))
248 |
249 | if args.resume:
250 | epoch = checkpoint['epoch']
251 | if 'optimizer' in checkpoint and checkpoint['optimizer'] is not None:
252 | optimizer.load_state_dict(checkpoint['optimizer'])
253 | train_generator.set_random_state(checkpoint['random_state'])
254 | else:
255 | print('WARNING: this checkpoint does not contain an optimizer state. The optimizer will be reinitialized.')
256 |
257 | lr = checkpoint['lr']
258 |
259 | print('** Note: reported losses are averaged over all frames and test-time augmentation is not used here.')
260 | print('** The final evaluation will be carried out after the last training epoch.')
261 |
262 | # Pos model only
263 | while epoch < args.epochs:
264 | start_time = time()
265 | epoch_loss_3d_train = 0
266 | epoch_loss_traj_train = 0
267 | epoch_loss_2d_train_unlabeled = 0
268 | N = 0
269 | N_semi = 0
270 | model_pos_train.train()
271 |
272 | for _, batch_3d, batch_2d in train_generator.next_epoch():
273 | inputs_3d = torch.from_numpy(batch_3d.astype('float32'))
274 | inputs_2d = torch.from_numpy(batch_2d.astype('float32'))
275 | if torch.cuda.is_available():
276 | inputs_3d = inputs_3d.cuda()
277 | inputs_2d = inputs_2d.cuda()
278 | inputs_3d[:, :, 0] = 0
279 |
280 | optimizer.zero_grad()
281 |
282 | # Predict 3D poses
283 | predicted_3d_pos = model_pos_train(inputs_2d)
284 | loss_3d_pos = mpjpe(predicted_3d_pos, inputs_3d)
285 | epoch_loss_3d_train += inputs_3d.shape[0] * inputs_3d.shape[1] * loss_3d_pos.item()
286 | N += inputs_3d.shape[0] * inputs_3d.shape[1]
287 |
288 | loss_total = loss_3d_pos
289 | loss_total.backward()
290 |
291 | optimizer.step()
292 |
293 | losses_3d_train.append(epoch_loss_3d_train / N)
294 |
295 | # End-of-epoch evaluation
296 | with torch.no_grad():
297 | model_pos.load_state_dict(model_pos_train.state_dict())
298 | model_pos.eval()
299 |
300 | epoch_loss_3d_valid = 0
301 | epoch_loss_traj_valid = 0
302 | epoch_loss_2d_valid = 0
303 | N = 0
304 |
305 | if not args.no_eval:
306 | # Evaluate on test set
307 | for cam, batch, batch_2d in test_generator.next_epoch():
308 | inputs_3d = torch.from_numpy(batch.astype('float32'))
309 | inputs_2d = torch.from_numpy(batch_2d.astype('float32'))
310 | if torch.cuda.is_available():
311 | inputs_3d = inputs_3d.cuda()
312 | inputs_2d = inputs_2d.cuda()
313 | inputs_traj = inputs_3d[:, :, :1].clone()
314 | inputs_3d[:, :, 0] = 0
315 |
316 | # Predict 3D poses
317 | predicted_3d_pos = model_pos(inputs_2d)
318 | loss_3d_pos = mpjpe(predicted_3d_pos, inputs_3d)
319 | epoch_loss_3d_valid += inputs_3d.shape[0] * inputs_3d.shape[1] * loss_3d_pos.item()
320 | N += inputs_3d.shape[0] * inputs_3d.shape[1]
321 |
322 | losses_3d_valid.append(epoch_loss_3d_valid / N)
323 |
324 | # Evaluate on training set, this time in evaluation mode
325 | epoch_loss_3d_train_eval = 0
326 | epoch_loss_traj_train_eval = 0
327 | epoch_loss_2d_train_labeled_eval = 0
328 | N = 0
329 | for cam, batch, batch_2d in train_generator_eval.next_epoch():
330 | if batch_2d.shape[1] == 0:
331 | # This can only happen when downsampling the dataset
332 | continue
333 |
334 | inputs_3d = torch.from_numpy(batch.astype('float32'))
335 | inputs_2d = torch.from_numpy(batch_2d.astype('float32'))
336 | if torch.cuda.is_available():
337 | inputs_3d = inputs_3d.cuda()
338 | inputs_2d = inputs_2d.cuda()
339 | inputs_traj = inputs_3d[:, :, :1].clone()
340 | inputs_3d[:, :, 0] = 0
341 |
342 | # Compute 3D poses
343 | predicted_3d_pos = model_pos(inputs_2d)
344 | loss_3d_pos = mpjpe(predicted_3d_pos, inputs_3d)
345 | epoch_loss_3d_train_eval += inputs_3d.shape[0] * inputs_3d.shape[1] * loss_3d_pos.item()
346 | N += inputs_3d.shape[0] * inputs_3d.shape[1]
347 |
348 | losses_3d_train_eval.append(epoch_loss_3d_train_eval / N)
349 |
350 | # Evaluate 2D loss on unlabeled training set (in evaluation mode)
351 | epoch_loss_2d_train_unlabeled_eval = 0
352 | N_semi = 0
353 |
354 | elapsed = (time() - start_time) / 60
355 |
356 | if args.no_eval:
357 | print('[%d] time %.2f lr %f 3d_train %f' % (
358 | epoch + 1,
359 | elapsed,
360 | lr,
361 | losses_3d_train[-1] * 1000))
362 | else:
363 | print('[%d] time %.2f lr %f 3d_train %f 3d_eval %f 3d_valid %f' % (
364 | epoch + 1,
365 | elapsed,
366 | lr,
367 | losses_3d_train[-1] * 1000,
368 | losses_3d_train_eval[-1] * 1000,
369 | losses_3d_valid[-1] * 1000))
370 |
371 | # cosin annealing
372 | scheduler.step()
373 | lr = scheduler.get_lr()[0]
374 | for param_group in optimizer.param_groups:
375 | param_group['lr'] = lr
376 |
377 | epoch += 1
378 | momentum = initial_momentum * np.exp(-epoch / args.epochs * np.log(initial_momentum / final_momentum))
379 | model_pos_train.set_bn_momentum(momentum)
380 | model_pos_train.set_KA_bn(momentum)
381 | model_pos_train.set_expand_bn(momentum)
382 | model_pos_train.set_dilation_bn(momentum)
383 |
384 | # Save checkpoint if necessary
385 | if epoch % args.checkpoint_frequency == 0:
386 | check_point_name = 'supervised'
387 |
388 | chk_path = os.path.join(args.checkpoint, str(args.channels) + '_' + str(args.keypoints) +
389 | '_' + str(receptive_field) + '_' + check_point_name + '_epoch_{}.bin'.format(
390 | epoch))
391 | print('Saving checkpoint to', chk_path)
392 |
393 | torch.save({
394 | 'epoch': epoch,
395 | 'lr': lr,
396 | 'random_state': train_generator.random_state(),
397 | 'optimizer': optimizer.state_dict(),
398 | 'model_pos': model_pos_train.state_dict(),
399 | }, chk_path)
400 |
401 | # Save training curves after every epoch, as .png images (if requested)
402 | if args.export_training_curves and epoch > 3:
403 | if 'matplotlib' not in sys.modules:
404 | import matplotlib
405 |
406 | matplotlib.use('Agg')
407 | import matplotlib.pyplot as plt
408 |
409 | plt.figure()
410 | epoch_x = np.arange(3, len(losses_3d_train)) + 1
411 | plt.plot(epoch_x, losses_3d_train[3:], '--', color='C0')
412 | plt.plot(epoch_x, losses_3d_train_eval[3:], color='C0')
413 | plt.plot(epoch_x, losses_3d_valid[3:], color='C1')
414 | plt.legend(['3d train', '3d train (eval)', '3d valid (eval)'])
415 | plt.ylabel('MPJPE (m)')
416 | plt.xlabel('Epoch')
417 | plt.xlim((3, epoch))
418 | plt.savefig(os.path.join(args.checkpoint, 'loss_3d.png'))
419 |
420 | plt.close('all')
421 |
422 |
423 | # Evaluate
424 | def evaluate(test_generator, action=None, return_predictions=False):
425 | epoch_loss_3d_pos = 0
426 | epoch_loss_3d_pos_procrustes = 0
427 |
428 | with torch.no_grad():
429 | model_pos.eval()
430 | N = 0
431 |
432 | # Test-time augmentation (if enabled)
433 | if args.test_time_augmentation:
434 | for _, batch, batch_2d, batch_2d_flip in test_generator.next_epoch():
435 | inputs_2d = torch.from_numpy(batch_2d.astype('float32'))
436 | inputs_2d_flip = torch.from_numpy(batch_2d_flip.astype('float32'))
437 | if torch.cuda.is_available():
438 | inputs_2d = inputs_2d.cuda()
439 | inputs_2d_flip = inputs_2d_flip.cuda()
440 |
441 | # Positional model
442 | predicted_3d_pos = model_pos(inputs_2d)
443 | predicted_3d_pos_flip = model_pos(inputs_2d_flip)
444 | predicted_3d_pos_flip[:, :, :, 0] *= -1
445 | predicted_3d_pos_flip[:, :, joints_left + joints_right] = predicted_3d_pos_flip[:, :,
446 | joints_right + joints_left]
447 |
448 | predicted_3d_pos = torch.mean(torch.cat((predicted_3d_pos, predicted_3d_pos_flip), dim=1), dim=1,
449 | keepdim=True)
450 |
451 | if return_predictions:
452 | return predicted_3d_pos.squeeze().cpu().numpy()
453 |
454 | inputs_3d = torch.from_numpy(batch.astype('float32'))
455 | if torch.cuda.is_available():
456 | inputs_3d = inputs_3d.cuda()
457 | inputs_3d[:, :, 0] = 0
458 |
459 | error = mpjpe(predicted_3d_pos, inputs_3d)
460 |
461 | epoch_loss_3d_pos += inputs_3d.shape[0] * inputs_3d.shape[1] * error.item()
462 | N += inputs_3d.shape[0] * inputs_3d.shape[1]
463 |
464 | inputs = inputs_3d.cpu().numpy().reshape(-1, inputs_3d.shape[-2], inputs_3d.shape[-1])
465 | predicted_3d_pos = predicted_3d_pos.cpu().numpy().reshape(-1, inputs_3d.shape[-2], inputs_3d.shape[-1])
466 |
467 | epoch_loss_3d_pos_procrustes += inputs_3d.shape[0] * inputs_3d.shape[1] * p_mpjpe(predicted_3d_pos,
468 | inputs)
469 |
470 | else:
471 | for _, batch, batch_2d in test_generator.next_epoch():
472 | inputs_2d = torch.from_numpy(batch_2d.astype('float32'))
473 | if torch.cuda.is_available():
474 | inputs_2d = inputs_2d.cuda()
475 |
476 | # Positional model
477 | predicted_3d_pos = model_pos(inputs_2d)
478 |
479 | if return_predictions:
480 | return predicted_3d_pos.squeeze().cpu().numpy()
481 |
482 | inputs_3d = torch.from_numpy(batch.astype('float32'))
483 | if torch.cuda.is_available():
484 | inputs_3d = inputs_3d.cuda()
485 | inputs_3d[:, :, 0] = 0
486 |
487 | error = mpjpe(predicted_3d_pos, inputs_3d)
488 |
489 | epoch_loss_3d_pos += inputs_3d.shape[0] * inputs_3d.shape[1] * error.item()
490 | N += inputs_3d.shape[0] * inputs_3d.shape[1]
491 |
492 | inputs = inputs_3d.cpu().numpy().reshape(-1, inputs_3d.shape[-2], inputs_3d.shape[-1])
493 | predicted_3d_pos = predicted_3d_pos.cpu().numpy().reshape(-1, inputs_3d.shape[-2], inputs_3d.shape[-1])
494 |
495 | epoch_loss_3d_pos_procrustes += inputs_3d.shape[0] * inputs_3d.shape[1] * p_mpjpe(predicted_3d_pos,
496 | inputs)
497 | if action is None:
498 | print('----------')
499 | else:
500 | print('----' + action + '----')
501 | e1 = (epoch_loss_3d_pos / N) * 1000
502 | e2 = (epoch_loss_3d_pos_procrustes / N) * 1000
503 |
504 | print('Test time augmentation:', test_generator.augment_enabled())
505 | print('Protocol #1 Error (MPJPE):', e1, 'mm')
506 | print('Protocol #2 Error (P-MPJPE):', e2, 'mm')
507 | print('----------')
508 |
509 | return e1, e2
510 |
511 |
512 | if args.render:
513 | print('Rendering...')
514 |
515 | input_keypoints = keypoints[args.viz_subject][args.viz_action][args.viz_camera].copy()
516 | if args.viz_subject in dataset.subjects() and args.viz_action in dataset[args.viz_subject]:
517 | ground_truth = dataset[args.viz_subject][args.viz_action]['positions_3d'][args.viz_camera].copy()
518 | else:
519 | ground_truth = None
520 | print('INFO: this action is unlabeled. Ground truth will not be rendered.')
521 |
522 | gen = Evaluate_Generator(1, None, None, [input_keypoints], args.stride,
523 | pad=pad, causal_shift=causal_shift, augment=args.test_time_augmentation,
524 | shuffle=False,
525 | kps_left=kps_left, kps_right=kps_right, joints_left=joints_left,
526 | joints_right=joints_right)
527 | prediction = evaluate(gen, return_predictions=True)
528 |
529 | if ground_truth is not None:
530 | # Reapply trajectory
531 | trajectory = ground_truth[:, :1]
532 | ground_truth[:, 1:] += trajectory
533 | prediction += trajectory
534 |
535 | # Invert camera transformation
536 | cam = dataset.cameras()[args.viz_subject][args.viz_camera]
537 | if ground_truth is not None:
538 | prediction = camera_to_world(prediction, R=cam['orientation'], t=cam['translation'])
539 | ground_truth = camera_to_world(ground_truth, R=cam['orientation'], t=cam['translation'])
540 | else:
541 | # If the ground truth is not available, take the camera extrinsic params from a random subject.
542 | # They are almost the same, and anyway, we only need this for visualization purposes.
543 | for subject in dataset.cameras():
544 | if 'orientation' in dataset.cameras()[subject][args.viz_camera]:
545 | rot = dataset.cameras()[subject][args.viz_camera]['orientation']
546 | break
547 | prediction = camera_to_world(prediction, R=rot, t=0)
548 | # We don't have the trajectory, but at least we can rebase the height
549 | prediction[:, :, 2] -= np.min(prediction[:, :, 2])
550 |
551 | anim_output = {'Reconstruction': prediction}
552 | if ground_truth is not None and not args.viz_no_ground_truth:
553 | anim_output['Ground truth'] = ground_truth
554 |
555 | input_keypoints = image_coordinates(input_keypoints[..., :2], w=cam['res_w'], h=cam['res_h'])
556 |
557 | from common.visualization import render_animation
558 |
559 | render_animation(input_keypoints, anim_output,
560 | dataset.skeleton(), dataset.fps(), args.viz_bitrate, cam['azimuth'], args.viz_output,
561 | limit=args.viz_limit, downsample=args.viz_downsample, size=args.viz_size,
562 | input_video_path=args.viz_video, viewport=(cam['res_w'], cam['res_h']),
563 | input_video_skip=args.viz_skip)
564 |
565 | else:
566 | print('Evaluating...')
567 | all_actions = {}
568 | all_actions_by_subject = {}
569 | for subject in subjects_test:
570 | if subject not in all_actions_by_subject:
571 | all_actions_by_subject[subject] = {}
572 |
573 | ordered_actions = dataset.define_actions()
574 | for ordered_action in ordered_actions:
575 | for action in dataset[subject].keys():
576 | action_name = action.split(' ')[0]
577 | if action_name == ordered_action:
578 | if action_name not in all_actions:
579 | all_actions[action_name] = []
580 | if action_name not in all_actions_by_subject:
581 | all_actions_by_subject[subject][action_name] = []
582 | all_actions[action_name].append((subject, action))
583 | all_actions_by_subject[subject][action_name].append((subject, action))
584 | else:
585 | continue
586 |
587 |
588 | def fetch_actions(actions):
589 | out_poses_3d = []
590 | out_poses_2d = []
591 |
592 | for subject, action in actions:
593 | poses_2d = keypoints[subject][action]
594 | for i in range(len(poses_2d)): # Iterate across cameras
595 | out_poses_2d.append(poses_2d[i])
596 |
597 | poses_3d = dataset[subject][action]['positions_3d']
598 | assert len(poses_3d) == len(poses_2d), 'Camera count mismatch'
599 | for i in range(len(poses_3d)): # Iterate across cameras
600 | out_poses_3d.append(poses_3d[i])
601 |
602 | stride = args.downsample
603 | if stride > 1:
604 | # Downsample as requested
605 | for i in range(len(out_poses_2d)):
606 | out_poses_2d[i] = out_poses_2d[i][::stride]
607 | if out_poses_3d is not None:
608 | out_poses_3d[i] = out_poses_3d[i][::stride]
609 |
610 | return out_poses_3d, out_poses_2d
611 |
612 |
613 | def run_evaluation(actions, action_filter=None):
614 | errors_p1 = []
615 | errors_p2 = []
616 |
617 | for action_key in actions.keys():
618 | if action_filter is not None:
619 | found = False
620 | for a in action_filter:
621 | if action_key.startswith(a):
622 | found = True
623 | break
624 | if not found:
625 | continue
626 |
627 | poses_act, poses_2d_act = fetch_actions(actions[action_key])
628 | gen = Evaluate_Generator(1, None, poses_act, poses_2d_act, args.stride,
629 | pad=pad, causal_shift=causal_shift, augment=args.test_time_augmentation,
630 | shuffle=False,
631 | kps_left=kps_left, kps_right=kps_right, joints_left=joints_left,
632 | joints_right=joints_right)
633 | e1, e2 = evaluate(gen, action_key)
634 | errors_p1.append(e1)
635 | errors_p2.append(e2)
636 |
637 | print('Protocol #1 (MPJPE) action-wise average:', round(np.mean(errors_p1), 1), 'mm')
638 | print('Protocol #2 (P-MPJPE) action-wise average:', round(np.mean(errors_p2), 1), 'mm')
639 |
640 | if not args.by_subject:
641 | run_evaluation(all_actions, action_filter)
642 | else:
643 | for subject in all_actions_by_subject.keys():
644 | print('Evaluating on subject', subject)
645 | run_evaluation(all_actions_by_subject[subject], action_filter)
646 | print('')
647 |
648 |
--------------------------------------------------------------------------------