├── models
├── __init__.py
├── Orthogonal.py
├── motion_pred.py
├── LinNF.py
└── model.py
├── utils
├── __init__.py
├── logger.py
├── metrics.py
├── vis_util.py
├── util.py
├── torch.py
├── vis_poses.py
└── valid_angle_check.py
├── images
└── intro.png
├── motion_pred
├── cfg
│ ├── h36m.yml
│ └── humaneva.yml
└── utils
│ ├── config.py
│ ├── dataset.py
│ ├── dataset_humaneva.py
│ ├── dataset_h36m.py
│ ├── skeleton.py
│ ├── visualization.py
│ ├── dataset_humaneva_multimodal.py
│ └── dataset_h36m_multimodal.py
├── LICENSE
├── README.md
├── train_nf.py
└── main.py
/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from utils.torch import *
2 | from utils.logger import *
3 |
--------------------------------------------------------------------------------
/images/intro.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuoweiXu368/SLD-HMP/HEAD/images/intro.png
--------------------------------------------------------------------------------
/motion_pred/cfg/h36m.yml:
--------------------------------------------------------------------------------
1 | dataset: h36m
2 | nz: 128
3 | t_his: 25
4 | t_pred: 100
5 | dropout: 0.1
6 | nk: 50
7 | nk1: 5
8 | nk2: 10
9 | lambdas: [ 0, 500.0, 0, 64, 2.0, 1.0, 0.01, 50, 100]
10 | specs:
11 | num_flow_layer: 1
12 | n_pre: 20
13 | multimodal_path: ./data/data_multi_modal/t_his25_1_thre0.500_t_pred100_thre0.100_filtered_dlow.npz
14 | data_candi_path: ./data/data_multi_modal/data_candi_t_his25_t_pred100_skiprate20.npz
15 |
16 | lr: 1.e-3
17 | batch_size: 16
18 | num_epoch: 500
19 | num_epoch_fix: 100
20 | num_data_sample: 5000
21 | normalize_data: False
--------------------------------------------------------------------------------
/motion_pred/cfg/humaneva.yml:
--------------------------------------------------------------------------------
1 | dataset: humaneva
2 | nz: 128
3 | t_his: 15
4 | t_pred: 60
5 | dropout: 0.1
6 | nk: 50
7 | nk1: 5
8 | nk2: 10
9 | lambdas: [ 0, 50.0, 0, 16, 8.0, 4.0, 0.002, 10, 10]
10 | specs:
11 | num_flow_layer: 1
12 | n_pre: 16
13 | multimodal_path: ./data/humaneva_multi_modal/t_his15_1_thre0.500_t_pred60_thre0.010_index_filterd.npz
14 | data_candi_path: ./data/humaneva_multi_modal/data_candi_t_his15_t_pred60_skiprate15.npz
15 |
16 | lr: 1.e-3
17 | batch_size: 16
18 | num_epoch: 500
19 | num_epoch_fix: 100
20 | num_data_sample: 2000
21 | normalize_data: False
22 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Guowei Xu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/models/Orthogonal.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class Orthogonal(nn.Module):
6 |
7 | def __init__(self, d):
8 | super(Orthogonal, self).__init__()
9 | self.d = d
10 | self.U = torch.nn.Parameter(torch.zeros((d, d)).normal_(0, 0.5))
11 |
12 | self.reset_parameters()
13 |
14 | def reset_parameters(self):
15 | self.U.data = self.U.data / torch.norm(self.U.data, dim=1, keepdim=True)
16 |
17 | def sequential_mult(self, V, X):
18 | for row in range(V.shape[0] - 1, -1, -1):
19 | X = X - 2 * (X @ V[row:row + 1, :].t() @ V[row:row + 1, :]) / \
20 | (V[row:row + 1, :] @ V[row:row + 1, :].t())[0, 0]
21 | return X
22 |
23 | def forward(self, X, invert=False):
24 | """
25 |
26 | @param X:
27 | @return:
28 | """
29 | if not invert:
30 | X = self.sequential_mult(self.U, X)
31 | else:
32 | X = self.inverse(X)
33 | return X
34 |
35 | def inverse(self, X):
36 | X = self.sequential_mult(torch.flip(self.U, dims=[0]), X)
37 | return X
38 |
39 | def lgdet(self, X):
40 | return 0
41 |
--------------------------------------------------------------------------------
/models/motion_pred.py:
--------------------------------------------------------------------------------
1 | from models import LinNF, model
2 | import numpy as np
3 |
4 | def get_model(cfg, dataset, model_type='h36m'):
5 | traj_dim = dataset.traj_dim // 3
6 | specs = {'model_name': 'NFDiag', 'rnn_type': 'gru', 'nh_mlp': [1024, 512], 'x_birnn': False}
7 | if model_type == 'h36m' or model_type == 'humaneva':
8 | keep_joints = dataset.kept_joints[1:]
9 |
10 | if model_type == 'h36m':
11 | parents=[-1, 0, 1, 2, 3, 4, 0, 6, 7, 8, 9, 0, 11, 12, 13, 14, 12,
12 | 16, 17, 18, 19, 20, 19, 22, 12, 24, 25, 26, 27, 28, 27, 30]
13 | joints_left=[6, 7, 8, 9, 10, 16, 17, 18, 19, 20, 21, 22, 23]
14 | joints_right=[1, 2, 3, 4, 5, 24, 25, 26, 27, 28, 29, 30, 31]
15 | elif model_type == 'humaneva':
16 | parents=[-1, 0, 1, 2, 3, 1, 5, 6, 0, 8, 9, 0, 11, 12, 1]
17 | joints_left=[2, 3, 4, 8, 9, 10]
18 | joints_right=[5, 6, 7, 11, 12, 13]
19 |
20 | pose_info = {'keep_joints': keep_joints,
21 | 'parents': parents,
22 | 'joints_left': joints_left,
23 | 'joints_right': joints_right}
24 | print("Human pose information: ", pose_info)
25 | return model.Model(traj_dim * 3, 128, input_channels = 3,st_gcnn_dropout = cfg.dropout, joints_to_consider = traj_dim, pose_info=pose_info), \
26 | LinNF.LinNF(data_dim=traj_dim * 3, num_layer=3)
27 |
28 | elif model_type == 'h36m_nf' or model_type == 'humaneva_nf':
29 | return LinNF.LinNF(data_dim=dataset.traj_dim, num_layer=cfg.specs['num_flow_layer'])
--------------------------------------------------------------------------------
/utils/logger.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import numpy as np
4 |
5 |
6 | def create_logger(filename, file_handle=True):
7 | # create logger
8 | logger = logging.getLogger(filename)
9 | logger.propagate = False
10 | logger.setLevel(logging.DEBUG)
11 | # create console handler with a higher log level
12 | ch = logging.StreamHandler()
13 | ch.setLevel(logging.INFO)
14 | stream_formatter = logging.Formatter('%(message)s')
15 | ch.setFormatter(stream_formatter)
16 | logger.addHandler(ch)
17 |
18 | if file_handle:
19 | # create file handler which logs even debug messages
20 | os.makedirs(os.path.dirname(filename), exist_ok=True)
21 | fh = logging.FileHandler(filename, mode='a')
22 | fh.setLevel(logging.DEBUG)
23 | file_formatter = logging.Formatter('[%(asctime)s] %(message)s')
24 | fh.setFormatter(file_formatter)
25 | logger.addHandler(fh)
26 |
27 | return logger
28 |
29 |
30 | def combine_dict(new_dict, old_dict=None):
31 | if old_dict is None:
32 | return new_dict
33 | for k in old_dict.keys():
34 | if k in new_dict.keys():
35 | old_dict[k] = np.concatenate([old_dict[k], new_dict[k]], axis=0)
36 | return old_dict
37 |
38 |
39 | class AverageMeter(object):
40 | """Computes and stores the average and current value"""
41 |
42 | def __init__(self):
43 | self.reset()
44 |
45 | def reset(self):
46 | self.val = 0
47 | self.avg = 0
48 | self.sum = 0
49 | self.count = 0
50 |
51 | def update(self, val, n=1):
52 | self.val = val
53 | self.sum += val * n
54 | self.count += n
55 | self.avg = self.sum / self.count
56 |
--------------------------------------------------------------------------------
/motion_pred/utils/config.py:
--------------------------------------------------------------------------------
1 | import yaml
2 | import os
3 |
4 |
5 | class Config:
6 |
7 | def __init__(self, cfg_id, test=False, nf=False):
8 | self.id = cfg_id
9 | cfg_name = 'motion_pred/cfg/%s.yml' % cfg_id
10 | if not os.path.exists(cfg_name):
11 | print("Config file doesn't exist: %s" % cfg_name)
12 | exit(0)
13 | cfg = yaml.safe_load(open(cfg_name, 'r'))
14 | if nf:
15 | cfg_id += '_nf'
16 | # create dirs
17 |
18 | self.base_dir = 'results'
19 | self.cfg_dir = '%s/%s' % (self.base_dir, cfg_id)
20 | self.model_dir = '%s/models' % self.cfg_dir
21 | self.result_dir = '%s/results' % self.cfg_dir
22 | self.log_dir = '%s/log' % self.cfg_dir
23 | self.tb_dir = '%s/tb' % self.cfg_dir
24 | os.makedirs(self.model_dir, exist_ok=True)
25 | os.makedirs(self.result_dir, exist_ok=True)
26 | os.makedirs(self.log_dir, exist_ok=True)
27 | os.makedirs(self.tb_dir, exist_ok=True)
28 |
29 | # common
30 | self.dataset = cfg.get('dataset', 'h36m')
31 | self.batch_size = cfg.get('batch_size', 8)
32 | self.normalize_data = cfg.get('normalize_data', False)
33 | self.save_model_interval = cfg.get('save_model_interval', 20)
34 | self.t_his = cfg['t_his']
35 | self.t_pred = cfg['t_pred']
36 | self.use_vel = cfg.get('use_vel', False)
37 |
38 | self.nz = cfg['nz']
39 | self.lr = cfg['lr']
40 | self.dropout = cfg.get('dropout', 0.1)
41 | self.num_epoch = cfg['num_epoch']
42 | self.num_epoch_fix = cfg.get('num_epoch_fix', self.num_epoch)
43 | self.num_data_sample = cfg['num_data_sample']
44 | self.model_path = os.path.join(self.model_dir, '%04d.p')
45 | self.poses_prediction_path = os.path.join(self.model_dir, 'poses_prediction%04d.p')
46 | self.poses_vae_path = os.path.join(self.model_dir, 'poses_vae_%04d.p')
47 | self.poses_mapping_path = os.path.join(self.model_dir, 'poses_mapping_%04d.p')
48 | self.nk = cfg.get('nk', 10)
49 | self.nk1 = cfg.get('nk1', 5)
50 | self.nk2 = cfg.get('nk2', 2)
51 | self.lambdas = cfg.get('lambdas', [])
52 |
53 | self.specs = cfg.get('specs', dict())
54 |
--------------------------------------------------------------------------------
/utils/metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy.spatial.distance import pdist
3 |
4 | def mpjpe_error(pred, gt, *args):
5 | indexs = np.zeros(pred.shape[0])
6 | sample_num, total_len, feature_size = pred.shape
7 | pred = pred.reshape([sample_num, total_len, feature_size//3, 3])[:, -1:, :, :] * 1000
8 | gt = gt.reshape([1, total_len, feature_size//3, 3])[:, -1:, :, :] * 1000
9 | dist = np.linalg.norm(pred - gt, axis=3).mean(axis=2).mean(axis=1)
10 | index = dist.argmin()
11 | indexs[index] += 1
12 | return dist[index], indexs
13 |
14 | def compute_diversity(pred, *args):
15 | if pred.shape[0] == 1:
16 | return 0.0
17 | dist = pdist(pred.reshape(pred.shape[0], -1))
18 | diversity = dist.mean().item()
19 | return diversity, None
20 |
21 |
22 | def compute_ade(pred, gt, *args):
23 | indexs = np.zeros(pred.shape[0])
24 | diff = pred - gt
25 | dist = np.linalg.norm(diff, axis=2).mean(axis=1)
26 | index = dist.argmin()
27 | indexs[index] += 1
28 | return dist[index], indexs
29 |
30 |
31 | def compute_fde(pred, gt, *args):
32 | indexs = np.zeros(pred.shape[0])
33 | diff = pred - gt
34 | dist = np.linalg.norm(diff, axis=2)[:, -1]
35 | index = dist.argmin()
36 | indexs[index] += 1
37 | return dist[index], indexs
38 |
39 | def compute_amse(pred, gt, *args):
40 | diff = pred - gt # sample_num * total_len * ((num_key-1)*3)
41 | dist = (diff*diff).sum() / diff.shape[0]
42 | return dist.mean(), None
43 |
44 |
45 | def compute_fmse(pred, gt, *args):
46 | diff = pred[:, -1, :] - gt[:, -1, :] # sample_num * total_len * ((num_key-1)*3)
47 | dist = (diff*diff).sum() / diff.shape[0]
48 | return dist.mean(), None
49 |
50 | def compute_mmade(pred, gt, gt_multi):
51 | gt_dist = []
52 | indexs = np.zeros(pred.shape[0])
53 | for gt_multi_i in gt_multi:
54 | dist, index = compute_ade(pred, gt_multi_i)
55 | gt_dist.append(dist)
56 | indexs += index
57 | gt_dist = np.array(gt_dist).mean()
58 | return gt_dist, indexs
59 |
60 |
61 | def compute_mmfde(pred, gt, gt_multi):
62 | gt_dist = []
63 | indexs = np.zeros(pred.shape[0])
64 | for gt_multi_i in gt_multi:
65 | dist, index = compute_fde(pred, gt_multi_i)
66 | gt_dist.append(dist)
67 | indexs += index
68 | gt_dist = np.array(gt_dist).mean()
69 | return gt_dist, indexs
70 |
--------------------------------------------------------------------------------
/motion_pred/utils/dataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class Dataset:
5 |
6 | def __init__(self, mode, t_his, t_pred, actions='all'):
7 | self.mode = mode
8 | self.t_his = t_his
9 | self.t_pred = t_pred
10 | self.t_total = t_his + t_pred
11 | self.actions = actions
12 | self.prepare_data()
13 | self.std, self.mean = None, None
14 | self.data_len = sum([seq.shape[0] for data_s in self.data.values() for seq in data_s.values()])
15 | self.traj_dim = (self.kept_joints.shape[0] - 1) * 3
16 | self.normalized = False
17 | # iterator specific
18 | self.sample_ind = None
19 |
20 | def prepare_data(self):
21 | raise NotImplementedError
22 |
23 | def normalize_data(self, mean=None, std=None):
24 | if mean is None:
25 | all_seq = []
26 | for data_s in self.data.values():
27 | for seq in data_s.values():
28 | all_seq.append(seq[:, 1:])
29 | all_seq = np.concatenate(all_seq)
30 | self.mean = all_seq.mean(axis=0)
31 | self.std = all_seq.std(axis=0)
32 | else:
33 | self.mean = mean
34 | self.std = std
35 | for data_s in self.data.values():
36 | for action in data_s.keys():
37 | data_s[action][:, 1:] = (data_s[action][:, 1:] - self.mean) / self.std
38 | self.normalized = True
39 |
40 | def sample(self):
41 | subject = np.random.choice(self.subjects)
42 | dict_s = self.data[subject]
43 | action = np.random.choice(list(dict_s.keys()))
44 | seq = dict_s[action]
45 | fr_start = np.random.randint(seq.shape[0] - self.t_total)
46 | fr_end = fr_start + self.t_total
47 | traj = seq[fr_start: fr_end]
48 | return traj[None, ...]
49 |
50 | def sampling_generator(self, num_samples=1000, batch_size=8):
51 | for i in range(num_samples // batch_size):
52 | sample = []
53 | for i in range(batch_size):
54 | sample_i = self.sample()
55 | sample.append(sample_i)
56 | sample = np.concatenate(sample, axis=0)
57 | yield sample
58 |
59 | def iter_generator(self, step=25):
60 | for data_s in self.data.values():
61 | for seq in data_s.values():
62 | seq_len = seq.shape[0]
63 | for i in range(0, seq_len - self.t_total, step):
64 | traj = seq[None, i: i + self.t_total]
65 | yield traj
66 |
67 |
68 |
69 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # "Learning Semantic Latent Directions for Accurate and Controllable Human Motion Prediction" (**ECCV 2024**)
2 |
3 |
4 |
5 | ---
6 | This repo contains the official implementation of the paper:
7 |
8 | Learning Semantic Latent Directions for Accurate and Controllable Human Motion Prediction
9 |
10 | ECCV 2024
11 | [[arxiv](https://arxiv.org/abs/2407.11494)]
12 | ### Dependencies
13 | * Python >= 3.8
14 | * [PyTorch](https://pytorch.org) >= 1.9
15 | * Tensorboard
16 | * matplotlib
17 | * tqdm
18 | * argparse
19 |
20 | ### Get the data
21 | We adapt the data preprocessing from [GSPS](https://github.com/wei-mao-2019/gsps).
22 | * We follow the data preprocessing steps ([DATASETS.md](https://github.com/facebookresearch/VideoPose3D/blob/master/DATASETS.md)) inside the [VideoPose3D](https://github.com/facebookresearch/VideoPose3D) repo.
23 | * Given the processed dataset, we further compute the multi-modal future for each motion sequence. All data needed can be downloaded from [Google Drive](https://drive.google.com/drive/folders/1sb1n9l0Na5EqtapDVShOJJ-v6o-GZrIJ?usp=sharing) and place all the dataset in ``data`` folder inside the root of this repo.
24 |
25 | ### Get the pretrain models
26 | * All pretrain models can be downloaded from [Google Drive](https://drive.google.com/drive/folders/1YAa3Lpei0V3-JTEZwSw0WqRfSTYbY-z2?usp=drive_link) and place all the pretrain models in ``results`` folder inside the root of this repo.
27 |
28 | ### Train
29 | We have used the following commands for training the network on Human3.6M or HumanEva-I with skeleton representation:
30 | ```bash
31 | python train_nf.py --cfg [h36m/humaneva] --gpu_index 0
32 | python main.py --cfg [h36m/humaneva] --gpu_index 0
33 | ```
34 | ### Test
35 | To test on the pretrained model, we have used the following commands:
36 | ```bash
37 | python main.py --cfg [h36m/humaneva] --mode test --iter 500 --gpu_index 0
38 | ```
39 | ### Visualization
40 | For visualizing from a pretrained model, we have used the following commands:
41 |
42 | ```bash
43 | python main.py --cfg [h36m/humaneva] --mode viz --iter 500 --gpu_index 0
44 | ```
45 |
46 | ### Acknowledgments
47 |
48 | This code is based on the implementations of [STARS](https://github.com/Sirui-Xu/STARS).
49 |
50 | ## Citation
51 | If you find this work useful in your research, please cite:
52 |
53 | ```bibtex
54 | @article{xu2024learning,
55 | title={Learning Semantic Latent Directions for Accurate and Controllable Human Motion Prediction},
56 | author={Xu, Guowei and Tao, Jiale and Li, Wen and Duan, Lixin},
57 | journal={arXiv preprint arXiv:2407.11494},
58 | year={2024}
59 | }
60 | ```
61 |
62 | ## License
63 |
64 | This repo is distributed under an [MIT LICENSE](LICENSE)
65 |
--------------------------------------------------------------------------------
/motion_pred/utils/dataset_humaneva.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | from motion_pred.utils.dataset import Dataset
4 | from motion_pred.utils.skeleton import Skeleton
5 |
6 |
7 | class DatasetHumanEva(Dataset):
8 |
9 | def __init__(self, mode, t_his=15, t_pred=60, actions='all', **kwargs):
10 | super().__init__(mode, t_his, t_pred, actions)
11 |
12 | def prepare_data(self):
13 | self.data_file = os.path.join('data', 'data_3d_humaneva15.npz')
14 | self.subjects_split = {'train': ['Train/S1', 'Train/S2', 'Train/S3'],
15 | 'test': ['Validate/S1', 'Validate/S2', 'Validate/S3']}
16 | self.subjects = [x for x in self.subjects_split[self.mode]]
17 | self.skeleton = Skeleton(parents=[-1, 0, 1, 2, 3, 1, 5, 6, 0, 8, 9, 0, 11, 12, 1],
18 | joints_left=[2, 3, 4, 8, 9, 10],
19 | joints_right=[5, 6, 7, 11, 12, 13])
20 | self.kept_joints = np.arange(15)
21 | self.process_data()
22 |
23 | def process_data(self):
24 | data_o = np.load(self.data_file, allow_pickle=True)['positions_3d'].item()
25 | data_f = dict(filter(lambda x: x[0] in self.subjects, data_o.items()))
26 | # these takes have wrong head position, excluded from training and testing
27 | if self.mode == 'train':
28 | data_f['Train/S3'].pop('Walking 1 chunk0')
29 | data_f['Train/S3'].pop('Walking 1 chunk2')
30 | else:
31 | data_f['Validate/S3'].pop('Walking 1 chunk4')
32 | for key in list(data_f.keys()):
33 | # data_f[key] = dict(filter(lambda x: (self.actions == 'all' or
34 | # all([a in x[0] for a in self.actions]))
35 | # and x[1].shape[0] >= self.t_total, data_f[key].items()))
36 | data_f[key] = dict(filter(lambda x: (self.actions == 'all' or
37 | any([a in x[0] for a in self.actions]))
38 | and x[1].shape[0] >= self.t_total, data_f[key].items()))
39 | if len(data_f[key]) == 0:
40 | data_f.pop(key)
41 | for data_s in data_f.values():
42 | for action in data_s.keys():
43 | seq = data_s[action][:, self.kept_joints, :]
44 | seq[:, 1:] -= seq[:, :1]
45 | data_s[action] = seq
46 | self.data = data_f
47 |
48 |
49 | if __name__ == '__main__':
50 | np.random.seed(0)
51 | actions = 'all'
52 | dataset = DatasetHumanEva('test', actions=actions)
53 | generator = dataset.sampling_generator()
54 | dataset.normalize_data()
55 | # generator = dataset.iter_generator()
56 | for data in generator:
57 | print(data.shape)
58 |
59 |
60 |
--------------------------------------------------------------------------------
/motion_pred/utils/dataset_h36m.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | from motion_pred.utils.dataset import Dataset
4 | from motion_pred.utils.skeleton import Skeleton
5 |
6 |
7 | class DatasetH36M(Dataset):
8 |
9 | def __init__(self, mode, t_his=25, t_pred=100, actions='all', use_vel=False):
10 | self.use_vel = use_vel
11 | super().__init__(mode, t_his, t_pred, actions)
12 | if use_vel:
13 | self.traj_dim += 3
14 |
15 | def prepare_data(self):
16 | self.data_file = os.path.join('data', 'data_3d_h36m.npz')
17 | self.subjects_split = {'train': [1, 5, 6, 7, 8],
18 | 'test': [9, 11]}
19 | self.subjects = ['S%d' % x for x in self.subjects_split[self.mode]]
20 | self.skeleton = Skeleton(parents=[-1, 0, 1, 2, 3, 4, 0, 6, 7, 8, 9, 0, 11, 12, 13, 14, 12,
21 | 16, 17, 18, 19, 20, 19, 22, 12, 24, 25, 26, 27, 28, 27, 30],
22 | joints_left=[6, 7, 8, 9, 10, 16, 17, 18, 19, 20, 21, 22, 23],
23 | joints_right=[1, 2, 3, 4, 5, 24, 25, 26, 27, 28, 29, 30, 31])
24 | self.removed_joints = {4, 5, 9, 10, 11, 16, 20, 21, 22, 23, 24, 28, 29, 30, 31}
25 | self.kept_joints = np.array([x for x in range(32) if x not in self.removed_joints])
26 | self.skeleton.remove_joints(self.removed_joints)
27 | self.skeleton._parents[11] = 8
28 | self.skeleton._parents[14] = 8
29 | self.process_data()
30 |
31 | def process_data(self):
32 | data_o = np.load(self.data_file, allow_pickle=True)['positions_3d'].item()
33 | self.S1_skeleton = data_o['S1']['Directions'][:1, self.kept_joints].copy()
34 | data_f = dict(filter(lambda x: x[0] in self.subjects, data_o.items()))
35 | if self.actions != 'all':
36 | for key in list(data_f.keys()):
37 | # data_f[key] = dict(filter(lambda x: all([a in x[0] for a in self.actions]), data_f[key].items()))
38 | data_f[key] = dict(filter(lambda x: any([a in str.lower(x[0]) for a in self.actions]), data_f[key].items()))
39 | if len(data_f[key]) == 0:
40 | data_f.pop(key)
41 | for data_s in data_f.values():
42 | for action in data_s.keys():
43 | seq = data_s[action][:, self.kept_joints, :]
44 | if self.use_vel:
45 | v = (np.diff(seq[:, :1], axis=0) * 50).clip(-5.0, 5.0)
46 | v = np.append(v, v[[-1]], axis=0)
47 | seq[:, 1:] -= seq[:, :1]
48 | if self.use_vel:
49 | seq = np.concatenate((seq, v), axis=1)
50 | data_s[action] = seq
51 | self.data = data_f
52 |
53 |
54 | if __name__ == '__main__':
55 | np.random.seed(0)
56 | actions = {'WalkDog'}
57 | dataset = DatasetH36M('train', actions=actions)
58 | generator = dataset.sampling_generator()
59 | dataset.normalize_data()
60 | # generator = dataset.iter_generator()
61 | for data in generator:
62 | print(data.shape)
63 |
--------------------------------------------------------------------------------
/motion_pred/utils/skeleton.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2018-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | import numpy as np
9 |
10 |
11 | class Skeleton:
12 | def __init__(self, parents, joints_left, joints_right):
13 | assert len(joints_left) == len(joints_right)
14 |
15 | self._parents = np.array(parents)
16 | self._joints_left = joints_left
17 | self._joints_right = joints_right
18 | self._compute_metadata()
19 |
20 | def num_joints(self):
21 | return len(self._parents)
22 |
23 | def parents(self):
24 | return self._parents
25 |
26 | def has_children(self):
27 | return self._has_children
28 |
29 | def children(self):
30 | return self._children
31 |
32 | def remove_joints(self, joints_to_remove):
33 | """
34 | Remove the joints specified in 'joints_to_remove'.
35 | """
36 | valid_joints = []
37 | for joint in range(len(self._parents)):
38 | if joint not in joints_to_remove:
39 | valid_joints.append(joint)
40 |
41 | for i in range(len(self._parents)):
42 | while self._parents[i] in joints_to_remove:
43 | self._parents[i] = self._parents[self._parents[i]]
44 |
45 | index_offsets = np.zeros(len(self._parents), dtype=int)
46 | new_parents = []
47 | for i, parent in enumerate(self._parents):
48 | if i not in joints_to_remove:
49 | new_parents.append(parent - index_offsets[parent])
50 | else:
51 | index_offsets[i:] += 1
52 | self._parents = np.array(new_parents)
53 |
54 |
55 | if self._joints_left is not None:
56 | new_joints_left = []
57 | for joint in self._joints_left:
58 | if joint in valid_joints:
59 | new_joints_left.append(joint - index_offsets[joint])
60 | self._joints_left = new_joints_left
61 | if self._joints_right is not None:
62 | new_joints_right = []
63 | for joint in self._joints_right:
64 | if joint in valid_joints:
65 | new_joints_right.append(joint - index_offsets[joint])
66 | self._joints_right = new_joints_right
67 |
68 | self._compute_metadata()
69 |
70 | return valid_joints
71 |
72 | def joints_left(self):
73 | return self._joints_left
74 |
75 | def joints_right(self):
76 | return self._joints_right
77 |
78 | def _compute_metadata(self):
79 | self._has_children = np.zeros(len(self._parents)).astype(bool)
80 | for i, parent in enumerate(self._parents):
81 | if parent != -1:
82 | self._has_children[parent] = True
83 |
84 | self._children = []
85 | for i, parent in enumerate(self._parents):
86 | self._children.append([])
87 | for i, parent in enumerate(self._parents):
88 | if parent != -1:
89 | self._children[parent].append(i)
--------------------------------------------------------------------------------
/models/LinNF.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from torch.nn import functional as F
5 |
6 | from models import Orthogonal
7 | from torch.nn.parameter import Parameter
8 | import math
9 |
10 |
11 | class LinQR(nn.Module):
12 | """
13 | Implementation of the additive coupling layer from section 3.2 of the NICE
14 | paper.
15 | """
16 |
17 | def __init__(self, data_dim):
18 | super().__init__()
19 |
20 | self.Q = Orthogonal.Orthogonal(d=data_dim)
21 | self.R = Parameter(torch.Tensor(data_dim, data_dim))
22 | self.bias = Parameter(torch.Tensor(data_dim))
23 | self.reset_parameters()
24 |
25 | def reset_parameters(self):
26 | torch.nn.init.kaiming_uniform_(self.R, a=math.sqrt(5))
27 | if self.bias is not None:
28 | fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.R)
29 | bound = 1 / math.sqrt(fan_in)
30 | torch.nn.init.uniform_(self.bias, -bound, bound)
31 |
32 | def forward(self, x, logdet):
33 | """
34 | x,x_cond: [bs,data_dim]
35 |
36 | """
37 | diagR = torch.diag(self.R)
38 | R = torch.triu(self.R, diagonal=1) + torch.diag(torch.exp(diagR))
39 | x = x.matmul(R.t())
40 | x = self.Q(x)
41 | x = x + self.bias
42 | logdet += diagR[None, :].repeat([x.shape[0], 1]).sum(dim=1)
43 | return x, logdet
44 |
45 | def inverse(self, x, logdet):
46 | """
47 | x,x_cond: [bs,data_dim]
48 |
49 | """
50 |
51 | x = x - self.bias
52 | x = self.Q.inverse(x)
53 | diagR = torch.diag(self.R)
54 | R = torch.triu(self.R, diagonal=1) + torch.diag(torch.exp(diagR))
55 | invR = torch.inverse(R)
56 | x = x.matmul(invR.t())
57 |
58 | logdet += diagR[None, :].repeat([x.shape[0], 1]).sum(dim=1)
59 | return x, logdet
60 |
61 |
62 | class prelu(nn.Module):
63 |
64 | def __init__(self, num_parameters: int = 1, init: float = 0.25):
65 | self.num_parameters = num_parameters
66 | super(prelu, self).__init__()
67 | self.weight = Parameter(torch.Tensor(num_parameters).fill_(init))
68 |
69 | def forward(self, input, logdet):
70 | s = torch.zeros_like(input)
71 | s[input < 0] = torch.log(self.weight)
72 | logdet += torch.sum(s, dim=1)
73 | return F.prelu(input, self.weight), logdet
74 |
75 | def inverse(self, input, logdet):
76 | s = torch.zeros_like(input)
77 | s[input < 0] = torch.log(self.weight)
78 | logdet += torch.sum(s, dim=1)
79 | return F.prelu(input, 1 / self.weight), logdet
80 |
81 |
82 | class LinNF(nn.Module):
83 | def __init__(self, data_dim, num_layer=5, with_prelu=True):
84 | super().__init__()
85 | self.num_layer = num_layer
86 | self.w1 = nn.ModuleList()
87 | for i in range(num_layer - 1):
88 | self.w1.append(LinQR(data_dim=data_dim))
89 | if with_prelu:
90 | self.w1.append(prelu())
91 | self.w1.append(LinQR(data_dim=data_dim))
92 |
93 | def forward(self, x):
94 |
95 | z = x
96 | log_det_jacobian = 0
97 |
98 | for i, w in enumerate(self.w1):
99 | z, log_det_jacobian = w(z, log_det_jacobian)
100 |
101 | return z, log_det_jacobian
102 |
103 | def inverse(self, z):
104 | x = z
105 | log_det_jacobian = 0
106 |
107 | for i in range(len(self.w1) - 1, -1, -1):
108 | x, log_det_jacobian = self.w1[i].inverse(x, log_det_jacobian)
109 |
110 | return x, log_det_jacobian
111 |
112 |
113 | if __name__ == '__main__':
114 | bs = 32
115 | data_dim = 25
116 |
117 | sf = LinNF(data_dim=data_dim)
118 | # a = sf.prior.sample([10000, 48, 25])
119 | # print(torch.mean(sf.prior.log_prob(a).sum([1, 2])))
120 | sf.double()
121 | sf.cuda()
122 | for i in range(10):
123 | x = torch.randn([bs, data_dim]).double().cuda()
124 |
125 | y1, logdet = sf(x)
126 | x1, logdet = sf.inverse(y1)
127 | err = (x1 / x - 1).abs().max()
128 | print(err)
129 | print(1)
130 |
--------------------------------------------------------------------------------
/utils/vis_util.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | from mpl_toolkits.mplot3d import Axes3D
3 | import numpy as np
4 |
5 |
6 | def draw_skeleton(ax, kpts, parents=[], is_right=[], cols=["#3498db", "#e74c3c"], marker='o', line_style='-',
7 | label=None):
8 | """
9 |
10 | :param kpts: joint_n*(3 or 2)
11 | :param parents:
12 | :return:
13 | """
14 | # ax = plt.subplot(111)
15 | joint_n, dims = kpts.shape
16 | # by default it is human 3.6m joints
17 | # [0, 1, 2, 3, 6, 7, 8, 12, 13, 14, 15, 25, 26, 27, 17, 18, 19]
18 | # if len(parents) == 0:
19 | # parents = [-1, 0, 1, 2, 0, 4, 5, 0, 7, 8, 9, 8, 11, 12, 8, 14, 15]
20 | # is_right = [0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0]
21 | # if cols == []:
22 | # cols = ["#3498db", "#e74c3c"]
23 | # if parents == 'op':
24 | # parents = [1, -1, 1, 2, 3, 1, 5, 6, 1, 8, 9, 1, 11, 12, 0, 0, 14, 15]
25 | # if parents == 'smpl':
26 | # # parents = [-1, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 12, 13, 14, 16, 17, 18, 19, 20, 21]
27 | # parents = [-1, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 12, 13, 14, 16, 17, 18, 19]
28 | # # is_right = [0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1]
29 | # is_right = [0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1]
30 | # if cols == []:
31 | # cols = ["#3498db", "#e74c3c"]
32 | # if parents == 'smpl_add':
33 | # parents = [-1, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 12, 13, 14, 16, 17, 18, 19, 20, 21, 15]
34 | # is_right = [0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0]
35 | # cols = ["#3498db", "#e74c3c"]
36 | # ax.set_xlabel('X Label')
37 | # ax.set_ylabel('Y Label')
38 | if dims > 2:
39 | ax.view_init(75, 90)
40 | ax.set_zlabel('Z Label')
41 | # if dims == 2:
42 | # # idx_choosed = np.intersect1d(np.where(kpts[:, 0] > 0)[0], np.where(kpts[:, 1] > 0)[0])
43 | # # ax.scatter(kpts[idx_choosed, 0], kpts[idx_choosed, 1], c=c, marker=marker, s=10)
44 | # ax.scatter(kpts[:, 0], kpts[:, 1], c=cols[0], marker=marker, s=10)
45 | # # for i in idx_choosed:
46 | # # ax.text(kpts[i, 0], kpts[i, 1], "{:d}".format(i), color=c)
47 | # else:
48 | # ax.scatter(kpts[:, 0], kpts[:, 1], kpts[:, 2], c=cols[0], marker=marker, s=10)
49 | # for i in range(kpts.shape[0]):
50 | # ax.text(kpts[i, 0], kpts[i, 1], kpts[i, 2], "{:d}".format(i), color=cols[0])
51 | is_label = True
52 | for i in range(len(parents)):
53 | if parents[i] < 0:
54 | continue
55 | # if dims == 2:
56 | # if not (parents[i] in idx_choosed and i in idx_choosed):
57 | # continue
58 |
59 | if dims == 2:
60 | # ax.plot([kpts[parents[i], 0], kpts[i, 0]], [kpts[parents[i], 1], kpts[i, 1]], c=cols[is_right[i]],
61 | # linestyle=line_style,
62 | # alpha=0.5 if is_right[i] else 1, linewidth=3)
63 | if label is not None and is_label:
64 | ax.plot([kpts[parents[i], 0], kpts[i, 0]], [kpts[parents[i], 1], kpts[i, 1]], c=cols[is_right[i]],
65 | linestyle=line_style,
66 | alpha=1 if is_right[i] else 0.6, label=label)
67 | is_label = False
68 | else:
69 | ax.plot([kpts[parents[i], 0], kpts[i, 0]], [kpts[parents[i], 1], kpts[i, 1]], c=cols[is_right[i]],
70 | linestyle=line_style,
71 | alpha=1 if is_right[i] else 0.6)
72 | else:
73 | if label is not None and is_label:
74 | ax.plot([kpts[parents[i], 0], kpts[i, 0]], [kpts[parents[i], 1], kpts[i, 1]],
75 | [kpts[parents[i], 2], kpts[i, 2]], linestyle=line_style, c=cols[is_right[i]],
76 | alpha=1 if is_right[i] else 0.6, linewidth=3, label=label)
77 | is_label = False
78 | else:
79 | ax.plot([kpts[parents[i], 0], kpts[i, 0]], [kpts[parents[i], 1], kpts[i, 1]],
80 | [kpts[parents[i], 2], kpts[i, 2]], linestyle=line_style, c=cols[is_right[i]],
81 | alpha=1 if is_right[i] else 0.6, linewidth=3)
82 |
83 | return None
--------------------------------------------------------------------------------
/utils/util.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | def get_dct_matrix(N, is_torch=True):
6 | dct_m = np.eye(N)
7 | for k in np.arange(N):
8 | for i in np.arange(N):
9 | w = np.sqrt(2 / N)
10 | if k == 0:
11 | w = np.sqrt(1 / N)
12 | dct_m[k, i] = w * np.cos(np.pi * (i + 1 / 2) * k / N)
13 | idct_m = np.linalg.inv(dct_m)
14 | if is_torch:
15 | dct_m = torch.from_numpy(dct_m)
16 | idct_m = torch.from_numpy(idct_m)
17 | return dct_m, idct_m
18 |
19 |
20 | def _pairwise_distances(embeddings, squared=False):
21 | """Compute the 2D matrix of distances between all the embeddings.
22 |
23 | Args:
24 | embeddings: tensor of shape (batch_size, embed_dim)
25 | squared: Boolean. If true, output is the pairwise squared euclidean distance matrix.
26 | If false, output is the pairwise euclidean distance matrix.
27 |
28 | Returns:
29 | pairwise_distances: tensor of shape (batch_size, batch_size)
30 | """
31 | dot_product = torch.matmul(embeddings, embeddings.t())
32 |
33 | # Get squared L2 norm for each embedding. We can just take the diagonal of `dot_product`.
34 | # This also provides more numerical stability (the diagonal of the result will be exactly 0).
35 | # shape (batch_size,)
36 | square_norm = torch.diag(dot_product)
37 |
38 | # Compute the pairwise distance matrix as we have:
39 | # ||a - b||^2 = ||a||^2 - 2 + ||b||^2
40 | # shape (batch_size, batch_size)
41 | distances = square_norm.unsqueeze(0) - 2.0 * dot_product + square_norm.unsqueeze(1)
42 |
43 | # Because of computation errors, some distances might be negative so we put everything >= 0.0
44 | distances[distances < 0] = 0
45 |
46 | if not squared:
47 | # Because the gradient of sqrt is infinite when distances == 0.0 (ex: on the diagonal)
48 | # we need to add a small epsilon where distances == 0.0
49 | mask = distances.eq(0).float()
50 | distances = distances + mask * 1e-16
51 |
52 | distances = (1.0 - mask) * torch.sqrt(distances)
53 |
54 | return distances
55 |
56 |
57 | def _pairwise_distances_l1(embeddings, squared=False):
58 | """Compute the 2D matrix of distances between all the embeddings.
59 |
60 | Args:
61 | embeddings: tensor of shape (batch_size, embed_dim)
62 | squared: Boolean. If true, output is the pairwise squared euclidean distance matrix.
63 | If false, output is the pairwise euclidean distance matrix.
64 |
65 | Returns:
66 | pairwise_distances: tensor of shape (batch_size, batch_size)
67 | """
68 | distances = torch.abs(embeddings[None, :, :] - embeddings[:, None, :])
69 | return distances
70 |
71 |
72 | def expmap2rotmat(r):
73 | """
74 | Converts an exponential map angle to a rotation matrix
75 | Matlab port to python for evaluation purposes
76 | I believe this is also called Rodrigues' formula
77 | https://github.com/asheshjain399/RNNexp/blob/srnn/structural_rnn/CRFProblems/H3.6m/mhmublv/Motion/expmap2rotmat.m
78 |
79 | Args
80 | r: 1x3 exponential map
81 | Returns
82 | R: 3x3 rotation matrix
83 | """
84 | theta = np.linalg.norm(r)
85 | r0 = np.divide(r, theta + np.finfo(np.float32).eps)
86 | r0x = np.array([0, -r0[2], r0[1], 0, 0, -r0[0], 0, 0, 0]).reshape(3, 3)
87 | r0x = r0x - r0x.T
88 | R = np.eye(3, 3) + np.sin(theta) * r0x + (1 - np.cos(theta)) * (r0x).dot(r0x);
89 | return R
90 |
91 |
92 | def absolute2relative(x, parents, invert=False, x0=None):
93 | """
94 | x: [bs,..., jn, 3] or [bs,..., jn-1, 3] if invert
95 | x0: [1,..., jn, 3]
96 | parents: [-1,0,1 ...]
97 | """
98 | if not invert:
99 | xt = x[..., 1:, :] - x[..., parents[1:], :]
100 | xt = xt / np.linalg.norm(xt, axis=-1, keepdims=True)
101 | return xt
102 | else:
103 | jn = x0.shape[-2]
104 | limb_l = np.linalg.norm(x0[..., 1:, :] - x0[..., parents[1:], :], axis=-1, keepdims=True)
105 | xt = x * limb_l
106 | xt0 = np.zeros_like(xt[..., :1, :])
107 | xt = np.concatenate([xt0, xt], axis=-2)
108 | for i in range(1, jn):
109 | xt[..., i, :] = xt[..., parents[i], :] + xt[..., i, :]
110 | return xt
111 |
112 |
113 | def absolute2relative_torch(x, parents, invert=False, x0=None):
114 | """
115 | x: [bs,..., jn, 3] or [bs,..., jn-1, 3] if invert
116 | x0: [1,..., jn, 3]
117 | parents: [-1,0,1 ...]
118 | """
119 | if not invert:
120 | xt = x[..., 1:, :] - x[..., parents[1:], :]
121 | xt = xt / torch.norm(xt, dim=-1, keepdim=True)
122 | return xt
123 | else:
124 | jn = x0.shape[-2]
125 | limb_l = torch.norm(x0[..., 1:, :] - x0[..., parents[1:], :], dim=-1, keepdim=True)
126 | xt = x * limb_l
127 | xt0 = torch.zeros_like(xt[..., :1, :])
128 | xt = torch.cat([xt0, xt], dim=-2)
129 | for i in range(1, jn):
130 | xt[..., i, :] = xt[..., parents[i], :] + xt[..., i, :]
131 | return xt
132 |
--------------------------------------------------------------------------------
/utils/torch.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from torch.optim import lr_scheduler
4 |
5 | tensor = torch.tensor
6 | DoubleTensor = torch.DoubleTensor
7 | FloatTensor = torch.FloatTensor
8 | LongTensor = torch.LongTensor
9 | ByteTensor = torch.ByteTensor
10 | ones = torch.ones
11 | zeros = torch.zeros
12 |
13 |
14 | class to_cpu:
15 |
16 | def __init__(self, *models):
17 | self.models = list(filter(lambda x: x is not None, models))
18 | self.prev_devices = [x.device if hasattr(x, 'device') else next(x.parameters()).device for x in self.models]
19 | for x in self.models:
20 | x.to(torch.device('cpu'))
21 |
22 | def __enter__(self):
23 | pass
24 |
25 | def __exit__(self, *args):
26 | for x, device in zip(self.models, self.prev_devices):
27 | x.to(device)
28 | return False
29 |
30 |
31 | class to_device:
32 |
33 | def __init__(self, device, *models):
34 | self.models = list(filter(lambda x: x is not None, models))
35 | self.prev_devices = [x.device if hasattr(x, 'device') else next(x.parameters()).device for x in self.models]
36 | for x in self.models:
37 | x.to(device)
38 |
39 | def __enter__(self):
40 | pass
41 |
42 | def __exit__(self, *args):
43 | for x, device in zip(self.models, self.prev_devices):
44 | x.to(device)
45 | return False
46 |
47 |
48 | class to_test:
49 |
50 | def __init__(self, *models):
51 | self.models = list(filter(lambda x: x is not None, models))
52 | self.prev_modes = [x.training for x in self.models]
53 | for x in self.models:
54 | x.train(False)
55 |
56 | def __enter__(self):
57 | pass
58 |
59 | def __exit__(self, *args):
60 | for x, mode in zip(self.models, self.prev_modes):
61 | x.train(mode)
62 | return False
63 |
64 |
65 | class to_train:
66 |
67 | def __init__(self, *models):
68 | self.models = list(filter(lambda x: x is not None, models))
69 | self.prev_modes = [x.training for x in self.models]
70 | for x in self.models:
71 | x.train(True)
72 |
73 | def __enter__(self):
74 | pass
75 |
76 | def __exit__(self, *args):
77 | for x, mode in zip(self.models, self.prev_modes):
78 | x.train(mode)
79 | return False
80 |
81 |
82 | def batch_to(dst, *args):
83 | return [x.to(dst) if x is not None else None for x in args]
84 |
85 |
86 | def get_flat_params_from(models):
87 | if not hasattr(models, '__iter__'):
88 | models = (models, )
89 | params = []
90 | for model in models:
91 | for param in model.parameters():
92 | params.append(param.data.view(-1))
93 |
94 | flat_params = torch.cat(params)
95 | return flat_params
96 |
97 |
98 | def set_flat_params_to(model, flat_params):
99 | prev_ind = 0
100 | for param in model.parameters():
101 | flat_size = int(np.prod(list(param.size())))
102 | param.data.copy_(
103 | flat_params[prev_ind:prev_ind + flat_size].view(param.size()))
104 | prev_ind += flat_size
105 |
106 |
107 | def get_flat_grad_from(inputs, grad_grad=False):
108 | grads = []
109 | for param in inputs:
110 | if grad_grad:
111 | grads.append(param.grad.grad.view(-1))
112 | else:
113 | if param.grad is None:
114 | grads.append(zeros(param.view(-1).shape))
115 | else:
116 | grads.append(param.grad.view(-1))
117 |
118 | flat_grad = torch.cat(grads)
119 | return flat_grad
120 |
121 |
122 | def compute_flat_grad(output, inputs, filter_input_ids=set(), retain_graph=False, create_graph=False):
123 | if create_graph:
124 | retain_graph = True
125 |
126 | inputs = list(inputs)
127 | params = []
128 | for i, param in enumerate(inputs):
129 | if i not in filter_input_ids:
130 | params.append(param)
131 |
132 | grads = torch.autograd.grad(output, params, retain_graph=retain_graph, create_graph=create_graph)
133 |
134 | j = 0
135 | out_grads = []
136 | for i, param in enumerate(inputs):
137 | if i in filter_input_ids:
138 | out_grads.append(zeros(param.view(-1).shape))
139 | else:
140 | out_grads.append(grads[j].view(-1))
141 | j += 1
142 | grads = torch.cat(out_grads)
143 |
144 | for param in params:
145 | param.grad = None
146 | return grads
147 |
148 |
149 | def set_optimizer_lr(optimizer, lr):
150 | for param_group in optimizer.param_groups:
151 | param_group['lr'] = lr
152 |
153 |
154 | def filter_state_dict(state_dict, filter_keys):
155 | for key in list(state_dict.keys()):
156 | for f_key in filter_keys:
157 | if f_key in key:
158 | del state_dict[key]
159 | break
160 |
161 |
162 | def get_scheduler(optimizer, policy, nepoch_fix=None, nepoch=None, decay_step=None):
163 | if policy == 'lambda':
164 | def lambda_rule(epoch):
165 | lr_l = 1.0 - max(0, epoch - nepoch_fix) / float(nepoch - nepoch_fix + 1)
166 | return lr_l
167 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
168 | elif policy == 'step':
169 | scheduler = lr_scheduler.StepLR(
170 | optimizer, step_size=decay_step, gamma=0.1)
171 | elif policy == 'plateau':
172 | scheduler = lr_scheduler.ReduceLROnPlateau(
173 | optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
174 | else:
175 | return NotImplementedError('learning rate policy [%s] is not implemented', policy)
176 | return scheduler
177 |
--------------------------------------------------------------------------------
/motion_pred/utils/visualization.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2018-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | import os
9 | import matplotlib
10 | import matplotlib.pyplot as plt
11 | from matplotlib.animation import FuncAnimation, writers
12 | from mpl_toolkits.mplot3d import Axes3D
13 | import numpy as np
14 |
15 |
16 |
17 | def render_animation(skeleton, poses_generator, t_hist, fix_0=True, azim=0.0, output=None, size=6, ncol=5, bitrate=3000, index_i=0):
18 | """
19 | TODO
20 | Render an animation. The supported output modes are:
21 | -- 'interactive': display an interactive figure
22 | (also works on notebooks if associated with %matplotlib inline)
23 | -- 'html': render the animation as HTML5 video. Can be displayed in a notebook using HTML(...).
24 | -- 'filename.mp4': render and export the animation as an h264 video (requires ffmpeg).
25 | -- 'filename.gif': render and export the animation a gif file (requires imagemagick).
26 | """
27 | os.makedirs(output, exist_ok=True)
28 | all_poses = next(poses_generator)
29 | action = all_poses.pop('action')
30 | t_total = next(iter(all_poses.values())).shape[0]
31 | poses = all_poses
32 | plt.ioff()
33 | nrow = int(np.ceil(len(poses) / ncol))
34 | fig = plt.figure(figsize=(size*ncol, size*nrow))
35 | ax_3d = []
36 | lines_3d = []
37 | trajectories = []
38 | radius = 1.7
39 | for index, (title, data) in enumerate(poses.items()):
40 | ax = fig.add_subplot(nrow, ncol, index+1, projection='3d')
41 | ax.view_init(elev=15., azim=azim)
42 | ax.set_xlim3d([-radius/2, radius/2])
43 | ax.set_zlim3d([0, radius])
44 | ax.set_ylim3d([-radius/2, radius/2])
45 | ax.set_aspect('auto')
46 | ax.set_xticklabels([])
47 | ax.set_yticklabels([])
48 | ax.set_zticklabels([])
49 | ax.dist = 5.0
50 | # ax.set_title(title, y=1.2)
51 | ax.set_axis_off()
52 | ax.patch.set_alpha(0.0)
53 | ax_3d.append(ax)
54 | lines_3d.append([])
55 | trajectories.append(data[:, 0, [0, 1]])
56 | fig.tight_layout()
57 | fig.subplots_adjust(wspace=-0.4, hspace=0)
58 | poses = list(poses.values())
59 |
60 | anim = None
61 | initialized = False
62 | animating = True
63 | find = 0
64 | hist_lcol, hist_rcol = 'black', 'red'
65 | pred_lcol, pred_rcol = 'purple', 'green'
66 |
67 | parents = skeleton.parents()
68 |
69 | def update_video(i):
70 | nonlocal initialized
71 | if i < t_hist:
72 | lcol, rcol = hist_lcol, hist_rcol
73 | else:
74 | lcol, rcol = pred_lcol, pred_rcol
75 |
76 | for n, ax in enumerate(ax_3d):
77 | if fix_0 and n == 0 and i >= t_hist:
78 | continue
79 | trajectories[n] = poses[n][:, 0, [0, 1, 2]]
80 | ax.set_xlim3d([-radius/2 + trajectories[n][i, 0], radius/2 + trajectories[n][i, 0]])
81 | ax.set_ylim3d([-radius/2 + trajectories[n][i, 1], radius/2 + trajectories[n][i, 1]])
82 | ax.set_zlim3d([-radius/2 + trajectories[n][i, 2], radius/2 + trajectories[n][i, 2]])
83 |
84 | if not initialized:
85 |
86 | for j, j_parent in enumerate(parents):
87 | if j_parent == -1:
88 | continue
89 |
90 | col = rcol if j in skeleton.joints_right() else lcol
91 | for n, ax in enumerate(ax_3d):
92 | pos = poses[n][i]
93 | lines_3d[n].append(ax.plot([pos[j, 0], pos[j_parent, 0]],
94 | [pos[j, 1], pos[j_parent, 1]],
95 | [pos[j, 2], pos[j_parent, 2]], zdir='z', c=col, lw=8, dash_capstyle='round', marker='o', markersize=12, alpha=0.5, aa=True))
96 | initialized = True
97 | else:
98 |
99 | for j, j_parent in enumerate(parents):
100 | if j_parent == -1:
101 | continue
102 |
103 | col = rcol if j in skeleton.joints_right() else lcol
104 | for n, ax in enumerate(ax_3d):
105 | if fix_0 and n == 0 and i >= t_hist:
106 | continue
107 | pos = poses[n][i]
108 | lines_3d[n][j-1][0].set_xdata(np.array([pos[j, 0], pos[j_parent, 0]]))
109 | lines_3d[n][j-1][0].set_ydata(np.array([pos[j, 1], pos[j_parent, 1]]))
110 | lines_3d[n][j-1][0].set_3d_properties(np.array([pos[j, 2], pos[j_parent, 2]]), zdir='z')
111 | lines_3d[n][j-1][0].set_color(col)
112 |
113 | def show_animation():
114 | nonlocal anim
115 | if anim is not None:
116 | anim.event_source.stop()
117 | anim = FuncAnimation(fig, update_video, frames=np.arange(0, poses[0].shape[0]), interval=0, repeat=True)
118 | plt.draw()
119 |
120 | def reload_poses():
121 | nonlocal poses, action
122 | if 'action' in all_poses:
123 | action = all_poses.pop('action')
124 | poses = all_poses
125 | # for ax, title in zip(ax_3d, poses.keys()):
126 | # ax.set_title(title, y=1.2)
127 | poses = list(poses.values())
128 |
129 | def save_figs():
130 | nonlocal find
131 | update_video(0)
132 | update_video(t_total - 1)
133 | os.makedirs(output + 'image', exist_ok=True)
134 | fig.savefig(output + 'image/%d_%s.png' % (index_i, action), dpi=80, transparent=True)
135 | find += 1
136 |
137 | def on_key(event):
138 | nonlocal all_poses, animating, anim
139 |
140 | if event.key == 'd':
141 | all_poses = next(poses_generator)
142 | reload_poses()
143 | show_animation()
144 | elif event.key == 'c':
145 | save()
146 | elif event.key == ' ':
147 | if animating:
148 | anim.event_source.stop()
149 | else:
150 | anim.event_source.start()
151 | animating = not animating
152 | elif event.key == 'v': # save images
153 | if anim is not None:
154 | anim.event_source.stop()
155 | anim = None
156 | save_figs()
157 |
158 | def save():
159 | nonlocal anim
160 |
161 | fps = 30
162 | anim = FuncAnimation(fig, update_video, frames=np.arange(0, poses[0].shape[0]), interval=1000 / fps, repeat=False)
163 | os.makedirs(output+'video', exist_ok=True)
164 | anim.save(output + 'video/%d_%s.gif' % (index_i, action), dpi=80, writer='imagemagick')
165 | print(f'video saved to {output}video/{index_i}_{action}.gif!')
166 | save()
167 | # fig.canvas.mpl_connect('key_press_event', on_key)
168 | # show_animation()
169 | # plt.show()
--------------------------------------------------------------------------------
/utils/vis_poses.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib
3 |
4 | matplotlib.use("Agg")
5 | import matplotlib.pyplot as plt
6 | import os
7 | import subprocess
8 |
9 | from utils.vis_util import draw_skeleton
10 |
11 |
12 | def vis_traj(gt_pos, pre_pos=[], labels=[], comments='', title='',
13 | joint_to_plot=[2, 3, 7, 8, 15, 16, 17, 20, 21, 22], sub_title=[''] * 10):
14 | """
15 | @param gt_pos: seq_l x jn x 3
16 | @param pre_pos: n seq_l x jn x 3
17 | @param labels: n
18 | @param comments:
19 | @return:
20 | """
21 | assert len(pre_pos) == len(labels)
22 | assert len(gt_pos.shape) == 3
23 | assert len(sub_title) == len(joint_to_plot)
24 | seq_n, joint_n, _ = gt_pos.shape
25 | joint_name = np.array(["Hips", "rUpLeg", "rLeg", "rFoot", "rToeBase", "rToeSite", "lUpLeg", "lLeg",
26 | "lFoot", "lToeBase", "lToeSite", "Spine", "Spine1", "Neck", "Head", "Site", "lShoulder",
27 | "lArm", "lForeArm", "lHand", "lHandThumb", "lHandSite", "lWristEnd", "lWristSite",
28 | "rShoulder", "rArm", "rForeArm", "rHand", "rHandThumb", "rHandSite",
29 | "rWristEnd", "rWristSite"])
30 | joint_to_ignore = np.array([11, 16, 20, 23, 24, 28, 31])
31 | joint_used = np.setdiff1d(np.arange(32), joint_to_ignore)
32 | parents = np.array([-1, 0, 1, 2, 3, 4, 0, 6, 7, 8, 9, 0, 11, 12, 13, 12, 15, 16, 17, 17, 12, 20, 21, 22, 22])
33 | is_right = [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
34 | # joint_to_plot = [2, 3, 7, 8, 16, 17, 21, 22]
35 | linestyles = ['-', '--', '-.', ':', '--', '-.']
36 | markers = ['o', 'v', '*', 'D', 's', 'p']
37 | colors = ['k', 'r', 'g', 'b', 'c', 'm']
38 | coord = ['x', 'y', 'z']
39 | fig = plt.figure(0, figsize=[6 * 3, 3 * len(joint_to_plot)])
40 | # plt.title(title, pad=1)
41 | fig.suptitle('{}_{}'.format(title.split('/')[-1].split('.pdf')[0], comments), fontsize=14)
42 | axs = fig.subplots(len(joint_to_plot), 3)
43 | plt.subplots_adjust(left=0.05, right=0.95, bottom=0.05, top=0.95, wspace=0.2, hspace=0.4)
44 | for i, jn in enumerate(joint_to_plot):
45 | axs[i, 1].set_title("{}_{}".format(joint_name[joint_used][jn], sub_title[i]), x=0.5, y=1)
46 | # axs[i, 1].legend()
47 | for j in range(3):
48 | axs[i, j].plot(np.arange(1, seq_n + 1), gt_pos[:, jn, j], linestyle=linestyles[0], marker=markers[0],
49 | c=colors[0], label='GT')
50 | axs[i, j].set_xticks(np.arange(1, seq_n + 1))
51 | for k, pp in enumerate(pre_pos):
52 | axs[i, j].plot(np.arange(1, seq_n + 1), pp[:, jn, j], linestyle=linestyles[k + 1],
53 | marker=markers[k + 1], c=colors[k + 1],
54 | label=labels[k])
55 | axs[0, 1].legend()
56 | plt.savefig('{}'.format(title))
57 | # plt.savefig('test1.pdf'.format(title))
58 | plt.clf()
59 | plt.close()
60 | # plt.show()
61 |
62 |
63 | def vis_poses(gt_pos, pre_pos=[], labels=[], comments='', title='', skeleton=None):
64 | """
65 | @param gt_pos: seq_l x jn x 3
66 | @param pre_pos: n seq_l x jn x 3
67 | @param labels: n
68 | @param comments:
69 | @return:
70 | """
71 | assert len(pre_pos) == len(labels)
72 | assert len(gt_pos.shape) == 3
73 | if not len(comments) == gt_pos.shape[0]:
74 | comments = [] * gt_pos.shape[0]
75 | seq_n, joint_n, _ = gt_pos.shape
76 | # joint_name = np.array(["Hips", "rUpLeg", "rLeg", "rFoot", "rToeBase", "rToeSite", "lUpLeg", "lLeg",
77 | # "lFoot", "lToeBase", "lToeSite", "Spine", "Spine1", "Neck", "Head", "Site", "lShoulder",
78 | # "lArm", "lForeArm", "lHand", "lHandThumb", "lHandSite", "lWristEnd", "lWristSite",
79 | # "rShoulder", "rArm", "rForeArm", "rHand", "rHandThumb", "rHandSite",
80 | # "rWristEnd", "rWristSite"])
81 | # joint_to_ignore = np.array([11, 16, 20, 23, 24, 28, 31])
82 | # joint_used = np.setdiff1d(np.arange(32), joint_to_ignore)
83 | # parents = np.array([-1, 0, 1, 2, 3, 4, 0, 6, 7, 8, 9, 0, 11, 12, 13, 12, 15, 16, 17, 17, 12, 20, 21, 22, 22])
84 | # is_right = [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
85 | parents = skeleton._parents
86 | is_right = np.zeros_like(parents)
87 | is_right[skeleton._joints_right] = 1
88 |
89 | joint_to_plot = [2, 3, 15, 16, 17]
90 | linestyles = ['-', '--', '-.', ':', '--', '-.']
91 | markers = ['o', 'v', '*', 'D', 's', 'p']
92 | colors = ['k', 'r', 'g', 'b', 'c', 'm']
93 | coord = ['x', 'y', 'z']
94 | rot1 = np.array([[0.9239557, -0.0000000, -0.3824995],
95 | [-0.1463059, 0.9239557, -0.3534126],
96 | [0.3534126, 0.3824995, 0.8536941]])
97 | rot2 = np.array([[1.0000000, 0.0000000, 0.0000000],
98 | [0.0000000, 0.9239557, -0.3824995],
99 | [0.0000000, 0.3824995, 0.9239557]])
100 | rot3 = np.array([[0.9239557, 0.0000000, 0.3824995],
101 | [0.1463059, 0.9239557, -0.3534126],
102 | [-0.3534126, 0.3824995, 0.8536941]])
103 | rot = [rot3, rot2, rot1]
104 | trans = 4
105 | # axs = fig.subplots(1, len(rot))
106 | # get value scope
107 | pgt = []
108 | ppred = []
109 | scope = []
110 | pp_tmp = []
111 | for i, rr in enumerate(rot):
112 | pt = np.matmul(np.expand_dims(rr, axis=0), gt_pos.transpose([0, 2, 1])).transpose([0, 2, 1])
113 | pt[:, :, 2] = pt[:, :, 2] + trans
114 | pt = pt[:, :, :2] / pt[:, :, 2:]
115 | pgt.append(pt)
116 | pp_tmp.append(pt)
117 | ppred.append([])
118 | for j, pp in enumerate(pre_pos):
119 | pt = np.matmul(np.expand_dims(rr, axis=0), pp.transpose([0, 2, 1])).transpose([0, 2, 1])
120 | pt[:, :, 2] = pt[:, :, 2] + trans
121 | pt = pt[:, :, :2] / pt[:, :, 2:]
122 | pp_tmp.append(pt)
123 | ppred[i].append(pt)
124 | pp_tmp = np.vstack(pp_tmp)
125 | max_x = np.max(pp_tmp[:, :, 0])
126 | min_x = np.min(pp_tmp[:, :, 0])
127 | max_y = np.max(pp_tmp[:, :, 1])
128 | min_y = np.min(pp_tmp[:, :, 1])
129 | dx = (max_x - min_x) + 0.1 * (max_x - min_x)
130 | max_x = min_x + dx * len(rot)
131 |
132 | scope = [max_x, min_x, max_y, min_y]
133 | for jj in range(seq_n):
134 | fig = plt.figure(0, figsize=[3 * len(rot), 6])
135 | axs = fig.subplots(1, 1)
136 | plt.subplots_adjust(left=0.05, right=0.95, bottom=0.05, top=0.95, wspace=0.2, hspace=0.4)
137 | # for i in range(len(rot)):
138 | # axs[i].set_axis_off()
139 | # axs[i].axis('equal')
140 | # axs[i].plot(scope[i][:2], scope[i][2:], c='w')
141 | axs.set_axis_off()
142 | axs.axis('equal')
143 | axs.plot(scope[:2], scope[2:], c='w')
144 | for i, rr in enumerate(rot):
145 | # pt = np.matmul(rr, gt_pos[jj].transpose([1, 0])).transpose([1, 0])
146 | # pt[:, 2] = pt[:, 2] + trans
147 | # pt = pt[:, :2] / pt[:, 2:]
148 | pgt[i][jj][:, 0] = pgt[i][jj][:, 0] + dx * i
149 | draw_skeleton(axs, pgt[i][jj], parents=parents, is_right=is_right, cols=[colors[0], colors[0]],
150 | line_style=linestyles[0], label='GT' if i == 0 else None)
151 | for j, pp in enumerate(pre_pos):
152 | # pt = np.matmul(rr, pp[jj].transpose([1, 0])).transpose([1, 0])
153 | # pt[:, 2] = pt[:, 2] + trans
154 | # pt = pt[:, :2] / pt[:, 2:]
155 | ppred[i][j][jj][:, 0] = ppred[i][j][jj][:, 0] + dx * i
156 | draw_skeleton(axs, ppred[i][j][jj], parents=parents, is_right=is_right,
157 | cols=[colors[j + 1], colors[j + 1]],
158 | line_style=linestyles[j + 1], label=labels[j] if i == 0 else None)
159 | # axs[1].legend()
160 | # axs[1].set_title("f{}_{}".format(jj + 1, comments[jj]), x=0.5, y=1)
161 | # plt.savefig('{}/{}.jpg'.format(title, jj))
162 | axs.legend()
163 | axs.set_title("f{}_{}".format(jj + 1, comments[jj]), x=0.5, y=1)
164 | plt.savefig('{}/{}.jpg'.format(title, jj))
165 | # plt.show(block=False)
166 | # plt.pause(0.1)
167 | # for i in range(len(rot)):
168 | # axs[i].clear()
169 | axs.clear()
170 | plt.clf()
171 | plt.close()
172 | # cmd = [
173 | # 'ffmpeg',
174 | # '-i', vid_filename,
175 | # f'{output_folder}/%06d.jpg',
176 | # '-threads', '16'
177 | # ]
178 | #
179 | # print(' '.join(cmd))
180 | # try:
181 | # subprocess.call(cmd)
182 | # except OSError:
183 | # print('OSError')
184 |
185 | # # Set up formatting for the movie files
186 | # Writer = animation.writers['ffmpeg']
187 | # writer = Writer(fps=15, metadata=dict(artist='Me'), bitrate=1800)
188 | # fig = plt.figure()
189 | # im = []
190 | # for jj in range(seq_n):
191 | # img = plt.imread('{}_{}.jpg'.format(title.replace('.mp4', ''), jj))
192 | # im_tmp = plt.imshow(img)
193 | # im.append((im_tmp,))
194 | # im_ani = animation.ArtistAnimation(fig, im, interval=50, repeat_delay=3000,
195 | # blit=True)
196 | # im_ani.save(title, writer=writer)
197 |
--------------------------------------------------------------------------------
/motion_pred/utils/dataset_humaneva_multimodal.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | from motion_pred.utils.dataset import Dataset
4 | from motion_pred.utils.skeleton import Skeleton
5 | from utils import util
6 |
7 |
8 | class DatasetHumanEva(Dataset):
9 |
10 | def __init__(self, mode, t_his=15, t_pred=60, actions='all', **kwargs):
11 | if 'multimodal_path' in kwargs.keys():
12 | self.multimodal_path = kwargs['multimodal_path']
13 | else:
14 | self.multimodal_path = None
15 |
16 | if 'data_candi_path' in kwargs.keys():
17 | self.data_candi_path = kwargs['data_candi_path']
18 | else:
19 | self.data_candi_path = None
20 | super().__init__(mode, t_his, t_pred, actions)
21 |
22 | def prepare_data(self):
23 | self.data_file = os.path.join('data', 'data_3d_humaneva15.npz')
24 | self.subjects_split = {'train': ['Train/S1', 'Train/S2', 'Train/S3'],
25 | 'test': ['Validate/S1', 'Validate/S2', 'Validate/S3']}
26 | self.subjects = [x for x in self.subjects_split[self.mode]]
27 | self.skeleton = Skeleton(parents=[-1, 0, 1, 2, 3, 1, 5, 6, 0, 8, 9, 0, 11, 12, 1],
28 | joints_left=[2, 3, 4, 8, 9, 10],
29 | joints_right=[5, 6, 7, 11, 12, 13])
30 | self.kept_joints = np.arange(15)
31 | self.process_data()
32 |
33 | def process_data(self):
34 | data_o = np.load(self.data_file, allow_pickle=True)['positions_3d'].item()
35 | data_f = dict(filter(lambda x: x[0] in self.subjects, data_o.items()))
36 | # these takes have wrong head position, excluded from training and testing
37 | if self.mode == 'train':
38 | data_f['Train/S3'].pop('Walking 1 chunk0')
39 | data_f['Train/S3'].pop('Walking 1 chunk2')
40 | else:
41 | data_f['Validate/S3'].pop('Walking 1 chunk4')
42 | if self.multimodal_path is None:
43 | self.data_multimodal = \
44 | np.load('./data/humaneva_multi_modal/t_his15_1_thre0.050_t_pred60_thre0.100_index_filterd.npz',
45 | allow_pickle=True)['data_multimodal'].item()
46 | data_candi = \
47 | np.load('./data/humaneva_multi_modal/data_candi_t_his15_t_pred60_skiprate1.npz', allow_pickle=True)[
48 | 'data_candidate.npy']
49 | else:
50 | self.data_multimodal = np.load(self.multimodal_path, allow_pickle=True)['data_multimodal'].item()
51 | data_candi = np.load(self.data_candi_path, allow_pickle=True)['data_candidate.npy']
52 |
53 | self.data_candi = {}
54 |
55 | for key in list(data_f.keys()):
56 | # data_f[key] = dict(filter(lambda x: (self.actions == 'all' or
57 | # all([a in x[0] for a in self.actions]))
58 | # and x[1].shape[0] >= self.t_total, data_f[key].items()))
59 | data_f[key] = dict(filter(lambda x: (self.actions == 'all' or
60 | any([a in x[0] for a in self.actions]))
61 | and x[1].shape[0] >= self.t_total, data_f[key].items()))
62 | if len(data_f[key]) == 0:
63 | data_f.pop(key)
64 | for sub in data_f.keys():
65 | data_s = data_f[sub]
66 | # for data_s in data_f.values():
67 | for action in data_s.keys():
68 | seq = data_s[action][:, self.kept_joints, :]
69 | seq[:, 1:] -= seq[:, :1]
70 | data_s[action] = seq
71 |
72 | if sub not in self.data_candi.keys():
73 | x0 = np.copy(seq[None, :1, ...])
74 | x0[:, :, 0] = 0
75 | self.data_candi[sub] = util.absolute2relative(data_candi, parents=self.skeleton.parents(),
76 | invert=True, x0=x0)
77 | self.data = data_f
78 |
79 | def sample(self, n_modality=5):
80 | while True:
81 | subject = np.random.choice(self.subjects)
82 | dict_s = self.data[subject]
83 | action = np.random.choice(list(dict_s.keys()))
84 | seq = dict_s[action]
85 | if seq.shape[0] > self.t_total:
86 | break
87 | fr_start = np.random.randint(seq.shape[0] - self.t_total)
88 | fr_end = fr_start + self.t_total
89 | traj = seq[fr_start: fr_end]
90 | if n_modality > 0:
91 | # margin_f = 1
92 | # thre_his = 0.05
93 | # thre_pred = 0.1
94 | # x0 = np.copy(traj[None, ...])
95 | # x0[:, :, 0] = 0
96 | # # candi_tmp = util.absolute2relative(self.data_candi, parents=self.skeleton.parents(), invert=True, x0=x0)
97 | candi_tmp = self.data_candi[subject]
98 | # # observation distance
99 | # dist_his = np.mean(np.linalg.norm(x0[:, self.t_his - margin_f:self.t_his, 1:] -
100 | # candi_tmp[:, self.t_his - margin_f:self.t_his, 1:], axis=3), axis=(1, 2))
101 | # idx_his = np.where(dist_his <= thre_his)[0]
102 | #
103 | # # future distance
104 | # dist_pred = np.mean(np.linalg.norm(x0[:, self.t_his:, 1:] -
105 | # candi_tmp[idx_his, self.t_his:, 1:], axis=3), axis=(1, 2))
106 | #
107 | # idx_pred = np.where(dist_pred >= thre_pred)[0]
108 | # # idxs = np.intersect1d(idx_his, idx_pred)
109 | idx_multi = self.data_multimodal[subject][action][fr_start]
110 | traj_multi = candi_tmp[idx_multi]
111 |
112 | # confirm if it is the right one
113 | if len(idx_multi) > 0:
114 | margin_f = 1
115 | thre_his = 0.05
116 | thre_pred = 0.1
117 | x0 = np.copy(traj[None, ...])
118 | x0[:, :, 0] = 0
119 | dist_his = np.mean(np.linalg.norm(x0[:, self.t_his - margin_f:self.t_his, 1:] -
120 | traj_multi[:, self.t_his - margin_f:self.t_his, 1:], axis=3),
121 | axis=(1, 2))
122 | # if np.any(dist_his > thre_his):
123 | # print(f'===> wrong multi modality sequneces {dist_his[dist_his > thre_his].max():.3f}')
124 |
125 | if len(traj_multi) > 0:
126 | traj_multi[:, :self.t_his] = traj[None, ...][:, :self.t_his]
127 | if traj_multi.shape[0] > n_modality:
128 | st0 = np.random.get_state()
129 | idxtmp = np.random.choice(np.arange(traj_multi.shape[0]), n_modality, replace=False)
130 | traj_multi = traj_multi[idxtmp]
131 | np.random.set_state(st0)
132 | # traj_multi = traj_multi[:n_modality]
133 | traj_multi = np.concatenate(
134 | [traj_multi, np.zeros_like(traj[None, ...][[0] * (n_modality - traj_multi.shape[0])])], axis=0)
135 |
136 | return traj[None, ...], traj_multi, action
137 | else:
138 | return traj[None, ...], None, action
139 |
140 | def sampling_generator(self, num_samples=1000, batch_size=8, n_modality=5):
141 | for i in range(num_samples // batch_size):
142 | sample = []
143 | sample_multi = []
144 | for i in range(batch_size):
145 | sample_i, sample_multi_i, _ = self.sample(n_modality=n_modality)
146 | sample.append(sample_i)
147 | sample_multi.append(sample_multi_i[None, ...])
148 | sample = np.concatenate(sample, axis=0)
149 | sample_multi = np.concatenate(sample_multi, axis=0)
150 | yield sample, sample_multi
151 |
152 | #
153 | # def iter_generator(self, step=25, n_modality=10):
154 | # for sub in self.data.keys():
155 | # data_s = self.data[sub]
156 | # candi_tmp = self.data_candi[sub]
157 | # for act in data_s.keys():
158 | # seq = data_s[act]
159 | # seq_len = seq.shape[0]
160 | # for i in range(0, seq_len - self.t_total, step):
161 | # # idx_multi = self.data_multimodal[sub][act][i]
162 | # # traj_multi = candi_tmp[idx_multi]
163 | # traj = seq[None, i: i + self.t_total]
164 | # if n_modality > 0:
165 | # margin_f = 1
166 | # thre_his = 0.05
167 | # thre_pred = 0.1
168 | # x0 = np.copy(traj)
169 | # x0[:, :, 0] = 0
170 | # # candi_tmp = util.absolute2relative(self.data_candi, parents=self.skeleton.parents(), invert=True, x0=x0)
171 | # # candi_tmp = self.data_candi[subject]
172 | # # observation distance
173 | # dist_his = np.mean(np.linalg.norm(x0[:, self.t_his - margin_f:self.t_his, 1:] -
174 | # candi_tmp[:, self.t_his - margin_f:self.t_his, 1:], axis=3),
175 | # axis=(1, 2))
176 | # idx_his = np.where(dist_his <= thre_his)[0]
177 | #
178 | # # future distance
179 | # dist_pred = np.mean(np.linalg.norm(x0[:, self.t_his:, 1:] -
180 | # candi_tmp[idx_his, self.t_his:, 1:], axis=3), axis=(1, 2))
181 | #
182 | # idx_pred = np.where(dist_pred >= thre_pred)[0]
183 | # # idxs = np.intersect1d(idx_his, idx_pred)
184 | # traj_multi = candi_tmp[idx_his[idx_pred]]
185 | # if len(traj_multi) > 0:
186 | # traj_multi[:, :self.t_his] = traj[:, :self.t_his]
187 | # if traj_multi.shape[0] > n_modality:
188 | # idxtmp = np.random.choice(np.arange(traj_multi.shape[0]), n_modality, replace=False)
189 | # traj_multi = traj_multi[idxtmp]
190 | # traj_multi = np.concatenate(
191 | # [traj_multi, np.zeros_like(traj[[0] * (n_modality - traj_multi.shape[0])])],
192 | # axis=0)
193 | # else:
194 | # traj_multi = None
195 | #
196 | # yield traj, traj_multi
197 |
198 | def iter_generator(self, step=25):
199 | for data_s in self.data.values():
200 | for seq in data_s.values():
201 | seq_len = seq.shape[0]
202 | for i in range(0, seq_len - self.t_total, step):
203 | traj = seq[None, i: i + self.t_total]
204 | yield traj, None
205 |
206 |
207 | if __name__ == '__main__':
208 | np.random.seed(0)
209 | actions = 'all'
210 | dataset = DatasetHumanEva('test', actions=actions)
211 | generator = dataset.sampling_generator()
212 | dataset.normalize_data()
213 | # generator = dataset.iter_generator()
214 | for data in generator:
215 | print(data.shape)
216 |
--------------------------------------------------------------------------------
/motion_pred/utils/dataset_h36m_multimodal.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | from motion_pred.utils.dataset import Dataset
4 | from motion_pred.utils.skeleton import Skeleton
5 | from utils import util
6 |
7 |
8 | class DatasetH36M(Dataset):
9 |
10 | def __init__(self, mode, t_his=25, t_pred=100, actions='all', use_vel=False, **kwargs):
11 | self.use_vel = use_vel
12 | if 'multimodal_path' in kwargs.keys():
13 | self.multimodal_path = kwargs['multimodal_path']
14 | else:
15 | self.multimodal_path = None
16 |
17 | if 'data_candi_path' in kwargs.keys():
18 | self.data_candi_path = kwargs['data_candi_path']
19 | else:
20 | self.data_candi_path = None
21 | super().__init__(mode, t_his, t_pred, actions)
22 | if use_vel:
23 | self.traj_dim += 3
24 |
25 | def prepare_data(self):
26 | self.data_file = os.path.join('data', 'data_3d_h36m.npz')
27 | self.subjects_split = {'train': [1, 5, 6, 7, 8],
28 | 'test': [9, 11]}
29 | self.subjects = ['S%d' % x for x in self.subjects_split[self.mode]]
30 | self.skeleton = Skeleton(parents=[-1, 0, 1, 2, 3, 4, 0, 6, 7, 8, 9, 0, 11, 12, 13, 14, 12,
31 | 16, 17, 18, 19, 20, 19, 22, 12, 24, 25, 26, 27, 28, 27, 30],
32 | joints_left=[6, 7, 8, 9, 10, 16, 17, 18, 19, 20, 21, 22, 23],
33 | joints_right=[1, 2, 3, 4, 5, 24, 25, 26, 27, 28, 29, 30, 31])
34 | self.removed_joints = {4, 5, 9, 10, 11, 16, 20, 21, 22, 23, 24, 28, 29, 30, 31}
35 | self.kept_joints = np.array([x for x in range(32) if x not in self.removed_joints])
36 | self.skeleton.remove_joints(self.removed_joints)
37 | self.skeleton._parents[11] = 8
38 | self.skeleton._parents[14] = 8
39 | self.process_data()
40 |
41 | def process_data(self):
42 | data_o = np.load(self.data_file, allow_pickle=True)['positions_3d'].item()
43 | self.S1_skeleton = data_o['S1']['Directions'][:1, self.kept_joints].copy()
44 | data_f = dict(filter(lambda x: x[0] in self.subjects, data_o.items()))
45 | if self.actions != 'all':
46 | for key in list(data_f.keys()):
47 | # data_f[key] = dict(
48 | # filter(lambda x: all([a in str.lower(x[0]) for a in self.actions]), data_f[key].items()))
49 | data_f[key] = dict(filter(lambda x: any([a in str.lower(x[0]) for a in self.actions]), data_f[key].items()))
50 | if len(data_f[key]) == 0:
51 | data_f.pop(key)
52 | # possible candidate
53 | # skip_rate = 10
54 | # data_candi = []
55 | if self.multimodal_path is None:
56 | self.data_multimodal = \
57 | np.load('./data/data_multi_modal/t_his25_1_thre0.050_t_pred100_thre0.100_filtered.npz',
58 | allow_pickle=True)[
59 | 'data_multimodal'].item()
60 | data_candi = \
61 | np.load('./data/data_multi_modal/data_candi_t_his25_t_pred100_skiprate20.npz', allow_pickle=True)[
62 | 'data_candidate.npy']
63 | else:
64 | self.data_multimodal = np.load(self.multimodal_path, allow_pickle=True)['data_multimodal'].item()
65 | data_candi = np.load(self.data_candi_path, allow_pickle=True)['data_candidate.npy']
66 |
67 | self.data_candi = {}
68 |
69 | for sub in data_f.keys():
70 | data_s = data_f[sub]
71 | for action in data_s.keys():
72 | seq = data_s[action][:, self.kept_joints, :]
73 | if self.use_vel:
74 | v = (np.diff(seq[:, :1], axis=0) * 50).clip(-5.0, 5.0)
75 | v = np.append(v, v[[-1]], axis=0)
76 | seq[:, 1:] -= seq[:, :1]
77 |
78 | # # get relative candidate
79 | # data_tmp = np.copy(seq)
80 | # data_tmp[:, 0] = 0
81 | # nf = data_tmp.shape[0]
82 | # idxs = np.arange(0, nf - self.t_his - self.t_pred, skip_rate)[:, None] + np.arange(
83 | # self.t_his + self.t_pred)[None, :]
84 | # data_tmp = data_tmp[idxs]
85 | # data_tmp = util.absolute2relative(data_tmp, parents=self.skeleton.parents())
86 | # data_candi.append(data_tmp)
87 |
88 | if self.use_vel:
89 | seq = np.concatenate((seq, v), axis=1)
90 | data_s[action] = seq
91 |
92 | if sub not in self.data_candi.keys():
93 | x0 = np.copy(seq[None, :1, ...])
94 | x0[:, :, 0] = 0
95 | self.data_candi[sub] = util.absolute2relative(data_candi, parents=self.skeleton.parents(),
96 | invert=True, x0=x0)
97 |
98 | self.data = data_f
99 | # self.data_candi = np.concatenate(data_candi, axis=0)
100 |
101 | def sample(self, n_modality=5):
102 | subject = np.random.choice(self.subjects)
103 | dict_s = self.data[subject]
104 | action = np.random.choice(list(dict_s.keys()))
105 | seq = dict_s[action]
106 | fr_start = np.random.randint(seq.shape[0] - self.t_total)
107 | fr_end = fr_start + self.t_total
108 | traj = seq[fr_start: fr_end]
109 | if n_modality > 0 and subject in self.data_multimodal.keys():
110 | # margin_f = 1
111 | # thre_his = 0.05
112 | # thre_pred = 0.1
113 | # x0 = np.copy(traj[None, ...])
114 | # x0[:, :, 0] = 0
115 | # # candi_tmp = util.absolute2relative(self.data_candi, parents=self.skeleton.parents(), invert=True, x0=x0)
116 | candi_tmp = self.data_candi[subject]
117 | # # observation distance
118 | # dist_his = np.mean(np.linalg.norm(x0[:, self.t_his - margin_f:self.t_his, 1:] -
119 | # candi_tmp[:, self.t_his - margin_f:self.t_his, 1:], axis=3), axis=(1, 2))
120 | # idx_his = np.where(dist_his <= thre_his)[0]
121 | #
122 | # # future distance
123 | # dist_pred = np.mean(np.linalg.norm(x0[:, self.t_his:, 1:] -
124 | # candi_tmp[idx_his, self.t_his:, 1:], axis=3), axis=(1, 2))
125 | #
126 | # idx_pred = np.where(dist_pred >= thre_pred)[0]
127 | # # idxs = np.intersect1d(idx_his, idx_pred)
128 | idx_multi = self.data_multimodal[subject][action][fr_start]
129 | traj_multi = candi_tmp[idx_multi]
130 |
131 | # # confirm if it is the right one
132 | # if len(idx_multi) > 0:
133 | # margin_f = 1
134 | # thre_his = 0.05
135 | # thre_pred = 0.1
136 | # x0 = np.copy(traj[None, ...])
137 | # x0[:, :, 0] = 0
138 | # dist_his = np.mean(np.linalg.norm(x0[:, self.t_his - margin_f:self.t_his, 1:] -
139 | # traj_multi[:, self.t_his - margin_f:self.t_his, 1:], axis=3),
140 | # axis=(1, 2))
141 | # if np.any(dist_his > thre_his):
142 | # print(f'===> wrong multi modality sequneces {dist_his[dist_his > thre_his].max():.3f}')
143 |
144 | if len(traj_multi) > 0:
145 | traj_multi[:, :self.t_his] = traj[None, ...][:, :self.t_his]
146 | if traj_multi.shape[0] > n_modality:
147 | st0 = np.random.get_state()
148 | idxtmp = np.random.choice(np.arange(traj_multi.shape[0]), n_modality, replace=False)
149 | traj_multi = traj_multi[idxtmp]
150 | np.random.set_state(st0)
151 | # traj_multi = traj_multi[:n_modality]
152 | traj_multi = np.concatenate(
153 | [traj_multi, np.zeros_like(traj[None, ...][[0] * (n_modality - traj_multi.shape[0])])], axis=0)
154 |
155 | return traj[None, ...], traj_multi, action
156 | else:
157 | return traj[None, ...], None, action
158 |
159 | def sampling_generator(self, num_samples=1000, batch_size=8, n_modality=5):
160 | for i in range(num_samples // batch_size):
161 | sample = []
162 | sample_multi = []
163 | for i in range(batch_size):
164 | sample_i, sample_multi_i, _ = self.sample(n_modality=n_modality)
165 | sample.append(sample_i)
166 | sample_multi.append(sample_multi_i[None, ...])
167 | sample = np.concatenate(sample, axis=0)
168 | sample_multi = np.concatenate(sample_multi, axis=0)
169 | yield sample, sample_multi
170 |
171 | def iter_generator(self, step=25, n_modality=10):
172 | for sub in self.data.keys():
173 | data_s = self.data[sub]
174 | candi_tmp = self.data_candi[sub]
175 | for act in data_s.keys():
176 | seq = data_s[act]
177 | seq_len = seq.shape[0]
178 | for i in range(0, seq_len - self.t_total, step):
179 | # idx_multi = self.data_multimodal[sub][act][i]
180 | # traj_multi = candi_tmp[idx_multi]
181 | traj = seq[None, i: i + self.t_total]
182 | if n_modality > 0:
183 | margin_f = 1
184 | thre_his = 0.05
185 | thre_pred = 0.1
186 | x0 = np.copy(traj)
187 | x0[:, :, 0] = 0
188 | # candi_tmp = util.absolute2relative(self.data_candi, parents=self.skeleton.parents(), invert=True, x0=x0)
189 | # candi_tmp = self.data_candi[subject]
190 | # observation distance
191 | dist_his = np.mean(np.linalg.norm(x0[:, self.t_his - margin_f:self.t_his, 1:] -
192 | candi_tmp[:, self.t_his - margin_f:self.t_his, 1:], axis=3),
193 | axis=(1, 2))
194 | idx_his = np.where(dist_his <= thre_his)[0]
195 |
196 | # future distance
197 | dist_pred = np.mean(np.linalg.norm(x0[:, self.t_his:, 1:] -
198 | candi_tmp[idx_his, self.t_his:, 1:], axis=3), axis=(1, 2))
199 |
200 | idx_pred = np.where(dist_pred >= thre_pred)[0]
201 | # idxs = np.intersect1d(idx_his, idx_pred)
202 | traj_multi = candi_tmp[idx_his[idx_pred]]
203 | if len(traj_multi) > 0:
204 | traj_multi[:, :self.t_his] = traj[:, :self.t_his]
205 | if traj_multi.shape[0] > n_modality:
206 | # idxtmp = np.random.choice(np.arange(traj_multi.shape[0]), n_modality, replace=False)
207 | # traj_multi = traj_multi[idxtmp]
208 | traj_multi = traj_multi[:n_modality]
209 | traj_multi = np.concatenate(
210 | [traj_multi, np.zeros_like(traj[[0] * (n_modality - traj_multi.shape[0])])],
211 | axis=0)
212 | else:
213 | traj_multi = None
214 |
215 | yield traj, traj_multi
216 |
217 |
218 | if __name__ == '__main__':
219 | np.random.seed(0)
220 | actions = {'WalkDog'}
221 | dataset = DatasetH36M('train', actions=actions)
222 | generator = dataset.sampling_generator()
223 | dataset.normalize_data()
224 | # generator = dataset.iter_generator()
225 | for data in generator:
226 | print(data.shape)
227 |
--------------------------------------------------------------------------------
/train_nf.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import pickle
4 | import argparse
5 | import time
6 | from torch import optim
7 | from torch.utils.tensorboard import SummaryWriter
8 | import matplotlib.pyplot as plt
9 |
10 | sys.path.append(os.getcwd())
11 | from utils import *
12 | from motion_pred.utils.config import Config
13 | from motion_pred.utils.dataset_h36m import DatasetH36M
14 | from motion_pred.utils.dataset_humaneva import DatasetHumanEva
15 | from models.motion_pred import *
16 | from utils import util
17 |
18 |
19 | def loss_function(prior_lkh, log_det_jacobian):
20 | loss_p = -prior_lkh.mean()
21 | loss_jac = - log_det_jacobian.mean()
22 | loss_r = loss_p + loss_jac
23 |
24 | return loss_r, np.array([loss_r.item(), loss_p.item(), loss_jac.item()])
25 |
26 |
27 | def train(epoch):
28 | model.train()
29 | t_s = time.time()
30 | train_losses = 0
31 | train_grad = 0
32 | total_num_sample = 0
33 | loss_names = ['LKH', 'log_p(z)', 'log_det']
34 | generator = dataset.sampling_generator(num_samples=cfg.num_data_sample, batch_size=cfg.batch_size)
35 | prior = torch.distributions.Normal(torch.tensor(0, dtype=dtype, device=device),
36 | torch.tensor(1, dtype=dtype, device=device))
37 | for traj_np in generator:
38 | with torch.no_grad():
39 | traj_np = traj_np[:, 0]
40 | traj_np[:, 0] = 0
41 | traj_np = util.absolute2relative(traj_np, parents=dataset.skeleton.parents())
42 | traj = tensor(traj_np, device=device, dtype=dtype) # .permute(0, 2, 1).contiguous()
43 | bs, nj, _ = traj.shape
44 | x = traj.reshape([bs, -1])
45 |
46 | z, log_det_jacobian = model(x)
47 | prior_likelihood = prior.log_prob(z).sum(dim=1)
48 |
49 | loss, losses = loss_function(prior_likelihood, log_det_jacobian)
50 | optimizer.zero_grad()
51 | loss.backward()
52 | # grad_norm = torch.nn.utils.clip_grad_norm_(list(model.parameters()), max_norm=10000)
53 | grad_norm = 0
54 | train_grad += grad_norm
55 | optimizer.step()
56 | train_losses += losses
57 | total_num_sample += 1
58 | del loss, z, prior_likelihood, log_det_jacobian
59 |
60 | scheduler.step()
61 | # dt = time.time() - t_s
62 | train_losses /= total_num_sample
63 | lr = optimizer.param_groups[0]['lr']
64 | losses_str = ' '.join(['{}: {:.4f}'.format(x, y) for x, y in zip(loss_names, train_losses)])
65 |
66 | # average cost of log time 20s
67 | tb_logger.add_scalar('train_grad', train_grad / total_num_sample, epoch)
68 | for name, loss in zip(loss_names, train_losses):
69 | tb_logger.add_scalars(name, {'train': loss}, epoch)
70 |
71 | logger.info('====> Epoch: {} Time: {:.2f} {} lr: {:.5f}'.format(epoch, time.time() - t_s, losses_str, lr))
72 |
73 |
74 | def val(epoch):
75 | model.eval()
76 | t_s = time.time()
77 | train_losses = 0
78 | total_num_sample = 0
79 | loss_names = ['LKH', 'log_p(z)', 'log_det']
80 | generator = dataset.sampling_generator(num_samples=cfg.num_data_sample // 2, batch_size=cfg.batch_size)
81 | prior = torch.distributions.Normal(torch.tensor(0, dtype=dtype, device=device),
82 | torch.tensor(1, dtype=dtype, device=device))
83 | loginfos = None
84 | with torch.no_grad():
85 | xx = []
86 | for traj_np in generator:
87 | traj_np = traj_np[:, 0]
88 | traj_np[:, 0] = 0
89 | traj_np = util.absolute2relative(traj_np, parents=dataset.skeleton.parents())
90 | traj = tensor(traj_np, device=device, dtype=dtype) # .permute(0, 2, 1).contiguous()
91 | bs, nj, _ = traj.shape
92 | x = traj.reshape([bs, -1])
93 |
94 | z, log_det_jacobian = model(x)
95 | prior_likelihood = prior.log_prob(z).sum(dim=1)
96 | loginf = {}
97 | loginf[f'z'] = z.cpu().data.numpy()
98 |
99 | loss, losses = loss_function(prior_likelihood, log_det_jacobian)
100 |
101 | train_losses += losses
102 | total_num_sample += 1
103 | del loss, z, prior_likelihood, log_det_jacobian
104 | loginfos = combine_dict(loginf, loginfos)
105 |
106 | # dt = time.time() - t_s
107 | train_losses /= total_num_sample
108 | lr = optimizer.param_groups[0]['lr']
109 | losses_str = ' '.join(['{}: {:.4f}'.format(x, y) for x, y in zip(loss_names, train_losses)])
110 |
111 | # average cost of log time 20s
112 | for name, loss in zip(loss_names, train_losses):
113 | tb_logger.add_scalars(name, {'val': loss}, epoch)
114 | logger.info('====> Epoch: {} Val Time: {:.2f} {} lr: {:.5f}'.format(epoch, time.time() - t_s, losses_str, lr))
115 |
116 | t_s = time.time()
117 | generator = dataset_test.sampling_generator(num_samples=cfg.num_data_sample // 2, batch_size=cfg.batch_size)
118 | loginfos_test = None
119 | with torch.no_grad():
120 | xx = []
121 | for traj_np in generator:
122 | traj_np = traj_np[:, 0]
123 | traj_np[:, 0] = 0
124 | traj_np = util.absolute2relative(traj_np, parents=dataset.skeleton.parents())
125 | traj = tensor(traj_np, device=device, dtype=dtype) # .permute(0, 2, 1).contiguous()
126 | bs, nj, _ = traj.shape
127 | x = traj.reshape([bs, -1])
128 |
129 | z, log_det_jacobian = model(x)
130 | prior_likelihood = prior.log_prob(z).sum(dim=1)
131 | loginf = {}
132 | loginf[f'z'] = z.cpu().data.numpy()
133 |
134 | loss, losses = loss_function(prior_likelihood, log_det_jacobian)
135 |
136 | train_losses += losses
137 | total_num_sample += 1
138 | del loss, z, prior_likelihood, log_det_jacobian
139 | loginfos_test = combine_dict(loginf, loginfos_test)
140 |
141 | # dt = time.time() - t_s
142 | train_losses /= total_num_sample
143 | lr = optimizer.param_groups[0]['lr']
144 | losses_str = ' '.join(['{}: {:.4f}'.format(x, y) for x, y in zip(loss_names, train_losses)])
145 |
146 | # average cost of log time 20s
147 | for name, loss in zip(loss_names, train_losses):
148 | tb_logger.add_scalars(name, {'test': loss}, epoch)
149 | logger.info('====> Epoch: {} Test Time: {:.2f} {} lr: {:.5f}'.format(epoch, time.time() - t_s, losses_str, lr))
150 |
151 | t_s = time.time()
152 | zz = loginfos['z']
153 | zz = zz.reshape([zz.shape[0], -1])
154 | bs, data_dim = zz.shape
155 | zz_test = loginfos_test['z']
156 | zz_test = zz_test.reshape([zz.shape[0], -1])
157 | # bs, data_dim = zz_test.shape
158 | for ii in range(data_dim):
159 | fig = plt.figure()
160 | ax1 = plt.subplot(121)
161 | _ = plt.hist(zz[:, ii].reshape(-1), bins=100, density=True, alpha=0.5, color='b')
162 | x = torch.from_numpy(np.arange(-5, 5, 0.01)).float().to(device)
163 | y = prior.cdf(x)
164 | x = x[1:]
165 | y = (y[1:] - y[:-1]) * 100
166 | plt.plot(x.cpu().data, y.cpu().data)
167 | ax1.set_title(f'z_val_{ii}')
168 | ax2 = plt.subplot(122)
169 | _ = plt.hist(zz_test[:, ii].reshape(-1), bins=100, density=True, alpha=0.5, color='b')
170 | x = torch.from_numpy(np.arange(-5, 5, 0.01)).float().to(device)
171 | y = prior.cdf(x)
172 | x = x[1:]
173 | y = (y[1:] - y[:-1]) * 100
174 | plt.plot(x.cpu().data, y.cpu().data)
175 | ax2.set_title(f'z_test_{ii}')
176 | tb_logger.add_figure(f'z_{ii}', fig, epoch)
177 | plt.clf()
178 | plt.cla()
179 | plt.close(fig)
180 |
181 | # plot covariance matrix
182 | fig = plt.figure()
183 | ax1 = plt.subplot(121)
184 | # zz = zz.reshape([bs, -1])
185 | zz = zz - zz.mean(axis=0)
186 | cov = 1 / bs * np.abs(np.matmul(zz.transpose([1, 0]), zz))
187 | std = np.sqrt(np.diag(cov))[:, None] * np.sqrt(np.diag(cov))[None, :]
188 | corr = cov / (std + 1e-10) - np.eye(cov.shape[0])
189 | plt.imshow(corr)
190 | plt.colorbar()
191 | ax1.set_title('z_val_corr')
192 |
193 | ax2 = plt.subplot(122)
194 | # zz = zz.reshape([bs, -1])
195 | zz_test = zz_test - zz_test.mean(axis=0)
196 | cov = 1 / bs * np.abs(np.matmul(zz_test.transpose([1, 0]), zz_test))
197 | std = np.sqrt(np.diag(cov))[:, None] * np.sqrt(np.diag(cov))[None, :]
198 | corr = cov / (std + 1e-10) - np.eye(cov.shape[0])
199 | plt.imshow(corr)
200 | plt.colorbar()
201 | ax2.set_title('z_test_corr')
202 | tb_logger.add_figure('z_corr', plt.gcf(), epoch)
203 | plt.close(fig)
204 | print(f'>>>>log time {time.time() - t_s:.3f}')
205 |
206 |
207 | if __name__ == '__main__':
208 |
209 | parser = argparse.ArgumentParser()
210 | parser.add_argument('--cfg', default='h36m_nf')
211 | parser.add_argument('--mode', default='train')
212 | parser.add_argument('--test', action='store_true', default=False)
213 | parser.add_argument('--iter', type=int, default=0)
214 | parser.add_argument('--seed', type=int, default=0)
215 | parser.add_argument('--gpu_index', type=int, default=1)
216 | parser.add_argument('--n_pre', type=int, default=10)
217 | parser.add_argument('--n_his', type=int, default=5)
218 | parser.add_argument('--trial', type=int, default=1)
219 | parser.add_argument('--num_coupling_layer', type=int, default=6)
220 | parser.add_argument('--nz', type=int, default=10)
221 | args = parser.parse_args()
222 |
223 | """setup"""
224 | np.random.seed(args.seed)
225 | torch.manual_seed(args.seed)
226 | dtype = torch.float32
227 | torch.set_default_dtype(dtype)
228 | device = torch.device('cuda')#, index=args.gpu_index) if torch.cuda.is_available() else torch.device('cpu')
229 | # if torch.cuda.is_available():
230 | # torch.cuda.set_device(args.gpu_index)
231 | cfg = Config(f'{args.cfg}', test=args.test, nf=True)
232 | tb_logger = SummaryWriter(cfg.tb_dir) if args.mode == 'train' else None
233 | logger = create_logger(os.path.join(cfg.log_dir, 'log.txt'))
234 |
235 | """parameter"""
236 | mode = args.mode
237 | nz = cfg.nz
238 | t_his = cfg.t_his
239 | t_pred = cfg.t_pred
240 | cfg.n_his = args.n_his
241 | cfg.n_pre = args.n_pre
242 | cfg.num_coupling_layer = args.num_coupling_layer
243 | cfg.nz = args.nz
244 | """data"""
245 | dataset_cls = DatasetH36M if cfg.dataset == 'h36m' else DatasetHumanEva
246 | dataset = dataset_cls('train', t_his, t_pred, actions='all', use_vel=cfg.use_vel)
247 | dataset_test = dataset_cls('test', t_his, t_pred, actions='all', use_vel=cfg.use_vel)
248 | if cfg.normalize_data:
249 | dataset.normalize_data()
250 | dataset_test.normalize_data(dataset.mean, dataset.std)
251 |
252 | """model"""
253 | model = get_model(cfg, dataset, args.cfg+'_nf')
254 | print(model)
255 | model.float()
256 | optimizer = optim.Adam(model.parameters(), lr=cfg.lr, weight_decay=1e-3)
257 | scheduler = get_scheduler(optimizer, policy='lambda', nepoch_fix=cfg.num_epoch_fix, nepoch=cfg.num_epoch)
258 | logger.info(">>> total params: {:.5f}M".format(sum(p.numel() for p in list(model.parameters())) / 1000000.0))
259 |
260 | if args.iter > 0:
261 | cp_path = cfg.model_path % args.iter
262 | print('loading model from checkpoint: %s' % cp_path)
263 | model_cp = pickle.load(open(cp_path, "rb"))
264 | model.load_state_dict(model_cp['model_dict'])
265 |
266 | if mode == 'train':
267 | model.to(device)
268 | overall_iter = 0
269 | for i in range(args.iter, cfg.num_epoch):
270 | model.train()
271 | train(i)
272 | model.eval()
273 | val(i)
274 | # test(i)
275 | if cfg.save_model_interval > 0 and (i + 1) % cfg.save_model_interval == 0:
276 | with to_cpu(model):
277 | cp_path = cfg.model_path % (i + 1)
278 | model_cp = {'model_dict': model.state_dict(), 'meta': {'std': dataset.std, 'mean': dataset.mean}}
279 | pickle.dump(model_cp, open(cp_path, 'wb'))
280 |
--------------------------------------------------------------------------------
/utils/valid_angle_check.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import pickle
4 |
5 |
6 | def h36m_valid_angle_check(p3d):
7 | """
8 | p3d: [bs,16,3] or [bs,48]
9 | """
10 | if p3d.shape[-1] == 48:
11 | p3d = p3d.reshape([p3d.shape[0], 16, 3])
12 |
13 | cos_func = lambda p1, p2: np.sum(p1 * p2, axis=1) / np.linalg.norm(p1, axis=1) / np.linalg.norm(p2, axis=1)
14 | data_all = p3d
15 | valid_cos = {}
16 | # Spine2LHip
17 | p1 = data_all[:, 3]
18 | p2 = data_all[:, 6]
19 | cos_gt_l = np.sum(p1 * p2, axis=1) / np.linalg.norm(p1, axis=1) / np.linalg.norm(p2, axis=1)
20 | # Spine2RHip
21 | p1 = data_all[:, 0]
22 | p2 = data_all[:, 6]
23 | cos_gt_r = np.sum(p1 * p2, axis=1) / np.linalg.norm(p1, axis=1) / np.linalg.norm(p2, axis=1)
24 | valid_cos['Spine2Hip'] = np.vstack((cos_gt_l, cos_gt_r))
25 |
26 | # LLeg2LeftHipPlane
27 | p0 = data_all[:, 3]
28 | p1 = data_all[:, 4] - data_all[:, 3]
29 | p2 = data_all[:, 5] - data_all[:, 4]
30 | n0 = np.cross(p0, p1)
31 | cos_gt_l = np.sum(n0 * p2, axis=1) / np.linalg.norm(n0, axis=1) / np.linalg.norm(p2, axis=1)
32 | # RLeg2RHipPlane
33 | p0 = data_all[:, 0]
34 | p1 = data_all[:, 1] - data_all[:, 0]
35 | p2 = data_all[:, 2] - data_all[:, 1]
36 | n0 = np.cross(p1, p0)
37 | cos_gt_r = np.sum(n0 * p2, axis=1) / np.linalg.norm(n0, axis=1) / np.linalg.norm(p2, axis=1)
38 | valid_cos['Leg2HipPlane'] = np.vstack((cos_gt_l, cos_gt_r))
39 |
40 | # Shoulder2Hip
41 | p1 = data_all[:, 10] - data_all[:, 7]
42 | p2 = data_all[:, 3]
43 | cos_gt_l = np.sum(p1 * p2, axis=1) / np.linalg.norm(p1, axis=1) / np.linalg.norm(p2, axis=1)
44 | p1 = data_all[:, 13] - data_all[:, 7]
45 | p2 = data_all[:, 0]
46 | cos_gt_r = np.sum(p1 * p2, axis=1) / np.linalg.norm(p1, axis=1) / np.linalg.norm(p2, axis=1)
47 | valid_cos['Shoulder2Hip'] = np.vstack((cos_gt_l, cos_gt_r))
48 |
49 | # Leg2ShoulderPlane
50 | p0 = data_all[:, 13]
51 | p1 = data_all[:, 10]
52 | p2 = data_all[:, 4]
53 | p3 = data_all[:, 1]
54 | n0 = np.cross(p0, p1)
55 | cos_gt_l = np.sum(n0 * p2, axis=1) / np.linalg.norm(n0, axis=1) / np.linalg.norm(p2, axis=1)
56 | cos_gt_r = np.sum(n0 * p3, axis=1) / np.linalg.norm(n0, axis=1) / np.linalg.norm(p3, axis=1)
57 | valid_cos['Leg2ShoulderPlane'] = np.vstack((cos_gt_l, cos_gt_r))
58 |
59 | # Shoulder2Shoulder
60 | p0 = data_all[:, 13] - data_all[:, 7]
61 | p1 = data_all[:, 10] - data_all[:, 7]
62 | cos_gt = np.sum(p0 * p1, axis=1) / np.linalg.norm(p0, axis=1) / np.linalg.norm(p1, axis=1)
63 | valid_cos['Shoulder2Shoulder'] = cos_gt
64 |
65 | # Neck2Spine
66 | p0 = data_all[:, 7] - data_all[:, 6]
67 | p1 = data_all[:, 6]
68 | cos_gt = np.sum(p0 * p1, axis=1) / np.linalg.norm(p0, axis=1) / np.linalg.norm(p1, axis=1)
69 | valid_cos['Neck2Spine'] = cos_gt
70 |
71 | # Spine2HipPlane1
72 | p0 = data_all[:, 3]
73 | p1 = data_all[:, 4] - data_all[:, 3]
74 | n0 = np.cross(p1, p0)
75 | p2 = data_all[:, 6]
76 | n1 = np.cross(p2, n0)
77 | cos_dir_l = np.sum(p0 * n1, axis=1) / np.linalg.norm(p0, axis=1) / np.linalg.norm(n1, axis=1)
78 | cos_gt_l = np.sum(n0 * p2, axis=1) / np.linalg.norm(n0, axis=1) / np.linalg.norm(p2, axis=1)
79 | p0 = data_all[:, 0]
80 | p1 = data_all[:, 1] - data_all[:, 0]
81 | n0 = np.cross(p0, p1)
82 | p2 = data_all[:, 6]
83 | n1 = np.cross(n0, p2)
84 | cos_dir_r = np.sum(p0 * n1, axis=1) / np.linalg.norm(p0, axis=1) / np.linalg.norm(n1, axis=1)
85 | cos_gt_r = np.sum(n0 * p2, axis=1) / np.linalg.norm(n0, axis=1) / np.linalg.norm(p2, axis=1)
86 | cos_gt_l1 = np.ones_like(cos_gt_l) * 0.5
87 | cos_gt_r1 = np.ones_like(cos_gt_r) * 0.5
88 | cos_gt_l1[cos_dir_l < 0] = cos_gt_l[cos_dir_l < 0]
89 | cos_gt_r1[cos_dir_r < 0] = cos_gt_r[cos_dir_r < 0]
90 | valid_cos['Spine2HipPlane1'] = np.vstack((cos_gt_l1, cos_gt_r1))
91 |
92 | # Spine2HipPlane2
93 | cos_gt_l2 = np.ones_like(cos_gt_l) * 0.5
94 | cos_gt_r2 = np.ones_like(cos_gt_r) * 0.5
95 | cos_gt_l2[cos_dir_l >= 0] = cos_gt_l[cos_dir_l >= 0]
96 | cos_gt_r2[cos_dir_r >= 0] = cos_gt_r[cos_dir_r >= 0]
97 | valid_cos['Spine2HipPlane2'] = np.vstack((cos_gt_l2, cos_gt_r2))
98 |
99 | # ShoulderPlane2HipPlane (25 Jan)
100 | p1 = data_all[:, 7] - data_all[:, 3]
101 | p2 = data_all[:, 7] - data_all[:, 0]
102 | p3 = data_all[:, 10]
103 | p4 = data_all[:, 13]
104 | n0 = np.cross(p2, p1)
105 | n1 = np.cross(p3, p4)
106 | cos_gt_l = np.sum(n0 * n1, axis=1) / np.linalg.norm(n0, axis=1) / np.linalg.norm(n1, axis=1)
107 | valid_cos['ShoulderPlane2HipPlane'] = cos_gt_l
108 |
109 | # Head2Neck
110 | p1 = data_all[:, 7] - data_all[:, 6]
111 | p2 = data_all[:, 8] - data_all[:, 7]
112 | cos_gt_l = np.sum(p1 * p2, axis=1) / np.linalg.norm(p1, axis=1) / np.linalg.norm(p2, axis=1)
113 | valid_cos['Head2Neck'] = cos_gt_l
114 |
115 | # Head2HeadTop
116 | p1 = data_all[:, 9] - data_all[:, 8]
117 | p2 = data_all[:, 8] - data_all[:, 7]
118 | cos_gt_l = np.sum(p1 * p2, axis=1) / np.linalg.norm(p1, axis=1) / np.linalg.norm(p2, axis=1)
119 | valid_cos['Head2HeadTop'] = cos_gt_l
120 |
121 | # HeadVerticalPlane2HipPlane
122 | p1 = data_all[:, 9] - data_all[:, 8]
123 | p2 = data_all[:, 8] - data_all[:, 7]
124 | n0 = np.cross(p1, p2)
125 | p3 = data_all[:, 9] - data_all[:, 7]
126 | n1 = np.cross(n0, p3)
127 | p4 = data_all[:, 7] - data_all[:, 0]
128 | p5 = data_all[:, 7] - data_all[:, 3]
129 | n2 = np.cross(p4, p5)
130 | cos_gt_l = cos_func(n1, n2)
131 | valid_cos['HeadVerticalPlane2HipPlane'] = cos_gt_l
132 |
133 | # Shoulder2Neck
134 | p1 = data_all[:, 10] - data_all[:, 7]
135 | p2 = data_all[:, 6] - data_all[:, 7]
136 | cos_gt_l = cos_func(p1, p2)
137 | p1 = data_all[:, 13] - data_all[:, 7]
138 | p2 = data_all[:, 6] - data_all[:, 7]
139 | cos_gt_r = cos_func(p1, p2)
140 | valid_cos['Shoulder2Neck'] = np.vstack((cos_gt_l, cos_gt_r))
141 |
142 | return valid_cos
143 |
144 |
145 | def h36m_valid_angle_check_torch(p3d):
146 | """
147 | p3d: [bs,16,3] or [bs,48]
148 | """
149 | if p3d.shape[-1] == 48:
150 | p3d = p3d.reshape([p3d.shape[0], 16, 3])
151 | data_all = p3d
152 | cos_func = lambda p1, p2: torch.sum(p1 * p2, dim=1) / torch.norm(p1, dim=1) / torch.norm(p2, dim=1)
153 |
154 | valid_cos = {}
155 | # Spine2LHip
156 | p1 = data_all[:, 3]
157 | p2 = data_all[:, 6]
158 | cos_gt_l = torch.sum(p1 * p2, dim=1) / torch.norm(p1, dim=1) / torch.norm(p2, dim=1)
159 | # Spine2RHip
160 | p1 = data_all[:, 0]
161 | p2 = data_all[:, 6]
162 | cos_gt_r = torch.sum(p1 * p2, dim=1) / torch.norm(p1, dim=1) / torch.norm(p2, dim=1)
163 | valid_cos['Spine2Hip'] = torch.vstack((cos_gt_l, cos_gt_r))
164 |
165 | # LLeg2LeftHipPlane
166 | p0 = data_all[:, 3]
167 | p1 = data_all[:, 4] - data_all[:, 3]
168 | p2 = data_all[:, 5] - data_all[:, 4]
169 | n0 = torch.cross(p0, p1, dim=1)
170 | cos_gt_l = torch.sum(n0 * p2, dim=1) / torch.norm(n0, dim=1) / torch.norm(p2, dim=1)
171 | # RLeg2RHipPlane
172 | p0 = data_all[:, 0]
173 | p1 = data_all[:, 1] - data_all[:, 0]
174 | p2 = data_all[:, 2] - data_all[:, 1]
175 | n0 = torch.cross(p1, p0)
176 | cos_gt_r = torch.sum(n0 * p2, dim=1) / torch.norm(n0, dim=1) / torch.norm(p2, dim=1)
177 | valid_cos['Leg2HipPlane'] = torch.vstack((cos_gt_l, cos_gt_r))
178 |
179 | # Shoulder2Hip
180 | p1 = data_all[:, 10] - data_all[:, 7]
181 | p2 = data_all[:, 3]
182 | cos_gt_l = torch.sum(p1 * p2, dim=1) / torch.norm(p1, dim=1) / torch.norm(p2, dim=1)
183 | p1 = data_all[:, 13] - data_all[:, 7]
184 | p2 = data_all[:, 0]
185 | cos_gt_r = torch.sum(p1 * p2, dim=1) / torch.norm(p1, dim=1) / torch.norm(p2, dim=1)
186 | valid_cos['Shoulder2Hip'] = torch.vstack((cos_gt_l, cos_gt_r))
187 |
188 | # Leg2ShoulderPlane
189 | p0 = data_all[:, 13]
190 | p1 = data_all[:, 10]
191 | p2 = data_all[:, 4]
192 | p3 = data_all[:, 1]
193 | n0 = torch.cross(p0, p1)
194 | cos_gt_l = torch.sum(n0 * p2, dim=1) / torch.norm(n0, dim=1) / torch.norm(p2, dim=1)
195 | cos_gt_r = torch.sum(n0 * p3, dim=1) / torch.norm(n0, dim=1) / torch.norm(p3, dim=1)
196 | valid_cos['Leg2ShoulderPlane'] = torch.vstack((cos_gt_l, cos_gt_r))
197 |
198 | # Shoulder2Shoulder
199 | p0 = data_all[:, 13] - data_all[:, 7]
200 | p1 = data_all[:, 10] - data_all[:, 7]
201 | cos_gt = torch.sum(p0 * p1, dim=1) / torch.norm(p0, dim=1) / torch.norm(p1, dim=1)
202 | valid_cos['Shoulder2Shoulder'] = cos_gt
203 |
204 | # Neck2Spine
205 | p0 = data_all[:, 7] - data_all[:, 6]
206 | p1 = data_all[:, 6]
207 | cos_gt = torch.sum(p0 * p1, dim=1) / torch.norm(p0, dim=1) / torch.norm(p1, dim=1)
208 | valid_cos['Neck2Spine'] = cos_gt
209 |
210 | # Spine2HipPlane1
211 | p0 = data_all[:, 3]
212 | p1 = data_all[:, 4] - data_all[:, 3]
213 | n0 = torch.cross(p1, p0)
214 | p2 = data_all[:, 6]
215 | n1 = torch.cross(p2, n0)
216 | cos_dir_l = torch.sum(p0 * n1, dim=1) / torch.norm(p0, dim=1) / torch.norm(n1, dim=1)
217 | cos_gt_l = torch.sum(n0 * p2, dim=1) / torch.norm(n0, dim=1) / torch.norm(p2, dim=1)
218 | p0 = data_all[:, 0]
219 | p1 = data_all[:, 1] - data_all[:, 0]
220 | n0 = torch.cross(p0, p1)
221 | p2 = data_all[:, 6]
222 | n1 = torch.cross(n0, p2)
223 | cos_dir_r = torch.sum(p0 * n1, dim=1) / torch.norm(p0, dim=1) / torch.norm(n1, dim=1)
224 | cos_gt_r = torch.sum(n0 * p2, dim=1) / torch.norm(n0, dim=1) / torch.norm(p2, dim=1)
225 | cos_gt_l1 = cos_gt_l[cos_dir_l < 0]
226 | cos_gt_r1 = cos_gt_r[cos_dir_r < 0]
227 | valid_cos['Spine2HipPlane1'] = torch.hstack((cos_gt_l1, cos_gt_r1))
228 |
229 | # Spine2HipPlane2
230 | cos_gt_l2 = cos_gt_l[cos_dir_l >= 0]
231 | cos_gt_r2 = cos_gt_r[cos_dir_r >= 0]
232 | valid_cos['Spine2HipPlane2'] = torch.hstack((cos_gt_l2, cos_gt_r2))
233 |
234 | # ShoulderPlane2HipPlane (25 Jan)
235 | p1 = data_all[:, 7] - data_all[:, 3]
236 | p2 = data_all[:, 7] - data_all[:, 0]
237 | p3 = data_all[:, 10]
238 | p4 = data_all[:, 13]
239 | n0 = torch.cross(p2, p1)
240 | n1 = torch.cross(p3, p4)
241 | cos_gt_l = torch.sum(n0 * n1, dim=1) / torch.norm(n0, dim=1) / torch.norm(n1, dim=1)
242 | valid_cos['ShoulderPlane2HipPlane'] = cos_gt_l
243 |
244 | # Head2Neck
245 | p1 = data_all[:, 7] - data_all[:, 6]
246 | p2 = data_all[:, 8] - data_all[:, 7]
247 | cos_gt_l = torch.sum(p1 * p2, dim=1) / torch.norm(p1, dim=1) / torch.norm(p2, dim=1)
248 | valid_cos['Head2Neck'] = cos_gt_l
249 |
250 | # Head2HeadTop
251 | p1 = data_all[:, 9] - data_all[:, 8]
252 | p2 = data_all[:, 8] - data_all[:, 7]
253 | cos_gt_l = torch.sum(p1 * p2, dim=1) / torch.norm(p1, dim=1) / torch.norm(p2, dim=1)
254 | valid_cos['Head2HeadTop'] = cos_gt_l
255 |
256 | # HeadVerticalPlane2HipPlane
257 | p1 = data_all[:, 9] - data_all[:, 8]
258 | p2 = data_all[:, 8] - data_all[:, 7]
259 | n0 = torch.cross(p1, p2)
260 | p3 = data_all[:, 9] - data_all[:, 7]
261 | n1 = torch.cross(n0, p3)
262 | p4 = data_all[:, 7] - data_all[:, 0]
263 | p5 = data_all[:, 7] - data_all[:, 3]
264 | n2 = torch.cross(p4, p5)
265 | cos_gt_l = cos_func(n1, n2)
266 | valid_cos['HeadVerticalPlane2HipPlane'] = cos_gt_l
267 |
268 | # Shoulder2Neck
269 | p1 = data_all[:, 10] - data_all[:, 7]
270 | p2 = data_all[:, 6] - data_all[:, 7]
271 | cos_gt_l = cos_func(p1, p2)
272 | p1 = data_all[:, 13] - data_all[:, 7]
273 | p2 = data_all[:, 6] - data_all[:, 7]
274 | cos_gt_r = cos_func(p1, p2)
275 | valid_cos['Shoulder2Neck'] = torch.vstack((cos_gt_l, cos_gt_r))
276 |
277 | return valid_cos
278 |
279 |
280 | def humaneva_valid_angle_check(p3d):
281 | """
282 | p3d: [bs,14,3] or [bs,42]
283 | """
284 | if p3d.shape[-1] == 42:
285 | p3d = p3d.reshape([p3d.shape[0], 14, 3])
286 |
287 | cos_func = lambda p1, p2: np.sum(p1 * p2, axis=1) / np.linalg.norm(p1, axis=1) / np.linalg.norm(p2, axis=1)
288 | data_all = p3d
289 | valid_cos = {}
290 |
291 | # LHip2RHip
292 | p1 = data_all[:, 7]
293 | p2 = data_all[:, 10]
294 | cos_gt_l = cos_func(p1, p2)
295 | valid_cos['LHip2RHip'] = cos_gt_l
296 |
297 | # Neck2HipPlane
298 | p1 = data_all[:, 7]
299 | p2 = data_all[:, 10]
300 | n0 = np.cross(p1, p2)
301 | p3 = data_all[:, 0]
302 | cos_gt_l = cos_func(n0, p3)
303 | valid_cos['Neck2HipPlane'] = cos_gt_l
304 |
305 | # Head2Neck
306 | p1 = data_all[:, 13] - data_all[:, 0]
307 | p2 = data_all[:, 0]
308 | cos_gt_l = cos_func(p1, p2)
309 | valid_cos['Head2Neck'] = cos_gt_l
310 |
311 | # Shoulder2Shoulder
312 | p1 = data_all[:, 1] - data_all[:, 0]
313 | p2 = data_all[:, 4] - data_all[:, 0]
314 | cos_gt_l = cos_func(p1, p2)
315 | valid_cos['Shoulder2Shoulder'] = cos_gt_l
316 |
317 | # ShoulderPlane2HipPlane
318 | p1 = data_all[:, 7] - data_all[:, 0]
319 | p2 = data_all[:, 10] - data_all[:, 0]
320 | n0 = np.cross(p1, p2)
321 | p3 = data_all[:, 1]
322 | p4 = data_all[:, 4]
323 | n1 = np.cross(p3, p4)
324 | cos_gt_l = cos_func(n0, n1)
325 | valid_cos['ShoulderPlane2HipPlane'] = cos_gt_l
326 |
327 | # Shoulder2Neck
328 | p1 = data_all[:, 1] - data_all[:, 0]
329 | p2 = data_all[:, 0]
330 | cos_gt_l = cos_func(p1, p2)
331 | p1 = data_all[:, 4] - data_all[:, 0]
332 | p2 = data_all[:, 0]
333 | cos_gt_r = cos_func(p1, p2)
334 | valid_cos['Shoulder2Neck'] = np.vstack((cos_gt_l, cos_gt_r))
335 |
336 | # Leg2HipPlane
337 | p1 = data_all[:, 7]
338 | p2 = data_all[:, 10]
339 | n0 = np.cross(p1, p2)
340 | p3 = data_all[:, 8] - data_all[:, 7]
341 | cos_gt_l = cos_func(n0, p3)
342 | p3 = data_all[:, 11] - data_all[:, 10]
343 | cos_gt_r = cos_func(n0, p3)
344 | valid_cos['Leg2HipPlane'] = np.vstack((cos_gt_l, cos_gt_r))
345 |
346 | # Foot2LegPlane
347 | p1 = data_all[:, 7] - data_all[:, 10]
348 | p2 = data_all[:, 11] - data_all[:, 10]
349 | n0 = np.cross(p1, p2)
350 | p3 = data_all[:, 12] - data_all[:, 11]
351 | cos_gt_l = cos_func(n0, p3)
352 | p1 = data_all[:, 7] - data_all[:, 10]
353 | p2 = data_all[:, 8] - data_all[:, 7]
354 | n0 = np.cross(p1, p2)
355 | p3 = data_all[:, 9] - data_all[:, 8]
356 | cos_gt_r = cos_func(n0, p3)
357 | valid_cos['Foot2LegPlane'] = np.vstack((cos_gt_l, cos_gt_r))
358 |
359 | # ForeArm2ShoulderPlane
360 | p1 = data_all[:, 4] - data_all[:, 0]
361 | p2 = data_all[:, 5] - data_all[:, 4]
362 | n0 = np.cross(p1, p2)
363 | p3 = data_all[:, 6] - data_all[:, 5]
364 | cos_gt_l = cos_func(n0, p3)
365 | p1 = data_all[:, 1] - data_all[:, 0]
366 | p2 = data_all[:, 2] - data_all[:, 1]
367 | n0 = np.cross(p2, p1)
368 | p3 = data_all[:, 3] - data_all[:, 2]
369 | cos_gt_r = cos_func(n0, p3)
370 | valid_cos['ForeArm2ShoulderPlane'] = np.vstack((cos_gt_l, cos_gt_r))
371 |
372 | return valid_cos
373 |
374 |
375 | def humaneva_valid_angle_check_torch(p3d):
376 | """
377 | p3d: [bs,14,3] or [bs,42]
378 | """
379 | if p3d.shape[-1] == 42:
380 | p3d = p3d.reshape([p3d.shape[0], 14, 3])
381 |
382 | cos_func = lambda p1, p2: torch.sum(p1 * p2, dim=1) / torch.norm(p1, dim=1) / torch.norm(p2, dim=1)
383 | data_all = p3d
384 | valid_cos = {}
385 |
386 | # LHip2RHip
387 | p1 = data_all[:, 7]
388 | p2 = data_all[:, 10]
389 | cos_gt_l = cos_func(p1, p2)
390 | valid_cos['LHip2RHip'] = cos_gt_l
391 |
392 | # Neck2HipPlane
393 | p1 = data_all[:, 7]
394 | p2 = data_all[:, 10]
395 | n0 = torch.cross(p1, p2)
396 | p3 = data_all[:, 0]
397 | cos_gt_l = cos_func(n0, p3)
398 | valid_cos['Neck2HipPlane'] = cos_gt_l
399 |
400 | # Head2Neck
401 | p1 = data_all[:, 13] - data_all[:, 0]
402 | p2 = data_all[:, 0]
403 | cos_gt_l = cos_func(p1, p2)
404 | valid_cos['Head2Neck'] = cos_gt_l
405 |
406 | # Shoulder2Shoulder
407 | p1 = data_all[:, 1] - data_all[:, 0]
408 | p2 = data_all[:, 4] - data_all[:, 0]
409 | cos_gt_l = cos_func(p1, p2)
410 | valid_cos['Shoulder2Shoulder'] = cos_gt_l
411 |
412 | # ShoulderPlane2HipPlane
413 | p1 = data_all[:, 7] - data_all[:, 0]
414 | p2 = data_all[:, 10] - data_all[:, 0]
415 | n0 = torch.cross(p1, p2)
416 | p3 = data_all[:, 1]
417 | p4 = data_all[:, 4]
418 | n1 = torch.cross(p3, p4)
419 | cos_gt_l = cos_func(n0, n1)
420 | valid_cos['ShoulderPlane2HipPlane'] = cos_gt_l
421 |
422 | # Shoulder2Neck
423 | p1 = data_all[:, 1] - data_all[:, 0]
424 | p2 = data_all[:, 0]
425 | cos_gt_l = cos_func(p1, p2)
426 | p1 = data_all[:, 4] - data_all[:, 0]
427 | p2 = data_all[:, 0]
428 | cos_gt_r = cos_func(p1, p2)
429 | valid_cos['Shoulder2Neck'] = torch.vstack((cos_gt_l, cos_gt_r))
430 |
431 | # Leg2HipPlane
432 | p1 = data_all[:, 7]
433 | p2 = data_all[:, 10]
434 | n0 = torch.cross(p1, p2)
435 | p3 = data_all[:, 8] - data_all[:, 7]
436 | cos_gt_l = cos_func(n0, p3)
437 | p3 = data_all[:, 11] - data_all[:, 10]
438 | cos_gt_r = cos_func(n0, p3)
439 | valid_cos['Leg2HipPlane'] = torch.vstack((cos_gt_l, cos_gt_r))
440 |
441 | # Foot2LegPlane
442 | p1 = data_all[:, 7] - data_all[:, 10]
443 | p2 = data_all[:, 11] - data_all[:, 10]
444 | n0 = torch.cross(p1, p2)
445 | p3 = data_all[:, 12] - data_all[:, 11]
446 | cos_gt_l = cos_func(n0, p3)
447 | p1 = data_all[:, 7] - data_all[:, 10]
448 | p2 = data_all[:, 8] - data_all[:, 7]
449 | n0 = torch.cross(p1, p2)
450 | p3 = data_all[:, 9] - data_all[:, 8]
451 | cos_gt_r = cos_func(n0, p3)
452 | valid_cos['Foot2LegPlane'] = torch.vstack((cos_gt_l, cos_gt_r))
453 |
454 | # ForeArm2ShoulderPlane
455 | p1 = data_all[:, 4] - data_all[:, 0]
456 | p2 = data_all[:, 5] - data_all[:, 4]
457 | n0 = torch.cross(p1, p2)
458 | p3 = data_all[:, 6] - data_all[:, 5]
459 | cos_gt_l = cos_func(n0, p3)
460 | p1 = data_all[:, 1] - data_all[:, 0]
461 | p2 = data_all[:, 2] - data_all[:, 1]
462 | n0 = torch.cross(p2, p1)
463 | p3 = data_all[:, 3] - data_all[:, 2]
464 | cos_gt_r = cos_func(n0, p3)
465 | valid_cos['ForeArm2ShoulderPlane'] = torch.vstack((cos_gt_l, cos_gt_r))
466 |
467 | return valid_cos
468 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import pickle
4 | import argparse
5 | import time
6 | from torch import maximum, optim
7 | from torch.utils.tensorboard import SummaryWriter
8 | import itertools
9 | sys.path.append(os.getcwd())
10 | from utils import *
11 | from motion_pred.utils.config import Config
12 | from motion_pred.utils.dataset_h36m_multimodal import DatasetH36M
13 | from motion_pred.utils.dataset_humaneva_multimodal import DatasetHumanEva
14 | from motion_pred.utils.visualization import render_animation
15 | from models.motion_pred import *
16 | from utils import util, valid_angle_check
17 | from utils.metrics import *
18 | from scipy.spatial.distance import pdist, squareform
19 | from tqdm import tqdm
20 | import random
21 | import time
22 |
23 | def recon_loss(Y_g, Y, Y_mm, Y_hg=None, Y_h=None):
24 |
25 | stat = torch.zeros(Y_g.shape[2])
26 | diff = Y_g - Y.unsqueeze(2) # TBMV
27 | dist = diff.pow(2).sum(dim=-1).sum(dim=0) # BM
28 |
29 | value, indices = dist.min(dim=1)
30 |
31 | loss_recon_1 = value.mean()
32 |
33 | diff = Y_hg - Y_h.unsqueeze(2) # TBMC
34 | loss_recon_2 = diff.pow(2).sum(dim=-1).sum(dim=0).mean()
35 |
36 |
37 | with torch.no_grad():
38 | ade = torch.norm(diff, dim=-1).mean(dim=0).min(dim=1)[0].mean()
39 |
40 | diff = Y_g[:, :, :, None, :] - Y_mm[:, :, None, :, :]
41 |
42 | mask = Y_mm.abs().sum(-1).sum(0) > 1e-6
43 |
44 | dist = diff.pow(2)
45 | with torch.no_grad():
46 | zeros = torch.zeros([dist.shape[1], dist.shape[2]], requires_grad=False).to(dist.device)# [b,m]
47 | zeros.scatter_(dim=1, index=indices.unsqueeze(1).repeat(1, dist.shape[2]), src=zeros+dist.max()-dist.min()+1)
48 | zeros = zeros.unsqueeze(0).unsqueeze(3).unsqueeze(4)
49 | dist += zeros
50 | dist = dist.sum(dim=-1).sum(dim=0)
51 |
52 | value_2, indices_2 = dist.min(dim=1)
53 | loss_recon_multi = value_2[mask].mean()
54 | if torch.isnan(loss_recon_multi):
55 | loss_recon_multi = torch.zeros_like(loss_recon_1)
56 |
57 | mask = torch.tril(torch.ones([cfg.nk, cfg.nk], device=device)) == 0
58 | # TBMC
59 |
60 | yt = Y_g.reshape([-1, cfg.nk, Y_g.shape[3]]).contiguous()
61 | pdist = torch.cdist(yt, yt, p=1)[:, mask]
62 | return loss_recon_1, loss_recon_2, loss_recon_multi, ade, stat, (-pdist / 100).exp().mean()
63 |
64 |
65 | def angle_loss(y):
66 | ang_names = list(valid_ang.keys())
67 | y = y.reshape([-1, y.shape[-1]])
68 | ang_cos = valid_angle_check.h36m_valid_angle_check_torch(
69 | y) if cfg.dataset == 'h36m' else valid_angle_check.humaneva_valid_angle_check_torch(y)
70 | loss = tensor(0, dtype=dtype, device=device)
71 | b = 1
72 | for an in ang_names:
73 | lower_bound = valid_ang[an][0]
74 | if lower_bound >= -0.98:
75 | # loss += torch.exp(-b * (ang_cos[an] - lower_bound)).mean()
76 | if torch.any(ang_cos[an] < lower_bound):
77 | # loss += b * torch.exp(-(ang_cos[an][ang_cos[an] < lower_bound] - lower_bound)).mean()
78 | loss += (ang_cos[an][ang_cos[an] < lower_bound] - lower_bound).pow(2).mean()
79 | upper_bound = valid_ang[an][1]
80 | if upper_bound <= 0.98:
81 | # loss += torch.exp(b * (ang_cos[an] - upper_bound)).mean()
82 | if torch.any(ang_cos[an] > upper_bound):
83 | # loss += b * torch.exp(ang_cos[an][ang_cos[an] > upper_bound] - upper_bound).mean()
84 | loss += (ang_cos[an][ang_cos[an] > upper_bound] - upper_bound).pow(2).mean()
85 | return loss
86 |
87 |
88 | def loss_function(traj_est, traj, traj_multimodal, prior_lkh, prior_logdetjac, _lambda):
89 | lambdas = cfg.lambdas
90 | nj = dataset.traj_dim // 3
91 |
92 | Y_g = traj_est.view(traj_est.shape[0], traj.shape[1], traj_est.shape[1]//traj.shape[1], -1)[t_his:] # T B M V
93 | Y = traj[t_his:]
94 | Y_multimodal = traj_multimodal[t_his:]
95 | Y_hg=traj_est.view(traj_est.shape[0], traj.shape[1], traj_est.shape[1]//traj.shape[1], -1)[:t_his]
96 | Y_h= traj[:t_his]
97 | RECON, RECON_2, RECON_mm, ade, stat, JL = recon_loss(Y_g, Y, Y_multimodal,Y_hg, Y_h)
98 | # maintain limb length
99 | parent = dataset.skeleton.parents()
100 | tmp = traj[0].reshape([cfg.batch_size, nj, 3])
101 | pgt = torch.zeros([cfg.batch_size, nj + 1, 3], dtype=dtype, device=device)
102 | pgt[:, 1:] = tmp
103 | limbgt = torch.norm(pgt[:, 1:] - pgt[:, parent[1:]], dim=2)[None, :, None, :]
104 | tmp = traj_est.reshape([-1, cfg.batch_size, cfg.nk, nj, 3])
105 | pest = torch.zeros([tmp.shape[0], cfg.batch_size, cfg.nk, nj + 1, 3], dtype=dtype, device=device)
106 | pest[:, :, :, 1:] = tmp
107 | limbest = torch.norm(pest[:, :, :, 1:] - pest[:, :, :, parent[1:]], dim=4)
108 | loss_limb = torch.mean((limbgt - limbest).pow(2).sum(dim=3))
109 |
110 | # angle loss
111 | loss_ang = angle_loss(Y_g)
112 | if _lambda < 0.1:
113 | _lambda *= 10
114 | else:
115 | _lambda = 1
116 |
117 | loss_r = loss_limb * lambdas[1] + JL * lambdas[3] * _lambda + RECON * lambdas[4] + RECON_mm * lambdas[5] \
118 | - prior_lkh.mean() * lambdas[6] + RECON_2 * lambdas[7]# - prior_logdetjac.mean() * lambdas[7]
119 | if loss_ang > 0:
120 | loss_r += loss_ang * lambdas[8]
121 | return loss_r, np.array([loss_r.item(), loss_limb.item(), loss_ang.item(),
122 | JL.item(), RECON.item(), RECON_2.item(), RECON_mm.item(), ade.item(),
123 | prior_lkh.mean().item(), prior_logdetjac.mean().item()]), stat#, indices_key, indices_2_key
124 |
125 |
126 | def dct_transform_torch(data, dct_m, dct_n):
127 | '''
128 | B, 60, 35
129 | '''
130 | batch_size, features, seq_len = data.shape
131 |
132 | data = data.contiguous().view(-1, seq_len) # [180077*60, 35]
133 | data = data.permute(1, 0) # [35, b*60]
134 |
135 | out_data = torch.matmul(dct_m[:dct_n, :], data) # [dct_n, 180077*60]
136 | out_data = out_data.permute(1, 0).contiguous().view(-1, features, dct_n) # [b, 60, dct_n]
137 | return out_data
138 |
139 | def get_dct_matrix(N):
140 | dct_m = np.eye(N)
141 | for k in np.arange(N):
142 | for i in np.arange(N):
143 | w = np.sqrt(2 / N)
144 | if k == 0:
145 | w = np.sqrt(1 / N)
146 | dct_m[k, i] = w * np.cos(np.pi * (i + 1 / 2) * k / N)
147 | idct_m = np.linalg.inv(dct_m)
148 | return dct_m, idct_m
149 |
150 | def train(epoch, stats):
151 |
152 | dct_m, i_dct_m = get_dct_matrix(cfg.t_his+cfg.t_pred)
153 | dct_m = torch.from_numpy(dct_m).float().to(device)
154 | i_dct_m = torch.from_numpy(i_dct_m).float().to(device)
155 |
156 | model.train()
157 | t_s = time.time()
158 | train_losses = 0
159 | train_grad = 0
160 | total_num_sample = 0
161 | n_modality = 10
162 | loss_names = ['LOSS', 'loss_limb', 'loss_ang', 'loss_DIV',
163 | 'RECON', 'RECON_2', 'RECON_multi', "ADE", 'p(z)', 'logdet']
164 | generator = dataset.sampling_generator(num_samples=cfg.num_data_sample, batch_size=cfg.batch_size,
165 | n_modality=n_modality)
166 | prior = torch.distributions.Normal(torch.tensor(0, dtype=dtype, device=device),
167 | torch.tensor(1, dtype=dtype, device=device))
168 |
169 | for traj_np, traj_multimodal_np in tqdm(generator):
170 | with torch.no_grad():
171 |
172 | bs, _, nj, _ = traj_np[..., 1:, :].shape # [bs, t_full, numJoints, 3]
173 | traj_np = traj_np[..., 1:, :].reshape(traj_np.shape[0], traj_np.shape[1], -1) # bs, T, NumJoints*3
174 | traj = tensor(traj_np, device=device, dtype=dtype).permute(1, 0, 2).contiguous() # T, bs, NumJoints*3
175 |
176 | traj_multimodal_np = traj_multimodal_np[..., 1:, :] # [bs, n_modality, t_full, NumJoints, 3]
177 | traj_multimodal_np = traj_multimodal_np.reshape([bs, n_modality, t_his + t_pred, -1]).transpose(
178 | [2, 0, 1, 3]) # [t_full, bs, n_modality, NumJoints*3]
179 | traj_multimodal = tensor(traj_multimodal_np, device=device, dtype=dtype) # .permute(0, 2, 1).contiguous()
180 |
181 | X = traj[:t_his]
182 | Y = traj[t_his:]
183 |
184 | pred, a, b = model(traj)
185 |
186 | pred_tmp1 = pred.reshape([-1, pred.shape[-1] // 3, 3])
187 | pred_tmp = torch.zeros_like(pred_tmp1[:, :1, :])
188 | pred_tmp1 = torch.cat([pred_tmp, pred_tmp1], dim=1)
189 | pred_tmp1 = util.absolute2relative_torch(pred_tmp1, parents=dataset.skeleton.parents()).reshape(
190 | [-1, pred.shape[-1]])
191 | z, prior_logdetjac = pose_prior(pred_tmp1)
192 | prior_lkh = prior.log_prob(z).sum(dim=-1)
193 |
194 | loss, losses, stat = loss_function(pred.unsqueeze(2), traj, traj_multimodal, prior_lkh, prior_logdetjac, epoch / cfg.num_epoch)
195 |
196 | optimizer.zero_grad()
197 | loss.backward()
198 | grad_norm = torch.nn.utils.clip_grad_norm_(list(model.parameters()), max_norm=100)
199 | train_grad += grad_norm
200 | optimizer.step()
201 | train_losses += losses
202 |
203 | total_num_sample += 1
204 | del loss
205 |
206 | scheduler.step()
207 | train_losses /= total_num_sample
208 | losses_str = ' '.join(['{}: {:.4f}'.format(x, y) for x, y in zip(loss_names, train_losses)])
209 | lr = optimizer.param_groups[0]['lr']
210 | # average cost of log time 20s
211 | tb_logger.add_scalar('train_grad', train_grad / total_num_sample, epoch)
212 |
213 | logger.info('====> Epoch: {} Time: {:.2f} {} lr: {:.5f} branch_stats: {}'.format(epoch, time.time() - t_s, losses_str , lr, stats))
214 |
215 | return stats
216 |
217 | def get_multimodal_gt(dataset_test):
218 | all_data = []
219 | data_gen = dataset_test.iter_generator(step=cfg.t_his)
220 | for data, _ in tqdm(data_gen):
221 | # print(data.shape)
222 | data = data[..., 1:, :].reshape(data.shape[0], data.shape[1], -1)
223 | all_data.append(data)
224 | all_data = np.concatenate(all_data, axis=0)
225 | all_start_pose = all_data[:, t_his - 1, :]
226 | pd = squareform(pdist(all_start_pose))
227 | traj_gt_arr = []
228 | num_mult = []
229 | for i in range(pd.shape[0]):
230 | ind = np.nonzero(pd[i] < args.multimodal_threshold)
231 | traj_gt_arr.append(all_data[ind][:, t_his:, :])
232 | num_mult.append(len(ind[0]))
233 |
234 | num_mult = np.array(num_mult)
235 | logger.info('')
236 | logger.info('')
237 | logger.info('=' * 80)
238 | logger.info(f'#1 future: {len(np.where(num_mult == 1)[0])}/{pd.shape[0]}')
239 | logger.info(f'#<10 future: {len(np.where(num_mult < 10)[0])}/{pd.shape[0]}')
240 | return traj_gt_arr
241 |
242 | def get_prediction(data, model, sample_num, num_seeds=1, concat_hist=True):
243 | # 1 * total_len * num_key * 3
244 | dct_m, i_dct_m = get_dct_matrix(cfg.t_his+cfg.t_pred)
245 | dct_m = torch.from_numpy(dct_m).float().to(device)
246 | i_dct_m = torch.from_numpy(i_dct_m).float().to(device)
247 | traj_np = data[..., 1:, :].reshape(data.shape[0], data.shape[1], -1)
248 | # 1 * total_len * ((num_key-1)*3)
249 | traj = tensor(traj_np, device=device, dtype=dtype).permute(1, 0, 2).contiguous()
250 | # total_len * 1 * ((num_key-1)*3)
251 | X = traj[:t_his]
252 | Y_gt = traj[t_his:]
253 | X = X.repeat((1, sample_num * num_seeds, 1))
254 | Y_gt = Y_gt.repeat((1, sample_num * num_seeds, 1))
255 |
256 |
257 | Y, mu, logvar = model(X)
258 |
259 | if concat_hist:
260 |
261 | X = X.unsqueeze(2).repeat(1, sample_num * num_seeds, cfg.nk, 1)
262 | Y = Y[t_his:].unsqueeze(1)
263 | Y = torch.cat((X, Y), dim=0)
264 | # total_len * batch_size * feature_size
265 | Y = Y.squeeze(1).permute(1, 0, 2).contiguous().cpu().numpy()
266 | # batch_size * total_len * feature_size
267 | if Y.shape[0] > 1:
268 | Y = Y.reshape(-1, cfg.nk * sample_num, Y.shape[-2], Y.shape[-1])
269 | else:
270 | Y = Y[None, ...]
271 | return Y
272 |
273 |
274 | def test(model, epoch):
275 | stats_func = {'Diversity': compute_diversity, 'AMSE': compute_amse, 'FMSE': compute_fmse, 'ADE': compute_ade,
276 | 'FDE': compute_fde, 'MMADE': compute_mmade, 'MMFDE': compute_mmfde, 'MPJPE': mpjpe_error}
277 | stats_names = list(stats_func.keys())
278 | stats_names.extend(['ADE_stat', 'FDE_stat', 'MMADE_stat', 'MMFDE_stat', 'MPJPE_stat'])
279 | stats_meter = {x: AverageMeter() for x in stats_names}
280 |
281 | data_gen = dataset_test.iter_generator(step=cfg.t_his)
282 | num_samples = 0
283 | num_seeds = 1
284 |
285 | for i, (data, _) in tqdm(enumerate(data_gen)):
286 | if args.mode == 'train' and (i >= 500 and (epoch + 1) % 50 != 0 and (epoch + 1) < cfg.num_epoch - 100):
287 | break
288 | num_samples += 1
289 | gt = data[..., 1:, :].reshape(data.shape[0], data.shape[1], -1)[:, t_his:, :]
290 | gt_multi = traj_gt_arr[i]
291 | if gt_multi.shape[0] == 1:
292 | continue
293 |
294 | pred = get_prediction(data, model, sample_num=1, num_seeds=num_seeds, concat_hist=False)
295 | pred = pred[:,:,t_his:,:]
296 | for stats in stats_names[:8]:
297 | val = 0
298 | branches = 0
299 | for pred_i in pred:
300 | # sample_num * total_len * ((num_key-1)*3), 1 * total_len * ((num_key-1)*3)
301 | v = stats_func[stats](pred_i, gt, gt_multi)
302 | val += v[0] / num_seeds
303 | if stats_func[stats](pred_i, gt, gt_multi)[1] is not None:
304 | branches += v[1] / num_seeds
305 | stats_meter[stats].update(val)
306 | if type(branches) is not int:
307 | stats_meter[stats + '_stat'].update(branches)
308 |
309 | logger.info('=' * 80)
310 | for stats in stats_names:
311 | str_stats = f'Total {stats}: ' + f'{stats_meter[stats].avg}'
312 | logger.info(str_stats)
313 | logger.info('=' * 80)
314 |
315 |
316 | def visualize():
317 | def denomarlize(*data):
318 | out = []
319 | for x in data:
320 | x = x * dataset.std + dataset.mean
321 | out.append(x)
322 | return out
323 |
324 | def post_process(pred, data):
325 | pred = pred.reshape(pred.shape[0], pred.shape[1], -1, 3)
326 | if cfg.normalize_data:
327 | pred = denomarlize(pred)
328 | pred = np.concatenate((np.tile(data[..., :1, :], (pred.shape[0], 1, 1, 1)), pred), axis=2)
329 | pred[..., :1, :] = 0
330 | return pred
331 |
332 | def pose_generator():
333 |
334 | while True:
335 | data, data_multimodal, action = dataset_test.sample(n_modality=10)
336 | gt = data[0].copy()
337 | gt[:, :1, :] = 0
338 |
339 | poses = {'action': action, 'context': gt, 'gt': gt}
340 | with torch.no_grad():
341 | pred = get_prediction(data, model, 1)[0]
342 | pred = post_process(pred, data)
343 | for i in range(pred.shape[0]):
344 | poses[f'{i}'] = pred[i]
345 |
346 | yield poses
347 |
348 | pose_gen = pose_generator()
349 | for i in tqdm(range(args.n_viz)):
350 | render_animation(dataset.skeleton, pose_gen, cfg.t_his, ncol=12, output='./results/{}/results/'.format(args.cfg), index_i=i)
351 |
352 |
353 |
354 |
355 |
356 | if __name__ == '__main__':
357 |
358 | parser = argparse.ArgumentParser()
359 | parser.add_argument('--cfg',
360 | default='h36m')
361 | parser.add_argument('--mode', default='train')
362 | parser.add_argument('--test', action='store_true', default=False)
363 | parser.add_argument('--iter', type=int, default=0)
364 | parser.add_argument('--seed', type=int, default=1)
365 | parser.add_argument('--gpu_index', type=int, default=0)
366 | parser.add_argument('--n_pre', type=int, default=8)
367 | parser.add_argument('--n_his', type=int, default=5)
368 | parser.add_argument('--n_viz', type=int, default=100)
369 | parser.add_argument('--num_coupling_layer', type=int, default=4)
370 | parser.add_argument('--multimodal_threshold', type=float, default=0.5)
371 | args = parser.parse_args()
372 |
373 | """setup"""
374 | np.random.seed(args.seed)
375 | torch.manual_seed(args.seed)
376 | dtype = torch.float32
377 | torch.set_default_dtype(dtype)
378 |
379 | device = torch.device('cuda', index=args.gpu_index) if torch.cuda.is_available() else torch.device('cpu')
380 |
381 | cfg = Config(f'{args.cfg}', test=args.test)
382 | tb_logger = SummaryWriter(cfg.tb_dir) if args.mode == 'train' else None
383 | logger = create_logger(os.path.join(cfg.log_dir, 'log.txt'))
384 |
385 | """parameter"""
386 | mode = args.mode
387 | nz = cfg.nz
388 | t_his = cfg.t_his
389 | t_pred = cfg.t_pred
390 | cfg.n_his = args.n_his
391 | if 'n_pre' not in cfg.specs.keys():
392 | cfg.n_pre = args.n_pre
393 | else:
394 | cfg.n_pre = cfg.specs['n_pre']
395 | cfg.num_coupling_layer = args.num_coupling_layer
396 | # cfg.nz = args.nz
397 | """data"""
398 | if 'actions' in cfg.specs.keys():
399 | act = cfg.specs['actions']
400 | else:
401 | act = 'all'
402 | dataset_cls = DatasetH36M if cfg.dataset == 'h36m' else DatasetHumanEva
403 | dataset = dataset_cls('train', t_his, t_pred, actions=act, use_vel=cfg.use_vel,
404 | multimodal_path=cfg.specs[
405 | 'multimodal_path'] if 'multimodal_path' in cfg.specs.keys() else None,
406 | data_candi_path=cfg.specs[
407 | 'data_candi_path'] if 'data_candi_path' in cfg.specs.keys() else None)
408 | dataset_test = dataset_cls('test', t_his, t_pred, actions=act, use_vel=cfg.use_vel,
409 | multimodal_path=cfg.specs[
410 | 'multimodal_path'] if 'multimodal_path' in cfg.specs.keys() else None,
411 | data_candi_path=cfg.specs[
412 | 'data_candi_path'] if 'data_candi_path' in cfg.specs.keys() else None)
413 | if cfg.normalize_data:
414 | dataset.normalize_data()
415 | dataset_test.normalize_data(dataset.mean, dataset.std)
416 | traj_gt_arr = get_multimodal_gt(dataset_test)
417 | """model"""
418 | model, pose_prior = get_model(cfg, dataset, cfg.dataset)
419 |
420 | model.float()
421 | pose_prior.float()
422 |
423 | optimizer = optim.Adam(model.parameters(), lr=cfg.lr)
424 |
425 | scheduler = get_scheduler(optimizer, policy='lambda', nepoch_fix=cfg.num_epoch_fix, nepoch=cfg.num_epoch)
426 |
427 |
428 | cp_path = 'results/h36m_nf/models/0025.p' if cfg.dataset == 'h36m' else 'results/humaneva_nf/models/0025.p'
429 | print('loading model from checkpoint: %s' % cp_path)
430 | model_cp = pickle.load(open(cp_path, "rb"))
431 | pose_prior.load_state_dict(model_cp['model_dict'])
432 | pose_prior.to(device)
433 |
434 | valid_ang = pickle.load(open('./data/h36m_valid_angle.p', "rb")) if cfg.dataset == 'h36m' else pickle.load(
435 | open('./data/humaneva_valid_angle.p', "rb"))
436 | if args.iter > 0:
437 | cp_path = cfg.model_path % args.iter
438 | print('loading model from checkpoint: %s' % cp_path)
439 | model_cp = pickle.load(open(cp_path, "rb"))
440 | model.load_state_dict(model_cp['model_dict'])
441 | print("load done")
442 |
443 | if mode == 'train':
444 | model.to(device)
445 | overall_iter = 0
446 | stats = torch.zeros(cfg.nk)
447 | model.train()
448 |
449 | for i in range(args.iter, cfg.num_epoch):
450 | stats = train(i, stats)
451 | if cfg.save_model_interval > 0 and (i + 1) % 10 == 0:
452 | model.eval()
453 | with torch.no_grad():
454 | test(model, i)
455 | model.train()
456 | with to_cpu(model):
457 | cp_path = cfg.model_path % (i + 1)
458 | model_cp = {'model_dict': model.state_dict(), 'meta': {'std': dataset.std, 'mean': dataset.mean}}
459 |
460 | pickle.dump(model_cp, open(cp_path, 'wb'))
461 |
462 | elif mode == 'test':
463 | model.to(device)
464 | model.eval()
465 |
466 | with torch.no_grad():
467 | test(model,args.iter)
468 |
469 |
470 | elif mode == 'viz':
471 | model.to(device)
472 | model.eval()
473 | with torch.no_grad():
474 | visualize()
475 |
--------------------------------------------------------------------------------
/models/model.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 |
4 | import torch
5 | import torch.nn as nn
6 | import math
7 | import numpy as np
8 | from torch.nn import functional as F
9 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
10 | return F.leaky_relu(input + bias, negative_slope) * scale
11 | class ST_GCNN_layer_down(nn.Module):
12 | """
13 | Shape:
14 | - Input[0]: Input graph sequence in :math:`(N, in_channels, T_{in}, V)` format
15 | - Input[1]: Input graph adjacency matrix in :math:`(K, V, V)` format
16 | - Output[0]: Outpu graph sequence in :math:`(N, out_channels, T_{out}, V)` format
17 | where
18 | :math:`N` is a batch size,
19 | :math:`K` is the spatial kernel size, as :math:`K == kernel_size[1]`,
20 | :math:`T_{in}/T_{out}` is a length of input/output sequence,
21 | :math:`V` is the number of graph nodes.
22 | :in_channels= dimension of coordinates
23 | : out_channels=dimension of coordinates
24 | +
25 | """
26 | def __init__(self,
27 | in_channels,
28 | out_channels,
29 | kernel_size,
30 | stride,
31 | time_dim,
32 | joints_dim,
33 | dropout,
34 | bias=True,
35 | version=0,
36 | pose_info=None):
37 |
38 | super(ST_GCNN_layer_down,self).__init__()
39 | self.kernel_size = kernel_size
40 | # assert self.kernel_size[0] % 2 == 1
41 | # assert self.kernel_size[1] % 2 == 1
42 | # padding = ((self.kernel_size[0] - 1) // 2,(self.kernel_size[1] - 1) // 2)
43 | padding = (0,0)
44 |
45 | if version == 0:
46 | self.gcn=ConvTemporalGraphical(time_dim,joints_dim) # the convolution layer
47 | elif version == 1:
48 | self.gcn = ConvTemporalGraphicalV1(time_dim,joints_dim,pose_info=pose_info)
49 | if type(stride) != list:
50 | self.tcn = nn.Sequential(
51 | nn.Conv2d(
52 | in_channels,
53 | out_channels,
54 | (self.kernel_size[0], self.kernel_size[1]),
55 | (stride, stride),
56 | padding,
57 | ),
58 | nn.BatchNorm2d(out_channels),
59 | nn.Dropout(dropout, inplace=True),
60 | )
61 | else:
62 | self.tcn = nn.Sequential(
63 | nn.Conv2d(
64 | in_channels,
65 | out_channels,
66 | (self.kernel_size[0], self.kernel_size[1]),
67 | (stride[0], stride[1]),
68 | padding,
69 | ),
70 | nn.BatchNorm2d(out_channels),
71 | nn.Dropout(dropout, inplace=True),
72 | )
73 |
74 |
75 |
76 | if stride != 1 or in_channels != out_channels:
77 |
78 | self.residual=nn.Sequential(nn.Conv2d(
79 | in_channels,
80 | out_channels,
81 | kernel_size=1,
82 | stride=(1, 1)),
83 | nn.BatchNorm2d(out_channels),
84 | )
85 |
86 |
87 | else:
88 | self.residual=nn.Identity()
89 |
90 |
91 | self.prelu = nn.PReLU()
92 |
93 |
94 |
95 | def forward(self, x):
96 | # assert A.shape[0] == self.kernel_size[1], print(A.shape[0],self.kernel_size)
97 | res=self.residual(x)
98 | x=self.gcn(x)
99 | x=self.tcn(x)
100 | # x=x+res
101 | x=self.prelu(x)
102 | return x
103 | def count_parameters(model):
104 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
105 |
106 | class EqualLinear(nn.Module):
107 | def __init__(self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None):
108 | super().__init__()
109 |
110 | self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
111 |
112 | if bias:
113 | self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
114 | else:
115 | self.bias = None
116 |
117 | self.activation = activation
118 |
119 | self.scale = (1 / math.sqrt(in_dim)) * lr_mul
120 | self.lr_mul = lr_mul
121 |
122 | def forward(self, input):
123 |
124 | if self.activation:
125 | out = F.linear(input, self.weight * self.scale)
126 | out = fused_leaky_relu(out, self.bias * self.lr_mul)
127 | else:
128 | out = F.linear(input, self.weight * self.scale, bias=self.bias * self.lr_mul)
129 |
130 | return out
131 |
132 | def __repr__(self):
133 | return (f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})')
134 |
135 | class GraphConv(nn.Module):
136 | """
137 | adapted from : https://github.com/tkipf/gcn/blob/92600c39797c2bfb61a508e52b88fb554df30177/gcn/layers.py#L132
138 | """
139 |
140 | def __init__(self, in_len, out_len, in_node_n=66, out_node_n=66, bias=True):
141 | super(GraphConv, self).__init__()
142 | self.in_len = in_len
143 | self.out_len = out_len
144 | self.in_node_n = in_node_n
145 | self.out_node_n = out_node_n
146 | self.weight = nn.Parameter(torch.FloatTensor(in_len, out_len))
147 | self.att = nn.Parameter(torch.FloatTensor(in_node_n, out_node_n))
148 |
149 | if bias:
150 | self.bias = nn.Parameter(torch.FloatTensor(out_len))
151 | else:
152 | self.register_parameter('bias', None)
153 | self.reset_parameters()
154 |
155 |
156 | def reset_parameters(self):
157 | stdv = 1. / math.sqrt(self.weight.size(1))
158 | self.weight.data.uniform_(-stdv, stdv)
159 | self.att.data.uniform_(-stdv, stdv)
160 | if self.bias is not None:
161 | self.bias.data.uniform_(-stdv, stdv)
162 |
163 | def forward(self, input):
164 | '''
165 | b, cv, t
166 | '''
167 |
168 | features = torch.matmul(input, self.weight) # 35 -> 256
169 | output = torch.matmul(features.permute(0, 2, 1).contiguous(), self.att).permute(0, 2, 1).contiguous() # 66 -> 66
170 |
171 | if self.bias is not None:
172 | output = output + self.bias
173 |
174 | return output
175 |
176 | def __repr__(self):
177 | return self.__class__.__name__ + ' ('+ str(self.in_len) + ' -> ' + str(self.out_len) + ')' + ' ('+ str(self.in_node_n) + ' -> ' + str(self.out_node_n) + ')'
178 |
179 | class GraphConvBlock(nn.Module):
180 | def __init__(self, in_len, out_len, in_node_n, out_node_n, dropout_rate=0, leaky=0.1, bias=True, residual=False):
181 | super(GraphConvBlock, self).__init__()
182 | self.dropout_rate = dropout_rate
183 | self.resual = residual
184 |
185 | self.out_len = out_len
186 |
187 | self.gcn = GraphConv(in_len, out_len, in_node_n=in_node_n, out_node_n=out_node_n, bias=bias)
188 | self.bn = nn.BatchNorm1d(out_node_n * out_len)
189 | self.act = nn.Tanh()
190 | if self.dropout_rate > 0:
191 | self.drop = nn.Dropout(dropout_rate)
192 |
193 | def forward(self, input):
194 | '''
195 |
196 | Args:
197 | input: b, cv, t
198 |
199 | Returns:
200 |
201 | '''
202 | x = self.gcn(input)
203 | b, vc, t = x.shape
204 | x = self.bn(x.view(b, -1)).view(b, vc, t)
205 | # x = self.bn(x.view(b, -1, 3, t).permute(0, 2, 1, 3).contiguous()).permute(0, 2, 1, 3).contiguous().view(b, vc, t)
206 | x = self.act(x)
207 | if self.dropout_rate > 0:
208 | x = self.drop(x)
209 |
210 | if self.resual:
211 | return x + input
212 | else:
213 | return x
214 |
215 |
216 | class ResGCB(nn.Module):
217 | def __init__(self, in_len, out_len, in_node_n, out_node_n, dropout_rate=0, leaky=0.1, bias=True, residual=False):
218 | super(ResGCB, self).__init__()
219 | self.resual = residual
220 | self.gcb1 = GraphConvBlock(in_len, in_len, in_node_n=in_node_n, out_node_n=in_node_n, dropout_rate=dropout_rate, bias=bias, residual=False)
221 | self.gcb2 = GraphConvBlock(in_len, out_len, in_node_n=in_node_n, out_node_n=out_node_n, dropout_rate=dropout_rate, bias=bias, residual=False)
222 |
223 |
224 | def forward(self, input):
225 | '''
226 |
227 | Args:
228 | x: B,CV,T
229 |
230 | Returns:
231 |
232 | '''
233 |
234 | x = self.gcb1(input)
235 | x = self.gcb2(x)
236 |
237 | if self.resual:
238 | return x + input
239 | else:
240 | return x
241 |
242 | class ConvTemporalGraphical(nn.Module):
243 | #Source : https://github.com/yysijie/st-gcn/blob/master/net/st_gcn.py
244 | r"""The basic module for applying a graph convolution.
245 | Shape:
246 | - Input: Input graph sequence in :math:`(N, in_channels, T_{in}, V)` format
247 | - Output: Output graph sequence in :math:`(N, out_channels, T_{out}, V)` format
248 | where
249 | :math:`N` is a batch size,
250 | :math:`K` is the spatial kernel size, as :math:`K == kernel_size[1]`,
251 | :math:`T_{in}/T_{out}` is a length of input/output sequence,
252 | :math:`V` is the number of graph nodes.
253 | """
254 | def __init__(self,
255 | time_dim,
256 | joints_dim
257 | ):
258 | super(ConvTemporalGraphical,self).__init__()
259 |
260 | self.A=nn.Parameter(torch.FloatTensor(time_dim, joints_dim,joints_dim)) #learnable, graph-agnostic 3-d adjacency matrix(or edge importance matrix)
261 | stdv = 1. / math.sqrt(self.A.size(1))
262 | self.A.data.uniform_(-stdv,stdv)
263 |
264 | self.T=nn.Parameter(torch.FloatTensor(joints_dim, time_dim, time_dim))
265 | stdv = 1. / math.sqrt(self.T.size(1))
266 | self.T.data.uniform_(-stdv,stdv)
267 | '''
268 | self.prelu = nn.PReLU()
269 |
270 | self.Z=nn.Parameter(torch.FloatTensor(joints_dim, joints_dim, time_dim, time_dim))
271 | stdv = 1. / math.sqrt(self.Z.size(2))
272 | self.Z.data.uniform_(-stdv,stdv)
273 | '''
274 | self.joints_dim = joints_dim
275 | self.time_dim = time_dim
276 |
277 | def forward(self, x):
278 | x = torch.einsum('nctv,vtq->ncqv', (x, self.T))
279 | ## x=self.prelu(x)
280 | x = torch.einsum('nctv,tvw->nctw', (x, self.A))
281 | ## x = torch.einsum('nctv,wvtq->ncqw', (x, self.Z))
282 | return x.contiguous()
283 |
284 |
285 | class ConvTemporalGraphicalV1(nn.Module):
286 | #Source : https://github.com/yysijie/st-gcn/blob/master/net/st_gcn.py
287 | r"""The basic module for applying a graph convolution.
288 | Shape:
289 | - Input: Input graph sequence in :math:`(N, in_channels, T_{in}, V)` format
290 | - Output: Outpu graph sequence in :math:`(N, out_channels, T_{out}, V)` format
291 | where
292 | :math:`N` is a batch size,
293 | :math:`K` is the spatial kernel size, as :math:`K == kernel_size[1]`,
294 | :math:`T_{in}/T_{out}` is a length of input/output sequence,
295 | :math:`V` is the number of graph nodes.
296 | """
297 | def __init__(self,
298 | time_dim,
299 | joints_dim,
300 | pose_info
301 | ):
302 | super(ConvTemporalGraphicalV1,self).__init__()
303 | parents=pose_info['parents']
304 | joints_left=list(pose_info['joints_left'])
305 | joints_right=list(pose_info['joints_right'])
306 | keep_joints=pose_info['keep_joints']
307 | dim_use = list(keep_joints)
308 | # print(dim_use)
309 | self.A=nn.Parameter(torch.FloatTensor(time_dim, joints_dim, joints_dim)) #learnable, graph-agnostic 3-d adjacency matrix(or edge importance matrix)
310 | stdv = 1. / math.sqrt(self.A.size(1))
311 | self.A.data.uniform_(-stdv,stdv)
312 |
313 | self.T=nn.Parameter(torch.FloatTensor(joints_dim, time_dim, time_dim))
314 | stdv = 1. / math.sqrt(self.T.size(1))
315 | self.T.data.uniform_(-stdv,stdv)
316 | '''
317 | self.prelu = nn.PReLU()
318 |
319 | self.Z=nn.Parameter(torch.FloatTensor(joints_dim, joints_dim, time_dim, time_dim))
320 | stdv = 1. / math.sqrt(self.Z.size(2))
321 | self.Z.data.uniform_(-stdv,stdv)
322 | '''
323 | self.A_s = torch.zeros((1,joints_dim,joints_dim), requires_grad=False, dtype=torch.float)
324 | for i, dim in enumerate(dim_use):
325 | self.A_s[0][i][i] = 1
326 | if parents[dim] in dim_use:
327 | parent_index = dim_use.index(parents[dim])
328 | self.A_s[0][i][parent_index] = 1
329 | self.A_s[0][parent_index][i] = 1
330 | if dim in joints_left:
331 | index = joints_left.index(dim)
332 | right_dim = joints_right[index]
333 | right_index = dim_use.index(right_dim)
334 | if right_dim in dim_use:
335 | self.A_s[0][i][right_index] = 1
336 | self.A_s[0][right_index][i] = 1
337 |
338 | self.joints_dim = joints_dim
339 | self.time_dim = time_dim
340 |
341 | def forward(self, x):
342 | A = self.A * self.A_s.to(x.device)
343 | x = torch.einsum('nctv,vtq->ncqv', (x, self.T))
344 | x = torch.einsum('nctv,tvw->nctw', (x, A))
345 | return x.contiguous()
346 |
347 |
348 | class ST_GCNN_layer(nn.Module):
349 | """
350 | Shape:
351 | - Input[0]: Input graph sequence in :math:`(N, in_channels, T_{in}, V)` format
352 | - Input[1]: Input graph adjacency matrix in :math:`(K, V, V)` format
353 | - Output[0]: Outpu graph sequence in :math:`(N, out_channels, T_{out}, V)` format
354 | where
355 | :math:`N` is a batch size,
356 | :math:`K` is the spatial kernel size, as :math:`K == kernel_size[1]`,
357 | :math:`T_{in}/T_{out}` is a length of input/output sequence,
358 | :math:`V` is the number of graph nodes.
359 | :in_channels= dimension of coordinates
360 | : out_channels=dimension of coordinates
361 | +
362 | """
363 | def __init__(self,
364 | in_channels,
365 | out_channels,
366 | kernel_size,
367 | stride,
368 | time_dim,
369 | joints_dim,
370 | dropout,
371 | bias=True,
372 | version=0,
373 | pose_info=None):
374 |
375 | super(ST_GCNN_layer,self).__init__()
376 | self.kernel_size = kernel_size
377 | assert self.kernel_size[0] % 2 == 1
378 | assert self.kernel_size[1] % 2 == 1
379 | padding = ((self.kernel_size[0] - 1) // 2,(self.kernel_size[1] - 1) // 2)
380 |
381 | if version == 0:
382 | self.gcn=ConvTemporalGraphical(time_dim,joints_dim) # the convolution layer
383 | elif version == 1:
384 | self.gcn = ConvTemporalGraphicalV1(time_dim,joints_dim,pose_info=pose_info)
385 |
386 | self.tcn = nn.Sequential(
387 | nn.Conv2d(
388 | in_channels,
389 | out_channels,
390 | (self.kernel_size[0], self.kernel_size[1]),
391 | (stride, stride),
392 | padding,
393 | ),
394 | nn.BatchNorm2d(out_channels),
395 | nn.Dropout(dropout, inplace=True),
396 | )
397 |
398 |
399 |
400 | if stride != 1 or in_channels != out_channels:
401 |
402 | self.residual=nn.Sequential(nn.Conv2d(
403 | in_channels,
404 | out_channels,
405 | kernel_size=1,
406 | stride=(1, 1)),
407 | nn.BatchNorm2d(out_channels),
408 | )
409 |
410 |
411 | else:
412 | self.residual=nn.Identity()
413 |
414 |
415 | self.prelu = nn.PReLU()
416 |
417 |
418 |
419 | def forward(self, x):
420 | # assert A.shape[0] == self.kernel_size[1], print(A.shape[0],self.kernel_size)
421 | res=self.residual(x)
422 | x=self.gcn(x)
423 | x=self.tcn(x)
424 | x=x+res
425 | x=self.prelu(x)
426 | return x
427 |
428 |
429 |
430 | class Direction(nn.Module):
431 | def __init__(self, motion_dim):
432 | super(Direction, self).__init__()
433 |
434 | self.weight = nn.Parameter(torch.randn(256, motion_dim))
435 |
436 | def forward(self, input):
437 | # input: (bs*t) x 256
438 |
439 | weight = self.weight + 1e-8
440 | Q, R = torch.qr(weight) # get eignvector, orthogonal [n1, n2, n3, n4]
441 |
442 | if input is None:
443 | return Q
444 | else:
445 | input_diag = torch.diag_embed(input) # alpha, diagonal matrix
446 | out = torch.matmul(input_diag, Q.T)
447 | out = torch.sum(out, dim=1)
448 |
449 | return out
450 |
451 |
452 | class Model(nn.Module):
453 | def __init__(self, nx, ny,input_channels,st_gcnn_dropout,
454 | joints_to_consider,
455 | pose_info):
456 | super(Model, self).__init__()
457 | self.nx = nx
458 | self.ny = ny
459 | self.output_len = 20
460 | self.num_joints = joints_to_consider
461 | self.num_D = 30
462 | if nx == 48:
463 | self.t_his = 25
464 | self.t_pred = 100
465 | elif nx == 42:
466 | self.t_his = 15
467 | self.t_pred = 60
468 | self.nk = 50
469 | self.num_anchor = self.nk
470 | self.anchor_input = nn.ParameterDict()
471 | stdv_anchor = 1. / math.sqrt(128)
472 | for i in range(self.num_anchor):
473 | self.anchor_input[f'anchor_{i}'] = nn.Parameter(torch.FloatTensor(1, 128).uniform_(-stdv_anchor, stdv_anchor))
474 |
475 |
476 | self.st_gcnns_encoder_past_motion=nn.ModuleList()
477 | #0
478 | self.st_gcnns_encoder_past_motion.append(ST_GCNN_layer(input_channels,128,[3,1],1,self.output_len,
479 | joints_to_consider,st_gcnn_dropout,pose_info=pose_info))
480 | #1
481 | self.st_gcnns_encoder_past_motion.append(ST_GCNN_layer(128,64,[3,1],1,self.output_len,
482 | joints_to_consider,st_gcnn_dropout, version=1, pose_info=pose_info))
483 | #2
484 | self.st_gcnns_encoder_past_motion.append(ST_GCNN_layer(64,128,[3,1],1,self.output_len,
485 | joints_to_consider,st_gcnn_dropout, version=1, pose_info=pose_info))
486 | #3
487 | self.st_gcnns_encoder_past_motion.append(ST_GCNN_layer(128,128,[3,1],1,self.output_len,
488 | joints_to_consider,st_gcnn_dropout, pose_info=pose_info))
489 |
490 | self.st_gcnns_compress=nn.ModuleList()
491 | #0
492 | self.st_gcnns_compress.append(ST_GCNN_layer_down(256,512,[2,2],2,self.output_len,
493 | joints_to_consider,st_gcnn_dropout, pose_info=pose_info))
494 | #2
495 | self.st_gcnns_compress.append(ST_GCNN_layer_down(512,768,[2,2],2,self.output_len//2,
496 | joints_to_consider//2,st_gcnn_dropout, pose_info=pose_info))
497 |
498 | self.st_gcnns_compress.append(ST_GCNN_layer_down(768,1024,[2,2],2,self.output_len//4,
499 | joints_to_consider//4,st_gcnn_dropout, pose_info=pose_info))
500 |
501 |
502 | down_fc = [EqualLinear(1024, 1024,activation=True)]
503 | for i in range(1):
504 | down_fc.append(EqualLinear(1024, 512,activation=True))
505 |
506 | down_fc.append(EqualLinear(512, self.num_D))
507 | self.down_fc = nn.Sequential(*down_fc)
508 |
509 |
510 | self.direction = Direction(motion_dim=self.num_D)
511 |
512 | self.st_gcnns_decoder=nn.ModuleList()
513 |
514 | #4
515 | self.st_gcnns_decoder.append(ST_GCNN_layer(128+256,128,[3,1],1,self.output_len,
516 | joints_to_consider,st_gcnn_dropout, version=1, pose_info=pose_info))
517 | self.st_gcnns_decoder[-1].gcn.A = self.st_gcnns_encoder_past_motion[-2].gcn.A
518 |
519 | #5
520 | self.st_gcnns_decoder.append(ST_GCNN_layer(128,64,[3,1],1,self.output_len,
521 | joints_to_consider,st_gcnn_dropout, pose_info=pose_info))
522 | self.st_gcnns_decoder[-1].gcn.A = self.st_gcnns_encoder_past_motion[-1].gcn.A
523 | #6
524 | self.st_gcnns_decoder.append(ST_GCNN_layer(64,128,[3,1],1,self.output_len,
525 | joints_to_consider,st_gcnn_dropout, version=1, pose_info=pose_info))
526 | self.st_gcnns_decoder[-1].gcn.A = self.st_gcnns_decoder[-3].gcn.A
527 | #7
528 | self.st_gcnns_decoder.append(ST_GCNN_layer(128,input_channels,[3,1],1,self.output_len,
529 | joints_to_consider,st_gcnn_dropout, pose_info=pose_info))
530 |
531 |
532 | self.dct_m, self.idct_m = self.get_dct_matrix(self.t_his + self.t_pred)
533 |
534 |
535 |
536 | def encode_past_motion(self,x_input):
537 | #x_input: [t_full, bs, V*C]
538 |
539 | # [t_full, bs, V*C] -> [t_full, bs, V, C] -> [bs, c, t_full, v]
540 | x_input = x_input.view(x_input.shape[0], x_input.shape[1], -1, 3).permute(1, 3, 0, 2)
541 | y = torch.zeros((x_input.shape[0], x_input.shape[1], self.t_pred, x_input.shape[3])).to(x_input.device)
542 | # [bs, c, t_full, v] -> [bs, t_full, c, v]
543 | x_padding = torch.cat([x_input[:,:,:self.t_his,], y], dim=2).permute(0, 2, 1, 3)
544 |
545 | N, T, C, V = x_padding.shape
546 |
547 | # [bs, t_full, C, V] -> [bs, t_full, C*V]
548 |
549 | x_padding = x_padding.reshape([N, T, C * V])
550 |
551 |
552 | dct_m = self.dct_m.to(x_input.device)
553 | idx_pad = list(range(self.t_his)) + [self.t_his - 1] * self.t_pred
554 |
555 | # [bs, t_full, C*V] -> [bs, t_full, C*V]
556 | x_pad = torch.matmul(dct_m[:self.output_len], x_padding[:, idx_pad, :]).reshape([N, -1, C, V]).permute(0, 2, 1, 3)
557 | x = x_pad # [N, C, T, V]
558 |
559 | for gcn in (self.st_gcnns_encoder_past_motion): #0-3 layer
560 | x = gcn(x)
561 | N, C, T, V = x.shape
562 |
563 | return x
564 |
565 | def decoding(self,z,condition=None):
566 | idct_m = self.idct_m.to(z.device)
567 |
568 | condition = condition.view(condition.shape[0], condition.shape[1], -1, 3).permute(1, 3, 0, 2) #(T_his, bs, Num_Joints, 3) -> [bs, t_full, 3, Num_joints]
569 | y_condition = torch.zeros((condition.shape[0], condition.shape[1], self.t_pred, condition.shape[3])).to(condition.device)
570 | condition_padding = torch.cat([condition[:,:,:self.t_his,:], y_condition], dim=2).permute(0, 2, 1, 3)
571 | N, T, C, V = condition_padding.shape
572 |
573 | condition_padding = condition_padding.reshape([N, T, C * V])
574 | dct_m = self.dct_m.to(condition.device)
575 | idx_pad = list(range(self.t_his)) + [self.t_his - 1] * self.t_pred
576 | condition_p = torch.matmul(dct_m[:self.output_len], condition_padding[:, idx_pad, :]).reshape([N, -1, C, V]).permute(0, 2, 1, 3)
577 | if condition_p.shape[0] != z.shape[0]:
578 | condition_p = condition_p.repeat_interleave(self.nk, dim=0)
579 |
580 | for gcn in (self.st_gcnns_decoder): #0-3 layer
581 | z = gcn(z)
582 |
583 | output = (z + condition_p)
584 | N, C, N_fre, V = output.shape
585 |
586 | output = output.permute(0, 2, 1, 3).reshape([N, -1, C * V])
587 |
588 | outputs = torch.matmul(idct_m[:, :self.output_len], output).reshape([N, -1, C, V]).permute(1, 0, 3, 2).contiguous().view(-1,N,C*V)
589 |
590 | return outputs
591 |
592 |
593 | def forward(self, x, z=None,epoch=None):
594 | bs = x.shape[1]
595 | z = self.encode_past_motion(x).repeat_interleave(self.nk,dim=0)
596 | replicated_parameters = torch.cat([self.anchor_input[f'anchor_{i}'].expand(self.nk//self.num_anchor, -1) for i in range(self.num_anchor)], dim=0)
597 | anchors_input = replicated_parameters.repeat(bs, 1)
598 | z1 = torch.cat((anchors_input.unsqueeze(2).unsqueeze(3).repeat(1, 1, self.output_len, self.num_joints),z),dim=1)
599 | for gcn in (self.st_gcnns_compress): #0-3 layer
600 | z1 = gcn(z1)
601 | z1 = z1.mean(-1).mean(-1).view(bs*self.nk,-1)
602 | alpha = self.down_fc(z1)
603 | directions = self.direction(alpha)
604 |
605 | N, C, T, V = z.shape
606 | feature = torch.cat((directions.unsqueeze(2).unsqueeze(3).repeat(1, 1, T, V),z),dim=1)
607 |
608 |
609 | outputs = self.decoding(feature, x)
610 |
611 | return outputs , feature, feature
612 |
613 |
614 | def get_dct_matrix(self, N, is_torch=True):
615 | dct_m = np.eye(N, dtype=np.float32)
616 | for k in np.arange(N):
617 | for i in np.arange(N):
618 | w = np.sqrt(2 / N)
619 | if k == 0:
620 | w = np.sqrt(1 / N)
621 | dct_m[k, i] = w * np.cos(np.pi * (i + 1 / 2) * k / N)
622 | idct_m = np.linalg.inv(dct_m)
623 | if is_torch:
624 | dct_m = torch.from_numpy(dct_m)
625 | idct_m = torch.from_numpy(idct_m)
626 | return dct_m, idct_m
627 |
628 |
629 |
630 |
--------------------------------------------------------------------------------