├── src ├── __init__.py ├── metrics │ ├── __init__.py │ ├── nlg.py │ ├── motion_generation.py │ └── common.py ├── models │ ├── __init__.py │ ├── model_base.py │ └── motion_clip.py ├── modules │ ├── __init__.py │ ├── mask.py │ ├── embeddings.py │ └── resnet.py ├── datasets │ ├── __init__.py │ ├── dataset_base.py │ ├── motion_clip_dataset.py │ └── motion_vqvae_dataset.py ├── utils │ ├── __init__.py │ ├── log.py │ ├── utils.py │ ├── rotation.py │ ├── normalizer.py │ ├── constants.py │ └── plot.py ├── configs │ ├── datasets │ │ ├── motion_clip.yaml │ │ ├── motion_vqvae.yaml │ │ └── lm.yaml │ ├── trainer │ │ └── default.yaml │ └── models │ │ ├── motion_clip.yaml │ │ ├── lm.yaml │ │ └── motion_vqvae.yaml └── losses │ └── __init__.py ├── data_preprocessing ├── __init__.py ├── 3.5-visualize_intergen_262.py ├── 1-preprocess_text.py ├── 0.5-visualize_joints3d_22.py ├── 7-calculate_xzr_range.py ├── 2-mirror_joints3d_22_and_text.py ├── 3-prepare_motions.py ├── 4-prepare_xzr.py ├── 6.5-visualize_tokens.py ├── utils.py ├── 5-prepare_normalizer.py ├── 6-prepare_tokens.py ├── check_prompt_with_llm.py ├── 0-smpl_to_joints3d_22.py └── split_interaction_caption_with_llm.py ├── third_party └── HumanML3D │ ├── common │ ├── __init__.py │ └── skeleton.py │ ├── human_body_prior │ ├── body_model │ │ ├── parts_segm │ │ │ └── readme │ │ ├── __init__.py │ │ └── rigid_object_model.py │ ├── __init__.py │ ├── models │ │ ├── __init__.py │ │ ├── model_components.py │ │ └── vposer_model.py │ ├── tools │ │ ├── __init__.py │ │ ├── configurations.py │ │ ├── angle_continuous_repres.py │ │ ├── model_loader.py │ │ ├── rotation_tools.py │ │ └── omni_tools.py │ ├── train │ │ ├── __init__.py │ │ ├── V02_05 │ │ │ ├── __init__.py │ │ │ ├── V02_05.py │ │ │ └── V02_05.yaml │ │ └── README.md │ └── visualizations │ │ ├── __init__.py │ │ └── training_visualization.py │ └── paramUtil.py ├── assets └── teaser.png ├── .gitignore ├── LICENSE ├── README.md ├── eval.py └── train.py /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/modules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data_preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /third_party/HumanML3D/common/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import * 2 | from .log import * 3 | -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlbertTan404/Think-Then-React/HEAD/assets/teaser.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | *.pyc 3 | __pycache__ 4 | 5 | outputs 6 | data 7 | 8 | *.out 9 | *.log 10 | *.LOG 11 | *_logs 12 | temp* 13 | 14 | wandb 15 | logs 16 | checkpoints 17 | -------------------------------------------------------------------------------- /third_party/HumanML3D/human_body_prior/body_model/parts_segm/readme: -------------------------------------------------------------------------------- 1 | ### Parts segmentation file obtained from https://github.com/vchoutas/torch-mesh-isect#examples and put here for convenience -------------------------------------------------------------------------------- /src/configs/datasets/motion_clip.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | target: src.datasets.motion_clip_dataset.MotionCLIPDataset 3 | 4 | # kwargs: 5 | dataset_dir: ~/data/data/motion/interx 6 | epoch_scaling: 1 7 | max_motion_length: 256 8 | min_motion_length: 32 9 | motion_representation: intergen_262 10 | split: 11 | tiny_dataset: False -------------------------------------------------------------------------------- /src/configs/datasets/motion_vqvae.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | target: src.datasets.motion_vqvae_dataset.MotionVQVAEDataset 3 | 4 | # kwargs: 5 | dataset_dir: ~/data/data/motion/interx 6 | epoch_scaling: 10 # scaling up the length of dataset for each epoch, reduce val interval 7 | max_motion_length: 256 8 | min_motion_length: 64 # this is very important for VQVAE training (empirically)! 9 | motion_representation: intergen_262 10 | split: 11 | tiny_dataset: False 12 | use_h3d: False 13 | abs_action: False 14 | -------------------------------------------------------------------------------- /src/losses/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def get_masked_seq2seq_loss(gt_seq, pred_seq, loss_mask, loss_mask_sum=None, loss_fn = F.smooth_l1_loss): 6 | shape = gt_seq.shape 7 | if loss_mask_sum == None: 8 | loss_mask_sum = loss_mask.sum() 9 | gt_seq = gt_seq.reshape(shape[0], shape[1], -1) 10 | pred_seq = pred_seq.reshape(shape[0], shape[1], -1) 11 | loss = loss_fn(pred_seq, gt_seq, reduction='none').mean(-1) 12 | loss = loss_mask * loss 13 | return torch.sum(loss) / loss_mask_sum 14 | -------------------------------------------------------------------------------- /src/configs/datasets/lm.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | target: src.datasets.lm_dataset.LMDataset 3 | 4 | # kwargs: 5 | dataset_dir: ~/data/data/motion/interx 6 | epoch_scaling: 1 7 | max_motion_length: 256 8 | min_motion_length: 32 9 | motion_representation: intergen_262 10 | split: 11 | tiny_dataset: False 12 | vqvae_ckpt_path: /path/to/ckpt 13 | n_x_bins: 10 14 | n_z_bins: 10 15 | n_r_bins: 10 16 | stage: pretrain 17 | use_h3d: True 18 | 19 | motion_token_template: 20 | x_template: 21 | z_template: 22 | r_template: 23 | -------------------------------------------------------------------------------- /src/configs/trainer/default.yaml: -------------------------------------------------------------------------------- 1 | trainer: 2 | accelerator: auto 3 | target: lightning.pytorch.trainer.Trainer 4 | devices: auto 5 | max_steps: -1 6 | check_val_every_n_epoch: 1 7 | log_every_n_steps: 50 8 | num_sanity_val_steps: 2 9 | gradient_clip_val: 0.5 10 | precision: 32-true 11 | accumulate_grad_batches: 1 12 | 13 | logger: 14 | target: lightning.pytorch.loggers.TensorBoardLogger 15 | save_dir: 16 | name: 17 | version: 18 | 19 | callbacks: 20 | - target: lightning.pytorch.callbacks.ModelCheckpoint 21 | save_last: True 22 | save_top_k: 3 23 | mode: max 24 | monitor: monitor 25 | auto_insert_metric_name: False 26 | filename: "epoch{epoch}__step{step}__monitor{monitor:.3f}" 27 | save_weights_only: True 28 | -------------------------------------------------------------------------------- /src/configs/models/motion_clip.yaml: -------------------------------------------------------------------------------- 1 | seed: 42 2 | 3 | model: 4 | target: src.models.motion_clip.MotionCLIP 5 | model_kwargs: 6 | output_size: 768 7 | n_heads: 8 8 | n_encoder_layers: 8 9 | init_latent_scale: 1 10 | text_feature_name: openai/clip-vit-large-patch14 11 | motion_representation: intergen_262 12 | dropout: 0.25 13 | n_labels: 40 14 | cls_weight: 0.1 15 | action_mask_coef: 31 16 | 17 | training_kwargs: 18 | optimizer: 19 | target: torch.optim.Adam 20 | lr: 1e-4 21 | scheduler: constant_schedule_with_warmup 22 | warmup_steps: 1000 23 | 24 | 25 | trainer: 26 | max_epochs: 40 27 | 28 | 29 | dataloader: 30 | batch_size: 128 31 | val_batch_size: 32 32 | num_workers: 32 33 | pin_memory: True 34 | persistent_workers: True 35 | -------------------------------------------------------------------------------- /src/configs/models/lm.yaml: -------------------------------------------------------------------------------- 1 | seed: 42 2 | 3 | model: 4 | target: src.models.lm.LMReactiveMotionGenerator 5 | model_kwargs: 6 | lm: google/flan-t5-base # google/flan-t5-large 7 | vqvae_ckpt_path: /path/to/ckpt 8 | 9 | evaluator_ckpt_path: /path/to/ckpt 10 | n_x_bins: 20 11 | n_z_bins: 20 12 | n_r_bins: 20 13 | mask_ratio: 0.15 14 | use_h3d: True # set this to False works as well 15 | stage: pretrain 16 | pretrained_path: 17 | unit_size: 1 18 | rethinking_interval: 4 19 | use_adaptive_sampling: True 20 | 21 | training_kwargs: 22 | optimizer: 23 | target: torch.optim.Adam 24 | lr: 1e-4 25 | scheduler: constant_schedule_with_warmup 26 | warmup_steps: 1000 27 | 28 | trainer: 29 | max_epochs: 200 30 | 31 | dataloader: 32 | batch_size: 32 33 | val_batch_size: 32 34 | num_workers: 32 35 | pin_memory: True 36 | persistent_workers: True 37 | -------------------------------------------------------------------------------- /src/modules/mask.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def get_triu_mask(h, w, device='cpu', batch_size=None, dtype=torch.float32, diagonal=1, step_length=1): 5 | fill_value = True if dtype == bool or dtype == torch.bool else float('-inf') 6 | if step_length == 1: 7 | mask = torch.triu(torch.full((h, w), fill_value=fill_value, dtype=dtype, device=device), diagonal=diagonal) 8 | elif step_length > 1: 9 | h_pad = ((h + step_length - 1) // step_length) * step_length 10 | w_pad = ((w + step_length - 1) // step_length) * step_length 11 | mask = torch.triu(torch.ones(h_pad // step_length, w_pad // step_length, dtype=dtype, device=device), diagonal=1) 12 | mask = mask.repeat_interleave(step_length, dim=0).repeat_interleave(step_length, dim=1)[:h, :w] 13 | else: 14 | raise ValueError(f'{step_length} is not a valid step_length') 15 | 16 | if batch_size is None: 17 | return mask 18 | else: 19 | return torch.stack([mask] * batch_size, dim=0) 20 | -------------------------------------------------------------------------------- /src/configs/models/motion_vqvae.yaml: -------------------------------------------------------------------------------- 1 | seed: 42 2 | 3 | 4 | model: 5 | target: src.models.motion_vqvae.MotionVQVAE 6 | model_kwargs: 7 | motion_representation: intergen_262 8 | nb_code: 256 9 | code_dim: 512 10 | width: 512 11 | output_emb_width: 512 12 | v2: False 13 | with_first_frame: True 14 | mu: 0.99 15 | down_t: 2 16 | stride_t: 2 17 | depth: 3 18 | dilation_growth_rate: 3 19 | vq_act: relu 20 | vq_norm: ~ 21 | quantizer: ema_reset 22 | beta: 1.0 23 | evaluator_ckpt_path: /path/to/ckpt 24 | 25 | training_kwargs: 26 | loss_kwargs: 27 | commit_weight: 0.02 28 | vel_weight: 0.5 29 | optimizer: 30 | target: torch.optim.Adam 31 | lr: 1e-4 32 | scheduler: constant_schedule_with_warmup 33 | warmup_steps: 1000 34 | 35 | 36 | trainer: 37 | max_epochs: 1000 38 | 39 | 40 | dataloader: 41 | batch_size: 512 42 | val_batch_size: 32 43 | num_workers: 32 44 | pin_memory: True 45 | persistent_workers: True 46 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 OpenMotionLab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /third_party/HumanML3D/human_body_prior/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright (C) 2019 Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG), 4 | # acting on behalf of its Max Planck Institute for Intelligent Systems and the 5 | # Max Planck Institute for Biological Cybernetics. All rights reserved. 6 | # 7 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is holder of all proprietary rights 8 | # on this computer program. You can only use this computer program if you have closed a license agreement 9 | # with MPG or you get the right to use the computer program from someone who is authorized to grant you that right. 10 | # Any use of the computer program without a valid license is prohibited and liable to prosecution. 11 | # Contact: ps-license@tuebingen.mpg.de 12 | # 13 | # 14 | # If you use this code in a research publication please consider citing the following: 15 | # 16 | # Expressive Body Capture: 3D Hands, Face, and Body from a Single Image 17 | # 18 | # 19 | # Code Developed by: 20 | # Nima Ghorbani 21 | # 22 | # 2018.01.02 23 | -------------------------------------------------------------------------------- /third_party/HumanML3D/human_body_prior/models/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright (C) 2019 Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG), 4 | # acting on behalf of its Max Planck Institute for Intelligent Systems and the 5 | # Max Planck Institute for Biological Cybernetics. All rights reserved. 6 | # 7 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is holder of all proprietary rights 8 | # on this computer program. You can only use this computer program if you have closed a license agreement 9 | # with MPG or you get the right to use the computer program from someone who is authorized to grant you that right. 10 | # Any use of the computer program without a valid license is prohibited and liable to prosecution. 11 | # Contact: ps-license@tuebingen.mpg.de 12 | # 13 | # 14 | # If you use this code in a research publication please consider citing the following: 15 | # 16 | # Expressive Body Capture: 3D Hands, Face, and Body from a Single Image 17 | # 18 | # 19 | # Code Developed by: 20 | # Nima Ghorbani 21 | # 22 | # 2020.12.12 -------------------------------------------------------------------------------- /third_party/HumanML3D/human_body_prior/tools/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright (C) 2019 Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG), 4 | # acting on behalf of its Max Planck Institute for Intelligent Systems and the 5 | # Max Planck Institute for Biological Cybernetics. All rights reserved. 6 | # 7 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is holder of all proprietary rights 8 | # on this computer program. You can only use this computer program if you have closed a license agreement 9 | # with MPG or you get the right to use the computer program from someone who is authorized to grant you that right. 10 | # Any use of the computer program without a valid license is prohibited and liable to prosecution. 11 | # Contact: ps-license@tuebingen.mpg.de 12 | # 13 | # 14 | # If you use this code in a research publication please consider citing the following: 15 | # 16 | # Expressive Body Capture: 3D Hands, Face, and Body from a Single Image 17 | # 18 | # 19 | # Code Developed by: 20 | # Nima Ghorbani 21 | # 22 | # 2020.12.12 23 | -------------------------------------------------------------------------------- /third_party/HumanML3D/human_body_prior/train/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright (C) 2019 Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG), 4 | # acting on behalf of its Max Planck Institute for Intelligent Systems and the 5 | # Max Planck Institute for Biological Cybernetics. All rights reserved. 6 | # 7 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is holder of all proprietary rights 8 | # on this computer program. You can only use this computer program if you have closed a license agreement 9 | # with MPG or you get the right to use the computer program from someone who is authorized to grant you that right. 10 | # Any use of the computer program without a valid license is prohibited and liable to prosecution. 11 | # Contact: ps-license@tuebingen.mpg.de 12 | # 13 | # 14 | # If you use this code in a research publication please consider citing the following: 15 | # 16 | # Expressive Body Capture: 3D Hands, Face, and Body from a Single Image 17 | # 18 | # 19 | # Code Developed by: 20 | # Nima Ghorbani 21 | # 22 | # 2018.01.02 23 | -------------------------------------------------------------------------------- /third_party/HumanML3D/human_body_prior/body_model/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright (C) 2019 Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG), 4 | # acting on behalf of its Max Planck Institute for Intelligent Systems and the 5 | # Max Planck Institute for Biological Cybernetics. All rights reserved. 6 | # 7 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is holder of all proprietary rights 8 | # on this computer program. You can only use this computer program if you have closed a license agreement 9 | # with MPG or you get the right to use the computer program from someone who is authorized to grant you that right. 10 | # Any use of the computer program without a valid license is prohibited and liable to prosecution. 11 | # Contact: ps-license@tuebingen.mpg.de 12 | # 13 | # 14 | # If you use this code in a research publication please consider citing the following: 15 | # 16 | # Expressive Body Capture: 3D Hands, Face, and Body from a Single Image 17 | # 18 | # 19 | # Code Developed by: 20 | # Nima Ghorbani 21 | # 22 | # 2018.01.02 23 | -------------------------------------------------------------------------------- /third_party/HumanML3D/human_body_prior/train/V02_05/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright (C) 2019 Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG), 4 | # acting on behalf of its Max Planck Institute for Intelligent Systems and the 5 | # Max Planck Institute for Biological Cybernetics. All rights reserved. 6 | # 7 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is holder of all proprietary rights 8 | # on this computer program. You can only use this computer program if you have closed a license agreement 9 | # with MPG or you get the right to use the computer program from someone who is authorized to grant you that right. 10 | # Any use of the computer program without a valid license is prohibited and liable to prosecution. 11 | # Contact: ps-license@tuebingen.mpg.de 12 | # 13 | # 14 | # If you use this code in a research publication please consider citing the following: 15 | # 16 | # Expressive Body Capture: 3D Hands, Face, and Body from a Single Image 17 | # 18 | # 19 | # Code Developed by: 20 | # Nima Ghorbani 21 | # 22 | # 2020.12.12 23 | -------------------------------------------------------------------------------- /third_party/HumanML3D/human_body_prior/visualizations/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright (C) 2019 Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG), 4 | # acting on behalf of its Max Planck Institute for Intelligent Systems and the 5 | # Max Planck Institute for Biological Cybernetics. All rights reserved. 6 | # 7 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is holder of all proprietary rights 8 | # on this computer program. You can only use this computer program if you have closed a license agreement 9 | # with MPG or you get the right to use the computer program from someone who is authorized to grant you that right. 10 | # Any use of the computer program without a valid license is prohibited and liable to prosecution. 11 | # Contact: ps-license@tuebingen.mpg.de 12 | # 13 | # 14 | # If you use this code in a research publication please consider citing the following: 15 | # 16 | # Expressive Body Capture: 3D Hands, Face, and Body from a Single Image 17 | # 18 | # 19 | # Code Developed by: 20 | # Nima Ghorbani 21 | # 22 | # 2020.12.12 23 | -------------------------------------------------------------------------------- /data_preprocessing/3.5-visualize_intergen_262.py: -------------------------------------------------------------------------------- 1 | #%% 2 | import sys 3 | import os 4 | import pickle 5 | import random 6 | from pathlib import Path 7 | 8 | sys.path.append(os.getcwd()) 9 | sys.path.append(os.getcwd() + '/../') 10 | from src.utils.motion_representation_converter import MotionRepresentationConverter 11 | from src.utils.plot import animate_multiple_joints3d_22 12 | 13 | 14 | ds = 'interx' 15 | joints_dirs = list(Path(f'~/data/data/motion/{ds}/intergen_262').expanduser().glob('*.pkl')) 16 | text_dir = Path(f'~/data/data/motion/{ds}/texts').expanduser() 17 | mrc = MotionRepresentationConverter() 18 | 19 | # %% 20 | for i in range(2): 21 | p = random.choice(joints_dirs) 22 | text = (text_dir / f'{p.stem}.txt').read_text().split('\n')[0] 23 | print(text) 24 | 25 | motion_dict = pickle.load(p.open('rb')) 26 | 27 | if 'inter' in ds.lower(): 28 | action, reaction = mrc('i262', 'j3d', motion_dict['action']), mrc('i262', 'j3d', motion_dict['reaction']) 29 | animate_multiple_joints3d_22([action, reaction], ['b', 'g'], title=text, file_path=f'temp_{i}.mp4') 30 | else: 31 | motion = mrc('i262', 'j3d', motion_dict['reaction']) 32 | -------------------------------------------------------------------------------- /src/datasets/dataset_base.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from pathlib import Path 3 | 4 | 5 | class DatasetBase(Dataset): 6 | def __init__( 7 | self, 8 | dataset_dir: str, 9 | split: str = 'train', 10 | epoch_scaling=1, 11 | tiny_dataset: bool = False, 12 | ): 13 | if len(dataset_dir.split(',')) == 1: 14 | self.dataset_dir = Path(dataset_dir).expanduser() 15 | else: 16 | self.dataset_dir = [Path(d).expanduser() for d in dataset_dir.split(',')] 17 | self.split = split 18 | self.epoch_scaling = epoch_scaling 19 | self.tiny_dataset = tiny_dataset 20 | 21 | @property 22 | def real_length(self): 23 | raise NotImplementedError("Implement this in the child class") 24 | 25 | def __len__(self): 26 | if self.split == 'train': 27 | return self.real_length * self.epoch_scaling 28 | else: 29 | return self.real_length 30 | 31 | def __getitem__(self, index): 32 | index = index % self.real_length 33 | return self.getitem(index=index) 34 | 35 | def getitem(self, index): 36 | raise NotImplementedError("Implement this in the child class") 37 | -------------------------------------------------------------------------------- /src/modules/embeddings.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | def get_sincos_pe(hidden_size, max_len=1000): 8 | pe = torch.zeros(max_len, hidden_size) 9 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 10 | div_term = torch.exp(torch.arange(0, hidden_size, 2).float() * (-torch.log(torch.tensor(10000.0)) / hidden_size)) 11 | pe[:, 0::2] = torch.sin(position * div_term) 12 | pe[:, 1::2] = torch.cos(position * div_term) 13 | pe = pe.unsqueeze(0) 14 | return pe 15 | 16 | 17 | class PositionalEncoding(nn.Module): 18 | def __init__(self, d_model, dropout=0.0, max_len=1000): 19 | super(PositionalEncoding, self).__init__() 20 | self.dropout = nn.Dropout(p=dropout) 21 | 22 | pe = torch.zeros(max_len, d_model) 23 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 24 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)) 25 | pe[:, 0::2] = torch.sin(position * div_term) 26 | pe[:, 1::2] = torch.cos(position * div_term) 27 | 28 | self.register_buffer('pe', pe) 29 | 30 | def forward(self, x): 31 | x = x + self.pe[:x.shape[1], :].unsqueeze(0) 32 | return self.dropout(x) 33 | -------------------------------------------------------------------------------- /data_preprocessing/1-preprocess_text.py: -------------------------------------------------------------------------------- 1 | #%% joints3d_22 to intergen_262 2 | import sys 3 | import os 4 | import pickle 5 | import numpy as np 6 | from pathlib import Path 7 | import torch 8 | from concurrent.futures import ProcessPoolExecutor as PPE 9 | 10 | sys.path.append(os.getcwd()) 11 | sys.path.append(os.getcwd() + '/../') 12 | 13 | 14 | dataset = 'interx' 15 | data_root_dir = Path(f'~/data/data/motion/{dataset}').expanduser() 16 | text_dir = data_root_dir / 'texts' 17 | 18 | def single_process(text_path: Path): 19 | try: 20 | texts = text_path.read_text().strip('\n').split('\n') 21 | target_texts = [] 22 | for text in texts: 23 | target_texts.append( 24 | text.strip(' \n,\t').replace( 25 | 'his/her', 'his').replace('him/her', 'him').replace('he/she', 'he').replace( 26 | 'counter-clockwise', 'counterclockwise').replace('counter clockwise', 'counterclockwise').replace( 27 | 'anti-clockwise', 'counterclockwise').replace('anti clockwise', 'counterclockwise') 28 | ) 29 | text_path.write_text('\n'.join(target_texts)) 30 | except ValueError as e: 31 | print(e) 32 | 33 | text_path_list = list(text_dir.glob('*.txt')) 34 | with PPE() as ppe: 35 | list(ppe.map(single_process, text_path_list)) 36 | -------------------------------------------------------------------------------- /data_preprocessing/0.5-visualize_joints3d_22.py: -------------------------------------------------------------------------------- 1 | #%% 2 | import sys 3 | import numpy as np 4 | import os 5 | import pickle 6 | import random 7 | from pathlib import Path 8 | 9 | sys.path.append(os.getcwd()) 10 | sys.path.append(os.getcwd() + '/../') 11 | from src.utils.plot import animate_multiple_joints3d_22 12 | 13 | ds = 'interx' 14 | joints_dirs = list(Path(f'~/data/data/motion/{ds}/joints3d_22').expanduser().glob('*')) 15 | text_dir = Path(f'~/data/data/motion/{ds}/texts').expanduser() 16 | 17 | # %% 18 | for i in range(1): 19 | p = random.choice(joints_dirs) 20 | text = (text_dir / f'{p.stem}.txt').read_text().strip().split('\n') 21 | 22 | if 'inter' in ds.lower(): 23 | dual_person = pickle.load(p.open('rb')) 24 | action, reaction, naction = dual_person['action'], dual_person['reaction'], dual_person['naction'] 25 | # animate_multiple_joints3d_22([action, reaction, naction], ['r', 'g', 'b'], title=text[0], file_path=f'temp_{ds}_{i}.mp4', show_axis=True) 26 | animate_multiple_joints3d_22([action, reaction], ['r', 'g'], title=text[0], file_path=f'temp_{ds}_{i}.mp4', show_axis=True) 27 | else: 28 | motion = pickle.load(p.open('rb'))['reaction'] 29 | animate_multiple_joints3d_22([motion], ['b'], title=text[0], file_path=f'temp_{ds}_{p.stem}.mp4', show_axis=True) 30 | 31 | # %% 32 | -------------------------------------------------------------------------------- /data_preprocessing/7-calculate_xzr_range.py: -------------------------------------------------------------------------------- 1 | #%% 2 | import sys 3 | import os 4 | import pickle 5 | import numpy as np 6 | from pathlib import Path 7 | import torch 8 | import matplotlib.pyplot as plt 9 | 10 | 11 | tgt_motion = 'intergen_262' 12 | dataset = 'interx' 13 | 14 | data_root_dir = Path(f'~/data/data/motion/{dataset}').expanduser() 15 | src_data_dir = data_root_dir / 'joints3d_22' 16 | 17 | x, z, r = [], [] , [] 18 | for p in src_data_dir.glob('*.pkl'): 19 | m = pickle.load(p.open('rb')) 20 | x.append(m['action_x']) 21 | z.append(m['action_z']) 22 | r.append(m['action_r']) 23 | x.append(m['reaction_x']) 24 | z.append(m['reaction_z']) 25 | r.append(m['reaction_r']) 26 | 27 | x = np.concatenate(x) 28 | z = np.concatenate(z) 29 | r = np.concatenate(r) 30 | 31 | #%% 32 | eps = 1e-4 33 | print(f'x: [{min(x) - eps}, {max(x) + eps}]') 34 | print(f'z: [{min(z) - eps}, {max(z) + eps}]') 35 | print(f'r: [{min(r) - eps}, {max(r) + eps}]') 36 | # %% 37 | def visualize_distribution(data, num_bins=100): 38 | data_min = np.min(data) 39 | data_max = np.max(data) 40 | bin_width = (data_max - data_min) / num_bins 41 | bins = np.arange(data_min, data_max + bin_width, bin_width) 42 | bin_indices = np.digitize(data, bins) 43 | plt.figure(figsize=(10, 6)) 44 | plt.hist(data, bins=bins, edgecolor='black') 45 | plt.xticks(bins) 46 | plt.show() 47 | 48 | #%% 49 | visualize_distribution(x, 100) 50 | visualize_distribution(z, 100) 51 | visualize_distribution(r, 100) 52 | # %% 53 | -------------------------------------------------------------------------------- /third_party/HumanML3D/human_body_prior/models/model_components.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright (C) 2019 Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG), 4 | # acting on behalf of its Max Planck Institute for Intelligent Systems and the 5 | # Max Planck Institute for Biological Cybernetics. All rights reserved. 6 | # 7 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is holder of all proprietary rights 8 | # on this computer program. You can only use this computer program if you have closed a license agreement 9 | # with MPG or you get the right to use the computer program from someone who is authorized to grant you that right. 10 | # Any use of the computer program without a valid license is prohibited and liable to prosecution. 11 | # Contact: ps-license@tuebingen.mpg.de 12 | # 13 | # 14 | # If you use this code in a research publication please consider citing the following: 15 | # 16 | # Expressive Body Capture: 3D Hands, Face, and Body from a Single Image 17 | # 18 | # 19 | # Code Developed by: 20 | # Nima Ghorbani 21 | # 22 | # 2020.12.12 23 | 24 | from torch import nn 25 | 26 | class View(nn.Module): 27 | def __init__(self, *args): 28 | super(View, self).__init__() 29 | self.shape = args 30 | self._name = 'reshape' 31 | 32 | def forward(self, x): 33 | return x.view(self.shape) 34 | 35 | class BatchFlatten(nn.Module): 36 | def __init__(self): 37 | super(BatchFlatten, self).__init__() 38 | self._name = 'batch_flatten' 39 | 40 | def forward(self, x): 41 | return x.view(x.shape[0], -1) -------------------------------------------------------------------------------- /third_party/HumanML3D/human_body_prior/train/README.md: -------------------------------------------------------------------------------- 1 | # Train VPoser from Scratch 2 | To train your own VPoser with new configuration duplicate the provided **V02_05** folder while setting a new experiment ID 3 | and change the settings as you desire. 4 | First you would need to download the 5 | [AMASS](https://amass.is.tue.mpg.de/) dataset, then following the [data preparation tutorial](../data/README.md) 6 | prepare the data for training. 7 | Following is a code snippet for training that can be found in the [example training experiment](https://github.com/nghorbani/human_body_prior/blob/master/src/human_body_prior/train/V02_05/V02_05.py): 8 | 9 | ```python 10 | import glob 11 | import os.path as osp 12 | 13 | from human_body_prior.tools.configurations import load_config 14 | from human_body_prior.train.vposer_trainer import train_vposer_once 15 | 16 | def main(): 17 | expr_id = 'V02_05' 18 | 19 | default_ps_fname = glob.glob(osp.join(osp.dirname(__file__), '*.yaml'))[0] 20 | 21 | vp_ps = load_config(default_ps_fname) 22 | 23 | vp_ps.train_parms.batch_size = 128 24 | 25 | vp_ps.general.expr_id = expr_id 26 | 27 | total_jobs = [] 28 | total_jobs.append(vp_ps.toDict().copy()) 29 | 30 | print('#training_jobs to be done: {}'.format(len(total_jobs))) 31 | if len(total_jobs) == 0: 32 | print('No jobs to be done') 33 | return 34 | 35 | for job in total_jobs: 36 | train_vposer_once(job) 37 | ``` 38 | The above code uses yaml configuration files to handle experiment settings. 39 | It loads the default settings in *.yaml* and overloads it with your new args. 40 | 41 | The training code, will dump a log file along with tensorboard readable events file. -------------------------------------------------------------------------------- /third_party/HumanML3D/human_body_prior/tools/configurations.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright (C) 2019 Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG), 4 | # acting on behalf of its Max Planck Institute for Intelligent Systems and the 5 | # Max Planck Institute for Biological Cybernetics. All rights reserved. 6 | # 7 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is holder of all proprietary rights 8 | # on this computer program. You can only use this computer program if you have closed a license agreement 9 | # with MPG or you get the right to use the computer program from someone who is authorized to grant you that right. 10 | # Any use of the computer program without a valid license is prohibited and liable to prosecution. 11 | # Contact: ps-license@tuebingen.mpg.de 12 | # 13 | # 14 | # If you use this code in a research publication please consider citing the following: 15 | # 16 | # Expressive Body Capture: 3D Hands, Face, and Body from a Single Image 17 | # 18 | # 19 | # Code Developed by: 20 | # Nima Ghorbani 21 | # 22 | # 2020.12.12 23 | from dotmap import DotMap 24 | import os 25 | import yaml 26 | 27 | def load_config(default_ps_fname=None, **kwargs): 28 | if isinstance(default_ps_fname, str): 29 | assert os.path.exists(default_ps_fname), FileNotFoundError(default_ps_fname) 30 | assert default_ps_fname.lower().endswith('.yaml'), NotImplementedError('Only .yaml files are accepted.') 31 | default_ps = yaml.safe_load(open(default_ps_fname, 'r')) 32 | else: 33 | default_ps = {} 34 | 35 | default_ps.update(kwargs) 36 | 37 | return DotMap(default_ps, _dynamic=False) 38 | 39 | def dump_config(data, fname): 40 | ''' 41 | dump current configuration to an ini file 42 | :param fname: 43 | :return: 44 | ''' 45 | with open(fname, 'w') as file: 46 | yaml.dump(data.toDict(), file) 47 | return fname 48 | -------------------------------------------------------------------------------- /third_party/HumanML3D/human_body_prior/train/V02_05/V02_05.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright (C) 2019 Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG), 4 | # acting on behalf of its Max Planck Institute for Intelligent Systems and the 5 | # Max Planck Institute for Biological Cybernetics. All rights reserved. 6 | # 7 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is holder of all proprietary rights 8 | # on this computer program. You can only use this computer program if you have closed a license agreement 9 | # with MPG or you get the right to use the computer program from someone who is authorized to grant you that right. 10 | # Any use of the computer program without a valid license is prohibited and liable to prosecution. 11 | # Contact: ps-license@tuebingen.mpg.de 12 | # 13 | # 14 | # If you use this code in a research publication please consider citing the following: 15 | # 16 | # Expressive Body Capture: 3D Hands, Face, and Body from a Single Image 17 | # 18 | # 19 | # Code Developed by: 20 | # Nima Ghorbani 21 | # 22 | # 2020.12.12 23 | 24 | import glob 25 | import os.path as osp 26 | 27 | from human_body_prior.tools.configurations import load_config 28 | from human_body_prior.train.vposer_trainer import train_vposer_once 29 | 30 | def main(): 31 | expr_id = 'V02_05' 32 | 33 | default_ps_fname = glob.glob(osp.join(osp.dirname(__file__), '*.yaml'))[0] 34 | 35 | vp_ps = load_config(default_ps_fname) 36 | 37 | vp_ps.train_parms.batch_size = 128 38 | 39 | vp_ps.general.expr_id = expr_id 40 | 41 | total_jobs = [] 42 | total_jobs.append(vp_ps.toDict().copy()) 43 | 44 | print('#training_jobs to be done: {}'.format(len(total_jobs))) 45 | if len(total_jobs) == 0: 46 | print('No jobs to be done') 47 | return 48 | 49 | for job in total_jobs: 50 | train_vposer_once(job) 51 | 52 | 53 | if __name__ == '__main__': 54 | main() -------------------------------------------------------------------------------- /third_party/HumanML3D/human_body_prior/train/V02_05/V02_05.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | body_model: 3 | gender: neutral 4 | bm_fname: ../../../../support_data/dowloads/models/smplx/neutral/model.npz 5 | 6 | general: 7 | verbosity: 0 8 | expr_id: 9 | dataset_id: V02_03 #SMPLx neutral 10 | rnd_seed: 100 11 | work_basedir: ../../../../support_data/training/training_experiments 12 | dataset_basedir: ../../../../support_data/training/data 13 | 14 | logging: 15 | expr_msg: 16 | num_bodies_to_display: 25 17 | work_dir: 18 | dataset_dir: 19 | render_during_training: False 20 | best_model_fname: 21 | 22 | train_parms: 23 | batch_size: 24 | num_epochs: 100 25 | restore_optimizer: False 26 | gen_optimizer: 27 | type: Adam 28 | args: 29 | lr: 0.001 30 | weight_decay: 0.00001 31 | lr_scheduler: 32 | type: ReduceLROnPlateau 33 | args: 34 | # metrics: val_loss 35 | verbose: true 36 | patience: 5 37 | early_stopping: 38 | monitor: val_loss 39 | min_delta: 0.0 40 | patience: 10 41 | verbose: True 42 | mode: min 43 | keep_extra_loss_terms_until_epoch: 15 44 | loss_weights: 45 | loss_kl_wt: 0.005 46 | loss_rec_wt: 4 47 | loss_matrot_wt: 2 48 | loss_jtr_wt: 2 49 | 50 | 51 | data_parms: 52 | num_workers: 5 # Used for dataloaders 53 | amass_dir: support_data/dowloads/amass/smplx_neutral 54 | num_timeseq_frames: 1 55 | amass_splits: 56 | vald: 57 | # - HumanEva 58 | # - MPI_HDM05 59 | # - SFU 60 | # - MPI_mosh 61 | - BMLrub_vald 62 | train: 63 | - CMU 64 | - BMLrub_train 65 | # - MPI_Limits 66 | # - TotalCapture 67 | # - Eyes_Japan_Dataset 68 | # - KIT 69 | # - BMLrub 70 | # - EKUT 71 | # - TCD_handMocap 72 | # - ACCAD 73 | # - BMLmovi 74 | test: 75 | - BMLrub_test 76 | # - Transitions_mocap 77 | # - SSM_synced 78 | # - DFaust_67 79 | 80 | 81 | model_params: 82 | num_neurons : 512 83 | latentD : 32 84 | 85 | -------------------------------------------------------------------------------- /src/utils/log.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import pickle 4 | from pathlib import Path 5 | import lightning.pytorch as pl 6 | 7 | 8 | class JsonLogger: 9 | def __init__(self, pl_class: pl.LightningModule): 10 | try: 11 | save_dir = pl_class.logger.log_dir 12 | except: 13 | save_dir = None 14 | 15 | if save_dir != None: 16 | self.log_path = Path(save_dir) / 'outputs.json' 17 | else: 18 | self.log_path = Path('temp_debug_log.json') 19 | 20 | def log(self, message: dict): 21 | json_message = json.dumps(message, indent=2) 22 | with self.log_path.open('a') as f: 23 | f.write(json_message + '\n') 24 | 25 | 26 | class PickleLogger: 27 | def __init__(self, pl_class: pl.LightningModule, log_dir=None): 28 | if log_dir: 29 | self.log_dir = Path(log_dir).expanduser() 30 | else: 31 | try: 32 | exp_dir = pl_class.logger.log_dir 33 | except: 34 | self.log_dir = Path('outputs') 35 | else: 36 | self.log_dir = Path(exp_dir) / 'outputs' 37 | 38 | self.log_dir.mkdir(exist_ok=True) 39 | 40 | def log(self, data: dict, file_name: str): 41 | save_path = self.log_dir / file_name 42 | with save_path.open('wb') as f: 43 | pickle.dump(data, f) 44 | 45 | 46 | def setup_logger(name, log_file=None): 47 | logger = logging.getLogger(name) 48 | logger.setLevel(logging.INFO) 49 | 50 | console_handler = logging.StreamHandler() 51 | console_handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')) 52 | 53 | if log_file: 54 | file_handler = logging.FileHandler(log_file) 55 | file_handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')) 56 | 57 | logger.addHandler(console_handler) 58 | if log_file: 59 | logger.addHandler(file_handler) 60 | 61 | return logger 62 | -------------------------------------------------------------------------------- /third_party/HumanML3D/paramUtil.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | # Define a kinematic tree for the skeletal struture 4 | kit_kinematic_chain = [[0, 11, 12, 13, 14, 15], [0, 16, 17, 18, 19, 20], [0, 1, 2, 3, 4], [3, 5, 6, 7], [3, 8, 9, 10]] 5 | 6 | kit_raw_offsets = np.array( 7 | [ 8 | [0, 0, 0], 9 | [0, 1, 0], 10 | [0, 1, 0], 11 | [0, 1, 0], 12 | [0, 1, 0], 13 | [1, 0, 0], 14 | [0, -1, 0], 15 | [0, -1, 0], 16 | [-1, 0, 0], 17 | [0, -1, 0], 18 | [0, -1, 0], 19 | [1, 0, 0], 20 | [0, -1, 0], 21 | [0, -1, 0], 22 | [0, 0, 1], 23 | [0, 0, 1], 24 | [-1, 0, 0], 25 | [0, -1, 0], 26 | [0, -1, 0], 27 | [0, 0, 1], 28 | [0, 0, 1] 29 | ] 30 | ) 31 | 32 | t2m_raw_offsets = np.array([[0,0,0], 33 | [1,0,0], 34 | [-1,0,0], 35 | [0,1,0], 36 | [0,-1,0], 37 | [0,-1,0], 38 | [0,1,0], 39 | [0,-1,0], 40 | [0,-1,0], 41 | [0,1,0], 42 | [0,0,1], 43 | [0,0,1], 44 | [0,1,0], 45 | [1,0,0], 46 | [-1,0,0], 47 | [0,0,1], 48 | [0,-1,0], 49 | [0,-1,0], 50 | [0,-1,0], 51 | [0,-1,0], 52 | [0,-1,0], 53 | [0,-1,0]]) 54 | 55 | t2m_kinematic_chain = [[0, 2, 5, 8, 11], [0, 1, 4, 7, 10], [0, 3, 6, 9, 12, 15], [9, 14, 17, 19, 21], [9, 13, 16, 18, 20]] 56 | t2m_left_hand_chain = [[20, 22, 23, 24], [20, 34, 35, 36], [20, 25, 26, 27], [20, 31, 32, 33], [20, 28, 29, 30]] 57 | t2m_right_hand_chain = [[21, 43, 44, 45], [21, 46, 47, 48], [21, 40, 41, 42], [21, 37, 38, 39], [21, 49, 50, 51]] 58 | 59 | 60 | kit_tgt_skel_id = '03950' 61 | 62 | t2m_tgt_skel_id = '000021' 63 | 64 | -------------------------------------------------------------------------------- /src/metrics/nlg.py: -------------------------------------------------------------------------------- 1 | #%% 2 | import numpy as np 3 | from typing import List, Dict 4 | from nlgmetricverse import load_metric, NLGMetricverse 5 | from nlgmetricverse.metrics import Bertscore 6 | 7 | 8 | 9 | class NLGEvaluator: 10 | def __init__(self, device='cpu'): 11 | # Initialize ROUGE, CIDEr, and BERT scorers 12 | self.nlg_metricverse = NLGMetricverse(metrics=[ 13 | load_metric("bleu", resulting_name="bleu_1", compute_kwargs={"max_order": 1}), 14 | load_metric("bleu", resulting_name="bleu_4", compute_kwargs={"max_order": 4}), 15 | # load_metric("bertscore", compute_kwargs={'device': device, 'idf': True}), 16 | # load_metric('meteor'), 17 | load_metric("rouge"), 18 | load_metric("cider"), 19 | # load_metric('recall') 20 | ]) 21 | 22 | def evaluate(self, pred_sentences: List[str], reference_sentences: List[List[str]]): 23 | metricverse_results = self.nlg_metricverse(predictions=pred_sentences, references=reference_sentences) 24 | log_dict = { 25 | 'nlg/blue_1': metricverse_results['bleu_1']['score'], 26 | 'nlg/blue_4': metricverse_results['bleu_4']['score'], 27 | 'nlg/rouge_1': metricverse_results['rouge']['rouge1'], 28 | 'nlg/rouge_2': metricverse_results['rouge']['rouge2'], 29 | 'nlg/rouge_L': metricverse_results['rouge']['rougeL'], 30 | 'nlg/cider': metricverse_results['cider']['score'], 31 | # 'nlg/bertscore': metricverse_results['bertscore']['score'], 32 | # 'nlg/bertscore_p': np.mean(metricverse_results['bertscore']['precision']), 33 | # 'nlg/bertscore_r': np.mean(metricverse_results['bertscore']['recall']), 34 | # 'nlg/bertscore_f1': np.mean(metricverse_results['bertscore']['f1']), 35 | # 'nlg/meteor': metricverse_results['meteor']['score'], 36 | 'nlg/recall': metricverse_results['recall']['score'] 37 | } 38 | log_dict['nlg/scores_sum'] = sum(log_dict.values()) 39 | return log_dict 40 | 41 | 42 | #%% 43 | if __name__ == '__main__': 44 | m = NLGEvaluator() 45 | pred = ['a person waves left hand', 'a person is dancing Waltz', 'i love python', 'world peace'] * 8 46 | src = [['the human is reaching out his left hand', 'hello world']] * 32 47 | res = m.evaluate(pred, src) 48 | 49 | for k, v in res.items(): 50 | print(f'{k}: {v}') 51 | # %% 52 | -------------------------------------------------------------------------------- /data_preprocessing/2-mirror_joints3d_22_and_text.py: -------------------------------------------------------------------------------- 1 | #%% 2 | import sys 3 | import numpy as np 4 | import os 5 | import pickle 6 | import random 7 | from pathlib import Path 8 | from concurrent.futures import ProcessPoolExecutor as PPE 9 | 10 | sys.path.append(os.getcwd()) 11 | sys.path.append(os.getcwd() + '/../') 12 | from data_preprocessing.utils import mirror_joints3d_22, mirror_text 13 | 14 | 15 | ds = 'interx' 16 | data_dir = Path(f'~/data/data/motion/{ds}').expanduser() 17 | joints_dir = data_dir / 'joints3d_22' 18 | text_dir = data_dir / 'texts' 19 | splits_dir = data_dir / 'splits' 20 | 21 | process_text = True 22 | 23 | 24 | def single_process(file_id): 25 | try: 26 | motion_path = joints_dir / f'{file_id}.pkl' 27 | motion_dict = pickle.load(motion_path.open('rb')) 28 | x, z, r = motion_dict['action_x'], motion_dict['action_z'], motion_dict['action_r'] 29 | mirrored_action, mirrored_reaction, mirrored_naction =\ 30 | mirror_joints3d_22(motion_dict['action']), mirror_joints3d_22(motion_dict['reaction']), mirror_joints3d_22(motion_dict['naction']) 31 | if process_text: 32 | texts = (text_dir / f'{file_id}.txt').read_text().split('\n') 33 | mirrored_texts = [mirror_text(t) for t in texts] 34 | except Exception as e: 35 | print(e) 36 | return None 37 | else: 38 | with (joints_dir / f'M{file_id}.pkl').open('wb') as f: 39 | pickle.dump( 40 | obj={ 41 | 'action': mirrored_action, 42 | 'reaction': mirrored_reaction, 43 | 'naction': mirrored_naction, 44 | 'action_x': -x, 45 | 'action_z': z, 46 | 'action_r': r 47 | }, 48 | file=f 49 | ) 50 | if process_text: 51 | (text_dir / f'M{file_id}.txt').write_text('\n'.join(mirrored_texts)) 52 | return f'M{file_id}' 53 | 54 | # %% 55 | raw_train_ids = (splits_dir / 'train.txt').read_text().strip('\n').split('\n') 56 | 57 | with PPE() as ppe: 58 | new_train_ids = list(ppe.map(single_process, raw_train_ids)) 59 | 60 | if process_text: 61 | with (splits_dir / 'train.txt').open('a') as f: 62 | f.writelines([f'\n{t}' for t in new_train_ids if t is not None]) 63 | 64 | with (splits_dir / 'all.txt').open('a') as f: 65 | f.writelines([f'\n{t}' for t in new_train_ids if t is not None]) 66 | -------------------------------------------------------------------------------- /data_preprocessing/3-prepare_motions.py: -------------------------------------------------------------------------------- 1 | #%% joints3d_22 to intergen_262 2 | import sys 3 | import os 4 | import pickle 5 | import numpy as np 6 | from pathlib import Path 7 | import torch 8 | from concurrent.futures import ProcessPoolExecutor as PPE 9 | 10 | sys.path.append(os.getcwd()) 11 | sys.path.append(os.getcwd() + '/../') 12 | from src.utils.motion_representation_converter import MotionRepresentationConverter 13 | 14 | motion_map = { 15 | 'joints3d_22': 'j3d', 16 | 'joints12d_22': 'j12d', 17 | 'intergen_262': 'i262', 18 | 'humanml3d_263': 'h263' 19 | } 20 | 21 | 22 | tgt_motion = 'intergen_262' 23 | dataset = 'interx' 24 | 25 | data_root_dir = Path(f'~/data/data/motion/{dataset}').expanduser() 26 | src_data_dir = data_root_dir / 'joints3d_22' 27 | save_dir = data_root_dir / tgt_motion 28 | save_dir.mkdir(exist_ok=True) 29 | 30 | 31 | def single_process(motion_data_path): 32 | try: 33 | mrc = MotionRepresentationConverter() 34 | if motion_data_path.name.endswith('pkl'): 35 | with motion_data_path.open('rb') as f: 36 | motion_dict = pickle.load(f) 37 | reaction = motion_dict['reaction'] 38 | tgt_reaction = mrc('j3d', motion_map[tgt_motion], reaction) 39 | res = {'reaction': tgt_reaction} 40 | if 'action' in motion_dict: 41 | action = motion_dict['action'] 42 | tgt_action = mrc('j3d', motion_map[tgt_motion], action) 43 | res['action'] = tgt_action 44 | 45 | naction = motion_dict['naction'] 46 | tgt_naction = mrc('j3d', motion_map[tgt_motion], naction) 47 | res['naction'] = tgt_naction 48 | 49 | res['action_x'] = motion_dict['action_x'] 50 | res['action_z'] = motion_dict['action_z'] 51 | res['action_r'] = motion_dict['action_r'] 52 | with (save_dir / f'{motion_data_path.stem}.pkl').open('wb') as f: 53 | pickle.dump(res, f) 54 | else: 55 | if (save_dir / f'{motion_data_path.stem}.npy').exists(): 56 | return 57 | with motion_data_path.open('rb') as f: 58 | motion = np.load(f) 59 | tgt_reaction = mrc('j3d', motion_map[tgt_motion], motion) 60 | res = tgt_reaction 61 | with (save_dir / f'{motion_data_path.stem}.npy').open('wb') as f: 62 | np.save(f, res) 63 | except ValueError as e: 64 | print(e) 65 | 66 | 67 | motion_data_path_list = list(src_data_dir.glob('*')) 68 | # single_process(motion_data_path_list[0]) 69 | with PPE() as ppe: 70 | list(ppe.map(single_process, motion_data_path_list)) 71 | -------------------------------------------------------------------------------- /data_preprocessing/4-prepare_xzr.py: -------------------------------------------------------------------------------- 1 | #%% 2 | import sys 3 | import numpy as np 4 | import os 5 | import pickle 6 | from pathlib import Path 7 | from concurrent.futures import ProcessPoolExecutor as PPE 8 | 9 | sys.path.append(os.getcwd()) 10 | sys.path.append(os.getcwd() + '/../') 11 | 12 | 13 | ds = 'interx' 14 | data_dir = Path(f'~/data/data/motion/{ds}').expanduser() 15 | joints_dir = data_dir / 'joints3d_22' 16 | i262_dir = data_dir / 'intergen_262' 17 | splits_dir = data_dir / 'splits' 18 | 19 | face_joint_indx = [2, 1, 17, 16] 20 | r_hip, l_hip, sdr_r, sdr_l = face_joint_indx 21 | 22 | 23 | def get_xzr(motion): 24 | x, z, r = [], [] , [] 25 | for m in motion: 26 | x.append(m[0, 0]) 27 | z.append(m[0, 2]) 28 | across = m[r_hip] - m[l_hip] 29 | across = across / np.sqrt((across ** 2).sum(axis=-1))[..., np.newaxis] 30 | forward_init = np.cross(np.array([[0, 1, 0]]), across, axis=-1) 31 | forward_init = forward_init / np.sqrt((forward_init ** 2).sum(axis=-1))[..., np.newaxis] 32 | r.append(np.arctan2(forward_init[0, 0], forward_init[0, 2])) 33 | return np.array(x), np.array(z), np.array(r) 34 | 35 | 36 | def single_process(file_id): 37 | try: 38 | j3d_path = joints_dir / f'{file_id}.pkl' 39 | j3d_dict = pickle.load(j3d_path.open('rb')) 40 | 41 | i262_path = i262_dir / f'{file_id}.pkl' 42 | i262_dict = pickle.load(i262_path.open('rb')) 43 | 44 | reaction_xzr = get_xzr(j3d_dict['reaction']) 45 | action_xzr = get_xzr(j3d_dict['action']) 46 | 47 | j3d_dict['reaction_x'] = reaction_xzr[0] 48 | j3d_dict['reaction_z'] = reaction_xzr[1] 49 | j3d_dict['reaction_r'] = reaction_xzr[2] 50 | j3d_dict['action_x'] = action_xzr[0] 51 | j3d_dict['action_z'] = action_xzr[1] 52 | j3d_dict['action_r'] = action_xzr[2] 53 | 54 | i262_dict['reaction_x'] = reaction_xzr[0] 55 | i262_dict['reaction_z'] = reaction_xzr[1] 56 | i262_dict['reaction_r'] = reaction_xzr[2] 57 | i262_dict['action_x'] = action_xzr[0] 58 | i262_dict['action_z'] = action_xzr[1] 59 | i262_dict['action_r'] = action_xzr[2] 60 | 61 | except Exception as e: 62 | print(e) 63 | return None 64 | else: 65 | with j3d_path.open('wb') as f: 66 | pickle.dump( 67 | obj=j3d_dict, 68 | file=f 69 | ) 70 | with i262_path.open('wb') as f: 71 | pickle.dump( 72 | obj=i262_dict, 73 | file=f 74 | ) 75 | 76 | # %% 77 | ids = (splits_dir / 'all.txt').read_text().strip('\n').split('\n') 78 | 79 | with PPE() as ppe: 80 | list(ppe.map(single_process, ids)) 81 | -------------------------------------------------------------------------------- /third_party/HumanML3D/human_body_prior/body_model/rigid_object_model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright (C) 2019 Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG), 4 | # acting on behalf of its Max Planck Institute for Intelligent Systems and the 5 | # Max Planck Institute for Biological Cybernetics. All rights reserved. 6 | # 7 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is holder of all proprietary rights 8 | # on this computer program. You can only use this computer program if you have closed a license agreement 9 | # with MPG or you get the right to use the computer program from someone who is authorized to grant you that right. 10 | # Any use of the computer program without a valid license is prohibited and liable to prosecution. 11 | # Contact: ps-license@tuebingen.mpg.de 12 | # 13 | # 14 | # If you use this code in a research publication please consider citing the following: 15 | # 16 | # Expressive Body Capture: 3D Hands, Face, and Body from a Single Image 17 | # 18 | # 19 | # Code Developed by: 20 | # Nima Ghorbani 21 | # 22 | # 2018.12.13 23 | 24 | import numpy as np 25 | 26 | import torch 27 | import torch.nn as nn 28 | 29 | # from smplx.lbs import lbs 30 | from human_body_prior.body_model.lbs import lbs 31 | # import trimesh # dont use this package for loading meshes since it messes up the order of vertices 32 | from psbody.mesh import Mesh 33 | from human_body_prior.body_model.lbs import batch_rodrigues 34 | 35 | class RigidObjectModel(nn.Module): 36 | 37 | def __init__(self, plpath, batch_size=1, dtype=torch.float32): 38 | super(RigidObjectModel, self).__init__() 39 | 40 | trans = torch.tensor(np.zeros((batch_size, 3)), dtype=dtype, requires_grad=True) 41 | self.register_parameter('trans', nn.Parameter(trans, requires_grad=True)) 42 | 43 | root_orient = torch.tensor(np.zeros((batch_size, 3)), dtype=dtype, requires_grad=True) 44 | self.register_parameter('root_orient', nn.Parameter(root_orient, requires_grad=True)) 45 | 46 | mesh = Mesh(filename=plpath) 47 | 48 | self.rigid_v = torch.from_numpy(np.repeat(mesh.v[np.newaxis], batch_size, axis=0)).type(dtype) 49 | self.f = torch.from_numpy(mesh.f.astype(np.int32)) 50 | 51 | def forward(self, root_orient, trans): 52 | if root_orient is None: root_orient = self.root_orient 53 | if trans is None: trans = self.trans 54 | verts = torch.bmm(self.rigid_v, batch_rodrigues(root_orient)) + trans.view(-1,1,3) 55 | 56 | res = {} 57 | res['v'] = verts 58 | res['f'] = self.f 59 | 60 | class result_meta(object): pass 61 | 62 | res_class = result_meta() 63 | for k, v in res.items(): 64 | res_class.__setattr__(k, v) 65 | res = res_class 66 | 67 | return res 68 | -------------------------------------------------------------------------------- /third_party/HumanML3D/human_body_prior/tools/angle_continuous_repres.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright (C) 2019 Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG), 4 | # acting on behalf of its Max Planck Institute for Intelligent Systems and the 5 | # Max Planck Institute for Biological Cybernetics. All rights reserved. 6 | # 7 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is holder of all proprietary rights 8 | # on this computer program. You can only use this computer program if you have closed a license agreement 9 | # with MPG or you get the right to use the computer program from someone who is authorized to grant you that right. 10 | # Any use of the computer program without a valid license is prohibited and liable to prosecution. 11 | # Contact: ps-license@tuebingen.mpg.de 12 | # 13 | # 14 | # If you use this code in a research publication please consider citing the following: 15 | # 16 | # Expressive Body Capture: 3D Hands, Face, and Body from a Single Image 17 | # 18 | # 19 | # Code Developed by: 20 | # Nima Ghorbani 21 | # 22 | # 2020.12.12 23 | import torch.nn.functional as F 24 | import torch 25 | from torch import nn 26 | 27 | import numpy as np 28 | 29 | # numpy implementation of yi zhou's method 30 | def norm(v): 31 | return v/np.linalg.norm(v) 32 | 33 | def gs(M): 34 | a1 = M[:,0] 35 | a2 = M[:,1] 36 | b1 = norm(a1) 37 | b2 = norm((a2-np.dot(b1,a2)*b1)) 38 | b3 = np.cross(b1,b2) 39 | return np.vstack([b1,b2,b3]).T 40 | 41 | # input sz bszx3x2 42 | def bgs(d6s): 43 | 44 | bsz = d6s.shape[0] 45 | b1 = F.normalize(d6s[:,:,0], p=2, dim=1) 46 | a2 = d6s[:,:,1] 47 | c = torch.bmm(b1.view(bsz,1,-1),a2.view(bsz,-1,1)).view(bsz,1)*b1 48 | b2 = F.normalize(a2-c,p=2,dim=1) 49 | b3=torch.cross(b1,b2,dim=1) 50 | return torch.stack([b1,b2,b3],dim=1).permute(0,2,1) 51 | 52 | 53 | class geodesic_loss_R(nn.Module): 54 | def __init__(self, reduction='batchmean'): 55 | super(geodesic_loss_R, self).__init__() 56 | 57 | self.reduction = reduction 58 | self.eps = 1e-6 59 | 60 | # batch geodesic loss for rotation matrices 61 | def bgdR(self,m1,m2): 62 | batch = m1.shape[0] 63 | m = torch.bmm(m1, m2.transpose(1, 2)) # batch*3*3 64 | 65 | cos = (m[:, 0, 0] + m[:, 1, 1] + m[:, 2, 2] - 1) / 2 66 | cos = torch.min(cos, m1.new(np.ones(batch))) 67 | cos = torch.max(cos, m1.new(np.ones(batch)) * -1) 68 | 69 | return torch.acos(cos) 70 | 71 | def forward(self, ypred, ytrue): 72 | theta = self.bgdR(ypred,ytrue) 73 | if self.reduction == 'mean': 74 | return torch.mean(theta) 75 | if self.reduction == 'batchmean': 76 | breakpoint() 77 | return torch.mean(torch.sum(theta, dim=theta.shape[1:])) 78 | 79 | else: 80 | return theta -------------------------------------------------------------------------------- /src/models/model_base.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | import lightning.pytorch as pl 3 | from transformers.optimization import get_cosine_schedule_with_warmup, get_constant_schedule_with_warmup 4 | 5 | from ..utils import instantiate_from_config 6 | 7 | 8 | class ModelBase(pl.LightningModule): 9 | def __init__( 10 | self, 11 | model_kwargs, 12 | training_kwargs, 13 | all_config=None, 14 | ): 15 | super().__init__() 16 | 17 | self.all_config = all_config 18 | self.training_kwargs = training_kwargs 19 | self.model_kwargs = model_kwargs 20 | self.save_hyperparameters() 21 | 22 | def configure_optimizers(self): 23 | kwargs = self.training_kwargs 24 | tuned_parameters = [p for p in self.parameters() if p.requires_grad] 25 | 26 | optimizer = instantiate_from_config(kwargs.optimizer, extra_kwargs={'params': tuned_parameters}) 27 | 28 | if kwargs.scheduler == 'cosine_schedule_with_warmup': 29 | scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=kwargs.warmup_steps, num_training_steps=kwargs.num_training_steps) 30 | else: 31 | scheduler = get_constant_schedule_with_warmup(optimizer, num_warmup_steps=kwargs.warmup_steps) 32 | 33 | self.lr_scheduler = scheduler 34 | return { 35 | 'optimizer': optimizer, 36 | 'lr_scheduler': { 37 | 'scheduler': scheduler, 38 | 'interval': 'step' 39 | } 40 | } 41 | 42 | def training_step(self, batch, batch_idx): 43 | log_dict = self.get_log_dict(batch, batch_idx, 'train') 44 | log_dict.update(self.extra_training_step(batch=batch, batch_idx=batch_idx)) 45 | log_dict['lr'] = self.lr_scheduler.get_last_lr()[0] 46 | self.log_dict(log_dict, sync_dist=True, prog_bar=True, batch_size=self.all_config.dataloader.batch_size) 47 | return log_dict['train/total_loss'] 48 | 49 | def validation_step(self, batch, batch_idx): 50 | log_dict = self.get_log_dict(batch, batch_idx, 'val') 51 | log_dict.update(self.extra_validation_step(batch=batch, batch_idx=batch_idx)) 52 | if 'monitor' not in log_dict.keys(): 53 | log_dict['monitor'] = - log_dict['val/total_loss'] 54 | self.log_dict(log_dict, sync_dist=True, prog_bar=True, batch_size=self.all_config.dataloader.batch_size) 55 | return log_dict['val/total_loss'] 56 | 57 | def test_step(self, batch, batch_idx=None) -> Dict: 58 | raise NotImplementedError("Implement this in the child class") 59 | 60 | def get_log_dict(self, batch, split, batch_idx=None) -> Dict: 61 | raise NotImplementedError("Implement this in the child class") 62 | 63 | def extra_training_step(self, batch, batch_idx=None) -> Dict: 64 | return {} 65 | 66 | def extra_validation_step(self, batch, batch_idx=None) -> Dict: 67 | return {} 68 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [ICLR 2025] Think Then React: Towards Better Action-to-Reaction Motion Generation 2 | 3 | Wenhui Tan, Boyuan Li, Chuhao Jin, Wenbing Huang, Xiting Wang, Ruihua Song @ RUC-GSAI 4 | 5 | [Paper Link](https://openreview.net/pdf?id=UxzKcIZedp), [Project Link](Think-Then-React.github.io) 6 | 7 | # Introduction 8 | ![teaser image](./assets/teaser.png) 9 | Given a human action as input, our Think-Then-React model first thinks by generating an action description and reasons out a reaction prompt. It then reacts to the action based on the results of this thinking process. TTR reacts in a real-time manner at every timestep and periodically 10 | 11 | # Environment 12 | ``` 13 | pip install -r requirements.txt 14 | ``` 15 | 16 | Or just use your own pytorch environment with **transformers**, **lightning**, **smplx**, **omegaconf**, and **matplotlib** installed. 17 | 18 | # Data 19 | ### Download Dataset 20 | Get Inter-X-Dataset from https://github.com/liangxuy/Inter-X, and place it at ~/data/data/motion/Inter-X_Dataset. 21 | 22 | ### Preprocessing 23 | Run all the scripts under _data_preprocessing_ sequentially, e.g., run 24 | ``` 25 | python data_preprocessing/0-smpl_to_joints3d_22.py 26 | ``` 27 | And then run 28 | ``` 29 | python data_preprocessing/1-preprocess_text.py 30 | ``` 31 | 32 | Files with indices like 0.5 and 1.5 would help you analyze and check the processed dataset. The intermidiate steps are very crucial to ensure the data is in the right format and can be used for training. 33 | 34 | # Training 35 | 36 | We use lightning.pytorch to train our models, Omegaconf to manage the configurations, and Tensorboard to monitor training. 37 | 38 | ### 0: Evaluation Model Training 39 | ``` 40 | python train.py --model=motion_clip --dataset=motion_clip --devices=0,1,2,3 41 | ``` 42 | 43 | ### 1: Motion VQ-VAE Training 44 | ``` 45 | python train.py --model=motion_vqvae --dataset=motion_vqvae --devices=1 46 | ``` 47 | 48 | ### 2: LM Pre-Training 49 | ``` 50 | python train.py --model=lm --dataset=lm --devices=0,1,2,3,4,5,6,7 stage=pretrain 51 | ``` 52 | 53 | ### 3: LM Fine-Tuning 54 | ``` 55 | python train.py --model=lm --dataset=lm --devices=0,1,2,3,4,5,6,7 stage=finetune pretrained_ckpt=/path/to/your/lm/pretrained_ckpt 56 | ``` 57 | 58 | # Eval & Inference 59 | ``` 60 | python eval.py --ckpt_path=/path/to/your/ckpt 61 | ``` 62 | 63 | # Acknowledgement 64 | We thank [MotionGPT](https://github.com/OpenMotionLab/MotionGPT) and [HumanML3D](https://github.com/EricGuo5513/HumanML3D) for their useful code for data processing. 65 | 66 | # Citation 67 | If you use this code base in your research, please cite our paper with the following BibTex entry: 68 | ```bibtex 69 | @inproceedings{ 70 | tan2025think, 71 | title={Think Then React: Towards Unconstrained Action-to-Reaction Motion Generation}, 72 | author={Wenhui Tan and Boyuan Li and Chuhao Jin and Wenbing Huang and Xiting Wang and Ruihua Song}, 73 | booktitle={The Thirteenth International Conference on Learning Representations}, 74 | year={2025}, 75 | url={https://openreview.net/forum?id=UxzKcIZedp} 76 | } 77 | ``` 78 | 79 | # Licenses 80 | This project is licensed under the MIT LICENSE - see the [LICENSE](LICENSE) file for details 81 | 82 | Note that our code depends on other libraries, including SMPL, SMPL-X, PyTorch3D, and uses datasets which each have their own respective licenses that must also be followed. 83 | -------------------------------------------------------------------------------- /src/metrics/motion_generation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import Levenshtein 4 | 5 | from ..utils.utils import get_model_and_config_from_ckpt_path 6 | from .common import calculate_fid, euclidean_distance_matrix, calculate_top_k, calculate_diversity 7 | 8 | 9 | class MotionGenerationEvaluator(torch.nn.Module): 10 | def __init__(self, ckpt_path: str, device='cpu'): 11 | super().__init__() 12 | model, _ = get_model_and_config_from_ckpt_path(ckpt_path) 13 | self.model = model.to(device) 14 | for p in self.model.parameters(): 15 | p.requires_grad_(False) 16 | 17 | def calculate_fid(self, gt_embeddings, pred_embeddings): 18 | return { 19 | 'mg/fid': calculate_fid(gt_embeddings, pred_embeddings) 20 | } 21 | 22 | def calculate_div(self, motion_embeddings, diversity_times): 23 | return { 24 | 'mg/div': calculate_diversity(motion_embeddings, diversity_times) 25 | } 26 | 27 | def calculate_ranking_and_mm_dist(self, text_embeddings, motion_embeddings): 28 | res = {} 29 | gt_dist_mat = euclidean_distance_matrix(text_embeddings, motion_embeddings) 30 | mm_dist = gt_dist_mat.trace() / text_embeddings.shape[0] 31 | res['mg/mm_dist'] = mm_dist 32 | 33 | argsmax = np.argsort(gt_dist_mat, axis=1) 34 | top_k_mat = calculate_top_k(argsmax, top_k=3) 35 | r_prec = top_k_mat.sum(axis=0) / text_embeddings.shape[0] 36 | 37 | for i in range(3): 38 | res[f'mg/top {i + 1}'] = r_prec[i] 39 | 40 | return res 41 | 42 | def calculate_acc(self, logits, labels): 43 | log_dict = dict() 44 | acc = (logits.argmax(-1) == labels).sum() / labels.shape[0] 45 | log_dict['mg/acc_1'] = acc.item() 46 | 47 | _, top5_preds = torch.topk(logits, 5, dim=1) 48 | top5_correct = (labels.unsqueeze(1) == top5_preds).any(dim=1).float() 49 | top5_accuracy = top5_correct.sum() / labels.shape[0] 50 | log_dict['mg/acc_5'] = top5_accuracy.item() 51 | return log_dict 52 | 53 | # def calculate_edit_distance(self, gt_chars_list, pred_chars_list, clustering=True): 54 | # eds = [] 55 | # for g, p in zip(gt_chars_list, pred_chars_list): 56 | # if clustering: 57 | # eds.append(Levenshtein.distance(chars_clustering(g), chars_clustering(p)) / len(g)) 58 | # else: 59 | # eds.append(Levenshtein.distance(g, p) / len(g)) 60 | # return {'mg/ed': np.mean(eds)} 61 | 62 | def evaluate(self, gt_action, gt_reaction, pred_reaction, boolean_mask, text_list, labels=None): 63 | gt_motion_embeddings = self.model.encode_motion(reaction=gt_reaction, action=gt_action, boolean_mask=boolean_mask).cpu().numpy() 64 | pred_motion_embeddings = self.model.encode_motion(reaction=pred_reaction, action=gt_action, boolean_mask=boolean_mask) 65 | 66 | if labels is not None: 67 | logits = self.model.motion_cls_head(pred_motion_embeddings) 68 | 69 | pred_motion_embeddings = pred_motion_embeddings.cpu().numpy() 70 | text_embeddings = self.model.encode_text(text_list).cpu().numpy() 71 | 72 | results = {} 73 | results.update(self.calculate_ranking_and_mm_dist(text_embeddings, pred_motion_embeddings)) 74 | results.update(self.calculate_div(pred_motion_embeddings, pred_motion_embeddings.shape[0] - 1)) 75 | results.update(self.calculate_fid(gt_motion_embeddings, pred_motion_embeddings)) 76 | 77 | if labels is not None: 78 | results.update(self.calculate_acc(logits, labels)) 79 | 80 | return results 81 | -------------------------------------------------------------------------------- /src/utils/utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from typing import Dict, Any 3 | import copy 4 | import importlib 5 | from pathlib import Path 6 | from datetime import datetime 7 | from omegaconf import OmegaConf 8 | from omegaconf.dictconfig import DictConfig 9 | import numpy as np 10 | import torch 11 | 12 | 13 | def get_timestamp(): 14 | return datetime.now().strftime('%Y%m%d-%H%M%S') 15 | 16 | 17 | def get_obj_from_str(string, reload=False): 18 | module, cls = string.rsplit(".", 1) 19 | if reload: 20 | module_imp = importlib.import_module(module) 21 | importlib.reload(module_imp) 22 | return getattr(importlib.import_module(module, package=None), cls) 23 | 24 | 25 | def instantiate_from_config(config, extra_kwargs=dict()): 26 | config_dict = dict(config) 27 | if not "target" in config_dict: 28 | raise ValueError(f'target not found in {config}') 29 | 30 | target_kwargs = copy.deepcopy(config_dict) 31 | target_kwargs.pop('target') 32 | 33 | for k, v in target_kwargs.items(): 34 | if isinstance(v, DictConfig) and 'target' in v.keys(): 35 | target_kwargs[k] = instantiate_from_config(v) 36 | target_kwargs.update(extra_kwargs) 37 | 38 | return get_obj_from_str(config_dict["target"])(**target_kwargs) 39 | 40 | 41 | def dict_apply(x, func): 42 | result = dict() 43 | for key, value in x.items(): 44 | if isinstance(value, dict): 45 | result[key] = dict_apply(value, func) 46 | else: 47 | result[key] = func(value) 48 | return result 49 | 50 | 51 | def dict_to_device(d: Dict[str, Any], device): 52 | for k, v in d.items(): 53 | if isinstance(v, torch.Tensor): 54 | d[k] = v.to(device=device) 55 | return d 56 | 57 | 58 | def list_subdirs(path: Path): 59 | return [d for d in path.glob('*') if not d.is_file()] 60 | 61 | 62 | def is_debug_mode(): 63 | return hasattr(sys, 'gettrace') and sys.gettrace() is not None 64 | 65 | 66 | def get_clones(module, N): 67 | return torch.nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 68 | 69 | 70 | def pad(data: torch.Tensor, length: int, dim: int, value: Any = 0, right_side_padding=True, get_boolean_mask=True): 71 | raw_shape = data.shape 72 | 73 | if get_boolean_mask: 74 | boolean_mask = torch.ones(length, dtype=torch.bool) 75 | boolean_mask[:raw_shape[dim]] = False 76 | if raw_shape[dim] == length: 77 | return data, boolean_mask, raw_shape[dim] 78 | else: 79 | boolean_mask = None 80 | 81 | padding_shape = list(raw_shape) 82 | padding_shape[dim] = length - raw_shape[dim] 83 | paddings = torch.ones(size=padding_shape, device=data.device, dtype=data.dtype) * value 84 | 85 | if right_side_padding: 86 | return torch.cat([data, paddings], dim=dim), boolean_mask, raw_shape[dim] 87 | else: 88 | return torch.cat([paddings, data], dim=dim), boolean_mask, raw_shape[dim] 89 | 90 | 91 | def get_metric_statistics(values, replication_times): 92 | mean = np.mean(values, axis=0) 93 | std = np.std(values, axis=0) 94 | conf_interval = 1.96 * std / np.sqrt(replication_times) 95 | return mean, conf_interval 96 | 97 | 98 | def get_model_and_config_from_ckpt_path(ckpt_path: str, strict=False): 99 | ckpt_path = Path(ckpt_path) 100 | log_dir = ckpt_path.parent.parent 101 | config = OmegaConf.load(log_dir / 'hparams.yaml').all_config 102 | 103 | model_cls = get_obj_from_str(config.model.target) 104 | model = model_cls.load_from_checkpoint(str(ckpt_path), map_location='cpu', strict=strict).eval() 105 | 106 | return model, config 107 | -------------------------------------------------------------------------------- /src/utils/rotation.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is actually never accessed, but might be useful in the future 3 | """ 4 | #%% 5 | from typing import Union 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | from scipy.spatial.transform import Rotation 10 | 11 | 12 | def dtype_maintaining_wrapper(func): 13 | """ 14 | Maintain the data type of input and output (torch or numpy) for functions in numpy space 15 | """ 16 | def wrapper(rotation: Union[np.ndarray, torch.Tensor], **kwargs): 17 | torch_input = False 18 | if isinstance(rotation, torch.Tensor): 19 | torch_input = True 20 | rotation = rotation.cpu().numpy() 21 | 22 | result = func(rotation, **kwargs) 23 | 24 | if torch_input: 25 | result = torch.from_numpy(result) 26 | 27 | return result 28 | return wrapper 29 | 30 | 31 | def multi_dim_rotmat_input_wrapper(func): 32 | def wrapper(rotation: np.ndarray, **kwargs): 33 | raw_shape = rotation.shape 34 | if len(raw_shape) <= 3: # ([bs,] 3, 3) 35 | result = func(rotation, **kwargs) 36 | else: 37 | result = func(rotation.reshape(-1, 3, 3), **kwargs).reshape(*raw_shape[: -2], -1) 38 | return result 39 | return wrapper 40 | 41 | 42 | class RotationHelper: 43 | 44 | @staticmethod 45 | @dtype_maintaining_wrapper 46 | def quat_to_rotmat(q: np.ndarray): 47 | r = Rotation.from_quat(q) 48 | return r.as_matrix() 49 | 50 | @staticmethod 51 | @dtype_maintaining_wrapper 52 | @multi_dim_rotmat_input_wrapper 53 | def rotmat_to_quat(rotmat: np.ndarray): 54 | r = Rotation.from_matrix(rotmat) 55 | return r.as_quat() 56 | 57 | @staticmethod 58 | @dtype_maintaining_wrapper 59 | def axis_angle_to_rotmat(aa: np.ndarray): 60 | r = Rotation.from_rotvec(aa) 61 | return r.as_matrix() 62 | 63 | @staticmethod 64 | @dtype_maintaining_wrapper 65 | @multi_dim_rotmat_input_wrapper 66 | def rotmat_to_axis_angle(rotmat: np.ndarray): 67 | r = Rotation.from_matrix(rotmat) 68 | return r.as_rotvec() 69 | 70 | @staticmethod 71 | @dtype_maintaining_wrapper 72 | def euler_angle_to_rotmat(euler: np.ndarray, euler_format='xyz'): 73 | r = Rotation.from_euler(euler_format, euler) 74 | return r.as_matrix() 75 | 76 | @staticmethod 77 | @dtype_maintaining_wrapper 78 | @multi_dim_rotmat_input_wrapper 79 | def rotmat_to_euler_angle(rotmat: np.ndarray, euler_format='xyz'): 80 | r = Rotation.from_matrix(rotmat) 81 | return r.as_euler(euler_format) 82 | 83 | @staticmethod 84 | @dtype_maintaining_wrapper 85 | def sixd_to_rotmat(sixd: np.ndarray): 86 | rotmat = np.zeros((sixd.shape[0], 3, 3)) 87 | rotmat[:, :, :2] = sixd.reshape(-1, 3, 2) 88 | rotmat[:, :, 2] = np.cross(rotmat[:, :, 0], rotmat[:, :, 1], axis=-1) 89 | return rotmat 90 | 91 | @staticmethod 92 | @dtype_maintaining_wrapper 93 | @multi_dim_rotmat_input_wrapper 94 | def rotmat_to_6d(rotmat: np.ndarray): 95 | sixd: np.ndarray = rotmat[None, :, 0:2] 96 | return sixd.reshape(-1, 6) 97 | 98 | #%% 99 | if __name__ == '__main__': 100 | # euler = np.array([0, 1, 2]) 101 | euler = torch.tensor([[0, 1, 2]]) 102 | rotmat = RotationHelper.euler_angle_to_rotmat(euler) 103 | euler1 = RotationHelper.rotmat_to_euler_angle(rotmat) 104 | 105 | aa = RotationHelper.rotmat_to_axis_angle(rotmat) 106 | quat = RotationHelper.rotmat_to_quat(rotmat) 107 | sixd = RotationHelper.rotmat_to_6d(rotmat) 108 | rot_t = RotationHelper.sixd_to_rotmat(sixd) 109 | rot1 = RotationHelper.quat_to_rotmat(quat) 110 | # %% 111 | -------------------------------------------------------------------------------- /data_preprocessing/6.5-visualize_tokens.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import pickle 4 | import random 5 | import tqdm 6 | import numpy as np 7 | from pathlib import Path 8 | import torch 9 | import argparse 10 | 11 | sys.path.append(os.getcwd()) 12 | sys.path.append(os.getcwd() + '/../') 13 | from src.utils.normalizer import TorchNormalizer 14 | from src.utils.motion_representation_converter import MotionRepresentationConverter 15 | from src.utils.utils import get_model_and_config_from_ckpt_path 16 | from src.utils.plot import animate_multiple_joints3d_22 17 | 18 | 19 | def get_args(): 20 | parser = argparse.ArgumentParser() 21 | 22 | parser.add_argument( 23 | '--ckpt_path', 24 | default='./path/to/vqvae.ckpt' 25 | ) 26 | 27 | parser.add_argument( 28 | '--dataset', 29 | default='interx' 30 | ) 31 | 32 | parser.add_argument( 33 | '--n_samples', 34 | default=20 35 | ) 36 | 37 | parser.add_argument( 38 | '--device', 39 | type=str, 40 | default=1 41 | ) 42 | 43 | args = parser.parse_args() 44 | 45 | args.device = torch.device(f'cuda:{args.device}') 46 | return args 47 | 48 | 49 | @torch.no_grad() 50 | def main(): 51 | args = get_args() 52 | model, model_config = get_model_and_config_from_ckpt_path(args.ckpt_path) 53 | model = model.to(args.device) 54 | 55 | mrc = MotionRepresentationConverter() 56 | 57 | data_dir = Path(f'~/data/data/motion/{args.dataset}').expanduser() 58 | src_dir = data_dir / model_config.model.model_kwargs.motion_representation 59 | tgt_dir = data_dir / args.ckpt_path.replace('/', '__slash__') 60 | 61 | normalizer_dict = pickle.load((data_dir / 'normalizers' / f'{model_config.model.model_kwargs.motion_representation}.pkl').open('rb')) 62 | normalizer = TorchNormalizer(normalizer_dict) 63 | 64 | src_motion_paths = random.choices(list(src_dir.glob('*.pkl')), k=args.n_samples) 65 | 66 | for mp in tqdm.tqdm(src_motion_paths): 67 | try: 68 | with mp.open('rb') as f: 69 | gt_motion_dict = pickle.load(f) 70 | with (tgt_dir / mp.name).open('rb') as f: 71 | tokens_dict = pickle.load(f) 72 | 73 | gt_reaction = gt_motion_dict['reaction'] 74 | gt_naction = gt_motion_dict['naction'] 75 | pred_reaction_tokens = tokens_dict['reaction'] 76 | pred_naction_tokens = tokens_dict['naction'] 77 | 78 | pred_reaction = normalizer.denormalize( 79 | model.decode(pred_reaction_tokens.unsqueeze(0).to(args.device)), 80 | key='all_motion' 81 | ).squeeze().cpu().numpy() 82 | pred_naction = normalizer.denormalize( 83 | model.decode(pred_naction_tokens.unsqueeze(0).to(args.device)), 84 | key='all_motion' 85 | ).squeeze().cpu().numpy() 86 | 87 | animate_multiple_joints3d_22( 88 | motions=[ 89 | mrc.convert('i262', 'j3d', gt_reaction[:pred_reaction.shape[0], ...]), 90 | mrc.convert('i262', 'j3d', pred_reaction) 91 | ], 92 | colors=['r', 'g'], 93 | title='reaction', 94 | file_path=f'temp_reaction_{mp.stem}.mp4' 95 | ) 96 | animate_multiple_joints3d_22( 97 | motions=[ 98 | mrc.convert('i262', 'j3d', gt_naction[:pred_naction.shape[0], ...]), 99 | mrc.convert('i262', 'j3d', pred_naction) 100 | ], 101 | colors=['r', 'g'], 102 | title='naction', 103 | file_path=f'temp_naction_{mp.stem}.mp4' 104 | ) 105 | 106 | except Exception as e: 107 | if isinstance(e, KeyboardInterrupt): 108 | raise 109 | print(e) 110 | 111 | 112 | if __name__ == '__main__': 113 | main() 114 | -------------------------------------------------------------------------------- /data_preprocessing/utils.py: -------------------------------------------------------------------------------- 1 | #%% joints3d_22 to intergen_262 2 | import sys 3 | import os 4 | import pickle 5 | import copy 6 | import numpy as np 7 | from pathlib import Path 8 | import torch 9 | 10 | sys.path.append(os.getcwd()) 11 | sys.path.append(os.getcwd() + '/../') 12 | import third_party.HumanML3D.common.quaternion as quat 13 | from src.utils.constants import JOINTS3D_22_KINEMATIC_CHAIN 14 | 15 | 16 | face_joint_indx = [2, 1, 17, 16] 17 | 18 | 19 | def normalize_single_joints3d_22(motion): 20 | motion = copy.copy(motion) 21 | # put on floor 22 | floor_height = motion.min(axis=0).min(axis=0)[1] 23 | motion[:, :, 1] -= floor_height 24 | 25 | # reactor xz at origin 26 | re_root_pose_init = motion[0] 27 | re_root_xz_init = re_root_pose_init[0] * np.array([1, 0, 1]) 28 | motion -= re_root_xz_init 29 | 30 | # reactor face Z+ 31 | r_hip, l_hip, sdr_r, sdr_l = face_joint_indx 32 | across = re_root_pose_init[r_hip] - re_root_pose_init[l_hip] 33 | across = across / np.sqrt((across ** 2).sum(axis=-1))[..., np.newaxis] 34 | forward_init = np.cross(np.array([[0, 1, 0]]), across, axis=-1) 35 | forward_init = forward_init / np.sqrt((forward_init ** 2).sum(axis=-1))[..., np.newaxis] 36 | target = np.array([[0, 0, 1]]) 37 | root_quat_init = quat.qbetween_np(forward_init, target) 38 | root_quat_init_for_all = np.ones(motion.shape[:-1] + (4,)) * root_quat_init 39 | motion = quat.qrot_np(root_quat_init_for_all, motion) 40 | 41 | return motion, (re_root_xz_init[0], re_root_xz_init[2], np.arctan2(forward_init[0, 2], forward_init[0, 0])) 42 | 43 | 44 | def denormalize_single_joints3d_22(motion, x, z, r): 45 | motion = copy.copy(motion) 46 | # recover rotation 47 | forward_init = np.array([[0, 0, 1]]) 48 | target = np.array([[np.cos(r), 0, np.sin(r)]]) 49 | root_quat_init = quat.qbetween_np(forward_init, target) 50 | root_quat_init_for_all = np.ones(motion.shape[:-1] + (4,)) * root_quat_init 51 | motion = quat.qrot_np(root_quat_init_for_all, motion) 52 | motion += np.array([x, 0, z]) 53 | return motion 54 | 55 | 56 | def normalize_dual_joints3d_22(action, reaction): 57 | action = copy.copy(action) 58 | reaction = copy.copy(reaction) 59 | # put on floor 60 | re_floor_height = reaction.min(axis=0).min(axis=0)[1] 61 | a_floor_height = action.min(axis=0).min(axis=0)[1] 62 | 63 | reaction[:, :, 1] -= re_floor_height 64 | action[:, :, 1] -= a_floor_height 65 | 66 | # reactor xz at origin 67 | re_root_pose_init = reaction[0] 68 | re_root_xz_init = re_root_pose_init[0] * np.array([1, 0, 1]) 69 | reaction -= re_root_xz_init 70 | action -= re_root_xz_init 71 | 72 | # reactor face Z+ 73 | r_hip, l_hip, sdr_r, sdr_l = face_joint_indx 74 | across = re_root_pose_init[r_hip] - re_root_pose_init[l_hip] 75 | across = across / np.sqrt((across ** 2).sum(axis=-1))[..., np.newaxis] 76 | forward_init = np.cross(np.array([[0, 1, 0]]), across, axis=-1) 77 | forward_init = forward_init / np.sqrt((forward_init ** 2).sum(axis=-1))[..., np.newaxis] 78 | target = np.array([[0, 0, 1]]) 79 | root_quat_init = quat.qbetween_np(forward_init, target) 80 | root_quat_init_for_all = np.ones(reaction.shape[:-1] + (4,)) * root_quat_init 81 | reaction = quat.qrot_np(root_quat_init_for_all, reaction) 82 | action = quat.qrot_np(root_quat_init_for_all, action) 83 | 84 | return action, reaction 85 | 86 | 87 | def mirror_joints3d_22(motion): 88 | assert len(motion.shape) == 3 89 | mirrored = copy.copy(motion) 90 | mirrored[..., 0] *= -1 91 | right_chain = [2, 5, 8, 11, 14, 17, 19, 21] 92 | left_chain = [1, 4, 7, 10, 13, 16, 18, 20] 93 | tmp = mirrored[:, right_chain] 94 | mirrored[:, right_chain] = mirrored[:, left_chain] 95 | mirrored[:, left_chain] = tmp 96 | return mirrored 97 | 98 | 99 | def mirror_text(text: str): 100 | return text.replace( 101 | "left", "tmp").replace("right", "left").replace("tmp", "right").replace( 102 | "clockwise", "tmp").replace("counterclockwise", "clockwise").replace("tmp", "counterclockwise") 103 | -------------------------------------------------------------------------------- /src/utils/normalizer.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Union 2 | import numpy as np 3 | import torch 4 | 5 | 6 | class TorchNormalizer(): 7 | def __init__(self, statistics_dict: Dict[str, Dict[str, torch.Tensor]]): 8 | self.statistics_dict = statistics_dict 9 | for k, v in statistics_dict.items(): 10 | for kk, vv in v.items(): 11 | vv.requires_grad_(False) 12 | 13 | def normalize(self, data: Union[torch.Tensor, np.ndarray], key: str): 14 | if self.statistics_dict == None: 15 | return data 16 | 17 | if isinstance(data, np.ndarray): 18 | is_np_input = True 19 | data = torch.from_numpy(data) 20 | else: 21 | is_np_input = False 22 | 23 | mean = self.statistics_dict[key]['mean'].to(data.device) 24 | std = self.statistics_dict[key]['std'].to(data.device) 25 | res = (data - mean) / std 26 | res = torch.nan_to_num(res, nan=0.0) # TODO: change to fill with mean 27 | 28 | return res.cpu().numpy() if is_np_input else res 29 | 30 | def denormalize(self, data: Union[torch.Tensor, np.ndarray], key: str): 31 | if isinstance(data, np.ndarray): 32 | data = torch.from_numpy(data) 33 | is_np_input = True 34 | else: 35 | is_np_input = False 36 | 37 | mean = self.statistics_dict[key]['mean'].to(data.device) 38 | std = self.statistics_dict[key]['std'].to(data.device) 39 | res = data * std + mean 40 | 41 | return res.cpu().numpy() if is_np_input else res 42 | 43 | def norm_batch(self, batch: Dict[str, torch.Tensor], keys: List[str] = None, device=torch.device('cuda:0')): 44 | if keys is None: 45 | keys = batch.keys() 46 | 47 | raw_device = batch[keys[0]].device # assume all tensor values are on the same device 48 | 49 | if raw_device != torch.device('cpu'): 50 | device = raw_device 51 | # else if raw_device is cpu, move data to cuda:0 52 | 53 | for k in keys: 54 | batch[k] = self.normalize(data=batch[k].to(device), key=k).to(raw_device) 55 | return batch 56 | 57 | def denorm_batch(self, batch: Dict[str, torch.Tensor], device=torch.device('cuda:0')): 58 | if keys is None: 59 | keys = batch.keys() 60 | 61 | raw_device = batch[keys[0]].device 62 | 63 | if raw_device != torch.device('cpu'): 64 | device = raw_device 65 | # else if raw_device is cpu, move data to cuda:0 66 | 67 | for k in keys: 68 | batch[k] = self.denormalize(data=batch[k].to(device), key=k).to(raw_device) 69 | return batch 70 | 71 | def norm_list_dict(self, data: List[Dict[str, torch.Tensor]], keys: List[str] = None, device=torch.device('cuda:0')): 72 | if keys is None: 73 | keys = data.keys() 74 | 75 | data_length = len(data) 76 | big_batch = {k:[] for k in keys} 77 | for k in keys: 78 | for idx in range(data_length): 79 | big_batch[k].append(data[idx][k]) 80 | big_batch[k] = torch.stack(big_batch[k]) 81 | 82 | big_batch = self.norm_batch(batch=big_batch, keys=keys, device=device) 83 | 84 | for idx in range(data_length): 85 | for k in keys: 86 | data[idx][k] = big_batch[k][idx] 87 | 88 | return data 89 | 90 | def denorm_list_dict(self, data: Dict[str, torch.Tensor], device=torch.device('cuda:0')): 91 | if keys is None: 92 | keys = data.keys() 93 | 94 | data_length = len(data) 95 | big_batch = {k:[] for k in keys} 96 | for k in keys: 97 | for idx in range(data_length): 98 | big_batch[k].append(data[idx][k]) 99 | big_batch[k] = torch.stack(big_batch[k]) 100 | 101 | big_batch = self.denorm_batch(batch=big_batch, keys=keys, device=device) 102 | 103 | for idx in range(data_length): 104 | for k in keys: 105 | data[idx][k] = big_batch[k][idx] 106 | 107 | return data 108 | -------------------------------------------------------------------------------- /src/utils/constants.py: -------------------------------------------------------------------------------- 1 | """ 2 | Throw anything here that is a constant, 3 | or when you want to get rid of thinking about "where should I put this?" 4 | """ 5 | 6 | from itertools import pairwise 7 | 8 | SMPLX_KEY_SHAPE = { 9 | 'transl': (3), 'global_orient': (1, 3, 3), 'body_pose': (21, 3, 3), 'betas': (10), 'left_hand_pose': (15, 3, 3), 'right_hand_pose': (15, 3, 3), 'jaw_pose': (1, 3, 3), 'leye_pose': (1, 3, 3), 'reye_pose': (1, 3, 3), 'expression': (10) 10 | } 11 | 12 | SMPLX_ROTATION_KEYS = ('global_orient', 'body_pose', 'left_hand_pose', 'right_hand_pose', 'jaw_pose', 'leye_pose', 'reye_pose') 13 | 14 | JOINTS3D_22_KINEMATIC_CHAIN = [ 15 | [0, 2, 5, 8, 11], 16 | [0, 1, 4, 7, 10], 17 | [0, 3, 6, 9, 12, 15], 18 | [9, 14, 17, 19, 21], 19 | [9, 13, 16, 18, 20] 20 | ] 21 | 22 | EDGE22_INDICES_UNDIRCTIONAL = [] 23 | for chain in JOINTS3D_22_KINEMATIC_CHAIN: 24 | chain_edges = [] 25 | for i, j in pairwise(chain): 26 | EDGE22_INDICES_UNDIRCTIONAL.append([i, j]) 27 | 28 | EDGE22_INDICES = [] 29 | for chain in JOINTS3D_22_KINEMATIC_CHAIN: 30 | chain_edges = [] 31 | for i, j in pairwise(chain): 32 | chain_edges.extend([(i, j), (j, i)]) 33 | EDGE22_INDICES.extend(chain_edges) 34 | for i in range(22): 35 | EDGE22_INDICES.append((i, i)) 36 | 37 | EDGE_INDEX_INFO = { 38 | 'joints3d_22': EDGE22_INDICES, 39 | 'joints12d_22': EDGE22_INDICES 40 | } 41 | 42 | MOTION_REPRESENTATION_INFO = { 43 | 'intergen_262': { 44 | 'feature_size': 262, 45 | 'key_to_range': { 46 | 'pos': [0, 66], 47 | 'vel': [66, 132], 48 | 'rot': [132, 258], 49 | 'foot': [258, 262] 50 | } 51 | }, 52 | 53 | 'joints3d_22': { 54 | 'feature_size': [22, 3] 55 | }, 56 | 57 | 'joints12d_22': { 58 | 'feature_size': [22, 12], 59 | 'key_to_range': { 60 | 'pos': [0, 3], 61 | 'vel': [3, 6], 62 | 'rot': [6, 12], 63 | }, 64 | }, 65 | 66 | 'foot_indices' : { 67 | 'left': [7, 10], 68 | 'right': [8, 11] 69 | }, 70 | 71 | 'tokens': { 72 | 'feature_size': 512 73 | }, 74 | 75 | 'tokens_512': { 76 | 'feature_size': 512 77 | } 78 | } 79 | 80 | TEXT_FEATURE_INFO = { 81 | 'google-bert/bert-base-uncased': { 82 | 'feature_size': 768 83 | }, 84 | 'openai/clip-vit-base-patch32': { 85 | 'feature_size': 512 86 | }, 87 | 'openai/clip-vit-large-patch14': { 88 | 'feature_size': 768 89 | }, 90 | } 91 | 92 | VALUE_RANGES = { 93 | 'x': [-4.268601281738281, 4.268601281738281], 94 | 'z': [-3.7807260585784914, 4.461718423461914], 95 | 'r': [-3.1416857620286556, 3.1416867933963655] 96 | } 97 | 98 | INTERX_LABEL_MAPPING = { 99 | 0: 'Hug', 1: 'Handshake', 2: 'Wave', 3: 'Grab', 4: 'Hit', 100 | 5: 'Kick', 6: 'Posing', 7: 'Push', 8: 'Pull', 9: 'Sit-on-leg', 101 | 10: 'Slap', 11: 'Pat-on-back', 12: 'Point-finger-at', 13: 'Walk-towards', 14: 'Knock-over', 102 | 15: 'Step-on-foot', 16: 'High-five', 17: 'Chase', 18: 'Whisper-in-ear', 19: 'Support-with-hand', 103 | 20: 'Finger-guessing', 21: 'Dance', 22: 'Link-arms', 23: 'Shoulder-to-shoulder', 24: 'Bend', 104 | 25: 'Carry-on-back', 26: 'Massage-shoulder', 27: 'Massage-leg', 28: 'Hand-wrestling', 29: 'Chat', 105 | 30: 'Pat-on-cheek', 31: 'Thumb-up', 32: 'Touch-head', 33: 'Imitate', 34: 'Kiss-on-cheek', 106 | 35: 'Help-up', 36: 'Cover-mouth', 37: 'Look-back', 38: 'Block', 39: 'Fly-kiss' 107 | } 108 | 109 | # just suppose 110 | INTERX_FAMILIARITY_MAPPING = { 111 | 1: 'stranger', 112 | 2: 'schoolmate', 113 | 3: 'friend', 114 | 4: 'lover', 115 | } 116 | 117 | INTERX_GROUP_TO_FAMILIARITY = { 118 | 1: 1, 2: 4, 3: 1, 4: 4, 5: 1, 6: 1, 7: 1, 8: 1, 9: 1, 10: 1, 119 | 11: 1, 12: 4, 13: 4, 14: 1, 15: 1, 16: 2, 17: 1, 18: 1, 19: 3, 20: 3, 120 | 21: 3, 22: 1, 23: 1, 24: 1, 25: 1, 26: 1, 27: 4, 28: 1, 29: 4, 30: 1, 121 | 31: 2, 32: 4, 33: 4, 34: 3, 35: 4, 36: 2, 37: 1, 38: 3, 39: 1, 40: 1, 122 | 41: 1, 42: 1, 43: 1, 44: 4, 45: 2, 46: 1, 47: 2, 48: 1, 49: 3, 50: 1, 123 | 51: 4, 52: 3, 53: 3, 54: 2, 55: 2, 56: 1, 57: 3, 58: 1, 59: 4 124 | } 125 | -------------------------------------------------------------------------------- /third_party/HumanML3D/human_body_prior/tools/model_loader.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright (C) 2019 Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG), 4 | # acting on behalf of its Max Planck Institute for Intelligent Systems and the 5 | # Max Planck Institute for Biological Cybernetics. All rights reserved. 6 | # 7 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is holder of all proprietary rights 8 | # on this computer program. You can only use this computer program if you have closed a license agreement 9 | # with MPG or you get the right to use the computer program from someone who is authorized to grant you that right. 10 | # Any use of the computer program without a valid license is prohibited and liable to prosecution. 11 | # Contact: ps-license@tuebingen.mpg.de 12 | # 13 | # 14 | # If you use this code in a research publication please consider citing the following: 15 | # 16 | # Expressive Body Capture: 3D Hands, Face, and Body from a Single Image 17 | # 18 | # 19 | # Code Developed by: Nima Ghorbani 20 | # 2018.01.02 21 | 22 | import os, glob 23 | import numpy as np 24 | from human_body_prior.tools.configurations import load_config, dump_config 25 | import os.path as osp 26 | 27 | def exprdir2model(expr_dir): 28 | 29 | if not os.path.exists(expr_dir): raise ValueError('Could not find the experiment directory: %s' % expr_dir) 30 | 31 | model_snapshots_dir = osp.join(expr_dir, 'snapshots') 32 | available_ckpts = sorted(glob.glob(osp.join(model_snapshots_dir, '*.ckpt')), key=osp.getmtime) 33 | assert len(available_ckpts) > 0, ValueError('No checck points found at {}'.format(model_snapshots_dir)) 34 | trained_weigths_fname = available_ckpts[-1] 35 | 36 | model_ps_fname = glob.glob(osp.join('/', '/'.join(trained_weigths_fname.split('/')[:-2]), '*.yaml')) 37 | if len(model_ps_fname) == 0: 38 | model_ps_fname = glob.glob(osp.join('/'.join(trained_weigths_fname.split('/')[:-2]), '*.yaml')) 39 | 40 | model_ps_fname = model_ps_fname[0] 41 | model_ps = load_config(default_ps_fname=model_ps_fname) 42 | 43 | model_ps.logging.best_model_fname = trained_weigths_fname 44 | 45 | return model_ps, trained_weigths_fname 46 | 47 | 48 | def load_model(expr_dir, model_code=None, remove_words_in_model_weights=None, load_only_ps=False, disable_grad=True, custom_ps = None): 49 | ''' 50 | 51 | :param expr_dir: 52 | :param model_code: an imported module 53 | from supercap.train.supercap_smpl import SuperCap, then pass SuperCap to this function 54 | :param if True will load the model definition used for training, and not the one in current repository 55 | :return: 56 | ''' 57 | import importlib 58 | import torch 59 | 60 | model_ps, trained_weigths_fname = exprdir2model(expr_dir) 61 | if load_only_ps: return model_ps 62 | if custom_ps is not None: model_ps = custom_ps 63 | assert model_code is not None, ValueError('mode_code should be provided') 64 | model_instance = model_code(model_ps) 65 | if disable_grad: # i had to do this. torch.no_grad() couldnt achieve what i was looking for 66 | for param in model_instance.parameters(): 67 | param.requires_grad = False 68 | state_dict = torch.load(trained_weigths_fname)['state_dict'] 69 | if remove_words_in_model_weights is not None: 70 | words = '{}'.format(remove_words_in_model_weights) 71 | state_dict = {k.replace(words, '') if k.startswith(words) else k: v for k, v in state_dict.items()} 72 | 73 | ## keys that were in the model trained file and not in the current model 74 | instance_model_keys = list(model_instance.state_dict().keys()) 75 | trained_model_keys = list(state_dict.keys()) 76 | wts_in_model_not_in_file = set(instance_model_keys).difference(set(trained_model_keys)) 77 | ## keys that are in the current model not in the training weights 78 | wts_in_file_not_in_model = set(trained_model_keys).difference(set(instance_model_keys)) 79 | # assert len(wts_in_model_not_in_file) == 0, ValueError('Some model weights are not present in the pretrained file. {}'.format(wts_in_model_not_in_file)) 80 | 81 | state_dict = {k:v for k, v in state_dict.items() if k in instance_model_keys} 82 | model_instance.load_state_dict(state_dict, strict=False) # Todo fix the issues so that we can set the strict to true. The body model uses unnecessary registered buffers 83 | model_instance.eval() 84 | 85 | return model_instance, model_ps 86 | 87 | 88 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import shutil 4 | import tqdm 5 | from collections import defaultdict 6 | import lightning.pytorch as pl 7 | import numpy as np 8 | from pathlib import Path 9 | from omegaconf import OmegaConf 10 | import torch 11 | import pprint 12 | from torch.utils.data import DataLoader 13 | 14 | from src.utils import get_obj_from_str, instantiate_from_config, setup_logger, get_metric_statistics, dict_to_device 15 | from src.metrics.motion_generation import MotionGenerationEvaluator 16 | 17 | 18 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 19 | logger = None 20 | 21 | 22 | @torch.no_grad() 23 | def evaluate(args, config, model, test_dataloader): 24 | nlg_generation_config = {'max_new_tokens': 200} if args.eval_nlg_only else {} 25 | final_res = defaultdict(list) 26 | final_statistics = defaultdict(list) 27 | for i in range(args.replication_times): 28 | pl.seed_everything(config.seed + i) 29 | 30 | results = defaultdict(list) 31 | for batch in tqdm.tqdm(test_dataloader): 32 | batch = dict_to_device(batch, device=model.device) 33 | step_results = model.test_step( 34 | batch, use_gt_prompt=args.use_gt_prompt, 35 | eval_nlg_only=args.eval_nlg_only, 36 | eval_nlg_action_ratio=args.eval_nlg_action_ratio, 37 | nlg_generation_config=nlg_generation_config 38 | ) 39 | for k, v in step_results.items(): 40 | results[k].append(v) 41 | 42 | for k, v in results.items(): 43 | final_res[k].append(torch.stack(v).cpu().numpy().mean()) 44 | print(final_res) 45 | 46 | for k, v in final_res.items(): 47 | mean, conf_interval = get_metric_statistics(v, replication_times=args.replication_times) 48 | final_statistics[k] = f'{mean:.3f},{conf_interval:.3f}' 49 | 50 | logger.info(f'evaluation results: {final_statistics}') 51 | 52 | 53 | def get_args_and_config(): 54 | parser = argparse.ArgumentParser() 55 | 56 | parser.add_argument( 57 | '--ckpt_path', 58 | type=str, 59 | default='logs/model_name/dataset_name/signature_trained/checkpoints/best.ckpt' 60 | ) 61 | 62 | parser.add_argument( 63 | '--device', 64 | type=int, 65 | default=1 66 | ) 67 | 68 | parser.add_argument( 69 | '--batch_size', 70 | type=int, 71 | default=32 72 | ) 73 | 74 | parser.add_argument( 75 | '--replication_times', 76 | type=int, 77 | default=1 78 | ) 79 | 80 | parser.add_argument( 81 | '--rethinking_interval', 82 | type=int, 83 | default=None 84 | ) 85 | 86 | parser.add_argument( 87 | '--evaluator_ckpt_path', 88 | type=str, 89 | default=None 90 | ) 91 | 92 | parser.add_argument( 93 | '--eval_nlg_only', 94 | action='store_true' 95 | ) 96 | 97 | parser.add_argument( 98 | '--eval_nlg_action_ratio', 99 | type=float, 100 | default=1 101 | ) 102 | 103 | parser.add_argument( 104 | '--use_gt_prompt', 105 | default=False, 106 | ) 107 | 108 | args = parser.parse_args() 109 | 110 | args.ckpt_path = Path(args.ckpt_path) 111 | log_dir = args.ckpt_path.parent.parent 112 | 113 | args.device = torch.device(args.device) 114 | 115 | args.log_file_path = str((log_dir / 'results.log').expanduser()) 116 | 117 | config = OmegaConf.load(log_dir / 'hparams.yaml').all_config 118 | 119 | return args, config 120 | 121 | 122 | def main(): 123 | global logger 124 | args, config = get_args_and_config() 125 | 126 | logger = setup_logger(__file__, log_file=args.log_file_path) 127 | 128 | logger.info(f'\n-----------------------------------------------\n') 129 | logger.info(f'Evaluating with ckpt: {args.ckpt_path}') 130 | logger.info(f'Evaluation config: {pprint.pformat(config)}') 131 | 132 | model_cls = get_obj_from_str(config.model.target) 133 | model = model_cls.load_from_checkpoint(str(args.ckpt_path), map_location=args.device, strict=False).eval() 134 | 135 | if p := args.evaluator_ckpt_path: 136 | model._mg_evaluator = MotionGenerationEvaluator(ckpt_path=p, device=args.device) 137 | 138 | if isinstance(args.rethinking_interval, int): 139 | model.model_kwargs.rethinking_interval = args.rethinking_interval 140 | 141 | # load val data 142 | test_dataset = instantiate_from_config(config.dataset, extra_kwargs={'split': 'test'}) 143 | model.normalizer = test_dataset.normalizer 144 | 145 | test_dataloader = DataLoader(dataset=test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=8, persistent_workers=True, drop_last=True) 146 | 147 | evaluate(args=args, config=config, model=model, test_dataloader=test_dataloader) 148 | 149 | 150 | if __name__ == '__main__': 151 | main() 152 | -------------------------------------------------------------------------------- /third_party/HumanML3D/human_body_prior/tools/rotation_tools.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright (C) 2019 Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG), 4 | # acting on behalf of its Max Planck Institute for Intelligent Systems and the 5 | # Max Planck Institute for Biological Cybernetics. All rights reserved. 6 | # 7 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is holder of all proprietary rights 8 | # on this computer program. You can only use this computer program if you have closed a license agreement 9 | # with MPG or you get the right to use the computer program from someone who is authorized to grant you that right. 10 | # Any use of the computer program without a valid license is prohibited and liable to prosecution. 11 | # Contact: ps-license@tuebingen.mpg.de 12 | # 13 | # 14 | # If you use this code in a research publication please consider citing the following: 15 | # 16 | # Expressive Body Capture: 3D Hands, Face, and Body from a Single Image 17 | # 18 | # 19 | # Code Developed by: 20 | # Nima Ghorbani 21 | # 22 | # 2020.12.12 23 | import numpy as np 24 | 25 | from torch.nn import functional as F 26 | from human_body_prior.tools import tgm_conversion as tgm 27 | import torch 28 | 29 | def local2global_pose(local_pose, kintree): 30 | bs = local_pose.shape[0] 31 | 32 | local_pose = local_pose.view(bs, -1, 3, 3) 33 | 34 | global_pose = local_pose.clone() 35 | 36 | for jId in range(len(kintree)): 37 | parent_id = kintree[jId] 38 | if parent_id >= 0: 39 | global_pose[:, jId] = torch.matmul(global_pose[:, parent_id], global_pose[:, jId]) 40 | 41 | return global_pose 42 | 43 | def em2euler(em): 44 | ''' 45 | 46 | :param em: rotation in expo-map (3,) 47 | :return: rotation in euler angles (3,) 48 | ''' 49 | from transforms3d.euler import axangle2euler 50 | 51 | theta = np.sqrt((em ** 2).sum()) 52 | axis = em / theta 53 | return np.array(axangle2euler(axis, theta)) 54 | 55 | 56 | def euler2em(ea): 57 | ''' 58 | 59 | :param ea: rotation in euler angles (3,) 60 | :return: rotation in expo-map (3,) 61 | ''' 62 | from transforms3d.euler import euler2axangle 63 | axis, theta = euler2axangle(*ea) 64 | return np.array(axis*theta) 65 | 66 | 67 | def remove_zrot(pose): 68 | noZ = em2euler(pose[:3].copy()) 69 | noZ[2] = 0 70 | pose[:3] = euler2em(noZ).copy() 71 | return pose 72 | 73 | def matrot2aa(pose_matrot): 74 | ''' 75 | :param pose_matrot: Nx3x3 76 | :return: Nx3 77 | ''' 78 | bs = pose_matrot.size(0) 79 | homogen_matrot = F.pad(pose_matrot, [0,1]) 80 | pose = tgm.rotation_matrix_to_angle_axis(homogen_matrot) 81 | return pose 82 | 83 | def aa2matrot(pose): 84 | ''' 85 | :param Nx3 86 | :return: pose_matrot: Nx3x3 87 | ''' 88 | bs = pose.size(0) 89 | num_joints = pose.size(1)//3 90 | pose_body_matrot = tgm.angle_axis_to_rotation_matrix(pose)[:, :3, :3].contiguous()#.view(bs, num_joints*9) 91 | return pose_body_matrot 92 | 93 | def noisy_zrot(rot_in): 94 | ''' 95 | 96 | :param rot_in: np.array Nx3 rotations in axis-angle representation 97 | :return: 98 | will add a degree from a full circle to the zrotations 99 | ''' 100 | is_batched = False 101 | if rot_in.ndim == 2: is_batched = True 102 | if not is_batched: 103 | rot_in = rot_in[np.newaxis] 104 | 105 | rnd_zrot = np.random.uniform(-np.pi, np.pi) 106 | rot_out = [] 107 | for bId in range(len(rot_in)): 108 | pose_cpu = rot_in[bId] 109 | pose_euler = em2euler(pose_cpu) 110 | 111 | pose_euler[2] += rnd_zrot 112 | 113 | pose_aa = euler2em(pose_euler) 114 | rot_out.append(pose_aa.copy()) 115 | 116 | return np.array(rot_out) 117 | 118 | def rotate_points_xyz(mesh_v, Rxyz): 119 | ''' 120 | 121 | :param mesh_v: Nxnum_vx3 122 | :param Rxyz: Nx3 123 | :return: 124 | ''' 125 | 126 | mesh_v_rotated = [] 127 | 128 | for fId in range(mesh_v.shape[0]): 129 | angle = np.radians(Rxyz[fId, 0]) 130 | rx = np.array([ 131 | [1., 0., 0. ], 132 | [0., np.cos(angle), -np.sin(angle)], 133 | [0., np.sin(angle), np.cos(angle) ] 134 | ]) 135 | 136 | angle = np.radians(Rxyz[fId, 1]) 137 | ry = np.array([ 138 | [np.cos(angle), 0., np.sin(angle)], 139 | [0., 1., 0. ], 140 | [-np.sin(angle), 0., np.cos(angle)] 141 | ]) 142 | 143 | angle = np.radians(Rxyz[fId, 2]) 144 | rz = np.array([ 145 | [np.cos(angle), -np.sin(angle), 0. ], 146 | [np.sin(angle), np.cos(angle), 0. ], 147 | [0., 0., 1. ] 148 | ]) 149 | mesh_v_rotated.append(rz.dot(ry.dot(rx.dot(mesh_v[fId].T))).T) 150 | 151 | return np.array(mesh_v_rotated) -------------------------------------------------------------------------------- /third_party/HumanML3D/human_body_prior/models/vposer_model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright (C) 2019 Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG), 4 | # acting on behalf of its Max Planck Institute for Intelligent Systems and the 5 | # Max Planck Institute for Biological Cybernetics. All rights reserved. 6 | # 7 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is holder of all proprietary rights 8 | # on this computer program. You can only use this computer program if you have closed a license agreement 9 | # with MPG or you get the right to use the computer program from someone who is authorized to grant you that right. 10 | # Any use of the computer program without a valid license is prohibited and liable to prosecution. 11 | # Contact: ps-license@tuebingen.mpg.de 12 | # 13 | # 14 | # If you use this code in a research publication please consider citing the following: 15 | # 16 | # Expressive Body Capture: 3D Hands, Face, and Body from a Single Image 17 | # 18 | # 19 | # Code Developed by: 20 | # Nima Ghorbani 21 | # 22 | # 2020.12.12 23 | 24 | import numpy as np 25 | import torch 26 | from human_body_prior.models.model_components import BatchFlatten 27 | from human_body_prior.tools.rotation_tools import matrot2aa 28 | from torch import nn 29 | from torch.nn import functional as F 30 | 31 | 32 | class ContinousRotReprDecoder(nn.Module): 33 | def __init__(self): 34 | super(ContinousRotReprDecoder, self).__init__() 35 | 36 | def forward(self, module_input): 37 | reshaped_input = module_input.view(-1, 3, 2) 38 | 39 | b1 = F.normalize(reshaped_input[:, :, 0], dim=1) 40 | 41 | dot_prod = torch.sum(b1 * reshaped_input[:, :, 1], dim=1, keepdim=True) 42 | b2 = F.normalize(reshaped_input[:, :, 1] - dot_prod * b1, dim=-1) 43 | b3 = torch.cross(b1, b2, dim=1) 44 | 45 | return torch.stack([b1, b2, b3], dim=-1) 46 | 47 | 48 | class NormalDistDecoder(nn.Module): 49 | def __init__(self, num_feat_in, latentD): 50 | super(NormalDistDecoder, self).__init__() 51 | 52 | self.mu = nn.Linear(num_feat_in, latentD) 53 | self.logvar = nn.Linear(num_feat_in, latentD) 54 | 55 | def forward(self, Xout): 56 | return torch.distributions.normal.Normal(self.mu(Xout), F.softplus(self.logvar(Xout))) 57 | 58 | 59 | class VPoser(nn.Module): 60 | def __init__(self, model_ps): 61 | super(VPoser, self).__init__() 62 | 63 | num_neurons, self.latentD = model_ps.model_params.num_neurons, model_ps.model_params.latentD 64 | 65 | self.num_joints = 21 66 | n_features = self.num_joints * 3 67 | 68 | self.encoder_net = nn.Sequential( 69 | BatchFlatten(), 70 | nn.BatchNorm1d(n_features), 71 | nn.Linear(n_features, num_neurons), 72 | nn.LeakyReLU(), 73 | nn.BatchNorm1d(num_neurons), 74 | nn.Dropout(0.1), 75 | nn.Linear(num_neurons, num_neurons), 76 | nn.Linear(num_neurons, num_neurons), 77 | NormalDistDecoder(num_neurons, self.latentD) 78 | ) 79 | 80 | self.decoder_net = nn.Sequential( 81 | nn.Linear(self.latentD, num_neurons), 82 | nn.LeakyReLU(), 83 | nn.Dropout(0.1), 84 | nn.Linear(num_neurons, num_neurons), 85 | nn.LeakyReLU(), 86 | nn.Linear(num_neurons, self.num_joints * 6), 87 | ContinousRotReprDecoder(), 88 | ) 89 | 90 | def encode(self, pose_body): 91 | ''' 92 | :param Pin: Nx(numjoints*3) 93 | :param rep_type: 'matrot'/'aa' for matrix rotations or axis-angle 94 | :return: 95 | ''' 96 | return self.encoder_net(pose_body) 97 | 98 | def decode(self, Zin): 99 | bs = Zin.shape[0] 100 | 101 | prec = self.decoder_net(Zin) 102 | 103 | return { 104 | 'pose_body': matrot2aa(prec.view(-1, 3, 3)).view(bs, -1, 3), 105 | 'pose_body_matrot': prec.view(bs, -1, 9) 106 | } 107 | 108 | 109 | def forward(self, pose_body): 110 | ''' 111 | :param Pin: aa: Nx1xnum_jointsx3 / matrot: Nx1xnum_jointsx9 112 | :param input_type: matrot / aa for matrix rotations or axis angles 113 | :param output_type: matrot / aa 114 | :return: 115 | ''' 116 | 117 | q_z = self.encode(pose_body) 118 | q_z_sample = q_z.rsample() 119 | decode_results = self.decode(q_z_sample) 120 | decode_results.update({'poZ_body_mean': q_z.mean, 'poZ_body_std': q_z.scale, 'q_z': q_z}) 121 | return decode_results 122 | 123 | def sample_poses(self, num_poses, seed=None): 124 | np.random.seed(seed) 125 | 126 | some_weight = [a for a in self.parameters()][0] 127 | dtype = some_weight.dtype 128 | device = some_weight.device 129 | self.eval() 130 | with torch.no_grad(): 131 | Zgen = torch.tensor(np.random.normal(0., 1., size=(num_poses, self.latentD)), dtype=dtype, device=device) 132 | 133 | return self.decode(Zgen) 134 | -------------------------------------------------------------------------------- /third_party/HumanML3D/human_body_prior/visualizations/training_visualization.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright (C) 2019 Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG), 4 | # acting on behalf of its Max Planck Institute for Intelligent Systems and the 5 | # Max Planck Institute for Biological Cybernetics. All rights reserved. 6 | # 7 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is holder of all proprietary rights 8 | # on this computer program. You can only use this computer program if you have closed a license agreement 9 | # with MPG or you get the right to use the computer program from someone who is authorized to grant you that right. 10 | # Any use of the computer program without a valid license is prohibited and liable to prosecution. 11 | # Contact: ps-license@tuebingen.mpg.de 12 | # 13 | # 14 | # If you use this code in a research publication please consider citing the following: 15 | # 16 | # Expressive Body Capture: 3D Hands, Face, and Body from a Single Image 17 | # 18 | # 19 | # Code Developed by: 20 | # Nima Ghorbani 21 | # 22 | # 2020.12.12 23 | 24 | def pyrenderer(imw=2048, imh=2048): 25 | 26 | from body_visualizer.mesh.mesh_viewer import MeshViewer 27 | import cv2 28 | 29 | import numpy as np 30 | import trimesh 31 | 32 | try: 33 | mv = MeshViewer(width=imw, height=imh, use_offscreen=True) 34 | except: 35 | import os 36 | os.environ['PYOPENGL_PLATFORM'] = 'egl' 37 | os.environ['EGL_DEVICE_ID'] = os.environ['GPU_DEVICE_ORDINAL'].split(',')[0] 38 | 39 | mv = MeshViewer(width=imw, height=imh, use_offscreen=True) 40 | 41 | mv.set_cam_trans([0, -0.5, 2.]) 42 | 43 | def render_an_image(meshes): 44 | n_all = len(meshes) 45 | nc = int(np.sqrt(n_all)) 46 | 47 | out_image = np.zeros([1, 1, 1, mv.width, mv.height, 4]) 48 | 49 | scale_percent = 100./nc 50 | width = int(mv.width * scale_percent / 100) 51 | height = int(mv.height * scale_percent / 100) 52 | dim = (width, height) 53 | 54 | for rId in range(nc): 55 | for cId in range(nc): 56 | i = (nc*rId) + cId 57 | if i>len(meshes): break 58 | 59 | mesh = meshes[i] 60 | 61 | # mesh.apply_transform(trimesh.transformations.rotation_matrix(np.radians(-90), (1, 0, 0))) 62 | mesh.vertices -= np.median(np.array(mesh.vertices), axis=0) 63 | mv.set_dynamic_meshes([mesh]) 64 | img = mv.render(render_wireframe=False, RGBA=True) 65 | img_resized = cv2.resize(img, dim, interpolation=cv2.INTER_AREA) 66 | 67 | out_image[0, 0, 0, (rId*width):((rId+1)*width), (cId*height):((cId+1)*height)] = cv2.cvtColor(img_resized, cv2.COLOR_BGRA2RGBA) 68 | 69 | return out_image.astype(np.uint8) 70 | 71 | return render_an_image 72 | 73 | def vposer_trainer_renderer(bm, num_bodies_to_display=5): 74 | import numpy as np 75 | import trimesh 76 | import torch 77 | 78 | from body_visualizer.tools.vis_tools import imagearray2file, colors 79 | from human_body_prior.tools.omni_tools import copy2cpu as c2c 80 | from human_body_prior.tools.omni_tools import makepath 81 | from trimesh import Trimesh as Mesh 82 | from trimesh.util import concatenate as mesh_cat 83 | 84 | renderer = pyrenderer(1024, 1024) 85 | 86 | faces = c2c(bm.f) 87 | 88 | def render_once(body_parms, body_colors=[colors['grey'], colors['brown-light']], out_fname=None): 89 | ''' 90 | 91 | :param body_parms: list of dictionaries of body parameters. 92 | :param body_colors: list of np arrays of color rgb values 93 | :param movie_outpath: a mp4 path 94 | :return: 95 | ''' 96 | 97 | if out_fname is not None: makepath(out_fname, isfile=True) 98 | assert len(body_parms) <= len(body_colors), ValueError('Not enough colors provided for #{} body_parms'.format(len(body_parms))) 99 | 100 | bs = body_parms[0]['pose_body'].shape[0] 101 | 102 | body_ids = np.random.choice(bs, num_bodies_to_display) 103 | 104 | body_evals = [c2c(bm(root_orient=v['root_orient'].view(bs, -1) if 'root_orient' in v else torch.zeros(bs, 3).type_as(v['pose_body']), 105 | pose_body=v['pose_body'].contiguous().view(bs, -1)).v) for v in body_parms] 106 | num_verts = body_evals[0].shape[1] 107 | 108 | render_meshes = [] 109 | for bId in body_ids: 110 | concat_cur_meshes = None 111 | for body, body_color in zip(body_evals, body_colors): 112 | cur_body_mesh = Mesh(body[bId], faces, vertex_colors=np.ones([num_verts, 3]) * body_color) 113 | concat_cur_meshes = cur_body_mesh if concat_cur_meshes is None else mesh_cat(concat_cur_meshes, cur_body_mesh) 114 | render_meshes.append(concat_cur_meshes) 115 | 116 | img = renderer(render_meshes) 117 | 118 | if out_fname is not None: imagearray2file(img, out_fname, fps=10) 119 | 120 | 121 | return 122 | 123 | return render_once 124 | -------------------------------------------------------------------------------- /src/metrics/common.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import linalg 3 | 4 | 5 | # (X - X_train)*(X - X_train) = -2X*X_train + X*X + X_train*X_train 6 | def euclidean_distance_matrix(matrix1, matrix2): 7 | """ 8 | Params: 9 | -- matrix1: N1 x D 10 | -- matrix2: N2 x D 11 | Returns: 12 | -- dist: N1 x N2 13 | dist[i, j] == distance(matrix1[i], matrix2[j]) 14 | """ 15 | assert matrix1.shape[1] == matrix2.shape[1] 16 | d1 = -2 * np.dot(matrix1, matrix2.T) # shape (num_test, num_train) 17 | d2 = np.sum(np.square(matrix1), axis=1, keepdims=True) # shape (num_test, 1) 18 | d3 = np.sum(np.square(matrix2), axis=1) # shape (num_train, ) 19 | dists = np.sqrt(d1 + d2 + d3) # broadcasting 20 | return dists 21 | 22 | 23 | def calculate_top_k(mat, top_k): 24 | size = mat.shape[0] 25 | gt_mat = np.expand_dims(np.arange(size), 1).repeat(size, 1) 26 | bool_mat = (mat == gt_mat) 27 | correct_vec = False 28 | top_k_list = [] 29 | for i in range(top_k): 30 | # print(correct_vec, bool_mat[:, i]) 31 | correct_vec = (correct_vec | bool_mat[:, i]) 32 | # print(correct_vec) 33 | top_k_list.append(correct_vec[:, None]) 34 | top_k_mat = np.concatenate(top_k_list, axis=1) 35 | return top_k_mat 36 | 37 | 38 | def calculate_R_precision(embedding1, embedding2, top_k, sum_all=False): 39 | dist_mat = euclidean_distance_matrix(embedding1, embedding2) 40 | argmax = np.argsort(dist_mat, axis=1) 41 | top_k_mat = calculate_top_k(argmax, top_k) 42 | if sum_all: 43 | return top_k_mat.sum(axis=0) 44 | else: 45 | return top_k_mat 46 | 47 | 48 | def calculate_matching_score(embedding1, embedding2, sum_all=False): 49 | assert len(embedding1.shape) == 2 50 | assert embedding1.shape[0] == embedding2.shape[0] 51 | assert embedding1.shape[1] == embedding2.shape[1] 52 | 53 | dist = linalg.norm(embedding1 - embedding2, axis=1) 54 | if sum_all: 55 | return dist.sum(axis=0) 56 | else: 57 | return dist 58 | 59 | 60 | def calculate_activation_statistics(activations): 61 | """ 62 | Params: 63 | -- activation: num_samples x dim_feat 64 | Returns: 65 | -- mu: dim_feat 66 | -- sigma: dim_feat x dim_feat 67 | """ 68 | mu = np.mean(activations, axis=0) 69 | cov = np.cov(activations, rowvar=False) 70 | return mu, cov 71 | 72 | 73 | def calculate_diversity(activations, diversity_times): 74 | assert len(activations.shape) == 2 75 | assert activations.shape[0] > diversity_times 76 | num_samples = activations.shape[0] 77 | 78 | first_indices = np.random.choice(num_samples, diversity_times, replace=False) 79 | second_indices = np.random.choice(num_samples, diversity_times, replace=False) 80 | dist = linalg.norm((activations[first_indices] - activations[second_indices])/2, axis=1) 81 | return dist.mean() 82 | 83 | 84 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 85 | """Numpy implementation of the Frechet Distance. 86 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 87 | and X_2 ~ N(mu_2, C_2) is 88 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 89 | Stable version by Dougal J. Sutherland. 90 | Params: 91 | -- mu1 : Numpy array containing the activations of a layer of the 92 | inception net (like returned by the function 'get_predictions') 93 | for generated samples. 94 | -- mu2 : The sample mean over activations, precalculated on an 95 | representative data set. 96 | -- sigma1: The covariance matrix over activations for generated samples. 97 | -- sigma2: The covariance matrix over activations, precalculated on an 98 | representative data set. 99 | Returns: 100 | -- : The Frechet Distance. 101 | """ 102 | 103 | mu1 = np.atleast_1d(mu1) 104 | mu2 = np.atleast_1d(mu2) 105 | 106 | sigma1 = np.atleast_2d(sigma1) 107 | sigma2 = np.atleast_2d(sigma2) 108 | 109 | assert mu1.shape == mu2.shape, \ 110 | 'Training and test mean vectors have different lengths' 111 | assert sigma1.shape == sigma2.shape, \ 112 | 'Training and test covariances have different dimensions' 113 | 114 | diff = mu1 - mu2 115 | 116 | # Product might be almost singular 117 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 118 | if not np.isfinite(covmean).all(): 119 | msg = ('fid calculation produces singular product; ' 120 | 'adding %s to diagonal of cov estimates') % eps 121 | print(msg) 122 | offset = np.eye(sigma1.shape[0]) * eps 123 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 124 | 125 | # Numerical error might give slight imaginary component 126 | if np.iscomplexobj(covmean): 127 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 128 | m = np.max(np.abs(covmean.imag)) 129 | print('Imaginary component {}'.format(m)) 130 | # raise ValueError('Imaginary component {}'.format(m)) 131 | covmean = covmean.real 132 | 133 | tr_covmean = np.trace(covmean) 134 | 135 | return (diff.dot(diff) + np.trace(sigma1) + 136 | np.trace(sigma2) - 2 * tr_covmean) 137 | 138 | 139 | def calculate_fid(gt_embeddings, pred_embeddings): 140 | gt_mu, gt_cov = calculate_activation_statistics(gt_embeddings) 141 | pred_mu, pred_cov = calculate_activation_statistics(pred_embeddings) 142 | fid = calculate_frechet_distance(gt_mu, gt_cov, pred_mu, pred_cov) 143 | return fid 144 | 145 | 146 | def calculate_multimodality(activation, multimodality_times): 147 | assert len(activation.shape) == 3 148 | assert activation.shape[1] > multimodality_times 149 | num_per_sent = activation.shape[1] 150 | 151 | first_dices = np.random.choice(num_per_sent, multimodality_times, replace=False) 152 | second_dices = np.random.choice(num_per_sent, multimodality_times, replace=False) 153 | dist = linalg.norm(activation[:, first_dices] - activation[:, second_dices], axis=2) 154 | return dist.mean() 155 | -------------------------------------------------------------------------------- /data_preprocessing/5-prepare_normalizer.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import pickle 4 | import tqdm 5 | import numpy as np 6 | from pathlib import Path 7 | import torch 8 | import argparse 9 | 10 | sys.path.append(os.getcwd()) 11 | sys.path.append(os.getcwd() + '/../') 12 | 13 | 14 | def get_args(): 15 | parser = argparse.ArgumentParser() 16 | 17 | parser.add_argument( 18 | '--dataset', 19 | default='interx' 20 | ) 21 | 22 | parser.add_argument( 23 | '--motion_representations', 24 | default='joints3d_22,intergen_262' 25 | ) 26 | 27 | args = parser.parse_args() 28 | args.motion_representations = args.motion_representations.split(',') 29 | return args 30 | 31 | 32 | def process_single_dataset(args): 33 | data_dir = Path(f'~/data/data/motion/{args.dataset}').expanduser() 34 | save_dir = data_dir / 'normalizers' 35 | save_dir.mkdir(exist_ok=True) 36 | 37 | for motion_representation in args.motion_representations: 38 | src_dir = data_dir / motion_representation 39 | motion_paths = src_dir.glob('*') 40 | actions = [] 41 | nactions = [] 42 | reactions = [] 43 | all_motions = [] 44 | for mp in tqdm.tqdm(motion_paths): 45 | try: 46 | if mp.name.endswith('pkl'): 47 | with mp.open('rb') as f: 48 | motion = pickle.load(f) 49 | if not 32 <= len(motion['reaction']) <= 256: 50 | continue 51 | reactions.append(torch.from_numpy(motion['reaction'])) 52 | all_motions.append(torch.from_numpy(motion['reaction'])) 53 | if 'action' in motion.keys(): 54 | actions.append(torch.from_numpy(motion['action'])) 55 | nactions.append(torch.from_numpy(motion['naction'])) 56 | all_motions.append(torch.from_numpy(motion['naction'])) 57 | else: 58 | with mp.open('rb') as f: 59 | motion = np.load(f) 60 | reactions.append(torch.from_numpy(motion)) 61 | except Exception as e: 62 | print(e) 63 | 64 | reactions = torch.cat(reactions, dim=0) 65 | 66 | sd = { 67 | 'reaction': { 68 | 'mean': reactions.mean(0), 69 | 'std': reactions.std(0) 70 | } 71 | } 72 | if actions != []: 73 | actions = torch.cat(actions, dim=0) 74 | sd.update({ 75 | 'action': { 76 | 'mean': actions.mean(0), 77 | 'std': actions.std(0) 78 | } 79 | }) 80 | nactions = torch.cat(nactions, dim=0) 81 | sd.update({ 82 | 'naction': { 83 | 'mean': nactions.mean(0), 84 | 'std': nactions.std(0) 85 | } 86 | }) 87 | all_motions = torch.cat(all_motions, dim=0) 88 | sd.update({ 89 | 'all_motion': { 90 | 'mean': nactions.mean(0), 91 | 'std': nactions.std(0) 92 | } 93 | }) 94 | with (save_dir / f'{motion_representation}.pkl').open('wb') as f: 95 | pickle.dump(sd, f) 96 | 97 | 98 | def process_multi_datasets(args): 99 | root_dir = Path(f'~/data/data/motion').expanduser() 100 | save_dir = root_dir / 'normalizers' 101 | save_dir.mkdir(exist_ok=True) 102 | 103 | for motion_representation in args.motion_representations: 104 | actions = [] 105 | nactions = [] 106 | reactions = [] 107 | all_motions = [] 108 | for dataset in args.dataset.split(','): 109 | dataset_dir = root_dir / dataset 110 | motion_dir = dataset_dir / motion_representation 111 | motion_paths = list(motion_dir.glob('*')) 112 | for mp in tqdm.tqdm(motion_paths): 113 | try: 114 | if mp.name.endswith('pkl'): 115 | with mp.open('rb') as f: 116 | motion = pickle.load(f) 117 | reactions.append(torch.from_numpy(motion['reaction'])) 118 | all_motions.append(torch.from_numpy(motion['reaction'])) 119 | if 'action' in motion.keys(): 120 | actions.append(torch.from_numpy(motion['action'])) 121 | nactions.append(torch.from_numpy(motion['naction'])) 122 | all_motions.append(torch.from_numpy(motion['naction'])) 123 | else: 124 | with mp.open('rb') as f: 125 | motion = np.load(f) 126 | reactions.append(torch.from_numpy(motion)) 127 | except Exception as e: 128 | print(e) 129 | 130 | reactions = torch.cat(reactions, dim=0) 131 | all_motions = torch.cat(all_motions, dim=0) 132 | 133 | sd = { 134 | 'reaction': { 135 | 'mean': reactions.mean(0), 136 | 'std': reactions.std(0) 137 | }, 138 | 'all_motion': { 139 | 'mean': all_motions.mean(0), 140 | 'std': all_motions.std(0), 141 | } 142 | } 143 | if actions != []: 144 | actions = torch.cat(actions, dim=0) 145 | sd.update({ 146 | 'action': { 147 | 'mean': actions.mean(0), 148 | 'std': actions.std(0) 149 | } 150 | }) 151 | nactions = torch.cat(nactions, dim=0) 152 | sd.update({ 153 | 'naction': { 154 | 'mean': nactions.mean(0), 155 | 'std': nactions.std(0) 156 | } 157 | }) 158 | 159 | with (save_dir / f'{motion_representation}.pkl').open('wb') as f: 160 | pickle.dump(sd, f) 161 | 162 | 163 | if __name__ == '__main__': 164 | args = get_args() 165 | datasets = args.dataset.split(',') 166 | if len(datasets) == 1: 167 | process_single_dataset(args) 168 | else: 169 | process_multi_datasets(args) 170 | -------------------------------------------------------------------------------- /src/utils/plot.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | import copy 3 | import numpy as np 4 | from pathlib import Path 5 | import matplotlib.pyplot as plt 6 | from matplotlib.animation import FuncAnimation 7 | from mpl_toolkits.mplot3d.art3d import Poly3DCollection 8 | import pickle 9 | 10 | from src.utils.constants import JOINTS3D_22_KINEMATIC_CHAIN, EDGE22_INDICES_UNDIRCTIONAL 11 | from src.utils.motion_representation_converter import MotionRepresentationConverter 12 | 13 | 14 | optional_colors = ['r', 'g', 'b', 'c', 'm', 'y'] 15 | mrc = MotionRepresentationConverter() 16 | 17 | 18 | def visualize_all_pkl(src_dir: str, file_pattern: str = '*.pkl'): 19 | src_dir = Path(src_dir) 20 | for p in src_dir.glob(file_pattern): 21 | animate_from_pkl(p) 22 | 23 | 24 | def animate_from_pkl(file_path: str): 25 | with open(file_path, 'rb') as f: 26 | data_dict = pickle.load(f) 27 | 28 | shift = np.array([2, 0, 0]) 29 | motions = [] 30 | 31 | gt_action = data_dict.get('action', data_dict.get('gt_action', None)) 32 | if gt_action is not None: 33 | if len(gt_action.shape) == 2 and gt_action.shape[-1] == 262: 34 | gt_action = mrc('i262', 'j3d', gt_action) 35 | motions.append(gt_action) 36 | 37 | gt_reaction = data_dict.get('reaction', data_dict.get('gt_reaction', None)) 38 | if gt_reaction is not None: 39 | if len(gt_reaction.shape) == 2 and gt_reaction.shape[-1] == 262: 40 | gt_reaction = mrc('i262', 'j3d', gt_reaction) 41 | motions.append(gt_reaction) 42 | 43 | pred_reaction = data_dict.get('pred_reaction', None) 44 | if pred_reaction is not None: 45 | if len(pred_reaction.shape) == 2 and pred_reaction.shape[-1] == 262: 46 | pred_reaction = mrc('i262', 'j3d', pred_reaction) 47 | pred_action = gt_action + shift 48 | pred_reaction = pred_reaction + shift 49 | motions.append(pred_action) 50 | motions.append(pred_reaction) 51 | 52 | text = data_dict.get('caption', 'None') 53 | 54 | animate_multiple_joints3d_22( 55 | motions=motions, 56 | colors=optional_colors[:len(motions)], 57 | title=text, 58 | file_path=file_path.replace('pkl', 'mp4') 59 | ) 60 | 61 | 62 | def animate_multiple_joints3d_22(motions, colors, title, file_path, fps=20, downsample_rate=4, show_axis=False): 63 | motions = copy.deepcopy(motions) 64 | for i, m in enumerate(motions): 65 | motions[i] = m[::downsample_rate, ...] 66 | if len(motions[i].shape) == 2 and motions[i].shape[1] == 262: 67 | motions[i] = mrc('i262', 'j3d', motions[i]) 68 | 69 | if isinstance(title, str): 70 | words = title.split(' ') 71 | else: 72 | words = [] 73 | title = '' 74 | for i, word in enumerate(words): 75 | if i % 10 == 0 and i > 0: 76 | title += '\n' 77 | title += word + ' ' 78 | 79 | fig = plt.figure() 80 | ax = fig.add_subplot(111, projection='3d') 81 | # Clear the axis before drawing new frame 82 | def init(): 83 | ax.set_xlim(-2, 2) 84 | ax.set_ylim(0, 2) 85 | ax.set_zlim(-2, 2) 86 | if not show_axis: 87 | ax.set_axis_off() 88 | fig.suptitle(title, fontsize=10) 89 | return [] 90 | 91 | def plot_xzPlane(minx, maxx, miny, minz, maxz): 92 | ## Plot a plane XZ 93 | verts = [ 94 | [minx, miny, minz], 95 | [minx, miny, maxz], 96 | [maxx, miny, maxz], 97 | [maxx, miny, minz] 98 | ] 99 | xz_plane = Poly3DCollection([verts]) 100 | xz_plane.set_facecolor((0.5, 0.5, 0.5, 0.5)) 101 | ax.add_collection3d(xz_plane) 102 | 103 | # Update function for animation 104 | def update(frame): 105 | ax.clear() 106 | ax.set_xlim(-2, 2) 107 | ax.set_ylim(0, 2) 108 | ax.set_zlim(-2, 2) 109 | if not show_axis: 110 | ax.set_axis_off() 111 | # ax.view_init(elev=120, azim=-30, roll=90, vertical_axis='y') 112 | ax.view_init(vertical_axis='y') 113 | plot_xzPlane(-2, 2, 0, -2, 2) 114 | 115 | # Draw reaction skeleton 116 | for m, c in zip(motions, colors): 117 | plot_skeleton(ax, m[frame], color=c) 118 | 119 | return ax, 120 | 121 | # Function to plot a single skeleton 122 | def plot_skeleton(ax, joints, color): 123 | for chain in JOINTS3D_22_KINEMATIC_CHAIN: 124 | for i in range(len(chain) - 1): 125 | ax.plot([joints[chain[i]][0], joints[chain[i + 1]][0]], 126 | [joints[chain[i]][1], joints[chain[i + 1]][1]], 127 | [joints[chain[i]][2], joints[chain[i + 1]][2]], 128 | '-k', lw=2) # Plot lines between joints 129 | 130 | for joint in joints: 131 | ax.scatter(joint[0], joint[1], joint[2], c=color, s=30) # Plot joints 132 | 133 | # Create the animation 134 | ani = FuncAnimation(fig, update, frames=np.arange(0, motions[0].shape[0]), init_func=init, blit=False) 135 | 136 | ani.save(file_path, fps=fps // downsample_rate) 137 | plt.close() 138 | 139 | 140 | def visualize_3d_skeleton(joints: np.ndarray, save_path): 141 | """ 142 | Visualizes a 3D skeleton. 143 | 144 | Parameters: 145 | - joints: A numpy array of shape (22, 3) representing the XYZ coordinates of 22 joints. 146 | - edge_indices: A numpy array of shape (n_edges, 2) representing the indices of connected joints. 147 | """ 148 | # Create a new matplotlib figure and 3D axis 149 | fig = plt.figure() 150 | ax = fig.add_subplot(111, projection='3d') 151 | 152 | # Plot the joints as points 153 | for i in range(joints.shape[0]): 154 | ax.scatter(joints[i, 0], joints[i, 1], joints[i, 2], color='r', s=30) 155 | 156 | # Plot the edges between joints 157 | for edge in EDGE22_INDICES_UNDIRCTIONAL: 158 | start_joint = joints[edge[0]] 159 | end_joint = joints[edge[1]] 160 | ax.plot([start_joint[0], end_joint[0]], [start_joint[1], end_joint[1]], [start_joint[2], end_joint[2]], color='b') 161 | 162 | # Set labels for the axes 163 | ax.set_xlabel('X') 164 | ax.set_ylabel('Y') 165 | ax.set_zlabel('Z') 166 | 167 | # Show the plot 168 | fig.savefig(save_path) 169 | plt.close() 170 | -------------------------------------------------------------------------------- /third_party/HumanML3D/human_body_prior/tools/omni_tools.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright (C) 2019 Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG), 4 | # acting on behalf of its Max Planck Institute for Intelligent Systems and the 5 | # Max Planck Institute for Biological Cybernetics. All rights reserved. 6 | # 7 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is holder of all proprietary rights 8 | # on this computer program. You can only use this computer program if you have closed a license agreement 9 | # with MPG or you get the right to use the computer program from someone who is authorized to grant you that right. 10 | # Any use of the computer program without a valid license is prohibited and liable to prosecution. 11 | # Contact: ps-license@tuebingen.mpg.de 12 | # 13 | # 14 | # If you use this code in a research publication please consider citing the following: 15 | # 16 | # Expressive Body Capture: 3D Hands, Face, and Body from a Single Image 17 | # 18 | # 19 | # Code Developed by: 20 | # Nima Ghorbani 21 | # 22 | # 2018.01.02 23 | import numpy as np 24 | import random 25 | import torch 26 | import os 27 | import sys 28 | import os.path as osp 29 | 30 | def copy2cpu(tensor): 31 | if isinstance(tensor, np.ndarray): return tensor 32 | return tensor.detach().cpu().numpy() 33 | 34 | def create_list_chunks(list_, group_size, overlap_size, cut_smaller_batches=True): 35 | if cut_smaller_batches: 36 | return [list_[i:i + group_size] for i in range(0, len(list_), group_size - overlap_size) if len(list_[i:i + group_size])==group_size] 37 | else: 38 | return [list_[i:i + group_size] for i in range(0, len(list_), group_size - overlap_size)] 39 | 40 | 41 | def trainable_params_count(params): 42 | return sum([p.numel() for p in params if p.requires_grad]) 43 | 44 | def flatten_list(l): 45 | return [item for sublist in l for item in sublist] 46 | 47 | def get_support_data_dir(current_fname=__file__): 48 | support_data_dir = osp.abspath(current_fname) 49 | support_data_dir_split = support_data_dir.split('/') 50 | support_data_dir = '/'.join(support_data_dir_split[:support_data_dir_split.index('src')]) 51 | support_data_dir = osp.join(support_data_dir, 'support_data') 52 | assert osp.exists(support_data_dir) 53 | return support_data_dir 54 | 55 | def make_deterministic(seed): 56 | random.seed(seed) 57 | torch.manual_seed(seed) 58 | torch.cuda.manual_seed_all(seed) 59 | np.random.seed(seed) 60 | os.environ['PYTHONHASHSEED'] = str(seed) 61 | torch.backends.cudnn.deterministic = True 62 | torch.backends.cudnn.benchmark = False 63 | 64 | def id_generator(size=13): 65 | import string 66 | import random 67 | chars = string.ascii_uppercase + string.digits 68 | return ''.join(random.choice(chars) for _ in range(size)) 69 | 70 | def logger_sequencer(logger_list, prefix=None): 71 | def post_text(text): 72 | if prefix is not None: text = '{} -- '.format(prefix) + text 73 | for logger_call in logger_list: logger_call(text) 74 | return post_text 75 | 76 | class log2file(): 77 | def __init__(self,logpath=None, prefix='', auto_newline = True, write2file_only=False): 78 | if logpath is not None: 79 | makepath(logpath, isfile=True) 80 | self.fhandle = open(logpath,'a+') 81 | else: 82 | self.fhandle = None 83 | 84 | self.prefix = prefix 85 | self.auto_newline = auto_newline 86 | self.write2file_only = write2file_only 87 | 88 | def __call__(self, text): 89 | if text is None: return 90 | if self.prefix != '': text = '{} -- '.format(self.prefix) + text 91 | # breakpoint() 92 | if self.auto_newline: 93 | if not text.endswith('\n'): 94 | text = text + '\n' 95 | if not self.write2file_only: sys.stderr.write(text) 96 | if self.fhandle is not None: 97 | self.fhandle.write(text) 98 | self.fhandle.flush() 99 | 100 | 101 | def makepath(*args, **kwargs): 102 | ''' 103 | if the path does not exist make it 104 | :param desired_path: can be path to a file or a folder name 105 | :return: 106 | ''' 107 | isfile = kwargs.get('isfile', False) 108 | import os 109 | desired_path = os.path.join(*args) 110 | if isfile: 111 | if not os.path.exists(os.path.dirname(desired_path)):os.makedirs(os.path.dirname(desired_path)) 112 | else: 113 | if not os.path.exists(desired_path): os.makedirs(desired_path) 114 | return desired_path 115 | 116 | def matrot2axisangle(matrots): 117 | ''' 118 | :param matrots: N*T*num_joints*9 119 | :return: N*T*num_joints*3 120 | ''' 121 | import cv2 122 | N = matrots.shape[0] 123 | T = matrots.shape[1] 124 | n_joints = matrots.shape[2] 125 | out_axisangle = [] 126 | for tIdx in range(T): 127 | T_axisangle = [] 128 | for mIdx in range(N): 129 | cur_axisangle = [] 130 | for jIdx in range(n_joints): 131 | cur_axisangle.append(cv2.Rodrigues(matrots[mIdx, tIdx, jIdx:jIdx + 1, :].reshape(3, 3))[0].T) 132 | T_axisangle.append(np.vstack(cur_axisangle)[np.newaxis]) 133 | out_axisangle.append(np.vstack(T_axisangle).reshape([N,1, -1,3])) 134 | return np.concatenate(out_axisangle, axis=1) 135 | 136 | def axisangle2matrots(axisangle): 137 | ''' 138 | :param matrots: N*1*num_joints*3 139 | :return: N*num_joints*9 140 | ''' 141 | import cv2 142 | batch_size = axisangle.shape[0] 143 | axisangle = axisangle.reshape([batch_size,1,-1,3]) 144 | out_matrot = [] 145 | for mIdx in range(axisangle.shape[0]): 146 | cur_axisangle = [] 147 | for jIdx in range(axisangle.shape[2]): 148 | a = cv2.Rodrigues(axisangle[mIdx, 0, jIdx:jIdx + 1, :].reshape(1, 3))[0].T 149 | cur_axisangle.append(a) 150 | 151 | out_matrot.append(np.array(cur_axisangle).reshape([batch_size,1,-1,9])) 152 | return np.vstack(out_matrot) 153 | 154 | 155 | def apply_mesh_tranfsormations_(meshes, transf): 156 | ''' 157 | apply inplace translations to meshes 158 | :param meshes: list of trimesh meshes 159 | :param transf: 160 | :return: 161 | ''' 162 | for i in range(len(meshes)): 163 | meshes[i] = meshes[i].apply_transform(transf) -------------------------------------------------------------------------------- /data_preprocessing/6-prepare_tokens.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import pickle 4 | import tqdm 5 | import numpy as np 6 | from pathlib import Path 7 | import torch 8 | import argparse 9 | 10 | sys.path.append(os.getcwd()) 11 | sys.path.append(os.getcwd() + '/../') 12 | 13 | 14 | def get_args(): 15 | parser = argparse.ArgumentParser() 16 | 17 | parser.add_argument( 18 | '--dataset', 19 | default='interx' 20 | ) 21 | 22 | parser.add_argument( 23 | '--motion_representations', 24 | default='intergen_262' 25 | ) 26 | 27 | args, _ = parser.parse_known_args() 28 | args.motion_representations = args.motion_representations.split(',') 29 | return args 30 | 31 | 32 | def process_single_dataset(args): 33 | data_dir = Path(f'~/data/data/motion/{args.dataset}').expanduser() 34 | save_dir = data_dir / 'normalizers' 35 | save_dir.mkdir(exist_ok=True) 36 | 37 | for motion_representation in args.motion_representations: 38 | src_dir = data_dir / motion_representation 39 | motion_paths = src_dir.glob('*') 40 | actions = [] 41 | nactions = [] 42 | reactions = [] 43 | for mp in tqdm.tqdm(motion_paths): 44 | try: 45 | if mp.name.endswith('pkl'): 46 | with mp.open('rb') as f: 47 | motion = pickle.load(f) 48 | if not 32 <= len(motion['reaction']) <= 256: 49 | continue 50 | reactions.append(torch.from_numpy(motion['reaction'])) 51 | if 'action' in motion.keys(): 52 | actions.append(torch.from_numpy(motion['action'])) 53 | nactions.append(torch.from_numpy(motion['naction'])) 54 | else: 55 | with mp.open('rb') as f: 56 | motion = np.load(f) 57 | reactions.append(torch.from_numpy(motion)) 58 | except Exception as e: 59 | print(e) 60 | 61 | reactions = torch.cat(reactions, dim=0) 62 | 63 | sd = { 64 | 'reaction': { 65 | 'mean': reactions.mean(0), 66 | 'std': reactions.std(0) 67 | } 68 | } 69 | if actions != []: 70 | actions = torch.cat(actions, dim=0) 71 | sd.update({ 72 | 'action': { 73 | 'mean': actions.mean(0), 74 | 'std': actions.std(0) 75 | } 76 | }) 77 | nactions = torch.cat(nactions, dim=0) 78 | sd.update({ 79 | 'naction': { 80 | 'mean': nactions.mean(0), 81 | 'std': nactions.std(0) 82 | } 83 | }) 84 | all_motion = torch.cat([actions, reactions], dim=0) 85 | sd.update({ 86 | 'all_motion': { 87 | 'mean': all_motion.mean(0), 88 | 'std': all_motion.std(0) 89 | } 90 | }) 91 | egocentric_motion = torch.cat([nactions, reactions], dim=0) 92 | sd.update({ 93 | 'egocentric_motion': { 94 | 'mean': egocentric_motion.mean(0), 95 | 'std': egocentric_motion.std(0) 96 | } 97 | }) 98 | with (save_dir / f'{motion_representation}.pkl').open('wb') as f: 99 | pickle.dump(sd, f) 100 | 101 | 102 | def process_multi_datasets(args): 103 | root_dir = Path(f'~/data/data/motion').expanduser() 104 | save_dir = root_dir / 'normalizers' 105 | save_dir.mkdir(exist_ok=True) 106 | 107 | for motion_representation in args.motion_representations: 108 | actions = [] 109 | nactions = [] 110 | reactions = [] 111 | all_motions = [] 112 | for dataset in args.dataset.split(','): 113 | dataset_dir = root_dir / dataset 114 | motion_dir = dataset_dir / motion_representation 115 | motion_paths = list(motion_dir.glob('*')) 116 | for mp in tqdm.tqdm(motion_paths): 117 | try: 118 | if mp.name.endswith('pkl'): 119 | with mp.open('rb') as f: 120 | motion = pickle.load(f) 121 | reactions.append(torch.from_numpy(motion['reaction'])) 122 | all_motions.append(torch.from_numpy(motion['reaction'])) 123 | if 'action' in motion.keys(): 124 | actions.append(torch.from_numpy(motion['action'])) 125 | nactions.append(torch.from_numpy(motion['naction'])) 126 | all_motions.append(torch.from_numpy(motion['naction'])) 127 | else: 128 | with mp.open('rb') as f: 129 | motion = np.load(f) 130 | reactions.append(torch.from_numpy(motion)) 131 | except Exception as e: 132 | print(e) 133 | 134 | reactions = torch.cat(reactions, dim=0) 135 | all_motions = torch.cat(all_motions, dim=0) 136 | 137 | sd = { 138 | 'reaction': { 139 | 'mean': reactions.mean(0), 140 | 'std': reactions.std(0) 141 | }, 142 | 'all_motion': { 143 | 'mean': all_motions.mean(0), 144 | 'std': all_motions.std(0), 145 | } 146 | } 147 | if actions != []: 148 | actions = torch.cat(actions, dim=0) 149 | sd.update({ 150 | 'action': { 151 | 'mean': actions.mean(0), 152 | 'std': actions.std(0) 153 | } 154 | }) 155 | nactions = torch.cat(nactions, dim=0) 156 | sd.update({ 157 | 'naction': { 158 | 'mean': nactions.mean(0), 159 | 'std': nactions.std(0) 160 | } 161 | }) 162 | 163 | with (save_dir / f'{motion_representation}.pkl').open('wb') as f: 164 | pickle.dump(sd, f) 165 | 166 | 167 | if __name__ == '__main__': 168 | args = get_args() 169 | datasets = args.dataset.split(',') 170 | if len(datasets) == 1: 171 | process_single_dataset(args) 172 | else: 173 | process_multi_datasets(args) 174 | -------------------------------------------------------------------------------- /data_preprocessing/check_prompt_with_llm.py: -------------------------------------------------------------------------------- 1 | #%% prepare LLMs 2 | import transformers 3 | import torch 4 | 5 | 6 | class LLM: 7 | insruction = \ 8 | """ 9 | You are tasked with analyzing a narrative that describes the interactions between two individuals through their movements. Your goal is to identify whether their the two persons are doing intense exercise like boxing, fencing or fighting. 10 | Respond "Yes." or "No." only without any other text. 11 | """ 12 | 13 | query_template = \ 14 | """ 15 | Here's the set of descriptions of the interaction: 16 | {} 17 | """ 18 | 19 | example_description_set_1 = \ 20 | """ 21 | the two guys grip swords with the right hand. one strikes to the left twice, and the other moves the sword to the right twice. 22 | two humans grip the swords in their right hand. the first one lunges twice to the left with the sword while the second one lunges twice to the right with their sword. 23 | two performers wield swords in their right hands, while the first person swipes the sword twice to the left, the second slashes two times in the opposite direction. 24 | """ 25 | 26 | example_response_1 = \ 27 | """ 28 | Yes. 29 | """ 30 | 31 | example_description_set_2 = \ 32 | """ 33 | one squats down and picks up an object from the ground, as the other approaches with head down. 34 | the first person bends down and picks up an item from the floor with both hands, while the second lowers their head and walks towards the first person. 35 | one person crouches down and picks up an item from the ground with both hands, while the other approaches and lowers their head towards the first. 36 | """ 37 | 38 | example_response_2 = \ 39 | """ 40 | No. 41 | """ 42 | 43 | 44 | def __init__(self, device, model_dir): 45 | self.pipeline = transformers.pipeline( 46 | "text-generation", 47 | model=model_dir, 48 | model_kwargs={"torch_dtype": torch.bfloat16}, 49 | device=device, 50 | use_fast=False 51 | ) 52 | 53 | def preprocess_lines(self, lines): 54 | res = [] 55 | for line in lines: 56 | res.append(line.replace('his/her', 'his').replace('him/her', 'him').replace('he/she', 'he')) 57 | return res 58 | 59 | @torch.no_grad() 60 | def one_round_qa(self, lines): 61 | lines = self.preprocess_lines(lines) 62 | description_set = '' 63 | for line in lines: 64 | description_set += line.strip() + '\n' 65 | messages = [ 66 | {"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"}, 67 | {"role": "user", "content": self.insruction + self.query_template.format(self.example_description_set_1)}, 68 | {"role": "assistant", "content": self.example_response_1}, 69 | {"role": "user", "content": self.query_template.format(self.example_description_set_2)}, 70 | {"role": "assistant", "content": self.example_response_2}, 71 | {"role": "user", "content": self.query_template.format(description_set)} 72 | ] 73 | 74 | prompt = self.pipeline.tokenizer.apply_chat_template( 75 | messages, 76 | tokenize=False, 77 | add_generation_prompt=True 78 | ) 79 | 80 | terminators = [ 81 | self.pipeline.tokenizer.eos_token_id, 82 | self.pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>") 83 | ] 84 | 85 | outputs = self.pipeline( 86 | prompt, 87 | max_new_tokens=256, 88 | eos_token_id=terminators, 89 | do_sample=True, 90 | temperature=0.1, 91 | top_p=0.9, 92 | pad_token_id=self.pipeline.tokenizer.eos_token_id 93 | ) 94 | response = outputs[0]["generated_text"][len(prompt):].strip().split('\n') 95 | 96 | assert response[0].startswith('[Initiator]') 97 | assert response[1].startswith('[Receiver]') 98 | return { 99 | 'action': [response[0][len('[Initiator]'):].strip(' \n.')], 100 | 'reaction': [response[1][len('[Receiver]'):].strip(' \n.')] 101 | } 102 | 103 | #%% Prepare src data 104 | from pathlib import Path 105 | 106 | data_root_dir = Path('~/data/data/motion/interhuman').expanduser() 107 | src_txt_path_list = (data_root_dir / 'texts').glob('*.txt') 108 | tgt_dir = data_root_dir / 'annots' / 'short_thinkings' 109 | tgt_dir.mkdir(exist_ok=True, parents=True) 110 | 111 | name_lines = [] 112 | for src_txt_path in src_txt_path_list: 113 | with src_txt_path.open('r') as f: 114 | 115 | try: 116 | lines = f.readlines() 117 | except: 118 | continue 119 | name_lines.append((src_txt_path.stem, lines)) 120 | 121 | #%% Start parallel processing 122 | import os 123 | import torch 124 | import json 125 | from concurrent.futures import ProcessPoolExecutor as PPE 126 | 127 | MODEL_DIR = os.path.expanduser('~/data/pretrained_models/llm/Meta-Llama-3-8B-Instruct') 128 | 129 | 130 | def single_process(device_idx, name_lines_chunk, model_dir=MODEL_DIR): 131 | llm = LLM(device=device_idx, model_dir=model_dir) 132 | res = {} 133 | for i, (name, lines) in enumerate(name_lines_chunk): 134 | if i % 100 == 0: 135 | print(i) 136 | try: 137 | result = llm.one_round_qa(lines) 138 | result['interaction'] = lines 139 | except Exception as e: 140 | print(e) 141 | else: 142 | res[name] = result 143 | return res 144 | 145 | devices = [torch.device(f'cuda:{i}') for i in '1,2,5,6,7'.split(',')] 146 | n_devices = len(devices) 147 | 148 | name_lines_chunks = [ 149 | name_lines[i: : n_devices] for i in range(n_devices) 150 | ] 151 | 152 | # with PPE(max_workers=n_devices) as ppe: 153 | # list(ppe.map(single_process, devices, name_lines_chunks)) 154 | single_process(devices[0], name_lines_chunks[0]) 155 | 156 | #%% check split data 157 | if False: 158 | import json 159 | import random 160 | from pathlib import Path 161 | 162 | data_root_dir = Path('~/data/data/motion/Inter-X_Dataset').expanduser() 163 | src_txt_path_list = list((data_root_dir / 'texts').glob('*.txt')) 164 | random.shuffle(src_txt_path_list) 165 | tgt_dir = data_root_dir / 'texts_action_reaction' 166 | 167 | for src_txt_path in src_txt_path_list[:10]: 168 | stem = src_txt_path.stem 169 | with src_txt_path.open('r') as f: 170 | src_lines = f.readlines() 171 | with (tgt_dir / f'{stem}.json').open('r') as f: 172 | tgt_lines = json.load(f) 173 | 174 | for src, tgt in zip(src_lines, tgt_lines): 175 | print(f'{src.strip()}\n{tgt}\n') 176 | print('-----------------------------------') 177 | # %% 178 | -------------------------------------------------------------------------------- /data_preprocessing/0-smpl_to_joints3d_22.py: -------------------------------------------------------------------------------- 1 | #%% smpl to joints3d_22 2 | from typing import List 3 | import sys 4 | import copy 5 | import os 6 | import pickle 7 | import numpy as np 8 | from pathlib import Path 9 | import torch 10 | import tqdm 11 | from concurrent.futures import ProcessPoolExecutor as PPE 12 | import argparse 13 | 14 | sys.path.append(os.getcwd()) 15 | sys.path.append(os.getcwd() + '/../') 16 | from third_party.HumanML3D.human_body_prior.body_model.body_model import BodyModel 17 | from src.utils.motion_representation_converter import MotionRepresentationConverter 18 | from data_preprocessing.utils import normalize_dual_joints3d_22, normalize_single_joints3d_22 19 | 20 | 21 | args = None 22 | trans_matrix = np.array([[1.0, 0.0, 0.0], 23 | [0.0, 0.0, -1.0], 24 | [0.0, 1.0, 0.0]]) 25 | mrc = MotionRepresentationConverter() 26 | 27 | #%% 28 | args = argparse.ArgumentParser() 29 | args.add_argument( 30 | '--dataset', 31 | default='interx', 32 | ) 33 | 34 | args.add_argument( 35 | '--devices', 36 | default='1' 37 | ) 38 | args = args.parse_args() 39 | 40 | #%% 41 | dataset = 'interx' if args is None else args.dataset 42 | devices = '1' if args is None else args.devices 43 | devices = [torch.device(f'cuda:{i}') for i in devices.split(',')] 44 | 45 | 46 | def single_process_smpl_to_joint3d_22(smpl_paths: List[Path], pose_save_dir, body_models_path, interaction_order, device, fps_downsample_rate, length_range, n_joints=22): 47 | def get_pose_seq_np(person, down_sample_rate): 48 | data = { 49 | 'trans': torch.Tensor(person['trans'][::down_sample_rate, ...]).to(device), 50 | 'pose_body': torch.Tensor(person['pose_body'][::down_sample_rate, ...]).view(-1, 21 * 3).to(device), 51 | 'root_orient': torch.Tensor(person['root_orient'][::down_sample_rate, ...]).to(device), 52 | } 53 | pose_seq_np = body_model(**data).Jtr.detach().cpu().numpy() 54 | if dataset == 'interhuman': 55 | pose_seq_np = np.dot(pose_seq_np, trans_matrix) 56 | return pose_seq_np 57 | 58 | body_model = BodyModel(bm_fname=body_models_path).to(device) 59 | 60 | for dual_smpl_path in tqdm.tqdm(smpl_paths): 61 | file_id = dual_smpl_path.stem 62 | # if (pose_save_dir / f'{file_id}.pkl').exists(): 63 | # continue 64 | try: 65 | # interx 66 | if dual_smpl_path.is_dir(): 67 | path_1 = dual_smpl_path / 'P1.npz' 68 | path_2 = dual_smpl_path / 'P2.npz' 69 | with path_1.open('rb') as f1, path_2.open('rb') as f2: 70 | person1 = np.load(f1) 71 | person2 = np.load(f2) 72 | n_frames = len(person1['pose_body']) 73 | if n_frames < length_range[0] or n_frames > length_range[1]: 74 | continue 75 | if interaction_order[file_id] == 0: 76 | person1, person2 = person2, person1 77 | 78 | action = get_pose_seq_np(person1, down_sample_rate=fps_downsample_rate)[:, :n_joints, :] 79 | reaction = get_pose_seq_np(person2, down_sample_rate=fps_downsample_rate)[:, :n_joints, :] 80 | action, reaction = mrc.norm_dual_joints3d_22(action, reaction) 81 | naction, (x, z, r) = mrc.norm_joint3d_22(action) 82 | pose_data = { 83 | 'naction': naction, 84 | 'action': action, 85 | 'reaction': reaction, 86 | 'action_x': x, 87 | 'action_z': z, 88 | 'action_r': r, 89 | } 90 | else: 91 | # interhuman 92 | with dual_smpl_path.open('rb') as f: 93 | dual_smpl = pickle.load(f) 94 | n_frames = dual_smpl['frames'] 95 | if n_frames < length_range[0] or n_frames > length_range[1]: 96 | continue 97 | person1 = dual_smpl['person1'] 98 | person2 = dual_smpl['person2'] 99 | if interaction_order[file_id] == 0: 100 | person1, person2 = person2, person1 101 | 102 | action = get_pose_seq_np(person1, down_sample_rate=fps_downsample_rate)[:, :n_joints, :] 103 | reaction = get_pose_seq_np(person2, down_sample_rate=fps_downsample_rate)[:, :n_joints, :] 104 | action, reaction = normalize_dual_joints3d_22(action, reaction) 105 | naction = normalize_single_joints3d_22(action) 106 | pose_data = { 107 | 'naction': naction, 108 | 'action': action, 109 | 'reaction': reaction 110 | } 111 | 112 | with (pose_save_dir / f'{file_id}.pkl').open('wb') as f: 113 | pickle.dump(pose_data, f) 114 | except Exception as e: 115 | print(f'{dual_smpl_path}: {e}') 116 | 117 | 118 | if __name__ == '__main__': 119 | data_root_dir = Path(f'~/data/data/motion/{dataset}').expanduser() 120 | smpl_dir = data_root_dir / 'motions' 121 | smpl_paths = [p for p in smpl_dir.glob('*') if p.is_dir()] 122 | if smpl_paths == []: 123 | smpl_paths = [p for p in smpl_dir.glob('*.pkl')] 124 | 125 | pose_save_dir = data_root_dir / 'joints3d_22' 126 | pose_save_dir.mkdir(exist_ok=True) 127 | 128 | n_proc = len(devices) 129 | 130 | src_fps = 60 if dataset.lower() == 'interhuman' else 120 # interx 131 | tgt_fps = 20 132 | fps_downsample_rate = src_fps // tgt_fps 133 | length_range = [32 * fps_downsample_rate, 256 * fps_downsample_rate] 134 | 135 | body_models_path = os.path.expanduser('~/data/pretrained_models/motion/body_models/smplx/SMPLX_NEUTRAL.npz') 136 | 137 | try: 138 | with (data_root_dir / 'annots' / 'interaction_order.pkl').open('rb') as f: 139 | interaction_order = pickle.load(f) 140 | except: 141 | interaction_order = None 142 | 143 | smpl_path_chunks = [ 144 | smpl_paths[i::n_proc] for i in range(n_proc) 145 | ] 146 | single_process_smpl_to_joint3d_22(smpl_paths, pose_save_dir, body_models_path, interaction_order, devices[0], fps_downsample_rate, length_range) 147 | with PPE(max_workers=n_proc) as ppe: 148 | list(ppe.map( 149 | single_process_smpl_to_joint3d_22, 150 | smpl_path_chunks, 151 | [pose_save_dir] * n_proc, 152 | [body_models_path] * n_proc, 153 | [interaction_order] * n_proc, 154 | devices, 155 | [fps_downsample_rate] * n_proc, 156 | [length_range] * n_proc 157 | )) 158 | 159 | # %% 160 | -------------------------------------------------------------------------------- /data_preprocessing/split_interaction_caption_with_llm.py: -------------------------------------------------------------------------------- 1 | #%% prepare LLMs 2 | import transformers 3 | import torch 4 | 5 | 6 | class LLM: 7 | insruction = \ 8 | """ 9 | You are tasked with analyzing a narrative that describes the interactions between two individuals through their movements. Your goal is to identify the initiator and the receiver of the motion and to provide separate, distinct descriptions for each person's actions. 10 | Please adhere to the following response format: 11 | "[Initiator] The person ... 12 | [Receiver] The person ..." 13 | Key Guidelines: 14 | - Refrain from using "first/second person" in your descriptions. Instead, exclusively use "the person" and "another person" when referring to the individuals involved. 15 | - Each description should start with "The person." 16 | - Ensure that you capture the entirety of each person's motion, including all actions in the order they occur within the interaction. 17 | - Strictly follow the response template, and deliver precise and formal captions without any extra words. 18 | - Limit your response to a single line. 19 | Tip: Utilize the active voice for the initiator's actions and the passive voice for the receiver's reactions when appropriate to clearly convey the dynamics of the interaction. 20 | """ 21 | 22 | query_template = \ 23 | """ 24 | Here's the set of descriptions of the interaction: 25 | {} 26 | """ 27 | 28 | example_description_set = \ 29 | """ 30 | One individual stands with arms crossed while another massages his right leg, and the first person softly pats the other's right arm. 31 | One person stands with his arms crossed while the other person massages his right leg, and the first person gently pats the other person's right arm. 32 | Two people stand facing each other, one person bends over to massage the other person's thighs, while the other person taps his shoulder. 33 | """ 34 | # """ 35 | # The first person walks forward and the second person blocks him by crossing their arms in front of their chest 36 | # The first person crosses his/her arms in front of his/her chest. The second person walks towards the first person, and the first person blocks the second person's chest with his/her hands 37 | # Two people face each other, one person walks forward, and the other person crosses his/her hands in front of his/her chest to block, then the first person stops 38 | # """ 39 | 40 | example_response = \ 41 | """ 42 | [Initiator] The person bends over to massage another person's right leg. 43 | [Receiver] The person stands with his arms crossed and is being massaged, and softly pats another person's right arm. 44 | """ 45 | # """ 46 | # [Initiator] The person is walking forward. [Receiver] The person is crossing their arms in front of their chest to block another person. 47 | # [Initiator] The person walks towards another person. [Receiver] The person blocks the another person's chest with his/her hands wit his/her arms in front of his/her chest. 48 | # [Initiator] The person walks forward [Receiver] The person crosses his/her hands in front of his/her chest to block another person. 49 | # """ 50 | 51 | def __init__(self, device, model_dir): 52 | self.pipeline = transformers.pipeline( 53 | "text-generation", 54 | model=model_dir, 55 | model_kwargs={"torch_dtype": torch.bfloat16}, 56 | device=device, 57 | use_fast=False 58 | ) 59 | 60 | def preprocess_lines(self, lines): 61 | res = [] 62 | for line in lines: 63 | res.append(line.replace('his/her', 'his').replace('him/her', 'him').replace('he/she', 'he')) 64 | return res 65 | 66 | @torch.no_grad() 67 | def one_round_qa(self, lines): 68 | lines = self.preprocess_lines(lines) 69 | description_set = '' 70 | for line in lines: 71 | description_set += line.strip() + '\n' 72 | messages = [ 73 | {"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"}, 74 | {"role": "user", "content": self.insruction + self.query_template.format(self.example_description_set)}, 75 | {"role": "assistant", "content": self.example_response}, 76 | {"role": "user", "content": self.query_template.format(description_set)} 77 | ] 78 | 79 | prompt = self.pipeline.tokenizer.apply_chat_template( 80 | messages, 81 | tokenize=False, 82 | add_generation_prompt=True 83 | ) 84 | 85 | terminators = [ 86 | self.pipeline.tokenizer.eos_token_id, 87 | self.pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>") 88 | ] 89 | 90 | outputs = self.pipeline( 91 | prompt, 92 | max_new_tokens=256, 93 | eos_token_id=terminators, 94 | do_sample=True, 95 | temperature=0.1, 96 | top_p=0.9, 97 | pad_token_id=self.pipeline.tokenizer.eos_token_id 98 | ) 99 | response = outputs[0]["generated_text"][len(prompt):].strip().split('\n') 100 | 101 | assert response[0].startswith('[Initiator]') 102 | assert response[1].startswith('[Receiver]') 103 | return { 104 | 'action': [response[0][len('[Initiator]'):].strip(' \n.')], 105 | 'reaction': [response[1][len('[Receiver]'):].strip(' \n.')] 106 | } 107 | 108 | #%% Prepare src data 109 | from pathlib import Path 110 | 111 | data_root_dir = Path('~/data/data/motion/interx').expanduser() 112 | src_txt_path_list = (data_root_dir / 'texts').glob('*.txt') 113 | tgt_dir = data_root_dir / 'texts_all' 114 | tgt_dir.mkdir(exist_ok=True) 115 | 116 | name_lines = [] 117 | for src_txt_path in src_txt_path_list: 118 | with src_txt_path.open('r') as f: 119 | lines = f.readlines() 120 | name_lines.append((src_txt_path.stem, lines)) 121 | 122 | #%% Start parallel processing 123 | import os 124 | import torch 125 | import json 126 | from concurrent.futures import ProcessPoolExecutor as PPE 127 | 128 | MODEL_DIR = os.path.expanduser('~/data/pretrained_models/llm/Meta-Llama-3-8B-Instruct') 129 | 130 | 131 | def single_process(device_idx, name_lines_chunk, save_dir=tgt_dir, model_dir=MODEL_DIR): 132 | llm = LLM(device=device_idx, model_dir=model_dir) 133 | for i, (name, lines) in enumerate(name_lines_chunk): 134 | if i % 100 == 0: 135 | print(i) 136 | try: 137 | result = llm.one_round_qa(lines) 138 | result['interaction'] = lines 139 | except Exception as e: 140 | print(e) 141 | else: 142 | with (save_dir / f'{name}.json').open('w') as f: 143 | json.dump(result, f) 144 | 145 | n_devices = torch.cuda.device_count() 146 | 147 | device_list = list(range(n_devices)) 148 | name_lines_chunks = [ 149 | name_lines[i: : n_devices] for i in range(n_devices) 150 | ] 151 | 152 | with PPE(max_workers=n_devices) as ppe: 153 | list(ppe.map(single_process, device_list, name_lines_chunks)) 154 | # single_process(device_list[0], name_lines_chunks[0]) 155 | 156 | #%% check split data 157 | if False: 158 | import json 159 | import random 160 | from pathlib import Path 161 | 162 | data_root_dir = Path('~/data/data/motion/Inter-X_Dataset').expanduser() 163 | src_txt_path_list = list((data_root_dir / 'texts').glob('*.txt')) 164 | random.shuffle(src_txt_path_list) 165 | tgt_dir = data_root_dir / 'texts_action_reaction' 166 | 167 | for src_txt_path in src_txt_path_list[:10]: 168 | stem = src_txt_path.stem 169 | with src_txt_path.open('r') as f: 170 | src_lines = f.readlines() 171 | with (tgt_dir / f'{stem}.json').open('r') as f: 172 | tgt_lines = json.load(f) 173 | 174 | for src, tgt in zip(src_lines, tgt_lines): 175 | print(f'{src.strip()}\n{tgt}\n') 176 | print('-----------------------------------') 177 | # %% 178 | -------------------------------------------------------------------------------- /src/datasets/motion_clip_dataset.py: -------------------------------------------------------------------------------- 1 | #%% 2 | if __name__ == "__main__": 3 | import sys 4 | sys.path.append(sys.path[0] + r"/../../") 5 | 6 | 7 | import tqdm 8 | import re 9 | import numpy as np 10 | import random 11 | import pickle 12 | import torch 13 | 14 | from src.utils import setup_logger, pad 15 | from src.utils.motion_representation_converter import MotionRepresentationConverter 16 | from src.utils.normalizer import TorchNormalizer 17 | from src.datasets.dataset_base import DatasetBase 18 | 19 | 20 | logger = setup_logger(__file__) 21 | mrc = MotionRepresentationConverter() 22 | 23 | 24 | class MotionCLIPDataset(DatasetBase): 25 | def __init__( 26 | self, 27 | dataset_dir, 28 | split='train', 29 | epoch_scaling=1, 30 | max_motion_length=256, 31 | min_motion_length=32, 32 | motion_representation='intergen_262', 33 | tiny_dataset=False, 34 | test_ar_correspondence='', 35 | ): 36 | super().__init__(dataset_dir=dataset_dir, split=split, epoch_scaling=epoch_scaling, tiny_dataset=tiny_dataset) 37 | self.max_motion_length = max_motion_length 38 | self.min_motion_length = min_motion_length 39 | self.motion_representation = motion_representation 40 | self.test_ar_correspondence = test_ar_correspondence 41 | 42 | logger.info(f'{dataset_dir.split("/")[-1]}/{split} initializing...') 43 | 44 | # 0. load ids 45 | ids = (self.dataset_dir / 'splits' / f'{split}.txt').read_text().strip('\n').split('\n') 46 | if tiny_dataset: 47 | ids = ids[:200] 48 | 49 | # 1. load text and features 50 | logger.info('Loading texts') 51 | texts = {} 52 | valid_ids = [] 53 | for file_id in ids: 54 | texts_dir = self.dataset_dir / 'texts' 55 | try: 56 | texts[file_id] = (texts_dir / f'{file_id}.txt').read_text().strip().split('\n') 57 | valid_ids.append(file_id) 58 | except: 59 | pass 60 | ids = valid_ids 61 | 62 | # 3. load normalized data 63 | logger.info('Loading motion data') 64 | self.normalizer = TorchNormalizer( 65 | statistics_dict=pickle.load( 66 | (self.dataset_dir.parent / 'normalizers' / f'{motion_representation}.pkl').open('rb') 67 | ) 68 | ) 69 | data_dict, valid_ids = self._get_data_dict(ids=ids) 70 | ids = valid_ids 71 | self.motion_dict = data_dict 72 | 73 | # 4. done 74 | self.ids = sorted(ids, key=lambda k: data_dict[k]['length']) 75 | self.texts = texts 76 | 77 | self.familiarity = {i+1: int(label) for i, label in enumerate((self.dataset_dir / 'annots' / 'familiarity.txt').read_text().strip().split('\n'))} 78 | 79 | logger.info(f'{dataset_dir.split("/")[-1]}/{split} initialization done.') 80 | 81 | def _get_data_dict(self, ids): 82 | data_dict = {} 83 | valid_ids = [] 84 | for file_id in tqdm.tqdm(ids): 85 | motion_path = self.dataset_dir / self.motion_representation / f'{file_id}.pkl' 86 | j3d_path = self.dataset_dir / 'joints3d_22' / f'{file_id}.pkl' 87 | try: 88 | ar_test = self.test_ar_correspondence 89 | with motion_path.open('rb') as f: 90 | data = pickle.load(f) 91 | if ar_test: 92 | with j3d_path.open('rb') as f: 93 | j3d_data = pickle.load(f) 94 | 95 | data_len = len(data['reaction']) 96 | if data_len > self.max_motion_length or data_len < self.min_motion_length: 97 | continue 98 | 99 | for k, v in data.items(): 100 | if isinstance(v, np.ndarray): 101 | data[k] = torch.from_numpy(v) 102 | 103 | action = data['action'] 104 | action = self.normalizer.normalize(action, key='all_motion') 105 | 106 | reaction = data['reaction'] 107 | if ar_test != '': 108 | reaction_shifted = j3d_data['reaction'] 109 | if ar_test.startswith('pos'): 110 | delta = float(ar_test[3:]) 111 | x = np.random.random() > 0.5 112 | if x: 113 | delta = np.array([np.random.choice([-delta, delta]), 0, 0]) 114 | else: 115 | delta = np.array([0, 0, np.random.choice([-delta, delta])]) 116 | reaction_shifted += delta 117 | elif ar_test.startswith('time'): 118 | delta = int(ar_test[4:]) 119 | reaction_shifted = np.concatenate([reaction_shifted[delta:], reaction_shifted[-1:].repeat(delta, 0)], axis=0) 120 | reaction_shifted = torch.from_numpy(mrc.convert('j3d', 'i262', reaction_shifted)) 121 | reaction_shifted = self.normalizer.normalize(reaction_shifted, key='all_motion') 122 | reaction_shifted, boolean_mask, length = pad(reaction_shifted, length=self.max_motion_length, dim=0, value=0) 123 | 124 | reaction = self.normalizer.normalize(reaction, key='all_motion') 125 | 126 | action, boolean_mask, length = pad(action, length=self.max_motion_length, dim=0, value=0) 127 | reaction, boolean_mask, length = pad(reaction, length=self.max_motion_length, dim=0, value=0) 128 | label = int(re.findall(r'A(\d+)', file_id)[0]) 129 | 130 | data_dict[file_id] = {'action': action, 'reaction': reaction, 'boolean_mask': boolean_mask, 'length': length, 'label': label} 131 | if ar_test: 132 | data_dict[file_id].update({ 133 | 'reaction_shifted': reaction_shifted 134 | }) 135 | valid_ids.append(file_id) 136 | except FileNotFoundError: 137 | continue 138 | return data_dict, valid_ids 139 | 140 | @property 141 | def real_length(self): 142 | return len(self.ids) 143 | 144 | def getitem(self, index): 145 | real_index = index % self.real_length 146 | file_id = self.ids[real_index] 147 | 148 | res = dict() 149 | res['id'] = file_id 150 | res.update(self.motion_dict[file_id]) 151 | 152 | familiarity = self.familiarity[int(re.findall(r'G(\d+)', file_id)[0])] 153 | res.update({ 154 | 'text': random.choice(self.texts[file_id]), 155 | 'familiarity': familiarity 156 | }) 157 | 158 | rand_id = random.choice(self.ids) 159 | res.update({ 160 | 'random_reaction': self.motion_dict[rand_id]['reaction'], 161 | 'random_length': self.motion_dict[rand_id]['length'], 162 | }) 163 | 164 | return res 165 | 166 | if __name__ == '__main__': 167 | from torch.utils.data import DataLoader 168 | seed = 42 169 | random.seed(seed) 170 | torch.manual_seed(seed) 171 | np.random.seed(seed) 172 | 173 | # for d in ['interx']: 174 | # for split in ['train', 'val', 'test']: 175 | # ds = MotionCLIPDataset( 176 | # dataset_dir=f'~/data/data/motion/{d}', 177 | # split=split, 178 | # tiny_dataset=True, 179 | # ) 180 | # dl = DataLoader(ds, batch_size=32, shuffle=False) 181 | # print(f'len: {next(iter(dl))["length"]}') 182 | 183 | for d in ['interx']: 184 | for ar_test in ['time10']: 185 | ds = MotionCLIPDataset( 186 | dataset_dir=f'~/data/data/motion/{d}', 187 | split='test', 188 | tiny_dataset=True, 189 | test_ar_correspondence=ar_test 190 | ) 191 | dl = DataLoader(ds, batch_size=32, shuffle=True) 192 | # print(f'len: {next(iter(dl))["length"]}') 193 | print(next(iter(dl))) 194 | -------------------------------------------------------------------------------- /src/modules/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Nonlinearity(nn.Module): 6 | def __init__(self): 7 | super().__init__() 8 | 9 | def forward(self, x): 10 | # swish 11 | return x * torch.sigmoid(x) 12 | 13 | 14 | class ResConv1DBlock(nn.Module): 15 | def __init__(self, n_in, n_state, dilation=1, activation='silu', norm=None, dropout=None): 16 | super().__init__() 17 | padding = dilation 18 | self.norm = norm 19 | if norm == "LN": 20 | self.norm1 = nn.LayerNorm(n_in) 21 | self.norm2 = nn.LayerNorm(n_in) 22 | elif norm == "GN": 23 | self.norm1 = nn.GroupNorm(num_groups=32, num_channels=n_in, eps=1e-6, affine=True) 24 | self.norm2 = nn.GroupNorm(num_groups=32, num_channels=n_in, eps=1e-6, affine=True) 25 | elif norm == "BN": 26 | self.norm1 = nn.BatchNorm1d(num_features=n_in, eps=1e-6, affine=True) 27 | self.norm2 = nn.BatchNorm1d(num_features=n_in, eps=1e-6, affine=True) 28 | 29 | else: 30 | self.norm1 = nn.Identity() 31 | self.norm2 = nn.Identity() 32 | 33 | if activation == "relu": 34 | self.activation1 = nn.ReLU() 35 | self.activation2 = nn.ReLU() 36 | 37 | elif activation == "silu": 38 | self.activation1 = Nonlinearity() 39 | self.activation2 = Nonlinearity() 40 | 41 | elif activation == "gelu": 42 | self.activation1 = nn.GELU() 43 | self.activation2 = nn.GELU() 44 | 45 | self.conv1 = nn.Conv1d(n_in, n_state, 3, 1, padding, dilation) 46 | self.conv2 = nn.Conv1d(n_state, n_in, 1, 1, 0,) 47 | 48 | def forward(self, x): 49 | x_orig = x 50 | if self.norm == "LN": 51 | x = self.norm1(x.transpose(-2, -1)) 52 | x = self.activation1(x.transpose(-2, -1)) 53 | else: 54 | x = self.norm1(x) 55 | x = self.activation1(x) 56 | 57 | x = self.conv1(x) 58 | 59 | if self.norm == "LN": 60 | x = self.norm2(x.transpose(-2, -1)) 61 | x = self.activation2(x.transpose(-2, -1)) 62 | else: 63 | x = self.norm2(x) 64 | x = self.activation2(x) 65 | 66 | x = self.conv2(x) 67 | x = x + x_orig 68 | return x 69 | 70 | 71 | class Resnet1D(nn.Module): 72 | def __init__(self, n_in, n_depth, dilation_growth_rate=1, reverse_dilation=True, activation='relu', norm=None): 73 | super().__init__() 74 | 75 | blocks = [ResConv1DBlock(n_in, n_in, dilation=dilation_growth_rate ** depth, activation=activation, norm=norm) for depth in range(n_depth)] 76 | if reverse_dilation: 77 | blocks = blocks[::-1] 78 | 79 | self.model = nn.Sequential(*blocks) 80 | 81 | def forward(self, x): 82 | return self.model(x) 83 | 84 | 85 | class Res1DEncoder(nn.Module): 86 | 87 | def __init__(self, 88 | input_emb_width=3, 89 | output_emb_width=512, 90 | down_t=2, 91 | stride_t=2, 92 | width=512, 93 | depth=3, 94 | dilation_growth_rate=3, 95 | activation='relu', 96 | norm=None): 97 | super().__init__() 98 | 99 | blocks = [] 100 | filter_t, pad_t = stride_t * 2, stride_t // 2 101 | blocks.append(nn.Conv1d(input_emb_width, width, 3, 1, 1)) 102 | blocks.append(nn.ReLU()) 103 | 104 | for i in range(down_t): 105 | input_dim = width 106 | block = nn.Sequential( 107 | nn.Conv1d(input_dim, width, filter_t, stride_t, pad_t), 108 | Resnet1D(width, 109 | depth, 110 | dilation_growth_rate, 111 | activation=activation, 112 | norm=norm), 113 | ) 114 | blocks.append(block) 115 | blocks.append(nn.Conv1d(width, output_emb_width, 3, 1, 1)) 116 | self.model = nn.Sequential(*blocks) 117 | 118 | def forward(self, x): 119 | return self.model(x) 120 | 121 | 122 | class Res1DDecoder(nn.Module): 123 | 124 | def __init__(self, 125 | input_emb_width=3, 126 | output_emb_width=512, 127 | down_t=2, 128 | stride_t=2, 129 | width=512, 130 | depth=3, 131 | dilation_growth_rate=3, 132 | activation='relu', 133 | norm=None): 134 | super().__init__() 135 | blocks = [] 136 | 137 | filter_t, pad_t = stride_t * 2, stride_t // 2 138 | blocks.append(nn.Conv1d(output_emb_width, width, 3, 1, 1)) 139 | blocks.append(nn.ReLU()) 140 | for i in range(down_t): 141 | out_dim = width 142 | block = nn.Sequential( 143 | Resnet1D(width, 144 | depth, 145 | dilation_growth_rate, 146 | reverse_dilation=True, 147 | activation=activation, 148 | norm=norm), nn.Upsample(scale_factor=2, 149 | mode='nearest'), 150 | nn.Conv1d(width, out_dim, 3, 1, 1)) 151 | blocks.append(block) 152 | blocks.append(nn.Conv1d(width, width, 3, 1, 1)) 153 | blocks.append(nn.ReLU()) 154 | blocks.append(nn.Conv1d(width, input_emb_width, 3, 1, 1)) 155 | self.model = nn.Sequential(*blocks) 156 | 157 | def forward(self, x): 158 | return self.model(x) 159 | 160 | 161 | class EncoderV2(nn.Module): 162 | def __init__(self, 163 | input_emb_width=3, 164 | output_emb_width=512, 165 | down_t=2, 166 | stride_t=2, 167 | width=512, 168 | depth=3, 169 | dilation_growth_rate=3, 170 | activation='relu', 171 | norm=None): 172 | super().__init__() 173 | 174 | blocks = [] 175 | filter_t, pad_t = stride_t * 2, stride_t // 2 176 | blocks.append(nn.Conv1d(input_emb_width, width, 3, 1, 1)) 177 | blocks.append(nn.ReLU()) 178 | 179 | input_dim = width 180 | block = nn.Sequential( 181 | nn.Conv1d(input_dim, width, filter_t, stride_t, pad_t), 182 | Resnet1D(width, depth, dilation_growth_rate, activation=activation, norm=norm), 183 | ) 184 | blocks.append(block) 185 | 186 | input_dim = width 187 | block = nn.Sequential( 188 | nn.Conv1d(input_dim, width, 3, 1, 1), 189 | Resnet1D(width, depth, dilation_growth_rate, activation=activation, norm=norm), 190 | ) 191 | blocks.append(block) 192 | 193 | blocks.append(nn.Conv1d(width, output_emb_width, 3, 1, 1)) 194 | self.model = nn.Sequential(*blocks) 195 | 196 | def forward(self, x): 197 | return self.model(x) 198 | 199 | 200 | class DecoderV2(nn.Module): 201 | def __init__(self, 202 | input_emb_width=3, 203 | output_emb_width=512, 204 | down_t=2, 205 | stride_t=2, 206 | width=512, 207 | depth=3, 208 | dilation_growth_rate=3, 209 | activation='relu', 210 | norm=None): 211 | super().__init__() 212 | blocks = [] 213 | 214 | filter_t, pad_t = stride_t * 2, stride_t // 2 215 | blocks.append(nn.Conv1d(output_emb_width, width, 3, 1, 1)) 216 | blocks.append(nn.ReLU()) 217 | 218 | out_dim = width 219 | block = nn.Sequential( 220 | Resnet1D(width, 221 | depth, 222 | dilation_growth_rate, 223 | reverse_dilation=True, 224 | activation=activation, 225 | norm=norm), 226 | nn.Upsample(scale_factor=2, mode='nearest'), 227 | nn.Conv1d(width, out_dim, 3, 1, 1)) 228 | blocks.append(block) 229 | 230 | out_dim = width 231 | block = nn.Sequential( 232 | Resnet1D(width, 233 | depth, 234 | dilation_growth_rate, 235 | reverse_dilation=True, 236 | activation=activation, 237 | norm=norm), 238 | nn.Conv1d(width, out_dim, 3, 1, 1)) 239 | blocks.append(block) 240 | 241 | blocks.append(nn.Conv1d(width, width, 3, 1, 1)) 242 | blocks.append(nn.ReLU()) 243 | blocks.append(nn.Conv1d(width, input_emb_width, 3, 1, 1)) 244 | self.model = nn.Sequential(*blocks) 245 | 246 | def forward(self, x): 247 | return self.model(x) -------------------------------------------------------------------------------- /src/datasets/motion_vqvae_dataset.py: -------------------------------------------------------------------------------- 1 | #%% 2 | if __name__ == "__main__": 3 | import sys 4 | sys.path.append(sys.path[0] + r"/../../") 5 | 6 | 7 | import tqdm 8 | import numpy as np 9 | import random 10 | import pickle 11 | import re 12 | import torch 13 | 14 | from src.utils import setup_logger, pad 15 | from src.utils.normalizer import TorchNormalizer 16 | from src.datasets.dataset_base import DatasetBase 17 | 18 | 19 | logger = setup_logger(__file__) 20 | 21 | 22 | class MotionVQVAEDataset(DatasetBase): 23 | def __init__( 24 | self, 25 | dataset_dir, 26 | split='train', 27 | epoch_scaling=1, 28 | max_motion_length=256, 29 | min_motion_length=32, 30 | motion_representation='intergen_262', 31 | tiny_dataset=False, 32 | use_h3d=False, 33 | abs_action=False, 34 | ): 35 | super().__init__(dataset_dir=dataset_dir, split=split, epoch_scaling=epoch_scaling, tiny_dataset=tiny_dataset) 36 | self.max_motion_length = max_motion_length 37 | self.min_motion_length = min_motion_length 38 | self.motion_representation = motion_representation 39 | self.use_h3d = use_h3d 40 | self.abs_action = abs_action 41 | if self.split != 'train': 42 | self.min_motion_length = 32 43 | 44 | logger.info(f'{dataset_dir.split("/")[-1]}/{split} initializing...') 45 | 46 | # 0. load ids 47 | ids = (self.dataset_dir / 'splits' / f'{split}.txt').read_text().strip('\n').split('\n') 48 | if self.tiny_dataset: 49 | ids = ids[:200] 50 | 51 | # 1. load text 52 | logger.info('Loading texts') 53 | texts_dir = self.dataset_dir / 'texts' 54 | texts = {} 55 | valid_ids = [] 56 | for file_id in ids: 57 | try: 58 | texts[file_id] = (texts_dir / f'{file_id}.txt').read_text().strip().split('\n') 59 | valid_ids.append(file_id) 60 | except: 61 | pass 62 | self.ids = valid_ids 63 | self.texts = texts 64 | 65 | # 2. load normalized data 66 | logger.info('Loading motion data') 67 | self.normalizer = TorchNormalizer( 68 | statistics_dict=pickle.load( 69 | (self.dataset_dir / 'normalizers' / f'{motion_representation}.pkl').open('rb') 70 | ) 71 | ) 72 | 73 | if self.split == 'train': 74 | self.motions = self._load_humanml3d_motions() if self.use_h3d else [] 75 | self.motions.extend(self._load_training_data()) 76 | else: 77 | self.motion_dict, self.padded_motion_dict, self.ids = self._load_val_data() 78 | 79 | logger.info(f'{dataset_dir.split("/")[-1]}/{split} initialization done.') 80 | 81 | def _load_training_data(self): 82 | motions = [] 83 | for file_id in tqdm.tqdm(self.ids): 84 | motion_path = self.dataset_dir / self.motion_representation / f'{file_id}.pkl' 85 | try: 86 | with motion_path.open('rb') as f: 87 | data = pickle.load(f) 88 | 89 | data_len = len(data['reaction']) 90 | if data_len > self.max_motion_length or data_len < self.min_motion_length: 91 | continue 92 | 93 | for k, v in data.items(): 94 | if isinstance(v, np.ndarray): 95 | data[k] = torch.from_numpy(v) 96 | 97 | reaction = self.normalizer.normalize(data['reaction'], key='all_motion') 98 | motions.append(reaction) 99 | motions.append(reaction) # double reaction 100 | 101 | if self.abs_action: 102 | action = self.normalizer.normalize(data['action'], key='all_motion') 103 | motions.append(action) 104 | else: 105 | naction = self.normalizer.normalize(data['naction'], key='all_motion') 106 | motions.append(naction) 107 | 108 | except FileNotFoundError: 109 | continue 110 | return motions 111 | 112 | def _load_humanml3d_motions(self): 113 | h3d_ids = (self.dataset_dir.parent / 'humanml3d' / 'splits' / 'all.txt').read_text().strip('\n').split('\n') 114 | if self.tiny_dataset: 115 | h3d_ids = h3d_ids[:200] 116 | 117 | motions = [] 118 | for file_id in tqdm.tqdm(h3d_ids, desc='load h3d motion'): 119 | motion_path = self.dataset_dir.parent / 'humanml3d' / self.motion_representation / f'{file_id}.pkl' 120 | try: 121 | with motion_path.open('rb') as f: 122 | data = pickle.load(f) 123 | 124 | data_len = len(data['reaction']) 125 | if data_len > self.max_motion_length or data_len < self.min_motion_length: 126 | continue 127 | 128 | for k, v in data.items(): 129 | if isinstance(v, np.ndarray): 130 | data[k] = torch.from_numpy(v) 131 | 132 | reaction = self.normalizer.normalize(data['reaction'], key='all_motion') 133 | motions.append(reaction) 134 | except: 135 | pass 136 | return motions 137 | 138 | def _load_val_data(self): 139 | motion_dict = {} 140 | padded_motion_dict = {} 141 | valid_ids = [] 142 | for file_id in tqdm.tqdm(self.ids): 143 | motion_path = self.dataset_dir / self.motion_representation / f'{file_id}.pkl' 144 | try: 145 | with motion_path.open('rb') as f: 146 | data = pickle.load(f) 147 | 148 | data_len = len(data['reaction']) 149 | if data_len > self.max_motion_length or data_len < self.min_motion_length: 150 | continue 151 | 152 | for k, v in data.items(): 153 | if isinstance(v, np.ndarray): 154 | data[k] = torch.from_numpy(v) 155 | 156 | reaction = self.normalizer.normalize(data['reaction'], key='all_motion') 157 | action = self.normalizer.normalize(data['action'], key='all_motion') 158 | motion_dict[file_id] = { 159 | 'reaction': reaction, 160 | 'length': reaction.shape[0] 161 | } 162 | padded_action, boolean_mask, _ = pad(action, length=self.max_motion_length, dim=0, value=0) 163 | padded_reaction, _, _ = pad(reaction, length=self.max_motion_length, dim=0, value=0, get_boolean_mask=False) 164 | padded_motion_dict[file_id] = { 165 | 'action': padded_action, 'reaction': padded_reaction, 'boolean_mask': boolean_mask, 'label': int(re.findall(r'A(\d+)', file_id)[0]) 166 | } 167 | valid_ids.append(file_id) 168 | 169 | except FileNotFoundError: 170 | continue 171 | 172 | return motion_dict, padded_motion_dict, valid_ids 173 | 174 | @property 175 | def real_length(self): 176 | if self.split == 'train': 177 | return len(self.motions) 178 | else: 179 | return len(self.ids) 180 | 181 | def getitem(self, index): 182 | if self.split == 'train': 183 | return self.get_train_item(index=index) 184 | else: 185 | return self.get_val_item(index=index) 186 | 187 | def get_train_item(self, index): 188 | motion = self.motions[index] 189 | length = len(motion) 190 | idx = random.randint(0, length - self.min_motion_length) 191 | return { 192 | 'motion': motion[idx: idx + self.min_motion_length, :] 193 | } 194 | 195 | def get_val_item(self, index): 196 | file_id = self.ids[index] 197 | 198 | res = dict() 199 | res['id'] = file_id 200 | 201 | motion_dict = self.motion_dict[file_id] 202 | length = motion_dict['length'] 203 | idx = random.randint(0, length - self.min_motion_length) 204 | vq_reaction = motion_dict['reaction'][idx: idx + self.min_motion_length, :] 205 | 206 | res.update({ 207 | 'motion': vq_reaction, 208 | 'length': length 209 | }) 210 | if self.split != 'train': 211 | padded_motion_dict = self.padded_motion_dict[file_id] 212 | res.update({ 213 | 'padded_action': padded_motion_dict['action'], 214 | 'padded_reaction': padded_motion_dict['reaction'], 215 | 'boolean_mask': padded_motion_dict['boolean_mask'], 216 | 'label': padded_motion_dict['label'] 217 | }) 218 | 219 | res.update({'text': random.choice(self.texts[file_id])}) 220 | 221 | return res 222 | 223 | if __name__ == '__main__': 224 | from torch.utils.data import DataLoader 225 | seed = 42 226 | random.seed(seed) 227 | torch.manual_seed(seed) 228 | np.random.seed(seed) 229 | 230 | for d in ['interx']: 231 | for split in ['train', 'val']: 232 | ds = MotionVQVAEDataset( 233 | dataset_dir=f'~/data/data/motion/{d}', 234 | split=split, 235 | tiny_dataset=True, 236 | use_h3d=True 237 | ) 238 | dl = DataLoader(ds, batch_size=2, shuffle=True) 239 | print(f'len: {len(ds), next(iter(dl))}') 240 | 241 | # %% 242 | -------------------------------------------------------------------------------- /third_party/HumanML3D/common/skeleton.py: -------------------------------------------------------------------------------- 1 | try: 2 | from ..common.quaternion import * 3 | except: 4 | from common.quaternion import * 5 | import scipy.ndimage.filters as filters 6 | 7 | class Skeleton(object): 8 | def __init__(self, offset, kinematic_tree, device): 9 | self.device = device 10 | self._raw_offset_np = offset.numpy() 11 | self._raw_offset = offset.clone().detach().to(device).float() 12 | self._kinematic_tree = kinematic_tree 13 | self._offset = None 14 | self._parents = [0] * len(self._raw_offset) 15 | self._parents[0] = -1 16 | for chain in self._kinematic_tree: 17 | for j in range(1, len(chain)): 18 | self._parents[chain[j]] = chain[j-1] 19 | 20 | def njoints(self): 21 | return len(self._raw_offset) 22 | 23 | def offset(self): 24 | return self._offset 25 | 26 | def set_offset(self, offsets): 27 | self._offset = offsets.clone().detach().to(self.device).float() 28 | 29 | def kinematic_tree(self): 30 | return self._kinematic_tree 31 | 32 | def parents(self): 33 | return self._parents 34 | 35 | # joints (batch_size, joints_num, 3) 36 | def get_offsets_joints_batch(self, joints): 37 | assert len(joints.shape) == 3 38 | _offsets = self._raw_offset.expand(joints.shape[0], -1, -1).clone() 39 | for i in range(1, self._raw_offset.shape[0]): 40 | _offsets[:, i] = torch.norm(joints[:, i] - joints[:, self._parents[i]], p=2, dim=1)[:, None] * _offsets[:, i] 41 | 42 | self._offset = _offsets.detach() 43 | return _offsets 44 | 45 | # joints (joints_num, 3) 46 | def get_offsets_joints(self, joints): 47 | assert len(joints.shape) == 2 48 | _offsets = self._raw_offset.clone() 49 | for i in range(1, self._raw_offset.shape[0]): 50 | # print(joints.shape) 51 | _offsets[i] = torch.norm(joints[i] - joints[self._parents[i]], p=2, dim=0) * _offsets[i] 52 | 53 | self._offset = _offsets.detach() 54 | return _offsets 55 | 56 | # face_joint_idx should follow the order of right hip, left hip, right shoulder, left shoulder 57 | # joints (batch_size, joints_num, 3) 58 | def inverse_kinematics_np(self, joints, face_joint_idx, smooth_forward=False): 59 | assert len(face_joint_idx) == 4 60 | '''Get Forward Direction''' 61 | l_hip, r_hip, sdr_r, sdr_l = face_joint_idx 62 | across1 = joints[:, r_hip] - joints[:, l_hip] 63 | across2 = joints[:, sdr_r] - joints[:, sdr_l] 64 | across = across1 + across2 65 | across = across / np.sqrt((across**2).sum(axis=-1))[:, np.newaxis] 66 | # print(across1.shape, across2.shape) 67 | 68 | # forward (batch_size, 3) 69 | forward = np.cross(np.array([[0, 1, 0]]), across, axis=-1) 70 | if smooth_forward: 71 | forward = filters.gaussian_filter1d(forward, 20, axis=0, mode='nearest') 72 | # forward (batch_size, 3) 73 | forward = forward / np.sqrt((forward**2).sum(axis=-1))[..., np.newaxis] 74 | 75 | '''Get Root Rotation''' 76 | target = np.array([[0,0,1]]).repeat(len(forward), axis=0) 77 | root_quat = qbetween_np(forward, target) 78 | 79 | '''Inverse Kinematics''' 80 | # quat_params (batch_size, joints_num, 4) 81 | # print(joints.shape[:-1]) 82 | quat_params = np.zeros(joints.shape[:-1] + (4,)) 83 | # print(quat_params.shape) 84 | root_quat[0] = np.array([[1.0, 0.0, 0.0, 0.0]]) 85 | quat_params[:, 0] = root_quat 86 | # quat_params[0, 0] = np.array([[1.0, 0.0, 0.0, 0.0]]) 87 | for chain in self._kinematic_tree: 88 | R = root_quat 89 | for j in range(len(chain) - 1): 90 | # (batch, 3) 91 | u = self._raw_offset_np[chain[j+1]][np.newaxis,...].repeat(len(joints), axis=0) 92 | # print(u.shape) 93 | # (batch, 3) 94 | v = joints[:, chain[j+1]] - joints[:, chain[j]] 95 | v = v / np.sqrt((v**2).sum(axis=-1))[:, np.newaxis] 96 | # print(u.shape, v.shape) 97 | rot_u_v = qbetween_np(u, v) 98 | 99 | R_loc = qmul_np(qinv_np(R), rot_u_v) 100 | 101 | quat_params[:,chain[j + 1], :] = R_loc 102 | R = qmul_np(R, R_loc) 103 | 104 | return quat_params 105 | 106 | # Be sure root joint is at the beginning of kinematic chains 107 | def forward_kinematics(self, quat_params, root_pos, skel_joints=None, do_root_R=True): 108 | # quat_params (batch_size, joints_num, 4) 109 | # joints (batch_size, joints_num, 3) 110 | # root_pos (batch_size, 3) 111 | if skel_joints is not None: 112 | offsets = self.get_offsets_joints_batch(skel_joints) 113 | if len(self._offset.shape) == 2: 114 | offsets = self._offset.expand(quat_params.shape[0], -1, -1) 115 | joints = torch.zeros(quat_params.shape[:-1] + (3,)).to(self.device) 116 | joints[:, 0] = root_pos 117 | for chain in self._kinematic_tree: 118 | if do_root_R: 119 | R = quat_params[:, 0] 120 | else: 121 | R = torch.tensor([[1.0, 0.0, 0.0, 0.0]]).expand(len(quat_params), -1).detach().to(self.device) 122 | for i in range(1, len(chain)): 123 | R = qmul(R, quat_params[:, chain[i]]) 124 | offset_vec = offsets[:, chain[i]] 125 | joints[:, chain[i]] = qrot(R, offset_vec) + joints[:, chain[i-1]] 126 | return joints 127 | 128 | # Be sure root joint is at the beginning of kinematic chains 129 | def forward_kinematics_np(self, quat_params, root_pos, skel_joints=None, do_root_R=True): 130 | # quat_params (batch_size, joints_num, 4) 131 | # joints (batch_size, joints_num, 3) 132 | # root_pos (batch_size, 3) 133 | if skel_joints is not None: 134 | skel_joints = torch.from_numpy(skel_joints) 135 | offsets = self.get_offsets_joints_batch(skel_joints) 136 | if len(self._offset.shape) == 2: 137 | offsets = self._offset.expand(quat_params.shape[0], -1, -1) 138 | offsets = offsets.numpy() 139 | joints = np.zeros(quat_params.shape[:-1] + (3,)) 140 | joints[:, 0] = root_pos 141 | for chain in self._kinematic_tree: 142 | if do_root_R: 143 | R = quat_params[:, 0] 144 | else: 145 | R = np.array([[1.0, 0.0, 0.0, 0.0]]).repeat(len(quat_params), axis=0) 146 | for i in range(1, len(chain)): 147 | R = qmul_np(R, quat_params[:, chain[i]]) 148 | offset_vec = offsets[:, chain[i]] 149 | joints[:, chain[i]] = qrot_np(R, offset_vec) + joints[:, chain[i - 1]] 150 | return joints 151 | 152 | def forward_kinematics_cont6d_np(self, cont6d_params, root_pos, skel_joints=None, do_root_R=True): 153 | # cont6d_params (batch_size, joints_num, 6) 154 | # joints (batch_size, joints_num, 3) 155 | # root_pos (batch_size, 3) 156 | if skel_joints is not None: 157 | skel_joints = torch.from_numpy(skel_joints) 158 | offsets = self.get_offsets_joints_batch(skel_joints) 159 | if len(self._offset.shape) == 2: 160 | offsets = self._offset.expand(cont6d_params.shape[0], -1, -1) 161 | offsets = offsets.numpy() 162 | joints = np.zeros(cont6d_params.shape[:-1] + (3,)) 163 | joints[:, 0] = root_pos 164 | for chain in self._kinematic_tree: 165 | if do_root_R: 166 | matR = cont6d_to_matrix_np(cont6d_params[:, 0]) 167 | else: 168 | matR = np.eye(3)[np.newaxis, :].repeat(len(cont6d_params), axis=0) 169 | for i in range(1, len(chain)): 170 | matR = np.matmul(matR, cont6d_to_matrix_np(cont6d_params[:, chain[i]])) 171 | offset_vec = offsets[:, chain[i]][..., np.newaxis] 172 | # print(matR.shape, offset_vec.shape) 173 | joints[:, chain[i]] = np.matmul(matR, offset_vec).squeeze(-1) + joints[:, chain[i-1]] 174 | return joints 175 | 176 | def forward_kinematics_cont6d(self, cont6d_params, root_pos, skel_joints=None, do_root_R=True): 177 | # cont6d_params (batch_size, joints_num, 6) 178 | # joints (batch_size, joints_num, 3) 179 | # root_pos (batch_size, 3) 180 | if skel_joints is not None: 181 | # skel_joints = torch.from_numpy(skel_joints) 182 | offsets = self.get_offsets_joints_batch(skel_joints) 183 | if len(self._offset.shape) == 2: 184 | offsets = self._offset.expand(cont6d_params.shape[0], -1, -1) 185 | joints = torch.zeros(cont6d_params.shape[:-1] + (3,)).to(cont6d_params.device) 186 | joints[..., 0, :] = root_pos 187 | for chain in self._kinematic_tree: 188 | if do_root_R: 189 | matR = cont6d_to_matrix(cont6d_params[:, 0]) 190 | else: 191 | matR = torch.eye(3).expand((len(cont6d_params), -1, -1)).detach().to(cont6d_params.device) 192 | for i in range(1, len(chain)): 193 | matR = torch.matmul(matR, cont6d_to_matrix(cont6d_params[:, chain[i]])) 194 | offset_vec = offsets[:, chain[i]].unsqueeze(-1) 195 | # print(matR.shape, offset_vec.shape) 196 | joints[:, chain[i]] = torch.matmul(matR, offset_vec).squeeze(-1) + joints[:, chain[i-1]] 197 | return joints 198 | 199 | 200 | 201 | 202 | 203 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List 3 | import shutil 4 | from collections import defaultdict 5 | import argparse 6 | import copy 7 | from omegaconf import OmegaConf, DictConfig, ListConfig 8 | import numpy as np 9 | import torch 10 | import torch.distributed 11 | from torch.utils.data import DataLoader 12 | import lightning.pytorch as pl 13 | 14 | from src.utils import instantiate_from_config, get_timestamp, setup_logger 15 | 16 | 17 | logger = setup_logger(__file__) 18 | 19 | 20 | def instantiate_callbacks(callback_configs: ListConfig): 21 | callbacks = [] 22 | for callback_cfg in callback_configs: 23 | callbacks.append(instantiate_from_config(callback_cfg)) 24 | 25 | return callbacks 26 | 27 | 28 | def get_dataloaders(config, args): 29 | train_ds = instantiate_from_config(config.dataset, extra_kwargs={'split': 'train'}) 30 | val_ds = instantiate_from_config(config.dataset, extra_kwargs={'split': 'val'}) 31 | test_ds = instantiate_from_config(config.dataset, extra_kwargs={'split': 'test'}) 32 | 33 | dataloader_config = copy.copy(config.dataloader) 34 | val_batch_size = dataloader_config.pop('val_batch_size', dataloader_config.batch_size) 35 | train_dataloader = DataLoader(train_ds, **dataloader_config, shuffle= not args.no_shuffle_train, drop_last=True) 36 | dataloader_config.batch_size = val_batch_size 37 | val_dataloader = DataLoader(val_ds, **dataloader_config, shuffle=False, drop_last=True) 38 | test_dataloader = DataLoader(test_ds, **dataloader_config, shuffle=False, drop_last=True) 39 | 40 | return train_dataloader, val_dataloader, test_dataloader 41 | 42 | 43 | def _preprocess_config(config, args, unknown_args): 44 | # global logger 45 | def set_config_key_value(inplace_dict, key_path, value): 46 | def bfs_set_config_key_value(inplace_dict, key, value): 47 | at_least_one_kv_is_set = False 48 | if not isinstance(inplace_dict, (DictConfig, dict)): 49 | return False 50 | if key in inplace_dict.keys(): 51 | inplace_dict[key] = value 52 | at_least_one_kv_is_set = True 53 | for v in inplace_dict.values(): 54 | if isinstance(v, (DictConfig, dict)): 55 | at_least_one_kv_is_set |= bfs_set_config_key_value(inplace_dict=v, key=key, value=value) 56 | elif isinstance(v, ListConfig): 57 | for item in v: 58 | at_least_one_kv_is_set |= bfs_set_config_key_value(inplace_dict=item, key=key, value=value) 59 | return at_least_one_kv_is_set 60 | 61 | keys = key_path.split('.') # e.g., dataset.a.b=1 62 | len_keys = len(keys) 63 | if len_keys == 1: # e.g., batch_size=32 64 | success = bfs_set_config_key_value(inplace_dict, key=key_path, value=value) 65 | if success: 66 | return 67 | else: 68 | raise ValueError(f'{key_path} is not found in config') 69 | 70 | # else len_keys > 1: 71 | for key_idx in range(len_keys - 1): 72 | inplace_dict = inplace_dict[keys[key_idx]] 73 | 74 | if isinstance(inplace_dict, ListConfig): 75 | for item in inplace_dict: 76 | for sub_key_idx in range(key_idx + 1, len_keys - 1): 77 | item = item[keys[sub_key_idx]] 78 | item[keys[-1]] = value 79 | return 80 | 81 | inplace_dict[keys[-1]] = value 82 | 83 | # set unknown args to config 84 | for unknown in unknown_args: 85 | k, v = unknown.split('=') 86 | try: 87 | v = int(v) # maybe int has the highest priority 88 | except: 89 | try: 90 | v = float(v) 91 | except: 92 | # Python constants: True, False, None 93 | # it should not be v = bool(v) as bool('False') -> True 94 | if (vlower := v.lower()) == 'true': 95 | v = True 96 | elif vlower == 'false': 97 | v = False 98 | elif vlower == 'none': 99 | v = None 100 | # else v = v, the str itself 101 | set_config_key_value(config, k, v) 102 | 103 | # devices 104 | devices = args.devices 105 | if devices is None: 106 | config.trainer.accelerator = 'cpu' # bet you won't run into this line 107 | else: 108 | config.trainer.devices = [int(rank) for rank in devices.split(',')] 109 | 110 | # set project name and signature for logging 111 | if args.no_log: 112 | config.trainer.logger = False 113 | else: 114 | config.trainer.logger.save_dir = f'logs/{args.model}' 115 | config.trainer.logger.name = f'{args.dataset}' 116 | config.trainer.logger.version = get_timestamp() + (f'_{args.log_suffix}' if args.log_suffix != '' else '') 117 | 118 | # batch size for ddp 119 | total_bs = config.dataloader.batch_size 120 | num_devices = len(config.trainer.devices) 121 | bs_per_device = total_bs // num_devices 122 | real_bs = bs_per_device * num_devices 123 | if real_bs != total_bs: 124 | logger.warning(f'real batch size is {real_bs}') 125 | config.dataloader.batch_size = bs_per_device 126 | 127 | # epoch scaling: scaling up the epoch length while reducing the number of epochs 128 | # this is useful when an epoch is too short and val is too frequent 129 | epoch_scaling = config.dataset.get('epoch_scaling') 130 | if epoch_scaling is not None and epoch_scaling != 1: 131 | config.trainer.max_epochs = int(config.trainer.max_epochs / epoch_scaling) 132 | logger.info(f'Training epoch length is scaled by {epoch_scaling}, thus the num of epochs is decreased to {config.trainer.max_epochs}') 133 | 134 | # process the config here 135 | config = preprocess_config_hook(config) 136 | 137 | logger.info(f'running with config: {config}') 138 | return config 139 | 140 | 141 | def preprocess_config_hook(config): 142 | return config 143 | 144 | 145 | def get_processed_args_and_config(): 146 | args, unknown_args = get_args() 147 | 148 | OmegaConf.register_new_resolver("eval", eval) 149 | 150 | # load trainer config 151 | trainer_config = OmegaConf.load(f'src/configs/trainer/{args.trainer}.yaml') 152 | 153 | # load model config 154 | model_config = OmegaConf.load(f'src/configs/models/{args.model}.yaml') 155 | config = OmegaConf.merge(trainer_config, model_config) 156 | 157 | # load dataset config 158 | dataset_config = OmegaConf.load(f'src/configs/datasets/{args.dataset}.yaml') 159 | config = OmegaConf.merge(config, DictConfig(dataset_config)) 160 | OmegaConf.resolve(config) 161 | 162 | config = _preprocess_config(config, args, unknown_args) 163 | 164 | return args, config 165 | 166 | 167 | def get_args(): 168 | parser = argparse.ArgumentParser() 169 | 170 | parser.add_argument( 171 | '--model', 172 | default='motion_clip' 173 | ) 174 | 175 | parser.add_argument( 176 | '--dataset', 177 | default='motion_clip' 178 | ) 179 | 180 | parser.add_argument( 181 | '--trainer', 182 | default='default' # actually this is the only trainer 183 | ) 184 | 185 | parser.add_argument( 186 | '--devices', 187 | type=str, 188 | default=None, 189 | ) 190 | 191 | parser.add_argument( 192 | '--resume_ckpt_path', 193 | type=str, 194 | default=None 195 | ) 196 | 197 | parser.add_argument( 198 | '--load_ckpt_path', 199 | type=str, 200 | default=None 201 | ) 202 | 203 | parser.add_argument( 204 | '--no_log', # when debugging, setting this to False helps. (Recommend to add this to launch.json) 205 | help='disable training log', 206 | action='store_true' 207 | ) 208 | 209 | parser.add_argument( 210 | '--log_suffix', 211 | help='append a suffix to log dir', 212 | default='' 213 | ) 214 | 215 | parser.add_argument( 216 | '--no_shuffle_train', 217 | action='store_true' 218 | ) 219 | 220 | args, unknown_args = parser.parse_known_args() 221 | return args, unknown_args 222 | 223 | 224 | def main(): 225 | args, config = get_processed_args_and_config() 226 | pl.seed_everything(config.seed) 227 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 228 | 229 | train_dataloader, val_dataloader, test_dataloader = get_dataloaders(config, args) 230 | epoch_length = len(train_dataloader) // len(config.trainer.devices) 231 | config.model.training_kwargs['num_training_steps'] = epoch_length * config.trainer.max_epochs 232 | 233 | model: pl.LightningModule = instantiate_from_config(config.model, extra_kwargs={"all_config": config}) 234 | if p := args.load_ckpt_path: 235 | model.load_state_dict(state_dict=torch.load(p, map_location='cpu')['state_dict'], strict=False) 236 | 237 | trainer: pl.Trainer = instantiate_from_config(config.trainer, extra_kwargs={'callbacks': instantiate_callbacks(config.callbacks)}) 238 | 239 | try: 240 | try: 241 | if trainer.global_rank == 0: 242 | shutil.copytree('src', os.path.join(trainer.logger.log_dir, 'src_backup')) # backup src directory 243 | except: pass 244 | 245 | trainer.fit(model=model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader, ckpt_path=args.resume_ckpt_path) 246 | 247 | # evaluation 248 | results = trainer.test(ckpt_path='best', dataloaders=test_dataloader)[0] # the first dataloader 249 | logger = setup_logger('results', log_file=f'{trainer.logger.log_dir}/eval_after_train.log') 250 | logger.info(f'evaluation results: {results}') 251 | 252 | except Exception as e: 253 | raise e 254 | else: 255 | # mark log dir as trained 256 | if trainer.global_rank == 0: 257 | shutil.move(trainer.logger.log_dir, trainer.logger.log_dir + '_trained') 258 | 259 | 260 | if __name__ == '__main__': 261 | main() 262 | -------------------------------------------------------------------------------- /src/models/motion_clip.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | import random 3 | import copy 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import transformers 9 | 10 | from .model_base import ModelBase 11 | from ..modules.embeddings import PositionalEncoding 12 | from ..utils.constants import MOTION_REPRESENTATION_INFO, TEXT_FEATURE_INFO 13 | from ..metrics.common import calculate_diversity, euclidean_distance_matrix, calculate_top_k, calculate_fid 14 | 15 | 16 | class MotionCLIP(ModelBase): 17 | def __init__( 18 | self, 19 | model_kwargs, 20 | training_kwargs, 21 | all_config=None, 22 | ): 23 | super().__init__(model_kwargs=model_kwargs, training_kwargs=training_kwargs, all_config=all_config) 24 | 25 | self.motion_representation = model_kwargs.motion_representation 26 | self.motion_representation_info = MOTION_REPRESENTATION_INFO[model_kwargs.motion_representation] 27 | self.motion_feature_size = motion_feature_size = self.motion_representation_info['feature_size'] 28 | self.output_size = output_size = model_kwargs.output_size 29 | self.n_labels = n_labels = model_kwargs.n_labels 30 | 31 | # text 32 | text_feature_name = model_kwargs.text_feature_name 33 | self.text_feature_info = TEXT_FEATURE_INFO[text_feature_name] 34 | text_model = transformers.CLIPTextModel.from_pretrained(text_feature_name) 35 | self.text_emb = copy.deepcopy(text_model.text_model.embeddings).eval() 36 | for p in self.text_emb.parameters(): 37 | p.requires_grad_(False) 38 | del text_model 39 | self.tokenizer = transformers.AutoTokenizer.from_pretrained(text_feature_name) 40 | self.text_feature_size = text_feature_size = self.text_feature_info['feature_size'] 41 | self.text_transformer = nn.TransformerEncoder( 42 | encoder_layer=nn.TransformerEncoderLayer( 43 | d_model=text_feature_size, 44 | nhead=model_kwargs.n_heads, 45 | activation='gelu', 46 | batch_first=True, 47 | dropout=model_kwargs.dropout 48 | ), 49 | num_layers=model_kwargs.n_encoder_layers, 50 | ) 51 | self.text_linear_out = nn.Linear(text_feature_size, output_size) 52 | 53 | # motion 54 | self.action_mask_coef = model_kwargs.get('action_mask_coef', 0) 55 | input_motion_size = (motion_feature_size - 4) * 2 if self.action_mask_coef >=0 else (motion_feature_size - 4) 56 | self.motion_lin_in = nn.Linear(input_motion_size, text_feature_size) 57 | self.motion_transformer = nn.TransformerEncoder( 58 | encoder_layer=nn.TransformerEncoderLayer( 59 | d_model=text_feature_size, 60 | nhead=model_kwargs.n_heads, 61 | activation='gelu', 62 | batch_first=True, 63 | dropout=model_kwargs.dropout 64 | ), 65 | num_layers=model_kwargs.n_encoder_layers, 66 | ) 67 | 68 | # classification head 69 | self.motion_linear_out = nn.Linear(text_feature_size, output_size) 70 | if model_kwargs.cls_weight > 0: 71 | self.motion_cls_head = nn.Sequential( 72 | nn.Dropout(model_kwargs.dropout), 73 | nn.Linear(output_size, output_size), 74 | nn.Tanh(), 75 | nn.Dropout(model_kwargs.dropout), 76 | nn.Linear(output_size, n_labels) 77 | ) 78 | 79 | self.pe = PositionalEncoding(d_model=text_feature_size, dropout=model_kwargs.dropout) 80 | self.cls = torch.nn.Parameter(torch.zeros(size=(1, 1, text_feature_size))) 81 | torch.nn.init.normal_(self.cls, mean=0, std=0.02) 82 | 83 | self.ce_loss = torch.nn.CrossEntropyLoss() 84 | self.latent_scale = torch.nn.Parameter(torch.Tensor([model_kwargs.get('init_latent_scale', 1)])) 85 | 86 | def combine_motion(self, reaction, action): 87 | seq_length = reaction.shape[1] 88 | 89 | reaction = reaction[:, :, :-4] 90 | action = action[:, :, :-4] 91 | if self.action_mask_coef > 0: 92 | action_mask = [1] + [0] * self.action_mask_coef 93 | action_mask = action_mask * (seq_length // len(action_mask) + 1) 94 | action_mask = torch.tensor(action_mask[:seq_length], device=reaction.device) 95 | action = torch.einsum('bsh,s->bsh', action, action_mask) 96 | 97 | if self.action_mask_coef >= 0: 98 | motion = torch.cat([reaction, action], dim=-1) 99 | else: 100 | motion = reaction 101 | 102 | return motion 103 | 104 | def encode_motion(self, reaction, action, boolean_mask): 105 | motion = self.combine_motion(reaction=reaction, action=action) 106 | bsz, seq_length, _ = reaction.shape 107 | 108 | encoder_input = self.motion_lin_in(motion) 109 | encoder_input = torch.concat([self.cls.expand(size=(bsz, 1, self.text_feature_size)), encoder_input], dim=1) 110 | 111 | encoder_input = self.pe(encoder_input) 112 | src_pad_mask = torch.cat( 113 | [torch.zeros(size=(bsz, 1), dtype=motion.dtype, device=motion.device), boolean_mask], dim=1 114 | ) 115 | 116 | encoder_output = self.motion_transformer.forward(src=encoder_input, src_key_padding_mask=src_pad_mask)[:, 0, :] 117 | motion_feature = self.motion_linear_out(encoder_output) 118 | return motion_feature 119 | 120 | def encode_text(self, text_list): 121 | inputs = self.tokenizer(text_list, padding=True, return_tensors="pt", truncation=True) 122 | input_ids = inputs.input_ids.to(self.device) 123 | attention_mask = ~inputs.attention_mask.to(device=self.device, dtype=bool) 124 | 125 | with torch.no_grad(): 126 | hidden_states = self.text_emb(input_ids) 127 | 128 | last_hidden_state = self.text_transformer.forward( 129 | src=hidden_states, 130 | src_key_padding_mask=attention_mask 131 | ) 132 | 133 | pooled_output = last_hidden_state[ 134 | torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), 135 | (input_ids.to(dtype=torch.int, device=last_hidden_state.device) == self.tokenizer.eos_token_id) 136 | .int() 137 | .argmax(dim=-1), 138 | ] 139 | text_feature = self.text_linear_out(pooled_output) 140 | return text_feature 141 | 142 | def get_log_dict(self, batch, batch_idx, split) -> Dict: 143 | boolean_mask = batch[f'boolean_mask'] 144 | reaction = batch['reaction'] 145 | action = batch['action'] 146 | text_list = batch['text'] 147 | 148 | if 'label' in batch: 149 | labels = batch['label'] 150 | 151 | motion_embeddings = self.encode_motion(reaction=reaction, action=action, boolean_mask=boolean_mask) 152 | text_embeddings = self.encode_text(text_list=text_list) 153 | 154 | logits_per_motion = self.latent_scale.exp() * motion_embeddings @ text_embeddings.t() 155 | 156 | logits_per_d = logits_per_motion.t() 157 | batch_size = motion_embeddings.shape[0] 158 | ground_truth = torch.arange(batch_size, dtype=torch.long, device=motion_embeddings.device) 159 | ce_from_motion_loss = self.ce_loss(logits_per_motion, ground_truth) 160 | ce_from_d_loss = self.ce_loss(logits_per_d, ground_truth) 161 | 162 | res = { 163 | f'{split}/d_ce': ce_from_d_loss, 164 | f'{split}/motion_ce': ce_from_motion_loss, 165 | } 166 | 167 | total_loss = (ce_from_motion_loss + ce_from_d_loss) / 2 168 | if self.model_kwargs.cls_weight > 0: 169 | logits = self.motion_cls_head(motion_embeddings) 170 | ce_from_cls_loss = self.ce_loss(input=logits, target=labels) 171 | res[f'{split}/cls_ce'] = ce_from_cls_loss 172 | total_loss = ce_from_cls_loss * self.model_kwargs.cls_weight + total_loss 173 | 174 | res[f'{split}/total_loss'] = total_loss 175 | 176 | return res 177 | 178 | def get_metrics(self, batch, split, shift=False, q=False): 179 | log_dict = {} 180 | 181 | boolean_mask = batch[f'boolean_mask'] 182 | action = batch['action'] 183 | reaction = batch['reaction'] 184 | if shift or q: # ignore this when training your own model 185 | if q: 186 | reaction_shifted = batch['reaction_shifted'] 187 | else: 188 | reaction_shifted = batch['random_reaction'] 189 | min_lengthes = torch.stack([batch['length'], batch['random_length']], dim=-1).min(-1).values 190 | shifted_boolean_mask = boolean_mask.clone().detach() 191 | for b in range(min_lengthes.shape[0]): 192 | shifted_boolean_mask[b, min_lengthes[b]:] = True 193 | 194 | motion_embeddings = self.encode_motion(reaction=reaction, action=action, boolean_mask=boolean_mask) 195 | shifted_motion_embeddings = self.encode_motion(reaction=reaction_shifted, action=action, boolean_mask=shifted_boolean_mask) 196 | fid = calculate_fid(motion_embeddings.detach().cpu().numpy(), shifted_motion_embeddings.detach().cpu().numpy()) 197 | reaction = reaction_shifted 198 | mask = shifted_boolean_mask 199 | else: 200 | fid = 0.0 201 | log_dict[f'{split}/fid'] = fid 202 | 203 | text_list = batch['text'] 204 | labels = batch['label'] 205 | 206 | motion_embeddings = self.encode_motion(reaction=reaction, action=action, boolean_mask=boolean_mask) 207 | 208 | if self.model_kwargs.cls_weight > 0: 209 | logits = self.motion_cls_head(motion_embeddings) 210 | acc_1 = (logits.argmax(-1) == labels).sum() / labels.shape[0] 211 | log_dict[f'{split}/acc_1'] = acc_1 212 | 213 | _, top5_preds = torch.topk(logits, 5, dim=1) 214 | top5_correct = (labels.unsqueeze(1) == top5_preds).any(dim=1).float() 215 | acc_5 = top5_correct.sum() / labels.shape[0] 216 | log_dict[f'{split}/acc_5'] = acc_5 217 | 218 | motion_embeddings = motion_embeddings.detach().cpu().numpy() 219 | 220 | text_embeddings = self.encode_text(text_list=text_list).detach().cpu().numpy() 221 | batch_size = boolean_mask.shape[0] 222 | 223 | gt_dist_mat = euclidean_distance_matrix(text_embeddings, motion_embeddings) 224 | mm_dist = gt_dist_mat.trace() / batch_size 225 | log_dict[f'{split}/mm_dist'] = mm_dist 226 | 227 | argsmax = np.argsort(gt_dist_mat, axis=1) 228 | top_k_mat = calculate_top_k(argsmax, top_k=3) 229 | r_prec = top_k_mat.sum(axis=0) / batch_size 230 | 231 | total_ranking = 0 232 | for i in range(3): 233 | log_dict[f'{split}/top {i + 1}'] = r_prec[i] 234 | total_ranking += r_prec[i] 235 | 236 | log_dict['monitor'] = total_ranking 237 | if self.model_kwargs.cls_weight > 0: 238 | log_dict['monitor'] += acc_1 239 | 240 | div = calculate_diversity(activations=motion_embeddings, diversity_times=batch_size - 1) 241 | log_dict[f'{split}/div'] = div 242 | 243 | for k, v in log_dict.items(): 244 | try: 245 | log_dict[k] = torch.tensor(v).to(self.device) 246 | except: pass 247 | 248 | return log_dict 249 | 250 | def extra_validation_step(self, batch, batch_idx=None) -> Dict: 251 | return self.get_metrics(batch, 'val') 252 | 253 | def test_step(self, batch, batch_idx=None): 254 | res = self.get_metrics(batch, 'test') 255 | self.log_dict(res, sync_dist=True) 256 | return res 257 | --------------------------------------------------------------------------------