├── common ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ └── quaternion.cpython-38.pyc └── quaternion.py ├── utils ├── __init__.py ├── __pycache__ │ ├── utils.cpython-38.pyc │ ├── __init__.cpython-38.pyc │ ├── metrics.cpython-38.pyc │ ├── plot_script.cpython-38.pyc │ ├── preprocess.cpython-38.pyc │ ├── quaternion.cpython-38.pyc │ └── rotation_conversions.cpython-38.pyc ├── preprocess.py ├── paramUtil.py ├── plot_script.py ├── metrics.py ├── utils.py ├── quaternion.py └── rotation_conversions.py ├── tools ├── __init__.py ├── train.py └── eval.py ├── assets ├── model.png ├── teaser.png └── demo_teaser.mp4 ├── data ├── global_std.npy └── global_mean.npy ├── models ├── __pycache__ │ ├── nets.cpython-38.pyc │ ├── utils.cpython-38.pyc │ ├── blocks.cpython-38.pyc │ ├── layers.cpython-38.pyc │ ├── losses.cpython-38.pyc │ ├── __init__.cpython-38.pyc │ ├── intergen.cpython-38.pyc │ ├── cfg_sampler.cpython-38.pyc │ └── gaussian_diffusion.cpython-38.pyc ├── __init__.py ├── cfg_sampler.py ├── blocks.py ├── utils.py ├── layers.py ├── intergen.py ├── nets.py └── losses.py ├── configs ├── __pycache__ │ └── __init__.cpython-38.pyc ├── eval_model.yaml ├── infer.yaml ├── train_single.yaml ├── train_inter.yaml ├── model_single.yaml ├── model_inter.yaml ├── datasets_inter.yaml ├── datasets_single.yaml └── __init__.py ├── datasets ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── evaluator.cpython-38.pyc │ ├── interhuman.cpython-38.pyc │ └── evaluator_models.cpython-38.pyc ├── __init__.py ├── evaluator_models.py ├── dataloader.py └── evaluator.py ├── test.sh ├── train_inter.sh ├── train_single.sh ├── prepare └── download_evaluation_model.sh ├── requirements.txt └── README.md /common/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /assets/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VankouF/FreeMotion-Codes/HEAD/assets/model.png -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VankouF/FreeMotion-Codes/HEAD/assets/teaser.png -------------------------------------------------------------------------------- /data/global_std.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VankouF/FreeMotion-Codes/HEAD/data/global_std.npy -------------------------------------------------------------------------------- /data/global_mean.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VankouF/FreeMotion-Codes/HEAD/data/global_mean.npy -------------------------------------------------------------------------------- /assets/demo_teaser.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VankouF/FreeMotion-Codes/HEAD/assets/demo_teaser.mp4 -------------------------------------------------------------------------------- /models/__pycache__/nets.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VankouF/FreeMotion-Codes/HEAD/models/__pycache__/nets.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VankouF/FreeMotion-Codes/HEAD/models/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VankouF/FreeMotion-Codes/HEAD/utils/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/blocks.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VankouF/FreeMotion-Codes/HEAD/models/__pycache__/blocks.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/layers.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VankouF/FreeMotion-Codes/HEAD/models/__pycache__/layers.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/losses.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VankouF/FreeMotion-Codes/HEAD/models/__pycache__/losses.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VankouF/FreeMotion-Codes/HEAD/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/metrics.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VankouF/FreeMotion-Codes/HEAD/utils/__pycache__/metrics.cpython-38.pyc -------------------------------------------------------------------------------- /common/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VankouF/FreeMotion-Codes/HEAD/common/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /common/__pycache__/quaternion.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VankouF/FreeMotion-Codes/HEAD/common/__pycache__/quaternion.cpython-38.pyc -------------------------------------------------------------------------------- /configs/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VankouF/FreeMotion-Codes/HEAD/configs/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VankouF/FreeMotion-Codes/HEAD/datasets/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VankouF/FreeMotion-Codes/HEAD/models/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/intergen.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VankouF/FreeMotion-Codes/HEAD/models/__pycache__/intergen.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/plot_script.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VankouF/FreeMotion-Codes/HEAD/utils/__pycache__/plot_script.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/preprocess.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VankouF/FreeMotion-Codes/HEAD/utils/__pycache__/preprocess.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/quaternion.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VankouF/FreeMotion-Codes/HEAD/utils/__pycache__/quaternion.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/evaluator.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VankouF/FreeMotion-Codes/HEAD/datasets/__pycache__/evaluator.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/interhuman.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VankouF/FreeMotion-Codes/HEAD/datasets/__pycache__/interhuman.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/cfg_sampler.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VankouF/FreeMotion-Codes/HEAD/models/__pycache__/cfg_sampler.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/evaluator_models.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VankouF/FreeMotion-Codes/HEAD/datasets/__pycache__/evaluator_models.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/gaussian_diffusion.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VankouF/FreeMotion-Codes/HEAD/models/__pycache__/gaussian_diffusion.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/rotation_conversions.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VankouF/FreeMotion-Codes/HEAD/utils/__pycache__/rotation_conversions.cpython-38.pyc -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python tools/eval.py --model_config configs/model_inter.yaml --dataset_config configs/datasets_inter.yaml --evalmodel_config configs/eval_model.yaml 2 | 3 | -------------------------------------------------------------------------------- /train_inter.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python tools/train.py --model_config configs/model_inter.yaml --dataset_config configs/datasets_inter.yaml --train_config configs/train_inter.yaml 2 | 3 | -------------------------------------------------------------------------------- /train_single.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python tools/train.py --model_config configs/model_single.yaml --dataset_config configs/datasets_single.yaml --train_config configs/train_single.yaml 2 | 3 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .gaussian_diffusion import GaussianDiffusion 2 | from .nets import * 3 | from .intergen import * 4 | from .blocks import * 5 | from .cfg_sampler import * 6 | from .utils import * -------------------------------------------------------------------------------- /configs/eval_model.yaml: -------------------------------------------------------------------------------- 1 | NAME: InterCLIP 2 | NUM_LAYERS: 8 3 | NUM_HEADS: 8 4 | DROPOUT: 0.1 5 | INPUT_DIM: 258 6 | LATENT_DIM: 1024 7 | FF_SIZE: 2048 8 | ACTIVATION: gelu 9 | 10 | MOTION_REP: global 11 | FINETUNE: False 12 | -------------------------------------------------------------------------------- /prepare/download_evaluation_model.sh: -------------------------------------------------------------------------------- 1 | mkdir -p ./eval_model/ 2 | cd ./eval_model/ 3 | echo "The pretrained evaluation model will be stored in the 'eval_model' folder\n" 4 | # InterHuman 5 | echo "Downloading the evaluation model..." 6 | gdown https://drive.google.com/uc?id=1bJv5lTP7otJleaBYZ2byjru_k_wCsGvH 7 | echo "Evaluation model downloaded" -------------------------------------------------------------------------------- /configs/infer.yaml: -------------------------------------------------------------------------------- 1 | GENERAL: 2 | EXP_NAME: IG-S-8 3 | CHECKPOINT: ./checkpoints 4 | LOG_DIR: ./log 5 | 6 | TRAIN: 7 | LR: 1e-4 8 | WEIGHT_DECAY: 0.00002 9 | BATCH_SIZE: 1 10 | EPOCH: 2000 11 | STEP: 1000000 12 | LOG_STEPS: 10 13 | SAVE_STEPS: 20000 14 | SAVE_EPOCH: 100 15 | RESUME: #checkpoints/IG-S/8/model/epoch=99-step=17600.ckpt 16 | NUM_WORKERS: 2 17 | MODE: finetune 18 | LAST_EPOCH: 0 19 | LAST_ITER: 0 20 | -------------------------------------------------------------------------------- /configs/train_single.yaml: -------------------------------------------------------------------------------- 1 | GENERAL: 2 | EXP_NAME: train_single 3 | CHECKPOINT: ./checkpoints 4 | LOG_DIR: ./log 5 | 6 | TRAIN: 7 | LR: 1e-4 8 | WEIGHT_DECAY: 0.00002 9 | BATCH_SIZE: 80 10 | EPOCH: 2500 11 | STEP: 1000000 12 | LOG_STEPS: 10 13 | SAVE_STEPS: 20000 14 | SAVE_EPOCH: 50 15 | SAVE_TOPK: -1 16 | RESUME: 17 | NUM_WORKERS: 32 18 | MODE: finetune 19 | LAST_EPOCH: 0 20 | LAST_ITER: 0 21 | FROM_PRETRAIN: 22 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | --extra-index-url https://download.pytorch.org/whl/cu117 2 | numpy~=1.23.3 3 | tqdm~=4.65.0 4 | lightning==1.9.1 5 | scipy~=1.4.1 6 | matplotlib==3.2.0 7 | pillow~=7.2.0 8 | yacs~=0.1.8 9 | mmcv~=1.6.2 10 | opencv-python~=4.5.3.56 11 | tabulate~=0.8.9 12 | termcolor~=1.1.0 13 | smplx~=0.1.28 14 | torch==1.13.1+cu117 15 | torchvision==0.14.1+cu117 16 | torchaudio==0.13.1 17 | tensorboard==2.14.0 18 | git+https://github.com/openai/CLIP.git -------------------------------------------------------------------------------- /configs/train_inter.yaml: -------------------------------------------------------------------------------- 1 | GENERAL: 2 | EXP_NAME: train_inter 3 | CHECKPOINT: ./checkpoints 4 | LOG_DIR: ./log 5 | 6 | TRAIN: 7 | LR: 1e-4 8 | WEIGHT_DECAY: 0.00002 9 | BATCH_SIZE: 30 10 | EPOCH: 1000 11 | STEP: 1000000 12 | LOG_STEPS: 10 13 | SAVE_STEPS: 20000 14 | SAVE_EPOCH: 50 15 | SAVE_TOPK: -1 16 | RESUME: 17 | NUM_WORKERS: 32 18 | MODE: finetune 19 | LAST_EPOCH: 0 20 | LAST_ITER: 0 21 | FROM_PRETRAIN: ./checkpoints/train_single/epoch999.ckpt 22 | -------------------------------------------------------------------------------- /configs/model_single.yaml: -------------------------------------------------------------------------------- 1 | NAME: InterGenSpatialControl 2 | ARCHI: single # single # single, multi 3 | NUM_LAYERS: 8 4 | NUM_HEADS: 8 5 | DROPOUT: 0.1 6 | INPUT_DIM: 262 7 | LATENT_DIM: 1024 8 | FF_SIZE: 2048 9 | ACTIVATION: gelu 10 | CHECKPOINT: 11 | 12 | DIFFUSION_STEPS: 1000 13 | BETA_SCHEDULER: cosine 14 | SAMPLER: uniform 15 | 16 | MOTION_REP: global 17 | FINETUNE: False 18 | 19 | TEXT_ENCODER: clip 20 | T_BAR: 700 21 | 22 | CONTROL: text 23 | STRATEGY: ddim50 24 | CFG_WEIGHT: 3.5 25 | 26 | GENERATE_NUM: 2 27 | 28 | -------------------------------------------------------------------------------- /configs/model_inter.yaml: -------------------------------------------------------------------------------- 1 | NAME: InterGenSpatialControl 2 | ARCHI: multi # single # single, multi 3 | NUM_LAYERS: 8 4 | NUM_HEADS: 8 5 | DROPOUT: 0.1 6 | INPUT_DIM: 262 7 | LATENT_DIM: 1024 8 | FF_SIZE: 2048 9 | ACTIVATION: gelu 10 | CHECKPOINT: ./checkpoints/epoch999.ckpt 11 | 12 | DIFFUSION_STEPS: 1000 13 | BETA_SCHEDULER: cosine 14 | SAMPLER: uniform 15 | 16 | MOTION_REP: global 17 | FINETUNE: False 18 | 19 | TEXT_ENCODER: clip 20 | T_BAR: 700 21 | 22 | CONTROL: text 23 | STRATEGY: ddim50 24 | CFG_WEIGHT: 3.5 25 | 26 | GENERATE_NUM: 2 27 | 28 | -------------------------------------------------------------------------------- /configs/datasets_inter.yaml: -------------------------------------------------------------------------------- 1 | interhuman: 2 | NAME: interhuman 3 | DATA_ROOT: ./data 4 | MOTION_REP: global 5 | MODE: train 6 | CACHE: True 7 | SINGLE_TEXT_DESCRIPTION: True 8 | 9 | interhuman_val: 10 | NAME: interhuman 11 | DATA_ROOT: ./data 12 | MOTION_REP: global 13 | MODE: val 14 | CACHE: True 15 | SINGLE_TEXT_DESCRIPTION: True 16 | 17 | interhuman_test: 18 | NAME: interhuman 19 | DATA_ROOT: ./data 20 | MOTION_REP: global 21 | MODE: test 22 | CACHE: True 23 | SINGLE_TEXT_DESCRIPTION: True 24 | -------------------------------------------------------------------------------- /configs/datasets_single.yaml: -------------------------------------------------------------------------------- 1 | interhuman: 2 | NAME: singlehuman 3 | DATA_ROOT: ./data 4 | MOTION_REP: global 5 | MODE: train 6 | CACHE: True 7 | SINGLE_TEXT_DESCRIPTION: True 8 | 9 | interhuman_val: 10 | NAME: singlehuman 11 | DATA_ROOT: ./data 12 | MOTION_REP: global 13 | MODE: val 14 | CACHE: True 15 | SINGLE_TEXT_DESCRIPTION: True 16 | 17 | interhuman_test: 18 | NAME: singlehuman 19 | DATA_ROOT: ./data 20 | MOTION_REP: global 21 | MODE: test 22 | CACHE: True 23 | SINGLE_TEXT_DESCRIPTION: True 24 | -------------------------------------------------------------------------------- /utils/preprocess.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from utils.utils import * 3 | 4 | FPS = 30 5 | 6 | def load_motion(file_path, min_length, swap=False): 7 | 8 | try: 9 | motion = np.load(file_path).astype(np.float32) 10 | except: 11 | print("error: ", file_path) 12 | return None, None 13 | 14 | motion1 = motion[:, :22 * 3] # 22*3 表示的是每个关节的position 15 | motion2 = motion[:, 62 * 3:62 * 3 + 21 * 6] # 21*6 表示的6D旋转表示 16 | motion = np.concatenate([motion1, motion2], axis=1) 17 | 18 | if motion.shape[0] < min_length: 19 | return None, None 20 | if swap: 21 | motion_swap = swap_left_right(motion, 22) 22 | else: 23 | motion_swap = None 24 | return motion, motion_swap -------------------------------------------------------------------------------- /models/cfg_sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | class ClassifierFreeSampleModel(nn.Module): 4 | 5 | def __init__(self, model, cfg_scale): 6 | super().__init__() 7 | self.model = model # model is the actual model to run 8 | self.s = cfg_scale 9 | 10 | def forward(self, x, timesteps, cond=None, mask=None, motion_guidance = None, spatial_condition=None, **kwargs): 11 | B, T, D = x.shape 12 | 13 | x_combined = torch.cat([x, x], dim=0) 14 | timesteps_combined = torch.cat([timesteps, timesteps], dim=0) 15 | if cond is not None: 16 | cond = torch.cat([cond, torch.zeros_like(cond)], dim=0) 17 | if mask is not None: 18 | mask = torch.cat([mask, mask], dim=0) 19 | if motion_guidance is not None: 20 | motion_guidance = torch.cat([motion_guidance, motion_guidance], dim=0) 21 | if spatial_condition is not None: 22 | spatial_condition = torch.cat([spatial_condition, spatial_condition], dim=0) 23 | out = self.model(x_combined, timesteps_combined, cond=cond, mask=mask, motion_guidance=motion_guidance, spatial_condition=spatial_condition, **kwargs) 24 | 25 | out_cond = out[:B] 26 | out_uncond = out[B:] 27 | 28 | cfg_out = self.s * out_cond + (1-self.s) *out_uncond 29 | return cfg_out 30 | -------------------------------------------------------------------------------- /configs/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Dict 3 | from yacs.config import CfgNode as CN 4 | 5 | 6 | def to_lower(x: Dict) -> Dict: 7 | """ 8 | Convert all dictionary keys to lowercase 9 | Args: 10 | x (dict): Input dictionary 11 | Returns: 12 | dict: Output dictionary with all keys converted to lowercase 13 | """ 14 | return {k.lower(): v for k, v in x.items()} 15 | 16 | _C = CN(new_allowed=True) 17 | 18 | def default_config() -> CN: 19 | """ 20 | Get a yacs CfgNode object with the default config values. 21 | """ 22 | # Return a clone so that the defaults will not be altered 23 | # This is for the "local variable" use pattern 24 | return _C.clone() 25 | 26 | def get_config(config_file: str, merge: bool = True) -> CN: 27 | """ 28 | Read a config file and optionally merge it with the default config file. 29 | Args: 30 | config_file (str): Path to config file. 31 | merge (bool): Whether to merge with the default config or not. 32 | Returns: 33 | CfgNode: Config as a yacs CfgNode object. 34 | """ 35 | if merge: 36 | cfg = default_config() 37 | else: 38 | cfg = CN(new_allowed=True) 39 | cfg.merge_from_file(config_file) 40 | cfg.freeze() 41 | return cfg 42 | 43 | def dataset_config() -> CN: 44 | """ 45 | Get dataset config file 46 | Returns: 47 | CfgNode: Dataset config as a yacs CfgNode object. 48 | """ 49 | cfg = CN(new_allowed=True) 50 | config_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'datasets.yaml') 51 | cfg.merge_from_file(config_file) 52 | cfg.freeze() 53 | return cfg 54 | 55 | -------------------------------------------------------------------------------- /utils/paramUtil.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | # Define a kinematic tree for the skeletal struture 4 | kit_kinematic_chain = [[0, 11, 12, 13, 14, 15], [0, 16, 17, 18, 19, 20], [0, 1, 2, 3, 4], [3, 5, 6, 7], [3, 8, 9, 10]] 5 | 6 | kit_raw_offsets = np.array( 7 | [ 8 | [0, 0, 0], 9 | [0, 1, 0], 10 | [0, 1, 0], 11 | [0, 1, 0], 12 | [0, 1, 0], 13 | [1, 0, 0], 14 | [0, -1, 0], 15 | [0, -1, 0], 16 | [-1, 0, 0], 17 | [0, -1, 0], 18 | [0, -1, 0], 19 | [1, 0, 0], 20 | [0, -1, 0], 21 | [0, -1, 0], 22 | [0, 0, 1], 23 | [0, 0, 1], 24 | [-1, 0, 0], 25 | [0, -1, 0], 26 | [0, -1, 0], 27 | [0, 0, 1], 28 | [0, 0, 1] 29 | ] 30 | ) 31 | 32 | t2m_raw_offsets = np.array([[0,0,0], 33 | [1,0,0], 34 | [-1,0,0], 35 | [0,1,0], 36 | [0,-1,0], 37 | [0,-1,0], 38 | [0,1,0], 39 | [0,-1,0], 40 | [0,-1,0], 41 | [0,1,0], 42 | [0,0,1], 43 | [0,0,1], 44 | [0,1,0], 45 | [1,0,0], 46 | [-1,0,0], 47 | [0,0,1], 48 | [0,-1,0], 49 | [0,-1,0], 50 | [0,-1,0], 51 | [0,-1,0], 52 | [0,-1,0], 53 | [0,-1,0]]) 54 | 55 | t2m_kinematic_chain = [[0, 2, 5, 8, 11], [0, 1, 4, 7, 10], [0, 3, 6, 9, 12, 15], [9, 14, 17, 19, 21], [9, 13, 16, 18, 20]] 56 | t2m_left_hand_chain = [[20, 22, 23, 24], [20, 34, 35, 36], [20, 25, 26, 27], [20, 31, 32, 33], [20, 28, 29, 30]] 57 | t2m_right_hand_chain = [[21, 43, 44, 45], [21, 46, 47, 48], [21, 40, 41, 42], [21, 37, 38, 39], [21, 49, 50, 51]] 58 | 59 | 60 | kit_tgt_skel_id = '03950' 61 | 62 | t2m_tgt_skel_id = '000021' 63 | 64 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | import lightning.pytorch as pl 2 | import torch 3 | from .interhuman import ( 4 | InterHumanDataset, 5 | SingleHumanDataset, 6 | InterHumanPipelineInferDataset, 7 | ) 8 | from datasets.evaluator import ( 9 | EvaluatorModelWrapper, 10 | EvaluationDataset, 11 | get_dataset_motion_loader, 12 | get_motion_loader) 13 | 14 | 15 | __all__ = [ 16 | 'InterHumanDataset', 17 | 'InterHumanPipelineInferDataset', 18 | 'get_dataset_motion_loader', 'get_motion_loader' 19 | ] 20 | 21 | def build_loader(cfg, data_cfg): 22 | # setup data 23 | if data_cfg.NAME == "interhuman": 24 | train_dataset = InterHumanDataset(data_cfg) 25 | else: 26 | raise NotImplementedError 27 | 28 | loader = torch.utils.data.DataLoader( 29 | train_dataset, 30 | batch_size=cfg.BATCH_SIZE, 31 | num_workers=1, 32 | pin_memory=False, 33 | shuffle=True, 34 | drop_last=True, 35 | ) 36 | 37 | return loader 38 | 39 | class DataModule(pl.LightningDataModule): 40 | def __init__(self, cfg, batch_size, num_workers): 41 | """ 42 | Initialize LightningDataModule for ProHMR training 43 | Args: 44 | cfg (CfgNode): Config file as a yacs CfgNode containing necessary dataset info. 45 | dataset_cfg (CfgNode): Dataset configuration file 46 | """ 47 | super().__init__() 48 | self.cfg = cfg 49 | self.batch_size = batch_size 50 | self.num_workers = num_workers 51 | 52 | def setup(self, stage = None): 53 | """ 54 | Create train and validation datasets 55 | """ 56 | 57 | if self.cfg.NAME == "interhuman": 58 | self.train_dataset = InterHumanDataset(self.cfg) 59 | elif self.cfg.NAME == 'singlehuman': 60 | self.train_dataset = SingleHumanDataset(self.cfg) 61 | else: 62 | raise NotImplementedError 63 | 64 | def train_dataloader(self): 65 | """ 66 | Return train dataloader 67 | """ 68 | return torch.utils.data.DataLoader( 69 | self.train_dataset, 70 | batch_size=self.batch_size, 71 | num_workers=self.num_workers, 72 | pin_memory=False, 73 | shuffle=True, 74 | drop_last=True, 75 | ) 76 | -------------------------------------------------------------------------------- /models/blocks.py: -------------------------------------------------------------------------------- 1 | from .layers import * 2 | 3 | 4 | class TransformerBlock(nn.Module): 5 | def __init__(self, 6 | latent_dim=512, 7 | num_heads=8, 8 | ff_size=1024, 9 | dropout=0., 10 | cond_abl=False, 11 | **kargs): 12 | super().__init__() 13 | self.latent_dim = latent_dim 14 | self.num_heads = num_heads 15 | self.dropout = dropout 16 | self.cond_abl = cond_abl 17 | 18 | self.sa_block = VanillaSelfAttention(latent_dim, num_heads, dropout) 19 | self.ca_block = VanillaCrossAttention(latent_dim, latent_dim, num_heads, dropout, latent_dim) 20 | self.ffn = FFN(latent_dim, ff_size, dropout, latent_dim) 21 | 22 | def forward(self, x, y, emb=None, key_padding_mask=None): 23 | h1 = self.sa_block(x, emb, key_padding_mask) 24 | h1 = h1 + x 25 | h2 = self.ca_block(h1, y, emb, key_padding_mask) 26 | h2 = h2 + h1 27 | out = self.ffn(h2, emb) 28 | out = out + h2 29 | return out 30 | 31 | 32 | class TransformerMotionGuidanceBlock(nn.Module): 33 | def __init__(self, 34 | latent_dim=512, 35 | num_heads=8, 36 | ff_size=1024, 37 | dropout=0., 38 | cond_abl=False, 39 | **kargs): 40 | super().__init__() 41 | self.latent_dim = latent_dim 42 | self.num_heads = num_heads 43 | self.dropout = dropout 44 | self.cond_abl = cond_abl 45 | 46 | self.sa_block = VanillaSelfAttention(latent_dim, num_heads, dropout) 47 | self.condition_sa_block = VanillaSelfAttention(latent_dim, num_heads, dropout, latent_dim) 48 | 49 | self.ffn = FFN(latent_dim, ff_size, dropout, latent_dim) 50 | 51 | def forward(self, x, T=300, emb=None, key_padding_mask=None): 52 | 53 | x_a = x[:,:T,...] 54 | key_padding_mask_a = key_padding_mask[:,:T] 55 | 56 | # self_att first 57 | h1 = self.sa_block(x_a, emb, key_padding_mask_a) 58 | h1 = h1 + x_a 59 | 60 | h1 = torch.cat([h1, x[:,T:,...]], dim=1) 61 | 62 | # add motion guidance to att again 63 | h2 = self.condition_sa_block(h1, emb, key_padding_mask) 64 | h2 = h2 + h1 65 | 66 | out = self.ffn(h2, emb) 67 | out = out + h2 68 | 69 | return out 70 | -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | 5 | 6 | class CosineWarmupScheduler(torch.optim.lr_scheduler._LRScheduler): 7 | def __init__(self, optimizer, warmup, max_iters, verbose=False): 8 | self.warmup = warmup 9 | self.max_num_iters = max_iters 10 | super().__init__(optimizer, verbose=verbose) 11 | 12 | def get_lr(self): 13 | lr_factor = self.get_lr_factor(epoch=self.last_epoch) 14 | return [base_lr * lr_factor for base_lr in self.base_lrs] 15 | 16 | def get_lr_factor(self, epoch): 17 | lr_factor = 0.5 * (1 + np.cos(np.pi * epoch / self.max_num_iters)) 18 | if epoch <= self.warmup: 19 | lr_factor *= (epoch+1) * 1.0 / self.warmup 20 | return lr_factor 21 | 22 | 23 | 24 | class PositionalEncoding(nn.Module): 25 | def __init__(self, d_model, dropout=0.0, max_len=5000): 26 | super(PositionalEncoding, self).__init__() 27 | self.dropout = nn.Dropout(p=dropout) 28 | 29 | pe = torch.zeros(max_len, d_model) 30 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 31 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)) 32 | pe[:, 0::2] = torch.sin(position * div_term) 33 | pe[:, 1::2] = torch.cos(position * div_term) 34 | # pe = pe.unsqueeze(0)#.transpose(0, 1) 35 | 36 | self.register_buffer('pe', pe) 37 | 38 | def forward(self, x): 39 | # not used in the final model 40 | x = x + self.pe[:x.shape[1], :].unsqueeze(0) 41 | return self.dropout(x) 42 | 43 | 44 | class TimestepEmbedder(nn.Module): 45 | def __init__(self, latent_dim, sequence_pos_encoder): 46 | super().__init__() 47 | self.latent_dim = latent_dim 48 | self.sequence_pos_encoder = sequence_pos_encoder 49 | 50 | time_embed_dim = self.latent_dim 51 | self.time_embed = nn.Sequential( 52 | nn.Linear(self.latent_dim, time_embed_dim), 53 | nn.SiLU(), 54 | nn.Linear(time_embed_dim, time_embed_dim), 55 | ) 56 | 57 | def forward(self, timesteps): 58 | return self.time_embed(self.sequence_pos_encoder.pe[timesteps]) 59 | 60 | 61 | class IdentityEmbedder(nn.Module): 62 | def __init__(self, latent_dim, sequence_pos_encoder): 63 | super().__init__() 64 | self.latent_dim = latent_dim 65 | self.sequence_pos_encoder = sequence_pos_encoder 66 | 67 | time_embed_dim = self.latent_dim 68 | self.time_embed = nn.Sequential( 69 | nn.Linear(self.latent_dim, time_embed_dim), 70 | nn.SiLU(), 71 | nn.Linear(time_embed_dim, time_embed_dim), 72 | ) 73 | 74 | def forward(self, timesteps): 75 | return self.time_embed(self.sequence_pos_encoder.pe[timesteps]).unsqueeze(1) 76 | 77 | 78 | def set_requires_grad(nets, requires_grad=False): 79 | """Set requies_grad for all the networks. 80 | 81 | Args: 82 | nets (nn.Module | list[nn.Module]): A list of networks or a single 83 | network. 84 | requires_grad (bool): Whether the networks require gradients or not 85 | """ 86 | if not isinstance(nets, list): 87 | nets = [nets] 88 | for net in nets: 89 | if net is not None: 90 | for param in net.parameters(): 91 | param.requires_grad = requires_grad 92 | 93 | 94 | def zero_module(module): 95 | """ 96 | Zero out the parameters of a module and return it. 97 | """ 98 | for p in module.parameters(): 99 | p.detach().zero_() 100 | return module 101 | -------------------------------------------------------------------------------- /models/layers.py: -------------------------------------------------------------------------------- 1 | from .utils import * 2 | 3 | class AdaLN(nn.Module): 4 | 5 | def __init__(self, latent_dim, embed_dim=None): 6 | super().__init__() 7 | if embed_dim is None: 8 | embed_dim = latent_dim 9 | self.emb_layers = nn.Sequential( 10 | # nn.Linear(embed_dim, latent_dim, bias=True), 11 | nn.SiLU(), 12 | zero_module(nn.Linear(embed_dim, 2 * latent_dim, bias=True)), 13 | ) 14 | self.norm = nn.LayerNorm(latent_dim, elementwise_affine=False, eps=1e-6) 15 | 16 | def forward(self, h, emb): 17 | """ 18 | h: B, T, D 19 | emb: B, D 20 | """ 21 | # B, 1, 2D 22 | emb_out = self.emb_layers(emb) 23 | # scale: B, 1, D / shift: B, 1, D 24 | scale, shift = torch.chunk(emb_out, 2, dim=-1) 25 | h = self.norm(h) * (1 + scale[:, None]) + shift[:, None] 26 | return h 27 | 28 | 29 | class VanillaSelfAttention(nn.Module): 30 | 31 | def __init__(self, latent_dim, num_head, dropout, embed_dim=None): 32 | super().__init__() 33 | self.num_head = num_head 34 | self.norm = AdaLN(latent_dim, embed_dim) 35 | self.attention = nn.MultiheadAttention(latent_dim, num_head, dropout=dropout, batch_first=True, 36 | add_zero_attn=True) 37 | 38 | def forward(self, x, emb, key_padding_mask=None): 39 | """ 40 | x: B, T, D 41 | """ 42 | x_norm = self.norm(x, emb) 43 | y = self.attention(x_norm, x_norm, x_norm, 44 | attn_mask=None, 45 | key_padding_mask=key_padding_mask, 46 | need_weights=False)[0] 47 | return y 48 | 49 | 50 | class VanillaCrossAttention(nn.Module): 51 | 52 | def __init__(self, latent_dim, xf_latent_dim, num_head, dropout, embed_dim=None): 53 | super().__init__() 54 | self.num_head = num_head 55 | self.norm = AdaLN(latent_dim, embed_dim) 56 | self.xf_norm = AdaLN(xf_latent_dim, embed_dim) 57 | self.attention = nn.MultiheadAttention(latent_dim, num_head, kdim=xf_latent_dim, vdim=xf_latent_dim, 58 | dropout=dropout, batch_first=True, add_zero_attn=True) 59 | 60 | def forward(self, x, xf, emb, key_padding_mask=None): 61 | """ 62 | x: B, T, D 63 | xf: B, N, L 64 | """ 65 | x_norm = self.norm(x, emb) 66 | xf_norm = self.xf_norm(xf, emb) 67 | y = self.attention(x_norm, xf_norm, xf_norm, 68 | attn_mask=None, 69 | key_padding_mask=key_padding_mask, 70 | need_weights=False)[0] 71 | return y 72 | 73 | 74 | class FFN(nn.Module): 75 | def __init__(self, latent_dim, ffn_dim, dropout, embed_dim=None): 76 | super().__init__() 77 | self.norm = AdaLN(latent_dim, embed_dim) 78 | self.linear1 = nn.Linear(latent_dim, ffn_dim, bias=True) 79 | self.linear2 = zero_module(nn.Linear(ffn_dim, latent_dim, bias=True)) 80 | self.activation = nn.GELU() 81 | self.dropout = nn.Dropout(dropout) 82 | 83 | def forward(self, x, emb=None): 84 | if emb is not None: 85 | x_norm = self.norm(x, emb) 86 | else: 87 | x_norm = x 88 | y = self.linear2(self.dropout(self.activation(self.linear1(x_norm)))) 89 | return y 90 | 91 | 92 | class FinalLayer(nn.Module): 93 | def __init__(self, latent_dim, out_dim): 94 | super().__init__() 95 | self.linear = zero_module(nn.Linear(latent_dim, out_dim, bias=True)) 96 | 97 | def forward(self, x): 98 | x = self.linear(x) 99 | return x -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |

