├── models ├── __init__.py ├── Orthogonal.py ├── motion_pred.py ├── LinNF.py └── model.py ├── utils ├── __init__.py ├── logger.py ├── metrics.py ├── vis_util.py ├── util.py ├── torch.py ├── vis_poses.py └── valid_angle_check.py ├── images └── intro.png ├── motion_pred ├── cfg │ ├── h36m.yml │ └── humaneva.yml └── utils │ ├── config.py │ ├── dataset.py │ ├── dataset_humaneva.py │ ├── dataset_h36m.py │ ├── skeleton.py │ ├── visualization.py │ ├── dataset_humaneva_multimodal.py │ └── dataset_h36m_multimodal.py ├── LICENSE ├── README.md ├── train_nf.py └── main.py /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from utils.torch import * 2 | from utils.logger import * 3 | -------------------------------------------------------------------------------- /images/intro.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuoweiXu368/SLD-HMP/HEAD/images/intro.png -------------------------------------------------------------------------------- /motion_pred/cfg/h36m.yml: -------------------------------------------------------------------------------- 1 | dataset: h36m 2 | nz: 128 3 | t_his: 25 4 | t_pred: 100 5 | dropout: 0.1 6 | nk: 50 7 | nk1: 5 8 | nk2: 10 9 | lambdas: [ 0, 500.0, 0, 64, 2.0, 1.0, 0.01, 50, 100] 10 | specs: 11 | num_flow_layer: 1 12 | n_pre: 20 13 | multimodal_path: ./data/data_multi_modal/t_his25_1_thre0.500_t_pred100_thre0.100_filtered_dlow.npz 14 | data_candi_path: ./data/data_multi_modal/data_candi_t_his25_t_pred100_skiprate20.npz 15 | 16 | lr: 1.e-3 17 | batch_size: 16 18 | num_epoch: 500 19 | num_epoch_fix: 100 20 | num_data_sample: 5000 21 | normalize_data: False -------------------------------------------------------------------------------- /motion_pred/cfg/humaneva.yml: -------------------------------------------------------------------------------- 1 | dataset: humaneva 2 | nz: 128 3 | t_his: 15 4 | t_pred: 60 5 | dropout: 0.1 6 | nk: 50 7 | nk1: 5 8 | nk2: 10 9 | lambdas: [ 0, 50.0, 0, 16, 8.0, 4.0, 0.002, 10, 10] 10 | specs: 11 | num_flow_layer: 1 12 | n_pre: 16 13 | multimodal_path: ./data/humaneva_multi_modal/t_his15_1_thre0.500_t_pred60_thre0.010_index_filterd.npz 14 | data_candi_path: ./data/humaneva_multi_modal/data_candi_t_his15_t_pred60_skiprate15.npz 15 | 16 | lr: 1.e-3 17 | batch_size: 16 18 | num_epoch: 500 19 | num_epoch_fix: 100 20 | num_data_sample: 2000 21 | normalize_data: False 22 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Guowei Xu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /models/Orthogonal.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class Orthogonal(nn.Module): 6 | 7 | def __init__(self, d): 8 | super(Orthogonal, self).__init__() 9 | self.d = d 10 | self.U = torch.nn.Parameter(torch.zeros((d, d)).normal_(0, 0.5)) 11 | 12 | self.reset_parameters() 13 | 14 | def reset_parameters(self): 15 | self.U.data = self.U.data / torch.norm(self.U.data, dim=1, keepdim=True) 16 | 17 | def sequential_mult(self, V, X): 18 | for row in range(V.shape[0] - 1, -1, -1): 19 | X = X - 2 * (X @ V[row:row + 1, :].t() @ V[row:row + 1, :]) / \ 20 | (V[row:row + 1, :] @ V[row:row + 1, :].t())[0, 0] 21 | return X 22 | 23 | def forward(self, X, invert=False): 24 | """ 25 | 26 | @param X: 27 | @return: 28 | """ 29 | if not invert: 30 | X = self.sequential_mult(self.U, X) 31 | else: 32 | X = self.inverse(X) 33 | return X 34 | 35 | def inverse(self, X): 36 | X = self.sequential_mult(torch.flip(self.U, dims=[0]), X) 37 | return X 38 | 39 | def lgdet(self, X): 40 | return 0 41 | -------------------------------------------------------------------------------- /models/motion_pred.py: -------------------------------------------------------------------------------- 1 | from models import LinNF, model 2 | import numpy as np 3 | 4 | def get_model(cfg, dataset, model_type='h36m'): 5 | traj_dim = dataset.traj_dim // 3 6 | specs = {'model_name': 'NFDiag', 'rnn_type': 'gru', 'nh_mlp': [1024, 512], 'x_birnn': False} 7 | if model_type == 'h36m' or model_type == 'humaneva': 8 | keep_joints = dataset.kept_joints[1:] 9 | 10 | if model_type == 'h36m': 11 | parents=[-1, 0, 1, 2, 3, 4, 0, 6, 7, 8, 9, 0, 11, 12, 13, 14, 12, 12 | 16, 17, 18, 19, 20, 19, 22, 12, 24, 25, 26, 27, 28, 27, 30] 13 | joints_left=[6, 7, 8, 9, 10, 16, 17, 18, 19, 20, 21, 22, 23] 14 | joints_right=[1, 2, 3, 4, 5, 24, 25, 26, 27, 28, 29, 30, 31] 15 | elif model_type == 'humaneva': 16 | parents=[-1, 0, 1, 2, 3, 1, 5, 6, 0, 8, 9, 0, 11, 12, 1] 17 | joints_left=[2, 3, 4, 8, 9, 10] 18 | joints_right=[5, 6, 7, 11, 12, 13] 19 | 20 | pose_info = {'keep_joints': keep_joints, 21 | 'parents': parents, 22 | 'joints_left': joints_left, 23 | 'joints_right': joints_right} 24 | print("Human pose information: ", pose_info) 25 | return model.Model(traj_dim * 3, 128, input_channels = 3,st_gcnn_dropout = cfg.dropout, joints_to_consider = traj_dim, pose_info=pose_info), \ 26 | LinNF.LinNF(data_dim=traj_dim * 3, num_layer=3) 27 | 28 | elif model_type == 'h36m_nf' or model_type == 'humaneva_nf': 29 | return LinNF.LinNF(data_dim=dataset.traj_dim, num_layer=cfg.specs['num_flow_layer']) -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import numpy as np 4 | 5 | 6 | def create_logger(filename, file_handle=True): 7 | # create logger 8 | logger = logging.getLogger(filename) 9 | logger.propagate = False 10 | logger.setLevel(logging.DEBUG) 11 | # create console handler with a higher log level 12 | ch = logging.StreamHandler() 13 | ch.setLevel(logging.INFO) 14 | stream_formatter = logging.Formatter('%(message)s') 15 | ch.setFormatter(stream_formatter) 16 | logger.addHandler(ch) 17 | 18 | if file_handle: 19 | # create file handler which logs even debug messages 20 | os.makedirs(os.path.dirname(filename), exist_ok=True) 21 | fh = logging.FileHandler(filename, mode='a') 22 | fh.setLevel(logging.DEBUG) 23 | file_formatter = logging.Formatter('[%(asctime)s] %(message)s') 24 | fh.setFormatter(file_formatter) 25 | logger.addHandler(fh) 26 | 27 | return logger 28 | 29 | 30 | def combine_dict(new_dict, old_dict=None): 31 | if old_dict is None: 32 | return new_dict 33 | for k in old_dict.keys(): 34 | if k in new_dict.keys(): 35 | old_dict[k] = np.concatenate([old_dict[k], new_dict[k]], axis=0) 36 | return old_dict 37 | 38 | 39 | class AverageMeter(object): 40 | """Computes and stores the average and current value""" 41 | 42 | def __init__(self): 43 | self.reset() 44 | 45 | def reset(self): 46 | self.val = 0 47 | self.avg = 0 48 | self.sum = 0 49 | self.count = 0 50 | 51 | def update(self, val, n=1): 52 | self.val = val 53 | self.sum += val * n 54 | self.count += n 55 | self.avg = self.sum / self.count 56 | -------------------------------------------------------------------------------- /motion_pred/utils/config.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import os 3 | 4 | 5 | class Config: 6 | 7 | def __init__(self, cfg_id, test=False, nf=False): 8 | self.id = cfg_id 9 | cfg_name = 'motion_pred/cfg/%s.yml' % cfg_id 10 | if not os.path.exists(cfg_name): 11 | print("Config file doesn't exist: %s" % cfg_name) 12 | exit(0) 13 | cfg = yaml.safe_load(open(cfg_name, 'r')) 14 | if nf: 15 | cfg_id += '_nf' 16 | # create dirs 17 | 18 | self.base_dir = 'results' 19 | self.cfg_dir = '%s/%s' % (self.base_dir, cfg_id) 20 | self.model_dir = '%s/models' % self.cfg_dir 21 | self.result_dir = '%s/results' % self.cfg_dir 22 | self.log_dir = '%s/log' % self.cfg_dir 23 | self.tb_dir = '%s/tb' % self.cfg_dir 24 | os.makedirs(self.model_dir, exist_ok=True) 25 | os.makedirs(self.result_dir, exist_ok=True) 26 | os.makedirs(self.log_dir, exist_ok=True) 27 | os.makedirs(self.tb_dir, exist_ok=True) 28 | 29 | # common 30 | self.dataset = cfg.get('dataset', 'h36m') 31 | self.batch_size = cfg.get('batch_size', 8) 32 | self.normalize_data = cfg.get('normalize_data', False) 33 | self.save_model_interval = cfg.get('save_model_interval', 20) 34 | self.t_his = cfg['t_his'] 35 | self.t_pred = cfg['t_pred'] 36 | self.use_vel = cfg.get('use_vel', False) 37 | 38 | self.nz = cfg['nz'] 39 | self.lr = cfg['lr'] 40 | self.dropout = cfg.get('dropout', 0.1) 41 | self.num_epoch = cfg['num_epoch'] 42 | self.num_epoch_fix = cfg.get('num_epoch_fix', self.num_epoch) 43 | self.num_data_sample = cfg['num_data_sample'] 44 | self.model_path = os.path.join(self.model_dir, '%04d.p') 45 | self.poses_prediction_path = os.path.join(self.model_dir, 'poses_prediction%04d.p') 46 | self.poses_vae_path = os.path.join(self.model_dir, 'poses_vae_%04d.p') 47 | self.poses_mapping_path = os.path.join(self.model_dir, 'poses_mapping_%04d.p') 48 | self.nk = cfg.get('nk', 10) 49 | self.nk1 = cfg.get('nk1', 5) 50 | self.nk2 = cfg.get('nk2', 2) 51 | self.lambdas = cfg.get('lambdas', []) 52 | 53 | self.specs = cfg.get('specs', dict()) 54 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.spatial.distance import pdist 3 | 4 | def mpjpe_error(pred, gt, *args): 5 | indexs = np.zeros(pred.shape[0]) 6 | sample_num, total_len, feature_size = pred.shape 7 | pred = pred.reshape([sample_num, total_len, feature_size//3, 3])[:, -1:, :, :] * 1000 8 | gt = gt.reshape([1, total_len, feature_size//3, 3])[:, -1:, :, :] * 1000 9 | dist = np.linalg.norm(pred - gt, axis=3).mean(axis=2).mean(axis=1) 10 | index = dist.argmin() 11 | indexs[index] += 1 12 | return dist[index], indexs 13 | 14 | def compute_diversity(pred, *args): 15 | if pred.shape[0] == 1: 16 | return 0.0 17 | dist = pdist(pred.reshape(pred.shape[0], -1)) 18 | diversity = dist.mean().item() 19 | return diversity, None 20 | 21 | 22 | def compute_ade(pred, gt, *args): 23 | indexs = np.zeros(pred.shape[0]) 24 | diff = pred - gt 25 | dist = np.linalg.norm(diff, axis=2).mean(axis=1) 26 | index = dist.argmin() 27 | indexs[index] += 1 28 | return dist[index], indexs 29 | 30 | 31 | def compute_fde(pred, gt, *args): 32 | indexs = np.zeros(pred.shape[0]) 33 | diff = pred - gt 34 | dist = np.linalg.norm(diff, axis=2)[:, -1] 35 | index = dist.argmin() 36 | indexs[index] += 1 37 | return dist[index], indexs 38 | 39 | def compute_amse(pred, gt, *args): 40 | diff = pred - gt # sample_num * total_len * ((num_key-1)*3) 41 | dist = (diff*diff).sum() / diff.shape[0] 42 | return dist.mean(), None 43 | 44 | 45 | def compute_fmse(pred, gt, *args): 46 | diff = pred[:, -1, :] - gt[:, -1, :] # sample_num * total_len * ((num_key-1)*3) 47 | dist = (diff*diff).sum() / diff.shape[0] 48 | return dist.mean(), None 49 | 50 | def compute_mmade(pred, gt, gt_multi): 51 | gt_dist = [] 52 | indexs = np.zeros(pred.shape[0]) 53 | for gt_multi_i in gt_multi: 54 | dist, index = compute_ade(pred, gt_multi_i) 55 | gt_dist.append(dist) 56 | indexs += index 57 | gt_dist = np.array(gt_dist).mean() 58 | return gt_dist, indexs 59 | 60 | 61 | def compute_mmfde(pred, gt, gt_multi): 62 | gt_dist = [] 63 | indexs = np.zeros(pred.shape[0]) 64 | for gt_multi_i in gt_multi: 65 | dist, index = compute_fde(pred, gt_multi_i) 66 | gt_dist.append(dist) 67 | indexs += index 68 | gt_dist = np.array(gt_dist).mean() 69 | return gt_dist, indexs 70 | -------------------------------------------------------------------------------- /motion_pred/utils/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class Dataset: 5 | 6 | def __init__(self, mode, t_his, t_pred, actions='all'): 7 | self.mode = mode 8 | self.t_his = t_his 9 | self.t_pred = t_pred 10 | self.t_total = t_his + t_pred 11 | self.actions = actions 12 | self.prepare_data() 13 | self.std, self.mean = None, None 14 | self.data_len = sum([seq.shape[0] for data_s in self.data.values() for seq in data_s.values()]) 15 | self.traj_dim = (self.kept_joints.shape[0] - 1) * 3 16 | self.normalized = False 17 | # iterator specific 18 | self.sample_ind = None 19 | 20 | def prepare_data(self): 21 | raise NotImplementedError 22 | 23 | def normalize_data(self, mean=None, std=None): 24 | if mean is None: 25 | all_seq = [] 26 | for data_s in self.data.values(): 27 | for seq in data_s.values(): 28 | all_seq.append(seq[:, 1:]) 29 | all_seq = np.concatenate(all_seq) 30 | self.mean = all_seq.mean(axis=0) 31 | self.std = all_seq.std(axis=0) 32 | else: 33 | self.mean = mean 34 | self.std = std 35 | for data_s in self.data.values(): 36 | for action in data_s.keys(): 37 | data_s[action][:, 1:] = (data_s[action][:, 1:] - self.mean) / self.std 38 | self.normalized = True 39 | 40 | def sample(self): 41 | subject = np.random.choice(self.subjects) 42 | dict_s = self.data[subject] 43 | action = np.random.choice(list(dict_s.keys())) 44 | seq = dict_s[action] 45 | fr_start = np.random.randint(seq.shape[0] - self.t_total) 46 | fr_end = fr_start + self.t_total 47 | traj = seq[fr_start: fr_end] 48 | return traj[None, ...] 49 | 50 | def sampling_generator(self, num_samples=1000, batch_size=8): 51 | for i in range(num_samples // batch_size): 52 | sample = [] 53 | for i in range(batch_size): 54 | sample_i = self.sample() 55 | sample.append(sample_i) 56 | sample = np.concatenate(sample, axis=0) 57 | yield sample 58 | 59 | def iter_generator(self, step=25): 60 | for data_s in self.data.values(): 61 | for seq in data_s.values(): 62 | seq_len = seq.shape[0] 63 | for i in range(0, seq_len - self.t_total, step): 64 | traj = seq[None, i: i + self.t_total] 65 | yield traj 66 | 67 | 68 | 69 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # "Learning Semantic Latent Directions for Accurate and Controllable Human Motion Prediction" (**ECCV 2024**) 2 | 3 | 4 | 5 | --- 6 | This repo contains the official implementation of the paper: 7 | 8 | Learning Semantic Latent Directions for Accurate and Controllable Human Motion Prediction 9 | 10 | ECCV 2024 11 | [[arxiv](https://arxiv.org/abs/2407.11494)] 12 | ### Dependencies 13 | * Python >= 3.8 14 | * [PyTorch](https://pytorch.org) >= 1.9 15 | * Tensorboard 16 | * matplotlib 17 | * tqdm 18 | * argparse 19 | 20 | ### Get the data 21 | We adapt the data preprocessing from [GSPS](https://github.com/wei-mao-2019/gsps). 22 | * We follow the data preprocessing steps ([DATASETS.md](https://github.com/facebookresearch/VideoPose3D/blob/master/DATASETS.md)) inside the [VideoPose3D](https://github.com/facebookresearch/VideoPose3D) repo. 23 | * Given the processed dataset, we further compute the multi-modal future for each motion sequence. All data needed can be downloaded from [Google Drive](https://drive.google.com/drive/folders/1sb1n9l0Na5EqtapDVShOJJ-v6o-GZrIJ?usp=sharing) and place all the dataset in ``data`` folder inside the root of this repo. 24 | 25 | ### Get the pretrain models 26 | * All pretrain models can be downloaded from [Google Drive](https://drive.google.com/drive/folders/1YAa3Lpei0V3-JTEZwSw0WqRfSTYbY-z2?usp=drive_link) and place all the pretrain models in ``results`` folder inside the root of this repo. 27 | 28 | ### Train 29 | We have used the following commands for training the network on Human3.6M or HumanEva-I with skeleton representation: 30 | ```bash 31 | python train_nf.py --cfg [h36m/humaneva] --gpu_index 0 32 | python main.py --cfg [h36m/humaneva] --gpu_index 0 33 | ``` 34 | ### Test 35 | To test on the pretrained model, we have used the following commands: 36 | ```bash 37 | python main.py --cfg [h36m/humaneva] --mode test --iter 500 --gpu_index 0 38 | ``` 39 | ### Visualization 40 | For visualizing from a pretrained model, we have used the following commands: 41 | 42 | ```bash 43 | python main.py --cfg [h36m/humaneva] --mode viz --iter 500 --gpu_index 0 44 | ``` 45 | 46 | ### Acknowledgments 47 | 48 | This code is based on the implementations of [STARS](https://github.com/Sirui-Xu/STARS). 49 | 50 | ## Citation 51 | If you find this work useful in your research, please cite: 52 | 53 | ```bibtex 54 | @article{xu2024learning, 55 | title={Learning Semantic Latent Directions for Accurate and Controllable Human Motion Prediction}, 56 | author={Xu, Guowei and Tao, Jiale and Li, Wen and Duan, Lixin}, 57 | journal={arXiv preprint arXiv:2407.11494}, 58 | year={2024} 59 | } 60 | ``` 61 | 62 | ## License 63 | 64 | This repo is distributed under an [MIT LICENSE](LICENSE) 65 | -------------------------------------------------------------------------------- /motion_pred/utils/dataset_humaneva.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from motion_pred.utils.dataset import Dataset 4 | from motion_pred.utils.skeleton import Skeleton 5 | 6 | 7 | class DatasetHumanEva(Dataset): 8 | 9 | def __init__(self, mode, t_his=15, t_pred=60, actions='all', **kwargs): 10 | super().__init__(mode, t_his, t_pred, actions) 11 | 12 | def prepare_data(self): 13 | self.data_file = os.path.join('data', 'data_3d_humaneva15.npz') 14 | self.subjects_split = {'train': ['Train/S1', 'Train/S2', 'Train/S3'], 15 | 'test': ['Validate/S1', 'Validate/S2', 'Validate/S3']} 16 | self.subjects = [x for x in self.subjects_split[self.mode]] 17 | self.skeleton = Skeleton(parents=[-1, 0, 1, 2, 3, 1, 5, 6, 0, 8, 9, 0, 11, 12, 1], 18 | joints_left=[2, 3, 4, 8, 9, 10], 19 | joints_right=[5, 6, 7, 11, 12, 13]) 20 | self.kept_joints = np.arange(15) 21 | self.process_data() 22 | 23 | def process_data(self): 24 | data_o = np.load(self.data_file, allow_pickle=True)['positions_3d'].item() 25 | data_f = dict(filter(lambda x: x[0] in self.subjects, data_o.items())) 26 | # these takes have wrong head position, excluded from training and testing 27 | if self.mode == 'train': 28 | data_f['Train/S3'].pop('Walking 1 chunk0') 29 | data_f['Train/S3'].pop('Walking 1 chunk2') 30 | else: 31 | data_f['Validate/S3'].pop('Walking 1 chunk4') 32 | for key in list(data_f.keys()): 33 | # data_f[key] = dict(filter(lambda x: (self.actions == 'all' or 34 | # all([a in x[0] for a in self.actions])) 35 | # and x[1].shape[0] >= self.t_total, data_f[key].items())) 36 | data_f[key] = dict(filter(lambda x: (self.actions == 'all' or 37 | any([a in x[0] for a in self.actions])) 38 | and x[1].shape[0] >= self.t_total, data_f[key].items())) 39 | if len(data_f[key]) == 0: 40 | data_f.pop(key) 41 | for data_s in data_f.values(): 42 | for action in data_s.keys(): 43 | seq = data_s[action][:, self.kept_joints, :] 44 | seq[:, 1:] -= seq[:, :1] 45 | data_s[action] = seq 46 | self.data = data_f 47 | 48 | 49 | if __name__ == '__main__': 50 | np.random.seed(0) 51 | actions = 'all' 52 | dataset = DatasetHumanEva('test', actions=actions) 53 | generator = dataset.sampling_generator() 54 | dataset.normalize_data() 55 | # generator = dataset.iter_generator() 56 | for data in generator: 57 | print(data.shape) 58 | 59 | 60 | -------------------------------------------------------------------------------- /motion_pred/utils/dataset_h36m.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from motion_pred.utils.dataset import Dataset 4 | from motion_pred.utils.skeleton import Skeleton 5 | 6 | 7 | class DatasetH36M(Dataset): 8 | 9 | def __init__(self, mode, t_his=25, t_pred=100, actions='all', use_vel=False): 10 | self.use_vel = use_vel 11 | super().__init__(mode, t_his, t_pred, actions) 12 | if use_vel: 13 | self.traj_dim += 3 14 | 15 | def prepare_data(self): 16 | self.data_file = os.path.join('data', 'data_3d_h36m.npz') 17 | self.subjects_split = {'train': [1, 5, 6, 7, 8], 18 | 'test': [9, 11]} 19 | self.subjects = ['S%d' % x for x in self.subjects_split[self.mode]] 20 | self.skeleton = Skeleton(parents=[-1, 0, 1, 2, 3, 4, 0, 6, 7, 8, 9, 0, 11, 12, 13, 14, 12, 21 | 16, 17, 18, 19, 20, 19, 22, 12, 24, 25, 26, 27, 28, 27, 30], 22 | joints_left=[6, 7, 8, 9, 10, 16, 17, 18, 19, 20, 21, 22, 23], 23 | joints_right=[1, 2, 3, 4, 5, 24, 25, 26, 27, 28, 29, 30, 31]) 24 | self.removed_joints = {4, 5, 9, 10, 11, 16, 20, 21, 22, 23, 24, 28, 29, 30, 31} 25 | self.kept_joints = np.array([x for x in range(32) if x not in self.removed_joints]) 26 | self.skeleton.remove_joints(self.removed_joints) 27 | self.skeleton._parents[11] = 8 28 | self.skeleton._parents[14] = 8 29 | self.process_data() 30 | 31 | def process_data(self): 32 | data_o = np.load(self.data_file, allow_pickle=True)['positions_3d'].item() 33 | self.S1_skeleton = data_o['S1']['Directions'][:1, self.kept_joints].copy() 34 | data_f = dict(filter(lambda x: x[0] in self.subjects, data_o.items())) 35 | if self.actions != 'all': 36 | for key in list(data_f.keys()): 37 | # data_f[key] = dict(filter(lambda x: all([a in x[0] for a in self.actions]), data_f[key].items())) 38 | data_f[key] = dict(filter(lambda x: any([a in str.lower(x[0]) for a in self.actions]), data_f[key].items())) 39 | if len(data_f[key]) == 0: 40 | data_f.pop(key) 41 | for data_s in data_f.values(): 42 | for action in data_s.keys(): 43 | seq = data_s[action][:, self.kept_joints, :] 44 | if self.use_vel: 45 | v = (np.diff(seq[:, :1], axis=0) * 50).clip(-5.0, 5.0) 46 | v = np.append(v, v[[-1]], axis=0) 47 | seq[:, 1:] -= seq[:, :1] 48 | if self.use_vel: 49 | seq = np.concatenate((seq, v), axis=1) 50 | data_s[action] = seq 51 | self.data = data_f 52 | 53 | 54 | if __name__ == '__main__': 55 | np.random.seed(0) 56 | actions = {'WalkDog'} 57 | dataset = DatasetH36M('train', actions=actions) 58 | generator = dataset.sampling_generator() 59 | dataset.normalize_data() 60 | # generator = dataset.iter_generator() 61 | for data in generator: 62 | print(data.shape) 63 | -------------------------------------------------------------------------------- /motion_pred/utils/skeleton.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import numpy as np 9 | 10 | 11 | class Skeleton: 12 | def __init__(self, parents, joints_left, joints_right): 13 | assert len(joints_left) == len(joints_right) 14 | 15 | self._parents = np.array(parents) 16 | self._joints_left = joints_left 17 | self._joints_right = joints_right 18 | self._compute_metadata() 19 | 20 | def num_joints(self): 21 | return len(self._parents) 22 | 23 | def parents(self): 24 | return self._parents 25 | 26 | def has_children(self): 27 | return self._has_children 28 | 29 | def children(self): 30 | return self._children 31 | 32 | def remove_joints(self, joints_to_remove): 33 | """ 34 | Remove the joints specified in 'joints_to_remove'. 35 | """ 36 | valid_joints = [] 37 | for joint in range(len(self._parents)): 38 | if joint not in joints_to_remove: 39 | valid_joints.append(joint) 40 | 41 | for i in range(len(self._parents)): 42 | while self._parents[i] in joints_to_remove: 43 | self._parents[i] = self._parents[self._parents[i]] 44 | 45 | index_offsets = np.zeros(len(self._parents), dtype=int) 46 | new_parents = [] 47 | for i, parent in enumerate(self._parents): 48 | if i not in joints_to_remove: 49 | new_parents.append(parent - index_offsets[parent]) 50 | else: 51 | index_offsets[i:] += 1 52 | self._parents = np.array(new_parents) 53 | 54 | 55 | if self._joints_left is not None: 56 | new_joints_left = [] 57 | for joint in self._joints_left: 58 | if joint in valid_joints: 59 | new_joints_left.append(joint - index_offsets[joint]) 60 | self._joints_left = new_joints_left 61 | if self._joints_right is not None: 62 | new_joints_right = [] 63 | for joint in self._joints_right: 64 | if joint in valid_joints: 65 | new_joints_right.append(joint - index_offsets[joint]) 66 | self._joints_right = new_joints_right 67 | 68 | self._compute_metadata() 69 | 70 | return valid_joints 71 | 72 | def joints_left(self): 73 | return self._joints_left 74 | 75 | def joints_right(self): 76 | return self._joints_right 77 | 78 | def _compute_metadata(self): 79 | self._has_children = np.zeros(len(self._parents)).astype(bool) 80 | for i, parent in enumerate(self._parents): 81 | if parent != -1: 82 | self._has_children[parent] = True 83 | 84 | self._children = [] 85 | for i, parent in enumerate(self._parents): 86 | self._children.append([]) 87 | for i, parent in enumerate(self._parents): 88 | if parent != -1: 89 | self._children[parent].append(i) -------------------------------------------------------------------------------- /models/LinNF.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from torch.nn import functional as F 5 | 6 | from models import Orthogonal 7 | from torch.nn.parameter import Parameter 8 | import math 9 | 10 | 11 | class LinQR(nn.Module): 12 | """ 13 | Implementation of the additive coupling layer from section 3.2 of the NICE 14 | paper. 15 | """ 16 | 17 | def __init__(self, data_dim): 18 | super().__init__() 19 | 20 | self.Q = Orthogonal.Orthogonal(d=data_dim) 21 | self.R = Parameter(torch.Tensor(data_dim, data_dim)) 22 | self.bias = Parameter(torch.Tensor(data_dim)) 23 | self.reset_parameters() 24 | 25 | def reset_parameters(self): 26 | torch.nn.init.kaiming_uniform_(self.R, a=math.sqrt(5)) 27 | if self.bias is not None: 28 | fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.R) 29 | bound = 1 / math.sqrt(fan_in) 30 | torch.nn.init.uniform_(self.bias, -bound, bound) 31 | 32 | def forward(self, x, logdet): 33 | """ 34 | x,x_cond: [bs,data_dim] 35 | 36 | """ 37 | diagR = torch.diag(self.R) 38 | R = torch.triu(self.R, diagonal=1) + torch.diag(torch.exp(diagR)) 39 | x = x.matmul(R.t()) 40 | x = self.Q(x) 41 | x = x + self.bias 42 | logdet += diagR[None, :].repeat([x.shape[0], 1]).sum(dim=1) 43 | return x, logdet 44 | 45 | def inverse(self, x, logdet): 46 | """ 47 | x,x_cond: [bs,data_dim] 48 | 49 | """ 50 | 51 | x = x - self.bias 52 | x = self.Q.inverse(x) 53 | diagR = torch.diag(self.R) 54 | R = torch.triu(self.R, diagonal=1) + torch.diag(torch.exp(diagR)) 55 | invR = torch.inverse(R) 56 | x = x.matmul(invR.t()) 57 | 58 | logdet += diagR[None, :].repeat([x.shape[0], 1]).sum(dim=1) 59 | return x, logdet 60 | 61 | 62 | class prelu(nn.Module): 63 | 64 | def __init__(self, num_parameters: int = 1, init: float = 0.25): 65 | self.num_parameters = num_parameters 66 | super(prelu, self).__init__() 67 | self.weight = Parameter(torch.Tensor(num_parameters).fill_(init)) 68 | 69 | def forward(self, input, logdet): 70 | s = torch.zeros_like(input) 71 | s[input < 0] = torch.log(self.weight) 72 | logdet += torch.sum(s, dim=1) 73 | return F.prelu(input, self.weight), logdet 74 | 75 | def inverse(self, input, logdet): 76 | s = torch.zeros_like(input) 77 | s[input < 0] = torch.log(self.weight) 78 | logdet += torch.sum(s, dim=1) 79 | return F.prelu(input, 1 / self.weight), logdet 80 | 81 | 82 | class LinNF(nn.Module): 83 | def __init__(self, data_dim, num_layer=5, with_prelu=True): 84 | super().__init__() 85 | self.num_layer = num_layer 86 | self.w1 = nn.ModuleList() 87 | for i in range(num_layer - 1): 88 | self.w1.append(LinQR(data_dim=data_dim)) 89 | if with_prelu: 90 | self.w1.append(prelu()) 91 | self.w1.append(LinQR(data_dim=data_dim)) 92 | 93 | def forward(self, x): 94 | 95 | z = x 96 | log_det_jacobian = 0 97 | 98 | for i, w in enumerate(self.w1): 99 | z, log_det_jacobian = w(z, log_det_jacobian) 100 | 101 | return z, log_det_jacobian 102 | 103 | def inverse(self, z): 104 | x = z 105 | log_det_jacobian = 0 106 | 107 | for i in range(len(self.w1) - 1, -1, -1): 108 | x, log_det_jacobian = self.w1[i].inverse(x, log_det_jacobian) 109 | 110 | return x, log_det_jacobian 111 | 112 | 113 | if __name__ == '__main__': 114 | bs = 32 115 | data_dim = 25 116 | 117 | sf = LinNF(data_dim=data_dim) 118 | # a = sf.prior.sample([10000, 48, 25]) 119 | # print(torch.mean(sf.prior.log_prob(a).sum([1, 2]))) 120 | sf.double() 121 | sf.cuda() 122 | for i in range(10): 123 | x = torch.randn([bs, data_dim]).double().cuda() 124 | 125 | y1, logdet = sf(x) 126 | x1, logdet = sf.inverse(y1) 127 | err = (x1 / x - 1).abs().max() 128 | print(err) 129 | print(1) 130 | -------------------------------------------------------------------------------- /utils/vis_util.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | from mpl_toolkits.mplot3d import Axes3D 3 | import numpy as np 4 | 5 | 6 | def draw_skeleton(ax, kpts, parents=[], is_right=[], cols=["#3498db", "#e74c3c"], marker='o', line_style='-', 7 | label=None): 8 | """ 9 | 10 | :param kpts: joint_n*(3 or 2) 11 | :param parents: 12 | :return: 13 | """ 14 | # ax = plt.subplot(111) 15 | joint_n, dims = kpts.shape 16 | # by default it is human 3.6m joints 17 | # [0, 1, 2, 3, 6, 7, 8, 12, 13, 14, 15, 25, 26, 27, 17, 18, 19] 18 | # if len(parents) == 0: 19 | # parents = [-1, 0, 1, 2, 0, 4, 5, 0, 7, 8, 9, 8, 11, 12, 8, 14, 15] 20 | # is_right = [0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0] 21 | # if cols == []: 22 | # cols = ["#3498db", "#e74c3c"] 23 | # if parents == 'op': 24 | # parents = [1, -1, 1, 2, 3, 1, 5, 6, 1, 8, 9, 1, 11, 12, 0, 0, 14, 15] 25 | # if parents == 'smpl': 26 | # # parents = [-1, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 12, 13, 14, 16, 17, 18, 19, 20, 21] 27 | # parents = [-1, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 12, 13, 14, 16, 17, 18, 19] 28 | # # is_right = [0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1] 29 | # is_right = [0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1] 30 | # if cols == []: 31 | # cols = ["#3498db", "#e74c3c"] 32 | # if parents == 'smpl_add': 33 | # parents = [-1, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 12, 13, 14, 16, 17, 18, 19, 20, 21, 15] 34 | # is_right = [0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0] 35 | # cols = ["#3498db", "#e74c3c"] 36 | # ax.set_xlabel('X Label') 37 | # ax.set_ylabel('Y Label') 38 | if dims > 2: 39 | ax.view_init(75, 90) 40 | ax.set_zlabel('Z Label') 41 | # if dims == 2: 42 | # # idx_choosed = np.intersect1d(np.where(kpts[:, 0] > 0)[0], np.where(kpts[:, 1] > 0)[0]) 43 | # # ax.scatter(kpts[idx_choosed, 0], kpts[idx_choosed, 1], c=c, marker=marker, s=10) 44 | # ax.scatter(kpts[:, 0], kpts[:, 1], c=cols[0], marker=marker, s=10) 45 | # # for i in idx_choosed: 46 | # # ax.text(kpts[i, 0], kpts[i, 1], "{:d}".format(i), color=c) 47 | # else: 48 | # ax.scatter(kpts[:, 0], kpts[:, 1], kpts[:, 2], c=cols[0], marker=marker, s=10) 49 | # for i in range(kpts.shape[0]): 50 | # ax.text(kpts[i, 0], kpts[i, 1], kpts[i, 2], "{:d}".format(i), color=cols[0]) 51 | is_label = True 52 | for i in range(len(parents)): 53 | if parents[i] < 0: 54 | continue 55 | # if dims == 2: 56 | # if not (parents[i] in idx_choosed and i in idx_choosed): 57 | # continue 58 | 59 | if dims == 2: 60 | # ax.plot([kpts[parents[i], 0], kpts[i, 0]], [kpts[parents[i], 1], kpts[i, 1]], c=cols[is_right[i]], 61 | # linestyle=line_style, 62 | # alpha=0.5 if is_right[i] else 1, linewidth=3) 63 | if label is not None and is_label: 64 | ax.plot([kpts[parents[i], 0], kpts[i, 0]], [kpts[parents[i], 1], kpts[i, 1]], c=cols[is_right[i]], 65 | linestyle=line_style, 66 | alpha=1 if is_right[i] else 0.6, label=label) 67 | is_label = False 68 | else: 69 | ax.plot([kpts[parents[i], 0], kpts[i, 0]], [kpts[parents[i], 1], kpts[i, 1]], c=cols[is_right[i]], 70 | linestyle=line_style, 71 | alpha=1 if is_right[i] else 0.6) 72 | else: 73 | if label is not None and is_label: 74 | ax.plot([kpts[parents[i], 0], kpts[i, 0]], [kpts[parents[i], 1], kpts[i, 1]], 75 | [kpts[parents[i], 2], kpts[i, 2]], linestyle=line_style, c=cols[is_right[i]], 76 | alpha=1 if is_right[i] else 0.6, linewidth=3, label=label) 77 | is_label = False 78 | else: 79 | ax.plot([kpts[parents[i], 0], kpts[i, 0]], [kpts[parents[i], 1], kpts[i, 1]], 80 | [kpts[parents[i], 2], kpts[i, 2]], linestyle=line_style, c=cols[is_right[i]], 81 | alpha=1 if is_right[i] else 0.6, linewidth=3) 82 | 83 | return None -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def get_dct_matrix(N, is_torch=True): 6 | dct_m = np.eye(N) 7 | for k in np.arange(N): 8 | for i in np.arange(N): 9 | w = np.sqrt(2 / N) 10 | if k == 0: 11 | w = np.sqrt(1 / N) 12 | dct_m[k, i] = w * np.cos(np.pi * (i + 1 / 2) * k / N) 13 | idct_m = np.linalg.inv(dct_m) 14 | if is_torch: 15 | dct_m = torch.from_numpy(dct_m) 16 | idct_m = torch.from_numpy(idct_m) 17 | return dct_m, idct_m 18 | 19 | 20 | def _pairwise_distances(embeddings, squared=False): 21 | """Compute the 2D matrix of distances between all the embeddings. 22 | 23 | Args: 24 | embeddings: tensor of shape (batch_size, embed_dim) 25 | squared: Boolean. If true, output is the pairwise squared euclidean distance matrix. 26 | If false, output is the pairwise euclidean distance matrix. 27 | 28 | Returns: 29 | pairwise_distances: tensor of shape (batch_size, batch_size) 30 | """ 31 | dot_product = torch.matmul(embeddings, embeddings.t()) 32 | 33 | # Get squared L2 norm for each embedding. We can just take the diagonal of `dot_product`. 34 | # This also provides more numerical stability (the diagonal of the result will be exactly 0). 35 | # shape (batch_size,) 36 | square_norm = torch.diag(dot_product) 37 | 38 | # Compute the pairwise distance matrix as we have: 39 | # ||a - b||^2 = ||a||^2 - 2 + ||b||^2 40 | # shape (batch_size, batch_size) 41 | distances = square_norm.unsqueeze(0) - 2.0 * dot_product + square_norm.unsqueeze(1) 42 | 43 | # Because of computation errors, some distances might be negative so we put everything >= 0.0 44 | distances[distances < 0] = 0 45 | 46 | if not squared: 47 | # Because the gradient of sqrt is infinite when distances == 0.0 (ex: on the diagonal) 48 | # we need to add a small epsilon where distances == 0.0 49 | mask = distances.eq(0).float() 50 | distances = distances + mask * 1e-16 51 | 52 | distances = (1.0 - mask) * torch.sqrt(distances) 53 | 54 | return distances 55 | 56 | 57 | def _pairwise_distances_l1(embeddings, squared=False): 58 | """Compute the 2D matrix of distances between all the embeddings. 59 | 60 | Args: 61 | embeddings: tensor of shape (batch_size, embed_dim) 62 | squared: Boolean. If true, output is the pairwise squared euclidean distance matrix. 63 | If false, output is the pairwise euclidean distance matrix. 64 | 65 | Returns: 66 | pairwise_distances: tensor of shape (batch_size, batch_size) 67 | """ 68 | distances = torch.abs(embeddings[None, :, :] - embeddings[:, None, :]) 69 | return distances 70 | 71 | 72 | def expmap2rotmat(r): 73 | """ 74 | Converts an exponential map angle to a rotation matrix 75 | Matlab port to python for evaluation purposes 76 | I believe this is also called Rodrigues' formula 77 | https://github.com/asheshjain399/RNNexp/blob/srnn/structural_rnn/CRFProblems/H3.6m/mhmublv/Motion/expmap2rotmat.m 78 | 79 | Args 80 | r: 1x3 exponential map 81 | Returns 82 | R: 3x3 rotation matrix 83 | """ 84 | theta = np.linalg.norm(r) 85 | r0 = np.divide(r, theta + np.finfo(np.float32).eps) 86 | r0x = np.array([0, -r0[2], r0[1], 0, 0, -r0[0], 0, 0, 0]).reshape(3, 3) 87 | r0x = r0x - r0x.T 88 | R = np.eye(3, 3) + np.sin(theta) * r0x + (1 - np.cos(theta)) * (r0x).dot(r0x); 89 | return R 90 | 91 | 92 | def absolute2relative(x, parents, invert=False, x0=None): 93 | """ 94 | x: [bs,..., jn, 3] or [bs,..., jn-1, 3] if invert 95 | x0: [1,..., jn, 3] 96 | parents: [-1,0,1 ...] 97 | """ 98 | if not invert: 99 | xt = x[..., 1:, :] - x[..., parents[1:], :] 100 | xt = xt / np.linalg.norm(xt, axis=-1, keepdims=True) 101 | return xt 102 | else: 103 | jn = x0.shape[-2] 104 | limb_l = np.linalg.norm(x0[..., 1:, :] - x0[..., parents[1:], :], axis=-1, keepdims=True) 105 | xt = x * limb_l 106 | xt0 = np.zeros_like(xt[..., :1, :]) 107 | xt = np.concatenate([xt0, xt], axis=-2) 108 | for i in range(1, jn): 109 | xt[..., i, :] = xt[..., parents[i], :] + xt[..., i, :] 110 | return xt 111 | 112 | 113 | def absolute2relative_torch(x, parents, invert=False, x0=None): 114 | """ 115 | x: [bs,..., jn, 3] or [bs,..., jn-1, 3] if invert 116 | x0: [1,..., jn, 3] 117 | parents: [-1,0,1 ...] 118 | """ 119 | if not invert: 120 | xt = x[..., 1:, :] - x[..., parents[1:], :] 121 | xt = xt / torch.norm(xt, dim=-1, keepdim=True) 122 | return xt 123 | else: 124 | jn = x0.shape[-2] 125 | limb_l = torch.norm(x0[..., 1:, :] - x0[..., parents[1:], :], dim=-1, keepdim=True) 126 | xt = x * limb_l 127 | xt0 = torch.zeros_like(xt[..., :1, :]) 128 | xt = torch.cat([xt0, xt], dim=-2) 129 | for i in range(1, jn): 130 | xt[..., i, :] = xt[..., parents[i], :] + xt[..., i, :] 131 | return xt 132 | -------------------------------------------------------------------------------- /utils/torch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.optim import lr_scheduler 4 | 5 | tensor = torch.tensor 6 | DoubleTensor = torch.DoubleTensor 7 | FloatTensor = torch.FloatTensor 8 | LongTensor = torch.LongTensor 9 | ByteTensor = torch.ByteTensor 10 | ones = torch.ones 11 | zeros = torch.zeros 12 | 13 | 14 | class to_cpu: 15 | 16 | def __init__(self, *models): 17 | self.models = list(filter(lambda x: x is not None, models)) 18 | self.prev_devices = [x.device if hasattr(x, 'device') else next(x.parameters()).device for x in self.models] 19 | for x in self.models: 20 | x.to(torch.device('cpu')) 21 | 22 | def __enter__(self): 23 | pass 24 | 25 | def __exit__(self, *args): 26 | for x, device in zip(self.models, self.prev_devices): 27 | x.to(device) 28 | return False 29 | 30 | 31 | class to_device: 32 | 33 | def __init__(self, device, *models): 34 | self.models = list(filter(lambda x: x is not None, models)) 35 | self.prev_devices = [x.device if hasattr(x, 'device') else next(x.parameters()).device for x in self.models] 36 | for x in self.models: 37 | x.to(device) 38 | 39 | def __enter__(self): 40 | pass 41 | 42 | def __exit__(self, *args): 43 | for x, device in zip(self.models, self.prev_devices): 44 | x.to(device) 45 | return False 46 | 47 | 48 | class to_test: 49 | 50 | def __init__(self, *models): 51 | self.models = list(filter(lambda x: x is not None, models)) 52 | self.prev_modes = [x.training for x in self.models] 53 | for x in self.models: 54 | x.train(False) 55 | 56 | def __enter__(self): 57 | pass 58 | 59 | def __exit__(self, *args): 60 | for x, mode in zip(self.models, self.prev_modes): 61 | x.train(mode) 62 | return False 63 | 64 | 65 | class to_train: 66 | 67 | def __init__(self, *models): 68 | self.models = list(filter(lambda x: x is not None, models)) 69 | self.prev_modes = [x.training for x in self.models] 70 | for x in self.models: 71 | x.train(True) 72 | 73 | def __enter__(self): 74 | pass 75 | 76 | def __exit__(self, *args): 77 | for x, mode in zip(self.models, self.prev_modes): 78 | x.train(mode) 79 | return False 80 | 81 | 82 | def batch_to(dst, *args): 83 | return [x.to(dst) if x is not None else None for x in args] 84 | 85 | 86 | def get_flat_params_from(models): 87 | if not hasattr(models, '__iter__'): 88 | models = (models, ) 89 | params = [] 90 | for model in models: 91 | for param in model.parameters(): 92 | params.append(param.data.view(-1)) 93 | 94 | flat_params = torch.cat(params) 95 | return flat_params 96 | 97 | 98 | def set_flat_params_to(model, flat_params): 99 | prev_ind = 0 100 | for param in model.parameters(): 101 | flat_size = int(np.prod(list(param.size()))) 102 | param.data.copy_( 103 | flat_params[prev_ind:prev_ind + flat_size].view(param.size())) 104 | prev_ind += flat_size 105 | 106 | 107 | def get_flat_grad_from(inputs, grad_grad=False): 108 | grads = [] 109 | for param in inputs: 110 | if grad_grad: 111 | grads.append(param.grad.grad.view(-1)) 112 | else: 113 | if param.grad is None: 114 | grads.append(zeros(param.view(-1).shape)) 115 | else: 116 | grads.append(param.grad.view(-1)) 117 | 118 | flat_grad = torch.cat(grads) 119 | return flat_grad 120 | 121 | 122 | def compute_flat_grad(output, inputs, filter_input_ids=set(), retain_graph=False, create_graph=False): 123 | if create_graph: 124 | retain_graph = True 125 | 126 | inputs = list(inputs) 127 | params = [] 128 | for i, param in enumerate(inputs): 129 | if i not in filter_input_ids: 130 | params.append(param) 131 | 132 | grads = torch.autograd.grad(output, params, retain_graph=retain_graph, create_graph=create_graph) 133 | 134 | j = 0 135 | out_grads = [] 136 | for i, param in enumerate(inputs): 137 | if i in filter_input_ids: 138 | out_grads.append(zeros(param.view(-1).shape)) 139 | else: 140 | out_grads.append(grads[j].view(-1)) 141 | j += 1 142 | grads = torch.cat(out_grads) 143 | 144 | for param in params: 145 | param.grad = None 146 | return grads 147 | 148 | 149 | def set_optimizer_lr(optimizer, lr): 150 | for param_group in optimizer.param_groups: 151 | param_group['lr'] = lr 152 | 153 | 154 | def filter_state_dict(state_dict, filter_keys): 155 | for key in list(state_dict.keys()): 156 | for f_key in filter_keys: 157 | if f_key in key: 158 | del state_dict[key] 159 | break 160 | 161 | 162 | def get_scheduler(optimizer, policy, nepoch_fix=None, nepoch=None, decay_step=None): 163 | if policy == 'lambda': 164 | def lambda_rule(epoch): 165 | lr_l = 1.0 - max(0, epoch - nepoch_fix) / float(nepoch - nepoch_fix + 1) 166 | return lr_l 167 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 168 | elif policy == 'step': 169 | scheduler = lr_scheduler.StepLR( 170 | optimizer, step_size=decay_step, gamma=0.1) 171 | elif policy == 'plateau': 172 | scheduler = lr_scheduler.ReduceLROnPlateau( 173 | optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) 174 | else: 175 | return NotImplementedError('learning rate policy [%s] is not implemented', policy) 176 | return scheduler 177 | -------------------------------------------------------------------------------- /motion_pred/utils/visualization.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import os 9 | import matplotlib 10 | import matplotlib.pyplot as plt 11 | from matplotlib.animation import FuncAnimation, writers 12 | from mpl_toolkits.mplot3d import Axes3D 13 | import numpy as np 14 | 15 | 16 | 17 | def render_animation(skeleton, poses_generator, t_hist, fix_0=True, azim=0.0, output=None, size=6, ncol=5, bitrate=3000, index_i=0): 18 | """ 19 | TODO 20 | Render an animation. The supported output modes are: 21 | -- 'interactive': display an interactive figure 22 | (also works on notebooks if associated with %matplotlib inline) 23 | -- 'html': render the animation as HTML5 video. Can be displayed in a notebook using HTML(...). 24 | -- 'filename.mp4': render and export the animation as an h264 video (requires ffmpeg). 25 | -- 'filename.gif': render and export the animation a gif file (requires imagemagick). 26 | """ 27 | os.makedirs(output, exist_ok=True) 28 | all_poses = next(poses_generator) 29 | action = all_poses.pop('action') 30 | t_total = next(iter(all_poses.values())).shape[0] 31 | poses = all_poses 32 | plt.ioff() 33 | nrow = int(np.ceil(len(poses) / ncol)) 34 | fig = plt.figure(figsize=(size*ncol, size*nrow)) 35 | ax_3d = [] 36 | lines_3d = [] 37 | trajectories = [] 38 | radius = 1.7 39 | for index, (title, data) in enumerate(poses.items()): 40 | ax = fig.add_subplot(nrow, ncol, index+1, projection='3d') 41 | ax.view_init(elev=15., azim=azim) 42 | ax.set_xlim3d([-radius/2, radius/2]) 43 | ax.set_zlim3d([0, radius]) 44 | ax.set_ylim3d([-radius/2, radius/2]) 45 | ax.set_aspect('auto') 46 | ax.set_xticklabels([]) 47 | ax.set_yticklabels([]) 48 | ax.set_zticklabels([]) 49 | ax.dist = 5.0 50 | # ax.set_title(title, y=1.2) 51 | ax.set_axis_off() 52 | ax.patch.set_alpha(0.0) 53 | ax_3d.append(ax) 54 | lines_3d.append([]) 55 | trajectories.append(data[:, 0, [0, 1]]) 56 | fig.tight_layout() 57 | fig.subplots_adjust(wspace=-0.4, hspace=0) 58 | poses = list(poses.values()) 59 | 60 | anim = None 61 | initialized = False 62 | animating = True 63 | find = 0 64 | hist_lcol, hist_rcol = 'black', 'red' 65 | pred_lcol, pred_rcol = 'purple', 'green' 66 | 67 | parents = skeleton.parents() 68 | 69 | def update_video(i): 70 | nonlocal initialized 71 | if i < t_hist: 72 | lcol, rcol = hist_lcol, hist_rcol 73 | else: 74 | lcol, rcol = pred_lcol, pred_rcol 75 | 76 | for n, ax in enumerate(ax_3d): 77 | if fix_0 and n == 0 and i >= t_hist: 78 | continue 79 | trajectories[n] = poses[n][:, 0, [0, 1, 2]] 80 | ax.set_xlim3d([-radius/2 + trajectories[n][i, 0], radius/2 + trajectories[n][i, 0]]) 81 | ax.set_ylim3d([-radius/2 + trajectories[n][i, 1], radius/2 + trajectories[n][i, 1]]) 82 | ax.set_zlim3d([-radius/2 + trajectories[n][i, 2], radius/2 + trajectories[n][i, 2]]) 83 | 84 | if not initialized: 85 | 86 | for j, j_parent in enumerate(parents): 87 | if j_parent == -1: 88 | continue 89 | 90 | col = rcol if j in skeleton.joints_right() else lcol 91 | for n, ax in enumerate(ax_3d): 92 | pos = poses[n][i] 93 | lines_3d[n].append(ax.plot([pos[j, 0], pos[j_parent, 0]], 94 | [pos[j, 1], pos[j_parent, 1]], 95 | [pos[j, 2], pos[j_parent, 2]], zdir='z', c=col, lw=8, dash_capstyle='round', marker='o', markersize=12, alpha=0.5, aa=True)) 96 | initialized = True 97 | else: 98 | 99 | for j, j_parent in enumerate(parents): 100 | if j_parent == -1: 101 | continue 102 | 103 | col = rcol if j in skeleton.joints_right() else lcol 104 | for n, ax in enumerate(ax_3d): 105 | if fix_0 and n == 0 and i >= t_hist: 106 | continue 107 | pos = poses[n][i] 108 | lines_3d[n][j-1][0].set_xdata(np.array([pos[j, 0], pos[j_parent, 0]])) 109 | lines_3d[n][j-1][0].set_ydata(np.array([pos[j, 1], pos[j_parent, 1]])) 110 | lines_3d[n][j-1][0].set_3d_properties(np.array([pos[j, 2], pos[j_parent, 2]]), zdir='z') 111 | lines_3d[n][j-1][0].set_color(col) 112 | 113 | def show_animation(): 114 | nonlocal anim 115 | if anim is not None: 116 | anim.event_source.stop() 117 | anim = FuncAnimation(fig, update_video, frames=np.arange(0, poses[0].shape[0]), interval=0, repeat=True) 118 | plt.draw() 119 | 120 | def reload_poses(): 121 | nonlocal poses, action 122 | if 'action' in all_poses: 123 | action = all_poses.pop('action') 124 | poses = all_poses 125 | # for ax, title in zip(ax_3d, poses.keys()): 126 | # ax.set_title(title, y=1.2) 127 | poses = list(poses.values()) 128 | 129 | def save_figs(): 130 | nonlocal find 131 | update_video(0) 132 | update_video(t_total - 1) 133 | os.makedirs(output + 'image', exist_ok=True) 134 | fig.savefig(output + 'image/%d_%s.png' % (index_i, action), dpi=80, transparent=True) 135 | find += 1 136 | 137 | def on_key(event): 138 | nonlocal all_poses, animating, anim 139 | 140 | if event.key == 'd': 141 | all_poses = next(poses_generator) 142 | reload_poses() 143 | show_animation() 144 | elif event.key == 'c': 145 | save() 146 | elif event.key == ' ': 147 | if animating: 148 | anim.event_source.stop() 149 | else: 150 | anim.event_source.start() 151 | animating = not animating 152 | elif event.key == 'v': # save images 153 | if anim is not None: 154 | anim.event_source.stop() 155 | anim = None 156 | save_figs() 157 | 158 | def save(): 159 | nonlocal anim 160 | 161 | fps = 30 162 | anim = FuncAnimation(fig, update_video, frames=np.arange(0, poses[0].shape[0]), interval=1000 / fps, repeat=False) 163 | os.makedirs(output+'video', exist_ok=True) 164 | anim.save(output + 'video/%d_%s.gif' % (index_i, action), dpi=80, writer='imagemagick') 165 | print(f'video saved to {output}video/{index_i}_{action}.gif!') 166 | save() 167 | # fig.canvas.mpl_connect('key_press_event', on_key) 168 | # show_animation() 169 | # plt.show() -------------------------------------------------------------------------------- /utils/vis_poses.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib 3 | 4 | matplotlib.use("Agg") 5 | import matplotlib.pyplot as plt 6 | import os 7 | import subprocess 8 | 9 | from utils.vis_util import draw_skeleton 10 | 11 | 12 | def vis_traj(gt_pos, pre_pos=[], labels=[], comments='', title='', 13 | joint_to_plot=[2, 3, 7, 8, 15, 16, 17, 20, 21, 22], sub_title=[''] * 10): 14 | """ 15 | @param gt_pos: seq_l x jn x 3 16 | @param pre_pos: n seq_l x jn x 3 17 | @param labels: n 18 | @param comments: 19 | @return: 20 | """ 21 | assert len(pre_pos) == len(labels) 22 | assert len(gt_pos.shape) == 3 23 | assert len(sub_title) == len(joint_to_plot) 24 | seq_n, joint_n, _ = gt_pos.shape 25 | joint_name = np.array(["Hips", "rUpLeg", "rLeg", "rFoot", "rToeBase", "rToeSite", "lUpLeg", "lLeg", 26 | "lFoot", "lToeBase", "lToeSite", "Spine", "Spine1", "Neck", "Head", "Site", "lShoulder", 27 | "lArm", "lForeArm", "lHand", "lHandThumb", "lHandSite", "lWristEnd", "lWristSite", 28 | "rShoulder", "rArm", "rForeArm", "rHand", "rHandThumb", "rHandSite", 29 | "rWristEnd", "rWristSite"]) 30 | joint_to_ignore = np.array([11, 16, 20, 23, 24, 28, 31]) 31 | joint_used = np.setdiff1d(np.arange(32), joint_to_ignore) 32 | parents = np.array([-1, 0, 1, 2, 3, 4, 0, 6, 7, 8, 9, 0, 11, 12, 13, 12, 15, 16, 17, 17, 12, 20, 21, 22, 22]) 33 | is_right = [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1] 34 | # joint_to_plot = [2, 3, 7, 8, 16, 17, 21, 22] 35 | linestyles = ['-', '--', '-.', ':', '--', '-.'] 36 | markers = ['o', 'v', '*', 'D', 's', 'p'] 37 | colors = ['k', 'r', 'g', 'b', 'c', 'm'] 38 | coord = ['x', 'y', 'z'] 39 | fig = plt.figure(0, figsize=[6 * 3, 3 * len(joint_to_plot)]) 40 | # plt.title(title, pad=1) 41 | fig.suptitle('{}_{}'.format(title.split('/')[-1].split('.pdf')[0], comments), fontsize=14) 42 | axs = fig.subplots(len(joint_to_plot), 3) 43 | plt.subplots_adjust(left=0.05, right=0.95, bottom=0.05, top=0.95, wspace=0.2, hspace=0.4) 44 | for i, jn in enumerate(joint_to_plot): 45 | axs[i, 1].set_title("{}_{}".format(joint_name[joint_used][jn], sub_title[i]), x=0.5, y=1) 46 | # axs[i, 1].legend() 47 | for j in range(3): 48 | axs[i, j].plot(np.arange(1, seq_n + 1), gt_pos[:, jn, j], linestyle=linestyles[0], marker=markers[0], 49 | c=colors[0], label='GT') 50 | axs[i, j].set_xticks(np.arange(1, seq_n + 1)) 51 | for k, pp in enumerate(pre_pos): 52 | axs[i, j].plot(np.arange(1, seq_n + 1), pp[:, jn, j], linestyle=linestyles[k + 1], 53 | marker=markers[k + 1], c=colors[k + 1], 54 | label=labels[k]) 55 | axs[0, 1].legend() 56 | plt.savefig('{}'.format(title)) 57 | # plt.savefig('test1.pdf'.format(title)) 58 | plt.clf() 59 | plt.close() 60 | # plt.show() 61 | 62 | 63 | def vis_poses(gt_pos, pre_pos=[], labels=[], comments='', title='', skeleton=None): 64 | """ 65 | @param gt_pos: seq_l x jn x 3 66 | @param pre_pos: n seq_l x jn x 3 67 | @param labels: n 68 | @param comments: 69 | @return: 70 | """ 71 | assert len(pre_pos) == len(labels) 72 | assert len(gt_pos.shape) == 3 73 | if not len(comments) == gt_pos.shape[0]: 74 | comments = [] * gt_pos.shape[0] 75 | seq_n, joint_n, _ = gt_pos.shape 76 | # joint_name = np.array(["Hips", "rUpLeg", "rLeg", "rFoot", "rToeBase", "rToeSite", "lUpLeg", "lLeg", 77 | # "lFoot", "lToeBase", "lToeSite", "Spine", "Spine1", "Neck", "Head", "Site", "lShoulder", 78 | # "lArm", "lForeArm", "lHand", "lHandThumb", "lHandSite", "lWristEnd", "lWristSite", 79 | # "rShoulder", "rArm", "rForeArm", "rHand", "rHandThumb", "rHandSite", 80 | # "rWristEnd", "rWristSite"]) 81 | # joint_to_ignore = np.array([11, 16, 20, 23, 24, 28, 31]) 82 | # joint_used = np.setdiff1d(np.arange(32), joint_to_ignore) 83 | # parents = np.array([-1, 0, 1, 2, 3, 4, 0, 6, 7, 8, 9, 0, 11, 12, 13, 12, 15, 16, 17, 17, 12, 20, 21, 22, 22]) 84 | # is_right = [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1] 85 | parents = skeleton._parents 86 | is_right = np.zeros_like(parents) 87 | is_right[skeleton._joints_right] = 1 88 | 89 | joint_to_plot = [2, 3, 15, 16, 17] 90 | linestyles = ['-', '--', '-.', ':', '--', '-.'] 91 | markers = ['o', 'v', '*', 'D', 's', 'p'] 92 | colors = ['k', 'r', 'g', 'b', 'c', 'm'] 93 | coord = ['x', 'y', 'z'] 94 | rot1 = np.array([[0.9239557, -0.0000000, -0.3824995], 95 | [-0.1463059, 0.9239557, -0.3534126], 96 | [0.3534126, 0.3824995, 0.8536941]]) 97 | rot2 = np.array([[1.0000000, 0.0000000, 0.0000000], 98 | [0.0000000, 0.9239557, -0.3824995], 99 | [0.0000000, 0.3824995, 0.9239557]]) 100 | rot3 = np.array([[0.9239557, 0.0000000, 0.3824995], 101 | [0.1463059, 0.9239557, -0.3534126], 102 | [-0.3534126, 0.3824995, 0.8536941]]) 103 | rot = [rot3, rot2, rot1] 104 | trans = 4 105 | # axs = fig.subplots(1, len(rot)) 106 | # get value scope 107 | pgt = [] 108 | ppred = [] 109 | scope = [] 110 | pp_tmp = [] 111 | for i, rr in enumerate(rot): 112 | pt = np.matmul(np.expand_dims(rr, axis=0), gt_pos.transpose([0, 2, 1])).transpose([0, 2, 1]) 113 | pt[:, :, 2] = pt[:, :, 2] + trans 114 | pt = pt[:, :, :2] / pt[:, :, 2:] 115 | pgt.append(pt) 116 | pp_tmp.append(pt) 117 | ppred.append([]) 118 | for j, pp in enumerate(pre_pos): 119 | pt = np.matmul(np.expand_dims(rr, axis=0), pp.transpose([0, 2, 1])).transpose([0, 2, 1]) 120 | pt[:, :, 2] = pt[:, :, 2] + trans 121 | pt = pt[:, :, :2] / pt[:, :, 2:] 122 | pp_tmp.append(pt) 123 | ppred[i].append(pt) 124 | pp_tmp = np.vstack(pp_tmp) 125 | max_x = np.max(pp_tmp[:, :, 0]) 126 | min_x = np.min(pp_tmp[:, :, 0]) 127 | max_y = np.max(pp_tmp[:, :, 1]) 128 | min_y = np.min(pp_tmp[:, :, 1]) 129 | dx = (max_x - min_x) + 0.1 * (max_x - min_x) 130 | max_x = min_x + dx * len(rot) 131 | 132 | scope = [max_x, min_x, max_y, min_y] 133 | for jj in range(seq_n): 134 | fig = plt.figure(0, figsize=[3 * len(rot), 6]) 135 | axs = fig.subplots(1, 1) 136 | plt.subplots_adjust(left=0.05, right=0.95, bottom=0.05, top=0.95, wspace=0.2, hspace=0.4) 137 | # for i in range(len(rot)): 138 | # axs[i].set_axis_off() 139 | # axs[i].axis('equal') 140 | # axs[i].plot(scope[i][:2], scope[i][2:], c='w') 141 | axs.set_axis_off() 142 | axs.axis('equal') 143 | axs.plot(scope[:2], scope[2:], c='w') 144 | for i, rr in enumerate(rot): 145 | # pt = np.matmul(rr, gt_pos[jj].transpose([1, 0])).transpose([1, 0]) 146 | # pt[:, 2] = pt[:, 2] + trans 147 | # pt = pt[:, :2] / pt[:, 2:] 148 | pgt[i][jj][:, 0] = pgt[i][jj][:, 0] + dx * i 149 | draw_skeleton(axs, pgt[i][jj], parents=parents, is_right=is_right, cols=[colors[0], colors[0]], 150 | line_style=linestyles[0], label='GT' if i == 0 else None) 151 | for j, pp in enumerate(pre_pos): 152 | # pt = np.matmul(rr, pp[jj].transpose([1, 0])).transpose([1, 0]) 153 | # pt[:, 2] = pt[:, 2] + trans 154 | # pt = pt[:, :2] / pt[:, 2:] 155 | ppred[i][j][jj][:, 0] = ppred[i][j][jj][:, 0] + dx * i 156 | draw_skeleton(axs, ppred[i][j][jj], parents=parents, is_right=is_right, 157 | cols=[colors[j + 1], colors[j + 1]], 158 | line_style=linestyles[j + 1], label=labels[j] if i == 0 else None) 159 | # axs[1].legend() 160 | # axs[1].set_title("f{}_{}".format(jj + 1, comments[jj]), x=0.5, y=1) 161 | # plt.savefig('{}/{}.jpg'.format(title, jj)) 162 | axs.legend() 163 | axs.set_title("f{}_{}".format(jj + 1, comments[jj]), x=0.5, y=1) 164 | plt.savefig('{}/{}.jpg'.format(title, jj)) 165 | # plt.show(block=False) 166 | # plt.pause(0.1) 167 | # for i in range(len(rot)): 168 | # axs[i].clear() 169 | axs.clear() 170 | plt.clf() 171 | plt.close() 172 | # cmd = [ 173 | # 'ffmpeg', 174 | # '-i', vid_filename, 175 | # f'{output_folder}/%06d.jpg', 176 | # '-threads', '16' 177 | # ] 178 | # 179 | # print(' '.join(cmd)) 180 | # try: 181 | # subprocess.call(cmd) 182 | # except OSError: 183 | # print('OSError') 184 | 185 | # # Set up formatting for the movie files 186 | # Writer = animation.writers['ffmpeg'] 187 | # writer = Writer(fps=15, metadata=dict(artist='Me'), bitrate=1800) 188 | # fig = plt.figure() 189 | # im = [] 190 | # for jj in range(seq_n): 191 | # img = plt.imread('{}_{}.jpg'.format(title.replace('.mp4', ''), jj)) 192 | # im_tmp = plt.imshow(img) 193 | # im.append((im_tmp,)) 194 | # im_ani = animation.ArtistAnimation(fig, im, interval=50, repeat_delay=3000, 195 | # blit=True) 196 | # im_ani.save(title, writer=writer) 197 | -------------------------------------------------------------------------------- /motion_pred/utils/dataset_humaneva_multimodal.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from motion_pred.utils.dataset import Dataset 4 | from motion_pred.utils.skeleton import Skeleton 5 | from utils import util 6 | 7 | 8 | class DatasetHumanEva(Dataset): 9 | 10 | def __init__(self, mode, t_his=15, t_pred=60, actions='all', **kwargs): 11 | if 'multimodal_path' in kwargs.keys(): 12 | self.multimodal_path = kwargs['multimodal_path'] 13 | else: 14 | self.multimodal_path = None 15 | 16 | if 'data_candi_path' in kwargs.keys(): 17 | self.data_candi_path = kwargs['data_candi_path'] 18 | else: 19 | self.data_candi_path = None 20 | super().__init__(mode, t_his, t_pred, actions) 21 | 22 | def prepare_data(self): 23 | self.data_file = os.path.join('data', 'data_3d_humaneva15.npz') 24 | self.subjects_split = {'train': ['Train/S1', 'Train/S2', 'Train/S3'], 25 | 'test': ['Validate/S1', 'Validate/S2', 'Validate/S3']} 26 | self.subjects = [x for x in self.subjects_split[self.mode]] 27 | self.skeleton = Skeleton(parents=[-1, 0, 1, 2, 3, 1, 5, 6, 0, 8, 9, 0, 11, 12, 1], 28 | joints_left=[2, 3, 4, 8, 9, 10], 29 | joints_right=[5, 6, 7, 11, 12, 13]) 30 | self.kept_joints = np.arange(15) 31 | self.process_data() 32 | 33 | def process_data(self): 34 | data_o = np.load(self.data_file, allow_pickle=True)['positions_3d'].item() 35 | data_f = dict(filter(lambda x: x[0] in self.subjects, data_o.items())) 36 | # these takes have wrong head position, excluded from training and testing 37 | if self.mode == 'train': 38 | data_f['Train/S3'].pop('Walking 1 chunk0') 39 | data_f['Train/S3'].pop('Walking 1 chunk2') 40 | else: 41 | data_f['Validate/S3'].pop('Walking 1 chunk4') 42 | if self.multimodal_path is None: 43 | self.data_multimodal = \ 44 | np.load('./data/humaneva_multi_modal/t_his15_1_thre0.050_t_pred60_thre0.100_index_filterd.npz', 45 | allow_pickle=True)['data_multimodal'].item() 46 | data_candi = \ 47 | np.load('./data/humaneva_multi_modal/data_candi_t_his15_t_pred60_skiprate1.npz', allow_pickle=True)[ 48 | 'data_candidate.npy'] 49 | else: 50 | self.data_multimodal = np.load(self.multimodal_path, allow_pickle=True)['data_multimodal'].item() 51 | data_candi = np.load(self.data_candi_path, allow_pickle=True)['data_candidate.npy'] 52 | 53 | self.data_candi = {} 54 | 55 | for key in list(data_f.keys()): 56 | # data_f[key] = dict(filter(lambda x: (self.actions == 'all' or 57 | # all([a in x[0] for a in self.actions])) 58 | # and x[1].shape[0] >= self.t_total, data_f[key].items())) 59 | data_f[key] = dict(filter(lambda x: (self.actions == 'all' or 60 | any([a in x[0] for a in self.actions])) 61 | and x[1].shape[0] >= self.t_total, data_f[key].items())) 62 | if len(data_f[key]) == 0: 63 | data_f.pop(key) 64 | for sub in data_f.keys(): 65 | data_s = data_f[sub] 66 | # for data_s in data_f.values(): 67 | for action in data_s.keys(): 68 | seq = data_s[action][:, self.kept_joints, :] 69 | seq[:, 1:] -= seq[:, :1] 70 | data_s[action] = seq 71 | 72 | if sub not in self.data_candi.keys(): 73 | x0 = np.copy(seq[None, :1, ...]) 74 | x0[:, :, 0] = 0 75 | self.data_candi[sub] = util.absolute2relative(data_candi, parents=self.skeleton.parents(), 76 | invert=True, x0=x0) 77 | self.data = data_f 78 | 79 | def sample(self, n_modality=5): 80 | while True: 81 | subject = np.random.choice(self.subjects) 82 | dict_s = self.data[subject] 83 | action = np.random.choice(list(dict_s.keys())) 84 | seq = dict_s[action] 85 | if seq.shape[0] > self.t_total: 86 | break 87 | fr_start = np.random.randint(seq.shape[0] - self.t_total) 88 | fr_end = fr_start + self.t_total 89 | traj = seq[fr_start: fr_end] 90 | if n_modality > 0: 91 | # margin_f = 1 92 | # thre_his = 0.05 93 | # thre_pred = 0.1 94 | # x0 = np.copy(traj[None, ...]) 95 | # x0[:, :, 0] = 0 96 | # # candi_tmp = util.absolute2relative(self.data_candi, parents=self.skeleton.parents(), invert=True, x0=x0) 97 | candi_tmp = self.data_candi[subject] 98 | # # observation distance 99 | # dist_his = np.mean(np.linalg.norm(x0[:, self.t_his - margin_f:self.t_his, 1:] - 100 | # candi_tmp[:, self.t_his - margin_f:self.t_his, 1:], axis=3), axis=(1, 2)) 101 | # idx_his = np.where(dist_his <= thre_his)[0] 102 | # 103 | # # future distance 104 | # dist_pred = np.mean(np.linalg.norm(x0[:, self.t_his:, 1:] - 105 | # candi_tmp[idx_his, self.t_his:, 1:], axis=3), axis=(1, 2)) 106 | # 107 | # idx_pred = np.where(dist_pred >= thre_pred)[0] 108 | # # idxs = np.intersect1d(idx_his, idx_pred) 109 | idx_multi = self.data_multimodal[subject][action][fr_start] 110 | traj_multi = candi_tmp[idx_multi] 111 | 112 | # confirm if it is the right one 113 | if len(idx_multi) > 0: 114 | margin_f = 1 115 | thre_his = 0.05 116 | thre_pred = 0.1 117 | x0 = np.copy(traj[None, ...]) 118 | x0[:, :, 0] = 0 119 | dist_his = np.mean(np.linalg.norm(x0[:, self.t_his - margin_f:self.t_his, 1:] - 120 | traj_multi[:, self.t_his - margin_f:self.t_his, 1:], axis=3), 121 | axis=(1, 2)) 122 | # if np.any(dist_his > thre_his): 123 | # print(f'===> wrong multi modality sequneces {dist_his[dist_his > thre_his].max():.3f}') 124 | 125 | if len(traj_multi) > 0: 126 | traj_multi[:, :self.t_his] = traj[None, ...][:, :self.t_his] 127 | if traj_multi.shape[0] > n_modality: 128 | st0 = np.random.get_state() 129 | idxtmp = np.random.choice(np.arange(traj_multi.shape[0]), n_modality, replace=False) 130 | traj_multi = traj_multi[idxtmp] 131 | np.random.set_state(st0) 132 | # traj_multi = traj_multi[:n_modality] 133 | traj_multi = np.concatenate( 134 | [traj_multi, np.zeros_like(traj[None, ...][[0] * (n_modality - traj_multi.shape[0])])], axis=0) 135 | 136 | return traj[None, ...], traj_multi, action 137 | else: 138 | return traj[None, ...], None, action 139 | 140 | def sampling_generator(self, num_samples=1000, batch_size=8, n_modality=5): 141 | for i in range(num_samples // batch_size): 142 | sample = [] 143 | sample_multi = [] 144 | for i in range(batch_size): 145 | sample_i, sample_multi_i, _ = self.sample(n_modality=n_modality) 146 | sample.append(sample_i) 147 | sample_multi.append(sample_multi_i[None, ...]) 148 | sample = np.concatenate(sample, axis=0) 149 | sample_multi = np.concatenate(sample_multi, axis=0) 150 | yield sample, sample_multi 151 | 152 | # 153 | # def iter_generator(self, step=25, n_modality=10): 154 | # for sub in self.data.keys(): 155 | # data_s = self.data[sub] 156 | # candi_tmp = self.data_candi[sub] 157 | # for act in data_s.keys(): 158 | # seq = data_s[act] 159 | # seq_len = seq.shape[0] 160 | # for i in range(0, seq_len - self.t_total, step): 161 | # # idx_multi = self.data_multimodal[sub][act][i] 162 | # # traj_multi = candi_tmp[idx_multi] 163 | # traj = seq[None, i: i + self.t_total] 164 | # if n_modality > 0: 165 | # margin_f = 1 166 | # thre_his = 0.05 167 | # thre_pred = 0.1 168 | # x0 = np.copy(traj) 169 | # x0[:, :, 0] = 0 170 | # # candi_tmp = util.absolute2relative(self.data_candi, parents=self.skeleton.parents(), invert=True, x0=x0) 171 | # # candi_tmp = self.data_candi[subject] 172 | # # observation distance 173 | # dist_his = np.mean(np.linalg.norm(x0[:, self.t_his - margin_f:self.t_his, 1:] - 174 | # candi_tmp[:, self.t_his - margin_f:self.t_his, 1:], axis=3), 175 | # axis=(1, 2)) 176 | # idx_his = np.where(dist_his <= thre_his)[0] 177 | # 178 | # # future distance 179 | # dist_pred = np.mean(np.linalg.norm(x0[:, self.t_his:, 1:] - 180 | # candi_tmp[idx_his, self.t_his:, 1:], axis=3), axis=(1, 2)) 181 | # 182 | # idx_pred = np.where(dist_pred >= thre_pred)[0] 183 | # # idxs = np.intersect1d(idx_his, idx_pred) 184 | # traj_multi = candi_tmp[idx_his[idx_pred]] 185 | # if len(traj_multi) > 0: 186 | # traj_multi[:, :self.t_his] = traj[:, :self.t_his] 187 | # if traj_multi.shape[0] > n_modality: 188 | # idxtmp = np.random.choice(np.arange(traj_multi.shape[0]), n_modality, replace=False) 189 | # traj_multi = traj_multi[idxtmp] 190 | # traj_multi = np.concatenate( 191 | # [traj_multi, np.zeros_like(traj[[0] * (n_modality - traj_multi.shape[0])])], 192 | # axis=0) 193 | # else: 194 | # traj_multi = None 195 | # 196 | # yield traj, traj_multi 197 | 198 | def iter_generator(self, step=25): 199 | for data_s in self.data.values(): 200 | for seq in data_s.values(): 201 | seq_len = seq.shape[0] 202 | for i in range(0, seq_len - self.t_total, step): 203 | traj = seq[None, i: i + self.t_total] 204 | yield traj, None 205 | 206 | 207 | if __name__ == '__main__': 208 | np.random.seed(0) 209 | actions = 'all' 210 | dataset = DatasetHumanEva('test', actions=actions) 211 | generator = dataset.sampling_generator() 212 | dataset.normalize_data() 213 | # generator = dataset.iter_generator() 214 | for data in generator: 215 | print(data.shape) 216 | -------------------------------------------------------------------------------- /motion_pred/utils/dataset_h36m_multimodal.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from motion_pred.utils.dataset import Dataset 4 | from motion_pred.utils.skeleton import Skeleton 5 | from utils import util 6 | 7 | 8 | class DatasetH36M(Dataset): 9 | 10 | def __init__(self, mode, t_his=25, t_pred=100, actions='all', use_vel=False, **kwargs): 11 | self.use_vel = use_vel 12 | if 'multimodal_path' in kwargs.keys(): 13 | self.multimodal_path = kwargs['multimodal_path'] 14 | else: 15 | self.multimodal_path = None 16 | 17 | if 'data_candi_path' in kwargs.keys(): 18 | self.data_candi_path = kwargs['data_candi_path'] 19 | else: 20 | self.data_candi_path = None 21 | super().__init__(mode, t_his, t_pred, actions) 22 | if use_vel: 23 | self.traj_dim += 3 24 | 25 | def prepare_data(self): 26 | self.data_file = os.path.join('data', 'data_3d_h36m.npz') 27 | self.subjects_split = {'train': [1, 5, 6, 7, 8], 28 | 'test': [9, 11]} 29 | self.subjects = ['S%d' % x for x in self.subjects_split[self.mode]] 30 | self.skeleton = Skeleton(parents=[-1, 0, 1, 2, 3, 4, 0, 6, 7, 8, 9, 0, 11, 12, 13, 14, 12, 31 | 16, 17, 18, 19, 20, 19, 22, 12, 24, 25, 26, 27, 28, 27, 30], 32 | joints_left=[6, 7, 8, 9, 10, 16, 17, 18, 19, 20, 21, 22, 23], 33 | joints_right=[1, 2, 3, 4, 5, 24, 25, 26, 27, 28, 29, 30, 31]) 34 | self.removed_joints = {4, 5, 9, 10, 11, 16, 20, 21, 22, 23, 24, 28, 29, 30, 31} 35 | self.kept_joints = np.array([x for x in range(32) if x not in self.removed_joints]) 36 | self.skeleton.remove_joints(self.removed_joints) 37 | self.skeleton._parents[11] = 8 38 | self.skeleton._parents[14] = 8 39 | self.process_data() 40 | 41 | def process_data(self): 42 | data_o = np.load(self.data_file, allow_pickle=True)['positions_3d'].item() 43 | self.S1_skeleton = data_o['S1']['Directions'][:1, self.kept_joints].copy() 44 | data_f = dict(filter(lambda x: x[0] in self.subjects, data_o.items())) 45 | if self.actions != 'all': 46 | for key in list(data_f.keys()): 47 | # data_f[key] = dict( 48 | # filter(lambda x: all([a in str.lower(x[0]) for a in self.actions]), data_f[key].items())) 49 | data_f[key] = dict(filter(lambda x: any([a in str.lower(x[0]) for a in self.actions]), data_f[key].items())) 50 | if len(data_f[key]) == 0: 51 | data_f.pop(key) 52 | # possible candidate 53 | # skip_rate = 10 54 | # data_candi = [] 55 | if self.multimodal_path is None: 56 | self.data_multimodal = \ 57 | np.load('./data/data_multi_modal/t_his25_1_thre0.050_t_pred100_thre0.100_filtered.npz', 58 | allow_pickle=True)[ 59 | 'data_multimodal'].item() 60 | data_candi = \ 61 | np.load('./data/data_multi_modal/data_candi_t_his25_t_pred100_skiprate20.npz', allow_pickle=True)[ 62 | 'data_candidate.npy'] 63 | else: 64 | self.data_multimodal = np.load(self.multimodal_path, allow_pickle=True)['data_multimodal'].item() 65 | data_candi = np.load(self.data_candi_path, allow_pickle=True)['data_candidate.npy'] 66 | 67 | self.data_candi = {} 68 | 69 | for sub in data_f.keys(): 70 | data_s = data_f[sub] 71 | for action in data_s.keys(): 72 | seq = data_s[action][:, self.kept_joints, :] 73 | if self.use_vel: 74 | v = (np.diff(seq[:, :1], axis=0) * 50).clip(-5.0, 5.0) 75 | v = np.append(v, v[[-1]], axis=0) 76 | seq[:, 1:] -= seq[:, :1] 77 | 78 | # # get relative candidate 79 | # data_tmp = np.copy(seq) 80 | # data_tmp[:, 0] = 0 81 | # nf = data_tmp.shape[0] 82 | # idxs = np.arange(0, nf - self.t_his - self.t_pred, skip_rate)[:, None] + np.arange( 83 | # self.t_his + self.t_pred)[None, :] 84 | # data_tmp = data_tmp[idxs] 85 | # data_tmp = util.absolute2relative(data_tmp, parents=self.skeleton.parents()) 86 | # data_candi.append(data_tmp) 87 | 88 | if self.use_vel: 89 | seq = np.concatenate((seq, v), axis=1) 90 | data_s[action] = seq 91 | 92 | if sub not in self.data_candi.keys(): 93 | x0 = np.copy(seq[None, :1, ...]) 94 | x0[:, :, 0] = 0 95 | self.data_candi[sub] = util.absolute2relative(data_candi, parents=self.skeleton.parents(), 96 | invert=True, x0=x0) 97 | 98 | self.data = data_f 99 | # self.data_candi = np.concatenate(data_candi, axis=0) 100 | 101 | def sample(self, n_modality=5): 102 | subject = np.random.choice(self.subjects) 103 | dict_s = self.data[subject] 104 | action = np.random.choice(list(dict_s.keys())) 105 | seq = dict_s[action] 106 | fr_start = np.random.randint(seq.shape[0] - self.t_total) 107 | fr_end = fr_start + self.t_total 108 | traj = seq[fr_start: fr_end] 109 | if n_modality > 0 and subject in self.data_multimodal.keys(): 110 | # margin_f = 1 111 | # thre_his = 0.05 112 | # thre_pred = 0.1 113 | # x0 = np.copy(traj[None, ...]) 114 | # x0[:, :, 0] = 0 115 | # # candi_tmp = util.absolute2relative(self.data_candi, parents=self.skeleton.parents(), invert=True, x0=x0) 116 | candi_tmp = self.data_candi[subject] 117 | # # observation distance 118 | # dist_his = np.mean(np.linalg.norm(x0[:, self.t_his - margin_f:self.t_his, 1:] - 119 | # candi_tmp[:, self.t_his - margin_f:self.t_his, 1:], axis=3), axis=(1, 2)) 120 | # idx_his = np.where(dist_his <= thre_his)[0] 121 | # 122 | # # future distance 123 | # dist_pred = np.mean(np.linalg.norm(x0[:, self.t_his:, 1:] - 124 | # candi_tmp[idx_his, self.t_his:, 1:], axis=3), axis=(1, 2)) 125 | # 126 | # idx_pred = np.where(dist_pred >= thre_pred)[0] 127 | # # idxs = np.intersect1d(idx_his, idx_pred) 128 | idx_multi = self.data_multimodal[subject][action][fr_start] 129 | traj_multi = candi_tmp[idx_multi] 130 | 131 | # # confirm if it is the right one 132 | # if len(idx_multi) > 0: 133 | # margin_f = 1 134 | # thre_his = 0.05 135 | # thre_pred = 0.1 136 | # x0 = np.copy(traj[None, ...]) 137 | # x0[:, :, 0] = 0 138 | # dist_his = np.mean(np.linalg.norm(x0[:, self.t_his - margin_f:self.t_his, 1:] - 139 | # traj_multi[:, self.t_his - margin_f:self.t_his, 1:], axis=3), 140 | # axis=(1, 2)) 141 | # if np.any(dist_his > thre_his): 142 | # print(f'===> wrong multi modality sequneces {dist_his[dist_his > thre_his].max():.3f}') 143 | 144 | if len(traj_multi) > 0: 145 | traj_multi[:, :self.t_his] = traj[None, ...][:, :self.t_his] 146 | if traj_multi.shape[0] > n_modality: 147 | st0 = np.random.get_state() 148 | idxtmp = np.random.choice(np.arange(traj_multi.shape[0]), n_modality, replace=False) 149 | traj_multi = traj_multi[idxtmp] 150 | np.random.set_state(st0) 151 | # traj_multi = traj_multi[:n_modality] 152 | traj_multi = np.concatenate( 153 | [traj_multi, np.zeros_like(traj[None, ...][[0] * (n_modality - traj_multi.shape[0])])], axis=0) 154 | 155 | return traj[None, ...], traj_multi, action 156 | else: 157 | return traj[None, ...], None, action 158 | 159 | def sampling_generator(self, num_samples=1000, batch_size=8, n_modality=5): 160 | for i in range(num_samples // batch_size): 161 | sample = [] 162 | sample_multi = [] 163 | for i in range(batch_size): 164 | sample_i, sample_multi_i, _ = self.sample(n_modality=n_modality) 165 | sample.append(sample_i) 166 | sample_multi.append(sample_multi_i[None, ...]) 167 | sample = np.concatenate(sample, axis=0) 168 | sample_multi = np.concatenate(sample_multi, axis=0) 169 | yield sample, sample_multi 170 | 171 | def iter_generator(self, step=25, n_modality=10): 172 | for sub in self.data.keys(): 173 | data_s = self.data[sub] 174 | candi_tmp = self.data_candi[sub] 175 | for act in data_s.keys(): 176 | seq = data_s[act] 177 | seq_len = seq.shape[0] 178 | for i in range(0, seq_len - self.t_total, step): 179 | # idx_multi = self.data_multimodal[sub][act][i] 180 | # traj_multi = candi_tmp[idx_multi] 181 | traj = seq[None, i: i + self.t_total] 182 | if n_modality > 0: 183 | margin_f = 1 184 | thre_his = 0.05 185 | thre_pred = 0.1 186 | x0 = np.copy(traj) 187 | x0[:, :, 0] = 0 188 | # candi_tmp = util.absolute2relative(self.data_candi, parents=self.skeleton.parents(), invert=True, x0=x0) 189 | # candi_tmp = self.data_candi[subject] 190 | # observation distance 191 | dist_his = np.mean(np.linalg.norm(x0[:, self.t_his - margin_f:self.t_his, 1:] - 192 | candi_tmp[:, self.t_his - margin_f:self.t_his, 1:], axis=3), 193 | axis=(1, 2)) 194 | idx_his = np.where(dist_his <= thre_his)[0] 195 | 196 | # future distance 197 | dist_pred = np.mean(np.linalg.norm(x0[:, self.t_his:, 1:] - 198 | candi_tmp[idx_his, self.t_his:, 1:], axis=3), axis=(1, 2)) 199 | 200 | idx_pred = np.where(dist_pred >= thre_pred)[0] 201 | # idxs = np.intersect1d(idx_his, idx_pred) 202 | traj_multi = candi_tmp[idx_his[idx_pred]] 203 | if len(traj_multi) > 0: 204 | traj_multi[:, :self.t_his] = traj[:, :self.t_his] 205 | if traj_multi.shape[0] > n_modality: 206 | # idxtmp = np.random.choice(np.arange(traj_multi.shape[0]), n_modality, replace=False) 207 | # traj_multi = traj_multi[idxtmp] 208 | traj_multi = traj_multi[:n_modality] 209 | traj_multi = np.concatenate( 210 | [traj_multi, np.zeros_like(traj[[0] * (n_modality - traj_multi.shape[0])])], 211 | axis=0) 212 | else: 213 | traj_multi = None 214 | 215 | yield traj, traj_multi 216 | 217 | 218 | if __name__ == '__main__': 219 | np.random.seed(0) 220 | actions = {'WalkDog'} 221 | dataset = DatasetH36M('train', actions=actions) 222 | generator = dataset.sampling_generator() 223 | dataset.normalize_data() 224 | # generator = dataset.iter_generator() 225 | for data in generator: 226 | print(data.shape) 227 | -------------------------------------------------------------------------------- /train_nf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import pickle 4 | import argparse 5 | import time 6 | from torch import optim 7 | from torch.utils.tensorboard import SummaryWriter 8 | import matplotlib.pyplot as plt 9 | 10 | sys.path.append(os.getcwd()) 11 | from utils import * 12 | from motion_pred.utils.config import Config 13 | from motion_pred.utils.dataset_h36m import DatasetH36M 14 | from motion_pred.utils.dataset_humaneva import DatasetHumanEva 15 | from models.motion_pred import * 16 | from utils import util 17 | 18 | 19 | def loss_function(prior_lkh, log_det_jacobian): 20 | loss_p = -prior_lkh.mean() 21 | loss_jac = - log_det_jacobian.mean() 22 | loss_r = loss_p + loss_jac 23 | 24 | return loss_r, np.array([loss_r.item(), loss_p.item(), loss_jac.item()]) 25 | 26 | 27 | def train(epoch): 28 | model.train() 29 | t_s = time.time() 30 | train_losses = 0 31 | train_grad = 0 32 | total_num_sample = 0 33 | loss_names = ['LKH', 'log_p(z)', 'log_det'] 34 | generator = dataset.sampling_generator(num_samples=cfg.num_data_sample, batch_size=cfg.batch_size) 35 | prior = torch.distributions.Normal(torch.tensor(0, dtype=dtype, device=device), 36 | torch.tensor(1, dtype=dtype, device=device)) 37 | for traj_np in generator: 38 | with torch.no_grad(): 39 | traj_np = traj_np[:, 0] 40 | traj_np[:, 0] = 0 41 | traj_np = util.absolute2relative(traj_np, parents=dataset.skeleton.parents()) 42 | traj = tensor(traj_np, device=device, dtype=dtype) # .permute(0, 2, 1).contiguous() 43 | bs, nj, _ = traj.shape 44 | x = traj.reshape([bs, -1]) 45 | 46 | z, log_det_jacobian = model(x) 47 | prior_likelihood = prior.log_prob(z).sum(dim=1) 48 | 49 | loss, losses = loss_function(prior_likelihood, log_det_jacobian) 50 | optimizer.zero_grad() 51 | loss.backward() 52 | # grad_norm = torch.nn.utils.clip_grad_norm_(list(model.parameters()), max_norm=10000) 53 | grad_norm = 0 54 | train_grad += grad_norm 55 | optimizer.step() 56 | train_losses += losses 57 | total_num_sample += 1 58 | del loss, z, prior_likelihood, log_det_jacobian 59 | 60 | scheduler.step() 61 | # dt = time.time() - t_s 62 | train_losses /= total_num_sample 63 | lr = optimizer.param_groups[0]['lr'] 64 | losses_str = ' '.join(['{}: {:.4f}'.format(x, y) for x, y in zip(loss_names, train_losses)]) 65 | 66 | # average cost of log time 20s 67 | tb_logger.add_scalar('train_grad', train_grad / total_num_sample, epoch) 68 | for name, loss in zip(loss_names, train_losses): 69 | tb_logger.add_scalars(name, {'train': loss}, epoch) 70 | 71 | logger.info('====> Epoch: {} Time: {:.2f} {} lr: {:.5f}'.format(epoch, time.time() - t_s, losses_str, lr)) 72 | 73 | 74 | def val(epoch): 75 | model.eval() 76 | t_s = time.time() 77 | train_losses = 0 78 | total_num_sample = 0 79 | loss_names = ['LKH', 'log_p(z)', 'log_det'] 80 | generator = dataset.sampling_generator(num_samples=cfg.num_data_sample // 2, batch_size=cfg.batch_size) 81 | prior = torch.distributions.Normal(torch.tensor(0, dtype=dtype, device=device), 82 | torch.tensor(1, dtype=dtype, device=device)) 83 | loginfos = None 84 | with torch.no_grad(): 85 | xx = [] 86 | for traj_np in generator: 87 | traj_np = traj_np[:, 0] 88 | traj_np[:, 0] = 0 89 | traj_np = util.absolute2relative(traj_np, parents=dataset.skeleton.parents()) 90 | traj = tensor(traj_np, device=device, dtype=dtype) # .permute(0, 2, 1).contiguous() 91 | bs, nj, _ = traj.shape 92 | x = traj.reshape([bs, -1]) 93 | 94 | z, log_det_jacobian = model(x) 95 | prior_likelihood = prior.log_prob(z).sum(dim=1) 96 | loginf = {} 97 | loginf[f'z'] = z.cpu().data.numpy() 98 | 99 | loss, losses = loss_function(prior_likelihood, log_det_jacobian) 100 | 101 | train_losses += losses 102 | total_num_sample += 1 103 | del loss, z, prior_likelihood, log_det_jacobian 104 | loginfos = combine_dict(loginf, loginfos) 105 | 106 | # dt = time.time() - t_s 107 | train_losses /= total_num_sample 108 | lr = optimizer.param_groups[0]['lr'] 109 | losses_str = ' '.join(['{}: {:.4f}'.format(x, y) for x, y in zip(loss_names, train_losses)]) 110 | 111 | # average cost of log time 20s 112 | for name, loss in zip(loss_names, train_losses): 113 | tb_logger.add_scalars(name, {'val': loss}, epoch) 114 | logger.info('====> Epoch: {} Val Time: {:.2f} {} lr: {:.5f}'.format(epoch, time.time() - t_s, losses_str, lr)) 115 | 116 | t_s = time.time() 117 | generator = dataset_test.sampling_generator(num_samples=cfg.num_data_sample // 2, batch_size=cfg.batch_size) 118 | loginfos_test = None 119 | with torch.no_grad(): 120 | xx = [] 121 | for traj_np in generator: 122 | traj_np = traj_np[:, 0] 123 | traj_np[:, 0] = 0 124 | traj_np = util.absolute2relative(traj_np, parents=dataset.skeleton.parents()) 125 | traj = tensor(traj_np, device=device, dtype=dtype) # .permute(0, 2, 1).contiguous() 126 | bs, nj, _ = traj.shape 127 | x = traj.reshape([bs, -1]) 128 | 129 | z, log_det_jacobian = model(x) 130 | prior_likelihood = prior.log_prob(z).sum(dim=1) 131 | loginf = {} 132 | loginf[f'z'] = z.cpu().data.numpy() 133 | 134 | loss, losses = loss_function(prior_likelihood, log_det_jacobian) 135 | 136 | train_losses += losses 137 | total_num_sample += 1 138 | del loss, z, prior_likelihood, log_det_jacobian 139 | loginfos_test = combine_dict(loginf, loginfos_test) 140 | 141 | # dt = time.time() - t_s 142 | train_losses /= total_num_sample 143 | lr = optimizer.param_groups[0]['lr'] 144 | losses_str = ' '.join(['{}: {:.4f}'.format(x, y) for x, y in zip(loss_names, train_losses)]) 145 | 146 | # average cost of log time 20s 147 | for name, loss in zip(loss_names, train_losses): 148 | tb_logger.add_scalars(name, {'test': loss}, epoch) 149 | logger.info('====> Epoch: {} Test Time: {:.2f} {} lr: {:.5f}'.format(epoch, time.time() - t_s, losses_str, lr)) 150 | 151 | t_s = time.time() 152 | zz = loginfos['z'] 153 | zz = zz.reshape([zz.shape[0], -1]) 154 | bs, data_dim = zz.shape 155 | zz_test = loginfos_test['z'] 156 | zz_test = zz_test.reshape([zz.shape[0], -1]) 157 | # bs, data_dim = zz_test.shape 158 | for ii in range(data_dim): 159 | fig = plt.figure() 160 | ax1 = plt.subplot(121) 161 | _ = plt.hist(zz[:, ii].reshape(-1), bins=100, density=True, alpha=0.5, color='b') 162 | x = torch.from_numpy(np.arange(-5, 5, 0.01)).float().to(device) 163 | y = prior.cdf(x) 164 | x = x[1:] 165 | y = (y[1:] - y[:-1]) * 100 166 | plt.plot(x.cpu().data, y.cpu().data) 167 | ax1.set_title(f'z_val_{ii}') 168 | ax2 = plt.subplot(122) 169 | _ = plt.hist(zz_test[:, ii].reshape(-1), bins=100, density=True, alpha=0.5, color='b') 170 | x = torch.from_numpy(np.arange(-5, 5, 0.01)).float().to(device) 171 | y = prior.cdf(x) 172 | x = x[1:] 173 | y = (y[1:] - y[:-1]) * 100 174 | plt.plot(x.cpu().data, y.cpu().data) 175 | ax2.set_title(f'z_test_{ii}') 176 | tb_logger.add_figure(f'z_{ii}', fig, epoch) 177 | plt.clf() 178 | plt.cla() 179 | plt.close(fig) 180 | 181 | # plot covariance matrix 182 | fig = plt.figure() 183 | ax1 = plt.subplot(121) 184 | # zz = zz.reshape([bs, -1]) 185 | zz = zz - zz.mean(axis=0) 186 | cov = 1 / bs * np.abs(np.matmul(zz.transpose([1, 0]), zz)) 187 | std = np.sqrt(np.diag(cov))[:, None] * np.sqrt(np.diag(cov))[None, :] 188 | corr = cov / (std + 1e-10) - np.eye(cov.shape[0]) 189 | plt.imshow(corr) 190 | plt.colorbar() 191 | ax1.set_title('z_val_corr') 192 | 193 | ax2 = plt.subplot(122) 194 | # zz = zz.reshape([bs, -1]) 195 | zz_test = zz_test - zz_test.mean(axis=0) 196 | cov = 1 / bs * np.abs(np.matmul(zz_test.transpose([1, 0]), zz_test)) 197 | std = np.sqrt(np.diag(cov))[:, None] * np.sqrt(np.diag(cov))[None, :] 198 | corr = cov / (std + 1e-10) - np.eye(cov.shape[0]) 199 | plt.imshow(corr) 200 | plt.colorbar() 201 | ax2.set_title('z_test_corr') 202 | tb_logger.add_figure('z_corr', plt.gcf(), epoch) 203 | plt.close(fig) 204 | print(f'>>>>log time {time.time() - t_s:.3f}') 205 | 206 | 207 | if __name__ == '__main__': 208 | 209 | parser = argparse.ArgumentParser() 210 | parser.add_argument('--cfg', default='h36m_nf') 211 | parser.add_argument('--mode', default='train') 212 | parser.add_argument('--test', action='store_true', default=False) 213 | parser.add_argument('--iter', type=int, default=0) 214 | parser.add_argument('--seed', type=int, default=0) 215 | parser.add_argument('--gpu_index', type=int, default=1) 216 | parser.add_argument('--n_pre', type=int, default=10) 217 | parser.add_argument('--n_his', type=int, default=5) 218 | parser.add_argument('--trial', type=int, default=1) 219 | parser.add_argument('--num_coupling_layer', type=int, default=6) 220 | parser.add_argument('--nz', type=int, default=10) 221 | args = parser.parse_args() 222 | 223 | """setup""" 224 | np.random.seed(args.seed) 225 | torch.manual_seed(args.seed) 226 | dtype = torch.float32 227 | torch.set_default_dtype(dtype) 228 | device = torch.device('cuda')#, index=args.gpu_index) if torch.cuda.is_available() else torch.device('cpu') 229 | # if torch.cuda.is_available(): 230 | # torch.cuda.set_device(args.gpu_index) 231 | cfg = Config(f'{args.cfg}', test=args.test, nf=True) 232 | tb_logger = SummaryWriter(cfg.tb_dir) if args.mode == 'train' else None 233 | logger = create_logger(os.path.join(cfg.log_dir, 'log.txt')) 234 | 235 | """parameter""" 236 | mode = args.mode 237 | nz = cfg.nz 238 | t_his = cfg.t_his 239 | t_pred = cfg.t_pred 240 | cfg.n_his = args.n_his 241 | cfg.n_pre = args.n_pre 242 | cfg.num_coupling_layer = args.num_coupling_layer 243 | cfg.nz = args.nz 244 | """data""" 245 | dataset_cls = DatasetH36M if cfg.dataset == 'h36m' else DatasetHumanEva 246 | dataset = dataset_cls('train', t_his, t_pred, actions='all', use_vel=cfg.use_vel) 247 | dataset_test = dataset_cls('test', t_his, t_pred, actions='all', use_vel=cfg.use_vel) 248 | if cfg.normalize_data: 249 | dataset.normalize_data() 250 | dataset_test.normalize_data(dataset.mean, dataset.std) 251 | 252 | """model""" 253 | model = get_model(cfg, dataset, args.cfg+'_nf') 254 | print(model) 255 | model.float() 256 | optimizer = optim.Adam(model.parameters(), lr=cfg.lr, weight_decay=1e-3) 257 | scheduler = get_scheduler(optimizer, policy='lambda', nepoch_fix=cfg.num_epoch_fix, nepoch=cfg.num_epoch) 258 | logger.info(">>> total params: {:.5f}M".format(sum(p.numel() for p in list(model.parameters())) / 1000000.0)) 259 | 260 | if args.iter > 0: 261 | cp_path = cfg.model_path % args.iter 262 | print('loading model from checkpoint: %s' % cp_path) 263 | model_cp = pickle.load(open(cp_path, "rb")) 264 | model.load_state_dict(model_cp['model_dict']) 265 | 266 | if mode == 'train': 267 | model.to(device) 268 | overall_iter = 0 269 | for i in range(args.iter, cfg.num_epoch): 270 | model.train() 271 | train(i) 272 | model.eval() 273 | val(i) 274 | # test(i) 275 | if cfg.save_model_interval > 0 and (i + 1) % cfg.save_model_interval == 0: 276 | with to_cpu(model): 277 | cp_path = cfg.model_path % (i + 1) 278 | model_cp = {'model_dict': model.state_dict(), 'meta': {'std': dataset.std, 'mean': dataset.mean}} 279 | pickle.dump(model_cp, open(cp_path, 'wb')) 280 | -------------------------------------------------------------------------------- /utils/valid_angle_check.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import pickle 4 | 5 | 6 | def h36m_valid_angle_check(p3d): 7 | """ 8 | p3d: [bs,16,3] or [bs,48] 9 | """ 10 | if p3d.shape[-1] == 48: 11 | p3d = p3d.reshape([p3d.shape[0], 16, 3]) 12 | 13 | cos_func = lambda p1, p2: np.sum(p1 * p2, axis=1) / np.linalg.norm(p1, axis=1) / np.linalg.norm(p2, axis=1) 14 | data_all = p3d 15 | valid_cos = {} 16 | # Spine2LHip 17 | p1 = data_all[:, 3] 18 | p2 = data_all[:, 6] 19 | cos_gt_l = np.sum(p1 * p2, axis=1) / np.linalg.norm(p1, axis=1) / np.linalg.norm(p2, axis=1) 20 | # Spine2RHip 21 | p1 = data_all[:, 0] 22 | p2 = data_all[:, 6] 23 | cos_gt_r = np.sum(p1 * p2, axis=1) / np.linalg.norm(p1, axis=1) / np.linalg.norm(p2, axis=1) 24 | valid_cos['Spine2Hip'] = np.vstack((cos_gt_l, cos_gt_r)) 25 | 26 | # LLeg2LeftHipPlane 27 | p0 = data_all[:, 3] 28 | p1 = data_all[:, 4] - data_all[:, 3] 29 | p2 = data_all[:, 5] - data_all[:, 4] 30 | n0 = np.cross(p0, p1) 31 | cos_gt_l = np.sum(n0 * p2, axis=1) / np.linalg.norm(n0, axis=1) / np.linalg.norm(p2, axis=1) 32 | # RLeg2RHipPlane 33 | p0 = data_all[:, 0] 34 | p1 = data_all[:, 1] - data_all[:, 0] 35 | p2 = data_all[:, 2] - data_all[:, 1] 36 | n0 = np.cross(p1, p0) 37 | cos_gt_r = np.sum(n0 * p2, axis=1) / np.linalg.norm(n0, axis=1) / np.linalg.norm(p2, axis=1) 38 | valid_cos['Leg2HipPlane'] = np.vstack((cos_gt_l, cos_gt_r)) 39 | 40 | # Shoulder2Hip 41 | p1 = data_all[:, 10] - data_all[:, 7] 42 | p2 = data_all[:, 3] 43 | cos_gt_l = np.sum(p1 * p2, axis=1) / np.linalg.norm(p1, axis=1) / np.linalg.norm(p2, axis=1) 44 | p1 = data_all[:, 13] - data_all[:, 7] 45 | p2 = data_all[:, 0] 46 | cos_gt_r = np.sum(p1 * p2, axis=1) / np.linalg.norm(p1, axis=1) / np.linalg.norm(p2, axis=1) 47 | valid_cos['Shoulder2Hip'] = np.vstack((cos_gt_l, cos_gt_r)) 48 | 49 | # Leg2ShoulderPlane 50 | p0 = data_all[:, 13] 51 | p1 = data_all[:, 10] 52 | p2 = data_all[:, 4] 53 | p3 = data_all[:, 1] 54 | n0 = np.cross(p0, p1) 55 | cos_gt_l = np.sum(n0 * p2, axis=1) / np.linalg.norm(n0, axis=1) / np.linalg.norm(p2, axis=1) 56 | cos_gt_r = np.sum(n0 * p3, axis=1) / np.linalg.norm(n0, axis=1) / np.linalg.norm(p3, axis=1) 57 | valid_cos['Leg2ShoulderPlane'] = np.vstack((cos_gt_l, cos_gt_r)) 58 | 59 | # Shoulder2Shoulder 60 | p0 = data_all[:, 13] - data_all[:, 7] 61 | p1 = data_all[:, 10] - data_all[:, 7] 62 | cos_gt = np.sum(p0 * p1, axis=1) / np.linalg.norm(p0, axis=1) / np.linalg.norm(p1, axis=1) 63 | valid_cos['Shoulder2Shoulder'] = cos_gt 64 | 65 | # Neck2Spine 66 | p0 = data_all[:, 7] - data_all[:, 6] 67 | p1 = data_all[:, 6] 68 | cos_gt = np.sum(p0 * p1, axis=1) / np.linalg.norm(p0, axis=1) / np.linalg.norm(p1, axis=1) 69 | valid_cos['Neck2Spine'] = cos_gt 70 | 71 | # Spine2HipPlane1 72 | p0 = data_all[:, 3] 73 | p1 = data_all[:, 4] - data_all[:, 3] 74 | n0 = np.cross(p1, p0) 75 | p2 = data_all[:, 6] 76 | n1 = np.cross(p2, n0) 77 | cos_dir_l = np.sum(p0 * n1, axis=1) / np.linalg.norm(p0, axis=1) / np.linalg.norm(n1, axis=1) 78 | cos_gt_l = np.sum(n0 * p2, axis=1) / np.linalg.norm(n0, axis=1) / np.linalg.norm(p2, axis=1) 79 | p0 = data_all[:, 0] 80 | p1 = data_all[:, 1] - data_all[:, 0] 81 | n0 = np.cross(p0, p1) 82 | p2 = data_all[:, 6] 83 | n1 = np.cross(n0, p2) 84 | cos_dir_r = np.sum(p0 * n1, axis=1) / np.linalg.norm(p0, axis=1) / np.linalg.norm(n1, axis=1) 85 | cos_gt_r = np.sum(n0 * p2, axis=1) / np.linalg.norm(n0, axis=1) / np.linalg.norm(p2, axis=1) 86 | cos_gt_l1 = np.ones_like(cos_gt_l) * 0.5 87 | cos_gt_r1 = np.ones_like(cos_gt_r) * 0.5 88 | cos_gt_l1[cos_dir_l < 0] = cos_gt_l[cos_dir_l < 0] 89 | cos_gt_r1[cos_dir_r < 0] = cos_gt_r[cos_dir_r < 0] 90 | valid_cos['Spine2HipPlane1'] = np.vstack((cos_gt_l1, cos_gt_r1)) 91 | 92 | # Spine2HipPlane2 93 | cos_gt_l2 = np.ones_like(cos_gt_l) * 0.5 94 | cos_gt_r2 = np.ones_like(cos_gt_r) * 0.5 95 | cos_gt_l2[cos_dir_l >= 0] = cos_gt_l[cos_dir_l >= 0] 96 | cos_gt_r2[cos_dir_r >= 0] = cos_gt_r[cos_dir_r >= 0] 97 | valid_cos['Spine2HipPlane2'] = np.vstack((cos_gt_l2, cos_gt_r2)) 98 | 99 | # ShoulderPlane2HipPlane (25 Jan) 100 | p1 = data_all[:, 7] - data_all[:, 3] 101 | p2 = data_all[:, 7] - data_all[:, 0] 102 | p3 = data_all[:, 10] 103 | p4 = data_all[:, 13] 104 | n0 = np.cross(p2, p1) 105 | n1 = np.cross(p3, p4) 106 | cos_gt_l = np.sum(n0 * n1, axis=1) / np.linalg.norm(n0, axis=1) / np.linalg.norm(n1, axis=1) 107 | valid_cos['ShoulderPlane2HipPlane'] = cos_gt_l 108 | 109 | # Head2Neck 110 | p1 = data_all[:, 7] - data_all[:, 6] 111 | p2 = data_all[:, 8] - data_all[:, 7] 112 | cos_gt_l = np.sum(p1 * p2, axis=1) / np.linalg.norm(p1, axis=1) / np.linalg.norm(p2, axis=1) 113 | valid_cos['Head2Neck'] = cos_gt_l 114 | 115 | # Head2HeadTop 116 | p1 = data_all[:, 9] - data_all[:, 8] 117 | p2 = data_all[:, 8] - data_all[:, 7] 118 | cos_gt_l = np.sum(p1 * p2, axis=1) / np.linalg.norm(p1, axis=1) / np.linalg.norm(p2, axis=1) 119 | valid_cos['Head2HeadTop'] = cos_gt_l 120 | 121 | # HeadVerticalPlane2HipPlane 122 | p1 = data_all[:, 9] - data_all[:, 8] 123 | p2 = data_all[:, 8] - data_all[:, 7] 124 | n0 = np.cross(p1, p2) 125 | p3 = data_all[:, 9] - data_all[:, 7] 126 | n1 = np.cross(n0, p3) 127 | p4 = data_all[:, 7] - data_all[:, 0] 128 | p5 = data_all[:, 7] - data_all[:, 3] 129 | n2 = np.cross(p4, p5) 130 | cos_gt_l = cos_func(n1, n2) 131 | valid_cos['HeadVerticalPlane2HipPlane'] = cos_gt_l 132 | 133 | # Shoulder2Neck 134 | p1 = data_all[:, 10] - data_all[:, 7] 135 | p2 = data_all[:, 6] - data_all[:, 7] 136 | cos_gt_l = cos_func(p1, p2) 137 | p1 = data_all[:, 13] - data_all[:, 7] 138 | p2 = data_all[:, 6] - data_all[:, 7] 139 | cos_gt_r = cos_func(p1, p2) 140 | valid_cos['Shoulder2Neck'] = np.vstack((cos_gt_l, cos_gt_r)) 141 | 142 | return valid_cos 143 | 144 | 145 | def h36m_valid_angle_check_torch(p3d): 146 | """ 147 | p3d: [bs,16,3] or [bs,48] 148 | """ 149 | if p3d.shape[-1] == 48: 150 | p3d = p3d.reshape([p3d.shape[0], 16, 3]) 151 | data_all = p3d 152 | cos_func = lambda p1, p2: torch.sum(p1 * p2, dim=1) / torch.norm(p1, dim=1) / torch.norm(p2, dim=1) 153 | 154 | valid_cos = {} 155 | # Spine2LHip 156 | p1 = data_all[:, 3] 157 | p2 = data_all[:, 6] 158 | cos_gt_l = torch.sum(p1 * p2, dim=1) / torch.norm(p1, dim=1) / torch.norm(p2, dim=1) 159 | # Spine2RHip 160 | p1 = data_all[:, 0] 161 | p2 = data_all[:, 6] 162 | cos_gt_r = torch.sum(p1 * p2, dim=1) / torch.norm(p1, dim=1) / torch.norm(p2, dim=1) 163 | valid_cos['Spine2Hip'] = torch.vstack((cos_gt_l, cos_gt_r)) 164 | 165 | # LLeg2LeftHipPlane 166 | p0 = data_all[:, 3] 167 | p1 = data_all[:, 4] - data_all[:, 3] 168 | p2 = data_all[:, 5] - data_all[:, 4] 169 | n0 = torch.cross(p0, p1, dim=1) 170 | cos_gt_l = torch.sum(n0 * p2, dim=1) / torch.norm(n0, dim=1) / torch.norm(p2, dim=1) 171 | # RLeg2RHipPlane 172 | p0 = data_all[:, 0] 173 | p1 = data_all[:, 1] - data_all[:, 0] 174 | p2 = data_all[:, 2] - data_all[:, 1] 175 | n0 = torch.cross(p1, p0) 176 | cos_gt_r = torch.sum(n0 * p2, dim=1) / torch.norm(n0, dim=1) / torch.norm(p2, dim=1) 177 | valid_cos['Leg2HipPlane'] = torch.vstack((cos_gt_l, cos_gt_r)) 178 | 179 | # Shoulder2Hip 180 | p1 = data_all[:, 10] - data_all[:, 7] 181 | p2 = data_all[:, 3] 182 | cos_gt_l = torch.sum(p1 * p2, dim=1) / torch.norm(p1, dim=1) / torch.norm(p2, dim=1) 183 | p1 = data_all[:, 13] - data_all[:, 7] 184 | p2 = data_all[:, 0] 185 | cos_gt_r = torch.sum(p1 * p2, dim=1) / torch.norm(p1, dim=1) / torch.norm(p2, dim=1) 186 | valid_cos['Shoulder2Hip'] = torch.vstack((cos_gt_l, cos_gt_r)) 187 | 188 | # Leg2ShoulderPlane 189 | p0 = data_all[:, 13] 190 | p1 = data_all[:, 10] 191 | p2 = data_all[:, 4] 192 | p3 = data_all[:, 1] 193 | n0 = torch.cross(p0, p1) 194 | cos_gt_l = torch.sum(n0 * p2, dim=1) / torch.norm(n0, dim=1) / torch.norm(p2, dim=1) 195 | cos_gt_r = torch.sum(n0 * p3, dim=1) / torch.norm(n0, dim=1) / torch.norm(p3, dim=1) 196 | valid_cos['Leg2ShoulderPlane'] = torch.vstack((cos_gt_l, cos_gt_r)) 197 | 198 | # Shoulder2Shoulder 199 | p0 = data_all[:, 13] - data_all[:, 7] 200 | p1 = data_all[:, 10] - data_all[:, 7] 201 | cos_gt = torch.sum(p0 * p1, dim=1) / torch.norm(p0, dim=1) / torch.norm(p1, dim=1) 202 | valid_cos['Shoulder2Shoulder'] = cos_gt 203 | 204 | # Neck2Spine 205 | p0 = data_all[:, 7] - data_all[:, 6] 206 | p1 = data_all[:, 6] 207 | cos_gt = torch.sum(p0 * p1, dim=1) / torch.norm(p0, dim=1) / torch.norm(p1, dim=1) 208 | valid_cos['Neck2Spine'] = cos_gt 209 | 210 | # Spine2HipPlane1 211 | p0 = data_all[:, 3] 212 | p1 = data_all[:, 4] - data_all[:, 3] 213 | n0 = torch.cross(p1, p0) 214 | p2 = data_all[:, 6] 215 | n1 = torch.cross(p2, n0) 216 | cos_dir_l = torch.sum(p0 * n1, dim=1) / torch.norm(p0, dim=1) / torch.norm(n1, dim=1) 217 | cos_gt_l = torch.sum(n0 * p2, dim=1) / torch.norm(n0, dim=1) / torch.norm(p2, dim=1) 218 | p0 = data_all[:, 0] 219 | p1 = data_all[:, 1] - data_all[:, 0] 220 | n0 = torch.cross(p0, p1) 221 | p2 = data_all[:, 6] 222 | n1 = torch.cross(n0, p2) 223 | cos_dir_r = torch.sum(p0 * n1, dim=1) / torch.norm(p0, dim=1) / torch.norm(n1, dim=1) 224 | cos_gt_r = torch.sum(n0 * p2, dim=1) / torch.norm(n0, dim=1) / torch.norm(p2, dim=1) 225 | cos_gt_l1 = cos_gt_l[cos_dir_l < 0] 226 | cos_gt_r1 = cos_gt_r[cos_dir_r < 0] 227 | valid_cos['Spine2HipPlane1'] = torch.hstack((cos_gt_l1, cos_gt_r1)) 228 | 229 | # Spine2HipPlane2 230 | cos_gt_l2 = cos_gt_l[cos_dir_l >= 0] 231 | cos_gt_r2 = cos_gt_r[cos_dir_r >= 0] 232 | valid_cos['Spine2HipPlane2'] = torch.hstack((cos_gt_l2, cos_gt_r2)) 233 | 234 | # ShoulderPlane2HipPlane (25 Jan) 235 | p1 = data_all[:, 7] - data_all[:, 3] 236 | p2 = data_all[:, 7] - data_all[:, 0] 237 | p3 = data_all[:, 10] 238 | p4 = data_all[:, 13] 239 | n0 = torch.cross(p2, p1) 240 | n1 = torch.cross(p3, p4) 241 | cos_gt_l = torch.sum(n0 * n1, dim=1) / torch.norm(n0, dim=1) / torch.norm(n1, dim=1) 242 | valid_cos['ShoulderPlane2HipPlane'] = cos_gt_l 243 | 244 | # Head2Neck 245 | p1 = data_all[:, 7] - data_all[:, 6] 246 | p2 = data_all[:, 8] - data_all[:, 7] 247 | cos_gt_l = torch.sum(p1 * p2, dim=1) / torch.norm(p1, dim=1) / torch.norm(p2, dim=1) 248 | valid_cos['Head2Neck'] = cos_gt_l 249 | 250 | # Head2HeadTop 251 | p1 = data_all[:, 9] - data_all[:, 8] 252 | p2 = data_all[:, 8] - data_all[:, 7] 253 | cos_gt_l = torch.sum(p1 * p2, dim=1) / torch.norm(p1, dim=1) / torch.norm(p2, dim=1) 254 | valid_cos['Head2HeadTop'] = cos_gt_l 255 | 256 | # HeadVerticalPlane2HipPlane 257 | p1 = data_all[:, 9] - data_all[:, 8] 258 | p2 = data_all[:, 8] - data_all[:, 7] 259 | n0 = torch.cross(p1, p2) 260 | p3 = data_all[:, 9] - data_all[:, 7] 261 | n1 = torch.cross(n0, p3) 262 | p4 = data_all[:, 7] - data_all[:, 0] 263 | p5 = data_all[:, 7] - data_all[:, 3] 264 | n2 = torch.cross(p4, p5) 265 | cos_gt_l = cos_func(n1, n2) 266 | valid_cos['HeadVerticalPlane2HipPlane'] = cos_gt_l 267 | 268 | # Shoulder2Neck 269 | p1 = data_all[:, 10] - data_all[:, 7] 270 | p2 = data_all[:, 6] - data_all[:, 7] 271 | cos_gt_l = cos_func(p1, p2) 272 | p1 = data_all[:, 13] - data_all[:, 7] 273 | p2 = data_all[:, 6] - data_all[:, 7] 274 | cos_gt_r = cos_func(p1, p2) 275 | valid_cos['Shoulder2Neck'] = torch.vstack((cos_gt_l, cos_gt_r)) 276 | 277 | return valid_cos 278 | 279 | 280 | def humaneva_valid_angle_check(p3d): 281 | """ 282 | p3d: [bs,14,3] or [bs,42] 283 | """ 284 | if p3d.shape[-1] == 42: 285 | p3d = p3d.reshape([p3d.shape[0], 14, 3]) 286 | 287 | cos_func = lambda p1, p2: np.sum(p1 * p2, axis=1) / np.linalg.norm(p1, axis=1) / np.linalg.norm(p2, axis=1) 288 | data_all = p3d 289 | valid_cos = {} 290 | 291 | # LHip2RHip 292 | p1 = data_all[:, 7] 293 | p2 = data_all[:, 10] 294 | cos_gt_l = cos_func(p1, p2) 295 | valid_cos['LHip2RHip'] = cos_gt_l 296 | 297 | # Neck2HipPlane 298 | p1 = data_all[:, 7] 299 | p2 = data_all[:, 10] 300 | n0 = np.cross(p1, p2) 301 | p3 = data_all[:, 0] 302 | cos_gt_l = cos_func(n0, p3) 303 | valid_cos['Neck2HipPlane'] = cos_gt_l 304 | 305 | # Head2Neck 306 | p1 = data_all[:, 13] - data_all[:, 0] 307 | p2 = data_all[:, 0] 308 | cos_gt_l = cos_func(p1, p2) 309 | valid_cos['Head2Neck'] = cos_gt_l 310 | 311 | # Shoulder2Shoulder 312 | p1 = data_all[:, 1] - data_all[:, 0] 313 | p2 = data_all[:, 4] - data_all[:, 0] 314 | cos_gt_l = cos_func(p1, p2) 315 | valid_cos['Shoulder2Shoulder'] = cos_gt_l 316 | 317 | # ShoulderPlane2HipPlane 318 | p1 = data_all[:, 7] - data_all[:, 0] 319 | p2 = data_all[:, 10] - data_all[:, 0] 320 | n0 = np.cross(p1, p2) 321 | p3 = data_all[:, 1] 322 | p4 = data_all[:, 4] 323 | n1 = np.cross(p3, p4) 324 | cos_gt_l = cos_func(n0, n1) 325 | valid_cos['ShoulderPlane2HipPlane'] = cos_gt_l 326 | 327 | # Shoulder2Neck 328 | p1 = data_all[:, 1] - data_all[:, 0] 329 | p2 = data_all[:, 0] 330 | cos_gt_l = cos_func(p1, p2) 331 | p1 = data_all[:, 4] - data_all[:, 0] 332 | p2 = data_all[:, 0] 333 | cos_gt_r = cos_func(p1, p2) 334 | valid_cos['Shoulder2Neck'] = np.vstack((cos_gt_l, cos_gt_r)) 335 | 336 | # Leg2HipPlane 337 | p1 = data_all[:, 7] 338 | p2 = data_all[:, 10] 339 | n0 = np.cross(p1, p2) 340 | p3 = data_all[:, 8] - data_all[:, 7] 341 | cos_gt_l = cos_func(n0, p3) 342 | p3 = data_all[:, 11] - data_all[:, 10] 343 | cos_gt_r = cos_func(n0, p3) 344 | valid_cos['Leg2HipPlane'] = np.vstack((cos_gt_l, cos_gt_r)) 345 | 346 | # Foot2LegPlane 347 | p1 = data_all[:, 7] - data_all[:, 10] 348 | p2 = data_all[:, 11] - data_all[:, 10] 349 | n0 = np.cross(p1, p2) 350 | p3 = data_all[:, 12] - data_all[:, 11] 351 | cos_gt_l = cos_func(n0, p3) 352 | p1 = data_all[:, 7] - data_all[:, 10] 353 | p2 = data_all[:, 8] - data_all[:, 7] 354 | n0 = np.cross(p1, p2) 355 | p3 = data_all[:, 9] - data_all[:, 8] 356 | cos_gt_r = cos_func(n0, p3) 357 | valid_cos['Foot2LegPlane'] = np.vstack((cos_gt_l, cos_gt_r)) 358 | 359 | # ForeArm2ShoulderPlane 360 | p1 = data_all[:, 4] - data_all[:, 0] 361 | p2 = data_all[:, 5] - data_all[:, 4] 362 | n0 = np.cross(p1, p2) 363 | p3 = data_all[:, 6] - data_all[:, 5] 364 | cos_gt_l = cos_func(n0, p3) 365 | p1 = data_all[:, 1] - data_all[:, 0] 366 | p2 = data_all[:, 2] - data_all[:, 1] 367 | n0 = np.cross(p2, p1) 368 | p3 = data_all[:, 3] - data_all[:, 2] 369 | cos_gt_r = cos_func(n0, p3) 370 | valid_cos['ForeArm2ShoulderPlane'] = np.vstack((cos_gt_l, cos_gt_r)) 371 | 372 | return valid_cos 373 | 374 | 375 | def humaneva_valid_angle_check_torch(p3d): 376 | """ 377 | p3d: [bs,14,3] or [bs,42] 378 | """ 379 | if p3d.shape[-1] == 42: 380 | p3d = p3d.reshape([p3d.shape[0], 14, 3]) 381 | 382 | cos_func = lambda p1, p2: torch.sum(p1 * p2, dim=1) / torch.norm(p1, dim=1) / torch.norm(p2, dim=1) 383 | data_all = p3d 384 | valid_cos = {} 385 | 386 | # LHip2RHip 387 | p1 = data_all[:, 7] 388 | p2 = data_all[:, 10] 389 | cos_gt_l = cos_func(p1, p2) 390 | valid_cos['LHip2RHip'] = cos_gt_l 391 | 392 | # Neck2HipPlane 393 | p1 = data_all[:, 7] 394 | p2 = data_all[:, 10] 395 | n0 = torch.cross(p1, p2) 396 | p3 = data_all[:, 0] 397 | cos_gt_l = cos_func(n0, p3) 398 | valid_cos['Neck2HipPlane'] = cos_gt_l 399 | 400 | # Head2Neck 401 | p1 = data_all[:, 13] - data_all[:, 0] 402 | p2 = data_all[:, 0] 403 | cos_gt_l = cos_func(p1, p2) 404 | valid_cos['Head2Neck'] = cos_gt_l 405 | 406 | # Shoulder2Shoulder 407 | p1 = data_all[:, 1] - data_all[:, 0] 408 | p2 = data_all[:, 4] - data_all[:, 0] 409 | cos_gt_l = cos_func(p1, p2) 410 | valid_cos['Shoulder2Shoulder'] = cos_gt_l 411 | 412 | # ShoulderPlane2HipPlane 413 | p1 = data_all[:, 7] - data_all[:, 0] 414 | p2 = data_all[:, 10] - data_all[:, 0] 415 | n0 = torch.cross(p1, p2) 416 | p3 = data_all[:, 1] 417 | p4 = data_all[:, 4] 418 | n1 = torch.cross(p3, p4) 419 | cos_gt_l = cos_func(n0, n1) 420 | valid_cos['ShoulderPlane2HipPlane'] = cos_gt_l 421 | 422 | # Shoulder2Neck 423 | p1 = data_all[:, 1] - data_all[:, 0] 424 | p2 = data_all[:, 0] 425 | cos_gt_l = cos_func(p1, p2) 426 | p1 = data_all[:, 4] - data_all[:, 0] 427 | p2 = data_all[:, 0] 428 | cos_gt_r = cos_func(p1, p2) 429 | valid_cos['Shoulder2Neck'] = torch.vstack((cos_gt_l, cos_gt_r)) 430 | 431 | # Leg2HipPlane 432 | p1 = data_all[:, 7] 433 | p2 = data_all[:, 10] 434 | n0 = torch.cross(p1, p2) 435 | p3 = data_all[:, 8] - data_all[:, 7] 436 | cos_gt_l = cos_func(n0, p3) 437 | p3 = data_all[:, 11] - data_all[:, 10] 438 | cos_gt_r = cos_func(n0, p3) 439 | valid_cos['Leg2HipPlane'] = torch.vstack((cos_gt_l, cos_gt_r)) 440 | 441 | # Foot2LegPlane 442 | p1 = data_all[:, 7] - data_all[:, 10] 443 | p2 = data_all[:, 11] - data_all[:, 10] 444 | n0 = torch.cross(p1, p2) 445 | p3 = data_all[:, 12] - data_all[:, 11] 446 | cos_gt_l = cos_func(n0, p3) 447 | p1 = data_all[:, 7] - data_all[:, 10] 448 | p2 = data_all[:, 8] - data_all[:, 7] 449 | n0 = torch.cross(p1, p2) 450 | p3 = data_all[:, 9] - data_all[:, 8] 451 | cos_gt_r = cos_func(n0, p3) 452 | valid_cos['Foot2LegPlane'] = torch.vstack((cos_gt_l, cos_gt_r)) 453 | 454 | # ForeArm2ShoulderPlane 455 | p1 = data_all[:, 4] - data_all[:, 0] 456 | p2 = data_all[:, 5] - data_all[:, 4] 457 | n0 = torch.cross(p1, p2) 458 | p3 = data_all[:, 6] - data_all[:, 5] 459 | cos_gt_l = cos_func(n0, p3) 460 | p1 = data_all[:, 1] - data_all[:, 0] 461 | p2 = data_all[:, 2] - data_all[:, 1] 462 | n0 = torch.cross(p2, p1) 463 | p3 = data_all[:, 3] - data_all[:, 2] 464 | cos_gt_r = cos_func(n0, p3) 465 | valid_cos['ForeArm2ShoulderPlane'] = torch.vstack((cos_gt_l, cos_gt_r)) 466 | 467 | return valid_cos 468 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import pickle 4 | import argparse 5 | import time 6 | from torch import maximum, optim 7 | from torch.utils.tensorboard import SummaryWriter 8 | import itertools 9 | sys.path.append(os.getcwd()) 10 | from utils import * 11 | from motion_pred.utils.config import Config 12 | from motion_pred.utils.dataset_h36m_multimodal import DatasetH36M 13 | from motion_pred.utils.dataset_humaneva_multimodal import DatasetHumanEva 14 | from motion_pred.utils.visualization import render_animation 15 | from models.motion_pred import * 16 | from utils import util, valid_angle_check 17 | from utils.metrics import * 18 | from scipy.spatial.distance import pdist, squareform 19 | from tqdm import tqdm 20 | import random 21 | import time 22 | 23 | def recon_loss(Y_g, Y, Y_mm, Y_hg=None, Y_h=None): 24 | 25 | stat = torch.zeros(Y_g.shape[2]) 26 | diff = Y_g - Y.unsqueeze(2) # TBMV 27 | dist = diff.pow(2).sum(dim=-1).sum(dim=0) # BM 28 | 29 | value, indices = dist.min(dim=1) 30 | 31 | loss_recon_1 = value.mean() 32 | 33 | diff = Y_hg - Y_h.unsqueeze(2) # TBMC 34 | loss_recon_2 = diff.pow(2).sum(dim=-1).sum(dim=0).mean() 35 | 36 | 37 | with torch.no_grad(): 38 | ade = torch.norm(diff, dim=-1).mean(dim=0).min(dim=1)[0].mean() 39 | 40 | diff = Y_g[:, :, :, None, :] - Y_mm[:, :, None, :, :] 41 | 42 | mask = Y_mm.abs().sum(-1).sum(0) > 1e-6 43 | 44 | dist = diff.pow(2) 45 | with torch.no_grad(): 46 | zeros = torch.zeros([dist.shape[1], dist.shape[2]], requires_grad=False).to(dist.device)# [b,m] 47 | zeros.scatter_(dim=1, index=indices.unsqueeze(1).repeat(1, dist.shape[2]), src=zeros+dist.max()-dist.min()+1) 48 | zeros = zeros.unsqueeze(0).unsqueeze(3).unsqueeze(4) 49 | dist += zeros 50 | dist = dist.sum(dim=-1).sum(dim=0) 51 | 52 | value_2, indices_2 = dist.min(dim=1) 53 | loss_recon_multi = value_2[mask].mean() 54 | if torch.isnan(loss_recon_multi): 55 | loss_recon_multi = torch.zeros_like(loss_recon_1) 56 | 57 | mask = torch.tril(torch.ones([cfg.nk, cfg.nk], device=device)) == 0 58 | # TBMC 59 | 60 | yt = Y_g.reshape([-1, cfg.nk, Y_g.shape[3]]).contiguous() 61 | pdist = torch.cdist(yt, yt, p=1)[:, mask] 62 | return loss_recon_1, loss_recon_2, loss_recon_multi, ade, stat, (-pdist / 100).exp().mean() 63 | 64 | 65 | def angle_loss(y): 66 | ang_names = list(valid_ang.keys()) 67 | y = y.reshape([-1, y.shape[-1]]) 68 | ang_cos = valid_angle_check.h36m_valid_angle_check_torch( 69 | y) if cfg.dataset == 'h36m' else valid_angle_check.humaneva_valid_angle_check_torch(y) 70 | loss = tensor(0, dtype=dtype, device=device) 71 | b = 1 72 | for an in ang_names: 73 | lower_bound = valid_ang[an][0] 74 | if lower_bound >= -0.98: 75 | # loss += torch.exp(-b * (ang_cos[an] - lower_bound)).mean() 76 | if torch.any(ang_cos[an] < lower_bound): 77 | # loss += b * torch.exp(-(ang_cos[an][ang_cos[an] < lower_bound] - lower_bound)).mean() 78 | loss += (ang_cos[an][ang_cos[an] < lower_bound] - lower_bound).pow(2).mean() 79 | upper_bound = valid_ang[an][1] 80 | if upper_bound <= 0.98: 81 | # loss += torch.exp(b * (ang_cos[an] - upper_bound)).mean() 82 | if torch.any(ang_cos[an] > upper_bound): 83 | # loss += b * torch.exp(ang_cos[an][ang_cos[an] > upper_bound] - upper_bound).mean() 84 | loss += (ang_cos[an][ang_cos[an] > upper_bound] - upper_bound).pow(2).mean() 85 | return loss 86 | 87 | 88 | def loss_function(traj_est, traj, traj_multimodal, prior_lkh, prior_logdetjac, _lambda): 89 | lambdas = cfg.lambdas 90 | nj = dataset.traj_dim // 3 91 | 92 | Y_g = traj_est.view(traj_est.shape[0], traj.shape[1], traj_est.shape[1]//traj.shape[1], -1)[t_his:] # T B M V 93 | Y = traj[t_his:] 94 | Y_multimodal = traj_multimodal[t_his:] 95 | Y_hg=traj_est.view(traj_est.shape[0], traj.shape[1], traj_est.shape[1]//traj.shape[1], -1)[:t_his] 96 | Y_h= traj[:t_his] 97 | RECON, RECON_2, RECON_mm, ade, stat, JL = recon_loss(Y_g, Y, Y_multimodal,Y_hg, Y_h) 98 | # maintain limb length 99 | parent = dataset.skeleton.parents() 100 | tmp = traj[0].reshape([cfg.batch_size, nj, 3]) 101 | pgt = torch.zeros([cfg.batch_size, nj + 1, 3], dtype=dtype, device=device) 102 | pgt[:, 1:] = tmp 103 | limbgt = torch.norm(pgt[:, 1:] - pgt[:, parent[1:]], dim=2)[None, :, None, :] 104 | tmp = traj_est.reshape([-1, cfg.batch_size, cfg.nk, nj, 3]) 105 | pest = torch.zeros([tmp.shape[0], cfg.batch_size, cfg.nk, nj + 1, 3], dtype=dtype, device=device) 106 | pest[:, :, :, 1:] = tmp 107 | limbest = torch.norm(pest[:, :, :, 1:] - pest[:, :, :, parent[1:]], dim=4) 108 | loss_limb = torch.mean((limbgt - limbest).pow(2).sum(dim=3)) 109 | 110 | # angle loss 111 | loss_ang = angle_loss(Y_g) 112 | if _lambda < 0.1: 113 | _lambda *= 10 114 | else: 115 | _lambda = 1 116 | 117 | loss_r = loss_limb * lambdas[1] + JL * lambdas[3] * _lambda + RECON * lambdas[4] + RECON_mm * lambdas[5] \ 118 | - prior_lkh.mean() * lambdas[6] + RECON_2 * lambdas[7]# - prior_logdetjac.mean() * lambdas[7] 119 | if loss_ang > 0: 120 | loss_r += loss_ang * lambdas[8] 121 | return loss_r, np.array([loss_r.item(), loss_limb.item(), loss_ang.item(), 122 | JL.item(), RECON.item(), RECON_2.item(), RECON_mm.item(), ade.item(), 123 | prior_lkh.mean().item(), prior_logdetjac.mean().item()]), stat#, indices_key, indices_2_key 124 | 125 | 126 | def dct_transform_torch(data, dct_m, dct_n): 127 | ''' 128 | B, 60, 35 129 | ''' 130 | batch_size, features, seq_len = data.shape 131 | 132 | data = data.contiguous().view(-1, seq_len) # [180077*60, 35] 133 | data = data.permute(1, 0) # [35, b*60] 134 | 135 | out_data = torch.matmul(dct_m[:dct_n, :], data) # [dct_n, 180077*60] 136 | out_data = out_data.permute(1, 0).contiguous().view(-1, features, dct_n) # [b, 60, dct_n] 137 | return out_data 138 | 139 | def get_dct_matrix(N): 140 | dct_m = np.eye(N) 141 | for k in np.arange(N): 142 | for i in np.arange(N): 143 | w = np.sqrt(2 / N) 144 | if k == 0: 145 | w = np.sqrt(1 / N) 146 | dct_m[k, i] = w * np.cos(np.pi * (i + 1 / 2) * k / N) 147 | idct_m = np.linalg.inv(dct_m) 148 | return dct_m, idct_m 149 | 150 | def train(epoch, stats): 151 | 152 | dct_m, i_dct_m = get_dct_matrix(cfg.t_his+cfg.t_pred) 153 | dct_m = torch.from_numpy(dct_m).float().to(device) 154 | i_dct_m = torch.from_numpy(i_dct_m).float().to(device) 155 | 156 | model.train() 157 | t_s = time.time() 158 | train_losses = 0 159 | train_grad = 0 160 | total_num_sample = 0 161 | n_modality = 10 162 | loss_names = ['LOSS', 'loss_limb', 'loss_ang', 'loss_DIV', 163 | 'RECON', 'RECON_2', 'RECON_multi', "ADE", 'p(z)', 'logdet'] 164 | generator = dataset.sampling_generator(num_samples=cfg.num_data_sample, batch_size=cfg.batch_size, 165 | n_modality=n_modality) 166 | prior = torch.distributions.Normal(torch.tensor(0, dtype=dtype, device=device), 167 | torch.tensor(1, dtype=dtype, device=device)) 168 | 169 | for traj_np, traj_multimodal_np in tqdm(generator): 170 | with torch.no_grad(): 171 | 172 | bs, _, nj, _ = traj_np[..., 1:, :].shape # [bs, t_full, numJoints, 3] 173 | traj_np = traj_np[..., 1:, :].reshape(traj_np.shape[0], traj_np.shape[1], -1) # bs, T, NumJoints*3 174 | traj = tensor(traj_np, device=device, dtype=dtype).permute(1, 0, 2).contiguous() # T, bs, NumJoints*3 175 | 176 | traj_multimodal_np = traj_multimodal_np[..., 1:, :] # [bs, n_modality, t_full, NumJoints, 3] 177 | traj_multimodal_np = traj_multimodal_np.reshape([bs, n_modality, t_his + t_pred, -1]).transpose( 178 | [2, 0, 1, 3]) # [t_full, bs, n_modality, NumJoints*3] 179 | traj_multimodal = tensor(traj_multimodal_np, device=device, dtype=dtype) # .permute(0, 2, 1).contiguous() 180 | 181 | X = traj[:t_his] 182 | Y = traj[t_his:] 183 | 184 | pred, a, b = model(traj) 185 | 186 | pred_tmp1 = pred.reshape([-1, pred.shape[-1] // 3, 3]) 187 | pred_tmp = torch.zeros_like(pred_tmp1[:, :1, :]) 188 | pred_tmp1 = torch.cat([pred_tmp, pred_tmp1], dim=1) 189 | pred_tmp1 = util.absolute2relative_torch(pred_tmp1, parents=dataset.skeleton.parents()).reshape( 190 | [-1, pred.shape[-1]]) 191 | z, prior_logdetjac = pose_prior(pred_tmp1) 192 | prior_lkh = prior.log_prob(z).sum(dim=-1) 193 | 194 | loss, losses, stat = loss_function(pred.unsqueeze(2), traj, traj_multimodal, prior_lkh, prior_logdetjac, epoch / cfg.num_epoch) 195 | 196 | optimizer.zero_grad() 197 | loss.backward() 198 | grad_norm = torch.nn.utils.clip_grad_norm_(list(model.parameters()), max_norm=100) 199 | train_grad += grad_norm 200 | optimizer.step() 201 | train_losses += losses 202 | 203 | total_num_sample += 1 204 | del loss 205 | 206 | scheduler.step() 207 | train_losses /= total_num_sample 208 | losses_str = ' '.join(['{}: {:.4f}'.format(x, y) for x, y in zip(loss_names, train_losses)]) 209 | lr = optimizer.param_groups[0]['lr'] 210 | # average cost of log time 20s 211 | tb_logger.add_scalar('train_grad', train_grad / total_num_sample, epoch) 212 | 213 | logger.info('====> Epoch: {} Time: {:.2f} {} lr: {:.5f} branch_stats: {}'.format(epoch, time.time() - t_s, losses_str , lr, stats)) 214 | 215 | return stats 216 | 217 | def get_multimodal_gt(dataset_test): 218 | all_data = [] 219 | data_gen = dataset_test.iter_generator(step=cfg.t_his) 220 | for data, _ in tqdm(data_gen): 221 | # print(data.shape) 222 | data = data[..., 1:, :].reshape(data.shape[0], data.shape[1], -1) 223 | all_data.append(data) 224 | all_data = np.concatenate(all_data, axis=0) 225 | all_start_pose = all_data[:, t_his - 1, :] 226 | pd = squareform(pdist(all_start_pose)) 227 | traj_gt_arr = [] 228 | num_mult = [] 229 | for i in range(pd.shape[0]): 230 | ind = np.nonzero(pd[i] < args.multimodal_threshold) 231 | traj_gt_arr.append(all_data[ind][:, t_his:, :]) 232 | num_mult.append(len(ind[0])) 233 | 234 | num_mult = np.array(num_mult) 235 | logger.info('') 236 | logger.info('') 237 | logger.info('=' * 80) 238 | logger.info(f'#1 future: {len(np.where(num_mult == 1)[0])}/{pd.shape[0]}') 239 | logger.info(f'#<10 future: {len(np.where(num_mult < 10)[0])}/{pd.shape[0]}') 240 | return traj_gt_arr 241 | 242 | def get_prediction(data, model, sample_num, num_seeds=1, concat_hist=True): 243 | # 1 * total_len * num_key * 3 244 | dct_m, i_dct_m = get_dct_matrix(cfg.t_his+cfg.t_pred) 245 | dct_m = torch.from_numpy(dct_m).float().to(device) 246 | i_dct_m = torch.from_numpy(i_dct_m).float().to(device) 247 | traj_np = data[..., 1:, :].reshape(data.shape[0], data.shape[1], -1) 248 | # 1 * total_len * ((num_key-1)*3) 249 | traj = tensor(traj_np, device=device, dtype=dtype).permute(1, 0, 2).contiguous() 250 | # total_len * 1 * ((num_key-1)*3) 251 | X = traj[:t_his] 252 | Y_gt = traj[t_his:] 253 | X = X.repeat((1, sample_num * num_seeds, 1)) 254 | Y_gt = Y_gt.repeat((1, sample_num * num_seeds, 1)) 255 | 256 | 257 | Y, mu, logvar = model(X) 258 | 259 | if concat_hist: 260 | 261 | X = X.unsqueeze(2).repeat(1, sample_num * num_seeds, cfg.nk, 1) 262 | Y = Y[t_his:].unsqueeze(1) 263 | Y = torch.cat((X, Y), dim=0) 264 | # total_len * batch_size * feature_size 265 | Y = Y.squeeze(1).permute(1, 0, 2).contiguous().cpu().numpy() 266 | # batch_size * total_len * feature_size 267 | if Y.shape[0] > 1: 268 | Y = Y.reshape(-1, cfg.nk * sample_num, Y.shape[-2], Y.shape[-1]) 269 | else: 270 | Y = Y[None, ...] 271 | return Y 272 | 273 | 274 | def test(model, epoch): 275 | stats_func = {'Diversity': compute_diversity, 'AMSE': compute_amse, 'FMSE': compute_fmse, 'ADE': compute_ade, 276 | 'FDE': compute_fde, 'MMADE': compute_mmade, 'MMFDE': compute_mmfde, 'MPJPE': mpjpe_error} 277 | stats_names = list(stats_func.keys()) 278 | stats_names.extend(['ADE_stat', 'FDE_stat', 'MMADE_stat', 'MMFDE_stat', 'MPJPE_stat']) 279 | stats_meter = {x: AverageMeter() for x in stats_names} 280 | 281 | data_gen = dataset_test.iter_generator(step=cfg.t_his) 282 | num_samples = 0 283 | num_seeds = 1 284 | 285 | for i, (data, _) in tqdm(enumerate(data_gen)): 286 | if args.mode == 'train' and (i >= 500 and (epoch + 1) % 50 != 0 and (epoch + 1) < cfg.num_epoch - 100): 287 | break 288 | num_samples += 1 289 | gt = data[..., 1:, :].reshape(data.shape[0], data.shape[1], -1)[:, t_his:, :] 290 | gt_multi = traj_gt_arr[i] 291 | if gt_multi.shape[0] == 1: 292 | continue 293 | 294 | pred = get_prediction(data, model, sample_num=1, num_seeds=num_seeds, concat_hist=False) 295 | pred = pred[:,:,t_his:,:] 296 | for stats in stats_names[:8]: 297 | val = 0 298 | branches = 0 299 | for pred_i in pred: 300 | # sample_num * total_len * ((num_key-1)*3), 1 * total_len * ((num_key-1)*3) 301 | v = stats_func[stats](pred_i, gt, gt_multi) 302 | val += v[0] / num_seeds 303 | if stats_func[stats](pred_i, gt, gt_multi)[1] is not None: 304 | branches += v[1] / num_seeds 305 | stats_meter[stats].update(val) 306 | if type(branches) is not int: 307 | stats_meter[stats + '_stat'].update(branches) 308 | 309 | logger.info('=' * 80) 310 | for stats in stats_names: 311 | str_stats = f'Total {stats}: ' + f'{stats_meter[stats].avg}' 312 | logger.info(str_stats) 313 | logger.info('=' * 80) 314 | 315 | 316 | def visualize(): 317 | def denomarlize(*data): 318 | out = [] 319 | for x in data: 320 | x = x * dataset.std + dataset.mean 321 | out.append(x) 322 | return out 323 | 324 | def post_process(pred, data): 325 | pred = pred.reshape(pred.shape[0], pred.shape[1], -1, 3) 326 | if cfg.normalize_data: 327 | pred = denomarlize(pred) 328 | pred = np.concatenate((np.tile(data[..., :1, :], (pred.shape[0], 1, 1, 1)), pred), axis=2) 329 | pred[..., :1, :] = 0 330 | return pred 331 | 332 | def pose_generator(): 333 | 334 | while True: 335 | data, data_multimodal, action = dataset_test.sample(n_modality=10) 336 | gt = data[0].copy() 337 | gt[:, :1, :] = 0 338 | 339 | poses = {'action': action, 'context': gt, 'gt': gt} 340 | with torch.no_grad(): 341 | pred = get_prediction(data, model, 1)[0] 342 | pred = post_process(pred, data) 343 | for i in range(pred.shape[0]): 344 | poses[f'{i}'] = pred[i] 345 | 346 | yield poses 347 | 348 | pose_gen = pose_generator() 349 | for i in tqdm(range(args.n_viz)): 350 | render_animation(dataset.skeleton, pose_gen, cfg.t_his, ncol=12, output='./results/{}/results/'.format(args.cfg), index_i=i) 351 | 352 | 353 | 354 | 355 | 356 | if __name__ == '__main__': 357 | 358 | parser = argparse.ArgumentParser() 359 | parser.add_argument('--cfg', 360 | default='h36m') 361 | parser.add_argument('--mode', default='train') 362 | parser.add_argument('--test', action='store_true', default=False) 363 | parser.add_argument('--iter', type=int, default=0) 364 | parser.add_argument('--seed', type=int, default=1) 365 | parser.add_argument('--gpu_index', type=int, default=0) 366 | parser.add_argument('--n_pre', type=int, default=8) 367 | parser.add_argument('--n_his', type=int, default=5) 368 | parser.add_argument('--n_viz', type=int, default=100) 369 | parser.add_argument('--num_coupling_layer', type=int, default=4) 370 | parser.add_argument('--multimodal_threshold', type=float, default=0.5) 371 | args = parser.parse_args() 372 | 373 | """setup""" 374 | np.random.seed(args.seed) 375 | torch.manual_seed(args.seed) 376 | dtype = torch.float32 377 | torch.set_default_dtype(dtype) 378 | 379 | device = torch.device('cuda', index=args.gpu_index) if torch.cuda.is_available() else torch.device('cpu') 380 | 381 | cfg = Config(f'{args.cfg}', test=args.test) 382 | tb_logger = SummaryWriter(cfg.tb_dir) if args.mode == 'train' else None 383 | logger = create_logger(os.path.join(cfg.log_dir, 'log.txt')) 384 | 385 | """parameter""" 386 | mode = args.mode 387 | nz = cfg.nz 388 | t_his = cfg.t_his 389 | t_pred = cfg.t_pred 390 | cfg.n_his = args.n_his 391 | if 'n_pre' not in cfg.specs.keys(): 392 | cfg.n_pre = args.n_pre 393 | else: 394 | cfg.n_pre = cfg.specs['n_pre'] 395 | cfg.num_coupling_layer = args.num_coupling_layer 396 | # cfg.nz = args.nz 397 | """data""" 398 | if 'actions' in cfg.specs.keys(): 399 | act = cfg.specs['actions'] 400 | else: 401 | act = 'all' 402 | dataset_cls = DatasetH36M if cfg.dataset == 'h36m' else DatasetHumanEva 403 | dataset = dataset_cls('train', t_his, t_pred, actions=act, use_vel=cfg.use_vel, 404 | multimodal_path=cfg.specs[ 405 | 'multimodal_path'] if 'multimodal_path' in cfg.specs.keys() else None, 406 | data_candi_path=cfg.specs[ 407 | 'data_candi_path'] if 'data_candi_path' in cfg.specs.keys() else None) 408 | dataset_test = dataset_cls('test', t_his, t_pred, actions=act, use_vel=cfg.use_vel, 409 | multimodal_path=cfg.specs[ 410 | 'multimodal_path'] if 'multimodal_path' in cfg.specs.keys() else None, 411 | data_candi_path=cfg.specs[ 412 | 'data_candi_path'] if 'data_candi_path' in cfg.specs.keys() else None) 413 | if cfg.normalize_data: 414 | dataset.normalize_data() 415 | dataset_test.normalize_data(dataset.mean, dataset.std) 416 | traj_gt_arr = get_multimodal_gt(dataset_test) 417 | """model""" 418 | model, pose_prior = get_model(cfg, dataset, cfg.dataset) 419 | 420 | model.float() 421 | pose_prior.float() 422 | 423 | optimizer = optim.Adam(model.parameters(), lr=cfg.lr) 424 | 425 | scheduler = get_scheduler(optimizer, policy='lambda', nepoch_fix=cfg.num_epoch_fix, nepoch=cfg.num_epoch) 426 | 427 | 428 | cp_path = 'results/h36m_nf/models/0025.p' if cfg.dataset == 'h36m' else 'results/humaneva_nf/models/0025.p' 429 | print('loading model from checkpoint: %s' % cp_path) 430 | model_cp = pickle.load(open(cp_path, "rb")) 431 | pose_prior.load_state_dict(model_cp['model_dict']) 432 | pose_prior.to(device) 433 | 434 | valid_ang = pickle.load(open('./data/h36m_valid_angle.p', "rb")) if cfg.dataset == 'h36m' else pickle.load( 435 | open('./data/humaneva_valid_angle.p', "rb")) 436 | if args.iter > 0: 437 | cp_path = cfg.model_path % args.iter 438 | print('loading model from checkpoint: %s' % cp_path) 439 | model_cp = pickle.load(open(cp_path, "rb")) 440 | model.load_state_dict(model_cp['model_dict']) 441 | print("load done") 442 | 443 | if mode == 'train': 444 | model.to(device) 445 | overall_iter = 0 446 | stats = torch.zeros(cfg.nk) 447 | model.train() 448 | 449 | for i in range(args.iter, cfg.num_epoch): 450 | stats = train(i, stats) 451 | if cfg.save_model_interval > 0 and (i + 1) % 10 == 0: 452 | model.eval() 453 | with torch.no_grad(): 454 | test(model, i) 455 | model.train() 456 | with to_cpu(model): 457 | cp_path = cfg.model_path % (i + 1) 458 | model_cp = {'model_dict': model.state_dict(), 'meta': {'std': dataset.std, 'mean': dataset.mean}} 459 | 460 | pickle.dump(model_cp, open(cp_path, 'wb')) 461 | 462 | elif mode == 'test': 463 | model.to(device) 464 | model.eval() 465 | 466 | with torch.no_grad(): 467 | test(model,args.iter) 468 | 469 | 470 | elif mode == 'viz': 471 | model.to(device) 472 | model.eval() 473 | with torch.no_grad(): 474 | visualize() 475 | -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | import torch 5 | import torch.nn as nn 6 | import math 7 | import numpy as np 8 | from torch.nn import functional as F 9 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): 10 | return F.leaky_relu(input + bias, negative_slope) * scale 11 | class ST_GCNN_layer_down(nn.Module): 12 | """ 13 | Shape: 14 | - Input[0]: Input graph sequence in :math:`(N, in_channels, T_{in}, V)` format 15 | - Input[1]: Input graph adjacency matrix in :math:`(K, V, V)` format 16 | - Output[0]: Outpu graph sequence in :math:`(N, out_channels, T_{out}, V)` format 17 | where 18 | :math:`N` is a batch size, 19 | :math:`K` is the spatial kernel size, as :math:`K == kernel_size[1]`, 20 | :math:`T_{in}/T_{out}` is a length of input/output sequence, 21 | :math:`V` is the number of graph nodes. 22 | :in_channels= dimension of coordinates 23 | : out_channels=dimension of coordinates 24 | + 25 | """ 26 | def __init__(self, 27 | in_channels, 28 | out_channels, 29 | kernel_size, 30 | stride, 31 | time_dim, 32 | joints_dim, 33 | dropout, 34 | bias=True, 35 | version=0, 36 | pose_info=None): 37 | 38 | super(ST_GCNN_layer_down,self).__init__() 39 | self.kernel_size = kernel_size 40 | # assert self.kernel_size[0] % 2 == 1 41 | # assert self.kernel_size[1] % 2 == 1 42 | # padding = ((self.kernel_size[0] - 1) // 2,(self.kernel_size[1] - 1) // 2) 43 | padding = (0,0) 44 | 45 | if version == 0: 46 | self.gcn=ConvTemporalGraphical(time_dim,joints_dim) # the convolution layer 47 | elif version == 1: 48 | self.gcn = ConvTemporalGraphicalV1(time_dim,joints_dim,pose_info=pose_info) 49 | if type(stride) != list: 50 | self.tcn = nn.Sequential( 51 | nn.Conv2d( 52 | in_channels, 53 | out_channels, 54 | (self.kernel_size[0], self.kernel_size[1]), 55 | (stride, stride), 56 | padding, 57 | ), 58 | nn.BatchNorm2d(out_channels), 59 | nn.Dropout(dropout, inplace=True), 60 | ) 61 | else: 62 | self.tcn = nn.Sequential( 63 | nn.Conv2d( 64 | in_channels, 65 | out_channels, 66 | (self.kernel_size[0], self.kernel_size[1]), 67 | (stride[0], stride[1]), 68 | padding, 69 | ), 70 | nn.BatchNorm2d(out_channels), 71 | nn.Dropout(dropout, inplace=True), 72 | ) 73 | 74 | 75 | 76 | if stride != 1 or in_channels != out_channels: 77 | 78 | self.residual=nn.Sequential(nn.Conv2d( 79 | in_channels, 80 | out_channels, 81 | kernel_size=1, 82 | stride=(1, 1)), 83 | nn.BatchNorm2d(out_channels), 84 | ) 85 | 86 | 87 | else: 88 | self.residual=nn.Identity() 89 | 90 | 91 | self.prelu = nn.PReLU() 92 | 93 | 94 | 95 | def forward(self, x): 96 | # assert A.shape[0] == self.kernel_size[1], print(A.shape[0],self.kernel_size) 97 | res=self.residual(x) 98 | x=self.gcn(x) 99 | x=self.tcn(x) 100 | # x=x+res 101 | x=self.prelu(x) 102 | return x 103 | def count_parameters(model): 104 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 105 | 106 | class EqualLinear(nn.Module): 107 | def __init__(self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None): 108 | super().__init__() 109 | 110 | self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) 111 | 112 | if bias: 113 | self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) 114 | else: 115 | self.bias = None 116 | 117 | self.activation = activation 118 | 119 | self.scale = (1 / math.sqrt(in_dim)) * lr_mul 120 | self.lr_mul = lr_mul 121 | 122 | def forward(self, input): 123 | 124 | if self.activation: 125 | out = F.linear(input, self.weight * self.scale) 126 | out = fused_leaky_relu(out, self.bias * self.lr_mul) 127 | else: 128 | out = F.linear(input, self.weight * self.scale, bias=self.bias * self.lr_mul) 129 | 130 | return out 131 | 132 | def __repr__(self): 133 | return (f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})') 134 | 135 | class GraphConv(nn.Module): 136 | """ 137 | adapted from : https://github.com/tkipf/gcn/blob/92600c39797c2bfb61a508e52b88fb554df30177/gcn/layers.py#L132 138 | """ 139 | 140 | def __init__(self, in_len, out_len, in_node_n=66, out_node_n=66, bias=True): 141 | super(GraphConv, self).__init__() 142 | self.in_len = in_len 143 | self.out_len = out_len 144 | self.in_node_n = in_node_n 145 | self.out_node_n = out_node_n 146 | self.weight = nn.Parameter(torch.FloatTensor(in_len, out_len)) 147 | self.att = nn.Parameter(torch.FloatTensor(in_node_n, out_node_n)) 148 | 149 | if bias: 150 | self.bias = nn.Parameter(torch.FloatTensor(out_len)) 151 | else: 152 | self.register_parameter('bias', None) 153 | self.reset_parameters() 154 | 155 | 156 | def reset_parameters(self): 157 | stdv = 1. / math.sqrt(self.weight.size(1)) 158 | self.weight.data.uniform_(-stdv, stdv) 159 | self.att.data.uniform_(-stdv, stdv) 160 | if self.bias is not None: 161 | self.bias.data.uniform_(-stdv, stdv) 162 | 163 | def forward(self, input): 164 | ''' 165 | b, cv, t 166 | ''' 167 | 168 | features = torch.matmul(input, self.weight) # 35 -> 256 169 | output = torch.matmul(features.permute(0, 2, 1).contiguous(), self.att).permute(0, 2, 1).contiguous() # 66 -> 66 170 | 171 | if self.bias is not None: 172 | output = output + self.bias 173 | 174 | return output 175 | 176 | def __repr__(self): 177 | return self.__class__.__name__ + ' ('+ str(self.in_len) + ' -> ' + str(self.out_len) + ')' + ' ('+ str(self.in_node_n) + ' -> ' + str(self.out_node_n) + ')' 178 | 179 | class GraphConvBlock(nn.Module): 180 | def __init__(self, in_len, out_len, in_node_n, out_node_n, dropout_rate=0, leaky=0.1, bias=True, residual=False): 181 | super(GraphConvBlock, self).__init__() 182 | self.dropout_rate = dropout_rate 183 | self.resual = residual 184 | 185 | self.out_len = out_len 186 | 187 | self.gcn = GraphConv(in_len, out_len, in_node_n=in_node_n, out_node_n=out_node_n, bias=bias) 188 | self.bn = nn.BatchNorm1d(out_node_n * out_len) 189 | self.act = nn.Tanh() 190 | if self.dropout_rate > 0: 191 | self.drop = nn.Dropout(dropout_rate) 192 | 193 | def forward(self, input): 194 | ''' 195 | 196 | Args: 197 | input: b, cv, t 198 | 199 | Returns: 200 | 201 | ''' 202 | x = self.gcn(input) 203 | b, vc, t = x.shape 204 | x = self.bn(x.view(b, -1)).view(b, vc, t) 205 | # x = self.bn(x.view(b, -1, 3, t).permute(0, 2, 1, 3).contiguous()).permute(0, 2, 1, 3).contiguous().view(b, vc, t) 206 | x = self.act(x) 207 | if self.dropout_rate > 0: 208 | x = self.drop(x) 209 | 210 | if self.resual: 211 | return x + input 212 | else: 213 | return x 214 | 215 | 216 | class ResGCB(nn.Module): 217 | def __init__(self, in_len, out_len, in_node_n, out_node_n, dropout_rate=0, leaky=0.1, bias=True, residual=False): 218 | super(ResGCB, self).__init__() 219 | self.resual = residual 220 | self.gcb1 = GraphConvBlock(in_len, in_len, in_node_n=in_node_n, out_node_n=in_node_n, dropout_rate=dropout_rate, bias=bias, residual=False) 221 | self.gcb2 = GraphConvBlock(in_len, out_len, in_node_n=in_node_n, out_node_n=out_node_n, dropout_rate=dropout_rate, bias=bias, residual=False) 222 | 223 | 224 | def forward(self, input): 225 | ''' 226 | 227 | Args: 228 | x: B,CV,T 229 | 230 | Returns: 231 | 232 | ''' 233 | 234 | x = self.gcb1(input) 235 | x = self.gcb2(x) 236 | 237 | if self.resual: 238 | return x + input 239 | else: 240 | return x 241 | 242 | class ConvTemporalGraphical(nn.Module): 243 | #Source : https://github.com/yysijie/st-gcn/blob/master/net/st_gcn.py 244 | r"""The basic module for applying a graph convolution. 245 | Shape: 246 | - Input: Input graph sequence in :math:`(N, in_channels, T_{in}, V)` format 247 | - Output: Output graph sequence in :math:`(N, out_channels, T_{out}, V)` format 248 | where 249 | :math:`N` is a batch size, 250 | :math:`K` is the spatial kernel size, as :math:`K == kernel_size[1]`, 251 | :math:`T_{in}/T_{out}` is a length of input/output sequence, 252 | :math:`V` is the number of graph nodes. 253 | """ 254 | def __init__(self, 255 | time_dim, 256 | joints_dim 257 | ): 258 | super(ConvTemporalGraphical,self).__init__() 259 | 260 | self.A=nn.Parameter(torch.FloatTensor(time_dim, joints_dim,joints_dim)) #learnable, graph-agnostic 3-d adjacency matrix(or edge importance matrix) 261 | stdv = 1. / math.sqrt(self.A.size(1)) 262 | self.A.data.uniform_(-stdv,stdv) 263 | 264 | self.T=nn.Parameter(torch.FloatTensor(joints_dim, time_dim, time_dim)) 265 | stdv = 1. / math.sqrt(self.T.size(1)) 266 | self.T.data.uniform_(-stdv,stdv) 267 | ''' 268 | self.prelu = nn.PReLU() 269 | 270 | self.Z=nn.Parameter(torch.FloatTensor(joints_dim, joints_dim, time_dim, time_dim)) 271 | stdv = 1. / math.sqrt(self.Z.size(2)) 272 | self.Z.data.uniform_(-stdv,stdv) 273 | ''' 274 | self.joints_dim = joints_dim 275 | self.time_dim = time_dim 276 | 277 | def forward(self, x): 278 | x = torch.einsum('nctv,vtq->ncqv', (x, self.T)) 279 | ## x=self.prelu(x) 280 | x = torch.einsum('nctv,tvw->nctw', (x, self.A)) 281 | ## x = torch.einsum('nctv,wvtq->ncqw', (x, self.Z)) 282 | return x.contiguous() 283 | 284 | 285 | class ConvTemporalGraphicalV1(nn.Module): 286 | #Source : https://github.com/yysijie/st-gcn/blob/master/net/st_gcn.py 287 | r"""The basic module for applying a graph convolution. 288 | Shape: 289 | - Input: Input graph sequence in :math:`(N, in_channels, T_{in}, V)` format 290 | - Output: Outpu graph sequence in :math:`(N, out_channels, T_{out}, V)` format 291 | where 292 | :math:`N` is a batch size, 293 | :math:`K` is the spatial kernel size, as :math:`K == kernel_size[1]`, 294 | :math:`T_{in}/T_{out}` is a length of input/output sequence, 295 | :math:`V` is the number of graph nodes. 296 | """ 297 | def __init__(self, 298 | time_dim, 299 | joints_dim, 300 | pose_info 301 | ): 302 | super(ConvTemporalGraphicalV1,self).__init__() 303 | parents=pose_info['parents'] 304 | joints_left=list(pose_info['joints_left']) 305 | joints_right=list(pose_info['joints_right']) 306 | keep_joints=pose_info['keep_joints'] 307 | dim_use = list(keep_joints) 308 | # print(dim_use) 309 | self.A=nn.Parameter(torch.FloatTensor(time_dim, joints_dim, joints_dim)) #learnable, graph-agnostic 3-d adjacency matrix(or edge importance matrix) 310 | stdv = 1. / math.sqrt(self.A.size(1)) 311 | self.A.data.uniform_(-stdv,stdv) 312 | 313 | self.T=nn.Parameter(torch.FloatTensor(joints_dim, time_dim, time_dim)) 314 | stdv = 1. / math.sqrt(self.T.size(1)) 315 | self.T.data.uniform_(-stdv,stdv) 316 | ''' 317 | self.prelu = nn.PReLU() 318 | 319 | self.Z=nn.Parameter(torch.FloatTensor(joints_dim, joints_dim, time_dim, time_dim)) 320 | stdv = 1. / math.sqrt(self.Z.size(2)) 321 | self.Z.data.uniform_(-stdv,stdv) 322 | ''' 323 | self.A_s = torch.zeros((1,joints_dim,joints_dim), requires_grad=False, dtype=torch.float) 324 | for i, dim in enumerate(dim_use): 325 | self.A_s[0][i][i] = 1 326 | if parents[dim] in dim_use: 327 | parent_index = dim_use.index(parents[dim]) 328 | self.A_s[0][i][parent_index] = 1 329 | self.A_s[0][parent_index][i] = 1 330 | if dim in joints_left: 331 | index = joints_left.index(dim) 332 | right_dim = joints_right[index] 333 | right_index = dim_use.index(right_dim) 334 | if right_dim in dim_use: 335 | self.A_s[0][i][right_index] = 1 336 | self.A_s[0][right_index][i] = 1 337 | 338 | self.joints_dim = joints_dim 339 | self.time_dim = time_dim 340 | 341 | def forward(self, x): 342 | A = self.A * self.A_s.to(x.device) 343 | x = torch.einsum('nctv,vtq->ncqv', (x, self.T)) 344 | x = torch.einsum('nctv,tvw->nctw', (x, A)) 345 | return x.contiguous() 346 | 347 | 348 | class ST_GCNN_layer(nn.Module): 349 | """ 350 | Shape: 351 | - Input[0]: Input graph sequence in :math:`(N, in_channels, T_{in}, V)` format 352 | - Input[1]: Input graph adjacency matrix in :math:`(K, V, V)` format 353 | - Output[0]: Outpu graph sequence in :math:`(N, out_channels, T_{out}, V)` format 354 | where 355 | :math:`N` is a batch size, 356 | :math:`K` is the spatial kernel size, as :math:`K == kernel_size[1]`, 357 | :math:`T_{in}/T_{out}` is a length of input/output sequence, 358 | :math:`V` is the number of graph nodes. 359 | :in_channels= dimension of coordinates 360 | : out_channels=dimension of coordinates 361 | + 362 | """ 363 | def __init__(self, 364 | in_channels, 365 | out_channels, 366 | kernel_size, 367 | stride, 368 | time_dim, 369 | joints_dim, 370 | dropout, 371 | bias=True, 372 | version=0, 373 | pose_info=None): 374 | 375 | super(ST_GCNN_layer,self).__init__() 376 | self.kernel_size = kernel_size 377 | assert self.kernel_size[0] % 2 == 1 378 | assert self.kernel_size[1] % 2 == 1 379 | padding = ((self.kernel_size[0] - 1) // 2,(self.kernel_size[1] - 1) // 2) 380 | 381 | if version == 0: 382 | self.gcn=ConvTemporalGraphical(time_dim,joints_dim) # the convolution layer 383 | elif version == 1: 384 | self.gcn = ConvTemporalGraphicalV1(time_dim,joints_dim,pose_info=pose_info) 385 | 386 | self.tcn = nn.Sequential( 387 | nn.Conv2d( 388 | in_channels, 389 | out_channels, 390 | (self.kernel_size[0], self.kernel_size[1]), 391 | (stride, stride), 392 | padding, 393 | ), 394 | nn.BatchNorm2d(out_channels), 395 | nn.Dropout(dropout, inplace=True), 396 | ) 397 | 398 | 399 | 400 | if stride != 1 or in_channels != out_channels: 401 | 402 | self.residual=nn.Sequential(nn.Conv2d( 403 | in_channels, 404 | out_channels, 405 | kernel_size=1, 406 | stride=(1, 1)), 407 | nn.BatchNorm2d(out_channels), 408 | ) 409 | 410 | 411 | else: 412 | self.residual=nn.Identity() 413 | 414 | 415 | self.prelu = nn.PReLU() 416 | 417 | 418 | 419 | def forward(self, x): 420 | # assert A.shape[0] == self.kernel_size[1], print(A.shape[0],self.kernel_size) 421 | res=self.residual(x) 422 | x=self.gcn(x) 423 | x=self.tcn(x) 424 | x=x+res 425 | x=self.prelu(x) 426 | return x 427 | 428 | 429 | 430 | class Direction(nn.Module): 431 | def __init__(self, motion_dim): 432 | super(Direction, self).__init__() 433 | 434 | self.weight = nn.Parameter(torch.randn(256, motion_dim)) 435 | 436 | def forward(self, input): 437 | # input: (bs*t) x 256 438 | 439 | weight = self.weight + 1e-8 440 | Q, R = torch.qr(weight) # get eignvector, orthogonal [n1, n2, n3, n4] 441 | 442 | if input is None: 443 | return Q 444 | else: 445 | input_diag = torch.diag_embed(input) # alpha, diagonal matrix 446 | out = torch.matmul(input_diag, Q.T) 447 | out = torch.sum(out, dim=1) 448 | 449 | return out 450 | 451 | 452 | class Model(nn.Module): 453 | def __init__(self, nx, ny,input_channels,st_gcnn_dropout, 454 | joints_to_consider, 455 | pose_info): 456 | super(Model, self).__init__() 457 | self.nx = nx 458 | self.ny = ny 459 | self.output_len = 20 460 | self.num_joints = joints_to_consider 461 | self.num_D = 30 462 | if nx == 48: 463 | self.t_his = 25 464 | self.t_pred = 100 465 | elif nx == 42: 466 | self.t_his = 15 467 | self.t_pred = 60 468 | self.nk = 50 469 | self.num_anchor = self.nk 470 | self.anchor_input = nn.ParameterDict() 471 | stdv_anchor = 1. / math.sqrt(128) 472 | for i in range(self.num_anchor): 473 | self.anchor_input[f'anchor_{i}'] = nn.Parameter(torch.FloatTensor(1, 128).uniform_(-stdv_anchor, stdv_anchor)) 474 | 475 | 476 | self.st_gcnns_encoder_past_motion=nn.ModuleList() 477 | #0 478 | self.st_gcnns_encoder_past_motion.append(ST_GCNN_layer(input_channels,128,[3,1],1,self.output_len, 479 | joints_to_consider,st_gcnn_dropout,pose_info=pose_info)) 480 | #1 481 | self.st_gcnns_encoder_past_motion.append(ST_GCNN_layer(128,64,[3,1],1,self.output_len, 482 | joints_to_consider,st_gcnn_dropout, version=1, pose_info=pose_info)) 483 | #2 484 | self.st_gcnns_encoder_past_motion.append(ST_GCNN_layer(64,128,[3,1],1,self.output_len, 485 | joints_to_consider,st_gcnn_dropout, version=1, pose_info=pose_info)) 486 | #3 487 | self.st_gcnns_encoder_past_motion.append(ST_GCNN_layer(128,128,[3,1],1,self.output_len, 488 | joints_to_consider,st_gcnn_dropout, pose_info=pose_info)) 489 | 490 | self.st_gcnns_compress=nn.ModuleList() 491 | #0 492 | self.st_gcnns_compress.append(ST_GCNN_layer_down(256,512,[2,2],2,self.output_len, 493 | joints_to_consider,st_gcnn_dropout, pose_info=pose_info)) 494 | #2 495 | self.st_gcnns_compress.append(ST_GCNN_layer_down(512,768,[2,2],2,self.output_len//2, 496 | joints_to_consider//2,st_gcnn_dropout, pose_info=pose_info)) 497 | 498 | self.st_gcnns_compress.append(ST_GCNN_layer_down(768,1024,[2,2],2,self.output_len//4, 499 | joints_to_consider//4,st_gcnn_dropout, pose_info=pose_info)) 500 | 501 | 502 | down_fc = [EqualLinear(1024, 1024,activation=True)] 503 | for i in range(1): 504 | down_fc.append(EqualLinear(1024, 512,activation=True)) 505 | 506 | down_fc.append(EqualLinear(512, self.num_D)) 507 | self.down_fc = nn.Sequential(*down_fc) 508 | 509 | 510 | self.direction = Direction(motion_dim=self.num_D) 511 | 512 | self.st_gcnns_decoder=nn.ModuleList() 513 | 514 | #4 515 | self.st_gcnns_decoder.append(ST_GCNN_layer(128+256,128,[3,1],1,self.output_len, 516 | joints_to_consider,st_gcnn_dropout, version=1, pose_info=pose_info)) 517 | self.st_gcnns_decoder[-1].gcn.A = self.st_gcnns_encoder_past_motion[-2].gcn.A 518 | 519 | #5 520 | self.st_gcnns_decoder.append(ST_GCNN_layer(128,64,[3,1],1,self.output_len, 521 | joints_to_consider,st_gcnn_dropout, pose_info=pose_info)) 522 | self.st_gcnns_decoder[-1].gcn.A = self.st_gcnns_encoder_past_motion[-1].gcn.A 523 | #6 524 | self.st_gcnns_decoder.append(ST_GCNN_layer(64,128,[3,1],1,self.output_len, 525 | joints_to_consider,st_gcnn_dropout, version=1, pose_info=pose_info)) 526 | self.st_gcnns_decoder[-1].gcn.A = self.st_gcnns_decoder[-3].gcn.A 527 | #7 528 | self.st_gcnns_decoder.append(ST_GCNN_layer(128,input_channels,[3,1],1,self.output_len, 529 | joints_to_consider,st_gcnn_dropout, pose_info=pose_info)) 530 | 531 | 532 | self.dct_m, self.idct_m = self.get_dct_matrix(self.t_his + self.t_pred) 533 | 534 | 535 | 536 | def encode_past_motion(self,x_input): 537 | #x_input: [t_full, bs, V*C] 538 | 539 | # [t_full, bs, V*C] -> [t_full, bs, V, C] -> [bs, c, t_full, v] 540 | x_input = x_input.view(x_input.shape[0], x_input.shape[1], -1, 3).permute(1, 3, 0, 2) 541 | y = torch.zeros((x_input.shape[0], x_input.shape[1], self.t_pred, x_input.shape[3])).to(x_input.device) 542 | # [bs, c, t_full, v] -> [bs, t_full, c, v] 543 | x_padding = torch.cat([x_input[:,:,:self.t_his,], y], dim=2).permute(0, 2, 1, 3) 544 | 545 | N, T, C, V = x_padding.shape 546 | 547 | # [bs, t_full, C, V] -> [bs, t_full, C*V] 548 | 549 | x_padding = x_padding.reshape([N, T, C * V]) 550 | 551 | 552 | dct_m = self.dct_m.to(x_input.device) 553 | idx_pad = list(range(self.t_his)) + [self.t_his - 1] * self.t_pred 554 | 555 | # [bs, t_full, C*V] -> [bs, t_full, C*V] 556 | x_pad = torch.matmul(dct_m[:self.output_len], x_padding[:, idx_pad, :]).reshape([N, -1, C, V]).permute(0, 2, 1, 3) 557 | x = x_pad # [N, C, T, V] 558 | 559 | for gcn in (self.st_gcnns_encoder_past_motion): #0-3 layer 560 | x = gcn(x) 561 | N, C, T, V = x.shape 562 | 563 | return x 564 | 565 | def decoding(self,z,condition=None): 566 | idct_m = self.idct_m.to(z.device) 567 | 568 | condition = condition.view(condition.shape[0], condition.shape[1], -1, 3).permute(1, 3, 0, 2) #(T_his, bs, Num_Joints, 3) -> [bs, t_full, 3, Num_joints] 569 | y_condition = torch.zeros((condition.shape[0], condition.shape[1], self.t_pred, condition.shape[3])).to(condition.device) 570 | condition_padding = torch.cat([condition[:,:,:self.t_his,:], y_condition], dim=2).permute(0, 2, 1, 3) 571 | N, T, C, V = condition_padding.shape 572 | 573 | condition_padding = condition_padding.reshape([N, T, C * V]) 574 | dct_m = self.dct_m.to(condition.device) 575 | idx_pad = list(range(self.t_his)) + [self.t_his - 1] * self.t_pred 576 | condition_p = torch.matmul(dct_m[:self.output_len], condition_padding[:, idx_pad, :]).reshape([N, -1, C, V]).permute(0, 2, 1, 3) 577 | if condition_p.shape[0] != z.shape[0]: 578 | condition_p = condition_p.repeat_interleave(self.nk, dim=0) 579 | 580 | for gcn in (self.st_gcnns_decoder): #0-3 layer 581 | z = gcn(z) 582 | 583 | output = (z + condition_p) 584 | N, C, N_fre, V = output.shape 585 | 586 | output = output.permute(0, 2, 1, 3).reshape([N, -1, C * V]) 587 | 588 | outputs = torch.matmul(idct_m[:, :self.output_len], output).reshape([N, -1, C, V]).permute(1, 0, 3, 2).contiguous().view(-1,N,C*V) 589 | 590 | return outputs 591 | 592 | 593 | def forward(self, x, z=None,epoch=None): 594 | bs = x.shape[1] 595 | z = self.encode_past_motion(x).repeat_interleave(self.nk,dim=0) 596 | replicated_parameters = torch.cat([self.anchor_input[f'anchor_{i}'].expand(self.nk//self.num_anchor, -1) for i in range(self.num_anchor)], dim=0) 597 | anchors_input = replicated_parameters.repeat(bs, 1) 598 | z1 = torch.cat((anchors_input.unsqueeze(2).unsqueeze(3).repeat(1, 1, self.output_len, self.num_joints),z),dim=1) 599 | for gcn in (self.st_gcnns_compress): #0-3 layer 600 | z1 = gcn(z1) 601 | z1 = z1.mean(-1).mean(-1).view(bs*self.nk,-1) 602 | alpha = self.down_fc(z1) 603 | directions = self.direction(alpha) 604 | 605 | N, C, T, V = z.shape 606 | feature = torch.cat((directions.unsqueeze(2).unsqueeze(3).repeat(1, 1, T, V),z),dim=1) 607 | 608 | 609 | outputs = self.decoding(feature, x) 610 | 611 | return outputs , feature, feature 612 | 613 | 614 | def get_dct_matrix(self, N, is_torch=True): 615 | dct_m = np.eye(N, dtype=np.float32) 616 | for k in np.arange(N): 617 | for i in np.arange(N): 618 | w = np.sqrt(2 / N) 619 | if k == 0: 620 | w = np.sqrt(1 / N) 621 | dct_m[k, i] = w * np.cos(np.pi * (i + 1 / 2) * k / N) 622 | idct_m = np.linalg.inv(dct_m) 623 | if is_torch: 624 | dct_m = torch.from_numpy(dct_m) 625 | idct_m = torch.from_numpy(idct_m) 626 | return dct_m, idct_m 627 | 628 | 629 | 630 | --------------------------------------------------------------------------------