├── common ├── dataset │ ├── __init__.py │ ├── pre_process │ │ ├── __init__.py │ │ ├── mpi-inf.py │ │ ├── get_mpi_inf.py │ │ ├── norm_data.py │ │ ├── utils.py │ │ ├── hm36.py │ │ └── get_3dpw.py │ ├── post_process │ │ ├── __init__.py │ │ └── process3d.py │ ├── mocap_dataset.py │ ├── skeleton.py │ └── h36m_dataset.py ├── common_pytorch │ ├── __init__.py │ ├── loss │ │ ├── __init__.py │ │ └── loss_family.py │ ├── model │ │ ├── __init__.py │ │ ├── srnet_utils │ │ │ ├── __init__.py │ │ │ └── group_index.py │ │ └── fc_baseline.py │ ├── experiment │ │ ├── __init__.py │ │ ├── .inference.py.swp │ │ ├── tools.py │ │ ├── train.py │ │ └── eval_metrics.py │ └── utils.py ├── transformation │ ├── __init__.py │ ├── quaternion.py │ ├── kpt_trans.py │ ├── aug_rotate.py │ └── cam_utils.py ├── visualization │ ├── __init__.py │ ├── plot_pose2d.py │ ├── plot_log_epoch.py │ ├── plot_log_kpt.py │ └── plot_pose3d.py └── arguments │ └── basic_args.py ├── img ├── framework.png ├── comparison.png └── observation.png ├── config ├── single_l6_gp5_mul_p1.sh ├── single_l8_gp5_add_p1.sh ├── single_l8_gp5_concat_p1.sh ├── temporal_f243_gp5_mul_p1.sh ├── single_eval_l8_gp5_add.sh ├── single_eval_l6_gp5_mul.sh ├── temporal_eval_f243_gp5_mul_p1.sh └── single_l6_gp5_mul_crossaction_norm.sh ├── run_os.py ├── data ├── convert_cdf_to_mat.m ├── prepare_data_2d_h36m_generic.py ├── data_utils.py ├── prepare_data_2d_custom.py ├── prepare_data_2d_h36m_sh.py ├── prepare_data_h36m.py ├── ConvertHumanEva.m └── prepare_data_humaneva.py ├── .gitignore ├── README.md └── LICENSE /common/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /common/common_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /common/transformation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /common/visualization/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /common/common_pytorch/loss/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /common/dataset/pre_process/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /common/common_pytorch/model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /common/dataset/post_process/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /common/common_pytorch/experiment/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /common/common_pytorch/model/srnet_utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /img/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ailingzengzzz/Split-and-Recombine-Net/HEAD/img/framework.png -------------------------------------------------------------------------------- /img/comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ailingzengzzz/Split-and-Recombine-Net/HEAD/img/comparison.png -------------------------------------------------------------------------------- /img/observation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ailingzengzzz/Split-and-Recombine-Net/HEAD/img/observation.png -------------------------------------------------------------------------------- /config/single_l6_gp5_mul_p1.sh: -------------------------------------------------------------------------------- 1 | nohup python -u run.py -arc 1,1,1 --group 5 --recombine multiply > log/exp1.log 2>&1& -------------------------------------------------------------------------------- /config/single_l8_gp5_add_p1.sh: -------------------------------------------------------------------------------- 1 | nohup python -u run.py -arc 1,1,1,1 --group 5 --recombine add > log/exp2.log 2>&1& 2 | -------------------------------------------------------------------------------- /config/single_l8_gp5_concat_p1.sh: -------------------------------------------------------------------------------- 1 | nohup python -u run.py -arc 1,1,1,1 --group 5 --recombine concat > log/exp3.log 2>&1& 2 | -------------------------------------------------------------------------------- /config/temporal_f243_gp5_mul_p1.sh: -------------------------------------------------------------------------------- 1 | nohup python -u run.py -arc 3,3,3,3,3 --group 5 --recombine multiply > log/train_t243.log 2>&1& 2 | -------------------------------------------------------------------------------- /config/single_eval_l8_gp5_add.sh: -------------------------------------------------------------------------------- 1 | nohup python -u run.py -arc 1,1,1,1 --evaluate srnet_gp5_T1_add.bin --group 5 --recombine add > log/eval_t1_add.log 2>&1& 2 | -------------------------------------------------------------------------------- /config/single_eval_l6_gp5_mul.sh: -------------------------------------------------------------------------------- 1 | nohup python -u run.py -arc 1,1,1 --evaluate srnet_gp5_T1_mul.bin --group 5 --recombine multiply > log/eval_t1_mul.log 2>&1& 2 | -------------------------------------------------------------------------------- /common/common_pytorch/experiment/.inference.py.swp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ailingzengzzz/Split-and-Recombine-Net/HEAD/common/common_pytorch/experiment/.inference.py.swp -------------------------------------------------------------------------------- /config/temporal_eval_f243_gp5_mul_p1.sh: -------------------------------------------------------------------------------- 1 | nohup python -u run.py -arc 3,3,3,3,3 --evaluate srnet_gp5_t243_mul.bin --group 5 --recombine multiply > log/eval_t243.log 2>&1& 2 | -------------------------------------------------------------------------------- /config/single_l6_gp5_mul_crossaction_norm.sh: -------------------------------------------------------------------------------- 1 | nohup python -u run.py -arc 1,1,1 --group 5 --recombine multiply --use-action-split True --train-action Discussion --norm lcn > log/exp4.log 2>&1& 2 | -------------------------------------------------------------------------------- /run_os.py: -------------------------------------------------------------------------------- 1 | from subprocess import call 2 | import sys 3 | action = ['Greeting','Sitting','SittingDown','WalkTogether','Phoning','Posing','WalkDog','Walking','Purchases','Waiting','Directions','Smoking','Photo','Eating','Discussion'] 4 | group = [1, 2, 3, 5] 5 | 6 | for i in action: 7 | cmd = """nohup python -u run.py --model srnet -arc 1,1,1 --use-action-split True --train-action {} -mn sr_t1_crossaction_act{} > log/sr_t1_crossaction_act{}.log 2>&1&""".format(i,i,i) 8 | print(cmd) 9 | call(cmd, shell=True) 10 | print('Finish!') 11 | 12 | -------------------------------------------------------------------------------- /data/convert_cdf_to_mat.m: -------------------------------------------------------------------------------- 1 | % Copyright (c) 2018-present, Facebook, Inc. 2 | % All rights reserved. 3 | % 4 | % This source code is licensed under the license found in the 5 | % LICENSE file in the root directory of this source tree. 6 | % 7 | 8 | % Extract "Poses_D3_Positions_S*.tgz" to the "pose" directory 9 | % and run this script to convert all .cdf files to .mat 10 | 11 | pose_directory = 'pose'; 12 | dirs = dir(strcat(pose_directory, '/*/MyPoseFeatures/D3_Positions/*.cdf')); 13 | 14 | paths = {dirs.folder}; 15 | names = {dirs.name}; 16 | 17 | for i = 1:numel(names) 18 | data = cdfread(strcat(paths{i}, '/', names{i})); 19 | save(strcat(paths{i}, '/', names{i}, '.mat'), 'data'); 20 | end -------------------------------------------------------------------------------- /common/common_pytorch/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def wrap(func, *args, unsqueeze=False): 6 | """ 7 | Wrap a torch function so it can be called with NumPy arrays. 8 | Input and return types are seamlessly converted. 9 | """ 10 | 11 | # Convert input types where applicable 12 | args = list(args) 13 | for i, arg in enumerate(args): 14 | if type(arg) == np.ndarray: 15 | args[i] = torch.from_numpy(arg) 16 | if unsqueeze: 17 | args[i] = args[i].unsqueeze(0) 18 | 19 | result = func(*args) 20 | 21 | # Convert output types where applicable 22 | if isinstance(result, tuple): 23 | result = list(result) 24 | for i, res in enumerate(result): 25 | if type(res) == torch.Tensor: 26 | if unsqueeze: 27 | res = res.squeeze(0) 28 | result[i] = res.numpy() 29 | return tuple(result) 30 | elif type(result) == torch.Tensor: 31 | if unsqueeze: 32 | result = result.squeeze(0) 33 | return result.numpy() 34 | else: 35 | return result -------------------------------------------------------------------------------- /common/transformation/quaternion.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import torch 9 | 10 | def qrot(q, v): 11 | """ 12 | Rotate vector(s) v about the rotation described by quaternion(s) q. 13 | Expects a tensor of shape (*, 4) for q and a tensor of shape (*, 3) for v, 14 | where * denotes any number of dimensions. 15 | Returns a tensor of shape (*, 3). 16 | """ 17 | assert q.shape[-1] == 4 18 | assert v.shape[-1] == 3 19 | assert q.shape[:-1] == v.shape[:-1] 20 | 21 | qvec = q[..., 1:] 22 | uv = torch.cross(qvec, v, dim=len(q.shape) - 1) 23 | uuv = torch.cross(qvec, uv, dim=len(q.shape) - 1) 24 | return (v + 2 * (q[..., :1] * uv + uuv)) 25 | 26 | 27 | def qinverse(q, inplace=False): 28 | # We assume the quaternion to be normalized 29 | if inplace: 30 | q[..., 1:] *= -1 31 | return q 32 | else: 33 | w = q[..., :1] 34 | xyz = q[..., 1:] 35 | return torch.cat((w, -xyz), dim=len(q.shape) - 1) -------------------------------------------------------------------------------- /common/transformation/kpt_trans.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | def kpt_to_bone_vector(pose_3d, parent_index=None): 5 | if parent_index is not None: 6 | hm36_parent = parent_index 7 | 8 | else: 9 | hm36_parent = [-1, 0, 1, 2, 0, 4, 5, 0, 7, 8, 9, 8, 11, 12, 8, 14, 15] #by body kinematic connections 10 | #print('random parent index:',hm36_parent) 11 | bone = [] 12 | for i in range(1, len(hm36_parent)): 13 | bone_3d = pose_3d[:, :, i] - pose_3d[:,:,hm36_parent[i]] 14 | bone.append(bone_3d.unsqueeze(dim=-2)) 15 | bone_out = torch.cat(bone, dim=-2) 16 | return bone_out 17 | 18 | def two_order_bone_vector(bone_3d, parent_index=None): 19 | if parent_index is not None: 20 | hm36_parent = parent_index 21 | else: 22 | hm36_parent = [-1, 0, 1, 0, 3, 4, 0, 6, 7, 8, 7, 10, 11, 7, 13, 14] #by body kinematic connections, same to calculate angles 23 | #print('random parent index:',hm36_parent) 24 | bone = [] 25 | for i in range(1, len(hm36_parent)): 26 | bone_3d_2 = bone_3d[:, :, i] - bone_3d[:,:,hm36_parent[i]] 27 | bone.append(bone_3d_2.unsqueeze(dim=-2)) 28 | bone_out = torch.cat(bone, dim=-2) 29 | return bone_out -------------------------------------------------------------------------------- /common/dataset/mocap_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import numpy as np 9 | from common.dataset.skeleton import Skeleton 10 | 11 | 12 | class MocapDataset: 13 | def __init__(self, fps, skeleton): 14 | self._skeleton = skeleton 15 | self._fps = fps 16 | self._data = None # Must be filled by subclass 17 | self._cameras = None # Must be filled by subclass 18 | 19 | def remove_joints(self, joints_to_remove): 20 | kept_joints = self._skeleton.remove_joints(joints_to_remove) 21 | for subject in self._data.keys(): 22 | for action in self._data[subject].keys(): 23 | s = self._data[subject][action] 24 | if 'positions' in s: 25 | s['positions'] = s['positions'][:, kept_joints] 26 | 27 | def __getitem__(self, key): 28 | return self._data[key] 29 | 30 | def subjects(self): 31 | return self._data.keys() 32 | 33 | def fps(self): 34 | return self._fps 35 | 36 | def skeleton(self): 37 | return self._skeleton 38 | 39 | def cameras(self): 40 | return self._cameras 41 | 42 | def supports_semi_supervised(self): 43 | # This method can be overridden 44 | return False -------------------------------------------------------------------------------- /common/common_pytorch/model/srnet_utils/group_index.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def get_input(group): 4 | if group == 2: 5 | print('Now group is:', group) 6 | conv_seq = [range(0, 16), [0, 1, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33]] 7 | final_outc = 55 8 | elif group == 3: 9 | print('Now group is:', group) 10 | conv_seq = [range(0, 14), [0, 1, 14, 15, 16, 17, 18, 19, 20, 21], 11 | [0, 1, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33]] 12 | final_outc = 58 13 | elif group == 5: 14 | print('Now group is:', group) 15 | conv_seq = [range(0, 8), [0, 1, 8, 9, 10, 11, 12, 13], [0, 1, 14, 15, 16, 17, 18, 19, 20, 21], 16 | [0, 1, 22, 23, 24, 25, 26, 27], [0, 1, 28, 29, 30, 31, 32, 33]] 17 | final_outc = 64 18 | elif group == 1: 19 | print('Now group is:', group) 20 | conv_seq = [range(0, 34)] 21 | final_outc = 51 22 | else: 23 | raise KeyError('Invalid group number!') 24 | 25 | return conv_seq, final_outc 26 | 27 | # # 28 | def shrink_output(x): 29 | num_joints_out = x.shape[-1] 30 | pose_dim = 3 # means [X,Y,Z]: three values 31 | if num_joints_out == 1: 32 | x = x[:, :, :pose_dim] 33 | elif num_joints_out == 64: #Group = 5 34 | x = torch.cat([x[:, :, :(4*pose_dim)], x[:, :, (5*pose_dim):(8*pose_dim)], x[:, :, (9*pose_dim):(13*pose_dim)], x[:, :, (14*pose_dim):(17*pose_dim)], x[:, :, (18*pose_dim):(21*pose_dim)]], dim=-1) 35 | 36 | elif num_joints_out == 58: #Group = 3 37 | x = torch.cat([x[:, :, :(7*pose_dim)], x[:, :, (8*pose_dim):(12*pose_dim)], x[:, :, (13*pose_dim):(19*pose_dim)]], dim=-1) 38 | 39 | elif num_joints_out == 55: #Group = 2 40 | x = torch.cat([x[:, :, :(8*pose_dim)], x[:, :, (9*pose_dim):(18*pose_dim)]], dim=-1) 41 | 42 | elif num_joints_out == 52: #Group = 1 43 | x = x[:, :, :(17*pose_dim)] 44 | else: 45 | raise KeyError('Invalid outputs!') 46 | return x -------------------------------------------------------------------------------- /common/dataset/post_process/process3d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from common.transformation.cam_utils import image_coordinates, reprojection 3 | 4 | # For inverse inference to get final 3d XYZ 5 | def get_final_3d_coord(pos_3d_out, abs_root_3d, relaive_root_3d, camera, rescale_bbox_ratio, pixel_depth_ratio,norm): 6 | if norm=='lcn': # Another way to process pixel XY -> normalize xy by image height and width. Same way with LCN 7 | img_w, img_h = int(camera[0,2]*2), int(camera[0,3]*2) 8 | pixel_pose_3d = image_coordinates(pos_3d_out, img_w, img_h) 9 | pos_3d_stage3 = torch.zeros_like(pos_3d_out) 10 | pos_3d_stage3[...,2:3] = pixel_pose_3d[...,2:3]/pixel_depth_ratio 11 | pos_3d_stage3[...,:2] = pixel_pose_3d[...,:2] 12 | 13 | else: 14 | pose_relative = torch.zeros_like(pos_3d_out) 15 | pose_relative[..., :2] = relaive_root_3d 16 | pos_3d_stage1 = pos_3d_out / rescale_bbox_ratio # To recover xyz bbox scale 17 | pos_3d_stage2 = pos_3d_stage1 + pose_relative # (2000,1,17,3) #To recover xy first. 18 | 19 | pos_3d_stage3 = torch.zeros_like(pos_3d_out) 20 | pos_3d_stage3[:, :, :, 2:3] = pos_3d_stage2[:, :, :, 2:3] / pixel_depth_ratio 21 | pos_3d_stage3[..., :2] = pos_3d_stage2[..., :2].clone() 22 | 23 | abs_depth_z = pos_3d_stage3[..., 2:3].clone() 24 | abs_depth = abs_depth_z + abs_root_3d 25 | # Reprojection to get 3d X,Y 26 | reproject_3d = reprojection(pos_3d_stage3, abs_depth, camera) 27 | final_3d = reproject_3d - reproject_3d[:, :, :1] 28 | return final_3d/1000 #Use meters 29 | 30 | def post_process3d(predicted_3d, inputs3d, cam, normalize_param, norm): 31 | inputs_3d_depth = normalize_param[..., 4:5] 32 | inputs_3d_relative_xy = normalize_param[..., 2:4] 33 | rescale_bbox_ratio, pixel_depth_ratio = normalize_param[..., 1:2], normalize_param[..., 0:1] 34 | predicted_3d_pos = get_final_3d_coord(predicted_3d, inputs_3d_depth, inputs_3d_relative_xy, cam, rescale_bbox_ratio, 35 | pixel_depth_ratio,norm) 36 | inputs_3d = get_final_3d_coord(inputs3d, inputs_3d_depth, inputs_3d_relative_xy, cam, rescale_bbox_ratio, 37 | pixel_depth_ratio,norm) 38 | return predicted_3d_pos, inputs_3d 39 | -------------------------------------------------------------------------------- /common/visualization/plot_pose2d.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | # (R,G,B) 5 | color1 = [(28,26,228),(28,26,228),(28,26,228), 6 | (74,175,77), (74,175,77),(74,175,77), 7 | (153,255,255),(153,255,255),(153,255,255),(153,255,255), 8 | (163,78,152),(163,78,152),(163,78,152), 9 | (0,127,255),(0,127,255),(0,127,255)] 10 | 11 | link_pairs1 = [ 12 | [[0, 1], [1, 2], [2, 3], [0, 4], [4, 5], 13 | [5, 6], [0, 7],[7, 8], [8, 9], [9, 10], 14 | [8, 11], [11, 12], [12, 13], [8, 14], [14, 15], [15, 16]] 15 | ] 16 | 17 | point_color1 = [(0,0,0), (0,0,255), (0,0,255), (0,0,255), 18 | (0,255,0),(0,255,0),(0,255,0), 19 | (138,41,231),(138,41,231),(138,41,231),(138,41,231), 20 | (179,112,117),(179,112,117),(179,112,117), 21 | (2,95,217),(2,95,217),(2,95,217)] 22 | 23 | class ColorStyle: 24 | def __init__(self, color, link_pairs, point_color): 25 | self.color = color 26 | self.link_pairs = link_pairs 27 | self.point_color = point_color 28 | self.line_color = [] 29 | for i in range(len(self.color)): 30 | self.line_color.append(self.color[i]) 31 | 32 | self.ring_color = [] 33 | for i in range(len(self.point_color)): 34 | self.ring_color.append(self.point_color[i]) 35 | 36 | def show_2d_hm36_pose(img_path, pose_2d, index=0): 37 | # plot single pose from a image 38 | colorstyle = ColorStyle(color1, link_pairs1, point_color1) 39 | connections = [[0, 1], [1, 2], [2, 3], [0, 4], [4, 5], 40 | [5, 6], [0, 7], [7, 8], [8, 9], [9, 10], 41 | [8, 11], [11, 12], [12, 13], [8, 14], [14, 15], [15, 16]] 42 | if img_path is None: 43 | img = np.zeros((1000, 1000,3), dtype=np.uint8) 44 | else: 45 | img = cv2.imread(img_path) 46 | 47 | kps = pose_2d # 2d pose in pixel unit, shape [17, 2] 48 | for j, c in enumerate(connections): 49 | start = kps[c[0]] 50 | end = kps[c[1]] 51 | cv2.line(img, (int(start[0]), int(start[1])), (int(end[0]), int(end[1])), colorstyle.line_color[j], 3) 52 | cv2.circle(img, (int(kps[j, 0]), int(kps[j, 1])), 4, colorstyle.ring_color[j], 2) 53 | cv2.imshow('3DPW Example', img) 54 | #cv2.imwrite('data/3dpw/validation/{}_{}_{:05d}.jpg'.format(seq, p_id, index), img) 55 | cv2.waitKey(0) 56 | cv2.destroyAllWindows() 57 | -------------------------------------------------------------------------------- /common/dataset/pre_process/mpi-inf.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from common.dataset.pre_process.norm_data import norm_to_pixel 3 | 4 | def load_mpi_test(file_path, seq, baseline_normalize): 5 | """ 6 | Usage: Load a section once 7 | :param dataset_root: root path 8 | :param section: There are six sequences in this (seq=0,1,2,3,4,5). And 2935 poses in a unique set(seq==7). 9 | If you want to evaluate by scene setting, you can use the sequencewise evaluation 10 | to convert to these numbers by doing 11 | #1:Studio with Green Screen (TS1*603 + TS2 *540)/ (603+540) 12 | #2:Studio without Green Screen (TS3*505+TS4*553)/(505+553) 13 | #3:Outdoor (TS5*276+TS6*452)/(276+452) 14 | :return: Normalized 2d/3d pose, normalization params and camera intrinics. All types: List 15 | """ 16 | info = np.load(file_path, allow_pickle=True) 17 | if seq in range(0,6): 18 | pose_3d = info['pose3d_univ'][seq] 19 | pose_2d = info['pose2d'][seq] 20 | if seq in [0, 1, 2, 3]: 21 | img_w, img_h = 2048, 2048 22 | cam_intri = np.array([1500.0686135995716, 1500.6590966853348, 1017.3794860438494, 1043.062824876024, 1,1,1,1,1]) 23 | elif seq in [4, 5]: 24 | img_w, img_h = 1920, 1080 25 | cam_intri = np.array([1683.482559482185, 1671.927242063379, 939.9278168524228, 560.2072491988034, 1,1,1,1,1]) 26 | 27 | elif seq == 7: 28 | pose_3d = info['pose3d_univ'][0] 29 | pose_2d = info['pose2d'][0] 30 | img_w, img_h = 2048, 2048 31 | cam_intri = np.array([1504.1479043534127, 1556.86936732066, 991.7469587022122, 872.994958045596, 1, 1, 1, 1, 1]) 32 | params = {} 33 | if baseline_normalize: 34 | # Remove global offset, but keep trajectory in first position 35 | pose_3d[:, 1:] -= pose_3d[:, :1] 36 | normed_pose_3d = pose_3d/1000 37 | normed_pose_2d = normalize_screen_coordinates(pose_2d[..., :2], w=img_w, h=img_h) 38 | params['intrinsic'] = cam_intri 39 | else: 40 | normed_pose_3d, normed_pose_2d, pixel_ratio, rescale_ratio, offset_2d, abs_root_Z = norm_to_pixel(pose_3d/1000, pose_2d, cam_intri, norm) 41 | norm_params=np.concatenate((pixel_ratio, rescale_ratio, offset_2d, abs_root_Z), axis=-1) # [T, 1, 5], len()==4 42 | params['intrinsic'] = cam_intri 43 | params['normalization_params'] = norm_params 44 | return normed_pose_3d, normed_pose_2d, params -------------------------------------------------------------------------------- /common/dataset/pre_process/get_mpi_inf.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from common.dataset.pre_process.norm_data import norm_to_pixel 3 | from common.transformation.cam_utils import normalize_screen_coordinates 4 | 5 | def load_mpi_test(file_path, seq, norm): 6 | """ 7 | Usage: Load a section once 8 | :param dataset_root: root path 9 | :param section: There are six sequences in this (seq=0,1,2,3,4,5). And 2935 poses in a unique set(seq==7). 10 | If you want to evaluate by scene setting, you can use the sequencewise evaluation 11 | to convert to these numbers by doing 12 | #1:Studio with Green Screen (TS1*603 + TS2 *540)/ (603+540) 13 | #2:Studio without Green Screen (TS3*505+TS4*553)/(505+553) 14 | #3:Outdoor (TS5*276+TS6*452)/(276+452) 15 | :return: Normalized 2d/3d pose, normalization params and camera intrinics. All types: List 16 | """ 17 | info = np.load(file_path, allow_pickle=True) 18 | if seq in range(0,6): 19 | pose_3d = info['pose3d_univ'][seq] 20 | pose_2d = info['pose2d'][seq] 21 | if seq in [0, 1, 2, 3]: 22 | img_w, img_h = 2048, 2048 23 | cam_intri = np.array([1500.0686135995716, 1500.6590966853348, 1017.3794860438494, 1043.062824876024, 1,1,1,1,1]) 24 | elif seq in [4, 5]: 25 | img_w, img_h = 1920, 1080 26 | cam_intri = np.array([1683.482559482185, 1671.927242063379, 939.9278168524228, 560.2072491988034, 1,1,1,1,1]) 27 | 28 | elif seq == 7: 29 | pose_3d = info['pose3d_univ'][0] 30 | pose_2d = info['pose2d'][0] 31 | img_w, img_h = 2048, 2048 32 | cam_intri = np.array([1504.1479043534127, 1556.86936732066, 991.7469587022122, 872.994958045596, 1, 1, 1, 1, 1]) 33 | params = {} 34 | if norm == 'base': 35 | # Remove global offset, but keep trajectory in first position 36 | pose_3d[:, 1:] -= pose_3d[:, :1] 37 | normed_pose_3d = pose_3d/1000 38 | normed_pose_2d = normalize_screen_coordinates(pose_2d[..., :2], w=img_w, h=img_h) 39 | params['intrinsic'] = cam_intri 40 | else: 41 | normed_pose_3d, normed_pose_2d, pixel_ratio, rescale_ratio, offset_2d, abs_root_Z = norm_to_pixel(pose_3d/1000, pose_2d, cam_intri, norm) 42 | norm_params=np.concatenate((pixel_ratio, rescale_ratio, offset_2d, abs_root_Z), axis=-1) # [T, 1, 5], len()==4 43 | params['intrinsic'] = cam_intri 44 | params['normalization_params'] = norm_params 45 | return normed_pose_3d, normed_pose_2d, params -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | checkpoint/ 2 | best_checkpoint/ 3 | log/ 4 | img/ 5 | viz/ 6 | figs/ 7 | .git/ 8 | *.npy 9 | *.png 10 | *.npz 11 | *.zip 12 | *.pyc 13 | *.swp 14 | *.mp4 15 | *.bin 16 | *.pickle 17 | .DS_Store 18 | # Byte-compiled / optimized / DLL files 19 | __pycache__/ 20 | *.py[cod] 21 | *$py.class 22 | 23 | # C extensions 24 | *.so 25 | 26 | # Distribution / packaging 27 | .Python 28 | build/ 29 | develop-eggs/ 30 | dist/ 31 | downloads/ 32 | eggs/ 33 | .eggs/ 34 | lib/ 35 | lib64/ 36 | parts/ 37 | sdist/ 38 | var/ 39 | wheels/ 40 | pip-wheel-metadata/ 41 | share/python-wheels/ 42 | *.egg-info/ 43 | .installed.cfg 44 | *.egg 45 | MANIFEST 46 | 47 | # PyInstaller 48 | # Usually these files are written by a python script from a template 49 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 50 | *.manifest 51 | *.spec 52 | 53 | # Installer logs 54 | pip-log.txt 55 | pip-delete-this-directory.txt 56 | 57 | # Unit test / coverage reports 58 | htmlcov/ 59 | .tox/ 60 | .nox/ 61 | .coverage 62 | .coverage.* 63 | .cache 64 | nosetests.xml 65 | coverage.xml 66 | *.cover 67 | *.py,cover 68 | .hypothesis/ 69 | .pytest_cache/ 70 | 71 | # Translations 72 | *.mo 73 | *.pot 74 | 75 | # Django stuff: 76 | *.log 77 | local_settings.py 78 | db.sqlite3 79 | db.sqlite3-journal 80 | 81 | # Flask stuff: 82 | instance/ 83 | .webassets-cache 84 | 85 | # Scrapy stuff: 86 | .scrapy 87 | 88 | # Sphinx documentation 89 | docs/_build/ 90 | 91 | # PyBuilder 92 | target/ 93 | 94 | # Jupyter Notebook 95 | .ipynb_checkpoints 96 | 97 | # IPython 98 | profile_default/ 99 | ipython_config.py 100 | 101 | # pyenv 102 | .python-version 103 | 104 | # pipenv 105 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 106 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 107 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 108 | # install all needed dependencies. 109 | #Pipfile.lock 110 | 111 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 112 | __pypackages__/ 113 | 114 | # Celery stuff 115 | celerybeat-schedule 116 | celerybeat.pid 117 | 118 | # SageMath parsed files 119 | *.sage.py 120 | 121 | # Environments 122 | .env 123 | .venv 124 | env/ 125 | venv/ 126 | ENV/ 127 | env.bak/ 128 | venv.bak/ 129 | 130 | # Spyder project settings 131 | .spyderproject 132 | .spyproject 133 | 134 | # Rope project settings 135 | .ropeproject 136 | 137 | # mkdocs documentation 138 | /site 139 | 140 | # mypy 141 | .mypy_cache/ 142 | .dmypy.json 143 | dmypy.json 144 | 145 | # Pyre type checker 146 | .pyre/ 147 | -------------------------------------------------------------------------------- /common/dataset/skeleton.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | 9 | class Skeleton: 10 | def __init__(self, parents, joints_left, joints_right): 11 | assert len(joints_left) == len(joints_right) 12 | 13 | self._parents = np.array(parents) 14 | self._joints_left = joints_left 15 | self._joints_right = joints_right 16 | self._compute_metadata() 17 | 18 | def num_joints(self): 19 | return len(self._parents) 20 | 21 | def parents(self): 22 | return self._parents 23 | 24 | def has_children(self): 25 | return self._has_children 26 | 27 | def children(self): 28 | return self._children 29 | 30 | def remove_joints(self, joints_to_remove): 31 | """ 32 | Remove the joints specified in 'joints_to_remove'. 33 | """ 34 | valid_joints = [] 35 | for joint in range(len(self._parents)): 36 | if joint not in joints_to_remove: 37 | valid_joints.append(joint) 38 | 39 | for i in range(len(self._parents)): 40 | while self._parents[i] in joints_to_remove: 41 | self._parents[i] = self._parents[self._parents[i]] 42 | 43 | index_offsets = np.zeros(len(self._parents), dtype=int) 44 | new_parents = [] 45 | for i, parent in enumerate(self._parents): 46 | if i not in joints_to_remove: 47 | new_parents.append(parent - index_offsets[parent]) 48 | else: 49 | index_offsets[i:] += 1 50 | self._parents = np.array(new_parents) 51 | 52 | if self._joints_left is not None: 53 | new_joints_left = [] 54 | for joint in self._joints_left: 55 | if joint in valid_joints: 56 | new_joints_left.append(joint - index_offsets[joint]) 57 | self._joints_left = new_joints_left 58 | if self._joints_right is not None: 59 | new_joints_right = [] 60 | for joint in self._joints_right: 61 | if joint in valid_joints: 62 | new_joints_right.append(joint - index_offsets[joint]) 63 | self._joints_right = new_joints_right 64 | 65 | self._compute_metadata() 66 | 67 | return valid_joints 68 | 69 | def joints_left(self): 70 | return self._joints_left 71 | 72 | def joints_right(self): 73 | return self._joints_right 74 | 75 | def _compute_metadata(self): 76 | self._has_children = np.zeros(len(self._parents)).astype(bool) 77 | for i, parent in enumerate(self._parents): 78 | if parent != -1: 79 | self._has_children[parent] = True 80 | 81 | self._children = [] 82 | for i, parent in enumerate(self._parents): 83 | self._children.append([]) 84 | for i, parent in enumerate(self._parents): 85 | if parent != -1: 86 | self._children[parent].append(i) -------------------------------------------------------------------------------- /data/prepare_data_2d_h36m_generic.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import argparse 9 | import os 10 | import zipfile 11 | import numpy as np 12 | import h5py 13 | import re 14 | from glob import glob 15 | from shutil import rmtree 16 | from data_utils import suggest_metadata, suggest_pose_importer 17 | 18 | import sys 19 | sys.path.append('../') 20 | from common.utils import wrap 21 | from itertools import groupby 22 | 23 | output_prefix_2d = 'data_2d_h36m_' 24 | cam_map = { 25 | '54138969': 0, 26 | '55011271': 1, 27 | '58860488': 2, 28 | '60457274': 3, 29 | } 30 | 31 | if __name__ == '__main__': 32 | if os.path.basename(os.getcwd()) != 'data': 33 | print('This script must be launched from the "data" directory') 34 | exit(0) 35 | 36 | parser = argparse.ArgumentParser(description='Human3.6M dataset converter') 37 | 38 | parser.add_argument('-i', '--input', default='', type=str, metavar='PATH', help='input path to 2D detections') 39 | parser.add_argument('-o', '--output', default='', type=str, metavar='PATH', help='output suffix for 2D detections (e.g. detectron_pt_coco)') 40 | 41 | args = parser.parse_args() 42 | 43 | if not args.input: 44 | print('Please specify the input directory') 45 | exit(0) 46 | 47 | if not args.output: 48 | print('Please specify an output suffix (e.g. detectron_pt_coco)') 49 | exit(0) 50 | 51 | import_func = suggest_pose_importer(args.output) 52 | metadata = suggest_metadata(args.output) 53 | 54 | print('Parsing 2D detections from', args.input) 55 | 56 | output = {} 57 | file_list = glob(args.input + '/S*/*.mp4.npz') 58 | for f in file_list: 59 | path, fname = os.path.split(f) 60 | subject = os.path.basename(path) 61 | assert subject.startswith('S'), subject + ' does not look like a subject directory' 62 | 63 | if '_ALL' in fname: 64 | continue 65 | 66 | m = re.search('(.*)\\.([0-9]+)\\.mp4\\.npz', fname) 67 | action = m.group(1) 68 | camera = m.group(2) 69 | camera_idx = cam_map[camera] 70 | 71 | if subject == 'S11' and action == 'Directions': 72 | continue # Discard corrupted video 73 | 74 | # Use consistent naming convention 75 | canonical_name = action.replace('TakingPhoto', 'Photo') \ 76 | .replace('WalkingDog', 'WalkDog') 77 | 78 | keypoints = import_func(f) 79 | assert keypoints.shape[1] == metadata['num_joints'] 80 | 81 | if subject not in output: 82 | output[subject] = {} 83 | if canonical_name not in output[subject]: 84 | output[subject][canonical_name] = [None, None, None, None] 85 | output[subject][canonical_name][camera_idx] = keypoints.astype('float32') 86 | 87 | print('Saving...') 88 | np.savez_compressed(output_prefix_2d + args.output, positions_2d=output, metadata=metadata) 89 | print('Done.') -------------------------------------------------------------------------------- /data/data_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import numpy as np 9 | import h5py 10 | 11 | mpii_metadata = { 12 | 'layout_name': 'mpii', 13 | 'num_joints': 16, 14 | 'keypoints_symmetry': [ 15 | [3, 4, 5, 13, 14, 15], 16 | [0, 1, 2, 10, 11, 12], 17 | ] 18 | } 19 | 20 | coco_metadata = { 21 | 'layout_name': 'coco', 22 | 'num_joints': 17, 23 | 'keypoints_symmetry': [ 24 | [1, 3, 5, 7, 9, 11, 13, 15], 25 | [2, 4, 6, 8, 10, 12, 14, 16], 26 | ] 27 | } 28 | 29 | h36m_metadata = { 30 | 'layout_name': 'h36m', 31 | 'num_joints': 17, 32 | 'keypoints_symmetry': [ 33 | [4, 5, 6, 11, 12, 13], 34 | [1, 2, 3, 14, 15, 16], 35 | ] 36 | } 37 | 38 | humaneva15_metadata = { 39 | 'layout_name': 'humaneva15', 40 | 'num_joints': 15, 41 | 'keypoints_symmetry': [ 42 | [2, 3, 4, 8, 9, 10], 43 | [5, 6, 7, 11, 12, 13] 44 | ] 45 | } 46 | 47 | humaneva20_metadata = { 48 | 'layout_name': 'humaneva20', 49 | 'num_joints': 20, 50 | 'keypoints_symmetry': [ 51 | [3, 4, 5, 6, 11, 12, 13, 14], 52 | [7, 8, 9, 10, 15, 16, 17, 18] 53 | ] 54 | } 55 | 56 | def suggest_metadata(name): 57 | names = [] 58 | for metadata in [mpii_metadata, coco_metadata, h36m_metadata, humaneva15_metadata, humaneva20_metadata]: 59 | if metadata['layout_name'] in name: 60 | return metadata 61 | names.append(metadata['layout_name']) 62 | raise KeyError('Cannot infer keypoint layout from name "{}". Tried {}.'.format(name, names)) 63 | 64 | def import_detectron_poses(path): 65 | # Latin1 encoding because Detectron runs on Python 2.7 66 | data = np.load(path, encoding='latin1') 67 | kp = data['keypoints'] 68 | bb = data['boxes'] 69 | results = [] 70 | for i in range(len(bb)): 71 | if len(bb[i][1]) == 0: 72 | assert i > 0 73 | # Use last pose in case of detection failure 74 | results.append(results[-1]) 75 | continue 76 | best_match = np.argmax(bb[i][1][:, 4]) 77 | keypoints = kp[i][1][best_match].T.copy() 78 | results.append(keypoints) 79 | results = np.array(results) 80 | return results[:, :, 4:6] # Soft-argmax 81 | #return results[:, :, [0, 1, 3]] # Argmax + score 82 | 83 | 84 | def import_cpn_poses(path): 85 | data = np.load(path) 86 | kp = data['keypoints'] 87 | return kp[:, :, :2] 88 | 89 | 90 | def import_sh_poses(path): 91 | with h5py.File(path) as hf: 92 | positions = hf['poses'].value 93 | return positions.astype('float32') 94 | 95 | def suggest_pose_importer(name): 96 | if 'detectron' in name: 97 | return import_detectron_poses 98 | if 'cpn' in name: 99 | return import_cpn_poses 100 | if 'sh' in name: 101 | return import_sh_poses 102 | raise KeyError('Cannot infer keypoint format from name "{}". Tried detectron, cpn, sh.'.format(name)) 103 | -------------------------------------------------------------------------------- /common/transformation/aug_rotate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | from __future__ import print_function 4 | import numpy as np 5 | from common.transformation.cam_utils import * 6 | from common.arguments.basic_args import parse_args 7 | from common.visualization.plot_pose3d import plot17j 8 | 9 | args = parse_args() 10 | np.random.seed(4321) 11 | 12 | def rotate(batch_2d_in, batch_3d_in, cam, repeat_num): 13 | pose_3d = [] 14 | pose_2d = [] 15 | for i in range(repeat_num): 16 | batch_3d, batch_2d = axis_rotation(batch_3d_in*1000, cam) 17 | pose_3d.append(batch_3d) 18 | pose_2d.append(batch_2d) 19 | rotate_3d = np.concatenate(pose_3d, axis=0) 20 | batch_3d_out = np.concatenate((batch_3d_in, rotate_3d), axis=0) 21 | batch_3d_relative = batch_3d_out-batch_3d_out[:,:,:1] 22 | 23 | w, h = batch_2d_in[:,:,17:18,0:1], batch_2d_in[:,:,17:18,1:2] 24 | batch_2d_in_norm = process_2d(batch_2d_in) 25 | 26 | pose_out = [] 27 | for i in range(repeat_num): 28 | batch_2d_norm = norm_pixel(pose_2d[i], w, h) 29 | pose_out.append(batch_2d_norm) 30 | batch_2d_out = np.concatenate((batch_2d_in_norm, np.concatenate(pose_out,axis=0)), axis=0) 31 | return batch_2d_out, batch_3d_relative 32 | 33 | def axis_rotation(batch_3d,cam): 34 | # Input batch 3d pose is presented by the Relative value in pose model coordination. Root=[0,0,0] 35 | batch_root = batch_3d[:,:,:1].copy() 36 | batch_size = batch_3d.shape[0] 37 | batch_pose = batch_3d - batch_root 38 | theta = np.random.uniform(-np.pi, np.pi, batch_size).astype('f') # Y axis - roll 39 | beta = np.random.uniform(-np.pi/5, np.pi/5, batch_size).astype('f') # X axis - pitch 40 | alpha = np.random.uniform(-np.pi/5, np.pi/5, batch_size).astype('f') #Z axis - yaw 41 | 42 | cos_theta = np.cos(theta)[:, None,None,None] 43 | sin_theta = np.sin(theta)[:, None,None,None] 44 | 45 | cos_beta = np.cos(beta)[:, None,None,None] 46 | sin_beta = np.sin(beta)[:, None,None,None] 47 | 48 | cos_alpha = np.cos(alpha)[:, None,None,None] 49 | sin_alpha = np.sin(alpha)[:, None,None,None] 50 | 51 | X = batch_pose[...,0:1] 52 | Y = batch_pose[...,1:2] 53 | Z = batch_pose[...,2:3] 54 | 55 | # rotate around Y axis 56 | new_x = X * cos_theta + Z * sin_theta 57 | new_y = Y 58 | new_z = - X * sin_theta + Z * cos_theta 59 | 60 | # rotate around X axis 61 | new_x = new_x 62 | new_y = new_y * cos_beta - new_z * sin_beta 63 | new_z = new_y * sin_beta + new_z * cos_beta 64 | 65 | # rotate around Z axis 66 | new_x = new_x * cos_alpha - new_y *sin_alpha 67 | new_y = new_x * sin_alpha + new_y * cos_alpha 68 | new_z = new_z 69 | 70 | rotated_pose = np.concatenate((new_x,new_y,new_z),axis=-1) 71 | rotated_abs_3d = rotated_pose + batch_root 72 | rotated_2d = wrap(project_to_2d, rotated_abs_3d, cam) 73 | rotated_3d = rotated_abs_3d / 1000.0 #change unit from mm to m 74 | return rotated_3d, rotated_2d 75 | 76 | def process_3d(pose_3d_in): 77 | pose_3d_out = pose_3d_in - pose_3d_in[:, :, :1] 78 | return pose_3d_out 79 | 80 | def process_2d(pose_2d_in): 81 | pose_2d_joint = pose_2d_in[:,:,:17] 82 | w, h = pose_2d_in[:,:,17:18,0:1], pose_2d_in[:,:,17:18,1:2] 83 | pose_2d_in_norm = norm_pixel(pose_2d_joint, w, h) 84 | return pose_2d_in_norm 85 | 86 | def norm_pixel(pose_2d, w, h): 87 | X = pose_2d[...,0:1] 88 | Y = pose_2d[...,1:2] 89 | norm_X = X/w * 2 - 1 90 | w_abs = np.abs(w) 91 | norm_Y = Y/w_abs *2 - h/w_abs #The flip influences the w value. 92 | norm_2d = np.concatenate((norm_X,norm_Y),axis=-1) 93 | return norm_2d 94 | -------------------------------------------------------------------------------- /common/common_pytorch/experiment/tools.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import os 4 | import sys 5 | sys.path.append("../../..") 6 | 7 | from common.arguments.basic_args import parse_args 8 | args = parse_args() 9 | 10 | from tensorboardX import SummaryWriter 11 | tf_writer = SummaryWriter(log_dir=os.path.join(args.root_log, args.model_name)) 12 | 13 | def count_params(model, ): 14 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 15 | 16 | 17 | def check_rootfolder(): 18 | ####### Create log and model folder #######3 19 | folders_util = [args.root_log, args.checkpoint, 20 | os.path.join(args.root_log, args.model_name), 21 | os.path.join(args.checkpoint, args.model_name)] 22 | for folder in folders_util: 23 | if not os.path.exists(folder): 24 | print('creating folder: '+folder) 25 | os.mkdir(folder) 26 | 27 | def deterministic_random(min_value, max_value, data): 28 | digest = hashlib.sha256(data.encode()).digest() 29 | raw_value = int.from_bytes(digest[:4], byteorder='little', signed=False) 30 | return int(raw_value / (2**32 - 1) * (max_value - min_value)) + min_value 31 | 32 | def print_result(epoch, elapsed, lr, losses_3d_train, losses_3d_train_eval, losses_3d_valid): 33 | if args.no_eval: 34 | print('[%d] time %.2f lr %f 3d_train %f' % ( 35 | epoch + 1, 36 | elapsed, 37 | lr, 38 | losses_3d_train[-1])) 39 | else: 40 | output = ('[%d] time %.2f lr %f 3d_train %f 3d_eval %f 3d_valid %f' % ( 41 | epoch + 1, 42 | elapsed, 43 | lr, 44 | losses_3d_train[-1], 45 | losses_3d_train_eval[-1], 46 | losses_3d_valid[-1])) 47 | 48 | tf_writer.add_scalar('loss/valid', losses_3d_train_eval[-1], epoch + 1) 49 | tf_writer.add_scalar('loss/test', losses_3d_valid[-1], epoch + 1) 50 | tf_writer.add_scalar('lr', lr, epoch + 1) 51 | tf_writer.add_scalar('loss/train', losses_3d_train[-1], epoch + 1) 52 | print(output) 53 | 54 | def save_model(losses_3d_train, losses_3d_train_eval, losses_3d_valid, train_generator, optimizer, model_pos_train, epoch, lr, Best_model=False): 55 | if epoch % args.checkpoint_frequency == 0: 56 | chk_path = os.path.join(args.checkpoint, 'latest_epoch_{}.bin'.format(args.model_name)) 57 | print('Saving checkpoint to', chk_path) 58 | torch.save({ 59 | 'epoch': epoch, 60 | 'lr': lr, 61 | 'loss 3d train': losses_3d_train[-1], 62 | 'loss 3d eval': losses_3d_train_eval[-1], 63 | 'loss 3d test': losses_3d_valid[-1], 64 | 'random_state': train_generator.random_state(), 65 | 'optimizer': optimizer.state_dict(), 66 | 'model_pos': model_pos_train.state_dict(), 67 | }, chk_path) 68 | if Best_model: 69 | print('Best model got in epoch', epoch, ' with test error:', losses_3d_valid[-1]) 70 | best_path = os.path.join(args.best_checkpoint, 'model_best' + args.model_name + '.bin') 71 | print('Saving best checkpoint to', best_path) 72 | out = 'Best model got in epoch {0} \n with test error: {1} \n Saving best checkpoint to{2}'.format(epoch+1, losses_3d_valid[-1], best_path) 73 | print(out) 74 | torch.save({ 75 | 'epoch': epoch, 76 | 'lr': lr, 77 | 'loss 3d train': losses_3d_train[-1] * 1000, 78 | 'loss 3d eval': losses_3d_train_eval[-1] * 1000, 79 | 'loss 3d test': losses_3d_valid[-1], 80 | 'random_state': train_generator.random_state(), 81 | 'optimizer': optimizer.state_dict(), 82 | 'model_pos': model_pos_train.state_dict(), 83 | }, best_path) 84 | -------------------------------------------------------------------------------- /data/prepare_data_2d_custom.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import numpy as np 9 | from glob import glob 10 | import os 11 | import sys 12 | 13 | import argparse 14 | from data_utils import suggest_metadata 15 | 16 | output_prefix_2d = 'data_2d_custom_' 17 | 18 | def decode(filename): 19 | # Latin1 encoding because Detectron runs on Python 2.7 20 | print('Processing {}'.format(filename)) 21 | data = np.load(filename, encoding='latin1', allow_pickle=True) 22 | bb = data['boxes'] 23 | kp = data['keypoints'] 24 | metadata = data['metadata'].item() 25 | results_bb = [] 26 | results_kp = [] 27 | for i in range(len(bb)): 28 | if len(bb[i][1]) == 0 or len(kp[i][1]) == 0: 29 | # No bbox/keypoints detected for this frame -> will be interpolated 30 | results_bb.append(np.full(4, np.nan, dtype=np.float32)) # 4 bounding box coordinates 31 | results_kp.append(np.full((17, 4), np.nan, dtype=np.float32)) # 17 COCO keypoints 32 | continue 33 | best_match = np.argmax(bb[i][1][:, 4]) 34 | best_bb = bb[i][1][best_match, :4] 35 | best_kp = kp[i][1][best_match].T.copy() 36 | results_bb.append(best_bb) 37 | results_kp.append(best_kp) 38 | 39 | bb = np.array(results_bb, dtype=np.float32) 40 | kp = np.array(results_kp, dtype=np.float32) 41 | kp = kp[:, :, :2] # Extract (x, y) 42 | 43 | # Fix missing bboxes/keypoints by linear interpolation 44 | mask = ~np.isnan(bb[:, 0]) 45 | indices = np.arange(len(bb)) 46 | for i in range(4): 47 | bb[:, i] = np.interp(indices, indices[mask], bb[mask, i]) 48 | for i in range(17): 49 | for j in range(2): 50 | kp[:, i, j] = np.interp(indices, indices[mask], kp[mask, i, j]) 51 | 52 | print('{} total frames processed'.format(len(bb))) 53 | print('{} frames were interpolated'.format(np.sum(~mask))) 54 | print('----------') 55 | 56 | return [{ 57 | 'start_frame': 0, # Inclusive 58 | 'end_frame': len(kp), # Exclusive 59 | 'bounding_boxes': bb, 60 | 'keypoints': kp, 61 | }], metadata 62 | 63 | 64 | if __name__ == '__main__': 65 | if os.path.basename(os.getcwd()) != 'data': 66 | print('This script must be launched from the "data" directory') 67 | exit(0) 68 | 69 | parser = argparse.ArgumentParser(description='Custom dataset creator') 70 | parser.add_argument('-i', '--input', type=str, default='', metavar='PATH', help='detections directory') 71 | parser.add_argument('-o', '--output', type=str, default='', metavar='PATH', help='output suffix for 2D detections') 72 | args = parser.parse_args() 73 | 74 | if not args.input: 75 | print('Please specify the input directory') 76 | exit(0) 77 | 78 | if not args.output: 79 | print('Please specify an output suffix (e.g. detectron_pt_coco)') 80 | exit(0) 81 | 82 | print('Parsing 2D detections from', args.input) 83 | 84 | metadata = suggest_metadata('coco') 85 | metadata['video_metadata'] = {} 86 | 87 | output = {} 88 | file_list = glob(args.input + '/*.npz') 89 | for f in file_list: 90 | canonical_name = os.path.splitext(os.path.basename(f))[0] 91 | data, video_metadata = decode(f) 92 | output[canonical_name] = {} 93 | output[canonical_name]['custom'] = [data[0]['keypoints'].astype('float32')] 94 | metadata['video_metadata'][canonical_name] = video_metadata 95 | 96 | print('Saving...') 97 | np.savez_compressed(output_prefix_2d + args.output, positions_2d=output, metadata=metadata) 98 | print('Done.') -------------------------------------------------------------------------------- /data/prepare_data_2d_h36m_sh.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import argparse 9 | import os 10 | import zipfile 11 | import tarfile 12 | import numpy as np 13 | import h5py 14 | from glob import glob 15 | from shutil import rmtree 16 | 17 | import sys 18 | sys.path.append('../') 19 | from common.h36m_dataset import Human36mDataset 20 | from common.camera import world_to_camera, project_to_2d, image_coordinates 21 | from common.utils import wrap 22 | 23 | output_filename_pt = 'data_2d_h36m_sh_pt_mpii' 24 | output_filename_ft = 'data_2d_h36m_sh_ft_h36m' 25 | subjects = ['S1', 'S5', 'S6', 'S7', 'S8', 'S9', 'S11'] 26 | cam_map = { 27 | '54138969': 0, 28 | '55011271': 1, 29 | '58860488': 2, 30 | '60457274': 3, 31 | } 32 | 33 | metadata = { 34 | 'num_joints': 16, 35 | 'keypoints_symmetry': [ 36 | [3, 4, 5, 13, 14, 15], 37 | [0, 1, 2, 10, 11, 12], 38 | ] 39 | } 40 | 41 | def process_subject(subject, file_list, output): 42 | if subject == 'S11': 43 | assert len(file_list) == 119, "Expected 119 files for subject " + subject + ", got " + str(len(file_list)) 44 | else: 45 | assert len(file_list) == 120, "Expected 120 files for subject " + subject + ", got " + str(len(file_list)) 46 | 47 | for f in file_list: 48 | action, cam = os.path.splitext(os.path.basename(f))[0].replace('_', ' ').split('.') 49 | 50 | if subject == 'S11' and action == 'Directions': 51 | continue # Discard corrupted video 52 | 53 | if action not in output[subject]: 54 | output[subject][action] = [None, None, None, None] 55 | 56 | with h5py.File(f) as hf: 57 | positions = hf['poses'].value 58 | output[subject][action][cam_map[cam]] = positions.astype('float32') 59 | 60 | if __name__ == '__main__': 61 | if os.path.basename(os.getcwd()) != 'data': 62 | print('This script must be launched from the "data" directory') 63 | exit(0) 64 | 65 | parser = argparse.ArgumentParser(description='Human3.6M dataset downloader/converter') 66 | 67 | parser.add_argument('-pt', '--pretrained', default='', type=str, metavar='PATH', help='convert pretrained dataset') 68 | parser.add_argument('-ft', '--fine-tuned', default='', type=str, metavar='PATH', help='convert fine-tuned dataset') 69 | 70 | args = parser.parse_args() 71 | 72 | if args.pretrained: 73 | print('Converting pretrained dataset from', args.pretrained) 74 | print('Extracting...') 75 | with zipfile.ZipFile(args.pretrained, 'r') as archive: 76 | archive.extractall('sh_pt') 77 | 78 | print('Converting...') 79 | output = {} 80 | for subject in subjects: 81 | output[subject] = {} 82 | file_list = glob('sh_pt/h36m/' + subject + '/StackedHourglass/*.h5') 83 | process_subject(subject, file_list, output) 84 | 85 | print('Saving...') 86 | np.savez_compressed(output_filename_pt, positions_2d=output, metadata=metadata) 87 | 88 | print('Cleaning up...') 89 | rmtree('sh_pt') 90 | 91 | print('Done.') 92 | 93 | if args.fine_tuned: 94 | print('Converting fine-tuned dataset from', args.fine_tuned) 95 | print('Extracting...') 96 | with tarfile.open(args.fine_tuned, 'r:gz') as archive: 97 | archive.extractall('sh_ft') 98 | 99 | print('Converting...') 100 | output = {} 101 | for subject in subjects: 102 | output[subject] = {} 103 | file_list = glob('sh_ft/' + subject + '/StackedHourglassFineTuned240/*.h5') 104 | process_subject(subject, file_list, output) 105 | 106 | print('Saving...') 107 | np.savez_compressed(output_filename_ft, positions_2d=output, metadata=metadata) 108 | 109 | print('Cleaning up...') 110 | rmtree('sh_ft') 111 | 112 | print('Done.') 113 | -------------------------------------------------------------------------------- /common/common_pytorch/experiment/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | import numpy as np 4 | import os 5 | 6 | from common.common_pytorch.loss.loss_family import * 7 | from common.dataset.post_process.process3d import post_process3d 8 | from tensorboardX import SummaryWriter 9 | from common.arguments.basic_args import parse_args 10 | args = parse_args() 11 | tf_writer = SummaryWriter(log_dir=os.path.join(args.root_log, args.model_name)) 12 | 13 | def train(train_generator, model_pos_train, dataset, optimizer, epoch, norm, i_train=0, bone_length_term=False): 14 | N = 0 15 | 16 | epoch_loss_3d_train_l1 = 0 17 | epoch_loss_left_right = 0 18 | epoch_loss_3d_train_l2 = 0 19 | model_pos_train.train() 20 | 21 | # Regular supervised scenario 22 | i = 0 23 | 24 | for use_params, batch_3d, batch_2d in train_generator.next_epoch(): 25 | if norm != 'base': 26 | normalize_param = use_params['normalization_params'] 27 | cam_intri = use_params['intrinsic'] 28 | inputs_3d = torch.from_numpy(batch_3d.astype('float32')) # torch.Size([1024,1,17,3]) 29 | inputs_2d = torch.from_numpy(batch_2d.astype('float32')) # torch.Size([1024,27,17,2]) 30 | if torch.cuda.is_available(): 31 | inputs_3d = inputs_3d.cuda() 32 | inputs_2d = inputs_2d.cuda() 33 | if norm != 'base': 34 | cam = torch.from_numpy(cam_intri.astype('float32')).cuda() # torch.Size([1024,9]) 35 | normalize_param = torch.from_numpy(normalize_param.astype('float32')).cuda() 36 | 37 | optimizer.zero_grad() 38 | # Train model 39 | predicted_3d_pos = model_pos_train(inputs_2d) 40 | if norm == 'base': 41 | inputs_3d[:, :, 0] = 0 42 | # Calculate L1 Loss 43 | 44 | loss_3d_pos_l1 = L1_loss(predicted_3d_pos, inputs_3d) 45 | epoch_loss_3d_train_l1 += inputs_3d.shape[0] * inputs_3d.shape[1] * loss_3d_pos_l1.item() 46 | N += inputs_3d.shape[0] * inputs_3d.shape[1] 47 | 48 | # denorm 3d pose for proposed normalization 49 | if norm != 'base': 50 | predicted_3d_pos, inputs_3d = post_process3d(predicted_3d_pos, inputs_3d, cam, normalize_param, norm) 51 | # Calculate L2 error with denormed 3d pose in meters unit. 52 | loss_l2 = mpjpe(predicted_3d_pos, inputs_3d) 53 | epoch_loss_3d_train_l2 += inputs_3d.shape[0] * inputs_3d.shape[1] * loss_l2.item() 54 | 55 | print('each joint l1 loss', loss_3d_pos_l1.item( ) *1000 ,'Total epoch l1 loss average', epoch_loss_3d_train_l1/ N* 1000) 56 | 57 | # Bone length term to enforce kinematic constraints 58 | if bone_length_term: #not used by default 59 | if epoch > 100: 60 | left = [4, 5, 6, 11, 12, 13] 61 | right = [1, 2, 3, 14, 15, 16] 62 | bone_lengths_lift = [] 63 | bone_lengths_right = [] 64 | each_bone_error = [] 65 | left_right_error = 0 66 | # error = [0.001, 0.0018, 0.0008, 0.0019, 0.0043, 0.0011] 67 | for i in left: 68 | dists_l = predicted_3d_pos[:, :, i, :] - predicted_3d_pos[:, :, dataset.skeleton().parents()[i], :] 69 | bone_lengths_lift.append(torch.mean(torch.norm(dists_l, dim=-1), dim=1)) 70 | for i in right: 71 | dists_r = predicted_3d_pos[:, :, i, :] - predicted_3d_pos[:, :, dataset.skeleton().parents()[i], :] 72 | bone_lengths_right.append(torch.mean(torch.norm(dists_r, dim=-1), dim=1)) 73 | for i in range(len(left)): 74 | left_right_error += torch.abs( 75 | torch.abs(bone_lengths_right[i] - bone_lengths_lift[i])) 76 | each_bone_error.append(torch.mean(torch.abs(bone_lengths_right[i] - bone_lengths_lift[i]))) 77 | # print('each bone error', each_bone_error[-1] * 1000) 78 | left_right_err_mean = torch.mean(left_right_error) 79 | 80 | epoch_loss_left_right += inputs_3d.shape[0] * inputs_3d.shape[1] * left_right_err_mean.item() 81 | print('Each epoch left right error average', (epoch_loss_left_right / N) * 1000) 82 | else: 83 | left_right_err_mean = 0 84 | else: 85 | left_right_err_mean = 0 86 | i_train += 1 87 | tf_writer.add_scalar('loss/training part', (epoch_loss_3d_train_l1 / N) * 1000, i_train) 88 | 89 | # loss_total = loss_3d_pos_l1 + 0.1*left_right_err_mean 90 | loss_total = loss_3d_pos_l1 91 | loss_total.backward() 92 | optimizer.step() 93 | return epoch_loss_3d_train_l2/N*1000, i_train 94 | -------------------------------------------------------------------------------- /common/dataset/pre_process/norm_data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from common.transformation.cam_utils import project_to_2d_linear, normalize_screen_coordinates 5 | from common.common_pytorch.utils import wrap 6 | 7 | 8 | def week_perspective_scale(camera_params, depth): 9 | fx = camera_params[..., 0:1] 10 | return fx / depth 11 | 12 | def change_to_mm(input): 13 | return input*1000 14 | 15 | def change_to_m(input): 16 | return input/1000 17 | 18 | def norm_to_pixel(pose_3d, pose_2d, camera, norm): 19 | pose_3d = change_to_mm(pose_3d) # change m into mm 20 | pose3d_pixel, pixel_ratio = norm_to_pixel_s1(pose_3d, camera, norm) 21 | normed_3d, normed_2d, rescale_ratio, offset_2d, abs_root_Z = norm_to_pixel_s2(pose3d_pixel, pose_3d[:, 0:1], pose_2d, camera) 22 | if norm =='lcn': 23 | c_x, c_y = camera[2], camera[3] 24 | img_w = int(2 * c_x) 25 | img_h = int(2* c_y) 26 | normed_3d = normalize_screen_coordinates(pose3d_pixel,img_w,img_h) 27 | normed_2d = normalize_screen_coordinates(pose_2d,img_w,img_h) 28 | return normed_3d, normed_2d, pixel_ratio, rescale_ratio, offset_2d, abs_root_Z 29 | 30 | 31 | def norm_to_pixel_s1(pose_3d, camera, norm): 32 | """ 33 | pose_3d: 3d joints with absolute location in the camera coordinate system (meters) 34 | pose_3d.shape = [T, K, N], e.g. [1500, 17, 3] 35 | pose_2d: 2d joints with pixel location in the images coordinate system (pixels) 36 | pose_3d.shape = [T, K, M], e.g. [1500, 17, 2] 37 | return: normed_3d: root joint contain relative [x,y] offset and absolute depth of root Z. others joints are normed 3d joints in pixel unit 38 | normed_2d: zero-center root with resize into a fixed bbox 39 | """ 40 | # stage1: linear project 3d X,Y to pixel unit, corresponding scale Z to keep the same 3d scale 41 | pose3d_root_Z = pose_3d[:, 0:1, 2:3].copy() 42 | 43 | camera = np.repeat(camera[np.newaxis, :], pose3d_root_Z.shape[0], axis=0) 44 | if norm == 'lcn': 45 | ratio1 = week_perspective_scale(camera[:,np.newaxis], pose3d_root_Z)+1 #[T,1,1] project depth as the same scale with XY 46 | else: 47 | ratio1 = week_perspective_scale(camera[:,np.newaxis], pose3d_root_Z) #[T,1,1] project depth as the same scale with XY 48 | 49 | pose3d_pixel = np.zeros_like(pose_3d) 50 | if norm == 'weak_proj': 51 | pose3d_root = np.repeat(pose3d_root_Z, 17, axis=-2) # (T,17,1) # For weak perspective projection 52 | pose3d_pixel[..., :2] = pose_3d[..., :2]/pose3d_root * camera[:, np.newaxis, :2] + camera[:, np.newaxis, 2:4] 53 | else: 54 | pose3d_pixel[..., :2] = wrap(project_to_2d_linear, pose_3d.copy(), camera) # Keep all depth from each joints, projected 2d xy are more precise. 55 | pose3d_relative_depth = minus_root(pose_3d[..., 2:3]) # Make root depth=0 56 | pose3d_stage1_depth = pose3d_relative_depth * ratio1 # Root_depth=0 [2000,17,1] 57 | pose3d_pixel[..., 2:3] = pose3d_stage1_depth.copy() 58 | return pose3d_pixel, ratio1 59 | 60 | def norm_to_pixel_s2(pose3d_pixel, root_joint, pose_2d, camera, bbox_scale=2): 61 | # stage2: Resize 2d and 3d pixel position into one fixed bbox_scale 62 | pose3d_root_Z = root_joint[:, :, 2:3].copy() 63 | tl_3d_joint, br_3d_joint = make_3d_bbox(root_joint) 64 | 65 | camera = np.repeat(camera[np.newaxis, :], root_joint.shape[0], axis=0) 66 | tl2d = wrap(project_to_2d_linear, tl_3d_joint, camera) # Use weak perspective 67 | br2d = wrap(project_to_2d_linear, br_3d_joint, camera) # Use weak perspective 68 | bbox_2d = np.concatenate((tl2d.squeeze(), br2d.squeeze()), axis=-1) 69 | 70 | diff_bbox_2d = bbox_2d[..., 2:3] - bbox_2d[..., 0:1] # (x_br - x_tl) 71 | ratio2 = bbox_scale / diff_bbox_2d # ratio2.all() == (bbox_scale/(ratio2 * rectange_3d_size).all()) 72 | 73 | # Get normed 3d joints 74 | pixel_xy_root = pose3d_pixel[:, 0:1, 0:2] # [T,1,2] 75 | reshape_3d = pose3d_pixel * ratio2[:, np.newaxis, :] 76 | normed_3d = minus_root(reshape_3d) 77 | 78 | # Get normed 2d joints 79 | reshape_2d = pose_2d * ratio2[:, :, np.newaxis] 80 | normed_2d = minus_root(reshape_2d) 81 | return normed_3d, normed_2d, ratio2[:, np.newaxis], pixel_xy_root, pose3d_root_Z 82 | 83 | def make_3d_bbox(pose_3d_root, rectangle_3d_size=2000): 84 | tl_joint = pose_3d_root.copy() 85 | tl_joint[..., :2] -= rectangle_3d_size/2 # 1000mm 86 | br_joint = pose_3d_root.copy() 87 | br_joint[..., :2] += rectangle_3d_size/2 # 1000mm 88 | return tl_joint, br_joint 89 | 90 | def minus_root(pose): 91 | # Assume pose.shape = [T, K ,N] 92 | pose_root = pose[:,:1] 93 | relative_pose = pose - pose_root 94 | return relative_pose 95 | 96 | def get_ratio(abs_root_3d, camera): 97 | # abs_root_3d.shape = [B, T, 1, 1] 98 | # camera.shape = [B, 9] & [2, 9] 99 | bbox_scale = 1 100 | rectangle_3d_size = 2000 101 | camera = camera.unsqueeze(dim=1).unsqueeze(dim=1) 102 | fx, fy = camera[:,:,:,0:1], camera[:,:,:,1:2] 103 | pixel_depth_ratio = fx / abs_root_3d 104 | rescale_bbox = bbox_scale / pixel_depth_ratio 105 | rescale_bbox_ratio = rescale_bbox / rectangle_3d_size 106 | return rescale_bbox_ratio, pixel_depth_ratio 107 | 108 | -------------------------------------------------------------------------------- /common/transformation/cam_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import numpy as np 9 | import torch 10 | 11 | from common.common_pytorch.utils import wrap 12 | from common.transformation.quaternion import qrot, qinverse 13 | 14 | 15 | def normalize_screen_coordinates(X, w, h): 16 | # Normalize so that [0, w] is mapped to [-1, 1], while preserving the aspect ratio 17 | if X.shape[-1] == 3: #input 3d pose 18 | X_norm = X[..., :2] 19 | X_norm = X_norm / w * 2 - [1, h / w] 20 | X_out = np.concatenate((X_norm, X[..., 2:3] / 1000), -1) 21 | else: 22 | assert X.shape[-1] == 2 23 | X_out = X / w * 2 - [1, h / w] 24 | return X_out 25 | 26 | 27 | def image_coordinates(X, w, h): 28 | # Reverse camera frame normalization 29 | if X.shape[-1] == 3: #input 3d pose 30 | X_norm = X[..., :2] 31 | X_norm[..., :1] = (X_norm[..., :1] + 1) * w / 2 32 | X_norm[..., 1:2] = (X_norm[..., 1:2] + h / w) * w / 2 33 | X_out = torch.cat([X_norm, X[..., 2:3] * 1000], -1) 34 | 35 | else: 36 | assert X.shape[-1] == 2 37 | X_out = (X + [1, h / w]) * w / 2 38 | return X_out 39 | 40 | 41 | def world_to_camera(X, R, t): 42 | Rt = wrap(qinverse, R) # Invert rotation 43 | return wrap(qrot, np.tile(Rt, (*X.shape[:-1], 1)), X - t) # Rotate and translate 44 | 45 | 46 | def camera_to_world(X, R, t): 47 | return wrap(qrot, np.tile(R, (*X.shape[:-1], 1)), X) + t 48 | 49 | 50 | def get_intrinsic(camera_params): 51 | assert len(camera_params.shape) == 2 52 | assert camera_params.shape[-1] == 9 53 | fx, fy, cx, cy = camera_params[..., :1], camera_params[..., 1:2], camera_params[..., 2:3], camera_params[..., 3:4] 54 | return fx, fy, cx, cy 55 | 56 | 57 | def infer_camera_intrinsics(points2d, points3d): 58 | """Infer camera instrinsics from 2D<->3D point correspondences.""" 59 | pose2d = points2d.reshape(-1, 2) 60 | pose3d = points3d.reshape(-1, 3) 61 | x3d = np.stack([pose3d[:, 0], pose3d[:, 2]], axis=-1) 62 | x2d = (pose2d[:, 0] * pose3d[:, 2]) 63 | alpha_x, x_0 = list(np.linalg.lstsq(x3d, x2d, rcond=-1)[0].flatten()) 64 | y3d = np.stack([pose3d[:, 1], pose3d[:, 2]], axis=-1) 65 | y2d = (pose2d[:, 1] * pose3d[:, 2]) 66 | alpha_y, y_0 = list(np.linalg.lstsq(y3d, y2d, rcond=-1)[0].flatten()) 67 | return np.array([alpha_x, x_0, alpha_y, y_0]) 68 | 69 | 70 | def project_to_2d(X, camera_params): 71 | """ 72 | Project 3D points to 2D using the Human3.6M camera projection function. 73 | This is a differentiable and batched reimplementation of the original MATLAB script. 74 | 75 | Arguments: 76 | X -- 3D points in *camera space* to transform (N, *, 3) 77 | camera_params -- intrinsic parameteres (N, 2+2+3+2=9) 78 | """ 79 | assert X.shape[-1] == 3 80 | assert len(camera_params.shape) == 2 81 | assert camera_params.shape[-1] == 9 82 | assert X.shape[0] == camera_params.shape[0] 83 | 84 | while len(camera_params.shape) < len(X.shape): 85 | camera_params = camera_params.unsqueeze(1) 86 | 87 | f = camera_params[..., :2] 88 | c = camera_params[..., 2:4] 89 | k = camera_params[..., 4:7] 90 | p = camera_params[..., 7:] 91 | 92 | XX = torch.clamp(X[..., :2] / X[..., 2:], min=-1, max=1) 93 | r2 = torch.sum(XX[..., :2] ** 2, dim=len(XX.shape) - 1, keepdim=True) 94 | 95 | radial = 1 + torch.sum(k * torch.cat((r2, r2 ** 2, r2 ** 3), dim=len(r2.shape) - 1), dim=len(r2.shape) - 1, 96 | keepdim=True) 97 | tan = torch.sum(p * XX, dim=len(XX.shape) - 1, keepdim=True) 98 | 99 | XXX = XX * (radial + tan) + p * r2 100 | 101 | return f * XXX + c 102 | 103 | 104 | def project_to_2d_linear(X, camera_params): 105 | """ 106 | Project 3D points to 2D using only linear parameters (focal length and principal point). 107 | 108 | Arguments: 109 | X -- 3D points in *camera space* to transform (N, *, 3) 110 | camera_params -- intrinsic parameteres (N, 2+2+3+2=9) 111 | """ 112 | assert X.shape[-1] == 3 113 | assert len(camera_params.shape) == 2 114 | assert camera_params.shape[-1] == 9 115 | assert X.shape[0] == camera_params.shape[0] 116 | 117 | while len(camera_params.shape) < len(X.shape): 118 | if type(camera_params) == torch: 119 | camera_params = camera_params.unsqueeze(1) 120 | else: 121 | camera_params = camera_params[:, np.newaxis] 122 | 123 | f = camera_params[..., :2] 124 | c = camera_params[..., 2:4] 125 | XX = X[..., :2] / X[..., 2:] 126 | # XX = torch.clamp(X[..., :2] / X[..., 2:], min=-1, max=1) 127 | if np.array(XX).any() > 1 or np.array(XX).any() < -1: 128 | print(np.array(XX).any() > 1 or np.array(XX).any() < -1) 129 | print('Attention for this pose!!!') 130 | return f * XX + c 131 | 132 | 133 | def reprojection(pose_3d, abs_depth, camera): 134 | """ 135 | :param pose_3d: predicted 3d or normed 3d with pixel unit 136 | :param abs_depth: absolute depth root Z in the camera coordinate 137 | :param camera: camera intrinsic parameters 138 | :return: 3d pose in the camera cooridinate with millimeter unit, root joint: zero-center 139 | """ 140 | camera = camera.unsqueeze(dim=1).unsqueeze(dim=1) 141 | cx, cy, fx, fy = camera[:,:,:,2:3], camera[:,:,:,3:4], camera[:,:,:,0:1], camera[:,:,:,1:2] 142 | final_3d = torch.zeros_like(pose_3d) 143 | final_3d_x = (pose_3d[:, :, :, 0:1] - cx) / fx 144 | final_3d_y = (pose_3d[:, :, :, 1:2] - cy) / fy 145 | final_3d[:, :, :, 0:1] = final_3d_x * abs_depth 146 | final_3d[:, :, :, 1:2] = final_3d_y * abs_depth 147 | final_3d[:, :, :, 2:3] = abs_depth 148 | return final_3d -------------------------------------------------------------------------------- /data/prepare_data_h36m.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import argparse 9 | import os 10 | import zipfile 11 | import numpy as np 12 | import h5py 13 | from glob import glob 14 | from shutil import rmtree 15 | 16 | import sys 17 | sys.path.append('../') 18 | from common.h36m_dataset import Human36mDataset 19 | from common.camera import world_to_camera, project_to_2d, image_coordinates 20 | from common.utils import wrap 21 | 22 | output_filename = 'data_3d_h36m' 23 | output_filename_2d = 'data_2d_h36m_gt' 24 | subjects = ['S1', 'S5', 'S6', 'S7', 'S8', 'S9', 'S11'] 25 | 26 | if __name__ == '__main__': 27 | if os.path.basename(os.getcwd()) != 'data': 28 | print('This script must be launched from the "data" directory') 29 | exit(0) 30 | 31 | parser = argparse.ArgumentParser(description='Human3.6M dataset downloader/converter') 32 | 33 | # Default: convert dataset preprocessed by Martinez et al. in https://github.com/una-dinosauria/3d-pose-baseline 34 | parser.add_argument('--from-archive', default='', type=str, metavar='PATH', help='convert preprocessed dataset') 35 | 36 | # Alternatively, convert dataset from original source (the Human3.6M dataset path must be specified manually) 37 | parser.add_argument('--from-source', default='', type=str, metavar='PATH', help='convert original dataset') 38 | 39 | args = parser.parse_args() 40 | 41 | if args.from_archive and args.from_source: 42 | print('Please specify only one argument') 43 | exit(0) 44 | 45 | if os.path.exists(output_filename + '.npz'): 46 | print('The dataset already exists at', output_filename + '.npz') 47 | exit(0) 48 | 49 | if args.from_archive: 50 | print('Extracting Human3.6M dataset from', args.from_archive) 51 | with zipfile.ZipFile(args.from_archive, 'r') as archive: 52 | archive.extractall() 53 | 54 | print('Converting...') 55 | output = {} 56 | for subject in subjects: 57 | output[subject] = {} 58 | file_list = glob('h36m/' + subject + '/MyPoses/3D_positions/*.h5') 59 | assert len(file_list) == 30, "Expected 30 files for subject " + subject + ", got " + str(len(file_list)) 60 | for f in file_list: 61 | action = os.path.splitext(os.path.basename(f))[0] 62 | 63 | if subject == 'S11' and action == 'Directions': 64 | continue # Discard corrupted video 65 | 66 | with h5py.File(f) as hf: 67 | positions = hf['3D_positions'].value.reshape(32, 3, -1).transpose(2, 0, 1) 68 | positions /= 1000 # Meters instead of millimeters 69 | output[subject][action] = positions.astype('float32') 70 | 71 | print('Saving...') 72 | np.savez_compressed(output_filename, positions_3d=output) 73 | 74 | print('Cleaning up...') 75 | rmtree('h36m') 76 | 77 | print('Done.') 78 | 79 | elif args.from_source: 80 | print('Converting original Human3.6M dataset from', args.from_source) 81 | output = {} 82 | 83 | from scipy.io import loadmat 84 | 85 | for subject in subjects: 86 | output[subject] = {} 87 | file_list = glob(args.from_source + '/' + subject + '/MyPoseFeatures/D3_Positions/*.cdf.mat') 88 | assert len(file_list) == 30, "Expected 30 files for subject " + subject + ", got " + str(len(file_list)) 89 | for f in file_list: 90 | action = os.path.splitext(os.path.splitext(os.path.basename(f))[0])[0] 91 | 92 | if subject == 'S11' and action == 'Directions': 93 | continue # Discard corrupted video 94 | 95 | # Use consistent naming convention 96 | canonical_name = action.replace('TakingPhoto', 'Photo') \ 97 | .replace('WalkingDog', 'WalkDog') 98 | 99 | hf = loadmat(f) 100 | positions = hf['data'][0, 0].reshape(-1, 32, 3) 101 | positions /= 1000 # Meters instead of millimeters 102 | output[subject][canonical_name] = positions.astype('float32') 103 | 104 | print('Saving...') 105 | np.savez_compressed(output_filename, positions_3d=output) 106 | 107 | print('Done.') 108 | 109 | else: 110 | print('Please specify the dataset source') 111 | exit(0) 112 | 113 | # Create 2D pose file 114 | print('') 115 | print('Computing ground-truth 2D poses...') 116 | dataset = Human36mDataset(output_filename + '.npz') 117 | output_2d_poses = {} 118 | for subject in dataset.subjects(): 119 | output_2d_poses[subject] = {} 120 | for action in dataset[subject].keys(): 121 | anim = dataset[subject][action] 122 | 123 | positions_2d = [] 124 | for cam in anim['cameras']: 125 | pos_3d = world_to_camera(anim['positions'], R=cam['orientation'], t=cam['translation']) 126 | pos_2d = wrap(project_to_2d, pos_3d, cam['intrinsic'], unsqueeze=True) 127 | pos_2d_pixel_space = image_coordinates(pos_2d, w=cam['res_w'], h=cam['res_h']) 128 | positions_2d.append(pos_2d_pixel_space.astype('float32')) 129 | output_2d_poses[subject][action] = positions_2d 130 | 131 | print('Saving...') 132 | metadata = { 133 | 'num_joints': dataset.skeleton().num_joints(), 134 | 'keypoints_symmetry': [dataset.skeleton().joints_left(), dataset.skeleton().joints_right()] 135 | } 136 | np.savez_compressed(output_filename_2d, positions_2d=output_2d_poses, metadata=metadata) 137 | 138 | print('Done.') 139 | -------------------------------------------------------------------------------- /data/ConvertHumanEva.m: -------------------------------------------------------------------------------- 1 | % Copyright (c) 2018-present, Facebook, Inc. 2 | % All rights reserved. 3 | % 4 | % This source code is licensed under the license found in the 5 | % LICENSE file in the root directory of this source tree. 6 | % 7 | 8 | function [] = ConvertDataset() 9 | 10 | N_JOINTS = 15; % Set to 20 if you want to export a 20-joint skeleton 11 | 12 | function [pose_out] = ExtractPose15(pose, dimensions) 13 | % We use the same 15-joint skeleton as in the evaluation 14 | % script "@body_pose/error.m". Proximal and Distal joints 15 | % are averaged. 16 | pose_out = NaN(15, dimensions); 17 | pose_out(1, :) = pose.torsoDistal; % Pelvis (root) 18 | pose_out(2, :) = (pose.torsoProximal + pose.headProximal) / 2; % Thorax 19 | pose_out(3, :) = pose.upperLArmProximal; % Left shoulder 20 | pose_out(4, :) = (pose.upperLArmDistal + pose.lowerLArmProximal) / 2; % Left elbow 21 | pose_out(5, :) = pose.lowerLArmDistal; % Left wrist 22 | pose_out(6, :) = pose.upperRArmProximal; % Right shoulder 23 | pose_out(7, :) = (pose.upperRArmDistal + pose.lowerRArmProximal) / 2; % Right elbow 24 | pose_out(8, :) = pose.lowerRArmDistal; % Right wrist 25 | pose_out(9, :) = pose.upperLLegProximal; % Left hip 26 | pose_out(10, :) = (pose.upperLLegDistal + pose.lowerLLegProximal) / 2; % Left knee 27 | pose_out(11, :) = pose.lowerLLegDistal; % Left ankle 28 | pose_out(12, :) = pose.upperRLegProximal; % Right hip 29 | pose_out(13, :) = (pose.upperRLegDistal + pose.lowerRLegProximal) / 2; % Right knee 30 | pose_out(14, :) = pose.lowerRLegDistal; % Right ankle 31 | pose_out(15, :) = pose.headDistal; % Head 32 | end 33 | 34 | function [pose_out] = ExtractPose20(pose, dimensions) 35 | pose_out = NaN(20, dimensions); 36 | pose_out(1, :) = pose.torsoDistal; % Pelvis (root) 37 | pose_out(2, :) = pose.torsoProximal; 38 | pose_out(3, :) = pose.headProximal; 39 | pose_out(4, :) = pose.upperLArmProximal; % Left shoulder 40 | pose_out(5, :) = pose.upperLArmDistal; 41 | pose_out(6, :) = pose.lowerLArmProximal; 42 | pose_out(7, :) = pose.lowerLArmDistal; % Left wrist 43 | pose_out(8, :) = pose.upperRArmProximal; % Right shoulder 44 | pose_out(9, :) = pose.upperRArmDistal; 45 | pose_out(10, :) = pose.lowerRArmProximal; 46 | pose_out(11, :) = pose.lowerRArmDistal; % Right wrist 47 | pose_out(12, :) = pose.upperLLegProximal; % Left hip 48 | pose_out(13, :) = pose.upperLLegDistal; 49 | pose_out(14, :) = pose.lowerLLegProximal; 50 | pose_out(15, :) = pose.lowerLLegDistal; % Left ankle 51 | pose_out(16, :) = pose.upperRLegProximal; % Right hip 52 | pose_out(17, :) = pose.upperRLegDistal; 53 | pose_out(18, :) = pose.lowerRLegProximal; 54 | pose_out(19, :) = pose.lowerRLegDistal; % Right ankle 55 | pose_out(20, :) = pose.headDistal; % Head 56 | end 57 | 58 | addpath('./TOOLBOX_calib/'); 59 | addpath('./TOOLBOX_common/'); 60 | addpath('./TOOLBOX_dxAvi/'); 61 | addpath('./TOOLBOX_readc3d/'); 62 | 63 | % Create the output directory for the converted dataset 64 | OUT_DIR = ['./converted_', int2str(N_JOINTS), 'j']; 65 | warning('off', 'MATLAB:MKDIR:DirectoryExists'); 66 | mkdir(OUT_DIR); 67 | 68 | % We use the validation set as the test set 69 | for SPLIT = {'Train', 'Validate'} 70 | mkdir([OUT_DIR, '/', SPLIT{1}]); 71 | CurrentDataset = he_dataset('HumanEvaI', SPLIT{1}); 72 | 73 | for SEQ = 1:length(CurrentDataset) 74 | 75 | Subject = char(get(CurrentDataset(SEQ), 'SubjectName')); 76 | Action = char(get(CurrentDataset(SEQ), 'ActionType')); 77 | Trial = char(get(CurrentDataset(SEQ), 'Trial')); 78 | DatasetBasePath = char(get(CurrentDataset(SEQ), 'DatasetBasePath')); 79 | if Trial ~= '1' 80 | % We are only interested in fully-annotated data 81 | continue; 82 | end 83 | 84 | if strcmp(Action, 'ThrowCatch') && strcmp(Subject, 'S3') 85 | % Damaged mocap stream 86 | continue; 87 | end 88 | 89 | fprintf('Converting...\n') 90 | fprintf('\tSplit: %s\n', SPLIT{1}); 91 | fprintf('\tSubject: %s\n', Subject); 92 | fprintf('\tAction: %s\n', Action); 93 | fprintf('\tTrial: %s\n', Trial); 94 | 95 | % Create subject directory if it does not exist 96 | mkdir([OUT_DIR, '/', SPLIT{1}, '/', Subject]); 97 | 98 | % Load the sequence 99 | [~, ~, MocapStream, MocapStream_Enabled] ... 100 | = sync_stream(CurrentDataset(SEQ)); 101 | 102 | % Set frame range 103 | FrameStart = get(CurrentDataset(SEQ), 'FrameStart'); 104 | FrameStart = [FrameStart{:}]; 105 | FrameEnd = get(CurrentDataset(SEQ), 'FrameEnd'); 106 | FrameEnd = [FrameEnd{:}]; 107 | 108 | fprintf('\tNum. frames: %d\n', FrameEnd - FrameStart + 1); 109 | poses_3d = NaN(FrameEnd - FrameStart + 1, N_JOINTS, 3); 110 | poses_2d = NaN(3, FrameEnd - FrameStart + 1, N_JOINTS, 2); 111 | corrupt = 0; 112 | for FRAME = FrameStart:FrameEnd 113 | 114 | if (MocapStream_Enabled) 115 | [MocapStream, pose, ValidPose] = cur_frame(MocapStream, FRAME, 'body_pose'); 116 | 117 | if (ValidPose) 118 | i = FRAME - FrameStart + 1; 119 | 120 | % Extract 3D pose 121 | if N_JOINTS == 15 122 | poses_3d(i, :, :) = ExtractPose15(pose, 3); 123 | else 124 | poses_3d(i, :, :) = ExtractPose20(pose, 3); 125 | end 126 | 127 | % Extract ground-truth 2D pose via camera 128 | % projection 129 | for CAM = 1:3 130 | if (CAM == 1) 131 | CameraName = 'C1'; 132 | elseif (CAM == 2) 133 | CameraName = 'C2'; 134 | elseif (CAM == 3) 135 | CameraName = 'C3'; 136 | end 137 | CalibrationFilename = [DatasetBasePath, Subject, '/Calibration_Data/', CameraName, '.cal']; 138 | pose_2d = project2d(pose, CalibrationFilename); 139 | if N_JOINTS == 15 140 | poses_2d(CAM, i, :, :) = ExtractPose15(pose_2d, 2); 141 | else 142 | poses_2d(CAM, i, :, :) = ExtractPose20(pose_2d, 2); 143 | end 144 | end 145 | 146 | else 147 | corrupt = corrupt + 1; 148 | end 149 | end 150 | end 151 | fprintf('\n%d out of %d frames are damaged\n', corrupt, FrameEnd - FrameStart + 1); 152 | FileName = [OUT_DIR, '/', SPLIT{1}, '/', Subject, '/', Action, '_', Trial, '.mat']; 153 | save(FileName, 'poses_3d', 'poses_2d'); 154 | fprintf('... saved to %s\n\n', FileName); 155 | end 156 | end 157 | end 158 | -------------------------------------------------------------------------------- /common/common_pytorch/experiment/eval_metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from common.common_pytorch.loss.loss_family import p_mpjpe 5 | from common.visualization.plot_pose3d import plot17j 6 | 7 | def gather_3d_metrics(expected, actual): 8 | """ 9 | 10 | :param expected: Predicted pose 11 | :param actual: Ground Truth 12 | :return: evaluation results 13 | """ 14 | unaligned_pck = pck(actual, expected) 15 | unaligned_auc = auc(actual, expected) 16 | expect_np = expected.cpu().numpy().reshape(-1, expected.shape[-2], expected.shape[-1]) 17 | actual_np = actual.cpu().numpy().reshape(-1, expected.shape[-2], expected.shape[-1]) 18 | aligned_mpjpe, aligned = p_mpjpe(expect_np, actual_np) 19 | #plot17j(np.concatenate((actual_np[100:104],expect_np[100:104],aligned[0,100:104].cpu().numpy()),axis=0),'aa','aa') 20 | aligned_pck = pck(aligned, actual) 21 | aligned_auc = auc(aligned, actual) 22 | return dict( 23 | pck=unaligned_pck, 24 | auc=unaligned_auc, 25 | aligned_mpjpe=aligned_mpjpe, 26 | aligned_pck=aligned_pck, 27 | aligned_auc=aligned_auc, 28 | ) 29 | 30 | def pck(actual, expected,threshold=150): 31 | dists = torch.norm((actual - expected), dim=len(actual.shape)-1) 32 | error = (dists < threshold).double().mean().item() 33 | return error 34 | 35 | def auc(actual, expected): 36 | # This range of thresholds mimics `mpii_compute_3d_pck.m`, which is provided as part of the 37 | # MPI-INF-3DHP test data release. 38 | thresholds = torch.linspace(0, 150, 31).tolist() 39 | pck_values = torch.DoubleTensor(len(thresholds)) 40 | for i, threshold in enumerate(thresholds): 41 | pck_values[i] = pck(actual, expected, threshold=threshold) 42 | return pck_values.mean().item() 43 | 44 | def kpt_to_bone_vector(pose_3d, parent_index=None): 45 | if parent_index is not None: 46 | hm36_parent = parent_index 47 | 48 | else: 49 | hm36_parent = [-1, 0, 1, 2, 0, 4, 5, 0, 7, 8, 9, 8, 11, 12, 8, 14, 15] #by body kinematic connections 50 | #print('random parent index:',hm36_parent) 51 | bone = [] 52 | for i in range(1, len(hm36_parent)): 53 | bone_3d = pose_3d[:, :, i] - pose_3d[:,:,hm36_parent[i]] 54 | bone.append(bone_3d.unsqueeze(dim=-2)) 55 | bone_out = torch.cat(bone, dim=-2) 56 | return bone_out 57 | 58 | def cal_bone_sym(predicted_3d_pos): 59 | # calculate bone length symmetry 60 | left = [4,5,6,11,12,13] 61 | right = [1,2,3,14,15,16] 62 | hm36_parent = [-1, 0, 1, 2, 0, 4, 5, 0, 7, 8, 9, 8, 11, 12, 8, 14, 15] 63 | 64 | bone_lengths_lift = [] 65 | bone_lengths_right = [] 66 | each_bone_error = [] 67 | left_right_error = 0 68 | for i in left: 69 | dists_l = predicted_3d_pos[:, :, i, :] - predicted_3d_pos[:, :, hm36_parent[i], :] 70 | bone_lengths_lift.append(torch.mean(torch.norm(dists_l, dim=-1), dim=1)) 71 | for i in right: 72 | dists_r = predicted_3d_pos[:, :, i, :] - predicted_3d_pos[:, :, hm36_parent[i], :] 73 | bone_lengths_right.append(torch.mean(torch.norm(dists_r, dim=-1), dim=1)) 74 | for i in range(len(left)): 75 | left_right_error += torch.abs(bone_lengths_right[i] - bone_lengths_lift[i]) 76 | each_bone_error.append(torch.mean(torch.abs(bone_lengths_right[i] - bone_lengths_lift[i]))) 77 | txt1 = 'Bone symmetric error (right-left): Hip {0}mm, Upper Leg {1}mm, Lower Leg {2}mm, '\ 78 | 'Shoulder {3}mm, Upper elbow {4}mm, Lower elbow {5}mm'.format(each_bone_error[0]*1000,each_bone_error[1]*1000,each_bone_error[2]*1000, 79 | each_bone_error[3]*1000,each_bone_error[4]*1000,each_bone_error[5]*1000) 80 | print(txt1) 81 | left_right_err_mean = torch.mean(left_right_error/6) 82 | print('all parts mean symmetric error: ',left_right_err_mean.item()*1000) 83 | return torch.tensor(each_bone_error) 84 | 85 | def angle_np(v1, v2, acute=False): 86 | # v1 is your firsr vector 87 | # v2 is your second vector 88 | angle = np.arccos(np.dot(v1, v2) / (np.linalg.norm(v1, axis=-1) * np.linalg.norm(v2, axis=-1))) 89 | #angle = angle[:, :, np.newaxis] 90 | if (acute == True): 91 | return angle 92 | else: 93 | return 2 * np.pi - angle 94 | 95 | def angle_torch(v1, v2, torch_pi, acute=False): 96 | # v1 is your firsr 3d vector, v.shape: [B, T, 1, 3] 97 | # v2 is your second 3d vector 98 | v1_len = torch.norm(v1, dim=-1) 99 | v2_len = torch.norm(v2, dim=-1) 100 | angle = torch.mean(torch.acos(torch.sum(v1.mul(v2), dim=-1) / (v1_len * v2_len+torch.Tensor([1e-8]).cuda()))) #shape: [B, T. 1] 101 | if (acute == True): 102 | return angle 103 | else: 104 | return 2 * torch_pi - angle 105 | 106 | def cal_bone_angle(bone_3d, indexes=None): 107 | torch_pi = torch.acos(torch.zeros(1)).item() * 2 108 | bone_parent_index = [-1, 0, 1, 0, 3, 4, 0, 6, 7, 8, 7, 10, 11, 7, 13, 14] 109 | bone_angle = [] 110 | text = [] 111 | # init_vector = [0,-1,0] #relative to root joint 112 | 113 | if indexes: 114 | # calculate specific pair joint angle 115 | for index in indexes: 116 | bone_child = bone_3d[index] 117 | bone_parent = bone_3d[bone_parent_index[index]] 118 | bone_angle.append(180 * angle_torch(bone_child, bone_parent, torch_pi, acute=True)/torch.pi) 119 | text.append('The bone angle between child bone {} and parent bone {} is :{}'.format(index, bone_parent_index[index], bone_angle[-1])) 120 | print(text) 121 | 122 | else: 123 | # calculate each pair joint angle 124 | for index in range(1, 16): 125 | bone_child = bone_3d[:, :, index] 126 | bone_parent = bone_3d[:, :, bone_parent_index[index]] 127 | #angle1 = 180*angle(bone_child.squeeze().cpu().detach().numpy()[0], bone_parent.squeeze().cpu().detach().numpy()[0], acute=True)/np.pi 128 | joint_angle = 180 * angle_torch(bone_child.squeeze(), bone_parent.squeeze(), torch_pi, acute=True) / torch_pi 129 | bone_angle.append(joint_angle.unsqueeze(dim=-1)) 130 | text.append('The bone angle between child bone {} and parent bone {} is :{}'.format(index, bone_parent_index[index],bone_angle[-1])) 131 | 132 | body_angle = torch.cat(bone_angle, dim=-1) 133 | print(text) 134 | return body_angle 135 | 136 | def cal_bone_length_np(pose_3d): 137 | hm36_parent = [-1, 0, 1, 2, 0, 4, 5, 0, 7, 8, 9, 8, 11, 12, 8, 14, 15] 138 | hm36_num = 17 139 | e4 = [] 140 | for i in range(1, hm36_num): 141 | bone_3d = pose_3d[:, :, i] - pose_3d[:,:,hm36_parent[i]] 142 | e4.append(np.mean(np.mean(np.linalg.norm(bone_3d, axis=-1)*100,0),0)) 143 | print('std of each bone length',np.std(np.mean(np.linalg.norm(bone_3d, axis=-1)*100,0),0),'cm') 144 | kpt_txt = 'Bone length of RHip is {0}cm,RUleg is {1}cm, RLleg is {2}cm, Lhip is {3}cm, LUleg is {4}cm, LLleg is {5}cm, ' \ 145 | 'Lspine is {6}cm, Uspine is {7}cm, Neck is {8}cm, Head is {9}cm, Lshoulder is {10}cm, LUelbow is {11}cm, LLelbow is {12}cm, '\ 146 | 'Rshoudle is {13}cm, RUelbow is {14}cm, RLelbow is {15}cm:'.format(e4[0], e4[1], e4[2], e4[3], 147 | e4[4], e4[5], 148 | e4[6], 149 | e4[7], e4[8], 150 | e4[9], e4[10], 151 | e4[11], e4[12], 152 | e4[13], 153 | e4[14], e4[15]) 154 | print(kpt_txt) 155 | return e4 -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Split-and-Recombine-Net 2 | This is the original PyTorch implementation of the following work: [SRNet: Improving Generalization in 3D Human Pose Estimation with a Split-and-Recombine Approach](https://arxiv.org/pdf/2007.09389.pdf) in ECCV 2020. 3 | 4 | ![News](https://img.shields.io/badge/-News!-red) Recently, our method has been verified in [3D Human Mesh Recovery](https://openaccess.thecvf.com/content/ICCV2021/papers/Lee_Uncertainty-Aware_Human_Mesh_Recovery_From_Video_by_Learning_Part-Based_3D_ICCV_2021_paper.pdf) as a decoder to obtain both per-frame accuracy and motion smoothness in ICCV 2021! 5 | 6 | Beyond this task, You can make full use of prior knowledge in your task to design the **group** strategies. Our proposed method (Split-and-Recombine) is an efficient and effective way to replace fully connected layer with about **[1/group] parameters (group is 5 in this task) and better performance**. 7 | 8 | ## Features 9 | - [x] Support single-frame setting (e.g., -arc 1,1,1) 10 | - [x] Support multi-frame setting (e.g., -arc 3,3,3,3,3 for 243 frames) 11 | - [x] Support four normalization (--norm {base,proj,weak_proj,lcn}) 12 | - [x] Support cross-subject, cross-action, cross-camera settings 13 | - [x] Support [VideoPose3d](https://arxiv.org/abs/1811.11742), [SimpleBaseline](https://arxiv.org/pdf/1705.03098.pdf) as our baseline. 14 | 15 | 16 | ## Introduction 17 | Monocular 3D human pose estimation is to input 2d poses to lift into 3d relative poses. Take root (index=0) joint as the zero-position under camera coordinate by default.\ 18 | Human poses that are rare or unseen in a training set are challenging for a network to predict. Similar to the long-tailed distribution problem in visual recognition, the small number of examples for such poses limits the ability of networks to model them. Interestingly, local pose distributions suffer less from the long-tail problem, i.e., local 19 | joint configurations within a rare pose may appear within other poses in the training set, making them less rare. 20 | ![observation](img/observation.png) 21 | 22 | We propose to take advantage of this fact for better generalization to rare and unseen poses. To be specific, our method splits the body into local regions and processes them in 23 | separate network branches, utilizing the property that a joint's position depends mainly on the joints within its local body region. Global coherence is maintained by recombining the global context from the rest of the body into each branch as a low-dimensional vector. With the reduced dimensionality of less relevant body areas, the training set distribution within network branches more closely reflects the statistics of local poses instead of global body poses, without sacrificing information important for joint inference. The proposed split-and-recombine approach, called SRNet, can be easily adapted to **both single-image and temporal models**, and it leads to appreciable improvements in the prediction of rare and unseen poses. 24 | ![framework](img/framework.png) 25 | 26 | The comparison of Different network structures used for 2D to 3D pose estimation. 27 | ![comparison](img/comparison.png) 28 | 29 | 30 | ## Get started 31 | 32 | To get started as quickly as possible, follow the instructions in this section. It allows you to train a model from scratch, test our pretrained models, and produce basic visualizations. For more detailed instructions, please refer to [DOCUMENTATION.md](https://github.com/facebookresearch/VideoPose3D/blob/master/DOCUMENTATION.md). 33 | 34 | ### Dependencies 35 | Make sure you have the following dependencies installed before proceeding: 36 | 37 | * Python 3+ distribution 38 | * PyTorch >= 0.4.0 39 | * pip install matplotlib==3.1.1 40 | 41 | ### Directory 42 | 43 | First, we build new files to store models: 44 | 45 | ``` 46 | mkdir checkpoint 47 | mkdir best_checkpoint 48 | ``` 49 | 50 | The ${ROOT} is described as below. 51 | 52 | ```${ROOT} 53 | |-- data/ 54 | |-- checkpoint/ 55 | |-- best_checkpoint/ 56 | |-- common/ 57 | |-- config/ 58 | |-- run.py 59 | ``` 60 | 61 | ### Dataset preparation 62 | 63 | Please follow the instruction from [VideoPose3D](https://github.com/facebookresearch/VideoPose3D/blob/main/DATASETS.md) to process the data from the official [Human3.6M](http://vision.imar.ro/human3.6m/description.php) website. 64 | You can download the processed skeleton-based Human3.6M datasets in the [link](https://drive.google.com/drive/folders/17kXk6rK84-wdDTc1LLemlHZIvFQy4oKj?usp=sharing). Put the data into the dictory data/. 65 | ``` 66 | mkdir data 67 | cd data 68 | ``` 69 | 70 | The data directory structure is shown as follows. 71 | ``` 72 | ./ 73 | └── data/ 74 | ├── data_2d_h36m_gt.npz 75 | ├── data_3d_h36m.npz 76 | ``` 77 | 78 | `data_2d_h36m_gt.npz` is the 2d ground-truth pose of Human3.6M dataset. 79 | 80 | `data_3d_h36m.npz` is the 3d ground-truth pose of Human3.6M dataset. 81 | 82 | ### Pretrain Model 83 | We provide single-frame and multi-frame models in the [link](https://drive.google.com/drive/folders/11HhwYZpYhMMK7IdlyFekDqZB9ii9BA53?usp=sharing) for inference and finetune: 84 | 85 | Pretrained models should be put in the `checkpoint/` \ 86 | e.g., 87 | `latest_epoch_fc.bin` is the last checkpoint of the [fully connected network](https://arxiv.org/pdf/1705.03098.pdf)\ 88 | `latest_epoch_sr.bin` is the last checkpoint with 1 frame of the [split-and-recombine network](https://arxiv.org/pdf/2007.09389.pdf)\ 89 | `srnet_gp5_t243_mul.bin` is the last checkpoint with 243 frames of the [split-and-recombine network](https://arxiv.org/pdf/2007.09389.pdf) 90 | 91 | ### Inference 92 | Using **--evaluate {model_name}** for testing the model. \ 93 | Using **--resume {model_name}** for resuming some checkpoint to finetune the model. 94 | 95 | For example: 96 | ``` 97 | python run.py -arc 3,3,3,3,3 --model srnet --evaluate srnet_gp5_t243_mul.bin 98 | python run.py -arc 1,1,1 --model srnet --evaluate latest_epoch_sr.bin 99 | python run.py -arc 1,1,1 --model fc --evaluate latest_epoch_fc.bin 100 | ``` 101 | 102 | ### Training 103 | 104 | There are three training and test settings, the commonly used is the **cross-subject** (by default). We train on **five subjects** with all four cameras and all fifteen actions and test on other **two subjects** with all cameras and all actions. \ 105 | To train the [split-and-recombine network](https://arxiv.org/pdf/2007.09389.pdf) with 243 frames as input and 1 frames as output from the scratch, run: 106 | 107 | ``` 108 | python run.py -arc 3,3,3,3,3 --model srnet -mn {given_model_name} 109 | ``` 110 | `-mn` is the model name to save the specific model. 111 | 112 | To train the [VideoPose3d](https://arxiv.org/abs/1811.11742) with 243 frames as input and 1 frames as output from the scratch, run: 113 | ``` 114 | python run.py -arc 3,3,3,3,3 --model fc -mn {given_model_name} 115 | ``` 116 | 117 | To use **cross-action** setting, we train on only **one action** with all subjects and all cameras, and test on other **fourteen actions** with all subjects and all cameras. 118 | 119 | You can add the arguments in the command like: `--use-action-split True --train-action Discussion` 120 | 121 | To use **cross-camera** setting, we train on only **one camera** with all subjects and all actions, and test on other **three cameras** with all subjects and all actions. 122 | 123 | You can add the arguments in the command like: `--cam-train [0] --cam-test [1,2,3]` 124 | 125 | For convenience of different hyper-parameter settings, you can edit the scipt **run_os.py** to run experiments for once. 126 | #### We also put some configuration examples in the dictory config/. To facilitate reproduction, we provide the training logs for single-frame and multi-frame settings [here](https://drive.google.com/drive/folders/1Z1xZt9n749cW89eKPcR4WJ8RgRS3ar-8?usp=sharing). You can check **the hyperparameters, training loss and test results for each epoch** in these logs as well. 127 | 128 | If you find this repository useful for your work, please consider citing it as follows: 129 | 130 | ``` 131 | @inproceedings{Zeng2020SRNet, 132 | title={SRNet: Improving Generalization in 3D Human Pose Estimation with a Split-and-Recombine Approach}, 133 | author={Ailing Zeng and Xiao Sun and Fuyang Huang and Minhao Liu and Qiang Xu and Stephen Ching-Feng Lin}, 134 | booktitle={ECCV}, 135 | year={2020} 136 | } 137 | ``` 138 | -------------------------------------------------------------------------------- /common/dataset/pre_process/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | # Copyright (c) 2018-present, Facebook, Inc. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | # 9 | 10 | import numpy as np 11 | #import h5py 12 | 13 | mpii_metadata = { 14 | 'layout_name': 'mpii', 15 | 'num_joints': 16, 16 | 'keypoints_symmetry': [ 17 | [3, 4, 5, 13, 14, 15], 18 | [0, 1, 2, 10, 11, 12], 19 | ] 20 | } 21 | 22 | coco_metadata = { 23 | 'layout_name': 'coco', 24 | 'num_joints': 17, 25 | 'keypoints_symmetry': [ 26 | [1, 3, 5, 7, 9, 11, 13, 15], 27 | [2, 4, 6, 8, 10, 12, 14, 16], 28 | ] 29 | } 30 | 31 | h36m_metadata = { 32 | 'layout_name': 'h36m', 33 | 'num_joints': 17, 34 | 'keypoints_symmetry': [ 35 | [4, 5, 6, 11, 12, 13], 36 | [1, 2, 3, 14, 15, 16], 37 | ] 38 | } 39 | 40 | humaneva15_metadata = { 41 | 'layout_name': 'humaneva15', 42 | 'num_joints': 15, 43 | 'keypoints_symmetry': [ 44 | [2, 3, 4, 8, 9, 10], 45 | [5, 6, 7, 11, 12, 13] 46 | ] 47 | } 48 | 49 | humaneva20_metadata = { 50 | 'layout_name': 'humaneva20', 51 | 'num_joints': 20, 52 | 'keypoints_symmetry': [ 53 | [3, 4, 5, 6, 11, 12, 13, 14], 54 | [7, 8, 9, 10, 15, 16, 17, 18] 55 | ] 56 | } 57 | 58 | 59 | def suggest_metadata(name): 60 | names = [] 61 | for metadata in [mpii_metadata, coco_metadata, h36m_metadata, humaneva15_metadata, humaneva20_metadata]: 62 | if metadata['layout_name'] in name: 63 | return metadata 64 | names.append(metadata['layout_name']) 65 | raise KeyError('Cannot infer keypoint layout from name "{}". Tried {}.'.format(name, names)) 66 | 67 | 68 | def import_detectron_poses(path): 69 | # Latin1 encoding because Detectron runs on Python 2.7 70 | data = np.load(path, encoding='latin1') 71 | kp = data['keypoints'] 72 | bb = data['boxes'] 73 | results = [] 74 | for i in range(len(bb)): 75 | if len(bb[i][1]) == 0: 76 | assert i > 0 77 | # Use last pose in case of detection failure 78 | results.append(results[-1]) 79 | continue 80 | best_match = np.argmax(bb[i][1][:, 4]) 81 | keypoints = kp[i][1][best_match].T.copy() 82 | results.append(keypoints) 83 | results = np.array(results) 84 | return results[:, :, 4:6] # Soft-argmax 85 | # return results[:, :, [0, 1, 3]] # Argmax + score 86 | 87 | 88 | def import_cpn_poses(path): 89 | data = np.load(path) 90 | kp = data['keypoints'] 91 | return kp[:, :, :2] 92 | 93 | 94 | def import_sh_poses(path): 95 | with h5py.File(path) as hf: 96 | positions = hf['poses'].value 97 | return positions.astype('float32') 98 | 99 | 100 | def suggest_pose_importer(name): 101 | if 'detectron' in name: 102 | return import_detectron_poses 103 | if 'cpn' in name: 104 | return import_cpn_poses 105 | if 'sh' in name: 106 | return import_sh_poses 107 | raise KeyError('Cannot infer keypoint format from name "{}". Tried detectron, cpn, sh.'.format(name)) 108 | 109 | 110 | def fetch(subjects, keypoints, dataset, downsample, action_filter=None, cam_filter=None, subset=1, parse_3d_poses=True): 111 | out_poses_3d = [] 112 | out_poses_2d = [] 113 | out_camera_params = [] 114 | for subject in subjects: 115 | for action in keypoints[subject].keys(): 116 | action_split = action.split(' ')[0] 117 | if action_filter is not None: 118 | found = False 119 | # distinguish the actions:'Sitting' and 'SittingDown' 120 | for act in action_filter: 121 | act = act.split(' ')[0] 122 | if action_split == act: 123 | found = True 124 | break 125 | if not found: 126 | continue 127 | poses_2d = keypoints[subject][action] 128 | index = np.random.randint(0,4) 129 | if cam_filter==[5]: #random camera index 130 | out_poses_2d.append(poses_2d[index]) 131 | print('choose a camera index for each action:', index) 132 | 133 | elif cam_filter: 134 | for j in cam_filter: # Select by some camera viewpoints 135 | out_poses_2d.append(poses_2d[j]) 136 | 137 | else: 138 | for i in range(len(poses_2d)): # Iterate across cameras 139 | out_poses_2d.append(poses_2d[i]) 140 | 141 | if subject in dataset.cameras(): 142 | cams = dataset.cameras()[subject] 143 | assert len(cams) == len(poses_2d), 'Camera count mismatch' 144 | if cam_filter==[5]: 145 | cam = cams[index] 146 | if 'intrinsic' in cam: 147 | use_params = {} 148 | use_params['intrinsic'] = cam['intrinsic'] 149 | if 'normalization_params' in dataset[subject][action]: 150 | use_params['normalization_params'] = \ 151 | dataset[subject][action]['normalization_params'][index] 152 | out_camera_params.append(use_params) 153 | elif cam_filter: 154 | for j in cam_filter: 155 | for i, cam in enumerate(cams): 156 | if j == i: 157 | if 'intrinsic' in cam: 158 | use_params = {} 159 | use_params['intrinsic'] = cam['intrinsic'] 160 | if 'normalization_params' in dataset[subject][action]: 161 | use_params['normalization_params'] = \ 162 | dataset[subject][action]['normalization_params'][i] 163 | out_camera_params.append(use_params) 164 | else: 165 | for i, cam in enumerate(cams): 166 | if 'intrinsic' in cam: 167 | use_params = {} 168 | use_params['intrinsic'] = cam['intrinsic'] 169 | if 'normalization_params' in dataset[subject][action]: 170 | use_params['normalization_params'] = dataset[subject][action]['normalization_params'][i] 171 | out_camera_params.append(use_params) 172 | 173 | if parse_3d_poses and 'positions_3d' in dataset[subject][action]: 174 | poses_3d = dataset[subject][action]['positions_3d'] 175 | assert len(poses_3d) == len(poses_2d), 'Camera count mismatch' 176 | if cam_filter==[5]: 177 | out_poses_3d.append(poses_3d[index]) 178 | elif cam_filter : 179 | for j in cam_filter: 180 | out_poses_3d.append(poses_3d[j]) 181 | else: 182 | for i in range(len(poses_3d)): # Iterate across cameras 183 | out_poses_3d.append(poses_3d[i]) 184 | 185 | if len(out_camera_params) == 0: 186 | out_camera_params = None 187 | if len(out_poses_3d) == 0: 188 | out_poses_3d = None 189 | 190 | stride = downsample 191 | if subset < 1: 192 | for i in range(len(out_poses_2d)): 193 | n_frames = int(round(len(out_poses_2d[i]) // stride * subset) * stride) 194 | start = deterministic_random(0, len(out_poses_2d[i]) - n_frames + 1, str(len(out_poses_2d[i]))) 195 | out_poses_2d[i] = out_poses_2d[i][start:start + n_frames:stride] 196 | if out_poses_3d is not None: 197 | out_poses_3d[i] = out_poses_3d[i][start:start + n_frames:stride] 198 | elif stride > 1: 199 | # Downsample as requested 200 | for i in range(len(out_poses_2d)): 201 | out_poses_2d[i] = out_poses_2d[i][::stride] 202 | if out_poses_3d is not None: 203 | out_poses_3d[i] = out_poses_3d[i][::stride] 204 | 205 | return out_camera_params, out_poses_3d, out_poses_2d 206 | -------------------------------------------------------------------------------- /common/dataset/pre_process/hm36.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from common.transformation.cam_utils import * 3 | from common.dataset.pre_process.norm_data import norm_to_pixel 4 | from common.dataset.h36m_dataset import Human36mDataset 5 | 6 | def load_data(dataset_root, dataset_name, kpt_name): 7 | dataset_path = dataset_root + 'data_3d_' + dataset_name + '.npz' 8 | if dataset_name == 'h36m': 9 | dataset = Human36mDataset(dataset_path) 10 | else: 11 | raise KeyError('Invalid dataset') 12 | return dataset 13 | 14 | 15 | # dataset is modified. 16 | def prepare_dataset(dataset): 17 | for subject in dataset.subjects(): 18 | for action in dataset[subject].keys(): 19 | anim = dataset[subject][action] 20 | if 'positions' in anim: 21 | positions_3d = [] 22 | for cam in anim['cameras']: 23 | pos_3d = world_to_camera(anim['positions'], R=cam['orientation'], t=cam['translation']) 24 | positions_3d.append(pos_3d) 25 | anim['positions_3d'] = positions_3d 26 | 27 | 28 | def load_2d_data(dataset_root, dataset_name, kpt_name): 29 | keypoints = np.load(dataset_root + 'data_2d_' + dataset_name + '_' + kpt_name + '.npz', allow_pickle=True) 30 | keypoints_metadata = keypoints['metadata'].item() 31 | keypoints_symmetry = keypoints_metadata['keypoints_symmetry'] 32 | kps_left, kps_right = list(keypoints_symmetry[0]), list(keypoints_symmetry[1]) 33 | keypoints = keypoints['positions_2d'].item() 34 | use_smooth_2d = False 35 | if use_smooth_2d: 36 | print('use smooth 2d pose:') 37 | smooth_2d = np.load('common/dataset/pre_process/smooth_cpn_ft_81_all.npz', allow_pickle=True) 38 | keypoints = smooth_2d['positions_2d'].item() 39 | 40 | return keypoints, keypoints_metadata, kps_left, kps_right 41 | 42 | def load_hard_test(file_path, eval_num): 43 | # Used for load hard test set in our evaluation 44 | hard_pose = np.load(file_path, allow_pickle=True) 45 | pose_3d = hard_pose['pose_3d'] # Have normalized, type:list 46 | pose_2d = hard_pose['pose_2d'] # Have normalized; type:list 47 | if len(pose_3d)==1: 48 | num = pose_3d[0].shape[0] 49 | size = num // eval_num 50 | t_3d = [] 51 | t_2d = [] 52 | t = 0 53 | for j in range(size): 54 | t_3d.append(pose_3d[0][t:t + eval_num]) 55 | t_2d.append(pose_2d[0][t:t + eval_num]) 56 | t += eval_num 57 | t_3d.append(pose_3d[0][t:]) 58 | t_2d.append(pose_2d[0][t:]) 59 | return t_3d, t_2d 60 | else: 61 | return pose_3d, pose_2d 62 | 63 | # keypoints are midified. dataset remains. 64 | def prepare_2d_data(keypoints, dataset): 65 | for subject in dataset.subjects(): 66 | assert subject in keypoints, 'Subject {} is missing from the 2D detections dataset'.format(subject) 67 | for action in dataset[subject].keys(): 68 | assert action in keypoints[ 69 | subject], 'Action {} of subject {} is missing from the 2D detections dataset'.format( 70 | action, subject) 71 | if 'positions_3d' not in dataset[subject][action]: 72 | continue 73 | 74 | for cam_idx in range(len(keypoints[subject][action])): 75 | 76 | # We check for >= instead of == because some videos in H3.6M contain extra frames 77 | mocap_length = dataset[subject][action]['positions_3d'][cam_idx].shape[0] 78 | assert keypoints[subject][action][cam_idx].shape[0] >= mocap_length 79 | 80 | if keypoints[subject][action][cam_idx].shape[0] > mocap_length: 81 | # Shorten sequence 82 | keypoints[subject][action][cam_idx] = keypoints[subject][action][cam_idx][:mocap_length] 83 | 84 | assert len(keypoints[subject][action]) == len(dataset[subject][action]['positions_3d']) 85 | 86 | def random_rotate(dataset, keypoints, subjects, action_filter=None, cam_filter=None): 87 | print('Random rotate 3d pose around Y axis, output rotated 2d and 3d poses ') 88 | for subject in subjects: 89 | for action in dataset[subject].keys(): 90 | action_split = action.split(' ')[0] 91 | if action_filter is not None: 92 | found = False 93 | # distinguish the actions:'Sitting' and 'SittingDown' 94 | for act in action_filter: 95 | act = act.split(' ')[0] 96 | if action_split == act: 97 | found = True 98 | break 99 | if not found: 100 | continue 101 | cams = dataset.cameras()[subject] 102 | poses_3d = dataset[subject][action]['positions_3d'] 103 | poses_2d = keypoints[subject][action] 104 | assert len(poses_3d) == len(cams), 'Camera count mismatch' 105 | assert len(cams) == len(poses_2d), 'Camera count mismatch' 106 | 107 | if cam_filter: 108 | for i in cam_filter: # Select by some camera viewpoints 109 | dataset[subject][action]['positions_3d'][i] = poses_3d[i] 110 | w, h = np.repeat(np.array(cams[i]['res_w'])[np.newaxis,np.newaxis,np.newaxis], poses_2d[i].shape[0], axis=0), \ 111 | np.repeat(np.array(cams[i]['res_h'])[np.newaxis,np.newaxis,np.newaxis], poses_2d[i].shape[0], axis=0) 112 | wh = np.concatenate((w,h),axis=-1) 113 | keypoints[subject][action][i] = np.concatenate((poses_2d[i],wh),axis=1) 114 | else: 115 | for i in range(len(poses_3d)): 116 | dataset[subject][action]['positions_3d'][i] = poses_3d[i] 117 | w, h = np.repeat(np.array(cams[i]['res_w'])[np.newaxis,np.newaxis,np.newaxis], poses_2d[i].shape[0], axis=0), \ 118 | np.repeat(np.array(cams[i]['res_h'])[np.newaxis,np.newaxis,np.newaxis], poses_2d[i].shape[0], axis=0) 119 | wh = np.concatenate((w,h),axis=-1) 120 | keypoints[subject][action][i] = np.concatenate((poses_2d[i],wh),axis=1) 121 | 122 | def normalization(dataset, keypoints, subjects, action_filter, cam_filter, norm): 123 | print('Start to normalize input 2d and 3d pose: ') 124 | for subject in subjects: 125 | for action in keypoints[subject].keys(): 126 | action_split = action.split(' ')[0] 127 | if action_filter is not None: 128 | found = False 129 | # distinguish the actions:'Sitting' and 'SittingDown' 130 | for act in action_filter: 131 | act = act.split(' ')[0] 132 | if action_split == act: 133 | found = True 134 | break 135 | if not found: 136 | continue 137 | 138 | poses_2d = keypoints[subject][action] 139 | cams = dataset.cameras()[subject] 140 | poses_3d = dataset[subject][action]['positions_3d'] 141 | assert len(poses_3d) == len(poses_2d), 'Camera count mismatch' 142 | assert len(cams) == len(poses_2d), 'Camera count mismatch' 143 | norm_params = [] 144 | if cam_filter: 145 | for i in cam_filter: # Select by some camera viewpoints 146 | if norm == 'base': 147 | # Remove global offset, but keep trajectory in first position 148 | poses_3d[i][:, 1:] -= poses_3d[i][:, :1] 149 | normed_pose_3d = poses_3d[i] 150 | normed_pose_2d = normalize_screen_coordinates(poses_2d[i][..., :2], w=cams[i]['res_w'], h=cams[i]['res_h']) 151 | 152 | else: 153 | normed_pose_3d, normed_pose_2d, pixel_ratio, rescale_ratio, offset_2d, abs_root_Z = norm_to_pixel( 154 | poses_3d[i], poses_2d[i], cams[i]['intrinsic'], norm) 155 | norm_params.append(np.concatenate((pixel_ratio, rescale_ratio, offset_2d, abs_root_Z), axis=-1)) # [T, 1, 5], len()==4 156 | keypoints[subject][action][i] = normed_pose_2d 157 | dataset[subject][action]['positions_3d'][i] = normed_pose_3d 158 | if norm_params: 159 | dataset[subject][action]['normalization_params'] = norm_params 160 | else: 161 | for i in range(len(poses_2d)): 162 | if norm == 'base': 163 | # Remove global offset, but keep trajectory in first position 164 | poses_3d[i][:, 1:] -= poses_3d[i][:, :1] 165 | normed_pose_3d = poses_3d[i] 166 | normed_pose_2d = normalize_screen_coordinates(poses_2d[i][..., :2], w=cams[i]['res_w'], h=cams[i]['res_h']) 167 | 168 | else: 169 | normed_pose_3d, normed_pose_2d, pixel_ratio, rescale_ratio, offset_2d, abs_root_Z = norm_to_pixel(poses_3d[i], poses_2d[i], cams[i]['intrinsic'], norm) 170 | norm_params.append(np.concatenate((pixel_ratio, rescale_ratio, offset_2d, abs_root_Z), axis=-1)) # [T, 1, 5], len()==4 171 | keypoints[subject][action][i] = normed_pose_2d 172 | dataset[subject][action]['positions_3d'][i] = normed_pose_3d 173 | if norm_params: 174 | dataset[subject][action]['normalization_params'] = norm_params 175 | 176 | -------------------------------------------------------------------------------- /common/visualization/plot_log_epoch.py: -------------------------------------------------------------------------------- 1 | # This code is used for plot the training/test errors with each epoch for each model. 2 | # I have made these four picture in one picture 3 | 4 | import matplotlib 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | 8 | 9 | 10 | # input_1 = open('log/1031_cpn_l1.log','r') 11 | # input_2 = open('log/1031_cpn_lr_l1.log','r') 12 | # input_3 = open('log/1031_cpn_lr_l2.log','r') 13 | # input_4 = open('log/1031_cpn_lr_l1_243.log','r') 14 | # input_5 = open('log/1031_hg_l2.log','r') 15 | # input_6 = open('log/1031_hg_l1.log','r') 16 | # input_7 = open('log/1031_hg_lr_l1.log','r') 17 | # input_8 = open('log/1031_hg_lr_l2.log','r') 18 | # input_9 = open('log/1101_cpn_l2_243.log','r') 19 | # input_10 = open('log/1101_cpn_l1_lr002_243.log','r') 20 | 21 | input_1 = open('log/r_100_gp5_lr1e-3.log','r') 22 | input_2 = open('log/r_15_gp5.log','r') 23 | # input_3 = open('log/test_3dpw_gp3_t243.log','r') 24 | # input_4 = open('log/test_3dpw_fc2.log','r') 25 | #input_4 = open('log/test_3dpw_t3_v3.log','r') 26 | 27 | # input_5 = open('log/1019_ori_33333_gt_gp2.log','r') 28 | # input_6 = open('log/1101_mask_l2_ori_243.log','r') 29 | # input_7 = open('log/1101_hg_l2_ori_243.log','r') 30 | # input_8 = open('log/1101_mask_l1_lr002_243.log','r') 31 | # input_9 = open('log/1101_cpn_l2_243.log','r') 32 | # input_10 = open('log/1101_cpn_l1_lr002_243.log','r') 33 | 34 | Training_result_1 = [] 35 | Test_result_1 = [] 36 | Training_result_2 = [] 37 | Test_result_2 = [] 38 | Training_result_3 = [] 39 | Test_result_3 = [] 40 | Training_result_4 = [] 41 | Test_result_4 = [] 42 | Training_result_5 = [] 43 | Test_result_5 = [] 44 | Training_result_6 = [] 45 | Test_result_6 = [] 46 | Training_result_7 = [] 47 | Test_result_7 = [] 48 | Training_result_8 = [] 49 | Test_result_8 = [] 50 | Training_result_9 = [] 51 | Test_result_9 = [] 52 | Test_result_10 = [] 53 | Test_result_11 = [] 54 | Test_result_12 = [] 55 | Test_result_13 = [] 56 | Test_result_14 = [] 57 | Test_result_15 = [] 58 | Test_result_16 = [] 59 | Test_result_17 = [] 60 | Test_result_18 = [] 61 | Test_details = [] 62 | i = [] 63 | j = 0 64 | for line in input_1: 65 | if '3d_valid' in line: 66 | word = line.split(' ') 67 | Test_result_1.append(float(word[-5][:5])) 68 | # if 'aligned_pck' in line: 69 | # word = line.split(' ') 70 | Test_result_2.append(float(word[-1][:5])) 71 | # if 'previous 9' in line: 72 | # word = line.split(' ') 73 | # Test_result_1.append(float(word[6][7:12])) 74 | # if 'after 8' in line: 75 | # word = line.split(' ') 76 | # Test_result_2.append(float(word[6][7:12])) 77 | # if 'Hip(root)' in line: 78 | # word = line.split(' ') 79 | # Test_result_3.append(float(word[7][0:4])) 80 | # Test_result_4.append(float(word[10][0:4])) 81 | # Test_result_5.append(float(word[13][0:4])) 82 | # Test_result_6.append(float(word[16][0:4])) 83 | # Test_result_7.append(float(word[19][0:4])) 84 | # Test_result_8.append(float(word[22][0:4])) 85 | # Test_result_9.append(float(word[25][0:5])) 86 | # Test_result_10.append(float(word[28][0:4])) 87 | # Test_result_11.append(float(word[31][0:4])) 88 | # Test_result_12.append(float(word[34][0:4])) 89 | # Test_result_13.append(float(word[37][0:4])) 90 | # Test_result_14.append(float(word[40][0:4])) 91 | # Test_result_15.append(float(word[43][0:3])) 92 | # Test_result_16.append(float(word[46][0:4])) 93 | # Test_result_17.append(float(word[49][0:4])) 94 | # # Test_result_18.append(float(word[52][0:4])) 95 | for line in input_2: 96 | # if '3d_train' in line: 97 | # word = line.split(' ') 98 | # Test_result_2.append(float(word[-1])) 99 | if '3d_valid' in line: 100 | word = line.split(' ') 101 | Test_result_3.append(float(word[-5][:5])) 102 | # if 'aligned_pck' in line: 103 | # word = line.split(' ') 104 | Test_result_4.append(float(word[-1][:5])) 105 | # 106 | # for line in input_3: 107 | # # if '3d_train' in line: 108 | # # word = line.split(' ') 109 | # # Test_result_2.append(float(word[-1])) 110 | # if 'mean pck' in line: 111 | # word = line.split(' ') 112 | # Test_result_5.append(float(word[-2][:5])) 113 | # # if 'aligned_pck' in line: 114 | # # word = line.split(' ') 115 | # Test_result_6.append(float(word[-1][:5])) 116 | # # 117 | # # 118 | # for line in input_4: 119 | # if 'mean pck' in line: 120 | # word = line.split(' ') 121 | # Test_result_7.append(float(word[-2][:5])) 122 | # # if 'aligned_pck' in line: 123 | # # word = line.split(' ') 124 | # Test_result_8.append(float(word[-1][:5])) 125 | # if '3d_train' in line: 126 | # word = line.split(' ') 127 | # Test_result_4.append(float(word[-1])) 128 | # 129 | # for line in input_5: 130 | # if '3d_train' in line: 131 | # word = line.split(' ') 132 | # Test_result_5.append(float(word[-1])) 133 | # 134 | # for line in input_6: 135 | # if '3d_train' in line: 136 | # word = line.split(' ') 137 | # Test_result_6.append(float(word[-1])) 138 | # # 139 | # for line in input_7: 140 | # if '3d_train' in line: 141 | # # Training_result_7.append(float(line[36:46])) 142 | # # # elif '3d_valid' in line: 143 | # word = line.split(' ') 144 | # Test_result_7.append(float(word[-1])) 145 | # 146 | # for line in input_8: 147 | # if '3d_train' in line: 148 | # word = line.split(' ') 149 | # Test_result_8.append(float(word[-1])) 150 | # # 151 | # for line in input_9: 152 | # if '3d_train' in line: 153 | # word = line.split(' ') 154 | # Test_result_9.append(float(word[-1])) 155 | # 156 | # for line in input_10: 157 | # if '3d_train' in line: 158 | # word = line.split(' ') 159 | # Test_result_10.append(float(word[-1])) 160 | # 161 | 162 | fig = plt.figure() 163 | plt.title('Train and Test on 3dpw testset(24 scenes with 37 people) Model') 164 | # ax_1 = plt.subplot(311) 165 | # ax_1.set_title("matte input with 4 channels") 166 | #plt.plot(Training_result_1,'r-x', label = 'Training error of origin_256_384/mm') 167 | plt.plot(Test_result_1,'r-^', label = 'Monocular setup-gp5-Training L1 loss - finetune 100epoch by LR=1e-4') 168 | # plt.plot(Training_result_2,'b-x', label = 'Training error of origin_256_1024/mm') 169 | plt.plot(Test_result_2,'r-x', label = 'Monocular setup-gp5-Test error/mm') 170 | # # plt.plot(Training_result_3,'g-x', label = 'Training error of origin_1024_1024/mm') 171 | plt.plot(Test_result_3,'g-^', label = 'Monocular setup-gp5-Training L1 loss, finetune 15epoch by LR=5e-4') 172 | plt.plot(Test_result_4,'g-x', label = 'Temporal setup-gp5-Test error/mm') 173 | # plt.plot(Test_result_5,'b-^', label = 'Temporal setup-gp3-Test error of AUC/mm') 174 | # # # # #plt.plot(Training_result_5,'m-x', label = 'Training error of dcn_1024_1024/mm') 175 | # # plt.plot(Test_result_4,'b-^', label = 'Test error of ori gp1/mm') 176 | # plt.plot(Test_result_6,'b-x', label = 'Temporal setup-gp3-Test error of aligned AUC/mm') 177 | # plt.plot(Test_result_7,'k-^', label = 'Monocular setup-Video3d-Test error of AUC/mm') 178 | # plt.plot(Test_result_8,'k-x', label = 'Monocular setup-Video3d-Test error of aligned AUC/mm') 179 | # plt.plot(Test_result_9,'tab:orange', label = 'Test error of Lfoot/mm') 180 | # plt.plot(Test_result_10,'tab:blue', label = 'Test error of Spine') 181 | # 182 | # plt.plot(Test_result_11,'y-x', label = 'Training error of Thorax/mm') 183 | # plt.plot(Test_result_12,'k-x', label = 'Training error of Neck/mm') 184 | # plt.plot(Test_result_13,'m-x', label = 'Training error of Head/mm') 185 | # plt.plot(Test_result_14,'r-x', label = 'Training error of Lshoulder/mm') 186 | # plt.plot(Test_result_15,'b-x', label = 'Training error of Lelbow/mm') 187 | # plt.plot(Test_result_16,'g-x', label = 'Training error of Lwrist/mm') 188 | # plt.plot(Test_result_15,'c-x', label = 'Training error of Rshoulder/mm') 189 | # plt.plot(Test_result_16,'y-^', label = 'Training error of Relbow/mm') 190 | # plt.plot(Test_result_17,'tab:green', label = 'Training error of Rwrist/mm') 191 | 192 | # #plt.plot(Training_result_7,'k-x', label = 'Training error of refine_dcn_1024/mm') 193 | 194 | plt.legend() 195 | plt.grid(True) 196 | plt.xlabel('epoch') 197 | plt.ylabel('MPMJE/mm') 198 | # my_x_ticks = np.arange(0,20,1) 199 | # my_y_ticks = np.arange(0,70,10) 200 | # plt.xticks(my_x_ticks)git 201 | # plt.yticks(my_y_ticks) 202 | # 203 | # ax_2 = plt.subplot(312) 204 | # ax_2.set_title('rgb2matte input with 4 channels') 205 | # plt.plot(Training_result_2,'r-x', label = 'Training error/mm') 206 | # plt.plot(Test_result_2,'b-^', label = 'Test error/mm') 207 | # plt.legend() 208 | # plt.grid(True) 209 | # plt.xlabel('Epoch') 210 | # plt.ylabel('MPMJE/mm') 211 | # my_x_ticks = np.arange(0,20,1) 212 | # my_y_ticks = np.arange(0,70,10) 213 | # plt.xticks(my_x_ticks) 214 | # plt.yticks(my_y_ticks) 215 | # 216 | # ax_3 = plt.subplot(313) 217 | # ax_3.set_title('rgb input with 12 channels') 218 | # plt.plot(Training_result_3,'r-x', label = 'Training error/mm') 219 | # plt.plot(Test_result_3,'b-^', label = 'Test error/mm') 220 | # plt.legend() 221 | # plt.grid(True) 222 | # plt.xlabel('Epoch') 223 | # plt.ylabel('MPMJE/mm') 224 | # my_x_ticks = np.arange(0,20,1) 225 | # my_y_ticks = np.arange(0,70,10) 226 | # plt.xticks(my_x_ticks) 227 | # plt.yticks(my_y_ticks) 228 | 229 | # ax_4 = plt.subplot(224) 230 | # ax_4.set_title('p2c4_frame5') 231 | # plt.plot(Training_result_4,'r-x', label = 'Training error/mm') 232 | # plt.plot(Test_result_4,'b-^', label = 'Test error/mm') 233 | # plt.legend() 234 | # plt.grid(True) 235 | # plt.xlabel('Epoch') 236 | # plt.ylabel('MPMJE/mm') 237 | # my_x_ticks = np.arange(0,20,1) 238 | # my_y_ticks = np.arange(0,120,10) 239 | # plt.xticks(my_x_ticks) 240 | # plt.yticks(my_y_ticks) 241 | plt.tight_layout() 242 | plt.show() -------------------------------------------------------------------------------- /common/visualization/plot_log_kpt.py: -------------------------------------------------------------------------------- 1 | import re 2 | import matplotlib 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | 6 | # This code is used for plot the detailed test errors with each frame for each model. 7 | # I have made these four picture in one picture 8 | 9 | input_1 = open('log/test_lcn.log','r') 10 | # input_2 = open('log/eval_gp5_mul.log','r') 11 | # input_3 = open('log/eval_gp5_add.log','r') 12 | # input_4 = open('log/eval_fc_l8_bone0.1.log','r') 13 | # input_5 = open('log/eval_gp5_mul_bone1.log','r') 14 | # input_6 = open('log/eval_gp5_mul_bone0.11.log','r') 15 | # input_7 = open('log/eval_gp5_mul_bone0.01.log','r') 16 | # input_8 = open('log/eval_gp5_add_bone0.01.log','r') 17 | # input_9 = open('log/1210_gp1_test.log','r') 18 | # input_10 = open('log/1210_gp2_test.log','r') 19 | # input_5 = open('log/1210_gp10_test.log','r') 20 | # input_6 = open('log/1210_gp3_test.log','r') 21 | # input_7 = open('log/1210_gp4_test.log','r') 22 | # input_8 = open('log/1210_gp5_test.log','r') 23 | #input_8 = open('log/1013_eva_epoch_20_all_in_bn_nbias_gp2.log','r') 24 | #Training_result means the accuracy of each frame in testset. 25 | Training_result_1 = [] 26 | Test_result_1 = [] 27 | Training_result_2 = [] 28 | Test_result_2 = [] 29 | Training_result_3 = [] 30 | Test_result_3 = [] 31 | Training_result_4 = [] 32 | Test_result_4 = [] 33 | Training_result_5 = [] 34 | Test_result_5 = [] 35 | Training_result_6 = [] 36 | Test_result_6 = [] 37 | Training_result_7 = [] 38 | Test_result_7 = [] 39 | Training_result_8 = [] 40 | Test_result_8 = [] 41 | Test_result_9 = [] 42 | Test_result_10 = [] 43 | Test_result_11 = [] 44 | Test_result_12 = [] 45 | Test_result_13 = [] 46 | Test_result_14 = [] 47 | Test_result_15 = [] 48 | Test_result_16 = [] 49 | 50 | # lines_3 = input_3.readlines() 51 | # lines_last_3 = lines_3[-16:-1] 52 | # lines_last_3.append(lines_3[-1]) 53 | # for v in lines_last_3: 54 | # v = v.split() 55 | # Test_result_3.append(float(v[-1])) 56 | # 57 | 58 | def average(*args): 59 | l = len(args) 60 | sum = 0 61 | if l==0: 62 | return 0.0 63 | i = 0 64 | while i < l: 65 | sum += args[i] 66 | i += 1 67 | return sum*1.0/l 68 | 69 | # Group=2,1rd Group: average(float(word[4][:8]),float(word[7][:8]),float(word[10][:8]),float(word[13][:8]),float(word[16][:8]),float(word[19][:8]),float(word[22][:8]),float(word[25][:8])) float(word[28][:8]), 70 | # Group=2,2rd Group:average(float(word[28][:8]),float(word[31][:8]),float(word[34][:8]),float(word[37][:8]),float(word[40][:8]),float(word[43][:8]),float(word[46][:8]),float(word[49][:8]),float(word[52][:8]))) 71 | 72 | # Group=3; 1rd Group: average(float(word[4][:8]),float(word[7][:8]),float(word[10][:8]),float(word[13][:8]),float(word[16][:8]),float(word[19][:8]),float(word[22][:8]))) 73 | # Group=3; 2rd Group: average(float(word[25][:8]),float(word[28][:8]),float(word[31][:8]),float(word[34][:8]))) 74 | # Group=3; 3rd Group: average(float(word[37][:8]),float(word[40][:8]),float(word[43][:8]),float(word[46][:8]),float(word[49][:8]),float(word[52][:8]))) 75 | a=14 76 | for line in input_1: 77 | if 'Bone length' in line: 78 | word = line.split(' ') 79 | Test_result_1.append(float(word[5][:5])) 80 | Test_result_2.append(float(word[7][:5])) 81 | Test_result_3.append(float(word[10][:5])) 82 | Test_result_4.append(float(word[13][:5])) 83 | Test_result_5.append(float(word[17][:5])) 84 | Test_result_6.append(float(word[20][:5])) 85 | Test_result_7.append(float(word[23][:5])) 86 | Test_result_8.append(float(word[26][:5])) 87 | Test_result_9.append(float(word[29][:5])) 88 | Test_result_10.append(float(word[32][:5])) 89 | Test_result_11.append(float(word[35][:5])) 90 | Test_result_12.append(float(word[38][:5])) 91 | Test_result_13.append(float(word[41][:5])) 92 | Test_result_14.append(float(word[44][:5])) 93 | Test_result_15.append(float(word[47][:5])) 94 | Test_result_16.append(float(word[50][:5])) 95 | 96 | # calculate mean/variance 97 | cal = Test_result_1 98 | part_mean = np.mean(cal) 99 | part_var = np.std(cal) 100 | print(part_mean,part_var) 101 | cal = Test_result_2 102 | part_mean = np.mean(cal) 103 | part_var = np.std(cal) 104 | print(part_mean,part_var) 105 | cal = Test_result_3 106 | part_mean = np.mean(cal) 107 | part_var = np.std(cal) 108 | print(part_mean,part_var) 109 | cal = Test_result_4 110 | part_mean = np.mean(cal) 111 | part_var = np.std(cal) 112 | print(part_mean,part_var) 113 | cal = Test_result_5 114 | part_mean = np.mean(cal) 115 | part_var = np.std(cal) 116 | print(part_mean,part_var) 117 | cal = Test_result_6 118 | part_mean = np.mean(cal) 119 | part_var = np.std(cal) 120 | print(part_mean,part_var) 121 | cal = Test_result_7 122 | part_mean = np.mean(cal) 123 | part_var = np.std(cal) 124 | print(part_mean,part_var) 125 | cal = Test_result_8 126 | part_mean = np.mean(cal) 127 | part_var = np.std(cal) 128 | print(part_mean,part_var) 129 | cal = Test_result_9 130 | part_mean = np.mean(cal) 131 | part_var = np.std(cal) 132 | print(part_mean,part_var) 133 | cal = Test_result_10 134 | part_mean = np.mean(cal) 135 | part_var = np.std(cal) 136 | print(part_mean,part_var) 137 | cal = Test_result_11 138 | part_mean = np.mean(cal) 139 | part_var = np.std(cal) 140 | print(part_mean,part_var) 141 | cal = Test_result_12 142 | part_mean = np.mean(cal) 143 | part_var = np.std(cal) 144 | print(part_mean,part_var) 145 | cal = Test_result_13 146 | part_mean = np.mean(cal) 147 | part_var = np.std(cal) 148 | print(part_mean,part_var) 149 | cal = Test_result_14 150 | part_mean = np.mean(cal) 151 | part_var = np.std(cal) 152 | print(part_mean,part_var) 153 | cal = Test_result_15 154 | part_mean = np.mean(cal) 155 | part_var = np.std(cal) 156 | print(part_mean,part_var) 157 | cal = Test_result_16 158 | part_mean = np.mean(cal) 159 | part_var = np.std(cal) 160 | print(part_mean,part_var) 161 | 162 | 163 | # for line in input_2: 164 | # if 'Action bone' in line: 165 | # word = line.split(' ') 166 | # Test_result_2.append(float(word[7+a][:6])) 167 | # 168 | # for line in input_3: 169 | # if 'Action bone' in line: 170 | # word = line.split(' ') 171 | # Test_result_3.append(float(word[7+a][:6])) 172 | # 173 | # for line in input_4: 174 | # if 'Action bone' in line: 175 | # word = line.split(' ') 176 | # Test_result_4.append(float(word[7+a][:6])) 177 | # 178 | # for line in input_5: 179 | # if 'Action bone' in line: 180 | # word = line.split(' ') 181 | # Test_result_5.append(float(word[7+a][:6])) 182 | # for line in input_6: 183 | # if 'Action bone' in line: 184 | # word = line.split(' ') 185 | # Test_result_6.append(float(word[7+a][:6])) 186 | # 187 | # for line in input_7: 188 | # if 'Action bone' in line: 189 | # word = line.split(' ') 190 | # Test_result_7.append(float(word[7+a][:6])) 191 | # for line in input_8: 192 | # if 'Action bone' in line: 193 | # word = line.split(' ') 194 | # Test_result_8.append(float(word[7+a][:6])) 195 | # # 196 | # for line in input_9: 197 | # if 'Hip(root)' in line: 198 | # word = line.split(' ') 199 | # Test_result_9.append(average(float(word[4][:8]),float(word[7][:8]),float(word[10][:8]),float(word[13][:8]),float(word[16][:8]),float(word[19][:8]),float(word[22][:8]),float(word[25][:8]))) 200 | # 201 | # for line in input_10: 202 | # if 'Hip(root)' in line: 203 | # word = line.split(' ') 204 | # Test_result_10.append(average(float(word[4][:8]),float(word[7][:8]),float(word[10][:8]),float(word[13][:8]),float(word[16][:8]),float(word[19][:8]),float(word[22][:8]),float(word[25][:8]),float(word[28][:8]))) 205 | fig = plt.figure() 206 | plt.title("all subjects data in human3.6M") 207 | plt.plot(Test_result_1,'r-x', label = 'RHip') 208 | plt.plot(Test_result_2,'b-x', label = 'URLeg') 209 | plt.plot(Test_result_3,'g-x', label = 'LRLeg') 210 | plt.plot(Test_result_4,'y-x', label = 'Lhip') 211 | plt.plot(Test_result_5,'m-x', label = 'ULLeg') 212 | plt.plot(Test_result_6,'c-x', label = 'LLleg') 213 | plt.plot(Test_result_7,'k-x', label = 'Lspine') 214 | plt.plot(Test_result_8,'tab:pink', label = 'Uspine') 215 | plt.plot(Test_result_9,'tab:blue', label = 'Neck') 216 | plt.plot(Test_result_10,'tab:green', label = 'Head') 217 | plt.plot(Test_result_11,'c-^', label = 'Rshoulder') 218 | plt.plot(Test_result_12,'k-^', label = 'URelbow') 219 | plt.plot(Test_result_13,'r-^', label = 'LRelbow') 220 | plt.plot(Test_result_14,'b-^', label = 'Lshoulder') 221 | plt.plot(Test_result_15,'g-^', label = 'ULelbow') 222 | plt.plot(Test_result_16,'y-^', label = 'LLelbow') 223 | 224 | 225 | plt.legend() 226 | plt.grid(True) 227 | plt.xlabel('each subaction-person') 228 | plt.ylabel('Bone length/cm') 229 | plt.tight_layout() 230 | # my_x_ticks = np.arange(0,2181,10) 231 | # #my_y_ticks = np.arange(0,300,10) 232 | # plt.xticks(my_x_ticks) 233 | # #plt.yticks(my_y_ticks) 234 | # 235 | # ax_2 = plt.subplot(212) 236 | # ax_2.set_title('Test error for each action') 237 | # plt.plot(Test_result_1,'r-x', label = 'p2c12_f4_nfeat128_ep11') 238 | # plt.plot(Test_result_2,'b-x', label = 'p2c16_f4_nfeat128_ep20') 239 | # plt.plot(Test_result_3,'g-x', label = 'p2c12_f5_nfeat128_ep18') 240 | # plt.plot(Test_result_4,'y-x', label = 'p2c4_test_ep17') 241 | # plt.plot(Test_result_5,'m-x', label = 'p2c16_f4_nfeat128_ep15_best') 242 | # plt.plot(Test_result_6,'c-x', label = 'p2c4_f4_nfeat 256_ep19_best') 243 | # plt.legend() 244 | # plt.grid(True) 245 | # plt.xlabel('Epoch') 246 | # plt.ylabel('MPMJE/mm') 247 | # scale_ls = range(16) 248 | # #index_ls = ['Directions','Discussion','Eating','Greeting','Phoning','Photo','Posing','Purchases','Sitting','SittingDown','Smoking','Waiting','WalkDog','WalkTogether','Walking','Mean error'] 249 | # index_ls = ['Greeting','Sitting','SittingDown','WalkTogether','Phoning','Posing','WalkDog','Walking','Purchases','Waiting','Directions','Smoking','Photo','Eating','Discussion','Average'] 250 | # plt.xticks(scale_ls,index_ls) 251 | 252 | plt.show() 253 | -------------------------------------------------------------------------------- /data/prepare_data_humaneva.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import argparse 9 | import os 10 | import zipfile 11 | import numpy as np 12 | import h5py 13 | import re 14 | from glob import glob 15 | from shutil import rmtree 16 | from data_utils import suggest_metadata, suggest_pose_importer 17 | 18 | import sys 19 | sys.path.append('../') 20 | from common.utils import wrap 21 | from itertools import groupby 22 | 23 | subjects = ['Train/S1', 'Train/S2', 'Train/S3', 'Validate/S1', 'Validate/S2', 'Validate/S3'] 24 | 25 | cam_map = { 26 | 'C1': 0, 27 | 'C2': 1, 28 | 'C3': 2, 29 | } 30 | 31 | # Frame numbers for train/test split 32 | # format: [start_frame, end_frame[ (inclusive, exclusive) 33 | index = { 34 | 'Train/S1': { 35 | 'Walking 1': (590, 1203), 36 | 'Jog 1': (367, 740), 37 | 'ThrowCatch 1': (473, 945), 38 | 'Gestures 1': (395, 801), 39 | 'Box 1': (385, 789), 40 | }, 41 | 'Train/S2': { 42 | 'Walking 1': (438, 876), 43 | 'Jog 1': (398, 795), 44 | 'ThrowCatch 1': (550, 1128), 45 | 'Gestures 1': (500, 901), 46 | 'Box 1': (382, 734), 47 | }, 48 | 'Train/S3': { 49 | 'Walking 1': (448, 939), 50 | 'Jog 1': (401, 842), 51 | 'ThrowCatch 1': (493, 1027), 52 | 'Gestures 1': (533, 1102), 53 | 'Box 1': (512, 1021), 54 | }, 55 | 'Validate/S1': { 56 | 'Walking 1': (5, 590), 57 | 'Jog 1': (5, 367), 58 | 'ThrowCatch 1': (5, 473), 59 | 'Gestures 1': (5, 395), 60 | 'Box 1': (5, 385), 61 | }, 62 | 'Validate/S2': { 63 | 'Walking 1': (5, 438), 64 | 'Jog 1': (5, 398), 65 | 'ThrowCatch 1': (5, 550), 66 | 'Gestures 1': (5, 500), 67 | 'Box 1': (5, 382), 68 | }, 69 | 'Validate/S3': { 70 | 'Walking 1': (5, 448), 71 | 'Jog 1': (5, 401), 72 | 'ThrowCatch 1': (5, 493), 73 | 'Gestures 1': (5, 533), 74 | 'Box 1': (5, 512), 75 | }, 76 | } 77 | 78 | # Frames to skip for each video (synchronization) 79 | sync_data = { 80 | 'S1': { 81 | 'Walking 1': (82, 81, 82), 82 | 'Jog 1': (51, 51, 50), 83 | 'ThrowCatch 1': (61, 61, 60), 84 | 'Gestures 1': (45, 45, 44), 85 | 'Box 1': (57, 57, 56), 86 | }, 87 | 'S2': { 88 | 'Walking 1': (115, 115, 114), 89 | 'Jog 1': (100, 100, 99), 90 | 'ThrowCatch 1': (127, 127, 127), 91 | 'Gestures 1': (122, 122, 121), 92 | 'Box 1': (119, 119, 117), 93 | }, 94 | 'S3': { 95 | 'Walking 1': (80, 80, 80), 96 | 'Jog 1': (65, 65, 65), 97 | 'ThrowCatch 1': (79, 79, 79), 98 | 'Gestures 1': (83, 83, 82), 99 | 'Box 1': (1, 1, 1), 100 | }, 101 | 'S4': {} 102 | } 103 | 104 | if __name__ == '__main__': 105 | if os.path.basename(os.getcwd()) != 'data': 106 | print('This script must be launched from the "data" directory') 107 | exit(0) 108 | 109 | parser = argparse.ArgumentParser(description='HumanEva dataset converter') 110 | 111 | parser.add_argument('-p', '--path', default='', type=str, metavar='PATH', help='path to the processed HumanEva dataset') 112 | parser.add_argument('--convert-3d', action='store_true', help='convert 3D mocap data') 113 | parser.add_argument('--convert-2d', default='', type=str, metavar='PATH', help='convert user-supplied 2D detections') 114 | parser.add_argument('-o', '--output', default='', type=str, metavar='PATH', help='output suffix for 2D detections (e.g. detectron_pt_coco)') 115 | 116 | args = parser.parse_args() 117 | 118 | if not args.convert_2d and not args.convert_3d: 119 | print('Please specify one conversion mode') 120 | exit(0) 121 | 122 | 123 | if args.path: 124 | print('Parsing HumanEva dataset from', args.path) 125 | output = {} 126 | output_2d = {} 127 | frame_mapping = {} 128 | 129 | from scipy.io import loadmat 130 | 131 | num_joints = None 132 | 133 | for subject in subjects: 134 | output[subject] = {} 135 | output_2d[subject] = {} 136 | split, subject_name = subject.split('/') 137 | if subject_name not in frame_mapping: 138 | frame_mapping[subject_name] = {} 139 | 140 | file_list = glob(args.path + '/' + subject + '/*.mat') 141 | for f in file_list: 142 | action = os.path.splitext(os.path.basename(f))[0] 143 | 144 | # Use consistent naming convention 145 | canonical_name = action.replace('_', ' ') 146 | 147 | hf = loadmat(f) 148 | positions = hf['poses_3d'] 149 | positions_2d = hf['poses_2d'].transpose(1, 0, 2, 3) # Ground-truth 2D poses 150 | assert positions.shape[0] == positions_2d.shape[0] and positions.shape[1] == positions_2d.shape[2] 151 | assert num_joints is None or num_joints == positions.shape[1], "Joint number inconsistency among files" 152 | num_joints = positions.shape[1] 153 | 154 | # Sanity check for the sequence length 155 | assert positions.shape[0] == index[subject][canonical_name][1] - index[subject][canonical_name][0] 156 | 157 | # Split corrupted motion capture streams into contiguous chunks 158 | # e.g. 012XX567X9 is split into "012", "567", and "9". 159 | all_chunks = [list(v) for k, v in groupby(positions, lambda x: np.isfinite(x).all())] 160 | all_chunks_2d = [list(v) for k, v in groupby(positions_2d, lambda x: np.isfinite(x).all())] 161 | assert len(all_chunks) == len(all_chunks_2d) 162 | current_index = index[subject][canonical_name][0] 163 | chunk_indices = [] 164 | for i, chunk in enumerate(all_chunks): 165 | next_index = current_index + len(chunk) 166 | name = canonical_name + ' chunk' + str(i) 167 | if np.isfinite(chunk).all(): 168 | output[subject][name] = np.array(chunk, dtype='float32') / 1000 169 | output_2d[subject][name] = list(np.array(all_chunks_2d[i], dtype='float32').transpose(1, 0, 2, 3)) 170 | chunk_indices.append((current_index, next_index, np.isfinite(chunk).all(), split, name)) 171 | current_index = next_index 172 | assert current_index == index[subject][canonical_name][1] 173 | if canonical_name not in frame_mapping[subject_name]: 174 | frame_mapping[subject_name][canonical_name] = [] 175 | frame_mapping[subject_name][canonical_name] += chunk_indices 176 | 177 | metadata = suggest_metadata('humaneva' + str(num_joints)) 178 | output_filename = 'data_3d_' + metadata['layout_name'] 179 | output_prefix_2d = 'data_2d_' + metadata['layout_name'] + '_' 180 | 181 | if args.convert_3d: 182 | print('Saving...') 183 | np.savez_compressed(output_filename, positions_3d=output) 184 | np.savez_compressed(output_prefix_2d + 'gt', positions_2d=output_2d, metadata=metadata) 185 | print('Done.') 186 | 187 | else: 188 | print('Please specify the dataset source') 189 | exit(0) 190 | 191 | if args.convert_2d: 192 | if not args.output: 193 | print('Please specify an output suffix (e.g. detectron_pt_coco)') 194 | exit(0) 195 | 196 | import_func = suggest_pose_importer(args.output) 197 | metadata = suggest_metadata(args.output) 198 | 199 | print('Parsing 2D detections from', args.convert_2d) 200 | 201 | output = {} 202 | file_list = glob(args.convert_2d + '/S*/*.avi.npz') 203 | for f in file_list: 204 | path, fname = os.path.split(f) 205 | subject = os.path.basename(path) 206 | assert subject.startswith('S'), subject + ' does not look like a subject directory' 207 | 208 | m = re.search('(.*) \\((.*)\\)', fname.replace('_', ' ')) 209 | action = m.group(1) 210 | camera = m.group(2) 211 | camera_idx = cam_map[camera] 212 | 213 | keypoints = import_func(f) 214 | assert keypoints.shape[1] == metadata['num_joints'] 215 | 216 | if action in sync_data[subject]: 217 | sync_offset = sync_data[subject][action][camera_idx] - 1 218 | else: 219 | sync_offset = 0 220 | 221 | if subject in frame_mapping and action in frame_mapping[subject]: 222 | chunks = frame_mapping[subject][action] 223 | for (start_idx, end_idx, labeled, split, name) in chunks: 224 | canonical_subject = split + '/' + subject 225 | if not labeled: 226 | canonical_subject = 'Unlabeled/' + canonical_subject 227 | if canonical_subject not in output: 228 | output[canonical_subject] = {} 229 | kps = keypoints[start_idx+sync_offset:end_idx+sync_offset] 230 | assert len(kps) == end_idx - start_idx, "Got len {}, expected {}".format(len(kps), end_idx - start_idx) 231 | 232 | if name not in output[canonical_subject]: 233 | output[canonical_subject][name] = [None, None, None] 234 | 235 | output[canonical_subject][name][camera_idx] = kps.astype('float32') 236 | else: 237 | canonical_subject = 'Unlabeled/' + subject 238 | if canonical_subject not in output: 239 | output[canonical_subject] = {} 240 | if action not in output[canonical_subject]: 241 | output[canonical_subject][action] = [None, None, None] 242 | output[canonical_subject][action][camera_idx] = keypoints.astype('float32') 243 | 244 | print('Saving...') 245 | np.savez_compressed(output_prefix_2d + args.output, positions_2d=output, metadata=metadata) 246 | print('Done.') -------------------------------------------------------------------------------- /common/common_pytorch/loss/loss_family.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | import numpy as np 11 | import math 12 | 13 | 14 | def mpjpe(predicted, target): 15 | """ 16 | Mean per-joint position error (i.e. mean Euclidean distance), 17 | often referred to as "Protocol #1" in many papers. 18 | """ 19 | assert predicted.shape == target.shape 20 | #l2_error = torch.mean(torch.norm((predicted - target), dim=len(target.shape) - 1), -1).squeeze() 21 | #print('each joint error:', torch.norm((predicted - target), dim=len(target.shape) - 1)) 22 | #index = np.where(l2_error.cpu().detach().numpy() > 0.3) # mean body l2 distance larger than 300mm 23 | #value = l2_error[l2_error > 0.3] 24 | #print('Index of mean body l2 distance larger than 300mm', index, value) 25 | return torch.mean(torch.norm((predicted - target), dim=len(target.shape) - 1)) 26 | 27 | 28 | def mpjae(predicted, target): 29 | """ 30 | Mean per-joint angle error (3d bone vector angle error between gt and predicted one) 31 | """ 32 | assert predicted.shape == target.shape # [B,T, K] 33 | joint_error = torch.mean(torch.abs(predicted - target).cuda(), dim=0) # Calculate each joint angle 34 | print('each bone angle error:', joint_error) 35 | return torch.mean(joint_error) 36 | 37 | 38 | # def weighted_mpjpe(predicted, target): 39 | # take each joint with a weight 40 | 41 | def mpjpe_smooth(predicted, target, threshold, mi, L1): 42 | """ 43 | Referred in triangulation 3d pose paper 44 | """ 45 | assert predicted.shape == target.shape 46 | if L1: 47 | diff_norm = torch.abs((predicted - target), dim=len(target.shape) - 1) 48 | diff = diff_norm.clone() 49 | else: # MSE 50 | diff = (predicted - target) ** 2 51 | diff[diff > threshold] = torch.pow(diff[diff > threshold], mi) * (threshold ** (1 - mi)) 52 | loss = torch.mean(diff) 53 | return loss 54 | 55 | 56 | def L1_loss(predicted, target): 57 | assert predicted.shape == target.shape 58 | abs_error = torch.mean(torch.mean(torch.abs(predicted - target).cuda(), dim=-2), dim=0) 59 | error = torch.mean(abs_error) 60 | return error 61 | 62 | 63 | def kpt_mpjpe(predicted, target): 64 | # Mean per-joint position error for each keypoint(i.e. mean Euclidean distance) 65 | # This function is just for evaluate!! input shape is (1,t,17,3) 66 | assert predicted.shape == target.shape 67 | kpt_error = torch.mean(torch.norm(predicted - target, dim=len(target.shape) - 1)) 68 | kpt_xyz = torch.mean(torch.norm(predicted - target, dim=len(target.shape) - 2), dim=1) 69 | print('X,Y,Z error of T input frames:', kpt_xyz / np.sqrt(17)) 70 | kpt_17 = torch.mean(torch.mean(torch.norm(predicted - target, dim=len(target.shape) - 1), dim=0), dim=0) 71 | return kpt_17 72 | 73 | 74 | def kpt_test(predicted, target): 75 | # Mean per-joint position error for each keypoint(i.e. mean Euclidean distance) 76 | # This function is just for evaluate!! input shape is (2,t,17,3) 77 | assert predicted.shape == target.shape 78 | print('The frame number is', target.size()) 79 | kpt_xyz = torch.mean(torch.mean(torch.norm(predicted - target, dim=len(target.shape) - 2), dim=0), dim=0) 80 | kpt_17 = torch.mean(torch.mean(torch.norm(predicted - target, dim=len(target.shape) - 1), dim=0), dim=0) 81 | return kpt_xyz, kpt_17 82 | 83 | def class_accuracy(predicted, target, confidence, threshold=0.04): 84 | # confidence.shape = [B,T,1,1] 85 | confidence = torch.mean(torch.mean(torch.mean(confidence, -1), -1), 0) 86 | sig = nn.Sigmoid() 87 | 88 | diff = torch.mean(torch.mean(torch.norm(predicted - target, dim=len(target.shape) - 1), dim=0), dim=-1) 89 | class0 = (diff<=threshold).cpu() 90 | conf = (sig(confidence)>0.5).cpu() 91 | correct = (conf == class0).sum() 92 | print('ooo',diff.shape,correct, predicted.shape[1],correct/predicted.shape[1]) 93 | return correct 94 | 95 | 96 | 97 | def Uncertain_CE(predicted, target, confidence, threshold=0.04, L1_loss=True): 98 | # confidence.shape = [B,T,1,1] 99 | confidence = torch.mean(torch.mean(torch.mean(confidence, -1), -1), -1) 100 | above_thre = torch.zeros_like(confidence).cuda() 101 | below_thre = torch.zeros_like(confidence).cuda() 102 | # class0 = torch.zeros((len(confidence), 2)).cuda() 103 | class0 = torch.zeros_like(confidence).cuda() 104 | if L1_loss: 105 | diff = torch.mean(torch.mean(torch.mean(torch.abs(predicted - target), dim=1), dim=-1), dim=-1) #[B] 106 | else: 107 | diff = torch.mean(torch.mean(torch.norm(predicted - target, dim=len(target.shape) - 1), dim=1), dim=-1) #[B] 108 | threshold = torch.mean(diff) 109 | above_thre = diff * (diff>threshold).float() 110 | below_thre = diff * (diff<=threshold).float() 111 | # print('ccc',above_thre,'ppp',below_thre) 112 | # class0[:, 0] = (diff>threshold).long().cuda() 113 | # class0[:, 1] = (diff<=threshold).long().cuda() 114 | class0 = (diff<=threshold).float().cuda() 115 | sig = nn.Sigmoid() 116 | a1 = sig(confidence) 117 | # a1 = 0.5 118 | a2 = 1 - a1 119 | 120 | weight = 0.1 121 | # pos_weight = torch.sum(diff>threshold)/torch.sum(diff<=threshold) 122 | # BCE = nn.BCEWithLogitsLoss(pos_weight=pos_weight) 123 | BCE = nn.BCEWithLogitsLoss() 124 | item1 = torch.mean(a1 * below_thre) 125 | item2 = torch.mean(a2 * above_thre) 126 | item3 = torch.mean(weight*BCE(confidence, class0)) 127 | print('xxx',item3,item2,item1) 128 | loss = item1 + item2 + item3 129 | return loss 130 | 131 | 132 | 133 | 134 | class L1GaussianRegressionNewFlow(nn.Module): 135 | ''' L1 Joint Gaussian Regression Loss 136 | ''' 137 | 138 | def __init__(self, OUTPUT_3D=False, size_average=True): 139 | super(L1GaussianRegressionNewFlow, self).__init__() 140 | self.size_average = size_average 141 | self.amp = 1 / math.sqrt(2 * math.pi) 142 | 143 | def weighted_l2_loss(self, pred, gt, weight): 144 | diff = (pred - gt) ** 2 145 | diff = diff * weight 146 | return diff.sum() / (weight.sum() + 1e-9) 147 | 148 | def _generate_activation(self, gt_coords, pred_coords): 149 | sigma = 2 * 2 / 64 150 | # (B, K, 1, 2) 151 | gt_coords = gt_coords.permute(0,2,1,3) 152 | # (B, 1, K, 2) 153 | pred_coords = pred_coords 154 | 155 | diff = torch.sum((gt_coords - pred_coords) ** 2, dim=-1) 156 | activation = torch.exp(-(diff / (2 * (sigma ** 2)))) 157 | 158 | return activation 159 | 160 | def forward(self, output, labels): 161 | nf_loss = output.nf_loss 162 | pred_jts = output.pred_jts 163 | sigma = output.sigma 164 | 165 | gt_uv = labels 166 | weight = 1 167 | #gaussian = weight * torch.log(sigma / self.amp) + torch.abs(gt_uv - pred_jts) / (math.sqrt(2) * sigma + 1e-9) 168 | gaussian = torch.abs(gt_uv - pred_jts) 169 | residual = True 170 | if residual: 171 | weight1 = 1 172 | nf_loss = weight1 * nf_loss + gaussian 173 | 174 | if self.size_average > 0: 175 | regression_loss = nf_loss.sum() / len(nf_loss) 176 | #regression_loss = torch.mean(nf_loss)#todo 177 | 178 | else: 179 | regression_loss = nf_loss.sum() 180 | 181 | loss = regression_loss 182 | 183 | return loss 184 | 185 | 186 | def p_mpjpe(predicted, target): 187 | """ 188 | Pose error: MPJPE after rigid alignment (scale, rotation, and translation), 189 | often referred to as "Protocol #2" in many papers. 190 | """ 191 | assert predicted.shape == target.shape # (3071, 17, 3) 192 | muX = np.mean(target, axis=1, keepdims=True) 193 | muY = np.mean(predicted, axis=1, keepdims=True) 194 | 195 | X0 = target - muX 196 | Y0 = predicted - muY 197 | 198 | # Remove scale 199 | normX = np.sqrt(np.sum(X0 ** 2, axis=(1, 2), keepdims=True)) 200 | normY = np.sqrt(np.sum(Y0 ** 2, axis=(1, 2), keepdims=True)) 201 | # print('target',normX,'predice',normY) 202 | X0 /= (normX + 1e-8) 203 | if normY.any() == 0: 204 | normY = normY + 1e-8 205 | 206 | Y0 /= (normY + 1e-8) 207 | # Optimum rotation matrix of Y0 208 | H = np.matmul(X0.transpose(0, 2, 1), Y0) 209 | U, s, Vt = np.linalg.svd(H) 210 | V = Vt.transpose(0, 2, 1) 211 | R = np.matmul(V, U.transpose(0, 2, 1)) # Rotation 212 | 213 | # Avoid improper rotations (reflections), i.e. rotations with det(R) = -1 214 | sign_detR = np.sign(np.expand_dims(np.linalg.det(R), axis=1)) 215 | V[:, :, -1] *= sign_detR 216 | s[:, -1] *= sign_detR.flatten() 217 | R = np.matmul(V, U.transpose(0, 2, 1)) # Rotation 218 | 219 | tr = np.expand_dims(np.sum(s, axis=1, keepdims=True), axis=2) 220 | 221 | a = tr * normX / normY # Scale 222 | t = muX - a * np.matmul(muY, R) # Translation 223 | 224 | # Standarized Distance between X0 and a*Y0*R+c 225 | d = 1 - tr ** 2 226 | 227 | # Perform rigid transformation on the input 228 | predicted_aligned = a * np.matmul(predicted, R) 229 | trans_aligned = predicted_aligned + t 230 | error = np.mean(np.linalg.norm(trans_aligned - target, axis=len(target.shape) - 1)) 231 | # Return MPJPE 232 | return error, torch.from_numpy(trans_aligned).unsqueeze(dim=0).cuda() 233 | 234 | 235 | def n_mpjpe(predicted, target): 236 | """ 237 | Normalized MPJPE (scale only), adapted from: 238 | https://github.com/hrhodin/UnsupervisedGeometryAwareRepresentationLearning/blob/master/losses/poses.py 239 | """ 240 | assert predicted.shape == target.shape # [1, 1703, 17, 3] 241 | norm_predicted = torch.mean(torch.sum(predicted ** 2, dim=3, keepdim=True), dim=2, keepdim=True) 242 | norm_target = torch.mean(torch.sum(target * predicted, dim=3, keepdim=True), dim=2, keepdim=True) 243 | scale = norm_target / norm_predicted 244 | out = torch.mean(torch.norm((scale * predicted - target), dim=len(target.shape) - 1)) 245 | return out 246 | 247 | 248 | def mean_velocity_error(predicted, target): 249 | """ 250 | Mean per-joint velocity error (i.e. mean Euclidean distance of the 1st derivative) 251 | """ 252 | assert predicted.shape == target.shape 253 | velocity_predicted = np.diff(predicted, axis=0) 254 | velocity_target = np.diff(target, axis=0) 255 | return np.mean(np.linalg.norm(velocity_predicted - velocity_target, axis=len(target.shape) - 1)) 256 | -------------------------------------------------------------------------------- /common/dataset/pre_process/get_3dpw.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('.') 3 | 4 | import os 5 | import cv2 6 | import numpy as np 7 | import pickle as pkl 8 | import os.path as osp 9 | 10 | import matplotlib 11 | matplotlib.use('Agg') 12 | 13 | from common.dataset.pre_process.kpt_index import get_perm_idxs 14 | from common.dataset.pre_process.norm_data import norm_to_pixel 15 | 16 | from common.transformation.cam_utils import normalize_screen_coordinates 17 | from common.visualization.plot_pose3d import plot17j 18 | from common.visualization.plot_pose2d import ColorStyle, color1, link_pairs1, point_color1 19 | 20 | def get_3dpw(part): 21 | # part can be :'train', 'validation', 'test' 22 | 23 | folder = '/data/ailing/Video3d/data/3dpw/' 24 | NUM_JOINTS = 24 25 | VIS_THRESH = 0.3 26 | MIN_KP = 6 27 | 28 | sequences = [x.split('.')[0] for x in os.listdir(osp.join(folder, 'sequenceFiles',part))] 29 | print('action sequence:',sequences,len(sequences)) 30 | imgs_path = [] 31 | pose_3d = [] 32 | pose_2d = [] 33 | cam_ex = [] 34 | cam_intri = [] 35 | 36 | # start to process 3dpw raw data 37 | for i, seq in enumerate(sequences): 38 | print('sub sequence index:',i) 39 | data_file = osp.join(folder, 'sequenceFiles', part, seq + '.pkl') 40 | data = pkl.load(open(data_file, 'rb'), encoding='latin1') 41 | img_dir = osp.join(folder, 'imageFiles', seq) 42 | 43 | num_people = len(data['poses']) 44 | num_frames = len(data['img_frame_ids']) 45 | print('open action file:',data_file,img_dir,'has number people:',num_people,'with frame number:',num_frames) 46 | 47 | assert (data['poses2d'][0].shape[0] == num_frames) 48 | 49 | for p_id in range(num_people): 50 | print('person number:',p_id) 51 | j3d = data['jointPositions'][p_id].reshape(-1, 24,3) 52 | j2d = data['poses2d'][p_id].transpose(0,2,1) 53 | cam_in = data['cam_intrinsics'] #[3,3] 54 | cam_pose = data['cam_poses'] #[T, 4, 4] all people in a image will share the same 55 | 56 | campose_valid = data['campose_valid'][p_id] #[T,] 57 | print('invalid frames:',np.where(campose_valid==0),'valid frame number:',np.count_nonzero(campose_valid)) 58 | new_j2d = np.zeros((j2d.shape[0],17,3)) 59 | new_j3d = np.zeros((j3d.shape[0],17,3)) 60 | 61 | # process 2d 3dpw keypoints into hm36 style 62 | perm_idxs = get_perm_idxs('3dpw', 'h36m') 63 | j2d = j2d[:, perm_idxs] 64 | new_j2d[:, 0] = (j2d[:,0] + j2d[:,3])/2 65 | new_j2d[:,1:7] = j2d[:,0:6] 66 | # new_j2d[:,4:7] = j2d[:,0:3] 67 | # new_j2d[:,1:4] = j2d[:,3:6] 68 | 69 | new_j2d[:,8] = (j2d[:,7]+j2d[:,10])/2 #neck 70 | new_j2d[:,7] = 0.7*new_j2d[:,0]+0.3*new_j2d[:,8] 71 | new_j2d[:,9] = j2d[:,6] 72 | new_j2d[:, 10] = 2*j2d[:, 6] - new_j2d[:,9] 73 | new_j2d[:,11:14] = j2d[:,10:13] 74 | new_j2d[:,14:17] = j2d[:,7:10] 75 | 76 | new_j2d[:, :, 2] = new_j2d[:, :, 2] > 0.3 # set the visibility flags 77 | 78 | # process 3d 3dpw_smpl joints into hm36 style 79 | perm_idxs = get_perm_idxs('smpl', 'h36m') 80 | j3d = j3d[:, perm_idxs] 81 | new_j3d[:,10] = 2*j3d[:, 9] - j3d[:,8] 82 | new_j3d[:,:10] = j3d[:,:10] 83 | new_j3d[:,11:] = j3d[:,10:] 84 | new_j3d[:,7] = 0.7*new_j3d[:,0] + 0.3*new_j3d[:,8] #update lower spine position 85 | #print('new pose 2d/3d shape:',new_j2d.shape, new_j3d.shape) 86 | 87 | # get camere params. 88 | cam_rt = cam_pose[:,0:3, 0:3] 89 | cam_t = cam_pose[:,0:3, 3:4] 90 | cam_pose3d = np.zeros_like(new_j3d) # get 3d pose under camere coordination system 91 | for j in range(len(new_j3d)): 92 | for k, kpt in enumerate(new_j3d[0]): 93 | cam_pose3d[j,k][:,np.newaxis] = np.dot(cam_rt[j], new_j3d[j,k][:,np.newaxis])+cam_t[j] 94 | 95 | #cam_pose3d[:,8]=(cam_pose3d[:,11]+cam_pose3d[:,14])/2 96 | cam_pose3d[:,0]=(cam_pose3d[:,1]+cam_pose3d[:,4])/2 97 | 98 | cam_f = np.array([cam_in[0, 0], cam_in[1, 1]]) 99 | cam_c = cam_in[0:2, 2] 100 | h = int(2 * cam_c[1]) 101 | w = int(2 * cam_c[0]) 102 | 103 | # verify cam_pose is right: 104 | XX = cam_pose3d[:, :,:2] / cam_pose3d[:,:, 2:] 105 | if np.array(XX).any() > 1 or np.array(XX).any() < -1: 106 | print(np.array(XX).any() > 1 or np.array(XX).any() < -1) 107 | print('Attention for this pose!!!') 108 | pose_2 = cam_f * XX + cam_c 109 | 110 | 111 | show_2d = False 112 | show_3d = False 113 | for index in range(0, len(pose_2)): 114 | #index = 350 115 | img_path = os.path.join(img_dir + '/image_%05d.jpg' % index) 116 | text = "Root 3d: ({:04.2f},{:04.2f},{:04.2f})m".format(cam_pose3d[index, 0, 0], cam_pose3d[index, 0, 1], 117 | cam_pose3d[index, 0, 2]) 118 | print(text, 'seq', seq, 'person_id', p_id, 'index', index) 119 | if show_2d: 120 | colorstyle = ColorStyle(color1, link_pairs1, point_color1) 121 | connections = [[0, 1], [1, 2], [2, 3], [0, 4], [4, 5], 122 | [5, 6], [0, 7], [7, 8], [8, 9], [9, 10], 123 | [8, 11], [11, 12], [12, 13], [8, 14], [14, 15], [15, 16]] 124 | img = cv2.imread(img_path) 125 | 126 | kps = pose_2 # projected 2d pose 127 | kps_gt = new_j2d #given 2d pose 128 | for j, c in enumerate(connections): 129 | start = kps[index,c[0]] 130 | end = kps[index,c[1]] 131 | cv2.line(img, (int(start[0]), int(start[1])), (int(end[0]), int(end[1])), colorstyle.line_color[j], 3) 132 | cv2.circle(img, (int(kps[index,j,0]), int(kps[index,j,1])), 4, colorstyle.ring_color[j], 2) 133 | 134 | start_gt = kps_gt[index, c[0]] 135 | end_gt = kps_gt[index, c[1]] 136 | cv2.line(img, (int(start_gt[0]), int(start_gt[1])), (int(end_gt[0]), int(end_gt[1])), (255, 0, 0), 3) 137 | cv2.circle(img, (int(kps_gt[index, j, 0]), int(kps_gt[index, j, 1])), 3, (255, 100, 0), 2) 138 | text = "Root 3d: ({:04.2f}, {:04.2f}, {:04.2f})m".format(cam_pose3d[index,0,0],cam_pose3d[index,0,1],cam_pose3d[index,0,2]) 139 | print(part, text, 'seq',seq, 'person_id',p_id, 'index',index) 140 | # cv2.putText(img, text, (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2, cv2.LINE_AA) 141 | cv2.imshow('3DPW Example', img) 142 | 143 | # cv2.imwrite('data/3dpw/validation/{}_{}_{:05d}.jpg'.format(seq, p_id, index), img) 144 | # cv2.waitKey(0) 145 | # cv2.destroyAllWindows() 146 | 147 | if show_3d: 148 | plot17j(np.concatenate((new_j3d[345:349], cam_pose3d[345:349]),axis=0), None,'a','a') 149 | 150 | 151 | # Filter out keypoints 152 | indices_to_use = np.where((j2d[:, :, 2] > VIS_THRESH).sum(-1) > MIN_KP)[0] # you can change the VIS_THRESH to get pose_2d with different quality 153 | print('selected indexes:',indices_to_use) 154 | print('selected valid frame number:',len(indices_to_use)) 155 | 156 | #norm pose 3d use zero-root 157 | #cam_pose_norm = cam_pose3d-cam_pose3d[:,:1] 158 | #pose_2_norm = normalize_screen_coordinates(pose_2, w, h) 159 | #pose_2_norm[indices_to_use] = normalize_screen_coordinates(new_j2d[indices_to_use,:,:2], w, h) 160 | #pose_3d.append(cam_pose_norm) 161 | #pose_2d.append(pose_2_norm) 162 | 163 | if indices_to_use.any(): 164 | pose_2 = pose_2[indices_to_use] 165 | cam_pose3d = cam_pose3d[indices_to_use] 166 | print('final pose shape:',pose_2.shape, cam_pose3d.shape) 167 | cam_int = np.zeros((9)) 168 | cam_int[:2] = cam_f 169 | cam_int[2:4] = cam_c 170 | 171 | pose_2d.append(pose_2) 172 | pose_3d.append(cam_pose3d) 173 | cam_intri.append(cam_int) 174 | 175 | 176 | print('total length:',len(pose_3d)) 177 | file_name = 'data/3dpw_{}'.format(part) 178 | np.savez_compressed(file_name, pose_3d=pose_3d, pose_2d=pose_2d,intrinsic=cam_intri) 179 | print('Saved as:', file_name) 180 | print('Done') 181 | 182 | 183 | def load_3dpw(part, norm): 184 | data_3dpw_test = np.load('data/3dpw_{}_valid.npz'.format(part), allow_pickle=True) 185 | poses_valid = data_3dpw_test['pose_3d'] 186 | poses_valid_2d = data_3dpw_test['pose_2d'] 187 | valid_cam_in = data_3dpw_test['intrinsic'] 188 | # normalize 189 | norm_val = [] 190 | cameras_valid = [] 191 | for i in range(len(poses_valid)): 192 | if norm == 'base': 193 | poses_valid[i][:, 1:] -= poses_valid[i][:, :1] 194 | normed_pose_3d = poses_valid[i] 195 | c_x, c_y = valid_cam_in[i][2], valid_cam_in[i][3] 196 | img_w = int(2 * c_x) 197 | img_h = int(2 * c_y) 198 | normed_pose_2d = normalize_screen_coordinates(poses_valid_2d[i][..., :2], w=img_w, h=img_h) 199 | cameras_valid = None 200 | else: 201 | normed_pose_3d, normed_pose_2d, pixel_ratio, rescale_ratio, offset_2d, abs_root_Z = norm_to_pixel(poses_valid[i], 202 | poses_valid_2d[i], 203 | valid_cam_in[i], 204 | norm) 205 | norm_val.append(np.concatenate((pixel_ratio, rescale_ratio, offset_2d, abs_root_Z), axis=-1)) # [T, 1, 5], len()==4 206 | use_params = {} 207 | use_params['intrinsic'] = valid_cam_in[i] 208 | use_params['normalization_params'] = norm_val[i] 209 | cameras_valid.append(use_params) 210 | poses_valid_2d[i] = normed_pose_2d 211 | poses_valid[i] = normed_pose_3d 212 | return poses_valid, poses_valid_2d, cameras_valid -------------------------------------------------------------------------------- /common/visualization/plot_pose3d.py: -------------------------------------------------------------------------------- 1 | ### Many thanks to: https://raw.githubusercontent.com/bastianwandt/RepNet/7b9185cadd12f850e9fa1754505fca68c34be4ed/plot17j.py 2 | ### Changed the index to Human3.6M 17 keypoints 3 | import numpy as np 4 | 5 | import matplotlib as mpl 6 | mpl.use('Qt5Agg') 7 | import matplotlib.pyplot as plt 8 | import matplotlib.animation as anim 9 | from mpl_toolkits.mplot3d import axes3d, Axes3D 10 | 11 | def plot17j(poses, ax=None, subject=None, action=None, show_animation=False): 12 | if not show_animation: 13 | plot_idx = 1 14 | if len(poses.shape)>2: 15 | fig = plt.figure() 16 | frames = np.linspace(start=0, stop=poses.shape[0]-1, num=6).astype(int) 17 | for i in frames: 18 | ax = fig.add_subplot(2, 3, plot_idx, projection='3d') 19 | pose = poses[i] 20 | x = pose[:, 0] 21 | y = pose[:, 1] 22 | z = pose[:, 2] 23 | ax.scatter(x, y, z) 24 | ax.plot(x[([0, 1])], y[([0, 1])], z[([0, 1])]) 25 | ax.plot(x[([1, 2])], y[([1, 2])], z[([1, 2])]) 26 | ax.plot(x[([2, 3])], y[([2, 3])], z[([2, 3])]) 27 | ax.plot(x[([0, 4])], y[([0, 4])], z[([0, 4])]) 28 | ax.plot(x[([4, 5])], y[([4, 5])], z[([4, 5])]) 29 | ax.plot(x[([5, 6])], y[([5, 6])], z[([5, 6])]) 30 | ax.plot(x[([0, 7])], y[([0, 7])], z[([0, 7])]) 31 | ax.plot(x[([7, 8])], y[([7, 8])], z[([7, 8])]) 32 | ax.plot(x[([8, 9])], y[([8, 9])], z[([8, 9])]) 33 | ax.plot(x[([9, 10])], y[([9, 10])], z[([9, 10])]) 34 | ax.plot(x[([8, 11])], y[([8, 11])], z[([8, 11])]) 35 | ax.plot(x[([11, 12])], y[([11, 12])], z[([11, 12])]) 36 | ax.plot(x[([12, 13])], y[([12, 13])], z[([12, 13])]) 37 | ax.plot(x[([8, 14])], y[([8, 14])], z[([8, 14])]) 38 | ax.plot(x[([14, 15])], y[([14, 15])], z[([14, 15])]) 39 | ax.plot(x[([15, 16])], y[([15, 16])], z[([15, 16])]) 40 | # Create cubic bounding box to simulate equal aspect ratio 41 | max_range = np.array([x.max() - x.min(), y.max() - y.min(), z.max() - z.min()]).max() 42 | Xb = 0.5 * max_range * np.mgrid[-1:1:1, -1:1:1, -1:1:1][0].flatten() + 0.5 * (x.max() + x.min()) 43 | Yb = 0.5 * max_range * np.mgrid[-1:1:1, -1:1:1, -1:1:1][1].flatten() + 0.5 * (y.max() + y.min()) 44 | Zb = 0.5 * max_range * np.mgrid[-1:1:1, -1:1:1, -1:1:1][2].flatten() + 0.5 * (z.max() + z.min()) 45 | 46 | for xb, yb, zb in zip(Xb, Yb, Zb): 47 | ax.plot([xb], [yb], [zb], 'w') 48 | radius = 2 49 | ax.view_init(elev=75, azim=110) 50 | ax.set_xlim3d([-radius / 2, radius / 2]) 51 | ax.set_zlim3d([0, radius]) 52 | ax.set_ylim3d([-radius / 2, radius / 2]) 53 | 54 | ax.set_xlabel("x") 55 | ax.set_ylabel("y") 56 | ax.set_zlabel("z") 57 | # ax.invert_zaxis() 58 | ax.axis('equal') 59 | # ax.axis('off') 60 | 61 | # ax.set_title('camera = ' + str(i)) 62 | 63 | plot_idx += 1 64 | else: 65 | pose = poses 66 | x = pose[:, 0] 67 | y = pose[:, 1] 68 | z = pose[:, 2] 69 | ax.scatter(x, y, z) 70 | ax.plot(x[([0, 1])], y[([0, 1])], z[([0, 1])]) 71 | ax.plot(x[([1, 2])], y[([1, 2])], z[([1, 2])]) 72 | ax.plot(x[([2, 3])], y[([2, 3])], z[([2, 3])]) 73 | ax.plot(x[([0, 4])], y[([0, 4])], z[([0, 4])]) 74 | ax.plot(x[([4, 5])], y[([4, 5])], z[([4, 5])]) 75 | ax.plot(x[([5, 6])], y[([5, 6])], z[([5, 6])]) 76 | ax.plot(x[([0, 7])], y[([0, 7])], z[([0, 7])]) 77 | ax.plot(x[([7, 8])], y[([7, 8])], z[([7, 8])]) 78 | ax.plot(x[([8, 9])], y[([8, 9])], z[([8, 9])]) 79 | ax.plot(x[([9, 10])], y[([9, 10])], z[([9, 10])]) 80 | ax.plot(x[([8, 11])], y[([8, 11])], z[([8, 11])]) 81 | ax.plot(x[([11, 12])], y[([11, 12])], z[([11, 12])]) 82 | ax.plot(x[([12, 13])], y[([12, 13])], z[([12, 13])]) 83 | ax.plot(x[([8, 14])], y[([8, 14])], z[([8, 14])]) 84 | ax.plot(x[([14, 15])], y[([14, 15])], z[([14, 15])]) 85 | ax.plot(x[([15, 16])], y[([15, 16])], z[([15, 16])]) 86 | # Create cubic bounding box to simulate equal aspect ratio 87 | max_range = np.array([x.max() - x.min(), y.max() - y.min(), z.max() - z.min()]).max() 88 | Xb = 0.5 * max_range * np.mgrid[-1:1:1, -1:1:1, -1:1:1][0].flatten() + 0.5 * (x.max() + x.min()) 89 | Yb = 0.5 * max_range * np.mgrid[-1:1:1, -1:1:1, -1:1:1][1].flatten() + 0.5 * (y.max() + y.min()) 90 | Zb = 0.5 * max_range * np.mgrid[-1:1:1, -1:1:1, -1:1:1][2].flatten() + 0.5 * (z.max() + z.min()) 91 | 92 | for xb, yb, zb in zip(Xb, Yb, Zb): 93 | ax.plot([xb], [yb], [zb], 'w') 94 | # radius = 2 95 | # ax.view_init(elev=15., azim=110) 96 | # ax.set_xlim3d([-radius / 2, radius / 2]) 97 | # ax.set_zlim3d([0, radius]) 98 | # ax.set_ylim3d([-radius / 2, radius / 2]) 99 | 100 | ax.set_xlabel("x") 101 | ax.set_ylabel("y") 102 | ax.set_zlabel("z") 103 | # ax.invert_zaxis() 104 | ax.axis('equal') 105 | #ax.axis('off') 106 | 107 | #ax.set_title('camera = ' + str(i)) 108 | 109 | plot_idx += 1 110 | 111 | # this uses QT5Agg backend 112 | # you can identify the backend using plt.get_backend() 113 | # delete the following two lines and resize manually if it throws an error 114 | figManager = plt.get_current_fig_manager() 115 | figManager.window.showMaximized() 116 | 117 | plt.show() 118 | plt.savefig('show/mean_train_pose_{}_{}'.format(subject, action), bbox_inches='tight') 119 | plt.close() 120 | 121 | else: 122 | def update(i): 123 | 124 | ax.clear() 125 | 126 | pose = poses[i] 127 | 128 | x = pose[:, 0] 129 | y = pose[:, 1] 130 | z = pose[:, 2] 131 | ax.scatter(x, y, z) 132 | 133 | ax.plot(x[([0, 1])], y[([0, 1])], z[([0, 1])]) 134 | ax.plot(x[([1, 2])], y[([1, 2])], z[([1, 2])]) 135 | ax.plot(x[([2, 3])], y[([2, 3])], z[([2, 3])]) 136 | ax.plot(x[([0, 4])], y[([0, 4])], z[([0, 4])]) 137 | ax.plot(x[([4, 5])], y[([4, 5])], z[([4, 5])]) 138 | ax.plot(x[([5, 6])], y[([5, 6])], z[([5, 6])]) 139 | ax.plot(x[([0, 7])], y[([0, 7])], z[([0, 7])]) 140 | ax.plot(x[([7, 8])], y[([7, 8])], z[([7, 8])]) 141 | ax.plot(x[([8, 9])], y[([8, 9])], z[([8, 9])]) 142 | ax.plot(x[([9, 10])], y[([9, 10])], z[([9, 10])]) 143 | ax.plot(x[([8, 11])], y[([8, 11])], z[([8, 11])]) 144 | ax.plot(x[([11, 12])], y[([11, 12])], z[([11, 12])]) 145 | ax.plot(x[([12, 13])], y[([12, 13])], z[([12, 13])]) 146 | ax.plot(x[([8, 14])], y[([8, 14])], z[([8, 14])]) 147 | ax.plot(x[([14, 15])], y[([14, 15])], z[([14, 15])]) 148 | ax.plot(x[([15, 16])], y[([15, 16])], z[([15, 16])]) 149 | 150 | # Create cubic bounding box to simulate equal aspect ratio 151 | max_range = np.array([x.max() - x.min(), y.max() - y.min(), z.max() - z.min()]).max() 152 | Xb = 0.5 * max_range * np.mgrid[-1:2:2, -1:2:2, -1:2:2][0].flatten() + 0.5 * (x.max() + x.min()) 153 | Yb = 0.5 * max_range * np.mgrid[-1:2:2, -1:2:2, -1:2:2][1].flatten() + 0.5 * (y.max() + y.min()) 154 | Zb = 0.5 * max_range * np.mgrid[-1:2:2, -1:2:2, -1:2:2][2].flatten() + 0.5 * (z.max() + z.min()) 155 | 156 | for xb, yb, zb in zip(Xb, Yb, Zb): 157 | ax.plot([xb], [yb], [zb], 'w') 158 | 159 | plt.axis('equal') 160 | 161 | a = anim.FuncAnimation(fig, update, frames=len(poses), repeat=False) 162 | plt.show() 163 | plt.savefig('show/mean_train_pose_{}_{}'.format(subject, action), bbox_inches='tight') 164 | plt.close() 165 | 166 | return 167 | 168 | 169 | def drawskeleton(img, kps, thickness=3, lcolor=(255,0,0), rcolor=(0,0,255), mpii=2): 170 | 171 | if mpii == 0: # h36m with mpii joints 172 | connections = [[0, 1], [1, 2], [2, 3], [0, 4], [4, 5], 173 | [5, 6], [0, 8], [8, 9], [9, 10], 174 | [8, 11], [11, 12], [12, 13], [8, 14], [14, 15], [15, 16]] 175 | LR = np.array([1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1], dtype=bool) 176 | elif mpii == 1: # only mpii 177 | connections = [[0, 1], [1, 2], [2, 6], [6, 3], [3, 4], [4, 5], [6, 7], 178 | [7, 8], [8, 9], [7, 12], [12, 11], [11, 10], [7, 13], [13, 14], [14, 15]] 179 | LR = np.array([1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0], dtype=bool) 180 | else: # default h36m 181 | connections = [[0, 1], [1, 2], [2, 3], [0, 4], [4, 5], 182 | [5, 6], [0, 7], [7, 8], [8, 9], [9, 10], 183 | [8, 11], [11, 12], [12, 13], [8, 14], [14, 15], [15, 16]] 184 | 185 | LR = np.array([0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], dtype=bool) 186 | 187 | for j,c in enumerate(connections): 188 | start = map(int, kps[c[0]]) 189 | end = map(int, kps[c[1]]) 190 | start = list(start) 191 | end = list(end) 192 | cv2.line(img, (start[0], start[1]), (end[0], end[1]), lcolor if LR[j] else rcolor, thickness) 193 | 194 | 195 | def show3Dpose(channels, ax, radius=40, mpii=2, lcolor='#ff0000', rcolor='#0000ff'): 196 | vals = channels 197 | 198 | if mpii == 0: # h36m with mpii joints 199 | connections = [[0, 1], [1, 2], [2, 3], [0, 4], [4, 5], 200 | [5, 6], [0, 8], [8, 9], [9, 10], 201 | [8, 11], [11, 12], [12, 13], [8, 14], [14, 15], [15, 16]] 202 | LR = np.array([0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1], dtype=bool) 203 | elif mpii == 1: # only mpii 204 | connections = [[0, 1], [1, 2], [2, 6], [6, 3], [3, 4], [4, 5], [6, 7], 205 | [7, 8], [8, 9], [7, 12], [12, 11], [11, 10], [7, 13], [13, 14], [14, 15]] 206 | LR = np.array([0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1], dtype=bool) 207 | else: # default h36m 208 | connections = [[0, 1], [1, 2], [2, 3], [0, 4], [4, 5], 209 | [5, 6], [0, 7], [7, 8], [8, 9], [9, 10], 210 | [8, 11], [11, 12], [12, 13], [8, 14], [14, 15], [15, 16]] 211 | 212 | LR = np.array([0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], dtype=bool) 213 | 214 | for ind, (i,j) in enumerate(connections): 215 | x, y, z = [np.array([vals[i, c], vals[j, c]]) for c in range(3)] 216 | ax.plot(x, y, z, lw=2, c=lcolor if LR[ind] else rcolor) 217 | 218 | RADIUS = radius # space around the subject 219 | if mpii == 1: 220 | xroot, yroot, zroot = vals[6, 0], vals[6, 1], vals[6, 2] 221 | else: 222 | xroot, yroot, zroot = vals[0, 0], vals[0, 1], vals[0, 2] 223 | ax.set_xlim3d([-RADIUS + xroot, RADIUS + xroot]) 224 | ax.set_zlim3d([-RADIUS + zroot, RADIUS + zroot]) 225 | ax.set_ylim3d([-RADIUS + yroot, RADIUS + yroot]) 226 | 227 | ax.set_xlabel("x") 228 | ax.set_ylabel("y") 229 | ax.set_zlabel("z") 230 | -------------------------------------------------------------------------------- /common/dataset/h36m_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import numpy as np 9 | import copy 10 | from common.dataset.skeleton import Skeleton 11 | from common.dataset.mocap_dataset import MocapDataset 12 | 13 | h36m_skeleton = Skeleton(parents=[-1, 0, 1, 2, 3, 4, 0, 6, 7, 8, 9, 0, 11, 12, 13, 14, 12, 14 | 16, 17, 18, 19, 20, 19, 22, 12, 24, 25, 26, 27, 28, 27, 30], 15 | joints_left=[6, 7, 8, 9, 10, 16, 17, 18, 19, 20, 21, 22, 23], 16 | joints_right=[1, 2, 3, 4, 5, 24, 25, 26, 27, 28, 29, 30, 31]) 17 | 18 | h36m_cameras_intrinsic_params = [ 19 | { 20 | 'id': '54138969', 21 | 'center': [512.54150390625, 515.4514770507812], 22 | 'focal_length': [1145.0494384765625, 1143.7811279296875], 23 | 'radial_distortion': [-0.20709891617298126, 0.24777518212795258, -0.0030751503072679043], 24 | 'tangential_distortion': [-0.0009756988729350269, -0.00142447161488235], 25 | 'res_w': 1000, 26 | 'res_h': 1002, 27 | 'azimuth': 70, # Only used for visualization 28 | }, 29 | { 30 | 'id': '55011271', 31 | 'center': [508.8486328125, 508.0649108886719], 32 | 'focal_length': [1149.6756591796875, 1147.5916748046875], 33 | 'radial_distortion': [-0.1942136287689209, 0.2404085397720337, 0.006819975562393665], 34 | 'tangential_distortion': [-0.0016190266469493508, -0.0027408944442868233], 35 | 'res_w': 1000, 36 | 'res_h': 1000, 37 | 'azimuth': -70, # Only used for visualization 38 | }, 39 | { 40 | 'id': '58860488', 41 | 'center': [519.8158569335938, 501.40264892578125], 42 | 'focal_length': [1149.1407470703125, 1148.7989501953125], 43 | 'radial_distortion': [-0.2083381861448288, 0.25548800826072693, -0.0024604974314570427], 44 | 'tangential_distortion': [0.0014843869721516967, -0.0007599993259645998], 45 | 'res_w': 1000, 46 | 'res_h': 1000, 47 | 'azimuth': 110, # Only used for visualization 48 | }, 49 | { 50 | 'id': '60457274', 51 | 'center': [514.9682006835938, 501.88201904296875], 52 | 'focal_length': [1145.5113525390625, 1144.77392578125], 53 | 'radial_distortion': [-0.198384091258049, 0.21832367777824402, -0.008947807364165783], 54 | 'tangential_distortion': [-0.0005872055771760643, -0.0018133620033040643], 55 | 'res_w': 1000, 56 | 'res_h': 1002, 57 | 'azimuth': -110, # Only used for visualization 58 | }, 59 | ] 60 | 61 | h36m_cameras_extrinsic_params = { 62 | 'S1': [ 63 | { 64 | 'orientation': [0.1407056450843811, -0.1500701755285263, -0.755240797996521, 0.6223280429840088], 65 | 'translation': [1841.1070556640625, 4955.28466796875, 1563.4454345703125], 66 | }, 67 | { 68 | 'orientation': [0.6157187819480896, -0.764836311340332, -0.14833825826644897, 0.11794740706682205], 69 | 'translation': [1761.278564453125, -5078.0068359375, 1606.2650146484375], 70 | }, 71 | { 72 | 'orientation': [0.14651472866535187, -0.14647851884365082, 0.7653023600578308, -0.6094175577163696], 73 | 'translation': [-1846.7777099609375, 5215.04638671875, 1491.972412109375], 74 | }, 75 | { 76 | 'orientation': [0.5834008455276489, -0.7853162288665771, 0.14548823237419128, -0.14749594032764435], 77 | 'translation': [-1794.7896728515625, -3722.698974609375, 1574.8927001953125], 78 | }, 79 | ], 80 | 'S2': [ 81 | {}, 82 | {}, 83 | {}, 84 | {}, 85 | ], 86 | 'S3': [ 87 | {}, 88 | {}, 89 | {}, 90 | {}, 91 | ], 92 | 'S4': [ 93 | {}, 94 | {}, 95 | {}, 96 | {}, 97 | ], 98 | 'S5': [ 99 | { 100 | 'orientation': [0.1467377245426178, -0.162370964884758, -0.7551892995834351, 0.6178938746452332], 101 | 'translation': [2097.3916015625, 4880.94482421875, 1605.732421875], 102 | }, 103 | { 104 | 'orientation': [0.6159758567810059, -0.7626792192459106, -0.15728192031383514, 0.1189815029501915], 105 | 'translation': [2031.7008056640625, -5167.93310546875, 1612.923095703125], 106 | }, 107 | { 108 | 'orientation': [0.14291371405124664, -0.12907841801643372, 0.7678384780883789, -0.6110143065452576], 109 | 'translation': [-1620.5948486328125, 5171.65869140625, 1496.43701171875], 110 | }, 111 | { 112 | 'orientation': [0.5920479893684387, -0.7814217805862427, 0.1274748593568802, -0.15036417543888092], 113 | 'translation': [-1637.1737060546875, -3867.3173828125, 1547.033203125], 114 | }, 115 | ], 116 | 'S6': [ 117 | { 118 | 'orientation': [0.1337897777557373, -0.15692396461963654, -0.7571090459823608, 0.6198879480361938], 119 | 'translation': [1935.4517822265625, 4950.24560546875, 1618.0838623046875], 120 | }, 121 | { 122 | 'orientation': [0.6147197484970093, -0.7628812789916992, -0.16174767911434174, 0.11819244921207428], 123 | 'translation': [1969.803955078125, -5128.73876953125, 1632.77880859375], 124 | }, 125 | { 126 | 'orientation': [0.1529948115348816, -0.13529130816459656, 0.7646096348762512, -0.6112781167030334], 127 | 'translation': [-1769.596435546875, 5185.361328125, 1476.993408203125], 128 | }, 129 | { 130 | 'orientation': [0.5916101336479187, -0.7804774045944214, 0.12832270562648773, -0.1561593860387802], 131 | 'translation': [-1721.668701171875, -3884.13134765625, 1540.4879150390625], 132 | }, 133 | ], 134 | 'S7': [ 135 | { 136 | 'orientation': [0.1435241848230362, -0.1631336808204651, -0.7548328638076782, 0.6188824772834778], 137 | 'translation': [1974.512939453125, 4926.3544921875, 1597.8326416015625], 138 | }, 139 | { 140 | 'orientation': [0.6141672730445862, -0.7638262510299683, -0.1596645563840866, 0.1177929937839508], 141 | 'translation': [1937.0584716796875, -5119.7900390625, 1631.5665283203125], 142 | }, 143 | { 144 | 'orientation': [0.14550060033798218, -0.12874816358089447, 0.7660516500473022, -0.6127139329910278], 145 | 'translation': [-1741.8111572265625, 5208.24951171875, 1464.8245849609375], 146 | }, 147 | { 148 | 'orientation': [0.5912848114967346, -0.7821764349937439, 0.12445473670959473, -0.15196487307548523], 149 | 'translation': [-1734.7105712890625, -3832.42138671875, 1548.5830078125], 150 | }, 151 | ], 152 | 'S8': [ 153 | { 154 | 'orientation': [0.14110587537288666, -0.15589867532253265, -0.7561917304992676, 0.619644045829773], 155 | 'translation': [2150.65185546875, 4896.1611328125, 1611.9046630859375], 156 | }, 157 | { 158 | 'orientation': [0.6169601678848267, -0.7647668123245239, -0.14846350252628326, 0.11158157885074615], 159 | 'translation': [2219.965576171875, -5148.453125, 1613.0440673828125], 160 | }, 161 | { 162 | 'orientation': [0.1471444070339203, -0.13377119600772858, 0.7670128345489502, -0.6100369691848755], 163 | 'translation': [-1571.2215576171875, 5137.0185546875, 1498.1761474609375], 164 | }, 165 | { 166 | 'orientation': [0.5927824378013611, -0.7825870513916016, 0.12147816270589828, -0.14631995558738708], 167 | 'translation': [-1476.913330078125, -3896.7412109375, 1547.97216796875], 168 | }, 169 | ], 170 | 'S9': [ 171 | { 172 | 'orientation': [0.15540587902069092, -0.15548215806484222, -0.7532095313072205, 0.6199594736099243], 173 | 'translation': [2044.45849609375, 4935.1171875, 1481.2275390625], 174 | }, 175 | { 176 | 'orientation': [0.618784487247467, -0.7634735107421875, -0.14132238924503326, 0.11933968216180801], 177 | 'translation': [1990.959716796875, -5123.810546875, 1568.8048095703125], 178 | }, 179 | { 180 | 'orientation': [0.13357827067375183, -0.1367100477218628, 0.7689454555511475, -0.6100738644599915], 181 | 'translation': [-1670.9921875, 5211.98583984375, 1528.387939453125], 182 | }, 183 | { 184 | 'orientation': [0.5879399180412292, -0.7823407053947449, 0.1427614390850067, -0.14794869720935822], 185 | 'translation': [-1696.04345703125, -3827.099853515625, 1591.4127197265625], 186 | }, 187 | ], 188 | 'S11': [ 189 | { 190 | 'orientation': [0.15232472121715546, -0.15442320704460144, -0.7547563314437866, 0.6191070079803467], 191 | 'translation': [2098.440185546875, 4926.5546875, 1500.278564453125], 192 | }, 193 | { 194 | 'orientation': [0.6189449429512024, -0.7600917220115662, -0.15300633013248444, 0.1255258321762085], 195 | 'translation': [2083.182373046875, -4912.1728515625, 1561.07861328125], 196 | }, 197 | { 198 | 'orientation': [0.14943228662014008, -0.15650227665901184, 0.7681233882904053, -0.6026304364204407], 199 | 'translation': [-1609.8153076171875, 5177.3359375, 1537.896728515625], 200 | }, 201 | { 202 | 'orientation': [0.5894251465797424, -0.7818877100944519, 0.13991211354732513, -0.14715361595153809], 203 | 'translation': [-1590.738037109375, -3854.1689453125, 1578.017578125], 204 | }, 205 | ], 206 | } 207 | 208 | 209 | class Human36mDataset(MocapDataset): 210 | def __init__(self, path, remove_static_joints=True): 211 | super().__init__(fps=50, skeleton=h36m_skeleton) 212 | 213 | self._cameras = copy.deepcopy(h36m_cameras_extrinsic_params) 214 | for cameras in self._cameras.values(): 215 | for i, cam in enumerate(cameras): 216 | cam.update(h36m_cameras_intrinsic_params[i]) 217 | for k, v in cam.items(): 218 | if k not in ['id', 'res_w', 'res_h']: 219 | cam[k] = np.array(v, dtype='float32') 220 | 221 | if 'translation' in cam: 222 | cam['translation'] = cam['translation'] / 1000 # mm to meters 223 | 224 | # Add intrinsic parameters vector 225 | cam['intrinsic'] = np.concatenate((cam['focal_length'], 226 | cam['center'], 227 | cam['radial_distortion'], 228 | cam['tangential_distortion'])) 229 | 230 | # Load serialized dataset 231 | data = np.load(path, allow_pickle=True)['positions_3d'].item() 232 | 233 | self._data = {} 234 | for subject, actions in data.items(): 235 | self._data[subject] = {} 236 | for action_name, positions in actions.items(): 237 | self._data[subject][action_name] = { 238 | 'positions': positions, 239 | 'cameras': self._cameras[subject], 240 | } 241 | 242 | if remove_static_joints: 243 | # Bring the skeleton to 17 joints instead of the original 32 244 | self.remove_joints([4, 5, 9, 10, 11, 16, 20, 21, 22, 23, 24, 28, 29, 30, 31]) 245 | 246 | # Rewire shoulders to the correct parents 247 | self._skeleton._parents[11] = 8 248 | self._skeleton._parents[14] = 8 249 | 250 | def supports_semi_supervised(self): 251 | return True 252 | 253 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /common/common_pytorch/model/fc_baseline.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import torch 8 | import torch.nn as nn 9 | from common.arguments.basic_args import parse_args 10 | args = parse_args() 11 | 12 | class TemporalModelBase(nn.Module): 13 | """ 14 | Do not instantiate this class. 15 | """ 16 | 17 | def __init__(self, num_joints_in, in_features, num_joints_out, 18 | filter_widths, causal, dropout, channels): 19 | super().__init__() 20 | 21 | # Validate input 22 | for fw in filter_widths: 23 | assert fw % 2 != 0, 'Only odd filter widths are supported' 24 | 25 | self.num_joints_in = num_joints_in 26 | self.in_features = in_features 27 | self.num_joints_out = num_joints_out 28 | self.filter_widths = filter_widths 29 | 30 | self.drop = nn.Dropout(dropout) 31 | self.relu = nn.LeakyReLU(negative_slope=0.01,inplace=True) 32 | 33 | self.pad = [filter_widths[0] // 2] 34 | self.expand_bn = nn.BatchNorm1d(channels, momentum=0.1) 35 | self.shrink = nn.Conv1d(channels, num_joints_out * 3, 1) 36 | 37 | def set_bn_momentum(self, momentum): 38 | self.expand_bn.momentum = momentum 39 | for bn in self.layers_bn: 40 | bn.momentum = momentum 41 | 42 | def receptive_field(self): 43 | """ 44 | Return the total receptive field of this model as # of frames. 45 | """ 46 | frames = 0 47 | for f in self.pad: 48 | frames += f 49 | return 1 + 2 * frames 50 | 51 | def total_causal_shift(self): 52 | """ 53 | Return the asymmetric offset for sequence padding. 54 | The returned value is typically 0 if causal convolutions are disabled, 55 | otherwise it is half the receptive field. 56 | """ 57 | frames = self.causal_shift[0] 58 | next_dilation = self.filter_widths[0] 59 | for i in range(1, len(self.filter_widths)): 60 | frames += self.causal_shift[i] * next_dilation 61 | next_dilation *= self.filter_widths[i] 62 | return frames 63 | 64 | def forward(self, x): 65 | assert len(x.shape) == 4 66 | assert x.shape[-2] == self.num_joints_in 67 | assert x.shape[-1] == self.in_features 68 | 69 | sz = x.shape[:3] 70 | 71 | x_out = x 72 | 73 | x = x.view(x.shape[0], x.shape[1], -1) 74 | x = x.permute(0, 2, 1) 75 | 76 | y = self._forward_blocks(x) 77 | 78 | y = y.permute(0, 2, 1) #[1024,1,3K] 79 | 80 | y = y.view(sz[0], -1, self.num_joints_out, 3) 81 | if args.norm == 'lcn': 82 | pose_2d = x_out + y[...,:2] 83 | y = torch.cat([pose_2d, y[...,2:3]], dim=-1) 84 | return y 85 | 86 | 87 | class TemporalModel(TemporalModelBase): 88 | """ 89 | Reference 3D pose estimation model with temporal convolutions. 90 | This implementation can be used for all use-cases. 91 | """ 92 | 93 | def __init__(self, num_joints_in, in_features, num_joints_out, 94 | filter_widths, causal=False, dropout=0.25, channels=1024, dense=False): 95 | """ 96 | Initialize this model. 97 | 98 | Arguments: 99 | num_joints_in -- number of input joints (e.g. 17 for Human3.6M) 100 | in_features -- number of input features for each joint (typically 2 for 2D input) 101 | num_joints_out -- number of output joints (can be different than input) 102 | filter_widths -- list of convolution widths, which also determines the # of blocks and receptive field 103 | causal -- use causal convolutions instead of symmetric convolutions (for real-time applications) 104 | dropout -- dropout probability 105 | channels -- number of convolution channels 106 | dense -- use regular dense convolutions instead of dilated convolutions (ablation experiment) 107 | """ 108 | super().__init__(num_joints_in, in_features, num_joints_out, filter_widths, causal, dropout, channels) 109 | 110 | self.expand_conv = nn.Conv1d(num_joints_in * in_features, channels, filter_widths[0],groups=1, bias=False) 111 | 112 | layers_conv = [] 113 | layers_bn = [] 114 | 115 | self.causal_shift = [(filter_widths[0]) // 2 if causal else 0] 116 | next_dilation = filter_widths[0] 117 | for i in range(1, len(filter_widths)): 118 | self.pad.append((filter_widths[i] - 1) * next_dilation // 2) 119 | self.causal_shift.append((filter_widths[i] // 2 * next_dilation) if causal else 0) 120 | 121 | layers_conv.append(nn.Conv1d(channels, channels, 122 | filter_widths[i] if not dense else (2 * self.pad[-1] + 1), 123 | dilation=next_dilation if not dense else 1,groups=1, 124 | bias=False)) 125 | layers_bn.append(nn.BatchNorm1d(channels, momentum=0.1)) 126 | layers_conv.append(nn.Conv1d(channels, channels, 1, dilation=1, bias=False,groups=1)) 127 | layers_bn.append(nn.BatchNorm1d(channels, momentum=0.1)) 128 | 129 | next_dilation *= filter_widths[i] 130 | 131 | self.layers_conv = nn.ModuleList(layers_conv) 132 | self.layers_bn = nn.ModuleList(layers_bn) 133 | 134 | def _forward_blocks(self, x): 135 | x = self.drop(self.relu(self.expand_bn(self.expand_conv(x)))) 136 | 137 | for i in range(len(self.pad) - 1): 138 | pad = self.pad[i + 1] 139 | shift = self.causal_shift[i + 1] 140 | res = x[:, :, pad + shift: x.shape[2] - pad + shift] 141 | 142 | x = self.drop(self.relu(self.layers_bn[2 * i](self.layers_conv[2 * i](x)))) 143 | x = res + self.drop(self.relu(self.layers_bn[2 * i + 1](self.layers_conv[2 * i + 1](x)))) 144 | 145 | x = self.shrink(x) 146 | return x 147 | 148 | 149 | class TemporalModelOptimized1f(TemporalModelBase): 150 | """ 151 | 3D pose estimation model optimized for single-frame batching, i.e. 152 | where batches have input length = receptive field, and output length = 1. 153 | This scenario is only used for training when stride == 1. 154 | 155 | This implementation replaces dilated convolutions with strided convolutions 156 | to avoid generating unused intermediate results. The weights are interchangeable 157 | with the reference implementation. 158 | """ 159 | 160 | def __init__(self, num_joints_in, in_features, num_joints_out, 161 | filter_widths, causal=False, dropout=0.25, channels=1024): 162 | """ 163 | Initialize this model. 164 | 165 | Arguments: 166 | num_joints_in -- number of input joints (e.g. 17 for Human3.6M) 167 | in_features -- number of input features for each joint (typically 2 for 2D input) 168 | num_joints_out -- number of output joints (can be different than input) 169 | filter_widths -- list of convolution widths, which also determines the # of blocks and receptive field 170 | causal -- use causal convolutions instead of symmetric convolutions (for real-time applications) 171 | dropout -- dropout probability 172 | channels -- number of convolution channels 173 | """ 174 | super().__init__(num_joints_in, in_features, num_joints_out, filter_widths, causal, dropout, channels) 175 | 176 | self.expand_conv = nn.Conv1d(num_joints_in * in_features, channels, filter_widths[0], stride=filter_widths[0], groups=1, bias=False) 177 | 178 | layers_conv = [] 179 | layers_bn = [] 180 | 181 | self.causal_shift = [(filter_widths[0] // 2) if causal else 0] 182 | next_dilation = filter_widths[0] 183 | for i in range(1, len(filter_widths)): 184 | self.pad.append((filter_widths[i] - 1) * next_dilation // 2) 185 | self.causal_shift.append((filter_widths[i] // 2) if causal else 0) 186 | 187 | layers_conv.append(nn.Conv1d(channels, channels, filter_widths[i], stride=filter_widths[i], groups=1,bias=False)) 188 | layers_bn.append(nn.BatchNorm1d(channels, momentum=0.1)) 189 | layers_conv.append(nn.Conv1d(channels, channels, 1, dilation=1,groups=1, bias=False)) 190 | layers_bn.append(nn.BatchNorm1d(channels, momentum=0.1)) 191 | next_dilation *= filter_widths[i] 192 | 193 | self.layers_conv = nn.ModuleList(layers_conv) 194 | self.layers_bn = nn.ModuleList(layers_bn) 195 | 196 | def _forward_blocks(self, x): 197 | x = self.drop(self.relu(self.expand_bn(self.expand_conv(x)))) 198 | 199 | for i in range(len(self.pad) - 1): 200 | res = x[:, :, self.causal_shift[i + 1] + self.filter_widths[i + 1] // 2:: self.filter_widths[i + 1]] 201 | x = self.drop(self.relu(self.layers_bn[2 * i](self.layers_conv[2 * i](x)))) 202 | x = res + self.drop(self.relu(self.layers_bn[2 * i + 1](self.layers_conv[2 * i + 1](x)))) 203 | 204 | x = self.shrink(x) 205 | return x 206 | 207 | 208 | 209 | class Same_Model(TemporalModelBase): 210 | """ 211 | Reference 3D pose estimation model with temporal convolutions. 212 | This implementation can be used for all use-cases. 213 | Change the padding number and type, padding before putting into convolution (self.zeropadding, self.rep_pad with ReplicationPad1d, nn.ConstantPad1d, nn.ReflectionPad1d): 214 | if padding == 0: The same setting as the TemporalModel 215 | if padding == number of dilation or stride: it will keep the same temporal size as 2d inputs 216 | """ 217 | 218 | def __init__(self, num_joints_in, in_features, num_joints_out, 219 | filter_widths, causal=False, dropout=0.25, channels=1024): 220 | """ 221 | Initialize this model. 222 | 223 | New Arguments: 224 | FlexGroupLayer: Use this function with different group strategies 225 | self.rep_pad: Recommend use nn.ReflectionPad1d to make the same temporal size as 2d inputs. 226 | 227 | """ 228 | mode = 'replicate' #padding mode: reflect, replicate, zeros 229 | super().__init__(num_joints_in, in_features, num_joints_out, filter_widths, causal, dropout, channels) 230 | self.expand_conv = nn.Conv1d(num_joints_in * in_features, channels, kernel_size=filter_widths[0], bias=False) 231 | 232 | layers_conv = [] 233 | layers_bn = [] 234 | 235 | self.ref_pad = [] 236 | 237 | self.causal_shift = [(filter_widths[0]) // 2 if causal else 0] 238 | next_dilation = filter_widths[0] 239 | for i in range(1, len(filter_widths)): 240 | self.pad.append((filter_widths[i] - 1) * next_dilation // 2) 241 | self.causal_shift.append((filter_widths[i] // 2 * next_dilation) if causal else 0) 242 | layers_conv.append(nn.Conv1d(channels, channels, kernel_size=filter_widths[0], dilation=next_dilation, bias=False)) 243 | layers_bn.append(nn.BatchNorm1d(channels, momentum=0.1)) 244 | layers_conv.append(nn.Conv1d(channels, channels, kernel_size=1, dilation=1, bias=False)) 245 | layers_bn.append(nn.BatchNorm1d(channels, momentum=0.1)) 246 | self.ref_pad.append(nn.ReplicationPad1d(next_dilation)) 247 | #self.ref_pad.append(nn.ReflectionPad1d(next_dilation)) 248 | next_dilation *= filter_widths[i] 249 | #self.reflec = nn.ReflectionPad1d(1) 250 | self.reflec = nn.ReplicationPad1d(1) 251 | self.layers_conv = nn.ModuleList(layers_conv) 252 | self.layers_bn = nn.ModuleList(layers_bn) 253 | self.avg_pool = nn.AdaptiveAvgPool1d(1) 254 | self.final_layer = nn.Conv1d(channels, num_joints_out * 3, kernel_size=1, bias=True) 255 | 256 | def _forward_blocks(self, x): 257 | x = self.drop(self.relu(self.expand_bn(self.expand_conv(self.reflec(x))))) 258 | for i in range(len(self.pad) - 1): 259 | res = x 260 | x = self.drop(self.relu(self.layers_bn[2 * i](self.layers_conv[2 * i](self.ref_pad[i](x))))) 261 | x = res + self.drop(self.relu(self.layers_bn[2 * i + 1](self.layers_conv[2 * i + 1](x)))) 262 | x = self.final_layer(x) 263 | return x 264 | 265 | -------------------------------------------------------------------------------- /common/arguments/basic_args.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import argparse 9 | 10 | def parse_args(): 11 | parser = argparse.ArgumentParser(description='Training SRNet script') 12 | 13 | # General arguments 14 | parser.add_argument('-d', '--dataset', default='h36m', type=str, metavar='NAME', help='target dataset') 15 | parser.add_argument('-k', '--keypoints', default='gt', type=str, metavar='NAME', 16 | help='2D detections to use', choices=['gt','cpn_ft_h36m_dbb']) 17 | parser.add_argument('--rand-seed', default=4321, type=int, metavar='N', help='random seeds') 18 | ### Protocol settings 19 | # Differ from subjects (people), e.g. standard protocol 1 (mpjpe) & 2 (pa-mpjpe)#S5,S6,S7,S8 20 | parser.add_argument('-str', '--subjects-train', default='S1,S5,S6,S7,S8', type=str, metavar='LIST', 21 | help='training subjects separated by comma') 22 | parser.add_argument('-ste', '--subjects-test', default='S9,S11', type=str, metavar='LIST', 23 | help='test subjects separated by comma') 24 | parser.add_argument('--subjects-full', default='S1,S5,S6,S7,S8,S9,S11', type=str, metavar='LIST', 25 | help='All subjects separated by comma') 26 | parser.add_argument('-sun', '--subjects-unlabeled', default='', type=str, metavar='LIST', 27 | help='unlabeled subjects separated by comma for self-supervision') 28 | 29 | # Differ from actions, e.g. cross-action validation protocol (one action for training, others for test) 30 | parser.add_argument('-a', '--actions', default='*', type=str, metavar='LIST', 31 | help='actions to train/test on, separated by comma, or * for all') 32 | parser.add_argument('--use-action-split', default=False, help='Train one some actions, test on others') 33 | parser.add_argument('--train-action', default='Discussion', type=str, metavar='LIST', 34 | help='action name for training') 35 | parser.add_argument('--test-action', 36 | default='Greeting,Sitting,SittingDown,WalkTogether,Phoning,Posing,WalkDog,Walking,Purchases,Waiting,Directions,Smoking,Photo,Eating', 37 | type=str, metavar='LIST', help='action name for test') 38 | parser.add_argument('--all-action', 39 | default='Greeting,Sitting,SittingDown,WalkTogether,Phoning,Posing,WalkDog,Walking,Purchases,Waiting,Directions,Smoking,Photo,Eating,Discussion', 40 | type=str, metavar='LIST', help='action name for test') 41 | parser.add_argument('--action_unlabeled', default='', type=str, metavar='LIST', help='action name for training') 42 | 43 | # Differ from camera settings, e.g. cross-camera validation 44 | parser.add_argument('--cam-test', default='', type=list, metavar='LIST', 45 | help='test camera viewpoint, If None, use all cameras; If [5], choose four of them randomly', 46 | choices=[0, 1, 2, 3, 5]) 47 | parser.add_argument('--cam-train', default='', type=list, metavar='LIST', 48 | help='train camera viewpoint,If None, use all cameras; If [5], choose four of them randomly', 49 | choices=[0, 1, 2, 3, 5]) 50 | 51 | #### The data to test: 52 | parser.add_argument('--three-dpw', default=False, help='Cross train/test on 3DPW testset') 53 | parser.add_argument('--use-hard-test', default=False, 54 | help='For evaluation setting, using rarest N% test set in S9/S11') 55 | parser.add_argument('--use-mpi-test', default=False, help='Cross test on MPI-INF-3DHP') 56 | 57 | #### Data normalization 58 | parser.add_argument('--norm', choices=['base', 'proj', 'weak_proj', 'lcn'], type=str, help='way of data normalization', default='base') 59 | 60 | parser.add_argument('-c', '--checkpoint', default='checkpoint', type=str, metavar='PATH', 61 | help='checkpoint directory to store models') 62 | parser.add_argument('-bc', '--best-checkpoint', default='best_checkpoint', type=str, metavar='PATH', 63 | help='best checkpoint directory to store the best models') 64 | parser.add_argument('--checkpoint-frequency', default=1, type=int, metavar='N', 65 | help='create a checkpoint every N epochs') 66 | parser.add_argument('-r', '--resume', default='', type=str, metavar='FILENAME', 67 | help='checkpoint to resume (file name)') 68 | parser.add_argument('--evaluate', default='', type=str, metavar='FILENAME', 69 | help='checkpoint to evaluate (file name)') 70 | parser.add_argument('-ft', '--finetune', default='', type=str, metavar='FILENAME', 71 | help='checkpoint to finetune (file name)') 72 | parser.add_argument('--render', action='store_true', help='visualize a particular video') 73 | parser.add_argument('--by-subject', action='store_true', help='break down error by subject (on evaluation)') 74 | parser.add_argument('--export-training-curves', action='store_true', help='save training curves as .png images') 75 | 76 | # Model arguments 77 | parser.add_argument('--model', default='srnet', type=str, choices=['srnet', 'fc'], 78 | help='the name of models which you train') 79 | parser.add_argument('-mn', '--model-name', default='sr_h36m_gt2d', type=str, 80 | help='the name of models which you want to save') 81 | parser.add_argument('-s', '--stride', default=1, type=int, metavar='N', help='chunk size to use during training') 82 | parser.add_argument('-e', '--epochs', default=60, type=int, metavar='N', help='number of training epochs') 83 | parser.add_argument('-b', '--batch-size', default=1024, type=int, metavar='N', 84 | help='batch size in terms of predicted frames') 85 | parser.add_argument('-tb', '--test-batch-size', default=100, type=int, metavar='N', 86 | help='batch size in terms of predicted frames') 87 | parser.add_argument('-arc', '--architecture', default='1,1,1', type=str, metavar='LAYERS', 88 | help='filter widths separated by comma') 89 | parser.add_argument('--causal', action='store_true', help='use causal convolutions for real-time processing') 90 | parser.add_argument('-ch', '--channels', default=1024, type=int, metavar='N', 91 | help='number of channels in convolution layers') 92 | 93 | parser.add_argument('-drop', '--dropout', default=0, type=float, metavar='P', help='dropout probability') 94 | parser.add_argument('-lr', '--learning-rate', default=0.001, type=float, metavar='LR', help='initial learning rate') 95 | parser.add_argument('-lrd', '--lr-decay', default=0.95, type=float, metavar='LR', 96 | help='learning rate decay per epoch') 97 | parser.add_argument('-no-da', '--no-data-augmentation', dest='data_augmentation', action='store_false', 98 | help='disable train-time flipping') 99 | parser.add_argument('-no-tta', '--no-test-time-augmentation', dest='test_time_augmentation', action='store_false', 100 | help='disable test-time flipping') 101 | parser.add_argument('--conf', default=0, type=int, metavar='N',help='confidence score number') 102 | #### basic model settings, Experimental 103 | parser.add_argument('--subset', default=1, type=float, metavar='FRACTION', help='reduce dataset size by fraction') 104 | parser.add_argument('--downsample', default=1, type=int, metavar='FACTOR', 105 | help='downsample frame rate by factor (semi-supervised)') 106 | parser.add_argument('--warmup', default=1, type=int, metavar='N', help='warm-up epochs for semi-supervision') 107 | parser.add_argument('--no-eval', action='store_true', 108 | help='disable epoch evaluation while training (small speed-up)') 109 | parser.add_argument('--dense', action='store_true', help='use dense convolutions instead of dilated convolutions') 110 | parser.add_argument('--disable-optimizations', action='store_true', 111 | help='disable optimized model for single-frame predictions') 112 | parser.add_argument('--linear-projection', action='store_true', 113 | help='use only linear coefficients for semi-supervised projection') 114 | parser.add_argument('--no-bone-length', action='store_false', dest='bone_length_term', 115 | help='disable bone length term in semi-supervised settings') 116 | parser.add_argument('--no-proj', action='store_true', help='disable projection for semi-supervised setting') 117 | parser.add_argument('--root-log', default='log', type=str) 118 | 119 | parser.add_argument('--train-rotation', default=False, 120 | help='Use random Y-axis rotation for training stage, please close train-flip augmentation!') 121 | parser.add_argument('--repeat-num', default=1, type=int, metavar='N', help='number of repeat rotation') 122 | 123 | # Temporal Pose settings 124 | parser.add_argument('--use-same-3d-input', default=False, help='input frame number is equal to output frame number') 125 | 126 | #### For smooth loss: 127 | parser.add_argument('--threshold', default=0.0004, type=float, metavar='LR', 128 | help='The threshold of smooth loss to control the loss functions') 129 | parser.add_argument('--mi', default=0.1, type=float, metavar='LR', help='The pow of smooth loss') 130 | 131 | parser.add_argument('--scale', default=0.001, type=float, metavar='LR', help='') 132 | parser.add_argument('--rnum', default=0, type=int, metavar='LR', help='') 133 | 134 | # Render function Visualization 135 | parser.add_argument('--viz-subject', type=str, metavar='STR', help='subject to render') 136 | parser.add_argument('--viz-action', type=str, metavar='STR', help='action to render') 137 | parser.add_argument('--viz-camera', type=int, default=0, metavar='N', help='camera to render') 138 | parser.add_argument('--viz-video', type=str, metavar='PATH', help='path to input video') 139 | parser.add_argument('--viz-skip', type=int, default=0, metavar='N', help='skip first N frames of input video') 140 | parser.add_argument('--viz-output', type=str, metavar='PATH', help='output file name (.gif or .mp4)') 141 | parser.add_argument('--viz-export', type=str, metavar='PATH', help='output file name for coordinates') 142 | parser.add_argument('--viz-bitrate', type=int, default=3000, metavar='N', help='bitrate for mp4 videos') 143 | parser.add_argument('--viz-no-ground-truth', action='store_true', help='do not show ground-truth poses') 144 | parser.add_argument('--viz-limit', type=int, default=-1, metavar='N', help='only render first N frames') 145 | parser.add_argument('--viz-downsample', type=int, default=1, metavar='N', help='downsample FPS by a factor N') 146 | parser.add_argument('--viz-size', type=int, default=5, metavar='N', help='image size') 147 | 148 | parser.set_defaults(bone_length_term=True) 149 | parser.set_defaults(data_augmentation=True) 150 | parser.set_defaults(test_time_augmentation=True) 151 | 152 | ### SRNet arguments 153 | ### split features 154 | parser.add_argument('-mo', '--modulation', default=False, 155 | help='Use modulation module for temporal mask self-attention multiply the whole channel [all joint inputs]') 156 | parser.add_argument('--group-modulation', default=False, 157 | help='Use modulation module for multiply each group as local attention [group-wise joint inputs]') 158 | parser.add_argument('--split-modulation', default=True, 159 | help='Use modulation module multiply each group as global attention [except local joint inputs]') 160 | parser.add_argument('--channelwise', default=False, 161 | help='Use modulation module multiply each group with channel-wise attention [all joint inputs]') 162 | ### recombine feature source 163 | 164 | parser.add_argument('--split', choices=['all', 'others', 'none'], type=str, 165 | help='way of feature split', default='others') 166 | 167 | ### recombine operators 168 | parser.add_argument('--recombine', choices=['multiply', 'add', 'concat'], type=str, 169 | help='way of low-dimension global features and local feature recombination', default='multiply') 170 | parser.add_argument('--mean-func', default=False, help='Use mean function [other joint inputs]') 171 | parser.add_argument('--repeat-concat', default=False, 172 | help='Use [repeat number] concatenate for fusion group feature and other joint features, if True, --concat must be True') 173 | 174 | parser.add_argument('--ups-mean', default=False, help='Use flexible mean function [other joint inputs]') 175 | 176 | # Group number 177 | parser.add_argument('--group', type=int, default=5, metavar='N', help='Guide the group strategies',choices=[1,2,3,5]) 178 | 179 | args = parser.parse_args() 180 | # Check invalid configuration 181 | if args.resume and args.evaluate: 182 | print('Invalid flags: --resume and --evaluate cannot be set at the same time') 183 | exit() 184 | 185 | if args.export_training_curves and args.no_eval: 186 | print('Invalid flags: --export-training-curves and --no-eval cannot be set at the same time') 187 | exit() 188 | 189 | return args 190 | --------------------------------------------------------------------------------