├── DiffHM.py ├── LICENSE ├── README.md ├── Supplementary Material.pdf ├── __pycache__ └── DiffHM.cpython-37.pyc ├── configs └── baseline.yaml ├── dataset └── utils │ ├── __pycache__ │ ├── dataset.cpython-37.pyc │ ├── dataset_generated_motions.cpython-37.pyc │ ├── dataset_h36m.cpython-37.pyc │ └── skeleton.cpython-37.pyc │ ├── dataset.py │ ├── dataset_generated_motions.py │ ├── dataset_h36m.py │ └── skeleton.py ├── main.py ├── models ├── Diffusion.py ├── MotionDiff.py ├── PoseFormer.py ├── __pycache__ │ ├── Diffusion.cpython-37.pyc │ ├── MotionDiff.cpython-37.pyc │ ├── PoseFormer.cpython-37.pyc │ ├── common.cpython-37.pyc │ ├── mao_gcn.cpython-37.pyc │ ├── motion_pred.cpython-37.pyc │ └── rnn.cpython-37.pyc ├── common.py ├── mao_gcn.py ├── motion_pred.py └── rnn.py ├── requirements.txt ├── utils ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── logger.cpython-37.pyc │ └── torch.cpython-37.pyc ├── logger.py └── torch.py └── visualization ├── __pycache__ └── visualization.cpython-37.pyc └── visualization.py /DiffHM.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | import pdb 5 | import numpy as np 6 | import os.path as osp 7 | import logging 8 | import time 9 | from torch import nn, optim, utils 10 | import torch.nn as nn 11 | from torch.utils.tensorboard import SummaryWriter 12 | import pickle 13 | from progress.bar import Bar 14 | from torch.autograd import Variable 15 | 16 | from utils.logger import create_logger 17 | from dataset.utils.dataset_h36m import DatasetH36M 18 | from models.motion_pred import * 19 | from models.common import * 20 | from dataset.utils.dataset_generated_motions import CustomH36M 21 | from torch.utils.data import DataLoader 22 | from tqdm import tqdm 23 | import dataset.utils as dutil 24 | 25 | 26 | 27 | class DiffHM(): 28 | def __init__(self, config): 29 | self.config = config 30 | torch.backends.cudnn.benchmark = True 31 | self.device = torch.device('cuda', index=config.gpu_index) if torch.cuda.is_available() else torch.device('cpu') 32 | self.dtype = torch.float32 33 | self._build() 34 | 35 | 36 | def train_diff(self): 37 | model = self.model 38 | config = self.config 39 | t_his = config.obs_frames 40 | t_pred = config.pred_frames 41 | model.train() 42 | print('==========================') 43 | for epoch in range(config.iter_start_diff, config.num_diff_epoch): 44 | t_s = time.time() 45 | train_losses = 0 46 | total_num_sample = 0 47 | loss_names = ['MSE'] 48 | generator = self.dataset_train.sampling_generator(num_samples=config.num_diff_data_sample, batch_size=config.batch_size) 49 | 50 | for traj_np in generator: 51 | traj_np = traj_np[..., 1:, :] 52 | traj = tensor(traj_np, device=self.device, dtype=self.dtype) 53 | X = traj[:, :t_his, :, :] 54 | Y = traj[:, t_his:, :, :] 55 | loss = model.get_loss(X, Y) 56 | self.optimizer.zero_grad() 57 | loss.backward() 58 | # gradient clipped 59 | if config.max_norm: 60 | nn.utils.clip_grad_norm(self.model.parameters(), max_norm=1) 61 | self.optimizer.step() 62 | train_losses += loss 63 | total_num_sample += 1 64 | 65 | self.scheduler.step() 66 | dt = time.time() - t_s 67 | train_losses /= total_num_sample 68 | lr = self.optimizer.param_groups[0]['lr'] 69 | self.logger.info('====> Epoch: {} Time: {:.2f} MSE: {} lr: {:.5f}'.format(epoch+1, dt, train_losses, lr)) 70 | self.log_writer_diff.add_scalar('DiffMotion_' + str(loss_names), train_losses, epoch) 71 | 72 | 73 | ############ Saving model ############### 74 | if config.save_model_interval > 0 and (epoch + 1) % config.save_model_interval == 0: 75 | with to_cpu(model): 76 | cp_path = self.pretrained_model_dir_diff % (epoch + 1) 77 | model_cp = {'model_dict': model.state_dict(), 'meta': {'std': self.dataset_train.std, 'mean': self.dataset_train.mean}} 78 | pickle.dump(model_cp, open(cp_path, 'wb')) 79 | 80 | 81 | 82 | def generate_diff(self): 83 | device = self.device 84 | dtype = self.dtype 85 | config = self.config 86 | t_his = config.obs_frames 87 | torch.set_grad_enabled(False) 88 | logger_test = create_logger(os.path.join(self.model_dir_log, "log_eval.txt")) 89 | 90 | # get dataset 91 | dataset = self.dataset_train 92 | 93 | # get models 94 | algos = ['diff'] 95 | models = {} 96 | for algo in algos: 97 | models[algo] = get_diff_model(config, self.dataset_train.traj_dim) 98 | cp_path = self.pretrained_model_dir_diff % config.eval_at_diff 99 | print('loading diffusion model from checkpoint: %s' % cp_path) 100 | diff_cp = pickle.load(open(cp_path, "rb")) 101 | models[algo].load_state_dict(diff_cp['model_dict']) 102 | models[algo].to(device) 103 | models[algo].eval() 104 | 105 | # normalize 106 | if config.normalize_data: 107 | dataset.normalize_data(diff_cp['meta']['mean'], diff_cp['meta']['std']) 108 | 109 | # generate 50 diversity training samples 110 | data_gen = dataset.sampling_generator(num_samples=config.num_generate_diff_data_sample, batch_size=config.generate_diff_batch_size) 111 | num_seeds = config.num_seeds 112 | 113 | data_diff = [] 114 | data_gt = [] 115 | count = 0 116 | 117 | for traj_np in data_gen: 118 | traj_gt = traj_np[..., 1:, :].reshape(traj_np.shape[0], traj_np.shape[1], -1) 119 | traj_gt = torch.squeeze(tensor(traj_gt, dtype=self.dtype)) 120 | data_gt.append(traj_gt) 121 | 122 | traj_np = tensor(traj_np, device=self.device, dtype=self.dtype) 123 | pred = get_prediction(config, models, traj_np, algo="diff", sample_num=config.nk, device=device, dtype=dtype, num_seeds=num_seeds, concat_hist=True) 124 | data_diff.append(torch.squeeze(torch.tensor(pred))) 125 | 126 | count += 1 127 | if count % 500 == 0: 128 | print(count) 129 | 130 | 131 | data_diff = torch.stack(data_diff) 132 | data_gt = torch.stack(data_gt) 133 | generated_diff = osp.join(self.generated_motions, "generate_diversity_diff.pth") 134 | generated_gt = osp.join(self.generated_motions, "generate_diversity_gt.pth") 135 | torch.save(data_diff.to(torch.device('cpu')), generated_diff) 136 | torch.save(data_gt.to(torch.device('cpu')), generated_gt) 137 | 138 | 139 | 140 | def train_refine(self): 141 | device = self.device 142 | dtype = self.dtype 143 | config = self.config 144 | logger_refine = create_logger(os.path.join(self.model_dir_log, "log_refine.txt")) 145 | 146 | """data""" 147 | t_his = config.obs_frames 148 | t_pred = config.pred_frames 149 | nk = config.nk 150 | 151 | # generated dataset and dataloader 152 | dataset = self.dataset_custom 153 | dataloader = self.dataloader_custom 154 | 155 | """model""" 156 | optimizer = self.optimizer 157 | refine = self.refine 158 | lr_now = config.refine_lr 159 | 160 | for epoch in range(config.iter_start_refine, config.num_refine_epoch): 161 | 162 | if (epoch + 1) % config.gcn_lr_decay == 0: 163 | lr_now = lr_decay(optimizer, lr_now, config.gcn_lr_gamma) 164 | 165 | t_l = AccumLoss() 166 | refine.train() 167 | 168 | i = 0 169 | st = time.time() 170 | # bar = Bar('>>>', fill='>', max=len(dataloader)) 171 | train_losses = 0 172 | total_num_sample = 0 173 | loss_names = ['TOTAL', 'RECON', 'R1', 'JL'] 174 | 175 | for (inputs, targets, all_seqs) in tqdm(dataloader): 176 | 177 | b, f, c = targets.shape 178 | # batch_size = inputs.shape[0] 179 | bt = time.time() 180 | if torch.cuda.is_available(): 181 | inputs = Variable(inputs.cuda()).float() 182 | targets = Variable(targets.cuda()).float() 183 | all_seqs = Variable(all_seqs.cuda()).float() 184 | 185 | X = inputs.reshape(-1, config.dct_n, c) 186 | X = X.permute(0, 2, 1).contiguous() 187 | outputs = refine(X) 188 | 189 | # IDCT 190 | _, idct_m = dutil.dataset_generated_motions.get_dct_matrix(t_his + t_pred) 191 | idct_m = Variable(torch.from_numpy(idct_m)).float().cuda() 192 | outputs_t = outputs.view(-1, config.dct_n).transpose(0, 1) 193 | outputs_g = torch.matmul(idct_m[:, 0:config.dct_n], outputs_t).transpose(0, 1).contiguous().view(-1, c, t_his + t_pred).transpose(1,2) 194 | 195 | Y = targets.permute(1, 0, 2).contiguous() 196 | Y_g = outputs_g.permute(1, 0, 2).contiguous() 197 | loss, losses = loss_function(config, Y_g, Y, device, dtype, all_seqs) 198 | optimizer.zero_grad() 199 | loss.backward() 200 | if config.gcn_max_norm: 201 | nn.utils.clip_grad_norm(refine.parameters(), max_norm=1) 202 | optimizer.step() 203 | 204 | train_losses += losses 205 | total_num_sample += 1 206 | 207 | dt = time.time() - st 208 | train_losses /= total_num_sample 209 | losses_str = ' '.join(['{}: {:.4f}'.format(x, y) for x, y in zip(loss_names, train_losses)]) 210 | logger_refine.info('====> Epoch: {} Time: {:.2f} {} lr: {:.5f}'.format(epoch, dt, losses_str, lr_now)) 211 | for name, loss in zip(loss_names, train_losses): 212 | self.log_writer_refine.add_scalar('refine_' + name, loss, epoch) 213 | 214 | if config.save_model_interval > 0 and (epoch + 1) % config.save_model_interval == 0: 215 | with to_cpu(refine): 216 | cp_path = self.pretrained_model_dir_refine % (epoch + 1) 217 | model_cp = {'model_dict': refine.state_dict()} 218 | pickle.dump(model_cp, open(cp_path, 'wb')) 219 | 220 | 221 | 222 | def eval(self): 223 | device = self.device 224 | dtype = self.dtype 225 | config = self.config 226 | torch.set_grad_enabled(False) 227 | logger_test = create_logger(os.path.join(self.model_dir_log, "log_eval.txt")) 228 | 229 | algos = [] 230 | all_algos = ['refine', 'diff'] 231 | for algo in all_algos: 232 | iter_algo = 'iter_%s' % algo 233 | num_algo = 'eval_at_%s' % algo 234 | setattr(config, iter_algo, getattr(config, num_algo)) 235 | algos.append(algo) 236 | vis_algos = algos.copy() 237 | 238 | # get dataset 239 | dataset = self.dataset_test 240 | 241 | # get models 242 | model_generator = { 243 | 'refine': get_refine_model, 244 | 'diff': get_diff_model 245 | } 246 | models = {} 247 | for algo in all_algos: 248 | models[algo] = model_generator[algo](config, dataset.traj_dim) 249 | if algo == 'diff': 250 | model_path = self.pretrained_model_dir_diff % getattr(config, f'iter_{algo}') 251 | elif algo == 'refine': 252 | model_path = self.pretrained_model_dir_refine % getattr(config, f'iter_{algo}') 253 | print(f'loading {algo} model from checkpoint: {model_path}') 254 | model_cp = pickle.load(open(model_path, "rb")) 255 | models[algo].load_state_dict(model_cp['model_dict']) 256 | models[algo].to(device) 257 | models[algo].eval() 258 | 259 | # visualization or compute statistics 260 | if config.mode_test == 'vis': 261 | visualize(config, models, dataset, self.device, self.dtype, algos, self.dir_out) 262 | elif config.mode_test == 'stats': 263 | compute_stats(config, models, dataset, self.device, self.dtype, vis_algos, logger_test, self.dir_out) 264 | 265 | 266 | 267 | 268 | 269 | def _build(self): 270 | self._build_dir() 271 | 272 | if self.config.mode == "train_diff": 273 | self._build_train_loader() 274 | elif self.config.mode == "generate_diff": 275 | self._build_train_loader() 276 | elif self.config.mode == "train_refine": 277 | self._build_custom_loader() 278 | elif self.config.mode == "test": 279 | self._build_val_loader() 280 | 281 | self._build_model() 282 | self._build_optimizer() 283 | 284 | print("> Everything built. Have fun :)") 285 | 286 | 287 | def _build_dir(self): 288 | self.model_dir = osp.join("./results", self.config.dataset) 289 | os.makedirs(self.model_dir, exist_ok=True) 290 | self.log_writer_diff = SummaryWriter(osp.join(self.model_dir, "tb")) if self.config.mode == "train_diff" else None 291 | self.log_writer_refine = SummaryWriter(osp.join(self.model_dir, "tb")) if self.config.mode == "train_refine" else None 292 | self.model_dir_log = osp.join(self.model_dir, "log") 293 | os.makedirs(self.model_dir_log, exist_ok=True) 294 | self.dir_out = osp.join(self.model_dir, "out") 295 | os.makedirs(self.dir_out, exist_ok=True) 296 | self.logger = create_logger(osp.join(self.model_dir_log, "log_diff.txt")) 297 | tmp = osp.join(self.model_dir, "models") 298 | os.makedirs(tmp, exist_ok=True) 299 | self.pretrained_model_dir_diff = osp.join(tmp, "diffMotion_%04d.p") 300 | self.pretrained_model_dir_refine = osp.join(tmp, "refine_%04d.p") 301 | self.generated_motions = osp.join(self.model_dir, "generated_diff") 302 | 303 | print("> Directory built!") 304 | 305 | 306 | def _build_optimizer(self): 307 | if self.config.mode == "train_diff": 308 | self.optimizer = optim.Adam(self.model.parameters(), lr=self.config.lr) 309 | self.scheduler = get_scheduler(self.optimizer, policy='lambda', nepoch_fix=self.config.num_diff_epoch_fix, 310 | nepoch=self.config.num_diff_epoch) 311 | elif self.config.mode == "train_refine": 312 | self.optimizer = optim.Adam(self.refine.parameters(), lr=self.config.refine_lr) 313 | 314 | print("> Optimizer built!") 315 | 316 | 317 | def _build_model(self): 318 | """ Define Model """ 319 | config = self.config 320 | if self.config.mode == "train_diff": 321 | model = get_diff_model(config, self.dataset_train.traj_dim) 322 | self.model = model.to(self.device) 323 | print("> Model built!") 324 | 325 | elif self.config.mode == "train_refine": 326 | refine = get_refine_model(config, 48) 327 | self.refine = refine.to(self.device) 328 | print("> Model built!") 329 | 330 | 331 | # loading model from checkpoint 332 | if config.iter_start_diff > 0: 333 | if self.config.mode == "train_diff": 334 | cp_path = self.pretrained_model_dir_diff % config.iter_start_diff 335 | print('loading diff model from checkpoint: %s' % cp_path) 336 | model_cp = pickle.load(open(cp_path, "rb")) 337 | model.load_state_dict(model_cp['model_dict']) 338 | if config.iter_start_refine > 0: 339 | if self.config.mode == "train_refine": 340 | cp_path = self.pretrained_model_dir_refine % config.iter_start_refine 341 | print('loading refine model from checkpoint: %s' % cp_path) 342 | model_cp = pickle.load(open(cp_path, "rb")) 343 | model.load_state_dict(model_cp['model_dict']) 344 | 345 | 346 | 347 | def _build_train_loader(self): 348 | config = self.config 349 | print(">>> loading data...") 350 | 351 | t_his = config.obs_frames 352 | t_pred = config.pred_frames 353 | 354 | dataset_cls = DatasetH36M 355 | dataset = dataset_cls('train', t_his, t_pred, actions='all', use_vel=config.use_vel) 356 | if config.normalize_data: 357 | dataset.normalize_data() 358 | 359 | self.dataset_train = dataset 360 | 361 | 362 | def _build_custom_loader(self): 363 | config = self.config 364 | print(">>> loading data...") 365 | 366 | t_his = config.obs_frames 367 | t_pred = config.pred_frames 368 | train_dataset = CustomH36M(config, path_to_data=self.generated_motions, input_n=t_his, output_n=t_pred, dct_used=config.dct_n) 369 | train_loader = DataLoader( 370 | dataset=train_dataset, 371 | batch_size=config.refine_batch_size, 372 | shuffle=True, 373 | num_workers=0, 374 | pin_memory=True) 375 | self.dataset_custom = train_dataset 376 | self.dataloader_custom = train_loader 377 | 378 | 379 | def _build_val_loader(self): 380 | config = self.config 381 | t_his = config.obs_frames 382 | t_pred = config.pred_frames 383 | 384 | dataset_cls = DatasetH36M 385 | dataset = dataset_cls('test', t_his, t_pred, actions='all', use_vel=config.use_vel) 386 | if config.normalize_data: 387 | dataset.normalize_data() 388 | 389 | self.dataset_test = dataset 390 | 391 | print("> Dataset built!") 392 | 393 | 394 | 395 | 396 | 397 | 398 | 399 | 400 | 401 | 402 | 403 | 404 | 405 | 406 | 407 | 408 | 409 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 csdwei 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MotionDiff 2 | Code for AAAI2023 paper "Human Joint Kinematics Diffusion-Refinement for Stochastic Motion Prediction" 3 | 4 | By Dong Wei, Huaijiang Sun, Bin Li, Jianfeng Lu, Weiqing Li, Xiaoning Sun and Shengxiang Hu 5 | 6 | > Stochastic human motion prediction aims to forecast multiple plausible future motions given a single pose sequence from the past. Most previous works focus on designing elaborate losses to improve the accuracy, while the diversity is typically characterized by randomly sampling a set of latent variables from the latent prior, which is then decoded into possible motions. This joint training of sampling and decoding, however, suffers from posterior collapse as the learned latent variables tend to be ignored by a strong decoder, leading to limited diversity. Alternatively, inspired by the diffusion process in nonequilibrium thermodynamics, we propose MotionDiff, a diffusion probabilistic model to treat the kinematics of human joints as heated particles, which will diffuse from original states to a noise distribution. This process not only offers a natural way to obtain the ``whitened'' latents without any trainable parameters, but also introduces a new noise in each diffusion step, both of which facilitate more diverse motions. Human motion prediction is then regarded as the reverse diffusion process that converts the noise distribution into realistic future motions conditioned on the observed sequence. Specifically, MotionDiff consists of two parts: a spatial-temporal transformer-based diffusion network to generate diverse yet plausible motions, and a flexible refinement network to further enable geometric losses and align with the ground truth. Experimental results on two datasets demonstrate that our model yields the competitive performance in terms of both diversity and accuracy. 7 | 8 | 9 | # Code 10 | 11 | ## Environment 12 | PyTorch == 1.7.1 13 | CUDA > 10.1 14 | 15 | 16 | ## Training & Evaluation 17 | 18 | ### Step 1: Modify or create your own config file in ```/configs``` 19 | 20 | You can revise parameters and seeds in config file as you like and change the network architecture of the diffusion model in ```models/Diffusion.py``` 21 | 22 | ### Step 2: Train the Diffusion Network 23 | 24 | ```python main.py --config configs/baseline.yaml --mode train_diff``` 25 | 26 | Logs and checkpoints will be automatically saved. 27 | 28 | ### Step 3: Generate the motions by Diffusion Network 29 | 30 | ```python main.py --config configs/baseline.yaml --mode generate_diff``` 31 | 32 | Since the sampling process of the diffusion network may take a long time, we generate future motions in advance. The obtained motions will be saved in ```./results/generated_diff``` 33 | 34 | ### Step 4: Train the Refinement Network 35 | 36 | ```python main.py --config configs/baseline.yaml --mode train_refine``` 37 | 38 | Logs and checkpoints will be automatically saved. 39 | 40 | ### Step 5: Evaluation 41 | 42 | ```python main.py --config configs/baseline.yaml --mode test``` 43 | 44 | Evaluation for diffusion network and diffusion-refinement architecture, including statistics (APD, ADE, FDE, MMADE, MMFDE) and visualizations. 45 | 46 | ### Citation 47 | ``` 48 | @article{wei2022human, 49 | title={Human Joint Kinematics Diffusion-Refinement for Stochastic Motion Prediction}, 50 | author={Wei, Dong and Sun, Huaijiang and Li, Bin and Lu, Jianfeng and Li, Weiqing and Sun, Xiaoning and Hu, Shengxiang}, 51 | journal={arXiv preprint arXiv:2210.05976}, 52 | year={2022} 53 | } 54 | ``` 55 | 56 | ### Acknowledgements 57 | 58 | Part of our code is borrowed from [DLow](https://github.com/Khrylx/DLow), [PoseFormer](https://github.com/zczcwh/PoseFormer), [LTD-GCN](https://github.com/wei-mao-2019/LearnTrajDep) and [MID](https://github.com/gutianpei/MID). We thank the authors for releasing the codes. -------------------------------------------------------------------------------- /Supplementary Material.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csdwei/MotionDiff/416a94b37c426d1a042ed72c6b0efe0b79ec0a7a/Supplementary Material.pdf -------------------------------------------------------------------------------- /__pycache__/DiffHM.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csdwei/MotionDiff/416a94b37c426d1a042ed72c6b0efe0b79ec0a7a/__pycache__/DiffHM.cpython-37.pyc -------------------------------------------------------------------------------- /configs/baseline.yaml: -------------------------------------------------------------------------------- 1 | ############ Experimental Settings ######### 2 | mode: train_diff # order: train_diff, generate_diff, train_refine, test 3 | seed: 1 4 | num_seeds: 1 5 | # device 6 | gpu_index: 0 7 | # dataset 8 | dataset: h36m 9 | data_dir: data/h3.6m/dataset 10 | obs_frames: 25 11 | pred_frames: 100 12 | nk: 50 # the number of generated sequences 13 | use_vel: False 14 | normalize_data: False 15 | # model 16 | model_name: MotionDiff 17 | 18 | ############ DiffMotion Config ############# 19 | batch_size: 64 20 | # optimizer 21 | lr: 0.0005 22 | num_diff_epoch_fix: 100 23 | num_diff_epoch: 1000 24 | # Train 25 | save_model_interval: 50 26 | iter_start_diff: 0 27 | num_diff_data_sample: 4000 28 | max_norm: False # whether gradient clipping 29 | num_generate_diff_data_sample: 5000 30 | generate_diff_batch_size: 1 31 | # DDPM Parameters 32 | num_steps: 100 # the number of denoise procedure 33 | beta_1: 1.0e-4 34 | beta_T: 5.0e-2 35 | flexibility: 0.0 36 | ret_traj: False 37 | pose_embed_dim: 32 38 | drop_path_rate: 0.1 39 | drop_rate_poseformer: 0.0 40 | encoder_rnn: False 41 | rnn_type: gru 42 | rnn_output_dim: 512 43 | tf_layer: 3 44 | 45 | ############ Refinement Config ############# 46 | refine_batch_size: 16 47 | # optimizer 48 | refine_lr: 5.0e-4 49 | gcn_lr_decay: 2 50 | gcn_lr_gamma: 0.96 51 | # Train 52 | iter_start_refine: 0 53 | num_refine_epoch: 100 54 | gcn_max_norm: True 55 | # Refine Parameters 56 | dct_n: 80 57 | d_scale: 100000.0 58 | gcn_linear_size: 256 59 | gcn_dropout: 0.5 60 | gcn_layers: 12 61 | lambda_j: 1.0 62 | lambda_recon: 200.0 63 | gamma: 0.01 64 | 65 | # Testing Parameters: 66 | eval_at_diff: 1000 67 | eval_at_refine: 100 68 | multimodal_threshold: 0.5 69 | mode_test: vis # stats or vis 70 | 71 | 72 | 73 | 74 | 75 | 76 | -------------------------------------------------------------------------------- /dataset/utils/__pycache__/dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csdwei/MotionDiff/416a94b37c426d1a042ed72c6b0efe0b79ec0a7a/dataset/utils/__pycache__/dataset.cpython-37.pyc -------------------------------------------------------------------------------- /dataset/utils/__pycache__/dataset_generated_motions.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csdwei/MotionDiff/416a94b37c426d1a042ed72c6b0efe0b79ec0a7a/dataset/utils/__pycache__/dataset_generated_motions.cpython-37.pyc -------------------------------------------------------------------------------- /dataset/utils/__pycache__/dataset_h36m.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csdwei/MotionDiff/416a94b37c426d1a042ed72c6b0efe0b79ec0a7a/dataset/utils/__pycache__/dataset_h36m.cpython-37.pyc -------------------------------------------------------------------------------- /dataset/utils/__pycache__/skeleton.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csdwei/MotionDiff/416a94b37c426d1a042ed72c6b0efe0b79ec0a7a/dataset/utils/__pycache__/skeleton.cpython-37.pyc -------------------------------------------------------------------------------- /dataset/utils/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class Dataset: 5 | 6 | def __init__(self, mode, t_his, t_pred, actions='all'): 7 | self.mode = mode 8 | self.t_his = t_his 9 | self.t_pred = t_pred 10 | self.t_total = t_his + t_pred 11 | self.actions = actions 12 | self.prepare_data() 13 | self.std, self.mean = None, None 14 | self.data_len = sum([seq.shape[0] for data_s in self.data.values() for seq in data_s.values()]) 15 | self.traj_dim = (self.kept_joints.shape[0] - 1) * 3 16 | self.normalized = False 17 | # iterator specific 18 | self.sample_ind = None 19 | 20 | def prepare_data(self): 21 | raise NotImplementedError 22 | 23 | def normalize_data(self, mean=None, std=None): 24 | if mean is None: 25 | all_seq = [] 26 | for data_s in self.data.values(): 27 | for seq in data_s.values(): 28 | all_seq.append(seq[:, 1:]) 29 | all_seq = np.concatenate(all_seq) 30 | self.mean = all_seq.mean(axis=0) 31 | self.std = all_seq.std(axis=0) 32 | else: 33 | self.mean = mean 34 | self.std = std 35 | for data_s in self.data.values(): 36 | for action in data_s.keys(): 37 | data_s[action][:, 1:] = (data_s[action][:, 1:] - self.mean) / self.std 38 | self.normalized = True 39 | 40 | def sample(self): 41 | subject = np.random.choice(self.subjects) 42 | dict_s = self.data[subject] 43 | action = np.random.choice(list(dict_s.keys())) 44 | seq = dict_s[action] 45 | fr_start = np.random.randint(seq.shape[0] - self.t_total) 46 | fr_end = fr_start + self.t_total 47 | traj = seq[fr_start: fr_end] 48 | return traj[None, ...] 49 | 50 | def sampling_generator(self, num_samples=1000, batch_size=8): 51 | for i in range(num_samples // batch_size): 52 | sample = [] 53 | for i in range(batch_size): 54 | sample_i = self.sample() 55 | sample.append(sample_i) 56 | sample = np.concatenate(sample, axis=0) 57 | yield sample 58 | 59 | def iter_generator(self, step=25): 60 | for data_s in self.data.values(): 61 | for seq in data_s.values(): 62 | seq_len = seq.shape[0] 63 | for i in range(0, seq_len - self.t_total, step): 64 | traj = seq[None, i: i + self.t_total] 65 | yield traj 66 | 67 | 68 | 69 | -------------------------------------------------------------------------------- /dataset/utils/dataset_generated_motions.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import numpy as np 3 | import torch 4 | import os 5 | from h5py import File 6 | import scipy.io as sio 7 | from matplotlib import pyplot as plt 8 | 9 | 10 | def get_dct_matrix(N): 11 | dct_m = np.eye(N) 12 | for k in np.arange(N): 13 | for i in np.arange(N): 14 | w = np.sqrt(2 / N) 15 | if k == 0: 16 | w = np.sqrt(1 / N) 17 | dct_m[k, i] = w * np.cos(np.pi * (i + 1 / 2) * k / N) 18 | idct_m = np.linalg.inv(dct_m) 19 | return dct_m, idct_m 20 | 21 | 22 | class CustomH36M(Dataset): 23 | 24 | def __init__(self, config, path_to_data, input_n=25, output_n=100, dct_used=None): 25 | """ 26 | :param path_to_data: 27 | :param input_n: 28 | :param output_n: 29 | """ 30 | self.path_to_data = path_to_data 31 | self.input_n = input_n 32 | self.output_n = output_n 33 | 34 | if dct_used is None: 35 | dct_used = input_n + output_n 36 | 37 | # load generated motions 38 | data_diff = os.path.join(path_to_data, 'generate_diversity_diff.pth') 39 | data_gt = os.path.join(path_to_data, 'generate_diversity_gt.pth') 40 | data = torch.load(data_diff) 41 | data_gt = torch.load(data_gt) 42 | num_coordinate = data_gt.shape[2] 43 | 44 | self.all_seqs = data 45 | data = data.reshape(-1, input_n + output_n, data.shape[-1]) 46 | data = data.permute(0, 2, 1) 47 | data = data.reshape(-1, input_n + output_n) 48 | all_seqs = data.transpose(0, 1) 49 | 50 | 51 | 52 | dct_m_in, _ = get_dct_matrix(input_n + output_n) 53 | input_dct_seq = np.matmul(dct_m_in[0:dct_used, :], all_seqs) 54 | input_dct_seq = input_dct_seq.transpose(0, 1).reshape([-1, config.nk, num_coordinate, dct_used]) 55 | 56 | output_seq = data_gt 57 | 58 | self.input_dct_seq = input_dct_seq.permute(0, 1, 3, 2).contiguous() 59 | self.output_seq = output_seq 60 | 61 | 62 | 63 | def __len__(self): 64 | return np.shape(self.input_dct_seq)[0] 65 | 66 | def __getitem__(self, item): 67 | return self.input_dct_seq[item], self.output_seq[item], self.all_seqs[item] 68 | -------------------------------------------------------------------------------- /dataset/utils/dataset_h36m.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from dataset.utils.dataset import Dataset 4 | from dataset.utils.skeleton import Skeleton 5 | 6 | 7 | class DatasetH36M(Dataset): 8 | 9 | def __init__(self, mode, t_his=25, t_pred=100, actions='all', use_vel=False): 10 | self.use_vel = use_vel 11 | super().__init__(mode, t_his, t_pred, actions) 12 | if use_vel: 13 | self.traj_dim += 3 14 | 15 | def prepare_data(self): 16 | self.data_file = os.path.join('data\h3.6m\dataset', 'data_3d_h36m.npz') 17 | self.subjects_split = {'train': [1, 5, 6, 7, 8], 18 | 'test': [9, 11]} 19 | self.subjects = ['S%d' % x for x in self.subjects_split[self.mode]] 20 | self.skeleton = Skeleton(parents=[-1, 0, 1, 2, 3, 4, 0, 6, 7, 8, 9, 0, 11, 12, 13, 14, 12, 21 | 16, 17, 18, 19, 20, 19, 22, 12, 24, 25, 26, 27, 28, 27, 30], 22 | joints_left=[6, 7, 8, 9, 10, 16, 17, 18, 19, 20, 21, 22, 23], 23 | joints_right=[1, 2, 3, 4, 5, 24, 25, 26, 27, 28, 29, 30, 31]) 24 | self.removed_joints = {4, 5, 9, 10, 11, 16, 20, 21, 22, 23, 24, 28, 29, 30, 31} 25 | self.kept_joints = np.array([x for x in range(32) if x not in self.removed_joints]) 26 | self.skeleton.remove_joints(self.removed_joints) 27 | self.skeleton._parents[11] = 8 28 | self.skeleton._parents[14] = 8 29 | self.process_data() 30 | 31 | def process_data(self): 32 | data_o = np.load(self.data_file, allow_pickle=True)['positions_3d'].item() 33 | data_f = dict(filter(lambda x: x[0] in self.subjects, data_o.items())) 34 | if self.actions != 'all': 35 | for key in list(data_f.keys()): 36 | data_f[key] = dict(filter(lambda x: all([a in x[0] for a in self.actions]), data_f[key].items())) 37 | if len(data_f[key]) == 0: 38 | data_f.pop(key) 39 | for data_s in data_f.values(): 40 | for action in data_s.keys(): 41 | seq = data_s[action][:, self.kept_joints, :] 42 | if self.use_vel: 43 | v = (np.diff(seq[:, :1], axis=0) * 50).clip(-5.0, 5.0) 44 | v = np.append(v, v[[-1]], axis=0) 45 | seq[:, 1:] -= seq[:, :1] 46 | if self.use_vel: 47 | seq = np.concatenate((seq, v), axis=1) 48 | data_s[action] = seq 49 | self.data = data_f 50 | 51 | 52 | if __name__ == '__main__': 53 | np.random.seed(0) 54 | actions = {'WalkDog'} 55 | dataset = DatasetH36M('train', actions=actions) 56 | generator = dataset.sampling_generator() 57 | dataset.normalize_data() 58 | # generator = dataset.iter_generator() 59 | for data in generator: 60 | print(data.shape) 61 | 62 | 63 | -------------------------------------------------------------------------------- /dataset/utils/skeleton.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import numpy as np 9 | 10 | 11 | class Skeleton: 12 | def __init__(self, parents, joints_left, joints_right): 13 | assert len(joints_left) == len(joints_right) 14 | 15 | self._parents = np.array(parents) 16 | self._joints_left = joints_left 17 | self._joints_right = joints_right 18 | self._compute_metadata() 19 | 20 | def num_joints(self): 21 | return len(self._parents) 22 | 23 | def parents(self): 24 | return self._parents 25 | 26 | def has_children(self): 27 | return self._has_children 28 | 29 | def children(self): 30 | return self._children 31 | 32 | def remove_joints(self, joints_to_remove): 33 | """ 34 | Remove the joints specified in 'joints_to_remove'. 35 | """ 36 | valid_joints = [] 37 | for joint in range(len(self._parents)): 38 | if joint not in joints_to_remove: 39 | valid_joints.append(joint) 40 | 41 | for i in range(len(self._parents)): 42 | while self._parents[i] in joints_to_remove: 43 | self._parents[i] = self._parents[self._parents[i]] 44 | 45 | index_offsets = np.zeros(len(self._parents), dtype=int) 46 | new_parents = [] 47 | for i, parent in enumerate(self._parents): 48 | if i not in joints_to_remove: 49 | new_parents.append(parent - index_offsets[parent]) 50 | else: 51 | index_offsets[i:] += 1 52 | self._parents = np.array(new_parents) 53 | 54 | 55 | if self._joints_left is not None: 56 | new_joints_left = [] 57 | for joint in self._joints_left: 58 | if joint in valid_joints: 59 | new_joints_left.append(joint - index_offsets[joint]) 60 | self._joints_left = new_joints_left 61 | if self._joints_right is not None: 62 | new_joints_right = [] 63 | for joint in self._joints_right: 64 | if joint in valid_joints: 65 | new_joints_right.append(joint - index_offsets[joint]) 66 | self._joints_right = new_joints_right 67 | 68 | self._compute_metadata() 69 | 70 | return valid_joints 71 | 72 | def joints_left(self): 73 | return self._joints_left 74 | 75 | def joints_right(self): 76 | return self._joints_right 77 | 78 | def _compute_metadata(self): 79 | self._has_children = np.zeros(len(self._parents)).astype(bool) 80 | for i, parent in enumerate(self._parents): 81 | if parent != -1: 82 | self._has_children[parent] = True 83 | 84 | self._children = [] 85 | for i, parent in enumerate(self._parents): 86 | self._children.append([]) 87 | for i, parent in enumerate(self._parents): 88 | if parent != -1: 89 | self._children[parent].append(i) -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from DiffHM import DiffHM 2 | import argparse 3 | import os 4 | import yaml 5 | # from pprint import pprint 6 | from easydict import EasyDict 7 | import numpy as np 8 | import torch 9 | 10 | 11 | 12 | def parse_args(): 13 | parser = argparse.ArgumentParser( 14 | description='Pytorch implementation of MotionDiff') 15 | parser.add_argument('--config', default='configs/baseline.yaml') 16 | return parser.parse_args() 17 | 18 | 19 | def main(): 20 | 21 | # parse arguments and load config 22 | args = parse_args() 23 | 24 | with open(args.config) as f: 25 | config = yaml.safe_load(f) 26 | 27 | for k, v in vars(args).items(): 28 | config[k] = v 29 | 30 | config = EasyDict(config) 31 | 32 | """setup""" 33 | np.random.seed(config.seed) 34 | torch.manual_seed(config.seed) 35 | 36 | 37 | agent = DiffHM(config) 38 | 39 | if config["mode"] == 'train_diff': 40 | agent.train_diff() 41 | elif config["mode"] == 'train_refine': 42 | agent.train_refine() 43 | elif config["mode"] == 'generate_diff': 44 | agent.generate_diff() 45 | else: 46 | agent.eval() 47 | 48 | 49 | 50 | 51 | 52 | if __name__ == '__main__': 53 | main() 54 | -------------------------------------------------------------------------------- /models/Diffusion.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import torch.nn as nn 5 | import torch 6 | import math 7 | import numpy as np 8 | 9 | 10 | 11 | class VarianceSchedule(nn.Module): 12 | 13 | def __init__(self, num_steps, mode_beta='linear', beta_1=1e-4, beta_T=5e-2, cosine_s=8e-3): 14 | super().__init__() 15 | assert mode_beta in ('linear', 'cosine') 16 | self.num_steps = num_steps 17 | self.beta_1 = beta_1 18 | self.beta_T = beta_T 19 | self.mode = mode_beta 20 | 21 | if mode_beta == 'linear': 22 | betas = torch.linspace(beta_1, beta_T, steps=num_steps) 23 | elif mode_beta == 'cosine': 24 | timesteps = ( 25 | torch.arange(num_steps + 1) / num_steps + cosine_s 26 | ) 27 | alphas = timesteps / (1 + cosine_s) * math.pi / 2 28 | alphas = torch.cos(alphas).pow(2) 29 | alphas = alphas / alphas[0] 30 | betas = 1 - alphas[1:] / alphas[:-1] 31 | betas = betas.clamp(max=0.999) 32 | 33 | betas = torch.cat([torch.zeros([1]), betas], dim=0) # Padding 34 | 35 | alphas = 1 - betas 36 | log_alphas = torch.log(alphas) 37 | for i in range(1, log_alphas.size(0)): # 1 to T 38 | log_alphas[i] += log_alphas[i - 1] 39 | alpha_bars = log_alphas.exp() 40 | 41 | sigmas_flex = torch.sqrt(betas) 42 | sigmas_inflex = torch.zeros_like(sigmas_flex) 43 | for i in range(1, sigmas_flex.size(0)): 44 | sigmas_inflex[i] = ((1 - alpha_bars[i-1]) / (1 - alpha_bars[i])) * betas[i] 45 | sigmas_inflex = torch.sqrt(sigmas_inflex) 46 | 47 | self.register_buffer('betas', betas) 48 | self.register_buffer('alphas', alphas) 49 | self.register_buffer('alpha_bars', alpha_bars) 50 | self.register_buffer('sigmas_flex', sigmas_flex) 51 | self.register_buffer('sigmas_inflex', sigmas_inflex) 52 | 53 | 54 | def uniform_sample_t(self, batch_size): 55 | ts = np.random.choice(np.arange(1, self.num_steps+1), batch_size) 56 | return ts.tolist() 57 | 58 | def get_sigmas(self, t, flexibility): 59 | assert 0 <= flexibility and flexibility <= 1 60 | sigmas = self.sigmas_flex[t] * flexibility + self.sigmas_inflex[t] * (1 - flexibility) 61 | return sigmas 62 | 63 | 64 | 65 | class Diffusion(nn.Module): 66 | def __init__(self, config, num_joint): 67 | super().__init__() 68 | self.num_steps = config.num_steps 69 | self.beta_1 = config.beta_1 70 | self.beta_T = config.beta_T 71 | self.num_joint = num_joint 72 | self.config = config 73 | self.var_sched = VarianceSchedule(num_steps=self.num_steps, mode_beta='linear', beta_1=self.beta_1, 74 | beta_T=self.beta_T, cosine_s=8e-3) 75 | 76 | 77 | def sample(self, net, encoded_x, flexibility=0.0, ret_traj=False): 78 | num_sample = encoded_x.shape[0] 79 | t_pred = self.config.pred_frames 80 | dim_each_frame = 3 * self.num_joint 81 | # start from standard Gaussian noise 82 | x_T = torch.randn([num_sample, t_pred, dim_each_frame]).cuda() 83 | traj = {self.var_sched.num_steps: x_T} 84 | 85 | for t in range(self.var_sched.num_steps, 0, -1): 86 | z = torch.randn_like(x_T) if t > 1 else torch.zeros_like(x_T) 87 | alpha = self.var_sched.alphas[t] 88 | alpha_bar = self.var_sched.alpha_bars[t] 89 | sigma = self.var_sched.get_sigmas(t, flexibility) 90 | 91 | c0 = 1.0 / torch.sqrt(alpha) 92 | c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar) 93 | 94 | x_t = traj[t] 95 | beta = self.var_sched.betas[[t] * num_sample] 96 | e_theta = net(encoded_x, x_t, beta) 97 | x_next = c0 * (x_t - c1 * e_theta) + sigma * z 98 | traj[t - 1] = x_next.detach() # Stop gradient and save trajectory. 99 | # traj[t] = traj[t].cpu() # Move previous output to CPU memory. 100 | # if not ret_traj: 101 | # del traj[t] 102 | 103 | if ret_traj: 104 | return traj 105 | else: 106 | return traj[0] 107 | 108 | 109 | 110 | def forward(self, x_0, t=None): 111 | batch_size, _, _ = x_0.size() 112 | if t == None: 113 | t = self.var_sched.uniform_sample_t(batch_size) 114 | 115 | alpha_bar = self.var_sched.alpha_bars[t] 116 | beta = self.var_sched.betas[t].cuda() 117 | 118 | c0 = torch.sqrt(alpha_bar).view(-1, 1, 1) # (B, 1, 1) 119 | c1 = torch.sqrt(1 - alpha_bar).view(-1, 1, 1) # (B, 1, 1) 120 | 121 | e_rand = torch.randn_like(x_0).cuda() 122 | x_T = c0 * x_0 + c1 * e_rand 123 | x_T = x_T.cuda() 124 | 125 | return x_T, e_rand, beta 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | -------------------------------------------------------------------------------- /models/MotionDiff.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import torch.nn as nn 5 | import torch 6 | from einops.layers.torch import Rearrange 7 | from models.PoseFormer import PoseFormer 8 | import math 9 | import torch.nn.functional as F 10 | 11 | 12 | class PositionalEncoding(nn.Module): 13 | def __init__(self, d_model, dropout=0.1, max_len=5000): 14 | super().__init__() 15 | 16 | self.dropout = nn.Dropout(p=dropout) 17 | 18 | pe = torch.zeros(max_len, d_model) 19 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 20 | div_term = torch.exp( 21 | torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model) 22 | ) 23 | pe[:, 0::2] = torch.sin(position * div_term) 24 | pe[:, 1::2] = torch.cos(position * div_term) 25 | pe = pe.unsqueeze(0).transpose(0, 1) 26 | self.register_buffer("pe", pe) 27 | 28 | def forward(self, x): 29 | x = x + self.pe[: x.size(0), :] 30 | return self.dropout(x) 31 | 32 | 33 | class ConcatSquashLinear(nn.Module): 34 | def __init__(self, dim_in, dim_out, dim_ctx): 35 | super(ConcatSquashLinear, self).__init__() 36 | self._layer = nn.Linear(dim_in, dim_out) 37 | self._hyper_bias = nn.Linear(dim_ctx, dim_out, bias=False) 38 | self._hyper_gate = nn.Linear(dim_ctx, dim_out) 39 | 40 | def forward(self, ctx, x): 41 | gate = torch.sigmoid(self._hyper_gate(ctx)) 42 | bias = self._hyper_bias(ctx) 43 | # if x.dim() == 3: 44 | # gate = gate.unsqueeze(1) 45 | # bias = bias.unsqueeze(1) 46 | ret = self._layer(x) * gate + bias 47 | return ret 48 | 49 | 50 | 51 | class MotionDiff(nn.Module): 52 | def __init__(self, config, num_joint): 53 | super().__init__() 54 | self.dct_n = config.dct_n 55 | self.act = F.leaky_relu 56 | self.pose_embed_dim = config.pose_embed_dim 57 | self.rnn_output_dim = config.rnn_output_dim 58 | self.num_joint = num_joint 59 | self.concat1 = ConcatSquashLinear(self.pose_embed_dim * self.num_joint, self.rnn_output_dim, self.rnn_output_dim + 3) 60 | # concat 61 | self.concat2 = ConcatSquashLinear(self.rnn_output_dim, self.rnn_output_dim // 2, self.rnn_output_dim + 3) 62 | self.concat3 = ConcatSquashLinear(self.rnn_output_dim // 2, self.rnn_output_dim // 4, self.rnn_output_dim + 3) 63 | self.concat4 = ConcatSquashLinear(self.rnn_output_dim // 4, 3 * self.num_joint, self.rnn_output_dim + 3) 64 | # encoder 65 | self.poseformer = PoseFormer(config, num_joint=self.num_joint, in_chans=3, num_frame=config.pred_frames, embed_dim=config.pose_embed_dim, 66 | drop_rate=config.drop_rate_poseformer, drop_path_rate=config.drop_path_rate, norm_layer=None) 67 | # decoder 68 | self.pos_emb = PositionalEncoding(d_model=self.rnn_output_dim, dropout=0.1, max_len=200) 69 | self.layer = nn.TransformerEncoderLayer(d_model=self.rnn_output_dim, nhead=4, dim_feedforward=2*self.rnn_output_dim) 70 | self.transformer_encoder = nn.TransformerEncoder(self.layer, num_layers=config.tf_layer) 71 | 72 | 73 | def forward(self, context, x, beta): 74 | batch_size = x.size(0) 75 | beta = beta.view(batch_size, 1, 1) # (B, 1, 1) 76 | context = context.view(batch_size, 1, -1) # (B, 1, F) 77 | 78 | time_emb = torch.cat([beta, torch.sin(beta), torch.cos(beta)], dim=-1) # (B, 1, 3) 79 | ctx_emb = torch.cat([time_emb, context], dim=-1) # (B, 1, F+3) 80 | 81 | x = self.poseformer(x) 82 | x = self.concat1(ctx_emb, x) 83 | 84 | # Transformer Decoder 85 | final_emb = x.permute(1, 0, 2) 86 | final_emb = self.pos_emb(final_emb) 87 | x = self.transformer_encoder(final_emb).permute(1, 0, 2) 88 | 89 | # concat 90 | x = self.concat2(ctx_emb, x) 91 | x = self.act(x) 92 | x = self.concat3(ctx_emb, x) 93 | x = self.act(x) 94 | x = self.concat4(ctx_emb, x) 95 | 96 | return x 97 | 98 | 99 | 100 | 101 | -------------------------------------------------------------------------------- /models/PoseFormer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import torch.nn as nn 5 | import torch 6 | from timm.models.layers import DropPath 7 | from functools import partial 8 | from einops import rearrange 9 | from models.rnn import RNN 10 | 11 | 12 | 13 | class Mlp(nn.Module): 14 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 15 | super().__init__() 16 | out_features = out_features or in_features 17 | hidden_features = hidden_features or in_features 18 | self.fc1 = nn.Linear(in_features, hidden_features) 19 | self.act = act_layer() 20 | self.fc2 = nn.Linear(hidden_features, out_features) 21 | self.drop = nn.Dropout(drop) 22 | 23 | def forward(self, x): 24 | x = self.fc1(x) 25 | x = self.act(x) 26 | x = self.drop(x) 27 | x = self.fc2(x) 28 | x = self.drop(x) 29 | return x 30 | 31 | 32 | 33 | class Attention(nn.Module): 34 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 35 | super().__init__() 36 | self.num_heads = num_heads 37 | head_dim = dim // num_heads 38 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 39 | self.scale = qk_scale or head_dim ** -0.5 40 | 41 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 42 | self.attn_drop = nn.Dropout(attn_drop) 43 | self.proj = nn.Linear(dim, dim) 44 | self.proj_drop = nn.Dropout(proj_drop) 45 | 46 | def forward(self, x): 47 | B, N, C = x.shape 48 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 49 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 50 | 51 | attn = (q @ k.transpose(-2, -1)) * self.scale 52 | attn = attn.softmax(dim=-1) 53 | attn = self.attn_drop(attn) 54 | 55 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 56 | x = self.proj(x) 57 | x = self.proj_drop(x) 58 | return x 59 | 60 | 61 | 62 | class Block(nn.Module): 63 | 64 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 65 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 66 | super().__init__() 67 | self.norm1 = norm_layer(dim) 68 | self.attn = Attention( 69 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 70 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 71 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 72 | self.norm2 = norm_layer(dim) 73 | mlp_hidden_dim = int(dim * mlp_ratio) 74 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 75 | 76 | def forward(self, x): 77 | x = x + self.drop_path(self.attn(self.norm1(x))) 78 | x = x + self.drop_path(self.mlp(self.norm2(x))) 79 | return x 80 | 81 | 82 | 83 | 84 | 85 | class PoseFormer(nn.Module): 86 | def __init__(self, config, num_joint=16, in_chans=3, num_frame=100, embed_dim=16, depth=4, num_heads=8, mlp_ratio=2., 87 | qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.2, norm_layer=None): 88 | super(PoseFormer, self).__init__() 89 | 90 | # poseformer 91 | self.embed_dim = embed_dim 92 | self.num_joint = num_joint 93 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 94 | self.joint_embedding_his = nn.Linear(in_chans, embed_dim) 95 | self.joint_embedding_pred = nn.Linear(in_chans, embed_dim) 96 | self.Spatial_pos_embed = nn.Parameter(torch.zeros(1, num_joint, embed_dim)) 97 | self.Temporal_pos_embed = nn.Parameter(torch.zeros(1, num_frame, embed_dim * self.num_joint)) 98 | self.pos_drop = nn.Dropout(p=drop_rate) 99 | # rnn 100 | self.x_birnn = config.encoder_rnn 101 | self.rnn_type = config.rnn_type 102 | self.rnn_input_dim = num_joint * embed_dim 103 | self.rnn_output_dim = config.rnn_output_dim 104 | self.x_rnn = RNN(input_dim=self.rnn_input_dim, out_dim=self.rnn_output_dim, bi_dir=self.x_birnn, cell_type=self.rnn_type) 105 | 106 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 107 | 108 | self.Spatial_blocks = nn.ModuleList([ 109 | Block( 110 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 111 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) 112 | for i in range(depth)]) 113 | 114 | self.blocks = nn.ModuleList([ 115 | Block( 116 | dim=embed_dim * num_joint, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 117 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) 118 | for i in range(depth)]) 119 | 120 | self.Spatial_norm = norm_layer(embed_dim) 121 | self.Temporal_norm = norm_layer(embed_dim * self.num_joint) 122 | # self.weighted_mean = torch.nn.Conv1d(in_channels=num_joint, out_channels=1, kernel_size=1) 123 | 124 | 125 | def SpatialTrans(self, x, f): 126 | x += self.Spatial_pos_embed 127 | x = self.pos_drop(x) 128 | 129 | for blk in self.Spatial_blocks: 130 | x = blk(x) 131 | 132 | x = self.Spatial_norm(x) 133 | x = rearrange(x, '(b f) w c -> b f (w c)', f=f) 134 | return x 135 | 136 | def r_encode(self, x): 137 | if self.x_birnn: 138 | h_x = self.x_rnn(x).mean(dim=0) 139 | else: 140 | h_x = self.x_rnn(x)[-1] 141 | return h_x 142 | 143 | 144 | def encode_his(self, x): 145 | b, f, p, c = x.shape ##### b is batch size, f is number of frames, p is number of joints, c is dimension of each joint 146 | x = x.permute(0, 3, 1, 2) 147 | x = rearrange(x, 'b c f p -> (b f) p c', ) 148 | x = self.joint_embedding_his(x) 149 | x = self.SpatialTrans(x, f) 150 | # ####### A easy way to implement weighted mean (batch, joints, d) --> (batch, 1, d) 151 | # x = self.weighted_mean(x) 152 | # x = x.view(-1, self.embed_dim) 153 | ######## rnn to obtain the first predicted frame 154 | x = x.permute(1, 0, 2) 155 | x = self.r_encode(x) # (b, d) 156 | return x 157 | 158 | 159 | def TemporalAttention(self, x): 160 | b = x.shape[0] 161 | x += self.Temporal_pos_embed 162 | x = self.pos_drop(x) 163 | for blk in self.blocks: 164 | x = blk(x) 165 | 166 | x = self.Temporal_norm(x) 167 | return x 168 | 169 | 170 | def forward(self, y): 171 | x = y.reshape(y.shape[0], y.shape[1], self.num_joint, -1) 172 | ##### b is batch size, f is number of frames, p is number of joints, c is dimension of each joint 173 | x = x.permute(0, 3, 1, 2) 174 | b, _, f, p = x.shape 175 | x = rearrange(x, 'b c f p -> (b f) p c', ) 176 | x = self.joint_embedding_pred(x) 177 | x = self.SpatialTrans(x, f) 178 | x = self.TemporalAttention(x) 179 | return x 180 | 181 | 182 | 183 | 184 | -------------------------------------------------------------------------------- /models/__pycache__/Diffusion.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csdwei/MotionDiff/416a94b37c426d1a042ed72c6b0efe0b79ec0a7a/models/__pycache__/Diffusion.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/MotionDiff.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csdwei/MotionDiff/416a94b37c426d1a042ed72c6b0efe0b79ec0a7a/models/__pycache__/MotionDiff.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/PoseFormer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csdwei/MotionDiff/416a94b37c426d1a042ed72c6b0efe0b79ec0a7a/models/__pycache__/PoseFormer.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/common.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csdwei/MotionDiff/416a94b37c426d1a042ed72c6b0efe0b79ec0a7a/models/__pycache__/common.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/mao_gcn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csdwei/MotionDiff/416a94b37c426d1a042ed72c6b0efe0b79ec0a7a/models/__pycache__/mao_gcn.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/motion_pred.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csdwei/MotionDiff/416a94b37c426d1a042ed72c6b0efe0b79ec0a7a/models/__pycache__/motion_pred.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/rnn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csdwei/MotionDiff/416a94b37c426d1a042ed72c6b0efe0b79ec0a7a/models/__pycache__/rnn.cpython-37.pyc -------------------------------------------------------------------------------- /models/common.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from scipy.spatial.distance import pdist, squareform 4 | from visualization.visualization import * 5 | from utils import * 6 | from utils.logger import AverageMeter 7 | import csv 8 | import pickle 9 | import os.path as osp 10 | from torch.nn import functional as F 11 | import dataset.utils as dutil 12 | from torch.autograd import Variable 13 | 14 | ########################################################################################################### 15 | ########################################################################################################### 16 | ################################ Visualization ###################################################### 17 | ########################################################################################################### 18 | ########################################################################################################### 19 | 20 | def denomarlize(dataset, *data): 21 | out = [] 22 | for x in data: 23 | x = x * dataset.std + dataset.mean 24 | out.append(x) 25 | return out 26 | 27 | 28 | def get_prediction(config, models, data, algo, sample_num, device, dtype, num_seeds=1, concat_hist=True): 29 | t_his = config.obs_frames 30 | t_pred = config.pred_frames 31 | traj_np = data[..., 1:, :].reshape(data.shape[0], data.shape[1], -1) 32 | traj = torch.tensor(traj_np, device=device, dtype=dtype).permute(1, 0, 2).contiguous() 33 | X = traj[:t_his] 34 | 35 | if algo == 'diff': 36 | X = X.repeat((1, sample_num, 1)) 37 | Y = models[algo].generate(X) 38 | elif algo == 'refine': 39 | X = X.repeat((1, sample_num, 1)) 40 | Y = models['diff'].generate(X) 41 | Y = torch.cat((X, Y), dim=0) 42 | # DCT 43 | c = Y.shape[-1] # 48 44 | Y = Y.permute(1, 2, 0) 45 | Y = Y.reshape(-1, t_his + t_pred) 46 | Y = Y.transpose(0, 1) 47 | dct_m_in, _ = dutil.dataset_generated_motions.get_dct_matrix(t_his + t_pred) 48 | dct_m_in = Variable(torch.from_numpy(dct_m_in)).float().cuda() 49 | input_dct_seq = torch.matmul(dct_m_in[0 : config.dct_n, :], Y) 50 | input_dct_seq = input_dct_seq.transpose(0, 1).reshape([config.nk, c, config.dct_n]) 51 | outputs = models[algo](input_dct_seq) 52 | # IDCT 53 | _, idct_m = dutil.dataset_generated_motions.get_dct_matrix(t_his + t_pred) 54 | idct_m = Variable(torch.from_numpy(idct_m)).float().cuda() 55 | outputs_t = outputs.view(-1, config.dct_n).transpose(0, 1) 56 | Y = torch.matmul(idct_m[:, 0:config.dct_n], outputs_t).transpose(0, 1).contiguous().view(-1, outputs.shape[1], t_his + t_pred).transpose(1, 2) 57 | Y = Y.permute(1, 0, 2).contiguous() 58 | Y = Y[t_his:] 59 | 60 | if concat_hist: 61 | Y = torch.cat((X, Y), dim=0) 62 | Y = Y.permute(1, 0, 2).contiguous().cpu().numpy() 63 | if Y.shape[0] > 1: 64 | Y = Y.reshape(-1, sample_num, Y.shape[-2], Y.shape[-1]) 65 | else: 66 | Y = Y[None, ...] 67 | return Y 68 | 69 | 70 | 71 | def visualize(config, model, dataset, device, dtype, algos, out_path): 72 | 73 | def post_process(config, pred, data): 74 | pred = pred.reshape(pred.shape[0], pred.shape[1], -1, 3) 75 | if config.normalize_data: 76 | pred = denomarlize(dataset, pred) 77 | pred = np.concatenate((np.tile(data[..., :1, :], (pred.shape[0], 1, 1, 1)), pred), axis=2) 78 | pred[..., :1, :] = 0 79 | return pred 80 | 81 | def pose_generator(config, model, dataset, device, dtype): 82 | 83 | while True: 84 | data = dataset.sample() 85 | 86 | # gt 87 | gt = data[0].copy() 88 | gt[:, :1, :] = 0 89 | poses = {'context': gt, 'gt': gt} 90 | # vae 91 | for algo in vis_algos: 92 | pred = get_prediction(config, model, data, algo, config.nk, device, dtype)[0] 93 | pred = post_process(config, pred, data) 94 | for i in range(pred.shape[0]): 95 | poses[f'{algo}_{i}'] = pred[i] 96 | 97 | yield poses 98 | 99 | vis_algos = algos 100 | t_his = config.obs_frames 101 | # t_pred = config.pred_frames 102 | pose_gen = pose_generator(config, model, dataset, device, dtype) 103 | out = osp.join(out_path, 'video.mo4') 104 | render_animation(dataset.skeleton, pose_gen, vis_algos, t_his, ncol=12, output=out) 105 | 106 | 107 | 108 | 109 | ########################################################################################################### 110 | ########################################################################################################### 111 | ################################### Statistics ###################################################### 112 | ########################################################################################################### 113 | ########################################################################################################### 114 | 115 | def get_gt(data, t_his): 116 | gt = data[..., 1:, :].reshape(data.shape[0], data.shape[1], -1) 117 | return gt[:, t_his:, :] 118 | 119 | 120 | def get_multimodal_gt(config, dataset_test, logger_test): 121 | all_data = [] 122 | t_his = config.obs_frames 123 | t_pred = config.pred_frames 124 | data_gen = dataset_test.iter_generator(t_his) 125 | for data in data_gen: 126 | data = data[..., 1:, :].reshape(data.shape[0], data.shape[1], -1) 127 | all_data.append(data) 128 | all_data = np.concatenate(all_data, axis=0) 129 | all_start_pose = all_data[:, t_his - 1, :] 130 | pd = squareform(pdist(all_start_pose)) 131 | traj_gt_arr = [] 132 | num_mult = [] 133 | for i in range(pd.shape[0]): 134 | ind = np.nonzero(pd[i] < config.multimodal_threshold) 135 | traj_gt_arr.append(all_data[ind][:, t_his:, :]) 136 | num_mult.append(len(ind[0])) 137 | 138 | # num_mult = np.array(num_mult) 139 | # logger_test.info('') 140 | # logger_test.info('') 141 | # logger_test.info('=' * 80) 142 | # logger_test.info(f'#1 future: {len(np.where(num_mult == 1)[0])}/{pd.shape[0]}') 143 | # logger_test.info(f'#<10 future: {len(np.where(num_mult < 10)[0])}/{pd.shape[0]}') 144 | return traj_gt_arr 145 | 146 | 147 | """metrics""" 148 | 149 | def compute_diversity(pred, *args): 150 | if pred.shape[0] == 1: 151 | return 0.0 152 | dist = pdist(pred.reshape(pred.shape[0], -1)) 153 | a, idx1 = torch.sort(torch.tensor(dist), descending=True) 154 | diversity = a[:50].mean().item() 155 | # diversity = dist.mean().item() 156 | return diversity 157 | 158 | 159 | def compute_ade(pred, gt, *args): 160 | diff = pred - gt 161 | dist = np.linalg.norm(diff, axis=2).mean(axis=1) 162 | return dist.min() 163 | 164 | 165 | def compute_fde(pred, gt, *args): 166 | diff = pred - gt 167 | dist = np.linalg.norm(diff, axis=2)[:, -1] 168 | return dist.min() 169 | 170 | 171 | def compute_mmade(pred, gt, gt_multi): 172 | gt_dist = [] 173 | for gt_multi_i in gt_multi: 174 | dist = compute_ade(pred, gt_multi_i) 175 | gt_dist.append(dist) 176 | gt_dist = np.array(gt_dist).mean() 177 | return gt_dist 178 | 179 | 180 | def compute_mmfde(pred, gt, gt_multi): 181 | gt_dist = [] 182 | for gt_multi_i in gt_multi: 183 | dist = compute_fde(pred, gt_multi_i) 184 | gt_dist.append(dist) 185 | gt_dist = np.array(gt_dist).mean() 186 | return gt_dist 187 | 188 | 189 | 190 | def compute_stats(config, model, dataset, device, dtype, algos, logger_test, out_path): 191 | stats_algos = algos 192 | t_his = config.obs_frames 193 | # t_pred = config.pred_frames 194 | num_seeds = config.num_seeds 195 | 196 | stats_func = {'Diversity': compute_diversity, 'ADE': compute_ade, 197 | 'FDE': compute_fde, 'MMADE': compute_mmade, 'MMFDE': compute_mmfde} 198 | stats_names = list(stats_func.keys()) 199 | stats_meter = {x: {y: AverageMeter() for y in stats_algos} for x in stats_names} 200 | # generate multi-modal ground truth (only in test stage) 201 | traj_gt_arr = get_multimodal_gt(config, dataset, logger_test) 202 | 203 | data_gen = dataset.iter_generator(step=t_his) 204 | num_samples = 0 205 | for i, data in enumerate(data_gen): 206 | num_samples += 1 207 | gt = get_gt(data, t_his) 208 | gt_multi = traj_gt_arr[i] 209 | for algo in stats_algos: 210 | pred = get_prediction(config, model, data, algo, sample_num=config.nk, device=device, dtype=dtype, num_seeds=num_seeds, concat_hist=False) 211 | for stats in stats_names: 212 | val = 0 213 | for pred_i in pred: 214 | val += stats_func[stats](pred_i, gt, gt_multi) / num_seeds 215 | stats_meter[stats][algo].update(val) 216 | print('-' * 80) 217 | for stats in stats_names: 218 | str_stats = f'{num_samples:04d} {stats}: ' + ' '.join([f'{x}: {y.val:.4f}({y.avg:.4f})' for x, y in stats_meter[stats].items()]) 219 | print(str_stats) 220 | 221 | logger_test.info('=' * 80) 222 | for stats in stats_names: 223 | str_stats = f'Total {stats}: ' + ' '.join([f'{x}: {y.avg:.4f}' for x, y in stats_meter[stats].items()]) 224 | logger_test.info(str_stats) 225 | logger_test.info('=' * 80) 226 | 227 | with open('%s/stats_%s.csv' % (out_path, config.nk), 'w') as csv_file: 228 | writer = csv.DictWriter(csv_file, fieldnames=['Metric'] + algos) 229 | writer.writeheader() 230 | for stats, meter in stats_meter.items(): 231 | new_meter = {x: y.avg for x, y in meter.items()} 232 | new_meter['Metric'] = stats 233 | writer.writerow(new_meter) 234 | 235 | 236 | 237 | 238 | ######################################################################################### 239 | ########################################################################################## 240 | def lr_decay(optimizer, lr_now, gamma): 241 | lr = lr_now * gamma 242 | for param_group in optimizer.param_groups: 243 | param_group['lr'] = lr 244 | return lr 245 | 246 | 247 | class AccumLoss(object): 248 | def __init__(self): 249 | self.val = 0 250 | self.avg = 0 251 | self.sum = 0 252 | self.count = 0 253 | 254 | def update(self, val, n=1): 255 | self.val = val 256 | self.sum += val 257 | self.count += n 258 | self.avg = self.sum / self.count 259 | 260 | 261 | def loss_function(config, Y_g, Y, device, dtype, X): 262 | t_his = config.obs_frames 263 | # loss 264 | JL = joint_loss(config, Y_g) if config.lambda_j > 0 else 0.0 265 | RECON = recon_loss(config, Y_g, Y) if config.lambda_recon > 0 else 0.0 266 | X = X.reshape(-1, X.shape[2], X.shape[3]).permute(1, 0, 2) 267 | loss_r = RECON * config.lambda_recon + JL * config.lambda_j 268 | return loss_r, np.array([loss_r.item()]) 269 | 270 | 271 | def joint_loss(config, Y_g): 272 | loss = 0.0 273 | Y_g = Y_g.permute(1, 0, 2).contiguous() 274 | Y_g = Y_g.view(Y_g.shape[0] // config.nk, config.nk, -1) 275 | for Y in Y_g: 276 | dist = F.pdist(Y, 2) ** 2 277 | loss += (-dist / config.d_scale).exp().mean() 278 | loss /= Y_g.shape[0] 279 | return loss 280 | 281 | 282 | def recon_loss(config, Y_g, Y): 283 | Y_g = Y_g.view(Y_g.shape[0], -1, config.nk, Y_g.shape[2]) 284 | diff = Y_g - Y.unsqueeze(2) 285 | dist = diff.pow(2).sum(dim=-1).sum(dim=0) 286 | loss_recon = dist.min(dim=1)[0].mean() 287 | return loss_recon / 100000.0 288 | 289 | 290 | 291 | 292 | -------------------------------------------------------------------------------- /models/mao_gcn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | from __future__ import absolute_import 4 | from __future__ import print_function 5 | 6 | import torch.nn as nn 7 | import torch 8 | from torch.nn.parameter import Parameter 9 | import math 10 | 11 | 12 | class GraphConvolution(nn.Module): 13 | """ 14 | adapted from : https://github.com/tkipf/gcn/blob/92600c39797c2bfb61a508e52b88fb554df30177/gcn/layers.py#L132 15 | """ 16 | 17 | def __init__(self, in_features, out_features, bias=True, node_n=48): 18 | super(GraphConvolution, self).__init__() 19 | self.in_features = in_features 20 | self.out_features = out_features 21 | self.weight = Parameter(torch.FloatTensor(in_features, out_features)) # torch.nn.parameter作用是将数据变为可训练参数,即默认数据的requires_grad为true,而一般的torch.tensor默认requires_grad为false 22 | self.att = Parameter(torch.FloatTensor(node_n, node_n)) 23 | if bias: 24 | self.bias = Parameter(torch.FloatTensor(out_features)) 25 | else: 26 | self.register_parameter('bias', None) 27 | self.reset_parameters() 28 | 29 | def reset_parameters(self): 30 | stdv = 1. / math.sqrt(self.weight.size(1)) 31 | self.weight.data.uniform_(-stdv, stdv) 32 | self.att.data.uniform_(-stdv, stdv) 33 | if self.bias is not None: 34 | self.bias.data.uniform_(-stdv, stdv) 35 | 36 | def forward(self, input): 37 | support = torch.matmul(input, self.weight) 38 | output = torch.matmul(self.att, support) 39 | if self.bias is not None: 40 | return output + self.bias 41 | else: 42 | return output 43 | 44 | def __repr__(self): 45 | return self.__class__.__name__ + ' (' \ 46 | + str(self.in_features) + ' -> ' \ 47 | + str(self.out_features) + ')' 48 | 49 | 50 | class GC_Block(nn.Module): 51 | def __init__(self, in_features, p_dropout, bias=True, node_n=48): 52 | """ 53 | Define a residual block of GCN 54 | """ 55 | super(GC_Block, self).__init__() 56 | self.in_features = in_features 57 | self.out_features = in_features 58 | 59 | self.gc1 = GraphConvolution(in_features, in_features, node_n=node_n, bias=bias) 60 | self.bn1 = nn.BatchNorm1d(node_n * in_features) 61 | 62 | self.gc2 = GraphConvolution(in_features, in_features, node_n=node_n, bias=bias) 63 | self.bn2 = nn.BatchNorm1d(node_n * in_features) 64 | 65 | self.do = nn.Dropout(p_dropout) 66 | self.act_f = nn.Tanh() 67 | 68 | def forward(self, x): 69 | y = self.gc1(x) 70 | b, n, f = y.shape 71 | y = self.bn1(y.view(b, -1)).view(b, n, f) 72 | y = self.act_f(y) 73 | y = self.do(y) 74 | 75 | y = self.gc2(y) 76 | b, n, f = y.shape 77 | y = self.bn2(y.view(b, -1)).view(b, n, f) 78 | y = self.act_f(y) 79 | y = self.do(y) 80 | 81 | return y + x 82 | 83 | def __repr__(self): 84 | return self.__class__.__name__ + ' (' \ 85 | + str(self.in_features) + ' -> ' \ 86 | + str(self.out_features) + ')' 87 | 88 | 89 | class GCN(nn.Module): 90 | def __init__(self, input_feature, hidden_feature, p_dropout, num_stage=1, node_n=48, gamma=0.1): 91 | """ 92 | 93 | :param input_feature: num of input feature 94 | :param hidden_feature: num of hidden feature 95 | :param p_dropout: drop out prob. 96 | :param num_stage: number of residual blocks 97 | :param node_n: number of nodes in graph 98 | """ 99 | super(GCN, self).__init__() 100 | self.num_stage = num_stage 101 | self.gamma = gamma 102 | 103 | self.gc1 = GraphConvolution(input_feature, hidden_feature, node_n=node_n) 104 | self.bn1 = nn.BatchNorm1d(node_n * hidden_feature) 105 | 106 | self.gcbs = [] 107 | for i in range(num_stage): 108 | self.gcbs.append(GC_Block(hidden_feature, p_dropout=p_dropout, node_n=node_n)) 109 | 110 | self.gcbs = nn.ModuleList(self.gcbs) 111 | 112 | self.gc7 = GraphConvolution(hidden_feature, input_feature, node_n=node_n) 113 | 114 | self.do = nn.Dropout(p_dropout) 115 | self.act_f = nn.Tanh() 116 | 117 | def forward(self, x): 118 | y = self.gc1(x) 119 | b, n, f = y.shape 120 | y = self.bn1(y.view(b, -1)).view(b, n, f) 121 | y = self.act_f(y) 122 | y = self.do(y) 123 | 124 | for i in range(self.num_stage): 125 | y = self.gcbs[i](y) 126 | 127 | y = self.gc7(y) 128 | y = self.gamma * y + x 129 | 130 | return y 131 | -------------------------------------------------------------------------------- /models/motion_pred.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from utils.torch import * 6 | from models.PoseFormer import PoseFormer 7 | from models.MotionDiff import MotionDiff 8 | from models.Diffusion import Diffusion 9 | import models.mao_gcn as nnmodel 10 | 11 | 12 | class DiffHM(nn.Module): 13 | def __init__(self, config, num_joint): 14 | super(DiffHM, self).__init__() 15 | # Variable 16 | self.config = config 17 | self.num_frame = config.obs_frames 18 | self.pose_embed_dim = config.pose_embed_dim 19 | self.drop_path_rate = config.drop_path_rate 20 | self.drop_rate_poseformer = config.drop_rate_poseformer 21 | self.num_joint = num_joint 22 | # Encoder 23 | self.poseformer = PoseFormer(config, num_joint=self.num_joint, in_chans=3, num_frame=self.num_frame, embed_dim=self.pose_embed_dim, 24 | drop_rate=self.drop_rate_poseformer, drop_path_rate=self.drop_path_rate, norm_layer=None) 25 | # Decoder 26 | self.y_diff = Diffusion(config, num_joint) 27 | self.y_mlp = MotionDiff(config, num_joint=self.num_joint) 28 | 29 | 30 | def diff(self, y): 31 | b, f, _, _ = y.shape 32 | y = y.reshape(b, f, -1) 33 | return self.y_diff(y) 34 | 35 | def encode(self, x, y): 36 | feat_x_encoded = self.poseformer.encode_his(x) 37 | diff_y, e_rand, beta = self.diff(y) 38 | # whether DCT ??? 39 | 40 | return feat_x_encoded, diff_y, e_rand, beta 41 | 42 | def denoise(self, feat_x_encoded, diff_y, beta): 43 | return self.y_mlp(feat_x_encoded, diff_y, beta) 44 | 45 | def get_e_loss(self, e_rand, e_theta): 46 | loss = F.mse_loss(e_theta.view(-1, 3 * self.num_joint), e_rand.view(-1, 3 * self.num_joint), reduction='mean') 47 | return loss 48 | 49 | 50 | def get_loss(self, x, y): 51 | feat_x_encoded, diff_y, e_rand, beta = self.encode(x, y) 52 | e_theta = self.denoise(feat_x_encoded, diff_y, beta) 53 | loss = self.get_e_loss(e_rand, e_theta) 54 | return loss 55 | 56 | 57 | 58 | def generate(self, x): 59 | x = x.reshape(x.shape[0], x.shape[1], self.num_joint, -1) 60 | x = x.permute(1, 0, 2, 3).contiguous() 61 | encoded_x = self.poseformer.encode_his(x) 62 | predicted_x = self.y_diff.sample(self.y_mlp, encoded_x, flexibility=self.config.flexibility, ret_traj=self.config.ret_traj) 63 | predicted_x = predicted_x.permute(1, 0, 2).contiguous() 64 | return predicted_x 65 | 66 | 67 | 68 | def get_diff_model(config, traj_dim): 69 | model_name = config.model_name 70 | num_joint = traj_dim // 3 71 | if model_name == "MotionDiff": 72 | return DiffHM(config, num_joint) 73 | else: 74 | print("The model doesn't exist: %s" % model_name) 75 | exit(0) 76 | 77 | 78 | 79 | def get_refine_model(config, traj_dim): 80 | return nnmodel.GCN(input_feature=config.dct_n, hidden_feature=config.gcn_linear_size, p_dropout=config.gcn_dropout, 81 | num_stage=config.gcn_layers, node_n=traj_dim, gamma=config.gamma) 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | -------------------------------------------------------------------------------- /models/rnn.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from utils.torch import * 3 | 4 | 5 | class RNN(nn.Module): 6 | def __init__(self, input_dim, out_dim, cell_type='lstm', bi_dir=False): 7 | super().__init__() 8 | self.input_dim = input_dim 9 | self.out_dim = out_dim 10 | self.cell_type = cell_type 11 | self.bi_dir = bi_dir 12 | self.mode = 'batch' 13 | rnn_cls = nn.LSTMCell if cell_type == 'lstm' else nn.GRUCell 14 | hidden_dim = out_dim // 2 if bi_dir else out_dim 15 | self.rnn_f = rnn_cls(self.input_dim, hidden_dim) 16 | if bi_dir: 17 | self.rnn_b = rnn_cls(self.input_dim, hidden_dim) 18 | self.hx, self.cx = None, None 19 | 20 | def set_mode(self, mode): 21 | self.mode = mode 22 | 23 | def initialize(self, batch_size=1, hx=None, cx=None): 24 | if self.mode == 'step': 25 | self.hx = zeros((batch_size, self.rnn_f.hidden_size)) if hx is None else hx 26 | if self.cell_type == 'lstm': 27 | self.cx = zeros((batch_size, self.rnn_f.hidden_size)) if cx is None else cx 28 | 29 | def forward(self, x): 30 | if self.mode == 'step': 31 | self.hx, self.cx = batch_to(x.device, self.hx, self.cx) 32 | if self.cell_type == 'lstm': 33 | self.hx, self.cx = self.rnn_f(x, (self.hx, self.cx)) 34 | else: 35 | self.hx = self.rnn_f(x, self.hx) 36 | rnn_out = self.hx 37 | else: 38 | rnn_out_f = self.batch_forward(x) 39 | if not self.bi_dir: 40 | return rnn_out_f 41 | rnn_out_b = self.batch_forward(x, reverse=True) 42 | rnn_out = torch.cat((rnn_out_f, rnn_out_b), 2) 43 | return rnn_out 44 | 45 | def batch_forward(self, x, reverse=False): 46 | rnn = self.rnn_b if reverse else self.rnn_f 47 | rnn_out = [] 48 | hx = zeros((x.size(1), rnn.hidden_size), device=x.device) 49 | if self.cell_type == 'lstm': 50 | cx = zeros((x.size(1), rnn.hidden_size), device=x.device) 51 | ind = reversed(range(x.size(0))) if reverse else range(x.size(0)) 52 | for t in ind: 53 | if self.cell_type == 'lstm': 54 | hx, cx = rnn(x[t, ...], (hx, cx)) 55 | else: 56 | hx = rnn(x[t, ...], hx) 57 | rnn_out.append(hx.unsqueeze(0)) 58 | if reverse: 59 | rnn_out.reverse() 60 | rnn_out = torch.cat(rnn_out, 0) 61 | return rnn_out 62 | 63 | 64 | if __name__ == '__main__': 65 | rnn = RNN(12, 24, 'gru', bi_dir=True) 66 | input = zeros(5, 3, 12) 67 | out = rnn(input) 68 | print(out.shape) 69 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | easydict==1.9 2 | einops==0.4.1 3 | h5py==3.6.0 4 | matplotlib==3.5.2 5 | numpy==1.21.6 6 | progress==1.6 7 | PyYAML==6.0 8 | scipy==1.7.3 9 | timm==0.5.4 10 | tqdm==4.64.0 11 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from utils.torch import * 2 | from utils.logger import * 3 | -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csdwei/MotionDiff/416a94b37c426d1a042ed72c6b0efe0b79ec0a7a/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/logger.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csdwei/MotionDiff/416a94b37c426d1a042ed72c6b0efe0b79ec0a7a/utils/__pycache__/logger.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/torch.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csdwei/MotionDiff/416a94b37c426d1a042ed72c6b0efe0b79ec0a7a/utils/__pycache__/torch.cpython-37.pyc -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | 5 | def create_logger(filename, file_handle=True): 6 | # create logger 7 | logger = logging.getLogger(filename) 8 | logger.propagate = False 9 | logger.setLevel(logging.DEBUG) 10 | # create console handler with a higher log level 11 | ch = logging.StreamHandler() 12 | ch.setLevel(logging.INFO) 13 | stream_formatter = logging.Formatter('%(message)s') 14 | ch.setFormatter(stream_formatter) 15 | logger.addHandler(ch) 16 | 17 | if file_handle: 18 | # create file handler which logs even debug messages 19 | os.makedirs(os.path.dirname(filename), exist_ok=True) 20 | fh = logging.FileHandler(filename, mode='a') 21 | fh.setLevel(logging.DEBUG) 22 | file_formatter = logging.Formatter('[%(asctime)s] %(message)s') 23 | fh.setFormatter(file_formatter) 24 | logger.addHandler(fh) 25 | 26 | return logger 27 | 28 | 29 | class AverageMeter(object): 30 | """Computes and stores the average and current value""" 31 | 32 | def __init__(self): 33 | self.reset() 34 | 35 | def reset(self): 36 | self.val = 0 37 | self.avg = 0 38 | self.sum = 0 39 | self.count = 0 40 | 41 | def update(self, val, n=1): 42 | self.val = val 43 | self.sum += val * n 44 | self.count += n 45 | self.avg = self.sum / self.count -------------------------------------------------------------------------------- /utils/torch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.optim import lr_scheduler 4 | 5 | tensor = torch.tensor 6 | DoubleTensor = torch.DoubleTensor 7 | FloatTensor = torch.FloatTensor 8 | LongTensor = torch.LongTensor 9 | ByteTensor = torch.ByteTensor 10 | ones = torch.ones 11 | zeros = torch.zeros 12 | 13 | 14 | class to_cpu: 15 | 16 | def __init__(self, *models): 17 | self.models = list(filter(lambda x: x is not None, models)) 18 | self.prev_devices = [x.device if hasattr(x, 'device') else next(x.parameters()).device for x in self.models] 19 | for x in self.models: 20 | x.to(torch.device('cpu')) 21 | 22 | def __enter__(self): 23 | pass 24 | 25 | def __exit__(self, *args): 26 | for x, device in zip(self.models, self.prev_devices): 27 | x.to(device) 28 | return False 29 | 30 | 31 | class to_device: 32 | 33 | def __init__(self, device, *models): 34 | self.models = list(filter(lambda x: x is not None, models)) 35 | self.prev_devices = [x.device if hasattr(x, 'device') else next(x.parameters()).device for x in self.models] 36 | for x in self.models: 37 | x.to(device) 38 | 39 | def __enter__(self): 40 | pass 41 | 42 | def __exit__(self, *args): 43 | for x, device in zip(self.models, self.prev_devices): 44 | x.to(device) 45 | return False 46 | 47 | 48 | class to_test: 49 | 50 | def __init__(self, *models): 51 | self.models = list(filter(lambda x: x is not None, models)) 52 | self.prev_modes = [x.training for x in self.models] 53 | for x in self.models: 54 | x.train(False) 55 | 56 | def __enter__(self): 57 | pass 58 | 59 | def __exit__(self, *args): 60 | for x, mode in zip(self.models, self.prev_modes): 61 | x.train(mode) 62 | return False 63 | 64 | 65 | class to_train: 66 | 67 | def __init__(self, *models): 68 | self.models = list(filter(lambda x: x is not None, models)) 69 | self.prev_modes = [x.training for x in self.models] 70 | for x in self.models: 71 | x.train(True) 72 | 73 | def __enter__(self): 74 | pass 75 | 76 | def __exit__(self, *args): 77 | for x, mode in zip(self.models, self.prev_modes): 78 | x.train(mode) 79 | return False 80 | 81 | 82 | def batch_to(dst, *args): 83 | return [x.to(dst) if x is not None else None for x in args] 84 | 85 | 86 | def get_flat_params_from(models): 87 | if not hasattr(models, '__iter__'): 88 | models = (models, ) 89 | params = [] 90 | for model in models: 91 | for param in model.parameters(): 92 | params.append(param.data.view(-1)) 93 | 94 | flat_params = torch.cat(params) 95 | return flat_params 96 | 97 | 98 | def set_flat_params_to(model, flat_params): 99 | prev_ind = 0 100 | for param in model.parameters(): 101 | flat_size = int(np.prod(list(param.size()))) 102 | param.data.copy_( 103 | flat_params[prev_ind:prev_ind + flat_size].view(param.size())) 104 | prev_ind += flat_size 105 | 106 | 107 | def get_flat_grad_from(inputs, grad_grad=False): 108 | grads = [] 109 | for param in inputs: 110 | if grad_grad: 111 | grads.append(param.grad.grad.view(-1)) 112 | else: 113 | if param.grad is None: 114 | grads.append(zeros(param.view(-1).shape)) 115 | else: 116 | grads.append(param.grad.view(-1)) 117 | 118 | flat_grad = torch.cat(grads) 119 | return flat_grad 120 | 121 | 122 | def compute_flat_grad(output, inputs, filter_input_ids=set(), retain_graph=False, create_graph=False): 123 | if create_graph: 124 | retain_graph = True 125 | 126 | inputs = list(inputs) 127 | params = [] 128 | for i, param in enumerate(inputs): 129 | if i not in filter_input_ids: 130 | params.append(param) 131 | 132 | grads = torch.autograd.grad(output, params, retain_graph=retain_graph, create_graph=create_graph) 133 | 134 | j = 0 135 | out_grads = [] 136 | for i, param in enumerate(inputs): 137 | if i in filter_input_ids: 138 | out_grads.append(zeros(param.view(-1).shape)) 139 | else: 140 | out_grads.append(grads[j].view(-1)) 141 | j += 1 142 | grads = torch.cat(out_grads) 143 | 144 | for param in params: 145 | param.grad = None 146 | return grads 147 | 148 | 149 | def set_optimizer_lr(optimizer, lr): 150 | for param_group in optimizer.param_groups: 151 | param_group['lr'] = lr 152 | 153 | 154 | def filter_state_dict(state_dict, filter_keys): 155 | for key in list(state_dict.keys()): 156 | for f_key in filter_keys: 157 | if f_key in key: 158 | del state_dict[key] 159 | break 160 | 161 | 162 | def get_scheduler(optimizer, policy, nepoch_fix=None, nepoch=None, decay_step=None): 163 | if policy == 'lambda': 164 | def lambda_rule(epoch): 165 | lr_l = 1.0 - max(0, epoch - nepoch_fix) / float(nepoch - nepoch_fix + 1) 166 | return lr_l 167 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 168 | elif policy == 'step': 169 | scheduler = lr_scheduler.StepLR( 170 | optimizer, step_size=decay_step, gamma=0.1) 171 | elif policy == 'plateau': 172 | scheduler = lr_scheduler.ReduceLROnPlateau( 173 | optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) 174 | else: 175 | return NotImplementedError('learning rate policy [%s] is not implemented', policy) 176 | return scheduler 177 | -------------------------------------------------------------------------------- /visualization/__pycache__/visualization.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csdwei/MotionDiff/416a94b37c426d1a042ed72c6b0efe0b79ec0a7a/visualization/__pycache__/visualization.cpython-37.pyc -------------------------------------------------------------------------------- /visualization/visualization.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import os 9 | import matplotlib 10 | import matplotlib.pyplot as plt 11 | from matplotlib.animation import FuncAnimation, writers 12 | from mpl_toolkits.mplot3d import Axes3D 13 | import numpy as np 14 | import pickle 15 | 16 | 17 | def render_animation(skeleton, poses_generator, algos, t_hist, fix_0=True, azim=0.0, output=None, size=6, ncol=5, bitrate=3000): 18 | """ 19 | TODO 20 | Render an animation. The supported output modes are: 21 | -- 'interactive': display an interactive figure 22 | (also works on notebooks if associated with %matplotlib inline) 23 | -- 'html': render the animation as HTML5 video. Can be displayed in a notebook using HTML(...). 24 | -- 'filename.mp4': render and export the animation as an h264 video (requires ffmpeg). 25 | -- 'filename.gif': render and export the animation a gif file (requires imagemagick). 26 | """ 27 | 28 | all_poses = next(poses_generator) 29 | algo = algos[0] if len(algos) > 0 else next(iter(all_poses.keys())) 30 | t_total = next(iter(all_poses.values())).shape[0] 31 | poses = dict(filter(lambda x: x[0] in {'gt', 'context'} or algo == x[0].split('_')[0], all_poses.items())) 32 | plt.ioff() 33 | nrow = int(np.ceil(len(poses) / ncol)) 34 | fig = plt.figure(figsize=(size * ncol, size * nrow)) 35 | ax_3d = [] 36 | lines_3d = [] 37 | trajectories = [] 38 | radius = 1.7 39 | for index, (title, data) in enumerate(poses.items()): 40 | ax = fig.add_subplot(nrow, ncol, index + 1, projection='3d') 41 | ax.view_init(elev=15., azim=azim) 42 | ax.set_xlim3d([-radius / 2, radius / 2]) 43 | ax.set_zlim3d([0, radius]) 44 | ax.set_ylim3d([-radius / 2, radius / 2]) 45 | # ax.set_aspect('equal') 46 | ax.set_xticklabels([]) 47 | ax.set_yticklabels([]) 48 | ax.set_zticklabels([]) 49 | ax.dist = 5.0 50 | # ax.set_title(title, y=1.2) 51 | ax.set_axis_off() 52 | ax.patch.set_alpha(0.0) 53 | ax_3d.append(ax) 54 | lines_3d.append([]) 55 | trajectories.append(data[:, 0, [0, 1]]) 56 | fig.tight_layout() 57 | fig.subplots_adjust(wspace=-0.4, hspace=0) 58 | poses = list(poses.values()) 59 | 60 | anim = None 61 | initialized = False 62 | animating = True 63 | find = 0 64 | hist_lcol, hist_rcol = '#e66f51', '#42a6cb' 65 | pred_lcol, pred_rcol = '#a64036', '#4182a4' 66 | 67 | parents = skeleton.parents() 68 | 69 | def update_video(i): 70 | nonlocal initialized 71 | if i < t_hist: 72 | lcol, rcol = hist_lcol, hist_rcol 73 | else: 74 | lcol, rcol = pred_lcol, pred_rcol 75 | 76 | for n, ax in enumerate(ax_3d): 77 | if fix_0 and n == 0 and i >= t_hist: 78 | continue 79 | trajectories[n] = poses[n][:, 0, [0, 1, 2]] 80 | ax.set_xlim3d([-radius / 2 + trajectories[n][i, 0], radius / 2 + trajectories[n][i, 0]]) 81 | ax.set_ylim3d([-radius / 2 + trajectories[n][i, 1], radius / 2 + trajectories[n][i, 1]]) 82 | ax.set_zlim3d([-radius / 2 + trajectories[n][i, 2], radius / 2 + trajectories[n][i, 2]]) 83 | 84 | if not initialized: 85 | 86 | for j, j_parent in enumerate(parents): 87 | if j_parent == -1: 88 | continue 89 | 90 | col = rcol if j in skeleton.joints_right() else lcol 91 | zo = 1 if col == rcol else 2 92 | # zo = 2 if j in skeleton.joints_right() else 1 93 | for n, ax in enumerate(ax_3d): 94 | pos = poses[n][i] 95 | lines_3d[n].append(ax.plot([pos[j, 0], pos[j_parent, 0]], 96 | [pos[j, 1], pos[j_parent, 1]], 97 | [pos[j, 2], pos[j_parent, 2]], zdir='z', c=col, zorder=zo, lw=6)) 98 | initialized = True 99 | else: 100 | 101 | for j, j_parent in enumerate(parents): 102 | if j_parent == -1: 103 | continue 104 | 105 | col = rcol if j in skeleton.joints_right() else lcol 106 | for n, ax in enumerate(ax_3d): 107 | if fix_0 and n == 0 and i >= t_hist: 108 | continue 109 | pos = poses[n][i] 110 | # pos = np.array(pos) 111 | lines_3d[n][j - 1][0].set_xdata(np.asarray([pos[j, 0], pos[j_parent, 0]])) 112 | lines_3d[n][j - 1][0].set_ydata(np.asarray([pos[j, 1], pos[j_parent, 1]])) 113 | lines_3d[n][j - 1][0].set_3d_properties(np.asarray([pos[j, 2], pos[j_parent, 2]]), zdir='z') 114 | lines_3d[n][j - 1][0].set_color(col) 115 | 116 | def show_animation(): 117 | nonlocal anim 118 | if anim is not None: 119 | anim.event_source.stop() 120 | anim = FuncAnimation(fig, update_video, frames=np.arange(0, poses[0].shape[0]), interval=0, repeat=True) 121 | plt.draw() 122 | 123 | def reload_poses(): 124 | nonlocal poses 125 | poses = dict(filter(lambda x: x[0] in {'gt', 'context'} or algo == x[0].split('_')[0], all_poses.items())) 126 | for ax, title in zip(ax_3d, poses.keys()): 127 | ax.set_title(title, y=1.2) 128 | poses = list(poses.values()) 129 | 130 | def save_figs(): 131 | nonlocal algo, find 132 | old_algo = algo 133 | for algo in algos: 134 | reload_poses() 135 | update_video(t_total - 1) 136 | fig.savefig('out/%d_%s.png' % (find, algo), dpi=400, transparent=True) 137 | algo = old_algo 138 | find += 1 139 | 140 | def on_key(event): 141 | nonlocal algo, all_poses, animating, anim 142 | 143 | if event.key == 'd': 144 | all_poses = next(poses_generator) 145 | reload_poses() 146 | show_animation() 147 | elif event.key == 'c': 148 | save() 149 | elif event.key == ' ': 150 | if animating: 151 | anim.event_source.stop() 152 | else: 153 | anim.event_source.start() 154 | animating = not animating 155 | elif event.key == 'v': # save images 156 | if anim is not None: 157 | anim.event_source.stop() 158 | anim = None 159 | save_figs() 160 | elif event.key.isdigit(): 161 | algo = algos[int(event.key) - 1] 162 | reload_poses() 163 | show_animation() 164 | 165 | def save(): 166 | nonlocal anim 167 | 168 | fps = 30 169 | anim = FuncAnimation(fig, update_video, frames=np.arange(0, poses[0].shape[0]), interval=1000 / fps, 170 | repeat=False) 171 | os.makedirs(os.path.dirname(output), exist_ok=True) 172 | if output.endswith('.mp4'): 173 | Writer = writers['ffmpeg'] 174 | writer = Writer(fps=fps, metadata={}, bitrate=bitrate) 175 | anim.save(output, writer=writer) 176 | elif output.endswith('.gif'): 177 | anim.save(output, dpi=80, writer='pillow') 178 | else: 179 | raise ValueError('Unsupported output format (only .mp4 and .gif are supported)') 180 | print(f'video saved to {output}!') 181 | 182 | fig.canvas.mpl_connect('key_press_event', on_key) 183 | show_animation() 184 | # plt.show() 185 | save() --------------------------------------------------------------------------------