Official repo for FreeMotion

3 | 4 |
5 | 6 |
7 |

FreeMotion: A Unified Framework for Number-free Text-to-Motion Synthesis

8 | 9 |

10 | Project Page • 11 | Arxiv Paper • 12 | Dataset Link • 13 | Citation 14 |

15 | 16 |
17 | 18 |
19 | 20 |
21 | 22 |
23 | 24 |
25 | 26 | 27 | 28 | 29 | 30 | ## Intro FreeMotion 31 | 32 | Text-to-motion synthesis is a crucial task in computer vision. Existing methods are limited in their universality, as they are tailored for single-person or two-person scenarios and can not be applied to generate motions for more individuals. To achieve the number-free motion synthesis, this paper reconsiders motion generation and proposes to unify the single and multi-person motion by the conditional motion distribution. Furthermore, a generation module and an interaction module are designed for our FreeMotion framework to decouple the process of conditional motion generation and finally support the number-free motion synthesis. Besides, based on our framework, the current single-person motion spatial control method could be seamlessly integrated, achieving precise control of multi-person motion. Extensive experiments demonstrate the superior performance of our method and our capability to infer single and multi-human motions simultaneously. 33 | pipeline 34 | 35 | 36 | ## ☑️ Todo List 37 | 38 | - [✓] Release the FreeMotion training. 39 | - [✓] Release the FreeMotion evaluation. 40 | - [✓] Release the separate_annots dataset. 41 | - [] Release the inference code. 42 | - [] Release the FreeMotion checkpoints. 43 | 44 | ## Quick Start 45 | 46 | 48 | 49 | ### 1. Conda environment 50 | 51 | ``` 52 | conda create python=3.8 --name freemotion 53 | conda activate freemotion 54 | ``` 55 | 56 | ``` 57 | pip install -r requirements.txt 58 | ``` 59 | 60 | ### 2. Download Text to Motion Evaluation Model 61 | 62 | ``` 63 | bash prepare/download_evaluation_model.sh 64 | ``` 65 | 66 | 67 | ## Train and Evaluation your own models 68 | 69 | ### 1. Prepare the datasets 70 | 71 | > Download the data from [InterGen Webpage](https://tr3e.github.io/intergen-page/). And put them into ./data/. 72 | 73 | > Download the data from [Ours Webpage](https://vankouf.github.io/FreeMotion/) And put it into ./data/. 74 | 75 | #### Data Structure 76 | ```sh 77 | 78 | ./annots //Natural language annotations where each file consisting of three sentences. 79 | ./motions //Raw motion data standardized as SMPL which is similiar to AMASS. 80 | ./motions_processed //Processed motion data with joint positions and rotations (6D representation) of SMPL 22 joints kinematic structure. 81 | ./split //Train-val-test split. 82 | ./separate_annots //Annotations for each person's motion 83 | ``` 84 | 85 | ### 2. Train 86 | 87 | #### Stage 1: train generation module 88 | ``` 89 | sh train_single.sh 90 | ``` 91 | 92 | #### Stage 2: train interaction module 93 | ``` 94 | sh train_inter.sh 95 | ``` 96 | 97 | ### 3. Evaluation 98 | 99 | ``` 100 | sh test.sh 101 | ``` 102 | 103 | ## 📖 Citation 104 | 105 | If you find our code or paper helps, please consider citing: 106 | 107 | ```bibtex 108 | @article{fan2024freemotion, 109 | title={FreeMotion: A Unified Framework for Number-free Text-to-Motion Synthesis}, 110 | author={Ke Fan and Junshu Tang and Weijian Cao and Ran Yi and Moran Li and Jingyu Gong and Jiangning Zhang and Yabiao Wang and Chengjie Wang and Lizhuang Ma}, 111 | year={2024}, 112 | eprint={2405.15763}, 113 | archivePrefix={arXiv}, 114 | primaryClass={cs.CV} 115 | } 116 | 117 | ``` 118 | 119 | ## Acknowledgments 120 | 121 | Thanks to [interhuman](https://github.com/tr3e/InterGen),[MotionGPT](https://github.com/OpenMotionLab/MotionGPT), our code is partially borrowing from them. 122 | 123 | ## Licenses 124 |
Creative Commons License
This work is licensed under a Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License. -------------------------------------------------------------------------------- /datasets/evaluator_models.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import clip 6 | 7 | from models import * 8 | 9 | 10 | loss_ce = nn.CrossEntropyLoss() 11 | class InterCLIP(nn.Module): 12 | def __init__(self, cfg): 13 | super().__init__() 14 | self.cfg = cfg 15 | self.latent_dim = cfg.LATENT_DIM 16 | self.motion_encoder = MotionEncoder(cfg) 17 | 18 | self.latent_dim = self.latent_dim 19 | 20 | clip_model, _ = clip.load("ViT-L/14@336px", device="cpu", jit=False) 21 | self.token_embedding = clip_model.token_embedding 22 | self.positional_embedding = clip_model.positional_embedding 23 | self.dtype = clip_model.dtype 24 | self.latent_scale = nn.Parameter(torch.Tensor([1])) 25 | 26 | set_requires_grad(self.token_embedding, False) 27 | 28 | textTransEncoderLayer = nn.TransformerEncoderLayer( 29 | d_model=768, 30 | nhead=8, 31 | dim_feedforward=cfg.FF_SIZE, 32 | dropout=0.1, 33 | activation="gelu", 34 | batch_first=True) 35 | self.textTransEncoder = nn.TransformerEncoder( 36 | textTransEncoderLayer, 37 | num_layers=8) 38 | self.text_ln = nn.LayerNorm(768) 39 | self.out = nn.Linear(768, 512) 40 | 41 | self.clip_training = "text_" 42 | self.l1_criterion = torch.nn.L1Loss(reduction='mean') 43 | 44 | def compute_loss(self, batch): 45 | losses = {} 46 | losses["total"] = 0 47 | 48 | # compute clip losses 49 | batch = self.encode_text(batch) 50 | batch = self.encode_motion(batch) 51 | 52 | mixed_clip_loss, clip_losses = self.compute_clip_losses(batch) 53 | losses.update(clip_losses) 54 | losses["total"] += mixed_clip_loss 55 | 56 | return losses["total"], losses 57 | 58 | def forward(self, batch): 59 | return self.compute_loss(batch) 60 | 61 | def compute_clip_losses(self, batch): 62 | mixed_clip_loss = 0. 63 | clip_losses = {} 64 | 65 | if 1: 66 | for d in self.clip_training.split('_')[:1]: 67 | if d == 'image': 68 | features = self.clip_model.encode_image(batch['images']).float() # preprocess is done in dataloader 69 | elif d == 'text': 70 | features = batch['text_emb'] 71 | motion_features = batch['motion_emb'] 72 | # normalized features 73 | features_norm = features / features.norm(dim=-1, keepdim=True) 74 | motion_features_norm = motion_features / motion_features.norm(dim=-1, keepdim=True) 75 | 76 | logit_scale = self.latent_scale ** 2 77 | logits_per_motion = logit_scale * motion_features_norm @ features_norm.t() 78 | logits_per_d = logits_per_motion.t() 79 | 80 | batch_size = motion_features.shape[0] 81 | ground_truth = torch.arange(batch_size, dtype=torch.long, device=motion_features.device) 82 | 83 | ce_from_motion_loss = loss_ce(logits_per_motion, ground_truth) 84 | ce_from_d_loss = loss_ce(logits_per_d, ground_truth) 85 | clip_mixed_loss = (ce_from_motion_loss + ce_from_d_loss) / 2. 86 | 87 | clip_losses[f'{d}_ce_from_d'] = ce_from_d_loss.item() 88 | clip_losses[f'{d}_ce_from_motion'] = ce_from_motion_loss.item() 89 | clip_losses[f'{d}_mixed_ce'] = clip_mixed_loss.item() 90 | mixed_clip_loss += clip_mixed_loss 91 | 92 | return mixed_clip_loss, clip_losses 93 | 94 | def generate_src_mask(self, T, length): 95 | B = length.shape[0] 96 | src_mask = torch.ones(B, T) 97 | for i in range(B): 98 | for j in range(length[i], T): 99 | src_mask[i, j] = 0 100 | return src_mask 101 | 102 | def encode_motion(self, batch): 103 | batch["mask"] = self.generate_src_mask(batch["motions"].shape[1], batch["motion_lens"]).to(batch["motions"].device) 104 | batch.update(self.motion_encoder(batch)) 105 | batch["motion_emb"] = batch["motion_emb"] / batch["motion_emb"].norm(dim=-1, keepdim=True) * self.latent_scale 106 | 107 | return batch 108 | 109 | def encode_text(self, batch): 110 | device = next(self.parameters()).device 111 | raw_text = batch["text"] 112 | 113 | with torch.no_grad(): 114 | text = clip.tokenize(raw_text, truncate=True).to(device) 115 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 116 | pe_tokens = x + self.positional_embedding.type(self.dtype) 117 | 118 | out = self.textTransEncoder(pe_tokens) 119 | out = self.text_ln(out) 120 | 121 | out = out[torch.arange(x.shape[0]), text.argmax(dim=-1)] 122 | out = self.out(out) 123 | 124 | batch['text_emb'] = out 125 | batch["text_emb"] = batch["text_emb"] / batch["text_emb"].norm(dim=-1, keepdim=True) * self.latent_scale 126 | 127 | return batch 128 | -------------------------------------------------------------------------------- /datasets/dataloader.py: -------------------------------------------------------------------------------- 1 | import platform 2 | import random 3 | from functools import partial 4 | from typing import Optional, Union 5 | 6 | import numpy as np 7 | from mmcv.runner import get_dist_info 8 | from mmcv.utils import Registry, build_from_cfg 9 | from torch.utils.data import DataLoader 10 | from torch.utils.data.dataset import Dataset 11 | 12 | import torch 13 | from torch.utils.data import DistributedSampler as _DistributedSampler 14 | 15 | 16 | class DistributedSampler(_DistributedSampler): 17 | 18 | def __init__(self, 19 | dataset, 20 | num_replicas=None, 21 | rank=None, 22 | shuffle=True, 23 | round_up=True): 24 | super().__init__(dataset, num_replicas=num_replicas, rank=rank) 25 | self.shuffle = shuffle 26 | self.round_up = round_up 27 | if self.round_up: 28 | self.total_size = self.num_samples * self.num_replicas 29 | else: 30 | self.total_size = len(self.dataset) 31 | 32 | def __iter__(self): 33 | # deterministically shuffle based on epoch 34 | if self.shuffle: 35 | g = torch.Generator() 36 | g.manual_seed(self.epoch) 37 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 38 | else: 39 | indices = torch.arange(len(self.dataset)).tolist() 40 | 41 | # add extra samples to make it evenly divisible 42 | if self.round_up: 43 | indices = ( 44 | indices * 45 | int(self.total_size / len(indices) + 1))[:self.total_size] 46 | assert len(indices) == self.total_size 47 | 48 | # subsample 49 | indices = indices[self.rank:self.total_size:self.num_replicas] 50 | if self.round_up: 51 | assert len(indices) == self.num_samples 52 | 53 | return iter(indices) 54 | 55 | 56 | def build_dataloader(dataset: Dataset, 57 | samples_per_gpu: int, 58 | workers_per_gpu: int, 59 | num_gpus: Optional[int] = 1, 60 | dist: Optional[bool] = True, 61 | shuffle: Optional[bool] = True, 62 | round_up: Optional[bool] = True, 63 | seed: Optional[Union[int, None]] = None, 64 | persistent_workers: Optional[bool] = True, 65 | **kwargs): 66 | """Build PyTorch DataLoader. 67 | 68 | In distributed training, each GPU/process has a dataloader. 69 | In non-distributed training, there is only one dataloader for all GPUs. 70 | 71 | Args: 72 | dataset (:obj:`Dataset`): A PyTorch dataset. 73 | samples_per_gpu (int): Number of training samples on each GPU, i.e., 74 | batch size of each GPU. 75 | workers_per_gpu (int): How many subprocesses to use for data loading 76 | for each GPU. 77 | num_gpus (int, optional): Number of GPUs. Only used in non-distributed 78 | training. 79 | dist (bool, optional): Distributed training/test or not. Default: True. 80 | shuffle (bool, optional): Whether to shuffle the data at every epoch. 81 | Default: True. 82 | round_up (bool, optional): Whether to round up the length of dataset by 83 | adding extra samples to make it evenly divisible. Default: True. 84 | persistent_workers (bool): If True, the data loader will not shutdown 85 | the worker processes after a dataset has been consumed once. 86 | This allows to maintain the workers Dataset instances alive. 87 | The argument also has effect in PyTorch>=1.7.0. 88 | Default: True 89 | kwargs: any keyword argument to be used to initialize DataLoader 90 | 91 | Returns: 92 | DataLoader: A PyTorch dataloader. 93 | """ 94 | rank, world_size = get_dist_info() 95 | if dist: 96 | sampler = DistributedSampler( 97 | dataset, world_size, rank, shuffle=shuffle, round_up=round_up) 98 | shuffle = False 99 | batch_size = samples_per_gpu 100 | num_workers = workers_per_gpu 101 | else: 102 | sampler = None 103 | batch_size = num_gpus * samples_per_gpu 104 | num_workers = num_gpus * workers_per_gpu 105 | 106 | init_fn = partial( 107 | worker_init_fn, num_workers=num_workers, rank=rank, 108 | seed=seed) if seed is not None else None 109 | 110 | data_loader = DataLoader( 111 | dataset, 112 | batch_size=batch_size, 113 | sampler=sampler, 114 | num_workers=num_workers, 115 | pin_memory=False, 116 | shuffle=shuffle, 117 | worker_init_fn=init_fn, 118 | persistent_workers=persistent_workers, 119 | **kwargs) 120 | 121 | 122 | 123 | return data_loader 124 | 125 | 126 | def worker_init_fn(worker_id: int, num_workers: int, rank: int, seed: int): 127 | """Init random seed for each worker.""" 128 | # The seed of each worker equals to 129 | # num_worker * rank + worker_id + user_seed 130 | worker_seed = num_workers * rank + worker_id + seed 131 | np.random.seed(worker_seed) 132 | random.seed(worker_seed) 133 | -------------------------------------------------------------------------------- /utils/plot_script.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import matplotlib 4 | import matplotlib.pyplot as plt 5 | from mpl_toolkits.mplot3d import Axes3D 6 | from matplotlib.animation import FuncAnimation, FFMpegFileWriter 7 | from mpl_toolkits.mplot3d.art3d import Poly3DCollection 8 | import mpl_toolkits.mplot3d.axes3d as p3 9 | # import cv2 10 | 11 | 12 | def list_cut_average(ll, intervals): 13 | if intervals == 1: 14 | return ll 15 | 16 | bins = math.ceil(len(ll) * 1.0 / intervals) 17 | ll_new = [] 18 | for i in range(bins): 19 | l_low = intervals * i 20 | l_high = l_low + intervals 21 | l_high = l_high if l_high < len(ll) else len(ll) 22 | ll_new.append(np.mean(ll[l_low:l_high])) 23 | return ll_new 24 | 25 | 26 | def plot_3d_motion(save_path, kinematic_tree, mp_joints, title, figsize=(10, 10), fps=120, radius=4, hints=None): 27 | 28 | matplotlib.use('Agg') 29 | 30 | title_sp = title.split(' ') 31 | if len(title_sp) > 20: 32 | title = '\n'.join([' '.join(title_sp[:10]), ' '.join(title_sp[10:20]), ' '.join(title_sp[20:])]) 33 | elif len(title_sp) > 10: 34 | title = '\n'.join([' '.join(title_sp[:10]), ' '.join(title_sp[10:])]) 35 | 36 | def init(): 37 | ax.set_xlim3d([-radius / 4, radius / 4]) 38 | ax.set_ylim3d([0, radius / 2]) 39 | ax.set_zlim3d([0, radius / 2]) 40 | 41 | fig.suptitle(title, fontsize=20) 42 | ax.grid(b=False) 43 | 44 | def plot_xzPlane(minx, maxx, miny, minz, maxz): 45 | ## Plot a plane XZ 46 | verts = [ 47 | [minx, miny, minz], 48 | [minx, miny, maxz], 49 | [maxx, miny, maxz], 50 | [maxx, miny, minz] 51 | ] 52 | xz_plane = Poly3DCollection([verts]) 53 | xz_plane.set_facecolor((0.5, 0.5, 0.5, 0.5)) 54 | ax.add_collection3d(xz_plane) 55 | 56 | # if hint is not None: 57 | # mask = hint.sum(-1) != 0 58 | # hint = hint[mask] 59 | 60 | hints_new = [] 61 | for hint in hints: 62 | if min(hint.shape) != 0: 63 | mask = hint.sum(-1) != 0 64 | hint = hint[mask] 65 | hints_new.append(hint) 66 | hints = hints_new 67 | 68 | fig = plt.figure(figsize=figsize) 69 | ax = p3.Axes3D(fig) 70 | init() 71 | 72 | mp_data = [] 73 | frame_number = min([data.shape[0] for data in mp_joints]) 74 | 75 | colors = ['red', 'green', 'black', 'blue', 'red', 76 | 'darkblue', 'darkblue', 'darkblue', 'darkblue', 'darkblue', 77 | 'darkred', 'darkred', 'darkred', 'black', 'blue'] 78 | 79 | mp_offset = list(range(-len(mp_joints)//2, len(mp_joints)//2, 1)) 80 | mp_colors = [[colors[i]] * 15 for i in range(len(mp_offset))] 81 | 82 | for i,joints in enumerate(mp_joints): 83 | 84 | # (seq_len, joints_num, 3) 85 | data = joints.copy().reshape(len(joints), -1, 3) 86 | 87 | MINS = data.min(axis=0).min(axis=0) 88 | MAXS = data.max(axis=0).max(axis=0) 89 | 90 | height_offset = MINS[1] 91 | data[:, :, 1] -= height_offset 92 | trajec = data[:, 0, [0, 2]] 93 | 94 | # if hint is not None: 95 | # hint[..., 1] -= height_offset 96 | 97 | hints_new = [] 98 | for hint in hints: 99 | if min(hint.shape) != 0: 100 | hint[..., 1] -= height_offset 101 | hints_new.append(hint) 102 | hints = hints_new 103 | 104 | # data[:, :, 0] -= data[0:1, 0:1, 0] 105 | # data[:, :, 0] += mp_offset[i] 106 | # 107 | # data[:, :, 2] -= data[0:1, 0:1, 2] 108 | mp_data.append({"joints":data, 109 | "MINS":MINS, 110 | "MAXS":MAXS, 111 | "trajec":trajec, }) 112 | 113 | # print(trajec.shape) 114 | 115 | def update(index): 116 | # print(index) 117 | ax.lines = [] 118 | ax.collections = [] 119 | ax.view_init(elev=120, azim=-90) 120 | ax.dist = 15#7.5 121 | # ax = 122 | plot_xzPlane(-3, 3, 0, -3, 3) 123 | 124 | for pid,data in enumerate(mp_data): 125 | for i, (chain, color) in enumerate(zip(kinematic_tree, mp_colors[pid])): 126 | # print(color) 127 | if i < 5: 128 | linewidth = 2.0 129 | else: 130 | linewidth = 1.0 131 | 132 | ax.plot3D(data["joints"][index, chain, 0], data["joints"][index, chain, 1], data["joints"][index, chain, 2], linewidth=linewidth, 133 | color=color) 134 | 135 | # if hint is not None: 136 | # ax.plot3D(hint[:index+1, 0], hint[:index+1, 1], hint[:index+1, 2], color="blue", linewidth=3.0) 137 | 138 | for idx, hint in enumerate(hints): 139 | if min(hint.shape) != 0: 140 | ax.plot3D(hint[:index+1, 0], hint[:index+1, 1], hint[:index+1, 2], color=colors[-1-idx], linewidth=3.0) 141 | 142 | plt.axis('off') 143 | ax.set_xticklabels([]) 144 | ax.set_yticklabels([]) 145 | ax.set_zticklabels([]) 146 | 147 | ani = FuncAnimation(fig, update, frames=frame_number, interval=1000 / fps, repeat=False) 148 | 149 | # writer = FFMpegFileWriter(fps=fps) 150 | ani.save(save_path, fps=fps) 151 | plt.close() 152 | -------------------------------------------------------------------------------- /models/intergen.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Mapping, Union, List 2 | 3 | import torch 4 | import clip 5 | 6 | from torch import nn 7 | from models import * 8 | from collections import OrderedDict 9 | 10 | class InterGenSpatialControlNet(nn.Module): 11 | def __init__(self, cfg): 12 | super().__init__() 13 | self.cfg = cfg 14 | self.latent_dim = cfg.LATENT_DIM 15 | self.decoder = InterDiffusionSpatialControlNet(cfg, sampling_strategy=cfg.STRATEGY) 16 | self.archi = cfg.ARCHI 17 | clip_model, _ = clip.load("ViT-L/14@336px", device="cpu", jit=False) 18 | self.token_embedding = clip_model.token_embedding 19 | self.clip_transformer = clip_model.transformer 20 | self.positional_embedding = clip_model.positional_embedding 21 | self.ln_final = clip_model.ln_final 22 | self.dtype = clip_model.dtype 23 | 24 | set_requires_grad(self.clip_transformer, False) 25 | set_requires_grad(self.token_embedding, False) 26 | set_requires_grad(self.ln_final, False) 27 | 28 | clipTransEncoderLayer = nn.TransformerEncoderLayer( 29 | d_model=768, 30 | nhead=8, 31 | dim_feedforward=2048, 32 | dropout=0.1, 33 | activation="gelu", 34 | batch_first=True) 35 | self.clipTransEncoder = nn.TransformerEncoder( 36 | clipTransEncoderLayer, 37 | num_layers=2) 38 | self.clip_ln = nn.LayerNorm(768) 39 | 40 | self.positional_embedding.requires_grad = False 41 | 42 | set_requires_grad(self.clipTransEncoder, False) 43 | set_requires_grad(self.clip_ln, False) 44 | 45 | 46 | def compute_loss(self, batch): 47 | 48 | batch = self.text_process(batch) 49 | losses = self.decoder.compute_loss(batch) 50 | return losses["total"], losses 51 | 52 | def decode_motion(self, batch): 53 | batch.update(self.decoder(batch)) # batch['output'].shape = [1, 210, 524] 54 | return batch 55 | 56 | def forward(self, batch): 57 | return self.compute_loss(batch) 58 | 59 | def forward_test(self, batch): # batch: 'motion_lens', 'prompt', 'text'=['prompt'] 60 | 61 | batch = self.text_process(batch) 62 | batch.update(self.decode_motion(batch)) 63 | return batch 64 | 65 | def text_process(self, batch): 66 | device = next(self.clip_transformer.parameters()).device 67 | 68 | if "text" in batch and batch["text"] is not None and not isinstance(batch["text"], torch.Tensor): 69 | 70 | raw_text = batch["text"] 71 | 72 | with torch.no_grad(): 73 | 74 | text = clip.tokenize(raw_text, truncate=True).to(device) # [batch_szie, n_ctx]=[1,77] 75 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]=[1,77,768] 76 | pe_tokens = x + self.positional_embedding.type(self.dtype) 77 | x = pe_tokens.permute(1, 0, 2) # NLD -> LND [n_ctx, batch_size, d_model]=[77,1,768] 78 | x = self.clip_transformer(x) 79 | x = x.permute(1, 0, 2) 80 | clip_out = self.ln_final(x).type(self.dtype) # [batch_size, n_ctx, d_model]=[1,77,768] 81 | 82 | out = self.clipTransEncoder(clip_out) # [batch_size, n_ctx, d_model]=[1,77,768] 83 | out = self.clip_ln(out) 84 | 85 | cond = out[torch.arange(x.shape[0]), text.argmax(dim=-1)] 86 | batch["cond"] = cond 87 | 88 | if "text_multi_person" in batch and batch["text_multi_person"] is not None and not isinstance(batch["text_multi_person"], torch.Tensor): 89 | raw_text = batch["text_multi_person"] 90 | 91 | with torch.no_grad(): 92 | 93 | text = clip.tokenize(raw_text, truncate=True).to(device) # [batch_szie, n_ctx]=[1,77] 94 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]=[1,77,768] 95 | pe_tokens = x + self.positional_embedding.type(self.dtype) 96 | x = pe_tokens.permute(1, 0, 2) # NLD -> LND [n_ctx, batch_size, d_model]=[77,1,768] 97 | x = self.clip_transformer(x) 98 | x = x.permute(1, 0, 2) 99 | clip_out = self.ln_final(x).type(self.dtype) # [batch_size, n_ctx, d_model]=[1,77,768] 100 | 101 | out = self.clipTransEncoder(clip_out) # [batch_size, n_ctx, d_model]=[1,77,768] 102 | out = self.clip_ln(out) 103 | 104 | cond = out[torch.arange(x.shape[0]), text.argmax(dim=-1)] 105 | 106 | if "cond" in batch: 107 | batch["cond"] = batch["cond"] + cond 108 | else: 109 | batch["cond"] = cond 110 | 111 | return batch 112 | 113 | def load_state_dict(self, state_dict: Union[List[Mapping[str, Any]],Mapping[str, Any]], strict: bool = True): 114 | 115 | if self.archi == 'single': 116 | return super().load_state_dict(state_dict, strict=strict) 117 | 118 | new_state_dict = OrderedDict() 119 | for name, value in state_dict.items(): 120 | if 'decoder.net.net' in name: 121 | name_new = name.replace('decoder.net.net', 'decoder.net.control_branch') 122 | new_state_dict[name_new] = value 123 | new_state_dict[name] = value 124 | else: 125 | new_state_dict[name] = value 126 | return super().load_state_dict(new_state_dict, strict=strict) 127 | -------------------------------------------------------------------------------- /tools/train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append(sys.path[0] + r"/../") 3 | import logging 4 | import torch 5 | import lightning.pytorch as pl 6 | import torch.optim as optim 7 | from collections import OrderedDict 8 | from datasets import DataModule 9 | from configs import get_config 10 | from os.path import join as pjoin 11 | from torch.utils.tensorboard import SummaryWriter 12 | from models import * 13 | import os 14 | import argparse 15 | 16 | os.environ['PL_TORCH_DISTRIBUTED_BACKEND'] = 'nccl' 17 | from lightning.pytorch.strategies import DDPStrategy 18 | torch.set_float32_matmul_precision('medium') 19 | 20 | class LitTrainModel(pl.LightningModule): 21 | def __init__(self, model, cfg): 22 | super().__init__() 23 | # cfg init 24 | self.cfg = cfg 25 | self.mode = cfg.TRAIN.MODE 26 | 27 | self.automatic_optimization = False 28 | 29 | self.save_root = pjoin(self.cfg.GENERAL.CHECKPOINT, self.cfg.GENERAL.EXP_NAME) 30 | self.model_dir = pjoin(self.save_root, 'model') 31 | self.meta_dir = pjoin(self.save_root, 'meta') 32 | self.log_dir = pjoin(self.save_root, 'log') 33 | 34 | os.makedirs(self.model_dir, exist_ok=True) 35 | os.makedirs(self.meta_dir, exist_ok=True) 36 | os.makedirs(self.log_dir, exist_ok=True) 37 | 38 | self.model = model 39 | 40 | self.writer = SummaryWriter(self.log_dir) 41 | 42 | logging.basicConfig(filename=os.path.join(self.log_dir,'train.log'), filemode="w", format="%(asctime)s %(name)s:%(levelname)s:%(message)s", datefmt="%d-%M-%Y %H:%M:%S", level=logging.DEBUG) 43 | 44 | def _configure_optim(self): 45 | 46 | optimizer = optim.AdamW([p for p in self.model.parameters() if p.requires_grad], lr=float(self.cfg.TRAIN.LR), weight_decay=self.cfg.TRAIN.WEIGHT_DECAY) 47 | 48 | scheduler = CosineWarmupScheduler(optimizer=optimizer, warmup=10, max_iters=self.cfg.TRAIN.EPOCH, verbose=True) 49 | return [optimizer], [scheduler] 50 | 51 | def configure_optimizers(self): 52 | return self._configure_optim() 53 | 54 | def forward(self, batch_data): 55 | 56 | name, text, text_multi_person, motion1, motion2, motion_lens, spatial_condition = batch_data 57 | 58 | motion1 = motion1.detach().float() 59 | 60 | batch = OrderedDict({}) 61 | if min(motion2.shape) == 0: # for single human training 62 | motions = motion1 63 | elif min(spatial_condition.shape) == 0: # for pure double huaman training 64 | motion2 = motion2.detach().float() 65 | motions = torch.cat([motion1, motion2], dim=-1) 66 | 67 | 68 | B, T = motion1.shape[:2] 69 | 70 | batch["motions"] = motions.reshape(B, T, -1).type(torch.float32) 71 | batch["motion_lens"] = motion_lens.long() 72 | batch["person_num"] = motions.shape[-1] // motion1.shape[-1] 73 | 74 | if isinstance(text, torch.Tensor): 75 | batch["text"] = None 76 | else: 77 | batch["text"] = text 78 | 79 | if isinstance(text_multi_person, torch.Tensor): 80 | batch["text_multi_person"] = None 81 | else: 82 | batch["text_multi_person"] = text_multi_person 83 | 84 | loss, loss_logs = self.model(batch) 85 | return loss, loss_logs 86 | 87 | def on_train_start(self): 88 | self.rank = 0 89 | self.world_size = 1 90 | self.start_time = time.time() 91 | self.it = self.cfg.TRAIN.LAST_ITER if self.cfg.TRAIN.LAST_ITER else 0 92 | self.epoch = self.cfg.TRAIN.LAST_EPOCH if self.cfg.TRAIN.LAST_EPOCH else 0 93 | self.logs = OrderedDict() 94 | 95 | 96 | def training_step(self, batch, batch_idx): 97 | loss, loss_logs = self.forward(batch) 98 | opt = self.optimizers() 99 | opt.zero_grad() 100 | self.manual_backward(loss) 101 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.5) 102 | opt.step() 103 | 104 | return {"loss": loss, 105 | "loss_logs": loss_logs} 106 | 107 | 108 | def on_train_batch_end(self, outputs, batch, batch_idx): 109 | if outputs.get('skip_batch') or not outputs.get('loss_logs'): 110 | return 111 | for k, v in outputs['loss_logs'].items(): 112 | if k not in self.logs: 113 | self.logs[k] = v.item() 114 | else: 115 | self.logs[k] += v.item() 116 | 117 | self.it += 1 118 | if self.it % self.cfg.TRAIN.LOG_STEPS == 0 and self.device.index == 0: 119 | mean_loss = OrderedDict({}) 120 | for tag, value in self.logs.items(): 121 | mean_loss[tag] = value / self.cfg.TRAIN.LOG_STEPS 122 | self.writer.add_scalar(tag, mean_loss[tag], self.it) 123 | self.logs = OrderedDict() 124 | print_current_loss(self.start_time, self.it, mean_loss, 125 | self.trainer.current_epoch, 126 | inner_iter=batch_idx, 127 | lr=self.trainer.optimizers[0].param_groups[0]['lr']) 128 | 129 | def on_train_epoch_end(self): 130 | # pass 131 | sch = self.lr_schedulers() 132 | if sch is not None: 133 | sch.step() 134 | 135 | def save(self, file_name): 136 | state = {} 137 | try: 138 | state['model'] = self.model.module.state_dict() 139 | except: 140 | state['model'] = self.model.state_dict() 141 | torch.save(state, file_name, _use_new_zipfile_serialization=False) 142 | return 143 | 144 | 145 | def build_models(cfg): 146 | model = InterGenSpatialControlNet(cfg) 147 | return model 148 | 149 | 150 | if __name__ == '__main__': 151 | 152 | parser = argparse.ArgumentParser(description='Process configs.') 153 | parser.add_argument('--model_config', type=str, help='model config') 154 | parser.add_argument('--dataset_config', type=str, help='dataset config') 155 | parser.add_argument('--train_config', type=str, help='train config') 156 | 157 | args = parser.parse_args() 158 | 159 | print(os.getcwd()) 160 | model_cfg = get_config(args.model_config) 161 | train_cfg = get_config(args.train_config) 162 | data_cfg = get_config(args.dataset_config).interhuman 163 | 164 | datamodule = DataModule(data_cfg, train_cfg.TRAIN.BATCH_SIZE, train_cfg.TRAIN.NUM_WORKERS) 165 | model = build_models(model_cfg) 166 | 167 | if train_cfg.TRAIN.FROM_PRETRAIN: 168 | ckpt = torch.load(train_cfg.TRAIN.FROM_PRETRAIN, map_location="cpu") 169 | for k in list(ckpt["state_dict"].keys()): 170 | if "model" in k: 171 | ckpt["state_dict"][k.replace("model.", "")] = ckpt["state_dict"].pop(k) 172 | model.load_state_dict(ckpt["state_dict"], strict=False) 173 | print("checkpoint state loaded!") 174 | 175 | if train_cfg.TRAIN.RESUME: 176 | resume_path = train_cfg.TRAIN.RESUME 177 | else: 178 | resume_path = None 179 | 180 | litmodel = LitTrainModel(model, train_cfg) 181 | 182 | checkpoint_callback = pl.callbacks.ModelCheckpoint(dirpath=litmodel.model_dir, 183 | every_n_epochs=train_cfg.TRAIN.SAVE_EPOCH, 184 | save_top_k = train_cfg.TRAIN.SAVE_TOPK, 185 | ) 186 | trainer = pl.Trainer( 187 | default_root_dir=litmodel.model_dir, 188 | devices="auto", accelerator='gpu', 189 | max_epochs=train_cfg.TRAIN.EPOCH, 190 | strategy=DDPStrategy(find_unused_parameters=True), 191 | precision=32, 192 | callbacks=[checkpoint_callback], 193 | detect_anomaly=True 194 | ) 195 | 196 | trainer.fit(model=litmodel, datamodule=datamodule, ckpt_path=resume_path) 197 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import linalg 3 | from scipy.ndimage import uniform_filter1d 4 | 5 | emb_scale = 6 6 | 7 | # (X - X_train)*(X - X_train) = -2X*X_train + X*X + X_train*X_train 8 | def euclidean_distance_matrix(matrix1, matrix2): 9 | """ 10 | Params: 11 | -- matrix1: N1 x D 12 | -- matrix2: N2 x D 13 | Returns: 14 | -- dist: N1 x N2 15 | dist[i, j] == distance(matrix1[i], matrix2[j]) 16 | """ 17 | assert matrix1.shape[1] == matrix2.shape[1] 18 | d1 = -2 * np.dot(matrix1, matrix2.T) # shape (num_test, num_train) 19 | d2 = np.sum(np.square(matrix1), axis=1, keepdims=True) # shape (num_test, 1) 20 | d3 = np.sum(np.square(matrix2), axis=1) # shape (num_train, ) 21 | dists = np.sqrt(d1 + d2 + d3) # broadcasting 22 | return dists 23 | 24 | def calculate_top_k(mat, top_k): 25 | size = mat.shape[0] 26 | gt_mat = np.expand_dims(np.arange(size), 1).repeat(size, 1) 27 | bool_mat = (mat == gt_mat) 28 | correct_vec = False 29 | top_k_list = [] 30 | for i in range(top_k): 31 | # print(correct_vec, bool_mat[:, i]) 32 | correct_vec = (correct_vec | bool_mat[:, i]) 33 | # print(correct_vec) 34 | top_k_list.append(correct_vec[:, None]) 35 | top_k_mat = np.concatenate(top_k_list, axis=1) 36 | return top_k_mat 37 | 38 | 39 | def calculate_R_precision(embedding1, embedding2, top_k, sum_all=False): 40 | dist_mat = euclidean_distance_matrix(embedding1, embedding2) 41 | argmax = np.argsort(dist_mat, axis=1) 42 | top_k_mat = calculate_top_k(argmax, top_k) 43 | if sum_all: 44 | return top_k_mat.sum(axis=0) 45 | else: 46 | return top_k_mat 47 | 48 | 49 | def calculate_matching_score(embedding1, embedding2, sum_all=False): 50 | assert len(embedding1.shape) == 2 51 | assert embedding1.shape[0] == embedding2.shape[0] 52 | assert embedding1.shape[1] == embedding2.shape[1] 53 | 54 | dist = linalg.norm(embedding1 - embedding2, axis=1) 55 | if sum_all: 56 | return dist.sum(axis=0) 57 | else: 58 | return dist 59 | 60 | 61 | 62 | def calculate_activation_statistics(activations): 63 | """ 64 | Params: 65 | -- activation: num_samples x dim_feat 66 | Returns: 67 | -- mu: dim_feat 68 | -- sigma: dim_feat x dim_feat 69 | """ 70 | activations = activations * emb_scale 71 | mu = np.mean(activations, axis=0) 72 | cov = np.cov(activations, rowvar=False) 73 | return mu, cov 74 | 75 | 76 | def calculate_diversity(activation, diversity_times): 77 | assert len(activation.shape) == 2 78 | assert activation.shape[0] > diversity_times 79 | num_samples = activation.shape[0] 80 | 81 | activation = activation * emb_scale 82 | first_indices = np.random.choice(num_samples, diversity_times, replace=False) 83 | second_indices = np.random.choice(num_samples, diversity_times, replace=False) 84 | dist = linalg.norm((activation[first_indices] - activation[second_indices])/2, axis=1) 85 | return dist.mean() 86 | 87 | 88 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 89 | """Numpy implementation of the Frechet Distance. 90 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 91 | and X_2 ~ N(mu_2, C_2) is 92 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 93 | Stable version by Dougal J. Sutherland. 94 | Params: 95 | -- mu1 : Numpy array containing the activations of a layer of the 96 | inception net (like returned by the function 'get_predictions') 97 | for generated samples. 98 | -- mu2 : The sample mean over activations, precalculated on an 99 | representative data set. 100 | -- sigma1: The covariance matrix over activations for generated samples. 101 | -- sigma2: The covariance matrix over activations, precalculated on an 102 | representative data set. 103 | Returns: 104 | -- : The Frechet Distance. 105 | """ 106 | 107 | mu1 = np.atleast_1d(mu1) 108 | mu2 = np.atleast_1d(mu2) 109 | 110 | sigma1 = np.atleast_2d(sigma1) 111 | sigma2 = np.atleast_2d(sigma2) 112 | 113 | assert mu1.shape == mu2.shape, \ 114 | 'Training and test mean vectors have different lengths' 115 | assert sigma1.shape == sigma2.shape, \ 116 | 'Training and test covariances have different dimensions' 117 | 118 | diff = mu1 - mu2 119 | 120 | # Product might be almost singular 121 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 122 | if not np.isfinite(covmean).all(): 123 | msg = ('fid calculation produces singular product; ' 124 | 'adding %s to diagonal of cov estimates') % eps 125 | print(msg) 126 | offset = np.eye(sigma1.shape[0]) * eps 127 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 128 | 129 | # Numerical error might give slight imaginary component 130 | if np.iscomplexobj(covmean): 131 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 132 | m = np.max(np.abs(covmean.imag)) 133 | raise ValueError('Imaginary component {}'.format(m)) 134 | covmean = covmean.real 135 | 136 | tr_covmean = np.trace(covmean) 137 | 138 | return (diff.dot(diff) + np.trace(sigma1) + 139 | np.trace(sigma2) - 2 * tr_covmean) 140 | 141 | 142 | def calculate_multimodality(activation, multimodality_times): 143 | assert len(activation.shape) == 3 144 | assert activation.shape[1] > multimodality_times 145 | num_per_sent = activation.shape[1] 146 | 147 | first_dices = np.random.choice(num_per_sent, multimodality_times, replace=False) 148 | second_dices = np.random.choice(num_per_sent, multimodality_times, replace=False) 149 | dist = linalg.norm(activation[:, first_dices] - activation[:, second_dices], axis=2) 150 | return dist.mean() 151 | 152 | def calculate_trajectory_error(dist_error, mean_err_traj, mask, strict=True): 153 | ''' dist_error shape [5]: error for each kps in metre 154 | Two threshold: 20 cm and 50 cm. 155 | If mean error in sequence is more then the threshold, fails 156 | return: traj_fail(0.2), traj_fail(0.5), all_kps_fail(0.2), all_kps_fail(0.5), all_mean_err. 157 | Every metrics are already averaged. 158 | ''' 159 | # mean_err_traj = dist_error.mean(1) 160 | if strict: 161 | # Traj fails if any of the key frame fails 162 | traj_fail_02 = 1.0 - (dist_error <= 0.2).all() 163 | traj_fail_05 = 1.0 - (dist_error <= 0.5).all() 164 | else: 165 | # Traj fails if the mean error of all keyframes more than the threshold 166 | traj_fail_02 = (mean_err_traj > 0.2) 167 | traj_fail_05 = (mean_err_traj > 0.5) 168 | all_fail_02 = (dist_error > 0.2).sum() / mask.sum() 169 | all_fail_05 = (dist_error > 0.5).sum() / mask.sum() 170 | 171 | # out = {"traj_fail_02": traj_fail_02, 172 | # "traj_fail_05": traj_fail_05, 173 | # "all_fail_02": all_fail_02, 174 | # "all_fail_05": all_fail_05, 175 | # "all_mean_err": dist_error.mean()} 176 | return np.array([traj_fail_02, traj_fail_05, all_fail_02, all_fail_05, dist_error.sum() / mask.sum()]) 177 | 178 | 179 | def calculate_trajectory_diversity(trajectories, lengths): 180 | ''' Standard diviation of point locations in the trajectories 181 | Args: 182 | trajectories: [bs, rep, 196, 2] 183 | lengths: [bs] 184 | ''' 185 | # [32, 2, 196, 2 (xz)] 186 | # mean_trajs = trajectories.mean(1, keepdims=True) 187 | # dist_to_mean = np.linalg.norm(trajectories - mean_trajs, axis=3) 188 | def traj_div(traj, length): 189 | # traj [rep, 196, 2] 190 | # length (int) 191 | traj = traj[:, :length, :] 192 | # point_var = traj.var(axis=0, keepdims=True).mean() 193 | # point_var = np.sqrt(point_var) 194 | # return point_var 195 | 196 | mean_traj = traj.mean(axis=0, keepdims=True) 197 | dist = np.sqrt(((traj - mean_traj)**2).sum(axis=2)) 198 | rms_dist = np.sqrt((dist**2).mean()) 199 | return rms_dist 200 | 201 | div = [] 202 | for i in range(len(trajectories)): 203 | div.append(traj_div(trajectories[i], lengths[i])) 204 | return np.array(div).mean() 205 | 206 | 207 | def calculate_skating_ratio(motions): 208 | thresh_height = 0.05 # 10 209 | fps = 20.0 210 | thresh_vel = 0.50 # 20 cm /s 211 | avg_window = 5 # frames 212 | 213 | batch_size = motions.shape[0] 214 | # 10 left, 11 right foot. XZ plane, y up 215 | # motions [bs, 22, 3, max_len] 216 | verts_feet = motions[:, [10, 11], :, :].detach().cpu().numpy() # [bs, 2, 3, max_len] 217 | verts_feet_plane_vel = np.linalg.norm(verts_feet[:, :, [0, 2], 1:] - verts_feet[:, :, [0, 2], :-1], axis=2) * fps # [bs, 2, max_len-1] 218 | # [bs, 2, max_len-1] 219 | vel_avg = uniform_filter1d(verts_feet_plane_vel, axis=-1, size=avg_window, mode='constant', origin=0) 220 | 221 | verts_feet_height = verts_feet[:, :, 1, :] # [bs, 2, max_len] 222 | # If feet touch ground in agjecent frames 223 | feet_contact = np.logical_and((verts_feet_height[:, :, :-1] < thresh_height), (verts_feet_height[:, :, 1:] < thresh_height)) # [bs, 2, max_len - 1] 224 | # skate velocity 225 | skate_vel = feet_contact * vel_avg 226 | 227 | # it must both skating in the current frame 228 | skating = np.logical_and(feet_contact, (verts_feet_plane_vel > thresh_vel)) 229 | # and also skate in the windows of frames 230 | skating = np.logical_and(skating, (vel_avg > thresh_vel)) 231 | 232 | # Both feet slide 233 | skating = np.logical_or(skating[:, 0, :], skating[:, 1, :]) # [bs, max_len -1] 234 | skating_ratio = np.sum(skating, axis=1) / skating.shape[1] 235 | 236 | return skating_ratio, skate_vel 237 | 238 | # verts_feet_gt = markers_got[:, [16, 47], :].detach().cpu().numpy() # [119, 2, 3] heels 239 | # verts_feet_horizon_vel_gt = np.linalg.norm(verts_feet_gt[1:, :, :-1] - verts_feet_gt[:-1, :, :-1], axis=-1) * 30 240 | 241 | # verts_feet_height_gt = verts_feet_gt[:, :, -1][0:-1] # [118,2] 242 | # min_z = markers_gt[:, :, 2].min().detach().cpu().numpy() 243 | # verts_feet_height_gt = verts_feet_height_gt - min_z 244 | 245 | # skating_gt = (verts_feet_horizon_vel_gt > thresh_vel) * (verts_feet_height_gt < thresh_height) 246 | # skating_gt = np.sum(np.logival_and(skating_gt[:, 0], skating_gt[:, 1])) / 118 247 | # skating_gt_list.append(skating_gt) 248 | 249 | 250 | def calculate_skating_ratio_kit(motions): 251 | thresh_height = 0.05 # 10 252 | fps = 20.0 253 | thresh_vel = 0.50 # 20 cm /s 254 | avg_window = 5 # frames 255 | 256 | batch_size = motions.shape[0] 257 | # 15 left, 20 right foot. XZ plane, y up 258 | # motions [bs, 22, 3, max_len] 259 | verts_feet = motions[:, [15, 20], :, :].detach().cpu().numpy() # [bs, 2, 3, max_len] 260 | verts_feet_plane_vel = np.linalg.norm(verts_feet[:, :, [0, 2], 1:] - verts_feet[:, :, [0, 2], :-1], axis=2) * fps # [bs, 2, max_len-1] 261 | # [bs, 2, max_len-1] 262 | vel_avg = uniform_filter1d(verts_feet_plane_vel, axis=-1, size=avg_window, mode='constant', origin=0) 263 | 264 | verts_feet_height = verts_feet[:, :, 1, :] # [bs, 2, max_len] 265 | # If feet touch ground in agjecent frames 266 | feet_contact = np.logical_and((verts_feet_height[:, :, :-1] < thresh_height), (verts_feet_height[:, :, 1:] < thresh_height)) # [bs, 2, max_len - 1] 267 | # skate velocity 268 | skate_vel = feet_contact * vel_avg 269 | 270 | # it must both skating in the current frame 271 | skating = np.logical_and(feet_contact, (verts_feet_plane_vel > thresh_vel)) 272 | # and also skate in the windows of frames 273 | skating = np.logical_and(skating, (vel_avg > thresh_vel)) 274 | 275 | # Both feet slide 276 | skating = np.logical_or(skating[:, 0, :], skating[:, 1, :]) # [bs, max_len -1] 277 | skating_ratio = np.sum(skating, axis=1) / skating.shape[1] 278 | 279 | return skating_ratio, skate_vel 280 | 281 | 282 | def control_l2(motion, hint, hint_mask): 283 | # motion: b, seq, 22, 3 284 | # hint: b, seq, 22, 1 285 | loss = np.linalg.norm((motion - hint) * hint_mask, axis=-1) 286 | # loss = loss.sum() / hint_mask.sum() 287 | return loss -------------------------------------------------------------------------------- /tools/eval.py: -------------------------------------------------------------------------------- 1 | 2 | import sys 3 | sys.path.append(sys.path[0]+r"/../") 4 | import numpy as np 5 | import torch 6 | 7 | from datetime import datetime 8 | from datasets import get_dataset_motion_loader, get_motion_loader 9 | from models import * 10 | # from models_intergen2 import * 11 | 12 | from utils.metrics import * 13 | from datasets import EvaluatorModelWrapper 14 | from collections import OrderedDict 15 | from utils.plot_script import * 16 | from utils.utils import * 17 | from configs import get_config 18 | from os.path import join as pjoin 19 | from tqdm import tqdm 20 | import argparse 21 | 22 | os.environ['WORLD_SIZE'] = '1' 23 | os.environ['RANK'] = '0' 24 | os.environ['MASTER_ADDR'] = 'localhost' 25 | os.environ['MASTER_PORT'] = '12345' 26 | torch.multiprocessing.set_sharing_strategy('file_system') 27 | 28 | def build_models(cfg): 29 | model = InterGenSpatialControlNet(cfg) 30 | return model 31 | 32 | def evaluate_matching_score(motion_loaders, file): 33 | match_score_dict = OrderedDict({}) 34 | R_precision_dict = OrderedDict({}) 35 | activation_dict = OrderedDict({}) 36 | # print(motion_loaders.keys()) 37 | print('========== Evaluating MM Distance ==========') 38 | for motion_loader_name, motion_loader in motion_loaders.items(): 39 | all_motion_embeddings = [] 40 | score_list = [] 41 | all_size = 0 42 | mm_dist_sum = 0 43 | top_k_count = 0 44 | # print(motion_loader_name) 45 | with torch.no_grad(): 46 | for idx, batch in tqdm(enumerate(motion_loader)): 47 | text_embeddings, motion_embeddings = eval_wrapper.get_co_embeddings(batch) 48 | # print(text_embeddings.shape) 49 | # print(motion_embeddings.shape) 50 | 51 | dist_mat = euclidean_distance_matrix(text_embeddings.cpu().numpy(), 52 | motion_embeddings.cpu().numpy()) 53 | # print(dist_mat.shape) 54 | mm_dist_sum += dist_mat.trace() 55 | 56 | argsmax = np.argsort(dist_mat, axis=1) 57 | # print(argsmax.shape) 58 | 59 | top_k_mat = calculate_top_k(argsmax, top_k=3) 60 | top_k_count += top_k_mat.sum(axis=0) 61 | 62 | all_size += text_embeddings.shape[0] 63 | 64 | all_motion_embeddings.append(motion_embeddings.cpu().numpy()) 65 | 66 | all_motion_embeddings = np.concatenate(all_motion_embeddings, axis=0) 67 | mm_dist = mm_dist_sum / all_size 68 | R_precision = top_k_count / all_size 69 | match_score_dict[motion_loader_name] = mm_dist 70 | R_precision_dict[motion_loader_name] = R_precision 71 | activation_dict[motion_loader_name] = all_motion_embeddings 72 | 73 | print(f'---> [{motion_loader_name}] MM Distance: {mm_dist:.4f}') 74 | print(f'---> [{motion_loader_name}] MM Distance: {mm_dist:.4f}', file=file, flush=True) 75 | 76 | line = f'---> [{motion_loader_name}] R_precision: ' 77 | for i in range(len(R_precision)): 78 | line += '(top %d): %.4f ' % (i+1, R_precision[i]) 79 | print(line) 80 | print(line, file=file, flush=True) 81 | 82 | return match_score_dict, R_precision_dict, activation_dict 83 | 84 | 85 | def evaluate_fid(groundtruth_loader, activation_dict, file): 86 | eval_dict = OrderedDict({}) 87 | gt_motion_embeddings = [] 88 | print('========== Evaluating FID ==========') 89 | with torch.no_grad(): 90 | for idx, batch in tqdm(enumerate(groundtruth_loader)): 91 | motion_embeddings = eval_wrapper.get_motion_embeddings(batch) 92 | gt_motion_embeddings.append(motion_embeddings.cpu().numpy()) 93 | gt_motion_embeddings = np.concatenate(gt_motion_embeddings, axis=0) 94 | gt_mu, gt_cov = calculate_activation_statistics(gt_motion_embeddings) 95 | 96 | # print(gt_mu) 97 | for model_name, motion_embeddings in activation_dict.items(): 98 | mu, cov = calculate_activation_statistics(motion_embeddings) 99 | # print(mu) 100 | fid = calculate_frechet_distance(gt_mu, gt_cov, mu, cov) 101 | print(f'---> [{model_name}] FID: {fid:.4f}') 102 | print(f'---> [{model_name}] FID: {fid:.4f}', file=file, flush=True) 103 | eval_dict[model_name] = fid 104 | return eval_dict 105 | 106 | 107 | def evaluate_diversity(activation_dict, file): 108 | eval_dict = OrderedDict({}) 109 | print('========== Evaluating Diversity ==========') 110 | for model_name, motion_embeddings in activation_dict.items(): 111 | diversity = calculate_diversity(motion_embeddings, diversity_times) 112 | eval_dict[model_name] = diversity 113 | print(f'---> [{model_name}] Diversity: {diversity:.4f}') 114 | print(f'---> [{model_name}] Diversity: {diversity:.4f}', file=file, flush=True) 115 | return eval_dict 116 | 117 | 118 | def evaluate_multimodality(mm_motion_loaders, file): 119 | eval_dict = OrderedDict({}) 120 | print('========== Evaluating MultiModality ==========') 121 | for model_name, mm_motion_loader in mm_motion_loaders.items(): 122 | mm_motion_embeddings = [] 123 | with torch.no_grad(): 124 | for idx, batch in enumerate(mm_motion_loader): 125 | 126 | batch[-4] = batch[-4][0] 127 | batch[-3] = batch[-3][0] 128 | batch[-2] = batch[-2][0] 129 | motion_embedings = eval_wrapper.get_motion_embeddings(batch) 130 | mm_motion_embeddings.append(motion_embedings.unsqueeze(0)) 131 | if len(mm_motion_embeddings) == 0: 132 | multimodality = 0 133 | else: 134 | mm_motion_embeddings = torch.cat(mm_motion_embeddings, dim=0).cpu().numpy() 135 | multimodality = calculate_multimodality(mm_motion_embeddings, mm_num_times) 136 | print(f'---> [{model_name}] Multimodality: {multimodality:.4f}') 137 | print(f'---> [{model_name}] Multimodality: {multimodality:.4f}', file=file, flush=True) 138 | eval_dict[model_name] = multimodality 139 | return eval_dict 140 | 141 | 142 | def get_metric_statistics(values): 143 | mean = np.mean(values, axis=0) 144 | std = np.std(values, axis=0) 145 | conf_interval = 1.96 * std / np.sqrt(replication_times) 146 | return mean, conf_interval 147 | 148 | 149 | def evaluation(log_file): 150 | with open(log_file, 'w') as f: 151 | all_metrics = OrderedDict({'MM Distance': OrderedDict({}), 152 | 'R_precision': OrderedDict({}), 153 | 'FID': OrderedDict({}), 154 | 'Diversity': OrderedDict({}), 155 | 'MultiModality': OrderedDict({}) 156 | } 157 | ) 158 | for replication in range(replication_times): 159 | motion_loaders = {} 160 | mm_motion_loaders = {} 161 | motion_loaders['ground truth'] = gt_loader 162 | for motion_loader_name, motion_loader_getter in eval_motion_loaders.items(): 163 | motion_loader, mm_motion_loader = motion_loader_getter() 164 | motion_loaders[motion_loader_name] = motion_loader 165 | mm_motion_loaders[motion_loader_name] = mm_motion_loader 166 | 167 | print(f'==================== Replication {replication} ====================') 168 | print(f'==================== Replication {replication} ====================', file=f, flush=True) 169 | print(f'Time: {datetime.now()}') 170 | print(f'Time: {datetime.now()}', file=f, flush=True) 171 | mat_score_dict, R_precision_dict, acti_dict = evaluate_matching_score(motion_loaders, f) 172 | 173 | print(f'Time: {datetime.now()}') 174 | print(f'Time: {datetime.now()}', file=f, flush=True) 175 | fid_score_dict = evaluate_fid(gt_loader, acti_dict, f) 176 | 177 | print(f'Time: {datetime.now()}') 178 | print(f'Time: {datetime.now()}', file=f, flush=True) 179 | div_score_dict = evaluate_diversity(acti_dict, f) 180 | 181 | print(f'Time: {datetime.now()}') 182 | print(f'Time: {datetime.now()}', file=f, flush=True) 183 | mm_score_dict = evaluate_multimodality(mm_motion_loaders, f) 184 | 185 | print(f'!!! DONE !!!') 186 | print(f'!!! DONE !!!', file=f, flush=True) 187 | 188 | for key, item in mat_score_dict.items(): 189 | if key not in all_metrics['MM Distance']: 190 | all_metrics['MM Distance'][key] = [item] 191 | else: 192 | all_metrics['MM Distance'][key] += [item] 193 | 194 | for key, item in R_precision_dict.items(): 195 | if key not in all_metrics['R_precision']: 196 | all_metrics['R_precision'][key] = [item] 197 | else: 198 | all_metrics['R_precision'][key] += [item] 199 | 200 | for key, item in fid_score_dict.items(): 201 | if key not in all_metrics['FID']: 202 | all_metrics['FID'][key] = [item] 203 | else: 204 | all_metrics['FID'][key] += [item] 205 | 206 | for key, item in div_score_dict.items(): 207 | if key not in all_metrics['Diversity']: 208 | all_metrics['Diversity'][key] = [item] 209 | else: 210 | all_metrics['Diversity'][key] += [item] 211 | 212 | for key, item in mm_score_dict.items(): 213 | if key not in all_metrics['MultiModality']: 214 | all_metrics['MultiModality'][key] = [item] 215 | else: 216 | all_metrics['MultiModality'][key] += [item] 217 | 218 | 219 | # print(all_metrics['Diversity']) 220 | for metric_name, metric_dict in all_metrics.items(): 221 | print('========== %s Summary ==========' % metric_name) 222 | print('========== %s Summary ==========' % metric_name, file=f, flush=True) 223 | 224 | for model_name, values in metric_dict.items(): 225 | # print(metric_name, model_name) 226 | mean, conf_interval = get_metric_statistics(np.array(values)) 227 | # print(mean, mean.dtype) 228 | if isinstance(mean, np.float64) or isinstance(mean, np.float32): 229 | print(f'---> [{model_name}] Mean: {mean:.4f} CInterval: {conf_interval:.4f}') 230 | print(f'---> [{model_name}] Mean: {mean:.4f} CInterval: {conf_interval:.4f}', file=f, flush=True) 231 | elif isinstance(mean, np.ndarray): 232 | line = f'---> [{model_name}]' 233 | for i in range(len(mean)): 234 | line += '(top %d) Mean: %.4f CInt: %.4f;' % (i+1, mean[i], conf_interval[i]) 235 | print(line) 236 | print(line, file=f, flush=True) 237 | 238 | 239 | if __name__ == '__main__': 240 | 241 | parser = argparse.ArgumentParser(description='Process configs.') 242 | parser.add_argument('--model_config', type=str, help='model config') 243 | parser.add_argument('--dataset_config', type=str, help='dataset config') 244 | parser.add_argument('--evalmodel_config', type=str, help='train config') 245 | 246 | args = parser.parse_args() 247 | 248 | mm_num_samples = 100 249 | mm_num_repeats = 30 250 | mm_num_times = 10 251 | 252 | diversity_times = 300 253 | replication_times = 20 254 | 255 | # batch_size is fixed to 96!! 256 | batch_size = 96 257 | 258 | data_cfg = get_config(args.dataset_config).interhuman_test 259 | cfg_path_list = [args.model_config] 260 | 261 | 262 | normalizer = MotionNormalizerTorch() 263 | 264 | spatial_mean = normalizer.motion_mean[:66] 265 | spatial_std = normalizer.motion_std[:66] 266 | 267 | eval_motion_loaders = {} 268 | for cfg_path in cfg_path_list: 269 | model_cfg = get_config(cfg_path) 270 | 271 | device = torch.device('cuda:%d' % 0 if torch.cuda.is_available() else 'cpu') 272 | torch.cuda.set_device(0) 273 | model = build_models(model_cfg) 274 | 275 | if model_cfg.CHECKPOINT and isinstance(model_cfg.CHECKPOINT, list): 276 | 277 | ckpts = [] 278 | for ckpt_path in model_cfg.CHECKPOINT: 279 | ckpt = torch.load(ckpt_path, map_location="cpu") 280 | for k in list(ckpt["state_dict"].keys()): 281 | if "model" in k: 282 | ckpt["state_dict"][k.replace("model.", "")] = ckpt["state_dict"].pop(k) 283 | ckpts.append(ckpt['state_dict']) 284 | model.load_state_dict(ckpts, strict=True) 285 | print("checkpoint state loaded!") 286 | elif model_cfg.CHECKPOINT and isinstance(model_cfg.CHECKPOINT,str): 287 | 288 | checkpoint = torch.load(model_cfg.CHECKPOINT, map_location=torch.device("cpu")) 289 | for k in list(checkpoint["state_dict"].keys()): 290 | if "model" in k: 291 | checkpoint["state_dict"][k.replace("model.", "")] = checkpoint["state_dict"].pop(k) 292 | model.load_state_dict(checkpoint['state_dict'], strict=True) 293 | 294 | eval_motion_loaders[model_cfg.NAME] = lambda: get_motion_loader( 295 | batch_size, 296 | model, 297 | gt_dataset, 298 | device, 299 | mm_num_samples, 300 | mm_num_repeats 301 | ) 302 | 303 | device = torch.device('cuda:%d' % 0 if torch.cuda.is_available() else 'cpu') 304 | gt_loader, gt_dataset = get_dataset_motion_loader(data_cfg, batch_size) 305 | evalmodel_cfg = get_config(args.evalmodel_config) 306 | eval_wrapper = EvaluatorModelWrapper(evalmodel_cfg, device) 307 | 308 | log_file = f'evaluation.log' 309 | evaluation(log_file) 310 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | # import cv2 5 | from PIL import Image 6 | import math 7 | import time 8 | import matplotlib.pyplot as plt 9 | from scipy.ndimage import gaussian_filter 10 | from common.quaternion import * 11 | 12 | from utils.rotation_conversions import * 13 | 14 | import logging 15 | 16 | face_joint_indx = [2,1,17,16] 17 | fid_l = [7,10] 18 | fid_r = [8,11] 19 | 20 | 21 | def swap_left_right_position(data): 22 | assert len(data.shape) == 3 and data.shape[-1] == 3 23 | data = data.copy() 24 | data[..., 0] *= -1 25 | right_chain = [2, 5, 8, 11, 14, 17, 19, 21] 26 | left_chain = [1, 4, 7, 10, 13, 16, 18, 20] 27 | left_hand_chain = [22, 23, 24, 34, 35, 36, 25, 26, 27, 31, 32, 33, 28, 29, 30, 52, 53, 54, 55, 56] 28 | right_hand_chain = [43, 44, 45, 46, 47, 48, 40, 41, 42, 37, 38, 39, 49, 50, 51, 57, 58, 59, 60, 61] 29 | 30 | tmp = data[:, right_chain] 31 | data[:, right_chain] = data[:, left_chain] 32 | data[:, left_chain] = tmp 33 | if data.shape[1] > 24: 34 | tmp = data[:, right_hand_chain] 35 | data[:, right_hand_chain] = data[:, left_hand_chain] 36 | data[:, left_hand_chain] = tmp 37 | return data 38 | 39 | def swap_left_right_rot(data): 40 | assert len(data.shape) == 3 and data.shape[-1] == 6 41 | data = data.copy() 42 | 43 | data[..., [1,2,4]] *= -1 44 | 45 | right_chain = np.array([2, 5, 8, 11, 14, 17, 19, 21])-1 46 | left_chain = np.array([1, 4, 7, 10, 13, 16, 18, 20])-1 47 | left_hand_chain = np.array([22, 23, 24, 34, 35, 36, 25, 26, 27, 31, 32, 33, 28, 29, 30,])-1 48 | right_hand_chain = np.array([43, 44, 45, 46, 47, 48, 40, 41, 42, 37, 38, 39, 49, 50, 51,])-1 49 | 50 | tmp = data[:, right_chain] 51 | data[:, right_chain] = data[:, left_chain] 52 | data[:, left_chain] = tmp 53 | if data.shape[1] > 24: 54 | tmp = data[:, right_hand_chain] 55 | data[:, right_hand_chain] = data[:, left_hand_chain] 56 | data[:, left_hand_chain] = tmp 57 | return data 58 | 59 | 60 | def swap_left_right(data, n_joints): 61 | T = data.shape[0] 62 | new_data = data.copy() 63 | positions = new_data[..., :3*n_joints].reshape(T, n_joints, 3) 64 | rotations = new_data[..., 3*n_joints:].reshape(T, -1, 6) 65 | 66 | positions = swap_left_right_position(positions) 67 | rotations = swap_left_right_rot(rotations) 68 | 69 | new_data = np.concatenate([positions.reshape(T, -1), rotations.reshape(T, -1)], axis=-1) 70 | return new_data 71 | 72 | 73 | def rigid_transform(relative, data): # 这段代码实现了一个刚性变换,将给定的相对变换应用到数据中的全局位置和全局速度上。 74 | 75 | global_positions = data[..., :22 * 3].reshape(data.shape[:-1] + (22, 3)) 76 | global_vel = data[..., 22 * 3:22 * 6].reshape(data.shape[:-1] + (22, 3)) 77 | 78 | relative_rot = relative[0] 79 | relative_t = relative[1:3] 80 | relative_r_rot_quat = np.zeros(global_positions.shape[:-1] + (4,)) 81 | relative_r_rot_quat[..., 0] = np.cos(relative_rot) 82 | relative_r_rot_quat[..., 2] = np.sin(relative_rot) 83 | global_positions = qrot_np(qinv_np(relative_r_rot_quat), global_positions) 84 | global_positions[..., [0, 2]] += relative_t 85 | data[..., :22 * 3] = global_positions.reshape(data.shape[:-1] + (-1,)) 86 | global_vel = qrot_np(qinv_np(relative_r_rot_quat), global_vel) 87 | data[..., 22 * 3:22 * 6] = global_vel.reshape(data.shape[:-1] + (-1,)) 88 | 89 | return data 90 | 91 | 92 | class MotionNormalizer(): 93 | def __init__(self): 94 | mean = np.load("./data/global_mean.npy") 95 | std = np.load("./data/global_std.npy") 96 | 97 | self.motion_mean = mean # (262,) 98 | self.motion_std = std # (262,) 99 | 100 | 101 | def forward(self, x): 102 | 103 | x = (x - self.motion_mean) / self.motion_std 104 | return x 105 | 106 | def backward(self, x): 107 | x = x * self.motion_std + self.motion_mean 108 | return x 109 | 110 | 111 | 112 | class MotionNormalizerTorch(): 113 | def __init__(self): 114 | mean = np.load("./data/global_mean.npy") # shape=(262,) 115 | std = np.load("./data/global_std.npy") 116 | 117 | self.motion_mean = torch.from_numpy(mean).float() 118 | self.motion_std = torch.from_numpy(std).float() 119 | 120 | 121 | def forward(self, x): 122 | device = x.device 123 | x = x.clone() 124 | x = (x - self.motion_mean.to(device)) / self.motion_std.to(device) 125 | return x 126 | 127 | def backward(self, x, global_rt=False): 128 | device = x.device 129 | x = x.clone() 130 | x = x * self.motion_std.to(device) + self.motion_mean.to(device) 131 | return x 132 | 133 | trans_matrix = torch.Tensor([[1.0, 0.0, 0.0], # 绕x转180度 134 | [0.0, 0.0, 1.0], 135 | [0.0, -1.0, 0.0]]) 136 | 137 | 138 | def get_orientation(motion, prev_frames, n_joints): 139 | positions = motion[:, :n_joints*3].reshape(-1, n_joints, 3) # [95,22,3] 140 | root_pos_init = positions[prev_frames] 141 | 142 | r_hip, l_hip, sdr_r, sdr_l = face_joint_indx 143 | across = root_pos_init[r_hip] - root_pos_init[l_hip] # 计算身体的横向向量 144 | across = across / np.sqrt((across ** 2).sum(axis=-1))[..., np.newaxis] # 归一化,使得across只表示方向 145 | 146 | # forward (3,), rotate around y-axis 147 | forward_init = np.cross(np.array([[0, 1, 0]]), across, axis=-1) # 将 [0,1,0]和横向向量across叉乘,计算身体的前向向量 148 | # forward (3,) 149 | forward_init = forward_init / np.sqrt((forward_init ** 2).sum(axis=-1))[..., np.newaxis] # 归一化 150 | return forward_init 151 | 152 | def process_motion_np(motion, feet_thre, prev_frames, n_joints, target=np.array([[0, 0, 1]])): 153 | 154 | # (seq_len, joints_num, 3) 155 | # '''Down Sample''' 156 | # positions = positions[::ds_num] 157 | 158 | '''Uniform Skeleton''' 159 | # positions = uniform_skeleton(positions, tgt_offsets) 160 | 161 | positions = motion[:, :n_joints*3].reshape(-1, n_joints, 3) # [95,22,3] 162 | rotations = motion[:, n_joints*3:] # [95,126=21*6] 163 | 164 | positions = np.einsum("mn, tjn->tjm", trans_matrix, positions) 165 | 166 | '''Put on Floor''' 167 | floor_height = positions.min(axis=0).min(axis=0)[1] # find the lowest point across all frames and joints 168 | positions[:, :, 1] -= floor_height 169 | 170 | 171 | '''XZ at origin''' 172 | # move the root pos of the first frame to (0,0) of xz plane 173 | # for example 174 | # poistion_root_init = [2,3,5] -> [0,3,5] 175 | 176 | root_pos_init = positions[prev_frames] # [95,22,3] -> [22,3] 177 | root_pose_init_xz = root_pos_init[0] * np.array([1, 0, 1]) # [3] 178 | positions = positions - root_pose_init_xz # 把第一帧root关节平移到(0,y,0)处, [95,22,3] 179 | 180 | '''All initially face Z+''' 181 | r_hip, l_hip, sdr_r, sdr_l = face_joint_indx 182 | across = root_pos_init[r_hip] - root_pos_init[l_hip] # 计算身体的横向向量 183 | across = across / np.sqrt((across ** 2).sum(axis=-1))[..., np.newaxis] # 归一化,使得across只表示方向 184 | 185 | # forward (3,), rotate around y-axis 186 | forward_init = np.cross(np.array([[0, 1, 0]]), across, axis=-1) # 将 [0,1,0]和横向向量across叉乘,计算身体的前向向量 187 | # forward (3,) 188 | forward_init = forward_init / np.sqrt((forward_init ** 2).sum(axis=-1))[..., np.newaxis] # 归一化 189 | 190 | # target = np.array([[0, 0, 1]]) # 目标朝向,即z轴正方向 191 | root_quat_init = qbetween_np(forward_init, target) # 使用qbetween_np函数计算了一个四元数root_quat_init,它表示将初始的前向向量forward_init旋转到目标向量target所需的旋转。 192 | root_quat_init_for_all = np.ones(positions.shape[:-1] + (4,)) * root_quat_init # 代码使用qrot_np函数将所有的姿势向量positions应用了旋转,使得身体的朝向与期望的朝向一致。 193 | 194 | 195 | positions = qrot_np(root_quat_init_for_all, positions) # [95,22,3] 代码使用qrot_np函数将所有的姿势向量positions应用了旋转,使得身体的朝向与期望的朝向一致 196 | 197 | """ Get Foot Contacts """ 198 | 199 | def foot_detect(positions, thres): 200 | velfactor, heightfactor = np.array([thres, thres]), np.array([0.12, 0.05]) 201 | 202 | feet_l_x = (positions[1:, fid_l, 0] - positions[:-1, fid_l, 0]) ** 2 203 | feet_l_y = (positions[1:, fid_l, 1] - positions[:-1, fid_l, 1]) ** 2 204 | feet_l_z = (positions[1:, fid_l, 2] - positions[:-1, fid_l, 2]) ** 2 205 | feet_l_h = positions[:-1,fid_l,1] 206 | feet_l = (((feet_l_x + feet_l_y + feet_l_z) < velfactor) & (feet_l_h < heightfactor)).astype(np.float32) 207 | 208 | feet_r_x = (positions[1:, fid_r, 0] - positions[:-1, fid_r, 0]) ** 2 209 | feet_r_y = (positions[1:, fid_r, 1] - positions[:-1, fid_r, 1]) ** 2 210 | feet_r_z = (positions[1:, fid_r, 2] - positions[:-1, fid_r, 2]) ** 2 211 | feet_r_h = positions[:-1,fid_r,1] 212 | feet_r = (((feet_r_x + feet_r_y + feet_r_z) < velfactor) & (feet_r_h < heightfactor)).astype(np.float32) 213 | return feet_l, feet_r 214 | # 215 | feet_l, feet_r = foot_detect(positions, feet_thre) # [94,2], [94,2] 216 | 217 | '''Get Joint Rotation Representation''' 218 | rot_data = rotations # [95,126] 219 | 220 | '''Get Joint Rotation Invariant Position Represention''' 221 | joint_positions = positions.reshape(len(positions), -1) # [95,66] 222 | joint_vels = positions[1:] - positions[:-1] # [94,66] 223 | joint_vels = joint_vels.reshape(len(joint_vels), -1) # [94,66] 224 | 225 | data = joint_positions[:-1] # [94,66] 226 | data = np.concatenate([data, joint_vels], axis=-1) # [94,66+66] 227 | data = np.concatenate([data, rot_data[:-1]], axis=-1) # [94,66+66+126] 228 | data = np.concatenate([data, feet_l, feet_r], axis=-1) # [94,66+66+126+2+2] 229 | 230 | return data, root_quat_init, root_pose_init_xz[None] 231 | 232 | def mkdir(path): 233 | if not os.path.exists(path): 234 | os.makedirs(path) 235 | 236 | COLORS = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], 237 | [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], 238 | [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]] 239 | 240 | MISSING_VALUE = -1 241 | 242 | def save_image(image_numpy, image_path): 243 | img_pil = Image.fromarray(image_numpy) 244 | img_pil.save(image_path) 245 | 246 | 247 | def save_logfile(log_loss, save_path): 248 | with open(save_path, 'wt') as f: 249 | for k, v in log_loss.items(): 250 | w_line = k 251 | for digit in v: 252 | w_line += ' %.3f' % digit 253 | f.write(w_line + '\n') 254 | 255 | 256 | def print_current_loss(start_time, niter_state, losses, epoch=None, inner_iter=None, lr=None): 257 | 258 | def as_minutes(s): 259 | m = math.floor(s / 60) 260 | s -= m * 60 261 | return '%dm %ds' % (m, s) 262 | 263 | def time_since(since, percent): 264 | now = time.time() 265 | s = now - since 266 | es = s / percent 267 | rs = es - s 268 | return '%s (- %s)' % (as_minutes(s), as_minutes(rs)) 269 | 270 | if epoch is not None and lr is not None : 271 | epoch_mes = 'epoch: %3d niter:%6d inner_iter:%4d lr:%5f' % (epoch, niter_state, inner_iter, lr) 272 | print(epoch_mes, end=" ") 273 | elif epoch is not None: 274 | epoch_mes = 'epoch: %3d niter:%6d inner_iter:%4d' % (epoch, niter_state, inner_iter) 275 | print(epoch_mes, end=" ") 276 | 277 | now = time.time() 278 | message = '%s'%(as_minutes(now - start_time)) 279 | 280 | for k, v in losses.items(): 281 | message += ' %s: %.4f ' % (k, v) 282 | print(message) 283 | logging.info(epoch_mes+' '+message) 284 | 285 | 286 | def compose_gif_img_list(img_list, fp_out, duration): 287 | img, *imgs = [Image.fromarray(np.array(image)) for image in img_list] 288 | img.save(fp=fp_out, format='GIF', append_images=imgs, optimize=False, 289 | save_all=True, loop=0, duration=duration) 290 | 291 | 292 | def save_images(visuals, image_path): 293 | if not os.path.exists(image_path): 294 | os.makedirs(image_path) 295 | 296 | for i, (label, img_numpy) in enumerate(visuals.items()): 297 | img_name = '%d_%s.jpg' % (i, label) 298 | save_path = os.path.join(image_path, img_name) 299 | save_image(img_numpy, save_path) 300 | 301 | 302 | def save_images_test(visuals, image_path, from_name, to_name): 303 | if not os.path.exists(image_path): 304 | os.makedirs(image_path) 305 | 306 | for i, (label, img_numpy) in enumerate(visuals.items()): 307 | img_name = "%s_%s_%s" % (from_name, to_name, label) 308 | save_path = os.path.join(image_path, img_name) 309 | save_image(img_numpy, save_path) 310 | 311 | 312 | def compose_and_save_img(img_list, save_dir, img_name, col=4, row=1, img_size=(256, 200)): 313 | # print(col, row) 314 | compose_img = compose_image(img_list, col, row, img_size) 315 | if not os.path.exists(save_dir): 316 | os.makedirs(save_dir) 317 | img_path = os.path.join(save_dir, img_name) 318 | # print(img_path) 319 | compose_img.save(img_path) 320 | 321 | 322 | def compose_image(img_list, col, row, img_size): 323 | to_image = Image.new('RGB', (col * img_size[0], row * img_size[1])) 324 | for y in range(0, row): 325 | for x in range(0, col): 326 | from_img = Image.fromarray(img_list[y * col + x]) 327 | # print((x * img_size[0], y*img_size[1], 328 | # (x + 1) * img_size[0], (y + 1) * img_size[1])) 329 | paste_area = (x * img_size[0], y*img_size[1], 330 | (x + 1) * img_size[0], (y + 1) * img_size[1]) 331 | to_image.paste(from_img, paste_area) 332 | # to_image[y*img_size[1]:(y + 1) * img_size[1], x * img_size[0] :(x + 1) * img_size[0]] = from_img 333 | return to_image 334 | 335 | 336 | def list_cut_average(ll, intervals): 337 | if intervals == 1: 338 | return ll 339 | 340 | bins = math.ceil(len(ll) * 1.0 / intervals) 341 | ll_new = [] 342 | for i in range(bins): 343 | l_low = intervals * i 344 | l_high = l_low + intervals 345 | l_high = l_high if l_high < len(ll) else len(ll) 346 | ll_new.append(np.mean(ll[l_low:l_high])) 347 | return ll_new 348 | 349 | 350 | def motion_temporal_filter(motion, sigma=1): 351 | motion = motion.reshape(motion.shape[0], -1) 352 | # print(motion.shape) 353 | for i in range(motion.shape[1]): 354 | motion[:, i] = gaussian_filter(motion[:, i], sigma=sigma, mode="nearest") 355 | return motion.reshape(motion.shape[0], -1, 3) 356 | 357 | -------------------------------------------------------------------------------- /common/quaternion.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import torch 9 | import numpy as np 10 | 11 | _EPS4 = np.finfo(np.float32).eps * 4.0 12 | 13 | _FLOAT_EPS = np.finfo(np.float32).eps 14 | 15 | # PyTorch-backed implementations 16 | def qinv(q): 17 | assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)' 18 | mask = torch.ones_like(q) 19 | mask[..., 1:] = -mask[..., 1:] 20 | return q * mask 21 | 22 | 23 | def qinv_np(q): 24 | assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)' 25 | return qinv(torch.from_numpy(q).float()).numpy() 26 | 27 | 28 | def qnormalize(q): 29 | assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)' 30 | return q / torch.norm(q, dim=-1, keepdim=True) 31 | 32 | 33 | def qmul(q, r): 34 | """ 35 | Multiply quaternion(s) q with quaternion(s) r. 36 | Expects two equally-sized tensors of shape (*, 4), where * denotes any number of dimensions. 37 | Returns q*r as a tensor of shape (*, 4). 38 | """ 39 | assert q.shape[-1] == 4 40 | assert r.shape[-1] == 4 41 | 42 | original_shape = q.shape 43 | 44 | # Compute outer product 45 | terms = torch.bmm(r.view(-1, 4, 1), q.view(-1, 1, 4)) 46 | 47 | w = terms[:, 0, 0] - terms[:, 1, 1] - terms[:, 2, 2] - terms[:, 3, 3] 48 | x = terms[:, 0, 1] + terms[:, 1, 0] - terms[:, 2, 3] + terms[:, 3, 2] 49 | y = terms[:, 0, 2] + terms[:, 1, 3] + terms[:, 2, 0] - terms[:, 3, 1] 50 | z = terms[:, 0, 3] - terms[:, 1, 2] + terms[:, 2, 1] + terms[:, 3, 0] 51 | return torch.stack((w, x, y, z), dim=1).view(original_shape) 52 | 53 | 54 | def qrot(q, v): 55 | """ 56 | Rotate vector(s) v about the rotation described by quaternion(s) q. 57 | Expects a tensor of shape (*, 4) for q and a tensor of shape (*, 3) for v, 58 | where * denotes any number of dimensions. 59 | Returns a tensor of shape (*, 3). 60 | """ 61 | assert q.shape[-1] == 4 62 | assert v.shape[-1] == 3 63 | assert q.shape[:-1] == v.shape[:-1] 64 | 65 | original_shape = list(v.shape) 66 | # print(q.shape) 67 | q = q.contiguous().view(-1, 4) 68 | v = v.contiguous().view(-1, 3) 69 | 70 | qvec = q[:, 1:] 71 | uv = torch.cross(qvec, v, dim=1) 72 | uuv = torch.cross(qvec, uv, dim=1) 73 | return (v + 2 * (q[:, :1] * uv + uuv)).view(original_shape) 74 | 75 | 76 | def qeuler(q, order, epsilon=0, deg=True): 77 | """ 78 | Convert quaternion(s) q to Euler angles. 79 | Expects a tensor of shape (*, 4), where * denotes any number of dimensions. 80 | Returns a tensor of shape (*, 3). 81 | """ 82 | assert q.shape[-1] == 4 83 | 84 | original_shape = list(q.shape) 85 | original_shape[-1] = 3 86 | q = q.view(-1, 4) 87 | 88 | q0 = q[:, 0] 89 | q1 = q[:, 1] 90 | q2 = q[:, 2] 91 | q3 = q[:, 3] 92 | 93 | if order == 'xyz': 94 | x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2)) 95 | y = torch.asin(torch.clamp(2 * (q1 * q3 + q0 * q2), -1 + epsilon, 1 - epsilon)) 96 | z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3)) 97 | elif order == 'yzx': 98 | x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3)) 99 | y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3)) 100 | z = torch.asin(torch.clamp(2 * (q1 * q2 + q0 * q3), -1 + epsilon, 1 - epsilon)) 101 | elif order == 'zxy': 102 | x = torch.asin(torch.clamp(2 * (q0 * q1 + q2 * q3), -1 + epsilon, 1 - epsilon)) 103 | y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q1 * q1 + q2 * q2)) 104 | z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q1 * q1 + q3 * q3)) 105 | elif order == 'xzy': 106 | x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3)) 107 | y = torch.atan2(2 * (q0 * q2 + q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3)) 108 | z = torch.asin(torch.clamp(2 * (q0 * q3 - q1 * q2), -1 + epsilon, 1 - epsilon)) 109 | elif order == 'yxz': 110 | x = torch.asin(torch.clamp(2 * (q0 * q1 - q2 * q3), -1 + epsilon, 1 - epsilon)) 111 | y = torch.atan2(2 * (q1 * q3 + q0 * q2), 1 - 2 * (q1 * q1 + q2 * q2)) 112 | z = torch.atan2(2 * (q1 * q2 + q0 * q3), 1 - 2 * (q1 * q1 + q3 * q3)) 113 | elif order == 'zyx': 114 | x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2)) 115 | y = torch.asin(torch.clamp(2 * (q0 * q2 - q1 * q3), -1 + epsilon, 1 - epsilon)) 116 | z = torch.atan2(2 * (q0 * q3 + q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3)) 117 | else: 118 | raise 119 | 120 | if deg: 121 | return torch.stack((x, y, z), dim=1).view(original_shape) * 180 / np.pi 122 | else: 123 | return torch.stack((x, y, z), dim=1).view(original_shape) 124 | 125 | 126 | # Numpy-backed implementations 127 | 128 | def qmul_np(q, r): 129 | q = torch.from_numpy(q).contiguous().float() 130 | r = torch.from_numpy(r).contiguous().float() 131 | return qmul(q, r).numpy() 132 | 133 | 134 | def qrot_np(q, v): 135 | q = torch.from_numpy(q).contiguous().float() 136 | v = torch.from_numpy(v).contiguous().float() 137 | return qrot(q, v).numpy() 138 | 139 | 140 | def qeuler_np(q, order, epsilon=0, use_gpu=False): 141 | if use_gpu: 142 | q = torch.from_numpy(q).cuda().float() 143 | return qeuler(q, order, epsilon).cpu().numpy() 144 | else: 145 | q = torch.from_numpy(q).contiguous().float() 146 | return qeuler(q, order, epsilon).numpy() 147 | 148 | 149 | def qfix(q): 150 | """ 151 | Enforce quaternion continuity across the time dimension by selecting 152 | the representation (q or -q) with minimal distance (or, equivalently, maximal dot product) 153 | between two consecutive frames. 154 | 155 | Expects a tensor of shape (L, J, 4), where L is the sequence length and J is the number of joints. 156 | Returns a tensor of the same shape. 157 | """ 158 | assert len(q.shape) == 3 159 | assert q.shape[-1] == 4 160 | 161 | result = q.copy() 162 | dot_products = np.sum(q[1:] * q[:-1], axis=2) 163 | mask = dot_products < 0 164 | mask = (np.cumsum(mask, axis=0) % 2).astype(bool) 165 | result[1:][mask] *= -1 166 | return result 167 | 168 | 169 | def euler2quat(e, order, deg=True): 170 | """ 171 | Convert Euler angles to quaternions. 172 | """ 173 | assert e.shape[-1] == 3 174 | 175 | original_shape = list(e.shape) 176 | original_shape[-1] = 4 177 | 178 | e = e.view(-1, 3) 179 | 180 | ## if euler angles in degrees 181 | if deg: 182 | e = e * np.pi / 180. 183 | 184 | x = e[:, 0] 185 | y = e[:, 1] 186 | z = e[:, 2] 187 | 188 | rx = torch.stack((torch.cos(x / 2), torch.sin(x / 2), torch.zeros_like(x), torch.zeros_like(x)), dim=1) 189 | ry = torch.stack((torch.cos(y / 2), torch.zeros_like(y), torch.sin(y / 2), torch.zeros_like(y)), dim=1) 190 | rz = torch.stack((torch.cos(z / 2), torch.zeros_like(z), torch.zeros_like(z), torch.sin(z / 2)), dim=1) 191 | 192 | result = None 193 | for coord in order: 194 | if coord == 'x': 195 | r = rx 196 | elif coord == 'y': 197 | r = ry 198 | elif coord == 'z': 199 | r = rz 200 | else: 201 | raise 202 | if result is None: 203 | result = r 204 | else: 205 | result = qmul(result, r) 206 | 207 | # Reverse antipodal representation to have a non-negative "w" 208 | if order in ['xyz', 'yzx', 'zxy']: 209 | result *= -1 210 | 211 | return result.view(original_shape) 212 | 213 | 214 | def expmap_to_quaternion(e): 215 | """ 216 | Convert axis-angle rotations (aka exponential maps) to quaternions. 217 | Stable formula from "Practical Parameterization of Rotations Using the Exponential Map". 218 | Expects a tensor of shape (*, 3), where * denotes any number of dimensions. 219 | Returns a tensor of shape (*, 4). 220 | """ 221 | assert e.shape[-1] == 3 222 | 223 | original_shape = list(e.shape) 224 | original_shape[-1] = 4 225 | e = e.reshape(-1, 3) 226 | 227 | theta = np.linalg.norm(e, axis=1).reshape(-1, 1) 228 | w = np.cos(0.5 * theta).reshape(-1, 1) 229 | xyz = 0.5 * np.sinc(0.5 * theta / np.pi) * e 230 | return np.concatenate((w, xyz), axis=1).reshape(original_shape) 231 | 232 | 233 | def euler_to_quaternion(e, order): 234 | """ 235 | Convert Euler angles to quaternions. 236 | """ 237 | assert e.shape[-1] == 3 238 | 239 | original_shape = list(e.shape) 240 | original_shape[-1] = 4 241 | 242 | e = e.reshape(-1, 3) 243 | 244 | x = e[:, 0] 245 | y = e[:, 1] 246 | z = e[:, 2] 247 | 248 | rx = np.stack((np.cos(x / 2), np.sin(x / 2), np.zeros_like(x), np.zeros_like(x)), axis=1) 249 | ry = np.stack((np.cos(y / 2), np.zeros_like(y), np.sin(y / 2), np.zeros_like(y)), axis=1) 250 | rz = np.stack((np.cos(z / 2), np.zeros_like(z), np.zeros_like(z), np.sin(z / 2)), axis=1) 251 | 252 | result = None 253 | for coord in order: 254 | if coord == 'x': 255 | r = rx 256 | elif coord == 'y': 257 | r = ry 258 | elif coord == 'z': 259 | r = rz 260 | else: 261 | raise 262 | if result is None: 263 | result = r 264 | else: 265 | result = qmul_np(result, r) 266 | 267 | # Reverse antipodal representation to have a non-negative "w" 268 | if order in ['xyz', 'yzx', 'zxy']: 269 | result *= -1 270 | 271 | return result.reshape(original_shape) 272 | 273 | 274 | def quaternion_to_matrix(quaternions): 275 | """ 276 | Convert rotations given as quaternions to rotation matrices. 277 | Args: 278 | quaternions: quaternions with real part first, 279 | as tensor of shape (..., 4). 280 | Returns: 281 | Rotation matrices as tensor of shape (..., 3, 3). 282 | """ 283 | r, i, j, k = torch.unbind(quaternions, -1) 284 | two_s = 2.0 / (quaternions * quaternions).sum(-1) 285 | 286 | o = torch.stack( 287 | ( 288 | 1 - two_s * (j * j + k * k), 289 | two_s * (i * j - k * r), 290 | two_s * (i * k + j * r), 291 | two_s * (i * j + k * r), 292 | 1 - two_s * (i * i + k * k), 293 | two_s * (j * k - i * r), 294 | two_s * (i * k - j * r), 295 | two_s * (j * k + i * r), 296 | 1 - two_s * (i * i + j * j), 297 | ), 298 | -1, 299 | ) 300 | return o.reshape(quaternions.shape[:-1] + (3, 3)) 301 | 302 | 303 | def quaternion_to_matrix_np(quaternions): 304 | q = torch.from_numpy(quaternions).contiguous().float() 305 | return quaternion_to_matrix(q).numpy() 306 | 307 | 308 | def quaternion_to_cont6d_np(quaternions): 309 | rotation_mat = quaternion_to_matrix_np(quaternions) 310 | cont_6d = np.concatenate([rotation_mat[..., 0], rotation_mat[..., 1]], axis=-1) 311 | return cont_6d 312 | 313 | 314 | def quaternion_to_cont6d(quaternions): 315 | rotation_mat = quaternion_to_matrix(quaternions) 316 | cont_6d = torch.cat([rotation_mat[..., 0], rotation_mat[..., 1]], dim=-1) 317 | return cont_6d 318 | 319 | 320 | def cont6d_to_matrix(cont6d): 321 | assert cont6d.shape[-1] == 6, "The last dimension must be 6" 322 | x_raw = cont6d[..., 0:3] 323 | y_raw = cont6d[..., 3:6] 324 | 325 | x = x_raw / torch.norm(x_raw, dim=-1, keepdim=True) 326 | z = torch.cross(x, y_raw, dim=-1) 327 | z = z / torch.norm(z, dim=-1, keepdim=True) 328 | 329 | y = torch.cross(z, x, dim=-1) 330 | 331 | x = x[..., None] 332 | y = y[..., None] 333 | z = z[..., None] 334 | 335 | mat = torch.cat([x, y, z], dim=-1) 336 | return mat 337 | 338 | 339 | def cont6d_to_matrix_np(cont6d): 340 | q = torch.from_numpy(cont6d).contiguous().float() 341 | return cont6d_to_matrix(q).numpy() 342 | 343 | 344 | def qpow(q0, t, dtype=torch.float): 345 | ''' q0 : tensor of quaternions 346 | t: tensor of powers 347 | ''' 348 | q0 = qnormalize(q0) 349 | theta0 = torch.acos(q0[..., 0]) 350 | 351 | ## if theta0 is close to zero, add epsilon to avoid NaNs 352 | mask = (theta0 <= 10e-10) * (theta0 >= -10e-10) 353 | theta0 = (1 - mask) * theta0 + mask * 10e-10 354 | v0 = q0[..., 1:] / torch.sin(theta0).view(-1, 1) 355 | 356 | if isinstance(t, torch.Tensor): 357 | q = torch.zeros(t.shape + q0.shape) 358 | theta = t.view(-1, 1) * theta0.view(1, -1) 359 | else: ## if t is a number 360 | q = torch.zeros(q0.shape) 361 | theta = t * theta0 362 | 363 | q[..., 0] = torch.cos(theta) 364 | q[..., 1:] = v0 * torch.sin(theta).unsqueeze(-1) 365 | 366 | return q.to(dtype) 367 | 368 | 369 | def qslerp(q0, q1, t): 370 | ''' 371 | q0: starting quaternion 372 | q1: ending quaternion 373 | t: array of points along the way 374 | 375 | Returns: 376 | Tensor of Slerps: t.shape + q0.shape 377 | ''' 378 | 379 | q0 = qnormalize(q0) 380 | q1 = qnormalize(q1) 381 | q_ = qpow(qmul(q1, qinv(q0)), t) 382 | 383 | return qmul(q_, 384 | q0.contiguous().view(torch.Size([1] * len(t.shape)) + q0.shape).expand(t.shape + q0.shape).contiguous()) 385 | 386 | 387 | def qbetween(v0, v1): 388 | ''' 389 | find the quaternion used to rotate v0 to v1 390 | ''' 391 | assert v0.shape[-1] == 3, 'v0 must be of the shape (*, 3)' 392 | assert v1.shape[-1] == 3, 'v1 must be of the shape (*, 3)' 393 | 394 | v = torch.cross(v0, v1) 395 | w = torch.sqrt((v0 ** 2).sum(dim=-1, keepdim=True) * (v1 ** 2).sum(dim=-1, keepdim=True)) + (v0 * v1).sum(dim=-1, 396 | keepdim=True) 397 | return qnormalize(torch.cat([w, v], dim=-1)) 398 | 399 | 400 | def qbetween_np(v0, v1): 401 | ''' 402 | find the quaternion used to rotate v0 to v1 403 | ''' 404 | assert v0.shape[-1] == 3, 'v0 must be of the shape (*, 3)' 405 | assert v1.shape[-1] == 3, 'v1 must be of the shape (*, 3)' 406 | 407 | v0 = torch.from_numpy(v0).float() 408 | v1 = torch.from_numpy(v1).float() 409 | return qbetween(v0, v1).numpy() 410 | 411 | 412 | def lerp(p0, p1, t): 413 | if not isinstance(t, torch.Tensor): 414 | t = torch.Tensor([t]) 415 | 416 | new_shape = t.shape + p0.shape 417 | new_view_t = t.shape + torch.Size([1] * len(p0.shape)) 418 | new_view_p = torch.Size([1] * len(t.shape)) + p0.shape 419 | p0 = p0.view(new_view_p).expand(new_shape) 420 | p1 = p1.view(new_view_p).expand(new_shape) 421 | t = t.view(new_view_t).expand(new_shape) 422 | 423 | return p0 + t * (p1 - p0) 424 | -------------------------------------------------------------------------------- /utils/quaternion.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import torch 9 | import numpy as np 10 | 11 | _EPS4 = np.finfo(np.float32).eps * 4.0 12 | 13 | _FLOAT_EPS = np.finfo(np.float32).eps 14 | 15 | # PyTorch-backed implementations 16 | def qinv(q): 17 | assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)' 18 | mask = torch.ones_like(q) 19 | mask[..., 1:] = -mask[..., 1:] 20 | return q * mask 21 | 22 | 23 | def qinv_np(q): 24 | assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)' 25 | return qinv(torch.from_numpy(q).float()).numpy() 26 | 27 | 28 | def qnormalize(q): 29 | assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)' 30 | return q / torch.norm(q, dim=-1, keepdim=True) 31 | 32 | 33 | def qmul(q, r): 34 | """ 35 | Multiply quaternion(s) q with quaternion(s) r. 36 | Expects two equally-sized tensors of shape (*, 4), where * denotes any number of dimensions. 37 | Returns q*r as a tensor of shape (*, 4). 38 | """ 39 | assert q.shape[-1] == 4 40 | assert r.shape[-1] == 4 41 | 42 | original_shape = q.shape 43 | 44 | # Compute outer product 45 | terms = torch.bmm(r.view(-1, 4, 1), q.view(-1, 1, 4)) 46 | 47 | w = terms[:, 0, 0] - terms[:, 1, 1] - terms[:, 2, 2] - terms[:, 3, 3] 48 | x = terms[:, 0, 1] + terms[:, 1, 0] - terms[:, 2, 3] + terms[:, 3, 2] 49 | y = terms[:, 0, 2] + terms[:, 1, 3] + terms[:, 2, 0] - terms[:, 3, 1] 50 | z = terms[:, 0, 3] - terms[:, 1, 2] + terms[:, 2, 1] + terms[:, 3, 0] 51 | return torch.stack((w, x, y, z), dim=1).view(original_shape) 52 | 53 | 54 | def qrot(q, v): 55 | """ 56 | Rotate vector(s) v about the rotation described by quaternion(s) q. 57 | Expects a tensor of shape (*, 4) for q and a tensor of shape (*, 3) for v, 58 | where * denotes any number of dimensions. 59 | Returns a tensor of shape (*, 3). 60 | """ 61 | assert q.shape[-1] == 4 62 | assert v.shape[-1] == 3 63 | assert q.shape[:-1] == v.shape[:-1] 64 | 65 | original_shape = list(v.shape) 66 | # print(q.shape) 67 | q = q.contiguous().view(-1, 4) 68 | v = v.contiguous().view(-1, 3) 69 | 70 | qvec = q[:, 1:] 71 | uv = torch.cross(qvec, v, dim=1) 72 | uuv = torch.cross(qvec, uv, dim=1) 73 | return (v + 2 * (q[:, :1] * uv + uuv)).view(original_shape) 74 | 75 | 76 | def qeuler(q, order, epsilon=0, deg=True): 77 | """ 78 | Convert quaternion(s) q to Euler angles. 79 | Expects a tensor of shape (*, 4), where * denotes any number of dimensions. 80 | Returns a tensor of shape (*, 3). 81 | """ 82 | assert q.shape[-1] == 4 83 | 84 | original_shape = list(q.shape) 85 | original_shape[-1] = 3 86 | q = q.view(-1, 4) 87 | 88 | q0 = q[:, 0] 89 | q1 = q[:, 1] 90 | q2 = q[:, 2] 91 | q3 = q[:, 3] 92 | 93 | if order == 'xyz': 94 | x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2)) 95 | y = torch.asin(torch.clamp(2 * (q1 * q3 + q0 * q2), -1 + epsilon, 1 - epsilon)) 96 | z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3)) 97 | elif order == 'yzx': 98 | x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3)) 99 | y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3)) 100 | z = torch.asin(torch.clamp(2 * (q1 * q2 + q0 * q3), -1 + epsilon, 1 - epsilon)) 101 | elif order == 'zxy': 102 | x = torch.asin(torch.clamp(2 * (q0 * q1 + q2 * q3), -1 + epsilon, 1 - epsilon)) 103 | y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q1 * q1 + q2 * q2)) 104 | z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q1 * q1 + q3 * q3)) 105 | elif order == 'xzy': 106 | x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3)) 107 | y = torch.atan2(2 * (q0 * q2 + q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3)) 108 | z = torch.asin(torch.clamp(2 * (q0 * q3 - q1 * q2), -1 + epsilon, 1 - epsilon)) 109 | elif order == 'yxz': 110 | x = torch.asin(torch.clamp(2 * (q0 * q1 - q2 * q3), -1 + epsilon, 1 - epsilon)) 111 | y = torch.atan2(2 * (q1 * q3 + q0 * q2), 1 - 2 * (q1 * q1 + q2 * q2)) 112 | z = torch.atan2(2 * (q1 * q2 + q0 * q3), 1 - 2 * (q1 * q1 + q3 * q3)) 113 | elif order == 'zyx': 114 | x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2)) 115 | y = torch.asin(torch.clamp(2 * (q0 * q2 - q1 * q3), -1 + epsilon, 1 - epsilon)) 116 | z = torch.atan2(2 * (q0 * q3 + q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3)) 117 | else: 118 | raise 119 | 120 | if deg: 121 | return torch.stack((x, y, z), dim=1).view(original_shape) * 180 / np.pi 122 | else: 123 | return torch.stack((x, y, z), dim=1).view(original_shape) 124 | 125 | 126 | # Numpy-backed implementations 127 | 128 | def qmul_np(q, r): 129 | q = torch.from_numpy(q).contiguous().float() 130 | r = torch.from_numpy(r).contiguous().float() 131 | return qmul(q, r).numpy() 132 | 133 | 134 | def qrot_np(q, v): 135 | q = torch.from_numpy(q).contiguous().float() 136 | v = torch.from_numpy(v).contiguous().float() 137 | return qrot(q, v).numpy() 138 | 139 | 140 | def qeuler_np(q, order, epsilon=0, use_gpu=False): 141 | if use_gpu: 142 | q = torch.from_numpy(q).cuda().float() 143 | return qeuler(q, order, epsilon).cpu().numpy() 144 | else: 145 | q = torch.from_numpy(q).contiguous().float() 146 | return qeuler(q, order, epsilon).numpy() 147 | 148 | 149 | def qfix(q): 150 | """ 151 | Enforce quaternion continuity across the time dimension by selecting 152 | the representation (q or -q) with minimal distance (or, equivalently, maximal dot product) 153 | between two consecutive frames. 154 | 155 | Expects a tensor of shape (L, J, 4), where L is the sequence length and J is the number of joints. 156 | Returns a tensor of the same shape. 157 | """ 158 | assert len(q.shape) == 3 159 | assert q.shape[-1] == 4 160 | 161 | result = q.copy() 162 | dot_products = np.sum(q[1:] * q[:-1], axis=2) 163 | mask = dot_products < 0 164 | mask = (np.cumsum(mask, axis=0) % 2).astype(bool) 165 | result[1:][mask] *= -1 166 | return result 167 | 168 | 169 | def euler2quat(e, order, deg=True): 170 | """ 171 | Convert Euler angles to quaternions. 172 | """ 173 | assert e.shape[-1] == 3 174 | 175 | original_shape = list(e.shape) 176 | original_shape[-1] = 4 177 | 178 | e = e.view(-1, 3) 179 | 180 | ## if euler angles in degrees 181 | if deg: 182 | e = e * np.pi / 180. 183 | 184 | x = e[:, 0] 185 | y = e[:, 1] 186 | z = e[:, 2] 187 | 188 | rx = torch.stack((torch.cos(x / 2), torch.sin(x / 2), torch.zeros_like(x), torch.zeros_like(x)), dim=1) 189 | ry = torch.stack((torch.cos(y / 2), torch.zeros_like(y), torch.sin(y / 2), torch.zeros_like(y)), dim=1) 190 | rz = torch.stack((torch.cos(z / 2), torch.zeros_like(z), torch.zeros_like(z), torch.sin(z / 2)), dim=1) 191 | 192 | result = None 193 | for coord in order: 194 | if coord == 'x': 195 | r = rx 196 | elif coord == 'y': 197 | r = ry 198 | elif coord == 'z': 199 | r = rz 200 | else: 201 | raise 202 | if result is None: 203 | result = r 204 | else: 205 | result = qmul(result, r) 206 | 207 | # Reverse antipodal representation to have a non-negative "w" 208 | if order in ['xyz', 'yzx', 'zxy']: 209 | result *= -1 210 | 211 | return result.view(original_shape) 212 | 213 | 214 | def expmap_to_quaternion(e): 215 | """ 216 | Convert axis-angle rotations (aka exponential maps) to quaternions. 217 | Stable formula from "Practical Parameterization of Rotations Using the Exponential Map". 218 | Expects a tensor of shape (*, 3), where * denotes any number of dimensions. 219 | Returns a tensor of shape (*, 4). 220 | """ 221 | assert e.shape[-1] == 3 222 | 223 | original_shape = list(e.shape) 224 | original_shape[-1] = 4 225 | e = e.reshape(-1, 3) 226 | 227 | theta = np.linalg.norm(e, axis=1).reshape(-1, 1) 228 | w = np.cos(0.5 * theta).reshape(-1, 1) 229 | xyz = 0.5 * np.sinc(0.5 * theta / np.pi) * e 230 | return np.concatenate((w, xyz), axis=1).reshape(original_shape) 231 | 232 | 233 | def euler_to_quaternion(e, order): 234 | """ 235 | Convert Euler angles to quaternions. 236 | """ 237 | assert e.shape[-1] == 3 238 | 239 | original_shape = list(e.shape) 240 | original_shape[-1] = 4 241 | 242 | e = e.reshape(-1, 3) 243 | 244 | x = e[:, 0] 245 | y = e[:, 1] 246 | z = e[:, 2] 247 | 248 | rx = np.stack((np.cos(x / 2), np.sin(x / 2), np.zeros_like(x), np.zeros_like(x)), axis=1) 249 | ry = np.stack((np.cos(y / 2), np.zeros_like(y), np.sin(y / 2), np.zeros_like(y)), axis=1) 250 | rz = np.stack((np.cos(z / 2), np.zeros_like(z), np.zeros_like(z), np.sin(z / 2)), axis=1) 251 | 252 | result = None 253 | for coord in order: 254 | if coord == 'x': 255 | r = rx 256 | elif coord == 'y': 257 | r = ry 258 | elif coord == 'z': 259 | r = rz 260 | else: 261 | raise 262 | if result is None: 263 | result = r 264 | else: 265 | result = qmul_np(result, r) 266 | 267 | # Reverse antipodal representation to have a non-negative "w" 268 | if order in ['xyz', 'yzx', 'zxy']: 269 | result *= -1 270 | 271 | return result.reshape(original_shape) 272 | 273 | 274 | def quaternion_to_matrix(quaternions): 275 | """ 276 | Convert rotations given as quaternions to rotation matrices. 277 | Args: 278 | quaternions: quaternions with real part first, 279 | as tensor of shape (..., 4). 280 | Returns: 281 | Rotation matrices as tensor of shape (..., 3, 3). 282 | """ 283 | r, i, j, k = torch.unbind(quaternions, -1) 284 | two_s = 2.0 / (quaternions * quaternions).sum(-1) 285 | 286 | o = torch.stack( 287 | ( 288 | 1 - two_s * (j * j + k * k), 289 | two_s * (i * j - k * r), 290 | two_s * (i * k + j * r), 291 | two_s * (i * j + k * r), 292 | 1 - two_s * (i * i + k * k), 293 | two_s * (j * k - i * r), 294 | two_s * (i * k - j * r), 295 | two_s * (j * k + i * r), 296 | 1 - two_s * (i * i + j * j), 297 | ), 298 | -1, 299 | ) 300 | return o.reshape(quaternions.shape[:-1] + (3, 3)) 301 | 302 | 303 | def quaternion_to_matrix_np(quaternions): 304 | q = torch.from_numpy(quaternions).contiguous().float() 305 | return quaternion_to_matrix(q).numpy() 306 | 307 | 308 | def quaternion_to_cont6d_np(quaternions): 309 | rotation_mat = quaternion_to_matrix_np(quaternions) 310 | cont_6d = np.concatenate([rotation_mat[..., 0], rotation_mat[..., 1]], axis=-1) 311 | return cont_6d 312 | 313 | 314 | def quaternion_to_cont6d(quaternions): 315 | rotation_mat = quaternion_to_matrix(quaternions) 316 | cont_6d = torch.cat([rotation_mat[..., 0], rotation_mat[..., 1]], dim=-1) 317 | return cont_6d 318 | 319 | 320 | def cont6d_to_matrix(cont6d): 321 | assert cont6d.shape[-1] == 6, "The last dimension must be 6" 322 | x_raw = cont6d[..., 0:3] 323 | y_raw = cont6d[..., 3:6] 324 | 325 | x = x_raw / torch.norm(x_raw, dim=-1, keepdim=True) 326 | z = torch.cross(x, y_raw, dim=-1) 327 | z = z / torch.norm(z, dim=-1, keepdim=True) 328 | 329 | y = torch.cross(z, x, dim=-1) 330 | 331 | x = x[..., None] 332 | y = y[..., None] 333 | z = z[..., None] 334 | 335 | mat = torch.cat([x, y, z], dim=-1) 336 | return mat 337 | 338 | 339 | def cont6d_to_matrix_np(cont6d): 340 | q = torch.from_numpy(cont6d).contiguous().float() 341 | return cont6d_to_matrix(q).numpy() 342 | 343 | 344 | def qpow(q0, t, dtype=torch.float): 345 | ''' q0 : tensor of quaternions 346 | t: tensor of powers 347 | ''' 348 | q0 = qnormalize(q0) 349 | theta0 = torch.acos(q0[..., 0]) 350 | 351 | ## if theta0 is close to zero, add epsilon to avoid NaNs 352 | mask = (theta0 <= 10e-10) * (theta0 >= -10e-10) 353 | theta0 = (1 - mask) * theta0 + mask * 10e-10 354 | v0 = q0[..., 1:] / torch.sin(theta0).view(-1, 1) 355 | 356 | if isinstance(t, torch.Tensor): 357 | q = torch.zeros(t.shape + q0.shape) 358 | theta = t.view(-1, 1) * theta0.view(1, -1) 359 | else: ## if t is a number 360 | q = torch.zeros(q0.shape) 361 | theta = t * theta0 362 | 363 | q[..., 0] = torch.cos(theta) 364 | q[..., 1:] = v0 * torch.sin(theta).unsqueeze(-1) 365 | 366 | return q.to(dtype) 367 | 368 | 369 | def qslerp(q0, q1, t): 370 | ''' 371 | q0: starting quaternion 372 | q1: ending quaternion 373 | t: array of points along the way 374 | 375 | Returns: 376 | Tensor of Slerps: t.shape + q0.shape 377 | ''' 378 | 379 | q0 = qnormalize(q0) 380 | q1 = qnormalize(q1) 381 | q_ = qpow(qmul(q1, qinv(q0)), t) 382 | 383 | return qmul(q_, 384 | q0.contiguous().view(torch.Size([1] * len(t.shape)) + q0.shape).expand(t.shape + q0.shape).contiguous()) 385 | 386 | 387 | def qbetween(v0, v1): 388 | ''' 389 | find the quaternion used to rotate v0 to v1 390 | ''' 391 | assert v0.shape[-1] == 3, 'v0 must be of the shape (*, 3)' 392 | assert v1.shape[-1] == 3, 'v1 must be of the shape (*, 3)' 393 | 394 | v = torch.cross(v0, v1) 395 | w = torch.sqrt((v0 ** 2).sum(dim=-1, keepdim=True) * (v1 ** 2).sum(dim=-1, keepdim=True)) + (v0 * v1).sum(dim=-1, 396 | keepdim=True) 397 | return qnormalize(torch.cat([w, v], dim=-1)) 398 | 399 | 400 | def qbetween_np(v0, v1): 401 | ''' 402 | find the quaternion used to rotate v0 to v1 403 | ''' 404 | assert v0.shape[-1] == 3, 'v0 must be of the shape (*, 3)' 405 | assert v1.shape[-1] == 3, 'v1 must be of the shape (*, 3)' 406 | 407 | v0 = torch.from_numpy(v0).float() 408 | v1 = torch.from_numpy(v1).float() 409 | return qbetween(v0, v1).numpy() 410 | 411 | 412 | def lerp(p0, p1, t): 413 | if not isinstance(t, torch.Tensor): 414 | t = torch.Tensor([t]) 415 | 416 | new_shape = t.shape + p0.shape 417 | new_view_t = t.shape + torch.Size([1] * len(p0.shape)) 418 | new_view_p = torch.Size([1] * len(t.shape)) + p0.shape 419 | p0 = p0.view(new_view_p).expand(new_shape) 420 | p1 = p1.view(new_view_p).expand(new_shape) 421 | t = t.view(new_view_t).expand(new_shape) 422 | 423 | return p0 + t * (p1 - p0) 424 | -------------------------------------------------------------------------------- /models/nets.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Mapping 2 | import torch 3 | 4 | from models.utils import * 5 | from models.cfg_sampler import ClassifierFreeSampleModel 6 | from models.blocks import * 7 | from utils.utils import * 8 | 9 | from models.gaussian_diffusion import ( 10 | # MotionDiffusion, 11 | MotionSpatialControlNetDiffusion, 12 | space_timesteps, 13 | get_named_beta_schedule, 14 | create_named_schedule_sampler, 15 | ModelMeanType, 16 | ModelVarType, 17 | LossType 18 | ) 19 | import random 20 | 21 | class MotionEncoder(nn.Module): 22 | def __init__(self, cfg): 23 | super().__init__() 24 | 25 | self.cfg = cfg 26 | self.input_feats = cfg.INPUT_DIM 27 | self.latent_dim = cfg.LATENT_DIM 28 | self.ff_size = cfg.FF_SIZE 29 | self.num_layers = cfg.NUM_LAYERS 30 | self.num_heads = cfg.NUM_HEADS 31 | self.dropout = cfg.DROPOUT 32 | self.activation = cfg.ACTIVATION 33 | 34 | self.query_token = nn.Parameter(torch.randn(1, self.latent_dim)) 35 | 36 | self.embed_motion = nn.Linear(self.input_feats*2, self.latent_dim) 37 | self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout, max_len=2000) 38 | 39 | seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim, 40 | nhead=self.num_heads, 41 | dim_feedforward=self.ff_size, 42 | dropout=self.dropout, 43 | activation=self.activation, 44 | batch_first=True) 45 | self.transformer = nn.TransformerEncoder(seqTransEncoderLayer, num_layers=self.num_layers) 46 | self.out_ln = nn.LayerNorm(self.latent_dim) 47 | self.out = nn.Linear(self.latent_dim, 512) 48 | 49 | 50 | def forward(self, batch): 51 | x, mask = batch["motions"], batch["mask"] 52 | B, T, D = x.shape 53 | 54 | x = x.reshape(B, T, 2, -1)[..., :-4].reshape(B, T, -1) 55 | 56 | x_emb = self.embed_motion(x) 57 | 58 | emb = torch.cat([self.query_token[torch.zeros(B, dtype=torch.long, device=x.device)][:,None], x_emb], dim=1) 59 | 60 | seq_mask = (mask>0.5) 61 | token_mask = torch.ones((B, 1), dtype=bool, device=x.device) 62 | valid_mask = torch.cat([token_mask, seq_mask], dim=1) 63 | 64 | h = self.sequence_pos_encoder(emb) 65 | h = self.transformer(h, src_key_padding_mask=~valid_mask) 66 | h = self.out_ln(h) 67 | motion_emb = self.out(h[:,0]) 68 | 69 | batch["motion_emb"] = motion_emb 70 | 71 | return batch 72 | 73 | 74 | class InterDenoiser(nn.Module): 75 | def __init__(self, 76 | input_feats, 77 | latent_dim=512, 78 | num_frames=240, 79 | ff_size=1024, 80 | num_layers=8, 81 | num_heads=8, 82 | dropout=0.1, 83 | activation="gelu", 84 | cfg_weight=0., 85 | archi='single', 86 | return_intermediate=False, 87 | **kargs): 88 | super().__init__() 89 | 90 | self.cfg_weight = cfg_weight 91 | self.num_frames = num_frames 92 | self.latent_dim = latent_dim 93 | self.ff_size = ff_size 94 | self.num_layers = num_layers 95 | self.num_heads = num_heads 96 | self.dropout = dropout 97 | self.activation = activation 98 | self.input_feats = input_feats 99 | self.time_embed_dim = latent_dim 100 | 101 | self.archi = archi 102 | 103 | self.text_emb_dim = 768 104 | self.spatial_emb_dim = 4 105 | 106 | self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, dropout=0) 107 | self.embed_timestep = TimestepEmbedder(self.latent_dim, self.sequence_pos_encoder) 108 | 109 | # Input Embedding 110 | self.motion_embed = nn.Linear(self.input_feats, self.latent_dim) 111 | self.text_embed = nn.Linear(self.text_emb_dim, self.latent_dim) 112 | 113 | self.return_intermediate = return_intermediate 114 | 115 | self.blocks = nn.ModuleList() 116 | for _ in range(num_layers): 117 | self.blocks.append(TransformerMotionGuidanceBlock(num_heads=num_heads,latent_dim=latent_dim, dropout=dropout, ff_size=ff_size)) 118 | # self.blocks.append(TransformerBlock(num_heads=num_heads,latent_dim=latent_dim, dropout=dropout, ff_size=ff_size)) 119 | 120 | # Output Module 121 | self.out = zero_module(FinalLayer(self.latent_dim, self.input_feats)) 122 | 123 | 124 | def mask_motion_guidance(self, motion_mask, cond_mask_prob = 0.1, force_mask=False): 125 | # This function is aims to mask the motion guidance with a probability of cond_mask_prob. 126 | # motion_mask of input has been mask the invalid time. 127 | # mask.shape=[B,P,T] 128 | bs, guidance_num = motion_mask.shape[0], motion_mask.shape[1]-1 129 | if guidance_num == 0: 130 | return motion_mask.clone() 131 | elif force_mask: 132 | mask_new = motion_mask.clone() 133 | mask_new[:,1:,:] = 0 134 | return mask_new 135 | elif cond_mask_prob > 0.: 136 | 137 | mask_new = torch.bernoulli(torch.ones(bs,guidance_num, device=motion_mask.device) * cond_mask_prob).view(bs,guidance_num,1).contiguous() # 1-> use null_cond, 0-> use real cond 138 | mask_new = torch.cat([torch.zeros(bs,1,1, device=motion_mask.device), mask_new], dim=1) 139 | 140 | return motion_mask * (1. - mask_new) 141 | else: 142 | return motion_mask 143 | 144 | def forward(self, x, timesteps, mask=None, cond=None, motion_guidance=None, **kwargs): # spatial_condition=None, **kwargs): 145 | """ 146 | x: B, T, D 147 | """ 148 | 149 | B, T = x.shape[0], x.shape[1] 150 | 151 | if motion_guidance is None: # single motion generation 152 | x_cat = x.unsqueeze(1) # B,1,T,D 153 | elif self.return_intermediate: # control branch for two motion 154 | motion_guidance = motion_guidance.permute(0,2,1,3) # B,P,T,D 155 | x_cat = torch.cat([x.unsqueeze(1), motion_guidance], dim=1) # B,1,T,D, B,P,T,D -> B,P+1,T,D 156 | else: # main branch for two motion 157 | x_cat = x.unsqueeze(1) # B,1,T,D 158 | 159 | if mask is not None: 160 | mask = mask.permute(0,2,1).contiguous()[:,:x_cat.shape[1],...] # B,T,P -> B,P,T 161 | else: 162 | mask = torch.ones(B, x_cat.shape[1], T).to(x_cat.device) 163 | 164 | if x_cat.shape[1] > 1: 165 | mask = self.mask_motion_guidance(mask, 0.1) 166 | 167 | emb = self.embed_timestep(timesteps) + self.text_embed(cond) # [4],[4,768] -> [4,1024] 168 | 169 | x_cat_emb = self.motion_embed(x_cat) # B,P+1,T,D -> B,P+1,T,1024 170 | 171 | h_cat_prev = self.sequence_pos_encoder(x_cat_emb.view(-1, T, self.latent_dim)).view(B, -1, T, self.latent_dim) 172 | 173 | key_padding_mask = ~(mask > 0.5) 174 | 175 | h_cat_prev = h_cat_prev.view(B,-1,self.latent_dim) # [B,P*T,D] 176 | key_padding_mask = key_padding_mask.view(B,-1) # [B,P*T] 177 | 178 | if self.return_intermediate: # For control branch | spatial control branch 179 | intermediate_feats = [] 180 | 181 | for idx ,block in enumerate(self.blocks): 182 | 183 | h_cat_prev = block(h_cat_prev, T, emb, key_padding_mask) 184 | 185 | if motion_guidance is None: # single motion generation 186 | pass 187 | elif self.return_intermediate: # control branch 188 | intermediate_feats.append(h_cat_prev.view(B,-1,T,self.latent_dim)[:,:1,...]) 189 | else: # main branch for two motion 190 | h_cat_prev = h_cat_prev.view(B,-1,T,self.latent_dim) 191 | h_cat_prev = h_cat_prev + motion_guidance[idx] 192 | h_cat_prev = h_cat_prev.view(B,-1,self.latent_dim) 193 | 194 | if self.return_intermediate: 195 | return intermediate_feats 196 | 197 | h_cat_prev = h_cat_prev.view(B,-1,T,self.latent_dim) 198 | output = self.out(h_cat_prev) 199 | 200 | return output[:,0,...] # only return the first person for matching the dimension of diffusion process. 201 | 202 | 203 | class InterDenoiserSpatialControlNet(nn.Module): 204 | def __init__(self, 205 | input_feats, 206 | latent_dim=512, 207 | num_frames=240, 208 | ff_size=1024, 209 | num_layers=8, 210 | num_heads=8, 211 | dropout=0.1, 212 | activation="gelu", 213 | cfg_weight=0., 214 | archi='single', 215 | **kargs): 216 | super().__init__() 217 | 218 | self.cfg_weight = cfg_weight 219 | self.num_frames = num_frames 220 | self.latent_dim = latent_dim 221 | self.ff_size = ff_size 222 | self.num_layers = num_layers 223 | self.num_heads = num_heads 224 | self.dropout = dropout 225 | self.activation = activation 226 | self.input_feats = input_feats 227 | self.time_embed_dim = latent_dim 228 | 229 | self.archi = archi 230 | 231 | if self.input_feats == 262 or self.input_feats == 263: 232 | self.n_joints = 22 233 | else: 234 | self.n_joints = 21 235 | 236 | self.net = InterDenoiser(self.input_feats, self.latent_dim, ff_size=self.ff_size, num_layers=self.num_layers, 237 | num_heads=self.num_heads, dropout=self.dropout, activation=self.activation, 238 | cfg_weight=self.cfg_weight, archi=self.archi, return_intermediate=False) 239 | 240 | if self.archi != 'single': 241 | self.control_branch = InterDenoiser(self.input_feats, self.latent_dim, ff_size=self.ff_size, num_layers=self.num_layers, 242 | num_heads=self.num_heads, dropout=self.dropout, activation=self.activation, 243 | cfg_weight=self.cfg_weight, archi=self.archi, return_intermediate=True) 244 | 245 | self.zero_linear = zero_module(nn.ModuleList([nn.Linear(self.latent_dim, self.latent_dim) for _ in range(self.num_layers)])) 246 | 247 | set_requires_grad(self.net, False) 248 | 249 | 250 | def forward(self, x, timesteps, mask=None, cond=None, motion_guidance=None, spatial_condition=None, **kwargs): 251 | 252 | """ 253 | x: B, T, D 254 | spatial_condition: B, T, D 255 | """ 256 | 257 | if motion_guidance is not None: 258 | intermediate_feats = self.control_branch(x, timesteps, mask=mask, cond=cond, motion_guidance=motion_guidance, spatial_condition=None,**kwargs) 259 | intermediate_feats = [self.zero_linear[i](intermediate_feats[i]) for i in range(len(intermediate_feats))] 260 | else: 261 | intermediate_feats = None 262 | 263 | output = self.net(x, timesteps, mask=mask, cond=cond, motion_guidance=intermediate_feats,**kwargs) #spatial_condition=None,**kwargs) 264 | 265 | return output 266 | 267 | 268 | class InterDiffusionSpatialControlNet(nn.Module): 269 | def __init__(self, cfg, sampling_strategy="ddim50"): 270 | super().__init__() 271 | self.cfg = cfg 272 | self.archi = cfg.ARCHI 273 | self.nfeats = cfg.INPUT_DIM 274 | self.latent_dim = cfg.LATENT_DIM 275 | self.ff_size = cfg.FF_SIZE 276 | self.num_layers = cfg.NUM_LAYERS 277 | self.num_heads = cfg.NUM_HEADS 278 | self.dropout = cfg.DROPOUT 279 | self.activation = cfg.ACTIVATION 280 | self.motion_rep = cfg.MOTION_REP 281 | 282 | self.cfg_weight = cfg.CFG_WEIGHT 283 | self.diffusion_steps = cfg.DIFFUSION_STEPS 284 | self.beta_scheduler = cfg.BETA_SCHEDULER 285 | self.sampler = cfg.SAMPLER 286 | self.sampling_strategy = sampling_strategy 287 | 288 | self.net = InterDenoiserSpatialControlNet(self.nfeats, self.latent_dim, ff_size=self.ff_size, num_layers=self.num_layers, 289 | num_heads=self.num_heads, dropout=self.dropout, activation=self.activation, cfg_weight=self.cfg_weight, archi=self.archi) 290 | 291 | self.diffusion_steps = self.diffusion_steps 292 | self.betas = get_named_beta_schedule(self.beta_scheduler, self.diffusion_steps) 293 | 294 | timestep_respacing=[self.diffusion_steps] 295 | self.diffusion = MotionSpatialControlNetDiffusion( 296 | use_timesteps=space_timesteps(self.diffusion_steps, timestep_respacing), 297 | betas=self.betas, 298 | motion_rep=self.motion_rep, 299 | model_mean_type=ModelMeanType.START_X, 300 | model_var_type=ModelVarType.FIXED_SMALL, 301 | loss_type=LossType.MSE, 302 | rescale_timesteps = False, 303 | archi= self.archi, 304 | ) 305 | self.sampler = create_named_schedule_sampler(self.sampler, self.diffusion) 306 | 307 | def mask_cond(self, cond, cond_mask_prob = 0.1, force_mask=False): 308 | bs = cond.shape[0] 309 | if force_mask: 310 | return torch.zeros_like(cond) 311 | elif cond_mask_prob > 0.: 312 | mask = torch.bernoulli(torch.ones(bs, device=cond.device) * cond_mask_prob).view([bs]+[1]*len(cond.shape[1:])) # 1-> use null_cond, 0-> use real cond 313 | return cond * (1. - mask), (1. - mask) 314 | else: 315 | return cond, None 316 | 317 | def generate_src_mask(self, T, length, person_num, B=0): 318 | if B==0: 319 | B = length.shape[0] 320 | else: 321 | if len(length.shape) == 1: 322 | length = torch.cat([length]*B, dim=0) 323 | src_mask = torch.ones(B, T, person_num) 324 | for p in range(person_num): 325 | for i in range(B): 326 | for j in range(length[i], T): 327 | src_mask[i, j, p] = 0 328 | return src_mask 329 | 330 | def compute_loss(self, batch): # 似乎这个函数是在训练阶段使用的 331 | cond = batch["cond"] 332 | x_and_xCondition = batch["motions"] # [4,300,524] 333 | B,T = batch["motions"].shape[:2] 334 | 335 | if cond is not None: 336 | cond, cond_mask = self.mask_cond(cond, 0.1) 337 | 338 | seq_mask = self.generate_src_mask(batch["motions"].shape[1], batch["motion_lens"], batch["person_num"]).to(x_and_xCondition.device) 339 | 340 | t, _ = self.sampler.sample(B, x_and_xCondition.device) 341 | 342 | output = self.diffusion.training_losses( 343 | model=self.net, 344 | x_start=x_and_xCondition, 345 | t=t, 346 | mask=seq_mask, 347 | t_bar=self.cfg.T_BAR, 348 | cond_mask=cond_mask, 349 | model_kwargs={"mask":seq_mask, 350 | "cond":cond, 351 | "person_num":batch["person_num"], 352 | # "spatial_condition":batch["spatial_condition"], 353 | }, 354 | ) 355 | return output 356 | 357 | def forward(self, batch): 358 | 359 | cond = batch["cond"] 360 | 361 | motion_guidance = batch["motion_guidance"] 362 | # spatial_condition = batch["spatial_condition"] 363 | 364 | if motion_guidance is not None: 365 | B, T = motion_guidance.shape[:2] 366 | else: # T equals valid motion lens, then all items of the mask is valid. 367 | B = cond.shape[0] 368 | T = batch["motion_lens"].item() 369 | 370 | # add valid time length mask 371 | seq_mask = self.generate_src_mask(T, batch["motion_lens"], batch["person_num"], B=cond.shape[0]).to(batch["motion_lens"].device) 372 | 373 | timestep_respacing= self.sampling_strategy 374 | self.diffusion_test = MotionSpatialControlNetDiffusion( # MotionSpatialControlNetDiffusion 375 | use_timesteps=space_timesteps(self.diffusion_steps, timestep_respacing), 376 | betas=self.betas, 377 | motion_rep=self.motion_rep, 378 | model_mean_type=ModelMeanType.START_X, 379 | model_var_type=ModelVarType.FIXED_SMALL, 380 | loss_type=LossType.MSE, 381 | rescale_timesteps = False, 382 | archi= self.archi, 383 | ) 384 | 385 | self.cfg_model = ClassifierFreeSampleModel(self.net, self.cfg_weight) # cfg_weight=3.5 386 | 387 | output = self.diffusion_test.ddim_sample_loop( 388 | self.cfg_model, 389 | (B,T,self.nfeats), 390 | clip_denoised=False, 391 | progress=True, 392 | model_kwargs={ 393 | "mask":seq_mask, # None, 394 | "cond":cond, 395 | "motion_guidance":motion_guidance, 396 | # "spatial_condition": spatial_condition 397 | }, 398 | x_start=None) 399 | 400 | return {"output":output} # output: [batch_size, n_ctx, 2*d_model]=[1,210,524] 401 | 402 | 403 | 404 | 405 | -------------------------------------------------------------------------------- /datasets/evaluator.py: -------------------------------------------------------------------------------- 1 | from os.path import join as pjoin 2 | from torch.utils.data import Dataset, DataLoader 3 | from datasets import InterHumanPipelineInferDataset 4 | from models import * 5 | import copy 6 | from datasets.evaluator_models import InterCLIP 7 | from tqdm import tqdm 8 | import torch 9 | 10 | from utils.quaternion import * 11 | 12 | normalizer = MotionNormalizerTorch() 13 | 14 | def process_motion_np(motion, feet_thre, prev_frames, n_joints, target=np.array([[0, 0, 1]])): 15 | 16 | positions = motion[:, :n_joints*3].reshape(-1, n_joints, 3) # [95,22,3] 17 | 18 | vels = motion[:, n_joints*3:n_joints*6].reshape(-1, n_joints, 3) # [95,22,3] 19 | 20 | '''XZ at origin''' 21 | # move the root pos of the first frame to (0,0) of xz plane 22 | # for example 23 | # poistion_root_init = [2,3,5] -> [0,3,5] 24 | 25 | root_pos_init = positions[prev_frames] # [95,22,3] -> [22,3] 26 | root_pose_init_xz = root_pos_init[0] * np.array([1, 0, 1]) # [3] 27 | positions = positions - root_pose_init_xz 28 | 29 | '''All initially face Z+''' 30 | r_hip, l_hip, sdr_r, sdr_l = face_joint_indx 31 | across = root_pos_init[r_hip] - root_pos_init[l_hip] 32 | across = across / np.sqrt((across ** 2).sum(axis=-1))[..., np.newaxis] 33 | 34 | # forward (3,), rotate around y-axis 35 | forward_init = np.cross(np.array([[0, 1, 0]]), across, axis=-1) 36 | # forward (3,) 37 | forward_init = forward_init / np.sqrt((forward_init ** 2).sum(axis=-1))[..., np.newaxis] # 归一化 38 | 39 | # target = np.array([[0, 0, 1]]) 40 | root_quat_init = qbetween_np(forward_init, target) 41 | root_quat_init_for_all = np.ones(positions.shape[:-1] + (4,)) * root_quat_init 42 | 43 | positions = qrot_np(root_quat_init_for_all, positions) 44 | 45 | vels = qrot_np(root_quat_init_for_all, vels) 46 | 47 | '''Get Joint Rotation Invariant Position Represention''' 48 | joint_positions = positions.reshape(len(positions), -1) # [95,66] 49 | joint_vels = vels.reshape(len(vels), -1) # [95,66] 50 | 51 | motion[:, :n_joints*3] = joint_positions 52 | motion[:, n_joints*3:n_joints*6] = joint_vels 53 | 54 | return motion, root_quat_init, root_pose_init_xz[None] 55 | 56 | def rigid_transform(relative, data): 57 | 58 | global_positions = data[..., :22 * 3].reshape(data.shape[:-1] + (22, 3)) 59 | global_vel = data[..., 22 * 3:22 * 6].reshape(data.shape[:-1] + (22, 3)) 60 | 61 | relative_rot = relative[0] 62 | relative_t = relative[1:3] 63 | relative_r_rot_quat = np.zeros(global_positions.shape[:-1] + (4,)) 64 | relative_r_rot_quat[..., 0] = np.cos(relative_rot) 65 | relative_r_rot_quat[..., 2] = np.sin(relative_rot) 66 | global_positions = qrot_np(qinv_np(relative_r_rot_quat), global_positions) 67 | global_positions[..., [0, 2]] += relative_t 68 | data[..., :22 * 3] = global_positions.reshape(data.shape[:-1] + (-1,)) 69 | global_vel = qrot_np(qinv_np(relative_r_rot_quat), global_vel) 70 | data[..., 22 * 3:22 * 6] = global_vel.reshape(data.shape[:-1] + (-1,)) 71 | 72 | return data 73 | 74 | 75 | class EvaluationDataset(Dataset): 76 | 77 | def __init__(self, model, dataset, device, mm_num_samples, mm_num_repeats): 78 | 79 | self.normalizer = MotionNormalizer() 80 | self.device = device 81 | self.model = model.to(device) 82 | self.model.eval() 83 | dataloader = DataLoader(dataset, batch_size=1, num_workers=0, shuffle=True) 84 | self.max_length = dataset.max_length 85 | 86 | idxs = list(range(len(dataset))) 87 | random.shuffle(idxs) 88 | mm_idxs = idxs[:mm_num_samples] 89 | 90 | generated_motions = [] 91 | mm_generated_motions = [] 92 | # Pre-process all target captions 93 | with torch.no_grad(): 94 | for i, data in tqdm(enumerate(dataloader)): 95 | name, text, text1, text2, motion1, motion2, motion_lens, hint = data 96 | batch = {} 97 | 98 | if min(hint.shape) == 0: 99 | hint = None 100 | 101 | if i in mm_idxs: 102 | text1 = list(text1) * mm_num_repeats 103 | text2 = list(text2) * mm_num_repeats 104 | text = list(text) * mm_num_repeats 105 | 106 | batch["motion_lens"] = motion_lens 107 | 108 | model_name = self.model.__class__.__name__ 109 | motions_output = self.pipeline_generate(self.model, motion_lens, [text1,text2], text, [motion1,motion2], 1, hint) # [1,27,2,262] 110 | 111 | motions_output = self.normalizer.backward(motions_output.cpu().detach().numpy()) 112 | 113 | B,T = motions_output.shape[0], motions_output.shape[1] 114 | if T < self.max_length: 115 | padding_len = self.max_length - T 116 | D = motions_output.shape[-1] 117 | padding_zeros = np.zeros((B, padding_len, 2, D)) 118 | motions_output = np.concatenate((motions_output, padding_zeros), axis=1) 119 | assert motions_output.shape[1] == self.max_length 120 | 121 | sub_dict = {'motion1': motions_output[0, :,0], 122 | 'motion2': motions_output[0, :,1], 123 | 'motion_lens': motion_lens[0], 124 | 'text': text[0]} 125 | 126 | if hint is not None: 127 | sub_dict['spatial_condition'] = hint[0] 128 | # else: 129 | # sub_dict['hint'] = None 130 | generated_motions.append(sub_dict) 131 | if i in mm_idxs: 132 | mm_sub_dict = {'mm_motions': motions_output, 133 | 'motion_lens': motion_lens[0], 134 | 'text': text[0]} 135 | mm_generated_motions.append(mm_sub_dict) 136 | 137 | 138 | self.generated_motions = generated_motions 139 | self.mm_generated_motions = mm_generated_motions 140 | 141 | def __len__(self): 142 | return len(self.generated_motions) 143 | 144 | def __getitem__(self, item): 145 | data = self.generated_motions[item] 146 | motion1, motion2, motion_lens, text = data['motion1'], data['motion2'], data['motion_lens'], data['text'] 147 | hint = data['spatial_condition'] if 'spatial_condition' in data else torch.zeros(0) 148 | return "generated", text, "placeholder", "placeholder", motion1, motion2, motion_lens, hint 149 | 150 | def pipeline_generate(self, model, motion_lens, texts, text_multi_person, motions, FLAG=0, hint=None): 151 | 152 | def generate(motion_lens, text, text_multi_person, person_num=1, motion_guidance=None, hint=None): 153 | T = motion_lens 154 | batch = {} 155 | batch["prompt"] = list(text) 156 | batch["text"] = list(text) 157 | batch["text_multi_person"] = list(text_multi_person) 158 | batch["person_num"] = person_num 159 | batch["motion_lens"] = T # torch.tensor([T]).unsqueeze(0).long().to(torch.device("cuda:0")) 160 | batch["motion_guidance"] = motion_guidance 161 | 162 | if hint is not None: 163 | hint = hint[:,:T,...] 164 | batch["spatial_condition"] = hint 165 | batch = model.forward_test(batch) 166 | output = batch["output"].reshape(batch["output"].shape[0], batch["output"].shape[1], 1, -1) 167 | return output 168 | 169 | generated_motions = [] 170 | motion = generate(motion_lens.to(self.device), texts[0], text_multi_person) 171 | 172 | generated_motions.append(motion) 173 | 174 | if FLAG != 0: 175 | for person_idx, text in enumerate(texts[1:]): 176 | tmp_motions = [ m.detach().float().to(self.device) for m in generated_motions] 177 | motion_guidance = torch.cat(tmp_motions, dim=-2) 178 | motion = generate(motion_lens.to(self.device), text, text_multi_person, person_idx+2, motion_guidance, None) #hint.to(self.device)) 179 | generated_motions.append(motion) 180 | return torch.cat(generated_motions, dim=-2) # [:motion_lens.item()] # [B,T,P,D] 181 | 182 | def intergen_generate(self, model, motion_lens, texts, text_multi_person, motions, person_num=0): 183 | 184 | def generate(motion_lens, texts, text_multi_person, person_num=1): 185 | T = motion_lens 186 | batch = {} 187 | batch["text1"] = list(texts[0]) 188 | batch["text2"] = list(texts[1]) 189 | 190 | batch["text"] = list(text_multi_person) 191 | batch["person_num"] = person_num 192 | batch["motion_lens"] = T 193 | 194 | batch = model.forward_test(batch) 195 | output = batch["output"].reshape(batch["output"].shape[0], batch["output"].shape[1], person_num, -1) 196 | return output 197 | 198 | motion = generate(motion_lens.to(self.device), texts, text_multi_person, person_num=person_num) 199 | 200 | return motion 201 | 202 | class MMGeneratedDataset(Dataset): 203 | def __init__(self, motion_dataset): 204 | self.dataset = motion_dataset.mm_generated_motions 205 | 206 | def __len__(self): 207 | return len(self.dataset) 208 | 209 | def __getitem__(self, item): 210 | data = self.dataset[item] 211 | mm_motions = data['mm_motions'] 212 | motion_lens = data['motion_lens'] 213 | mm_motions1 = mm_motions[:,:,0] 214 | mm_motions2 = mm_motions[:,:,1] 215 | text = data['text'] 216 | motion_lens = np.array([motion_lens]*mm_motions1.shape[0]) 217 | return "mm_generated", text, "placeholder", "placeholder", mm_motions1, mm_motions2, motion_lens, torch.zeros(0) 218 | 219 | 220 | def get_dataset_motion_loader(opt, batch_size): 221 | opt = copy.deepcopy(opt) 222 | # Configurations of T2M dataset and KIT dataset is almost the same 223 | 224 | if opt.NAME == 'interhuman': 225 | print('Loading dataset %s ...' % opt.NAME) 226 | dataset = InterHumanPipelineInferDataset(opt) 227 | dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0, drop_last=True, shuffle=True) 228 | else: 229 | raise KeyError('Dataset not Recognized !!') 230 | 231 | print('Ground Truth Dataset Loading Completed!!!') 232 | return dataloader, dataset 233 | 234 | def get_motion_loader(batch_size, model, ground_truth_dataset, device, mm_num_samples, mm_num_repeats): 235 | 236 | dataset = EvaluationDataset(model, ground_truth_dataset, device, mm_num_samples=mm_num_samples, mm_num_repeats=mm_num_repeats) 237 | mm_dataset = MMGeneratedDataset(dataset) 238 | 239 | motion_loader = DataLoader(dataset, batch_size=batch_size, drop_last=True, num_workers=0, shuffle=True) 240 | mm_motion_loader = DataLoader(mm_dataset, batch_size=1, num_workers=0) 241 | 242 | print('Generated Dataset Loading Completed!!!') 243 | 244 | return motion_loader, mm_motion_loader 245 | 246 | 247 | 248 | 249 | def build_models(cfg): 250 | model = InterCLIP(cfg) 251 | checkpoint = torch.load(pjoin('eval_model/interclip.ckpt'),map_location="cpu") 252 | 253 | # checkpoint = torch.load(pjoin('checkpoints/interclip/model/5.ckpt'),map_location="cpu") 254 | for k in list(checkpoint["state_dict"].keys()): 255 | if "model" in k: 256 | checkpoint["state_dict"][k.replace("model.", "")] = checkpoint["state_dict"].pop(k) 257 | model.load_state_dict(checkpoint["state_dict"], strict=True) 258 | 259 | # print('Loading Evaluation Model Wrapper (Epoch %d) Completed!!' % (checkpoint['epoch'])) 260 | return model 261 | 262 | 263 | class EvaluatorModelWrapper(object): 264 | 265 | def __init__(self, cfg, device): 266 | 267 | self.model = build_models(cfg) 268 | self.cfg = cfg 269 | self.device = device 270 | 271 | self.model = self.model.to(device) 272 | self.model.eval() 273 | 274 | 275 | # Please note that the results does not following the order of inputs 276 | def get_co_embeddings(self, batch_data): 277 | with torch.no_grad(): 278 | name, text, text1, text2, motion1, motion2, motion_lens, _ = batch_data 279 | 280 | motion1 = motion1.detach().float() # .to(self.device) 281 | motion2 = motion2.detach().float() # .to(self.device) 282 | motions = torch.cat([motion1, motion2], dim=-1) 283 | motions = motions.detach().to(self.device).float() 284 | 285 | align_idx = np.argsort(motion_lens.data.tolist())[::-1].copy() 286 | motions = motions[align_idx] 287 | motion_lens = motion_lens[align_idx] 288 | text = list(text) 289 | 290 | B, T = motions.shape[:2] 291 | cur_len = torch.LongTensor([min(T, m_len) for m_len in motion_lens]).to(self.device) 292 | padded_len = cur_len.max() 293 | 294 | batch = {} 295 | batch["text"] = text 296 | batch["motions"] = motions.reshape(B, T, -1)[:, :padded_len] 297 | batch["motion_lens"] = motion_lens 298 | 299 | '''Motion Encoding''' 300 | motion_embedding = self.model.encode_motion(batch)['motion_emb'] 301 | 302 | '''Text Encoding''' 303 | text_embedding = self.model.encode_text(batch)['text_emb'][align_idx] 304 | 305 | return text_embedding, motion_embedding 306 | 307 | # Please note that the results does not following the order of inputs 308 | def get_motion_embeddings(self, batch_data): 309 | with torch.no_grad(): 310 | name, text, text1, text2, motion1, motion2, motion_lens, _ = batch_data 311 | motion1 = motion1.detach().float() # .to(self.device) 312 | motion2 = motion2.detach().float() # .to(self.device) 313 | motions = torch.cat([motion1, motion2], dim=-1) 314 | motions = motions.detach().to(self.device).float() 315 | 316 | align_idx = np.argsort(motion_lens.data.tolist())[::-1].copy() 317 | motions = motions[align_idx] 318 | motion_lens = motion_lens[align_idx] 319 | text = list(text) 320 | 321 | B, T = motions.shape[:2] 322 | cur_len = torch.LongTensor([min(T, m_len) for m_len in motion_lens]).to(self.device) 323 | padded_len = cur_len.max() 324 | 325 | batch = {} 326 | batch["text"] = text 327 | batch["motions"] = motions.reshape(B, T, -1)[:, :padded_len] 328 | batch["motion_lens"] = motion_lens 329 | 330 | '''Motion Encoding''' 331 | motion_embedding = self.model.encode_motion(batch)['motion_emb'] 332 | 333 | return motion_embedding 334 | 335 | 336 | class TrainEvaluatorModelWrapper(object): 337 | 338 | def __init__(self, cfg, device): 339 | 340 | self.model = build_models(cfg) 341 | self.cfg = cfg 342 | self.device = device 343 | 344 | self.model = self.model.to(device) 345 | # self.model.eval() 346 | 347 | 348 | # Please note that the results does not following the order of inputs 349 | def get_co_embeddings(self, batch_data): 350 | name, text, motion1, motion2, motion_lens = batch_data 351 | motion1 = motion1.detach().float() # .to(self.device) 352 | motion2 = motion2.detach().float() # .to(self.device) 353 | motions = torch.cat([motion1, motion2], dim=-1) 354 | motions = motions.detach().to(self.device).float() 355 | align_idx = np.argsort(motion_lens.data.tolist())[::-1].copy() 356 | motions = motions[align_idx] 357 | motion_lens = motion_lens[align_idx] 358 | text = list(text) 359 | B, T = motions.shape[:2] 360 | cur_len = torch.LongTensor([min(T, m_len) for m_len in motion_lens]).to(self.device) 361 | padded_len = cur_len.max() 362 | batch = {} 363 | batch["text"] = text 364 | batch["motions"] = motions.reshape(B, T, -1)[:, :padded_len] 365 | batch["motion_lens"] = motion_lens 366 | '''Motion Encoding''' 367 | motion_embedding = self.model.encode_motion(batch)['motion_emb'] 368 | '''Text Encoding''' 369 | text_embedding = self.model.encode_text(batch)['text_emb'][align_idx] 370 | 371 | return text_embedding, motion_embedding 372 | 373 | # Please note that the results does not following the order of inputs 374 | def get_motion_embeddings(self, batch_data): 375 | with torch.no_grad(): 376 | name, text, motion1, motion2, motion_lens = batch_data 377 | motion1 = motion1.detach().float() # .to(self.device) 378 | motion2 = motion2.detach().float() # .to(self.device) 379 | motions = torch.cat([motion1, motion2], dim=-1) 380 | motions = motions.detach().to(self.device).float() 381 | 382 | align_idx = np.argsort(motion_lens.data.tolist())[::-1].copy() 383 | motions = motions[align_idx] 384 | motion_lens = motion_lens[align_idx] 385 | text = list(text) 386 | 387 | B, T = motions.shape[:2] 388 | cur_len = torch.LongTensor([min(T, m_len) for m_len in motion_lens]).to(self.device) 389 | padded_len = cur_len.max() 390 | 391 | batch = {} 392 | batch["text"] = text 393 | batch["motions"] = motions.reshape(B, T, -1)[:, :padded_len] 394 | batch["motion_lens"] = motion_lens 395 | 396 | '''Motion Encoding''' 397 | motion_embedding = self.model.encode_motion(batch)['motion_emb'] 398 | 399 | return motion_embedding 400 | 401 | def compute_contrastive_loss(self, batch_data): 402 | 403 | loss_total, losses = self.model(batch_data) 404 | 405 | return loss_total, losses 406 | -------------------------------------------------------------------------------- /models/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from utils.utils import * 5 | 6 | kinematic_chain = [[0, 2, 5, 8, 11], 7 | [0, 1, 4, 7, 10], 8 | [0, 3, 6, 9, 12, 15], 9 | [9, 14, 17, 19, 21], 10 | [9, 13, 16, 18, 20]] 11 | 12 | class InterLoss(nn.Module): 13 | def __init__(self, recons_loss, nb_joints): 14 | super(InterLoss, self).__init__() 15 | self.nb_joints = nb_joints 16 | if recons_loss == 'l1': 17 | self.Loss = torch.nn.L1Loss(reduction='none') 18 | elif recons_loss == 'l2': 19 | self.Loss = torch.nn.MSELoss(reduction='none') 20 | elif recons_loss == 'l1_smooth': 21 | self.Loss = torch.nn.SmoothL1Loss(reduction='none') 22 | 23 | self.normalizer = MotionNormalizerTorch() 24 | 25 | self.weights = {} 26 | self.weights["RO"] = 0.01 27 | self.weights["JA"] = 3 28 | self.weights["DM"] = 3 29 | 30 | self.losses = {} 31 | 32 | def seq_masked_mse(self, prediction, target, mask): 33 | loss = self.Loss(prediction, target).mean(dim=-1, keepdim=True) 34 | loss = (loss * mask).sum() / (mask.sum() + 1.e-7) 35 | return loss 36 | 37 | def mix_masked_mse(self, prediction, target, mask, batch_mask, contact_mask=None, dm_mask=None): 38 | if dm_mask is not None: 39 | loss = (self.Loss(prediction, target) * dm_mask).sum(dim=-1, keepdim=True)/ (dm_mask.sum(dim=-1, keepdim=True) + 1.e-7) 40 | else: 41 | loss = self.Loss(prediction, target).mean(dim=-1, keepdim=True) # [b,t,p,4,1] 42 | if contact_mask is not None: 43 | loss = (loss[..., 0] * contact_mask).sum(dim=-1, keepdim=True) / (contact_mask.sum(dim=-1, keepdim=True) + 1.e-7) 44 | loss = (loss * mask).sum(dim=(-1, -2, -3)) / (mask.sum(dim=(-1, -2, -3)) + 1.e-7) # [b] 45 | loss = (loss * batch_mask).sum(dim=0) / (batch_mask.sum(dim=0) + 1.e-7) 46 | 47 | return loss 48 | 49 | def forward(self, motion_pred, motion_gt, mask, timestep_mask): 50 | B, T = motion_pred.shape[:2] 51 | self.losses["simple"] = self.seq_masked_mse(motion_pred, motion_gt, mask) 52 | target = self.normalizer.backward(motion_gt, global_rt=True) 53 | prediction = self.normalizer.backward(motion_pred, global_rt=True) 54 | 55 | self.pred_g_joints = prediction[..., :self.nb_joints * 3].reshape(B, T, -1, self.nb_joints, 3) 56 | self.tgt_g_joints = target[..., :self.nb_joints * 3].reshape(B, T, -1, self.nb_joints, 3) 57 | 58 | self.mask = mask 59 | self.timestep_mask = timestep_mask 60 | 61 | self.forward_distance_map(thresh=1) 62 | self.forward_joint_affinity(thresh=0.1) 63 | self.forward_relatvie_rot() 64 | self.accum_loss() 65 | 66 | 67 | def forward_relatvie_rot(self): 68 | 69 | r_hip, l_hip, sdr_r, sdr_l = face_joint_indx 70 | across = self.pred_g_joints[..., r_hip, :] - self.pred_g_joints[..., l_hip, :] 71 | across = across / across.norm(dim=-1, keepdim=True) 72 | across_gt = self.tgt_g_joints[..., r_hip, :] - self.tgt_g_joints[..., l_hip, :] 73 | across_gt = across_gt / across_gt.norm(dim=-1, keepdim=True) 74 | 75 | y_axis = torch.zeros_like(across) 76 | y_axis[..., 1] = 1 77 | 78 | forward = torch.cross(y_axis, across, axis=-1) 79 | forward = forward / forward.norm(dim=-1, keepdim=True) 80 | 81 | forward_gt = torch.cross(y_axis, across_gt, axis=-1) 82 | forward_gt = forward_gt / forward_gt.norm(dim=-1, keepdim=True) 83 | 84 | pred_relative_rot = qbetween(forward[..., 0, :], forward[..., 1, :]) 85 | tgt_relative_rot = qbetween(forward_gt[..., 0, :], forward_gt[..., 1, :]) 86 | 87 | self.losses["RO"] = self.mix_masked_mse(pred_relative_rot[..., [0, 2]], 88 | tgt_relative_rot[..., [0, 2]], 89 | self.mask[..., 0, :], self.timestep_mask) * self.weights["RO"] 90 | 91 | 92 | def forward_distance_map(self, thresh): 93 | pred_g_joints = self.pred_g_joints.reshape(self.mask.shape[:-1] + (-1,)) 94 | tgt_g_joints = self.tgt_g_joints.reshape(self.mask.shape[:-1] + (-1,)) 95 | 96 | pred_g_joints1 = pred_g_joints[..., 0:1, :].reshape(-1, self.nb_joints, 3) 97 | pred_g_joints2 = pred_g_joints[..., 1:2, :].reshape(-1, self.nb_joints, 3) 98 | tgt_g_joints1 = tgt_g_joints[..., 0:1, :].reshape(-1, self.nb_joints, 3) 99 | tgt_g_joints2 = tgt_g_joints[..., 1:2, :].reshape(-1, self.nb_joints, 3) 100 | 101 | pred_distance_matrix = torch.cdist(pred_g_joints1.contiguous(), pred_g_joints2).reshape( 102 | self.mask.shape[:-2] + (1, -1,)) 103 | tgt_distance_matrix = torch.cdist(tgt_g_joints1.contiguous(), tgt_g_joints2).reshape( 104 | self.mask.shape[:-2] + (1, -1,)) 105 | 106 | distance_matrix_mask = (pred_distance_matrix < thresh).float() 107 | 108 | self.losses["DM"] = self.mix_masked_mse(pred_distance_matrix, tgt_distance_matrix, 109 | self.mask[..., 0:1, :], 110 | self.timestep_mask, dm_mask=distance_matrix_mask) * self.weights["DM"] 111 | 112 | def forward_joint_affinity(self, thresh): 113 | pred_g_joints = self.pred_g_joints.reshape(self.mask.shape[:-1] + (-1,)) 114 | tgt_g_joints = self.tgt_g_joints.reshape(self.mask.shape[:-1] + (-1,)) 115 | 116 | pred_g_joints1 = pred_g_joints[..., 0:1, :].reshape(-1, self.nb_joints, 3) 117 | pred_g_joints2 = pred_g_joints[..., 1:2, :].reshape(-1, self.nb_joints, 3) 118 | tgt_g_joints1 = tgt_g_joints[..., 0:1, :].reshape(-1, self.nb_joints, 3) 119 | tgt_g_joints2 = tgt_g_joints[..., 1:2, :].reshape(-1, self.nb_joints, 3) 120 | 121 | pred_distance_matrix = torch.cdist(pred_g_joints1.contiguous(), pred_g_joints2).reshape( 122 | self.mask.shape[:-2] + (1, -1,)) 123 | tgt_distance_matrix = torch.cdist(tgt_g_joints1.contiguous(), tgt_g_joints2).reshape( 124 | self.mask.shape[:-2] + (1, -1,)) 125 | 126 | distance_matrix_mask = (tgt_distance_matrix < thresh).float() 127 | 128 | self.losses["JA"] = self.mix_masked_mse(pred_distance_matrix, torch.zeros_like(tgt_distance_matrix), 129 | self.mask[..., 0:1, :], 130 | self.timestep_mask, dm_mask=distance_matrix_mask) * self.weights["JA"] 131 | 132 | def accum_loss(self): 133 | loss = 0 134 | for term in self.losses.keys(): 135 | loss += self.losses[term] 136 | self.losses["total"] = loss 137 | return self.losses 138 | 139 | 140 | class SingleLoss(nn.Module): 141 | def __init__(self, recons_loss, nb_joints): 142 | super(SingleLoss, self).__init__() 143 | self.nb_joints = nb_joints 144 | if recons_loss == 'l1': 145 | self.Loss = torch.nn.L1Loss(reduction='none') 146 | elif recons_loss == 'l2': 147 | self.Loss = torch.nn.MSELoss(reduction='none') 148 | elif recons_loss == 'l1_smooth': 149 | self.Loss = torch.nn.SmoothL1Loss(reduction='none') 150 | 151 | self.normalizer = MotionNormalizerTorch() 152 | 153 | self.losses = {} 154 | 155 | def seq_masked_mse(self, prediction, target, mask): 156 | loss = self.Loss(prediction, target).mean(dim=-1, keepdim=True) 157 | loss = (loss * mask).sum() / (mask.sum() + 1.e-7) 158 | return loss 159 | 160 | def forward(self, motion_pred, motion_gt, mask, timestep_mask): 161 | B, T = motion_pred.shape[:2] 162 | self.losses["simple"] = self.seq_masked_mse(motion_pred, motion_gt, mask) 163 | target = self.normalizer.backward(motion_gt, global_rt=True) 164 | prediction = self.normalizer.backward(motion_pred, global_rt=True) 165 | 166 | self.pred_g_joints = prediction[..., :self.nb_joints * 3].reshape(B, T, -1, self.nb_joints, 3) 167 | self.tgt_g_joints = target[..., :self.nb_joints * 3].reshape(B, T, -1, self.nb_joints, 3) 168 | 169 | self.mask = mask 170 | self.timestep_mask = timestep_mask 171 | 172 | self.accum_loss() 173 | 174 | def accum_loss(self): 175 | loss = 0 176 | for term in self.losses.keys(): 177 | loss += self.losses[term] 178 | self.losses["total"] = loss 179 | return self.losses 180 | 181 | 182 | class GeometricLoss(nn.Module): 183 | def __init__(self, recons_loss, nb_joints, name): 184 | super(GeometricLoss, self).__init__() 185 | self.name = name 186 | self.nb_joints = nb_joints 187 | if recons_loss == 'l1': 188 | self.Loss = torch.nn.L1Loss(reduction='none') 189 | elif recons_loss == 'l2': 190 | self.Loss = torch.nn.MSELoss(reduction='none') 191 | elif recons_loss == 'l1_smooth': 192 | self.Loss = torch.nn.SmoothL1Loss(reduction='none') 193 | 194 | self.normalizer = MotionNormalizerTorch() 195 | self.fids = [7, 10, 8, 11] 196 | 197 | self.weights = {} 198 | self.weights["VEL"] = 30 199 | self.weights["BL"] = 10 200 | self.weights["FC"] = 30 201 | self.weights["POSE"] = 1 202 | self.weights["TR"] = 100 203 | 204 | self.losses = {} 205 | 206 | def seq_masked_mse(self, prediction, target, mask): 207 | loss = self.Loss(prediction, target).mean(dim=-1, keepdim=True) 208 | loss = (loss * mask).sum() / (mask.sum() + 1.e-7) 209 | return loss 210 | 211 | def mix_masked_mse(self, prediction, target, mask, batch_mask, contact_mask=None, dm_mask=None): 212 | if dm_mask is not None: 213 | loss = (self.Loss(prediction, target) * dm_mask).sum(dim=-1, keepdim=True)/ (dm_mask.sum(dim=-1, keepdim=True) + 1.e-7) # [b,t,p,4,1] 214 | else: 215 | loss = self.Loss(prediction, target).mean(dim=-1, keepdim=True) # [b,t,p,4,1] 216 | if contact_mask is not None: 217 | loss = (loss[..., 0] * contact_mask).sum(dim=-1, keepdim=True) / (contact_mask.sum(dim=-1, keepdim=True) + 1.e-7) 218 | loss = (loss * mask).sum(dim=(-1, -2)) / (mask.sum(dim=(-1, -2)) + 1.e-7) # [b] 219 | loss = (loss * batch_mask).sum(dim=0) / (batch_mask.sum(dim=0) + 1.e-7) 220 | 221 | return loss 222 | 223 | def forward(self, motion_pred, motion_gt, mask, timestep_mask): 224 | B, T = motion_pred.shape[:2] 225 | # self.losses["simple"] = self.seq_masked_mse(motion_pred, motion_gt, mask) # * 0.01 226 | target = self.normalizer.backward(motion_gt, global_rt=True) 227 | prediction = self.normalizer.backward(motion_pred, global_rt=True) 228 | 229 | self.first_motion_pred =motion_pred[:,0:1] 230 | self.first_motion_gt =motion_gt[:,0:1] 231 | 232 | self.pred_g_joints = prediction[..., :self.nb_joints * 3].reshape(B, T, self.nb_joints, 3) 233 | self.tgt_g_joints = target[..., :self.nb_joints * 3].reshape(B, T, self.nb_joints, 3) 234 | self.mask = mask 235 | self.timestep_mask = timestep_mask 236 | 237 | self.forward_vel() 238 | self.forward_bone_length() 239 | self.forward_contact() 240 | self.accum_loss() 241 | # return self.losses["simple"] 242 | 243 | def get_local_positions(self, positions, r_rot): 244 | '''Local pose''' 245 | positions[..., 0] -= positions[..., 0:1, 0] 246 | positions[..., 2] -= positions[..., 0:1, 2] 247 | '''All pose face Z+''' 248 | positions = qrot(r_rot[..., None, :].repeat(1, 1, positions.shape[-2], 1), positions) 249 | return positions 250 | 251 | def forward_local_pose(self): 252 | r_hip, l_hip, sdr_r, sdr_l = face_joint_indx 253 | 254 | pred_g_joints = self.pred_g_joints.clone() 255 | tgt_g_joints = self.tgt_g_joints.clone() 256 | 257 | across = pred_g_joints[..., r_hip, :] - pred_g_joints[..., l_hip, :] 258 | across = across / across.norm(dim=-1, keepdim=True) 259 | across_gt = tgt_g_joints[..., r_hip, :] - tgt_g_joints[..., l_hip, :] 260 | across_gt = across_gt / across_gt.norm(dim=-1, keepdim=True) 261 | 262 | y_axis = torch.zeros_like(across) 263 | y_axis[..., 1] = 1 264 | 265 | forward = torch.cross(y_axis, across, axis=-1) 266 | forward = forward / forward.norm(dim=-1, keepdim=True) 267 | forward_gt = torch.cross(y_axis, across_gt, axis=-1) 268 | forward_gt = forward_gt / forward_gt.norm(dim=-1, keepdim=True) 269 | 270 | z_axis = torch.zeros_like(forward) 271 | z_axis[..., 2] = 1 272 | noise = torch.randn_like(z_axis) *0.0001 273 | z_axis = z_axis+noise 274 | z_axis = z_axis / z_axis.norm(dim=-1, keepdim=True) 275 | 276 | 277 | pred_rot = qbetween(forward, z_axis) 278 | tgt_rot = qbetween(forward_gt, z_axis) 279 | 280 | B, T, J, D = self.pred_g_joints.shape 281 | pred_joints = self.get_local_positions(pred_g_joints, pred_rot).reshape(B, T, -1) 282 | tgt_joints = self.get_local_positions(tgt_g_joints, tgt_rot).reshape(B, T, -1) 283 | 284 | self.losses["POSE_"+self.name] = self.mix_masked_mse(pred_joints, tgt_joints, self.mask, self.timestep_mask) * self.weights["POSE"] 285 | 286 | def forward_vel(self): 287 | B, T = self.pred_g_joints.shape[:2] 288 | 289 | pred_vel = self.pred_g_joints[:, 1:] - self.pred_g_joints[:, :-1] 290 | tgt_vel = self.tgt_g_joints[:, 1:] - self.tgt_g_joints[:, :-1] 291 | 292 | pred_vel = pred_vel.reshape(pred_vel.shape[:-2] + (-1,)) 293 | tgt_vel = tgt_vel.reshape(tgt_vel.shape[:-2] + (-1,)) 294 | 295 | self.losses["VEL_"+self.name] = self.mix_masked_mse(pred_vel, tgt_vel, self.mask[:, :-1], self.timestep_mask) * self.weights["VEL"] 296 | 297 | 298 | def forward_contact(self): 299 | 300 | feet_vel = self.pred_g_joints[:, 1:, self.fids, :] - self.pred_g_joints[:, :-1, self.fids,:] 301 | feet_h = self.pred_g_joints[:, :-1, self.fids, 1] 302 | # contact = target[:,:-1,:,-8:-4] # [b,t,p,4] 303 | 304 | contact = self.foot_detect(feet_vel, feet_h, 0.001) 305 | 306 | self.losses["FC_"+self.name] = self.mix_masked_mse(feet_vel, torch.zeros_like(feet_vel), self.mask[:, :-1], 307 | self.timestep_mask, 308 | contact) * self.weights["FC"] 309 | 310 | 311 | 312 | def forward_bone_length(self): 313 | pred_g_joints = self.pred_g_joints 314 | tgt_g_joints = self.tgt_g_joints 315 | pred_bones = [] 316 | tgt_bones = [] 317 | for chain in kinematic_chain: 318 | for i, joint in enumerate(chain[:-1]): 319 | pred_bone = (pred_g_joints[..., chain[i], :] - pred_g_joints[..., chain[i + 1], :]).norm(dim=-1, 320 | keepdim=True) # [B,T,P,1] 321 | tgt_bone = (tgt_g_joints[..., chain[i], :] - tgt_g_joints[..., chain[i + 1], :]).norm(dim=-1, 322 | keepdim=True) 323 | pred_bones.append(pred_bone) 324 | tgt_bones.append(tgt_bone) 325 | 326 | pred_bones = torch.cat(pred_bones, dim=-1) 327 | tgt_bones = torch.cat(tgt_bones, dim=-1) 328 | 329 | self.losses["BL_"+self.name] = self.mix_masked_mse(pred_bones, tgt_bones, self.mask, self.timestep_mask) * self.weights[ 330 | "BL"] 331 | 332 | 333 | def forward_traj(self): 334 | B, T = self.pred_g_joints.shape[:2] 335 | 336 | pred_traj = self.pred_g_joints[..., 0, [0, 2]] 337 | tgt_g_traj = self.tgt_g_joints[..., 0, [0, 2]] 338 | 339 | self.losses["TR_"+self.name] = self.mix_masked_mse(pred_traj, tgt_g_traj, self.mask, self.timestep_mask) * self.weights["TR"] 340 | 341 | 342 | def accum_loss(self): 343 | loss = 0 344 | for term in self.losses.keys(): 345 | loss += self.losses[term] 346 | self.losses[self.name] = loss 347 | 348 | def foot_detect(self, feet_vel, feet_h, thres): 349 | velfactor, heightfactor = torch.Tensor([thres, thres, thres, thres]).to(feet_vel.device), torch.Tensor( 350 | [0.12, 0.05, 0.12, 0.05]).to(feet_vel.device) 351 | 352 | feet_x = (feet_vel[..., 0]) ** 2 353 | feet_y = (feet_vel[..., 1]) ** 2 354 | feet_z = (feet_vel[..., 2]) ** 2 355 | 356 | contact = (((feet_x + feet_y + feet_z) < velfactor) & (feet_h < heightfactor)).float() 357 | return contact 358 | 359 | class InitOrientationAndPositionLoss(nn.Module): 360 | def __init__(self, recons_loss, nb_joints): 361 | super(InitOrientationAndPositionLoss, self).__init__() 362 | self.nb_joints = nb_joints 363 | if recons_loss == 'l1': 364 | self.Loss = torch.nn.L1Loss(reduction='none') 365 | elif recons_loss == 'l2': 366 | self.Loss = torch.nn.MSELoss(reduction='none') 367 | elif recons_loss == 'l1_smooth': 368 | self.Loss = torch.nn.SmoothL1Loss(reduction='none') 369 | 370 | self.normalizer = MotionNormalizerTorch() 371 | 372 | self.weights = {} 373 | self.weights["init_ori"] = 0.01 374 | self.weights["init_pos"] = 3 375 | 376 | self.losses = {} 377 | 378 | def mse(self, prediction, target, weights=1.0): 379 | loss = self.Loss(prediction, target).mean(dim=-1, keepdim=True) 380 | loss = loss.sum() * weights 381 | return loss 382 | 383 | def forward(self, motion_pred, spatial_gt): 384 | 385 | B, T = motion_pred.shape[:2] 386 | 387 | prediction = self.normalizer.backward(motion_pred, global_rt=True) 388 | 389 | self.pred_g_joints = prediction[..., :self.nb_joints * 3].reshape(B, T, -1, self.nb_joints, 3) # [B, T, 1, J, 3] 390 | 391 | r_hip, l_hip, sdr_r, sdr_l = face_joint_indx 392 | across = self.pred_g_joints[..., r_hip, :] - self.pred_g_joints[..., l_hip, :] 393 | across = across / across.norm(dim=-1, keepdim=True) 394 | 395 | y_axis = torch.zeros_like(across) # [B, T, 1, 3] 396 | y_axis[..., 1] = 1 397 | 398 | pred_init_forward = torch.cross(y_axis, across, axis=-1) # [B, T, 1, 3]???? 399 | 400 | pred_init_forward = (pred_init_forward / pred_init_forward.norm(dim=-1, keepdim=True))[:,0,0,[0,2]] # [B, T, 1, 3] -> [B, 2] 401 | 402 | pred_init_pos = self.pred_g_joints[..., 0, :][:,0,0,[0,2]] # # [B, T, 1, J, 3] -> [B, T, 1, 3] -> [B, 2] 403 | 404 | self.losses['spatial'] = self.mse(pred_init_forward, spatial_gt[..., 0:2], self.weights['init_ori']) + self.mse(pred_init_pos, spatial_gt[..., 2:4], self.weights['init_pos']) 405 | 406 | -------------------------------------------------------------------------------- /utils/rotation_conversions.py: -------------------------------------------------------------------------------- 1 | # This code is based on https://github.com/Mathux/ACTOR.git 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | # Check PYTORCH3D_LICENCE before use 4 | 5 | import functools 6 | from typing import Optional 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | 11 | 12 | """ 13 | The transformation matrices returned from the functions in this file assume 14 | the points on which the transformation will be applied are column vectors. 15 | i.e. the R matrix is structured as 16 | 17 | R = [ 18 | [Rxx, Rxy, Rxz], 19 | [Ryx, Ryy, Ryz], 20 | [Rzx, Rzy, Rzz], 21 | ] # (3, 3) 22 | 23 | This matrix can be applied to column vectors by post multiplication 24 | by the points e.g. 25 | 26 | points = [[0], [1], [2]] # (3 x 1) xyz coordinates of a point 27 | transformed_points = R * points 28 | 29 | To apply the same matrix to points which are row vectors, the R matrix 30 | can be transposed and pre multiplied by the points: 31 | 32 | e.g. 33 | points = [[0, 1, 2]] # (1 x 3) xyz coordinates of a point 34 | transformed_points = points * R.transpose(1, 0) 35 | """ 36 | 37 | 38 | def quaternion_to_matrix(quaternions): 39 | """ 40 | Convert rotations given as quaternions to rotation matrices. 41 | 42 | Args: 43 | quaternions: quaternions with real part first, 44 | as tensor of shape (..., 4). 45 | 46 | Returns: 47 | Rotation matrices as tensor of shape (..., 3, 3). 48 | """ 49 | r, i, j, k = torch.unbind(quaternions, -1) 50 | two_s = 2.0 / (quaternions * quaternions).sum(-1) 51 | 52 | o = torch.stack( 53 | ( 54 | 1 - two_s * (j * j + k * k), 55 | two_s * (i * j - k * r), 56 | two_s * (i * k + j * r), 57 | two_s * (i * j + k * r), 58 | 1 - two_s * (i * i + k * k), 59 | two_s * (j * k - i * r), 60 | two_s * (i * k - j * r), 61 | two_s * (j * k + i * r), 62 | 1 - two_s * (i * i + j * j), 63 | ), 64 | -1, 65 | ) 66 | return o.reshape(quaternions.shape[:-1] + (3, 3)) 67 | 68 | 69 | def _copysign(a, b): 70 | """ 71 | Return a tensor where each element has the absolute value taken from the, 72 | corresponding element of a, with sign taken from the corresponding 73 | element of b. This is like the standard copysign floating-point operation, 74 | but is not careful about negative 0 and NaN. 75 | 76 | Args: 77 | a: source tensor. 78 | b: tensor whose signs will be used, of the same shape as a. 79 | 80 | Returns: 81 | Tensor of the same shape as a with the signs of b. 82 | """ 83 | signs_differ = (a < 0) != (b < 0) 84 | return torch.where(signs_differ, -a, a) 85 | 86 | 87 | def _sqrt_positive_part(x): 88 | """ 89 | Returns torch.sqrt(torch.max(0, x)) 90 | but with a zero subgradient where x is 0. 91 | """ 92 | ret = torch.zeros_like(x) 93 | positive_mask = x > 0 94 | ret[positive_mask] = torch.sqrt(x[positive_mask]) 95 | return ret 96 | 97 | 98 | def matrix_to_quaternion(matrix): 99 | """ 100 | Convert rotations given as rotation matrices to quaternions. 101 | 102 | Args: 103 | matrix: Rotation matrices as tensor of shape (..., 3, 3). 104 | 105 | Returns: 106 | quaternions with real part first, as tensor of shape (..., 4). 107 | """ 108 | if matrix.size(-1) != 3 or matrix.size(-2) != 3: 109 | raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.") 110 | m00 = matrix[..., 0, 0] 111 | m11 = matrix[..., 1, 1] 112 | m22 = matrix[..., 2, 2] 113 | o0 = 0.5 * _sqrt_positive_part(1 + m00 + m11 + m22) 114 | x = 0.5 * _sqrt_positive_part(1 + m00 - m11 - m22) 115 | y = 0.5 * _sqrt_positive_part(1 - m00 + m11 - m22) 116 | z = 0.5 * _sqrt_positive_part(1 - m00 - m11 + m22) 117 | o1 = _copysign(x, matrix[..., 2, 1] - matrix[..., 1, 2]) 118 | o2 = _copysign(y, matrix[..., 0, 2] - matrix[..., 2, 0]) 119 | o3 = _copysign(z, matrix[..., 1, 0] - matrix[..., 0, 1]) 120 | return torch.stack((o0, o1, o2, o3), -1) 121 | 122 | 123 | def _axis_angle_rotation(axis: str, angle): 124 | """ 125 | Return the rotation matrices for one of the rotations about an axis 126 | of which Euler angles describe, for each value of the angle given. 127 | 128 | Args: 129 | axis: Axis label "X" or "Y or "Z". 130 | angle: any shape tensor of Euler angles in radians 131 | 132 | Returns: 133 | Rotation matrices as tensor of shape (..., 3, 3). 134 | """ 135 | 136 | cos = torch.cos(angle) 137 | sin = torch.sin(angle) 138 | one = torch.ones_like(angle) 139 | zero = torch.zeros_like(angle) 140 | 141 | if axis == "X": 142 | R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos) 143 | if axis == "Y": 144 | R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos) 145 | if axis == "Z": 146 | R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one) 147 | 148 | return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3)) 149 | 150 | 151 | def euler_angles_to_matrix(euler_angles, convention: str): 152 | """ 153 | Convert rotations given as Euler angles in radians to rotation matrices. 154 | 155 | Args: 156 | euler_angles: Euler angles in radians as tensor of shape (..., 3). 157 | convention: Convention string of three uppercase letters from 158 | {"X", "Y", and "Z"}. 159 | 160 | Returns: 161 | Rotation matrices as tensor of shape (..., 3, 3). 162 | """ 163 | if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3: 164 | raise ValueError("Invalid input euler angles.") 165 | if len(convention) != 3: 166 | raise ValueError("Convention must have 3 letters.") 167 | if convention[1] in (convention[0], convention[2]): 168 | raise ValueError(f"Invalid convention {convention}.") 169 | for letter in convention: 170 | if letter not in ("X", "Y", "Z"): 171 | raise ValueError(f"Invalid letter {letter} in convention string.") 172 | matrices = map(_axis_angle_rotation, convention, torch.unbind(euler_angles, -1)) 173 | return functools.reduce(torch.matmul, matrices) 174 | 175 | 176 | def _angle_from_tan( 177 | axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool 178 | ): 179 | """ 180 | Extract the first or third Euler angle from the two members of 181 | the matrix which are positive constant times its sine and cosine. 182 | 183 | Args: 184 | axis: Axis label "X" or "Y or "Z" for the angle we are finding. 185 | other_axis: Axis label "X" or "Y or "Z" for the middle axis in the 186 | convention. 187 | data: Rotation matrices as tensor of shape (..., 3, 3). 188 | horizontal: Whether we are looking for the angle for the third axis, 189 | which means the relevant entries are in the same row of the 190 | rotation matrix. If not, they are in the same column. 191 | tait_bryan: Whether the first and third axes in the convention differ. 192 | 193 | Returns: 194 | Euler Angles in radians for each matrix in dataset as a tensor 195 | of shape (...). 196 | """ 197 | 198 | i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis] 199 | if horizontal: 200 | i2, i1 = i1, i2 201 | even = (axis + other_axis) in ["XY", "YZ", "ZX"] 202 | if horizontal == even: 203 | return torch.atan2(data[..., i1], data[..., i2]) 204 | if tait_bryan: 205 | return torch.atan2(-data[..., i2], data[..., i1]) 206 | return torch.atan2(data[..., i2], -data[..., i1]) 207 | 208 | 209 | def _index_from_letter(letter: str): 210 | if letter == "X": 211 | return 0 212 | if letter == "Y": 213 | return 1 214 | if letter == "Z": 215 | return 2 216 | 217 | 218 | def matrix_to_euler_angles(matrix, convention: str): 219 | """ 220 | Convert rotations given as rotation matrices to Euler angles in radians. 221 | 222 | Args: 223 | matrix: Rotation matrices as tensor of shape (..., 3, 3). 224 | convention: Convention string of three uppercase letters. 225 | 226 | Returns: 227 | Euler angles in radians as tensor of shape (..., 3). 228 | """ 229 | if len(convention) != 3: 230 | raise ValueError("Convention must have 3 letters.") 231 | if convention[1] in (convention[0], convention[2]): 232 | raise ValueError(f"Invalid convention {convention}.") 233 | for letter in convention: 234 | if letter not in ("X", "Y", "Z"): 235 | raise ValueError(f"Invalid letter {letter} in convention string.") 236 | if matrix.size(-1) != 3 or matrix.size(-2) != 3: 237 | raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.") 238 | i0 = _index_from_letter(convention[0]) 239 | i2 = _index_from_letter(convention[2]) 240 | tait_bryan = i0 != i2 241 | if tait_bryan: 242 | central_angle = torch.asin( 243 | matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0) 244 | ) 245 | else: 246 | central_angle = torch.acos(matrix[..., i0, i0]) 247 | 248 | o = ( 249 | _angle_from_tan( 250 | convention[0], convention[1], matrix[..., i2], False, tait_bryan 251 | ), 252 | central_angle, 253 | _angle_from_tan( 254 | convention[2], convention[1], matrix[..., i0, :], True, tait_bryan 255 | ), 256 | ) 257 | return torch.stack(o, -1) 258 | 259 | 260 | def random_quaternions( 261 | n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False 262 | ): 263 | """ 264 | Generate random quaternions representing rotations, 265 | i.e. versors with nonnegative real part. 266 | 267 | Args: 268 | n: Number of quaternions in a batch to return. 269 | dtype: Type to return. 270 | device: Desired device of returned tensor. Default: 271 | uses the current device for the default tensor type. 272 | requires_grad: Whether the resulting tensor should have the gradient 273 | flag set. 274 | 275 | Returns: 276 | Quaternions as tensor of shape (N, 4). 277 | """ 278 | o = torch.randn((n, 4), dtype=dtype, device=device, requires_grad=requires_grad) 279 | s = (o * o).sum(1) 280 | o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None] 281 | return o 282 | 283 | 284 | def random_rotations( 285 | n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False 286 | ): 287 | """ 288 | Generate random rotations as 3x3 rotation matrices. 289 | 290 | Args: 291 | n: Number of rotation matrices in a batch to return. 292 | dtype: Type to return. 293 | device: Device of returned tensor. Default: if None, 294 | uses the current device for the default tensor type. 295 | requires_grad: Whether the resulting tensor should have the gradient 296 | flag set. 297 | 298 | Returns: 299 | Rotation matrices as tensor of shape (n, 3, 3). 300 | """ 301 | quaternions = random_quaternions( 302 | n, dtype=dtype, device=device, requires_grad=requires_grad 303 | ) 304 | return quaternion_to_matrix(quaternions) 305 | 306 | 307 | def random_rotation( 308 | dtype: Optional[torch.dtype] = None, device=None, requires_grad=False 309 | ): 310 | """ 311 | Generate a single random 3x3 rotation matrix. 312 | 313 | Args: 314 | dtype: Type to return 315 | device: Device of returned tensor. Default: if None, 316 | uses the current device for the default tensor type 317 | requires_grad: Whether the resulting tensor should have the gradient 318 | flag set 319 | 320 | Returns: 321 | Rotation matrix as tensor of shape (3, 3). 322 | """ 323 | return random_rotations(1, dtype, device, requires_grad)[0] 324 | 325 | 326 | def standardize_quaternion(quaternions): 327 | """ 328 | Convert a unit quaternion to a standard form: one in which the real 329 | part is non negative. 330 | 331 | Args: 332 | quaternions: Quaternions with real part first, 333 | as tensor of shape (..., 4). 334 | 335 | Returns: 336 | Standardized quaternions as tensor of shape (..., 4). 337 | """ 338 | return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions) 339 | 340 | 341 | def quaternion_raw_multiply(a, b): 342 | """ 343 | Multiply two quaternions. 344 | Usual torch rules for broadcasting apply. 345 | 346 | Args: 347 | a: Quaternions as tensor of shape (..., 4), real part first. 348 | b: Quaternions as tensor of shape (..., 4), real part first. 349 | 350 | Returns: 351 | The product of a and b, a tensor of quaternions shape (..., 4). 352 | """ 353 | aw, ax, ay, az = torch.unbind(a, -1) 354 | bw, bx, by, bz = torch.unbind(b, -1) 355 | ow = aw * bw - ax * bx - ay * by - az * bz 356 | ox = aw * bx + ax * bw + ay * bz - az * by 357 | oy = aw * by - ax * bz + ay * bw + az * bx 358 | oz = aw * bz + ax * by - ay * bx + az * bw 359 | return torch.stack((ow, ox, oy, oz), -1) 360 | 361 | 362 | def quaternion_multiply(a, b): 363 | """ 364 | Multiply two quaternions representing rotations, returning the quaternion 365 | representing their composition, i.e. the versor with nonnegative real part. 366 | Usual torch rules for broadcasting apply. 367 | 368 | Args: 369 | a: Quaternions as tensor of shape (..., 4), real part first. 370 | b: Quaternions as tensor of shape (..., 4), real part first. 371 | 372 | Returns: 373 | The product of a and b, a tensor of quaternions of shape (..., 4). 374 | """ 375 | ab = quaternion_raw_multiply(a, b) 376 | return standardize_quaternion(ab) 377 | 378 | 379 | def quaternion_invert(quaternion): 380 | """ 381 | Given a quaternion representing rotation, get the quaternion representing 382 | its inverse. 383 | 384 | Args: 385 | quaternion: Quaternions as tensor of shape (..., 4), with real part 386 | first, which must be versors (unit quaternions). 387 | 388 | Returns: 389 | The inverse, a tensor of quaternions of shape (..., 4). 390 | """ 391 | 392 | return quaternion * quaternion.new_tensor([1, -1, -1, -1]) 393 | 394 | 395 | def quaternion_apply(quaternion, point): 396 | """ 397 | Apply the rotation given by a quaternion to a 3D point. 398 | Usual torch rules for broadcasting apply. 399 | 400 | Args: 401 | quaternion: Tensor of quaternions, real part first, of shape (..., 4). 402 | point: Tensor of 3D points of shape (..., 3). 403 | 404 | Returns: 405 | Tensor of rotated points of shape (..., 3). 406 | """ 407 | if point.size(-1) != 3: 408 | raise ValueError(f"Points are not in 3D, f{point.shape}.") 409 | real_parts = point.new_zeros(point.shape[:-1] + (1,)) 410 | point_as_quaternion = torch.cat((real_parts, point), -1) 411 | out = quaternion_raw_multiply( 412 | quaternion_raw_multiply(quaternion, point_as_quaternion), 413 | quaternion_invert(quaternion), 414 | ) 415 | return out[..., 1:] 416 | 417 | 418 | def axis_angle_to_matrix(axis_angle): 419 | """ 420 | Convert rotations given as axis/angle to rotation matrices. 421 | 422 | Args: 423 | axis_angle: Rotations given as a vector in axis angle form, 424 | as a tensor of shape (..., 3), where the magnitude is 425 | the angle turned anticlockwise in radians around the 426 | vector's direction. 427 | 428 | Returns: 429 | Rotation matrices as tensor of shape (..., 3, 3). 430 | """ 431 | return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle)) 432 | 433 | 434 | def matrix_to_axis_angle(matrix): 435 | """ 436 | Convert rotations given as rotation matrices to axis/angle. 437 | 438 | Args: 439 | matrix: Rotation matrices as tensor of shape (..., 3, 3). 440 | 441 | Returns: 442 | Rotations given as a vector in axis angle form, as a tensor 443 | of shape (..., 3), where the magnitude is the angle 444 | turned anticlockwise in radians around the vector's 445 | direction. 446 | """ 447 | return quaternion_to_axis_angle(matrix_to_quaternion(matrix)) 448 | 449 | 450 | def axis_angle_to_quaternion(axis_angle): 451 | """ 452 | Convert rotations given as axis/angle to quaternions. 453 | 454 | Args: 455 | axis_angle: Rotations given as a vector in axis angle form, 456 | as a tensor of shape (..., 3), where the magnitude is 457 | the angle turned anticlockwise in radians around the 458 | vector's direction. 459 | 460 | Returns: 461 | quaternions with real part first, as tensor of shape (..., 4). 462 | """ 463 | angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True) 464 | half_angles = 0.5 * angles 465 | eps = 1e-6 466 | small_angles = angles.abs() < eps 467 | sin_half_angles_over_angles = torch.empty_like(angles) 468 | sin_half_angles_over_angles[~small_angles] = ( 469 | torch.sin(half_angles[~small_angles]) / angles[~small_angles] 470 | ) 471 | # for x small, sin(x/2) is about x/2 - (x/2)^3/6 472 | # so sin(x/2)/x is about 1/2 - (x*x)/48 473 | sin_half_angles_over_angles[small_angles] = ( 474 | 0.5 - (angles[small_angles] * angles[small_angles]) / 48 475 | ) 476 | quaternions = torch.cat( 477 | [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1 478 | ) 479 | return quaternions 480 | 481 | 482 | def quaternion_to_axis_angle(quaternions): 483 | """ 484 | Convert rotations given as quaternions to axis/angle. 485 | 486 | Args: 487 | quaternions: quaternions with real part first, 488 | as tensor of shape (..., 4). 489 | 490 | Returns: 491 | Rotations given as a vector in axis angle form, as a tensor 492 | of shape (..., 3), where the magnitude is the angle 493 | turned anticlockwise in radians around the vector's 494 | direction. 495 | """ 496 | norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True) 497 | half_angles = torch.atan2(norms, quaternions[..., :1]) 498 | angles = 2 * half_angles 499 | eps = 1e-6 500 | small_angles = angles.abs() < eps 501 | sin_half_angles_over_angles = torch.empty_like(angles) 502 | sin_half_angles_over_angles[~small_angles] = ( 503 | torch.sin(half_angles[~small_angles]) / angles[~small_angles] 504 | ) 505 | # for x small, sin(x/2) is about x/2 - (x/2)^3/6 506 | # so sin(x/2)/x is about 1/2 - (x*x)/48 507 | sin_half_angles_over_angles[small_angles] = ( 508 | 0.5 - (angles[small_angles] * angles[small_angles]) / 48 509 | ) 510 | return quaternions[..., 1:] / sin_half_angles_over_angles 511 | 512 | 513 | def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor: 514 | """ 515 | Converts 6D rotation representation by Zhou et al. [1] to rotation matrix 516 | using Gram--Schmidt orthogonalisation per Section B of [1]. 517 | Args: 518 | d6: 6D rotation representation, of size (*, 6) 519 | 520 | Returns: 521 | batch of rotation matrices of size (*, 3, 3) 522 | 523 | [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. 524 | On the Continuity of Rotation Representations in Neural Networks. 525 | IEEE Conference on Computer Vision and Pattern Recognition, 2019. 526 | Retrieved from http://arxiv.org/abs/1812.07035 527 | """ 528 | 529 | a1, a2 = d6[..., [0,2,4]], d6[..., [1,3,5]] 530 | b1 = F.normalize(a1, dim=-1) 531 | b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1 532 | b2 = F.normalize(b2, dim=-1) 533 | b3 = torch.cross(b1, b2, dim=-1) 534 | return torch.cat((b1[...,None], b2[...,None], b3[...,None]), dim=-1) 535 | 536 | 537 | def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor: 538 | """ 539 | Converts rotation matrices to 6D rotation representation by Zhou et al. [1] 540 | by dropping the last row. Note that 6D representation is not unique. 541 | Args: 542 | matrix: batch of rotation matrices of size (*, 3, 3) 543 | 544 | Returns: 545 | 6D rotation representation, of size (*, 6) 546 | 547 | [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. 548 | On the Continuity of Rotation Representations in Neural Networks. 549 | IEEE Conference on Computer Vision and Pattern Recognition, 2019. 550 | Retrieved from http://arxiv.org/abs/1812.07035 551 | """ 552 | return matrix[..., :2, :].clone().reshape(*matrix.size()[:-2], 6) 553 | --------------------------------------------------------------------------------